diff --git a/EventStream/baseline/FT_task_baseline.py b/EventStream/baseline/FT_task_baseline.py index b75ffd10..c5e252e2 100644 --- a/EventStream/baseline/FT_task_baseline.py +++ b/EventStream/baseline/FT_task_baseline.py @@ -46,6 +46,7 @@ def load_flat_rep( do_update_if_missing: bool = True, task_df_name: str | None = None, do_cache_filtered_task: bool = True, + overwrite_cache_filtered_task: bool = False, subjects_included: dict[str, set[int]] | None = None, ) -> dict[str, pl.LazyFrame]: """Loads a set of flat representations from a passed dataset that satisfy the given constraints. @@ -67,7 +68,7 @@ def load_flat_rep( do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will try to update the stored flat representations to reflect these. If `False`, if information is missing, it will raise a `FileNotFoundError` instead. - task_df_name: If specified, the flat representations loaded will be (inner) joined against the task + task_df_name: If specified, the flat representations loaded will be joined against the task dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a @@ -75,6 +76,8 @@ def load_flat_rep( dataset. do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the relevant rows for the task, be cached to disk for faster re-use. + overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, + the cached file will be loaded if exists. subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are used wholesale. @@ -152,6 +155,7 @@ def load_flat_rep( static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet") if task_df_name is not None: + static_df = static_df.cast({"subject_id": sp_join_df.select('subject_id').dtypes[0]}) static_df = static_df.join(sp_join_df.select("subject_id").unique(), on="subject_id", how="inner") dfs = [] @@ -170,7 +174,7 @@ def load_flat_rep( if task_df_name is not None: fn = fp.parts[-1] cached_fp = task_window_dir / fn - if cached_fp.is_file(): + if cached_fp.is_file() and not overwrite_cache_filtered_task: df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features) if subjects_included.get(sp, None) is not None: subjects = list(set(subjects).intersection(subjects_included[sp])) @@ -181,8 +185,13 @@ def load_flat_rep( df = pl.scan_parquet(fp) if task_df_name is not None: filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects)) - - df = df.join(filter_join_df, on=join_keys, how="inner") + df = df.cast({"subject_id": filter_join_df.select('subject_id').dtypes[0]}) + df = filter_join_df.join_asof( + df, + by="subject_id", + on="timestamp", + strategy="forward" if "-" in window_size else "backward", + ) if do_cache_filtered_task: cached_fp.parent.mkdir(exist_ok=True, parents=True) @@ -195,7 +204,7 @@ def load_flat_rep( window_dfs.append(df) - dfs.append(pl.concat(window_dfs, how="vertical")) + dfs.append(pl.concat(window_dfs, how="vertical_relaxed")) joined_df = dfs[0] for jdf in dfs[1:]: diff --git a/EventStream/data/config.py b/EventStream/data/config.py index 017eb56f..5d246fc9 100644 --- a/EventStream/data/config.py +++ b/EventStream/data/config.py @@ -803,6 +803,10 @@ class PytorchDatasetConfig(JSONableMixin): training subset. If `None` or "FULL", then the full training data is used. train_subset_seed: If the training data should be subsampled randomly, this specifies the seed for that random subsampling. + tuning_subset_size: If the tuning data should be subsampled randomly, this specifies the size of the + tuning subset. If `None` or "FULL", then the full tuning data is used. + tuning_subset_seed: If the tuning data should be subsampled randomly, this specifies the seed for + that random subsampling. task_df_name: If the raw dataset should be limited to a task dataframe view, this specifies the name of the task dataframe, and indirectly the path on disk from where that task dataframe will be read (save_dir / "task_dfs" / f"{task_df_name}.parquet"). @@ -873,6 +877,8 @@ class PytorchDatasetConfig(JSONableMixin): train_subset_size: int | float | str = "FULL" train_subset_seed: int | None = None + tuning_subset_size: int | float | str = "FULL" + tuning_subset_seed: int | None = None task_df_name: str | None = None @@ -907,6 +913,22 @@ def __post_init__(self): pass case _: raise TypeError(f"train_subset_size is of unrecognized type {type(self.train_subset_size)}.") + + match self.tuning_subset_size: + case int() as n if n < 0: + raise ValueError(f"If integral, tuning_subset_size must be positive! Got {n}") + case float() as frac if frac <= 0 or frac >= 1: + raise ValueError(f"If float, tuning_subset_size must be in (0, 1)! Got {frac}") + case int() | float() if (self.tuning_subset_seed is None): + seed = int(random.randint(1, int(1e6))) + print(f"WARNING! tuning_subset_size is set, but tuning_subset_seed is not. Setting to {seed}") + self.tuning_subset_seed = seed + case None | "FULL" | int() | float(): + pass + case _: + raise TypeError( + f"tuning_subset_size is of unrecognized type {type(self.tuning_subset_size)}." + ) def to_dict(self) -> dict: """Represents this configuration object as a plain dictionary.""" diff --git a/EventStream/data/dataset_base.py b/EventStream/data/dataset_base.py index d324e499..d5405e00 100644 --- a/EventStream/data/dataset_base.py +++ b/EventStream/data/dataset_base.py @@ -223,17 +223,21 @@ def build_event_and_measurement_dfs( all_events_and_measurements = [] event_types = [] - for df, schemas in schemas_by_df.items(): + for df_name, schemas in schemas_by_df.items(): all_columns = [] all_columns.extend(itertools.chain.from_iterable(s.columns_to_load for s in schemas)) try: - df = cls._load_input_df(df, all_columns, subject_id_col, subject_ids_map, subject_id_dtype) + df = cls._load_input_df( + df_name, all_columns, subject_id_col, subject_ids_map, subject_id_dtype + ) except Exception as e: - raise ValueError(f"Errored while loading {df}") from e + raise ValueError(f"Errored while loading {df_name}") from e - for schema in schemas: + for schema in tqdm( + schemas, desc=f"Processing events and measurements df for {df_name.split('/')[-1]}" + ): if schema.filter_on: df = cls._filter_col_inclusion(schema.filter_on) match schema.type: @@ -266,7 +270,10 @@ def build_event_and_measurement_dfs( all_events, all_measurements = [], [] running_event_id_max = 0 - for event_type, (events, measurements) in zip(event_types, all_events_and_measurements): + for event_type, (events, measurements) in tqdm( + zip(event_types, all_events_and_measurements), + desc="Incrementing and combining events and measurements", + ): try: new_events = cls._inc_df_col(events, "event_id", running_event_id_max) except Exception as e: diff --git a/EventStream/data/dataset_polars.py b/EventStream/data/dataset_polars.py index 9be90299..ad1e63b2 100644 --- a/EventStream/data/dataset_polars.py +++ b/EventStream/data/dataset_polars.py @@ -705,7 +705,7 @@ def _update_subject_event_properties(self): ) n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas() - self.n_events_per_subject = n_events_pd.set_index("subject_id")["counts"].to_dict() + self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict() self.subject_ids = set(self.n_events_per_subject.keys()) if self.subjects_df is not None: @@ -1105,7 +1105,7 @@ def _fit_vocabulary(self, measure: str, config: MeasurementConfig, source_df: DF try: value_counts = observations.value_counts() vocab_elements = value_counts.get_column(measure).to_list() - el_counts = value_counts.get_column("counts") + el_counts = value_counts.get_column("count") return Vocabulary(vocabulary=vocab_elements, obs_frequencies=el_counts) except AssertionError as e: raise AssertionError(f"Failed to build vocabulary for {measure}") from e @@ -1417,7 +1417,10 @@ def _summarize_static_measurements( if include_only_subjects is None: df = self.subjects_df else: - df = self.subjects_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) + self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) + df = self.subjects_df.filter( + pl.col("subject_id").is_in([str(id) for id in include_only_subjects]) + ) valid_measures = {} for feat_col in feature_columns: @@ -1477,7 +1480,8 @@ def _summarize_time_dependent_measurements( if include_only_subjects is None: df = self.events_df else: - df = self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) + self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) + df = self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) valid_measures = {} for feat_col in feature_columns: @@ -1540,10 +1544,11 @@ def _summarize_dynamic_measurements( if include_only_subjects is None: df = self.dynamic_measurements_df else: + self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) df = self.dynamic_measurements_df.join( - self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select( - "event_id" - ), + self.events_df.filter( + pl.col("subject_id").is_in([str(id) for id in include_only_subjects]) + ).select("event_id"), on="event_id", how="inner", ) diff --git a/scripts/build_flat_reps.py b/scripts/build_flat_reps.py new file mode 100755 index 00000000..536c8def --- /dev/null +++ b/scripts/build_flat_reps.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +"""Builds a flat representation dataset given a hydra config file.""" + +try: + import stackprinter + + stackprinter.set_excepthook(style="darkbg2") +except ImportError: + pass # no need to fail because of missing dev dependency + +from pathlib import Path + +import hydra +from omegaconf import DictConfig + +from EventStream.data.dataset_polars import Dataset + + +@hydra.main(version_base=None, config_path="../configs", config_name="dataset_base") +def main(cfg: DictConfig): + cfg = hydra.utils.instantiate(cfg, _convert_="all") + save_dir = Path(cfg.pop("save_dir")) + window_sizes = cfg.pop("window_sizes") + subjects_per_output_file = ( + cfg.pop("subjects_per_output_file") if "subjects_per_output_file" in cfg else None + ) + + # Build flat reps for specified task and window sizes + ESD = Dataset.load(save_dir) + feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params( + feature_inclusion_frequency=None, include_only_measurements=None + ) + cache_kwargs = dict( + subjects_per_output_file=subjects_per_output_file, + feature_inclusion_frequency=feature_inclusion_frequency, # 0.1 + window_sizes=window_sizes, + include_only_measurements=include_only_measurements, + do_overwrite=False, + do_update=True, + ) + ESD.cache_flat_representation(**cache_kwargs) + + +if __name__ == "__main__": + main()