diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 464ea3c3c..a0ed19242 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -31,6 +31,9 @@

Improvements

+* Add details to the error message for failed remote jobs. + [(#370)](https://github.com/XanaduAI/strawberryfields/pull/370) +

Bug fixes

Contributors

diff --git a/strawberryfields/api/connection.py b/strawberryfields/api/connection.py index d1b5cd5eb..e2a3ded07 100644 --- a/strawberryfields/api/connection.py +++ b/strawberryfields/api/connection.py @@ -193,6 +193,7 @@ def get_job(self, job_id: str) -> Job: id_=response.json()["id"], status=JobStatus(response.json()["status"]), connection=self, + meta=response.json()["meta"], ) raise RequestFailedError( "Failed to get job: {}".format(self._format_error_message(response)) diff --git a/strawberryfields/api/job.py b/strawberryfields/api/job.py index ef1a063e9..672130697 100644 --- a/strawberryfields/api/job.py +++ b/strawberryfields/api/job.py @@ -73,13 +73,15 @@ class Job: status (strawberryfields.api.JobStatus): the job status connection (strawberryfields.api.Connection): the connection over which the job is managed + meta (dict[str, str]): metadata related to job execution """ - def __init__(self, id_: str, status: JobStatus, connection): + def __init__(self, id_: str, status: JobStatus, connection, meta: dict = None): self._id = id_ self._status = status self._connection = connection self._result = None + self._meta = meta or {} self.log = create_logger(__name__) @@ -119,14 +121,27 @@ def result(self) -> Result: ) return self._result + @property + def meta(self) -> dict: + """Returns a dictionary of metadata on job execution, such as error + details. + + Returns: + dict[str, str] + """ + return self._meta + def refresh(self): - """Refreshes the status of an open or queued job, + """Refreshes the status and metadata of an open or queued job, along with the job result if the job is newly completed. """ if self._status.is_final: self.log.warning("A %s job cannot be refreshed", self._status.value) return - self._status = JobStatus(self._connection.get_job_status(self.id)) + job_info = self._connection.get_job(self.id) + self._status = JobStatus(job_info.status) + self._meta = job_info.meta + self.log.debug("Job %s metadata: %s", self.id, job_info.meta) if self._status == JobStatus.COMPLETED: self._result = self._connection.get_job_result(self.id) diff --git a/strawberryfields/engine.py b/strawberryfields/engine.py index 70f95d32f..226afaf11 100644 --- a/strawberryfields/engine.py +++ b/strawberryfields/engine.py @@ -547,8 +547,8 @@ def run(self, program: Program, *, compile_options=None, **kwargs) -> Optional[R if job.status == "failed": message = ( - "The remote job %s failed due to an internal " - "server error. Please try again." % job.id + "The remote job {} failed due to an internal " + "server error. Please try again. {}".format(job.id, job.meta) ) self.log.error(message) diff --git a/tests/api/test_connection.py b/tests/api/test_connection.py index bfc531a07..3be3b2d6b 100644 --- a/tests/api/test_connection.py +++ b/tests/api/test_connection.py @@ -132,16 +132,19 @@ def test_get_all_jobs_error(self, connection, monkeypatch): def test_get_job(self, connection, monkeypatch): """Tests a successful job request.""" - id_, status = "123", JobStatus.COMPLETED + id_, status, meta = "123", JobStatus.COMPLETED, {"abc": "def"} monkeypatch.setattr( - requests, "get", mock_return(MockResponse(200, {"id": id_, "status": status.value})), + requests, + "get", + mock_return(MockResponse(200, {"id": id_, "status": status.value, "meta": meta})), ) job = connection.get_job(id_) assert job.id == id_ assert job.status == status.value + assert job.meta == meta def test_get_job_error(self, connection, monkeypatch): """Tests a failed job request.""" @@ -155,7 +158,9 @@ def test_get_job_status(self, connection, monkeypatch): id_, status = "123", JobStatus.COMPLETED monkeypatch.setattr( - requests, "get", mock_return(MockResponse(200, {"id": id_, "status": status.value})), + requests, + "get", + mock_return(MockResponse(200, {"id": id_, "status": status.value, "meta": {}})), ) assert connection.get_job_status(id_) == status.value @@ -238,6 +243,7 @@ def test_ping_failure(self, connection, monkeypatch): assert not connection.ping() + class TestConnectionIntegration: """Integration tests for using instances of the Connection.""" diff --git a/tests/api/test_remote_engine.py b/tests/api/test_remote_engine.py index 396cbef8a..e7068b060 100644 --- a/tests/api/test_remote_engine.py +++ b/tests/api/test_remote_engine.py @@ -35,16 +35,17 @@ class MockServer: def __init__(self): self.request_count = 0 - def get_job_status(self, _id): + def get_job(self, _id): """Returns a 'queued' job status until the number of requests exceeds a defined threshold, beyond which a 'complete' job status is returned. """ self.request_count += 1 - return ( + status = ( JobStatus.COMPLETED if self.request_count >= self.REQUESTS_BEFORE_COMPLETED else JobStatus.QUEUED ) + return Job(id_="123", status=status, connection=None, meta={"foo": "bar"}) @pytest.fixture @@ -56,7 +57,7 @@ def job_to_complete(connection, monkeypatch): mock_return(Job(id_="123", status=JobStatus.OPEN, connection=connection)), ) server = MockServer() - monkeypatch.setattr(Connection, "get_job_status", server.get_job_status) + monkeypatch.setattr(Connection, "get_job", server.get_job) monkeypatch.setattr( Connection, "get_job_result", @@ -90,6 +91,7 @@ def test_run_async(self, connection, prog, job_to_complete): job.refresh() assert job.status == JobStatus.COMPLETED.value + assert job.meta == {"foo": "bar"} assert np.array_equal(job.result.samples, np.array([[1, 2], [3, 4]])) with pytest.raises(