Skip to content

Commit

Permalink
add delete for overwrite
Browse files Browse the repository at this point in the history
  • Loading branch information
bindipankhudi committed May 15, 2024
1 parent fe80b17 commit af0e8e6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,20 @@ def index(self, document_chunks: Iterable[Any], namespace: str, stream: str):
cortex_processor.process_airbyte_messages(airbyte_messages, self.get_write_strategy(stream))

def delete(self, delete_ids: list[str], namespace: str, stream: str):
# delete is generally used when we use full refresh/overwrite strategy.
# PyAirbyte's sync will take care of overwriting the records. Hence, we don't need to do anything here.
# this delete is specific to vector stores, hence not implemented here
pass

def pre_sync(self, catalog: ConfiguredAirbyteCatalog) -> None:
"""
Run before the sync starts. This method should be used to make sure all records in the destination that belong to streams with a destination mode of overwrite are deleted.
Each record has a metadata field with the name airbyte_cdk.destinations.vector_db_based.document_processor.METADATA_STREAM_FIELD which can be used to filter documents for deletion.
Use the airbyte_cdk.destinations.vector_db_based.utils.create_stream_identifier method to create the stream identifier based on the stream definition to use for filtering.
Run before the sync starts. This method makes sure that all records in the destination that belong to streams with a destination mode of overwrite are deleted.
"""
table_list = self.default_processor._get_tables_list()
for stream in catalog.streams:
# remove all records for streams with overwrite mode
if stream.destination_sync_mode == DestinationSyncMode.overwrite:
# TODO: remove all records for the stream
stream_name = stream.stream.name
if stream_name.lower() in [table.lower() for table in table_list]:
self.default_processor._execute_sql(f"DELETE FROM {stream_name}")
pass

def check(self) -> Optional[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,22 @@ def test_write(self):
assert(len(result) == 1)
result[0] == "str_col: Cats are nice"

# Fix: Trying to write records > batch size in overwrite mode does not work currently
def _test_write_record_count_200(self):

def test_overwrite_mode_deletes_records(self):
self._delete_table("mystream")
catalog = self._get_configured_catalog(DestinationSyncMode.overwrite)
first_state_message = self._state({"state": "1"})
first_record_chunk = [self._record("mystream", f"Dogs are number {i}", i) for i in range(200)]
first_record_chunk = [self._record("mystream", f"Dogs are number {i}", i) for i in range(4)]

# initial sync with replace
destination = DestinationSnowflakeCortex()
list(destination.write(self.config, catalog, [*first_record_chunk, first_state_message]))
assert(self._get_record_count("mystream") == 200)
assert(self._get_record_count("mystream") == 4)

# following should replace existing records
append_catalog = self._get_configured_catalog(DestinationSyncMode.overwrite)
list(destination.write(self.config, append_catalog, [self._record("mystream", "Cats are nice", 6), first_state_message]))
assert(self._get_record_count("mystream") == 1)

"""
Following tests are not code specific, but are useful to confirm that the Cortex functions are available and behaving as expcected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,28 @@ def test_check():
assert result == None


def test_pre_sync_table_does_exist():
indexer = _create_snowflake_cortex_indexer(generate_catalog())
mock_processor = MagicMock()
indexer.default_processor = mock_processor

mock_processor._get_tables_list.return_value = ["table1", "table2"]
mock_processor._execute_query.return_value = None
indexer.pre_sync(generate_catalog())
mock_processor._get_tables_list.assert_called_once()
mock_processor._execute_sql.assert_not_called()

def test_pre_sync_table_exists():
indexer = _create_snowflake_cortex_indexer(generate_catalog())
mock_processor = MagicMock()
indexer.default_processor = mock_processor

mock_processor._get_tables_list.return_value = ["example_stream2", "table2"]
mock_processor._execute_query.return_value = None
indexer.pre_sync(generate_catalog())
mock_processor._get_tables_list.assert_called_once()
mock_processor._execute_sql.assert_called_once()

def generate_catalog():
return ConfiguredAirbyteCatalog.parse_obj(
{
Expand Down

0 comments on commit af0e8e6

Please sign in to comment.