Skip to content

Commit

Permalink
update labeling process
Browse files Browse the repository at this point in the history
  • Loading branch information
derixu committed Dec 10, 2024
1 parent e0f56f0 commit d0c5de4
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 206 deletions.
74 changes: 67 additions & 7 deletions fastchat/serve/monitor/classify/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
# - score
import ast
import re
import numpy as np
from collections import defaultdict

from utils import HuggingFaceRefusalClassifier

from utils import HuggingFaceClassifier, chat_completion_openai

class Category:
def __init__(self):
Expand All @@ -37,7 +38,25 @@ def post_process(self):
pass


class CategoryHardPrompt(Category):
class CategoryAPI(Category):
def __init__(self):
pass

def get_answer(self, batch, model_name, max_tokens, temperature, api_dict):
assert len(batch) == 1, "API-based categories must have batch size of 1"

conv = self.pre_process(batch["prompt"].iloc[0])
output = chat_completion_openai(
model=model_name,
messages=conv,
temperature=temperature,
max_tokens=max_tokens,
api_dict=api_dict,
)
return self.post_process(output)


class CategoryHardPrompt(CategoryAPI):
def __init__(self):
super().__init__()
self.name_tag = "criteria_v0.1"
Expand All @@ -52,6 +71,8 @@ def __init__(self):
6: "technical_accuracy",
7: "real_world",
}
self.batch_size = 1
self.is_parallel = True

def get_score(self, judgment):
matches = self.pattern.findall(judgment)
Expand All @@ -77,13 +98,15 @@ def post_process(self, judgment):
return {name: bool(i in criteria) for i, name in self.tags.items()}


class CategoryIF(Category):
class CategoryIF(CategoryAPI):
def __init__(self):
super().__init__()
self.name_tag = "if_v0.1"
self.pattern = re.compile(r"<score>([012345])<\/score>")
self.system_prompt = "You are an AI assistant tasked with determining whether a given user prompt can effectively assess another AI's ability to follow instructions. Your goal is to analyze the prompt and decide if it contains specific, clear instructions that would test an AI's capability to understand and execute directions accurately. Carefully examine the user prompt and consider the following aspects:\n1. Does it contain specific instructions or requirements?\n2. Are there multiple steps or elements the AI needs to address?\n3. Does it ask for a particular format or structure in the response?\n4. Is there a unique or challenging aspect that would test the AI's ability to follow directions precisely?\n\nConsider both the content and the structure of the instructions. A good prompt for assessing instruction-following capabilities should have clear, specific directions that can be objectively evaluated. Think about why this prompt does or does not effectively assess an AI's ability to follow instructions. Consider both the strengths and weaknesses of the prompt in this regard. Output your verdict as a score from 0 to 5:\n0 = Does not evaluate instruction-following ability.\n1 = Ineffective at evaluating instruction-following ability.\n2 = Somewhat effective at evaluating instruction-following ability.\n3 = Effective at evaluating simple instruction-following ability.\n4 = Effective at evaluating more complex instruction-following ability.\n5 = Effective at evaluating advanced instruction-following ability.\n\nPresent your score in the following format:\n<score>[Your score from 0 to 5]</score>.\nDo NOT explain."
self.prompt_template = "<user_prompt>{PROMPT}</user_prompt>"
self.batch_size = 1
self.is_parallel = True

def get_score(self, judgment):
matches = self.pattern.findall(judgment)
Expand Down Expand Up @@ -111,13 +134,15 @@ def post_process(self, judgment):
}


class CategoryMath(Category):
class CategoryMath(CategoryAPI):
def __init__(self):
super().__init__()
self.name_tag = "math_v0.1"
self.pattern = re.compile(r"<decision>(\w+)<\/decision>")
self.system_prompt = 'You are tasked with determining whether a given user prompt requires an AI assistant to solve a math problem and apply mathematical logic and reasoning.\n\nCarefully analyze the user prompt and consider whether it requires mathematical problem-solving skills to answer correctly. Think about the following aspects:\n\n1. Does it require the application of a specific mathematical concept or formula?\n2. Does the prompt involve numerical calculations or algebraic manipulation or logical reasoning?\n3. Is there a clear mathematical problem to be solved?\n4. Would answering this prompt demonstrate proficiency in a specific area in mathematics?\n\nOutput your verdict in the following format:"<decision>\n[yes/no]\n</decision>". Do NOT explain.'
self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"
self.batch_size = 1
self.is_parallel = True

def get_score(self, judgment):
matches = self.pattern.findall(judgment.replace("\n", "").lower())
Expand All @@ -142,13 +167,15 @@ def post_process(self, judgment):
return {"math": bool(score == "yes") if score else False}


