From 50a27eb0c9580c2d7c5068fea4fd5beb2a72f6c3 Mon Sep 17 00:00:00 2001 From: Ryan Michael Date: Tue, 8 Aug 2023 11:01:22 -0400 Subject: [PATCH] Update notebook per Eric's prep steps (#614) Co-authored-by: Eric Pinzur <2641606+epinzur@users.noreply.github.com> --- examples/slackbot/Notebook.ipynb | 232 ++++++++++++++++--------------- 1 file changed, 118 insertions(+), 114 deletions(-) diff --git a/examples/slackbot/Notebook.ipynb b/examples/slackbot/Notebook.ipynb index a9c05fd73..c19b05e87 100644 --- a/examples/slackbot/Notebook.ipynb +++ b/examples/slackbot/Notebook.ipynb @@ -22,17 +22,9 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install openai kaskada" + "%pip install openai kaskada" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "bbe4c44d", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -41,17 +33,14 @@ "outputs": [], "source": [ "from datetime import datetime, timedelta\n", - "from sklearn.model_selection import train_test_split\n", - "from slack_sdk.socket_mode import SocketModeClient\n", - "import openai\n", - "import sparrow_pi as k\n", - "import sparrow_pi.sources as sources\n", + "from slack_sdk.socket_mode import SocketModeClient, SocketModeResponse\n", + "import sparrow_pi as kt\n", "import openai\n", "import getpass\n", "import pyarrow\n", "\n", "# Initialize Kaskada with a local execution context.\n", - "k.init_session()\n", + "kt.init_session()\n", "\n", "# Initialize OpenAI\n", "openai.api_key = getpass.getpass('OpenAI: API Key')\n", @@ -74,125 +63,94 @@ { "cell_type": "code", "execution_count": null, - "id": "e95d4d50", - "metadata": { - "tags": [] - }, + "id": "7e6fedb9", + "metadata": {}, "outputs": [], "source": [ - "SYSTEM_CONTEXT = \"\"\"\n", - "You are a helpful assistant designed to suggest the Slack usernames of \n", - "people who need to know about a Slack conversation.\n", + "def build_conversation(messages):\n", + " message_time = messages.col(\"ts\")\n", + " last_message_time = message_time.lag(1) # !!!\n", + " is_new_conversation = message_time.seconds_since(last_message_time) > 10 * 60\n", "\n", - "Only respond as a JSON list containing the Slack usernames of people to notify \n", - "of the conversation, or return an empty list if no should be notified.\n", - "\n", - "The Slack conversation is as follows, formatted as a JSON object:\n", - "\n", - "\"\"\" " + " return messages \\\n", + " .select(\"user\", \"ts\", \"text\", \"reactions\") \\\n", + " .collect(window=kt.windows.Since(is_new_conversation), max=100) \\\n", + " .select(\"user\", \"ts\", \"text\") \\\n", + " .collect(window=kt.windows.Since(is_new_conversation), max=100)" ] }, { "cell_type": "code", "execution_count": null, - "id": "cf3ec0e3-c943-4974-922b-9b9c965ba0ff", - "metadata": { - "tags": [] - }, + "id": "fdb2d959-d371-4026-9f8d-4ab26cfbf317", + "metadata": {}, "outputs": [], "source": [ - "def prompt(messages):\n", - " # A conversation starts when a new messages is more than 5 minutes after the previous message\n", - " #last_message = messages.lag(1)\n", - " #since_last_message = messages.time().seconds_since(last_message.time())\n", - " #conversation_start = since_last_message > k.minutes(5)\n", - "\n", - " k.record({\n", - " # A list of all messages over the past 10 minutes (up to 100)\n", - " \"recent_messages\": messages\n", - " #.select(\"user\", \"type\", \"text\")\n", - " .select(False, \"user\", \"subtype\", \"text\")\n", - " #.collect(window=since(conversation_start), max=100),\n", - " .last(),\n", - "\n", - " # How many messages have been reacted to in the conversation\n", - " \"reaction_count\": messages\n", - " #.filter(messages[\"reactions\"].is_not_null())\n", - " #.count(window=since(conversation_start)),\n", - " [\"reply_count\"].sum(),\n", - " })" + "def build_examples(messages):\n", + " duration = kt.minutes(5) # !!!\n", + "\n", + " coverstation = build_conversation(messages)\n", + " shifted_coversation = coverstation.shift_by(duration) # !!!\n", + "\n", + " reaction_users = coverstation.col(\"reactions\").col(\"name\").collect(kt.windows.Trailing(duration)).flatten() # !!!\n", + " participating_users = coverstation.col(\"user\").collect(kt.windows.Trailing(duration)) # !!!\n", + " engaged_users = kt.union(reaction_users, participating_users) # !!!\n", + "\n", + " return kt.record({ \"prompt\": shifted_coversation, \"completion\": engaged_users}) \\\n", + " .filter(shifted_coversation.is_not_null())" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "fdb2d959-d371-4026-9f8d-4ab26cfbf317", + "cell_type": "markdown", + "id": "0035f558-23bd-4b4d-95a0-ed5e8fece673", "metadata": {}, - "outputs": [], "source": [ - "def examples(messages):\n", - " # We'll train ChatGPT to generate the user ID who will engage next\n", - " k.record({\n", - " # For each example, use the previous prompt\n", - " \"prompt\": prompt(messages)\n", - " #.lag(1),\n", - " .last(),\n", - "\n", - " # ...and the current user ID\n", - " \"completion\": messages[\"user\"],\n", - " })" + "## Fine-tune the model" ] }, { "cell_type": "code", "execution_count": null, - "id": "0be6bde5-9602-43e6-89bf-e0e1470fc7c0", + "id": "af7d2a45-eb89-47ce-b471-a39ad8c7bbc7", "metadata": {}, "outputs": [], "source": [ - "# Format the data for OpenAI\n", - "def format_prompt(prompt):\n", - " return SYSTEM_CONTEXT + json.dumps(prompt) + \"\\n\\n###\\n\\n\"\n", - "def format_completion(completion):\n", - " return completion + \"###\"" - ] - }, - { - "cell_type": "markdown", - "id": "0035f558-23bd-4b4d-95a0-ed5e8fece673", - "metadata": {}, - "source": [ - "## Fine-tune the model" + "import pandas\n", + "import sparrow_pi.sources as sources\n", + "\n", + "messages = kt.sources.Parquet(\"./messages.parquet\", time = \"ts\", entity = \"channel\")\n", + "messages = messages.with_key(kt.record({ # !!!\n", + " \"channel\": messages.col(\"channel\"),\n", + " \"thread\": messages.col(\"thread_ts\"),\n", + " }))\n", + "examples = build_examples(messages)\n", + "\n", + "examples_df = examples.run().to_pandas()" ] }, { "cell_type": "code", "execution_count": null, - "id": "af7d2a45-eb89-47ce-b471-a39ad8c7bbc7", + "id": "fa93a8db", "metadata": {}, "outputs": [], "source": [ - "# Compute examples from historical data\n", - "#tl = examples(messages = k.source.read_parquet(\n", - "# entity_column=\"channel\", \n", - "# time_column=\"ts\", \n", - "# files=[\"./messages.parquet\"]))\n", - "tl = examples(sources.ArrowSource(\"ts\", \"channel\", pandas.read_parquet(\"./messages.parquet\")))\n", + "from sklearn import preprocessing\n", "\n", - "\n", - "# Limit to the examples we want to use for training\n", - "#tl = tl.filter(tl[\"prompt\"].is_not_null())\n", - "#examples_df = tl.run().to_pandas()\n", - "examples_df = tl.run()\n", + "le = preprocessing.LabelEncoder()\n", + "le.fit(examples_df.completion.explode())\n", "\n", "# Format for the OpenAI API\n", - "examples_df[\"prompt\"] = examples_df[\"prompt\"].apply(format_prompt)\n", - "examples_df[\"completion\"] = examples_df[\"completion\"].apply(format_completion)\n", + "def format_prompt(prompt):\n", + " return \"start -> \" + \"\\n\\n\".join([f' {msg.user} --> {msg.text} ' for msg in prompt]) + \"\\n\\n###\\n\\n\"\n", + "examples_df.prompt = examples_df.prompt.apply(format_prompt)\n", "\n", - "# Split training & validation\n", - "train, valid = train_test_split(examples_df, test_size=0.2, random_state=42)\n", - "train.to_json(\"train.jsonl\", orient='records', lines=True)\n", - "valid.to_json(\"valid.jsonl\", orient='records', lines=True)" + "def format_completion(completion):\n", + " return \" \" + (\" \".join([le.transform(u) for u in completion]) if len(completion) > 0 else \"nil\") + \" end\"\n", + "examples_df.completion = examples_df.completion.apply(format_completion)\n", + "\n", + "# Write examples to file\n", + "examples_df.to_json(\"examples.jsonl\", orient='records', lines=True)" ] }, { @@ -204,8 +162,32 @@ }, "outputs": [], "source": [ - "%%bash\n", - "openai api fine_tunes.create -t \"train.jsonl\" -v \"valid.jsonl\"" + "from types import SimpleNamespace\n", + "from openai import cli\n", + "\n", + "# verifiy data format, split for training & validation\n", + "args = SimpleNamespace(file='./examples.jsonl', quiet=True)\n", + "cli.FineTune.prepare_data(args)\n", + "training_id = cli.FineTune._get_or_upload('./examples_prepared_train.jsonl', True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a60b77c", + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "resp = openai.FineTune.create(\n", + " training_file = training_id,\n", + " model = \"davinci\",\n", + " n_epochs = 2,\n", + " learning_rate_multiplier = 0.02,\n", + " suffix = \"coversation_users\"\n", + ")\n", + "print(f'Fine-tuning model with job ID: \"{resp[\"id\"]}\"')" ] }, { @@ -228,8 +210,12 @@ }, "outputs": [], "source": [ + "import json, math\n", + "\n", + "min_prob_for_response = 0.75\n", + "\n", "# Receive Slack messages in real-time\n", - "live_messages = k.source.read_stream(entity_column=\"channel\", time_column=\"ts\")\n", + "live_messages = kt.sources.read_stream(entity_column=\"channel\", time_column=\"ts\")\n", "\n", "# Receive messages from Slack\n", "def handle_message(client, req):\n", @@ -238,36 +224,54 @@ " \n", " # Deliver the message to Kaskada\n", " live_messages.add_event(pyarrow.json.read_json(req.payload))\n", - "client.socket_mode_request_listeners.append(handle_message)\n", - "client.connect()\n", + "slack.socket_mode_request_listeners.append(handle_message)\n", + "slack.connect()\n", "\n", "# Handle messages in realtime\n", - "for p in prompt(live_messages).run(starting=datetime.now()).to_generator():\n", + "# A \"conversation\" is a list of messages\n", + "for conversation in build_conversation(live_messages).start().to_generator():\n", + " if len(conversation) == 0:\n", + " continue\n", " \n", " # Ask the model who should be notified\n", - " completions = openai.Completion.create(\n", + " res = openai.Completion.create(\n", " model=\"ft-2zaA7qi0rxJduWQpdvOvmGn3\", \n", - " prompt=format_prompt(p),\n", - " max_tokens=10,\n", + " prompt=format_prompt(conversation),\n", + " max_tokens=1,\n", " temperature=0,\n", + " logprobs=5,\n", " )\n", - " users = json.loads(completions.choices[0].text)\n", + "\n", + " users = []\n", + " logprobs = res[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][0]\n", + " for user in logprobs:\n", + " if math.exp(logprobs[user]) > min_prob_for_response:\n", + " # if `nil` user is an option, stop processing\n", + " if user == \"nil\":\n", + " users = []\n", + " break\n", + " users.append(user)\n", + "\n", + " # alert on most recent message in conversation\n", + " msg = conversation.pop()\n", " \n", " # Send notification to users\n", " for user in users:\n", - " permalink = slack.web_client.chat_getPermalink(\n", - " channel=prompt[\"_entity\"],\n", - " message_ts=prompt[\"_time\"],\n", + " user_id = le.inverse_transform(user)\n", + "\n", + " link = slack.web_client.chat_getPermalink(\n", + " channel=msg[\"channel\"],\n", + " message_ts=msg[\"ts\"],\n", " )[\"permalink\"]\n", " \n", " app_channel = slack.web_client.users_conversations(\n", " types=\"im\",\n", - " user=user,\n", + " user=user_id,\n", " )[\"channels\"][0][\"id\"]\n", " \n", " slack.web_client.chat_postMessage(\n", " channel=app_channel,\n", - " text=f'You put eyes on this message: <{link}|{message_text}>'\n", + " text=f'You may be interested in this converstation: <{link}|{msg[\"text\"]}>'\n", " )" ] },