Skip to content

Commit

Permalink
Improve script to convert CARDINAL dataset to MMCL format to configur…
Browse files Browse the repository at this point in the history
…e how to name tabular data files
  • Loading branch information
nathanpainchaud committed Jan 28, 2025
1 parent 296610e commit 9b490c5
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions didactic/scripts/cardinal2mmcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def save_images_and_tabular_data(
label_tag: TabularAttribute,
output_dir: Path,
label_as_a_feature: bool = True,
tabular_tag: str = "tabular",
subsets: Dict[str, Sequence[int]] = None,
) -> None:
"""Serializes the images to disk using `torch.save`, and saves tabular data to a CSV file.
Expand All @@ -72,6 +73,7 @@ def save_images_and_tabular_data(
label_tag: Tabular variable to use as the label.
output_dir: Directory to save the data to.
label_as_a_feature: Whether to keep the label as a feature (LaaF) in the tabular data.
tabular_tag: Tag to use to describe the selection of tabular features in the tabular data files.
subsets: Optional dictionary mapping indices of images/rows in specific subsets. If provided, the subsets will
be saved to separate 'pt' and CSV files.
"""
Expand Down Expand Up @@ -103,12 +105,12 @@ def save_images_and_tabular_data(
# Group data by subset if necessary
if subsets:
imgs_by_subset = {f"{name}_images": images[idxs] for name, idxs in subsets.items()}
tab_by_subset = {f"{name}_tabular": tabular_df.iloc[idxs] for name, idxs in subsets.items()}
tab_by_subset = {f"{name}_{tabular_tag}": tabular_df.iloc[idxs] for name, idxs in subsets.items()}
labels_by_subset = {f"{name}_{label_tag}": labels[idxs] for name, idxs in subsets.items()}
else:
imgs_by_subset = {"images": images}
tab_by_subset = {"tabular": tabular_df}
labels_by_subset = {"labels": labels}
tab_by_subset = {tabular_tag: tabular_df}
labels_by_subset = {label_tag: labels}

# Save the images as a tensor
for tag, subset_imgs in imgs_by_subset.items():
Expand Down Expand Up @@ -139,6 +141,12 @@ def main():
parser.add_argument(
"--tabular_attrs", type=TabularAttribute, nargs="*", help="Tabular attributes to collect and save"
)
parser.add_argument(
"--tabular_tag",
type=str,
default="tabular",
help="Tag to use to describe the selection of tabular features in the tabular data files",
)
parser.add_argument(
"--label_tag",
type=TabularAttribute,
Expand All @@ -162,10 +170,11 @@ def main():
args = parser.parse_args()
kwargs = vars(args)

img_size, norm_bounds, tabular_attrs, label_tag, laaf, imp_rand_state, output_dir, subsets = (
img_size, norm_bounds, tabular_attrs, tabular_tag, label_tag, laaf, imp_rand_state, output_dir, subsets = (
kwargs.pop("img_size"),
kwargs.pop("norm_bounds"),
kwargs.pop("tabular_attrs"),
kwargs.pop("tabular_tag"),
kwargs.pop("label_tag"),
kwargs.pop("laaf"),
kwargs.pop("imputer_random_state"),
Expand Down Expand Up @@ -198,7 +207,9 @@ def main():
}

# Save the data
save_images_and_tabular_data(images, tabular_df, label_tag, output_dir, label_as_a_feature=laaf, subsets=subsets)
save_images_and_tabular_data(
images, tabular_df, label_tag, output_dir, label_as_a_feature=laaf, tabular_tag=tabular_tag, subsets=subsets
)


if __name__ == "__main__":
Expand Down

0 comments on commit 9b490c5

Please sign in to comment.