class CategoryCreativeWriting(Category):
class CategoryCreativeWriting(CategoryAPI):
def __init__(self):
super().__init__()
self.name_tag = "creative_writing_v0.1"
self.pattern = re.compile(r"<decision>(\w+)<\/decision>")
self.system_prompt = 'You are tasked with determining whether a given user prompt is asking for creative writing. Creative writing is defined as any form of writing that goes beyond standard professional, journalistic, academic, or technical literature. It typically involves imagination, originality, and expression of thoughts and emotions. Creative writing can include, but is not limited to, the following formats:\n- Fiction (e.g., short stories, novels)\n- Poetry (e.g., sonnets, free verse)\n- Dramatic writing (e.g., screenplays, monologues, scripts)\n- Personal essays (focusing on subjective experiences or narrative storytelling)\n- Songs and lyrics\n\nCarefully analyze the user prompt and consider whether it primarily requires creative writing. Think about the following aspects:\n1. Does the prompt ask for fictional content, speculative scenarios, or the use of imagination to construct narratives?\n2. Does it encourage the expression of thoughts, emotions, or personal experiences beyond mere factual reporting or analysis?\n3. Is it asking for writing in a specific creative format (e.g., story, poem, script, etc)?\n4. Is the primary purpose of the prompt to foster creative expression or originality rather than information delivery, technical documentation, or analytical reasoning?\n5. Does the prompt request stylistic or rhetorical elements often associated with creative writing, such as metaphor, imagery, dialogue, etc?\n6. Does the prompt expect a response in natural language (e.g., sentences, paragraphs) rather than visual, mathematical, or non-linguistic output?\n\nOutput your verdict as either "yes" or "no"in the following format:\n<decision>\n[yes/no]\n</decision>. Do NOT explain.'
self.prompt_template = "<user_prompt>\n{PROMPT}\n</user_prompt>"
self.batch_size = 1
self.is_parallel = True

def get_score(self, judgment):
matches = self.pattern.findall(
Expand Down Expand Up @@ -185,7 +212,9 @@ def __init__(self):
super().__init__()
self.name_tag = "refusal_v0.2"
self.prompt_template = "Here is the user query:\n<user_query>\n{QUERY}\n</user_query>\n\nHere is the LLM response to the user:\n<llm_response>\n{RESPONSE}\n</llm_response>"
self.classifier = HuggingFaceRefusalClassifier()
self.classifier = HuggingFaceClassifier(model_path="lmarena-ai/RefusalClassifier")
self.batch_size = 1
self.is_parallel = False

def pre_process(self, conversation):
conv = []
Expand All @@ -196,3 +225,34 @@ def pre_process(self, conversation):
}
conv.append(self.prompt_template.format(**args))
return conv

def post_process(self, outputs):
return outputs

def get_answer(self, batch, model_name, max_tokens, temperature, api_dict):
'''
Retrieve labels for a batch of conversations.
Returns:
dict: A dictionary mapping conversation uid to refusal classification.
'''
to_label = []
to_label_uids = []

for _, row in batch.iterrows():
conv_a = self.pre_process(row["conversation_a"])
conv_b = self.pre_process(row["conversation_b"])

to_label.extend(conv_a)
to_label.extend(conv_b)

to_label_uids.extend([row["uid"]] * (len(conv_a) + len(conv_b)))

labels = self.classifier.classify_batch(to_label)
conv_refusals = defaultdict(lambda: False)
query_refusals = np.where(labels)[0]

for i in query_refusals:
conv_refusals[to_label_uids[i]] = True

return conv_refusals
9 changes: 5 additions & 4 deletions fastchat/serve/monitor/classify/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Yaml config file for category classification

input_file: "/home/derryxu/FastChatRefusal/fastchat/serve/monitor/classify/refusal_test/unlabeled.json" # json
# input_file: "/home/derryxu/FastChatRefusal/fastchat/serve/monitor/classify/refusal_test/unlabeled.json" # json
input_file: "/home/derryxu/arena-data-analysis/refusal/json/battles_derry_50k_sample.json"
cache_file: null # json
output_file: "/home/derryxu/FastChatRefusal/fastchat/serve/monitor/classify/refusal_test/labeled.jsonl" # json line

Expand All @@ -10,15 +11,15 @@ task_name:
# - criteria_v0.1
# - if_v0.1
# - math_v0.1
# - creative_writing_v0.1
- refusal_v0.2
- creative_writing_v0.1
# - refusal_v0.2

model_name: null
name: llama-3-70b-instruct
endpoints:
- api_base: null
api_key: null
parallel: 50
parallel: 64
temperature: 0.0
max_token: 512

Expand Down
137 changes: 64 additions & 73 deletions fastchat/serve/monitor/classify/label.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import os
import time

import argparse
import json
import pandas as pd
import os
import numpy as np
import concurrent.futures
import tqdm
import yaml
import random
import threading
import orjson

from collections import defaultdict
from category import Category
from utils import api_config

from utils import api_config, chat_completion_openai
TASK_COMPLETIONS = defaultdict(lambda: set())
TASK_TRACKER = defaultdict(lambda: {})

