diff --git a/controlpanel/api/cluster.py b/controlpanel/api/cluster.py index 8e6a9cbd..80260bf9 100644 --- a/controlpanel/api/cluster.py +++ b/controlpanel/api/cluster.py @@ -1007,7 +1007,9 @@ def _set_values(self, **kwargs): return set_values def install(self, **kwargs): - self._delete_legacy_release() + # TODO remove as should no longer be necessary as we uninstall the previous release before + # installing the new one + # self._delete_legacy_release() try: set_values = self._set_values(**kwargs) @@ -1026,9 +1028,12 @@ def install(self, **kwargs): except helm.HelmError as error: raise ToolDeploymentError(error) - def uninstall(self, id_token): - deployment = self.get_deployment(id_token) - helm.delete(self.k8s_namespace, deployment.metadata.name) + def uninstall(self): + try: + return helm.delete(self.k8s_namespace, self.release_name) + except helm.HelmError as error: + # TODO make this less generic + raise ToolDeploymentError(error) def restart(self, id_token): k8s = KubernetesClient(id_token=id_token) @@ -1119,6 +1124,8 @@ def get_status(self, id_token, deployment=None): if "Available" in conditions: if conditions["Available"].status == "True": + # TODO to save us having to call the KubeAPI to get deployments we could use the + # ToolDeployment created/modified timestamp to determine if the tool is idle if deployment.spec.replicas == 0: return TOOL_IDLED return TOOL_READY diff --git a/controlpanel/api/migrations/0056_tooldeployment_tool_users_deployed.py b/controlpanel/api/migrations/0056_tooldeployment_tool_users_deployed.py new file mode 100644 index 00000000..42ff2f51 --- /dev/null +++ b/controlpanel/api/migrations/0056_tooldeployment_tool_users_deployed.py @@ -0,0 +1,80 @@ +# Generated by Django 5.1.2 on 2025-01-27 12:40 + +import django.db.models.deletion +import django_extensions.db.fields +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("api", "0055_alter_user_options"), + ] + + operations = [ + migrations.CreateModel( + name="ToolDeployment", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ( + "created", + django_extensions.db.fields.CreationDateTimeField( + auto_now_add=True, verbose_name="created" + ), + ), + ( + "modified", + django_extensions.db.fields.ModificationDateTimeField( + auto_now=True, verbose_name="modified" + ), + ), + ( + "tool_type", + models.CharField( + choices=[ + ("jupyter", "JupyterLab"), + ("rstudio", "RStudio"), + ("vscode", "Visual Studio Code"), + ], + max_length=100, + ), + ), + ("is_active", models.BooleanField(default=False)), + ( + "tool", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="tool_deployments", + to="api.tool", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="tool_deployments", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "db_table": "control_panel_api_tool_deployment", + "ordering": ["-created"], + }, + ), + migrations.AddField( + model_name="tool", + name="users_deployed", + field=models.ManyToManyField( + related_name="deployed_tools", + through="api.ToolDeployment", + to=settings.AUTH_USER_MODEL, + ), + ), + ] diff --git a/controlpanel/api/models/tool.py b/controlpanel/api/models/tool.py index 6eddfe88..1a446bc0 100644 --- a/controlpanel/api/models/tool.py +++ b/controlpanel/api/models/tool.py @@ -19,14 +19,6 @@ class Tool(TimeStampedModel): instance of a tool. """ - # Defines how a matching chart name is put into a named tool bucket. - # E.g. jupyter-* charts all end up in the jupyter-lab bucket. - # chart name match: tool bucket - TOOL_BOX_CHART_LOOKUP = { - "jupyter": "jupyter-lab", - "rstudio": "rstudio", - "vscode": "vscode", - } DEFAULT_DEPRECATED_MESSAGE = "The selected release has been deprecated and will be retired soon. Please update to a more recent version." # noqa JUPYTER_DATASCIENCE_CHART_NAME = "jupyter-lab-datascience-notebook" JUPYTER_ALL_SPARK_CHART_NAME = "jupyter-lab-all-spark" @@ -67,6 +59,9 @@ class Tool(TimeStampedModel): ) is_retired = models.BooleanField(default=False) image_tag = models.CharField(max_length=100) + users_deployed = models.ManyToManyField( + "User", through="ToolDeployment", related_name="deployed_tools" + ) class Meta(TimeStampedModel.Meta): db_table = "control_panel_api_tool" @@ -75,9 +70,8 @@ class Meta(TimeStampedModel.Meta): def __repr__(self): return f"" - def url(self, user): - tool = self.tool_domain or self.chart_name - return f"https://{user.slug}-{tool}.{settings.TOOLS_DOMAIN}/" + def __str__(self): + return f"[{self.chart_name} {self.image_tag}] {self.description}" def save(self, *args, **kwargs): helm.update_helm_repository(force=True) @@ -131,57 +125,79 @@ def status_colour(self): } return mapping[self.status.lower()] + @property + def tool_type(self): + return self.chart_name.split("-")[0] -class ToolDeploymentManager: - """ - Emulates a Django model manager - """ + @property + def tool_type_name(self): + mapping = { + "jupyter": "JupyterLab", + "rstudio": "RStudio", + "vscode": "Visual Studio Code", + } + return mapping[self.tool_type] + + +class ToolDeploymentQuerySet(models.QuerySet): + def active(self): + return self.filter(is_active=True) - def create(self, *args, **kwargs): - tool_deployment = ToolDeployment(*args, **kwargs) - tool_deployment.save() - return tool_deployment + def inactive(self): + return self.filter(is_active=False) -class ToolDeployment: +class ToolDeployment(TimeStampedModel): """ Represents a deployed Tool in the cluster """ - DoesNotExist = django.core.exceptions.ObjectDoesNotExist + class ToolType(models.TextChoices): + JUPYTER = "jupyter", "JupyterLab" + RSTUDIO = "rstudio", "RStudio" + VSCODE = "vscode", "Visual Studio Code" + + user = models.ForeignKey(to="User", on_delete=models.CASCADE, related_name="tool_deployments") + tool = models.ForeignKey(to="Tool", on_delete=models.CASCADE, related_name="tool_deployments") + tool_type = models.CharField(max_length=100, choices=ToolType.choices) + is_active = models.BooleanField(default=False) + Error = cluster.ToolDeploymentError - MultipleObjectsReturned = django.core.exceptions.MultipleObjectsReturned - objects = ToolDeploymentManager() + objects = ToolDeploymentQuerySet.as_manager() + + class Meta: + ordering = ["-created"] + db_table = "control_panel_api_tool_deployment" - def __init__(self, tool, user, old_chart_name=None): + def __init__(self, *args, **kwargs): + # TODO these may not be necessary but leaving for now self._subprocess = None - self.tool = tool - self.user = user - self.old_chart_name = old_chart_name + super().__init__(*args, **kwargs) def __repr__(self): return f"" - def delete(self, id_token): + def uninstall(self): """ Remove the release from the cluster """ - cluster.ToolDeployment(self.user, self.tool).uninstall(id_token) + return cluster.ToolDeployment(tool=self.tool, user=self.user).uninstall() - @property - def host(self): - return f"{self.user.slug}-{self.tool.chart_name}.{settings.TOOLS_DOMAIN}" + def delete(self, *args, **kwargs): + """ + Remove the release from the cluster + """ + self.uninstall() + super().delete(*args, **kwargs) - def save(self, *args, **kwargs): + def deploy(self): """ Deploy the tool to the cluster (asynchronous) """ - self._subprocess = cluster.ToolDeployment( - self.user, self.tool, self.old_chart_name - ).install() + self._subprocess = cluster.ToolDeployment(self.user, self.tool).install() - def get_status(self, id_token, deployment=None): + def get_status(self, id_token=None, deployment=None): """ Get the current status of the deployment. Polls the subprocess if running, otherwise returns idled status. @@ -194,9 +210,17 @@ def get_status(self, id_token, deployment=None): log.info(status) return status return cluster.ToolDeployment(self.user, self.tool).get_status( - id_token, deployment=deployment + id_token or self.user.get_id_token(), deployment=deployment ) + @property + def url(self): + tool = self.tool.tool_domain or self.tool.chart_name + url = f"https://{self.user.slug}-{tool}.{settings.TOOLS_DOMAIN}/" + if self.tool_type == self.ToolType.VSCODE: + url = f"{url}?folder=/home/analyticalplatform/workspace" + return url + def _poll(self): """ Poll the deployment subprocess for status @@ -212,10 +236,6 @@ def _poll(self): log.info(self._subprocess.stdout.read().strip()) self._subprocess = None - @property - def url(self): - return f"https://{self.host}/" - def restart(self, id_token): """ Restart the tool deployment diff --git a/controlpanel/api/serializers.py b/controlpanel/api/serializers.py index 4f101302..7d82ad1d 100644 --- a/controlpanel/api/serializers.py +++ b/controlpanel/api/serializers.py @@ -16,10 +16,12 @@ AppS3Bucket, IPAllowlist, S3Bucket, + ToolDeployment, User, UserApp, UserS3Bucket, ) +from controlpanel.utils import start_background_task class AppS3BucketSerializer(serializers.ModelSerializer): @@ -337,17 +339,42 @@ class DeleteAppCustomerSerializer(serializers.Serializer): env_name = serializers.CharField(max_length=64, required=True) -class ToolDeploymentSerializer(serializers.Serializer): - old_chart_name = serializers.CharField(max_length=64, required=False) - version = serializers.CharField(max_length=64, required=True) +class ToolDeploymentSerializer(serializers.ModelSerializer): + class Meta: + model = ToolDeployment + fields = ("tool",) - def validate_version(self, value): - try: - _, _, _ = value.split("__") - except ValueError: - raise serializers.ValidationError( - "This field include chart name, version and tool.id," ' they are joined by "__".' - ) + def __init__(self, *args, **kwargs): + self.request = kwargs.pop("request") + super().__init__(*args, **kwargs) + + def create(self, validated_data): + tool = validated_data["tool"] + # get the currently active deployment + previous_deployment = ToolDeployment.objects.filter( + user=self.request.user, tool_type=tool.tool_type, is_active=True + ).first() + # mark all previous deployments for this tool type as inactive + ToolDeployment.objects.filter(user=self.request.user, tool_type=tool.tool_type).update( + is_active=False + ) + # create the new active deployment record + new_deployment = ToolDeployment.objects.create( + tool=tool, + tool_type=tool.tool_type, + user=self.request.user, + is_active=True, + ) + # use these details to start a background process to uninstall the deploy the new tool + # TODO we may want to refactor this to be handled by celery + start_background_task( + "tool.deploy", + { + "new_deployment_id": new_deployment.id, + "previous_deployment_id": previous_deployment.id if previous_deployment else None, + }, + ) + return new_deployment class ESBucketHitsSerializer(serializers.BaseSerializer): diff --git a/controlpanel/api/views/tool_deployments.py b/controlpanel/api/views/tool_deployments.py index 6d81e198..9567fb27 100644 --- a/controlpanel/api/views/tool_deployments.py +++ b/controlpanel/api/views/tool_deployments.py @@ -6,7 +6,6 @@ # First-party/Local from controlpanel.api import serializers -from controlpanel.utils import start_background_task class ToolDeploymentAPIView(GenericAPIView): @@ -15,49 +14,13 @@ class ToolDeploymentAPIView(GenericAPIView): serializer_class = serializers.ToolDeploymentSerializer permission_classes = (IsAuthenticated,) - def _deploy(self, chart_name, data): - """ - This is the most backwards thing you'll see for a while. The helm - task to deploy the tool apparently must happen when the view class - attempts to redirect to the target url. I'm sure there's a good - reason why. - """ - # If there's already a tool deployed, we need to get this from a - # hidden field posted back in the form. This is used by helm to delete - # the currently installed chart for the tool before installing the - # new chart. - old_chart_name = data.get("deployed_chart_name", None) - # The selected option from the "version" select control contains the - # data we need. - chart_info = data.get("version") - # The tool name and version are stored in the selected option's value - # attribute and then split on "__" to extract them. Why? Because we - # need both pieces of information to kick off the background helm - # deploy. - tool_name, tool_version, tool_id = chart_info.split("__") - - # Kick off the helm chart as a background task. - start_background_task( - "tool.deploy", - { - "tool_name": chart_name, - "version": tool_version, - "tool_id": tool_id, - "user_id": self.request.user.id, - "id_token": self.request.user.get_id_token(), - "old_chart_name": old_chart_name, - }, - ) - def post(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - - chart_name = self.kwargs["tool_name"] - tool_action = self.kwargs["action"] - tool_action_function = getattr(self, f"_{tool_action}", None) - if tool_action_function and callable(tool_action_function): - tool_action_function(chart_name, request.data) - return Response(status=status.HTTP_200_OK) - else: + # TODO this is kept for legacy reasons, where the action is passed as a URL parameter. We + # may want to remove to either pass the action in the POST data, or remove the action + # entirely as currently it is only used for deploying a tool anyway. + if self.kwargs["action"] != "deploy": return Response(status=status.HTTP_400_BAD_REQUEST) + self.serializer = self.get_serializer(data=request.data, request=request) + self.serializer.is_valid(raise_exception=True) + self.serializer.save() + return Response(status=status.HTTP_200_OK) diff --git a/controlpanel/cli/management/commands/create_tool_deployments.py b/controlpanel/cli/management/commands/create_tool_deployments.py new file mode 100644 index 00000000..16d6bbda --- /dev/null +++ b/controlpanel/cli/management/commands/create_tool_deployments.py @@ -0,0 +1,103 @@ +# Third-party +from django.core.management.base import BaseCommand +from django.db.models import Q + +# First-party/Local +from controlpanel.api.kubernetes import KubernetesClient +from controlpanel.api.models import Tool, ToolDeployment, User + + +class Command(BaseCommand): + + MEMORY_DEFAULT = "12Gi" + CPU_DEFAULT = "1" + + def handle(self, *args, **options): + ToolDeployment.objects.all().delete() + client = KubernetesClient(use_cpanel_creds=True) + deployments = client.AppsV1Api.list_deployment_for_all_namespaces() + for deployment in deployments.items: + name = deployment.metadata.name + # we only care about tool deployments + if not name.startswith(("vscode", "jupyter", "rstudio")): + continue + + # we don't care about the scheduler + if "scheduler" in name: + continue + + username = deployment.metadata.namespace.strip("user-") + try: + user = User.objects.get(username=username) + except User.DoesNotExist: + self.stderr.write( + f"ERROR: failed to find user with username {username}, skipping\n---------------" # noqa + ) + continue + # get the chart name, version, and use that to get the tool type + chart_name, chart_version = deployment.metadata.labels.get("chart", "").rsplit("-", 1) + tool_type = chart_name.split("-")[0] + + tool_container = None + for container in deployment.spec.template.spec.containers: + if "auth" not in container.name: + tool_container = container + + # get specific details about the deployed tool + image_tag = tool_container.image.split(":")[-1] + requests_memory = tool_container.resources.requests["memory"] + requests_cpu = tool_container.resources.requests["cpu"] + limits_memory = tool_container.resources.limits["memory"] + gpu = deployment.spec.template.spec.node_selector == {"gpu-compute": "true"} + + self.stdout.write(f"Username: {username}") + self.stdout.write(f"Name: {name}") + self.stdout.write(f"Tool Type: {tool_type}") + self.stdout.write(f"Chart Name: {chart_name}") + self.stdout.write(f"Chart Version: {chart_version}") + self.stdout.write(f"Image tag: {image_tag}") + self.stdout.write(f"Requests Memory: {requests_memory}") + self.stdout.write(f"Requests CPU: {requests_cpu}") + self.stdout.write(f"Limits Memory: {limits_memory}") + self.stdout.write(f"GPU: {gpu}") + + tool_queryset = Tool.objects.filter( + Q(is_restricted=False) | Q(target_users=user), is_retired=False + ) + tool_queryset = tool_queryset.filter( + image_tag=image_tag, + version=chart_version, + chart_name=chart_name, + ) + + # build up values to filter for specific releases if they dont match the defaults + values = {} + if requests_memory != self.MEMORY_DEFAULT: + values[f"{tool_type}.resources.requests.memory"] = requests_memory + if requests_cpu != self.CPU_DEFAULT: + values[f"{tool_type}.resources.requests.cpu"] = requests_cpu + if limits_memory != self.MEMORY_DEFAULT: + values[f"{tool_type}.resources.limits.memory"] = limits_memory + + if values: + tool_queryset = tool_queryset.filter(values__contains=values) + + # filter or exclude GPU releases + if gpu: + tool_queryset = tool_queryset.filter(values__contains={"gpu.enabled": "true"}) + else: + tool_queryset = tool_queryset.exclude(values__contains={"gpu.enabled": "true"}) + + # we should be down to a single release at this point but just in case use first + tool = tool_queryset.first() + if not tool: + self.stderr.write("ERROR failed to find tool for these details\n") + continue + + # if we have a tool, create a ToolDeployment record + tool_deployment = ToolDeployment.objects.create( + tool=tool, user=user, tool_type=tool_type, is_active=True + ) + self.stdout.write( + f"Created tool deployment for {username} with tool {tool_deployment.tool}\n---------------" # noqa + ) diff --git a/controlpanel/frontend/consumers.py b/controlpanel/frontend/consumers.py index b51fdbdb..ac223399 100644 --- a/controlpanel/frontend/consumers.py +++ b/controlpanel/frontend/consumers.py @@ -19,6 +19,7 @@ HOME_RESETTING, TOOL_DEPLOY_FAILED, TOOL_DEPLOYING, + TOOL_READY, TOOL_RESTARTING, ) from controlpanel.api.models import App, HomeDirectory, IPAllowlist, Tool, ToolDeployment, User @@ -144,30 +145,37 @@ def app_ip_ranges_delete(self, message): def tool_deploy(self, message): """ - Deploy the named tool for the specified user - Expects a message with `tool_name`, `version` and `user_id` values + Uninstall the previous tool deployment, and deploy the new one. + Expects a message with `previous_deployment_id`, and 'new_deployment_id' values in order + to identify the user and the tool to deploy. """ - - tool, user = self.get_tool_and_user(message) - id_token = message["id_token"] - old_chart_name = message.get("old_chart_name", None) - tool_deployment = ToolDeployment(tool, user, old_chart_name) - - update_tool_status(tool_deployment, id_token, TOOL_DEPLOYING) + # if we have a previous deployment, uninstall it + previous_deployment = ToolDeployment.objects.filter( + pk=message["previous_deployment_id"] + ).first() + if previous_deployment: + try: + previous_deployment.uninstall() + except ToolDeployment.Error as err: + # if something went wrong, log the error but continue to try to deploy the new tool + log.error(err) + pass + + new_deployment = ToolDeployment.objects.get(pk=message["new_deployment_id"]) + update_tool_status(tool_deployment=new_deployment, status=TOOL_DEPLOYING) try: - tool_deployment.save() + new_deployment.deploy() + update_tool_status(tool_deployment=new_deployment, status=TOOL_READY) + log.debug(f"Deployed {new_deployment.tool.name} for {new_deployment.user}") except ToolDeployment.Error as err: + # if something went wrong, log the error and unmark the deployment object as active to + # allow the user to retry deploying the tool + new_deployment.is_active = False + new_deployment.save() + update_tool_status(tool_deployment=new_deployment, status=TOOL_DEPLOY_FAILED) self._send_to_sentry(err) - update_tool_status(tool_deployment, id_token, TOOL_DEPLOY_FAILED) log.error(err) - return - - status = wait_for_deployment(tool_deployment, id_token) - - if status == TOOL_DEPLOY_FAILED: - log.warning(f"Failed deploying {tool.name} for {user}") - else: - log.debug(f"Deployed {tool.name} for {user}") + log.warning(f"Failed deploying {new_deployment.tool.name} for {new_deployment.user}") def _send_to_sentry(self, error): if os.environ.get("SENTRY_DSN"): @@ -180,27 +188,19 @@ def tool_restart(self, message): """ Restart the named tool for the specified user """ - tool, user = self.get_tool_and_user(message) - id_token = message["id_token"] - - tool_deployment = ToolDeployment(tool, user) - update_tool_status(tool_deployment, id_token, TOOL_RESTARTING) - - tool_deployment.restart(id_token=id_token) + tool_deployment = ToolDeployment.objects.active().get( + id=message["tool_deployment_id"], + user=message["user_id"], + ) - status = wait_for_deployment(tool_deployment, id_token) + update_tool_status(tool_deployment, TOOL_RESTARTING) + tool_deployment.restart(id_token=message["id_token"]) + status = wait_for_deployment(tool_deployment, message["id_token"]) if status == TOOL_DEPLOY_FAILED: - log.warning(f"Failed restarting {tool.name} for {user}") + log.warning(f"Failed restarting {tool_deployment.tool.name} for {tool_deployment.user}") else: - log.debug(f"Restarted {tool.name} for {user}") - - def get_tool_and_user(self, message): - tool = Tool.objects.get(is_retired=False, pk=message["tool_id"]) - if not tool: - raise Exception(f"no Tool record found for query {message['tool_id']}") - user = User.objects.get(auth0_id=message["user_id"]) - return tool, user + log.debug(f"Restarted {tool_deployment.tool.name} for {tool_deployment.user}") def home_reset(self, message): """ @@ -225,7 +225,7 @@ def workers_health(self, message): log.debug("Worker health ping task executed") -def update_tool_status(tool_deployment, id_token, status): +def update_tool_status(tool_deployment, status): user = tool_deployment.user tool = tool_deployment.tool @@ -262,7 +262,7 @@ def wait_for_deployment(tool_deployment, id_token): status = TOOL_DEPLOYING while status == TOOL_DEPLOYING: status = tool_deployment.get_status(id_token) - update_tool_status(tool_deployment, id_token, status) + update_tool_status(tool_deployment, status) sleep(1) return status diff --git a/controlpanel/frontend/forms.py b/controlpanel/frontend/forms.py index 883fd829..d8934fee 100644 --- a/controlpanel/frontend/forms.py +++ b/controlpanel/frontend/forms.py @@ -8,6 +8,7 @@ from django.contrib.postgres.forms import SimpleArrayField from django.core.exceptions import ValidationError from django.core.validators import RegexValidator, validate_email +from django.db.models import Q # First-party/Local from controlpanel.api import validators @@ -23,11 +24,11 @@ S3Bucket, Tool, User, - UserS3Bucket, ) from controlpanel.api.models.access_to_s3bucket import S3BUCKET_PATH_REGEX from controlpanel.api.models.iam_managed_policy import POLICY_NAME_REGEX from controlpanel.api.models.ip_allowlist import IPAllowlist +from controlpanel.api.models.tool import ToolDeployment APP_CUSTOMERS_DELIMITERS = re.compile(r"[,; ]+") @@ -702,3 +703,76 @@ class Meta: "satisfaction_rating", "suggestions", ] + + +class ToolChoice(forms.Select): + + def create_option(self, name, value, label, selected, index, subindex=None, attrs=None): + + option = super().create_option(name, value, label, selected, index, subindex, attrs) + if value: + option["attrs"]["data-is-deprecated"] = f"{value.instance.is_deprecated}" + option["attrs"]["data-deprecated-message"] = value.instance.get_deprecated_message + + if value and selected: + option["attrs"]["label"] = f"{label} (installed)" + option["attrs"]["class"] = "installed" + + return option + + +class ToolDeploymentForm(forms.Form): + + tool = forms.ModelChoiceField( + queryset=Tool.objects.none(), + empty_label='Select a tool from this list and click "Deploy" to start', + widget=ToolChoice(attrs={"class": "govuk-select govuk-!-width-full govuk-!-font-size-16"}), + ) + + def __init__(self, *args, **kwargs): + self.user = kwargs.pop("user") + self.tool_type = kwargs.pop("tool_type") + self.deployment = kwargs.pop("deployment", None) + super().__init__(*args, **kwargs) + self.fields["tool"].queryset = self.get_tool_release_choices(tool_type=self.tool_type) + self.fields["tool"].widget.attrs.update( + {"data-action-target": self.tool_type, "id": f"tools-{self.tool_type}"} + ) + if self.deployment: + self.fields["tool"].initial = self.deployment.tool + + def get_tool_release_choices(self, tool_type: str): + """ + Return a queryset for Tool objects where: + + * The tool is not retired + + AND EITHER: + + * The tool is not restricted + + OR + + * The current user has access to the restricted tool + """ + return ( + Tool.objects.filter( + Q(is_restricted=False) | Q(target_users=self.user), + chart_name__startswith=tool_type, + ) + .exclude(is_retired=True) + .order_by("-chart_name", "-image_tag", "-version", "-created") + ) + + @property + def tool_type_label(self): + return ToolDeployment.ToolType(self.tool_type).label + + +class ToolDeploymentRestartForm(forms.Form): + tool_deployment = forms.ModelChoiceField(queryset=ToolDeployment.objects.none()) + + def __init__(self, *args, **kwargs): + self.user = kwargs.pop("user") + super().__init__(*args, **kwargs) + self.fields["tool_deployment"].queryset = self.user.tool_deployments.active() diff --git a/controlpanel/frontend/jinja2/release-list.html b/controlpanel/frontend/jinja2/release-list.html index ef5742ce..c79f92d5 100644 --- a/controlpanel/frontend/jinja2/release-list.html +++ b/controlpanel/frontend/jinja2/release-list.html @@ -46,6 +46,7 @@

