Source code for scilpy.viz.plot


import matplotlib
import matplotlib.pyplot as plt
import numpy as np


[docs] def plot_metrics_stats(means, stds, title=None, xlabel=None, ylabel=None, figlabel=None, fill_color=None, display_means=False): """ Plots the mean of a metric along n points with the standard deviation. Parameters ---------- means: Numpy 1D (or 2D) array of size n Mean of the metric along n points. stds: Numpy 1D (or 2D) array of size n Standard deviation of the metric along n points. title: string Title of the figure. xlabel: string Label of the X axis. ylabel: string Label of the Y axis (suggestion: the metric name). figlabel: string Label of the figure (only metadata in the figure object returned). fill_color: string Hexadecimal RGB color filling the region between mean ± std. The hexadecimal RGB color should be formatted as #RRGGBB display_means: bool Display the subjects means as semi-transparent line Return ------ The figure object. """ matplotlib.style.use('ggplot') fig, ax = plt.subplots() # Set optional information to the figure, if required. if title is not None: ax.set_title(title) if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) if figlabel is not None: fig.set_label(figlabel) if means.ndim > 1: mean = np.average(means, axis=1) std = np.average(stds, axis=1) alpha = 0.5 else: mean = np.array(means).ravel() std = np.array(stds).ravel() alpha = 0.9 dim = np.arange(1, len(mean)+1, 1) if len(mean) <= 20: ax.xaxis.set_ticks(dim) ax.set_xlim(0, len(mean)+1) if means.ndim > 1 and display_means: for i in range(means.shape[-1]): ax.plot(dim, means[:, i], color="k", linewidth=1, solid_capstyle='round', alpha=0.1) # Plot the mean line. ax.plot(dim, mean, color="k", linewidth=5, solid_capstyle='round') # Plot the std plt.fill_between(dim, mean - std, mean + std, facecolor=fill_color, alpha=alpha) plt.close(fig) return fig
[docs] def plot_residuals(data_diff, mask, R_k, q1, q3, iqr, residual_basename): """ Plots residual statistics for DWI. Parameters ---------- data_diff: np.ndarray The 4D residuals between the DWI and predicted data. mask : Numpy 3D array or None Mask array indicating the region of interest for computing residuals. If None, residuals are computed for the entire dataset. R_k : Numpy 1D array Mean residual values for each DWI volume. q1 : Numpy 1D array First quartile values for each DWI volume. q3 : Numpy 1D array Third quartile values for each DWI volume. iqr : Numpy 1D array Interquartile range (Q3 - Q1) for each DWI volume. residual_basename : string Basename for saving the output plot file. The file will be saved as '<residual_basename>_residuals_stats.png'. Returns ------- None The function generates a plot and saves it as a PNG file. """ # Showing results in graph # Note that stats will be computed manually and plotted using bxp # but could be computed using stats = cbook.boxplot_stats # or pyplot.boxplot(x) # Initializing stats as a List[dict] stats = [dict.fromkeys(['label', 'mean', 'iqr', 'cilo', 'cihi', 'whishi', 'whislo', 'fliers', 'q1', 'med', 'q3'], []) for _ in range(data_diff.shape[-1])] nb_voxels = np.count_nonzero(mask) percent_outliers = np.zeros(data_diff.shape[-1], dtype=np.float32) for k in range(data_diff.shape[-1]): stats[k]['med'] = (q1[k] + q3[k]) / 2 stats[k]['mean'] = R_k[k] stats[k]['q1'] = q1[k] stats[k]['q3'] = q3[k] stats[k]['whislo'] = q1[k] - 1.5 * iqr[k] stats[k]['whishi'] = q3[k] + 1.5 * iqr[k] stats[k]['label'] = k # Outliers are observations that fall below Q1 - 1.5(IQR) or # above Q3 + 1.5(IQR) We check if a voxel is an outlier only if # we have a mask, else we are biased. if mask is not None: x = data_diff[..., k] outliers = (x < stats[k]['whislo']) | (x > stats[k]['whishi']) percent_outliers[k] = np.sum(outliers) / nb_voxels * 100 # What would be our definition of too many outliers? # Maybe mean(all_means)+-3SD? # Or we let people choose based on the figure. # if percent_outliers[k] > ???? : # logger.warning(' Careful! Diffusion-Weighted Image' # ' i=%s has %s %% outlier voxels', # k, percent_outliers[k]) if mask is None: fig, axe = plt.subplots(nrows=1, ncols=1, squeeze=False) else: fig, axe = plt.subplots(nrows=1, ncols=2, squeeze=False, figsize=[10, 4.8]) # Default is [6.4, 4.8]. Increasing width to see better. medianprops = dict(linestyle='-', linewidth=2.5, color='firebrick') meanprops = dict(linestyle='-', linewidth=2.5, color='green') axe[0, 0].bxp(stats, showmeans=True, meanline=True, showfliers=False, medianprops=medianprops, meanprops=meanprops) axe[0, 0].set_xlabel('DW image') axe[0, 0].set_ylabel('Residuals per DWI volume. Red is median,\n' 'green is mean. Whiskers are 1.5*interquartile') axe[0, 0].set_title('Residuals') axe[0, 0].set_xticks(range(0, q1.shape[0], 5)) axe[0, 0].set_xticklabels(range(0, q1.shape[0], 5)) if mask is not None: axe[0, 1].plot(range(data_diff.shape[-1]), percent_outliers) axe[0, 1].set_xticks(range(0, q1.shape[0], 5)) axe[0, 1].set_xticklabels(range(0, q1.shape[0], 5)) axe[0, 1].set_xlabel('DW image') axe[0, 1].set_ylabel('Percentage of outlier voxels') axe[0, 1].set_title('Outliers') plt.savefig(residual_basename + '_residuals_stats.png')