LOCK = threading.RLock()

Expand All @@ -38,67 +44,36 @@ def get_endpoint(endpoint_list):


def get_answer(
question: dict,
batch: pd.DataFrame,
model_name: str,
max_tokens: int,
temperature: float,
answer_file: str,
api_dict: dict,
categories: list,
category: object,
testing: bool,
):
if "category_tag" in question:
category_tag = question["category_tag"]
else:
category_tag = {}

output_log = {}

for category in categories:
if category.name_tag == "refusal_v0.2":
refusal_classifier = category.classifier

conv_a = category.pre_process(question["conversation_a"])
conv_b = category.pre_process(question["conversation_b"])

refusal_prompts = conv_a + conv_b
batch_size = 16
refusal_results = []
for i in range(0, len(refusal_prompts), batch_size):
batch_prompts = refusal_prompts[i : i + batch_size]
batch_results = refusal_classifier.classify_batch(batch_prompts)
refusal_results.extend(batch_results)

# If any query/resp classified as refusal, entire conversation is refusal
output = any(refusal_results)

# Dump answers
category_tag[category.name_tag] = output

else:
conv = category.pre_process(question["prompt"])
output = chat_completion_openai(
model=model_name,
messages=conv,
temperature=temperature,
max_tokens=max_tokens,
api_dict=api_dict,
)
# Dump answers
category_tag[category.name_tag] = category.post_process(output)

if testing:
output_log[category.name_tag] = output

question["category_tag"] = category_tag
if testing:
question["output_log"] = output_log

question.drop(["prompt", "uid", "required_tasks"], inplace=True)

with LOCK:
with open(answer_file, "a") as fout:
fout.write(json.dumps(question.to_dict()) + "\n")
uid_to_row = {}
for _, row in batch.iterrows():
uid = row["uid"]
uid_to_row[uid] = row
if "category_tag" in row:
TASK_TRACKER[uid].update(row["category_tag"])

outputs = category.get_answer(batch, model_name, max_tokens, temperature, api_dict)

for uid in uid_to_row:
output = outputs[uid]
TASK_COMPLETIONS[uid].add(category.name_tag)
TASK_TRACKER[uid][category.name_tag] = category.post_process(output)

row = uid_to_row[uid]
if TASK_COMPLETIONS[uid] == row["required_tasks"]:
row["category_tag"] = TASK_TRACKER[uid]
row.drop(["prompt", "uid", "required_tasks"], inplace=True)
with LOCK:
with open(answer_file, "a") as fout:
fout.write(json.dumps(row.to_dict()) + "\n")


def category_merge(row):
Expand All @@ -125,13 +100,13 @@ def find_required_tasks(row):
cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {}
output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {}

return [
return set([
name
for name in TASKS
if not (
name in input_category or name in cache_category or name in output_category
)
]
])


if __name__ == "__main__":
Expand All @@ -150,6 +125,9 @@ def find_required_tasks(row):
api_config(config)

categories = [Category.create_category(name) for name in config["task_name"]]
parallel_categories = [category for category in categories if category.is_parallel]
not_parallel_categories = [category for category in categories if not category.is_parallel]

TASKS = config["task_name"]
print(
f"Following categories will be labeled:\n{[category.name_tag for category in categories]}"
Expand Down Expand Up @@ -217,27 +195,40 @@ def find_required_tasks(row):
)
not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500])

with concurrent.futures.ThreadPoolExecutor(
max_workers=config["parallel"]
) as executor:
futures = []
for index, row in tqdm.tqdm(not_labeled.iterrows()):
future = executor.submit(
get_answer,
row,
for category in not_parallel_categories:
category_not_labeled = not_labeled[not_labeled['required_tasks'].apply(lambda x: category.name_tag in x)]
print(category.name_tag)
for index, batch in tqdm.tqdm(category_not_labeled.groupby(np.arange(len(category_not_labeled)) // category.batch_size)):
get_answer(
batch,
config["model_name"],
config["max_token"],
config["temperature"],
config["output_file"],
get_endpoint(config["endpoints"]),
[
category
for category in categories
if category.name_tag in row["required_tasks"]
],
args.testing,
category,
args.testing
)
futures.append(future)

with concurrent.futures.ThreadPoolExecutor(
max_workers=config["parallel"]
) as executor:
futures = []
for category in parallel_categories:
category_not_labeled = not_labeled[not_labeled['required_tasks'].apply(lambda x: category.name_tag in x)]
for index, batch in tqdm.tqdm(category_not_labeled.groupby(np.arange(len(category_not_labeled)) // category.batch_size)):
future = executor.submit(
get_answer,
batch,
config["model_name"],
config["max_token"],
config["temperature"],
config["output_file"],
get_endpoint(config["endpoints"]),
category,
args.testing
)
futures.append(future)
for future in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures)
):
Expand Down
Loading

0 comments on commit d0c5de4

Please sign in to comment.