Skip to content

Commit

Permalink
fix(agent): use context.abort() instead of returning
Browse files Browse the repository at this point in the history
We stopped always yielding stuff unconditionally, so now if you get
an error and use our own `abort_with_msg` we'll get an error
iterating responses, as we never yield anything and just return `None`.
So instead just use proper grpc `ServicerContext.abort()` method to
raise an exception instead.
  • Loading branch information
efiop committed May 23, 2024
1 parent 2051e12 commit df73fb8
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions src/isolate/connections/grpc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dataclasses import dataclass
from typing import (
Any,
Generator,
Iterable,
Iterator,
)
Expand Down Expand Up @@ -66,15 +65,15 @@ def Run(
self.log(
"The setup function has thrown an error. Aborting the run."
)
yield from self.send_object(
yield self.send_object(
request.setup_func.method,
result,
was_it_raised,
stringized_tb,
)
raise AbortException("The setup function has thrown an error.")
except AbortException as exc:
return self.abort_with_msg(context, exc.message)
context.abort(StatusCode.INVALID_ARGUMENT, exc.message)
else:
assert not was_it_raised
self._run_cache[cache_key] = result
Expand All @@ -87,14 +86,14 @@ def Run(
"function",
extra_args=extra_args,
)
yield from self.send_object(
yield self.send_object(
request.function.method,
result,
was_it_raised,
stringized_tb,
)
except AbortException as exc:
return self.abort_with_msg(context, exc.message)
context.abort(StatusCode.INVALID_ARGUMENT, exc.message)

def execute_function(
self,
Expand Down Expand Up @@ -143,7 +142,7 @@ def send_object(
result: object,
was_it_raised: bool,
stringized_tb: str | None,
) -> Generator[definitions.PartialRunResult, None, Any]:
) -> definitions.PartialRunResult:
try:
definition = serialize_object(serialization_method, result)
except SerializationError:
Expand All @@ -166,7 +165,7 @@ def send_object(
was_it_raised=was_it_raised,
stringized_traceback=stringized_tb,
)
yield definitions.PartialRunResult(
return definitions.PartialRunResult(
result=serialized_obj,
is_complete=True,
logs=[],
Expand All @@ -176,17 +175,6 @@ def log(self, message: str) -> None:
self._log.write(message)
self._log.flush()

def abort_with_msg(
self,
context: ServicerContext,
message: str,
*,
code: StatusCode = StatusCode.INVALID_ARGUMENT,
) -> None:
context.set_code(code)
context.set_details(message)
return None


def create_server(address: str) -> grpc.Server:
"""Create a new (temporary) gRPC server listening on the given
Expand Down

0 comments on commit df73fb8

Please sign in to comment.