From e138cfd8b8f8033a562c1b7f2d340042f57db27e Mon Sep 17 00:00:00 2001 From: chongyouquan <48691403+chongyouquan@users.noreply.github.com> Date: Fri, 17 Jun 2022 14:51:13 -0700 Subject: [PATCH] feat: add support for accepting an Artifact Registry URL in pipeline_job (#1405) * Add support for Artifact Registry in template_path * fix typo * update tests * fix AR path * remove unused project * add code for refreshing credentials * add import for google.auth.transport * fix AR path * fix AR path * fix runtime_config * test removing v1beta1 * try using v1 directly instead * update to use v1beta1 * use select_version * add back template_uri * try adding back v1beta1 * use select_version * differentiate when to use select_version * test removing v1beta1 for pipeline_complete_states * add tests for creating pipelines using v1beta1 * fix merge * fix typo * fix lint using blacken * fix regex * update to use v1 instead of v1beta1 * add test for invalid url * update error type * implement failure_policy * use urllib.request instead of requests * Revert "implement failure_policy" This reverts commit 72cdd9ef60f10192f4c80669f5d2aaa448e9da76. * fix lint Co-authored-by: Anthonios Partheniou --- google/cloud/aiplatform/pipeline_jobs.py | 27 ++++-- google/cloud/aiplatform/utils/yaml_utils.py | 42 ++++++++++ tests/unit/aiplatform/test_pipeline_jobs.py | 92 +++++++++++++++++++++ tests/unit/aiplatform/test_utils.py | 27 +++++- 4 files changed, 177 insertions(+), 11 deletions(-) diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index a1ea72e8fd..0e5fcf74b2 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -56,6 +56,9 @@ # Pattern for valid names used as a Vertex resource name. _VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$") +# Pattern for an Artifact Registry URL. +_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*") + def _get_current_time() -> datetime.datetime: """Gets the current timestamp.""" @@ -125,8 +128,9 @@ def __init__( Required. The user-defined name of this Pipeline. template_path (str): Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It - can be a local path or a Google Cloud Storage URI. - Example: "gs://project.name" + can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"), + or an Artifact Registry URI (e.g. + "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"). job_id (str): Optional. The unique ID of the job run. If not specified, pipeline name + timestamp will be used. @@ -237,15 +241,20 @@ def __init__( if enable_caching is not None: _set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching) - self._gca_resource = gca_pipeline_job.PipelineJob( - display_name=display_name, - pipeline_spec=pipeline_job["pipelineSpec"], - labels=labels, - runtime_config=runtime_config, - encryption_spec=initializer.global_config.get_encryption_spec( + pipeline_job_args = { + "display_name": display_name, + "pipeline_spec": pipeline_job["pipelineSpec"], + "labels": labels, + "runtime_config": runtime_config, + "encryption_spec": initializer.global_config.get_encryption_spec( encryption_spec_key_name=encryption_spec_key_name ), - ) + } + + if _VALID_AR_URL.match(template_path): + pipeline_job_args["template_uri"] = template_path + + self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args) @base.optional_sync() def run( diff --git a/google/cloud/aiplatform/utils/yaml_utils.py b/google/cloud/aiplatform/utils/yaml_utils.py index 29d41e56ab..bac33733dc 100644 --- a/google/cloud/aiplatform/utils/yaml_utils.py +++ b/google/cloud/aiplatform/utils/yaml_utils.py @@ -15,11 +15,17 @@ # limitations under the License. # +import re from typing import Any, Dict, Optional +from urllib import request from google.auth import credentials as auth_credentials +from google.auth import transport from google.cloud import storage +# Pattern for an Artifact Registry URL. +_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*") + def load_yaml( path: str, @@ -42,6 +48,8 @@ def load_yaml( """ if path.startswith("gs://"): return _load_yaml_from_gs_uri(path, project, credentials) + elif _VALID_AR_URL.match(path): + return _load_yaml_from_ar_uri(path, credentials) else: return _load_yaml_from_local_file(path) @@ -95,3 +103,37 @@ def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]: ) with open(file_path) as f: return yaml.safe_load(f) + + +def _load_yaml_from_ar_uri( + uri: str, + credentials: Optional[auth_credentials.Credentials] = None, +) -> Dict[str, Any]: + """Loads data from a YAML document referenced by a Artifact Registry URI. + + Args: + path (str): + Required. Artifact Registry URI for YAML document. + credentials (auth_credentials.Credentials): + Optional. Credentials to use with Artifact Registry. + + Returns: + A Dict object representing the YAML document. + """ + try: + import yaml + except ImportError: + raise ImportError( + "pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. " + 'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"' + ) + req = request.Request(uri) + + if credentials: + if not credentials.valid: + credentials.refresh(transport.requests.Request()) + if credentials.token: + req.add_header("Authorization", "Bearer " + credentials.token) + response = request.urlopen(req) + + return yaml.safe_load(response.read().decode("utf-8")) diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index 9bfda28353..78fa4d926a 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -22,6 +22,7 @@ from unittest import mock from importlib import reload from unittest.mock import patch +from urllib import request from datetime import datetime from google.auth import credentials as auth_credentials @@ -50,6 +51,7 @@ _TEST_SERVICE_ACCOUNT = "abcde@my-project.iam.gserviceaccount.com" _TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json" +_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" _TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}" @@ -289,6 +291,17 @@ def mock_load_yaml_and_json(job_spec): yield mock_load_yaml_and_json +@pytest.fixture +def mock_request_urlopen(job_spec): + with patch.object(request, "urlopen") as mock_urlopen: + mock_read_response = mock.MagicMock() + mock_decode_response = mock.MagicMock() + mock_decode_response.return_value = job_spec.encode() + mock_read_response.return_value.decode = mock_decode_response + mock_urlopen.return_value.read = mock_read_response + yield mock_urlopen + + @pytest.mark.usefixtures("google_auth_mock") class TestPipelineJob: def setup_method(self): @@ -376,6 +389,85 @@ def test_run_call_pipeline_service_create( gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED ) + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_artifact_registry( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_request_urlopen, + job_spec, + mock_load_yaml_and_json, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_AR_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + enable_caching=True, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + sync=sync, + create_request_timeout=None, + ) + + if not sync: + job.wait() + + expected_runtime_config_dict = { + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, + "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, + } + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb + json_format.ParseDict(expected_runtime_config_dict, runtime_config) + + job_spec = yaml.safe_load(job_spec) + pipeline_spec = job_spec.get("pipelineSpec") or job_spec + + # Construct expected request + expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + pipeline_spec={ + "components": {}, + "pipelineInfo": pipeline_spec["pipelineInfo"], + "root": pipeline_spec["root"], + "schemaVersion": "2.1.0", + }, + runtime_config=runtime_config, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + template_uri=_TEST_AR_TEMPLATE_PATH, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=_TEST_PARENT, + pipeline_job=expected_gapic_pipeline_job, + pipeline_job_id=_TEST_PIPELINE_JOB_ID, + timeout=None, + ) + + mock_pipeline_service_get.assert_called_with( + name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY + ) + + assert job._gca_resource == make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + @pytest.mark.parametrize( "job_spec", [ diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index c467a28b05..6d59c16dc3 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -20,6 +20,8 @@ import json import os from typing import Callable, Dict, Optional +from unittest import mock +from urllib import request import pytest import yaml @@ -564,13 +566,34 @@ def json_file(tmp_path): yield json_file_path +@pytest.fixture(scope="function") +def mock_request_urlopen(): + data = {"key": "val", "list": ["1", 2, 3.0]} + with mock.patch.object(request, "urlopen") as mock_urlopen: + mock_read_response = mock.MagicMock() + mock_decode_response = mock.MagicMock() + mock_decode_response.return_value = json.dumps(data) + mock_read_response.return_value.decode = mock_decode_response + mock_urlopen.return_value.read = mock_read_response + yield "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" + + class TestYamlUtils: - def test_load_yaml_from_local_file__with_json(self, yaml_file): + def test_load_yaml_from_local_file__with_yaml(self, yaml_file): actual = yaml_utils.load_yaml(yaml_file) expected = {"key": "val", "list": ["1", 2, 3.0]} assert actual == expected - def test_load_yaml_from_local_file__with_yaml(self, json_file): + def test_load_yaml_from_local_file__with_json(self, json_file): actual = yaml_utils.load_yaml(json_file) expected = {"key": "val", "list": ["1", 2, 3.0]} assert actual == expected + + def test_load_yaml_from_ar_uri(self, mock_request_urlopen): + actual = yaml_utils.load_yaml(mock_request_urlopen) + expected = {"key": "val", "list": ["1", 2, 3.0]} + assert actual == expected + + def test_load_yaml_from_invalid_uri(self): + with pytest.raises(FileNotFoundError): + yaml_utils.load_yaml("https://us-docker.pkg.dev/v2/proj/repo/img/tags/list")