diff --git a/EventStream/data/dataset_polars.py b/EventStream/data/dataset_polars.py index ed93790c..a7f03e42 100644 --- a/EventStream/data/dataset_polars.py +++ b/EventStream/data/dataset_polars.py @@ -690,6 +690,7 @@ def _agg_by_time(self): ) def _update_subject_event_properties(self): + self.subject_id_dtype = self.events_df.schema["subject_id"] if self.events_df is not None: logger.debug("Collecting event types") self.event_types = ( @@ -700,6 +701,9 @@ def _update_subject_event_properties(self): ) n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count")) + # here we cast to str to avoid issues with the subject_id column being various other types as we + # will eventually JSON serialize it. + n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)) self.n_events_per_subject = { subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"]) } @@ -707,9 +711,8 @@ def _update_subject_event_properties(self): if self.subjects_df is not None: logger.debug("Collecting subject event counts") - subjects_with_no_events = ( - set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids - ) + subjects_df_subjects = self.subjects_df.select(pl.col("subject_id").cast(pl.Utf8)) + subjects_with_no_events = set(subjects_df_subjects["subject_id"].to_list()) - self.subject_ids for sid in subjects_with_no_events: self.n_events_per_subject[sid] = 0 self.subject_ids.update(subjects_with_no_events) @@ -725,7 +728,18 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool | filter_exprs.append(pl.col(col).is_null()) case _: try: - incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) + logger.debug( + f"Converting inclusion targets of type {type(list(incl_targets)[0])} for " + f"{col} to {df.schema[col]}" + ) + if isinstance(list(incl_targets)[0], str): + incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8) + else: + incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) + + incl_list = incl_list.cast(df.schema[col]) + + logger.debug(f"Converted to Series of type {incl_list.dtype}") except TypeError as e: incl_targets_by_type = defaultdict(list) for t in incl_targets: