Skip to content

Commit

Permalink
Add args to cli
Browse files Browse the repository at this point in the history
  • Loading branch information
jsun-m committed Feb 19, 2025
1 parent e7c7b4f commit 2dc4498
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 82 deletions.
19 changes: 10 additions & 9 deletions sdk/src/beta9/abstractions/base/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def __init__(
self.stub_id: str = ""
self.handler: str = ""
self.on_start: str = ""
self.on_deploy: "AbstractCallableWrapper" = self._validate_on_deploy(on_deploy)
self.on_deploy: "AbstractCallableWrapper" = self.parse_on_deploy(on_deploy)
self.callback_url = callback_url or ""
self.cpu = cpu
self.memory = self.parse_memory(memory) if isinstance(memory, str) else memory
self.cpu = self.parse_cpu(cpu)
self.memory = self.parse_memory(memory)
self.gpu = gpu
self.gpu_count = gpu_count
self.volumes = volumes or []
Expand Down Expand Up @@ -202,6 +202,9 @@ def print_invocation_snippet(self, url_type: str = "") -> GetUrlResponse:
return res

def parse_memory(self, memory_str: str) -> int:
if not isinstance(memory_str, str):
return memory_str

"""Parse memory str (with units) to megabytes."""

if memory_str.lower().endswith("mi"):
Expand Down Expand Up @@ -233,7 +236,7 @@ def shell_stub(self) -> ShellServiceStub:
def shell_stub(self, value) -> None:
self._shell_stub = value

def _parse_cpu_to_millicores(self, cpu: Union[float, str]) -> int:
def parse_cpu(self, cpu: Union[float, str]) -> int:
"""
Parse the cpu argument to an integer value in millicores.
Expand Down Expand Up @@ -349,7 +352,7 @@ def _sync_content(
except Exception as e:
terminal.warn(str(e))

def _parse_gpu(self, gpu: Union[GpuTypeAlias, List[GpuTypeAlias]]) -> str:
def parse_gpu(self, gpu: Union[GpuTypeAlias, List[GpuTypeAlias]]) -> str:
if not isinstance(gpu, str) and not isinstance(gpu, list):
raise ValueError("Invalid GPU type")

Expand All @@ -358,7 +361,7 @@ def _parse_gpu(self, gpu: Union[GpuTypeAlias, List[GpuTypeAlias]]) -> str:
else:
return GpuType(gpu).value

def _validate_on_deploy(self, func: Callable):
def parse_on_deploy(self, func: Callable):
if func is None:
return None

Expand Down Expand Up @@ -403,8 +406,6 @@ def prepare_runtime(
if self.runtime_ready:
return True

self.cpu = self._parse_cpu_to_millicores(self.cpu)

if not self.image_available:
image_build_result: ImageBuildResult = self.image.build()

Expand Down Expand Up @@ -433,7 +434,7 @@ def prepare_runtime(
return False

try:
self.gpu = self._parse_gpu(self.gpu)
self.gpu = self.parse_gpu(self.gpu)
except ValueError:
terminal.error(f"Invalid GPU type: {self.gpu}", exit=False)
return False
Expand Down
10 changes: 6 additions & 4 deletions sdk/src/beta9/abstractions/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def stub(self) -> PodServiceStub:
def stub(self, value: PodServiceStub) -> None:
self._pod_stub = value

def create(self) -> str:
def create(self) -> bool:
if not self.prepare_runtime(stub_type=POD_RUN_STUB_TYPE, force_create_stub=True):
return ""
return

terminal.header("Creating")
create_response: CreatePodResponse = self.stub.create_pod(
Expand All @@ -120,7 +120,7 @@ def create(self) -> str:
terminal.header(f"Container created successfully ===> {create_response.container_id}")
self.print_invocation_snippet()

return create_response.container_id
return create_response.ok

def deploy(
self,
Expand Down Expand Up @@ -148,6 +148,8 @@ def deploy(
self.deployment_id = deploy_response.deployment_id
if deploy_response.ok:
terminal.header("Deployed 🎉")
self.print_invocation_snippet(**invocation_details_options)

if len(self.ports) > 0:
self.print_invocation_snippet(url_type="path")

return deploy_response.ok
105 changes: 75 additions & 30 deletions sdk/src/beta9/cli/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
StopDeploymentResponse,
StringList,
)
from .extraclick import ClickCommonGroup, ClickManagementGroup
from .extraclick import (
ClickCommonGroup,
ClickManagementGroup,
handle_config_override,
override_config_options,
)


@click.group(cls=ClickCommonGroup)
Expand Down Expand Up @@ -66,9 +71,22 @@ def common(**_):
help="The type of URL to get back. [default is determined by the server] ",
type=click.Choice(["host", "path"]),
)
@override_config_options
@click.pass_context
def deploy(ctx: click.Context, name: str, entrypoint: str, url_type: str):
ctx.invoke(create_deployment, name=name, entrypoint=entrypoint, url_type=url_type)
def deploy(
ctx: click.Context,
name: str,
entrypoint: str,
url_type: str,
**kwargs,
):
ctx.invoke(
create_deployment,
name=name,
entrypoint=entrypoint,
url_type=url_type,
**kwargs,
)


@click.group(
Expand All @@ -80,6 +98,17 @@ def management():
pass


def generate_pod_module(name: str, entrypoint: str):
from beta9.abstractions.pod import Pod

pod = Pod(
name=name,
entrypoint=entrypoint,
)

return pod


@management.command(
name="create",
help="Create a new deployment.",
Expand All @@ -99,41 +128,57 @@ def management():
@click.option(
"--entrypoint",
"-e",
help='The name the entrypoint e.g. "file:function".',
help='The name the entrypoint e.g. "file:function" or script to run.',
required=True,
)
@click.option(
"--url-type",
help="The type of URL to get back. [default is determined by the server] ",
type=click.Choice(["host", "path"]),
)
@override_config_options
@extraclick.pass_service_client
def create_deployment(service: ServiceClient, name: str, entrypoint: str, url_type: str):
current_dir = os.getcwd()
if current_dir not in sys.path:
sys.path.insert(0, current_dir)

module_path, obj_name, *_ = entrypoint.split(":") if ":" in entrypoint else (entrypoint, "")
module_name = module_path.replace(".py", "").replace(os.path.sep, ".")

if not Path(module_path).exists():
terminal.error(f"Unable to find file: '{module_path}'")

if not obj_name:
terminal.error(
"Invalid handler function specified. Expected format: beam deploy [file.py]:[function]"
)

module = importlib.import_module(module_name)

user_obj: Optional[DeployableMixin] = getattr(module, obj_name, None)
if user_obj is None:
terminal.error(
f"Invalid handler function specified. Make sure '{module_path}' contains the function: '{obj_name}'"
)

if hasattr(user_obj, "set_handler"):
user_obj.set_handler(f"{module_name}:{obj_name}")
def create_deployment(
service: ServiceClient,
name: str,
entrypoint: str,
url_type: str,
**kwargs,
):
try:
current_dir = os.getcwd()
if current_dir not in sys.path:
sys.path.insert(0, current_dir)

module_path, obj_name, *_ = entrypoint.split(":") if ":" in entrypoint else (entrypoint, "")
module_name = module_path.replace(".py", "").replace(os.path.sep, ".")

if not Path(module_path).exists():
terminal.error(f"Unable to find file: '{module_path}'")

if not obj_name:
terminal.error(
"Invalid handler function specified. Expected format: beam deploy [file.py]:[function]"
)

module = importlib.import_module(module_name)

user_obj: Optional[DeployableMixin] = getattr(module, obj_name, None)
if user_obj is None:
terminal.error(
f"Invalid handler function specified. Make sure '{module_path}' contains the function: '{obj_name}'"
)

if hasattr(user_obj, "set_handler"):
user_obj.set_handler(f"{module_name}:{obj_name}")
except BaseException as e:
terminal.error(f"Error importing module with entrypoint: {e}", exit=False)
# There is no entrypoint, so we generate a pod appfile
user_obj = generate_pod_module(name, entrypoint)

if not handle_config_override(user_obj, kwargs):
terminal.error("Failed to override config")
return

if not user_obj.deploy(name=name, context=service._config, url_type=url_type): # type: ignore
terminal.error("Deployment failed ☠️")
Expand Down
77 changes: 76 additions & 1 deletion sdk/src/beta9/cli/extraclick.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import click

from .. import terminal
from ..abstractions import base as base_abstraction
from ..channel import ServiceClient, with_grpc_error_handling
from ..clients.gateway import (
Expand All @@ -19,7 +20,6 @@
show_default=True,
)


config_context_param = click.Option(
param_decls=["-c", "--context"],
default=DEFAULT_CONTEXT_NAME,
Expand Down Expand Up @@ -235,3 +235,78 @@ def filter_values_callback(
filters[key] = StringList(values=value_list)

return filters


# Get all kwargs from __init__
def get_init_kwargs(cls):
sig = inspect.signature(cls.__init__)
kwargs = {
k: v.default
for k, v in sig.parameters.items()
if v.default is not inspect.Parameter.empty
and v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
}
return kwargs


def override_config_options(func: click.Command):
f = click.option(
"--cpu", type=click.INT, help="The number of CPU units to allocate.", required=False
)(func)
f = click.option(
"--memory",
type=click.STRING,
help="The amount of memory to allocate in MB.",
required=False,
)(f)
f = click.option(
"--gpu", type=click.STRING, help="The type of GPU to allocate.", required=False
)(f)
f = click.option(
"--gpu-count", type=click.INT, help="The number of GPUs to allocate.", required=False
)(f)
f = click.option(
"--image", type=click.STRING, help="The image to use for the deployment.", required=False
)(f)
f = click.option(
"--secrets",
type=click.STRING,
multiple=True,
help="The secrets to inject into the deployment.",
)(f)
f = click.option(
"--ports",
type=click.INT,
multiple=True,
help="The ports to expose the deployment on.",
)(f)
return f


def handle_config_override(func, kwargs: Dict[str, str]) -> bool:
try:
config_class_instance = None
if hasattr(func, "parent"):
config_class_instance = func.parent
else:
config_class_instance = func

# We only want to override the config if the config class has an __init__ method
# For example, ports is only available on a Pod
init_kwargs = get_init_kwargs(config_class_instance)

for key, value in kwargs.items():
if value is not None and key in init_kwargs:
if isinstance(value, tuple):
value = list(value)

if len(value) == 0:
continue

if hasattr(config_class_instance, f"parse_{key}"):
value = config_class_instance.__getattribute__(f"parse_{key}")(value)
setattr(config_class_instance, key, value)
return True
except BaseException as e:
terminal.error(f"Error overriding config: {e}", exit=False)
return False
Loading

0 comments on commit 2dc4498

Please sign in to comment.