Skip to content

Commit

Permalink
Merge branch 'typed-search-attributes' of https://github.com/cretz/te…
Browse files Browse the repository at this point in the history
…mporal-sdk-python into typed-search-attributes
  • Loading branch information
cretz committed Oct 30, 2023
2 parents 66187ca + 1b5ce24 commit dd6240f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 78 deletions.
106 changes: 35 additions & 71 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4103,7 +4103,6 @@ def workflow_run_id(self) -> Optional[str]:
async def result(
self,
*,
timeout: Optional[timedelta] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> LocalReturnType:
Expand All @@ -4112,7 +4111,6 @@ async def result(
specified.
Args:
timeout: Optional timeout specifying maximum wait time for the result.
rpc_metadata: Headers used on the RPC call. Keys here override client-level RPC metadata keys.
rpc_timeout: Optional RPC deadline to set for the RPC call. If this elapses, the poll is retried until the
overall timeout has been reached.
Expand All @@ -4131,18 +4129,43 @@ async def result(
self._result_type,
)

return await self._client._impl.poll_workflow_update(
PollWorkflowUpdateInput(
self.workflow_id,
self.workflow_run_id,
self.id,
timeout,
self._result_type,
rpc_metadata,
rpc_timeout,
)
req = temporalio.api.workflowservice.v1.PollWorkflowExecutionUpdateRequest(
namespace=self._client.namespace,
update_ref=temporalio.api.update.v1.UpdateRef(
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
workflow_id=self.workflow_id,
run_id=self.workflow_run_id or "",
),
update_id=self.id,
),
identity=self._client.identity,
wait_policy=temporalio.api.update.v1.WaitPolicy(
lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED
),
)

# Continue polling as long as we have either an empty response, or an *rpc* timeout
while True:
try:
res = (
await self._client.workflow_service.poll_workflow_execution_update(
req,
retry=True,
metadata=rpc_metadata,
timeout=rpc_timeout,
)
)
if res.HasField("outcome"):
return await _update_outcome_to_result(
res.outcome,
self.id,
self._client.data_converter,
self._result_type,
)
except RPCError as err:
if err.status != RPCStatusCode.DEADLINE_EXCEEDED:
raise


class WorkflowFailureError(temporalio.exceptions.TemporalError):
"""Error that occurs when a workflow is unsuccessful."""
Expand Down Expand Up @@ -4369,19 +4392,6 @@ class StartWorkflowUpdateInput:
rpc_timeout: Optional[timedelta]


@dataclass
class PollWorkflowUpdateInput:
"""Input for :py:meth:`OutboundInterceptor.poll_workflow_update`."""

workflow_id: str
run_id: Optional[str]
update_id: str
timeout: Optional[timedelta]
ret_type: Optional[Type]
rpc_metadata: Mapping[str, str]
rpc_timeout: Optional[timedelta]


@dataclass
class HeartbeatAsyncActivityInput:
"""Input for :py:meth:`OutboundInterceptor.heartbeat_async_activity`."""
Expand Down Expand Up @@ -4636,10 +4646,6 @@ async def start_workflow_update(
"""Called for every :py:meth:`WorkflowHandle.update` and :py:meth:`WorkflowHandle.start_update` call."""
return await self.next.start_workflow_update(input)

async def poll_workflow_update(self, input: PollWorkflowUpdateInput) -> Any:
"""May be called when calling :py:meth:`WorkflowUpdateHandle.result`."""
return await self.next.poll_workflow_update(input)

### Async activity calls

async def heartbeat_async_activity(
Expand Down Expand Up @@ -5017,48 +5023,6 @@ async def start_workflow_update(

return update_handle

async def poll_workflow_update(self, input: PollWorkflowUpdateInput) -> Any:
req = temporalio.api.workflowservice.v1.PollWorkflowExecutionUpdateRequest(
namespace=self._client.namespace,
update_ref=temporalio.api.update.v1.UpdateRef(
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
workflow_id=input.workflow_id,
run_id=input.run_id or "",
),
update_id=input.update_id,
),
identity=self._client.identity,
wait_policy=temporalio.api.update.v1.WaitPolicy(
lifecycle_stage=temporalio.api.enums.v1.UpdateWorkflowExecutionLifecycleStage.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED
),
)

async def poll_loop():
# Continue polling as long as we have either an empty response, or an *rpc* timeout
while True:
try:
res = await self._client.workflow_service.poll_workflow_execution_update(
req,
retry=True,
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
if res.HasField("outcome"):
return await _update_outcome_to_result(
res.outcome,
input.update_id,
self._client.data_converter,
input.ret_type,
)
except RPCError as err:
if err.status != RPCStatusCode.DEADLINE_EXCEEDED:
raise

# Wait for at most the *overall* timeout
return await asyncio.wait_for(
poll_loop(), input.timeout.total_seconds() if input.timeout else None
)

### Async activity calls

async def heartbeat_async_activity(
Expand Down
7 changes: 0 additions & 7 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
Client,
Interceptor,
OutboundInterceptor,
PollWorkflowUpdateInput,
QueryWorkflowInput,
RPCError,
RPCStatusCode,
Expand Down Expand Up @@ -467,12 +466,6 @@ async def start_workflow_update(
self._parent.traces.append(("start_workflow_update", input))
return await super().start_workflow_update(input)

async def poll_workflow_update(
self, input: PollWorkflowUpdateInput
) -> WorkflowUpdateHandle[Any]:
self._parent.traces.append(("poll_workflow_update", input))
return await super().poll_workflow_update(input)


async def test_interceptor(client: Client, worker: ExternalWorker):
# Create new client from existing client but with a tracing interceptor
Expand Down

0 comments on commit dd6240f

Please sign in to comment.