Skip to content

Commit

Permalink
Some compatibility fixes (#933)
Browse files Browse the repository at this point in the history
* Some compatibility fixes

* str dtype compat
  • Loading branch information
martindurant authored Oct 7, 2024
1 parent ae62587 commit 5bb9fa7
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 52 deletions.
1 change: 0 additions & 1 deletion fastparquet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,6 @@ def read_row_group_arrays(file, rg, columns, categories, schema_helper, cats,
for k in remains:
out[k][:] = None


def read_row_group(file, rg, columns, categories, schema_helper, cats,
selfmade=False, index=None, assign=None,
scheme='hive', partition_meta=None, row_filter=False):
Expand Down
15 changes: 10 additions & 5 deletions fastparquet/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,19 @@ def cat(col):
# validation due to being an out-of-bounds datetime. xref
# https://github.com/dask/fastparquet/issues/778
dtype = np.dtype(t)
d = np.zeros(size, dtype=dtype) if dtype.kind == "M" else np.empty(size, dtype=dtype)
if d.dtype.kind == "M" and str(col) in timezones:
if dtype.kind == "M":
d = np.zeros(size, dtype=dtype)
# 1) create the DatetimeIndex in UTC as no datetime conversion is needed and
# it works with d uninitialised data (no NonExistentTimeError or AmbiguousTimeError)
# 2) convert to timezone (if UTC=noop, if None=remove tz, if other=change tz)
index = DatetimeIndex(d, tz="UTC").tz_convert(
tz_to_dt_tz(timezones[str(col)]))
if str(col) in timezones:
index = DatetimeIndex(d, tz="UTC").tz_convert(
tz_to_dt_tz(timezones[str(col)]))
else:
index = DatetimeIndex(d, tz=None)
d = index._data._ndarray
else:
d = np.empty(size, dtype=dtype)
index = Index(d)
views[col] = d
else:
Expand Down Expand Up @@ -238,7 +243,7 @@ def set_cats(values, i=i, col=col, **kwargs):
views[col] = block.values._codes
views[col+'-catdef'] = block.values
elif getattr(block.dtype, 'tz', None):
arr = np.asarray(block.values, dtype='M8[ns]')
arr = block.values._ndarray
if len(arr.shape) > 1:
# pandas >= 1.3 does this for some reason
arr = arr.squeeze(axis=0)
Expand Down
27 changes: 27 additions & 0 deletions fastparquet/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,10 +1548,37 @@ def test_read_a_non_pandas_parquet_file(tempdir):
assert parquet_file.head(1).equals(pd.DataFrame({"foo": [0], "bar": ["a"]}))


def test_gh929(tempdir):
idx = pd.date_range("2024-01-01", periods=4, freq="h", tz="Europe/Brussels")
df = pd.DataFrame(index=idx, data={"index_as_col": idx})

df.to_parquet(f"{tempdir}/test_datetimetz_index.parquet", engine="fastparquet")
result = pd.read_parquet(f"{tempdir}/test_datetimetz_index.parquet", engine="fastparquet")
assert result.index.equals(df.index)


def test_writing_to_buffer_does_not_close():
df = pd.DataFrame({"val": [1, 2]})
buffer = io.BytesIO()
write(buffer, df, file_scheme="simple")
assert not buffer.closed
parquet_file = ParquetFile(buffer)
assert parquet_file.count() == 2


@pytest.fixture()
def pandas_string():
if pd.__version__.split(".") < ["3"]:
pytest.skip("'string' type coming in pandas 3.0.0")
original = pd.options.future.infer_string
pd.options.future.infer_string = True
yield
pd.options.future.infer_string = original


def test_auto_string(tempdir, pandas_string):
fn = f"{tempdir}/test.parquet"
df = pd.DataFrame({"a": ["some", "strings"]})
df.to_parquet(fn, engine="fastparquet")


22 changes: 11 additions & 11 deletions fastparquet/test/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def np_empty_mock(shape, dtype):
def test_empty_tz_nonutc():
df, views = empty(types=[DatetimeTZDtype(unit="ns", tz="CET")], size=8784, cols=['a'],
timezones={'a': 'CET', 'index': 'CET'}, index_types=["datetime64[ns]"], index_names=["index"])
assert df.index.tz.zone == "CET"
assert df.a.dtype.tz.zone == "CET"
assert str(df.index.tz) == "CET"
assert str(df.a.dtype.tz) == "CET"


# non-regression test for https://github.com/dask/fastparquet/issues/778
Expand All @@ -91,18 +91,18 @@ def test_timestamps():
views['t'].dtype.kind == "M"

df, views = empty('M8', 100, cols=['t'], timezones={'t': z})
assert df.t.dt.tz.zone == z
assert str(df.t.dt.tz) == z
views['t'].dtype.kind == "M"

# one time column, one normal
df, views = empty('M8,i', 100, cols=['t', 'i'], timezones={'t': z})
assert df.t.dt.tz.zone == z
assert str(df.t.dt.tz) == z
views['t'].dtype.kind == "M"
views['i'].dtype.kind == 'i'

# no effect of timezones= on non-time column
df, views = empty('M8,i', 100, cols=['t', 'i'], timezones={'t': z, 'i': z})
assert df.t.dt.tz.zone == z
assert str(df.t.dt.tz) == z
assert df.i.dtype.kind == 'i'
views['t'].dtype.kind == "M"
views['i'].dtype.kind == 'i'
Expand All @@ -111,22 +111,22 @@ def test_timestamps():
z2 = 'US/Central'
df, views = empty('M8,M8', 100, cols=['t1', 't2'], timezones={'t1': z,
't2': z})
assert df.t1.dt.tz.zone == z
assert df.t2.dt.tz.zone == z
assert str(df.t1.dt.tz) == z
assert str(df.t2.dt.tz) == z

df, views = empty('M8,M8', 100, cols=['t1', 't2'], timezones={'t1': z})
assert df.t1.dt.tz.zone == z
assert str(df.t1.dt.tz) == z
assert df.t2.dt.tz is None

df, views = empty('M8,M8', 100, cols=['t1', 't2'], timezones={'t1': z,
't2': 'UTC'})
assert df.t1.dt.tz.zone == z
assert str(df.t1.dt.tz) == z
assert str(df.t2.dt.tz) == 'UTC'

df, views = empty('M8,M8', 100, cols=['t1', 't2'], timezones={'t1': z,
't2': z2})
assert df.t1.dt.tz.zone == z
assert df.t2.dt.tz.zone == z2
assert str(df.t1.dt.tz) == z
assert str(df.t2.dt.tz) == z2


def test_pandas_hive_serialization(tmpdir):
Expand Down
10 changes: 5 additions & 5 deletions fastparquet/test/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def test_roundtrip_complex(tempdir, scheme,):
@pytest.mark.parametrize('df', [
makeMixedDataFrame(),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz='Europe/London')}),
periods=10, freq='h', tz='Europe/London')}),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz='Europe/Berlin')}),
periods=10, freq='h', tz='Europe/Berlin')}),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz='UTC')}),
periods=10, freq='h', tz='UTC')}),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz=datetime.timezone.min)}),
periods=10, freq='h', tz=datetime.timezone.min)}),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz=datetime.timezone.max)})
periods=10, freq='h', tz=datetime.timezone.max)})
])
def test_datetime_roundtrip(tempdir, df, capsys):
fname = os.path.join(tempdir, 'test.parquet')
Expand Down
46 changes: 22 additions & 24 deletions fastparquet/test/test_partition_filters_specialstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,30 @@ def frame_symbol_dtTrade_type_strike(days=1 * 252,
@pytest.mark.parametrize('input_symbols,input_days,file_scheme,input_columns,'
'partitions,filters',
[
(['NOW', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['now', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['TODAY', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['VIX*', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['QQQ*', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['QQQ!', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['Q%QQ', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['NOW', 'SPY', 'VIX'], 10, 'hive', 2,
['symbol', 'dtTrade'], [('symbol', '==', 'SPY')]),
# (['NOW', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['now', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['TODAY', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['VIX*', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['QQQ*', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['QQQ!', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['Q%QQ', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['NOW', 'SPY', 'VIX'], 10, 'hive', 2,
# ['symbol', 'dtTrade'], [('symbol', '==', 'SPY')]),
(['NOW', 'SPY', 'VIX'], 10, 'hive', 2,
['symbol', 'dtTrade'],
[('dtTrade', '==',
'2005-01-02 00:00:00')]),
(['NOW', 'SPY', 'VIX'], 10, 'hive', 2,
['symbol', 'dtTrade'],
[('dtTrade', '==',
pd.to_datetime('2005-01-02 00:00:00'))]),
]
)
def test_frame_write_read_verify(tempdir, input_symbols, input_days,
Expand Down Expand Up @@ -88,15 +92,9 @@ def test_frame_write_read_verify(tempdir, input_symbols, input_days,

# Filter Input Frame to Match What Should Be Expected from parquet read
# Handle either string or non-string inputs / works for timestamps
filterStrings = []
filtered_input_df = input_df
for name, operator, value in filters:
if isinstance(value, str):
value = "'{}'".format(value)
else:
value = value.__repr__()
filterStrings.append("{} {} {}".format(name, operator, value))
filters_expression = " and ".join(filterStrings)
filtered_input_df = input_df.query(filters_expression)
filtered_input_df = filtered_input_df[filtered_input_df[name] == value]

# Check to Ensure Columns Match
for col in filtered_output_df.columns:
Expand Down
30 changes: 24 additions & 6 deletions fastparquet/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def find_type(data, fixed_text=None, object_encoding=None, times='int64',
"LogicalType",
TIMESTAMP=ThriftObject.from_fields(
"TimestampType",
isAdjustedToUTC=True,
isAdjustedToUTC=tz,
unit=ThriftObject.from_fields("TimeUnit", MICROS={})
)
)
Expand All @@ -195,7 +195,7 @@ def find_type(data, fixed_text=None, object_encoding=None, times='int64',
"LogicalType",
TIMESTAMP=ThriftObject.from_fields(
"TimestampType",
isAdjustedToUTC=True,
isAdjustedToUTC=tz,
unit=ThriftObject.from_fields("TimeUnit", MILLIS={})
)
)
Expand All @@ -214,7 +214,7 @@ def find_type(data, fixed_text=None, object_encoding=None, times='int64',
elif dtype.kind == "m":
type, converted_type, width = (parquet_thrift.Type.INT64,
parquet_thrift.ConvertedType.TIME_MICROS, None)
elif "string" in str(dtype):
elif "str" in str(dtype):
type, converted_type, width = (parquet_thrift.Type.BYTE_ARRAY,
parquet_thrift.ConvertedType.UTF8,
None)
Expand Down Expand Up @@ -283,7 +283,7 @@ def convert(data, se):
raise ValueError('Error converting column "%s" to bytes using '
'encoding %s. Original error: '
'%s' % (data.name, ct, e))
elif str(dtype) == "string":
elif "str" in str(dtype):
try:
if converted_type == parquet_thrift.ConvertedType.UTF8:
# TODO: into bytes in one step
Expand Down Expand Up @@ -315,12 +315,30 @@ def convert(data, se):
out['ns'] = ns
out['day'] = day
elif dtype.kind == "M":
out = data.values.view("int64")
part = str(dtype).split("[")[1][:-1].split(",")[0]
if converted_type:
factor = time_factors[(converted_type, part)]
else:
unit = [k for k, v in se.logicalType.TIMESTAMP.unit._asdict().items() if v is not None][0]
factor = time_factors[(unit, part)]
try:
out = data.values.view("int64") * factor
except KeyError:
breakpoint()
else:
raise ValueError("Don't know how to convert data type: %s" % dtype)
return out


time_factors = {
("NANOS", "ns"): 1,
(parquet_thrift.ConvertedType.TIMESTAMP_MICROS, "us"): 1,
(parquet_thrift.ConvertedType.TIMESTAMP_MICROS, "ns"): 1000,
(parquet_thrift.ConvertedType.TIMESTAMP_MILLIS, "ms"): 1,
(parquet_thrift.ConvertedType.TIMESTAMP_MILLIS, "s"): 1000,
}


def infer_object_encoding(data):
"""Guess object type from first 10 non-na values by iteration"""
if data.empty:
Expand Down Expand Up @@ -449,7 +467,7 @@ def _rows_per_page(data, selement, has_nulls=True, page_size=None):
bytes_per_element = 4
elif isinstance(data.dtype, BaseMaskedDtype) and data.dtype in pdoptional_to_numpy_typemap:
bytes_per_element = np.dtype(pdoptional_to_numpy_typemap[data.dtype]).itemsize
elif data.dtype == "object" or str(data.dtype) == "string":
elif data.dtype == "object" or "str" in str(data.dtype):
dd = data.iloc[:1000]
d2 = dd[dd.notnull()]
try:
Expand Down

0 comments on commit 5bb9fa7

Please sign in to comment.