Skip to content

Commit

Permalink
reformatted code
Browse files Browse the repository at this point in the history
  • Loading branch information
derixu committed Nov 30, 2024
1 parent 1e103f3 commit e0f56f0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 24 deletions.
9 changes: 5 additions & 4 deletions fastchat/serve/monitor/classify/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
import ast
import re

from utils import (
HuggingFaceRefusalClassifier
)
from utils import HuggingFaceRefusalClassifier


class Category:
Expand Down Expand Up @@ -192,6 +190,9 @@ def __init__(self):
def pre_process(self, conversation):
conv = []
for i in range(0, len(conversation), 2):
args = {"QUERY": conversation[i]["content"], "RESPONSE": conversation[i+1]["content"]}
args = {
"QUERY": conversation[i]["content"],
"RESPONSE": conversation[i + 1]["content"],
}
conv.append(self.prompt_template.format(**args))
return conv
10 changes: 3 additions & 7 deletions fastchat/serve/monitor/classify/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@

from category import Category

from utils import (
api_config,
chat_completion_openai
)
from utils import api_config, chat_completion_openai

LOCK = threading.RLock()

Expand Down Expand Up @@ -58,7 +55,6 @@ def get_answer(
output_log = {}

for category in categories:

if category.name_tag == "refusal_v0.2":
refusal_classifier = category.classifier

Expand All @@ -69,10 +65,10 @@ def get_answer(
batch_size = 16
refusal_results = []
for i in range(0, len(refusal_prompts), batch_size):
batch_prompts = refusal_prompts[i:i + batch_size]
batch_prompts = refusal_prompts[i : i + batch_size]
batch_results = refusal_classifier.classify_batch(batch_prompts)
refusal_results.extend(batch_results)

# If any query/resp classified as refusal, entire conversation is refusal
output = any(refusal_results)

Expand Down
26 changes: 13 additions & 13 deletions fastchat/serve/monitor/classify/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,24 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
class HuggingFaceRefusalClassifier:
def __init__(self):
print("Loading model and tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained("derixu/refusal_classifier-mlm_then_classifier_v3") #TODO: Migrate to LMSYS account and change path
self.model = AutoModelForSequenceClassification.from_pretrained("derixu/refusal_classifier-mlm_then_classifier_v3")
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(
"derixu/refusal_classifier-mlm_then_classifier_v3"
) # TODO: Migrate to LMSYS account and change path
self.model = AutoModelForSequenceClassification.from_pretrained(
"derixu/refusal_classifier-mlm_then_classifier_v3"
)
self.model.eval()

def classify_batch(self, input_texts):
inputs = self.tokenizer(
input_texts,
truncation=True,
max_length=512,
return_tensors="pt",
padding=True
input_texts,
truncation=True,
max_length=512,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred_classes = torch.argmax(probabilities, dim=-1).tolist()
return [bool(pred) for pred in pred_classes]




0 comments on commit e0f56f0

Please sign in to comment.