Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support to multiple transit agencies #535

Open
wants to merge 24 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--@ The *agencies* table holds information about the Public Transport
--@ agencies within the GTFS data. This table information comes from
--@ GTFS file *agency.txt*.
--@ You can check out more information `here <https://developers.google.com/transit/gtfs/reference#agencytxt>`_.
--@ You can check out more information `here <https://gtfs.org/schedule/reference/#agencytxt>`_.
--@
--@ **agency_id** identifies the agency for the specified route
--@
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--@ The *fare_attributes* table holds information about the fare values.
--@ This table information comes from the GTFS file *fare_attributes.txt*.
--@ Given that this file is optional in GTFS, it can be empty.
--@ You can check out more information `here <https://developers.google.com/transit/gtfs/reference#fare_attributestxt>`_.
--@ You can check out more information `here <https://gtfs.org/schedule/reference/#fare_attributestxt>`_.
--@
--@ **fare_id** identifies a fare class
--@
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
--@ **fare_zone_id** identifies the fare zone for a stop
--@
--@ **transit_zone** identifies the TAZ for a fare zone
--@
--@ **agency_id** identifies the agency fot the specified route

CREATE TABLE IF NOT EXISTS fare_zones (
fare_zone_id INTEGER PRIMARY KEY,
transit_zone TEXT NOT NULL,
agency_id INTEGER NOT NULL,
r-akemii marked this conversation as resolved.
Show resolved Hide resolved
FOREIGN KEY(agency_id) REFERENCES agencies(agency_id) deferrable initially deferred
fare_zone_id INTEGER NOT NULL,
transit_zone TEXT NOT NULL
);
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--@ The *routes* table holds information on the available transit routes for a
--@ specific day. This table information comes from the GTFS file *routes.txt*.
--@ You can find more information about it `here <https://developers.google.com/transit/gtfs/reference#routestxt>`_.
--@ You can find more information about it `here <https://gtfs.org/schedule/reference/#routestxt>`_.
--@
--@ **pattern_id** is an unique pattern for the route
--@
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
--@ The *stops* table holds information on the stops where vehicles
--@ pick up or drop off riders. This table information comes from
--@ the GTFS file *stops.txt*. You can find more information about
--@ it `here <https://developers.google.com/transit/gtfs/reference#stopstxt>`_.
--@ it `here <https://gtfs.org/schedule/reference/#stopstxt>`_.
--@
--@ **stop_id** is an unique identifier for a stop
--@
--@ **stop** idenfifies a stop, statio, or station entrance
--@
--@ **agency_id** identifies the agency fot the specified route
--@
--@ **link** identifies the *link_id* in the links table that corresponds to the
--@ pattern matching
--@
Expand All @@ -21,8 +19,6 @@
--@
--@ **description** provides useful description of the stop location
--@
--@ **street** identifies the address of a stop
--@
--@ **fare_zone_id** identifies the fare zone for a stop
--@
--@ **transit_zone** identifies the TAZ for a fare zone
Expand All @@ -32,17 +28,14 @@
CREATE TABLE IF NOT EXISTS stops (
stop_id TEXT PRIMARY KEY,
stop TEXT NOT NULL ,
agency_id INTEGER NOT NULL,
r-akemii marked this conversation as resolved.
Show resolved Hide resolved
link INTEGER,
dir INTEGER,
name TEXT,
parent_station TEXT,
description TEXT,
street TEXT,
fare_zone_id INTEGER,
transit_zone TEXT,
route_type INTEGER NOT NULL DEFAULT -1,
FOREIGN KEY(agency_id) REFERENCES agencies(agency_id),
FOREIGN KEY("fare_zone_id") REFERENCES fare_zones("fare_zone_id")
);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--@ The *trips* table holds information on trips for each route.
--@ This table comes from the GTFS file *trips.txt*.
--@ You can find more information about it `here <https://developers.google.com/transit/gtfs/reference#tripstxt>`_.
--@ You can find more information about it `here <https://gtfs.org/schedule/reference/#tripstxt>`_.
--@
--@ **trip_id** identifies a trip
--@
Expand Down
15 changes: 8 additions & 7 deletions aequilibrae/transit/column_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
("agency_name", str),
("agency_url", str),
("agency_timezone", str),
("agency_lang", str),
("agency_phone", str),
("agency_fare_url", str),
("agency_email", str),
# ("agency_lang", str),
# ("agency_phone", str),
# ("agency_fare_url", str),
# ("agency_email", str),
]
),
"routes.txt": OrderedDict(
Expand All @@ -24,7 +24,7 @@
# ("route_color", str),
# ("route_text_color", str),
# ("route_sort_order", int),
# ("agency_id", str),
("agency_id", str),
]
),
"trips.txt": OrderedDict(
Expand Down Expand Up @@ -76,7 +76,7 @@
("currency_type", str),
("payment_method", int),
("transfers", int),
# ("agency_id", str),
("agency_id", str),
("transfer_duration", float),
]
),
Expand Down Expand Up @@ -113,13 +113,14 @@
("stop_desc", str),
("stop_lat", float),
("stop_lon", float),
("stop_street", str),
("zone_id", str),
# ("stop_url", str),
# ("location_type", int),
("parent_station", str),
# ("stop_timezone", str),
# ("wheelchair_boarding", int),
# ("level_id", int),
# ("platform_code", str)
]
),
"shapes.txt": OrderedDict(
Expand Down
3 changes: 2 additions & 1 deletion aequilibrae/transit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
WALK_LINK_RANGE = 30000000
TRANSIT_LINK_RANGE = 20000000
WALK_AGENCY_ID = 1
STOP_ID = 1

# 1 for right, -1 for wrong (left)
DRIVING_SIDE = 1
Expand All @@ -21,7 +22,7 @@ class Constants:
trips: Dict[int, int] = {}
patterns: Dict[int, int] = {}
pattern_lookup: Dict[int, int] = {}
stops: Dict[int, int] = {}
stops: Dict[int, Any] = {}
fares: Dict[int, int] = {}
links: Dict[int, int] = {}
transit_links: Dict[int, int] = {}
84 changes: 56 additions & 28 deletions aequilibrae/transit/gtfs_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self):
self.__capacities__ = {}
self.__max_speeds__ = {}
self.feed_date = ""
self.agency = Agency()
self.agency: Dict[int, Agency] = {}
self.services = {}
self.routes: Dict[int, Route] = {}
self.trips: Dict[int, Dict[Route]] = {}
Expand All @@ -61,6 +61,7 @@ def __init__(self):
self.srid = get_srid()
self.transformer = Transformer.from_crs("epsg:4326", f"epsg:{self.srid}", always_xy=False)
self.__mt = ""
self.agency_correspondence = {}
self.logger = get_logger()

