1313import  durabletask .internal .orchestrator_service_pb2  as  pb 
1414import  durabletask .internal .orchestrator_service_pb2_grpc  as  stubs 
1515import  durabletask .internal .shared  as  shared 
16- from  durabletask .aio .internal .shared  import  get_grpc_aio_channel , ClientInterceptor 
1716from  durabletask  import  task 
18- from  durabletask .client  import  OrchestrationState , OrchestrationStatus , new_orchestration_state , TInput , TOutput 
1917from  durabletask .aio .internal .grpc_interceptor  import  DefaultClientInterceptorImpl 
18+ from  durabletask .aio .internal .shared  import  ClientInterceptor , get_grpc_aio_channel 
19+ from  durabletask .client  import  (
20+     OrchestrationState ,
21+     OrchestrationStatus ,
22+     TInput ,
23+     TOutput ,
24+     new_orchestration_state ,
25+ )
2026
2127
2228class  AsyncTaskHubGrpcClient :
23- 
24-     def  __init__ (self , * ,
25-                  host_address : Optional [str ] =  None ,
26-                  metadata : Optional [list [tuple [str , str ]]] =  None ,
27-                  log_handler : Optional [logging .Handler ] =  None ,
28-                  log_formatter : Optional [logging .Formatter ] =  None ,
29-                  secure_channel : bool  =  False ,
30-                  interceptors : Optional [Sequence [ClientInterceptor ]] =  None ):
31- 
29+     def  __init__ (
30+         self ,
31+         * ,
32+         host_address : Optional [str ] =  None ,
33+         metadata : Optional [list [tuple [str , str ]]] =  None ,
34+         log_handler : Optional [logging .Handler ] =  None ,
35+         log_formatter : Optional [logging .Formatter ] =  None ,
36+         secure_channel : bool  =  False ,
37+         interceptors : Optional [Sequence [ClientInterceptor ]] =  None ,
38+     ):
3239        if  interceptors  is  not   None :
3340            interceptors  =  list (interceptors )
3441            if  metadata  is  not   None :
@@ -39,9 +46,7 @@ def __init__(self, *,
3946            interceptors  =  None 
4047
4148        channel  =  get_grpc_aio_channel (
42-             host_address = host_address ,
43-             secure_channel = secure_channel ,
44-             interceptors = interceptors 
49+             host_address = host_address , secure_channel = secure_channel , interceptors = interceptors 
4550        )
4651        self ._channel  =  channel 
4752        self ._stub  =  stubs .TaskHubSidecarServiceStub (channel )
@@ -57,18 +62,23 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
5762        await  self .aclose ()
5863        return  False 
5964
60-     async  def  schedule_new_orchestration (self , orchestrator : Union [task .Orchestrator [TInput , TOutput ], str ], * ,
61-                                          input : Optional [TInput ] =  None ,
62-                                          instance_id : Optional [str ] =  None ,
63-                                          start_at : Optional [datetime ] =  None ,
64-                                          reuse_id_policy : Optional [pb .OrchestrationIdReusePolicy ] =  None ) ->  str :
65- 
65+     async  def  schedule_new_orchestration (
66+         self ,
67+         orchestrator : Union [task .Orchestrator [TInput , TOutput ], str ],
68+         * ,
69+         input : Optional [TInput ] =  None ,
70+         instance_id : Optional [str ] =  None ,
71+         start_at : Optional [datetime ] =  None ,
72+         reuse_id_policy : Optional [pb .OrchestrationIdReusePolicy ] =  None ,
73+     ) ->  str :
6674        name  =  orchestrator  if  isinstance (orchestrator , str ) else  task .get_name (orchestrator )
6775
6876        req  =  pb .CreateInstanceRequest (
6977            name = name ,
7078            instanceId = instance_id  if  instance_id  else  uuid .uuid4 ().hex ,
71-             input = wrappers_pb2 .StringValue (value = shared .to_json (input )) if  input  is  not   None  else  None ,
79+             input = wrappers_pb2 .StringValue (value = shared .to_json (input ))
80+             if  input  is  not   None 
81+             else  None ,
7282            scheduledStartTimestamp = helpers .new_timestamp (start_at ) if  start_at  else  None ,
7383            version = helpers .get_string_value (None ),
7484            orchestrationIdReusePolicy = reuse_id_policy ,
@@ -78,20 +88,25 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator
7888        res : pb .CreateInstanceResponse  =  await  self ._stub .StartInstance (req )
7989        return  res .instanceId 
8090
81-     async  def  get_orchestration_state (self , instance_id : str , * , fetch_payloads : bool  =  True ) ->  Optional [OrchestrationState ]:
91+     async  def  get_orchestration_state (
92+         self , instance_id : str , * , fetch_payloads : bool  =  True 
93+     ) ->  Optional [OrchestrationState ]:
8294        req  =  pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
8395        res : pb .GetInstanceResponse  =  await  self ._stub .GetInstance (req )
8496        return  new_orchestration_state (req .instanceId , res )
8597
86-     async  def  wait_for_orchestration_start (self ,  instance_id :  str ,  * , 
87-                                             fetch_payloads : bool  =  False ,
88-                                             timeout :  int   =   0 ) ->  Optional [OrchestrationState ]:
98+     async  def  wait_for_orchestration_start (
99+         self ,  instance_id :  str ,  * ,  fetch_payloads : bool  =  False ,  timeout :  int   =   0 
100+     ) ->  Optional [OrchestrationState ]:
89101        req  =  pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
90102        try :
91103            grpc_timeout  =  None  if  timeout  ==  0  else  timeout 
92104            self ._logger .info (
93-                 f"Waiting { 'indefinitely'  if  timeout  ==  0  else  f'up to { timeout }  s' }   for instance '{ instance_id }  ' to start." )
94-             res : pb .GetInstanceResponse  =  await  self ._stub .WaitForInstanceStart (req , timeout = grpc_timeout )
105+                 f"Waiting { 'indefinitely'  if  timeout  ==  0  else  f'up to { timeout }  s' }   for instance '{ instance_id }  ' to start." 
106+             )
107+             res : pb .GetInstanceResponse  =  await  self ._stub .WaitForInstanceStart (
108+                 req , timeout = grpc_timeout 
109+             )
95110            return  new_orchestration_state (req .instanceId , res )
96111        except  grpc .RpcError  as  rpc_error :
97112            if  rpc_error .code () ==  grpc .StatusCode .DEADLINE_EXCEEDED :  # type: ignore 
@@ -100,22 +115,30 @@ async def wait_for_orchestration_start(self, instance_id: str, *,
100115            else :
101116                raise 
102117
103-     async  def  wait_for_orchestration_completion (self ,  instance_id :  str ,  * , 
104-                                                  fetch_payloads : bool  =  True ,
105-                                                  timeout :  int   =   0 ) ->  Optional [OrchestrationState ]:
118+     async  def  wait_for_orchestration_completion (
119+         self ,  instance_id :  str ,  * ,  fetch_payloads : bool  =  True ,  timeout :  int   =   0 
120+     ) ->  Optional [OrchestrationState ]:
106121        req  =  pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
107122        try :
108123            grpc_timeout  =  None  if  timeout  ==  0  else  timeout 
109124            self ._logger .info (
110-                 f"Waiting { 'indefinitely'  if  timeout  ==  0  else  f'up to { timeout }  s' }   for instance '{ instance_id }  ' to complete." )
111-             res : pb .GetInstanceResponse  =  await  self ._stub .WaitForInstanceCompletion (req , timeout = grpc_timeout )
125+                 f"Waiting { 'indefinitely'  if  timeout  ==  0  else  f'up to { timeout }  s' }   for instance '{ instance_id }  ' to complete." 
126+             )
127+             res : pb .GetInstanceResponse  =  await  self ._stub .WaitForInstanceCompletion (
128+                 req , timeout = grpc_timeout 
129+             )
112130            state  =  new_orchestration_state (req .instanceId , res )
113131            if  not  state :
114132                return  None 
115133
116-             if  state .runtime_status  ==  OrchestrationStatus .FAILED  and  state .failure_details  is  not   None :
134+             if  (
135+                 state .runtime_status  ==  OrchestrationStatus .FAILED 
136+                 and  state .failure_details  is  not   None 
137+             ):
117138                details  =  state .failure_details 
118-                 self ._logger .info (f"Instance '{ instance_id }  ' failed: [{ details .error_type }  ] { details .message }  " )
139+                 self ._logger .info (
140+                     f"Instance '{ instance_id }  ' failed: [{ details .error_type }  ] { details .message }  " 
141+                 )
119142            elif  state .runtime_status  ==  OrchestrationStatus .TERMINATED :
120143                self ._logger .info (f"Instance '{ instance_id }  ' was terminated." )
121144            elif  state .runtime_status  ==  OrchestrationStatus .COMPLETED :
@@ -130,26 +153,25 @@ async def wait_for_orchestration_completion(self, instance_id: str, *,
130153                raise 
131154
132155    async  def  raise_orchestration_event (
133-             self ,
134-             instance_id : str ,
135-             event_name : str ,
136-             * ,
137-             data : Optional [Any ] =  None ):
156+         self , instance_id : str , event_name : str , * , data : Optional [Any ] =  None 
157+     ):
138158        req  =  pb .RaiseEventRequest (
139159            instanceId = instance_id ,
140160            name = event_name ,
141-             input = wrappers_pb2 .StringValue (value = shared .to_json (data )) if  data  else  None )
161+             input = wrappers_pb2 .StringValue (value = shared .to_json (data )) if  data  else  None ,
162+         )
142163
143164        self ._logger .info (f"Raising event '{ event_name }  ' for instance '{ instance_id }  '." )
144165        await  self ._stub .RaiseEvent (req )
145166
146-     async  def  terminate_orchestration (self ,  instance_id :  str ,  * , 
147-                                        output : Optional [Any ] =  None ,
148-                                        recursive :  bool   =   True ):
167+     async  def  terminate_orchestration (
168+         self ,  instance_id :  str ,  * ,  output : Optional [Any ] =  None ,  recursive :  bool   =   True 
169+     ):
149170        req  =  pb .TerminateRequest (
150171            instanceId = instance_id ,
151172            output = wrappers_pb2 .StringValue (value = shared .to_json (output )) if  output  else  None ,
152-             recursive = recursive )
173+             recursive = recursive ,
174+         )
153175
154176        self ._logger .info (f"Terminating instance '{ instance_id }  '." )
155177        await  self ._stub .TerminateInstance (req )
0 commit comments