# -*- coding: utf-8 -*-
import numpy as np
import logging
from dipy.reconst.shm import sh_to_sf_matrix
from dipy.data import get_sphere
from dipy.core.sphere import Sphere
from scipy.ndimage import correlate
from itertools import product as iterprod
from scilpy.gpuparallel.opencl_utils import have_opencl, CLKernel, CLManager
[docs]
def unified_filtering(sh_data, sh_order, sh_basis, is_legacy, full_basis,
sphere_str, sigma_spatial=1.0, sigma_align=0.8,
sigma_angle=None, rel_sigma_range=0.2,
win_hwidth=None, exclude_center=False,
device_type='gpu', use_opencl=True, patch_size=40):
"""
Unified asymmetric filtering as described in [1].
Parameters
----------
sh_data: ndarray
SH coefficients image.
sh_order: int
Maximum order of spherical harmonics (SH) basis.
sh_basis: str
SH basis definition used for input and output SH image.
One of 'descoteaux07' or 'tournier07'.
is_legacy: bool
Whether the legacy SH basis definition should be used.
full_basis: bool
Whether the input SH basis is full or not.
sphere_str: str
Name of the DIPY sphere to use for SH to SF projection.
sigma_spatial: float or None
Standard deviation of spatial filter. Can be None to replace
by mean filter, in what case win_hwidth must be given.
sigma_align: float or None
Standard deviation of alignment filter. `None` disables
alignment filtering.
sigma_angle: float or None
Standard deviation of the angle filter. `None` disables
angle filtering.
rel_sigma_range: float or None
Standard deviation of the range filter, relative to the
range of SF amplitudes. `None` disables range filtering.
disable_spatial: bool, optional
Replace gaussian filter by a mean filter for spatial filter.
The value from `sigma_spatial` is still used for setting the
size of the filtering window.
win_hwidth: int, optional
Half-width of the filtering window. When None, the
filtering window half-width is given by (6*sigma_spatial + 1).
exclude_center: bool, optional
Assign a weight of 0 to the center voxel of the filter.
device_type: string, optional
Device on which the code should run. Choices are cpu or gpu.
use_opencl: bool, optional
Use OpenCL for software acceleration.
patch_size: int, optional
Patch size for OpenCL execution.
References
----------
[1] Poirier and Descoteaux, 2024, "A Unified Filtering Method for
Estimating Asymmetric Orientation Distribution Functions",
Neuroimage, https://doi.org/10.1016/j.neuroimage.2024.120516
"""
if sigma_spatial is None and win_hwidth is None:
raise ValueError('sigma_spatial and win_hwidth cannot both be None')
if device_type not in ['cpu', 'gpu']:
raise ValueError('Invalid device type {}. Must be cpu or gpu'
.format(device_type))
if use_opencl and not have_opencl:
raise ValueError('pyopencl is not installed. Please install before'
' using option use_opencl=True.')
if device_type == 'gpu' and not use_opencl:
raise ValueError('Option use_opencl must be enabled '
'to use device \'gpu\'.')
sphere = get_sphere(sphere_str)
if sigma_spatial is not None:
if sigma_spatial <= 0.0:
raise ValueError('sigma_spatial cannot be <= 0.')
# calculate half-width from sigma_spatial
half_width = int(round(3*sigma_spatial))
if sigma_align is not None:
if sigma_align <= 0.0:
raise ValueError('sigma_align cannot be <= 0.')
if sigma_angle is not None:
if sigma_angle <= 0.0:
raise ValueError('sigma_align cannot be <= 0.')
# overwrite half-width if win_hwidth is supplied
if win_hwidth is not None:
half_width = win_hwidth
# filter shape computed from half_width
filter_shape = (half_width*2+1, half_width*2+1, half_width*2+1)
# build filters
uv_filter = _unified_filter_build_uv(sigma_angle, sphere)
nx_filter = _unified_filter_build_nx(filter_shape, sigma_spatial,
sigma_align, sphere, exclude_center)
B = sh_to_sf_matrix(sphere, sh_order, sh_basis, full_basis,
legacy=is_legacy, return_inv=False)
_, B_inv = sh_to_sf_matrix(sphere, sh_order, sh_basis, True,
legacy=is_legacy, return_inv=True)
# compute "real" sigma_range scaled by sf amplitudes
# if rel_sigma_range is supplied
sigma_range = None
if rel_sigma_range is not None:
if rel_sigma_range <= 0.0:
raise ValueError('sigma_rangel cannot be <= 0.')
sigma_range = rel_sigma_range * _get_sf_range(sh_data, B)
if use_opencl:
# initialize opencl
cl_manager = _unified_filter_prepare_opencl(sigma_range, sigma_angle,
filter_shape[0], sphere,
device_type)
return _unified_filter_call_opencl(sh_data, nx_filter, uv_filter,
cl_manager, B, B_inv, sphere,
patch_size)
else:
return _unified_filter_call_python(sh_data, nx_filter, uv_filter,
sigma_range, B, B_inv, sphere)
def _unified_filter_prepare_opencl(sigma_range, sigma_angle, window_width,
sphere, device_type):
"""
Instantiate OpenCL context manager and compile OpenCL program.
Parameters
----------
sigma_range: float or None
Value for sigma_range.
sigma_angle: float or None
Value for sigma_angle.
window_width: int
Width of filtering window.
sphere: DIPY sphere
Sphere used for SH to SF projection.
device_type: string
Device to be used by OpenCL. Either 'cpu' or 'gpu'.
Returns
-------
cl_manager: CLManager
OpenCL manager object.
"""
disable_range = sigma_range is None
disable_angle = sigma_angle is None
if sigma_range is None:
sigma_range = 0.0 # placeholder value for sigma_range
cl_kernel = CLKernel('filter', 'denoise', 'aodf_filter.cl')
cl_kernel.set_define('WIN_WIDTH', window_width)
cl_kernel.set_define('SIGMA_RANGE', '{}f'.format(sigma_range))
cl_kernel.set_define('N_DIRS', len(sphere.vertices))
cl_kernel.set_define('DISABLE_ANGLE', 'true' if disable_angle else 'false')
cl_kernel.set_define('DISABLE_RANGE', 'true' if disable_range else 'false')
return CLManager(cl_kernel, device_type)
def _unified_filter_build_uv(sigma_angle, sphere):
"""
Build the angle filter, weighted on angle between current direction u
and neighbour direction v.
Parameters
----------
sigma_angle: float
Standard deviation of filter. Values at distances greater than
sigma_angle are clipped to 0 to reduce computation time.
sphere: DIPY sphere
Sphere used for sampling the SF.
Returns
-------
weights: ndarray
Angle filter of shape (N_dirs, N_dirs).
"""
directions = sphere.vertices
if sigma_angle is not None:
dot = directions.dot(directions.T)
x = np.arccos(np.clip(dot, -1.0, 1.0))
weights = _evaluate_gaussian_distribution(x, sigma_angle)
mask = x > (3.0*sigma_angle)
weights[mask] = 0.0
weights /= np.sum(weights, axis=-1)
else:
weights = np.eye(len(directions))
return weights
def _unified_filter_build_nx(filter_shape, sigma_spatial, sigma_align,
sphere, exclude_center):
"""
Build the combined spatial and alignment filter.
Parameters
----------
filter_shape: tuple
Dimensions of filtering window.
sigma_spatial: float or None
Standard deviation of spatial filter. None disables Gaussian
weighting for spatial filtering.
sigma_align: float or None
Standard deviation of the alignment filter. None disables Gaussian
weighting for alignment filtering.
sphere: DIPY sphere
Sphere for SH to SF projection.
exclude_center: bool
Whether the center voxel is included in the neighbourhood.
Returns
-------
weights: ndarray
Combined spatial + alignment filter of shape (W, H, D, N) where
N is the number of sphere directions.
"""
directions = sphere.vertices.astype(np.float32)
grid_directions = _get_window_directions(filter_shape).astype(np.float32)
distances = np.linalg.norm(grid_directions, axis=-1)
grid_directions[distances > 0] = grid_directions[distances > 0] /\
distances[distances > 0][..., None]
if sigma_spatial is None:
w_spatial = np.ones(filter_shape)
else:
w_spatial = _evaluate_gaussian_distribution(distances, sigma_spatial)
if sigma_align is None:
w_align = np.ones(np.append(filter_shape, (len(directions),)))
else:
cos_theta = np.clip(grid_directions.dot(directions.T), -1.0, 1.0)
theta = np.arccos(cos_theta)
theta[filter_shape[0] // 2,
filter_shape[1] // 2,
filter_shape[2] // 2] = 0.0
w_align = _evaluate_gaussian_distribution(theta, sigma_align)
# resulting filter
w = w_spatial[..., None] * w_align
if exclude_center:
w[filter_shape[0] // 2,
filter_shape[1] // 2,
filter_shape[2] // 2] = 0.0
# normalize and return
w /= np.sum(w, axis=(0, 1, 2), keepdims=True)
return w
def _get_sf_range(sh_data, B_mat):
"""
Get the range of SF amplitudes for input `sh_data`.
Parameters
----------
sh_data: ndarray
Spherical harmonics coefficients image.
B_mat: ndarray
SH to SF projection matrix.
Returns
-------
sf_range: float
Range of SF amplitudes.
"""
sf = np.array([np.dot(i, B_mat) for i in sh_data],
dtype=sh_data.dtype)
sf[sf < 0.0] = 0.0
sf_max = np.max(sf)
sf_min = np.min(sf)
return sf_max - sf_min
def _unified_filter_call_opencl(sh_data, nx_filter, uv_filter, cl_manager,
B, B_inv, sphere, patch_size=40):
"""
Run unified filtering for asymmetric ODFs using OpenCL.
Parameters
----------
sh_data: ndarray
Input SH volume.
nx_filter: ndarray
Combined spatial and alignment filter.
uv_filter: ndarray
Angle filter.
cl_manager: CLManager
A CLManager instance.
B: ndarray
SH to SF projection matrix.
B_inv: ndarray
SF to SH projection matrix.
sphere: DIPY sphere
Sphere for SH to SF projection.
patch_size: int
Data is processed in patches of
patch_size x patch_size x patch_size.
Returns
-------
out_sh: ndarray
Filtered output as SH coefficients.
"""
uv_weights_offsets =\
np.append([0.0], np.cumsum(np.count_nonzero(uv_filter,
axis=-1)))
v_indices = np.tile(np.arange(uv_filter.shape[1]),
(uv_filter.shape[0], 1))[uv_filter > 0.0]
flat_uv = uv_filter[uv_filter > 0.0]
# Prepare GPU buffers
cl_manager.add_input_buffer("sf_data") # SF data not initialized
cl_manager.add_input_buffer("nx_filter", nx_filter)
cl_manager.add_input_buffer("uv_filter", flat_uv)
cl_manager.add_input_buffer("uv_weights_offsets", uv_weights_offsets)
cl_manager.add_input_buffer("v_indices", v_indices)
cl_manager.add_output_buffer("out_sf") # SF not initialized yet
win_width = nx_filter.shape[0]
win_hwidth = win_width // 2
volume_shape = sh_data.shape[:-1]
padded_volume_shape = tuple(np.asarray(volume_shape) + win_width - 1)
out_sh = np.zeros(np.append(sh_data.shape[:-1], B_inv.shape[-1]))
# Pad SH data
sh_data = np.pad(sh_data, ((win_hwidth, win_hwidth),
(win_hwidth, win_hwidth),
(win_hwidth, win_hwidth),
(0, 0)))
# process in batches
padded_patch_size = patch_size + nx_filter.shape[0] - 1
n_splits = np.ceil(np.asarray(volume_shape) / float(patch_size))\
.astype(int)
splits_prod = iterprod(np.arange(n_splits[0]),
np.arange(n_splits[1]),
np.arange(n_splits[2]))
n_splits_prod = np.prod(n_splits)
for i, split_offset in enumerate(splits_prod):
logging.info('Patch {}/{}'.format(i+1, n_splits_prod))
i, j, k = split_offset
patch_in = np.array(
[[i * patch_size, min((i*patch_size)+padded_patch_size,
padded_volume_shape[0])],
[j * patch_size, min((j*patch_size)+padded_patch_size,
padded_volume_shape[1])],
[k * patch_size, min((k*patch_size)+padded_patch_size,
padded_volume_shape[2])]])
patch_out = np.array(
[[i * patch_size, min((i+1)*patch_size, volume_shape[0])],
[j * patch_size, min((j+1)*patch_size, volume_shape[1])],
[k * patch_size, min((k+1)*patch_size, volume_shape[2])]])
out_shape = tuple(np.append(patch_out[:, 1] - patch_out[:, 0],
len(sphere.vertices)))
sh_patch = sh_data[patch_in[0, 0]:patch_in[0, 1],
patch_in[1, 0]:patch_in[1, 1],
patch_in[2, 0]:patch_in[2, 1]]
sf_patch = np.dot(sh_patch, B)
cl_manager.update_input_buffer("sf_data", sf_patch)
cl_manager.update_output_buffer("out_sf", out_shape)
out_sf = cl_manager.run(out_shape[:-1])[0]
out_sh[patch_out[0, 0]:patch_out[0, 1],
patch_out[1, 0]:patch_out[1, 1],
patch_out[2, 0]:patch_out[2, 1]] = np.dot(out_sf, B_inv)
return out_sh
def _unified_filter_call_python(sh_data, nx_filter, uv_filter, sigma_range,
B_mat, B_inv, sphere):
"""
Run filtering using pure python implementation.
Parameters
----------
sh_data: ndarray
Input SH data.
nx_filter: ndarray
Combined spatial and alignment filter.
uv_filter: ndarray
Angle filter.
sigma_range: float or None
Standard deviation of range filter. None disables range filtering.
B_mat: ndarray
SH to SF projection matrix.
B_inv: ndarray
SF to SH projection matrix.
sphere: DIPY sphere
Sphere for SH to SF projection.
Returns
-------
out_sh: ndarray
Filtered output as SH coefficients.
"""
nb_sf = len(sphere.vertices)
mean_sf = np.zeros(sh_data.shape[:-1] + (nb_sf,))
# Apply filter to each sphere vertice
for u_sph_id in range(nb_sf):
if u_sph_id % 20 == 0:
logging.info('Processing direction: {}/{}'
.format(u_sph_id, nb_sf))
mean_sf[..., u_sph_id] = _correlate(sh_data, nx_filter, uv_filter,
sigma_range, u_sph_id, B_mat)
out_sh = np.array([np.dot(i, B_inv) for i in mean_sf],
dtype=sh_data.dtype)
return out_sh
def _correlate(sh_data, nx_filter, uv_filter, sigma_range, u_index, B_mat):
"""
Apply the filters to the SH image for the sphere direction
described by `u_index`.
Parameters
----------
sh_data: ndarray
Input SH coefficients.
nx_filter: ndarray
Combined spatial and alignment filter.
uv_filter: ndarray
Angle filter.
sigma_range: float or None
Standard deviation of range filter. None disables range filtering.
u_index: int
Index of the current sphere direction to process.
B_mat: ndarray
SH to SF projection matrix.
Returns
-------
out_sf: ndarray
Output SF amplitudes along the direction described by `u_index`.
"""
v_indices = np.flatnonzero(uv_filter[u_index])
nx_filter = nx_filter[..., u_index]
h_w, h_h, h_d = nx_filter.shape[:3]
half_w, half_h, half_d = h_w // 2, h_h // 2, h_d // 2
out_sf = np.zeros(sh_data.shape[:3])
sh_data = np.pad(sh_data, ((half_w, half_w),
(half_h, half_h),
(half_d, half_d),
(0, 0)))
sf_u = np.dot(sh_data, B_mat[:, u_index])
sf_v = np.dot(sh_data, B_mat[:, v_indices])
uv_filter = uv_filter[u_index, v_indices]
_get_range = _evaluate_gaussian_distribution\
if sigma_range is not None else lambda x, _: np.ones_like(x)
for ii in range(out_sf.shape[0]):
for jj in range(out_sf.shape[1]):
for kk in range(out_sf.shape[2]):
a = sf_v[ii:ii+h_w, jj:jj+h_h, kk:kk+h_d]
b = sf_u[ii + half_w, jj + half_h, kk + half_d]
x_range = a - b
range_filter = _get_range(x_range, sigma_range)
# the resulting filter for the current voxel and v_index
res_filter = range_filter * nx_filter[..., None]
res_filter =\
res_filter * np.reshape(uv_filter,
(1, 1, 1, len(uv_filter)))
out_sf[ii, jj, kk] = np.sum(
sf_v[ii:ii+h_w, jj:jj+h_h, kk:kk+h_d] * res_filter)
out_sf[ii, jj, kk] /= np.sum(res_filter)
return out_sf
[docs]
def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07',
in_full_basis=False, is_legacy=True, dot_sharpness=1.0,
sphere_str='repulsion724', sigma=1.0):
"""
Average the SH projected on a sphere using a first-neighbor gaussian
blur and a dot product weight between sphere directions and the direction
to neighborhood voxels, forcing to 0 negative values and thus performing
asymmetric hemisphere-aware filtering.
Parameters
----------
in_sh: ndarray (x, y, z, n_coeffs)
Input SH coefficients array
sh_order: int, optional
Maximum order of the SH series.
sh_basis: {'descoteaux07', 'tournier07'}, optional
SH basis of the input signal.
in_full_basis: bool, optional
True if the input is in full SH basis.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
dot_sharpness: float, optional
Exponent of the dot product. When set to 0.0, directions
are not weighted by the dot product.
sphere_str: str, optional
Name of the sphere used to project SH coefficients to SF.
sigma: float, optional
Sigma for the Gaussian.
Returns
-------
out_sh: ndarray (x, y, z, n_coeffs)
Filtered signal as SH coefficients in full SH basis.
"""
# Load the sphere used for projection of SH
sphere = get_sphere(sphere_str)
# Normalized filter for each sf direction
weights = _get_cosine_weights(sphere, dot_sharpness, sigma)
nb_sf = len(sphere.vertices)
mean_sf = np.zeros(np.append(in_sh.shape[:-1], nb_sf))
B = sh_to_sf_matrix(sphere, sh_order_max=sh_order, basis_type=sh_basis,
return_inv=False, full_basis=in_full_basis,
legacy=is_legacy)
# We want a B matrix to project on an inverse sphere to have the sf on
# the opposite hemisphere for a given vertice
neg_B = sh_to_sf_matrix(Sphere(xyz=-sphere.vertices), sh_order_max=sh_order,
basis_type=sh_basis, return_inv=False,
full_basis=in_full_basis, legacy=is_legacy)
# Apply filter to each sphere vertice
for sf_i in range(nb_sf):
w_filter = weights[..., sf_i]
# Calculate contribution of center voxel
current_sf = np.dot(in_sh, B[:, sf_i])
mean_sf[..., sf_i] = w_filter[1, 1, 1] * current_sf
# Add contributions of neighbors using opposite hemispheres
current_sf = np.dot(in_sh, neg_B[:, sf_i])
w_filter[1, 1, 1] = 0.0
mean_sf[..., sf_i] += correlate(current_sf, w_filter, mode="constant")
# Convert back to SH coefficients
_, B_inv = sh_to_sf_matrix(sphere, sh_order_max=sh_order,
basis_type=sh_basis,
full_basis=True,
legacy=is_legacy)
out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], dtype=in_sh.dtype)
return out_sh
def _get_cosine_weights(sphere, dot_sharpness, sigma):
"""
Get neighbors weight in respect to the direction to a voxel.
Parameters
----------
sphere: Sphere
Sphere used for SF reconstruction.
dot_sharpness: float
Dot product exponent.
sigma: float
Variance of the gaussian used for weighting neighbors.
Returns
-------
weights: dictionary
Vertices weights with respect to voxel directions.
"""
directions = np.zeros((3, 3, 3, 3))
for x in range(3):
for y in range(3):
for z in range(3):
directions[x, y, z, 0] = x - 1
directions[x, y, z, 1] = y - 1
directions[x, y, z, 2] = z - 1
non_zero_dir = np.ones((3, 3, 3), dtype=bool)
non_zero_dir[1, 1, 1] = False
# normalize dir
dir_norm = np.linalg.norm(directions, axis=-1, keepdims=True)
directions[non_zero_dir] /= dir_norm[non_zero_dir]
g_weights = np.exp(-dir_norm**2 / (2 * sigma**2))
d_weights = np.dot(directions, sphere.vertices.T)
d_weights = np.where(d_weights > 0.0, d_weights**dot_sharpness, 0.0)
weights = d_weights * g_weights
weights[1, 1, 1, :] = 1.0
# Normalize filters so that all sphere directions weights sum to 1
weights /= weights.reshape((-1, weights.shape[-1])).sum(axis=0)
return weights
def _evaluate_gaussian_distribution(x, sigma):
"""
1-dimensional 0-centered Gaussian distribution
with standard deviation sigma.
Parameters
----------
x: ndarray or float
Points where the distribution is evaluated.
sigma: float
Standard deviation.
Returns
-------
out: ndarray or float
Values at x.
"""
assert sigma > 0.0, "Sigma must be greater than 0."
cnorm = 1.0 / sigma / np.sqrt(2.0*np.pi)
return cnorm * np.exp(-x**2/2/sigma**2)
def _get_window_directions(shape):
"""
Get directions from center voxel to all neighbours
for a window of given shape.
Parameters
----------
shape: tuple
Dimensions of the window.
Returns
-------
grid: ndarray
Grid containing the direction from the center voxel to
the current position for all positions inside the window.
"""
grid = np.indices(shape)
grid = np.moveaxis(grid, 0, -1)
grid = grid - np.asarray(shape) // 2
return grid