Source code for scilpy.tractanalysis.bingham_metric_along_streamlines

# -*- coding: utf-8 -*-

import numpy as np
from scilpy.reconst.bingham import bingham_to_peak_direction
from scilpy.tractanalysis.grid_intersections import grid_intersections


[docs] def bingham_metric_map_along_streamlines(sft, bingham_coeffs, metric, max_theta, length_weighting): """ Compute mean map for a given Bingham metric along streamlines. Parameters ---------- sft : StatefulTractogram StatefulTractogram containing the streamlines needed. bingham_coeffs : ndarray Array of shape (X, Y, Z, N_LOBES, NB_PARAMS) containing the Bingham distributions parameters. metric : ndarray Array of shape (X, Y, Z) containing the Bingham metric of interest. max_theta : float Maximum angle in degrees between the fiber direction and the Bingham peak direction. length_weighting : bool If True, will weigh the metric values according to segment lengths. """ fd_sum, weights = \ bingham_metric_sum_along_streamlines(sft, bingham_coeffs, metric, max_theta, length_weighting) non_zeros = np.nonzero(fd_sum) weights_nz = weights[non_zeros] fd_sum[non_zeros] /= weights_nz return fd_sum
[docs] def bingham_metric_sum_along_streamlines(sft, bingham_coeffs, metric, max_theta, length_weighting): """ Compute a sum map along a bundle for a given Bingham metric. Parameters ---------- sft : StatefulTractogram StatefulTractogram containing the streamlines needed. bingham_coeffs : ndarray (X, Y, Z, N_LOBES, NB_PARAMS) Bingham distributions parameters volume. metric : ndarray (X, Y, Z) The Bingham metric of interest. max_theta : float Maximum angle in degrees between the fiber direction and the Bingham peak direction. length_weighting : bool If True, will weight the metric values according to segment lengths. Returns ------- metric_sum_map : np.array Bingham metric sum map. weight_map : np.array Segment lengths. """ sft.to_vox() sft.to_corner() metric_sum_map = np.zeros(metric.shape[:-1]) weight_map = np.zeros(metric.shape[:-1]) min_cos_theta = np.cos(np.radians(max_theta)) all_crossed_indices = grid_intersections(sft.streamlines) for crossed_indices in all_crossed_indices: segments = crossed_indices[1:] - crossed_indices[:-1] seg_lengths = np.linalg.norm(segments, axis=1) # Remove points where the segment is zero. # This removes numpy warnings of division by zero. non_zero_lengths = np.nonzero(seg_lengths)[0] segments = segments[non_zero_lengths] seg_lengths = seg_lengths[non_zero_lengths] # Those starting points are used for the segment vox_idx computations seg_start = crossed_indices[non_zero_lengths] vox_indices = (seg_start + (0.5 * segments)).astype(int) normalization_weights = np.ones_like(seg_lengths) if length_weighting: normalization_weights = seg_lengths normalized_seg = np.reshape(segments / seg_lengths[..., None], (-1, 3)) for vox_idx, seg_dir, norm_weight in zip(vox_indices, normalized_seg, normalization_weights): vox_idx = tuple(vox_idx) bingham_at_idx = bingham_coeffs[vox_idx] # (5, N_PARAMS) bingham_peak_dir = bingham_to_peak_direction(bingham_at_idx) cos_theta = np.abs(np.dot(seg_dir.reshape((-1, 3)), bingham_peak_dir.T)) metric_val = 0.0 if (cos_theta > min_cos_theta).any(): lobe_idx = np.argmax(np.squeeze(cos_theta), axis=0) # (n_segs) metric_val = metric[vox_idx][lobe_idx] metric_sum_map[vox_idx] += metric_val * norm_weight weight_map[vox_idx] += norm_weight return metric_sum_map, weight_map