From c9665b9bc33a291ce6edc83111bd165344e66861 Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 30 Aug 2023 10:26:27 +0800 Subject: [PATCH 01/16] Add encoder/decoder in structureDataset for snowflake. Signed-off-by: HH --- flytekit/core/type_engine.py | 3 + flytekit/types/structured/__init__.py | 13 +++ flytekit/types/structured/snowflake.py | 116 +++++++++++++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 flytekit/types/structured/snowflake.py diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9c48908f98..00c0b82e00 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -854,6 +854,7 @@ def lazy_import_transformers(cls): register_arrow_handlers, register_bigquery_handlers, register_pandas_handlers, + register_snowflake_handlers, ) if is_imported("tensorflow"): @@ -872,6 +873,8 @@ def lazy_import_transformers(cls): register_arrow_handlers() if is_imported("google.cloud.bigquery"): register_bigquery_handlers() + if is_imported("snowflake.connector"): + register_snowflake_handlers() if is_imported("numpy"): from flytekit.types import numpy # noqa: F401 diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 543117c865..407961717f 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -70,3 +70,16 @@ def register_bigquery_handlers(): "We won't register bigquery handler for structured dataset because " "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" ) + + +def register_snowflake_handlers(): + try: + from .snowflake import PandasToSnowflakeEncodingHandlers, SnowflakeToPandasDecodingHandler + + StructuredDatasetTransformerEngine.register(SnowflakeToPandasDecodingHandler()) + StructuredDatasetTransformerEngine.register(PandasToSnowflakeEncodingHandlers()) + + except ImportError: + logger.info( + "We won't register snowflake handler for structured dataset because " "we can't find package snowflake" + ) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py new file mode 100644 index 0000000000..b359ce424d --- /dev/null +++ b/flytekit/types/structured/snowflake.py @@ -0,0 +1,116 @@ +import re +import typing + +import pandas as pd +import pyarrow as pa +import snowflake.connector +from snowflake.connector.pandas_tools import write_pandas + +from flytekit import FlyteContext +from flytekit.models import literals +from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.structured_dataset import ( + StructuredDataset, + StructuredDatasetDecoder, + StructuredDatasetEncoder, + StructuredDatasetMetadata, +) + +SNOWFLAKE = "snowflake" + + +def get_private_key(): + import os + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives.asymmetric import dsa + from cryptography.hazmat.primitives import serialization + import flytekit + + pk_path = flytekit.current_context().secrets.get_secrets_file(SNOWFLAKE, "rsa_key.p8") + + with open(pk_path, "rb") as key: + p_key= serialization.load_pem_private_key( + key.read(), + password=None, + backend=default_backend() + ) + + return p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption()) + + +def _write_to_sf(structured_dataset: StructuredDataset): + if structured_dataset.uri is None: + raise ValueError("structured_dataset.uri cannot be None.") + + uri = structured_dataset.uri + _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) + df = structured_dataset.dataframe + + conn = snowflake.connector.connect( + user=user, + account=account, + private_key=get_private_key(), + database=database, + schema=schema, + warehouse=warehouse + ) + + cs = conn.cursor() + write_pandas(conn, df, table) + + +def _read_from_sf( + flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata +) -> pd.DataFrame: + if flyte_value.uri is None: + raise ValueError("structured_dataset.uri cannot be None.") + + uri = flyte_value.uri + _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) + + conn = snowflake.connector.connect( + user=user, + account=account, + private_key=get_private_key(), + database=database, + schema=schema, + warehouse=warehouse + ) + + cs = conn.cursor() + cs.execute(f"select * from {table}") + + return cs.fetch_pandas_all() + + +class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder): + def __init__(self): + super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="") + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + _write_to_sf(structured_dataset) + return literals.StructuredDataset( + uri=typing.cast(str, structured_dataset.uri), metadata=StructuredDatasetMetadata(structured_dataset_type) + ) + + +class SnowflakeToPandasDecodingHandler(StructuredDatasetDecoder): + def __init__(self): + super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="") + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> pd.DataFrame: + return _read_from_sf(flyte_value, current_task_metadata) From 3789dfa12408efedea3dc72df9b4822c37c0b240 Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 31 Aug 2023 12:21:37 +0800 Subject: [PATCH 02/16] add unit-test for snowflake structure dataset encoder/decoder Signed-off-by: HH --- .../structured_dataset/test_snowflake.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/flytekit/unit/types/structured_dataset/test_snowflake.py diff --git a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py new file mode 100644 index 0000000000..8ea85e9e17 --- /dev/null +++ b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py @@ -0,0 +1,48 @@ +import mock +import pytest +import pandas as pd +from typing_extensions import Annotated + +from flytekit import StructuredDataset, kwtypes, task, workflow + +pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) +my_cols = kwtypes(Name=str, Age=int) + + +@task +def gen_df() -> Annotated[pd.DataFrame, my_cols, "parquet"]: + return pd_df + + +@task +def t1(df: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: + return StructuredDataset(dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table") + + +@task +def t2(sd: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: + return sd.open(pd.DataFrame).all() + + +@workflow +def wf() -> pd.DataFrame: + df = gen_df() + sd = t1(df=df) + return t2(sd=sd) + + +@mock.patch("snowflake.connector.connect") +@pytest.mark.asyncio +async def test_sf_wf(mock_connect): + class mock_pages: + def to_dataframe(self): + return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + class mock_rows: + pages = [mock_pages()] + + mock_connect_instance = mock_connect.return_value + mock_coursor_instance = mock_connect.cursor.return_value + mock_coursor_instance.fetch_pandas_all.return_value = mock_rows + + assert wf().equals(pd_df) From 153a6b4ea4a62060d8bb52fd8735b53635bb4d87 Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 31 Aug 2023 17:31:57 +0800 Subject: [PATCH 03/16] add unit-test for snowflake structure dataset encoder/decoder Signed-off-by: HH --- flytekit/types/structured/snowflake.py | 34 ++++++------------- .../structured_dataset/test_snowflake.py | 15 ++++---- 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py index b359ce424d..9f28734f43 100644 --- a/flytekit/types/structured/snowflake.py +++ b/flytekit/types/structured/snowflake.py @@ -2,7 +2,6 @@ import typing import pandas as pd -import pyarrow as pa import snowflake.connector from snowflake.connector.pandas_tools import write_pandas @@ -20,26 +19,21 @@ def get_private_key(): - import os from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives.asymmetric import rsa - from cryptography.hazmat.primitives.asymmetric import dsa from cryptography.hazmat.primitives import serialization + import flytekit pk_path = flytekit.current_context().secrets.get_secrets_file(SNOWFLAKE, "rsa_key.p8") with open(pk_path, "rb") as key: - p_key= serialization.load_pem_private_key( - key.read(), - password=None, - backend=default_backend() - ) + p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend()) return p_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption()) + encryption_algorithm=serialization.NoEncryption(), + ) def _write_to_sf(structured_dataset: StructuredDataset): @@ -51,15 +45,9 @@ def _write_to_sf(structured_dataset: StructuredDataset): df = structured_dataset.dataframe conn = snowflake.connector.connect( - user=user, - account=account, - private_key=get_private_key(), - database=database, - schema=schema, - warehouse=warehouse + user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse ) - cs = conn.cursor() write_pandas(conn, df, table) @@ -73,18 +61,16 @@ def _read_from_sf( _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) conn = snowflake.connector.connect( - user=user, - account=account, - private_key=get_private_key(), - database=database, - schema=schema, - warehouse=warehouse + user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse ) cs = conn.cursor() cs.execute(f"select * from {table}") - return cs.fetch_pandas_all() + dff = cs.fetch_pandas_all() + print("cs", cs) + print("dff", dff) + return dff class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder): diff --git a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py index 8ea85e9e17..0c88be40d5 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py +++ b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py @@ -1,6 +1,6 @@ import mock -import pytest import pandas as pd +import pytest from typing_extensions import Annotated from flytekit import StructuredDataset, kwtypes, task, workflow @@ -16,7 +16,9 @@ def gen_df() -> Annotated[pd.DataFrame, my_cols, "parquet"]: @task def t1(df: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: - return StructuredDataset(dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table") + return StructuredDataset( + dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table" + ) @task @@ -34,15 +36,12 @@ def wf() -> pd.DataFrame: @mock.patch("snowflake.connector.connect") @pytest.mark.asyncio async def test_sf_wf(mock_connect): - class mock_pages: + class mock_dataframe: def to_dataframe(self): return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - class mock_rows: - pages = [mock_pages()] - mock_connect_instance = mock_connect.return_value - mock_coursor_instance = mock_connect.cursor.return_value - mock_coursor_instance.fetch_pandas_all.return_value = mock_rows + mock_coursor_instance = mock_connect_instance.cursor.return_value + mock_coursor_instance.fetch_pandas_all.return_value = mock_dataframe().to_dataframe() assert wf().equals(pd_df) From 953c5a3b4d4a7e642a5c2d0f294d49fd5d1ae150 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 1 Sep 2023 12:27:18 +0800 Subject: [PATCH 04/16] let lazy_import_transformers force load the snowflake-connector Signed-off-by: HH --- flytekit/core/type_engine.py | 7 +++++-- flytekit/types/structured/__init__.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 00c0b82e00..82c11f3bcd 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -873,11 +873,14 @@ def lazy_import_transformers(cls): register_arrow_handlers() if is_imported("google.cloud.bigquery"): register_bigquery_handlers() - if is_imported("snowflake.connector"): - register_snowflake_handlers() if is_imported("numpy"): from flytekit.types import numpy # noqa: F401 + try: + register_snowflake_handlers() + except ValueError as e: + logger.debug(f"Attempted to register the Snowflake handler but failed due to: {str(e)}") + @classmethod def to_literal_type(cls, python_type: Type) -> LiteralType: """ diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 407961717f..617e4bcafa 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -81,5 +81,6 @@ def register_snowflake_handlers(): except ImportError: logger.info( - "We won't register snowflake handler for structured dataset because " "we can't find package snowflake" + "We won't register snowflake handler for structured dataset because " + "we can't find package snowflakee-connector-python" ) From d41f7dffde3e4409ddf3db3f59f44b1553835b32 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 1 Sep 2023 12:33:50 +0800 Subject: [PATCH 05/16] add mock get_private_key for unit-test Signed-off-by: HH --- tests/flytekit/unit/types/structured_dataset/test_snowflake.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py index 0c88be40d5..c957c0bbce 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py +++ b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py @@ -33,9 +33,10 @@ def wf() -> pd.DataFrame: return t2(sd=sd) +@mock.patch("flytekit.types.structured.snowflake.get_private_key", return_value="pb") @mock.patch("snowflake.connector.connect") @pytest.mark.asyncio -async def test_sf_wf(mock_connect): +async def test_sf_wf(mock_connect, mock_get_private_key): class mock_dataframe: def to_dataframe(self): return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) From b8cba8783bbe22f9f1b2d2136e8446df904c74e9 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 1 Sep 2023 12:36:33 +0800 Subject: [PATCH 06/16] add snowflake-connector-python in dev-requirements.in Signed-off-by: HH --- dev-requirements.in | 1 + 1 file changed, 1 insertion(+) diff --git a/dev-requirements.in b/dev-requirements.in index 7159812e26..5285290cfa 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -11,6 +11,7 @@ pre-commit codespell google-cloud-bigquery google-cloud-bigquery-storage +snowflake-connector-python IPython keyrings.alt From 8346fb986b7f32bab1a4ea46069cbcdb61ce77a1 Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 13 Sep 2023 11:33:45 +0800 Subject: [PATCH 07/16] replace get_secrets_file with get Signed-off-by: HH --- flytekit/types/structured/snowflake.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py index 9f28734f43..70a2de6b6f 100644 --- a/flytekit/types/structured/snowflake.py +++ b/flytekit/types/structured/snowflake.py @@ -24,16 +24,9 @@ def get_private_key(): import flytekit - pk_path = flytekit.current_context().secrets.get_secrets_file(SNOWFLAKE, "rsa_key.p8") + pk_string = flytekit.current_context().secrets.get(SNOWFLAKE, "private_key") - with open(pk_path, "rb") as key: - p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend()) - - return p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) + return serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) def _write_to_sf(structured_dataset: StructuredDataset): From 62f917093982e8b9a05e7b1963ae57626249e29f Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 13 Sep 2023 11:54:22 +0800 Subject: [PATCH 08/16] fix small bugs in get private key Signed-off-by: HH --- flytekit/types/structured/snowflake.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py index 70a2de6b6f..4842a1c28e 100644 --- a/flytekit/types/structured/snowflake.py +++ b/flytekit/types/structured/snowflake.py @@ -24,9 +24,16 @@ def get_private_key(): import flytekit - pk_string = flytekit.current_context().secrets.get(SNOWFLAKE, "private_key") + pk_string = flytekit.current_context().secrets.get(TASK_TYPE, "private_key", encode_mode="rb") + p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) - return serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + return pkb def _write_to_sf(structured_dataset: StructuredDataset): From a6858c0bcb29df20e372f4145cc2ccfa9adbd5f2 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 15 Sep 2023 08:49:34 +0800 Subject: [PATCH 09/16] fix the suggestions Signed-off-by: HH --- flytekit/types/structured/__init__.py | 2 +- flytekit/types/structured/snowflake.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 617e4bcafa..0783a99077 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -82,5 +82,5 @@ def register_snowflake_handlers(): except ImportError: logger.info( "We won't register snowflake handler for structured dataset because " - "we can't find package snowflakee-connector-python" + "we can't find package snowflake-connector-python" ) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py index 4842a1c28e..5af2a93430 100644 --- a/flytekit/types/structured/snowflake.py +++ b/flytekit/types/structured/snowflake.py @@ -67,10 +67,7 @@ def _read_from_sf( cs = conn.cursor() cs.execute(f"select * from {table}") - dff = cs.fetch_pandas_all() - print("cs", cs) - print("dff", dff) - return dff + return cs.fetch_pandas_all() class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder): From ba3d2fcc33613d8f6a35644108b5216c4396c611 Mon Sep 17 00:00:00 2001 From: HH Date: Mon, 18 Sep 2023 00:32:03 +0800 Subject: [PATCH 10/16] fix the typo Signed-off-by: HH --- flytekit/types/structured/snowflake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py index 5af2a93430..050a77bfb8 100644 --- a/flytekit/types/structured/snowflake.py +++ b/flytekit/types/structured/snowflake.py @@ -24,7 +24,7 @@ def get_private_key(): import flytekit - pk_string = flytekit.current_context().secrets.get(TASK_TYPE, "private_key", encode_mode="rb") + pk_string = flytekit.current_context().secrets.get(SNOWFLAKE, "private_key", encode_mode="rb") p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) pkb = p_key.private_bytes( From 1b99e9015411828c6ed3819340915bb1a8287c44 Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 28 Sep 2023 01:18:02 +0800 Subject: [PATCH 11/16] add extentsion Signed-off-by: HH --- flytekit/types/structured/snowflake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py index 050a77bfb8..152dc51b90 100644 --- a/flytekit/types/structured/snowflake.py +++ b/flytekit/types/structured/snowflake.py @@ -24,7 +24,7 @@ def get_private_key(): import flytekit - pk_string = flytekit.current_context().secrets.get(SNOWFLAKE, "private_key", encode_mode="rb") + pk_string = flytekit.current_context().secrets.get(SNOWFLAKE, "private_key.pem", encode_mode="rb") p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) pkb = p_key.private_bytes( From 3af4488128bf03b5b0feec4aba067dc938d00888 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 29 Sep 2023 20:20:03 +0800 Subject: [PATCH 12/16] add additional protocol to support multiple protocol Signed-off-by: HH --- flytekit/types/structured/snowflake.py | 10 ++++---- .../types/structured/structured_dataset.py | 24 ++++++++++++++++--- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py index 152dc51b90..dbdd7e0681 100644 --- a/flytekit/types/structured/snowflake.py +++ b/flytekit/types/structured/snowflake.py @@ -24,7 +24,7 @@ def get_private_key(): import flytekit - pk_string = flytekit.current_context().secrets.get(SNOWFLAKE, "private_key.pem", encode_mode="rb") + pk_string = flytekit.current_context().secrets.get("snowflake", "private_key.pem", encode_mode="rb") p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) pkb = p_key.private_bytes( @@ -41,7 +41,7 @@ def _write_to_sf(structured_dataset: StructuredDataset): raise ValueError("structured_dataset.uri cannot be None.") uri = structured_dataset.uri - _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) + _, user, account, warehouse, database, schema, table = re.split("\\/|://|:", uri) df = structured_dataset.dataframe conn = snowflake.connector.connect( @@ -58,7 +58,7 @@ def _read_from_sf( raise ValueError("structured_dataset.uri cannot be None.") uri = flyte_value.uri - _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) + _, user, account, warehouse, database, schema, table = re.split("\\/|://|:", uri) conn = snowflake.connector.connect( user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse @@ -72,7 +72,7 @@ def _read_from_sf( class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder): def __init__(self): - super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="") + super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="", additional_protocols=["sf"]) def encode( self, @@ -88,7 +88,7 @@ def encode( class SnowflakeToPandasDecodingHandler(StructuredDatasetDecoder): def __init__(self): - super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="") + super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="", additional_protocols=["sf"]) def decode( self, diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 99a0e0832b..03b5a6fcfc 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -5,7 +5,7 @@ import typing from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Dict, Generator, Optional, Type, Union +from typing import Dict, Generator, Optional, Type, Union, List import _datetime from dataclasses_json import config @@ -162,7 +162,7 @@ def extract_cols_and_format( class StructuredDatasetEncoder(ABC): - def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None): + def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None, additional_protocols: Optional[List[str]] = None): """ Extend this abstract class, implement the encode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle @@ -178,9 +178,11 @@ def __init__(self, python_type: Type[T], protocol: Optional[str] = None, support :param supported_format: Arbitrary string representing the format. If not supplied then an empty string will be used. An empty string implies that the encoder works with any format. If the format being asked for does not exist, the transformer enginer will look for the "" endcoder instead and write a warning. + :param additional_protocols: Support many protocols to let user is able to connect to the service with various options. """ self._python_type = python_type self._protocol = protocol.replace("://", "") if protocol else None + self._additional_protocols = [additional_protocol.replace("://", "") for additional_protocol in additional_protocols] if additional_protocols else None self._supported_format = supported_format or "" @property @@ -191,6 +193,10 @@ def python_type(self) -> Type[T]: def protocol(self) -> Optional[str]: return self._protocol + @property + def additional_protocols(self) -> Optional[List[str]]: + return self._additional_protocols + @property def supported_format(self) -> str: return self._supported_format @@ -224,7 +230,7 @@ def encode( class StructuredDatasetDecoder(ABC): - def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None): + def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None, additional_protocols: Optional[List[str]] = None): """ Extend this abstract class, implement the decode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle @@ -239,9 +245,11 @@ def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, suppor :param supported_format: Arbitrary string representing the format. If not supplied then an empty string will be used. An empty string implies that the decoder works with any format. If the format being asked for does not exist, the transformer enginer will look for the "" decoder instead and write a warning. + :param additional_protocols: Support many protocols to let user is able to connect to the service with various options. """ self._python_type = python_type self._protocol = protocol.replace("://", "") if protocol else None + self._additional_protocols = [additional_protocol.replace("://", "") for additional_protocol in additional_protocols] if additional_protocols else None self._supported_format = supported_format or "" @property @@ -252,6 +260,10 @@ def python_type(self) -> Type[DF]: def protocol(self) -> Optional[str]: return self._protocol + @property + def additional_protocols(self) -> Optional[List[str]]: + return self._additional_protocols + @property def supported_format(self) -> str: return self._supported_format @@ -470,6 +482,12 @@ def register( h, h.protocol, default_for_type, override, default_format_for_type, default_storage_for_type ) + if h.additional_protocols is not None: + for additional_protocol in h.additional_protocols: + cls.register_for_protocol( + h, additional_protocol, default_for_type, override, default_format_for_type, default_storage_for_type + ) + @classmethod def register_for_protocol( cls, From f4a3c3b0ed3c02bda3bc6a83240cdb70174be282 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 29 Sep 2023 22:42:15 +0800 Subject: [PATCH 13/16] fix lint Signed-off-by: HH --- .../types/structured/structured_dataset.py | 37 ++++++++++++++++--- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 03b5a6fcfc..25371463ea 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -5,7 +5,7 @@ import typing from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Dict, Generator, Optional, Type, Union, List +from typing import Dict, Generator, List, Optional, Type, Union import _datetime from dataclasses_json import config @@ -162,7 +162,13 @@ def extract_cols_and_format( class StructuredDatasetEncoder(ABC): - def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None, additional_protocols: Optional[List[str]] = None): + def __init__( + self, + python_type: Type[T], + protocol: Optional[str] = None, + supported_format: Optional[str] = None, + additional_protocols: Optional[List[str]] = None, + ): """ Extend this abstract class, implement the encode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle @@ -182,7 +188,11 @@ def __init__(self, python_type: Type[T], protocol: Optional[str] = None, support """ self._python_type = python_type self._protocol = protocol.replace("://", "") if protocol else None - self._additional_protocols = [additional_protocol.replace("://", "") for additional_protocol in additional_protocols] if additional_protocols else None + self._additional_protocols = ( + [additional_protocol.replace("://", "") for additional_protocol in additional_protocols] + if additional_protocols + else None + ) self._supported_format = supported_format or "" @property @@ -230,7 +240,13 @@ def encode( class StructuredDatasetDecoder(ABC): - def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None, additional_protocols: Optional[List[str]] = None): + def __init__( + self, + python_type: Type[DF], + protocol: Optional[str] = None, + supported_format: Optional[str] = None, + additional_protocols: Optional[List[str]] = None, + ): """ Extend this abstract class, implement the decode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle @@ -249,7 +265,11 @@ def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, suppor """ self._python_type = python_type self._protocol = protocol.replace("://", "") if protocol else None - self._additional_protocols = [additional_protocol.replace("://", "") for additional_protocol in additional_protocols] if additional_protocols else None + self._additional_protocols = ( + [additional_protocol.replace("://", "") for additional_protocol in additional_protocols] + if additional_protocols + else None + ) self._supported_format = supported_format or "" @property @@ -485,7 +505,12 @@ def register( if h.additional_protocols is not None: for additional_protocol in h.additional_protocols: cls.register_for_protocol( - h, additional_protocol, default_for_type, override, default_format_for_type, default_storage_for_type + h, + additional_protocol, + default_for_type, + override, + default_format_for_type, + default_storage_for_type, ) @classmethod From d64cd7a98093f4f314dddfc51a4e68b7cac5426e Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 29 Sep 2023 22:46:59 +0800 Subject: [PATCH 14/16] fix uri in test Signed-off-by: HH --- tests/flytekit/unit/types/structured_dataset/test_snowflake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py index c957c0bbce..38078021ea 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py +++ b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py @@ -17,7 +17,7 @@ def gen_df() -> Annotated[pd.DataFrame, my_cols, "parquet"]: @task def t1(df: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: return StructuredDataset( - dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table" + dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_warehouse/dummy_database/dummy_schema/dummy_table" ) From 0cdf9f639c106698428d83882a20b0ac97002553 Mon Sep 17 00:00:00 2001 From: HH Date: Sat, 30 Sep 2023 23:07:15 +0800 Subject: [PATCH 15/16] test .pem error Signed-off-by: HH --- plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index c4176228ea..e172ebe8ac 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -81,7 +81,7 @@ async def async_create( name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() } native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) - logger.info(f"Create Snowflake params with inputs: {native_inputs}") + logger.info(f"Create Snowflake Agent params with inputs: {native_inputs}") params = native_inputs config = task_template.config From 0b4787217481f63e31268721f0d05e28c6e4bf25 Mon Sep 17 00:00:00 2001 From: HH Date: Sat, 30 Sep 2023 23:21:14 +0800 Subject: [PATCH 16/16] test .pem error Signed-off-by: HH --- flytekit/types/structured/snowflake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py index dbdd7e0681..7dd04ad9f9 100644 --- a/flytekit/types/structured/snowflake.py +++ b/flytekit/types/structured/snowflake.py @@ -24,7 +24,7 @@ def get_private_key(): import flytekit - pk_string = flytekit.current_context().secrets.get("snowflake", "private_key.pem", encode_mode="rb") + pk_string = flytekit.current_context().secrets.get("snowflake", "private_key", encode_mode="rb") p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) pkb = p_key.private_bytes(