State-based dFC for continuously varying measures
The previous tutorials showed how to (1) estimate continuously varying dFC measures and (2) work with inherently state-based measures. However, Comet also allows users to perform clustering analysis on continuous measures.
We can simply use the same example as before to derive state measures for a single subject:
[7]:
from matplotlib import pyplot as plt
from nilearn import datasets
from comet import connectivity, utils
# Preprocessed time series data from the ABIDE dataset
subject = 50010
data = datasets.fetch_abide_pcp(SUB_ID=subject, pipeline='cpac', band_pass_filtering=True, derivatives="rois_dosenbach160")
ts = data.rois_dosenbach160[0]
sw = connectivity.SlidingWindow(ts)
dfc_sw = sw.estimate()
[fetch_abide_pcp] Dataset found in /home/mibur/nilearn_data/ABIDE_pcp
Next, we can perform clustering analysis and extract some summary measures:
[2]:
state_tc, states, inertia = utils.kmeans_cluster(dfc_sw, num_states=5, random_state=0)
summary = utils.summarise_state_tc(state_tc)
fig1, ax1 = utils.state_plots(states=states, figsize=(8,2))
fig2, ax2 = utils.state_plots(state_tc=state_tc, figsize=(8,2))
fig3, ax3 = utils.state_plots(summary=summary, figsize=(8,3.5))
More commonly, state-based analysis uses multiple subjects for estimating the state dynamics. For this, we can simply estimate dFC for multiple subjects and store the estimates in a list before performing the clustering analysis:
[3]:
# Get data from 5 subjects
subjects = ["50008", "50010", "50012", "50014", "50020"]
data = datasets.fetch_abide_pcp(SUB_ID=subjects, pipeline='cpac', band_pass_filtering=True, derivatives="rois_dosenbach160")
ts = data.rois_dosenbach160 # list of 2D time series data
print("Number of subjects:",len(ts))
print("Time series shape:", ts[0].shape)
# Estimate dFC for all subjects and store as a list
dfc_list = []
for sub_ts in ts:
dfc = connectivity.LeiDA(sub_ts).estimate()
dfc_list.append(dfc)
[fetch_abide_pcp] Dataset found in /home/mibur/nilearn_data/ABIDE_pcp
Number of subjects: 5
Time series shape: (196, 161)
You can then estimate state dynamics. TO calculate popular summary metrics, the summarise_state_tc and state_plots functions are available:
[4]:
state_tc, states, inertia = utils.kmeans_cluster(dfc_list, strategy="pooled")
summary = utils.summarise_state_tc(state_tc)
print(f"Available summary metrics: {summary.keys()}\n")
print("Average transition matrix:")
print(summary["transitions"].mean(axis=0))
# Plot results
fig1, ax1 = utils.state_plots(states=states, figsize=(8,2))
fig2, ax2 = utils.state_plots(state_tc=state_tc, figsize=(8,8))
fig3, ax3 = utils.state_plots(summary=summary, figsize=(8,3.5))
Available summary metrics: dict_keys(['dwell_times', 'fractional_occupancy', 'transitions', 'transition_counts', 'transitions_sum', 'switch_rate'])
Average transition matrix:
[[0.81064516 0. 0. 0.15645161 0.03290323]
[0. 0.78292848 0.01295547 0.15167341 0.05244265]
[0. 0.02877196 0.70434969 0.01107716 0.05580118]
[0.01536338 0.04860802 0.0039604 0.83489767 0.09717054]
[0.01222989 0.00956322 0.08330327 0.13062644 0.76427719]]
The attentive reader might have noticed that there is the normal SlidingWindow class as well as the SlidingWindowClustering class for the state-based method. In practice, both classes yield equivalent results when SlidingWindow is combined with the two-level clustering strategy (strategy=”two_level”) implemented in kmeans_cluster:
[5]:
# SlidingWindow + kmeans_cluster
dfc_list = []
for ts_i in ts:
dfc_sw = connectivity.SlidingWindow(ts_i, windowsize=29, stepsize=1, shape="gaussian", diagonal=1).estimate()
dfc_list.append(dfc_sw)
state_tc, _, _ = utils.kmeans_cluster(dfc_list, num_states=5, subject_clusters=5, strategy="two_level", random_state=42)
# SlidingWindowClustering
state_tc_swc, _ = connectivity.SlidingWindowClustering(ts, n_states=5, subject_clusters=5, windowsize=29, stepsize=1, random_state=42).estimate()
Sliding Window Clustering: 100%|██████████| 5/5 [00:31<00:00, 6.21s/it]
[6]:
fig1, ax1 = utils.state_plots(states=states, figsize=(8,2))
sub_idx = 2
fig, ax = plt.subplots()
ax.plot(state_tc[sub_idx], label="SlidingWindow()", lw=2)
ax.plot(state_tc_swc[sub_idx], label="SlidingWindowClustering()", ls="--", lw=2)
ax.set(xlabel="Time windows", ylabel="State")
plt.legend(loc="lower left");