From a7390a3ec57848ee0c843284d3c88fc53037e82f Mon Sep 17 00:00:00 2001 From: celprov Date: Tue, 19 Dec 2023 15:16:18 -0800 Subject: [PATCH] enh: add interface to get corresponding file with given extension --- niworkflows/utils/bids.py | 58 +++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/niworkflows/utils/bids.py b/niworkflows/utils/bids.py index fdcfbd586f3..68ec30a0598 100644 --- a/niworkflows/utils/bids.py +++ b/niworkflows/utils/bids.py @@ -43,6 +43,12 @@ "asl": {"datatype": "perf", "suffix": "asl"}, } +OUTPUT_NAME_PATTERN = [ + "sub-{subject}[_ses-{session}]_task-{task}[_acq-{acquisition}][_ce-{ceagent}]" + "[_rec-{reconstruction}][_dir-{direction}][_run-{run}][_echo-{echo}][_part-{part}]" + "[_space-{space}][_cohort-{cohort}][_desc-{desc}]_{suffix}" +] + class BIDSError(ValueError): def __init__(self, message, bids_root): @@ -64,9 +70,7 @@ class BIDSWarning(RuntimeWarning): pass -def collect_participants( - bids_dir, participant_label=None, strict=False, bids_validate=True -): +def collect_participants(bids_dir, participant_label=None, strict=False, bids_validate=True): """ List the participants under the BIDS root and checks that participants designated with the participant_label argument exist in that folder. @@ -135,9 +139,7 @@ def collect_participants( participant_label = [participant_label] # Drop sub- prefixes - participant_label = [ - sub[4:] if sub.startswith("sub-") else sub for sub in participant_label - ] + participant_label = [sub[4:] if sub.startswith("sub-") else sub for sub in participant_label] # Remove duplicates participant_label = sorted(set(participant_label)) # Remove labels not found @@ -258,8 +260,7 @@ def collect_data( queries["bold"]["echo"] = echo subj_data = { - dtype: sorted(layout.get(**layout_get_kwargs, **query)) - for dtype, query in queries.items() + dtype: sorted(layout.get(**layout_get_kwargs, **query)) for dtype, query in queries.items() } # Special case: multi-echo BOLD, grouping echos @@ -488,3 +489,44 @@ def check_pipeline_version(cvers, data_desc): dvers = desc.get("PipelineDescription", {}).get("Version", "0+unknown") if Version(cvers).public != Version(dvers).public: return "Previous output generated by version {} found.".format(dvers) + + +def corresponding_file(in_file, bids_dir, extension, bids_validate=True): + """ + Returns the corresponding file with the given extension + + Parameters + ---------- + in_file : :obj:`str` + Filename + bids_dir : :obj:`str` or :obj:`bids.layout.BIDSLayout` + The BIDS directory + extension : :obj: `str` + An extension string + bids_validate : :obj:`bool` + Whether the `bids_dir` is validated upon initialization + + Returns + ------- + filename : :obj:`str` + The path to the corresponding file + """ + + if isinstance(bids_dir, BIDSLayout): + layout = bids_dir + else: + layout = BIDSLayout(str(bids_dir), validate=bids_validate) + + entities = layout.parse_file_entities(in_file) + entities.pop("extension", None) + entities.pop("echo", None) + entities.pop("part", None) + entities["suffix"] = extension + + corresponding_filename = layout.build_path(entities, OUTPUT_NAME_PATTERN, validate=False) + + corresponding_filename = Path(corresponding_filename) + if not corresponding_filename.is_absolute(): + corresponding_filename = in_file.parent / corresponding_filename + + return corresponding_filename