1313import durabletask .internal .orchestrator_service_pb2 as pb
1414import durabletask .internal .orchestrator_service_pb2_grpc as stubs
1515import durabletask .internal .shared as shared
16- from durabletask .aio .internal .shared import get_grpc_aio_channel , ClientInterceptor
1716from durabletask import task
18- from durabletask .client import OrchestrationState , OrchestrationStatus , new_orchestration_state , TInput , TOutput
1917from durabletask .aio .internal .grpc_interceptor import DefaultClientInterceptorImpl
18+ from durabletask .aio .internal .shared import ClientInterceptor , get_grpc_aio_channel
19+ from durabletask .client import (
20+ OrchestrationState ,
21+ OrchestrationStatus ,
22+ TInput ,
23+ TOutput ,
24+ new_orchestration_state ,
25+ )
2026
2127
2228class AsyncTaskHubGrpcClient :
23-
24- def __init__ (self , * ,
25- host_address : Optional [str ] = None ,
26- metadata : Optional [list [tuple [str , str ]]] = None ,
27- log_handler : Optional [logging .Handler ] = None ,
28- log_formatter : Optional [logging .Formatter ] = None ,
29- secure_channel : bool = False ,
30- interceptors : Optional [Sequence [ClientInterceptor ]] = None ):
31-
29+ def __init__ (
30+ self ,
31+ * ,
32+ host_address : Optional [str ] = None ,
33+ metadata : Optional [list [tuple [str , str ]]] = None ,
34+ log_handler : Optional [logging .Handler ] = None ,
35+ log_formatter : Optional [logging .Formatter ] = None ,
36+ secure_channel : bool = False ,
37+ interceptors : Optional [Sequence [ClientInterceptor ]] = None ,
38+ ):
3239 if interceptors is not None :
3340 interceptors = list (interceptors )
3441 if metadata is not None :
@@ -39,9 +46,7 @@ def __init__(self, *,
3946 interceptors = None
4047
4148 channel = get_grpc_aio_channel (
42- host_address = host_address ,
43- secure_channel = secure_channel ,
44- interceptors = interceptors
49+ host_address = host_address , secure_channel = secure_channel , interceptors = interceptors
4550 )
4651 self ._channel = channel
4752 self ._stub = stubs .TaskHubSidecarServiceStub (channel )
@@ -57,18 +62,23 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
5762 await self .aclose ()
5863 return False
5964
60- async def schedule_new_orchestration (self , orchestrator : Union [task .Orchestrator [TInput , TOutput ], str ], * ,
61- input : Optional [TInput ] = None ,
62- instance_id : Optional [str ] = None ,
63- start_at : Optional [datetime ] = None ,
64- reuse_id_policy : Optional [pb .OrchestrationIdReusePolicy ] = None ) -> str :
65-
65+ async def schedule_new_orchestration (
66+ self ,
67+ orchestrator : Union [task .Orchestrator [TInput , TOutput ], str ],
68+ * ,
69+ input : Optional [TInput ] = None ,
70+ instance_id : Optional [str ] = None ,
71+ start_at : Optional [datetime ] = None ,
72+ reuse_id_policy : Optional [pb .OrchestrationIdReusePolicy ] = None ,
73+ ) -> str :
6674 name = orchestrator if isinstance (orchestrator , str ) else task .get_name (orchestrator )
6775
6876 req = pb .CreateInstanceRequest (
6977 name = name ,
7078 instanceId = instance_id if instance_id else uuid .uuid4 ().hex ,
71- input = wrappers_pb2 .StringValue (value = shared .to_json (input )) if input is not None else None ,
79+ input = wrappers_pb2 .StringValue (value = shared .to_json (input ))
80+ if input is not None
81+ else None ,
7282 scheduledStartTimestamp = helpers .new_timestamp (start_at ) if start_at else None ,
7383 version = helpers .get_string_value (None ),
7484 orchestrationIdReusePolicy = reuse_id_policy ,
@@ -78,20 +88,25 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator
7888 res : pb .CreateInstanceResponse = await self ._stub .StartInstance (req )
7989 return res .instanceId
8090
81- async def get_orchestration_state (self , instance_id : str , * , fetch_payloads : bool = True ) -> Optional [OrchestrationState ]:
91+ async def get_orchestration_state (
92+ self , instance_id : str , * , fetch_payloads : bool = True
93+ ) -> Optional [OrchestrationState ]:
8294 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
8395 res : pb .GetInstanceResponse = await self ._stub .GetInstance (req )
8496 return new_orchestration_state (req .instanceId , res )
8597
86- async def wait_for_orchestration_start (self , instance_id : str , * ,
87- fetch_payloads : bool = False ,
88- timeout : int = 0 ) -> Optional [OrchestrationState ]:
98+ async def wait_for_orchestration_start (
99+ self , instance_id : str , * , fetch_payloads : bool = False , timeout : int = 0
100+ ) -> Optional [OrchestrationState ]:
89101 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
90102 try :
91103 grpc_timeout = None if timeout == 0 else timeout
92104 self ._logger .info (
93- f"Waiting { 'indefinitely' if timeout == 0 else f'up to { timeout } s' } for instance '{ instance_id } ' to start." )
94- res : pb .GetInstanceResponse = await self ._stub .WaitForInstanceStart (req , timeout = grpc_timeout )
105+ f"Waiting { 'indefinitely' if timeout == 0 else f'up to { timeout } s' } for instance '{ instance_id } ' to start."
106+ )
107+ res : pb .GetInstanceResponse = await self ._stub .WaitForInstanceStart (
108+ req , timeout = grpc_timeout
109+ )
95110 return new_orchestration_state (req .instanceId , res )
96111 except grpc .RpcError as rpc_error :
97112 if rpc_error .code () == grpc .StatusCode .DEADLINE_EXCEEDED : # type: ignore
@@ -100,22 +115,30 @@ async def wait_for_orchestration_start(self, instance_id: str, *,
100115 else :
101116 raise
102117
103- async def wait_for_orchestration_completion (self , instance_id : str , * ,
104- fetch_payloads : bool = True ,
105- timeout : int = 0 ) -> Optional [OrchestrationState ]:
118+ async def wait_for_orchestration_completion (
119+ self , instance_id : str , * , fetch_payloads : bool = True , timeout : int = 0
120+ ) -> Optional [OrchestrationState ]:
106121 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
107122 try :
108123 grpc_timeout = None if timeout == 0 else timeout
109124 self ._logger .info (
110- f"Waiting { 'indefinitely' if timeout == 0 else f'up to { timeout } s' } for instance '{ instance_id } ' to complete." )
111- res : pb .GetInstanceResponse = await self ._stub .WaitForInstanceCompletion (req , timeout = grpc_timeout )
125+ f"Waiting { 'indefinitely' if timeout == 0 else f'up to { timeout } s' } for instance '{ instance_id } ' to complete."
126+ )
127+ res : pb .GetInstanceResponse = await self ._stub .WaitForInstanceCompletion (
128+ req , timeout = grpc_timeout
129+ )
112130 state = new_orchestration_state (req .instanceId , res )
113131 if not state :
114132 return None
115133
116- if state .runtime_status == OrchestrationStatus .FAILED and state .failure_details is not None :
134+ if (
135+ state .runtime_status == OrchestrationStatus .FAILED
136+ and state .failure_details is not None
137+ ):
117138 details = state .failure_details
118- self ._logger .info (f"Instance '{ instance_id } ' failed: [{ details .error_type } ] { details .message } " )
139+ self ._logger .info (
140+ f"Instance '{ instance_id } ' failed: [{ details .error_type } ] { details .message } "
141+ )
119142 elif state .runtime_status == OrchestrationStatus .TERMINATED :
120143 self ._logger .info (f"Instance '{ instance_id } ' was terminated." )
121144 elif state .runtime_status == OrchestrationStatus .COMPLETED :
@@ -130,26 +153,25 @@ async def wait_for_orchestration_completion(self, instance_id: str, *,
130153 raise
131154
132155 async def raise_orchestration_event (
133- self ,
134- instance_id : str ,
135- event_name : str ,
136- * ,
137- data : Optional [Any ] = None ):
156+ self , instance_id : str , event_name : str , * , data : Optional [Any ] = None
157+ ):
138158 req = pb .RaiseEventRequest (
139159 instanceId = instance_id ,
140160 name = event_name ,
141- input = wrappers_pb2 .StringValue (value = shared .to_json (data )) if data else None )
161+ input = wrappers_pb2 .StringValue (value = shared .to_json (data )) if data else None ,
162+ )
142163
143164 self ._logger .info (f"Raising event '{ event_name } ' for instance '{ instance_id } '." )
144165 await self ._stub .RaiseEvent (req )
145166
146- async def terminate_orchestration (self , instance_id : str , * ,
147- output : Optional [Any ] = None ,
148- recursive : bool = True ):
167+ async def terminate_orchestration (
168+ self , instance_id : str , * , output : Optional [Any ] = None , recursive : bool = True
169+ ):
149170 req = pb .TerminateRequest (
150171 instanceId = instance_id ,
151172 output = wrappers_pb2 .StringValue (value = shared .to_json (output )) if output else None ,
152- recursive = recursive )
173+ recursive = recursive ,
174+ )
153175
154176 self ._logger .info (f"Terminating instance '{ instance_id } '." )
155177 await self ._stub .TerminateInstance (req )
0 commit comments