import math
import numpy as np
import nibabel as nib
import pyvista as pv
import urllib.request
import importlib_resources
from pathlib import Path
from typing import Any, cast
from scipy.io import loadmat
from matplotlib import cm as mpl_cm
from matplotlib.colors import ListedColormap
nib.imageglobals.logger.setLevel(40)
_DISCRETE_CMAP_REF_MAX: dict[str, int] = {}
# Parcellation
[docs]
def parcellate(dtseries :str|np.ndarray|nib.cifti2.cifti2.Cifti2Image|None=None,
atlas : str="schaefer",
resolution : int=100,
subcortical : None|str=None,
networks : int=7,
kong : bool=False,
standardize : bool=True,
method = np.mean,
return_labels : bool=False,
debug : bool=False
) -> np.ndarray | tuple[np.ndarray|None, list[str], np.ndarray]:
"""
Parcellate cifti data (.dtseries.nii) using a given atlas.
Atlases for many different parameter combinations are available and will be downloaded on demand.
If the atlas for the parameter combination is not available, a ValueError is raised.
References
----------
- Schaefer, Glasser, Gordon (+ Tian subcortical): https://github.com/yetianmed/subcortex
- Schaefer + Yan (cortical only): https://github.com/ThomasYeoLab/CBIG
Note: Any cortical atlas can be used on its own or combined with any Tian scale. Combinations
without an available file are assembled at runtime from the cortical atlas plus the Tian
subcortical block reused from the Gordon+Tian atlas.
Parameters
----------
dtseries : str, np.ndarray nibabel.cifti2.cifti2.Cifti2Image
string containing a path, array containing vertex data, or nibabel cifti image object
atlas : string
Name of the atlas to use for parcellation. Available options are:
- "schaefer": Schaefer et al. (2018) atlas
- "yan": Yan et al. (2023) homotopic atlas
- "glasser": Glasser et al. (2016) atlas
- "gordon": Gordon et al. (2016) atlas
resolution : int
Number of parcels in the atlas. Only used with the Schaefer and Yan atlases.
Available options are: 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000.
subcortical : None or string
If a string containing the scale is provided, the Tian subcortical parcels are included.
Available options are: None, 'S1' (16 ROIs), 'S2' (32 ROIs), 'S3' (50 ROIs), 'S4' (54 ROIs).
Works with every atlas at any scale: None gives a cortical-only parcellation, a scale
appends the Tian structures (keyed first). Cortical parcel counts: schaefer/yan = resolution,
glasser = 360, gordon = 333.
networks : int
Number of networks in the atlas. Only used with Schaefer atlas.
Available options are: 7, 17
kong : bool
Use the Kong 2022 version of the Schaefer atlas (only for Schaefer cortical atlas with 17 networks).
Reference: https://doi.org/10.1093/cercor/bhab101
standardize : bool
Standardize the time series to zero (temporal) mean and unit variance before parcellation.
method : function
Aggregation function to use for parcellation. Default (and the only tested function) is np.mean.
debug : bool
Flag to provide additional debugging information. Default is False.
Returns
-------
ts_parc : np.ndarray or tuple
If ``return_labels`` is False (default):
- ts_parc : (T, P) np.ndarray
Parcellated time series data.
If ``return_labels`` is True:
- ts_parc : (T, P) np.ndarray
Parcellated time series data. None if no input data was provided.
- node_labels : list of str
Label name for each parcel.
- vertex_labels : np.ndarray
ROI index for each vertex in the CIFTI file.
"""
if isinstance(dtseries, nib.cifti2.cifti2.Cifti2Image):
ts = dtseries.get_fdata()
elif isinstance(dtseries, np.ndarray) or isinstance(dtseries, np.memmap):
ts = dtseries
elif isinstance(dtseries, str):
data = nib.load(dtseries)
ts = data.get_fdata()
elif dtseries is None:
pass
else:
print("Error: Input must be either a string to a CIFTI file, a nibabel CIFTI object, " \
"a numpy array containing vertex data, or None (to return only atlas labels).")
return
# Check provided parameters
if atlas not in ["schaefer", "glasser", "gordon", "yan"]:
raise ValueError(f"Atlas '{atlas}' not available. Please choose from ['schaefer', 'glasser', 'gordon', 'yan'].")
if resolution not in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]:
raise ValueError(f"Resolution '{resolution}' not available. Please choose from [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000].")
if networks not in [7, 17]:
raise ValueError(f"Networks '{networks}' not available. Please choose from [7, 17].")
if subcortical not in [None, 'S1', 'S2', 'S3', 'S4']:
raise ValueError(f"Subcortical scale '{subcortical}' not available. Please choose from [None, 'S1', 'S2', 'S3', 'S4'].")
if kong not in [True, False]:
raise ValueError(f"Kong flag must be a boolean value (True or False).")
# Combinations which automatically adjust parameters with a warning instead of raising an error.
if atlas == "schaefer" and networks == 7 and kong is True:
print(f"[WARN] Schaefer Kong version is only available with 17 networks. Networks were set to 17.")
networks = 17
# Get the atlas
vertex_labels, keys, node_labels, _ = _get_atlas(atlas=atlas, resolution=resolution, networks=networks, subcortical=subcortical, kong=kong, debug=debug)
# If we have no input data, return the labels now
if dtseries is None:
return (None, node_labels, vertex_labels)
# Cortical-only atlases are 64984-vertex surface maps that include the medial wall,
# so the medial-wall columns have to be inserted into the (59412-vertex) cortical data.
if subcortical is None:
with importlib_resources.path("comet.data.atlas", "fs_LR_32k_medial_mask.mat") as maskdir:
medial_mask = loadmat(maskdir)['medial_mask'].squeeze().astype(bool)
idx = np.where(medial_mask == 0)[0]
# prepare idices and insert them into the HCP data
for i, value in enumerate(idx):
idx[i] = value - i
cortical_vertices = 59412 # HCP data has 59412 cortical vertices
ts = ts[:,:cortical_vertices]
ts = np.insert(ts, idx, np.nan, axis=1)
# Standardize before parcellation
if standardize:
ts = _stdize(ts)
# Parcellation
ts_parc = np.zeros((len(ts), len(node_labels)), dtype=ts.dtype)
i = 0
for k, lab in zip(keys, node_labels):
if k == 0:
continue
mask = (vertex_labels == k)
if not np.any(mask):
ts_parc[:, i] = 0
i += 1
print(f"[WARN] ROI {lab} is empty and was set to zero.")
continue
ts_parc[:, i] = method(ts[:, mask], axis=1)
i += 1
return (ts_parc, node_labels, vertex_labels) if return_labels else ts_parc
[docs]
def get_networks(labels: list[str]) -> tuple[list[str], np.ndarray, list[str], dict[str, int]]:
"""
Extract network information for Schaefer-Yeo parcellations.
Parameters
----------
labels : list of str
Atlas parcel labels obtained from ``cifti.parcellate()``.
Returns
-------
networks : list of str
Network label per parcel (length N).
ids : np.ndarray
Integer network ids per parcel (length N).
hemisphere : list of str
Hemisphere label per parcel ('LH' or 'RH'; length N).
network_map : dict[str, int]
Mapping from network name to integer id.
Raises
------
ValueError
If network labels cannot be inferred from the atlas labels.
"""
if len(labels) == 0:
raise ValueError("Empty label list.")
networks: list[str] = []
hemisphere: list[str] = []
for lab in labels:
# Cortical Schaefer Yeo-style labels
if ("networks_" in lab) or ("Networks_" in lab):
parts = lab.split("_")
if len(parts) < 3:
raise ValueError(f"Unexpected Schaefer label format: {lab}")
hemisphere.append(parts[1])
networks.append(parts[2])
# Simple subcortical extension
elif lab.endswith("-lh") or lab.endswith("-rh"):
hemisphere.append("LH" if lab.endswith("-lh") else "RH")
networks.append("Subcortical")
# Other atlases without canonical network labels will raise errors
elif lab.startswith("Parcel_"):
raise ValueError("Error: Gordon atlas detected. No canonical network partition is available.")
elif lab.endswith("_ROI"):
raise ValueError("Error: Glasser/HCP-MMP atlas detected. No canonical network partition is available.")
else:
raise ValueError(f"Unknown atlas label format; cannot infer network assignments: {lab}")
# Map network names to integers (stable alphabetical order)
uniq = sorted(set(networks))
if "Subcortical" in uniq:
uniq.remove("Subcortical")
network_map = {"Subcortical": 0}
for i, name in enumerate(uniq, start=1):
network_map[name] = i
else:
network_map = {name: i + 1 for i, name in enumerate(uniq)}
ids = np.array([network_map[n] for n in networks], dtype=int)
return networks, ids, hemisphere, network_map
def _get_atlas(atlas, resolution, networks, subcortical, kong, debug) -> tuple:
"""
Helper function: Get and prepare a CIFTI-2 atlas for parcellation.
Parameters
----------
**See parcellate() for details.**
Returns
-------
tuple
A tuple containing:
- rois : np.ndarray
ROI indices for each vertex.
- keys : np.ndarray
Keys of the atlas.
- labels : list
Labels of the atlas.
- rgba : list
RGBA values of each label.
"""
# Cortical sources. Schaefer/Yan ship standalone cortical dlabels; Glasser/Gordon are only
# published as Tian-combined files, so their cortex is extracted from those. The Tian
# subcortical block is always taken from the Gordon+Tian file and appended for any scale,
# so every atlas works both cortical-only and combined with any Tian scale.
base_urls = {
"schaefer_c": "https://github.com/ThomasYeoLab/CBIG/raw/master/stable_projects/brain_parcellation/Schaefer2018_LocalGlobal/Parcellations/HCP/fslr32k/cifti/Schaefer2018_{parcels}Parcels_{kong}{networks}Networks_order.dlabel.nii",
"yan": "https://github.com/ThomasYeoLab/CBIG/raw/master/stable_projects/brain_parcellation/Yan2023_homotopic/parcellations/HCP/fsLR32k/yeo{networks}/{parcels}Parcels_Yeo2011_{networks}Networks.dlabel.nii",
"gordon": "https://github.com/yetianmed/subcortex/raw/master/Group-Parcellation/3T/Cortex-Subcortex/Gordon333.32k_fs_LR_Tian_Subcortex_{subcortical}.dlabel.nii",
"glasser": "https://github.com/yetianmed/subcortex/raw/master/Group-Parcellation/3T/Cortex-Subcortex/Q1-Q6_RelatedValidation210.CorticalAreas_dil_Final_Final_Areas_Group_Colors.32k_fs_LR_Tian_Subcortex_{subcortical}.dlabel.nii",
}
def _download_and_load(filename, url):
"""Download (if needed) and load a CIFTI-2 dlabel atlas."""
with importlib_resources.path("comet.data.atlas", filename) as atlas_path:
if not atlas_path.exists():
urllib.request.urlretrieve(url, atlas_path)
print(f"Atlas not available. Downloading to: {atlas_path}")
return nib.load(str(atlas_path))
def _extract_labels(img):
"""Return (keys, labels, rgba) from the dlabel label table, skipping background (key 0)."""
# Usually for dlabel.nii files we have the following header structure
# axis 0: LabelAxis
# axis 1: BrainModelAxis
named_map = list(img.header.get_index_map(0).named_maps)[0]
keys, labels, rgba = [], [], []
for key, label in named_map.label_table.items():
if key == 0:
continue # skip background
keys.append(key)
labels.append(label.label)
rgba.append(label.rgba)
return np.asarray(keys), labels, rgba
def _print_brainmodels(img):
for idx, (name, _slice, _bm) in enumerate(img.header.get_axis(1).iter_structures()):
print(idx, str(name), _slice, _bm)
def _load_medial_mask():
with importlib_resources.path("comet.data.atlas", "fs_LR_32k_medial_mask.mat") as maskdir:
return loadmat(maskdir)["medial_mask"].squeeze().astype(bool)
def _cortical():
"""Cortical parcellation as a 64984-vertex surface map with contiguous keys 1..N.
Glasser and Gordon are downloaded as Tian-combined files, so we fetch the S1 file, drop the
leading Tian structures, renumber the cortical parcels to 1..N, and re-insert the medial wall.
"""
if atlas == "schaefer":
url = base_urls["schaefer_c"].format(parcels=resolution, networks=networks, kong="Kong2022_" if kong else "")
elif atlas == "yan":
url = base_urls["yan"].format(parcels=resolution, networks=networks)
else: # glasser, gordon
url = base_urls[atlas].format(subcortical="S1")
img = _download_and_load(url.split("/")[-1], url)
rois_full = img.dataobj[0].astype(int).squeeze()
keys, labels, rgba = _extract_labels(img)
if atlas in ["schaefer", "yan"]:
return rois_full, keys, labels, rgba, img # already a 64984-vertex cortical map
# Combined file: subcortex is keyed first (1..M), cortex after. Keep the cortex only.
n_sub = len(set(np.unique(rois_full[59412:]).tolist()) - {0})
cort_keys_sorted = sorted(int(k) for k in keys if int(k) > n_sub)
remap = {old: i + 1 for i, old in enumerate(cort_keys_sorted)} # -> 1..N
lut = np.zeros(int(rois_full.max()) + 1, dtype=int)
for old, new in remap.items():
lut[old] = new
medial_mask = _load_medial_mask()
rois = np.zeros(medial_mask.size, dtype=int) # 64984
rois[medial_mask] = lut[rois_full[:59412]] # re-insert the medial wall
key_to_idx = {int(k): i for i, k in enumerate(keys)}
return (rois,
np.asarray([remap[k] for k in cort_keys_sorted]),
[labels[key_to_idx[k]] for k in cort_keys_sorted],
[rgba[key_to_idx[k]] for k in cort_keys_sorted],
img)
cort_rois, cort_keys, cort_labels, cort_rgba, cort_img = _cortical()
# Cortical-only: return the 64984-vertex surface map.
if subcortical is None:
if debug:
_print_brainmodels(cort_img)
return (cort_rois, cort_keys, cort_labels, cort_rgba)
# Combined: append the Tian subcortical block taken from the Gordon+Tian file
sub_url = base_urls["gordon"].format(subcortical=subcortical)
sub_img = _download_and_load(sub_url.split("/")[-1], sub_url)
sub_rois = sub_img.dataobj[0].astype(int).squeeze()[59412:] # 31870 subcortical grayordinates
sub_keys, sub_labels, sub_rgba = _extract_labels(sub_img)
sub_present = sorted(int(k) for k in sub_keys if np.any(sub_rois == k))
n_sub = len(sub_present)
sub_remap = {old: i + 1 for i, old in enumerate(sub_present)}
sub_rois_new = np.zeros_like(sub_rois)
for old, new in sub_remap.items():
sub_rois_new[sub_rois == old] = new
cort_rois59412 = cort_rois[_load_medial_mask()] # drop the medial wall -> 59412 grayordinates
cort_rois_new = np.where(cort_rois59412 > 0, cort_rois59412 + n_sub, 0)
rois = np.concatenate([cort_rois_new, sub_rois_new]) # 91282 grayordinates (cortex-first)
key_to_idx = {int(k): i for i, k in enumerate(sub_keys)}
keys = np.concatenate([np.arange(1, n_sub + 1, dtype=cort_keys.dtype), cort_keys + n_sub])
labels = [sub_labels[key_to_idx[old]] for old in sub_present] + list(cort_labels)
rgba = [sub_rgba[key_to_idx[old]] for old in sub_present] + list(cort_rgba)
if debug:
print(f"[{atlas} cortical source]"); _print_brainmodels(cort_img)
print("[Gordon+Tian subcortical source]"); _print_brainmodels(sub_img)
return (rois, keys, labels, rgba)
def _stdize(ts) -> np.ndarray:
"""
Helper function: Standardize time series to zero (temporal) mean and unit standard deviation.
Parameters
----------
ts : np.ndarray
Time series data
Returns
-------
ts : np.ndarray
Standardized time series data
"""
mean = np.mean(ts, axis=0, keepdims=True)
std = np.std(ts, axis=0, keepdims=True)
std[std == 0] = 1.0
return (ts - mean) / std
# Plotting
[docs]
def surface_plot(node_values : np.ndarray|None=None,
vertex_labels: np.ndarray|None=None,
hemi: str="both",
surface: str="inflated",
view_names: tuple[str, str]=("medial", "lateral"),
ncols: int|None=None,
colwise: bool=True,
cmap: str="viridis",
border_color: None|str=None,
border_width: int=5,
distance: float=400.0,
size : list[int]|None=None,
labelsize : int=18,
colorbar: None|str="bottom",
colorbar_label : str|None=None,
interactive : bool=True,
fname : str|None=None):
"""
Plot cortical hemispheres with optional parcel border overlays.
Parameters
----------
node_values : ndarray or None
Parcel-level values (1D). If None, only surfaces are shown.
vertex_labels : ndarray or None
Vertex-to-parcel labels for both hemispheres (length 64984).
hemi : {"left", "right", "both"}
Hemisphere(s) to render.
surface : str
Surface type. Valid options are:
- "midthickness_orig"
- "midthickness_mni"
- "inflated"
- "very_inflated"
- "super_inflated"
- "sphere"
view_names : tuple containing one or multiple strings
Views to render per hemisphere. Options are:
- "lateral"
- "medial"
- "anterior"
- "posterior"
- "superior"
- "inferior"
ncols : int or None
Number of subplot columns.
colwise : bool
If True (default), fill subplots column-wise, else fill row-wise.
cmap : str
Colormap for node values.
border_color : str or None
Border color. If None, no border overlay is added.
border_mode : {"lines", "mask"}
Border rendering mode. "lines" draws smooth edge lines (default).
"mask" uses the legacy vertex-mask overlay method.
border_line_smoothing : int
Number of Chaikin smoothing iterations for line borders.
Only used when ``border_mode="lines"``.
border_line_decimate : int
Keep every Nth point along each border polyline before smoothing.
Use 1 to keep all points. Only used when ``border_mode="lines"``.
distance : float
Camera distance.
size : tuple[int, int] or None
Plotter window size.
colorbar : {"bottom", "right", None}
Shared colorbar placement outside data panels. If None, no colorbar is shown.
colorbar_label : str or None
Label for the colorbar.
interactive : bool
Show the plot in an interactive window (default is True).
fname : string or None
Save the plot (will consider manipulations done in the interactive window).
The name should contain the desired file type with the options being:
- Raster: ".png", ".jpeg", ".jpg", ".bmp", ".tif", ".tiff"
- Vectorised: ".svg", ".eps", ".ps", ".pdf", ".tex"
"""
# Input validation / normalization
if node_values is None:
print("Warning: node_values are required for data plotting. Proceeding with blank surfaces.")
else:
node_values = np.asarray(node_values, dtype=float)
vertex_labels = np.asarray(vertex_labels, dtype=np.int64)
if vertex_labels.ndim != 1:
raise ValueError("vertex_labels must be a 1D array of parcel IDs per vertex.")
if node_values.ndim != 1:
raise ValueError("node_values must be a 1D array of node/parcel values.")
if vertex_labels.size != 64984:
raise ValueError(f"vertex_labels must have length 64984. Got {vertex_labels.size}.")
# Get surface meshes and desired views
meshes = _get_surface(surface=surface) # get the surface mesh(es)
hemi_order = [hemi] if hemi in ("left", "right") else ["left", "right"] # which hemispheres to plot
panels = [(h, v) for h in hemi_order for v in view_names] # list of (hemisphere, view) pairs
# Define default camera positions
base_cams = {
"lateral": ("x", +1, (0, 0, 1)),
"medial": ("x", -1, (0, 0, 1)),
"anterior": ("y", -1, (0, 0, 1)),
"posterior": ("y", +1, (0, 0, 1)),
"superior": ("z", -1, (0, 1, 0)),
"inferior": ("z", +1, (0, 1, 0))
}
# Check validity of views
unknown_views = sorted({v for _, v in panels if v not in base_cams})
if unknown_views:
raise ValueError(f"Unknown view(s): {unknown_views}. Available: {list(base_cams.keys())}")
# Set up plot layout
n = len(panels)
if n == 0:
raise ValueError("No panels to plot. Check 'hemi' and available surfaces.")
if ncols is None:
ncols = min(3, n) if n <= 6 else int(math.ceil(math.sqrt(n)))
nrows = int(math.ceil(n / ncols))
if colwise:
ncols = int(math.ceil(n / nrows)) # guarantee enough columns
axis_idx = {"x": 0, "y": 1, "z": 2}
# Build per-vertex arrays for each hemisphere
scalars = {}
if node_values is not None and vertex_labels is not None:
def _map_to_vertices(parc):
out = np.full(32492, np.nan, dtype=float)
mask = (parc > 0) & (parc <= node_values.size)
out[mask] = node_values[parc[mask] - 1] # labels are 1-based
return out
lh_parc = vertex_labels[:32492]
rh_parc = vertex_labels[32492:]
scalars["left"] = _map_to_vertices(lh_parc)
scalars["right"] = _map_to_vertices(rh_parc)
# Use one shared colour scale across all plotted hemispheres
vals_list = [scalars[h][~np.isnan(scalars[h])] for h in scalars if h in meshes]
vals_list = [v for v in vals_list if v.size > 0]
clim = None
discrete_values = None
discrete_nlabels = 4
discrete_ref_max = None
if vals_list:
vals = np.concatenate(vals_list)
clim = (float(np.nanmin(vals)), float(np.nanmax(vals)))
if clim[0] == clim[1]:
eps = 1e-12 if clim[0] == 0.0 else abs(clim[0]) * 1e-12
clim = (clim[0] - eps, clim[1] + eps)
# Auto-detect categorical integer-valued maps (e.g., network IDs)
uniq_vals = np.unique(vals[np.isfinite(vals)])
# Background values are often encoded as 0 and rendered as NaN later.
# If nonzero categories exist, exclude 0 from discrete labeling.
if np.any(uniq_vals != 0):
uniq_vals = uniq_vals[uniq_vals != 0]
if uniq_vals.size > 0 and np.allclose(uniq_vals, np.round(uniq_vals)) and uniq_vals.size <= 32:
discrete_values = uniq_vals.astype(float)
discrete_nlabels = int(discrete_values.size)
# Use exact integer limits so scalar-bar ticks can be labeled 1..K.
clim = (float(np.min(discrete_values)), float(np.max(discrete_values)))
# Update/lookup reference max ID for this cmap across calls.
local_max = int(np.max(discrete_values))
prev_max = _DISCRETE_CMAP_REF_MAX.get(cmap, 0)
discrete_ref_max = max(prev_max, local_max)
_DISCRETE_CMAP_REF_MAX[cmap] = discrete_ref_max
# Colorbar needs an extra grid slot
show_shared_colorbar = (node_values is not None) and (clim is not None) and (colorbar is not None)
panel_nrows, panel_ncols = nrows, ncols
row_weights = None
col_weights = None
groups = None
if show_shared_colorbar and colorbar == "bottom":
plot_shape = (panel_nrows + 1, panel_ncols)
# Merge the full bottom row into one renderer and keep it narrow.
groups = [([panel_nrows], list(range(panel_ncols)))]
row_weights = [1.0] * panel_nrows + [0.2]
elif show_shared_colorbar and colorbar == "right":
plot_shape = (panel_nrows, panel_ncols + 1)
# Merge the full right column into one renderer and keep it narrow.
groups = [(list(range(panel_nrows)), [panel_ncols])]
col_weights = [1.0] * panel_ncols + [0.2]
else:
plot_shape = (panel_nrows, panel_ncols)
# Plotting
pv.global_theme.font.family = "times"
pl = pv.Plotter(shape=plot_shape, window_size=size, title="Comet Toolbox Surface Viewer", border=False, line_smoothing=True,
notebook=_in_notebook() and not interactive, off_screen=not interactive, row_weights=row_weights, col_weights=col_weights, groups=groups)
pl.enable_anti_aliasing("msaa")
# Loop through panels and plot each view
for i, (h, v) in enumerate(panels):
mesh = meshes[h]
center = mesh.center
axis, sign, up = base_cams[v]
row, col = (i % panel_nrows, i // panel_nrows) if colwise else (i // panel_ncols, i % panel_ncols)
# Swap lateral/medial for right hemisphere
if h == "right" and v in ("lateral", "medial"):
sign *= -1
# Plot the mesh
pl.subplot(row, col)
pl.add_text(f"{h} {v}", font_size=int(labelsize*0.7))
if node_values is None or vertex_labels is None:
pl.add_mesh(mesh, color="lightgray", smooth_shading=True)
colorbar_mapper = None
else:
values = scalars[h].copy()
# Make 0 values white by masking them
zero_mask = values == 0
values = values.astype(float)
values[zero_mask] = np.nan # treat zeros as NaN
# Plot data to the surface
mesh_kwargs = dict(scalars=values, clim=clim, nan_color="white", nan_opacity=1.0,
show_scalar_bar=False, interpolate_before_map=False, smooth_shading=True)
if discrete_values is not None:
# Keep ID->color mapping stable for single-ID subsets.
if int(discrete_values.size) == 1:
v = int(round(float(discrete_values[0])))
n_ref = max(int(discrete_ref_max) if discrete_ref_max is not None else v, v, 2)
cmap_ref = mpl_cm.get_cmap(cmap, n_ref)
mesh_kwargs["cmap"] = ListedColormap([cmap_ref(v - 1)])
mesh_kwargs["clim"] = (float(v) - 0.5, float(v) + 0.5)
else:
# Exact number of used categories/colors in the scalar bar.
mesh_kwargs["cmap"] = mpl_cm.get_cmap(cmap, int(discrete_values.size))
else:
mesh_kwargs["cmap"] = cmap
actor = pl.add_mesh(mesh, **mesh_kwargs)
colorbar_mapper = actor.mapper
# Draw parcel outlines
if border_color is not None:
outline_scalars = vertex_labels[:32492] if h == "left" else vertex_labels[32492:]
border_lines = _parcel_border_lines(mesh, outline_scalars)
pl.add_mesh(border_lines, color=border_color, line_width=border_width,
render_lines_as_tubes=True, lighting=True, show_scalar_bar=False)
# Camera position
cam_pos = list(center)
cam_pos[axis_idx[axis]] += sign * distance
pl.camera_position = [tuple(cam_pos), center, up]
# Plot the colorbar
if show_shared_colorbar and colorbar_mapper is not None:
add_scalar_bar = cast(Any, pl.add_scalar_bar)
if colorbar == "bottom":
cb_row, cb_col = panel_nrows, 0
pl.subplot(cb_row, cb_col)
n_labels = discrete_nlabels if discrete_values is not None else 4
fmt = "%.0f" if discrete_values is not None else "%.3g"
add_scalar_bar(title=colorbar_label, mapper=colorbar_mapper, vertical=False, width=0.33, height=0.7,
position_x=0.33, position_y=0.1, n_labels=n_labels, fmt=fmt, label_font_size=int(labelsize), title_font_size=int(labelsize*1.2))
else:
cb_row, cb_col = 0, panel_ncols
pl.subplot(cb_row, cb_col)
n_labels = discrete_nlabels if discrete_values is not None else 4
fmt = "%.0f" if discrete_values is not None else "%.3g"
add_scalar_bar(title=colorbar_label, mapper=colorbar_mapper, vertical=True, width=0.7, height=0.25,
position_x=0.1, position_y=0.33, n_labels=n_labels, fmt=fmt, label_font_size=int(labelsize), title_font_size=int(labelsize*1.2))
# Show static/interactive figure
if interactive:
pl.show(auto_close=False)
else:
if _in_notebook():
pl.show(jupyter_backend="static")
else:
pl.show()
# Save the figure
if fname is not None:
if fname.endswith(("svg", "pdf", "eps", "ps")):
pl.save_graphic(fname, raster=False)
elif fname.endswith(("png", "jpeg", "jpg", "bmp", "tif", "tiff")):
pl.screenshot(fname)
else:
pass
return pl.close()
[docs]
def subcortical_plot(node_values: np.ndarray | list[float] | None = None,
scale: str = "S1",
surface: str | None = "inflated",
view_names: tuple[str, ...] = ("lateral", "medial"),
ncols: int | None = None,
colwise: bool = True,
cmap: str = "viridis",
nan_color: str = "lightgray",
nan_alpha: float = 0.0,
surface_color: str = "lightgray",
surface_alpha: float = 0.10,
smooth_iter: int = 20,
smooth_relaxation: float = 0.2,
distance: float = 600.0,
size: list[int] | None = None,
labelsize: int = 18,
colorbar: None | str = "bottom",
colorbar_label: str | None = None,
interactive: bool = True,
fname: str | None = None):
"""
Plot Tian subcortical structures for scale ``S1`` to ``S4``.
If ``node_values`` contains cortex + subcortex values from ``parcellate(..., subcortical=scale)``,
the first N values are used automatically, where N is the number of Tian meshes for that scale.
The combined atlases place the subcortical parcels in the leading columns, and both the parcels
and the meshes are ordered by Tian region id, so the i-th value maps to the i-th mesh.
"""
meshes = _get_subcortical(scale=scale, smooth_iter=smooth_iter, smooth_relaxation=smooth_relaxation)
mesh_names = sorted(meshes)
n_sub = len(mesh_names)
values = None
if node_values is not None:
values = np.asarray(node_values, dtype=float)
if values.ndim != 1:
raise ValueError("node_values must be a 1D array-like or None.")
if values.size < n_sub:
raise ValueError(f"node_values must contain at least {n_sub} values for Tian {scale}.")
values = values[:n_sub]
panels = list(view_names)
base_cams = {
"lateral": ("x", +1, (0, 0, 1)),
"medial": ("x", -1, (0, 0, 1)),
"anterior": ("y", -1, (0, 0, 1)),
"posterior": ("y", +1, (0, 0, 1)),
"superior": ("z", -1, (0, 1, 0)),
"inferior": ("z", +1, (0, 1, 0))
}
unknown_views = sorted({v for v in panels if v not in base_cams})
if unknown_views:
raise ValueError(f"Unknown view(s): {unknown_views}. Available: {list(base_cams.keys())}")
n = len(panels)
if ncols is None:
ncols = min(3, n) if n <= 6 else int(math.ceil(math.sqrt(n)))
nrows = int(math.ceil(n / ncols))
if colwise:
ncols = int(math.ceil(n / nrows))
clim = None
if values is not None:
finite = values[np.isfinite(values)]
if finite.size > 0:
clim = (float(np.nanmin(finite)), float(np.nanmax(finite)))
if clim[0] == clim[1]:
eps = 1e-12 if clim[0] == 0.0 else abs(clim[0]) * 1e-12
clim = (clim[0] - eps, clim[1] + eps)
show_shared_colorbar = (values is not None) and (clim is not None) and (colorbar is not None)
panel_nrows, panel_ncols = nrows, ncols
row_weights = None
col_weights = None
groups = None
if show_shared_colorbar and colorbar == "bottom":
plot_shape = (panel_nrows + 1, panel_ncols)
groups = [([panel_nrows], list(range(panel_ncols)))]
row_weights = [1.0] * panel_nrows + [0.2]
elif show_shared_colorbar and colorbar == "right":
plot_shape = (panel_nrows, panel_ncols + 1)
groups = [(list(range(panel_nrows)), [panel_ncols])]
col_weights = [1.0] * panel_ncols + [0.2]
else:
plot_shape = (panel_nrows, panel_ncols)
context_meshes = _get_surface(surface=surface) if surface is not None else {}
pv.global_theme.font.family = "times"
pl = pv.Plotter(shape=plot_shape, window_size=size, title=f"Comet Toolbox Tian {scale} Viewer", border=False,
notebook=_in_notebook() and not interactive, off_screen=not interactive,
row_weights=row_weights, col_weights=col_weights, groups=groups)
pl.enable_anti_aliasing("msaa")
colorbar_mapper = None
axis_idx = {"x": 0, "y": 1, "z": 2}
all_centers = [np.asarray(mesh.center) for mesh in meshes.values()]
for mesh in context_meshes.values():
all_centers.append(np.asarray(mesh.center))
center = np.mean(np.vstack(all_centers), axis=0) if all_centers else np.zeros(3, dtype=float)
for i, view_name in enumerate(panels):
row, col = (i % panel_nrows, i // panel_nrows) if colwise else (i // panel_ncols, i % panel_ncols)
axis, sign, up = base_cams[view_name]
pl.subplot(row, col)
pl.add_text(view_name, font_size=int(labelsize * 0.7))
for mesh in context_meshes.values():
pl.add_mesh(mesh, color=surface_color, opacity=surface_alpha, show_scalar_bar=False,
smooth_shading=True)
for j, name in enumerate(mesh_names):
mesh = meshes[name]
if values is None:
actor = pl.add_mesh(mesh, color=surface_color, opacity=1.0, show_scalar_bar=False,
smooth_shading=True)
else:
val = float(values[j])
mesh["Data"] = np.full(mesh.n_points, val, dtype=float)
actor = pl.add_mesh(mesh, scalars="Data", cmap=cmap, clim=clim, nan_color=nan_color,
opacity=1.0 if np.isfinite(val) else nan_alpha, show_scalar_bar=False,
smooth_shading=True)
colorbar_mapper = actor.mapper
cam_pos = list(center)
cam_pos[axis_idx[axis]] += sign * distance
pl.camera_position = [tuple(cam_pos), tuple(center), up]
pl.hide_axes()
if show_shared_colorbar and colorbar_mapper is not None:
add_scalar_bar = cast(Any, pl.add_scalar_bar)
if colorbar == "bottom":
pl.subplot(panel_nrows, 0)
add_scalar_bar(title=colorbar_label, mapper=colorbar_mapper, vertical=False, width=0.33, height=0.7,
position_x=0.33, position_y=0.1, n_labels=4, fmt="%.3g",
label_font_size=int(labelsize), title_font_size=int(labelsize * 1.2))
else:
pl.subplot(0, panel_ncols)
add_scalar_bar(title=colorbar_label, mapper=colorbar_mapper, vertical=True, width=0.7, height=0.25,
position_x=0.1, position_y=0.33, n_labels=4, fmt="%.3g",
label_font_size=int(labelsize), title_font_size=int(labelsize * 1.2))
if interactive:
pl.show(auto_close=False)
else:
if _in_notebook():
pl.show(jupyter_backend="static")
else:
pl.show()
if fname is not None:
if fname.endswith(("svg", "pdf", "eps", "ps")):
pl.save_graphic(fname, raster=False)
elif fname.endswith(("png", "jpeg", "jpg", "bmp", "tif", "tiff")):
pl.screenshot(fname)
return pl.close()
def _get_surface(surface: str = "very_inflated") -> dict[str, pv.PolyData]:
"""
Download (if needed) and load fs_LR 32k cortical surfaces
from the CBIG template repository and return them as PyVista meshes.
Parameters
----------
surface : str, default="very_inflated"
Surface type. Options:
- "midthickness_orig"
- "midthickness_mni"
- "inflated"
- "very_inflated"
- "super_inflated"
- "sphere"
Returns
-------
dict
Dictionary containing the loaded hemispheres:
{"left": pv.PolyData,
"right": pv.PolyData}
Notes
-----
Surface files are downloaded from:
https://github.com/ThomasYeoLab/CBIG
Files are cached in the `comet.data.surf` resource directory.
"""
base_url = ("https://github.com/ThomasYeoLab/CBIG/raw/master/data/templates/surface/fs_LR_32k")
# CBIG filenames in that folder
lh_name = f"fsaverage.L.{surface}.32k_fs_LR.surf.gii"
rh_name = f"fsaverage.R.{surface}.32k_fs_LR.surf.gii"
lh_url = f"{base_url}/{lh_name}"
rh_url = f"{base_url}/{rh_name}"
# Store in resources
with importlib_resources.path("comet.data.surf", lh_name) as lh_path:
if not lh_path.exists():
urllib.request.urlretrieve(lh_url, lh_path)
lh_path = str(lh_path)
with importlib_resources.path("comet.data.surf", rh_name) as rh_path:
if not rh_path.exists():
urllib.request.urlretrieve(rh_url, rh_path)
rh_path = str(rh_path)
meshes = {}
# VTK requires faces stored in a single array structured as:
# [n_points, v0, v1, v2, n_points, v0, v1, v2, ...]
vertices, triangles = nib.load(lh_path).agg_data()
meshes["left"] = pv.make_tri_mesh(vertices, triangles)
vertices, triangles = nib.load(rh_path).agg_data()
meshes["right"] = pv.make_tri_mesh(vertices, triangles)
return meshes
def _get_subcortical(scale: str, smooth_iter: int = 20, smooth_relaxation: float = 0.2) -> dict[str, pv.PolyData]:
"""Download/build cached Tian subcortical meshes for one scale."""
scale = scale.upper()
if scale not in {"S1", "S2", "S3", "S4"}:
raise ValueError("scale must be one of 'S1', 'S2', 'S3', 'S4'.")
with importlib_resources.path("comet.data.surf", ".") as surf_root:
cache_root = Path(surf_root) / "subcortex"
cache_root.mkdir(parents=True, exist_ok=True)
nii_path = cache_root / f"Tian_Subcortex_{scale}_3T_2009cAsym.nii.gz"
txt_path = cache_root / f"Tian_Subcortex_{scale}_3T_label.txt"
mesh_dir = cache_root / "meshes" / scale.lower()
mesh_dir.mkdir(parents=True, exist_ok=True)
if not nii_path.exists():
url = f"https://raw.githubusercontent.com/yetianmed/subcortex/master/Group-Parcellation/3T/Subcortex-Only/{nii_path.name}"
nii_path.parent.mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(url, nii_path)
if not txt_path.exists():
url = f"https://raw.githubusercontent.com/yetianmed/subcortex/master/Group-Parcellation/3T/Subcortex-Only/{txt_path.name}"
try:
urllib.request.urlretrieve(url, txt_path)
except Exception:
txt_path = None
mesh_files = sorted(mesh_dir.glob("*.vtk"))
if not mesh_files:
labels: dict[int, str] = {}
if txt_path is not None and txt_path.exists():
with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
for line in f:
parts = line.replace("\t", " ").split()
if len(parts) < 2:
continue
try:
labels[int(parts[0])] = parts[1]
except ValueError:
continue
img = nib.load(str(nii_path))
data = np.asarray(img.get_fdata(), dtype=np.int32)
ids = np.unique(data)
ids = ids[ids > 0]
for label_id in ids:
mask = data == label_id
if not np.any(mask):
continue
label = labels.get(int(label_id), f"Region_{int(label_id):03d}")
fname = f"{int(label_id):03d}_{label}.vtk"
grid = pv.ImageData(dimensions=np.array(mask.shape))
grid.point_data["values"] = np.ascontiguousarray(mask.astype(np.uint8)).ravel(order="F")
mesh = grid.contour(isosurfaces=[0.5], scalars="values")
mesh = mesh.triangulate().clean()
mesh.points = nib.affines.apply_affine(img.affine, mesh.points)
if smooth_iter > 0:
mesh = mesh.smooth(n_iter=smooth_iter, relaxation_factor=smooth_relaxation)
mesh.compute_normals(inplace=True)
mesh.save(mesh_dir / fname)
mesh_files = sorted(mesh_dir.glob("*.vtk"))
return {f.stem: pv.read(f) for f in mesh_files}
def _parcel_border_lines(mesh, parc_labels, offset: float = 0.10,
smooth_iters: int = 2, decimate_step: int = 2):
"""Build a line mesh representing parcel boundaries."""
parc_labels = np.asarray(parc_labels)
faces = mesh.faces.reshape(-1, 4)[:, 1:]
e0 = faces[:, [0, 1]]
e1 = faces[:, [1, 2]]
e2 = faces[:, [2, 0]]
edges = np.vstack([e0, e1, e2])
edges = np.sort(edges, axis=1)
edges = np.unique(edges, axis=0)
edge_labels = parc_labels[edges]
border_edges = edges[edge_labels[:, 0] != edge_labels[:, 1]]
if border_edges.size == 0:
return pv.PolyData()
# Prune dangling branches by repeatedly removing degree-1 endpoints.
while border_edges.size > 0:
vertex_degree = np.bincount(border_edges.ravel(), minlength=mesh.n_points)
keep = (vertex_degree[border_edges[:, 0]] > 1) & (vertex_degree[border_edges[:, 1]] > 1)
new_border_edges = border_edges[keep]
if new_border_edges.shape[0] == border_edges.shape[0]:
break
border_edges = new_border_edges
if border_edges.size == 0:
return pv.PolyData()
points = np.asarray(mesh.points).copy()
if offset != 0.0:
normals_mesh = mesh.compute_normals(point_normals=True, cell_normals=False, inplace=False)
normals = np.asarray(normals_mesh.point_data["Normals"])
points += offset * normals
polylines = _edges_to_polylines(border_edges)
out_points: list[np.ndarray] = []
out_lines: list[np.ndarray] = []
next_id = 0
for chain in polylines:
chain_idx = np.asarray(chain, dtype=np.int64)
coords = points[chain_idx]
is_closed = bool(chain[0] == chain[-1])
coords = _decimate_polyline(coords, step=decimate_step, closed=is_closed)
coords = _chaikin_smooth(coords, iterations=smooth_iters, closed=is_closed)
out_points.append(coords)
ids = np.arange(next_id, next_id + coords.shape[0], dtype=np.int64)
out_lines.append(np.concatenate(([coords.shape[0]], ids)))
next_id += coords.shape[0]
line_mesh = pv.PolyData()
line_mesh.points = np.vstack(out_points)
line_mesh.lines = np.concatenate(out_lines)
return line_mesh
def _edges_to_polylines(edges: np.ndarray) -> list[list[int]]:
"""Convert undirected edges into connected polylines/cycles."""
adj: dict[int, list[int]] = {}
for u, v in edges:
ui, vi = int(u), int(v)
adj.setdefault(ui, []).append(vi)
adj.setdefault(vi, []).append(ui)
visited: set[tuple[int, int]] = set()
polylines: list[list[int]] = []
def _edge(a: int, b: int) -> tuple[int, int]:
return (a, b) if a < b else (b, a)
starts = [node for node, neigh in adj.items() if len(neigh) != 2]
for start in starts:
for nxt in adj[start]:
e = _edge(start, nxt)
if e in visited:
continue
visited.add(e)
chain = [start, nxt]
prev, cur = start, nxt
while True:
neigh = adj[cur]
if len(neigh) != 2:
break
cand = neigh[0] if neigh[1] == prev else neigh[1]
e2 = _edge(cur, cand)
if e2 in visited:
break
visited.add(e2)
chain.append(cand)
prev, cur = cur, cand
polylines.append(chain)
for u, v in edges:
start, nxt = int(u), int(v)
e = _edge(start, nxt)
if e in visited:
continue
visited.add(e)
chain = [start, nxt]
prev, cur = start, nxt
while True:
neigh = adj[cur]
cand = neigh[0] if neigh[1] == prev else neigh[1]
e2 = _edge(cur, cand)
if e2 in visited:
if cand == start:
chain.append(cand)
break
visited.add(e2)
chain.append(cand)
prev, cur = cur, cand
polylines.append(chain)
return polylines
def _decimate_polyline(points: np.ndarray, step: int, closed: bool) -> np.ndarray:
"""Keep every Nth point in a polyline."""
n = points.shape[0]
if step <= 1 or n <= 2:
return points
if closed:
core = points[:-1] if np.allclose(points[0], points[-1]) else points
if core.shape[0] <= 3:
out = core
else:
idx = np.arange(0, core.shape[0], step, dtype=int)
if idx[-1] != core.shape[0] - 1:
idx = np.append(idx, core.shape[0] - 1)
out = core[idx]
return np.vstack([out, out[0]])
idx = np.arange(0, n, step, dtype=int)
if idx[-1] != n - 1:
idx = np.append(idx, n - 1)
return points[idx]
def _chaikin_smooth(points: np.ndarray, iterations: int, closed: bool) -> np.ndarray:
"""Smooth the polylines."""
pts = points
for _ in range(iterations):
if pts.shape[0] < 3:
break
if closed:
core = pts[:-1] if np.allclose(pts[0], pts[-1]) else pts
nxt = np.roll(core, -1, axis=0)
q = 0.75 * core + 0.25 * nxt
r = 0.25 * core + 0.75 * nxt
out = np.empty((2 * core.shape[0], 3), dtype=core.dtype)
out[0::2] = q
out[1::2] = r
pts = np.vstack([out, out[0]])
else:
out = np.empty((2 * (pts.shape[0] - 1) + 2, 3), dtype=pts.dtype)
out[0] = pts[0]
j = 1
for i in range(pts.shape[0] - 1):
p0, p1 = pts[i], pts[i + 1]
out[j] = 0.75 * p0 + 0.25 * p1
out[j + 1] = 0.25 * p0 + 0.75 * p1
j += 2
out[-1] = pts[-1]
pts = out
return pts
def _in_notebook():
"""Check if the code is running in a Jupyter notebook."""
try:
from IPython import get_ipython
if 'IPKernelApp' not in get_ipython().config:
return False
except Exception:
return False
return True