Source code for scilpy.tractanalysis.scoring

# -*- coding: utf-8 -*-

"""
Tractometry
-----------
Global connectivity metrics:

- Computed by default:
    - VS: valid streamlines, belonging to a bundle (i.e. respecting all the
        criteria for that bundle; endpoints, limit_mask, gt_mask.).
    - IS: invalid streamlines. All other streamlines. IS = IC + NC.

- Optional:
    - WPC: wrong path connections, streamlines connecting correct ROIs but not
        respecting the other criteria for that bundle. Such streamlines always
        exist but they are only saved separately if specified in the options.
        Else, they are merged back with the IS.
        By definition. WPC are only computed if "limits masks" are provided.
    - IC: invalid connections, streamlines joining an incorrect combination of
        ROIs. Use carefully, quality depends on the quality of your ROIs and no
        analysis is done on the shape of the streamlines.
    - NC: no connections. Invalid streamlines minus invalid connections.

- Fidelity metrics:
    - OL: Overlap. Percentage of ground truth voxels containing streamline(s)
        for a given bundle.
    - OR: Overreach. Amount of voxels containing streamline(s) when they
        shouldn't, for a given bundle. We compute two versions :
        OR_pct_vs = divided by the total number of voxel covered by the bundle.
        (percentage of the voxels touched by VS).
        Values range between 0 and 100%. Values are not defined when we
        recovered no streamline for a bundle, but we set the OR_pct_vs to 0
        in that case.
        OR_pct_gt = divided by the total size of the ground truth bundle mask.
        Values could be higher than 100%.
    - f1 score: which is the same as the Dice score.
"""

import logging

from multiprocessing import Pool, Lock
import numpy as np
from tqdm import tqdm

from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map
from scilpy.tractograms.streamline_and_mask_operations import \
    get_endpoints_density_map


def _compute_ae(args):
    """Used for multiprocessing in compute_ae below."""
    process_id, dirs, coords, peaks = args

    nb_segments = len(dirs)
    ae_chunk = np.zeros(nb_segments)
    nb_nan_chunk = 0

    # Hiding the segment number in the tqdm bar, not meaningful
    _format = '{desc}:{percentage:3.0f}%|{bar}|  '
    _format += '[{elapsed}<{remaining}, {rate_fmt}{postfix}'

    for i in tqdm(range(nb_segments), ncols=120,
                  bar_format=_format, position=process_id + 1,
                  desc="Process {}.".format(process_id + 1)):
        current_peaks = peaks[coords[i][0], coords[i][1], coords[i][2], :, :]

        # Using only non-zero peaks. Dealing with buggy voxels: setting AE to 0
        current_peaks = current_peaks[np.any(current_peaks!=0, axis=-1)]
        if current_peaks.size == 0:
            nb_nan_chunk += 1
            ae_chunk[i] = 0
            continue

        # Using the abs value because vectors are undirected.
        cos_theta = np.abs(np.dot(current_peaks, dirs[i]))
        cos_theta = np.clip(cos_theta, 0, 1.0)  # numerical safety
        theta = np.rad2deg(np.arccos(cos_theta))
        ae_chunk[i] = np.min(theta)

    return ae_chunk, nb_nan_chunk


