Skip to content

Commit

Permalink
runs tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xmachak committed Dec 15, 2023
1 parent a7e30e0 commit 5a424cb
Show file tree
Hide file tree
Showing 14 changed files with 934 additions and 64 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Python application

on:
push:
branches: '*'

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9, 3.10, 3.11]

steps:
- uses: actions/checkout@v2
- name: Set up ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install poetry
run: |
python -m pip install --upgrade pip
pip install poetry
- name: Install dependencies
run: |
poetry install --without torch
- name: Test
run: |
pytest
2 changes: 1 addition & 1 deletion cascade/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def remote_func(*args, **kwargs):
"prefect-io_flow-name": flow_name,
"prefect-io_task-name": task_name,
"prefect-io_task-id": task_id,
"block_cascade-version": version("sq_cascade"),
"block_cascade-version": version("block_cascade"),
}
resource.environment = resource.environment or GcpEnvironmentConfig()
if resource.environment.is_complete:
Expand Down
153 changes: 90 additions & 63 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ description = "Library for model training in multi-cloud environment."
readme = "README.md"
authors = ["Block"]

[[source]]
name = "pypi"
url = "https://pypi.org/simple"
verify_ssl = true

[tool.poetry.dependencies]
python = ">=3.8,<3.12"
cloudml-hypertune = "==0.1.0.dev6"
Expand All @@ -32,6 +37,7 @@ pytest = ">=7.3.1"
pytest-env = "^0.8.1"
pytest-mock = "^3.11.1"
dask = {extras = ["distributed"], version = ">=2022"}
pyfakefs = "<5.3"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
Empty file added tests/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions tests/resource_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from cascade.executors.vertex.resource import (
GcpEnvironmentConfig,
GcpMachineConfig,
GcpResource,
)

GCP_PROJECT = "test-project"
TEST_BUCKET = GCP_PROJECT
REGION = "us-west1"

chief_machine = GcpMachineConfig("n1-standard-4", 1)
gcp_environment = GcpEnvironmentConfig(
project=GCP_PROJECT,
storage_location="gs://bucket/path/to/file",
region="us-west1",
service_account=f"{GCP_PROJECT}@{GCP_PROJECT}.iam.gserviceaccount.com",
image="cascade",
)
gcp_resource = GcpResource(
chief=chief_machine,
environment=gcp_environment,
)
143 changes: 143 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from pyfakefs.fake_filesystem import FakeFilesystem
import pytest

from cascade.config import find_default_configuration
from cascade.executors.databricks.resource import (
DatabricksAutoscaleConfig,
DatabricksResource,
)
from cascade.executors.vertex.resource import (
GcpEnvironmentConfig,
GcpMachineConfig,
GcpResource,
)


@pytest.fixture(params=["cascade.yaml", "cascade.yml"])
def configuration_filename(request):
return request.param


@pytest.fixture()
def gcp_project():
return "test-project"


@pytest.fixture()
def gcp_location():
return "us-central1"


@pytest.fixture()
def gcp_service_account():
return "[email protected]"


@pytest.fixture()
def gcp_machine_config():
return GcpMachineConfig(type="n1-standard-4", count=2)


@pytest.fixture
def gcp_environment(gcp_project, gcp_location, gcp_service_account):
return GcpEnvironmentConfig(
project=gcp_project, service_account=gcp_service_account, region=gcp_location
)


@pytest.fixture()
def gcp_resource(gcp_environment, gcp_machine_config):
return GcpResource(chief=gcp_machine_config, environment=gcp_environment)


@pytest.fixture()
def databricks_resource():
return DatabricksResource(
worker_count=DatabricksAutoscaleConfig(min_workers=5, max_workers=10),
cloud_pickle_by_value=["a", "b"],
)


@pytest.fixture()
def test_job_name():
return "hello-world"


def test_no_configuration():
assert find_default_configuration() is None


def test_invalid_type_specified(fs: FakeFilesystem, configuration_filename: str):
configuration = """
addition:
type: AwsResource
"""
fs.create_file(configuration_filename, contents=configuration)
with pytest.raises(ValueError):
find_default_configuration()


def test_gcp_resource(
fs: FakeFilesystem,
configuration_filename: str,
gcp_resource: GcpResource,
test_job_name: str,
):
configuration = f"""
{test_job_name}:
type: GcpResource
chief:
type: {gcp_resource.chief.type}
count: {gcp_resource.chief.count}
environment:
project: {gcp_resource.environment.project}
service_account: {gcp_resource.environment.service_account}
region: {gcp_resource.environment.region}
"""
fs.create_file(configuration_filename, contents=configuration)
assert gcp_resource == find_default_configuration()[test_job_name]


def test_databricks_resource(
fs: FakeFilesystem,
configuration_filename: str,
databricks_resource: DatabricksResource,
test_job_name: str,
):
configuration = f"""
{test_job_name}:
type: DatabricksResource
worker_count:
min_workers: {databricks_resource.worker_count.min_workers}
max_workers: {databricks_resource.worker_count.max_workers}
cloud_pickle_by_value:
- a
- b
"""
fs.create_file(configuration_filename, contents=configuration)
assert databricks_resource == find_default_configuration()[test_job_name]


