Skip to content

Commit

Permalink
Fix sck source to create empty chunks when there's no data
Browse files Browse the repository at this point in the history
  • Loading branch information
TomHodson committed Jan 23, 2025
1 parent eceb466 commit 179ccf1
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 45 deletions.
27 changes: 27 additions & 0 deletions src/ionbeam/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@ def __post_init__(self):
def union(self, other: Self) -> Self:
return type(self)(min(self.start, other.start), max(self.end, other.end))

def overlaps(self, other: Self) -> bool:
"""Return True if the time spans overlap"""
return self.start < other.end and self.end > other.start

@classmethod
def all_time(cls) -> Self:
"""Return a time span that covers all possible times"""
return cls(datetime.min.replace(tzinfo=UTC), datetime.max.replace(tzinfo=UTC))

@classmethod
def from_set(cls, times: Iterable[datetime | None]) -> Self | None:
"""Compute the union of a set of (possibly None) datetimes.
Return None if the set is empty or contains only None"""
times = [t for t in times if t is not None]
if not times:
return None
return cls(min(times), max(times))

@classmethod
def max(cls) -> Self:
"""Return a time span that covers all possible times"""
return cls(datetime.min.replace(tzinfo=UTC), datetime.max.replace(tzinfo=UTC))

@classmethod
def parse(cls, value: dict[str, str]) -> Self:
"""
Expand Down Expand Up @@ -118,6 +141,10 @@ def __contains__(self, other: datetime) -> bool:
if not other.tzinfo:
raise ValueError('Tried to do "o in TimeSpan()" where o was a naive datetime object')
return self.start <= other < self.end

def expand(self, dt: datetime) -> Self:
"""Expand the time span to include the given datetime"""
return type(self)(min(self.start, dt), max(self.end, dt))

def round_datetime(dt: datetime, round_to: timedelta, method: str = "floor") -> datetime:
if round_to.total_seconds() <= 0:
Expand Down
47 changes: 33 additions & 14 deletions src/ionbeam/sources/API_sources_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from io import BytesIO
from time import time
from typing import Any, Iterable, Literal, Self
from urllib.parse import urljoin

import pandas as pd
import requests
from requests.adapters import HTTPAdapter
from sqlalchemy import Index, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB, insert
from sqlalchemy.orm import Mapped, Session, load_only, mapped_column
from sqlalchemy_utils import UUIDType
from urllib3.util import Retry

from ..core.bases import TabularMessage
from ..core.source import Source
Expand All @@ -24,7 +24,6 @@
)
from ..metadata.db import Base
from ..singleprocess_pipeline import fmt_time
from time import time

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(self, retry_at: datetime | None = None,
logger.debug("No Retry-After header, delaying for 5 minutes.")
self.retry_at = datetime.now(UTC) + timedelta(minutes=5)

@dataclasses.dataclass(eq=True, frozen=True)
@dataclass(eq=True, frozen=True)
class DataChunk:
"""
Represents a chunk of a stream of data that has been downloaded from an API and stored in the db
Expand Down Expand Up @@ -86,6 +85,18 @@ class DataChunk:
def __repr__(self) -> str:
return f"DataChunk({self.source}, {self.key}, {self.version}, {self.time_span}, {self.ingestion_time}, {self.success}, {self.empty})"

@classmethod
def make_empty_chunk(cls, data_stream: "DataStream", time_span) -> Self:
return cls(
source=data_stream.source,
key = data_stream.key,
version = data_stream.version,
empty = True,
time_span = time_span,
json = {},
data = None,
)

@classmethod
def make_error_chunk(cls, data_stream : "DataStream", time_span : TimeSpan, error : Exception, json = {}) -> Self:
json = json.copy()
Expand Down Expand Up @@ -201,7 +212,7 @@ def to_data_chunk(self, only_metadata = False) -> DataChunk:
data=pd.read_parquet(BytesIO(self.data)) if not only_metadata and self.data is not None else None
)

@dataclasses.dataclass(frozen=True)
@dataclass(frozen=True)
class DataStream:
"""
Represents a logical stream of data from an API with no time component.
Expand All @@ -227,7 +238,12 @@ def get_chunks(self, db_session: Session, time_span : TimeSpan,
success: bool | None = None,
ingested_after : datetime | None = None,
ingested_before : datetime | None = None,
empty : bool | None = False) -> Iterable[DataChunk]:

# Empty chunks serve as a sentinel that we have queried for data from this timespan but there wasn't any
# By default do not return empty chunks, but it's very important to set this to None when deciding
# what data to query for next
empty : bool | None = False
) -> Iterable[DataChunk]:

query = db_session.query(DBDataChunk).filter_by(
source=self.source,
Expand Down Expand Up @@ -375,7 +391,7 @@ def emit_messages(self, relevant_chunks : Iterable[DataChunk], time_span_group:
"""
pass

