forked from googleapis/python-aiplatform
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add support for accepting an Artifact Registry URL in pipeline_…
…job (googleapis#1405) * Add support for Artifact Registry in template_path * fix typo * update tests * fix AR path * remove unused project * add code for refreshing credentials * add import for google.auth.transport * fix AR path * fix AR path * fix runtime_config * test removing v1beta1 * try using v1 directly instead * update to use v1beta1 * use select_version * add back template_uri * try adding back v1beta1 * use select_version * differentiate when to use select_version * test removing v1beta1 for pipeline_complete_states * add tests for creating pipelines using v1beta1 * fix merge * fix typo * fix lint using blacken * fix regex * update to use v1 instead of v1beta1 * add test for invalid url * update error type * implement failure_policy * use urllib.request instead of requests * Revert "implement failure_policy" This reverts commit 72cdd9e. * fix lint Co-authored-by: Anthonios Partheniou <[email protected]>
- Loading branch information
1 parent
82f678e
commit e138cfd
Showing
4 changed files
with
177 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from unittest import mock | ||
from importlib import reload | ||
from unittest.mock import patch | ||
from urllib import request | ||
from datetime import datetime | ||
|
||
from google.auth import credentials as auth_credentials | ||
|
@@ -50,6 +51,7 @@ | |
_TEST_SERVICE_ACCOUNT = "[email protected]" | ||
|
||
_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json" | ||
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" | ||
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" | ||
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}" | ||
|
||
|
@@ -289,6 +291,17 @@ def mock_load_yaml_and_json(job_spec): | |
yield mock_load_yaml_and_json | ||
|
||
|
||
@pytest.fixture | ||
def mock_request_urlopen(job_spec): | ||
with patch.object(request, "urlopen") as mock_urlopen: | ||
mock_read_response = mock.MagicMock() | ||
mock_decode_response = mock.MagicMock() | ||
mock_decode_response.return_value = job_spec.encode() | ||
mock_read_response.return_value.decode = mock_decode_response | ||
mock_urlopen.return_value.read = mock_read_response | ||
yield mock_urlopen | ||
|
||
|
||
@pytest.mark.usefixtures("google_auth_mock") | ||
class TestPipelineJob: | ||
def setup_method(self): | ||
|
@@ -376,6 +389,85 @@ def test_run_call_pipeline_service_create( | |
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"job_spec", | ||
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], | ||
) | ||
@pytest.mark.parametrize("sync", [True, False]) | ||
def test_run_call_pipeline_service_create_artifact_registry( | ||
self, | ||
mock_pipeline_service_create, | ||
mock_pipeline_service_get, | ||
mock_request_urlopen, | ||
job_spec, | ||
mock_load_yaml_and_json, | ||
sync, | ||
): | ||
aiplatform.init( | ||
project=_TEST_PROJECT, | ||
staging_bucket=_TEST_GCS_BUCKET_NAME, | ||
location=_TEST_LOCATION, | ||
credentials=_TEST_CREDENTIALS, | ||
) | ||
|
||
job = pipeline_jobs.PipelineJob( | ||
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, | ||
template_path=_TEST_AR_TEMPLATE_PATH, | ||
job_id=_TEST_PIPELINE_JOB_ID, | ||
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, | ||
enable_caching=True, | ||
) | ||
|
||
job.run( | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
network=_TEST_NETWORK, | ||
sync=sync, | ||
create_request_timeout=None, | ||
) | ||
|
||
if not sync: | ||
job.wait() | ||
|
||
expected_runtime_config_dict = { | ||
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, | ||
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, | ||
} | ||
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb | ||
json_format.ParseDict(expected_runtime_config_dict, runtime_config) | ||
|
||
job_spec = yaml.safe_load(job_spec) | ||
pipeline_spec = job_spec.get("pipelineSpec") or job_spec | ||
|
||
# Construct expected request | ||
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( | ||
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, | ||
pipeline_spec={ | ||
"components": {}, | ||
"pipelineInfo": pipeline_spec["pipelineInfo"], | ||
"root": pipeline_spec["root"], | ||
"schemaVersion": "2.1.0", | ||
}, | ||
runtime_config=runtime_config, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
network=_TEST_NETWORK, | ||
template_uri=_TEST_AR_TEMPLATE_PATH, | ||
) | ||
|
||
mock_pipeline_service_create.assert_called_once_with( | ||
parent=_TEST_PARENT, | ||
pipeline_job=expected_gapic_pipeline_job, | ||
pipeline_job_id=_TEST_PIPELINE_JOB_ID, | ||
timeout=None, | ||
) | ||
|
||
mock_pipeline_service_get.assert_called_with( | ||
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY | ||
) | ||
|
||
assert job._gca_resource == make_pipeline_job( | ||
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"job_spec", | ||
[ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters