Source code for scilpy.segment.recobundlesx

# -*- coding: utf-8 -*-

import logging
from time import time
import warnings

from dipy.align.streamlinear import (BundleMinDistanceMetric,
                                     StreamlineLinearRegistration)
from dipy.segment.fss import FastStreamlineSearch
from dipy.segment.clustering import qbx_and_merge
from dipy.tracking.distances import bundles_distances_mdf
from dipy.tracking.streamline import (select_random_set_of_streamlines,
                                      transform_streamlines)
from nibabel.streamlines.array_sequence import ArraySequence
import numpy as np

from scilpy.io.streamlines import reconstruct_streamlines_from_memmap


[docs] class RecobundlesX(object): """ This class is a 'remastered' version of the Dipy Recobundles class. Made to be used in sync with the voting_scheme. Adapted in many way for HPC processing and increase control over parameters and logging. However, in essence they do the same thing. https://github.com/nipy/dipy/blob/master/dipy/segment/bundles.py References ---------- .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter bundles using local and global streamline-based registration and clustering, Neuroimage, 2017. """ def __init__(self, memmap_filenames, clusters_indices, wb_centroids, rng=None): """ Parameters ---------- memmap_filenames : tuple tuple of filenames for the data, offsets and lengths. clusters_indices: ArraySequence ArraySequence containing the indices of the streamlines per cluster. wb_centroids : list of numpy.ndarray List contaning the average streamline per cluster as obtained from qbx. rng : RandomState If None then RandomState is initialized internally. """ self.memmap_filenames = memmap_filenames self.wb_clusters_indices = clusters_indices self.centroids = wb_centroids self.rng = rng # For declaration outside of init self.neighb_centroids = None self.neighb_indices = None self.models_streamlines = None self.model_centroids = None self.final_pruned_indices = None
[docs] def recognize(self, model_bundle, model_clust_thr=8, pruning_thr=8, slr_transform_type='similarity', identifier=None): """ Parameters ---------- model_bundle : list or ArraySequence Model bundle as loaded by the nibabel API. model_clust_thr : obj Distance threshold (mm) for model clustering (QBx) pruning_thr : int Distance threshold (mm) for the final pruning. slr_transform_type : str Define the transformation for the local SLR. [translation, rigid, similarity, scaling] identifier : str Identify the current bundle being recognized for the logging. Returns ------- clusters : list Streamlines that were recognized by Recobundles and these parameters. """ self.model_streamlines = model_bundle self._cluster_model_bundle(model_clust_thr, identifier=identifier) if not self._reduce_search_space(neighbors_reduction_thr=16): if identifier: logging.error('{0} did not find any neighbors in ' 'the tractogram'.format(identifier)) return np.array([], dtype=np.uint32) self._register_model_to_neighb(slr_transform_type=slr_transform_type) # self._reduce_search_space(neighbors_reduction_thr=12) if not self._reduce_search_space(neighbors_reduction_thr=12): if identifier: logging.error('{0} did not find any neighbors in ' 'the tractogram'.format(identifier)) return np.array([], dtype=np.uint32) self.prune_far_from_model(pruning_thr=pruning_thr) self.cleanup() return self.get_final_pruned_indices()
def _cluster_model_bundle(self, model_clust_thr, identifier): """ Wrapper function to compute QBx for the model and logging information. Parameters ---------- model_clust_thr, float, distance in mm for clustering. identifier, str, name of the bundle for logging. """ thresholds = [30, 20, 15, model_clust_thr] model_cluster_map = qbx_and_merge(self.model_streamlines, thresholds, nb_pts=12, rng=self.rng, verbose=False) self.model_centroids = ArraySequence(model_cluster_map.centroids) len_centroids = len(self.model_centroids) if len_centroids > 1000: logging.warning('Model {0} simplified at threshold ' '{1}mm with {2} centroids'.format(identifier, str(model_clust_thr), str(len_centroids))) def _reduce_search_space(self, neighbors_reduction_thr=18): """ Wrapper function to discard clusters from the tractogram too far from the model and logging informations. :param neighbors_reduction_thr, float, distance in mm for thresholding to discard distant streamlines. """ centroid_matrix = bundles_distances_mdf(self.model_centroids, self.centroids).astype(np.float16) centroid_matrix[centroid_matrix > neighbors_reduction_thr] = np.inf mins = np.min(centroid_matrix, axis=0) close_clusters_indices = np.array(np.where(mins != np.inf)[0], dtype=np.uint32) self.neighb_indices = [] for i in close_clusters_indices: self.neighb_indices.extend(self.wb_clusters_indices[i]) self.neighb_indices = np.array(self.neighb_indices, dtype=np.uint32) self.neighb_centroids = [self.centroids[i] for i in close_clusters_indices] return self.neighb_indices.size def _register_model_to_neighb(self, select_model=1000, select_target=1000, slr_transform_type='similarity'): """ Parameters ---------- select_model : int Maximum number of clusters to select from the model. select_target : int Maximum number of clusters to select from the neighborhood. slr_transform_type : str Define the transformation for the local SLR. [translation, rigid, similarity, scaling]. Returns ------- transf_neighbor : list The neighborhood clusters transformed into model space. """ possible_slr_transform_type = {'translation': 0, 'rigid': 1, 'similarity': 2, 'scaling': 3} static = select_random_set_of_streamlines(self.model_centroids, select_model, self.rng) moving = select_random_set_of_streamlines(self.neighb_centroids, select_target, self.rng) # Tuple 0,1,2 are the min & max bound in x,y,z for translation # Tuple 3,4,5 are the min & max bound in x,y,z for rotation # Tuple 6,7,8 are the min & max bound in x,y,z for scaling # For uniform scaling (similarity), tuple #6 is enough bounds_dof = [(-20, 20), (-20, 20), (-20, 20), (-10, 10), (-10, 10), (-10, 10), (0.8, 1.2), (0.8, 1.2), (0.8, 1.2)] metric = BundleMinDistanceMetric(num_threads=1) slr_transform_type_id = possible_slr_transform_type[slr_transform_type] if slr_transform_type_id >= 0: init_transfo_dof = np.zeros(3) slr = StreamlineLinearRegistration(metric=metric, method="L-BFGS-B", x0=init_transfo_dof, bounds=bounds_dof[:3], num_threads=1) slm = slr.optimize(static, moving) if slr_transform_type_id >= 1: init_transfo_dof = np.zeros(6) init_transfo_dof[:3] = slm.xopt slr = StreamlineLinearRegistration(metric=metric, x0=init_transfo_dof, bounds=bounds_dof[:6], num_threads=1) slm = slr.optimize(static, moving) if slr_transform_type_id >= 2: if slr_transform_type_id == 2: init_transfo_dof = np.zeros(7) init_transfo_dof[:6] = slm.xopt init_transfo_dof[6] = 1. slr = StreamlineLinearRegistration(metric=metric, x0=init_transfo_dof, bounds=bounds_dof[:7], num_threads=1) slm = slr.optimize(static, moving) else: init_transfo_dof = np.zeros(9) init_transfo_dof[:6] = slm.xopt[:6] init_transfo_dof[6:] = np.array((slm.xopt[6],) * 3) slr = StreamlineLinearRegistration(metric=metric, x0=init_transfo_dof, bounds=bounds_dof[:9], num_threads=1) slm = slr.optimize(static, moving) self.model_centroids = transform_streamlines(self.model_centroids, np.linalg.inv(slm.matrix))
[docs] def prune_far_from_model(self, pruning_thr=10): """ Wrapper function to prune clusters from the tractogram too far from the model. Parameters ---------- pruning_thr: float, distance in thresholds = [32, 16, 24, neighbors_cluster_thr] """ # Neighbors can be refined since the search space is smaller t0 = time() neighb_streamlines = reconstruct_streamlines_from_memmap( self.memmap_filenames, self.neighb_indices, strs_dtype=np.float16) with warnings.catch_warnings(record=True) as _: fss = FastStreamlineSearch(neighb_streamlines, pruning_thr, resampling=12) dist_mat = fss.radius_search(self.model_streamlines, pruning_thr) logging.debug("Fast search took of dimensions {0}: {1} sec.".format( dist_mat.shape, np.round(time() - t0, 2))) sparse_dist_mat = np.abs(dist_mat.tocsr()) sparse_dist_vec = np.squeeze(np.max(sparse_dist_mat, axis=0).toarray()) pruned_indices = np.where(sparse_dist_vec > 1e-6)[0] # Since the neighbors were clustered, a mapping of indices is neccesary self.final_pruned_indices = self.neighb_indices[pruned_indices].astype( np.uint32) return self.final_pruned_indices
[docs] def get_final_pruned_indices(self): """ Public getter for the final indices recognize by the algorithm. """ return self.final_pruned_indices
[docs] def cleanup(self): del self.neighb_centroids del self.neighb_indices del self.model_streamlines del self.model_centroids