Skip to content

Commit

Permalink
Merge pull request #43 from Doist/proxi/add-types-2
Browse files Browse the repository at this point in the history
chore: Partial typing.
  • Loading branch information
proxi authored Oct 24, 2023
2 parents 1621442 + 42e6092 commit 5e661ae
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 59 deletions.
4 changes: 2 additions & 2 deletions sqs_workers/async_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def __init__(self, async_task, args, kwargs):
self.args = args
self.kwargs = kwargs

def __call__(self):
def __call__(self) -> None:
self.async_task(*self.args, **self.kwargs)

def delay(self):
def delay(self) -> None:
self.async_task.delay(*self.args, **self.kwargs)

def __repr__(self) -> str:
Expand Down
15 changes: 10 additions & 5 deletions sqs_workers/backoff_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ class ConstantBackoff(BackoffPolicy):
immediately on failure
"""

def __init__(self, backoff_value=0):
def __init__(self, backoff_value: float = 0):
self.backoff_value = backoff_value

def get_visibility_timeout(self, message):
def get_visibility_timeout(self, message) -> float:
return self.backoff_value


Expand All @@ -25,14 +25,19 @@ class ExponentialBackoff(BackoffPolicy):
with an exponential backoff
"""

def __init__(self, base=2, min_visibility_timeout=0, max_visbility_timeout=30 * 60):
def __init__(
self,
base: float = 2,
min_visibility_timeout: float = 0,
max_visbility_timeout: float = 30 * 60,
) -> None:
self.base = base # in seconds
self.min_visibility_timeout = min_visibility_timeout
self.max_visibility_timeout = max_visbility_timeout

def get_visibility_timeout(self, message):
def get_visibility_timeout(self, message) -> int:
prev_receive_count = int(message.attributes["ApproximateReceiveCount"]) - 1
mu = self.min_visibility_timeout + (self.base ** prev_receive_count)
mu = self.min_visibility_timeout + (self.base**prev_receive_count)
sigma = float(mu) / 10
visibility_timeout = random.normalvariate(mu, sigma)
visibility_timeout = max(self.min_visibility_timeout, visibility_timeout)
Expand Down
2 changes: 1 addition & 1 deletion sqs_workers/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def batch_size(self) -> int:
class BatchMessages(BatchingConfiguration):
"""Configures the processor to send a list of messages to the call handler"""

def __init__(self, batch_size):
def __init__(self, batch_size: int):
self.number_of_messages = batch_size

@property
Expand Down
2 changes: 1 addition & 1 deletion sqs_workers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Config(object):
options: Dict[str, Any] = attr.ib(factory=dict)
maker_key = attr.ib(default="maker")

def __setitem__(self, key, value):
def __setitem__(self, key: str, value):
self.options.__setitem__(key, value)

def __getitem__(self, item):
Expand Down
16 changes: 8 additions & 8 deletions sqs_workers/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import json
import logging
from typing import Any
from typing import Any, Optional

logger = logging.getLogger(__name__)


class BatchProcessingResult(object):
def __init__(self, queue_name, succeeded=None, failed=None):
def __init__(self, queue_name: str, succeeded=None, failed=None):
self.queue_name = queue_name
self.succeeded = succeeded or []
self.failed = failed or []
Expand All @@ -20,16 +20,16 @@ def update_with_message(self, message: Any, success: bool):
else:
self.failed.append(message)

def succeeded_count(self):
def succeeded_count(self) -> int:
return len(self.succeeded)

def failed_count(self):
def failed_count(self) -> int:
return len(self.failed)

def total_count(self):
def total_count(self) -> int:
return self.succeeded_count() + self.failed_count()

def __repr__(self):
def __repr__(self) -> str:
return "<BatchProcessingResult/%s/%s/%s>" % (
self.queue_name,
self.succeeded_count(),
Expand All @@ -50,7 +50,7 @@ def __init__(self, sqs_env, dead_letter_queue_name, max_receive_count):
self.dead_letter_queue_name = dead_letter_queue_name
self.max_receive_count = max_receive_count

def __json__(self):
def __json__(self) -> str:
queue = self.sqs_env.sqs_resource.get_queue_by_name(
QueueName=self.sqs_env.get_sqs_queue_name(self.dead_letter_queue_name)
)
Expand All @@ -64,6 +64,6 @@ def __json__(self):
)


