Source code for scilpy.tracking.tracker

# -*- coding: utf-8 -*-
from contextlib import nullcontext
import itertools
import logging
import multiprocessing
import os
import sys
from tempfile import TemporaryDirectory
from time import perf_counter
import traceback
from typing import Union
from tqdm import tqdm

import numpy as np
from dipy.data import get_sphere
from dipy.core.sphere import HemiSphere
from dipy.io.stateful_tractogram import Space
from dipy.reconst.shm import sh_to_sf_matrix
from dipy.tracking.streamlinespeed import compress_streamlines

from scilpy.image.volume_space_management import DataVolume
from scilpy.tracking.propagator import AbstractPropagator, PropagationStatus
from scilpy.reconst.utils import find_order_from_nb_coeff
from scilpy.tracking.seed import SeedGenerator
from scilpy.gpuparallel.opencl_utils import CLKernel, CLManager, have_opencl

# For the multi-processing:
# Dictionary. Will contain all parameters necessary for a sub-process
# initialization.
multiprocess_init_args = {}


[docs] class Tracker(object): def __init__(self, propagator: AbstractPropagator, mask: DataVolume, seed_generator: SeedGenerator, nbr_seeds, min_nbr_pts, max_nbr_pts, max_invalid_dirs, compression_th=0.1, nbr_processes=1, save_seeds=False, mmap_mode: Union[str, None] = None, rng_seed=1234, track_forward_only=False, skip=0, verbose=False, append_last_point=True): """ Parameters ---------- propagator : AbstractPropagator Tracking object. This tracker will use space and origin defined in the propagator. mask : DataVolume Tracking volume(s). seed_generator : SeedGenerator Seeding volume. nbr_seeds: int Number of seeds to create via the seed generator. min_nbr_pts: int Minimum number of points for streamlines. max_nbr_pts: int Maximum number of points for streamlines. max_invalid_dirs: int Number of consecutives invalid directions allowed during tracking. compression_th : float, Maximal distance threshold for compression. If None, no compression is applied. nbr_processes: int Number of sub processes to use. save_seeds: bool Whether to save the seeds associated to their respective streamlines. mmap_mode: str Memory-mapping mode. One of {None, 'r+', 'c'}. This value is passed to np.load() when loading the raw tracking data from a subprocess. rng_seed: int The random "seed" for the random generator. track_forward_only: bool If true, only the forward direction is computed. skip: int Skip the first N seeds created (and thus N rng numbers). Useful if you want to create new streamlines to add to a previously created tractogram with a fixed rng_seed. Ex: If tractogram_1 was created with nbr_seeds=1,000,000, you can create tractogram_2 with skip 1,000,000. verbose: bool Display tracking progression. append_last_point: bool Whether to add the last point (once out of the tracking mask) to the streamline or not. Note that points obtained after an invalid direction (based on the propagator's definition of invalid; ex when angle is too sharp of sh_threshold not reached) are never added. """ self.propagator = propagator self.mask = mask self.seed_generator = seed_generator self.nbr_seeds = nbr_seeds self.min_nbr_pts = min_nbr_pts self.max_nbr_pts = max_nbr_pts self.max_invalid_dirs = max_invalid_dirs self.compression_th = compression_th self.save_seeds = save_seeds self.mmap_mode = mmap_mode self.rng_seed = rng_seed self.track_forward_only = track_forward_only self.append_last_point = append_last_point self.skip = skip self.origin = self.propagator.origin self.space = self.propagator.space if self.space == Space.RASMM: raise NotImplementedError( "This version of the Tracker is not ready to work in RASMM " "space.") if (seed_generator.origin != propagator.origin or seed_generator.space != propagator.space): raise ValueError("Seed generator and propagator must work with " "the same space and origin!") if self.min_nbr_pts <= 0: logging.warning("Minimum number of points cannot be 0. Changed to " "1.") self.min_nbr_pts = 1 if self.mmap_mode not in [None, 'r+', 'c']: logging.warning("Memory-mapping mode cannot be {}. Changed to " "None.".format(self.mmap_mode)) self.mmap_mode = None self.nbr_processes = self._set_nbr_processes(nbr_processes) self.printing_frequency = 1000 self.verbose = verbose
[docs] def track(self): """ Generate a set of streamline from seed, mask and odf files. Return ------ streamlines: list of numpy.array List of streamlines, represented as an array of positions. seeds: list of numpy.array List of seeding positions, one 3-dimensional position per streamline. """ if self.nbr_processes < 2: chunk_id = 0 lines, seeds = self._get_streamlines(chunk_id) else: # Each process will use get_streamlines_at_seeds chunk_ids = np.arange(self.nbr_processes) with TemporaryDirectory() as tmpdir: # Lock for logging lock = multiprocessing.Manager().Lock() zipped_chunks = zip(chunk_ids, [lock] * self.nbr_processes) pool = self._prepare_multiprocessing_pool(tmpdir) lines_per_process, seeds_per_process = zip(*pool.map( self._get_streamlines_sub, zipped_chunks)) pool.close() # Make sure all worker processes have exited before leaving # context manager. pool.join() lines = [line for line in itertools.chain(*lines_per_process)] seeds = [seed for seed in itertools.chain(*seeds_per_process)] return lines, seeds
def _set_nbr_processes(self, nbr_processes): """ If user did not define the number of processes, define it automatically (or set to 1 -- no multiprocessing -- if we can't). """ if nbr_processes <= 0: try: nbr_processes = multiprocessing.cpu_count() except NotImplementedError: logging.warning("Cannot determine number of cpus: " "nbr_processes set to 1.") nbr_processes = 1 if nbr_processes > self.nbr_seeds: nbr_processes = self.nbr_seeds logging.info("Setting number of processes to {} since there were " "less seeds than processes.".format(nbr_processes)) return nbr_processes def _prepare_multiprocessing_pool(self, tmpdir): """ Prepare multiprocessing pool. Data must be carefully managed to avoid corruption with multiprocessing. Params ------ tmpdir: str Path where to save temporarily the data. This will allow clearing the data from memory. We will fetch it back later. Returns ------- pool: The multiprocessing pool. """ # Using pool with a class method will serialize all parameters # in the class, which can be heavy, but it is what we would be # doing manually with a static class. # Be careful however, parameter changes inside the method will # not be kept. # Saving data. We will reload it in each process. data_file_name = os.path.join(tmpdir, 'data.npy') np.save(data_file_name, self.propagator.datavolume.data) # Clear data from memory self.propagator.reset_data(new_data=None) pool = multiprocessing.Pool( self.nbr_processes, initializer=self._send_multiprocess_args_to_global, initargs=({ 'data_file_name': data_file_name, 'mmap_mode': self.mmap_mode },)) return pool @staticmethod def _send_multiprocess_args_to_global(init_args): """ Sends subprocess' initialisation arguments to global for easier access by the multiprocessing pool. """ global multiprocess_init_args multiprocess_init_args = init_args return def _get_streamlines_sub(self, params): """ multiprocessing.pool.map input function. Calls the main tracking method (_get_streamlines) with correct initialization arguments (taken from the global variable multiprocess_init_args). Parameters ---------- params: Tuple[chunk_id, Lock] chunk_id: int, this processes's id. Lock: the multiprocessing lock. Return ------- lines: list List of list of 3D positions (streamlines). """ chunk_id, lock = params global multiprocess_init_args self._reload_data_for_new_process(multiprocess_init_args) try: streamlines, seeds = self._get_streamlines(chunk_id, lock) return streamlines, seeds except Exception as e: logging.error("Operation _get_streamlines_sub() failed.") traceback.print_exception(*sys.exc_info(), file=sys.stderr) raise e def _reload_data_for_new_process(self, init_args): """ Once process is started, load back data. Params ------ init_args: Iterable Args necessary to reset data. In current implementation: a tuple; (file where the data is saved, mmap_mode). """ self.propagator.reset_data(np.load( init_args['data_file_name'], mmap_mode=init_args['mmap_mode'])) def _get_streamlines(self, chunk_id, lock=None): """ Tracks the n streamlines associates with current process (identified by chunk_id). The number n is the total number of seeds / the number of processes. If asked by user, may compress the streamlines and save the seeds. Parameters ---------- chunk_id: int This process ID. lock: Lock The multiprocessing lock for verbose printing (optional with single processing). Returns ------- streamlines: list The successful streamlines. seeds: list The list of seeds for each streamline, if self.save_seeds. Else, an empty list. """ streamlines = [] seeds = [] # Initialize the random number generator to cover multiprocessing, # skip, which voxel to seed and the subvoxel random position chunk_size = int(self.nbr_seeds / self.nbr_processes) first_seed_of_chunk = chunk_id * chunk_size + self.skip random_generator, indices = self.seed_generator.init_generator( self.rng_seed, first_seed_of_chunk) if chunk_id == self.nbr_processes - 1: chunk_size += self.nbr_seeds % self.nbr_processes # Getting streamlines tqdm_text = "#" + "{}".format(chunk_id).zfill(3) if self.verbose: if lock is None: lock = nullcontext() with lock: # Note. Option miniters does not work with manual pbar update. # Will verify manually, lower. # Fixed choice of value rather than a percentage of the chunk # size because our tracker is quite slow. miniters = 100 p = tqdm(total=chunk_size, desc=tqdm_text, position=chunk_id+1, leave=False) for s in range(chunk_size): seed = self.seed_generator.get_next_pos( random_generator, indices, first_seed_of_chunk + s) # Setting the random value. # Previous usage (and usage in Dipy) is to set the random seed # based on the (real) seed position. However, in the case where we # like to have exactly the same seed more than once, this will lead # to exactly the same line, even in probabilistic tracking. # Changing to seed position + seed number. # Then in the case of multiprocessing, adding also a fraction based # on current process ID. eps = s + chunk_id / (self.nbr_processes + 1) line_generator = np.random.default_rng( np.uint32(hash((seed + (eps, eps, eps), self.rng_seed)))) # Forward and backward tracking line = self._get_line_both_directions(seed, line_generator) if line is not None: streamline = np.array(line, dtype='float32') if self.compression_th is not None: # Compressing. Threshold is in mm. Verifying space. if self.space == Space.VOX: # Equivalent of sft.to_voxmm: streamline *= self.seed_generator.voxres compress_streamlines(streamline, self.compression_th) # Equivalent of sft.to_vox: streamline /= self.seed_generator.voxres else: compress_streamlines(streamline, self.compression_th) streamlines.append(streamline) if self.save_seeds: seeds.append(np.asarray(seed, dtype='float32')) if self.verbose and (s + 1) % miniters == 0: with lock: p.update(miniters) if self.verbose: with lock: p.close() return streamlines, seeds def _get_line_both_directions(self, seeding_pos, line_generator): """ Generate a streamline from an initial position following the tracking parameters. Parameters ---------- seeding_pos : tuple 3D position, the seed position. Returns ------- line: list of 3D positions The generated streamline for seeding_pos. """ # Forward line = [np.asarray(seeding_pos)] tracking_info = self.propagator.prepare_forward(seeding_pos, line_generator) if tracking_info == PropagationStatus.ERROR: # No good tracking direction can be found at seeding position. return None line = self._propagate_line(line, tracking_info) # Backward if not self.track_forward_only: if len(line) > 1: line.reverse() tracking_info = self.propagator.prepare_backward(line, tracking_info) line = self._propagate_line(line, tracking_info) # Clean streamline if self.min_nbr_pts <= len(line) <= self.max_nbr_pts: return line return None def _propagate_line(self, line, tracking_info): """ Generate a streamline in forward or backward direction from an initial position following the tracking parameters. Propagation will stop if the current position is out of bounds (mask's bounds and data's bounds should be the same) or if mask's value at current position is 0 (usual use is with a binary mask but this is not mandatory). Parameters ---------- line: List[np.ndarrays] Beginning of the line to propagate: list of 3D coordinates formatted as arrays. tracking_info: Any Information necessary to know how to propagate. Type: as understood by the propagator. Example, with the typical fODF propagator: the previous direction of the streamline, v_in, used to define a cone theta, of type TrackingDirection. Returns ------- line: list of 3D positions At minimum, stays as initial line. Or extended with new tracked points. """ invalid_direction_count = 0 propagation_can_continue = True while len(line) < self.max_nbr_pts and propagation_can_continue: new_pos, new_tracking_info, is_direction_valid = \ self.propagator.propagate(line, tracking_info) # Verifying if direction is valid # If invalid: break. Else, verify tracking mask. if is_direction_valid: invalid_direction_count = 0 else: invalid_direction_count += 1 if invalid_direction_count > self.max_invalid_dirs: break propagation_can_continue = self._verify_stopping_criteria(new_pos) if propagation_can_continue or self.append_last_point: line.append(new_pos) tracking_info = new_tracking_info return line def _verify_stopping_criteria(self, last_pos): # Checking if out of bound if not self.mask.is_coordinate_in_bound( *last_pos, space=self.space, origin=self.origin): return False # Checking if out of mask if self.mask.get_value_at_coordinate( *last_pos, space=self.space, origin=self.origin) <= 0: return False return True
[docs] class GPUTacker(): """ Perform probabilistic tracking on a ODF field inside a binary mask. The tracking is executed on the GPU using the OpenCL API. Tracking is performed in voxel space with origin `corner`. Streamlines are interrupted as soon as they reach maximum length and returned even if they end inside the tracking mask. The ODF image is interpolated using nearest neighbor interpolation. No backward tracking is performed. Parameters ---------- sh : ndarray Spherical harmonics volume. Ex: ODF or fODF. mask : ndarray Tracking mask. Tracking stops outside the mask. seeds : ndarray (n_seeds, 3) Seed positions in voxel space with origin `center`. step_size : float Step size in voxel space. max_nbr_pts : int Maximum length of a streamline in voxel space. theta : float or list of float, optional Maximum angle (degrees) between 2 steps. If a list, a theta is randomly drawn from the list for each streamline. sh_basis : str, optional Spherical harmonics basis. is_legacy : bool, optional Whether or not the SH basis is in its legacy form. batch_size : int, optional Approximate size of GPU batches. forward_only: bool, optional If True, only forward tracking is performed. rng_seed : int, optional Seed for random number generator. sphere : int, optional Sphere to use for the tracking. """ def __init__(self, sh, mask, seeds, step_size, max_nbr_pts, theta=20.0, sf_threshold=0.1, sh_interp='trilinear', sh_basis='descoteaux07', is_legacy=True, batch_size=100000, forward_only=False, rng_seed=None, sphere=None): if not have_opencl: raise ImportError('pyopencl is not installed. In order to use' 'GPU tracker, you need to install it first.') self.sh = sh if sh_interp not in ['nearest', 'trilinear']: raise ValueError('Invalid SH interpolation mode: {}' .format(sh_interp)) self.sh_interp_nn = sh_interp == 'nearest' self.mask = mask self.n_seeds = len(seeds) self.seed_batches =\ np.array_split(seeds + 0.5, np.ceil(len(seeds)/batch_size)) if sphere is None: self.sphere = get_sphere("repulsion724") else: self.sphere = sphere # tracking step_size and number of points self.step_size = step_size self.sf_threshold = sf_threshold self.max_strl_points = max_nbr_pts # convert theta to array self.theta = np.atleast_1d(theta) self.sh_basis = sh_basis self.is_legacy = is_legacy self.forward_only = forward_only # Instantiate random number generator self.rng = np.random.default_rng(rng_seed) def _get_max_amplitudes(self, B_mat): fodf_max = np.zeros(self.mask.shape, dtype=np.float32) fodf_max[self.mask > 0] = np.max(self.sh[self.mask > 0].dot(B_mat), axis=-1) return fodf_max def __iter__(self): return self._track() def _track(self): """ GPU streamlines generator yielding streamlines with corresponding seed positions one by one. """ # Convert theta to cos(theta) max_cos_theta = np.cos(np.deg2rad(self.theta)) cl_kernel = CLKernel('tracker', 'tracking', 'local_tracking.cl') # Set tracking parameters cl_kernel.set_define('IM_X_DIM', self.sh.shape[0]) cl_kernel.set_define('IM_Y_DIM', self.sh.shape[1]) cl_kernel.set_define('IM_Z_DIM', self.sh.shape[2]) cl_kernel.set_define('IM_N_COEFFS', self.sh.shape[3]) cl_kernel.set_define('N_DIRS', len(self.sphere.vertices)) cl_kernel.set_define('N_THETAS', len(self.theta)) cl_kernel.set_define('STEP_SIZE', '{:.8f}f'.format(self.step_size)) cl_kernel.set_define('MAX_LENGTH', self.max_strl_points) cl_kernel.set_define('FORWARD_ONLY', 'true' if self.forward_only else 'false') cl_kernel.set_define('SF_THRESHOLD', '{:.8f}f'.format(self.sf_threshold)) cl_kernel.set_define('SH_INTERP_NN', 'true' if self.sh_interp_nn else 'false') # Create CL program cl_manager = CLManager(cl_kernel) # Input buffers # Constant input buffers cl_manager.add_input_buffer('sh', self.sh) cl_manager.add_input_buffer('vertices', self.sphere.vertices) sh_order = find_order_from_nb_coeff(self.sh) B_mat = sh_to_sf_matrix(self.sphere, sh_order, self.sh_basis, return_inv=False, legacy=self.is_legacy) cl_manager.add_input_buffer('b_matrix', B_mat) fodf_max = self._get_max_amplitudes(B_mat) cl_manager.add_input_buffer('max_amplitudes', fodf_max) cl_manager.add_input_buffer('mask', self.mask.astype(np.float32)) cl_manager.add_input_buffer('max_cos_theta', max_cos_theta) cl_manager.add_input_buffer('seeds') cl_manager.add_input_buffer('randvals') cl_manager.add_output_buffer('out_strl') cl_manager.add_output_buffer('out_lengths') # Generate streamlines in batches for seed_batch in self.seed_batches: # Generate random values for sf sampling # TODO: Implement random number generator directly # on the GPU to generate values on-the-fly. rand_vals = self.rng.uniform(0.0, 1.0, (len(seed_batch), self.max_strl_points)) # Update buffers cl_manager.update_input_buffer('seeds', seed_batch) cl_manager.update_input_buffer('randvals', rand_vals) # output streamlines buffer cl_manager.update_output_buffer('out_strl', (len(seed_batch), self.max_strl_points, 3)) # output streamlines length buffer cl_manager.update_output_buffer('out_lengths', (len(seed_batch), 1)) # Run the kernel tracks, n_points = cl_manager.run((len(seed_batch), 1, 1)) n_points = n_points.flatten().astype(np.int16) for (strl, seed, n_pts) in zip(tracks, seed_batch, n_points): strl = strl[:n_pts] # output is yielded so that we can use LazyTractogram. # seed and strl with origin center (same as DIPY) yield strl - 0.5, seed - 0.5