Source code for scilpy.tracking.utils

# -*- coding: utf-8 -*-
import logging
from typing import Iterable

import nibabel as nib
import numpy as np
from nibabel.streamlines import TrkFile
from nibabel.streamlines.tractogram import LazyTractogram, TractogramItem
from tqdm import tqdm

from dipy.core.sphere import HemiSphere
from dipy.data import get_sphere
from dipy.direction import (DeterministicMaximumDirectionGetter,
                            ProbabilisticDirectionGetter, PTTDirectionGetter)
from dipy.direction.peaks import PeaksAndMetrics
from dipy.io.utils import create_tractogram_header, get_reference_info
from dipy.reconst.shm import sh_to_sf_matrix
from dipy.tracking.streamlinespeed import compress_streamlines, length
from scilpy.io.utils import (add_compression_arg, add_overwrite_arg,
                             add_sh_basis_args)
from scilpy.reconst.utils import find_order_from_nb_coeff, get_maximas


[docs] class TrackingDirection(list): """ Tracking direction use as 3D cartesian direction (list(x,y,z)) and has an index to work with discrete sphere. """ def __init__(self, cartesian, index=None): super(TrackingDirection, self).__init__(cartesian) self.index = index
[docs] def add_mandatory_options_tracking(p): """ Args that are required in both scil_tracking_local and scil_tracking_local_dev scripts. """ p.add_argument('in_odf', help='File containing the orientation diffusion function \n' 'as spherical harmonics file (.nii.gz). Ex: ODF or ' 'fODF.') p.add_argument('in_seed', help='Seeding mask (.nii.gz).') p.add_argument('in_mask', help='Tracking mask (.nii.gz).\n' 'Tracking will stop outside this mask. The last point ' 'of each \nstreamline (triggering the stopping ' 'criteria) IS added to the streamline.') p.add_argument('out_tractogram', help='Tractogram output file (must be .trk or .tck).')
[docs] def add_tracking_options(p): """ Options that are available in both scil_tracking_local and scil_tracking_local_dev scripts. """ track_g = p.add_argument_group('Tracking options') track_g.add_argument('--step', dest='step_size', type=float, default=0.5, help='Step size in mm. [%(default)s]') track_g.add_argument('--min_length', type=float, default=10., metavar='m', help='Minimum length of a streamline in mm. ' '[%(default)s]') track_g.add_argument('--max_length', type=float, default=300., metavar='M', help='Maximum length of a streamline in mm. ' '[%(default)s]') track_g.add_argument('--theta', type=float, help='Maximum angle between 2 steps. If the angle is ' 'too big, streamline is \nstopped and the ' 'following point is NOT included.\n' '["eudx"=60, "det"=45, "prob"=20, "ptt"=20]') track_g.add_argument('--sfthres', dest='sf_threshold', metavar='sf_th', type=float, default=0.1, help='Spherical function relative threshold. ' '[%(default)s]') add_sh_basis_args(track_g) return track_g
[docs] def add_tracking_ptt_options(p): track_g = p.add_argument_group('PTT options') track_g.add_argument('--probe_length', dest='probe_length', type=float, default=1.0, help='The length of the probes. Smaller value\n' 'yields more dispersed fibers. [%(default)s]') track_g.add_argument('--probe_radius', dest='probe_radius', type=float, default=0, help='The radius of the probe. A large probe_radius\n' 'helps mitigate noise in the pmf but it might\n' 'make it harder to sample thin and intricate\n' 'connections, also the boundary of fiber\n' 'bundles might be eroded. [%(default)s]') track_g.add_argument('--probe_quality', dest='probe_quality', type=int, default=3, help='The quality of the probe. This parameter sets\n' 'the number of segments to split the cylinder\n' 'along the length of the probe (minimum=2) ' '[%(default)s]') track_g.add_argument('--probe_count', dest='probe_count', type=int, default=1, help='The number of probes. This parameter sets the\n' 'number of parallel lines used to model the\n' 'cylinder (minimum=1). [%(default)s]') track_g.add_argument('--support_exponent', type=float, default=3, help='Data support exponent, used for rejection\n' 'sampling. [%(default)s]') return track_g
[docs] def add_seeding_options(p): """ Options that are available in both scil_tracking_local and scil_tracking_local_dev scripts. """ seed_group = p.add_argument_group( 'Seeding options', 'When no option is provided, uses --npv 1.') seed_sub_exclusive = seed_group.add_mutually_exclusive_group() seed_sub_exclusive.add_argument('--npv', type=int, help='Number of seeds per voxel.') seed_sub_exclusive.add_argument('--nt', type=int, help='Total number of seeds to use.')
[docs] def add_out_options(p): """ Options that are available in both scil_tracking_local and scil_tracking_local_dev scripts. """ out_g = p.add_argument_group('Output options') msg = ("\nA rule of thumb is to set it to 0.1mm for deterministic \n" "streamlines and to 0.2mm for probabilitic streamlines.") add_compression_arg(out_g, additional_msg=msg) add_overwrite_arg(out_g) out_g.add_argument('--save_seeds', action='store_true', help='If set, save the seeds used for the tracking \n ' 'in the data_per_streamline property.\n' 'Hint: you can then use ' 'scilpy_compute_seed_density_map.') return out_g
[docs] def verify_streamline_length_options(parser, args): if not args.min_length >= 0: parser.error('min_length must be >= 0, but {}mm was provided.' .format(args.min_length)) if args.max_length < args.min_length: parser.error('max_length must be > than min_length, but ' 'min_length={}mm and max_length={}mm.' .format(args.min_length, args.max_length))
[docs] def verify_seed_options(parser, args): if args.npv and args.npv <= 0: parser.error('Number of seeds per voxel must be > 0.') if args.nt and args.nt <= 0: parser.error('Total number of seeds must be > 0.')
[docs] def tqdm_if_verbose(generator: Iterable, verbose: bool, *args, **kwargs): if verbose: return tqdm(generator, *args, **kwargs) return generator
[docs] def save_tractogram( streamlines_generator, tracts_format, ref_img, total_nb_seeds, out_tractogram, min_length, max_length, compress, save_seeds, verbose ): """ Save the streamlines on-the-fly using a generator. Tracts are filtered according to their length and compressed if requested. Seeds are saved if requested. The tractogram is shifted and scaled according to the file format. Parameters ---------- streamlines_generator : generator Streamlines generator. tracts_format : TrkFile or TckFile Tractogram format. ref_img : nibabel.Nifti1Image Image used as reference. total_nb_seeds : int Total number of seeds. out_tractogram : str Output tractogram filename. min_length : float Minimum length of a streamline in mm. max_length : float Maximum length of a streamline in mm. compress : float Distance threshold for compressing streamlines in mm. save_seeds : bool If True, save the seeds used for the tracking in the data_per_streamline property. verbose : bool If True, display progression bar. """ voxel_size = ref_img.header.get_zooms()[0] scaled_min_length = min_length / voxel_size scaled_max_length = max_length / voxel_size # Tracking is expected to be returned in voxel space, origin `center`. def tracks_generator_wrapper(): for strl, seed in tqdm_if_verbose(streamlines_generator, verbose=verbose, total=total_nb_seeds, miniters=int(total_nb_seeds / 100), leave=False): if (scaled_min_length <= length(strl) <= scaled_max_length): # Seeds are saved with origin `center` by our own convention. # Other scripts (e.g. scil_compute_seed_density_map) expect so. dps = {} if save_seeds: dps['seeds'] = seed if compress: # compression threshold is given in mm, but we # are in voxel space strl = compress_streamlines( strl, compress / voxel_size) # TODO: Use nibabel utilities for dealing with spaces if tracts_format is TrkFile: # Streamlines are dumped in mm space with # origin `corner`. This is what is expected by # LazyTractogram for .trk files (although this is not # specified anywhere in the doc) strl += 0.5 strl *= voxel_size # in mm. else: # Streamlines are dumped in true world space with # origin center as expected by .tck files. strl = np.dot(strl, ref_img.affine[:3, :3]) + \ ref_img.affine[:3, 3] yield TractogramItem(strl, dps, {}) tractogram = LazyTractogram.from_data_func(tracks_generator_wrapper) tractogram.affine_to_rasmm = ref_img.affine filetype = nib.streamlines.detect_format(out_tractogram) reference = get_reference_info(ref_img) header = create_tractogram_header(filetype, *reference) # Use generator to save the streamlines on-the-fly nib.streamlines.save(tractogram, out_tractogram, header=header)
[docs] def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, voxel_size, sf_threshold, sh_to_pmf, probe_length, probe_radius, probe_quality, probe_count, support_exponent, is_legacy=True): """ Return the direction getter object. Parameters ---------- in_img: str Path to the input odf file. algo: str Algorithm to use for tracking. Can be 'det', 'prob', 'ptt' or 'eudx'. sphere: str Name of the sphere to use for tracking. sub_sphere: int Number of subdivisions to use for the sphere. theta: float Angle threshold for tracking. sh_basis: str Name of the sh basis to use for tracking. voxel_size: float Voxel size of the input data. sf_threshold: float Spherical function-amplitude threshold for tracking. sh_to_pmf: bool Map sherical harmonics to spherical function (pmf) before tracking (faster, requires more memory). probe_length : float The length of the probes. Shorter probe_length yields more dispersed fibers. probe_radius : float The radius of the probe. A large probe_radius helps mitigate noise in the pmf but it might make it harder to sample thin and intricate connections, also the boundary of fiber bundles might be eroded. probe_quality : int The quality of the probe. This parameter sets the number of segments to split the cylinder along the length of the probe (minimum=2). probe_count : int The number of probes. This parameter sets the number of parallel lines used to model the cylinder (minimum=1). support_exponent : float Data support exponent, used for rejection sampling. is_legacy : bool, optional Whether or not the SH basis is in its legacy form. Return ------ dg: dipy.direction.DirectionGetter The direction getter object. """ img_data = nib.load(in_img).get_fdata(dtype=np.float32) sphere = HemiSphere.from_sphere( get_sphere(sphere)).subdivide(sub_sphere) # Theta depends on user choice and algorithm theta = get_theta(theta, algo) # Heuristic to find out if the input are peaks or fodf # fodf are always around 0.15 and peaks around 0.75 # Peaks have more zero values than fodf. The first value of fodf is # usually the highest. non_zeros_count = np.count_nonzero(np.sum(img_data, axis=-1)) non_first_val_count = np.count_nonzero(np.argmax(img_data, axis=-1)) is_peaks = non_first_val_count / non_zeros_count > 0.5 if algo in ['det', 'prob', 'ptt']: if is_peaks: logging.warning( 'Input detected as peaks. Input should be fodf for ' 'det/prob/ptt, verify input just in case.') kwargs = {} if algo == 'ptt': dg_class = PTTDirectionGetter # Considering the step size usually used, the probe length # can be set as the voxel size. kwargs = {'probe_length': probe_length, 'probe_radius': probe_radius, 'probe_quality': probe_quality, 'probe_count': probe_count, 'data_support_exponent': support_exponent} elif algo == 'det': dg_class = DeterministicMaximumDirectionGetter else: dg_class = ProbabilisticDirectionGetter return dg_class.from_shcoeff( shcoeff=img_data, max_angle=theta, sphere=sphere, basis_type=sh_basis, legacy=is_legacy, sh_to_pmf=sh_to_pmf, relative_peak_threshold=sf_threshold, **kwargs) elif algo == 'eudx': # Code for algo EUDX. We don't use peaks_from_model # because we want the peaks from the provided sh. img_shape_3d = img_data.shape[:-1] dg = PeaksAndMetrics() dg.sphere = sphere dg.ang_thr = theta dg.qa_thr = sf_threshold if is_peaks: # If the input is peaks, we compute their amplitude and # find the closest direction on the sphere. logging.info('Input detected as peaks.') nb_peaks = img_data.shape[-1] // 3 slices = np.arange(0, 15 + 1, 3) peak_values = np.zeros(img_shape_3d + (nb_peaks,)) peak_indices = np.zeros(img_shape_3d + (nb_peaks,)) for idx in np.argwhere(np.sum(img_data, axis=-1)): idx = tuple(idx) for i in range(nb_peaks): peak_values[idx][i] = np.linalg.norm( img_data[idx][slices[i]:slices[i + 1]], axis=-1) peak_indices[idx][i] = sphere.find_closest( img_data[idx][slices[i]:slices[i + 1]]) dg.peak_dirs = img_data else: # If the input is not peaks, we assume it is fodf # and we compute the peaks from the fodf. logging.info('Input detected as fodf.') npeaks = 5 peak_dirs = np.zeros((img_shape_3d + (npeaks, 3))) peak_values = np.zeros((img_shape_3d + (npeaks,))) peak_indices = np.full((img_shape_3d + (npeaks,)), -1, dtype='int') b_matrix, _ = sh_to_sf_matrix(sphere, find_order_from_nb_coeff(img_data), sh_basis, legacy=is_legacy) for idx in np.argwhere(np.sum(img_data, axis=-1)): idx = tuple(idx) directions, values, indices = get_maximas(img_data[idx], sphere, b_matrix.T, sf_threshold, 0) if values.shape[0] != 0: n = min(npeaks, values.shape[0]) peak_dirs[idx][:n] = directions[:n] peak_values[idx][:n] = values[:n] peak_indices[idx][:n] = indices[:n] dg.peak_dirs = peak_dirs dg.peak_values = peak_values dg.peak_indices = peak_indices return dg
[docs] def get_theta(requested_theta, tracking_type): if requested_theta is not None: theta = requested_theta elif tracking_type == 'ptt': theta = 20 elif tracking_type == 'prob': theta = 20 elif tracking_type == 'eudx': theta = 60 else: theta = 45 return theta
[docs] def sample_distribution(dist, random_generator: np.random.Generator): """ Parameters ---------- dist: numpy.array The empirical distribution to sample from. random_generator: numpy Generator Return ------ ind: int The index of the sampled element. """ cdf = dist.cumsum() if cdf[-1] == 0: return None return cdf.searchsorted(random_generator.random() * cdf[-1])