Skip to content

Commit

Permalink
Adding a function to properly handle gobble date ranges (#951)
Browse files Browse the repository at this point in the history
* Adding a function to properly handle gobble date ranges

* Fixing import path

* Fixing logic

* Renaming sdate and edate to start_date and end_date
  • Loading branch information
devinmatte committed Feb 25, 2024
1 parent 117cb15 commit 7996e73
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 67 deletions.
24 changes: 12 additions & 12 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,43 +117,43 @@ def alerts_route(user_date):

@app.route("/api/aggregate/traveltimes", cors=cors_config)
def traveltime_aggregate_route():
sdate = parse_user_date(app.current_request.query_params["start_date"])
edate = parse_user_date(app.current_request.query_params["end_date"])
start_date = parse_user_date(app.current_request.query_params["start_date"])
end_date = parse_user_date(app.current_request.query_params["end_date"])
from_stops = app.current_request.query_params.getlist("from_stop")
to_stops = app.current_request.query_params.getlist("to_stop")

response = aggregation.travel_times_over_time(sdate, edate, from_stops, to_stops)
response = aggregation.travel_times_over_time(start_date, end_date, from_stops, to_stops)
return json.dumps(response, indent=4, sort_keys=True, default=str)


@app.route("/api/aggregate/traveltimes2", cors=cors_config)
def traveltime_aggregate_route_2():
sdate = parse_user_date(app.current_request.query_params["start_date"])
edate = parse_user_date(app.current_request.query_params["end_date"])
start_date = parse_user_date(app.current_request.query_params["start_date"])
end_date = parse_user_date(app.current_request.query_params["end_date"])
from_stop = app.current_request.query_params.getlist("from_stop")
to_stop = app.current_request.query_params.getlist("to_stop")

response = aggregation.travel_times_all(sdate, edate, from_stop, to_stop)
response = aggregation.travel_times_all(start_date, end_date, from_stop, to_stop)
return json.dumps(response, indent=4, sort_keys=True, default=str)


@app.route("/api/aggregate/headways", cors=cors_config)
def headways_aggregate_route():
sdate = parse_user_date(app.current_request.query_params["start_date"])
edate = parse_user_date(app.current_request.query_params["end_date"])
start_date = parse_user_date(app.current_request.query_params["start_date"])
end_date = parse_user_date(app.current_request.query_params["end_date"])
stops = app.current_request.query_params.getlist("stop")

response = aggregation.headways_over_time(sdate, edate, stops)
response = aggregation.headways_over_time(start_date, end_date, stops)
return json.dumps(response, indent=4, sort_keys=True, default=str)


@app.route("/api/aggregate/dwells", cors=cors_config)
def dwells_aggregate_route():
sdate = parse_user_date(app.current_request.query_params["start_date"])
edate = parse_user_date(app.current_request.query_params["end_date"])
start_date = parse_user_date(app.current_request.query_params["start_date"])
end_date = parse_user_date(app.current_request.query_params["end_date"])
stops = app.current_request.query_params.getlist("stop")

response = aggregation.dwells_over_time(sdate, edate, stops)
response = aggregation.dwells_over_time(start_date, end_date, stops)
return json.dumps(response, indent=4, sort_keys=True, default=str)


Expand Down
20 changes: 10 additions & 10 deletions server/chalicelib/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def faster_describe(grouped):
# `travel_times_over_time` is legacy and returns just the by_date aggregation w/ peak == all


def aggregate_traveltime_data(sdate, edate, from_stops, to_stops):
all_data = data_funcs.travel_times(sdate, from_stops, to_stops, edate)
def aggregate_traveltime_data(start_date: str | datetime.date, end_date: str | datetime.date, from_stops, to_stops):
all_data = data_funcs.travel_times(start_date, from_stops, to_stops, end_date)
if not all_data:
return None

Expand Down Expand Up @@ -99,8 +99,8 @@ def calc_travel_times_by_date(df):
return summary_stats_final


def travel_times_all(sdate, edate, from_stops, to_stops):
df = aggregate_traveltime_data(sdate, edate, from_stops, to_stops)
def travel_times_all(start_date: str | datetime.date, end_date: str, from_stops, to_stops):
df = aggregate_traveltime_data(start_date, end_date, from_stops, to_stops)
if df is None:
return {"by_date": [], "by_time": []}
by_date = calc_travel_times_by_date(df)
Expand All @@ -112,8 +112,8 @@ def travel_times_all(sdate, edate, from_stops, to_stops):
}


