diff --git a/EventStream/baseline/FT_task_baseline.py b/EventStream/baseline/FT_task_baseline.py index ddf4fdb1..928b0dd3 100644 --- a/EventStream/baseline/FT_task_baseline.py +++ b/EventStream/baseline/FT_task_baseline.py @@ -148,8 +148,9 @@ def load_flat_rep( by_split = {} for sp, all_sp_subjects in ESD.split_subjects.items(): + all_sp_subjects = pl.Series(list(all_sp_subjects)).cast(ESD.subject_id_dtype) if task_df_name is not None: - sp_join_df = join_df.filter(pl.col("subject_id").is_in(list(all_sp_subjects))) + sp_join_df = join_df.filter(pl.col("subject_id").is_in(all_sp_subjects)) static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet") if task_df_name is not None: @@ -175,13 +176,17 @@ def load_flat_rep( df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features) if subjects_included.get(sp, None) is not None: subjects = list(set(subjects).intersection(subjects_included[sp])) - df = df.filter(pl.col("subject_id").is_in(subjects)) + df = df.filter( + pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype)) + ) window_dfs.append(df) continue df = pl.scan_parquet(fp) if task_df_name is not None: - filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects)) + filter_join_df = sp_join_df.select(join_keys).filter( + pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype)) + ) df = df.join(filter_join_df, on=join_keys, how="inner") @@ -193,7 +198,7 @@ def load_flat_rep( df = df.select("subject_id", "timestamp", *window_features) if subjects_included.get(sp, None) is not None: subjects = list(set(subjects).intersection(subjects_included[sp])) - df = df.filter(pl.col("subject_id").is_in(subjects)) + df = df.filter(pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))) window_dfs.append(df) diff --git a/tests/test_e2e_runs.py b/tests/test_e2e_runs.py index 4a32e985..a4280ac0 100644 --- a/tests/test_e2e_runs.py +++ b/tests/test_e2e_runs.py @@ -81,10 +81,6 @@ def _test_dataset_output(self, raw_data_root: Path, dataset_save_dir: Path): tuning_DL_subjects = set(tuning_DL_reps["subject_id"].unique().to_list()) held_out_DL_subjects = set(held_out_DL_reps["subject_id"].unique().to_list()) - all_DL_subjects = train_DL_subjects | tuning_DL_subjects | held_out_DL_subjects - - self.assertEqual(all_DL_subjects, all_subjects) - self.assertEqual(len(train_DL_subjects & tuning_DL_subjects), 0) self.assertEqual(len(train_DL_subjects & held_out_DL_subjects), 0) self.assertEqual(len(tuning_DL_subjects & held_out_DL_subjects), 0) @@ -92,6 +88,10 @@ def _test_dataset_output(self, raw_data_root: Path, dataset_save_dir: Path): self.assertTrue(len(train_DL_subjects) > len(tuning_DL_subjects)) self.assertTrue(len(train_DL_subjects) > len(held_out_DL_subjects)) + all_DL_subjects = train_DL_subjects | tuning_DL_subjects | held_out_DL_subjects + + self.assertEqual(all_DL_subjects, all_subjects) + def _test_command(self, command_parts: list[str], case_name: str, use_subtest: bool = True): if use_subtest: with self.subTest(case_name):