Skip to content

Commit

Permalink
GH-40720: [Python] Simplify and improve perf of creation of the colum…
Browse files Browse the repository at this point in the history
…n names in Table.to_pandas (#40721)

### Rationale for this change

The `pandas_compat.py` has over the years grown quite complex and a lot of pandas compatibility code, which probably can be simplified nowadays because of not supporting old pandas and Python versions anymore.

One part of the code where this is the case is in the reconstruction of the `.columns` Index object of the resulting DataFrame. Right now that always goes through a MultiIndex (even for simple column names), which has quite some overhead of the simple case. And it also has some old Python/pandas compat code that could be removed.

### What changes are included in this PR?

The simplification to not go through a MultiIndex for the simple cases gives a nice speed-up as well:

```python
In [1]: table = pa.table({'a': [1, 2, 3], 'b': [0.1, 0.2, 0.3], 'c': [3, 4, 5]})

In [2]: %timeit table.to_pandas()
251 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)   # <-- main
68.1 µs ± 894 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)  # <-- PR
```

### Are these changes tested?

We should have extensive existing tests for this

### Are there any user-facing changes?

That should not be the case
* GitHub Issue: #40720

Authored-by: Joris Van den Bossche <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
jorisvandenbossche authored Mar 26, 2024
1 parent 5e1a4fd commit 7d4d744
Showing 1 changed file with 17 additions and 50 deletions.
67 changes: 17 additions & 50 deletions python/pyarrow/pandas_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,42 +864,35 @@ def _check_data_column_metadata_consistency(all_columns):


def _deserialize_column_index(block_table, all_columns, column_indexes):
column_strings = [frombytes(x) if isinstance(x, bytes) else x
for x in block_table.column_names]
if all_columns:
columns_name_dict = {
c.get('field_name', _column_name_to_strings(c['name'])): c['name']
for c in all_columns
}
columns_values = [
columns_name_dict.get(name, name) for name in column_strings
columns_name_dict.get(name, name) for name in block_table.column_names
]
else:
columns_values = column_strings

# If we're passed multiple column indexes then evaluate with
# ast.literal_eval, since the column index values show up as a list of
# tuples
to_pair = ast.literal_eval if len(column_indexes) > 1 else lambda x: (x,)

# Create the column index
columns_values = block_table.column_names

# Construct the base index
if not columns_values:
columns = _pandas_api.pd.Index(columns_values)
else:
if len(column_indexes) > 1:
# If we're passed multiple column indexes then evaluate with
# ast.literal_eval, since the column index values show up as a list of
# tuples
columns = _pandas_api.pd.MultiIndex.from_tuples(
list(map(to_pair, columns_values)),
names=[col_index['name'] for col_index in column_indexes] or None,
list(map(ast.literal_eval, columns_values)),
names=[col_index['name'] for col_index in column_indexes],
)
else:
columns = _pandas_api.pd.Index(
columns_values, name=column_indexes[0]["name"] if column_indexes else None
)

# if we're reconstructing the index
if len(column_indexes) > 0:
columns = _reconstruct_columns_from_metadata(columns, column_indexes)

# ARROW-1751: flatten a single level column MultiIndex for pandas 0.21.0
columns = _flatten_single_level_multiindex(columns)

return columns


Expand Down Expand Up @@ -1040,13 +1033,6 @@ def _pandas_type_to_numpy_type(pandas_type):
return np.dtype(pandas_type)


def _get_multiindex_codes(mi):
if isinstance(mi, _pandas_api.pd.MultiIndex):
return mi.codes
else:
return None


def _reconstruct_columns_from_metadata(columns, column_indexes):
"""Construct a pandas MultiIndex from `columns` and column index metadata
in `column_indexes`.
Expand All @@ -1073,9 +1059,7 @@ def _reconstruct_columns_from_metadata(columns, column_indexes):
# Get levels and labels, and provide sane defaults if the index has a
# single level to avoid if/else spaghetti.
levels = getattr(columns, 'levels', None) or [columns]
labels = _get_multiindex_codes(columns) or [
pd.RangeIndex(len(level)) for level in levels
]
labels = getattr(columns, 'codes', None) or [None]

# Convert each level to the dtype provided in the metadata
levels_dtypes = [
Expand Down Expand Up @@ -1109,7 +1093,10 @@ def _reconstruct_columns_from_metadata(columns, column_indexes):

new_levels.append(level)

return pd.MultiIndex(new_levels, labels, names=columns.names)
if len(new_levels) > 1:
return pd.MultiIndex(new_levels, labels, names=columns.names)
else:
return pd.Index(new_levels[0], dtype=new_levels[0].dtype, name=columns.name)


def _table_to_blocks(options, block_table, categories, extension_columns):
Expand All @@ -1123,26 +1110,6 @@ def _table_to_blocks(options, block_table, categories, extension_columns):
for item in result]


def _flatten_single_level_multiindex(index):
pd = _pandas_api.pd
if isinstance(index, pd.MultiIndex) and index.nlevels == 1:
levels, = index.levels
labels, = _get_multiindex_codes(index)
# ARROW-9096: use levels.dtype to match cast with original DataFrame
dtype = levels.dtype

# Cheaply check that we do not somehow have duplicate column names
if not index.is_unique:
raise ValueError('Found non-unique column index')

return pd.Index(
[levels[_label] if _label != -1 else None for _label in labels],
dtype=dtype,
name=index.names[0]
)
return index


def _add_any_metadata(table, pandas_metadata):
modified_columns = {}
modified_fields = {}
Expand Down

0 comments on commit 7d4d744

Please sign in to comment.