diff --git a/common/forms.py b/common/forms.py index 4189ffe0a..6c5de5117 100644 --- a/common/forms.py +++ b/common/forms.py @@ -425,7 +425,7 @@ def __init__(self, *args, **kwargs): self.fields["end_date"].help_text = ( f"Leave empty if {get_model_indefinite_article(self.instance)} " - f"{self.instance._meta.verbose_name} is needed for an unlimited time." + f"{self.instance._meta.verbose_name} is needed for an unlimited time" ) if self.instance.valid_between: diff --git a/common/jinja2/common/app_info.jinja b/common/jinja2/common/app_info.jinja index 5dfd98afd..c462ec71b 100644 --- a/common/jinja2/common/app_info.jinja +++ b/common/jinja2/common/app_info.jinja @@ -52,9 +52,9 @@ {"text": "Environment variable"}, ], [ - {"text": "APP_UPDATED_TIME"}, - {"text": APP_UPDATED_TIME}, - {"text": "Estimated application deploy time"}, + {"text": "UPTIME"}, + {"text": UPTIME}, + {"text": "Time this instance has been in service"}, ], [ {"text": "LAST_TRANSACTION_TIME"}, diff --git a/common/tests/test_util.py b/common/tests/test_util.py index 66ed96f2a..ba956d454 100644 --- a/common/tests/test_util.py +++ b/common/tests/test_util.py @@ -1,3 +1,4 @@ +import json import os from unittest import mock @@ -14,6 +15,64 @@ pytestmark = pytest.mark.django_db +@pytest.mark.parametrize( + "environment_key, expected_result", + ( + ( + { + "engine": "engine", + "username": "username", + "password": "password", + "host": "host", + "port": 1234, + "dbname": "dbname", + }, + "engine://username:password@host:1234/dbname", + ), + ( + { + "engine": "engine", + "username": "username", + "host": "host", + "dbname": "dbname", + }, + "engine://username@host/dbname", + ), + ( + { + "engine": "engine", + "host": "host", + "dbname": "dbname", + }, + "engine://host/dbname", + ), + ( + { + "engine": "engine", + "password": "password", + "port": 1234, + "dbname": "dbname", + }, + "engine:///dbname", + ), + ( + { + "engine": "engine", + "dbname": "dbname", + }, + "engine:///dbname", + ), + ), +) +def test_database_url_from_env(environment_key, expected_result): + with mock.patch.dict( + os.environ, + {"DATABASE_CREDENTIALS": json.dumps(environment_key)}, + clear=True, + ): + assert util.database_url_from_env("DATABASE_CREDENTIALS") == expected_result + + @pytest.mark.parametrize( "value, expected", [ diff --git a/common/util.py b/common/util.py index 7f06fb9c1..667f96b79 100644 --- a/common/util.py +++ b/common/util.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import os import re import typing @@ -54,10 +55,53 @@ major, minor, patch = python_version_tuple() +def is_cloud_foundry(): + """Return True if the deployment environment contains a `VCAP_SERVICES` env + var, indicating a CloudFoundry environment, False otherwise.""" + return "VCAP_SERVICES" in os.environ + + def classproperty(fn): return classmethod(property(fn)) +def database_url_from_env(environment_key: str) -> str: + """ + Return a database URL string from the environment variable identified by + `environment_key`. The environment variable should be parsable as a + JSON-like string and may contain the keys: + + "engine" (Required) - database engine id. For instance "postgres" or "sqlite". + "username" (Optional if "password" is not present) - database user name. + "password" (Optional) - database user's password. + "host" (Optional if "port" is not present) - database hostname. + "port" (Optional) - database host port. + "dbname" (Required) - database name. + + If all keys are present, then the returned result would be a string of the + form: + + ://:@:/ + + This is a plug-in, less naive version of + `dbt_copilot_python.database.database_url_from_env()` making `username`, + `password`, `host` and `port` an optional as described above. + """ + config = json.loads(os.environ[environment_key]) + + username = config.get("username", "") + password = config.get("password") + host = config.get("host", "") + port = config.get("port") + + config["username"] = username + config["password"] = f":{password}" if username and password else "" + config["host"] = f"@{host}" if (username or password) and host else host + config["port"] = f":{port}" if host and port else "" + + return "{engine}://{username}{password}{host}{port}/{dbname}".format(**config) + + def is_truthy(value: Union[str, bool]) -> bool: """ Check whether a string represents a True boolean value. diff --git a/common/views/pages.py b/common/views/pages.py index 87f061334..bc2259bdd 100644 --- a/common/views/pages.py +++ b/common/views/pages.py @@ -1,8 +1,10 @@ """Common views.""" +import logging import os import time from datetime import datetime +from datetime import timedelta from typing import Dict from typing import List from typing import Optional @@ -36,6 +38,7 @@ from common.celery import app as celery_app from common.forms import HomeSearchForm from common.models import Transaction +from common.util import is_cloud_foundry from exporter.sqlite.util import sqlite_dumps from footnotes.models import Footnote from geo_areas.models import GeographicalArea @@ -47,6 +50,8 @@ from workbaskets.models import WorkBasket from workbaskets.models import WorkflowStatus +logger = logging.getLogger(__name__) + class HomeView(LoginRequiredMixin, FormView): template_name = "common/homepage.jinja" @@ -322,6 +327,33 @@ def get(self, request, *args, **kwargs) -> HttpResponse: ) +def get_uptime() -> str: + """ + Return approximate system uptime in a platform-independent way as a string + in the following format: + " days, hours, minutes" + """ + try: + if is_cloud_foundry(): + # CF recycles Garden containers so time.monotonic() returns a + # misleading value. However, file modified time is set on deployment. + uptime = timedelta(seconds=(time.time() - os.path.getmtime(__file__))) + else: + # time.monotonic() doesn't count time spent in hibernation, so may + # be inaccurate on systems that hibernate. + uptime = timedelta(seconds=time.monotonic()) + + formatted_uptime = ( + f"{uptime.days} days, {uptime.seconds // 3600} hours, " + f"{uptime.seconds // 60 % 60} minutes" + ) + except Exception as e: + logger.error(e) + formatted_uptime = "Error getting uptime" + + return formatted_uptime + + class AppInfoView( LoginRequiredMixin, TemplateView, @@ -416,9 +448,7 @@ def get_context_data(self, **kwargs): if self.request.user.is_superuser: data["GIT_BRANCH"] = os.getenv("GIT_BRANCH", "Unavailable") data["GIT_COMMIT"] = os.getenv("GIT_COMMIT", "Unavailable") - data["APP_UPDATED_TIME"] = AppInfoView.timestamp_to_datetime_string( - os.path.getmtime(__file__), - ) + data["UPTIME"] = get_uptime() last_transaction = Transaction.objects.order_by("updated_at").last() data["LAST_TRANSACTION_TIME"] = ( format( diff --git a/exporter/sqlite/__init__.py b/exporter/sqlite/__init__.py index e85ac64c1..0450c3a21 100644 --- a/exporter/sqlite/__init__.py +++ b/exporter/sqlite/__init__.py @@ -73,8 +73,8 @@ def make_export(connection: apsw.Connection): Path(temp_sqlite_db.name), ) plan = make_export_plan(plan_runner) - # make_tamato_database() creates a Connection instance that needs - # closing once an in-memory plan has been created from it. + # Runner.make_tamato_database() (above) creates a Connection instance + # that needs closing once an in-memory plan has been created from it. plan_runner.database.close() export_runner = runner.Runner(connection) diff --git a/exporter/sqlite/runner.py b/exporter/sqlite/runner.py index 070a87cf2..ae31bb75b 100644 --- a/exporter/sqlite/runner.py +++ b/exporter/sqlite/runner.py @@ -1,9 +1,11 @@ import json import logging import os +import shutil +import subprocess import sys from pathlib import Path -from subprocess import run +from tempfile import TemporaryDirectory from typing import Iterable from typing import Iterator from typing import Tuple @@ -16,80 +18,203 @@ logger = logging.getLogger(__name__) -class Runner: - """Runs commands on an SQLite database.""" +def normalise_loglevel(loglevel): + """ + Attempt conversion of `loglevel` from a string integer value (e.g. "20") to + its loglevel name (e.g. "INFO"). - database: apsw.Connection + This function can be used after, for instance, copying log levels from + environment variables, when the incorrect representation (int as string + rather than the log level name) may occur. + """ + try: + return logging._levelToName.get(int(loglevel)) + except: + return loglevel - def __init__(self, database: apsw.Connection) -> None: - self.database = database - @classmethod - def normalise_loglevel(cls, loglevel): - """ - Attempt conversion of `loglevel` from a string integer value (e.g. "20") - to its loglevel name (e.g. "INFO"). +SQLITE_MIGRATIONS_NAME = "sqlite_export" +"""Name passed to `manage.py makemigrations`, via the --name flag, when creating +the SQLite migrations source files.""" - This function can be used after, for instance, copying log levels from - environment variables, when the incorrect representation (int as string - rather than the log level name) may occur. - """ - try: - return logging._levelToName.get(int(loglevel)) - except: - return loglevel +SQLITE_MIGRATIONS_GLOB = f"**/migrations/*{SQLITE_MIGRATIONS_NAME}.py" +"""Glob pattern matching all SQLite-specific migration source files generated by +the `manage.py makemigrations --name sqlite_export` command.""" - @classmethod - def manage(cls, sqlite_file: Path, *args: str): + +class SQLiteMigrationCurrentDirectory: + """ + Context manager class that uses the application's current base directory for + managing SQLite migrations. + + Upon exiting the context manager, SQLite-specific migration files are + deleted. + """ + + def __enter__(self): + logger.info(f"Entering context manager {self.__class__.__name__}") + return settings.BASE_DIR + + def __exit__(self, exc_type, exc_value, traceback): + logger.info(f"Exiting context manager {self.__class__.__name__}") + for file in Path(settings.BASE_DIR).rglob(SQLITE_MIGRATIONS_GLOB): + file.unlink() + + +class SQLiteMigrationTemporaryDirectory(TemporaryDirectory): + """ + Context manager class that provides a newly created temporary directory + (under the OS's temporary directory system) for managing SQLite migrations. + + Upon exiting the context manager, the temporary directory is deleted. + """ + + def __enter__(self): + logger.info(f"Entering context manager {self.__class__.__name__}") + + tmp_dir = super().__enter__() + tmp_dir = os.path.join(tmp_dir, "tamato_sqlite_migration") + shutil.copytree(settings.BASE_DIR, tmp_dir) + + # Ensure migrations directories are writable to allow SQLite migrations + # to be created - some deployments make source tree directories + # non-wriable. + for d in [p for p in Path(tmp_dir).rglob("migrations") if p.is_dir()]: + d.chmod(0o777) + + copied_files = [f for f in Path(tmp_dir).rglob("*") if f.is_file()] + logger.info(f"Copied {len(copied_files)} files to {tmp_dir}") + + return tmp_dir + + def __exit__(self, exc_type, exc_value, traceback): + logger.info(f"Exiting context manager {self.__class__.__name__}") + super().__exit__(exc_type, exc_value, traceback) + + +class SQLiteMigrator: + """ + Populates a new and empty SQLite database file with the Tamato database + schema derived from Tamato's models. + + This is required because SQLite uses different fields to PostgreSQL, missing + migrations are first generated to bring in the different style of validity + fields. + + This is done by creating additional, auxiliary migrations that are specific + to the SQLite and then executing them to populate the database with the + schema. + """ + + sqlite_file: Path + + def __init__(self, sqlite_file: Path, migrations_in_tmp_dir=False): + self.sqlite_file = sqlite_file + self.migration_directory_class = ( + SQLiteMigrationTemporaryDirectory + if migrations_in_tmp_dir + else SQLiteMigrationCurrentDirectory + ) + + def migrate(self): + from manage import ENV_INFO_FLAG + + with self.migration_directory_class() as migration_dir: + logger.info(f"Running `makemigrations` in {migration_dir}") + self.manage( + migration_dir, + ENV_INFO_FLAG, + "makemigrations", + "--name", + SQLITE_MIGRATIONS_NAME, + ) + + sqlite_migration_files = [ + f + for f in Path(migration_dir).rglob(SQLITE_MIGRATIONS_GLOB) + if f.is_file() + ] + logger.info( + f"{len(sqlite_migration_files)} SQLite migration files " + f"generated in {migration_dir}", + ) + + logger.info(f"Running `migrate` in {migration_dir}") + self.manage( + migration_dir, + ENV_INFO_FLAG, + "migrate", + ) + + def manage(self, exec_dir: str, *manage_args: str): """ Runs a Django management command on the SQLite database. This management command will be run such that ``settings.SQLITE`` is True, allowing SQLite specific functionality to be switched on and off using the value of this setting. + + `exec_dir` sets the directory in which the management command should be + executed. """ + sqlite_env = os.environ.copy() # Correct log levels that are incorrectly expressed as string ints. if "CELERY_LOG_LEVEL" in sqlite_env: - sqlite_env["CELERY_LOG_LEVEL"] = cls.normalise_loglevel( + sqlite_env["CELERY_LOG_LEVEL"] = normalise_loglevel( sqlite_env["CELERY_LOG_LEVEL"], ) - sqlite_env["DATABASE_URL"] = f"sqlite:///{str(sqlite_file)}" - # Required to make sure the postgres default isn't set as the DB_URL + # Set up environment-specific env var values. if sqlite_env.get("VCAP_SERVICES"): vcap_env = json.loads(sqlite_env["VCAP_SERVICES"]) vcap_env.pop("postgres", None) sqlite_env["VCAP_SERVICES"] = json.dumps(vcap_env) + sqlite_env["DATABASE_URL"] = f"sqlite:///{str(self.sqlite_file)}" + elif sqlite_env.get("COPILOT_ENVIRONMENT_NAME"): + sqlite_env["DATABASE_CREDENTIALS"] = json.dumps( + { + "engine": "sqlite", + "dbname": f"{str(self.sqlite_file)}", + }, + ) + else: + sqlite_env["DATABASE_URL"] = f"sqlite:///{str(self.sqlite_file)}" - run( - [sys.executable, "manage.py", *args], - cwd=settings.BASE_DIR, - capture_output=False, + sqlite_env["PATH"] = exec_dir + ":" + sqlite_env["PATH"] + manage_cmd = os.path.join(exec_dir, "manage.py") + + subprocess.run( + [sys.executable, manage_cmd, *manage_args], + cwd=exec_dir, + check=True, env=sqlite_env, ) + +class Runner: + """Runs commands on an SQLite database.""" + + database: apsw.Connection + + def __init__(self, database: apsw.Connection) -> None: + self.database = database + @classmethod def make_tamato_database(cls, sqlite_file: Path) -> "Runner": """Generate a new and empty SQLite database with the TaMaTo schema derived from Tamato's models - by performing 'makemigrations' followed by 'migrate' on the Sqlite file located at `sqlite_file`.""" - try: - # Because SQLite uses different fields to PostgreSQL, missing - # migrations are first generated to bring in the different style of - # validity fields. However, these should not be applied to Postgres - # and so should be removed (in the `finally` block) after they have - # been applied (when running `migrate`). - cls.manage(sqlite_file, "makemigrations", "--name", "sqlite_export") - cls.manage(sqlite_file, "migrate") - assert sqlite_file.exists() - return cls(apsw.Connection(str(sqlite_file))) - finally: - for file in Path(settings.BASE_DIR).rglob( - "**/migrations/*sqlite_export.py", - ): - file.unlink() + + sqlite_migrator = SQLiteMigrator( + sqlite_file=sqlite_file, + migrations_in_tmp_dir=settings.SQLITE_MIGRATIONS_IN_TMP_DIR, + ) + sqlite_migrator.migrate() + + assert sqlite_file.exists() + return cls(apsw.Connection(str(sqlite_file))) def read_schema(self, type: str) -> Iterator[Tuple[str, str]]: """ diff --git a/exporter/storages.py b/exporter/storages.py index a6a86ac7d..6b2c65d0a 100644 --- a/exporter/storages.py +++ b/exporter/storages.py @@ -1,4 +1,5 @@ import logging +import sqlite3 from functools import cached_property from os import path from pathlib import Path @@ -15,6 +16,40 @@ logger = logging.getLogger(__name__) +class EmptyFileException(Exception): + pass + + +def is_valid_sqlite(file_path: str) -> bool: + """ + `file_path` should be a path to a file on the local file system. Validation. + + includes: + - test that a file exists at `file_path`, + - test that the file at `file_path` has non-zero size, + - perform a SQLite PRAGMA quick_check on file at `file_path`. + + If errors are found during validation, then exceptions that this function + may raise include: + - sqlite3.DatabaseError if the PRAGMA quick_check fails. + - FileNotFoundError if no file was found at `file_path`. + - exporter.storage.EmptyFileException if the file at `file_path` has + zero size. + + Returns True if validation checks all pass. + """ + + if path.getsize(file_path) == 0: + raise EmptyFileException(f"{file_path} has zero size.") + + with sqlite3.connect(file_path) as connection: + cursor = connection.cursor() + # Executing "PRAGMA quick_check" raises DatabaseError if the SQLite + # database file is invalid. + cursor.execute("PRAGMA quick_check") + return True + + class HMRCStorage(S3Boto3Storage): def get_default_settings(self): # Importing settings here makes it possible for tests to override_settings @@ -113,7 +148,9 @@ def export_database(self, filename: str): sqlite.make_export(connection) connection.close() logger.info(f"Saving {filename} to S3 storage.") - self.save(filename, temp_sqlite_db.file) + if is_valid_sqlite(temp_sqlite_db.name): + # Only save to S3 if the SQLite file is valid. + self.save(filename, temp_sqlite_db.file) class SQLiteLocalStorage(SQLiteExportMixin, Storage): diff --git a/exporter/tests/test_files/empty_sqlite.db b/exporter/tests/test_files/empty_sqlite.db new file mode 100644 index 000000000..e69de29bb diff --git a/exporter/tests/test_files/invalid_sqlite.db b/exporter/tests/test_files/invalid_sqlite.db new file mode 100644 index 000000000..abe95ec8e --- /dev/null +++ b/exporter/tests/test_files/invalid_sqlite.db @@ -0,0 +1 @@ +invalid sqlite file content \ No newline at end of file diff --git a/exporter/tests/test_files/valid_sqlite.db b/exporter/tests/test_files/valid_sqlite.db new file mode 100644 index 000000000..8c92662d4 Binary files /dev/null and b/exporter/tests/test_files/valid_sqlite.db differ diff --git a/exporter/tests/test_sqlite.py b/exporter/tests/test_sqlite.py index c75d8f839..bb4f7cdf9 100644 --- a/exporter/tests/test_sqlite.py +++ b/exporter/tests/test_sqlite.py @@ -1,4 +1,6 @@ +import sqlite3 import tempfile +from contextlib import nullcontext from io import BytesIO from os import path from pathlib import Path @@ -13,6 +15,9 @@ from exporter.sqlite import plan from exporter.sqlite import tasks from exporter.sqlite.runner import Runner +from exporter.sqlite.runner import SQLiteMigrator +from exporter.storages import EmptyFileException +from exporter.storages import is_valid_sqlite from workbaskets.validators import WorkflowStatus pytestmark = pytest.mark.django_db @@ -42,6 +47,58 @@ def sqlite_database(sqlite_template: Runner) -> Iterator[Runner]: yield Runner(in_memory_database) +def get_test_file_path(filename): + return path.join( + path.dirname(path.abspath(__file__)), + "test_files", + filename, + ) + + +@pytest.mark.parametrize( + ("test_file_path, expect_context"), + ( + ( + get_test_file_path("valid_sqlite.db"), + nullcontext(), + ), + ( + "/invalid/file/path", + pytest.raises(FileNotFoundError), + ), + ( + get_test_file_path("empty_sqlite.db"), + pytest.raises(EmptyFileException), + ), + ( + get_test_file_path("invalid_sqlite.db"), + pytest.raises(sqlite3.DatabaseError), + ), + ), +) +def test_is_valid_sqlite(test_file_path, expect_context): + """Test that `is_valid_sqlite()` raises correct exceptions for invalid + SQLite files and succeeds for valid SQLite files.""" + with expect_context: + is_valid_sqlite(test_file_path) + + +@pytest.mark.parametrize( + ("migrations_in_tmp_dir"), + (False, True), +) +def test_sqlite_migrator(migrations_in_tmp_dir): + """Test SQLiteMigrator.""" + with tempfile.NamedTemporaryFile() as sqlite_file: + sqlite_migrator = SQLiteMigrator( + sqlite_file=Path(sqlite_file.name), + migrations_in_tmp_dir=migrations_in_tmp_dir, + ) + sqlite_migrator.migrate() + + assert is_valid_sqlite(sqlite_file.name) + + FACTORIES_EXPORTED = [ factory for factory in factories.TrackedModelMixin.__subclasses__() diff --git a/hmrc_sdes/tests/test_client.py b/hmrc_sdes/tests/test_client.py index ebd660b11..9533729dc 100644 --- a/hmrc_sdes/tests/test_client.py +++ b/hmrc_sdes/tests/test_client.py @@ -63,7 +63,9 @@ def test_api_call(responses, settings): responses.add_passthru(settings.HMRC["base_url"]) # reload settings from env, overriding test settings - dotenv.read_dotenv(os.path.join(settings.BASE_DIR, ".env")) + import dotenv + + dotenv.load_dotenv(dot_envpath=os.path.join(settings.BASE_DIR, ".env")) settings.HMRC["client_id"] = os.environ.get("HMRC_API_CLIENT_ID") settings.HMRC["client_secret"] = os.environ.get("HMRC_API_CLIENT_SECRET") settings.HMRC["service_reference_number"] = os.environ.get( diff --git a/manage.py b/manage.py index 6122f2dd9..d101448f0 100755 --- a/manage.py +++ b/manage.py @@ -6,14 +6,52 @@ import dotenv +ENV_INFO_FLAG = "--env-info" + + +def output_env_info(): + """Inspect and output environment diagnostics for help with platform / + environment debugging.""" + + import pwd + from pathlib import Path + + cwd = Path().resolve() + script_path = Path(__file__).resolve() + executable_path = Path(sys.executable).resolve() + path = os.environ.get("PATH") + username = pwd.getpwuid(os.getuid()).pw_name + + print("Environment diagnostics") + print("----") + print(f" Current working directory: {cwd}") + print(f" Current script path: {script_path}") + print(f" Python executable path: {executable_path}") + print(f" PATH: {path}") + print(f" username: {username}") + print("----") + + # Remove the flag to avoid Django unknown command errors. + sys.argv = [arg for arg in sys.argv if arg != ENV_INFO_FLAG] + + +def set_django_settings_module(): + """Set the DJANGO_SETTINGS_MODULE env var with an appropriate value.""" -def main(): in_test = not {"pytest", "test"}.isdisjoint(sys.argv[1:]) in_dev = in_test is False and "DEV" == str(os.environ.get("ENV")).upper() os.environ.setdefault( "DJANGO_SETTINGS_MODULE", "settings.test" if in_test else "settings.dev" if in_dev else "settings", ) + + +def main(): + if ENV_INFO_FLAG in sys.argv: + output_env_info() + + set_django_settings_module() + try: from django.core.management import execute_from_command_line except ImportError as exc: @@ -28,5 +66,5 @@ def main(): if __name__ == "__main__": with warnings.catch_warnings(): warnings.simplefilter("ignore") - dotenv.read_dotenv() + dotenv.load_dotenv() main() diff --git a/publishing/models/crown_dependencies_envelope.py b/publishing/models/crown_dependencies_envelope.py index 043acfcd0..cf85ea584 100644 --- a/publishing/models/crown_dependencies_envelope.py +++ b/publishing/models/crown_dependencies_envelope.py @@ -144,7 +144,10 @@ def publishing_succeeded(self): @save_after @transition( field=publishing_state, - source=ApiPublishingState.CURRENTLY_PUBLISHING, + source=[ + ApiPublishingState.CURRENTLY_PUBLISHING, + ApiPublishingState.FAILED_PUBLISHING, + ], target=ApiPublishingState.FAILED_PUBLISHING, custom={"label": "Publishing failed"}, ) diff --git a/publishing/tests/test_model_crown_dependencies_envelope.py b/publishing/tests/test_model_crown_dependencies_envelope.py index b2ba46820..928998448 100644 --- a/publishing/tests/test_model_crown_dependencies_envelope.py +++ b/publishing/tests/test_model_crown_dependencies_envelope.py @@ -1,5 +1,6 @@ import pytest +from common.tests.factories import CrownDependenciesEnvelopeFactory from notifications.models import Notification from publishing.models import ApiPublishingState from publishing.models import CrownDependenciesEnvelope @@ -70,3 +71,41 @@ def test_notify_processing_failed( Notification.objects.last() mocked_send_emails_apply_async.assert_called() + + +@pytest.mark.parametrize( + "transition_method, source_state, target_state", + [ + ( + "publishing_succeeded", + ApiPublishingState.CURRENTLY_PUBLISHING, + ApiPublishingState.SUCCESSFULLY_PUBLISHED, + ), + ( + "publishing_failed", + ApiPublishingState.CURRENTLY_PUBLISHING, + ApiPublishingState.FAILED_PUBLISHING, + ), + ( + "publishing_failed", + ApiPublishingState.FAILED_PUBLISHING, + ApiPublishingState.FAILED_PUBLISHING, + ), + ], +) +def test_crown_dependencies_envelope_transition_methods( + transition_method, + source_state, + target_state, + settings, +): + """Tests that `CrownDependenciesEnvelope` transition methods move + `publishing_state` from source_state to target_state.""" + + settings.ENABLE_PACKAGING_NOTIFICATIONS = False + + envelope = CrownDependenciesEnvelopeFactory.create( + publishing_state=source_state, + ) + getattr(envelope, transition_method)() + assert envelope.publishing_state == target_state diff --git a/pyproject.toml b/pyproject.toml index 127fb30e6..fbd3393b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ readme = "README.md" dependencies = [ "dj-database-url", "django", - "django-dotenv", "django-extra-fields", "django-filter", "django-fsm", @@ -28,6 +27,7 @@ dependencies = [ "gunicorn", "jinja2", "psycopg[binary]", + "python-dotenv", "sentry-sdk", "werkzeug", "whitenoise", diff --git a/quotas/business_rules.py b/quotas/business_rules.py index f26fc27ff..0092ecd5e 100644 --- a/quotas/business_rules.py +++ b/quotas/business_rules.py @@ -379,6 +379,18 @@ class QA2(ValidityPeriodContained): contained_field_name = "sub_quota" +def check_QA2_dict(sub_definition_valid_between, main_definition_valid_between): + """Confirms data is compliant with QA2.""" + if main_definition_valid_between.upper: + return ( + sub_definition_valid_between.lower >= main_definition_valid_between.lower + ) and ( + sub_definition_valid_between.upper <= main_definition_valid_between.upper + ) + else: + return sub_definition_valid_between.lower >= main_definition_valid_between.lower + + class QA3(BusinessRule): """ When converted to the measurement unit of the main quota, the volume of a @@ -395,14 +407,34 @@ class QA3(BusinessRule): def validate(self, association): main = association.main_quota sub = association.sub_quota - if not ( - sub.measurement_unit == main.measurement_unit - and sub.volume <= main.volume - and sub.initial_volume <= main.initial_volume + if not check_QA3_dict( + main_definition_unit=main.measurement_unit, + sub_definition_unit=sub.measurement_unit, + main_definition_volume=main.volume, + sub_definition_volume=sub.volume, + main_initial_volume=main.initial_volume, + sub_initial_volume=sub.initial_volume, ): raise self.violation(association) +def check_QA3_dict( + main_definition_unit, + sub_definition_unit, + main_definition_volume, + sub_definition_volume, + sub_initial_volume, + main_initial_volume, +): + """Confirms data is compliant with QA3 See note above about changing the + unit types.""" + return ( + main_definition_unit == sub_definition_unit + and sub_definition_volume <= main_definition_volume + and sub_initial_volume <= main_initial_volume + ) + + class QA4(BusinessRule): """ Whenever a sub-quota receives a coefficient, this has to be a strictly @@ -412,10 +444,14 @@ class QA4(BusinessRule): """ def validate(self, association): - if not association.coefficient > 0: + if not check_QA4_dict(association.coefficient): raise self.violation(association) +def check_QA4_dict(coefficient): + return coefficient > 0 + + class QA5(BusinessRule): """ Whenever a sub-quota is defined with the 'equivalent' type, it must have the @@ -427,7 +463,7 @@ class QA5(BusinessRule): def validate(self, association): if association.sub_quota_relation_type == SubQuotaType.EQUIVALENT: - if association.coefficient == Decimal("1.00000"): + if not check_QA5_equivalent_coefficient(association.coefficient): raise self.violation( model=association, message=( @@ -435,14 +471,7 @@ def validate(self, association): "coefficient not equal to 1" ), ) - - if ( - association.main_quota.sub_quotas.values("volume") - .order_by("volume") - .distinct("volume") - .count() - > 1 - ): + if not check_QA5_equivalent_volumes(association.main_quota): raise self.violation( model=association, message=( @@ -452,17 +481,33 @@ def validate(self, association): ), ) - elif ( - association.sub_quota_relation_type == SubQuotaType.NORMAL - and association.coefficient != Decimal("1.00000") - ): - raise self.violation( - model=association, - message=( - "A sub-quota defined with the 'normal' type must have a coefficient " - "equal to 1" - ), - ) + elif association.sub_quota_relation_type == SubQuotaType.NORMAL: + if not check_QA5_normal_coefficient(association.coefficient): + raise self.violation( + model=association, + message=( + "A sub-quota defined with the 'normal' type must have a coefficient " + "equal to 1" + ), + ) + + +def check_QA5_equivalent_coefficient(coefficient): + return coefficient != Decimal("1.000") + + +def check_QA5_equivalent_volumes(original_definition, volume=None): + return ( + original_definition.sub_quotas.values("volume") + .order_by("volume") + .distinct("volume") + .count() + <= 1 + ) + + +def check_QA5_normal_coefficient(coefficient): + return Decimal(coefficient) == Decimal("1.000") class QA6(BusinessRule): @@ -470,21 +515,37 @@ class QA6(BusinessRule): relation type.""" def validate(self, association): - if ( - association.main_quota.sub_quota_associations.approved_up_to_transaction( - association.transaction, - ) - .values( - "sub_quota_relation_type", - ) - .order_by("sub_quota_relation_type") - .distinct() - .count() - > 1 + if not check_QA6_dict( + association.main_quota, + association.sub_quota_relation_type, + association.transaction, ): raise self.violation(association) +def check_QA6_dict(main_quota, new_relation_type, transaction=None): + """ + Confirms the provided data is compliant with the above businsess rule. + + The above test will be re-run so as to separate historic violations, which + will require TAP to fix, from a user trying to introduce a new violation. + Because the business rule should have been checked and there should only be + one type, we can check the new type against any one of the old type + """ + relation_type = ( + main_quota.sub_quota_associations.approved_up_to_transaction( + transaction, + ) + .values("sub_quota_relation_type") + .order_by("sub_quota_relation_type") + .distinct() + ) + if relation_type.count() > 1: + return False + elif relation_type.count() == 1: + return relation_type[0]["sub_quota_relation_type"] == new_relation_type + + class SameMainAndSubQuota(BusinessRule): """A quota association may only exist between two distinct quota definitions.""" diff --git a/quotas/forms.py b/quotas/forms.py index 7cdd0e062..74295f484 100644 --- a/quotas/forms.py +++ b/quotas/forms.py @@ -15,6 +15,7 @@ from django.template.loader import render_to_string from django.urls import reverse_lazy +from common.fields import AutoCompleteField from common.forms import BindNestedFormMixin from common.forms import FormSet from common.forms import FormSetField @@ -25,17 +26,23 @@ from common.forms import delete_form_for from common.forms import formset_factory from common.forms import unprefix_formset_data +from common.serializers import deserialize_date from common.util import validity_range_contains_range from common.validators import SymbolValidator from common.validators import UpdateType from geo_areas.models import GeographicalArea from measures.models import MeasurementUnit +from quotas import business_rules from quotas import models from quotas import validators from quotas.constants import QUOTA_EXCLUSIONS_FORMSET_PREFIX from quotas.constants import QUOTA_ORIGIN_EXCLUSIONS_FORMSET_PREFIX from quotas.constants import QUOTA_ORIGINS_FORMSET_PREFIX +from quotas.serializers import serialize_duplicate_data +from workbaskets.forms import SelectableObjectsForm +RELATIONSHIP_TYPE_HELP_TEXT = "Select the relationship type for the quota association" +COEFFICIENT_HELP_TEXT = "Select the coefficient for the quota association" CATEGORY_HELP_TEXT = "Categories are required for the TAP database but will not appear as a TARIC3 object in your workbasket" SAFEGUARD_HELP_TEXT = ( "Once the quota category has been set as ‘Safeguard’, this cannot be changed" @@ -66,38 +73,6 @@ def __init__(self, *args, **kwargs): QuotaDeleteForm = delete_form_for(models.QuotaOrderNumber) -class QuotaDefinitionFilterForm(forms.Form): - quota_type = forms.MultipleChoiceField( - label="View by", - choices=[ - ("sub_quotas", "Sub-quotas"), - ("blocking_periods", "Blocking periods"), - ("suspension_periods", "Suspension periods"), - ], - widget=forms.RadioSelect(), - ) - - def __init__(self, *args, **kwargs): - quota_type_initial = kwargs.pop("form_initial") - object_sid = kwargs.pop("object_sid") - super().__init__(*args, **kwargs) - self.fields["quota_type"].initial = quota_type_initial - self.helper = FormHelper() - - clear_url = reverse_lazy( - "quota_definition-ui-list", - kwargs={"sid": object_sid}, - ) - - self.helper.layout = Layout( - Field.radios("quota_type", label_size=Size.SMALL), - Button("submit", "Apply"), - HTML( - f'Restore defaults', - ), - ) - - class QuotaOriginExclusionsForm(forms.Form): exclusion = forms.ModelChoiceField( label="", @@ -1005,3 +980,345 @@ def create_blocking_period(self, workbasket): update_type=UpdateType.CREATE, transaction=workbasket.new_transaction(), ) + + +class DuplicateQuotaDefinitionPeriodStartForm(forms.Form): + pass + + +class QuotaOrderNumbersSelectForm(forms.Form): + main_quota_order_number = AutoCompleteField( + label="Main quota order number", + queryset=models.QuotaOrderNumber.objects.all(), + required=True, + ) + sub_quota_order_number = AutoCompleteField( + label="Sub-quota order number", + queryset=models.QuotaOrderNumber.objects.all(), + required=True, + ) + + def __init__(self, *args, **kwargs): + self.request = kwargs.pop("request", None) + super().__init__(*args, **kwargs) + self.init_layout(self.request) + + def init_layout(self, request): + self.helper = FormHelper(self) + self.helper.label_size = Size.SMALL + self.helper.legend_size = Size.SMALL + + self.helper.layout = Layout( + Div( + HTML( + '

Enter main and sub-quota order numbers

', + ), + ), + Div( + "main_quota_order_number", + Div( + "sub_quota_order_number", + css_class="govuk-inset-text", + ), + ), + Submit( + "submit", + "Continue", + data_module="govuk-button", + data_prevent_double_click="true", + ), + ) + + +class SelectSubQuotaDefinitionsForm( + SelectableObjectsForm, +): + """Form to select the main quota definitions that are to be duplicated.""" + + def __init__(self, *args, **kwargs): + self.request = kwargs.pop("request", None) + super().__init__(*args, **kwargs) + + def set_staged_definition_data(self, selected_definitions): + if ( + self.prefix in ["select_definition_periods"] + and self.request.path != "/quotas/duplicate_quota_definitions/complete" + ): + staged_definition_data = [] + for definition in selected_definitions: + staged_definition_data.append( + { + "main_definition": definition.pk, + "sub_definition_staged_data": serialize_duplicate_data( + definition, + ), + }, + ) + self.request.session["staged_definition_data"] = staged_definition_data + + def clean(self): + cleaned_data = super().clean() + selected_definitions = { + key: value for key, value in cleaned_data.items() if value + } + definitions_pks = [ + self.object_id_from_field_name(key) for key in selected_definitions + ] + if len(selected_definitions) < 1: + raise ValidationError("At least one quota definition must be selected") + selected_definitions = models.QuotaDefinition.objects.filter( + pk__in=definitions_pks, + ).current() + cleaned_data["selected_definitions"] = selected_definitions + self.set_staged_definition_data(selected_definitions) + return cleaned_data + + +class SelectedDefinitionsForm(forms.Form): + def __init__(self, *args, **kwargs): + self.request = kwargs.pop("request") + super().__init__(*args, **kwargs) + + def clean(self): + cleaned_data = super().clean() + cleaned_data["staged_definitions"] = self.request.session[ + "staged_definition_data" + ] + for definition in cleaned_data["staged_definitions"]: + if not definition["sub_definition_staged_data"]["status"]: + raise ValidationError( + "Each definition period must have a specified relationship and co-efficient value", + ) + return cleaned_data + + +class SubQuotaDefinitionsUpdatesForm( + ValidityPeriodForm, +): + class Meta: + model = models.QuotaDefinition + fields = [ + "coefficient", + "relationship_type", + "volume", + "initial_volume", + "measurement_unit", + "valid_between", + ] + + relationship_type = forms.ChoiceField( + choices=[ + ("EQ", "Equivalent"), + ("NM", "Normal"), + ], + help_text=RELATIONSHIP_TYPE_HELP_TEXT, + error_messages={ + "required": "Choose the category", + }, + ) + + coefficient = forms.DecimalField( + label="Coefficient", + widget=forms.TextInput(), + help_text=COEFFICIENT_HELP_TEXT, + error_messages={ + "invalid": "Coefficient must be a number", + "required": "Enter the coefficient", + }, + ) + + initial_volume = forms.DecimalField( + label="Initial volume", + widget=forms.TextInput(), + help_text="The initial volume is the legal balance applied to the definition period.", + error_messages={ + "invalid": "Initial volume must be a number", + "required": "Enter the initial volume", + }, + ) + volume = forms.DecimalField( + label="Current volume", + widget=forms.TextInput(), + help_text="The current volume is the starting balance for the quota.", + error_messages={ + "invalid": "Volume must be a number", + "required": "Enter the volume", + }, + ) + + measurement_unit = forms.ModelChoiceField( + queryset=MeasurementUnit.objects.current().order_by("code"), + error_messages={"required": "Select the measurement unit"}, + ) + + def get_duplicate_data(self, original_definition): + staged_definition_data = self.request.session["staged_definition_data"] + duplicate_data = list( + filter( + lambda staged_definition_data: staged_definition_data["main_definition"] + == original_definition.pk, + staged_definition_data, + ), + )[0]["sub_definition_staged_data"] + self.set_initial_data(duplicate_data) + return duplicate_data + + def set_initial_data(self, duplicate_data): + fields = self.fields + fields["relationship_type"].initial = "NM" + fields["coefficient"].initial = 1 + fields["measurement_unit"].initial = MeasurementUnit.objects.get( + code=duplicate_data["measurement_unit_code"], + ) + fields["initial_volume"].initial = duplicate_data["initial_volume"] + fields["volume"].initial = duplicate_data["volume"] + fields["start_date"].initial = deserialize_date(duplicate_data["start_date"]) + fields["end_date"].initial = deserialize_date(duplicate_data["end_date"]) + + def init_fields(self): + self.fields["measurement_unit"].label_from_instance = ( + lambda obj: f"{obj.code} - {obj.description}" + ) + + def __init__(self, *args, **kwargs): + self.request = kwargs.pop("request", None) + main_def_id = kwargs.pop("pk") + super().__init__(*args, **kwargs) + self.original_definition = models.QuotaDefinition.objects.get( + trackedmodel_ptr_id=main_def_id, + ) + self.init_fields() + self.get_duplicate_data(self.original_definition) + self.init_layout(self.request) + + def clean(self): + cleaned_data = super().clean() + """ + Carrying out business rule checks here to prevent erroneous + associations, see: + + https://uktrade.github.io/tariff-data-manual/documentation/data-structures/quota-associations.html#validation-rules + """ + original_definition = self.original_definition + if cleaned_data["valid_between"].upper is None: + raise ValidationError("An end date must be supplied") + + if not business_rules.check_QA2_dict( + sub_definition_valid_between=cleaned_data["valid_between"], + main_definition_valid_between=original_definition.valid_between, + ): + raise ValidationError( + "QA2: Validity period for sub quota must be within the " + "validity period of the main quota", + ) + + if not business_rules.check_QA3_dict( + main_definition_unit=self.original_definition.measurement_unit, + sub_definition_unit=cleaned_data["measurement_unit"], + main_definition_volume=original_definition.volume, + sub_definition_volume=cleaned_data["volume"], + main_initial_volume=original_definition.initial_volume, + sub_initial_volume=cleaned_data["initial_volume"], + ): + raise ValidationError( + "QA3: When converted to the measurement unit of the main " + "quota, the volume of a sub-quota must always be lower than " + "or equal to the volume of the main quota", + ) + + if not business_rules.check_QA4_dict(cleaned_data["coefficient"]): + raise ValidationError( + "QA4: A coefficient must be a positive decimal number", + ) + + if cleaned_data["relationship_type"] == "NM": + if not business_rules.check_QA5_normal_coefficient( + cleaned_data["coefficient"], + ): + raise ValidationError( + "QA5: Where the relationship type is Normal, the " + "coefficient value must be 1", + ) + elif cleaned_data["relationship_type"] == "EQ": + if not business_rules.check_QA5_equivalent_coefficient( + cleaned_data["coefficient"], + ): + raise ValidationError( + "QA5: Where the relationship type is Equivalent, the " + "coefficient value must be something other than 1", + ) + if not business_rules.check_QA5_equivalent_volumes( + self.original_definition, + volume=cleaned_data["volume"], + ): + raise ValidationError( + "Whenever a sub-quota is defined with the 'equivalent' " + "type, it must have the same volume as the ones associated" + " with the parent quota", + ) + + if not business_rules.check_QA6_dict( + main_quota=original_definition, + new_relation_type=cleaned_data["relationship_type"], + ): + ValidationError( + "QA6: Sub-quotas associated with the same main quota must " + "have the same relation type.", + ) + + return cleaned_data + + def init_layout(self, request): + self.helper = FormHelper(self) + self.helper.label_size = Size.SMALL + self.helper.legend_size = Size.SMALL + + self.helper.layout = Layout( + Div( + HTML( + '

Quota association details

', + ), + Div( + Div("relationship_type", css_class="govuk-grid-column-one-half"), + Div("coefficient", css_class="govuk-grid-column-one-half"), + css_class="govuk-grid-row", + ), + ), + HTML( + '
', + ), + Div( + HTML( + '

Sub-quota definition details

', + ), + Div( + Div( + "start_date", + css_class="govuk-grid-column-one-half", + ), + Div( + "end_date", + css_class="govuk-grid-column-one-half", + ), + Div( + "initial_volume", + "measurement_unit", + css_class="govuk-grid-column-one-half", + ), + Div( + "volume", + css_class="govuk-grid-column-one-half", + ), + css_class="govuk-grid-row", + ), + HTML( + '
', + ), + Submit( + "submit", + "Save and continue", + data_module="govuk-button", + data_prevent_double_click="true", + ), + ), + ) diff --git a/quotas/jinja2/includes/quotas/actions.jinja b/quotas/jinja2/includes/quotas/actions.jinja index 0bb0d0612..393fe213c 100644 --- a/quotas/jinja2/includes/quotas/actions.jinja +++ b/quotas/jinja2/includes/quotas/actions.jinja @@ -2,9 +2,11 @@