Source code for scilpy.tractanalysis.afd_along_streamlines

# -*- coding: utf-8 -*-

from dipy.data import get_sphere
from dipy.reconst.shm import sh_to_sf_matrix, sph_harm_ind_list
import numpy as np
from scipy.special import lpn

from scilpy.reconst.utils import find_order_from_nb_coeff
from scilpy.tractanalysis.grid_intersections import grid_intersections


[docs] def afd_map_along_streamlines(sft, fodf, fodf_basis, length_weighting, is_legacy=True): """ Compute the mean Apparent Fiber Density (AFD) and mean Radial fODF (radfODF) maps along a bundle. Parameters ---------- sft : StatefulTractogram StatefulTractogram containing the streamlines needed. fodf : nibabel.image fODF with shape (X, Y, Z, #coeffs) coeffs depending on the sh_order fodf_basis : string Has to be descoteaux07 or tournier07 length_weighting : bool If set, will weigh the AFD values according to segment lengths Returns ------- afd_sum : np.array AFD map (weighted if length_weighting) rd_sum : np.array rdAFD map (weighted if length_weighting) """ afd_sum, rd_sum, weights = \ afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, length_weighting, is_legacy=is_legacy) non_zeros = np.nonzero(afd_sum) weights_nz = weights[non_zeros] afd_sum[non_zeros] /= weights_nz rd_sum[non_zeros] /= weights_nz return afd_sum, rd_sum
[docs] def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, length_weighting, is_legacy=True): """ Compute the mean Apparent Fiber Density (AFD) and mean Radial fODF (radfODF) maps along a bundle. Parameters ---------- sft : StatefulTractogram StatefulTractogram containing the streamlines needed. fodf : nibabel.image fODF with shape (X, Y, Z, #coeffs). #coeffs depend on the sh_order. fodf_basis : string Has to be descoteaux07 or tournier07. length_weighting : bool If set, will weigh the AFD values according to segment lengths. is_legacy : bool, optional Whether or not the SH basis is in its legacy form. Returns ------- afd_sum_map : np.array AFD map. rd_sum_map : np.array fdAFD map. weight_map : np.array Segment lengths. """ sft.to_vox() sft.to_corner() fodf_data = fodf.get_fdata(dtype=np.float32) order = find_order_from_nb_coeff(fodf_data) sphere = get_sphere('repulsion724') b_matrix, _ = sh_to_sf_matrix(sphere, order, fodf_basis, legacy=is_legacy) _, n = sph_harm_ind_list(order) legendre0_at_n = lpn(order, 0)[0][n] sphere_norm = np.linalg.norm(sphere.vertices) afd_sum_map = np.zeros(shape=fodf_data.shape[:-1]) rd_sum_map = np.zeros(shape=fodf_data.shape[:-1]) weight_map = np.zeros(shape=fodf_data.shape[:-1]) p_matrix = np.eye(fodf_data.shape[3]) * legendre0_at_n all_crossed_indices = grid_intersections(sft.streamlines) for crossed_indices in all_crossed_indices: segments = crossed_indices[1:] - crossed_indices[:-1] seg_lengths = np.linalg.norm(segments, axis=1) # Remove points where the segment is zero. # This removes numpy warnings of division by zero. non_zero_lengths = np.nonzero(seg_lengths)[0] segments = segments[non_zero_lengths] seg_lengths = seg_lengths[non_zero_lengths] test = np.dot(segments, sphere.vertices.T) test2 = (test.T / (seg_lengths * sphere_norm)).T angles = np.arccos(test2) sorted_angles = np.argsort(angles, axis=1) closest_vertex_indices = sorted_angles[:, 0] # Those starting points are used for the segment vox_idx computations strl_start = crossed_indices[non_zero_lengths] vox_indices = (strl_start + (0.5 * segments)).astype(int) normalization_weights = np.ones_like(seg_lengths) if length_weighting: normalization_weights = seg_lengths / \ np.linalg.norm(fodf.header.get_zooms()[:3]) for vox_idx, closest_vertex_index, norm_weight in zip(vox_indices, closest_vertex_indices, normalization_weights): vox_idx = tuple(vox_idx) b_at_idx = b_matrix.T[closest_vertex_index] fodf_at_index = fodf_data[vox_idx] afd_val = np.dot(b_at_idx, fodf_at_index) rd_val = np.dot(np.dot(b_at_idx.T, p_matrix), fodf_at_index) afd_sum_map[vox_idx] += afd_val * norm_weight rd_sum_map[vox_idx] += rd_val * norm_weight weight_map[vox_idx] += norm_weight rd_sum_map[rd_sum_map < 0.] = 0. return afd_sum_map, rd_sum_map, weight_map