Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Support creation and reading of StructuredDataset with local or remote uri #2914

25 changes: 25 additions & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from flytekit.models import types as type_models
from flytekit.models.literals import Binary, Literal, Scalar, StructuredDatasetMetadata
from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType
from flytekit.utils.asyn import loop_manager

if typing.TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -176,8 +177,32 @@ def all(self) -> DF: # type: ignore
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
ctx = FlyteContextManager.current_context()

if self.uri is not None and self.dataframe is None:
expected = TypeEngine.to_literal_type(StructuredDataset)
self._set_literal(ctx, expected)

return flyte_dataset_transformer.open_as(ctx, self.literal, self._dataframe_type, self.metadata)

def _set_literal(self, ctx: FlyteContext, expected: LiteralType) -> None:
"""
Explicitly set the StructuredDataset Literal to handle the following cases:

1. Read a dataframe from a StructuredDataset with an uri, for example:

@task
def return_sd() -> StructuredDataset:
sd = StructuredDataset(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", file_format="parquet")
df = sd.open(pd.DataFrame).all()
return df

For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5954.
"""
Comment on lines +187 to +200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means a lot to hear it from you!

to_literal = loop_manager.synced(flyte_dataset_transformer.async_to_literal)
self._literal_sd = to_literal(ctx, self, StructuredDataset, expected).scalar.structured_dataset
Comment on lines +201 to +202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if is here the best way to write it.
cc @wild-endeavor @thomasjpfan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I'm also pondering if this is a good practice...

Will be glad to learn more from you guys!

if self.metadata is None:
self._metadata = self._literal_sd.metadata

def iter(self) -> Generator[DF, None, None]:
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tempfile
import typing
from collections import OrderedDict
from pathlib import Path

import google.cloud.bigquery
import pytest
Expand All @@ -21,6 +22,7 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType
from flytekit.tools.translator import get_serializable
from flytekit.types.file import FlyteFile
from flytekit.types.structured.structured_dataset import (
PARQUET,
StructuredDataset,
Expand Down Expand Up @@ -59,6 +61,21 @@ def generate_pandas() -> pd.DataFrame:
return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]})


@pytest.fixture
def local_tmp_pqt_file():
df = generate_pandas()

# Create a temporary parquet file
with tempfile.NamedTemporaryFile(delete=False, mode="w+b", suffix=".parquet") as pqt_file:
pqt_path = pqt_file.name
df.to_parquet(pqt_path)

yield pqt_path

# Cleanup
Path(pqt_path).unlink(missing_ok=True)


def test_formats_make_sense():
@task
def t1(a: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -643,3 +660,70 @@ def wf_with_input() -> pd.DataFrame:

pd.testing.assert_frame_equal(wf_no_input(), default_val)
pd.testing.assert_frame_equal(wf_with_input(), input_val)



def test_read_sd_from_uri(local_tmp_pqt_file):
import os

os.environ['FLYTE_AWS_ENDPOINT'] = 'http://localhost:30002/'
os.environ['FLYTE_AWS_ACCESS_KEY_ID'] = 'minio'
os.environ['FLYTE_AWS_SECRET_ACCESS_KEY'] = 'miniostorage'

@task
def upload_pqt_to_s3(local_path: str, remote_path: str) -> None:
"""Upload local temp parquet file to s3 object storage"""

with tempfile.TemporaryDirectory() as tmp_dir:
fs = FileAccessProvider(
local_sandbox_dir=tmp_dir,
raw_output_prefix="s3://my-s3-bucket"
)
fs.upload(local_path, remote_path)

@task
def read_sd_from_uri(uri: str) -> pd.DataFrame:
sd = StructuredDataset(uri=uri, file_format="parquet")
df = sd.open(pd.DataFrame).all()

return df

@workflow
def read_sd_from_local_uri(uri: str) -> pd.DataFrame:
df = read_sd_from_uri(uri=uri)

return df

@workflow
def read_sd_from_remote_uri(
local_path: str,
remote_path: str,
read_uri: str
) -> typing.Tuple[pd.DataFrame, pd.DataFrame]:
# Upload parqut to s3
upload_pqt_to_s3(local_path=local_path, remote_path=remote_path)
ff = FlyteFile(path=REMOTE_PATH)
with ff.open(mode="rb") as f:
df_s3 = pd.read_parquet(f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't do this in a workflow, workflow is like a bridge to build your DAG.
task is the place to write Python code.

https://docs.flyte.org/en/latest/user_guide/concepts/main_concepts/workflows.html#divedeep-workflows

image


# Read sd from remote uri
df_remote = read_sd_from_uri(uri=read_uri)

return df_s3, df_remote


REMOTE_PATH = "s3://my-s3-bucket/my-test/df.parquet"
df = generate_pandas()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why REMOTE_PATH = "s3://my-s3-bucket/my-test/df.parquet" but not
REMOTE_PATH = "s3://my-s3-bucket/df.parquet"?

Copy link
Contributor Author

@JiangJiaWei1103 JiangJiaWei1103 Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot that I just created my-testdir and uploaded df.parquet manually last time. Hence, the test seems to pass without an error locally.

Your are right. We can just use REMOTE_PATH = "s3://my-s3-bucket/df.parquet". Also, the uploading logic in the previous comment should be modified, too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the best way to test this in your computer.

(dev) future@outlier ~ % pwd
/Users/future-outlier/code/dev/flytekit
(dev) future@outlier ~ % pytest -s test_file.py

This is how flytekit run the uni test.
or you can just run

make unit_test

under the folder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the expected result.
image

# Read sd from local uri
df_local = read_sd_from_local_uri(uri=local_tmp_pqt_file)
pd.testing.assert_frame_equal(df, df_local)

# Read sd from remote uri
df_s3, df_remote = read_sd_from_remote_uri(
local_path=local_tmp_pqt_file,
remote_path=REMOTE_PATH,
read_uri=REMOTE_PATH
)
pd.testing.assert_frame_equal(df, df_s3)
pd.testing.assert_frame_equal(df, df_remote)
Loading