Source code for scilpy.segment.voting_scheme

# -*- coding: utf-8 -*-

import gc
from itertools import product, repeat
import json
import logging
from multiprocessing import Manager
import multiprocessing
import os
from time import time
import warnings

from dipy.io.streamline import save_tractogram, load_tractogram
from dipy.segment.clustering import qbx_and_merge
from dipy.tracking.streamline import transform_streamlines
import nibabel as nib
from nibabel.streamlines.array_sequence import ArraySequence
import numpy as np

from scilpy.io.streamlines import streamlines_to_memmap
from scilpy.segment.bundleseg import BundleSeg
from scilpy.utils import get_duration

logger = logging.getLogger('BundleSeg')

# These parameters are leftovers from Recobundles.
# Now with BundleSeg, they do not need to be modified.
global MCT, TCT
MCT, TCT = 4, 12
# TCT means Tractogram Clustering Threshold (mm)
# MCT means Model Clustering Threshold (mm)


[docs] class VotingScheme(object): def __init__(self, config, atlas_directory, transformation, output_directory, minimal_vote_ratio=0.5, save_empty=False, ignore_metadata=False): """ Parameters ---------- config : dict Dictionary containing information relative to bundle recognition. atlas_directory : list List of all directories to be used as atlas by BundleSeg. Must contain all bundles as declared in the config file. transformation : numpy.ndarray Transformation (4x4) bringing the models into subject space. output_directory : str Directory name where all files will be saved. minimal_vote_ratio : float Value for the vote ratio for a streamline to be considered. (0 < minimal_vote_ratio < 1) save_empty : bool If True, save empty files for bundles that were not recognized. ignore_metadata : bool If True, ignore metadata in the tractogram if present. This will only \nconsider the geometry of the streamlines for saving. """ self.config = config self.minimal_vote_ratio = minimal_vote_ratio self.ignore_metadata = ignore_metadata self.save_empty = save_empty # Scripts parameters if isinstance(atlas_directory, list): self.atlas_dir = atlas_directory else: self.atlas_dir = [atlas_directory] self.transformation = transformation self.output_directory = output_directory def _load_bundles_dictionary(self): """ Using all bundles in the configuration file and the input models folders, generate all model filepaths. Bundles must exist across all folders. """ bundle_names = [] bundles_filepath = [] # Generate all input files based on the config file and model directory for key in self.config.keys(): bundle_names.append(key) all_atlas_models = product(self.atlas_dir, [key]) tmp_list = [os.path.join(tag, bundle) for tag, bundle in all_atlas_models] for filepath in tmp_list: if not os.path.isfile(filepath): basename = os.path.basename(filepath) tmp_list.remove(filepath) bundle_names.remove(basename) if len(tmp_list) > 0: bundles_filepath.append(tmp_list) logger.info(f'{len(self.atlas_dir)} sub-model directories were found. ' f'with {len(bundle_names)} model bundles total') model_bundles_dict = {} bundle_counts = [] for i, basename in enumerate(bundle_names): count = 0 for j, filename in enumerate(bundles_filepath[i]): if not os.path.isfile(filename): continue count += 1 model_bundles_dict[filename] = self.config[basename] bundle_counts.append(count) if sum(bundle_counts) == 0: raise IOError("No model bundles found, check input directory.") return model_bundles_dict, bundle_names, bundle_counts def _find_max_in_sparse_matrix(self, bundle_id, min_vote, bundles_wise_vote): """ Will find the maximum values of a specific row (bundle_id), make sure they are the maximum values across bundles (argmax) and above the min_vote threshold. Return the indices respecting all three conditions. Parameters ---------- bundle_id : int Indices of the bundles in the csr_matrix. min_vote : int Minimum value for considering (voting). bundles_wise_vote : scipy.sparse.csr_matrix bundles-wise sparse matrix use for voting. Returns ------- streamlines_ids : numpy.ndarray Indices of the streamlines that are above the min_vote threshold and are the maximum values across bundles. """ if min_vote == 0: streamlines_ids = np.asarray([], dtype=np.uint32) return streamlines_ids streamlines_ids = np.argwhere(bundles_wise_vote[bundle_id] >= min_vote) streamlines_ids = np.asarray(streamlines_ids, dtype=np.uint32) return streamlines_ids.reshape((-1,)) def _save_recognized_bundles(self, input_tractograms_path, reference, bundle_names, bundles_wise_vote, bundles_wise_score, minimum_vote, extension): """ Will save multiple TRK/TCK file and results.json (contains indices) To preserve DPS and DPP but preserve memory, will reload each input tractogram. Parameters ---------- input_tractograms_path : list List of input tractograms path used to reload the streamlines. reference : str Reference file for the header. bundle_names : list Bundle names as defined in the configuration file. Will save the bundle using that filename and the extension. bundles_wise_vote : lil_matrix Array of vote of shape (nbr_bundles x nbr_streamlines). bundles_wise_score : lil_matrix Array of score of shape (nbr_bundles x nbr_streamlines). minimum_vote : np.ndarray Value for the vote ratio for a streamline to be considered. (0 < minimal_vote < 1) extension : str Extension for file saving (TRK or TCK). """ results_dict = {} results_sft = {} tot_sft_len = 0 # To avoid load all inputs tractograms to slice them, reload each # one after the other and keep track of the indices/streamlines only for in_tractogram in input_tractograms_path: sft = load_tractogram(in_tractogram, reference) curr_sft_len = len(sft) for bundle_id in range(len(bundle_names)): # All models of the same bundle have the same basename basename = os.path.splitext(bundle_names[bundle_id])[0] if basename not in results_dict: streamlines_id = self._find_max_in_sparse_matrix( bundle_id, minimum_vote[bundle_id], bundles_wise_vote) if len(streamlines_id) == 0: streamlines_id = np.array([], dtype=np.uint32) logger.info(f'{bundle_names[bundle_id]} final recognition got ' f'{len(streamlines_id)} streamlines') else: streamlines_id = np.array( results_dict[basename]['indices']) # Need to make sure the indices are valid for this sft if len(sft) and len(streamlines_id): # Convert back to local indices (for this sft) streamlines_id = streamlines_id[streamlines_id >= tot_sft_len] streamlines_id = streamlines_id[streamlines_id < tot_sft_len + curr_sft_len] # If the user asked to ignore metadata, remove it (simpler) new_sft = sft[streamlines_id - tot_sft_len] if self.ignore_metadata: new_sft.data_per_point = {} new_sft.data_per_streamline = {} if basename in results_sft: try: results_sft[basename] += new_sft except ValueError: # This error message will be raised if the # DPP and DPS are not the same across (+= operator) raise ValueError(f"Could not merge SFT for {basename}, " f"try --ignore_metadata.") else: results_sft[basename] = new_sft # Populate the results dictionary (will be saved as json) curr_results_dict = {} curr_results_dict['indices'] = streamlines_id.tolist() if len(streamlines_id) > 0: scores = bundles_wise_score[bundle_id, streamlines_id].flatten() else: scores = np.array([], dtype=np.float16) curr_results_dict['scores'] = scores.tolist() results_dict[basename] = curr_results_dict tot_sft_len += len(sft) # Once everything is done, save all bundles, at the moment only # the bundles are held in memory (typically 1/10th of the tractogram) for basename in results_sft: sft = results_sft[basename] if len(sft) > 0 or self.save_empty: sft.remove_invalid_streamlines() save_tractogram(sft, os.path.join(self.output_directory, basename + extension)) out_logfile = os.path.join(self.output_directory, 'results.json') with open(out_logfile, 'w') as outfile: json.dump(results_dict, outfile) def __call__(self, input_tractograms_path, nbr_processes=1, seed=None, reference=None): """ Entry point function that generate the 'stack' of commands for dispatching and launch them using multiprocessing. Parameters ---------- input_tractograms_path : str Filepath of the whole brain tractogram to segment. nbr_processes : int Number of processes used for the parallel bundle recognition. seed : int Seed for the RandomState. """ # Load the subject tractogram load_timer = time() reference = input_tractograms_path[0] if reference is None else reference wb_streamlines = ArraySequence() for in_tractogram in input_tractograms_path: wb_streamlines.extend( nib.streamlines.load(in_tractogram).streamlines) len_wb_streamlines = len(wb_streamlines) logger.debug(f'Tractogram {input_tractograms_path} with ' f'{len_wb_streamlines} streamlines ' f'is loaded in {get_duration(load_timer)} seconds') total_timer = time() # Each type of bundle is processed separately model_bundles_dict, bundle_names, bundle_count = \ self._load_bundles_dictionary() thresholds = [45, 35, 25, TCT] rng = np.random.RandomState(seed) cluster_timer = time() with warnings.catch_warnings(record=True) as _: cluster_map = qbx_and_merge(wb_streamlines, thresholds, nb_pts=12, rng=rng, verbose=False) clusters_indices = [] for cluster in cluster_map.clusters: clusters_indices.append(cluster.indices) centroids = ArraySequence(cluster_map.centroids) clusters_indices = ArraySequence(clusters_indices) clusters_indices._data = clusters_indices._data.astype(np.uint32) logger.info(f'QBx with seed {seed} at {TCT}mm took ' f'{get_duration(cluster_timer)}sec. gave ' f'{len(cluster_map.centroids)} centroids') tmp_dir, tmp_memmap_filenames = streamlines_to_memmap(wb_streamlines, 'float16') # Memory cleanup (before multiprocessing) cluster_map.refdata = None for ref in gc.get_referrers(cluster_map) + \ gc.get_referrers(wb_streamlines): if isinstance(ref, ArraySequence): del ref._data del ref del wb_streamlines, cluster_map gc.collect() # End of memory cleanup bsg = BundleSeg(tmp_memmap_filenames, self.transformation, clusters_indices, centroids, rng=rng) # Update all BundleSeg initialisation into a single dictionnary with Manager() as manager: model_bundles_dict = manager.dict(model_bundles_dict) pool = multiprocessing.Pool(nbr_processes) all_recognized_dict = pool.imap_unordered( single_recognize, zip(repeat(bsg), model_bundles_dict.keys(), model_bundles_dict.values(), repeat(bundle_names))) pool.close() pool.join() logger.info(f'BundleSeg took {get_duration(total_timer)} sec. for ' f'{len(bundle_names)} bundles from {len(self.atlas_dir)} atlas') bundles_wise_vote = np.zeros((len(bundle_names), len_wb_streamlines), dtype=np.uint8) bundles_wise_score = np.zeros((len(bundle_names), len_wb_streamlines), dtype=np.float16) for bundle_id, recognized_indices, recognized_scores in all_recognized_dict: if recognized_indices is not None: if len(recognized_indices) == 0: continue bundles_wise_vote[bundle_id, recognized_indices.T] += 1 bundles_wise_score[bundle_id, recognized_indices.T] += recognized_scores bundles_wise_score[bundles_wise_vote != 0] \ /= bundles_wise_vote[bundles_wise_vote != 0] # Once everything was run, save the results using a voting system minimum_vote = np.array(bundle_count) * self.minimal_vote_ratio minimum_vote[np.logical_and(minimum_vote > 0, minimum_vote < 1)] = 1 minimum_vote = minimum_vote.astype(np.uint8) _, ext = os.path.splitext(input_tractograms_path[0]) save_timer = time() self._save_recognized_bundles(input_tractograms_path, reference, bundle_names, bundles_wise_vote, bundles_wise_score, minimum_vote, ext) tmp_dir.cleanup() saved_bundles = [f for f in os.listdir(self.output_directory) if os.path.splitext(f)[1] in ['.trk', '.tck']] logger.info(f'Saving of {len(saved_bundles)} files in ' f'{self.output_directory} took ' f'{get_duration(save_timer)} sec.')
[docs] def single_recognize_parallel(args): """Wrapper function to multiprocess recobundles execution.""" rbx = args[0] model_filepath = args[1] bundle_pruning_thr, model_bundle = args[2] bundle_names = args[3] return single_recognize(rbx, model_filepath, model_bundle, bundle_pruning_thr, bundle_names)
[docs] def single_recognize(args): """ Recobundle for a single bundle. Parameters ---------- bsg : Object Initialize BundleSeg object with QBx ClusterMap as values model_filepath : str Path to the model bundle file model_bundle: ArraySequence Model bundle. bundle_pruning_thr : float Threshold for pruning the model bundle bundle_names : list List of string with bundle names for models (to get bundle_id) seed : int Value to initialize the RandomState of numpy Returns ------- transf_neighbor : tuple bundle_id : (int) Unique value to each bundle to identify them. recognized_indices : (numpy.ndarray) Streamlines indices from the original tractogram. recognized_scores : (numpy.ndarray) Scores of the recognized streamlines. """ bsg = args[0] model_filepath = args[1] bundle_pruning_thr = args[2] bundle_names = args[3] model_bundle = nib.streamlines.load(model_filepath).streamlines model_bundle = transform_streamlines(model_bundle, bsg.transformation) # Use for logging and finding the bundle_id shorter_tag, ext = os.path.splitext(os.path.basename(model_filepath)) # Now hardcoded (not useful with FSS from Etienne) slr_transform_type = 'similarity' recognize_timer = time() results = bsg.recognize(model_bundle, model_clust_thr=MCT, pruning_thr=bundle_pruning_thr, slr_transform_type=slr_transform_type, identifier=shorter_tag) recognized_indices, recognized_scores = results del model_bundle._data, model_bundle logger.info(f'Model {shorter_tag} recognized {len(recognized_indices)} ' 'streamlines') logger.debug(f'Model {model_filepath} with parameters tct={TCT}, mct={MCT}, ' f'bpt={bundle_pruning_thr} ' f'took {get_duration(recognize_timer)} sec.') bundle_id = bundle_names.index(shorter_tag+ext) return bundle_id, recognized_indices, recognized_scores