From ec0a85adb2e429978225991db6b1e25bb4cfa889 Mon Sep 17 00:00:00 2001 From: "Terence D. Honles" Date: Tue, 3 Sep 2024 17:00:59 +0200 Subject: [PATCH] test type annotations with mypy --- .github/workflows/workflow.yml | 23 ++++- pyproject.toml | 17 ++++ rq/cli/cli.py | 4 +- rq/command.py | 2 +- rq/group.py | 2 +- rq/job.py | 178 +++++++++++++++++---------------- rq/local.py | 171 ++++++++++++++++++++++--------- rq/queue.py | 36 ++++--- rq/registry.py | 25 ++--- rq/serializers.py | 18 +++- rq/types.py | 2 +- rq/utils.py | 15 +-- rq/worker.py | 10 +- 13 files changed, 321 insertions(+), 182 deletions(-) diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index c30dd099d..e69b6dde4 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -66,4 +66,25 @@ jobs: uses: codecov/codecov-action@v4 with: file: ./coverage.xml - fail_ci_if_error: false \ No newline at end of file + fail_ci_if_error: false + + mypy: + runs-on: ubuntu-latest + name: Type check + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python 3.7 + uses: actions/setup-python@v5.1.1 + with: + python-version: "3.7" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + + - name: Run Test + run: | + hatch run test:typing diff --git a/pyproject.toml b/pyproject.toml index aee524fcf..c0e14b732 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,21 +80,38 @@ include = [ dependencies = [ "black", "coverage", + "mypy", "packaging", "psutil", "pytest", "pytest-cov", "ruff", "tox", + "types-greenlet", + "types-redis", ] [tool.hatch.envs.test.scripts] cov = "pytest --cov=rq --cov-config=.coveragerc --cov-report=xml {args:tests}" +typing = "mypy --enable-incomplete-feature=Unpack rq" [tool.black] line-length = 120 target-version = ["py38"] skip-string-normalization = true +[tool.mypy] +allow_redefinition = true +pretty = true +show_error_codes = true +show_error_context = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = "setproctitle.*" +ignore_missing_imports = true + [tool.ruff] # Set what ruff should check for. # See https://beta.ruff.rs/docs/rules/ for a list of rules. diff --git a/rq/cli/cli.py b/rq/cli/cli.py index d9fcf56e1..29aca0868 100755 --- a/rq/cli/cli.py +++ b/rq/cli/cli.py @@ -460,7 +460,7 @@ def worker_pool( setup_loghandlers_from_args(verbose, quiet, date_format, log_format) if serializer: - serializer_class: Type[DefaultSerializer] = import_attribute(serializer) + serializer_class: Type[DefaultSerializer] = import_attribute(serializer) # type: ignore[assignment] else: serializer_class = DefaultSerializer @@ -479,7 +479,7 @@ def worker_pool( logging_level = None pool = WorkerPool( - queue_names, + queue_names, # type: ignore[arg-type] connection=cli_config.connection, num_workers=num_workers, serializer=serializer_class, diff --git a/rq/command.py b/rq/command.py index 0488d687e..5d1135922 100644 --- a/rq/command.py +++ b/rq/command.py @@ -42,7 +42,7 @@ def parse_payload(payload: Dict[Any, Any]) -> Dict[Any, Any]: Args: payload (dict): Parses the payload dict. """ - return json.loads(payload.get('data').decode()) + return json.loads(payload['data'].decode()) def send_shutdown_command(connection: 'Redis', worker_name: str): diff --git a/rq/group.py b/rq/group.py index 56d87ecc1..94baf0dcf 100644 --- a/rq/group.py +++ b/rq/group.py @@ -17,7 +17,7 @@ class Group: REDIS_GROUP_NAME_PREFIX = 'rq:group:' REDIS_GROUP_KEY = 'rq:groups' - def __init__(self, connection: Redis, name: str = None): + def __init__(self, connection: Redis, name: Optional[str] = None): self.name = name if name else str(uuid4().hex) self.connection = connection self.key = '{0}{1}'.format(self.REDIS_GROUP_NAME_PREFIX, self.name) diff --git a/rq/job.py b/rq/job.py index 91867ca72..cf6213bd2 100644 --- a/rq/job.py +++ b/rq/job.py @@ -6,7 +6,7 @@ import zlib from datetime import datetime, timedelta, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union from uuid import uuid4 from redis import WatchError @@ -17,11 +17,17 @@ if TYPE_CHECKING: from redis import Redis from redis.client import Pipeline + from _typeshed import ExcInfo + from typing_extensions import Unpack from .executions import Execution, ExecutionRegistry from .queue import Queue from .results import Result + class UnevaluatedType: + pass + + from .exceptions import DeserializationError, InvalidJobOperation, NoSuchJobError from .local import LocalStack from .serializers import resolve_serializer @@ -89,7 +95,7 @@ def __init__(self, jobs: List[Union['Job', str]], allow_failure: bool = False, e self.enqueue_at_front = enqueue_at_front -UNEVALUATED = object() +UNEVALUATED: 'UnevaluatedType' = object() # type: ignore[assignment] """Sentinel value to mark that some of our lazily evaluated properties have not yet been evaluated. """ @@ -144,13 +150,66 @@ def requeue_job(job_id: str, connection: 'Redis', serializer=None) -> 'Job': class Job: """A Job is just a convenient datastructure to pass around job (meta) data.""" + _dependency: Optional['Job'] redis_job_namespace_prefix = 'rq:job:' + def __init__(self, id: Optional[str] = None, connection: Optional['Redis'] = None, serializer=None): + # Manually check for the presence of the connection argument to preserve + # backwards compatibility during the transition to RQ v2.0.0. + if not connection: + raise TypeError("Job.__init__() missing 1 required argument: 'connection'") + self.connection = connection + self._id = id + self.created_at = now() + self._data = UNEVALUATED + self._func_name: Union[str, 'UnevaluatedType'] = UNEVALUATED + self._instance: Optional[Union[object, 'UnevaluatedType']] = UNEVALUATED + self._args: Union[tuple, list, 'UnevaluatedType'] = UNEVALUATED + self._kwargs: Union[Dict[str, Any], 'UnevaluatedType'] = UNEVALUATED + self._success_callback_name: Optional[str] = None + self._success_callback: Union[Callable[['Job', 'Redis', Any], Any], 'UnevaluatedType'] = UNEVALUATED + self._failure_callback_name: Optional[str] = None + self._failure_callback: Union[Callable[['Job', 'Redis', Unpack[Tuple['ExcInfo']]], Any], 'UnevaluatedType'] = ( + UNEVALUATED + ) + self._stopped_callback_name: Optional[str] = None + self._stopped_callback: Union[Callable[['Job', 'Redis'], Any], 'UnevaluatedType'] = UNEVALUATED + self.description: Optional[str] = None + self.origin: str = '' + self.enqueued_at: Optional[datetime] = None + self.started_at: Optional[datetime] = None + self.ended_at: Optional[datetime] = None + self._result: Optional[Any] = None + self._exc_info: Optional[str] = None + self.timeout: Optional[float] = None + self._success_callback_timeout: Optional[int] = None + self._failure_callback_timeout: Optional[int] = None + self._stopped_callback_timeout: Optional[int] = None + self.result_ttl: Optional[int] = None + self.failure_ttl: Optional[int] = None + self.ttl: Optional[int] = None + self.worker_name: Optional[str] = None + self._status: Optional[JobStatus] = None + self._dependency_ids: List[str] = [] + self.meta: Dict[str, Any] = {} + self.serializer = resolve_serializer(serializer) + self.retries_left: Optional[int] = None + self.retry_intervals: Optional[List[int]] = None + self.redis_server_version: Optional[Tuple[int, int, int]] = None + self.last_heartbeat: Optional[datetime] = None + self.allow_dependency_failures: Optional[bool] = None + self.enqueue_at_front: Optional[bool] = None + self.group_id: Optional[str] = None + + from .results import Result + + self._cached_result: Optional[Result] = None + @classmethod def create( cls, func: FunctionReferenceType, - args: Union[List[Any], Optional[Tuple]] = None, + args: Optional[Union[list, tuple]] = None, kwargs: Optional[Dict[str, Any]] = None, connection: Optional['Redis'] = None, result_ttl: Optional[int] = None, @@ -296,9 +355,8 @@ def create( # dependency could be job instance or id, or iterable thereof if depends_on is not None: - depends_on = ensure_list(depends_on) depends_on_list = [] - for depends_on_item in depends_on: + for depends_on_item in ensure_list(depends_on): if isinstance(depends_on_item, Dependency): # If a Dependency has enqueue_at_front or allow_failure set to True, these behaviors are used for # all dependencies. @@ -321,10 +379,10 @@ def get_position(self) -> Optional[int]: if self.origin: q = Queue(name=self.origin, connection=self.connection) - return q.get_job_position(self._id) + return q.get_job_position(self.id) return None - def get_status(self, refresh: bool = True) -> JobStatus: + def get_status(self, refresh: bool = True) -> Optional[JobStatus]: """Gets the Job Status Args: @@ -365,7 +423,7 @@ def get_meta(self, refresh: bool = True) -> Dict: """ if refresh: meta = self.connection.hget(self.key, 'meta') - self.meta = self.serializer.loads(meta) if meta else {} + self.meta = self.serializer.loads(meta) if meta else {} # type: ignore[assignment] return self.meta @@ -552,10 +610,10 @@ def instance(self, value): self._data = UNEVALUATED @property - def args(self) -> tuple: + def args(self) -> Union[list, tuple]: if self._args is UNEVALUATED: self._deserialize_data() - return self._args + return self._args # type: ignore[return-value] @args.setter def args(self, value): @@ -563,10 +621,10 @@ def args(self, value): self._data = UNEVALUATED @property - def kwargs(self): + def kwargs(self) -> Dict[str, Any]: if self._kwargs is UNEVALUATED: self._deserialize_data() - return self._kwargs + return self._kwargs # type: ignore[return-value] @kwargs.setter def kwargs(self, value): @@ -639,56 +697,6 @@ def fetch_many(cls, job_ids: Iterable[str], connection: 'Redis', serializer=None return jobs - def __init__(self, id: Optional[str] = None, connection: 'Redis' = None, serializer=None): - # Manually check for the presence of the connection argument to preserve - # backwards compatibility during the transition to RQ v2.0.0. - if not connection: - raise TypeError("Job.__init__() missing 1 required argument: 'connection'") - self.connection = connection - self._id = id - self.created_at = now() - self._data = UNEVALUATED - self._func_name = UNEVALUATED - self._instance = UNEVALUATED - self._args = UNEVALUATED - self._kwargs = UNEVALUATED - self._success_callback_name = None - self._success_callback = UNEVALUATED - self._failure_callback_name = None - self._failure_callback = UNEVALUATED - self._stopped_callback_name = None - self._stopped_callback = UNEVALUATED - self.description: Optional[str] = None - self.origin: str = '' - self.enqueued_at: Optional[datetime] = None - self.started_at: Optional[datetime] = None - self.ended_at: Optional[datetime] = None - self._result = None - self._exc_info = None - self.timeout: Optional[float] = None - self._success_callback_timeout: Optional[int] = None - self._failure_callback_timeout: Optional[int] = None - self._stopped_callback_timeout: Optional[int] = None - self.result_ttl: Optional[int] = None - self.failure_ttl: Optional[int] = None - self.ttl: Optional[int] = None - self.worker_name: Optional[str] = None - self._status = None - self._dependency_ids: List[str] = [] - self.meta: Dict = {} - self.serializer = resolve_serializer(serializer) - self.retries_left: Optional[int] = None - self.retry_intervals: Optional[List[int]] = None - self.redis_server_version: Optional[Tuple[int, int, int]] = None - self.last_heartbeat: Optional[datetime] = None - self.allow_dependency_failures: Optional[bool] = None - self.enqueue_at_front: Optional[bool] = None - self.group_id: Optional[str] = None - - from .results import Result - - self._cached_result: Optional[Result] = None - def __repr__(self): # noqa # pragma: no cover return '{0}({1!r}, enqueued_at={2!r})'.format(self.__class__.__name__, self._id, self.enqueued_at) @@ -937,15 +945,15 @@ def restore(self, raw_data) -> Any: # Fallback to uncompressed string self.data = raw_data - self.created_at = str_to_date(obj.get('created_at')) - self.origin = as_text(obj.get('origin')) if obj.get('origin') else '' - self.worker_name = obj.get('worker_name').decode() if obj.get('worker_name') else None - self.description = as_text(obj.get('description')) if obj.get('description') else None + self.created_at = str_to_date(obj.get('created_at')) # type: ignore[assignment] + self.origin = as_text(obj['origin']) if obj.get('origin') else '' + self.worker_name = obj['worker_name'].decode() if obj.get('worker_name') else None + self.description = as_text(obj['description']) if obj.get('description') else None self.enqueued_at = str_to_date(obj.get('enqueued_at')) self.started_at = str_to_date(obj.get('started_at')) self.ended_at = str_to_date(obj.get('ended_at')) self.last_heartbeat = str_to_date(obj.get('last_heartbeat')) - self.group_id = as_text(obj.get('group_id')) if obj.get('group_id') else None + self.group_id = as_text(obj['group_id']) if obj.get('group_id') else None result = obj.get('result') if result: try: @@ -953,27 +961,27 @@ def restore(self, raw_data) -> Any: except Exception: self._result = UNSERIALIZABLE_RETURN_VALUE_PAYLOAD self.timeout = parse_timeout(obj.get('timeout')) if obj.get('timeout') else None - self.result_ttl = int(obj.get('result_ttl')) if obj.get('result_ttl') else None - self.failure_ttl = int(obj.get('failure_ttl')) if obj.get('failure_ttl') else None - self._status = JobStatus(as_text(obj.get('status'))) if obj.get('status') else None + self.result_ttl = int(obj['result_ttl']) if obj.get('result_ttl') else None + self.failure_ttl = int(obj['failure_ttl']) if obj.get('failure_ttl') else None + self._status = JobStatus(as_text(obj['status'])) if obj.get('status') else None if obj.get('success_callback_name'): - self._success_callback_name = obj.get('success_callback_name').decode() + self._success_callback_name = obj['success_callback_name'].decode() if 'success_callback_timeout' in obj: - self._success_callback_timeout = int(obj.get('success_callback_timeout')) + self._success_callback_timeout = int(obj['success_callback_timeout']) if obj.get('failure_callback_name'): - self._failure_callback_name = obj.get('failure_callback_name').decode() + self._failure_callback_name = obj['failure_callback_name'].decode() if 'failure_callback_timeout' in obj: - self._failure_callback_timeout = int(obj.get('failure_callback_timeout')) + self._failure_callback_timeout = int(obj['failure_callback_timeout']) if obj.get('stopped_callback_name'): - self._stopped_callback_name = obj.get('stopped_callback_name').decode() + self._stopped_callback_name = obj['stopped_callback_name'].decode() if 'stopped_callback_timeout' in obj: - self._stopped_callback_timeout = int(obj.get('stopped_callback_timeout')) + self._stopped_callback_timeout = int(obj['stopped_callback_timeout']) dep_ids = obj.get('dependency_ids') dep_id = obj.get('dependency_id') # for backwards compatibility @@ -981,15 +989,15 @@ def restore(self, raw_data) -> Any: allow_failures = obj.get('allow_dependency_failures') self.allow_dependency_failures = bool(int(allow_failures)) if allow_failures else None self.enqueue_at_front = bool(int(obj['enqueue_at_front'])) if 'enqueue_at_front' in obj else None - self.ttl = int(obj.get('ttl')) if obj.get('ttl') else None + self.ttl = int(obj['ttl']) if obj.get('ttl') else None try: - self.meta = self.serializer.loads(obj.get('meta')) if obj.get('meta') else {} + self.meta = self.serializer.loads(obj['meta']) if obj.get('meta') else {} # type: ignore[assignment] except Exception: # depends on the serializer self.meta = {'unserialized': obj.get('meta', {})} - self.retries_left = int(obj.get('retries_left')) if obj.get('retries_left') else None + self.retries_left = int(obj['retries_left']) if obj.get('retries_left') else None if obj.get('retry_intervals'): - self.retry_intervals = json.loads(obj.get('retry_intervals').decode()) + self.retry_intervals = json.loads(obj['retry_intervals'].decode()) raw_exc_info = obj.get('exc_info') if raw_exc_info: @@ -1024,7 +1032,7 @@ def to_dict(self, include_meta: bool = True, include_result: bool = True) -> dic Returns: dict: The Job serialized as a dictionary """ - obj = { + obj: Dict[str, Any] = { 'created_at': utcformat(self.created_at or now()), 'data': zlib.compress(self.data), 'success_callback_name': self._success_callback_name if self._success_callback_name else '', @@ -1328,10 +1336,10 @@ def prepare_for_execution(self, worker_name: str, pipeline: 'Pipeline'): self.last_heartbeat = now() self.started_at = self.last_heartbeat self._status = JobStatus.STARTED - mapping = { + mapping: Mapping = { 'last_heartbeat': utcformat(self.last_heartbeat), 'status': self._status, - 'started_at': utcformat(self.started_at), # type: ignore + 'started_at': utcformat(self.started_at), 'worker_name': worker_name, } pipeline.hset(self.key, mapping=mapping) @@ -1523,6 +1531,7 @@ def get_retry_interval(self) -> int: if self.retry_intervals is None: return 0 number_of_intervals = len(self.retry_intervals) + assert self.retries_left index = max(number_of_intervals - self.retries_left, 0) return self.retry_intervals[index] @@ -1537,6 +1546,7 @@ def retry(self, queue: 'Queue', pipeline: 'Pipeline'): pipeline (Pipeline): The Redis' pipeline to use """ retry_interval = self.get_retry_interval() + assert self.retries_left self.retries_left = self.retries_left - 1 if retry_interval: scheduled_datetime = datetime.now(timezone.utc) + timedelta(seconds=retry_interval) @@ -1673,7 +1683,7 @@ def __init__(self, max: int, interval: Union[int, Iterable[int]] = 0): for i in interval: if i < 0: raise ValueError('interval: negative numbers are not allowed') - intervals = interval + intervals = list(interval) self.max = max self.intervals = intervals diff --git a/rq/local.py b/rq/local.py index 2fe22c901..58756a12d 100644 --- a/rq/local.py +++ b/rq/local.py @@ -14,13 +14,7 @@ try: from greenlet import getcurrent as get_ident except ImportError: - try: - from threading import get_ident - except ImportError: - try: - from _thread import get_ident - except ImportError: - from dummy_thread import get_ident + from threading import get_ident # type: ignore[assignment] def release_local(local): @@ -313,44 +307,125 @@ def __setitem__(self, key, value): def __delitem__(self, key): del self._get_current_object()[key] - __setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v) - __delattr__ = lambda x, n: delattr(x._get_current_object(), n) - __str__ = lambda x: str(x._get_current_object()) - __lt__ = lambda x, o: x._get_current_object() < o - __le__ = lambda x, o: x._get_current_object() <= o - __eq__ = lambda x, o: x._get_current_object() == o - __ne__ = lambda x, o: x._get_current_object() != o - __gt__ = lambda x, o: x._get_current_object() > o - __ge__ = lambda x, o: x._get_current_object() >= o - __hash__ = lambda x: hash(x._get_current_object()) - __call__ = lambda x, *a, **kw: x._get_current_object()(*a, **kw) - __len__ = lambda x: len(x._get_current_object()) - __getitem__ = lambda x, i: x._get_current_object()[i] - __iter__ = lambda x: iter(x._get_current_object()) - __contains__ = lambda x, i: i in x._get_current_object() - __add__ = lambda x, o: x._get_current_object() + o - __sub__ = lambda x, o: x._get_current_object() - o - __mul__ = lambda x, o: x._get_current_object() * o - __floordiv__ = lambda x, o: x._get_current_object() // o - __mod__ = lambda x, o: x._get_current_object() % o - __divmod__ = lambda x, o: x._get_current_object().__divmod__(o) - __pow__ = lambda x, o: x._get_current_object() ** o - __lshift__ = lambda x, o: x._get_current_object() << o - __rshift__ = lambda x, o: x._get_current_object() >> o - __and__ = lambda x, o: x._get_current_object() & o - __xor__ = lambda x, o: x._get_current_object() ^ o - __or__ = lambda x, o: x._get_current_object() | o - __div__ = lambda x, o: x._get_current_object().__div__(o) - __truediv__ = lambda x, o: x._get_current_object().__truediv__(o) - __neg__ = lambda x: -(x._get_current_object()) - __pos__ = lambda x: +(x._get_current_object()) - __abs__ = lambda x: abs(x._get_current_object()) - __invert__ = lambda x: ~(x._get_current_object()) - __complex__ = lambda x: complex(x._get_current_object()) - __int__ = lambda x: int(x._get_current_object()) - __float__ = lambda x: float(x._get_current_object()) - __oct__ = lambda x: oct(x._get_current_object()) - __hex__ = lambda x: hex(x._get_current_object()) - __index__ = lambda x: x._get_current_object().__index__() - __enter__ = lambda x: x._get_current_object().__enter__() - __exit__ = lambda x, *a, **kw: x._get_current_object().__exit__(*a, **kw) + def __setattr__(self, name, value): + setattr(self._get_current_object(), name, value) + + def __delattr__(self, name): + return delattr(self._get_current_object(), name) + + def __str__(self): + return restr(self._get_current_object()) + + def __lt__(self, other): + return self._get_current_object() < other + + def __le__(self, other): + return self._get_current_object() <= other + + def __eq__(self, other): + return self._get_current_object() == other + + def __ne__(self, other): + return self._get_current_object() != other + + def __gt__(self, other): + return self._get_current_object() > other + + def __ge__(self, other): + return self._get_current_object() >= other + + def __hash__(self): + return hash(self._get_current_object()) + + def __call__(self, *args, **kwargs): + return self._get_current_object()(*args, **kwargs) + + def __len__(self): + return len(self._get_current_object()) + + def __getitem__(self, i): + return self._get_current_object()[i] + + def __iter__(self): + return iter(self._get_current_object()) + + def __contains__(self, obj): + obj in self._get_current_object() + + def __add__(self, other): + return self._get_current_object() + other + + def __sub__(self, other): + return self._get_current_object() - other + + def __mul__(self, other): + return self._get_current_object() * other + + def __floordiv__(self, other): + return self._get_current_object() // other + + def __mod__(self, other): + return self._get_current_object() % other + + def __divmod__(self, other): + return self._get_current_object().__divmod__(other) + + def __pow__(self, other): + return self._get_current_object() ** other + + def __lshift__(self, other): + return self._get_current_object() << other + + def __rshift__(self, other): + return self._get_current_object() >> other + + def __and__(self, other): + return self._get_current_object() & other + + def __xor__(self, other): + return self._get_current_object() ^ other + + def __or__(self, other): + return self._get_current_object() | other + + def __div__(self, other): + return self._get_current_object().__div__(other) + + def __truediv__(self, other): + return self._get_current_object().__truediv__(other) + + def __neg__(self): + return -(self._get_current_object()) + + def __pos__(self): + return +(self._get_current_object()) + + def __abs__(self): + return abs(self._get_current_object()) + + def __invert__(self): + return ~(self._get_current_object()) + + def __complex__(self): + return complex(self._get_current_object()) + + def __int__(self): + return int(self._get_current_object()) + + def __float__(self): + return float(self._get_current_object()) + + def __oct__(self): + return oct(self._get_current_object()) + + def __hex__(self): + return hex(self._get_current_object()) + + def __index__(self): + return self._get_current_object().__index__() + + def __enter__(self): + return self._get_current_object().__enter__() + + def __exit__(self, *args, **kwargs): + return self._get_current_object().__exit__(*args, **kwargs) diff --git a/rq/queue.py b/rq/queue.py index 96851d031..b1caa98eb 100644 --- a/rq/queue.py +++ b/rq/queue.py @@ -6,7 +6,7 @@ from collections import namedtuple from datetime import datetime, timedelta, timezone from functools import total_ordering -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union from redis import WatchError @@ -143,7 +143,7 @@ def from_queue_key( def __init__( self, name: str = 'default', - connection: 'Redis' = None, + connection: Optional['Redis'] = None, default_timeout: Optional[int] = None, is_async: bool = True, job_class: Optional[Union[str, Type['Job']]] = None, @@ -182,9 +182,10 @@ def __init__( # override class attribute job_class if one was passed if job_class is not None: if isinstance(job_class, str): - job_class = import_attribute(job_class) - self.job_class = job_class - self.death_penalty_class = death_penalty_class # type: ignore + self.job_class = import_attribute(job_class) # type: ignore[assignment] + else: + self.job_class = job_class + self.death_penalty_class = death_penalty_class # type: ignore[assignment] self.serializer = resolve_serializer(serializer) self.redis_server_version: Optional[Tuple[int, int, int]] = None @@ -229,7 +230,7 @@ def registry_cleaning_key(self): return 'rq:clean_registries:%s' % self.name @property - def scheduler_pid(self) -> int: + def scheduler_pid(self) -> Optional[int]: from rq.scheduler import RQScheduler pid = self.connection.get(RQScheduler.get_locking_key(self.name)) @@ -331,6 +332,8 @@ def fetch_job(self, job_id: str) -> Optional['Job']: if job.origin == self.name: return job + return None + def get_job_position(self, job_or_id: Union['Job', str]) -> Optional[int]: """Returns the position of a job within the queue @@ -787,7 +790,7 @@ def prepare_data( ) def enqueue_many( - self, job_datas: List['EnqueueData'], pipeline: Optional['Pipeline'] = None, group_id: str = None + self, job_datas: List['EnqueueData'], pipeline: Optional['Pipeline'] = None, group_id: Optional[str] = None ) -> List[Job]: """Creates multiple jobs (created via `Queue.prepare_data` calls) to represent the delayed function calls and enqueues them. @@ -1165,10 +1168,10 @@ def run_sync(self, job: 'Job') -> 'Job': pipeline.execute() if job.failure_callback: - job.failure_callback(job, self.connection, *sys.exc_info()) # type: ignore + job.failure_callback(job, self.connection, *sys.exc_info()) else: if job.success_callback: - job.success_callback(job, self.connection, job.return_value()) # type: ignore + job.success_callback(job, self.connection, job.return_value()) return job @@ -1198,7 +1201,7 @@ def enqueue_dependents( if pipeline is None: pipe.watch(dependents_key) - dependent_job_ids = {as_text(_id) for _id in pipe.smembers(dependents_key)} + dependent_job_ids = {as_text(_id) for _id in pipe.smembers(dependents_key)} # type: ignore[attr-defined] # There's no dependents if not dependent_job_ids: @@ -1295,6 +1298,7 @@ def lpop(cls, queue_keys: List[str], timeout: Optional[int], connection: Optiona raise ValueError('RQ does not support indefinite timeouts. Please pick a timeout value > 0') colored_queues = ', '.join(map(str, [green(str(queue)) for queue in queue_keys])) logger.debug(f"Starting BLPOP operation for queues {colored_queues} with timeout of {timeout}") + assert connection result = connection.blpop(queue_keys, timeout) if result is None: logger.debug(f"BLPOP timeout, no jobs found on queues {colored_queues}") @@ -1303,6 +1307,7 @@ def lpop(cls, queue_keys: List[str], timeout: Optional[int], connection: Optiona return queue_key, job_id else: # non-blocking variant for queue_key in queue_keys: + assert connection blob = connection.lpop(queue_key) if blob is not None: return queue_key, blob @@ -1319,13 +1324,13 @@ def lmove(cls, connection: 'Redis', queue_key: str, timeout: Optional[int]): raise ValueError('RQ does not support indefinite timeouts. Please pick a timeout value > 0') colored_queue = green(queue_key) logger.debug(f"Starting BLMOVE operation for {colored_queue} with timeout of {timeout}") - result = connection.blmove(queue_key, intermediate_queue.key, timeout) + result: Optional[Any] = connection.blmove(queue_key, intermediate_queue.key, timeout) if result is None: logger.debug(f"BLMOVE timeout, no jobs found on {colored_queue}") raise DequeueTimeout(timeout, queue_key) return queue_key, result else: # non-blocking variant - result = connection.lmove(queue_key, intermediate_queue.key) + result = cast(Optional[Any], connection.lmove(queue_key, intermediate_queue.key)) if result is not None: return queue_key, result return None @@ -1364,7 +1369,7 @@ def dequeue_any( Returns: job, queue (Tuple[Job, Queue]): A tuple of Job, Queue """ - job_cls: Type[Job] = backend_class(cls, 'job_class', override=job_class) # type: ignore + job_cls: Type[Job] = backend_class(cls, 'job_class', override=job_class) while True: queue_keys = [q.key for q in queues] @@ -1391,11 +1396,10 @@ def dequeue_any( except Exception as e: # Attach queue information on the exception for improved error # reporting - e.job_id = job_id - e.queue = queue + e.job_id = job_id # type: ignore[attr-defined] + e.queue = queue # type: ignore[attr-defined] raise e return job, queue - return None, None # Total ordering definition (the rest of the required Python methods are # auto-generated by the @total_ordering decorator) diff --git a/rq/registry.py b/rq/registry.py index 2c6680427..894d5b377 100644 --- a/rq/registry.py +++ b/rq/registry.py @@ -3,7 +3,7 @@ import time import traceback from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, cast, List, Optional, Type, Union from rq.serializers import resolve_serializer @@ -34,6 +34,7 @@ class BaseRegistry: job_class = Job death_penalty_class = UnixSignalDeathPenalty key_template = 'rq:registry:{0}' + cleanup: Callable[..., Any] def __init__( self, @@ -50,7 +51,7 @@ def __init__( self.serializer = queue.serializer else: self.name = name - self.connection = connection + self.connection = connection # type: ignore[assignment] self.serializer = resolve_serializer(serializer) self.key = self.key_template.format(self.name) @@ -78,7 +79,7 @@ def __contains__(self, item: Union[str, 'Job']) -> bool: job_id = item if isinstance(item, self.job_class): job_id = item.id - return self.connection.zscore(self.key, job_id) is not None + return self.connection.zscore(self.key, cast(str, job_id)) is not None @property def count(self) -> int: @@ -106,11 +107,11 @@ def add(self, job: 'Job', ttl=0, pipeline: Optional['Pipeline'] = None, xx: bool if score == -1: score = '+inf' if pipeline is not None: - return pipeline.zadd(self.key, {job.id: score}, xx=xx) + return cast(int, pipeline.zadd(self.key, {job.id: score}, xx=xx)) return self.connection.zadd(self.key, {job.id: score}, xx=xx) - def remove(self, job: 'Job', pipeline: Optional['Pipeline'] = None, delete_job: bool = False): + def remove(self, job: Union['Job', str], pipeline: Optional['Pipeline'] = None, delete_job: bool = False): """Removes job from registry and deletes it if `delete_job == True` Args: @@ -176,7 +177,7 @@ def get_expiration_time(self, job: 'Job') -> datetime: job (Job): The Job to get the expiration """ score = self.connection.zscore(self.key, job.id) - return datetime.utcfromtimestamp(score) + return datetime.utcfromtimestamp(score) # type: ignore[arg-type] def requeue(self, job_or_id: Union['Job', str], at_front: bool = False) -> 'Job': """Requeues the job with the given job ID. @@ -225,7 +226,7 @@ class StartedJobRegistry(BaseRegistry): key_template = 'rq:wip:{0}' - def cleanup(self, timestamp: Optional[float] = None, exception_handlers: List = None): + def cleanup(self, timestamp: Optional[float] = None, exception_handlers: Optional[list] = None): """Remove abandoned jobs from registry and add them to FailedJobRegistry. Removes jobs with an expiry time earlier than timestamp, specified as @@ -361,7 +362,7 @@ def cleanup(self, timestamp: Optional[float] = None): score = timestamp if timestamp is not None else current_timestamp() self.connection.zremrangebyscore(self.key, 0, score) - def add( + def add( # type: ignore[override] self, job: 'Job', ttl=None, @@ -478,7 +479,7 @@ def remove_jobs(self, timestamp: Optional[datetime] = None, pipeline: Optional[' pipeline (Optional[Pipeline], optional): The Redis pipeline. Defaults to None. """ connection = pipeline if pipeline is not None else self.connection - score = timestamp if timestamp is not None else current_timestamp() + score: Any = timestamp if timestamp is not None else current_timestamp() return connection.zremrangebyscore(self.key, 0, score) def get_jobs_to_schedule(self, timestamp: Optional[datetime] = None, chunk_size: int = 1000) -> List[str]: @@ -491,7 +492,7 @@ def get_jobs_to_schedule(self, timestamp: Optional[datetime] = None, chunk_size: Returns: jobs (List[str]): A list of Job ids """ - score = timestamp if timestamp is not None else current_timestamp() + score: Any = timestamp if timestamp is not None else current_timestamp() jobs_to_schedule = self.connection.zrangebyscore(self.key, 0, score, start=0, num=chunk_size) return [as_text(job_id) for job_id in jobs_to_schedule] @@ -522,7 +523,7 @@ def get_scheduled_time(self, job_or_id: Union['Job', str]) -> datetime: class CanceledJobRegistry(BaseRegistry): key_template = 'rq:canceled:{0}' - def get_expired_job_ids(self, timestamp: Optional[datetime] = None): + def get_expired_job_ids(self, timestamp: Optional[float] = None): raise NotImplementedError def cleanup(self): @@ -532,7 +533,7 @@ def cleanup(self): pass -def clean_registries(queue: 'Queue', exception_handlers: list = None): +def clean_registries(queue: 'Queue', exception_handlers: Optional[list] = None): """Cleans StartedJobRegistry, FinishedJobRegistry and FailedJobRegistry, and DeferredJobRegistry of a queue. Args: diff --git a/rq/serializers.py b/rq/serializers.py index 94eddbf26..8d0a2ad44 100644 --- a/rq/serializers.py +++ b/rq/serializers.py @@ -1,14 +1,21 @@ import json import pickle from functools import partial -from typing import Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Type, Union from .utils import import_attribute +if TYPE_CHECKING: + from typing_extensions import Protocol + + class Serializer(Protocol): + dumps: Callable[..., bytes] + loads: Callable[..., object] + class DefaultSerializer: - dumps = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) - loads = pickle.loads + dumps: Callable[..., bytes] = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) + loads: Callable[..., object] = pickle.loads class JSONSerializer: @@ -21,7 +28,7 @@ def loads(s, *args, **kwargs): return json.loads(s.decode('utf-8'), *args, **kwargs) -def resolve_serializer(serializer: Optional[Union[Type[DefaultSerializer], str]] = None) -> Type[DefaultSerializer]: +def resolve_serializer(serializer: Optional[Union[Type['Serializer'], str]] = None) -> Type['Serializer']: """This function checks the user defined serializer for ('dumps', 'loads') methods It returns a default pickle serializer if not found else it returns a MySerializer The returned serializer objects implement ('dumps', 'loads') methods @@ -37,8 +44,9 @@ def resolve_serializer(serializer: Optional[Union[Type[DefaultSerializer], str]] return DefaultSerializer if isinstance(serializer, str): - serializer = import_attribute(serializer) + serializer = import_attribute(serializer) # type: ignore[assignment] + assert not isinstance(serializer, str) default_serializer_methods = ('dumps', 'loads') for instance_method in default_serializer_methods: diff --git a/rq/types.py b/rq/types.py index fe8e002d2..4826b2d91 100644 --- a/rq/types.py +++ b/rq/types.py @@ -11,7 +11,7 @@ """ -JobDependencyType = TypeVar('JobDependencyType', 'Dependency', 'Job', str, List[Union['Dependency', 'Job']]) +JobDependencyType = TypeVar('JobDependencyType', 'Dependency', 'Job', str, List[Union['Dependency', 'Job', str]]) """Custom type definition for a job dependencies. A simple helper definition for the `depends_on` parameter when creating a job. """ diff --git a/rq/utils.py b/rq/utils.py index 522c91b68..792bf1d3b 100644 --- a/rq/utils.py +++ b/rq/utils.py @@ -12,7 +12,7 @@ import logging import numbers from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union if TYPE_CHECKING: from redis import Redis @@ -104,7 +104,7 @@ def import_attribute(name: str) -> Callable[..., Any]: if module is None: # maybe it's a builtin try: - return __builtins__[name] + return __builtins__[name] # type: ignore[index] except KeyError: raise ValueError('Invalid attribute name: %s' % name) @@ -124,7 +124,7 @@ def import_attribute(name: str) -> Callable[..., Any]: return getattr(attribute_owner, attribute_name) -def now(): +def now() -> datetime.datetime: """Return now in UTC""" return datetime.datetime.now(datetime.timezone.utc) @@ -226,7 +226,7 @@ def current_timestamp() -> int: return calendar.timegm(datetime.datetime.now(datetime.timezone.utc).utctimetuple()) -def backend_class(holder, default_name, override=None) -> TypeVar('T'): +def backend_class(holder, default_name, override=None) -> Type: """Get a backend class using its default attribute name or an override Args: @@ -240,14 +240,14 @@ def backend_class(holder, default_name, override=None) -> TypeVar('T'): if override is None: return getattr(holder, default_name) elif isinstance(override, str): - return import_attribute(override) + return import_attribute(override) # type: ignore[return-value] else: return override -def str_to_date(date_str: Optional[str]) -> Union[dt.datetime, Any]: +def str_to_date(date_str: Optional[bytes]) -> Optional[dt.datetime]: if not date_str: - return + return None else: return utcparse(date_str.decode()) @@ -258,6 +258,7 @@ def parse_timeout(timeout: Optional[Union[int, float, str]]) -> Optional[int]: try: timeout = int(timeout) except ValueError: + assert isinstance(timeout, str) digit, unit = timeout[:-1], (timeout[-1:]).lower() unit_second = {'d': 86400, 'h': 3600, 'm': 60, 's': 1} try: diff --git a/rq/worker.py b/rq/worker.py index f5a41c174..8992598db 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -158,6 +158,7 @@ def __init__( if not connection: connection = get_connection_from_queues(queues) + assert connection connection = self._set_connection(connection) self.connection = connection self.redis_server_version = None @@ -209,7 +210,7 @@ def __init__( self.scheduler: Optional[RQScheduler] = None self.pubsub: Optional['PubSub'] = None self.pubsub_thread = None - self._dequeue_strategy: DequeueStrategy = DequeueStrategy.DEFAULT + self._dequeue_strategy: Optional[DequeueStrategy] = DequeueStrategy.DEFAULT self.disable_default_exception_handler = disable_default_exception_handler @@ -306,6 +307,7 @@ def all( if queue: connection = queue.connection + assert connection worker_keys = worker_registration.get_keys(queue=queue, connection=connection) workers = [ cls.find_by_key( @@ -981,7 +983,7 @@ def unsubscribe(self): def dequeue_job_and_maintain_ttl( self, timeout: Optional[int], max_idle_time: Optional[int] = None - ) -> Tuple['Job', 'Queue']: + ) -> Optional[Tuple['Job', 'Queue']]: """Dequeues a job while maintaining the TTL. Returns: @@ -1039,8 +1041,6 @@ def dequeue_job_and_maintain_ttl( time.sleep(connection_wait_time) connection_wait_time *= self.exponential_backoff_factor connection_wait_time = min(connection_wait_time, self.max_connection_wait_time) - else: - connection_wait_time = 1.0 self.heartbeat() return result @@ -1499,6 +1499,8 @@ def handle_job_success(self, job: 'Job', queue: 'Queue', started_job_registry: S pipeline.execute() + assert job.started_at + assert job.ended_at time_taken = job.ended_at - job.started_at self.log.info( 'Successfully completed %s job in %ss on worker %s', job.description, time_taken, self.name