Skip to content

Commit

Permalink
update topic clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Sep 19, 2023
1 parent e4758da commit 3bbc68d
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 15 deletions.
33 changes: 31 additions & 2 deletions fastchat/llm_judge/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from fastchat.model.model_adapter import get_conversation_template

# API setting constants
API_MAX_RETRY = 16
API_RETRY_SLEEP = 10
API_MAX_RETRY = 2
API_RETRY_SLEEP = 20
API_ERROR_OUTPUT = "$ERROR$"

TIE_DELTA = 0.1
Expand Down Expand Up @@ -418,6 +418,35 @@ def chat_compeletion_openai(model, conv, temperature, max_tokens):
return output


def chat_compeletion_openai_azure(model, conv, temperature, max_tokens):
openai.api_type = "azure"
openai.api_base = os.environ["AZURE_OPENAI_ENDPOINT"]
openai.api_key = os.environ["AZURE_OPENAI_KEY"]
openai.api_version = "2023-05-15"

if "azure-" in model:
model = model[6:]

output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
messages = conv.to_openai_api_messages()
response = openai.ChatCompletion.create(
engine=model,
messages=messages,
n=1,
temperature=temperature,
max_tokens=max_tokens,
)
output = response["choices"][0]["message"]["content"]
break
except openai.error.OpenAIError as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)

return output


def chat_compeletion_anthropic(model, conv, temperature, max_tokens):
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Count the unique users in a battle log file.
Usage:
python3 -input in.json --number 1000
"""

import argparse
import json
import random

K = 1000

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str)
parser.add_argument("--number", type=int, nargs="+")
args = parser.parse_args()

convs = json.load(open(args.input))
random.seed(42)
random.shuffle(convs)

for number in args.number:
new_convs = convs[:number]

output = args.input.replace(".json", f"_{number//K}k.json")
with open(output, "w") as fout:
json.dump(new_convs, fout, indent=2, ensure_ascii=False)

print(f"#in: {len(convs)}, #out: {len(new_convs)}")
print(f"Write to file: {output}")
15 changes: 11 additions & 4 deletions fastchat/serve/monitor/summarize_cluster.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""
Usage:
python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4
python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 --num-prompts 100
python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model azure-gpt-4-32k --num-prompts 200
"""
import argparse
import pickle

from fastchat.llm_judge.common import (
chat_compeletion_openai,
chat_compeletion_openai_azure,
chat_compeletion_anthropic,
)
from fastchat.conversation import get_conv_template
Expand All @@ -32,18 +34,23 @@ def truncate_string(s, l):
topics = []
percentages = []
for i, info in enumerate(cluster_infos):
num_samples, prompts = info
num_samples, topk_prompts, random_prompts = info
percentage = num_samples / num_total_prompts
print(
f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%"
)
instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific."
split = int(args.num_prompts * 0.8)
prompt = "\n".join(
[truncate_string(x, l=200) for x in prompts[: args.num_prompts]]
[truncate_string(x, l=200) for x in topk_prompts[: split]] +
[truncate_string(x, l=200) for x in random_prompts[: args.num_prompts - split]]
)
prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST."

if "gpt" in model:
if "azure-" in model:
template_name = "chatgpt"
completion_func = chat_compeletion_openai_azure
elif "gpt" in model:
template_name = "chatgpt"
completion_func = chat_compeletion_openai
elif "claude" in model:
Expand Down
22 changes: 13 additions & 9 deletions fastchat/serve/monitor/topic_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Usage:
python3 topic_clustering.py --in arena.json --english-only --min-length 32
python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1024
python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1536
"""
import argparse
import json
Expand Down Expand Up @@ -90,7 +90,7 @@ def get_embeddings(texts, model_name, batch_size):


def run_k_means(embeddings, num_clusters):
np.random.seed(0)
np.random.seed(42)
clustering_model = KMeans(n_clusters=num_clusters, n_init="auto")
clustering_model.fit(embeddings.numpy())
centers = torch.from_numpy(clustering_model.cluster_centers_)
Expand All @@ -109,7 +109,7 @@ def run_k_means(embeddings, num_clusters):


def run_agg_cluster(embeddings, num_clusters):
np.random.seed(0)
np.random.seed(42)
clustering_model = AgglomerativeClustering(n_clusters=num_clusters)
clustering_model.fit(embeddings)
labels = torch.from_numpy(clustering_model.labels_)
Expand All @@ -133,7 +133,7 @@ def run_agg_cluster(embeddings, num_clusters):
def run_hdbscan_cluster(embeddings):
import hdbscan

np.random.seed(0)
np.random.seed(42)
clusterer = hdbscan.HDBSCAN(min_cluster_size=10)
labels = torch.from_numpy(clusterer.fit_predict(embeddings))

Expand Down Expand Up @@ -183,13 +183,18 @@ def print_topk(texts, labels, topk_indices, show_cut_off):


def get_cluster_info(texts, labels, topk_indices):
np.random.seed(42)

cluster_info = []
for k in range(len(topk_indices)):
num_samples = torch.sum(labels == k).item()
prompts = []
topk_prompts = []
for idx in topk_indices[k]:
prompts.append(texts[idx])
cluster_info.append((num_samples, prompts))
topk_prompts.append(texts[idx])
random_prompts = []
for idx in range(len(topk_indices)):
random_prompts.append(np.random.choice(texts))
cluster_info.append((num_samples, topk_prompts, random_prompts))

return cluster_info

Expand Down Expand Up @@ -238,8 +243,6 @@ def get_cluster_info(texts, labels, topk_indices):
topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off)
num_clusters = len(centers)

cluster_info = get_cluster_info(texts, labels, topk_indices)

# Dump results
filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}"
print(topk_str)
Expand All @@ -259,5 +262,6 @@ def get_cluster_info(texts, labels, topk_indices):
obj = {"cluster": i, "text": text, "sim": score.item()}
fout.write(json.dumps(obj, ensure_ascii=False) + "\n")

cluster_info = get_cluster_info(texts, labels, topk_indices)
with open(filename_prefix + "_cluster.pkl", "wb") as fout:
pickle.dump(cluster_info, fout)

0 comments on commit 3bbc68d

Please sign in to comment.