Skip to content

Commit

Permalink
**Update Dependencies and Enhance Slack Plugin Functionality**
Browse files Browse the repository at this point in the history
**Changes:**
1. **Dependency Updates:**
   - Updated the `solace_ai_connector` dependency in `pyproject.toml` from `>=0.0.1` to `>=0.1.3` to incorporate the latest features and fixes.

2. **Enhancements to `slack_input.py`:**
   - Added the `input_type` field to `user_properties` to specify the source as Slack.
   - Improved message acknowledgment logic to ensure the acknowledgment message is only sent for direct messages (channel type is `im`).
   - Enhanced the `process_text_for_mentions` method to improve performance by skipping unnecessary processing when no mentions exist in the text.

3. **Enhancements to `slack_output.py`:**
   - Introduced `streaming_state` management to track and handle message streaming states and ensure accurate Slack message updates.
   - Added mechanisms to handle the first and last chunks of streamed content, ensuring the state is updated correctly.
   - Incorporated logic to manage the lifetime of streaming states, automatically aging out old states to maintain performance.
   - Provided better handling for message indexing and bulk updates to ensure efficient Slack message posting without redundancy.
   - Added error handling to clean up acknowledgment messages after message streaming is complete.

**Overall Impact:**
These enhancements improve the reliability and efficiency of the Slack input and output components within the Solace AI Connector, ensuring smoother operation and better handling of streamed messages, acknowledgments, and user mentions.
  • Loading branch information
efunneko committed Jul 29, 2024
1 parent f6c1b62 commit 7aefd09
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
dependencies = [
"PyYAML>=6.0.1",
"slack_bolt>=1.18.1",
"solace_ai_connector>=0.1.1",
"solace_ai_connector>=0.1.3",
]

[project.urls]
Expand Down
5 changes: 4 additions & 1 deletion src/solace_ai_connector_slack/components/slack_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,10 @@ def handle_event(self, event):
"event_ts": event.get("event_ts"),
"channel_type": event.get("channel_type"),
"user_id": event.get("user"),
"input_type": "slack",
}

if self.acknowledgement_message:
if self.acknowledgement_message and event.get("channel_type") == "im":
ack_msg_ts = self.app.client.chat_postMessage(
channel=event["channel"],
text=self.acknowledgement_message,
Expand All @@ -323,6 +324,8 @@ def get_user_email(self, user_id):

def process_text_for_mentions(self, text):
mention_emails = []
if "<@" not in text:
return text, mention_emails
for mention in text.split("<@"):
if mention.startswith("!"):
mention = mention[1:]
Expand Down
92 changes: 85 additions & 7 deletions src/solace_ai_connector_slack/components/slack_output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import re
from datetime import datetime


from solace_ai_connector.common.log import log
Expand Down Expand Up @@ -114,23 +115,35 @@ class SlackOutput(SlackBase):
def __init__(self, **kwargs):
super().__init__(info, **kwargs)
self.fix_formatting = self.get_config("correct_markdown_formatting", True)
self.streaming_state = {}

def invoke(self, message, data):
message_info = data.get("message_info")
content = data.get("content")
text = content.get("text")
stream = content.get("stream")
first_streamed_chunk = content.get("first_streamed_chunk")
last_streamed_chunk = content.get("last_streamed_chunk")
uuid = content.get("uuid")
channel = message_info.get("channel")
thread_ts = message_info.get("ts")
ack_msg_ts = message_info.get("ack_msg_ts")

if not channel:
log.error("slack_output: No channel specified in message")
self.discard_current_message()
return None

return {
"channel": channel,
"text": text,
"files": content.get("files"),
"thread_ts": thread_ts,
"ack_msg_ts": ack_msg_ts,
"stream": stream,
"first_streamed_chunk": first_streamed_chunk,
"last_streamed_chunk": last_streamed_chunk,
"uuid": uuid,
}

def send_message(self, message):
Expand All @@ -141,30 +154,69 @@ def send_message(self, message):
files = message.get_data("previous:files") or []
thread_ts = message.get_data("previous:ts")
ack_msg_ts = message.get_data("previous:ack_msg_ts")
first_streamed_chunk = message.get_data("previous:first_streamed_chunk")
last_streamed_chunk = message.get_data("previous:last_streamed_chunk")
uuid = message.get_data("previous:uuid")

if not isinstance(messages, list):
if messages is not None:
messages = [messages]
else:
messages = []

for text in messages:
for index, text in enumerate(messages):
if not text or not isinstance(text, str):
continue

if self.fix_formatting:
text = self.fix_markdown(text)

if index != 0:
text = "\n" + text

if first_streamed_chunk:
streaming_state = self.add_streaming_state(uuid)
else:
streaming_state = self.get_streaming_state(uuid)
if not streaming_state:
streaming_state = self.add_streaming_state(uuid)

if stream:
if ack_msg_ts:
if streaming_state.get("completed"):
# We can sometimes get a message after the stream has completed
continue

streaming_state["completed"] = last_streamed_chunk
ts = streaming_state.get("ts")
if ts:
try:
self.app.client.chat_update(
channel=channel, ts=ack_msg_ts, text=text
channel=channel, ts=ts, text=text
)
except Exception:
# It is normal to possibly get an update after the final
# message has already arrived and deleted the ack message
pass
else:
response = self.app.client.chat_postMessage(
channel=channel, text=text, thread_ts=thread_ts
)
streaming_state["ts"] = response["ts"]

else:
self.app.client.chat_postMessage(
channel=channel, text=text, thread_ts=thread_ts
)
# Not streaming
ts = streaming_state.get("ts")
streaming_state["completed"] = True
if not ts:
self.app.client.chat_postMessage(
channel=channel, text=text, thread_ts=thread_ts
)
# if ts:
# self.app.client.chat_update(channel=channel, ts=ts, text=text)
# else:
# self.app.client.chat_postMessage(
# channel=channel, text=text, thread_ts=thread_ts
# )

for file in files:
file_content = base64.b64decode(file["content"])
Expand All @@ -180,7 +232,7 @@ def send_message(self, message):
super().send_message(message)

try:
if ack_msg_ts and not stream:
if ack_msg_ts:
self.app.client.chat_delete(channel=channel, ts=ack_msg_ts)
except Exception:
pass
Expand All @@ -194,3 +246,29 @@ def fix_markdown(self, message):
# Fix bold
message = re.sub(r"\*\*(.*?)\*\*", r"*\1*", message)
return message

def get_streaming_state(self, uuid):
return self.streaming_state.get(uuid)

def add_streaming_state(self, uuid):
state = {
"create_time": datetime.now(),
}
self.streaming_state[uuid] = state
self.age_out_streaming_state()
return state

def delete_streaming_state(self, uuid):
try:
del self.streaming_state[uuid]
except KeyError:
pass

def age_out_streaming_state(self, age=60):
# Note that we can later optimize this by using an array of streaming_state that
# is ordered by create_time and then we can just remove the first element until
# we find one that is not expired.
now = datetime.now()
for uuid, state in list(self.streaming_state.items()):
if (now - state["create_time"]).total_seconds() > age:
del self.streaming_state[uuid]

0 comments on commit 7aefd09

Please sign in to comment.