def travel_times_over_time(sdate, edate, from_stops, to_stops):
df = aggregate_traveltime_data(sdate, edate, from_stops, to_stops)
def travel_times_over_time(start_date: str | datetime.date, end_date: str | datetime.date, from_stops, to_stops):
df = aggregate_traveltime_data(start_date, end_date, from_stops, to_stops)
if df is None:
return []
stats = calc_travel_times_by_date(df)
Expand All @@ -123,8 +123,8 @@ def travel_times_over_time(sdate, edate, from_stops, to_stops):
####################
# HEADWAYS
####################
def headways_over_time(sdate, edate, stops):
all_data = data_funcs.headways(sdate, stops, edate)
def headways_over_time(start_date: str | datetime.date, end_date: str | datetime.date, stops):
all_data = data_funcs.headways(start_date, stops, end_date)
if not all_data:
return []

Expand Down Expand Up @@ -155,8 +155,8 @@ def headways_over_time(sdate, edate, stops):
return results.to_dict("records")


def dwells_over_time(sdate, edate, stops):
all_data = data_funcs.dwells(sdate, stops, edate)
def dwells_over_time(start_date: str | datetime.date, end_date: str | datetime.date, stops):
all_data = data_funcs.dwells(start_date, stops, end_date)
if not all_data:
return []

Expand Down
50 changes: 25 additions & 25 deletions server/chalicelib/data_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def use_S3(date, bus=False):
return archival or bus


