Skip to content

Commit

Permalink
[SNOW-1758029] Removed necessity of requirements.txt for git execute (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pjob authored and sfc-gh-turbaszek committed Oct 23, 2024
1 parent 14cdf3d commit cb61b02
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 13 deletions.
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
* `snow spcs image-repository list-images` now displays image tag and digest.
* Fix `snow stage list-files` for paths with directories.
* `snow --info` callback returns information about `SNOWFLAKE_HOME` variable.
* Removed requirement of existence of any `requirements.txt` file for Python code execution via `snow git execute` command.
Before the fix the file (even empty) was required to make the execution working.

# v3.0.2

Expand Down
10 changes: 4 additions & 6 deletions src/snowflake/cli/_plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ def execute(
stage_path = self.build_path(stage_path_str)

all_files_list = self._get_files_list_from_stage(stage_path.root_path())
if not all_files_list:
raise ClickException(f"No files found on stage '{stage_path}'")

all_files_with_stage_name_prefix = [
stage_path_parts.get_directory(file) for file in all_files_list
Expand Down Expand Up @@ -480,10 +482,6 @@ def _get_files_list_from_stage(
self, stage_path: StagePath, pattern: str | None = None
) -> List[str]:
files_list_result = self.list_files(stage_path, pattern=pattern).fetchall()

if not files_list_result:
raise ClickException(f"No files found on stage '{stage_path}'")

return [f["name"] for f in files_list_result]

def _filter_files_list(
Expand Down Expand Up @@ -631,7 +629,7 @@ def _bootstrap_snowpark_execution_environment(self, stage_path: StagePath):
requirements = self._check_for_requirements_file(stage_path)
self.snowpark_session.add_packages(*requirements)

@sproc(is_permanent=False)
@sproc(is_permanent=False, session=self.snowpark_session)
def _python_execution_procedure(
_: Session, file_path: str, variables: Dict | None = None
) -> None:
Expand Down Expand Up @@ -668,7 +666,7 @@ def _execute_python(
from snowflake.snowpark.exceptions import SnowparkSQLException

try:
self._python_exe_procedure(self.get_standard_stage_prefix(file_stage_path), variables) # type: ignore
self._python_exe_procedure(self.get_standard_stage_prefix(file_stage_path), variables, session=self.snowpark_session) # type: ignore
return StageManager._success_result(file=original_file)
except SnowparkSQLException as e:
StageManager._handle_execution_exception(on_error=on_error, exception=e)
Expand Down
27 changes: 20 additions & 7 deletions tests/stage/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,11 @@ def test_execute_from_user_stage(

@mock.patch(f"{STAGE_MANAGER}._execute_query")
@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment")
@mock.patch(f"{STAGE_MANAGER}.snowpark_session")
@skip_python_3_12
def test_execute_with_variables(mock_bootstrap, mock_execute, mock_cursor, runner):
def test_execute_with_variables(
mock_snowpark_session, mock_bootstrap, mock_execute, mock_cursor, runner
):
mock_execute.return_value = mock_cursor(
[{"name": "exe/s1.sql"}, {"name": "exe/s2.py"}], []
)
Expand Down Expand Up @@ -907,6 +910,7 @@ def test_execute_with_variables(mock_bootstrap, mock_execute, mock_cursor, runne
"key4": "NULL",
"key5": "var=value",
},
session=mock_snowpark_session,
)


Expand Down Expand Up @@ -1000,8 +1004,11 @@ def test_execute_no_files_for_stage_path(

@mock.patch(f"{STAGE_MANAGER}._execute_query")
@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment")
@mock.patch(f"{STAGE_MANAGER}.snowpark_session")
@skip_python_3_12
def test_execute_stop_on_error(mock_bootstrap, mock_execute, mock_cursor, runner):
def test_execute_stop_on_error(
mock_snowpark_session, mock_bootstrap, mock_execute, mock_cursor, runner
):
error_message = "Error"
mock_execute.side_effect = [
mock_cursor(
Expand All @@ -1027,17 +1034,23 @@ def test_execute_stop_on_error(mock_bootstrap, mock_execute, mock_cursor, runner
mock.call(f"execute immediate from @exe/s2.sql"),
]
assert mock_bootstrap.return_value.mock_calls == [
mock.call("@exe/p1.py", {}),
mock.call("@exe/p2.py", {}),
mock.call("@exe/p1.py", {}, session=mock_snowpark_session),
mock.call("@exe/p2.py", {}, session=mock_snowpark_session),
]
assert error_message in result.output


@mock.patch(f"{STAGE_MANAGER}._execute_query")
@mock.patch(f"{STAGE_MANAGER}._bootstrap_snowpark_execution_environment")
@mock.patch(f"{STAGE_MANAGER}.snowpark_session")
@skip_python_3_12
def test_execute_continue_on_error(
mock_bootstrap, mock_execute, mock_cursor, runner, os_agnostic_snapshot
mock_snowpark_session,
mock_bootstrap,
mock_execute,
mock_cursor,
runner,
os_agnostic_snapshot,
):
from snowflake.snowpark.exceptions import SnowparkSQLException

Expand Down Expand Up @@ -1071,8 +1084,8 @@ def test_execute_continue_on_error(
]

assert mock_bootstrap.return_value.mock_calls == [
mock.call("@exe/p1.py", {}),
mock.call("@exe/p2.py", {}),
mock.call("@exe/p1.py", {}, session=mock_snowpark_session),
mock.call("@exe/p2.py", {}, session=mock_snowpark_session),
]


Expand Down
9 changes: 9 additions & 0 deletions tests_integration/__snapshots__/test_git.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
}),
])
# ---
# name: test_git_execute_python_without_requirements
list([
dict({
'Error': None,
'File': '@snowcli_testing_repo/branches/main/tests_integration/test_data/projects/stage_execute_without_requirements/script_template.py',
'Status': 'SUCCESS',
}),
])
# ---
# name: test_execute_with_name_in_pascal_case
list([
dict({
Expand Down
9 changes: 9 additions & 0 deletions tests_integration/__snapshots__/test_stage.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@
}),
])
# ---
# name: test_stage_execute_python_without_requirements
list([
dict({
'Error': None,
'File': '@test_stage_execute_without_requirements/script_template.py',
'Status': 'SUCCESS',
}),
])
# ---
# name: test_user_stage_execute
list([
dict({
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os
from snowflake.core import Root
from snowflake.core.database import DatabaseResource
from snowflake.core.schema import Schema
from snowflake.snowpark.session import Session

session = Session.builder.getOrCreate()
database: DatabaseResource = Root(session).databases[os.environ["test_database_name"]]

assert database.name.upper() == os.environ["test_database_name"].upper()

# Make a side effect that we can check in tests
database.schemas.create(Schema(name=os.environ["TEST_ID"]))
41 changes: 41 additions & 0 deletions tests_integration/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time

import pytest
from snowflake.connector.errors import ProgrammingError
from pathlib import Path
import tempfile

from tests_integration.test_utils import contains_row_with

FILE_IN_REPO = "RELEASE-NOTES.md"


Expand Down Expand Up @@ -302,6 +306,43 @@ def test_execute_python(runner, test_database, sf_git_repository, snapshot):
assert result.json == snapshot


@pytest.mark.integration
@pytest.mark.skipif(
sys.version_info >= (3, 12), reason="Snowpark is not supported in Python >= 3.12"
)
@pytest.mark.skip(
"Requires merging changes to the main branch"
) # TODO: remove after merging to the main branch
def test_git_execute_python_without_requirements(
snowflake_session,
runner,
test_database,
test_root_path,
snapshot,
sf_git_repository,
):
test_id = f"FOO{time.time_ns()}"
result = runner.invoke_with_connection_json(
[
"git",
"execute",
f"@{sf_git_repository.lower()}/branches/main/tests_integration/test_data/projects/stage_execute_without_requirements",
"-D",
f"test_database_name={test_database}",
"-D",
f"TEST_ID={test_id}",
]
)
assert result.exit_code == 0
assert result.json == snapshot

# Assert side effect created by executed script
*_, schemas = snowflake_session.execute_string(
f"show schemas like '{test_id}' in database {test_database};"
)
assert len(list(schemas)) == 1


@pytest.mark.integration
def test_execute_fqn_repo(runner, test_database, sf_git_repository):
result_fqn = runner.invoke_with_connection_json(
Expand Down
51 changes: 51 additions & 0 deletions tests_integration/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,57 @@ def test_stage_execute_python(
assert len(list(schemas)) == 1


@pytest.mark.integration
@pytest.mark.skipif(
sys.version_info >= (3, 12), reason="Snowpark is not supported in Python >= 3.12"
)
def test_stage_execute_python_without_requirements(
snowflake_session, runner, test_database, test_root_path, snapshot
):
project_path = (
test_root_path / "test_data/projects/stage_execute_without_requirements"
)
stage_name = "test_stage_execute_without_requirements"

result = runner.invoke_with_connection_json(["stage", "create", stage_name])
assert contains_row_with(
result.json,
{"status": f"Stage area {stage_name.upper()} successfully created."},
)

result = runner.invoke_with_connection_json(
[
"stage",
"copy",
str(Path(project_path) / "script_template.py"),
f"@{stage_name}",
]
)
assert result.exit_code == 0, result.output
assert contains_row_with(result.json, {"status": "UPLOADED"})

test_id = f"FOO{time.time_ns()}"
result = runner.invoke_with_connection_json(
[
"stage",
"execute",
f"{stage_name}/",
"-D",
f"test_database_name={test_database}",
"-D",
f"TEST_ID={test_id}",
]
)
assert result.exit_code == 0
assert result.json == snapshot

# Assert side effect created by executed script
*_, schemas = snowflake_session.execute_string(
f"show schemas like '{test_id}' in database {test_database};"
)
assert len(list(schemas)) == 1


@pytest.mark.integration
def test_stage_diff(runner, snowflake_session, test_database, tmp_path, snapshot):
stage_name = "test_stage"
Expand Down

0 comments on commit cb61b02

Please sign in to comment.