# -*- coding: utf-8 -*-
"""
Tractometry
-----------
Global connectivity metrics:
- Computed by default:
- VS: valid streamlines, belonging to a bundle (i.e. respecting all the
criteria for that bundle; endpoints, limit_mask, gt_mask.).
- IS: invalid streamlines. All other streamlines. IS = IC + NC.
- Optional:
- WPC: wrong path connections, streamlines connecting correct ROIs but not
respecting the other criteria for that bundle. Such streamlines always
exist but they are only saved separately if specified in the options.
Else, they are merged back with the IS.
By definition. WPC are only computed if "limits masks" are provided.
- IC: invalid connections, streamlines joining an incorrect combination of
ROIs. Use carefully, quality depends on the quality of your ROIs and no
analysis is done on the shape of the streamlines.
- NC: no connections. Invalid streamlines minus invalid connections.
- Fidelity metrics:
- OL: Overlap. Percentage of ground truth voxels containing streamline(s)
for a given bundle.
- OR: Overreach. Amount of voxels containing streamline(s) when they
shouldn't, for a given bundle. We compute two versions :
OR_pct_vs = divided by the total number of voxel covered by the bundle.
(percentage of the voxels touched by VS).
Values range between 0 and 100%. Values are not defined when we
recovered no streamline for a bundle, but we set the OR_pct_vs to 0
in that case.
OR_pct_gt = divided by the total size of the ground truth bundle mask.
Values could be higher than 100%.
- f1 score: which is the same as the Dice score.
"""
import logging
from multiprocessing import Pool, Lock
import numpy as np
from tqdm import tqdm
from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map
from scilpy.tractograms.streamline_and_mask_operations import \
get_endpoints_density_map
def _compute_ae(args):
"""Used for multiprocessing in compute_ae below."""
process_id, dirs, coords, peaks = args
nb_segments = len(dirs)
ae_chunk = np.zeros(nb_segments)
nb_nan_chunk = 0
# Hiding the segment number in the tqdm bar, not meaningful
_format = '{desc}:{percentage:3.0f}%|{bar}| '
_format += '[{elapsed}<{remaining}, {rate_fmt}{postfix}'
for i in tqdm(range(nb_segments), ncols=120,
bar_format=_format, position=process_id + 1,
desc="Process {}.".format(process_id + 1)):
current_peaks = peaks[coords[i][0], coords[i][1], coords[i][2], :, :]
# Using only non-zero peaks. Dealing with buggy voxels: setting AE to 0
current_peaks = current_peaks[np.any(current_peaks!=0, axis=-1)]
if current_peaks.size == 0:
nb_nan_chunk += 1
ae_chunk[i] = 0
continue
# Using the abs value because vectors are undirected.
cos_theta = np.abs(np.dot(current_peaks, dirs[i]))
cos_theta = np.clip(cos_theta, 0, 1.0) # numerical safety
theta = np.rad2deg(np.arccos(cos_theta))
ae_chunk[i] = np.min(theta)
return ae_chunk, nb_nan_chunk
[docs]
def compute_ae(sft, peaks, nb_processes=1):
"""
Computing the angular error for each segment. The direction is compared
with the underlying peak (for single peak files like DTI) or with the
closest peak (ex, with fODF peaks). Currently, interpolation is not
supported: peaks of the closest voxel are used (nearest neighbor). AE is
computed as the cosine difference.
Parameters
----------
sft: StatefulTractogram
The tractogram
peaks: np.array of shape [x, y, z, nb_peaks, 3].
The peaks.
nb_processes: int
To use multiprocessing
Returns
-------
ae: list[np.array]
The angular error for each streamline, in degrees. The last point of
each streamline has an AE of zero.
"""
# If there is only one peak, make sure we still have a 4th dimension = 1.
peaks = peaks.reshape(peaks.shape[:3] + (-1, 3))
if peaks.shape[3] == 1:
multi_peaks = False
logging.info("Peaks seem to be single-peaks (DTI, probably). Simple "
"alignment measure.")
else:
multi_peaks = True
logging.info("Peaks seem to be multi-peaks (maybe coming from ODF, "
"fODF, etc). We will verify alignment with the closest "
"peak in each voxel.")
# Sending sft to vox space, corner origin. Then nearest neighbor
# interpolation is just the floor.
previous_space = sft.space
previous_origin = sft.origin
sft.to_vox()
sft.to_corner()
# Fixing peaks shape and normalizing
logging.info("Normalizing peaks")
_norm = np.linalg.norm(peaks, axis=-1, keepdims=True)
_norm[_norm == 0] = 1 # Making sure we don't divide by 0
peaks = peaks / _norm
del _norm
# Getting segments and normalizing.
# Concatenating all segments because the process is independant on each.
# Then we can launch multiprocessing by dividing the segments into
# the number of processes. Makes multiprocessing more equal.
# However, it duplicates the tractogram 3 times in memory: sft, dirs,
# coords. Fails with very large tractograms (ex, 10 millions).
logging.info("Preparing each streamline segment and normalizing")
dirs = [np.diff(s, axis=0) for s in sft.streamlines]
dirs = np.vstack(dirs)
dirs = dirs / np.linalg.norm(dirs, axis=-1, keepdims=True)
# Concatenating streamlines for faster processing + nearest neighbor
logging.info("Neareast-neighbor interpolation of streamline coordinates.")
coords = np.floor(np.vstack([s[1:] for s in sft.streamlines])).astype(int)
nb_segments = len(coords)
# Preparing multiprocessing
chunk_size = (nb_segments + nb_processes - 1) // nb_processes
logging.info("Preparing multiprocessing. Nb processes = {}. "
"Nb segments: {}. {} in each process."
.format(nb_processes, nb_segments, chunk_size))
split_indices = [(i, min(i + chunk_size, nb_segments))
for i in range(0, nb_segments, chunk_size)]
# Finding the angular difference with the closest peak: multiprocessing
lock = Lock()
tqdm.set_lock(lock)
ae = []
nb_nan = 0
with Pool(processes=nb_processes) as pool:
for ae_chunk, nb_nan_chunk in pool.imap_unordered(
_compute_ae, [(i, dirs[start:end], coords[start:end], peaks)
for i, (start, end) in enumerate(split_indices)]):
ae.append(ae_chunk)
nb_nan += nb_nan_chunk
pool.close()
pool.join()
# Freeing memory
del dirs
del coords
print(' ') # Required because finishing sub-processes' tqdm is flaky
ae = np.hstack(ae) # Stacking multiprocess results
if nb_nan > 0:
msg = "AE in these voxels was set to 0. Total number of segments " + \
"traversing these voxels: {} /{} .".format(nb_nan, nb_segments)
if multi_peaks:
logging.warning("Some voxels had 0 valid peaks out of the {} "
"possible peaks (they were all [0,0,0]). "
.format(peaks.shape[3]) + msg)
else:
logging.warning("Invalid peaks ([0,0,0]) were found in some "
"voxels. " + msg)
# Split back streamlines
lengths = [len(s) - 1 for s in sft.streamlines]
ae = np.split(ae, np.cumsum(lengths)[:-1])
# Add value 0 as the last value of each streamline
ae = [np.append(line_ae, 0) for line_ae in ae]
# Sending back to previous space
sft.to_space(previous_space)
sft.to_origin(previous_origin)
return ae
[docs]
def compute_f1_score(overlap, overreach):
"""
Compute the F1 score between overlap and overreach (they must be
percentages).
Parameters
------
overlap: float, The overlap value.
overreach: float, The overreach value
(Version normalized over bundle area, not version normalized over gt).
Returns
-------
f1_score: float, The f1 score.
Ref: https://en.wikipedia.org/wiki/F1_score
"""
# In the case where overlap = 0 (found the bundle but entirely out of the
# mask; overreach = 100%), we avoid division by 0 and define f1 as 0.
if overlap == 0 and overreach == 1:
return 0.
# Recall = True positive / (True positive + False negative)
# = |B inter A| / |A|
# = overlap
recall = overlap
# Precision = True positive / (True positive + False positive)
# = |B inter A| / |B|
# = 1 - |B except A| / |B|
# = 1 - overreach
precision = 1.0 - overreach
f1_score = 2.0 * (precision * recall) / (precision + recall)
return f1_score
[docs]
def compute_f1_overlap_overreach(current_vb_voxels, gt_mask, dimensions):
"""
Compute f1, OL and OR/ORn based on a ground truth mask.
Parameters
------
current_vb_voxels: 3D array
The voxels touched by at least one streamlines for a given bundle.
gt_mask: 3D array
The ground truth mask.
dimensions: np.array
The nibabel dimensions of the data (3D).
Returns
-------
f1: float
The f1 score.
tp_nb_voxels: int
The TP (true positive) count in number of voxels.
fp_nb_voxels: int
The FP (false positive) count in number of voxels.
Hint: Divide it by the ground truth count to get the overreach, or
by the recovered bundle count to get the ORn (scores used in the
ismrm2015 tractography challenge).
fn_nb_voxels: int
The number of voxels from the gt_mask that have not been recovered;
corresponds to the FN count (false negative).
overlap: float
TP divided by the ground truth count (i.e. TP + FN), in percentage.
overreach_pct_gt: float
The overreach, normalized by the ground truth area.
overreach_pct_vs: float
The overreach, normalized by the recovered bundle's area. (Or 0 if
no streamline have been recovered for this bundle).
"""
# True positive = |B inter A|
tp_mask = gt_mask * current_vb_voxels
tp_nb_voxels = np.count_nonzero(tp_mask)
# False positive = |B except A|
fp_mask = np.zeros(dimensions)
fp_mask[np.where(
(gt_mask == 0) & (current_vb_voxels >= 1))] = 1
fp_nb_voxels = np.count_nonzero(fp_mask)
# False negative = |A except B|
fn_mask = np.zeros(dimensions)
fn_mask[np.where(
(gt_mask == 1) & (current_vb_voxels == 0))] = 1
fn_nb_voxels = np.count_nonzero(fn_mask)
gt_total_nb_voxels = tp_nb_voxels + fn_nb_voxels
# Same as np.count_nonzero(gt_mask)
nb_voxels_total = tp_nb_voxels + fp_nb_voxels
# Same as np.count_nonzero(current_vb_voxels)
# Overlap = |B inter A| / |A|
overlap = tp_nb_voxels / gt_total_nb_voxels
# Overreach: two versions are sometimes used.
# |B except A| / |A| or |B except A| / |B|
if nb_voxels_total == 0:
overreach_pct_vs = 0
else:
overreach_pct_vs = fp_nb_voxels / nb_voxels_total
overreach_pct_gt = fp_nb_voxels / gt_total_nb_voxels
# f1 score (=dice)
f1 = compute_f1_score(overlap, overreach_pct_vs)
return (f1, tp_nb_voxels, fp_nb_voxels, fn_nb_voxels,
overlap, overreach_pct_gt, overreach_pct_vs)
[docs]
def get_binary_maps(sft):
"""
Extract a mask from a bundle.
Parameters
----------
sft: StatefulTractogram
Bundle.
Returns
-------
bundles_voxels: numpy.ndarray
Mask representing the bundle volume.
endpoints_voxels: numpy.ndarray
Mask representing the bundle's endpoints.
"""
sft.to_vox()
sft.to_corner()
_, dimensions, _, _ = sft.space_attributes
if len(sft) == 0:
return np.zeros(dimensions), np.zeros(dimensions)
bundles_voxels = compute_tract_counts_map(sft.streamlines,
dimensions).astype(np.int16)
endpoints_voxels = get_endpoints_density_map(sft).astype(np.int16)
bundles_voxels[bundles_voxels > 0] = 1
endpoints_voxels[endpoints_voxels > 0] = 1
return bundles_voxels, endpoints_voxels
[docs]
def compute_tractometry(
vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, args,
bundles_names, gt_masks, dimensions, ib_names):
"""
Tractometry stats: First in terms of connections (NC, IC, VS, WPC), then
in terms of volume (OL, OR, Dice score)
"""
logging.info("Computing tractometry")
vs_per_bundle = [len(x) if x is not None else 0 for x in vb_sft_list]
vb_count = np.count_nonzero(vs_per_bundle)
vs_count = np.sum(vs_per_bundle)
if wpc_sft_list is not None:
wpc_per_bundle = [len(x) if x is not None else 0 for x in wpc_sft_list]
wpb_count = np.count_nonzero(wpc_per_bundle)
wpc_count = np.sum(wpc_per_bundle)
else:
wpb_count = 0
wpc_count = 0
ic_per_ib_bundle = [len(x) for x in ib_sft_list]
ib_count = np.count_nonzero(ic_per_ib_bundle)
ic_count = np.sum(ic_per_ib_bundle)
nc_count = len(nc_sft) if nc_sft is not None else 0
total_count = vs_count + wpc_count + ic_count + nc_count
nb_bundles = len(bundles_names)
final_results = {
"total_streamlines": int(total_count),
"VB": int(vb_count),
"VS": int(vs_count),
"VS_ratio": vs_count / total_count,
"IS": int(ic_count + nc_count), # ic_count = 0 if not args.compute_ic
"IS_ratio": (ic_count + nc_count) / total_count,
}
if args.compute_ic:
final_results.update({
"IB": int(ib_count),
"IC": int(ic_count),
"IC_ratio": ic_count / total_count,
"NC": int(nc_count),
"NC_ratio": nc_count / total_count})
if args.save_wpc_separately:
final_results.update({
"WPC": int(wpc_count),
"WPC_bundle": wpb_count,
"WPC_ratio": wpc_count / total_count})
# Tractometry stats over volume: OL, OR, Dice score
mean_overlap = 0.0
mean_overreach_gt = 0.0
mean_overreach_vs = 0.0
mean_f1 = 0.0
nb_bundles_in_stats = 0
bundle_wise_dict = {}
for i in range(nb_bundles):
logging.debug("Scoring bundle {}".format({bundles_names[i]}))
current_vb = vb_sft_list[i]
bundle_results = {"VS": len(current_vb)}
if gt_masks[i] is not None:
if current_vb is None or len(current_vb) == 0:
logging.debug(" Empty bundle or bundle not found.")
bundle_results.update({
"TP": 0, "FP": 0, "FN": 0,
"OL": 0, "OR_pct_gt": 0, "OR_pct_vs": 0, "f1": 0,
"endpoints_OL": 0, "endpoints_OR": 0
})
nb_bundles_in_stats += 1
bundle_wise_dict.update({bundles_names[i]: bundle_results})
continue
# Getting the recovered mask
current_vb_voxels, current_vb_endpoints_voxels = get_binary_maps(
current_vb)
(f1, tp_nb_voxels, fp_nb_voxels, fn_nb_voxels,
overlap_count, overreach_pct_gt, overreach_pct_vs) = \
compute_f1_overlap_overreach(
current_vb_voxels, gt_masks[i], dimensions)
# Endpoints coverage
# todo. What is this? Useful?
endpoints_overlap = gt_masks[i] * current_vb_endpoints_voxels
endpoints_overreach = np.zeros(dimensions)
endpoints_overreach[np.where(
(gt_masks[i] == 0) & (current_vb_endpoints_voxels >= 1))] = 1
bundle_results.update({
"TP": tp_nb_voxels,
"FP": fp_nb_voxels,
"FN": fn_nb_voxels,
"OL": overlap_count,
"OR_pct_vs": overreach_pct_vs,
"OR_pct_gt": overreach_pct_gt,
"f1": f1,
"endpoints_OL": np.count_nonzero(endpoints_overlap),
"endpoints_OR": np.count_nonzero(endpoints_overreach)
})
# WPC
if args.save_wpc_separately:
wpc_sft = wpc_sft_list[i]
if wpc_sft is not None and len(wpc_sft) > 0:
current_wpc_voxels, _ = get_binary_maps(wpc_sft)
# We could add an option to include wpc streamlines to the
# overreach count. But it seems more natural to exclude wpc
# streamlines from any count. Separating into a different
# statistic dict. Else, user may simply not include a "ALL"
# mask, there won't be any wpc.
(_, tp_nb_voxels, fp_nb_voxels, _, overlap_count,
overreach_pct_gt, overreach_pct_vs) = \
compute_f1_overlap_overreach(
current_vb_voxels, gt_masks[i], dimensions)
wpc_results = {
"Count": len(wpc_sft),
"TP": tp_nb_voxels,
"FP": fp_nb_voxels,
"OL": overlap_count,
"OR_pct_vs": overreach_pct_vs,
"OR_pct_gt": overreach_pct_gt,
}
bundle_results.update({"WPC": wpc_results})
else:
bundle_results.update({"WPC": None})
mean_overlap += bundle_results["OL"]
mean_overreach_gt += bundle_results["OR_pct_gt"]
mean_overreach_vs += bundle_results["OR_pct_vs"]
mean_f1 += bundle_results["f1"]
nb_bundles_in_stats += 1
else:
bundle_results.update({"Scoring skipped": "No gt_mask provided"})
bundle_wise_dict.update({bundles_names[i]: bundle_results})
if args.compute_ic:
# -----------
# False connections stats: number of voxels
# -----------
ic_results = {}
for i in range(len(ib_names)):
current_ib = ib_sft_list[i]
if len(current_ib) > 0:
current_ib_voxels, _ = get_binary_maps(current_ib)
bundle_results = {
"IC": len(current_ib),
"nb_voxels": np.count_nonzero(current_ib_voxels)
}
ic_results.update({ib_names[i]: bundle_results})
bundle_wise_dict.update({"IB": ic_results})
final_results.update({"bundle_wise": bundle_wise_dict})
if nb_bundles_in_stats > 0:
final_results.update({
"mean_OL": mean_overlap / nb_bundles_in_stats,
"mean_OR_gt": mean_overreach_gt / nb_bundles_in_stats,
"mean_OR_vs": mean_overreach_vs / nb_bundles_in_stats,
"mean_f1": mean_f1 / nb_bundles_in_stats
})
return final_results