Source code for scilpy.ml.bundleparc.predict

import logging
import numpy as np

from tqdm import tqdm

from scipy.ndimage import label

from scilpy.ml.utils import get_device, to_numpy, IMPORT_ERROR_MSG
from scilpy.ml.bundleparc.utils import (DEFAULT_BUNDLES, get_data)

from dipy.utils.optpkg import optional_package
torch, have_torch, _ = optional_package('torch', trip_msg=IMPORT_ERROR_MSG)


[docs] def post_process_mask( mask, bundle_name, min_blob_size=100, keep_biggest_blob=False ): """ Post-process the mask. This function binarizes the mask. In a future release, it will also remove small blobs and fill holes (this is why the bundle name is passed). Parameters ---------- mask : np.ndarray Predicted mask for the bundle. bundle_name : str Name of the bundle. """ bundle_mask = (mask > 0.5) # Get the blobs in the image. Ideally, a mask only has one blob. # More than one either indicates a broken segmentation, or extraneous # voxels. blobs, nb = label(bundle_mask) # No need to process, return the mask if nb <= 1: logging.debug(f"Only one blob in {bundle_name}.") return bundle_mask.astype(np.uint8) # Calculate the size of each blob blob_sizes = np.bincount(blobs.ravel()) new_mask = np.zeros_like(bundle_mask) if keep_biggest_blob: logging.debug(f"More than one blob in {bundle_name}, keeping largest") biggest_blob = np.argmax(blob_sizes[1:]) new_mask[blobs == biggest_blob + 1] = 1 return new_mask # Remove blobs under a certain size (min_blob_size) new_nb_blobs = 0 for i in range(1, len(blob_sizes[1:])): if blob_sizes[i] >= min_blob_size: new_mask[blobs == i] = 1 new_nb_blobs += 1 logging.debug(f'Kept {new_nb_blobs} blob out of {nb} in {bundle_name}.') return bundle_mask.astype(np.uint8)
@torch.no_grad() def predict( model, fodf, n_coefs, label_function, bundles, min_blob_size, keep_biggest_blob, half_precision=False, verbose=False ): """ Predict the bundle labels. This function is a generator that yields the predicted labels for each bundle. Parameters ---------- model : LabelSegNet Model to use for the prediction. fodf: nib.nib.Nifti1Image fODF image, resampled to the model's input size. n_coefs : int Number of SH coefficients to use. label_function: function Function to use for label post-processing. Can be one of: post_process_labels_discrete, post_process_labels_mm, post_process_labels_continuous. bundles : list of str List of bundle names. min_blob_size : int Minimum size of a blob to keep in the mask. Small blobs will be removed from the mask. keep_biggest_blob : bool Whether to keep the biggest blob in the mask. If True, all other blobs will be removed. If False, all blobs smaller than `min_blob_size` will be removed. half_precision : bool, optional Whether to use half precision. Will reduce GPU memory usage but may reduce the accuracy of the label maps. Default is False. verbose : bool, optional Whether to display a progress bar. Default is True. Yields ------ bundle_label : np.ndarray Predicted labels for the bundle. bundle_name : str Name of the bundle. """ bundle_indices = np.array([DEFAULT_BUNDLES.index(b) for b in bundles]) device = get_device() fodf_data = get_data(fodf, n_coefs) pbar = tqdm(bundle_indices, disable=not verbose) with torch.amp.autocast(device.type, enabled=half_precision): # Convert the fODF data to a torch tensor data = torch.tensor( fodf_data, dtype=torch.float ).to(device) # Create a one-hot encoding of the bundle prompts. prompts = torch.eye(len(DEFAULT_BUNDLES), device=device) # Encode the data once, reuse the features. z, encoder_features = model.encode( data[None, ...]) # Loop over the bundles. for i in pbar: pbar.set_description(DEFAULT_BUNDLES[i]) # Decode the features for the current bundle. y_hat = torch.nn.functional.sigmoid(model.decode( z, encoder_features, prompts[None, i, ...] )[-1]).squeeze() # Get the predicted mask and labels as numpy arrays. y_hat_np = to_numpy(y_hat) bundle_mask = y_hat_np[0] bundle_label = y_hat_np[1] # Post-process the mask and labels. # Binarize the mask (and in a future release remove small blobs and # fill holes bundle_mask = post_process_mask( bundle_mask, DEFAULT_BUNDLES[i], min_blob_size=min_blob_size, keep_biggest_blob=keep_biggest_blob) # Extract the labels using the mask, then filter and discretize # them. bundle_label = label_function( bundle_label, bundle_mask, DEFAULT_BUNDLES[i]) yield bundle_label, DEFAULT_BUNDLES[i]