def set_feed_path(self, file_path):
Expand All @@ -80,24 +81,25 @@ def set_feed_path(self, file_path):
self.zip_archive.close()

self.feed_date = splitext(basename(file_path))[0]
self.__mt = f"Reading GTFS for {self.agency.agency}"
self.__mt = "Reading GTFS"

def _set_capacities(self, capacities: dict):
self.__capacities__ = capacities

def _set_maximum_speeds(self, max_speeds: dict):
self.__max_speeds__ = max_speeds

def load_data(self, service_date: str):
def load_data(self, service_date: str, description: str):
"""Loads the data for a respective service date.

:Arguments:
**service_date** (:obj:`str`): service date. e.g. "2020-04-01".
"""
ag_id = self.agency.agency
self.logger.info(f"Loading data for {service_date} from the {ag_id} GTFS feed. This may take some time")
self.service_date = service_date
self.description = description
self.logger.info(f"Loading data for {self.service_date} from the GTFS feed. This may take some time")

self.__mt = f"Reading GTFS for {ag_id}"
self.__mt = "Reading GTFS"
self.signal.emit(["start", "master", 6, self.__mt, self.__mt])

self.__load_date()
Expand All @@ -110,6 +112,8 @@ def finished(self):
def __load_date(self):
self.logger.debug("Starting __load_date")
self.zip_archive = zipfile.ZipFile(self.archive_dir)

