Skip to content

Commit

Permalink
Ignore duplicate handler errors when lazy loading (flyteorg#2316)
Browse files Browse the repository at this point in the history
If the user registers a custom structured dataset encoder/decoder before the lazy import is run for the first time, the default transformers will fail because they don't run with override. flytekit should swallow those errors.

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Apr 2, 2024
1 parent 66a6018 commit 2cbdc99
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
16 changes: 13 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,7 @@ def lazy_import_transformers(cls):
register_bigquery_handlers,
register_pandas_handlers,
)
from flytekit.types.structured.structured_dataset import DuplicateHandlerError

if is_imported("tensorflow"):
from flytekit.extras import tensorflow # noqa: F401
Expand All @@ -1056,11 +1057,20 @@ def lazy_import_transformers(cls):
from flytekit.types.schema.types_pandas import PandasSchemaReader, PandasSchemaWriter # noqa: F401
except ValueError:
logger.debug("Transformer for pandas is already registered.")
register_pandas_handlers()
try:
register_pandas_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for pandas is already registered.")
if is_imported("pyarrow"):
register_arrow_handlers()
try:
register_arrow_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for arrow is already registered.")
if is_imported("google.cloud.bigquery"):
register_bigquery_handlers()
try:
register_bigquery_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for bigquery is already registered.")
if is_imported("numpy"):
from flytekit.types import numpy # noqa: F401
if is_imported("PIL"):
Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(self, python_type: Type[T], protocol: Optional[str] = None, support
is capable of handling.
:param supported_format: Arbitrary string representing the format. If not supplied then an empty string
will be used. An empty string implies that the encoder works with any format. If the format being asked
for does not exist, the transformer enginer will look for the "" encoder instead and write a warning.
for does not exist, the transformer engine will look for the "" encoder instead and write a warning.
"""
self._python_type = python_type
self._protocol = protocol.replace("://", "") if protocol else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tempfile
import typing

import google.cloud.bigquery
import pyarrow as pa
import pytest
from fsspec.utils import get_protocol
Expand All @@ -15,6 +16,7 @@
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.lazy_import.lazy_module import is_imported
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import SchemaType, SimpleType, StructuredDatasetType
Expand Down Expand Up @@ -508,3 +510,38 @@ def test_list_of_annotated():
@task
def no_op(data: WineDataset) -> typing.List[WineDataset]:
return [data]


class PrivatePandasToBQEncodingHandlers(StructuredDatasetEncoder):
def __init__(self):
super().__init__(pd.DataFrame, "bq", supported_format="")

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
return literals.StructuredDataset(
uri=typing.cast(str, structured_dataset.uri), metadata=StructuredDatasetMetadata(structured_dataset_type)
)


def test_reregister_encoder():
# Test that lazy import can run after a user has already registered a custom handler.
# The default handlers don't have override=True (and should not) but the call should not fail.
dir(google.cloud.bigquery)
assert is_imported("google.cloud.bigquery")

StructuredDatasetTransformerEngine.register(
PrivatePandasToBQEncodingHandlers(), default_format_for_type=False, override=True
)
TypeEngine.lazy_import_transformers()

sd = StructuredDataset(dataframe=pd.DataFrame({"a": [1, 2], "b": [3, 4]}), uri="bq://blah", file_format="bq")

ctx = FlyteContextManager.current_context()

df_literal_type = TypeEngine.to_literal_type(pd.DataFrame)

TypeEngine.to_literal(ctx, sd, python_type=pd.DataFrame, expected=df_literal_type)

0 comments on commit 2cbdc99

Please sign in to comment.