diff --git a/src/transport_performance/gtfs/cleaners.py b/src/transport_performance/gtfs/cleaners.py index cd031473..4c2a6e18 100644 --- a/src/transport_performance/gtfs/cleaners.py +++ b/src/transport_performance/gtfs/cleaners.py @@ -160,3 +160,6 @@ def clean_multiple_stop_fast_travel_warnings( ~gtfs.multiple_stops_invalid["trip_id"].isin(trip_ids) ] return None + + +# TODO: add core_cleaner diff --git a/src/transport_performance/gtfs/validation.py b/src/transport_performance/gtfs/validation.py index 5c519066..e48b34b2 100644 --- a/src/transport_performance/gtfs/validation.py +++ b/src/transport_performance/gtfs/validation.py @@ -16,15 +16,13 @@ from typing import Union, Callable from plotly.graph_objects import Figure as PlotlyFigure -from transport_performance.gtfs.validators import ( - validate_travel_over_multiple_stops, - validate_travel_between_consecutive_stops, - validate_route_type_warnings, -) from transport_performance.gtfs.cleaners import ( clean_consecutive_stop_fast_travel_warnings, clean_multiple_stop_fast_travel_warnings, ) +import transport_performance.gtfs.cleaners as cleaners +import transport_performance.gtfs.validators as gtfs_validators + from transport_performance.gtfs.routes import ( scrape_route_type_lookup, get_saved_route_type_lookup, @@ -43,7 +41,31 @@ TemplateHTML, _set_up_report_dir, ) -from transport_performance.utils.constants import PKG_PATH +from transport_performance.utils.constants import ( + PKG_PATH, +) + +CLEAN_FEED_FUNCTION_MAP = { + "clean_consecutive_stop_fast_travel_warnings": ( + cleaners.clean_consecutive_stop_fast_travel_warnings + ), + "clean_multiple_stop_fast_travel_warnings": ( + cleaners.clean_multiple_stop_fast_travel_warnings + ), +} + +VALIDATE_FEED_FUNC_MAP = { + "core_validation": gtfs_validators.core_validation, + "validate_travel_between_consecutive_stops": ( + gtfs_validators.validate_travel_between_consecutive_stops + ), + "validate_travel_over_multiple_stops": ( + gtfs_validators.validate_travel_over_multiple_stops + ), + "validate_route_type_warnings": ( + gtfs_validators.validate_route_type_warnings + ), +} def _get_intermediate_dates( @@ -338,19 +360,13 @@ def get_gtfs_files(self) -> list: self.file_list = file_list return self.file_list - def is_valid( - self, route_types: bool = True, far_stops: bool = True - ) -> pd.DataFrame: + def is_valid(self, validators: dict = None) -> pd.DataFrame: """Check a feed is valid with `gtfs_kit`. Parameters ---------- - route_types : bool, optional - Whether or not to validate that the 'invalid route_type...' - warnings are valid (if they route type is actually invalid) - far_stops : bool, optional - Whether or not to perform validation for far stops (both - between consecutive stops and over multiple stops) + validators : dict, optional + A dictionary of function name to kwargs mappings. Returns ------- @@ -358,14 +374,42 @@ def is_valid( Table of errors, warnings & their descriptions. """ - _type_defence(route_types, "route_types", bool) - _type_defence(far_stops, "far_stops", bool) - self.validity_df = self.feed.validate() - if route_types: - validate_route_type_warnings(self) - if far_stops: - validate_travel_between_consecutive_stops(self) - validate_travel_over_multiple_stops(self) + _type_defence(validators, "validators", (dict, type(None))) + # create validity df + self.validity_df = pd.DataFrame( + columns=["type", "message", "table", "rows"] + ) + # carry out additional validators + if validators is not None: + # check all keys are known validators + for key in validators.keys(): + if key not in VALIDATE_FEED_FUNC_MAP.keys(): + raise KeyError( + "Function name passed to 'validators' is not a known " + "validator. Known validators include: " + f"{VALIDATE_FEED_FUNC_MAP.keys()}" + ) + for validator in validators: + # check key is str + _type_defence(validator, "Key of validators", str) + # check value is dict or none (for kwargs) + _type_defence( + validators[validator], + f"validators[{validator}]", + (dict, type(None)), + ) + validators[validator] = ( + {} + if validators[validator] is None + else validators[validator] + ) + VALIDATE_FEED_FUNC_MAP[validator]( + gtfs=self, **validators[validator] + ) + # if no validators passed, carry out all validators + else: + for validator in VALIDATE_FEED_FUNC_MAP: + VALIDATE_FEED_FUNC_MAP[validator](gtfs=self) return self.validity_df def print_alerts(self, alert_type: str = "error") -> None: diff --git a/src/transport_performance/gtfs/validators.py b/src/transport_performance/gtfs/validators.py index 7cc204b5..61ca75a6 100644 --- a/src/transport_performance/gtfs/validators.py +++ b/src/transport_performance/gtfs/validators.py @@ -271,7 +271,7 @@ def validate_travel_over_multiple_stops(gtfs: "GtfsInstance") -> None: return far_stops_df -def validate_route_type_warnings(gtfs) -> None: +def validate_route_type_warnings(gtfs: "GtfsInstance") -> None: """Valiidate that the route type warnings are reasonable and just. Parameters @@ -307,3 +307,12 @@ def validate_route_type_warnings(gtfs) -> None: rows=list(route_rows.index), ) return None + + +def core_validation(gtfs): + """Carry out the main validators of gtfs-kit.""" + _gtfs_defence(gtfs, "gtfs") + validation_df = gtfs.feed.validate() + gtfs.validity_df = pd.concat( + [validation_df, gtfs.validity_df], axis=0 + ).reset_index(drop=True) diff --git a/src/transport_performance/utils/constants.py b/src/transport_performance/utils/constants.py index 97b98598..a1521f2b 100644 --- a/src/transport_performance/utils/constants.py +++ b/src/transport_performance/utils/constants.py @@ -1,5 +1,6 @@ """Constants to be used throughout the transport-performance package.""" +from importlib import resources as pkg_resources # + import transport_performance -from importlib import resources as pkg_resources PKG_PATH = pkg_resources.files(transport_performance)