# -*- coding: utf-8 -*-
"""
Tractometry
-----------
Global connectivity metrics:
- Computed by default:
- VS: valid streamlines, belonging to a bundle (i.e. respecting all the
criteria for that bundle; endpoints, limit_mask, gt_mask.).
- IS: invalid streamlines. All other streamlines. IS = IC + NC.
- Optional:
- WPC: wrong path connections, streamlines connecting correct ROIs but not
respecting the other criteria for that bundle. Such streamlines always
exist but they are only saved separately if specified in the options.
Else, they are merged back with the IS.
By definition. WPC are only computed if "limits masks" are provided.
- IC: invalid connections, streamlines joining an incorrect combination of
ROIs. Use carefully, quality depends on the quality of your ROIs and no
analysis is done on the shape of the streamlines.
- NC: no connections. Invalid streamlines minus invalid connections.
- Fidelity metrics:
- OL: Overlap. Percentage of ground truth voxels containing streamline(s)
for a given bundle.
- OR: Overreach. Amount of voxels containing streamline(s) when they
shouldn't, for a given bundle. We compute two versions :
OR_pct_vs = divided by the total number of voxel covered by the bundle.
(percentage of the voxels touched by VS).
Values range between 0 and 100%. Values are not defined when we
recovered no streamline for a bundle, but we set the OR_pct_vs to 0
in that case.
OR_pct_gt = divided by the total size of the ground truth bundle mask.
Values could be higher than 100%.
- f1 score: which is the same as the Dice score.
"""
import logging
import numpy as np
from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map
from scilpy.tractograms.streamline_and_mask_operations import \
get_endpoints_density_map
[docs]
def compute_f1_score(overlap, overreach):
"""
Compute the F1 score between overlap and overreach (they must be
percentages).
Parameters
------
overlap: float, The overlap value.
overreach: float, The overreach value
(Version normalized over bundle area, not version normalized over gt).
Returns
-------
f1_score: float, The f1 score.
Ref: https://en.wikipedia.org/wiki/F1_score
"""
# In the case where overlap = 0 (found the bundle but entirely out of the
# mask; overreach = 100%), we avoid division by 0 and define f1 as 0.
if overlap == 0 and overreach == 1:
return 0.
# Recall = True positive / (True positive + False negative)
# = |B inter A| / |A|
# = overlap
recall = overlap
# Precision = True positive / (True positive + False positive)
# = |B inter A| / |B|
# = 1 - |B except A| / |B|
# = 1 - overreach
precision = 1.0 - overreach
f1_score = 2.0 * (precision * recall) / (precision + recall)
return f1_score
[docs]
def compute_f1_overlap_overreach(current_vb_voxels, gt_mask, dimensions):
"""
Compute f1, OL and OR/ORn based on a ground truth mask.
Parameters
------
current_vb_voxels: 3D array
The voxels touched by at least one streamlines for a given bundle.
gt_mask: 3D array
The ground truth mask.
dimensions: np.array
The nibabel dimensions of the data (3D).
Returns
-------
f1: float
The f1 score.
tp_nb_voxels: int
The TP (true positive) count in number of voxels.
fp_nb_voxels: int
The FP (false positive) count in number of voxels.
Hint: Divide it by the ground truth count to get the overreach, or
by the recovered bundle count to get the ORn (scores used in the
ismrm2015 tractography challenge).
fn_nb_voxels: int
The number of voxels from the gt_mask that have not been recovered;
corresponds to the FN count (false negative).
overlap: float
TP divided by the ground truth count (i.e. TP + FN), in percentage.
overreach_pct_gt: float
The overreach, normalized by the ground truth area.
overreach_pct_vs: float
The overreach, normalized by the recovered bundle's area. (Or 0 if
no streamline have been recovered for this bundle).
"""
# True positive = |B inter A|
tp_mask = gt_mask * current_vb_voxels
tp_nb_voxels = np.count_nonzero(tp_mask)
# False positive = |B except A|
fp_mask = np.zeros(dimensions)
fp_mask[np.where(
(gt_mask == 0) & (current_vb_voxels >= 1))] = 1
fp_nb_voxels = np.count_nonzero(fp_mask)
# False negative = |A except B|
fn_mask = np.zeros(dimensions)
fn_mask[np.where(
(gt_mask == 1) & (current_vb_voxels == 0))] = 1
fn_nb_voxels = np.count_nonzero(fn_mask)
gt_total_nb_voxels = tp_nb_voxels + fn_nb_voxels
# Same as np.count_nonzero(gt_mask)
nb_voxels_total = tp_nb_voxels + fp_nb_voxels
# Same as np.count_nonzero(current_vb_voxels)
# Overlap = |B inter A| / |A|
overlap = tp_nb_voxels / gt_total_nb_voxels
# Overreach: two versions are sometimes used.
# |B except A| / |A| or |B except A| / |B|
if nb_voxels_total == 0:
overreach_pct_vs = 0
else:
overreach_pct_vs = fp_nb_voxels / nb_voxels_total
overreach_pct_gt = fp_nb_voxels / gt_total_nb_voxels
# f1 score (=dice)
f1 = compute_f1_score(overlap, overreach_pct_vs)
return (f1, tp_nb_voxels, fp_nb_voxels, fn_nb_voxels,
overlap, overreach_pct_gt, overreach_pct_vs)
[docs]
def get_binary_maps(sft):
"""
Extract a mask from a bundle.
Parameters
----------
sft: StatefulTractogram
Bundle.
Returns
-------
bundles_voxels: numpy.ndarray
Mask representing the bundle volume.
endpoints_voxels: numpy.ndarray
Mask representing the bundle's endpoints.
"""
sft.to_vox()
sft.to_corner()
_, dimensions, _, _ = sft.space_attributes
if len(sft) == 0:
return np.zeros(dimensions), np.zeros(dimensions)
bundles_voxels = compute_tract_counts_map(sft.streamlines,
dimensions).astype(np.int16)
endpoints_voxels = get_endpoints_density_map(sft).astype(np.int16)
bundles_voxels[bundles_voxels > 0] = 1
endpoints_voxels[endpoints_voxels > 0] = 1
return bundles_voxels, endpoints_voxels
[docs]
def compute_tractometry(
vb_sft_list, wpc_sft_list, ib_sft_list, nc_sft, args,
bundles_names, gt_masks, dimensions, ib_names):
"""
Tractometry stats: First in terms of connections (NC, IC, VS, WPC), then
in terms of volume (OL, OR, Dice score)
"""
logging.info("Computing tractometry")
vs_per_bundle = [len(x) if x is not None else 0 for x in vb_sft_list]
vb_count = np.count_nonzero(vs_per_bundle)
vs_count = np.sum(vs_per_bundle)
if wpc_sft_list is not None:
wpc_per_bundle = [len(x) if x is not None else 0 for x in wpc_sft_list]
wpb_count = np.count_nonzero(wpc_per_bundle)
wpc_count = np.sum(wpc_per_bundle)
else:
wpb_count = 0
wpc_count = 0
ic_per_ib_bundle = [len(x) for x in ib_sft_list]
ib_count = np.count_nonzero(ic_per_ib_bundle)
ic_count = np.sum(ic_per_ib_bundle)
nc_count = len(nc_sft) if nc_sft is not None else 0
total_count = vs_count + wpc_count + ic_count + nc_count
nb_bundles = len(bundles_names)
final_results = {
"total_streamlines": int(total_count),
"VB": int(vb_count),
"VS": int(vs_count),
"VS_ratio": vs_count / total_count,
"IS": int(ic_count + nc_count), # ic_count = 0 if not args.compute_ic
"IS_ratio": (ic_count + nc_count) / total_count,
}
if args.compute_ic:
final_results.update({
"IB": int(ib_count),
"IC": int(ic_count),
"IC_ratio": ic_count / total_count,
"NC": int(nc_count),
"NC_ratio": nc_count / total_count})
if args.save_wpc_separately:
final_results.update({
"WPC": int(wpc_count),
"WPC_bundle": wpb_count,
"WPC_ratio": wpc_count / total_count})
# Tractometry stats over volume: OL, OR, Dice score
mean_overlap = 0.0
mean_overreach_gt = 0.0
mean_overreach_vs = 0.0
mean_f1 = 0.0
nb_bundles_in_stats = 0
bundle_wise_dict = {}
for i in range(nb_bundles):
logging.debug("Scoring bundle {}".format({bundles_names[i]}))
current_vb = vb_sft_list[i]
bundle_results = {"VS": len(current_vb)}
if gt_masks[i] is not None:
if current_vb is None or len(current_vb) == 0:
logging.debug(" Empty bundle or bundle not found.")
bundle_results.update({
"TP": 0, "FP": 0, "FN": 0,
"OL": 0, "OR_pct_gt": 0, "OR_pct_vs": 0, "f1": 0,
"endpoints_OL": 0, "endpoints_OR": 0
})
nb_bundles_in_stats += 1
bundle_wise_dict.update({bundles_names[i]: bundle_results})
continue
# Getting the recovered mask
current_vb_voxels, current_vb_endpoints_voxels = get_binary_maps(
current_vb)
(f1, tp_nb_voxels, fp_nb_voxels, fn_nb_voxels,
overlap_count, overreach_pct_gt, overreach_pct_vs) = \
compute_f1_overlap_overreach(
current_vb_voxels, gt_masks[i], dimensions)
# Endpoints coverage
# todo. What is this? Useful?
endpoints_overlap = gt_masks[i] * current_vb_endpoints_voxels
endpoints_overreach = np.zeros(dimensions)
endpoints_overreach[np.where(
(gt_masks[i] == 0) & (current_vb_endpoints_voxels >= 1))] = 1
bundle_results.update({
"TP": tp_nb_voxels,
"FP": fp_nb_voxels,
"FN": fn_nb_voxels,
"OL": overlap_count,
"OR_pct_vs": overreach_pct_vs,
"OR_pct_gt": overreach_pct_gt,
"f1": f1,
"endpoints_OL": np.count_nonzero(endpoints_overlap),
"endpoints_OR": np.count_nonzero(endpoints_overreach)
})
# WPC
if args.save_wpc_separately:
wpc_sft = wpc_sft_list[i]
if wpc_sft is not None and len(wpc_sft) > 0:
current_wpc_voxels, _ = get_binary_maps(wpc_sft)
# We could add an option to include wpc streamlines to the
# overreach count. But it seems more natural to exclude wpc
# streamlines from any count. Separating into a different
# statistic dict. Else, user may simply not include a "ALL"
# mask, there won't be any wpc.
(_, tp_nb_voxels, fp_nb_voxels, _, overlap_count,
overreach_pct_gt, overreach_pct_vs) = \
compute_f1_overlap_overreach(
current_vb_voxels, gt_masks[i], dimensions)
wpc_results = {
"Count": len(wpc_sft),
"TP": tp_nb_voxels,
"FP": fp_nb_voxels,
"OL": overlap_count,
"OR_pct_vs": overreach_pct_vs,
"OR_pct_gt": overreach_pct_gt,
}
bundle_results.update({"WPC": wpc_results})
else:
bundle_results.update({"WPC": None})
mean_overlap += bundle_results["OL"]
mean_overreach_gt += bundle_results["OR_pct_gt"]
mean_overreach_vs += bundle_results["OR_pct_vs"]
mean_f1 += bundle_results["f1"]
nb_bundles_in_stats += 1
else:
bundle_results.update({"Scoring skipped": "No gt_mask provided"})
bundle_wise_dict.update({bundles_names[i]: bundle_results})
if args.compute_ic:
# -----------
# False connections stats: number of voxels
# -----------
ic_results = {}
for i in range(len(ib_names)):
current_ib = ib_sft_list[i]
if len(current_ib) > 0:
current_ib_voxels, _ = get_binary_maps(current_ib)
bundle_results = {
"IC": len(current_ib),
"nb_voxels": np.count_nonzero(current_ib_voxels)
}
ic_results.update({ib_names[i]: bundle_results})
bundle_wise_dict.update({"IB": ic_results})
final_results.update({"bundle_wise": bundle_wise_dict})
if nb_bundles_in_stats > 0:
final_results.update({
"mean_OL": mean_overlap / nb_bundles_in_stats,
"mean_OR_gt": mean_overreach_gt / nb_bundles_in_stats,
"mean_OR_vs": mean_overreach_vs / nb_bundles_in_stats,
"mean_f1": mean_f1 / nb_bundles_in_stats
})
return final_results