Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code and Demo for Redirection #250

Merged
merged 32 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
90bce18
Fix broken hyperconvo demo link
vianxnguyen Jul 27, 2024
70cb3a5
Merge branch 'master' of github.com:vianxnguyen/ConvoKit
vianxnguyen Nov 13, 2024
b936e6f
Initial commit
vianxnguyen Nov 19, 2024
74267b2
Merge branch 'CornellNLP:master' into master
vianxnguyen Nov 19, 2024
015fa4e
Update setup
vianxnguyen Nov 19, 2024
dfa4ffe
Fix imports
vianxnguyen Nov 19, 2024
1b2fe76
Minor import fix
vianxnguyen Nov 19, 2024
937169c
Update utterances
vianxnguyen Nov 20, 2024
04120d1
Update notebook
vianxnguyen Nov 20, 2024
65956f7
Fix typo
vianxnguyen Nov 20, 2024
39b1e3f
Fix typo
vianxnguyen Nov 20, 2024
16d5df2
Change context length
vianxnguyen Nov 20, 2024
a61c945
More fixes
vianxnguyen Nov 20, 2024
eb39946
Update notebook
vianxnguyen Nov 20, 2024
3e17432
Update
vianxnguyen Nov 20, 2024
02a2d3f
Changes
vianxnguyen Dec 2, 2024
f9dfee1
Fix
vianxnguyen Dec 2, 2024
5fd450f
Merge remote-tracking branch 'upstream/master'
vianxnguyen Dec 2, 2024
a99bee7
update dependencies
vianxnguyen Dec 23, 2024
ca6ce0a
format
vianxnguyen Dec 23, 2024
1b6f632
docs and dependnecies"
vianxnguyen Dec 27, 2024
a19250c
update
vianxnguyen Dec 27, 2024
dbcfdd0
link
vianxnguyen Dec 27, 2024
d6ba19f
install
vianxnguyen Dec 27, 2024
bc48478
.
vianxnguyen Dec 27, 2024
91612bc
.
vianxnguyen Dec 27, 2024
f9c7a21
.
vianxnguyen Dec 29, 2024
f7774be
format
vianxnguyen Dec 29, 2024
9ad7dba
.
vianxnguyen Dec 29, 2024
59b7dca
reorg
vianxnguyen Dec 30, 2024
a100266
Merge branch 'CornellNLP:master' into master
vianxnguyen Dec 30, 2024
bdb6400
update version number
seanzhangkx8 Dec 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
[![Discord Community](https://img.shields.io/static/v1?logo=discord&style=flat&color=red&label=discord&message=community)](https://discord.gg/WMFqMWgz6P)


This toolkit contains tools to extract conversational features and analyze social phenomena in conversations, using a [single unified interface](https://convokit.cornell.edu/documentation/architecture.html) inspired by (and compatible with) scikit-learn. Several large [conversational datasets](https://github.com/CornellNLP/ConvoKit#datasets) are included together with scripts exemplifying the use of the toolkit on these datasets. The latest version is [3.0.2](https://github.com/CornellNLP/ConvoKit/releases/tag/v3.0.2) (released December 27, 2024); follow the [project on GitHub](https://github.com/CornellNLP/ConvoKit) to keep track of updates.
This toolkit contains tools to extract conversational features and analyze social phenomena in conversations, using a [single unified interface](https://convokit.cornell.edu/documentation/architecture.html) inspired by (and compatible with) scikit-learn. Several large [conversational datasets](https://github.com/CornellNLP/ConvoKit#datasets) are included together with scripts exemplifying the use of the toolkit on these datasets. The latest version is [3.1.0](https://github.com/CornellNLP/ConvoKit/releases/tag/v3.1.0) (released December 30, 2024); follow the [project on GitHub](https://github.com/CornellNLP/ConvoKit) to keep track of updates.

Join our [Discord community](https://discord.gg/WMFqMWgz6P) to stay informed, connect with fellow developers, and be part of an engaging space where we share progress, discuss features, and tackle issues together.

Expand Down
2 changes: 2 additions & 0 deletions convokit/redirection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .likelihoodModel import *
from .redirection import *
33 changes: 33 additions & 0 deletions convokit/redirection/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from peft import LoraConfig
from transformers import BitsAndBytesConfig
import torch

DEFAULT_BNB_CONFIG = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

DEFAULT_LORA_CONFIG = LoraConfig(
r=16,
lora_dropout=0.05,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)

DEFAULT_TRAIN_CONFIG = {
"output_dir": "checkpoints",
"logging_dir": "logging",
"logging_steps": 25,
"eval_steps": 50,
"num_train_epochs": 2,
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
"evaluation_strategy": "steps",
"save_strategy": "steps",
"save_steps": 50,
"optim": "paged_adamw_8bit",
"learning_rate": 2e-4,
"max_seq_length": 512,
"load_best_model_at_end": True,
}
100 changes: 100 additions & 0 deletions convokit/redirection/contextSelector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from .preprocessing import default_speaker_prefixes


def default_previous_context_selector(convo):
"""
Default function to compute previous contexts for Redirection. For
actual contexts, uses the current utterance and immediate previous
utterance by speaker with different role. For reference contexts, uses
the previous utterance by the same role speaker instead of the current
utterance as a point of reference.

:param convo: ConvoKit Conversation object to compute contexts over

:return: Tuple of actual contexts and reference contexts
"""
actual_contexts = {}
reference_contexts = {}
utts = [utt for utt in convo.iter_utterances()]
roles = list({utt.meta["role"] for utt in utts})
assert len(roles) == 2
spk_prefixes = default_speaker_prefixes(roles)
role_to_prefix = {roles[i]: spk_prefixes[i] for i in range(len(roles))}
role_1 = roles[0]
role_2 = roles[1]
prev_spk = None
prev_1, prev_2, cur_1, cur_2 = None, None, None, None
for i, utt in enumerate(utts):
utt_text = utt.text
cur_spk = utt.meta["role"]
if prev_spk is not None and cur_spk != prev_spk:
if role_2 in cur_spk:
prev_1 = cur_1
else:
prev_2 = cur_2

if prev_1 and prev_2 is not None:
if role_2 in cur_spk:
prev = prev_1
prev_prev = prev_2
else:
prev = prev_2
prev_prev = prev_1

prev_prev_text, prev_prev_role = prev_prev
prev_text, prev_role = prev

prev_prev_data = role_to_prefix[prev_prev_role] + prev_prev_text
prev_data = role_to_prefix[prev_role] + prev_text
cur_data = role_to_prefix[cur_spk] + utt_text

actual_contexts[utt.id] = [prev_data, cur_data]
reference_contexts[utt.id] = [prev_data, prev_prev_data]

if role_1 in cur_spk:
cur_1 = (utt_text, cur_spk)
if role_2 in cur_spk:
cur_2 = (utt_text, cur_spk)

prev_spk = cur_spk

return actual_contexts, reference_contexts


def default_future_context_selector(convo):
"""
Default function to compute future contexts for Redirection. Uses the
immediate successor utterance from a different role speaker.

:param convo: ConvoKit Conversation object to compute contexts over

:return: Dictionary of Utterance id to future contexts
"""
future_contexts = {}
cur_1 = None
cur_2 = None
utts = [utt for utt in convo.iter_utterances()]
roles = list({utt.meta["role"] for utt in utts})
assert len(roles) == 2
spk_prefixes = default_speaker_prefixes(roles)
role_to_prefix = {roles[i]: spk_prefixes[i] for i in range(len(roles))}
role_1 = roles[0]
role_2 = roles[1]
n = len(utts)
for i in range(n - 1, -1, -1):
utt = utts[i]
utt_text = utt.text
cur_spk = utt.meta["role"]
if role_2 in cur_spk:
cur_2 = (utt_text, cur_spk)
if cur_1 is not None:
future_text, future_role = cur_1
future_data = role_to_prefix[future_role] + future_text
future_contexts[utt.id] = [future_data]
else:
cur_1 = (utt_text, cur_spk)
if cur_2 is not None:
future_text, future_role = cur_2
future_data = role_to_prefix[future_role] + future_text
future_contexts[utt.id] = [future_data]
return future_contexts
152 changes: 152 additions & 0 deletions convokit/redirection/gemmaLikelihoodModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from .likelihoodModel import LikelihoodModel
import torch
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
TrainingArguments,
)
from trl import SFTTrainer
from .config import DEFAULT_TRAIN_CONFIG, DEFAULT_BNB_CONFIG, DEFAULT_LORA_CONFIG


class GemmaLikelihoodModel(LikelihoodModel):
"""
Likelihood model supported by Gemma, used to compute utterance likelihoods.

:param hf_token: Huggingface authentication token
:param model_id: Gemma model id version
:param device: Device to use
:param train_config: Training config for fine-tuning
:param bnb_config: bitsandbytes config for quantization
:param lora_config: LoRA config for fine-tuning
"""

def __init__(
self,
hf_token,
model_id="google/gemma-2b",
device="cuda" if torch.cuda.is_available() else "cpu",
train_config=DEFAULT_TRAIN_CONFIG,
bnb_config=DEFAULT_BNB_CONFIG,
lora_config=DEFAULT_LORA_CONFIG,
):
self.tokenizer = AutoTokenizer.from_pretrained(
model_id, token=hf_token, padding_side="right"
)
self.model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=bnb_config, device_map="auto", token=hf_token
)
self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)
self.hf_token = hf_token
self.device = device
self.train_config = train_config
self.lora_config = lora_config
self.bnb_config = bnb_config
self.max_length = self.train_config["max_seq_length"]

def name(self):
return self.__class__.name

def fit(self, train_data, val_data):
"""
Fine-tunes the Gemma model on the provided `train_data` and validates
on `val_data`.

:param train_data: Data to fine-tune model
:param val_data: Data to validate model
"""
training_args = TrainingArguments(
output_dir=self.train_config["output_dir"],
logging_dir=self.train_config["logging_dir"],
logging_steps=self.train_config["logging_steps"],
eval_steps=self.train_config["eval_steps"],
num_train_epochs=self.train_config["num_train_epochs"],
per_device_train_batch_size=self.train_config["per_device_train_batch_size"],
per_device_eval_batch_size=self.train_config["per_device_eval_batch_size"],
evaluation_strategy=self.train_config["evaluation_strategy"],
save_strategy=self.train_config["save_strategy"],
save_steps=self.train_config["save_steps"],
optim=self.train_config["optim"],
learning_rate=self.train_config["learning_rate"],
load_best_model_at_end=self.train_config["load_best_model_at_end"],
)

trainer = SFTTrainer(
model=self.model,
train_dataset=train_data,
eval_dataset=val_data,
args=training_args,
peft_config=self.lora_config,
max_seq_length=self.train_config["max_seq_length"],
)
trainer.train()

def _calculate_likelihood_prob(self, past_context, future_context):
"""
Computes the utterance likelihoods given the previous context to
condition on and the future context to predict.

:param past_context: Context to condition
:param future_context: Context to predict likelihood

:return: Likelihoods of contexts
"""
past_context = "\n\n".join(past_context)
future_context = "\n\n".join(future_context)

context_ids = self.tokenizer.encode(
past_context,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
future_ids = self.tokenizer.encode(
future_context,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = torch.cat([context_ids, future_ids], dim=1)
if input_ids.shape[1] > self.max_length:
input_ids = input_ids[:, -self.max_length :]
input_ids = input_ids.to(self.device)
with torch.no_grad():
probs = torch.nn.functional.softmax(self.model(input_ids)[0], dim=-1)
cond_log_probs = []
for i, future_id in enumerate(future_ids[0]):
index = i + (input_ids.shape[1] - future_ids.shape[1]) - 1
logprob = torch.log(probs[0, index, future_id])
cond_log_probs.append(logprob.item())
result = sum(cond_log_probs)
return result

def transform(self, test_data, verbosity=5):
"""
Computes the utterance likelihoods for the provided `test_data`.

:param test_data: Data to compute likelihoods over
:param verbosity: Verbosity to print updated messages

:return: Likelihoods of the `test_data`
"""
prev_contexts, future_contexts = test_data
likelihoods = []
for i in range(len(prev_contexts)):
if i % verbosity == 0 and i > 0:
print(i, "/", len(test_data))
convo_likelihoods = {}
convo_prev_contexts = prev_contexts[i]
convo_future_contexts = future_contexts[i]
for utt_id in convo_prev_contexts:
if utt_id not in convo_future_contexts:
continue
utt_prev_context = convo_prev_contexts[utt_id]
utt_future_context = convo_future_contexts[utt_id]
convo_likelihoods[utt_id] = self._calculate_likelihood_prob(
past_context=utt_prev_context, future_context=utt_future_context
)
likelihoods.append(convo_likelihoods)
return likelihoods
51 changes: 51 additions & 0 deletions convokit/redirection/likelihoodModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Callable


class LikelihoodModel(ABC):
"""
Abstract class representing a model to compute utterance likelihoods
based on provided context. Different models (Gemma, Llama, Mistral, etc.)
can be supported by inheriting from this base class.
"""

def __init__(self):
self._name = None

@property
def name(self):
"""
Name of the likelihood model.
"""
return self._name

@name.setter
def name(self, name):
"""
Sets the name of the likelihood model.

:param name: Name of model
"""
self._name = name

@abstractmethod
def fit(self, train_data, val_data):
"""
Fine-tunes the likelihood model on the provided `train_data` and
validates on `val_data`.

:param train_data: Data to fine-tune model
:param val_data: Data to validate model
"""
pass

@abstractmethod
def transform(self, test_data):
"""
Computes the utterance likelihoods for the provided `test_data`.

:param test_data: Data to compute likelihoods over

:return: Likelihoods of the `test_data`
"""
pass
Loading
Loading