self.__load_agencies()
self.__load_routes_table()
self.__load_stops_table()
self.__load_stop_times()
Expand All @@ -123,7 +127,7 @@ def __load_date(self):
def __deconflict_stop_times(self) -> None:
self.logger.info("Starting deconflict_stop_times")

msg_txt = f"Interpolating stop times for {self.agency.agency}"
msg_txt = "Interpolating stop times for feed agencies"
self.signal.emit(["start", "secondary", len(self.trips), msg_txt, self.__mt])
total_fast = 0
for prog_counter, route in enumerate(self.trips):
Expand Down Expand Up @@ -215,13 +219,19 @@ def __load_fare_data(self):
fareatt = parse_csv(file, column_order[fareatttxt])
self.data_arrays[fareatttxt] = fareatt

existing_agencies = np.unique(fareatt["agency_id"])
if existing_agencies.shape[0] != len(self.agency):
self.logger.debug("agency_id exists on fare_attributes.txt but not in agency.txt")
elif existing_agencies.shape[0] == 1 and existing_agencies[0] == "":
fareatt["agency_id"] = list(self.agency.keys())[0]

for line in range(fareatt.shape[0]):
data = tuple(fareatt[line][list(column_order[fareatttxt].keys())])
headers = ["fare_id", "price", "currency", "payment_method", "transfer", "transfer_duration"]
f = Fare(self.agency.agency_id)
f = Fare(fareatt[line]["agency_id"])
f.populate(data, headers)
if f.fare in self.fare_attributes:
self.__fail(f"Fare ID {f.fare} for {self.agency.agency} is duplicated")
self.__fail(f"Fare ID {f.fare} for {fareatt[line]['agency_id']} is duplicated")
self.fare_attributes[f.fare] = f

farerltxt = "fare_rules.txt"
Expand All @@ -235,23 +245,17 @@ def __load_fare_data(self):
farerl = parse_csv(file, column_order[farerltxt])
self.data_arrays[farerltxt] = farerl

corresp = {}
zone_id = self.agency.agency_id * AGENCY_MULTIPLIER + 1
# corresp = {}
for line in range(farerl.shape[0]):
data = tuple(farerl[line][list(column_order[farerltxt].keys())])
fr = FareRule()
fr.populate(data, ["fare", "route", "origin", "destination", "contains"])
fr.fare_id = self.fare_attributes[fr.fare].fare_id
if fr.route in self.routes:
fr.route_id = self.routes[fr.route].route_id
fr.agency_id = self.agency.agency_id
for x in [fr.origin, fr.destination]:
if x not in corresp:
corresp[x] = zone_id
zone_id += 1
fr.origin_id = corresp[fr.origin]
fr.destination_id = corresp[fr.destination] if fr.destination == "" else fr.destination_id
self.fare_rules.append(fr) if fr.origin == "" else fr.origin_id
fr.origin_id = None if fr.origin == "" else int(fr.origin)
fr.destination_id = None if fr.destination == "" else int(fr.destination)
self.fare_rules.append(fr)

def __load_shapes_table(self):
self.logger.debug("Starting __load_shapes_table")
Expand All @@ -266,7 +270,7 @@ def __load_shapes_table(self):
shapes = parse_csv(file, column_order[shapestxt])

all_shape_ids = np.unique(shapes["shape_id"]).tolist()
msg_txt = f"Load shapes - {self.agency.agency}"
msg_txt = "Load shapes"
self.signal.emit(["start", "secondary", len(all_shape_ids), msg_txt, self.__mt])

self.data_arrays[shapestxt] = shapes
Expand All @@ -291,7 +295,7 @@ def __load_trips_table(self):
trips_array = parse_csv(file, column_order[tripstxt])
self.data_arrays[tripstxt] = trips_array

