diff --git a/conda/meta.yaml b/conda/meta.yaml
index 689aeb5c7..5304194bb 100644
--- a/conda/meta.yaml
+++ b/conda/meta.yaml
@@ -17,8 +17,8 @@ requirements:
- python >=3.9,<3.13
- typing_extensions >=4.8
- orjson >=3.9,<4
- - pydantic >=2.7,<2.12
- - pydantic-settings >=2.3,<2.11
+ - pydantic >=2.7,<2.13
+ - pydantic-settings >=2.3,<2.12
- jsonschema >=4.3.0
- fastavro >=1.8,<2.0
- jsonlines >=4,<5
diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md
index 09e2f5790..d6a961573 100644
--- a/docs/api-reference/dataframe.md
+++ b/docs/api-reference/dataframe.md
@@ -10,7 +10,7 @@
class StreamingDataFrame()
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L90)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L94)
`StreamingDataFrame` is the main object you will use for ETL work.
@@ -73,7 +73,7 @@ sdf = sdf.to_topic(topic_obj)
def stream_id() -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L175)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L179)
An identifier of the data stream this StreamingDataFrame
manipulates in the application.
@@ -107,7 +107,7 @@ def apply(func: Union[
metadata: bool = False) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L234)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L238)
Apply a function to transform the value and return a new value.
@@ -165,7 +165,7 @@ def update(func: Union[
metadata: bool = False) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L338)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L342)
Apply a function to mutate value in-place or to perform a side effect
@@ -233,7 +233,7 @@ def filter(func: Union[
metadata: bool = False) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L441)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L445)
Filter value using provided function.
@@ -285,7 +285,7 @@ def group_by(key: Union[str, Callable[[Any], Any]],
key_serializer: SerializerType = "json") -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L526)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L530)
"Groups" messages by re-keying them via the provided group_by operation
@@ -350,7 +350,7 @@ a clone with this operation added (assign to keep its effect).
def contains(keys: Union[str, list[str]]) -> StreamingSeries
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L640)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L644)
Check if keys are present in the Row value.
@@ -392,7 +392,7 @@ def to_topic(
key: Optional[Callable[[Any], Any]] = None) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L684)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L688)
Produce current value to a topic. You can optionally specify a new key.
@@ -463,7 +463,7 @@ def set_timestamp(
func: Callable[[Any, Any, int, Any], int]) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L753)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L757)
Set a new timestamp based on the current message value and its metadata.
@@ -516,7 +516,7 @@ def set_headers(
) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L796)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L800)
Set new message headers based on the current message value and metadata.
@@ -565,7 +565,7 @@ a new StreamingDataFrame instance
def print(pretty: bool = True, metadata: bool = False) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L847)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L851)
Print out the current message value (and optionally, the message metadata) to
@@ -628,7 +628,7 @@ def print_table(
int]] = None) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L893)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L897)
Print a table with the most recent records.
@@ -721,7 +721,7 @@ sdf.print_table(size=5, title="Live Records", slowdown=1)
def compose(sink: Optional[VoidExecutor] = None) -> dict[str, VoidExecutor]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1009)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1013)
Compose all functions of this StreamingDataFrame into one big closure.
@@ -775,7 +775,7 @@ def test(value: Any,
topic: Optional[Topic] = None) -> List[Any]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1043)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1047)
A shorthand to test `StreamingDataFrame` with provided value
@@ -811,11 +811,13 @@ def tumbling_window(
duration_ms: Union[int, timedelta],
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
- on_late: Optional[WindowOnLateCallback] = None
+ on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None
) -> TumblingTimeWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1082)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1086)
Create a time-based tumbling window transformation on this StreamingDataFrame.
@@ -885,6 +887,18 @@ sdf = (
If the callback returns `True`, the message about a late record will be logged
(default behavior).
Otherwise, no message will be logged.
+ - `before_update`: an optional callback to trigger early window expiration
+ before the window is updated.
+ The callback receives `aggregated` (current aggregated value or default/None),
+ `value`, `key`, `timestamp`, and `headers`.
+ If it returns `True`, the window will be expired immediately.
+ Default - `None`.
+ - `after_update`: an optional callback to trigger early window expiration
+ after the window is updated.
+ The callback receives `aggregated` (updated aggregated value), `value`, `key`,
+ `timestamp`, and `headers`.
+ If it returns `True`, the window will be expired immediately.
+ Default - `None`.
@@ -907,7 +921,7 @@ def tumbling_count_window(
name: Optional[str] = None) -> TumblingCountWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1171)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1193)
Create a count-based tumbling window transformation on this StreamingDataFrame.
@@ -976,11 +990,13 @@ def hopping_window(
step_ms: Union[int, timedelta],
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
- on_late: Optional[WindowOnLateCallback] = None
+ on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None
) -> HoppingTimeWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1221)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1243)
Create a time-based hopping window transformation on this StreamingDataFrame.
@@ -1060,6 +1076,18 @@ sdf = (
If the callback returns `True`, the message about a late record will be logged
(default behavior).
Otherwise, no message will be logged.
+ - `before_update`: an optional callback to trigger early window expiration
+ before the window is updated.
+ The callback receives `aggregated` (current aggregated value or default/None),
+ `value`, `key`, `timestamp`, and `headers`.
+ If it returns `True`, the window will be expired immediately.
+ Default - `None`.
+ - `after_update`: an optional callback to trigger early window expiration
+ after the window is updated.
+ The callback receives `aggregated` (updated aggregated value), `value`, `key`,
+ `timestamp`, and `headers`.
+ If it returns `True`, the window will be expired immediately.
+ Default - `None`.
@@ -1083,7 +1111,7 @@ def hopping_count_window(
name: Optional[str] = None) -> HoppingCountWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1324)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1364)
Create a count-based hopping window transformation on this StreamingDataFrame.
@@ -1161,7 +1189,7 @@ def sliding_window(
) -> SlidingTimeWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1381)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1421)
Create a time-based sliding window transformation on this StreamingDataFrame.
@@ -1259,7 +1287,7 @@ def sliding_count_window(
name: Optional[str] = None) -> SlidingCountWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1476)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1516)
Create a count-based sliding window transformation on this StreamingDataFrame.
@@ -1329,7 +1357,7 @@ sdf = (
def fill(*columns: str, **mapping: Any) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1529)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1569)
Fill missing values in the message value with a constant value.
@@ -1386,7 +1414,7 @@ def drop(columns: Union[str, List[str]],
errors: Literal["ignore", "raise"] = "raise") -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1581)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1621)
Drop column(s) from the message value (value must support `del`, like a dict).
@@ -1430,7 +1458,7 @@ a new StreamingDataFrame instance
def sink(sink: BaseSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1625)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1665)
Sink the processed data to the specified destination.
@@ -1458,7 +1486,7 @@ operations, but branches can still be generated from its originating SDF.
def concat(other: "StreamingDataFrame") -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1663)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1703)
Concatenate two StreamingDataFrames together and return a new one.
@@ -1499,7 +1527,7 @@ def join_asof(right: "StreamingDataFrame",
name: Optional[str] = None) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1699)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1739)
Join the left dataframe with the records of the right dataframe with
@@ -1582,7 +1610,7 @@ def join_interval(
forward_ms: Union[int, timedelta] = 0) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1775)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1815)
Join the left dataframe with records from the right dataframe that fall within
@@ -1685,7 +1713,7 @@ def join_lookup(
) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1880)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1920)
Note: This is an experimental feature, and its API is likely to change in the future.
@@ -1746,7 +1774,7 @@ sdf = sdf.join_lookup(lookup, fields)
def register_store(store_type: Optional[StoreTypes] = None) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1969)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L2009)
Register the default store for the current stream_id in StateStoreManager.
diff --git a/docs/api-reference/quixstreams.md b/docs/api-reference/quixstreams.md
index 086a017fa..e4702d524 100644
--- a/docs/api-reference/quixstreams.md
+++ b/docs/api-reference/quixstreams.md
@@ -55,7 +55,7 @@ True if logging config has been updated, otherwise False.
def strip_workspace_id_prefix(workspace_id: str, s: str) -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L46)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L44)
Remove the workspace ID from a given string if it starts with it.
@@ -78,7 +78,7 @@ the string with workspace_id prefix removed
def prepend_workspace_id(workspace_id: str, s: str) -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L59)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L57)
Add the workspace ID as a prefix to a given string if it does not have it.
@@ -102,7 +102,7 @@ the string with workspace_id prepended
class QuixApplicationConfig()
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L73)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L71)
A convenience container class for Quix Application configs.
@@ -114,7 +114,7 @@ A convenience container class for Quix Application configs.
class QuixKafkaConfigsBuilder()
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L83)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L81)
Retrieves all the necessary information from the Quix API and builds all the
objects required to connect a confluent-kafka client to the Quix Platform.
@@ -139,7 +139,7 @@ def __init__(quix_portal_api_service: QuixPortalApiService,
topic_create_timeout: float = 60)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L98)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L96)
**Arguments**:
@@ -161,7 +161,7 @@ def from_credentials(
topic_create_timeout: float = 60) -> "QuixKafkaConfigsBuilder"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L129)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L127)
Initialize class using the quix_sdk_token and quix_portal_api params.
@@ -174,7 +174,7 @@ Initialize class using the quix_sdk_token and quix_portal_api params.
def convert_topic_response(cls, api_response: dict) -> Topic
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L191)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L189)
Converts a GET or POST ("create") topic API response to a Topic object
@@ -194,7 +194,7 @@ a corresponding Topic object
def strip_workspace_id_prefix(s: str) -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L224)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L222)
Remove the workspace ID from a given string if it starts with it.
@@ -216,7 +216,7 @@ the string with workspace_id prefix removed
def prepend_workspace_id(s: str) -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L235)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L233)
Add the workspace ID as a prefix to a given string if it does not have it.
@@ -239,7 +239,7 @@ def search_for_workspace(workspace_name_or_id: Optional[str] = None,
timeout: Optional[float] = None) -> Optional[dict]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L246)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L244)
Search for a workspace given an expected workspace name or id.
@@ -261,7 +261,7 @@ def get_workspace_info(known_workspace_topic: Optional[str] = None,
timeout: Optional[float] = None) -> dict
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L289)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L287)
Queries for workspace data from the Quix API, regardless of instance cache,
@@ -283,7 +283,7 @@ def search_workspace_for_topic(
timeout: Optional[float] = None) -> Optional[str]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L318)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L316)
Search through all the topics in the given workspace id to see if there is a
@@ -309,7 +309,7 @@ def search_for_topic_workspace(
timeout: Optional[float] = None) -> Optional[dict]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L341)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L339)
Find what workspace a topic belongs to.
@@ -333,7 +333,7 @@ workspace data dict if topic search success, else None
def create_topic(topic: Topic, timeout: Optional[float] = None) -> dict
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L372)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L370)
The actual API call to create the topic.
@@ -350,7 +350,7 @@ The actual API call to create the topic.
def get_or_create_topic(topic: Topic, timeout: Optional[float] = None) -> dict
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L408)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L406)
Get or create topics in a Quix cluster as part of initializing the Topic
@@ -372,7 +372,7 @@ def wait_for_topic_ready_statuses(topics: List[Topic],
finalize_timeout: Optional[float] = None)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L436)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L434)
After the broker acknowledges topics for creation, they will be in a
@@ -395,7 +395,7 @@ marked as "Ready" (and thus ready to produce to/consume from).
def get_topic(topic: Topic, timeout: Optional[float] = None) -> dict
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L479)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L477)
return the topic ID (the actual cluster topic name) if it exists, else raise
@@ -420,7 +420,7 @@ response dict of the topic info if topic found, else None
def get_application_config(consumer_group_id: str) -> QuixApplicationConfig
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L511)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/config.py#L509)
Get all the necessary attributes for an Application to run on Quix Cloud.
@@ -634,6 +634,23 @@ If it doesn't match, the warning will be logged.
## quixstreams.platforms.quix.api
+
+
+#### retry\_on\_connection\_error
+
+```python
+def retry_on_connection_error(max_retries: int = 5, base_delay: float = 1.0)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/api.py#L23)
+
+Retry decorator for httpx connection errors with exponential backoff.
+
+**Arguments**:
+
+- `max_retries`: Maximum number of retry attempts (default: 5)
+- `base_delay`: Base delay in seconds for exponential backoff (default: 1.0)
+
### QuixPortalApiService
@@ -642,7 +659,7 @@ If it doesn't match, the warning will be logged.
class QuixPortalApiService()
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/api.py#L18)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/api.py#L55)
A light wrapper around the Quix Portal Api. If used in the Quix Platform, it will
use that workspaces auth token and portal endpoint, else you must provide it.
@@ -660,11 +677,12 @@ See the swagger documentation for more info about the endpoints.
#### QuixPortalApiService.get\_workspace\_certificate
```python
+@retry_on_connection_error()
def get_workspace_certificate(workspace_id: Optional[str] = None,
timeout: float = 30) -> Optional[bytes]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/api.py#L76)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/platforms/quix/api.py#L128)
Get a workspace TLS certificate if available.
@@ -938,7 +956,7 @@ messages in the timestamp-aligned way for the correct processing.
class StreamingDataFrame()
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L90)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L94)
`StreamingDataFrame` is the main object you will use for ETL work.
@@ -993,7 +1011,7 @@ sdf = sdf.to_topic(topic_obj)
def stream_id() -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L175)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L179)
An identifier of the data stream this StreamingDataFrame
manipulates in the application.
@@ -1025,7 +1043,7 @@ def apply(func: Union[
metadata: bool = False) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L234)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L238)
Apply a function to transform the value and return a new value.
@@ -1077,7 +1095,7 @@ def update(func: Union[
metadata: bool = False) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L338)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L342)
Apply a function to mutate value in-place or to perform a side effect
@@ -1137,7 +1155,7 @@ def filter(func: Union[
metadata: bool = False) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L441)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L445)
Filter value using provided function.
@@ -1183,7 +1201,7 @@ def group_by(key: Union[str, Callable[[Any], Any]],
key_serializer: SerializerType = "json") -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L526)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L530)
"Groups" messages by re-keying them via the provided group_by operation
@@ -1240,7 +1258,7 @@ a clone with this operation added (assign to keep its effect).
def contains(keys: Union[str, list[str]]) -> StreamingSeries
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L640)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L644)
Check if keys are present in the Row value.
@@ -1274,7 +1292,7 @@ def to_topic(
key: Optional[Callable[[Any], Any]] = None) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L684)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L688)
Produce current value to a topic. You can optionally specify a new key.
@@ -1337,7 +1355,7 @@ def set_timestamp(
func: Callable[[Any, Any, int, Any], int]) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L753)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L757)
Set a new timestamp based on the current message value and its metadata.
@@ -1382,7 +1400,7 @@ def set_headers(
) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L796)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L800)
Set new message headers based on the current message value and metadata.
@@ -1423,7 +1441,7 @@ a new StreamingDataFrame instance
def print(pretty: bool = True, metadata: bool = False) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L847)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L851)
Print out the current message value (and optionally, the message metadata) to
@@ -1478,7 +1496,7 @@ def print_table(
int]] = None) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L893)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L897)
Print a table with the most recent records.
@@ -1565,7 +1583,7 @@ automatically based on content. Example: {"name": 20, "id": 10}
def compose(sink: Optional[VoidExecutor] = None) -> dict[str, VoidExecutor]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1009)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1013)
Compose all functions of this StreamingDataFrame into one big closure.
@@ -1611,7 +1629,7 @@ def test(value: Any,
topic: Optional[Topic] = None) -> List[Any]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1043)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1047)
A shorthand to test `StreamingDataFrame` with provided value
@@ -1641,11 +1659,13 @@ def tumbling_window(
duration_ms: Union[int, timedelta],
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
- on_late: Optional[WindowOnLateCallback] = None
+ on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None
) -> TumblingTimeWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1082)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1086)
Create a time-based tumbling window transformation on this StreamingDataFrame.
@@ -1710,6 +1730,18 @@ to configure the logging of such events.
If the callback returns `True`, the message about a late record will be logged
(default behavior).
Otherwise, no message will be logged.
+- `before_update`: an optional callback to trigger early window expiration
+before the window is updated.
+The callback receives `aggregated` (current aggregated value or default/None),
+`value`, `key`, `timestamp`, and `headers`.
+If it returns `True`, the window will be expired immediately.
+Default - `None`.
+- `after_update`: an optional callback to trigger early window expiration
+after the window is updated.
+The callback receives `aggregated` (updated aggregated value), `value`, `key`,
+`timestamp`, and `headers`.
+If it returns `True`, the window will be expired immediately.
+Default - `None`.
**Returns**:
@@ -1728,7 +1760,7 @@ def tumbling_count_window(
name: Optional[str] = None) -> TumblingCountWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1171)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1193)
Create a count-based tumbling window transformation on this StreamingDataFrame.
@@ -1787,11 +1819,13 @@ def hopping_window(
step_ms: Union[int, timedelta],
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
- on_late: Optional[WindowOnLateCallback] = None
+ on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None
) -> HoppingTimeWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1221)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1243)
Create a time-based hopping window transformation on this StreamingDataFrame.
@@ -1866,6 +1900,18 @@ to configure the logging of such events.
If the callback returns `True`, the message about a late record will be logged
(default behavior).
Otherwise, no message will be logged.
+- `before_update`: an optional callback to trigger early window expiration
+before the window is updated.
+The callback receives `aggregated` (current aggregated value or default/None),
+`value`, `key`, `timestamp`, and `headers`.
+If it returns `True`, the window will be expired immediately.
+Default - `None`.
+- `after_update`: an optional callback to trigger early window expiration
+after the window is updated.
+The callback receives `aggregated` (updated aggregated value), `value`, `key`,
+`timestamp`, and `headers`.
+If it returns `True`, the window will be expired immediately.
+Default - `None`.
**Returns**:
@@ -1885,7 +1931,7 @@ def hopping_count_window(
name: Optional[str] = None) -> HoppingCountWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1324)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1364)
Create a count-based hopping window transformation on this StreamingDataFrame.
@@ -1953,7 +1999,7 @@ def sliding_window(
) -> SlidingTimeWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1381)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1421)
Create a time-based sliding window transformation on this StreamingDataFrame.
@@ -2042,7 +2088,7 @@ def sliding_count_window(
name: Optional[str] = None) -> SlidingCountWindowDefinition
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1476)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1516)
Create a count-based sliding window transformation on this StreamingDataFrame.
@@ -2102,7 +2148,7 @@ like `sum`, `count`, etc. applied to the StreamingDataFrame.
def fill(*columns: str, **mapping: Any) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1529)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1569)
Fill missing values in the message value with a constant value.
@@ -2153,7 +2199,7 @@ def drop(columns: Union[str, List[str]],
errors: Literal["ignore", "raise"] = "raise") -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1581)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1621)
Drop column(s) from the message value (value must support `del`, like a dict).
@@ -2189,7 +2235,7 @@ a new StreamingDataFrame instance
def sink(sink: BaseSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1625)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1665)
Sink the processed data to the specified destination.
@@ -2215,7 +2261,7 @@ operations, but branches can still be generated from its originating SDF.
def concat(other: "StreamingDataFrame") -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1663)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1703)
Concatenate two StreamingDataFrames together and return a new one.
@@ -2250,7 +2296,7 @@ def join_asof(right: "StreamingDataFrame",
name: Optional[str] = None) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1699)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1739)
Join the left dataframe with the records of the right dataframe with
@@ -2328,7 +2374,7 @@ def join_interval(
forward_ms: Union[int, timedelta] = 0) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1775)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1815)
Join the left dataframe with records from the right dataframe that fall within
@@ -2424,7 +2470,7 @@ def join_lookup(
) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1880)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1920)
Note: This is an experimental feature, and its API is likely to change in the future.
@@ -2477,7 +2523,7 @@ sdf = sdf.join_lookup(lookup, fields)
def register_store(store_type: Optional[StoreTypes] = None) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L1969)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/dataframe.py#L2009)
Register the default store for the current stream_id in StateStoreManager.
@@ -3965,7 +4011,7 @@ Optional[ConfigurationVersion]: The valid version, or None if not found.
class Lookup(BaseLookup[BaseField])
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L39)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L40)
Lookup join implementation for enriching streaming data with configuration data from a Kafka topic.
@@ -4001,7 +4047,7 @@ def json_field(jsonpath: str,
default: Any = RAISE_ON_MISSING) -> JSONField
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L136)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L135)
Create a JSON field for extracting values from configuration content using JSONPath.
@@ -4024,7 +4070,7 @@ A JSONField instance.
def bytes_field(type: str, default: Any = RAISE_ON_MISSING) -> BytesField
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L160)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L159)
Create a bytes field for extracting binary content from configuration.
@@ -4045,7 +4091,7 @@ A BytesField instance.
def cache_info() -> CacheInfo
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L380)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L393)
Get information about the cache.
@@ -4065,7 +4111,7 @@ def join(fields: Mapping[str, BaseField], on: str, value: dict[str, Any],
key: Any, timestamp: int, headers: HeadersMapping) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L392)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/joins/lookups/quix_configuration_service/lookup.py#L405)
Enrich the message with configuration data from the Quix Configuration Service.
@@ -4330,7 +4376,7 @@ class SlidingWindow(TimeWindow)
```python
def process_window(
- value: Any, key: Any, timestamp_ms: int,
+ value: Any, key: Any, timestamp_ms: int, headers: Any,
transaction: WindowedPartitionTransaction
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]
```
@@ -4390,7 +4436,7 @@ aggregation and combine it with the incoming message.
class WindowDefinition(abc.ABC, Generic[WindowT])
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L51)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L53)
@@ -4400,7 +4446,7 @@ class WindowDefinition(abc.ABC, Generic[WindowT])
def sum() -> WindowT
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L72)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L78)
Configure the window to aggregate data by summing up values within
@@ -4418,7 +4464,7 @@ an instance of `FixedTimeWindow` configured to perform sum aggregation.
def count() -> WindowT
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L85)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L91)
Configure the window to aggregate data by counting the number of values
@@ -4436,7 +4482,7 @@ an instance of `FixedTimeWindow` configured to perform record count.
def mean() -> WindowT
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L98)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L104)
Configure the window to aggregate data by calculating the mean of the values
@@ -4456,7 +4502,7 @@ def reduce(reducer: Callable[[Any, Any], Any],
initializer: Callable[[Any], Any]) -> WindowT
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L112)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L118)
Configure the window to perform a custom aggregation using `reducer`
@@ -4505,7 +4551,7 @@ A window configured to perform custom reduce aggregation on the data.
def max() -> WindowT
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L156)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L162)
Configure a window to aggregate the maximum value within each window period.
@@ -4522,7 +4568,7 @@ value within each window period.
def min() -> WindowT
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L169)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L175)
Configure a window to aggregate the minimum value within each window period.
@@ -4539,7 +4585,7 @@ value within each window period.
def collect() -> WindowT
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L182)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/definitions.py#L188)
Configure the window to collect all values within each window period into a
@@ -4577,7 +4623,7 @@ within each window period.
class TimeWindow(Window)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/time_based.py#L40)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/time_based.py#L43)
@@ -4589,7 +4635,7 @@ def final(
) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/time_based.py#L62)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/time_based.py#L69)
Apply the window aggregation and return results only when the windows are
@@ -4629,7 +4675,7 @@ def current(
) -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/time_based.py#L94)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/time_based.py#L101)
Apply the window transformation to the StreamingDataFrame to return results
@@ -4677,7 +4723,7 @@ class CountWindow(Window)
```python
def process_window(
- value: Any, key: Any, timestamp_ms: int,
+ value: Any, key: Any, timestamp_ms: int, headers: Any,
transaction: WindowedPartitionTransaction[str, CountWindowsData]
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]
```
@@ -4715,7 +4761,7 @@ optimisation. Instead the msg id reset to 0 on every new window.
class Window(abc.ABC)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L46)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L48)
@@ -4725,7 +4771,7 @@ class Window(abc.ABC)
def final() -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L102)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L105)
Apply the window aggregation and return results only when the windows are
closed.
@@ -4756,7 +4802,7 @@ can remain unprocessed until the message the same key is received.
def current() -> "StreamingDataFrame"
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L148)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L152)
Apply the window transformation to the StreamingDataFrame to return results
for each updated window.
@@ -4781,7 +4827,7 @@ regardless of whether the window is closed or not.
class SingleAggregationWindowMixin()
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L223)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L231)
DEPRECATED: Use MultiAggregationWindowMixin instead.
@@ -4798,7 +4844,7 @@ def get_window_ranges(timestamp_ms: int,
step_ms: Optional[int] = None) -> Deque[tuple[int, int]]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L443)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/dataframe/windows/base.py#L451)
Get a list of window ranges for the given timestamp.
@@ -5934,7 +5980,7 @@ list_sink[0] # 1
class InfluxDB3Sink(BatchingSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/core/influxdb3.py#L54)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/core/influxdb3.py#L53)
@@ -5963,7 +6009,7 @@ def __init__(token: str,
ClientConnectFailureCallback] = None)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/core/influxdb3.py#L62)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/core/influxdb3.py#L61)
A connector to sink processed data to InfluxDB v3.
@@ -6243,7 +6289,7 @@ Callback must resolve (or propagate/re-raise) the Exception.
class ParquetFormat(Format)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L16)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L13)
Serializes batches of messages into Parquet format.
@@ -6262,7 +6308,7 @@ def __init__(file_extension: str = ".parquet",
compression: Compression = "snappy") -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L29)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L26)
Initializes the ParquetFormat.
@@ -6283,7 +6329,7 @@ or "zstd". Defaults to "snappy".
def file_extension() -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L47)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L63)
Returns the file extension used for output files.
@@ -6299,7 +6345,7 @@ The file extension as a string.
def serialize(batch: SinkBatch) -> bytes
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L55)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L71)
Serializes a `SinkBatch` into bytes in Parquet format.
@@ -6525,170 +6571,71 @@ Serializes a batch of messages into bytes.
The serialized batch as bytes.
-
+
-## quixstreams.sinks.community.file.sink
+## quixstreams.sinks.community.file.local
-
+
-### FileSink
+### AppendNotSupported
```python
-class FileSink(BatchingSink)
+class AppendNotSupported(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/sink.py#L17)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/local.py#L15)
-A sink that writes data batches to files using configurable formats and
-destinations.
+Raised when append=True but specified format does not support it
-The sink groups messages by their topic and partition, ensuring data from the
-same source is stored together. Each batch is serialized using the specified
-format (e.g., JSON, Parquet) before being written to the configured
-destination.
+
-The destination determines the storage location and write behavior. By default,
-it uses LocalDestination for writing to the local filesystem, but can be
-configured to use other storage backends (e.g., cloud storage).
-
-
-
-#### FileSink.\_\_init\_\_
+### LocalFileSink
```python
-def __init__(
- directory: str = "",
- format: Union[FormatName, Format] = "json",
- destination: Optional[Destination] = None,
- on_client_connect_success: Optional[ClientConnectSuccessCallback] = None,
- on_client_connect_failure: Optional[ClientConnectFailureCallback] = None
-) -> None
+class LocalFileSink(FileSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/sink.py#L31)
-
-Initialize the FileSink with the specified configuration.
-
-**Arguments**:
-
-- `directory`: Base directory path for storing files. Defaults to
-current directory.
-- `format`: Data serialization format, either as a string
-("json", "parquet") or a Format instance.
-- `destination`: Storage destination handler. Defaults to
-LocalDestination if not specified.
-- `on_client_connect_success`: An optional callback made after successful
-client authentication, primarily for additional logging.
-- `on_client_connect_failure`: An optional callback made after failed
-client authentication (which should raise an Exception).
-Callback should accept the raised Exception as an argument.
-Callback must resolve (or propagate/re-raise) the Exception.
-
-
-
-#### FileSink.write
-
-```python
-def write(batch: SinkBatch) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/sink.py#L67)
-
-Write a batch of data using the configured format and destination.
-
-The method performs the following steps:
-1. Serializes the batch data using the configured format
-2. Writes the serialized data to the destination
-3. Handles any write failures by raising a backpressure error
-
-**Arguments**:
-
-- `batch`: The batch of data to write.
-
-**Raises**:
-
-- `SinkBackpressureError`: If the write operation fails, indicating
-that the sink needs backpressure with a 5-second retry delay.
-
-
-
-## quixstreams.sinks.community.file.destinations.local
-
-
-
-### LocalDestination
-
-```python
-class LocalDestination(Destination)
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/local.py#L15)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/local.py#L19)
A destination that writes data to the local filesystem.
Handles writing data to local files with support for both creating new files
and appending to existing ones.
-
+
-#### LocalDestination.\_\_init\_\_
+#### LocalFileSink.\_\_init\_\_
```python
-def __init__(append: bool = False) -> None
+def __init__(append: bool = False,
+ directory: str = "",
+ format: Union[FormatName, Format] = "json") -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/local.py#L22)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/local.py#L26)
Initialize the local destination.
**Arguments**:
- `append`: If True, append to existing files instead of creating new
-ones. Defaults to False.
-
-
-
-#### LocalDestination.set\_extension
-
-```python
-def set_extension(format: Format) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/local.py#L35)
-
-Set the file extension and validate append mode compatibility.
-
-**Arguments**:
-
-- `format`: The Format instance that defines the file extension.
+ones by selecting the lexicographical last file in the given directory
+(or creates one).
+Defaults to False.
+- `directory`: Base directory path for storing files. Defaults to
+current directory.
+- `format`: Data serialization format, either as a string
+("json", "parquet") or a Format instance.
**Raises**:
-- `ValueError`: If append mode is enabled but the format doesn't
-support appending.
-
-
-
-#### LocalDestination.write
-
-```python
-def write(data: bytes, batch: SinkBatch) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/local.py#L46)
-
-Write data to a local file.
-
-**Arguments**:
+- `AppendNotSupported`: If append=True but given format does not support it.
-- `data`: The serialized data to write.
-- `batch`: The batch information containing topic and partition details.
+
-
+## quixstreams.sinks.community.file.azure
-## quixstreams.sinks.community.file.destinations.azure
-
-
+
### AzureContainerNotFoundError
@@ -6696,11 +6643,11 @@ Write data to a local file.
class AzureContainerNotFoundError(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L24)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/azure.py#L29)
Raised when the specified Azure File container does not exist.
-
+
### AzureContainerAccessDeniedError
@@ -6708,73 +6655,73 @@ Raised when the specified Azure File container does not exist.
class AzureContainerAccessDeniedError(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L28)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/azure.py#L33)
Raised when the specified Azure File container access is denied.
-
+
-### AzureFileDestination
+### AzureFileSink
```python
-class AzureFileDestination(Destination)
+class AzureFileSink(FileSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L32)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/azure.py#L37)
A destination that writes data to Microsoft Azure File.
Handles writing data to Azure containers using the Azure Blob SDK. Credentials can
be provided directly or via environment variables.
-
+
-#### AzureFileDestination.\_\_init\_\_
+#### AzureFileSink.\_\_init\_\_
```python
-def __init__(connection_string: str, container: str) -> None
+def __init__(
+ azure_connection_string: str,
+ azure_container: str,
+ directory: str = "",
+ format: Union[FormatName, Format] = "json",
+ on_client_connect_success: Optional[ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[ClientConnectFailureCallback] = None
+) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L40)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/azure.py#L45)
Initialize the Azure File destination.
**Arguments**:
-- `connection_string`: Azure client authentication string.
-- `container`: Azure container name.
+- `azure_connection_string`: Azure client authentication string.
+- `azure_container`: Azure container name.
+- `directory`: Base directory path for storing files. Defaults to
+current directory.
+- `format`: Data serialization format, either as a string
+("json", "parquet") or a Format instance.
+- `on_client_connect_success`: An optional callback made after successful
+client authentication, primarily for additional logging.
+- `on_client_connect_failure`: An optional callback made after failed
+client authentication (which should raise an Exception).
+Callback should accept the raised Exception as an argument.
+Callback must resolve (or propagate/re-raise) the Exception.
**Raises**:
- `AzureContainerNotFoundError`: If the specified container doesn't exist.
- `AzureContainerAccessDeniedError`: If access to the container is denied.
-
-
-#### AzureFileDestination.write
-
-```python
-def write(data: bytes, batch: SinkBatch) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L94)
-
-Write data to Azure.
-
-**Arguments**:
-
-- `data`: The serialized data to write.
-- `batch`: The batch information containing topic and partition details.
-
-
+
-## quixstreams.sinks.community.file.destinations
+## quixstreams.sinks.community.file
-
+
-## quixstreams.sinks.community.file.destinations.s3
+## quixstreams.sinks.community.file.s3
-
+
### S3BucketNotFoundError
@@ -6782,11 +6729,11 @@ Write data to Azure.
class S3BucketNotFoundError(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L14)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/s3.py#L29)
Raised when the specified S3 bucket does not exist.
-
+
### S3BucketAccessDeniedError
@@ -6794,28 +6741,35 @@ Raised when the specified S3 bucket does not exist.
class S3BucketAccessDeniedError(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L18)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/s3.py#L33)
Raised when the specified S3 bucket access is denied.
-
+
-### S3Destination
+### S3FileSink
```python
-class S3Destination(Destination)
+class S3FileSink(FileSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L22)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/s3.py#L37)
-A destination that writes data to Amazon S3.
+A sink that writes data batches to files using configurable formats and
+destinations.
-Handles writing data to S3 buckets using the AWS SDK. Credentials can be
-provided directly or via environment variables.
+The sink groups messages by their topic and partition, ensuring data from the
+same source is stored together. Each batch is serialized using the specified
+format (e.g., JSON, Parquet) before being written to the configured
+destination.
+
+The destination determines the storage location and write behavior. By default,
+it uses LocalDestination for writing to the local filesystem, but can be
+configured to use other storage backends (e.g., cloud storage).
-
+
-#### S3Destination.\_\_init\_\_
+#### S3FileSink.\_\_init\_\_
```python
def __init__(bucket: str,
@@ -6825,10 +6779,16 @@ def __init__(bucket: str,
region_name: Optional[str] = getenv("AWS_REGION",
getenv("AWS_DEFAULT_REGION")),
endpoint_url: Optional[str] = getenv("AWS_ENDPOINT_URL_S3"),
+ directory: str = "",
+ format: Union[FormatName, Format] = "json",
+ on_client_connect_success: Optional[
+ ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[
+ ClientConnectFailureCallback] = None,
**kwargs) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L29)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/s3.py#L51)
Initialize the S3 destination.
@@ -6851,117 +6811,90 @@ NOTE: can alternatively set the AWS_ENDPOINT_URL_S3 environment variable
- `S3BucketNotFoundError`: If the specified bucket doesn't exist.
- `S3BucketAccessDeniedError`: If access to the bucket is denied.
-
-
-#### S3Destination.write
-
-```python
-def write(data: bytes, batch: SinkBatch) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L89)
-
-Write data to S3.
-
-**Arguments**:
+
-- `data`: The serialized data to write.
-- `batch`: The batch information containing topic and partition details.
+## quixstreams.sinks.community.file.base
-
+
-## quixstreams.sinks.community.file.destinations.base
-
-
-
-### Destination
+### FileSink
```python
-class Destination(ABC)
+class FileSink(BatchingSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L16)
-
-Abstract base class for defining where and how data should be stored.
-
-Destinations handle the storage of serialized data, whether that's to local
-disk, cloud storage, or other locations. They manage the physical writing of
-data while maintaining a consistent directory/path structure based on topics
-and partitions.
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/base.py#L24)
-
-
-#### Destination.setup
-
-```python
-@abstractmethod
-def setup()
-```
+A sink that writes data batches to files using configurable formats and
+destinations.
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L29)
+The sink groups messages by their topic and partition, ensuring data from the
+same source is stored together. Each batch is serialized using the specified
+format (e.g., JSON, Parquet) before being written to the configured
+destination.
-Authenticate and validate connection here
+The destination determines the storage location and write behavior. By default,
+it uses LocalDestination for writing to the local filesystem, but can be
+configured to use other storage backends (e.g., cloud storage).
-
+
-#### Destination.write
+#### FileSink.\_\_init\_\_
```python
-@abstractmethod
-def write(data: bytes, batch: SinkBatch) -> None
+def __init__(
+ directory: str = "",
+ format: Union[FormatName, Format] = "json",
+ on_client_connect_success: Optional[ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[ClientConnectFailureCallback] = None
+) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L34)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/base.py#L38)
-Write the serialized data to storage.
+Initialize the FileSink with the specified configuration.
**Arguments**:
-- `data`: The serialized data to write.
-- `batch`: The batch information containing topic, partition and offset
-details.
+- `directory`: Base directory path for storing files. Defaults to
+current directory.
+- `format`: Data serialization format, either as a string
+("json", "parquet") or a Format instance.
+- `on_client_connect_success`: An optional callback made after successful
+client authentication, primarily for additional logging.
+- `on_client_connect_failure`: An optional callback made after failed
+client authentication (which should raise an Exception).
+Callback should accept the raised Exception as an argument.
+Callback must resolve (or propagate/re-raise) the Exception.
-
+
-#### Destination.set\_directory
+#### FileSink.setup
```python
-def set_directory(directory: str) -> None
+@abstractmethod
+def setup()
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L43)
-
-Configure the base directory for storing files.
-
-**Arguments**:
-
-- `directory`: The base directory path where files will be stored.
-
-**Raises**:
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/base.py#L76)
-- `ValueError`: If the directory path contains invalid characters.
-Only alphanumeric characters (a-zA-Z0-9), spaces, dots, slashes, and
-underscores are allowed.
+Authenticate and validate connection here
-
+
-#### Destination.set\_extension
+#### FileSink.write
```python
-def set_extension(format: Format) -> None
+def write(batch: SinkBatch) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L64)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/base.py#L90)
-Set the file extension based on the format.
+Write a batch of data using the configured format.
**Arguments**:
-- `format`: The Format instance that defines the file extension.
-
-
-
-## quixstreams.sinks.community.file
+- `batch`: The batch of data to write.
@@ -7774,6 +7707,77 @@ def set_str_rep(cls, rep_function)
Set the string representation for all Points.
+
+
+## quixstreams.sinks.community.mqtt
+
+
+
+### MQTTSink
+
+```python
+class MQTTSink(BaseSink)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/mqtt.py#L35)
+
+A sink that publishes messages to an MQTT broker.
+
+
+
+#### MQTTSink.\_\_init\_\_
+
+```python
+def __init__(client_id: str,
+ server: str,
+ port: int,
+ topic_root: str,
+ username: str = None,
+ password: str = None,
+ version: ProtocolVersion = "3.1.1",
+ tls_enabled: bool = True,
+ key_serializer: Callable[[Any], str] = bytes.decode,
+ value_serializer: Callable[[Any], str] = json.dumps,
+ qos: Literal[0, 1] = 1,
+ mqtt_flush_timeout_seconds: int = 10,
+ retain: Union[bool, Callable[[Any], bool]] = False,
+ properties: Optional[MqttPropertiesHandler] = None,
+ on_client_connect_success: Optional[
+ ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[
+ ClientConnectFailureCallback] = None)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/mqtt.py#L40)
+
+Initialize the MQTTSink.
+
+**Arguments**:
+
+- `client_id`: MQTT client identifier.
+- `server`: MQTT broker server address.
+- `port`: MQTT broker server port.
+- `topic_root`: Root topic to publish messages to.
+- `username`: Username for MQTT broker authentication. Default = None
+- `password`: Password for MQTT broker authentication. Default = None
+- `version`: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
+- `tls_enabled`: Whether to use TLS encryption. Default = True
+- `key_serializer`: How to serialize the MQTT message key for producing.
+- `value_serializer`: How to serialize the MQTT message value for producing.
+- `qos`: Quality of Service level (0 or 1; 2 not yet supported) Default = 1.
+- `mqtt_flush_timeout_seconds`: how long to wait for publish acknowledgment
+of MQTT messages before failing. Default = 10.
+- `retain`: Retain last message for new subscribers. Default = False.
+Also accepts a callable that uses the current message value as input.
+- `properties`: An optional Properties instance for messages. Default = None.
+Also accepts a callable that uses the current message value as input.
+ :param on_client_connect_success: An optional callback made after successful
+client authentication, primarily for additional logging.
+- `on_client_connect_failure`: An optional callback made after failed
+client authentication (which should raise an Exception).
+Callback should accept the raised Exception as an argument.
+Callback must resolve (or propagate/re-raise) the Exception.
+
## quixstreams.sinks.community.redis
@@ -7970,6 +7974,151 @@ Implements retry logic to handle concurrent write conflicts.
- `batch`: The batch of data to write.
+
+
+## quixstreams.sinks.community.kafka
+
+
+
+### KafkaReplicatorSink
+
+```python
+class KafkaReplicatorSink(BaseSink)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/kafka.py#L22)
+
+A sink that produces data to an external Kafka cluster.
+
+This sink uses the same serialization approach as the Quix Application.
+
+Example Snippet:
+
+```python
+from quixstreams import Application
+from quixstreams.sinks.community.kafka import KafkaReplicatorSink
+
+app = Application(
+ consumer_group="group",
+)
+
+topic = app.topic("input-topic")
+
+# Define the external Kafka cluster configuration
+kafka_sink = KafkaReplicatorSink(
+ broker_address="external-kafka:9092",
+ topic_name="output-topic",
+ value_serializer="json",
+ key_serializer="bytes",
+)
+
+sdf = app.dataframe(topic=topic)
+sdf.sink(kafka_sink)
+
+app.run()
+```
+
+
+
+#### KafkaReplicatorSink.\_\_init\_\_
+
+```python
+def __init__(
+ broker_address: Union[str, ConnectionConfig],
+ topic_name: str,
+ value_serializer: SerializerType = "json",
+ key_serializer: SerializerType = "bytes",
+ producer_extra_config: Optional[dict] = None,
+ flush_timeout: float = 10.0,
+ origin_topic: Optional[Topic] = None,
+ auto_create_sink_topic: bool = True,
+ on_client_connect_success: Optional[ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[ClientConnectFailureCallback] = None
+) -> None
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/kafka.py#L55)
+
+**Arguments**:
+
+- `broker_address`: The connection settings for the external Kafka cluster.
+Accepts string with Kafka broker host and port formatted as `:`,
+or a ConnectionConfig object if authentication is required.
+- `topic_name`: The topic name to produce to on the external Kafka cluster.
+- `value_serializer`: The serializer type for values.
+Default - `json`.
+- `key_serializer`: The serializer type for keys.
+Default - `bytes`.
+- `producer_extra_config`: A dictionary with additional options that
+will be passed to `confluent_kafka.Producer` as is.
+Default - `None`.
+- `flush_timeout`: The time in seconds the producer waits for all messages
+to be delivered during flush.
+Default - 10.0.
+- `origin_topic`: If auto-creating the sink topic, can optionally pass the
+source topic to use its configuration.
+- `auto_create_sink_topic`: Whether to try to create the sink topic upon startup
+Default - True
+- `on_client_connect_success`: An optional callback made after successful
+client authentication, primarily for additional logging.
+- `on_client_connect_failure`: An optional callback made after failed
+client authentication (which should raise an Exception).
+Callback should accept the raised Exception as an argument.
+Callback must resolve (or propagate/re-raise) the Exception.
+
+
+
+#### KafkaReplicatorSink.setup
+
+```python
+def setup()
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/kafka.py#L111)
+
+Initialize the InternalProducer and Topic for serialization.
+
+
+
+#### KafkaReplicatorSink.add
+
+```python
+def add(value: Any, key: Any, timestamp: int, headers: HeadersTuples,
+ topic: str, partition: int, offset: int) -> None
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/kafka.py#L146)
+
+Add a message to be produced to the external Kafka cluster.
+
+This method converts the provided data into a Row object and uses
+the InternalProducer to serialize and produce it.
+
+**Arguments**:
+
+- `value`: The message value.
+- `key`: The message key.
+- `timestamp`: The message timestamp in milliseconds.
+- `headers`: The message headers.
+- `topic`: The source topic name.
+- `partition`: The source partition.
+- `offset`: The source offset.
+
+
+
+#### KafkaReplicatorSink.flush
+
+```python
+def flush() -> None
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/kafka.py#L190)
+
+Flush the producer to ensure all messages are delivered.
+
+This method is triggered by the Checkpoint class when it commits.
+If flush fails, the checkpoint will be aborted.
+
## quixstreams.sinks.community.pubsub
@@ -10076,7 +10225,7 @@ Multiple topics are expected for merged and joins streams.
def stream_id_from_topics(topics: Sequence[Topic]) -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/models/topics/manager.py#L350)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/models/topics/manager.py#L352)
Generate a stream_id by combining names of the provided topics.
@@ -12296,6 +12445,24 @@ so consecutive calls may yield different results for the same "latest timestamp"
- `delete`: If True, expired windows will be deleted.
- `collect`: If True, values will be collected into windows.
+
+
+#### WindowedPartitionTransaction.delete\_window
+
+```python
+def delete_window(start_ms: int, end_ms: int, prefix: bytes) -> None
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L394)
+
+Delete a single window defined by start and end timestamps.
+
+**Arguments**:
+
+- `start_ms`: start of the window in milliseconds
+- `end_ms`: end of the window in milliseconds
+- `prefix`: a key prefix
+
#### WindowedPartitionTransaction.delete\_windows
@@ -12305,7 +12472,7 @@ def delete_windows(max_start_time: int, delete_values: bool,
prefix: bytes) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L394)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L404)
Delete windows from RocksDB up to the specified `max_start_time` timestamp.
@@ -12331,7 +12498,7 @@ def get_windows(start_from_ms: int,
backwards: bool = False) -> list[WindowDetail[V]]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L411)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L421)
Get all windows that start between "start_from_ms" and "start_to_ms"
@@ -12356,7 +12523,7 @@ A sorted list of tuples in the format `((start, end), value)`.
def keys(cf_name: str = "default") -> Iterable[bytes]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L430)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L440)
Iterate over all keys in the store.
@@ -12379,7 +12546,7 @@ def flush(processed_offset: Optional[int] = None,
changelog_offset: Optional[int] = None)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L441)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L451)
Flush the recent updates to the storage.
@@ -12398,7 +12565,7 @@ optional.
def changelog_topic_partition() -> Optional[Tuple[str, int]]
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L455)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L465)
Return the changelog topic-partition for the StorePartition of this transaction.
@@ -12416,7 +12583,7 @@ Returns `None` if changelog_producer is not provided.
class PartitionRecoveryTransaction(Protocol)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L469)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L479)
A class for managing recovery for a StorePartition from a changelog message
@@ -12428,7 +12595,7 @@ A class for managing recovery for a StorePartition from a changelog message
def flush()
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L476)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/state/types.py#L486)
Flush the recovery update to the storage.
@@ -16563,6 +16730,72 @@ the default topic with optionally altered partition count
This module contains Sources developed and maintained by the members of Quix Streams community.
+
+
+## quixstreams.sources.community.mqtt
+
+
+
+### MQTTSource
+
+```python
+class MQTTSource(Source)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sources/community/mqtt.py#L56)
+
+A source that reads messages from an MQTT broker.
+
+
+
+#### MQTTSource.\_\_init\_\_
+
+```python
+def __init__(
+ topic: str,
+ client_id: str,
+ server: str,
+ port: int,
+ username: str = None,
+ password: str = None,
+ version: ProtocolVersion = "3.1.1",
+ tls_enabled: bool = True,
+ key_setter: MqttKeyValueSetter = _default_key_setter,
+ value_setter: MqttKeyValueSetter = _default_value_setter,
+ timestamp_setter: MqttTimestampSetter = _default_timestamp_setter,
+ payload_deserializer: Optional[Callable[[Any],
+ Any]] = _default_deserializer,
+ qos: Literal[0, 1] = 1,
+ on_client_connect_success: Optional[
+ ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[
+ ClientConnectFailureCallback] = None)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sources/community/mqtt.py#L61)
+
+**Arguments**:
+
+- `topic`: MQTT source topic.
+To consume from a base/prefix, use '#' as a wildcard i.e. my-topic-base/#
+- `client_id`: MQTT client identifier.
+- `server`: MQTT broker server address.
+- `port`: MQTT broker server port.
+- `username`: Username for MQTT broker authentication. Default = None
+- `password`: Password for MQTT broker authentication. Default = None
+- `version`: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
+- `tls_enabled`: Whether to use TLS encryption. Default = True
+- `payload_deserializer`: An optional payload deserializer.
+Useful when payloads are used by key, value, or timestamp setters.
+Used with default configuration, but can be set to None if not needed.
+- `qos`: Quality of Service level (0 or 1; 2 not yet supported) Default = 1.
+- `on_client_connect_success`: An optional callback made after successful
+client authentication, primarily for additional logging.
+- `on_client_connect_failure`: An optional callback made after failed
+client authentication (which should raise an Exception).
+Callback should accept the raised Exception as an argument.
+Callback must resolve (or propagate/re-raise) the Exception.
+
## quixstreams.sources.community.pubsub
diff --git a/docs/api-reference/sinks.md b/docs/api-reference/sinks.md
index 0827d05d5..e4fa845cf 100644
--- a/docs/api-reference/sinks.md
+++ b/docs/api-reference/sinks.md
@@ -322,7 +322,7 @@ a timeout specified in `retry_after`, and resume them when it's elapsed.
class InfluxDB3Sink(BatchingSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/core/influxdb3.py#L54)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/core/influxdb3.py#L53)
@@ -353,7 +353,7 @@ def __init__(token: str,
ClientConnectFailureCallback] = None)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/core/influxdb3.py#L62)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/core/influxdb3.py#L61)
A connector to sink processed data to InfluxDB v3.
@@ -479,11 +479,11 @@ Default - `str`.
- `value_serializer`: a callable to convert values to strings.
Default - `json.dumps`.
-
+
-## quixstreams.sinks.community.file.sink
+## quixstreams.sinks.community.file.base
-
+
### FileSink
@@ -491,7 +491,7 @@ Default - `json.dumps`.
class FileSink(BatchingSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/sink.py#L17)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/base.py#L24)
A sink that writes data batches to files using configurable formats and
destinations.
@@ -505,7 +505,7 @@ The destination determines the storage location and write behavior. By default,
it uses LocalDestination for writing to the local filesystem, but can be
configured to use other storage backends (e.g., cloud storage).
-
+
@@ -515,13 +515,12 @@ configured to use other storage backends (e.g., cloud storage).
def __init__(
directory: str = "",
format: Union[FormatName, Format] = "json",
- destination: Optional[Destination] = None,
on_client_connect_success: Optional[ClientConnectSuccessCallback] = None,
on_client_connect_failure: Optional[ClientConnectFailureCallback] = None
) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/sink.py#L31)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/base.py#L38)
Initialize the FileSink with the specified configuration.
@@ -533,8 +532,6 @@ Initialize the FileSink with the specified configuration.
current directory.
- `format`: Data serialization format, either as a string
("json", "parquet") or a Format instance.
-- `destination`: Storage destination handler. Defaults to
-LocalDestination if not specified.
- `on_client_connect_success`: An optional callback made after successful
client authentication, primarily for additional logging.
- `on_client_connect_failure`: An optional callback made after failed
@@ -542,7 +539,22 @@ client authentication (which should raise an Exception).
Callback should accept the raised Exception as an argument.
Callback must resolve (or propagate/re-raise) the Exception.
-
+
+
+
+
+#### FileSink.setup
+
+```python
+@abstractmethod
+def setup()
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/base.py#L76)
+
+Authenticate and validate connection here
+
+
@@ -552,14 +564,9 @@ Callback must resolve (or propagate/re-raise) the Exception.
def write(batch: SinkBatch) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/sink.py#L67)
-
-Write a batch of data using the configured format and destination.
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/base.py#L90)
-The method performs the following steps:
-1. Serializes the batch data using the configured format
-2. Writes the serialized data to the destination
-3. Handles any write failures by raising a backpressure error
+Write a batch of data using the configured format.
@@ -567,16 +574,11 @@ The method performs the following steps:
- `batch`: The batch of data to write.
-**Raises**:
-
-- `SinkBackpressureError`: If the write operation fails, indicating
-that the sink needs backpressure with a 5-second retry delay.
+
-
+## quixstreams.sinks.community.file.azure
-## quixstreams.sinks.community.file.destinations.azure
-
-
+
### AzureContainerNotFoundError
@@ -584,11 +586,11 @@ that the sink needs backpressure with a 5-second retry delay.
class AzureContainerNotFoundError(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L24)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/azure.py#L29)
Raised when the specified Azure File container does not exist.
-
+
### AzureContainerAccessDeniedError
@@ -596,36 +598,43 @@ Raised when the specified Azure File container does not exist.
class AzureContainerAccessDeniedError(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L28)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/azure.py#L33)
Raised when the specified Azure File container access is denied.
-
+
-### AzureFileDestination
+### AzureFileSink
```python
-class AzureFileDestination(Destination)
+class AzureFileSink(FileSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L32)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/azure.py#L37)
A destination that writes data to Microsoft Azure File.
Handles writing data to Azure containers using the Azure Blob SDK. Credentials can
be provided directly or via environment variables.
-
+
-#### AzureFileDestination.\_\_init\_\_
+#### AzureFileSink.\_\_init\_\_
```python
-def __init__(connection_string: str, container: str) -> None
+def __init__(
+ azure_connection_string: str,
+ azure_container: str,
+ directory: str = "",
+ format: Union[FormatName, Format] = "json",
+ on_client_connect_success: Optional[ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[ClientConnectFailureCallback] = None
+) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L40)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/azure.py#L45)
Initialize the Azure File destination.
@@ -633,170 +642,68 @@ Initialize the Azure File destination.
***Arguments:***
-- `connection_string`: Azure client authentication string.
-- `container`: Azure container name.
+- `azure_connection_string`: Azure client authentication string.
+- `azure_container`: Azure container name.
+- `directory`: Base directory path for storing files. Defaults to
+current directory.
+- `format`: Data serialization format, either as a string
+("json", "parquet") or a Format instance.
+- `on_client_connect_success`: An optional callback made after successful
+client authentication, primarily for additional logging.
+- `on_client_connect_failure`: An optional callback made after failed
+client authentication (which should raise an Exception).
+Callback should accept the raised Exception as an argument.
+Callback must resolve (or propagate/re-raise) the Exception.
**Raises**:
- `AzureContainerNotFoundError`: If the specified container doesn't exist.
- `AzureContainerAccessDeniedError`: If access to the container is denied.
-
+
-
+## quixstreams.sinks.community.file.local
-#### AzureFileDestination.write
+
-```python
-def write(data: bytes, batch: SinkBatch) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/azure.py#L94)
-
-Write data to Azure.
-
-
-
-***Arguments:***
-
-- `data`: The serialized data to write.
-- `batch`: The batch information containing topic and partition details.
-
-
-
-## quixstreams.sinks.community.file.destinations.base
-
-
-
-### Destination
+### AppendNotSupported
```python
-class Destination(ABC)
+class AppendNotSupported(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L16)
-
-Abstract base class for defining where and how data should be stored.
-
-Destinations handle the storage of serialized data, whether that's to local
-disk, cloud storage, or other locations. They manage the physical writing of
-data while maintaining a consistent directory/path structure based on topics
-and partitions.
-
-
-
-
-
-#### Destination.setup
-
-```python
-@abstractmethod
-def setup()
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L29)
-
-Authenticate and validate connection here
-
-
-
-
-
-#### Destination.write
-
-```python
-@abstractmethod
-def write(data: bytes, batch: SinkBatch) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L34)
-
-Write the serialized data to storage.
-
-
-
-***Arguments:***
-
-- `data`: The serialized data to write.
-- `batch`: The batch information containing topic, partition and offset
-details.
-
-
-
-
-
-#### Destination.set\_directory
-
-```python
-def set_directory(directory: str) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L43)
-
-Configure the base directory for storing files.
-
-
-
-***Arguments:***
-
-- `directory`: The base directory path where files will be stored.
-
-**Raises**:
-
-- `ValueError`: If the directory path contains invalid characters.
-Only alphanumeric characters (a-zA-Z0-9), spaces, dots, slashes, and
-underscores are allowed.
-
-
-
-
-
-#### Destination.set\_extension
-
-```python
-def set_extension(format: Format) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/base.py#L64)
-
-Set the file extension based on the format.
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/local.py#L15)
+Raised when append=True but specified format does not support it
-
-***Arguments:***
-
-- `format`: The Format instance that defines the file extension.
-
-
+
-## quixstreams.sinks.community.file.destinations.local
-
-
-
-### LocalDestination
+### LocalFileSink
```python
-class LocalDestination(Destination)
+class LocalFileSink(FileSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/local.py#L15)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/local.py#L19)
A destination that writes data to the local filesystem.
Handles writing data to local files with support for both creating new files
and appending to existing ones.
-
+
-#### LocalDestination.\_\_init\_\_
+#### LocalFileSink.\_\_init\_\_
```python
-def __init__(append: bool = False) -> None
+def __init__(append: bool = False,
+ directory: str = "",
+ format: Union[FormatName, Format] = "json") -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/local.py#L22)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/local.py#L26)
Initialize the local destination.
@@ -805,59 +712,23 @@ Initialize the local destination.
***Arguments:***
- `append`: If True, append to existing files instead of creating new
-ones. Defaults to False.
-
-
-
-
-
-#### LocalDestination.set\_extension
-
-```python
-def set_extension(format: Format) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/local.py#L35)
-
-Set the file extension and validate append mode compatibility.
-
-
-
-***Arguments:***
-
-- `format`: The Format instance that defines the file extension.
+ones by selecting the lexicographical last file in the given directory
+(or creates one).
+Defaults to False.
+- `directory`: Base directory path for storing files. Defaults to
+current directory.
+- `format`: Data serialization format, either as a string
+("json", "parquet") or a Format instance.
**Raises**:
-- `ValueError`: If append mode is enabled but the format doesn't
-support appending.
+- `AppendNotSupported`: If append=True but given format does not support it.
-
+
-
+## quixstreams.sinks.community.file.s3
-#### LocalDestination.write
-
-```python
-def write(data: bytes, batch: SinkBatch) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/local.py#L46)
-
-Write data to a local file.
-
-
-
-***Arguments:***
-
-- `data`: The serialized data to write.
-- `batch`: The batch information containing topic and partition details.
-
-
-
-## quixstreams.sinks.community.file.destinations.s3
-
-
+
### S3BucketNotFoundError
@@ -865,11 +736,11 @@ Write data to a local file.
class S3BucketNotFoundError(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L14)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/s3.py#L29)
Raised when the specified S3 bucket does not exist.
-
+
### S3BucketAccessDeniedError
@@ -877,30 +748,37 @@ Raised when the specified S3 bucket does not exist.
class S3BucketAccessDeniedError(Exception)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L18)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/s3.py#L33)
Raised when the specified S3 bucket access is denied.
-
+
-### S3Destination
+### S3FileSink
```python
-class S3Destination(Destination)
+class S3FileSink(FileSink)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L22)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/s3.py#L37)
-A destination that writes data to Amazon S3.
+A sink that writes data batches to files using configurable formats and
+destinations.
+
+The sink groups messages by their topic and partition, ensuring data from the
+same source is stored together. Each batch is serialized using the specified
+format (e.g., JSON, Parquet) before being written to the configured
+destination.
-Handles writing data to S3 buckets using the AWS SDK. Credentials can be
-provided directly or via environment variables.
+The destination determines the storage location and write behavior. By default,
+it uses LocalDestination for writing to the local filesystem, but can be
+configured to use other storage backends (e.g., cloud storage).
-
+
-#### S3Destination.\_\_init\_\_
+#### S3FileSink.\_\_init\_\_
```python
def __init__(bucket: str,
@@ -910,10 +788,16 @@ def __init__(bucket: str,
region_name: Optional[str] = getenv("AWS_REGION",
getenv("AWS_DEFAULT_REGION")),
endpoint_url: Optional[str] = getenv("AWS_ENDPOINT_URL_S3"),
+ directory: str = "",
+ format: Union[FormatName, Format] = "json",
+ on_client_connect_success: Optional[
+ ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[
+ ClientConnectFailureCallback] = None,
**kwargs) -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L29)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/s3.py#L51)
Initialize the S3 destination.
@@ -938,27 +822,6 @@ NOTE: can alternatively set the AWS_ENDPOINT_URL_S3 environment variable
- `S3BucketNotFoundError`: If the specified bucket doesn't exist.
- `S3BucketAccessDeniedError`: If access to the bucket is denied.
-
-
-
-
-#### S3Destination.write
-
-```python
-def write(data: bytes, batch: SinkBatch) -> None
-```
-
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/destinations/s3.py#L89)
-
-Write data to S3.
-
-
-
-***Arguments:***
-
-- `data`: The serialized data to write.
-- `batch`: The batch information containing topic and partition details.
-
## quixstreams.sinks.community.file.formats.base
@@ -1165,7 +1028,7 @@ compressed with gzip.
class ParquetFormat(Format)
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L16)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L13)
Serializes batches of messages into Parquet format.
@@ -1186,7 +1049,7 @@ def __init__(file_extension: str = ".parquet",
compression: Compression = "snappy") -> None
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L29)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L26)
Initializes the ParquetFormat.
@@ -1211,7 +1074,7 @@ or "zstd". Defaults to "snappy".
def file_extension() -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L47)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L63)
Returns the file extension used for output files.
@@ -1231,7 +1094,7 @@ The file extension as a string.
def serialize(batch: SinkBatch) -> bytes
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L55)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/file/formats/parquet.py#L71)
Serializes a `SinkBatch` into bytes in Parquet format.
@@ -1758,6 +1621,82 @@ Note: Transactions could be an option here, but then each record requires a
network call, and the transaction has size limits...so `bulk_write` is used
instead, with the downside that duplicate writes may occur if errors arise.
+
+
+## quixstreams.sinks.community.mqtt
+
+
+
+### MQTTSink
+
+```python
+class MQTTSink(BaseSink)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/mqtt.py#L35)
+
+A sink that publishes messages to an MQTT broker.
+
+
+
+
+
+#### MQTTSink.\_\_init\_\_
+
+```python
+def __init__(client_id: str,
+ server: str,
+ port: int,
+ topic_root: str,
+ username: str = None,
+ password: str = None,
+ version: ProtocolVersion = "3.1.1",
+ tls_enabled: bool = True,
+ key_serializer: Callable[[Any], str] = bytes.decode,
+ value_serializer: Callable[[Any], str] = json.dumps,
+ qos: Literal[0, 1] = 1,
+ mqtt_flush_timeout_seconds: int = 10,
+ retain: Union[bool, Callable[[Any], bool]] = False,
+ properties: Optional[MqttPropertiesHandler] = None,
+ on_client_connect_success: Optional[
+ ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[
+ ClientConnectFailureCallback] = None)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sinks/community/mqtt.py#L40)
+
+Initialize the MQTTSink.
+
+
+
+***Arguments:***
+
+- `client_id`: MQTT client identifier.
+- `server`: MQTT broker server address.
+- `port`: MQTT broker server port.
+- `topic_root`: Root topic to publish messages to.
+- `username`: Username for MQTT broker authentication. Default = None
+- `password`: Password for MQTT broker authentication. Default = None
+- `version`: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
+- `tls_enabled`: Whether to use TLS encryption. Default = True
+- `key_serializer`: How to serialize the MQTT message key for producing.
+- `value_serializer`: How to serialize the MQTT message value for producing.
+- `qos`: Quality of Service level (0 or 1; 2 not yet supported) Default = 1.
+- `mqtt_flush_timeout_seconds`: how long to wait for publish acknowledgment
+of MQTT messages before failing. Default = 10.
+- `retain`: Retain last message for new subscribers. Default = False.
+Also accepts a callable that uses the current message value as input.
+- `properties`: An optional Properties instance for messages. Default = None.
+Also accepts a callable that uses the current message value as input.
+ :param on_client_connect_success: An optional callback made after successful
+client authentication, primarily for additional logging.
+- `on_client_connect_failure`: An optional callback made after failed
+client authentication (which should raise an Exception).
+Callback should accept the raised Exception as an argument.
+Callback must resolve (or propagate/re-raise) the Exception.
+
+
## quixstreams.sinks.community.neo4j
diff --git a/docs/api-reference/sources.md b/docs/api-reference/sources.md
index 77df0d13b..44960c1bf 100644
--- a/docs/api-reference/sources.md
+++ b/docs/api-reference/sources.md
@@ -1476,6 +1476,76 @@ client authentication (which should raise an Exception).
Callback should accept the raised Exception as an argument.
Callback must resolve (or propagate/re-raise) the Exception.
+
+
+## quixstreams.sources.community.mqtt
+
+
+
+### MQTTSource
+
+```python
+class MQTTSource(Source)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sources/community/mqtt.py#L56)
+
+A source that reads messages from an MQTT broker.
+
+
+
+
+
+#### MQTTSource.\_\_init\_\_
+
+```python
+def __init__(
+ topic: str,
+ client_id: str,
+ server: str,
+ port: int,
+ username: str = None,
+ password: str = None,
+ version: ProtocolVersion = "3.1.1",
+ tls_enabled: bool = True,
+ key_setter: MqttKeyValueSetter = _default_key_setter,
+ value_setter: MqttKeyValueSetter = _default_value_setter,
+ timestamp_setter: MqttTimestampSetter = _default_timestamp_setter,
+ payload_deserializer: Optional[Callable[[Any],
+ Any]] = _default_deserializer,
+ qos: Literal[0, 1] = 1,
+ on_client_connect_success: Optional[
+ ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[
+ ClientConnectFailureCallback] = None)
+```
+
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/sources/community/mqtt.py#L61)
+
+
+
+***Arguments:***
+
+- `topic`: MQTT source topic.
+To consume from a base/prefix, use '#' as a wildcard i.e. my-topic-base/#
+- `client_id`: MQTT client identifier.
+- `server`: MQTT broker server address.
+- `port`: MQTT broker server port.
+- `username`: Username for MQTT broker authentication. Default = None
+- `password`: Password for MQTT broker authentication. Default = None
+- `version`: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
+- `tls_enabled`: Whether to use TLS encryption. Default = True
+- `payload_deserializer`: An optional payload deserializer.
+Useful when payloads are used by key, value, or timestamp setters.
+Used with default configuration, but can be set to None if not needed.
+- `qos`: Quality of Service level (0 or 1; 2 not yet supported) Default = 1.
+- `on_client_connect_success`: An optional callback made after successful
+client authentication, primarily for additional logging.
+- `on_client_connect_failure`: An optional callback made after failed
+client authentication (which should raise an Exception).
+Callback should accept the raised Exception as an argument.
+Callback must resolve (or propagate/re-raise) the Exception.
+
## quixstreams.sources.community.pubsub.pubsub
diff --git a/docs/api-reference/topics.md b/docs/api-reference/topics.md
index 9367ee1ab..479ffd9bc 100644
--- a/docs/api-reference/topics.md
+++ b/docs/api-reference/topics.md
@@ -600,7 +600,7 @@ Multiple topics are expected for merged and joins streams.
def stream_id_from_topics(topics: Sequence[Topic]) -> str
```
-[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/models/topics/manager.py#L350)
+[[VIEW SOURCE]](https://github.com/quixio/quix-streams/blob/main/quixstreams/models/topics/manager.py#L352)
Generate a stream_id by combining names of the provided topics.
diff --git a/docs/build/build.py b/docs/build/build.py
index 55990f3a6..e50b1bd5a 100644
--- a/docs/build/build.py
+++ b/docs/build/build.py
@@ -124,11 +124,10 @@
"quixstreams.sinks.base.exceptions",
"quixstreams.sinks.core.influxdb3",
"quixstreams.sinks.core.csv",
- "quixstreams.sinks.community.file.sink",
- "quixstreams.sinks.community.file.destinations.azure",
- "quixstreams.sinks.community.file.destinations.base",
- "quixstreams.sinks.community.file.destinations.local",
- "quixstreams.sinks.community.file.destinations.s3",
+ "quixstreams.sinks.community.file.base",
+ "quixstreams.sinks.community.file.azure",
+ "quixstreams.sinks.community.file.local",
+ "quixstreams.sinks.community.file.s3",
"quixstreams.sinks.community.file.formats.base",
"quixstreams.sinks.community.file.formats.json",
"quixstreams.sinks.community.file.formats.parquet",
diff --git a/docs/windowing.md b/docs/windowing.md
index ea5a9a72f..de0a3e704 100644
--- a/docs/windowing.md
+++ b/docs/windowing.md
@@ -593,6 +593,76 @@ if __name__ == '__main__':
```
+### Early window expiration with triggers
+!!! info New in v3.24.0
+
+To expire windows before their natural expiration time based on custom conditions, you can pass `before_update` or `after_update` callbacks to `.tumbling_window()` and `.hopping_window()` methods.
+
+This is useful when you want to emit results as soon as certain conditions are met, rather than waiting for the window to close naturally.
+
+**How it works**:
+
+- The `before_update` callback is invoked before the window aggregation is updated with a new value.
+- The `after_update` callback is invoked after the window aggregation has been updated with a new value.
+- Both callbacks receive: `aggregated` (current or updated aggregated value), `value` (incoming value), `key`, `timestamp`, and `headers`.
+- For `collect()` operations without aggregation, `aggregated` contains the list of collected values.
+- If either callback returns `True`, the window is immediately expired and emitted downstream.
+- The window metadata is deleted from state, but collected values (if using `.collect()`) remain until natural expiration.
+- This means a triggered window can be "resurrected" if new data arrives within its time range - a new window will be created with the previously collected values still present.
+
+**Example with after_update**:
+
+```python
+from typing import Any
+
+from datetime import timedelta
+from quixstreams import Application
+
+app = Application(...)
+sdf = app.dataframe(...)
+
+
+def trigger_on_threshold(
+ aggregated: int, value: Any, key: Any, timestamp: int, headers: Any
+) -> bool:
+ """
+ Expire the window early when the sum exceeds 1000.
+ """
+ return aggregated > 1000
+
+
+# Define a 1-hour tumbling window with early expiration trigger
+sdf = (
+ sdf.tumbling_window(timedelta(hours=1), after_update=trigger_on_threshold)
+ .sum()
+ .final()
+)
+
+# Start the application
+if __name__ == '__main__':
+ app.run()
+
+```
+
+**Example with before_update**:
+
+```python
+def trigger_before_large_value(
+ aggregated: int, value: Any, key: Any, timestamp: int, headers: Any
+) -> bool:
+ """
+ Expire the window before adding a value if it would make the sum too large.
+ """
+ return (aggregated + value) > 1000
+
+
+sdf = (
+ sdf.tumbling_window(timedelta(hours=1), before_update=trigger_before_large_value)
+ .sum()
+ .final()
+)
+```
+
## Emitting results
@@ -660,73 +730,6 @@ Also, specifying a grace period using `grace_ms` will increase the latency, beca
You can use `final()` mode when some latency is allowed, but the emitted results must be complete and unique.
-## Closing strategies
-
-By default, windows use the **key** closing strategy.
-In this strategy, messages advance time and close only windows with the **same** message key.
-
-If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until the message with the same key is received.
-
-```python
-from datetime import timedelta
-from quixstreams import Application
-from quixstreams.dataframe.windows import Sum
-
-app = Application(...)
-sdf = app.dataframe(...)
-
-# Calculate a sum of values over a window of 10 seconds
-# and use .final() to emit results only when the window is complete
-sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).final(closing_strategy="key")
-
-# Details:
-# -> Timestamp=100, Key="A", value=1 -> emit nothing (the window is not closed yet)
-# -> Timestamp=101, Key="B", value=2 -> emit nothing (the window is not closed yet)
-# -> Timestamp=105, Key="C", value=3 -> emit nothing (the window is not closed yet)
-# -> Timestamp=10100, Key="B", value=2 -> emit one message with key "B" and value {"start": 0, "end": 10000, "value": 2}, the time has progressed beyond the window end for the "B" key only.
-# -> Timestamp=8000, Key="A", value=1 -> emit nothing (the window is not closed yet)
-# -> Timestamp=10001, Key="A", value=1 -> emit one message with key "A" and value {"start": 0, "end": 10000, "value": 2}, the time has progressed beyond the window end for the "A" key.
-
-# Results:
-# (key="B", value={"start": 0, "end": 10000, "value": 2})
-# (key="A", value={"start": 0, "end": 10000, "value": 2})
-# No message for key "C" as the window is never closed since no messages with key "C" and a timestamp later than 10000 was received
-```
-
-An alternative is to use the **partition** closing strategy.
-In this strategy, messages advance time and close windows for the whole partition to which this key belongs.
-
-If messages aren't ordered accross keys some message can be skipped if the windows are already closed.
-
-```python
-from datetime import timedelta
-from quixstreams import Application
-from quixstreams.dataframe.windows import Sum
-
-app = Application(...)
-sdf = app.dataframe(...)
-
-# Calculate a sum of values over a window of 10 seconds
-# and use .final() to emit results only when the window is complete
-sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).final(closing_strategy="partition")
-
-# Details:
-# -> Timestamp=100, Key="A", value=1 -> emit nothing (the window is not closed yet)
-# -> Timestamp=101, Key="B", value=2 -> emit nothing (the window is not closed yet)
-# -> Timestamp=105, Key="C", value=3 -> emit nothing (the window is not closed yet)
-# -> Timestamp=10100, Key="B", value=1 -> emit three messages, the time has progressed beyond the window end for all the keys in the partition
-# 1. first one with key "A" and value {"start": 0, "end": 10000, "value": 1}
-# 2. second one with key "B" and value {"start": 0, "end": 10000, "value": 2}
-# 3. third one with key "C" and value {"start": 0, "end": 10000, "value": 3}
-# -> Timestamp=8000, Key="A", value=1 -> emit nothing and value isn't part of the sum (the window is already closed)
-# -> Timestamp=10001, Key="A", value=1 -> emit nothing (the window is not closed yet)
-
-# Results:
-# (key="A", value={"start": 0, "end": 10000, "value": 1})
-# (key="B", value={"start": 0, "end": 10000, "value": 2})
-# (key="C", value={"start": 0, "end": 10000, "value": 3})
-```
-
## Transforming the result of a windowed aggregation
Windowed aggregations return aggregated results in the following format/schema:
diff --git a/quixstreams/app.py b/quixstreams/app.py
index 30d71ab4e..0a6f14706 100644
--- a/quixstreams/app.py
+++ b/quixstreams/app.py
@@ -6,7 +6,6 @@
import time
import uuid
import warnings
-from collections import defaultdict
from pathlib import Path
from typing import Callable, List, Literal, Optional, Protocol, Tuple, Type, Union, cast
@@ -15,7 +14,7 @@
from pydantic_settings import BaseSettings as PydanticBaseSettings
from pydantic_settings import PydanticBaseSettingsSource, SettingsConfigDict
-from .context import copy_context, set_message_context
+from .context import MessageContext, copy_context, set_message_context
from .core.stream.functions.types import VoidExecutor
from .dataframe import DataFrameRegistry, StreamingDataFrame
from .error_callbacks import (
@@ -46,12 +45,14 @@
)
from .platforms.quix.env import QUIX_ENVIRONMENT
from .processing import ProcessingContext
+from .processing.watermarking import WatermarkManager, WatermarkMessage
from .runtracker import RunTracker
from .sinks import SinkManager
from .sources import BaseSource, SourceException, SourceManager
from .state import StateStoreManager
from .state.recovery import RecoveryManager
from .state.rocksdb import RocksDBOptionsType
+from .utils.format import format_timestamp
from .utils.settings import BaseSettings
__all__ = ("Application", "ApplicationConfig")
@@ -152,6 +153,8 @@ def __init__(
topic_create_timeout: float = 60,
processing_guarantee: ProcessingGuarantee = "at-least-once",
max_partition_buffer_size: int = 10000,
+ watermarking_default_assignor_enabled: bool = True,
+ watermarking_interval: float = 1.0,
):
"""
:param broker_address: Connection settings for Kafka.
@@ -220,6 +223,14 @@ def __init__(
It is a soft limit, and the actual number of buffered messages can be up to x2 higher.
Lower value decreases the memory use, but increases the latency.
Default - `10000`.
+ :param watermarking_default_assignor_enabled: when True, the applicaiton extracts watermarks
+ from incoming messages by default (respecting the `Topic(timestamp_extractor)` if configured).
+ When disabled, no watermarks will be emitted unless the `StreamingDataFrame.set_timestamp()`
+ is called for each main StreamingDataFrame.
+ Default - `True`.
+
+ :param watermarking_interval: how often to emit watermarks updates for assigned partitions (in seconds).
+ Default - `1.0`s.
***Error Handlers***
To handle errors, `Application` accepts callbacks triggered when
@@ -339,6 +350,7 @@ def __init__(
rocksdb_options=rocksdb_options,
use_changelog_topics=use_changelog_topics,
max_partition_buffer_size=max_partition_buffer_size,
+ watermarking_default_assignor_enabled=watermarking_default_assignor_enabled,
)
self._on_message_processed = on_message_processed
@@ -374,6 +386,11 @@ def __init__(
self._source_manager = SourceManager()
self._sink_manager = SinkManager()
self._dataframe_registry = DataFrameRegistry()
+ self._watermark_manager = WatermarkManager(
+ producer=self._producer,
+ topic_manager=self._topic_manager,
+ interval=watermarking_interval,
+ )
self._processing_context = ProcessingContext(
commit_interval=self._config.commit_interval,
commit_every=self._config.commit_every,
@@ -383,6 +400,7 @@ def __init__(
exactly_once=self._config.exactly_once,
sink_manager=self._sink_manager,
dataframe_registry=self._dataframe_registry,
+ watermark_manager=self._watermark_manager,
)
self._run_tracker = RunTracker()
@@ -903,9 +921,19 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None):
printer = self._processing_context.printer
run_tracker = self._run_tracker
consumer = self._consumer
+ producer = self._producer
+ producer_poll_timeout = self._config.producer_poll_timeout
+ watermark_manager = self._watermark_manager
+
+ # Set the topics to be tracked by the Watermark manager
+ watermark_manager.set_topics(topics=self._dataframe_registry.consumer_topics)
consumer.subscribe(
- topics=self._dataframe_registry.consumer_topics + changelog_topics,
+ topics=self._dataframe_registry.consumer_topics
+ + changelog_topics
+ + [
+ self._watermark_manager.watermarks_topic
+ ], # TODO: We subscribe here because otherwise it can't deserialize a message. Maybe it's time to split poll() and deserialization
on_assign=self._on_assign,
on_revoke=self._on_revoke,
on_lost=self._on_lost,
@@ -922,11 +950,14 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None):
state_manager.do_recovery()
run_tracker.timeout_refresh()
else:
+ # Serve producer callbacks
+ producer.poll(producer_poll_timeout)
process_message(dataframes_composed)
processing_context.commit_checkpoint()
consumer.resume_backpressured()
source_manager.raise_for_error()
printer.print()
+ watermark_manager.produce()
run_tracker.update_status()
logger.info("Stopping the application")
@@ -954,9 +985,7 @@ def _quix_runtime_init(self):
if self._state_manager.stores:
check_state_management_enabled()
- def _process_message(self, dataframe_composed):
- # Serve producer callbacks
- self._producer.poll(self._config.producer_poll_timeout)
+ def _process_message(self, dataframe_composed: dict[str, VoidExecutor]):
rows = self._consumer.poll_row(
timeout=self._config.consumer_poll_timeout,
buffered=self._dataframe_registry.requires_time_alignment,
@@ -978,7 +1007,54 @@ def _process_message(self, dataframe_composed):
first_row.offset,
)
+ if topic_name == self._watermark_manager.watermarks_topic.name:
+ watermark = self._watermark_manager.receive(
+ message=cast(WatermarkMessage, first_row.value)
+ )
+ if watermark is None:
+ return
+
+ data_topics = self._topic_manager.non_changelog_topics
+ data_tps = [
+ tp for tp in self._consumer.assignment() if tp.topic in data_topics
+ ]
+ for tp in data_tps:
+ logger.info(
+ f"Process watermark {format_timestamp(watermark)}. "
+ f"topic={tp.topic} partition={tp.partition} timestamp={watermark}"
+ )
+ # Create a MessageContext to process a watermark update
+ # for each assigned TP
+ watermark_ctx = MessageContext(
+ topic=tp.topic,
+ partition=tp.partition,
+ offset=None,
+ size=0,
+ )
+ context = copy_context()
+ context.run(set_message_context, watermark_ctx)
+ # Execute StreamingDataFrame in a context
+ context.run(
+ dataframe_composed[tp.topic],
+ value=None,
+ key=None,
+ timestamp=watermark,
+ headers=[],
+ is_watermark=True,
+ )
+ return
+
for row in rows:
+ if self._config.watermarking_default_assignor_enabled:
+ # Update the watermark with the current row's timestamp
+ # if the default watermark assignor is enabled (True by default).
+ self._processing_context.watermark_manager.store(
+ topic=row.topic,
+ partition=row.partition,
+ timestamp=row.timestamp,
+ default=True,
+ )
+
context = copy_context()
context.run(set_message_context, row.context)
try:
@@ -999,12 +1075,12 @@ def _process_message(self, dataframe_composed):
# Store the message offset after it's successfully processed
self._processing_context.store_offset(
- topic=topic_name, partition=partition, offset=offset
+ topic=topic_name, partition=partition, offset=offset or 0
)
self._run_tracker.set_message_consumed(True)
if self._on_message_processed is not None:
- self._on_message_processed(topic_name, partition, offset)
+ self._on_message_processed(topic_name, partition, offset or 0)
def _on_assign(self, _, topic_partitions: List[TopicPartition]):
"""
@@ -1024,42 +1100,34 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]):
self._source_manager.start_sources()
# Assign partitions manually to pause the changelog topics
- self._consumer.assign(topic_partitions)
- # Pause changelog topic+partitions immediately after assignment
- non_changelog_topics = self._topic_manager.non_changelog_topics
- changelog_tps = [
- tp for tp in topic_partitions if tp.topic not in non_changelog_topics
+ watermarks_partitions = [
+ TopicPartition(
+ topic=self._watermark_manager.watermarks_topic.name, partition=i
+ )
+ for i in range(
+ self._watermark_manager.watermarks_topic.broker_config.num_partitions
+ or 1
+ )
]
+ # TODO: The set is used because the watermark tp can already be present in the "topic_partitions"
+ # because we use `subscribe()` earlier. Fix the mess later.
+ # TODO: Also, how to avoid reading the whole WM topic on each restart?
+ # We really need only the most recent data
+ # Is it fine to read it from the end? The active partitions must still publish something.
+ # Or should we commit it?
+ self._consumer.assign(list(set(topic_partitions + watermarks_partitions)))
+
+ # Pause changelog topic+partitions immediately after assignment
+ changelog_topics = {t.name for t in self._topic_manager.changelog_topics_list}
+ changelog_tps = [tp for tp in topic_partitions if tp.topic in changelog_topics]
self._consumer.pause(changelog_tps)
- if self._state_manager.stores:
- non_changelog_tps = [
- tp for tp in topic_partitions if tp.topic in non_changelog_topics
- ]
- committed_tps = self._consumer.committed(
- partitions=non_changelog_tps, timeout=30
- )
- committed_offsets: dict[int, dict[str, int]] = defaultdict(dict)
- for tp in committed_tps:
- if tp.error:
- raise RuntimeError(
- f"Failed to get committed offsets for "
- f'"{tp.topic}[{tp.partition}]" from the broker: {tp.error}'
- )
- committed_offsets[tp.partition][tp.topic] = tp.offset
+ data_topics = self._topic_manager.non_changelog_topics
+ data_tps = [tp for tp in topic_partitions if tp.topic in data_topics]
+
+ for tp in data_tps:
+ self._assign_state_partitions(topic=tp.topic, partition=tp.partition)
- # Match the assigned TP with a stream ID via DataFrameRegistry
- for tp in non_changelog_tps:
- stream_ids = self._dataframe_registry.get_stream_ids(
- topic_name=tp.topic
- )
- # Assign store partitions for the given stream ids
- for stream_id in stream_ids:
- self._state_manager.on_partition_assign(
- stream_id=stream_id,
- partition=tp.partition,
- committed_offsets=committed_offsets[tp.partition],
- )
self._run_tracker.timeout_refresh()
def _on_revoke(self, _, topic_partitions: List[TopicPartition]):
@@ -1079,7 +1147,12 @@ def _on_revoke(self, _, topic_partitions: List[TopicPartition]):
else:
self._processing_context.commit_checkpoint(force=True)
- self._revoke_state_partitions(topic_partitions=topic_partitions)
+ data_topics = self._topic_manager.non_changelog_topics
+ data_tps = [tp for tp in topic_partitions if tp.topic in data_topics]
+ for tp in data_tps:
+ self._watermark_manager.on_revoke(topic=tp.topic, partition=tp.partition)
+ self._revoke_state_partitions(topic=tp.topic, partition=tp.partition)
+
self._consumer.reset_backpressure()
def _on_lost(self, _, topic_partitions: List[TopicPartition]):
@@ -1088,23 +1161,34 @@ def _on_lost(self, _, topic_partitions: List[TopicPartition]):
"""
logger.debug("Rebalancing: dropping lost partitions")
- self._revoke_state_partitions(topic_partitions=topic_partitions)
+ data_tps = [
+ tp
+ for tp in topic_partitions
+ if tp.topic in self._topic_manager.non_changelog_topics
+ ]
+ for tp in data_tps:
+ self._watermark_manager.on_revoke(topic=tp.topic, partition=tp.partition)
+ self._revoke_state_partitions(topic=tp.topic, partition=tp.partition)
+
self._consumer.reset_backpressure()
- def _revoke_state_partitions(self, topic_partitions: List[TopicPartition]):
- non_changelog_topics = self._topic_manager.non_changelog_topics
- non_changelog_tps = [
- tp for tp in topic_partitions if tp.topic in non_changelog_topics
- ]
- for tp in non_changelog_tps:
- if self._state_manager.stores:
- stream_ids = self._dataframe_registry.get_stream_ids(
- topic_name=tp.topic
+ def _assign_state_partitions(self, topic: str, partition: int):
+ if self._state_manager.stores:
+ # Match the assigned TP with a stream ID via DataFrameRegistry
+ stream_ids = self._dataframe_registry.get_stream_ids(topic_name=topic)
+ # Assign store partitions for the given stream ids
+ for stream_id in stream_ids:
+ self._state_manager.on_partition_assign(
+ stream_id=stream_id, partition=partition
+ )
+
+ def _revoke_state_partitions(self, topic: str, partition: int):
+ if self._state_manager.stores:
+ stream_ids = self._dataframe_registry.get_stream_ids(topic_name=topic)
+ for stream_id in stream_ids:
+ self._state_manager.on_partition_revoke(
+ stream_id=stream_id, partition=partition
)
- for stream_id in stream_ids:
- self._state_manager.on_partition_revoke(
- stream_id=stream_id, partition=tp.partition
- )
def _setup_signal_handlers(self):
signal.signal(signal.SIGINT, self._on_sigint)
@@ -1156,6 +1240,7 @@ class ApplicationConfig(BaseSettings):
rocksdb_options: Optional[RocksDBOptionsType] = None
use_changelog_topics: bool = True
max_partition_buffer_size: int = 10000
+ watermarking_default_assignor_enabled: bool = True
@classmethod
def settings_customise_sources(
diff --git a/quixstreams/checkpointing/checkpoint.py b/quixstreams/checkpointing/checkpoint.py
index 7bdb09044..430d72d2a 100644
--- a/quixstreams/checkpointing/checkpoint.py
+++ b/quixstreams/checkpointing/checkpoint.py
@@ -1,3 +1,4 @@
+import abc
import logging
import time
from abc import abstractmethod
@@ -26,7 +27,7 @@
logger = logging.getLogger(__name__)
-class BaseCheckpoint:
+class BaseCheckpoint(abc.ABC):
"""
Base class to keep track of state updates and consumer offsets and to checkpoint these
updates on schedule.
@@ -70,7 +71,7 @@ def empty(self) -> bool:
Returns `True` if checkpoint doesn't have any offsets stored yet.
:return:
"""
- return not bool(self._tp_offsets)
+ return not bool(self._tp_offsets) and not bool(self._store_transactions)
def store_offset(self, topic: str, partition: int, offset: int):
"""
@@ -228,20 +229,12 @@ def commit(self):
partition,
store_name,
), transaction in self._store_transactions.items():
- topics = self._dataframe_registry.get_topics_for_stream_id(
- stream_id=stream_id
- )
- processed_offsets = {
- topic: offset
- for (topic, partition_), offset in self._tp_offsets.items()
- if topic in topics and partition_ == partition
- }
if transaction.failed:
raise StoreTransactionFailed(
f'Detected a failed transaction for store "{store_name}", '
f"the checkpoint is aborted"
)
- transaction.prepare(processed_offsets=processed_offsets)
+ transaction.prepare()
# Step 3. Flush producer to trigger all delivery callbacks and ensure that
# all messages are produced
@@ -263,7 +256,9 @@ def commit(self):
self._producer.commit_transaction(
offsets, self._consumer.consumer_group_metadata()
)
- else:
+ elif offsets:
+ # Checkpoint may have no offsets processed when only watermarks are processed.
+ # In this case we don't have anything to commit to Kafka.
logger.debug("Checkpoint: committing consumer")
try:
partitions = self._consumer.commit(offsets=offsets, asynchronous=False)
diff --git a/quixstreams/core/stream/functions/apply.py b/quixstreams/core/stream/functions/apply.py
index bdf493953..d34bdc4df 100644
--- a/quixstreams/core/stream/functions/apply.py
+++ b/quixstreams/core/stream/functions/apply.py
@@ -47,12 +47,22 @@ def wrapper(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
) -> None:
# Execute a function on a single value and wrap results into a list
# to expand them downstream
- result = func(value)
- for item in result:
- child_executor(item, key, timestamp, headers)
+ if is_watermark:
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
+ else:
+ result = func(value)
+ for item in result:
+ child_executor(item, key, timestamp, headers)
else:
@@ -61,10 +71,20 @@ def wrapper(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
) -> None:
- # Execute a function on a single value and return its result
- result = func(value)
- child_executor(result, key, timestamp, headers)
+ if is_watermark:
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
+ else:
+ # Execute a function on a single value and return its result
+ result = func(value)
+ child_executor(result, key, timestamp, headers)
return wrapper
@@ -109,12 +129,22 @@ def wrapper(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
# Execute a function on a single value and wrap results into a list
# to expand them downstream
- result = func(value, key, timestamp, headers)
- for item in result:
- child_executor(item, key, timestamp, headers)
+ if is_watermark:
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
+ else:
+ result = func(value, key, timestamp, headers)
+ for item in result:
+ child_executor(item, key, timestamp, headers)
else:
@@ -123,9 +153,19 @@ def wrapper(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
- # Execute a function on a single value and return its result
- result = func(value, key, timestamp, headers)
- child_executor(result, key, timestamp, headers)
+ if is_watermark:
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
+ else:
+ # Execute a function on a single value and return its result
+ result = func(value, key, timestamp, headers)
+ child_executor(result, key, timestamp, headers)
return wrapper
diff --git a/quixstreams/core/stream/functions/base.py b/quixstreams/core/stream/functions/base.py
index 08037fef0..f0b8c6fca 100644
--- a/quixstreams/core/stream/functions/base.py
+++ b/quixstreams/core/stream/functions/base.py
@@ -1,5 +1,5 @@
import abc
-from typing import Any
+from typing import Any, Optional
from quixstreams.utils.pickle import pickle_copier
@@ -18,8 +18,11 @@ class StreamFunction(abc.ABC):
expand: bool = False
- def __init__(self, func: StreamCallback):
+ def __init__(
+ self, func: StreamCallback, on_watermark: Optional[StreamCallback] = None
+ ):
self.func = func
+ self.on_watermark = on_watermark
@abc.abstractmethod
def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
@@ -49,7 +52,9 @@ def wrapper(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
+ # TODO: Handle a watermark in branched operations
first_branch_executor, *branch_executors = child_executors
copier = pickle_copier(value)
diff --git a/quixstreams/core/stream/functions/filter.py b/quixstreams/core/stream/functions/filter.py
index e291880c7..94cbf30ee 100644
--- a/quixstreams/core/stream/functions/filter.py
+++ b/quixstreams/core/stream/functions/filter.py
@@ -28,9 +28,18 @@ def wrapper(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
# Filter a single value
- if func(value):
+ if is_watermark:
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
+ elif func(value):
child_executor(value, key, timestamp, headers)
return wrapper
@@ -60,9 +69,18 @@ def wrapper(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
+ if is_watermark:
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
# Filter a single value
- if func(value, key, timestamp, headers):
+ elif func(value, key, timestamp, headers):
child_executor(value, key, timestamp, headers)
return wrapper
diff --git a/quixstreams/core/stream/functions/transform.py b/quixstreams/core/stream/functions/transform.py
index 219662b6b..86c614e37 100644
--- a/quixstreams/core/stream/functions/transform.py
+++ b/quixstreams/core/stream/functions/transform.py
@@ -1,7 +1,11 @@
from typing import Any, Literal, Union, cast, overload
from .base import StreamFunction
-from .types import TransformCallback, TransformExpandedCallback, VoidExecutor
+from .types import (
+ TransformCallback,
+ TransformExpandedCallback,
+ VoidExecutor,
+)
__all__ = ("TransformFunction",)
@@ -21,24 +25,32 @@ class TransformFunction(StreamFunction):
The result of the callback will always be passed downstream.
"""
+ func: Union[TransformCallback, TransformExpandedCallback]
+
@overload
def __init__(
- self, func: TransformCallback, expand: Literal[False] = False
+ self,
+ func: TransformCallback,
+ expand: Literal[False] = False,
+ on_watermark: Union[TransformCallback, None] = None,
) -> None: ...
@overload
def __init__(
- self, func: TransformExpandedCallback, expand: Literal[True]
+ self,
+ func: TransformExpandedCallback,
+ expand: Literal[True],
+ on_watermark: Union[TransformExpandedCallback, None] = None,
) -> None: ...
def __init__(
self,
func: Union[TransformCallback, TransformExpandedCallback],
expand: bool = False,
+ on_watermark: Union[TransformCallback, TransformExpandedCallback, None] = None,
):
- super().__init__(func)
+ super().__init__(func=func, on_watermark=on_watermark)
- self.func: Union[TransformCallback, TransformExpandedCallback]
self.expand = expand
def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
@@ -52,10 +64,36 @@ def wrapper(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
- result = expanded_func(value, key, timestamp, headers)
- for new_value, new_key, new_timestamp, new_headers in result:
- child_executor(new_value, new_key, new_timestamp, new_headers)
+ if is_watermark:
+ if self.on_watermark is not None:
+ # React on the new watermark if "on_watermark" is defined
+ watermark_func = cast(
+ TransformExpandedCallback, self.on_watermark
+ )
+ result = watermark_func(None, None, timestamp, ())
+ for new_value, new_key, new_timestamp, new_headers in result:
+ child_executor(
+ new_value,
+ new_key,
+ new_timestamp,
+ new_headers,
+ False,
+ )
+ # Always pass the watermark downstream so other operators can react
+ # on it as well.
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
+ else:
+ result = expanded_func(value, key, timestamp, headers)
+ for new_value, new_key, new_timestamp, new_headers in result:
+ child_executor(new_value, new_key, new_timestamp, new_headers)
else:
func = cast(TransformCallback, self.func)
@@ -65,11 +103,36 @@ def wrapper(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
- # Execute a function on a single value and return its result
- new_value, new_key, new_timestamp, new_headers = func(
- value, key, timestamp, headers
- )
- child_executor(new_value, new_key, new_timestamp, new_headers)
+ if is_watermark:
+ if self.on_watermark is not None:
+ # React on the new watermark if "on_watermark" is defined
+ watermark_func = cast(TransformCallback, self.on_watermark)
+ new_value, new_key, new_timestamp, new_headers = watermark_func(
+ None, None, timestamp, ()
+ )
+ child_executor(
+ new_value,
+ new_key,
+ new_timestamp,
+ new_headers,
+ False,
+ )
+ # Always pass the watermark downstream so other operators can react
+ # on it as well.
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
+ else:
+ # Execute a function on a single value and return its result
+ new_value, new_key, new_timestamp, new_headers = func(
+ value, key, timestamp, headers
+ )
+ child_executor(new_value, new_key, new_timestamp, new_headers)
return wrapper
diff --git a/quixstreams/core/stream/functions/types.py b/quixstreams/core/stream/functions/types.py
index 504299b53..18a3b2023 100644
--- a/quixstreams/core/stream/functions/types.py
+++ b/quixstreams/core/stream/functions/types.py
@@ -14,6 +14,7 @@
"FilterWithMetadataCallback",
"TransformCallback",
"TransformExpandedCallback",
+ "StreamSink",
)
@@ -57,6 +58,7 @@ def __call__(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
) -> None: ...
@@ -67,4 +69,15 @@ def __call__(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
) -> Tuple[Any, Any, int, Any]: ...
+
+
+class StreamSink(Protocol):
+ def __call__(
+ self,
+ value: Any,
+ key: Any,
+ timestamp: int,
+ headers: Any,
+ ) -> None: ...
diff --git a/quixstreams/core/stream/functions/update.py b/quixstreams/core/stream/functions/update.py
index b2d9a19bc..157d6be5b 100644
--- a/quixstreams/core/stream/functions/update.py
+++ b/quixstreams/core/stream/functions/update.py
@@ -26,10 +26,25 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
child_executor = self._resolve_branching(*child_executors)
func = self.func
- def wrapper(value: Any, key: Any, timestamp: int, headers: Any):
- # Update a single value and forward it
- func(value)
- child_executor(value, key, timestamp, headers)
+ def wrapper(
+ value: Any,
+ key: Any,
+ timestamp: int,
+ headers: Any,
+ is_watermark: bool = False,
+ ):
+ if is_watermark:
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
+ else:
+ # Update a single value and forward it
+ func(value)
+ child_executor(value, key, timestamp, headers)
return wrapper
@@ -54,9 +69,24 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
child_executor = self._resolve_branching(*child_executors)
func = self.func
- def wrapper(value: Any, key: Any, timestamp: int, headers: Any):
- # Update a single value and forward it
- func(value, key, timestamp, headers)
- child_executor(value, key, timestamp, headers)
+ def wrapper(
+ value: Any,
+ key: Any,
+ timestamp: int,
+ headers: Any,
+ is_watermark: bool = False,
+ ):
+ if is_watermark:
+ child_executor(
+ value,
+ key,
+ timestamp,
+ headers,
+ True,
+ )
+ else:
+ # Update a single value and forward it
+ func(value, key, timestamp, headers)
+ child_executor(value, key, timestamp, headers)
return wrapper
diff --git a/quixstreams/core/stream/stream.py b/quixstreams/core/stream/stream.py
index f538f5307..258bf4025 100644
--- a/quixstreams/core/stream/stream.py
+++ b/quixstreams/core/stream/stream.py
@@ -27,6 +27,7 @@
FilterWithMetadataFunction,
ReturningExecutor,
StreamFunction,
+ StreamSink,
TransformCallback,
TransformExpandedCallback,
TransformFunction,
@@ -249,17 +250,30 @@ def add_update(
return self._add(update_func)
@overload
- def add_transform(self, func: TransformCallback, *, expand: Literal[False] = False):
+ def add_transform(
+ self,
+ func: TransformCallback,
+ *,
+ expand: Literal[False] = False,
+ on_watermark: Union[TransformCallback, None] = None,
+ ):
pass
@overload
- def add_transform(self, func: TransformExpandedCallback, *, expand: Literal[True]):
+ def add_transform(
+ self,
+ func: TransformExpandedCallback,
+ *,
+ expand: Literal[True],
+ on_watermark: Union[TransformExpandedCallback, None] = None,
+ ):
pass
def add_transform(
self,
func: Union[TransformCallback, TransformExpandedCallback],
*,
+ on_watermark: Union[TransformCallback, TransformExpandedCallback, None] = None,
expand: bool = False,
) -> "Stream":
"""
@@ -276,9 +290,13 @@ def add_transform(
:param expand: if True, expand the returned iterable into individual items
downstream. If returned value is not iterable, `TypeError` will be raised.
Default - `False`.
+ :param on_watermark: a callback to process the watermark messages.
+ They can be used to expire and emit window results.
:return: a new Stream derived from the current one
"""
- return self._add(TransformFunction(func, expand=expand)) # type: ignore[call-overload]
+ return self._add(
+ TransformFunction(func, expand=expand, on_watermark=on_watermark) # type: ignore[call-overload]
+ )
def merge(self, other: "Stream") -> "Stream":
"""
@@ -407,7 +425,7 @@ def compose(
allow_expands=True,
allow_updates=True,
allow_transforms=True,
- sink: Optional[VoidExecutor] = None,
+ sink: Optional[StreamSink] = None,
) -> dict["Stream", VoidExecutor]:
"""
Generate an "executor" closure by mapping all relatives of this `Stream` and
@@ -430,7 +448,7 @@ def compose(
:param sink: callable to accumulate the results of the execution, optional.
"""
- sink = sink or self._default_sink
+ sink = self._sink_wrapper(sink or self._default_sink)
executors: dict["Stream", VoidExecutor] = {}
for stream in reversed(self.full_tree()):
@@ -487,10 +505,16 @@ def compose_returning(self) -> ReturningExecutor:
),
)
- def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any:
+ def wrapper(
+ value: Any,
+ key: Any,
+ timestamp: int,
+ headers: Any,
+ is_watermark: bool = False,
+ ) -> Any:
try:
# Execute the stream and return the result from the queue
- executor(value, key, timestamp, headers)
+ executor(value, key, timestamp, headers, is_watermark)
return buffer.popleft()
finally:
# Always clean the queue after the Stream is executed
@@ -504,7 +528,7 @@ def compose_single(
allow_expands=True,
allow_updates=True,
allow_transforms=True,
- sink: Optional[VoidExecutor] = None,
+ sink: Optional[StreamSink] = None,
) -> VoidExecutor:
"""
A helper function to compose a Stream with a single root.
@@ -557,6 +581,23 @@ def _add(self, func: StreamFunction) -> "Stream":
self.children.append(new_node)
return new_node
+ def _sink_wrapper(self, sink_func: StreamSink) -> VoidExecutor:
+ def wrapper(
+ value: Any,
+ key: Any,
+ timestamp: int,
+ headers: Any,
+ is_watermark: bool = False,
+ ):
+ if not is_watermark:
+ sink_func(value, key, timestamp, headers)
+
+ return wrapper
+
def _default_sink(
- self, value: Any, key: Any, timestamp: int, headers: Any
+ self,
+ value: Any,
+ key: Any,
+ timestamp: int,
+ headers: Any,
) -> None: ...
diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py
index 53e90c767..4efec03c5 100644
--- a/quixstreams/dataframe/dataframe.py
+++ b/quixstreams/dataframe/dataframe.py
@@ -35,6 +35,7 @@
FilterCallback,
FilterWithMetadataCallback,
Stream,
+ StreamSink,
UpdateCallback,
UpdateWithMetadataCallback,
VoidExecutor,
@@ -72,7 +73,11 @@
TumblingCountWindowDefinition,
TumblingTimeWindowDefinition,
)
-from .windows.base import WindowOnLateCallback
+from .windows.base import (
+ WindowAfterUpdateCallback,
+ WindowBeforeUpdateCallback,
+ WindowOnLateCallback,
+)
if typing.TYPE_CHECKING:
from quixstreams.processing import ProcessingContext
@@ -754,6 +759,8 @@ def set_timestamp(
self, func: Callable[[Any, Any, int, Any], int]
) -> "StreamingDataFrame":
"""
+ # TODO: Document that it overwrites the default watermark.
+
Set a new timestamp based on the current message value and its metadata.
The new timestamp will be used in windowed aggregations and when producing
@@ -788,6 +795,14 @@ def _set_timestamp_callback(
headers: Any,
) -> Tuple[Any, Any, int, Any]:
new_timestamp = func(value, key, timestamp, headers)
+
+ ctx = message_context()
+ self._processing_context.watermark_manager.store(
+ topic=ctx.topic,
+ partition=ctx.partition,
+ timestamp=new_timestamp,
+ default=False,
+ )
return value, key, new_timestamp, headers
stream = self.stream.add_transform(_set_timestamp_callback, expand=False)
@@ -1008,7 +1023,7 @@ def _add_row(value: Any, *_metadata: tuple[Any, int, HeadersTuples]) -> None:
def compose(
self,
- sink: Optional[VoidExecutor] = None,
+ sink: Optional[StreamSink] = None,
) -> dict[str, VoidExecutor]:
"""
@@ -1048,6 +1063,7 @@ def test(
headers: Optional[Any] = None,
ctx: Optional[MessageContext] = None,
topic: Optional[Topic] = None,
+ is_watermark: bool = False,
) -> List[Any]:
"""
A shorthand to test `StreamingDataFrame` with provided value
@@ -1061,6 +1077,8 @@ def test(
has stateful functions or windows.
Default - `None`.
:param topic: optionally, a topic branch to test with
+ :param is_watermark: whether the value is a watermark.
+ Default - `False`.
:return: result of `StreamingDataFrame`
"""
@@ -1076,7 +1094,7 @@ def test(
(value_, key_, timestamp_, headers_)
)
)
- context.run(composed[topic.name], value, key, timestamp, headers)
+ context.run(composed[topic.name], value, key, timestamp, headers, is_watermark)
return result
def tumbling_window(
@@ -1085,6 +1103,8 @@ def tumbling_window(
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None,
) -> TumblingTimeWindowDefinition:
"""
Create a time-based tumbling window transformation on this StreamingDataFrame.
@@ -1151,6 +1171,20 @@ def tumbling_window(
(default behavior).
Otherwise, no message will be logged.
+ :param before_update: an optional callback to trigger early window expiration
+ before the window is updated.
+ The callback receives `aggregated` (current aggregated value or default/None),
+ `value`, `key`, `timestamp`, and `headers`.
+ If it returns `True`, the window will be expired immediately.
+ Default - `None`.
+
+ :param after_update: an optional callback to trigger early window expiration
+ after the window is updated.
+ The callback receives `aggregated` (updated aggregated value), `value`, `key`,
+ `timestamp`, and `headers`.
+ If it returns `True`, the window will be expired immediately.
+ Default - `None`.
+
:return: `TumblingTimeWindowDefinition` instance representing the tumbling window
configuration.
This object can be further configured with aggregation functions
@@ -1166,6 +1200,8 @@ def tumbling_window(
dataframe=self,
name=name,
on_late=on_late,
+ before_update=before_update,
+ after_update=after_update,
)
def tumbling_count_window(
@@ -1225,6 +1261,8 @@ def hopping_window(
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None,
) -> HoppingTimeWindowDefinition:
"""
Create a time-based hopping window transformation on this StreamingDataFrame.
@@ -1302,6 +1340,20 @@ def hopping_window(
(default behavior).
Otherwise, no message will be logged.
+ :param before_update: an optional callback to trigger early window expiration
+ before the window is updated.
+ The callback receives `aggregated` (current aggregated value or default/None),
+ `value`, `key`, `timestamp`, and `headers`.
+ If it returns `True`, the window will be expired immediately.
+ Default - `None`.
+
+ :param after_update: an optional callback to trigger early window expiration
+ after the window is updated.
+ The callback receives `aggregated` (updated aggregated value), `value`, `key`,
+ `timestamp`, and `headers`.
+ If it returns `True`, the window will be expired immediately.
+ Default - `None`.
+
:return: `HoppingTimeWindowDefinition` instance representing the hopping
window configuration.
This object can be further configured with aggregation functions
@@ -1319,6 +1371,8 @@ def hopping_window(
dataframe=self,
name=name,
on_late=on_late,
+ before_update=before_update,
+ after_update=after_update,
)
def hopping_count_window(
@@ -1654,7 +1708,7 @@ def _sink_callback(
headers=headers,
partition=ctx.partition,
topic=ctx.topic,
- offset=ctx.offset,
+ offset=ctx.offset or 0,
)
# uses apply without returning to make this operation terminal
diff --git a/quixstreams/dataframe/registry.py b/quixstreams/dataframe/registry.py
index dd7138e0b..87f9441bc 100644
--- a/quixstreams/dataframe/registry.py
+++ b/quixstreams/dataframe/registry.py
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Optional
-from quixstreams.core.stream import Stream, VoidExecutor
+from quixstreams.core.stream import Stream, StreamSink, VoidExecutor
from quixstreams.models import Topic
from .exceptions import (
@@ -105,9 +105,7 @@ def register_groupby(
"adjust by setting a unique name with `SDF.group_by(name=)` "
)
- def compose_all(
- self, sink: Optional[VoidExecutor] = None
- ) -> dict[str, VoidExecutor]:
+ def compose_all(self, sink: Optional[StreamSink] = None) -> dict[str, VoidExecutor]:
"""
Composes all the Streams and returns a dict of format {: }
:param sink: callable to accumulate the results of the execution, optional.
diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py
index 8040b2774..a4e888eb5 100644
--- a/quixstreams/dataframe/windows/base.py
+++ b/quixstreams/dataframe/windows/base.py
@@ -18,7 +18,6 @@
from quixstreams.context import message_context
from quixstreams.core.stream import TransformExpandedCallback
-from quixstreams.core.stream.exceptions import InvalidOperation
from quixstreams.models.topics.manager import TopicManager
from quixstreams.state import WindowedPartitionTransaction
@@ -34,6 +33,8 @@
WindowResult: TypeAlias = dict[str, Any]
WindowKeyResult: TypeAlias = tuple[Any, WindowResult]
Message: TypeAlias = tuple[WindowResult, Any, int, Any]
+WindowBeforeUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool]
+WindowAfterUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool]
WindowAggregateFunc = Callable[[Any, Any], Any]
@@ -65,8 +66,17 @@ def process_window(
value: Any,
key: Any,
timestamp_ms: int,
+ headers: Any,
transaction: WindowedPartitionTransaction,
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
+ """
+ Process a window update for the given value and key.
+
+ Returns:
+ A tuple of (updated_windows, triggered_windows) where:
+ - updated_windows: Windows that were updated but not expired
+ - triggered_windows: Windows that were expired early due to before_update/after_update callbacks
+ """
pass
def register_store(self) -> None:
@@ -81,24 +91,39 @@ def register_store(self) -> None:
def _apply_window(
self,
- func: TransformRecordCallbackExpandedWindowed,
+ on_update: TransformRecordCallbackExpandedWindowed,
name: str,
+ on_watermark: Optional[TransformRecordCallbackExpandedWindowed] = None,
) -> "StreamingDataFrame":
self.register_store()
windowed_func = _as_windowed(
- func=func,
+ func=on_update,
stream_id=self._dataframe.stream_id,
processing_context=self._dataframe.processing_context,
store_name=name,
)
+ if on_watermark:
+ watermark_func = _as_windowed(
+ func=on_watermark,
+ stream_id=self._dataframe.stream_id,
+ processing_context=self._dataframe.processing_context,
+ store_name=name,
+ allow_null_key=True,
+ )
+ else:
+ watermark_func = None
+
# Manually modify the Stream and clone the source StreamingDataFrame
# to avoid adding "transform" API to it.
# Transform callbacks can modify record key and timestamp,
# and it's prone to misuse.
- stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True)
+ stream = self._dataframe.stream.add_transform(
+ func=windowed_func, expand=True, on_watermark=watermark_func
+ )
return self._dataframe.__dataframe_clone__(stream=stream)
+ @abstractmethod
def final(self) -> "StreamingDataFrame":
"""
Apply the window aggregation and return results only when the windows are
@@ -122,29 +147,9 @@ def final(self) -> "StreamingDataFrame":
If some message keys appear irregularly in the stream, the latest windows
can remain unprocessed until the message the same key is received.
"""
+ ...
- def window_callback(
- value: Any,
- key: Any,
- timestamp_ms: int,
- _headers: Any,
- transaction: WindowedPartitionTransaction,
- ) -> Iterable[Message]:
- _, expired_windows = self.process_window(
- value=value,
- key=key,
- timestamp_ms=timestamp_ms,
- transaction=transaction,
- )
- # Use window start timestamp as a new record timestamp
- for key, window in expired_windows:
- yield (window, key, window["start"], None)
-
- return self._apply_window(
- func=window_callback,
- name=self._name,
- )
-
+ @abstractmethod
def current(self) -> "StreamingDataFrame":
"""
Apply the window transformation to the StreamingDataFrame to return results
@@ -162,33 +167,7 @@ def current(self) -> "StreamingDataFrame":
This method processes streaming data and returns results as they come,
regardless of whether the window is closed or not.
"""
-
- if self.collect:
- raise InvalidOperation(
- "BaseCollectors are not supported by `current` windows"
- )
-
- def window_callback(
- value: Any,
- key: Any,
- timestamp_ms: int,
- _headers: Any,
- transaction: WindowedPartitionTransaction,
- ) -> Iterable[Message]:
- updated_windows, expired_windows = self.process_window(
- value=value, key=key, timestamp_ms=timestamp_ms, transaction=transaction
- )
-
- # loop over the expired_windows generator to ensure the windows
- # are expired
- for key, window in expired_windows:
- pass
-
- # Use window start timestamp as a new record timestamp
- for key, window in updated_windows:
- yield (window, key, window["start"], None)
-
- return self._apply_window(func=window_callback, name=self._name)
+ ...
# Implemented by SingleAggregationWindowMixin and MultiAggregationWindowMixin
# Single aggregation and multi aggregation windows store aggregations and collections
@@ -401,6 +380,7 @@ def _as_windowed(
processing_context: "ProcessingContext",
store_name: str,
stream_id: str,
+ allow_null_key: bool = False,
) -> TransformExpandedCallback:
@functools.wraps(func)
def wrapper(
@@ -413,7 +393,7 @@ def wrapper(
stream_id=stream_id, partition=ctx.partition, store_name=store_name
),
)
- if key is None:
+ if key is None and not allow_null_key:
logger.warning(
f"Skipping window processing for a message because the key is None, "
f"partition='{ctx.topic}[{ctx.partition}]' offset='{ctx.offset}'."
@@ -436,7 +416,7 @@ def __call__(
store_name: str,
topic: str,
partition: int,
- offset: int,
+ offset: Optional[int],
) -> bool: ...
diff --git a/quixstreams/dataframe/windows/count_based.py b/quixstreams/dataframe/windows/count_based.py
index 57c6b36e5..6afeeb0d9 100644
--- a/quixstreams/dataframe/windows/count_based.py
+++ b/quixstreams/dataframe/windows/count_based.py
@@ -1,9 +1,11 @@
import logging
from typing import TYPE_CHECKING, Any, Iterable, Optional, TypedDict, Union, cast
+from quixstreams.core.stream import InvalidOperation
from quixstreams.state import WindowedPartitionTransaction
from .base import (
+ Message,
MultiAggregationWindowMixin,
SingleAggregationWindowMixin,
Window,
@@ -53,11 +55,108 @@ def __init__(
self._max_count = count
self._step = step
+ def final(self) -> "StreamingDataFrame":
+ """
+ Apply the window aggregation and return results only when the windows are
+ closed.
+
+ The format of returned windows:
+ ```python
+ {
+ "start": ,
+ "end": ,
+ "value: ,
+ }
+ ```
+
+ The individual window is closed when the event time
+ (the maximum observed timestamp across the partition) passes
+ its end timestamp + grace period.
+ The closed windows cannot receive updates anymore and are considered final.
+
+ >***NOTE:*** Windows can be closed only within the same message key.
+ If some message keys appear irregularly in the stream, the latest windows
+ can remain unprocessed until the message the same key is received.
+ """
+
+ def window_callback(
+ value: Any,
+ key: Any,
+ timestamp_ms: int,
+ _headers: Any,
+ transaction: WindowedPartitionTransaction,
+ ) -> Iterable[Message]:
+ _, expired_windows = self.process_window(
+ value=value,
+ key=key,
+ timestamp_ms=timestamp_ms,
+ headers=_headers,
+ transaction=transaction,
+ )
+ # Use window start timestamp as a new record timestamp
+ for key, window in expired_windows:
+ yield window, key, window["start"], None
+
+ return self._apply_window(
+ on_update=window_callback,
+ name=self._name,
+ )
+
+ def current(self) -> "StreamingDataFrame":
+ """
+ Apply the window transformation to the StreamingDataFrame to return results
+ for each updated window.
+
+ The format of returned windows:
+ ```python
+ {
+ "start": ,
+ "end": ,
+ "value: ,
+ }
+ ```
+
+ This method processes streaming data and returns results as they come,
+ regardless of whether the window is closed or not.
+ """
+
+ if self.collect:
+ raise InvalidOperation(
+ "BaseCollectors are not supported by `current` windows"
+ )
+
+ def window_callback(
+ value: Any,
+ key: Any,
+ timestamp_ms: int,
+ _headers: Any,
+ transaction: WindowedPartitionTransaction,
+ ) -> Iterable[Message]:
+ updated_windows, expired_windows = self.process_window(
+ value=value,
+ key=key,
+ timestamp_ms=timestamp_ms,
+ headers=_headers,
+ transaction=transaction,
+ )
+
+ # loop over the expired_windows generator to ensure the windows
+ # are expired
+ for key, window in expired_windows:
+ pass
+
+ # Use window start timestamp as a new record timestamp
+ for key, window in updated_windows:
+ yield window, key, window["start"], None
+
+ return self._apply_window(on_update=window_callback, name=self._name)
+
def process_window(
self,
value: Any,
key: Any,
timestamp_ms: int,
+ headers: Any,
transaction: WindowedPartitionTransaction[str, CountWindowsData],
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
"""
@@ -78,7 +177,7 @@ def process_window(
next free msg id is 35 (32 + 3).
For tumbling windows there is no window overlap so we can't rely on that
- optimisation. Instead the msg id reset to 0 on every new window.
+ optimisation. Instead, the msg id reset to 0 on every new window.
"""
state = transaction.as_state(prefix=key)
data = state.get(key=self.STATE_KEY, default=CountWindowsData(windows=[]))
diff --git a/quixstreams/dataframe/windows/definitions.py b/quixstreams/dataframe/windows/definitions.py
index 90d4d815b..20e2ce944 100644
--- a/quixstreams/dataframe/windows/definitions.py
+++ b/quixstreams/dataframe/windows/definitions.py
@@ -15,6 +15,8 @@
)
from .base import (
Window,
+ WindowAfterUpdateCallback,
+ WindowBeforeUpdateCallback,
WindowOnLateCallback,
)
from .count_based import (
@@ -54,11 +56,15 @@ def __init__(
name: Optional[str],
dataframe: "StreamingDataFrame",
on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None,
) -> None:
super().__init__()
self._name = name
self._on_late = on_late
+ self._before_update = before_update
+ self._after_update = after_update
self._dataframe = dataframe
@abstractmethod
@@ -239,6 +245,8 @@ def __init__(
name: Optional[str] = None,
step_ms: Optional[int] = None,
on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None,
):
if not isinstance(duration_ms, int):
raise TypeError("Window size must be an integer")
@@ -253,7 +261,7 @@ def __init__(
f"got {step_ms}ms"
)
- super().__init__(name, dataframe, on_late)
+ super().__init__(name, dataframe, on_late, before_update, after_update)
self._duration_ms = duration_ms
self._grace_ms = grace_ms
@@ -281,6 +289,8 @@ def __init__(
dataframe: "StreamingDataFrame",
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None,
):
super().__init__(
duration_ms=duration_ms,
@@ -289,6 +299,8 @@ def __init__(
name=name,
step_ms=step_ms,
on_late=on_late,
+ before_update=before_update,
+ after_update=after_update,
)
def _get_name(self, func_name: Optional[str]) -> str:
@@ -320,6 +332,8 @@ def _create_window(
aggregators=aggregators or {},
collectors=collectors or {},
on_late=self._on_late,
+ before_update=self._before_update,
+ after_update=self._after_update,
)
@@ -331,6 +345,8 @@ def __init__(
dataframe: "StreamingDataFrame",
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None,
):
super().__init__(
duration_ms=duration_ms,
@@ -338,6 +354,8 @@ def __init__(
dataframe=dataframe,
name=name,
on_late=on_late,
+ before_update=before_update,
+ after_update=after_update,
)
def _get_name(self, func_name: Optional[str]) -> str:
@@ -368,6 +386,8 @@ def _create_window(
aggregators=aggregators or {},
collectors=collectors or {},
on_late=self._on_late,
+ before_update=self._before_update,
+ after_update=self._after_update,
)
@@ -379,13 +399,22 @@ def __init__(
dataframe: "StreamingDataFrame",
name: Optional[str] = None,
on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None,
):
+ if before_update is not None or after_update is not None:
+ raise ValueError(
+ "Sliding windows do not support trigger callbacks (before_update/after_update). "
+ "Use tumbling or hopping windows instead."
+ )
super().__init__(
duration_ms=duration_ms,
grace_ms=grace_ms,
dataframe=dataframe,
name=name,
on_late=on_late,
+ before_update=before_update,
+ after_update=after_update,
)
def _get_name(self, func_name: Optional[str]) -> str:
@@ -417,6 +446,8 @@ def _create_window(
aggregators=aggregators or {},
collectors=collectors or {},
on_late=self._on_late,
+ before_update=self._before_update,
+ after_update=self._after_update,
)
diff --git a/quixstreams/dataframe/windows/sliding.py b/quixstreams/dataframe/windows/sliding.py
index d3dfdbb39..3a4e6f692 100644
--- a/quixstreams/dataframe/windows/sliding.py
+++ b/quixstreams/dataframe/windows/sliding.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any, Iterable
+from typing import Any, Iterable
from quixstreams.state import WindowedPartitionTransaction, WindowedState
@@ -7,34 +7,16 @@
SingleAggregationWindowMixin,
WindowKeyResult,
)
-from .time_based import ClosingStrategyValues, TimeWindow
-
-if TYPE_CHECKING:
- from quixstreams.dataframe.dataframe import StreamingDataFrame
+from .time_based import TimeWindow
class SlidingWindow(TimeWindow):
- def final(
- self, closing_strategy: ClosingStrategyValues = "key"
- ) -> "StreamingDataFrame":
- if closing_strategy != "key":
- raise TypeError("Sliding window only support the 'key' closing strategy")
-
- return super().final(closing_strategy=closing_strategy)
-
- def current(
- self, closing_strategy: ClosingStrategyValues = "key"
- ) -> "StreamingDataFrame":
- if closing_strategy != "key":
- raise TypeError("Sliding window only support the 'key' closing strategy")
-
- return super().current(closing_strategy=closing_strategy)
-
def process_window(
self,
value: Any,
key: Any,
timestamp_ms: int,
+ headers: Any,
transaction: WindowedPartitionTransaction,
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
"""
@@ -88,11 +70,10 @@ def process_window(
# Sliding windows are inclusive on both ends, so values with
# timestamps equal to latest_timestamp - duration - grace
# are still eligible for processing.
- state_ts = state.get_latest_timestamp() or 0
- latest_timestamp = max(timestamp_ms, state_ts)
- max_expired_window_end = latest_timestamp - grace - 1
+ max_expired_window_end = max(
+ timestamp_ms - grace - 1, transaction.get_latest_expired(prefix=b"")
+ )
max_expired_window_start = max_expired_window_end - duration
- max_deleted_window_start = max_expired_window_start - duration
left_start = max(0, timestamp_ms - duration)
left_end = timestamp_ms
@@ -104,7 +85,7 @@ def process_window(
start=left_start,
end=left_end,
timestamp_ms=timestamp_ms,
- late_by_ms=max_expired_window_end + 1 - timestamp_ms,
+ late_by_ms=max_expired_window_end - timestamp_ms,
)
return [], []
@@ -112,7 +93,7 @@ def process_window(
right_end = right_start + duration
right_exists = False
- starts = set([left_start])
+ starts = {left_start}
updated_windows: list[WindowKeyResult] = []
iterated_windows = state.get_windows(
# start_from_ms is exclusive, hence -1
@@ -252,7 +233,6 @@ def process_window(
# At this point, this is the last window that will ever be considered
# for existing aggregations. Windows lower than this and lower than
# the expiration watermark may be deleted.
- max_deleted_window_start = min(start - 1, max_expired_window_start)
break
else:
@@ -276,30 +256,39 @@ def process_window(
if collect:
state.add_to_collection(value=self._collect_value(value), id=timestamp_ms)
- # build a complete list otherwise expired windows could be deleted
- # in state.delete_windows() and never be fetched.
- expired_windows = list(
- self._expired_windows(key, state, max_expired_window_start, collect)
- )
+ # Sliding windows don't support before_update/after_update callbacks yet,
+ # so triggered_windows is always empty
+ return reversed(updated_windows), []
- state.delete_windows(
- max_start_time=max_deleted_window_start,
- delete_values=collect,
- )
-
- return reversed(updated_windows), expired_windows
-
- def _expired_windows(self, key, state, max_expired_window_start, collect):
- for window in state.expire_windows(
- max_start_time=max_expired_window_start,
+ def expire_by_partition(
+ self,
+ transaction: WindowedPartitionTransaction,
+ timestamp_ms: int,
+ ) -> Iterable[WindowKeyResult]:
+ latest_expired_window_end = transaction.get_latest_expired(prefix=b"")
+ latest_timestamp = max(timestamp_ms, latest_expired_window_end)
+ # Subtract 1 because sliding windows are inclusive on the end
+ max_expired_window_end = latest_timestamp - self._grace_ms - 1
+
+ # First, expire and return windows without deleting them.
+ # Sliding windows use previous updates to calculate the new state.
+ for window in transaction.expire_all_windows(
+ max_end_time=max_expired_window_end,
+ step_ms=1, # step is 1ms because sliding windows don't have fixed boundaries
+ collect=self.collect,
delete=False,
- collect=collect,
end_inclusive=True,
):
- (start, end), (max_timestamp, aggregated), collected, _ = window
+ (start, end), (max_timestamp, aggregated), collected, key = window
if end == max_timestamp:
yield key, self._results(aggregated, collected, start, end)
+ # Second, delete all windows that can't be used by the sliding windows anymore.
+ transaction.delete_all_windows(
+ max_end_time=max_expired_window_end - self._duration_ms,
+ collect=self.collect,
+ )
+
def _update_window(
self,
key: bytes,
@@ -316,7 +305,7 @@ def _update_window(
value=[max_timestamp, value],
timestamp_ms=timestamp,
)
- return (key, self._results(value, [], start, end))
+ return key, self._results(value, [], start, end)
class SlidingWindowSingleAggregation(SingleAggregationWindowMixin, SlidingWindow):
diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py
index c403cfdfa..21ee68c4b 100644
--- a/quixstreams/dataframe/windows/time_based.py
+++ b/quixstreams/dataframe/windows/time_based.py
@@ -1,14 +1,17 @@
import logging
-from enum import Enum
-from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional
+from typing import TYPE_CHECKING, Any, Iterable, Optional
from quixstreams.context import message_context
-from quixstreams.state import WindowedPartitionTransaction, WindowedState
+from quixstreams.state import WindowedPartitionTransaction
+from quixstreams.utils.format import format_timestamp
from .base import (
+ Message,
MultiAggregationWindowMixin,
SingleAggregationWindowMixin,
Window,
+ WindowAfterUpdateCallback,
+ WindowBeforeUpdateCallback,
WindowKeyResult,
WindowOnLateCallback,
get_window_ranges,
@@ -20,23 +23,6 @@
logger = logging.getLogger(__name__)
-class ClosingStrategy(Enum):
- KEY = "key"
- PARTITION = "partition"
-
- @classmethod
- def new(cls, value: str) -> "ClosingStrategy":
- try:
- return ClosingStrategy[value.upper()]
- except KeyError:
- raise TypeError(
- 'closing strategy must be one of "key" or "partition'
- ) from None
-
-
-ClosingStrategyValues = Literal["key", "partition"]
-
-
class TimeWindow(Window):
def __init__(
self,
@@ -46,6 +32,8 @@ def __init__(
dataframe: "StreamingDataFrame",
step_ms: Optional[int] = None,
on_late: Optional[WindowOnLateCallback] = None,
+ before_update: Optional[WindowBeforeUpdateCallback] = None,
+ after_update: Optional[WindowAfterUpdateCallback] = None,
):
super().__init__(
name=name,
@@ -56,12 +44,10 @@ def __init__(
self._grace_ms = grace_ms
self._step_ms = step_ms
self._on_late = on_late
+ self._before_update = before_update
+ self._after_update = after_update
- self._closing_strategy = ClosingStrategy.KEY
-
- def final(
- self, closing_strategy: ClosingStrategyValues = "key"
- ) -> "StreamingDataFrame":
+ def final(self) -> "StreamingDataFrame":
"""
Apply the window aggregation and return results only when the windows are
closed.
@@ -80,20 +66,59 @@ def final(
its end timestamp + grace period.
The closed windows cannot receive updates anymore and are considered final.
- :param closing_strategy: the strategy to use when closing windows.
- Possible values:
- - `"key"` - messages advance time and close windows with the same key.
- If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until a message with the same key is received.
- - `"partition"` - messages advance time and close windows for the whole partition to which this message key belongs.
- If timestamps between keys are not ordered, it may increase the number of discarded late messages.
- Default - `"key"`.
"""
- self._closing_strategy = ClosingStrategy.new(closing_strategy)
- return super().final()
- def current(
- self, closing_strategy: ClosingStrategyValues = "key"
- ) -> "StreamingDataFrame":
+ def on_update(
+ value: Any,
+ key: Any,
+ timestamp_ms: int,
+ _headers: Any,
+ transaction: WindowedPartitionTransaction,
+ ):
+ # Process the window and get windows triggered from callbacks
+ _, triggered_windows = self.process_window(
+ value=value,
+ key=key,
+ timestamp_ms=timestamp_ms,
+ headers=_headers,
+ transaction=transaction,
+ )
+ # Yield triggered windows (from before_update/after_update callbacks)
+ for key, window in triggered_windows:
+ yield window, key, window["start"], None
+
+ def on_watermark(
+ _value: Any,
+ _key: Any,
+ timestamp_ms: int,
+ _headers: Any,
+ transaction: WindowedPartitionTransaction,
+ ) -> Iterable[Message]:
+ expired_windows = self.expire_by_partition(
+ transaction=transaction, timestamp_ms=timestamp_ms
+ )
+
+ total_expired = 0
+ # Use window start timestamp as a new record timestamp
+ for key, window in expired_windows:
+ total_expired += 1
+ yield window, key, window["start"], None
+
+ ctx = message_context()
+ logger.info(
+ f"Expired {total_expired} windows after processing "
+ f"the watermark at {format_timestamp(timestamp_ms)}. "
+ f"window_name={self._name} topic={ctx.topic} "
+ f"partition={ctx.partition} timestamp={timestamp_ms}"
+ )
+
+ return self._apply_window(
+ on_update=on_update,
+ on_watermark=on_watermark,
+ name=self._name,
+ )
+
+ def current(self) -> "StreamingDataFrame":
"""
Apply the window transformation to the StreamingDataFrame to return results
for each updated window.
@@ -109,29 +134,72 @@ def current(
This method processes streaming data and returns results as they come,
regardless of whether the window is closed or not.
-
- :param closing_strategy: the strategy to use when closing windows.
- Possible values:
- - `"key"` - messages advance time and close windows with the same key.
- If some message keys appear irregularly in the stream, the latest windows can remain unprocessed until a message with the same key is received.
- - `"partition"` - messages advance time and close windows for the whole partition to which this message key belongs.
- If timestamps between keys are not ordered, it may increase the number of discarded late messages.
- Default - `"key"`.
"""
- self._closing_strategy = ClosingStrategy.new(closing_strategy)
- return super().current()
+ def on_update(
+ value: Any,
+ key: Any,
+ timestamp_ms: int,
+ _headers: Any,
+ transaction: WindowedPartitionTransaction,
+ ):
+ # Process the window and get both updated and triggered windows
+ updated_windows, triggered_windows = self.process_window(
+ value=value,
+ key=key,
+ timestamp_ms=timestamp_ms,
+ headers=_headers,
+ transaction=transaction,
+ )
+ # Use window start timestamp as a new record timestamp
+ # Yield both updated and triggered windows
+ for key, window in updated_windows:
+ yield window, key, window["start"], None
+ for key, window in triggered_windows:
+ yield window, key, window["start"], None
+
+ def on_watermark(
+ _value: Any,
+ _key: Any,
+ timestamp_ms: int,
+ _headers: Any,
+ transaction: WindowedPartitionTransaction,
+ ) -> Iterable[Message]:
+ expired_windows = self.expire_by_partition(
+ transaction=transaction, timestamp_ms=timestamp_ms
+ )
+ # Just exhaust the iterator here
+ for _ in expired_windows:
+ pass
+ return []
+
+ return self._apply_window(
+ on_update=on_update,
+ on_watermark=on_watermark,
+ name=self._name,
+ )
def process_window(
self,
value: Any,
key: Any,
timestamp_ms: int,
+ headers: Any,
transaction: WindowedPartitionTransaction,
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
+ """
+ Process a window update for the given value and key.
+
+ Returns:
+ A tuple of (updated_windows, triggered_windows) where:
+ - updated_windows: Windows that were updated but not expired
+ - triggered_windows: Windows that were expired early due to before_update/after_update callbacks
+ """
state = transaction.as_state(prefix=key)
duration_ms = self._duration_ms
grace_ms = self._grace_ms
+ before_update = self._before_update
+ after_update = self._after_update
collect = self.collect
aggregate = self.aggregate
@@ -142,16 +210,13 @@ def process_window(
step_ms=self._step_ms,
)
- if self._closing_strategy == ClosingStrategy.PARTITION:
- latest_expired_window_end = transaction.get_latest_expired(prefix=b"")
- latest_timestamp = max(timestamp_ms, latest_expired_window_end)
- else:
- state_ts = state.get_latest_timestamp() or 0
- latest_timestamp = max(timestamp_ms, state_ts)
+ latest_expired_window_end = transaction.get_latest_expired(prefix=b"")
+ latest_timestamp = max(timestamp_ms, latest_expired_window_end)
max_expired_window_end = latest_timestamp - grace_ms
max_expired_window_start = max_expired_window_end - duration_ms
updated_windows: list[WindowKeyResult] = []
+ triggered_windows: list[WindowKeyResult] = []
for start, end in ranges:
if start <= max_expired_window_start:
late_by_ms = max_expired_window_end - timestamp_ms
@@ -169,18 +234,78 @@ def process_window(
# since actual values are stored separately and combined into an array
# during window expiration.
aggregated = None
+
if aggregate:
current_value = state.get_window(start, end)
if current_value is None:
current_value = self._initialize_value()
+ # Check before_update trigger
+ if before_update and before_update(
+ current_value, value, key, timestamp_ms, headers
+ ):
+ # Get collected values for the result
+ # Do NOT include the current value - before_update means
+ # we expire BEFORE adding the current value
+ collected = state.get_from_collection(start, end) if collect else []
+
+ result = self._results(current_value, collected, start, end)
+ triggered_windows.append((key, result))
+ transaction.delete_window(start, end, prefix=key)
+ # Note: We don't delete from collection here - normal expiration
+ # will handle cleanup for both tumbling and hopping windows
+ continue
+
aggregated = self._aggregate_value(current_value, value, timestamp_ms)
- updated_windows.append(
- (
- key,
- self._results(aggregated, [], start, end),
- )
- )
+
+ # Check after_update trigger
+ if after_update and after_update(
+ aggregated, value, key, timestamp_ms, headers
+ ):
+ # Get collected values for the result
+ collected = []
+ if collect:
+ collected = state.get_from_collection(start, end)
+ # Add the current value that's being collected
+ collected.append(self._collect_value(value))
+
+ result = self._results(aggregated, collected, start, end)
+ triggered_windows.append((key, result))
+ transaction.delete_window(start, end, prefix=key)
+ # Note: We don't delete from collection here - normal expiration
+ # will handle cleanup for both tumbling and hopping windows
+ continue
+
+ result = self._results(aggregated, [], start, end)
+ updated_windows.append((key, result))
+ elif collect and (before_update or after_update):
+ # For collect-only windows, get the old collected values
+ old_collected = state.get_from_collection(start, end)
+
+ # Check before_update trigger (before adding new value)
+ if before_update and before_update(
+ old_collected, value, key, timestamp_ms, headers
+ ):
+ # Expire with the current collection (WITHOUT the new value)
+ result = self._results(None, old_collected, start, end)
+ triggered_windows.append((key, result))
+ transaction.delete_window(start, end, prefix=key)
+ # Note: We don't delete from collection here - normal expiration
+ # will handle cleanup for both tumbling and hopping windows
+ continue
+
+ # Check after_update trigger (conceptually after adding new value)
+ # For collect, "after update" means after the value would be added
+ if after_update:
+ new_collected = [*old_collected, self._collect_value(value)]
+ if after_update(new_collected, value, key, timestamp_ms, headers):
+ result = self._results(None, new_collected, start, end)
+ triggered_windows.append((key, result))
+ transaction.delete_window(start, end, prefix=key)
+ # Note: We don't delete from collection here - normal expiration
+ # will handle cleanup for both tumbling and hopping windows
+ continue
+
state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms)
if collect:
@@ -189,50 +314,34 @@ def process_window(
id=timestamp_ms,
)
- if self._closing_strategy == ClosingStrategy.PARTITION:
- expired_windows = self.expire_by_partition(
- transaction, max_expired_window_end, collect
- )
- else:
- expired_windows = self.expire_by_key(
- key, state, max_expired_window_start, collect
- )
-
- return updated_windows, expired_windows
+ return updated_windows, triggered_windows
def expire_by_partition(
self,
transaction: WindowedPartitionTransaction,
- max_expired_end: int,
- collect: bool,
+ timestamp_ms: int,
) -> Iterable[WindowKeyResult]:
+ """
+ Expire windows for the whole partition at the given timestamp.
+
+ :param transaction: state transaction object.
+ :param timestamp_ms: the current timestamp (inclusive).
+ """
+ latest_expired_window_end = transaction.get_latest_expired(prefix=b"")
+ latest_timestamp = max(timestamp_ms, latest_expired_window_end)
+ max_expired_window_end = max(latest_timestamp - self._grace_ms, 0)
+
for (
window_start,
window_end,
), aggregated, collected, key in transaction.expire_all_windows(
- max_end_time=max_expired_end,
+ max_end_time=max_expired_window_end,
step_ms=self._step_ms if self._step_ms else self._duration_ms,
- collect=collect,
+ collect=self.collect,
delete=True,
):
yield key, self._results(aggregated, collected, window_start, window_end)
- def expire_by_key(
- self,
- key: Any,
- state: WindowedState,
- max_expired_start: int,
- collect: bool,
- ) -> Iterable[WindowKeyResult]:
- for (
- window_start,
- window_end,
- ), aggregated, collected, _ in state.expire_windows(
- max_start_time=max_expired_start,
- collect=collect,
- ):
- yield (key, self._results(aggregated, collected, window_start, window_end))
-
def _on_expired_window(
self,
value: Any,
@@ -261,13 +370,12 @@ def _on_expired_window(
)
if to_log:
logger.warning(
- "Skipping window processing for the closed window "
- f"timestamp_ms={timestamp_ms} "
- f"window={(start, end)} "
- f"late_by_ms={late_by_ms} "
+ "Skipping record processing for the closed window. "
+ f"timestamp_ms={format_timestamp(timestamp_ms)} ({timestamp_ms}ms) "
+ f"window=[{format_timestamp(start)}, {format_timestamp(end)}) ([{start}ms, {end}ms)) "
+ f"late_by={late_by_ms}ms "
f"store_name={self._name} "
- f"partition={ctx.topic}[{ctx.partition}] "
- f"offset={ctx.offset}"
+ f"partition={ctx.topic}[{ctx.partition}]"
)
diff --git a/quixstreams/internal_producer.py b/quixstreams/internal_producer.py
index 42a0d461b..c322bbdeb 100644
--- a/quixstreams/internal_producer.py
+++ b/quixstreams/internal_producer.py
@@ -315,7 +315,8 @@ def commit_transaction(
group_metadata: GroupMetadata,
timeout: Optional[float] = None,
):
- self._send_offsets_to_transaction(positions, group_metadata, timeout)
+ if positions:
+ self._send_offsets_to_transaction(positions, group_metadata, timeout)
self._commit_transaction(timeout)
def __enter__(self):
diff --git a/quixstreams/models/messagecontext.py b/quixstreams/models/messagecontext.py
index 351fe9157..c672d0ce8 100644
--- a/quixstreams/models/messagecontext.py
+++ b/quixstreams/models/messagecontext.py
@@ -22,8 +22,8 @@ def __init__(
self,
topic: str,
partition: int,
- offset: int,
size: int,
+ offset: Optional[int] = None,
leader_epoch: Optional[int] = None,
):
self._topic = topic
@@ -41,7 +41,7 @@ def partition(self) -> int:
return self._partition
@property
- def offset(self) -> int:
+ def offset(self) -> Optional[int]:
return self._offset
@property
diff --git a/quixstreams/models/rows.py b/quixstreams/models/rows.py
index a618ee7d0..9e769da3a 100644
--- a/quixstreams/models/rows.py
+++ b/quixstreams/models/rows.py
@@ -36,7 +36,7 @@ def partition(self) -> int:
return self.context.partition
@property
- def offset(self) -> int:
+ def offset(self) -> Optional[int]:
return self.context.offset
@property
diff --git a/quixstreams/models/topics/manager.py b/quixstreams/models/topics/manager.py
index 56780352f..881528c72 100644
--- a/quixstreams/models/topics/manager.py
+++ b/quixstreams/models/topics/manager.py
@@ -60,6 +60,7 @@ def __init__(
self._consumer_group = consumer_group
self._regular_topics: Dict[str, Topic] = {}
self._repartition_topics: Dict[str, Topic] = {}
+ self._watermarks_topics: Dict[str, Topic] = {}
self._changelog_topics: Dict[Optional[str], Dict[str, Topic]] = {}
self._timeout = timeout
self._create_timeout = create_timeout
@@ -284,6 +285,30 @@ def changelog_topic(
self._changelog_topics.setdefault(stream_id, {})[store_name] = topic
return topic
+ def watermarks_topic(self):
+ """
+ The topic to be used to share watermarks across the application instances.
+ It is always prefixed with the consumer group name,
+ and it has only a single partition.
+ """
+ topic = Topic(
+ name=self._internal_name("watermarks", None, "watermarks"),
+ value_deserializer="json",
+ key_deserializer="str",
+ value_serializer="json",
+ key_serializer="str",
+ create_config=TopicConfig(
+ num_partitions=1, # The waterka
+ replication_factor=self.default_replication_factor,
+ extra_config={"cleanup.policy": "compact,delete"},
+ ),
+ topic_type=TopicType.WATERMARKS,
+ )
+ broker_topic = self._get_or_create_broker_topic(topic)
+ topic = self._configure_topic(topic, broker_topic)
+ self._watermarks_topics[topic.name] = topic
+ return topic
+
@classmethod
def derive_topic_config(cls, topics: Iterable[Topic]) -> TopicConfig:
"""
@@ -437,7 +462,7 @@ def _format_nested_name(self, topic_name: str) -> str:
def _internal_name(
self,
- topic_type: Literal["changelog", "repartition"],
+ topic_type: Literal["changelog", "repartition", "watermarks"],
topic_name: Optional[str],
suffix: str,
) -> str:
diff --git a/quixstreams/models/topics/topic.py b/quixstreams/models/topics/topic.py
index b50e9245a..e16cbcaf1 100644
--- a/quixstreams/models/topics/topic.py
+++ b/quixstreams/models/topics/topic.py
@@ -93,6 +93,7 @@ class TopicType(enum.Enum):
REGULAR = 1
REPARTITION = 2
CHANGELOG = 3
+ WATERMARKS = 4
class Topic:
diff --git a/quixstreams/platforms/quix/topic_manager.py b/quixstreams/platforms/quix/topic_manager.py
index 18a0e87b5..d43906650 100644
--- a/quixstreams/platforms/quix/topic_manager.py
+++ b/quixstreams/platforms/quix/topic_manager.py
@@ -128,7 +128,7 @@ def _create_topic(self, topic: Topic, timeout: float, create_timeout: float):
def _internal_name(
self,
- topic_type: Literal["changelog", "repartition"],
+ topic_type: Literal["changelog", "repartition", "watermarks"],
topic_name: Optional[str],
suffix: str,
):
diff --git a/quixstreams/processing/context.py b/quixstreams/processing/context.py
index fa0c55320..8c348941d 100644
--- a/quixstreams/processing/context.py
+++ b/quixstreams/processing/context.py
@@ -8,6 +8,7 @@
from quixstreams.exceptions import QuixException
from quixstreams.internal_consumer import InternalConsumer
from quixstreams.internal_producer import InternalProducer
+from quixstreams.processing.watermarking import WatermarkManager
from quixstreams.sinks import SinkManager
from quixstreams.state import StateStoreManager
from quixstreams.utils.printing import Printer
@@ -33,6 +34,7 @@ class ProcessingContext:
state_manager: StateStoreManager
sink_manager: SinkManager
dataframe_registry: DataFrameRegistry
+ watermark_manager: WatermarkManager
commit_every: int = 0
exactly_once: bool = False
printer: Printer = Printer()
diff --git a/quixstreams/processing/watermarking.py b/quixstreams/processing/watermarking.py
new file mode 100644
index 000000000..b777af432
--- /dev/null
+++ b/quixstreams/processing/watermarking.py
@@ -0,0 +1,157 @@
+import logging
+from time import monotonic
+from typing import Optional, TypedDict
+
+from quixstreams.internal_producer import InternalProducer
+from quixstreams.models import Topic
+from quixstreams.models.topics.manager import TopicManager
+from quixstreams.utils.format import format_timestamp
+from quixstreams.utils.json import dumps
+
+logger = logging.getLogger(__name__)
+
+__all__ = ("WatermarkManager", "WatermarkMessage")
+
+
+class WatermarkMessage(TypedDict):
+ topic: str
+ partition: int
+ timestamp: int
+
+
+class WatermarkManager:
+ def __init__(
+ self,
+ producer: InternalProducer,
+ topic_manager: TopicManager,
+ interval: float = 1.0,
+ ):
+ self._interval = interval
+ self._last_produced = 0
+ self._watermarks: dict[tuple[str, int], int] = {}
+ self._producer = producer
+ self._topic_manager = topic_manager
+ self._watermarks_topic: Optional[Topic] = None
+ self._to_produce: dict[tuple[str, int], tuple[int, bool]] = {}
+
+ def set_topics(self, topics: list[Topic]):
+ """
+ Set topics to be used as sources of watermarks
+ (normally, topics consumed by the application).
+
+ This method must be called before processing the watermarks.
+ It will clear the existing TP watermarks and prime the internal
+ state to know which partitions the app is expected to consume.
+ """
+ # Prime the watermarks with -1 for each expected topic partition
+ # to make sure we have all TP watermarks before calculating the main watemark.
+
+ self._watermarks = {
+ (topic.name, partition): -1
+ for topic in topics
+ for partition in range(topic.broker_config.num_partitions or 1)
+ }
+
+ @property
+ def watermarks_topic(self) -> Topic:
+ """
+ A topic with watermarks updates.
+ """
+ if self._watermarks_topic is None:
+ self._watermarks_topic = self._topic_manager.watermarks_topic()
+ return self._watermarks_topic
+
+ def on_revoke(self, topic: str, partition: int):
+ """
+ Remove the TP from tracking (e.g. when partition is revoked).
+ """
+ tp = (topic, partition)
+ self._to_produce.pop(tp, None)
+
+ def store(self, topic: str, partition: int, timestamp: int, default: bool):
+ """
+ Store the new watermark.
+
+ :param topic: topic name.
+ :param partition: partition number.
+ :param timestamp: watermark timestamp.
+ :param default: whether the watermark is set by the default mechanism
+ (i.e. extracted from the Kafka message timestamp or via Topic `timestamp_extractor`).
+ Non-default watermarks always override the defaults.
+ Default watermarks never override the non-default ones.
+ """
+ if timestamp < 0:
+ raise ValueError("Watermark cannot be negative.")
+ tp = (topic, partition)
+ stored_watermark, stored_default = self._to_produce.get(tp, (-1, True))
+ new_watermark = max(stored_watermark, timestamp)
+
+ if default and not stored_default:
+ # Skip watermark update if the non-default watermark is set.
+ return
+ elif not default and stored_default:
+ # Always override the default watermark
+ self._to_produce[tp] = (new_watermark, default)
+ elif new_watermark > stored_watermark:
+ # Schedule the updated watermark to be produced on the next cycle
+ # if it's tracked and larger than the previous one.
+ self._to_produce[tp] = (new_watermark, default)
+
+ def produce(self):
+ """
+ Produce updated watermarks to the watermarks topic.
+ """
+ if monotonic() >= self._last_produced + self._interval:
+ # Produce watermarks only for those partitions that are tracked by this application
+ # to avoid re-publishing the same watermarks.
+ for (topic, partition), (timestamp, _) in self._to_produce.items():
+ msg: WatermarkMessage = {
+ "topic": topic,
+ "partition": partition,
+ "timestamp": timestamp,
+ }
+ logger.debug(
+ f"Produce watermark {format_timestamp(timestamp)}. "
+ f"topic={topic} partition={partition} timestamp={timestamp}"
+ )
+ key = f"{topic}[{partition}]"
+ self._producer.produce(
+ topic=self._watermarks_topic.name, value=dumps(msg), key=key
+ )
+ self._last_produced = monotonic()
+ self._to_produce.clear()
+
+ def receive(self, message: WatermarkMessage) -> Optional[int]:
+ """
+ Receive and store the consumed watermark message.
+ Returns True if the new watermark is larger the existing one.
+ """
+ topic, partition, timestamp = (
+ message["topic"],
+ message["partition"],
+ message["timestamp"],
+ )
+ logger.debug(
+ f"Received watermark {format_timestamp(timestamp)}. topic={topic} partition={partition} timestamp={timestamp}"
+ )
+ current_watermark = self._get_watermark()
+ if current_watermark is None:
+ current_watermark = -1
+
+ # Store the updated TP watermark
+ tp = (topic, partition)
+ current_tp_watermark = self._watermarks.get(tp, -1)
+ self._watermarks[tp] = max(current_tp_watermark, timestamp)
+
+ # Check if the new TP watemark updates the overall watermark, and return it
+ # if it does.
+ new_watermark = self._get_watermark()
+ if new_watermark > current_watermark:
+ return new_watermark
+ return None
+
+ def _get_watermark(self) -> int:
+ watermark = -1
+ if watermarks := self._watermarks.values():
+ watermark = min(watermarks)
+ return watermark
diff --git a/quixstreams/runtracker.py b/quixstreams/runtracker.py
index 413398335..552da1581 100644
--- a/quixstreams/runtracker.py
+++ b/quixstreams/runtracker.py
@@ -101,14 +101,17 @@ def collect_values_and_metadata(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
+ if is_watermark:
+ return
ctx = message_context()
self._collector.add_value_and_metadata(
key=key,
value=value,
timestamp_ms=timestamp,
headers=headers,
- offset=ctx.offset,
+ offset=ctx.offset or 0,
partition=ctx.partition,
topic=ctx.topic,
)
@@ -119,7 +122,10 @@ def collect_values(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
+ if is_watermark:
+ return
self._collector.add_value(value=value)
def increment_count(
@@ -128,7 +134,10 @@ def increment_count(
key: Any,
timestamp: int,
headers: Any,
+ is_watermark: bool = False,
):
+ if is_watermark:
+ return
self._collector.increment_count()
def stop(self):
diff --git a/quixstreams/sinks/community/kafka.py b/quixstreams/sinks/community/kafka.py
new file mode 100644
index 000000000..ebb92b21c
--- /dev/null
+++ b/quixstreams/sinks/community/kafka.py
@@ -0,0 +1,211 @@
+import logging
+from typing import Any, Optional, Union
+
+from quixstreams.internal_producer import InternalProducer
+from quixstreams.kafka.configuration import ConnectionConfig
+from quixstreams.models import Row, Topic, TopicAdmin
+from quixstreams.models.messagecontext import MessageContext
+from quixstreams.models.serializers import SerializerType
+from quixstreams.models.types import HeadersTuples
+from quixstreams.sinks import (
+ BaseSink,
+ ClientConnectFailureCallback,
+ ClientConnectSuccessCallback,
+ SinkBackpressureError,
+)
+
+__all__ = ("KafkaReplicatorSink",)
+
+logger = logging.getLogger(__name__)
+
+
+class KafkaReplicatorSink(BaseSink):
+ """
+ A sink that produces data to an external Kafka cluster.
+
+ This sink uses the same serialization approach as the Quix Application.
+
+ Example Snippet:
+
+ ```python
+ from quixstreams import Application
+ from quixstreams.sinks.community.kafka import KafkaReplicatorSink
+
+ app = Application(
+ consumer_group="group",
+ )
+
+ topic = app.topic("input-topic")
+
+ # Define the external Kafka cluster configuration
+ kafka_sink = KafkaReplicatorSink(
+ broker_address="external-kafka:9092",
+ topic_name="output-topic",
+ value_serializer="json",
+ key_serializer="bytes",
+ )
+
+ sdf = app.dataframe(topic=topic)
+ sdf.sink(kafka_sink)
+
+ app.run()
+ ```
+ """
+
+ def __init__(
+ self,
+ broker_address: Union[str, ConnectionConfig],
+ topic_name: str,
+ value_serializer: SerializerType = "json",
+ key_serializer: SerializerType = "bytes",
+ producer_extra_config: Optional[dict] = None,
+ flush_timeout: float = 10.0,
+ origin_topic: Optional[Topic] = None,
+ auto_create_sink_topic: bool = True,
+ on_client_connect_success: Optional[ClientConnectSuccessCallback] = None,
+ on_client_connect_failure: Optional[ClientConnectFailureCallback] = None,
+ ) -> None:
+ """
+ :param broker_address: The connection settings for the external Kafka cluster.
+ Accepts string with Kafka broker host and port formatted as `:`,
+ or a ConnectionConfig object if authentication is required.
+ :param topic_name: The topic name to produce to on the external Kafka cluster.
+ :param value_serializer: The serializer type for values.
+ Default - `json`.
+ :param key_serializer: The serializer type for keys.
+ Default - `bytes`.
+ :param producer_extra_config: A dictionary with additional options that
+ will be passed to `confluent_kafka.Producer` as is.
+ Default - `None`.
+ :param flush_timeout: The time in seconds the producer waits for all messages
+ to be delivered during flush.
+ Default - 10.0.
+ :param origin_topic: If auto-creating the sink topic, can optionally pass the
+ source topic to use its configuration.
+ :param auto_create_sink_topic: Whether to try to create the sink topic upon startup
+ Default - True
+ :param on_client_connect_success: An optional callback made after successful
+ client authentication, primarily for additional logging.
+ :param on_client_connect_failure: An optional callback made after failed
+ client authentication (which should raise an Exception).
+ Callback should accept the raised Exception as an argument.
+ Callback must resolve (or propagate/re-raise) the Exception.
+ """
+ super().__init__(
+ on_client_connect_success=on_client_connect_success,
+ on_client_connect_failure=on_client_connect_failure,
+ )
+
+ self._broker_address = broker_address
+ self._topic_name = topic_name
+ self._value_serializer = value_serializer
+ self._key_serializer = key_serializer
+ self._producer_extra_config = producer_extra_config or {}
+ self._flush_timeout = flush_timeout
+ self._auto_create_sink_topic = auto_create_sink_topic
+ self._origin_topic = origin_topic
+
+ self._producer: Optional[InternalProducer] = None
+ self._topic: Optional[Topic] = None
+
+ def setup(self):
+ """
+ Initialize the InternalProducer and Topic for serialization.
+ """
+ logger.info(
+ f"Setting up KafkaReplicatorSink: "
+ f'broker_address="{self._broker_address}" '
+ f'topic="{self._topic_name}" '
+ f'value_serializer="{self._value_serializer}" '
+ f'key_serializer="{self._key_serializer}"'
+ )
+
+ self._producer = InternalProducer(
+ broker_address=self._broker_address,
+ extra_config=self._producer_extra_config,
+ flush_timeout=self._flush_timeout,
+ transactional=False,
+ )
+
+ self._topic = Topic(
+ name=self._topic_name,
+ value_serializer=self._value_serializer,
+ key_serializer=self._key_serializer,
+ create_config=self._origin_topic.broker_config
+ if self._origin_topic
+ else None,
+ )
+
+ if self._auto_create_sink_topic:
+ admin = TopicAdmin(
+ broker_address=self._broker_address,
+ extra_config=self._producer_extra_config,
+ )
+ admin.create_topics(topics=[self._topic])
+
+ def add(
+ self,
+ value: Any,
+ key: Any,
+ timestamp: int,
+ headers: HeadersTuples,
+ topic: str,
+ partition: int,
+ offset: int,
+ ) -> None:
+ """
+ Add a message to be produced to the external Kafka cluster.
+
+ This method converts the provided data into a Row object and uses
+ the InternalProducer to serialize and produce it.
+
+ :param value: The message value.
+ :param key: The message key.
+ :param timestamp: The message timestamp in milliseconds.
+ :param headers: The message headers.
+ :param topic: The source topic name.
+ :param partition: The source partition.
+ :param offset: The source offset.
+ """
+ context = MessageContext(
+ topic=topic,
+ partition=partition,
+ offset=offset,
+ size=0,
+ leader_epoch=None,
+ )
+ row = Row(
+ value=value,
+ key=key,
+ timestamp=timestamp,
+ context=context,
+ headers=headers,
+ )
+ self._producer.produce_row(
+ row=row,
+ topic=self._topic,
+ timestamp=timestamp,
+ )
+
+ def flush(self) -> None:
+ """
+ Flush the producer to ensure all messages are delivered.
+
+ This method is triggered by the Checkpoint class when it commits.
+ If flush fails, the checkpoint will be aborted.
+ """
+ logger.debug(f'Flushing KafkaReplicatorSink for topic "{self._topic_name}"')
+
+ # Flush all pending messages
+ result = self._producer.flush(timeout=self._flush_timeout)
+
+ if result > 0:
+ logger.warning(
+ f"{result} messages were not delivered to Kafka topic "
+ f'"{self._topic_name}" within the flush timeout of {self._flush_timeout}s'
+ )
+ raise SinkBackpressureError(retry_after=10.0)
+
+ logger.debug(
+ f'Successfully flushed KafkaReplicatorSink for topic "{self._topic_name}"'
+ )
diff --git a/quixstreams/sources/base/manager.py b/quixstreams/sources/base/manager.py
index 517232077..99e6f75ea 100644
--- a/quixstreams/sources/base/manager.py
+++ b/quixstreams/sources/base/manager.py
@@ -156,9 +156,7 @@ def _recover_state(self, source: StatefulSource) -> StorePartition:
self._consumer.assign([changelog_tp])
store_partitions = state_manager.on_partition_assign(
- stream_id=None,
- partition=source.assigned_store_partition,
- committed_offsets={},
+ stream_id=None, partition=source.assigned_store_partition
)
if state_manager.recovery_required:
diff --git a/quixstreams/state/base/transaction.py b/quixstreams/state/base/transaction.py
index 432b3922a..0b0d69c8b 100644
--- a/quixstreams/state/base/transaction.py
+++ b/quixstreams/state/base/transaction.py
@@ -25,13 +25,11 @@
)
from quixstreams.state.metadata import (
CHANGELOG_CF_MESSAGE_HEADER,
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER,
DEFAULT_PREFIX,
SEPARATOR,
Marker,
)
from quixstreams.state.serialization import DumpsFunc, LoadsFunc, deserialize, serialize
-from quixstreams.utils.json import dumps as json_dumps
from .state import State, TransactionState
@@ -477,7 +475,7 @@ def exists(self, key: K, prefix: bytes, cf_name: str = "default") -> bool:
return self._partition.exists(key_serialized, cf_name=cf_name)
@validate_transaction_status(PartitionTransactionStatus.STARTED)
- def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None:
+ def prepare(self) -> None:
"""
Produce changelog messages to the changelog topic for all changes accumulated
in this transaction and prepare transaction to flush its state to the state
@@ -488,18 +486,16 @@ def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None:
If changelog is disabled for this application, no updates will be produced
to the changelog topic.
-
- :param processed_offsets: the dict with of the latest processed message
"""
try:
- self._prepare(processed_offsets=processed_offsets)
+ self._prepare()
self._status = PartitionTransactionStatus.PREPARED
except Exception:
self._status = PartitionTransactionStatus.FAILED
raise
- def _prepare(self, processed_offsets: Optional[dict[str, int]]):
+ def _prepare(self):
if self._changelog_producer is None:
return
@@ -508,13 +504,11 @@ def _prepare(self, processed_offsets: Optional[dict[str, int]]):
f'topic_name="{self._changelog_producer.changelog_name}" '
f"partition={self._changelog_producer.partition}"
)
- source_tp_offset_header = json_dumps(processed_offsets)
column_families = self._update_cache.get_column_families()
for cf_name in column_families:
headers: Headers = {
CHANGELOG_CF_MESSAGE_HEADER: cf_name,
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: source_tp_offset_header,
}
updates = self._update_cache.get_updates(cf_name=cf_name)
diff --git a/quixstreams/state/manager.py b/quixstreams/state/manager.py
index 378de42ee..31d8863fd 100644
--- a/quixstreams/state/manager.py
+++ b/quixstreams/state/manager.py
@@ -295,7 +295,6 @@ def on_partition_assign(
self,
stream_id: Optional[str],
partition: int,
- committed_offsets: dict[str, int],
) -> Dict[str, StorePartition]:
"""
Assign store partitions for each registered store for the given stream_id
@@ -303,8 +302,6 @@ def on_partition_assign(
:param stream_id: stream id
:param partition: Kafka topic partition number
- :param committed_offsets: a dict with latest committed offsets
- of all assigned topics for this partition number.
:return: list of assigned `StorePartition`
"""
store_partitions = {}
@@ -315,7 +312,6 @@ def on_partition_assign(
self._recovery_manager.assign_partition(
topic=stream_id,
partition=partition,
- committed_offsets=committed_offsets,
store_partitions=store_partitions,
)
return store_partitions
diff --git a/quixstreams/state/metadata.py b/quixstreams/state/metadata.py
index 09dd70e72..3ec1d25fe 100644
--- a/quixstreams/state/metadata.py
+++ b/quixstreams/state/metadata.py
@@ -4,7 +4,6 @@
SEPARATOR_LENGTH = len(SEPARATOR)
CHANGELOG_CF_MESSAGE_HEADER = "__column_family__"
-CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER = "__processed_tp_offsets__"
METADATA_CF_NAME = "__metadata__"
DEFAULT_PREFIX = b""
diff --git a/quixstreams/state/recovery.py b/quixstreams/state/recovery.py
index b79c30188..74270d2cf 100644
--- a/quixstreams/state/recovery.py
+++ b/quixstreams/state/recovery.py
@@ -13,17 +13,13 @@
from quixstreams.models.types import Headers
from quixstreams.state.base import StorePartition
from quixstreams.utils.dicts import dict_values
-from quixstreams.utils.json import loads as json_loads
from .exceptions import (
ChangelogTopicPartitionNotAssigned,
ColumnFamilyHeaderMissing,
InvalidStoreChangelogOffset,
)
-from .metadata import (
- CHANGELOG_CF_MESSAGE_HEADER,
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER,
-)
+from .metadata import CHANGELOG_CF_MESSAGE_HEADER
logger = logging.getLogger(__name__)
@@ -50,7 +46,6 @@ def __init__(
changelog_name: str,
partition_num: int,
store_partition: StorePartition,
- committed_offsets: dict[str, int],
lowwater: int,
highwater: int,
):
@@ -59,7 +54,6 @@ def __init__(
self._store_partition = store_partition
self._changelog_lowwater = lowwater
self._changelog_highwater = highwater
- self._committed_offsets = committed_offsets
self._recovery_consume_position: Optional[int] = None
self._initial_offset: Optional[int] = None
@@ -154,40 +148,23 @@ def recover_from_changelog_message(
f"Header '{CHANGELOG_CF_MESSAGE_HEADER}' missing from changelog message"
)
- # Parse the processed topic-partition-offset info from the changelog message
- # headers to determine whether the update should be applied or skipped.
- # It can be empty if the message was produced by the older version of the lib.
- processed_offsets = json_loads(
- headers.get(CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, b"null")
- )
- if processed_offsets is None or self._should_apply_changelog(
- processed_offsets=processed_offsets
- ):
- key = changelog_message.key()
- if not isinstance(key, bytes):
- raise TypeError(
- f'Invalid changelog key type {type(key)}, expected "bytes"'
- )
-
- value = changelog_message.value()
- if not isinstance(value, (bytes, _NoneType)):
- raise TypeError(
- f'Invalid changelog value type {type(value)}, expected "bytes"'
- )
+ key = changelog_message.key()
+ if not isinstance(key, bytes):
+ raise TypeError(f'Invalid changelog key type {type(key)}, expected "bytes"')
- self._store_partition.recover_from_changelog_message(
- cf_name=cf_name,
- key=key,
- value=value,
- offset=changelog_message.offset(),
- )
- else:
- # Even if the changelog update is skipped, roll the changelog offset
- # to move forward within the changelog topic
- self._store_partition.write_changelog_offset(
- offset=changelog_message.offset(),
+ value = changelog_message.value()
+ if not isinstance(value, (bytes, _NoneType)):
+ raise TypeError(
+ f'Invalid changelog value type {type(value)}, expected "bytes"'
)
+ self._store_partition.recover_from_changelog_message(
+ cf_name=cf_name,
+ key=key,
+ value=value,
+ offset=changelog_message.offset(),
+ )
+
def set_recovery_consume_position(self, offset: int):
"""
Update the recovery partition with the consumer's position (whenever
@@ -199,26 +176,6 @@ def set_recovery_consume_position(self, offset: int):
"""
self._recovery_consume_position = offset
- def _should_apply_changelog(self, processed_offsets: dict[str, int]) -> bool:
- """
- Determine whether the changelog update should be skipped.
-
- :param processed_offsets: a dict with processed offsets
- from the changelog message header processed offset.
-
- :return: True if update should be applied, else False.
- """
- committed_offsets = self._committed_offsets
- for topic, processed_offset in processed_offsets.items():
- # Skip recovering from the message if its processed offset is ahead of the
- # current committed offset.
- # This is a best-effort to recover to a consistent state
- # if the checkpointing code produced the changelog messages
- # but failed to commit the source topic offset.
- if processed_offset >= committed_offsets[topic]:
- return False
- return True
-
class ChangelogProducerFactory:
"""
@@ -411,7 +368,6 @@ def _generate_recovery_partitions(
topic_name: Optional[str],
partition_num: int,
store_partitions: Dict[str, StorePartition],
- committed_offsets: dict[str, int],
) -> List[RecoveryPartition]:
partitions = []
for store_name, store_partition in store_partitions.items():
@@ -432,7 +388,6 @@ def _generate_recovery_partitions(
changelog_name=changelog_topic.name,
partition_num=partition_num,
store_partition=store_partition,
- committed_offsets=committed_offsets,
lowwater=lowwater,
highwater=highwater,
)
@@ -443,7 +398,6 @@ def assign_partition(
self,
topic: Optional[str],
partition: int,
- committed_offsets: dict[str, int],
store_partitions: Dict[str, StorePartition],
):
"""
@@ -455,7 +409,6 @@ def assign_partition(
topic_name=topic,
partition_num=partition,
store_partitions=store_partitions,
- committed_offsets=committed_offsets,
)
assigned_tps = set(
diff --git a/quixstreams/state/rocksdb/timestamped.py b/quixstreams/state/rocksdb/timestamped.py
index 4c80f9dd3..419480a0f 100644
--- a/quixstreams/state/rocksdb/timestamped.py
+++ b/quixstreams/state/rocksdb/timestamped.py
@@ -171,16 +171,14 @@ def set_for_timestamp(self, timestamp: int, value: Any, prefix: Any) -> None:
self._set_min_eligible_timestamp(prefix, min_eligible_timestamp)
@validate_transaction_status(PartitionTransactionStatus.STARTED)
- def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None:
+ def prepare(self) -> None:
"""
This method first calls `_expire()` to remove outdated entries based on
their timestamps and grace periods, then calls the parent class's
`prepare()` to prepare the transaction for flush.
-
- :param processed_offsets: the dict with of the latest processed message
"""
self._expire()
- super().prepare(processed_offsets=processed_offsets)
+ super().prepare()
def _expire(self) -> None:
"""
diff --git a/quixstreams/state/rocksdb/transaction.py b/quixstreams/state/rocksdb/transaction.py
index 5624499be..cc76288ef 100644
--- a/quixstreams/state/rocksdb/transaction.py
+++ b/quixstreams/state/rocksdb/transaction.py
@@ -95,15 +95,13 @@ def _get_items(
return sorted(merged_items.items(), key=lambda kv: kv[0], reverse=backwards)
@validate_transaction_status(PartitionTransactionStatus.STARTED)
- def prepare(self, processed_offsets: Optional[dict[str, int]] = None) -> None:
+ def prepare(self) -> None:
"""
This method first persists the counter and then calls the parent class's
`prepare()` to prepare the transaction for flush.
-
- :param processed_offsets: the dict with of the latest processed message
"""
self._persist_counter()
- super().prepare(processed_offsets=processed_offsets)
+ super().prepare()
def _increment_counter(self) -> int:
"""
diff --git a/quixstreams/state/rocksdb/windowed/metadata.py b/quixstreams/state/rocksdb/windowed/metadata.py
index a41838f10..9c54317c9 100644
--- a/quixstreams/state/rocksdb/windowed/metadata.py
+++ b/quixstreams/state/rocksdb/windowed/metadata.py
@@ -7,7 +7,4 @@
LATEST_DELETED_VALUE_CF_NAME = "__value-deletion-index__"
LATEST_DELETED_VALUE_TIMESTAMP_KEY = b"__value_deleted_start_gt__"
-LATEST_TIMESTAMPS_CF_NAME = "__latest-timestamps__"
-LATEST_TIMESTAMP_KEY = b"__latest_timestamp__"
-
VALUES_CF_NAME = "__values__"
diff --git a/quixstreams/state/rocksdb/windowed/state.py b/quixstreams/state/rocksdb/windowed/state.py
index 3e3021b20..740517b13 100644
--- a/quixstreams/state/rocksdb/windowed/state.py
+++ b/quixstreams/state/rocksdb/windowed/state.py
@@ -1,7 +1,7 @@
-from typing import TYPE_CHECKING, Any, Iterable, Optional
+from typing import TYPE_CHECKING, Any, Optional
from quixstreams.state.base import TransactionState
-from quixstreams.state.types import ExpiredWindowDetail, WindowDetail, WindowedState
+from quixstreams.state.types import WindowDetail, WindowedState
if TYPE_CHECKING:
from .transaction import WindowedRocksDBPartitionTransaction
@@ -107,46 +107,6 @@ def delete_from_collection(self, end: int, *, start: Optional[int] = None) -> No
end=end, start=start, prefix=self._prefix
)
- def get_latest_timestamp(self) -> Optional[int]:
- """
- Get the latest observed timestamp for the current message key.
-
- Use this timestamp to determine if the arriving event is late and should be
- discarded from the processing.
-
- :return: latest observed event timestamp in milliseconds
- """
-
- return self._transaction.get_latest_timestamp(prefix=self._prefix)
-
- def expire_windows(
- self,
- max_start_time: int,
- delete: bool = True,
- collect: bool = False,
- end_inclusive: bool = False,
- ) -> Iterable[ExpiredWindowDetail]:
- """
- Get all expired windows from RocksDB up to the specified `max_start_time` timestamp.
-
- This method marks the latest found window as expired in the expiration index,
- so consecutive calls may yield different results for the same "latest timestamp".
-
- :param max_start_time: The timestamp up to which windows are considered expired, inclusive.
- :param delete: If True, expired windows will be deleted.
- :param collect: If True, values will be collected into windows.
- :param end_inclusive: If True, the end of the window will be inclusive.
- Relevant only together with `collect=True`.
- :return: A sorted list of tuples in the format `((start, end), value)`.
- """
- return self._transaction.expire_windows(
- max_start_time=max_start_time,
- prefix=self._prefix,
- delete=delete,
- collect=collect,
- end_inclusive=end_inclusive,
- )
-
def get_windows(
self, start_from_ms: int, start_to_ms: int, backwards: bool = False
) -> list[WindowDetail]:
@@ -164,21 +124,3 @@ def get_windows(
prefix=self._prefix,
backwards=backwards,
)
-
- def delete_windows(self, max_start_time: int, delete_values: bool) -> None:
- """
- Delete windows from RocksDB up to the specified `max_start_time` timestamp.
-
- This method removes all window entries that have a start time less than or equal
- to the given `max_start_time`. It ensures that expired data is cleaned up
- efficiently without affecting unexpired windows.
-
- :param max_start_time: The timestamp up to which windows should be deleted, inclusive.
- :param delete_values: If True, values with timestamps less than max_start_time
- will be deleted, as they can no longer belong to any active window.
- """
- return self._transaction.delete_windows(
- max_start_time=max_start_time,
- delete_values=delete_values,
- prefix=self._prefix,
- )
diff --git a/quixstreams/state/rocksdb/windowed/transaction.py b/quixstreams/state/rocksdb/windowed/transaction.py
index 3779b3e29..18a697102 100644
--- a/quixstreams/state/rocksdb/windowed/transaction.py
+++ b/quixstreams/state/rocksdb/windowed/transaction.py
@@ -1,3 +1,4 @@
+import heapq
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
from quixstreams.state.base.transaction import (
@@ -24,8 +25,6 @@
LATEST_DELETED_WINDOW_TIMESTAMP_KEY,
LATEST_EXPIRED_WINDOW_CF_NAME,
LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY,
- LATEST_TIMESTAMP_KEY,
- LATEST_TIMESTAMPS_CF_NAME,
VALUES_CF_NAME,
)
from .serialization import parse_window_key
@@ -55,10 +54,6 @@ def __init__(
# Cache the metadata separately to avoid serdes on each access
# (we are 100% sure that the underlying types are immutable, while windows'
# values are not)
- self._latest_timestamps: Cache = Cache(
- key=LATEST_TIMESTAMP_KEY,
- cf_name=LATEST_TIMESTAMPS_CF_NAME,
- )
self._last_expired_timestamps: Cache = Cache(
key=LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY,
cf_name=LATEST_EXPIRED_WINDOW_CF_NAME,
@@ -84,25 +79,37 @@ def as_state(self, prefix: Any = DEFAULT_PREFIX) -> WindowedTransactionState: #
@validate_transaction_status(PartitionTransactionStatus.STARTED)
def keys(self, cf_name: str = "default") -> Iterable[Any]:
- db_skip_keys: set[bytes] = set()
+ """
+ Return all keys in the store partition for the given column family.
+ It merges data from the transaction update cache and DB,
+ and returns keys in a sorted way.
- cache = self._update_cache.get_updates(cf_name=cf_name)
- for prefix_update_cache in cache.values():
- # when iterating over the DB, skip keys already returned by the cache
- db_skip_keys.update(prefix_update_cache.keys())
- yield from prefix_update_cache.keys()
+ :param cf_name: column family name.
+ """
+ delete_cache_keys: set[bytes] = self._update_cache.get_deletes()
+ update_cache_keys: set[bytes] = set()
- # skip keys that were deleted from the cache
- db_skip_keys.update(self._update_cache.get_deletes())
+ for prefix_update_cache in self._update_cache.get_updates(
+ cf_name=cf_name
+ ).values():
+ # when iterating over the DB, skip keys already returned by the cache
+ update_cache_keys.update(prefix_update_cache.keys())
+
+ # Get the keys stored in the DB excluding the keys updated/deleted
+ # in the current transaction
+ db_skip_keys = delete_cache_keys | update_cache_keys
+ stored_keys = (
+ key
+ for key in self._partition.iter_keys(cf_name=cf_name)
+ if key not in db_skip_keys
+ )
- for key in self._partition.iter_keys(cf_name=cf_name):
- if key in db_skip_keys:
- continue
+ # Sort the keys updated in the cache to iterate over both generators
+ # in the sorted way
+ update_cache_keys_sorted = sorted(update_cache_keys)
+ for key in heapq.merge(stored_keys, update_cache_keys_sorted):
yield key
- def get_latest_timestamp(self, prefix: bytes) -> int:
- return self._get_timestamp(prefix=prefix, cache=self._latest_timestamps) or 0
-
def get_latest_expired(self, prefix: bytes) -> int:
return (
self._get_timestamp(prefix=prefix, cache=self._last_expired_timestamps) or 0
@@ -133,18 +140,6 @@ def update_window(
key = encode_integer_pair(start_ms, end_ms)
self.set(key=key, value=value, prefix=prefix)
- latest_timestamp_ms = self.get_latest_timestamp(prefix=prefix)
- updated_timestamp_ms = (
- max(latest_timestamp_ms, timestamp_ms)
- if latest_timestamp_ms is not None
- else timestamp_ms
- )
-
- self._set_timestamp(
- cache=self._latest_timestamps,
- prefix=prefix,
- timestamp_ms=updated_timestamp_ms,
- )
def add_to_collection(
self,
@@ -199,119 +194,36 @@ def delete_window(self, start_ms: int, end_ms: int, prefix: bytes):
key = encode_integer_pair(start_ms, end_ms)
self.delete(key=key, prefix=prefix)
- def expire_windows(
- self,
- max_start_time: int,
- prefix: bytes,
- delete: bool = True,
- collect: bool = False,
- end_inclusive: bool = False,
- ) -> Iterable[ExpiredWindowDetail]:
- """
- Get all expired windows with a set prefix from RocksDB up to the specified `max_start_time` timestamp.
-
- This method marks the latest found window as expired in the expiration index,
- so consecutive calls may yield different results for the same "latest timestamp".
-
- How it works:
- - First, it checks the expiration cache for the start time of the last expired
- window for the current prefix. If found, this value helps reduce the search
- space and prevents returning previously expired windows.
- - Next, it iterates over window segments and identifies the windows that should
- be marked as expired.
- - Finally, it updates the expiration cache with the start time of the latest
- windows found.
-
- Collection behavior (when collect=True):
- - For tumbling and hopping windows (created using .collect()), the window
- value is None and is replaced with the list of collected values.
- - For sliding windows, the window value is [max_timestamp, None] where
- None is replaced with the list of collected values.
- - Values are collected from a separate column family and obsolete values
- are deleted if delete=True.
-
- :param max_start_time: The timestamp up to which windows are considered expired, inclusive.
- :param prefix: The key prefix for filtering windows.
- :param delete: If True, expired windows will be deleted.
- :param collect: If True, values will be collected into windows.
- :param end_inclusive: If True, the end of the window will be inclusive.
- Relevant only together with `collect=True`.
- :return: A sorted list of tuples in the format `((start, end), value)`.
- """
- start_from = -1
-
- # Find the latest start timestamp of the expired windows for the given key
- last_expired = self._get_timestamp(
- cache=self._last_expired_timestamps, prefix=prefix
- )
- if last_expired is not None:
- start_from = max(start_from, last_expired)
-
- # Use the latest expired timestamp to limit the iteration over
- # only those windows that have not been expired before
- windows = self.get_windows(
- start_from_ms=start_from,
- start_to_ms=max_start_time,
- prefix=prefix,
- )
- if not windows:
- return
-
- # Save the start of the latest expired window to the expiration index
- latest_window = windows[-1]
- last_expired__gt = latest_window[0][0]
-
- self._set_timestamp(
- cache=self._last_expired_timestamps,
- prefix=prefix,
- timestamp_ms=last_expired__gt,
- )
-
- # Collect values into windows
- if collect:
- for (start, end), aggregated, key in windows:
- collected = self.get_from_collection(
- start=start,
- # Sliding windows are inclusive on both ends
- # (including timestamps of messages equal to `end`).
- # Since RocksDB range queries are exclusive on the
- # `end` boundary, we add +1 to include it.
- end=end + 1 if end_inclusive else end,
- prefix=prefix,
- )
- yield ((start, end), aggregated, collected, key)
-
- else:
- for window, aggregated, key in windows:
- yield (window, aggregated, [], key)
-
- # Delete expired windows from the state
- if delete:
- for (start, end), _, _ in windows:
- self.delete_window(start, end, prefix=prefix)
- if collect:
- self.delete_from_collection(end=start, prefix=prefix)
-
def expire_all_windows(
self,
max_end_time: int,
- step_ms: int,
+ step_ms: int = 1,
delete: bool = True,
collect: bool = False,
+ end_inclusive: bool = False,
) -> Iterable[ExpiredWindowDetail]:
"""
Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp.
:param max_end_time: The timestamp up to which windows are considered expired, inclusive.
+ :param step_ms: step between the windows is known.
+ For example, tumbling windows of size 100ms have 100ms step between them.
+ This value is used to optimize the DB lookups.
+ Default - 1ms.
:param delete: If True, expired windows will be deleted.
:param collect: If True, values will be collected into windows.
+ :param end_inclusive: If True, the end of the window will be inclusive.
+ Relevant only together with `collect=True`.
"""
+
+ max_end_time = max(max_end_time, 0)
last_expired = self.get_latest_expired(prefix=b"")
to_delete: set[tuple[bytes, int, int]] = set()
collected = []
-
if last_expired:
+ # TODO: Probably optimize that. It works only for tumbling/hopping windows
+ # with fixed boundaries
windows = windows_to_expire(last_expired, max_end_time, step_ms)
if not windows:
return
@@ -327,7 +239,11 @@ def expire_all_windows(
if collect:
collected = self.get_from_collection(
start=start,
- end=end,
+ # Sliding windows are inclusive on both ends
+ # (including timestamps of messages equal to `end`).
+ # Since RocksDB range queries are exclusive on the
+ # `end` boundary, we add +1 to include it.
+ end=end + 1 if end_inclusive else end,
prefix=prefix,
)
yield (start, end), aggregated, collected, prefix
@@ -348,7 +264,11 @@ def expire_all_windows(
if collect:
collected = self.get_from_collection(
start=start,
- end=end,
+ # Sliding windows are inclusive on both ends
+ # (including timestamps of messages equal to `end`).
+ # Since RocksDB range queries are exclusive on the
+ # `end` boundary, we add +1 to include it.
+ end=end + 1 if end_inclusive else end,
prefix=prefix,
)
@@ -364,60 +284,20 @@ def expire_all_windows(
prefix=b"", cache=self._last_expired_timestamps, timestamp_ms=last_expired
)
- def delete_windows(
- self, max_start_time: int, delete_values: bool, prefix: bytes
- ) -> None:
+ def delete_all_windows(self, max_end_time: int, collect: bool) -> None:
"""
Delete windows from RocksDB up to the specified `max_start_time` timestamp.
- This method removes all window entries that have a start time less than or equal to the given
- `max_start_time`. It ensures that expired data is cleaned up efficiently without affecting
- unexpired windows.
-
- How it works:
- - It retrieves the start time of the last deleted window for the given prefix from the
- deletion index. This minimizes redundant scans over already deleted windows.
- - It iterates over the windows starting from the last deleted timestamp up to the `max_start_time`.
- - Each window within this range is deleted from the database.
- - After deletion, it updates the deletion index with the start time of the latest window
- that was deleted to keep track of progress.
- - Values with timestamps less than max_start_time are considered obsolete and are
- deleted if delete_values=True, as they can no longer belong to any active window.
-
- :param max_start_time: The timestamp up to which windows should be deleted, inclusive.
- :param delete_values: If True, obsolete values will be deleted.
- :param prefix: The key prefix used to identify and filter relevant windows.
+ :param max_end_time: The timestamp up to which windows should be deleted, inclusive.
+ :param collect: If True, the values from collections will be deleted too.
"""
- start_from = -1
-
- # Find the latest start timestamp of the deleted windows for the given key
- last_deleted = self._get_timestamp(
- cache=self._last_deleted_window_timestamps, prefix=prefix
- )
- if last_deleted is not None:
- start_from = max(start_from, last_deleted)
-
- windows = self.get_windows(
- start_from_ms=start_from,
- start_to_ms=max_start_time,
- prefix=prefix,
- )
-
- last_deleted__gt = None
- for (start, end), _, _ in windows:
- last_deleted__gt = start
- self.delete_window(start, end, prefix=prefix)
-
- # Save the start of the latest deleted window to the deletion index
- if last_deleted__gt:
- self._set_timestamp(
- cache=self._last_deleted_window_timestamps,
- prefix=prefix,
- timestamp_ms=last_deleted__gt,
- )
-
- if delete_values:
- self.delete_from_collection(end=max_start_time, prefix=prefix)
+ max_end_time = max(max_end_time, 0)
+ for key in self.keys():
+ prefix, start, end = parse_window_key(key)
+ if end <= max_end_time:
+ self.delete_window(start, end, prefix)
+ if collect:
+ self.delete_from_collection(end=start, prefix=prefix)
def get_windows(
self,
diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py
index c80c9e2ad..60317328b 100644
--- a/quixstreams/state/types.py
+++ b/quixstreams/state/types.py
@@ -140,53 +140,6 @@ def delete_from_collection(self, end: int, *, start: Optional[int] = None) -> No
"""
...
- def get_latest_timestamp(self) -> Optional[int]:
- """
- Get the latest observed timestamp for the current state partition.
-
- Use this timestamp to determine if the arriving event is late and should be
- discarded from the processing.
-
- :return: latest observed event timestamp in milliseconds
- """
- ...
-
- def expire_windows(
- self,
- max_start_time: int,
- delete: bool = True,
- collect: bool = False,
- end_inclusive: bool = False,
- ) -> Iterable[ExpiredWindowDetail[V]]:
- """
- Get all expired windows from RocksDB up to the specified `max_start_time` timestamp.
-
- This method marks the latest found window as expired in the expiration index,
- so consecutive calls may yield different results for the same "latest timestamp".
-
- :param max_start_time: The timestamp up to which windows are considered expired, inclusive.
- :param delete: If True, expired windows will be deleted.
- :param collect: If True, values will be collected into windows.
- :param end_inclusive: If True, the end of the window will be inclusive.
- Relevant only together with `collect=True`.
- :return: A sorted list of tuples in the format `((start, end), value)`.
- """
- ...
-
- def delete_windows(self, max_start_time: int, delete_values: bool) -> None:
- """
- Delete windows from RocksDB up to the specified `max_start_time` timestamp.
-
- This method removes all window entries that have a start time less than or equal
- to the given `max_start_time`. It ensures that expired data is cleaned up
- efficiently without affecting unexpired windows.
-
- :param max_start_time: The timestamp up to which windows should be deleted, inclusive.
- :param delete_values: If True, values with timestamps less than max_start_time
- will be deleted, as they can no longer belong to any active window.
- """
- ...
-
def get_windows(
self, start_from_ms: int, start_to_ms: int, backwards: bool = False
) -> list[WindowDetail[V]]:
@@ -232,7 +185,7 @@ def prepared(self) -> bool:
"""
...
- def prepare(self, processed_offsets: Optional[dict[str, int]]):
+ def prepare(self):
"""
Produce changelog messages to the changelog topic for all changes accumulated
in this transaction and prepare transcation to flush its state to the state
@@ -243,9 +196,6 @@ def prepare(self, processed_offsets: Optional[dict[str, int]]):
If changelog is disabled for this application, no updates will be produced
to the changelog topic.
-
- :param processed_offsets: the dict with of
- the latest processed message in the current partition
"""
def as_state(self, prefix: Any) -> WindowedState[K, V]: ...
@@ -324,18 +274,6 @@ def delete_from_collection(self, end: int) -> None:
"""
...
- def get_latest_timestamp(self, prefix: bytes) -> int:
- """
- Get the latest observed timestamp for the current state prefix
- (same as message key).
-
- Use this timestamp to determine if the arriving event is late and should be
- discarded from the processing.
-
- :return: latest observed event timestamp in milliseconds
- """
- ...
-
def get_latest_expired(self, prefix: bytes) -> int:
"""
Get the latest expired timestamp for the current state prefix
@@ -348,46 +286,35 @@ def get_latest_expired(self, prefix: bytes) -> int:
"""
...
- def expire_windows(
+ def expire_all_windows(
self,
- max_start_time: int,
- prefix: bytes,
+ max_end_time: int,
+ step_ms: int,
delete: bool = True,
collect: bool = False,
end_inclusive: bool = False,
) -> Iterable[ExpiredWindowDetail[V]]:
"""
- Get all expired windows with a set prefix from RocksDB up to the specified `max_start_time` timestamp.
+ Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp.
This method marks the latest found window as expired in the expiration index,
so consecutive calls may yield different results for the same "latest timestamp".
- :param max_start_time: The timestamp up to which windows are considered expired, inclusive.
- :param prefix: The key prefix for filtering windows.
+ :param max_end_time: The timestamp up to which windows are considered expired, inclusive.
:param delete: If True, expired windows will be deleted.
:param collect: If True, values will be collected into windows.
:param end_inclusive: If True, the end of the window will be inclusive.
Relevant only together with `collect=True`.
- :return: A sorted list of tuples in the format `((start, end), value)`.
"""
...
- def expire_all_windows(
- self,
- max_end_time: int,
- step_ms: int,
- delete: bool = True,
- collect: bool = False,
- ) -> Iterable[ExpiredWindowDetail[V]]:
+ def delete_window(self, start_ms: int, end_ms: int, prefix: bytes) -> None:
"""
- Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp.
-
- This method marks the latest found window as expired in the expiration index,
- so consecutive calls may yield different results for the same "latest timestamp".
+ Delete a single window defined by start and end timestamps.
- :param max_end_time: The timestamp up to which windows are considered expired, inclusive.
- :param delete: If True, expired windows will be deleted.
- :param collect: If True, values will be collected into windows.
+ :param start_ms: start of the window in milliseconds
+ :param end_ms: end of the window in milliseconds
+ :param prefix: a key prefix
"""
...
@@ -397,14 +324,18 @@ def delete_windows(
"""
Delete windows from RocksDB up to the specified `max_start_time` timestamp.
- This method removes all window entries that have a start time less than or equal
- to the given `max_start_time`. It ensures that expired data is cleaned up
- efficiently without affecting unexpired windows.
-
:param max_start_time: The timestamp up to which windows should be deleted, inclusive.
- :param delete_values: If True, values with timestamps less than max_start_time
- will be deleted, as they can no longer belong to any active window.
- :param prefix: The key prefix used to identify and filter relevant windows.
+ :param delete_values: If True, the values from collections will be deleted too.
+ :param prefix: a key prefix
+ """
+ ...
+
+ def delete_all_windows(self, max_end_time: int, collect: bool) -> None:
+ """
+ Delete windows from RocksDB up to the specified `max_end_time` timestamp.
+
+ :param max_end_time: The timestamp up to which windows should be deleted, inclusive.
+ :param collect: If True, the values from collections will be deleted too.
"""
...
diff --git a/quixstreams/utils/format.py b/quixstreams/utils/format.py
new file mode 100644
index 000000000..486a4017d
--- /dev/null
+++ b/quixstreams/utils/format.py
@@ -0,0 +1,9 @@
+from datetime import datetime, timezone
+
+__all__ = ("format_timestamp",)
+
+
+def format_timestamp(timestamp_ms: int) -> str:
+ return datetime.fromtimestamp(timestamp_ms / 1000, timezone.utc).strftime(
+ "%Y-%m-%d %H:%M:%S.%f"
+ )[:-3]
diff --git a/requirements-mypy.txt b/requirements-mypy.txt
index 3d8b6b294..71f3e76cf 100644
--- a/requirements-mypy.txt
+++ b/requirements-mypy.txt
@@ -1,5 +1,5 @@
-mypy==1.18.1
+mypy==1.18.2
mypy-extensions==1.1.0
-types-jsonschema==4.25.1.20250822
-types-protobuf==6.30.2.20250914
+types-jsonschema==4.25.1.20251009
+types-protobuf==6.32.1.20250918
types-psycopg2>=2.9,<3
diff --git a/requirements.txt b/requirements.txt
index aef64438d..1670dd7c6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,8 +2,8 @@ confluent-kafka[avro,json,protobuf,schemaregistry]>=2.8.2,<2.12
rocksdict>=0.3,<0.4
typing_extensions>=4.8
orjson>=3.9,<4
-pydantic>=2.7,<2.12
-pydantic-settings>=2.3,<2.11
+pydantic>=2.7,<2.13
+pydantic-settings>=2.3,<2.12
jsonschema>=4.3.0
jsonlines>=4,<5
rich>=13,<15
diff --git a/tests/requirements.txt b/tests/requirements.txt
index e7939ef50..032dbd9fb 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -1,4 +1,4 @@
-testcontainers[postgres]==4.13.0
+testcontainers[postgres]==4.13.2
pytest
docker>=7.1.0 # Required to use requests>=2.32
fastavro>=1.8,<2.0
diff --git a/tests/test_quixstreams/test_app.py b/tests/test_quixstreams/test_app.py
index fa39be29d..297c298be 100644
--- a/tests/test_quixstreams/test_app.py
+++ b/tests/test_quixstreams/test_app.py
@@ -12,7 +12,6 @@
from quixstreams.app import Application
from quixstreams.dataframe import StreamingDataFrame
-from quixstreams.dataframe.windows.base import get_window_ranges
from quixstreams.exceptions import PartitionAssignmentError
from quixstreams.internal_consumer import InternalConsumer
from quixstreams.internal_producer import InternalProducer
@@ -1205,9 +1204,7 @@ def _validate_state(
)
state_manager.register_store(stream_id, "default")
state_manager.on_partition_assign(
- stream_id=stream_id,
- partition=partition_index,
- committed_offsets={stream_id: -1001},
+ stream_id=stream_id, partition=partition_index
)
store = state_manager.get_store(stream_id=stream_id, store_name="default")
with store.start_partition_transaction(partition=partition_index) as tx:
@@ -1336,11 +1333,7 @@ def count_and_fail(_, state: State):
group_id=consumer_group, state_dir=state_dir
)
state_manager.register_store(sdf.stream_id, "default")
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id,
- partition=0,
- committed_offsets={},
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
store = state_manager.get_store(stream_id=sdf.stream_id, store_name="default")
with store.start_partition_transaction(partition=0) as tx:
assert tx.get("total", prefix=key) is None
@@ -1453,11 +1446,7 @@ def test_clear_state(
# Add data to the state store
with state_manager:
state_manager.register_store(topic_in_name, "default")
- state_manager.on_partition_assign(
- stream_id=topic_in_name,
- partition=0,
- committed_offsets={topic_in_name: -1001},
- )
+ state_manager.on_partition_assign(stream_id=topic_in_name, partition=0)
store = state_manager.get_store(
stream_id=topic_in_name, store_name="default"
)
@@ -1471,11 +1460,7 @@ def test_clear_state(
# Check that the date is cleared from the state store
with state_manager:
state_manager.register_store(topic_in_name, "default")
- state_manager.on_partition_assign(
- stream_id=topic_in_name,
- partition=0,
- committed_offsets={topic_in_name: -1001},
- )
+ state_manager.on_partition_assign(stream_id=topic_in_name, partition=0)
store = state_manager.get_store(
stream_id=topic_in_name, store_name="default"
)
@@ -1563,9 +1548,7 @@ def validate_state(stores):
state_manager.register_store(sdf.stream_id, store_name)
for p_num, count in partition_msg_count.items():
state_manager.on_partition_assign(
- stream_id=sdf.stream_id,
- partition=p_num,
- committed_offsets={topic.name: -1001},
+ stream_id=sdf.stream_id, partition=p_num
)
store = state_manager.get_store(
stream_id=sdf.stream_id, store_name=store_name
@@ -1621,186 +1604,7 @@ def revoke_partition(store, partition):
# State should be the same as before deletion
validate_state(stores)
- @pytest.mark.parametrize("processing_guarantee", ["at-least-once", "exactly-once"])
- def test_changelog_recovery_window_store(
- self,
- app_factory,
- executor,
- tmp_path,
- state_manager_factory,
- processing_guarantee,
- ):
- consumer_group = str(uuid.uuid4())
- state_dir = (tmp_path / "state").absolute()
- topic_name = str(uuid.uuid4())
- store_name = "window"
- window_duration_ms = 5000
- window_step_ms = 2000
-
- msg_tick_ms = 1000
- msg_int_value = 10
-
- partition_timestamps = {
- 0: list(range(10000, 14000, msg_tick_ms)),
- 1: list(range(10000, 12000, msg_tick_ms)),
- }
- partition_windows = {
- p: [
- w
- for ts in ts_list
- for w in get_window_ranges(ts, window_duration_ms, window_step_ms)
- ]
- for p, ts_list in partition_timestamps.items()
- }
-
- # how many times window updates should occur (1:1 with changelog updates)
- expected_window_updates = {0: {}, 1: {}}
- # expired windows should have no values (changelog updates per tx == num_exp_windows + 1)
- expected_expired_windows = {0: set(), 1: set()}
-
- for p, windows in partition_windows.items():
- latest_timestamp = partition_timestamps[p][-1]
- for w in windows:
- if latest_timestamp >= w[1]:
- expected_expired_windows[p].add(w)
- expected_window_updates[p][w] = (
- expected_window_updates[p].setdefault(w, 0) + 1
- )
-
- processed_count = {0: 0, 1: 0}
- partition_msg_count = {
- p: len(partition_timestamps[p]) for p in partition_timestamps
- }
-
- def on_message_processed(topic_, partition, offset):
- # Set the callback to track total messages processed
- # The callback is not triggered if processing fails
- processed_count[partition] += 1
- if processed_count == partition_msg_count:
- done.set_result(True)
-
- def get_app():
- app = app_factory(
- commit_interval=0, # Commit every processed message
- auto_offset_reset="earliest",
- use_changelog_topics=True,
- consumer_group=consumer_group,
- on_message_processed=on_message_processed,
- state_dir=state_dir,
- processing_guarantee=processing_guarantee,
- )
- topic = app.topic(
- topic_name,
- config=TopicConfig(
- num_partitions=len(partition_msg_count), replication_factor=1
- ),
- )
- # Create a streaming dataframe with a hopping window
- sdf = (
- app.dataframe(topic)
- .apply(lambda row: row["my_value"])
- .hopping_window(
- duration_ms=window_duration_ms,
- step_ms=window_step_ms,
- name=store_name,
- )
- .sum()
- .final()
- )
- return app, sdf, topic
-
- def validate_state():
- actual_store_name = (
- f"{store_name}_hopping_window_{window_duration_ms}_{window_step_ms}_sum"
- )
- with state_manager_factory(
- group_id=consumer_group, state_dir=state_dir
- ) as state_manager:
- state_manager.register_windowed_store(sdf.stream_id, actual_store_name)
- for p_num, windows in expected_window_updates.items():
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id,
- partition=p_num,
- committed_offsets={topic.name: -1001},
- )
- store = state_manager.get_store(
- stream_id=sdf.stream_id,
- store_name=actual_store_name,
- )
-
- # Calculate how many messages should be send to the changelog topic
- expected_offset = (
- # A number of total window updates
- sum(expected_window_updates[p_num].values())
- # A number of expired windows
- + 2 * len(expected_expired_windows[p_num])
- # A number of total timestamps
- # (each timestamp updates the )
- + len(partition_timestamps[p_num])
- # Correction for zero-based index
- - 1
- )
- if processing_guarantee == "exactly-once":
- # In this test, we commit after each message is processed, so
- # must add PMC-1 to our offset calculation since each kafka
- # to account for transaction commit markers (except last one)
- expected_offset += partition_msg_count[p_num] - 1
- assert (
- expected_offset
- == store.partitions[p_num].get_changelog_offset()
- )
-
- partition = store.partitions[p_num]
-
- with partition.begin() as tx:
- prefix = b"key"
- for window, count in windows.items():
- expected = count
- if window in expected_expired_windows[p_num]:
- expected = None
- else:
- # each message value was 10
- expected *= msg_int_value
- assert tx.get_window(*window, prefix=prefix) == expected
-
- app, sdf, topic = get_app()
- # Produce messages to the topic and flush
- with app.get_producer() as producer:
- for p_num, timestamps in partition_timestamps.items():
- serialized = topic.serialize(
- key=b"key", value={"my_value": msg_int_value}
- )
- for ts in timestamps:
- producer.produce(
- topic=topic.name,
- key=serialized.key,
- value=serialized.value,
- partition=p_num,
- timestamp=ts,
- )
-
- # run app to populate state
- done = Future()
- executor.submit(_stop_app_on_future, app, done, 10.0)
- app.run()
- # validate and then delete the state
- assert processed_count == partition_msg_count
- validate_state()
-
- # run the app again and validate the recovered state
- processed_count = {0: 0, 1: 0}
- app, sdf, topic = get_app()
- app.clear_state()
- done = Future()
- executor.submit(_stop_app_on_future, app, done, 10.0)
- app.run()
- # no messages should have been processed outside of recovery loop
- assert processed_count == {0: 0, 1: 0}
- # State should be the same as before deletion
- validate_state()
-
- @pytest.mark.parametrize("processing_guarantee", ["at-least-once", "exactly-once"])
- def test_changelog_recovery_consistent_after_failed_commit(
+ def test_changelog_recovery_consistent_after_failed_commit_exactly_once(
self,
store_type,
app_factory,
@@ -1808,7 +1612,6 @@ def test_changelog_recovery_consistent_after_failed_commit(
tmp_path,
state_manager_factory,
internal_consumer_factory,
- processing_guarantee,
):
"""
Scenario: application processes messages and successfully produces changelog
@@ -1822,14 +1625,9 @@ def test_changelog_recovery_consistent_after_failed_commit(
topic_name = str(uuid.uuid4())
store_name = "default"
- if processing_guarantee == "exactly-once":
- commit_patch = patch.object(
- InternalProducer, "commit_transaction", side_effect=ValueError("Fail")
- )
- else:
- commit_patch = patch.object(
- InternalConsumer, "commit", side_effect=ValueError("Fail")
- )
+ commit_patch = patch.object(
+ InternalProducer, "commit_transaction", side_effect=ValueError("Fail")
+ )
# Messages to be processed successfully
succeeded_messages = [
@@ -1864,7 +1662,7 @@ def get_app():
on_message_processed=on_message_processed,
consumer_group=consumer_group,
state_dir=state_dir,
- processing_guarantee=processing_guarantee,
+ processing_guarantee="exactly-once",
)
topic = app.topic(topic_name)
sdf = app.dataframe(topic)
@@ -1892,18 +1690,11 @@ def validate_state(stores):
group_id=consumer_group,
state_dir=state_dir,
) as state_manager,
- internal_consumer_factory(
- consumer_group=consumer_group
- ) as consumer,
+ internal_consumer_factory(consumer_group=consumer_group),
):
- committed_offset = consumer.committed(
- [TopicPartition(topic=topic_name, partition=0)]
- )[0].offset
state_manager.register_store(sdf.stream_id, store_name)
partition = state_manager.on_partition_assign(
- stream_id=sdf.stream_id,
- partition=0,
- committed_offsets={topic_name: committed_offset},
+ stream_id=sdf.stream_id, partition=0
)["default"]
with partition.begin() as tx:
_validate_transaction_state(tx)
@@ -2617,9 +2408,7 @@ def _validate_state(
)
state_manager.register_store(stream_id, "default")
state_manager.on_partition_assign(
- stream_id=stream_id,
- partition=partition_num,
- committed_offsets={},
+ stream_id=stream_id, partition=partition_num
)
store = state_manager.get_store(
stream_id=stream_id, store_name="default"
diff --git a/tests/test_quixstreams/test_dataframe/fixtures.py b/tests/test_quixstreams/test_dataframe/fixtures.py
index 1955c17a3..dbd592d3a 100644
--- a/tests/test_quixstreams/test_dataframe/fixtures.py
+++ b/tests/test_quixstreams/test_dataframe/fixtures.py
@@ -9,6 +9,7 @@
from quixstreams.internal_producer import InternalProducer
from quixstreams.models.topics import Topic, TopicManager
from quixstreams.processing import ProcessingContext
+from quixstreams.processing.watermarking import WatermarkManager
from quixstreams.sinks import SinkManager
from quixstreams.state import StateStoreManager
@@ -37,6 +38,9 @@ def factory(
consumer = MagicMock(spec_set=InternalConsumer)
sink_manager = SinkManager()
registry = registry or default_registry
+ watermark_manager = WatermarkManager(
+ topic_manager=topic_manager, producer=producer
+ )
processing_ctx = ProcessingContext(
producer=producer,
@@ -45,6 +49,7 @@ def factory(
state_manager=state_manager,
sink_manager=sink_manager,
dataframe_registry=registry,
+ watermark_manager=watermark_manager,
)
processing_ctx.init_checkpoint()
diff --git a/tests/test_quixstreams/test_dataframe/test_dataframe.py b/tests/test_quixstreams/test_dataframe/test_dataframe.py
index e94ab42d8..004b827f3 100644
--- a/tests/test_quixstreams/test_dataframe/test_dataframe.py
+++ b/tests/test_quixstreams/test_dataframe/test_dataframe.py
@@ -3,9 +3,8 @@
import re
import uuid
import warnings
-from collections import namedtuple
from datetime import timedelta
-from typing import Any
+from typing import Any, NamedTuple
from unittest import mock
import pytest
@@ -21,7 +20,12 @@
from quixstreams.utils.stream_id import stream_id_from_strings
from tests.utils import DummySink
-RecordStub = namedtuple("RecordStub", ("value", "key", "timestamp"))
+
+class RecordStub(NamedTuple):
+ value: Any
+ key: Any
+ timestamp: int
+ is_watermark: bool = False
class TestStreamingDataFrame:
@@ -368,16 +372,27 @@ def test_cannot_use_logical_or(self, dataframe_factory):
with pytest.raises(InvalidOperation):
sdf["truth"] = sdf[sdf.apply(lambda x: x["a"] > 0)] or sdf[["b"]]
- def test_set_timestamp(self, dataframe_factory):
+ def test_set_timestamp(
+ self, dataframe_factory, topic_manager_factory, message_context_factory
+ ):
value, key, timestamp, headers = 1, "key", 0, None
expected = (1, "key", 100, headers)
- sdf = dataframe_factory()
+
+ topic_manager = topic_manager_factory()
+ topic = topic_manager.topic(name=str(uuid.uuid4()))
+ sdf = dataframe_factory(topic)
sdf = sdf.set_timestamp(
lambda value_, key_, timestamp_, headers_: timestamp_ + 100
)
- result = sdf.test(value=value, key=key, timestamp=timestamp, headers=headers)[0]
+ result = sdf.test(
+ value=value,
+ key=key,
+ timestamp=timestamp,
+ headers=headers,
+ ctx=message_context_factory(topic=topic.name),
+ )[0]
assert result == expected
@pytest.mark.parametrize(
@@ -845,9 +860,7 @@ def stateful_func(value_: dict, state: State) -> int:
sdf = dataframe_factory(topic, state_manager=state_manager)
sdf = sdf.apply(stateful_func, stateful=True)
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
values = [
{"number": 1},
{"number": 10},
@@ -884,9 +897,7 @@ def stateful_func(value_: dict, state: State):
sdf = dataframe_factory(topic, state_manager=state_manager)
sdf = sdf.update(stateful_func, stateful=True)
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
result = None
values = [
{"number": 1},
@@ -924,9 +935,7 @@ def stateful_func(value_: dict, state: State):
sdf = sdf.update(stateful_func, stateful=True)
sdf = sdf.filter(lambda v, state: state.get("max") >= 3, stateful=True)
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
values = [
{"number": 1},
{"number": 1},
@@ -965,9 +974,7 @@ def stateful_func(value_: dict, state: State):
sdf = sdf.update(stateful_func, stateful=True)
sdf = sdf[sdf.apply(lambda v, state: state.get("max") >= 3, stateful=True)]
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
values = [
{"number": 1},
{"number": 1},
@@ -1036,9 +1043,7 @@ def test_tumbling_window_current(
.current()
)
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Message early in the window
@@ -1051,7 +1056,7 @@ def test_tumbling_window_current(
headers = [("key", b"value")]
results = []
- for value, key, timestamp in records:
+ for value, key, timestamp, _ in records:
ctx = message_context_factory(topic=topic.name)
results += sdf.test(
value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
@@ -1112,14 +1117,14 @@ def on_late(
.current()
)
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Create window [0, 10)
RecordStub(1, "test", 1),
# Create window [20,30)
RecordStub(2, "test", 20),
+ # Send watermark at 20
+ RecordStub(None, None, 20, is_watermark=True),
# Late message - it belongs to window [0,10) but this window
# is already closed. This message should be skipped from processing
RecordStub(3, "test", 19),
@@ -1128,10 +1133,15 @@ def on_late(
results = []
with caplog.at_level(logging.WARNING, logger="quixstreams"):
- for value, key, timestamp in records:
+ for value, key, timestamp, is_watermark in records:
ctx = message_context_factory(topic=topic.name)
result = sdf.test(
- value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
+ value=value,
+ key=key,
+ timestamp=timestamp,
+ headers=headers,
+ is_watermark=is_watermark,
+ ctx=ctx,
)
results += result
@@ -1140,7 +1150,7 @@ def on_late(
r
for r in caplog.records
if r.levelname == "WARNING"
- and "Skipping window processing for the closed window" in r.message
+ and "Skipping record processing for the closed window" in r.message
]
assert warning_logs if should_log else not warning_logs
@@ -1165,56 +1175,43 @@ def test_tumbling_window_final(
sdf = dataframe_factory(topic, state_manager=state_manager)
sdf = sdf.tumbling_window(duration_ms=10, grace_ms=0).sum().final()
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Create window [0, 10)
RecordStub(1, "test", 1),
# Update window [0, 10)
RecordStub(1, "test", 2),
- # Create window [20,30). Window [0, 10) is expired now.
+ # Create window [20,30).
RecordStub(2, "test", 20),
- # Create window [30, 40). Window [20, 30) is expired now.
+ # Send watermark at 20. Window [0, 10) is expired now.
+ RecordStub(None, None, 20, is_watermark=True),
+ # Create window [30, 40).
RecordStub(3, "test", 39),
+ # Send watermark at 39. Window [20, 30) is expired now.
+ RecordStub(3, "test", 39, is_watermark=True),
# Update window [30, 40). Nothing should be returned.
RecordStub(4, "test", 38),
]
headers = [("key", b"value")]
results = []
- for value, key, timestamp in records:
+ for value, key, timestamp, is_watermark in records:
ctx = message_context_factory(topic=topic.name)
results += sdf.test(
- value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
+ value=value,
+ key=key,
+ timestamp=timestamp,
+ headers=headers,
+ is_watermark=is_watermark,
+ ctx=ctx,
)
assert len(results) == 2
assert results == [
- (WindowResult(value=2, start=0, end=10), records[2].key, 0, None),
- (WindowResult(value=2, start=20, end=30), records[3].key, 20, None),
+ (WindowResult(value=2, start=0, end=10), b'"test"', 0, None),
+ (WindowResult(value=2, start=20, end=30), b'"test"', 20, None),
]
- def test_tumbling_window_final_invalid_strategy(
- self,
- dataframe_factory,
- state_manager,
- message_context_factory,
- topic_manager_topic_factory,
- ):
- topic = topic_manager_topic_factory(
- name="test",
- )
-
- sdf = dataframe_factory(topic, state_manager=state_manager)
-
- with pytest.raises(TypeError):
- sdf = (
- sdf.tumbling_window(duration_ms=10, grace_ms=0)
- .sum()
- .final(closing_strategy="foo")
- )
-
def test_tumbling_window_none_key_messages(
self,
dataframe_factory,
@@ -1227,9 +1224,7 @@ def test_tumbling_window_none_key_messages(
sdf = dataframe_factory(topic, state_manager=state_manager)
sdf = sdf.tumbling_window(duration_ms=10).sum().current()
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Create window [0,10)
RecordStub(1, "test", 1),
@@ -1241,7 +1236,7 @@ def test_tumbling_window_none_key_messages(
headers = [("key", b"value")]
results = []
- for value, key, timestamp in records:
+ for value, key, timestamp, _ in records:
ctx = message_context_factory(topic=topic.name)
results += sdf.test(
value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
@@ -1277,9 +1272,7 @@ def test_tumbling_window_two_windows(
.current()
)
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Message early in the window
@@ -1292,7 +1285,7 @@ def test_tumbling_window_two_windows(
headers = [("key", b"value")]
results = []
- for value, key, timestamp in records:
+ for value, key, timestamp, _ in records:
ctx = message_context_factory(topic=topic.name)
results += sdf.test(
value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
@@ -1390,9 +1383,7 @@ def test_hopping_window_current(
sdf = dataframe_factory(topic, state_manager=state_manager)
sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().current()
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Create window [0,10)
RecordStub(1, "test", 1),
@@ -1408,7 +1399,7 @@ def test_hopping_window_current(
headers = [("key", b"value")]
results = []
- for value, key, timestamp in records:
+ for value, key, timestamp, _ in records:
ctx = message_context_factory(topic=topic.name)
results += sdf.test(
value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
@@ -1441,36 +1432,41 @@ def test_hopping_window_current_out_of_order_late(
sdf = dataframe_factory(topic, state_manager=state_manager)
sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().current()
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Create window [0,10)
- RecordStub(1, "test", 1),
+ RecordStub(1, b"test", 1),
# Update window [0,10) and create window [5,15)
- RecordStub(2, "test", 7),
+ RecordStub(2, b"test", 7),
# Create windows [30, 40) and [35, 45)
- RecordStub(4, "test", 35),
+ RecordStub(4, b"test", 35),
+ # Send watermark at 35
+ RecordStub(None, None, 35, is_watermark=True),
# Timestamp "10" is late and should not be processed
- RecordStub(3, "test", 26),
+ RecordStub(3, b"test", 26),
]
headers = [("key", b"value")]
results = []
- for value, key, timestamp in records:
+ for value, key, timestamp, is_watermark in records:
ctx = message_context_factory(topic=topic.name)
results += sdf.test(
- value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
+ value=value,
+ key=key,
+ timestamp=timestamp,
+ headers=headers,
+ is_watermark=is_watermark,
+ ctx=ctx,
)
assert len(results) == 5
# Ensure that the windows are returned with correct values and order
assert results == [
- (WindowResult(value=1, start=0, end=10), records[0].key, 0, None),
- (WindowResult(value=3, start=0, end=10), records[1].key, 0, None),
- (WindowResult(value=2, start=5, end=15), records[1].key, 5, None),
- (WindowResult(value=4, start=30, end=40), records[2].key, 30, None),
- (WindowResult(value=4, start=35, end=45), records[2].key, 35, None),
+ (WindowResult(value=1, start=0, end=10), b"test", 0, None),
+ (WindowResult(value=3, start=0, end=10), b"test", 0, None),
+ (WindowResult(value=2, start=5, end=15), b"test", 5, None),
+ (WindowResult(value=4, start=30, end=40), b"test", 30, None),
+ (WindowResult(value=4, start=35, end=45), b"test", 35, None),
]
def test_hopping_window_final(
@@ -1485,61 +1481,46 @@ def test_hopping_window_final(
sdf = dataframe_factory(topic, state_manager=state_manager)
sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().final()
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Create window [0,10)
- RecordStub(1, "test", 1),
+ RecordStub(1, b"test", 1),
# Update window [0,10) and create window [5,15)
- RecordStub(2, "test", 7),
+ RecordStub(2, b"test", 7),
# Update window [5,15) and create window [10,20)
- RecordStub(3, "test", 10),
+ RecordStub(3, b"test", 10),
# Create windows [30, 40) and [35, 45).
+ RecordStub(4, b"test", 35),
+ # Send watermark at 35 to expire windows
# Windows [0,10), [5,15) and [10,20) should be expired
- RecordStub(4, "test", 35),
+ RecordStub(None, None, 35, is_watermark=True),
# Update windows [30, 40) and [35, 45)
- RecordStub(5, "test", 35),
+ RecordStub(5, b"test", 35),
]
headers = [("key", b"value")]
results = []
- for value, key, timestamp in records:
+ for value, key, timestamp, is_watermark in records:
ctx = message_context_factory(topic=topic.name)
results += sdf.test(
- value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
+ value=value,
+ key=key,
+ timestamp=timestamp,
+ headers=headers,
+ is_watermark=is_watermark,
+ ctx=ctx,
)
assert len(results) == 3
# Ensure that the windows are returned with correct values and order
assert results == [
- (WindowResult(value=3, start=0, end=10), records[2].key, 0, None),
- (WindowResult(value=5, start=5, end=15), records[3].key, 5, None),
- (WindowResult(value=3, start=10, end=20), records[3].key, 10, None),
+ (WindowResult(value=3, start=0, end=10), b"test", 0, None),
+ (WindowResult(value=5, start=5, end=15), b"test", 5, None),
+ (WindowResult(value=3, start=10, end=20), b"test", 10, None),
]
- def test_hopping_window_final_invalid_strategy(
- self,
- dataframe_factory,
- state_manager,
- message_context_factory,
- topic_manager_topic_factory,
- ):
- topic = topic_manager_topic_factory(
- name="test",
- )
-
- sdf = dataframe_factory(topic, state_manager=state_manager)
-
- with pytest.raises(TypeError):
- sdf = (
- sdf.hopping_window(duration_ms=10, step_ms=5)
- .sum()
- .final(closing_strategy="foo")
- )
-
def test_hopping_window_none_key_messages(
self,
dataframe_factory,
@@ -1552,9 +1533,7 @@ def test_hopping_window_none_key_messages(
sdf = dataframe_factory(topic, state_manager=state_manager)
sdf = sdf.hopping_window(duration_ms=10, step_ms=5).sum().current()
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Create window [0,10)
RecordStub(1, "test", 1),
@@ -1566,7 +1545,7 @@ def test_hopping_window_none_key_messages(
headers = [("key", b"value")]
results = []
- for value, key, timestamp in records:
+ for value, key, timestamp, _ in records:
ctx = message_context_factory(topic=topic.name)
results += sdf.test(
value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
@@ -1599,9 +1578,7 @@ def test_sliding_window_current(
.current()
)
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
RecordStub(1, "key", 1000),
@@ -1611,7 +1588,7 @@ def test_sliding_window_current(
headers = [("key", b"value")]
results = []
- for value, key, timestamp in records:
+ for value, key, timestamp, _ in records:
ctx = message_context_factory(topic=topic.name)
results += sdf.test(
value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
@@ -1672,14 +1649,14 @@ def on_late(
.current()
)
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id, partition=0, committed_offsets={topic.name: -1001}
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
records = [
# Create window [0, 1]
RecordStub(1, "test", 1),
# Create window [10,20]
RecordStub(2, "test", 20),
+ # Watermark to expire windows ending before 20
+ RecordStub(None, None, 20, is_watermark=True),
# Late message - it belongs to window [0,5] but this window
# is already closed. This message should be skipped from processing
RecordStub(3, "test", 5),
@@ -1688,10 +1665,15 @@ def on_late(
results = []
with caplog.at_level(logging.WARNING, logger="quixstreams"):
- for value, key, timestamp in records:
+ for value, key, timestamp, is_watermark in records:
ctx = message_context_factory(topic=topic.name)
result = sdf.test(
- value=value, key=key, timestamp=timestamp, headers=headers, ctx=ctx
+ value=value,
+ key=key,
+ timestamp=timestamp,
+ headers=headers,
+ ctx=ctx,
+ is_watermark=is_watermark,
)
results += result
@@ -1700,7 +1682,7 @@ def on_late(
r
for r in caplog.records
if r.levelname == "WARNING"
- and "Skipping window processing for the closed window" in r.message
+ and "Skipping record processing for the closed window" in r.message
]
assert warning_logs if should_log else not warning_logs
@@ -2456,7 +2438,9 @@ def wrapper(value):
assert results == expected
- def test_set_timestamp(self, dataframe_factory):
+ def test_set_timestamp(
+ self, dataframe_factory, topic_manager_factory, message_context_factory
+ ):
"""
"Transform" functions work with split behavior.
"""
@@ -2464,7 +2448,10 @@ def test_set_timestamp(self, dataframe_factory):
def set_ts(n):
return lambda value, key, timestamp, headers: timestamp + n
- sdf = dataframe_factory().apply(add_n(1))
+ topic_manager = topic_manager_factory()
+ topic = topic_manager.topic(str(uuid.uuid4()))
+
+ sdf = dataframe_factory(topic).apply(add_n(1))
sdf2 = sdf.apply(add_n(2)).set_timestamp(set_ts(3)).set_timestamp(set_ts(5)) # noqa: F841
sdf3 = sdf.apply(add_n(3)) # noqa: F841
sdf = sdf.set_timestamp(set_ts(4)).apply(add_n(7))
@@ -2472,7 +2459,10 @@ def set_ts(n):
_extras = {"key": b"key", "timestamp": 0, "headers": []}
extras = list(_extras.values())
expected = [(3, b"key", 8, []), (4, *extras), (8, b"key", 4, [])]
- results = sdf.test(value=0, **_extras)
+
+ results = sdf.test(
+ value=0, ctx=message_context_factory(topic=topic.name), **_extras
+ )
assert results == expected
@@ -2699,9 +2689,7 @@ def accumulate(value: dict, state: State):
sdf_concatenated = sdf1.concat(sdf2).apply(accumulate, stateful=True)
state_manager.on_partition_assign(
- stream_id=sdf_concatenated.stream_id,
- partition=0,
- committed_offsets={},
+ stream_id=sdf_concatenated.stream_id, partition=0
)
key, timestamp, headers = b"key", 0, None
diff --git a/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py b/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py
index ddadaf24f..5445fa4a0 100644
--- a/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py
+++ b/tests/test_quixstreams/test_dataframe/test_joins/fixtures.py
@@ -12,11 +12,7 @@ def _create_sdf(topic):
@pytest.fixture
def assign_partition(state_manager):
def _assign_partition(sdf):
- state_manager.on_partition_assign(
- stream_id=sdf.stream_id,
- partition=0,
- committed_offsets={},
- )
+ state_manager.on_partition_assign(stream_id=sdf.stream_id, partition=0)
return _assign_partition
diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py b/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py
new file mode 100644
index 000000000..aeda7933f
--- /dev/null
+++ b/tests/test_quixstreams/test_dataframe/test_windows/test_countwindow.py
@@ -0,0 +1,1417 @@
+from typing import Any
+
+import pytest
+
+import quixstreams.dataframe.windows.aggregations as agg
+from quixstreams.dataframe import DataFrameRegistry
+from quixstreams.dataframe.windows import (
+ HoppingCountWindowDefinition,
+ TumblingCountWindowDefinition,
+)
+from quixstreams.dataframe.windows.count_based import CountWindow
+from quixstreams.state import WindowedPartitionTransaction
+
+
+def process(
+ window: CountWindow,
+ value: Any,
+ key: Any,
+ transaction: WindowedPartitionTransaction,
+ timestamp_ms: int,
+):
+ updated, expired = window.process_window(
+ value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms
+ )
+
+ return list(updated), list(expired)
+
+
+@pytest.fixture()
+def count_tumbling_window_definition_factory(state_manager, dataframe_factory):
+ def factory(count: int) -> TumblingCountWindowDefinition:
+ sdf = dataframe_factory(
+ state_manager=state_manager, registry=DataFrameRegistry()
+ )
+ window_def = TumblingCountWindowDefinition(dataframe=sdf, count=count)
+ return window_def
+
+ return factory
+
+
+class TestCountTumblingWindow:
+ @pytest.mark.parametrize(
+ "count, name",
+ [
+ (-10, "test"),
+ (0, "test"),
+ (1, "test"),
+ ],
+ )
+ def test_init_invalid(self, count, name, dataframe_factory):
+ with pytest.raises(ValueError):
+ TumblingCountWindowDefinition(
+ count=count,
+ name=name,
+ dataframe=dataframe_factory(),
+ )
+
+ def test_multiaggregation(
+ self,
+ count_tumbling_window_definition_factory,
+ state_manager,
+ ):
+ window = count_tumbling_window_definition_factory(count=2).agg(
+ count=agg.Count(),
+ sum=agg.Sum(),
+ mean=agg.Mean(),
+ max=agg.Max(),
+ min=agg.Min(),
+ collect=agg.Collect(),
+ )
+ window.final()
+ assert window.name == "tumbling_count_window"
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, value=1, key=key, transaction=tx, timestamp_ms=2
+ )
+ assert not expired
+ assert updated == [
+ (
+ key,
+ {
+ "start": 2,
+ "end": 2,
+ "count": 1,
+ "sum": 1,
+ "mean": 1.0,
+ "max": 1,
+ "min": 1,
+ "collect": [],
+ },
+ )
+ ]
+
+ updated, expired = process(
+ window, value=4, key=key, transaction=tx, timestamp_ms=4
+ )
+ assert expired == [
+ (
+ key,
+ {
+ "start": 2,
+ "end": 4,
+ "count": 2,
+ "sum": 5,
+ "mean": 2.5,
+ "max": 4,
+ "min": 1,
+ "collect": [1, 4],
+ },
+ )
+ ]
+ assert updated == [
+ (
+ key,
+ {
+ "start": 2,
+ "end": 4,
+ "count": 2,
+ "sum": 5,
+ "mean": 2.5,
+ "max": 4,
+ "min": 1,
+ "collect": [],
+ },
+ )
+ ]
+
+ updated, expired = process(
+ window, value=2, key=key, transaction=tx, timestamp_ms=12
+ )
+ assert not expired
+ assert updated == [
+ (
+ key,
+ {
+ "start": 12,
+ "end": 12,
+ "count": 1,
+ "sum": 2,
+ "mean": 2.0,
+ "max": 2,
+ "min": 2,
+ "collect": [],
+ },
+ )
+ ]
+
+ # Update window definition
+ # * delete an aggregation (min)
+ # * change aggregation but keep the name with new aggregation (mean -> max)
+ # * add new aggregations (sum2, collect2)
+ window = count_tumbling_window_definition_factory(count=2).agg(
+ count=agg.Count(),
+ sum=agg.Sum(),
+ mean=agg.Max(),
+ max=agg.Max(),
+ collect=agg.Collect(),
+ sum2=agg.Sum(),
+ collect2=agg.Collect(),
+ )
+ assert window.name == "tumbling_count_window" # still the same window and store
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, value=1, key=key, transaction=tx, timestamp_ms=13
+ )
+ assert (
+ expired
+ == [
+ (
+ key,
+ {
+ "start": 12,
+ "end": 13,
+ "count": 2,
+ "sum": 3,
+ "sum2": 1, # sum2 only aggregates the values after the update
+ "mean": 1, # mean was replace by max. The aggregation restarts with the new values.
+ "max": 2,
+ "collect": [2, 1],
+ "collect2": [
+ 2,
+ 1,
+ ], # Collect2 has all the values as they were fully collected before the update
+ },
+ )
+ ]
+ )
+ assert (
+ updated
+ == [
+ (
+ key,
+ {
+ "start": 12,
+ "end": 13,
+ "count": 2,
+ "sum": 3,
+ "sum2": 1, # sum2 only aggregates the values after the update
+ "mean": 1, # mean was replace by max. The aggregation restarts with the new values.
+ "max": 2,
+ "collect": [],
+ "collect2": [],
+ },
+ )
+ ]
+ )
+
+ updated, expired = process(
+ window, value=5, key=key, transaction=tx, timestamp_ms=15
+ )
+ assert not expired
+ assert updated == [
+ (
+ key,
+ {
+ "start": 15,
+ "end": 15,
+ "count": 1,
+ "sum": 5,
+ "sum2": 5,
+ "mean": 5,
+ "max": 5,
+ "collect": [],
+ "collect2": [],
+ },
+ )
+ ]
+
+ def test_count(self, count_tumbling_window_definition_factory, state_manager):
+ window_def = count_tumbling_window_definition_factory(count=10)
+ window = window_def.count()
+ assert window.name == "tumbling_count_window_count"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ process(window, key="", value=0, transaction=tx, timestamp_ms=100)
+ updated, expired = process(
+ window, key="", value=0, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 2
+ assert not expired
+
+ def test_sum(self, count_tumbling_window_definition_factory, state_manager):
+ window_def = count_tumbling_window_definition_factory(count=10)
+ window = window_def.sum()
+ assert window.name == "tumbling_count_window_sum"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ process(window, key="", value=2, transaction=tx, timestamp_ms=100)
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 3
+ assert not expired
+
+ def test_mean(self, count_tumbling_window_definition_factory, state_manager):
+ window_def = count_tumbling_window_definition_factory(count=10)
+ window = window_def.mean()
+ assert window.name == "tumbling_count_window_mean"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ process(window, key="", value=2, transaction=tx, timestamp_ms=100)
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 1.5
+ assert not expired
+
+ def test_reduce(self, count_tumbling_window_definition_factory, state_manager):
+ window_def = count_tumbling_window_definition_factory(count=10)
+ window = window_def.reduce(
+ reducer=lambda agg, current: agg + [current],
+ initializer=lambda value: [value],
+ )
+ assert window.name == "tumbling_count_window_reduce"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ process(window, key="", value=2, transaction=tx, timestamp_ms=100)
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == [2, 1]
+ assert not expired
+
+ def test_max(self, count_tumbling_window_definition_factory, state_manager):
+ window_def = count_tumbling_window_definition_factory(count=10)
+ window = window_def.max()
+ assert window.name == "tumbling_count_window_max"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ process(window, key="", value=2, transaction=tx, timestamp_ms=100)
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 2
+ assert not expired
+
+ def test_min(self, count_tumbling_window_definition_factory, state_manager):
+ window_def = count_tumbling_window_definition_factory(count=10)
+ window = window_def.min()
+ assert window.name == "tumbling_count_window_min"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ process(window, key="", value=2, transaction=tx, timestamp_ms=100)
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 1
+ assert not expired
+
+ def test_collect(self, count_tumbling_window_definition_factory, state_manager):
+ window_def = count_tumbling_window_definition_factory(count=3)
+ window = window_def.collect()
+ assert window.name == "tumbling_count_window_collect"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ process(window, key="", value=1, transaction=tx, timestamp_ms=100)
+ process(window, key="", value=2, transaction=tx, timestamp_ms=100)
+ updated, expired = process(
+ window, key="", value=3, transaction=tx, timestamp_ms=101
+ )
+
+ assert not updated
+ assert expired == [("", {"start": 100, "end": 101, "value": [1, 2, 3]})]
+
+ with store.start_partition_transaction(0) as tx:
+ state = tx.as_state(prefix=b"")
+ remaining_items = state.get_from_collection(start=0, end=1000)
+ assert remaining_items == []
+
+ def test_window_expired(
+ self,
+ count_tumbling_window_definition_factory,
+ state_manager,
+ ):
+ window_def = count_tumbling_window_definition_factory(count=2)
+ window = window_def.sum()
+ window.register_store()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ # Add first item to the window
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 1
+ assert updated[0][1]["start"] == 100
+ assert updated[0][1]["end"] == 100
+ assert not expired
+
+ # Now add second item to the window
+ # The window is now expired and should be returned
+ updated, expired = process(
+ window, key="", value=2, transaction=tx, timestamp_ms=110
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 3
+ assert updated[0][1]["start"] == 100
+ assert updated[0][1]["end"] == 110
+
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 3
+ assert expired[0][1]["start"] == 100
+ assert expired[0][1]["end"] == 110
+
+ def test_multiple_keys_sum(
+ self, count_tumbling_window_definition_factory, state_manager
+ ):
+ window_def = count_tumbling_window_definition_factory(count=3)
+ window = window_def.sum()
+ window.register_store()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="key1", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(expired) == 0
+ assert updated[0][1]["value"] == 1
+ updated, expired = process(
+ window, key="key2", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(expired) == 0
+ assert updated[0][1]["value"] == 5
+
+ updated, expired = process(
+ window, key="key1", value=2, transaction=tx, timestamp_ms=110
+ )
+ assert len(expired) == 0
+ assert updated[0][1]["value"] == 3
+ updated, expired = process(
+ window, key="key2", value=4, transaction=tx, timestamp_ms=110
+ )
+ assert len(expired) == 0
+ assert updated[0][1]["value"] == 9
+
+ updated, expired = process(
+ window, key="key1", value=3, transaction=tx, timestamp_ms=120
+ )
+ assert expired[0][1]["value"] == 6
+ assert updated[0][1]["value"] == 6
+
+ updated, expired = process(
+ window, key="key1", value=4, transaction=tx, timestamp_ms=130
+ )
+ assert len(expired) == 0
+ assert updated[0][1]["value"] == 4
+
+ updated, expired = process(
+ window, key="key2", value=3, transaction=tx, timestamp_ms=120
+ )
+ assert expired[0][1]["value"] == 12
+ assert updated[0][1]["value"] == 12
+
+ updated, expired = process(
+ window, key="key2", value=2, transaction=tx, timestamp_ms=130
+ )
+ assert len(expired) == 0
+ assert updated[0][1]["value"] == 2
+ updated, expired = process(
+ window, key="key1", value=5, transaction=tx, timestamp_ms=140
+ )
+ assert len(expired) == 0
+ assert updated[0][1]["value"] == 9
+
+ updated, expired = process(
+ window, key="key2", value=1, transaction=tx, timestamp_ms=140
+ )
+ assert len(expired) == 0
+ assert updated[0][1]["value"] == 3
+
+ def test_multiple_keys_collect(
+ self, count_tumbling_window_definition_factory, state_manager
+ ):
+ window_def = count_tumbling_window_definition_factory(count=3)
+ window = window_def.collect()
+ window.register_store()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="key1", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+ updated, expired = process(
+ window, key="key2", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key1", value=2, transaction=tx, timestamp_ms=110
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+ updated, expired = process(
+ window, key="key2", value=4, transaction=tx, timestamp_ms=110
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key1", value=3, transaction=tx, timestamp_ms=120
+ )
+ assert expired[0][1]["value"] == [1, 2, 3]
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key1", value=4, transaction=tx, timestamp_ms=130
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key2", value=3, transaction=tx, timestamp_ms=120
+ )
+ assert expired[0][1]["value"] == [5, 4, 3]
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key2", value=2, transaction=tx, timestamp_ms=130
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+ updated, expired = process(
+ window, key="key1", value=5, transaction=tx, timestamp_ms=140
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key2", value=1, transaction=tx, timestamp_ms=140
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key2", value=0, transaction=tx, timestamp_ms=130
+ )
+ assert expired[0][1]["value"] == [2, 1, 0]
+ assert len(updated) == 0
+ updated, expired = process(
+ window, key="key1", value=6, transaction=tx, timestamp_ms=140
+ )
+ assert expired[0][1]["value"] == [4, 5, 6]
+ assert len(updated) == 0
+
+
+@pytest.fixture()
+def count_hopping_window_definition_factory(state_manager, dataframe_factory):
+ def factory(count: int, step: int) -> HoppingCountWindowDefinition:
+ sdf = dataframe_factory(
+ state_manager=state_manager, registry=DataFrameRegistry()
+ )
+ window_def = HoppingCountWindowDefinition(dataframe=sdf, count=count, step=step)
+ return window_def
+
+ return factory
+
+
+class TestCountHoppingWindow:
+ @pytest.mark.parametrize(
+ "count, step, name",
+ [
+ (-10, 1, "test"),
+ (0, 1, "test"),
+ (1, 1, "test"),
+ (2, 0, "test"),
+ (2, -1, "test"),
+ ],
+ )
+ def test_init_invalid(self, count, step, name, dataframe_factory):
+ with pytest.raises(ValueError):
+ HoppingCountWindowDefinition(
+ count=count,
+ step=step,
+ name=name,
+ dataframe=dataframe_factory(),
+ )
+
+ def test_multiaggregation(
+ self,
+ count_hopping_window_definition_factory,
+ state_manager,
+ ):
+ window = count_hopping_window_definition_factory(count=3, step=2).agg(
+ count=agg.Count(),
+ sum=agg.Sum(),
+ mean=agg.Mean(),
+ max=agg.Max(),
+ min=agg.Min(),
+ collect=agg.Collect(),
+ )
+ window.final()
+ assert window.name == "hopping_count_window"
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, value=1, key=key, transaction=tx, timestamp_ms=2
+ )
+ assert not expired
+ assert updated == [
+ (
+ key,
+ {
+ "start": 2,
+ "end": 2,
+ "count": 1,
+ "sum": 1,
+ "mean": 1.0,
+ "max": 1,
+ "min": 1,
+ "collect": [],
+ },
+ ),
+ ]
+
+ updated, expired = process(
+ window, value=5, key=key, transaction=tx, timestamp_ms=6
+ )
+ assert not expired
+ assert updated == [
+ (
+ key,
+ {
+ "start": 2,
+ "end": 6,
+ "count": 2,
+ "sum": 6,
+ "mean": 3.0,
+ "max": 5,
+ "min": 1,
+ "collect": [],
+ },
+ ),
+ ]
+
+ updated, expired = process(
+ window, value=3, key=key, transaction=tx, timestamp_ms=12
+ )
+ assert expired == [
+ (
+ key,
+ {
+ "start": 2,
+ "end": 12,
+ "count": 3,
+ "sum": 9,
+ "mean": 3.0,
+ "max": 5,
+ "min": 1,
+ "collect": [1, 5, 3],
+ },
+ ),
+ ]
+ assert updated == [
+ (
+ key,
+ {
+ "start": 2,
+ "end": 12,
+ "count": 3,
+ "sum": 9,
+ "mean": 3,
+ "max": 5,
+ "min": 1,
+ "collect": [],
+ },
+ ),
+ (
+ key,
+ {
+ "start": 12,
+ "end": 12,
+ "count": 1,
+ "sum": 3,
+ "mean": 3,
+ "max": 3,
+ "min": 3,
+ "collect": [],
+ },
+ ),
+ ]
+
+ # Update window definition
+ # * delete an aggregation (min)
+ # * change aggregation but keep the name with new aggregation (mean -> max)
+ # * add new aggregations (sum2, collect2)
+ window = count_hopping_window_definition_factory(count=3, step=2).agg(
+ count=agg.Count(),
+ sum=agg.Sum(),
+ mean=agg.Max(),
+ max=agg.Max(),
+ collect=agg.Collect(),
+ sum2=agg.Sum(),
+ collect2=agg.Collect(),
+ )
+ assert window.name == "hopping_count_window" # still the same window and store
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, value=1, key=key, transaction=tx, timestamp_ms=16
+ )
+ assert not expired
+ assert (
+ updated
+ == [
+ (
+ key,
+ {
+ "start": 12,
+ "end": 16,
+ "count": 2,
+ "sum": 4,
+ "sum2": 1, # sum2 only aggregates the values after the update
+ "mean": 1, # mean was replace by max. The aggregation restarts with the new values.
+ "max": 3,
+ "collect": [],
+ "collect2": [],
+ },
+ ),
+ ]
+ )
+
+ updated, expired = process(
+ window, value=4, key=key, transaction=tx, timestamp_ms=22
+ )
+ assert (
+ expired
+ == [
+ (
+ key,
+ {
+ "start": 12,
+ "end": 22,
+ "count": 3,
+ "sum": 8,
+ "sum2": 5, # sum2 only aggregates the values after the update
+ "mean": 4, # mean was replace by max. The aggregation restarts with the new values.
+ "max": 4,
+ "collect": [3, 1, 4],
+ "collect2": [3, 1, 4],
+ },
+ ),
+ ]
+ )
+ assert (
+ updated
+ == [
+ (
+ key,
+ {
+ "start": 12,
+ "end": 22,
+ "count": 3,
+ "sum": 8,
+ "sum2": 5, # sum2 only aggregates the values after the update
+ "mean": 4, # mean was replace by max. The aggregation restarts with the new values.
+ "max": 4,
+ "collect": [],
+ "collect2": [],
+ },
+ ),
+ (
+ key,
+ {
+ "start": 22,
+ "end": 22,
+ "count": 1,
+ "sum": 4,
+ "sum2": 4,
+ "mean": 4,
+ "max": 4,
+ "collect": [],
+ "collect2": [],
+ },
+ ),
+ ]
+ )
+
+ def test_count(self, count_hopping_window_definition_factory, state_manager):
+ window_def = count_hopping_window_definition_factory(count=4, step=2)
+ window = window_def.count()
+ assert window.name == "hopping_count_window_count"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="", value=0, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 1
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=0, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 2
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=0, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 3
+ assert updated[1][1]["value"] == 1
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=0, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 4
+ assert updated[1][1]["value"] == 2
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 4
+
+ updated, expired = process(
+ window, key="", value=0, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 3
+ assert updated[1][1]["value"] == 1
+ assert len(expired) == 0
+
+ updated, expired = process(
+ window, key="", value=0, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 4
+ assert updated[1][1]["value"] == 2
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 4
+
+ def test_sum(self, count_hopping_window_definition_factory, state_manager):
+ window_def = count_hopping_window_definition_factory(count=4, step=2)
+ window = window_def.sum()
+ assert window.name == "hopping_count_window_sum"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 1
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=2, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 3 # 1 + 2
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=3, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 6 # 1 + 2 + 3
+ assert updated[1][1]["value"] == 3
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=4, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 10 # 1 + 2 + 3 + 4
+ assert updated[1][1]["value"] == 7 # 3 + 4
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 10
+
+ updated, expired = process(
+ window, key="", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 12 # 3 + 4 + 5
+ assert updated[1][1]["value"] == 5
+ assert len(expired) == 0
+
+ updated, expired = process(
+ window, key="", value=6, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 18 # 3 + 4 + 5 + 6
+ assert updated[1][1]["value"] == 11 # 5 + 6
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 18
+
+ def test_mean(self, count_hopping_window_definition_factory, state_manager):
+ window_def = count_hopping_window_definition_factory(count=4, step=2)
+ window = window_def.mean()
+ assert window.name == "hopping_count_window_mean"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 1
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=2, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 1.5 # (1 + 2) / 2
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=3, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 2 # (1 + 2 + 3) / 3
+ assert updated[1][1]["value"] == 3
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=4, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 2.5 # (1 + 2 + 3 + 4) / 4
+ assert updated[1][1]["value"] == 3.5 # 3 + 4
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 2.5
+
+ updated, expired = process(
+ window, key="", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 4 # (3 + 4 + 5) / 3
+ assert updated[1][1]["value"] == 5
+ assert len(expired) == 0
+
+ updated, expired = process(
+ window, key="", value=6, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert (
+ updated[0][1]["value"] == 4.5
+ ) # (3 # sum2 only aggregates the values after the update + 6) / 2
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 4.5
+
+ def test_reduce(self, count_hopping_window_definition_factory, state_manager):
+ window_def = count_hopping_window_definition_factory(count=4, step=2)
+ window = window_def.reduce(
+ reducer=lambda agg, current: agg + [current],
+ initializer=lambda value: [value],
+ )
+ assert window.name == "hopping_count_window_reduce"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == [1]
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=2, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == [1, 2]
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=3, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == [1, 2, 3]
+ assert updated[1][1]["value"] == [3]
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=4, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == [1, 2, 3, 4]
+ assert updated[1][1]["value"] == [3, 4]
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [1, 2, 3, 4]
+
+ updated, expired = process(
+ window, key="", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == [3, 4, 5]
+ assert updated[1][1]["value"] == [5]
+ assert len(expired) == 0
+
+ updated, expired = process(
+ window, key="", value=6, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == [3, 4, 5, 6]
+ assert updated[1][1]["value"] == [5, 6]
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [3, 4, 5, 6]
+
+ def test_max(self, count_hopping_window_definition_factory, state_manager):
+ window_def = count_hopping_window_definition_factory(count=4, step=2)
+ window = window_def.max()
+ assert window.name == "hopping_count_window_max"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 1
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=2, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 2
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=4, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 4
+ assert updated[1][1]["value"] == 4
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=3, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 4
+ assert updated[1][1]["value"] == 4
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 4
+
+ updated, expired = process(
+ window, key="", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 5
+ assert updated[1][1]["value"] == 5
+ assert len(expired) == 0
+
+ updated, expired = process(
+ window, key="", value=6, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 6
+ assert updated[1][1]["value"] == 6
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 6
+
+ def test_min(self, count_hopping_window_definition_factory, state_manager):
+ window_def = count_hopping_window_definition_factory(count=4, step=2)
+ window = window_def.min()
+ assert window.name == "hopping_count_window_min"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="", value=4, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 4
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=2, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 2
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=3, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 2
+ assert updated[1][1]["value"] == 3
+ assert expired == []
+
+ updated, expired = process(
+ window, key="", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 2
+ assert updated[1][1]["value"] == 3
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 2
+
+ updated, expired = process(
+ window, key="", value=6, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 3
+ assert updated[1][1]["value"] == 6
+ assert len(expired) == 0
+
+ updated, expired = process(
+ window, key="", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 3
+ assert updated[1][1]["value"] == 5
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 3
+
+ def test_collect(self, count_hopping_window_definition_factory, state_manager):
+ window_def = count_hopping_window_definition_factory(count=4, step=2)
+ window = window_def.collect()
+ assert window.name == "hopping_count_window_collect"
+
+ window.final()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=2, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=3, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=4, transaction=tx, timestamp_ms=100
+ )
+ assert updated == []
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [1, 2, 3, 4]
+
+ updated, expired = process(
+ window, key="", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=6, transaction=tx, timestamp_ms=100
+ )
+ assert updated == []
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [3, 4, 5, 6]
+
+ with store.start_partition_transaction(0) as tx:
+ state = tx.as_state(prefix="")
+ remaining_items = state.get_from_collection(start=0, end=1000)
+ assert remaining_items == [5, 6]
+
+ def test_unaligned_steps(
+ self, count_hopping_window_definition_factory, state_manager
+ ):
+ window_def = count_hopping_window_definition_factory(count=5, step=2)
+ window = window_def.collect()
+ window.register_store()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=2, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=3, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=4, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert updated == []
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [1, 2, 3, 4, 5]
+
+ updated, expired = process(
+ window, key="", value=6, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=7, transaction=tx, timestamp_ms=100
+ )
+ assert updated == []
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [3, 4, 5, 6, 7]
+
+ updated, expired = process(
+ window, key="", value=8, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=9, transaction=tx, timestamp_ms=100
+ )
+ assert updated == []
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [5, 6, 7, 8, 9]
+
+ updated, expired = process(
+ window, key="", value=10, transaction=tx, timestamp_ms=100
+ )
+ assert updated == expired == []
+
+ updated, expired = process(
+ window, key="", value=11, transaction=tx, timestamp_ms=100
+ )
+ assert updated == []
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [7, 8, 9, 10, 11]
+
+ with store.start_partition_transaction(0) as tx:
+ state = tx.as_state(prefix="")
+ remaining_items = state.get_from_collection(start=0, end=1000)
+ assert remaining_items == [9, 10, 11]
+
+ def test_multiple_keys_sum(
+ self, count_hopping_window_definition_factory, state_manager
+ ):
+ window_def = count_hopping_window_definition_factory(count=3, step=1)
+ window = window_def.sum()
+ window.register_store()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="key1", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(expired) == 0
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 1
+ updated, expired = process(
+ window, key="key2", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(expired) == 0
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 5
+
+ updated, expired = process(
+ window, key="key1", value=2, transaction=tx, timestamp_ms=110
+ )
+ assert len(expired) == 0
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 3
+ assert updated[1][1]["value"] == 2
+
+ updated, expired = process(
+ window, key="key2", value=4, transaction=tx, timestamp_ms=110
+ )
+ assert len(expired) == 0
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 9
+ assert updated[1][1]["value"] == 4
+
+ updated, expired = process(
+ window, key="key1", value=3, transaction=tx, timestamp_ms=120
+ )
+ assert expired[0][1]["value"] == 6
+ assert len(updated) == 3
+ assert updated[0][1]["value"] == 6
+ assert updated[1][1]["value"] == 5
+ assert updated[2][1]["value"] == 3
+
+ updated, expired = process(
+ window, key="key1", value=4, transaction=tx, timestamp_ms=130
+ )
+ assert expired[0][1]["value"] == 9
+ assert len(updated) == 3
+ assert updated[0][1]["value"] == 9
+ assert updated[1][1]["value"] == 7
+ assert updated[2][1]["value"] == 4
+
+ updated, expired = process(
+ window, key="key2", value=3, transaction=tx, timestamp_ms=120
+ )
+ assert expired[0][1]["value"] == 12
+ assert len(updated) == 3
+ assert updated[0][1]["value"] == 12
+ assert updated[1][1]["value"] == 7
+ assert updated[2][1]["value"] == 3
+
+ updated, expired = process(
+ window, key="key2", value=2, transaction=tx, timestamp_ms=130
+ )
+ assert expired[0][1]["value"] == 9
+ assert len(updated) == 3
+ assert updated[0][1]["value"] == 9
+ assert updated[1][1]["value"] == 5
+ assert updated[2][1]["value"] == 2
+
+ updated, expired = process(
+ window, key="key1", value=5, transaction=tx, timestamp_ms=140
+ )
+ assert expired[0][1]["value"] == 12
+ assert len(updated) == 3
+ assert updated[0][1]["value"] == 12
+ assert updated[1][1]["value"] == 9
+ assert updated[2][1]["value"] == 5
+
+ updated, expired = process(
+ window, key="key2", value=1, transaction=tx, timestamp_ms=140
+ )
+ assert expired[0][1]["value"] == 6
+ assert len(updated) == 3
+ assert updated[0][1]["value"] == 6
+ assert updated[1][1]["value"] == 3
+ assert updated[2][1]["value"] == 1
+
+ def test_multiple_keys_collect(
+ self, count_hopping_window_definition_factory, state_manager
+ ):
+ window_def = count_hopping_window_definition_factory(count=3, step=1)
+ window = window_def.collect()
+ window.register_store()
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+
+ with store.start_partition_transaction(0) as tx:
+ updated, expired = process(
+ window, key="key1", value=1, transaction=tx, timestamp_ms=100
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+ updated, expired = process(
+ window, key="key2", value=5, transaction=tx, timestamp_ms=100
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key1", value=2, transaction=tx, timestamp_ms=110
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+ updated, expired = process(
+ window, key="key2", value=4, transaction=tx, timestamp_ms=110
+ )
+ assert len(expired) == 0
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key1", value=3, transaction=tx, timestamp_ms=120
+ )
+ assert expired[0][1]["value"] == [1, 2, 3]
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key1", value=4, transaction=tx, timestamp_ms=130
+ )
+ assert expired[0][1]["value"] == [2, 3, 4]
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key2", value=3, transaction=tx, timestamp_ms=120
+ )
+ assert expired[0][1]["value"] == [5, 4, 3]
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key2", value=2, transaction=tx, timestamp_ms=130
+ )
+ assert expired[0][1]["value"] == [4, 3, 2]
+ assert len(updated) == 0
+ updated, expired = process(
+ window, key="key1", value=5, transaction=tx, timestamp_ms=140
+ )
+ assert expired[0][1]["value"] == [3, 4, 5]
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key2", value=1, transaction=tx, timestamp_ms=140
+ )
+ assert expired[0][1]["value"] == [3, 2, 1]
+ assert len(updated) == 0
+
+ updated, expired = process(
+ window, key="key2", value=0, transaction=tx, timestamp_ms=130
+ )
+ assert expired[0][1]["value"] == [2, 1, 0]
+ assert len(updated) == 0
+ updated, expired = process(
+ window, key="key1", value=6, transaction=tx, timestamp_ms=140
+ )
+ assert expired[0][1]["value"] == [4, 5, 6]
+ assert len(updated) == 0
diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py
index 6a0b1fd5f..27d630d70 100644
--- a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py
+++ b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py
@@ -1,38 +1,447 @@
+import functools
+
import pytest
import quixstreams.dataframe.windows.aggregations as agg
from quixstreams.dataframe import DataFrameRegistry
from quixstreams.dataframe.windows import (
- HoppingCountWindowDefinition,
HoppingTimeWindowDefinition,
)
-from quixstreams.dataframe.windows.time_based import ClosingStrategy
@pytest.fixture()
def hopping_window_definition_factory(state_manager, dataframe_factory):
def factory(
- duration_ms: int, step_ms: int, grace_ms: int = 0
+ duration_ms: int,
+ step_ms: int,
+ grace_ms: int = 0,
+ before_update=None,
+ after_update=None,
) -> HoppingTimeWindowDefinition:
sdf = dataframe_factory(
state_manager=state_manager, registry=DataFrameRegistry()
)
window_def = HoppingTimeWindowDefinition(
- duration_ms=duration_ms, step_ms=step_ms, grace_ms=grace_ms, dataframe=sdf
+ duration_ms=duration_ms,
+ step_ms=step_ms,
+ grace_ms=grace_ms,
+ dataframe=sdf,
+ before_update=before_update,
+ after_update=after_update,
)
return window_def
return factory
-def process(window, value, key, transaction, timestamp_ms):
- updated, expired = window.process_window(
- value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms
+def process(window, value, key, transaction, timestamp_ms, headers=None):
+ updated, triggered = window.process_window(
+ value=value,
+ key=key,
+ timestamp_ms=timestamp_ms,
+ headers=headers,
+ transaction=transaction,
+ )
+ expired = window.expire_by_partition(
+ transaction=transaction, timestamp_ms=timestamp_ms
)
- return list(updated), list(expired)
+ # Combine triggered windows (from callbacks) with time-expired windows
+ all_expired = list(triggered) + list(expired)
+ return list(updated), all_expired
class TestHoppingWindow:
+ def test_hopping_window_with_after_update_trigger(
+ self, hopping_window_definition_factory, state_manager
+ ):
+ # Define a trigger that expires windows when the sum reaches 100 or more
+ def trigger_on_sum_100(aggregated, value, key, timestamp, headers) -> bool:
+ return aggregated >= 100
+
+ window_def = hopping_window_definition_factory(
+ duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_sum_100
+ )
+ window = window_def.sum()
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ _process = functools.partial(
+ process, window=window, key=key, transaction=tx
+ )
+
+ # Step 1: Add value=90 at timestamp 50ms
+ # Creates windows [0, 100) and [50, 150) with sum 90 each
+ updated, expired = _process(value=90, timestamp_ms=50)
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 90
+ assert updated[0][1]["start"] == 0
+ assert updated[0][1]["end"] == 100
+ assert updated[1][1]["value"] == 90
+ assert updated[1][1]["start"] == 50
+ assert updated[1][1]["end"] == 150
+ assert not expired
+
+ # Step 2: Add value=5 at timestamp 110ms
+ # With grace_ms=100, [0, 100) does NOT expire naturally yet
+ # [0, 100): stays 90 (timestamp 110 is outside [0, 100), not updated)
+ # [50, 150): 90 -> 95 (< 100, NOT TRIGGERED)
+ # [100, 200): newly created with sum 5
+ updated, expired = _process(value=5, timestamp_ms=110)
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 95
+ assert updated[0][1]["start"] == 50
+ assert updated[0][1]["end"] == 150
+ assert updated[1][1]["value"] == 5
+ assert updated[1][1]["start"] == 100
+ assert updated[1][1]["end"] == 200
+ # No windows expired (grace period keeps [0, 100) alive)
+ assert not expired
+
+ # Step 3: Add value=5 at timestamp 90ms (late message)
+ # Timestamp 90 belongs to BOTH [0, 100) and [50, 150)
+ # [0, 100): 90 -> 95 (< 100, NOT TRIGGERED)
+ # [50, 150): 95 -> 100 (>= 100, TRIGGERED!)
+ updated, expired = _process(value=5, timestamp_ms=90)
+ # Only [0, 100) remains in updated (not triggered, 95 < 100)
+ # Only [50, 150) was triggered (100 >= 100)
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 95
+ assert updated[0][1]["start"] == 0
+ assert updated[0][1]["end"] == 100
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 100
+ assert expired[0][1]["start"] == 50
+ assert expired[0][1]["end"] == 150
+
+ def test_hopping_window_with_before_update_trigger(
+ self, hopping_window_definition_factory, state_manager
+ ):
+ """Test that before_update callback works for hopping windows."""
+
+ # Define a trigger that expires windows before adding a value
+ # if the sum would exceed 50
+ def trigger_before_exceeding_50(
+ aggregated, value, key, timestamp, headers
+ ) -> bool:
+ return (aggregated + value) > 50
+
+ window_def = hopping_window_definition_factory(
+ duration_ms=100,
+ step_ms=50,
+ grace_ms=100,
+ before_update=trigger_before_exceeding_50,
+ )
+ window = window_def.sum()
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ # Helper to process and return results
+ def _process(value, timestamp_ms):
+ return process(
+ window,
+ value=value,
+ key=key,
+ transaction=tx,
+ timestamp_ms=timestamp_ms,
+ )
+
+ # Step 1: Add value=10 at timestamp 50ms
+ # Belongs to windows [0, 100) and [50, 150) (hopping windows overlap)
+ # Both windows: Sum=10, doesn't exceed 50, no trigger
+ updated, expired = _process(value=10, timestamp_ms=50)
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 10
+ assert updated[0][1]["start"] == 0
+ assert updated[1][1]["value"] == 10
+ assert updated[1][1]["start"] == 50
+ assert not expired
+
+ # Step 2: Add value=20 at timestamp 60ms
+ # Belongs to windows [0, 100) and [50, 150)
+ # Both windows: Sum=30, doesn't exceed 50, no trigger
+ updated, expired = _process(value=20, timestamp_ms=60)
+ assert len(updated) == 2
+ assert updated[0][1]["value"] == 30 # [0, 100)
+ assert updated[1][1]["value"] == 30 # [50, 150)
+ assert not expired
+
+ # Step 3: Add value=25 at timestamp 70ms
+ # Belongs to windows [0, 100) and [50, 150)
+ # Both windows: Sum would be 55 which exceeds 50, should trigger BEFORE adding
+ # Both expired windows should have value=30 (not 55)
+ updated, expired = _process(value=25, timestamp_ms=70)
+ assert not updated
+ assert len(expired) == 2
+ assert expired[0][1]["value"] == 30 # [0, 100) before the update
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+ assert expired[1][1]["value"] == 30 # [50, 150) before the update
+ assert expired[1][1]["start"] == 50
+ assert expired[1][1]["end"] == 150
+
+ # Step 4: Add value=5 at timestamp 100ms
+ # Belongs to windows [50, 150) and [100, 200)
+ # Window [50, 150) sum=5, doesn't trigger
+ # Window [100, 200) sum=5, doesn't trigger
+ updated, expired = _process(value=5, timestamp_ms=100)
+ assert len(updated) == 2
+ # Results should be for both windows
+ assert not expired
+
+ def test_hopping_window_collect_with_after_update_trigger(
+ self, hopping_window_definition_factory, state_manager
+ ):
+ """Test that after_update callback works with collect for hopping windows."""
+
+ # Define a trigger that expires windows when we collect 3 or more items
+ def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool:
+ return len(aggregated) >= 3
+
+ window_def = hopping_window_definition_factory(
+ duration_ms=100, step_ms=50, grace_ms=100, after_update=trigger_on_count_3
+ )
+ window = window_def.collect()
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ _process = functools.partial(
+ process, window=window, key=key, transaction=tx
+ )
+
+ # Step 1: Add first value at timestamp 50ms
+ # Creates windows [0, 100) and [50, 150) with 1 item each
+ updated, expired = _process(value=1, timestamp_ms=50)
+ assert not updated # collect doesn't emit on updates
+ assert not expired
+
+ # Step 2: Add second value at timestamp 60ms
+ # Both windows now have 2 items
+ updated, expired = _process(value=2, timestamp_ms=60)
+ assert not updated
+ assert not expired
+
+ # Step 3: Add third value at timestamp 70ms
+ # Both windows now have 3 items - BOTH SHOULD TRIGGER
+ updated, expired = _process(value=3, timestamp_ms=70)
+ assert not updated
+ assert len(expired) == 2
+ # Window [0, 100) triggered
+ assert expired[0][1]["value"] == [1, 2, 3]
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+ # Window [50, 150) triggered
+ assert expired[1][1]["value"] == [1, 2, 3]
+ assert expired[1][1]["start"] == 50
+ assert expired[1][1]["end"] == 150
+
+ # Step 4: Add fourth value at timestamp 110ms
+ # Timestamp 110 belongs to windows [50, 150) and [100, 200)
+ # Window [50, 150) is "resurrected" because collection values weren't deleted
+ # (for hopping windows, we don't delete collection on trigger to preserve
+ # values for overlapping windows)
+ # Window [50, 150) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN!
+ # Window [100, 200) has [4] = 1 item - doesn't trigger
+ updated, expired = _process(value=4, timestamp_ms=110)
+ assert not updated
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [1, 2, 3, 4]
+ assert expired[0][1]["start"] == 50
+ assert expired[0][1]["end"] == 150
+
+ def test_hopping_window_collect_with_before_update_trigger(
+ self, hopping_window_definition_factory, state_manager
+ ):
+ """Test that before_update callback works with collect for hopping windows."""
+
+ # Define a trigger that expires windows before adding a value
+ # if the collection would reach 3 or more items
+ def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool:
+ # For collect, aggregated is the list of collected values BEFORE adding
+ return len(aggregated) + 1 >= 3
+
+ window_def = hopping_window_definition_factory(
+ duration_ms=100,
+ step_ms=50,
+ grace_ms=100,
+ before_update=trigger_before_count_3,
+ )
+ window = window_def.collect()
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ # Helper to process and return results
+ def _process(value, timestamp_ms):
+ return process(
+ window,
+ value=value,
+ key=key,
+ transaction=tx,
+ timestamp_ms=timestamp_ms,
+ )
+
+ # Step 1: Add value=1 at timestamp 50ms
+ # Belongs to windows [0, 100) and [50, 150)
+ # Both windows would have 1 item, no trigger
+ updated, expired = _process(value=1, timestamp_ms=50)
+ assert not updated # collect doesn't emit on updates
+ assert not expired
+
+ # Step 2: Add value=2 at timestamp 60ms
+ # Belongs to windows [0, 100) and [50, 150)
+ # Both windows would have 2 items, no trigger
+ updated, expired = _process(value=2, timestamp_ms=60)
+ assert not updated
+ assert not expired
+
+ # Step 3: Add value=3 at timestamp 70ms
+ # Belongs to windows [0, 100) and [50, 150)
+ # Both windows would have 3 items, triggers BEFORE adding
+ # Both windows should have [1, 2] (not [1, 2, 3])
+ updated, expired = _process(value=3, timestamp_ms=70)
+ assert not updated
+ assert len(expired) == 2
+ # Window [0, 100)
+ assert expired[0][1]["value"] == [1, 2]
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+ # Window [50, 150)
+ assert expired[1][1]["value"] == [1, 2]
+ assert expired[1][1]["start"] == 50
+ assert expired[1][1]["end"] == 150
+
+ # Step 4: Add value=4 at timestamp 110ms
+ # Belongs to windows [50, 150) and [100, 200)
+ # Window [50, 150) resurrected with [1, 2, 3] - would be 4 items, triggers
+ # Window [100, 200) would have 1 item, no trigger
+ updated, expired = _process(value=4, timestamp_ms=110)
+ assert not updated
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [1, 2, 3] # Before adding 4
+ assert expired[0][1]["start"] == 50
+ assert expired[0][1]["end"] == 150
+
+ def test_hopping_window_agg_and_collect_with_before_update_trigger(
+ self, hopping_window_definition_factory, state_manager
+ ):
+ """Test before_update with BOTH aggregation and collect for hopping windows.
+
+ This verifies that:
+ 1. The triggered window does NOT include the triggering value in collect
+ 2. The triggering value IS still added to collection storage for future windows
+ 3. The aggregated value is BEFORE the triggering value
+ 4. For hopping windows, overlapping windows share the collection storage
+ """
+ import quixstreams.dataframe.windows.aggregations as agg
+
+ # Trigger when count would reach 3
+ def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool:
+ # In multi-aggregation, keys are like 'count/Count', 'sum/Sum'
+ # Find the count aggregation value
+ for k, v in agg_dict.items():
+ if k.startswith("count"):
+ return v + 1 >= 3
+ return False
+
+ window_def = hopping_window_definition_factory(
+ duration_ms=100,
+ step_ms=50,
+ grace_ms=100,
+ before_update=trigger_before_count_3,
+ )
+ window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect())
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ _process = functools.partial(
+ process, window=window, key=key, transaction=tx
+ )
+
+ # Step 1: Add value=1 at timestamp 50ms
+ # Windows [0, 100) and [50, 150) both get count=1
+ updated, expired = _process(value=1, timestamp_ms=50)
+ assert len(updated) == 2
+ assert not expired
+
+ # Step 2: Add value=2 at timestamp 60ms
+ # Both windows get count=2
+ updated, expired = _process(value=2, timestamp_ms=60)
+ assert len(updated) == 2
+ assert not expired
+
+ # Step 3: Add value=3 at timestamp 70ms
+ # Both windows: count would be 3, triggers BEFORE adding
+ updated, expired = _process(value=3, timestamp_ms=70)
+ assert not updated
+ assert len(expired) == 2
+
+ # Window [0, 100)
+ assert expired[0][1]["count"] == 2 # Before the update (not 3)
+ assert expired[0][1]["sum"] == 3 # Before the update (1+2, not 1+2+3)
+ # CRITICAL: collect should NOT include the triggering value (3)
+ assert expired[0][1]["collect"] == [1, 2]
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+
+ # Window [50, 150)
+ assert expired[1][1]["count"] == 2 # Before the update (not 3)
+ assert expired[1][1]["sum"] == 3 # Before the update (1+2, not 1+2+3)
+ # CRITICAL: collect should NOT include the triggering value (3)
+ assert expired[1][1]["collect"] == [1, 2]
+ assert expired[1][1]["start"] == 50
+ assert expired[1][1]["end"] == 150
+
+ # Step 4: Add value=4 at timestamp 100ms
+ # This belongs to windows [50, 150) and [100, 200)
+ # The triggering value (3) should still be in collection storage
+ updated, expired = _process(value=4, timestamp_ms=100)
+ assert len(updated) == 2
+ assert not expired
+
+ # Step 5: Force natural expiration to verify collection includes triggering value
+ # Windows that were deleted by trigger won't resurrect in hopping windows
+ # since they were explicitly deleted. Let's verify the triggering value
+ # was still added to collection by adding more values to a later window
+ updated, expired = _process(value=5, timestamp_ms=120)
+ assert len(updated) == 2 # Windows [50,150) resurrected and [100,200)
+ assert not expired
+
+ # Force expiration at timestamp 260 (well past grace period)
+ updated, expired = _process(value=6, timestamp_ms=260)
+ # This should expire windows that existed
+ assert len(expired) >= 1
+
+ # The key point: the triggering value (3) WAS added to collection storage
+ # So any window that overlaps with that timestamp includes it
+ # Verify at least one expired window contains the triggering value
+ found_triggering_value = False
+ for _, window_result in expired:
+ if 3 in window_result["collect"]:
+ found_triggering_value = True
+ break
+ assert (
+ found_triggering_value
+ ), "Triggering value (3) should be in collection storage"
+
@pytest.mark.parametrize(
"duration, grace, step, provided_name, func_name, expected_name",
[
@@ -255,15 +664,14 @@ def test_multiaggregation(
]
)
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_hoppingwindow_count(
- self, expiration, hopping_window_definition_factory, state_manager
+ self, hopping_window_definition_factory, state_manager
):
window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5)
window = window_def.count()
assert window.name == "hopping_window_10_5_count"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -282,15 +690,12 @@ def test_hoppingwindow_count(
assert updated[1][1]["end"] == 110
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
- def test_hoppingwindow_sum(
- self, expiration, hopping_window_definition_factory, state_manager
- ):
+ def test_hoppingwindow_sum(self, hopping_window_definition_factory, state_manager):
window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5)
window = window_def.sum()
assert window.name == "hopping_window_10_5_sum"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -309,15 +714,12 @@ def test_hoppingwindow_sum(
assert updated[1][1]["end"] == 110
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
- def test_hoppingwindow_mean(
- self, expiration, hopping_window_definition_factory, state_manager
- ):
+ def test_hoppingwindow_mean(self, hopping_window_definition_factory, state_manager):
window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5)
window = window_def.mean()
assert window.name == "hopping_window_10_5_mean"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -336,9 +738,8 @@ def test_hoppingwindow_mean(
assert updated[1][1]["end"] == 110
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_hoppingwindow_reduce(
- self, expiration, hopping_window_definition_factory, state_manager
+ self, hopping_window_definition_factory, state_manager
):
window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5)
window = window_def.reduce(
@@ -347,7 +748,7 @@ def test_hoppingwindow_reduce(
)
assert window.name == "hopping_window_10_5_reduce"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -365,15 +766,12 @@ def test_hoppingwindow_reduce(
assert updated[1][1]["end"] == 110
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
- def test_hoppingwindow_max(
- self, expiration, hopping_window_definition_factory, state_manager
- ):
+ def test_hoppingwindow_max(self, hopping_window_definition_factory, state_manager):
window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5)
window = window_def.max()
assert window.name == "hopping_window_10_5_max"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -391,15 +789,12 @@ def test_hoppingwindow_max(
assert updated[1][1]["end"] == 110
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
- def test_hoppingwindow_min(
- self, expiration, hopping_window_definition_factory, state_manager
- ):
+ def test_hoppingwindow_min(self, hopping_window_definition_factory, state_manager):
window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5)
window = window_def.min()
assert window.name == "hopping_window_10_5_min"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -417,15 +812,14 @@ def test_hoppingwindow_min(
assert updated[1][1]["end"] == 110
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_hoppingwindow_collect(
- self, expiration, hopping_window_definition_factory, state_manager
+ self, hopping_window_definition_factory, state_manager
):
window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5)
window = window_def.collect()
assert window.name == "hopping_window_10_5_collect"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -466,10 +860,8 @@ def test_hopping_window_def_init_invalid(
dataframe=dataframe_factory(),
)
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_hopping_window_process_window_expired(
self,
- expiration,
hopping_window_definition_factory,
state_manager,
):
@@ -477,7 +869,7 @@ def test_hopping_window_process_window_expired(
duration_ms=10, grace_ms=0, step_ms=5
)
window = window_def.sum()
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
key = b"key"
@@ -518,7 +910,7 @@ def test_hopping_partition_expiration(
duration_ms=10, grace_ms=2, step_ms=5
)
window = window_def.sum()
- window.final(closing_strategy="partition")
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -566,965 +958,3 @@ def test_hopping_partition_expiration(
(key1, {"start": 100, "end": 110, "value": 4}),
(key2, {"start": 100, "end": 110, "value": 14}),
]
-
- def test_hopping_key_expiration_to_partition(
- self, hopping_window_definition_factory, state_manager
- ):
- window_def = hopping_window_definition_factory(
- duration_ms=10, grace_ms=0, step_ms=5
- )
- window = window_def.sum()
- window.final(closing_strategy="key")
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- key1 = b"key1"
- key2 = b"key2"
-
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=100)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=102)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=105)
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=106)
-
- window._closing_strategy = ClosingStrategy.PARTITION
- with store.start_partition_transaction(0) as tx:
- key3 = b"key3"
-
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=107)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=108)
- updated, expired = process(
- window, value=1, key=key3, transaction=tx, timestamp_ms=114
- )
-
- assert updated == [
- (key3, {"start": 105, "end": 115, "value": 1}),
- (key3, {"start": 110, "end": 120, "value": 1}),
- ]
- assert expired == [
- (key1, {"start": 100, "end": 110, "value": 3}),
- (key2, {"start": 100, "end": 110, "value": 3}),
- ]
-
- def test_hopping_partition_expiration_to_key(
- self, hopping_window_definition_factory, state_manager
- ):
- window_def = hopping_window_definition_factory(
- duration_ms=10, grace_ms=0, step_ms=5
- )
- window = window_def.sum()
- window.final(closing_strategy="partition")
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- key1 = b"key1"
- key2 = b"key2"
-
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=100)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=102)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=105)
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=106)
-
- window._closing_strategy = ClosingStrategy.KEY
- with store.start_partition_transaction(0) as tx:
- key3 = b"key3"
-
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=107)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=108)
- updated, expired = process(
- window, value=1, key=key3, transaction=tx, timestamp_ms=114
- )
-
- assert updated == [
- (key3, {"start": 105, "end": 115, "value": 1}),
- (key3, {"start": 110, "end": 120, "value": 1}),
- ]
- assert expired == []
-
- updated, expired = process(
- window, value=1, key=key1, transaction=tx, timestamp_ms=116
- )
- assert updated == [
- (key1, {"start": 110, "end": 120, "value": 1}),
- (key1, {"start": 115, "end": 125, "value": 1}),
- ]
- assert expired == [
- (key1, {"start": 100, "end": 110, "value": 3}),
- (key1, {"start": 105, "end": 115, "value": 2}),
- ]
-
-
-@pytest.fixture()
-def count_hopping_window_definition_factory(state_manager, dataframe_factory):
- def factory(count: int, step: int) -> HoppingCountWindowDefinition:
- sdf = dataframe_factory(
- state_manager=state_manager, registry=DataFrameRegistry()
- )
- window_def = HoppingCountWindowDefinition(dataframe=sdf, count=count, step=step)
- return window_def
-
- return factory
-
-
-class TestCountHoppingWindow:
- @pytest.mark.parametrize(
- "count, step, name",
- [
- (-10, 1, "test"),
- (0, 1, "test"),
- (1, 1, "test"),
- (2, 0, "test"),
- (2, -1, "test"),
- ],
- )
- def test_init_invalid(self, count, step, name, dataframe_factory):
- with pytest.raises(ValueError):
- HoppingCountWindowDefinition(
- count=count,
- step=step,
- name=name,
- dataframe=dataframe_factory(),
- )
-
- def test_multiaggregation(
- self,
- count_hopping_window_definition_factory,
- state_manager,
- ):
- window = count_hopping_window_definition_factory(count=3, step=2).agg(
- count=agg.Count(),
- sum=agg.Sum(),
- mean=agg.Mean(),
- max=agg.Max(),
- min=agg.Min(),
- collect=agg.Collect(),
- )
- window.final()
- assert window.name == "hopping_count_window"
-
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- key = b"key"
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, value=1, key=key, transaction=tx, timestamp_ms=2
- )
- assert not expired
- assert updated == [
- (
- key,
- {
- "start": 2,
- "end": 2,
- "count": 1,
- "sum": 1,
- "mean": 1.0,
- "max": 1,
- "min": 1,
- "collect": [],
- },
- ),
- ]
-
- updated, expired = process(
- window, value=5, key=key, transaction=tx, timestamp_ms=6
- )
- assert not expired
- assert updated == [
- (
- key,
- {
- "start": 2,
- "end": 6,
- "count": 2,
- "sum": 6,
- "mean": 3.0,
- "max": 5,
- "min": 1,
- "collect": [],
- },
- ),
- ]
-
- updated, expired = process(
- window, value=3, key=key, transaction=tx, timestamp_ms=12
- )
- assert expired == [
- (
- key,
- {
- "start": 2,
- "end": 12,
- "count": 3,
- "sum": 9,
- "mean": 3.0,
- "max": 5,
- "min": 1,
- "collect": [1, 5, 3],
- },
- ),
- ]
- assert updated == [
- (
- key,
- {
- "start": 2,
- "end": 12,
- "count": 3,
- "sum": 9,
- "mean": 3,
- "max": 5,
- "min": 1,
- "collect": [],
- },
- ),
- (
- key,
- {
- "start": 12,
- "end": 12,
- "count": 1,
- "sum": 3,
- "mean": 3,
- "max": 3,
- "min": 3,
- "collect": [],
- },
- ),
- ]
-
- # Update window definition
- # * delete an aggregation (min)
- # * change aggregation but keep the name with new aggregation (mean -> max)
- # * add new aggregations (sum2, collect2)
- window = count_hopping_window_definition_factory(count=3, step=2).agg(
- count=agg.Count(),
- sum=agg.Sum(),
- mean=agg.Max(),
- max=agg.Max(),
- collect=agg.Collect(),
- sum2=agg.Sum(),
- collect2=agg.Collect(),
- )
- assert window.name == "hopping_count_window" # still the same window and store
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, value=1, key=key, transaction=tx, timestamp_ms=16
- )
- assert not expired
- assert (
- updated
- == [
- (
- key,
- {
- "start": 12,
- "end": 16,
- "count": 2,
- "sum": 4,
- "sum2": 1, # sum2 only aggregates the values after the update
- "mean": 1, # mean was replace by max. The aggregation restarts with the new values.
- "max": 3,
- "collect": [],
- "collect2": [],
- },
- ),
- ]
- )
-
- updated, expired = process(
- window, value=4, key=key, transaction=tx, timestamp_ms=22
- )
- assert (
- expired
- == [
- (
- key,
- {
- "start": 12,
- "end": 22,
- "count": 3,
- "sum": 8,
- "sum2": 5, # sum2 only aggregates the values after the update
- "mean": 4, # mean was replace by max. The aggregation restarts with the new values.
- "max": 4,
- "collect": [3, 1, 4],
- "collect2": [3, 1, 4],
- },
- ),
- ]
- )
- assert (
- updated
- == [
- (
- key,
- {
- "start": 12,
- "end": 22,
- "count": 3,
- "sum": 8,
- "sum2": 5, # sum2 only aggregates the values after the update
- "mean": 4, # mean was replace by max. The aggregation restarts with the new values.
- "max": 4,
- "collect": [],
- "collect2": [],
- },
- ),
- (
- key,
- {
- "start": 22,
- "end": 22,
- "count": 1,
- "sum": 4,
- "sum2": 4,
- "mean": 4,
- "max": 4,
- "collect": [],
- "collect2": [],
- },
- ),
- ]
- )
-
- def test_count(self, count_hopping_window_definition_factory, state_manager):
- window_def = count_hopping_window_definition_factory(count=4, step=2)
- window = window_def.count()
- assert window.name == "hopping_count_window_count"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="", value=0, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 1
- assert expired == []
-
- updated, expired = process(
- window, key="", value=0, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 2
- assert expired == []
-
- updated, expired = process(
- window, key="", value=0, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 3
- assert updated[1][1]["value"] == 1
- assert expired == []
-
- updated, expired = process(
- window, key="", value=0, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 4
- assert updated[1][1]["value"] == 2
- assert len(expired) == 1
- assert expired[0][1]["value"] == 4
-
- updated, expired = process(
- window, key="", value=0, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 3
- assert updated[1][1]["value"] == 1
- assert len(expired) == 0
-
- updated, expired = process(
- window, key="", value=0, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 4
- assert updated[1][1]["value"] == 2
- assert len(expired) == 1
- assert expired[0][1]["value"] == 4
-
- def test_sum(self, count_hopping_window_definition_factory, state_manager):
- window_def = count_hopping_window_definition_factory(count=4, step=2)
- window = window_def.sum()
- assert window.name == "hopping_count_window_sum"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 1
- assert expired == []
-
- updated, expired = process(
- window, key="", value=2, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 3 # 1 + 2
- assert expired == []
-
- updated, expired = process(
- window, key="", value=3, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 6 # 1 + 2 + 3
- assert updated[1][1]["value"] == 3
- assert expired == []
-
- updated, expired = process(
- window, key="", value=4, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 10 # 1 + 2 + 3 + 4
- assert updated[1][1]["value"] == 7 # 3 + 4
- assert len(expired) == 1
- assert expired[0][1]["value"] == 10
-
- updated, expired = process(
- window, key="", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 12 # 3 + 4 + 5
- assert updated[1][1]["value"] == 5
- assert len(expired) == 0
-
- updated, expired = process(
- window, key="", value=6, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 18 # 3 + 4 + 5 + 6
- assert updated[1][1]["value"] == 11 # 5 + 6
- assert len(expired) == 1
- assert expired[0][1]["value"] == 18
-
- def test_mean(self, count_hopping_window_definition_factory, state_manager):
- window_def = count_hopping_window_definition_factory(count=4, step=2)
- window = window_def.mean()
- assert window.name == "hopping_count_window_mean"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 1
- assert expired == []
-
- updated, expired = process(
- window, key="", value=2, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 1.5 # (1 + 2) / 2
- assert expired == []
-
- updated, expired = process(
- window, key="", value=3, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 2 # (1 + 2 + 3) / 3
- assert updated[1][1]["value"] == 3
- assert expired == []
-
- updated, expired = process(
- window, key="", value=4, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 2.5 # (1 + 2 + 3 + 4) / 4
- assert updated[1][1]["value"] == 3.5 # 3 + 4
- assert len(expired) == 1
- assert expired[0][1]["value"] == 2.5
-
- updated, expired = process(
- window, key="", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 4 # (3 + 4 + 5) / 3
- assert updated[1][1]["value"] == 5
- assert len(expired) == 0
-
- updated, expired = process(
- window, key="", value=6, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert (
- updated[0][1]["value"] == 4.5
- ) # (3 # sum2 only aggregates the values after the update + 6) / 2
- assert len(expired) == 1
- assert expired[0][1]["value"] == 4.5
-
- def test_reduce(self, count_hopping_window_definition_factory, state_manager):
- window_def = count_hopping_window_definition_factory(count=4, step=2)
- window = window_def.reduce(
- reducer=lambda agg, current: agg + [current],
- initializer=lambda value: [value],
- )
- assert window.name == "hopping_count_window_reduce"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == [1]
- assert expired == []
-
- updated, expired = process(
- window, key="", value=2, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == [1, 2]
- assert expired == []
-
- updated, expired = process(
- window, key="", value=3, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == [1, 2, 3]
- assert updated[1][1]["value"] == [3]
- assert expired == []
-
- updated, expired = process(
- window, key="", value=4, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == [1, 2, 3, 4]
- assert updated[1][1]["value"] == [3, 4]
- assert len(expired) == 1
- assert expired[0][1]["value"] == [1, 2, 3, 4]
-
- updated, expired = process(
- window, key="", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == [3, 4, 5]
- assert updated[1][1]["value"] == [5]
- assert len(expired) == 0
-
- updated, expired = process(
- window, key="", value=6, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == [3, 4, 5, 6]
- assert updated[1][1]["value"] == [5, 6]
- assert len(expired) == 1
- assert expired[0][1]["value"] == [3, 4, 5, 6]
-
- def test_max(self, count_hopping_window_definition_factory, state_manager):
- window_def = count_hopping_window_definition_factory(count=4, step=2)
- window = window_def.max()
- assert window.name == "hopping_count_window_max"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 1
- assert expired == []
-
- updated, expired = process(
- window, key="", value=2, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 2
- assert expired == []
-
- updated, expired = process(
- window, key="", value=4, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 4
- assert updated[1][1]["value"] == 4
- assert expired == []
-
- updated, expired = process(
- window, key="", value=3, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 4
- assert updated[1][1]["value"] == 4
- assert len(expired) == 1
- assert expired[0][1]["value"] == 4
-
- updated, expired = process(
- window, key="", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 5
- assert updated[1][1]["value"] == 5
- assert len(expired) == 0
-
- updated, expired = process(
- window, key="", value=6, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 6
- assert updated[1][1]["value"] == 6
- assert len(expired) == 1
- assert expired[0][1]["value"] == 6
-
- def test_min(self, count_hopping_window_definition_factory, state_manager):
- window_def = count_hopping_window_definition_factory(count=4, step=2)
- window = window_def.min()
- assert window.name == "hopping_count_window_min"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="", value=4, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 4
- assert expired == []
-
- updated, expired = process(
- window, key="", value=2, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 2
- assert expired == []
-
- updated, expired = process(
- window, key="", value=3, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 2
- assert updated[1][1]["value"] == 3
- assert expired == []
-
- updated, expired = process(
- window, key="", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 2
- assert updated[1][1]["value"] == 3
- assert len(expired) == 1
- assert expired[0][1]["value"] == 2
-
- updated, expired = process(
- window, key="", value=6, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 3
- assert updated[1][1]["value"] == 6
- assert len(expired) == 0
-
- updated, expired = process(
- window, key="", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 2
- assert updated[0][1]["value"] == 3
- assert updated[1][1]["value"] == 5
- assert len(expired) == 1
- assert expired[0][1]["value"] == 3
-
- def test_collect(self, count_hopping_window_definition_factory, state_manager):
- window_def = count_hopping_window_definition_factory(count=4, step=2)
- window = window_def.collect()
- assert window.name == "hopping_count_window_collect"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=2, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=3, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=4, transaction=tx, timestamp_ms=100
- )
- assert updated == []
- assert len(expired) == 1
- assert expired[0][1]["value"] == [1, 2, 3, 4]
-
- updated, expired = process(
- window, key="", value=5, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=6, transaction=tx, timestamp_ms=100
- )
- assert updated == []
- assert len(expired) == 1
- assert expired[0][1]["value"] == [3, 4, 5, 6]
-
- with store.start_partition_transaction(0) as tx:
- state = tx.as_state(prefix="")
- remaining_items = state.get_from_collection(start=0, end=1000)
- assert remaining_items == [5, 6]
-
- def test_unaligned_steps(
- self, count_hopping_window_definition_factory, state_manager
- ):
- window_def = count_hopping_window_definition_factory(count=5, step=2)
- window = window_def.collect()
- window.register_store()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=2, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=3, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=4, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=5, transaction=tx, timestamp_ms=100
- )
- assert updated == []
- assert len(expired) == 1
- assert expired[0][1]["value"] == [1, 2, 3, 4, 5]
-
- updated, expired = process(
- window, key="", value=6, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=7, transaction=tx, timestamp_ms=100
- )
- assert updated == []
- assert len(expired) == 1
- assert expired[0][1]["value"] == [3, 4, 5, 6, 7]
-
- updated, expired = process(
- window, key="", value=8, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=9, transaction=tx, timestamp_ms=100
- )
- assert updated == []
- assert len(expired) == 1
- assert expired[0][1]["value"] == [5, 6, 7, 8, 9]
-
- updated, expired = process(
- window, key="", value=10, transaction=tx, timestamp_ms=100
- )
- assert updated == expired == []
-
- updated, expired = process(
- window, key="", value=11, transaction=tx, timestamp_ms=100
- )
- assert updated == []
- assert len(expired) == 1
- assert expired[0][1]["value"] == [7, 8, 9, 10, 11]
-
- with store.start_partition_transaction(0) as tx:
- state = tx.as_state(prefix="")
- remaining_items = state.get_from_collection(start=0, end=1000)
- assert remaining_items == [9, 10, 11]
-
- def test_multiple_keys_sum(
- self, count_hopping_window_definition_factory, state_manager
- ):
- window_def = count_hopping_window_definition_factory(count=3, step=1)
- window = window_def.sum()
- window.register_store()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
-
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="key1", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(expired) == 0
- assert len(updated) == 1
- assert updated[0][1]["value"] == 1
- updated, expired = process(
- window, key="key2", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(expired) == 0
- assert len(updated) == 1
- assert updated[0][1]["value"] == 5
-
- updated, expired = process(
- window, key="key1", value=2, transaction=tx, timestamp_ms=110
- )
- assert len(expired) == 0
- assert len(updated) == 2
- assert updated[0][1]["value"] == 3
- assert updated[1][1]["value"] == 2
-
- updated, expired = process(
- window, key="key2", value=4, transaction=tx, timestamp_ms=110
- )
- assert len(expired) == 0
- assert len(updated) == 2
- assert updated[0][1]["value"] == 9
- assert updated[1][1]["value"] == 4
-
- updated, expired = process(
- window, key="key1", value=3, transaction=tx, timestamp_ms=120
- )
- assert expired[0][1]["value"] == 6
- assert len(updated) == 3
- assert updated[0][1]["value"] == 6
- assert updated[1][1]["value"] == 5
- assert updated[2][1]["value"] == 3
-
- updated, expired = process(
- window, key="key1", value=4, transaction=tx, timestamp_ms=130
- )
- assert expired[0][1]["value"] == 9
- assert len(updated) == 3
- assert updated[0][1]["value"] == 9
- assert updated[1][1]["value"] == 7
- assert updated[2][1]["value"] == 4
-
- updated, expired = process(
- window, key="key2", value=3, transaction=tx, timestamp_ms=120
- )
- assert expired[0][1]["value"] == 12
- assert len(updated) == 3
- assert updated[0][1]["value"] == 12
- assert updated[1][1]["value"] == 7
- assert updated[2][1]["value"] == 3
-
- updated, expired = process(
- window, key="key2", value=2, transaction=tx, timestamp_ms=130
- )
- assert expired[0][1]["value"] == 9
- assert len(updated) == 3
- assert updated[0][1]["value"] == 9
- assert updated[1][1]["value"] == 5
- assert updated[2][1]["value"] == 2
-
- updated, expired = process(
- window, key="key1", value=5, transaction=tx, timestamp_ms=140
- )
- assert expired[0][1]["value"] == 12
- assert len(updated) == 3
- assert updated[0][1]["value"] == 12
- assert updated[1][1]["value"] == 9
- assert updated[2][1]["value"] == 5
-
- updated, expired = process(
- window, key="key2", value=1, transaction=tx, timestamp_ms=140
- )
- assert expired[0][1]["value"] == 6
- assert len(updated) == 3
- assert updated[0][1]["value"] == 6
- assert updated[1][1]["value"] == 3
- assert updated[2][1]["value"] == 1
-
- def test_multiple_keys_collect(
- self, count_hopping_window_definition_factory, state_manager
- ):
- window_def = count_hopping_window_definition_factory(count=3, step=1)
- window = window_def.collect()
- window.register_store()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
-
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="key1", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(expired) == 0
- assert len(updated) == 0
- updated, expired = process(
- window, key="key2", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(expired) == 0
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key1", value=2, transaction=tx, timestamp_ms=110
- )
- assert len(expired) == 0
- assert len(updated) == 0
- updated, expired = process(
- window, key="key2", value=4, transaction=tx, timestamp_ms=110
- )
- assert len(expired) == 0
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key1", value=3, transaction=tx, timestamp_ms=120
- )
- assert expired[0][1]["value"] == [1, 2, 3]
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key1", value=4, transaction=tx, timestamp_ms=130
- )
- assert expired[0][1]["value"] == [2, 3, 4]
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key2", value=3, transaction=tx, timestamp_ms=120
- )
- assert expired[0][1]["value"] == [5, 4, 3]
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key2", value=2, transaction=tx, timestamp_ms=130
- )
- assert expired[0][1]["value"] == [4, 3, 2]
- assert len(updated) == 0
- updated, expired = process(
- window, key="key1", value=5, transaction=tx, timestamp_ms=140
- )
- assert expired[0][1]["value"] == [3, 4, 5]
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key2", value=1, transaction=tx, timestamp_ms=140
- )
- assert expired[0][1]["value"] == [3, 2, 1]
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key2", value=0, transaction=tx, timestamp_ms=130
- )
- assert expired[0][1]["value"] == [2, 1, 0]
- assert len(updated) == 0
- updated, expired = process(
- window, key="key1", value=6, transaction=tx, timestamp_ms=140
- )
- assert expired[0][1]["value"] == [4, 5, 6]
- assert len(updated) == 0
diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py
index fc5ab8eba..22a3d6885 100644
--- a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py
+++ b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py
@@ -21,11 +21,20 @@
}
-def process(window, value, key, transaction, timestamp_ms):
- updated, expired = window.process_window(
- value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms
+def process(window, value, key, transaction, timestamp_ms, headers=None):
+ updated, triggered = window.process_window(
+ value=value,
+ key=key,
+ transaction=transaction,
+ timestamp_ms=timestamp_ms,
+ headers=headers,
)
- return list(updated), list(expired)
+ expired = window.expire_by_partition(
+ transaction=transaction, timestamp_ms=timestamp_ms
+ )
+ # Combine triggered windows (from callbacks) with time-expired windows
+ all_expired = list(triggered) + list(expired)
+ return list(updated), all_expired
@dataclass
@@ -350,8 +359,8 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]:
value=C,
updated=[{"start": 15, "end": 25, "value": [A, B, C]}], # left C
expired=[{"start": 14, "end": 24, "value": [A, B]}], # left B
- deleted=[{"start": 6, "end": 16, "value": [A]}], # left A
present=[
+ {"start": 6, "end": 16, "value": [16, [A]]}, # right A
{"start": 17, "end": 27, "value": [25, [B, C]]}, # right A
{"start": 25, "end": 35, "value": [25, [C]]}, # right B
],
@@ -405,6 +414,7 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]:
value=C,
updated=[{"start": 15, "end": 25, "value": [A, B, C]}], # left C
present=[
+ {"start": 6, "end": 16, "value": [16, [A]]}, # left A
{"start": 14, "end": 24, "value": [24, [A, B]]}, # left B
{"start": 17, "end": 27, "value": [25, [B, C]]}, # right A
{"start": 25, "end": 35, "value": [25, [C]]}, # right B
@@ -652,8 +662,8 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]:
value=D,
updated=[{"start": 16, "end": 26, "value": [A, B, C, D]}], # left D
expired=[{"start": 12, "end": 22, "value": [A, C]}], # left A
- deleted=[{"start": 12, "end": 22, "value": [A, C]}], # left A
present=[
+ {"start": 12, "end": 22, "value": [22, [A, C]]},
{"start": 13, "end": 23, "value": [23, [A, B, C]]}, # left B
{"start": 18, "end": 28, "value": [26, [A, B, D]]}, # right C
{"start": 23, "end": 33, "value": [26, [B, D]]}, # right A
@@ -673,7 +683,7 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]:
# ______________________________________________________________________
# B 20 |---------|
# 20 30
-# ^ 9 expiration watermark = 20 - 10 - 0 - 1
+# ^ 9 expiration watermark = 20 - 10 - 0 - 1c
# ______________________________________________________________________
# C 5 C
# ^ 9 expiration watermark = 20 - 10 - 0 - 1
@@ -690,12 +700,12 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]:
value=B,
updated=[{"start": 10, "end": 20, "value": [B]}], # left B
expired=[{"start": 0, "end": 1, "value": [A]}], # left A
+ deleted=[{"start": 0, "end": 1, "value": [A]}], # left A
),
Message(
timestamp=5,
value=C,
present=[
- {"start": 0, "end": 1, "value": [1, [A]]},
{"start": 10, "end": 20, "value": [20, [B]]},
],
),
@@ -729,12 +739,12 @@ def expected_windows_in_state(self) -> set[tuple[int, int]]:
value=B,
updated=[{"start": 10, "end": 20, "value": [B]}], # left B
expired=[{"start": 0, "end": 1, "value": [A]}], # left A
+ deleted=[{"start": 0, "end": 1, "value": [A]}],
),
Message(
timestamp=9,
value=C,
present=[
- {"start": 0, "end": 1, "value": [1, [A]]},
{"start": 10, "end": 20, "value": [20, [B]]},
],
),
@@ -949,10 +959,8 @@ def test_sliding_window_reduce(
{"start": 1, "end": 11, "value": [A]},
{"start": 2, "end": 12, "value": [A, B]},
],
- deleted=[
- {"start": 1, "end": 11},
- ],
present=[
+ {"start": 1, "end": 11, "value": [11, None]},
{"start": 2, "end": 12, "value": [12, None]},
{"start": 11, "end": 21, "value": [21, None]},
{"start": 12, "end": 22, "value": [21, None]},
@@ -975,7 +983,7 @@ def test_sliding_window_reduce(
present=[
{"start": 50, "end": 60, "value": [60, None]},
],
- expected_values_in_state=[D],
+ expected_values_in_state=[C, D],
),
]
@@ -1072,7 +1080,21 @@ def test_sliding_window_multiaggregation(
updated, expired = process(
window, value=3, key=key, transaction=tx, timestamp_ms=3
)
- assert not expired
+ assert expired == [
+ (
+ key,
+ {
+ "start": 0,
+ "end": 2,
+ "count": 1,
+ "sum": 1,
+ "mean": 1.0,
+ "max": 1,
+ "min": 1,
+ "collect": [1],
+ },
+ ),
+ ]
assert updated == [
(
key,
@@ -1093,21 +1115,6 @@ def test_sliding_window_multiaggregation(
window, value=5, key=key, transaction=tx, timestamp_ms=11
)
assert expired == [
- (
- key,
- {
- "start": 0,
- "end": 2,
- "count": 1,
- "sum": 1,
- "mean": 1.0,
- "max": 1,
- "min": 1,
- "collect": [
- 1,
- ],
- },
- ),
(
key,
{
diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py
index 98d9f56c1..959357a51 100644
--- a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py
+++ b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py
@@ -3,34 +3,356 @@
import quixstreams.dataframe.windows.aggregations as agg
from quixstreams.dataframe import DataFrameRegistry
from quixstreams.dataframe.windows import (
- TumblingCountWindowDefinition,
TumblingTimeWindowDefinition,
)
-from quixstreams.dataframe.windows.time_based import ClosingStrategy
@pytest.fixture()
def tumbling_window_definition_factory(state_manager, dataframe_factory):
- def factory(duration_ms: int, grace_ms: int = 0) -> TumblingTimeWindowDefinition:
+ def factory(
+ duration_ms: int,
+ grace_ms: int = 0,
+ before_update=None,
+ after_update=None,
+ ) -> TumblingTimeWindowDefinition:
sdf = dataframe_factory(
state_manager=state_manager, registry=DataFrameRegistry()
)
window_def = TumblingTimeWindowDefinition(
- duration_ms=duration_ms, grace_ms=grace_ms, dataframe=sdf
+ duration_ms=duration_ms,
+ grace_ms=grace_ms,
+ dataframe=sdf,
+ before_update=before_update,
+ after_update=after_update,
)
return window_def
return factory
-def process(window, value, key, transaction, timestamp_ms):
- updated, expired = window.process_window(
- value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms
+def process(window, value, key, transaction, timestamp_ms, headers=None):
+ updated, triggered = window.process_window(
+ value=value,
+ key=key,
+ timestamp_ms=timestamp_ms,
+ headers=headers,
+ transaction=transaction,
)
- return list(updated), list(expired)
+ expired = window.expire_by_partition(
+ transaction=transaction, timestamp_ms=timestamp_ms
+ )
+ # Combine triggered windows (from callbacks) with time-expired windows
+ all_expired = list(triggered) + list(expired)
+ return list(updated), all_expired
class TestTumblingWindow:
+ def test_tumbling_window_with_after_update_trigger(
+ self, tumbling_window_definition_factory, state_manager
+ ):
+ # Define a trigger that expires the window when the sum reaches 9 or more
+ def trigger_on_sum_9(aggregated, value, key, timestamp, headers) -> bool:
+ return aggregated >= 9
+
+ window_def = tumbling_window_definition_factory(
+ duration_ms=100, grace_ms=0, after_update=trigger_on_sum_9
+ )
+ window = window_def.sum()
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ # Add value=2, sum becomes 2, delta from 0 is 2, should not trigger
+ updated, expired = process(
+ window, value=2, key=key, transaction=tx, timestamp_ms=50
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 2
+ assert not expired
+
+ # Add value=2, sum becomes 4, delta from 2 is 2, should not trigger
+ updated, expired = process(
+ window, value=2, key=key, transaction=tx, timestamp_ms=60
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 4
+ assert not expired
+
+ # Add value=5, sum becomes 9, delta from 4 is 5, should trigger (>= 5)
+ updated, expired = process(
+ window, value=5, key=key, transaction=tx, timestamp_ms=70
+ )
+ assert not updated # Window was triggered
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 9
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+
+ # Next value should start a new window
+ updated, expired = process(
+ window, value=3, key=key, transaction=tx, timestamp_ms=80
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 3
+ assert not expired
+
+ def test_tumbling_window_with_before_update_trigger(
+ self, tumbling_window_definition_factory, state_manager
+ ):
+ """Test that before_update callback works and triggers before aggregation."""
+
+ # Define a trigger that expires the window before adding a value
+ # if the sum would exceed 10
+ def trigger_before_exceeding_10(
+ aggregated, value, key, timestamp, headers
+ ) -> bool:
+ return (aggregated + value) > 10
+
+ window_def = tumbling_window_definition_factory(
+ duration_ms=100, grace_ms=0, before_update=trigger_before_exceeding_10
+ )
+ window = window_def.sum()
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ # Add value=3, sum becomes 3, would not exceed 10, should not trigger
+ updated, expired = process(
+ window, value=3, key=key, transaction=tx, timestamp_ms=50
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 3
+ assert not expired
+
+ # Add value=5, sum becomes 8, would not exceed 10, should not trigger
+ updated, expired = process(
+ window, value=5, key=key, transaction=tx, timestamp_ms=60
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 8
+ assert not expired
+
+ # Add value=3, would make sum 11 which exceeds 10, should trigger BEFORE adding
+ # So the expired window should have value=8 (not 11)
+ updated, expired = process(
+ window, value=3, key=key, transaction=tx, timestamp_ms=70
+ )
+ assert not updated # Window was triggered
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == 8 # Before the update (not 11)
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+
+ # Next value should start a new window
+ updated, expired = process(
+ window, value=2, key=key, transaction=tx, timestamp_ms=80
+ )
+ assert len(updated) == 1
+ assert updated[0][1]["value"] == 2
+ assert not expired
+
+ def test_tumbling_window_collect_with_after_update_trigger(
+ self, tumbling_window_definition_factory, state_manager
+ ):
+ """Test that after_update callback works with collect."""
+
+ # Define a trigger that expires the window when we collect 3 or more items
+ def trigger_on_count_3(aggregated, value, key, timestamp, headers) -> bool:
+ # For collect, aggregated is the list of collected values
+ return len(aggregated) >= 3
+
+ window_def = tumbling_window_definition_factory(
+ duration_ms=100, grace_ms=0, after_update=trigger_on_count_3
+ )
+ window = window_def.collect()
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ # Add first value - should not trigger (count=1)
+ updated, expired = process(
+ window, value=1, key=key, transaction=tx, timestamp_ms=50
+ )
+ assert not updated # collect doesn't emit on updates
+ assert not expired
+
+ # Add second value - should not trigger (count=2)
+ updated, expired = process(
+ window, value=2, key=key, transaction=tx, timestamp_ms=60
+ )
+ assert not updated
+ assert not expired
+
+ # Add third value - should trigger (count=3)
+ updated, expired = process(
+ window, value=3, key=key, transaction=tx, timestamp_ms=70
+ )
+ assert not updated
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [1, 2, 3]
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+
+ # Next value at t=80 still belongs to window [0, 100)
+ # Window is "resurrected" because collection values weren't deleted
+ # (we let normal expiration handle cleanup for simplicity)
+ # Window [0, 100) now has [1, 2, 3, 4] = 4 items - TRIGGERS AGAIN
+ updated, expired = process(
+ window, value=4, key=key, transaction=tx, timestamp_ms=80
+ )
+ assert not updated
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [1, 2, 3, 4]
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+
+ def test_tumbling_window_collect_with_before_update_trigger(
+ self, tumbling_window_definition_factory, state_manager
+ ):
+ """Test that before_update callback works with collect."""
+
+ # Define a trigger that expires the window before adding a value
+ # if the collection would reach 3 or more items
+ def trigger_before_count_3(aggregated, value, key, timestamp, headers) -> bool:
+ # For collect, aggregated is the list of collected values BEFORE adding the new value
+ return len(aggregated) + 1 >= 3
+
+ window_def = tumbling_window_definition_factory(
+ duration_ms=100, grace_ms=0, before_update=trigger_before_count_3
+ )
+ window = window_def.collect()
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ # Add first value - should not trigger (count would be 1)
+ updated, expired = process(
+ window, value=1, key=key, transaction=tx, timestamp_ms=50
+ )
+ assert not updated # collect doesn't emit on updates
+ assert not expired
+
+ # Add second value - should not trigger (count would be 2)
+ updated, expired = process(
+ window, value=2, key=key, transaction=tx, timestamp_ms=60
+ )
+ assert not updated
+ assert not expired
+
+ # Add third value - should trigger BEFORE adding (count would be 3)
+ # Expired window should have [1, 2] (not [1, 2, 3])
+ updated, expired = process(
+ window, value=3, key=key, transaction=tx, timestamp_ms=70
+ )
+ assert not updated
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [1, 2] # Before adding the third value
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+
+ # Next value should start accumulating in the same window again
+ # (window was deleted but collection values remain until natural expiration)
+ updated, expired = process(
+ window, value=4, key=key, transaction=tx, timestamp_ms=80
+ )
+ assert not updated
+ # Window [0, 100) is "resurrected" with [1, 2, 3]
+ # Adding value 4 would make it 4 items, triggers again
+ assert len(expired) == 1
+ assert expired[0][1]["value"] == [1, 2, 3] # Before adding 4
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+
+ def test_tumbling_window_agg_and_collect_with_before_update_trigger(
+ self, tumbling_window_definition_factory, state_manager
+ ):
+ """Test before_update with BOTH aggregation and collect.
+
+ This verifies that:
+ 1. The triggered window does NOT include the triggering value in collect
+ 2. The triggering value IS still added to collection storage for future
+ 3. The aggregated value is BEFORE the triggering value
+ """
+ import quixstreams.dataframe.windows.aggregations as agg
+
+ # Trigger when count would reach 3
+ def trigger_before_count_3(agg_dict, value, key, timestamp, headers) -> bool:
+ # In multi-aggregation, keys are like 'count/Count', 'sum/Sum'
+ # Find the count aggregation value
+ for k, v in agg_dict.items():
+ if k.startswith("count"):
+ return v + 1 >= 3
+ return False
+
+ window_def = tumbling_window_definition_factory(
+ duration_ms=100, grace_ms=0, before_update=trigger_before_count_3
+ )
+ window = window_def.agg(count=agg.Count(), sum=agg.Sum(), collect=agg.Collect())
+ window.final()
+
+ store = state_manager.get_store(stream_id="test", store_name=window.name)
+ store.assign_partition(0)
+ key = b"key"
+
+ with store.start_partition_transaction(0) as tx:
+ # Add value=1, count becomes 1
+ updated, expired = process(
+ window, value=1, key=key, transaction=tx, timestamp_ms=50
+ )
+ assert len(updated) == 1
+ assert not expired
+
+ # Add value=2, count becomes 2
+ updated, expired = process(
+ window, value=2, key=key, transaction=tx, timestamp_ms=60
+ )
+ assert len(updated) == 1
+ assert not expired
+
+ # Add value=3, would make count 3
+ # Should trigger BEFORE adding
+ updated, expired = process(
+ window, value=3, key=key, transaction=tx, timestamp_ms=70
+ )
+ assert not updated # Window was triggered
+ assert len(expired) == 1
+
+ assert expired[0][1]["count"] == 2 # Before the update (not 3)
+ assert expired[0][1]["sum"] == 3 # Before the update (1+2, not 1+2+3)
+ # CRITICAL: collect should NOT include the triggering value (3)
+ assert expired[0][1]["collect"] == [1, 2]
+ assert expired[0][1]["start"] == 0
+ assert expired[0][1]["end"] == 100
+
+ # Next value should start a new window
+ # But the triggering value (3) should still be in storage
+ updated, expired = process(
+ window, value=4, key=key, transaction=tx, timestamp_ms=80
+ )
+ assert len(updated) == 1
+ assert not expired
+
+ # Force window expiration to see what was collected
+ updated, expired = process(
+ window, value=5, key=key, transaction=tx, timestamp_ms=110
+ )
+ assert len(expired) == 1
+ # The collection should include the triggering value (3) that was added to storage
+ # even though it wasn't in the triggered window result
+ assert expired[0][1]["collect"] == [1, 2, 3, 4] # All values before t=110
+
@pytest.mark.parametrize(
"duration, grace, provided_name, func_name, expected_name",
[
@@ -70,7 +392,7 @@ def test_multiaggregation(
min=agg.Min(),
collect=agg.Collect(),
)
- window.final(closing_strategy="key")
+ window.final()
assert window.name == "tumbling_window_10"
store = state_manager.get_store(stream_id="test", store_name=window.name)
@@ -232,15 +554,14 @@ def test_multiaggregation(
)
]
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_tumblingwindow_count(
- self, expiration, tumbling_window_definition_factory, state_manager
+ self, tumbling_window_definition_factory, state_manager
):
window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5)
window = window_def.count()
assert window.name == "tumbling_window_10_count"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -253,15 +574,14 @@ def test_tumblingwindow_count(
assert updated[0][1]["value"] == 2
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_tumblingwindow_sum(
- self, expiration, tumbling_window_definition_factory, state_manager
+ self, tumbling_window_definition_factory, state_manager
):
window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5)
window = window_def.sum()
assert window.name == "tumbling_window_10_sum"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -274,15 +594,14 @@ def test_tumblingwindow_sum(
assert updated[0][1]["value"] == 3
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_tumblingwindow_mean(
- self, expiration, tumbling_window_definition_factory, state_manager
+ self, tumbling_window_definition_factory, state_manager
):
window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5)
window = window_def.mean()
assert window.name == "tumbling_window_10_mean"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -295,9 +614,8 @@ def test_tumblingwindow_mean(
assert updated[0][1]["value"] == 1.5
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_tumblingwindow_reduce(
- self, expiration, tumbling_window_definition_factory, state_manager
+ self, tumbling_window_definition_factory, state_manager
):
window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5)
window = window_def.reduce(
@@ -306,7 +624,7 @@ def test_tumblingwindow_reduce(
)
assert window.name == "tumbling_window_10_reduce"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -319,15 +637,14 @@ def test_tumblingwindow_reduce(
assert updated[0][1]["value"] == [2, 1]
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_tumblingwindow_max(
- self, expiration, tumbling_window_definition_factory, state_manager
+ self, tumbling_window_definition_factory, state_manager
):
window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5)
window = window_def.max()
assert window.name == "tumbling_window_10_max"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -340,15 +657,14 @@ def test_tumblingwindow_max(
assert updated[0][1]["value"] == 2
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_tumblingwindow_min(
- self, expiration, tumbling_window_definition_factory, state_manager
+ self, tumbling_window_definition_factory, state_manager
):
window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5)
window = window_def.min()
assert window.name == "tumbling_window_10_min"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -361,15 +677,14 @@ def test_tumblingwindow_min(
assert updated[0][1]["value"] == 1
assert not expired
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_tumblingwindow_collect(
- self, expiration, tumbling_window_definition_factory, state_manager
+ self, tumbling_window_definition_factory, state_manager
):
window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5)
window = window_def.collect()
assert window.name == "tumbling_window_10_collect"
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -401,16 +716,14 @@ def test_tumbling_window_def_init_invalid(
dataframe=dataframe_factory(),
)
- @pytest.mark.parametrize("expiration", ("key", "partition"))
def test_tumbling_window_process_window_expired(
self,
- expiration,
tumbling_window_definition_factory,
state_manager,
):
window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=0)
window = window_def.sum()
- window.final(closing_strategy=expiration)
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -445,7 +758,7 @@ def test_tumbling_partition_expiration(
):
window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=2)
window = window_def.sum()
- window.final(closing_strategy="partition")
+ window.final()
store = state_manager.get_store(stream_id="test", store_name=window.name)
store.assign_partition(0)
with store.start_partition_transaction(0) as tx:
@@ -491,588 +804,3 @@ def test_tumbling_partition_expiration(
(key1, {"start": 100, "end": 110, "value": 4}),
(key2, {"start": 100, "end": 110, "value": 14}),
]
-
- def test_tumbling_key_expiration_to_partition(
- self, tumbling_window_definition_factory, state_manager
- ):
- window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=0)
- window = window_def.sum()
- window.final(closing_strategy="key")
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- key1 = b"key1"
- key2 = b"key2"
-
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=100)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=102)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=105)
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=106)
-
- window._closing_strategy = ClosingStrategy.PARTITION
- with store.start_partition_transaction(0) as tx:
- key3 = b"key3"
-
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=107)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=108)
- updated, expired = process(
- window, value=1, key=key3, transaction=tx, timestamp_ms=115
- )
-
- assert updated == [
- (key3, {"start": 110, "end": 120, "value": 1}),
- ]
- assert expired == [
- (key1, {"start": 100, "end": 110, "value": 3}),
- (key2, {"start": 100, "end": 110, "value": 3}),
- ]
-
- def test_tumbling_partition_expiration_to_key(
- self, tumbling_window_definition_factory, state_manager
- ):
- window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=0)
- window = window_def.sum()
- window.final(closing_strategy="partition")
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- key1 = b"key1"
- key2 = b"key2"
-
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=100)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=102)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=105)
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=106)
-
- window._closing_strategy = ClosingStrategy.KEY
- with store.start_partition_transaction(0) as tx:
- key3 = b"key3"
-
- process(window, value=1, key=key1, transaction=tx, timestamp_ms=107)
- process(window, value=1, key=key2, transaction=tx, timestamp_ms=108)
- updated, expired = process(
- window, value=1, key=key3, transaction=tx, timestamp_ms=115
- )
-
- assert updated == [(key3, {"start": 110, "end": 120, "value": 1})]
- assert expired == []
-
- updated, expired = process(
- window, value=1, key=key1, transaction=tx, timestamp_ms=116
- )
- assert updated == [(key1, {"start": 110, "end": 120, "value": 1})]
- assert expired == [(key1, {"start": 100, "end": 110, "value": 3})]
-
-
-@pytest.fixture()
-def count_tumbling_window_definition_factory(state_manager, dataframe_factory):
- def factory(count: int) -> TumblingCountWindowDefinition:
- sdf = dataframe_factory(
- state_manager=state_manager, registry=DataFrameRegistry()
- )
- window_def = TumblingCountWindowDefinition(dataframe=sdf, count=count)
- return window_def
-
- return factory
-
-
-class TestCountTumblingWindow:
- @pytest.mark.parametrize(
- "count, name",
- [
- (-10, "test"),
- (0, "test"),
- (1, "test"),
- ],
- )
- def test_init_invalid(self, count, name, dataframe_factory):
- with pytest.raises(ValueError):
- TumblingCountWindowDefinition(
- count=count,
- name=name,
- dataframe=dataframe_factory(),
- )
-
- def test_multiaggregation(
- self,
- count_tumbling_window_definition_factory,
- state_manager,
- ):
- window = count_tumbling_window_definition_factory(count=2).agg(
- count=agg.Count(),
- sum=agg.Sum(),
- mean=agg.Mean(),
- max=agg.Max(),
- min=agg.Min(),
- collect=agg.Collect(),
- )
- window.final()
- assert window.name == "tumbling_count_window"
-
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- key = b"key"
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, value=1, key=key, transaction=tx, timestamp_ms=2
- )
- assert not expired
- assert updated == [
- (
- key,
- {
- "start": 2,
- "end": 2,
- "count": 1,
- "sum": 1,
- "mean": 1.0,
- "max": 1,
- "min": 1,
- "collect": [],
- },
- )
- ]
-
- updated, expired = process(
- window, value=4, key=key, transaction=tx, timestamp_ms=4
- )
- assert expired == [
- (
- key,
- {
- "start": 2,
- "end": 4,
- "count": 2,
- "sum": 5,
- "mean": 2.5,
- "max": 4,
- "min": 1,
- "collect": [1, 4],
- },
- )
- ]
- assert updated == [
- (
- key,
- {
- "start": 2,
- "end": 4,
- "count": 2,
- "sum": 5,
- "mean": 2.5,
- "max": 4,
- "min": 1,
- "collect": [],
- },
- )
- ]
-
- updated, expired = process(
- window, value=2, key=key, transaction=tx, timestamp_ms=12
- )
- assert not expired
- assert updated == [
- (
- key,
- {
- "start": 12,
- "end": 12,
- "count": 1,
- "sum": 2,
- "mean": 2.0,
- "max": 2,
- "min": 2,
- "collect": [],
- },
- )
- ]
-
- # Update window definition
- # * delete an aggregation (min)
- # * change aggregation but keep the name with new aggregation (mean -> max)
- # * add new aggregations (sum2, collect2)
- window = count_tumbling_window_definition_factory(count=2).agg(
- count=agg.Count(),
- sum=agg.Sum(),
- mean=agg.Max(),
- max=agg.Max(),
- collect=agg.Collect(),
- sum2=agg.Sum(),
- collect2=agg.Collect(),
- )
- assert window.name == "tumbling_count_window" # still the same window and store
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, value=1, key=key, transaction=tx, timestamp_ms=13
- )
- assert (
- expired
- == [
- (
- key,
- {
- "start": 12,
- "end": 13,
- "count": 2,
- "sum": 3,
- "sum2": 1, # sum2 only aggregates the values after the update
- "mean": 1, # mean was replace by max. The aggregation restarts with the new values.
- "max": 2,
- "collect": [2, 1],
- "collect2": [
- 2,
- 1,
- ], # Collect2 has all the values as they were fully collected before the update
- },
- )
- ]
- )
- assert (
- updated
- == [
- (
- key,
- {
- "start": 12,
- "end": 13,
- "count": 2,
- "sum": 3,
- "sum2": 1, # sum2 only aggregates the values after the update
- "mean": 1, # mean was replace by max. The aggregation restarts with the new values.
- "max": 2,
- "collect": [],
- "collect2": [],
- },
- )
- ]
- )
-
- updated, expired = process(
- window, value=5, key=key, transaction=tx, timestamp_ms=15
- )
- assert not expired
- assert updated == [
- (
- key,
- {
- "start": 15,
- "end": 15,
- "count": 1,
- "sum": 5,
- "sum2": 5,
- "mean": 5,
- "max": 5,
- "collect": [],
- "collect2": [],
- },
- )
- ]
-
- def test_count(self, count_tumbling_window_definition_factory, state_manager):
- window_def = count_tumbling_window_definition_factory(count=10)
- window = window_def.count()
- assert window.name == "tumbling_count_window_count"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- process(window, key="", value=0, transaction=tx, timestamp_ms=100)
- updated, expired = process(
- window, key="", value=0, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 2
- assert not expired
-
- def test_sum(self, count_tumbling_window_definition_factory, state_manager):
- window_def = count_tumbling_window_definition_factory(count=10)
- window = window_def.sum()
- assert window.name == "tumbling_count_window_sum"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- process(window, key="", value=2, transaction=tx, timestamp_ms=100)
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 3
- assert not expired
-
- def test_mean(self, count_tumbling_window_definition_factory, state_manager):
- window_def = count_tumbling_window_definition_factory(count=10)
- window = window_def.mean()
- assert window.name == "tumbling_count_window_mean"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- process(window, key="", value=2, transaction=tx, timestamp_ms=100)
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 1.5
- assert not expired
-
- def test_reduce(self, count_tumbling_window_definition_factory, state_manager):
- window_def = count_tumbling_window_definition_factory(count=10)
- window = window_def.reduce(
- reducer=lambda agg, current: agg + [current],
- initializer=lambda value: [value],
- )
- assert window.name == "tumbling_count_window_reduce"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- process(window, key="", value=2, transaction=tx, timestamp_ms=100)
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == [2, 1]
- assert not expired
-
- def test_max(self, count_tumbling_window_definition_factory, state_manager):
- window_def = count_tumbling_window_definition_factory(count=10)
- window = window_def.max()
- assert window.name == "tumbling_count_window_max"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- process(window, key="", value=2, transaction=tx, timestamp_ms=100)
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 2
- assert not expired
-
- def test_min(self, count_tumbling_window_definition_factory, state_manager):
- window_def = count_tumbling_window_definition_factory(count=10)
- window = window_def.min()
- assert window.name == "tumbling_count_window_min"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- process(window, key="", value=2, transaction=tx, timestamp_ms=100)
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 1
- assert not expired
-
- def test_collect(self, count_tumbling_window_definition_factory, state_manager):
- window_def = count_tumbling_window_definition_factory(count=3)
- window = window_def.collect()
- assert window.name == "tumbling_count_window_collect"
-
- window.final()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- process(window, key="", value=1, transaction=tx, timestamp_ms=100)
- process(window, key="", value=2, transaction=tx, timestamp_ms=100)
- updated, expired = process(
- window, key="", value=3, transaction=tx, timestamp_ms=101
- )
-
- assert not updated
- assert expired == [("", {"start": 100, "end": 101, "value": [1, 2, 3]})]
-
- with store.start_partition_transaction(0) as tx:
- state = tx.as_state(prefix=b"")
- remaining_items = state.get_from_collection(start=0, end=1000)
- assert remaining_items == []
-
- def test_window_expired(
- self,
- count_tumbling_window_definition_factory,
- state_manager,
- ):
- window_def = count_tumbling_window_definition_factory(count=2)
- window = window_def.sum()
- window.register_store()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
- with store.start_partition_transaction(0) as tx:
- # Add first item to the window
- updated, expired = process(
- window, key="", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 1
- assert updated[0][1]["start"] == 100
- assert updated[0][1]["end"] == 100
- assert not expired
-
- # Now add second item to the window
- # The window is now expired and should be returned
- updated, expired = process(
- window, key="", value=2, transaction=tx, timestamp_ms=110
- )
- assert len(updated) == 1
- assert updated[0][1]["value"] == 3
- assert updated[0][1]["start"] == 100
- assert updated[0][1]["end"] == 110
-
- assert len(expired) == 1
- assert expired[0][1]["value"] == 3
- assert expired[0][1]["start"] == 100
- assert expired[0][1]["end"] == 110
-
- def test_multiple_keys_sum(
- self, count_tumbling_window_definition_factory, state_manager
- ):
- window_def = count_tumbling_window_definition_factory(count=3)
- window = window_def.sum()
- window.register_store()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
-
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="key1", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(expired) == 0
- assert updated[0][1]["value"] == 1
- updated, expired = process(
- window, key="key2", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(expired) == 0
- assert updated[0][1]["value"] == 5
-
- updated, expired = process(
- window, key="key1", value=2, transaction=tx, timestamp_ms=110
- )
- assert len(expired) == 0
- assert updated[0][1]["value"] == 3
- updated, expired = process(
- window, key="key2", value=4, transaction=tx, timestamp_ms=110
- )
- assert len(expired) == 0
- assert updated[0][1]["value"] == 9
-
- updated, expired = process(
- window, key="key1", value=3, transaction=tx, timestamp_ms=120
- )
- assert expired[0][1]["value"] == 6
- assert updated[0][1]["value"] == 6
-
- updated, expired = process(
- window, key="key1", value=4, transaction=tx, timestamp_ms=130
- )
- assert len(expired) == 0
- assert updated[0][1]["value"] == 4
-
- updated, expired = process(
- window, key="key2", value=3, transaction=tx, timestamp_ms=120
- )
- assert expired[0][1]["value"] == 12
- assert updated[0][1]["value"] == 12
-
- updated, expired = process(
- window, key="key2", value=2, transaction=tx, timestamp_ms=130
- )
- assert len(expired) == 0
- assert updated[0][1]["value"] == 2
- updated, expired = process(
- window, key="key1", value=5, transaction=tx, timestamp_ms=140
- )
- assert len(expired) == 0
- assert updated[0][1]["value"] == 9
-
- updated, expired = process(
- window, key="key2", value=1, transaction=tx, timestamp_ms=140
- )
- assert len(expired) == 0
- assert updated[0][1]["value"] == 3
-
- def test_multiple_keys_collect(
- self, count_tumbling_window_definition_factory, state_manager
- ):
- window_def = count_tumbling_window_definition_factory(count=3)
- window = window_def.collect()
- window.register_store()
- store = state_manager.get_store(stream_id="test", store_name=window.name)
- store.assign_partition(0)
-
- with store.start_partition_transaction(0) as tx:
- updated, expired = process(
- window, key="key1", value=1, transaction=tx, timestamp_ms=100
- )
- assert len(expired) == 0
- assert len(updated) == 0
- updated, expired = process(
- window, key="key2", value=5, transaction=tx, timestamp_ms=100
- )
- assert len(expired) == 0
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key1", value=2, transaction=tx, timestamp_ms=110
- )
- assert len(expired) == 0
- assert len(updated) == 0
- updated, expired = process(
- window, key="key2", value=4, transaction=tx, timestamp_ms=110
- )
- assert len(expired) == 0
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key1", value=3, transaction=tx, timestamp_ms=120
- )
- assert expired[0][1]["value"] == [1, 2, 3]
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key1", value=4, transaction=tx, timestamp_ms=130
- )
- assert len(expired) == 0
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key2", value=3, transaction=tx, timestamp_ms=120
- )
- assert expired[0][1]["value"] == [5, 4, 3]
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key2", value=2, transaction=tx, timestamp_ms=130
- )
- assert len(expired) == 0
- assert len(updated) == 0
- updated, expired = process(
- window, key="key1", value=5, transaction=tx, timestamp_ms=140
- )
- assert len(expired) == 0
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key2", value=1, transaction=tx, timestamp_ms=140
- )
- assert len(expired) == 0
- assert len(updated) == 0
-
- updated, expired = process(
- window, key="key2", value=0, transaction=tx, timestamp_ms=130
- )
- assert expired[0][1]["value"] == [2, 1, 0]
- assert len(updated) == 0
- updated, expired = process(
- window, key="key1", value=6, transaction=tx, timestamp_ms=140
- )
- assert expired[0][1]["value"] == [4, 5, 6]
- assert len(updated) == 0
diff --git a/tests/test_quixstreams/test_state/fixtures.py b/tests/test_quixstreams/test_state/fixtures.py
index a24d1dd47..ca76a67a0 100644
--- a/tests/test_quixstreams/test_state/fixtures.py
+++ b/tests/test_quixstreams/test_state/fixtures.py
@@ -45,19 +45,16 @@ def factory(
changelog_name: str = "",
partition_num: int = 0,
store_partition: Optional[StorePartition] = None,
- committed_offsets: Optional[dict[str, int]] = None,
lowwater: int = 0,
highwater: int = 0,
):
changelog_name = changelog_name or f"changelog__{str(uuid.uuid4())}"
if not store_partition:
store_partition = MagicMock(spec_set=StorePartition)
- committed_offsets = committed_offsets or {}
recovery_partition = RecoveryPartition(
changelog_name=changelog_name,
partition_num=partition_num,
store_partition=store_partition,
- committed_offsets=committed_offsets,
lowwater=lowwater,
highwater=highwater,
)
diff --git a/tests/test_quixstreams/test_state/test_manager.py b/tests/test_quixstreams/test_state/test_manager.py
index c4f8de62d..ed991570c 100644
--- a/tests/test_quixstreams/test_state/test_manager.py
+++ b/tests/test_quixstreams/test_state/test_manager.py
@@ -47,9 +47,7 @@ def test_init_state_dir_exists_not_a_dir_fails(
def test_rebalance_partitions_stores_not_registered(self, state_manager):
# It's ok to rebalance partitions when there are no stores registered
- state_manager.on_partition_assign(
- stream_id="topic", partition=0, committed_offsets={"topic": -1001}
- )
+ state_manager.on_partition_assign(stream_id="topic", partition=0)
state_manager.on_partition_revoke(stream_id="topic", partition=0)
def test_register_store(self, state_manager):
@@ -71,13 +69,10 @@ def test_assign_revoke_partitions_stores_registered(self, state_manager):
]
store_partitions = []
- committed_offsets = {"topic1": -1001, "topic2": -1001}
for tp in partitions:
store_partitions.extend(
state_manager.on_partition_assign(
- stream_id=tp.topic,
- partition=tp.partition,
- committed_offsets=committed_offsets,
+ stream_id=tp.topic, partition=tp.partition
)
)
assert len(store_partitions) == 3
@@ -141,7 +136,6 @@ def test_clear_stores(self, state_manager):
state_manager.on_partition_assign(
stream_id=tp.topic,
partition=tp.partition,
- committed_offsets={"topic1": -1001, "topic2": -1001},
)
# Collect paths of stores to be deleted
@@ -170,9 +164,7 @@ def test_clear_stores_fails(self, state_manager):
state_manager.register_store("topic1", store_name="store1")
# Assign the partition
- state_manager.on_partition_assign(
- stream_id="topic1", partition=0, committed_offsets={"topic1": -1001}
- )
+ state_manager.on_partition_assign(stream_id="topic1", partition=0)
# Act - Delete stores
with pytest.raises(PartitionStoreIsUsed):
@@ -202,9 +194,7 @@ def test_rebalance_partitions_stores_not_registered(
producer=producer,
)
# It's ok to rebalance partitions when there are no stores registered
- state_manager.on_partition_assign(
- stream_id="topic", partition=0, committed_offsets={"topic": -1001}
- )
+ state_manager.on_partition_assign(stream_id="topic", partition=0)
state_manager.on_partition_revoke(stream_id="topic", partition=0)
def test_register_store(
@@ -270,11 +260,7 @@ def test_assign_revoke_partitions_stores_registered(
consumer.assignment.return_value = [changelog_tp]
# Assign a topic partition
- state_manager.on_partition_assign(
- stream_id=topic_name,
- partition=partition,
- committed_offsets={"topic1": -1001},
- )
+ state_manager.on_partition_assign(stream_id=topic_name, partition=partition)
# Check that RecoveryManager has a partition assigned
assert recovery_manager.partitions
diff --git a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py
index 70a3ecead..5a02acdb2 100644
--- a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py
+++ b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py
@@ -80,7 +80,6 @@ def test_assign_partition_invalid_offset(
topic=topic_name,
partition=partition_num,
store_partitions={store_name: store_partition},
- committed_offsets={topic_name: -1001},
)
# No pause or assignments should happen
@@ -131,7 +130,6 @@ def test_single_changelog_message_recovery(
recovery_manager.assign_partition(
topic=topic_name,
partition=0,
- committed_offsets={topic_name: -1001},
store_partitions={store_name: store_partition},
)
@@ -184,7 +182,6 @@ def test_assign_partitions_during_recovery(
recovery_manager.assign_partition(
topic=topic_name,
partition=0,
- committed_offsets={topic_name: -1001},
store_partitions={store_name: store_partition},
)
assert recovery_manager.partitions
@@ -200,7 +197,6 @@ def test_assign_partitions_during_recovery(
recovery_manager.assign_partition(
topic=topic_name,
partition=1,
- committed_offsets={topic_name: -1001},
store_partitions={store_name: store_partition},
)
assert recovery_manager.partitions
@@ -262,7 +258,6 @@ def test_assign_partition_changelog_tp_is_missing(
recovery_manager.assign_partition(
topic=topic_name,
partition=1,
- committed_offsets={topic_name: -1001},
store_partitions={store_name: store_partition},
)
@@ -302,13 +297,11 @@ def test_revoke_partition(self, recovery_manager_factory, topic_manager_factory)
recovery_manager.assign_partition(
topic=topic_name,
partition=0,
- committed_offsets={topic_name: -1001},
store_partitions={store_name: store_partition},
)
recovery_manager.assign_partition(
topic=topic_name,
partition=1,
- committed_offsets={topic_name: -1001},
store_partitions={store_name: store_partition},
)
assert len(recovery_manager.partitions) == 2
@@ -408,7 +401,6 @@ def test_assign_partition(
topic=topic_name,
partition=partition_num,
store_partitions=store_partitions,
- committed_offsets={topic_name: -1001},
)
# Check that RecoveryPartition is assigned to RecoveryManager
@@ -482,7 +474,6 @@ def test_do_recovery(
recovery_manager.assign_partition(
topic=topic_name,
partition=0,
- committed_offsets={topic_name: -1001},
store_partitions={store_name: store_partition},
)
diff --git a/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py b/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py
index f6ffc31b1..0e2fedeba 100644
--- a/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py
+++ b/tests/test_quixstreams/test_state/test_recovery/test_recovery_partition.py
@@ -4,11 +4,7 @@
from confluent_kafka import OFFSET_BEGINNING
from quixstreams.state.exceptions import ColumnFamilyHeaderMissing
-from quixstreams.state.metadata import (
- CHANGELOG_CF_MESSAGE_HEADER,
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER,
- SEPARATOR,
-)
+from quixstreams.state.metadata import CHANGELOG_CF_MESSAGE_HEADER, SEPARATOR
from quixstreams.state.rocksdb import RocksDBStorePartition
from quixstreams.utils.json import dumps
from tests.utils import ConfluentKafkaMessageStub
@@ -75,7 +71,7 @@ def test_initial_offset(
class TestRecoverFromChangelogMessage:
@pytest.mark.parametrize("store_value", [10, None])
- def test_recover_from_changelog_message_no_processed_offset(
+ def test_recover_from_changelog_message_success(
self, store_partition, store_value, recovery_partition_factory
):
"""
@@ -147,104 +143,3 @@ def test_recover_from_changelog_message_invalid_value_type(
recovery_partition.recover_from_changelog_message(
changelog_message=changelog_msg
)
-
- def test_recover_from_changelog_message_with_processed_offset_behind_committed(
- self, store_partition, recovery_partition_factory
- ):
- """
- Test that changes from the changelog topic are applied if the
- source topic offset header is present and is smaller than the latest committed
- offset.
- """
- kafka_key = b"my_key"
- user_store_key = "count"
-
- # Processed offset is behind the committed offset - the changelog belongs
- # to an already committed message and should be applied
- processed_offsets = {"topic": 1}
- committed_offsets = {"topic": 2}
-
- recovery_partition = recovery_partition_factory(
- store_partition=store_partition, committed_offsets=committed_offsets
- )
-
- processed_offset_header = (
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER,
- dumps(processed_offsets),
- )
-
- changelog_msg = ConfluentKafkaMessageStub(
- key=kafka_key + SEPARATOR + dumps(user_store_key),
- value=dumps(10),
- headers=[
- (CHANGELOG_CF_MESSAGE_HEADER, b"default"),
- processed_offset_header,
- ],
- )
-
- recovery_partition.recover_from_changelog_message(changelog_msg)
-
- with store_partition.begin() as tx:
- assert tx.get(user_store_key, prefix=kafka_key) == 10
- assert store_partition.get_changelog_offset() == changelog_msg.offset()
-
- @pytest.mark.parametrize(
- "processed_offsets, committed_offsets",
- # Processed offsets should be strictly lower than committed offsets for
- # the change to be applied
- [
- ({"topic1": 2}, {"topic1": 1}),
- ({"topic1": 2}, {"topic1": 2}),
- ({"topic1": 2, "topic2": 2}, {"topic1": 3, "topic2": 2}),
- ({"topic1": 2, "topic2": 2}, {"topic1": 1, "topic2": 3}),
- ({"topic1": 2, "topic2": 2}, {"topic1": 1, "topic2": 1}),
- ],
- )
- def test_recover_from_changelog_message_with_processed_offset_ahead_committed(
- self,
- store_partition,
- recovery_partition_factory,
- processed_offsets,
- committed_offsets,
- ):
- """
- Test that changes from the changelog topic are NOT applied if the
- source topic offset header is present but larger than the latest committed
- offset.
- It means that the changelog messages were produced during the checkpoint,
- but the topic offset was not committed.
- Possible reasons:
- - Producer couldn't verify the delivery of every changelog message
- - Consumer failed to commit the source topic offsets
- """
- kafka_key = b"my_key"
- user_store_key = "count"
-
- recovery_partition = recovery_partition_factory(
- store_partition=store_partition, committed_offsets=committed_offsets
- )
-
- # Generate the changelog message with processed offset ahead of the committed
- # one
- processed_offset_header = (
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER,
- dumps(processed_offsets),
- )
- changelog_msg = ConfluentKafkaMessageStub(
- key=kafka_key + SEPARATOR + dumps(user_store_key),
- value=dumps(10),
- headers=[
- (CHANGELOG_CF_MESSAGE_HEADER, b"default"),
- processed_offset_header,
- ],
- )
-
- # Recover from the message
- recovery_partition.recover_from_changelog_message(changelog_msg)
-
- # Check that the changes have not been applied, but the changelog offset
- # increased
- with store_partition.begin() as tx:
- assert tx.get(user_store_key, prefix=kafka_key) is None
-
- assert store_partition.get_changelog_offset() == changelog_msg.offset()
diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py
index 5217b5961..33bb92694 100644
--- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py
+++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py
@@ -49,107 +49,6 @@ def test_update_window(transaction_state, value):
assert state.get_window(start_ms=0, end_ms=10) == value
-@pytest.mark.parametrize("delete", [True, False])
-def test_expire_windows(transaction_state, delete):
- duration_ms = 10
-
- with transaction_state() as state:
- state.update_window(start_ms=0, end_ms=10, value=1, timestamp_ms=2)
- state.update_window(start_ms=10, end_ms=20, value=2, timestamp_ms=10)
-
- with transaction_state() as state:
- state.update_window(start_ms=20, end_ms=30, value=3, timestamp_ms=20)
- max_start_time = state.get_latest_timestamp() - duration_ms
- expired = list(
- state.expire_windows(max_start_time=max_start_time, delete=delete)
- )
- # "expire_windows" must update the expiration index so that the same
- # windows are not expired twice
- assert not list(
- state.expire_windows(max_start_time=max_start_time, delete=delete)
- )
-
- assert len(expired) == 2
- assert expired == [
- ((0, 10), 1, [], b"__key__"),
- ((10, 20), 2, [], b"__key__"),
- ]
-
- with transaction_state() as state:
- assert state.get_window(start_ms=0, end_ms=10) == None if delete else 1
- assert state.get_window(start_ms=10, end_ms=20) == None if delete else 2
- assert state.get_window(start_ms=20, end_ms=30) == 3
-
-
-@pytest.mark.parametrize("end_inclusive", [True, False])
-def test_expire_windows_with_collect(transaction_state, end_inclusive):
- duration_ms = 10
-
- with transaction_state() as state:
- # Different window types store values differently:
- # - Tumbling/hopping windows use None as placeholder values
- # - Sliding windows use [int, None] format where int is the max timestamp
- # Note: In production, these different value types would not be mixed
- # within the same state.
- state.update_window(start_ms=0, end_ms=10, value=None, timestamp_ms=2)
- state.update_window(start_ms=10, end_ms=20, value=[777, None], timestamp_ms=10)
-
- state.add_to_collection(value="a", id=0)
- state.add_to_collection(value="b", id=10)
- state.add_to_collection(value="c", id=20)
-
- with transaction_state() as state:
- state.update_window(start_ms=20, end_ms=30, value=None, timestamp_ms=20)
- max_start_time = state.get_latest_timestamp() - duration_ms
- expired = list(
- state.expire_windows(
- max_start_time=max_start_time,
- collect=True,
- end_inclusive=end_inclusive,
- )
- )
-
- window_1_value = ["a", "b"] if end_inclusive else ["a"]
- window_2_value = ["b", "c"] if end_inclusive else ["b"]
- assert expired == [
- ((0, 10), None, window_1_value, b"__key__"),
- ((10, 20), [777, None], window_2_value, b"__key__"),
- ]
-
-
-def test_same_keys_in_db_and_update_cache(transaction_state):
- duration_ms = 10
-
- with transaction_state() as state:
- state.update_window(start_ms=0, end_ms=10, value=1, timestamp_ms=2)
-
- with transaction_state() as state:
- # The same window already exists in the db
- state.update_window(start_ms=0, end_ms=10, value=3, timestamp_ms=8)
-
- state.update_window(start_ms=10, end_ms=20, value=2, timestamp_ms=10)
- max_start_time = state.get_latest_timestamp() - duration_ms
- expired = list(state.expire_windows(max_start_time=max_start_time))
-
- # Value from the cache takes precedence over the value in the db
- assert expired == [((0, 10), 3, [], b"__key__")]
-
-
-def test_get_latest_timestamp(windowed_rocksdb_store_factory):
- store = windowed_rocksdb_store_factory()
- partition = store.assign_partition(0)
- timestamp = 123
- prefix = b"__key__"
- with partition.begin() as tx:
- state = tx.as_state(prefix)
- state.update_window(0, 10, value=1, timestamp_ms=timestamp)
- store.revoke_partition(0)
-
- partition = store.assign_partition(0)
- with partition.begin() as tx:
- assert tx.get_latest_timestamp(prefix=prefix) == timestamp
-
-
@pytest.mark.parametrize(
"db_windows, cached_windows, deleted_windows, get_windows_args, expected_windows",
[
@@ -351,43 +250,6 @@ def test_get_windows(
assert list(windows) == expected_windows
-def test_delete_windows(transaction_state):
- with transaction_state() as state:
- state.update_window(start_ms=1, end_ms=2, value=1, timestamp_ms=1)
- state.update_window(start_ms=2, end_ms=3, value=2, timestamp_ms=2)
- state.update_window(start_ms=3, end_ms=4, value=3, timestamp_ms=3)
-
- with transaction_state() as state:
- assert state.get_window(start_ms=1, end_ms=2)
- assert state.get_window(start_ms=2, end_ms=3)
- assert state.get_window(start_ms=3, end_ms=4)
-
- state.delete_windows(max_start_time=2, delete_values=False)
-
- assert not state.get_window(start_ms=1, end_ms=2)
- assert not state.get_window(start_ms=2, end_ms=3)
- assert state.get_window(start_ms=3, end_ms=4)
-
-
-def test_delete_windows_with_values(transaction_state, get_value):
- with transaction_state() as state:
- state.update_window(start_ms=2, end_ms=3, value=1, timestamp_ms=2)
- state.add_to_collection(value="a", id=1)
- state.add_to_collection(value="b", id=2)
-
- with transaction_state() as state:
- assert state.get_window(start_ms=2, end_ms=3)
- assert get_value(timestamp_ms=1, counter=0) == "a"
- assert get_value(timestamp_ms=2, counter=1) == "b"
-
- state.delete_windows(max_start_time=2, delete_values=True)
-
- with transaction_state() as state:
- assert not state.get_window(start_ms=2, end_ms=3)
- assert not get_value(timestamp_ms=1, counter=0)
- assert get_value(timestamp_ms=2, counter=1) == "b"
-
-
@pytest.mark.parametrize("value", [1, "string", None, ["list"], {"dict": "dict"}])
def test_add_to_collection(transaction_state, get_value, value):
with transaction_state() as state:
diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py
index 6808b0fef..723837fd7 100644
--- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py
+++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py
@@ -1,11 +1,7 @@
import pytest
-from quixstreams.state.metadata import (
- CHANGELOG_CF_MESSAGE_HEADER,
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER,
-)
+from quixstreams.state.metadata import CHANGELOG_CF_MESSAGE_HEADER
from quixstreams.state.serialization import encode_integer_pair
-from quixstreams.utils.json import dumps
class TestWindowedRocksDBPartitionTransaction:
@@ -44,59 +40,48 @@ def test_delete_window(self, windowed_rocksdb_store_factory):
assert tx.get_window(start_ms=0, end_ms=10, prefix=prefix) is None
@pytest.mark.parametrize("delete", [True, False])
- def test_expire_windows_expired(self, windowed_rocksdb_store_factory, delete):
+ def test_expire_all_windows_expired(self, windowed_rocksdb_store_factory, delete):
store = windowed_rocksdb_store_factory()
store.assign_partition(0)
- prefix = b"__key__"
- duration_ms = 10
+ prefix1 = b"__key__1"
+ prefix2 = b"__key__2"
with store.start_partition_transaction(0) as tx:
tx.update_window(
- start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix
+ start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix1
)
tx.update_window(
- start_ms=10, end_ms=20, value=2, timestamp_ms=10, prefix=prefix
+ start_ms=10, end_ms=20, value=2, timestamp_ms=10, prefix=prefix2
)
with store.start_partition_transaction(0) as tx:
tx.update_window(
- start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix
- )
- max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms
- expired = list(
- tx.expire_windows(
- max_start_time=max_start_time, prefix=prefix, delete=delete
- )
- )
- # "expire_windows" must update the expiration index so that the same
- # windows are not expired twice
- assert not list(
- tx.expire_windows(
- max_start_time=max_start_time, prefix=prefix, delete=delete
- )
+ start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix1
)
+ expired = list(tx.expire_all_windows(max_end_time=20, delete=delete))
+ assert not list(tx.expire_all_windows(max_end_time=20, delete=delete))
assert len(expired) == 2
assert expired == [
- ((0, 10), 1, [], prefix),
- ((10, 20), 2, [], prefix),
+ ((0, 10), 1, [], prefix1),
+ ((10, 20), 2, [], prefix2),
]
with store.start_partition_transaction(0) as tx:
assert (
- tx.get_window(start_ms=0, end_ms=10, prefix=prefix) == None
+ tx.get_window(start_ms=0, end_ms=10, prefix=prefix1) is None
if delete
else 1
)
assert (
- tx.get_window(start_ms=10, end_ms=20, prefix=prefix) == None
+ tx.get_window(start_ms=10, end_ms=20, prefix=prefix2) is None
if delete
else 2
)
- assert tx.get_window(start_ms=20, end_ms=30, prefix=prefix) == 3
+ assert tx.get_window(start_ms=20, end_ms=30, prefix=prefix1) == 3
@pytest.mark.parametrize("delete", [True, False])
- def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete):
+ def test_expire_all_windows_cached(self, windowed_rocksdb_store_factory, delete):
"""
Check that windows expire correctly even if they're not committed to the DB
yet.
@@ -104,7 +89,6 @@ def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete):
store = windowed_rocksdb_store_factory()
store.assign_partition(0)
prefix = b"__key__"
- duration_ms = 10
with store.start_partition_transaction(0) as tx:
tx.update_window(
@@ -116,41 +100,31 @@ def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete):
tx.update_window(
start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix
)
- max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms
- expired = list(
- tx.expire_windows(
- max_start_time=max_start_time, prefix=prefix, delete=delete
- )
- )
+ expired = list(tx.expire_all_windows(max_end_time=20, delete=delete))
# "expire_windows" must update the expiration index so that the same
# windows are not expired twice
- assert not list(
- tx.expire_windows(
- max_start_time=max_start_time, prefix=prefix, delete=delete
- )
- )
+ assert not list(tx.expire_all_windows(max_end_time=20, delete=delete))
assert len(expired) == 2
assert expired == [
((0, 10), 1, [], prefix),
((10, 20), 2, [], prefix),
]
assert (
- tx.get_window(start_ms=0, end_ms=10, prefix=prefix) == None
+ tx.get_window(start_ms=0, end_ms=10, prefix=prefix) is None
if delete
else 1
)
assert (
- tx.get_window(start_ms=10, end_ms=20, prefix=prefix) == None
+ tx.get_window(start_ms=10, end_ms=20, prefix=prefix) is None
if delete
else 2
)
assert tx.get_window(start_ms=20, end_ms=30, prefix=prefix) == 3
- def test_expire_windows_empty(self, windowed_rocksdb_store_factory):
+ def test_expire_all_windows_empty(self, windowed_rocksdb_store_factory):
store = windowed_rocksdb_store_factory()
store.assign_partition(0)
prefix = b"__key__"
- duration_ms = 10
with store.start_partition_transaction(0) as tx:
tx.update_window(
@@ -164,43 +138,62 @@ def test_expire_windows_empty(self, windowed_rocksdb_store_factory):
tx.update_window(
start_ms=3, end_ms=13, value=1, timestamp_ms=3, prefix=prefix
)
- max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms
- assert not list(
- tx.expire_windows(max_start_time=max_start_time, prefix=prefix)
- )
+ assert not list(tx.expire_all_windows(max_end_time=3))
- def test_expire_windows_with_grace_expired(self, windowed_rocksdb_store_factory):
+ @pytest.mark.parametrize("end_inclusive", [True, False])
+ def test_expire_all_windows_with_collect(
+ self, windowed_rocksdb_store_factory, end_inclusive
+ ):
store = windowed_rocksdb_store_factory()
store.assign_partition(0)
prefix = b"__key__"
- duration_ms = 10
- grace_ms = 5
with store.start_partition_transaction(0) as tx:
+ # Different window types store values differently:
+ # - Tumbling/hopping windows use None as placeholder values
+ # - Sliding windows use [int, None] format where int is the max timestamp
+ # Note: In production, these different value types would not be mixed
+ # within the same state.
tx.update_window(
- start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix
+ start_ms=0, end_ms=10, value=None, timestamp_ms=2, prefix=prefix
+ )
+ tx.update_window(
+ start_ms=10,
+ end_ms=20,
+ value=[777, None],
+ timestamp_ms=10,
+ prefix=prefix,
)
+ tx.add_to_collection(value="a", id=0, prefix=prefix)
+ tx.add_to_collection(value="b", id=10, prefix=prefix)
+ tx.add_to_collection(value="c", id=20, prefix=prefix)
+
with store.start_partition_transaction(0) as tx:
tx.update_window(
- start_ms=15, end_ms=25, value=1, timestamp_ms=15, prefix=prefix
- )
- max_start_time = (
- tx.get_latest_timestamp(prefix=prefix) - duration_ms - grace_ms
+ start_ms=20, end_ms=30, value=None, timestamp_ms=20, prefix=prefix
)
expired = list(
- tx.expire_windows(max_start_time=max_start_time, prefix=prefix)
+ tx.expire_all_windows(
+ max_end_time=20,
+ collect=True,
+ end_inclusive=end_inclusive,
+ )
)
- assert len(expired) == 1
- assert expired == [((0, 10), 1, [], prefix)]
+ window_1_value = ["a", "b"] if end_inclusive else ["a"]
+ window_2_value = ["b", "c"] if end_inclusive else ["b"]
+ assert expired == [
+ ((0, 10), None, window_1_value, b"__key__"),
+ ((10, 20), [777, None], window_2_value, b"__key__"),
+ ]
- def test_expire_windows_with_grace_empty(self, windowed_rocksdb_store_factory):
+ def test_expire_all_windows_same_keys_in_db_and_update_cache(
+ self, windowed_rocksdb_store_factory
+ ):
store = windowed_rocksdb_store_factory()
store.assign_partition(0)
prefix = b"__key__"
- duration_ms = 10
- grace_ms = 5
with store.start_partition_transaction(0) as tx:
tx.update_window(
@@ -208,17 +201,17 @@ def test_expire_windows_with_grace_empty(self, windowed_rocksdb_store_factory):
)
with store.start_partition_transaction(0) as tx:
+ # The same window already exists in the db
tx.update_window(
- start_ms=13, end_ms=23, value=1, timestamp_ms=13, prefix=prefix
+ start_ms=0, end_ms=10, value=3, timestamp_ms=8, prefix=prefix
)
- max_start_time = (
- tx.get_latest_timestamp(prefix=prefix) - duration_ms - grace_ms
- )
- expired = list(
- tx.expire_windows(max_start_time=max_start_time, prefix=prefix)
+ tx.update_window(
+ start_ms=10, end_ms=20, value=2, timestamp_ms=10, prefix=prefix
)
+ expired = list(tx.expire_all_windows(max_end_time=10))
- assert not expired
+ # Value from the cache takes precedence over the value in the db
+ assert expired == [((0, 10), 3, [], b"__key__")]
@pytest.mark.parametrize(
"start_ms, end_ms",
@@ -277,87 +270,6 @@ def test_delete_window_invalid_duration(
with pytest.raises(ValueError, match="Invalid window duration"):
tx.delete_window(start_ms=start_ms, end_ms=end_ms, prefix=prefix)
- def test_expire_windows_no_expired(self, windowed_rocksdb_store_factory):
- store = windowed_rocksdb_store_factory()
- store.assign_partition(0)
- prefix = b"__key__"
- duration_ms = 10
-
- with store.start_partition_transaction(0) as tx:
- tx.update_window(
- start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix
- )
-
- with store.start_partition_transaction(0) as tx:
- tx.update_window(
- start_ms=1, end_ms=11, value=1, timestamp_ms=9, prefix=prefix
- )
- # "expire_windows" must update the expiration index so that the same
- # windows are not expired twice
- max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms
- assert not list(
- tx.expire_windows(max_start_time=max_start_time, prefix=prefix)
- )
-
- def test_expire_windows_multiple_windows(self, windowed_rocksdb_store_factory):
- store = windowed_rocksdb_store_factory()
- store.assign_partition(0)
- prefix = b"__key__"
- duration_ms = 10
-
- with store.start_partition_transaction(0) as tx:
- tx.update_window(
- start_ms=0, end_ms=10, value=1, timestamp_ms=2, prefix=prefix
- )
- tx.update_window(
- start_ms=10, end_ms=20, value=1, timestamp_ms=11, prefix=prefix
- )
- tx.update_window(
- start_ms=20, end_ms=30, value=1, timestamp_ms=21, prefix=prefix
- )
-
- with store.start_partition_transaction(0) as tx:
- tx.update_window(
- start_ms=30, end_ms=40, value=1, timestamp_ms=31, prefix=prefix
- )
- # "expire_windows" must update the expiration index so that the same
- # windows are not expired twice
- max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms
- expired = list(
- tx.expire_windows(max_start_time=max_start_time, prefix=prefix)
- )
-
- assert len(expired) == 3
- assert expired[0] == ((0, 10), 1, [], prefix)
- assert expired[1] == ((10, 20), 1, [], prefix)
- assert expired[2] == ((20, 30), 1, [], prefix)
-
- def test_get_latest_timestamp_update(self, windowed_rocksdb_store_factory):
- store = windowed_rocksdb_store_factory()
- partition = store.assign_partition(0)
- timestamp = 123
- prefix = b"__key__"
- with partition.begin() as tx:
- tx.update_window(0, 10, value=1, timestamp_ms=timestamp, prefix=prefix)
-
- with partition.begin() as tx:
- assert tx.get_latest_timestamp(prefix=prefix) == timestamp
-
- def test_get_latest_timestamp_cannot_go_backwards(
- self, windowed_rocksdb_store_factory
- ):
- store = windowed_rocksdb_store_factory()
- partition = store.assign_partition(0)
- timestamp = 9
- prefix = b"__key__"
- with partition.begin() as tx:
- tx.update_window(0, 10, value=1, timestamp_ms=timestamp, prefix=prefix)
- tx.update_window(0, 10, value=1, timestamp_ms=timestamp - 1, prefix=prefix)
- assert tx.get_latest_timestamp(prefix=prefix) == timestamp
-
- with partition.begin() as tx:
- assert tx.get_latest_timestamp(prefix=prefix) == timestamp
-
def test_update_window_and_prepare(
self, windowed_rocksdb_partition_factory, changelog_producer_mock
):
@@ -365,7 +277,6 @@ def test_update_window_and_prepare(
start_ms = 0
end_ms = 10
value = 1
- processed_offsets = {"topic": 1}
with windowed_rocksdb_partition_factory(
changelog_producer=changelog_producer_mock
@@ -378,12 +289,10 @@ def test_update_window_and_prepare(
timestamp_ms=2,
prefix=prefix,
)
- tx.prepare(processed_offsets=processed_offsets)
+ tx.prepare()
assert tx.prepared
- # The transaction is expected to produce 2 keys for each updated one:
- # One for the window itself, and another for the latest timestamp
- assert changelog_producer_mock.produce.call_count == 2
+ assert changelog_producer_mock.produce.call_count == 1
expected_produced_key = tx._serialize_key(
encode_integer_pair(start_ms, end_ms), prefix=prefix
)
@@ -391,10 +300,7 @@ def test_update_window_and_prepare(
changelog_producer_mock.produce.assert_any_call(
key=expected_produced_key,
value=expected_produced_value,
- headers={
- CHANGELOG_CF_MESSAGE_HEADER: "default",
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets),
- },
+ headers={CHANGELOG_CF_MESSAGE_HEADER: "default"},
)
def test_delete_window_and_prepare(
@@ -403,14 +309,13 @@ def test_delete_window_and_prepare(
prefix = b"__key__"
start_ms = 0
end_ms = 10
- processed_offsets = {"topic": 1}
with windowed_rocksdb_partition_factory(
changelog_producer=changelog_producer_mock
) as store_partition:
tx = store_partition.begin()
tx.delete_window(start_ms=start_ms, end_ms=end_ms, prefix=prefix)
- tx.prepare(processed_offsets=processed_offsets)
+ tx.prepare()
assert tx.prepared
assert changelog_producer_mock.produce.call_count == 1
@@ -420,8 +325,5 @@ def test_delete_window_and_prepare(
changelog_producer_mock.produce.assert_called_with(
key=expected_produced_key,
value=None,
- headers={
- CHANGELOG_CF_MESSAGE_HEADER: "default",
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets),
- },
+ headers={CHANGELOG_CF_MESSAGE_HEADER: "default"},
)
diff --git a/tests/test_quixstreams/test_state/test_transaction.py b/tests/test_quixstreams/test_state/test_transaction.py
index 076b07354..b01a1cb8c 100644
--- a/tests/test_quixstreams/test_state/test_transaction.py
+++ b/tests/test_quixstreams/test_state/test_transaction.py
@@ -13,11 +13,7 @@
StateTransactionError,
)
from quixstreams.state.manager import SUPPORTED_STORES
-from quixstreams.state.metadata import (
- CHANGELOG_CF_MESSAGE_HEADER,
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER,
- Marker,
-)
+from quixstreams.state.metadata import CHANGELOG_CF_MESSAGE_HEADER, Marker
from quixstreams.state.serialization import serialize
from quixstreams.utils.json import dumps
@@ -345,7 +341,7 @@ def test_update_key_prepared_transaction_fails(self, store_partition):
tx = store_partition.begin()
tx.set(key="key", value="value", prefix=prefix)
- tx.prepare(processed_offsets={"topic": 1})
+ tx.prepare()
assert tx.prepared
with pytest.raises(StateTransactionError):
@@ -445,7 +441,6 @@ def test_set_and_prepare(self, store_partition_factory, changelog_producer_mock)
]
cf = "default"
prefix = b"__key__"
- processed_offsets = {"topic": 1}
with store_partition_factory(
changelog_producer=changelog_producer_mock
@@ -458,7 +453,7 @@ def test_set_and_prepare(self, store_partition_factory, changelog_producer_mock)
cf_name=cf,
prefix=prefix,
)
- tx.prepare(processed_offsets=processed_offsets)
+ tx.prepare()
assert changelog_producer_mock.produce.call_count == len(data)
@@ -467,12 +462,7 @@ def test_set_and_prepare(self, store_partition_factory, changelog_producer_mock)
):
assert call.kwargs["key"] == tx._serialize_key(key=key, prefix=prefix)
assert call.kwargs["value"] == tx._serialize_value(value=value)
- assert call.kwargs["headers"] == {
- CHANGELOG_CF_MESSAGE_HEADER: cf,
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(
- processed_offsets
- ),
- }
+ assert call.kwargs["headers"] == {CHANGELOG_CF_MESSAGE_HEADER: cf}
assert tx.prepared
@@ -480,7 +470,6 @@ def test_delete_and_prepare(self, store_partition_factory, changelog_producer_mo
key = "key"
cf = "default"
prefix = b"__key__"
- processed_offsets = {"topic": 1}
with store_partition_factory(
changelog_producer=changelog_producer_mock
@@ -488,7 +477,7 @@ def test_delete_and_prepare(self, store_partition_factory, changelog_producer_mo
tx = partition.begin()
tx.delete(key=key, cf_name=cf, prefix=prefix)
- tx.prepare(processed_offsets=processed_offsets)
+ tx.prepare()
assert tx.prepared
assert changelog_producer_mock.produce.call_count == 1
@@ -498,10 +487,7 @@ def test_delete_and_prepare(self, store_partition_factory, changelog_producer_mo
key=key, prefix=prefix
)
assert delete_changelog.kwargs["value"] is None
- assert delete_changelog.kwargs["headers"] == {
- CHANGELOG_CF_MESSAGE_HEADER: cf,
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets),
- }
+ assert delete_changelog.kwargs["headers"] == {CHANGELOG_CF_MESSAGE_HEADER: cf}
def test_set_delete_and_prepare(
self, store_partition_factory, changelog_producer_mock
@@ -513,7 +499,6 @@ def test_set_delete_and_prepare(
key, value = "key", "value"
cf = "default"
prefix = b"__key__"
- processed_offsets = {"topic": 1}
with store_partition_factory(
changelog_producer=changelog_producer_mock
@@ -522,7 +507,7 @@ def test_set_delete_and_prepare(
tx.set(key=key, value=value, cf_name=cf, prefix=prefix)
tx.delete(key=key, cf_name=cf, prefix=prefix)
- tx.prepare(processed_offsets=processed_offsets)
+ tx.prepare()
assert tx.prepared
assert changelog_producer_mock.produce.call_count == 1
@@ -532,8 +517,7 @@ def test_set_delete_and_prepare(
)
assert delete_changelog.kwargs["value"] is None
assert delete_changelog.kwargs["headers"] == {
- CHANGELOG_CF_MESSAGE_HEADER: cf,
- CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER: dumps(processed_offsets),
+ CHANGELOG_CF_MESSAGE_HEADER: cf
}