@dataclasses.dataclass
@dataclass
class APISource(Source, AbstractDataSourceMixin):
"""
The generic logic related to sources that periodically query API endpoints.
Expand All @@ -394,9 +410,9 @@ class APISource(Source, AbstractDataSourceMixin):
cache_version: int = 3 # increment this to invalidate the cache
use_cache: bool = True
source: str = "should be set by derived class"
maximum_request_size: timedelta = timedelta(days=1)
minimum_request_size: timedelta = timedelta(minutes=5)
max_time_downloading: timedelta = timedelta(seconds = 30)
maximum_request_size: timedelta = field(kw_only=True)
minimum_request_size: timedelta = field(kw_only=True)
max_time_downloading: timedelta = field(kw_only=True)

def init(self, globals, **kwargs):
super().init(globals, **kwargs)
Expand Down Expand Up @@ -428,7 +444,7 @@ def gaps_in_database(self, db_session: Session, data_stream: DataStream, time_sp
"""Given a DataStream and a time span, return any gaps in the time span that need to be downloaded"""

# Get all the timespan of all chunks that have been ingested sucessfully
chunks_time_spans = list(c.time_span for c in data_stream.get_chunks(db_session, time_span, success=True, empty=False, mode="metadata"))
chunks_time_spans = list(c.time_span for c in data_stream.get_chunks(db_session, time_span, success=True, mode="metadata", empty=None))
chunks_to_ingest = time_span.remove(chunks_time_spans)
split_chunks = [
c
Expand Down Expand Up @@ -457,6 +473,8 @@ def download_data(self, data_streams: Iterable[DataStream],
fail = False) -> Iterable[DataChunk]:

start_time = datetime.now(UTC)
logger.info(f"Starting download for source {self.source}")
logger.debug(f"{self.max_time_downloading = }, {self.maximum_request_size = }, {self.minimum_request_size = }")

with self.globals.sql_session.begin() as db_session:
# Check if the source has been rate limited recently
Expand Down Expand Up @@ -493,7 +511,7 @@ def download_data(self, data_streams: Iterable[DataStream],
t0 = time()
data_chunk = self.download_chunk(data_stream, gap)
data_chunk.write_to_db(db_session)
logger.info(f"Downloaded data and wrote to db for stream {data_stream.key} in {fmt_time(time() - t0)}")
logger.info(f"Downloaded data and wrote to db for stream {data_stream.key} in {fmt_time(time() - t0)} {data_chunk.empty = }")
yield data_chunk

except ExistingDataException:
Expand All @@ -513,8 +531,9 @@ def download_data(self, data_streams: Iterable[DataStream],
# Todo: write a source level flag to indicate that we are rate limited
# And when we should next attempt this source
return
# return

def get_all_data_streams(self, db_session: Session) -> Iterable[DataStream]:
def get_all_data_streams(self, db_session: Session, timespan : TimeSpan | None = None) -> Iterable[DataStream]:
for ds in db_session.query(DBDataStream).filter_by(source=self.source).all():
yield ds.to_data_stream()

Expand Down Expand Up @@ -661,7 +680,7 @@ def generate(self) -> Iterable[TabularMessage]:



@dataclasses.dataclass
@dataclass
class RESTSource(APISource):
endpoint = "scheme://example.com/api/v1" # Override this in derived classes

Expand Down
66 changes: 35 additions & 31 deletions src/ionbeam/sources/smart_citizen_kit/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from time import sleep, time
from typing import Iterable
from unicodedata import normalize
from time import time

import numpy as np
import pandas as pd
from cachetools import TTLCache, cachedmethod
from cachetools.keys import hashkey

from ...core.bases import Mappings, RawVariable, TabularMessage, TimeSpan
from ..API_sources_base import DataChunk, DataStream, RESTSource
from ...singleprocess_pipeline import fmt_time
from time import sleep
from ..API_sources_base import DataChunk, DataStream, RESTSource

logger = logging.getLogger(__name__)

Expand All @@ -34,9 +33,10 @@ class SmartCitizenKitSource(RESTSource):
"""
mappings: Mappings = field(kw_only=True)

