Skip to content

Commit

Permalink
This may have fixed it.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Jun 22, 2024
1 parent 1b4f0d8 commit 39ba674
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
13 changes: 9 additions & 4 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def load_flat_rep(

by_split = {}
for sp, all_sp_subjects in ESD.split_subjects.items():
all_sp_subjects = pl.Series(list(all_sp_subjects)).cast(ESD.subject_id_dtype)
if task_df_name is not None:
sp_join_df = join_df.filter(pl.col("subject_id").is_in(list(all_sp_subjects)))
sp_join_df = join_df.filter(pl.col("subject_id").is_in(all_sp_subjects))

static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet")
if task_df_name is not None:
Expand All @@ -175,13 +176,17 @@ def load_flat_rep(
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
df = df.filter(pl.col("subject_id").is_in(subjects))
df = df.filter(
pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))
)
window_dfs.append(df)
continue

df = pl.scan_parquet(fp)
if task_df_name is not None:
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects))
filter_join_df = sp_join_df.select(join_keys).filter(
pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))
)

df = df.join(filter_join_df, on=join_keys, how="inner")

Expand All @@ -193,7 +198,7 @@ def load_flat_rep(
df = df.select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
df = df.filter(pl.col("subject_id").is_in(subjects))
df = df.filter(pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype)))

window_dfs.append(df)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_e2e_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ def _test_dataset_output(self, raw_data_root: Path, dataset_save_dir: Path):
tuning_DL_subjects = set(tuning_DL_reps["subject_id"].unique().to_list())
held_out_DL_subjects = set(held_out_DL_reps["subject_id"].unique().to_list())

all_DL_subjects = train_DL_subjects | tuning_DL_subjects | held_out_DL_subjects

self.assertEqual(all_DL_subjects, all_subjects)

self.assertEqual(len(train_DL_subjects & tuning_DL_subjects), 0)
self.assertEqual(len(train_DL_subjects & held_out_DL_subjects), 0)
self.assertEqual(len(tuning_DL_subjects & held_out_DL_subjects), 0)

self.assertTrue(len(train_DL_subjects) > len(tuning_DL_subjects))
self.assertTrue(len(train_DL_subjects) > len(held_out_DL_subjects))

all_DL_subjects = train_DL_subjects | tuning_DL_subjects | held_out_DL_subjects

self.assertEqual(all_DL_subjects, all_subjects)

def _test_command(self, command_parts: list[str], case_name: str, use_subtest: bool = True):
if use_subtest:
with self.subTest(case_name):
Expand Down

0 comments on commit 39ba674

Please sign in to comment.