diff --git a/src/transport_performance/gtfs/multi_validation.py b/src/transport_performance/gtfs/multi_validation.py index 0feb475d..aa22e1e4 100644 --- a/src/transport_performance/gtfs/multi_validation.py +++ b/src/transport_performance/gtfs/multi_validation.py @@ -5,6 +5,8 @@ import glob import os +from geopandas import GeoDataFrame + from transport_performance.gtfs.validation import GtfsInstance from transport_performance.utils.defence import ( _type_defence, @@ -124,3 +126,56 @@ def is_valid(self, validation_kwargs: Union[dict, None] = None) -> None: progress.set_description(f"Cleaning GTFS from path {path}") inst.is_valid(**validation_kwargs) return None + + def filter_to_date(self, dates: Union[str, list]) -> None: + """Filter each GTFS to date(s). + + Parameters + ---------- + dates : Union[str, list] + The date(s) to filter the GTFS to + + Returns + ------- + None + + """ + # defences + _type_defence(dates, "dates", (str, list)) + # convert to normalsed format + if isinstance(dates, str): + dates = [dates] + # filter gtfs + progress = tqdm(zip(self.paths, self.instances), total=len(self.paths)) + for path, inst in progress: + progress.set_description(f"Filtering GTFS from path {path}") + inst.filter_to_date(dates=dates) + return None + + def filter_to_bbox( + self, bbox: Union[list, GeoDataFrame], crs: str = "epsg:4326" + ) -> None: + """Filter GTFS to a bbox. + + Parameters + ---------- + bbox : Union[list, GeoDataFrame] + The bbox to filter the GTFS to. Leave as none if the GTFS does not + need to be cropped. Format - [xmin, ymin, xmax, ymax] + crs : str, optional + The CRS of the given bbox, by default "epsg:4326" + + Returns + ------- + None + + """ + # defences + _type_defence(bbox, "bbox", [list, GeoDataFrame]) + _type_defence(crs, "crs", str) + # filter gtfs + progress = tqdm(zip(self.paths, self.instances), total=len(self.paths)) + for path, inst in progress: + progress.set_description(f"Filtering GTFS from path {path}") + inst.filter_to_bbox(bbox=bbox, crs=crs) + return None diff --git a/src/transport_performance/gtfs/validation.py b/src/transport_performance/gtfs/validation.py index 829432ba..94b9c789 100644 --- a/src/transport_performance/gtfs/validation.py +++ b/src/transport_performance/gtfs/validation.py @@ -1676,7 +1676,7 @@ def filter_to_date(self, dates: Union[str, list]) -> None: Parameters ---------- dates : Union[str, list] - The date(s) to filter to + The date(s) to filter the GTFS to Returns -------