diff --git a/controlpanel/api/cluster.py b/controlpanel/api/cluster.py index 5c1d7b7b1..6961e43d3 100644 --- a/controlpanel/api/cluster.py +++ b/controlpanel/api/cluster.py @@ -20,6 +20,7 @@ TOOL_IDLED = 'Idled' TOOL_NOT_DEPLOYED = 'Not deployed' TOOL_READY = 'Ready' +TOOL_RESTARTING = 'Restarting' TOOL_UPGRADED = 'Upgraded' TOOL_STATUS_UNKNOWN = 'Unknown' @@ -345,6 +346,23 @@ def get_deployment(self, id_token): return deployments[0] + + def get_installed_chart_version(self, id_token): + """ + Returns the installed helm chart version of the tool + + This is extracted from the `chart` label in the corresponding + `Deployment`. + """ + + try: + deployment = self.get_deployment(id_token) + _, chart_version = deployment.metadata.labels["chart"].rsplit("-", 1) + return chart_version + except ObjectDoesNotExist: + return None + + def get_status(self, id_token): try: deployment = self.get_deployment(id_token) diff --git a/controlpanel/api/helm.py b/controlpanel/api/helm.py index 70fe6ac64..a61a06f18 100644 --- a/controlpanel/api/helm.py +++ b/controlpanel/api/helm.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta import logging import os import re @@ -6,6 +6,7 @@ from django.conf import settings from rest_framework.exceptions import APIException +import yaml log = logging.getLogger(__name__) @@ -19,7 +20,8 @@ class HelmError(APIException): class Helm(object): - def _execute(self, *args, check=True, **kwargs): + @classmethod + def execute(cls, *args, check=True, **kwargs): should_wait = False if 'timeout' in kwargs: should_wait = True @@ -67,13 +69,10 @@ def _execute(self, *args, check=True, **kwargs): return proc - def update_repositories(self, *args): - self._execute("repo", "update", timeout=None) - def upgrade_release(self, release, chart, *args): - self.update_repositories() + HelmRepository.update() - return self._execute( + return self.__class__.execute( "upgrade", "--install", "--wait", release, chart, *args, ) @@ -81,11 +80,11 @@ def delete(self, purge=True, *args): default_args = [] if purge: default_args.append("--purge") - self._execute("delete", *default_args, *args) + self.__class__.execute("delete", *default_args, *args) def list_releases(self, *args): # TODO - use --max and --offset to paginate through releases - proc = self._execute("list", "-q", "--max=1024", *args, timeout=None) + proc = self.__class__.execute("list", "-q", "--max=1024", *args, timeout=None) return proc.stdout.read().split() @@ -148,4 +147,104 @@ def parse_upgrade_output(output): } +class Chart(object): + + def __init__(self, name, description, version, app_version): + self.name = name + self.description = description + self.version = version + self.app_version = app_version + + +class HelmRepository(object): + + CACHE_FOR_MINUTES = 30 + + HELM_HOME = Helm.execute("home").stdout.read().strip() + REPO_PATH = os.path.join( + HELM_HOME, + "repository", + "cache", + f"{settings.HELM_REPO}-index.yaml", + ) + + _updated_at = None + _repository = {} + + @classmethod + def update(cls, force=True): + if force or cls._outdated(): + Helm.execute("repo", "update", timeout=None) + cls._load() + cls._updated_at = datetime.utcnow() + + @classmethod + def _load(cls): + # Read and parse helm repository YAML file + try: + with open(cls.REPO_PATH) as f: + cls._repository = yaml.load(f, Loader=yaml.FullLoader) + except Exception as err: + wrapped_err = HelmError(err) + wrapped_err.detail = f"Error while opening/parsing helm repository cache: '{cls.REPO_PATH}'" + raise HelmError(wrapped_err) + + @classmethod + def get_chart_info(cls, name): + """ + Get information about the given chart + + Returns a dictionary with the chart versions as keys and the chart + as value (`Chart` instance) + + Returns an empty dictionary when the chart is not in the helm + repository index. + + ``` + rstudio_info = HelmRepository.get_chart_info("rstudio") + # rstudio_info = { + # "2.2.5": , + # "2.2.4": , + # } + ``` + """ + + cls.update(force=False) + + try: + versions = cls._repository["entries"][name] + except KeyError: + # No such a chart with this name, returning {} + return {} + + # Convert to dictionary + chart_info = {} + for version_info in versions: + chart = Chart( + version_info["name"], + version_info["description"], + version_info["version"], + # appVersion is relatively new and some old helm chart don't + # have it + version_info.get("appVersion", None), + ) + chart_info[chart.version] = chart + return chart_info + + @classmethod + def _outdated(cls): + # helm update never called? + if not cls._updated_at: + return True + + # helm update called more than `CACHE_FOR_MINUTES` ago + now = datetime.utcnow() + elapsed = now - cls._updated_at + if elapsed > timedelta(minutes=cls.CACHE_FOR_MINUTES): + return True + + # helm update called recently + return False + + helm = Helm() diff --git a/controlpanel/api/models/tool.py b/controlpanel/api/models/tool.py index 00f7c7d42..c9a4be989 100644 --- a/controlpanel/api/models/tool.py +++ b/controlpanel/api/models/tool.py @@ -9,6 +9,7 @@ from django_extensions.db.models import TimeStampedModel from controlpanel.api import cluster +from controlpanel.api.helm import HelmRepository log = logging.getLogger(__name__) @@ -49,14 +50,12 @@ def create(self, *args, **kwargs): return tool_deployment def filter(self, **kwargs): - deployed_versions = {} user = kwargs["user"] id_token = kwargs["id_token"] filter = Q(chart_name=None) # Always False deployments = cluster.ToolDeployment.get_deployments(user, id_token) for deployment in deployments: chart_name, version = deployment.metadata.labels["chart"].rsplit("-", 1) - deployed_versions[chart_name] = version filter = filter | ( Q(chart_name=chart_name) # & Q(version=version) @@ -65,8 +64,7 @@ def filter(self, **kwargs): tools = Tool.objects.filter(filter) results = [] for tool in tools: - outdated = tool.version != deployed_versions[tool.chart_name] - tool_deployment = ToolDeployment(tool, user, outdated) + tool_deployment = ToolDeployment(tool, user) results.append(tool_deployment) return results @@ -82,15 +80,64 @@ class ToolDeployment: objects = ToolDeploymentManager() - def __init__(self, tool, user, outdated=False): + def __init__(self, tool, user): self._subprocess = None self.tool = tool self.user = user - self.outdated = outdated def __repr__(self): return f'' + def get_installed_app_version(self, id_token): + """ + Returns the version of the deployed tool + + NOTE: This is the version coming from the helm + chart `appVersion` field, **not** the version + of the chart released in the user namespace. + + e.g. if user has `rstudio-2.2.5` (chart version) + installed in his namespace, this would return + "RStudio: 1.2.1335+conda, R: 3.5.1, Python: 3.7.1, patch: 10" + **not** "2.2.5". + + Also bear in mind that Helm added this `appVersion` + field only "recently" so if a user has an old + version of a tool chart installed this would return + `None` as we can't determine the tool version + as this information is simply not available + in the helm repository index. + """ + + td = cluster.ToolDeployment(self.user, self.tool) + chart_version = td.get_installed_chart_version(id_token) + if chart_version: + chart_info = HelmRepository.get_chart_info(self.tool.chart_name) + + version_info = chart_info.get(chart_version, None) + if version_info: + return version_info.app_version + + return None + + + def outdated(self, id_token): + """ + Returns true if the tool helm chart version is old + + NOTE: This is simple/naive at the moment and it returns true if + the installed chart for the tool has a different version + than the one in the corresponding Tool record. + """ + + td = cluster.ToolDeployment(self.user, self.tool) + chart_version = td.get_installed_chart_version(id_token) + + if chart_version: + return self.tool.version != chart_version + + return False + def delete(self, id_token): """ Remove the release from the cluster diff --git a/controlpanel/frontend/consumers.py b/controlpanel/frontend/consumers.py index e663003ff..7cfb59fb3 100644 --- a/controlpanel/frontend/consumers.py +++ b/controlpanel/frontend/consumers.py @@ -11,12 +11,12 @@ from django.conf import settings from django.urls import reverse -from controlpanel.api import cluster from controlpanel.api.cluster import ( TOOL_DEPLOYING, TOOL_DEPLOY_FAILED, TOOL_IDLED, TOOL_READY, + TOOL_RESTARTING, TOOL_UPGRADED, ) from controlpanel.api.models import Tool, ToolDeployment, User @@ -110,19 +110,17 @@ def tool_deploy(self, message): tool, user = self.get_tool_and_user(message) id_token = message["id_token"] + tool_deployment = ToolDeployment(tool, user) - update_tool_status(user, tool, TOOL_DEPLOYING) - + update_tool_status(tool_deployment, id_token, TOOL_DEPLOYING) try: - deployment = ToolDeployment(tool, user) - deployment.save() - + tool_deployment.save() except ToolDeployment.Error as err: - update_tool_status(user, tool, TOOL_DEPLOY_FAILED) + update_tool_status(tool_deployment, id_token, TOOL_DEPLOY_FAILED) log.error(err) return - status = wait_for_deployment(deployment, id_token) + status = wait_for_deployment(tool_deployment, id_token) if status == TOOL_DEPLOY_FAILED: log.error(f"Failed deploying {tool.name} for {user}") @@ -136,7 +134,10 @@ def tool_upgrade(self, message): self.tool_deploy(message) tool, user = self.get_tool_and_user(message) - update_tool_status(user, tool, TOOL_UPGRADED) + id_token = message["id_token"] + + tool_deployment = ToolDeployment(tool, user) + update_tool_status(tool_deployment, id_token, TOOL_UPGRADED) def tool_restart(self, message): """ @@ -145,12 +146,12 @@ def tool_restart(self, message): tool, user = self.get_tool_and_user(message) id_token = message["id_token"] - update_tool_status(user, tool, "Restarting") + tool_deployment = ToolDeployment(tool, user) + update_tool_status(tool_deployment, id_token, TOOL_RESTARTING) - deployment = ToolDeployment(tool, user) - deployment.restart(id_token=id_token) + tool_deployment.restart(id_token=id_token) - status = wait_for_deployment(deployment, id_token) + status = wait_for_deployment(tool_deployment, id_token) if status == TOOL_DEPLOY_FAILED: log.error(f"Failed restarting {tool.name} for {user}") @@ -179,14 +180,21 @@ def send_sse(user_id, event): ) -def update_tool_status(user, tool, status): +def update_tool_status(tool_deployment, id_token, status): + user = tool_deployment.user + tool = tool_deployment.tool + + app_version = tool_deployment.get_installed_app_version(id_token) + + payload = { + "toolName": tool.chart_name, + "version": tool.version, + "appVersion": app_version, + "status": status, + } send_sse(user.auth0_id, { "event": "toolStatus", - "data": json.dumps({ - 'toolName': tool.chart_name, - 'version': tool.version, - 'status': status, - }), + "data": json.dumps(payload), }) @@ -199,11 +207,10 @@ def start_background_task(task, message): }, ) -def wait_for_deployment(deployment, id_token): +def wait_for_deployment(tool_deployment, id_token): status = TOOL_DEPLOYING while status == TOOL_DEPLOYING: - status = deployment.get_status(id_token) - update_tool_status(deployment.user, deployment.tool, status) + status = tool_deployment.get_status(id_token) + update_tool_status(tool_deployment, id_token, status) sleep(1) return status - diff --git a/controlpanel/frontend/jinja2/tool-list.html b/controlpanel/frontend/jinja2/tool-list.html index 966c08cb4..f24dcbd41 100644 --- a/controlpanel/frontend/jinja2/tool-list.html +++ b/controlpanel/frontend/jinja2/tool-list.html @@ -23,7 +23,12 @@

Your tools

{{ tool.name }} - {#
{{ tool.description }} #} +
+ + {% if deployment %} + {{ deployment.get_installed_app_version(id_token) or "Unknown" }} + {% endif %} +
@@ -54,7 +59,7 @@

Your tools

{# #} - {% if deployment.outdated %} + {% if deployment and deployment.outdated(id_token) %}
{ action.classList.toggle(this.hidden, !action_names.includes(action.dataset.actionName)); diff --git a/controlpanel/frontend/views/tool.py b/controlpanel/frontend/views/tool.py index d3f09ed1c..1887756a0 100644 --- a/controlpanel/frontend/views/tool.py +++ b/controlpanel/frontend/views/tool.py @@ -29,7 +29,7 @@ def get_context_data(self, *args, **kwargs): user = self.request.user id_token = user.get_id_token() - deployments = ToolDeployment.objects.filter( + tool_deployments = ToolDeployment.objects.filter( user=user, id_token=id_token, ) @@ -37,9 +37,10 @@ def get_context_data(self, *args, **kwargs): context = super().get_context_data(*args, **kwargs) context["id_token"] = id_token context["deployed_tools"] = { - deployment.tool: deployment - for deployment in deployments + tool_deployment.tool: tool_deployment + for tool_deployment in tool_deployments } + return context diff --git a/tests/api/cluster/test_tool_deployment.py b/tests/api/cluster/test_tool_deployment.py new file mode 100644 index 000000000..ad565bf56 --- /dev/null +++ b/tests/api/cluster/test_tool_deployment.py @@ -0,0 +1,25 @@ +from unittest.mock import Mock, patch + +from controlpanel.api.cluster import ToolDeployment +from controlpanel.api.models import Tool, User + + +def test_get_installed_chart_version(): + user = User(username="test-user") + tool = Tool(chart_name="test-chart") + id_token = "dummy" + + installed_chart_version = "1.2.3" + + td = ToolDeployment(user, tool) + + deploy_metadata = Mock("k8s Deployment - metadata") + deploy_metadata.labels = { + "chart": f"{tool.chart_name}-{installed_chart_version}" + } + deploy = Mock("k8s Deployment", metadata=deploy_metadata) + + with patch("controlpanel.api.cluster.ToolDeployment.get_deployment") as get_deployment: + get_deployment.return_value = deploy + assert td.get_installed_chart_version(id_token) == installed_chart_version + get_deployment.assert_called_with(id_token) diff --git a/tests/api/fixtures/helm_mojanalytics_index.py b/tests/api/fixtures/helm_mojanalytics_index.py new file mode 100644 index 000000000..94c2b71a4 --- /dev/null +++ b/tests/api/fixtures/helm_mojanalytics_index.py @@ -0,0 +1,37 @@ +# Python dictionary version (excerpt) of what you'd find in the helm +# repository index YAML file at +# $(helm home)/repository/cache/mojanalytics-index.yaml +# +# used for testing the `helm` module +# +# (see `helm home --help`) +HELM_MOJANALYTICS_INDEX = { + "apiVersion": "v1", + "entries": { + "rstudio": [ + { + "apiVersion": "v1", + "appVersion": "RStudio: 1.2.1335+conda, R: 3.5.1, Python: 3.7.1, patch: 10", + "created": "2020-05-18T10:28:14.187538013Z", + "description": "RStudio with Auth0 authentication proxy", + "digest": "283e735476479425a76634840d73024f83e9d0bed7f009cb18b87916a3b84741", + "name": "rstudio", + "urls": [ + "http://moj-analytics-helm-repo.s3-website-eu-west-1.amazonaws.com/rstudio-2.2.5.tgz", + ], + "version": "2.2.5", + }, + { + "apiVersion": "v1", + "created": "2018-05-18T16:05:37.748243984Z", + "description": "A Helm chart for RStudio", + "digest": "a2df2dfe7aa0d04a6d7de175b134cc2e1e3e1b930f8b2acfdbda52fb396a4329", + "name": "rstudio", + "urls": [ + "https://ministryofjustice.github.io/analytics-platform-helm-charts/charts/rstudio-1.0.0.tgz", + ], + "version": "1.0.0", + }, + ], + }, +} diff --git a/tests/api/models/test_tool.py b/tests/api/models/test_tool.py index 476f2eda0..a6f74c426 100644 --- a/tests/api/models/test_tool.py +++ b/tests/api/models/test_tool.py @@ -48,3 +48,65 @@ def test_deploy_for_generic(helm, token_hex, tool, users): '--set', f'aws.iamRole={user.iam_role_name}', '--set', f'toolsDomain={settings.TOOLS_DOMAIN}', ) + + +@pytest.yield_fixture +def cluster(): + with patch("controlpanel.api.models.tool.cluster") as cluster: + yield cluster + + +@pytest.mark.parametrize( + "chart_version, expected_outdated", + [ + (None, False), + ("0.0.1", True), + ("1.0.0", False), + ], + ids=[ + "no-chart-version", + "old-chart-version", + "up-to-date-chart-version", + ], +) +def test_tool_deployment_outdated(cluster, chart_version, expected_outdated): + tool = Tool(chart_name="test-tool", version="1.0.0") + user = User(username="test-user") + td = ToolDeployment(tool, user) + id_token = "dummy" + + cluster_td = cluster.ToolDeployment.return_value + cluster_td.get_installed_chart_version.return_value = chart_version + + assert td.outdated(id_token) == expected_outdated + cluster.ToolDeployment.assert_called_with(user, tool) + cluster_td.get_installed_chart_version.assert_called_with(id_token) + + + + +@pytest.mark.parametrize( + "chart_version, expected_app_version", + [ + (None, None), + ("1.0.0", None), + ("2.2.5", "RStudio: 1.2.1335+conda, R: 3.5.1, Python: 3.7.1, patch: 10"), + ], + ids=[ + "no-chart-installed", + "old-chart-version", + "new-chart-version", + ], +) +def test_tool_deployment_get_installed_app_version(helm_repository_index, cluster, chart_version, expected_app_version): + tool = Tool(chart_name="rstudio") + user = User(username="test-user") + td = ToolDeployment(tool, user) + id_token = "dummy" + + cluster_td = cluster.ToolDeployment.return_value + cluster_td.get_installed_chart_version.return_value = chart_version + + assert td.get_installed_app_version(id_token) == expected_app_version + cluster.ToolDeployment.assert_called_with(user, tool) + cluster_td.get_installed_chart_version.assert_called_with(id_token) diff --git a/tests/api/test_helm.py b/tests/api/test_helm.py new file mode 100644 index 000000000..133078575 --- /dev/null +++ b/tests/api/test_helm.py @@ -0,0 +1,66 @@ +from datetime import datetime, timedelta +import pytest +from unittest.mock import patch + +from controlpanel.api.helm import ( + Chart, + HelmRepository, +) + + +def setup_function(fn): + print("Resetting HelmRepository._updated_at ...") + HelmRepository._updated_at = None + + +def test_chart_app_version(): + app_version = "RStudio: 1.2.1335+conda, R: 3.5.1, Python: 3.7.1, patch: 10" + chart = Chart( + "rstudio", + "RStudio with Auth0 authentication proxy", + "2.2.5", + app_version, + ) + + assert chart.app_version == app_version + + +def test_helm_repository_update_when_recently_updated(helm_repository_index): + HelmRepository._updated_at = datetime.utcnow() + + with patch("controlpanel.api.helm.Helm") as helm: + HelmRepository.update(force=False) + helm.execute.assert_not_called() + + +def test_helm_repository_update_when_cache_old(helm_repository_index): + yesterday = datetime.utcnow() - timedelta(days=1) + HelmRepository._updated_at = yesterday + + with patch("controlpanel.api.helm.Helm") as helm: + HelmRepository.update(force=False) + helm.execute.assert_called_once() + + +def test_helm_repository_chart_info_when_chart_not_found(helm_repository_index): + with patch("controlpanel.api.helm.open", helm_repository_index): + info = HelmRepository.get_chart_info("notfound") + assert info == {} + + +def test_helm_repository_chart_info_when_chart_found(helm_repository_index): + with patch("controlpanel.api.helm.open", helm_repository_index): + # See tests/api/fixtures/helm_mojanalytics_index.py + rstudio_info = HelmRepository.get_chart_info("rstudio") + + rstudio_2_2_5_app_version = "RStudio: 1.2.1335+conda, R: 3.5.1, Python: 3.7.1, patch: 10" + + assert len(rstudio_info) == 2 + assert "2.2.5" in rstudio_info + assert "1.0.0" in rstudio_info + + assert rstudio_info["2.2.5"].app_version == rstudio_2_2_5_app_version + # Helm added `appVersion` field in metadata only + # "recently" so for testing that for old chart + # version this returns `None` + assert rstudio_info["1.0.0"].app_version == None diff --git a/tests/conftest.py b/tests/conftest.py index 8a8df4ace..7b6a71023 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,11 @@ -from unittest.mock import patch +from unittest.mock import mock_open, patch from model_mommy import mommy +import yaml import pytest +from tests.api.fixtures.helm_mojanalytics_index import HELM_MOJANALYTICS_INDEX + @pytest.yield_fixture(autouse=True) def aws(): @@ -49,6 +52,15 @@ def helm(): yield helm +@pytest.fixture +def helm_repository_index(autouse=True): + """ + Mock the helm repository with some data + """ + content = yaml.dump(HELM_MOJANALYTICS_INDEX) + return mock_open(read_data=content) + + @pytest.yield_fixture(autouse=True) def slack_WebClient(): """ diff --git a/tests/frontend/test_consumers.py b/tests/frontend/test_consumers.py new file mode 100644 index 000000000..4456e83bb --- /dev/null +++ b/tests/frontend/test_consumers.py @@ -0,0 +1,180 @@ +import json + +import pytest +from unittest.mock import patch, Mock + +from controlpanel.api.models import Tool, ToolDeployment, User +from controlpanel.api.cluster import ( + TOOL_DEPLOYING, + TOOL_RESTARTING, + TOOL_UPGRADED, +) +from controlpanel.frontend import consumers + + +@pytest.fixture +def users(db): + print("Setting up users...") + User(auth0_id="github|1", username="alice").save() + User(auth0_id="github|2", username="bob").save() + + +@pytest.fixture +def tools(db): + print("Setting up tools...") + Tool(chart_name="a_tool").save() + Tool(chart_name="another_tool").save() + + +@pytest.yield_fixture +def update_tool_status(): + with patch("controlpanel.frontend.consumers.update_tool_status") as update_tool_status: + yield update_tool_status + + +@pytest.yield_fixture +def wait_for_deployment(): + with patch("controlpanel.frontend.consumers.wait_for_deployment") as wait_for_deployment: + yield wait_for_deployment + + +def test_tool_deploy(users, tools, update_tool_status, wait_for_deployment): + user = User.objects.first() + tool = Tool.objects.first() + id_token = "secret user id_token" + + with patch("controlpanel.frontend.consumers.ToolDeployment") as ToolDeployment: + tool_deployment = Mock() + ToolDeployment.return_value = tool_deployment + + consumer = consumers.BackgroundTaskConsumer("test") + consumer.tool_deploy( + message={ + "user_id": user.auth0_id, + "tool_name": tool.chart_name, + "id_token": id_token, + } + ) + + # 1. Instanciate `ToolDeployment` correctly + ToolDeployment.assert_called_with(tool, user) + # 2. Send status update + update_tool_status.assert_called_with( + tool_deployment, + id_token, + TOOL_DEPLOYING, + ) + # 3. Call save() on ToolDeployment (trigger deployment) + tool_deployment.save.assert_called() + # 4. Wait for deployment to complete + wait_for_deployment.assert_called_with(tool_deployment, id_token) + + +def test_tool_upgrade(users, tools, update_tool_status): + user = User.objects.first() + tool = Tool.objects.first() + id_token = "secret user id_token" + + with patch("controlpanel.frontend.consumers.ToolDeployment") as ToolDeployment: + tool_deployment = Mock() + ToolDeployment.return_value = tool_deployment + + message = { + "user_id": user.auth0_id, + "tool_name": tool.chart_name, + "id_token": id_token, + } + + consumer = consumers.BackgroundTaskConsumer("test") + consumer.tool_deploy = Mock() # mock tool_deploy() method + consumer.tool_upgrade(message=message) + + # 1. calls/reuse tool_deploy() + consumer.tool_deploy.assert_called_with(message) + # 2. Instanciate `ToolDeployment` correctly + ToolDeployment.assert_called_with(tool, user) + # 3. Send status update + update_tool_status.assert_called_with( + tool_deployment, + id_token, + TOOL_UPGRADED, + ) + + +def test_tool_restart(users, tools, update_tool_status, wait_for_deployment): + user = User.objects.first() + tool = Tool.objects.first() + id_token = "secret user id_token" + + with patch("controlpanel.frontend.consumers.ToolDeployment") as ToolDeployment: + tool_deployment = Mock() + ToolDeployment.return_value = tool_deployment + + consumer = consumers.BackgroundTaskConsumer("test") + consumer.tool_restart( + message={ + "user_id": user.auth0_id, + "tool_name": tool.chart_name, + "id_token": id_token, + } + ) + + # 1. Instanciate `ToolDeployment` correctly + ToolDeployment.assert_called_with(tool, user) + # 2. Send status update + update_tool_status.assert_called_with( + tool_deployment, + id_token, + TOOL_RESTARTING, + ) + # 3. Call restart() on ToolDeployment (trigger deployment) + tool_deployment.restart.assert_called_with(id_token=id_token) + # 4. Wait for deployment to complete + wait_for_deployment.assert_called_with(tool_deployment, id_token) + + +def test_get_tool_and_user(users, tools): + expected_user = User.objects.first() + expected_tool = Tool.objects.first() + message = { + "user_id": expected_user.auth0_id, + "tool_name": expected_tool.chart_name, + "id_token": "not used by this method", + } + + consumer = consumers.BackgroundTaskConsumer("test") + tool, user = consumer.get_tool_and_user(message) + assert expected_user == user + assert expected_tool == tool + + +def test_update_tool_status(): + tool = Tool(chart_name="a_tool", version="v1.0.0") + user = User(auth0_id="github|123") + id_token = "user id_token" + status = TOOL_UPGRADED + app_version = "R: 42, Python: 2.0.0" + + tool_deployment = Mock() + tool_deployment.tool = tool + tool_deployment.user = user + tool_deployment.get_installed_app_version.return_value = app_version + + expected_sse_event = { + "event": "toolStatus", + "data": json.dumps({ + "toolName": tool.chart_name, + "version": tool.version, + "appVersion": app_version, + "status": status, + }), + } + + with patch("controlpanel.frontend.consumers.send_sse") as send_sse: + consumers.update_tool_status( + tool_deployment, + id_token, + status, + ) + tool_deployment.get_installed_app_version.assert_called_with(id_token) + send_sse.assert_called_with(user.auth0_id, expected_sse_event)