From 3f72d7e00b32bfc79d6338890ea4cecd382ced25 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Tue, 29 Oct 2024 14:12:16 -0400 Subject: [PATCH] try this --- xcp_d/interfaces/connectivity.py | 66 +++++++++++++++++++++++-- xcp_d/workflows/bold/connectivity.py | 72 +++++----------------------- 2 files changed, 76 insertions(+), 62 deletions(-) diff --git a/xcp_d/interfaces/connectivity.py b/xcp_d/interfaces/connectivity.py index 15f0a3bdd..3596b1b17 100644 --- a/xcp_d/interfaces/connectivity.py +++ b/xcp_d/interfaces/connectivity.py @@ -217,10 +217,16 @@ class _TSVConnectInputSpec(BaseInterfaceInputSpec): mandatory=False, desc="Temporal mask, after dummy scan removal.", ) + flatten = traits.Bool( + False, + usedefault=True, + desc="Flatten the correlation matrix to a TSV file.", + ) class _TSVConnectOutputSpec(TraitedSpec): correlations = File(exists=True, desc="Correlation matrix file.") + correlations_square = File(exists=True, desc="Square correlation matrix file.") correlations_exact = traits.Either( None, traits.List(File(exists=True)), @@ -256,6 +262,16 @@ def correlate_timeseries(timeseries, temporal_mask): return correlations_df, correlations_exact +def flatten_conmat(df): + df = df.where(np.triu(np.ones(df.shape[0])).astype(bool)) + df = df.stack().reset_index() + df.columns = ["Row", "Column", "Value"] + df["Edge"] = df["Row"] + "-" + df["Column"] + df = df.set_index("Edge") + df = df[["Edge", "Value"]].T + return df + + class TSVConnect(SimpleInterface): """Extract timeseries and compute connectivity matrices. @@ -273,6 +289,23 @@ def _run_interface(self, runtime): temporal_mask=self.inputs.temporal_mask, ) + self._results["correlations_square"] = fname_presuffix( + "correlations_square.tsv", + newpath=runtime.cwd, + use_ext=True, + ) + correlations_df.to_csv( + self._results["correlations"], + sep="\t", + na_rep="n/a", + index_label="Node", + ) + if self.inputs.flatten: + correlations_df = flatten_conmat(correlations_df) + kwargs = {"index": False} + else: + kwargs = {"index_label": "Node"} + self._results["correlations"] = fname_presuffix( "correlations.tsv", newpath=runtime.cwd, @@ -282,7 +315,7 @@ def _run_interface(self, runtime): self._results["correlations"], sep="\t", na_rep="n/a", - index_label="Node", + **kwargs, ) del correlations_df gc.collect() @@ -298,11 +331,14 @@ def _run_interface(self, runtime): newpath=runtime.cwd, use_ext=True, ) + if self.inputs.flatten: + exact_correlations_df = flatten_conmat(exact_correlations_df) + exact_correlations_df.to_csv( exact_correlations_file, sep="\t", na_rep="n/a", - index_label="Node", + **kwargs, ) self._results["correlations_exact"].append(exact_correlations_file) @@ -539,10 +575,16 @@ class _CiftiToTSVInputSpec(BaseInterfaceInputSpec): desc="Parcellated CIFTI file to extract into a TSV.", ) atlas_labels = File(exists=True, mandatory=True, desc="atlas labels file") + flatten = traits.Bool( + False, + usedefault=True, + desc="Flatten the correlation matrix to a TSV file.", + ) class _CiftiToTSVOutputSpec(TraitedSpec): out_file = File(exists=True, desc="Parcellated data TSV file.") + correlations_square = File(desc="Square correlation matrix TSV file.") class CiftiToTSV(SimpleInterface): @@ -657,7 +699,25 @@ def _run_interface(self, runtime): ) if in_file.endswith(".pconn.nii"): - df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", index_label="Node") + self._results["correlations_square"] = fname_presuffix( + "correlations_square.tsv", + newpath=runtime.cwd, + use_ext=True, + ) + df.to_csv( + self._results["correlations_square"], + sep="\t", + na_rep="n/a", + index_label="Node", + ) + + if self.inputs.flatten: + df = flatten_conmat(df) + kwargs = {"index": False} + else: + kwargs = {"index_label": "Node"} + + df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", **kwargs) else: df.to_csv(self._results["out_file"], sep="\t", na_rep="n/a", index=False) diff --git a/xcp_d/workflows/bold/connectivity.py b/xcp_d/workflows/bold/connectivity.py index e783a3837..b267208fa 100644 --- a/xcp_d/workflows/bold/connectivity.py +++ b/xcp_d/workflows/bold/connectivity.py @@ -9,7 +9,6 @@ from xcp_d import config from xcp_d.interfaces.bids import DerivativesDataSink -from xcp_d.interfaces.connectivity import FlattenTSV from xcp_d.utils.atlas import select_atlases from xcp_d.utils.doc import fill_doc from xcp_d.workflows.parcellation import init_parcellate_cifti_wf @@ -133,7 +132,7 @@ def init_functional_connectivity_nifti_wf(mem_gb, name="connectivity_wf"): if config.workflow.output_correlations: functional_connectivity = pe.MapNode( - TSVConnect(), + TSVConnect(flatten=config.workflow.flatten_conmats), name="functional_connectivity", iterfield=["timeseries"], mem_gb=mem_gb["timeseries"], @@ -141,6 +140,10 @@ def init_functional_connectivity_nifti_wf(mem_gb, name="connectivity_wf"): workflow.connect([ (inputnode, functional_connectivity, [("temporal_mask", "temporal_mask")]), (parcellate_data, functional_connectivity, [("timeseries", "timeseries")]), + (functional_connectivity, outputnode, [ + ("correlations", "correlations"), + ("correlations_exact", "correlations_exact"), + ]), ]) # fmt:skip connectivity_plot = pe.Node( @@ -153,7 +156,9 @@ def init_functional_connectivity_nifti_wf(mem_gb, name="connectivity_wf"): ("atlases", "atlases"), ("atlas_labels_files", "atlas_tsvs"), ]), - (functional_connectivity, connectivity_plot, [("correlations", "correlations_tsv")]), + (functional_connectivity, connectivity_plot, [ + ("correlations_square", "correlations_tsv"), + ]), ]) # fmt:skip ds_report_connectivity_plot = pe.Node( @@ -168,27 +173,6 @@ def init_functional_connectivity_nifti_wf(mem_gb, name="connectivity_wf"): (connectivity_plot, ds_report_connectivity_plot, [("connectplot", "in_file")]), ]) # fmt:skip - if config.workflow.flatten_conmats: - flatten_conmats = pe.MapNode( - FlattenTSV(kind="conmat"), - name="flatten_conmats", - iterfield=["in_file", "in_file_exact"], - ) - workflow.connect([ - (functional_connectivity, flatten_conmats, [ - ("correlations", "in_file"), - ("correlations_exact", "in_file_exact"), - ]), - (flatten_conmats, outputnode, [("out_file", "correlations")]), - ]) # fmt:skip - else: - workflow.connect([ - (functional_connectivity, outputnode, [ - ("correlations", "correlations"), - ("correlations_exact", "correlations_exact"), - ]), - ]) # fmt:skip - parcellate_reho = pe.MapNode( NiftiParcellate(min_coverage=min_coverage), name="parcellate_reho", @@ -438,13 +422,14 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit # Convert correlation pconn file to TSV dconn_to_tsv = pe.MapNode( - CiftiToTSV(), + CiftiToTSV(flatten=config.workflow.flatten_conmats), name="dconn_to_tsv", iterfield=["in_file", "atlas_labels"], ) workflow.connect([ (inputnode, dconn_to_tsv, [("atlas_labels_files", "atlas_labels")]), (correlate_bold, dconn_to_tsv, [("out_file", "in_file")]), + (dconn_to_tsv, outputnode, [("out_file", "correlations")]), ]) # fmt:skip # Plot up to four connectivity matrices @@ -458,7 +443,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit ("atlases", "atlases"), ("atlas_labels_files", "atlas_tsvs"), ]), - (dconn_to_tsv, connectivity_plot, [("out_file", "correlations_tsv")]), + (dconn_to_tsv, connectivity_plot, [("correlations_square", "correlations_tsv")]), ]) # fmt:skip ds_report_connectivity = pe.Node( @@ -474,19 +459,6 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit (connectivity_plot, ds_report_connectivity, [("connectplot", "in_file")]), ]) # fmt:skip - if config.workflow.flatten_conmats: - flatten_conmats = pe.MapNode( - FlattenTSV(kind="conmat"), - name="flatten_conmats", - iterfield=["in_file"], - ) - workflow.connect([ - (dconn_to_tsv, flatten_conmats, [("out_file", "in_file")]), - (flatten_conmats, outputnode, [("out_file", "correlations")]), - ]) # fmt:skip - else: - workflow.connect([(dconn_to_tsv, outputnode, [("out_file", "correlations")])]) - # Perform exact-time correlations if exact_scans: collect_exact_ciftis = pe.Node( @@ -528,34 +500,16 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit # Convert correlation pconn file to TSV exact_dconn_to_tsv = pe.MapNode( - CiftiToTSV(), + CiftiToTSV(flatten=config.workflow.flatten_conmats), name=f"dconn_to_tsv_{exact_scan}volumes", iterfield=["in_file", "atlas_labels"], ) workflow.connect([ (inputnode, exact_dconn_to_tsv, [("atlas_labels_files", "atlas_labels")]), (correlate_exact_bold, exact_dconn_to_tsv, [("out_file", "in_file")]), + (exact_dconn_to_tsv, collect_exact_tsvs, [("out_file", f"in{i_exact_scan + 1}")]), ]) # fmt:skip - if config.workflow.flatten_conmats: - flatten_conmats_exact = pe.MapNode( - FlattenTSV(kind="conmat"), - name=f"flatten_conmats_{exact_scan}volumes", - iterfield=["in_file"], - ) - workflow.connect([ - (exact_dconn_to_tsv, flatten_conmats_exact, [("out_file", "in_file")]), - (flatten_conmats_exact, collect_exact_tsvs, [ - ("out_file", f"in{i_exact_scan + 1}"), - ]), - ]) # fmt:skip - else: - workflow.connect([ - (exact_dconn_to_tsv, collect_exact_tsvs, [ - ("out_file", f"in{i_exact_scan + 1}"), - ]), - ]) # fmt:skip - parcellate_reho_wf = init_parcellate_cifti_wf( mem_gb=mem_gb, compute_mask=False,