-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
934 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
name: Python application | ||
|
||
on: | ||
push: | ||
branches: '*' | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: [3.8, 3.9, 3.10, 3.11] | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install poetry | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install poetry | ||
- name: Install dependencies | ||
run: | | ||
poetry install --without torch | ||
- name: Test | ||
run: | | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from cascade.executors.vertex.resource import ( | ||
GcpEnvironmentConfig, | ||
GcpMachineConfig, | ||
GcpResource, | ||
) | ||
|
||
GCP_PROJECT = "test-project" | ||
TEST_BUCKET = GCP_PROJECT | ||
REGION = "us-west1" | ||
|
||
chief_machine = GcpMachineConfig("n1-standard-4", 1) | ||
gcp_environment = GcpEnvironmentConfig( | ||
project=GCP_PROJECT, | ||
storage_location="gs://bucket/path/to/file", | ||
region="us-west1", | ||
service_account=f"{GCP_PROJECT}@{GCP_PROJECT}.iam.gserviceaccount.com", | ||
image="cascade", | ||
) | ||
gcp_resource = GcpResource( | ||
chief=chief_machine, | ||
environment=gcp_environment, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from pyfakefs.fake_filesystem import FakeFilesystem | ||
import pytest | ||
|
||
from cascade.config import find_default_configuration | ||
from cascade.executors.databricks.resource import ( | ||
DatabricksAutoscaleConfig, | ||
DatabricksResource, | ||
) | ||
from cascade.executors.vertex.resource import ( | ||
GcpEnvironmentConfig, | ||
GcpMachineConfig, | ||
GcpResource, | ||
) | ||
|
||
|
||
@pytest.fixture(params=["cascade.yaml", "cascade.yml"]) | ||
def configuration_filename(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture() | ||
def gcp_project(): | ||
return "test-project" | ||
|
||
|
||
@pytest.fixture() | ||
def gcp_location(): | ||
return "us-central1" | ||
|
||
|
||
@pytest.fixture() | ||
def gcp_service_account(): | ||
return "[email protected]" | ||
|
||
|
||
@pytest.fixture() | ||
def gcp_machine_config(): | ||
return GcpMachineConfig(type="n1-standard-4", count=2) | ||
|
||
|
||
@pytest.fixture | ||
def gcp_environment(gcp_project, gcp_location, gcp_service_account): | ||
return GcpEnvironmentConfig( | ||
project=gcp_project, service_account=gcp_service_account, region=gcp_location | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def gcp_resource(gcp_environment, gcp_machine_config): | ||
return GcpResource(chief=gcp_machine_config, environment=gcp_environment) | ||
|
||
|
||
@pytest.fixture() | ||
def databricks_resource(): | ||
return DatabricksResource( | ||
worker_count=DatabricksAutoscaleConfig(min_workers=5, max_workers=10), | ||
cloud_pickle_by_value=["a", "b"], | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def test_job_name(): | ||
return "hello-world" | ||
|
||
|
||
def test_no_configuration(): | ||
assert find_default_configuration() is None | ||
|
||
|
||
def test_invalid_type_specified(fs: FakeFilesystem, configuration_filename: str): | ||
configuration = """ | ||
addition: | ||
type: AwsResource | ||
""" | ||
fs.create_file(configuration_filename, contents=configuration) | ||
with pytest.raises(ValueError): | ||
find_default_configuration() | ||
|
||
|
||
def test_gcp_resource( | ||
fs: FakeFilesystem, | ||
configuration_filename: str, | ||
gcp_resource: GcpResource, | ||
test_job_name: str, | ||
): | ||
configuration = f""" | ||
{test_job_name}: | ||
type: GcpResource | ||
chief: | ||
type: {gcp_resource.chief.type} | ||
count: {gcp_resource.chief.count} | ||
environment: | ||
project: {gcp_resource.environment.project} | ||
service_account: {gcp_resource.environment.service_account} | ||
region: {gcp_resource.environment.region} | ||
""" | ||
fs.create_file(configuration_filename, contents=configuration) | ||
assert gcp_resource == find_default_configuration()[test_job_name] | ||
|
||
|
||
def test_databricks_resource( | ||
fs: FakeFilesystem, | ||
configuration_filename: str, | ||
databricks_resource: DatabricksResource, | ||
test_job_name: str, | ||
): | ||
configuration = f""" | ||
{test_job_name}: | ||
type: DatabricksResource | ||
worker_count: | ||
min_workers: {databricks_resource.worker_count.min_workers} | ||
max_workers: {databricks_resource.worker_count.max_workers} | ||
cloud_pickle_by_value: | ||
- a | ||
- b | ||
""" | ||
fs.create_file(configuration_filename, contents=configuration) | ||
assert databricks_resource == find_default_configuration()[test_job_name] | ||
|
||
|
||
def test_merged_resources( | ||
fs: FakeFilesystem, | ||
configuration_filename: str, | ||
test_job_name: str, | ||
gcp_resource: GcpResource, | ||
): | ||
configuration = f""" | ||
default: | ||
GcpResource: | ||
environment: | ||
project: "ds-cash-dev" | ||
service_account: {gcp_resource.environment.service_account} | ||
region: {gcp_resource.environment.region} | ||
{test_job_name}: | ||
type: GcpResource | ||
environment: | ||
project: {gcp_resource.environment.project} | ||
chief: | ||
type: {gcp_resource.chief.type} | ||
count: {gcp_resource.chief.count} | ||
""" | ||
fs.create_file(configuration_filename, contents=configuration) | ||
assert gcp_resource == find_default_configuration()[test_job_name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from functools import partial | ||
from unittest.mock import patch | ||
|
||
from cascade import DatabricksResource | ||
from cascade.executors import DatabricksExecutor | ||
from cascade.executors.databricks.job import DatabricksJob | ||
from cascade.utils import wrapped_partial | ||
|
||
databricks_resource = DatabricksResource() | ||
|
||
# Mocks paths | ||
MOCK_CLUSTER_POLICY = ( | ||
"cascade.executors.DatabricksExecutor.get_cluster_policy_id_from_policy_name" | ||
) | ||
MOCK__RUN = "cascade.executors.DatabricksExecutor._run" | ||
MOCK_FILESYSTEM = "cascade.executors.DatabricksExecutor.fs" | ||
MOCK_STORAGE_PATH = "cascade.executors.DatabricksExecutor.storage_path" | ||
|
||
DATABRICKS_GROUP = "cascade" | ||
|
||
databricks_resource = DatabricksResource( | ||
group_name=DATABRICKS_GROUP, | ||
) | ||
|
||
|
||
def addition(a: int, b: int) -> int: | ||
return a + b | ||
|
||
|
||
addition_packed = wrapped_partial(addition, 1, 2) | ||
|
||
|
||
def test_create_executor(): | ||
"""Test that a DatabricksExecutor can be created.""" | ||
_ = DatabricksExecutor( | ||
func=addition_packed, | ||
resource=databricks_resource, | ||
) | ||
|
||
|
||
@patch(MOCK_CLUSTER_POLICY, return_value="12345") | ||
def test_create_job(mock_cluster_policy): | ||
"""Test that the creat_job method returns a valid DatabricksJob object.""" | ||
executor = DatabricksExecutor( | ||
func=addition_packed, | ||
resource=databricks_resource, | ||
) | ||
databricks_job = executor.create_job() | ||
assert isinstance(databricks_job, DatabricksJob) | ||
|
||
|
||
@patch(MOCK_CLUSTER_POLICY, return_value="12345") | ||
def test_infer_name(mock_cluster_policy): | ||
"""Test that if no name is provided, the name is inferred correctly.""" | ||
executor = DatabricksExecutor( | ||
func=addition_packed, | ||
resource=databricks_resource, | ||
) | ||
assert executor.name is None | ||
_ = executor.create_job() | ||
assert executor.name == "addition" | ||
|
||
partial_func = partial(addition, 1, 2) | ||
|
||
executor_partialfunc = DatabricksExecutor( | ||
func=partial_func, | ||
resource=databricks_resource, | ||
) | ||
|
||
assert executor_partialfunc.name is None | ||
_ = executor_partialfunc.create_job() | ||
assert executor_partialfunc.name == "unnamed" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import os | ||
from unittest.mock import MagicMock, patch | ||
|
||
import cloudpickle | ||
from fsspec.implementations.local import LocalFileSystem | ||
|
||
from cascade.executors import LocalExecutor | ||
from cascade.utils import wrapped_partial | ||
|
||
|
||
def addition(a: int, b: int) -> int: | ||
"""Adds two numbers together.""" | ||
return a + b | ||
|
||
|
||
# create a mock of addition and set its name to "addition" | ||
mocked_addition = MagicMock(return_value=3) | ||
mocked_addition.__name__ = "addition" | ||
|
||
prepared_addition = wrapped_partial(addition, 1, 2) | ||
|
||
|
||
def test_local_executor(): | ||
"""Test the local executor run method.""" | ||
|
||
executor = LocalExecutor(func=prepared_addition) | ||
result = executor.run() | ||
assert result == 3 | ||
|
||
|
||
def test_run_twice(): | ||
"""Tests that if the executor is run twice | ||
the second run executes the function again and stores it in a unique file. | ||
""" | ||
|
||
executor = LocalExecutor(func=mocked_addition) | ||
|
||
result1 = executor.run() | ||
result2 = executor.run() | ||
|
||
assert mocked_addition.call_count == 2 | ||
assert result1 == result2 | ||
|
||
|
||
def test_new_executor(): | ||
""" | ||
Tests generating a new executor from an existing one. | ||
""" | ||
mocked_addition.call_count = 0 | ||
|
||
executor1 = LocalExecutor(func=mocked_addition) | ||
result1 = executor1.run() | ||
|
||
executor2 = executor1.with_() | ||
result2 = executor2.run() | ||
|
||
assert mocked_addition.call_count == 2 | ||
assert executor1 != executor2 | ||
assert result1 == result2 | ||
|
||
|
||
@patch("cascade.executors.executor.uuid4", return_value="12345") | ||
def test_result(mock_uuid4): | ||
""" | ||
Tests that a file containing a pickled function can be opened, the function run | ||
and the results written to a local filepath. | ||
""" | ||
fs = LocalFileSystem(auto_mkdir=True) | ||
executor = LocalExecutor(func=mocked_addition) | ||
|
||
path_root = os.path.expanduser("~") | ||
|
||
# test that the staged_filepath was created correctly | ||
assert executor.staged_filepath == f"{path_root}/cascade-storage/12345/function.pkl" | ||
|
||
# stage the pickled function to the staged_filedpath | ||
with fs.open(executor.staged_filepath, "wb") as f: | ||
cloudpickle.dump(wrapped_partial, f) | ||
|
||
result = executor._result() | ||
assert result == 3 |
Oops, something went wrong.