Skip to content

Commit

Permalink
Update notebook per Eric's prep steps (#614)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Pinzur <[email protected]>
  • Loading branch information
kerinin and epinzur authored Aug 8, 2023
1 parent 6624cfa commit 50a27eb
Showing 1 changed file with 118 additions and 114 deletions.
232 changes: 118 additions & 114 deletions examples/slackbot/Notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,9 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install openai kaskada"
"%pip install openai kaskada"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbe4c44d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -41,17 +33,14 @@
"outputs": [],
"source": [
"from datetime import datetime, timedelta\n",
"from sklearn.model_selection import train_test_split\n",
"from slack_sdk.socket_mode import SocketModeClient\n",
"import openai\n",
"import sparrow_pi as k\n",
"import sparrow_pi.sources as sources\n",
"from slack_sdk.socket_mode import SocketModeClient, SocketModeResponse\n",
"import sparrow_pi as kt\n",
"import openai\n",
"import getpass\n",
"import pyarrow\n",
"\n",
"# Initialize Kaskada with a local execution context.\n",
"k.init_session()\n",
"kt.init_session()\n",
"\n",
"# Initialize OpenAI\n",
"openai.api_key = getpass.getpass('OpenAI: API Key')\n",
Expand All @@ -74,125 +63,94 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e95d4d50",
"metadata": {
"tags": []
},
"id": "7e6fedb9",
"metadata": {},
"outputs": [],
"source": [
"SYSTEM_CONTEXT = \"\"\"\n",
"You are a helpful assistant designed to suggest the Slack usernames of \n",
"people who need to know about a Slack conversation.\n",
"def build_conversation(messages):\n",
" message_time = messages.col(\"ts\")\n",
" last_message_time = message_time.lag(1) # !!!\n",
" is_new_conversation = message_time.seconds_since(last_message_time) > 10 * 60\n",
"\n",
"Only respond as a JSON list containing the Slack usernames of people to notify \n",
"of the conversation, or return an empty list if no should be notified.\n",
"\n",
"The Slack conversation is as follows, formatted as a JSON object:\n",
"\n",
"\"\"\" "
" return messages \\\n",
" .select(\"user\", \"ts\", \"text\", \"reactions\") \\\n",
" .collect(window=kt.windows.Since(is_new_conversation), max=100) \\\n",
" .select(\"user\", \"ts\", \"text\") \\\n",
" .collect(window=kt.windows.Since(is_new_conversation), max=100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf3ec0e3-c943-4974-922b-9b9c965ba0ff",
"metadata": {
"tags": []
},
"id": "fdb2d959-d371-4026-9f8d-4ab26cfbf317",
"metadata": {},
"outputs": [],
"source": [
"def prompt(messages):\n",
" # A conversation starts when a new messages is more than 5 minutes after the previous message\n",
" #last_message = messages.lag(1)\n",
" #since_last_message = messages.time().seconds_since(last_message.time())\n",
" #conversation_start = since_last_message > k.minutes(5)\n",
"\n",
" k.record({\n",
" # A list of all messages over the past 10 minutes (up to 100)\n",
" \"recent_messages\": messages\n",
" #.select(\"user\", \"type\", \"text\")\n",
" .select(False, \"user\", \"subtype\", \"text\")\n",
" #.collect(window=since(conversation_start), max=100),\n",
" .last(),\n",
"\n",
" # How many messages have been reacted to in the conversation\n",
" \"reaction_count\": messages\n",
" #.filter(messages[\"reactions\"].is_not_null())\n",
" #.count(window=since(conversation_start)),\n",
" [\"reply_count\"].sum(),\n",
" })"
"def build_examples(messages):\n",
" duration = kt.minutes(5) # !!!\n",
"\n",
" coverstation = build_conversation(messages)\n",
" shifted_coversation = coverstation.shift_by(duration) # !!!\n",
"\n",
" reaction_users = coverstation.col(\"reactions\").col(\"name\").collect(kt.windows.Trailing(duration)).flatten() # !!!\n",
" participating_users = coverstation.col(\"user\").collect(kt.windows.Trailing(duration)) # !!!\n",
" engaged_users = kt.union(reaction_users, participating_users) # !!!\n",
"\n",
" return kt.record({ \"prompt\": shifted_coversation, \"completion\": engaged_users}) \\\n",
" .filter(shifted_coversation.is_not_null())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fdb2d959-d371-4026-9f8d-4ab26cfbf317",
"cell_type": "markdown",
"id": "0035f558-23bd-4b4d-95a0-ed5e8fece673",
"metadata": {},
"outputs": [],
"source": [
"def examples(messages):\n",
" # We'll train ChatGPT to generate the user ID who will engage next\n",
" k.record({\n",
" # For each example, use the previous prompt\n",
" \"prompt\": prompt(messages)\n",
" #.lag(1),\n",
" .last(),\n",
"\n",
" # ...and the current user ID\n",
" \"completion\": messages[\"user\"],\n",
" })"
"## Fine-tune the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0be6bde5-9602-43e6-89bf-e0e1470fc7c0",
"id": "af7d2a45-eb89-47ce-b471-a39ad8c7bbc7",
"metadata": {},
"outputs": [],
"source": [
"# Format the data for OpenAI\n",
"def format_prompt(prompt):\n",
" return SYSTEM_CONTEXT + json.dumps(prompt) + \"\\n\\n###\\n\\n\"\n",
"def format_completion(completion):\n",
" return completion + \"###\""
]
},
{
"cell_type": "markdown",
"id": "0035f558-23bd-4b4d-95a0-ed5e8fece673",
"metadata": {},
"source": [
"## Fine-tune the model"
"import pandas\n",
"import sparrow_pi.sources as sources\n",
"\n",
"messages = kt.sources.Parquet(\"./messages.parquet\", time = \"ts\", entity = \"channel\")\n",
"messages = messages.with_key(kt.record({ # !!!\n",
" \"channel\": messages.col(\"channel\"),\n",
" \"thread\": messages.col(\"thread_ts\"),\n",
" }))\n",
"examples = build_examples(messages)\n",
"\n",
"examples_df = examples.run().to_pandas()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af7d2a45-eb89-47ce-b471-a39ad8c7bbc7",
"id": "fa93a8db",
"metadata": {},
"outputs": [],
"source": [
"# Compute examples from historical data\n",
"#tl = examples(messages = k.source.read_parquet(\n",
"# entity_column=\"channel\", \n",
"# time_column=\"ts\", \n",
"# files=[\"./messages.parquet\"]))\n",
"tl = examples(sources.ArrowSource(\"ts\", \"channel\", pandas.read_parquet(\"./messages.parquet\")))\n",
"from sklearn import preprocessing\n",
"\n",
"\n",
"# Limit to the examples we want to use for training\n",
"#tl = tl.filter(tl[\"prompt\"].is_not_null())\n",
"#examples_df = tl.run().to_pandas()\n",
"examples_df = tl.run()\n",
"le = preprocessing.LabelEncoder()\n",
"le.fit(examples_df.completion.explode())\n",
"\n",
"# Format for the OpenAI API\n",
"examples_df[\"prompt\"] = examples_df[\"prompt\"].apply(format_prompt)\n",
"examples_df[\"completion\"] = examples_df[\"completion\"].apply(format_completion)\n",
"def format_prompt(prompt):\n",
" return \"start -> \" + \"\\n\\n\".join([f' {msg.user} --> {msg.text} ' for msg in prompt]) + \"\\n\\n###\\n\\n\"\n",
"examples_df.prompt = examples_df.prompt.apply(format_prompt)\n",
"\n",
"# Split training & validation\n",
"train, valid = train_test_split(examples_df, test_size=0.2, random_state=42)\n",
"train.to_json(\"train.jsonl\", orient='records', lines=True)\n",
"valid.to_json(\"valid.jsonl\", orient='records', lines=True)"
"def format_completion(completion):\n",
" return \" \" + (\" \".join([le.transform(u) for u in completion]) if len(completion) > 0 else \"nil\") + \" end\"\n",
"examples_df.completion = examples_df.completion.apply(format_completion)\n",
"\n",
"# Write examples to file\n",
"examples_df.to_json(\"examples.jsonl\", orient='records', lines=True)"
]
},
{
Expand All @@ -204,8 +162,32 @@
},
"outputs": [],
"source": [
"%%bash\n",
"openai api fine_tunes.create -t \"train.jsonl\" -v \"valid.jsonl\""
"from types import SimpleNamespace\n",
"from openai import cli\n",
"\n",
"# verifiy data format, split for training & validation\n",
"args = SimpleNamespace(file='./examples.jsonl', quiet=True)\n",
"cli.FineTune.prepare_data(args)\n",
"training_id = cli.FineTune._get_or_upload('./examples_prepared_train.jsonl', True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a60b77c",
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"resp = openai.FineTune.create(\n",
" training_file = training_id,\n",
" model = \"davinci\",\n",
" n_epochs = 2,\n",
" learning_rate_multiplier = 0.02,\n",
" suffix = \"coversation_users\"\n",
")\n",
"print(f'Fine-tuning model with job ID: \"{resp[\"id\"]}\"')"
]
},
{
Expand All @@ -228,8 +210,12 @@
},
"outputs": [],
"source": [
"import json, math\n",
"\n",
"min_prob_for_response = 0.75\n",
"\n",
"# Receive Slack messages in real-time\n",
"live_messages = k.source.read_stream(entity_column=\"channel\", time_column=\"ts\")\n",
"live_messages = kt.sources.read_stream(entity_column=\"channel\", time_column=\"ts\")\n",
"\n",
"# Receive messages from Slack\n",
"def handle_message(client, req):\n",
Expand All @@ -238,36 +224,54 @@
" \n",
" # Deliver the message to Kaskada\n",
" live_messages.add_event(pyarrow.json.read_json(req.payload))\n",
"client.socket_mode_request_listeners.append(handle_message)\n",
"client.connect()\n",
"slack.socket_mode_request_listeners.append(handle_message)\n",
"slack.connect()\n",
"\n",
"# Handle messages in realtime\n",
"for p in prompt(live_messages).run(starting=datetime.now()).to_generator():\n",
"# A \"conversation\" is a list of messages\n",
"for conversation in build_conversation(live_messages).start().to_generator():\n",
" if len(conversation) == 0:\n",
" continue\n",
" \n",
" # Ask the model who should be notified\n",
" completions = openai.Completion.create(\n",
" res = openai.Completion.create(\n",
" model=\"ft-2zaA7qi0rxJduWQpdvOvmGn3\", \n",
" prompt=format_prompt(p),\n",
" max_tokens=10,\n",
" prompt=format_prompt(conversation),\n",
" max_tokens=1,\n",
" temperature=0,\n",
" logprobs=5,\n",
" )\n",
" users = json.loads(completions.choices[0].text)\n",
"\n",
" users = []\n",
" logprobs = res[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][0]\n",
" for user in logprobs:\n",
" if math.exp(logprobs[user]) > min_prob_for_response:\n",
" # if `nil` user is an option, stop processing\n",
" if user == \"nil\":\n",
" users = []\n",
" break\n",
" users.append(user)\n",
"\n",
" # alert on most recent message in conversation\n",
" msg = conversation.pop()\n",
" \n",
" # Send notification to users\n",
" for user in users:\n",
" permalink = slack.web_client.chat_getPermalink(\n",
" channel=prompt[\"_entity\"],\n",
" message_ts=prompt[\"_time\"],\n",
" user_id = le.inverse_transform(user)\n",
"\n",
" link = slack.web_client.chat_getPermalink(\n",
" channel=msg[\"channel\"],\n",
" message_ts=msg[\"ts\"],\n",
" )[\"permalink\"]\n",
" \n",
" app_channel = slack.web_client.users_conversations(\n",
" types=\"im\",\n",
" user=user,\n",
" user=user_id,\n",
" )[\"channels\"][0][\"id\"]\n",
" \n",
" slack.web_client.chat_postMessage(\n",
" channel=app_channel,\n",
" text=f'You put eyes on this message: <{link}|{message_text}>'\n",
" text=f'You may be interested in this converstation: <{link}|{msg[\"text\"]}>'\n",
" )"
]
},
Expand Down

0 comments on commit 50a27eb

Please sign in to comment.