Skip to content

Commit

Permalink
feat: add support for accepting an Artifact Registry URL in pipeline_…
Browse files Browse the repository at this point in the history
…job (googleapis#1405)

* Add support for Artifact Registry in template_path

* fix typo

* update tests

* fix AR path

* remove unused project

* add code for refreshing credentials

* add import for google.auth.transport

* fix AR path

* fix AR path

* fix runtime_config

* test removing v1beta1

* try using v1 directly instead

* update to use v1beta1

* use select_version

* add back template_uri

* try adding back v1beta1

* use select_version

* differentiate when to use select_version

* test removing v1beta1 for pipeline_complete_states

* add tests for creating pipelines using v1beta1

* fix merge

* fix typo

* fix lint using blacken

* fix regex

* update to use v1 instead of v1beta1

* add test for invalid url

* update error type

* implement failure_policy

* use urllib.request instead of requests

* Revert "implement failure_policy"

This reverts commit 72cdd9e.

* fix lint

Co-authored-by: Anthonios Partheniou <[email protected]>
  • Loading branch information
chongyouquan and parthea authored Jun 17, 2022
1 parent 82f678e commit e138cfd
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 11 deletions.
27 changes: 18 additions & 9 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
# Pattern for valid names used as a Vertex resource name.
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$")

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")


def _get_current_time() -> datetime.datetime:
"""Gets the current timestamp."""
Expand Down Expand Up @@ -125,8 +128,9 @@ def __init__(
Required. The user-defined name of this Pipeline.
template_path (str):
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
can be a local path or a Google Cloud Storage URI.
Example: "gs://project.name"
can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
or an Artifact Registry URI (e.g.
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest").
job_id (str):
Optional. The unique ID of the job run.
If not specified, pipeline name + timestamp will be used.
Expand Down Expand Up @@ -237,15 +241,20 @@ def __init__(
if enable_caching is not None:
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)

self._gca_resource = gca_pipeline_job.PipelineJob(
display_name=display_name,
pipeline_spec=pipeline_job["pipelineSpec"],
labels=labels,
runtime_config=runtime_config,
encryption_spec=initializer.global_config.get_encryption_spec(
pipeline_job_args = {
"display_name": display_name,
"pipeline_spec": pipeline_job["pipelineSpec"],
"labels": labels,
"runtime_config": runtime_config,
"encryption_spec": initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
)
}

if _VALID_AR_URL.match(template_path):
pipeline_job_args["template_uri"] = template_path

self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args)

@base.optional_sync()
def run(
Expand Down
42 changes: 42 additions & 0 deletions google/cloud/aiplatform/utils/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
# limitations under the License.
#

import re
from typing import Any, Dict, Optional
from urllib import request

from google.auth import credentials as auth_credentials
from google.auth import transport
from google.cloud import storage

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")


def load_yaml(
path: str,
Expand All @@ -42,6 +48,8 @@ def load_yaml(
"""
if path.startswith("gs://"):
return _load_yaml_from_gs_uri(path, project, credentials)
elif _VALID_AR_URL.match(path):
return _load_yaml_from_ar_uri(path, credentials)
else:
return _load_yaml_from_local_file(path)

Expand Down Expand Up @@ -95,3 +103,37 @@ def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
)
with open(file_path) as f:
return yaml.safe_load(f)


def _load_yaml_from_ar_uri(
uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Dict[str, Any]:
"""Loads data from a YAML document referenced by a Artifact Registry URI.
Args:
path (str):
Required. Artifact Registry URI for YAML document.
credentials (auth_credentials.Credentials):
Optional. Credentials to use with Artifact Registry.
Returns:
A Dict object representing the YAML document.
"""
try:
import yaml
except ImportError:
raise ImportError(
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
)
req = request.Request(uri)

if credentials:
if not credentials.valid:
credentials.refresh(transport.requests.Request())
if credentials.token:
req.add_header("Authorization", "Bearer " + credentials.token)
response = request.urlopen(req)

return yaml.safe_load(response.read().decode("utf-8"))
92 changes: 92 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock
from importlib import reload
from unittest.mock import patch
from urllib import request
from datetime import datetime

from google.auth import credentials as auth_credentials
Expand Down Expand Up @@ -50,6 +51,7 @@
_TEST_SERVICE_ACCOUNT = "[email protected]"

_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"

Expand Down Expand Up @@ -289,6 +291,17 @@ def mock_load_yaml_and_json(job_spec):
yield mock_load_yaml_and_json


@pytest.fixture
def mock_request_urlopen(job_spec):
with patch.object(request, "urlopen") as mock_urlopen:
mock_read_response = mock.MagicMock()
mock_decode_response = mock.MagicMock()
mock_decode_response.return_value = job_spec.encode()
mock_read_response.return_value.decode = mock_decode_response
mock_urlopen.return_value.read = mock_read_response
yield mock_urlopen


@pytest.mark.usefixtures("google_auth_mock")
class TestPipelineJob:
def setup_method(self):
Expand Down Expand Up @@ -376,6 +389,85 @@ def test_run_call_pipeline_service_create(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_artifact_registry(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_request_urlopen,
job_spec,
mock_load_yaml_and_json,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_AR_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
sync=sync,
create_request_timeout=None,
)

if not sync:
job.wait()

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
template_uri=_TEST_AR_TEMPLATE_PATH,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
timeout=None,
)

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[
Expand Down
27 changes: 25 additions & 2 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import json
import os
from typing import Callable, Dict, Optional
from unittest import mock
from urllib import request

import pytest
import yaml
Expand Down Expand Up @@ -564,13 +566,34 @@ def json_file(tmp_path):
yield json_file_path


@pytest.fixture(scope="function")
def mock_request_urlopen():
data = {"key": "val", "list": ["1", 2, 3.0]}
with mock.patch.object(request, "urlopen") as mock_urlopen:
mock_read_response = mock.MagicMock()
mock_decode_response = mock.MagicMock()
mock_decode_response.return_value = json.dumps(data)
mock_read_response.return_value.decode = mock_decode_response
mock_urlopen.return_value.read = mock_read_response
yield "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"


class TestYamlUtils:
def test_load_yaml_from_local_file__with_json(self, yaml_file):
def test_load_yaml_from_local_file__with_yaml(self, yaml_file):
actual = yaml_utils.load_yaml(yaml_file)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

def test_load_yaml_from_local_file__with_yaml(self, json_file):
def test_load_yaml_from_local_file__with_json(self, json_file):
actual = yaml_utils.load_yaml(json_file)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
actual = yaml_utils.load_yaml(mock_request_urlopen)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

def test_load_yaml_from_invalid_uri(self):
with pytest.raises(FileNotFoundError):
yaml_utils.load_yaml("https://us-docker.pkg.dev/v2/proj/repo/img/tags/list")

0 comments on commit e138cfd

Please sign in to comment.