Source code for scilpy.segment.tractogram_from_roi

# -*- coding: utf-8 -*-

import itertools
import logging

import nibabel as nib
import numpy as np
import os
from scipy.ndimage import binary_dilation

from dipy.io.streamline import save_tractogram
from dipy.tracking.utils import length as compute_length

from scilpy.image.utils import \
    split_mask_blobs_kmeans
from scilpy.io.image import get_data_as_mask
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.segment.streamlines import filter_grid_roi, filter_grid_roi_both_ends
from scilpy.tractograms.streamline_operations import \
    remove_loops_and_sharp_turns
from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map

from scilpy.tractograms.streamline_operations import \
    filter_streamlines_by_total_length_per_dim
from scilpy.utils.filenames import split_name_with_nii


def _extract_prefix(filename):
    prefix = os.path.basename(filename)
    prefix, _ = split_name_with_nii(prefix)

    return prefix


[docs] def compute_masks_from_bundles(gt_files, parser, args): """ Compute ground-truth masks. If the file is already a mask, load it. If it is a bundle, compute the mask. If the filename is None, appends None to the lists of masks. Compatibility between files should already be verified. Parameters ---------- gt_files: list List of either StatefulTractograms or niftis. parser: ArgumentParser Argument parser which handles the script's arguments. Used to print parser errors, if any. args: Namespace List of arguments passed to the script. Used for its 'ref' and 'bbox_check' arguments. Returns ------- mask: list[numpy.ndarray] The loaded masks. """ save_ref = args.reference gt_bundle_masks = [] for gt_bundle in gt_files: if gt_bundle is not None: # Support ground truth as streamlines or masks # Will be converted to binary masks immediately _, ext = split_name_with_nii(gt_bundle) if ext in ['.gz', '.nii.gz']: gt_img = nib.load(gt_bundle) gt_mask = get_data_as_mask(gt_img) else: # Cheating ref because it may send a lot of warning if loading # many trk with ref (reference was maybe added only for some # of these files) if ext == '.trk': args.reference = None else: args.reference = save_ref gt_sft = load_tractogram_with_reference( parser, args, gt_bundle) gt_sft.to_vox() gt_sft.to_corner() _, dimensions, _, _ = gt_sft.space_attributes gt_mask = compute_tract_counts_map(gt_sft.streamlines, dimensions).astype(np.int16) gt_mask[gt_mask > 0] = 1 else: gt_mask = None gt_bundle_masks.append(gt_mask) args.reference = save_ref return gt_bundle_masks
def _extract_and_save_tails_heads_from_endpoints(gt_endpoints, out_dir): """ Extract two masks from a single mask containing two regions. Parameters ---------- gt_endpoints: str Ground-truth mask filename. out_dir: str Path where to save the heads and tails. Returns ------- tails: list List of tail filenames. heads: list List of head filenames. affine: numpy.ndarray Affine of mask image. dimensions: tuple of int Dimensions of the mask image. """ mask_img = nib.load(gt_endpoints) mask = get_data_as_mask(mask_img) affine = mask_img.affine dimensions = mask.shape head, tail = split_mask_blobs_kmeans(mask, nb_clusters=2) basename = os.path.basename(split_name_with_nii(gt_endpoints)[0]) tail_filename = os.path.join(out_dir, '{}_tail.nii.gz'.format(basename)) head_filename = os.path.join(out_dir, '{}_head.nii.gz'.format(basename)) nib.save(nib.Nifti1Image(head.astype(mask.dtype), affine), head_filename) nib.save(nib.Nifti1Image(tail.astype(mask.dtype), affine), tail_filename) return tail_filename, head_filename, affine, dimensions
[docs] def compute_endpoint_masks(roi_options, out_dir): """ If endpoints without heads/tails are loaded, split them and continue normally after. Q/C of the output is important. Compatibility between files should be already verified. Parameters ------ roi_options: dict Keys are the bundle names. For each bundle, the value is itself a dictionary either key 'gt_endpoints' (the name of the file containing the bundle's endpoints), or both keys 'gt_tail' and 'gt_head' (the names of the respetive files). out_dir: str Where to save the heads and tails. Returns ------- tails, heads: lists of filenames with length the number of bundles. """ tails = [] heads = [] for bundle_options in roi_options: if 'gt_endpoints' in bundle_options: tail, head, _, _ = _extract_and_save_tails_heads_from_endpoints( bundle_options['gt_endpoints'], out_dir) else: tail = bundle_options['gt_tail'] head = bundle_options['gt_head'] tails.append(tail) heads.append(head) return tails, heads
def _extract_vb_and_wpc_all_bundles( gt_tails, gt_heads, sft, bundle_names, lengths, angles, orientation_lengths, abs_orientation_lengths, all_masks, any_masks, out_dir, unique, dilate_endpoints, save_wpc_separately, remove_wpc_belonging_to_another_bundle): """ Loop on every ground truth bundles and extract VS and WPC. VS: 1) Connect the head and tail 2) Are completely included in the all_mask (if any) 3) Have acceptable angle, length and length per orientation. 4) Reach the any_mask (if any) + WPC connections: 1) connect the head and tail but criteria 2 and 3 are not respected Returns ------- vb_sft_list: list List of StatefulTractograms of VS wpc_sft_list: list List of StatefulTractograms of WPC if save_wpc_separately), else None. all_vs_wpc_ids: list List of list of all VS + WPC streamlines detected. bundle_stats_dict: dict Dictionnary of the processing information for each bundle. Saves ----- - Each duplicate in segmented_conflicts/duplicates_*_*.trk """ nb_bundles = len(bundle_names) vb_sft_list = [] vs_ids_list = [] wpc_ids_list = [] bundles_stats = [] remaining_ids = np.arange(len(sft)) # For unique management. # 1. Extract VB and WPC. for i in range(nb_bundles): head_filename = gt_heads[i] tail_filename = gt_tails[i] vs_ids, wpc_ids, bundle_stats = \ _extract_vb_one_bundle( sft[remaining_ids], head_filename, tail_filename, lengths[i], angles[i], orientation_lengths[i], abs_orientation_lengths[i], all_masks[i], any_masks[i], dilate_endpoints) if unique: # Assign actual ids, not from subset vs_ids = remaining_ids[vs_ids] wpc_ids = remaining_ids[wpc_ids] # Update remaining_ids based on valid streamlines only remaining_ids = np.setdiff1d(remaining_ids, vs_ids, assume_unique=True) # Append info vb_sft = sft[vs_ids] vb_sft_list.append(vb_sft) vs_ids_list.append(vs_ids) wpc_ids_list.append(wpc_ids) bundles_stats.append(bundle_stats) logging.info("Bundle {}: nb VS = {}" .format(bundle_names[i], bundle_stats["VS"])) all_gt_ids = list(itertools.chain(*vs_ids_list)) # 2. Remove duplicate WPC and then save. if save_wpc_separately: if remove_wpc_belonging_to_another_bundle or unique: for i in range(nb_bundles): new_wpc_ids = np.setdiff1d(wpc_ids_list[i], all_gt_ids) nb_rejected = len(wpc_ids_list[i]) - len(new_wpc_ids) bundles_stats[i].update( {"Belonging to another bundle": nb_rejected}) wpc_ids_list[i] = new_wpc_ids bundles_stats[i].update({"Cleaned WPC": len(new_wpc_ids)}) wpc_sft_list = [] for i in range(nb_bundles): logging.info("Bundle {}: nb WPC = {}" .format(bundle_names[i], len(wpc_ids_list[i]))) wpc_ids = wpc_ids_list[i] if len(wpc_ids) == 0: wpc_sft = None else: wpc_sft = sft[wpc_ids] wpc_sft_list.append(wpc_sft) else: # Remove WPCs to be included as IS in the future wpc_ids_list = [[] for _ in range(nb_bundles)] wpc_sft_list = None # 3. If not unique, tell users if there were duplicates. Save # duplicates separately in segmented_conflicts/duplicates_*_*.trk. if not unique: for i in range(nb_bundles): for j in range(i + 1, nb_bundles): duplicate_ids = np.intersect1d(vs_ids_list[i], vs_ids_list[j]) if len(duplicate_ids) > 0: logging.warning( "{} streamlines belong to true connections of both " "bundles {} and {}.\n" "Please verify your criteria!" .format(len(duplicate_ids), bundle_names[i], bundle_names[j])) # Duplicates directory only created if at least one # duplicate is found. path_duplicates = os.path.join(out_dir, 'segmented_conflicts') if not os.path.isdir(path_duplicates): os.makedirs(path_duplicates) save_tractogram(sft[duplicate_ids], os.path.join( path_duplicates, 'duplicates_' + bundle_names[i] + '_' + bundle_names[j] + '.trk')) # 4. Save bundle stats. bundle_stats_dict = {} for i in range(len(bundle_names)): bundle_stats_dict.update({bundle_names[i]: bundles_stats[i]}) all_vs_ids = np.unique(list(itertools.chain(*vs_ids_list))) all_wpc_ids = np.unique(list(itertools.chain(*wpc_ids_list))) all_vs_wpc_ids = np.concatenate((all_vs_ids, all_wpc_ids)) return vb_sft_list, wpc_sft_list, all_vs_wpc_ids, bundle_stats_dict def _extract_vb_one_bundle( sft, head_filename, tail_filename, limits_length, angle, orientation_length, abs_orientation_length, all_mask, any_mask, dilate_endpoints): """ Extract valid bundle (and valid streamline ids) from a tractogram, based on two regions of interest for the endpoints, one region of interest for the inclusion of streamlines, and maximum length, maximum angle, maximum length per orientation. Parameters ---------- sft: StatefulTractogram Tractogram containing the streamlines to be extracted. head_filename: str Filename of the "head" of the bundle. tail_filename: str Filename of the "tail" of the bundle. limits_length: list or None Bundle's length Parameters [min max]. angle: int or None Bundle's max angle. orientation_length: list or None Bundle's length parameters in each direction: [[min_x, max_x], [min_y, max_y], [min_z, max_z]] abs_orientation_length: idem, computed in absolute values. all_mask: np.ndarray or None The "ALL" mask for this bundle: no point must be outside the mask. any_mask: np.ndarray or None ANY mask for this bundle. Streamlines must pass through this mask (touch it) to be included in the bundle. dilate_endpoints: int or None If set, dilate the masks for n iterations. Returns ------- vs_ids: list List of ids of valid streamlines wpc_ids: list List of ids of wrong-path connections bundle_stats: dict Dictionary of recognized streamlines statistics """ if len(sft) > 0: mask_1_img = nib.load(head_filename) mask_2_img = nib.load(tail_filename) mask_1 = get_data_as_mask(mask_1_img) mask_2 = get_data_as_mask(mask_2_img) if dilate_endpoints: mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints) mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints) _, vs_ids = filter_grid_roi_both_ends(sft, mask_1, mask_2) else: vs_ids = np.array([]) wpc_ids = [] bundle_stats = {"Initial count head to tail": len(vs_ids)} # Remove out of inclusion mask (limits_mask) if len(vs_ids) > 0 and all_mask is not None: tmp_sft = sft[vs_ids] # ALL points inside = NO points outside = NOT ANY point outside # Inversing the mask. all_mask = all_mask.astype(bool) inv_all_mask = ~all_mask out_of_mask_ids_from_vs = filter_grid_roi( tmp_sft, inv_all_mask, 'any', is_exclude=False) out_of_mask_ids = vs_ids[out_of_mask_ids_from_vs] bundle_stats.update({"WPC_out_of_mask": len(out_of_mask_ids)}) # Update ids wpc_ids.extend(out_of_mask_ids) vs_ids = np.setdiff1d(vs_ids, wpc_ids) # Remove streamlines not passing through any_mask if len(vs_ids) > 0 and any_mask is not None: tmp_sft = sft[vs_ids] in_mask_ids_from_vs = filter_grid_roi( tmp_sft, any_mask, 'any', is_exclude=False) in_mask_ids = vs_ids[in_mask_ids_from_vs] out_of_mask_ids = np.setdiff1d(vs_ids, in_mask_ids) bundle_stats.update({"WPC_not_reaching_the_ANY_mask": len(out_of_mask_ids)}) # Update ids wpc_ids.extend(out_of_mask_ids) vs_ids = in_mask_ids # Remove invalid lengths if len(vs_ids) > 0 and limits_length is not None: min_len, max_len = limits_length # Bring streamlines to world coordinates so proper length # is calculated sft.to_rasmm() lengths = np.array(list(compute_length(sft.streamlines[vs_ids]))) sft.to_vox() # Compute valid lengths valid_length_ids_mask_from_vs = np.logical_and(lengths > min_len, lengths < max_len) bundle_stats.update({ "WPC_invalid_length": int(sum(~valid_length_ids_mask_from_vs))}) # Update ids wpc_ids.extend(vs_ids[~valid_length_ids_mask_from_vs]) vs_ids = vs_ids[valid_length_ids_mask_from_vs] # Remove invalid lengths per orientation if len(vs_ids) > 0 and orientation_length is not None: # Compute valid lengths limits_x, limits_y, limits_z = orientation_length _, valid_orientation_ids_from_vs, _ = \ filter_streamlines_by_total_length_per_dim( sft[vs_ids], limits_x, limits_y, limits_z, use_abs=False, save_rejected=False) # Update ids valid_orientation_ids = vs_ids[valid_orientation_ids_from_vs] invalid_orientation_ids = np.setdiff1d(vs_ids, valid_orientation_ids) bundle_stats.update({ "WPC_invalid_orientation": len(invalid_orientation_ids)}) wpc_ids.extend(invalid_orientation_ids) vs_ids = valid_orientation_ids # Idem in abs if len(vs_ids) > 0 and abs_orientation_length is not None: # Compute valid lengths limits_x, limits_y, limits_z = abs_orientation_length _, valid_orientation_ids_from_vs, _ = \ filter_streamlines_by_total_length_per_dim( sft[vs_ids], limits_x, limits_y, limits_z, use_abs=True, save_rejected=False) # Update ids valid_orientation_ids = vs_ids[valid_orientation_ids_from_vs] invalid_orientation_ids = np.setdiff1d(vs_ids, valid_orientation_ids) bundle_stats.update({ "WPC_invalid_orientation_abs": len(invalid_orientation_ids)}) wpc_ids.extend(invalid_orientation_ids) vs_ids = valid_orientation_ids # Remove loops from tc if len(vs_ids) > 0 and angle is not None: # Compute valid angles valid_angle_ids_from_vs = remove_loops_and_sharp_turns( sft.streamlines[vs_ids], angle) # Update ids valid_angle_ids = vs_ids[valid_angle_ids_from_vs] invalid_angle_ids = np.setdiff1d(vs_ids, valid_angle_ids) bundle_stats.update({"WPC_invalid_length": len(invalid_angle_ids)}) wpc_ids.extend(invalid_angle_ids) vs_ids = valid_angle_ids bundle_stats.update({"VS": len(vs_ids)}) return list(vs_ids), list(wpc_ids), bundle_stats def _extract_ib_one_bundle(sft, mask_1_filename, mask_2_filename, dilate_endpoints): """ Extract false connections based on two regions from a tractogram. Parameters ---------- sft: StatefulTractogram Tractogram containing the streamlines to be extracted. mask_1_filename: str Filename of the "head" of the bundle. mask_2_filename: str Filename of the "tail" of the bundle. dilate_endpoints: int or None If set, dilate the masks for n iterations. Returns ------- fc_sft: StatefulTractogram SFT of false connections. sft: StatefulTractogram SFT of remaining streamlines. """ if len(sft) > 0: mask_1_img = nib.load(mask_1_filename) mask_2_img = nib.load(mask_2_filename) mask_1 = get_data_as_mask(mask_1_img) mask_2 = get_data_as_mask(mask_2_img) if dilate_endpoints: mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints) mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints) _, fc_ids = filter_grid_roi_both_ends(sft, mask_1, mask_2) else: fc_ids = [] fc_sft = sft[fc_ids] return fc_sft, fc_ids def _extract_ib_all_bundles(comb_filename, sft, unique, dilate_endpoints): """ Loop on every bundle and compute false connections, defined as connections between ROIs pairs that do not form gt bundles. (Goes through all the possible combinations of endpoints masks) """ ib_sft_list = [] ic_ids_list = [] ib_bundle_names = [] all_ids = np.arange(len(sft)) for i, roi in enumerate(comb_filename): roi1_filename, roi2_filename = roi # Automatically generate filename for Q/C prefix_1 = _extract_prefix(roi1_filename) prefix_2 = _extract_prefix(roi2_filename) ib_sft, ic_ids = _extract_ib_one_bundle( sft[all_ids], roi1_filename, roi2_filename, dilate_endpoints) if unique: ic_ids = all_ids[ic_ids] all_ids = np.setdiff1d(all_ids, ic_ids, assume_unique=True) if len(ib_sft.streamlines) > 0: logging.info("IB: Recognized {} streamlines between {} and {}" .format(len(ib_sft.streamlines), prefix_1, prefix_2)) ib_sft_list.append(ib_sft) ic_ids_list.append(ic_ids) ib_bundle_names.append(prefix_1 + '_' + prefix_2) # Duplicates? if not unique: nb_pairs = len(ic_ids_list) for i in range(nb_pairs): for j in range(i + 1, nb_pairs): duplicate_ids = np.intersect1d(ic_ids_list[i], ic_ids_list[j]) if len(duplicate_ids) > 0: logging.warning( "{} streamlines are scored twice as invalid " "connections\n (between pair {}\n and between pair " "{}). You probably have overlapping ROIs!" .format(len(duplicate_ids), comb_filename[i], comb_filename[j])) return ib_sft_list, ic_ids_list, ib_bundle_names
[docs] def segment_tractogram_from_roi( sft, gt_tails, gt_heads, bundle_names, bundle_lengths, angles, orientation_lengths, abs_orientation_lengths, all_masks, any_masks, out_dir, compute_ic=False, save_wpc_separately=False, unique=True, remove_wpc_belonging_to_another_bundle=True, no_empty=True, bbox_check=True, dilate_endpoints=0): """ Segments valid bundles (VB). Parameters ---------- sft: StatefulTractogram The tractogram to segment. gt_tails: list[str] List of filenames, each VB endpoint mask (first end) gt_heads: list[str] List of filenames, each VB endpoint mask (second end), in the same order as gt_tails. Ex, VB #2 uses gt_tails[2] and gt_head[2] as endpoints. bundle_names: list[str] Bundle names. bundle_lengths: list[[float, float] or None] Maximum length for each bundle. Either a limit range, [float, float] or None for no limit. angles: list[float] Maximum angle (in loops) for each bundle (in degree). orientation_lengths: list[[limitsx, limitsy, limitsz] or None] For each bundle, the length parameters in each direction: [[min_x, max_x], [min_y, max_y], [min_z, max_z]]. None for no limit. abs_orientation_lengths: list[[limitsx, limitsy, limitsz] or None] Idem, computed in absolute values. all_masks: list[np.ndarray or None] For each bundle, the "ALL" mask for this bundle: no point must be outside the mask. any_masks: list[np.ndarray or None] For each bundle, the "ANY" mask for this bundle: at least one point must pass through this mask. out_dir: str Output directory. We will save all VB, IC and WPC there. compute_ic: bool Also compute invalid connections (IC). save_wpc_separately: bool Separate wrong path connections (WPC) from other invalid connections (IC). WPC = correct endpoint ROIs but wrong path based on other criteria. unique: bool If True, streamlines are assigned to the first bundle they fit in and not to all. remove_wpc_belonging_to_another_bundle: bool If true, WPC actually belonging to any VB (in the case of overlapping ROIs) will be removed from the WPC classification. no_empty: bool If true, do not save empty bundles. bbox_check: bool If true, check bounding box validation. dilate_endpoints: int Dilate endpoint masks n-times. Default: 0. Returns ------- vb_sft_list: list[StatefulTractogram] The list of valid bundles discovered. These files are also saved in segmented_VB/\\*_VS.trk. wpc_sft_list: list[StatefulTractogram or None] or None The list of wrong path connections: streamlines connecting the right endpoint regions but not included in the ALL mask. This list has the same length as vb_sft_list. ** This is only computed if save_wpc_separately. Else, this is None. ib_sft_list: list[StatefulTractogram] or None The list of invalid bundles: streamlines connecting regions that should not be connected. ** This is only computed if compute_ic. Else, this is None. nc_sft_list: list[StatefulTractogram] or None The list of rejected streamlines that were not included in any IB. ib_names: list[StatefulTractogram] or None The list of names for invalid bundles (IB). They are created from the combinations of ROIs used for IB computations. bundle_stats: dict Dictionnary of the processing information for each VB bundle. """ sft.to_vox() # VS logging.info("Extracting valid bundles (and wpc, if any)") vb_sft_list, wpc_sft_list, detected_vs_wpc_ids, bundle_stats = \ _extract_vb_and_wpc_all_bundles( gt_tails, gt_heads, sft, bundle_names, bundle_lengths, angles, orientation_lengths, abs_orientation_lengths, all_masks, any_masks, out_dir, unique, dilate_endpoints, save_wpc_separately, remove_wpc_belonging_to_another_bundle) remaining_ids = np.arange(0, len(sft)) if unique: remaining_ids = np.setdiff1d(remaining_ids, detected_vs_wpc_ids) # IC list_rois = gt_tails + gt_heads list_rois = list(dict.fromkeys(list_rois)) # Removes duplicates if compute_ic and len(remaining_ids) > 0: logging.info("Extracting invalid bundles") # Keep all possible combinations list_rois = sorted(list_rois) comb_filename = list(itertools.combinations(list_rois, r=2)) # Remove the true connections from all combinations, leaving only # false connections vb_roi_filenames = list(zip(gt_tails, gt_heads)) for vb_roi_pair in vb_roi_filenames: vb_roi_pair = tuple(sorted(vb_roi_pair)) comb_filename.remove(vb_roi_pair) ib_sft_list, ic_ids_list, ib_names = _extract_ib_all_bundles( comb_filename, sft[remaining_ids], unique, dilate_endpoints) if unique and len(ic_ids_list) > 0: for i in range(len(ic_ids_list)): # Assign actual ids ic_ids_list[i] = remaining_ids[ic_ids_list[i]] detected_vs_wpc_ids = np.concatenate(ic_ids_list) remaining_ids = np.setdiff1d(remaining_ids, detected_vs_wpc_ids) else: ic_ids_list = [] ib_sft_list = [] ib_names = [] all_ic_ids = np.unique(list(itertools.chain(*ic_ids_list))) # NC # = ids that are not VS, not wpc (if asked) and not IC (if asked). all_nc_ids = remaining_ids if not unique: all_nc_ids = np.setdiff1d(all_nc_ids, detected_vs_wpc_ids) all_nc_ids = np.setdiff1d(all_nc_ids, all_ic_ids) if compute_ic: logging.info("The remaining {} / {} streamlines will be scored as NC." .format(len(all_nc_ids), len(sft))) filename = "NC.trk" else: logging.info("The remaining {} / {} streamlines will be scored as IS." .format(len(all_nc_ids), len(sft))) filename = "IS.trk" nc_sft = sft[all_nc_ids] if len(nc_sft) > 0 or not no_empty: save_tractogram(nc_sft, os.path.join( out_dir, filename), bbox_valid_check=bbox_check) return (vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, ib_names, bundle_stats)