diff --git a/.gitignore b/.gitignore
index 686e6938..339b1175 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,11 +2,13 @@
**/*.pbf
**/*.mapdb
**/*.mapdb.p
-# moved zip blanket rule above specific exception for test fixture below
+# moved blanket rules above specific exceptions for test fixtures
*.zip
+*.pkl
# except test fixtures
!tests/data/newport-2023-06-13.osm.pbf
!tests/data/newport-20230613_gtfs.zip
+!tests/data/gtfs/route_lookup.pkl
### Project structure ###
data/*
@@ -36,7 +38,6 @@ outputs/*
*.html
*.pdf
*.csv
-*.pkl
*.rds
*.rda
*.parquet
diff --git a/conftest.py b/conftest.py
index 7e6a77f4..3e7bdd1e 100644
--- a/conftest.py
+++ b/conftest.py
@@ -16,14 +16,32 @@ def pytest_addoption(parser):
default=False,
help="run set-up tests",
)
+ parser.addoption(
+ "--runinteg",
+ action="store_true",
+ default=False,
+ help="run integration tests",
+ )
+ parser.addoption(
+ "--runexpensive",
+ action="store_true",
+ default=False,
+ help="run expensive tests",
+ )
def pytest_configure(config):
"""Add ini value line."""
config.addinivalue_line("markers", "setup: mark test to run during setup")
+ config.addinivalue_line(
+ "markers", "runinteg: mark test to run for integration tests"
+ )
+ config.addinivalue_line(
+ "markers", "runexpensive: mark test to run expensive tests"
+ )
-def pytest_collection_modifyitems(config, items):
+def pytest_collection_modifyitems(config, items): # noqa C901
"""Handle switching based on cli args."""
if config.getoption("--runsetup"):
# --runsetup given in cli: do not skip slow tests
@@ -32,3 +50,19 @@ def pytest_collection_modifyitems(config, items):
for item in items:
if "setup" in item.keywords:
item.add_marker(skip_setup)
+
+ if config.getoption("--runinteg"):
+ return
+ skip_runinteg = pytest.mark.skip(reason="need --runinteg option to run")
+ for item in items:
+ if "runinteg" in item.keywords:
+ item.add_marker(skip_runinteg)
+
+ if config.getoption("--runexpensive"):
+ return
+ skip_runexpensive = pytest.mark.skip(
+ reason="need --runexpensive option to run"
+ )
+ for item in items:
+ if "runexpensive" in item.keywords:
+ item.add_marker(skip_runexpensive)
diff --git a/notebooks/gtfs/check_unmatched_id_warnings.py b/notebooks/gtfs/check_unmatched_id_warnings.py
new file mode 100644
index 00000000..21c064c4
--- /dev/null
+++ b/notebooks/gtfs/check_unmatched_id_warnings.py
@@ -0,0 +1,93 @@
+"""Validation of invalid IDs whilst joining GTFS sub-tables."""
+
+# %%
+# imports
+import gtfs_kit as gk
+from pyprojroot import here
+import pandas as pd
+import numpy as np
+
+# %%
+# initialise my feed from GTFS test data
+feed = gk.read_feed(
+ here("tests/data/newport-20230613_gtfs.zip"), dist_units="m"
+)
+feed.validate()
+
+# %%
+# calendar test
+feed.calendar = pd.concat(
+ [
+ feed.calendar,
+ pd.DataFrame(
+ {
+ "service_id": [101],
+ "monday": [0],
+ "tuesday": [0],
+ "wednesday": [0],
+ "thursday": [0],
+ "friday": [0],
+ "saturday": [0],
+ "sunday": [0],
+ "start_date": ["20200104"],
+ "end_date": ["20230301"],
+ }
+ ),
+ ],
+ axis=0,
+)
+
+feed.validate()
+
+# %%
+# trips test
+feed.trips = pd.concat(
+ [
+ feed.trips,
+ pd.DataFrame(
+ {
+ "service_id": [101],
+ "route_id": [20304],
+ "trip_id": ["VJbedb4cfd0673348e017d42435abbdff3ddacbf89"],
+ "trip_headsign": ["Newport"],
+ "block_id": [np.nan],
+ "shape_id": ["RPSPc4c99ac6aff7e4648cbbef785f88427a48efa80f"],
+ "wheelchair_accessible": [0],
+ "trip_direction_name": [np.nan],
+ "vehicle_journey_code": ["VJ109"],
+ }
+ ),
+ ],
+ axis=0,
+)
+
+feed.validate()
+
+# %%
+# routes test
+feed.routes = pd.concat(
+ [
+ feed.routes,
+ pd.DataFrame(
+ {
+ "service_id": [101],
+ "route_id": [20304],
+ "agency_id": ["OL5060"],
+ "route_short_name": ["X145"],
+ "route_long_name": [np.nan],
+ "route_type": [200],
+ }
+ ),
+ ],
+ axis=0,
+)
+
+feed.validate()
+
+# OUTCOME
+# It appears that 'errors' are recognised when there is an attempt to validate
+# the gtfs data using the pre-built gtfs_kit functions.
+# This suggests that if the GTFS data is flawed, it will be identified within
+# the pipeline and therefore the user will be made aware. It is also flagged
+# as an error which means that 'the GTFS is violated'
+# (https://mrcagney.github.io/gtfs_kit_docs/).
diff --git a/pipeline/gtfs/01-validate-gtfs.py b/pipeline/gtfs/01-validate-gtfs.py
new file mode 100644
index 00000000..71821446
--- /dev/null
+++ b/pipeline/gtfs/01-validate-gtfs.py
@@ -0,0 +1,112 @@
+"""Run the GTFS validation checks for the toml-specified GTFS file.
+
+1. read feed
+2. describe feed
+3. validate feed
+4. clean feed
+5. new - print errors / warnings in full
+6. new - visualise convex hull of stops and area
+7. visualise stop locations
+8. new - modalities available (including extended spec)
+9. new - feed stats by is-weekend
+"""
+import toml
+from pyprojroot import here
+import time
+import subprocess
+
+from transport_performance.gtfs.validation import GtfsInstance
+from transport_performance.utils.defence import _is_gtfs_pth
+
+CONFIG = toml.load(here("pipeline/gtfs/config/01-validate-gtfs.toml"))
+GTFS_PTH = here(CONFIG["GTFS"]["PATH"])
+UNITS = CONFIG["GTFS"]["UNITS"]
+GEOM_CRS = CONFIG["GTFS"]["GEOMETRIC_CRS"]
+POINT_MAP_PTH = CONFIG["MAPS"]["STOP_COORD_PTH"]
+HULL_MAP_PATH = CONFIG["MAPS"]["STOP_HULL_PTH"]
+PROFILING = CONFIG["UTILS"]["PROFILING"]
+# check GTFS Path exists
+_is_gtfs_pth(pth=GTFS_PTH, param_nm="GTFS_PTH", check_existing=True)
+# Get the disk usage of the GTFS file.
+gtfs_du = (
+ subprocess.check_output(["du", "-sh", GTFS_PTH]).split()[0].decode("utf-8")
+)
+if PROFILING:
+ print(f"GTFS at {GTFS_PTH} disk usage: {gtfs_du}")
+
+pre_init = time.perf_counter()
+feed = GtfsInstance(gtfs_pth=GTFS_PTH, units=UNITS)
+post_init = time.perf_counter()
+if PROFILING:
+ print(f"Init in {post_init - pre_init:0.4f} seconds")
+
+available_dates = feed.feed.get_dates()
+post_dates = time.perf_counter()
+if PROFILING:
+ print(f"get_dates in {post_dates - post_init:0.4f} seconds")
+s = available_dates[0]
+f = available_dates[-1]
+print(f"{len(available_dates)} dates available between {s} & {f}.")
+
+try:
+ # If agency_id is missing, an AttributeError is raised. GTFS spec states
+ # This is conditionally required, dependent if more than one agency is
+ # operating within the feed. https://gtfs.org/schedule/reference/#agencytxt
+ # Cleaning the feed doesn't resolve. Raise issue to investigate.
+ print(feed.is_valid())
+ post_isvalid = time.perf_counter()
+ if PROFILING:
+ print(f"is_valid in {post_isvalid - post_dates:0.4f} seconds")
+ print(feed.validity_df["type"].value_counts())
+ feed.print_alerts()
+ post_errors = time.perf_counter()
+ feed.print_alerts(alert_type="warning")
+ post_warn = time.perf_counter()
+ if PROFILING:
+ print(f"print_alerts errors: {post_errors - post_isvalid:0.4f} secs")
+ print(f"print_alerts warn: {post_warn - post_errors:0.4f} secs")
+except AttributeError:
+ print("AttributeError. Unable to validate feed.")
+
+pre_clean = time.perf_counter()
+feed.clean_feed()
+post_clean = time.perf_counter()
+if PROFILING:
+ print(f"clean_feed in {post_clean - pre_clean:0.4f} seconds")
+
+try:
+ print(feed.is_valid())
+ print(feed.validity_df["type"].value_counts())
+ feed.print_alerts()
+ feed.print_alerts(alert_type="warning")
+except AttributeError:
+ print("AttributeError. Unable to validate feed.")
+
+# visualise gtfs
+pre_viz_points = time.perf_counter()
+feed.viz_stops(out_pth=POINT_MAP_PTH)
+post_viz_points = time.perf_counter()
+if PROFILING:
+ print(f"viz_points in {post_viz_points - pre_viz_points:0.4f} seconds")
+print(f"Map written to {POINT_MAP_PTH}")
+
+pre_viz_hull = time.perf_counter()
+feed.viz_stops(out_pth=HULL_MAP_PATH, geoms="hull", geom_crs=GEOM_CRS)
+post_viz_hull = time.perf_counter()
+if PROFILING:
+ print(f"viz_hull in {post_viz_hull - pre_viz_hull:0.4f} seconds")
+print(f"Map written to {HULL_MAP_PATH}")
+
+pre_route_modes = time.perf_counter()
+print(feed.get_route_modes())
+post_route_modes = time.perf_counter()
+if PROFILING:
+ print(f"route_modes in {post_route_modes - pre_route_modes:0.4f} seconds")
+
+pre_summ_weekday = time.perf_counter()
+print(feed.summarise_trips())
+print(feed.summarise_routes())
+post_summ_weekday = time.perf_counter()
+if PROFILING:
+ print(f"summ_weekday in {post_summ_weekday - pre_summ_weekday:0.4f} secs")
+ print(f"Pipeline execution in {post_summ_weekday - pre_init:0.4f}")
diff --git a/pipeline/gtfs/config/01-validate-gtfs.toml b/pipeline/gtfs/config/01-validate-gtfs.toml
new file mode 100644
index 00000000..aa5b0299
--- /dev/null
+++ b/pipeline/gtfs/config/01-validate-gtfs.toml
@@ -0,0 +1,13 @@
+title = "Config for GTFS Validation Pipeline"
+
+[GTFS]
+PATH = "data/external/croppednewport-bus-07-07-2022_gtfs.zip"
+UNITS = "m"
+GEOMETRIC_CRS = 27700 # used for area calculations only
+
+[MAPS]
+STOP_COORD_PTH = "outputs/gtfs/validation/gtfs-stops-locations.html"
+STOP_HULL_PTH = "outputs/gtfs/validation/gtfs-stops-convex-hull.html"
+
+[UTILS]
+PROFILING = true
diff --git a/requirements.txt b/requirements.txt
index c080881f..ea88f854 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,12 @@ r5py>=0.0.4
gtfs_kit==5.2.7
pytest
coverage
+ipykernel==6.23.1
+pandas
+beautifulsoup4
+requests
+pytest-mock
+toml
rasterio
pyprojroot
matplotlib
@@ -15,5 +21,6 @@ geocube
mapclassify
pytest-lazy-fixture
seaborn
+numpy>=1.25.0 # test suite will fail if user installed lower than this
rioxarray
-e .
diff --git a/src/transport_performance/gtfs/__init__.py b/src/transport_performance/gtfs/__init__.py
new file mode 100644
index 00000000..5317b442
--- /dev/null
+++ b/src/transport_performance/gtfs/__init__.py
@@ -0,0 +1 @@
+"""Helpers for working with & validating GTFS."""
diff --git a/src/transport_performance/gtfs/routes.py b/src/transport_performance/gtfs/routes.py
new file mode 100644
index 00000000..1488f109
--- /dev/null
+++ b/src/transport_performance/gtfs/routes.py
@@ -0,0 +1,133 @@
+"""Helpers for working with routes.txt."""
+import pandas as pd
+from bs4 import BeautifulSoup
+import requests
+import warnings
+
+from transport_performance.utils.defence import _url_defence, _bool_defence
+
+warnings.filterwarnings(
+ action="ignore", category=DeprecationWarning, module=".*pkg_resources"
+)
+# see https://github.com/datasciencecampus/transport-network-performance/
+# issues/19
+
+
+def _construct_extended_schema_table(some_soup, cd_list, desc_list):
+ """Create the extended table from a soup object. Not exported.
+
+ Parameters
+ ----------
+ some_soup : bs4.BeautifulSoup
+ A bs4 soup representation of `ext_spec_url`.
+ cd_list : list
+ A list of schema codes scraped so far. Will append addiitonal codes to
+ this list.
+ desc_list : list
+ A list of schema descriptions found so far. Will append additional
+ descriptions to this list.
+
+ Returns
+ -------
+ tuple[0]: Proposed extension to route_type codes
+ tuple[1]: Proposed extension to route_type descriptions
+
+ """
+ for i in some_soup.findAll("table"):
+ # target table has 'nice_table' class
+ if i.get("class")[0] == "nice-table":
+ target = i
+
+ for row in target.tbody.findAll("tr"):
+ # Get the table headers
+ found = row.findAll("th")
+ if found:
+ cols = [f.text for f in found]
+ else:
+ # otherwise get the table data
+ dat = [i.text for i in row.findAll("td")]
+ # subset to the required column
+ cd_list.append(dat[cols.index("Code")])
+ desc_list.append(dat[cols.index("Description")])
+
+ return (cd_list, desc_list)
+
+
+def _get_response_text(url):
+ """Return the response & extract the text. Not exported."""
+ r = requests.get(url)
+ t = r.text
+ return t
+
+
+def scrape_route_type_lookup(
+ gtfs_url="https://gtfs.org/schedule/reference/",
+ ext_spec_url=(
+ "https://developers.google.com/transit/gtfs/reference/"
+ "extended-route-types"
+ ),
+ extended_schema=True,
+):
+ """Scrape a lookup of GTFS route_type codes to descriptions.
+
+ Scrapes HTML tables from `gtfs_url` to provide a lookup of `route_type`
+ codes to human readable descriptions. Useful for confirming available
+ modes of transport within a GTFS. If `extended_schema` is True, then also
+ include the proposed extension of route_type to the GTFS.
+
+ Parameters
+ ----------
+ gtfs_url : str
+ The url containing the GTFS accepted route_type codes. Defaults to
+ "https://gtfs.org/schedule/reference/".
+ ext_spec_url : str
+ The url containing the table of the proposed extension to the GTFS
+ schema for route_type codes. Defaults to
+ ( "https://developers.google.com/transit/gtfs/reference/"
+ "extended-route-types" ).
+ extended_schema : bool
+ Should the extended schema table be scraped and included in the output?
+ Defaults to True.
+
+ Returns
+ -------
+ pd.core.frame.DataFrame: A lookup of route_type codes to descriptions.
+
+ """
+ # a little defence
+ for url in [gtfs_url, ext_spec_url]:
+ _url_defence(url)
+
+ _bool_defence(extended_schema)
+ # Get the basic scheme lookup
+ resp_txt = _get_response_text(gtfs_url)
+ soup = BeautifulSoup(resp_txt, "html.parser")
+ for dat in soup.findAll("td"):
+ # Look for a pattern to target, going with Tram, could go more specific
+ # with regex if table format unstable.
+ if "Tram" in dat.text:
+ target_node = dat
+
+ cds = list()
+ txts = list()
+ # the required data is in awkward little inline 'table' that's really
+ # a table row, but helpfully the data is either side of some break
+ # tags
+ for x in target_node.findAll("br"):
+ cds.append(x.nextSibling.text)
+ txts.append(x.previousSibling.text)
+ # strip out rubbish
+ cds = [cd for cd in cds if len(cd) > 0]
+ txts = [t.strip(" - ") for t in txts if t.startswith(" - ")]
+ # catch the final description which is not succeeded by a break
+ txts.append(target_node.text.split(" - ")[-1])
+ # if interested in the extended schema, get that too. Perhaps not
+ # relevant to all territories
+ if extended_schema:
+ resp_txt = _get_response_text(ext_spec_url)
+ soup = BeautifulSoup(resp_txt, "html.parser")
+ cds, txts = _construct_extended_schema_table(soup, cds, txts)
+
+ route_lookup = pd.DataFrame(zip(cds, txts), columns=["route_type", "desc"])
+
+ return route_lookup
diff --git a/src/transport_performance/gtfs/utils.py b/src/transport_performance/gtfs/utils.py
new file mode 100644
index 00000000..7e9d9a67
--- /dev/null
+++ b/src/transport_performance/gtfs/utils.py
@@ -0,0 +1,57 @@
+"""Utility functions for GTFS archives."""
+import gtfs_kit as gk
+import geopandas as gpd
+from shapely.geometry import box
+from pyprojroot import here
+
+from transport_performance.utils.defence import _is_gtfs_pth, _check_list
+
+
+def bbox_filter_gtfs(
+ in_pth=here("tests/data/newport-20230613_gtfs.zip"),
+ out_pth=here("data/external/filtered_gtfs.zip"),
+ bbox_list=[-3.077081, 51.52222, -2.925075, 51.593596],
+ units="m",
+ crs="epsg:4326",
+):
+ """Filter a GTFS feed to any routes intersecting with a bounding box.
+
+ Parameters
+ ----------
+ in_pth : (str, pathlib.PosixPath)
+ Path to the unfiltered GTFS feed. Defaults to
+ here("tests/data/newport-20230613_gtfs.zip").
+ out_pth : (str, pathlib.PosixPath)
+ Path to write the filtered feed to. Defaults to
+ here("data/external/filtered_gtfs.zip").
+ bbox_list : list(float)
+ A list of x and y values in the order of minx, miny, maxx, maxy.
+ Defaults to [-3.077081, 51.52222, -2.925075, 51.593596].
+ units : str
+ Distance units of the original GTFS. Defaults to "m".
+ crs : str
+ What projection should the `bbox_list` be interpreted as. Defaults to
+ "epsg:4326" for lat long.
+
+ Returns
+ -------
+ None
+
+ """
+ _is_gtfs_pth(pth=in_pth, param_nm="in_pth")
+ _is_gtfs_pth(pth=out_pth, param_nm="out_pth", check_existing=False)
+ _check_list(ls=bbox_list, param_nm="bbox_list", exp_type=float)
+ for param in [units, crs]:
+ if not isinstance(param, str):
+ raise TypeError(f"Expected string. Found {type(param)} : {param}")
+
+ # create box polygon around provided coords, need to splat
+ box_poly = box(*bbox_list)
+ # gtfs_kit expects gdf
+ gdf = gpd.GeoDataFrame(index=[0], crs=crs, geometry=[box_poly])
+ feed = gk.read_feed(in_pth, dist_units=units)
+ newport_feed = gk.miscellany.restrict_to_area(feed=feed, area=gdf)
+ newport_feed.write(out_pth)
+ print(f"Filtered feed written to {out_pth}.")
+
+ return None
diff --git a/src/transport_performance/gtfs/validation.py b/src/transport_performance/gtfs/validation.py
new file mode 100644
index 00000000..13cdb917
--- /dev/null
+++ b/src/transport_performance/gtfs/validation.py
@@ -0,0 +1,595 @@
+"""Validating GTFS data."""
+import gtfs_kit as gk
+from pyprojroot import here
+import pandas as pd
+import geopandas as gpd
+import folium
+import datetime
+import numpy as np
+import os
+import inspect
+
+from transport_performance.gtfs.routes import scrape_route_type_lookup
+from transport_performance.utils.defence import (
+ _is_gtfs_pth,
+ _check_namespace_export,
+ _check_parent_dir_exists,
+)
+
+
+def _get_intermediate_dates(
+ start: pd.Timestamp, end: pd.Timestamp
+) -> list[pd.Timestamp]:
+ """Return a list of daily timestamps between two dates.
+
+ Parameters
+ ----------
+ start : pd.Timestamp
+ The start date of the given time period in %Y%m%d format.
+ end : pd.Timestamp
+ The end date of the given time period in %Y%m%d format.
+
+ Returns
+ -------
+ list[pd.Timestamp]
+ A list of daily timestamps for each day in the time period
+
+ """
+ # checks for start and end
+ if not isinstance(start, pd.Timestamp):
+ raise TypeError(
+ "'start' expected type pd.Timestamp."
+ f" Recieved type {type(start)}"
+ )
+ if not isinstance(end, pd.Timestamp):
+ raise TypeError(
+ "'end' expected type pd.Timestamp." f" Recieved type {type(end)}"
+ )
+ result = []
+ while start <= end:
+ result.append(start)
+ start = start + datetime.timedelta(days=1)
+ return result
+
+
+def _create_map_title_text(gdf, units, geom_crs):
+ """Generate the map title text when plotting convex hull.
+
+ Parameters
+ ----------
+ gdf : gpd.GeoDataFrame
+ GeoDataFrame containing the spatial features.
+ units : str
+ Distance units of the GTFS feed from which `gdf` originated.
+ geom_crs : (str, int):
+ The geometric crs to use in reprojecting the data in order to
+ calculate the area of the hull polygon.
+
+ Returns
+ -------
+ str : The formatted text string for presentation in the map title.
+
+ """
+ if units in ["m", "km"]:
+ hull_km2 = gdf.to_crs(geom_crs).area
+ if units == "m":
+ hull_km2 = hull_km2 / 1000000
+ pre = "GTFS Stops Convex Hull Area: "
+ post = " nearest km2."
+ txt = f"{pre}{int(round(hull_km2[0], 0)):,}{post}"
+ else:
+ txt = (
+ "GTFS Stops Convex Hull. Area Calculation for Metric "
+ f"Units Only. Units Found are in {units}."
+ )
+ return txt
+
+
+class GtfsInstance:
+ """Create a feed instance for validation, cleaning & visualisation."""
+
+ def __init__(
+ self, gtfs_pth=here("tests/data/newport-20230613_gtfs.zip"), units="m"
+ ):
+ _is_gtfs_pth(pth=gtfs_pth, param_nm="gtfs_pth")
+
+ # validate units param
+ if not isinstance(units, str):
+ raise TypeError(f"`units` expected a string. Found {type(units)}")
+
+ units = units.lower().strip()
+ if units in ["metres", "meters"]:
+ units = "m"
+ elif units in ["kilometers", "kilometres"]:
+ units = "km"
+ accepted_units = ["m", "km"]
+
+ if units not in accepted_units:
+ raise ValueError(f"`units` accepts metric only. Found: {units}")
+
+ self.feed = gk.read_feed(gtfs_pth, dist_units=units)
+
+ def is_valid(self):
+ """Check a feed is valid with `gtfs_kit`.
+
+ Returns
+ -------
+ pd.core.frame.DataFrame: Table of errors, warnings & their
+ descriptions.
+
+ """
+ self.validity_df = self.feed.validate()
+ return self.validity_df
+
+ def print_alerts(self, alert_type="error"):
+ """Print validity errors & warnins messages in full.
+
+ Parameters
+ ----------
+ alert_type : str, optional
+ The alert type to print messages. Defaults to "error". Also
+ accepts "warning".
+
+ Returns
+ -------
+ None
+
+ """
+ if not hasattr(self, "validity_df"):
+ raise AttributeError(
+ "`self.validity_df` is None, did you forget to use "
+ "`self.is_valid()`?"
+ )
+
+ try:
+ # In cases where no alerts of alert_type are found, KeyError raised
+ msgs = (
+ self.validity_df.set_index("type")
+ .sort_index()
+ .loc[alert_type]["message"]
+ )
+ # multiple errors
+ if isinstance(msgs, pd.core.series.Series):
+ for m in msgs:
+ print(m)
+ # case where single error
+ elif isinstance(msgs, str):
+ print(msgs)
+ except KeyError:
+ print(f"No alerts of type {alert_type} were found.")
+
+ return None
+
+ def clean_feed(self):
+ """Attempt to clean feed using `gtfs_kit`."""
+ try:
+ # In cases where shape_id is missing, keyerror is raised.
+ # https://developers.google.com/transit/gtfs/reference#shapestxt
+ # shows that shapes.txt is optional file.
+ self.feed = self.feed.clean()
+ except KeyError:
+ print("KeyError. Feed was not cleaned.")
+
+ def viz_stops(
+ self, out_pth, geoms="point", geom_crs=27700, create_out_parent=False
+ ):
+ """Visualise the stops on a map as points or convex hull. Writes file.
+
+ Parameters
+ ----------
+ out_pth : str
+ Path to write the map file html document to, including the file
+ name. Must end with '.html' file extension.
+
+ geoms : str
+ Type of map to plot. If `geoms=point` (the default) uses `gtfs_kit`
+ to map point locations of available stops. If `geoms=hull`,
+ calculates the convex hull & its area. Defaults to "point".
+
+ geom_crs : (str, int)
+ Geometric CRS to use for the calculation of the convex hull area
+ only. Defaults to "27700" (OSGB36, British National Grid).
+
+ create_out_parent : bool
+ Should the parent directory of `out_pth` be created if not found.
+
+ Returns
+ -------
+ None
+
+ """
+ # out_pth defence
+ _check_parent_dir_exists(
+ pth=out_pth, param_nm="out_pth", create=create_out_parent
+ )
+
+ pre, ext = os.path.splitext(out_pth)
+ if ext != ".html":
+ print(f"{ext} format not implemented. Writing to .html")
+ out_pth = os.path.normpath(pre + ".html")
+
+ # geoms defence
+ if not isinstance(geoms, str):
+ raise TypeError(f"`geoms` expects a string. Found {type(geoms)}")
+ geoms = geoms.lower().strip()
+ accept_vals = ["point", "hull"]
+ if geoms not in accept_vals:
+ raise ValueError("`geoms` must be either 'point' or 'hull.'")
+
+ # geom_crs defence
+ if not isinstance(geom_crs, (str, int)):
+ raise TypeError(
+ f"`geom_crs` expects string or integer. Found {type(geom_crs)}"
+ )
+
+ try:
+ # map_stops will fail if stop_code not present. According to :
+ # https://developers.google.com/transit/gtfs/reference#stopstxt
+ # This should be an optional column
+ if geoms == "point":
+ # viz stop locations
+ m = self.feed.map_stops(self.feed.stops["stop_id"])
+ elif geoms == "hull":
+ # visualise feed, output to file with area est, based on stops
+ gtfs_hull = self.feed.compute_convex_hull()
+ gdf = gpd.GeoDataFrame(
+ {"geometry": gtfs_hull}, index=[0], crs="epsg:4326"
+ )
+ units = self.feed.dist_units
+ # prepare the map title
+ txt = _create_map_title_text(gdf, units, geom_crs)
+
+ title_pre = "
"
+ title_html = f"{title_pre}{txt}
"
+
+ gtfs_centroid = self.feed.compute_centroid()
+ m = folium.Map(
+ location=[gtfs_centroid.y, gtfs_centroid.x], zoom_start=5
+ )
+ geo_j = gdf.to_json()
+ geo_j = folium.GeoJson(
+ data=geo_j, style_function=lambda x: {"fillColor": "red"}
+ )
+ geo_j.add_to(m)
+ m.get_root().html.add_child(folium.Element(title_html))
+ m.save(out_pth)
+ except KeyError:
+ print("Key Error. Map was not written.")
+
+ def _order_dataframe_by_day(
+ self, df: pd.DataFrame, day_column_name: str = "day"
+ ) -> pd.DataFrame:
+ """Order a dataframe by days of the week in real-world order.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ Input dataframe containing a column with the name
+ of the day of a record
+ day_column_name : str, optional
+ The name of the columns in the pandas dataframe
+ that contains the name of the day of a record,
+ by default "day"
+
+ Returns
+ -------
+ pd.DataFrame
+ The inputted dataframe ordered by the day column
+ (by real world order).
+
+ """
+ # defences for parameters
+ if not isinstance(df, pd.DataFrame):
+ raise TypeError(f"'df' expected type pd.DataFrame, got {type(df)}")
+ if not isinstance(day_column_name, str):
+ raise TypeError(
+ "'day_column_name' expected type str, "
+ f"got {type(day_column_name)}"
+ )
+
+ # hard coded day order
+ day_order = {
+ "monday": 0,
+ "tuesday": 1,
+ "wednesday": 2,
+ "thursday": 3,
+ "friday": 4,
+ "saturday": 5,
+ "sunday": 6,
+ }
+
+ # apply the day order and sort the df
+ df["day_order"] = (
+ df[day_column_name].str.lower().apply(lambda x: day_order[x])
+ )
+ df.sort_values("day_order", ascending=True, inplace=True)
+ df.sort_index(axis=1, inplace=True)
+ df.drop("day_order", inplace=True, axis=1)
+ return df
+
+ def _preprocess_trips_and_routes(self) -> pd.DataFrame:
+ """Create a trips table containing a record for each trip on each date.
+
+ Returns
+ -------
+ pd.DataFrame
+ A dataframe containing a record of every trip on every day
+ from the gtfs feed.
+
+ """
+ # create a calendar lookup (one row = one date, rather than date range)
+ calendar = self.feed.calendar.copy()
+ # convert dates to dt and create a list of dates between them
+ calendar["start_date"] = pd.to_datetime(
+ calendar["start_date"], format="%Y%m%d"
+ )
+ calendar["end_date"] = pd.to_datetime(
+ calendar["end_date"], format="%Y%m%d"
+ )
+ calendar["period"] = [
+ [start, end]
+ for start, end in zip(calendar["start_date"], calendar["end_date"])
+ ]
+
+ calendar["date"] = calendar["period"].apply(
+ lambda x: _get_intermediate_dates(x[0], x[1])
+ )
+ # explode the dataframe into daily rows for each service
+ # (between the given time period)
+ full_calendar = calendar.explode("date")
+ calendar.drop(
+ ["start_date", "end_date", "period"], axis=1, inplace=True
+ )
+
+ # obtain the day of a given date
+ full_calendar["day"] = full_calendar["date"].dt.day_name()
+ full_calendar["day"] = full_calendar["day"].apply(
+ lambda day: day.lower()
+ )
+
+ # reformat the data into a long format and only keep dates
+ # where the service is active.
+ # this ensures that dates are only kept when the service
+ # is running (e.g., Friday)
+ melted_calendar = full_calendar.melt(
+ id_vars=["date", "service_id", "day"], var_name="valid_day"
+ )
+ melted_calendar = melted_calendar[melted_calendar["value"] == 1]
+ melted_calendar = melted_calendar[
+ melted_calendar["day"] == melted_calendar["valid_day"]
+ ][["service_id", "day", "date"]]
+
+ # join the dates to the trip information to then join on the
+ # route_type from the route table
+ trips = self.feed.get_trips().copy()
+ routes = self.feed.get_routes().copy()
+ dated_trips = trips.merge(melted_calendar, on="service_id", how="left")
+ dated_trips_routes = dated_trips.merge(
+ routes, on="route_id", how="left"
+ )
+ return dated_trips_routes
+
+ def _get_pre_processed_trips(self):
+ """Obtain pre-processed trip data."""
+ try:
+ return self.pre_processed_trips.copy()
+ except AttributeError:
+ self.pre_processed_trips = self._preprocess_trips_and_routes()
+ return self.pre_processed_trips.copy()
+
+ def _summary_defence(
+ self,
+ summ_ops: list = [np.min, np.max, np.mean, np.median],
+ return_summary: bool = True,
+ ) -> None:
+ """Check for any invalid parameters in a summarising function.
+
+ Parameters
+ ----------
+ summ_ops : list, optional
+ A list of operators used to get a summary of a given day,
+ by default [np.min, np.max, np.mean, np.median]
+ return_summary : bool, optional
+ When True, a summary is returned. When False, route data
+ for each date is returned,
+ by default True
+
+ Returns
+ -------
+ None
+
+ """
+ if not isinstance(return_summary, bool):
+ raise TypeError(
+ "'return_summary' must be of type boolean."
+ f" Found {type(return_summary)} : {return_summary}"
+ )
+ # summ_ops defence
+
+ if isinstance(summ_ops, list):
+ for i in summ_ops:
+ # updated for numpy >= 1.25.0, this check rules out cases
+ # that are not functions
+ if inspect.isfunction(i) or type(i).__module__ == "numpy":
+ if not _check_namespace_export(pkg=np, func=i):
+ raise TypeError(
+ "Each item in `summ_ops` must be a numpy function."
+ f" Found {type(i)} : {i.__name__}"
+ )
+ else:
+ raise TypeError(
+ (
+ "Each item in `summ_ops` must be a function."
+ f" Found {type(i)} : {i}"
+ )
+ )
+ elif inspect.isfunction(summ_ops):
+ if not _check_namespace_export(pkg=np, func=summ_ops):
+ raise NotImplementedError(
+ "`summ_ops` expects numpy functions only."
+ )
+ else:
+ raise TypeError(
+ "`summ_ops` expects a numpy function or list of numpy"
+ f" functions. Found {type(summ_ops)}"
+ )
+
+ def summarise_trips(
+ self,
+ summ_ops: list = [np.min, np.max, np.mean, np.median],
+ return_summary: bool = True,
+ ) -> pd.DataFrame:
+ """Produce a summarised table of trip statistics by day of week.
+
+ For trip count summaries, func counts distinct trip_id only. These
+ are then summarised into average/median/min/max (default) number
+ of trips per day. Raw data for each date can also be obtained by
+ setting the 'return_summary' parameter to False (bool).
+
+ Parameters
+ ----------
+ summ_ops : list, optional
+ A list of operators used to get a summary of a given day,
+ by default [np.min, np.max, np.mean, np.median]
+ return_summary : bool, optional
+ When True, a summary is returned. When False, trip data
+ for each date is returned,
+ by default True
+
+ Returns
+ -------
+ pd.DataFrame: A dataframe containing either summarized
+ results or dated route data.
+
+ """
+ self._summary_defence(summ_ops=summ_ops, return_summary=return_summary)
+ pre_processed_trips = self._get_pre_processed_trips()
+
+ # clean the trips to ensure that there are no duplicates
+ cleaned_trips = pre_processed_trips[
+ ["date", "day", "trip_id", "route_type"]
+ ].drop_duplicates()
+ trip_counts = cleaned_trips.groupby(["date", "route_type"]).agg(
+ {"trip_id": "count", "day": "first"}
+ )
+ trip_counts.reset_index(inplace=True)
+ trip_counts.rename(
+ mapper={"trip_id": "trip_count"}, axis=1, inplace=True
+ )
+ self.dated_trip_counts = trip_counts.copy()
+ if not return_summary:
+ return self.dated_trip_counts
+
+ # aggregate to mean/median/min/max (default) trips on each day
+ # of the week
+ day_trip_counts = trip_counts.groupby(["day", "route_type"]).agg(
+ {"trip_count": summ_ops}
+ )
+ day_trip_counts.reset_index(inplace=True)
+ day_trip_counts = day_trip_counts.round(0)
+
+ # order the days (for plotting future purposes)
+ # order the days (for plotting future purposes)
+ day_trip_counts = self._order_dataframe_by_day(df=day_trip_counts)
+ day_trip_counts.reset_index(drop=True, inplace=True)
+ self.daily_trip_summary = day_trip_counts.copy()
+ return self.daily_trip_summary
+
+ def summarise_routes(
+ self,
+ summ_ops: list = [np.min, np.max, np.mean, np.median],
+ return_summary: bool = True,
+ ) -> pd.DataFrame:
+ """Produce a summarised table of route statistics by day of week.
+
+ For route count summaries, func counts route_id only, irrespective of
+ which service_id the routes map to. If the services run on different
+ calendar days, they will be counted separately. In cases where more
+ than one service runs the same route on the same day, these will not be
+ counted as distinct routes.
+
+ Parameters
+ ----------
+ summ_ops : list, optional
+ A list of operators used to get a summary of a given day,
+ by default [np.min, np.max, np.mean, np.median]
+ return_summary : bool, optional
+ When True, a summary is returned. When False, route data
+ for each date is returned,
+ by default True
+
+ Returns
+ -------
+ pd.DataFrame: A dataframe containing either summarized
+ results or dated route data.
+
+ """
+ self._summary_defence(summ_ops=summ_ops, return_summary=return_summary)
+ pre_processed_trips = self._get_pre_processed_trips()
+ cleaned_routes = pre_processed_trips[
+ ["route_id", "day", "date", "route_type"]
+ ].drop_duplicates()
+ # group data into route counts per day
+ route_count = (
+ cleaned_routes.groupby(["date", "route_type", "day"])
+ .agg(
+ {
+ "route_id": "count",
+ }
+ )
+ .reset_index()
+ )
+ route_count.rename(
+ mapper={"route_id": "route_count"}, axis=1, inplace=True
+ )
+ self.dated_route_counts = route_count.copy()
+
+ if not return_summary:
+ return self.dated_route_counts
+
+ # aggregate the to the average number of routes
+ # on a given day (e.g., Monday)
+ day_route_count = (
+ route_count.groupby(["day", "route_type"])
+ .agg({"route_count": summ_ops})
+ .reset_index()
+ )
+
+ # order the days (for plotting future purposes)
+ day_route_count = self._order_dataframe_by_day(df=day_route_count)
+ day_route_count = day_route_count.round(0)
+ day_route_count.reset_index(drop=True, inplace=True)
+ self.daily_route_summary = day_route_count.copy()
+
+ return self.daily_route_summary
+
+ def get_route_modes(self):
+ """Summarise the available routes by their associated `route_type`.
+
+ Returns
+ -------
+ pd.core.frame.DataFrame: Summary table of route counts by transport
+ mode.
+
+ """
+ # Get the available modalities
+ lookup = scrape_route_type_lookup()
+ gtfs_route_types = [
+ str(x) for x in self.feed.routes["route_type"].unique()
+ ]
+ # Get readable route_type descriptions
+ out_tab = lookup[
+ lookup["route_type"].isin(gtfs_route_types)
+ ].reset_index(drop=True)
+ out_tab["n_routes"] = (
+ self.feed.routes["route_type"]
+ .value_counts()
+ .reset_index(drop=True)
+ )
+ out_tab["prop_routes"] = (
+ self.feed.routes["route_type"]
+ .value_counts(normalize=True)
+ .reset_index(drop=True)
+ )
+ self.route_mode_summary_df = out_tab
+ return self.route_mode_summary_df
diff --git a/src/transport_performance/utils/defence.py b/src/transport_performance/utils/defence.py
new file mode 100644
index 00000000..bf2322fc
--- /dev/null
+++ b/src/transport_performance/utils/defence.py
@@ -0,0 +1,160 @@
+"""Defensive check utility funcs. Internals only."""
+import pathlib
+import numpy as np
+import os
+
+
+def _is_path_like(pth, param_nm):
+ """Handle path-like parameter values.
+
+ Parameters
+ ----------
+ pth : (str, pathlib.PosixPath)
+ The path to check.
+
+ param_nm : str
+ The name of the parameter being tested.
+
+ Raises
+ ------
+ TypeError: `pth` is not either of string or pathlib.PosixPath.
+
+ Returns
+ -------
+ None
+
+ """
+ if not isinstance(pth, (str, pathlib.Path)):
+ raise TypeError(f"`{param_nm}` expected path-like, found {type(pth)}.")
+
+
+def _check_parent_dir_exists(pth, param_nm, create=False):
+ _is_path_like(pth, param_nm)
+ parent = os.path.dirname(pth)
+ if not os.path.exists(parent):
+ if create:
+ os.mkdir(parent)
+ print(f"Creating parent directory: {parent}")
+ else:
+ raise FileNotFoundError(
+ f"Parent directory {parent} not found on disk."
+ )
+
+ return None
+
+
+def _is_gtfs_pth(pth, param_nm, check_existing=True):
+ """Handle file paths that should be existing GTFS feeds.
+
+ Parameters
+ ----------
+ pth : (str, pathlib.PosixPath)
+ The path to check.
+ param_nm : str
+ The name of the parameter being tested. Helps with debugging.
+ check_existing : bool
+ Whether to check if the GTFS file already exists. Defaults to True.
+
+ Raises
+ ------
+ TypeError: `pth` is not either of string or pathlib.PosixPath.
+ FileExistsError: `pth` does not exist on disk.
+ ValueError: `pth` does not have a `.zip` file extension.
+
+ Returns
+ -------
+ None
+
+ """
+ _is_path_like(pth=pth, param_nm=param_nm)
+
+ _, ext = os.path.splitext(pth)
+ if check_existing and not os.path.exists(pth):
+ raise FileExistsError(f"{pth} not found on file.")
+ if ext != ".zip":
+ raise ValueError(
+ f"`gtfs_pth` expected a zip file extension. Found {ext}"
+ )
+
+ return None
+
+
+def _check_namespace_export(pkg=np, func=np.min):
+ """Check that a function is exported from the specified namespace.
+
+ Parameters
+ ----------
+ pkg : module
+ The package to check. If imported as alias, must use alias. Defaults to
+ np.
+
+ func : function
+ The function to check is exported from pkg. Defaults to np.mean.
+
+ Returns
+ -------
+ bool: True if func is exported from pkg namespace.
+
+ """
+ return hasattr(pkg, func.__name__)
+
+
+def _url_defence(url):
+ """Defence checking. Not exported."""
+ if not isinstance(url, str):
+ raise TypeError(f"url {url} expected string, instead got {type(url)}")
+ elif not url.startswith((r"http://", r"https://")):
+ raise ValueError(f"url string expected protocol, instead found {url}")
+
+ return None
+
+
+def _bool_defence(some_bool):
+ """Defence checking. Not exported."""
+ if not isinstance(some_bool, bool):
+ raise TypeError(
+ f"`extended_schema` expected boolean. Got {type(some_bool)}"
+ )
+
+ return None
+
+
+def _check_list(ls, param_nm, check_elements=True, exp_type=str):
+ """Check a list and its elements for type.
+
+ Parameters
+ ----------
+ ls : list
+ List to check.
+ param_nm : str
+ Name of the parameter being checked.
+ check_elements : (bool, optional)
+ Whether to check the list element types. Defaults to True.
+ exp_type : (_type_, optional):
+ The expected type of the elements. Defaults to str.
+
+ Raises
+ ------
+ TypeError: `ls` is not a list.
+ TypeError: Elements of `ls` are not of the expected type.
+
+ Returns
+ -------
+ None
+
+ """
+ if not isinstance(ls, list):
+ raise TypeError(
+ f"`{param_nm}` should be a list. Instead found {type(ls)}"
+ )
+ if check_elements:
+ for i in ls:
+ if not isinstance(i, exp_type):
+ raise TypeError(
+ (
+ f"`{param_nm}` must contain {str(exp_type)} only."
+ f" Found {type(i)} : {i}"
+ )
+ )
+
+ return None
diff --git a/tests/data/gtfs/route_lookup.pkl b/tests/data/gtfs/route_lookup.pkl
new file mode 100644
index 00000000..504238a1
Binary files /dev/null and b/tests/data/gtfs/route_lookup.pkl differ
diff --git a/tests/gtfs/test_routes.py b/tests/gtfs/test_routes.py
new file mode 100644
index 00000000..38eaa23a
--- /dev/null
+++ b/tests/gtfs/test_routes.py
@@ -0,0 +1,127 @@
+"""Testing routes module."""
+import pytest
+import pandas as pd
+from pyprojroot import here
+from unittest.mock import call
+
+from transport_performance.gtfs.routes import scrape_route_type_lookup
+
+
+def mocked__get_response_text(*args):
+ """Mock _get_response_text.
+
+ Returns
+ -------
+ str: Minimal text representation of url tables.
+
+ """
+ k1 = "https://gtfs.org/schedule/reference/"
+ v1 = "
0 - Tram."
+ k2 = (
+ "https://developers.google.com/transit/gtfs/reference/"
+ "extended-route-types"
+ )
+ v2 = """
+
+
+ Code |
+ Description |
+ Supported |
+ Examples |
+
+
+ 100 |
+ Railway Service |
+ Yes |
+ Not applicable (N/A) |
+ """
+
+ return_vals = {k1: v1, k2: v2}
+ return return_vals[args[0]]
+
+
+class TestScrapeRouteTypeLookup(object):
+ """Test scrape_route_type_lookup."""
+
+ def test_defensive_exceptions(self):
+ """Test the defensive checks raise as expected."""
+ with pytest.raises(
+ TypeError,
+ match=r"url 1 expected string, instead got ",
+ ):
+ scrape_route_type_lookup(gtfs_url=1)
+ with pytest.raises(
+ TypeError,
+ match=r"url False expected string, instead got ",
+ ):
+ scrape_route_type_lookup(ext_spec_url=False)
+ with pytest.raises(
+ ValueError,
+ match="url string expected protocol, instead found foobar",
+ ):
+ scrape_route_type_lookup(gtfs_url="foobar")
+ with pytest.raises(
+ TypeError,
+ match=r"`extended_schema` expected boolean. Got ",
+ ):
+ scrape_route_type_lookup(extended_schema="True")
+
+ def test_table_without_extended_schema(self, mocker):
+ """Check the return object when extended_schema = False."""
+ patch_resp_txt = mocker.patch(
+ "transport_performance.gtfs.routes._get_response_text",
+ side_effect=mocked__get_response_text,
+ )
+ result = scrape_route_type_lookup(extended_schema=False)
+ # did the mocker get used
+ found = patch_resp_txt.call_args_list
+ assert found == [
+ call("https://gtfs.org/schedule/reference/")
+ ], f"Expected mocker was called with specific url but found: {found}"
+ assert isinstance(
+ result, pd.core.frame.DataFrame
+ ), f"Expected DF but found: {type(result)}"
+ pd.testing.assert_frame_equal(
+ result,
+ pd.DataFrame({"route_type": "0", "desc": "Tram."}, index=[0]),
+ )
+
+ def test_table_with_extended_schema(self, mocker):
+ """Check return table when extended schema = True."""
+ patch_resp_txt = mocker.patch(
+ "transport_performance.gtfs.routes._get_response_text",
+ side_effect=mocked__get_response_text,
+ )
+ result = scrape_route_type_lookup()
+ found = patch_resp_txt.call_args_list
+ assert found == [
+ call("https://gtfs.org/schedule/reference/"),
+ call(
+ (
+ "https://developers.google.com/transit/gtfs/reference/"
+ "extended-route-types"
+ )
+ ),
+ ], f"Expected mocker to be called with specific urls. Found: {found}"
+
+ assert isinstance(
+ result, pd.core.frame.DataFrame
+ ), f"Expected DF. Found: {type(result)}"
+ pd.testing.assert_frame_equal(
+ result,
+ pd.DataFrame(
+ {
+ "route_type": ["0", "100"],
+ "desc": ["Tram.", "Railway Service"],
+ },
+ index=[0, 1],
+ ),
+ )
+
+ @pytest.mark.runinteg
+ def test_lookup_is_stable(self):
+ """Check if the tables at the urls have changed content."""
+ # import the expected fixtures
+ lookup_fix = pd.read_pickle(here("tests/data/gtfs/route_lookup.pkl"))
+ lookup = scrape_route_type_lookup()
+ pd.testing.assert_frame_equal(lookup, lookup_fix)
diff --git a/tests/gtfs/test_utils.py b/tests/gtfs/test_utils.py
new file mode 100644
index 00000000..c5595ed1
--- /dev/null
+++ b/tests/gtfs/test_utils.py
@@ -0,0 +1,41 @@
+"""Test GTFS utility functions."""
+
+from pyprojroot import here
+import os
+import pytest
+
+from transport_performance.gtfs.utils import bbox_filter_gtfs
+from transport_performance.gtfs.validation import GtfsInstance
+
+
+class TestBboxFilterGtfs(object):
+ """Test bbox_filter_gtfs."""
+
+ def test_bbox_filter_gtfs_defence(self):
+ """Check defensive behaviour for bbox_filter_gtfs."""
+ with pytest.raises(
+ TypeError, match="Expected string. Found : False"
+ ):
+ bbox_filter_gtfs(units=False)
+
+ def test_bbox_filter_gtfs_writes_as_expected(self, tmpdir):
+ """Test bbox_filter_gtfs writes out a filtered GTFS archive."""
+ tmp_out = os.path.join(tmpdir, "newport-train-station_gtfs.zip")
+ bbox_filter_gtfs(
+ in_pth=here("tests/data/newport-20230613_gtfs.zip"),
+ out_pth=tmp_out,
+ bbox_list=[
+ -3.0017783334,
+ 51.5874718209,
+ -2.9964692194,
+ 51.5907034241,
+ ], # tiny bounding box over newport train station
+ )
+ assert os.path.exists(
+ tmp_out
+ ), f"Expected {tmp_out} to exist but it did not."
+ # check the output gtfs can be read
+ feed = GtfsInstance(gtfs_pth=tmp_out)
+ assert isinstance(
+ feed, GtfsInstance
+ ), f"Expected class `Gtfs_Instance but found: {type(feed)}`"
diff --git a/tests/gtfs/test_validation.py b/tests/gtfs/test_validation.py
new file mode 100644
index 00000000..79412ebc
--- /dev/null
+++ b/tests/gtfs/test_validation.py
@@ -0,0 +1,571 @@
+"""Tests for validation module."""
+import pytest
+from pyprojroot import here
+import gtfs_kit as gk
+import pandas as pd
+from unittest.mock import patch, call
+import os
+from geopandas import GeoDataFrame
+import numpy as np
+import re
+
+from transport_performance.gtfs.validation import (
+ GtfsInstance,
+ _get_intermediate_dates,
+ _create_map_title_text,
+)
+
+
+@pytest.fixture(scope="function") # some funcs expect cleaned feed others dont
+def gtfs_fixture():
+ """Fixture for test funcs expecting a valid feed object."""
+ gtfs = GtfsInstance()
+ return gtfs
+
+
+class TestGtfsInstance(object):
+ """Tests related to the GtfsInstance class."""
+
+ def test_init_defensive_behaviours(self):
+ """Testing parameter validation on class initialisation."""
+ with pytest.raises(
+ TypeError,
+ match=r"`gtfs_pth` expected path-like, found .",
+ ):
+ GtfsInstance(gtfs_pth=1)
+ with pytest.raises(
+ FileExistsError, match=r"doesnt/exist not found on file."
+ ):
+ GtfsInstance(gtfs_pth="doesnt/exist")
+ # a case where file is found but not a zip directory
+ with pytest.raises(
+ ValueError,
+ match=r"`gtfs_pth` expected a zip file extension. Found .pbf",
+ ):
+ GtfsInstance(
+ gtfs_pth=here("tests/data/newport-2023-06-13.osm.pbf")
+ )
+ # handling units
+ with pytest.raises(
+ TypeError, match=r"`units` expected a string. Found "
+ ):
+ GtfsInstance(units=False)
+ # non metric units
+ with pytest.raises(
+ ValueError, match=r"`units` accepts metric only. Found: miles"
+ ):
+ GtfsInstance(units="Miles") # imperial units not implemented
+
+ def test_init_on_pass(self):
+ """Assertions about the feed attribute."""
+ gtfs = GtfsInstance()
+ assert isinstance(
+ gtfs.feed, gk.feed.Feed
+ ), f"GExpected gtfs_kit feed attribute. Found: {type(gtfs.feed)}"
+ assert (
+ gtfs.feed.dist_units == "m"
+ ), f"Expected 'm', found: {gtfs.feed.dist_units}"
+ # can coerce to correct distance unit?
+ gtfs1 = GtfsInstance(units="kilometers")
+ assert (
+ gtfs1.feed.dist_units == "km"
+ ), f"Expected 'km', found: {gtfs1.feed.dist_units}"
+ gtfs2 = GtfsInstance(units="metres")
+ assert (
+ gtfs2.feed.dist_units == "m"
+ ), f"Expected 'm', found: {gtfs2.feed.dist_units}"
+
+ def test_is_valid(self, gtfs_fixture):
+ """Assertions about validity_df table."""
+ gtfs_fixture.is_valid()
+ assert isinstance(
+ gtfs_fixture.validity_df, pd.core.frame.DataFrame
+ ), f"Expected DataFrame. Found: {type(gtfs_fixture.validity_df)}"
+ shp = gtfs_fixture.validity_df.shape
+ assert shp == (
+ 7,
+ 4,
+ ), f"Attribute `validity_df` expected a shape of (7,4). Found: {shp}"
+ exp_cols = pd.Index(["type", "message", "table", "rows"])
+ found_cols = gtfs_fixture.validity_df.columns
+ assert (
+ found_cols == exp_cols
+ ).all(), f"Expected columns {exp_cols}. Found: {found_cols}"
+
+ @patch("builtins.print")
+ def test_print_alerts_defence(self, mocked_print, gtfs_fixture):
+ """Check defensive behaviour of print_alerts()."""
+ with pytest.raises(
+ AttributeError,
+ match=r"is None, did you forget to use `self.is_valid()`?",
+ ):
+ gtfs_fixture.print_alerts()
+
+ gtfs_fixture.is_valid()
+ gtfs_fixture.print_alerts(alert_type="doesnt_exist")
+ fun_out = mocked_print.mock_calls
+ assert fun_out == [
+ call("No alerts of type doesnt_exist were found.")
+ ], f"Expected a print about alert_type but found: {fun_out}"
+
+ @patch("builtins.print") # testing print statements
+ def test_print_alerts_single_case(self, mocked_print, gtfs_fixture):
+ """Check alerts print as expected without truncation."""
+ gtfs_fixture.is_valid()
+ gtfs_fixture.print_alerts()
+ # fixture contains single error
+ fun_out = mocked_print.mock_calls
+ assert fun_out == [
+ call("Invalid route_type; maybe has extra space characters")
+ ], f"Expected a print about invalid route type. Found {fun_out}"
+
+ @patch("builtins.print")
+ def test_print_alerts_multi_case(self, mocked_print, gtfs_fixture):
+ """Check multiple alerts are printed as expected."""
+ gtfs_fixture.is_valid()
+ # fixture contains several warnings
+ gtfs_fixture.print_alerts(alert_type="warning")
+ fun_out = mocked_print.mock_calls
+ assert fun_out == [
+ call("Unrecognized column agency_noc"),
+ call("Repeated pair (route_short_name, route_long_name)"),
+ call("Unrecognized column stop_direction_name"),
+ call("Unrecognized column platform_code"),
+ call("Unrecognized column trip_direction_name"),
+ call("Unrecognized column vehicle_journey_code"),
+ ], f"Expected print statements about GTFS warnings. Found: {fun_out}"
+
+ @patch("builtins.print")
+ def test_viz_stops_defence(self, mocked_print, gtfs_fixture):
+ """Check defensive behaviours of viz_stops()."""
+ with pytest.raises(
+ TypeError,
+ match="`out_pth` expected path-like, found ",
+ ):
+ gtfs_fixture.viz_stops(out_pth=True)
+ with pytest.raises(
+ TypeError, match="`geoms` expects a string. Found "
+ ):
+ gtfs_fixture.viz_stops(out_pth="outputs/somefile.html", geoms=38)
+ with pytest.raises(
+ ValueError, match="`geoms` must be either 'point' or 'hull."
+ ):
+ gtfs_fixture.viz_stops(
+ out_pth="outputs/somefile.html", geoms="foobar"
+ )
+ with pytest.raises(
+ TypeError,
+ match="`geom_crs`.*string or integer. Found ",
+ ):
+ gtfs_fixture.viz_stops(
+ out_pth="outputs/somefile.html", geom_crs=1.1
+ )
+ # check missing stop_id results in print instead of exception
+ gtfs_fixture.feed.stops.drop("stop_id", axis=1, inplace=True)
+ gtfs_fixture.viz_stops(out_pth="outputs/out.html")
+ fun_out = mocked_print.mock_calls
+ assert fun_out == [
+ call("Key Error. Map was not written.")
+ ], f"Expected confirmation that map was not written. Found: {fun_out}"
+
+ @patch("builtins.print")
+ def test_viz_stops_point(self, mock_print, tmpdir, gtfs_fixture):
+ """Check behaviour of viz_stops when plotting point geom."""
+ tmp = os.path.join(tmpdir, "points.html")
+ gtfs_fixture.viz_stops(out_pth=tmp)
+ assert os.path.exists(
+ tmp
+ ), f"{tmp} was expected to exist but it was not found."
+ # check behaviour when parent directory doesn't exist
+ no_parent_pth = os.path.join(tmpdir, "notfound", "points1.html")
+ gtfs_fixture.viz_stops(out_pth=no_parent_pth, create_out_parent=True)
+ assert os.path.exists(
+ no_parent_pth
+ ), f"{no_parent_pth} was expected to exist but it was not found."
+ # check behaviour when not implemented fileext used
+ tmp1 = os.path.join(tmpdir, "points2.svg")
+ gtfs_fixture.viz_stops(out_pth=tmp1)
+ # need to use regex for the first print statement, as tmpdir will
+ # change.
+ start_pat = re.compile(r"Creating parent directory:.*")
+ out = mock_print.mock_calls[0].__str__()
+ assert bool(
+ start_pat.search(out)
+ ), f"Print statement about directory creation expected. Found: {out}"
+ out_last = mock_print.mock_calls[-1]
+ assert out_last == call(
+ ".svg format not implemented. Writing to .html"
+ ), f"Expected print statement about .svg. Found: {out_last}"
+ write_pth = os.path.join(tmpdir, "points2.html")
+ assert os.path.exists(
+ write_pth
+ ), f"Map should have been written to {write_pth} but was not found."
+
+ def test_viz_stops_hull(self, tmpdir, gtfs_fixture):
+ """Check viz_stops behaviour when plotting hull geom."""
+ tmp = os.path.join(tmpdir, "hull.html")
+ gtfs_fixture.viz_stops(out_pth=tmp, geoms="hull")
+ assert os.path.exists(
+ tmp
+ ), f"Map should have been written to {tmp} but was not found."
+
+ def test__create_map_title_text(self):
+ """Check helper can cope with non-metric cases."""
+ gdf = GeoDataFrame()
+ txt = _create_map_title_text(gdf=gdf, units="miles", geom_crs=27700)
+ assert txt == (
+ "GTFS Stops Convex Hull. Area Calculation for Metric Units Only. "
+ "Units Found are in miles."
+ ), f"Unexpected text output: {txt}"
+
+ def test__get_intermediate_dates(self):
+ """Check function can handle valid and invalid arguments."""
+ # invalid arguments
+ with pytest.raises(
+ TypeError,
+ match="'start' expected type pd.Timestamp."
+ " Recieved type ",
+ ):
+ _get_intermediate_dates(
+ start="2023-05-02", end=pd.Timestamp("2023-05-08")
+ )
+ with pytest.raises(
+ TypeError,
+ match="'end' expected type pd.Timestamp."
+ " Recieved type ",
+ ):
+ _get_intermediate_dates(
+ start=pd.Timestamp("2023-05-02"), end="2023-05-08"
+ )
+
+ # valid arguments
+ dates = _get_intermediate_dates(
+ pd.Timestamp("2023-05-01"), pd.Timestamp("2023-05-08")
+ )
+ assert dates == [
+ pd.Timestamp("2023-05-01"),
+ pd.Timestamp("2023-05-02"),
+ pd.Timestamp("2023-05-03"),
+ pd.Timestamp("2023-05-04"),
+ pd.Timestamp("2023-05-05"),
+ pd.Timestamp("2023-05-06"),
+ pd.Timestamp("2023-05-07"),
+ pd.Timestamp("2023-05-08"),
+ ]
+
+ def test__order_dataframe_by_day_defence(self, gtfs_fixture):
+ """Test __order_dataframe_by_day defences."""
+ with pytest.raises(
+ TypeError,
+ match="'df' expected type pd.DataFrame, got ",
+ ):
+ (gtfs_fixture._order_dataframe_by_day(df="test"))
+ with pytest.raises(
+ TypeError,
+ match="'day_column_name' expected type str, got ",
+ ):
+ (
+ gtfs_fixture._order_dataframe_by_day(
+ df=pd.DataFrame(), day_column_name=5
+ )
+ )
+
+ def test_get_route_modes(self, gtfs_fixture, mocker):
+ """Assertions about the table returned by get_route_modes()."""
+ patch_scrape_lookup = mocker.patch(
+ "transport_performance.gtfs.validation.scrape_route_type_lookup",
+ # be sure to patch the func wherever it's being called
+ return_value=pd.DataFrame(
+ {"route_type": ["3"], "desc": ["Mocked bus"]}
+ ),
+ )
+ gtfs_fixture.get_route_modes()
+ # check mocker was called
+ assert (
+ patch_scrape_lookup.called
+ ), "mocker.patch `patch_scrape_lookup` was not called."
+ found = gtfs_fixture.route_mode_summary_df["desc"][0]
+ assert found == "Mocked bus", f"Expected 'Mocked bus', found: {found}"
+ assert isinstance(
+ gtfs_fixture.route_mode_summary_df, pd.core.frame.DataFrame
+ ), f"Expected pd df. Found: {type(gtfs_fixture.route_mode_summary_df)}"
+ exp_cols = pd.Index(["route_type", "desc", "n_routes", "prop_routes"])
+ found_cols = gtfs_fixture.route_mode_summary_df.columns
+ assert (
+ found_cols == exp_cols
+ ).all(), f"Expected columns are different. Found: {found_cols}"
+
+ def test__preprocess_trips_and_routes(self, gtfs_fixture):
+ """Check the outputs of _pre_process_trips_and_route() (test data)."""
+ returned_df = gtfs_fixture._preprocess_trips_and_routes()
+ assert isinstance(returned_df, pd.core.frame.DataFrame), (
+ "Expected DF for _preprocess_trips_and_routes() return,"
+ f"found {type(returned_df)}"
+ )
+ expected_columns = pd.Index(
+ [
+ "route_id",
+ "service_id",
+ "trip_id",
+ "trip_headsign",
+ "block_id",
+ "shape_id",
+ "wheelchair_accessible",
+ "trip_direction_name",
+ "vehicle_journey_code",
+ "day",
+ "date",
+ "agency_id",
+ "route_short_name",
+ "route_long_name",
+ "route_type",
+ ]
+ )
+ assert (returned_df.columns == expected_columns).all(), (
+ f"Columns not as expected. Expected {expected_columns},",
+ f"Found {returned_df.columns}",
+ )
+ expected_shape = (281627, 15)
+ assert returned_df.shape == expected_shape, (
+ f"Columns not as expected. Expected {expected_shape},",
+ f"Found {returned_df.shape}",
+ )
+
+ def test_summarise_trips_defence(self, gtfs_fixture):
+ """Defensive checks for summarise_trips()."""
+ with pytest.raises(
+ TypeError,
+ match="Each item in `summ_ops`.*. Found : np.mean",
+ ):
+ gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"])
+ # case where is function but not exported from numpy
+
+ def dummy_func():
+ """Test case func."""
+ return None
+
+ with pytest.raises(
+ TypeError,
+ match=(
+ "Each item in `summ_ops` must be a numpy function. Found"
+ " : dummy_func"
+ ),
+ ):
+ gtfs_fixture.summarise_trips(summ_ops=[np.min, dummy_func])
+ # case where a single non-numpy func is being passed
+ with pytest.raises(
+ NotImplementedError,
+ match="`summ_ops` expects numpy functions only.",
+ ):
+ gtfs_fixture.summarise_trips(summ_ops=dummy_func)
+ with pytest.raises(
+ TypeError,
+ match="`summ_ops` expects a numpy function.*. Found ",
+ ):
+ gtfs_fixture.summarise_trips(summ_ops=38)
+ # cases where return_summary are not of type boolean
+ with pytest.raises(
+ TypeError,
+ match="'return_summary' must be of type boolean."
+ " Found : 5",
+ ):
+ gtfs_fixture.summarise_trips(return_summary=5)
+ with pytest.raises(
+ TypeError,
+ match="'return_summary' must be of type boolean."
+ " Found : true",
+ ):
+ gtfs_fixture.summarise_trips(return_summary="true")
+
+ def test_summarise_routes_defence(self, gtfs_fixture):
+ """Defensive checks for summarise_routes()."""
+ with pytest.raises(
+ TypeError,
+ match="Each item in `summ_ops`.*. Found : np.mean",
+ ):
+ gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"])
+ # case where is function but not exported from numpy
+
+ def dummy_func():
+ """Test case func."""
+ return None
+
+ with pytest.raises(
+ TypeError,
+ match=(
+ "Each item in `summ_ops` must be a numpy function. Found"
+ " : dummy_func"
+ ),
+ ):
+ gtfs_fixture.summarise_routes(summ_ops=[np.min, dummy_func])
+ # case where a single non-numpy func is being passed
+ with pytest.raises(
+ NotImplementedError,
+ match="`summ_ops` expects numpy functions only.",
+ ):
+ gtfs_fixture.summarise_routes(summ_ops=dummy_func)
+ with pytest.raises(
+ TypeError,
+ match="`summ_ops` expects a numpy function.*. Found ",
+ ):
+ gtfs_fixture.summarise_routes(summ_ops=38)
+ # cases where return_summary are not of type boolean
+ with pytest.raises(
+ TypeError,
+ match="'return_summary' must be of type boolean."
+ " Found : 5",
+ ):
+ gtfs_fixture.summarise_routes(return_summary=5)
+ with pytest.raises(
+ TypeError,
+ match="'return_summary' must be of type boolean."
+ " Found : true",
+ ):
+ gtfs_fixture.summarise_routes(return_summary="true")
+
+ @patch("builtins.print")
+ def test_clean_feed_defence(self, mock_print, gtfs_fixture):
+ """Check defensive behaviours of clean_feed()."""
+ # Simulate condition where shapes.txt has no shape_id
+ gtfs_fixture.feed.shapes.drop("shape_id", axis=1, inplace=True)
+ gtfs_fixture.clean_feed()
+ fun_out = mock_print.mock_calls
+ assert fun_out == [
+ call("KeyError. Feed was not cleaned.")
+ ], f"Expected print statement about KeyError. Found: {fun_out}."
+
+ def test_summarise_trips_on_pass(self, gtfs_fixture):
+ """Assertions about the outputs from summarise_trips()."""
+ gtfs_fixture.summarise_trips()
+ # tests the daily_routes_summary return schema
+ assert isinstance(
+ gtfs_fixture.daily_trip_summary, pd.core.frame.DataFrame
+ ), (
+ "Expected DF for daily_summary,"
+ f"found {type(gtfs_fixture.daily_trip_summary)}"
+ )
+
+ found_ds = gtfs_fixture.daily_trip_summary.columns
+ exp_cols_ds = pd.MultiIndex.from_tuples(
+ [
+ ("day", ""),
+ ("route_type", ""),
+ ("trip_count", "max"),
+ ("trip_count", "mean"),
+ ("trip_count", "median"),
+ ("trip_count", "min"),
+ ]
+ )
+
+ assert (
+ found_ds == exp_cols_ds
+ ).all(), f"Columns were not as expected. Found {found_ds}"
+
+ # tests the self.dated_route_counts return schema
+ assert isinstance(
+ gtfs_fixture.dated_trip_counts, pd.core.frame.DataFrame
+ ), (
+ "Expected DF for dated_route_counts,"
+ f"found {type(gtfs_fixture.dated_trip_counts)}"
+ )
+
+ found_drc = gtfs_fixture.dated_trip_counts.columns
+ exp_cols_drc = pd.Index(["date", "route_type", "trip_count", "day"])
+
+ assert (
+ found_drc == exp_cols_drc
+ ).all(), f"Columns were not as expected. Found {found_drc}"
+
+ # tests the output of the daily_route_summary table
+ # using tests/data/newport-20230613_gtfs.zip
+ expected_df = {
+ ("day", ""): {8: "friday", 9: "friday"},
+ ("route_type", ""): {8: 3, 9: 200},
+ ("trip_count", "max"): {8: 1211, 9: 90},
+ ("trip_count", "mean"): {8: 1211.0, 9: 88.0},
+ ("trip_count", "median"): {8: 1211.0, 9: 88.0},
+ ("trip_count", "min"): {8: 1211, 9: 88},
+ }
+
+ found_df = gtfs_fixture.daily_trip_summary[
+ gtfs_fixture.daily_trip_summary["day"] == "friday"
+ ].to_dict()
+ assert (
+ found_df == expected_df
+ ), f"Daily summary not as expected. Found {found_df}"
+
+ # test that the dated_trip_counts can be returned
+ expected_size = (542, 4)
+ found_size = gtfs_fixture.summarise_trips(return_summary=False).shape
+ assert expected_size == found_size, (
+ "Size of date_route_counts not as expected. "
+ "Expected {expected_size}"
+ )
+
+ def test_summarise_routes_on_pass(self, gtfs_fixture):
+ """Assertions about the outputs from summarise_routes()."""
+ gtfs_fixture.summarise_routes()
+ # tests the daily_routes_summary return schema
+ assert isinstance(
+ gtfs_fixture.daily_route_summary, pd.core.frame.DataFrame
+ ), (
+ "Expected DF for daily_summary,"
+ f"found {type(gtfs_fixture.daily_route_summary)}"
+ )
+
+ found_ds = gtfs_fixture.daily_route_summary.columns
+ exp_cols_ds = pd.MultiIndex.from_tuples(
+ [
+ ("day", ""),
+ ("route_count", "max"),
+ ("route_count", "mean"),
+ ("route_count", "median"),
+ ("route_count", "min"),
+ ("route_type", ""),
+ ]
+ )
+
+ assert (
+ found_ds == exp_cols_ds
+ ).all(), f"Columns were not as expected. Found {found_ds}"
+
+ # tests the self.dated_route_counts return schema
+ assert isinstance(
+ gtfs_fixture.dated_route_counts, pd.core.frame.DataFrame
+ ), (
+ "Expected DF for dated_route_counts,"
+ f"found {type(gtfs_fixture.dated_route_counts)}"
+ )
+
+ found_drc = gtfs_fixture.dated_route_counts.columns
+ exp_cols_drc = pd.Index(["date", "route_type", "day", "route_count"])
+
+ assert (
+ found_drc == exp_cols_drc
+ ).all(), f"Columns were not as expected. Found {found_drc}"
+
+ # tests the output of the daily_route_summary table
+ # using tests/data/newport-20230613_gtfs.zip
+ expected_df = {
+ ("day", ""): {8: "friday", 9: "friday"},
+ ("route_count", "max"): {8: 74, 9: 10},
+ ("route_count", "mean"): {8: 74.0, 9: 9.0},
+ ("route_count", "median"): {8: 74.0, 9: 9.0},
+ ("route_count", "min"): {8: 74, 9: 9},
+ ("route_type", ""): {8: 3, 9: 200},
+ }
+
+ found_df = gtfs_fixture.daily_route_summary[
+ gtfs_fixture.daily_route_summary["day"] == "friday"
+ ].to_dict()
+ assert (
+ found_df == expected_df
+ ), f"Daily summary not as expected. Found {found_df}"
+
+ # test that the dated_route_counts can be returned
+ expected_size = (542, 4)
+ found_size = gtfs_fixture.summarise_routes(return_summary=False).shape
+ assert expected_size == found_size, (
+ "Size of date_route_counts not as expected. "
+ "Expected {expected_size}"
+ )
diff --git a/tests/test_setup.py b/tests/test_setup.py
index 63280a26..26bb6880 100644
--- a/tests/test_setup.py
+++ b/tests/test_setup.py
@@ -1,9 +1,9 @@
"""test_setup.py.
Unit tests for testing initial setup. The intention is these tests won't be
-part of the main test suite, and will only by run as needed.
+part of the main test suite, and will only be run as needed.
-TODO: make this run 'on request' only.
+This test module can be run with the pytest flag --runsetup.
"""
import pytest
diff --git a/tests/utils/test_defence.py b/tests/utils/test_defence.py
new file mode 100644
index 00000000..2c2f8fdf
--- /dev/null
+++ b/tests/utils/test_defence.py
@@ -0,0 +1,62 @@
+"""Tests for defence.py. These internals may be covered elsewhere."""
+import pytest
+
+from transport_performance.utils.defence import (
+ _check_list,
+ _check_parent_dir_exists,
+)
+
+
+class Test_CheckList(object):
+ """Test internal _check_list."""
+
+ def test__check_list_only(self):
+ """Func raises as expected when not checking list elements."""
+ with pytest.raises(
+ TypeError,
+ match="`some_bool` should be a list. Instead found ",
+ ):
+ _check_list(ls=True, param_nm="some_bool", check_elements=False)
+
+ def test__check_list_elements(self):
+ """Func raises as expected when checking list elements."""
+ with pytest.raises(
+ TypeError,
+ match=(
+ "`mixed_list` must contain only. Found "
+ " : 2"
+ ),
+ ):
+ _check_list(
+ ls=[1, "2", 3],
+ param_nm="mixed_list",
+ check_elements=True,
+ exp_type=int,
+ )
+
+ def test__check_list_passes(self):
+ """Test returns None when pass conditions met."""
+ assert (
+ _check_list(ls=[1, 2, 3], param_nm="int_list", exp_type=int)
+ is None
+ )
+ assert (
+ _check_list(
+ ls=[False, True], param_nm="bool_list", check_elements=False
+ )
+ is None
+ )
+
+
+class Test_CheckParentDirExists(object):
+ """Assertions for check_parent_dir_exists."""
+
+ def test_check_parent_dir_exists_defence(self):
+ """Check defence for _check_parent_dir_exists()."""
+ with pytest.raises(
+ FileNotFoundError,
+ match="Parent directory missing not found on disk.",
+ ):
+ _check_parent_dir_exists(
+ pth="missing/file.someext", param_nm="not_found", create=False
+ )
|