diff --git a/src/transport_performance/gtfs/gtfs_utils.py b/src/transport_performance/gtfs/gtfs_utils.py index 85f9a059..04bc8b10 100644 --- a/src/transport_performance/gtfs/gtfs_utils.py +++ b/src/transport_performance/gtfs/gtfs_utils.py @@ -10,6 +10,7 @@ from typing import Union import pathlib from geopandas import GeoDataFrame +from datetime import datetime from transport_performance.utils.defence import ( _is_expected_filetype, @@ -19,6 +20,15 @@ from transport_performance.utils.constants import PKG_PATH +def _validate_datestring(date_text, form="%Y%m%d"): + try: + datetime.strptime(date_text, form) + except ValueError: + raise ValueError( + f"Incorrect date format, {date_text} should be {form}" + ) + + def bbox_filter_gtfs( in_pth: Union[pathlib.Path, str] = ( os.path.join(PKG_PATH, "data", "gtfs", "newport-20230613_gtfs.zip"), @@ -34,9 +44,12 @@ def bbox_filter_gtfs( ], units: str = "km", crs: str = "epsg:4326", + filter_dates: Union[None, list] = None, ) -> None: """Filter a GTFS feed to any routes intersecting with a bounding box. + Optionally filter to a list of given dates. + Parameters ---------- in_pth : Union[pathlib.Path, str], optional @@ -53,6 +66,8 @@ def bbox_filter_gtfs( crs : str, optional What projection should the `bbox_list` be interpreted as. Defaults to "epsg:4326" for lat long. + filter_dates: Union[None, list], optional + A list of dates to restrict the feed to. Defaults to None. Returns ------- @@ -78,6 +93,7 @@ def bbox_filter_gtfs( "crs": [crs, str], "out_pth": [out_pth, (str, pathlib.Path)], "in_pth": [in_pth, (str, pathlib.Path)], + "filter_dates": [filter_dates, (type(None), list)], } for k, v in typing_dict.items(): _type_defence(v[0], k, v[-1]) @@ -102,6 +118,18 @@ def bbox_filter_gtfs( feed = gk.read_feed(in_pth, dist_units=units) restricted_feed = gk.miscellany.restrict_to_area(feed=feed, area=bbox) + # optionally retrict to a date + if filter_dates is not None: + _check_iterable(filter_dates, "filter_dates", list, exp_type=str) + # check date format is acceptable + [_validate_datestring(x) for x in filter_dates] + feed_dates = restricted_feed.get_dates() + diff = set(filter_dates).difference(feed_dates) + if diff: + raise ValueError(f"{diff} not present in feed dates.") + restricted_feed = gk.miscellany.restrict_to_dates( + restricted_feed, filter_dates + ) restricted_feed.write(out_pth) print(f"Filtered feed written to {out_pth}.")