Skip to content

Commit

Permalink
Adjusted loading flat reps code to account for differences in subject…
Browse files Browse the repository at this point in the history
…_id dtype
  • Loading branch information
pargaw committed Jul 8, 2024
1 parent 5c6cb4b commit 69b99ce
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def load_flat_rep(

static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet")
if task_df_name is not None:
static_df = static_df.cast({"subject_id": sp_join_df.select('subject_id').dtypes[0]})
static_df = static_df.join(sp_join_df.select("subject_id").unique(), on="subject_id", how="inner")

dfs = []
Expand Down Expand Up @@ -184,7 +185,7 @@ def load_flat_rep(
df = pl.scan_parquet(fp)
if task_df_name is not None:
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects))

df = df.cast({"subject_id": filter_join_df.select('subject_id').dtypes[0]})
df = filter_join_df.join_asof(
df,
by="subject_id",
Expand All @@ -203,7 +204,7 @@ def load_flat_rep(

window_dfs.append(df)

dfs.append(pl.concat(window_dfs, how="vertical"))
dfs.append(pl.concat(window_dfs, how="vertical_relaxed"))

joined_df = dfs[0]
for jdf in dfs[1:]:
Expand Down

0 comments on commit 69b99ce

Please sign in to comment.