diff --git a/opensafely/pull.py b/opensafely/pull.py index 2e80b52..d3c28e4 100644 --- a/opensafely/pull.py +++ b/opensafely/pull.py @@ -1,6 +1,6 @@ +import argparse import subprocess import sys -from collections import defaultdict from http.cookiejar import split_header_words from pathlib import Path from urllib.parse import urlparse @@ -18,18 +18,29 @@ # The deprecated `databuilder` name is still supported by job-runner, but we don't want # it showing up here IMAGES = list(config.ALLOWED_IMAGES - {"databuilder"}) -FULL_IMAGES = {f"{REGISTRY}/{image}" for image in IMAGES} DEPRECATED_REGISTRIES = ["docker.opensafely.org", "ghcr.io/opensafely"] IMAGES.sort() # this is just for consistency for testing +def valid_image(image_string): + if image_string == "all": + return image_string + + name, _, _ = image_string.partition(":") + if name not in IMAGES: + raise argparse.ArgumentTypeError( + f"{image_string} is not a valid OpenSAFELY image: {','.join(IMAGES)}" + ) + + return image_string + + def add_arguments(parser): - choices = ["all"] + IMAGES parser.add_argument( "image", nargs="?", - choices=choices, help="OpenSAFELY docker image to update (default: all)", + type=valid_image, default="all", ) parser.add_argument( @@ -47,26 +58,36 @@ def main(image="all", force=False, project=None): if not docker_preflight_check(): return False + local_images = get_local_images() + if project: force = True images = get_actions_from_project_file(project) elif image == "all": - images = IMAGES + if force: + images = { + f"{name}:{get_default_version_for_image(name)}": None for name in IMAGES + } + else: + images = list(local_images) else: # if user has requested a specific image, pull it regardless force = True images = [image] - local_images = get_local_images() try: updated = False for image in images: - tag = f"{REGISTRY}/{image}" - if force or tag in local_images: + if force or image in local_images: + name, _, tag = image.partition(":") + if not tag: + tag = get_default_version_for_image(name) updated = True - print(f"Updating OpenSAFELY {image} image") - version = get_default_version_for_image(image) - subprocess.run(["docker", "pull", tag + f":{version}"], check=True) + + print(f"Updating OpenSAFELY {name}:{tag} image") + subprocess.run( + ["docker", "pull", f"{REGISTRY}/{name}:{tag}"], check=True + ) if updated: remove_deprecated_images(local_images) @@ -88,7 +109,7 @@ def get_actions_from_project_file(project_yaml): for name, action in project.actions.items(): if action.run.name in IMAGES and action.run.name not in images: - images.append(action.run.name) + images.append(f"{action.run.name}:{action.run.version}") if not images: raise RuntimeError(f"No actions found in {project_yaml}") @@ -106,19 +127,21 @@ def get_local_images(): "--filter", "label=org.opensafely.action", "--no-trunc", - "--format={{.Repository}}={{.ID}}", + "--format={{.Repository}}:{{.Tag}}={{.ID}}", ], check=True, text=True, capture_output=True, ) - images = defaultdict(list) + images = dict() for line in ps.stdout.splitlines(): if not line.strip(): continue - name, sha = line.split("=", 1) - images[name].append(sha) + line = line.replace("ghcr.io/opensafely-core/", "") + + image, sha = line.split("=", 1) + images[image] = sha return images @@ -135,6 +158,8 @@ def remove_deprecated_images(local_images): def get_default_version_for_image(name): if name in ["ehrql"]: return "v1" + elif name == "python": + return "v2" else: return "latest" @@ -143,22 +168,34 @@ def get_default_version_for_image(name): token = None -def get_remote_sha(full_name, tag): +def dockerhub_api(path): """Get the current sha for a tag from a docker registry.""" global token - parsed = urlparse("https://" + full_name) - manifest_url = f"https://ghcr.io/v2/{parsed.path}/manifests/{tag}" + url = f"https://ghcr.io/{path}" + # Docker API requires auth token, even for public resources. + # However, we can reuse a public token. if token is None: - # Docker API requires auth token, even for public resources. - # However, we can reuse a public token. - response = session.get(manifest_url) + response = session.get(url) + token = get_auth_token(response.headers["www-authenticate"]) + else: + response = session.get(url, headers={"Authorization": f"Bearer {token}"}) + + # refresh token if needed + if response.status_code == 401: token = get_auth_token(response.headers["www-authenticate"]) + response = session.get(url, headers={"Authorization": f"Bearer {token}"}) - response = session.get(manifest_url, headers={"Authorization": f"Bearer {token}"}) response.raise_for_status() - return response.json()["config"]["digest"] + return response.json() + + +def get_remote_sha(full_name, tag): + """Get the current sha for a tag from a docker registry.""" + parsed = urlparse("https://" + full_name) + response = dockerhub_api(f"/v2/{parsed.path}/manifests/{tag}") + return response["config"]["digest"] def get_auth_token(header): @@ -179,14 +216,11 @@ def check_version(): need_update = [] local_images = get_local_images() - for image in IMAGES: - full_name = f"{REGISTRY}/{image}" - local_shas = local_images.get(full_name, []) - if local_shas: - version = get_default_version_for_image(image) - latest_sha = get_remote_sha(full_name, version) - if latest_sha not in local_shas: - need_update.append(image) + for image, local_sha in local_images.items(): + name, _, tag = image.partition(":") + latest_sha = get_remote_sha(f"{REGISTRY}/{name}", tag) + if latest_sha != local_sha: + need_update.append(image) if need_update: print( diff --git a/tests/conftest.py b/tests/conftest.py index 0776733..ddea745 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -143,7 +143,7 @@ def run(monkeypatch): monkeypatch.setattr(subprocess, "run", fixture.run) yield fixture if len(fixture) != 0: - remaining = "\n".join(str(cmd) for cmd, _ in fixture) + remaining = "\n".join(str(f[0]) for f in fixture if f) raise AssertionError( f"run fixture had unused remaining expected cmds:\n{remaining}" ) diff --git a/tests/test_pull.py b/tests/test_pull.py index b17625b..3447766 100644 --- a/tests/test_pull.py +++ b/tests/test_pull.py @@ -22,7 +22,7 @@ def expect_local_images(run, stdout="", **kwargs): "--filter", "label=org.opensafely.action", "--no-trunc", - "--format={{.Repository}}={{.ID}}", + "--format={{.Repository}}:{{.Tag}}={{.ID}}", ], stdout=stdout, **kwargs, @@ -45,7 +45,7 @@ def test_default_no_local_images_force(run, capsys): run.expect(["docker", "pull", tag("cohortextractor")]) run.expect(["docker", "pull", tag("ehrql", version="v1")]) run.expect(["docker", "pull", tag("jupyter")]) - run.expect(["docker", "pull", tag("python")]) + run.expect(["docker", "pull", tag("python", version="v2")]) run.expect(["docker", "pull", tag("r")]) run.expect(["docker", "pull", tag("sqlrunner")]) run.expect(["docker", "pull", tag("stata-mp")]) @@ -64,20 +64,20 @@ def test_default_no_local_images_force(run, capsys): out, err = capsys.readouterr() assert err == "" assert out.splitlines() == [ - "Updating OpenSAFELY cohortextractor image", - "Updating OpenSAFELY ehrql image", - "Updating OpenSAFELY jupyter image", - "Updating OpenSAFELY python image", - "Updating OpenSAFELY r image", - "Updating OpenSAFELY sqlrunner image", - "Updating OpenSAFELY stata-mp image", + "Updating OpenSAFELY cohortextractor:latest image", + "Updating OpenSAFELY ehrql:v1 image", + "Updating OpenSAFELY jupyter:latest image", + "Updating OpenSAFELY python:v2 image", + "Updating OpenSAFELY r:latest image", + "Updating OpenSAFELY sqlrunner:latest image", + "Updating OpenSAFELY stata-mp:latest image", "Pruning old OpenSAFELY docker images...", ] def test_default_with_local_images(run, capsys): run.expect(["docker", "info"]) - expect_local_images(run, stdout="ghcr.io/opensafely-core/r=sha") + expect_local_images(run, stdout="ghcr.io/opensafely-core/r:latest=sha") run.expect(["docker", "pull", tag("r")]) run.expect( [ @@ -94,7 +94,7 @@ def test_default_with_local_images(run, capsys): out, err = capsys.readouterr() assert err == "" assert out.splitlines() == [ - "Updating OpenSAFELY r image", + "Updating OpenSAFELY r:latest image", "Pruning old OpenSAFELY docker images...", ] @@ -118,7 +118,7 @@ def test_specific_image(run, capsys): out, err = capsys.readouterr() assert err == "" assert out.splitlines() == [ - "Updating OpenSAFELY r image", + "Updating OpenSAFELY r:latest image", "Pruning old OpenSAFELY docker images...", ] @@ -144,9 +144,9 @@ def test_project(run, capsys): out, err = capsys.readouterr() assert err == "" assert out.splitlines() == [ - "Updating OpenSAFELY cohortextractor image", - "Updating OpenSAFELY python image", - "Updating OpenSAFELY jupyter image", + "Updating OpenSAFELY cohortextractor:latest image", + "Updating OpenSAFELY python:latest image", + "Updating OpenSAFELY jupyter:latest image", "Pruning old OpenSAFELY docker images...", ] @@ -169,14 +169,14 @@ def test_remove_deprecated_images(run): def test_check_version_out_of_date(run, capsys): expect_local_images( run, - stdout="ghcr.io/opensafely-core/python=sha256:oldsha", + stdout="ghcr.io/opensafely-core/python:latest=sha256:oldsha", ) assert len(pull.check_version()) == 1 out, err = capsys.readouterr() assert out == "" assert err.splitlines() == [ - "Warning: the OpenSAFELY docker images for python actions are out of date - please update by running:", + "Warning: the OpenSAFELY docker images for python:latest actions are out of date - please update by running:", " opensafely pull", "", ] @@ -185,28 +185,9 @@ def test_check_version_out_of_date(run, capsys): def test_check_version_up_to_date(run, capsys): current_sha = pull.get_remote_sha("ghcr.io/opensafely-core/python", "latest") pull.token = None - - expect_local_images( - run, - stdout=f"ghcr.io/opensafely-core/python={current_sha}", - ) - - assert len(pull.check_version()) == 0 - out, err = capsys.readouterr() - assert err == "" - assert out.splitlines() == [] - - -def test_check_version_up_to_date_old_sha(run, capsys): - current_sha = pull.get_remote_sha("ghcr.io/opensafely-core/python", "latest") - pull.token = None - expect_local_images( run, - stdout=( - f"ghcr.io/opensafely-core/python={current_sha}\n" - f"ghcr.io/opensafely-core/python=oldsha" - ), + stdout=f"ghcr.io/opensafely-core/python:latest={current_sha}", ) assert len(pull.check_version()) == 0