Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use -i for stdin in snow sql #592

Merged
merged 5 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/snowcli/cli/common/cli_global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __setattr__(self, key, value):
We invalidate connection cache every time connection attributes change.
"""
super.__setattr__(self, key, value)
if key is not "_cached_connection":
if key != "_cached_connection":
self._cached_connection = None

@property
Expand Down
11 changes: 9 additions & 2 deletions src/snowcli/cli/sql/commands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
from pathlib import Path
from typing import Optional

import typer
from click import UsageError
from snowcli.cli.common.decorators import global_options_with_connection
from snowcli.cli.sql.manager import SqlManager
from snowcli.output.decorators import with_output
Expand Down Expand Up @@ -31,15 +33,20 @@ def execute_sql(
readable=True,
help="File to execute.",
),
std_in: Optional[bool] = typer.Option(
False,
"-i",
help="Read the query from standard input. Use it when piping input to this command.",
),
**options
) -> CommandResult:
"""
Executes Snowflake query.

Query to execute can be specified using query option, filename option (all queries from file will be executed)
or via stdin by piping output from other command. For example `cat my.sql | snow sql`.
or via stdin by piping output from other command. For example `cat my.sql | snow sql -i`.
"""
cursors = SqlManager().execute(query, file)
cursors = SqlManager().execute(query, file, std_in)
if len(cursors) > 1:
result = MultipleResults()
for curr in cursors:
Expand Down
31 changes: 14 additions & 17 deletions src/snowcli/cli/sql/manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from itertools import combinations, starmap
from pathlib import Path
from typing import List, Optional

Expand All @@ -9,25 +10,21 @@

class SqlManager(SqlExecutionMixin):
def execute(
self, query: Optional[str], file: Optional[Path]
self, query: Optional[str], file: Optional[Path], std_in: bool
) -> List[SnowflakeCursor]:
sys_input = None
inputs = [query, file, std_in]
if not any(inputs):
raise UsageError("Use either query, filename or input option.")

if query and file:
raise UsageError("Both query and file provided, please specify only one.")

if not sys.stdin.isatty():
sys_input = sys.stdin.read()

if sys_input and (query or file):
# Check if any two inputs were provided simultaneously
if any(starmap(lambda *t: all(t), combinations(inputs, 2))):
sfc-gh-turbaszek marked this conversation as resolved.
Show resolved Hide resolved
raise UsageError(
"Can't use stdin input together with query or filename option."
"Multiple input sources specified. Please specify only one."
)

if not query and not file and not sys_input:
raise UsageError("Provide either query or filename argument")
elif sys_input:
sql = sys_input
else:
sql = query if query else file.read_text() # type: ignore
return self._execute_queries(sql)
if std_in:
query = sys.stdin.read()
elif file:
query = file.read_text()

return self._execute_queries(query)
26 changes: 14 additions & 12 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from tempfile import NamedTemporaryFile
from unittest import mock

import pytest

from tests.testing_utils.fixtures import *
from tests.testing_utils.result_assertions import assert_that_result_is_usage_error

Expand Down Expand Up @@ -37,7 +39,7 @@ def test_sql_execute_from_stdin(mock_connector, runner, mock_ctx, mock_cursor):
mock_connector.return_value = ctx
query = "query from input"

result = runner.invoke(["sql"], input=query)
result = runner.invoke(["sql", "-i"], input=query)

assert result.exit_code == 0
assert ctx.get_query() == query
Expand All @@ -47,26 +49,26 @@ def test_sql_fails_if_no_query_file_or_stdin(runner):
result = runner.invoke(["sql"])

assert_that_result_is_usage_error(
result, "Provide either query or filename argument"
result, "Use either query, filename or input option."
)


def test_sql_fails_for_both_stdin_and_other_query_source(runner):
@pytest.mark.parametrize("inputs", [("-i", "-q", "foo"), ("-i",), ("-q", "foo")])
def test_sql_fails_if_other_inputs_and_file_provided(runner, inputs):
with NamedTemporaryFile("r") as tmp_file:
result = runner.invoke(["sql", "-f", tmp_file.name], input="query from input")
result = runner.invoke(["sql", *inputs, "-f", tmp_file.name])
assert_that_result_is_usage_error(
result, "Multiple input sources specified. Please specify only one. "
)


def test_sql_fails_if_query_and_stdin_provided(runner):
result = runner.invoke(["sql", "-q", "fooo", "-i"])
assert_that_result_is_usage_error(
result, "Can't use stdin input together with query or filename"
result, "Multiple input sources specified. Please specify only one. "
)


def test_sql_fails_for_both_query_and_file(runner):
with NamedTemporaryFile("r") as tmp_file:
result = runner.invoke(["sql", "-f", tmp_file.name, "-q", "query"])

assert_that_result_is_usage_error(result, "Both query and file provided")


@mock.patch("snowcli.cli.common.cli_global_context.connect_to_snowflake")
def test_sql_overrides_connection_configuration(mock_conn, runner, mock_cursor):
mock_conn.return_value.execute_string.return_value = [mock_cursor(["row"], [])]
Expand Down
4 changes: 2 additions & 2 deletions tests/testing_utils/result_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def assert_that_result_is_usage_error(
result: Result, expected_error_message: str
) -> None:
assert result.exit_code == 2
assert expected_error_message in result.output
assert result.exit_code == 2, result.exit_code
assert expected_error_message in result.output, result.output
assert isinstance(result.exception, SystemExit)
assert "traceback" not in result.output.lower()
13 changes: 13 additions & 0 deletions tests_integration/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def test_multi_queries_from_file(runner, snowflake_session, test_root_path):
]


@pytest.mark.integration
def test_multi_input_from_stdin(runner, snowflake_session, test_root_path):
result = runner.invoke_with_connection_json(
[
"sql",
"-i",
],
input="select 42;",
)
assert result.exit_code == 0
assert result.json == [{"42": 42}]


def _round_values(results):
for result in results:
for k, v in result.items():
Expand Down