Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 110 additions & 53 deletions quixstreams/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
import uuid
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Callable, List, Literal, Optional, Protocol, Tuple, Type, Union, cast

Expand All @@ -15,7 +14,7 @@
from pydantic_settings import BaseSettings as PydanticBaseSettings
from pydantic_settings import PydanticBaseSettingsSource, SettingsConfigDict

from .context import copy_context, set_message_context
from .context import MessageContext, copy_context, set_message_context
from .core.stream.functions.types import VoidExecutor
from .dataframe import DataFrameRegistry, StreamingDataFrame
from .error_callbacks import (
Expand Down Expand Up @@ -46,6 +45,7 @@
)
from .platforms.quix.env import QUIX_ENVIRONMENT
from .processing import ProcessingContext
from .processing.watermarking import WatermarkManager
from .runtracker import RunTracker
from .sinks import SinkManager
from .sources import BaseSource, SourceException, SourceManager
Expand Down Expand Up @@ -384,6 +384,9 @@ def __init__(
sink_manager=self._sink_manager,
dataframe_registry=self._dataframe_registry,
)
self._watermark_manager = WatermarkManager(
producer=self._producer, topic_manager=self._topic_manager
)
self._run_tracker = RunTracker()

@property
Expand Down Expand Up @@ -903,9 +906,19 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None):
printer = self._processing_context.printer
run_tracker = self._run_tracker
consumer = self._consumer
producer = self._producer
producer_poll_timeout = self._config.producer_poll_timeout
watermark_manager = self._watermark_manager

# Set the topics to be tracked by the Watermark manager
watermark_manager.set_topics(topics=self._dataframe_registry.consumer_topics)

consumer.subscribe(
topics=self._dataframe_registry.consumer_topics + changelog_topics,
topics=self._dataframe_registry.consumer_topics
+ changelog_topics
+ [
self._watermark_manager.watermarks_topic
], # TODO: We subscribe here because otherwise it can't deserialize a message. Maybe it's time to split poll() and deserialization
on_assign=self._on_assign,
on_revoke=self._on_revoke,
on_lost=self._on_lost,
Expand All @@ -922,11 +935,14 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None):
state_manager.do_recovery()
run_tracker.timeout_refresh()
else:
# Serve producer callbacks
producer.poll(producer_poll_timeout)
process_message(dataframes_composed)
processing_context.commit_checkpoint()
consumer.resume_backpressured()
source_manager.raise_for_error()
printer.print()
watermark_manager.produce()
run_tracker.update_status()

logger.info("Stopping the application")
Expand Down Expand Up @@ -954,9 +970,7 @@ def _quix_runtime_init(self):
if self._state_manager.stores:
check_state_management_enabled()

def _process_message(self, dataframe_composed):
# Serve producer callbacks
self._producer.poll(self._config.producer_poll_timeout)
def _process_message(self, dataframe_composed: dict[str, VoidExecutor]):
rows = self._consumer.poll_row(
timeout=self._config.consumer_poll_timeout,
buffered=self._dataframe_registry.requires_time_alignment,
Expand All @@ -978,6 +992,38 @@ def _process_message(self, dataframe_composed):
first_row.offset,
)

# TODO: Maybe store the topic name into some variable so we don't need to access it like that
if topic_name == self._watermark_manager.watermarks_topic.name:
watermark = self._watermark_manager.receive(message=first_row.value)
if watermark is None:
return

logger.info(f"Process watermark {watermark}")

data_topics = self._topic_manager.non_changelog_topics
# TODO: Expose data TPs assignment on App level?
data_tps = [
tp for tp in self._consumer.assignment() if tp.topic in data_topics
]
for tp in data_tps:
watermark_ctx = MessageContext(
topic=tp.topic,
partition=tp.partition,
size=0,
)
context = copy_context()
context.run(set_message_context, watermark_ctx)
# Execute StreamingDataFrame in a context
context.run(
dataframe_composed[tp.topic],
value=None,
key=None,
timestamp=watermark,
headers=[],
is_watermark=True,
)
return

for row in rows:
context = copy_context()
context.run(set_message_context, row.context)
Expand All @@ -997,6 +1043,10 @@ def _process_message(self, dataframe_composed):
if not to_suppress:
raise

