Skip to content

Commit

Permalink
Merge pull request #7 from asusevski/dev
Browse files Browse the repository at this point in the history
merge dev changes
  • Loading branch information
asusevski authored Dec 9, 2024
2 parents d3b79db + 3cc5225 commit 5d14906
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 98 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "topic-autolabel"
version = "0.1.2"
version = "0.1.3"
description = "Automatic topic labeling using LLMs"
authors = [
{name = "Anthony Susevski", email = "[email protected]"}
Expand All @@ -22,7 +22,8 @@ dependencies = [

"instructor",
"torch",
"transformers"
"transformers",
"sentence-transformers"
]
urls = {Repository = "https://github.com/asusevski/topic-autolabel"}

Expand Down Expand Up @@ -54,4 +55,4 @@ exclude = [".venv*", "**/__pycache__", "*.ipynb"]
[tool.pyright]
include = ["src/*"]
exclude = [".venv*", "**/__pycache__", "**/__init__.py", "data/*"]
typeCheckingMode = "standard"
typeCheckingMode = "standard"
14 changes: 10 additions & 4 deletions src/topic_autolabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@


def process_file(
filepath: str,
filepath: Optional[str],
text_column: str,
model_name: str = "meta-llama/Llama-3.1-8B-Instruct",
num_labels: int = 5,
df: Optional[pd.DataFrame] = None,
model_name: str = "meta-llama/Llama-3.1-8B-Instruct",
num_labels: Optional[int] = 5,
candidate_labels: Optional[List[str]] = None,
batch_size: Optional[int] = 8,
) -> pd.DataFrame:
"""
Process a file and add topic labels to it.
Expand All @@ -27,12 +28,17 @@ def process_file(
Returns:
DataFrame with a new 'label' column containing the generated labels
"""
try:
assert filepath is not None or df is not None
except AssertionError:
raise ValueError("One of filepath or df must be passed to the function.")

# Load the data
if df is None:
df = load_data(filepath, text_column)

# Initialize the labeler
labeler = TopicLabeler(model_name=model_name)
labeler = TopicLabeler(model_name=model_name, batch_size=batch_size)

# Generate labels
labels = labeler.generate_labels(
Expand Down
289 changes: 202 additions & 87 deletions src/topic_autolabel/core/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,163 @@
from typing import List, Optional, Union

import torch
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer


class TextDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
return self.texts[idx]


class TopicLabeler:
def __init__(
self,
model_name: str = "meta-llama/Llama-3.1-8B-Instruct",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
batch_size: int = 8,
):
"""
Initialize the topic labeler with a specified LLM.
"""
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(device)
self.batch_size = batch_size
self.similarity_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2"
)
if torch.cuda.is_available():
self.similarity_model.to(device)

def _generate_prompt(
self,
text: str,
candidate_labels: Optional[List[str]] = None,
def _create_prompt(
self, text: str, candidate_labels: Optional[List[str]] = None
) -> str:
"""Generate appropriate prompt based on labeling mode."""
if candidate_labels:
return f"Given the following text, classify it into one of these categories: {', '.join(candidate_labels)}\n\nText: {text}\n\nThe category that best describes this text is:"
return f"Use three words total (comma separated) to describe general topics in above texts. Under no circumstances use enumeration. Example format: Tree, Cat, Fireman\n\nText: {text}\nThree comma separated words:"

@torch.no_grad()
def _batch_generate(
self,
prompts: List[str],
max_new_tokens: int,
) -> List[str]:
"""Generate responses for a batch of prompts."""
# Tokenize all prompts at once
inputs = self.tokenizer(
prompts,
padding=True,
truncation=True,
return_tensors="pt",
).to(self.device)
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
temperature=0.3,
num_return_sequences=1,
pad_token_id=self.tokenizer.pad_token_id,
)
# Extract generated text for each sequence
responses = []
for i, output in enumerate(outputs):
prompt_length = inputs["attention_mask"][i].sum()
response = self.tokenizer.decode(
output[prompt_length:], skip_special_tokens=True
)
responses.append(response.lower().strip())
# print(responses)
return responses

def _filter_labels_semantic(
self, label_counts: Counter, num_labels: int, similarity_threshold: float = 0.50
):
"""
Generate appropriate prompt based on whether we're doing open-ended
labeling or classification with candidate labels.
Filter labels semantically using SentenceTransformers by removing similar labels
and keeping the most frequent ones.
Args:
label_counts: Counter object containing labels and their counts
num_labels: Number of labels to return
similarity_threshold: Threshold for cosine similarity (default: 0.7)
Returns:
List of filtered labels
"""
if candidate_labels:
prompt = f"""Given the following text, classify it into one of these categories: {', '.join(candidate_labels)}
Text: {text}
# Sort labels by frequency
sorted_labels = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)
labels = [label for label, _ in sorted_labels]
# Get embeddings for all labels
embeddings = self.similarity_model.encode(
labels, convert_to_tensor=True, device=self.device
)
# Calculate cosine similarity matrix
similarity_matrix = torch.nn.functional.cosine_similarity(
embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2
)
# Filter out similar labels
filtered_indices = []
for i in range(len(labels)):
# Skip if this label is too similar to any already selected label
too_similar = False
for j in filtered_indices:
if i != j and similarity_matrix[i, j] > similarity_threshold:
too_similar = True
break
if not too_similar:
filtered_indices.append(i)
# Break if we have enough labels
if len(filtered_indices) == num_labels:
break
# If we don't have enough labels after filtering, add the next most frequent ones
if len(filtered_indices) < num_labels:
remaining_indices = [
i for i in range(len(labels)) if i not in filtered_indices
]
filtered_indices.extend(
remaining_indices[: num_labels - len(filtered_indices)]
)

The category that best describes this text is:"""
else:
prompt = f""""Use three words total (comma separated)\
to describe general topics in above texts. Under no circumstances use enumeration. \
Example format: Tree, Cat, Fireman
return [labels[i] for i in filtered_indices[:num_labels]]

def _process_open_ended_responses(
self, responses: List[str], num_labels: int
) -> List[str]:
"""Process responses for open-ended labeling."""
pattern = r"^\w+,\s*\w+,\s*\w+"
word_lists = []

Text: {text}
Three comma separated words:"""
return prompt
for response in responses:
words = re.findall(pattern, response)
if words:
word_lists.append(words[0].split(", "))
else:
word_lists.append([])
# Get most common terms
counts = Counter(word for sublist in word_lists for word in sublist)
if len(counts) < num_labels:
raise ValueError(
f"Could not generate {num_labels} unique labels from the texts"
)
labels = self._filter_labels_semantic(counts, num_labels)
return labels

def generate_labels(
self,
Expand All @@ -53,82 +168,82 @@ def generate_labels(
candidate_labels: Optional[List[str]] = None,
) -> List[str]:
"""
Generate labels for the given texts.
Generate labels for the given texts in batches.
Args:
texts: Single text or list of texts to label
num_labels: Number of labels to generate for open-ended labeling
candidate_labels: Optional list of predefined labels
Returns:
List of generated labels
"""
if isinstance(texts, str):
texts = [texts]

if candidate_labels:
max_tokens = max(
[len(self.tokenizer(x)["input_ids"]) for x in candidate_labels]
# Create dataset and dataloader for batch processing
dataset = TextDataset(texts, self.tokenizer)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
# Calculate max tokens based on labeling mode
# NOTE: +2 is janky, but sometimes model outputs something like "a) answer" so meh
# for that matter 25 is also janky -- solely for unsupervised labels
max_tokens = (
max(
len(self.tokenizer(x)["input_ids"]) + 2
for x in (candidate_labels or [])
)
pattern = r""
else:
max_tokens = 50
pattern = r"^\w+,\s*\w+,\s*\w+"

labels = []
for text in texts:
prompt = self._generate_prompt(text, candidate_labels)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id,
if candidate_labels
else 25
)
all_responses = []
# Process texts in batches
for batch_texts in dataloader:
prompts = [
self._create_prompt(text, candidate_labels) for text in batch_texts
]
responses = self._batch_generate(prompts, max_tokens)
all_responses.extend(responses)
if not candidate_labels:
# Handle open-ended labeling
top_labels = self._process_open_ended_responses(all_responses, num_labels)
# Re-label texts with top labels
final_labels = []
for batch_texts in dataloader:
prompts = [
self._create_prompt(text, top_labels) for text in batch_texts
]
max_tokens = max(
len(self.tokenizer(x)["input_ids"]) + 2 for x in (top_labels or [])
)
batch_responses = self._batch_generate(prompts, max_tokens)
for response in batch_responses:
label_found = False
for label in top_labels:
if label in response:
final_labels.append(label)
label_found = True
break
if not label_found:
final_labels.append("<err>")

prompt_length = inputs["input_ids"].shape[1]
response = self.tokenizer.decode(outputs[0][prompt_length:])
response = response.lower().strip()
if candidate_labels:
if response not in candidate_labels:
response = "<err>"
labels.append(response)
else:
words = re.findall(pattern, response)
if words:
words = words[0].split(", ")
labels.append(words)
else:
labels.append([])
return [
response if response in top_labels else "<err>"
for response in final_labels
]

if candidate_labels:
return labels
else:
## Re-label with most common terms
counts = Counter(word for sublist in labels for word in sublist)
try:
assert (num_labels) <= len(counts)
except AssertionError:
raise Exception
top_labels = [x[0] for x in counts.most_common(num_labels)]
max_tokens = max([len(self.tokenizer(x)["input_ids"]) for x in top_labels])
# TODO: filter top labels w/ embeddings and maybe remove generic labels via clustering?
# re-label with top labels
# Handle classification with candidate labels
final_labels = []
for text in texts:
prompt = self._generate_prompt(text, top_labels)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id,
)
prompt_length = inputs["input_ids"].shape[1]
response = self.tokenizer.decode(outputs[0][prompt_length:])
response = response.lower().strip()
found = False
for label in top_labels:
if label in response:
final_labels.append(label)
found = True
for response in all_responses:
label_found = False
for candidate_label in candidate_labels:
if candidate_label in response:
final_labels.append(candidate_label)
label_found = True
break
if not found:
if not label_found:
final_labels.append("<err>")
return final_labels
all_responses = [response.strip().strip(".") for response in all_responses]
return [
response if response in candidate_labels else "<err>"
for response in final_labels
]
Loading

0 comments on commit 5d14906

Please sign in to comment.