diff --git a/src/ionbeam/aggregators/by_time_2.py b/src/ionbeam/aggregators/by_time_2.py index 7ff120d..6fd6291 100644 --- a/src/ionbeam/aggregators/by_time_2.py +++ b/src/ionbeam/aggregators/by_time_2.py @@ -12,7 +12,7 @@ import logging from collections import defaultdict from dataclasses import field -from datetime import datetime, time, timedelta, timezone +from datetime import datetime, time, timedelta from typing import Iterable, Literal import pandas as pd @@ -29,6 +29,7 @@ @dataclasses.dataclass class NewTimeAggregator(Aggregator): "How much time a data granule represents." + granularity_hours: int = 1 "Is data arriving chronologically or the reverse?" @@ -67,32 +68,44 @@ def init(self, globals): def bucket_to_message(self, bucket): data = [m.data for m in bucket] + print(data) return TabularMessage( metadata=self.metadata, data=pd.concat(data), ) def update_time_frontier(self, start_time, end_time): - assert start_time.tzinfo is not None - assert end_time.tzinfo is not None + if start_time.tzinfo is not None: + start_time = start_time.tz_convert("UTC").tz_localize(None) + if end_time.tzinfo is not None: + end_time = end_time.tz_convert("UTC").tz_localize(None) if self.time_frontier is None: - self.time_frontier = start_time if self.time_direction == "forwards" else end_time + self.time_frontier = ( + start_time if self.time_direction == "forwards" else end_time + ) return if self.time_direction == "forwards": self.time_frontier = max(self.time_frontier, start_time) elif self.time_direction == "backwards": self.time_frontier = min(self.time_frontier, end_time) - def process(self, message: TabularMessage | FinishMessage) -> Iterable[TabularMessage]: + def process( + self, message: TabularMessage | FinishMessage + ) -> Iterable[TabularMessage]: if isinstance(message, FinishMessage): for _, bucket in self.time_chunks.items(): yield self.bucket_to_message(bucket) return message_timedelta = message.data.time.max() - message.data.time.min() - self.min_emit_after_hours = max(self.min_emit_after_hours, message_timedelta.total_seconds() / 60**2) - self.update_time_frontier(start_time=message.data.time.min(), end_time=message.data.time.max()) + self.min_emit_after_hours = max( + self.min_emit_after_hours, message_timedelta.total_seconds() / 60**2 + ) + self.update_time_frontier( + start_time=message.data.time.min(), + end_time=message.data.time.max(), + ) print( f""" @@ -107,11 +120,13 @@ def process(self, message: TabularMessage | FinishMessage) -> Iterable[TabularMe # convert timezone to utc and convert to time naive message.data.time = message.data.time.dt.tz_convert("UTC").dt.tz_localize(None) - for (date, hour), data_chunk in message.data.groupby([message.data.time.dt.date, message.data.time.dt.hour]): + for (date, hour), data_chunk in message.data.groupby( + [message.data.time.dt.date, message.data.time.dt.hour] + ): if data_chunk.empty: continue chunked_message = dataclasses.replace(message, data=data_chunk) - start_time = datetime.combine(date, time(hour=hour, tzinfo=timezone.utc)) + start_time = datetime.combine(date, time(hour=hour, tzinfo=None)) self.time_chunks[start_time].append(chunked_message) assert self.time_frontier is not None