From b066e579e9402e733542917ca3eb56e5a9d518b5 Mon Sep 17 00:00:00 2001 From: Ryan Soley Date: Fri, 12 Jul 2024 12:40:08 -0400 Subject: [PATCH] allow cred passthrough in sync --- rubicon_ml/client/rubicon.py | 37 ++++++++++++---- tests/unit/client/test_rubicon_client.py | 55 ++++++++++++++++++------ 2 files changed, 72 insertions(+), 20 deletions(-) diff --git a/rubicon_ml/client/rubicon.py b/rubicon_ml/client/rubicon.py index ced8b171..8018bb7e 100644 --- a/rubicon_ml/client/rubicon.py +++ b/rubicon_ml/client/rubicon.py @@ -1,3 +1,4 @@ +import os import subprocess import warnings from typing import Any, Dict, List, Optional, Tuple, Union @@ -298,7 +299,7 @@ def projects(self): raise return_err @failsafe - def sync(self, project_name, s3_root_dir): + def sync(self, project_name, s3_root_dir, aws_profile=None, aws_shared_credentials_file=None): """Sync a local project to S3. Parameters @@ -308,23 +309,43 @@ def sync(self, project_name, s3_root_dir): s3_root_dir : str The S3 path where the project's data will be synced. + aws_profile : str + Specifies the name of the AWS CLI profile with the credentials and options to use. + Defaults to None, using the AWS default name 'default'. + aws_shared_credentials_file : str + Specifies the location of the file that the AWS CLI uses to store access keys. + Defaults to None, using the AWS default path '~/.aws/credentials'. Notes ----- - Use to backup your local project data to S3, as an alternative to direct S3 logging. - Relies on AWS CLI's sync. Ensure that your credentials are set and that your Proxy - is on. + Use sync to backup your local project data to S3 as an alternative to direct S3 logging. + Leverages the AWS CLI's `aws s3 sync`. Ensure that any credentials are set and that any + proxies are enabled. """ if self.config.persistence != "filesystem": - raise RubiconException( - "You can't sync projects written to memory. Sync from either local filesystem or S3." - ) + raise RubiconException("Projects can only be synced from local or S3 filesystems.") + + cmd_root = "aws s3 sync" + + if aws_profile: + cmd_root += f" --profile {aws_profile}" + + original_aws_shared_credentials_file = os.environ.get("AWS_SHARED_CREDENTIALS_FILE") + + if aws_shared_credentials_file: + os.environ["AWS_SHARED_CREDENTIALS_FILE"] = aws_shared_credentials_file project = self.get_project(project_name) local_path = f"{self.config.root_dir}/{slugify(project.name)}" - cmd = f"aws s3 sync {local_path} {s3_root_dir}/{slugify(project.name)}" + cmd = f"{cmd_root} {local_path} {s3_root_dir}/{slugify(project.name)}" try: subprocess.run(cmd, shell=True, check=True, capture_output=True) except subprocess.CalledProcessError as e: raise RubiconException(e.stderr) + finally: + if aws_shared_credentials_file: + if original_aws_shared_credentials_file: + os.environ["AWS_SHARED_CREDENTIALS_FILE"] = original_aws_shared_credentials_file + else: + del os.environ["AWS_SHARED_CREDENTIALS_FILE"] diff --git a/tests/unit/client/test_rubicon_client.py b/tests/unit/client/test_rubicon_client.py index c24807de..643b4dcf 100644 --- a/tests/unit/client/test_rubicon_client.py +++ b/tests/unit/client/test_rubicon_client.py @@ -1,3 +1,4 @@ +import os import subprocess from unittest import mock @@ -198,43 +199,73 @@ def test_get_or_create_project(rubicon_client): assert created_project.id == fetched_project.id -def test_sync_from_memory(rubicon_and_project_client): - rubicon, project = rubicon_and_project_client +@mock.patch("subprocess.run") +@mock.patch("rubicon_ml.client.Rubicon.get_project") +def test_sync(mock_get_project, mock_run): + rubicon = Rubicon(persistence="filesystem", root_dir="./local/path") + project_name = "Sync Test Project" + mock_get_project.return_value = client.Project(domain.Project(project_name)) - with pytest.raises(RubiconException) as e: - rubicon.sync("Test Project", "s3://test/path") + rubicon.sync(project_name, "s3://test/path") - assert "can't sync projects written to memory" in str(e) + assert "aws s3 sync ./local/path/sync-test-project s3://test/path" in str( + mock_run._mock_call_args_list + ) @mock.patch("subprocess.run") @mock.patch("rubicon_ml.client.Rubicon.get_project") -def test_sync_from_local(mock_get_project, mock_run): +@pytest.mark.parametrize("default_cred_path", [None, "./default-creds"]) +def test_sync_aws_inputs(mock_get_project, mock_run, default_cred_path): rubicon = Rubicon(persistence="filesystem", root_dir="./local/path") project_name = "Sync Test Project" mock_get_project.return_value = client.Project(domain.Project(project_name)) - rubicon.sync(project_name, "s3://test/path") + if default_cred_path: + os.environ["AWS_SHARED_CREDENTIALS_FILE"] = default_cred_path - assert "aws s3 sync ./local/path/sync-test-project s3://test/path" in str( - mock_run._mock_call_args_list + rubicon.sync( + project_name, + "s3://test/path", + aws_profile="my-profile", + aws_shared_credentials_file="./my-creds", ) + assert ( + "aws s3 sync --profile my-profile " "./local/path/sync-test-project s3://test/path" + ) in str(mock_run._mock_call_args_list) + + if default_cred_path: + assert os.environ.get("AWS_SHARED_CREDENTIALS_FILE") == default_cred_path + else: + assert os.environ.get("AWS_SHARED_CREDENTIALS_FILE") is None + @mock.patch("subprocess.run") @mock.patch("rubicon_ml.client.Rubicon.get_project") -def test_sync_from_local_error(mock_get_project, mock_run): +def test_sync_cli_error(mock_get_project, mock_run): rubicon = Rubicon(persistence="filesystem", root_dir="./local/path") project_name = "Sync Test Project" mock_get_project.return_value = client.Project(domain.Project(project_name)) mock_run.side_effect = subprocess.CalledProcessError( - cmd="aws cli sync", stderr="Some error. I bet it was proxy tho.", returncode=1 + cmd="aws cli sync", + stderr="ERROR", + returncode=1, ) with pytest.raises(RubiconException) as e: rubicon.sync(project_name, "s3://test/path") - assert "Some error. I bet it was proxy tho." in str(e) + assert "ERROR" in str(e) + + +def test_sync_from_memory_error(rubicon_and_project_client): + rubicon, project = rubicon_and_project_client + + with pytest.raises(RubiconException) as e: + rubicon.sync("Test Project", "s3://test/path") + + assert "can only be synced from local or S3" in str(e) def test_get_project_as_dask_df(rubicon_and_project_client_with_experiments):