Skip to content

Commit

Permalink
Implement OAuth refresh token flow (#520)
Browse files Browse the repository at this point in the history
* Implement OAuth refresh token flow

* Apply suggestions from code review

Formatting changes

Co-authored-by: antalszava <[email protected]>

* Add correct token refresh path

* update path, update getting the access token, remove url wrapping in _request

* Formatting with black

* move dict init into func

* Updates

* refresh access token unit tests

* no print

* User request.post directly, update tests

* Wrapped request test, updates

* Updates

* Formatting

* Remove access token refreshing from init

* Update tests/api/test_connection.py

* Update tests/api/test_connection.py

* Update tests/api/test_connection.py

* Update tests/api/test_connection.py

* Update strawberryfields/api/connection.py

Co-authored-by: Jeremy Swinarton <[email protected]>

* update msg in test

* changelog

Co-authored-by: antalszava <[email protected]>
Co-authored-by: Antal Szava <[email protected]>
Co-authored-by: Jeremy Swinarton <[email protected]>
  • Loading branch information
4 people authored Jan 15, 2021
1 parent fd1b1aa commit 7b0a45d
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 29 deletions.
10 changes: 7 additions & 3 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@
* `TDMProgram` objects can now be serialized into Blackbird scripts, and vice versa.
[(#476)](https://github.com/XanaduAI/strawberryfields/pull/476)

<h3>Breaking changes</h3>
<h3>Breaking Changes</h3>

* Jobs are submitted to the Xanadu Quantum Cloud through a new OAuth based
authentication flow using offline refresh tokens and access tokens.
[(#520)](https://github.com/XanaduAI/strawberryfields/pull/520)

<h3>Bug fixes</h3>

Expand Down Expand Up @@ -138,8 +142,8 @@

This release contains contributions from (in alphabetical order):

Tom Bromley, Jack Brown, Theodor Isacsson, Josh Izaac, Fabian Laudenbach, Nicolas Quesada,
Antal Száva.
Tom Bromley, Jack Brown, Theodor Isacsson, Josh Izaac, Fabian Laudenbach, Tim Leisti,
Nicolas Quesada, Antal Száva.

# Release 0.16.0 (current release)

Expand Down
70 changes: 60 additions & 10 deletions strawberryfields/api/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
import io
from datetime import datetime
from typing import List
from typing import List, Dict

import numpy as np
import requests
Expand Down Expand Up @@ -96,7 +96,8 @@ def __init__(
self._verbose = verbose

self._base_url = "http{}://{}:{}".format("s" if self.use_ssl else "", self.host, self.port)
self._headers = {"Authorization": self.token, "Accept-Version": self.api_version}

self._headers = {"Accept-Version": self.api_version}

self.log = create_logger(__name__)

Expand Down Expand Up @@ -156,7 +157,7 @@ def get_device_spec(self, target: str) -> DeviceSpec:
def _get_device_dict(self, target: str) -> dict:
"""Returns the device specifications as a dictionary"""
path = f"/devices/{target}/specifications"
response = requests.get(self._url(path), headers=self._headers)
response = self._request("GET", self._url(path), headers=self._headers)

if response.status_code == 200:
self.log.info("The device spec %s has been successfully retrieved.", target)
Expand Down Expand Up @@ -185,7 +186,9 @@ def create_job(self, target: str, program: Program, run_options: dict = None) ->
circuit = bb.serialize()

path = "/jobs"
response = requests.post(self._url(path), headers=self._headers, json={"circuit": circuit})
response = self._request(
"POST", self._url(path), headers=self._headers, json={"circuit": circuit}
)
if response.status_code == 201:
job_id = response.json()["id"]
if self._verbose:
Expand Down Expand Up @@ -221,7 +224,7 @@ def get_job(self, job_id: str) -> Job:
strawberryfields.api.Job: the job
"""
path = "/jobs/{}".format(job_id)
response = requests.get(self._url(path), headers=self._headers)
response = self._request("GET", self._url(path), headers=self._headers)
if response.status_code == 200:
return Job(
id_=response.json()["id"],
Expand Down Expand Up @@ -254,8 +257,8 @@ def get_job_result(self, job_id: str) -> Result:
strawberryfields.api.Result: the job result
"""
path = "/jobs/{}/result".format(job_id)
response = requests.get(
self._url(path), headers={"Accept": "application/x-numpy", **self._headers}
response = self._request(
"GET", self._url(path), headers={"Accept": "application/x-numpy", **self._headers}
)
if response.status_code == 200:
# Read the numpy binary data in the payload into memory
Expand Down Expand Up @@ -283,8 +286,11 @@ def cancel_job(self, job_id: str):
job_id (str): the job ID
"""
path = "/jobs/{}".format(job_id)
response = requests.patch(
self._url(path), headers=self._headers, json={"status": JobStatus.CANCELLED.value}
response = self._request(
"PATCH",
self._url(path),
headers=self._headers,
json={"status": JobStatus.CANCELLED.value},
)
if response.status_code == 204:
if self._verbose:
Expand All @@ -301,12 +307,56 @@ def ping(self) -> bool:
bool: ``True`` if the connection is successful, and ``False`` otherwise
"""
path = "/healthz"
response = requests.get(self._url(path), headers=self._headers)
response = self._request("GET", self._url(path), headers=self._headers)
return response.status_code == 200

def _url(self, path: str) -> str:
return self._base_url + path

def _refresh_access_token(self):
"""Use the offline token to request a new access token."""
self._headers.pop("Authorization", None)
path = "/auth/realms/platform/protocol/openid-connect/token"
headers = {**self._headers}
response = requests.post(
self._url(path),
headers=headers,
data={
"grant_type": "refresh_token",
"refresh_token": self._token,
"client_id": "public",
},
)
if response.status_code == 200:
access_token = response.json().get("access_token")
self._headers["Authorization"] = f"Bearer {access_token}"
else:
raise RequestFailedError(
"Could not retrieve access token. Please check that your API key is correct."
)

def _request(self, method: str, path: str, headers: Dict = None, **kwargs):
"""Wrap all API requests with an authentication token refresh if a 401 status
is received from the initial request.
Args:
method (str): the HTTP request method to use
path (str): path of the endpoint to use
headers (dict): dictionary containing the headers of the request
Returns:
requests.Response: the response received for the sent request
"""
headers = headers or {}
request_headers = {**headers, **self._headers}
response = requests.request(method, path, headers=request_headers, **kwargs)
if response.status_code == 401:
# Refresh the access_token and retry the request
self._refresh_access_token()
request_headers = {**headers, **self._headers}
response = requests.request(method, path, headers=request_headers, **kwargs)
return response

@staticmethod
def _format_error_message(response: requests.Response) -> str:
body = response.json()
Expand Down
76 changes: 60 additions & 16 deletions tests/api/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
port = 443
"""

test_host = "SomeHost"
test_token = "SomeToken"

class MockResponse:
"""A mock response with a JSON or binary body."""
Expand Down Expand Up @@ -95,7 +97,7 @@ def test_get_device_spec(self, prog, connection, monkeypatch):

monkeypatch.setattr(
requests,
"get",
"request",
mock_return(MockResponse(
200,
{"layout": "", "modes": 42, "compiler": [], "gate_parameters": {"param": [[0, 1]]}}
Expand All @@ -114,7 +116,7 @@ def test_get_device_spec(self, prog, connection, monkeypatch):

def test_get_device_spec_error(self, connection, monkeypatch):
"""Tests a failed device spec request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get device specifications"):
connection.get_device_spec("123")
Expand All @@ -124,7 +126,7 @@ def test_create_job(self, prog, connection, monkeypatch):
id_, status = "123", JobStatus.QUEUED

monkeypatch.setattr(
requests, "post", mock_return(MockResponse(201, {"id": id_, "status": status})),
requests, "request", mock_return(MockResponse(201, {"id": id_, "status": status})),
)

job = connection.create_job("X8_01", prog, {"shots": 1})
Expand All @@ -134,7 +136,7 @@ def test_create_job(self, prog, connection, monkeypatch):

def test_create_job_error(self, prog, connection, monkeypatch):
"""Tests a failed job creation flow."""
monkeypatch.setattr(requests, "post", mock_return(MockResponse(400, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(400, {})))

with pytest.raises(RequestFailedError, match="Failed to create job"):
connection.create_job("X8_01", prog, {"shots": 1})
Expand All @@ -151,7 +153,7 @@ def test_get_all_jobs(self, connection, monkeypatch):
for i in range(1, 10)
]
monkeypatch.setattr(
requests, "get", mock_return(MockResponse(200, {"data": jobs})),
requests, "request", mock_return(MockResponse(200, {"data": jobs})),
)

jobs = connection.get_all_jobs(after=datetime(2020, 1, 5))
Expand All @@ -161,7 +163,7 @@ def test_get_all_jobs(self, connection, monkeypatch):
@pytest.mark.xfail(reason="method not yet implemented")
def test_get_all_jobs_error(self, connection, monkeypatch):
"""Tests a failed job list request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get all jobs"):
connection.get_all_jobs()
Expand All @@ -172,7 +174,7 @@ def test_get_job(self, connection, monkeypatch):

monkeypatch.setattr(
requests,
"get",
"request",
mock_return(MockResponse(200, {"id": id_, "status": status.value, "meta": meta})),
)

Expand All @@ -184,7 +186,7 @@ def test_get_job(self, connection, monkeypatch):

def test_get_job_error(self, connection, monkeypatch):
"""Tests a failed job request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get job"):
connection.get_job("123")
Expand All @@ -195,15 +197,15 @@ def test_get_job_status(self, connection, monkeypatch):

monkeypatch.setattr(
requests,
"get",
"request",
mock_return(MockResponse(200, {"id": id_, "status": status.value, "meta": {}})),
)

assert connection.get_job_status(id_) == status.value

def test_get_job_status_error(self, connection, monkeypatch):
"""Tests a failed job status request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get job"):
connection.get_job_status("123")
Expand Down Expand Up @@ -231,7 +233,7 @@ def test_get_job_result(self, connection, result_dtype, monkeypatch):
np.save(buf, result_samples)
buf.seek(0)
monkeypatch.setattr(
requests, "get", mock_return(MockResponse(200, binary_body=buf.getvalue())),
requests, "request", mock_return(MockResponse(200, binary_body=buf.getvalue())),
)

result = connection.get_job_result("123")
Expand All @@ -240,7 +242,7 @@ def test_get_job_result(self, connection, result_dtype, monkeypatch):

def test_get_job_result_error(self, connection, monkeypatch):
"""Tests a failed job result request."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to get job result"):
connection.get_job_result("123")
Expand All @@ -255,30 +257,72 @@ def function(*args, **kwargs):

return function

monkeypatch.setattr(requests, "patch", _mock_return(MockResponse(204, {})))
monkeypatch.setattr(requests, "request", _mock_return(MockResponse(204, {})))

# A successful cancellation does not raise an exception
connection.cancel_job("123")

def test_cancel_job_error(self, connection, monkeypatch):
"""Tests a failed job cancellation request."""
monkeypatch.setattr(requests, "patch", mock_return(MockResponse(404, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(404, {})))

with pytest.raises(RequestFailedError, match="Failed to cancel job"):
connection.cancel_job("123")

def test_ping_success(self, connection, monkeypatch):
"""Tests a successful ping to the remote host."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(200, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(200, {})))

assert connection.ping()

def test_ping_failure(self, connection, monkeypatch):
"""Tests a failed ping to the remote host."""
monkeypatch.setattr(requests, "get", mock_return(MockResponse(500, {})))
monkeypatch.setattr(requests, "request", mock_return(MockResponse(500, {})))

assert not connection.ping()

def test_refresh_access_token(self, mocker, monkeypatch):
"""Test that the access token is created by passing the expected headers."""
path = "/auth/realms/platform/protocol/openid-connect/token"

data={
"grant_type": "refresh_token",
"refresh_token": test_token,
"client_id": "public",
}

monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {})))
spy = mocker.spy(requests, "post")

conn = Connection(token=test_token, host=test_host)
conn._refresh_access_token()
expected_headers = {'Accept-Version': conn.api_version}
expected_url = f"https://{test_host}:443{path}"
spy.assert_called_once_with(expected_url, headers=expected_headers, data=data)

def test_refresh_access_token_raises(self, monkeypatch):
"""Test that an error is raised when the access token could not be
generated while creating the Connection object."""
monkeypatch.setattr(requests, "post", mock_return(MockResponse(500, {})))
conn = Connection(token=test_token, host=test_host)
with pytest.raises(RequestFailedError, match="Could not retrieve access token"):
conn._refresh_access_token()

def test_wrapped_request_refreshes(self, mocker, monkeypatch):
"""Test that the _request method refreshes the access token when
getting a 401 response."""
# Mock post function used while refreshing
monkeypatch.setattr(requests, "post", mock_return(MockResponse(200, {})))

# Mock request function used for general requests
monkeypatch.setattr(requests, "request", mock_return(MockResponse(401, {})))

conn = Connection(token=test_token, host=test_host)

spy = mocker.spy(conn, "_refresh_access_token")
conn._request("SomeRequestMethod", "SomePath")
spy.assert_called_once_with()


class TestConnectionIntegration:
"""Integration tests for using instances of the Connection."""
Expand Down

0 comments on commit 7b0a45d

Please sign in to comment.