Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fake server & basic tests against it #45

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mds/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
Module implementing the MDS Provider API.
"""

from mds.api.client import ProviderClient

from mds.api.client import ProviderClient, MultipleProviderClient
243 changes: 141 additions & 102 deletions mds/api/client.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
"""
MDS Provider API client implementation.
MDS Provider API client implementation.
"""

from datetime import datetime
import json
import requests
import mds
from mds.api.auth import OAuthClientCredentialsAuth
from mds.providers import get_registry, Provider


class ProviderClient(OAuthClientCredentialsAuth):
"""
Client for MDS Provider APIs
"""
def __init__(self, providers=None, ref=None):
"""
Initialize a new ProviderClient object.

:providers: is a list of Providers this client tracks by default. If None is given, downloads and uses the official Provider registry.

When using the official Providers registry, :ref: could be any of:
- git branch name
- commit hash (long or short)
- git tag
"""
self.providers = providers if providers is not None else get_registry(ref)
class ProviderClientBase(OAuthClientCredentialsAuth):

def _auth_session(self, provider):
"""
Expand All @@ -50,23 +35,90 @@ def _build_url(self, provider, endpoint):

return url

def _request(self, providers, endpoint, params, paging):
def _date_format(self, dt):
"""
Internal helper for sending requests.

Returns a dict of provider => payload(s).
Internal helper to format datetimes for querystrings.
"""
def __describe(res):
"""
Prints details about the given response.
"""
print(f"Requested {res.url}, Response Code: {res.status_code}")
print("Response Headers:")
for k,v in res.headers.items():
print(f"{k}: {v}")
return int(dt.timestamp()) if isinstance(dt, datetime) else int(dt)

if r.status_code is not 200:
print(r.text)
def _prepare_status_changes_params(
self,
start_time=None,
end_time=None,
bbox=None,
**kwargs):

# convert datetimes to querystring friendly format
if start_time is not None:
start_time = self._date_format(start_time)
if end_time is not None:
end_time = self._date_format(end_time)

# gather all the params together
return {
**dict(start_time=start_time, end_time=end_time, bbox=bbox),
**kwargs
}

def _prepare_trips_params(
self,
device_id=None,
vehicle_id=None,
start_time=None,
end_time=None,
bbox=None,
**kwargs):

# convert datetimes to querystring friendly format
if start_time is not None:
start_time = self._date_format(start_time)
if end_time is not None:
end_time = self._date_format(end_time)

# gather all the params togethers
return {
**dict(device_id=device_id, vehicle_id=vehicle_id, start_time=start_time, end_time=end_time, bbox=bbox),
**kwargs
}


class ProviderClient(ProviderClientBase):
def __init__(self, provider):
self.provider = provider

def iterate_trips(self, **kwargs):
return self.iterate_items(mds.TRIPS, **kwargs)

def iterate_status_changes(self, **kwargs):
return self.iterate_items(mds.STATUS_CHANGES, **kwargs)

def iterate_items(self, endpoint, **kwargs):
for page in self.iterate_pages(endpoint, **kwargs):
for item in page['data'][endpoint]:
yield item

def iterate_pages_of_trips(self, paging=True, **kwargs):
return self.iterate_pages(mds.TRIPS, paging=paging, **kwargs)

def iterate_pages_of_status_changes(self, paging=True, **kwargs):
return self.iterate_pages(mds.STATUS_CHANGES, paging=paging, **kwargs)

def iterate_pages(self, endpoint, paging=True, **kwargs):
params = getattr(self, f'_prepare_{endpoint}_params')(**kwargs)
return self._request(endpoint, params, paging)

def _request(self, endpoint, params, paging):
url = self._build_url(self.provider, endpoint)
session = self._auth_session(self.provider)
for page in self._iterate_pages_from_session(session, endpoint, url, params):
yield page
if not paging:
break

def _iterate_pages_from_session(self, session, endpoint, url, params):
"""
Request items from endpoint, following pages
"""

def __has_data(page):
"""
Expand All @@ -83,61 +135,73 @@ def __next_url(page):
"""
return page["links"].get("next") if "links" in page else None

# create a request url for each provider
urls = [self._build_url(p, endpoint) for p in providers]

# keyed by provider
results = {}
response = session.get(url, params=params)
response.raise_for_status()

for i in range(len(providers)):
provider, url = providers[i], urls[i]
this_page = response.json()
if __has_data(this_page):
yield this_page

