Skip to content

Commit

Permalink
add basic arrow tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 18, 2024
1 parent bd46a0f commit 9c14c69
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
85 changes: 84 additions & 1 deletion tests/libs/pyarrow/test_pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from copy import deepcopy
from typing import List, Any

import pytest
import pyarrow as pa

from dlt.common.libs.pyarrow import py_arrow_to_table_schema_columns, get_py_arrow_datatype
from dlt.common.libs.pyarrow import (
py_arrow_to_table_schema_columns,
get_py_arrow_datatype,
remove_null_columns,
remove_columns,
append_column,
rename_columns,
is_arrow_item,
)
from dlt.common.destination import DestinationCapabilitiesContext
from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA

Expand Down Expand Up @@ -49,3 +59,76 @@ def test_py_arrow_to_table_schema_columns():

# Resulting schema should match the original
assert result == dlt_schema


def _row_at_index(table: pa.Table, index: int) -> List[Any]:
return [table.column(column_name)[index].as_py() for column_name in table.column_names]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_remove_null_columns(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2, "c": None},
{"a": 1, "b": None, "c": None},
]
)
result = remove_null_columns(table)
assert result.column_names == ["a", "b"]
assert _row_at_index(result, 0) == [1, 2]
assert _row_at_index(result, 1) == [1, None]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_remove_columns(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2, "c": 5},
{"a": 1, "b": 3, "c": 4},
]
)
result = remove_columns(table, ["b"])
assert result.column_names == ["a", "c"]
assert _row_at_index(result, 0) == [1, 5]
assert _row_at_index(result, 1) == [1, 4]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_append_column(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2},
{"a": 1, "b": 3},
]
)
result = append_column(table, "c", pa.array([5, 6]))
assert result.column_names == ["a", "b", "c"]
assert _row_at_index(result, 0) == [1, 2, 5]
assert _row_at_index(result, 1) == [1, 3, 6]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_rename_column(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2, "c": 5},
{"a": 1, "b": 3, "c": 4},
]
)
result = rename_columns(table, ["one", "two", "three"])
assert result.column_names == ["one", "two", "three"]
assert _row_at_index(result, 0) == [1, 2, 5]
assert _row_at_index(result, 1) == [1, 3, 4]


@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch])
def test_is_arrow_item(pa_type: Any) -> None:
table = pa_type.from_pylist(
[
{"a": 1, "b": 2, "c": 5},
{"a": 1, "b": 3, "c": 4},
]
)
assert is_arrow_item(table)
assert not is_arrow_item(table.to_pydict())
assert not is_arrow_item("hello")
2 changes: 1 addition & 1 deletion tests/libs/pyarrow/test_pyarrow_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _normalize(table: pa.Table, columns: List[TColumnSchema]) -> pa.Table:


def _row_at_index(table: pa.Table, index: int) -> List[Any]:
return [table.column(column_name)[0].as_py() for column_name in table.column_names]
return [table.column(column_name)[index].as_py() for column_name in table.column_names]


def test_quick_return_if_nothing_to_do() -> None:
Expand Down

0 comments on commit 9c14c69

Please sign in to comment.