From e575bb53530295fc9b97f3fe843e174fe7ab5167 Mon Sep 17 00:00:00 2001 From: Derek Wan Date: Wed, 11 Sep 2024 20:33:17 +0900 Subject: [PATCH] Unify redis cm (#727) --- src/tests/test_hypothesis.py | 59 +- src/tests/test_redis.py | 1379 ++++++++++++---------------------- src/utilities/__init__.py | 2 +- src/utilities/hypothesis.py | 122 +-- src/utilities/redis.py | 29 +- 5 files changed, 587 insertions(+), 1004 deletions(-) diff --git a/src/tests/test_hypothesis.py b/src/tests/test_hypothesis.py index 430d9d878..1eb8b1950 100644 --- a/src/tests/test_hypothesis.py +++ b/src/tests/test_hypothesis.py @@ -39,7 +39,6 @@ _ZONED_DATETIMES_LEFT_MOST, _ZONED_DATETIMES_RIGHT_MOST, Shape, - YieldRedisContainer, aiosqlite_engines, assume_does_not_raise, bool_arrays, @@ -57,7 +56,6 @@ months, random_states, redis_cms, - redis_cms_async, sets_fixed_length, settings_with_reduced_examples, setup_hypothesis_profiles, @@ -490,37 +488,38 @@ def test_main(self, *, data: DataObject) -> None: @SKIPIF_CI_AND_NOT_LINUX class TestRedisCMs: - @given(yield_redis=redis_cms(), value=integers()) - def test_sync_core(self, *, yield_redis: YieldRedisContainer, value: int) -> None: - with yield_redis() as redis: - assert not redis.client.exists(redis.key) - _ = redis.client.set(redis.key, value) - result = int(cast(str, redis.client.get(redis.key))) - assert result == value - - @given(yield_redis=redis_cms(), value=int32s()) - def test_sync_ts(self, *, yield_redis: YieldRedisContainer, value: int) -> None: - with yield_redis() as redis: - assert not redis.client.exists(redis.key) - _ = redis.ts.add(redis.key, "*", value) - res_timestamp, res_value = redis.ts.get(redis.key) - assert isinstance(res_timestamp, int) - assert int(res_value) == value - - @given(data=data(), value=integers()) - async def test_async_core(self, *, data: DataObject, value: int) -> None: - async with redis_cms_async(data) as redis: - assert not await redis.client.exists(redis.key) - _ = await redis.client.set(redis.key, value) - result = int(cast(str, await redis.client.get(redis.key))) + @given(data=data(), value=int32s()) + async def test_core(self, *, data: DataObject, value: int) -> None: + import redis + import redis.asyncio + + async with redis_cms(data) as container: + match container.client: + case redis.Redis() as client: + assert not client.exists(container.key) + _ = client.set(container.key, value) + result = int(cast(str, client.get(container.key))) + case redis.asyncio.Redis() as client: + assert not await client.exists(container.key) + _ = await client.set(container.key, value) + result = int(cast(str, await client.get(container.key))) assert result == value @given(data=data(), value=int32s()) - async def test_async_ts(self, *, data: DataObject, value: int) -> None: - async with redis_cms_async(data) as redis: - assert not await redis.client.exists(redis.key) - _ = await redis.ts.add(redis.key, "*", value) - res_timestamp, res_value = await redis.ts.get(redis.key) + async def test_ts(self, *, data: DataObject, value: int) -> None: + import redis + import redis.asyncio + + async with redis_cms(data) as container: + match container.client: + case redis.Redis() as client: + assert not client.exists(container.key) + _ = container.ts.add(container.key, "*", value) + res_timestamp, res_value = container.ts.get(container.key) + case redis.asyncio.Redis() as client: + assert not await client.exists(container.key) + _ = await container.ts.add(container.key, "*", value) + res_timestamp, res_value = await container.ts.get(container.key) assert isinstance(res_timestamp, int) assert int(res_value) == value diff --git a/src/tests/test_redis.py b/src/tests/test_redis.py index 0a71206f8..a862e503e 100644 --- a/src/tests/test_redis.py +++ b/src/tests/test_redis.py @@ -1,8 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass from math import inf, nan -from typing import TYPE_CHECKING, ClassVar, Literal, cast +from typing import TYPE_CHECKING, ClassVar, cast import redis import redis.asyncio @@ -17,27 +16,41 @@ ) from polars import Boolean, DataFrame, DataType, Float64, Int64, Utf8 from polars.testing import assert_frame_equal -from pytest import mark, param, raises +from pytest import raises from redis.commands.timeseries import TimeSeries from tests.conftest import SKIPIF_CI_AND_NOT_LINUX from utilities.datetime import EPOCH_UTC, drop_microseconds from utilities.hypothesis import ( - YieldRedisContainer, int32s, lists_fixed_length, redis_cms, - redis_cms_async, text_ascii, zoned_datetimes, ) from utilities.polars import DatetimeUTC, check_polars_dataframe, zoned_datetime from utilities.redis import ( - TimeSeriesAddDataFrameError, - TimeSeriesAddError, - TimeSeriesMAddError, - TimeSeriesRangeError, - TimeSeriesReadDataFrameError, + _TimeSeriesAddDataFrameKeyIsNotUtf8Error, + _TimeSeriesAddDataFrameKeyMissingError, + _TimeSeriesAddDataFrameTimestampIsNotAZonedDatetimeError, + _TimeSeriesAddDataFrameTimestampMissingError, + _TimeSeriesAddErrorAtUpsertError, + _TimeSeriesAddInvalidTimestampError, + _TimeSeriesAddInvalidValueError, + _TimeSeriesMAddInvalidKeyError, + _TimeSeriesMAddInvalidTimestampError, + _TimeSeriesMAddInvalidValueError, + _TimeSeriesMAddKeyIsNotUtf8Error, + _TimeSeriesMAddKeyMissingError, + _TimeSeriesMAddTimestampIsNotAZonedDatetimeError, + _TimeSeriesMAddTimestampMissingError, + _TimeSeriesMAddValueIsNotNumericError, + _TimeSeriesMAddValueMissingError, + _TimeSeriesRangeInvalidKeyError, + _TimeSeriesRangeKeyWithInt64AndFloat64Error, + _TimeSeriesRangeNoKeysRequestedError, + _TimeSeriesReadDataFrameNoColumnsRequestedError, + _TimeSeriesReadDataFrameNoKeysRequestedError, ensure_time_series_created, ensure_time_series_created_async, time_series_add, @@ -63,10 +76,11 @@ import datetime as dt from zoneinfo import ZoneInfo - from polars._typing import PolarsDataType, SchemaDict + from polars._typing import SchemaDict from utilities.types import Number + valid_zoned_datetimes = zoned_datetimes( min_value=EPOCH_UTC, time_zone=sampled_from([HongKong, UTC]), valid=True ).map(drop_microseconds) @@ -81,96 +95,54 @@ @SKIPIF_CI_AND_NOT_LINUX class TestEnsureTimeSeriesCreated: - @given(yield_redis=redis_cms()) - def test_sync(self, *, yield_redis: YieldRedisContainer) -> None: - with yield_redis() as redis: - assert redis.client.exists(redis.key) == 0 - for _ in range(2): - ensure_time_series_created(redis.client.ts(), redis.key) - assert redis.client.exists(redis.key) == 1 - @given(data=data()) - async def test_async(self, *, data: DataObject) -> None: - async with redis_cms_async(data) as redis: - assert await redis.client.exists(redis.key) == 0 - for _ in range(2): - await ensure_time_series_created_async(redis.client.ts(), redis.key) - assert await redis.client.exists(redis.key) == 1 + async def test_main(self, *, data: DataObject) -> None: + async with redis_cms(data) as container: + match container.client: + case redis.Redis() as client: + assert client.exists(container.key) == 0 + for _ in range(2): + ensure_time_series_created(client.ts(), container.key) + assert client.exists(container.key) == 1 + case redis.asyncio.Redis() as client: + assert await client.exists(container.key) == 0 + for _ in range(2): + await ensure_time_series_created_async( + client.ts(), container.key + ) + assert await client.exists(container.key) == 1 @SKIPIF_CI_AND_NOT_LINUX class TestTimeSeriesAddAndGet: - @given( - yield_redis=redis_cms(), - timestamp=valid_zoned_datetimes, - value=int32s() | floats(allow_nan=False, allow_infinity=False), - ) - def test_sync( - self, *, yield_redis: YieldRedisContainer, timestamp: dt.datetime, value: float - ) -> None: - with yield_redis() as redis: - result = time_series_add(redis.ts, redis.key, timestamp, value) - assert isinstance(result, int) - res_timestamp, res_value = time_series_get(redis.ts, redis.key) - assert res_timestamp == timestamp.astimezone(UTC) - assert res_value == value - - @given( - yield_redis=redis_cms(), - timestamp=valid_zoned_datetimes, - value=int32s() | floats(allow_nan=False, allow_infinity=False), - ) - def test_sync_error_at_upsert( - self, *, yield_redis: YieldRedisContainer, timestamp: dt.datetime, value: float - ) -> None: - with yield_redis() as redis: - _ = time_series_add(redis.ts, redis.key, timestamp, value) - with raises( - TimeSeriesAddError, - match="Error at upsert under DUPLICATE_POLICY == 'BLOCK'; got .*", - ): - _ = time_series_add(redis.ts, redis.key, timestamp, value) - - @given( - yield_redis=redis_cms(), - timestamp=invalid_zoned_datetimes, - value=int32s() | floats(allow_nan=False, allow_infinity=False), - ) - def test_sync_invalid_timestamp( - self, *, yield_redis: YieldRedisContainer, timestamp: dt.datetime, value: float - ) -> None: - _ = assume(timestamp < EPOCH_UTC) - with ( - yield_redis() as redis, - raises( - TimeSeriesAddError, match="Timestamp must be at least the Epoch; got .*" - ), - ): - _ = time_series_add(redis.ts, redis.key, timestamp, value) - - @given(yield_redis=redis_cms(), timestamp=valid_zoned_datetimes) - @mark.parametrize("value", [param(inf), param(-inf), param(nan)]) - def test_sync_invalid_value( - self, *, yield_redis: YieldRedisContainer, timestamp: dt.datetime, value: float - ) -> None: - with ( - yield_redis() as redis, - raises(TimeSeriesAddError, match="Invalid value; got .*"), - ): - _ = time_series_add(redis.ts, redis.key, timestamp, value) - @given( data=data(), timestamp=valid_zoned_datetimes, value=int32s() | floats(allow_nan=False, allow_infinity=False), ) - async def test_async( + async def test_main( self, *, data: DataObject, timestamp: dt.datetime, value: float ) -> None: - async with redis_cms_async(data) as redis: - result = await time_series_add_async(redis.ts, redis.key, timestamp, value) + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + result = time_series_add( + container.ts, container.key, timestamp, value + ) + case redis.asyncio.Redis(): + result = await time_series_add_async( + container.ts, container.key, timestamp, value + ) assert isinstance(result, int) - res_timestamp, res_value = await time_series_get_async(redis.ts, redis.key) + match container.client: + case redis.Redis(): + res_timestamp, res_value = time_series_get( + container.ts, container.key + ) + case redis.asyncio.Redis(): + res_timestamp, res_value = await time_series_get_async( + container.ts, container.key + ) assert res_timestamp == timestamp.astimezone(UTC) assert res_value == value @@ -179,383 +151,253 @@ async def test_async( timestamp=valid_zoned_datetimes, value=int32s() | floats(allow_nan=False, allow_infinity=False), ) - async def test_async_error_at_upsert( + async def test_error_at_upsert( self, *, data: DataObject, timestamp: dt.datetime, value: float ) -> None: - async with redis_cms_async(data) as redis: - _ = await time_series_add_async(redis.ts, redis.key, timestamp, value) - with raises( - TimeSeriesAddError, - match="Error at upsert under DUPLICATE_POLICY == 'BLOCK'; got .*", - ): - _ = await time_series_add_async(redis.ts, redis.key, timestamp, value) + match = "Error at upsert under DUPLICATE_POLICY == 'BLOCK'; got .*" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + _ = time_series_add(container.ts, container.key, timestamp, value) + with raises(_TimeSeriesAddErrorAtUpsertError, match=match): + _ = time_series_add( + container.ts, container.key, timestamp, value + ) + case redis.asyncio.Redis(): + _ = await time_series_add_async( + container.ts, container.key, timestamp, value + ) + with raises(_TimeSeriesAddErrorAtUpsertError, match=match): + _ = await time_series_add_async( + container.ts, container.key, timestamp, value + ) @given( data=data(), timestamp=invalid_zoned_datetimes, value=int32s() | floats(allow_nan=False, allow_infinity=False), ) - async def test_async_invalid_timestamp( + async def test_error_invalid_timestamp( self, *, data: DataObject, timestamp: dt.datetime, value: float ) -> None: _ = assume(timestamp < EPOCH_UTC) - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesAddError, match="Timestamp must be at least the Epoch; got .*" - ): - _ = await time_series_add_async(redis.ts, redis.key, timestamp, value) + match = "Timestamp must be at least the Epoch; got .*" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesAddInvalidTimestampError, match=match): + _ = time_series_add( + container.ts, container.key, timestamp, value + ) + case redis.asyncio.Redis(): + with raises(_TimeSeriesAddInvalidTimestampError, match=match): + _ = await time_series_add_async( + container.ts, container.key, timestamp, value + ) - @given(data=data(), timestamp=valid_zoned_datetimes) - @mark.parametrize("value", [param(inf), param(-inf), param(nan)]) - async def test_async_invalid_value( + @given( + data=data(), + timestamp=valid_zoned_datetimes, + value=sampled_from([inf, -inf, nan]), + ) + async def test_error_invalid_value( self, *, data: DataObject, timestamp: dt.datetime, value: float ) -> None: - async with redis_cms_async(data) as redis: - with raises(TimeSeriesAddError, match="Invalid value; got .*"): - _ = await time_series_add_async(redis.ts, redis.key, timestamp, value) - - -@dataclass(frozen=True, kw_only=True) -class _TestTimeSeriesAddAndReadDataFramePrepare: - df: DataFrame - key: str - timestamp: str - keys: tuple[str, str] - columns: tuple[str, str] - time_zone: ZoneInfo - schema: SchemaDict + match = "Invalid value; got .*" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesAddInvalidValueError, match=match): + _ = time_series_add( + container.ts, container.key, timestamp, value + ) + case redis.asyncio.Redis(): + with raises(_TimeSeriesAddInvalidValueError, match=match): + _ = await time_series_add_async( + container.ts, container.key, timestamp, value + ) @SKIPIF_CI_AND_NOT_LINUX class TestTimeSeriesAddAndReadDataFrame: - schema: ClassVar[SchemaDict] = { - "key": Utf8, - "timestamp": DatetimeUTC, - "value": Float64, - } - @given( data=data(), - yield_redis=redis_cms(), - series_names=lists_fixed_length(text_ascii(), 2, unique=True).map(tuple), key_timestamp_values=lists_fixed_length(text_ascii(), 4, unique=True).map( tuple ), + strategy_dtype1=sampled_from([ + (int32s(), Int64), + (floats(allow_nan=False, allow_infinity=False), Float64), + ]), + strategy_dtype2=sampled_from([ + (int32s(), Int64), + (floats(allow_nan=False, allow_infinity=False), Float64), + ]), time_zone=sampled_from([HongKong, UTC]), - ) - @mark.parametrize( - ("strategy1", "dtype1"), - [ - param(int32s(), Int64), - param(floats(allow_nan=False, allow_infinity=False), Float64), - ], - ) - @mark.parametrize( - ("strategy2", "dtype2"), - [ - param(int32s(), Int64), - param(floats(allow_nan=False, allow_infinity=False), Float64), - ], - ) - def test_sync( - self, - *, - data: DataObject, - yield_redis: YieldRedisContainer, - series_names: tuple[str, str], - strategy1: SearchStrategy[Number], - strategy2: SearchStrategy[Number], - key_timestamp_values: tuple[str, str, str, str], - time_zone: ZoneInfo, - dtype1: DataType, - dtype2: DataType, - ) -> None: - with yield_redis() as redis: - prepared = self._prepare_main_test( - data, - redis.key, - series_names, - strategy1, - strategy2, - key_timestamp_values, - time_zone, - dtype1, - dtype2, - ) - time_series_add_dataframe( - redis.ts, prepared.df, key=prepared.key, timestamp=prepared.timestamp - ) - result = time_series_read_dataframe( - redis.ts, - prepared.keys, - prepared.columns, - output_key=prepared.key, - output_timestamp=prepared.timestamp, - output_time_zone=prepared.time_zone, - ) - check_polars_dataframe(result, height=2, schema_list=prepared.schema) - assert_frame_equal(result, prepared.df) - - @given(yield_redis=redis_cms()) - def test_sync_error_add_key_missing( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame() - with ( - yield_redis() as redis, - raises( - TimeSeriesAddDataFrameError, - match="DataFrame must have a 'key' column; got .*", - ), - ): - _ = time_series_add_dataframe(redis.ts, df) - - @given(yield_redis=redis_cms()) - def test_sync_error_add_timestamp_missing( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame(schema={"key": Utf8}) - with ( - yield_redis() as redis, - raises( - TimeSeriesAddDataFrameError, - match="DataFrame must have a 'timestamp' column; got .*", - ), - ): - _ = time_series_add_dataframe(redis.ts, df) - - @given(yield_redis=redis_cms()) - def test_sync_error_add_key_is_not_utf8( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame(schema={"key": Boolean, "timestamp": DatetimeUTC}) - with ( - yield_redis() as redis, - raises( - TimeSeriesAddDataFrameError, - match="The 'key' column must be Utf8; got Boolean", - ), - ): - _ = time_series_add_dataframe(redis.ts, df) - - @given(yield_redis=redis_cms()) - def test_sync_error_madd_timestamp_is_not_a_zoned_datetime( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame(schema={"key": Utf8, "timestamp": Boolean}) - with ( - yield_redis() as redis, - raises( - TimeSeriesAddDataFrameError, - match="The 'timestamp' column must be a zoned Datetime; got Boolean", - ), - ): - _ = time_series_add_dataframe(redis.ts, df) - - @given(yield_redis=redis_cms()) - def test_sync_error_read_no_keys_requested( - self, *, yield_redis: YieldRedisContainer - ) -> None: - with ( - yield_redis() as redis, - raises( - TimeSeriesReadDataFrameError, match="At least 1 key must be requested" - ), - ): - _ = time_series_read_dataframe(redis.ts, [], []) - - @given(yield_redis=redis_cms()) - def test_sync_error_read_no_columns_requested( - self, *, yield_redis: YieldRedisContainer - ) -> None: - with ( - yield_redis() as redis, - raises( - TimeSeriesReadDataFrameError, - match="At least 1 column must be requested", - ), - ): - _ = time_series_read_dataframe(redis.ts, redis.key, []) - - @given( - data=data(), series_names=lists_fixed_length(text_ascii(), 2, unique=True).map(tuple), - key_timestamp_values=lists_fixed_length(text_ascii(), 4, unique=True).map( - tuple - ), - time_zone=sampled_from([HongKong, UTC]), ) - @mark.parametrize( - ("strategy1", "dtype1"), - [ - param(int32s(), Int64), - param(floats(allow_nan=False, allow_infinity=False), Float64), - ], - ) - @mark.parametrize( - ("strategy2", "dtype2"), - [ - param(int32s(), Int64), - param(floats(allow_nan=False, allow_infinity=False), Float64), - ], - ) - async def test_async( + async def test_main( self, *, data: DataObject, - series_names: tuple[str, str], - strategy1: SearchStrategy[Number], - strategy2: SearchStrategy[Number], + strategy_dtype1: tuple[SearchStrategy[Number], DataType], + strategy_dtype2: tuple[SearchStrategy[Number], DataType], key_timestamp_values: tuple[str, str, str, str], + series_names: tuple[str, str], time_zone: ZoneInfo, - dtype1: DataType, - dtype2: DataType, ) -> None: - async with redis_cms_async(data) as redis: - prepared = self._prepare_main_test( - data, - redis.key, - series_names, - strategy1, - strategy2, - key_timestamp_values, - time_zone, - dtype1, - dtype2, - ) - await time_series_add_dataframe_async( - redis.ts, prepared.df, key=prepared.key, timestamp=prepared.timestamp + timestamp1, timestamp2 = data.draw( + tuples(valid_zoned_datetimes, valid_zoned_datetimes) + ) + strategy1, dtype1 = strategy_dtype1 + strategy2, dtype2 = strategy_dtype2 + value11, value21 = data.draw(tuples(strategy1, strategy1)) + value12, value22 = data.draw(tuples(strategy2, strategy2)) + key, timestamp, column1, column2 = key_timestamp_values + columns = column1, column2 + schema = { + key: Utf8, + timestamp: zoned_datetime(time_zone=time_zone), + column1: dtype1, + column2: dtype2, + } + async with redis_cms(data) as container: + key1, key2 = keys = cast( + tuple[str, str], tuple(f"{container.key}_{id_}" for id_ in series_names) ) - result = await time_series_read_dataframe_async( - redis.ts, - prepared.keys, - prepared.columns, - output_key=prepared.key, - output_timestamp=prepared.timestamp, - output_time_zone=prepared.time_zone, + df = DataFrame( + [ + (key1, timestamp1, value11, value12), + (key2, timestamp2, value21, value22), + ], + schema=schema, + orient="row", ) - check_polars_dataframe(result, height=2, schema_list=prepared.schema) - assert_frame_equal(result, prepared.df) + match container.client: + case redis.Redis(): + time_series_add_dataframe( + container.ts, df, key=key, timestamp=timestamp + ) + result = time_series_read_dataframe( + container.ts, + keys, + columns, + output_key=key, + output_timestamp=timestamp, + output_time_zone=time_zone, + ) + case redis.asyncio.Redis(): + await time_series_add_dataframe_async( + container.ts, df, key=key, timestamp=timestamp + ) + result = await time_series_read_dataframe_async( + container.ts, + keys, + columns, + output_key=key, + output_timestamp=timestamp, + output_time_zone=time_zone, + ) + check_polars_dataframe(result, height=2, schema_list=schema) + assert_frame_equal(result, df) @given(data=data()) - async def test_async_error_add_key_missing(self, *, data: DataObject) -> None: + async def test_error_add_key_missing(self, *, data: DataObject) -> None: df = DataFrame() - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesAddDataFrameError, - match="DataFrame must have a 'key' column; got .*", - ): - _ = await time_series_add_dataframe_async(redis.ts, df) + match = "DataFrame must have a 'key' column; got .*" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesAddDataFrameKeyMissingError, match=match): + _ = time_series_add_dataframe(container.ts, df) + case redis.asyncio.Redis(): + with raises(_TimeSeriesAddDataFrameKeyMissingError, match=match): + _ = await time_series_add_dataframe_async(container.ts, df) @given(data=data()) - async def test_async_error_add_timestamp_missing(self, *, data: DataObject) -> None: + async def test_error_add_timestamp_missing(self, *, data: DataObject) -> None: df = DataFrame(schema={"key": Utf8}) - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesAddDataFrameError, - match="DataFrame must have a 'timestamp' column; got .*", - ): - _ = await time_series_add_dataframe_async(redis.ts, df) + match = "DataFrame must have a 'timestamp' column; got .*" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises( + _TimeSeriesAddDataFrameTimestampMissingError, match=match + ): + _ = time_series_add_dataframe(container.ts, df) + case redis.asyncio.Redis(): + with raises( + _TimeSeriesAddDataFrameTimestampMissingError, match=match + ): + _ = await time_series_add_dataframe_async(container.ts, df) @given(data=data()) - async def test_async_error_add_key_is_not_utf8(self, *, data: DataObject) -> None: + async def test_error_add_key_is_not_utf8(self, *, data: DataObject) -> None: df = DataFrame(schema={"key": Boolean, "timestamp": DatetimeUTC}) - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesAddDataFrameError, - match="The 'key' column must be Utf8; got Boolean", - ): - _ = await time_series_add_dataframe_async(redis.ts, df) + match = "The 'key' column must be Utf8; got Boolean" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesAddDataFrameKeyIsNotUtf8Error, match=match): + _ = time_series_add_dataframe(container.ts, df) + case redis.asyncio.Redis(): + with raises(_TimeSeriesAddDataFrameKeyIsNotUtf8Error, match=match): + _ = await time_series_add_dataframe_async(container.ts, df) @given(data=data()) - async def test_async_error_madd_timestamp_is_not_a_zoned_datetime( + async def test_error_add_timestamp_is_not_a_zoned_datetime( self, *, data: DataObject ) -> None: df = DataFrame(schema={"key": Utf8, "timestamp": Boolean}) - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesAddDataFrameError, - match="The 'timestamp' column must be a zoned Datetime; got Boolean", - ): - _ = await time_series_add_dataframe_async(redis.ts, df) + match = "The 'timestamp' column must be a zoned Datetime; got Boolean" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises( + _TimeSeriesAddDataFrameTimestampIsNotAZonedDatetimeError, + match=match, + ): + _ = time_series_add_dataframe(container.ts, df) + case redis.asyncio.Redis(): + with raises( + _TimeSeriesAddDataFrameTimestampIsNotAZonedDatetimeError, + match=match, + ): + _ = await time_series_add_dataframe_async(container.ts, df) @given(data=data()) - async def test_async_error_read_no_keys_requested( - self, *, data: DataObject - ) -> None: - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesReadDataFrameError, match="At least 1 key must be requested" - ): - _ = await time_series_read_dataframe_async(redis.ts, [], []) + async def test_error_read_no_keys_requested(self, *, data: DataObject) -> None: + match = "At least 1 key must be requested" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises( + _TimeSeriesReadDataFrameNoKeysRequestedError, match=match + ): + _ = time_series_read_dataframe(container.ts, [], []) + case redis.asyncio.Redis(): + with raises( + _TimeSeriesReadDataFrameNoKeysRequestedError, match=match + ): + _ = await time_series_read_dataframe_async(container.ts, [], []) @given(data=data()) - async def test_async_error_read_no_columns_requested( - self, *, data: DataObject - ) -> None: - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesReadDataFrameError, - match="At least 1 column must be requested", - ): - _ = await time_series_read_dataframe_async(redis.ts, redis.key, []) - - def _prepare_main_test( - self, - data: DataObject, - redis_key: str, - series_names: tuple[str, str], - strategy1: SearchStrategy[Number], - strategy2: SearchStrategy[Number], - key_timestamp_values: tuple[str, str, str, str], - time_zone: ZoneInfo, - dtype1: DataType, - dtype2: DataType, - /, - ) -> _TestTimeSeriesAddAndReadDataFramePrepare: - key1, key2 = keys = cast( - tuple[str, str], tuple(f"{redis_key}_{id_}" for id_ in series_names) - ) - timestamp1, timestamp2 = data.draw( - tuples(valid_zoned_datetimes, valid_zoned_datetimes) - ) - value11, value21 = data.draw(tuples(strategy1, strategy1)) - value12, value22 = data.draw(tuples(strategy2, strategy2)) - key, timestamp, column1, column2 = key_timestamp_values - schema = { - key: Utf8, - timestamp: zoned_datetime(time_zone=time_zone), - column1: dtype1, - column2: dtype2, - } - df = DataFrame( - [ - (key1, timestamp1, value11, value12), - (key2, timestamp2, value21, value22), - ], - schema=schema, - orient="row", - ) - return _TestTimeSeriesAddAndReadDataFramePrepare( - df=df, - key=key, - timestamp=timestamp, - keys=keys, - columns=(column1, column2), - time_zone=time_zone, - schema=schema, - ) - - -@dataclass(frozen=True, kw_only=True) -class _TestTimeSeriesMAddAndRangePrepare: - keys: tuple[str, str] - triples: list[tuple[str, dt.datetime, Number]] - key: str - timestamp: str - value: str - values_or_df: list[tuple[str, dt.datetime, Number]] | DataFrame - schema: SchemaDict + async def test_error_read_no_columns_requested(self, *, data: DataObject) -> None: + match = "At least 1 column must be requested" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises( + _TimeSeriesReadDataFrameNoColumnsRequestedError, match=match + ): + _ = time_series_read_dataframe(container.ts, container.key, []) + case redis.asyncio.Redis(): + with raises( + _TimeSeriesReadDataFrameNoColumnsRequestedError, match=match + ): + _ = await time_series_read_dataframe_async( + container.ts, container.key, [] + ) @SKIPIF_CI_AND_NOT_LINUX @@ -565,534 +407,291 @@ class TestTimeSeriesMAddAndRange: "timestamp": DatetimeUTC, "value": Int64, } - float_schema: ClassVar[SchemaDict] = { - "key": Utf8, - "timestamp": DatetimeUTC, - "value": Float64, - } @given( data=data(), - yield_redis=redis_cms(), series_names=lists_fixed_length(text_ascii(), 2, unique=True).map(tuple), time_zone=sampled_from([HongKong, UTC]), key_timestamp_value=lists_fixed_length(text_ascii(), 3, unique=True).map(tuple), + strategy_dtype=sampled_from([ + (int32s(), Int64), + (floats(allow_nan=False, allow_infinity=False), Float64), + ]), ) - @mark.parametrize("case", [param("values"), param("DataFrame")]) - @mark.parametrize( - ("strategy", "dtype"), - [ - param(int32s(), Int64), - param(floats(allow_nan=False, allow_infinity=False), Float64), - ], - ) - def test_sync( + async def test_main( self, *, data: DataObject, - yield_redis: YieldRedisContainer, series_names: tuple[str, str], time_zone: ZoneInfo, key_timestamp_value: tuple[str, str, str], - case: Literal["values", "DataFrame"], - strategy: SearchStrategy[Number], - dtype: PolarsDataType, - ) -> None: - with yield_redis() as redis: - prepared = self._prepare_main_test( - data, - redis.key, - series_names, - time_zone, - key_timestamp_value, - case, - strategy, - dtype, - ) - res_madd = time_series_madd( - redis.ts, - prepared.values_or_df, - key=prepared.key, - timestamp=prepared.timestamp, - value=prepared.value, - ) - assert isinstance(res_madd, list) - for i in res_madd: - assert isinstance(i, int) - res_range = time_series_range( - redis.ts, - prepared.keys, - output_key=prepared.key, - output_timestamp=prepared.timestamp, - output_time_zone=time_zone, - output_value=prepared.value, - ) - check_polars_dataframe(res_range, height=2, schema_list=prepared.schema) - assert res_range.rows() == prepared.triples - - @given(yield_redis=redis_cms()) - def test_sync_error_madd_key_missing( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame() - with ( - yield_redis() as redis, - raises( - TimeSeriesMAddError, match="DataFrame must have a 'key' column; got .*" - ), - ): - _ = time_series_madd(redis.ts, df) - - @given(yield_redis=redis_cms()) - def test_sync_error_madd_timestamp_missing( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame(schema={"key": Utf8}) - with ( - yield_redis() as redis, - raises( - TimeSeriesMAddError, - match="DataFrame must have a 'timestamp' column; got .*", - ), - ): - _ = time_series_madd(redis.ts, df) - - @given(yield_redis=redis_cms()) - def test_sync_error_madd_value_missing( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame(schema={"key": Utf8, "timestamp": DatetimeUTC}) - with ( - yield_redis() as redis, - raises( - TimeSeriesMAddError, - match="DataFrame must have a 'value' column; got .*", - ), - ): - _ = time_series_madd(redis.ts, df) - - @given(yield_redis=redis_cms()) - def test_sync_error_madd_key_is_not_utf8( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame( - schema={"key": Boolean, "timestamp": DatetimeUTC, "value": Float64} - ) - with ( - yield_redis() as redis, - raises( - TimeSeriesMAddError, match="The 'key' column must be Utf8; got Boolean" - ), - ): - _ = time_series_madd(redis.ts, df) - - @given(yield_redis=redis_cms()) - def test_sync_error_madd_timestamp_is_not_a_zoned_datetime( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame(schema={"key": Utf8, "timestamp": Boolean, "value": Float64}) - with ( - yield_redis() as redis, - raises( - TimeSeriesMAddError, - match="The 'timestamp' column must be a zoned Datetime; got Boolean", - ), - ): - _ = time_series_madd(redis.ts, df) - - @given(yield_redis=redis_cms()) - def test_sync_error_madd_value_is_not_numeric( - self, *, yield_redis: YieldRedisContainer - ) -> None: - df = DataFrame(schema={"key": Utf8, "timestamp": DatetimeUTC, "value": Boolean}) - with ( - yield_redis() as redis, - raises( - TimeSeriesMAddError, - match="The 'value' column must be numeric; got Boolean", - ), - ): - _ = time_series_madd(redis.ts, df) - - @given(data=data(), yield_redis=redis_cms()) - @mark.parametrize("case", [param("values"), param("DataFrame")]) - def test_sync_error_madd_invalid_key( - self, - *, - data: DataObject, - yield_redis: YieldRedisContainer, - case: Literal["values", "DataFrame"], + strategy_dtype: tuple[SearchStrategy[Number], DataType], ) -> None: - with yield_redis() as redis: - values_or_df = self._prepare_test_error_madd_invalid_key( - data, redis.key, case - ) - with raises(TimeSeriesMAddError, match="The key '.*' must exist"): - _ = time_series_madd( - redis.ts, values_or_df, assume_time_series_exist=True - ) - - @given(data=data(), yield_redis=redis_cms()) - @mark.parametrize("case", [param("values"), param("DataFrame")]) - def test_sync_error_madd_invalid_timestamp( - self, - *, - data: DataObject, - yield_redis: YieldRedisContainer, - case: Literal["values", "DataFrame"], - ) -> None: - with yield_redis() as redis: - values_or_df = self._prepare_test_error_madd_invalid_timestamp( - data, redis.key, case - ) - with raises( - TimeSeriesMAddError, - match="Timestamps must be at least the Epoch; got .*", - ): - _ = time_series_madd(redis.ts, values_or_df) - - @given(data=data(), yield_redis=redis_cms()) - @mark.parametrize("case", [param("values"), param("DataFrame")]) - @mark.parametrize("value", [param(inf), param(-inf), param(nan)]) - def test_sync_error_madd_invalid_value( - self, - *, - data: DataObject, - yield_redis: YieldRedisContainer, - case: Literal["values", "DataFrame"], - value: float, - ) -> None: - with yield_redis() as redis: - values_or_df = self._prepare_test_error_madd_invalid_value( - data, redis.key, case, value - ) - with raises(TimeSeriesMAddError, match="The value .* is invalid"): - _ = time_series_madd(redis.ts, values_or_df) - - @given(yield_redis=redis_cms()) - def test_sync_error_range_no_keys_requested( - self, *, yield_redis: YieldRedisContainer - ) -> None: - with ( - yield_redis() as redis, - raises( - TimeSeriesRangeError, match="At least 1 key must be requested; got .*" - ), - ): - _ = time_series_range(redis.ts, []) - - @given(yield_redis=redis_cms()) - def test_sync_error_range_invalid_key( - self, *, yield_redis: YieldRedisContainer - ) -> None: - with ( - yield_redis() as redis, - raises(TimeSeriesRangeError, match="The key '.*' must exist"), - ): - _ = time_series_range(redis.ts, redis.key) - - @given(data=data(), yield_redis=redis_cms()) - def test_sync_error_range_key_with_int64_and_float64( - self, *, data: DataObject, yield_redis: YieldRedisContainer - ) -> None: - with yield_redis() as redis: - values = self._prepare_test_error_range_key_with_int64_and_float64( - data, redis.key - ) - for vals in values: - _ = time_series_madd(redis.ts, vals) - with raises( - TimeSeriesRangeError, - match="The key '.*' contains both Int64 and Float64 data", - ): - _ = time_series_range(redis.ts, redis.key) - - @given( - data=data(), - series_names=lists_fixed_length(text_ascii(), 2, unique=True).map(tuple), - time_zone=sampled_from([HongKong, UTC]), - key_timestamp_value=lists_fixed_length(text_ascii(), 3, unique=True).map(tuple), - ) - @mark.parametrize("case", [param("values"), param("DataFrame")]) - @mark.parametrize( - ("strategy", "dtype"), - [ - param(int32s(), Int64), - param(floats(allow_nan=False, allow_infinity=False), Float64), - ], - ) - async def test_async( - self, - *, - data: DataObject, - series_names: tuple[str, str], - time_zone: ZoneInfo, - key_timestamp_value: tuple[str, str, str], - case: Literal["values", "DataFrame"], - strategy: SearchStrategy[Number], - dtype: PolarsDataType, - ) -> None: - async with redis_cms_async(data) as redis: - prepared = self._prepare_main_test( - data, - redis.key, - series_names, - time_zone, - key_timestamp_value, - case, - strategy, - dtype, + timestamps = data.draw(tuples(valid_zoned_datetimes, valid_zoned_datetimes)) + strategy, dtype = strategy_dtype + values = data.draw(tuples(strategy, strategy)) + key, timestamp, value = key_timestamp_value + async with redis_cms(data) as container: + keys = cast( + tuple[str, str], + tuple(f"{container.key}_{name}" for name in series_names), ) - res_madd = await time_series_madd_async( - redis.ts, - prepared.values_or_df, - key=prepared.key, - timestamp=prepared.timestamp, - value=prepared.value, + triples = list(zip(keys, timestamps, values, strict=True)) + schema = { + key: Utf8, + timestamp: zoned_datetime(time_zone=time_zone), + value: dtype, + } + values_or_df = data.draw( + sampled_from([triples, DataFrame(triples, schema=schema, orient="row")]) ) - assert isinstance(res_madd, list) + match container.client: + case redis.Redis(): + res_madd = time_series_madd( + container.ts, + values_or_df, + key=key, + timestamp=timestamp, + value=value, + ) + case redis.asyncio.Redis(): + res_madd = await time_series_madd_async( + container.ts, + values_or_df, + key=key, + timestamp=timestamp, + value=value, + ) for i in res_madd: assert isinstance(i, int) - res_range = await time_series_range_async( - redis.ts, - prepared.keys, - output_key=prepared.key, - output_timestamp=prepared.timestamp, - output_time_zone=time_zone, - output_value=prepared.value, - ) - check_polars_dataframe(res_range, height=2, schema_list=prepared.schema) - assert res_range.rows() == prepared.triples + match container.client: + case redis.Redis(): + res_range = time_series_range( + container.ts, + keys, + output_key=key, + output_timestamp=timestamp, + output_time_zone=time_zone, + output_value=value, + ) + case redis.asyncio.Redis(): + res_range = await time_series_range_async( + container.ts, + keys, + output_key=key, + output_timestamp=timestamp, + output_time_zone=time_zone, + output_value=value, + ) + check_polars_dataframe(res_range, height=2, schema_list=schema) + assert res_range.rows() == triples @given(data=data()) - async def test_async_error_madd_key_missing(self, *, data: DataObject) -> None: + async def test_error_madd_key_missing(self, *, data: DataObject) -> None: df = DataFrame() - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesMAddError, match="DataFrame must have a 'key' column; got .*" - ): - _ = await time_series_madd_async(redis.ts, df) + match = "DataFrame must have a 'key' column; got .*" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesMAddKeyMissingError, match=match): + _ = time_series_madd(container.ts, df) + case redis.asyncio.Redis(): + with raises(_TimeSeriesMAddKeyMissingError, match=match): + _ = await time_series_madd_async(container.ts, df) @given(data=data()) - async def test_async_error_madd_timestamp_missing( - self, *, data: DataObject - ) -> None: + async def test_error_madd_timestamp_missing(self, *, data: DataObject) -> None: df = DataFrame(schema={"key": Utf8}) - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesMAddError, - match="DataFrame must have a 'timestamp' column; got .*", - ): - _ = await time_series_madd_async(redis.ts, df) + match = "DataFrame must have a 'timestamp' column; got .*" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesMAddTimestampMissingError, match=match): + _ = time_series_madd(container.ts, df) + case redis.asyncio.Redis(): + with raises(_TimeSeriesMAddTimestampMissingError, match=match): + _ = await time_series_madd_async(container.ts, df) @given(data=data()) - async def test_async_error_madd_value_missing(self, *, data: DataObject) -> None: + async def test_error_madd_value_missing(self, *, data: DataObject) -> None: df = DataFrame(schema={"key": Utf8, "timestamp": DatetimeUTC}) - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesMAddError, - match="DataFrame must have a 'value' column; got .*", - ): - _ = await time_series_madd_async(redis.ts, df) + match = "DataFrame must have a 'value' column; got .*" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesMAddValueMissingError, match=match): + _ = time_series_madd(container.ts, df) + case redis.asyncio.Redis(): + with raises(_TimeSeriesMAddValueMissingError, match=match): + _ = await time_series_madd_async(container.ts, df) @given(data=data()) - async def test_async_error_madd_key_is_not_utf8(self, *, data: DataObject) -> None: + async def test_error_madd_key_is_not_utf8(self, *, data: DataObject) -> None: df = DataFrame( schema={"key": Boolean, "timestamp": DatetimeUTC, "value": Float64} ) - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesMAddError, match="The 'key' column must be Utf8; got Boolean" - ): - _ = await time_series_madd_async(redis.ts, df) + match = "The 'key' column must be Utf8; got Boolean" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesMAddKeyIsNotUtf8Error, match=match): + _ = time_series_madd(container.ts, df) + case redis.asyncio.Redis(): + with raises(_TimeSeriesMAddKeyIsNotUtf8Error, match=match): + _ = await time_series_madd_async(container.ts, df) @given(data=data()) - async def test_async_error_madd_timestamp_is_not_a_zoned_datetime( + async def test_error_madd_timestamp_is_not_a_zoned_datetime( self, *, data: DataObject ) -> None: df = DataFrame(schema={"key": Utf8, "timestamp": Boolean, "value": Float64}) - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesMAddError, - match="The 'timestamp' column must be a zoned Datetime; got Boolean", - ): - _ = await time_series_madd_async(redis.ts, df) + match = "The 'timestamp' column must be a zoned Datetime; got Boolean" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises( + _TimeSeriesMAddTimestampIsNotAZonedDatetimeError, match=match + ): + _ = time_series_madd(container.ts, df) + case redis.asyncio.Redis(): + with raises( + _TimeSeriesMAddTimestampIsNotAZonedDatetimeError, match=match + ): + _ = await time_series_madd_async(container.ts, df) @given(data=data()) - async def test_async_error_madd_value_is_not_numeric( - self, *, data: DataObject - ) -> None: + async def test_error_madd_value_is_not_numeric(self, *, data: DataObject) -> None: df = DataFrame(schema={"key": Utf8, "timestamp": DatetimeUTC, "value": Boolean}) - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesMAddError, - match="The 'value' column must be numeric; got Boolean", - ): - _ = await time_series_madd_async(redis.ts, df) - - @given(data=data()) - @mark.parametrize("case", [param("values"), param("DataFrame")]) - async def test_async_error_madd_invalid_key( - self, *, data: DataObject, case: Literal["values", "DataFrame"] - ) -> None: - async with redis_cms_async(data) as redis: - values_or_df = self._prepare_test_error_madd_invalid_key( - data, redis.key, case + match = "The 'value' column must be numeric; got Boolean" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesMAddValueIsNotNumericError, match=match): + _ = time_series_madd(container.ts, df) + case redis.asyncio.Redis(): + with raises(_TimeSeriesMAddValueIsNotNumericError, match=match): + _ = await time_series_madd_async(container.ts, df) + + @given(data=data(), timestamp=valid_zoned_datetimes, value=int32s()) + async def test_error_madd_invalid_key( + self, *, data: DataObject, timestamp: dt.datetime, value: int + ) -> None: + match = "The key '.*' must exist" + async with redis_cms(data) as container: + values = [(container.key, timestamp, value)] + values_or_df = data.draw( + sampled_from([ + values, + DataFrame(values, schema=self.int_schema, orient="row"), + ]) ) - with raises(TimeSeriesMAddError, match="The key '.*' must exist"): - _ = await time_series_madd_async( - redis.ts, values_or_df, assume_time_series_exist=True - ) - - @given(data=data()) - @mark.parametrize("case", [param("values"), param("DataFrame")]) - async def test_async_error_madd_invalid_timestamp( - self, *, data: DataObject, case: Literal["values", "DataFrame"] + match container.client: + case redis.Redis(): + with raises(_TimeSeriesMAddInvalidKeyError, match=match): + _ = time_series_madd( + container.ts, values_or_df, assume_time_series_exist=True + ) + case redis.asyncio.Redis(): + with raises(_TimeSeriesMAddInvalidKeyError, match=match): + _ = await time_series_madd_async( + container.ts, values_or_df, assume_time_series_exist=True + ) + + @given(data=data(), timestamp=invalid_zoned_datetimes) + async def test_error_madd_invalid_timestamp( + self, *, data: DataObject, timestamp: dt.datetime ) -> None: - async with redis_cms_async(data) as redis: - values_or_df = self._prepare_test_error_madd_invalid_timestamp( - data, redis.key, case + value = data.draw(int32s()) + match = "Timestamps must be at least the Epoch; got .*" + async with redis_cms(data) as container: + values = [(container.key, timestamp, value)] + values_or_df = data.draw( + sampled_from([ + values, + DataFrame(values, schema=self.int_schema, orient="row"), + ]) ) - with raises( - TimeSeriesMAddError, - match="Timestamps must be at least the Epoch; got .*", - ): - _ = await time_series_madd_async(redis.ts, values_or_df) + match container.client: + case redis.Redis(): + with raises(_TimeSeriesMAddInvalidTimestampError, match=match): + _ = time_series_madd(container.ts, values_or_df) + case redis.asyncio.Redis(): + with raises(_TimeSeriesMAddInvalidTimestampError, match=match): + _ = await time_series_madd_async(container.ts, values_or_df) - @given(data=data()) - @mark.parametrize("case", [param("values"), param("DataFrame")]) - @mark.parametrize("value", [param(inf), param(-inf), param(nan)]) - async def test_async_error_madd_invalid_value( - self, *, data: DataObject, case: Literal["values", "DataFrame"], value: float + @given( + data=data(), + timestamp=valid_zoned_datetimes, + value=sampled_from([inf, -inf, nan]), + ) + async def test_error_madd_invalid_value( + self, *, data: DataObject, timestamp: dt.datetime, value: float ) -> None: - async with redis_cms_async(data) as redis: - values_or_df = self._prepare_test_error_madd_invalid_value( - data, redis.key, case, value + timestamp = data.draw(valid_zoned_datetimes) + match = "The value .* is invalid" + schema = {"key": Utf8, "timestamp": DatetimeUTC, "value": Float64} + async with redis_cms(data) as container: + values = [(f"{container.key}", timestamp, value)] + values_or_df = data.draw( + sampled_from([values, DataFrame(values, schema=schema, orient="row")]) ) - with raises(TimeSeriesMAddError, match="The value .* is invalid"): - _ = await time_series_madd_async(redis.ts, values_or_df) - - @given(data=data()) - async def test_async_error_range_no_keys_requested( - self, *, data: DataObject - ) -> None: - async with redis_cms_async(data) as redis: - with raises( - TimeSeriesRangeError, match="At least 1 key must be requested; got .*" - ): - _ = await time_series_range_async(redis.ts, []) + match container.client: + case redis.Redis(): + with raises(_TimeSeriesMAddInvalidValueError, match=match): + _ = time_series_madd(container.ts, values_or_df) + case redis.asyncio.Redis(): + with raises(_TimeSeriesMAddInvalidValueError, match=match): + _ = await time_series_madd_async(container.ts, values_or_df) @given(data=data()) - async def test_async_error_range_invalid_key(self, *, data: DataObject) -> None: - async with redis_cms_async(data) as redis: - with raises(TimeSeriesRangeError, match="The key '.*' must exist"): - _ = await time_series_range_async(redis.ts, redis.key) + async def test_error_range_no_keys_requested(self, *, data: DataObject) -> None: + match = "At least 1 key must be requested; got .*" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesRangeNoKeysRequestedError, match=match): + _ = time_series_range(container.ts, []) + case redis.asyncio.Redis(): + with raises(_TimeSeriesRangeNoKeysRequestedError, match=match): + _ = await time_series_range_async(container.ts, []) @given(data=data()) - async def test_async_error_range_key_with_int64_and_float64( - self, *, data: DataObject - ) -> None: - async with redis_cms_async(data) as redis: - values = self._prepare_test_error_range_key_with_int64_and_float64( - data, redis.key + async def test_error_range_invalid_key(self, *, data: DataObject) -> None: + match = "The key '.*' must exist" + async with redis_cms(data) as container: + match container.client: + case redis.Redis(): + with raises(_TimeSeriesRangeInvalidKeyError, match=match): + _ = time_series_range(container.ts, container.key) + case redis.asyncio.Redis(): + with raises(_TimeSeriesRangeInvalidKeyError, match=match): + _ = await time_series_range_async(container.ts, container.key) + + @given(data=data(), timestamp=valid_zoned_datetimes, value=int32s()) + async def test_error_range_key_with_int64_and_float64( + self, *, data: DataObject, timestamp: dt.datetime, value: int + ) -> None: + match = "The key '.*' contains both Int64 and Float64 data" + async with redis_cms(data) as container: + values = ( + [(container.key, timestamp, value)], + [(container.key, timestamp, float(value))], ) - for vals in values: - _ = await time_series_madd_async(redis.ts, vals) - with raises( - TimeSeriesRangeError, - match="The key '.*' contains both Int64 and Float64 data", - ): - _ = await time_series_range_async(redis.ts, redis.key) - - def _prepare_main_test( - self, - data: DataObject, - redis_key: str, - series_names: tuple[str, str], - time_zone: ZoneInfo, - key_timestamp_value: tuple[str, str, str], - case: Literal["values", "DataFrame"], - strategy: SearchStrategy[Number], - dtype: PolarsDataType, - /, - ) -> _TestTimeSeriesMAddAndRangePrepare: - keys = cast( - tuple[str, str], - tuple(f"{redis_key}_{case}_{name}" for name in series_names), - ) - timestamps = data.draw(tuples(valid_zoned_datetimes, valid_zoned_datetimes)) - values = data.draw(tuples(strategy, strategy)) - triples = list(zip(keys, timestamps, values, strict=True)) - key, timestamp, value = key_timestamp_value - schema = { - key: Utf8, - timestamp: zoned_datetime(time_zone=time_zone), - value: dtype, - } - match case: - case "values": - values_or_df = triples - case "DataFrame": - values_or_df = DataFrame(triples, schema=schema, orient="row") - return _TestTimeSeriesMAddAndRangePrepare( - keys=keys, - triples=triples, - key=key, - timestamp=timestamp, - value=value, - values_or_df=values_or_df, - schema=schema, - ) - - def _prepare_test_error_madd_invalid_key( - self, data: DataObject, key: str, case: Literal["values", "DataFrame"], / - ) -> list[tuple[str, dt.datetime, int]] | DataFrame: - timestamp = data.draw(valid_zoned_datetimes) - value = data.draw(int32s()) - values = [(f"{key}_{case}", timestamp, value)] - match case: - case "values": - return values - case "DataFrame": - return DataFrame(values, schema=self.int_schema, orient="row") - - def _prepare_test_error_madd_invalid_timestamp( - self, data: DataObject, key: str, case: Literal["values", "DataFrame"], / - ) -> list[tuple[str, dt.datetime, int]] | DataFrame: - timestamp = data.draw(invalid_zoned_datetimes) - _ = assume(timestamp < EPOCH_UTC) - value = data.draw(int32s()) - values = [(f"{key}_{case}", timestamp, value)] - match case: - case "values": - return values - case "DataFrame": - return DataFrame(values, schema=self.int_schema, orient="row") - - def _prepare_test_error_madd_invalid_value( - self, - data: DataObject, - key: str, - case: Literal["values", "DataFrame"], - value: float, - /, - ) -> list[tuple[str, dt.datetime, float]] | DataFrame: - timestamp = data.draw(valid_zoned_datetimes) - values = [(f"{key}_{case}", timestamp, value)] - match case: - case "values": - return values - case "DataFrame": - return DataFrame(values, schema=self.float_schema, orient="row") - - def _prepare_test_error_range_key_with_int64_and_float64( - self, data: DataObject, key: str, / - ) -> tuple[ - list[tuple[str, dt.datetime, int]], list[tuple[str, dt.datetime, float]] - ]: - timestamp = data.draw(valid_zoned_datetimes) - value = data.draw(int32s()) - return [(key, timestamp, value)], [(key, timestamp, float(value))] + match container.client: + case redis.Redis(): + for vals in values: + _ = time_series_madd(container.ts, vals) + with raises( + _TimeSeriesRangeKeyWithInt64AndFloat64Error, match=match + ): + _ = time_series_range(container.ts, container.key) + case redis.asyncio.Redis(): + for vals in values: + _ = await time_series_madd_async(container.ts, vals) + with raises( + _TimeSeriesRangeKeyWithInt64AndFloat64Error, match=match + ): + _ = await time_series_range_async(container.ts, container.key) class TestYieldClient: diff --git a/src/utilities/__init__.py b/src/utilities/__init__.py index c4eff51ab..1bc06655a 100644 --- a/src/utilities/__init__.py +++ b/src/utilities/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -__version__ = "0.53.1" +__version__ = "0.53.2" diff --git a/src/utilities/hypothesis.py b/src/utilities/hypothesis.py index d984098c7..cb1247341 100644 --- a/src/utilities/hypothesis.py +++ b/src/utilities/hypothesis.py @@ -2,22 +2,13 @@ import builtins import datetime as dt -from collections.abc import ( - AsyncIterator, - Callable, - Collection, - Hashable, - Iterable, - Iterator, -) +from collections.abc import AsyncIterator, Collection, Hashable, Iterable, Iterator from contextlib import ( AbstractAsyncContextManager, - AbstractContextManager, asynccontextmanager, contextmanager, suppress, ) -from dataclasses import dataclass from datetime import timezone from enum import Enum, auto from math import ceil, floor, inf, isfinite, nan @@ -26,20 +17,9 @@ from re import search from string import ascii_letters, printable from subprocess import run -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Protocol, - TypeVar, - assert_never, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, assert_never, cast, overload from zoneinfo import ZoneInfo -import redis -import redis.asyncio from hypothesis import HealthCheck, Phase, Verbosity, assume, settings from hypothesis.errors import InvalidArgument from hypothesis.strategies import ( @@ -63,8 +43,6 @@ uuids, ) from hypothesis.utils.conventions import not_set -from redis.exceptions import ResponseError -from redis.typing import KeyT from utilities.datetime import MAX_MONTH, MIN_MONTH, Month, date_to_month, get_now from utilities.math import MAX_INT32, MAX_INT64, MAX_UINT32, MIN_INT32, MIN_INT64 @@ -76,15 +54,15 @@ from utilities.zoneinfo import UTC if TYPE_CHECKING: - from uuid import UUID - + import redis + import redis.asyncio from hypothesis.database import ExampleDatabase from numpy.random import RandomState - from redis.commands.timeseries import TimeSeries from sqlalchemy import Engine, MetaData from sqlalchemy.ext.asyncio import AsyncEngine from utilities.numpy import NDArrayB, NDArrayF, NDArrayI, NDArrayO + from utilities.redis import _RedisContainer from utilities.types import Duration, Number @@ -433,76 +411,58 @@ def random_states( return RandomState(seed=seed_use) -_TRedis = TypeVar("_TRedis", redis.Redis, redis.asyncio.Redis) - - -@dataclass(repr=False, frozen=True, kw_only=True) -class RedisContainer(Generic[_TRedis]): - """A container for `redis.Client`.""" - - client: _TRedis - timestamp: dt.datetime - uuid: UUID - key: str - - @property - def ts(self) -> TimeSeries: - return self.client.ts() # skipif-ci-and-not-linux +_ASYNCS = booleans() -YieldRedisContainer = Callable[[], AbstractContextManager[RedisContainer[redis.Redis]]] - - -@composite -def redis_cms(draw: DrawFn, /) -> YieldRedisContainer: +def redis_cms( + data: DataObject, /, *, async_: MaybeSearchStrategy[bool] = _ASYNCS +) -> AbstractAsyncContextManager[ + _RedisContainer[redis.Redis] | _RedisContainer[redis.asyncio.Redis] +]: """Strategy for generating redis clients (with cleanup).""" - from redis import Redis # skipif-ci-and-not-linux + import redis # skipif-ci-and-not-linux + import redis.asyncio # skipif-ci-and-not-linux from redis.exceptions import ResponseError # skipif-ci-and-not-linux + from redis.typing import KeyT # skipif-ci-and-not-linux + + from utilities.redis import _RedisContainer # skipif-ci-and-not-linux + draw = lift_data(data) # skipif-ci-and-not-linux now = get_now(time_zone="local") # skipif-ci-and-not-linux uuid = draw(uuids()) # skipif-ci-and-not-linux key = f"{now}_{uuid}" # skipif-ci-and-not-linux - @contextmanager - def yield_redis() -> ( # skipif-ci-and-not-linux - Iterator[RedisContainer[redis.Redis]] - ): - with Redis(db=15) as client: # skipif-ci-and-not-linux - keys = cast(list[KeyT], client.keys(pattern=f"{key}_*")) - with suppress(ResponseError): - _ = client.delete(*keys) - yield RedisContainer(client=client, timestamp=now, uuid=uuid, key=key) - keys = cast(list[KeyT], client.keys(pattern=f"{key}_*")) - with suppress(ResponseError): - _ = client.delete(*keys) - - return yield_redis # skipif-ci-and-not-linux + if draw(async_): # skipif-ci-and-not-linux + @asynccontextmanager + async def yield_redis_async() -> ( # skipif-ci-and-not-linux + AsyncIterator[_RedisContainer[redis.asyncio.Redis]] + ): + async with redis.asyncio.Redis(db=15) as client: # skipif-ci-and-not-linux + keys = cast(list[KeyT], await client.keys(pattern=f"{key}_*")) + with suppress(ResponseError): + _ = await client.delete(*keys) + yield _RedisContainer(client=client, timestamp=now, uuid=uuid, key=key) + keys = cast(list[KeyT], await client.keys(pattern=f"{key}_*")) + with suppress(ResponseError): + _ = await client.delete(*keys) -def redis_cms_async( - data: DataObject, / -) -> AbstractAsyncContextManager[RedisContainer[redis.asyncio.Redis]]: - """Strategy for generating asynchronous redis clients.""" - from redis.asyncio import Redis # skipif-ci-and-not-linux - - now = get_now(time_zone="local") # skipif-ci-and-not-linux - uuid = data.draw(uuids()) # skipif-ci-and-not-linux - key = f"{now}_{uuid}" # skipif-ci-and-not-linux + return yield_redis_async() # skipif-ci-and-not-linux @asynccontextmanager - async def yield_redis_async() -> ( # skipif-ci-and-not-linux - AsyncIterator[RedisContainer[redis.asyncio.Redis]] + async def yield_redis_sync() -> ( # skipif-ci-and-not-linux + AsyncIterator[_RedisContainer[redis.Redis]] ): - async with Redis(db=15) as client: # skipif-ci-and-not-linux - keys = cast(list[KeyT], await client.keys(pattern=f"{key}_*")) + with redis.Redis(db=15) as client: # skipif-ci-and-not-linux + keys = cast(list[KeyT], client.keys(pattern=f"{key}_*")) with suppress(ResponseError): - _ = await client.delete(*keys) - yield RedisContainer(client=client, timestamp=now, uuid=uuid, key=key) - keys = cast(list[KeyT], await client.keys(pattern=f"{key}_*")) + _ = client.delete(*keys) + yield _RedisContainer(client=client, timestamp=now, uuid=uuid, key=key) + keys = cast(list[KeyT], client.keys(pattern=f"{key}_*")) with suppress(ResponseError): - _ = await client.delete(*keys) + _ = client.delete(*keys) - return yield_redis_async() # skipif-ci-and-not-linux + return yield_redis_sync() # skipif-ci-and-not-linux @composite @@ -823,7 +783,6 @@ def _draw_text( __all__ = [ "MaybeSearchStrategy", - "RedisContainer", "Shape", "aiosqlite_engines", "assume_does_not_raise", @@ -842,7 +801,6 @@ def _draw_text( "months", "random_states", "redis_cms", - "redis_cms_async", "sets_fixed_length", "setup_hypothesis_profiles", "slices", diff --git a/src/utilities/redis.py b/src/utilities/redis.py index 6925b0e60..3bdb4f5f2 100644 --- a/src/utilities/redis.py +++ b/src/utilities/redis.py @@ -5,7 +5,16 @@ from functools import partial from itertools import product from re import search -from typing import TYPE_CHECKING, Any, Literal, NoReturn, assert_never, cast +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + NoReturn, + TypeVar, + assert_never, + cast, +) import redis import redis.asyncio @@ -25,6 +34,7 @@ if TYPE_CHECKING: import datetime as dt from collections.abc import AsyncIterator, Iterable, Iterator, Sequence + from uuid import UUID from zoneinfo import ZoneInfo from polars import DataFrame @@ -43,6 +53,23 @@ _SPLIT = "|" +_TRedis = TypeVar("_TRedis", redis.Redis, redis.asyncio.Redis) + + +@dataclass(repr=False, frozen=True, kw_only=True) +class _RedisContainer(Generic[_TRedis]): + """A container for a client; for testing purposes only.""" + + client: _TRedis + timestamp: dt.datetime + uuid: UUID + key: str + + @property + def ts(self) -> TimeSeries: + return self.client.ts() # skipif-ci-and-not-linux + + def ensure_time_series_created( ts: TimeSeries, /,