diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index c30ed399..47a980a2 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -155,6 +155,9 @@ def create_index( query_col_names = _serialize_to_expr( list(self._dataframe.unique_key) + [column], self._dataframe._db ) + unique_key_col_names = _serialize_to_expr( + list(self._dataframe.unique_key), self._dataframe._db + ) sql_add_relationship = f""" DO $$ BEGIN @@ -168,13 +171,21 @@ def create_index( SELECT FROM unnest({query_col_names}) AS query WHERE attname = query ) + ),emb_attnum_map AS ( + SELECT attname, attnum FROM pg_attribute + WHERE + attrelid = '{embedding_df._qualified_table_name}'::regclass::oid AND + EXISTS ( + SELECT FROM unnest({unique_key_col_names}) AS query + WHERE attname = query + ) ), embedding_info AS ( SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS embedding_relid, attnum AS content_attnum, {len(self._dataframe._unique_key) + 1} AS embedding_attnum, '{model_name}' AS model, - ARRAY(SELECT attnum FROM attnum_map WHERE attname != '{column}') AS unique_key + ARRAY(SELECT attnum FROM emb_attnum_map WHERE attname != '{column}') AS unique_key FROM attnum_map WHERE attname = '{column}' )