Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: load_datasets to handle structs #348

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions projects/extension/ai/load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,19 @@ def get_default_column_type(dtype: str) -> str:
return type_mapping.get(dtype.lower(), "TEXT")


def field_name_to_column_name(field_name: str) -> str:
# replace . and " with underscore. periods are used in nested structs but hard to work with in SQL.
# double quotes are dangerous in SQL because they can be used to escape identifiers.
return field_name.translate(str.maketrans({".": "_", '"': "_"}))


def get_column_info(
dataset: datasets.Dataset, field_types: Optional[Dict[str, str]]
) -> tuple[Dict[str, str], Dict[str, Any], str]:
# Extract types from features
column_dtypes = {name: feature.dtype for name, feature in dataset.features.items()}
column_dtypes = {
name: feature.dtype for name, feature in dataset.features.flatten().items()
}
# Prepare column types, using field_types if provided, otherwise use inferred types
column_pgtypes = {}
for name, py_type in column_dtypes.items():
Expand All @@ -47,7 +55,9 @@ def get_column_info(
if field_types and name in field_types
else get_default_column_type(str(py_type))
)
column_names = ", ".join(f'"{name}"' for name in column_dtypes.keys())
column_names = ", ".join(
f'"{field_name_to_column_name(name)}"' for name in column_dtypes.keys()
)
return column_pgtypes, column_dtypes, column_names


Expand Down Expand Up @@ -123,7 +133,8 @@ def create_table(
plpy.notice(f"creating table {qualified_table}")

column_type_def = ", ".join(
f'"{name}" {col_type}' for name, col_type in column_types.items()
f'"{field_name_to_column_name(name)}" {col_type}'
for name, col_type in column_types.items()
)

# Create table
Expand Down Expand Up @@ -195,6 +206,13 @@ def load_dataset(
column_pgtypes, column_dtypes, column_names = get_column_info(
first_dataset, field_types
)

flatten = False
for field_name in column_dtypes.keys():
if "." in field_name:
flatten = True
break

qualified_table = create_table(
plpy, name, config_name, schema, table_name, column_pgtypes, if_table_exists
)
Expand All @@ -216,23 +234,36 @@ def load_dataset(
batch_count = 0
batches_since_commit = 0
for split, dataset in datasetdict.items():
# use arrow format to allow you to use the flatten method to flatten nested structs
batched_dataset = dataset.with_format("arrow")
# Process data in batches using dataset iteration
batched_dataset = dataset.batch(batch_size=batch_size)
for batch in batched_dataset:
for batch in batched_dataset.iter(batch_size=batch_size):
if max_batches and batch_count >= max_batches:
break

if flatten:
batch = batch.flatten()

batch_arrays = [[] for _ in column_dtypes]
for i, (col_name, py_type) in enumerate(column_dtypes.items()):
for i, (field_name, py_type) in enumerate(column_dtypes.items()):
type_str = str(py_type).lower()
array_values = batch[col_name]
array_values = batch[field_name]

if type_str in ("dict", "list"):
batch_arrays[i] = [json.dumps(value) for value in array_values]
batch_arrays[i] = [
json.dumps(value.as_py()) if value is not None else None
for value in array_values
]
elif type_str in ("int64", "int32", "int16", "int8"):
batch_arrays[i] = [int(value) for value in array_values]
batch_arrays[i] = [
value.as_py() if value is not None else None
for value in array_values
]
elif type_str in ("float64", "float32", "float16"):
batch_arrays[i] = [float(value) for value in array_values]
batch_arrays[i] = [
value.as_py() if value is not None else None
for value in array_values
]
else:
batch_arrays[i] = array_values

Expand Down
13 changes: 13 additions & 0 deletions projects/extension/tests/test_load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ def test_load_dataset_with_field_types(cur):
assert cur.fetchall() == [("text", "text"), ("label", "integer")]


def test_load_dataset_with_struct_and_nulls(cur):
cur.execute(
"""
select ai.load_dataset('foursquare/fsq-os-places', schema_name=>'public', table_name=>'fsq_places', batch_size=>100, max_batches=>1)
""",
)
actual = cur.fetchone()[0]
assert actual == 100

cur.execute("select count(*) from public.fsq_places")
assert cur.fetchone()[0] == actual


def test_load_dataset_with_field_with_max_batches_and_timestamp(cur):
cur.execute(
"""
Expand Down
Loading