msg_txt = f"Load trips - {self.agency.agency}"
msg_txt = "Load trips"
self.signal.emit(["start", "secondary", trips_array.shape[0], msg_txt, self.__mt])
if np.unique(trips_array["trip_id"]).shape[0] < trips_array.shape[0]:
self.__fail("There are repeated trip IDs in trips.txt")
Expand Down Expand Up @@ -345,7 +349,6 @@ def __load_trips_table(self):
trip.source_time = list(stop_times.source_time.values)
self.logger.debug(f"{trip.trip} has {len(trip.stops)} stops")
trip._stop_based_shape = LineString([self.stops[x].geo for x in trip.stops])
# trip.shape = self.shapes.get(trip.shape)
trip.seated_capacity = self.routes[trip.route].seated_capacity
trip.total_capacity = self.routes[trip.route].total_capacity
self.trips[trip.route] = self.trips.get(trip.route, {})
Expand Down Expand Up @@ -398,7 +401,7 @@ def __load_stop_times(self):
with self.zip_archive.open(stoptimestxt, "r") as file:
stoptimes = parse_csv(file, column_order[stoptimestxt])
self.data_arrays[stoptimestxt] = stoptimes
msg_txt = f"Load stop times - {self.agency.agency}"
msg_txt = "Load stop times"

df = pd.DataFrame(stoptimes)
for col in ["arrival_time", "departure_time"]:
Expand Down Expand Up @@ -459,12 +462,11 @@ def __load_stops_table(self):
stops[:]["stop_lat"][:] = lats[:]
stops[:]["stop_lon"][:] = lons[:]

msg_txt = f"Load stops - {self.agency.agency}"
msg_txt = "Load stops"
self.signal.emit(["start", "secondary", stops.shape[0], msg_txt, self.__mt])
for i, line in enumerate(stops):
self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt])
s = Stop(self.agency.agency_id, line, stops.dtype.names)
s.agency = self.agency.agency
s = Stop(line, stops.dtype.names)
s.srid = self.srid
s.get_node_id()
self.stops[s.stop_id] = s
Expand All @@ -482,7 +484,7 @@ def __load_routes_table(self):
if np.unique(routes["route_id"]).shape[0] < routes.shape[0]:
self.__fail("There are repeated route IDs in routes.txt")

msg_txt = f"Load Routes - {self.agency.agency}"
msg_txt = "Load Routes"
self.signal.emit(["start", "secondary", len(routes), msg_txt, self.__mt])

cap = self.__capacities__.get("other", [None, None, None])
Expand All @@ -491,9 +493,12 @@ def __load_routes_table(self):
for route_type, cap in self.__capacities__.items():
routes.loc[routes.route_type == route_type, ["seated_capacity", "total_capacity"]] = cap

agency_finder = routes["agency_id"].values.tolist()
routes.drop(columns="agency_id", inplace=True)

for i, line in routes.iterrows():
self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt])
r = Route(self.agency.agency_id)
r = Route(self.agency_correspondence[agency_finder[i]])
r.populate(line.values, routes.columns)
self.routes[r.route] = r

Expand Down Expand Up @@ -597,6 +602,29 @@ def __load_feed_calendar(self):
if exception_inconsistencies:
self.logger.info(" Minor inconsistencies found between calendar.txt and calendar_dates.txt")

def __load_agencies(self):
self.logger.debug("Starting __load_agencies")
agencytxt = "agency.txt"

self.logger.debug(' Loading "agency" table')
self.agency = {}
with self.zip_archive.open(agencytxt, "r") as file:
agencies = parse_csv(file, column_order[agencytxt])
self.data_arrays[agencytxt] = agencies

msg_txt = "Load Agencies"
self.signal.emit(["start", "secondary", len(agencies), msg_txt, self.__mt])

for i, line in enumerate(agencies):
self.signal.emit(["update", "secondary", i + 1, msg_txt, self.__mt])
a = Agency()
a.agency = line["agency_name"]
a.feed_date = self.feed_date
a.service_date = self.service_date
a.description = self.description
self.agency[a.agency_id] = a
self.agency_correspondence[line["agency_id"]] = a.agency_id

def __fail(self, msg: str) -> None:
self.logger.error(msg)
raise Exception(msg)
Loading
Loading