self._watermark_manager.store(
topic=row.topic, partition=row.partition, timestamp=row.timestamp
)

# Store the message offset after it's successfully processed
self._processing_context.store_offset(
topic=topic_name, partition=partition, offset=offset
Expand Down Expand Up @@ -1024,42 +1074,33 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]):
self._source_manager.start_sources()

# Assign partitions manually to pause the changelog topics
self._consumer.assign(topic_partitions)
# Pause changelog topic+partitions immediately after assignment
non_changelog_topics = self._topic_manager.non_changelog_topics
changelog_tps = [
tp for tp in topic_partitions if tp.topic not in non_changelog_topics
watermarks_partitions = [
TopicPartition(
topic=self._watermark_manager.watermarks_topic.name, partition=i
)
for i in range(
self._watermark_manager.watermarks_topic.broker_config.num_partitions
)
]
# TODO: The set is used because the watermark tp can already be present in the "topic_partitions"
# because we use `subscribe()` earlier. Fix the mess later.
# TODO: Also, how to avoid reading the whole WM topic on each restart?
# We really need only the most recent data
# Is it fine to read it from the end? The active partitions must still publish something.
# Or should we commit it?
self._consumer.assign(list(set(topic_partitions + watermarks_partitions)))

# Pause changelog topic+partitions immediately after assignment
changelog_topics = {t.name for t in self._topic_manager.changelog_topics_list}
changelog_tps = [tp for tp in topic_partitions if tp.topic in changelog_topics]
self._consumer.pause(changelog_tps)

if self._state_manager.stores:
non_changelog_tps = [
tp for tp in topic_partitions if tp.topic in non_changelog_topics
]
committed_tps = self._consumer.committed(
partitions=non_changelog_tps, timeout=30
)
committed_offsets: dict[int, dict[str, int]] = defaultdict(dict)
for tp in committed_tps:
if tp.error:
raise RuntimeError(
f"Failed to get committed offsets for "
f'"{tp.topic}[{tp.partition}]" from the broker: {tp.error}'
)
committed_offsets[tp.partition][tp.topic] = tp.offset
data_topics = self._topic_manager.non_changelog_topics
data_tps = [tp for tp in topic_partitions if tp.topic in data_topics]

for tp in data_tps:
self._assign_state_partitions(topic=tp.topic, partition=tp.partition)

# Match the assigned TP with a stream ID via DataFrameRegistry
for tp in non_changelog_tps:
stream_ids = self._dataframe_registry.get_stream_ids(
topic_name=tp.topic
)
# Assign store partitions for the given stream ids
for stream_id in stream_ids:
self._state_manager.on_partition_assign(
stream_id=stream_id,
partition=tp.partition,
committed_offsets=committed_offsets[tp.partition],
)
self._run_tracker.timeout_refresh()

def _on_revoke(self, _, topic_partitions: List[TopicPartition]):
Expand All @@ -1079,7 +1120,12 @@ def _on_revoke(self, _, topic_partitions: List[TopicPartition]):
else:
self._processing_context.commit_checkpoint(force=True)

self._revoke_state_partitions(topic_partitions=topic_partitions)
data_topics = self._topic_manager.non_changelog_topics
data_tps = [tp for tp in topic_partitions if tp.topic in data_topics]
for tp in data_tps:
self._watermark_manager.untrack(topic=tp.topic, partition=tp.partition)
self._revoke_state_partitions(topic=tp.topic, partition=tp.partition)

self._consumer.reset_backpressure()

def _on_lost(self, _, topic_partitions: List[TopicPartition]):
Expand All @@ -1088,23 +1134,34 @@ def _on_lost(self, _, topic_partitions: List[TopicPartition]):
"""
logger.debug("Rebalancing: dropping lost partitions")

self._revoke_state_partitions(topic_partitions=topic_partitions)
data_tps = [
tp
for tp in topic_partitions
if tp.topic in self._topic_manager.non_changelog_topics
]
for tp in data_tps:
self._watermark_manager.untrack(topic=tp.topic, partition=tp.partition)
self._revoke_state_partitions(topic=tp.topic, partition=tp.partition)

self._consumer.reset_backpressure()

def _revoke_state_partitions(self, topic_partitions: List[TopicPartition]):
non_changelog_topics = self._topic_manager.non_changelog_topics
non_changelog_tps = [
tp for tp in topic_partitions if tp.topic in non_changelog_topics
]
for tp in non_changelog_tps:
if self._state_manager.stores:
stream_ids = self._dataframe_registry.get_stream_ids(
topic_name=tp.topic
def _assign_state_partitions(self, topic: str, partition: int):
if self._state_manager.stores:
# Match the assigned TP with a stream ID via DataFrameRegistry
stream_ids = self._dataframe_registry.get_stream_ids(topic_name=topic)
# Assign store partitions for the given stream ids
for stream_id in stream_ids:
self._state_manager.on_partition_assign(
stream_id=stream_id, partition=partition
)

def _revoke_state_partitions(self, topic: str, partition: int):
if self._state_manager.stores:
stream_ids = self._dataframe_registry.get_stream_ids(topic_name=topic)
for stream_id in stream_ids:
self._state_manager.on_partition_revoke(
stream_id=stream_id, partition=partition
)
for stream_id in stream_ids:
self._state_manager.on_partition_revoke(
stream_id=stream_id, partition=tp.partition
)

def _setup_signal_handlers(self):
signal.signal(signal.SIGINT, self._on_sigint)
Expand Down
49 changes: 23 additions & 26 deletions quixstreams/checkpointing/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import logging
import time
from abc import abstractmethod
Expand Down Expand Up @@ -26,7 +27,7 @@
logger = logging.getLogger(__name__)


class BaseCheckpoint:
class BaseCheckpoint(abc.ABC):
"""
Base class to keep track of state updates and consumer offsets and to checkpoint these
updates on schedule.
Expand Down Expand Up @@ -70,7 +71,7 @@ def empty(self) -> bool:
Returns `True` if checkpoint doesn't have any offsets stored yet.
:return:
"""
return not bool(self._tp_offsets)
return not bool(self._tp_offsets) and not bool(self._store_transactions)

def store_offset(self, topic: str, partition: int, offset: int):
"""
Expand Down Expand Up @@ -228,20 +229,12 @@ def commit(self):
partition,
store_name,
), transaction in self._store_transactions.items():
topics = self._dataframe_registry.get_topics_for_stream_id(
stream_id=stream_id
)
processed_offsets = {
topic: offset
for (topic, partition_), offset in self._tp_offsets.items()
if topic in topics and partition_ == partition
}
if transaction.failed:
raise StoreTransactionFailed(
f'Detected a failed transaction for store "{store_name}", '
f"the checkpoint is aborted"
)
transaction.prepare(processed_offsets=processed_offsets)
transaction.prepare()

