Skip to content

Commit efe556a

Browse files
committed
run tox -e ruff
Signed-off-by: Albert Callarisa <[email protected]>
1 parent 2096e11 commit efe556a

24 files changed

+850
-622
lines changed

durabletask/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@
33

44
"""Durable Task SDK for Python"""
55

6-
76
PACKAGE_NAME = "durabletask"

durabletask/aio/client.py

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,29 @@
1313
import durabletask.internal.orchestrator_service_pb2 as pb
1414
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
1515
import durabletask.internal.shared as shared
16-
from durabletask.aio.internal.shared import get_grpc_aio_channel, ClientInterceptor
1716
from durabletask import task
18-
from durabletask.client import OrchestrationState, OrchestrationStatus, new_orchestration_state, TInput, TOutput
1917
from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl
18+
from durabletask.aio.internal.shared import ClientInterceptor, get_grpc_aio_channel
19+
from durabletask.client import (
20+
OrchestrationState,
21+
OrchestrationStatus,
22+
TInput,
23+
TOutput,
24+
new_orchestration_state,
25+
)
2026

2127

2228
class AsyncTaskHubGrpcClient:
23-
24-
def __init__(self, *,
25-
host_address: Optional[str] = None,
26-
metadata: Optional[list[tuple[str, str]]] = None,
27-
log_handler: Optional[logging.Handler] = None,
28-
log_formatter: Optional[logging.Formatter] = None,
29-
secure_channel: bool = False,
30-
interceptors: Optional[Sequence[ClientInterceptor]] = None):
31-
29+
def __init__(
30+
self,
31+
*,
32+
host_address: Optional[str] = None,
33+
metadata: Optional[list[tuple[str, str]]] = None,
34+
log_handler: Optional[logging.Handler] = None,
35+
log_formatter: Optional[logging.Formatter] = None,
36+
secure_channel: bool = False,
37+
interceptors: Optional[Sequence[ClientInterceptor]] = None,
38+
):
3239
if interceptors is not None:
3340
interceptors = list(interceptors)
3441
if metadata is not None:
@@ -39,9 +46,7 @@ def __init__(self, *,
3946
interceptors = None
4047

4148
channel = get_grpc_aio_channel(
42-
host_address=host_address,
43-
secure_channel=secure_channel,
44-
interceptors=interceptors
49+
host_address=host_address, secure_channel=secure_channel, interceptors=interceptors
4550
)
4651
self._channel = channel
4752
self._stub = stubs.TaskHubSidecarServiceStub(channel)
@@ -57,18 +62,23 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
5762
await self.aclose()
5863
return False
5964

60-
async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
61-
input: Optional[TInput] = None,
62-
instance_id: Optional[str] = None,
63-
start_at: Optional[datetime] = None,
64-
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str:
65-
65+
async def schedule_new_orchestration(
66+
self,
67+
orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
68+
*,
69+
input: Optional[TInput] = None,
70+
instance_id: Optional[str] = None,
71+
start_at: Optional[datetime] = None,
72+
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
73+
) -> str:
6674
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
6775

6876
req = pb.CreateInstanceRequest(
6977
name=name,
7078
instanceId=instance_id if instance_id else uuid.uuid4().hex,
71-
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
79+
input=wrappers_pb2.StringValue(value=shared.to_json(input))
80+
if input is not None
81+
else None,
7282
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
7383
version=helpers.get_string_value(None),
7484
orchestrationIdReusePolicy=reuse_id_policy,
@@ -78,20 +88,25 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator
7888
res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
7989
return res.instanceId
8090

81-
async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
91+
async def get_orchestration_state(
92+
self, instance_id: str, *, fetch_payloads: bool = True
93+
) -> Optional[OrchestrationState]:
8294
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
8395
res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
8496
return new_orchestration_state(req.instanceId, res)
8597

86-
async def wait_for_orchestration_start(self, instance_id: str, *,
87-
fetch_payloads: bool = False,
88-
timeout: int = 0) -> Optional[OrchestrationState]:
98+
async def wait_for_orchestration_start(
99+
self, instance_id: str, *, fetch_payloads: bool = False, timeout: int = 0
100+
) -> Optional[OrchestrationState]:
89101
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
90102
try:
91103
grpc_timeout = None if timeout == 0 else timeout
92104
self._logger.info(
93-
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.")
94-
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=grpc_timeout)
105+
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start."
106+
)
107+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(
108+
req, timeout=grpc_timeout
109+
)
95110
return new_orchestration_state(req.instanceId, res)
96111
except grpc.RpcError as rpc_error:
97112
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
@@ -100,22 +115,30 @@ async def wait_for_orchestration_start(self, instance_id: str, *,
100115
else:
101116
raise
102117

103-
async def wait_for_orchestration_completion(self, instance_id: str, *,
104-
fetch_payloads: bool = True,
105-
timeout: int = 0) -> Optional[OrchestrationState]:
118+
async def wait_for_orchestration_completion(
119+
self, instance_id: str, *, fetch_payloads: bool = True, timeout: int = 0
120+
) -> Optional[OrchestrationState]:
106121
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
107122
try:
108123
grpc_timeout = None if timeout == 0 else timeout
109124
self._logger.info(
110-
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.")
111-
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout)
125+
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete."
126+
)
127+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(
128+
req, timeout=grpc_timeout
129+
)
112130
state = new_orchestration_state(req.instanceId, res)
113131
if not state:
114132
return None
115133

116-
if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
134+
if (
135+
state.runtime_status == OrchestrationStatus.FAILED
136+
and state.failure_details is not None
137+
):
117138
details = state.failure_details
118-
self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
139+
self._logger.info(
140+
f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}"
141+
)
119142
elif state.runtime_status == OrchestrationStatus.TERMINATED:
120143
self._logger.info(f"Instance '{instance_id}' was terminated.")
121144
elif state.runtime_status == OrchestrationStatus.COMPLETED:
@@ -130,26 +153,25 @@ async def wait_for_orchestration_completion(self, instance_id: str, *,
130153
raise
131154

132155
async def raise_orchestration_event(
133-
self,
134-
instance_id: str,
135-
event_name: str,
136-
*,
137-
data: Optional[Any] = None):
156+
self, instance_id: str, event_name: str, *, data: Optional[Any] = None
157+
):
138158
req = pb.RaiseEventRequest(
139159
instanceId=instance_id,
140160
name=event_name,
141-
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
161+
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None,
162+
)
142163

