diff --git a/checks/checks.py b/checks/checks.py index bd715132a6..455bd56cab 100644 --- a/checks/checks.py +++ b/checks/checks.py @@ -1,226 +1,203 @@ -from functools import cached_property -from typing import Collection -from typing import Dict -from typing import Iterator +import logging +from collections import defaultdict from typing import Optional +from typing import Set from typing import Tuple -from typing import Type -from typing import TypeVar + +from django.conf import settings from checks.models import TrackedModelCheck -from checks.models import TransactionCheck -from common.business_rules import ALL_RULES from common.business_rules import BusinessRule from common.business_rules import BusinessRuleViolation -from common.models.trackedmodel import TrackedModel +from common.models import TrackedModel +from common.models import Transaction from common.models.utils import get_current_transaction from common.models.utils import override_current_transaction -CheckResult = Tuple[bool, Optional[str]] - +logger = logging.getLogger(__name__) -Self = TypeVar("Self") +CheckResult = Tuple[bool, Optional[str]] class Checker: - """ - A ``Checker`` is an object that knows how to perform a certain kind of check - against a model. - - Checkers can be applied against a model. The logic of the checker will be - run and the result recorded as a ``TrackedModelCheck``. - """ - - @cached_property - def name(self) -> str: - """ - The name string that on a per-model basis uniquely identifies the - checker. - - The name should be deterministic (i.e. not rely on the current - environment, memory locations or random data) so that the system can - record the name in the database and later use it to work out whether - this check has been run. The name doesn't need to include any details - about the model. - - By default this is the name of the class, but it can include any other - non-model data that is unique to the checker. For a more complex - example, see ``IndirectBusinessRuleChecker.name``. - """ - return type(self).__name__ - @classmethod - def checkers_for(cls: Type[Self], model: TrackedModel) -> Collection[Self]: + def run_rule( + cls, + rule: BusinessRule, + transaction: Transaction, + model: TrackedModel, + ) -> CheckResult: """ - Returns instances of this ``Checker`` that should apply to the model. + Run a single business rule on a single model. - What checks apply to a model is sometimes data-dependent, so it is the - responsibility of the ``Checker`` class to tell the system what - instances of itself it would expect to run against the model. For an - example, see ``IndirectBusinessRuleChecker.checkers_for``. + :return CheckResult, a Tuple(rule_passed: str, violation_reason: Optional[str]). """ - raise NotImplementedError() + logger.debug(f"run_rule %s %s %s", model, rule, transaction.pk) + try: + rule(transaction).validate(model) + logger.debug(f"%s [tx:%s] %s [passed]", model, rule, transaction.pk) + return True, None + except BusinessRuleViolation as violation: + reason = violation.args[0] + logger.debug(f"%s [tx:%s] %s [failed]", model, rule, transaction.pk, reason) + return False, reason - def run(self, model: TrackedModel) -> CheckResult: - """Runs Checker-dependent logic and returns an indication of success.""" - raise NotImplementedError() + @classmethod + def apply_rule( + cls, + rule: BusinessRule, + transaction: Transaction, + model: TrackedModel, + ): + """ + Applies the check to the model and records success. - def apply(self, model: TrackedModel, context: TransactionCheck): - """Applies the check to the model and records success.""" + :return: TrackedModelCheck instance containing the result of the check. + During debugging the developer can set settings.RAISE_BUSINESS_RULE_FAILURES + to True to raise business rule violations as exceptions. + """ success, message = False, None try: - with override_current_transaction(context.transaction): - success, message = self.run(model) + with override_current_transaction(transaction): + success, message = cls.run_rule(rule, transaction, model) except Exception as e: success, message = False, str(e) + if settings.RAISE_BUSINESS_RULE_FAILURES: + raise finally: - return TrackedModelCheck.objects.create( + check, created = TrackedModelCheck.objects.get_or_create( + { + "successful": success, + "message": message, + "content_hash": model.content_hash().digest(), + }, model=model, - transaction_check=context, - check_name=self.name, - successful=success, - message=message, + check_name=rule.__name__, ) - - -class BusinessRuleChecker(Checker): - """ - A ``Checker`` that runs a ``BusinessRule`` against a model. - - This class is expected to be sub-typed for a specific rule by a call to - ``of()``. - - Attributes: - checker_cache (dict): (class attribute) Cache of Business checkers created by ``of()``. - """ - - rule: Type[BusinessRule] - - _checker_cache: Dict[str, BusinessRule] = {} + if not created: + check.successful = success + check.message = message + check.content_hash = model.content_hash().digest() + check.save() + return check @classmethod - def of(cls: Type, rule_type: Type[BusinessRule]) -> Type: + def apply_rule_cached( + cls, + rule: BusinessRule, + transaction: Transaction, + model: TrackedModel, + ): """ - Return a subclass of a Checker, e.g. BusinessRuleChecker, - IndirectBusinessRuleChecker that runs the passed in business rule. + If a matching TrackedModelCheck instance exists, returns it, otherwise + check rule, and return the result as a TrackedModelCheck instance. - Example, creating a BusinessRuleChecker for ME32: + :return: TrackedModelCheck instance containing the result of the check. + """ + try: + check = TrackedModelCheck.objects.get( + model=model, + check_name=rule.__name__, + ) + except TrackedModelCheck.DoesNotExist: + logger.debug( + "apply_rule_cached (no existing check) %s, %s apply rule", + rule.__name__, + transaction, + ) + return cls.apply_rule(rule, transaction, model) + + # Re-run the rule if the check is not successful. + check_hash = bytes(check.content_hash) + model_hash = model.content_hash().digest() + if check_hash == model_hash: + logger.debug( + "apply_rule_cached (matching content hash) %s, tx: %s, using cached result %s", + rule.__name__, + transaction.pk, + check, + ) + return check - >>> BusinessRuleChecker.of(measures.business_rules.ME32) - + logger.debug( + "apply_rule_cached (check.content_hash != model.content_hash()) %s != %s %s, %s apply rule", + check_hash, + model_hash, + rule.__name__, + transaction, + ) + check.delete() + return cls.apply_rule(rule, transaction, model) - This API is usually called by .applicable_to, however this docstring should - illustrate what it does. - Checkers are created once and then cached in _checker_cache. +class BusinessRuleChecker(Checker): + """Apply BusinessRules specified in a TrackedModels business_rules + attribute.""" - As well as a small performance improvement, caching aids debugging by ensuring - the same checker instance is returned if the same cls is passed to ``of``. + @classmethod + def get_model_rules(cls, model: TrackedModel, rules: Optional[Set[str]] = None): """ - checker_name = f"{cls.__name__}Of[{rule_type.__module__}.{rule_type.__name__}]" - - # If the checker class was already created, return it. - checker_class = cls._checker_cache.get(checker_name) - if checker_class is not None: - return checker_class - # No existing checker was found, so create it: + :param model: TrackedModel instance + :param rules: Optional list of rule names to filter by. + :return: Dict mapping models to a set of the BusinessRules that apply to them. + """ + model_rules = defaultdict(set) - class BusinessRuleCheckerOf(cls): - # Creating this class explicitly in code is more readable than using type(...) - # Once created the name will be mangled to include the rule to be checked. + for rule in model.business_rules: + if rules is not None and rule.__name__ not in rules: + continue - f"""Apply the following checks as specified in {rule_type.__name__}""" - rule = rule_type + model_rules[model].add(rule) - def __repr__(self): - return f"<{checker_name}>" + return model_rules - BusinessRuleCheckerOf.__name__ = checker_name - cls._checker_cache[checker_name] = BusinessRuleCheckerOf - return BusinessRuleCheckerOf +class LinkedModelsBusinessRuleChecker(Checker): + """Apply BusinessRules specified in a TrackedModels indirect_business_rules + attribute to models returned by get_linked_models on those rules.""" @classmethod - def checkers_for(cls: Type[Self], model: TrackedModel) -> Collection[Self]: - """If the rule attribute on this BusinessRuleChecker matches any in the - supplied TrackedModel instance's business_rules, return it in a list, - otherwise there are no matches so return an empty list.""" - if cls.rule in model.business_rules: - return [cls()] - return [] - - def run(self, model: TrackedModel) -> CheckResult: - """ - :return CheckResult, a Tuple(rule_passed: str, violation_reason: Optional[str]). + def apply_rule( + cls, + rule: BusinessRule, + transaction: Transaction, + model: TrackedModel, + ): """ - transaction = get_current_transaction() - try: - self.rule(transaction).validate(model) - return True, None - except BusinessRuleViolation as violation: - return False, violation.args[0] - - -class IndirectBusinessRuleChecker(BusinessRuleChecker): - """ - A ``Checker`` that runs a ``BusinessRule`` against a model that is linked to - the model being checked, and for which a change in the checked model could - result in a business rule failure against the linked model. + LinkedModelsBusinessRuleChecker assumes that the linked models are still + the current. - This is a base class: subclasses for checking specific rules are created by - calling ``of()``. - """ + versions (TODO - ensure a business rule checks this), - rule: Type[BusinessRule] - linked_model: TrackedModel - - def __init__(self, linked_model: TrackedModel) -> None: - self.linked_model = linked_model - super().__init__() - - @cached_property - def name(self) -> str: - # Include the identity of the linked model in the checker name, so that - # each linked model needs to be checked for all checks to be complete. - return f"{super().name}[{self.linked_model.pk}]" + The transaction to check is set to that of the model, which enables + """ + return super().apply_rule(rule, model.transaction, model) @classmethod - def checkers_for(cls: Type[Self], model: TrackedModel) -> Collection[Self]: - """Return a set of IndirectBusinessRuleCheckers for every model found on - rule.get_linked_models.""" - rules = set() - transaction = get_current_transaction() - if cls.rule in model.indirect_business_rules: - for linked_model in cls.rule.get_linked_models(model, transaction): - rules.add(cls(linked_model)) - return rules - - def run(self, model: TrackedModel) -> CheckResult: + def get_model_rules(cls, model: TrackedModel, rules: Optional[Set] = None): """ - Return the result of running super.run, passing self.linked_model, and. - - return it as a CheckResult - a Tuple(rule_passed: str, violation_reason: Optional[str]) + :param model: TrackedModel instance + :param rules: Optional list of rule names to filter by. + :return: Dict mapping models to a set of the BusinessRules that apply to them. """ - result, message = super().run(self.linked_model) - message = f"{self.linked_model}: " + message if message else None - return result, message + tx = get_current_transaction() + + model_rules = defaultdict(set) + for rule in [*model.indirect_business_rules]: + for linked_model in rule.get_linked_models(model, tx): + if rules is not None and rule.__name__ not in rules: + continue -def checker_types() -> Iterator[Type[Checker]]: - """ - Return all registered Checker types. + model_rules[linked_model].add(rule) - See ``checks.checks.BusinessRuleChecker.of``. - """ - for rule in ALL_RULES: - yield BusinessRuleChecker.of(rule) - yield IndirectBusinessRuleChecker.of(rule) + return model_rules -def applicable_to(model: TrackedModel) -> Iterator[Checker]: - """Return instances of any Checker classes applicable to the supplied - TrackedModel instance.""" - for checker_type in checker_types(): - yield from checker_type.checkers_for(model) +# Checkers in priority list order, checkers for linked models come first. +ALL_CHECKERS = { + "LinkedModelsBusinessRuleChecker": LinkedModelsBusinessRuleChecker, + "BusinessRuleChecker": BusinessRuleChecker, +} diff --git a/checks/migrations/0004_auto_20220718_1653.py b/checks/migrations/0004_auto_20220718_1653.py new file mode 100644 index 0000000000..76e3157e3c --- /dev/null +++ b/checks/migrations/0004_auto_20220718_1653.py @@ -0,0 +1,31 @@ +# Generated by Django 3.1.14 on 2022-07-18 16:53 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("checks", "0003_auto_20220525_1046"), + ] + + operations = [ + migrations.RemoveField( + model_name="transactioncheck", + name="head_transaction", + ), + migrations.RemoveField( + model_name="transactioncheck", + name="latest_tracked_model", + ), + migrations.RemoveField( + model_name="transactioncheck", + name="transaction", + ), + migrations.DeleteModel( + name="TrackedModelCheck", + ), + migrations.DeleteModel( + name="TransactionCheck", + ), + ] diff --git a/checks/migrations/0005_trackedmodelcheck.py b/checks/migrations/0005_trackedmodelcheck.py new file mode 100644 index 0000000000..f471a2d07a --- /dev/null +++ b/checks/migrations/0005_trackedmodelcheck.py @@ -0,0 +1,53 @@ +# Generated by Django 3.1.14 on 2022-08-02 20:32 + +import django.db.models.deletion +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("common", "0006_modelcelerytask_taskmodel"), + ("checks", "0004_auto_20220718_1653"), + ] + + operations = [ + migrations.CreateModel( + name="TrackedModelCheck", + fields=[ + ( + "taskmodel_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="common.taskmodel", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("check_name", models.CharField(max_length=255)), + ("successful", models.BooleanField()), + ("message", models.TextField(null=True)), + ("content_hash", models.BinaryField(max_length=32, null=True)), + ( + "model", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="checks", + to="common.trackedmodel", + ), + ), + ], + options={ + "unique_together": {("model", "check_name")}, + }, + bases=("common.taskmodel", models.Model), + ), + ] diff --git a/checks/models.py b/checks/models.py index f9680fa399..971e7648e8 100644 --- a/checks/models.py +++ b/checks/models.py @@ -1,136 +1,34 @@ +import logging + from django.db import models from django.db.models import fields +from polymorphic.managers import PolymorphicManager -from checks.querysets import TransactionCheckQueryset +from checks.querysets import TrackedModelCheckQueryset +from common.models import TimestampedMixin +from common.models.celerytask import TaskModel from common.models.trackedmodel import TrackedModel -from common.models.transactions import Transaction - - -class TransactionCheck(models.Model): - """ - Represents an in-progress or completed check of a transaction for - correctness. - - The ``TransactionCheck`` gets created once the check starts and has a flag - to track completeness. - """ - - transaction = models.ForeignKey( - Transaction, - on_delete=models.CASCADE, - related_name="checks", - ) - - completed = fields.BooleanField(default=False) - """True if all of the checks expected to be carried out against the models - in this transaction have recorded any result.""" - - successful = fields.BooleanField(null=True) - """ - True if all of the checks carried out against the models in this - transaction returned a positive result. - - This value will be null until ``completed`` is `True`. - """ - head_transaction_id: int - head_transaction = models.ForeignKey( - Transaction, - on_delete=models.CASCADE, - ) - """ - The latest transaction in the stream of approved transactions (i.e. in the - REVISION partition) at the moment this check was carried out. +logger = logging.getLogger(__name__) - Once new transactions are commited and the head transaction is no longer the - latest, this check will no longer be an accurate signal of correctness - because the new transactions could include new data which would invalidate - the checks. (Unless the checked transaction < head transaction, in which - case it will always be correct.) - """ - tracked_model_count = fields.PositiveSmallIntegerField() +class TrackedModelCheck(TimestampedMixin, TaskModel): """ - The number of tracked models in the transaction at the moment this check was - carried out. + Represents the result of running a single check against a single model. - If something is removed from the transaction later, the number of tracked - models will no longer match. This is used to detect if the check is now - stale. + Stores `content_hash`, a hash of the content for validity checking of the + stored result. """ - latest_tracked_model = models.ForeignKey( - TrackedModel, - on_delete=models.CASCADE, - null=True, - ) - """ - The latest tracked model in the transaction at the moment this check was - carried out. - - If some models are removed and subsequent ones added to the transaction, the - count may be the same but the latest transaction will have a new primary - key. This is used to detect if the check is now stale. - """ - - model_checks: models.QuerySet["TrackedModelCheck"] - - objects: TransactionCheckQueryset = models.Manager.from_queryset( - TransactionCheckQueryset, - )() - - def save(self, *args, **kwargs): - """Computes the metadata we will need later to detect if the check is - current and fresh.""" - if not self.head_transaction_id: - self.head_transaction = Transaction.approved.last() - - self.tracked_model_count = self.transaction.tracked_models.count() - self.latest_tracked_model = self.transaction.tracked_models.order_by( - "pk", - ).last() - - return super().save(*args, **kwargs) - class Meta: - ordering = ( - "transaction__partition", - "transaction__order", - "head_transaction__partition", - "head_transaction__order", - ) - - constraints = ( - models.CheckConstraint( - check=( - models.Q(completed=False, successful__isnull=True) - | models.Q(completed=True, successful__isnull=False) - ), - name="completed_checks_include_successfulness", - ), - ) - - -class TrackedModelCheck(models.Model): - """ - Represents the result of running a single check against a single model. - - The ``TrackedModelCheck`` only gets created once the check is complete, and - hence success should always be known. The reason is that a single model - check is atomic (i.e. there is no smaller structure) and so it's either done - or not, and it can't be "resumed". - """ + unique_together = ("model", "check_name") + objects = PolymorphicManager.from_queryset(TrackedModelCheckQueryset)() model = models.ForeignKey( TrackedModel, related_name="checks", - on_delete=models.CASCADE, - ) - - transaction_check = models.ForeignKey( - TransactionCheck, - on_delete=models.CASCADE, - related_name="model_checks", + on_delete=models.SET_NULL, + null=True, ) check_name = fields.CharField(max_length=255) @@ -141,3 +39,14 @@ class TrackedModelCheck(models.Model): message = fields.TextField(null=True) """The text content returned by the check, if any.""" + + content_hash = models.BinaryField(max_length=32, null=True) + """ + Hash of the content ('copyable_fields') at the time the data was checked. + """ + + def __str__(self): + if self.successful: + return f"{self.model} {self.check_name} [Passed at {self.updated_at}]" + + return f"{self.model} {self.check_name} [Failed at {self.updated_at}, Message: {self.message}]" diff --git a/checks/querysets.py b/checks/querysets.py index 63a132e891..a429e4c1e3 100644 --- a/checks/querysets.py +++ b/checks/querysets.py @@ -1,173 +1,32 @@ -from django.contrib.postgres.aggregates import BoolOr -from django.db import models -from django.db.models import expressions -from django.db.models.aggregates import Count -from django.db.models.aggregates import Max -from django_cte import CTEQuerySet -from django_cte import With +from django.db.transaction import atomic +from polymorphic.query import PolymorphicQuerySet -from common.models.transactions import Transaction -from common.models.transactions import TransactionPartition -from common.models.utils import LazyTransaction -latest_transaction = LazyTransaction(get_value=Transaction.approved.last) - - -class TransactionCheckQueryset(CTEQuerySet): - currentness_filter = ( - # If the head transaction is ahead of the latest transaction then no new - # transactions have been committed since the check. In practice we only - # expect the head_transaction == latest_transaction but it doesn't hurt - # to be more defensive with a greater than check. - # - # head_transaction >= latest_transaction - models.Q( - head_transaction__partition__gt=latest_transaction.partition, - ) - | models.Q( - head_transaction__partition=latest_transaction.partition, - head_transaction__order__gte=latest_transaction.order, - ) - # If the head transaction was ahead of the checked transaction when the - # check was carried out, and the checked transaction is approved, we - # don't need to update anything because subsequent changes can't affect - # the older transaction. - # - # head_transaction >= checked_transaction AND checked_transaction is - # approved - | models.Q( - head_transaction__partition__gt=models.F("transaction__partition"), - transaction__partition__in=TransactionPartition.approved_partitions(), - ) - | models.Q( - head_transaction__partition=models.F("transaction__partition"), - head_transaction__order__gte=models.F("transaction__order"), - transaction__partition__in=TransactionPartition.approved_partitions(), - ) - ) - - freshness_fields = { - # See the field descriptions on ``TransactionCheck`` for details on - # how these fields are populated and used to calculate freshness. - "tracked_model_count": Count("transaction__tracked_models"), - "latest_tracked_model": Max("transaction__tracked_models__id"), - } - - freshness_annotations = { - f"real_{field}": expr for field, expr in freshness_fields.items() - } - - freshness_filter = ( - # Use the metadata on the transaction check to work out if the check - # still represents the data in the transaction. The "real_" fields are - # expected to be annotated onto the queryset and represent the current - # state of the transaction. - # - # A fresh check is where all the current values match the stored values. - models.Q(**{field: models.F(f"real_{field}") for field in freshness_fields}) - # ...or where there are no models to check, which is valid. - | models.Q( - real_latest_tracked_model__isnull=True, - latest_tracked_model__isnull=True, - ) - ) - - requires_update_filter = (~freshness_filter) | (~currentness_filter) - - requires_update_annotation = expressions.ExpressionWrapper( - expression=requires_update_filter, - output_field=models.fields.BooleanField(), - ) - - def current(self): +class TrackedModelCheckQueryset(PolymorphicQuerySet): + def delete(self): """ - A ``TransactionCheck`` is considered "current" if there hasn't been any - data added after the check that could change the result of the check. + Delete, modified to workaround a python bug that stops delete from + working when some fields are ByteFields. - If the checked transaction is in a draft partition, "current" means no - new transactions have been approved since the check was carried out. If - any have, they will now potentially be in scope of the check. + Details: - If the checked transaction is in an approved partition, "current" means - no transactions were approved between the check happening and the - transaction being committed to the approved partition (but some may have - been added after it, which can't affect its result). - """ - return self.filter(self.currentness_filter) + Using .delete() on a query with ByteFields does not work due to a python bug: + https://github.com/python/cpython/issues/95081 + >>> TrackedModelCheck.objects.filter( + model__transaction__workbasket=workbasket_pk, + ).delete() - def fresh(self): - """ - A ``TransactionCheck`` is considered "fresh" if the transaction that it - checked hasn't been modified since the check was carried out, which - could change the result of the check. + File /usr/local/lib/python3.8/copy.py:161, in deepcopy(x, memo, _nil) + 159 reductor = getattr(x, "__reduce_ex__", None) + 160 if reductor is not None: + --> 161 rv = reductor(4) + 162 else: + 163 reductor = getattr(x, "__reduce__", None) - The ``tracked_model_count`` and ``latest_tracked_model`` of the checked - transaction are cached on the check and used to detect this. - """ - return self.annotate(**self.freshness_annotations).filter(self.freshness_filter) - - def stale(self): - """A ``TransactionCheck`` is considered "stale" if the transaction that - it checked has been modified since the check was carried out, which - could change the result of the check.""" - return self.annotate(**self.freshness_annotations).exclude( - self.freshness_filter, - ) - - def requires_update(self, requirement=True, include_archived=False): - """ - A ``TransactionCheck`` requires an update if it or any check on a - transaction before it in order is stale or no longer current. - - If a ``TransactionCheck`` on an earlier transaction is stale, it means - that transaction has been modified since the check was done, which could - also invalidate any checks of any subsequent transactions. + TypeError: cannot pickle 'memoryview' object - By default transactions in ARCHIVED workbaskets are ignored, since these - workbaskets exist outside of the normal workflow. + Work around this by setting the bytefields to None and then calling delete. """ - - if include_archived: - ignore_filter = {} - else: - ignore_filter = {"transaction__workbasket__status": "ARCHIVED"} - - # First filtering out any objects we should ignore, - # work out for each check whether it alone requires an update, by - # seeing whether it is stale or not current. - basic_info = With( - self.model.objects.exclude(**ignore_filter) - .annotate(**self.freshness_annotations) - .annotate( - requires_update=self.requires_update_annotation, - ), - name="basic_info", - ) - - # Now cascade that result down to any subsequent transactions: if a - # transaction in the same workbasket comes later, then it will also - # require an update. TODO: do stale transactions pollute the update - # check for ever? - sequence_info = With( - basic_info.join(self.model.objects.all(), pk=basic_info.col.pk).annotate( - requires_update=expressions.Window( - expression=BoolOr(basic_info.col.requires_update), - partition_by=models.F("transaction__workbasket"), - order_by=[ - models.F("transaction__order").asc(), - models.F("pk").desc(), - ], - ), - ), - name="sequence_info", - ) - - # Now filter for only the type that we want: checks that either do or do - # not require an update. - return ( - sequence_info.join(self, pk=sequence_info.col.pk) - .with_cte(basic_info) - .with_cte(sequence_info) - .annotate(requires_update=sequence_info.col.requires_update) - .filter(requires_update=requirement) - ) + with atomic(): + self.update(content_hash=None) + return super().delete() diff --git a/checks/tasks.py b/checks/tasks.py index dd7d0981c1..f3808dd58c 100644 --- a/checks/tasks.py +++ b/checks/tasks.py @@ -1,172 +1,267 @@ -from itertools import cycle - -from celery import group +""" +Celery tasks and workflow. + +Build a workflow of tasks in one go and to pass to celery. +""" +import logging +from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple + +from celery import chain +from celery import chord from celery.utils.log import get_task_logger -from checks.checks import applicable_to -from checks.models import TransactionCheck +from checks.checks import ALL_CHECKERS +from checks.checks import Checker +from checks.models import TrackedModelCheck +from common.business_rules import ALL_RULES +from common.business_rules import BusinessRule from common.celery import app +from common.models.celerytask import ModelCeleryTask +from common.models.celerytask import bind_model_task from common.models.trackedmodel import TrackedModel from common.models.transactions import Transaction -from common.models.transactions import TransactionPartition -from common.models.utils import override_current_transaction +from common.models.utils import get_current_transaction +from workbaskets.models import WorkBasket # Celery logger adds the task id and status and outputs via the worker. logger = get_task_logger(__name__) +# Types for passing over celery +CheckerModelRule = Tuple[Checker, TrackedModel, Sequence[BusinessRule]] +"""CheckerModelRule stores a checker, model, and a sequence of rules to apply to it.""" -@app.task(time_limit=60) -def check_model(trackedmodel_id: int, context_id: int): - """ - Runs all of the applicable checkers on the passed model ID, and records the - results. +ModelPKInterval = Tuple[int, int] +"""ModelPKInterval is a tuple of (first_pk, last_pk) referring to a contiguous range of TrackedModels""" - Model checks are expected (from observation) to be short – on the order of - 6-7 seconds max. So if the check is taking considerably longer than this, it - is probably broken and should be killed to free up the worker. - """ +TaskInfo = Tuple[int, str] +"""TaskInfo is a tuple of (task_id, task_name) which can be used to create a ModelCeleryTask.""" - model: TrackedModel = TrackedModel.objects.get(pk=trackedmodel_id) - context: TransactionCheck = TransactionCheck.objects.get(pk=context_id) - transaction = context.transaction - - with override_current_transaction(transaction): - for check in applicable_to(model): - if not context.model_checks.filter( - model=model, - check_name=check.name, - ).exists(): - # Run the checker on the model and record the result. (This is - # not Celery ``apply`` but ``Checker.apply``). - check.apply(model, context) - - -@app.task -def is_transaction_check_complete(check_id: int) -> bool: - """Checks and returns whether the given transaction check is complete, and - records the success if so.""" - - check: TransactionCheck = TransactionCheck.objects.get(pk=check_id) - check.completed = True - - with override_current_transaction(check.transaction): - for model in check.transaction.tracked_models.all(): - applicable_checks = set(check.name for check in applicable_to(model)) - performed_checks = set( - check.model_checks.filter(model=model).values_list( - "check_name", - flat=True, - ), - ) - if applicable_checks != performed_checks: - check.completed = False - break - - if check.completed: - check.successful = not check.model_checks.filter(successful=False).exists() - logger.info("Completed checking %s", check.transaction.summary) +def get_checker_model_rules( + models: Sequence[TrackedModel], + rule_names: Optional[Set[str]] = None, +): + """ + Generator of model, rules. - check.save() - return check.completed + Given a sequence of models and a sequence of checkers + yield (model, [rules...]) + """ -def setup_or_resume_transaction_check(transaction: Transaction): - """Return a current, fresh transaction check for the passed transaction ID - and a list of model IDs that need to be checked.""" + for model in models: + for checker in ALL_CHECKERS.values(): + yield from ( + (checker, checker_model, checker_rules) + for checker_model, checker_rules in checker.get_model_rules( + model, + rule_names, + ).items() + ) - head_transaction = Transaction.approved.last() - existing_checks = TransactionCheck.objects.filter( - transaction=transaction, - head_transaction=head_transaction, - ) +@app.task(trail=True) +def check_model( + transaction_pk: int, + model_pk: int, + rule_names: Optional[Sequence[str]] = None, + bind_to_task_kwargs: Optional[Dict] = None, +): + """ + Task to check one model against one business rule and record the result. - up_to_date_check = existing_checks.requires_update(False).filter(completed=True) - if up_to_date_check.exists(): - return up_to_date_check.get(), [] + As this is a celery task, parameters are in base formats that can be serialised, such as int and str. - context = existing_checks.requires_update(False).filter(completed=False).last() - if context is None: - context = TransactionCheck( - transaction=transaction, - head_transaction=head_transaction, - ) - context.save() + Run one business rule against one model, this is called as part of the check_models workflow. - return ( - context, - transaction.tracked_models.values_list("pk", flat=True), - ) + By setting bind_to_task_uuid, the task will be bound to the celery task with the given UUID, + this is useful for tracking the progress of the parent task, and cancelling it if needed. + """ + # XXXX - TODO, re-add note on timings, from Simons original code. + + if rule_names is None: + rule_names = set(ALL_RULES.keys()) + + assert set(ALL_RULES.keys()).issuperset(rule_names) + + transaction = Transaction.objects.get(pk=transaction_pk) + model = TrackedModel.objects.get(pk=model_pk) + successful = True + + for checker in ALL_CHECKERS.values(): + for checker_model, rules in checker.get_model_rules(model, rule_names).items(): + """get_model_rules will return a different model in the case of + LinkedModelChecker, so the model to check use checker_model.""" + for rule in rules: + logger.debug( + "%s rule: %s, tx: %s, model: %s", + checker.__name__, + rule, + transaction, + model, + ) + check_result = checker.apply_rule_cached( + rule, + transaction, + checker_model, + ) + if bind_to_task_kwargs: + logger.debug( + "Binding result %s to task. bind_to_task_kwargs: %s", + check_result.pk, + bind_to_task_kwargs, + ) + bind_model_task(check_result, **bind_to_task_kwargs) + + logger.info( + f"Ran check %s %s", + check_result, + "✅" if check_result.successful else "❌", + ) + successful &= check_result.successful + + return successful + + +@app.task(trail=True) +def check_models_workflow( + pk_intervals: Sequence[ModelPKInterval], + bind_to_task_kwargs: Optional[Dict] = None, + rules: Optional[Sequence[str]] = None, +): + """ + Celery Workflow group containing 'check_model_rule' tasks to run applicable + rules from checkers on supplied models in parallel via a celery group. + If checkers is None, then default to all applicable checkers + (see get_model_rules) -@app.task(bind=True) -def check_transaction(self, transaction_id: int): - """Run and record checks for the passed transaction ID, asynchronously.""" + Models checked will be the exact model versions passed in, + this is useful for caching checks, e.g. those of linked_models + where an older model is referenced. - transaction = Transaction.objects.get(pk=transaction_id) - check, model_ids = setup_or_resume_transaction_check(transaction) - if check.completed and not any(model_ids): - logger.debug( - "Skipping check of %s because an up-to-date check already exists", - transaction.summary, + Callers should ensure models passed in are the correct version, + e.g. by using override_transaction. + """ + logger.debug("Build check_models_workflow") + + models = TrackedModel.objects.from_pk_intervals(*pk_intervals) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Got %s models", models.count()) + + return chord( + check_model.si( + model.transaction.pk, + model.pk, + rules, + bind_to_task_kwargs, ) - return - - # Create a workflow: firstly run all of the model checks (in parallel) and - # then once they are all done see if the transaction check is now complete. - logger.info("Beginning check of %s", transaction.summary) - workflow = group( - check_model.si(*args) for args in zip(model_ids, cycle([check.pk])) - ) | is_transaction_check_complete.si(check.pk) - - # Execute the workflow by replacing this task with it. - return self.replace(workflow) + for model in models + )(unbind_model_tasks.si([bind_to_task_kwargs["celery_task_id"]])) + + +@app.task(trail=True) +def cancel_workbasket_checks(workbasket_pk: int): + """Find existing celery tasks and, revoke them ande delete the + ModelCeleryTask objects tracking them.""" + celery_tasks = ( + ModelCeleryTask.objects.filter(celery_task_name="check_workbasket") + .update_task_statuses() + .filter_by_task_kwargs(workbasket_pk=workbasket_pk) + ) + # Terminate the existing tasks, using SIGUSR1 which triggers the soft timeout handler. + celery_tasks.revoke(terminate=True, signal="SIGUSR1") -def check_transaction_sync(transaction: Transaction): +@app.task(trail=True) +def get_workbasket_model_pk_intervals(workbasket_pk: int): """ - Run and record checks for the passed transaction ID, syncronously. + Return a list of all models in the workbasket. - This method will run all of the checks one after the other and won't return - until they are complete. This is useful for testing and debugging. + Ordinarily this step doesn't take very long, though for the seed workbasket + of 9 million items it may take around 6 seconds (measured on a consumer + laptop [Ryzen 2500u, 32gb ram]). """ - check, model_ids = setup_or_resume_transaction_check(transaction) - if check.completed and not any(model_ids): - logger.debug( - "Skipping check of transaction %s " - "because an up-to-date check already exists", - transaction.pk, - ) - else: - logger.info("Beginning synchronous check of %s", transaction.summary) - for model_id in model_ids: - check_model(model_id, check.pk) - is_transaction_check_complete(check.pk) - - -@app.task(bind=True, rate_limit="1/m") -def update_checks(self): + workbasket = WorkBasket.objects.get(pk=workbasket_pk) + pks = [*workbasket.tracked_models.as_pk_intervals()] + return pks + + +@app.task(trail=True) +def unbind_model_tasks(task_ids: Sequence[str]): + """Called at the end of a workflow, as there is no ongoing celery task + associated with this data.""" + logger.debug("Task_ids: [%s]", task_ids) + deleted = ModelCeleryTask.objects.filter(celery_task_id__in=task_ids).delete() + logger.debug("Deleted %s ModelCeleryTask objects", deleted[0]) + + +@app.task(bind=True, trail=True) +def check_workbasket( + self, + workbasket_pk: int, + current_transaction_pk: Optional[int] = None, + rules: Optional[Sequence[str]] = None, + clear_cache=False, +): """ - Triggers checking for any transaction that requires an update. - - A rate limit is specified here to mitigate instances where this - task stacks up and prevents other tasks from running by monopolising - the worker. - - TODO: Ensure this task is *not* stacking up and blocking the worker! + Orchestration task, that kicks off a workflow to check all models in the + workbasket. + + Cancels existing tasks if they are running, the system has caching which + will help with overlapping checks, cancelling existing checks will help keep + the celery queue clear of stale tasks, which is makes it easier to manage + when the system is under load. + + :param workbasket_pk: pk of the workbasket to check + :param current_transaction_pk: pk of the current transaction, defaults to the current highest transaction + :param rules: specify rule names to check (defaults to ALL_RULES) [mostly for testing/debugging] + :param clear_cache: clear the cache before checking [mostly for testing/debugging] """ - - ids_require_update = ( - Transaction.objects.exclude( - pk__in=TransactionCheck.objects.requires_update(False).values( - "transaction__pk", - ), - ) - .filter(partition=TransactionPartition.DRAFT) - .values_list("pk", flat=True) + logger.debug( + "check_workbasket, workbasket_pk: %s, current_transaction_pk %s, clear_cache %s", + workbasket_pk, + current_transaction_pk, + clear_cache, ) - # Execute a check for each transaction that requires an update by replacing - # this task with a parallel workflow. - return self.replace(group(check_transaction.si(id) for id in ids_require_update)) + if clear_cache: + # Clearing the cache should not be needed in the usual workflow, but may be useful e.g. if + # business rules are updated and need to be re-run. + TrackedModelCheck.objects.filter( + model__transaction__workbasket__pk=workbasket_pk, + ).delete() + + if current_transaction_pk is None: + current_transaction_pk = ( + get_current_transaction() or Transaction.objects.last() + ).pk + + # Use 'bind_to_task' to pass in the celery task id to associate this task and it's subtasks, while + # the task is running, allowing them to be revoked if the underlying data changes or another copy + # of the task is started. + # + # get_workbasket_model_pk_intervals gets tuples of (first_pk, last_pk), a compact form to + # represent the trackedmodels in the workbasket, which is passed to the subtasks tasks. + return chain( + cancel_workbasket_checks.si(workbasket_pk), + get_workbasket_model_pk_intervals.si(workbasket_pk), + check_models_workflow.s( + bind_to_task_kwargs={ + "celery_task_id": self.request.id, + "celery_task_name": "check_workbasket", + }, + rules=rules, + ), + )() + + +def check_workbasket_sync(workbasket: WorkBasket, clear_cache: bool = False): + # Run the celery task and wait + tx = get_current_transaction() + result = check_workbasket.delay(workbasket.pk, tx.pk, clear_cache) + result.wait() diff --git a/checks/tests/factories.py b/checks/tests/factories.py index 46c247d949..584413c590 100644 --- a/checks/tests/factories.py +++ b/checks/tests/factories.py @@ -1,12 +1,8 @@ from dataclasses import dataclass from typing import Optional -import factory - -from checks import models from checks.checks import Checker from common.models.trackedmodel import TrackedModel -from common.tests import factories @dataclass(frozen=True) @@ -19,71 +15,71 @@ def run(self, model: TrackedModel): return self.success, self.message -class TransactionCheckFactory(factory.django.DjangoModelFactory): - class Meta: - model = models.TransactionCheck - - transaction = factory.SubFactory( - factories.TransactionFactory, - draft=True, - ) - completed = True - successful = True - head_transaction = factory.SubFactory(factories.ApprovedTransactionFactory) - tracked_model_count = factory.LazyAttribute( - lambda check: (len(check.transaction.tracked_models.all())), - ) - latest_tracked_model = factory.SubFactory( - factories.TestModel1Factory, - transaction=factory.SelfAttribute("..transaction"), - ) - - class Params: - incomplete = factory.Trait( - completed=False, - successful=None, - ) - - empty = factory.Trait( - latest_tracked_model=None, - tracked_model_count=0, - ) - - -class StaleTransactionCheckFactory(TransactionCheckFactory): - class Meta: - exclude = ("first", "second") - - first = factory.SubFactory( - factories.TestModel1Factory, - transaction=factory.SelfAttribute("..transaction"), - ) - second = factory.SubFactory( - factories.TestModel1Factory, - transaction=factory.SelfAttribute("..transaction"), - ) - - latest_tracked_model = factory.SelfAttribute("second") - - @classmethod - def _after_postgeneration(cls, instance: TrackedModel, create, results=None): - """Save again the instance if creating and at least one hook ran.""" - super()._after_postgeneration(instance, create, results) - - if create: - assert instance.transaction.tracked_models.count() >= 2 - instance.transaction.tracked_models.first().delete() - - -class TrackedModelCheckFactory(factory.django.DjangoModelFactory): - class Meta: - model = models.TrackedModelCheck - - model = factory.SubFactory( - factories.TestModel1Factory, - transaction=factory.SelfAttribute("..transaction_check.transaction"), - ) - transaction_check = factory.SubFactory(TransactionCheckFactory) - check_name = factories.string_sequence() - successful = True - message = None +# class TransactionCheckFactory(factory.django.DjangoModelFactory): +# class Meta: +# model = models.TransactionCheck +# +# transaction = factory.SubFactory( +# factories.TransactionFactory, +# draft=True, +# ) +# completed = True +# successful = True +# head_transaction = factory.SubFactory(factories.ApprovedTransactionFactory) +# tracked_model_count = factory.LazyAttribute( +# lambda check: (len(check.transaction.tracked_models.all())), +# ) +# latest_tracked_model = factory.SubFactory( +# factories.TestModel1Factory, +# transaction=factory.SelfAttribute("..transaction"), +# ) +# +# class Params: +# incomplete = factory.Trait( +# completed=False, +# successful=None, +# ) +# +# empty = factory.Trait( +# latest_tracked_model=None, +# tracked_model_count=0, +# ) +# +# +# class StaleTransactionCheckFactory(TransactionCheckFactory): +# class Meta: +# exclude = ("first", "second") +# +# first = factory.SubFactory( +# factories.TestModel1Factory, +# transaction=factory.SelfAttribute("..transaction"), +# ) +# second = factory.SubFactory( +# factories.TestModel1Factory, +# transaction=factory.SelfAttribute("..transaction"), +# ) +# +# latest_tracked_model = factory.SelfAttribute("second") +# +# @classmethod +# def _after_postgeneration(cls, instance: TrackedModel, create, results=None): +# """Save again the instance if creating and at least one hook ran.""" +# super()._after_postgeneration(instance, create, results) +# +# if create: +# assert instance.transaction.tracked_models.count() >= 2 +# instance.transaction.tracked_models.first().delete() +# +# +# class TrackedModelCheckFactory(factory.django.DjangoModelFactory): +# class Meta: +# model = models.TrackedModelCheck +# +# model = factory.SubFactory( +# factories.TestModel1Factory, +# transaction=factory.SelfAttribute("..transaction_check.transaction"), +# ) +# transaction_check = factory.SubFactory(TransactionCheckFactory) +# check_name = factories.string_sequence() +# successful = True +# message = None diff --git a/checks/tests/test_checkers.py b/checks/tests/test_checkers.py index b61738c027..fc70af1364 100644 --- a/checks/tests/test_checkers.py +++ b/checks/tests/test_checkers.py @@ -4,8 +4,9 @@ import checks.tests.factories from checks.checks import BusinessRuleChecker -from checks.checks import IndirectBusinessRuleChecker -from checks.checks import checker_types +from checks.checks import LinkedModelsBusinessRuleChecker + +# from checks.checks import checker_types # TODO from common.tests import factories from common.tests.util import TestRule from common.tests.util import add_business_rules @@ -61,10 +62,10 @@ def test_indirect_business_rule_validation(): TestRule, indirect=True, ): - checker_type = IndirectBusinessRuleChecker.of(TestRule) + checker_type = LinkedModelsBusinessRuleChecker.of(TestRule) # Verify the cache returns the same object if .of is called a second time. - assert checker_type is IndirectBusinessRuleChecker.of(TestRule) + assert checker_type is LinkedModelsBusinessRuleChecker.of(TestRule) checkers = checker_type.checkers_for(model) diff --git a/checks/tests/test_tasks.py b/checks/tests/test_tasks.py index c359332472..12aaf5b293 100644 --- a/checks/tests/test_tasks.py +++ b/checks/tests/test_tasks.py @@ -6,12 +6,7 @@ from pytest_django.asserts import assertQuerysetEqual # type: ignore from checks import tasks -from checks.models import TransactionCheck from checks.tests import factories -from checks.tests.util import assert_requires_update -from common.models.transactions import TransactionPartition -from common.tests import factories as common_factories -from workbaskets.validators import WorkflowStatus pytestmark = pytest.mark.django_db @@ -91,130 +86,130 @@ def test_model_checking(check): assert check.model_checks.filter(successful=True).count() == num_successful -def test_completion_of_transaction_checks(check): - check, num_checks, num_completed, num_successful = check - expect_completed = num_completed == num_checks - expect_successful = (num_successful == num_checks) if expect_completed else None - - complete = tasks.is_transaction_check_complete(check.id) - assert complete == expect_completed - - check.refresh_from_db() - assert check.completed == expect_completed - assert check.successful == expect_successful - - -@pytest.mark.parametrize("check_already_exists", (True, False)) -def test_checking_of_transaction(check, check_already_exists): - check, num_checks, num_completed, num_successful = check - expect_completed = num_completed == num_checks - expect_successful = (num_successful == num_checks) if expect_completed else None - if expect_completed: - check.completed = True - check.successful = expect_successful - check.save() - - transaction = check.transaction - if not check_already_exists: - check.delete() - - # The task will replace itself with a new workflow. Testing this is hard. - # Instead, we will capture the new workflow and assert it is calling the - # right things. This is brittle but probably better than nothing. - with mock.patch("celery.app.task.Task.replace", new=lambda _, t: t): - workflow = tasks.check_transaction(transaction.id) # type: ignore - - check = TransactionCheck.objects.filter(transaction=transaction).get() - if expect_completed and check_already_exists: - # If the check is already done, it should be skipped. - assert workflow is None - else: - # If checks need to happen, the workflow should have one check task per - # model and finish with a decide task. - assert transaction.tracked_models.count() == len(workflow.tasks) - model_ids = set(transaction.tracked_models.values_list("id", flat=True)) - for task in workflow.tasks: - model_id, context_id = task.args - model_ids.remove(model_id) - assert task.task == tasks.check_model.name - assert context_id == check.id - - assert workflow.body.task == tasks.is_transaction_check_complete.name - assert workflow.body.args[0] == check.id - - -def test_detecting_of_transactions_to_update(): - head_transaction = common_factories.ApprovedTransactionFactory.create() - - # Transaction with no check - no_check = common_factories.UnapprovedTransactionFactory.create() - - # Transaction that does not require update - no_update = factories.TransactionCheckFactory.create( - head_transaction=head_transaction, - ) - assert_requires_update(no_update, False) - - # Transaction that requires update in DRAFT - draft_update = factories.StaleTransactionCheckFactory.create( - transaction__partition=TransactionPartition.DRAFT, - head_transaction=head_transaction, - ) - assert_requires_update(draft_update, True) - - # Transaction that requires update in REVISION - revision_update = factories.StaleTransactionCheckFactory.create( - transaction__partition=TransactionPartition.REVISION, - transaction__order=-(head_transaction.order), - head_transaction=head_transaction, - ) - assert_requires_update(revision_update, True) - - expected_transaction_ids = {no_check.id, draft_update.transaction.id} - - # The task will replace itself with a new workflow. Testing this is hard. - # Instead, we will capture the new workflow and assert it is calling the - # right things. This is brittle but probably better than nothing. - with mock.patch("celery.app.task.Task.replace", new=lambda _, t: t): - workflow = tasks.update_checks() # type: ignore - - assert set(t.task for t in workflow.tasks) == {tasks.check_transaction.name} - assert set(t.args[0] for t in workflow.tasks) == expected_transaction_ids - - -@pytest.mark.parametrize("include_archived", [True, False]) -@pytest.mark.parametrize( - "transaction_partition", [TransactionPartition.DRAFT, TransactionPartition.REVISION] -) -def test_archived_workbasket_checks(include_archived, transaction_partition): - """ - Verify transactions in ARCHIVED workbaskets do not require checking unless - include_archived is True. - """ - head_transaction = common_factories.ApprovedTransactionFactory.create() - - # Transaction that requires update in DRAFT or REVISION - transaction_check = factories.StaleTransactionCheckFactory.create( - transaction__partition=transaction_partition, - head_transaction=head_transaction, - ) - - all_checks = TransactionCheck.objects.filter(pk=transaction_check.pk) - initial_require_update = all_checks.requires_update(True, include_archived) - - # Initially the transaction should require update. - assert initial_require_update.count() == 1 - assert initial_require_update.get().pk == transaction_check.pk - - # Set workbasket status to ARCHIVED and verify requires_update only returns their transaction checks if - # include_archived is True - transaction_check.transaction.workbasket.status = WorkflowStatus.ARCHIVED - transaction_check.transaction.workbasket.save() - - checks_require_update = all_checks.requires_update(True, include_archived) - - if include_archived: - assert checks_require_update.count() == 1 - assert checks_require_update.get().pk == transaction_check.pk - else: - assert checks_require_update.count() == 0 +# def test_completion_of_transaction_checks(check): +# check, num_checks, num_completed, num_successful = check +# expect_completed = num_completed == num_checks +# expect_successful = (num_successful == num_checks) if expect_completed else None +# +# complete = tasks.is_transaction_check_complete(check.id) +# assert complete == expect_completed +# +# check.refresh_from_db() +# assert check.completed == expect_completed +# assert check.successful == expect_successful +# +# +# @pytest.mark.parametrize("check_already_exists", (True, False)) +# def test_checking_of_transaction(check, check_already_exists): +# check, num_checks, num_completed, num_successful = check +# expect_completed = num_completed == num_checks +# expect_successful = (num_successful == num_checks) if expect_completed else None +# if expect_completed: +# check.completed = True +# check.successful = expect_successful +# check.save() +# +# transaction = check.transaction +# if not check_already_exists: +# check.delete() +# +# # The task will replace itself with a new workflow. Testing this is hard. +# # Instead, we will capture the new workflow and assert it is calling the +# # right things. This is brittle but probably better than nothing. +# with mock.patch("celery.app.task.Task.replace", new=lambda _, t: t): +# workflow = tasks.check_transaction(transaction.id) # type: ignore +# +# check = TransactionCheck.objects.filter(transaction=transaction).get() +# if expect_completed and check_already_exists: +# # If the check is already done, it should be skipped. +# assert workflow is None +# else: +# # If checks need to happen, the workflow should have one check task per +# # model and finish with a decide task. +# assert transaction.tracked_models.count() == len(workflow.tasks) +# model_ids = set(transaction.tracked_models.values_list("id", flat=True)) +# for task in workflow.tasks: +# model_id, context_id = task.args +# model_ids.remove(model_id) +# assert task.task == tasks.check_model.name +# assert context_id == check.id +# +# assert workflow.body.task == tasks.is_transaction_check_complete.name +# assert workflow.body.args[0] == check.id +# +# +# def test_detecting_of_transactions_to_update(): +# head_transaction = common_factories.ApprovedTransactionFactory.create() +# +# # Transaction with no check +# no_check = common_factories.UnapprovedTransactionFactory.create() +# +# # Transaction that does not require update +# no_update = factories.TransactionCheckFactory.create( +# head_transaction=head_transaction, +# ) +# assert_requires_update(no_update, False) +# +# # Transaction that requires update in DRAFT +# draft_update = factories.StaleTransactionCheckFactory.create( +# transaction__partition=TransactionPartition.DRAFT, +# head_transaction=head_transaction, +# ) +# assert_requires_update(draft_update, True) +# +# # Transaction that requires update in REVISION +# revision_update = factories.StaleTransactionCheckFactory.create( +# transaction__partition=TransactionPartition.REVISION, +# transaction__order=-(head_transaction.order), +# head_transaction=head_transaction, +# ) +# assert_requires_update(revision_update, True) +# +# expected_transaction_ids = {no_check.id, draft_update.transaction.id} +# +# # The task will replace itself with a new workflow. Testing this is hard. +# # Instead, we will capture the new workflow and assert it is calling the +# # right things. This is brittle but probably better than nothing. +# with mock.patch("celery.app.task.Task.replace", new=lambda _, t: t): +# workflow = tasks.update_checks() # type: ignore +# +# assert set(t.task for t in workflow.tasks) == {tasks.check_transaction.name} +# assert set(t.args[0] for t in workflow.tasks) == expected_transaction_ids +# +# +# @pytest.mark.parametrize("include_archived", [True, False]) +# @pytest.mark.parametrize( +# "transaction_partition", [TransactionPartition.DRAFT, TransactionPartition.REVISION] +# ) +# def test_archived_workbasket_checks(include_archived, transaction_partition): +# """ +# Verify transactions in ARCHIVED workbaskets do not require checking unless +# include_archived is True. +# """ +# head_transaction = common_factories.ApprovedTransactionFactory.create() +# +# # Transaction that requires update in DRAFT or REVISION +# transaction_check = factories.StaleTransactionCheckFactory.create( +# transaction__partition=transaction_partition, +# head_transaction=head_transaction, +# ) +# +# all_checks = TransactionCheck.objects.filter(pk=transaction_check.pk) +# initial_require_update = all_checks.requires_update(True, include_archived) +# +# # Initially the transaction should require update. +# assert initial_require_update.count() == 1 +# assert initial_require_update.get().pk == transaction_check.pk +# +# # Set workbasket status to ARCHIVED and verify requires_update only returns their transaction checks if +# # include_archived is True +# transaction_check.transaction.workbasket.status = WorkflowStatus.ARCHIVED +# transaction_check.transaction.workbasket.save() +# +# checks_require_update = all_checks.requires_update(True, include_archived) +# +# if include_archived: +# assert checks_require_update.count() == 1 +# assert checks_require_update.get().pk == transaction_check.pk +# else: +# assert checks_require_update.count() == 0 diff --git a/checks/tests/util.py b/checks/tests/util.py index 4da60b856b..aa2a7a1301 100644 --- a/checks/tests/util.py +++ b/checks/tests/util.py @@ -1,6 +1,6 @@ from pytest_django.asserts import assertQuerysetEqual # type: ignore -from checks.models import TransactionCheck +# from checks.models import TransactionCheck # TODO def assert_queryset(queryset, expected): diff --git a/commodities/tests/test_business_rules.py b/commodities/tests/test_business_rules.py index 19be75982b..fede86d18e 100644 --- a/commodities/tests/test_business_rules.py +++ b/commodities/tests/test_business_rules.py @@ -1,7 +1,7 @@ import pytest from django.db import DataError -from checks.tasks import check_transaction_sync +# from checks.tasks import check_transaction_sync # TODO from commodities import business_rules from common.business_rules import BusinessRuleViolation from common.tests import factories diff --git a/common/migrations/0006_modelcelerytask_taskmodel.py b/common/migrations/0006_modelcelerytask_taskmodel.py new file mode 100644 index 0000000000..395b31a2c8 --- /dev/null +++ b/common/migrations/0006_modelcelerytask_taskmodel.py @@ -0,0 +1,92 @@ +# Generated by Django 3.1.14 on 2022-08-02 20:32 + +import django.db.models.deletion +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + + dependencies = [ + ("contenttypes", "0002_remove_content_type_name"), + ("common", "0005_transaction_index"), + ] + + operations = [ + migrations.CreateModel( + name="TaskModel", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "polymorphic_ctype", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="polymorphic_common.taskmodel_set+", + to="contenttypes.contenttype", + ), + ), + ], + options={ + "abstract": False, + "base_manager_name": "objects", + }, + ), + migrations.CreateModel( + name="ModelCeleryTask", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "celery_task_name", + models.CharField( + blank=True, + db_index=True, + max_length=64, + null=True, + ), + ), + ("celery_task_id", models.CharField(db_index=True, max_length=64)), + ("last_task_status", models.CharField(max_length=8)), + ( + "object", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="common.taskmodel", + ), + ), + ( + "polymorphic_ctype", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="polymorphic_common.modelcelerytask_set+", + to="contenttypes.contenttype", + ), + ), + ], + options={ + "unique_together": {("celery_task_id", "object")}, + }, + ), + ] diff --git a/common/models/__init__.py b/common/models/__init__.py index a55b86f6fc..f5b7f455ec 100644 --- a/common/models/__init__.py +++ b/common/models/__init__.py @@ -3,6 +3,8 @@ from common.fields import NumericSID from common.fields import ShortDescription from common.fields import SignedIntSID +from common.models.celerytask import ModelCeleryTask +from common.models.celerytask import TaskModel from common.models.mixins import TimestampedMixin from common.models.mixins.description import DescriptionMixin from common.models.mixins.validity import ValidityMixin @@ -13,6 +15,8 @@ __all__ = [ "ApplicabilityCode", + "TaskModel", + "ModelCeleryTask", "NumericSID", "ShortDescription", "SignedIntSID", diff --git a/common/models/celerytask.py b/common/models/celerytask.py new file mode 100644 index 0000000000..103156f81d --- /dev/null +++ b/common/models/celerytask.py @@ -0,0 +1,194 @@ +""" +Provide a way to link a Celery Task (usually referencable from a UUID) to a +django Model. + +This enables retrieving the realtime status of tasks while they are running. + +Once tasks have completed, these models should be deleted. +""" + +from celery.result import AsyncResult +from celery.utils.log import get_task_logger +from django.db import models +from polymorphic.models import PolymorphicModel +from polymorphic.query import PolymorphicQuerySet + +from common.celery import app as celery_app + +logger = get_task_logger(__name__) + + +class TaskModel(PolymorphicModel): + """ + Mixin for models that can be linked to a celery task. + + All celery specific functionality is at the other end of the relationship, + on ModelCeleryTask, leaving an extension point for other non-celery based + implementations. + """ + + +class ModelCeleryTaskQuerySet(PolymorphicQuerySet): + def filter_by_task_status(self, statuses=None): + """ + Note: Passing in task ids that are not known to Celery will return Tasks with 'PENDING' status, + as celery can't know if these are tasks that have not reached the broker yet or just don't exist. + """ + model_task_ids = ( + model_task.pk + for model_task in self + if statuses is None or model_task.get_celery_task_status() in statuses + ) + + return self.filter(pk__in=model_task_ids) + + def filter_by_task_kwargs(self, **kwargs): + def task_kwargs_match(task): + """ + :return: True if all the specified kwargs match those on the task. + """ + if task.kwargs is None: + return False + + for k, v in kwargs.items(): + if k not in task.kwargs or task.kwargs[k] != v: + return False + return True + + model_task_ids = ( + model_task.pk + for model_task in self + if task_kwargs_match(model_task.get_celery_task()) + ) + + return self.filter(pk__in=model_task_ids) + + def filter_by_task_args(self, *args): + model_task_ids = ( + model_task.pk + for model_task in self + if model_task.get_celery_task().result.args == args + ) + + return self.filter(pk__in=model_task_ids) + + def update_task_statuses(self): + """Update the last_task_status of all modeltasks in the queryset from + celery.""" + model_tasks = self # .all() + for model_task in model_tasks: + task_status = model_task.get_celery_task_status() + # 'PENDING' can mean the task is not yet known to celery, + # or it is a task that has not yet reached the broker, if + # the status goes *back* to 'PENDING' from a higher status + # then don't forget it the higher status. + if not model_task.last_task_status or task_status != "PENDING": + model_task.last_task_status = task_status + + self.model.objects.bulk_update( + model_tasks, + ["last_task_status"], + batch_size=2000, + ) + return model_tasks + + def delete_pending_tasks(self): + """""" + return self.filter_by_task_status("PENDING").delete() + + def revoke(self, **kwargs): + for task_id in self.values_list("celery_task_id", flat=True): + task = AsyncResult(task_id) + task.revoke(**kwargs) + + self.delete() + + +class ModelCeleryTask(PolymorphicModel): + """ + Provide a way to link a Celery Task (usually referencable from a UUID) to a + django Model. + + ModelCeleryTask instances should be created at the same time as the Celery Task they are + linked to. + + This is because 'PENDING' in Celery either means the task is queued or is returned for unknown + tasks. + """ + + class Meta: + unique_together = ("celery_task_id", "object") + + objects = ModelCeleryTaskQuerySet.as_manager() + + celery_task_name = models.CharField( + max_length=64, + null=True, + blank=True, + db_index=True, + ) + celery_task_id = models.CharField(max_length=64, db_index=True) + last_task_status = models.CharField(max_length=8) + + object = models.ForeignKey( + "common.TaskModel", + blank=True, + null=True, + default=None, + on_delete=models.CASCADE, + ) + + def get_celery_task(self): + """Get a reference to the Celery task instance.""" + return celery_app.AsyncResult(self.celery_task_id) + + def get_celery_task_status(self): + """Query celery and return the task status.""" + return self.get_celery_task().status + + @classmethod + def bind_model(cls, object: TaskModel, celery_task_id: str, celery_task_name: str): + """Link a Celery Task UUID to a django Model.""" + model_task, created = ModelCeleryTask.objects.get_or_create( + {"celery_task_name": celery_task_name}, + object=object, + celery_task_id=celery_task_id, + ) + if not created: + # Call save to update the last_task_status from celery. + # (on creation, save will have been called by django) + model_task.save() + + logger.debug("Bound celery task %s to %s", celery_task_id, object) + return model_task + + @classmethod + def unbind_model(cls, object: TaskModel): + """Unlink a Celery Task UUID from a django Model.""" + return ModelCeleryTask.objects.filter( + object=object, + ).delete() + + def save(self, *args, **kwargs): + """Override save to update the last_task_status from celery.""" + task_status = self.get_celery_task_status() + if not self.last_task_status or task_status != "PENDING": + # 'PENDING' can mean the task is not yet known to celery, + # or it is a task that has not yet reached the broker, if + # the status goes *back* to 'PENDING' from a higher status + # then don't forget it the higher status. + self.last_task_status = task_status + super().save(*args, **kwargs) + + def __repr__(self): + return f"" + + +def bind_model_task(object: TaskModel, celery_task_id: str, celery_task_name: str): + """Link a Celery Task UUID to a PolymorphicModel instance.""" + return ModelCeleryTask.bind_model(object, celery_task_id, celery_task_name) + + +def unbind_model_task(object: TaskModel): + """Link a Celery Task UUID to a PolymorphicModel model instance.""" + return ModelCeleryTask.unbind_model(object) diff --git a/common/models/tracked_qs.py b/common/models/tracked_qs.py index 2683a11326..5b7cc337c3 100644 --- a/common/models/tracked_qs.py +++ b/common/models/tracked_qs.py @@ -1,8 +1,10 @@ from __future__ import annotations from hashlib import sha256 +from itertools import chain from typing import List +from django.contrib.contenttypes.models import ContentType from django.db.models import Case from django.db.models import CharField from django.db.models import F @@ -24,6 +26,9 @@ from common.util import resolve_path from common.validators import UpdateType +PK_INTERVAL_CHUNK_SIZE = 1024 * 64 +"""Default chunk size for primary key intervals, see: `as_pk_intervals.`""" + class TrackedModelQuerySet( PolymorphicQuerySet, @@ -352,11 +357,142 @@ def follow_path(self, path: str) -> TrackedModelQuerySet: return qs.distinct() + def as_pk_intervals(self, chunk_size=PK_INTERVAL_CHUNK_SIZE): + """ + Given a sequence of primary keys, return interval tuples + of: ((first_pk, last_pk), ...) + + By default[1] this provides a much smaller wire format, than, for instance sending all primary keys, thus being sutable for use in Celery. + In the happy case, a single interval will be returned, but in the case of a large number of primary keys, multiple intervals will be returned - determined by chunk_size, gaps in the original data set will also generate more intervals. Gaps may be generated when users delete items workbaskets. + + + Chunking is provided to make it easy to chunk up data for consumers of this data (e.g. a celery task on the other end). + + Unscientifically testing this on a developers' laptop with the seed workbasket (pk=1) with > 9m models, this takes 9.2 seconds, to generate 3426 interval pairs, with 128kb chunks this generates 3430 pairs. + + [1] Under a pathological case, where every primary key iterated by more than one, this would be worse. + """ + qs = self + if qs.query.order_by != ("pk",): + qs = self.order_by("pk") + + pks = qs.values_list("pk", flat=True) + + model_iterator = iter(pks) + + try: + pk = next(model_iterator) + except StopIteration: + return + + first_pk = pk + item_count = 0 + try: + while True: + item_count += 1 + last_pk = pk + pk = next(model_iterator) + if (item_count > chunk_size) or (pk > last_pk + 1): + # Yield an interval tuple of (first_pk, last_pk) if the amount of items is more than the chunk size, + # or if the pks are not contiguous. + yield first_pk, last_pk + first_pk = pk + item_count = 0 + except StopIteration: + pass + + yield first_pk, pk + + def from_pk_intervals(self, *pk_intervals): + """ + Returns a queryset of TrackedModel objects that match the primary key + tuples, (start, end) + + To generate data in this format call as_pk_ranges on a queryset. + """ + q = Q() + for first_pk, last_pk in pk_intervals: + q |= Q(pk__gte=first_pk, pk__lte=last_pk) + + if not q: + # An empty filter would match everything, so return an empty queryset in that case. + return self.none() + + return self.filter(q) + + @classmethod + def _content_hash(cls, models): + """ + Implementation of content hashing, shared by content_hash and + content_hash_fast, should not be called directly, instead `content_hash` + or `content_hash_fast` should be called which impose order on the + models. + + Code is shared in this private method so the naive and fast + implementations return the same hash. + """ + sha = sha256() + for o in models: + sha.update(o.content_hash().digest()) + return sha + def content_hash(self): """ :return: Combined sha256 hash for all contained TrackedModels. + + Ordering is by TrackedModel primary key, so the hash will be stable across multiple queries. """ - sha = sha256() - for o in self: - sha.update(o.content_hash()) - return sha.digest() + return self._content_hash(self.order_by("pk").iterator()) + + def group_by_type(self): + """Yield a sequence of query sets, where each queryset contains only one + Polymorphic ctype, enabling the use of prefetch and select_related on + them.""" + pks = self.values_list("pk", flat=True) + polymorphic_ctypes = ( + self.non_polymorphic() + .distinct("polymorphic_ctype_id") + .values_list("polymorphic_ctype", flat=True) + ) + + for polymorphic_ctype in polymorphic_ctypes: + # Query contenttypes to get the concrete class instance + klass = ContentType.objects.get_for_id(polymorphic_ctype).model_class() + yield klass.objects.filter(pk__in=pks) + + def select_related_copyable_fields(self): + """Split models into separate querysets, using group_by_type and call + select_related on any related fields found in the `copyable_fields` + attribute.""" + pks = self.values_list("pk", flat=True) + for qs in self.group_by_type(): + # Work out which fields from copyable_fields may be use in select_related + related_fields = [ + field.name + for field in qs.model.copyable_fields + if hasattr(field, "related_query_name") + ] + yield qs.select_related(*related_fields).filter(pk__in=pks) + + def content_hash_fast(self): + """ + Use `select_related_copyable_fields` to call select_related on fields + that will be hashed. + + This increases the speed a little more than 2x, at the expense of keeping the data in memory. + On this developers' laptop 2.3 seconds vs 6.5 for the naive implementation in `content_hash`, + for larger amounts of data the difference got bigger, 23 seconds vs 90, though this may + because more types of data were represented. + + For larger workbaskets batching should be used to keep memory usage withing reasonable bounds. + + The hash value returned here should be the same as that from `content_hash`. + """ + # Fetch data using select_related, at this point the ordering + # will have been lost. + all_models = chain(*self.select_related_copyable_fields()) + + # Sort the data using trackedmodel_ptr_id, since the previous step outputs + # an iterable, sorted is used, instead of order_by on a queryset. + sorted_models = sorted(all_models, key=lambda o: o.trackedmodel_ptr_id) + return self._content_hash(sorted_models) diff --git a/common/models/tracked_utils.py b/common/models/tracked_utils.py index 34dc1de39d..7e1627cc0d 100644 --- a/common/models/tracked_utils.py +++ b/common/models/tracked_utils.py @@ -79,3 +79,40 @@ def get_deferred_set_fields(class_: type[Model]) -> Set[Field]: and hasattr(field.remote_field, "through") and field.remote_field.through._meta.auto_created } + + +def get_field_hashable_string(value): + """ + Given a field return a hashable string, containing the fields type and + value, ensuring uniqueness across types. + + For fields that are TrackedModels, delegate to their content_hash method. + For non TrackedModels return a combination of type and value. + """ + from common.models.trackedmodel import TrackedModel + + value_type = type(value) + if isinstance(value, TrackedModel): + # For TrackedModel fields use their content_hash, the type is still included as a debugging aid. + value_hash = value.content_hash().hexdigest() + return ( + f"{value_type.__module__}:{value_type.__name__}.content_hash={value_hash}" + ) + + return f"{value_type.__module__}:{value_type.__name__}={value}" + + +def get_field_hashable_strings(instance, fields): + """ + Given a model instance, return a dict of {field names: hashable string}, + + This calls `get_field_hashable_string` to generate strings unique to the type and value of the fields. + + :param instance: The model instance to generate hashes for. + :param fields: The fields to get use in the hash. + :return: Dictionary of {field_name: hash} + """ + return { + field.name: get_field_hashable_string(getattr(instance, field.name)) + for field in fields + } diff --git a/common/models/trackedmodel.py b/common/models/trackedmodel.py index 52d3f2c8e4..f8ae13ae3c 100644 --- a/common/models/trackedmodel.py +++ b/common/models/trackedmodel.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import lru_cache from hashlib import sha256 from json import dumps from typing import Any @@ -31,6 +32,7 @@ from common.models.managers import TrackedModelManager from common.models.tracked_qs import TrackedModelQuerySet from common.models.tracked_utils import get_deferred_set_fields +from common.models.tracked_utils import get_field_hashable_strings from common.models.tracked_utils import get_models_linked_to from common.models.tracked_utils import get_relations from common.models.tracked_utils import get_subrecord_relations @@ -661,6 +663,7 @@ def get_url_pattern_name_prefix(cls): prefix = cls._meta.verbose_name.replace(" ", "_") return prefix + @lru_cache(maxsize=None) def content_hash(self): """ Hash of the user editable content, used by business rule checks for @@ -668,14 +671,12 @@ def content_hash(self): :return: 32 character sha256 'digest', see hashlib.sha256. """ - content = { - field.name: str(getattr(self, field.name)) for field in self.copyable_fields - } - - # The json encoder ensures a somewhat regular format and everything - # passed to it must be hashable. - hashable = dumps(content).encode("utf-8") + # The json encoder ensures a somewhat regular format and ensures only simple data types can be passed in, + # in testing the speed of json encoding is around the same speed as stringifying. + hashable = dumps(get_field_hashable_strings(self, self.copyable_fields)).encode( + "utf-8", + ) sha = sha256() sha.update(hashable) - return sha.digest() + return sha diff --git a/common/models/transactions.py b/common/models/transactions.py index 48f50a0ac0..27abc61b07 100644 --- a/common/models/transactions.py +++ b/common/models/transactions.py @@ -95,7 +95,7 @@ def preorder_negative_transactions(self) -> None: order += 1 tx.order = order - type(self).objects.bulk_update(transactions, ["order"]) + self.model.objects.bulk_update(transactions, ["order"]) @atomic def move_to_end_of_partition(self, partition) -> None: diff --git a/common/models/utils.py b/common/models/utils.py index 3d655d5d2a..d2c841ef5a 100644 --- a/common/models/utils.py +++ b/common/models/utils.py @@ -65,11 +65,12 @@ def __init__(self, func): self.func = func def __str__(self): - return self.func() + return str(self.func()) @wrapt.decorator def lazy_string(wrapped, instance, *args, **kwargs): + """Decorator that will evaluate the wrapped function when stringified.""" return LazyString(wrapped) diff --git a/common/tests/test_business_rules.py b/common/tests/test_business_rules.py index ffeccbf311..446231709d 100644 --- a/common/tests/test_business_rules.py +++ b/common/tests/test_business_rules.py @@ -20,11 +20,6 @@ pytestmark = pytest.mark.django_db -class TestRule(BusinessRule): - __test__ = False - validate = MagicMock() - - def test_business_rule_violation_message(): model = MagicMock() violation = TestRule(model.transaction).violation(model) diff --git a/common/tests/test_models.py b/common/tests/test_models.py index 7fcbfa4f33..4c88b8a8da 100644 --- a/common/tests/test_models.py +++ b/common/tests/test_models.py @@ -28,7 +28,8 @@ from regulations.models import Group from regulations.models import Regulation from taric.models import Envelope -from workbaskets.tasks import check_workbasket_sync + +# from workbaskets.tasks import check_workbasket_sync # TODO pytestmark = pytest.mark.django_db diff --git a/exporter/management/commands/dump_transactions.py b/exporter/management/commands/dump_transactions.py index 9fffcae9ea..b329481577 100644 --- a/exporter/management/commands/dump_transactions.py +++ b/exporter/management/commands/dump_transactions.py @@ -1,3 +1,5 @@ +import ast +import itertools import os import sys @@ -52,7 +54,7 @@ def add_arguments(self, parser): "with a comma-separated list of workbasket ids." ), nargs="*", - type=int, + type=ast.literal_eval, default=None, action="store", ) @@ -76,7 +78,7 @@ def add_arguments(self, parser): def handle(self, *args, **options): workbasket_ids = options.get("workbasket_ids") if workbasket_ids: - query = dict(id__in=workbasket_ids) + query = dict(id__in=itertools.chain.from_iterable(workbasket_ids)) else: query = dict(status=WorkflowStatus.APPROVED) diff --git a/footnotes/tests/test_views.py b/footnotes/tests/test_views.py index 3f776ee672..ad20d8a47e 100644 --- a/footnotes/tests/test_views.py +++ b/footnotes/tests/test_views.py @@ -13,7 +13,8 @@ from common.views import TrackedModelDetailMixin from footnotes.models import Footnote from footnotes.views import FootnoteList -from workbaskets.tasks import check_workbasket_sync + +# from workbaskets.tasks import check_workbasket_sync # TODO pytestmark = pytest.mark.django_db diff --git a/pii-ner-exclude.txt b/pii-ner-exclude.txt index 10408aeb88..32b9c64512 100644 --- a/pii-ner-exclude.txt +++ b/pii-ner-exclude.txt @@ -1150,3 +1150,64 @@ param kwargs: Enum sha256 hashlib.sha256 +" Generator of model +XXXX - TODO +Celery Workflow +XXXX TODO +is_transaction_check_complete(check_id +check_id +up_to_date_check.get +rate_limit="1 +" Generator of model +XXXX - TODO +Celery Workflow +XXXX TODO +is_transaction_check_complete(check_id +check_id +up_to_date_check.get +rate_limit="1 +mock.patch("celery.app.task +assert transaction.tracked_models.count +assert workflow.body.args[0 +assert initial_require_update.count +assert checks_require_update.get().pk +TrackedModelsCheck +TrackedModelChecks +the Celery Task +GenericRelation +Unlink a Celery Task UUID +SubFactory +TransactionFactory +Trait +assert instance.transaction.tracked_models.count +TrackedModelCheckFactory(factory.django +Sequence +Found existing TrackedModelsCheck % +check_models(model_pks +Split +trackedmodel_ptr_id +BusinessRules +finish_models_check +Remove +check_workbasket_models( +clear_cache +TrackedModelsCheckStatus(enum +TrackedModelsCheck(TimestampedMixin +WorkbasketCheck(TrackedModelsCheck +TrackedModelsCheckChunks +TaskModel +PolymorphicModel +WorkBasketOutputFormat Enum +TaskInfo +ByteFields +get_or_create +models.content_hash +SubTask Waiting +checks.models +Build +SIGUSR1 +AsyncResults +GroupResults +f"{tuple(res.args +GroupResult +self.stdout.flush diff --git a/quotas/models.py b/quotas/models.py index 584cb5bb83..5448155e6c 100644 --- a/quotas/models.py +++ b/quotas/models.py @@ -125,6 +125,9 @@ class QuotaOrderNumberOrigin(TrackedModel, ValidityMixin): UpdateValidity, ) + def __str__(self): + return self.sid + def order_number_in_use(self, transaction): return self.order_number.in_use(transaction) diff --git a/settings/common.py b/settings/common.py index d530cc188b..5d0caedab2 100644 --- a/settings/common.py +++ b/settings/common.py @@ -331,8 +331,8 @@ AWS_S3_SIGNATURE_VERSION = "s3v4" AWS_S3_REGION_NAME = "eu-west-2" +# For info on celery settings see the docs at https://docs.celeryq.dev/en/stable/userguide/configuration.html # Pickle could be used as a serializer here, as this always runs in a DMZ - CELERY_BROKER_URL = os.environ.get("CELERY_BROKER_URL", CACHES["default"]["LOCATION"]) if VCAP_SERVICES.get("redis"): @@ -350,12 +350,29 @@ CELERY_TIMEZONE = TIME_ZONE CELERY_WORKER_POOL_RESTARTS = True # Restart worker if it dies -CELERY_BEAT_SCHEDULE = { - "sqlite_export": { - "task": "exporter.sqlite.tasks.export_and_upload_sqlite", - "schedule": timedelta(minutes=30), - }, -} +CELERY_RESULT_EXTENDED = True # Adds Task name, args, kwargs to results. + +# The following settings are usually useful for development, but not for production. +CELERY_TASK_ALWAYS_EAGER = is_truthy(os.environ.get("CELERY_TASK_ALWAYS_EAGER", "N")) +CELERY_TASK_EAGER_PROPAGATES = is_truthy( + os.environ.get("CELERY_TASK_EAGER_PROPAGATES", "N"), +) +CELERY_TASK_REMOTE_TRACEBACKS = is_truthy( + os.environ.get("CELERY_TASK_REMOTE_TRACEBACKS", "N"), +) + +CELERY_BEAT_SCHEDULE = {} +if False: + CELERY_BEAT_SCHEDULE = { + "sqlite_export": { + "task": "exporter.sqlite.tasks.export_and_upload_sqlite", + "schedule": timedelta(minutes=30), + }, + } + +RAISE_BUSINESS_RULE_FAILURES = is_truthy( + os.environ.get("RAISE_BUSINESS_RULE_FAILURES", "N"), +) SQLITE_EXCLUDED_APPS = [ "checks", diff --git a/workbaskets/management/commands/list_workbaskets.py b/workbaskets/management/commands/list_workbaskets.py index 7eecbe8722..631ecbee2a 100644 --- a/workbaskets/management/commands/list_workbaskets.py +++ b/workbaskets/management/commands/list_workbaskets.py @@ -56,7 +56,8 @@ def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( "workbasket_ids", - help=("Comma-separated list of workbasket ids to filter to"), + nargs="*", + help="Comma-separated list of workbasket ids to filter to", type=ast.literal_eval, ) diff --git a/workbaskets/management/commands/run_checks.py b/workbaskets/management/commands/run_checks.py new file mode 100644 index 0000000000..6c295db971 --- /dev/null +++ b/workbaskets/management/commands/run_checks.py @@ -0,0 +1,208 @@ +import logging +import signal +from typing import Any +from typing import Dict +from typing import Optional + +from celery.result import AsyncResult +from celery.result import GroupResult +from django.core.management import BaseCommand +from django.core.management.base import CommandParser + +from checks.models import TrackedModelCheck +from workbaskets.models import WorkBasket + +logger = logging.getLogger(__name__) + + +CLEAR_TO_END_OF_LINE = "\x1b[K" + + +def revoke_task_and_children(task, depth=0): + """ + Revoke a task by task_id. + + Uses SIGUSR1, which invokes the SoftTimeLimitExceeded exception, this is + more friendly than plain terminate, which may kill other tasks in the + worker. + """ + if task.children: + for subtask in task.children: + revoke_task_and_children(subtask, depth + 1) + + task.revoke(terminate=True, signal="SIGUSR1") + yield task, depth + + +class TaskControlMixin: + # Implementing classes can populate this with + IGNORE_TASK_PREFIXES = [] + + def get_readable_task_name(self, node): + """Optionally remove a prefix from the task name (used to remove the + module which is often repeated)""" + task_name = getattr(node, "name") or "" + for prefix in self.IGNORE_TASK_PREFIXES: + unprefixed = task_name.replace(prefix, "") + if unprefixed != task_name: + return unprefixed + + return task_name + + def revoke_task_and_children_and_display_result(self, task): + """Call revoke_task_and_children and display information on each revoked + task.""" + for revoked_task, depth in revoke_task_and_children(task): + if isinstance(task, AsyncResult): + self.stdout.write( + " " * depth + + f"{getattr(revoked_task, 'name', None) or '-'} [{revoked_task.id}] {revoked_task.status}", + ) + + def revoke_task_on_sigint(self, task): + """ + Connect a signal handler to attempt to revoke a task if the user presses + Ctrl+C. + + Due to the way tasks travel through Celery, not all tasks can be + revoked. + """ + + def sigint_handler(sig, frame): + """Revoke celery task with task_id.""" + self.stdout.write(f"Received SIGINT, revoking task {task.id} and children.") + self.revoke_task_and_children_and_display_result(task) + + raise SystemExit(1) + + signal.signal(signal.SIGINT, sigint_handler) + + def display_task(self, node, value, depth): + """Default task display.""" + # For now this only shows args, to avoid some noise - if this is more widely + # used that may be something that could be configured by the caller. + readable_task_name = self.get_readable_task_name( + node, + ) + self.stdout.write( + " " * depth * 2 + f"{readable_task_name} " + f"{tuple(node.args)}", + ) + + def iterate_ongoing_tasks(self, result, ignore_groupresult=True): + """ + Iterate over the ongoing tasks as they are received, track their. + + depth - which is useful for visual formatting. + + Yields: (node, value, depth) + """ + task_depths: Dict[str, int] = {} + task_depths[result.id] = 0 + + for parent_id, node in result.iterdeps(intermediate=True): + value = node.get() + + depth = task_depths.get(parent_id, -1) + if isinstance(node, GroupResult) and ignore_groupresult: + # GroupResult is ignored: store it so looking up depth works + # but do not increase the indent.. + task_depths[node.id] = depth + continue + else: + task_depths[node.id] = depth + 1 + + yield node, value, depth + + def display_ongoing_tasks(self, result): + for node, value, depth in self.iterate_ongoing_tasks(result): + self.display_task(node, value, depth) + + +class Command(TaskControlMixin, BaseCommand): + IGNORE_TASK_PREFIXES = [ + "checks.tasks.", + ] + + passed = 0 + failed = 0 + + help = ( + "Run all business rule checks against a WorkBasket's TrackedModels in Celery." + ) + + def add_arguments(self, parser: CommandParser) -> None: + parser.add_argument("WORKBASKET_PK", type=int) + parser.add_argument("--clear-cache", action="store_true", default=False) + parser.add_argument( + "--throw", + help="Allow failing celery tasks to throw exceptions [dev setting]", + action="store_true", + default=False, + ) + + def display_check_model_task(self, node, value, depth): + model_pk = node.args[1] + check_passed = value + readable_task_name = self.get_readable_task_name(node) + style = self.style.SUCCESS if check_passed else self.style.ERROR + self.stdout.write( + " " * depth * 2 + + f"{readable_task_name} " + + style( + f"{TrackedModelCheck.objects.filter(model=model_pk).last()}", + ), + ) + + def display_task(self, node, value, depth): + """Custom display for check_model tasks, acculate their passes / + fails.""" + task_name = getattr(node, "name", None) + if task_name == "checks.tasks.check_model": + self.display_check_model_task(node, value, depth) + if value: + self.passed += 1 + else: + self.failed += 1 + else: + super().display_task(node, value, depth) + + def handle(self, *args: Any, **options: Any) -> Optional[str]: + from checks.tasks import check_workbasket + + # Get the workbasket first + workbasket = WorkBasket.objects.get( + pk=int(options["WORKBASKET_PK"]), + ) + clear_cache = options["clear_cache"] + rule_names = None + throw = options["throw"] + + # Temporarily display a message while waiting for celery, this will only have time to show up + # if celery isn't working (easy enough on a dev machine), or is busy. + self.stdout.write("Connecting to celery... ⌛", ending="") + self.stdout._out.flush() # self.stdout.flush() doesn't result in any output - should report as a bug to django. + result = check_workbasket.apply_async( + args=( + workbasket.pk, + None, + ), + kwargs={ + "clear_cache": clear_cache, + "rules": rule_names, + }, + throw=throw, + ) + result.wait() + self.stdout.write(f"\r{CLEAR_TO_END_OF_LINE}") + + # Attach a handler to revoke the task and its subtasks if the user presses Ctrl+C + self.revoke_task_on_sigint(result) + + self.display_ongoing_tasks(result) + self.stdout.write() + + style = self.style.ERROR if self.failed else self.style.SUCCESS + self.stdout.write(style(f"Failed: {self.failed}")) + self.stdout.write(style(f"Passed: {self.passed}")) + self.stdout.write() + return 1 if self.failed else 0 diff --git a/workbaskets/management/commands/sync_run_checks.py b/workbaskets/management/commands/sync_run_checks.py index a925c10971..aacc133f26 100644 --- a/workbaskets/management/commands/sync_run_checks.py +++ b/workbaskets/management/commands/sync_run_checks.py @@ -7,7 +7,8 @@ from workbaskets.management.util import WorkBasketCommandMixin from workbaskets.models import WorkBasket -from workbaskets.tasks import check_workbasket_sync + +# from workbaskets.tasks import check_workbasket_sync # TODO logger = logging.getLogger(__name__) diff --git a/workbaskets/management/util.py b/workbaskets/management/util.py index 3c6be99c34..59c89f64f8 100644 --- a/workbaskets/management/util.py +++ b/workbaskets/management/util.py @@ -22,8 +22,14 @@ def _output_workbasket_readable( self.stdout.write(f"{spaces}reason: {first_line_of(workbasket.reason)}") self.stdout.write(f"{spaces}status: {workbasket.status}") if show_transaction_info: + transactions = workbasket.transactions + first_pk = ( + workbasket.transactions.first().pk if transactions.count() else "-" + ) + last_pk = workbasket.transactions.last().pk if transactions.count() else "-" + self.stdout.write( - f"{spaces}transactions: {workbasket.transactions.first().pk} - {workbasket.transactions.last().pk}", + f"{spaces}transactions: {first_pk} - {last_pk} [{transactions.count()}]", ) def _output_workbasket_compact(self, workbasket, show_transaction_info, **kwargs): @@ -32,8 +38,13 @@ def _output_workbasket_compact(self, workbasket, show_transaction_info, **kwargs ending="" if show_transaction_info else "\n", ) if show_transaction_info: + transactions = workbasket.transactions + first_pk = ( + workbasket.transactions.first().pk if transactions.count() else "-" + ) + last_pk = workbasket.transactions.last().pk if transactions.count() else "-" self.stdout.write( - f", {workbasket.transactions.first().pk} - {workbasket.transactions.last().pk}", + f", {first_pk} - {last_pk} [{transactions.count()}]", ) def output_workbasket( diff --git a/workbaskets/models.py b/workbaskets/models.py index 53f1e6d0ef..21392f489b 100644 --- a/workbaskets/models.py +++ b/workbaskets/models.py @@ -7,7 +7,6 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.core.exceptions import ValidationError from django.db import models from django.db.models import QuerySet from django.db.models import Subquery @@ -15,7 +14,7 @@ from django_fsm import transition from checks.models import TrackedModelCheck -from checks.models import TransactionCheck +from common.models import ModelCeleryTask from common.models.mixins import TimestampedMixin from common.models.tracked_qs import TrackedModelQuerySet from common.models.trackedmodel import TrackedModel @@ -328,10 +327,11 @@ def submit_for_approval(self): if not self.transactions.exists(): return - if self.unchecked_or_errored_transactions.exists(): - raise ValidationError( - "Transactions have not yet been fully checked or contain errors", - ) + # TODO + # if self.unchecked_or_errored_transactions.exists(): + # raise ValidationError( + # "Transactions have not yet been fully checked or contain errors", + # ) @transition( field=status, @@ -486,23 +486,8 @@ def tracked_model_check_errors(self): ) def delete_checks(self): - """Delete all TrackedModelCheck and TransactionCheck instances related - to the WorkBasket.""" - TrackedModelCheck.objects.filter( - transaction_check__transaction__workbasket=self, - ).delete() - TransactionCheck.objects.filter( - transaction__workbasket=self, - ).delete() - - @property - def unchecked_or_errored_transactions(self): - return self.transactions.exclude( - pk__in=TransactionCheck.objects.requires_update(False) - .filter( - completed=True, - successful=True, - transaction__workbasket=self, - ) - .values("transaction__pk"), - ) + """Delete all TrackedModelCheck and ModelCeleryTask instances related to + the WorkBasket.""" + checks = TrackedModelCheck.objects.filter(model__transaction__workbasket=self) + ModelCeleryTask.objects.filter(object__in=checks).delete() + checks.delete() diff --git a/workbaskets/tasks.py b/workbaskets/tasks.py index cffd3464bb..f66b8ae67c 100644 --- a/workbaskets/tasks.py +++ b/workbaskets/tasks.py @@ -1,11 +1,9 @@ -from celery import group +"""Also see checks.tasks, which contains check_workbasket task which checks +business rules.""" from celery import shared_task from celery.utils.log import get_task_logger from django.db.transaction import atomic -from checks.tasks import check_transaction -from checks.tasks import check_transaction_sync -from common.celery import app from workbaskets.models import WorkBasket # Celery logger adds the task id and status and outputs via the worker. @@ -26,34 +24,3 @@ def transition(instance_id: int, state: str, *args): getattr(instance, state)(*args) instance.save() logger.info("Transitioned workbasket %s to state %s", instance_id, instance.status) - - -@app.task(bind=True) -def check_workbasket(self, workbasket_id: int): - """Run and record transaction checks for the passed workbasket ID, - asynchronously.""" - - workbasket: WorkBasket = WorkBasket.objects.get(pk=workbasket_id) - transactions = workbasket.transactions.values_list("pk", flat=True) - - logger.debug("Setup task to check workbasket %s", workbasket_id) - return self.replace(group(check_transaction.si(id) for id in transactions)) - - -def check_workbasket_sync(workbasket: WorkBasket): - """ - Run and record transaction checks for the passed workbasket ID, - synchronously. - - This method will run all of the checks one after the other and won't return - until they are complete. This is useful for testing and debugging. - """ - transactions = workbasket.transactions.all() - - logger.debug( - "Start synchronous check of workbasket %s with % transactions", - workbasket.pk, - transactions.count(), - ) - for transaction in transactions: - check_transaction_sync(transaction) diff --git a/workbaskets/tests/util.py b/workbaskets/tests/util.py index 4cf10a1792..7577230116 100644 --- a/workbaskets/tests/util.py +++ b/workbaskets/tests/util.py @@ -1,5 +1,6 @@ from workbaskets.models import WorkBasket -from workbaskets.tasks import check_workbasket_sync + +# from workbaskets.tasks import check_workbasket_sync # TODO def assert_workbasket_valid(workbasket: WorkBasket):