# -*- coding: utf-8 -*-
import logging
import numpy as np
[docs]
class RAP:
def __init__(self, rap_volume, propagator, max_nbr_pts):
"""
rap_volume: DataVolume
HRegion-Adaptive Propagation tractography volume.
"""
self.rap_volume = rap_volume
self.propagator = propagator
self.max_nbr_pts = max_nbr_pts
self._current_label = None
self._total_steps = 0
[docs]
def is_in_rap_region(self, curr_pos, space, origin):
return self.rap_volume.get_value_at_coordinate(
*curr_pos, space=space, origin=origin) > 0
[docs]
def rap_multistep_propagate(self, line, prev_direction):
"""
All child classes must implement this method. Must receive and return
the parameters as defined here:
Params
------
line: list
The beginning of the streamline
Returns
-------
line: list
The streamline extended with RAP in the RAP neighborhood.
prev_direction: tuple
The last direction (x, y, z).
is_line_valid: bool
If the line generated with RAP is valid.
"""
raise NotImplementedError
[docs]
class RAPContinue(RAP):
"""Dummy RAP class for tests. Goes straight"""
def __init__(self, rap_volume, propagator, max_nbr_pts, step_size):
"""
Step size: float
The step size inside the RAP mask. Could be different from the step
size elsewhere. In voxel world.
"""
super().__init__(rap_volume, propagator, max_nbr_pts)
self.step_size = step_size
[docs]
def rap_multistep_propagate(self, line, prev_direction):
is_line_valid = True
if len(line) > 3:
pos = line[-2] + self.step_size * np.array(prev_direction)
line[-1] = pos
return line, prev_direction, is_line_valid
return line, prev_direction, is_line_valid
[docs]
class RAPSwitch(RAP):
"""RAP class that switches tracking parameters when inside the RAP mask/label."""
def __init__(self, rap_volume, propagators: dict,
max_nbr_pts):
"""
Parameters
----------
rap_volume : DataVolume
Region-Adaptive Propagation mask.
propagators : dict
Dictionary of ODFPropagator instances keyed by label (str).
If --in_odf is provided, contains {odf_path: propagator}
as default. Additional propagators are keyed by their label,
loaded from the 'filename' key in rap_policies.json.
max_nbr_pts : int
Maximum number of points per streamline.
"""
base_propagator = list(propagators.values())[
0] if propagators else None
super().__init__(rap_volume, base_propagator, max_nbr_pts)
self._propagators = propagators
if self.propagator is not None:
self._base = {
'step_size': self.propagator.step_size,
'theta': self.propagator.theta,
'algo': getattr(self.propagator, 'algo', None),
'tracking_neighbours': getattr(self.propagator,
'tracking_neighbours', None)
}
else:
self._base = {}
# Check if all labels in the volume are covered by the configuration
unique_labels = np.unique(rap_volume.data)
# Remove 0 (background) and convert to int
unique_labels = [int(label) for label in unique_labels if label > 0]
if unique_labels:
missing_labels = [label for label in unique_labels
if str(label) not in self._propagators.keys()]
if missing_labels:
logging.warning(
f"Labels {missing_labels} found in RAP volume but not in "
f"methods config. Base params will be used for these labels."
)
[docs]
def rap_multistep_propagate(self, line, prev_direction):
"""
Propagate within the RAP region using modified parameters.
Parameters
----------
line : list
The current streamline.
prev_direction : np.ndarray
The previous tracking direction.
Returns
-------
line : list
The extended streamline.
prev_direction : np.ndarray
The last direction.
is_line_valid : bool
Whether the line is valid.
"""
# Switch to RAP parameters
label = self._get_label(line[-1],
self.propagator.space,
self.propagator.origin)
if label <= 0:
return line, prev_direction, False
# Logging debug when label changes # Apply the parameters of the RAP labels
if label != self._current_label:
if self._current_label is not None:
logging.debug(f"STEP[{self._total_steps}] label={self._current_label}"
f", algo={self.propagator.algo}"
f", theta (rad)={self.propagator.theta}"
f", vox step size={self.propagator.step_size}"
f" -> switching label to label {label}")
self._current_label = label
# Switch propagator based on label
if str(label) in self._propagators:
new_propagator = self._propagators[str(label)]
if new_propagator is not self.propagator:
new_propagator.line_rng_generator = self.propagator.line_rng_generator
self.propagator = new_propagator
logging.debug(f"RAP propagator switched to label {label}")
else:
new_propagator = self._propagators[self._propagators.keys()[0]]
if new_propagator is not self.propagator:
new_propagator.line_rng_generator = self.propagator.line_rng_generator
self.propagator = new_propagator
logging.debug(f"RAP propagator switched to default label {self._propagators.keys()[0]}")
# Perform propagation with new parameters
new_pos, new_dir, is_direction_valid = self.propagator.propagate(
line, prev_direction)
# Add the new point to the line
if is_direction_valid:
line.append(new_pos)
self._total_steps += 1
return line, new_dir, True
return line, prev_direction, False
def _get_label(self, curr_pos, space, origin):
"""
Receive label (int) at current position in RAP label volume.
Parameters
----------
curr_pos: np.ndarray
This is the current 3D position of the streamline.
space: Space
Coordinate space (here Space.VOX.).
origin: Origin
Origin convention ('center').
Returns
-------
int
The integer label at current position.
"""
v = self.rap_volume.get_value_at_coordinate(
*curr_pos, space=space, origin=origin)
try:
return int(v)
except Exception:
return int(np.round(v))
[docs]
class RAPGraph(RAP):
def __init__(self, mask_rap, propagator, max_nbr_pts, neighboorhood_size):
super().__init__(mask_rap, propagator, max_nbr_pts)
self.neighboorhood_size = neighboorhood_size
[docs]
def rap_multistep_propagate(self, line, prev_direction):
raise NotImplementedError