Source code for scilpy.tractograms.streamline_and_mask_operations

# -*- coding: utf-8 -*-
from enum import Enum
from multiprocessing import Pool

import numpy as np
from dipy.io.stateful_tractogram import StatefulTractogram

from nibabel.streamlines import ArraySequence

from scipy.ndimage import map_coordinates

from scilpy.tractograms.uncompress import streamlines_to_voxel_coordinates
from scilpy.tractograms.streamline_operations import \
    (_get_point_on_line, _get_streamline_pt_index,
     _get_next_real_point, _get_previous_real_point,
     filter_streamlines_by_length,
     resample_streamlines_step_size)


[docs] class CuttingStyle(Enum): DEFAULT = 0, KEEP_LONGEST = 1 TRIM_ENDPOINTS = 2
[docs] def get_endpoints_density_map(sft, point_to_select=1, to_millimeters=False, binary=False): """ Compute an endpoints density map, supports selecting more than one points at each end. Parameters ---------- sft: StatefulTractogram The streamlines to compute endpoints density from. point_to_select: int Instead of computing the density based on the first and last points, select more than one at each end. to_millimeters: bool Resample the streamlines to have a step size of 1 mm. This allows the user to compute endpoints with mms instead of points. Especially useful with compressed streamlines. binary: bool Return a binary mask. Returns ------- np.ndarray: A np.ndarray where voxel values represent the density of endpoints. """ if point_to_select == 1 and to_millimeters is False: # Very basic: selection of endpoints directly. # Uses nearest neighbor interpolation. If sft is in vox space, corner # origin, that's simply using the floor, or, even faster, casting to # int. endpoints_mask = np.zeros(sft.dimensions, dtype=int) sft.to_vox() sft.to_corner() sft.streamlines._data = sft.streamlines._data for streamline in sft.streamlines: endpoints_mask[tuple(streamline[0].astype(np.int16))] += 1 endpoints_mask[tuple(streamline[-1].astype(np.int16))] += 1 mask=endpoints_mask else: # For more complex options, using head + tail endpoints_map_head, endpoints_map_tail = \ get_head_tail_density_maps(sft, point_to_select, to_millimeters) mask = endpoints_map_head + endpoints_map_tail if binary: mask = mask.astype(bool) return mask
[docs] def get_head_tail_density_maps(sft, point_to_select=1, to_millimeters=False, binary=False, swap=False): """ Compute two separate endpoints density maps for the head and tail of a list of streamlines. Parameters ---------- sft: StatefulTractogram The streamlines to compute endpoints density from. point_to_select: int Instead of computing the density based on the first and last points, select more than one at each end. to_millimeters: bool Resample the streamlines to have a step size of 1 mm. This allows the user to compute endpoints with mms instead of points. Especially useful with compressed streamlines. binary: bool Return binary maps swap: bool Swap head and tail conventions Returns ------- endpoints_map_head: np.ndarray A volume where voxel values represent the density of head endpoints. endpoints_map_tail:np.ndarray A volume where voxel values represent the density of tail endpoints. """ sft.to_vox() sft.to_corner() if to_millimeters: # Resample the streamlines to have a step size of 1 mm streamlines = resample_streamlines_step_size(sft, 1.0).streamlines else: streamlines = sft.streamlines dimensions = sft.dimensions # Get the indices of the voxels intersected streamlines._data = streamlines._data.astype(np.float32) list_indices, points_to_indices = streamlines_to_voxel_coordinates( streamlines, return_mapping=True) # Initialize the endpoints maps endpoints_map_head = np.zeros(dimensions) endpoints_map_tail = np.zeros(dimensions) # A possible optimization would be to compute all coordinates first # and then do the np.add.at only once. for indices, points in zip(list_indices, points_to_indices): # Get the head and tail coordinates # +1 to include the last point point_to_select = min(point_to_select, len(points) - 1) head_indices = indices[:points[point_to_select] + 1, :] tail_indices = indices[points[-point_to_select]:, :] # Add the points to the endpoints map # Note: np.add.at is used to support duplicate points np.add.at(endpoints_map_head, tuple(head_indices.T), 1) np.add.at(endpoints_map_tail, tuple(tail_indices.T), 1) if binary: endpoints_map_head = (endpoints_map_head > 0).astype(np.int16) endpoints_map_tail = (endpoints_map_tail > 0).astype(np.int16) if swap: tmp = endpoints_map_head endpoints_map_head = endpoints_map_tail endpoints_map_tail = tmp return endpoints_map_head, endpoints_map_tail
def _trim_streamline_in_mask( idx, streamline, pts_to_idx, mask ): """ Trim streamlines to the bounding box or a binary mask. More streamlines may be generated if the original streamline goes in and out of the mask. Parameters ---------- idx: np.ndarray Indices of the voxels intersected by the streamline. streamline: np.ndarray The streamlines to cut. pts_to_idx: np.ndarray Mapping from streamline points to indices. mask: np.ndarray Boolean array representing the region. Returns ------- new_strmls : list of np.ndarray New streamlines trimmed within the mask. """ # Find all the points of the streamline that are in the ROIs roi_data_1_intersect = map_coordinates( mask, idx.T, order=0, mode='constant', cval=0) # Select the points that are not in the mask split_idx = np.arange(len(roi_data_1_intersect))[ roi_data_1_intersect == 0] # Split the streamline into segments that are in the mask split_strmls = np.array_split(np.arange(len(roi_data_1_intersect)), split_idx) new_strmls = [] for strml in split_strmls: if len(strml) <= 3: continue # Get the entry and exit points for each segment # Skip the first point as it caused the split in_strl_idx, out_strl_idx = strml[1], strml[-1] cut_strl = compute_streamline_segment(streamline, idx, in_strl_idx, out_strl_idx, pts_to_idx) new_strmls.append(cut_strl) return new_strmls def _trim_streamline_endpoints_in_mask( idx, streamline, pts_to_idx, mask ): """ Trim a streamline to remove its endpoints if they are outside of a mask. This function does not generate new streamlines. Parameters ---------- idx: np.ndarray Indices of the voxels intersected by the streamline. streamline: np.ndarray The streamlines to cut. pts_to_idx: np.ndarray Mapping from streamline points to indices. mask: np.ndarray Boolean array representing the region. Returns ------- streamline: np.ndarray The trimmed streamline within the mask. """ # Find all the points of the streamline that are in the ROIs roi_data_1_intersect = map_coordinates( mask, idx.T, order=0, mode='constant', cval=0) # Select the points that are in the mask mask_idx = np.arange(len(roi_data_1_intersect))[ roi_data_1_intersect == 1] if len(mask_idx) == 0: return [] # Get the entry and exit points for each segment in_strl_idx = np.amin(mask_idx) out_strl_idx = np.amax(mask_idx) cut_strl = compute_streamline_segment(streamline, idx, in_strl_idx, out_strl_idx, pts_to_idx) return [cut_strl] def _trim_streamline_in_mask_keep_longest( idx, streamline, pts_to_idx, mask ): """ Trim a streamline to keep the longest segment within a mask. This function does not generate new streamlines. Parameters ---------- idx: np.ndarray Indices of the voxels intersected by the streamline. streamline: np.ndarray The streamlines to cut. pts_to_idx: np.ndarray Mapping from streamline points to indices. mask: np.ndarray Boolean array representing the region. Returns ------- streamline: np.ndarray The trimmed streamline within the mask. """ # Find all the points of the streamline that are in the ROIs roi_data_1_intersect = map_coordinates( mask, idx.T, order=0, mode='constant', cval=0) # Select the points that are not in the mask split_idx = np.arange(len(roi_data_1_intersect))[ roi_data_1_intersect == 0] # Split the streamline into segments that are in the mask split_strmls = np.array_split(np.arange(len(roi_data_1_intersect)), split_idx) # Find the longest segment of the streamline that is in the mask longest_strml = max(split_strmls, key=len) if len(longest_strml) <= 1: return [] # Get the entry and exit points for the longest segment # Skip the first point as it caused the split id_to_pick = 0 if np.count_nonzero(roi_data_1_intersect) == len(idx) else 1 in_strl_idx, out_strl_idx = longest_strml[id_to_pick], longest_strml[-1] cut_strl = compute_streamline_segment(streamline, idx, in_strl_idx, out_strl_idx, pts_to_idx) return [cut_strl]
[docs] def cut_streamlines_with_mask(sft, mask, cutting_style=CuttingStyle.DEFAULT, min_len=0, processes=1): """ Cut streamlines according to a binary mask. This function erases the data_per_point. If keep_longest is set, the longest segment of the streamline that crosses the mask will be kept. Otherwise, the streamline will be cut at the mask. Parameters ---------- sft: StatefulTractogram The sft to cut streamlines (using a single mask with 1 entity) from. mask: np.ndarray Boolean array representing the region (must contain 1 entity) cutting_style: CuttingStyle How to cut the streamlines. Default is to cut the streamlines at the mask. If keep_longest is set, the longest segment of the streamline that crosses the mask will be kept. If trim_endpoints is set, the endpoints of the streamlines will be cut but the middle part of the streamline may go outside the mask. min_len: float Minimum length from the resulting streamlines. processes: int Number of processes to use. Returns ------- new_sft : StatefulTractogram New object with the streamlines trimmed within the mask. """ orig_space = sft.space orig_origin = sft.origin sft.to_vox() sft.to_corner() # Get the indices of the voxels # intersected by the streamlines and the mapping from points to indices sft.streamlines._data = sft.streamlines._data.astype(np.float32) indices, points_to_idx = streamlines_to_voxel_coordinates( sft.streamlines, return_mapping=True ) if len(sft.streamlines[0]) != len(points_to_idx[0]): raise ValueError("Error in the streamlines_to_voxel_coordinates " "function. Try running the " "scil_tractogram_remove_invalid.py script with the \n" "--remove_single_point and " "--remove_overlapping_points options.") # Select the trimming function. If keep_longest is set, the longest # segment of the streamline that crosses the mask will be kept. If # trim_endpoints is set, the endpoints of the streamlines will be cut. # Otherwise, the streamline will be cut at the mask. if cutting_style == CuttingStyle.TRIM_ENDPOINTS: trim_func = _trim_streamline_endpoints_in_mask elif cutting_style == CuttingStyle.KEEP_LONGEST: trim_func = _trim_streamline_in_mask_keep_longest else: trim_func = _trim_streamline_in_mask # Trim streamlines with the mask and return the new streamlines pool = Pool(processes) lists_of_new_strmls = pool.starmap( trim_func, [(i, s, pt, mask) for (i, s, pt) in zip( indices, sft.streamlines, points_to_idx)]) pool.close() # Flatten the list of lists of new streamlines in a single list of # new streamlines new_strmls = ArraySequence([strml for list_of_strml in lists_of_new_strmls for strml in list_of_strml]) new_sft = StatefulTractogram.from_sft( new_strmls, sft) # Put back the original space and origin new_sft.to_space(orig_space) new_sft.to_origin(orig_origin) sft.to_space(orig_space) sft.to_origin(orig_origin) new_sft, *_ = filter_streamlines_by_length(new_sft, min_length=min_len) return new_sft
[docs] def cut_streamlines_between_labels( sft, label_data, label_ids=None, min_len=0, one_point_in_roi=False, no_point_in_roi=False, processes=1 ): """ Cut streamlines so their segment are going from blob #1 to blob #2 in a binary mask. This function presumes strictly two blobs are present in the mask. This function erases the data_per_point. Parameters ---------- sft: StatefulTractogram The sft to cut streamlines (using a single mask with 2 entities) from. label_data: np.ndarray Label map representing the two regions. label_ids: list of int, optional The two labels to cut between. If not provided, the two unique labels in the label map will be used. min_len: float Minimum length from the resulting streamlines. one_point_in_roi: bool If True, one point in each ROI will be kept. no_point_in_roi: bool If True, no point in the ROIs will be kept. Returns ------- new_sft : StatefulTractogram New object with the streamlines trimmed within the masks. """ orig_space = sft.space orig_origin = sft.origin sft.to_vox() sft.to_corner() if label_ids is None: unique_vals = np.unique(label_data[label_data != 0]) if len(unique_vals) != 2: raise ValueError('More than two values in the label file, ' 'please select specific label ids.') else: unique_vals = label_ids # Create two binary masks label_data_1 = np.copy(label_data) mask = label_data_1 != unique_vals[0] label_data_1[mask] = 0 label_data_2 = np.copy(label_data) mask = label_data_2 != unique_vals[1] label_data_2[mask] = 0 sft.streamlines._data = sft.streamlines._data.astype(np.float32) (indices, points_to_idx) = streamlines_to_voxel_coordinates( sft.streamlines, return_mapping=True ) if len(sft.streamlines[0]) != len(points_to_idx[0]): raise ValueError("Error in the streamlines_to_voxel_coordinates " "function. Try running the " "scil_tractogram_remove_invalid.py script with the \n" "--remove_single_point and " "--remove_overlapping_points options.") # Trim streamlines with the mask and return the new streamlines pool = Pool(processes) lists_of_new_strmls = pool.starmap( _cut_streamline_with_labels, [(i, s, pt, label_data_1, label_data_2, one_point_in_roi, no_point_in_roi) for (i, s, pt) in zip( indices, sft.streamlines, points_to_idx)]) pool.close() # Flatten the list of lists of new streamlines in a single list of # new streamlines list_of_new_strmls = [strml for strml in lists_of_new_strmls if strml is not None] new_strmls = ArraySequence(list_of_new_strmls) new_sft = StatefulTractogram.from_sft(new_strmls, sft) # Put back the original space and origin new_sft.to_space(orig_space) new_sft.to_origin(orig_origin) sft.to_space(orig_space) sft.to_origin(orig_origin) new_sft, *_ = filter_streamlines_by_length(new_sft, min_length=min_len) return new_sft
def _cut_streamline_with_labels( idx, streamline, pts_to_idx, roi_data_1, roi_data_2, one_point_in_roi=False, no_point_in_roi=False ): """ Cut streamlines so their segment are going from label mask #1 to label mask #2. New endpoints may be generated to maximize the streamline length within the masks. Parameters ---------- idx: np.ndarray Indices of the voxels intersected by the streamlines. streamline: np.ndarray The streamlines to cut. pts_to_idx: np.ndarray Mapping from points to indices. roi_data_1: np.ndarray Boolean array representing the region #1. roi_data_2: np.ndarray Boolean array representing the region #2. one_point_in_roi: bool If True, one point in each ROI will be kept. no_point_in_roi: bool If True, no point in the ROIs will be kept. Returns ------- new_strmls : list of np.ndarray New streamlines trimmed within the masks. """ # Find the first and last "voxels" of the streamline that are in the # ROIs in_strl_idx, out_strl_idx = _intersects_two_rois( roi_data_1, roi_data_2, idx, one_point_in_roi=one_point_in_roi, no_point_in_roi=no_point_in_roi) cut_strl = None # If the streamline intersects both ROIs if in_strl_idx is not None and out_strl_idx is not None: # Compute the new streamline by keeping only the segment between # the two ROIs cut_strl = compute_streamline_segment(streamline, idx, in_strl_idx, out_strl_idx, pts_to_idx) return cut_strl def _get_all_streamline_segments_in_roi(all_strl_indices): """ Get the longest segment of a streamline that is in a ROI using the indices of the voxels intersected by the streamline. Parameters ---------- strl_indices: list of streamline indices (N) Returns ------- in_strl_idx : int Consectutive indices of the streamline that are in the ROI """ # If there is only one index, its likely invalid if len(all_strl_indices) == 1: return [None] # Find the gradient of the indices of the voxels intersecting with # the ROIs strl_indices_grad = np.gradient(all_strl_indices) split_pos = np.where(strl_indices_grad != 1)[0] + 1 # Covers weird cases where there is only non consecutive indices if len(strl_indices_grad) == len(split_pos) + 1: return [None] # Split the indices of the voxels intersecting with the ROIs into # segments where the gradient is 1 (i.e a chunk of consecutive indices) strl_indices_split = np.split(all_strl_indices, split_pos) return [sublist for sublist in strl_indices_split if sublist.size > 0] def _get_in_and_out_strl_indices(in_strl_indices_split, out_strl_indices_split, one_point_in_roi=False, no_point_in_roi=False): """ Get the first and last "voxels" of the streamline Parameters ---------- in_strl_indices_split: list List of np.array of streamline segment indices (N) out_strl_indices_split: list List of np.array of streamline segment indices (N) one_point_in_roi: bool If True, one point in each ROI will be kept. no_point_in_roi: bool If True, no point in the ROIs will be kept. Returns ------- in_strl_idx : int index of the first point of the streamline out_strl_idx : int index of the last point of the streamline """ # One of them is None takes the first segment of the other if in_strl_indices_split[0] is None: return None, out_strl_indices_split[0][-1] elif out_strl_indices_split[0] is None: return in_strl_indices_split[-1][0], None else: # Check the order of the first segments if min(in_strl_indices_split[0]) > min(out_strl_indices_split[0]): in_strl_indices_split, out_strl_indices_split = \ out_strl_indices_split, in_strl_indices_split # Get the last segment in the first ROI # Get the first segment in the second ROI in_strl_indices = in_strl_indices_split[-1] out_strl_indices = out_strl_indices_split[0] # If no options are set, start the streamline with the first # and last point of each segment if not one_point_in_roi and not no_point_in_roi: in_strl_idx = in_strl_indices[0] out_strl_idx = out_strl_indices[-1] else: if one_point_in_roi: add_indice = 0 elif no_point_in_roi: add_indice = 1 in_strl_idx = in_strl_indices[-1] + add_indice out_strl_idx = out_strl_indices[0] - add_indice return in_strl_idx, out_strl_idx def _intersects_two_rois(roi_data_1, roi_data_2, strl_indices, one_point_in_roi=False, no_point_in_roi=False): """ Find the first and last "voxels" of the streamline that are in the ROIs. Parameters ---------- roi_data_1: np.ndarray Boolean array representing the region #1 roi_data_2: np.ndarray Boolean array representing the region #2 strl_indices: list of tuple (N, 3) 3D indices of the voxels intersected by the streamline one_point_in_roi: bool If True, one point in each ROI will be kept. no_point_in_roi: bool If True, no point in the ROIs will be kept. Returns ------- in_strl_idx : int index of the first point (of the streamline) to be in the masks out_strl_idx : int index of the last point (of the streamline) to be in the masks """ # Find all the points of the streamline that are in the ROIs roi_data_1_intersect = map_coordinates( roi_data_1, strl_indices.T, order=0, mode='nearest') roi_data_2_intersect = map_coordinates( roi_data_2, strl_indices.T, order=0, mode='nearest') # Get the indices of the voxels intersecting with the ROIs in_strl_indices = np.argwhere(roi_data_1_intersect).squeeze(-1) out_strl_indices = np.argwhere(roi_data_2_intersect).squeeze(-1) # If there are no points in the ROIs, return None if len(in_strl_indices) == 0: in_strl_indices = [None] else: # Get the longest segment of the streamline that is in the ROI in_strl_indices = _get_all_streamline_segments_in_roi( in_strl_indices) if len(out_strl_indices) == 0: out_strl_indices = [None] else: out_strl_indices = _get_all_streamline_segments_in_roi( out_strl_indices) if in_strl_indices[0] is None and out_strl_indices[0] is None: return None, None else: in_strl_idx, out_strl_idx = _get_in_and_out_strl_indices( in_strl_indices, out_strl_indices, one_point_in_roi, no_point_in_roi) return in_strl_idx, out_strl_idx
[docs] def compute_streamline_segment(orig_strl, inter_vox, in_vox_idx, out_vox_idx, points_to_indices): """ Compute the segment of a streamline that is in a given ROI or between two ROIs. If the streamline does not have points in the ROI(s) but intersects it, new points are generated. Parameters ---------- orig_strl: np.ndarray The original streamline. inter_vox: np.ndarray The intersection points of the streamline with the voxel grid. in_vox_idx: int The index of the voxel where the streamline enters. out_vox_idx: int The index of the voxel where the streamline exits. points_to_indices: np.ndarray The indices of the voxels in the voxel grid. Returns ------- segment: np.ndarray The segment of the streamline that is in the voxel. """ additional_start_pt = None additional_exit_pt = None # Check if the ROI contains a real streamline point at # the beginning of the streamline in_strl_point = _get_streamline_pt_index(points_to_indices, in_vox_idx) # If not, find the next real streamline point if in_strl_point is None: # Find the index of the next real streamline point in_strl_point = _get_next_real_point(points_to_indices, in_vox_idx) if in_strl_point == 0: # If the entry point is the first point of the streamline, # don't generate a new point additional_start_pt = None else: # Generate an artificial point on the line between the previous # real point and the next real point additional_start_pt = _get_point_on_line( orig_strl[in_strl_point - 1], orig_strl[in_strl_point], inter_vox[in_vox_idx]) # Check if the ROI contains a real streamline point at # the end of the streamline out_strl_point = _get_streamline_pt_index(points_to_indices, out_vox_idx, from_start=False) # If not, find the previous real streamline point if out_strl_point is None: # Find the index of the previous real streamline point out_strl_point = _get_previous_real_point(points_to_indices, out_vox_idx) if out_strl_point == len(points_to_indices) - 1: # If the exit point is the last point of the streamline, # don't generate a new point additional_exit_pt = None else: # Generate an artificial point on the line between the previous # real point and the next real point additional_exit_pt = _get_point_on_line( orig_strl[out_strl_point], orig_strl[out_strl_point + 1], inter_vox[out_vox_idx]) # Set the segment as the part of the original streamline that is # in the ROI segment = orig_strl[in_strl_point:out_strl_point + 1] # Whereas the original implementation was using offsets to include # additional points, we are now inserting and appending to simplify # the code and avoid introducing more bugs. # If there is a new point at the beginning of the streamline # add it to the segment if additional_start_pt is not None: segment = np.insert(segment, 0, [additional_start_pt], axis=0) # If there is a new point at the end of the streamline # add it to the segment. if additional_exit_pt is not None: segment = np.append(segment, [additional_exit_pt], axis=0) # Return the segment return segment