Skip to content

Commit

Permalink
allow cred passthrough in sync
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanSoley committed Jul 12, 2024
1 parent 3735436 commit b066e57
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 20 deletions.
37 changes: 29 additions & 8 deletions rubicon_ml/client/rubicon.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import subprocess
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -298,7 +299,7 @@ def projects(self):
raise return_err

@failsafe
def sync(self, project_name, s3_root_dir):
def sync(self, project_name, s3_root_dir, aws_profile=None, aws_shared_credentials_file=None):
"""Sync a local project to S3.
Parameters
Expand All @@ -308,23 +309,43 @@ def sync(self, project_name, s3_root_dir):
s3_root_dir : str
The S3 path where the project's data
will be synced.
aws_profile : str
Specifies the name of the AWS CLI profile with the credentials and options to use.
Defaults to None, using the AWS default name 'default'.
aws_shared_credentials_file : str
Specifies the location of the file that the AWS CLI uses to store access keys.
Defaults to None, using the AWS default path '~/.aws/credentials'.
Notes
-----
Use to backup your local project data to S3, as an alternative to direct S3 logging.
Relies on AWS CLI's sync. Ensure that your credentials are set and that your Proxy
is on.
Use sync to backup your local project data to S3 as an alternative to direct S3 logging.
Leverages the AWS CLI's `aws s3 sync`. Ensure that any credentials are set and that any
proxies are enabled.
"""
if self.config.persistence != "filesystem":
raise RubiconException(
"You can't sync projects written to memory. Sync from either local filesystem or S3."
)
raise RubiconException("Projects can only be synced from local or S3 filesystems.")

cmd_root = "aws s3 sync"

if aws_profile:
cmd_root += f" --profile {aws_profile}"

original_aws_shared_credentials_file = os.environ.get("AWS_SHARED_CREDENTIALS_FILE")

if aws_shared_credentials_file:
os.environ["AWS_SHARED_CREDENTIALS_FILE"] = aws_shared_credentials_file

project = self.get_project(project_name)
local_path = f"{self.config.root_dir}/{slugify(project.name)}"
cmd = f"aws s3 sync {local_path} {s3_root_dir}/{slugify(project.name)}"
cmd = f"{cmd_root} {local_path} {s3_root_dir}/{slugify(project.name)}"

try:
subprocess.run(cmd, shell=True, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
raise RubiconException(e.stderr)
finally:
if aws_shared_credentials_file:
if original_aws_shared_credentials_file:
os.environ["AWS_SHARED_CREDENTIALS_FILE"] = original_aws_shared_credentials_file
else:
del os.environ["AWS_SHARED_CREDENTIALS_FILE"]
55 changes: 43 additions & 12 deletions tests/unit/client/test_rubicon_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import subprocess
from unittest import mock

Expand Down Expand Up @@ -198,43 +199,73 @@ def test_get_or_create_project(rubicon_client):
assert created_project.id == fetched_project.id


def test_sync_from_memory(rubicon_and_project_client):
rubicon, project = rubicon_and_project_client
@mock.patch("subprocess.run")
@mock.patch("rubicon_ml.client.Rubicon.get_project")
def test_sync(mock_get_project, mock_run):
rubicon = Rubicon(persistence="filesystem", root_dir="./local/path")
project_name = "Sync Test Project"
mock_get_project.return_value = client.Project(domain.Project(project_name))

with pytest.raises(RubiconException) as e:
rubicon.sync("Test Project", "s3://test/path")
rubicon.sync(project_name, "s3://test/path")

assert "can't sync projects written to memory" in str(e)
assert "aws s3 sync ./local/path/sync-test-project s3://test/path" in str(
mock_run._mock_call_args_list
)


@mock.patch("subprocess.run")
@mock.patch("rubicon_ml.client.Rubicon.get_project")
def test_sync_from_local(mock_get_project, mock_run):
@pytest.mark.parametrize("default_cred_path", [None, "./default-creds"])
def test_sync_aws_inputs(mock_get_project, mock_run, default_cred_path):
rubicon = Rubicon(persistence="filesystem", root_dir="./local/path")
project_name = "Sync Test Project"
mock_get_project.return_value = client.Project(domain.Project(project_name))

rubicon.sync(project_name, "s3://test/path")
if default_cred_path:
os.environ["AWS_SHARED_CREDENTIALS_FILE"] = default_cred_path

assert "aws s3 sync ./local/path/sync-test-project s3://test/path" in str(
mock_run._mock_call_args_list
rubicon.sync(
project_name,
"s3://test/path",
aws_profile="my-profile",
aws_shared_credentials_file="./my-creds",
)

assert (
"aws s3 sync --profile my-profile " "./local/path/sync-test-project s3://test/path"
) in str(mock_run._mock_call_args_list)

if default_cred_path:
assert os.environ.get("AWS_SHARED_CREDENTIALS_FILE") == default_cred_path
else:
assert os.environ.get("AWS_SHARED_CREDENTIALS_FILE") is None


@mock.patch("subprocess.run")
@mock.patch("rubicon_ml.client.Rubicon.get_project")
def test_sync_from_local_error(mock_get_project, mock_run):
def test_sync_cli_error(mock_get_project, mock_run):
rubicon = Rubicon(persistence="filesystem", root_dir="./local/path")
project_name = "Sync Test Project"
mock_get_project.return_value = client.Project(domain.Project(project_name))
mock_run.side_effect = subprocess.CalledProcessError(
cmd="aws cli sync", stderr="Some error. I bet it was proxy tho.", returncode=1
cmd="aws cli sync",
stderr="ERROR",
returncode=1,
)

with pytest.raises(RubiconException) as e:
rubicon.sync(project_name, "s3://test/path")

assert "Some error. I bet it was proxy tho." in str(e)
assert "ERROR" in str(e)


def test_sync_from_memory_error(rubicon_and_project_client):
rubicon, project = rubicon_and_project_client

with pytest.raises(RubiconException) as e:
rubicon.sync("Test Project", "s3://test/path")

assert "can only be synced from local or S3" in str(e)


def test_get_project_as_dask_df(rubicon_and_project_client_with_experiments):
Expand Down

0 comments on commit b066e57

Please sign in to comment.