143164
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
144165
await self._stub.RaiseEvent(req)
145166

146-
async def terminate_orchestration(self, instance_id: str, *,
147-
output: Optional[Any] = None,
148-
recursive: bool = True):
167+
async def terminate_orchestration(
168+
self, instance_id: str, *, output: Optional[Any] = None, recursive: bool = True
169+
):
149170
req = pb.TerminateRequest(
150171
instanceId=instance_id,
151172
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
152-
recursive=recursive)
173+
recursive=recursive,
174+
)
153175

154176
self._logger.info(f"Terminating instance '{instance_id}'.")
155177
await self._stub.TerminateInstance(req)

durabletask/aio/internal/grpc_interceptor.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,30 @@
77

88

99
class _ClientCallDetails(
10-
namedtuple(
11-
'_ClientCallDetails',
12-
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
13-
grpc_aio.ClientCallDetails):
10+
namedtuple(
11+
"_ClientCallDetails",
12+
["method", "timeout", "metadata", "credentials", "wait_for_ready", "compression"],
13+
),
14+
grpc_aio.ClientCallDetails,
15+
):
1416
pass
1517

1618

1719
class DefaultClientInterceptorImpl(
18-
grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor,
19-
grpc_aio.StreamUnaryClientInterceptor, grpc_aio.StreamStreamClientInterceptor):
20+
grpc_aio.UnaryUnaryClientInterceptor,
21+
grpc_aio.UnaryStreamClientInterceptor,
22+
grpc_aio.StreamUnaryClientInterceptor,
23+
grpc_aio.StreamStreamClientInterceptor,
24+
):
2025
"""Async gRPC client interceptor to add metadata to all calls."""
2126

2227
def __init__(self, metadata: list[tuple[str, str]]):
2328
super().__init__()
2429
self._metadata = metadata
2530

26-
def _intercept_call(self, client_call_details: _ClientCallDetails) -> grpc_aio.ClientCallDetails:
31+
def _intercept_call(
32+
self, client_call_details: _ClientCallDetails
33+
) -> grpc_aio.ClientCallDetails:
2734
if self._metadata is None:
2835
return client_call_details
2936

@@ -39,7 +46,8 @@ def _intercept_call(self, client_call_details: _ClientCallDetails) -> grpc_aio.C
3946
metadata,
4047
client_call_details.credentials,
4148
client_call_details.wait_for_ready,
42-
client_call_details.compression)
49+
client_call_details.compression,
50+
)
4351

4452
async def intercept_unary_unary(self, continuation, client_call_details, request):
4553
new_client_call_details = self._intercept_call(client_call_details)

durabletask/aio/internal/shared.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,43 @@
77
from grpc import aio as grpc_aio
88

99
from durabletask.internal.shared import (
10-
get_default_host_address,
11-
SECURE_PROTOCOLS,
1210
INSECURE_PROTOCOLS,
11+
SECURE_PROTOCOLS,
12+
get_default_host_address,
1313
)
1414

15-
1615
ClientInterceptor = Union[
1716
grpc_aio.UnaryUnaryClientInterceptor,
1817
grpc_aio.UnaryStreamClientInterceptor,
1918
grpc_aio.StreamUnaryClientInterceptor,
20-
grpc_aio.StreamStreamClientInterceptor
19+
grpc_aio.StreamStreamClientInterceptor,
2120
]
2221

2322

2423
def get_grpc_aio_channel(
25-
host_address: Optional[str],
26-
secure_channel: bool = False,
27-
interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc_aio.Channel:
28-
24+
host_address: Optional[str],
25+
secure_channel: bool = False,
26+
interceptors: Optional[Sequence[ClientInterceptor]] = None,
27+
) -> grpc_aio.Channel:
2928
if host_address is None:
3029
host_address = get_default_host_address()
3130

3231
for protocol in SECURE_PROTOCOLS:
3332
if host_address.lower().startswith(protocol):
3433
secure_channel = True
35-
host_address = host_address[len(protocol):]
34+
host_address = host_address[len(protocol) :]
3635
break
3736

3837
for protocol in INSECURE_PROTOCOLS:
3938
if host_address.lower().startswith(protocol):
4039
secure_channel = False
41-
host_address = host_address[len(protocol):]
40+
host_address = host_address[len(protocol) :]
4241
break
4342

4443
if secure_channel:
45-
channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors)
44+
channel = grpc_aio.secure_channel(
45+
host_address, grpc.ssl_channel_credentials(), interceptors=interceptors
46+
)
4647
else:
4748
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors)
4849

0 commit comments

Comments
 (0)