Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some compatibility fixes #933

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion fastparquet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,6 @@ def read_row_group_arrays(file, rg, columns, categories, schema_helper, cats,
for k in remains:
out[k][:] = None


def read_row_group(file, rg, columns, categories, schema_helper, cats,
selfmade=False, index=None, assign=None,
scheme='hive', partition_meta=None, row_filter=False):
Expand Down
15 changes: 10 additions & 5 deletions fastparquet/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,19 @@ def cat(col):
# validation due to being an out-of-bounds datetime. xref
# https://github.com/dask/fastparquet/issues/778
dtype = np.dtype(t)
d = np.zeros(size, dtype=dtype) if dtype.kind == "M" else np.empty(size, dtype=dtype)
if d.dtype.kind == "M" and str(col) in timezones:
if dtype.kind == "M":
d = np.zeros(size, dtype=dtype)
# 1) create the DatetimeIndex in UTC as no datetime conversion is needed and
# it works with d uninitialised data (no NonExistentTimeError or AmbiguousTimeError)
# 2) convert to timezone (if UTC=noop, if None=remove tz, if other=change tz)
index = DatetimeIndex(d, tz="UTC").tz_convert(
tz_to_dt_tz(timezones[str(col)]))
if str(col) in timezones:
index = DatetimeIndex(d, tz="UTC").tz_convert(
tz_to_dt_tz(timezones[str(col)]))
else:
index = DatetimeIndex(d, tz=None)
d = index._data._ndarray
else:
d = np.empty(size, dtype=dtype)
index = Index(d)
views[col] = d
else:
Expand Down Expand Up @@ -238,7 +243,7 @@ def set_cats(values, i=i, col=col, **kwargs):
views[col] = block.values._codes
views[col+'-catdef'] = block.values
elif getattr(block.dtype, 'tz', None):
arr = np.asarray(block.values, dtype='M8[ns]')
arr = block.values._ndarray
if len(arr.shape) > 1:
# pandas >= 1.3 does this for some reason
arr = arr.squeeze(axis=0)
Expand Down
27 changes: 27 additions & 0 deletions fastparquet/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,10 +1548,37 @@ def test_read_a_non_pandas_parquet_file(tempdir):
assert parquet_file.head(1).equals(pd.DataFrame({"foo": [0], "bar": ["a"]}))


def test_gh929(tempdir):
idx = pd.date_range("2024-01-01", periods=4, freq="h", tz="Europe/Brussels")
df = pd.DataFrame(index=idx, data={"index_as_col": idx})

df.to_parquet(f"{tempdir}/test_datetimetz_index.parquet", engine="fastparquet")
result = pd.read_parquet(f"{tempdir}/test_datetimetz_index.parquet", engine="fastparquet")
assert result.index.equals(df.index)


def test_writing_to_buffer_does_not_close():
df = pd.DataFrame({"val": [1, 2]})
buffer = io.BytesIO()
write(buffer, df, file_scheme="simple")
assert not buffer.closed
parquet_file = ParquetFile(buffer)
assert parquet_file.count() == 2


@pytest.fixture()
def pandas_string():
if pd.__version__.split(".") < ["3"]:
pytest.skip("'string' type coming in pandas 3.0.0")
original = pd.options.future.infer_string
pd.options.future.infer_string = True
yield
pd.options.future.infer_string = original


def test_auto_string(tempdir, pandas_string):
fn = f"{tempdir}/test.parquet"
df = pd.DataFrame({"a": ["some", "strings"]})
df.to_parquet(fn, engine="fastparquet")


22 changes: 11 additions & 11 deletions fastparquet/test/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def np_empty_mock(shape, dtype):
def test_empty_tz_nonutc():
df, views = empty(types=[DatetimeTZDtype(unit="ns", tz="CET")], size=8784, cols=['a'],
timezones={'a': 'CET', 'index': 'CET'}, index_types=["datetime64[ns]"], index_names=["index"])
assert df.index.tz.zone == "CET"
assert df.a.dtype.tz.zone == "CET"
assert str(df.index.tz) == "CET"
assert str(df.a.dtype.tz) == "CET"


# non-regression test for https://github.com/dask/fastparquet/issues/778
Expand All @@ -91,18 +91,18 @@ def test_timestamps():
views['t'].dtype.kind == "M"

df, views = empty('M8', 100, cols=['t'], timezones={'t': z})
assert df.t.dt.tz.zone == z
assert str(df.t.dt.tz) == z
views['t'].dtype.kind == "M"

# one time column, one normal
df, views = empty('M8,i', 100, cols=['t', 'i'], timezones={'t': z})
assert df.t.dt.tz.zone == z
assert str(df.t.dt.tz) == z
views['t'].dtype.kind == "M"
views['i'].dtype.kind == 'i'

# no effect of timezones= on non-time column
df, views = empty('M8,i', 100, cols=['t', 'i'], timezones={'t': z, 'i': z})
assert df.t.dt.tz.zone == z
assert str(df.t.dt.tz) == z
assert df.i.dtype.kind == 'i'
views['t'].dtype.kind == "M"
views['i'].dtype.kind == 'i'
Expand All @@ -111,22 +111,22 @@ def test_timestamps():
z2 = 'US/Central'
df, views = empty('M8,M8', 100, cols=['t1', 't2'], timezones={'t1': z,
't2': z})
assert df.t1.dt.tz.zone == z
assert df.t2.dt.tz.zone == z
assert str(df.t1.dt.tz) == z
assert str(df.t2.dt.tz) == z

df, views = empty('M8,M8', 100, cols=['t1', 't2'], timezones={'t1': z})
assert df.t1.dt.tz.zone == z
assert str(df.t1.dt.tz) == z
assert df.t2.dt.tz is None

df, views = empty('M8,M8', 100, cols=['t1', 't2'], timezones={'t1': z,
't2': 'UTC'})
assert df.t1.dt.tz.zone == z
assert str(df.t1.dt.tz) == z
assert str(df.t2.dt.tz) == 'UTC'

df, views = empty('M8,M8', 100, cols=['t1', 't2'], timezones={'t1': z,
't2': z2})
assert df.t1.dt.tz.zone == z
assert df.t2.dt.tz.zone == z2
assert str(df.t1.dt.tz) == z
assert str(df.t2.dt.tz) == z2


def test_pandas_hive_serialization(tmpdir):
Expand Down
10 changes: 5 additions & 5 deletions fastparquet/test/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def test_roundtrip_complex(tempdir, scheme,):
@pytest.mark.parametrize('df', [
makeMixedDataFrame(),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz='Europe/London')}),
periods=10, freq='h', tz='Europe/London')}),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz='Europe/Berlin')}),
periods=10, freq='h', tz='Europe/Berlin')}),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz='UTC')}),
periods=10, freq='h', tz='UTC')}),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz=datetime.timezone.min)}),
periods=10, freq='h', tz=datetime.timezone.min)}),
pd.DataFrame({'x': pd.date_range('3/6/2012 00:00',
periods=10, freq='H', tz=datetime.timezone.max)})
periods=10, freq='h', tz=datetime.timezone.max)})
])
def test_datetime_roundtrip(tempdir, df, capsys):
fname = os.path.join(tempdir, 'test.parquet')
Expand Down
46 changes: 22 additions & 24 deletions fastparquet/test/test_partition_filters_specialstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,30 @@ def frame_symbol_dtTrade_type_strike(days=1 * 252,
@pytest.mark.parametrize('input_symbols,input_days,file_scheme,input_columns,'
'partitions,filters',
[
(['NOW', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['now', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['TODAY', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['VIX*', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['QQQ*', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['QQQ!', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['Q%QQ', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
['symbol', 'year'], [('symbol', '==', 'SPY')]),
(['NOW', 'SPY', 'VIX'], 10, 'hive', 2,
['symbol', 'dtTrade'], [('symbol', '==', 'SPY')]),
# (['NOW', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['now', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['TODAY', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['VIX*', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['QQQ*', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['QQQ!', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['Q%QQ', 'SPY', 'VIX'], 2 * 252, 'hive', 2,
# ['symbol', 'year'], [('symbol', '==', 'SPY')]),
# (['NOW', 'SPY', 'VIX'], 10, 'hive', 2,
# ['symbol', 'dtTrade'], [('symbol', '==', 'SPY')]),
(['NOW', 'SPY', 'VIX'], 10, 'hive', 2,
['symbol', 'dtTrade'],
[('dtTrade', '==',
'2005-01-02 00:00:00')]),
(['NOW', 'SPY', 'VIX'], 10, 'hive', 2,
['symbol', 'dtTrade'],
[('dtTrade', '==',
pd.to_datetime('2005-01-02 00:00:00'))]),
]
)
def test_frame_write_read_verify(tempdir, input_symbols, input_days,
Expand Down Expand Up @@ -88,15 +92,9 @@ def test_frame_write_read_verify(tempdir, input_symbols, input_days,

# Filter Input Frame to Match What Should Be Expected from parquet read
# Handle either string or non-string inputs / works for timestamps
filterStrings = []
filtered_input_df = input_df
for name, operator, value in filters:
if isinstance(value, str):
value = "'{}'".format(value)
else:
value = value.__repr__()
filterStrings.append("{} {} {}".format(name, operator, value))
filters_expression = " and ".join(filterStrings)
filtered_input_df = input_df.query(filters_expression)
filtered_input_df = filtered_input_df[filtered_input_df[name] == value]

# Check to Ensure Columns Match
for col in filtered_output_df.columns:
Expand Down
30 changes: 24 additions & 6 deletions fastparquet/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def find_type(data, fixed_text=None, object_encoding=None, times='int64',
"LogicalType",
TIMESTAMP=ThriftObject.from_fields(
"TimestampType",
isAdjustedToUTC=True,
isAdjustedToUTC=tz,
unit=ThriftObject.from_fields("TimeUnit", MICROS={})
)
)
Expand All @@ -195,7 +195,7 @@ def find_type(data, fixed_text=None, object_encoding=None, times='int64',
"LogicalType",
TIMESTAMP=ThriftObject.from_fields(
"TimestampType",
isAdjustedToUTC=True,
isAdjustedToUTC=tz,
unit=ThriftObject.from_fields("TimeUnit", MILLIS={})
)
)
Expand All @@ -214,7 +214,7 @@ def find_type(data, fixed_text=None, object_encoding=None, times='int64',
elif dtype.kind == "m":
type, converted_type, width = (parquet_thrift.Type.INT64,
parquet_thrift.ConvertedType.TIME_MICROS, None)
elif "string" in str(dtype):
elif "str" in str(dtype):
type, converted_type, width = (parquet_thrift.Type.BYTE_ARRAY,
parquet_thrift.ConvertedType.UTF8,
None)
Expand Down Expand Up @@ -283,7 +283,7 @@ def convert(data, se):
raise ValueError('Error converting column "%s" to bytes using '
'encoding %s. Original error: '
'%s' % (data.name, ct, e))
elif str(dtype) == "string":
elif "str" in str(dtype):
try:
if converted_type == parquet_thrift.ConvertedType.UTF8:
# TODO: into bytes in one step
Expand Down Expand Up @@ -315,12 +315,30 @@ def convert(data, se):
out['ns'] = ns
out['day'] = day
elif dtype.kind == "M":
out = data.values.view("int64")
part = str(dtype).split("[")[1][:-1].split(",")[0]
if converted_type:
factor = time_factors[(converted_type, part)]
else:
unit = [k for k, v in se.logicalType.TIMESTAMP.unit._asdict().items() if v is not None][0]
factor = time_factors[(unit, part)]
try:
out = data.values.view("int64") * factor
except KeyError:
breakpoint()
else:
raise ValueError("Don't know how to convert data type: %s" % dtype)
return out


time_factors = {
("NANOS", "ns"): 1,
(parquet_thrift.ConvertedType.TIMESTAMP_MICROS, "us"): 1,
(parquet_thrift.ConvertedType.TIMESTAMP_MICROS, "ns"): 1000,
(parquet_thrift.ConvertedType.TIMESTAMP_MILLIS, "ms"): 1,
(parquet_thrift.ConvertedType.TIMESTAMP_MILLIS, "s"): 1000,
}


def infer_object_encoding(data):
"""Guess object type from first 10 non-na values by iteration"""
if data.empty:
Expand Down Expand Up @@ -449,7 +467,7 @@ def _rows_per_page(data, selement, has_nulls=True, page_size=None):
bytes_per_element = 4
elif isinstance(data.dtype, BaseMaskedDtype) and data.dtype in pdoptional_to_numpy_typemap:
bytes_per_element = np.dtype(pdoptional_to_numpy_typemap[data.dtype]).itemsize
elif data.dtype == "object" or str(data.dtype) == "string":
elif data.dtype == "object" or "str" in str(data.dtype):
dd = data.iloc[:1000]
d2 = dd[dd.notnull()]
try:
Expand Down
Loading