diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 4e191564ff..29d0fec704 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -34,6 +34,7 @@ * Add ability to add/remove versions to/from release channels through `snow app release-channel add-version` and `snow app release-channel remove-version` commands. * Add publish command to make it easier to manage publishing versions to release channels and updating release directives: `snow app publish` * Add support for restricting Snowflake user authentication policy to Snowflake CLI-only. +* Added a new command: `snow helpers import-snowsql-connections` allowing to import configuration of connections from SnowSQL. ## Fixes and improvements * Fixed inability to add patches to lowercase quoted versions diff --git a/src/snowflake/cli/_plugins/helpers/commands.py b/src/snowflake/cli/_plugins/helpers/commands.py index 5034943297..5666c063ba 100644 --- a/src/snowflake/cli/_plugins/helpers/commands.py +++ b/src/snowflake/cli/_plugins/helpers/commands.py @@ -14,17 +14,30 @@ from __future__ import annotations +import logging +from pathlib import Path +from typing import Any, List, Optional + import typer import yaml from click import ClickException from snowflake.cli.api.commands.snow_typer import SnowTyperFactory -from snowflake.cli.api.output.types import MessageResult +from snowflake.cli.api.config import ( + ConnectionConfig, + add_connection_to_proper_file, + get_all_connections, + set_config_value, +) +from snowflake.cli.api.console import cli_console +from snowflake.cli.api.output.types import CommandResult, MessageResult from snowflake.cli.api.project.definition_conversion import ( convert_project_definition_to_v2, ) from snowflake.cli.api.project.definition_manager import DefinitionManager from snowflake.cli.api.secure_path import SecurePath +log = logging.getLogger(__name__) + app = SnowTyperFactory( name="helpers", help="Helper commands.", @@ -88,3 +101,196 @@ def v1_to_v2( width=float("inf"), # Don't break lines ) return MessageResult("Project definition migrated to version 2.") + + +@app.command(name="import-snowsql-connections", requires_connection=False) +def import_snowsql_connections( + custom_snowsql_config_files: Optional[List[Path]] = typer.Option( + None, + "--snowsql-config-file", + help="Specifies file paths to custom SnowSQL configuration. The option can be used multiple times to specify more than 1 file.", + dir_okay=False, + exists=True, + ), + default_cli_connection_name: str = typer.Option( + "default", + "--default-connection-name", + help="Specifies the name which will be given in Snowflake CLI to the default connection imported from SnowSQL.", + ), + **options, +) -> CommandResult: + """Import your existing connections from your SnowSQL configuration.""" + + snowsql_config_files: list[Path] = custom_snowsql_config_files or [ + Path("/etc/snowsql.cnf"), + Path("/etc/snowflake/snowsql.cnf"), + Path("/usr/local/etc/snowsql.cnf"), + Path.home() / Path(".snowsql.cnf"), + Path.home() / Path(".snowsql/config"), + ] + snowsql_config_secure_paths: list[SecurePath] = [ + SecurePath(p) for p in snowsql_config_files + ] + + all_imported_connections = _read_all_connections_from_snowsql( + default_cli_connection_name, snowsql_config_secure_paths + ) + _validate_and_save_connections_imported_from_snowsql( + default_cli_connection_name, all_imported_connections + ) + return MessageResult( + "Connections successfully imported from SnowSQL to Snowflake CLI." + ) + + +def _read_all_connections_from_snowsql( + default_cli_connection_name: str, snowsql_config_files: List[SecurePath] +) -> dict[str, dict]: + import configparser + + imported_default_connection: dict[str, Any] = {} + imported_named_connections: dict[str, dict] = {} + + for file in snowsql_config_files: + if not file.exists(): + cli_console.step( + f"SnowSQL config file [{str(file.path)}] does not exist. Skipping." + ) + continue + + cli_console.step(f"Trying to read connections from [{str(file.path)}].") + snowsql_config = configparser.ConfigParser() + snowsql_config.read(file.path) + + if "connections" in snowsql_config and snowsql_config.items("connections"): + cli_console.step( + f"Reading SnowSQL's default connection configuration from [{str(file.path)}]" + ) + snowsql_default_connection = snowsql_config.items("connections") + imported_default_connection.update( + _convert_connection_from_snowsql_config_section( + snowsql_default_connection + ) + ) + + other_snowsql_connection_section_names = [ + section_name + for section_name in snowsql_config.sections() + if section_name.startswith("connections.") + ] + for snowsql_connection_section_name in other_snowsql_connection_section_names: + cli_console.step( + f"Reading SnowSQL's connection configuration [{snowsql_connection_section_name}] from [{str(file.path)}]" + ) + snowsql_named_connection = snowsql_config.items( + snowsql_connection_section_name + ) + if not snowsql_named_connection: + cli_console.step( + f"Empty connection configuration [{snowsql_connection_section_name}] in [{str(file.path)}]. Skipping." + ) + continue + + connection_name = snowsql_connection_section_name.removeprefix( + "connections." + ) + imported_named_conenction = _convert_connection_from_snowsql_config_section( + snowsql_named_connection + ) + if connection_name in imported_named_connections: + imported_named_connections[connection_name].update( + imported_named_conenction + ) + else: + imported_named_connections[connection_name] = imported_named_conenction + + def imported_default_connection_as_named_connection(): + name = _validate_imported_default_connection_name( + default_cli_connection_name, imported_named_connections + ) + return {name: imported_default_connection} + + named_default_connection = ( + imported_default_connection_as_named_connection() + if imported_default_connection + else {} + ) + + return imported_named_connections | named_default_connection + + +def _validate_imported_default_connection_name( + name_candidate: str, other_snowsql_connections: dict[str, dict] +) -> str: + if name_candidate in other_snowsql_connections: + new_name_candidate = typer.prompt( + f"Chosen default connection name '{name_candidate}' is already taken by other connection being imported from SnowSQL. Please choose a different name for your default connection" + ) + return _validate_imported_default_connection_name( + new_name_candidate, other_snowsql_connections + ) + else: + return name_candidate + + +def _convert_connection_from_snowsql_config_section( + snowsql_connection: list[tuple[str, Any]] +) -> dict[str, Any]: + from ast import literal_eval + + key_names_replacements = { + "accountname": "account", + "username": "user", + "databasename": "database", + "dbname": "database", + "schemaname": "schema", + "warehousename": "warehouse", + "rolename": "role", + "private_key_path": "private_key_file", + } + + def parse_value(value: Any): + try: + parsed_value = literal_eval(value) + except Exception: + parsed_value = value + return parsed_value + + cli_connection: dict[str, Any] = {} + for key, value in snowsql_connection: + cli_key = key_names_replacements.get(key, key) + cli_value = parse_value(value) + cli_connection[cli_key] = cli_value + return cli_connection + + +def _validate_and_save_connections_imported_from_snowsql( + default_cli_connection_name: str, all_imported_connections: dict[str, Any] +): + existing_cli_connection_names: set[str] = set(get_all_connections().keys()) + imported_connections_to_save: dict[str, Any] = {} + for ( + imported_connection_name, + imported_connection, + ) in all_imported_connections.items(): + if imported_connection_name in existing_cli_connection_names: + override_cli_connection = typer.confirm( + f"Connection '{imported_connection_name}' already exists in Snowflake CLI, do you want to use SnowSQL definition and override existing connection in Snowflake CLI?" + ) + if not override_cli_connection: + continue + imported_connections_to_save[imported_connection_name] = imported_connection + + for name, connection in imported_connections_to_save.items(): + cli_console.step(f"Saving [{name}] connection in Snowflake CLI's config.") + add_connection_to_proper_file(name, ConnectionConfig.from_dict(connection)) + + if default_cli_connection_name in imported_connections_to_save: + cli_console.step( + f"Setting [{default_cli_connection_name}] connection as Snowflake CLI's default connection." + ) + set_config_value( + section=None, + key="default_connection_name", + value=default_cli_connection_name, + ) diff --git a/tests_e2e/__snapshots__/test_import_snowsql_connections.ambr b/tests_e2e/__snapshots__/test_import_snowsql_connections.ambr new file mode 100644 index 0000000000..bf4d0ea7a7 --- /dev/null +++ b/tests_e2e/__snapshots__/test_import_snowsql_connections.ambr @@ -0,0 +1,25 @@ +# serializer version: 1 +# name: test_import_confirm_on_conflict_with_existing_cli_connection + '[{"connection_name": "example", "parameters": {"user": "u1", "schema": "public", "authenticator": "SNOWFLAKE_JWT"}, "is_default": false}]' +# --- +# name: test_import_confirm_on_conflict_with_existing_cli_connection.1 + '[{"connection_name": "example", "parameters": {"account": "accountname", "user": "username"}, "is_default": false}, {"connection_name": "snowsql1", "parameters": {"account": "a1", "user": "u1", "host": "h1_override", "database": "d1", "schema": "public", "warehouse": "w1", "role": "r1"}, "is_default": false}, {"connection_name": "snowsql2", "parameters": {"account": "a2", "user": "u2", "host": "h2", "port": 1234, "database": "d2", "schema": "public", "warehouse": "w2", "role": "r2"}, "is_default": false}, {"connection_name": "snowsql3", "parameters": {"account": "a3", "user": "u3", "password": "****", "host": "h3", "database": "d3", "schema": "public", "warehouse": "w3", "role": "r3"}, "is_default": false}, {"connection_name": "default", "parameters": {"account": "default_connection_account", "user": "default_connection_user", "host": "localhost", "database": "default_connection_database_override", "schema": "public", "warehouse": "default_connection_warehouse", "role": "accountadmin"}, "is_default": true}]' +# --- +# name: test_import_of_snowsql_connections + '[]' +# --- +# name: test_import_of_snowsql_connections.1 + '[{"connection_name": "snowsql1", "parameters": {"account": "a1", "user": "u1", "host": "h1_override", "database": "d1", "schema": "public", "warehouse": "w1", "role": "r1"}, "is_default": false}, {"connection_name": "snowsql2", "parameters": {"account": "a2", "user": "u2", "host": "h2", "port": 1234, "database": "d2", "schema": "public", "warehouse": "w2", "role": "r2"}, "is_default": false}, {"connection_name": "example", "parameters": {"account": "accountname", "user": "username"}, "is_default": false}, {"connection_name": "snowsql3", "parameters": {"account": "a3", "user": "u3", "password": "****", "host": "h3", "database": "d3", "schema": "public", "warehouse": "w3", "role": "r3"}, "is_default": false}, {"connection_name": "default", "parameters": {"account": "default_connection_account", "user": "default_connection_user", "host": "localhost", "database": "default_connection_database_override", "schema": "public", "warehouse": "default_connection_warehouse", "role": "accountadmin"}, "is_default": true}]' +# --- +# name: test_import_prompt_for_different_default_connection_name_on_conflict + '[]' +# --- +# name: test_import_prompt_for_different_default_connection_name_on_conflict.1 + '[{"connection_name": "snowsql1", "parameters": {"account": "a1", "user": "u1", "host": "h1_override", "database": "d1", "schema": "public", "warehouse": "w1", "role": "r1"}, "is_default": false}, {"connection_name": "snowsql2", "parameters": {"account": "a2", "user": "u2", "host": "h2", "port": 1234, "database": "d2", "schema": "public", "warehouse": "w2", "role": "r2"}, "is_default": true}, {"connection_name": "example", "parameters": {"account": "accountname", "user": "username"}, "is_default": false}, {"connection_name": "snowsql3", "parameters": {"account": "a3", "user": "u3", "password": "****", "host": "h3", "database": "d3", "schema": "public", "warehouse": "w3", "role": "r3"}, "is_default": false}, {"connection_name": "default", "parameters": {"account": "default_connection_account", "user": "default_connection_user", "host": "localhost", "database": "default_connection_database_override", "schema": "public", "warehouse": "default_connection_warehouse", "role": "accountadmin"}, "is_default": false}]' +# --- +# name: test_import_reject_on_conflict_with_existing_cli_connection + '[{"connection_name": "example", "parameters": {"user": "u1", "schema": "public", "authenticator": "SNOWFLAKE_JWT"}, "is_default": false}]' +# --- +# name: test_import_reject_on_conflict_with_existing_cli_connection.1 + '[{"connection_name": "example", "parameters": {"user": "u1", "schema": "public", "authenticator": "SNOWFLAKE_JWT"}, "is_default": false}, {"connection_name": "snowsql1", "parameters": {"account": "a1", "user": "u1", "host": "h1_override", "database": "d1", "schema": "public", "warehouse": "w1", "role": "r1"}, "is_default": false}, {"connection_name": "snowsql2", "parameters": {"account": "a2", "user": "u2", "host": "h2", "port": 1234, "database": "d2", "schema": "public", "warehouse": "w2", "role": "r2"}, "is_default": false}, {"connection_name": "snowsql3", "parameters": {"account": "a3", "user": "u3", "password": "****", "host": "h3", "database": "d3", "schema": "public", "warehouse": "w3", "role": "r3"}, "is_default": false}, {"connection_name": "default", "parameters": {"account": "default_connection_account", "user": "default_connection_user", "host": "localhost", "database": "default_connection_database_override", "schema": "public", "warehouse": "default_connection_warehouse", "role": "accountadmin"}, "is_default": true}]' +# --- diff --git a/tests_e2e/config/empty.toml b/tests_e2e/config/empty.toml new file mode 100644 index 0000000000..ada0a4e13d --- /dev/null +++ b/tests_e2e/config/empty.toml @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests_e2e/config/example_connection.toml b/tests_e2e/config/example_connection.toml new file mode 100644 index 0000000000..6113dc473e --- /dev/null +++ b/tests_e2e/config/example_connection.toml @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[connections.example] +schema = "public" +authenticator = "SNOWFLAKE_JWT" +user = "u1" + + +[cli.plugins.snowpark-hello] +enabled = true +[cli.plugins.snowpark-hello.config] +greeting = "Hello" + +[cli.plugins.multilingual-hello] +enabled = true diff --git a/tests_e2e/config/snowsql/config b/tests_e2e/config/snowsql/config new file mode 100644 index 0000000000..e260dacff1 --- /dev/null +++ b/tests_e2e/config/snowsql/config @@ -0,0 +1,47 @@ +[connections] + +accountname = "default_connection_account" +username = "default_connection_user" +host = "localhost" +databasename = "default_connection_database" +schemaname = "public" +warehousename = "default_connection_warehouse" +rolename = "accountadmin" + + +[connections.snowsql1] +accountname = "a1" +username = "u1" +host = "h1" +databasename = "d1" +schemaname = "public" +warehousename = "w1" +rolename = "r1" + +[connections.snowsql2] +accountname = "a2" +username = "u2" +host = "h2" +databasename = "d2" +schemaname = "public" +warehousename = "w2" +rolename = "r2" +port = 1234 + +[connections.example] +accountname = accountname +username = username + + +[variables] +example_variable=27 + + +[options] +auto_completion = True +log_file = /tmp/snowsql.log +log_level = DEBUG +timing = True +output_format = psql +key_bindings = emacs +repository_base_url = https://sfc-repo.snowflakecomputing.com/snowsql diff --git a/tests_e2e/config/snowsql/integration_config b/tests_e2e/config/snowsql/integration_config new file mode 100644 index 0000000000..a81e2dd829 --- /dev/null +++ b/tests_e2e/config/snowsql/integration_config @@ -0,0 +1,3 @@ +[connections.integration] +authenticator = "SNOWFLAKE_JWT" +schemaname = "public" diff --git a/tests_e2e/config/snowsql/overriding_config b/tests_e2e/config/snowsql/overriding_config new file mode 100644 index 0000000000..db2b8b414e --- /dev/null +++ b/tests_e2e/config/snowsql/overriding_config @@ -0,0 +1,30 @@ +[connections] +databasename = "default_connection_database_override" + + +[connections.snowsql1] +host = "h1_override" + +[connections.snowsql3] +accountname = "a3" +username = "u3" +password = "p3" +host = "h3" +databasename = "d3" +schemaname = "public" +warehousename = "w3" +rolename = "r3" + + +[variables] +example_variable=28 + + +[options] +auto_completion = True +log_file = /tmp/snowsql.log +log_level = DEBUG +timing = True +output_format = psql +key_bindings = emacs +repository_base_url = https://sfc-repo.snowflakecomputing.com/snowsql diff --git a/tests_e2e/conftest.py b/tests_e2e/conftest.py index 38bbb8327c..74a4bc4b1d 100644 --- a/tests_e2e/conftest.py +++ b/tests_e2e/conftest.py @@ -18,6 +18,7 @@ from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory +from typing import Optional import pytest from snowflake.cli import __about__ @@ -50,10 +51,10 @@ def _clean_output(text: str): ) -def subprocess_check_output(cmd): +def subprocess_check_output(cmd, stdin: Optional[str] = None): try: output = subprocess.check_output( - cmd, shell=IS_WINDOWS, stderr=sys.stdout, encoding="utf-8" + cmd, input=stdin, shell=IS_WINDOWS, stderr=sys.stdout, encoding="utf-8" ) return _clean_output(output) except subprocess.CalledProcessError as err: @@ -61,9 +62,10 @@ def subprocess_check_output(cmd): raise -def subprocess_run(cmd): +def subprocess_run(cmd, stdin: Optional[str] = None): p = subprocess.run( cmd, + input=stdin, shell=IS_WINDOWS, capture_output=True, text=True, @@ -163,8 +165,29 @@ def _temporary_project_directory(project_name): @pytest.fixture() -def config_file(test_root_path, temp_dir): - config_file_path = SecurePath(test_root_path) / "config" / "config.toml" - target_file_path = Path(temp_dir) / "config.toml" - config_file_path.copy(target_file_path) - yield target_file_path +def prepare_test_config_file(temp_dir): + def f(config_file_path: SecurePath): + target_file_path = Path(temp_dir) / "config.toml" + config_file_path.copy(target_file_path) + return target_file_path + + return f + + +@pytest.fixture() +def config_file(test_root_path, prepare_test_config_file): + yield prepare_test_config_file( + SecurePath(test_root_path) / "config" / "config.toml" + ) + + +@pytest.fixture() +def empty_config_file(test_root_path, prepare_test_config_file): + yield prepare_test_config_file(SecurePath(test_root_path) / "config" / "empty.toml") + + +@pytest.fixture() +def example_connection_config_file(test_root_path, prepare_test_config_file): + yield prepare_test_config_file( + SecurePath(test_root_path) / "config" / "example_connection.toml" + ) diff --git a/tests_e2e/test_import_snowsql_connections.py b/tests_e2e/test_import_snowsql_connections.py new file mode 100644 index 0000000000..49ebc30554 --- /dev/null +++ b/tests_e2e/test_import_snowsql_connections.py @@ -0,0 +1,234 @@ +import json +from typing import Optional + +import pytest + +from tests_e2e.conftest import subprocess_check_output, subprocess_run + + +@pytest.fixture() +def _assert_json_output_matches_snapshot(snapshot): + def f(cmd, stdin: Optional[str] = None): + output = subprocess_check_output(cmd, stdin) + parsed_json = json.loads(output) + snapshot.assert_match(json.dumps(parsed_json)) + + return f + + +@pytest.mark.e2e +def test_import_of_snowsql_connections( + snowcli, test_root_path, empty_config_file, _assert_json_output_matches_snapshot +): + _assert_json_output_matches_snapshot( + [ + snowcli, + "--config-file", + empty_config_file, + "connection", + "list", + "--format", + "json", + ], + ) + + result = subprocess_run( + [ + snowcli, + "--config-file", + empty_config_file, + "helpers", + "import-snowsql-connections", + "--snowsql-config-file", + test_root_path / "config" / "snowsql" / "config", + "--snowsql-config-file", + test_root_path / "config" / "snowsql" / "overriding_config", + ], + ) + assert result.returncode == 0 + + _assert_json_output_matches_snapshot( + [ + snowcli, + "--config-file", + empty_config_file, + "connection", + "list", + "--format", + "json", + ] + ) + + +@pytest.mark.e2e +def test_import_prompt_for_different_default_connection_name_on_conflict( + snowcli, test_root_path, empty_config_file, _assert_json_output_matches_snapshot +): + _assert_json_output_matches_snapshot( + [ + snowcli, + "--config-file", + empty_config_file, + "connection", + "list", + "--format", + "json", + ], + ) + + result = subprocess_run( + [ + snowcli, + "--config-file", + empty_config_file, + "helpers", + "import-snowsql-connections", + "--snowsql-config-file", + test_root_path / "config" / "snowsql" / "config", + "--snowsql-config-file", + test_root_path / "config" / "snowsql" / "overriding_config", + "--default-connection-name", + "snowsql2", + ], + stdin="default\n", + ) + assert result.returncode == 0 + + _assert_json_output_matches_snapshot( + [ + snowcli, + "--config-file", + empty_config_file, + "connection", + "list", + "--format", + "json", + ] + ) + + +@pytest.mark.e2e +def test_import_confirm_on_conflict_with_existing_cli_connection( + snowcli, + test_root_path, + example_connection_config_file, + _assert_json_output_matches_snapshot, +): + _assert_json_output_matches_snapshot( + [ + snowcli, + "--config-file", + example_connection_config_file, + "connection", + "list", + "--format", + "json", + ], + ) + + result = subprocess_run( + [ + snowcli, + "--config-file", + example_connection_config_file, + "helpers", + "import-snowsql-connections", + "--snowsql-config-file", + test_root_path / "config" / "snowsql" / "config", + "--snowsql-config-file", + test_root_path / "config" / "snowsql" / "overriding_config", + ], + stdin="y\n", + ) + assert result.returncode == 0 + + _assert_json_output_matches_snapshot( + [ + snowcli, + "--config-file", + example_connection_config_file, + "connection", + "list", + "--format", + "json", + ], + ) + + +@pytest.mark.e2e +def test_import_reject_on_conflict_with_existing_cli_connection( + snowcli, + test_root_path, + example_connection_config_file, + _assert_json_output_matches_snapshot, +): + _assert_json_output_matches_snapshot( + [ + snowcli, + "--config-file", + example_connection_config_file, + "connection", + "list", + "--format", + "json", + ], + ) + + result = subprocess_run( + [ + snowcli, + "--config-file", + example_connection_config_file, + "helpers", + "import-snowsql-connections", + "--snowsql-config-file", + test_root_path / "config" / "snowsql" / "config", + "--snowsql-config-file", + test_root_path / "config" / "snowsql" / "overriding_config", + ], + stdin="n\n", + ) + assert result.returncode == 0 + + _assert_json_output_matches_snapshot( + [ + snowcli, + "--config-file", + example_connection_config_file, + "connection", + "list", + "--format", + "json", + ], + ) + + +@pytest.mark.e2e +def test_connection_imported_from_snowsql(snowcli, test_root_path, empty_config_file): + result = subprocess_run( + [ + snowcli, + "--config-file", + empty_config_file, + "helpers", + "import-snowsql-connections", + "--snowsql-config-file", + test_root_path / "config" / "snowsql" / "integration_config", + ], + ) + assert result.returncode == 0 + + result = subprocess_run( + [ + snowcli, + "--config-file", + empty_config_file, + "connection", + "test", + "-c", + "integration", + "--format", + "json", + ], + ) + assert result.returncode == 0