diff --git a/sdk/src/beta9/abstractions/base/runner.py b/sdk/src/beta9/abstractions/base/runner.py index 7cc7c5a04..c03010761 100644 --- a/sdk/src/beta9/abstractions/base/runner.py +++ b/sdk/src/beta9/abstractions/base/runner.py @@ -124,10 +124,10 @@ def __init__( self.stub_id: str = "" self.handler: str = "" self.on_start: str = "" - self.on_deploy: "AbstractCallableWrapper" = self._validate_on_deploy(on_deploy) + self.on_deploy: "AbstractCallableWrapper" = self.parse_on_deploy(on_deploy) self.callback_url = callback_url or "" - self.cpu = cpu - self.memory = self.parse_memory(memory) if isinstance(memory, str) else memory + self.cpu = self.parse_cpu(cpu) + self.memory = self.parse_memory(memory) self.gpu = gpu self.gpu_count = gpu_count self.volumes = volumes or [] @@ -202,6 +202,9 @@ def print_invocation_snippet(self, url_type: str = "") -> GetUrlResponse: return res def parse_memory(self, memory_str: str) -> int: + if not isinstance(memory_str, str): + return memory_str + """Parse memory str (with units) to megabytes.""" if memory_str.lower().endswith("mi"): @@ -233,7 +236,7 @@ def shell_stub(self) -> ShellServiceStub: def shell_stub(self, value) -> None: self._shell_stub = value - def _parse_cpu_to_millicores(self, cpu: Union[float, str]) -> int: + def parse_cpu(self, cpu: Union[float, str]) -> int: """ Parse the cpu argument to an integer value in millicores. @@ -349,7 +352,7 @@ def _sync_content( except Exception as e: terminal.warn(str(e)) - def _parse_gpu(self, gpu: Union[GpuTypeAlias, List[GpuTypeAlias]]) -> str: + def parse_gpu(self, gpu: Union[GpuTypeAlias, List[GpuTypeAlias]]) -> str: if not isinstance(gpu, str) and not isinstance(gpu, list): raise ValueError("Invalid GPU type") @@ -358,7 +361,7 @@ def _parse_gpu(self, gpu: Union[GpuTypeAlias, List[GpuTypeAlias]]) -> str: else: return GpuType(gpu).value - def _validate_on_deploy(self, func: Callable): + def parse_on_deploy(self, func: Callable): if func is None: return None @@ -403,8 +406,6 @@ def prepare_runtime( if self.runtime_ready: return True - self.cpu = self._parse_cpu_to_millicores(self.cpu) - if not self.image_available: image_build_result: ImageBuildResult = self.image.build() @@ -433,7 +434,7 @@ def prepare_runtime( return False try: - self.gpu = self._parse_gpu(self.gpu) + self.gpu = self.parse_gpu(self.gpu) except ValueError: terminal.error(f"Invalid GPU type: {self.gpu}", exit=False) return False diff --git a/sdk/src/beta9/abstractions/pod.py b/sdk/src/beta9/abstractions/pod.py index 204eeb247..7764d4522 100644 --- a/sdk/src/beta9/abstractions/pod.py +++ b/sdk/src/beta9/abstractions/pod.py @@ -105,9 +105,9 @@ def stub(self) -> PodServiceStub: def stub(self, value: PodServiceStub) -> None: self._pod_stub = value - def create(self) -> str: + def create(self) -> bool: if not self.prepare_runtime(stub_type=POD_RUN_STUB_TYPE, force_create_stub=True): - return "" + return terminal.header("Creating") create_response: CreatePodResponse = self.stub.create_pod( @@ -120,7 +120,7 @@ def create(self) -> str: terminal.header(f"Container created successfully ===> {create_response.container_id}") self.print_invocation_snippet() - return create_response.container_id + return create_response.ok def deploy( self, @@ -148,6 +148,8 @@ def deploy( self.deployment_id = deploy_response.deployment_id if deploy_response.ok: terminal.header("Deployed 🎉") - self.print_invocation_snippet(**invocation_details_options) + + if len(self.ports) > 0: + self.print_invocation_snippet(url_type="path") return deploy_response.ok diff --git a/sdk/src/beta9/cli/deployment.py b/sdk/src/beta9/cli/deployment.py index 448e85d8d..37ab460a9 100644 --- a/sdk/src/beta9/cli/deployment.py +++ b/sdk/src/beta9/cli/deployment.py @@ -25,7 +25,12 @@ StopDeploymentResponse, StringList, ) -from .extraclick import ClickCommonGroup, ClickManagementGroup +from .extraclick import ( + ClickCommonGroup, + ClickManagementGroup, + handle_config_override, + override_config_options, +) @click.group(cls=ClickCommonGroup) @@ -66,9 +71,22 @@ def common(**_): help="The type of URL to get back. [default is determined by the server] ", type=click.Choice(["host", "path"]), ) +@override_config_options @click.pass_context -def deploy(ctx: click.Context, name: str, entrypoint: str, url_type: str): - ctx.invoke(create_deployment, name=name, entrypoint=entrypoint, url_type=url_type) +def deploy( + ctx: click.Context, + name: str, + entrypoint: str, + url_type: str, + **kwargs, +): + ctx.invoke( + create_deployment, + name=name, + entrypoint=entrypoint, + url_type=url_type, + **kwargs, + ) @click.group( @@ -80,6 +98,17 @@ def management(): pass +def generate_pod_module(name: str, entrypoint: str): + from beta9.abstractions.pod import Pod + + pod = Pod( + name=name, + entrypoint=entrypoint, + ) + + return pod + + @management.command( name="create", help="Create a new deployment.", @@ -99,7 +128,7 @@ def management(): @click.option( "--entrypoint", "-e", - help='The name the entrypoint e.g. "file:function".', + help='The name the entrypoint e.g. "file:function" or script to run.', required=True, ) @click.option( @@ -107,33 +136,49 @@ def management(): help="The type of URL to get back. [default is determined by the server] ", type=click.Choice(["host", "path"]), ) +@override_config_options @extraclick.pass_service_client -def create_deployment(service: ServiceClient, name: str, entrypoint: str, url_type: str): - current_dir = os.getcwd() - if current_dir not in sys.path: - sys.path.insert(0, current_dir) - - module_path, obj_name, *_ = entrypoint.split(":") if ":" in entrypoint else (entrypoint, "") - module_name = module_path.replace(".py", "").replace(os.path.sep, ".") - - if not Path(module_path).exists(): - terminal.error(f"Unable to find file: '{module_path}'") - - if not obj_name: - terminal.error( - "Invalid handler function specified. Expected format: beam deploy [file.py]:[function]" - ) - - module = importlib.import_module(module_name) - - user_obj: Optional[DeployableMixin] = getattr(module, obj_name, None) - if user_obj is None: - terminal.error( - f"Invalid handler function specified. Make sure '{module_path}' contains the function: '{obj_name}'" - ) - - if hasattr(user_obj, "set_handler"): - user_obj.set_handler(f"{module_name}:{obj_name}") +def create_deployment( + service: ServiceClient, + name: str, + entrypoint: str, + url_type: str, + **kwargs, +): + try: + current_dir = os.getcwd() + if current_dir not in sys.path: + sys.path.insert(0, current_dir) + + module_path, obj_name, *_ = entrypoint.split(":") if ":" in entrypoint else (entrypoint, "") + module_name = module_path.replace(".py", "").replace(os.path.sep, ".") + + if not Path(module_path).exists(): + terminal.error(f"Unable to find file: '{module_path}'") + + if not obj_name: + terminal.error( + "Invalid handler function specified. Expected format: beam deploy [file.py]:[function]" + ) + + module = importlib.import_module(module_name) + + user_obj: Optional[DeployableMixin] = getattr(module, obj_name, None) + if user_obj is None: + terminal.error( + f"Invalid handler function specified. Make sure '{module_path}' contains the function: '{obj_name}'" + ) + + if hasattr(user_obj, "set_handler"): + user_obj.set_handler(f"{module_name}:{obj_name}") + except BaseException as e: + terminal.error(f"Error importing module with entrypoint: {e}", exit=False) + # There is no entrypoint, so we generate a pod appfile + user_obj = generate_pod_module(name, entrypoint) + + if not handle_config_override(user_obj, kwargs): + terminal.error("Failed to override config") + return if not user_obj.deploy(name=name, context=service._config, url_type=url_type): # type: ignore terminal.error("Deployment failed ☠️") diff --git a/sdk/src/beta9/cli/extraclick.py b/sdk/src/beta9/cli/extraclick.py index 18adea52f..3e5a1c667 100644 --- a/sdk/src/beta9/cli/extraclick.py +++ b/sdk/src/beta9/cli/extraclick.py @@ -7,6 +7,7 @@ import click +from .. import terminal from ..abstractions import base as base_abstraction from ..channel import ServiceClient, with_grpc_error_handling from ..clients.gateway import ( @@ -19,7 +20,6 @@ show_default=True, ) - config_context_param = click.Option( param_decls=["-c", "--context"], default=DEFAULT_CONTEXT_NAME, @@ -235,3 +235,78 @@ def filter_values_callback( filters[key] = StringList(values=value_list) return filters + + +# Get all kwargs from __init__ +def get_init_kwargs(cls): + sig = inspect.signature(cls.__init__) + kwargs = { + k: v.default + for k, v in sig.parameters.items() + if v.default is not inspect.Parameter.empty + and v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + } + return kwargs + + +def override_config_options(func: click.Command): + f = click.option( + "--cpu", type=click.INT, help="The number of CPU units to allocate.", required=False + )(func) + f = click.option( + "--memory", + type=click.STRING, + help="The amount of memory to allocate in MB.", + required=False, + )(f) + f = click.option( + "--gpu", type=click.STRING, help="The type of GPU to allocate.", required=False + )(f) + f = click.option( + "--gpu-count", type=click.INT, help="The number of GPUs to allocate.", required=False + )(f) + f = click.option( + "--image", type=click.STRING, help="The image to use for the deployment.", required=False + )(f) + f = click.option( + "--secrets", + type=click.STRING, + multiple=True, + help="The secrets to inject into the deployment.", + )(f) + f = click.option( + "--ports", + type=click.INT, + multiple=True, + help="The ports to expose the deployment on.", + )(f) + return f + + +def handle_config_override(func, kwargs: Dict[str, str]) -> bool: + try: + config_class_instance = None + if hasattr(func, "parent"): + config_class_instance = func.parent + else: + config_class_instance = func + + # We only want to override the config if the config class has an __init__ method + # For example, ports is only available on a Pod + init_kwargs = get_init_kwargs(config_class_instance) + + for key, value in kwargs.items(): + if value is not None and key in init_kwargs: + if isinstance(value, tuple): + value = list(value) + + if len(value) == 0: + continue + + if hasattr(config_class_instance, f"parse_{key}"): + value = config_class_instance.__getattribute__(f"parse_{key}")(value) + setattr(config_class_instance, key, value) + return True + except BaseException as e: + terminal.error(f"Error overriding config: {e}", exit=False) + return False diff --git a/sdk/src/beta9/cli/run.py b/sdk/src/beta9/cli/run.py index 88f282a83..de455f118 100644 --- a/sdk/src/beta9/cli/run.py +++ b/sdk/src/beta9/cli/run.py @@ -9,7 +9,7 @@ from .. import terminal from ..abstractions.pod import Pod -from .extraclick import ClickCommonGroup +from .extraclick import ClickCommonGroup, handle_config_override, override_config_options @click.group(cls=ClickCommonGroup) @@ -40,38 +40,16 @@ def common(**_): nargs=1, required=False, ) -@click.option( - "--image", - help="The image to use for the pod.", - type=str, -) -@click.option( - "--gpu", - help="The GPU to use for the pod.", - type=str, -) -@click.option( - "--cpu", - help="The CPU to use for the pod.", - type=str, -) -@click.option( - "--memory", - help="The memory to use for the pod.", - type=str, -) @click.option( "--entrypoint", help="The entrypoint to use for the pod.", type=str, ) +@override_config_options def run( specfile: str, - image: str, - gpu: str, - cpu: str, - memory: str, entrypoint: str, + **kwargs, ): current_dir = os.getcwd() if current_dir not in sys.path: @@ -99,7 +77,7 @@ def run( f"Invalid handler function specified. Make sure '{module_path}' contains the function: '{obj_name}'" ) - if not inspect.isclass(type(pod_spec)) and pod_spec.__class__.__name__ != "Pod": + if not inspect.isclass(type(pod_spec)) or pod_spec.__class__.__name__ != "Pod": terminal.error("Invalid handler function specified. Expected a Pod abstraction.") if pod_spec is None: @@ -108,16 +86,10 @@ def run( pod_spec = Pod(entrypoint=shlex.split(entrypoint)) - if image: - pod_spec.image.base = image - - if gpu: - pod_spec.gpu = gpu - - if cpu: - pod_spec.cpu = cpu - - if memory: - pod_spec.memory = pod_spec.parse_memory(memory) + if not handle_config_override(pod_spec, kwargs): + terminal.error("Failed to handle config overrides.") + return - pod_spec.create() + if not pod_spec.create(): + terminal.error("Failed to create pod.") + return