diff --git a/.github/workflows/api-deployer.yml b/.github/workflows/api-deployer.yml index 2d631bece..e19249ffe 100644 --- a/.github/workflows/api-deployer.yml +++ b/.github/workflows/api-deployer.yml @@ -56,6 +56,10 @@ on: description: Validator endpoint required: true type: string + OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD: + description: Oauth client id part of the authorization for the operations API + required: true + type: string env: python_version: '3.11' @@ -255,6 +259,11 @@ jobs: name: feeds_gen path: api/src/feeds_gen/ + - uses: actions/download-artifact@v4 + with: + name: feeds_operations_gen + path: functions-python/operations_api/src/feeds_operations_gen/ + - name: Build python functions run: | scripts/function-python-build.sh --all @@ -290,11 +299,12 @@ jobs: env: OP_SERVICE_ACCOUNT_TOKEN: ${{ secrets.OP_SERVICE_ACCOUNT_TOKEN }} TRANSITLAND_API_KEY: "op://rbiv7rvkkrsdlpcrz3bmv7nmcu/TansitLand API Key/credential" + OPERATIONS_OAUTH2_CLIENT_ID: ${{ inputs.OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD }} - name: Populate Variables run: | scripts/replace-variables.sh -in_file infra/backend.conf.rename_me -out_file infra/backend.conf -variables BUCKET_NAME,OBJECT_PREFIX - scripts/replace-variables.sh -in_file infra/vars.tfvars.rename_me -out_file infra/vars.tfvars -variables PROJECT_ID,REGION,ENVIRONMENT,DEPLOYER_SERVICE_ACCOUNT,FEED_API_IMAGE_VERSION,OAUTH2_CLIENT_ID,OAUTH2_CLIENT_SECRET,GLOBAL_RATE_LIMIT_REQ_PER_MINUTE,ARTIFACT_REPO_NAME,VALIDATOR_ENDPOINT,TRANSITLAND_API_KEY + scripts/replace-variables.sh -in_file infra/vars.tfvars.rename_me -out_file infra/vars.tfvars -variables PROJECT_ID,REGION,ENVIRONMENT,DEPLOYER_SERVICE_ACCOUNT,FEED_API_IMAGE_VERSION,OAUTH2_CLIENT_ID,OAUTH2_CLIENT_SECRET,GLOBAL_RATE_LIMIT_REQ_PER_MINUTE,ARTIFACT_REPO_NAME,VALIDATOR_ENDPOINT,TRANSITLAND_API_KEY,OPERATIONS_OAUTH2_CLIENT_ID - uses: hashicorp/setup-terraform@v3 with: diff --git a/.github/workflows/api-dev.yml b/.github/workflows/api-dev.yml index f3738b9ec..7606f3c80 100644 --- a/.github/workflows/api-dev.yml +++ b/.github/workflows/api-dev.yml @@ -22,6 +22,7 @@ jobs: GLOBAL_RATE_LIMIT_REQ_PER_MINUTE: ${{ vars.GLOBAL_RATE_LIMIT_REQ_PER_MINUTE }} TF_APPLY: true VALIDATOR_ENDPOINT: https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app + OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD: "op://rbiv7rvkkrsdlpcrz3bmv7nmcu/GCP_RETOOL_OAUTH2_CREDS/username" secrets: GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.DEV_GCP_MOBILITY_FEEDS_SA_KEY }} OAUTH2_CLIENT_ID: ${{ secrets.DEV_MOBILITY_FEEDS_OAUTH2_CLIENT_ID}} diff --git a/.github/workflows/api-prod.yml b/.github/workflows/api-prod.yml index 3938583e9..4934898f5 100644 --- a/.github/workflows/api-prod.yml +++ b/.github/workflows/api-prod.yml @@ -18,6 +18,7 @@ jobs: GLOBAL_RATE_LIMIT_REQ_PER_MINUTE: ${{ vars.GLOBAL_RATE_LIMIT_REQ_PER_MINUTE }} TF_APPLY: true VALIDATOR_ENDPOINT: https://gtfs-validator-web-mbzoxaljzq-ue.a.run.app + OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD: "op://rbiv7rvkkrsdlpcrz3bmv7nmcu/GCP_RETOOL_OAUTH2_CREDS/username" secrets: GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.PROD_GCP_MOBILITY_FEEDS_SA_KEY }} OAUTH2_CLIENT_ID: ${{ secrets.PROD_MOBILITY_FEEDS_OAUTH2_CLIENT_ID}} diff --git a/.github/workflows/api-qa.yml b/.github/workflows/api-qa.yml index 2f527f4ec..04a61b3bd 100644 --- a/.github/workflows/api-qa.yml +++ b/.github/workflows/api-qa.yml @@ -18,6 +18,7 @@ jobs: TF_APPLY: true GLOBAL_RATE_LIMIT_REQ_PER_MINUTE: ${{ vars.GLOBAL_RATE_LIMIT_REQ_PER_MINUTE }} VALIDATOR_ENDPOINT: https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app + OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD: "op://rbiv7rvkkrsdlpcrz3bmv7nmcu/GCP_RETOOL_OAUTH2_CREDS/username" secrets: GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.QA_GCP_MOBILITY_FEEDS_SA_KEY }} OAUTH2_CLIENT_ID: ${{ secrets.DEV_MOBILITY_FEEDS_OAUTH2_CLIENT_ID}} diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 6c682153a..96b990940 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -80,6 +80,10 @@ jobs: scripts/setup-openapi-generator.sh scripts/api-gen.sh + - name: Generate Operations API code + run: | + scripts/api-operations-gen.sh + - name: Unit tests - API shell: bash run: | @@ -104,9 +108,16 @@ jobs: path: api/src/database_gen/ overwrite: true - - name: API generated code + - name: Upload API generated code uses: actions/upload-artifact@v4 with: name: feeds_gen path: api/src/feeds_gen/ + overwrite: true + + - name: Upload Operations API generated code + uses: actions/upload-artifact@v4 + with: + name: feeds_operations_gen + path: functions-python/operations_api/src/feeds_operations_gen/ overwrite: true \ No newline at end of file diff --git a/api/src/feeds/impl/feeds_api_impl.py b/api/src/feeds/impl/feeds_api_impl.py index 998090152..f538540cb 100644 --- a/api/src/feeds/impl/feeds_api_impl.py +++ b/api/src/feeds/impl/feeds_api_impl.py @@ -46,6 +46,7 @@ LocationTranslation, get_feeds_location_translations, ) +from utils.logger import Logger T = TypeVar("T", bound="BasicFeed") @@ -59,11 +60,17 @@ class FeedsApiImpl(BaseFeedsApi): APIFeedType = Union[BasicFeed, GtfsFeed, GtfsRTFeed] + def __init__(self) -> None: + self.logger = Logger("FeedsApiImpl").get_logger() + def get_feed( self, id: str, ) -> BasicFeed: """Get the specified feed from the Mobility Database.""" + is_email_restricted = is_user_email_restricted() + self.logger.info(f"User email is restricted: {is_email_restricted}") + feed = ( FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None) .filter(Database().get_query_model(Feed)) @@ -72,7 +79,7 @@ def get_feed( or_( Feed.operational_status == None, # noqa: E711 Feed.operational_status != "wip", - not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + not is_email_restricted, # Allow all feeds to be returned if the user is not restricted ) ) .first() @@ -91,6 +98,8 @@ def get_feeds( producer_url: str, ) -> List[BasicFeed]: """Get some (or all) feeds from the Mobility Database.""" + is_email_restricted = is_user_email_restricted() + self.logger.info(f"User email is restricted: {is_email_restricted}") feed_filter = FeedFilter( status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None ) @@ -100,7 +109,7 @@ def get_feeds( or_( Feed.operational_status == None, # noqa: E711 Feed.operational_status != "wip", - not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + not is_email_restricted, # Allow all feeds to be returned if the user is not restricted ) ) # Results are sorted by provider @@ -239,6 +248,8 @@ def get_gtfs_feeds( subquery, dataset_latitudes, dataset_longitudes, bounding_filter_method ).subquery() + is_email_restricted = is_user_email_restricted() + self.logger.info(f"User email is restricted: {is_email_restricted}") feed_query = ( Database() .get_session() @@ -248,7 +259,7 @@ def get_gtfs_feeds( or_( Gtfsfeed.operational_status == None, # noqa: E711 Gtfsfeed.operational_status != "wip", - not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + not is_email_restricted, # Allow all feeds to be returned if the user is not restricted ) ) .options( diff --git a/api/src/feeds/impl/search_api_impl.py b/api/src/feeds/impl/search_api_impl.py index e8906b13d..1ab21693d 100644 --- a/api/src/feeds/impl/search_api_impl.py +++ b/api/src/feeds/impl/search_api_impl.py @@ -42,7 +42,7 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) -> or_( t_feedsearch.c.operational_status == None, # noqa: E711 t_feedsearch.c.operational_status != "wip", - is_user_email_restricted(), + not is_user_email_restricted(), ) ) if feed_id: diff --git a/api/src/middleware/request_context.py b/api/src/middleware/request_context.py index e019bc633..842120785 100644 --- a/api/src/middleware/request_context.py +++ b/api/src/middleware/request_context.py @@ -94,7 +94,10 @@ def _extract_from_headers(self, headers: dict, scope: Scope) -> None: def __repr__(self) -> str: # Omitting sensitive data like email and jwt assertion safe_properties = dict( - user_id=self.user_id, client_user_agent=self.client_user_agent, client_host=self.client_host + user_id=self.user_id, + client_user_agent=self.client_user_agent, + client_host=self.client_host, + email=self.user_email, ) return f"request-context={safe_properties})" @@ -108,8 +111,8 @@ def is_user_email_restricted() -> bool: Check if an email's domain is restricted (e.g., for WIP visibility). """ request_context = get_request_context() - if not isinstance(request_context, RequestContext): - return True # Default to restricted - email = get_request_context().user_email - unrestricted_domains = ["@mobilitydata.org"] + if not request_context: + return True + email = request_context["user_email"] + unrestricted_domains = ["mobilitydata.org"] return not email or not any(email.endswith(f"@{domain}") for domain in unrestricted_domains) diff --git a/api/tests/unittest/middleware/test_request_context.py b/api/tests/unittest/middleware/test_request_context.py index 23ef7c120..3cb32057d 100644 --- a/api/tests/unittest/middleware/test_request_context.py +++ b/api/tests/unittest/middleware/test_request_context.py @@ -3,7 +3,7 @@ from starlette.datastructures import Headers -from middleware.request_context import RequestContext, get_request_context, _request_context, is_user_email_restricted +from middleware.request_context import RequestContext, get_request_context, _request_context class TestRequestContext(unittest.TestCase): @@ -54,45 +54,3 @@ def test_get_request_context(self): request_context = RequestContext(MagicMock()) _request_context.set(request_context) self.assertEqual(request_context, get_request_context()) - - def test_is_user_email_restricted(self): - self.assertTrue(is_user_email_restricted()) - scope_instance = { - "type": "http", - "asgi": {"version": "3.0"}, - "http_version": "1.1", - "method": "GET", - "headers": [ - (b"host", b"localhost"), - (b"x-forwarded-proto", b"https"), - (b"x-forwarded-for", b"client, proxy1"), - (b"server", b"server"), - (b"user-agent", b"user-agent"), - (b"x-goog-iap-jwt-assertion", b"jwt"), - (b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"), - (b"x-goog-authenticated-user-id", b"user_id"), - (b"x-goog-authenticated-user-email", b"email"), - ], - "path": "/", - "raw_path": b"/", - "query_string": b"", - "client": ("127.0.0.1", 32767), - "server": ("127.0.0.1", 80), - } - request_context = RequestContext(scope=scope_instance) - _request_context.set(request_context) - self.assertTrue(is_user_email_restricted()) - scope_instance["headers"] = [ - (b"host", b"localhost"), - (b"x-forwarded-proto", b"https"), - (b"x-forwarded-for", b"client, proxy1"), - (b"server", b"server"), - (b"user-agent", b"user-agent"), - (b"x-goog-iap-jwt-assertion", b"jwt"), - (b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"), - (b"x-goog-authenticated-user-id", b"user_id"), - (b"x-goog-authenticated-user-email", b"test@mobilitydata.org"), - ] - request_context = RequestContext(scope=scope_instance) - _request_context.set(request_context) - self.assertTrue(is_user_email_restricted()) diff --git a/docs/OperationsAPI.yaml b/docs/OperationsAPI.yaml new file mode 100644 index 000000000..db67a3ccf --- /dev/null +++ b/docs/OperationsAPI.yaml @@ -0,0 +1,323 @@ +openapi: 3.0.0 +info: + version: 1.0.0 + title: Mobility Database Catalog Operations + description: | + API for the Mobility Database Catalog Operations. See [https://mobilitydatabase.org/](https://mobilitydatabase.org/). + This API was designed for internal use and is not intended to be used by the general public. + The Mobility Database Operation API uses Auth2.0 authentication. + termsOfService: https://mobilitydatabase.org/terms-and-conditions + contact: + name: MobilityData + url: https://mobilitydata.org/ + email: api@mobilitydata.org + license: + name: MobilityData License + url: https://www.apache.org/licenses/LICENSE-2.0 + +tags: + - name: "operations" + description: "Mobility Database Operations" + +paths: + /v1/operations/feeds/gtfs: + put: + description: Update the specified GTFS feed in the Mobility Database. + tags: + - "operations" + operationId: updateGtfsFeed + security: + - ApiKeyAuth: [] + requestBody: + description: Payload to update the specified GTFS feed. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UpdateRequestGtfsFeed" + responses: + 200: + description: > + The feed was successfully updated. No content is returned. + 204: + description: > + The feed update request was successfully received, but the update process was skipped as the request matches with the source feed. + 400: + description: > + The request was invalid. + 401: + description: > + The request was not authenticated or has invalid authentication credentials. + 500: + description: > + An internal server error occurred. + + /v1/operations/feeds/gtfs_rt: + put: + description: Update the specified GTFS-RT feed in the Mobility Database. + tags: + - "operations" + operationId: updateGtfsRtFeed + security: + - ApiKeyAuth: [] + requestBody: + description: Payload to update the specified GTFS-RT feed. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UpdateRequestGtfsRtFeed" + responses: + 200: + description: > + The feed was successfully updated. No content is returned. + 204: + description: > + The feed update request was successfully received, but the update process was skipped as the request matches with the source feed. + 400: + description: > + The request was invalid. + 401: + description: > + The request was not authenticated or has invalid authentication credentials. + 500: + description: > + An internal server error occurred. + +components: + schemas: + Redirect: + type: object + properties: + target_id: + description: The feed ID that should be used in replacement of the current one. + type: string + example: mdb-10 + comment: + description: A comment explaining the redirect. + type: string + example: Redirected because of a change of URL. + + UpdateRequestGtfsRtFeed: + type: object + properties: + id: + description: Unique identifier used as a key for the feeds table. + type: string + example: mdb-1210 + status: + $ref: "#/components/schemas/FeedStatus" + external_ids: + $ref: "#/components/schemas/ExternalIds" + provider: + description: A commonly used name for the transit provider included in the feed. + type: string + example: Los Angeles Department of Transportation (LADOT, DASH, Commuter Express) + feed_name: + description: > + An optional description of the data feed, e.g to specify if the data feed is an aggregate of + multiple providers, or which network is represented by the feed. + type: string + example: Bus + note: + description: A note to clarify complex use cases for consumers. + type: string + feed_contact_email: + description: Use to contact the feed producer. + type: string + example: someEmail@ladotbus.com + source_info: + $ref: "#/components/schemas/SourceInfo" + redirects: + type: array + items: + $ref: "#/components/schemas/Redirect" + entity_types: + type: array + items: + $ref: "#/components/schemas/EntityType" + feed_references: + description: + A list of the GTFS feeds that the real time source is associated with, represented by their MDB source IDs. + type: array + items: + type: string + example: "mdb-20" + # This is a temporary fix as the operational status is not visible yet. + operational_status_action: + type: string + enum: + - no_change + - wip + - published + required: + - id + - status + - entity_types + + UpdateRequestGtfsFeed: + type: object + properties: + id: + description: Unique identifier used as a key for the feeds table. + type: string + example: mdb-1210 + status: + $ref: "#/components/schemas/FeedStatus" + external_ids: + $ref: "#/components/schemas/ExternalIds" + provider: + description: A commonly used name for the transit provider included in the feed. + type: string + example: Los Angeles Department of Transportation (LADOT, DASH, Commuter Express) + feed_name: + description: > + An optional description of the data feed, e.g to specify if the data feed is an aggregate of + multiple providers, or which network is represented by the feed. + type: string + example: Bus + note: + description: A note to clarify complex use cases for consumers. + type: string + feed_contact_email: + description: Use to contact the feed producer. + type: string + example: someEmail@ladotbus.com + source_info: + $ref: "#/components/schemas/SourceInfo" + redirects: + type: array + items: + $ref: "#/components/schemas/Redirect" + # This is a temporary fix as the operational status is not visible yet. + operational_status_action: + type: string + enum: + - no_change + - wip + - published + required: + - id + - status + + ExternalIds: + type: array + items: + $ref: "#/components/schemas/ExternalId" + + ExternalId: + type: object + properties: + external_id: + description: The ID that can be use to find the feed data in an external or legacy database. + type: string + example: 1210 + source: + description: The source of the external ID, e.g. the name of the database where the external ID can be used. + type: string + example: mdb + + SourceInfo: + type: object + properties: + producer_url: + description: > + URL where the producer is providing the dataset. + Refer to the authentication information to know how to access this URL. + type: string + format: url + example: https://ladotbus.com/gtfs + authentication_type: + $ref: "#/components/schemas/Authentication_type" + authentication_info_url: + description: > + Contains a URL to a human-readable page describing how the authentication should be performed and how credentials can be created. + This field is required for `authentication_type=1` and `authentication_type=2`. + type: string + format: url + example: https://apidevelopers.ladottransit.com + api_key_parameter_name: + type: string + description: > + Defines the name of the parameter to pass in the URL to provide the API key. + This field is required for `authentication_type=1` and `authentication_type=2`. + example: Ocp-Apim-Subscription-Key + license_url: + description: A URL where to find the license for the feed. + type: string + format: url + example: https://www.ladottransit.com/dla.html + + EntityType: + type: string + enum: + - vp + - tu + - sa + example: vp + description: > + The type of realtime entry: + * vp - vehicle positions + * tu - trip updates + * sa - service alerts + + FeedStatus: + description: > + Describes status of the Feed. Should be one of + * `active` Feed should be used in public trip planners. + * `deprecated` Feed is explicitly deprecated and should not be used in public trip planners. + * `inactive` Feed hasn't been recently updated and should be used at risk of providing outdated information. + * `development` Feed is being used for development purposes and should not be used in public trip planners. + type: string + enum: + - active + - deprecated + - inactive + - development + example: active + + DataType: + description: > + Describes data type of a fee. Should be one of + * `gtfs` GTFS feed. + * `gtfs_rt` GTFS-RT feed. + * `gbfs` GBFS feed. + type: string + enum: + - gtfs + - gtfs_rt + - gbfs + example: gtfs + + Authentication_type: + description: > + Defines the type of authentication required to access the `producer_url`. Valid values for this field are: + * 0 or (empty) - No authentication required. + * 1 - The authentication requires an API key, which should be passed as value of the parameter api_key_parameter_name in the URL. Please visit URL in authentication_info_url for more information. + * 2 - The authentication requires an HTTP header, which should be passed as the value of the header api_key_parameter_name in the HTTP request. + When not provided, the authentication type is assumed to be 0. + type: integer + enum: + - 0 + - 1 + - 2 + example: 2 + + parameters: + feed_id_path_param: + name: id + in: path + description: The feed ID of the requested feed. + required: True + schema: + type: string + example: mdb-1210 + + securitySchemes: + ApiKeyAuth: + type: apiKey + name: X-API-KEY + in: header + +security: + - ApiKeyAuth: [] diff --git a/functions-python/.flake8 b/functions-python/.flake8 index b7330efac..ee7633a8d 100644 --- a/functions-python/.flake8 +++ b/functions-python/.flake8 @@ -1,5 +1,5 @@ [flake8] max-line-length = 120 -exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,.*,database_gen +exclude = .git,__pycache__,__init__.py,.mypy_cache,.pytest_cache,venv,build,.*,database_gen,feeds_operations_gen # Ignored because conflict with black extend-ignore = E203 \ No newline at end of file diff --git a/functions-python/.gcloudignore b/functions-python/.gcloudignore new file mode 100644 index 000000000..5a616cba3 --- /dev/null +++ b/functions-python/.gcloudignore @@ -0,0 +1,17 @@ +# This file specifies files that are *not* uploaded to Google Cloud +# using gcloud. It follows the same syntax as .gitignore, with the addition of +# "#!include" directives (which insert the entries of the given .gitignore-style +# file at that point). +# +# For more information, run: +# $ gcloud topic gcloudignore +# +.gcloudignore +# If you would like to upload your .git directory, .gitignore file or files +# from your .gitignore file, remove the corresponding line +# below: +.git +.gitignore + +node_modules +#!include:.gitignore diff --git a/functions-python/helpers/database.py b/functions-python/helpers/database.py index 3904ab4f6..92a31e7db 100644 --- a/functions-python/helpers/database.py +++ b/functions-python/helpers/database.py @@ -18,15 +18,69 @@ import threading from typing import Final -from sqlalchemy import create_engine, text -from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine, text, event +from sqlalchemy.orm import sessionmaker, mapper, class_mapper import logging +from database_gen.sqlacodegen_models import Feed, Gtfsfeed, Gtfsrealtimefeed, Gbfsfeed + DB_REUSE_SESSION: Final[str] = "DB_REUSE_SESSION" lock = threading.Lock() global_session = None +def configure_polymorphic_mappers(): + """ + Configure the polymorphic mappers allowing polymorphic values on relationships. + """ + feed_mapper = class_mapper(Feed) + # Configure the polymorphic mapper using date_type as discriminator for the Feed class + feed_mapper.polymorphic_on = Feed.data_type + feed_mapper.polymorphic_identity = Feed.__tablename__.lower() + + gtfsfeed_mapper = class_mapper(Gtfsfeed) + gtfsfeed_mapper.inherits = feed_mapper + gtfsfeed_mapper.polymorphic_identity = Gtfsfeed.__tablename__.lower() + + gtfsrealtimefeed_mapper = class_mapper(Gtfsrealtimefeed) + gtfsrealtimefeed_mapper.inherits = feed_mapper + gtfsrealtimefeed_mapper.polymorphic_identity = ( + Gtfsrealtimefeed.__tablename__.lower() + ) + + gbfsfeed_mapper = class_mapper(Gbfsfeed) + gbfsfeed_mapper.inherits = feed_mapper + gbfsfeed_mapper.polymorphic_identity = Gbfsfeed.__tablename__.lower() + + +def set_cascade(mapper, class_): + """ + Set cascade for relationships in Gtfsfeed. + This allows to delete/add the relationships when their respective relation array changes. + """ + if class_.__name__ == "Gtfsfeed": + for rel in class_.__mapper__.relationships: + if rel.key in [ + "redirectingids", + "redirectingids_", + "externalids", + "externalids_", + ]: + rel.cascade = "all, delete-orphan" + + +def mapper_configure_listener(mapper, class_): + """ + Mapper configure listener + """ + set_cascade(mapper, class_) + configure_polymorphic_mappers() + + +# Add the mapper_configure_listener to the mapper_configured event +event.listen(mapper, "mapper_configured", mapper_configure_listener) + + def get_db_engine(database_url: str = None, echo: bool = True): """ :return: Database engine diff --git a/functions-python/helpers/query_helper.py b/functions-python/helpers/query_helper.py new file mode 100644 index 000000000..75e0fd7e3 --- /dev/null +++ b/functions-python/helpers/query_helper.py @@ -0,0 +1,22 @@ +from typing import Type + +from database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed + +feed_mapping = {"gtfs_rt": Gtfsrealtimefeed, "gtfs": Gtfsfeed, "gbfs": Gbfsfeed} + + +def get_model(data_type: str | None) -> Type[Feed]: + """ + Get the model based on the data type + """ + return feed_mapping.get(data_type, Feed) + + +def query_feed_by_stable_id( + session, stable_id: str, data_type: str | None +) -> Gtfsrealtimefeed | Gtfsfeed | Gbfsfeed: + """ + Query the feed by stable id + """ + model = get_model(data_type) + return session.query(model).filter(model.stable_id == stable_id).first() diff --git a/functions-python/helpers/tests/test_transform.py b/functions-python/helpers/tests/test_transform.py new file mode 100644 index 000000000..5ce65d4d8 --- /dev/null +++ b/functions-python/helpers/tests/test_transform.py @@ -0,0 +1,21 @@ +from helpers.transform import to_boolean + + +def test_to_boolean(): + assert to_boolean(True) is True + assert to_boolean(False) is False + assert to_boolean("true") is True + assert to_boolean("True") is True + assert to_boolean("1") is True + assert to_boolean("yes") is True + assert to_boolean("y") is True + assert to_boolean("false") is False + assert to_boolean("False") is False + assert to_boolean("0") is False + assert to_boolean("no") is False + assert to_boolean("n") is False + assert to_boolean(1) is False + assert to_boolean(0) is False + assert to_boolean(None) is False + assert to_boolean([]) is False + assert to_boolean({}) is False diff --git a/functions-python/helpers/transform.py b/functions-python/helpers/transform.py new file mode 100644 index 000000000..e7d5264f9 --- /dev/null +++ b/functions-python/helpers/transform.py @@ -0,0 +1,26 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +def to_boolean(value): + """ + Convert a value to a boolean. + """ + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.lower() in ["true", "1", "yes", "y"] + return False diff --git a/functions-python/operations_api/.coveragerc b/functions-python/operations_api/.coveragerc new file mode 100644 index 000000000..b664793c1 --- /dev/null +++ b/functions-python/operations_api/.coveragerc @@ -0,0 +1,11 @@ +[run] +omit = + */test*/* + */helpers/* + */database_gen/* + */dataset_service/* + */feeds_operations_gen/* + +[report] +exclude_lines = + if __name__ == .__main__.: \ No newline at end of file diff --git a/functions-python/operations_api/.gitignore b/functions-python/operations_api/.gitignore new file mode 100644 index 000000000..e139d59ad --- /dev/null +++ b/functions-python/operations_api/.gitignore @@ -0,0 +1,2 @@ +# Generated files +src/feeds_operations_gen \ No newline at end of file diff --git a/functions-python/operations_api/.openapi-generator-ignore b/functions-python/operations_api/.openapi-generator-ignore new file mode 100644 index 000000000..664a6f5b5 --- /dev/null +++ b/functions-python/operations_api/.openapi-generator-ignore @@ -0,0 +1,36 @@ +# OpenAPI Generator Ignore +# Generated by openapi-generator https://github.com/openapitools/openapi-generator + +# Use this file to prevent files from being overwritten by the generator. +# The patterns follow closely to .gitignore or .dockerignore. + +# As an example, the C# client generator defines ApiClient.cs. +# You can make changes and tell OpenAPI Generator to ignore just this file by uncommenting the following line: +#ApiClient.cs + +# You can match any string of characters against a directory, file or extension with a single asterisk (*): +#foo/*/qux +# The above matches foo/bar/qux and foo/baz/qux, but not foo/bar/baz/qux + +# You can recursively match patterns against a directory, file or extension with a double asterisk (**): +#foo/**/qux +# This matches foo/bar/qux, foo/baz/qux, and foo/bar/baz/qux + +# You can also negate patterns with an exclamation (!). +# For example, you can ignore all files in a docs folder with the file extension .md: +#docs/*.md +# Then explicitly reverse the ignore rule for a single file: +#!docs/README.md + +.gitignore +openapi.yaml +README.md +Dockerfile +docker-compose.yaml +myproject.toml +pyproject.toml +setup.cfg +requirements_dev.txt +requirements.txt +.flake8 +tests/conftest.py \ No newline at end of file diff --git a/functions-python/operations_api/.openapi-generator/FILES b/functions-python/operations_api/.openapi-generator/FILES new file mode 100644 index 000000000..1304a5b92 --- /dev/null +++ b/functions-python/operations_api/.openapi-generator/FILES @@ -0,0 +1,17 @@ +src/feeds_operations/impl/__init__.py +src/feeds_operations_gen/apis/__init__.py +src/feeds_operations_gen/apis/operations_api.py +src/feeds_operations_gen/apis/operations_api_base.py +src/feeds_operations_gen/main.py +src/feeds_operations_gen/models/__init__.py +src/feeds_operations_gen/models/authentication_type.py +src/feeds_operations_gen/models/data_type.py +src/feeds_operations_gen/models/entity_type.py +src/feeds_operations_gen/models/external_id.py +src/feeds_operations_gen/models/extra_models.py +src/feeds_operations_gen/models/feed_status.py +src/feeds_operations_gen/models/redirect.py +src/feeds_operations_gen/models/source_info.py +src/feeds_operations_gen/models/update_request_gtfs_feed.py +src/feeds_operations_gen/models/update_request_gtfs_rt_feed.py +src/feeds_operations_gen/security_api.py diff --git a/functions-python/operations_api/.openapi-generator/VERSION b/functions-python/operations_api/.openapi-generator/VERSION new file mode 100644 index 000000000..758bb9c82 --- /dev/null +++ b/functions-python/operations_api/.openapi-generator/VERSION @@ -0,0 +1 @@ +7.10.0 diff --git a/functions-python/operations_api/README.md b/functions-python/operations_api/README.md new file mode 100644 index 000000000..382a865ef --- /dev/null +++ b/functions-python/operations_api/README.md @@ -0,0 +1,25 @@ +# Operations API +The Operations API is a function that exposes the operations API. +The operations API schema is located at ../../docs/OperationsAPI.yml. + +# Function configuration +The function is configured using the following environment variables: +- `FEEDS_DATABASE_URL`: The URL of the feeds database. +- `GOOGLE_CLIENT_ID`: The Google client ID used for authentication. + +# Useful scripts +- To locally execute a function use the following command: +``` +./scripts/function-python-run.sh --function_name operations_api +``` +- To locally create a distribution zip use the following command: +``` +./scripts/function-python-build.sh --function_name operations_api +``` +- Start local and test database +``` +docker compose --env-file ./config/.env.local up -d liquibase-test + + +# Local development +The local development of this function follows the same steps as the other functions. Please refer to the [README.md](../README.md) file for more information. \ No newline at end of file diff --git a/functions-python/operations_api/function_config.json b/functions-python/operations_api/function_config.json new file mode 100644 index 000000000..1af6512ed --- /dev/null +++ b/functions-python/operations_api/function_config.json @@ -0,0 +1,27 @@ +{ + "name": "operations-api", + "description": "API containing the back-office operations", + "entry_point": "main", + "timeout": 540, + "memory": "1Gi", + "trigger_http": true, + "include_folders": ["database_gen", "helpers"], + "environment_variables": [ + { + "key": "GOOGLE_CLIENT_ID" + } + ], + "secret_environment_variables": [ + { + "key": "FEEDS_DATABASE_URL" + } + ], + "ingress_settings": "ALLOW_ALL", + "max_instance_request_concurrency": 1, + "max_instance_count": 5, + "min_instance_count": 0, + "available_cpu": 1, + "build_settings": { + "pre_build_script": "../../scripts/api-operations-gen.sh" + } +} diff --git a/functions-python/operations_api/requirements.txt b/functions-python/operations_api/requirements.txt new file mode 100644 index 000000000..0c8815452 --- /dev/null +++ b/functions-python/operations_api/requirements.txt @@ -0,0 +1,32 @@ +aiohttp~=3.10.5 +asgiref~=3.8.1 +asyncio~=3.4.3 +attrs~=23.1.0 +certifi==2024.7.4 +email-validator==2.0.0 +fastapi==0.115.2 +httpx +mangum +pluggy~=1.5.0 +promise==2.3 +pydantic>=2 +python-dotenv==0.17.1 +python-multipart==0.0.7 +PyYAML>=5.4.1,<6.1.0 +requests==2.32.3 +Rx==1.6.1 +starlette==0.40.0 +typing-extensions==4.10.0 +ujson==4.0.2 +urllib3~=2.2.2 +uvicorn +uvloop==0.19.0 + +# Additional packages +google-cloud-logging==3.10.0 +functions-framework==3.* +SQLAlchemy==2.0.23 +geoalchemy2==0.14.7 +psycopg2-binary==2.9.6 +cachetools +deepdiff \ No newline at end of file diff --git a/functions-python/operations_api/requirements_dev.txt b/functions-python/operations_api/requirements_dev.txt new file mode 100644 index 000000000..47b61b12f --- /dev/null +++ b/functions-python/operations_api/requirements_dev.txt @@ -0,0 +1,5 @@ +pytest +pytest-asyncio +urllib3-mock +requests-mock +python-dotenv~=1.0.0 \ No newline at end of file diff --git a/functions-python/operations_api/src/__init__.py b/functions-python/operations_api/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions-python/operations_api/src/feeds_operations/__init__.py b/functions-python/operations_api/src/feeds_operations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions-python/operations_api/src/feeds_operations/impl/__init__.py b/functions-python/operations_api/src/feeds_operations/impl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions-python/operations_api/src/feeds_operations/impl/feeds_operations_impl.py b/functions-python/operations_api/src/feeds_operations/impl/feeds_operations_impl.py new file mode 100644 index 000000000..0cf88530b --- /dev/null +++ b/functions-python/operations_api/src/feeds_operations/impl/feeds_operations_impl.py @@ -0,0 +1,191 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +from typing import Annotated + +from deepdiff import DeepDiff +from fastapi import HTTPException +from pydantic import Field +from starlette.responses import Response + +from database_gen.sqlacodegen_models import Gtfsfeed, t_feedsearch +from feeds_operations.impl.models.update_request_gtfs_feed_impl import ( + UpdateRequestGtfsFeedImpl, +) +from feeds_operations_gen.apis.operations_api_base import BaseOperationsApi +from feeds_operations_gen.models.data_type import DataType +from feeds_operations_gen.models.update_request_gtfs_feed import UpdateRequestGtfsFeed +from feeds_operations_gen.models.update_request_gtfs_rt_feed import ( + UpdateRequestGtfsRtFeed, +) +from helpers.database import start_db_session, refresh_materialized_view +from helpers.query_helper import query_feed_by_stable_id +from .models.update_request_gtfs_rt_feed_impl import UpdateRequestGtfsRtFeedImpl +from .request_validator import validate_request + +logging.basicConfig(level=logging.INFO) + + +class OperationsApiImpl(BaseOperationsApi): + """ + Implementation of the operations API + """ + + @staticmethod + def detect_changes( + feed: Gtfsfeed, + update_request_feed: UpdateRequestGtfsFeed | UpdateRequestGtfsRtFeed, + impl_class: UpdateRequestGtfsFeedImpl | UpdateRequestGtfsRtFeedImpl, + ) -> DeepDiff: + """ + Detect changes between the feed and the update request. + """ + # Normalize the feed and the update request and compare them + copy_feed = impl_class.from_orm(feed) + # Temporary solution to update the operational status + copy_feed.operational_status_action = ( + update_request_feed.operational_status_action + ) + diff = DeepDiff( + copy_feed.model_dump(), + update_request_feed.model_dump(), + ignore_order=True, + ) + if diff.affected_paths: + logging.info( + f"Detect update changes: affected paths: {diff.affected_paths}" + ) + else: + logging.info("Detect update changes: no changes detected") + return diff + + @validate_request(UpdateRequestGtfsFeed, "update_request_gtfs_feed") + async def update_gtfs_feed( + self, + update_request_gtfs_feed: Annotated[ + UpdateRequestGtfsFeed, + Field(description="Payload to update the specified feed."), + ], + ) -> Response: + """Update the specified feed in the Mobility Database. + returns: + - 200: Feed updated successfully. + - 204: No changes detected. + - 400: Feed ID not found. + - 500: Internal server error. + """ + ... + return await self._update_feed(update_request_gtfs_feed, DataType.GTFS) + + @validate_request(UpdateRequestGtfsRtFeed, "update_request_gtfs_rt_feed") + async def update_gtfs_rt_feed( + self, + update_request_gtfs_rt_feed: Annotated[ + UpdateRequestGtfsRtFeed, + Field(description="Payload to update the specified GTFS-RT feed."), + ], + ) -> Response: + """Update the specified GTFS-RT feed in the Mobility Database. + returns: + - 200: Feed updated successfully. + - 204: No changes detected. + - 400: Feed ID not found. + - 500: Internal server error. + """ + return await self._update_feed(update_request_gtfs_rt_feed, DataType.GTFS_RT) + + async def _update_feed( + self, + update_request_feed: UpdateRequestGtfsFeed | UpdateRequestGtfsRtFeed, + data_type: DataType, + ) -> Response: + """ + Update the specified feed in the Mobility Database + """ + session = None + try: + session = start_db_session(os.getenv("FEEDS_DATABASE_URL")) + feed = await OperationsApiImpl.fetch_feed( + data_type, session, update_request_feed + ) + + logging.info( + f"Feed ID: {update_request_feed.id} attempting to update with the following request: " + f"{update_request_feed}" + ) + impl_class = ( + UpdateRequestGtfsFeedImpl + if data_type == DataType.GTFS + else UpdateRequestGtfsRtFeedImpl + ) + diff = self.detect_changes(feed, update_request_feed, impl_class) + if len(diff.affected_paths) > 0 or ( + update_request_feed.operational_status_action is not None + and update_request_feed.operational_status_action != "no_change" + ): + await OperationsApiImpl._populate_feed_values( + feed, impl_class, session, update_request_feed + ) + session.flush() + refreshed = refresh_materialized_view(session, t_feedsearch.name) + logging.info( + f"Materialized view {t_feedsearch.name} refreshed: {refreshed}" + ) + session.commit() + logging.info( + f"Feed ID: {update_request_feed.id} updated successfully with the following changes: " + f"{diff.values()}" + ) + return Response(status_code=200) + else: + logging.info( + f"No changes detected for feed ID: {update_request_feed.id}" + ) + return Response(status_code=204) + except Exception as e: + logging.error( + f"Failed to update feed ID: {update_request_feed.id}. Error: {e}" + ) + session.rollback() + if isinstance(e, HTTPException): + raise e + raise HTTPException(status_code=500, detail=f"Internal server error: {e}") + finally: + if session: + session.close() + + @staticmethod + async def _populate_feed_values(feed, impl_class, session, update_request_feed): + impl_class.to_orm(update_request_feed, feed, session) + action = update_request_feed.operational_status_action + # This is a temporary solution as the operational_status is not visible in the diff + if action is not None and not action.lower() == "no_change": + feed.operational_status = "wip" if action.lower() == "wip" else None + session.add(feed) + + @staticmethod + async def fetch_feed(data_type, session, update_request_feed): + feed: Gtfsfeed = query_feed_by_stable_id( + session, update_request_feed.id, data_type.value + ) + if feed is None: + raise HTTPException( + status_code=400, + detail=f"Feed ID not found: {update_request_feed.id}", + ) + return feed diff --git a/functions-python/operations_api/src/feeds_operations/impl/models/__init__.py b/functions-python/operations_api/src/feeds_operations/impl/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions-python/operations_api/src/feeds_operations/impl/models/entity_type_impl.py b/functions-python/operations_api/src/feeds_operations/impl/models/entity_type_impl.py new file mode 100644 index 000000000..8995f0ab4 --- /dev/null +++ b/functions-python/operations_api/src/feeds_operations/impl/models/entity_type_impl.py @@ -0,0 +1,42 @@ +from pydantic import BaseModel + +from feeds_operations_gen.models.entity_type import EntityType +from database_gen.sqlacodegen_models import Entitytype as EntityTypeOrm + + +class EntityTypeImpl(BaseModel): + """Implementation of the EntityType model. + This class converts a SQLAlchemy row DB object with the gtfs feed fields to a Pydantic model. + """ + + class Config: + """Pydantic configuration. + Enabling `from_attributes` method to create a model instance from a SQLAlchemy row object. + """ + + from_attributes = True + + @classmethod + def from_orm(cls, obj: EntityTypeOrm | None) -> EntityType | None: + """ + Convert a SQLAlchemy row object to a Pydantic model. + """ + if obj is None: + return None + return EntityType(obj.name.lower()) + + @classmethod + def to_orm(cls, entity_type: EntityType, session) -> EntityTypeOrm: + """ + Convert a Pydantic model to a SQLAlchemy row object. + """ + result = ( + session.query(EntityTypeOrm) + .filter(EntityTypeOrm.name.ilike(entity_type.name)) + .first() + ) + return ( + result + if result is not None + else EntityTypeOrm(name=entity_type.name.lower()) + ) diff --git a/functions-python/operations_api/src/feeds_operations/impl/models/external_id_impl.py b/functions-python/operations_api/src/feeds_operations/impl/models/external_id_impl.py new file mode 100644 index 000000000..c67eb52b3 --- /dev/null +++ b/functions-python/operations_api/src/feeds_operations/impl/models/external_id_impl.py @@ -0,0 +1,61 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from database_gen.sqlacodegen_models import ( + Externalid, + Gtfsfeed, + Gtfsrealtimefeed, + Gbfsfeed, +) +from feeds_operations_gen.models.external_id import ExternalId + + +class ExternalIdImpl(ExternalId): + """Implementation of the `ExternalId` model. + This class converts a SQLAlchemy row DB object to a Pydantic model. + """ + + class Config: + """Pydantic configuration. + Enabling `from_attributes` method to create a model instance from a SQLAlchemy row object. + """ + + from_attributes = True + + @classmethod + def from_orm(cls, external_id: Externalid | None) -> ExternalId | None: + """ + Convert a SQLAlchemy row object to a Pydantic model + """ + if not external_id: + return None + return cls( + external_id=external_id.associated_id, + source=external_id.source, + ) + + @classmethod + def to_orm( + cls, external_id: ExternalId, feed: Gtfsfeed | Gtfsrealtimefeed | Gbfsfeed + ) -> Externalid: + """ + Convert a Pydantic model to a SQLAlchemy row object + """ + return Externalid( + feed_id=feed.id, + associated_id=external_id.external_id, + source=external_id.source, + ) diff --git a/functions-python/operations_api/src/feeds_operations/impl/models/redirect_impl.py b/functions-python/operations_api/src/feeds_operations/impl/models/redirect_impl.py new file mode 100644 index 000000000..839acd6f4 --- /dev/null +++ b/functions-python/operations_api/src/feeds_operations/impl/models/redirect_impl.py @@ -0,0 +1,71 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from database_gen.sqlacodegen_models import ( + Redirectingid, + Gtfsfeed, + Gbfsfeed, + Gtfsrealtimefeed, +) +from feeds_operations_gen.models.redirect import Redirect +from helpers.query_helper import query_feed_by_stable_id + + +class RedirectImpl(Redirect): + """Implementation of the `Redirect` model. + This class converts a SQLAlchemy row DB object to a Pydantic model. + """ + + class Config: + """Pydantic configuration. + Enabling `from_attributes` method to create a model instance from a SQLAlchemy row object. + """ + + from_attributes = True + + @classmethod + def from_orm(cls, redirect: Redirectingid | None) -> Redirect | None: + """ + Convert a SQLAlchemy row object to a Pydantic model. + """ + if not redirect: + return None + return cls( + target_id=redirect.target.stable_id, + comment=redirect.redirect_comment, + ) + + @classmethod + def to_orm( + cls, redirect: Redirect, source: Gtfsfeed | Gtfsrealtimefeed | Gbfsfeed, session + ) -> Redirectingid: + """ + Convert a Pydantic model to a SQLAlchemy row object. + """ + if not source or not source.id: + raise ValueError("Invalid source object or source.id is not set") + target_feed = query_feed_by_stable_id( + session, redirect.target_id, source.data_type + ) + + if not target_feed or not target_feed.id: + raise ValueError("Invalid target_feed object or target_feed.id is not set") + + return Redirectingid( + source_id=source.id, + target_id=target_feed.id, + redirect_comment=redirect.comment, + ) diff --git a/functions-python/operations_api/src/feeds_operations/impl/models/update_request_gtfs_feed_impl.py b/functions-python/operations_api/src/feeds_operations/impl/models/update_request_gtfs_feed_impl.py new file mode 100644 index 000000000..237b568be --- /dev/null +++ b/functions-python/operations_api/src/feeds_operations/impl/models/update_request_gtfs_feed_impl.py @@ -0,0 +1,141 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from database_gen.sqlacodegen_models import Gtfsfeed +from feeds_operations.impl.models.external_id_impl import ExternalIdImpl +from feeds_operations.impl.models.redirect_impl import RedirectImpl +from feeds_operations_gen.models.source_info import SourceInfo +from feeds_operations_gen.models.update_request_gtfs_feed import UpdateRequestGtfsFeed + + +class UpdateRequestGtfsFeedImpl(UpdateRequestGtfsFeed): + """Implementation of the UpdateRequestGtfsFeed model. + This class converts a SQLAlchemy row DB object with the gtfs feed fields to a Pydantic model. + """ + + class Config: + """Pydantic configuration. + Enabling `from_attributes` method to create a model instance from a SQLAlchemy row object. + """ + + from_attributes = True + + @classmethod + def from_orm(cls, obj: Gtfsfeed | None) -> UpdateRequestGtfsFeed | None: + """ + Convert a SQLAlchemy row object to a Pydantic model. + """ + if obj is None: + return None + return cls( + id=obj.stable_id, + status=obj.status, + provider=obj.provider, + feed_name=obj.feed_name, + note=obj.note, + feed_contact_email=obj.feed_contact_email, + source_info=SourceInfo( + producer_url=obj.producer_url, + authentication_type=None + if obj.authentication_type is None + else int(obj.authentication_type), + authentication_info_url=obj.authentication_info_url, + api_key_parameter_name=obj.api_key_parameter_name, + license_url=obj.license_url, + ), + redirects=sorted( + [RedirectImpl.from_orm(item) for item in obj.redirectingids], + key=lambda x: x.target_id, + ), + external_ids=sorted( + [ExternalIdImpl.from_orm(item) for item in obj.externalids], + key=lambda x: x.external_id, + ), + ) + + @classmethod + def to_orm( + cls, update_request: UpdateRequestGtfsFeed, entity: Gtfsfeed, session + ) -> Gtfsfeed: + """ + Convert a Pydantic model to a SQLAlchemy row object. + """ + entity.status = update_request.status + entity.provider = update_request.provider + entity.feed_name = update_request.feed_name + entity.note = update_request.note + entity.feed_contact_email = update_request.feed_contact_email + entity.producer_url = ( + None + if ( + update_request.source_info is None + or update_request.source_info.producer_url is None + ) + else update_request.source_info.producer_url + ) + entity.authentication_type = ( + None + if ( + update_request.source_info is None + or update_request.source_info.authentication_type is None + ) + else str(update_request.source_info.authentication_type.value) + ) + entity.authentication_info_url = ( + None + if ( + update_request.source_info is None + or update_request.source_info.authentication_info_url is None + ) + else update_request.source_info.authentication_info_url + ) + entity.api_key_parameter_name = ( + None + if ( + update_request.source_info is None + or update_request.source_info.api_key_parameter_name is None + ) + else update_request.source_info.api_key_parameter_name + ) + entity.license_url = ( + None + if ( + update_request.source_info is None + or update_request.source_info.license_url is None + ) + else update_request.source_info.license_url + ) + + redirecting_ids = ( + [] + if update_request.redirects is None + else [ + RedirectImpl.to_orm(item, entity, session) + for item in update_request.redirects + ] + ) + entity.redirectingids.clear() + entity.redirectingids.extend(redirecting_ids) + + entity.externalids = ( + [] + if update_request.external_ids is None + else [ + ExternalIdImpl.to_orm(item, entity) + for item in update_request.external_ids + ] + ) + return entity diff --git a/functions-python/operations_api/src/feeds_operations/impl/models/update_request_gtfs_rt_feed_impl.py b/functions-python/operations_api/src/feeds_operations/impl/models/update_request_gtfs_rt_feed_impl.py new file mode 100644 index 000000000..173dc3d5b --- /dev/null +++ b/functions-python/operations_api/src/feeds_operations/impl/models/update_request_gtfs_rt_feed_impl.py @@ -0,0 +1,164 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed +from feeds_operations.impl.models.entity_type_impl import EntityTypeImpl +from feeds_operations.impl.models.external_id_impl import ExternalIdImpl +from feeds_operations.impl.models.redirect_impl import RedirectImpl +from feeds_operations_gen.models.source_info import SourceInfo +from feeds_operations_gen.models.update_request_gtfs_rt_feed import ( + UpdateRequestGtfsRtFeed, +) + + +class UpdateRequestGtfsRtFeedImpl(UpdateRequestGtfsRtFeed): + """Implementation of the UpdateRequestGtfsRtFeed model. + This class converts a SQLAlchemy row DB object with the gtfs feed fields to a Pydantic model. + """ + + class Config: + """Pydantic configuration. + Enabling `from_attributes` method to create a model instance from a SQLAlchemy row object. + """ + + from_attributes = True + + @classmethod + def from_orm(cls, obj: Gtfsrealtimefeed | None) -> UpdateRequestGtfsRtFeed | None: + """ + Convert a SQLAlchemy row object to a Pydantic model. + """ + if obj is None: + return None + return cls( + id=obj.stable_id, + status=obj.status, + provider=obj.provider, + feed_name=obj.feed_name, + note=obj.note, + feed_contact_email=obj.feed_contact_email, + source_info=SourceInfo( + producer_url=obj.producer_url, + authentication_type=None + if obj.authentication_type is None + else int(obj.authentication_type), + authentication_info_url=obj.authentication_info_url, + api_key_parameter_name=obj.api_key_parameter_name, + license_url=obj.license_url, + ), + redirects=sorted( + [RedirectImpl.from_orm(item) for item in obj.redirectingids], + key=lambda x: x.target_id, + ), + external_ids=sorted( + [ExternalIdImpl.from_orm(item) for item in obj.externalids], + key=lambda x: x.external_id, + ), + entity_types=sorted( + [EntityTypeImpl.from_orm(item) for item in obj.entitytypes] + ), + feed_references=sorted([item.stable_id for item in obj.gtfs_feeds]), + ) + + @classmethod + def to_orm( + cls, update_request: UpdateRequestGtfsRtFeed, entity: Gtfsrealtimefeed, session + ) -> Gtfsrealtimefeed: + """ + Convert a Pydantic model to a SQLAlchemy row object. + """ + entity.status = update_request.status + entity.provider = update_request.provider + entity.feed_name = update_request.feed_name + entity.note = update_request.note + entity.feed_contact_email = update_request.feed_contact_email + entity.producer_url = ( + None + if ( + update_request.source_info is None + or update_request.source_info.producer_url is None + ) + else update_request.source_info.producer_url + ) + entity.authentication_type = ( + None + if ( + update_request.source_info is None + or update_request.source_info.authentication_type is None + ) + else str(update_request.source_info.authentication_type.value) + ) + entity.authentication_info_url = ( + None + if ( + update_request.source_info is None + or update_request.source_info.authentication_info_url is None + ) + else update_request.source_info.authentication_info_url + ) + entity.api_key_parameter_name = ( + None + if ( + update_request.source_info is None + or update_request.source_info.api_key_parameter_name is None + ) + else update_request.source_info.api_key_parameter_name + ) + entity.license_url = ( + None + if ( + update_request.source_info is None + or update_request.source_info.license_url is None + ) + else update_request.source_info.license_url + ) + + redirecting_ids = ( + [] + if update_request.redirects is None + else [ + RedirectImpl.to_orm(item, entity, session) + for item in update_request.redirects + ] + ) + entity.redirectingids.clear() + entity.redirectingids.extend(redirecting_ids) + + entity.externalids = ( + [] + if update_request.external_ids is None + else [ + ExternalIdImpl.to_orm(item, entity) + for item in update_request.external_ids + ] + ) + entity.entitytypes = ( + [] + if update_request.entity_types is None + else [ + EntityTypeImpl.to_orm(item, session) + for item in update_request.entity_types + ] + ) + entity.gtfs_feeds = ( + [] + if update_request.feed_references is None + else [ + session.query(Gtfsfeed).filter(Gtfsfeed.stable_id == item).one() + for item in update_request.feed_references + ] + ) + return entity diff --git a/functions-python/operations_api/src/feeds_operations/impl/request_validator.py b/functions-python/operations_api/src/feeds_operations/impl/request_validator.py new file mode 100644 index 000000000..95dd10a11 --- /dev/null +++ b/functions-python/operations_api/src/feeds_operations/impl/request_validator.py @@ -0,0 +1,49 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +from functools import wraps +from pydantic import BaseModel, ValidationError +from fastapi import HTTPException + + +def validate_request(model: BaseModel, parameter_name: str, validate_none: bool = True): + """ + Decorator to validate request parameters using Pydantic models. + raises: + HTTPException: 400, If the parameter is missing or invalid. + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + func_args = inspect.getfullargspec(func).args + print(func_args) + index = func_args.index(parameter_name) + value = args[index] + if value: + try: + model.model_validate(value) + except ValidationError as e: + raise HTTPException(status_code=400, detail=str(e)) + else: + if validate_none: + raise HTTPException(status_code=400, detail="Missing parameter") + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/functions-python/operations_api/src/main.py b/functions-python/operations_api/src/main.py new file mode 100644 index 000000000..7f44c00ee --- /dev/null +++ b/functions-python/operations_api/src/main.py @@ -0,0 +1,108 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from flask import Request, Response +from fastapi import FastAPI +from feeds_operations_gen.apis.operations_api import router as FeedsApiRouter +import functions_framework +import asyncio + +from middleware.request_context_middleware import RequestContextMiddleware +from helpers.logger import Logger + +Logger.init_logger() + +app = FastAPI( + title="Mobility Database Catalog Operations", + description="API for the Mobility Database Catalog Operations.", + version="1.0.0", +) + +# Add here middlewares that should be applied to all routes. +app.add_middleware(RequestContextMiddleware) +app.include_router(FeedsApiRouter) + + +def build_scope_from_wsgi(request: Request) -> dict: + """ + Build the ASGI scope dynamically from a Flask (WSGI) request. + """ + environ = request.environ + + connection_type = "http" + if environ.get("HTTP_UPGRADE", "").lower() == "websocket": + connection_type = "websocket" + + client = (environ.get("REMOTE_ADDR", ""), int(environ.get("REMOTE_PORT", 0))) + server = (environ.get("SERVER_NAME", ""), int(environ.get("SERVER_PORT", 0))) + + headers = [ + (key.lower().encode("latin-1"), value.encode("latin-1")) + for key, value in request.headers.items() + ] + + return { + "type": connection_type, + "http_version": environ.get("SERVER_PROTOCOL", "HTTP/1.1").split("/")[1], + "method": request.method, + "headers": headers, + "path": environ.get("PATH_INFO", "/"), + "raw_path": environ.get("RAW_URI", "").encode("latin-1"), + "query_string": environ.get("QUERY_STRING", "").encode("latin-1"), + "server": server, + "client": client, + "scheme": environ.get("wsgi.url_scheme", "http"), + } + + +@functions_framework.http +def main(request: Request): + """ + Entry point for Google Cloud Function. + """ + scope = build_scope_from_wsgi(request) + + async def receive(): + body = request.get_data() + return {"type": "http.request", "body": body, "more_body": False} + + send_response = {} + + async def send(message): + if message["type"] == "http.response.start": + send_response["status"] = message["status"] + send_response["headers"] = { + key.decode("latin-1"): value.decode("latin-1") + for key, value in message["headers"] + } + elif message["type"] == "http.response.body": + send_response["body"] = message.get("body", b"") + + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(app(scope, receive, send)) + + return Response( + response=send_response.get("body", b""), + status=send_response.get("status", 200), + headers=send_response.get("headers", {}), + ) + except Exception as e: + return Response( + response=str(e), + status=500, + ) diff --git a/functions-python/operations_api/src/middleware/request_context_middleware.py b/functions-python/operations_api/src/middleware/request_context_middleware.py new file mode 100644 index 000000000..dc4d676e2 --- /dev/null +++ b/functions-python/operations_api/src/middleware/request_context_middleware.py @@ -0,0 +1,48 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from starlette.types import ASGIApp, Receive, Scope, Send + +from middleware.request_context_oauth2 import ( + RequestContext, + _request_context, +) + + +class RequestContextMiddleware: + """ + Middleware to set the request context and authorize requests. + """ + + def __init__(self, app: ASGIApp) -> None: + self.logger = logging.getLogger() + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + Middleware to set the request context and authorize requests. + """ + if scope["type"] == "http": + request_context = RequestContext(scope=scope) + _request_context.set(request_context.__dict__) + + async def http_send(message): + await send(message) + + await self.app(scope, receive, http_send) + else: + await self.app(scope, receive, send) diff --git a/functions-python/operations_api/src/middleware/request_context_oauth2.py b/functions-python/operations_api/src/middleware/request_context_oauth2.py new file mode 100644 index 000000000..7f7d2100d --- /dev/null +++ b/functions-python/operations_api/src/middleware/request_context_oauth2.py @@ -0,0 +1,225 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +from contextvars import ContextVar +from time import time + +import requests +from cachetools import TTLCache +from fastapi import HTTPException +from starlette.datastructures import Headers +from starlette.types import Scope + +from helpers.transform import to_boolean + +REQUEST_CTX_KEY = "request_context_key" +_request_context: ContextVar[dict] = ContextVar(REQUEST_CTX_KEY, default={}) +cache = TTLCache(maxsize=1000, ttl=3600) + + +def validate_token_with_google(token: str, google_client_id: str) -> dict: + """ + Validate the token with Google's tokeninfo endpoint and return the token info. + returns: + dict: Token info + raises: + HTTPException: 401, If the token is invalid or the audience is not the expected client. + HTTPException: 500, If the token validation fails. + """ + try: + response = get_tokeninfo_response(token) + except Exception as e: + logging.error(f"Token validation failed: {e}") + raise HTTPException(status_code=500, detail="Token validation failed") + + if response.status_code != 200: + raise HTTPException(status_code=401, detail="Invalid access token") + + token_info = response.json() + # Ensure the token is for the expected client + if token_info.get("audience") != google_client_id: + raise HTTPException(status_code=401, detail="Invalid token audience") + + return token_info + + +def get_tokeninfo_response(token): + """ + Get the token info response from Google's tokeninfo endpoint. + """ + response = requests.get( + f"https://www.googleapis.com/oauth2/v1/tokeninfo?access_token={token}" + ) + return response + + +def get_token_info(token: str, google_client_id: str) -> dict: + """ + Resolve the token info, using cache when possible. If expired, clear from cache. + returns: + dict: Token info + """ + current_time = time() + if token in cache: + logging.info("Token found in cache") + token_info, expiry_time = cache[token] + + # Check if the token has expired + if current_time >= expiry_time: + logging.info("Cached token has expired, removing from cache") + del cache[token] # Remove expired token + else: + return token_info + + token_info = validate_token_with_google(token, google_client_id) + expires_in = int( + token_info.get("expires_in", 3600) + ) # Default to 1 hour if not provided + expiry_time = current_time + expires_in + cache[token] = (token_info, expiry_time) + + return token_info + + +def extract_authorization_oauth(headers: dict, google_client_id: str) -> str: + """ + Extract and validate the OAuth token, returning the associated email. + returns: + str: Email + raises: + HTTPException: 401, If the Authorization header is missing or invalid. + HTTPException: 400, If the email is not found in the token. + """ + auth_header = headers.get("Authorization") + + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException( + status_code=401, detail="Missing or invalid Authorization header" + ) + + token = auth_header.split(" ")[1] + + token_info = get_token_info(token, google_client_id) + + email = token_info.get("email") + if not email: + raise HTTPException(status_code=400, detail="Email not found in token") + + return email + + +class RequestContext: + """ + Request context class to store request metadata. + """ + + def __init__(self, scope: Scope) -> None: + headers = Headers(scope=scope) + self.headers = headers + self.scope = scope + self._extract_from_headers(headers, scope) + + def _extract_from_headers(self, headers, scope: Scope) -> None: + """ + Extract request context from headers. + - For local development, the user email is extracted from the Authorization header + (x-goog-authenticated-user-email). Otherwise, the Authorization header is required. + Local development can be enabled by setting the LOCAL_ENV environment variable to True. + - For production, the GOOGLE_CLIENT_ID environment variable must be set. + """ + self.host = headers.get("host") + self.protocol = ( + headers.get("x-forwarded-proto") + if headers.get("x-forwarded-proto") + else scope.get("scheme") + ) + self.client_host = headers.get("x-forwarded-for") + self.server_ip = ( + scope.get("server")[0] + if scope.get("server") and len(scope.get("server")) > 0 + else "" + ) + if not self.client_host: + self.client_host = ( + scope.get("client")[0] + if scope.get("client") and len(scope.get("client")) > 0 + else "" + ) + else: + # X-Forwarded-For: client, proxy1, proxy2 + forwarded_ips = self.client_host.split(",") + self.client_host = ( + str(forwarded_ips[0]).strip() + if len(forwarded_ips) > 0 + else str(self.client_host).strip() + ) + # merge all forwarded ips but the first one + self.server_ip = ( + ",".join(forwarded_ips[1:]).strip() + if len(forwarded_ips) > 1 + else self.server_ip + ) + self.client_user_agent = headers.get("user-agent") + self.iap_jwt_assertion = headers.get("x-goog-iap-jwt-assertion") + self.span_id = None + self.trace_id = None + self.trace_sampled = False + trace_context = headers.get("x-cloud-trace-context") + self.trace = trace_context + # x-cloud-trace-context: TRACE_ID/SPAN_ID;o=TRACE_TRUE + if trace_context and len(trace_context) > 0: + parts = trace_context.split("/") + self.trace_id = parts[0] + if len(parts) > 1: + self.span_id = parts[1].split(";")[0] + self.trace_sampled = ( + parts[1].split(";")[1] == "o=1" + if len(parts[1].split(";")) > 1 + else False + ) + # auth header is used for local development + self.user_email = headers.get("x-goog-authenticated-user-email") + + if headers.get("authorization") is not None: + google_client_id = os.getenv("GOOGLE_CLIENT_ID") + self.user_email = extract_authorization_oauth(headers, google_client_id) + else: + local_environment = os.getenv("LOCAL_ENV", False) + if not to_boolean(local_environment): + raise HTTPException( + status_code=401, detail="Authorization header not found" + ) + logging.info(self) + + def __repr__(self) -> str: + safe_properties = dict( + user_email=self.user_email, + client_user_agent=self.client_user_agent, + client_host=self.client_host, + client_protocol=self.protocol, + span_id=self.span_id, + trace_id=self.trace_id, + ) + return f"request-context={safe_properties})" + + +def get_request_context(): + """ + Get the request context. + """ + return _request_context.get() diff --git a/functions-python/operations_api/tests/__init__.py b/functions-python/operations_api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions-python/operations_api/tests/conftest.py b/functions-python/operations_api/tests/conftest.py new file mode 100644 index 000000000..4c5b67894 --- /dev/null +++ b/functions-python/operations_api/tests/conftest.py @@ -0,0 +1,115 @@ +# +# MobilityData 2023 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed, Entitytype +from test_utils.database_utils import clean_testing_db, get_testing_session + +feed_mdb_41 = Gtfsrealtimefeed( + id="mdb-41", + data_type="gtfs_rt", + feed_name="London Transit Commission(RT", + note="note", + producer_url="producer_url", + authentication_type="1", + authentication_info_url="authentication_info_url", + api_key_parameter_name="api_key_parameter_name", + license_url="license_url", + stable_id="mdb-41", + status="active", + feed_contact_email="feed_contact_email", + provider="provider", + entitytypes=[Entitytype(name="vp")], +) + +feed_mdb_40 = Gtfsfeed( + id="mdb-40", + data_type="gtfs", + feed_name="London Transit Commission", + note="note", + producer_url="producer_url", + authentication_type="1", + authentication_info_url="authentication_info_url", + api_key_parameter_name="api_key_parameter_name", + license_url="license_url", + stable_id="mdb-40", + status="active", + feed_contact_email="feed_contact_email", + provider="provider", + gtfs_rt_feeds=[feed_mdb_41], + operational_status="wip", +) + +feed_mdb_400 = Gtfsfeed( + id="mdb-400", + data_type="gtfs", + feed_name="London Transit Commission", + note="note", + producer_url="producer_url", + authentication_type="1", + authentication_info_url="authentication_info_url", + api_key_parameter_name="api_key_parameter_name", + license_url="license_url", + stable_id="mdb-400", + status="active", + feed_contact_email="feed_contact_email", + provider="provider", + gtfs_rt_feeds=[], +) + + +def populate_database(): + """ + Populates the database with fake data with the following distribution: + - 1 GTFS feeds + - 1 GTFS Realtime feeds + """ + session = get_testing_session() + + session.add(feed_mdb_41) + # session.flush() + session.add(feed_mdb_40) + session.add(feed_mdb_400) + session.commit() + + +def pytest_configure(config): + """ + Allows plugins and conftest files to perform initial configuration. + This hook is called for every plugin and initial conftest + file after command line options have been parsed. + """ + + +def pytest_sessionstart(session): + """ + Called after the Session object has been created and + before performing collection and entering the run test loop. + """ + clean_testing_db() + populate_database() + + +def pytest_sessionfinish(session, exitstatus): + """ + Called after whole test run finished, right before + returning the exit status to the system. + """ + clean_testing_db() + + +def pytest_unconfigure(config): + """ + called before test process is exited. + """ diff --git a/functions-python/operations_api/tests/feeds_operations/impl/models/test_entity_type_impl.py b/functions-python/operations_api/tests/feeds_operations/impl/models/test_entity_type_impl.py new file mode 100644 index 000000000..3d41aab15 --- /dev/null +++ b/functions-python/operations_api/tests/feeds_operations/impl/models/test_entity_type_impl.py @@ -0,0 +1,27 @@ +from unittest.mock import Mock + +from database_gen.sqlacodegen_models import Entitytype +from feeds_operations.impl.models.entity_type_impl import EntityTypeImpl +from feeds_operations_gen.models.entity_type import EntityType + + +def test_from_orm(): + entity_type = Entitytype(name="VP") + result = EntityTypeImpl.from_orm(entity_type) + assert result.name == "VP" + + +def test_from_orm_none(): + result = EntityTypeImpl.from_orm(None) + assert result is None + + +def test_to_orm(): + entity_type = EntityType("vp") + session = Mock() + mock_query = Mock() + resulting_entity = Mock() + mock_query.filter.return_value.first.return_value = resulting_entity + session.query.return_value = mock_query + result = EntityTypeImpl.to_orm(entity_type, session) + assert result == resulting_entity diff --git a/functions-python/operations_api/tests/feeds_operations/impl/models/test_external_id_impl.py b/functions-python/operations_api/tests/feeds_operations/impl/models/test_external_id_impl.py new file mode 100644 index 000000000..4faf1a356 --- /dev/null +++ b/functions-python/operations_api/tests/feeds_operations/impl/models/test_external_id_impl.py @@ -0,0 +1,26 @@ +from database_gen.sqlacodegen_models import Externalid, Gtfsfeed +from feeds_operations_gen.models.external_id import ExternalId +from feeds_operations.impl.models.external_id_impl import ( + ExternalIdImpl, +) + + +def test_from_orm(): + external_id = Externalid(associated_id="12345", source="test_source") + result = ExternalIdImpl.from_orm(external_id) + assert result.external_id == "12345" + assert result.source == "test_source" + + +def test_from_orm_none(): + result = ExternalIdImpl.from_orm(None) + assert result is None + + +def test_to_orm(): + external_id = ExternalId(external_id="12345", source="test_source") + feed = Gtfsfeed(id=1) + result = ExternalIdImpl.to_orm(external_id, feed) + assert result.feed_id == 1 + assert result.associated_id == "12345" + assert result.source == "test_source" diff --git a/functions-python/operations_api/tests/feeds_operations/impl/models/test_redirect_impl.py b/functions-python/operations_api/tests/feeds_operations/impl/models/test_redirect_impl.py new file mode 100644 index 000000000..b50695205 --- /dev/null +++ b/functions-python/operations_api/tests/feeds_operations/impl/models/test_redirect_impl.py @@ -0,0 +1,53 @@ +import pytest +from unittest.mock import MagicMock +from database_gen.sqlacodegen_models import Redirectingid, Gtfsfeed +from feeds_operations_gen.models.redirect import Redirect +from feeds_operations.impl.models.redirect_impl import RedirectImpl + + +def test_from_orm(): + redirecting_id = Redirectingid( + target=MagicMock(stable_id="target_stable_id"), redirect_comment="Test comment" + ) + result = RedirectImpl.from_orm(redirecting_id) + assert result.target_id == "target_stable_id" + assert result.comment == "Test comment" + + +def test_from_orm_none(): + result = RedirectImpl.from_orm(None) + assert result is None + + +def test_to_orm(): + redirect = Redirect(target_id="target_stable_id", comment="Test comment") + source_feed = Gtfsfeed(id=1, data_type="gtfs") + target_feed = Gtfsfeed(id=2, stable_id="target_stable_id") + session = MagicMock() + session.query.return_value.filter.return_value.first.return_value = target_feed + result = RedirectImpl.to_orm(redirect, source_feed, session) + assert result.source_id == 1 + assert result.target_id == 2 + assert result.redirect_comment == "Test comment" + + +def test_to_orm_invalid_source(): + redirect = Redirect(target_id="target_stable_id", comment="Test comment") + session = MagicMock() + + with pytest.raises( + ValueError, match="Invalid source object or source.id is not set" + ): + RedirectImpl.to_orm(redirect, None, session) + + +def test_to_orm_invalid_target(): + redirect = Redirect(target_id="target_stable_id", comment="Test comment") + source_feed = Gtfsfeed(id=1, data_type="gtfs") + session = MagicMock() + session.query.return_value.filter.return_value.first.return_value = None + + with pytest.raises( + ValueError, match="Invalid target_feed object or target_feed.id is not set" + ): + RedirectImpl.to_orm(redirect, source_feed, session) diff --git a/functions-python/operations_api/tests/feeds_operations/impl/models/test_update_request_gtfs_feed_impl.py b/functions-python/operations_api/tests/feeds_operations/impl/models/test_update_request_gtfs_feed_impl.py new file mode 100644 index 000000000..e2a47241d --- /dev/null +++ b/functions-python/operations_api/tests/feeds_operations/impl/models/test_update_request_gtfs_feed_impl.py @@ -0,0 +1,118 @@ +from unittest.mock import Mock, MagicMock +from database_gen.sqlacodegen_models import Gtfsfeed, Redirectingid, Externalid +from feeds_operations_gen.models.authentication_type import AuthenticationType +from feeds_operations_gen.models.feed_status import FeedStatus +from feeds_operations_gen.models.source_info import SourceInfo +from feeds_operations_gen.models.update_request_gtfs_feed import UpdateRequestGtfsFeed +from operations_api.src.feeds_operations.impl.models.update_request_gtfs_feed_impl import ( + UpdateRequestGtfsFeedImpl, +) +from operations_api.src.feeds_operations.impl.models.redirect_impl import RedirectImpl +from operations_api.src.feeds_operations.impl.models.external_id_impl import ( + ExternalIdImpl, +) + + +def test_from_orm(): + redirecting_id = Redirectingid(target=MagicMock(stable_id="target_stable_id")) + external_id = Externalid(associated_id="external_id") + gtfs_feed = Gtfsfeed( + stable_id="stable_id", + status="active", + provider="provider", + feed_name="feed_name", + note="note", + feed_contact_email="email@example.com", + producer_url="http://producer.url", + authentication_type=1, + authentication_info_url="http://auth.info.url", + api_key_parameter_name="api_key", + license_url="http://license.url", + redirectingids=[redirecting_id], + externalids=[external_id], + ) + + result = UpdateRequestGtfsFeedImpl.from_orm(gtfs_feed) + assert result.id == "stable_id" + assert result.status == "active" + assert result.provider == "provider" + assert result.feed_name == "feed_name" + assert result.note == "note" + assert result.feed_contact_email == "email@example.com" + assert result.source_info.producer_url == "http://producer.url" + assert result.source_info.authentication_type == 1 + assert result.source_info.authentication_info_url == "http://auth.info.url" + assert result.source_info.api_key_parameter_name == "api_key" + assert result.source_info.license_url == "http://license.url" + assert len(result.redirects) == 1 + assert result.redirects[0].target_id == "target_stable_id" + assert len(result.external_ids) == 1 + assert result.external_ids[0].external_id == "external_id" + + +def test_from_orm_none(): + result = UpdateRequestGtfsFeedImpl.from_orm(None) + assert result is None + + +def test_to_orm(): + update_request = UpdateRequestGtfsFeed( + id="stable_id", + status=FeedStatus.ACTIVE, + provider="provider", + feed_name="feed_name", + note="note", + feed_contact_email="email@example.com", + source_info=SourceInfo( + producer_url="http://producer.url", + authentication_type=AuthenticationType.NUMBER_1, + authentication_info_url="http://auth.info.url", + api_key_parameter_name="api_key", + license_url="http://license.url", + ), + redirects=[RedirectImpl(target_id="target_stable_id", comment="Test comment")], + external_ids=[ExternalIdImpl(external_id="external_id")], + ) + entity = Gtfsfeed(id="1", stable_id="stable_id", data_type="gtfs") + target_feed = Gtfsfeed(id=2, stable_id="target_stable_id") + session = MagicMock() + session.query.return_value.filter.return_value.first.return_value = target_feed + + result = UpdateRequestGtfsFeedImpl.to_orm(update_request, entity, session) + assert result.status == "active" + assert result.provider == "provider" + assert result.feed_name == "feed_name" + assert result.note == "note" + assert result.feed_contact_email == "email@example.com" + assert result.producer_url == "http://producer.url" + assert result.authentication_type == "1" + assert result.authentication_info_url == "http://auth.info.url" + assert result.api_key_parameter_name == "api_key" + assert result.license_url == "http://license.url" + assert len(result.redirectingids) == 1 + assert result.redirectingids[0].target_id == target_feed.id + assert len(result.externalids) == 1 + assert result.externalids[0].associated_id == "external_id" + + +def test_to_orm_invalid_source_info(): + update_request = UpdateRequestGtfsFeed( + id="stable_id", + status=FeedStatus.ACTIVE, + provider="provider", + feed_name="feed_name", + note="note", + feed_contact_email="email@example.com", + source_info=None, + redirects=[RedirectImpl(target_id="target_stable_id", comment="Test comment")], + external_ids=[ExternalIdImpl(external_id="external_id")], + ) + entity = Gtfsfeed(id="id") + session = Mock() + + result = UpdateRequestGtfsFeedImpl.to_orm(update_request, entity, session) + assert result.producer_url is None + assert result.authentication_type is None + assert result.authentication_info_url is None + assert result.api_key_parameter_name is None + assert result.license_url is None diff --git a/functions-python/operations_api/tests/feeds_operations/impl/models/test_update_request_gtfs_rt_feed_impl.py b/functions-python/operations_api/tests/feeds_operations/impl/models/test_update_request_gtfs_rt_feed_impl.py new file mode 100644 index 000000000..6d80f3743 --- /dev/null +++ b/functions-python/operations_api/tests/feeds_operations/impl/models/test_update_request_gtfs_rt_feed_impl.py @@ -0,0 +1,144 @@ +from unittest.mock import MagicMock +from database_gen.sqlacodegen_models import ( + Gtfsfeed, + Redirectingid, + Externalid, + Gtfsrealtimefeed, + Entitytype, +) +from feeds_operations.impl.models.update_request_gtfs_rt_feed_impl import ( + UpdateRequestGtfsRtFeedImpl, +) +from feeds_operations_gen.models.authentication_type import AuthenticationType +from feeds_operations_gen.models.entity_type import EntityType +from feeds_operations_gen.models.feed_status import FeedStatus +from feeds_operations_gen.models.source_info import SourceInfo +from feeds_operations_gen.models.update_request_gtfs_rt_feed import ( + UpdateRequestGtfsRtFeed, +) +from operations_api.src.feeds_operations.impl.models.redirect_impl import RedirectImpl +from operations_api.src.feeds_operations.impl.models.external_id_impl import ( + ExternalIdImpl, +) + + +def test_from_orm(): + redirecting_id = Redirectingid(target=MagicMock(stable_id="target_stable_id")) + external_id = Externalid(associated_id="external_id") + gtfs_feed = Gtfsrealtimefeed( + stable_id="stable_id", + status="active", + provider="provider", + feed_name="feed_name", + note="note", + feed_contact_email="email@example.com", + producer_url="http://producer.url", + authentication_type=1, + authentication_info_url="http://auth.info.url", + api_key_parameter_name="api_key", + license_url="http://license.url", + redirectingids=[redirecting_id], + externalids=[external_id], + ) + + result = UpdateRequestGtfsRtFeedImpl.from_orm(gtfs_feed) + assert result.id == "stable_id" + assert result.status == "active" + assert result.provider == "provider" + assert result.feed_name == "feed_name" + assert result.note == "note" + assert result.feed_contact_email == "email@example.com" + assert result.source_info.producer_url == "http://producer.url" + assert result.source_info.authentication_type == 1 + assert result.source_info.authentication_info_url == "http://auth.info.url" + assert result.source_info.api_key_parameter_name == "api_key" + assert result.source_info.license_url == "http://license.url" + assert len(result.redirects) == 1 + assert result.redirects[0].target_id == "target_stable_id" + assert len(result.external_ids) == 1 + assert result.external_ids[0].external_id == "external_id" + + +def test_from_orm_none(): + result = UpdateRequestGtfsRtFeedImpl.from_orm(None) + assert result is None + + +def test_to_orm(): + update_request = UpdateRequestGtfsRtFeed( + id="stable_id", + status=FeedStatus.ACTIVE, + provider="provider", + feed_name="feed_name", + note="note", + feed_contact_email="email@example.com", + source_info=SourceInfo( + producer_url="http://producer.url", + authentication_type=AuthenticationType.NUMBER_1, + authentication_info_url="http://auth.info.url", + api_key_parameter_name="api_key", + license_url="http://license.url", + ), + redirects=[RedirectImpl(target_id="target_stable_id", comment="Test comment")], + external_ids=[ExternalIdImpl(external_id="external_id")], + entity_types=[EntityType.VP], + feed_references=["feed_reference"], + ) + entity = Gtfsrealtimefeed(id="1", stable_id="stable_id", data_type="gtfs") + target_feed = Gtfsfeed(id=2, stable_id="target_stable_id") + resulting_entity = Entitytype(name="VP") + + session = MagicMock() + session.query.return_value.filter.return_value.first.side_effect = [ + target_feed, + resulting_entity, + ] + + result = UpdateRequestGtfsRtFeedImpl.to_orm(update_request, entity, session) + assert result is not None + assert result.status == "active" + assert result.provider == "provider" + assert result.feed_name == "feed_name" + assert result.note == "note" + assert result.feed_contact_email == "email@example.com" + assert result.producer_url == "http://producer.url" + assert result.authentication_type == "1" + assert result.authentication_info_url == "http://auth.info.url" + assert result.api_key_parameter_name == "api_key" + assert result.license_url == "http://license.url" + assert len(result.redirectingids) == 1 + assert result.redirectingids[0].target_id == target_feed.id + assert len(result.externalids) == 1 + assert result.externalids[0].associated_id == "external_id" + + +def test_to_orm_invalid_source_info(): + update_request = UpdateRequestGtfsRtFeed( + id="stable_id", + status=FeedStatus.ACTIVE, + provider="provider", + feed_name="feed_name", + note="note", + feed_contact_email="email@example.com", + source_info=None, + redirects=[RedirectImpl(target_id="target_stable_id", comment="Test comment")], + external_ids=[ExternalIdImpl(external_id="external_id")], + entity_types=[EntityType.VP], + feed_references=["feed_reference"], + ) + entity = Gtfsrealtimefeed(id="1", stable_id="stable_id", data_type="gtfs") + target_feed = Gtfsfeed(id=2, stable_id="target_stable_id") + + session = MagicMock() + session.query.return_value.filter.return_value.first.side_effect = [ + target_feed, + None, + ] + + result = UpdateRequestGtfsRtFeedImpl.to_orm(update_request, entity, session) + + assert result.producer_url is None + assert result.authentication_type is None + assert result.authentication_info_url is None + assert result.api_key_parameter_name is None + assert result.license_url is None diff --git a/functions-python/operations_api/tests/feeds_operations/impl/test_feeds_operations_impl_gtfs.py b/functions-python/operations_api/tests/feeds_operations/impl/test_feeds_operations_impl_gtfs.py new file mode 100644 index 000000000..75fe0b2ed --- /dev/null +++ b/functions-python/operations_api/tests/feeds_operations/impl/test_feeds_operations_impl_gtfs.py @@ -0,0 +1,170 @@ +import os +from unittest import mock +from unittest.mock import patch + +import pytest +from fastapi import HTTPException +from starlette.responses import Response + +from database_gen.sqlacodegen_models import Gtfsfeed +from feeds_operations.impl.feeds_operations_impl import OperationsApiImpl +from feeds_operations_gen.models.authentication_type import AuthenticationType +from feeds_operations_gen.models.external_id import ExternalId +from feeds_operations_gen.models.feed_status import FeedStatus +from feeds_operations_gen.models.source_info import SourceInfo +from feeds_operations_gen.models.update_request_gtfs_feed import UpdateRequestGtfsFeed +from operations_api.tests.conftest import feed_mdb_40 +from test_utils.database_utils import get_testing_session, default_db_url + + +@pytest.fixture +def update_request_gtfs_feed(): + return UpdateRequestGtfsFeed( + id=feed_mdb_40.id, + status=FeedStatus(feed_mdb_40.status.lower()), + external_ids=[], + provider=feed_mdb_40.provider, + feed_name=feed_mdb_40.feed_name, + note=feed_mdb_40.note, + feed_contact_email=feed_mdb_40.feed_contact_email, + source_info=SourceInfo( + producer_url=feed_mdb_40.producer_url, + authentication_type=AuthenticationType( + int(feed_mdb_40.authentication_type) + ), + authentication_info_url=feed_mdb_40.authentication_info_url, + api_key_parameter_name=feed_mdb_40.api_key_parameter_name, + license_url=feed_mdb_40.license_url, + ), + redirects=[], + operational_status_action="no_change", + ) + + +@patch("helpers.logger.Logger") +@mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + }, +) +@pytest.mark.asyncio +async def test_update_gtfs_feed_no_changes(_, update_request_gtfs_feed): + api = OperationsApiImpl() + response: Response = await api.update_gtfs_feed(update_request_gtfs_feed) + assert response.status_code == 204 + + +@patch("helpers.logger.Logger") +@mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + }, +) +@pytest.mark.asyncio +async def test_update_gtfs_feed_field_change(_, update_request_gtfs_feed): + update_request_gtfs_feed.feed_name = "New feed name" + update_request_gtfs_feed.external_ids = [ + ExternalId( + external_id="new_external_id", + source="new_source", + ) + ] + with get_testing_session() as session: + api = OperationsApiImpl() + response: Response = await api.update_gtfs_feed(update_request_gtfs_feed) + assert response.status_code == 200 + + db_feed = ( + session.query(Gtfsfeed) + .filter(Gtfsfeed.stable_id == feed_mdb_40.stable_id) + .one() + ) + assert db_feed.feed_name == "New feed name" + + +@patch("helpers.logger.Logger") +@mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + }, +) +@pytest.mark.asyncio +async def test_update_gtfs_feed_set_wip(_, update_request_gtfs_feed): + update_request_gtfs_feed.operational_status_action = "wip" + with get_testing_session() as session: + api = OperationsApiImpl() + response: Response = await api.update_gtfs_feed(update_request_gtfs_feed) + assert response.status_code == 200 + + db_feed = ( + session.query(Gtfsfeed) + .filter(Gtfsfeed.stable_id == feed_mdb_40.stable_id) + .one() + ) + assert db_feed.operational_status == "wip" + + +@patch("helpers.logger.Logger") +@mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + }, +) +@pytest.mark.asyncio +async def test_update_gtfs_feed_set_wip_publish(_, update_request_gtfs_feed): + update_request_gtfs_feed.operational_status_action = "published" + with get_testing_session() as session: + api = OperationsApiImpl() + response: Response = await api.update_gtfs_feed(update_request_gtfs_feed) + assert response.status_code == 200 + + db_feed = ( + session.query(Gtfsfeed) + .filter(Gtfsfeed.stable_id == feed_mdb_40.stable_id) + .one() + ) + assert db_feed.operational_status is None + + +@patch("helpers.logger.Logger") +@mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + }, +) +@pytest.mark.asyncio +async def test_update_gtfs_feed_set_wip_nochange(_, update_request_gtfs_feed): + update_request_gtfs_feed.operational_status_action = "no_change" + with get_testing_session() as session: + api = OperationsApiImpl() + response: Response = await api.update_gtfs_feed(update_request_gtfs_feed) + assert response.status_code == 204 + + db_feed = ( + session.query(Gtfsfeed) + .filter(Gtfsfeed.stable_id == feed_mdb_40.stable_id) + .one() + ) + assert db_feed.operational_status is None + + +@patch("helpers.logger.Logger") +@mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + }, +) +@pytest.mark.asyncio +async def test_update_gtfs_feed_invalid_feed(_, update_request_gtfs_feed): + update_request_gtfs_feed.id = "invalid" + api = OperationsApiImpl() + with pytest.raises(HTTPException) as exc_info: + await api.update_gtfs_feed(update_request_gtfs_feed) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Feed ID not found: invalid" diff --git a/functions-python/operations_api/tests/feeds_operations/impl/test_feeds_operations_impl_gtfs_rt.py b/functions-python/operations_api/tests/feeds_operations/impl/test_feeds_operations_impl_gtfs_rt.py new file mode 100644 index 000000000..d07fd9987 --- /dev/null +++ b/functions-python/operations_api/tests/feeds_operations/impl/test_feeds_operations_impl_gtfs_rt.py @@ -0,0 +1,93 @@ +import os +from unittest import mock +from unittest.mock import patch + +import pytest +from starlette.responses import Response + +from database_gen.sqlacodegen_models import Gtfsrealtimefeed +from feeds_operations.impl.feeds_operations_impl import OperationsApiImpl +from feeds_operations_gen.models.authentication_type import AuthenticationType +from feeds_operations_gen.models.entity_type import EntityType +from feeds_operations_gen.models.feed_status import FeedStatus +from feeds_operations_gen.models.source_info import SourceInfo +from feeds_operations_gen.models.update_request_gtfs_rt_feed import ( + UpdateRequestGtfsRtFeed, +) +from operations_api.tests.conftest import feed_mdb_41 +from test_utils.database_utils import get_testing_session, default_db_url + + +@pytest.fixture +def update_request_gtfs_rt_feed(): + return UpdateRequestGtfsRtFeed( + id=feed_mdb_41.stable_id, + status=FeedStatus(feed_mdb_41.status.lower()), + external_ids=[], + provider=feed_mdb_41.provider, + feed_name=feed_mdb_41.feed_name, + note=feed_mdb_41.note, + feed_contact_email=feed_mdb_41.feed_contact_email, + source_info=SourceInfo( + producer_url=feed_mdb_41.producer_url, + authentication_type=AuthenticationType( + int(feed_mdb_41.authentication_type) + ), + authentication_info_url=feed_mdb_41.authentication_info_url, + api_key_parameter_name=feed_mdb_41.api_key_parameter_name, + license_url=feed_mdb_41.license_url, + ), + redirects=[], + operational_status_action="no_change", + entity_types=[EntityType.VP], + ) + + +@patch("helpers.logger.Logger") +@mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + }, +) +@pytest.mark.asyncio +async def test_update_gtfs_feed_field_change(_, update_request_gtfs_rt_feed): + update_request_gtfs_rt_feed.feed_name = "New feed name" + with get_testing_session() as session: + api = OperationsApiImpl() + response: Response = await api.update_gtfs_rt_feed(update_request_gtfs_rt_feed) + assert response.status_code == 200 + + db_feed = ( + session.query(Gtfsrealtimefeed) + .filter(Gtfsrealtimefeed.stable_id == feed_mdb_41.stable_id) + .one() + ) + assert db_feed.feed_name == "New feed name" + + +@patch("helpers.logger.Logger") +@mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + }, +) +@pytest.mark.asyncio +async def test_update_gtfs_feed_static_change(_, update_request_gtfs_rt_feed): + update_request_gtfs_rt_feed.feed_references = ["mdb-400"] + with get_testing_session() as session: + api = OperationsApiImpl() + response: Response = await api.update_gtfs_rt_feed(update_request_gtfs_rt_feed) + assert response.status_code == 200 + + db_feed = ( + session.query(Gtfsrealtimefeed) + .filter(Gtfsrealtimefeed.stable_id == feed_mdb_41.stable_id) + .one() + ) + assert len(db_feed.gtfs_feeds) == 1 + feed = next( + (feed for feed in db_feed.gtfs_feeds if feed.stable_id == "mdb-400"), None + ) + assert feed is not None, "Feed with stable ID 'mdb-400' not found" diff --git a/functions-python/operations_api/tests/feeds_operations/impl/test_request_validator.py b/functions-python/operations_api/tests/feeds_operations/impl/test_request_validator.py new file mode 100644 index 000000000..f0b86e903 --- /dev/null +++ b/functions-python/operations_api/tests/feeds_operations/impl/test_request_validator.py @@ -0,0 +1,54 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from pydantic import BaseModel +from fastapi import HTTPException +from feeds_operations.impl.request_validator import validate_request + + +class MockImplModel(BaseModel): + name: str + age: int + + +@validate_request(MockImplModel, "data") +async def sample_function(data: MockImplModel): + return data + + +@pytest.mark.asyncio +async def test_valid_request(): + data = MockImplModel(name="John Doe", age=30) + result = await sample_function(data) + assert result == data + + +@pytest.mark.asyncio +async def test_invalid_request(): + data = {"name": "John Doe", "age": "invalid_age"} + with pytest.raises(HTTPException) as exc_info: + await sample_function(data) + assert exc_info.value.status_code == 400 + assert "Input should be a valid integer" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_missing_parameter(): + with pytest.raises(HTTPException) as exc_info: + await sample_function(None) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Missing parameter" diff --git a/functions-python/operations_api/tests/middleware/test_request_context_middleware.py b/functions-python/operations_api/tests/middleware/test_request_context_middleware.py new file mode 100644 index 000000000..bf7920ee5 --- /dev/null +++ b/functions-python/operations_api/tests/middleware/test_request_context_middleware.py @@ -0,0 +1,77 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from unittest.mock import patch, MagicMock +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send +import asyncio + +from middleware.request_context_middleware import ( + RequestContextMiddleware, +) + + +@pytest.fixture +def scope(): + def _scope(token): + return { + "type": "http", + "headers": [ + (b"host", b"example.com"), + (b"x-forwarded-proto", b"https"), + (b"x-forwarded-for", b"192.168.1.1"), + (b"user-agent", b"test-agent"), + (b"x-goog-iap-jwt-assertion", b"test-assertion"), + (b"x-cloud-trace-context", b"trace-id/span-id;o=1"), + (b"authorization", f"Bearer {token}".encode("utf-8")), + ], + "client": ("192.168.1.1", 12345), + "server": ("127.0.0.1", 8000), + "scheme": "https", + } + + return _scope + + +@pytest.mark.asyncio +@patch("middleware.request_context_middleware.RequestContext") +async def test_request_context_middleware(mock_request_context, scope, monkeypatch): + token = "test-token" + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "true") + + mock_request_context.return_value = MagicMock() + + async def mock_call_next(scope: Scope, receive: Receive, send: Send) -> None: + response = Response("Test response") + await response(scope, receive, send) + + middleware = RequestContextMiddleware(mock_call_next) + request = Request(scope=scope(token)) + + async def mock_send(message): + pass + + try: + await asyncio.wait_for( + middleware(request.scope, request.receive, mock_send), timeout=5.0 + ) + except asyncio.TimeoutError: + pytest.fail("The test timed out") + + mock_request_context.assert_called_once_with(scope=request.scope) diff --git a/functions-python/operations_api/tests/middleware/test_request_context_oauth2.py b/functions-python/operations_api/tests/middleware/test_request_context_oauth2.py new file mode 100644 index 000000000..53bee9617 --- /dev/null +++ b/functions-python/operations_api/tests/middleware/test_request_context_oauth2.py @@ -0,0 +1,179 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from fastapi import HTTPException +from unittest.mock import patch +from starlette.datastructures import Headers +from middleware.request_context_oauth2 import RequestContext + + +@pytest.fixture +def scope(): + def _scope(token): + result = { + "type": "http", + "headers": [ + (b"host", b"example.com"), + (b"x-forwarded-proto", b"https"), + (b"x-forwarded-for", b"192.168.1.1"), + (b"user-agent", b"test-agent"), + (b"x-goog-iap-jwt-assertion", b"test-assertion"), + (b"x-cloud-trace-context", b"trace-id/span-id;o=1"), + ], + "client": ("192.168.1.1", 12345), + "server": ("127.0.0.1", 8000), + "scheme": "https", + } + if token is not None: + if token is not None: + result["headers"].append((b"authorization", f"Bearer {token}".encode())) + return result + + return _scope + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_initialization( + mock_get_tokeninfo_response, scope, monkeypatch +): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "true") + + mock_get_tokeninfo_response.return_value.status_code = 200 + mock_get_tokeninfo_response.return_value.json.return_value = { + "email": "test-email@example.com", + "audience": "test-client-id", + "email_verified": True, + "expires_in": 3600, + } + + mocked_scope = scope("test_token_test_request_context_initialization") + request_context = RequestContext(mocked_scope) + + assert request_context.host == "example.com" + assert request_context.protocol == "https" + assert request_context.client_host == "192.168.1.1" + assert request_context.server_ip == "127.0.0.1" + assert request_context.client_user_agent == "test-agent" + assert request_context.iap_jwt_assertion == "test-assertion" + assert request_context.trace_id == "trace-id" + assert request_context.span_id == "span-id" + assert request_context.trace_sampled is True + assert request_context.user_email == "test-email@example.com" + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_invalid_audience( + mock_get_tokeninfo_response, scope, monkeypatch +): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id_audience") + monkeypatch.setenv("LOCAL_ENV", "true") + + mock_get_tokeninfo_response.return_value.status_code = 200 + mock_get_tokeninfo_response.return_value.json.return_value = { + "email": "test-email@example.com", + "audience": "not-test-client-id", + "email_verified": True, + "expires_in": 3600, + } + + mocked_scope = scope("test_request_context_invalid_audience") + + with pytest.raises(HTTPException) as exc_info: + RequestContext(mocked_scope) + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid token audience" + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_email_not_found( + mock_get_tokeninfo_response, scope, monkeypatch +): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "true") + + mock_get_tokeninfo_response.return_value.status_code = 200 + mock_get_tokeninfo_response.return_value.json.return_value = { + "audience": "test-client-id", + "email_verified": True, + "expires_in": 3600, + } + + mocked_scope = scope("test_request_context_email_not_found") + + with pytest.raises(HTTPException) as exc_info: + RequestContext(mocked_scope) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Email not found in token" + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_invalid_tokeninfo_exception( + mock_get_tokeninfo_response, scope, monkeypatch +): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "true") + + mock_get_tokeninfo_response.side_effect = Exception("Test exception") + + mocked_scope = scope("test_request_context_invalid_tokeninfo_exception") + + with pytest.raises(HTTPException) as exc_info: + RequestContext(mocked_scope) + assert exc_info.value.status_code == 500 + assert exc_info.value.detail == "Token validation failed" + + +def test_request_context_missing_authorization(scope, monkeypatch): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "False") + + mocked_scope = scope("test_token_test_request_context_missing_authorization") + headers = Headers(scope=mocked_scope) + headers._list = [(k, v) for k, v in headers._list if k != b"authorization"] + mocked_scope["headers"] = headers.raw + + with pytest.raises(HTTPException) as exc_info: + RequestContext(mocked_scope) + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Authorization header not found" + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_invalid_token(mock_get_tokeninfo_response, scope, monkeypatch): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "False") + + mock_get_tokeninfo_response.return_value.status_code = 400 + + with pytest.raises(HTTPException) as exc_info: + RequestContext(scope("test_token_test_request_context_invalid_token")) + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid access token" + + +@patch("middleware.request_context_oauth2.get_tokeninfo_response") +def test_request_context_no_token(mock_get_tokeninfo_response, scope, monkeypatch): + monkeypatch.setenv("GOOGLE_CLIENT_ID", "test-client-id") + monkeypatch.setenv("LOCAL_ENV", "False") + + mock_get_tokeninfo_response.return_value.status_code = 400 + + with pytest.raises(HTTPException) as exc_info: + RequestContext(scope(token=None)) + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Authorization header not found" diff --git a/infra/feed-api/main.tf b/infra/feed-api/main.tf index 67ebf382c..05e71786d 100644 --- a/infra/feed-api/main.tf +++ b/infra/feed-api/main.tf @@ -77,6 +77,12 @@ resource "google_cloud_run_v2_service" "mobility-feed-api" { name = "PROJECT_ID" value = data.google_project.project.project_id } + resources { + limits = { + cpu = "1" + memory = "1Gi" + } + } } } } diff --git a/infra/functions-python/main.tf b/infra/functions-python/main.tf index 25370b9a1..c0eb7645c 100644 --- a/infra/functions-python/main.tf +++ b/infra/functions-python/main.tf @@ -36,6 +36,9 @@ locals { function_feed_sync_dispatcher_transitland_config = jsondecode(file("${path.module}/../../functions-python/feed_sync_dispatcher_transitland/function_config.json")) function_feed_sync_dispatcher_transitland_zip = "${path.module}/../../functions-python/feed_sync_dispatcher_transitland/.dist/feed_sync_dispatcher_transitland.zip" + + function_operations_api_config = jsondecode(file("${path.module}/../../functions-python/operations_api/function_config.json")) + function_operations_api_zip = "${path.module}/../../functions-python/operations_api/.dist/operations_api.zip" } locals { @@ -116,6 +119,13 @@ resource "google_storage_bucket_object" "feed_sync_dispatcher_transitland_zip" { source = local.function_feed_sync_dispatcher_transitland_zip } +# 7. Operations API +resource "google_storage_bucket_object" "operations_api_zip" { + bucket = google_storage_bucket.functions_bucket.name + name = "operations-api-${substr(filebase64sha256(local.function_operations_api_zip), 0, 10)}.zip" + source = local.function_operations_api_zip +} + # Secrets access resource "google_secret_manager_secret_iam_member" "secret_iam_member" { for_each = local.unique_secret_keys @@ -582,6 +592,49 @@ resource "google_cloudfunctions2_function" "feed_sync_dispatcher_transitland" { } } +resource "google_cloudfunctions2_function" "operations_api" { + name = "${local.function_operations_api_config.name}" + description = local.function_operations_api_config.description + location = var.gcp_region + depends_on = [google_secret_manager_secret_iam_member.secret_iam_member] + + build_config { + runtime = var.python_runtime + entry_point = local.function_operations_api_config.entry_point + source { + storage_source { + bucket = google_storage_bucket.functions_bucket.name + object = google_storage_bucket_object.operations_api_zip.name + } + } + } + service_config { + environment_variables = { + PROJECT_ID = var.project_id + PYTHONNODEBUGRANGES = 0 + GOOGLE_CLIENT_ID = var.operations_oauth2_client_id + } + available_memory = local.function_operations_api_config.memory + timeout_seconds = local.function_operations_api_config.timeout + available_cpu = local.function_operations_api_config.available_cpu + max_instance_request_concurrency = local.function_operations_api_config.max_instance_request_concurrency + max_instance_count = local.function_operations_api_config.max_instance_count + min_instance_count = local.function_operations_api_config.min_instance_count + service_account_email = google_service_account.functions_service_account.email + ingress_settings = local.function_operations_api_config.ingress_settings + vpc_connector = data.google_vpc_access_connector.vpc_connector.id + vpc_connector_egress_settings = "PRIVATE_RANGES_ONLY" + dynamic "secret_environment_variables" { + for_each = local.function_operations_api_config.secret_environment_variables + content { + key = secret_environment_variables.value["key"] + project_id = var.project_id + secret = "${upper(var.environment)}_${secret_environment_variables.value["key"]}" + version = "latest" + } + } + } +} # IAM entry for all users to invoke the function resource "google_cloudfunctions2_function_iam_member" "tokens_invoker" { @@ -600,6 +653,23 @@ resource "google_cloud_run_service_iam_member" "tokens_cloud_run_invoker" { member = "allUsers" } +# Allow Operations API function to be called by all users +resource "google_cloudfunctions2_function_iam_member" "operations_api_invoker" { + project = var.project_id + location = var.gcp_region + cloud_function = google_cloudfunctions2_function.operations_api.name + role = "roles/cloudfunctions.invoker" + member = "allUsers" +} + +resource "google_cloud_run_service_iam_member" "operastions_cloud_run_invoker" { + project = var.project_id + location = var.gcp_region + service = google_cloudfunctions2_function.operations_api.name + role = "roles/run.invoker" + member = "allUsers" +} + # Permissions on the service account used by the function and Eventarc trigger resource "google_project_iam_member" "invoking" { project = var.project_id diff --git a/infra/functions-python/vars.tf b/infra/functions-python/vars.tf index c5029bdf3..8c68c2a3d 100644 --- a/infra/functions-python/vars.tf +++ b/infra/functions-python/vars.tf @@ -69,3 +69,8 @@ variable "transitland_api_key" { type = string description = "Transitland API key" } + +variable "operations_oauth2_client_id" { + type = string + description = "value of the OAuth2 client id for the Operations API" +} \ No newline at end of file diff --git a/infra/main.tf b/infra/main.tf index d238259dc..92e5f8970 100644 --- a/infra/main.tf +++ b/infra/main.tf @@ -105,8 +105,9 @@ module "functions-python" { project_id = var.project_id gcp_region = var.gcp_region environment = var.environment + transitland_api_key = var.transitland_api_key - validator_endpoint = var.validator_endpoint + operations_oauth2_client_id = var.operations_oauth2_client_id } module "workflows" { diff --git a/infra/vars.tf b/infra/vars.tf index 6dc0ebee1..ea21efa3d 100644 --- a/infra/vars.tf +++ b/infra/vars.tf @@ -66,4 +66,9 @@ variable "validator_endpoint" { variable "transitland_api_key" { type = string +} + +variable "operations_oauth2_client_id" { + type = string + description = "value of the OAuth2 client id for the Operations API" } \ No newline at end of file diff --git a/infra/vars.tfvars.rename_me b/infra/vars.tfvars.rename_me index 6dc3bd0b5..ef4120349 100644 --- a/infra/vars.tfvars.rename_me +++ b/infra/vars.tfvars.rename_me @@ -17,4 +17,6 @@ oauth2_client_secret = {{OAUTH2_CLIENT_SECRET}} global_rate_limit_req_per_minute = {{GLOBAL_RATE_LIMIT_REQ_PER_MINUTE}} validator_endpoint = {{VALIDATOR_ENDPOINT}} -transitland_api_key = {{TRANSITLAND_API_KEY}} \ No newline at end of file +transitland_api_key = {{TRANSITLAND_API_KEY}} + +operations_oauth2_client_id = {{OPERATIONS_OAUTH2_CLIENT_ID}} \ No newline at end of file diff --git a/scripts/api-operations-gen.sh b/scripts/api-operations-gen.sh new file mode 100755 index 000000000..cc5aa8417 --- /dev/null +++ b/scripts/api-operations-gen.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# +# This script generates the fastapi server stubs. It uses the gen-config.yaml file for additional properties. +# For information regarding ignored generated files check .openapi-generator-ignore file. +# As a requirement, you need to execute one time setup-openapi-generator.sh. +# Usage: +# api-gen.sh +# + +GENERATOR_VERSION=7.10.0 + +# relative path +SCRIPT_PATH="$(dirname -- "${BASH_SOURCE[0]}")" +OPERATIONS_PATH=functions-python/operations_api +OPENAPI_SCHEMA=$SCRIPT_PATH/../docs/OperationsAPI.yaml +OUTPUT_PATH=$SCRIPT_PATH/../$OPERATIONS_PATH +CONFIG_FILE=$SCRIPT_PATH/gen-operations-config.yaml + +echo "Generating FastAPI server stubs for Operations API from $OPENAPI_SCHEMA to $OUTPUT_PATH" +# Keep the "--global-property apiTests=false" at the end, otherwise it will generate test files that we already have +OPENAPI_GENERATOR_VERSION=$GENERATOR_VERSION $SCRIPT_PATH/bin/openapitools/openapi-generator-cli generate -g python-fastapi \ +-i $OPENAPI_SCHEMA -o $OUTPUT_PATH -c $CONFIG_FILE --global-property apiTests=false + diff --git a/scripts/api-tests.sh b/scripts/api-tests.sh index b172d133d..d7b73321d 100755 --- a/scripts/api-tests.sh +++ b/scripts/api-tests.sh @@ -87,6 +87,7 @@ while [[ $# -gt 0 ]]; do done cat $ABS_SCRIPTPATH/../config/.env.local > $ABS_SCRIPTPATH/../.env +PYTHONPATH_ORIGINAL=$PYTHONPATH execute_tests() { printf "\nExecuting tests in $1\n" @@ -98,6 +99,10 @@ execute_tests() { venv/bin/python -m pip install --disable-pip-version-check -r requirements_dev.txt >/dev/null venv/bin/python -m pip install --disable-pip-version-check coverage >/dev/null + # Setting PYTHONPATH to functions-python and the current function source directory + export PYTHONPATH="$ABS_SCRIPTPATH/../functions-python:$ABS_SCRIPTPATH/$1/src:$PYTHONPATH_ORIGINAL" + printf "PYTHONPATH=$PYTHONPATH\n" + # Run tests with coverage venv/bin/coverage run --branch -m pytest -W 'ignore::DeprecationWarning' tests @@ -140,8 +145,6 @@ fi execute_python_tests() { printf "\nExecuting python tests in $1\n" cd $ABS_SCRIPTPATH/../$1 - export PYTHONPATH="$ABS_SCRIPTPATH/../functions-python:$PYTHONPATH" - printf "PYTHONPATH=$PYTHONPATH\n" # Function to determine if a directory is valid for test execution should_directory_contain_tests() { diff --git a/scripts/function-python-build.sh b/scripts/function-python-build.sh index 7ea36ea15..ceb4ebdc8 100755 --- a/scripts/function-python-build.sh +++ b/scripts/function-python-build.sh @@ -92,6 +92,14 @@ build_function() { rm -rf "$FX_DIST_PATH" mkdir "$FX_DIST_PATH" + # Run pre_build script if specified + pre_build_script=$(jq -r '.build_settings.pre_build_script // empty' "$FX_PATH/function_config.json") + if [ -n "$pre_build_script" ]; then + printf "\nRunning pre_build script: $pre_build_script\n" + (cd "$FX_PATH" && eval "$pre_build_script") + printf "\nCompleted running pre_build script\n" + fi + cp -R "$FX_SOURCE_PATH" "$FX_DIST_BUILD" cp "$FX_PATH/requirements.txt" "$FX_DIST_BUILD" diff --git a/scripts/function-python-deploy.sh b/scripts/function-python-deploy.sh new file mode 100755 index 000000000..0efec4c65 --- /dev/null +++ b/scripts/function-python-deploy.sh @@ -0,0 +1,146 @@ +#!/bin/bash + +# Ensure the script exits if any command fails +set -e + +# relative path +SCRIPT_PATH="$(dirname -- "${BASH_SOURCE[0]}")" +FUNCTIONS_PATH="$SCRIPT_PATH/../functions-python" + +# Function to display usage +usage() { + echo "Usage: $0 [--build] [--help]" + echo " Name of the function to deploy" + echo " --build Optional flag to build the function before deploying" + echo " --help Display this help message" + exit 1 +} + +# defaults +FUNCTION_NAME="" +BUILD_FUNCTION=false +LOCAL_ENV_FILE=".env.local" + +# Parse parameters +while [[ $# -gt 0 ]]; do + case $1 in + --help) + usage + ;; + --build) + BUILD_FUNCTION=true + shift + ;; + *) + if [ -z "$FUNCTION_NAME" ]; then + FUNCTION_NAME=$1 + else + echo "Unknown parameter: $1" + usage + fi + shift + ;; + esac +done + +# Check if function name is provided +if [ -z "$FUNCTION_NAME" ]; then + usage +fi + +# Read configuration from function_config.json +CONFIG_FILE="$FUNCTIONS_PATH/$FUNCTION_NAME/function_config.json" +if [ ! -f "$CONFIG_FILE" ]; then + echo "Configuration file $CONFIG_FILE not found!" + exit 1 +fi + +RUNTIME=python311 +SOURCE=$FUNCTIONS_PATH/$FUNCTION_NAME/.dist/build +ENVIRONMENT=dev +PROJECT=mobility-feeds-$ENVIRONMENT +SERVICE_ACCOUNT=functions-service-account@mobility-feeds-$ENVIRONMENT.iam.gserviceaccount.com +ENVIRONMENT_UPPER=$(echo "$ENVIRONMENT" | tr '[:lower:]' '[:upper:]') + +if [ ! -d "$SOURCE" ]; then + echo "Function distribution folder found in $SOURCE. Building function..." + BUILD_FUNCTION=true +fi + +ENTRY_POINT=$(jq -r '.entry_point // empty' "$CONFIG_FILE") +TIMEOUT=$(jq -r '.timeout // empty' "$CONFIG_FILE") +MEMORY=$(jq -r '.memory // empty' "$CONFIG_FILE") +TRIGGER_HTTP=$(jq -r '.trigger_http // empty' "$CONFIG_FILE") +SECRET_ENV_VARS=$(jq -r '.secret_environment_variables | map("--set-secrets \(.key)=projects/'$PROJECT'/secrets/'$ENVIRONMENT_UPPER'_\(.key)/versions/latest") | join(" ") // empty' "$CONFIG_FILE") +INGRESS_SETTINGS=$(jq -r '.ingress_settings // empty' "$CONFIG_FILE") +MAX_CONCURRENCY=$(jq -r '.max_instance_request_concurrency // empty' "$CONFIG_FILE") +MAX_INSTANCES=$(jq -r '.max_instance_count // empty' "$CONFIG_FILE") +MIN_INSTANCES=$(jq -r '.min_instance_count // empty' "$CONFIG_FILE") +AVAILABLE_CPU=$(jq -r '.available_cpu // empty' "$CONFIG_FILE") + +if [ -z "$RUNTIME" ] || [ -z "$ENTRY_POINT" ]; then + echo "Invalid configuration in $CONFIG_FILE" + exit 1 +fi + +# Function to read environment variables from LOCAL_ENV_FILE +read_env_vars() { + if [ -f "$FUNCTIONS_PATH/$FUNCTION_NAME/$LOCAL_ENV_FILE" ]; then + export $(grep -v '^#' "$FUNCTIONS_PATH/$FUNCTION_NAME/$LOCAL_ENV_FILE" | xargs) + fi +} + +# Read environment variables from LOCAL_ENV_FILE +read_env_vars + +# Grant the Cloud Function's service account access to each secret +SECRETS=$(jq -r '.secret_environment_variables[].key' "$CONFIG_FILE") + +for SECRET_NAME in $SECRETS; do + gcloud secrets add-iam-policy-binding ${ENVIRONMENT_UPPER}_$SECRET_NAME \ + --project 978785769226 \ + --member "serviceAccount:$SERVICE_ACCOUNT" \ + --role "roles/secretmanager.secretAccessor" +done + +# Run the build script if the --build flag is provided +if [ "$BUILD_FUNCTION" = true ]; then + $SCRIPT_PATH/function-python-build.sh --function_name $FUNCTION_NAME +fi + +# Prepare environment variables from function_config.json +ENV_VARS="" +printf "Environment variables\n" +while IFS= read -r line; do + KEY=$(echo $line | jq -r '.key') + ENV_VAR_NAME=$(echo "$line" | jq -r '.[keys[0]]') + ENV_VAR_VALUE=$(printenv "$ENV_VAR_NAME") + printf " $KEY=$ENV_VAR_VALUE\n" + if [ -n "$ENV_VAR_VALUE" ]; then + ENV_VARS="$ENV_VARS --set-env-vars $KEY=$ENV_VAR_VALUE" + fi +done < <(jq -c '.environment_variables[]' "$CONFIG_FILE") + +# Deploy the function +DEPLOY_CMD="gcloud functions deploy $FUNCTION_NAME --gen2 --project $PROJECT --region northamerica-northeast1 --runtime $RUNTIME --entry-point $ENTRY_POINT --source $SOURCE --service-account $SERVICE_ACCOUNT" + +[ -n "$TIMEOUT" ] && DEPLOY_CMD="$DEPLOY_CMD --timeout $TIMEOUT" +[ -n "$MEMORY" ] && DEPLOY_CMD="$DEPLOY_CMD --memory $MEMORY" +[ -n "$INGRESS_SETTINGS" ] && DEPLOY_CMD="$DEPLOY_CMD --ingress-settings $INGRESS_SETTINGS" +[ -n "$MAX_INSTANCES" ] && DEPLOY_CMD="$DEPLOY_CMD --max-instances $MAX_INSTANCES" +[ -n "$MIN_INSTANCES" ] && DEPLOY_CMD="$DEPLOY_CMD --min-instances $MIN_INSTANCES" +[ -n "$MAX_CONCURRENCY" ] && DEPLOY_CMD="$DEPLOY_CMD --concurrency $MAX_CONCURRENCY" +[ -n "$AVAILABLE_CPU" ] && DEPLOY_CMD="$DEPLOY_CMD --cpu $AVAILABLE_CPU" +[ -n "$SECRET_ENV_VARS" ] && DEPLOY_CMD="$DEPLOY_CMD $SECRET_ENV_VARS" +[ -n "$ENV_VARS" ] && DEPLOY_CMD="$DEPLOY_CMD $ENV_VARS" + +if [ "$TRIGGER_HTTP" = true ]; then + DEPLOY_CMD="$DEPLOY_CMD --trigger-http --allow-unauthenticated" +else + echo "HTTP trigger not supported for $FUNCTION_NAME" +fi + +# Execute the deploy command +eval $DEPLOY_CMD + +echo "Deployment of $FUNCTION_NAME complete." \ No newline at end of file diff --git a/scripts/gen-operations-config.yaml b/scripts/gen-operations-config.yaml new file mode 100644 index 000000000..5d4a76bff --- /dev/null +++ b/scripts/gen-operations-config.yaml @@ -0,0 +1,9 @@ +# Documentation, https://openapi-generator.tech/docs/generators/python-fastapi/ +additionalProperties: + packageName: feeds_operations_gen + # modelNameSuffix: Api + removeOperationIdPrefix: true + fastapiImplementationPackage: feeds_operations.impl + useTags: false + # Adding this commented line for future reference as it is not currently supported by the fastApi generator + # legacyDiscriminatorBehavior: true