diff --git a/mds/api/__init__.py b/mds/api/__init__.py index 1013089..f0c4777 100644 --- a/mds/api/__init__.py +++ b/mds/api/__init__.py @@ -2,5 +2,4 @@ Module implementing the MDS Provider API. """ -from mds.api.client import ProviderClient - +from mds.api.client import MultipleProviderClient diff --git a/mds/api/client.py b/mds/api/client.py index bb27cdc..cc7c7cf 100644 --- a/mds/api/client.py +++ b/mds/api/client.py @@ -1,30 +1,15 @@ """ -MDS Provider API client implementation. +MDS Provider API client implementation. """ from datetime import datetime import json +import requests import mds from mds.api.auth import OAuthClientCredentialsAuth from mds.providers import get_registry, Provider - -class ProviderClient(OAuthClientCredentialsAuth): - """ - Client for MDS Provider APIs - """ - def __init__(self, providers=None, ref=None): - """ - Initialize a new ProviderClient object. - - :providers: is a list of Providers this client tracks by default. If None is given, downloads and uses the official Provider registry. - - When using the official Providers registry, :ref: could be any of: - - git branch name - - commit hash (long or short) - - git tag - """ - self.providers = providers if providers is not None else get_registry(ref) +class ProviderClientBase(OAuthClientCredentialsAuth): def _auth_session(self, provider): """ @@ -50,23 +35,83 @@ def _build_url(self, provider, endpoint): return url - def _request(self, providers, endpoint, params, paging): + def _date_format(self, dt): """ - Internal helper for sending requests. - - Returns a dict of provider => payload(s). + Internal helper to format datetimes for querystrings. """ - def __describe(res): - """ - Prints details about the given response. - """ - print(f"Requested {res.url}, Response Code: {res.status_code}") - print("Response Headers:") - for k,v in res.headers.items(): - print(f"{k}: {v}") + return int(dt.timestamp()) if isinstance(dt, datetime) else int(dt) - if r.status_code is not 200: - print(r.text) + def _prepare_status_changes_params( + self, + start_time=None, + end_time=None, + bbox=None, + **kwargs): + + # convert datetimes to querystring friendly format + if start_time is not None: + start_time = self._date_format(start_time) + if end_time is not None: + end_time = self._date_format(end_time) + + # gather all the params together + return { + **dict(start_time=start_time, end_time=end_time, bbox=bbox), + **kwargs + } + + def _prepare_trips_params( + self, + device_id=None, + vehicle_id=None, + start_time=None, + end_time=None, + bbox=None, + **kwargs): + + # convert datetimes to querystring friendly format + if start_time is not None: + start_time = self._date_format(start_time) + if end_time is not None: + end_time = self._date_format(end_time) + + # gather all the params togethers + return { + **dict(device_id=device_id, vehicle_id=vehicle_id, start_time=start_time, end_time=end_time, bbox=bbox), + **kwargs + } + + +class ProviderClient(ProviderClientBase): + def __init__(self, provider): + self.provider = provider + + def get_trips(self, **kwargs): + return list(self.iterate_trips_pages(**kwargs)) + + def get_status_changes(self, **kwargs): + return list(self.iterate_status_change_pages(**kwargs)) + + def iterate_trips_pages(self, paging=True, **kwargs): + params = self._prepare_trips_params(**kwargs) + return self.request(mds.TRIPS, params, paging) + + def iterate_status_change_pages(self, paging=True, **kwargs): + params = self._prepare_status_changes_params(**kwargs) + return self.request(mds.STATUS_CHANGES, params, paging) + + def request(self, endpoint, params, paging): + url = self._build_url(self.provider, endpoint) + session = self._auth_session(self.provider) + for page in self._iterate_pages_from_session(session, endpoint, url, params): + yield page + if not paging: + break + + def _iterate_pages_from_session(self, session, endpoint, url, params): + """ + Request items from endpoint, following pages + """ def __has_data(page): """ @@ -83,61 +128,73 @@ def __next_url(page): """ return page["links"].get("next") if "links" in page else None - # create a request url for each provider - urls = [self._build_url(p, endpoint) for p in providers] + response = session.get(url, params=params) + response.raise_for_status() - # keyed by provider - results = {} - - for i in range(len(providers)): - provider, url = providers[i], urls[i] + this_page = response.json() + if __has_data(this_page): + yield this_page - # establish an authenticated session - session = self._auth_session(provider) + next_url = __next_url(this_page) + while next_url is not None: + response = session.get(next_url) + response.raise_for_status() + this_page = response.json() + if __has_data(this_page): + yield this_page + next_url = __next_url(this_page) + else: + break - # get the initial page of data - r = session.get(url, params=params) - if r.status_code is not 200: - __describe(r) - continue +class MultipleProviderClient(ProviderClientBase): + """ + Client for MDS Provider APIs + """ + def __init__(self, providers=None, ref=None): + """ + Initialize a new MultipleProviderClient object. - this_page = r.json() + :providers: is a list of Providers this client tracks by default. If None is given, downloads and uses the official Provider registry. - # track the list of pages per provider - results[provider] = [this_page] if __has_data(this_page) else [] + When using the official Providers registry, :ref: could be any of: + - git branch name + - commit hash (long or short) + - git tag + """ + self.providers = providers if providers is not None else get_registry(ref) - # get subsequent pages of data - next_url = __next_url(this_page) - while paging and next_url: - r = session.get(next_url) + def _request_from_providers(self, providers, endpoint, params, paging): + """ + Internal helper for sending requests. - if r.status_code is not 200: - __describe(r) - break + Returns a dict of provider => payload(s). + """ + def __describe(res): + """ + Prints details about the given response. + """ + print(f"Requested {res.url}, Response Code: {res.status_code}") + print("Response Headers:") + for k,v in res.headers.items(): + print(f"{k}: {v}") - this_page = r.json() + if r.status_code is not 200: + print(r.text) - if __has_data(this_page): - results[provider].append(this_page) - next_url = __next_url(this_page) - else: - break + results = {} + for provider in providers: + client = ProviderClient(provider) + try: + results[provider] = list(client.request(endpoint, params, paging)) + except requests.RequestException as exc: + __describe(exc.response) return results - def _date_format(self, dt): - """ - Internal helper to format datetimes for querystrings. - """ - return int(dt.timestamp()) if isinstance(dt, datetime) else int(dt) - def get_status_changes( self, providers=None, - start_time=None, - end_time=None, - bbox=None, paging=True, **kwargs): """ @@ -155,7 +212,7 @@ def get_status_changes( Should be a datetime object or numeric representation of UNIX seconds - `bbox`: Filters for status changes where `event_location` is within defined bounding-box. - The order is defined as: southwest longitude, southwest latitude, + The order is defined as: southwest longitude, southwest latitude, northeast longitude, northeast latitude (separated by commas). e.g. @@ -168,31 +225,16 @@ def get_status_changes( if providers is None: providers = self.providers - # convert datetimes to querystring friendly format - if start_time is not None: - start_time = self._date_format(start_time) - if end_time is not None: - end_time = self._date_format(end_time) - - # gather all the params together - params = { - **dict(start_time=start_time, end_time=end_time, bbox=bbox), - **kwargs - } + params = self._prepare_status_changes_params(**kwargs) # make the request(s) - status_changes = self._request(providers, mds.STATUS_CHANGES, params, paging) + status_changes = self._request_from_providers(providers, mds.STATUS_CHANGES, params, paging) return status_changes def get_trips( self, providers=None, - device_id=None, - vehicle_id=None, - start_time=None, - end_time=None, - bbox=None, paging=True, **kwargs): """ @@ -214,7 +256,7 @@ def get_trips( Should be a datetime object or numeric representation of UNIX seconds - `bbox`: Filters for trips where and point within `route` is within defined bounding-box. - The order is defined as: southwest longitude, southwest latitude, + The order is defined as: southwest longitude, southwest latitude, northeast longitude, northeast latitude (separated by commas). e.g. @@ -227,19 +269,9 @@ def get_trips( if providers is None: providers = self.providers - # convert datetimes to querystring friendly format - if start_time is not None: - start_time = self._date_format(start_time) - if end_time is not None: - end_time = self._date_format(end_time) - - # gather all the params togethers - params = { - **dict(device_id=device_id, vehicle_id=vehicle_id, start_time=start_time, end_time=end_time, bbox=bbox), - **kwargs - } + params = self._prepare_trips_params(**kwargs) # make the request(s) - trips = self._request(providers, mds.TRIPS, params, paging) + trips = self._request_from_providers(providers, mds.TRIPS, params, paging) return trips diff --git a/mds/fake/server.py b/mds/fake/server.py new file mode 100644 index 0000000..e4274f7 --- /dev/null +++ b/mds/fake/server.py @@ -0,0 +1,94 @@ + +from flask import Flask, request, jsonify, url_for +import json, base64 + +class PaginationCursor(object): + def __init__(self, serialized_cursor=None, offset=None): + if offset is not None and serialized_cursor is not None: + raise RuntimeError('Cannot initialize with non-None offset AND non-None cursor') + + if serialized_cursor is not None: + data = json.loads(base64.b64decode(serialized_cursor).decode('utf-8')) + self.offset = data['o'] + else: + self.offset = 0 if offset is None else offset + + def serialize(self): + return base64.b64encode(json.dumps({ 'o': self.offset }).encode('utf-8')) + +class InMemoryPaginator(object): + def __init__(self, all_items, serialized_cursor=None, page_size=20): + self.items = all_items + self.cursor = PaginationCursor(serialized_cursor) + self.page_size = page_size + + def next_cursor_serialized(self): + offset = self.cursor.offset + return PaginationCursor(offset=offset+self.page_size).serialize() + + def get_page(self): + offset = self.cursor.offset + return self.items[offset:offset+self.page_size] + +def make_mds_response_data(version, resource_name, paginator, **params): + return { + 'version': version, + 'links': { + 'next': url_for(resource_name, + cursor=paginator.next_cursor_serialized(), + _external=True, + **params), + }, + 'data': { + resource_name: paginator.get_page(), + } + } + +def params_match_trip(params, trip): + vehicle_id = params.get('vehicle_id') + if vehicle_id and trip['vehicle_id'] != vehicle_id: + return False + + return True + +def params_match_status_change(params, sc): + return True + +def make_static_server_app(trips=[], + status_changes=[], + version='0.2.0', + page_size=20): + app = Flask('mds_static') + store = { + 'trips': trips, + 'status_changes': status_changes, + } + + @app.route('/trips') + def trips(): + params = { + 'vehicle_id': request.args.get('vehicle_id'), + # TODO: support other params + } + selected_trips = [t for t in store['trips'] if params_match_trip(params, t)] + paginator = InMemoryPaginator(selected_trips, + serialized_cursor=request.args.get('cursor'), + page_size=page_size) + return jsonify(make_mds_response_data( + version, 'trips', paginator, **params + )) + + @app.route('/status_changes') + def status_changes(): + params = { + # TODO + } + selected_items = [sc for sc in store['status_changes'] if params_match_status_change(params, sc)] + paginator = InMemoryPaginator(selected_items, + serialized_cursor=request.args.get('cursor'), + page_size=page_size) + return jsonify(make_mds_response_data( + version, 'status_changes', paginator, **params + )) + + return app diff --git a/mds/tests/__init__.py b/mds/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/mds/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/mds/tests/test_api.py b/mds/tests/test_api.py new file mode 100644 index 0000000..99ca4a7 --- /dev/null +++ b/mds/tests/test_api.py @@ -0,0 +1,79 @@ +import unittest, re, uuid +from contextlib import contextmanager +import requests_mock +from urllib3.util import parse_url + +from mds.fake.server import make_static_server_app +from mds.providers import Provider +from mds.api import MultipleProviderClient + + +def requests_mock_with_app(app, netloc='testserver'): + client = app.test_client() + def get_app_response(request, response_context): + url_object = parse_url(request.url) + app_response = client.get(url_object.request_uri, base_url='https://testserver/') + response_context.status_code = app_response.status_code + response_context.headers = app_response.headers + return app_response.data + + mock = requests_mock.Mocker() + matcher = re.compile(f'^https://{netloc}/') + mock.register_uri('GET', matcher, content=get_app_response) + return mock + +@contextmanager +def mock_provider(app): + with requests_mock_with_app(app, netloc='testserver') as mock: + provider = Provider( + 'test', + uuid.uuid4(), + url='', + auth_type='Bearer', + token='', # enable simple token auth + mds_api_url='https://testserver') + yield provider + + +class APITest(unittest.TestCase): + def setUp(self): + self.empty_app = make_static_server_app( + trips=[], + status_changes=[], + version='0.2.0', + page_size=20, + ) + + self.bogus_data_app = make_static_server_app( + trips=list(range(100)), + status_changes=list(range(100)), + version='0.2.0', + page_size=20, + ) + + def _all_items_from_app(self, app, endpoint='trips', get_method_kwargs={}): + with mock_provider(app) as provider: + client = MultipleProviderClient(providers=[provider]) + method = getattr(client, f'get_{endpoint}') + pages_by_provider = method(**get_method_kwargs) + + items = [] + for page in pages_by_provider[provider]: + for item in page['data'][endpoint]: + items.append(item) + return items + + def test_single_provider_paging_enabled(self): + # empty provider should return zero trips + trips = self._all_items_from_app(self.empty_app, 'trips') + self.assertEqual(len(trips), 0) + + # 100-trip provider should return all trips + trips = self._all_items_from_app(self.bogus_data_app, 'trips') + self.assertEqual(len(trips), 100) + + def test_single_provider_disable_paging(self): + # Turn off paging; should get just first 20 trips + trips = self._all_items_from_app(self.bogus_data_app, 'trips', + get_method_kwargs=dict(paging=False)) + self.assertEqual(len(trips), 20) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8e6a319 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ + +Fiona +geopandas +jsonschema >= 3.0.0a2 +numpy +pandas +psycopg2-binary +requests +scipy +Shapely +sqlalchemy + +flask +urllib3 +requests +requests_mock +ipdb