def get_job_name(message):
def get_job_name(message) -> Optional[str]:
attrs = message.message_attributes or {}
return (attrs.get("JobName") or {}).get("StringValue")
18 changes: 8 additions & 10 deletions sqs_workers/memory_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ class MemoryAWS(object):
)
queues: List["MemoryQueue"] = attr.ib(factory=list)

def create_queue(self, QueueName, Attributes):
def create_queue(self, QueueName: str, Attributes) -> "MemoryQueue":
queue = MemoryQueue(self, QueueName, Attributes)
self.queues.append(queue)
return queue

def delete_queue(self, QueueUrl):
def delete_queue(self, QueueUrl: str) -> None:
self.queues = [queue for queue in self.queues if queue.url != QueueUrl]


Expand All @@ -55,24 +55,23 @@ class MemorySession(object):

aws = attr.ib(repr=False, factory=MemoryAWS)

def client(self, service_name, **kwargs):
def client(self, service_name: str, **kwargs):
assert service_name == "sqs"
return self.aws.client

def resource(self, service_name, **kwargs):
def resource(self, service_name: str, **kwargs):
assert service_name == "sqs"
return self.aws.resource


@attr.s
class MemoryClient(object):

aws = attr.ib(repr=False)

def create_queue(self, QueueName, Attributes):
def create_queue(self, QueueName: str, Attributes):
return self.aws.create_queue(QueueName, Attributes)

def delete_queue(self, QueueUrl):
def delete_queue(self, QueueUrl: str):
return self.aws.delete_queue(QueueUrl)

def list_queues(self, QueueNamePrefix=""):
Expand All @@ -87,13 +86,12 @@ def list_queues(self, QueueNamePrefix=""):

@attr.s
class ServiceResource(object):

aws: MemoryAWS = attr.ib(repr=False)

def create_queue(self, QueueName, Attributes):
def create_queue(self, QueueName: str, Attributes):
return self.aws.create_queue(QueueName, Attributes)

def get_queue_by_name(self, QueueName):
def get_queue_by_name(self, QueueName: str):
for queue in self.aws.queues:
if queue.name == QueueName:
return queue
Expand Down
17 changes: 9 additions & 8 deletions sqs_workers/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Dict,
Generator,
List,
Literal,
Optional,
TypeVar,
)
Expand Down Expand Up @@ -197,7 +198,7 @@ def drain_queue(self, wait_seconds=0):
deleted_count += len(messages)
return deleted_count

def get_sqs_queue_name(self):
def get_sqs_queue_name(self) -> str:
"""
Take "high-level" (user-visible) queue name and return SQS
("low level") name by simply prefixing it. Used to create namespaces
Expand Down Expand Up @@ -423,19 +424,19 @@ def connect_processor(
)
return AsyncTask(self, job_name, processor)

def get_processor(self, job_name):
def get_processor(self, job_name: str) -> Optional[Processor]:
"""
Helper function to return a processor for the queue
"""
return self.processors.get(job_name)

def open_add_batch(self):
def open_add_batch(self) -> None:
"""
Open an add batch.
"""
self._batch_level += 1

def close_add_batch(self):
def close_add_batch(self) -> None:
"""
Close an add batch.
"""
Expand Down Expand Up @@ -498,8 +499,8 @@ def _flush_batch_if_needed(self) -> None:
def add_job(
self,
job_name: str,
_content_type=None,
_delay_seconds=None,
_content_type: Optional[str] = None,
_delay_seconds: Optional[int] = None,
_deduplication_id=None,
_group_id: Optional[str] = None,
**job_kwargs
Expand Down Expand Up @@ -603,7 +604,7 @@ def process_messages(self, messages: List[Any]) -> bool:

return all(results)

def process_message_fallback(self, job_name):
def process_message_fallback(self, job_name: str) -> Literal[False]:
# Note: we don't set exc_info=True, since source of the error is almost
# certainly in another code base.
logger.error(
Expand All @@ -614,7 +615,7 @@ def process_message_fallback(self, job_name):
)
return False

def copy_processors(self, dst_queue: "JobQueue"):
def copy_processors(self, dst_queue: "JobQueue") -> None:
"""
Copy processors from self to dst_queue. Can be helpful to process d
ead-letter queue with processors from the main queue.
Expand Down
38 changes: 19 additions & 19 deletions sqs_workers/shutdown_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ class NeverShutdown(object):
Never shutdown the worker
"""

def update_state(self, batch_processing_result):
def update_state(self, batch_processing_result) -> None:
pass

def need_shutdown(self):
def need_shutdown(self) -> bool:
return False

def __repr__(self):
def __repr__(self) -> str:
return "NeverShutdown()"


Expand All @@ -23,13 +23,13 @@ class IdleShutdown(object):
Shutdown worker if it's idle for certain time (set in seconds)
"""

