Skip to content

Commit

Permalink
try this
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo committed Oct 29, 2024
1 parent 2a6649d commit 3f72d7e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 62 deletions.
66 changes: 63 additions & 3 deletions xcp_d/interfaces/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,16 @@ class _TSVConnectInputSpec(BaseInterfaceInputSpec):
mandatory=False,
desc="Temporal mask, after dummy scan removal.",
)
flatten = traits.Bool(
False,
usedefault=True,
desc="Flatten the correlation matrix to a TSV file.",
)


class _TSVConnectOutputSpec(TraitedSpec):
correlations = File(exists=True, desc="Correlation matrix file.")
correlations_square = File(exists=True, desc="Square correlation matrix file.")
correlations_exact = traits.Either(
None,
traits.List(File(exists=True)),
Expand Down Expand Up @@ -256,6 +262,16 @@ def correlate_timeseries(timeseries, temporal_mask):
return correlations_df, correlations_exact


def flatten_conmat(df):
df = df.where(np.triu(np.ones(df.shape[0])).astype(bool))
df = df.stack().reset_index()
df.columns = ["Row", "Column", "Value"]
df["Edge"] = df["Row"] + "-" + df["Column"]
df = df.set_index("Edge")
df = df[["Edge", "Value"]].T
return df


class TSVConnect(SimpleInterface):
"""Extract timeseries and compute connectivity matrices.
Expand All @@ -273,6 +289,23 @@ def _run_interface(self, runtime):
temporal_mask=self.inputs.temporal_mask,
)

self._results["correlations_square"] = fname_presuffix(
"correlations_square.tsv",
newpath=runtime.cwd,
use_ext=True,
)
correlations_df.to_csv(
self._results["correlations"],
sep="\t",
na_rep="n/a",
index_label="Node",
)
if self.inputs.flatten:
correlations_df = flatten_conmat(correlations_df)
kwargs = {"index": False}
else:
kwargs = {"index_label": "Node"}

self._results["correlations"] = fname_presuffix(
"correlations.tsv",
newpath=runtime.cwd,
Expand All @@ -282,7 +315,7 @@ def _run_interface(self, runtime):
self._results["correlations"],
sep="\t",
na_rep="n/a",
index_label="Node",
**kwargs,
)
del correlations_df
gc.collect()
Expand All @@ -298,11 +331,14 @@ def _run_interface(self, runtime):
newpath=runtime.cwd,
use_ext=True,
)
if self.inputs.flatten:
exact_correlations_df = flatten_conmat(exact_correlations_df)

exact_correlations_df.to_csv(
exact_correlations_file,
sep="\t",
na_rep="n/a",
index_label="Node",
**kwargs,
)
self._results["correlations_exact"].append(exact_correlations_file)

Expand Down Expand Up @@ -539,10 +575,16 @@ class _CiftiToTSVInputSpec(BaseInterfaceInputSpec):
desc="Parcellated CIFTI file to extract into a TSV.",
)
atlas_labels = File(exists=True, mandatory=True, desc="atlas labels file")
flatten = traits.Bool(
False,
usedefault=True,
desc="Flatten the correlation matrix to a TSV file.",
)


class _CiftiToTSVOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="Parcellated data TSV file.")
correlations_square = File(desc="Square correlation matrix TSV file.")


class CiftiToTSV(SimpleInterface):
Expand Down Expand Up @@ -657,7 +699,25 @@ def _run_interface(self, runtime):
)

if in_file.endswith(".pconn.nii"):
df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", index_label="Node")
self._results["correlations_square"] = fname_presuffix(
"correlations_square.tsv",
newpath=runtime.cwd,
use_ext=True,
)
df.to_csv(
self._results["correlations_square"],
sep="\t",
na_rep="n/a",
index_label="Node",
)

if self.inputs.flatten:
df = flatten_conmat(df)
kwargs = {"index": False}
else:
kwargs = {"index_label": "Node"}

df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", **kwargs)
else:
df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", index=False)

Expand Down
72 changes: 13 additions & 59 deletions xcp_d/workflows/bold/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from xcp_d import config
from xcp_d.interfaces.bids import DerivativesDataSink
from xcp_d.interfaces.connectivity import FlattenTSV
from xcp_d.utils.atlas import select_atlases
from xcp_d.utils.doc import fill_doc
from xcp_d.workflows.parcellation import init_parcellate_cifti_wf
Expand Down Expand Up @@ -133,14 +132,18 @@ def init_functional_connectivity_nifti_wf(mem_gb, name="connectivity_wf"):

if config.workflow.output_correlations:
functional_connectivity = pe.MapNode(
TSVConnect(),
TSVConnect(flatten=config.workflow.flatten_conmats),
name="functional_connectivity",
iterfield=["timeseries"],
mem_gb=mem_gb["timeseries"],
)
workflow.connect([
(inputnode, functional_connectivity, [("temporal_mask", "temporal_mask")]),
(parcellate_data, functional_connectivity, [("timeseries", "timeseries")]),
(functional_connectivity, outputnode, [
("correlations", "correlations"),
("correlations_exact", "correlations_exact"),
]),
]) # fmt:skip

connectivity_plot = pe.Node(
Expand All @@ -153,7 +156,9 @@ def init_functional_connectivity_nifti_wf(mem_gb, name="connectivity_wf"):
("atlases", "atlases"),
("atlas_labels_files", "atlas_tsvs"),
]),
(functional_connectivity, connectivity_plot, [("correlations", "correlations_tsv")]),
(functional_connectivity, connectivity_plot, [
("correlations_square", "correlations_tsv"),
]),
]) # fmt:skip

ds_report_connectivity_plot = pe.Node(
Expand All @@ -168,27 +173,6 @@ def init_functional_connectivity_nifti_wf(mem_gb, name="connectivity_wf"):
(connectivity_plot, ds_report_connectivity_plot, [("connectplot", "in_file")]),
]) # fmt:skip

if config.workflow.flatten_conmats:
flatten_conmats = pe.MapNode(
FlattenTSV(kind="conmat"),
name="flatten_conmats",
iterfield=["in_file", "in_file_exact"],
)
workflow.connect([
(functional_connectivity, flatten_conmats, [
("correlations", "in_file"),
("correlations_exact", "in_file_exact"),
]),
(flatten_conmats, outputnode, [("out_file", "correlations")]),
]) # fmt:skip
else:
workflow.connect([
(functional_connectivity, outputnode, [
("correlations", "correlations"),
("correlations_exact", "correlations_exact"),
]),
]) # fmt:skip

parcellate_reho = pe.MapNode(
NiftiParcellate(min_coverage=min_coverage),
name="parcellate_reho",
Expand Down Expand Up @@ -438,13 +422,14 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit

# Convert correlation pconn file to TSV
dconn_to_tsv = pe.MapNode(
CiftiToTSV(),
CiftiToTSV(flatten=config.workflow.flatten_conmats),
name="dconn_to_tsv",
iterfield=["in_file", "atlas_labels"],
)
workflow.connect([
(inputnode, dconn_to_tsv, [("atlas_labels_files", "atlas_labels")]),
(correlate_bold, dconn_to_tsv, [("out_file", "in_file")]),
(dconn_to_tsv, outputnode, [("out_file", "correlations")]),
]) # fmt:skip

# Plot up to four connectivity matrices
Expand All @@ -458,7 +443,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit
("atlases", "atlases"),
("atlas_labels_files", "atlas_tsvs"),
]),
(dconn_to_tsv, connectivity_plot, [("out_file", "correlations_tsv")]),
(dconn_to_tsv, connectivity_plot, [("correlations_square", "correlations_tsv")]),
]) # fmt:skip

ds_report_connectivity = pe.Node(
Expand All @@ -474,19 +459,6 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit
(connectivity_plot, ds_report_connectivity, [("connectplot", "in_file")]),
]) # fmt:skip

if config.workflow.flatten_conmats:
flatten_conmats = pe.MapNode(
FlattenTSV(kind="conmat"),
name="flatten_conmats",
iterfield=["in_file"],
)
workflow.connect([
(dconn_to_tsv, flatten_conmats, [("out_file", "in_file")]),
(flatten_conmats, outputnode, [("out_file", "correlations")]),
]) # fmt:skip
else:
workflow.connect([(dconn_to_tsv, outputnode, [("out_file", "correlations")])])

# Perform exact-time correlations
if exact_scans:
collect_exact_ciftis = pe.Node(
Expand Down Expand Up @@ -528,34 +500,16 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit

# Convert correlation pconn file to TSV
exact_dconn_to_tsv = pe.MapNode(
CiftiToTSV(),
CiftiToTSV(flatten=config.workflow.flatten_conmats),
name=f"dconn_to_tsv_{exact_scan}volumes",
iterfield=["in_file", "atlas_labels"],
)
workflow.connect([
(inputnode, exact_dconn_to_tsv, [("atlas_labels_files", "atlas_labels")]),
(correlate_exact_bold, exact_dconn_to_tsv, [("out_file", "in_file")]),
(exact_dconn_to_tsv, collect_exact_tsvs, [("out_file", f"in{i_exact_scan + 1}")]),
]) # fmt:skip

if config.workflow.flatten_conmats:
flatten_conmats_exact = pe.MapNode(
FlattenTSV(kind="conmat"),
name=f"flatten_conmats_{exact_scan}volumes",
iterfield=["in_file"],
)
workflow.connect([
(exact_dconn_to_tsv, flatten_conmats_exact, [("out_file", "in_file")]),
(flatten_conmats_exact, collect_exact_tsvs, [
("out_file", f"in{i_exact_scan + 1}"),
]),
]) # fmt:skip
else:
workflow.connect([
(exact_dconn_to_tsv, collect_exact_tsvs, [
("out_file", f"in{i_exact_scan + 1}"),
]),
]) # fmt:skip

parcellate_reho_wf = init_parcellate_cifti_wf(
mem_gb=mem_gb,
compute_mask=False,
Expand Down

0 comments on commit 3f72d7e

Please sign in to comment.