From ae3742e8abc30fdd9b3888a79ce1f051b861c7ac Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Fri, 7 Jun 2024 09:29:47 -0300 Subject: [PATCH 01/18] Update transit.py --- aequilibrae/transit/transit.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aequilibrae/transit/transit.py b/aequilibrae/transit/transit.py index 254594f13..6d07bdf0c 100644 --- a/aequilibrae/transit/transit.py +++ b/aequilibrae/transit/transit.py @@ -48,8 +48,6 @@ def new_gtfs_builder(self, agency, file_path, day="", description="") -> GTFSRou """Returns a ``GTFSRouteSystemBuilder`` object compatible with the project :Arguments: - **agency** (:obj:`str`): Name for the agency this feed refers to (e.g. 'CTA') - **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') **day** (:obj:`str`, *Optional*): Service data contained in this field to be imported (e.g. '2019-10-04') From 174f439ab5c739595f60f296be934c58e0c7bf4c Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Thu, 13 Jun 2024 12:37:54 -0300 Subject: [PATCH 02/18] modifies files --- aequilibrae/transit/column_order.py | 12 ++--- aequilibrae/transit/constants.py | 3 +- aequilibrae/transit/gtfs_loader.py | 55 ++++++++++++++------ aequilibrae/transit/lib_gtfs.py | 39 ++++++-------- aequilibrae/transit/map_matching_graph.py | 2 +- aequilibrae/transit/transit.py | 3 +- aequilibrae/transit/transit_elements/stop.py | 25 +++++---- 7 files changed, 79 insertions(+), 60 deletions(-) diff --git a/aequilibrae/transit/column_order.py b/aequilibrae/transit/column_order.py index 2680b7427..45870606a 100644 --- a/aequilibrae/transit/column_order.py +++ b/aequilibrae/transit/column_order.py @@ -7,10 +7,10 @@ ("agency_name", str), ("agency_url", str), ("agency_timezone", str), - ("agency_lang", str), - ("agency_phone", str), - ("agency_fare_url", str), - ("agency_email", str), + # ("agency_lang", str), + # ("agency_phone", str), + # ("agency_fare_url", str), + # ("agency_email", str), ] ), "routes.txt": OrderedDict( @@ -24,7 +24,7 @@ # ("route_color", str), # ("route_text_color", str), # ("route_sort_order", int), - # ("agency_id", str), + ("agency_id", str), ] ), "trips.txt": OrderedDict( @@ -76,7 +76,7 @@ ("currency_type", str), ("payment_method", int), ("transfers", int), - # ("agency_id", str), + ("agency_id", str), ("transfer_duration", float), ] ), diff --git a/aequilibrae/transit/constants.py b/aequilibrae/transit/constants.py index 8ceaecc84..fdfce006e 100644 --- a/aequilibrae/transit/constants.py +++ b/aequilibrae/transit/constants.py @@ -9,6 +9,7 @@ WALK_LINK_RANGE = 30000000 TRANSIT_LINK_RANGE = 20000000 WALK_AGENCY_ID = 1 +STOP_ID = 1 # 1 for right, -1 for wrong (left) DRIVING_SIDE = 1 @@ -21,7 +22,7 @@ class Constants: trips: Dict[int, int] = {} patterns: Dict[int, int] = {} pattern_lookup: Dict[int, int] = {} - stops: Dict[int, int] = {} + stops: Dict[int, Any] = {} fares: Dict[int, int] = {} links: Dict[int, int] = {} transit_links: Dict[int, int] = {} diff --git a/aequilibrae/transit/gtfs_loader.py b/aequilibrae/transit/gtfs_loader.py index eb3c17a52..394337f79 100644 --- a/aequilibrae/transit/gtfs_loader.py +++ b/aequilibrae/transit/gtfs_loader.py @@ -45,7 +45,7 @@ def __init__(self): self.__capacities__ = {} self.__max_speeds__ = {} self.feed_date = "" - self.agency = Agency() + self.agency: Dict[int, Agency] = {} self.services = {} self.routes: Dict[int, Route] = {} self.trips: Dict[int, Dict[Route]] = {} @@ -80,7 +80,7 @@ def set_feed_path(self, file_path): self.zip_archive.close() self.feed_date = splitext(basename(file_path))[0] - self.__mt = f"Reading GTFS for {self.agency.agency}" + self.__mt = "Reading GTFS" def _set_capacities(self, capacities: dict): self.__capacities__ = capacities @@ -94,10 +94,10 @@ def load_data(self, service_date: str): :Arguments: **service_date** (:obj:`str`): service date. e.g. "2020-04-01". """ - ag_id = self.agency.agency - self.logger.info(f"Loading data for {service_date} from the {ag_id} GTFS feed. This may take some time") + self.service_date = service_date + self.logger.info(f"Loading data for {self.service_date} from the GTFS feed. This may take some time") - self.__mt = f"Reading GTFS for {ag_id}" + self.__mt = "Reading GTFS" self.signal.emit(["start", "master", 6, self.__mt, self.__mt]) self.__load_date() @@ -110,6 +110,8 @@ def finished(self): def __load_date(self): self.logger.debug("Starting __load_date") self.zip_archive = zipfile.ZipFile(self.archive_dir) + + self.__load_agencies() self.__load_routes_table() self.__load_stops_table() self.__load_stop_times() @@ -123,7 +125,7 @@ def __load_date(self): def __deconflict_stop_times(self) -> None: self.logger.info("Starting deconflict_stop_times") - msg_txt = f"Interpolating stop times for {self.agency.agency}" + msg_txt = "Interpolating stop times for feed agencies" self.signal.emit(["start", "secondary", len(self.trips), msg_txt, self.__mt]) total_fast = 0 for prog_counter, route in enumerate(self.trips): @@ -218,10 +220,10 @@ def __load_fare_data(self): for line in range(fareatt.shape[0]): data = tuple(fareatt[line][list(column_order[fareatttxt].keys())]) headers = ["fare_id", "price", "currency", "payment_method", "transfer", "transfer_duration"] - f = Fare(self.agency.agency_id) + f = Fare(self.agency[fareatt[line]["agency_id"]].agency_id) f.populate(data, headers) if f.fare in self.fare_attributes: - self.__fail(f"Fare ID {f.fare} for {self.agency.agency} is duplicated") + self.__fail(f"Fare ID {f.fare} for {self.agency[fareatt[line]['agency_id']].agency} is duplicated") self.fare_attributes[f.fare] = f farerltxt = "fare_rules.txt" @@ -266,7 +268,7 @@ def __load_shapes_table(self): shapes = parse_csv(file, column_order[shapestxt]) all_shape_ids = np.unique(shapes["shape_id"]).tolist() - msg_txt = f"Load shapes - {self.agency.agency}" + msg_txt = f"Load shapes" self.signal.emit(["start", "secondary", len(all_shape_ids), msg_txt, self.__mt]) self.data_arrays[shapestxt] = shapes @@ -291,7 +293,7 @@ def __load_trips_table(self): trips_array = parse_csv(file, column_order[tripstxt]) self.data_arrays[tripstxt] = trips_array - msg_txt = f"Load trips - {self.agency.agency}" + msg_txt = f"Load trips" self.signal.emit(["start", "secondary", trips_array.shape[0], msg_txt, self.__mt]) if np.unique(trips_array["trip_id"]).shape[0] < trips_array.shape[0]: self.__fail("There are repeated trip IDs in trips.txt") @@ -345,7 +347,6 @@ def __load_trips_table(self): trip.source_time = list(stop_times.source_time.values) self.logger.debug(f"{trip.trip} has {len(trip.stops)} stops") trip._stop_based_shape = LineString([self.stops[x].geo for x in trip.stops]) - # trip.shape = self.shapes.get(trip.shape) trip.seated_capacity = self.routes[trip.route].seated_capacity trip.total_capacity = self.routes[trip.route].total_capacity self.trips[trip.route] = self.trips.get(trip.route, {}) @@ -398,7 +399,7 @@ def __load_stop_times(self): with self.zip_archive.open(stoptimestxt, "r") as file: stoptimes = parse_csv(file, column_order[stoptimestxt]) self.data_arrays[stoptimestxt] = stoptimes - msg_txt = f"Load stop times - {self.agency.agency}" + msg_txt = "Load stop times" df = pd.DataFrame(stoptimes) for col in ["arrival_time", "departure_time"]: @@ -459,12 +460,11 @@ def __load_stops_table(self): stops[:]["stop_lat"][:] = lats[:] stops[:]["stop_lon"][:] = lons[:] - msg_txt = f"Load stops - {self.agency.agency}" + msg_txt = "Load stops" self.signal.emit(["start", "secondary", stops.shape[0], msg_txt, self.__mt]) for i, line in enumerate(stops): self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt]) - s = Stop(self.agency.agency_id, line, stops.dtype.names) - s.agency = self.agency.agency + s = Stop(line, stops.dtype.names) s.srid = self.srid s.get_node_id() self.stops[s.stop_id] = s @@ -482,7 +482,7 @@ def __load_routes_table(self): if np.unique(routes["route_id"]).shape[0] < routes.shape[0]: self.__fail("There are repeated route IDs in routes.txt") - msg_txt = f"Load Routes - {self.agency.agency}" + msg_txt = "Load Routes" self.signal.emit(["start", "secondary", len(routes), msg_txt, self.__mt]) cap = self.__capacities__.get("other", [None, None, None]) @@ -493,7 +493,7 @@ def __load_routes_table(self): for i, line in routes.iterrows(): self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt]) - r = Route(self.agency.agency_id) + r = Route(self.agency[line["agency_id"]].agency_id) r.populate(line.values, routes.columns) self.routes[r.route] = r @@ -570,6 +570,27 @@ def __load_feed_calendar(self): if exception_inconsistencies: self.logger.info(" Minor inconsistencies found between calendar.txt and calendar_dates.txt") + def __load_agencies(self): + self.logger.debug("Starting __load_agencies") + agencytxt = "agency.txt" + + self.logger.debug(' Loading "agency" table') + self.agency = {} + with self.zip_archive.open(agencytxt, "r") as file: + agencies = parse_csv(file, column_order[agencytxt]) + self.data_arrays[agencytxt] = agencies + + msg_txt = "Load Agencies" + self.signal.emit(["start", "secondary", len(agencies), msg_txt, self.__mt]) + + for i, line in enumerate(agencies): + self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt]) + a = Agency() + a.agency = line["agency_name"] + a.feed_date = self.feed_date + a.service_date = self.service_date + self.agency[line["agency_id"]] = a + def __fail(self, msg: str) -> None: self.logger.error(msg) raise Exception(msg) diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index 8521ef2c1..6fd7490c1 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -20,13 +20,12 @@ class GTFSRouteSystemBuilder(WorkerThread): """Container for GTFS feeds providing data retrieval for the importer""" - def __init__(self, network, agency_identifier, file_path, day="", description="", capacities={}): # noqa: B006 + def __init__(self, network, file_path, day="", description="", capacities={}): # noqa: B006 """Instantiates a transit class for the network :Arguments: **local network** (:obj:`Network`): Supply model to which this GTFS will be imported - **agency_identifier** (:obj:`str`): ID for the agency this feed refers to (e.g. 'CTA') **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') **day** (:obj:`str`, *Optional*): Service data contained in this field to be imported (e.g. '2019-10-04') **description** (:obj:`str`, *Optional*): Description for this feed (e.g. 'CTA19 fixed by John after coffee') @@ -48,8 +47,6 @@ def __init__(self, network, agency_identifier, file_path, day="", description="" self.graphs = {} self.transformer = Transformer.from_crs("epsg:4326", f"epsg:{self.srid}", always_xy=False) self.sridproj = pyproj.Proj(f"epsg:{self.srid}") - self.gtfs_data.agency.agency = agency_identifier - self.gtfs_data.agency.description = description self.__default_capacities = capacities self.__do_execute_map_matching = False self.__target_date__ = None @@ -133,13 +130,13 @@ def map_match(self, route_types=[3]) -> None: # noqa: B006 if msg is not None: self.logger.warning(msg) - def set_agency_identifier(self, agency_id: str) -> None: - """Adds agency ID to this GTFS for use on import. + # def set_agency_identifier(self, agency_id: str) -> None: + # """Adds agency ID to this GTFS for use on import. - :Arguments: - **agency_id** (:obj:`str`): ID for the agency this feed refers to (e.g. 'CTA') - """ - self.gtfs_data.agency.agency = agency_id + # :Arguments: + # **agency_id** (:obj:`str`): ID for the agency this feed refers to (e.g. 'CTA') + # """ + # self.gtfs_data.agency.agency = agency_id def set_feed(self, feed_path: str) -> None: """Sets GTFS feed source to be used. @@ -180,7 +177,6 @@ def load_date(self, service_date: str) -> None: self.logger.info(" Building data structures") self.__build_data() - self.gtfs_data.agency.service_date = self.day def doWork(self): """Alias for execute_import""" @@ -192,10 +188,10 @@ def execute_import(self): if self.__target_date__ is not None: self.load_date(self.__target_date__) if not self.select_routes: - self.logger.warning(f"Nothing to import for {self.gtfs_data.agency.agency} on {self.day}") + self.logger.warning(f"Nothing to import on {self.day}") return - self.logger.info(f" Importing feed for agency {self.gtfs_data.agency.agency} on {self.day}") + self.logger.info(f" Importing GTFS feed on {self.day}") self.save_to_disk() @@ -207,7 +203,8 @@ def save_to_disk(self): pattern.save_to_database(conn, commit=False) conn.commit() - self.gtfs_data.agency.save_to_database(conn) + for counter, (_, agency) in enumerate(self.gtfs_data.agency.items()): + agency.save_to_database(conn) for counter, trip in enumerate(self.select_trips): trip.save_to_database(conn, commit=False) @@ -222,11 +219,12 @@ def save_to_disk(self): zone_ids2 = {x.destination: x.destination_id for x in self.gtfs_data.fare_rules if x.destination_id >= 0} zone_ids = {**zone_ids1, **zone_ids2} - zones = [[y, x, self.gtfs_data.agency.agency_id] for x, y in list(zone_ids.items())] - if zones: - sql = "Insert into fare_zones (fare_zone_id, transit_zone, agency_id) values(?, ?, ?);" - conn.executemany(sql, zones) - conn.commit() + # fix + # zones = [[y, x, self.gtfs_data.agency.agency_id] for x, y in list(zone_ids.items())] + # if zones: + # sql = "Insert into fare_zones (fare_zone_id, transit_zone, agency_id) values(?, ?, ?);" + # conn.executemany(sql, zones) + # conn.commit() for fare in self.gtfs_data.fare_attributes.values(): fare.save_to_database(conn) @@ -366,9 +364,6 @@ def __get_routes_by_date(self): if not routes: self.logger.warning("NO ROUTES OPERATING FOR THIS DATE") - for route_id, route in routes.items(): - route.agency = self.gtfs_data.agency.agency - self.select_routes = routes def _get_trips_by_date_and_route(self, route_id: int, service_date: str) -> list: diff --git a/aequilibrae/transit/map_matching_graph.py b/aequilibrae/transit/map_matching_graph.py index da5d64885..b412a10a6 100644 --- a/aequilibrae/transit/map_matching_graph.py +++ b/aequilibrae/transit/map_matching_graph.py @@ -41,7 +41,7 @@ def __init__(self, lib_gtfs): self.mode_id = -1 self.__mode = "" self.__df_file = "" - self.__agency = lib_gtfs.gtfs_data.agency.agency + self.__agency = '-'.join([key for key in lib_gtfs.gtfs_data.agency.keys()]) self.__centroids_file = "" self.__mm_graph_file = "" self.node_corresp = [] diff --git a/aequilibrae/transit/transit.py b/aequilibrae/transit/transit.py index 6d07bdf0c..6e8907d87 100644 --- a/aequilibrae/transit/transit.py +++ b/aequilibrae/transit/transit.py @@ -44,7 +44,7 @@ def __init__(self, project): self.create_transit_database() self.pt_con = database_connection("transit") - def new_gtfs_builder(self, agency, file_path, day="", description="") -> GTFSRouteSystemBuilder: + def new_gtfs_builder(self, file_path, day="", description="") -> GTFSRouteSystemBuilder: """Returns a ``GTFSRouteSystemBuilder`` object compatible with the project :Arguments: @@ -59,7 +59,6 @@ def new_gtfs_builder(self, agency, file_path, day="", description="") -> GTFSRou """ gtfs = GTFSRouteSystemBuilder( network=self.project_base_path, - agency_identifier=agency, file_path=file_path, day=day, description=description, diff --git a/aequilibrae/transit/transit_elements/stop.py b/aequilibrae/transit/transit_elements/stop.py index 48f6a5fca..969e77395 100644 --- a/aequilibrae/transit/transit_elements/stop.py +++ b/aequilibrae/transit/transit_elements/stop.py @@ -4,7 +4,11 @@ from shapely.geometry import Point -from aequilibrae.transit.constants import Constants, AGENCY_MULTIPLIER +from contextlib import closing + +from aequilibrae.project.database_connection import database_connection + +from aequilibrae.transit.constants import Constants, STOP_ID from aequilibrae.transit.transit_elements.basic_element import BasicPTElement @@ -12,7 +16,7 @@ class Stop(BasicPTElement): """Transit stop as read from the GTFS feed""" - def __init__(self, agency_id: int, record: tuple, headers: list): + def __init__(self, record: tuple, headers: list): self.stop_id = -1 self.stop = "" self.stop_code = "" @@ -30,8 +34,6 @@ def __init__(self, agency_id: int, record: tuple, headers: list): # Not part of GTFS self.taz = None - self.agency = "" - self.agency_id = agency_id self.link = None self.dir = None self.srid = -1 @@ -55,9 +57,9 @@ def __init__(self, agency_id: int, record: tuple, headers: list): def save_to_database(self, conn: Connection, commit=True) -> None: """Saves Transit Stop to the database""" - sql = """insert into stops (stop_id, stop, agency_id, link, dir, name, + sql = """insert into stops (stop_id, stop, link, dir, name, parent_station, description, street, fare_zone_id, transit_zone, route_type, geometry) - values (?,?,?,?,?,?,?,?,?,?,?,?, GeomFromWKB(?, ?));""" + values (?,?,?,?,?,?,?,?,?,?,?, GeomFromWKB(?, ?));""" dt = self.data conn.execute(sql, dt) @@ -69,7 +71,6 @@ def data(self) -> list: return [ self.stop_id, self.stop, - self.agency_id, self.link, self.dir, self.stop_name, @@ -84,8 +85,10 @@ def data(self) -> list: ] def get_node_id(self): - c = Constants() + with closing(database_connection("transit")) as conn: + sql = "Select coalesce(max(distinct(stop_id)), 0) from stops;" + max_db = int(conn.execute(sql).fetchone()[0]) - val = 1 + c.stops.get(self.agency_id, AGENCY_MULTIPLIER * self.agency_id) - c.stops[self.agency_id] = val - self.stop_id = c.stops[self.agency_id] + c = Constants() + c.stops["stops"] = max(c.stops.get("stops", 0), max_db) + 1 + self.stop_id = c.stops["stops"] \ No newline at end of file From e7c519bc9f2e792a16e59ca5640d24717a6bce02 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Thu, 13 Jun 2024 20:39:40 -0300 Subject: [PATCH 03/18] . --- .../database_specification/transit/tables/stops.sql | 4 ---- aequilibrae/transit/gtfs_loader.py | 9 +++++++-- aequilibrae/transit/lib_gtfs.py | 2 +- aequilibrae/transit/map_matching_graph.py | 2 +- aequilibrae/transit/transit_elements/stop.py | 2 +- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/aequilibrae/project/database_specification/transit/tables/stops.sql b/aequilibrae/project/database_specification/transit/tables/stops.sql index 3620292ea..cc4fc142e 100644 --- a/aequilibrae/project/database_specification/transit/tables/stops.sql +++ b/aequilibrae/project/database_specification/transit/tables/stops.sql @@ -7,8 +7,6 @@ --@ --@ **stop** idenfifies a stop, statio, or station entrance --@ ---@ **agency_id** identifies the agency fot the specified route ---@ --@ **link** identifies the *link_id* in the links table that corresponds to the --@ pattern matching --@ @@ -32,7 +30,6 @@ CREATE TABLE IF NOT EXISTS stops ( stop_id TEXT PRIMARY KEY, stop TEXT NOT NULL , - agency_id INTEGER NOT NULL, link INTEGER, dir INTEGER, name TEXT, @@ -42,7 +39,6 @@ CREATE TABLE IF NOT EXISTS stops ( fare_zone_id INTEGER, transit_zone TEXT, route_type INTEGER NOT NULL DEFAULT -1, - FOREIGN KEY(agency_id) REFERENCES agencies(agency_id), FOREIGN KEY("fare_zone_id") REFERENCES fare_zones("fare_zone_id") ); diff --git a/aequilibrae/transit/gtfs_loader.py b/aequilibrae/transit/gtfs_loader.py index 394337f79..ac38a625e 100644 --- a/aequilibrae/transit/gtfs_loader.py +++ b/aequilibrae/transit/gtfs_loader.py @@ -61,6 +61,7 @@ def __init__(self): self.srid = get_srid() self.transformer = Transformer.from_crs("epsg:4326", f"epsg:{self.srid}", always_xy=False) self.__mt = "" + self.agency_correspondence = {} self.logger = get_logger() def set_feed_path(self, file_path): @@ -491,9 +492,12 @@ def __load_routes_table(self): for route_type, cap in self.__capacities__.items(): routes.loc[routes.route_type == route_type, ["seated_capacity", "total_capacity"]] = cap + agency_finder = routes["agency_id"].values.tolist() + routes.drop(columns="agency_id", inplace=True) + for i, line in routes.iterrows(): self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt]) - r = Route(self.agency[line["agency_id"]].agency_id) + r = Route(self.agency_correspondence[agency_finder[i]]) r.populate(line.values, routes.columns) self.routes[r.route] = r @@ -589,7 +593,8 @@ def __load_agencies(self): a.agency = line["agency_name"] a.feed_date = self.feed_date a.service_date = self.service_date - self.agency[line["agency_id"]] = a + self.agency[a.agency_id] = a + self.agency_correspondence[line["agency_id"]] = a.agency_id def __fail(self, msg: str) -> None: self.logger.error(msg) diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index 6fd7490c1..80a03969d 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -219,7 +219,7 @@ def save_to_disk(self): zone_ids2 = {x.destination: x.destination_id for x in self.gtfs_data.fare_rules if x.destination_id >= 0} zone_ids = {**zone_ids1, **zone_ids2} - # fix + # TODO # zones = [[y, x, self.gtfs_data.agency.agency_id] for x, y in list(zone_ids.items())] # if zones: # sql = "Insert into fare_zones (fare_zone_id, transit_zone, agency_id) values(?, ?, ?);" diff --git a/aequilibrae/transit/map_matching_graph.py b/aequilibrae/transit/map_matching_graph.py index b412a10a6..8424fc061 100644 --- a/aequilibrae/transit/map_matching_graph.py +++ b/aequilibrae/transit/map_matching_graph.py @@ -41,7 +41,7 @@ def __init__(self, lib_gtfs): self.mode_id = -1 self.__mode = "" self.__df_file = "" - self.__agency = '-'.join([key for key in lib_gtfs.gtfs_data.agency.keys()]) + self.__agency = '-'.join([key for key in lib_gtfs.gtfs_data.agency_correspondence.keys()]) self.__centroids_file = "" self.__mm_graph_file = "" self.node_corresp = [] diff --git a/aequilibrae/transit/transit_elements/stop.py b/aequilibrae/transit/transit_elements/stop.py index 969e77395..4198a8b5b 100644 --- a/aequilibrae/transit/transit_elements/stop.py +++ b/aequilibrae/transit/transit_elements/stop.py @@ -86,7 +86,7 @@ def data(self) -> list: def get_node_id(self): with closing(database_connection("transit")) as conn: - sql = "Select coalesce(max(distinct(stop_id)), 0) from stops;" + sql = "Select count(stop_id) from stops;" max_db = int(conn.execute(sql).fetchone()[0]) c = Constants() From 10a86fa9107708aebda28ce111f972f284d6f477 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Fri, 14 Jun 2024 16:49:17 -0300 Subject: [PATCH 04/18] fixes docs and tests --- .../transit/tables/fare_zones.sql | 8 +--- .../transit/tables/stops.sql | 3 -- aequilibrae/transit/column_order.py | 3 +- aequilibrae/transit/gtfs_loader.py | 38 +++++++++++-------- aequilibrae/transit/lib_gtfs.py | 26 +++++-------- aequilibrae/transit/map_matching_graph.py | 2 +- aequilibrae/transit/transit.py | 5 +-- aequilibrae/transit/transit_elements/stop.py | 8 ++-- .../creating_models/plot_import_gtfs.py | 2 +- .../plot_public_transit_assignment.py | 2 +- .../public_transport.rst | 2 +- .../paths/test_transit_graph_builder.py | 2 +- .../project/test_transit_tables.py | 4 +- tests/aequilibrae/transit/test_gtfs_loader.py | 2 +- tests/aequilibrae/transit/test_gtfs_stop.py | 6 +-- tests/aequilibrae/transit/test_lib_gtfs.py | 9 +---- tests/aequilibrae/transit/test_pattern.py | 2 +- tests/aequilibrae/transit/test_transit.py | 12 ++---- 18 files changed, 55 insertions(+), 81 deletions(-) diff --git a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql index b5f8ded58..fd5f5116d 100644 --- a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql +++ b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql @@ -4,12 +4,8 @@ --@ **fare_zone_id** identifies the fare zone for a stop --@ --@ **transit_zone** identifies the TAZ for a fare zone ---@ ---@ **agency_id** identifies the agency fot the specified route CREATE TABLE IF NOT EXISTS fare_zones ( - fare_zone_id INTEGER PRIMARY KEY, - transit_zone TEXT NOT NULL, - agency_id INTEGER NOT NULL, - FOREIGN KEY(agency_id) REFERENCES agencies(agency_id) deferrable initially deferred + fare_zone_id INTEGER NOT NULL PRIMARY KEY, + transit_zone TEXT NOT NULL ); \ No newline at end of file diff --git a/aequilibrae/project/database_specification/transit/tables/stops.sql b/aequilibrae/project/database_specification/transit/tables/stops.sql index cc4fc142e..90c33d7d7 100644 --- a/aequilibrae/project/database_specification/transit/tables/stops.sql +++ b/aequilibrae/project/database_specification/transit/tables/stops.sql @@ -19,8 +19,6 @@ --@ --@ **description** provides useful description of the stop location --@ ---@ **street** identifies the address of a stop ---@ --@ **fare_zone_id** identifies the fare zone for a stop --@ --@ **transit_zone** identifies the TAZ for a fare zone @@ -35,7 +33,6 @@ CREATE TABLE IF NOT EXISTS stops ( name TEXT, parent_station TEXT, description TEXT, - street TEXT, fare_zone_id INTEGER, transit_zone TEXT, route_type INTEGER NOT NULL DEFAULT -1, diff --git a/aequilibrae/transit/column_order.py b/aequilibrae/transit/column_order.py index 45870606a..60f54cec7 100644 --- a/aequilibrae/transit/column_order.py +++ b/aequilibrae/transit/column_order.py @@ -113,13 +113,14 @@ ("stop_desc", str), ("stop_lat", float), ("stop_lon", float), - ("stop_street", str), ("zone_id", str), # ("stop_url", str), # ("location_type", int), ("parent_station", str), # ("stop_timezone", str), # ("wheelchair_boarding", int), + # ("level_id", int), + # ("platform_code", str) ] ), "shapes.txt": OrderedDict( diff --git a/aequilibrae/transit/gtfs_loader.py b/aequilibrae/transit/gtfs_loader.py index 8711eb16d..c3961dfe0 100644 --- a/aequilibrae/transit/gtfs_loader.py +++ b/aequilibrae/transit/gtfs_loader.py @@ -89,13 +89,14 @@ def _set_capacities(self, capacities: dict): def _set_maximum_speeds(self, max_speeds: dict): self.__max_speeds__ = max_speeds - def load_data(self, service_date: str): + def load_data(self, service_date: str, description: str): """Loads the data for a respective service date. :Arguments: **service_date** (:obj:`str`): service date. e.g. "2020-04-01". """ self.service_date = service_date + self.description = description self.logger.info(f"Loading data for {self.service_date} from the GTFS feed. This may take some time") self.__mt = "Reading GTFS" @@ -218,13 +219,19 @@ def __load_fare_data(self): fareatt = parse_csv(file, column_order[fareatttxt]) self.data_arrays[fareatttxt] = fareatt + existing_agencies = np.unique(fareatt["agency_id"]) + if existing_agencies.shape[0] != len(self.agency): + self.logger.debug("agency_id exists on fare_attributes.txt but not in agency.txt") + elif existing_agencies.shape[0] == 1 and existing_agencies[0] == "": + fareatt["agency_id"] = list(self.agency.keys())[0] + for line in range(fareatt.shape[0]): data = tuple(fareatt[line][list(column_order[fareatttxt].keys())]) headers = ["fare_id", "price", "currency", "payment_method", "transfer", "transfer_duration"] - f = Fare(self.agency[fareatt[line]["agency_id"]].agency_id) + f = Fare(fareatt[line]["agency_id"]) f.populate(data, headers) if f.fare in self.fare_attributes: - self.__fail(f"Fare ID {f.fare} for {self.agency[fareatt[line]['agency_id']].agency} is duplicated") + self.__fail(f"Fare ID {f.fare} for {fareatt[line]['agency_id']} is duplicated") self.fare_attributes[f.fare] = f farerltxt = "fare_rules.txt" @@ -238,8 +245,7 @@ def __load_fare_data(self): farerl = parse_csv(file, column_order[farerltxt]) self.data_arrays[farerltxt] = farerl - corresp = {} - zone_id = self.agency.agency_id * AGENCY_MULTIPLIER + 1 + # corresp = {} for line in range(farerl.shape[0]): data = tuple(farerl[line][list(column_order[farerltxt].keys())]) fr = FareRule() @@ -247,14 +253,15 @@ def __load_fare_data(self): fr.fare_id = self.fare_attributes[fr.fare].fare_id if fr.route in self.routes: fr.route_id = self.routes[fr.route].route_id - fr.agency_id = self.agency.agency_id - for x in [fr.origin, fr.destination]: - if x not in corresp: - corresp[x] = zone_id - zone_id += 1 - fr.origin_id = corresp[fr.origin] - fr.destination_id = corresp[fr.destination] if fr.destination == "" else fr.destination_id - self.fare_rules.append(fr) if fr.origin == "" else fr.origin_id + # fr.agency_id = self.fare_attributes[fr.fare].agency_id + # zone_id = fr.agency_id * AGENCY_MULTIPLIER + 1 + # for x in [fr.origin, fr.destination]: + # if x not in corresp: + # corresp[x] = zone_id + # zone_id += 1 + fr.origin_id = None if fr.origin == "" else int(fr.origin) + fr.destination_id = None if fr.destination == "" else int(fr.destination) + self.fare_rules.append(fr) def __load_shapes_table(self): self.logger.debug("Starting __load_shapes_table") @@ -269,7 +276,7 @@ def __load_shapes_table(self): shapes = parse_csv(file, column_order[shapestxt]) all_shape_ids = np.unique(shapes["shape_id"]).tolist() - msg_txt = f"Load shapes" + msg_txt = "Load shapes" self.signal.emit(["start", "secondary", len(all_shape_ids), msg_txt, self.__mt]) self.data_arrays[shapestxt] = shapes @@ -294,7 +301,7 @@ def __load_trips_table(self): trips_array = parse_csv(file, column_order[tripstxt]) self.data_arrays[tripstxt] = trips_array - msg_txt = f"Load trips" + msg_txt = "Load trips" self.signal.emit(["start", "secondary", trips_array.shape[0], msg_txt, self.__mt]) if np.unique(trips_array["trip_id"]).shape[0] < trips_array.shape[0]: self.__fail("There are repeated trip IDs in trips.txt") @@ -620,6 +627,7 @@ def __load_agencies(self): a.agency = line["agency_name"] a.feed_date = self.feed_date a.service_date = self.service_date + a.description = self.description self.agency[a.agency_id] = a self.agency_correspondence[line["agency_id"]] = a.agency_id diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index 80a03969d..9143b1af4 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -20,14 +20,15 @@ class GTFSRouteSystemBuilder(WorkerThread): """Container for GTFS feeds providing data retrieval for the importer""" - def __init__(self, network, file_path, day="", description="", capacities={}): # noqa: B006 + def __init__(self, network, file_path, description="", capacities={}): # noqa: B006 """Instantiates a transit class for the network :Arguments: **local network** (:obj:`Network`): Supply model to which this GTFS will be imported + **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') - **day** (:obj:`str`, *Optional*): Service data contained in this field to be imported (e.g. '2019-10-04') + **description** (:obj:`str`, *Optional*): Description for this feed (e.g. 'CTA19 fixed by John after coffee') """ WorkerThread.__init__(self, None) @@ -35,7 +36,7 @@ def __init__(self, network, file_path, day="", description="", capacities={}): self.__network = network self.project = get_active_project(False) self.archive_dir = None # type: str - self.day = day + self.day = None self.logger = get_logger() self.gtfs_data = GTFSReader() @@ -45,6 +46,7 @@ def __init__(self, network, file_path, day="", description="", capacities={}): self.trip_by_service = {} self.patterns = {} self.graphs = {} + self.description = description self.transformer = Transformer.from_crs("epsg:4326", f"epsg:{self.srid}", always_xy=False) self.sridproj = pyproj.Proj(f"epsg:{self.srid}") self.__default_capacities = capacities @@ -72,8 +74,7 @@ def set_capacities(self, capacities: dict): :Arguments: **capacities** (:obj:`dict`): Dictionary with GTFS types as keys, each with a list - of 3 items for values for capacities: seated and total - i.e. -> "{0: [150, 300],...}" + of 3 items for values for capacities: seated and total i.e. -> "{0: [150, 300],...}" """ self.gtfs_data._set_capacities(capacities) @@ -130,14 +131,6 @@ def map_match(self, route_types=[3]) -> None: # noqa: B006 if msg is not None: self.logger.warning(msg) - # def set_agency_identifier(self, agency_id: str) -> None: - # """Adds agency ID to this GTFS for use on import. - - # :Arguments: - # **agency_id** (:obj:`str`): ID for the agency this feed refers to (e.g. 'CTA') - # """ - # self.gtfs_data.agency.agency = agency_id - def set_feed(self, feed_path: str) -> None: """Sets GTFS feed source to be used. @@ -173,7 +166,7 @@ def load_date(self, service_date: str) -> None: raise ValueError("The date chosen is not available in this GTFS feed") self.day = service_date - self.gtfs_data.load_data(service_date) + self.gtfs_data.load_data(service_date, self.description) self.logger.info(" Building data structures") self.__build_data() @@ -219,10 +212,9 @@ def save_to_disk(self): zone_ids2 = {x.destination: x.destination_id for x in self.gtfs_data.fare_rules if x.destination_id >= 0} zone_ids = {**zone_ids1, **zone_ids2} - # TODO - # zones = [[y, x, self.gtfs_data.agency.agency_id] for x, y in list(zone_ids.items())] + # zones = [[y, x] for x, y in list(zone_ids.items())] # if zones: - # sql = "Insert into fare_zones (fare_zone_id, transit_zone, agency_id) values(?, ?, ?);" + # sql = "Insert into fare_zones (fare_zone_id, transit_zone) values(?, ?);" # conn.executemany(sql, zones) # conn.commit() diff --git a/aequilibrae/transit/map_matching_graph.py b/aequilibrae/transit/map_matching_graph.py index 8424fc061..9b2177c52 100644 --- a/aequilibrae/transit/map_matching_graph.py +++ b/aequilibrae/transit/map_matching_graph.py @@ -41,7 +41,7 @@ def __init__(self, lib_gtfs): self.mode_id = -1 self.__mode = "" self.__df_file = "" - self.__agency = '-'.join([key for key in lib_gtfs.gtfs_data.agency_correspondence.keys()]) + self.__agency = "-".join(list(lib_gtfs.gtfs_data.agency_correspondence.keys())) self.__centroids_file = "" self.__mm_graph_file = "" self.node_corresp = [] diff --git a/aequilibrae/transit/transit.py b/aequilibrae/transit/transit.py index 6e8907d87..645b82c76 100644 --- a/aequilibrae/transit/transit.py +++ b/aequilibrae/transit/transit.py @@ -44,14 +44,12 @@ def __init__(self, project): self.create_transit_database() self.pt_con = database_connection("transit") - def new_gtfs_builder(self, file_path, day="", description="") -> GTFSRouteSystemBuilder: + def new_gtfs_builder(self, file_path, description="") -> GTFSRouteSystemBuilder: """Returns a ``GTFSRouteSystemBuilder`` object compatible with the project :Arguments: **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') - **day** (:obj:`str`, *Optional*): Service data contained in this field to be imported (e.g. '2019-10-04') - **description** (:obj:`str`, *Optional*): Description for this feed (e.g. 'CTA2019 fixed by John Doe') :Returns: @@ -60,7 +58,6 @@ def new_gtfs_builder(self, file_path, day="", description="") -> GTFSRouteSystem gtfs = GTFSRouteSystemBuilder( network=self.project_base_path, file_path=file_path, - day=day, description=description, capacities=self.default_capacities, ) diff --git a/aequilibrae/transit/transit_elements/stop.py b/aequilibrae/transit/transit_elements/stop.py index 4198a8b5b..b40ab3666 100644 --- a/aequilibrae/transit/transit_elements/stop.py +++ b/aequilibrae/transit/transit_elements/stop.py @@ -24,7 +24,6 @@ def __init__(self, record: tuple, headers: list): self.stop_desc = "" self.stop_lat: float = None self.stop_lon: float = None - self.stop_street = "" self.zone = "" self.zone_id = None self.stop_url = "" @@ -58,8 +57,8 @@ def save_to_database(self, conn: Connection, commit=True) -> None: """Saves Transit Stop to the database""" sql = """insert into stops (stop_id, stop, link, dir, name, - parent_station, description, street, fare_zone_id, transit_zone, route_type, geometry) - values (?,?,?,?,?,?,?,?,?,?,?, GeomFromWKB(?, ?));""" + parent_station, description, fare_zone_id, transit_zone, route_type, geometry) + values (?,?,?,?,?,?,?,?,?,?, GeomFromWKB(?, ?));""" dt = self.data conn.execute(sql, dt) @@ -76,7 +75,6 @@ def data(self) -> list: self.stop_name, self.parent_station, self.stop_desc, - self.stop_street, self.zone_id, self.taz, int(self.route_type), @@ -91,4 +89,4 @@ def get_node_id(self): c = Constants() c.stops["stops"] = max(c.stops.get("stops", 0), max_db) + 1 - self.stop_id = c.stops["stops"] \ No newline at end of file + self.stop_id = c.stops["stops"] diff --git a/docs/source/examples/creating_models/plot_import_gtfs.py b/docs/source/examples/creating_models/plot_import_gtfs.py index d08bb512d..290c95157 100644 --- a/docs/source/examples/creating_models/plot_import_gtfs.py +++ b/docs/source/examples/creating_models/plot_import_gtfs.py @@ -45,7 +45,7 @@ data = Transit(project) -transit = data.new_gtfs_builder(agency="Lisanco", file_path=dest_path) +transit = data.new_gtfs_builder(file_path=dest_path) # %% # To load the data, we must choose one date. We're going to continue with 2016-04-13 but feel free diff --git a/docs/source/examples/full_workflows/plot_public_transit_assignment.py b/docs/source/examples/full_workflows/plot_public_transit_assignment.py index ddd3a10b7..3e399890d 100644 --- a/docs/source/examples/full_workflows/plot_public_transit_assignment.py +++ b/docs/source/examples/full_workflows/plot_public_transit_assignment.py @@ -47,7 +47,7 @@ # This will automatically create a new public transport database. data = Transit(project) -transit = data.new_gtfs_builder(agency="LISANCO", file_path=dest_path) +transit = data.new_gtfs_builder(file_path=dest_path) # %% # To load the data, we must choose one date. We're going to continue with 2016-04-13 but feel free diff --git a/docs/source/modeling_with_aequilibrae/public_transport.rst b/docs/source/modeling_with_aequilibrae/public_transport.rst index e42b681ba..71b84f764 100644 --- a/docs/source/modeling_with_aequilibrae/public_transport.rst +++ b/docs/source/modeling_with_aequilibrae/public_transport.rst @@ -8,7 +8,7 @@ into its database. The Transit module has been updated in version 0.9.0. More de the **public_transport.sqlite** are discussed on a nearly *per-table* basis below, and we recommend understanding the role of each table before setting an AequilibraE model you intend to use. If you don't know much about GTFS, we strongly encourage you to take -a look at the documentation provided by `Google `_. +a look at the documentation provided by `Mobility Data `_. A more technical view of the database structure, including the SQL queries used to create each table and their indices are also available. diff --git a/tests/aequilibrae/paths/test_transit_graph_builder.py b/tests/aequilibrae/paths/test_transit_graph_builder.py index 230585216..88973b295 100644 --- a/tests/aequilibrae/paths/test_transit_graph_builder.py +++ b/tests/aequilibrae/paths/test_transit_graph_builder.py @@ -30,7 +30,7 @@ def setUp(self) -> None: self.data = Transit(self.project) dest_path = join(self.temp_proj_folder, "gtfs_coquimbo.zip") - self.transit = self.data.new_gtfs_builder(agency="LISANCO", file_path=dest_path) + self.transit = self.data.new_gtfs_builder(file_path=dest_path) self.transit.load_date("2016-04-13") self.transit.save_to_disk() diff --git a/tests/aequilibrae/project/test_transit_tables.py b/tests/aequilibrae/project/test_transit_tables.py index 715b78800..7bac0cf50 100644 --- a/tests/aequilibrae/project/test_transit_tables.py +++ b/tests/aequilibrae/project/test_transit_tables.py @@ -18,7 +18,7 @@ def create_project(project: Project): ["fare_id", "fare", "agency_id", "price", "currency", "payment_method", "transfer", "transfer_duration"], ), ("fare_rules", ["fare_id", "route_id", "origin", "destination", "contains"]), - ("fare_zones", ["fare_zone_id", "transit_zone", "agency_id"]), + ("fare_zones", ["fare_zone_id", "transit_zone"]), ("pattern_mapping", ["pattern_id", "seq", "link", "dir", "geometry"]), ( "routes", @@ -42,13 +42,11 @@ def create_project(project: Project): [ "stop_id", "stop", - "agency_id", "link", "dir", "name", "parent_station", "description", - "street", "fare_zone_id", "transit_zone", "route_type", diff --git a/tests/aequilibrae/transit/test_gtfs_loader.py b/tests/aequilibrae/transit/test_gtfs_loader.py index 0af0ea558..c50d79375 100644 --- a/tests/aequilibrae/transit/test_gtfs_loader.py +++ b/tests/aequilibrae/transit/test_gtfs_loader.py @@ -38,4 +38,4 @@ def test_load_data(gtfs_loader, gtfs_fldr): gtfs._set_maximum_speeds(dict_speeds) gtfs.set_feed_path(gtfs_fldr) - gtfs.load_data("2016-04-13") + gtfs.load_data("2016-04-13", "this is a description") diff --git a/tests/aequilibrae/transit/test_gtfs_stop.py b/tests/aequilibrae/transit/test_gtfs_stop.py index bb8d5bdd4..b88c0af7c 100644 --- a/tests/aequilibrae/transit/test_gtfs_stop.py +++ b/tests/aequilibrae/transit/test_gtfs_stop.py @@ -18,7 +18,6 @@ def data(): "stop_desc": randomword(randint(0, 40)), "stop_lat": uniform(0, 30000), "stop_lon": uniform(0, 30000), - "stop_street": randomword(randint(0, 40)), "zone_id": randomword(randint(0, 40)), "stop_url": randomword(randint(0, 40)), "location_type": choice((0, 1)), @@ -51,7 +50,6 @@ def test_save_to_database(data, transit_conn): s = Stop(1, tuple(data.values()), list(data.keys())) s.link = link = randint(1, 30000) s.dir = direc = choice((0, 1)) - s.agency = randint(5, 100000) s.route_type = randint(0, 13) s.srid = get_srid() s.get_node_id() @@ -61,7 +59,7 @@ def test_save_to_database(data, transit_conn): VALUES(?, ?, ?, ?, ?, ?, GeomFromWKB(?, 4326));""" transit_conn.execute(sql_tl, [tlink_id, randint(1, 1000000000), randint(1, 10), s.stop_id, s.stop_id + 1, 0, line]) - sql = "Select agency_id, link, dir, description, street from stops where stop=?" + sql = "Select link, dir, description, street from stops where stop=?" result = list(transit_conn.execute(sql, [data["stop_id"]]).fetchone()) - expected = [s.agency_id, link, direc, data["stop_desc"], data["stop_street"]] + expected = [link, direc, data["stop_desc"], data["stop_street"]] assert result == expected, "Saving Stop to the database failed" diff --git a/tests/aequilibrae/transit/test_lib_gtfs.py b/tests/aequilibrae/transit/test_lib_gtfs.py index b6a49fc9f..5bc896701 100644 --- a/tests/aequilibrae/transit/test_lib_gtfs.py +++ b/tests/aequilibrae/transit/test_lib_gtfs.py @@ -12,7 +12,7 @@ def gtfs_file(create_path): @pytest.fixture def system_builder(transit_conn, gtfs_file): yield GTFSRouteSystemBuilder( - network=transit_conn, agency_identifier="LISERCO, LISANCO, LINCOSUR", file_path=gtfs_file + network=transit_conn, file_path=gtfs_file ) @@ -51,12 +51,6 @@ def test_map_match(transit_conn, system_builder): assert transit_conn.execute("SELECT * FROM pattern_mapping;").fetchone()[0] > 1 -def test_set_agency_identifier(system_builder): - assert system_builder.gtfs_data.agency.agency != "CTA" - system_builder.set_agency_identifier("CTA") - assert system_builder.gtfs_data.agency.agency == "CTA" - - def test_set_feed(gtfs_file, system_builder): system_builder.set_feed(gtfs_file) assert system_builder.gtfs_data.archive_dir == gtfs_file @@ -74,7 +68,6 @@ def test_set_date(system_builder): def test_load_date(system_builder): system_builder.load_date("2016-04-13") - assert system_builder.gtfs_data.agency.service_date == "2016-04-13" assert "101387" in system_builder.select_routes.keys() diff --git a/tests/aequilibrae/transit/test_pattern.py b/tests/aequilibrae/transit/test_pattern.py index ab0b5c066..6d6a4b6a2 100644 --- a/tests/aequilibrae/transit/test_pattern.py +++ b/tests/aequilibrae/transit/test_pattern.py @@ -6,7 +6,7 @@ def pat(create_path, create_gtfs_project): gtfs_fldr = os.path.join(create_path, "gtfs_coquimbo.zip") - transit = create_gtfs_project.new_gtfs_builder(agency="Lisanco", file_path=gtfs_fldr, description="") + transit = create_gtfs_project.new_gtfs_builder(file_path=gtfs_fldr, description="") transit.load_date("2016-04-13") patterns = transit.select_patterns diff --git a/tests/aequilibrae/transit/test_transit.py b/tests/aequilibrae/transit/test_transit.py index fab7ee94d..850b37792 100644 --- a/tests/aequilibrae/transit/test_transit.py +++ b/tests/aequilibrae/transit/test_transit.py @@ -12,30 +12,26 @@ def test_new_gtfs_builder(create_gtfs_project, create_path): existing = conn.execute("SELECT COALESCE(MAX(DISTINCT(agency_id)), 0) FROM agencies;").fetchone()[0] transit = create_gtfs_project.new_gtfs_builder( - agency="Agency_1", - day="2016-04-13", file_path=join(create_path, "gtfs_coquimbo.zip"), ) - + transit.load_date("2016-04-13") + transit.save_to_disk() assert str(type(transit)) == "" transit2 = create_gtfs_project.new_gtfs_builder( - agency="Agency_2", - day="2016-07-19", file_path=join(create_path, "gtfs_coquimbo.zip"), ) - transit.save_to_disk() + transit2.load_date("2016-07-19") transit2.save_to_disk() assert conn.execute("SELECT MAX(DISTINCT(agency_id)) FROM agencies;").fetchone()[0] == existing + 2 transit3 = create_gtfs_project.new_gtfs_builder( - agency="Agency_3", - day="2016-07-19", file_path=join(create_path, "gtfs_coquimbo.zip"), ) + transit3.load_date("2016-06-04") transit3.save_to_disk() assert conn.execute("SELECT MAX(DISTINCT(agency_id)) FROM agencies;").fetchone()[0] == existing + 3 From 77e2046fc0948081ee6576f76956ed993fd0db7d Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Mon, 17 Jun 2024 08:35:55 -0300 Subject: [PATCH 05/18] fixes test files --- aequilibrae/transit/lib_gtfs.py | 1 - tests/aequilibrae/transit/test_gtfs_stop.py | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index 9143b1af4..b8b5e8e88 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -138,7 +138,6 @@ def set_feed(self, feed_path: str) -> None: **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') """ self.gtfs_data.set_feed_path(feed_path) - self.gtfs_data.agency.feed_date = self.gtfs_data.feed_date def set_description(self, description: str) -> None: """Adds description to be added to the imported layers metadata diff --git a/tests/aequilibrae/transit/test_gtfs_stop.py b/tests/aequilibrae/transit/test_gtfs_stop.py index b88c0af7c..eacd33fb0 100644 --- a/tests/aequilibrae/transit/test_gtfs_stop.py +++ b/tests/aequilibrae/transit/test_gtfs_stop.py @@ -27,7 +27,7 @@ def data(): def test__populate(data): - s = Stop(1, tuple(data.values()), list(data.keys())) + s = Stop(tuple(data.values()), list(data.keys())) xy = (s.geo.x, s.geo.y) assert xy == (data["stop_lon"], data["stop_lat"]), "Stop built geo wrongly" data["stop"] = data["stop_id"] @@ -41,13 +41,13 @@ def test__populate(data): new_data = deepcopy(data) new_data[randomword(randint(1, 15))] = randomword(randint(1, 20)) with pytest.raises(KeyError): - _ = Stop(1, tuple(new_data.values()), list(new_data.keys())) + _ = Stop(tuple(new_data.values()), list(new_data.keys())) def test_save_to_database(data, transit_conn): line = LineString([[-23.59, -46.64], [-23.43, -46.50]]).wkb tlink_id = randint(10000, 200000044) - s = Stop(1, tuple(data.values()), list(data.keys())) + s = Stop(tuple(data.values()), list(data.keys())) s.link = link = randint(1, 30000) s.dir = direc = choice((0, 1)) s.route_type = randint(0, 13) @@ -59,7 +59,7 @@ def test_save_to_database(data, transit_conn): VALUES(?, ?, ?, ?, ?, ?, GeomFromWKB(?, 4326));""" transit_conn.execute(sql_tl, [tlink_id, randint(1, 1000000000), randint(1, 10), s.stop_id, s.stop_id + 1, 0, line]) - sql = "Select link, dir, description, street from stops where stop=?" + sql = "Select link, dir, description from stops where stop=?" result = list(transit_conn.execute(sql, [data["stop_id"]]).fetchone()) - expected = [link, direc, data["stop_desc"], data["stop_street"]] + expected = [link, direc, data["stop_desc"]] assert result == expected, "Saving Stop to the database failed" From f9dac4b7e9a5ea8df40d072786e79e821beb1e89 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Mon, 17 Jun 2024 10:30:22 -0300 Subject: [PATCH 06/18] Update test_lib_gtfs.py --- tests/aequilibrae/transit/test_lib_gtfs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/aequilibrae/transit/test_lib_gtfs.py b/tests/aequilibrae/transit/test_lib_gtfs.py index 5bc896701..2d47f1117 100644 --- a/tests/aequilibrae/transit/test_lib_gtfs.py +++ b/tests/aequilibrae/transit/test_lib_gtfs.py @@ -11,9 +11,7 @@ def gtfs_file(create_path): @pytest.fixture def system_builder(transit_conn, gtfs_file): - yield GTFSRouteSystemBuilder( - network=transit_conn, file_path=gtfs_file - ) + yield GTFSRouteSystemBuilder(network=transit_conn, file_path=gtfs_file) def test_set_capacities(system_builder): From baf25f55077df427a9e3af026158c42f3812ad95 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Mon, 17 Jun 2024 11:28:10 -0300 Subject: [PATCH 07/18] Update plot_import_gtfs.py --- .../creating_models/plot_import_gtfs.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/docs/source/examples/creating_models/plot_import_gtfs.py b/docs/source/examples/creating_models/plot_import_gtfs.py index 290c95157..0aa3c37c2 100644 --- a/docs/source/examples/creating_models/plot_import_gtfs.py +++ b/docs/source/examples/creating_models/plot_import_gtfs.py @@ -30,9 +30,9 @@ project = create_example(fldr, "coquimbo") # %% -# As the Coquimbo example already has a complete GTFS model, we shall remove its public transport -# database for the sake of this example. +# Since the Coquimbo example already includes a complete GTFS model, we will remove its public transport +# database for the purposes of this example. remove(join(fldr, "public_transport.sqlite")) # %% @@ -40,35 +40,43 @@ dest_path = join(fldr, "gtfs_coquimbo.zip") # %% -# Now we create our Transit object and import the GTFS feed into our model. -# This will automatically create a new public transport database. - +# Now we create our Transit object. This will automatically create a new public transport database. data = Transit(project) -transit = data.new_gtfs_builder(file_path=dest_path) +# %% +# To initialize the GTFS builder, specify the path to the GTFS file. You no longer need to provide the +# name of the transit agency, but you can add a general description of the GTFS feed. If your GTFS file +# includes multiple transit agencies, all data will be loaded simultaneously. However, any description +# you add will apply to all agencies in the file. +transit = data.new_gtfs_builder(file_path=dest_path, description="Wednesday feed by John Doe") + +#%% +# Case you want information on the available dates before loading the GTFS data to the database, +# it is possible to use the function ``transit.dates_available()`` to check the available feed dates. # %% -# To load the data, we must choose one date. We're going to continue with 2016-04-13 but feel free -# to experiment with any other available dates. Transit class has a function allowing you to check -# dates for the GTFS feed. It should take approximately 2 minutes to load the data. +# To load the data, we must choose one date using the format ``YYYY-MM-DD``. We're going to build +# our database using the day 2016-04-13 but feel free to experiment with any other available dates. +# +# It shouldn't take long to load the data. transit.load_date("2016-04-13") -# Now we execute the map matching to find the real paths. -# Depending on the GTFS size, this process can be really time-consuming. -transit.set_allow_map_match(True) +# %% +# Now we execute the map matching to find the real paths. Depending on the number or different route +# patterns and/or the project area size, this process can be really time-consuming. transit.map_match() -# Finally, we save our GTFS into our model. +# %% +# Finally, we save our GTFS feed into our model. transit.save_to_disk() # %% -# Now we will plot one of the route's patterns we just imported +# Now we will plot the route's patterns we just imported conn = database_connection("transit") links = pd.read_sql("SELECT pattern_id, ST_AsText(geometry) geom FROM routes;", con=conn) - -stops = pd.read_sql("""SELECT stop_id, ST_X(geometry) X, ST_Y(geometry) Y FROM stops""", con=conn) +stops = pd.read_sql("SELECT stop_id, ST_X(geometry) X, ST_Y(geometry) Y FROM stops;", con=conn) # %% gtfs_links = folium.FeatureGroup("links") From a60d6ec58c330b55fa534a5330de712b3a5c24b0 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Tue, 18 Jun 2024 15:50:07 -0300 Subject: [PATCH 08/18] docs and fare setup --- .../transit/tables/agencies.sql | 2 +- .../transit/tables/fare_attributes.sql | 2 +- .../transit/tables/fare_zones.sql | 2 +- .../transit/tables/routes.sql | 2 +- .../transit/tables/stops.sql | 2 +- .../transit/tables/trips.sql | 2 +- aequilibrae/transit/gtfs_loader.py | 6 ------ aequilibrae/transit/lib_gtfs.py | 14 ++++++++------ 8 files changed, 14 insertions(+), 18 deletions(-) diff --git a/aequilibrae/project/database_specification/transit/tables/agencies.sql b/aequilibrae/project/database_specification/transit/tables/agencies.sql index dcca9e9ea..834d156b6 100644 --- a/aequilibrae/project/database_specification/transit/tables/agencies.sql +++ b/aequilibrae/project/database_specification/transit/tables/agencies.sql @@ -1,7 +1,7 @@ --@ The *agencies* table holds information about the Public Transport --@ agencies within the GTFS data. This table information comes from --@ GTFS file *agency.txt*. ---@ You can check out more information `here `_. +--@ You can check out more information `here `_. --@ --@ **agency_id** identifies the agency for the specified route --@ diff --git a/aequilibrae/project/database_specification/transit/tables/fare_attributes.sql b/aequilibrae/project/database_specification/transit/tables/fare_attributes.sql index bcf9849bb..9770fc5f1 100644 --- a/aequilibrae/project/database_specification/transit/tables/fare_attributes.sql +++ b/aequilibrae/project/database_specification/transit/tables/fare_attributes.sql @@ -1,7 +1,7 @@ --@ The *fare_attributes* table holds information about the fare values. --@ This table information comes from the GTFS file *fare_attributes.txt*. --@ Given that this file is optional in GTFS, it can be empty. ---@ You can check out more information `here `_. +--@ You can check out more information `here `_. --@ --@ **fare_id** identifies a fare class --@ diff --git a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql index fd5f5116d..6f1e0bf7d 100644 --- a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql +++ b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql @@ -6,6 +6,6 @@ --@ **transit_zone** identifies the TAZ for a fare zone CREATE TABLE IF NOT EXISTS fare_zones ( - fare_zone_id INTEGER NOT NULL PRIMARY KEY, + fare_zone_id INTEGER NOT NULL, transit_zone TEXT NOT NULL ); \ No newline at end of file diff --git a/aequilibrae/project/database_specification/transit/tables/routes.sql b/aequilibrae/project/database_specification/transit/tables/routes.sql index f60cf2361..57ca26fc8 100644 --- a/aequilibrae/project/database_specification/transit/tables/routes.sql +++ b/aequilibrae/project/database_specification/transit/tables/routes.sql @@ -1,6 +1,6 @@ --@ The *routes* table holds information on the available transit routes for a --@ specific day. This table information comes from the GTFS file *routes.txt*. ---@ You can find more information about it `here `_. +--@ You can find more information about it `here `_. --@ --@ **pattern_id** is an unique pattern for the route --@ diff --git a/aequilibrae/project/database_specification/transit/tables/stops.sql b/aequilibrae/project/database_specification/transit/tables/stops.sql index 90c33d7d7..881440f8f 100644 --- a/aequilibrae/project/database_specification/transit/tables/stops.sql +++ b/aequilibrae/project/database_specification/transit/tables/stops.sql @@ -1,7 +1,7 @@ --@ The *stops* table holds information on the stops where vehicles --@ pick up or drop off riders. This table information comes from --@ the GTFS file *stops.txt*. You can find more information about ---@ it `here `_. +--@ it `here `_. --@ --@ **stop_id** is an unique identifier for a stop --@ diff --git a/aequilibrae/project/database_specification/transit/tables/trips.sql b/aequilibrae/project/database_specification/transit/tables/trips.sql index 55b64fc38..a560706c4 100644 --- a/aequilibrae/project/database_specification/transit/tables/trips.sql +++ b/aequilibrae/project/database_specification/transit/tables/trips.sql @@ -1,6 +1,6 @@ --@ The *trips* table holds information on trips for each route. --@ This table comes from the GTFS file *trips.txt*. ---@ You can find more information about it `here `_. +--@ You can find more information about it `here `_. --@ --@ **trip_id** identifies a trip --@ diff --git a/aequilibrae/transit/gtfs_loader.py b/aequilibrae/transit/gtfs_loader.py index c3961dfe0..9eb96a20c 100644 --- a/aequilibrae/transit/gtfs_loader.py +++ b/aequilibrae/transit/gtfs_loader.py @@ -253,12 +253,6 @@ def __load_fare_data(self): fr.fare_id = self.fare_attributes[fr.fare].fare_id if fr.route in self.routes: fr.route_id = self.routes[fr.route].route_id - # fr.agency_id = self.fare_attributes[fr.fare].agency_id - # zone_id = fr.agency_id * AGENCY_MULTIPLIER + 1 - # for x in [fr.origin, fr.destination]: - # if x not in corresp: - # corresp[x] = zone_id - # zone_id += 1 fr.origin_id = None if fr.origin == "" else int(fr.origin) fr.destination_id = None if fr.destination == "" else int(fr.destination) self.fare_rules.append(fr) diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index b8b5e8e88..77ec71cb1 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -211,18 +211,13 @@ def save_to_disk(self): zone_ids2 = {x.destination: x.destination_id for x in self.gtfs_data.fare_rules if x.destination_id >= 0} zone_ids = {**zone_ids1, **zone_ids2} - # zones = [[y, x] for x, y in list(zone_ids.items())] - # if zones: - # sql = "Insert into fare_zones (fare_zone_id, transit_zone) values(?, ?);" - # conn.executemany(sql, zones) - # conn.commit() - for fare in self.gtfs_data.fare_attributes.values(): fare.save_to_database(conn) for fare_rule in self.gtfs_data.fare_rules: fare_rule.save_to_database(conn) + zones = [] for counter, (_, stop) in enumerate(self.select_stops.items()): if stop.zone in zone_ids: stop.zone_id = zone_ids[stop.zone] @@ -230,9 +225,16 @@ def save_to_disk(self): closest_zone = self.project.zoning.get_closest_zone(stop.geo) if stop.geo.within(self.project.zoning.get(closest_zone).geometry): stop.taz = closest_zone + if zone_ids: + zones.append([zone_ids[stop.zone], closest_zone]) stop.save_to_database(conn, commit=False) conn.commit() + if zones: + sql = "insert into fare_zones (fare_zone_id, transit_zone) values(?, ?);" + conn.executemany(sql, zones) + conn.commit() + self.__outside_zones = None in [x.taz for x in self.select_stops.values()] if self.__outside_zones: msg = " Some stops are outside the zoning system. Check the result on a map and see the log for info" From 18201b553d6d1295f110a417a6da20183a117539 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Tue, 18 Jun 2024 15:52:13 -0300 Subject: [PATCH 09/18] linting --- aequilibrae/transit/lib_gtfs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index 77ec71cb1..3a3b815ae 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -234,7 +234,7 @@ def save_to_disk(self): sql = "insert into fare_zones (fare_zone_id, transit_zone) values(?, ?);" conn.executemany(sql, zones) conn.commit() - + self.__outside_zones = None in [x.taz for x in self.select_stops.values()] if self.__outside_zones: msg = " Some stops are outside the zoning system. Check the result on a map and see the log for info" From 39d4f419b7b3fefa714af9b8a1a6b642ac488011 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Tue, 2 Jul 2024 15:30:59 -0300 Subject: [PATCH 10/18] rollback tables --- .../database_specification/transit/tables/fare_zones.sql | 8 ++++++-- .../database_specification/transit/tables/stops.sql | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql index 6f1e0bf7d..b5f8ded58 100644 --- a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql +++ b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql @@ -4,8 +4,12 @@ --@ **fare_zone_id** identifies the fare zone for a stop --@ --@ **transit_zone** identifies the TAZ for a fare zone +--@ +--@ **agency_id** identifies the agency fot the specified route CREATE TABLE IF NOT EXISTS fare_zones ( - fare_zone_id INTEGER NOT NULL, - transit_zone TEXT NOT NULL + fare_zone_id INTEGER PRIMARY KEY, + transit_zone TEXT NOT NULL, + agency_id INTEGER NOT NULL, + FOREIGN KEY(agency_id) REFERENCES agencies(agency_id) deferrable initially deferred ); \ No newline at end of file diff --git a/aequilibrae/project/database_specification/transit/tables/stops.sql b/aequilibrae/project/database_specification/transit/tables/stops.sql index 881440f8f..3a8fd280e 100644 --- a/aequilibrae/project/database_specification/transit/tables/stops.sql +++ b/aequilibrae/project/database_specification/transit/tables/stops.sql @@ -7,6 +7,8 @@ --@ --@ **stop** idenfifies a stop, statio, or station entrance --@ +--@ **agency_id** identifies the agency fot the specified route +--@ --@ **link** identifies the *link_id* in the links table that corresponds to the --@ pattern matching --@ @@ -28,6 +30,7 @@ CREATE TABLE IF NOT EXISTS stops ( stop_id TEXT PRIMARY KEY, stop TEXT NOT NULL , + agency_id INTEGER NOT NULL, link INTEGER, dir INTEGER, name TEXT, @@ -36,6 +39,7 @@ CREATE TABLE IF NOT EXISTS stops ( fare_zone_id INTEGER, transit_zone TEXT, route_type INTEGER NOT NULL DEFAULT -1, + FOREIGN KEY(agency_id) REFERENCES agencies(agency_id), FOREIGN KEY("fare_zone_id") REFERENCES fare_zones("fare_zone_id") ); From dcf9d53e91888aa5e51a79932378303fc05bbf3e Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Tue, 2 Jul 2024 17:47:14 -0300 Subject: [PATCH 11/18] add agency_id --- aequilibrae/transit/gtfs_loader.py | 1 - aequilibrae/transit/transit_elements/stop.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aequilibrae/transit/gtfs_loader.py b/aequilibrae/transit/gtfs_loader.py index 9eb96a20c..c8718937e 100644 --- a/aequilibrae/transit/gtfs_loader.py +++ b/aequilibrae/transit/gtfs_loader.py @@ -245,7 +245,6 @@ def __load_fare_data(self): farerl = parse_csv(file, column_order[farerltxt]) self.data_arrays[farerltxt] = farerl - # corresp = {} for line in range(farerl.shape[0]): data = tuple(farerl[line][list(column_order[farerltxt].keys())]) fr = FareRule() diff --git a/aequilibrae/transit/transit_elements/stop.py b/aequilibrae/transit/transit_elements/stop.py index b40ab3666..341159212 100644 --- a/aequilibrae/transit/transit_elements/stop.py +++ b/aequilibrae/transit/transit_elements/stop.py @@ -33,6 +33,7 @@ def __init__(self, record: tuple, headers: list): # Not part of GTFS self.taz = None + self.agency_id = None self.link = None self.dir = None self.srid = -1 @@ -56,9 +57,9 @@ def __init__(self, record: tuple, headers: list): def save_to_database(self, conn: Connection, commit=True) -> None: """Saves Transit Stop to the database""" - sql = """insert into stops (stop_id, stop, link, dir, name, + sql = """insert into stops (stop_id, stop, agency_id, link, dir, name, parent_station, description, fare_zone_id, transit_zone, route_type, geometry) - values (?,?,?,?,?,?,?,?,?,?, GeomFromWKB(?, ?));""" + values (?,?,?,?,?,?,?,?,?,?,?, GeomFromWKB(?, ?));""" dt = self.data conn.execute(sql, dt) From 343788af033a7aa301273f0b25c033599f7dd271 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Wed, 3 Jul 2024 16:56:21 -0300 Subject: [PATCH 12/18] . --- .../transit/tables/fare_zones.sql | 4 +- aequilibrae/transit/lib_gtfs.py | 39 ++++++++++++++----- aequilibrae/transit/parse_csv.py | 2 +- aequilibrae/transit/transit_elements/stop.py | 5 ++- .../project/test_transit_tables.py | 3 +- tests/aequilibrae/transit/test_gtfs_stop.py | 3 +- 6 files changed, 39 insertions(+), 17 deletions(-) diff --git a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql index b5f8ded58..d07145172 100644 --- a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql +++ b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql @@ -9,7 +9,7 @@ CREATE TABLE IF NOT EXISTS fare_zones ( fare_zone_id INTEGER PRIMARY KEY, - transit_zone TEXT NOT NULL, - agency_id INTEGER NOT NULL, + transit_zone TEXT, + agency_id INTEGER, FOREIGN KEY(agency_id) REFERENCES agencies(agency_id) deferrable initially deferred ); \ No newline at end of file diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index 3a3b815ae..8c69e0daa 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -206,32 +206,51 @@ def save_to_disk(self): link.save_to_database(conn, commit=False) conn.commit() - self.__outside_zones = 0 - zone_ids1 = {x.origin: x.origin_id for x in self.gtfs_data.fare_rules if x.origin_id >= 0} - zone_ids2 = {x.destination: x.destination_id for x in self.gtfs_data.fare_rules if x.destination_id >= 0} - zone_ids = {**zone_ids1, **zone_ids2} - for fare in self.gtfs_data.fare_attributes.values(): fare.save_to_database(conn) for fare_rule in self.gtfs_data.fare_rules: fare_rule.save_to_database(conn) + sql = """WITH t1 AS ( + SELECT from_stop stop_id, pattern_id FROM route_links + UNION ALL + SELECT to_stop stop_id, pattern_id FROM route_links + ), + t2 AS ( + SELECT route_id, pattern_id, agency_id FROM routes + ), + t3 AS ( + SELECT t1.stop_id, t2.agency_id, COUNT(*) as frequency + FROM t1 + JOIN t2 ON t1.pattern_id = t2.pattern_id + GROUP BY t1.stop_id, t2.agency_id + ) + SELECT t3.stop_id, t3.agency_id + FROM t3 + WHERE t3.frequency = ( + SELECT MAX(frequency) + FROM t3 AS sub + WHERE sub.stop_id = t3.stop_id + );""" + frequent_agency = conn.execute(sql).fetchall() + zones = [] for counter, (_, stop) in enumerate(self.select_stops.items()): - if stop.zone in zone_ids: - stop.zone_id = zone_ids[stop.zone] if self.__has_taz: closest_zone = self.project.zoning.get_closest_zone(stop.geo) if stop.geo.within(self.project.zoning.get(closest_zone).geometry): stop.taz = closest_zone - if zone_ids: - zones.append([zone_ids[stop.zone], closest_zone]) + stop.agency_id = frequent_agency[counter][1] + if stop.zone_id: + zones.append((stop.zone_id, "", stop.agency_id)) stop.save_to_database(conn, commit=False) conn.commit() + zones = list(set(zones)) + if zones: - sql = "insert into fare_zones (fare_zone_id, transit_zone) values(?, ?);" + sql = "insert into fare_zones (fare_zone_id, transit_zone, agency_id) values(?,?,?);" conn.executemany(sql, zones) conn.commit() diff --git a/aequilibrae/transit/parse_csv.py b/aequilibrae/transit/parse_csv.py index 408febda2..2709a1c86 100644 --- a/aequilibrae/transit/parse_csv.py +++ b/aequilibrae/transit/parse_csv.py @@ -35,7 +35,7 @@ def parse_csv(file_name: str, column_order=[]): # noqa B006 missing_cols_names = [x for x in column_order.keys() if x not in data.dtype.names] for col in missing_cols_names: - data = append_fields(data, col, np.array([""] * len(tot))) + data = append_fields(data, col, np.array([""] * len(tot)), usemask=False) if column_order: col_names = [x for x in column_order.keys() if x in data.dtype.names] diff --git a/aequilibrae/transit/transit_elements/stop.py b/aequilibrae/transit/transit_elements/stop.py index 341159212..c70fb0ad6 100644 --- a/aequilibrae/transit/transit_elements/stop.py +++ b/aequilibrae/transit/transit_elements/stop.py @@ -51,8 +51,8 @@ def __init__(self, record: tuple, headers: list): if None not in [self.stop_lon, self.stop_lat]: self.geo = Point(self.stop_lon, self.stop_lat) - if len(str(self.zone_id)) == 0: - self.zone_id = None + if len(self.zone) > 0: + self.zone_id = int(self.zone) def save_to_database(self, conn: Connection, commit=True) -> None: """Saves Transit Stop to the database""" @@ -71,6 +71,7 @@ def data(self) -> list: return [ self.stop_id, self.stop, + self.agency_id, self.link, self.dir, self.stop_name, diff --git a/tests/aequilibrae/project/test_transit_tables.py b/tests/aequilibrae/project/test_transit_tables.py index 7bac0cf50..c12f7407c 100644 --- a/tests/aequilibrae/project/test_transit_tables.py +++ b/tests/aequilibrae/project/test_transit_tables.py @@ -18,7 +18,7 @@ def create_project(project: Project): ["fare_id", "fare", "agency_id", "price", "currency", "payment_method", "transfer", "transfer_duration"], ), ("fare_rules", ["fare_id", "route_id", "origin", "destination", "contains"]), - ("fare_zones", ["fare_zone_id", "transit_zone"]), + ("fare_zones", ["fare_zone_id", "transit_zone", "agency_id"]), ("pattern_mapping", ["pattern_id", "seq", "link", "dir", "geometry"]), ( "routes", @@ -42,6 +42,7 @@ def create_project(project: Project): [ "stop_id", "stop", + "agency_id", "link", "dir", "name", diff --git a/tests/aequilibrae/transit/test_gtfs_stop.py b/tests/aequilibrae/transit/test_gtfs_stop.py index eacd33fb0..bc2ea9674 100644 --- a/tests/aequilibrae/transit/test_gtfs_stop.py +++ b/tests/aequilibrae/transit/test_gtfs_stop.py @@ -18,7 +18,7 @@ def data(): "stop_desc": randomword(randint(0, 40)), "stop_lat": uniform(0, 30000), "stop_lon": uniform(0, 30000), - "zone_id": randomword(randint(0, 40)), + "zone_id": str(randint(0, 40)), "stop_url": randomword(randint(0, 40)), "location_type": choice((0, 1)), "parent_station": randomword(randint(0, 40)), @@ -51,6 +51,7 @@ def test_save_to_database(data, transit_conn): s.link = link = randint(1, 30000) s.dir = direc = choice((0, 1)) s.route_type = randint(0, 13) + s.agency_id = randint(1, 10) s.srid = get_srid() s.get_node_id() s.save_to_database(transit_conn, commit=True) From 5013174cf84c60f3d016645c5582e189b151317b Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Thu, 15 Aug 2024 14:15:17 -0300 Subject: [PATCH 13/18] undo conflicts --- aequilibrae/transit/gtfs_loader.py | 41 +++++++----------------------- aequilibrae/transit/lib_gtfs.py | 5 +++- 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/aequilibrae/transit/gtfs_loader.py b/aequilibrae/transit/gtfs_loader.py index e207233c9..d5004c4eb 100644 --- a/aequilibrae/transit/gtfs_loader.py +++ b/aequilibrae/transit/gtfs_loader.py @@ -46,7 +46,6 @@ def __init__(self): self.wgs84 = pyproj.Proj("epsg:4326") self.srid = get_srid() self.transformer = Transformer.from_crs("epsg:4326", f"epsg:{self.srid}", always_xy=False) - self.__mt = "" self.agency_correspondence = {} self.logger = get_logger() @@ -67,7 +66,6 @@ def set_feed_path(self, file_path): self.zip_archive.close() self.feed_date = splitext(basename(file_path))[0] - self.__mt = "Reading GTFS" def _set_capacities(self, capacities: dict): self.__capacities__ = capacities @@ -87,10 +85,7 @@ def load_data(self, service_date: str, description: str): self.service_date = service_date self.description = description self.logger.info(f"Loading data for {self.service_date} from the GTFS feed. This may take some time") - - self.__mt = "Reading GTFS" - self.signal.emit(["start", "master", 6, self.__mt, self.__mt]) - + self.__load_date() def __load_date(self): @@ -98,7 +93,7 @@ def __load_date(self): self.zip_archive = zipfile.ZipFile(self.archive_dir) self.__load_agencies() - + self.signal.emit(["start", 7, "Loading routes"]) self.__load_routes_table() @@ -127,9 +122,6 @@ def __load_date(self): def __deconflict_stop_times(self) -> None: self.logger.info("Starting deconflict_stop_times") - msg_txt = "Interpolating stop times for feed agencies" - self.signal.emit(["start", "secondary", len(self.trips), msg_txt, self.__mt]) - total_fast = 0 for prog_counter, route in enumerate(self.trips): max_speeds = self.__max_speeds__.get(self.routes[route].route_type, pd.DataFrame([])) @@ -270,9 +262,6 @@ def __load_shapes_table(self): all_shape_ids = np.unique(shapes["shape_id"]).tolist() - msg_txt = "Load shapes" - self.signal.emit(["start", "secondary", len(all_shape_ids), msg_txt, self.__mt]) - self.data_arrays[shapestxt] = shapes lats, lons = self.transformer.transform(shapes[:]["shape_pt_lat"], shapes[:]["shape_pt_lon"]) shapes[:]["shape_pt_lat"][:] = lats[:] @@ -294,9 +283,6 @@ def __load_trips_table(self): trips_array = parse_csv(file, column_order[tripstxt]) self.data_arrays[tripstxt] = trips_array - msg_txt = "Load trips" - self.signal.emit(["start", "secondary", trips_array.shape[0], msg_txt, self.__mt]) - if np.unique(trips_array["trip_id"]).shape[0] < trips_array.shape[0]: self.__fail("There are repeated trip IDs in trips.txt") @@ -348,7 +334,7 @@ def __load_trips_table(self): trip.source_time = list(stop_times.source_time.values) self.logger.debug(f"{trip.trip} has {len(trip.stops)} stops") trip._stop_based_shape = LineString([self.stops[x].geo for x in trip.stops]) - + # trip.shape = self.shapes.get(trip.shape) trip.pce = self.routes[trip.route].pce trip.seated_capacity = self.routes[trip.route].seated_capacity @@ -403,8 +389,6 @@ def __load_stop_times(self): with self.zip_archive.open(stoptimestxt, "r") as file: stoptimes = parse_csv(file, column_order[stoptimestxt]) self.data_arrays[stoptimestxt] = stoptimes - - msg_txt = "Load stop times" df = pd.DataFrame(stoptimes) for col in ["arrival_time", "departure_time"]: @@ -463,10 +447,7 @@ def __load_stops_table(self): stops[:]["stop_lat"][:] = lats[:] stops[:]["stop_lon"][:] = lons[:] - msg_txt = "Load stops" - self.signal.emit(["start", "secondary", stops.shape[0], msg_txt, self.__mt]) for i, line in enumerate(stops): - self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt]) s = Stop(line, stops.dtype.names) s.srid = self.srid s.get_node_id() @@ -485,21 +466,21 @@ def __load_routes_table(self): if np.unique(routes["route_id"]).shape[0] < routes.shape[0]: self.__fail("There are repeated route IDs in routes.txt") - msg_txt = "Load Routes" - self.signal.emit(["start", "secondary", len(routes), msg_txt, self.__mt]) - - cap = self.__capacities__.get("other", [None, None, None]) - + seated_cap, total_cap = self.__capacities__.get("other", [None, None]) routes = pd.DataFrame(routes) routes = routes.assign(seated_capacity=seated_cap, total_capacity=total_cap, srid=self.srid) for route_type, cap in self.__capacities__.items(): routes.loc[routes.route_type == route_type, ["seated_capacity", "total_capacity"]] = cap + default_pce = self.__pces__.get("other", 2.0) + routes = routes.assign(pce=default_pce) + for route_type, pce in self.__pces__.items(): + routes.loc[routes.route_type == route_type, ["pce"]] = pce + agency_finder = routes["agency_id"].values.tolist() routes.drop(columns="agency_id", inplace=True) for i, line in routes.iterrows(): - self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt]) r = Route(self.agency_correspondence[agency_finder[i]]) r.populate(line.values, routes.columns) self.routes[r.route] = r @@ -607,11 +588,7 @@ def __load_agencies(self): agencies = parse_csv(file, column_order[agencytxt]) self.data_arrays[agencytxt] = agencies - msg_txt = "Load Agencies" - self.signal.emit(["start", "secondary", len(agencies), msg_txt, self.__mt]) - for i, line in enumerate(agencies): - self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt]) a = Agency() a.agency = line["agency_name"] a.feed_date = self.feed_date diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index 20133b5a2..eb4bb12fb 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -20,7 +20,9 @@ class GTFSRouteSystemBuilder: """Container for GTFS feeds providing data retrieval for the importer""" - def __init__(self, network, file_path, description="", capacities={}): # noqa: B006 + signal = SIGNAL(object) + + def __init__(self, network, file_path, description="", capacities=None, pces=None): # noqa: B006 """Instantiates a transit class for the network :Arguments: @@ -48,6 +50,7 @@ def __init__(self, network, file_path, description="", capacities={}): # noqa: self.transformer = Transformer.from_crs("epsg:4326", f"epsg:{self.srid}", always_xy=False) self.sridproj = pyproj.Proj(f"epsg:{self.srid}") self.__default_capacities = capacities + self.__default_pces = {} if pces is None else pces self.__do_execute_map_matching = False self.__target_date__ = None self.__outside_zones = 0 From b1b7d2b5e54d1e5cd6fc57af4164f274a6129ace Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Thu, 15 Aug 2024 14:33:11 -0300 Subject: [PATCH 14/18] Update lib_gtfs.py --- aequilibrae/transit/lib_gtfs.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index eb4bb12fb..c03b43c23 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -28,15 +28,12 @@ def __init__(self, network, file_path, description="", capacities=None, pces=Non :Arguments: **local network** (:obj:`Network`): Supply model to which this GTFS will be imported - **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') - **description** (:obj:`str`, *Optional*): Description for this feed (e.g. 'CTA19 fixed by John after coffee') """ self.__network = network self.project = get_active_project(False) self.archive_dir = None # type: str - self.day = None self.logger = get_logger() self.gtfs_data = GTFSReader() @@ -49,7 +46,7 @@ def __init__(self, network, file_path, description="", capacities=None, pces=Non self.description = description self.transformer = Transformer.from_crs("epsg:4326", f"epsg:{self.srid}", always_xy=False) self.sridproj = pyproj.Proj(f"epsg:{self.srid}") - self.__default_capacities = capacities + self.__default_capacities = {} if capacities is None else capacities self.__default_pces = {} if pces is None else pces self.__do_execute_map_matching = False self.__target_date__ = None From 90e3e9175b019221d4664017ec1bf3af6f069dfe Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Thu, 15 Aug 2024 14:46:30 -0300 Subject: [PATCH 15/18] . --- aequilibrae/transit/lib_gtfs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index c03b43c23..62fab0de2 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -34,6 +34,7 @@ def __init__(self, network, file_path, description="", capacities=None, pces=Non self.__network = network self.project = get_active_project(False) self.archive_dir = None # type: str + self.day = None self.logger = get_logger() self.gtfs_data = GTFSReader() From a5978c30733e99f7a355f49fc0257ff532d99289 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Thu, 19 Sep 2024 08:27:01 -0300 Subject: [PATCH 16/18] linting --- aequilibrae/transit/lib_gtfs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index 54c849550..9893bd2de 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -30,7 +30,7 @@ def __init__(self, network, file_path, description="", capacities=None, pces=Non **local network** (:obj:`Network`): Supply model to which this GTFS will be imported **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') - + **description** (:obj:`str`, *Optional*): Description for this feed (e.g. 'CTA19 fixed by John after coffee') """ self.__network = network From e16e09fba73c6087d75124b323b3c4fdff348a25 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Fri, 11 Oct 2024 11:01:52 -0300 Subject: [PATCH 17/18] fix tests --- .../database_specification/transit/tables/fare_zones.sql | 4 +++- .../project/database_specification/transit/tables/stops.sql | 2 +- tests/aequilibrae/project/test_transit_tables.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql index f578a366b..25f08abc9 100644 --- a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql +++ b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql @@ -1,7 +1,9 @@ --@ The *fare_zones* table hold information on the transit fare zones and --@ the transit agencies that operate in it. --@ ---@ **transit_fare_zone** identifies the transit fare zones +--@ **fare_zone_id** identifies the fare zone +--@ +--@ **transit_zone** identifies the transit fare zones --@ --@ **agency_id** identifies the agency/agencies for the specified fare zone diff --git a/aequilibrae/project/database_specification/transit/tables/stops.sql b/aequilibrae/project/database_specification/transit/tables/stops.sql index 8b08f8a2a..850d84b86 100644 --- a/aequilibrae/project/database_specification/transit/tables/stops.sql +++ b/aequilibrae/project/database_specification/transit/tables/stops.sql @@ -23,7 +23,7 @@ --@ --@ **fare_zone_id** identifies the fare zone for a stop --@ ---@ **transit_fare_zone** identifies the transit fare zone for a stop +--@ **transit_zone** identifies the transit fare zone for a stop --@ --@ **route_type** indicates the type of transporation used on a route diff --git a/tests/aequilibrae/project/test_transit_tables.py b/tests/aequilibrae/project/test_transit_tables.py index e025ccf21..27249aebd 100644 --- a/tests/aequilibrae/project/test_transit_tables.py +++ b/tests/aequilibrae/project/test_transit_tables.py @@ -18,7 +18,7 @@ def create_project(project: Project): ["fare_id", "fare", "agency_id", "price", "currency", "payment_method", "transfer", "transfer_duration"], ), ("fare_rules", ["fare_id", "route_id", "origin", "destination", "contains"]), - ("fare_zones", ["transit_fare_zone", "agency_id"]), + ("fare_zones", ["fare_zone_id", "transit_zone", "agency_id"]), ("pattern_mapping", ["pattern_id", "seq", "link", "dir", "geometry"]), ( "routes", From 3e9ae67eada73ab76b0febd0dce84bc3588cc514 Mon Sep 17 00:00:00 2001 From: Renata Imai Date: Mon, 14 Oct 2024 17:10:55 -0300 Subject: [PATCH 18/18] Update documentation.yml --- .github/workflows/documentation.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index bad2dacf5..fdc6bb7af 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -57,7 +57,7 @@ jobs: - name: Build documentation run: | jupyter nbconvert --to rst docs/source/useful_information/validation_benchmarking/IPF_benchmark.ipynb - sphinx-build -M latexpdf docs/source docs/source/_static + # sphinx-build -M latexpdf docs/source docs/source/_static sphinx-build -b html docs/source docs/build python3 -m zipfile -c AequilibraE.zip docs/build cp AequilibraE.zip docs/source/_static