Skip to content

Commit

Permalink
temporary_active_stack
Browse files Browse the repository at this point in the history
  • Loading branch information
kabinja committed Mar 24, 2024
1 parent 7e72e00 commit 951ea5b
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 42 deletions.
5 changes: 3 additions & 2 deletions src/zenml/cli/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ScheduleFilter,
)
from zenml.new.pipelines.pipeline import Pipeline
from zenml.stack.utils import temporary_active_stack
from zenml.utils import source_utils, uuid_utils
from zenml.utils.yaml_utils import write_yaml

Expand Down Expand Up @@ -184,7 +185,7 @@ def build_pipeline(
name_id_or_prefix=pipeline_name_or_id, version=version
)

with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
with temporary_active_stack(stack_name_or_id=stack_name_or_id):
pipeline_instance = Pipeline.from_model(pipeline_model)
build = pipeline_instance.build(config_path=config_path)

Expand Down Expand Up @@ -286,7 +287,7 @@ def run_pipeline(
"or file path."
)

with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
with temporary_active_stack(stack_name_or_id=stack_name_or_id):
pipeline_instance = Pipeline.from_model(pipeline_model)
pipeline_instance = pipeline_instance.with_options(
config_path=config_path,
Expand Down
31 changes: 0 additions & 31 deletions src/zenml/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# permissions and limitations under the License.
"""Utility functions for the CLI."""

import contextlib
import datetime
import json
import os
Expand All @@ -26,7 +25,6 @@
Any,
Callable,
Dict,
Iterator,
List,
NoReturn,
Optional,
Expand Down Expand Up @@ -78,8 +76,6 @@
from zenml.zen_server.deploy import ServerDeployment

if TYPE_CHECKING:
from uuid import UUID

from rich.text import Text

from zenml.enums import ExecutionStatus
Expand Down Expand Up @@ -2481,33 +2477,6 @@ def wrapper(function: F) -> F:
return inner_decorator


@contextlib.contextmanager
def temporary_active_stack(
stack_name_or_id: Union["UUID", str, None] = None,
) -> Iterator["Stack"]:
"""Contextmanager to temporarily activate a stack.
Args:
stack_name_or_id: The name or ID of the stack to activate. If not given,
this contextmanager will not do anything.
Yields:
The active stack.
"""
from zenml.client import Client

try:
if stack_name_or_id:
old_stack_id = Client().active_stack_model.id
Client().activate_stack(stack_name_or_id)
else:
old_stack_id = None
yield Client().active_stack
finally:
if old_stack_id:
Client().activate_stack(old_stack_id)


def get_package_information(
package_names: Optional[List[str]] = None,
) -> Dict[str, str]:
Expand Down
5 changes: 1 addition & 4 deletions src/zenml/new/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
prepare_model_versions,
)
from zenml.stack import Stack
from zenml.stack.utils import temporary_active_stack
from zenml.steps import BaseStep
from zenml.steps.entrypoint_function_utils import (
StepArtifact,
Expand Down Expand Up @@ -537,8 +538,6 @@ def build(
Returns:
The build output.
"""
from zenml.cli.utils import temporary_active_stack

with track_handler(
event=AnalyticsEvent.BUILD_PIPELINE
), temporary_active_stack():
Expand Down Expand Up @@ -609,8 +608,6 @@ def _run(
Model of the pipeline run if running without a schedule, `None` if
running with a schedule.
"""
from zenml.cli.utils import temporary_active_stack

if constants.SHOULD_PREVENT_PIPELINE_EXECUTION:
# An environment variable was set to stop the execution of
# pipelines. This is done to prevent execution of module-level
Expand Down
30 changes: 29 additions & 1 deletion src/zenml/stack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# permissions and limitations under the License.
"""Util functions for handling stacks, components, and flavors."""

from typing import Any, Dict, Optional
import contextlib
from typing import Any, Dict, Generator, Optional, Union
from uuid import UUID

from zenml.client import Client
from zenml.enums import StackComponentType, StoreType
from zenml.logger import get_logger
from zenml.models import FlavorFilter, FlavorResponse
from zenml.stack.flavor import Flavor
from zenml.stack.stack import Stack
from zenml.stack.stack_component import StackComponentConfig
from zenml.zen_stores.base_zen_store import BaseZenStore

Expand Down Expand Up @@ -139,3 +142,28 @@ def get_flavor_by_name_and_type_from_zen_store(
f"'{component_type}' exists."
)
return flavors[0]


@contextlib.contextmanager
def temporary_active_stack(
stack_name_or_id: Union[UUID, str, None] = None,
) -> Generator[Stack, Any, Any]:
"""Contextmanager to temporarily activate a stack.
Args:
stack_name_or_id: The name or ID of the stack to activate. If not given,
this contextmanager will not do anything.
Yields:
The active stack.
"""
try:
if stack_name_or_id:
old_stack_id = Client().active_stack_model.id
Client().activate_stack(stack_name_or_id)
else:
old_stack_id = None
yield Client().active_stack
finally:
if old_stack_id:
Client().activate_stack(old_stack_id)
6 changes: 2 additions & 4 deletions tests/integration/functional/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@

from tests.harness.harness import TestHarness
from zenml.cli import cli
from zenml.cli.utils import (
parse_name_and_extra_arguments,
temporary_active_stack,
)
from zenml.cli.utils import parse_name_and_extra_arguments
from zenml.client import Client
from zenml.models import (
TagFilter,
TagRequest,
UserResponse,
WorkspaceResponse,
)
from zenml.stack.utils import temporary_active_stack
from zenml.utils.string_utils import random_str

SAMPLE_CUSTOM_ARGUMENTS = [
Expand Down

0 comments on commit 951ea5b

Please sign in to comment.