Filter

Description Created Status + Num users Actions @@ -76,6 +77,9 @@

Filter

{{ release.status }} + + {{ release.num_users }} + diff --git a/controlpanel/frontend/jinja2/tool-list.html b/controlpanel/frontend/jinja2/tool-list.html index 1c38eb5b..3a113743 100644 --- a/controlpanel/frontend/jinja2/tool-list.html +++ b/controlpanel/frontend/jinja2/tool-list.html @@ -12,125 +12,95 @@

Your tools

If your tools get into an unusable state, try resetting your home directory.

-{% for chart_name, tool_info in tools_info.items() %} -{% set deployment = tool_info["deployment"] %} -

{{ tool_info.name }}

-
-
-
{{ tool_form.tool_type_label }} +
+
+ {{ csrf_input }} - {% if deployment %} - - {% endif%} -
- - -
- -
-
-

Status: - - {% if deployment and not deployment.is_retired %} - {{ deployment.status | default("") }} - {% else %} - Not deployed - {% endif %} - -

- -
- -
+
+ + {{ tool_form.tool }} +
+ +
+
+

Status: + + {% if tool_form.deployment and not tool_form.deployment.tool.is_retired %} + {{ tool_form.deployment.get_status() }} + {% else %} + Not deployed + {% endif %} + +

