Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjusted join in flat reps to account for different timestamps with t… #107

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
14 changes: 11 additions & 3 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def load_flat_rep(
do_update_if_missing: bool = True,
task_df_name: str | None = None,
do_cache_filtered_task: bool = True,
overwrite_cache_filtered_task: bool = False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper documentation for the new parameter.

The new parameter overwrite_cache_filtered_task should be included in the function's docstring to maintain comprehensive documentation.

+        overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, the cached file will be loaded if exists.
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
overwrite_cache_filtered_task: bool = False,
overwrite_cache_filtered_task: bool = False,

subjects_included: dict[str, set[int]] | None = None,
) -> dict[str, pl.LazyFrame]:
"""Loads a set of flat representations from a passed dataset that satisfy the given constraints.
Expand All @@ -67,14 +68,16 @@ def load_flat_rep(
do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will
try to update the stored flat representations to reflect these. If `False`, if information is
missing, it will raise a `FileNotFoundError` instead.
task_df_name: If specified, the flat representations loaded will be (inner) joined against the task
task_df_name: If specified, the flat representations loaded will be joined against the task
dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed
to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into
memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a
specified path for this task, then the data will be loaded from there, rather than from the base
dataset.
do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the
relevant rows for the task, be cached to disk for faster re-use.
overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`,
the cached file will be loaded if exists.
Comment on lines +79 to +80
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tip

Codebase Verification

Ensure all calls to load_flat_rep include the new parameter overwrite_cache_filtered_task.

The following calls to load_flat_rep in sample_data/examine_synthetic_data.ipynb do not include the new parameter and should be updated:

  • flat_reps = load_flat_rep(ESD, window_sizes=['7d'])
  • flat_reps = load_flat_rep(ESD, window_sizes=['FULL'], feature_inclusion_frequency=0.001)
  • flat_reps = load_flat_rep(ESD, window_sizes=['1d', '7d', 'FULL'])
  • flat_reps = load_flat_rep(ESD, window_sizes=['2d'], do_update_if_missing=False)
  • flat_reps = load_flat_rep(ESD, window_sizes=['1d', '7d', 'FULL'], subjects_included={'train': {0, 1, 2}})
  • flat_reps = load_flat_rep(ESD, window_sizes=['7d'], task_df_name='single_label_binary_classification')
  • flat_reps = load_flat_rep(ESD, window_sizes=['FULL', '1d'], task_df_name='multi_class_classification', do_cache_filtered_task=False)
  • flat_reps = load_flat_rep(ESD, window_sizes=['FULL', '1d'], task_df_name='single_label_binary_classification', subjects_included={'train': {0, 1, 2}})

Please update these calls to include the overwrite_cache_filtered_task parameter as needed.

Analysis chain

LGTM! But verify the new parameter usage in the codebase.

The code changes are approved.

However, ensure that all function calls to load_flat_rep are updated to use the new parameter if necessary.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify all function calls to `load_flat_rep` match the new signature.

# Test: Search for the function usage. Expect: Only occurrences of the new signature.
rg --type python $'load_flat_rep'

Length of output: 63


Script:

#!/bin/bash
# Description: Verify all function calls to `load_flat_rep` match the new signature.

# Test: Search for the function usage. Expect: Only occurrences of the new signature.
rg 'load_flat_rep'

Length of output: 1215


Script:

#!/bin/bash
# Description: Extract lines around each call to `load_flat_rep` to verify the usage of the new parameter.

# Extract lines around the calls in the Jupyter notebook
rg -A 3 -B 3 'load_flat_rep' sample_data/examine_synthetic_data.ipynb

# Extract lines around the calls in the Python file
rg -A 3 -B 3 'load_flat_rep' EventStream/baseline/FT_task_baseline.py

Length of output: 3273

subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are
used wholesale.

Expand Down Expand Up @@ -170,7 +173,7 @@ def load_flat_rep(
if task_df_name is not None:
fn = fp.parts[-1]
cached_fp = task_window_dir / fn
if cached_fp.is_file():
if cached_fp.is_file() and not overwrite_cache_filtered_task:
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
Expand All @@ -182,7 +185,12 @@ def load_flat_rep(
if task_df_name is not None:
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects))

df = df.join(filter_join_df, on=join_keys, how="inner")
df = filter_join_df.join_asof(
df,
by="subject_id",
on="timestamp",
strategy="forward" if "-" in window_size else "backward",
)

if do_cache_filtered_task:
cached_fp.parent.mkdir(exist_ok=True, parents=True)
Expand Down
Loading