def __init__(self, idle_seconds):
def __init__(self, idle_seconds: int) -> None:
self.idle_seconds = idle_seconds
self._idle_delta = datetime.timedelta(seconds=idle_seconds)
self._is_idle = False
self._last_seen = now()

def update_state(self, batch_processing_result):
def update_state(self, batch_processing_result) -> None:
"""
Update internal state of the shutdown policy
"""
Expand All @@ -39,12 +39,12 @@ def update_state(self, batch_processing_result):
self._is_idle = False
self._last_seen = now()

def need_shutdown(self):
def need_shutdown(self) -> bool:
if not self._is_idle:
return False
return now() - self._last_seen >= self._idle_delta

def __repr__(self):
def __repr__(self) -> str:
return f"IdleShutdown({self.idle_seconds})"


Expand All @@ -54,17 +54,17 @@ class MaxTasksShutdown(object):
successfully or with error)
"""

def __init__(self, max_tasks):
def __init__(self, max_tasks: int) -> None:
self.max_tasks = max_tasks
self._tasks = 0

def update_state(self, batch_processing_result):
def update_state(self, batch_processing_result) -> None:
self._tasks += batch_processing_result.total_count()

def need_shutdown(self):
def need_shutdown(self) -> bool:
return self._tasks >= self.max_tasks

def __repr__(self):
def __repr__(self) -> str:
return f"MaxTasksShutdown({self.max_tasks})"


Expand All @@ -73,20 +73,20 @@ class OrShutdown(object):
Return True if any of conditions of policies is met
"""

def __init__(self, *policies):
def __init__(self, *policies) -> None:
self.policies = policies

def update_state(self, batch_processing_result):
def update_state(self, batch_processing_result) -> None:
for p in self.policies:
p.update_state(batch_processing_result)

def need_shutdown(self):
def need_shutdown(self) -> bool:
for p in self.policies:
if p.need_shutdown():
return True
return False

def __repr__(self):
def __repr__(self) -> str:
repr_policies = ", ".join([repr(p) for p in self.policies])
return f"OrShutdown({repr_policies})"

Expand All @@ -96,20 +96,20 @@ class AndShutdown(object):
Return True if all of conditions of policies are met
"""

def __init__(self, *policies):
def __init__(self, *policies) -> None:
self.policies = policies

def update_state(self, batch_processing_result):
def update_state(self, batch_processing_result) -> None:
for p in self.policies:
p.update_state(batch_processing_result)

def need_shutdown(self):
def need_shutdown(self) -> bool:
for p in self.policies:
if not p.need_shutdown():
return False
return True

def __repr__(self):
def __repr__(self) -> str:
repr_policies = ", ".join([repr(p) for p in self.policies])
return f"AndShutdown({repr_policies})"

Expand Down
2 changes: 1 addition & 1 deletion sqs_workers/sqs_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def get_all_known_queues(self):
ret.append(sqs_name[queue_prefix_len:])
return ret

def get_sqs_queue_name(self, queue_name):
def get_sqs_queue_name(self, queue_name: str) -> str:
"""
Take "high-level" (user-visible) queue name and return SQS
("low level") name by simply prefixing it. Used to create namespaces
Expand Down
Loading

0 comments on commit 5e661ae

Please sign in to comment.