From 720e6cb184fbea07971a57a968c03188ad069d59 Mon Sep 17 00:00:00 2001 From: pargaw Date: Fri, 3 May 2024 12:57:26 -0400 Subject: [PATCH 1/7] Adjusted join in flat reps to account for different timestamps with task df --- EventStream/baseline/FT_task_baseline.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/EventStream/baseline/FT_task_baseline.py b/EventStream/baseline/FT_task_baseline.py index b75ffd10..2c3f42cd 100644 --- a/EventStream/baseline/FT_task_baseline.py +++ b/EventStream/baseline/FT_task_baseline.py @@ -46,6 +46,7 @@ def load_flat_rep( do_update_if_missing: bool = True, task_df_name: str | None = None, do_cache_filtered_task: bool = True, + overwrite_cache_filtered_task: bool = False, subjects_included: dict[str, set[int]] | None = None, ) -> dict[str, pl.LazyFrame]: """Loads a set of flat representations from a passed dataset that satisfy the given constraints. @@ -67,7 +68,7 @@ def load_flat_rep( do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will try to update the stored flat representations to reflect these. If `False`, if information is missing, it will raise a `FileNotFoundError` instead. - task_df_name: If specified, the flat representations loaded will be (inner) joined against the task + task_df_name: If specified, the flat representations loaded will be joined against the task dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a @@ -75,6 +76,8 @@ def load_flat_rep( dataset. do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the relevant rows for the task, be cached to disk for faster re-use. + overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, + the cached file will be loaded if exists. subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are used wholesale. @@ -170,7 +173,7 @@ def load_flat_rep( if task_df_name is not None: fn = fp.parts[-1] cached_fp = task_window_dir / fn - if cached_fp.is_file(): + if cached_fp.is_file() and not overwrite_cache_filtered_task: df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features) if subjects_included.get(sp, None) is not None: subjects = list(set(subjects).intersection(subjects_included[sp])) @@ -182,7 +185,8 @@ def load_flat_rep( if task_df_name is not None: filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects)) - df = df.join(filter_join_df, on=join_keys, how="inner") + df = filter_join_df.join_asof(df, by='subject_id', on='timestamp', + strategy='forward' if '-' in window_size else 'backward') if do_cache_filtered_task: cached_fp.parent.mkdir(exist_ok=True, parents=True) From b00619573cd98f19320c06da2b0f7641a9d8b4b1 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 4 May 2024 15:09:48 -0400 Subject: [PATCH 2/7] Fixed code style. --- EventStream/baseline/FT_task_baseline.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/EventStream/baseline/FT_task_baseline.py b/EventStream/baseline/FT_task_baseline.py index 2c3f42cd..ddb55fde 100644 --- a/EventStream/baseline/FT_task_baseline.py +++ b/EventStream/baseline/FT_task_baseline.py @@ -76,7 +76,7 @@ def load_flat_rep( dataset. do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the relevant rows for the task, be cached to disk for faster re-use. - overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, + overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, the cached file will be loaded if exists. subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are used wholesale. @@ -185,8 +185,12 @@ def load_flat_rep( if task_df_name is not None: filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects)) - df = filter_join_df.join_asof(df, by='subject_id', on='timestamp', - strategy='forward' if '-' in window_size else 'backward') + df = filter_join_df.join_asof( + df, + by="subject_id", + on="timestamp", + strategy="forward" if "-" in window_size else "backward", + ) if do_cache_filtered_task: cached_fp.parent.mkdir(exist_ok=True, parents=True) From 9e0acf74e4f201c734d512927dc42d36799981d4 Mon Sep 17 00:00:00 2001 From: pargaw Date: Tue, 11 Jun 2024 13:58:41 -0400 Subject: [PATCH 3/7] Added script to build flat reps and save, and adjusted dataset_polars to cast subject_id to str (pl.Utf8) to account for errors in Inovalon --- EventStream/data/dataset_polars.py | 9 ++++-- scripts/build_flat_reps.py | 44 ++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) create mode 100755 scripts/build_flat_reps.py diff --git a/EventStream/data/dataset_polars.py b/EventStream/data/dataset_polars.py index 9be90299..d6b731c6 100644 --- a/EventStream/data/dataset_polars.py +++ b/EventStream/data/dataset_polars.py @@ -1417,7 +1417,8 @@ def _summarize_static_measurements( if include_only_subjects is None: df = self.subjects_df else: - df = self.subjects_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) + self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) + df = self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) valid_measures = {} for feat_col in feature_columns: @@ -1477,7 +1478,8 @@ def _summarize_time_dependent_measurements( if include_only_subjects is None: df = self.events_df else: - df = self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) + self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) + df = self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) valid_measures = {} for feat_col in feature_columns: @@ -1540,8 +1542,9 @@ def _summarize_dynamic_measurements( if include_only_subjects is None: df = self.dynamic_measurements_df else: + self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) df = self.dynamic_measurements_df.join( - self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select( + self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])).select( "event_id" ), on="event_id", diff --git a/scripts/build_flat_reps.py b/scripts/build_flat_reps.py new file mode 100755 index 00000000..9fc54e84 --- /dev/null +++ b/scripts/build_flat_reps.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +"""Builds a flat representation dataset given a hydra config file.""" + +try: + import stackprinter + + stackprinter.set_excepthook(style="darkbg2") +except ImportError: + pass # no need to fail because of missing dev dependency + +from pathlib import Path +import hydra +from omegaconf import DictConfig, OmegaConf +from loguru import logger + +from EventStream.data.dataset_polars import Dataset + +@hydra.main(version_base=None, config_path="../configs", config_name="dataset_base") +def main(cfg: DictConfig): + cfg = hydra.utils.instantiate(cfg, _convert_="all") + save_dir = Path(cfg.pop("save_dir")) + window_sizes = cfg.pop("window_sizes") + subjects_per_output_file = cfg.pop("subjects_per_output_file") if 'subjects_per_output_file' in cfg else None + + # Build flat reps for specified task and window sizes + logger.debug('Loading ESD..') + ESD = Dataset.load(save_dir) + feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params( + feature_inclusion_frequency=None, include_only_measurements=None + ) + cache_kwargs = dict( + subjects_per_output_file=subjects_per_output_file, + feature_inclusion_frequency=feature_inclusion_frequency, #0.1 + window_sizes=window_sizes, + include_only_measurements=include_only_measurements, + do_overwrite=False, + do_update=True, + ) + logger.debug('Caching flat representation..') + ESD.cache_flat_representation(**cache_kwargs) + logger.debug('Done') + +if __name__ == "__main__": + main() From 435d968c5022ce8e5cdddef0ef6a90bcbba68bb3 Mon Sep 17 00:00:00 2001 From: pargaw Date: Fri, 14 Jun 2024 12:57:54 -0400 Subject: [PATCH 4/7] Updated for loops to print progress, fixed minor bugs --- EventStream/data/dataset_base.py | 10 +++++----- EventStream/data/dataset_polars.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/EventStream/data/dataset_base.py b/EventStream/data/dataset_base.py index d324e499..6453d819 100644 --- a/EventStream/data/dataset_base.py +++ b/EventStream/data/dataset_base.py @@ -223,17 +223,17 @@ def build_event_and_measurement_dfs( all_events_and_measurements = [] event_types = [] - for df, schemas in schemas_by_df.items(): + for df_name, schemas in schemas_by_df.items(): all_columns = [] all_columns.extend(itertools.chain.from_iterable(s.columns_to_load for s in schemas)) try: - df = cls._load_input_df(df, all_columns, subject_id_col, subject_ids_map, subject_id_dtype) + df = cls._load_input_df(df_name, all_columns, subject_id_col, subject_ids_map, subject_id_dtype) except Exception as e: - raise ValueError(f"Errored while loading {df}") from e + raise ValueError(f"Errored while loading {df_name}") from e - for schema in schemas: + for schema in tqdm(schemas, desc=f"Processing events and measurements df for {df_name.split('/')[-1]}"): if schema.filter_on: df = cls._filter_col_inclusion(schema.filter_on) match schema.type: @@ -266,7 +266,7 @@ def build_event_and_measurement_dfs( all_events, all_measurements = [], [] running_event_id_max = 0 - for event_type, (events, measurements) in zip(event_types, all_events_and_measurements): + for event_type, (events, measurements) in tqdm(zip(event_types, all_events_and_measurements), desc="Incrementing and combining events and measurements"): try: new_events = cls._inc_df_col(events, "event_id", running_event_id_max) except Exception as e: diff --git a/EventStream/data/dataset_polars.py b/EventStream/data/dataset_polars.py index d6b731c6..afb46f42 100644 --- a/EventStream/data/dataset_polars.py +++ b/EventStream/data/dataset_polars.py @@ -705,7 +705,7 @@ def _update_subject_event_properties(self): ) n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas() - self.n_events_per_subject = n_events_pd.set_index("subject_id")["counts"].to_dict() + self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict() self.subject_ids = set(self.n_events_per_subject.keys()) if self.subjects_df is not None: @@ -853,7 +853,7 @@ def _add_inferred_val_types( .alias("is_int") ) int_keys = for_val_type_inference.groupby(vocab_keys_col).agg(is_int_expr) - + measurement_metadata = measurement_metadata.join(int_keys, on=vocab_keys_col, how="outer") key_is_int = pl.col(vocab_keys_col).is_in(int_keys.filter("is_int")[vocab_keys_col]) @@ -1105,7 +1105,7 @@ def _fit_vocabulary(self, measure: str, config: MeasurementConfig, source_df: DF try: value_counts = observations.value_counts() vocab_elements = value_counts.get_column(measure).to_list() - el_counts = value_counts.get_column("counts") + el_counts = value_counts.get_column("count") return Vocabulary(vocabulary=vocab_elements, obs_frequencies=el_counts) except AssertionError as e: raise AssertionError(f"Failed to build vocabulary for {measure}") from e From 5c6cb4bb555623deb66686a5f7d414f015287b69 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 20 Jun 2024 10:10:40 -0400 Subject: [PATCH 5/7] Fixed lint issues --- EventStream/data/dataset_base.py | 13 ++++++++++--- EventStream/data/dataset_polars.py | 12 +++++++----- scripts/build_flat_reps.py | 15 ++++++++------- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/EventStream/data/dataset_base.py b/EventStream/data/dataset_base.py index 6453d819..d5405e00 100644 --- a/EventStream/data/dataset_base.py +++ b/EventStream/data/dataset_base.py @@ -229,11 +229,15 @@ def build_event_and_measurement_dfs( all_columns.extend(itertools.chain.from_iterable(s.columns_to_load for s in schemas)) try: - df = cls._load_input_df(df_name, all_columns, subject_id_col, subject_ids_map, subject_id_dtype) + df = cls._load_input_df( + df_name, all_columns, subject_id_col, subject_ids_map, subject_id_dtype + ) except Exception as e: raise ValueError(f"Errored while loading {df_name}") from e - for schema in tqdm(schemas, desc=f"Processing events and measurements df for {df_name.split('/')[-1]}"): + for schema in tqdm( + schemas, desc=f"Processing events and measurements df for {df_name.split('/')[-1]}" + ): if schema.filter_on: df = cls._filter_col_inclusion(schema.filter_on) match schema.type: @@ -266,7 +270,10 @@ def build_event_and_measurement_dfs( all_events, all_measurements = [], [] running_event_id_max = 0 - for event_type, (events, measurements) in tqdm(zip(event_types, all_events_and_measurements), desc="Incrementing and combining events and measurements"): + for event_type, (events, measurements) in tqdm( + zip(event_types, all_events_and_measurements), + desc="Incrementing and combining events and measurements", + ): try: new_events = cls._inc_df_col(events, "event_id", running_event_id_max) except Exception as e: diff --git a/EventStream/data/dataset_polars.py b/EventStream/data/dataset_polars.py index afb46f42..ad1e63b2 100644 --- a/EventStream/data/dataset_polars.py +++ b/EventStream/data/dataset_polars.py @@ -853,7 +853,7 @@ def _add_inferred_val_types( .alias("is_int") ) int_keys = for_val_type_inference.groupby(vocab_keys_col).agg(is_int_expr) - + measurement_metadata = measurement_metadata.join(int_keys, on=vocab_keys_col, how="outer") key_is_int = pl.col(vocab_keys_col).is_in(int_keys.filter("is_int")[vocab_keys_col]) @@ -1418,7 +1418,9 @@ def _summarize_static_measurements( df = self.subjects_df else: self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) - df = self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) + df = self.subjects_df.filter( + pl.col("subject_id").is_in([str(id) for id in include_only_subjects]) + ) valid_measures = {} for feat_col in feature_columns: @@ -1544,9 +1546,9 @@ def _summarize_dynamic_measurements( else: self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) df = self.dynamic_measurements_df.join( - self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])).select( - "event_id" - ), + self.events_df.filter( + pl.col("subject_id").is_in([str(id) for id in include_only_subjects]) + ).select("event_id"), on="event_id", how="inner", ) diff --git a/scripts/build_flat_reps.py b/scripts/build_flat_reps.py index 9fc54e84..536c8def 100755 --- a/scripts/build_flat_reps.py +++ b/scripts/build_flat_reps.py @@ -9,36 +9,37 @@ pass # no need to fail because of missing dev dependency from pathlib import Path + import hydra -from omegaconf import DictConfig, OmegaConf -from loguru import logger +from omegaconf import DictConfig from EventStream.data.dataset_polars import Dataset + @hydra.main(version_base=None, config_path="../configs", config_name="dataset_base") def main(cfg: DictConfig): cfg = hydra.utils.instantiate(cfg, _convert_="all") save_dir = Path(cfg.pop("save_dir")) window_sizes = cfg.pop("window_sizes") - subjects_per_output_file = cfg.pop("subjects_per_output_file") if 'subjects_per_output_file' in cfg else None + subjects_per_output_file = ( + cfg.pop("subjects_per_output_file") if "subjects_per_output_file" in cfg else None + ) # Build flat reps for specified task and window sizes - logger.debug('Loading ESD..') ESD = Dataset.load(save_dir) feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params( feature_inclusion_frequency=None, include_only_measurements=None ) cache_kwargs = dict( subjects_per_output_file=subjects_per_output_file, - feature_inclusion_frequency=feature_inclusion_frequency, #0.1 + feature_inclusion_frequency=feature_inclusion_frequency, # 0.1 window_sizes=window_sizes, include_only_measurements=include_only_measurements, do_overwrite=False, do_update=True, ) - logger.debug('Caching flat representation..') ESD.cache_flat_representation(**cache_kwargs) - logger.debug('Done') + if __name__ == "__main__": main() From 69b99ce94751abe8277120ac478e09bc3ba60a6b Mon Sep 17 00:00:00 2001 From: pargaw Date: Mon, 8 Jul 2024 18:29:47 -0400 Subject: [PATCH 6/7] Adjusted loading flat reps code to account for differences in subject_id dtype --- EventStream/baseline/FT_task_baseline.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/EventStream/baseline/FT_task_baseline.py b/EventStream/baseline/FT_task_baseline.py index ddb55fde..c5e252e2 100644 --- a/EventStream/baseline/FT_task_baseline.py +++ b/EventStream/baseline/FT_task_baseline.py @@ -155,6 +155,7 @@ def load_flat_rep( static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet") if task_df_name is not None: + static_df = static_df.cast({"subject_id": sp_join_df.select('subject_id').dtypes[0]}) static_df = static_df.join(sp_join_df.select("subject_id").unique(), on="subject_id", how="inner") dfs = [] @@ -184,7 +185,7 @@ def load_flat_rep( df = pl.scan_parquet(fp) if task_df_name is not None: filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects)) - + df = df.cast({"subject_id": filter_join_df.select('subject_id').dtypes[0]}) df = filter_join_df.join_asof( df, by="subject_id", @@ -203,7 +204,7 @@ def load_flat_rep( window_dfs.append(df) - dfs.append(pl.concat(window_dfs, how="vertical")) + dfs.append(pl.concat(window_dfs, how="vertical_relaxed")) joined_df = dfs[0] for jdf in dfs[1:]: From bf453e16f45e2d97f8a1eff0dfdb266ff946893c Mon Sep 17 00:00:00 2001 From: pargaw Date: Thu, 11 Jul 2024 00:23:47 -0400 Subject: [PATCH 7/7] Added tuning subset size and seed to data config --- EventStream/data/config.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/EventStream/data/config.py b/EventStream/data/config.py index 017eb56f..5d246fc9 100644 --- a/EventStream/data/config.py +++ b/EventStream/data/config.py @@ -803,6 +803,10 @@ class PytorchDatasetConfig(JSONableMixin): training subset. If `None` or "FULL", then the full training data is used. train_subset_seed: If the training data should be subsampled randomly, this specifies the seed for that random subsampling. + tuning_subset_size: If the tuning data should be subsampled randomly, this specifies the size of the + tuning subset. If `None` or "FULL", then the full tuning data is used. + tuning_subset_seed: If the tuning data should be subsampled randomly, this specifies the seed for + that random subsampling. task_df_name: If the raw dataset should be limited to a task dataframe view, this specifies the name of the task dataframe, and indirectly the path on disk from where that task dataframe will be read (save_dir / "task_dfs" / f"{task_df_name}.parquet"). @@ -873,6 +877,8 @@ class PytorchDatasetConfig(JSONableMixin): train_subset_size: int | float | str = "FULL" train_subset_seed: int | None = None + tuning_subset_size: int | float | str = "FULL" + tuning_subset_seed: int | None = None task_df_name: str | None = None @@ -907,6 +913,22 @@ def __post_init__(self): pass case _: raise TypeError(f"train_subset_size is of unrecognized type {type(self.train_subset_size)}.") + + match self.tuning_subset_size: + case int() as n if n < 0: + raise ValueError(f"If integral, tuning_subset_size must be positive! Got {n}") + case float() as frac if frac <= 0 or frac >= 1: + raise ValueError(f"If float, tuning_subset_size must be in (0, 1)! Got {frac}") + case int() | float() if (self.tuning_subset_seed is None): + seed = int(random.randint(1, int(1e6))) + print(f"WARNING! tuning_subset_size is set, but tuning_subset_seed is not. Setting to {seed}") + self.tuning_subset_seed = seed + case None | "FULL" | int() | float(): + pass + case _: + raise TypeError( + f"tuning_subset_size is of unrecognized type {type(self.tuning_subset_size)}." + ) def to_dict(self) -> dict: """Represents this configuration object as a plain dictionary."""