Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add encoder/decoder in structureDataset for snowflake. #1811

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pre-commit
codespell
google-cloud-bigquery
google-cloud-bigquery-storage
snowflake-connector-python
IPython
keyrings.alt

Expand Down
6 changes: 6 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ def lazy_import_transformers(cls):
register_arrow_handlers,
register_bigquery_handlers,
register_pandas_handlers,
register_snowflake_handlers,
)

if is_imported("tensorflow"):
Expand All @@ -875,6 +876,11 @@ def lazy_import_transformers(cls):
if is_imported("numpy"):
from flytekit.types import numpy # noqa: F401

try:
register_snowflake_handlers()
except ValueError as e:
logger.debug(f"Attempted to register the Snowflake handler but failed due to: {str(e)}")

@classmethod
def to_literal_type(cls, python_type: Type) -> LiteralType:
"""
Expand Down
14 changes: 14 additions & 0 deletions flytekit/types/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,17 @@ def register_bigquery_handlers():
"We won't register bigquery handler for structured dataset because "
"we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery"
)


def register_snowflake_handlers():
try:
from .snowflake import PandasToSnowflakeEncodingHandlers, SnowflakeToPandasDecodingHandler

StructuredDatasetTransformerEngine.register(SnowflakeToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(PandasToSnowflakeEncodingHandlers())

except ImportError:
logger.info(
"We won't register snowflake handler for structured dataset because "
"we can't find package snowflake-connector-python"
)
99 changes: 99 additions & 0 deletions flytekit/types/structured/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import re
import typing

import pandas as pd
import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas

from flytekit import FlyteContext
from flytekit.models import literals
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetMetadata,
)

SNOWFLAKE = "snowflake"


def get_private_key():
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

import flytekit

pk_string = flytekit.current_context().secrets.get("snowflake", "private_key", encode_mode="rb")
p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend())

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

return pkb


def _write_to_sf(structured_dataset: StructuredDataset):
if structured_dataset.uri is None:
raise ValueError("structured_dataset.uri cannot be None.")

uri = structured_dataset.uri
_, user, account, warehouse, database, schema, table = re.split("\\/|://|:", uri)
df = structured_dataset.dataframe

conn = snowflake.connector.connect(
user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse
)

write_pandas(conn, df, table)


def _read_from_sf(
flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata
) -> pd.DataFrame:
if flyte_value.uri is None:
raise ValueError("structured_dataset.uri cannot be None.")

uri = flyte_value.uri
_, user, account, warehouse, database, schema, table = re.split("\\/|://|:", uri)

conn = snowflake.connector.connect(
user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse
)

cs = conn.cursor()
cs.execute(f"select * from {table}")

return cs.fetch_pandas_all()


class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder):
def __init__(self):
super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="", additional_protocols=["sf"])

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
_write_to_sf(structured_dataset)
return literals.StructuredDataset(
uri=typing.cast(str, structured_dataset.uri), metadata=StructuredDatasetMetadata(structured_dataset_type)
)


class SnowflakeToPandasDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="", additional_protocols=["sf"])

def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pd.DataFrame:
return _read_from_sf(flyte_value, current_task_metadata)
49 changes: 46 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Generator, Optional, Type, Union
from typing import Dict, Generator, List, Optional, Type, Union

import _datetime
from dataclasses_json import config
Expand Down Expand Up @@ -162,7 +162,13 @@ def extract_cols_and_format(


class StructuredDatasetEncoder(ABC):
def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None):
def __init__(
self,
python_type: Type[T],
protocol: Optional[str] = None,
supported_format: Optional[str] = None,
additional_protocols: Optional[List[str]] = None,
):
"""
Extend this abstract class, implement the encode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand All @@ -178,9 +184,15 @@ def __init__(self, python_type: Type[T], protocol: Optional[str] = None, support
:param supported_format: Arbitrary string representing the format. If not supplied then an empty string
will be used. An empty string implies that the encoder works with any format. If the format being asked
for does not exist, the transformer enginer will look for the "" endcoder instead and write a warning.
:param additional_protocols: Support many protocols to let user is able to connect to the service with various options.
"""
self._python_type = python_type
self._protocol = protocol.replace("://", "") if protocol else None
self._additional_protocols = (
[additional_protocol.replace("://", "") for additional_protocol in additional_protocols]
if additional_protocols
else None
)
self._supported_format = supported_format or ""

@property
Expand All @@ -191,6 +203,10 @@ def python_type(self) -> Type[T]:
def protocol(self) -> Optional[str]:
return self._protocol

@property
def additional_protocols(self) -> Optional[List[str]]:
return self._additional_protocols

@property
def supported_format(self) -> str:
return self._supported_format
Expand Down Expand Up @@ -224,7 +240,13 @@ def encode(


class StructuredDatasetDecoder(ABC):
def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None):
def __init__(
self,
python_type: Type[DF],
protocol: Optional[str] = None,
supported_format: Optional[str] = None,
additional_protocols: Optional[List[str]] = None,
):
"""
Extend this abstract class, implement the decode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand All @@ -239,9 +261,15 @@ def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, suppor
:param supported_format: Arbitrary string representing the format. If not supplied then an empty string
will be used. An empty string implies that the decoder works with any format. If the format being asked
for does not exist, the transformer enginer will look for the "" decoder instead and write a warning.
:param additional_protocols: Support many protocols to let user is able to connect to the service with various options.
"""
self._python_type = python_type
self._protocol = protocol.replace("://", "") if protocol else None
self._additional_protocols = (
[additional_protocol.replace("://", "") for additional_protocol in additional_protocols]
if additional_protocols
else None
)
self._supported_format = supported_format or ""

@property
Expand All @@ -252,6 +280,10 @@ def python_type(self) -> Type[DF]:
def protocol(self) -> Optional[str]:
return self._protocol

@property
def additional_protocols(self) -> Optional[List[str]]:
return self._additional_protocols

@property
def supported_format(self) -> str:
return self._supported_format
Expand Down Expand Up @@ -470,6 +502,17 @@ def register(
h, h.protocol, default_for_type, override, default_format_for_type, default_storage_for_type
)

if h.additional_protocols is not None:
for additional_protocol in h.additional_protocols:
cls.register_for_protocol(
h,
additional_protocol,
default_for_type,
override,
default_format_for_type,
default_storage_for_type,
)

@classmethod
def register_for_protocol(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def async_create(
name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items()
}
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs)
logger.info(f"Create Snowflake params with inputs: {native_inputs}")
logger.info(f"Create Snowflake Agent params with inputs: {native_inputs}")
params = native_inputs

config = task_template.config
Expand Down
48 changes: 48 additions & 0 deletions tests/flytekit/unit/types/structured_dataset/test_snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import mock
import pandas as pd
import pytest
from typing_extensions import Annotated

from flytekit import StructuredDataset, kwtypes, task, workflow

pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
my_cols = kwtypes(Name=str, Age=int)


@task
def gen_df() -> Annotated[pd.DataFrame, my_cols, "parquet"]:
return pd_df


@task
def t1(df: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]:
return StructuredDataset(
dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_warehouse/dummy_database/dummy_schema/dummy_table"
)


@task
def t2(sd: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return sd.open(pd.DataFrame).all()


@workflow
def wf() -> pd.DataFrame:
df = gen_df()
sd = t1(df=df)
return t2(sd=sd)


@mock.patch("flytekit.types.structured.snowflake.get_private_key", return_value="pb")
@mock.patch("snowflake.connector.connect")
@pytest.mark.asyncio
async def test_sf_wf(mock_connect, mock_get_private_key):
class mock_dataframe:
def to_dataframe(self):
return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

mock_connect_instance = mock_connect.return_value
mock_coursor_instance = mock_connect_instance.cursor.return_value
mock_coursor_instance.fetch_pandas_all.return_value = mock_dataframe().to_dataframe()

assert wf().equals(pd_df)