maximum_request_size = timedelta(days=10)
minimum_request_size = timedelta(minutes=5)
max_time_downloading = timedelta(seconds=60)
maximum_request_size: timedelta = timedelta(days=10)
minimum_request_size: timedelta = timedelta(minutes=5)
max_time_downloading: timedelta = timedelta(minutes=1)

cache_directory: Path = Path("inputs/smart_citizen_kit")
endpoint = "https://api.smartcitizen.me/v0"
cache = TTLCache(maxsize=1e5, ttl=20 * 60) # Cache API responses for 20 minutes
Expand Down Expand Up @@ -73,6 +73,7 @@ def get_readings(self, device_id : int, sensor_id : int, time_span: TimeSpan):
"function": "avg",
"from": time_span.start.isoformat(),
"to": time_span.end.isoformat(),
"all_intervals": "false",
}

return self.get(
Expand Down Expand Up @@ -101,21 +102,22 @@ def get_devices_in_date_range(self, time_span: TimeSpan) -> list[dict]:
devices = self.get_ICHANGE_devices()

def filter_by_dates(device):
if device["last_reading_at"] is None or device["created_at"] is None:
return False
earliest_reading = datetime.max.replace(tzinfo=UTC)
latest_reading = datetine.min.replace

for sensor in
last_reading_at = max(
datetime.fromisoformat(sensor["created_at"]) if datetime.min.replace(tzinfo=UTC)
)

device_start_date = datetime.fromisoformat(device["created_at"])
device_end_date = datetime.fromisoformat(device["last_reading_at"])
# see https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap
return (device_start_date <= time_span.end) and (device_end_date >= time_span.start)
device_timespan = TimeSpan.from_set((
dt
for sensor in device["data"]["sensors"]
for dt in [
datetime.fromisoformat(sensor["created_at"]) if sensor["created_at"] is not None else None,
datetime.fromisoformat(sensor["last_reading_at"]) if sensor["last_reading_at"] is not None else None
]))
if device_timespan is None:
return False

device["timespan"] = device_timespan.as_json()

return device_timespan.overlaps(time_span)

devices_in_date_range = [d for d in devices if filter_by_dates(d)]

return devices_in_date_range
Expand Down Expand Up @@ -207,6 +209,13 @@ def get_data_streams(self, time_span: TimeSpan) -> Iterable[DataStream]:
def download_chunk(self, data_stream: DataStream, time_span: TimeSpan) -> DataChunk:
device = data_stream.data
device_id = device["id"]
device_timespan = time_span.from_json(device["timespan"])

# Quick exit if the time spans don't overlap
if not device_timespan.overlaps(time_span):
return DataChunk.make_empty_chunk(data_stream, time_span)


# logger.debug(f"Downloading data for device {device_id} in {time_span} with sensors {[s['name'] for s in device['data']['sensors']]}")
logger.debug(f"Downloading data for device {device_id} in {time_span}")

Expand All @@ -224,9 +233,9 @@ def download_chunk(self, data_stream: DataStream, time_span: TimeSpan) -> DataCh
readings = self.get_readings(device_id, sensor["id"], time_span)

# Try to reduce how often we get rate limited by SCK
sleep(0.5)
sleep(0.1)

# Skip if there are now readings
# Skip if there are no readings
if not readings["readings"]:
logger.debug(f"No readings returned for {sensor['name']}, even though the date metadata suggested there should be.")
continue
Expand All @@ -243,21 +252,12 @@ def download_chunk(self, data_stream: DataStream, time_span: TimeSpan) -> DataCh
readings = readings["readings"]
))

logger.debug(f"Got data for SCK device {device_id} in {fmt_time(time() - t0)}")
if not sensor_data or min_time is None or max_time is None:
# raise ValueError(f"No data for {device_id = } in {time_span = }")
logger.debug(f"No data for {device_id = } in {time_span = }")
return DataChunk(
source=self.source,
key = data_stream.key,
version = self.version,
empty = True,
time_span = time_span,
json = {},
data = None,
)

return DataChunk.make_empty_chunk(data_stream, time_span)

logger.debug(f"Got data for SCK device {device_id} in {fmt_time(time() - t0)}")
return DataChunk(
source=self.source,
key = data_stream.key,
Expand Down Expand Up @@ -356,6 +356,10 @@ def emit_messages(self, relevant_chunks : Iterable[DataChunk], time_spans: Itera

all_dfs.append(df_wide)

if not all_dfs:
return


combined_df = pd.concat(
[
df.reset_index() # moves the current DatetimeIndex into a column named 'datetime'
Expand Down

0 comments on commit 179ccf1

Please sign in to comment.