diff --git a/block_cascade/executors/vertex/executor.py b/block_cascade/executors/vertex/executor.py index 00d4bde..9f8c403 100644 --- a/block_cascade/executors/vertex/executor.py +++ b/block_cascade/executors/vertex/executor.py @@ -213,7 +213,7 @@ def create_job(self) -> VertexJob: raise RuntimeError( f"Unable to parse bucket from storage block: {storage}" ) - deployment_path = deployment.path.rstrip("/") + deployment_path = storage.data.get("bucket_folder").rstrip("/") or deployment.path.rstrip("/") package_path = f"{bucket}/{deployment_path}/{module_name}" self._logger.info( diff --git a/block_cascade/prefect/v2/__init__.py b/block_cascade/prefect/v2/__init__.py index 4c03261..53cd7f4 100644 --- a/block_cascade/prefect/v2/__init__.py +++ b/block_cascade/prefect/v2/__init__.py @@ -38,6 +38,12 @@ async def _fetch_block(block_id: str) -> Optional[BlockDocument]: async with get_client() as client: return await client.read_block_document(block_id) +async def _fetch_block_by_name(block_name: str, block_type_slug: str = "gcs-bucket") -> Optional[BlockDocument]: + async with get_client() as client: + return await client.read_block_document_by_name( + name=block_name, + block_type_slug=block_type_slug, + ) def get_from_prefect_context(attr: str, default: str = "") -> str: flow_context = FlowRunContext.get() @@ -80,8 +86,13 @@ def get_storage_block() -> Optional[BlockDocument]: global _CACHED_STORAGE # noqa: PLW0603 if not _CACHED_STORAGE: - _CACHED_STORAGE = run_async( - _fetch_block(current_deployment.storage_document_id) + if current_deployment.pull_steps: + _CACHED_STORAGE = run_async( + _fetch_block_by_name(block_name=current_deployment.pull_steps[0]["prefect.deployments.steps.pull_with_block"]["block_document_name"]) + ) + else: + _CACHED_STORAGE = run_async( + _fetch_block(block_id=current_deployment.storage_document_id) ) return _CACHED_STORAGE diff --git a/block_cascade/prefect/v2/environment.py b/block_cascade/prefect/v2/environment.py index 56b1833..0e6096b 100644 --- a/block_cascade/prefect/v2/environment.py +++ b/block_cascade/prefect/v2/environment.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional from prefect.context import FlowRunContext @@ -24,47 +24,71 @@ class PrefectEnvironmentClient(VertexAIEnvironmentInfoProvider): def __init__(self): self._current_deployment = None + self._current_job_variables = None self._current_infrastructure = None def get_container_image(self) -> Optional[str]: - infra = self._get_infrastructure_block() - if not infra: - return + job_variables = self._get_job_variables() + if job_variables: + return job_variables.get("image") - deployment_details = infra.data - return deployment_details.get("image") + infra = self._get_infrastructure_block() + if infra: + return infra.data.get("image") + return None def get_network(self) -> Optional[str]: + job_variables = self._get_job_variables() + if job_variables: + return job_variables.get("network") + infra = self._get_infrastructure_block() - if not infra: - return + if infra: + return infra.data.get("network") - deployment_details = infra.data - return deployment_details.get("network") + return None def get_project(self) -> Optional[str]: + job_variables = self._get_job_variables() + if job_variables: + return job_variables.get("credentials", {}).get("project") + infra = self._get_infrastructure_block() - if not infra: - return + if infra: + return infra.data.get("gcp_credentials", {}).get("project") - deployment_details = infra.data - return deployment_details.get("gcp_credentials", {}).get("project") + return None def get_region(self) -> Optional[str]: + job_variables = self._get_job_variables() + if job_variables: + return job_variables.get("region") + infra = self._get_infrastructure_block() - if not infra: - return + if infra: + return infra.data.get("region") - deployment_details = infra.data - return deployment_details.get("region") + return None def get_service_account(self) -> Optional[str]: + job_variables = self._get_job_variables() + if job_variables: + return job_variables.get("service_account_name") + infra = self._get_infrastructure_block() - if not infra: - return + if infra: + return infra.data.get("service_account") + + return None + + def _get_job_variables(self) -> Optional[Dict]: + current_deployment = self._get_current_deployment() + if not current_deployment: + return None - deployment_details = infra.data - return deployment_details.get("service_account") + if not self._current_job_variables: + self._current_job_variables = current_deployment.job_variables + return self._current_job_variables def _get_infrastructure_block(self) -> Optional[BlockDocument]: current_deployment = self._get_current_deployment() diff --git a/pyproject.toml b/pyproject.toml index 79fb74d..f1801fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "block-cascade" packages = [ {include = "block_cascade"} ] -version = "2.6.0" +version = "2.6.1" description = "Library for model training in multi-cloud environment." readme = "README.md" authors = ["Block"] diff --git a/tests/test_prefect_environment.py b/tests/test_prefect_environment.py new file mode 100644 index 0000000..e06eebe --- /dev/null +++ b/tests/test_prefect_environment.py @@ -0,0 +1,89 @@ +import pytest +from unittest.mock import Mock, patch + +from prefect.client.schemas.responses import DeploymentResponse +from prefect.context import FlowRunContext + +from block_cascade.prefect.v2.environment import PrefectEnvironmentClient + + +@pytest.fixture(autouse=True) +def mock_infrastructure_block(): + infra_block = Mock() + infra_block.data = { + "image": "infra_image", + "network": "infra_network", + "gcp_credentials": {"project": "infra_project"}, + "region": "infra_region", + "service_account": "infra_service_account" + } + with patch("block_cascade.prefect.v2.environment._fetch_block", return_value=infra_block): + yield infra_block + +@pytest.fixture +def mock_job_variables(): + return { + "image": "job_image", + "network": "job_network", + "credentials": {"project": "job_project"}, + "region": "job_region", + "service_account_name": "job_service_account" + } + +@pytest.fixture +def mock_deployment_response(mock_job_variables): + mock_deployment = Mock(spec=DeploymentResponse) + mock_deployment.job_variables = mock_job_variables + mock_deployment.infrastructure_document_id = "mock_infrastructure_id" + return mock_deployment + +@pytest.fixture(autouse=True) +def mock__fetch_deployment(mock_deployment_response): + with patch("block_cascade.prefect.v2.environment._fetch_deployment", return_value=mock_deployment_response): + yield + +@pytest.fixture(autouse=True) +def mock_flow_run_context(): + mock_flow_run = Mock() + mock_flow_run.deployment_id = "mock_deployment_id" + + mock_context = Mock(spec=FlowRunContext) + mock_context.flow_run = mock_flow_run + + with patch("block_cascade.prefect.v2.environment.FlowRunContext.get", return_value=mock_context): + yield mock_context + +def test_get_container_image(): + client = PrefectEnvironmentClient() + + assert client.get_container_image() == "job_image" + +def test_get_network(): + client = PrefectEnvironmentClient() + + assert client.get_network() == "job_network" + +def test_get_project(): + client = PrefectEnvironmentClient() + + assert client.get_project() == "job_project" + +def test_get_region(): + client = PrefectEnvironmentClient() + + assert client.get_region() == "job_region" + +def test_get_service_account(): + client = PrefectEnvironmentClient() + + assert client.get_service_account() == "job_service_account" + +def test_fallback_to_infrastructure(mock_deployment_response): + client = PrefectEnvironmentClient() + mock_deployment_response.job_variables = None + + assert client.get_container_image() == "infra_image" + assert client.get_network() == "infra_network" + assert client.get_project() == "infra_project" + assert client.get_region() == "infra_region" + assert client.get_service_account() == "infra_service_account" \ No newline at end of file