Skip to content

Commit

Permalink
Merge branch 'join_flat_reps_with_taskdf' of https://github.com/mmcde…
Browse files Browse the repository at this point in the history
…rmott/EventStreamGPT into join_flat_reps_with_taskdf
  • Loading branch information
pargaw committed Jun 11, 2024
2 parents 9e0acf7 + b006195 commit 22dca2d
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def load_flat_rep(
dataset.
do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the
relevant rows for the task, be cached to disk for faster re-use.
overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`,
overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`,
the cached file will be loaded if exists.
subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are
used wholesale.
Expand Down Expand Up @@ -185,8 +185,12 @@ def load_flat_rep(
if task_df_name is not None:
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects))

df = filter_join_df.join_asof(df, by='subject_id', on='timestamp',
strategy='forward' if '-' in window_size else 'backward')
df = filter_join_df.join_asof(
df,
by="subject_id",
on="timestamp",
strategy="forward" if "-" in window_size else "backward",
)

if do_cache_filtered_task:
cached_fp.parent.mkdir(exist_ok=True, parents=True)
Expand Down

0 comments on commit 22dca2d

Please sign in to comment.