def test_merged_resources(
fs: FakeFilesystem,
configuration_filename: str,
test_job_name: str,
gcp_resource: GcpResource,
):
configuration = f"""
default:
GcpResource:
environment:
project: "ds-cash-dev"
service_account: {gcp_resource.environment.service_account}
region: {gcp_resource.environment.region}
{test_job_name}:
type: GcpResource
environment:
project: {gcp_resource.environment.project}
chief:
type: {gcp_resource.chief.type}
count: {gcp_resource.chief.count}
"""
fs.create_file(configuration_filename, contents=configuration)
assert gcp_resource == find_default_configuration()[test_job_name]
72 changes: 72 additions & 0 deletions tests/test_databricks_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from functools import partial
from unittest.mock import patch

from cascade import DatabricksResource
from cascade.executors import DatabricksExecutor
from cascade.executors.databricks.job import DatabricksJob
from cascade.utils import wrapped_partial

databricks_resource = DatabricksResource()

# Mocks paths
MOCK_CLUSTER_POLICY = (
"cascade.executors.DatabricksExecutor.get_cluster_policy_id_from_policy_name"
)
MOCK__RUN = "cascade.executors.DatabricksExecutor._run"
MOCK_FILESYSTEM = "cascade.executors.DatabricksExecutor.fs"
MOCK_STORAGE_PATH = "cascade.executors.DatabricksExecutor.storage_path"

DATABRICKS_GROUP = "cascade"

databricks_resource = DatabricksResource(
group_name=DATABRICKS_GROUP,
)


def addition(a: int, b: int) -> int:
return a + b


addition_packed = wrapped_partial(addition, 1, 2)


def test_create_executor():
"""Test that a DatabricksExecutor can be created."""
_ = DatabricksExecutor(
func=addition_packed,
resource=databricks_resource,
)


@patch(MOCK_CLUSTER_POLICY, return_value="12345")
def test_create_job(mock_cluster_policy):
"""Test that the creat_job method returns a valid DatabricksJob object."""
executor = DatabricksExecutor(
func=addition_packed,
resource=databricks_resource,
)
databricks_job = executor.create_job()
assert isinstance(databricks_job, DatabricksJob)


@patch(MOCK_CLUSTER_POLICY, return_value="12345")
def test_infer_name(mock_cluster_policy):
"""Test that if no name is provided, the name is inferred correctly."""
executor = DatabricksExecutor(
func=addition_packed,
resource=databricks_resource,
)
assert executor.name is None
_ = executor.create_job()
assert executor.name == "addition"

partial_func = partial(addition, 1, 2)

executor_partialfunc = DatabricksExecutor(
func=partial_func,
resource=databricks_resource,
)

assert executor_partialfunc.name is None
_ = executor_partialfunc.create_job()
assert executor_partialfunc.name == "unnamed"
81 changes: 81 additions & 0 deletions tests/test_local_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
from unittest.mock import MagicMock, patch

import cloudpickle
from fsspec.implementations.local import LocalFileSystem

from cascade.executors import LocalExecutor
from cascade.utils import wrapped_partial


def addition(a: int, b: int) -> int:
"""Adds two numbers together."""
return a + b


# create a mock of addition and set its name to "addition"
mocked_addition = MagicMock(return_value=3)
mocked_addition.__name__ = "addition"

prepared_addition = wrapped_partial(addition, 1, 2)


def test_local_executor():
"""Test the local executor run method."""

executor = LocalExecutor(func=prepared_addition)
result = executor.run()
assert result == 3


def test_run_twice():
"""Tests that if the executor is run twice
the second run executes the function again and stores it in a unique file.
"""

executor = LocalExecutor(func=mocked_addition)

result1 = executor.run()
result2 = executor.run()

assert mocked_addition.call_count == 2
assert result1 == result2


def test_new_executor():
"""
Tests generating a new executor from an existing one.
"""
mocked_addition.call_count = 0

executor1 = LocalExecutor(func=mocked_addition)
result1 = executor1.run()

executor2 = executor1.with_()
result2 = executor2.run()

assert mocked_addition.call_count == 2
assert executor1 != executor2
assert result1 == result2


@patch("cascade.executors.executor.uuid4", return_value="12345")
def test_result(mock_uuid4):
"""
Tests that a file containing a pickled function can be opened, the function run
and the results written to a local filepath.
"""
fs = LocalFileSystem(auto_mkdir=True)
executor = LocalExecutor(func=mocked_addition)

path_root = os.path.expanduser("~")

# test that the staged_filepath was created correctly
assert executor.staged_filepath == f"{path_root}/cascade-storage/12345/function.pkl"

# stage the pickled function to the staged_filedpath
with fs.open(executor.staged_filepath, "wb") as f:
cloudpickle.dump(wrapped_partial, f)

result = executor._result()
assert result == 3
Loading

0 comments on commit 5a424cb

Please sign in to comment.