[docs] def compute_ae(sft, peaks, nb_processes=1): """ Computing the angular error for each segment. The direction is compared with the underlying peak (for single peak files like DTI) or with the closest peak (ex, with fODF peaks). Currently, interpolation is not supported: peaks of the closest voxel are used (nearest neighbor). AE is computed as the cosine difference. Parameters ---------- sft: StatefulTractogram The tractogram peaks: np.array of shape [x, y, z, nb_peaks, 3]. The peaks. nb_processes: int To use multiprocessing Returns ------- ae: list[np.array] The angular error for each streamline, in degrees. The last point of each streamline has an AE of zero. """ # If there is only one peak, make sure we still have a 4th dimension = 1. peaks = peaks.reshape(peaks.shape[:3] + (-1, 3)) if peaks.shape[3] == 1: multi_peaks = False logging.info("Peaks seem to be single-peaks (DTI, probably). Simple " "alignment measure.") else: multi_peaks = True logging.info("Peaks seem to be multi-peaks (maybe coming from ODF, " "fODF, etc). We will verify alignment with the closest " "peak in each voxel.") # Sending sft to vox space, corner origin. Then nearest neighbor # interpolation is just the floor. previous_space = sft.space previous_origin = sft.origin sft.to_vox() sft.to_corner() # Fixing peaks shape and normalizing logging.info("Normalizing peaks") _norm = np.linalg.norm(peaks, axis=-1, keepdims=True) _norm[_norm == 0] = 1 # Making sure we don't divide by 0 peaks = peaks / _norm del _norm # Getting segments and normalizing. # Concatenating all segments because the process is independant on each. # Then we can launch multiprocessing by dividing the segments into # the number of processes. Makes multiprocessing more equal. # However, it duplicates the tractogram 3 times in memory: sft, dirs, # coords. Fails with very large tractograms (ex, 10 millions). logging.info("Preparing each streamline segment and normalizing") dirs = [np.diff(s, axis=0) for s in sft.streamlines] dirs = np.vstack(dirs) dirs = dirs / np.linalg.norm(dirs, axis=-1, keepdims=True) # Concatenating streamlines for faster processing + nearest neighbor logging.info("Neareast-neighbor interpolation of streamline coordinates.") coords = np.floor(np.vstack([s[1:] for s in sft.streamlines])).astype(int) nb_segments = len(coords) # Preparing multiprocessing chunk_size = (nb_segments + nb_processes - 1) // nb_processes logging.info("Preparing multiprocessing. Nb processes = {}. " "Nb segments: {}. {} in each process." .format(nb_processes, nb_segments, chunk_size)) split_indices = [(i, min(i + chunk_size, nb_segments)) for i in range(0, nb_segments, chunk_size)] # Finding the angular difference with the closest peak: multiprocessing lock = Lock() tqdm.set_lock(lock) ae = [] nb_nan = 0 with Pool(processes=nb_processes) as pool: for ae_chunk, nb_nan_chunk in pool.imap_unordered( _compute_ae, [(i, dirs[start:end], coords[start:end], peaks) for i, (start, end) in enumerate(split_indices)]): ae.append(ae_chunk) nb_nan += nb_nan_chunk pool.close() pool.join() # Freeing memory del dirs del coords print(' ') # Required because finishing sub-processes' tqdm is flaky ae = np.hstack(ae) # Stacking multiprocess results if nb_nan > 0: msg = "AE in these voxels was set to 0. Total number of segments " + \ "traversing these voxels: {} /{} .".format(nb_nan, nb_segments) if multi_peaks: logging.warning("Some voxels had 0 valid peaks out of the {} " "possible peaks (they were all [0,0,0]). " .format(peaks.shape[3]) + msg) else: logging.warning("Invalid peaks ([0,0,0]) were found in some " "voxels. " + msg) # Split back streamlines lengths = [len(s) - 1 for s in sft.streamlines] ae = np.split(ae, np.cumsum(lengths)[:-1]) # Add value 0 as the last value of each streamline ae = [np.append(line_ae, 0) for line_ae in ae] # Sending back to previous space sft.to_space(previous_space) sft.to_origin(previous_origin) return ae
[docs] def compute_f1_score(overlap, overreach): """ Compute the F1 score between overlap and overreach (they must be percentages). Parameters ------ overlap: float, The overlap value. overreach: float, The overreach value (Version normalized over bundle area, not version normalized over gt). Returns ------- f1_score: float, The f1 score. Ref: https://en.wikipedia.org/wiki/F1_score """ # In the case where overlap = 0 (found the bundle but entirely out of the # mask; overreach = 100%), we avoid division by 0 and define f1 as 0. if overlap == 0 and overreach == 1: return 0. # Recall = True positive / (True positive + False negative) # = |B inter A| / |A| # = overlap recall = overlap # Precision = True positive / (True positive + False positive) # = |B inter A| / |B| # = 1 - |B except A| / |B| # = 1 - overreach precision = 1.0 - overreach f1_score = 2.0 * (precision * recall) / (precision + recall) return f1_score
[docs] def compute_f1_overlap_overreach(current_vb_voxels, gt_mask, dimensions): """ Compute f1, OL and OR/ORn based on a ground truth mask. Parameters ------ current_vb_voxels: 3D array The voxels touched by at least one streamlines for a given bundle. gt_mask: 3D array The ground truth mask. dimensions: np.array The nibabel dimensions of the data (3D). Returns ------- f1: float The f1 score. tp_nb_voxels: int The TP (true positive) count in number of voxels. fp_nb_voxels: int The FP (false positive) count in number of voxels. Hint: Divide it by the ground truth count to get the overreach, or by the recovered bundle count to get the ORn (scores used in the ismrm2015 tractography challenge). fn_nb_voxels: int The number of voxels from the gt_mask that have not been recovered; corresponds to the FN count (false negative). overlap: float TP divided by the ground truth count (i.e. TP + FN), in percentage. overreach_pct_gt: float The overreach, normalized by the ground truth area. overreach_pct_vs: float The overreach, normalized by the recovered bundle's area. (Or 0 if no streamline have been recovered for this bundle). """ # True positive = |B inter A| tp_mask = gt_mask * current_vb_voxels tp_nb_voxels = np.count_nonzero(tp_mask) # False positive = |B except A| fp_mask = np.zeros(dimensions) fp_mask[np.where( (gt_mask == 0) & (current_vb_voxels >= 1))] = 1 fp_nb_voxels = np.count_nonzero(fp_mask) # False negative = |A except B| fn_mask = np.zeros(dimensions) fn_mask[np.where( (gt_mask == 1) & (current_vb_voxels == 0))] = 1 fn_nb_voxels = np.count_nonzero(fn_mask) gt_total_nb_voxels = tp_nb_voxels + fn_nb_voxels # Same as np.count_nonzero(gt_mask) nb_voxels_total = tp_nb_voxels + fp_nb_voxels # Same as np.count_nonzero(current_vb_voxels) # Overlap = |B inter A| / |A| overlap = tp_nb_voxels / gt_total_nb_voxels # Overreach: two versions are sometimes used. # |B except A| / |A| or |B except A| / |B| if nb_voxels_total == 0: overreach_pct_vs = 0 else: overreach_pct_vs = fp_nb_voxels / nb_voxels_total overreach_pct_gt = fp_nb_voxels / gt_total_nb_voxels # f1 score (=dice) f1 = compute_f1_score(overlap, overreach_pct_vs) return (f1, tp_nb_voxels, fp_nb_voxels, fn_nb_voxels, overlap, overreach_pct_gt, overreach_pct_vs)
[docs] def get_binary_maps(sft): """ Extract a mask from a bundle. Parameters ---------- sft: StatefulTractogram Bundle. Returns ------- bundles_voxels: numpy.ndarray Mask representing the bundle volume. endpoints_voxels: numpy.ndarray Mask representing the bundle's endpoints. """ sft.to_vox() sft.to_corner() _, dimensions, _, _ = sft.space_attributes if len(sft) == 0: return np.zeros(dimensions), np.zeros(dimensions) bundles_voxels = compute_tract_counts_map(sft.streamlines, dimensions).astype(np.int16) endpoints_voxels = get_endpoints_density_map(sft).astype(np.int16) bundles_voxels[bundles_voxels > 0] = 1 endpoints_voxels[endpoints_voxels > 0] = 1 return bundles_voxels, endpoints_voxels
[docs] def compute_tractometry( vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, args, bundles_names, gt_masks, dimensions, ib_names): """ Tractometry stats: First in terms of connections (NC, IC, VS, WPC), then in terms of volume (OL, OR, Dice score) """ logging.info("Computing tractometry") vs_per_bundle = [len(x) if x is not None else 0 for x in vb_sft_list] vb_count = np.count_nonzero(vs_per_bundle) vs_count = np.sum(vs_per_bundle) if wpc_sft_list is not None: wpc_per_bundle = [len(x) if x is not None else 0 for x in wpc_sft_list] wpb_count = np.count_nonzero(wpc_per_bundle) wpc_count = np.sum(wpc_per_bundle) else: wpb_count = 0 wpc_count = 0 ic_per_ib_bundle = [len(x) for x in ib_sft_list] ib_count = np.count_nonzero(ic_per_ib_bundle) ic_count = np.sum(ic_per_ib_bundle) nc_count = len(nc_sft) if nc_sft is not None else 0 total_count = vs_count + wpc_count + ic_count + nc_count nb_bundles = len(bundles_names) final_results = { "total_streamlines": int(total_count), "VB": int(vb_count), "VS": int(vs_count), "VS_ratio": vs_count / total_count, "IS": int(ic_count + nc_count), # ic_count = 0 if not args.compute_ic "IS_ratio": (ic_count + nc_count) / total_count, } if args.compute_ic: final_results.update({ "IB": int(ib_count), "IC": int(ic_count), "IC_ratio": ic_count / total_count, "NC": int(nc_count), "NC_ratio": nc_count / total_count}) if args.save_wpc_separately: final_results.update({ "WPC": int(wpc_count), "WPC_bundle": wpb_count, "WPC_ratio": wpc_count / total_count}) # Tractometry stats over volume: OL, OR, Dice score mean_overlap = 0.0 mean_overreach_gt = 0.0 mean_overreach_vs = 0.0 mean_f1 = 0.0 nb_bundles_in_stats = 0 bundle_wise_dict = {} for i in range(nb_bundles): logging.debug("Scoring bundle {}".format({bundles_names[i]})) current_vb = vb_sft_list[i] bundle_results = {"VS": len(current_vb)} if gt_masks[i] is not None: if current_vb is None or len(current_vb) == 0: logging.debug(" Empty bundle or bundle not found.") bundle_results.update({ "TP": 0, "FP": 0, "FN": 0, "OL": 0, "OR_pct_gt": 0, "OR_pct_vs": 0, "f1": 0, "endpoints_OL": 0, "endpoints_OR": 0 }) nb_bundles_in_stats += 1 bundle_wise_dict.update({bundles_names[i]: bundle_results}) continue # Getting the recovered mask current_vb_voxels, current_vb_endpoints_voxels = get_binary_maps( current_vb) (f1, tp_nb_voxels, fp_nb_voxels, fn_nb_voxels, overlap_count, overreach_pct_gt, overreach_pct_vs) = \ compute_f1_overlap_overreach( current_vb_voxels, gt_masks[i], dimensions) # Endpoints coverage # todo. What is this? Useful? endpoints_overlap = gt_masks[i] * current_vb_endpoints_voxels endpoints_overreach = np.zeros(dimensions) endpoints_overreach[np.where( (gt_masks[i] == 0) & (current_vb_endpoints_voxels >= 1))] = 1 bundle_results.update({ "TP": tp_nb_voxels, "FP": fp_nb_voxels, "FN": fn_nb_voxels, "OL": overlap_count, "OR_pct_vs": overreach_pct_vs, "OR_pct_gt": overreach_pct_gt, "f1": f1, "endpoints_OL": np.count_nonzero(endpoints_overlap), "endpoints_OR": np.count_nonzero(endpoints_overreach) }) # WPC if args.save_wpc_separately: wpc_sft = wpc_sft_list[i] if wpc_sft is not None and len(wpc_sft) > 0: current_wpc_voxels, _ = get_binary_maps(wpc_sft) # We could add an option to include wpc streamlines to the # overreach count. But it seems more natural to exclude wpc # streamlines from any count. Separating into a different # statistic dict. Else, user may simply not include a "ALL" # mask, there won't be any wpc. (_, tp_nb_voxels, fp_nb_voxels, _, overlap_count, overreach_pct_gt, overreach_pct_vs) = \ compute_f1_overlap_overreach( current_vb_voxels, gt_masks[i], dimensions) wpc_results = { "Count": len(wpc_sft), "TP": tp_nb_voxels, "FP": fp_nb_voxels, "OL": overlap_count, "OR_pct_vs": overreach_pct_vs, "OR_pct_gt": overreach_pct_gt, } bundle_results.update({"WPC": wpc_results}) else: bundle_results.update({"WPC": None}) mean_overlap += bundle_results["OL"] mean_overreach_gt += bundle_results["OR_pct_gt"] mean_overreach_vs += bundle_results["OR_pct_vs"] mean_f1 += bundle_results["f1"] nb_bundles_in_stats += 1 else: bundle_results.update({"Scoring skipped": "No gt_mask provided"}) bundle_wise_dict.update({bundles_names[i]: bundle_results}) if args.compute_ic: # ----------- # False connections stats: number of voxels # ----------- ic_results = {} for i in range(len(ib_names)): current_ib = ib_sft_list[i] if len(current_ib) > 0: current_ib_voxels, _ = get_binary_maps(current_ib) bundle_results = { "IC": len(current_ib), "nb_voxels": np.count_nonzero(current_ib_voxels) } ic_results.update({ib_names[i]: bundle_results}) bundle_wise_dict.update({"IB": ic_results}) final_results.update({"bundle_wise": bundle_wise_dict}) if nb_bundles_in_stats > 0: final_results.update({ "mean_OL": mean_overlap / nb_bundles_in_stats, "mean_OR_gt": mean_overreach_gt / nb_bundles_in_stats, "mean_OR_vs": mean_overreach_vs / nb_bundles_in_stats, "mean_f1": mean_f1 / nb_bundles_in_stats }) return final_results