From 9c14c69a204db6631cae015a674eca4a3983a855 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 18 Jan 2024 17:22:19 +0100 Subject: [PATCH] add basic arrow tests --- tests/libs/pyarrow/test_pyarrow.py | 85 ++++++++++++++++++- tests/libs/pyarrow/test_pyarrow_normalizer.py | 2 +- 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/tests/libs/pyarrow/test_pyarrow.py b/tests/libs/pyarrow/test_pyarrow.py index dffda35005..9857755385 100644 --- a/tests/libs/pyarrow/test_pyarrow.py +++ b/tests/libs/pyarrow/test_pyarrow.py @@ -1,8 +1,18 @@ from copy import deepcopy +from typing import List, Any +import pytest import pyarrow as pa -from dlt.common.libs.pyarrow import py_arrow_to_table_schema_columns, get_py_arrow_datatype +from dlt.common.libs.pyarrow import ( + py_arrow_to_table_schema_columns, + get_py_arrow_datatype, + remove_null_columns, + remove_columns, + append_column, + rename_columns, + is_arrow_item, +) from dlt.common.destination import DestinationCapabilitiesContext from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA @@ -49,3 +59,76 @@ def test_py_arrow_to_table_schema_columns(): # Resulting schema should match the original assert result == dlt_schema + + +def _row_at_index(table: pa.Table, index: int) -> List[Any]: + return [table.column(column_name)[index].as_py() for column_name in table.column_names] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_remove_null_columns(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2, "c": None}, + {"a": 1, "b": None, "c": None}, + ] + ) + result = remove_null_columns(table) + assert result.column_names == ["a", "b"] + assert _row_at_index(result, 0) == [1, 2] + assert _row_at_index(result, 1) == [1, None] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_remove_columns(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2, "c": 5}, + {"a": 1, "b": 3, "c": 4}, + ] + ) + result = remove_columns(table, ["b"]) + assert result.column_names == ["a", "c"] + assert _row_at_index(result, 0) == [1, 5] + assert _row_at_index(result, 1) == [1, 4] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_append_column(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2}, + {"a": 1, "b": 3}, + ] + ) + result = append_column(table, "c", pa.array([5, 6])) + assert result.column_names == ["a", "b", "c"] + assert _row_at_index(result, 0) == [1, 2, 5] + assert _row_at_index(result, 1) == [1, 3, 6] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_rename_column(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2, "c": 5}, + {"a": 1, "b": 3, "c": 4}, + ] + ) + result = rename_columns(table, ["one", "two", "three"]) + assert result.column_names == ["one", "two", "three"] + assert _row_at_index(result, 0) == [1, 2, 5] + assert _row_at_index(result, 1) == [1, 3, 4] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_is_arrow_item(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2, "c": 5}, + {"a": 1, "b": 3, "c": 4}, + ] + ) + assert is_arrow_item(table) + assert not is_arrow_item(table.to_pydict()) + assert not is_arrow_item("hello") diff --git a/tests/libs/pyarrow/test_pyarrow_normalizer.py b/tests/libs/pyarrow/test_pyarrow_normalizer.py index 97a3c21d23..6622c5bb29 100644 --- a/tests/libs/pyarrow/test_pyarrow_normalizer.py +++ b/tests/libs/pyarrow/test_pyarrow_normalizer.py @@ -17,7 +17,7 @@ def _normalize(table: pa.Table, columns: List[TColumnSchema]) -> pa.Table: def _row_at_index(table: pa.Table, index: int) -> List[Any]: - return [table.column(column_name)[0].as_py() for column_name in table.column_names] + return [table.column(column_name)[index].as_py() for column_name in table.column_names] def test_quick_return_if_nothing_to_do() -> None: