Skip to content

Commit

Permalink
GH-38034: [Python] DataFrame Interchange Protocol - correct dtype inf…
Browse files Browse the repository at this point in the history
…ormation for categorical columns (#38065)

### Rationale for this change
See: #38034 (comment)

### What changes are included in this PR?

The `f_string` for the columns with categorical dtype is now corrected to reflect the type of the indices from the dictionary data type. Bit width has been correct before. From the spec:

> For categoricals, the format string describes the type of the
              categorical in the data buffer. In case of a separate encoding of
              the categorical (e.g. an integer to string mapping), this can
              be derived from ``self.describe_categorical``.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

No.
* Closes: #38034

Authored-by: AlenkaF <[email protected]>
Signed-off-by: AlenkaF <[email protected]>
  • Loading branch information
AlenkaF authored Oct 10, 2023
1 parent 4f06beb commit db420c9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/pyarrow/interchange/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def _dtype_from_arrowdtype(
return kind, bit_width, f_string, Endianness.NATIVE
elif pa.types.is_dictionary(dtype):
kind = DtypeKind.CATEGORICAL
f_string = "L"
arr = self._col
indices_dtype = arr.indices.type
_, f_string = _PYARROW_KINDS.get(indices_dtype)
return kind, bit_width, f_string, Endianness.NATIVE
else:
kind, f_string = _PYARROW_KINDS.get(dtype, (None, None))
Expand Down
20 changes: 20 additions & 0 deletions python/pyarrow/tests/interchange/test_interchange_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,23 @@ def test_buffer(int, use_batch):
for idx, truth in enumerate(arr):
val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value
assert val == truth, f"Buffer at index {idx} mismatch"


@pytest.mark.parametrize(
"indices_type, bitwidth, f_string", [
(pa.int8(), 8, "c"),
(pa.int16(), 16, "s"),
(pa.int32(), 32, "i"),
(pa.int64(), 64, "l")
]
)
def test_categorical_dtype(indices_type, bitwidth, f_string):
type = pa.dictionary(indices_type, pa.string())
arr = pa.array(["a", "b", None, "d"], type)
table = pa.table({'a': arr})

df = table.__dataframe__()
col = df.get_column(0)
assert col.dtype[0] == 23 # <DtypeKind.CATEGORICAL: 23>
assert col.dtype[1] == bitwidth
assert col.dtype[2] == f_string

0 comments on commit db420c9

Please sign in to comment.