# establish an authenticated session
session = self._auth_session(provider)
next_url = __next_url(this_page)
while next_url is not None:
response = session.get(next_url)
response.raise_for_status()
this_page = response.json()
if __has_data(this_page):
yield this_page
next_url = __next_url(this_page)
else:
break

# get the initial page of data
r = session.get(url, params=params)

if r.status_code is not 200:
__describe(r)
continue
class MultipleProviderClient(ProviderClientBase):
"""
Client for MDS Provider APIs
"""
def __init__(self, providers=None, ref=None):
"""
Initialize a new MultipleProviderClient object.

this_page = r.json()
:providers: is a list of Providers this client tracks by default. If None is given, downloads and uses the official Provider registry.

# track the list of pages per provider
results[provider] = [this_page] if __has_data(this_page) else []
When using the official Providers registry, :ref: could be any of:
- git branch name
- commit hash (long or short)
- git tag
"""
self.providers = providers if providers is not None else get_registry(ref)

# get subsequent pages of data
next_url = __next_url(this_page)
while paging and next_url:
r = session.get(next_url)
def _request_from_providers(self, providers, endpoint, params, paging):
"""
Internal helper for sending requests.

if r.status_code is not 200:
__describe(r)
break
Returns a dict of provider => payload(s).
"""
def __describe(res):
"""
Prints details about the given response.
"""
print(f"Requested {res.url}, Response Code: {res.status_code}")
print("Response Headers:")
for k,v in res.headers.items():
print(f"{k}: {v}")

this_page = r.json()
if r.status_code is not 200:
print(r.text)

if __has_data(this_page):
results[provider].append(this_page)
next_url = __next_url(this_page)
else:
break
results = {}
for provider in providers:
client = ProviderClient(provider)
try:
results[provider] = list(client.request(endpoint, params, paging))
except requests.RequestException as exc:
__describe(exc.response)

return results

def _date_format(self, dt):
"""
Internal helper to format datetimes for querystrings.
"""
return int(dt.timestamp()) if isinstance(dt, datetime) else int(dt)

def get_status_changes(
self,
providers=None,
start_time=None,
end_time=None,
bbox=None,
paging=True,
**kwargs):
"""
Expand All @@ -155,7 +219,7 @@ def get_status_changes(
Should be a datetime object or numeric representation of UNIX seconds

- `bbox`: Filters for status changes where `event_location` is within defined bounding-box.
The order is defined as: southwest longitude, southwest latitude,
The order is defined as: southwest longitude, southwest latitude,
northeast longitude, northeast latitude (separated by commas).

e.g.
Expand All @@ -168,31 +232,16 @@ def get_status_changes(
if providers is None:
providers = self.providers

# convert datetimes to querystring friendly format
if start_time is not None:
start_time = self._date_format(start_time)
if end_time is not None:
end_time = self._date_format(end_time)

# gather all the params together
params = {
**dict(start_time=start_time, end_time=end_time, bbox=bbox),
**kwargs
}
params = self._prepare_status_changes_params(**kwargs)

# make the request(s)
status_changes = self._request(providers, mds.STATUS_CHANGES, params, paging)
status_changes = self._request_from_providers(providers, mds.STATUS_CHANGES, params, paging)

return status_changes

def get_trips(
self,
providers=None,
device_id=None,
vehicle_id=None,
start_time=None,
end_time=None,
bbox=None,
paging=True,
**kwargs):
"""
Expand All @@ -214,7 +263,7 @@ def get_trips(
Should be a datetime object or numeric representation of UNIX seconds

- `bbox`: Filters for trips where and point within `route` is within defined bounding-box.
The order is defined as: southwest longitude, southwest latitude,
The order is defined as: southwest longitude, southwest latitude,
northeast longitude, northeast latitude (separated by commas).

e.g.
Expand All @@ -227,19 +276,9 @@ def get_trips(
if providers is None:
providers = self.providers

# convert datetimes to querystring friendly format
if start_time is not None:
start_time = self._date_format(start_time)
if end_time is not None:
end_time = self._date_format(end_time)

# gather all the params togethers
params = {
**dict(device_id=device_id, vehicle_id=vehicle_id, start_time=start_time, end_time=end_time, bbox=bbox),
**kwargs
}
params = self._prepare_trips_params(**kwargs)

# make the request(s)
trips = self._request(providers, mds.TRIPS, params, paging)
trips = self._request_from_providers(providers, mds.TRIPS, params, paging)

return trips
Loading