Source code for scilpy.tractograms.intersection_finder

import time
import logging
import numpy as np

from scipy.spatial import KDTree
from scilpy.tracking.fibertube_utils import (streamlines_to_segments,
                                             dist_segment_segment)
from dipy.io.stateful_tractogram import StatefulTractogram
from scilpy.tracking.utils import tqdm_if_verbose


[docs] class IntersectionFinder: """ Utility class for finding intersections in a given StatefulTractogram with a diameter for each streamline. """ FLOAT_EPSILON = 1e-7 def __init__(self, in_sft: StatefulTractogram, diameters: list, shuffle_segments=True, rng_seed=0, verbose=False): """ Builds a KDTree from all the tractogram's segments and stores data required later for filtering. Parameters ---------- in_sft : StatefulTractogram Stateful Tractogram object containing streamlines to filter. diameters : list Diameters of each streamline of the tractogram. shuffle_segments: bool Should pick streamline segments randomly. If set to false, they will be picked in order from the first segment of the first streamline to the last segment of the last streamline. rng_seed : int Seed to be used for random number generation. verbose : bool Should produce verbose output. """ self.diameters = diameters self.max_diameter = np.max(diameters) self.rng_seed = rng_seed self.verbose = verbose self.in_sft = in_sft self.streamlines = in_sft.streamlines self.seg_centers, self.seg_indices, self.max_seg_length = ( streamlines_to_segments(self.streamlines, verbose=verbose)) if shuffle_segments: logging.debug("Shuffling streamline segments") indexes = list(range(len(self.seg_centers))) gen = np.random.default_rng(rng_seed) gen.shuffle(indexes) self.seg_centers = self.seg_centers[indexes] self.seg_indices = self.seg_indices[indexes] self.tree = KDTree(self.seg_centers) self._invalid = [] self._collisions = [] self._obstacle = [] self._excluded = [] if self.max_seg_length >= 0.3: logging.warning("The longest streamline segment is over 0.3mm. " + "Performance may drop significantly. " + "Resampling to ~0.2mm is recommended. " "(See scil_tractogram_resample_nb_points.py)") @property def invalid(self): """Streamlines that hit another streamline and should be filtered out.""" return self._invalid @property def collisions(self): """Collision point of each invalid streamline.""" return self._collisions @property def obstacle(self): """Streamlines hit by an invalid streamline. They should not be filtered and are saved separately merely for visualization.""" return self._obstacle @property def excluded(self): """Streamlines that don't collide, but should be excluded for other reasons.""" return self._excluded
[docs] def find_intersections(self, min_distance=0): """ Finds intersections within the initialized data of the object Produces and stores: invalid : ndarray[bool] Bit map identifying streamlines that hit another streamline and should be filtered out. collisions : ndarray[float32] Collision point of each collider. obstacle : ndarray[bool] Streamlines hit by invalid. They should not be filtered and are flagged simply for visualization. excluded : ndarray[bool] Streamlines that don't collide, but should be excluded for other reasons. (ex: distance does not respect min_distance) Parameters ---------- min_distance: float If set, streamlines will be filtered more aggressively so that even if they don\'t collide, being below [min_distance] apart (external to their diameter) will be interpreted as a collision. This option is the same as filtering with a large diameter but only saving a small diameter in out_tractogram. (Value in mm) """ start_time = time.time() streamlines = self.streamlines invalid = np.full((len(streamlines)), False, dtype=np.bool_) collisions = np.zeros((len(streamlines), 3), dtype=np.float32) obstacle = np.full((len(streamlines)), False, dtype=np.bool_) excluded = np.full((len(streamlines)), False, dtype=np.bool_) # si : Streamline Index | index of streamline within the tractogram. # pi : Point Index | index of point coordinate within a # streamline. # segi : Segment Index | index of streamline segment within the # entire tractogram. for segi, center in tqdm_if_verbose(enumerate(self.seg_centers), self.verbose, total=len(self.seg_centers)): si = self.seg_indices[segi][0] # [Pruning 1] If current streamline has already collided or been # excluded, skip. if invalid[si] or excluded[si]: continue neighbors = self.tree.query_ball_point( center, self.max_seg_length + self.max_diameter + min_distance, workers=-1) for neighbor_segi in neighbors: neighbor_si = self.seg_indices[neighbor_segi][0] # [Pruning 2] Skip if neighbor is our streamline if neighbor_si == si: continue # [Pruning 3] If neighbor has already collided or been # excluded, skip. if invalid[neighbor_si] or excluded[neighbor_si]: continue p0 = streamlines[si][self.seg_indices[segi][1]] p1 = streamlines[si][self.seg_indices[segi][1] + 1] q0 = streamlines[neighbor_si][ self.seg_indices[neighbor_segi][1]] q1 = streamlines[neighbor_si][ self.seg_indices[neighbor_segi][1] + 1] rp = self.diameters[si] / 2 rq = self.diameters[neighbor_si] / 2 distance, _, p_coll, q_coll = dist_segment_segment(p0, p1, q0, q1) external_distance = distance - rp - rq if external_distance < 0: invalid[si] = True # Estimate of collision point collisions[si] = (p_coll + q_coll) / 2 obstacle[neighbor_si] = True break if min_distance != 0 and external_distance < min_distance: excluded[si] = True break logging.debug("Finished finding intersections in " + str(round(time.time() - start_time, 2)) + " seconds.") self._invalid = invalid self._collisions = collisions self._obstacle = obstacle self._excluded = excluded
[docs] def build_tractograms(self, save_colliding): """ Builds and saves the various tractograms obtained from find_intersections(). Parameters ---------- save_colliding: bool If set, will return invalid_sft and obstacle_sft in addition to out_sft. Return ------ out_sft: StatefulTractogram Tractogram containing final streamlines void of collision. invalid_sft: StatefulTractogram | None Tractogram containing the invalid streamlines that have been removed. obstacle_sft: StatefulTractogram | None Tractogram containing the streamlines that the invalid streamlines collided with. May or may not have been removed afterwards during filtering. """ out_streamlines = [] out_diameters = [] out_collisions = [] out_invalid = [] out_obstacle = [] for si, s in tqdm_if_verbose(enumerate(self.streamlines), self.verbose, total=len(self.streamlines)): if self._invalid[si]: out_invalid.append(s) out_collisions.append(self._collisions[si]) elif not self._excluded[si]: out_streamlines.append(s) out_diameters.append(np.array([self.diameters[si]])) if self._obstacle[si]: out_obstacle.append(s) out_sft = StatefulTractogram.from_sft( out_streamlines, self.in_sft, data_per_streamline={'diameters': out_diameters}) if save_colliding: invalid_sft = StatefulTractogram.from_sft( out_invalid, self.in_sft, data_per_streamline={'collisions': out_collisions}) obstacle_sft = StatefulTractogram.from_sft( out_obstacle, self.in_sft) return out_sft, invalid_sft, obstacle_sft return out_sft, None, None