Source code for scilpy.segment.bundleseg

# -*- coding: utf-8 -*-

import gc
import logging
from time import time
import warnings

from dipy.align.streamlinear import (BundleMinDistanceMetric,
                                     StreamlineLinearRegistration)
from dipy.segment.fss import FastStreamlineSearch, nearest_from_matrix_col
from dipy.segment.clustering import qbx_and_merge
from dipy.tracking.distances import bundles_distances_mdf
from dipy.tracking.streamline import (select_random_set_of_streamlines,
                                      transform_streamlines)
from nibabel.streamlines.array_sequence import ArraySequence
import numpy as np
from scipy.sparse import vstack

from scilpy.io.streamlines import reconstruct_streamlines_from_memmap
from scilpy.utils import get_duration

logger = logging.getLogger('BundleSeg')


[docs] class BundleSeg(object): """ This class is a 'remastered' version of the Dipy Recobundles class. Made to be used in sync with the voting_scheme. Adapted in many way for HPC processing and increase control over parameters and logger. However, in essence they do the same thing. https://github.com/nipy/dipy/blob/master/dipy/segment/bundles.py References ---------- [1] Garyfallidis et al. Recognition of white matter bundles using local and global streamline-based registration and clustering, Neuroimage, 2017. """ def __init__(self, memmap_filenames, transformation, clusters_indices, wb_centroids, rng=None): """ Parameters ---------- memmap_filenames : tuple tuple of filenames for the data, offsets and lengths. transformation : numpy.ndarray Transformation matrix to apply to the model streamlines. clusters_indices: ArraySequence ArraySequence containing the indices of the streamlines per cluster. wb_centroids : list of numpy.ndarray List contaning the average streamline per cluster as obtained from qbx. rng : RandomState If None then RandomState is initialized internally. """ self.memmap_filenames = memmap_filenames self.transformation = transformation self.wb_clusters_indices = clusters_indices self.wb_centroids = wb_centroids self.rng = rng # For memory management self.wb_centroids._data = self.wb_centroids._data.astype(np.float16) # For declaration outside of init self.neighb_centroids = None self.neighb_indices = None self.model_streamlines = None self.model_centroids = None
[docs] def recognize(self, model_bundle, model_clust_thr=8, pruning_thr=8, slr_transform_type='similarity', identifier=None): """ Parameters ---------- model_bundle : list or ArraySequence Model bundle as loaded by the nibabel API. model_clust_thr : obj Distance threshold (mm) for model clustering (QBx) pruning_thr : int Distance threshold (mm) for the final pruning. slr_transform_type : str Define the transformation for the local SLR. [translation, rigid, similarity, scaling] identifier : str Identify the current bundle being recognized for the logger. Returns ------- clusters : list Streamlines that were recognized by Recobundles and these parameters. """ self.model_streamlines = model_bundle self._cluster_model_bundle(model_clust_thr, identifier=identifier) if self._reduce_search_space(neighbors_reduction_thr=16) == 0: if identifier: logger.error(f'{identifier} did not find any neighbors in ' 'the tractogram') return [], [] self._register_model_to_neighb(slr_transform_type=slr_transform_type) if self._reduce_search_space(neighbors_reduction_thr=14) == 0: if identifier: logger.error(f'{identifier} did not find any neighbors in ' 'the tractogram') return [], [] del self.neighb_centroids, self.model_centroids gc.collect() pruned_indices, pruned_scores = self.prune_far_from_model( pruning_thr=pruning_thr) return pruned_indices, pruned_scores
def _cluster_model_bundle(self, model_clust_thr, identifier): """ Wrapper function to compute QBx for the model and logging information. Parameters ---------- model_clust_thr, float, distance in mm for clustering. identifier, str, name of the bundle for logger. """ thresholds = [30, 20, 15, model_clust_thr] model_cluster_map = qbx_and_merge(self.model_streamlines, thresholds, nb_pts=12, rng=self.rng, verbose=False) self.model_centroids = ArraySequence(model_cluster_map.centroids) self.model_centroids._data = self.model_centroids._data.astype( np.float16) len_centroids = len(self.model_centroids) if len_centroids > 1000: logger.warning(f'Model {identifier} simplified at threshold ' f'{model_clust_thr}mm with {len_centroids} centroids') def _reduce_search_space(self, neighbors_reduction_thr=18): """ Wrapper function to discard clusters from the tractogram too far from the model and logging informations. :param neighbors_reduction_thr, float, distance in mm for thresholding to discard distant streamlines. """ centroid_matrix = bundles_distances_mdf(self.model_centroids, self.wb_centroids).astype(np.float16) centroid_matrix[centroid_matrix > neighbors_reduction_thr] = np.inf mins = np.min(centroid_matrix, axis=0) close_clusters_indices = np.array(np.where(mins != np.inf)[0], dtype=np.uint32) self.neighb_indices = [] for i in close_clusters_indices: self.neighb_indices.extend(self.wb_clusters_indices[i]) self.neighb_indices = np.array(self.neighb_indices, dtype=np.uint32) self.neighb_centroids = ArraySequence([self.wb_centroids[i] for i in close_clusters_indices]) self.neighb_centroids._data = self.neighb_centroids._data.astype( np.float16) return self.neighb_indices.size def _register_model_to_neighb(self, select_model=250, select_target=250, slr_transform_type='similarity'): """ Parameters ---------- select_model : int Maximum number of clusters to select from the model. select_target : int Maximum number of clusters to select from the neighborhood. slr_transform_type : str Define the transformation for the local SLR. [translation, rigid, similarity, scaling]. Returns ------- transf_neighbor : list The neighborhood clusters transformed into model space. """ possible_slr_transform_type = {'translation': 0, 'rigid': 1, 'similarity': 2, 'scaling': 3} static = select_random_set_of_streamlines(self.model_centroids, select_model, rng=self.rng) moving = select_random_set_of_streamlines(self.neighb_centroids, select_target, rng=self.rng) # Tuple 0,1,2 are the min & max bound in x,y,z for translation # Tuple 3,4,5 are the min & max bound in x,y,z for rotation # Tuple 6,7,8 are the min & max bound in x,y,z for scaling # For uniform scaling (similarity), tuple #6 is enough bounds_dof = [(-20, 20), (-20, 20), (-20, 20), (-10, 10), (-10, 10), (-10, 10), (0.8, 1.2), (0.8, 1.2), (0.8, 1.2)] metric = BundleMinDistanceMetric(num_threads=1) slr_transform_type_id = possible_slr_transform_type[slr_transform_type] if slr_transform_type_id >= 0: init_transfo_dof = np.zeros(3) slr = StreamlineLinearRegistration(metric=metric, method="L-BFGS-B", x0=init_transfo_dof, bounds=bounds_dof[:3], num_threads=1) slm = slr.optimize(static, moving) if slr_transform_type_id >= 1: init_transfo_dof = np.zeros(6) init_transfo_dof[:3] = slm.xopt slr = StreamlineLinearRegistration(metric=metric, x0=init_transfo_dof, bounds=bounds_dof[:6], num_threads=1) slm = slr.optimize(static, moving) if slr_transform_type_id >= 2: if slr_transform_type_id == 2: init_transfo_dof = np.zeros(7) init_transfo_dof[:6] = slm.xopt init_transfo_dof[6] = 1. slr = StreamlineLinearRegistration(metric=metric, x0=init_transfo_dof, bounds=bounds_dof[:7], num_threads=1) slm = slr.optimize(static, moving) else: init_transfo_dof = np.zeros(9) init_transfo_dof[:6] = slm.xopt[:6] init_transfo_dof[6:] = np.array((slm.xopt[6],) * 3) slr = StreamlineLinearRegistration(metric=metric, x0=init_transfo_dof, bounds=bounds_dof[:9], num_threads=1) slm = slr.optimize(static, moving) # Apply the transformation to the model streamlines self.model_streamlines = transform_streamlines(self.model_streamlines, np.linalg.inv(slm.matrix)) self.model_streamlines._data = self.model_streamlines._data.astype( np.float16)
[docs] def prune_far_from_model(self, pruning_thr=10): """ Wrapper function to prune clusters from the tractogram too far from the model. Parameters ---------- pruning_thr: float, distance in thresholds = [32, 16, 24, neighbors_cluster_thr] """ # Neighbors can be refined since the search space is smaller t0 = time() neighb_streamlines = reconstruct_streamlines_from_memmap( self.memmap_filenames, self.neighb_indices, strs_dtype=np.float16) # Typically the neighbors is bigger than the model, so we flip the # FSS to be more memory efficient with warnings.catch_warnings(record=True) as _: fss = FastStreamlineSearch(self.model_streamlines, pruning_thr, resampling=12) CHUNK_SIZE = 1000 dist_mat_list = [] for chuck_id in range(0, len(neighb_streamlines), CHUNK_SIZE): tmp_dist_mat = fss.radius_search( neighb_streamlines[chuck_id:chuck_id+CHUNK_SIZE], pruning_thr) tmp_dist_mat.data = tmp_dist_mat.data.astype(np.float32) dist_mat_list.append(tmp_dist_mat.copy()) del tmp_dist_mat gc.collect() dist_mat = vstack(dist_mat_list) for tmp_dist_mat in dist_mat_list: del tmp_dist_mat gc.collect() dist_mat.data = dist_mat.data.astype(np.float32) dist_mat = dist_mat.T logger.debug(f'Fast search took of dimensions {dist_mat.shape}: ' f'{get_duration(t0)} sec.') if dist_mat.size == 0 or dist_mat.shape[1] <= 1: return [], [] # Identify the closest neighbors (remove the zeros, not matched) np.absolute(dist_mat.data, out=dist_mat.data) non_zero_ids, _, scores = nearest_from_matrix_col(dist_mat) del dist_mat, neighb_streamlines del fss.ref_slines, fss.bin_dict, fss gc.collect() # If no streamlines were recognized, return an empty array if len(non_zero_ids) != 0: # Since the neighbors were clustered, a mapping of indices is neccesary final_pruned_indices = self.neighb_indices[non_zero_ids].astype( np.uint32) final_pruned_scores = scores.astype(np.float16) return final_pruned_indices, final_pruned_scores else: return [], []
[docs] def cleanup(self): for indices in [self.neighb_indices, self.wb_clusters_indices]: if indices is not None: del indices del self.model_streamlines._data, self.model_streamlines