Source code for scilpy.reconst.sh

# -*- coding: utf-8 -*-
import itertools
import logging
import multiprocessing
import numpy as np

from dipy.core.sphere import Sphere
from dipy.direction.peaks import peak_directions
from dipy.reconst.odf import gfa
from dipy.reconst.shm import (sh_to_sf_matrix, order_from_ncoef, sf_to_sh,
                              sph_harm_ind_list)

from scilpy.gradients.bvec_bval_tools import (identify_shells,
                                              is_normalized_bvecs,
                                              normalize_bvecs,
                                              DEFAULT_B0_THRESHOLD)
from scilpy.dwi.operations import compute_dwi_attenuation


[docs] def verify_data_vs_sh_order(data, sh_order): """ Raises a warning if the dwi data shape is not enough for the chosen sh_order. Parameters ---------- data: np.ndarray Diffusion signal as weighted images (4D). sh_order: int SH order to fit, by default 4. """ if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2: logging.warning( 'We recommend having at least {} unique DWIs volumes, but you ' 'currently have {} volumes. Try lowering the parameter --sh_order ' 'in case of non convergence.'.format( (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1]))
[docs] def compute_sh_coefficients(dwi, gradient_table, b0_threshold=DEFAULT_B0_THRESHOLD, sh_order=4, basis_type='descoteaux07', smooth=0.006, use_attenuation=False, mask=None, sphere=None, is_legacy=True): """Fit a diffusion signal with spherical harmonics coefficients. Data must come from a single shell acquisition. Parameters ---------- dwi : nib.Nifti1Image object Diffusion signal as weighted images (4D). gradient_table : GradientTable Dipy object that contains all bvals and bvecs. b0_threshold: float Threshold for the b0 values. Used to validate that the data contains single shell signal. sh_order : int, optional SH order to fit, by default 4. basis_type: str Either 'tournier07' or 'descoteaux07' smooth : float, optional Lambda-regularization coefficient in the SH fit, by default 0.006. use_attenuation: bool, optional If true, we will use DWI attenuation. [False] mask: nib.Nifti1Image object, optional Binary mask. Only data inside the mask will be used for computations and reconstruction. sphere: Sphere Dipy object. If not provided, will use Sphere(xyz=bvecs). is_legacy : bool, optional Whether or not the SH basis is in its legacy form. Returns ------- sh_coeffs : np.ndarray with shape (X, Y, Z, #coeffs) Spherical harmonics coefficients at every voxel. The actual number of coefficients depends on `sh_order`. """ # Extracting infos b0_mask = gradient_table.b0s_mask bvecs = gradient_table.bvecs bvals = gradient_table.bvals # Checks if not is_normalized_bvecs(bvecs): logging.warning("Your b-vectors do not seem normalized...") bvecs = normalize_bvecs(bvecs) # Ensure that this is on a single shell. shell_values, _ = identify_shells(bvals) shell_values.sort() if shell_values.shape[0] != 2 or shell_values[0] > b0_threshold: raise ValueError("Can only work on single shell signals.") # Keeping b0-based infos bvecs = bvecs[np.logical_not(b0_mask)] weights = dwi[..., np.logical_not(b0_mask)] # Compute attenuation using the b0. if use_attenuation: b0 = dwi[..., b0_mask].mean(axis=3) weights = compute_dwi_attenuation(weights, b0) # Get cartesian coords from bvecs if sphere is None: sphere = Sphere(xyz=bvecs) # Fit SH sh = sf_to_sh(weights, sphere, sh_order, basis_type, smooth=smooth, legacy=is_legacy) # Apply mask if mask is not None: sh *= mask[..., None] return sh
[docs] def compute_rish(sh, mask=None, full_basis=False): """Compute the RISH (Rotationally Invariant Spherical Harmonics) features of the SH signal [1]. Each RISH feature map is the total energy of its associated order. Mathematically, it is the sum of the squared SH coefficients of the SH order. Parameters ---------- sh : np.ndarray object Array of the SH coefficients mask: np.ndarray object, optional Binary mask. Only data inside the mask will be used for computation. full_basis: bool, optional True when coefficients are for a full SH basis. Returns ------- rish : np.ndarray with shape (x,y,z,n_orders) The RISH features of the input SH, with one channel per SH order. orders : list(int) The SH order of each RISH feature in the last channel of `rish`. References ---------- [1] Mirzaalian, Hengameh, et al. "Harmonizing diffusion MRI data across multiple sites and scanners." MICCAI 2015. https://scholar.harvard.edu/files/hengameh/files/miccai2015.pdf """ # Guess SH order sh_order = order_from_ncoef(sh.shape[-1], full_basis=full_basis) # Get degree / order for all indices degree_ids, order_ids = sph_harm_ind_list(sh_order, full_basis=full_basis) # Apply mask to input if mask is not None: sh = sh * mask[..., None] # Get number of indices per order (e.g. for order 6, sym. : [1,5,9,13]) step = 1 if full_basis else 2 n_indices_per_order = np.bincount(order_ids)[::step] # Get start index of each order (e.g. for order 6 : [0,1,6,15]) order_positions = np.concatenate([[0], np.cumsum(n_indices_per_order)])[:-1] # Get paired indices for np.add.reduceat, specifying where to reduce. # The last index is omitted, it is automatically replaced by len(array)-1 # (e.g. for order 6 : [0,1, 1,6, 6,15, 15,]) reduce_indices = np.repeat(order_positions, 2)[1:] # Compute the sum of squared coefficients using numpy's `reduceat` squared_sh = np.square(sh) rish = np.add.reduceat(squared_sh, reduce_indices, axis=-1)[..., ::2] # Apply mask if mask is not None: rish *= mask[..., None] orders = sorted(np.unique(order_ids)) return rish, orders
def _peaks_from_sh_parallel(args): shm_coeff = args[0] B = args[1] sphere = args[2] relative_peak_threshold = args[3] absolute_threshold = args[4] min_separation_angle = args[5] npeaks = args[6] normalize_peaks = args[7] chunk_id = args[8] is_symmetric = args[9] data_shape = shm_coeff.shape[0] peak_dirs = np.zeros((data_shape, npeaks, 3)) peak_values = np.zeros((data_shape, npeaks)) peak_indices = np.zeros((data_shape, npeaks), dtype='int') peak_indices.fill(-1) for idx in range(len(shm_coeff)): if shm_coeff[idx].any(): odf = np.dot(shm_coeff[idx], B) odf[odf < absolute_threshold] = 0. dirs, peaks, ind = peak_directions(odf, sphere, relative_peak_threshold, min_separation_angle, is_symmetric) if peaks.shape[0] != 0: n = min(npeaks, peaks.shape[0]) peak_dirs[idx][:n] = dirs[:n] peak_indices[idx][:n] = ind[:n] peak_values[idx][:n] = peaks[:n] if normalize_peaks: peak_values[idx][:n] /= peaks[0] peak_dirs[idx] *= peak_values[idx][:, None] return chunk_id, peak_dirs, peak_values, peak_indices
[docs] def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5, absolute_threshold=0, min_separation_angle=25, normalize_peaks=False, npeaks=5, sh_basis_type='descoteaux07', is_legacy=True, nbr_processes=None, full_basis=False, is_symmetric=True): """Computes peaks from given spherical harmonic coefficients Parameters ---------- shm_coeff : np.ndarray Spherical harmonic coefficients sphere : Sphere The Sphere providing discrete directions for evaluation. mask : np.ndarray, optional If `mask` is provided, only the data inside the mask will be used for computations. relative_peak_threshold : float, optional Only return peaks greater than ``relative_peak_threshold * m`` where m is the largest peak. Default: 0.5 absolute_threshold : float, optional Absolute threshold on fODF amplitude. This value should be set to approximately 1.5 to 2 times the maximum fODF amplitude in isotropic voxels (ex. ventricles). `scil_fodf_max_in_ventricles.py` can be used to find the maximal value. Default: 0 min_separation_angle : float in [0, 90], optional The minimum distance between directions. If two peaks are too close only the larger of the two is returned. Default: 25 normalize_peaks : bool, optional If true, all peak values are calculated relative to `max(odf)`. npeaks : int, optional Maximum number of peaks found (default 5 peaks). sh_basis_type : str, optional Type of spherical harmonic basis used for `shm_coeff`. Either `descoteaux07` or `tournier07`. Default: `descoteaux07` is_legacy: bool, optional If true, this means that the input SH used a legacy basis definition for backward compatibility with previous ``tournier07`` and ``descoteaux07`` implementations. Default: True nbr_processes: int, optional The number of subprocesses to use. Default: multiprocessing.cpu_count() full_basis: bool, optional If True, SH coefficients are expressed using a full basis. Default: False is_symmetric: bool, optional If False, antipodal sphere directions are considered distinct. Default: True Returns ------- tuple of np.ndarray peak_dirs, peak_values, peak_indices """ sh_order = order_from_ncoef(shm_coeff.shape[-1], full_basis) B, _ = sh_to_sf_matrix(sphere, sh_order, sh_basis_type, full_basis, legacy=is_legacy) data_shape = shm_coeff.shape if mask is None: mask = np.sum(shm_coeff, axis=3).astype(bool) nbr_processes = multiprocessing.cpu_count() if nbr_processes is None \ or nbr_processes < 0 else nbr_processes # Ravel the first 3 dimensions while keeping the 4th intact, like a list of # 1D time series voxels. Then separate it in chunks of len(nbr_processes). shm_coeff = shm_coeff[mask].reshape( (np.count_nonzero(mask), data_shape[3])) chunks = np.array_split(shm_coeff, nbr_processes) chunk_len = np.cumsum([0] + [len(c) for c in chunks]) pool = multiprocessing.Pool(nbr_processes) results = pool.map(_peaks_from_sh_parallel, zip(chunks, itertools.repeat(B), itertools.repeat(sphere), itertools.repeat(relative_peak_threshold), itertools.repeat(absolute_threshold), itertools.repeat(min_separation_angle), itertools.repeat(npeaks), itertools.repeat(normalize_peaks), np.arange(len(chunks)), itertools.repeat(is_symmetric))) pool.close() pool.join() # Re-assemble the chunk together in the original shape. peak_dirs_array = np.zeros(data_shape[0:3] + (npeaks, 3)) peak_values_array = np.zeros(data_shape[0:3] + (npeaks,)) peak_indices_array = np.zeros(data_shape[0:3] + (npeaks,)) # tmp arrays are neccesary to avoid inserting data in returned variable # rather than the original array tmp_peak_dirs_array = np.zeros((np.count_nonzero(mask), npeaks, 3)) tmp_peak_values_array = np.zeros((np.count_nonzero(mask), npeaks)) tmp_peak_indices_array = np.zeros((np.count_nonzero(mask), npeaks)) for i, peak_dirs, peak_values, peak_indices in results: tmp_peak_dirs_array[chunk_len[i]:chunk_len[i+1], :, :] = peak_dirs tmp_peak_values_array[chunk_len[i]:chunk_len[i+1], :] = peak_values tmp_peak_indices_array[chunk_len[i]:chunk_len[i+1], :] = peak_indices peak_dirs_array[mask] = tmp_peak_dirs_array peak_values_array[mask] = tmp_peak_values_array peak_indices_array[mask] = tmp_peak_indices_array return peak_dirs_array, peak_values_array, peak_indices_array
def _maps_from_sh_parallel(args): shm_coeff = args[0] _ = args[1] peak_values = args[2] peak_indices = args[3] B = args[4] sphere = args[5] gfa_thr = args[6] chunk_id = args[7] data_shape = shm_coeff.shape[0] nufo_map = np.zeros(data_shape) afd_max = np.zeros(data_shape) afd_sum = np.zeros(data_shape) rgb_map = np.zeros((data_shape, 3)) gfa_map = np.zeros(data_shape) qa_map = np.zeros((data_shape, peak_values.shape[1])) max_odf = 0 global_max = -np.inf for idx in range(len(shm_coeff)): if shm_coeff[idx].any(): odf = np.dot(shm_coeff[idx], B) odf = odf.clip(min=0) sum_odf = np.sum(odf) max_odf = np.maximum(max_odf, sum_odf) if sum_odf > 0: rgb_map[idx] = np.dot(np.abs(sphere.vertices).T, odf) rgb_map[idx] /= np.linalg.norm(rgb_map[idx]) rgb_map[idx] *= sum_odf gfa_map[idx] = gfa(odf) if gfa_map[idx] < gfa_thr: global_max = max(global_max, odf.max()) elif np.sum(peak_indices[idx] > -1): nufo_map[idx] = np.sum(peak_indices[idx] > -1) afd_max[idx] = peak_values[idx].max() afd_sum[idx] = np.sqrt(np.dot(shm_coeff[idx], shm_coeff[idx])) qa_map = peak_values[idx] - odf.min() global_max = max(global_max, peak_values[idx][0]) return chunk_id, nufo_map, afd_max, afd_sum, rgb_map, \ gfa_map, qa_map, max_odf, global_max
[docs] def maps_from_sh(shm_coeff, peak_dirs, peak_values, peak_indices, sphere, mask=None, gfa_thr=0, sh_basis_type='descoteaux07', nbr_processes=None): """Computes maps from given SH coefficients and peaks Parameters ---------- shm_coeff : np.ndarray Spherical harmonic coefficients peak_dirs : np.ndarray Peak directions peak_values : np.ndarray Peak values peak_indices : np.ndarray Peak indices sphere : Sphere The Sphere providing discrete directions for evaluation. mask : np.ndarray, optional If `mask` is provided, only the data inside the mask will be used for computations. gfa_thr : float, optional Voxels with gfa less than `gfa_thr` are skipped for all metrics, except `rgb_map`. Default: 0 sh_basis_type : str, optional Type of spherical harmonic basis used for `shm_coeff`. Either `descoteaux07` or `tournier07`. Default: `descoteaux07` nbr_processes: int, optional The number of subprocesses to use. Default: multiprocessing.cpu_count() Returns ------- tuple of np.ndarray nufo_map, afd_max, afd_sum, rgb_map, gfa, qa """ sh_order = order_from_ncoef(shm_coeff.shape[-1]) B, _ = sh_to_sf_matrix(sphere, sh_order, sh_basis_type) data_shape = shm_coeff.shape if mask is None: mask = np.sum(shm_coeff, axis=3).astype(bool) nbr_processes = multiprocessing.cpu_count() \ if nbr_processes is None or nbr_processes < 0 \ else nbr_processes npeaks = peak_values.shape[3] # Ravel the first 3 dimensions while keeping the 4th intact, like a list of # 1D time series voxels. Then separate it in chunks of len(nbr_processes). shm_coeff = shm_coeff[mask].reshape( (np.count_nonzero(mask), data_shape[3])) peak_dirs = peak_dirs[mask].reshape((np.count_nonzero(mask), npeaks, 3)) peak_values = peak_values[mask].reshape((np.count_nonzero(mask), npeaks)) peak_indices = peak_indices[mask].reshape((np.count_nonzero(mask), npeaks)) shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) peak_dirs_chunks = np.array_split(peak_dirs, nbr_processes) peak_values_chunks = np.array_split(peak_values, nbr_processes) peak_indices_chunks = np.array_split(peak_indices, nbr_processes) chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) pool = multiprocessing.Pool(nbr_processes) results = pool.map(_maps_from_sh_parallel, zip(shm_coeff_chunks, peak_dirs_chunks, peak_values_chunks, peak_indices_chunks, itertools.repeat(B), itertools.repeat(sphere), itertools.repeat(gfa_thr), np.arange(len(shm_coeff_chunks)))) pool.close() pool.join() # Re-assemble the chunk together in the original shape. nufo_map_array = np.zeros(data_shape[0:3]) afd_max_array = np.zeros(data_shape[0:3]) afd_sum_array = np.zeros(data_shape[0:3]) rgb_map_array = np.zeros(data_shape[0:3] + (3,)) gfa_map_array = np.zeros(data_shape[0:3]) qa_map_array = np.zeros(data_shape[0:3] + (npeaks,)) # tmp arrays are neccesary to avoid inserting data in returned variable # rather than the original array tmp_nufo_map_array = np.zeros((np.count_nonzero(mask))) tmp_afd_max_array = np.zeros((np.count_nonzero(mask))) tmp_afd_sum_array = np.zeros((np.count_nonzero(mask))) tmp_rgb_map_array = np.zeros((np.count_nonzero(mask), 3)) tmp_gfa_map_array = np.zeros((np.count_nonzero(mask))) tmp_qa_map_array = np.zeros((np.count_nonzero(mask), npeaks)) all_time_max_odf = -np.inf all_time_global_max = -np.inf for (i, nufo_map, afd_max, afd_sum, rgb_map, gfa_map, qa_map, max_odf, global_max) in results: all_time_max_odf = max(all_time_global_max, max_odf) all_time_global_max = max(all_time_global_max, global_max) tmp_nufo_map_array[chunk_len[i]:chunk_len[i+1]] = nufo_map tmp_afd_max_array[chunk_len[i]:chunk_len[i+1]] = afd_max tmp_afd_sum_array[chunk_len[i]:chunk_len[i+1]] = afd_sum tmp_rgb_map_array[chunk_len[i]:chunk_len[i+1], :] = rgb_map tmp_gfa_map_array[chunk_len[i]:chunk_len[i+1]] = gfa_map tmp_qa_map_array[chunk_len[i]:chunk_len[i+1], :] = qa_map nufo_map_array[mask] = tmp_nufo_map_array afd_max_array[mask] = tmp_afd_max_array afd_sum_array[mask] = tmp_afd_sum_array rgb_map_array[mask] = tmp_rgb_map_array gfa_map_array[mask] = tmp_gfa_map_array qa_map_array[mask] = tmp_qa_map_array rgb_map_array /= all_time_max_odf rgb_map_array *= 255 qa_map_array /= all_time_global_max afd_unique = np.unique(afd_max_array) if np.array_equal(np.array([0, 1]), afd_unique) \ or np.array_equal(np.array([1]), afd_unique): logging.warning('All AFD_max values are 1. The peaks seem normalized.') return(nufo_map_array, afd_max_array, afd_sum_array, rgb_map_array, gfa_map_array, qa_map_array)
def _convert_sh_basis_parallel(args): sh = args[0] B_in = args[1] invB_out = args[2] chunk_id = args[3] for idx in range(sh.shape[0]): if sh[idx].any(): sf = np.dot(sh[idx], B_in) sh[idx] = np.dot(sf, invB_out) return chunk_id, sh
[docs] def convert_sh_basis(shm_coeff, sphere, mask=None, input_basis='descoteaux07', output_basis='tournier07', is_input_legacy=True, is_output_legacy=False, nbr_processes=None): """Converts spherical harmonic coefficients between two bases Parameters ---------- shm_coeff : np.ndarray Spherical harmonic coefficients sphere : Sphere The Sphere providing discrete directions for evaluation. mask : np.ndarray, optional If `mask` is provided, only the data inside the mask will be used for computations. input_basis : str, optional Type of spherical harmonic basis used for `shm_coeff`. Either `descoteaux07` or `tournier07`. Default: `descoteaux07` output_basis : str, optional Type of spherical harmonic basis wanted as output. Either `descoteaux07` or `tournier07`. Default: `tournier07` is_input_legacy: bool, optional If true, this means that the input SH used a legacy basis definition for backward compatibility with previous ``tournier07`` and ``descoteaux07`` implementations. Default: True is_output_legacy: bool, optional If true, this means that the output SH will use a legacy basis definition for backward compatibility with previous ``tournier07`` and ``descoteaux07`` implementations. Default: False nbr_processes: int, optional The number of subprocesses to use. Default: multiprocessing.cpu_count() Returns ------- shm_coeff_array : np.ndarray Spherical harmonic coefficients in the desired basis. """ if input_basis == output_basis and is_input_legacy == is_output_legacy: logging.info('Input and output SH basis are equal, no SH basis ' 'convertion needed.') return shm_coeff sh_order = order_from_ncoef(shm_coeff.shape[-1]) B_in, _ = sh_to_sf_matrix(sphere, sh_order, input_basis, legacy=is_input_legacy) _, invB_out = sh_to_sf_matrix(sphere, sh_order, output_basis, legacy=is_output_legacy) data_shape = shm_coeff.shape if mask is None: mask = np.sum(shm_coeff, axis=3).astype(bool) nbr_processes = multiprocessing.cpu_count() \ if nbr_processes is None or nbr_processes < 0 else nbr_processes # Ravel the first 3 dimensions while keeping the 4th intact, like a list of # 1D time series voxels. Then separate it in chunks of len(nbr_processes). shm_coeff = shm_coeff[mask].reshape( (np.count_nonzero(mask), data_shape[3])) shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) pool = multiprocessing.Pool(nbr_processes) results = pool.map(_convert_sh_basis_parallel, zip(shm_coeff_chunks, itertools.repeat(B_in), itertools.repeat(invB_out), np.arange(len(shm_coeff_chunks)))) pool.close() pool.join() # Re-assemble the chunk together in the original shape. shm_coeff_array = np.zeros(data_shape) tmp_shm_coeff_array = np.zeros((np.count_nonzero(mask), data_shape[3])) for i, new_shm_coeff in results: tmp_shm_coeff_array[chunk_len[i]:chunk_len[i+1], :] = new_shm_coeff shm_coeff_array[mask] = tmp_shm_coeff_array return shm_coeff_array
def _convert_sh_to_sf_parallel(args): sh = args[0] B_in = args[1] new_output_dim = args[2] chunk_id = args[3] sf = np.zeros((sh.shape[0], new_output_dim), dtype=np.float32) for idx in range(sh.shape[0]): if sh[idx].any(): sf[idx] = np.dot(sh[idx], B_in) return chunk_id, sf
[docs] def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32", input_basis='descoteaux07', input_full_basis=False, is_input_legacy=True, nbr_processes=multiprocessing.cpu_count()): """Converts spherical harmonic coefficients to an SF sphere Parameters ---------- shm_coeff : np.ndarray Spherical harmonic coefficients sphere : Sphere The Sphere providing discrete directions for evaluation. mask : np.ndarray, optional If `mask` is provided, only the data inside the mask will be used for computations. dtype : str Datatype to use for computation and output array. Either `float32` or `float64`. Default: `float32` input_basis : str, optional Type of spherical harmonic basis used for `shm_coeff`. Either `descoteaux07` or `tournier07`. Default: `descoteaux07` input_full_basis : bool, optional If True, use a full SH basis (even and odd orders) for the input SH coefficients. is_input_legacy : bool, optional Whether the input basis is in its legacy form. nbr_processes: int, optional The number of subprocesses to use. Default: multiprocessing.cpu_count() Returns ------- shm_coeff_array : np.ndarray Spherical harmonic coefficients in the desired basis. """ assert dtype in ["float32", "float64"], "Only `float32` and `float64` " \ "should be used." sh_order = order_from_ncoef(shm_coeff.shape[-1], full_basis=input_full_basis) B_in, _ = sh_to_sf_matrix(sphere, sh_order, basis_type=input_basis, full_basis=input_full_basis, legacy=is_input_legacy) B_in = B_in.astype(dtype) data_shape = shm_coeff.shape if mask is None: mask = np.sum(shm_coeff, axis=3).astype(bool) # Ravel the first 3 dimensions while keeping the 4th intact, like a list of # 1D time series voxels. Then separate it in chunks of len(nbr_processes). shm_coeff = shm_coeff[mask].reshape( (np.count_nonzero(mask), data_shape[3])) shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) pool = multiprocessing.Pool(nbr_processes) results = pool.map(_convert_sh_to_sf_parallel, zip(shm_coeff_chunks, itertools.repeat(B_in), itertools.repeat(len(sphere.vertices)), np.arange(len(shm_coeff_chunks)))) pool.close() pool.join() # Re-assemble the chunk together in the original shape. new_shape = data_shape[:3] + (len(sphere.vertices),) sf_array = np.zeros(new_shape, dtype=dtype) tmp_sf_array = np.zeros((np.count_nonzero(mask), new_shape[3]), dtype=dtype) for i, new_sf in results: tmp_sf_array[chunk_len[i]:chunk_len[i + 1], :] = new_sf sf_array[mask] = tmp_sf_array return sf_array