diff --git a/projects/extension/ai/load_dataset.py b/projects/extension/ai/load_dataset.py index e521dfbc..5aba0cca 100644 --- a/projects/extension/ai/load_dataset.py +++ b/projects/extension/ai/load_dataset.py @@ -33,11 +33,19 @@ def get_default_column_type(dtype: str) -> str: return type_mapping.get(dtype.lower(), "TEXT") +def field_name_to_column_name(field_name: str) -> str: + # replace . and " with underscore. periods are used in nested structs but hard to work with in SQL. + # double quotes are dangerous in SQL because they can be used to escape identifiers. + return field_name.translate(str.maketrans({".": "_", '"': "_"})) + + def get_column_info( dataset: datasets.Dataset, field_types: Optional[Dict[str, str]] ) -> tuple[Dict[str, str], Dict[str, Any], str]: # Extract types from features - column_dtypes = {name: feature.dtype for name, feature in dataset.features.items()} + column_dtypes = { + name: feature.dtype for name, feature in dataset.features.flatten().items() + } # Prepare column types, using field_types if provided, otherwise use inferred types column_pgtypes = {} for name, py_type in column_dtypes.items(): @@ -47,7 +55,9 @@ def get_column_info( if field_types and name in field_types else get_default_column_type(str(py_type)) ) - column_names = ", ".join(f'"{name}"' for name in column_dtypes.keys()) + column_names = ", ".join( + f'"{field_name_to_column_name(name)}"' for name in column_dtypes.keys() + ) return column_pgtypes, column_dtypes, column_names @@ -123,7 +133,8 @@ def create_table( plpy.notice(f"creating table {qualified_table}") column_type_def = ", ".join( - f'"{name}" {col_type}' for name, col_type in column_types.items() + f'"{field_name_to_column_name(name)}" {col_type}' + for name, col_type in column_types.items() ) # Create table @@ -195,6 +206,13 @@ def load_dataset( column_pgtypes, column_dtypes, column_names = get_column_info( first_dataset, field_types ) + + flatten = False + for field_name in column_dtypes.keys(): + if "." in field_name: + flatten = True + break + qualified_table = create_table( plpy, name, config_name, schema, table_name, column_pgtypes, if_table_exists ) @@ -216,23 +234,36 @@ def load_dataset( batch_count = 0 batches_since_commit = 0 for split, dataset in datasetdict.items(): + # use arrow format to allow you to use the flatten method to flatten nested structs + batched_dataset = dataset.with_format("arrow") # Process data in batches using dataset iteration - batched_dataset = dataset.batch(batch_size=batch_size) - for batch in batched_dataset: + for batch in batched_dataset.iter(batch_size=batch_size): if max_batches and batch_count >= max_batches: break + if flatten: + batch = batch.flatten() + batch_arrays = [[] for _ in column_dtypes] - for i, (col_name, py_type) in enumerate(column_dtypes.items()): + for i, (field_name, py_type) in enumerate(column_dtypes.items()): type_str = str(py_type).lower() - array_values = batch[col_name] + array_values = batch[field_name] if type_str in ("dict", "list"): - batch_arrays[i] = [json.dumps(value) for value in array_values] + batch_arrays[i] = [ + json.dumps(value.as_py()) if value is not None else None + for value in array_values + ] elif type_str in ("int64", "int32", "int16", "int8"): - batch_arrays[i] = [int(value) for value in array_values] + batch_arrays[i] = [ + value.as_py() if value is not None else None + for value in array_values + ] elif type_str in ("float64", "float32", "float16"): - batch_arrays[i] = [float(value) for value in array_values] + batch_arrays[i] = [ + value.as_py() if value is not None else None + for value in array_values + ] else: batch_arrays[i] = array_values diff --git a/projects/extension/tests/test_load_dataset.py b/projects/extension/tests/test_load_dataset.py index c7be46fe..728ffc54 100644 --- a/projects/extension/tests/test_load_dataset.py +++ b/projects/extension/tests/test_load_dataset.py @@ -87,6 +87,19 @@ def test_load_dataset_with_field_types(cur): assert cur.fetchall() == [("text", "text"), ("label", "integer")] +def test_load_dataset_with_struct_and_nulls(cur): + cur.execute( + """ + select ai.load_dataset('foursquare/fsq-os-places', schema_name=>'public', table_name=>'fsq_places', batch_size=>100, max_batches=>1) + """, + ) + actual = cur.fetchone()[0] + assert actual == 100 + + cur.execute("select count(*) from public.fsq_places") + assert cur.fetchone()[0] == actual + + def test_load_dataset_with_field_with_max_batches_and_timestamp(cur): cur.execute( """