diff --git a/checks/checks.py b/checks/checks.py index 64c7eb0dfb..6c66ccaa54 100644 --- a/checks/checks.py +++ b/checks/checks.py @@ -1,219 +1,111 @@ +import abc import logging +import typing from collections import defaultdict from typing import Optional from typing import Set from typing import Tuple -from django.conf import settings - +from checks.models import BusinessRuleModel +from checks.models import BusinessRuleResult from checks.models import TrackedModelCheck from common.business_rules import BusinessRule -from common.business_rules import BusinessRuleViolation from common.models import TrackedModel from common.models import Transaction from common.models.utils import get_current_transaction -from common.models.utils import override_current_transaction logger = logging.getLogger(__name__) CheckResult = Tuple[bool, Optional[str]] -class Checker: +class Checker(abc.ABC): @classmethod - def run_rule( - cls, - rule: BusinessRule, - transaction: Transaction, + @abc.abstractmethod + def get_model_rule_mapping( + cls: abc.abstractclassmethod, model: TrackedModel, - ) -> CheckResult: - """ - Run a single business rule on a single model. - - :return CheckResult, a Tuple(rule_passed: str, violation_reason: Optional[str]). - """ - logger.debug(f"run_rule %s %s %s", model, rule, transaction.pk) - try: - rule(transaction).validate(model) - logger.debug(f"%s [tx:%s] %s [passed]", model, rule, transaction.pk) - return True, None - except BusinessRuleViolation as violation: - reason = violation.args[0] - logger.debug(f"%s [tx:%s] %s [failed]", model, rule, transaction.pk, reason) - return False, reason + rules: Optional[Set[str]] = None, + ) -> typing.Dict[TrackedModel, Set[str]]: + """Implementing classes should return a dict mapping classes to sets of + business rules that apply to them.""" + return {} @classmethod - def apply_rule( + def apply_rules( cls, - rule: BusinessRule, + rules: typing.Sequence[BusinessRule], transaction: Transaction, model: TrackedModel, ): """ - Applies the rule to the model and records success in a - TrackedModelCheck. + TODO - Get rules_to_run - set of rules that have not been run. + """ + # model.content_hash().digest() - If a TrackedModelCheck already exists with a matching content checksum it - will be updated, otherwise a new one will be created. + # rule_models = { + # rule_model.name: rule_model for rule_model in type(model).get_business_rule_models() + # } + # TrackedModel check represents and ongoing check - :return: TrackedModelCheck instance containing the result of the check. + # To minimise the amount of queries, data is fetched up front and results are batched where possible. + rule_models = [*BusinessRuleModel.from_rules(rules)] - During debugging the developer can set settings.RAISE_BUSINESS_RULE_FAILURES - to True to raise business rule violations as exceptions. - """ - success, message = False, None - try: - with override_current_transaction(transaction): - success, message = cls.run_rule(rule, transaction, model) - except Exception as e: - success, message = False, str(e) - if settings.RAISE_BUSINESS_RULE_FAILURES: - # RAISE_BUSINESS_RULE_FAILURES can be set by the developer to raise - # Exceptions. - raise - finally: - check, created = TrackedModelCheck.objects.get_or_create( - { - "successful": success, - "message": message, - "content_hash": model.content_hash().digest(), - }, - model=model, - check_name=rule.__name__, - ) - if not created: - check.successful = success - check.message = message - check.content_hash = model.content_hash().digest() - check.save() - return check + head_transaction = Transaction.objects.approved().last() + check, created = TrackedModelCheck.objects.get_or_create( + model=model, + head_transaction=head_transaction, + # content_hash=model.content_hash().digest(), + ) - @classmethod - def apply_rule_cached( - cls, - rule: BusinessRule, - transaction: Transaction, - model: TrackedModel, - ): - """ - If a matching TrackedModelCheck instance exists, returns it, otherwise - check rule, and return the result as a TrackedModelCheck instance. + # TODO: Get exclude existing rules + results = [ + rule_model.get_result(model.transaction, model) + for rule_model in rule_models + ] - :return: TrackedModelCheck instance containing the result of the check. - """ - try: - check = TrackedModelCheck.objects.get( - model=model, - check_name=rule.__name__, - ) - except TrackedModelCheck.DoesNotExist: - logger.debug( - "apply_rule_cached (no existing check) %s, %s apply rule", - rule.__name__, - transaction, - ) - return cls.apply_rule(rule, transaction, model) - - # Re-run the rule if the content checksum no longer matches that of the previous test. - check_hash = bytes(check.content_hash) - model_hash = model.content_hash().digest() - if check_hash == model_hash: - logger.debug( - "apply_rule_cached (matching content hash) %s, tx: %s, using cached result %s", - rule.__name__, - transaction.pk, - check, - ) - return check - - logger.debug( - "apply_rule_cached (check.content_hash != model.content_hash()) %s != %s %s, %s apply rule", - check_hash, - model_hash, - rule.__name__, - transaction, - ) - check.delete() - return cls.apply_rule(rule, transaction, model) + results = BusinessRuleResult.objects.bulk_create(results) + + check.results.add(*results) + print(results) + return check class BusinessRuleChecker(Checker): - """Apply BusinessRules specified in a TrackedModels business_rules - attribute.""" + """A``Checker`` that runs a ``BusinessRule`` against a model.""" @classmethod - def apply_rule( + def get_model_rule_mapping( cls, - rule: BusinessRule, - transaction: Transaction, model: TrackedModel, + rules: Optional[Set[str]] = None, ): """ - Run the current business rule on the model. - - :return: TrackedModelCheck instance containing the result of the check. - :raises: ValueError if the rule is not in the model's business_rules attribute - - To get a list of applicable rules, get_model_rules can be used. - """ - if rule not in model.business_rules: - raise ValueError( - f"{model} does not have {rule} in its business_rules attribute.", - ) - - return super().apply_rule(rule, transaction, model) + Return a dict mapping business rules to the passed in model. - @classmethod - def get_model_rules(cls, model: TrackedModel, rules: Optional[Set[str]] = None): - """ + This returns a dict, with the passed in model used as a key (this allows LinkedModelsBusinessRuleChecker to map models other than the passed in model to rules.) :param model: TrackedModel instance :param rules: Optional list of rule names to filter by. - :return: Dict mapping models to a set of the BusinessRules that apply to them. + :return: Dict with one entry for the passed in model the values are the rule instances to apply. """ - model_rules = defaultdict(set) + if rules is None: + return {model: set(model.business_rules)} - for rule in model.business_rules: - if rules is not None and rule.__name__ not in rules: - continue - - model_rules[model].add(rule) - - # Downcast to a dict - this API (and unit testing) a little more sane. - return {**model_rules} + # User passed in a certain set of rule names to run, filter the business rules by these names + filtered_rules = { + rule for rule in model.business_rules if rule.__name__ in rules + } + return {model: filtered_rules} class LinkedModelsBusinessRuleChecker(Checker): - """Apply BusinessRules specified in a TrackedModels indirect_business_rules - attribute to models returned by get_linked_models on those rules.""" - - @classmethod - def apply_rule( - cls, - rule: BusinessRule, - transaction: Transaction, - model: TrackedModel, - ): - """ - LinkedModelsBusinessRuleChecker assumes that the linked models are - still. - - the current versions (TODO - ensure a business rule checks this), - - :return: TrackedModelCheck instance containing the result of the check. - :raises: ValueError if the rule is not in the model's indirect_business_rules attribute - - get_model_rules should be called to get a list of applicable rules and them models they apply to. - """ - if rule not in model.indirect_business_rules: - raise ValueError( - f"{model} does not have {rule} in its indirect_business_rules attribute.", - ) - - return super().apply_rule(rule, model.transaction, model) + """A ``Checker`` that runs a ``BusinessRule`` against a model that is linked + to the model being checked, and for which a change in the checked model + could result in a business rule failure against the linked model.""" @classmethod - def get_model_rules(cls, model: TrackedModel, rules: Optional[Set] = None): + def get_model_rule_mapping(cls, model: TrackedModel, rules: Optional[Set] = None): """ :param model: Initial TrackedModel instance :param rules: Optional list of rule names to filter by. diff --git a/checks/management/__init__.py b/checks/management/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/checks/management/commands/__init__.py b/checks/management/commands/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/checks/management/commands/list_business_rules.py b/checks/management/commands/list_business_rules.py new file mode 100644 index 0000000000..b54a718210 --- /dev/null +++ b/checks/management/commands/list_business_rules.py @@ -0,0 +1,40 @@ +from collections import defaultdict + +from django.core.management import BaseCommand + +from checks.models import BusinessRuleModel +from common.business_rules import ALL_RULES + + +class Command(BaseCommand): + """Display the business rules in the system and database.""" + + def handle(self, *app_labels, **options): + self.stdout.write("Rule Name, In System, In Database, Status") + + # Create a dictionary of rule_names, then a couple of flags + # to determine status. + rule_info = defaultdict(dict) + for rule_name in BusinessRuleModel.objects.values("name"): + rule_info[rule_name["name"]]["in_database"] = True + + for rule_name in ALL_RULES.keys(): + rule_info[rule_name]["in_system"] = True + + for rule_name, info in rule_info.items(): + in_database = info.get("in_database", False) + in_system = info.get("in_system", False) + + if in_database and in_system: + status = "In Sync" + elif in_database: + status = "Pending Removal" + elif in_system: + status = "Pending Addition" + + self.stdout.write( + f"{rule_name}," + f" {'Y' if in_system else 'N'}," + f" {'Y' if in_database else 'N'}," + f" {status}", + ) diff --git a/checks/management/commands/sync_business_rules.py b/checks/management/commands/sync_business_rules.py new file mode 100644 index 0000000000..ee58d0efe0 --- /dev/null +++ b/checks/management/commands/sync_business_rules.py @@ -0,0 +1,182 @@ +# https://stackoverflow.com/a/53761784/62709 +import sys +from pathlib import Path +from textwrap import dedent +from typing import Sequence + +import black +from django.conf import settings +from django.core.management import BaseCommand +from django.core.management import CommandParser +from django.db import DEFAULT_DB_ALIAS +from django.db.migrations import Migration +from django.db.migrations.autodetector import MigrationAutodetector +from django.db.migrations.recorder import MigrationRecorder +from django.db.migrations.writer import MigrationWriter + +from checks.models import BusinessRuleModel +from common.models.utils import ansi_hyperlink +from common.models.utils import is_database_synchronized + + +def _modify_migration_code( + app_name: str, + original_code: str, + added_rules: Sequence[str], + removed_rules: Sequence[str], +) -> str: + """ + Modify the migration code to add the business rules. + + While another approach may be desirable, this is the one taken by other + libraries that generate migrations, and allows explicitly embedding the + business rules in the migration code. + """ + header = "# Data Migration, written by sync_business_rules." + + extra_code = dedent( + f"""\ + added_rules = {sorted(added_rules)} + removed_rules = {sorted(removed_rules)} + + def add_rules(apps, schema_editor, rule_names): + BusinessRuleModel = apps.get_model("{app_name}", "BusinessRuleModel") + + rules = [BusinessRuleModel(name=name) for name in rule_names] + BusinessRuleModel.objects.bulk_create(rules) + + def mark_removed_rules(apps, schema_editor, rule_names): + BusinessRuleModel = apps.get_model("{app_name}", "BusinessRuleModel") + BusinessRuleModel.objects.filter(name__in=rule_names).update(current=False) + + def forward_sync_rules(apps, schema_editor): + add_rules(apps, schema_editor, added_rules) + mark_removed_rules(apps, schema_editor, removed_rules) + + def reverse_sync_rules(apps, schema_editor): + ''' + Attempt to delete rules - developer may need to edit this to remove + related models. + ''' + BusinessRuleModel = apps.get_model("{app_name}", "BusinessRuleModel") + BusinessRuleModel.objects.filter(name__in=added_rules).delete() + """, + ) + + operations_code = "migrations.RunPython(forward_sync_rules, reverse_sync_rules)" + + # Now modify the code: + code = header + "\n" + original_code + code = code.replace("class Migration", f"{extra_code}\n\nclass Migration") + + code = code.replace( + "operations = [", + "operations = [\n" f" {operations_code},", + ) + + return code + + +class Command(BaseCommand): + """Override the write method to add more stuff before finishing.""" + + def add_arguments(self, parser: CommandParser) -> None: + parser.add_argument( + "--check", + action="store_true", + help="Check if migrations are pending.", + default=False, + ) + + def write_sync_rules_migration(self, added_rules, removed_rules): + """ + Generate a migration file that will sync the business rules in the + database with those in the app. + + Works by generating an empty migration and adjusting the code, other libraries + seem to take this approach - a nicer API to do this on Djangos side would + be good. + """ + app_name = "checks" + + # Get next migration number: + prev_migration = ( + MigrationRecorder.Migration.objects.filter(app=app_name).last().name + ) + prev_migration_number = MigrationAutodetector.parse_number(prev_migration) + migration_number = f"{prev_migration_number + 1:04}" + migration_name = f"{migration_number}_sync_business_rules" + + migration = Migration(migration_name, app_name) + + migration.dependencies = [(app_name, prev_migration)] + writer = MigrationWriter(migration, include_header=True) + + writer_path = Path(writer.path) + if writer_path.exists(): + sys.exit(f"Migration {migration_name} already exists.") + + writer_path.parent.mkdir(exist_ok=True) + + code = _modify_migration_code( + app_name, + writer.as_string(), + added_rules, + removed_rules, + ) + formatted_code = black.format_str(code, mode=black.Mode()) + + with open(writer.path, "w") as fp: + fp.write(formatted_code) + + migration_path = Path(writer.path).relative_to(Path(settings.BASE_DIR)) + if self.stdout.isatty(): + migration_path = ansi_hyperlink(f"file://{writer.path}", migration_path) + + self.stdout.write( + f"Wrote Business Rules updates " + f"[Added: {len(added_rules)}, Removed: {len(removed_rules)}]:" + f" {migration_path}", + ) + self.stdout.write("") + self.stdout.write( + "Please review the generated migration file to modify reverse migrations as necessary.", + ) + self.stdout.write("") + self.stdout.write("To apply the business rule updates, run the migration:") + self.stdout.write(f"python manage.py migrate {app_name} {migration_number}") + + def handle(self, *app_labels, **options): + # from checks.models import get_updated_rules + + added_rules, removed_rules = BusinessRuleModel.objects.get_updated_rules() + business_rules_pending = added_rules or removed_rules + + migrations_pending = not is_database_synchronized(DEFAULT_DB_ALIAS) + + if options["check"]: + self.stdout.write( + "Migrations pending: " + ("Yes." if migrations_pending else "No."), + ) + self.stdout.write( + "Business models in sync: " + + ("Yes." if not business_rules_pending else "No:"), + ) + if added_rules or removed_rules: + self.stdout.write(f" - Add: {len(added_rules)}") + self.stdout.write(f" - Remove: {len(removed_rules)}") + return + + if migrations_pending: + # Not only ensure schema migrations are up-to-date, but also ensure only + # one business rule migration is written per update. + sys.exit( + "Run pending migrations before generating business rule migrations.", + ) + + if business_rules_pending: + self.write_sync_rules_migration(added_rules, removed_rules) + else: + self.stdout.write( + "Business rules are already synced, no migrations were created.", + ) diff --git a/checks/models.py b/checks/models.py index 971e7648e8..21977aca65 100644 --- a/checks/models.py +++ b/checks/models.py @@ -1,18 +1,190 @@ +from __future__ import annotations + import logging +from typing import Dict +from typing import Iterable +from typing import Type +from typing import TypeVar +from django.conf import settings from django.db import models -from django.db.models import fields +from django.db.models import Manager +from django.db.models import Model from polymorphic.managers import PolymorphicManager +from checks.querysets import BusinessRuleModelQuerySet +from checks.querysets import BusinessRuleResultQuerySet +from checks.querysets import BusinessRuleResultStatus from checks.querysets import TrackedModelCheckQueryset -from common.models import TimestampedMixin +from common.business_rules import ALL_RULES +from common.business_rules import BusinessRule from common.models.celerytask import TaskModel from common.models.trackedmodel import TrackedModel +from common.models.transactions import Transaction logger = logging.getLogger(__name__) +Self = TypeVar("Self", bound="BusinessRuleModel") + + +class BusinessRuleResult(models.Model): + """ + Result of running a business rule. + + Links to the rule itself, in the case of FAILED or ERROR, a message is appended. + + See `BusinessRuleResultStatus` for information on the possible values. + """ + + objects = Manager.from_queryset(BusinessRuleResultQuerySet)() + + rule = models.ForeignKey("BusinessRuleModel", on_delete=models.SET_NULL, null=True) + status = models.PositiveSmallIntegerField(choices=BusinessRuleResultStatus.choices) + message = models.TextField(null=True, blank=True) + + @classmethod + def from_pass(cls, rule_model: BusinessRuleModel) -> BusinessRuleResult: + """Return a BusinessRuleResult with a PASSED status.""" + return cls( + rule=rule_model, + status=BusinessRuleResultStatus.PASSED, + message=None, + ) + + @classmethod + def from_error(cls, rule, error): + """ + Return a BusinessRuleResult from a BusinessRuleViolation or ordinary + Exception. + + Users can optionally raise BusinessRuleViolations and Errors as + exceptions by setting RAISE_BUSINESS_RULE_FAILURES or + RAISE_BUSINESS_RULE_ERRORS + """ + from common.business_rules import BusinessRuleViolation + + if isinstance(error, BusinessRuleViolation): + if settings.RAISE_BUSINESS_RULE_FAILURES: + # During development can be useful to raise business rule failures as exceptions. + raise error + + return cls( + rule=rule, + status=BusinessRuleResultStatus.FAILED, + message=error.args[0], + ) + + if settings.RAISE_BUSINESS_RULE_ERRORS: + # During debugging can be useful to raise business rule errors as exceptions. + raise error + return cls(rule=rule, status=BusinessRuleResultStatus.ERROR, message=str(error)) + + def __str__(self): + status = BusinessRuleResultStatus(self.status).name + if status == BusinessRuleResultStatus.PASSED: + return f"{self.rule} [{status}]" + return f"{self.rule} [{status}] \"{self.message or ''}\"" + + def __repr__(self): + return f"" + + +class BusinessRuleModel(Model): + """ + Database representation of Business Rules. + + This table is maintained by the sync_business_rules management command to match the ALL_RULES dict + it is inadvisable to edit this table directly outside of unit tests. + + Since BusinessRule is a widely already used class in the system, this model is named with the Model suffix. + + Note: If BusinessRules implementation class is renamed, then a data migration may be required to carry on associating + data to the business rule. + """ + + CACHED_RULE_MODELS: Dict[str, Type[Self]] = {} + + objects = Manager.from_queryset(BusinessRuleModelQuerySet)() + + name = models.CharField(max_length=255, unique=True) + """The name of the business rule""" -class TrackedModelCheck(TimestampedMixin, TaskModel): + current = models.BooleanField(default=True) + + def get_implementation(self): + return ALL_RULES[self.name] + + @classmethod + def from_rule(cls: Self, rule: Type[BusinessRule]) -> Self: + """ + Given a BusinessRule fetch it's corresponding model from the cache or + database. + + If the model is not found it in the cache it will be fetched from the + database and cached before returning. + """ + instance = cls.CACHED_RULE_MODELS.get(rule.__name__) + if instance is not None: + return instance + + new_instance = BusinessRuleModel.objects.get(cls.__name__) + cls.CACHED_RULE_MODELS[cls.__name__] = new_instance + return new_instance + + @classmethod + def from_rules(cls, rules: Iterable[Type[BusinessRule]]): + """ + Fetch rule models, already in the cache will be returned, others will be + fetched from the database and added to the cache before being returned. + + If the rules are not found in the cache or database a ValueError will be + raised. + """ + rule_names = [rule.__name__ for rule in rules] + new_names = set(rule_names) + + # Yield the cached rules first, removing each one from the new_names set + for rule_name in rule_names: + if rule_name in cls.CACHED_RULE_MODELS: + new_names.remove(rule_name) + yield cls.CACHED_RULE_MODELS[rule_name] + + # Remaining rules must not be cached. + instances = BusinessRuleModel.objects.filter(name__in=new_names) + for instance in instances: + cls.CACHED_RULE_MODELS[instance.name] = instance + new_names.remove(instance.name) + yield instance + + if new_names: + raise ValueError(f"{[*new_names]} not found in cache or database.") + + def get_result(self, transaction, model): + """ + Run a business rule on a model and return the result as a + BusinessRuleResult. + + The result is not yet saved, this enables it to be used in bulk + operations. + """ + logger.debug(f"run_rule %s %s %s", model, self.name, transaction.pk) + rule = self.get_implementation() + try: + rule(transaction).validate(model) + return BusinessRuleResult.from_pass(self) + except Exception as ex: + return BusinessRuleResult.from_error(self, ex) + + def __str__(self): + return self.name + + def __repr__(self): + if not self.current: + return f"" + return f"" + + +class TrackedModelCheck(TaskModel): """ Represents the result of running a single check against a single model. @@ -21,32 +193,44 @@ class TrackedModelCheck(TimestampedMixin, TaskModel): """ class Meta: - unique_together = ("model", "check_name") + unique_together = (("model", "head_transaction"),) objects = PolymorphicManager.from_queryset(TrackedModelCheckQueryset)() - model = models.ForeignKey( + + results = models.ManyToManyField(BusinessRuleResult) + """Results of running requested business rules, if the check can be considered complete when there is a request + for each business rule to be checked. """ + + model = models.OneToOneField( TrackedModel, - related_name="checks", + related_name="trackedmodel_check", on_delete=models.SET_NULL, null=True, ) - - check_name = fields.CharField(max_length=255) - """A string identifying the type of check carried out.""" - - successful = fields.BooleanField() - """True if the check was successful.""" - - message = fields.TextField(null=True) - """The text content returned by the check, if any.""" - content_hash = models.BinaryField(max_length=32, null=True) """ Hash of the content ('copyable_fields') at the time the data was checked. """ - def __str__(self): - if self.successful: - return f"{self.model} {self.check_name} [Passed at {self.updated_at}]" + head_transaction = models.ForeignKey( + Transaction, + on_delete=models.CASCADE, + ) + """ + The latest transaction in the stream of approved transactions (i.e. in the + REVISION partition) at the moment this check was carried out. + + Once new transactions are commited and the head transaction is no longer the + latest, this check will no longer be an accurate signal of correctness + because the new transactions could include new data which would invalidate + the checks. (Unless the checked transaction < head transaction, in which + case it will always be correct.) + """ + # TODO ^ update this with info on caching strategies. + + def delete(self): + self.content_hash = None + super().delete() - return f"{self.model} {self.check_name} [Failed at {self.updated_at}, Message: {self.message}]" + def report(self, requested_rules): + pass diff --git a/checks/querysets.py b/checks/querysets.py index a429e4c1e3..2ed6554f99 100644 --- a/checks/querysets.py +++ b/checks/querysets.py @@ -1,6 +1,58 @@ +from django.db import models +from django.db.models import QuerySet from django.db.transaction import atomic from polymorphic.query import PolymorphicQuerySet +from common.business_rules import ALL_RULES + + +class BusinessRuleResultStatus(models.IntegerChoices): + """ + The outcome of running a business rule. + + PASSED: The business rule passed. + FAILED: The business rule failed, message is populated. + ERROR: An exception occurred while running the business rule, the name is added to the message field. + """ + + PASSED = 1 + FAILED = 2 + ERROR = 3 + + +class BusinessRuleResultQuerySet(QuerySet): + def errored(self): + return self.filter(status=BusinessRuleResultStatus.ERROR) + + def failed(self): + return self.filter(status=BusinessRuleResultStatus.FAILED) + + def not_passed(self): + return self.filter(status__ne=BusinessRuleResultStatus.PASSED) + + def passed(self): + return self.filter(status=BusinessRuleResultStatus.PASSED) + + +class BusinessRuleModelQuerySet(QuerySet): + def current(self): + """Return business rules that have not been removed.""" + return self.filter(current=True) + + def get_updated_rules(self): + """ + :return (added, removed): Lists of rules that were added and removed since sync_business_rules was last run. + """ + all_rules = set(self.model.objects.current().values_list("name", flat=True)) + added_rules = set() + + for rule_name in ALL_RULES.keys(): + if rule_name not in all_rules: + added_rules.add(rule_name) + all_rules.discard(rule_name) + + return list(added_rules), list(all_rules) + class TrackedModelCheckQueryset(PolymorphicQuerySet): def delete(self): diff --git a/checks/tasks.py b/checks/tasks.py index 200a02b902..f815b8eead 100644 --- a/checks/tasks.py +++ b/checks/tasks.py @@ -20,7 +20,6 @@ from common.business_rules import BusinessRule from common.celery import app from common.models.celerytask import ModelCeleryTask -from common.models.celerytask import bind_model_task from common.models.trackedmodel import TrackedModel from common.models.transactions import Transaction from common.models.utils import get_current_transaction @@ -40,6 +39,72 @@ """TaskInfo is a tuple of (task_id, task_name) which can be used to create a ModelCeleryTask.""" +# @app.task(trail=True) +# def check_model( +# transaction_pk: int, +# model_pk: int, +# rule_names: Optional[Sequence[str]] = None, +# bind_to_task_kwargs: Optional[Dict] = None, +# ): +# """ +# Task to check one model against one business rule and record the result. +# +# As this is a celery task, parameters are in base formats that can be serialised, such as int and str. +# +# Run one business rule against one model, this is called as part of the check_models workflow. +# +# By setting bind_to_task_uuid, the task will be bound to the celery task with the given UUID, +# this is useful for tracking the progress of the parent task, and cancelling it if needed. +# """ +# # XXXX - TODO, re-add note on timings, from Simons original code. +# +# if rule_names is None: +# rule_names = set(ALL_RULES.keys()) +# +# assert set(ALL_RULES.keys()).issuperset(rule_names) +# +# transaction = Transaction.objects.get(pk=transaction_pk) +# model = TrackedModel.objects.get(pk=model_pk) +# successful = True +# +# for checker in ALL_CHECKERS.values(): +# for checker_model, model_rules in checker.get_model_rule_mapping( +# model, +# rule_names, +# ).items(): +# """get_model_rules will return a different model in the case of +# LinkedModelChecker, so the model to check use checker_model.""" +# for rule in model_rules: +# logger.debug( +# "%s rule: %s, tx: %s, model: %s", +# checker.__name__, +# rule, +# transaction, +# model, +# ) +# check_result = checker.apply_rule_cached( +# rule, +# transaction, +# checker_model, +# ) +# if bind_to_task_kwargs: +# logger.debug( +# "Binding result %s to task. bind_to_task_kwargs: %s", +# check_result.pk, +# bind_to_task_kwargs, +# ) +# bind_model_task(check_result, **bind_to_task_kwargs) +# +# logger.info( +# f"Ran check %s %s", +# check_result, +# "✅" if check_result.successful else "❌", +# ) +# successful &= check_result.successful +# +# return successful + + @app.task(trail=True) def check_model( transaction_pk: int, @@ -65,43 +130,16 @@ def check_model( assert set(ALL_RULES.keys()).issuperset(rule_names) transaction = Transaction.objects.get(pk=transaction_pk) - model = TrackedModel.objects.get(pk=model_pk) + initial_model = TrackedModel.objects.get(pk=model_pk) + successful = True for checker in ALL_CHECKERS.values(): - for checker_model, model_rules in checker.get_model_rules( - model, + for model, rules in checker.get_model_rule_mapping( + initial_model, rule_names, ).items(): - """get_model_rules will return a different model in the case of - LinkedModelChecker, so the model to check use checker_model.""" - for rule in model_rules: - logger.debug( - "%s rule: %s, tx: %s, model: %s", - checker.__name__, - rule, - transaction, - model, - ) - check_result = checker.apply_rule_cached( - rule, - transaction, - checker_model, - ) - if bind_to_task_kwargs: - logger.debug( - "Binding result %s to task. bind_to_task_kwargs: %s", - check_result.pk, - bind_to_task_kwargs, - ) - bind_model_task(check_result, **bind_to_task_kwargs) - - logger.info( - f"Ran check %s %s", - check_result, - "✅" if check_result.successful else "❌", - ) - successful &= check_result.successful + Checker.apply_rules(rules, transaction, model) return successful diff --git a/checks/tests/test_checkers.py b/checks/tests/test_checkers.py index 426c5074b9..751f1ae2fd 100644 --- a/checks/tests/test_checkers.py +++ b/checks/tests/test_checkers.py @@ -29,7 +29,7 @@ def test_business_rules_validation(applicable_rules, rule_filter, expected_rules model = factories.TestModel1Factory.create() with add_business_rules(type(model), *applicable_rules): - model_rules = BusinessRuleChecker.get_model_rules(model) + model_rules = BusinessRuleChecker.get_model_rule_mapping(model) assert isinstance(model_rules, dict) if not expected_rules: @@ -73,7 +73,7 @@ def test_business_rules_validation_raises_exception_for_unknown_rule( TestRule1, indirect=True, ): - model_rules = checker.get_model_rules(model) + model_rules = checker.get_model_rule_mapping(model) assert isinstance(model_rules, dict) with pytest.raises(ValueError, match=expected_error_message): diff --git a/common/app_config.py b/common/app_config.py index 86169b8aae..c64fab24c0 100644 --- a/common/app_config.py +++ b/common/app_config.py @@ -1,7 +1,10 @@ import logging +import os +import sys from importlib import import_module from django.apps import AppConfig +from django.db import DEFAULT_DB_ALIAS logger = logging.getLogger(__file__) @@ -22,7 +25,7 @@ class CommonConfig(AppConfig): AppConfig.ready method. """ - def ready(self): + def load_importer_modules(self): """Load importer parser and handler modules, if they exist.""" modules_to_import = [ f"{self.name}.{IMPORT_PARSER_NAME}", @@ -33,3 +36,46 @@ def ready(self): import_module(module) except ModuleNotFoundError: logger.debug(f"Failed to import {module}") + + def warn_if_business_rules_changed(self): + """ + Output a message if the business rules in app don't match those in the + database. + + A data migration to sync the rules may be created using the sync_business_rules + management command. + + :return: True if the rules need syncing, or there are unapplied migrations to this app. + """ + from checks.models import get_updated_rules + from common.models.utils import is_database_synchronized + + if not is_database_synchronized(DEFAULT_DB_ALIAS): + logger.debug( + "Database has pending migrations, run them before checking if business rule sync is required.", + ) + # There are unapplied migrations (some of which may be ones needed to run the business rules.) + return True + + added, removed = get_updated_rules() + sync_required = bool(added or removed) + if sync_required: + print( + f"Business rules are not synced to the database. (Added: {len(added)}, Removed: {len(removed)})", + file=sys.stderr, + ) + print( + "Create a data migration to sync the rules using the management command:\n " + "sync_business_rules", + file=sys.stderr, + ) + + return sync_required + + def ready(self): + in_runserver = bool(os.environ.get("RUN_MAIN")) + + self.load_importer_modules() + if in_runserver: + if self.warn_if_business_rules_changed(): + sys.exit(1) diff --git a/common/models/tracked_qs.py b/common/models/tracked_qs.py index 5b7cc337c3..8eabaa0841 100644 --- a/common/models/tracked_qs.py +++ b/common/models/tracked_qs.py @@ -1,10 +1,8 @@ from __future__ import annotations from hashlib import sha256 -from itertools import chain from typing import List -from django.contrib.contenttypes.models import ContentType from django.db.models import Case from django.db.models import CharField from django.db.models import F @@ -444,55 +442,56 @@ def content_hash(self): """ return self._content_hash(self.order_by("pk").iterator()) - def group_by_type(self): - """Yield a sequence of query sets, where each queryset contains only one - Polymorphic ctype, enabling the use of prefetch and select_related on - them.""" - pks = self.values_list("pk", flat=True) - polymorphic_ctypes = ( - self.non_polymorphic() - .distinct("polymorphic_ctype_id") - .values_list("polymorphic_ctype", flat=True) - ) - - for polymorphic_ctype in polymorphic_ctypes: - # Query contenttypes to get the concrete class instance - klass = ContentType.objects.get_for_id(polymorphic_ctype).model_class() - yield klass.objects.filter(pk__in=pks) - - def select_related_copyable_fields(self): - """Split models into separate querysets, using group_by_type and call - select_related on any related fields found in the `copyable_fields` - attribute.""" - pks = self.values_list("pk", flat=True) - for qs in self.group_by_type(): - # Work out which fields from copyable_fields may be use in select_related - related_fields = [ - field.name - for field in qs.model.copyable_fields - if hasattr(field, "related_query_name") - ] - yield qs.select_related(*related_fields).filter(pk__in=pks) - - def content_hash_fast(self): - """ - Use `select_related_copyable_fields` to call select_related on fields - that will be hashed. - - This increases the speed a little more than 2x, at the expense of keeping the data in memory. - On this developers' laptop 2.3 seconds vs 6.5 for the naive implementation in `content_hash`, - for larger amounts of data the difference got bigger, 23 seconds vs 90, though this may - because more types of data were represented. - - For larger workbaskets batching should be used to keep memory usage withing reasonable bounds. - - The hash value returned here should be the same as that from `content_hash`. - """ - # Fetch data using select_related, at this point the ordering - # will have been lost. - all_models = chain(*self.select_related_copyable_fields()) - - # Sort the data using trackedmodel_ptr_id, since the previous step outputs - # an iterable, sorted is used, instead of order_by on a queryset. - sorted_models = sorted(all_models, key=lambda o: o.trackedmodel_ptr_id) - return self._content_hash(sorted_models) + # def group_by_type(self): + # """Yield a sequence of query sets, where each queryset contains only one + # Polymorphic ctype, enabling the use of prefetch and select_related on + # them.""" + # pks = self.values_list("pk", flat=True) + # polymorphic_ctypes = ( + # self.non_polymorphic() + # .only("polymorphic_ctype_id") + # .distinct("polymorphic_ctype_id") + # .values_list("polymorphic_ctype_id", flat=True) + # ) + # + # for polymorphic_ctype in polymorphic_ctypes: + # # Query contenttypes to get the concrete class instance + # klass = ContentType.objects.get_for_id(polymorphic_ctype).model_class() + # yield klass.objects.filter(pk__in=pks) + # + # def select_related_copyable_fields(self): + # """Split models into separate querysets, using group_by_type and call + # select_related on any related fields found in the `copyable_fields` + # attribute.""" + # pks = self.values_list("pk", flat=True) + # for qs in self.group_by_type(): + # # Work out which fields from copyable_fields may be use in select_related + # related_fields = [ + # field.name + # for field in qs.model.copyable_fields + # if hasattr(field, "related_query_name") + # ] + # yield qs.select_related(*related_fields).filter(pk__in=pks) + # + # def content_hash_fast(self): + # """ + # Use `select_related_copyable_fields` to call select_related on fields + # that will be hashed. + # + # This increases the speed a little more than 2x, at the expense of keeping the data in memory. + # On this developers' laptop 2.3 seconds vs 6.5 for the naive implementation in `content_hash`, + # for larger amounts of data the difference got bigger, 23 seconds vs 90, though this may + # because more types of data were represented. + # + # For larger workbaskets batching should be used to keep memory usage withing reasonable bounds. + # + # The hash value returned here should be the same as that from `content_hash`. + # """ + # # Fetch data using select_related, at this point the ordering + # # will have been lost. + # all_models = chain(*self.select_related_copyable_fields()) + # + # # Sort the data using trackedmodel_ptr_id, since the previous step outputs + # # an iterable, sorted is used, instead of order_by on a queryset. + # sorted_models = sorted(all_models, key=lambda o: o.trackedmodel_ptr_id) + # return self._content_hash(sorted_models) diff --git a/common/models/tracked_utils.py b/common/models/tracked_utils.py index 7e1627cc0d..ae408194fe 100644 --- a/common/models/tracked_utils.py +++ b/common/models/tracked_utils.py @@ -102,6 +102,11 @@ def get_field_hashable_string(value): return f"{value_type.__module__}:{value_type.__name__}={value}" +class NotPresent: + # Sentinel value for fields that are not present. + pass + + def get_field_hashable_strings(instance, fields): """ Given a model instance, return a dict of {field names: hashable string}, @@ -113,6 +118,6 @@ def get_field_hashable_strings(instance, fields): :return: Dictionary of {field_name: hash} """ return { - field.name: get_field_hashable_string(getattr(instance, field.name)) + field.name: get_field_hashable_string(getattr(instance, field.name, NotPresent)) for field in fields } diff --git a/common/models/utils.py b/common/models/utils.py index d2c841ef5a..3889070e26 100644 --- a/common/models/utils.py +++ b/common/models/utils.py @@ -3,6 +3,8 @@ from typing import FrozenSet import wrapt +from django.db import connections +from django.db.migrations.executor import MigrationExecutor from django.db.models import Value _thread_locals = threading.local() @@ -111,3 +113,37 @@ def __call__(self, request): response = self.get_response(request) # No post-view processing required. return response + + +def is_database_synchronized(database): + # https://stackoverflow.com/a/31847406/62709 + connection = connections[database] + connection.prepare_database() + executor = MigrationExecutor(connection) + targets = executor.loader.graph.leaf_nodes() + return not executor.migration_plan(targets) + + +def ansi_hyperlink(uri, label=None, parameters=None): + """ + Return an ANSI escape sequence that will hyperlink to the specified URI. + + TODO: Linking to the spec for ANSI escape sequences, or this example causes + one of our post commit hooks to output 'entropy check failed' + + For mor einfo search for the gisthub gist and many comments by egmontkob on github for + "Hyperlinks (a.k.a. HTML-like anchors) in terminal emulators" + + https://gist.github.com/egmontkob + + In a github gist unde the above account. + """ + if label is None: + label = uri + if parameters is None: + parameters = "" + + # OSC 8 ; params ; URI ST OSC 8 ;; ST + escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\" + + return escape_mask.format(parameters, uri, label) diff --git a/pii-ner-exclude.txt b/pii-ner-exclude.txt index 1e3fd37f4f..e515bdc49a 100644 --- a/pii-ner-exclude.txt +++ b/pii-ner-exclude.txt @@ -1215,3 +1215,27 @@ checker.apply_rule TestRule LinkedModelsBusinessRuleChecker.of(TestRule assert len(checkers +run_rule +rule.__name +A``Checker +mark_removed_rules(apps +f"python +ANSI +# Query +klass.objects.filter(pk__in +content_hash_fast(self +Sort +self._content_hash(sorted_models +Business Rules +Create a TrackedModelBusinessRuleViolation +Sync +Pending Addition" +CELERY_TASK_ALWAYS_EAGER +assert set(ALL_RULES.keys()).issuperset(rule_names +check_result +successful &= check_result.successful +successful &= +ALL_RULES +\s| +github https://gist.github.com/egmontkob <- +Hyperlinks (a.k.a diff --git a/quotas/models.py b/quotas/models.py index 5448155e6c..7b3af670c6 100644 --- a/quotas/models.py +++ b/quotas/models.py @@ -126,7 +126,7 @@ class QuotaOrderNumberOrigin(TrackedModel, ValidityMixin): ) def __str__(self): - return self.sid + return f"{self.sid}" def order_number_in_use(self, transaction): return self.order_number.in_use(transaction) diff --git a/settings/common.py b/settings/common.py index 5d0caedab2..00698e1413 100644 --- a/settings/common.py +++ b/settings/common.py @@ -370,6 +370,10 @@ }, } +RAISE_BUSINESS_RULE_FAILURES = is_truthy( + os.environ.get("RAISE_BUSINESS_RULE_ERRORS", "N"), +) + RAISE_BUSINESS_RULE_FAILURES = is_truthy( os.environ.get("RAISE_BUSINESS_RULE_FAILURES", "N"), ) diff --git a/workbaskets/management/commands/list_workbaskets.py b/workbaskets/management/commands/list_workbaskets.py index 631ecbee2a..7eecbe8722 100644 --- a/workbaskets/management/commands/list_workbaskets.py +++ b/workbaskets/management/commands/list_workbaskets.py @@ -56,8 +56,7 @@ def add_arguments(self, parser: CommandParser) -> None: parser.add_argument( "workbasket_ids", - nargs="*", - help="Comma-separated list of workbasket ids to filter to", + help=("Comma-separated list of workbasket ids to filter to"), type=ast.literal_eval, ) diff --git a/workbaskets/management/commands/run_checks.py b/workbaskets/management/commands/run_checks.py index 854a27e94e..a34b8e08a1 100644 --- a/workbaskets/management/commands/run_checks.py +++ b/workbaskets/management/commands/run_checks.py @@ -1,4 +1,5 @@ import logging +import re import signal import sys from typing import Any @@ -6,10 +7,12 @@ from typing import Optional from celery.result import AsyncResult +from celery.result import EagerResult from celery.result import GroupResult from django.core.management import BaseCommand from django.core.management.base import CommandParser +from checks.models import BusinessRuleModel from checks.models import TrackedModelCheck from workbaskets.models import WorkBasket @@ -35,12 +38,16 @@ def revoke_task_and_children(task, depth=0): class TaskControlMixin: - # Implementing classes can populate this with + # Implementing classes can remove the module names of tasks to make the display less verbose. IGNORE_TASK_PREFIXES = [] def get_readable_task_name(self, node): """Optionally remove a prefix from the task name (used to remove the module which is often repeated)""" + if isinstance(node, EagerResult) and not node.name: + # Name isn't available for when CELERY_TASK_ALWAYS_EAGER is set :( + return f"Eager Task: {node.id}" + task_name = getattr(node, "name") or "" for prefix in self.IGNORE_TASK_PREFIXES: unprefixed = task_name.replace(prefix, "") @@ -79,21 +86,21 @@ def sigint_handler(sig, frame): def display_task(self, node, value, depth): """Default task display.""" - # For now this only shows args, to avoid some noise - if this is more widely - # used that may be something that could be configured by the caller. + # Only shows task.args, to avoid some noise. readable_task_name = self.get_readable_task_name( node, ) + # node.args can be None when CELERY_TASK_ALWAYS_EAGER is set, + # - when running eagerly the full API isn't available. self.stdout.write( - " " * depth * 2 + f"{readable_task_name} " + f"{tuple(node.args)}", + " " * depth * 2 + f"{readable_task_name} " + f"{tuple(node.args or ())}", ) def iterate_ongoing_tasks(self, result, ignore_groupresult=True): """ - Iterate over the ongoing tasks as they are received, track their. - - depth - which is useful for visual formatting. + Iterate over the ongoing tasks as they are received, "depth" is tracked + to enabled visual formatting. Yields: (node, value, depth) """ @@ -114,6 +121,7 @@ def iterate_ongoing_tasks(self, result, ignore_groupresult=True): yield node, value, depth def display_ongoing_tasks(self, result): + """Iterate the task tree and display info as the received.""" for node, value, depth in self.iterate_ongoing_tasks(result): self.display_task(node, value, depth) @@ -127,6 +135,8 @@ class Command(TaskControlMixin, BaseCommand): "checks.tasks.", ] + rule_names = [] + rule_models = None passed = 0 failed = 0 @@ -144,16 +154,30 @@ def add_arguments(self, parser: CommandParser) -> None: default=False, ) + parser.add_argument( + "rules", + type=str, + help="Check only these rules (comma seperated list): 'rule_name1,rule_name2'", + ) + def display_check_model_task(self, node, value, depth): model_pk = node.args[1] check_passed = value readable_task_name = self.get_readable_task_name(node) style = self.style.SUCCESS if check_passed else self.style.ERROR + + check = TrackedModelCheck.objects.filter(model=model_pk).last() + + if check is None: + check_msg = "[No results]" + else: + check_msg = check.report(self.rule_names) + self.stdout.write( " " * depth * 2 + f"{readable_task_name} " + style( - f"[{model_pk}] {TrackedModelCheck.objects.filter(model=model_pk).last()}", + f"[{model_pk}] {check_msg}", ), ) @@ -175,6 +199,35 @@ def iterate_ongoing_tasks(self, result, ignore_groupresult=True): self.failed += 1 yield node, value, depth + def parse_rule_names_option(self, rule_names_option: str): + """ + Given a comma seperated list of rule names, return a list of rule names + and their corresponding models. + + Also handles the case where the user includes spaces. + """ + # Split by comma, but be kind and eat spaces too. + rule_names = re.split(r"\s|,", rule_names_option) + + # The user may limit the check to particular rules. + if rule_names: + rule_names = rule_names + rule_models = BusinessRuleModel.objects.current().filter( + name__in=rule_names, + ) + if rule_models.count() != len(rule_names): + # TODO - be nice to the user and show which rules are missing. + self.stderr.write( + "One or more rules not found: " + ", ".join(rule_names), + ) + sys.exit(2) + else: + # Default to all rules being checks. + rule_models = BusinessRuleModel.objects.current().all() + rule_names = [*rule_models.values_list("name", flat=True)] + + return rule_names, rule_models + def handle(self, *args: Any, **options: Any) -> Optional[str]: from checks.tasks import check_workbasket @@ -183,11 +236,14 @@ def handle(self, *args: Any, **options: Any) -> Optional[str]: pk=int(options["WORKBASKET_PK"]), ) clear_cache = options["clear_cache"] - rule_names = None throw = options["throw"] - # Temporarily display a message while waiting for celery, this will only have time to show up - # if celery isn't working (easy enough on a dev machine), or is busy. + self.rule_names, self.rule_models = self.parse_rule_names_option( + options["rules"], + ) + + # Temporarily display a message while waiting for celery to acknowledge the task, + # if this stays on the screen it's a sign celery is either busy or not running. self.stdout.write("Connecting to celery... ⌛", ending="") self.stdout._out.flush() # self.stdout.flush() doesn't result in any output - TODO: report as a bug to django. result = check_workbasket.apply_async( @@ -197,7 +253,7 @@ def handle(self, *args: Any, **options: Any) -> Optional[str]: ), kwargs={ "clear_cache": clear_cache, - "rules": rule_names, + "rules": self.rule_names, }, throw=throw, ) @@ -207,6 +263,7 @@ def handle(self, *args: Any, **options: Any) -> Optional[str]: # Attach a handler to revoke the task and its subtasks if the user presses Ctrl+C self.revoke_task_on_sigint(result) + # Display tasks as they complete self.display_ongoing_tasks(result) self.stdout.write() diff --git a/workbaskets/management/util.py b/workbaskets/management/util.py index 59c89f64f8..3c6be99c34 100644 --- a/workbaskets/management/util.py +++ b/workbaskets/management/util.py @@ -22,14 +22,8 @@ def _output_workbasket_readable( self.stdout.write(f"{spaces}reason: {first_line_of(workbasket.reason)}") self.stdout.write(f"{spaces}status: {workbasket.status}") if show_transaction_info: - transactions = workbasket.transactions - first_pk = ( - workbasket.transactions.first().pk if transactions.count() else "-" - ) - last_pk = workbasket.transactions.last().pk if transactions.count() else "-" - self.stdout.write( - f"{spaces}transactions: {first_pk} - {last_pk} [{transactions.count()}]", + f"{spaces}transactions: {workbasket.transactions.first().pk} - {workbasket.transactions.last().pk}", ) def _output_workbasket_compact(self, workbasket, show_transaction_info, **kwargs): @@ -38,13 +32,8 @@ def _output_workbasket_compact(self, workbasket, show_transaction_info, **kwargs ending="" if show_transaction_info else "\n", ) if show_transaction_info: - transactions = workbasket.transactions - first_pk = ( - workbasket.transactions.first().pk if transactions.count() else "-" - ) - last_pk = workbasket.transactions.last().pk if transactions.count() else "-" self.stdout.write( - f", {first_pk} - {last_pk} [{transactions.count()}]", + f", {workbasket.transactions.first().pk} - {workbasket.transactions.last().pk}", ) def output_workbasket(