Skip to content

Commit

Permalink
Things are partially improved, but other tests are still failing. Inv…
Browse files Browse the repository at this point in the history
…estigating
  • Loading branch information
mmcdermott committed Jun 22, 2024
1 parent 9acff54 commit 1b4f0d8
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions EventStream/data/dataset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def _agg_by_time(self):
)

def _update_subject_event_properties(self):
self.subject_id_dtype = self.events_df.schema["subject_id"]
if self.events_df is not None:
logger.debug("Collecting event types")
self.event_types = (
Expand All @@ -700,16 +701,18 @@ def _update_subject_event_properties(self):
)

n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count"))
# here we cast to str to avoid issues with the subject_id column being various other types as we
# will eventually JSON serialize it.
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
self.n_events_per_subject = {
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
}
self.subject_ids = set(self.n_events_per_subject.keys())

if self.subjects_df is not None:
logger.debug("Collecting subject event counts")
subjects_with_no_events = (
set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids
)
subjects_df_subjects = self.subjects_df.select(pl.col("subject_id").cast(pl.Utf8))
subjects_with_no_events = set(subjects_df_subjects["subject_id"].to_list()) - self.subject_ids
for sid in subjects_with_no_events:
self.n_events_per_subject[sid] = 0
self.subject_ids.update(subjects_with_no_events)
Expand All @@ -725,7 +728,18 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool |
filter_exprs.append(pl.col(col).is_null())
case _:
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
logger.debug(
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
f"{col} to {df.schema[col]}"
)
if isinstance(list(incl_targets)[0], str):
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
else:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])

incl_list = incl_list.cast(df.schema[col])

logger.debug(f"Converted to Series of type {incl_list.dtype}")
except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
Expand Down

0 comments on commit 1b4f0d8

Please sign in to comment.