diff --git a/EventStream/baseline/FT_task_baseline.py b/EventStream/baseline/FT_task_baseline.py index ddf4fdb1..a3281867 100644 --- a/EventStream/baseline/FT_task_baseline.py +++ b/EventStream/baseline/FT_task_baseline.py @@ -47,6 +47,7 @@ def load_flat_rep( do_update_if_missing: bool = True, task_df_name: str | None = None, do_cache_filtered_task: bool = True, + overwrite_cache_filtered_task: bool = False, subjects_included: dict[str, set[int]] | None = None, ) -> dict[str, pl.LazyFrame]: """Loads a set of flat representations from a passed dataset that satisfy the given constraints. @@ -68,7 +69,7 @@ def load_flat_rep( do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will try to update the stored flat representations to reflect these. If `False`, if information is missing, it will raise a `FileNotFoundError` instead. - task_df_name: If specified, the flat representations loaded will be (inner) joined against the task + task_df_name: If specified, the flat representations loaded will be joined against the task dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a @@ -76,6 +77,8 @@ def load_flat_rep( dataset. do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the relevant rows for the task, be cached to disk for faster re-use. + overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, + the cached file will be loaded if exists. subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are used wholesale. @@ -148,8 +151,9 @@ def load_flat_rep( by_split = {} for sp, all_sp_subjects in ESD.split_subjects.items(): + all_sp_subjects = pl.Series(list(all_sp_subjects)).cast(ESD.subject_id_dtype) if task_df_name is not None: - sp_join_df = join_df.filter(pl.col("subject_id").is_in(list(all_sp_subjects))) + sp_join_df = join_df.filter(pl.col("subject_id").is_in(all_sp_subjects)) static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet") if task_df_name is not None: @@ -171,19 +175,28 @@ def load_flat_rep( if task_df_name is not None: fn = fp.parts[-1] cached_fp = task_window_dir / fn - if cached_fp.is_file(): + if cached_fp.is_file() and not overwrite_cache_filtered_task: df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features) if subjects_included.get(sp, None) is not None: subjects = list(set(subjects).intersection(subjects_included[sp])) - df = df.filter(pl.col("subject_id").is_in(subjects)) + df = df.filter( + pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype)) + ) window_dfs.append(df) continue df = pl.scan_parquet(fp) if task_df_name is not None: - filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects)) + filter_join_df = sp_join_df.select(join_keys).filter( + pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype)) + ) - df = df.join(filter_join_df, on=join_keys, how="inner") + df = filter_join_df.join_asof( + df, + by="subject_id", + on="timestamp", + strategy="forward" if "-" in window_size else "backward", + ) if do_cache_filtered_task: cached_fp.parent.mkdir(exist_ok=True, parents=True) @@ -193,7 +206,7 @@ def load_flat_rep( df = df.select("subject_id", "timestamp", *window_features) if subjects_included.get(sp, None) is not None: subjects = list(set(subjects).intersection(subjects_included[sp])) - df = df.filter(pl.col("subject_id").is_in(subjects)) + df = df.filter(pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))) window_dfs.append(df) diff --git a/EventStream/data/dataset_base.py b/EventStream/data/dataset_base.py index 109996d5..0220e2b8 100644 --- a/EventStream/data/dataset_base.py +++ b/EventStream/data/dataset_base.py @@ -1239,9 +1239,12 @@ def cache_flat_representation( raise ValueError("\n".join(err_strings)) elif not do_overwrite: raise FileExistsError(f"do_overwrite is {do_overwrite} and {params_fp} exists!") - - with open(params_fp, mode="w") as f: - json.dump(params, f) + else: + with open(params_fp) as f: + params = json.load(f) + else: + with open(params_fp, mode="w") as f: + json.dump(params, f) # 0. Identify Output Columns # We set window_sizes to None here because we want to get the feature column names for the raw flat @@ -1265,10 +1268,11 @@ def cache_flat_representation( static_dfs[sp].append(fp) if fp.exists(): if do_update: + logger.debug(f'Skipping static representation for split: {sp}, {i}.parquet') continue elif not do_overwrite: raise FileExistsError(f"do_overwrite is {do_overwrite} and {fp} exists!") - + logger.debug('Creating static representation') df = self._get_flat_static_rep( feature_columns=feature_columns, include_only_subjects=subjects_list, @@ -1289,10 +1293,11 @@ def cache_flat_representation( ts_dfs[sp].append(fp) if fp.exists(): if do_update: + logger.debug(f'Skipping raw representation for split: {sp}, {i}.parquet') continue elif not do_overwrite: raise FileExistsError(f"do_overwrite is {do_overwrite} and {fp} exists!") - + logger.debug('Creating raw representation') df = self._get_flat_ts_rep( feature_columns=feature_columns, include_only_subjects=subjects_list, @@ -1312,10 +1317,11 @@ def cache_flat_representation( fp = history_subdir / sp / window_size / f"{i}.parquet" if fp.exists(): if do_update: + logger.debug(f'Skipping summarized history representation for split: {sp}, window: {window_size}, {i}.parquet') continue elif not do_overwrite: raise FileExistsError(f"do_overwrite is {do_overwrite} and {fp} exists!") - + logger.debug('Creating summarized history representation') df = self._summarize_over_window(df_fp, window_size) self._write_df(df, fp) @@ -1367,7 +1373,7 @@ def cache_deep_learning_representation( NRT_dir = self.config.save_dir / "NRT_reps" shards_fp = self.config.save_dir / "DL_shards.json" - if shards_fp.exists(): + if shards_fp.exists() and not do_overwrite: shards = json.loads(shards_fp.read_text()) else: shards = {} @@ -1375,19 +1381,16 @@ def cache_deep_learning_representation( if subjects_per_output_file is None: subject_chunks = [self.subject_ids] else: - subjects = np.random.permutation(list(self.subject_ids)) + subjects = np.random.permutation(list(set(self.subject_ids))) subject_chunks = np.array_split( subjects, np.arange(subjects_per_output_file, len(subjects), subjects_per_output_file), ) - subject_chunks = [[int(x) for x in c] for c in subject_chunks] - for chunk_idx, subjects_list in enumerate(subject_chunks): for split, subjects in self.split_subjects.items(): shard_key = f"{split}/{chunk_idx}" - included_subjects = set(subjects_list).intersection({int(x) for x in subjects}) - shards[shard_key] = list(included_subjects) + shards[shard_key] = list(set(subjects_list).intersection(subjects)) shards_fp.write_text(json.dumps(shards)) diff --git a/EventStream/data/dataset_polars.py b/EventStream/data/dataset_polars.py index a9cf4456..6cdfdc99 100644 --- a/EventStream/data/dataset_polars.py +++ b/EventStream/data/dataset_polars.py @@ -690,6 +690,7 @@ def _agg_by_time(self): ) def _update_subject_event_properties(self): + self.subject_id_dtype = self.events_df.schema["subject_id"] if self.events_df is not None: logger.debug("Collecting event types") self.event_types = ( @@ -699,15 +700,28 @@ def _update_subject_event_properties(self): .to_list() ) - n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas() - self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict() + logger.debug("Collecting subject event counts") + n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count")) + n_events = n_events.drop_nulls("subject_id") + # here we cast to str to avoid issues with the subject_id column being various other types as we + # will eventually JSON serialize it. + n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)) + self.n_events_per_subject = { + subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"]) + } self.subject_ids = set(self.n_events_per_subject.keys()) if self.subjects_df is not None: - logger.debug("Collecting subject event counts") - subjects_with_no_events = ( - set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids + subjects_df_subjects = ( + self.subjects_df + .drop_nulls("subject_id") + .select(pl.col("subject_id").cast(pl.Utf8)) ) + subjects_df_subj_ids = set(subjects_df_subjects["subject_id"].to_list()) + subj_no_in_df = self.subject_ids - subjects_df_subj_ids + if len(subj_no_in_df) > 0: + logger.warning(f"Found {len(subj_no_in_df)} subjects not in subject df!") + subjects_with_no_events = subjects_df_subj_ids - self.subject_ids for sid in subjects_with_no_events: self.n_events_per_subject[sid] = 0 self.subject_ids.update(subjects_with_no_events) @@ -723,7 +737,20 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool | filter_exprs.append(pl.col(col).is_null()) case _: try: - incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) + logger.debug( + f"Converting inclusion targets of type {type(list(incl_targets)[0])} for " + f"{col} to {df.schema[col]}" + ) + if isinstance(list(incl_targets)[0], str): + incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8) + else: + incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) + + incl_list = incl_list.cast(df.schema[col]) + + logger.debug( + f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}" + ) except TypeError as e: incl_targets_by_type = defaultdict(list) for t in incl_targets: @@ -1358,6 +1385,8 @@ def build_DL_cached_representation( # 1. Process subject data into the right format. if subject_ids: subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids}) + logger.warning( f"Size of given subject_ids are {len(subject_ids)}, but after _filter_col_inclusion " + f"the size of subjects_df are {len(subjects_df)}") else: subjects_df = self.subjects_df @@ -1369,6 +1398,7 @@ def build_DL_cached_representation( pl.col("index").alias("static_indices"), ) ) + logger.debug(f"Size of static_data: {static_data.shape[0]}") # 2. Process event data into the right format. if subject_ids: @@ -1378,6 +1408,7 @@ def build_DL_cached_representation( events_df = self.events_df event_ids = None event_data = self._melt_df(events_df, ["subject_id", "timestamp", "event_id"], event_measures) + logger.debug(f"Size of event_data: {event_data.shape[0]}") # 3. Process measurement data into the right base format: if event_ids: @@ -1392,6 +1423,7 @@ def build_DL_cached_representation( if do_sort_outputs: dynamic_data = dynamic_data.sort("event_id", "measurement_id") + logger.debug(f"Size of dynamic_data: {dynamic_data.shape[0]}") # 4. Join dynamic and event data. @@ -1444,7 +1476,8 @@ def _summarize_static_measurements( if include_only_subjects is None: df = self.subjects_df else: - df = self.subjects_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) + self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) + df = self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) valid_measures = {} for feat_col in feature_columns: @@ -1490,9 +1523,9 @@ def _summarize_static_measurements( ) remap_cols = [c for c in pivoted_df.columns if c not in ID_cols] - out_dfs[m] = pivoted_df.lazy().select( + out_dfs[m] = pivoted_df.select( *ID_cols, *[pl.col(c).alias(f"static/{m}/{c}/present").cast(pl.Boolean) for c in remap_cols] - ) + ).lazy() return pl.concat(list(out_dfs.values()), how="align") @@ -1504,7 +1537,8 @@ def _summarize_time_dependent_measurements( if include_only_subjects is None: df = self.events_df else: - df = self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) + self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) + df = self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) valid_measures = {} for feat_col in feature_columns: @@ -1549,13 +1583,13 @@ def _summarize_time_dependent_measurements( ) remap_cols = [c for c in pivoted_df.columns if c not in ID_cols] - out_dfs[m] = pivoted_df.lazy().select( + out_dfs[m] = pivoted_df.select( *ID_cols, *[ pl.col(c).cast(pl.Boolean).alias(f"functional_time_dependent/{m}/{c}/present") for c in remap_cols ], - ) + ).lazy() return pl.concat(list(out_dfs.values()), how="align") @@ -1567,8 +1601,9 @@ def _summarize_dynamic_measurements( if include_only_subjects is None: df = self.dynamic_measurements_df else: + self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) df = self.dynamic_measurements_df.join( - self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select( + self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])).select( "event_id" ), on="event_id", @@ -1676,10 +1711,10 @@ def _summarize_dynamic_measurements( values=values_cols, aggregate_function=None, ) - .lazy() .drop("measurement_id") .group_by("event_id") .agg(*aggs) + .lazy() ) return pl.concat(list(out_dfs.values()), how="align") @@ -1764,11 +1799,9 @@ def _get_flat_ts_rep( ) .drop("event_id") .sort(by=["subject_id", "timestamp"]) - .collect() .lazy(), [c for c in feature_columns if not c.startswith("static/")], ) - # The above .collect().lazy() shouldn't be necessary but it appears to be for some reason... def _normalize_flat_rep_df_cols( self, flat_df: DF_T, feature_columns: list[str] | None = None, set_count_0_to_null: bool = False diff --git a/EventStream/data/pytorch_dataset.py b/EventStream/data/pytorch_dataset.py index 9089b60a..f7c3f11a 100644 --- a/EventStream/data/pytorch_dataset.py +++ b/EventStream/data/pytorch_dataset.py @@ -461,7 +461,16 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: subject_id, st, end = self.index[idx] - shard = self.subj_map[subject_id] + if str(subject_id) not in self.subj_map: + err_str = [f"Subject {subject_id} ({type(subject_id)} -- as str) not found in the shard map!"] + + if len(self.subj_map) < 10: + err_str.append("Subject IDs in map:") + err_str.extend(f" * {k} ({type(k)}): {v}" for k, v in self.subj_map.items()) + + raise ValueError("\n".join(err_str)) + + shard = self.subj_map[str(subject_id)] subject_idx = self.subj_indices[subject_id] static_row = self.static_dfs[shard][subject_idx].to_dict() @@ -471,7 +480,7 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: } if self.config.do_include_subject_id: - out["subject_id"] = subject_id + out["subject_id"] = static_row["subject_id"].item() seq_len = end - st if seq_len > self.max_seq_len: diff --git a/configs/dataset_base.yaml b/configs/dataset_base.yaml index 6d58bf37..4ee1137b 100644 --- a/configs/dataset_base.yaml +++ b/configs/dataset_base.yaml @@ -8,6 +8,7 @@ subject_id_col: ??? seed: 1 split: [0.8, 0.1] do_overwrite: False +do_update: False DL_chunk_size: 20000 min_valid_vocab_element_observations: 25 min_valid_column_observations: 50 @@ -19,7 +20,7 @@ center_and_scale: True hydra: job: - name: build_${cohort_name} + name: build_dataset run: dir: ${save_dir}/.logs sweep: diff --git a/sample_data/dataset.yaml b/sample_data/dataset.yaml index 800d8544..e2e06053 100644 --- a/sample_data/dataset.yaml +++ b/sample_data/dataset.yaml @@ -10,7 +10,7 @@ subject_id_col: "MRN" raw_data_dir: "./sample_data/raw/" save_dir: "./sample_data/processed/${cohort_name}" -DL_chunk_size: null +DL_chunk_size: 25 inputs: subjects: diff --git a/sample_data/dataset_parquet.yaml b/sample_data/dataset_parquet.yaml new file mode 100644 index 00000000..75954cfd --- /dev/null +++ b/sample_data/dataset_parquet.yaml @@ -0,0 +1,69 @@ +defaults: + - dataset_base + - _self_ + +# So that it can be run multiple times without issue. +do_overwrite: True + +cohort_name: "sample" +subject_id_col: "MRN" +raw_data_dir: "./sample_data/raw_parquet" +save_dir: "./sample_data/processed/${cohort_name}" + +DL_chunk_size: 25 + +inputs: + subjects: + input_df: "${raw_data_dir}/subjects.parquet" + admissions: + input_df: "${raw_data_dir}/admit_vitals.parquet" + start_ts_col: "admit_date" + end_ts_col: "disch_date" + ts_format: "%m/%d/%Y, %H:%M:%S" + event_type: ["OUTPATIENT_VISIT", "ADMISSION", "DISCHARGE"] + vitals: + input_df: "${raw_data_dir}/admit_vitals.parquet" + ts_col: "vitals_date" + ts_format: "%m/%d/%Y, %H:%M:%S" + labs: + input_df: "${raw_data_dir}/labs.parquet" + ts_col: "timestamp" + ts_format: "%H:%M:%S-%Y-%m-%d" + medications: + input_df: "${raw_data_dir}/medications.parquet" + ts_col: "timestamp" + ts_format: "%H:%M:%S-%Y-%m-%d" + columns: {"name": "medication"} + +measurements: + static: + single_label_classification: + subjects: ["eye_color"] + functional_time_dependent: + age: + functor: AgeFunctor + necessary_static_measurements: { "dob": ["timestamp", "%m/%d/%Y"] } + kwargs: { dob_col: "dob" } + dynamic: + multi_label_classification: + admissions: ["department"] + medications: + - name: medication + modifiers: + - [dose, "float"] + - [frequency, "categorical"] + - [duration, "categorical"] + - [generic_name, "categorical"] + univariate_regression: + vitals: ["HR", "temp"] + multivariate_regression: + labs: [["lab_name", "lab_value"]] + +outlier_detector_config: + stddev_cutoff: 1.5 +min_valid_vocab_element_observations: 5 +min_valid_column_observations: 5 +min_true_float_frequency: 0.1 +min_unique_numerical_observations: 20 +min_events_per_subject: 3 +agg_by_time_scale: "1h" diff --git a/sample_data/raw_parquet/admit_vitals.parquet b/sample_data/raw_parquet/admit_vitals.parquet new file mode 100644 index 00000000..d11594bd Binary files /dev/null and b/sample_data/raw_parquet/admit_vitals.parquet differ diff --git a/sample_data/raw_parquet/labs.parquet b/sample_data/raw_parquet/labs.parquet new file mode 100644 index 00000000..7a662aec Binary files /dev/null and b/sample_data/raw_parquet/labs.parquet differ diff --git a/sample_data/raw_parquet/medications.parquet b/sample_data/raw_parquet/medications.parquet new file mode 100644 index 00000000..e12e4562 Binary files /dev/null and b/sample_data/raw_parquet/medications.parquet differ diff --git a/sample_data/raw_parquet/subjects.parquet b/sample_data/raw_parquet/subjects.parquet new file mode 100644 index 00000000..1d5bbb94 Binary files /dev/null and b/sample_data/raw_parquet/subjects.parquet differ diff --git a/scripts/build_flat_reps.py b/scripts/build_flat_reps.py new file mode 100644 index 00000000..f9cf9c0c --- /dev/null +++ b/scripts/build_flat_reps.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +"""Builds a flat representation dataset given a hydra config file.""" + +try: + import stackprinter + + stackprinter.set_excepthook(style="darkbg2") +except ImportError: + pass # no need to fail because of missing dev dependency + +from pathlib import Path + +import hydra +from omegaconf import DictConfig + +from EventStream.data.dataset_polars import Dataset + + +@hydra.main(version_base=None, config_path="../configs", config_name="dataset_base") +def main(cfg: DictConfig): + cfg = hydra.utils.instantiate(cfg, _convert_="all") + save_dir = Path(cfg.pop("save_dir")) + window_sizes = cfg.pop("window_sizes") + subjects_per_output_file = ( + cfg.pop("subjects_per_output_file") if "subjects_per_output_file" in cfg else None + ) + + # Build flat reps for specified task and window sizes + ESD = Dataset.load(save_dir) + feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params( + feature_inclusion_frequency=None, include_only_measurements=None + ) + cache_kwargs = dict( + subjects_per_output_file=subjects_per_output_file, + feature_inclusion_frequency=feature_inclusion_frequency, # 0.1 + window_sizes=window_sizes, + include_only_measurements=include_only_measurements, + do_overwrite=cfg.pop("do_overwrite"), + do_update=cfg.pop("do_update"), + ) + ESD.cache_flat_representation(**cache_kwargs) + + +if __name__ == "__main__": + main() diff --git a/tests/data/test_pytorch_dataset.py b/tests/data/test_pytorch_dataset.py index defa2eff..bf964572 100644 --- a/tests/data/test_pytorch_dataset.py +++ b/tests/data/test_pytorch_dataset.py @@ -320,7 +320,7 @@ def setUp(self): shards_fp = self.path / "DL_shards.json" shards = { - f"{self.split}/0": list(set(DL_REP_DF["subject_id"].to_list())), + f"{self.split}/0": [str(x) for x in set(DL_REP_DF["subject_id"].to_list())], } shards_fp.write_text(json.dumps(shards)) diff --git a/tests/test_e2e_runs.py b/tests/test_e2e_runs.py index b56a6970..be67e987 100644 --- a/tests/test_e2e_runs.py +++ b/tests/test_e2e_runs.py @@ -2,6 +2,7 @@ root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) +import json import os import subprocess import unittest @@ -10,6 +11,8 @@ from tempfile import TemporaryDirectory from typing import Any +import polars as pl + from tests.utils import MLTypeEqualityCheckableMixin @@ -32,6 +35,7 @@ def setUp(self): self.paths = {} for n in ( "dataset", + "dataset_from_parquet", "esds", "pretraining/CI", "pretraining/NA", @@ -45,6 +49,67 @@ def tearDown(self): for o in self.dir_objs.values(): o.cleanup() + def _test_dataset_output(self, raw_data_root: Path, dataset_save_dir: Path): + DL_save_dir = dataset_save_dir / "DL_reps" + + train_files = list((DL_save_dir / "train").glob("*.parquet")) + tuning_files = list((DL_save_dir / "tuning").glob("*.parquet")) + held_out_files = list((DL_save_dir / "held_out").glob("*.parquet")) + + assert len(set(train_files) & set(tuning_files)) == 0 + assert len(set(train_files) & set(held_out_files)) == 0 + assert len(set(tuning_files) & set(held_out_files)) == 0 + + self.assertTrue(len(train_files) > 0) + self.assertTrue(len(tuning_files) > 0) + self.assertTrue(len(held_out_files) > 0) + + train_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in train_files]) + tuning_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in tuning_files]) + held_out_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in held_out_files]) + + DL_shards = json.loads((dataset_save_dir / "DL_shards.json").read_text()) + + ESD_subjects = pl.read_parquet(dataset_save_dir / "subjects_df.parquet", use_pyarrow=False) + + # Check that the DL shards are correctly partitioned. + all_subjects = set(ESD_subjects["subject_id"].unique().to_list()) + + self.assertEqual(len(all_subjects), len(ESD_subjects)) + + all_subj_in_DL_shards = set().union(*DL_shards.values()) + + all_subj_in_DL_shards = set( + pl.Series(list(all_subj_in_DL_shards)).cast(ESD_subjects["subject_id"].dtype).to_list() + ) + + self.assertEqual(all_subjects, all_subj_in_DL_shards) + + all_train_DL_shard_subj = set().union(*(v for k, v in DL_shards.items() if k.startswith("train"))) + all_tuning_DL_shard_subj = set().union(*(v for k, v in DL_shards.items() if k.startswith("tuning"))) + all_held_out_DL_shard_subj = set().union( + *(v for k, v in DL_shards.items() if k.startswith("held_out")) + ) + + self.assertEqual(len(all_train_DL_shard_subj & all_tuning_DL_shard_subj), 0) + self.assertEqual(len(all_train_DL_shard_subj & all_held_out_DL_shard_subj), 0) + self.assertEqual(len(all_tuning_DL_shard_subj & all_held_out_DL_shard_subj), 0) + + train_DL_subjects = set(train_DL_reps["subject_id"].to_list()) + tuning_DL_subjects = set(tuning_DL_reps["subject_id"].to_list()) + held_out_DL_subjects = set(held_out_DL_reps["subject_id"].to_list()) + + self.assertEqual(all_train_DL_shard_subj, {str(x) for x in train_DL_subjects}) + self.assertEqual(all_tuning_DL_shard_subj, {str(x) for x in tuning_DL_subjects}) + self.assertEqual(all_held_out_DL_shard_subj, {str(x) for x in held_out_DL_subjects}) + + self.assertTrue(len(train_DL_subjects) > len(tuning_DL_subjects)) + self.assertTrue(len(train_DL_subjects) > len(held_out_DL_subjects)) + + all_DL_subjects = train_DL_subjects | tuning_DL_subjects | held_out_DL_subjects + + self.assertEqual(all_DL_subjects, all_subjects) + def _test_command(self, command_parts: list[str], case_name: str, use_subtest: bool = True): if use_subtest: with self.subTest(case_name): @@ -71,6 +136,17 @@ def build_dataset(self): f"save_dir={self.paths['dataset']}", ] self._test_command(command_parts, "Build Dataset", use_subtest=False) + self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset"]) + + command_parts = [ + "./scripts/build_dataset.py", + f"--config-path='{(root / 'sample_data').resolve()}'", + "--config-name=dataset_parquet", + '"hydra.searchpath=[./configs]"', + f"save_dir={self.paths['dataset_from_parquet']}", + ] + self._test_command(command_parts, "Build Dataset from Parquet", use_subtest=False) + self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset_from_parquet"]) def build_ESDS_dataset(self): command_parts = [