- +
+ +
-
- {{ csrf_input }} -
+ +
+ {{ csrf_input }} + + +
+
-
-
Warning{{ deployment.deprecated_message }}
-{% if deployment and deployment.is_retired %} +
Warning{% if tool_form.deployment %}{{ tool_form.deployment.tool.deprecated_message }}{% endif %}
-
- - Warning - Your previous deployment ({{ deployment.chart_name}}-{{ deployment.chart_version }}: {{ deployment.image_tag }}) - has been retired. You will need to deploy a new version from the dropdown list. - -
-{% endif %} -
+ + {% if tool_form.deployment and tool_form.deployment.tool.is_retired %} + +
+ + Warning + Your previous deployment ({{ tool_form.deployment.tool.chart_name}}-{{ tool_form.deployment.tool.chart_version }}: {{ tool_form.deployment.tool.image_tag }}) + has been retired. You will need to deploy a new version from the dropdown list. + +
+ {% endif %} +
{% endfor %} +

Airflow

diff --git a/controlpanel/frontend/static/javascripts/modules/tool-status.js b/controlpanel/frontend/static/javascripts/modules/tool-status.js index 3f159db2..4177a6c3 100644 --- a/controlpanel/frontend/static/javascripts/modules/tool-status.js +++ b/controlpanel/frontend/static/javascripts/modules/tool-status.js @@ -6,7 +6,7 @@ moj.Modules.toolStatus = { listenerClass: ".tool", statusLabelClass: ".tool-status-label", - versionSelector: "select[name='version']", + versionSelector: "select[name='tool']", versionNotInstalledClass: "not-installed", versionInstalledClass: "installed", installedSuffix: " (installed)", @@ -19,7 +19,7 @@ moj.Modules.toolStatus = { this.bindEvents(toolStatusListeners); } - // Bind version selects' change event listeners + // Bind tool selects' change event listeners const versionSelects = document.querySelectorAll(this.versionSelector); versionSelects.forEach(versionSelect => { versionSelect.addEventListener("change", (event) => this.versionSelectChanged(event.target)); @@ -68,7 +68,7 @@ moj.Modules.toolStatus = { }; }, - // Select the new version from the tool "version" select input + // Select the new tool from the tool select input updateAppVersion(listener, newVersionData) { const selectElement = listener.querySelector(this.versionSelector); @@ -80,19 +80,18 @@ moj.Modules.toolStatus = { notInstalledOption.remove(); } - // 2. remove "(installed)" suffix and class from old version + // 2. remove "(installed)" suffix and class from old tool version let oldVersionOption = selectElement.querySelector("option." + this.versionInstalledClass); if (oldVersionOption) { - oldVersionOption.innerText = oldVersionOption.innerText.replace(this.installedSuffix, ""); + oldVersionOption.label = oldVersionOption.label.replace(this.installedSuffix, ""); oldVersionOption.classList.remove(this.versionInstalledClass); } - // 3. add "(installed)" suffix and class to new version - let newValue = newVersionData.toolName + "__" + newVersionData.version + "__" + newVersionData.tool_id; - let newVersionOption = listener.querySelector(this.versionSelector + " option[value='" + newValue + "']"); + // 3. add "(installed)" suffix and class to new tool version + let newVersionOption = listener.querySelector(this.versionSelector + " option[value='" + newVersionData.tool_id + "']"); if (newVersionOption) { - newVersionOption.innerText = newVersionOption.innerText + this.installedSuffix; + newVersionOption.label = newVersionOption.label + this.installedSuffix; newVersionOption.classList.add(this.versionInstalledClass) // set the new version as the current chosen item @@ -142,7 +141,7 @@ moj.Modules.toolStatus = { }); }, - // version select "change" event handler + // tool version select "change" event handler versionSelectChanged(target) { const selected = target.options[target.options.selectedIndex]; const classes = selected.className.split(" "); diff --git a/controlpanel/frontend/urls.py b/controlpanel/frontend/urls.py index 22bdeb34..2c6bf8ba 100644 --- a/controlpanel/frontend/urls.py +++ b/controlpanel/frontend/urls.py @@ -72,7 +72,7 @@ ), path("tools/", views.ToolList.as_view(), name="list-tools"), path( - "tools//restart/", + "tools/restart/", views.RestartTool.as_view(), name="restart-tool", ), diff --git a/controlpanel/frontend/views/release.py b/controlpanel/frontend/views/release.py index c0f5e6f9..f5942ccb 100644 --- a/controlpanel/frontend/views/release.py +++ b/controlpanel/frontend/views/release.py @@ -1,5 +1,6 @@ # Third-party from django.contrib import messages +from django.db.models import Count, Q from django.http.response import HttpResponseRedirect from django.urls import reverse_lazy from django.views.generic.edit import CreateView, DeleteView, UpdateView @@ -25,6 +26,15 @@ class ReleaseList(OIDCLoginRequiredMixin, PermissionRequiredMixin, ListView): template_name = "release-list.html" ordering = ["name", "-version", "-created"] + def get_queryset(self): + qs = super().get_queryset() + qs = qs.annotate( + num_users=Count( + "tool_deployments", distinct=True, filter=Q(tool_deployments__is_active=True) + ) + ) + return qs + def get_context_data(self, *args, **kwargs): context = super().get_context_data(*args, **kwargs) context["filter"] = ReleaseFilter(self.request.GET, queryset=self.get_queryset()) diff --git a/controlpanel/frontend/views/tool.py b/controlpanel/frontend/views/tool.py index d0add967..1b64532b 100644 --- a/controlpanel/frontend/views/tool.py +++ b/controlpanel/frontend/views/tool.py @@ -5,147 +5,44 @@ import structlog from django.conf import settings from django.contrib import messages -from django.db.models import Q from django.urls import reverse_lazy -from django.views.generic.base import RedirectView -from django.views.generic.list import ListView +from django.views.generic import RedirectView, TemplateView from rules.contrib.views import PermissionRequiredMixin # First-party/Local -from controlpanel.api import cluster from controlpanel.api.models import Tool, ToolDeployment +from controlpanel.frontend.forms import ToolDeploymentForm, ToolDeploymentRestartForm from controlpanel.oidc import OIDCLoginRequiredMixin from controlpanel.utils import start_background_task log = structlog.getLogger(__name__) -class ToolList(OIDCLoginRequiredMixin, PermissionRequiredMixin, ListView): - context_object_name = "tools" - model = Tool +class ToolList(OIDCLoginRequiredMixin, PermissionRequiredMixin, TemplateView): permission_required = "api.list_tool" template_name = "tool-list.html" - def get_queryset(self): - """ - Return a queryset for Tool objects where: - - * The tool is to be run on this version of the infrastructure. - - AND EITHER: - - * The tool is not in beta, - - OR - - * The current user is in the beta tester group for the tool. - """ - return Tool.objects.filter( - Q(is_restricted=False) | Q(target_users=self.request.user.id) - ).exclude(is_retired=True) - - def _locate_tool_box_by_chart_name(self, chart_name): - tool_box = None - for key, bucket_name in Tool.TOOL_BOX_CHART_LOOKUP.items(): - if key in chart_name: - tool_box = bucket_name - break - return tool_box - - def _find_related_tool_record(self, chart_name, chart_version, image_tag): - """ - The current logic is to link the deployment back to the tool-release - record is based - - chart_name - - chart_version - - image_tag - if somehow we make a tool-release with duplicated 3 above fields but - different other parameters e.g. - memory, CPU etc, then the linkage will be confused although it - won't affect people usage. - """ - tools = self.get_queryset().filter(chart_name=chart_name, version=chart_version) - for tool in tools: - if tool.image_tag == image_tag: - return tool - # If we cant find a tool with the same image tag, this must mean that it was retired or - # deleted. So return none, and let the calling function handle it - return None - - def _add_new_item_to_tool_box(self, user, tool_box, tool, tools_info): - if tool_box not in tools_info: - tools_info[tool_box] = { - "name": tool.name, - "url": tool.url(user), - "deployment": None, - "releases": {}, - } - if tool.id not in tools_info[tool_box]["releases"]: - tools_info[tool_box]["releases"][tool.id] = { - "tool_id": tool.id, - "chart_name": tool.chart_name, - "description": tool.description, - "chart_version": tool.version, - "image_tag": tool.image_tag, - "is_deprecated": tool.is_deprecated, - "deprecated_message": tool.get_deprecated_message, - } - - def _get_tool_deployed_image_tag(self, containers): - for container in containers: - if "auth" not in container.name: - return container.image.split(":")[1] - return None - - def _add_deployed_charts_info(self, tools_info, user, id_token): - # Get list of deployed tools - # TODO this sets what tool the user currently has deployed. If we were to refactor to store - # deployed tools in the database, we could remove a lot of this logic - # See https://github.com/ministryofjustice/analytical-platform/issues/6266 - deployments = cluster.ToolDeployment.get_deployments(user, id_token) - for deployment in deployments: - chart_name, chart_version = cluster.ToolDeployment.get_chart_details( - deployment.metadata.labels["chart"] - ) - image_tag = self._get_tool_deployed_image_tag(deployment.spec.template.spec.containers) - tool_box = self._locate_tool_box_by_chart_name(chart_name) - tool_box = tool_box or "Unknown" - tool = self._find_related_tool_record(chart_name, chart_version, image_tag) - if not tool: - log.warn( - "this chart({}-{}) has not available from DB. ".format( - chart_name, chart_version - ) - ) - else: - self._add_new_item_to_tool_box(user, tool_box, tool, tools_info) - if tool_box not in tools_info: - # up to this stage, if the tool_box is still empty, it means - # there is no tool release available in db - tools_info[tool_box] = {"releases": {}} - tools_info[tool_box]["deployment"] = { - "tool_id": tool.id if tool else -1, - "chart_name": chart_name, - "chart_version": chart_version, - "image_tag": image_tag, - "description": tool.description if tool else "Not available", - "status": ToolDeployment(tool, user).get_status(id_token, deployment=deployment), - "is_deprecated": tool.is_deprecated if tool else False, - "deprecated_message": tool.get_deprecated_message if tool else "", - "is_retired": tool is None, - } - - def _retrieve_detail_tool_info(self, user, tools): - # TODO when deployed tools are tracked in the DB this will not be needed - # see https://github.com/ministryofjustice/analytical-platform/issues/6266 # noqa: E501 - tools_info = {} - for tool in tools: - # Work out which bucket the chart should be in - tool_box = self._locate_tool_box_by_chart_name(tool.chart_name) - # No matching tool bucket for the given chart. So ignore. - if tool_box: - self._add_new_item_to_tool_box(user, tool_box, tool, tools_info) - return tools_info + # def _add_deployed_charts_info(self, tools_info, user, id_token): + # # TODO this is left in place simply to determine the status of a tool. Not sure if it is + # # necessary or worth it we could store the status of the tool on the ToolDeployment model + # # instead + # deployments = cluster.ToolDeployment.get_deployments(user, id_token) + # # build an index using the chart name as the key for easy lookup later + # deployments = {deployment.metadata.labels["app"]: deployment for deployment in deployments} # noqa + # for tool_deployment in user.tool_deployments.active(): + # deployment = deployments.get(tool_deployment.tool.chart_name) + # tool = tool_deployment.tool + # tools_info[tool_deployment.tool_type]["deployment"] = { + # "tool_id": tool.id, + # "chart_name": tool.chart_name, + # "chart_version": tool.version, + # "image_tag": tool.image_tag, + # "description": tool.description, + # "status": tool_deployment.get_status(id_token=id_token, deployment=deployment), + # "is_deprecated": tool.is_deprecated, + # "deprecated_message": tool.get_deprecated_message, + # "is_retired": tool.is_retired, + # } def get_context_data(self, *args, **kwargs): """ @@ -187,65 +84,55 @@ def get_context_data(self, *args, **kwargs): } ``` """ - - user = self.request.user - id_token = user.get_id_token() - context = super().get_context_data(*args, **kwargs) context["user_guidance_base_url"] = settings.USER_GUIDANCE_BASE_URL context["aws_service_url"] = settings.AWS_SERVICE_URL + context["managed_airflow_dev_url"] = self.build_airflow_url("dev") + context["managed_airflow_prod_url"] = self.build_airflow_url("prod") + context["tool_forms"] = [ + self.get_tool_release_form(tool_type=tool_type) for tool_type in ToolDeployment.ToolType + ] - args_airflow_dev_url = urlencode( - { - "destination": f"mwaa/home?region={settings.AIRFLOW_REGION}#/environments/dev/sso", # noqa: E501 - } + return context + + def get_tool_release_form(self, tool_type): + deployment = self.request.user.tool_deployments.filter(tool_type=tool_type).active().first() + return ToolDeploymentForm( + user=self.request.user, + tool_type=tool_type, + deployment=deployment, ) - args_airflow_prod_url = urlencode( + + def build_airflow_url(self, environment): + destination = f"mwaa/home?region={settings.AIRFLOW_REGION}#/environments/{environment}/sso" + args = urlencode( { - "destination": f"mwaa/home?region={settings.AIRFLOW_REGION}#/environments/prod/sso", # noqa: E501 + "destination": destination, # noqa: E501 } ) - context["managed_airflow_dev_url"] = f"{settings.AWS_SERVICE_URL}/?{args_airflow_dev_url}" - context["managed_airflow_prod_url"] = f"{settings.AWS_SERVICE_URL}/?{args_airflow_prod_url}" - - tools_info = self._retrieve_detail_tool_info(user, context["tools"]) - - if "vscode" in tools_info: - url = tools_info["vscode"]["url"] - tools_info["vscode"]["url"] = f"{url}?folder=/home/analyticalplatform/workspace" - - self._add_deployed_charts_info(tools_info, user, id_token) - context["tools_info"] = tools_info - return context + return f"{settings.AWS_SERVICE_URL}/?{args}" class RestartTool(OIDCLoginRequiredMixin, RedirectView): - http_method_names = ["post"] url = reverse_lazy("list-tools") - def get_redirect_url(self, *args, **kwargs): - """ - So backwards, it's forwards. - - The "name" of the chart to restart is set in the template for - list-tools, if there's a live deployment. - - That's numberwang. - """ - name = self.kwargs["name"] + def post(self, request, *args, **kwargs): + form = ToolDeploymentRestartForm(data=request.POST, user=request.user) + if not form.is_valid(): + messages.error( + request, + "Something went wrong, please try again. If the issue persists please contact support.", # noqa + ) + return self.get(request, *args, **kwargs) + tool_deployment = form.cleaned_data["tool_deployment"] start_background_task( "tool.restart", { - "tool_name": name, - "tool_id": self.kwargs["tool_id"], - "user_id": self.request.user.id, + "tool_deployment_id": tool_deployment.id, + "user_id": self.request.user.auth0_id, "id_token": self.request.user.get_id_token(), }, ) - - messages.success( - self.request, - f"Restarting {name}...", - ) - return super().get_redirect_url(*args, **kwargs) + messages.success(self.request, f"Restarting {tool_deployment.tool.name}...") + return self.get(request, *args, **kwargs) diff --git a/tests/api/cluster/test_tool_deployment.py b/tests/api/cluster/test_tool_deployment.py index 46854d6d..52fce384 100644 --- a/tests/api/cluster/test_tool_deployment.py +++ b/tests/api/cluster/test_tool_deployment.py @@ -4,7 +4,7 @@ # First-party/Local from controlpanel.api import cluster -from controlpanel.api.models import Tool, User +from controlpanel.api.models import Tool, ToolDeployment, User def test_url(): @@ -18,14 +18,15 @@ def test_url(): chart_name="rstudio", version="1.0.0", ) + tool_deployment = ToolDeployment(user=user, tool=tool) expected = f"https://{user.slug}-rstudio.{settings.TOOLS_DOMAIN}/" # In the absence of a tool_domain, the chart_name (rstudio) is used. - assert tool.url(user) == expected - tool.chart_name = "rstudio-bespoke" - tool.tool_domain = "rstudio" + assert tool_deployment.url == expected + tool_deployment.chart_name = "rstudio-bespoke" + tool_deployment.tool_domain = "rstudio" # Now the chart_name is custom, the tool_domain (rstudio) is used, ensuring # the url remains "valid". - assert tool.url(user) == expected + assert tool_deployment.url == expected @pytest.mark.parametrize( diff --git a/tests/api/models/test_tool.py b/tests/api/models/test_tool.py index ecd38ad9..3c4e5acc 100644 --- a/tests/api/models/test_tool.py +++ b/tests/api/models/test_tool.py @@ -18,17 +18,8 @@ def tool(db): def test_deploy_for_generic(helm, tool, users): user = users["normal_user"] - # simulate release with old naming scheme installed - old_release_name = f"{tool.chart_name}-{user.username}" - helm.list_releases.return_value = [old_release_name[: settings.MAX_RELEASE_NAME_LEN]] - - tool_deployment = ToolDeployment(tool, user) - tool_deployment.save() - - # uninstall tool with old naming scheme - helm.delete.assert_called_with( - user.k8s_namespace, old_release_name[: settings.MAX_RELEASE_NAME_LEN] - ) + tool_deployment = ToolDeployment.objects.create(tool=tool, user=user, is_active=True) + tool_deployment.deploy() # install new release helm.upgrade_release.assert_called_with( diff --git a/tests/api/views/test_tool_deployments.py b/tests/api/views/test_tool_deployments.py index b62b0cff..482ce0a2 100644 --- a/tests/api/views/test_tool_deployments.py +++ b/tests/api/views/test_tool_deployments.py @@ -1,7 +1,13 @@ +# Standard library +from unittest.mock import patch + # Third-party from rest_framework import status from rest_framework.reverse import reverse +# First-party/Local +from tests.api.models.test_tool import tool # noqa: F401 + def test_get(client): response = client.get(reverse("tool-deployments", ("rstudio", "deploy"))) @@ -9,18 +15,20 @@ def test_get(client): def test_post_not_valid_data(client): - data = {"version": "rstudio_v1.0.0"} + data = {"tool": 1000} response = client.post(reverse("tool-deployments", ("rstudio", "deploy")), data) assert response.status_code == status.HTTP_400_BAD_REQUEST -def test_post_not_supported_action(client): - data = {"version": "rstudio_v1.0.0"} +def test_post_not_supported_action(client, tool): # noqa: F811 + data = {"tool": tool.pk} response = client.post(reverse("tool-deployments", ("rstudio", "testing")), data) assert response.status_code == status.HTTP_400_BAD_REQUEST -def test_post(client): - data = {"version": "rstudio__v1.0.0__1"} +@patch("controlpanel.api.serializers.ToolDeploymentSerializer.save") +def test_post(save, client, tool): # noqa: F811 + data = {"tool": tool.pk} response = client.post(reverse("tool-deployments", ("rstudio", "deploy")), data) assert response.status_code == status.HTTP_200_OK + save.assert_called_once() diff --git a/tests/frontend/test_consumers.py b/tests/frontend/test_consumers.py index 3b59011e..ae9c07f4 100644 --- a/tests/frontend/test_consumers.py +++ b/tests/frontend/test_consumers.py @@ -1,6 +1,6 @@ # Standard library import json -from unittest.mock import Mock, patch +from unittest.mock import Mock, call, patch # Third-party import pytest @@ -49,123 +49,79 @@ def wait_for_home_reset(): yield wait_for_home_reset -def test_tool_deploy(users, tools, update_tool_status, wait_for_deployment): +def test_tool_deploy(users, tools, update_tool_status): user = User.objects.first() tool = Tool.objects.first() - id_token = "secret user id_token" - - with patch("controlpanel.frontend.consumers.ToolDeployment") as tool_deploy_fix: - tool_deployment = Mock() - tool_deploy_fix.return_value = tool_deployment + tool_deployment = ToolDeployment.objects.create(tool=tool, user=user, is_active=True) + with patch.object(ToolDeployment, "deploy") as deploy: consumer = consumers.BackgroundTaskConsumer() consumer.tool_deploy( message={ - "user_id": user.auth0_id, - "tool_name": tool.chart_name, - "id_token": id_token, - "tool_id": tool.id, + "new_deployment_id": tool_deployment.id, + "previous_deployment_id": None, } ) - - # 1. Instanciate `ToolDeployment` correctly - tool_deploy_fix.assert_called_with(tool, user, None) - # 2. Send status update - update_tool_status.assert_called_with( - tool_deployment, - id_token, - TOOL_DEPLOYING, + deploy.assert_called_once() + update_tool_status.assert_has_calls( + calls=[ + call(tool_deployment=tool_deployment, status=TOOL_DEPLOYING), + call(tool_deployment=tool_deployment, status=TOOL_READY), + ] ) - # 3. Call save() on ToolDeployment (trigger deployment) - tool_deployment.save.assert_called() - # 4. Wait for deployment to complete - wait_for_deployment.assert_called_with(tool_deployment, id_token) -def test_tool_deploy_with_old_chart_name(users, tools, update_tool_status, wait_for_deployment): +def test_tool_deploy_with_previous_deployment(users, tools, update_tool_status): user = User.objects.first() tool = Tool.objects.first() - id_token = "secret user id_token" - old_chart_name = "old-chart" - - with patch("controlpanel.frontend.consumers.ToolDeployment") as tool_deploy_mock: - tool_deployment = Mock() - tool_deploy_mock.return_value = tool_deployment - # recevier = Mock() - # sender = Mock() + previous_deployment = ToolDeployment.objects.create(tool=tool, user=user, is_active=False) + new_deployment = ToolDeployment.objects.create(tool=tool, user=user, is_active=True) + with ( + patch.object(ToolDeployment, "deploy") as deploy, + patch.object(ToolDeployment, "uninstall") as uninstall, + ): consumer = consumers.BackgroundTaskConsumer() consumer.tool_deploy( message={ - "user_id": user.auth0_id, - "tool_name": tool.chart_name, - "id_token": id_token, - "old_chart_name": old_chart_name, - "tool_id": tool.id, + "new_deployment_id": new_deployment.id, + "previous_deployment_id": previous_deployment.id, } ) - # 1. Instanciate `ToolDeployment` correctly - tool_deploy_mock.assert_called_with(tool, user, old_chart_name) - # 2. Send status update - update_tool_status.assert_called_with( - tool_deployment, - id_token, - TOOL_DEPLOYING, + uninstall.assert_called_once() + deploy.assert_called_once() + update_tool_status.assert_has_calls( + calls=[ + call(tool_deployment=new_deployment, status=TOOL_DEPLOYING), + call(tool_deployment=new_deployment, status=TOOL_READY), + ] ) - # 3. Call save() on ToolDeployment (trigger deployment) - tool_deployment.save.assert_called() - # 4. Wait for deployment to complete - wait_for_deployment.assert_called_with(tool_deployment, id_token) def test_tool_restart(users, tools, update_tool_status, wait_for_deployment): user = User.objects.first() tool = Tool.objects.first() + tool_deployment = ToolDeployment.objects.create(tool=tool, user=user, is_active=True) id_token = "secret user id_token" - with patch("controlpanel.frontend.consumers.ToolDeployment") as tool_deploy_mock: - tool_deployment = Mock() - tool_deploy_mock.return_value = tool_deployment - + with patch.object(ToolDeployment, "restart") as restart_mock: consumer = consumers.BackgroundTaskConsumer() consumer.tool_restart( message={ + "tool_deployment_id": tool_deployment.id, "user_id": user.auth0_id, - "tool_name": tool.chart_name, "id_token": id_token, - "tool_id": tool.id, } ) - # 1. Instanciate `ToolDeployment` correctly - tool_deploy_mock.assert_called_with(tool, user) - # 2. Send status update update_tool_status.assert_called_with( tool_deployment, - id_token, TOOL_RESTARTING, ) - # 3. Call restart() on ToolDeployment (trigger deployment) - tool_deployment.restart.assert_called_with(id_token=id_token) - # 4. Wait for deployment to complete - wait_for_deployment.assert_called_with(tool_deployment, id_token) - + restart_mock.assert_called_with(id_token=id_token) -def test_get_tool_and_user(users, tools): - expected_user = User.objects.first() - expected_tool = Tool.objects.first() - message = { - "user_id": expected_user.auth0_id, - "tool_name": expected_tool.chart_name, - "id_token": "not used by this method", - "tool_id": expected_tool.id, - } - - consumer = consumers.BackgroundTaskConsumer() - tool, user = consumer.get_tool_and_user(message) - assert expected_user == user - assert expected_tool == tool + wait_for_deployment.assert_called_with(tool_deployment, id_token) def test_get_home_reset(users, update_home_status, wait_for_home_reset): @@ -198,7 +154,6 @@ def test_get_home_reset(users, update_home_status, wait_for_home_reset): def test_update_tool_status(): tool = Tool(chart_name="a_tool", version="v1.0.0") user = User(auth0_id="github|123") - id_token = "user id_token" status = TOOL_READY tool_deployment = Mock() @@ -220,7 +175,6 @@ def test_update_tool_status(): with patch("controlpanel.frontend.consumers.send_sse") as send_sse: consumers.update_tool_status( tool_deployment, - id_token, status, ) send_sse.assert_called_with(user.auth0_id, expected_sse_event)