Source code for scilpy.reconst.fodf

# -*- coding: utf-8 -*-
import itertools
import logging
import multiprocessing
import numpy as np

from dipy.data import get_sphere
from dipy.reconst.mcsd import MSDeconvFit
from dipy.reconst.multi_voxel import MultiVoxelFit
from dipy.reconst.shm import sh_to_sf_matrix

from scilpy.reconst.utils import find_order_from_nb_coeff

from dipy.utils.optpkg import optional_package
cvx, have_cvxpy, _ = optional_package("cvxpy")


[docs] def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, fa_threshold, md_threshold, small_dims=False, is_legacy=True): """ Compute mean maximal fodf value in ventricules. Given heuristics thresholds on FA and MD values, finds the voxels of the ventricules or CSF and computes a mean fODF value. This is described in Dell'Acqua et al. HBM 2013. Ventricles are searched in a window in the middle of the data to increase speed. No need to scan the whole image. Parameters ---------- data: ndarray (x, y, z, ncoeffs) Input fODF file in spherical harmonics coefficients. Uses sphere 'repulsion100' to convert to SF values. fa: ndarray (x, y, z) FA (Fractional Anisotropy) volume from DTI md: ndarray (x, y, z) MD (Mean Diffusivity) volume from DTI zoom: List of length 3 The resolution. A total number of voxels of 1000 works well at 2x2x2 = 8 mm3. sh_basis: str Either 'tournier07' or 'descoteaux07' small_dims: bool If set, takes the full range of data to search the max fodf amplitude in ventricles, rather than a window center in the data. Useful when the data has small dimensions. fa_threshold: float Maximal threshold of FA (voxels under that threshold are considered for evaluation). Suggested value: 0.1. md_threshold: float Minimal threshold of MD in mm2/s (voxels above that threshold are considered for evaluation). Suggested value: 0.003. is_legacy : bool, optional Whether the SH basis is in its legacy form. Returns ------- mean, mask: int, ndarray (x, y, z) Mean maximum fODF value and mask of voxels used. """ order = find_order_from_nb_coeff(data) sphere = get_sphere('repulsion100') b_matrix, _ = sh_to_sf_matrix(sphere, order, sh_basis, legacy=is_legacy) mask = np.zeros(data.shape[:-1]) # 1000 works well at 2x2x2 = 8 mm3 # Hence, we multiply by the volume of a voxel vol = (zoom[0] * zoom[1] * zoom[2]) if vol != 0: max_number_of_voxels = 1000 * 8 // vol else: max_number_of_voxels = 1000 logging.debug("Searching for ventricle voxels, up to a maximum of {} " "voxels.".format(max_number_of_voxels)) # In the case of 2D-like data (3D data with one dimension size of 1), or # a small 3D dataset, the full range of data is scanned. if small_dims: all_i = list(range(0, data.shape[0])) all_j = list(range(0, data.shape[1])) all_k = list(range(0, data.shape[2])) # In the case of a normal 3D dataset, a window is created in the middle of # the image to capture the ventricles. No need to scan the whole image. # (Automatic definition of window's radius based on the shape of the data.) else: if np.min(data.shape[:-1]) > 40: radius = 20 else: if np.min(data.shape[:-1]) > 20: radius = 10 else: radius = 5 all_i = list(range(int(data.shape[0]/2) - radius, int(data.shape[0]/2) + radius)) all_j = list(range(int(data.shape[1]/2) - radius, int(data.shape[1]/2) + radius)) all_k = list(range(int(data.shape[2]/2) - radius, int(data.shape[2]/2) + radius)) # Ok. Now find ventricle voxels. sum_of_max = 0 count = 0 for i in all_i: for j in all_j: for k in all_k: if count > max_number_of_voxels - 1: continue if fa[i, j, k] < fa_threshold \ and md[i, j, k] > md_threshold: sf = np.dot(data[i, j, k], b_matrix) sum_of_max += sf.max() count += 1 mask[i, j, k] = 1 logging.info('Number of voxels detected: {}'.format(count)) if count == 0: logging.warning('No voxels found for evaluation! Change your fa ' 'and/or md thresholds') return 0, mask logging.info('Average max fodf value: {}'.format(sum_of_max / count)) return sum_of_max / count, mask
def _fit_from_model_parallel(args): model = args[0] data = args[1] chunk_id = args[2] sub_fit_array = np.zeros((data.shape[0],), dtype='object') for i in range(data.shape[0]): if data[i].any(): try: sub_fit_array[i] = model.fit(data[i]) except cvx.error.SolverError: coeff = np.full((len(model.n)), np.NaN) sub_fit_array[i] = MSDeconvFit(model, coeff, None) return chunk_id, sub_fit_array
[docs] def fit_from_model(model, data, mask=None, nbr_processes=None): """Fit the model to data. Can use parallel processing. Parameters ---------- model : a model instance It will be used to fit the data. e.g: An instance of dipy.reconst.shm.SphHarmFit. data : np.ndarray (4d) Diffusion data. mask : np.ndarray, optional If `mask` is provided, only the data inside the mask will be used for computations. nbr_processes : int, optional The number of subprocesses to use. Default: multiprocessing.cpu_count() Returns ------- fit_array : MultiVoxelFit Dipy's MultiVoxelFit, containing the fit. It contains an array of fits. Any attributes of its individuals fits (of class given by 'model.fit') can be accessed through the MultiVoxelFit to get all fits at once. """ data_shape = data.shape if mask is None: mask = np.sum(data, axis=3).astype(bool) else: mask_any = np.sum(data, axis=3).astype(bool) mask *= mask_any nbr_processes = multiprocessing.cpu_count() \ if nbr_processes is None or nbr_processes <= 0 \ else nbr_processes # Ravel the first 3 dimensions while keeping the 4th intact, like a list of # 1D time series voxels. Then separate it in chunks of len(nbr_processes). data = data[mask].reshape((np.count_nonzero(mask), data_shape[3])) chunks = np.array_split(data, nbr_processes) chunk_len = np.cumsum([0] + [len(c) for c in chunks]) pool = multiprocessing.Pool(nbr_processes) results = pool.map(_fit_from_model_parallel, zip(itertools.repeat(model), chunks, np.arange(len(chunks)))) pool.close() pool.join() # Re-assemble the chunk together in the original shape. fit_array = np.zeros(data_shape[0:3], dtype='object') tmp_fit_array = np.zeros((np.count_nonzero(mask)), dtype='object') for i, fit in results: tmp_fit_array[chunk_len[i]:chunk_len[i+1]] = fit fit_array[mask] = tmp_fit_array fit_array = MultiVoxelFit(model, fit_array, mask) return fit_array
[docs] def verify_failed_voxels_shm_coeff(shm_coeff): """ Verifies if there are any NaN in the final coefficients, and if so raises warnings. Parameters ---------- shm_coeff: np.ndarray The shm_coeff given by dipy's fit classes. Of shape (x, y, z, n) with the coefficients on the last dimension. Returns ------- shm_coeff: np.ndarray The coefficients with 0 instead of NaNs. """ nan_count = len(np.argwhere(np.isnan(shm_coeff[..., 0]))) voxel_count = np.prod(shm_coeff.shape[:-1]) if nan_count / voxel_count >= 0.05: logging.warning( "There are {} voxels out of {} that could not be solved by the " "solver, reaching a critical amount of voxels. Make sure to tune " "the response functions properly, as the solving process is very " "sensitive to it. Proceeding to fill the problematic voxels by 0." .format(nan_count, voxel_count)) elif nan_count > 0: logging.warning( "There are {} voxels out of {} that could not be solved by the " "solver. Make sure to tune the response functions properly, as " "the solving process is very sensitive to it. Proceeding to fill " "the problematic voxels by 0.".format(nan_count, voxel_count)) shm_coeff = np.where(np.isnan(shm_coeff), 0, shm_coeff) return shm_coeff
[docs] def verify_frf_files(wm_frf, gm_frf, csf_frf): """ Verifies that all three frf files contain four columns, else raises ValueErrors. Parameters ---------- wm_frf: np.ndarray The frf directly as loaded from its text file. gm_frf: np.ndarray Idem csf_frf: np.ndarray Idem Returns ------- wm_frf: np.ndarray The file. In the case where there was only one line in the file, and it has been loaded as a vector, we return the array formatted as 2D, with the 4 frf values as columns. gm_frf: np.ndarray Idem csf_frf: np.ndarray Idem """ if len(wm_frf.shape) == 1: wm_frf = wm_frf[None, :] if len(gm_frf.shape) == 1: gm_frf = gm_frf[None, :] if len(csf_frf.shape) == 1: csf_frf = csf_frf[None, :] if not wm_frf.shape[1] == 4: raise ValueError('WM frf file did not contain 4 elements. ' 'Invalid or deprecated FRF format') if not gm_frf.shape[1] == 4: raise ValueError('GM frf file did not contain 4 elements. ' 'Invalid or deprecated FRF format') if not csf_frf.shape[1] == 4: raise ValueError('CSF frf file did not contain 4 elements. ' 'Invalid or deprecated FRF format') return wm_frf, gm_frf, csf_frf