diff --git a/src/snowflake/cli/_plugins/notebook/commands.py b/src/snowflake/cli/_plugins/notebook/commands.py index be9e4a7a33..f3f69d6cf7 100644 --- a/src/snowflake/cli/_plugins/notebook/commands.py +++ b/src/snowflake/cli/_plugins/notebook/commands.py @@ -15,6 +15,7 @@ import logging import typer +from click import UsageError from snowflake.cli._plugins.notebook.manager import NotebookManager from snowflake.cli._plugins.notebook.types import NotebookStagePath from snowflake.cli.api.commands.flags import identifier_argument @@ -76,11 +77,21 @@ def open_cmd( def create( identifier: Annotated[FQN, NOTEBOOK_IDENTIFIER], notebook_file: Annotated[NotebookStagePath, NotebookFile], + query_warehouse: Annotated[str, typer.Option("--query-warehouse")], + runtime_name: Annotated[str, typer.Option("--runtime-name")], + compute_pool: Annotated[str, typer.Option("--compute-pool")], **options, ): """Creates notebook from stage.""" + if runtime_name and not compute_pool: + raise UsageError("Runtime name requires compute pool.") + if compute_pool and not runtime_name: + raise UsageError("Compute pool requires runtime name.") notebook_url = NotebookManager().create( notebook_name=identifier, notebook_file=notebook_file, + query_warehouse=query_warehouse, + runtime_name=runtime_name, + compute_pool=compute_pool, ) return MessageResult(message=notebook_url) diff --git a/src/snowflake/cli/_plugins/notebook/manager.py b/src/snowflake/cli/_plugins/notebook/manager.py index 2de6d7f3a3..a266bd1a1c 100644 --- a/src/snowflake/cli/_plugins/notebook/manager.py +++ b/src/snowflake/cli/_plugins/notebook/manager.py @@ -50,16 +50,25 @@ def create( self, notebook_name: FQN, notebook_file: NotebookStagePath, + query_warehouse: str, + runtime_name: str, + compute_pool: str, ) -> str: notebook_fqn = notebook_name.using_connection(self._conn) stage_path = self.parse_stage_as_path(notebook_file) + runtime_query = f"RUNTIME_NAME = '{runtime_name}'\n" if runtime_name else "" + compute_pool_query = ( + f"COMPUTE_POOL = '{compute_pool}'\n" if compute_pool else "" + ) queries = dedent( f""" CREATE OR REPLACE NOTEBOOK {notebook_fqn.sql_identifier} FROM '{stage_path.parent}' - QUERY_WAREHOUSE = '{get_cli_context().connection.warehouse}' + QUERY_WAREHOUSE = '{query_warehouse or get_cli_context().connection.warehouse}' MAIN_FILE = '{stage_path.name}'; + {runtime_query} + {compute_pool_query} // Cannot use IDENTIFIER(...) ALTER NOTEBOOK {notebook_fqn.identifier} ADD LIVE VERSION FROM LAST; """