Source code for svcco.branch_addition.add_branches_v3

import numpy as np
import os #remove before final
import time #not neccessary?
from .check import *
from .close import *
from .local_func_v7 import *
from ..collision.sphere_proximity import *
from ..collision.collision import *
from .add_bifurcation import *
from .sample_triad import *
from .triangle import * #might not need
from .basis import *
from scipy import interpolate
from scipy.spatial import KDTree
import matplotlib.pyplot as plt #remove before final
from .get_point import *
from mpl_toolkits.mplot3d import proj3d #remove before final
from .geodesic import extract_surface,geodesic
from ..implicit.visualize.visualize import show_mesh
from .finite_difference import finite_difference
from scipy.sparse.csgraph import shortest_path
from .add_geodesic_path import *
import pyvista as pv
import random
from scipy.interpolate import griddata
from time import perf_counter

[docs]def add_branch(tree,low,high,threshold_exponent=1.5,threshold_adjuster=0.75, all_max_attempts=40,max_attemps=10,sampling=20,max_skip=8, flow_ratio=None,radius_buffer=0,isforest=False,threshold=None,radius_scale=4, method='L-BFGS-B'): number_edges = tree.parameters['edge_num'] if threshold is None: threshold = ((tree.boundary.volume)**(1/3)/(number_edges**threshold_exponent)) mu = tree.parameters['mu'] lam = tree.parameters['lambda'] gamma = tree.parameters['gamma'] nu = tree.parameters['nu'] Qterm = tree.parameters['Qterm'] Pperm = tree.parameters['Pperm'] Pterm = tree.parameters['Pterm'] new_branch_found = False if tree.homogeneous: subb = None search_time = 0 close_time = 0 constraint_time = 0 local_time = 0 collision_time = 0 time1 = 0 time2 = 0 time3 = 0 time4 = 0 add_time = 0 total_time = 0 start_total = time.time() start_rng = time.time() nonconvex_solve = True if len(tree.rng_points) == 0: #print('repick') rng_points,_ = tree.boundary.pick(size=len(tree.boundary.tet_verts),homogeneous=True,replacement=False) rng_points = rng_points.tolist() else: rng_points = tree.rng_points total_attempts = 0 attempt = 0 while not new_branch_found: start = time.time() #print(threshold_distance) #total_attempts = 0 #attempt = 0 for att in range(all_max_attempts): #print(attempt) start1 = time.time() #point,_ = tree.boundary.pick(homogeneous=True) #time1 += time.time() - start1 if len(rng_points) == 1: #print('repick_low') rng_points,_ = tree.boundary.pick(size=len(tree.boundary.tet_verts),homogeneous=True,replacement=False) rng_points = rng_points.tolist() point = np.array(rng_points.pop(0)) if number_edges < 200: vessel, line_distances = close_exact(tree.data,point) else: vessel, line_distances = close(tree.data,point) time1 += time.time() - start1 start2 = time.time() line_distances_below_threshold = sum(line_distances < threshold) minimum_line_distance = min(line_distances) time2 += time.time() - start2 start3 = time.time() ################ # check to see the attempt number ################ if attempt < max_attemps: attempt += 1 total_attempts += 1 else: if not tree.convex: cell_list_outter = set([]) ptss = [] for i in range(tree.data.shape[0]): first = tree.data[i,0:3] third = tree.data[i,3:6] second = (first+third)/2 ptss.append(second.tolist()) #cell_list_outter.update(set(tree.boundary.cell_lookup.query(first,k=tree.boundary.tet_verts.shape[0]//2))) #cell_list_outter.update(set(tree.boundary.cell_lookup.query(third,k=tree.boundary.tet_verts.shape[0]//2))) #cell_list_outter.update(set(tree.boundary.cell_lookup.query(second,k=tree.boundary.tet_verts.shape[0]//2))) #cell_list_outter.update(set(tree.boundary.cell_lookup.query_ball_point(first,threshold+threshold*0.5))) #cell_list_outter.update(set(tree.boundary.cell_lookup.query_ball_point(third,threshold+threshold*0.5))) #d,id = tree.boundary.cell_lookup.query(second,5000) #id = id.tolist() #id = id[::-1] #cell_list_outter.update(set(id[:200])) cell_list_outter.update(set(tree.boundary.cell_lookup.query_ball_point(second,threshold))) if threshold > ((tree.boundary.volume)**(1/3)/(number_edges**threshold_exponent))*(threshold_adjuster**5): #pass #cell_list_outter.difference_update(set(tree.boundary.cell_lookup.query_ball_point(first,threshold))) #cell_list_outter.difference_update(set(tree.boundary.cell_lookup.query_ball_point(third,threshold))) cell_list_outter.difference_update(set(tree.boundary.cell_lookup.query_ball_point(second,threshold*threshold_adjuster))) ball_size = 2*threshold while len(cell_list_outter) < 1: cell_list_outter.update(set(tree.boundary.cell_lookup.query_ball_point(second,ball_size))) ball_size *= 2 cell_list = list(cell_list_outter) #plotter = pv.Plotter() subb = tree.boundary.tet.grid.extract_cells(cell_list) #plotter.add_mesh(tree.boundary.tet.grid,opacity=0.5) #plotter.add_mesh(subb) #point_poly = pv.PolyData(np.array(ptss)) #plotter.add_mesh(point_poly,color='red') #plotter.show() rng_points = [] max_attemps = min(100,len(cell_list)) for j in range(len(cell_list)): p,_ = tree.boundary.pick_in_cell(cell_list[j]) rng_points.append(p[0].tolist()) random.shuffle(rng_points) threshold *= threshold_adjuster attempt = 0 #print(threshold) ################## ################## #print("Below Threshold: {}".format(line_distances_below_threshold)) #print("Minimum Distances: {} Threshold: {}".format(minimum_line_distance,radius_scale*tree.data[vessel[0],21])) if (line_distances_below_threshold == 0 and minimum_line_distance > radius_scale*tree.data[vessel[0],21]): escape = False start3 = time.time() for i in range(max_skip): if flow_ratio is not None: if tree.data[vessel[i],22] < flow_ratio*Qterm: vessel = vessel[i] escape = True break else: vessel = vessel[i] escape = True break time3 += time.time() - start3 if escape: #print('viable') break """ if attempt < max_attemps: attempt += 1 total_attempts += 1 else: #print('adjusting threshold') if not tree.convex: p0 = (tree.data[vessel[i],0:3]+tree.data[vessel[i],3:6])/2 p1 = point path,lengths,res,pf,f,rm = boundary.find_best_path(p0,p1,niter=10) path = np.array(path) data = np.vstack((t.data,np.zeros((path,tree.data.shape[1])))) data,sub_div_ind,sub_div_map = add_geodesic_path(data,path,lengths,vessel[i], sub_division_index,sub_division_map) return vessel[i],data,sub_div_ind,sub_div_map threshold_distance *= threshold_adjuster attempt = 0 """ else: #print('line distance') start4 = time.time() if attempt < max_attemps: attempt += 1 total_attempts += 1 else: #print('adjusting threshold') attempt = 0 threshold *= threshold_adjuster continue #attempt = 0 time4 += time.time() - start4 search_time += time.time()-start start = time.time() if not isinstance(vessel,np.int64): if len(vessel) > 1: vessel = vessel[0] proximal = tree.data[vessel,0:3] distal = tree.data[vessel,3:6] terminal = point #print(distal) #print(type(distal)) #print(terminal) #print(type(terminal)) if np.all(terminal.shape != distal.shape): terminal = terminal.flatten() points = get_local_points(tree.data,vessel,terminal,sampling,tree.clamped_root) #high_res_points = get_local_points(tree.data,vessel,terminal,10*sampling,tree.clamped_root) #points = np.array(relative_length_constraint(points,proximal,distal,terminal,0.25)) #P = forest.show() #P.add_points(points,render_points_as_spheres=True,point_size=20) #P.add_points(terminal,render_points_as_spheres=True,point_size=20,color='g') #P.show() points = np.array(relative_length_constraint(points,proximal,distal,terminal,0.25)) #high_res_points = np.array(relative_length_constraint(high_res_points,proximal,distal,terminal,0.25)) #P = forest.show() #P.add_points(points,render_points_as_spheres=True,point_size=20) #P.add_points(terminal,render_points_as_spheres=True,point_size=20,color='g') #P.show() if not tree.convex: points = boundary_constraint(points,tree.boundary,2) #high_res_points = boundary_constraint(high_res_points,tree.boundary,2) if len(points) == 0: attempt += 1 #print('constraint 1') continue if vessel != 0 and not tree.clamped_root: points = np.array(angle_constraint(points,terminal,distal,-0.4,True)) #high_res_points = np.array(angle_constraint(high_res_points,terminal,distal,-0.4,True)) if len(points) == 0: attempt += 1 #print('constraint 2') continue if vessel != 0 and not tree.clamped_root: points = np.array(angle_constraint(points,terminal,distal,0.75,False)) #high_res_points = np.array(angle_constraint(high_res_points,terminal,distal,0.75,False)) if len(points) == 0: attempt += 1 #print('constraint 3') continue if vessel != 0 and not tree.clamped_root: points = np.array(angle_constraint(points,terminal,proximal,0,False)) #high_res_points = np.array(angle_constraint(high_res_points,terminal,proximal,0.2,False)) if len(points) == 0: attempt += 1 #print('constraint 4') continue if vessel != 0 and not tree.clamped_root: points = np.array(angle_constraint(points,distal,proximal,0,False)) #high_res_points = np.array(angle_constraint(high_res_points,distal,proximal,0.2,False)) if len(points) == 0: attempt += 1 #print('constraint 5') continue """ if vessel != 0 and not tree.clamped_root: points = np.array(angle_constraint(points,distal,proximal,0.2,False)) #high_res_points = np.array(angle_constraint(high_res_points,distal,proximal,0.2,False)) if len(points) == 0: attempt += 1 #print('constraint 5') continue if vessel != 0 and not tree.clamped_root: points = np.array(angle_constraint(points,distal,proximal,0.2,False)) #high_res_points = np.array(angle_constraint(high_res_points,distal,proximal,0.2,False)) if len(points) == 0: attempt += 1 #print('constraint 5') continue """ if tree.data[vessel,17] >= 0: p_vessel = int(tree.data[vessel,17]) vector_1 = -tree.data[p_vessel,12:15] vector_2 = (points - proximal)/np.linalg.norm(points - proximal,axis=1).reshape(-1,1) #vector_3 = (high_res_points - proximal)/np.linalg.norm(high_res_points - proximal,axis=1).reshape(-1,1) angle = np.array([np.dot(vector_1,vector_2[i]) for i in range(len(vector_2))]) #high_res_angle = np.array([np.dot(vector_1,vector_3[i]) for i in range(len(vector_3))]) points = points[angle<0] #high_res_points = high_res_points[high_res_angle<0] if len(points) == 0: attempt += 1 #print('constraint 6') continue tmp_points = [] #high_res_tmp_points = [] if not tree.convex: for pt in range(points.shape[0]): if tree.boundary.within(points[pt,0],points[pt,1],points[pt,2],2): tmp_points.append(points[pt,:]) points = np.array(tmp_points) #for pt in range(high_res_points.shape[0]): # if tree.boundary.within(high_res_points[pt,0],high_res_points[pt,1],high_res_points[pt,2],2): # high_res_tmp_points.append(high_res_points[pt,:]) #high_res_points = np.array(high_res_tmp_points) if len(points) == 0: attempt += 1 #print('constraint 7') continue tmp_points = [] subdivision = 5 #plotter = pv.Plotter() #polys = pv.PolyData(points) #term_poly = pv.PolyData(terminal) #plotter.add_mesh(tree.boundary.tet.grid,opacity=0.25) #plotter.add_mesh(polys,color='red') #plotter.add_mesh(term_poly,color='green') #if subb is not None: # plotter.add_mesh(subb,color='blue',opacity=0.25) #plotter.show() if not tree.convex: for pt in range(points.shape[0]): include = True for sub in range(1,2*subdivision): mid_proximal = points[pt,:]*(sub/(2*subdivision))+proximal*(1-sub/(2*subdivision)) mid_distal = points[pt,:]*(sub/(2*subdivision))+distal*(1-sub/(2*subdivision)) mid_terminal = points[pt,:]*(sub/(2*subdivision))+terminal*(1-sub/(2*subdivision)) mid_proximal = mid_proximal.flatten() mid_distal = mid_distal.flatten() mid_terminal = mid_terminal.flatten() if vessel != 0: if not tree.boundary.DD[0]((mid_proximal[0],mid_proximal[1],mid_proximal[2],len(tree.boundary.patches)//10)) < 0.1: #print('proximal') include = False break val = tree.boundary.DD[0]((mid_distal[0],mid_distal[1],mid_distal[2],len(tree.boundary.patches)//10)) if not val < 0.01: #print(val) #plotter = pv.Plotter() #plotter.add_mesh(tree.boundary.tet.grid,opacity=0.5) #val_poly = pv.PolyData(mid_distal) #plotter.add_mesh(sub) #plotter.add_mesh(val_poly,color='yellow') #plotter.show() #print('distal') include = False break if not tree.boundary.DD[0]((mid_terminal[0],mid_terminal[1],mid_terminal[2],len(tree.boundary.patches)//10)) < 0.01: #print('terminal') include = False break if include: tmp_points.append(points[pt,:]) points = np.array(tmp_points) #plotter.show() if len(points) == 0: attempt += 1 nonconvex_solve=False #print('constraint 8') continue #print('passed all constraints') constraint_time += time.time()-start start = perf_counter() #construct_results = constructor(tree.data,terminal, # vessel,gamma,nu,Qterm,Pperm,Pterm,lam,mu,sampling,method=method) #print('Constructor Time: {}'.format(perf_counter()-start)) construct_time=perf_counter()-start #print("Minimize x: {} Value: {}".format(results[5],np.pi*results[0]**lam*results[1]**mu)) #start = perf_counter() #fd_point,fd_idx,fd_volume,fd_trial = finite_difference(tree.data,points,terminal, # vessel,gamma,nu,Qterm,Pperm,Pterm) #fd_time = perf_counter() #print('Finite Difference Time: {}'.format(perf_counter()-start)) #print("Finite difference: {} Value: {}".format(fd_point,fd_volume[fd_idx])) #start = time.time() start = perf_counter() results = fast_local_function(tree.data,points,terminal, vessel,gamma,nu,Qterm,Pperm,Pterm) brute_time = perf_counter()-start #truth_results = fast_local_function(tree.data,high_res_points,terminal, # vessel,gamma,nu,Qterm,Pperm,Pterm) #print('Brute Time: {}'.format(perf_counter()-start)) local_time += time.time()-start volume = np.pi*(results[0]**lam)*(results[1]**mu) brute = volume idx = np.argmin(volume) bif = results[5][idx] #truth_volume = np.pi*(truth_results[0]**lam)*(truth_results[1]**mu) #truth_idx = np.argmin(truth_volume) #truth_bif = truth_results[5][truth_idx] #truth_best_vol = truth_volume[truth_idx] #print("Brute x: {} Value: {}".format(bif,min(volume))) start = time.time() no_collision = collision_free(tree.data,results,idx,terminal, vessel,radius_buffer) collision_time += time.time()-start if no_collision: new_branch_found = True start = time.time() data,sub_division_map,sub_division_index = add_bifurcation(tree,vessel,terminal, results,idx,isforest=isforest) #brute = np.sum(np.pi*data[:,21]**lam*data[:,20]**mu) #print("Brute x: {} Value: {}".format(bif,brute[idx])) #fig = plt.figure() #ax1 = fig.add_subplot(121) #x = np.linspace(min(points[:,0]),max(points[:,0]),100) #y = np.linspace(min(points[:,1]),max(points[:,1]),100) #X, Y = np.meshgrid(x,y) #fd_Ti = griddata((points[:,0],points[:,1]),fd_volume,(X,Y),method='linear') #ax1.contour(X,Y,fd_Ti,linewidths=0.5,colors='k') #ax1.pcolormesh(X,Y,fd_Ti,shading='auto',cmap=plt.get_cmap('rainbow')) #ax1.colorbar() #brute_Ti = griddata((points[:,0],points[:,1]),brute,(X,Y),method='linear') #ax2 = fig.add_subplot(122) #ax2.contour(X,Y,brute_Ti,linewidths=0.5,colors='k') #ax2.pcolormesh(X,Y,brute_Ti,shading='auto',cmap=plt.get_cmap('rainbow')) #ax2.colorbar() #plt.colorbar() #plt.show() add_time += time.time()-start total_time += time.time()-start_total tree.time['search'].append(search_time) tree.time['constraints'].append(constraint_time) tree.time['local_optimize'].append(local_time) tree.time['collision'].append(collision_time) tree.time['close_time'].append(close_time) tree.time['search_1'].append(time1) tree.time['search_2'].append(time2) tree.time['search_3'].append(time3) tree.time['search_4'].append(time4) tree.time['add_time'].append(add_time) tree.time['total'].append(total_time) #tree.time['brute_time'].append(brute_time) #tree.time['method_time'].append(construct_time) #tree.time['depth'].append(data[vessel,26]) #tree.time['method_x_value'].append(construct_results[5].flatten()) #tree.time['method_value'].append(np.pi*construct_results[0]**lam*construct_results[1]**mu) #tree.time['brute_x_value'].append(bif) #tree.time['brute_value'].append(brute[idx]) #tree.time['truth_x_value'].append(truth_bif) #tree.time['truth_value'].append(truth_best_vol) if nonconvex_solve: tree.nonconvex_counter += 1 else: tree.nonconvex_counter = 0 if tree.nonconvex_counter > 100 or tree.convex: tree.convex =True else: tree.convex =False #print("returning") return vessel,data,sub_division_map,sub_division_index,threshold else: #print('collision') attempt += 1 continue else: reduced_data = tree.data[tree.data[:,-1]>-1] segment_data = tree.data[tree.data[:,-1]==-1] vessel = np.random.choice(list(range(reduced_data.shape[0]))) vessel_path = segment_data[segment_data[:,29].astype(int)==vessel] other_vessels = segment_data[segment_data[:,29].astype(int)!=vessel] if reduced_data.shape[0] > 1: other_KDTree = KDTree((other_vessels[:,0:3]+other_vessels[:,3:6])/2) else: other_KDTree = None mesh,pa,cp,cd = tree.boundary.mesh(vessel_path[1:,0:3],threshold,threshold//fraction,dive=0,others=other_KDTree) D,PR = shortest_path(graph,directed=False,method="D",return_predecessors=True) bif_idx = set(list(range(mesh.shape[0])))