Source code for scilpy.tracking.propagator

# -*- coding: utf-8 -*-
from enum import Enum
import logging

import numpy as np

import dipy
from dipy.io.stateful_tractogram import Space, Origin
from dipy.reconst.shm import sh_to_sf_matrix

from scilpy.reconst.utils import (get_sphere_neighbours,
                                  get_sh_order_and_fullness)
from scilpy.tracking.utils import sample_distribution, TrackingDirection


[docs] class PropagationStatus(Enum): ERROR = 1
[docs] class AbstractPropagator(object): """ Abstract class for propagator object. "Propagation" means continuing the streamline a step further. The propagator is thus responsible for sampling the next direction at current step through Runge-Kutta integration (whereas the tracker using this propagator will be responsible for the processing parameters, number of streamlines, stopping criteria, etc.). Propagation depends on the type of data (ex, DTI, fODF) and the way to get a direction from it (ex, det, prob). """ def __init__(self, datavolume, step_size, rk_order, space, origin): """ Parameters ---------- datavolume: scilpy.image.volume_space_management.DataVolume Trackable Dataset object. step_size: float The step size for tracking. Important: step size should be in the same units as the space of the tracking! rk_order: int Order for the Runge Kutta integration. space: dipy Space Space of the streamlines during tracking. value. origin: dipy Origin Origin of the streamlines during tracking. All coordinates received in the propagator's methods will be expected to respect that origin. A note on space and origin: All coordinates received in the propagator's methods will be expected to respect those values. Tracker will verify that the propagator has the same internal values as itself. """ self.datavolume = datavolume self.origin = origin self.space = space # Propagation options self.step_size = step_size if not (rk_order == 1 or rk_order == 2 or rk_order == 4): raise ValueError("Invalid runge-kutta order. Is " + str(rk_order) + ". Choices : 1, 2, 4") self.rk_order = rk_order # By default, normalizing directions. Adding option for child classes. self.normalize_directions = True # Will be reset at each new streamline. self.line_rng_generator = None
[docs] def reset_data(self, new_data=None): """ Reset data before starting a new process. In current implementation, we reset the internal data to None before starting a multiprocess, then load it back when process has started. Parameters ---------- new_data: Any Will replace self.datavolume.data. """ self.datavolume.data = new_data
[docs] def prepare_forward(self, seeding_pos, random_generator): """ Prepare information necessary at the first point of the streamline for forward propagation: v_in and any other information necessary for the self.propagate method. Parameters ---------- seeding_pos: tuple(x,y,z) The seeding position. Important, position must be in the same space and origin as self.space, self.origin! random_generator: numpy Generator. Returns ------- tracking_info: Any Any tracking information necessary for the propagation. Return PropagationStatus.ERROR if no good tracking direction can be set at current seeding position. """ # To be defined by child classes. # Should set self.line_rng_generator = random_generator raise NotImplementedError
[docs] def prepare_backward(self, line, forward_dir): """ Called at the beginning of backward tracking, in case we need to reset some parameters Parameters ---------- line: List Result from the forward tracking, reversed. forward_dir: ndarray (3,) v_in chosen at the forward step. Returns ------- v_in: ndarray (3,) Last direction of the streamline. If the streamline contains only the seeding point (forward tracking failed), simply inverse the forward direction. """ if len(line) > 1: v = line[-1] - line[-2] if self.normalize_directions: return v / np.linalg.norm(v) else: return v elif forward_dir is not None: return [-dir_i for dir_i in forward_dir] else: return None
[docs] def finalize_streamline(self, last_pos, v_in): """ Return the last position of the streamline. Parameters ---------- last_pos: ndarray (3,) Last propagated position. Important, position must be in the same space and origin as self.space, self.origin! v_in: TrackingDirection Last propagated direction. Returns ------- final_pos: ndarray (3,) Position of the final point of the streamline. Return None, or last_pos, if no last step is wished. """ # Make a last step straight in the last direction (no sampling or # interpolation of a new direction). Ex of use: if stopped because it # exited the (WM) tracking mask, reaching GM a little more. final_pos = last_pos + self.step_size * np.array(v_in) return final_pos
def _sample_next_direction_or_go_straight(self, pos, v_in): """ Same as _sample_next_direction but if no valid direction has been found, return v_in as v_out. """ is_direction_valid = True v_out = self._sample_next_direction(pos, v_in) if v_out is None: is_direction_valid = False v_out = v_in return is_direction_valid, v_out
[docs] def propagate(self, line, v_in): """ Given the current position and direction, computes the next position and direction using Runge-Kutta integration method. If no valid tracking direction is available, v_in is chosen. Parameters ---------- line: list[ndarrray (3,)] Current position. v_in: ndarray (3,) or TrackingDirection Previous tracking direction. Return ------ new_pos: ndarray (3,) The new segment position, expressed in propagator's space and origin. new_dir: ndarray (3,) or TrackingDirection The new segment direction. is_direction_valid: bool True if new_dir is valid. """ # Finding last coordinate pos = line[-1] if self.rk_order == 1: is_direction_valid, new_dir = \ self._sample_next_direction_or_go_straight(pos, v_in) elif self.rk_order == 2: is_direction_valid, dir1 = \ self._sample_next_direction_or_go_straight(pos, v_in) _, new_dir = self._sample_next_direction_or_go_straight( pos + 0.5 * self.step_size * np.array(dir1), dir1) else: # case self.rk_order == 4 is_direction_valid, dir1 = \ self._sample_next_direction_or_go_straight(pos, v_in) v1 = np.array(dir1) _, dir2 = self._sample_next_direction_or_go_straight( pos + 0.5 * self.step_size * v1, dir1) v2 = np.array(dir2) _, dir3 = self._sample_next_direction_or_go_straight( pos + 0.5 * self.step_size * v2, dir2) v3 = np.array(dir3) _, dir4 = self._sample_next_direction_or_go_straight( pos + self.step_size * v3, dir3) v4 = np.array(dir4) new_v = (v1 + 2 * v2 + 2 * v3 + v4) / 6 new_dir = TrackingDirection(new_v, dir1.index) new_pos = pos + self.step_size * np.array(new_dir) return new_pos, new_dir, is_direction_valid
def _sample_next_direction(self, pos, v_in): """ Chooses a next tracking direction from all possible directions offered by the tracking field. Parameters ---------- pos: ndarray (3,) Current tracking position. Important, position must be in the same space and origin as self.space, self.origin! v_in: ndarray (3,) Previous tracking direction. Return ------- direction: ndarray (3,) A valid tracking direction. None if no valid direction is found. Direction should be normalized. """ raise NotImplementedError
[docs] class PropagatorOnSphere(AbstractPropagator): def __init__(self, datavolume, step_size, rk_order, dipy_sphere, sub_sphere, space, origin): """ Parameters ---------- datavolume: scilpy.image.volume_space_management.DataVolume Trackable DataVolume object. step_size: float The step size for tracking. rk_order: int Order for the Runge Kutta integration. dipy_sphere: string, optional If necessary, name of the DIPY sphere object to use to evaluate directions. space: dipy Space Space of the streamlines during tracking. origin: dipy Origin Origin of the streamlines during tracking. """ super().__init__(datavolume, step_size, rk_order, space, origin) self.sphere = dipy.data.get_sphere(dipy_sphere).subdivide(sub_sphere) self.dirs = np.zeros(len(self.sphere.vertices), dtype=np.ndarray) for i in range(len(self.sphere.vertices)): self.dirs[i] = TrackingDirection(self.sphere.vertices[i], i)
[docs] def prepare_backward(self, line, forward_dir): """ Called at the beginning of backward tracking, in case we need to reset some parameters Parameters ---------- line: List Result from the forward tracking, reversed. forward_dir: ndarray (3,) v_in chosen at the forward step. Returns ------- v_in: ndarray (3,) Last direction of the streamline, of if it contains only the seeding point (forward tracking failed), simply inverse the forward direction. """ if len(line) > 1: last_dir = line[-1] - line[-2] ind = self.sphere.find_closest(last_dir) else: backward_dir = -np.asarray(forward_dir) ind = self.sphere.find_closest(backward_dir) # toDo. Is using a TrackingDirection necessary compared to a direction # x,y, z or rho, phi? self.sphere.vertices[ind] might not be # exactly equal to last_dir or to backward_dir. return TrackingDirection(self.sphere.vertices[ind], ind)
[docs] class ODFPropagator(PropagatorOnSphere): """ Propagator on ODFs/fODFs. Algo can be det or prob. """ def __init__(self, datavolume, step_size, rk_order, algo, basis, sf_threshold, sf_threshold_init, theta, dipy_sphere='symmetric724', sub_sphere=0, min_separation_angle=np.pi / 16., space=Space('vox'), origin=Origin('center'), is_legacy=True): """ Parameters ---------- datavolume: scilpy.image.volume_space_management.DataVolume Trackable DataVolume object. step_size: float The step size for tracking. rk_order: int Order for the Runge Kutta integration. algo: string Type of algorithm. Choices are 'det' or 'prob' basis: string SH basis name. One of 'tournier07' or 'descoteaux07' sf_threshold: float Threshold on spherical function (SF). sf_threshold_init: float Threshold on spherical function when initializing a new streamline. theta: float Maximum angle (radians) between two steps. dipy_sphere: string, optional Name of the DIPY sphere object to use for evaluating SH. Can't be None. sub_sphere: int Number of subdivisions to use for the sphere. min_separation_angle: float, optional Minimum separation angle (in radians) for peaks extraction. Used for deterministic tracking. A candidate direction is a maximum if its SF value is greater than all other SF values in its neighbourhood, where the neighbourhood includes all the sphere directions located at most `min_separation_angle` from the candidate direction. space: dipy Space Space of the streamlines during tracking. Default: VOX, like in dipy. Interpolation of the ODF is done in VOX space (see DataVolume.vox_to_value) so this choice implies the less data modification. origin: dipy Origin Origin of the streamlines during tracking. Default: center, like in dipy. Interpolation of the ODF is done in center origin so this choice implies the less data modification. is_legacy : bool, optional Whether or not the SH basis is in its legacy form. """ super().__init__(datavolume, step_size, rk_order, dipy_sphere, sub_sphere, space, origin) if self.space == Space.RASMM: raise NotImplementedError( "This version of the propagator on ODF is not ready to work " "in RASMM space.") # Warn user if the rk order does not match the algo if rk_order != 1 and algo == 'prob': logging.warning('Probabilistic tracking with RK order != 1 is ' 'not recommended! Use deterministic tracking ' 'or set rk_order to 1 instead.') # Propagation params self.theta = theta if algo not in ['det', 'prob']: raise ValueError("ODFPropagator algo should be 'det' or 'prob'.") self.algo = algo self.tracking_neighbours = get_sphere_neighbours(self.sphere, self.theta) # For deterministic tracking: self.maxima_neighbours = get_sphere_neighbours(self.sphere, min_separation_angle) # ODF params self.sf_threshold = sf_threshold self.sf_threshold_init = sf_threshold_init sh_order, full_basis =\ get_sh_order_and_fullness(self.datavolume.data.shape[-1]) self.basis = basis self.is_legacy = is_legacy self.B = sh_to_sf_matrix(self.sphere, sh_order, self.basis, smooth=0.006, return_inv=False, full_basis=full_basis, legacy=self.is_legacy) def _get_sf(self, pos): """ Get the spherical function at position pos. Parameters ---------- pos: ndarray (3,) Position in the trackable dataset. Important, position should be in the same space and origin as self.space, self.origin! Return ------ sf: ndarray (len(self.sphere.vertices),) Spherical function evaluated at pos, normalized by its maximum amplitude. """ # Interpolation: sh = self.datavolume.get_value_at_coordinate( *pos, space=self.space, origin=self.origin) sf = np.dot(self.B.T, sh).reshape((-1, 1)) sf_max = np.max(sf) if sf_max > 0: sf /= sf_max return sf
[docs] def prepare_forward(self, seeding_pos, random_generator): """ Prepare information necessary at the first point of the streamline for forward propagation: v_in and any other information necessary for the self.propagate method. About **v_in**, it is used for two things: - To sample the next direction based on _sample_next_direction method. Ex, with fODF, it defines a cone theta of accepable directions. - If no valid next dir are found, continue straight. Parameters ---------- seeding_pos: tuple(x,y,z) The seeding position. Important, position must be in the same space and origin as self.space, self.origin! random_generator: numpy Generator Returns ------- v_in: TrackingDirection The "fake" previous direction at first step. Could be None if your propagator can propagate without knowledge of previous direction. Return PropagationStatus.Error if no good tracking direction can be set at current seeding position. """ # Sampling on the SF values (no matter if general algo is det or prob) # with a different threshold than usual (sf_threshold_init). # So the initial step's propagation will be in a cone theta around a # "more probable" peak. sf = self._get_sf(seeding_pos) sf[sf < self.sf_threshold_init] = 0 self.line_rng_generator = random_generator if np.sum(sf) > 0: ind = sample_distribution(sf, self.line_rng_generator) return TrackingDirection(self.dirs[ind], ind) # Else: sf at current position is smaller than acceptable threshold in # all directions. return PropagationStatus.ERROR
def _sample_next_direction(self, pos, v_in): """ Chooses a next tracking direction from all possible directions offered by the tracking field. Parameters ---------- pos: ndarray (3,) Current tracking position. Important, position must be in the same space and origin as self.space, self.origin! v_in: ndarray (3,) Previous tracking direction. Return ------ direction: ndarray (3,) A valid tracking direction. None if no valid direction is found. """ if self.algo == 'prob': # Tracking field returns the sf and directions sf, directions = self._get_possible_next_dirs_prob(pos, v_in) # Sampling one. if np.sum(sf) > 0: v_out = directions[sample_distribution(sf, self.line_rng_generator)] else: return None elif self.algo == 'det': # Tracking field returns the list of possible maxima. possible_maxima = self._get_possible_next_dirs_det(pos, v_in) # Choosing one. cosinus = 0 v_out = None for d in possible_maxima: new_cosinus = np.dot(v_in, d) if new_cosinus > cosinus: cosinus = new_cosinus v_out = d else: raise ValueError("Tracking choice must be one of 'det' or 'prob'.") # Not normalizing: direction comes from dipy's (unit) sphere so # supposing that it's ok. return v_out def _get_possible_next_dirs_prob(self, pos, v_in): """ Get the spherical functions thresholded at position pos, for a given direction. Parameters ---------- pos: ndarray (3,) Position in trackable dataset. Important, position must be in the same space and origin as self.space, self.origin! v_in: TrackingDirection Incoming direction. Outcoming direction won't be further than an angle theta. Return ------ value: tuple The neighbours SF evaluated at pos in given direction and corresponding tracking directions. """ sf = self._get_sf(pos) sf[sf < self.sf_threshold] = 0 inds = np.nonzero( self.tracking_neighbours[v_in.index])[0] return sf[inds], self.dirs[inds] def _get_possible_next_dirs_det(self, pos, previous_direction): """ Get the set of maxima directions from the thresholded SF at position pos, for a direction. Parameters ---------- pos: ndarray (3,) Position in trackable dataset. Important, position must be in the same space and origin as self.space, self.origin! previous_direction: TrackingDirection Incoming direction. Outcoming direction won't be further than an angle theta. Return ------ maxima: list List of directions of maxima around the input direction at pos. """ sf = self._get_sf(pos) sf[sf < self.sf_threshold] = 0 maxima = [] for i in np.nonzero(self.tracking_neighbours[ previous_direction.index])[0]: if 0 < sf[i] == np.max(sf[self.maxima_neighbours[i]]): maxima.append(self.dirs[i]) return maxima