From 2d286960b4197cc66b34950a2e7e448681954ce8 Mon Sep 17 00:00:00 2001 From: Dan Herrera Date: Fri, 4 Oct 2024 11:03:34 -0700 Subject: [PATCH] Update Bytewax version to 0.21.0 --- dataflow.py | 10 +- requirements.txt | 7 +- step1.py | 33 ++--- step2.py | 71 ++++++----- step3.py | 117 +++++++++--------- step4.py | 41 ++++--- step5.py | 173 +++++++++++++------------- step6.py | 189 ++++++++++++++--------------- utils/connectors/slack/__init__.py | 3 - utils/connectors/slack/message.py | 3 +- utils/connectors/slack/sink.py | 4 +- utils/connectors/slack/source.py | 8 +- utils/proxy.py | 1 + 13 files changed, 319 insertions(+), 341 deletions(-) diff --git a/dataflow.py b/dataflow.py index fcc804f..4f9f4ca 100644 --- a/dataflow.py +++ b/dataflow.py @@ -6,12 +6,7 @@ import dotenv - -def _build_dataflow() -> Dataflow: - flow = Dataflow("supercharged-slackbot") - - return flow - +flow = Dataflow("supercharged-slackbot") # Load environment variables from .env dotenv.load_dotenv() @@ -21,6 +16,3 @@ def _build_dataflow() -> Dataflow: format="%(asctime)s %(levelname)-7s %(message)s", handlers=[logging.StreamHandler()], ) - -# Dataflow needs to be assigned to a global variable called "flow" -flow = _build_dataflow() diff --git a/requirements.txt b/requirements.txt index 8decade..68cfa11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -bytewax==0.18.0 +bytewax==0.21.0 annotated-types==0.6.0 anyio==4.2.0 certifi==2024.2.2 charset-normalizer==3.3.2 colorama==0.4.6 coloredlogs==15.0.1 -fastembed==0.1.1 +fastembed==0.3.6 flatbuffers==23.5.26 grpcio==1.60.1 grpcio-tools==1.60.1 @@ -19,7 +19,7 @@ hyperframe==6.0.1 idna==3.6 mpmath==1.3.0 numpy==1.26.4 -onnx==1.15.0 +onnx==1.17.0 onnxruntime==1.17.0 packaging==23.2 portalocker==2.8.2 @@ -32,7 +32,6 @@ qdrant-client[fastembed]==1.7.3 requests==2.31.0 sniffio==1.3.0 sympy==1.12 -tokenizers==0.13.3 tqdm==4.66.2 typing_extensions==4.9.0 urllib3==2.2.0 diff --git a/step1.py b/step1.py index 7f72e67..55af546 100644 --- a/step1.py +++ b/step1.py @@ -18,6 +18,9 @@ log = logging.getLogger(__name__) +# Load environment variables from .env +dotenv.load_dotenv() + def channel_is(channel: str) -> Callable[[SlackMessage], bool]: """Predicate function to check if the message was posted on the given channel.""" @@ -28,33 +31,23 @@ def _func(msg: SlackMessage) -> bool: return _func -def _build_dataflow() -> Dataflow: - # Create a bytewax stream object. - flow = Dataflow("supercharged-slackbot") - - # Data will be flowing in from the Slack stream. - stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) - - # Inspect will show what entries are in the stream. - op.inspect_debug("debug", stream) +# Create a bytewax stream object. +flow = Dataflow("supercharged-slackbot") - # Filter the messages based on which Slack channel they were posted on. - stream = op.filter("filter_channel", stream, channel_is(os.environ["SLACK_CHANNEL_ID"])) +# Data will be flowing in from the Slack stream. +stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) - # Output the messages into the console - op.output("output", stream, StdOutSink()) +# Inspect will show what entries are in the stream. +op.inspect_debug("debug", stream) - return flow +# Filter the messages based on which Slack channel they were posted on. +stream = op.filter("filter_channel", stream, channel_is(os.environ["SLACK_CHANNEL_ID"])) - -# Load environment variables from .env -dotenv.load_dotenv() +# Output the messages into the console +op.output("output", stream, StdOutSink()) logging.basicConfig( level=logging.DEBUG, format="%(asctime)s %(levelname)-7s %(message)s", handlers=[logging.StreamHandler()], ) - -# Dataflow needs to be assigned to a global variable called "flow" -flow = _build_dataflow() diff --git a/step2.py b/step2.py index d02bd22..ec6c780 100644 --- a/step2.py +++ b/step2.py @@ -9,12 +9,13 @@ from datetime import timedelta from datetime import timezone -import bytewax.operators as op import dotenv + +import bytewax.operators as op +import bytewax.operators.windowing as win from bytewax.connectors.stdio import StdOutSink from bytewax.dataflow import Dataflow -from bytewax.operators.window import EventClockConfig -from bytewax.operators.window import TumblingWindow +from bytewax.operators.windowing import EventClock, TumblingWindower from utils.connectors.slack import SlackMessage from utils.connectors.slack import SlackSource @@ -47,54 +48,50 @@ def get_message_channel(msg: SlackMessage) -> str: return msg.channel -def _build_dataflow() -> Dataflow: - # Create a bytewax stream object. - flow = Dataflow("supercharged-slackbot") +def get_timestamp(msg) -> datetime: + return msg.timestamp - # Data will be flowing in from the Slack stream. - stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) - keyed_stream = op.key_on("key_on_channel", stream, get_message_channel) +# Load environment variables from .env +dotenv.load_dotenv() - # Filter the messages based on which Slack channel they were posted on. - filtered_stream = op.filter( - "filter_channel", keyed_stream, channel_is(os.environ["SLACK_CHANNEL_ID"]) - ) - # Branch the stream into two: one for bot mentions, one for the rest - b_out = op.branch("is_mention", filtered_stream, is_mention) +# Create a bytewax stream object. +flow = Dataflow("supercharged-slackbot") - messages = b_out.falses - mentions = b_out.trues +# Data will be flowing in from the Slack stream. +stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) - # Inspect what messages got to which stream - op.inspect_debug("message", messages) - op.inspect_debug("mention", mentions) +keyed_stream = op.key_on("key_on_channel", stream, get_message_channel) - # We use windowing to throttle the amount of requests we are making to the - # LLM API. - clock = EventClockConfig( - lambda msg: msg.timestamp, wait_for_system_duration=timedelta(seconds=0) - ) - windower = TumblingWindow( - length=timedelta(seconds=10), align_to=datetime(2024, 2, 1, tzinfo=timezone.utc) - ) - windowed_messages = op.window.collect_window("window", messages, clock, windower) +# Filter the messages based on which Slack channel they were posted on. +filtered_stream = op.filter( + "filter_channel", keyed_stream, channel_is(os.environ["SLACK_CHANNEL_ID"]) +) - # Output the message windows into the console - op.output("output", windowed_messages, StdOutSink()) +# Branch the stream into two: one for bot mentions, one for the rest +b_out = op.branch("is_mention", filtered_stream, is_mention) - return flow +messages = b_out.falses +mentions = b_out.trues +# Inspect what messages got to which stream +op.inspect_debug("message", messages) +op.inspect_debug("mention", mentions) -# Load environment variables from .env -dotenv.load_dotenv() +# We use windowing to throttle the amount of requests we are making to the +# LLM API. +clock = EventClock(get_timestamp, wait_for_system_duration=timedelta(seconds=0)) +windower = TumblingWindower( + length=timedelta(seconds=10), align_to=datetime(2024, 2, 1, tzinfo=timezone.utc) +) +windowed_messages = win.collect_window("window", messages, clock, windower) + +# Output the message windows into the console +op.output("output", windowed_messages.down, StdOutSink()) logging.basicConfig( level=logging.DEBUG, format="%(asctime)s %(levelname)-7s %(message)s", handlers=[logging.StreamHandler()], ) - -# Dataflow needs to be assigned to a global variable called "flow" -flow = _build_dataflow() diff --git a/step3.py b/step3.py index a81aa34..ce44092 100644 --- a/step3.py +++ b/step3.py @@ -4,19 +4,17 @@ import logging import os -from typing import Callable -from typing import NewType +from typing import Callable, Optional, NewType from datetime import datetime from datetime import timedelta from datetime import timezone import dotenv import bytewax.operators as op +import bytewax.operators.windowing as win from bytewax.connectors.stdio import StdOutSink from bytewax.dataflow import Dataflow -from bytewax.operators.window import EventClockConfig -from bytewax.operators.window import TumblingWindow -from bytewax.operators.window import WindowMetadata +from bytewax.operators.windowing import EventClock, TumblingWindower, WindowMetadata import openai @@ -25,8 +23,23 @@ log = logging.getLogger(__name__) +# Load environment variables from .env +dotenv.load_dotenv() + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(levelname)-7s %(message)s", + handlers=[logging.StreamHandler()], +) + + +def get_timestamp(msg) -> datetime: + return msg.timestamp + + Summary = NewType("Summary", str) + def get_message_channel(msg: SlackMessage) -> str: """Extract the channel identifier from a message.""" return msg.channel @@ -77,19 +90,23 @@ def __init__(self): {summary} """ - def create_initial_state(self) -> Summary: + @classmethod + def create_initial_state(cls) -> Summary: """Get initial state for the stateful stream step.""" return Summary("No-one has said anything yet.") - def __call__( - self, previous_state: str, item: tuple[WindowMetadata, list[SlackMessage]] + def new_message( + self, previous_state: Optional[Summary], item: tuple[int, list[SlackMessage]] ) -> tuple[Summary, Summary]: """This is called whenewer a new window of messages arrive. It gets the previous state as the first argument, and returns the new state and an object to be passed downstream. """ - _, messages = item # we don't need the window metadata here + if previous_state is None: + previous_state = Summarizer.create_initial_state() + + _, messages = item # we don't need the window id system_prompt = self._prompt.format(summary=previous_state) @@ -111,63 +128,45 @@ def __call__( return new_state, summary -def _build_dataflow() -> Dataflow: - # Create a bytewax stream object. - flow = Dataflow("supercharged-slackbot") +# Create a bytewax stream object. +flow = Dataflow("supercharged-slackbot") - # Data will be flowing in from the Slack stream. - stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) +# Data will be flowing in from the Slack stream. +stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) - # Key the stream elements based on the channel id. In here we are not processing - # any channels separately, but this approach very much allows it. The windowing - # step requires a keyed stream, so that's why we are adding it here. - keyed_stream = op.key_on("key_on_channel", stream, get_message_channel) - - # Filter the messages based on which Slack channel they were posted on. - filtered_stream = op.filter( - "filter_channel", keyed_stream, channel_is(os.environ["SLACK_CHANNEL_ID"]) - ) +# Key the stream elements based on the channel id. In here we are not processing +# any channels separately, but this approach very much allows it. The windowing +# step requires a keyed stream, so that's why we are adding it here. +keyed_stream = op.key_on("key_on_channel", stream, get_message_channel) - # Branch the stream into two: one for bot mentions, one for the rest - b_out = op.branch("is_mention", filtered_stream, is_mention) - - messages = b_out.falses - mentions = b_out.trues - - # Inspect what messages got to which stream - op.inspect_debug("message", messages) - op.inspect_debug("mention", mentions) - - # We use windowing to throttle the amount of requests we are making to the - # LLM API. - clock = EventClockConfig( - lambda msg: msg.timestamp, wait_for_system_duration=timedelta(seconds=0) - ) - windower = TumblingWindow( - length=timedelta(seconds=10), align_to=datetime(2024, 1, 1, tzinfo=timezone.utc) - ) - windowed_messages = op.window.collect_window("window", messages, clock, windower) - - # Create a stateful step which keeps track of the current discussion summary - summarizer = Summarizer() - summary_stream = op.stateful_map( - "summarize", windowed_messages, summarizer.create_initial_state, summarizer - ) +# Filter the messages based on which Slack channel they were posted on. +filtered_stream = op.filter( + "filter_channel", keyed_stream, channel_is(os.environ["SLACK_CHANNEL_ID"]) +) - # Output the message windows into the console - op.output("output", summary_stream, StdOutSink()) +# Branch the stream into two: one for bot mentions, one for the rest +b_out = op.branch("is_mention", filtered_stream, is_mention) - return flow +messages = b_out.falses +mentions = b_out.trues +# Inspect what messages got to which stream +op.inspect_debug("message", messages) +op.inspect_debug("mention", mentions) -# Load environment variables from .env -dotenv.load_dotenv() +# We use windowing to throttle the amount of requests we are making to the +# LLM API. +clock = EventClock(get_timestamp, wait_for_system_duration=timedelta(seconds=0)) +windower = TumblingWindower( + length=timedelta(seconds=10), align_to=datetime(2024, 1, 1, tzinfo=timezone.utc) +) +windowed_messages = win.collect_window("window", messages, clock, windower) -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s %(levelname)-7s %(message)s", - handlers=[logging.StreamHandler()], +# Create a stateful step which keeps track of the current discussion summary +summarizer = Summarizer() +summary_stream = op.stateful_map( + "summarize", windowed_messages.down, summarizer.new_message ) -# Dataflow needs to be assigned to a global variable called "flow" -flow = _build_dataflow() +# Output the message windows into the console +op.output("output", summary_stream, StdOutSink()) diff --git a/step4.py b/step4.py index d56d89a..08b8fb7 100644 --- a/step4.py +++ b/step4.py @@ -4,19 +4,17 @@ import logging import os -from typing import Callable -from typing import NewType +from typing import Callable, Optional, NewType from datetime import datetime from datetime import timedelta from datetime import timezone import dotenv import bytewax.operators as op +import bytewax.operators.windowing as win from bytewax.connectors.stdio import StdOutSink from bytewax.dataflow import Dataflow -from bytewax.operators.window import EventClockConfig -from bytewax.operators.window import TumblingWindow -from bytewax.operators.window import WindowMetadata +from bytewax.operators.windowing import EventClock, TumblingWindower, WindowMetadata import openai @@ -26,10 +24,24 @@ log = logging.getLogger(__name__) +# Load environment variables from .env +dotenv.load_dotenv() + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(levelname)-7s %(message)s", + handlers=[logging.StreamHandler()], +) + + Summary = NewType("Summary", str) Context = NewType("Context", list[str]) +def get_timestamp(msg) -> datetime: + return msg.timestamp + + def get_message_channel(msg: SlackMessage) -> str: """Extract the channel identifier from a message.""" return msg.channel @@ -84,15 +96,18 @@ def create_initial_state(self) -> Summary: """Get initial state for the stateful stream step.""" return Summary("No-one has said anything yet.") - def __call__( - self, previous_state: str, item: tuple[WindowMetadata, list[SlackMessage]] + def new_message( + self, previous_state: Optional[Summary], item: tuple[int, list[SlackMessage]] ) -> tuple[Summary, Summary]: """This is called whenewer a new window of messages arrive. It gets the previous state as the first argument, and returns the new state and an object to be passed downstream. """ - _, messages = item # we don't need the window metadata here + if previous_state is None: + previous_state = Summary("No one has said anything yet.") + + _, messages = item # we don't need the window id here system_prompt = self._prompt.format(summary=previous_state) @@ -168,18 +183,16 @@ def _build_dataflow() -> Dataflow: # We use windowing to throttle the amount of requests we are making to the # LLM API. - clock = EventClockConfig( - lambda msg: msg.timestamp, wait_for_system_duration=timedelta(seconds=0) - ) - windower = TumblingWindow( + clock = EventClock(get_timestamp, wait_for_system_duration=timedelta(seconds=0)) + windower = TumblingWindower( length=timedelta(seconds=10), align_to=datetime(2024, 1, 1, tzinfo=timezone.utc) ) - windowed_messages = op.window.collect_window("window", messages, clock, windower) + windowed_messages = win.collect_window("window", messages, clock, windower) # Create a stateful step which keeps track of the current discussion summary summarizer = Summarizer() summary_stream = op.stateful_map( - "summarize", windowed_messages, summarizer.create_initial_state, summarizer + "summarize", windowed_messages.down, summarizer.new_message ) mentions_with_context = op.map( diff --git a/step5.py b/step5.py index bc75e82..9b096c9 100644 --- a/step5.py +++ b/step5.py @@ -4,8 +4,7 @@ import logging import os -from typing import Callable -from typing import NewType +from typing import Callable, Optional, NewType from datetime import datetime from datetime import timedelta from datetime import timezone @@ -13,11 +12,10 @@ import dataclasses import bytewax.operators as op +import bytewax.operators.windowing as win from bytewax.connectors.stdio import StdOutSink from bytewax.dataflow import Dataflow -from bytewax.operators.window import EventClockConfig -from bytewax.operators.window import TumblingWindow -from bytewax.operators.window import WindowMetadata +from bytewax.operators.windowing import EventClock, TumblingWindower, WindowMetadata import openai @@ -27,10 +25,24 @@ log = logging.getLogger(__name__) +# Load environment variables from .env +dotenv.load_dotenv() + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(levelname)-7s %(message)s", + handlers=[logging.StreamHandler()], +) + + Summary = NewType("Summary", str) Context = NewType("Context", list[str]) +def get_timestamp(msg) -> datetime: + return msg.timestamp + + @dataclasses.dataclass class AugmentedMessage: """Extension of the SlackMessage, with fields for summary and context.""" @@ -107,15 +119,18 @@ def create_initial_state(cls) -> Summary: """Get initial state for the stateful stream step.""" return Summary("No-one has said anything yet.") - def __call__( - self, previous_state: str, item: tuple[WindowMetadata, list[SlackMessage]] + def new_message( + self, previous_state: Optional[Summary], item: tuple[int, list[SlackMessage]] ) -> tuple[Summary, Summary]: """This is called whenewer a new window of messages arrive. It gets the previous state as the first argument, and returns the new state and an object to be passed downstream. """ - _, messages = item # we don't need the window metadata here + if previous_state is None: + previous_state = Summarizer.create_initial_state() + + _, messages = item # we don't need the window id here system_prompt = self._prompt.format(summary=previous_state) @@ -155,8 +170,9 @@ def _func( def join_summary_to_question( - previous_question_id: str | None, item: tuple[AugmentedMessage, Summary] -) -> tuple[str, tuple[AugmentedMessage, bool]]: + previous_question_id: str | None, + item: tuple[AugmentedMessage | None, Summary | None], +) -> tuple[str | None, tuple[AugmentedMessage | None, bool]]: """Join the summary data with the question message.""" message, summary = item @@ -171,104 +187,81 @@ def join_summary_to_question( def has_unique_flag_set( - item: tuple[str, tuple[AugmentedMessage, bool]], + item: tuple[str, tuple[AugmentedMessage | None, bool]], ) -> bool: _, (_, is_unique) = item return is_unique -def _build_dataflow() -> Dataflow: - # Initialize a vector database in-memory - document_storage = DocumentDatabase(model="BAAI/bge-small-en-v1.5") +# Initialize a vector database in-memory +document_storage = DocumentDatabase(model="BAAI/bge-small-en-v1.5") - # Load the preloaded documents - # This step will calculate the embeddings for all of the chapters in the - # document. - log.info("Loading documents to document database...") - document_storage.upload_text_chapterwise("data/dataset.txt") - log.info("Document loading finished") +# Load the preloaded documents +# This step will calculate the embeddings for all of the chapters in the +# document. +log.info("Loading documents to document database...") +document_storage.upload_text_chapterwise("data/dataset.txt") +log.info("Document loading finished") - # Create a bytewax stream object. - flow = Dataflow("supercharged-slackbot") +# Create a bytewax stream object. +flow = Dataflow("supercharged-slackbot") - # Data will be flowing in from the Slack stream. - stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) +# Data will be flowing in from the Slack stream. +stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) - # Key the stream elements based on the channel id. In here we are not processing - # any channels separately, but this approach very much allows it. The windowing - # step requires a keyed stream, so that's why we are adding it here. - keyed_stream = op.key_on("key_on_channel", stream, get_message_channel) - - # Filter the messages based on which Slack channel they were posted on. - filtered_stream = op.filter( - "filter_channel", keyed_stream, channel_is(os.environ["SLACK_CHANNEL_ID"]) - ) +# Key the stream elements based on the channel id. In here we are not processing +# any channels separately, but this approach very much allows it. The windowing +# step requires a keyed stream, so that's why we are adding it here. +keyed_stream = op.key_on("key_on_channel", stream, get_message_channel) - # Branch the stream into two: one for bot mentions, one for the rest - b_out = op.branch("is_mention", filtered_stream, is_mention) - - messages = b_out.falses - mentions = b_out.trues - - # Inspect what messages got to which stream - op.inspect_debug("message", messages) - op.inspect_debug("mention", mentions) - - # We use windowing to throttle the amount of requests we are making to the - # LLM API. - clock = EventClockConfig( - lambda msg: msg.timestamp, wait_for_system_duration=timedelta(seconds=0) - ) - windower = TumblingWindow( - length=timedelta(seconds=10), align_to=datetime(2024, 1, 1, tzinfo=timezone.utc) - ) - windowed_messages = op.window.collect_window("window", messages, clock, windower) +# Filter the messages based on which Slack channel they were posted on. +filtered_stream = op.filter( + "filter_channel", keyed_stream, channel_is(os.environ["SLACK_CHANNEL_ID"]) +) - # Create a stateful step which keeps track of the current discussion summary - summarizer = Summarizer() - summary_stream = op.stateful_map( - "summarize", windowed_messages, summarizer.create_initial_state, summarizer - ) +# Branch the stream into two: one for bot mentions, one for the rest +b_out = op.branch("is_mention", filtered_stream, is_mention) - # Augment the message with the context from document database - mentions_with_context = op.map( - "augment_with_context", mentions, context_retriever(document_storage) - ) +messages = b_out.falses +mentions = b_out.trues - # Join the two streams back together - joined = op.join( - "join_streams", - mentions_with_context, - summary_stream, - running=True, - ) +# Inspect what messages got to which stream +op.inspect_debug("message", messages) +op.inspect_debug("mention", mentions) - # Our running join will emit a new item each time either of the upstreams changes. - # In our case we are only caring about the change in the mentions-stream. - # Thus, we only let each mention/question go through this step once by first - # flagging them with a `stateful_map` and filtering the flagged items with `filter`. - flagged = op.stateful_map( - "augment_with_summary", joined, lambda: None, join_summary_to_question - ) - - unique_questions = op.filter("filter_flagged", flagged, has_unique_flag_set) +# We use windowing to throttle the amount of requests we are making to the +# LLM API. +clock = EventClock(get_timestamp, wait_for_system_duration=timedelta(seconds=0)) +windower = TumblingWindower( + length=timedelta(seconds=10), align_to=datetime(2024, 1, 1, tzinfo=timezone.utc) +) +windowed_messages = win.collect_window("window", messages, clock, windower) - questions = op.map("remove_flag_and_key", unique_questions, lambda x: x[1][0]) +# Create a stateful step which keeps track of the current discussion summary +summarizer = Summarizer() +summary_stream = op.stateful_map( + "summarize", windowed_messages.down, summarizer.new_message +) - # Output the augmented messages into the console - op.output("output", questions, StdOutSink()) +# Augment the message with the context from document database +mentions_with_context = op.map( + "augment_with_context", mentions, context_retriever(document_storage) +) - return flow +# Join the two streams back together +joined = op.join( + "join_streams", mentions_with_context, summary_stream, emit_mode="running" +) +# Our running join will emit a new item each time either of the upstreams changes. +# In our case we are only caring about the change in the mentions-stream. +# Thus, we only let each mention/question go through this step once by first +# flagging them with a `stateful_map` and filtering the flagged items with `filter`. +flagged = op.stateful_map("augment_with_summary", joined, join_summary_to_question) -# Load environment variables from .env -dotenv.load_dotenv() +unique_questions = op.filter("filter_flagged", flagged, has_unique_flag_set) -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s %(levelname)-7s %(message)s", - handlers=[logging.StreamHandler()], -) +questions = op.map("remove_flag_and_key", unique_questions, lambda x: x[1][0]) -# Dataflow needs to be assigned to a global variable called "flow" -flow = _build_dataflow() +# Output the augmented messages into the console +op.output("output", questions, StdOutSink()) diff --git a/step6.py b/step6.py index cca450f..6488930 100644 --- a/step6.py +++ b/step6.py @@ -4,8 +4,7 @@ import logging import os -from typing import Callable -from typing import NewType +from typing import Callable, Optional, NewType, Iterable from datetime import datetime from datetime import timedelta from datetime import timezone @@ -13,10 +12,9 @@ import dataclasses import bytewax.operators as op +import bytewax.operators.windowing as win from bytewax.dataflow import Dataflow -from bytewax.operators.window import EventClockConfig -from bytewax.operators.window import TumblingWindow -from bytewax.operators.window import WindowMetadata +from bytewax.operators.windowing import EventClock, TumblingWindower, WindowMetadata import openai @@ -27,10 +25,23 @@ log = logging.getLogger(__name__) +# Load environment variables from .env +dotenv.load_dotenv() + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(levelname)-7s %(message)s", + handlers=[logging.StreamHandler()], +) + Summary = NewType("Summary", str) Context = NewType("Context", list[str]) +def get_timestamp(msg) -> datetime: + return msg.timestamp + + @dataclasses.dataclass class AugmentedMessage: """Extension of the SlackMessage, with fields for summary and context.""" @@ -107,15 +118,18 @@ def create_initial_state(cls) -> Summary: """Get initial state for the stateful stream step.""" return Summary("No-one has said anything yet.") - def __call__( - self, previous_state: str, item: tuple[WindowMetadata, list[SlackMessage]] + def new_message( + self, previous_state: Optional[Summary], item: tuple[int, list[SlackMessage]] ) -> tuple[Summary, Summary]: """This is called whenewer a new window of messages arrive. It gets the previous state as the first argument, and returns the new state and an object to be passed downstream. """ - _, messages = item # we don't need the window metadata here + if previous_state is None: + previous_state = Summarizer.create_initial_state() + + _, messages = item # we don't need the window id here system_prompt = self._prompt.format(summary=previous_state) @@ -155,19 +169,20 @@ def _func( def join_summary_to_question( - previous_question_id: str | None, item: tuple[AugmentedMessage, Summary] -) -> tuple[str, tuple[AugmentedMessage, bool]]: + previous_question_id: str | None, + item: tuple[AugmentedMessage | None, Summary | None], +) -> tuple[str | None, Iterable[tuple[AugmentedMessage, bool]]]: """Join the summary data with the question message.""" message, summary = item if message is None: - return None, (None, False) + return None, [] if summary is None: summary = Summarizer.create_initial_state() message.related_summary = summary - return message.message.id, (message, message.message.id != previous_question_id) + return message.message.id, [(message, message.message.id != previous_question_id)] def has_unique_flag_set( @@ -205,9 +220,9 @@ def __init__(self): """ def __call__(self, message: AugmentedMessage) -> SlackMessage: - system_prompt = self._prompt.format( - summary=message.related_summary, documents="\n".join([f" * {s}" for s in message.related_context]) + summary=message.related_summary, + documents="\n".join([f" * {s}" for s in message.related_context]), ) user_prompt = message.message.text @@ -231,104 +246,82 @@ def __call__(self, message: AugmentedMessage) -> SlackMessage: ) -def _build_dataflow() -> Dataflow: - # Initialize a vector database in-memory - document_storage = DocumentDatabase(model="BAAI/bge-small-en-v1.5") - - # Load the preloaded documents - # This step will calculate the embeddings for all of the chapters in the - # document. - log.info("Loading documents to document database...") - document_storage.upload_text_chapterwise("data/dataset.txt") - log.info("Document loading finished") - - # Create a bytewax stream object. - flow = Dataflow("supercharged-slackbot") - - # Data will be flowing in from the Slack stream. - stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) - - # Key the stream elements based on the channel id. In here we are not processing - # any channels separately, but this approach very much allows it. The windowing - # step requires a keyed stream, so that's why we are adding it here. - keyed_stream = op.key_on("key_on_channel", stream, get_message_channel) - - # Filter the messages based on which Slack channel they were posted on. - filtered_stream = op.filter( - "filter_channel", keyed_stream, channel_is(os.environ["SLACK_CHANNEL_ID"]) - ) - - # Branch the stream into two: one for bot mentions, one for the rest - b_out = op.branch("is_mention", filtered_stream, is_mention) +# Initialize a vector database in-memory +document_storage = DocumentDatabase(model="BAAI/bge-small-en-v1.5") - messages = b_out.falses - mentions = b_out.trues +# Load the preloaded documents +# This step will calculate the embeddings for all of the chapters in the +# document. +log.info("Loading documents to document database...") +document_storage.upload_text_chapterwise("data/dataset.txt") +log.info("Document loading finished") - # Inspect what messages got to which stream - op.inspect_debug("message", messages) - op.inspect_debug("mention", mentions) +# Create a bytewax stream object. +flow = Dataflow("supercharged-slackbot") - # We use windowing to throttle the amount of requests we are making to the - # LLM API. - clock = EventClockConfig( - lambda msg: msg.timestamp, wait_for_system_duration=timedelta(seconds=0) - ) - windower = TumblingWindow( - length=timedelta(seconds=10), align_to=datetime(2024, 1, 1, tzinfo=timezone.utc) - ) - windowed_messages = op.window.collect_window("window", messages, clock, windower) +# Data will be flowing in from the Slack stream. +stream = op.input("input", flow, SlackSource(url=os.environ["SLACK_PROXY_URL"])) - # Create a stateful step which keeps track of the current discussion summary - summarizer = Summarizer() - summary_stream = op.stateful_map( - "summarize", windowed_messages, summarizer.create_initial_state, summarizer - ) +# Key the stream elements based on the channel id. In here we are not processing +# any channels separately, but this approach very much allows it. The windowing +# step requires a keyed stream, so that's why we are adding it here. +keyed_stream = op.key_on("key_on_channel", stream, get_message_channel) - # Augment the message with the context from document database - mentions_with_context = op.map( - "augment_with_context", mentions, context_retriever(document_storage) - ) +# Filter the messages based on which Slack channel they were posted on. +filtered_stream = op.filter( + "filter_channel", keyed_stream, channel_is(os.environ["SLACK_CHANNEL_ID"]) +) - # Join the two streams back together - joined = op.join( - "join_streams", - mentions_with_context, - summary_stream, - running=True, - ) +# Branch the stream into two: one for bot mentions, one for the rest +b_out = op.branch("is_mention", filtered_stream, is_mention) - # Our running join will emit a new item each time either of the upstreams changes. - # In our case we are only caring about the change in the mentions-stream. - # Thus, we only let each mention/question go through this step once by first - # flagging them with a `stateful_map` and filtering the flagged items with `filter`. - flagged = op.stateful_map( - "augment_with_summary", joined, lambda: None, join_summary_to_question - ) +messages = b_out.falses +mentions = b_out.trues - unique_questions = op.filter("filter_flagged", flagged, has_unique_flag_set) +# Inspect what messages got to which stream +op.inspect_debug("message", messages) +op.inspect_debug("mention", mentions) - questions = op.map("remove_flag_and_key", unique_questions, lambda x: x[1][0]) +# We use windowing to throttle the amount of requests we are making to the +# LLM API. +# We use windowing to throttle the amount of requests we are making to the +# LLM API. +clock = EventClock(get_timestamp, wait_for_system_duration=timedelta(seconds=0)) +windower = TumblingWindower( + length=timedelta(seconds=10), align_to=datetime(2024, 1, 1, tzinfo=timezone.utc) +) +windowed_messages = win.collect_window("window", messages, clock, windower) - # NOTE: Here one could do an additional lookup to document database based on - # the current summary, and extend the context of the message. +# Create a stateful step which keeps track of the current discussion summary +summarizer = Summarizer() +summary_stream = op.stateful_map( + "summarize", windowed_messages.down, summarizer.new_message +) - # Finally, generate a response - responses = op.map("generate", questions, Generator()) +# Augment the message with the context from document database +mentions_with_context = op.map( + "augment_with_context", mentions, context_retriever(document_storage) +) - # Finally, finally, send the reply back to the source of the question! - op.output("output", responses, SlackSink(url=os.environ["SLACK_PROXY_URL"])) +# Join the two streams back together +joined = op.join( + "join_streams", mentions_with_context, summary_stream, emit_mode="running" +) - return flow +# Our running join will emit a new item each time either of the upstreams changes. +# In our case we are only caring about the change in the mentions-stream. +# Thus, we only let each mention/question go through this step once by first +# flagging them with a `stateful_map` and filtering the flagged items with `filter`. +flagged = op.stateful_flat_map("augment_with_summary", joined, join_summary_to_question) +unique_questions = op.filter("filter_flagged", flagged, has_unique_flag_set) -# Load environment variables from .env -dotenv.load_dotenv() +questions = op.map("remove_flag_and_key", unique_questions, lambda x: x[1][0]) +# NOTE: Here one could do an additional lookup to document database based on +# the current summary, and extend the context of the message. -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s %(levelname)-7s %(message)s", - handlers=[logging.StreamHandler()], -) +# Finally, generate a response +responses = op.map("generate", questions, Generator()) -# Dataflow needs to be assigned to a global variable called "flow" -flow = _build_dataflow() +# Finally, finally, send the reply back to the source of the question! +op.output("output", responses, SlackSink(url=os.environ["SLACK_PROXY_URL"])) diff --git a/utils/connectors/slack/__init__.py b/utils/connectors/slack/__init__.py index b4bf616..f612c35 100644 --- a/utils/connectors/slack/__init__.py +++ b/utils/connectors/slack/__init__.py @@ -1,6 +1,3 @@ - - from .message import SlackMessage from .source import SlackSource from .sink import SlackSink - diff --git a/utils/connectors/slack/message.py b/utils/connectors/slack/message.py index 9a39320..45db2f3 100644 --- a/utils/connectors/slack/message.py +++ b/utils/connectors/slack/message.py @@ -1,4 +1,5 @@ """A data structure representing a slack message.""" + import dataclasses from datetime import datetime @@ -25,4 +26,4 @@ class SlackMessage: def __str__(self) -> str: """String-representation of the message, used by StdOutSink.""" - return f"Channel {self.channel}: User {self.user} says \"{self.text}\"" + return f'Channel {self.channel}: User {self.user} says "{self.text}"' diff --git a/utils/connectors/slack/sink.py b/utils/connectors/slack/sink.py index e1f8b51..bba2a17 100644 --- a/utils/connectors/slack/sink.py +++ b/utils/connectors/slack/sink.py @@ -66,5 +66,7 @@ def _send_messages(self): log.error("Send failed, reconnecting...") break - def build(self, worker_index: int, worker_count: int) -> _SlackSinkPartition: + def build( + self, step_id: str, worker_index: int, worker_count: int + ) -> _SlackSinkPartition: return _SlackSinkPartition(queue=self._queue) diff --git a/utils/connectors/slack/source.py b/utils/connectors/slack/source.py index 9e907d0..008ba3d 100644 --- a/utils/connectors/slack/source.py +++ b/utils/connectors/slack/source.py @@ -6,9 +6,7 @@ import random import threading import time -from datetime import datetime -from datetime import timedelta -from datetime import timezone +from datetime import datetime, timedelta, timezone from typing import Iterable from typing import Optional @@ -27,7 +25,7 @@ def __init__(self, queue: queue.Queue, *args, max_batch_size: int = 10, **kwargs self._queue = queue self._max_batch_size = max_batch_size - def next_batch(self, sched: datetime) -> Iterable[SlackMessage]: + def next_batch(self) -> Iterable[SlackMessage]: batch = [] for _ in range(self._max_batch_size): try: @@ -88,7 +86,7 @@ def _receive_messages(self): def build( self, - now: datetime, + step_id: str, worker_index: int, worker_count: int, ) -> _SlackSourcePartition: diff --git a/utils/proxy.py b/utils/proxy.py index 344c5ed..7678945 100644 --- a/utils/proxy.py +++ b/utils/proxy.py @@ -1,4 +1,5 @@ """A simple websocket proxy for combating the rate limits of the Slack API.""" + from __future__ import annotations import asyncio