def partition_S3_dates(start_date, end_date, bus=False):
def partition_S3_dates(start_date: str | date, end_date: str | date, bus=False):
"""
Partitions dates by what data source they should be fetched from.
S3 is used for archival data and for bus data. API is used for recent (within 90 days) subway data.
Expand All @@ -79,14 +79,14 @@ def partition_S3_dates(start_date, end_date, bus=False):
return (s3_dates, api_dates)


def headways(sdate, stops, edate=None):
if edate is None:
if use_S3(sdate, is_bus(stops)):
return s3_historical.headways(stops, sdate, sdate)
def headways(start_date: str | date, stops, end_date: str | date | None = None):
if end_date is None:
if use_S3(start_date, is_bus(stops)):
return s3_historical.headways(stops, start_date, start_date)
else:
return process_mbta_headways(stops, sdate)
return process_mbta_headways(stops, start_date)

s3_interval, api_interval = partition_S3_dates(sdate, edate, is_bus(stops))
s3_interval, api_interval = partition_S3_dates(start_date, end_date, is_bus(stops))
all_data = []
if s3_interval:
start, end = s3_interval
Expand All @@ -109,9 +109,9 @@ def current_transit_day():
return today


def process_mbta_headways(stops, sdate, edate=None):
def process_mbta_headways(stops, start_date: str | date, end_date: str | date | None = None):
# get data
api_data = MbtaPerformanceAPI.get_api_data("headways", {"stop": stops}, sdate, edate)
api_data = MbtaPerformanceAPI.get_api_data("headways", {"stop": stops}, start_date, end_date)
# combine all headways data
headways = []
for dict_data in api_data:
Expand All @@ -130,14 +130,14 @@ def process_mbta_headways(stops, sdate, edate=None):
return sorted(headways, key=lambda x: x["current_dep_dt"])


def travel_times(sdate, from_stops, to_stops, edate=None):
if edate is None:
if use_S3(sdate, is_bus(from_stops)):
return s3_historical.travel_times(from_stops, to_stops, sdate, sdate)
def travel_times(start_date, from_stops, to_stops, end_date: str | date | None = None):
if end_date is None:
if use_S3(start_date, is_bus(from_stops)):
return s3_historical.travel_times(from_stops, to_stops, start_date, start_date)
else:
return process_mbta_travel_times(from_stops, to_stops, sdate)
return process_mbta_travel_times(from_stops, to_stops, start_date)

s3_interval, api_interval = partition_S3_dates(sdate, edate, is_bus(from_stops))
s3_interval, api_interval = partition_S3_dates(start_date, end_date, is_bus(from_stops))
all_data = []
if s3_interval:
start, end = s3_interval
Expand All @@ -149,10 +149,10 @@ def travel_times(sdate, from_stops, to_stops, edate=None):
return all_data


def process_mbta_travel_times(from_stops, to_stops, sdate, edate=None):
def process_mbta_travel_times(from_stops, to_stops, start_date: str | date, end_date: str | date | None = None):
# get data
api_data = MbtaPerformanceAPI.get_api_data(
"traveltimes", {"from_stop": from_stops, "to_stop": to_stops}, sdate, edate
"traveltimes", {"from_stop": from_stops, "to_stop": to_stops}, start_date, end_date
)
# combine all travel times data, remove threshold flags from performance API, and dedupe on `dep_dt`
trips = {}
Expand All @@ -174,14 +174,14 @@ def process_mbta_travel_times(from_stops, to_stops, sdate, edate=None):
return sorted(trips_list, key=lambda x: x["dep_dt"])


def dwells(sdate, stops, edate=None):
if edate is None:
if use_S3(sdate, is_bus(stops)):
return s3_historical.dwells(stops, sdate, sdate)
def dwells(start_date, stops, end_date: str | date | None = None):
if end_date is None:
if use_S3(start_date, is_bus(stops)):
return s3_historical.dwells(stops, start_date, start_date)
else:
return process_mbta_dwells(stops, sdate)
return process_mbta_dwells(stops, start_date)

s3_interval, api_interval = partition_S3_dates(sdate, edate, is_bus(stops))
s3_interval, api_interval = partition_S3_dates(start_date, end_date, is_bus(stops))
all_data = []
if s3_interval:
start, end = s3_interval
Expand All @@ -194,9 +194,9 @@ def dwells(sdate, stops, edate=None):
return all_data


def process_mbta_dwells(stops, sdate, edate=None):
def process_mbta_dwells(stops, start_date: str | date, end_date: str | date | None = None):
# get data
api_data = MbtaPerformanceAPI.get_api_data("dwells", {"stop": stops}, sdate, edate)
api_data = MbtaPerformanceAPI.get_api_data("dwells", {"stop": stops}, start_date, end_date)

# combine all travel times data
dwells = []
Expand Down
3 changes: 3 additions & 0 deletions server/chalicelib/date_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
DATE_FORMAT_OUT = "%Y-%m-%dT%H:%M:%S"
EASTERN_TIME = ZoneInfo("US/Eastern")

# The most recent date for which we have monthly data
MAX_MONTH_DATA_DATE = "2023-12-31"


def parse_event_date(date_str: str):
if len(date_str) == 19:
Expand Down
2 changes: 1 addition & 1 deletion server/chalicelib/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def query_ridership(start_date: date, end_date: date, line_id: str = None):
return ddb_json.loads(response["Items"])


def query_agg_trip_metrics(start_date: str, end_date: str, table_name: str, line: str = None):
def query_agg_trip_metrics(start_date: str | date, end_date: str | date, table_name: str, line: str = None):
table = dynamodb.Table(table_name)
line_condition = Key("line").eq(line)
date_condition = Key("date").between(start_date, end_date)
Expand Down
27 changes: 23 additions & 4 deletions server/chalicelib/parallel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd

from chalicelib.date_utils import MAX_MONTH_DATA_DATE


def make_parallel(single_func, THREAD_COUNT=5):
# This function will wrap another function
Expand All @@ -21,15 +22,33 @@ def parallel_func(iterable, *args, **kwargs):
return parallel_func


def date_range(start, end):
def date_range(start: str, end: str):
return pd.date_range(start, end)


def month_range(start, end):
def s3_date_range(start: str, end: str):
"""
Generates a date range, meant for s3 data
For all dates that we have monthly datasets for, return 1 date of the month
For all dates that we have daily datasets for, return all dates
"""
month_end = end
if pd.to_datetime(MAX_MONTH_DATA_DATE) < pd.to_datetime(end):
month_end = MAX_MONTH_DATA_DATE

# This is kinda funky, but is stil simpler than other approaches
# pandas won't generate a monthly date_range that includes Jan and Feb for Jan31-Feb1 e.g.
# So we generate a daily date_range and then resample it down (summing 0s as a no-op in the process) so it aligns.
dates = pd.date_range(start, end, freq="1D", inclusive="both")
dates = pd.date_range(start, month_end, freq="1D", inclusive="both")
series = pd.Series(0, index=dates)
months = series.resample("1M").sum().index

# all dates between month_end and end if month_end is less than end
if pd.to_datetime(month_end) < pd.to_datetime(end):
daily_dates = pd.date_range(month_end, end, freq="1D", inclusive="both")

# combine the two date ranges of months and dates
if daily_dates is not None and len(daily_dates) > 0:
months = months.union(daily_dates)

return months
11 changes: 6 additions & 5 deletions server/chalicelib/s3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import date
import boto3
import botocore
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -80,12 +81,12 @@ def parallel_download_events(datestop):
return download_one_event_file(date, stop)


def download_events(sdate, edate, stops: list):
# This needs to be month_range for performance and memory,
# however, for data from gobble we'll need specific dates, not just first of the month
datestops = itertools.product(parallel.month_range(sdate, edate), stops)
def download_events(start_date: str | date, end_date: str | date, stops: list):
datestops = itertools.product(parallel.s3_date_range(start_date, end_date), stops)
result = parallel_download_events(datestops)
result = filter(lambda row: sdate.strftime("%Y-%m-%d") <= row["service_date"] <= edate.strftime("%Y-%m-%d"), result)
result = filter(
lambda row: start_date.strftime("%Y-%m-%d") <= row["service_date"] <= end_date.strftime("%Y-%m-%d"), result
)
return sorted(result, key=lambda row: row["event_time"])


Expand Down
15 changes: 8 additions & 7 deletions server/chalicelib/s3_historical.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import date
from chalicelib import s3
from chalicelib.constants import EVENT_ARRIVAL, EVENT_DEPARTURE

Expand Down Expand Up @@ -32,8 +33,8 @@ def unique_everseen(iterable, key=None):
yield element


def dwells(stop_ids: list, sdate, edate):
rows_by_time = s3.download_events(sdate, edate, stop_ids)
def dwells(stop_ids: list, start_date: str | date, end_date: str | date):
rows_by_time = s3.download_events(start_date, end_date, stop_ids)

dwells = []
for maybe_an_arrival, maybe_a_departure in pairwise(rows_by_time):
Expand All @@ -59,8 +60,8 @@ def dwells(stop_ids: list, sdate, edate):
return dwells


def headways(stop_ids: list, sdate, edate):
rows_by_time = s3.download_events(sdate, edate, stop_ids)
def headways(stop_ids: list, start_date: str | date, end_date: str | date):
rows_by_time = s3.download_events(start_date, end_date, stop_ids)

only_departures = filter(lambda row: row["event_type"] in EVENT_DEPARTURE, rows_by_time)

Expand Down Expand Up @@ -98,9 +99,9 @@ def headways(stop_ids: list, sdate, edate):
return headways


def travel_times(stops_a: list, stops_b: list, sdate, edate):
rows_by_time_a = s3.download_events(sdate, edate, stops_a)
rows_by_time_b = s3.download_events(sdate, edate, stops_b)
def travel_times(stops_a: list, stops_b: list, start_date: str | date, end_date: str | date):
rows_by_time_a = s3.download_events(start_date, end_date, stops_a)
rows_by_time_b = s3.download_events(start_date, end_date, stops_b)

departures = filter(lambda event: event["event_type"] in EVENT_DEPARTURE, rows_by_time_a)
# we reverse arrivals so that if the same train arrives twice (this can happen),
Expand Down
6 changes: 3 additions & 3 deletions server/chalicelib/speed.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import TypedDict
from chalice import BadRequestError, ForbiddenError
from chalicelib import dynamo
from datetime import datetime, timedelta
from datetime import date, datetime, timedelta
import pandas as pd
import numpy as np


class TripMetricsByLineParams(TypedDict):
start_date: str
end_date: str
start_date: str | date
end_date: str | date
agg: str
line: str

Expand Down

0 comments on commit 7996e73

Please sign in to comment.