# Step 3. Flush producer to trigger all delivery callbacks and ensure that
# all messages are produced
Expand All @@ -258,21 +251,25 @@ def commit(self):
TopicPartition(topic=topic, partition=partition, offset=offset + 1)
for (topic, partition), offset in self._tp_offsets.items()
]

if self._exactly_once:
self._producer.commit_transaction(
offsets, self._consumer.consumer_group_metadata()
)
else:
logger.debug("Checkpoint: committing consumer")
try:
partitions = self._consumer.commit(offsets=offsets, asynchronous=False)
except KafkaException as e:
raise CheckpointConsumerCommitError(e.args[0]) from None

for partition in partitions:
if partition.error:
raise CheckpointConsumerCommitError(partition.error)
if offsets:
# TODO: Test, update the exactly-once branch to work without offsets
# Checkpoint may have no offsets processed when watermarks are processed
if self._exactly_once:
self._producer.commit_transaction(
offsets, self._consumer.consumer_group_metadata()
)
else:
logger.debug("Checkpoint: committing consumer")
try:
partitions = self._consumer.commit(
offsets=offsets, asynchronous=False
)
except KafkaException as e:
raise CheckpointConsumerCommitError(e.args[0]) from None

for partition in partitions:
if partition.error:
raise CheckpointConsumerCommitError(partition.error)

# Step 5. Flush state store partitions to the disk together with changelog
# offsets.
Expand Down
Loading
Loading