Skip to content

[Refactor] Avoid padding in transformer wrapper #2881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 113 additions & 25 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@

import pytest
import torch
from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict
from tensordict import (
lazy_stack,
LazyStackedTensorDict,
NonTensorStack,
set_list_to_stack,
TensorDict,
)
from tensordict.nn import CompositeDistribution, TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

Expand Down Expand Up @@ -937,6 +943,38 @@ def vllm_instance(self):
tokenizer.pad_token = tokenizer.eos_token
return llm_model

@pytest.fixture(scope="module")
def transformers_instance(self):
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel(GPT2Config()).eval()
# tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
# model = OPTModel(OPTConfig("facebook/opt-125m"))
# tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
# model = OPTForCausalLM(OPTConfig())

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

return model, tokenizer

@pytest.fixture(scope="module")
def transformers_instance_pretrained(self):
from transformers import AutoTokenizer, OPTForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("gpt2")
# model = GPT2LMHeadModel(GPT2Config())
# tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
# model = OPTModel(OPTConfig("facebook/opt-125m"))
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
model = OPTForCausalLM.from_pretrained("facebook/opt-125m")

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

return model, tokenizer

@pytest.mark.parametrize(
"from_text, generate, return_log_probs, tokens, attention_mask",
[
Expand All @@ -961,22 +999,18 @@ def vllm_instance(self):
(False, True, False, torch.randint(1024, (1, 10)), None),
],
)
def test_TransformersWrapper(
self, from_text, generate, return_log_probs, tokens, attention_mask
def test_transformers_wrapper(
self,
from_text,
generate,
return_log_probs,
tokens,
attention_mask,
transformers_instance,
):
torch.manual_seed(0)
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel

# model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny"
# Load the model and tokenizer
# model = AutoModel.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel(GPT2Config())

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model, tokenizer = transformers_instance

m = TransformersWrapper(
model,
Expand Down Expand Up @@ -1019,7 +1053,7 @@ def test_TransformersWrapper(
(False, True, False, torch.randint(1024, (1, 10)), None),
],
)
def test_from_vllm(
def test_vllm_wrapper(
self,
from_text,
generate,
Expand Down Expand Up @@ -1163,15 +1197,11 @@ def _run_check(
(True, None, None),
],
)
def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
def test_transformers_logprobs(
self, from_text, tokens, attention_mask, transformers_instance
):
torch.manual_seed(0)
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel(GPT2Config()).eval()

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model, tokenizer = transformers_instance

m_generate = TransformersWrapper(
model,
Expand Down Expand Up @@ -1201,7 +1231,7 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
(True, False, torch.randint(1024, (1, 10)), None),
],
)
def test_from_vllm_logprobs(
def test_vllm_logprobs(
self, from_text, tokens, attention_mask, pad_output, vllm_instance
):
torch.manual_seed(0)
Expand Down Expand Up @@ -1254,6 +1284,7 @@ def _check_lps(
)
td_logprobs = model_logprobs(tdin_logprobs)
assert td_generate.log_probs.shape == td_generate.tokens_response.shape
assert td_logprobs.log_probs.shape == td_logprobs.tokens_response.shape
assert td_logprobs.log_probs.shape == td_generate.tokens_response.shape
torch.testing.assert_close(
td_generate.log_probs, td_logprobs.log_probs, rtol=tol, atol=tol
Expand Down Expand Up @@ -1374,7 +1405,7 @@ def _run_check_collector(self, policy):
assert "tokens" in data
# assert ("next", "tokens") in data

def test_generate_multiple_trajs_vllm(self, vllm_instance):
def test_vllm_generate_multiple_trajs(self, vllm_instance):
policy = vLLMWrapper(
vllm_instance,
return_log_probs=True,
Expand All @@ -1386,6 +1417,63 @@ def test_generate_multiple_trajs_vllm(self, vllm_instance):
)
data = policy(data)

@set_list_to_stack(True)
@pytest.mark.parametrize("from_text", [True, False])
@pytest.mark.parametrize("generate", [True, False])
def test_transformers_long_sequences(
self, from_text, generate, transformers_instance_pretrained
):
torch.manual_seed(42)
model, tokenizer = transformers_instance_pretrained
prompts = [
"The quick brown fox jumps over the lazy dog.", # Likely to finish soon
"Once upon a time in a land far, far away, there was a", # Likely to continue longer
"In the beginning, the universe was created. This has made a lot of people very angry and been widely regarded as a bad move.",
]
data = lazy_stack([TensorDict() for _ in range(len(prompts))])
data["text"] = prompts
eos_token_id = tokenizer.convert_tokens_to_ids(",")
if not from_text:
data["tokens"] = tokenizer(data["text"])["input_ids"]
data["attention_mask"] = (
0 * data.get("tokens", as_nested_tensor=True, layout=torch.strided) + 1
)
if not generate:
# we need responses
responses = prompts[1:] + [" et dolore magna aliqua."]
data["text_response"] = responses
if not from_text:
data["tokens_response"] = tokenizer(data["text_response"])["input_ids"]
# make sure dimensions are ragged for tokens entries
if "tokens" in data:
assert data.get_item_shape("tokens")[-1] == -1
if "tokens_response" in data:
assert data.get_item_shape("tokens_response")[-1] == -1
generate_kwargs = {}
if generate:
generate_kwargs = {
"max_new_tokens": 128, # Set a reasonable number of new tokens to generate
"min_length": 20, # Ensure a minimum length for the generated sequence
"pad_token_id": tokenizer.pad_token_id, # Use the tokenizer's pad token
"forced_eos_token_id": eos_token_id, # Use comma as an EOS token
}
policy = TransformersWrapper(
model,
tokenizer=tokenizer,
from_text=from_text,
generate=generate,
return_log_probs=True,
# TODO: use n trajs
generate_kwargs=generate_kwargs,
)
data_policy = policy(data)
if "tokens" in data_policy:
assert data_policy.get_item_shape("tokens")[-1] == -1
if "tokens_response" in data_policy:
assert (
data_policy.get_item_shape("tokens_response")[-1] == -1
) # TODO: this fails


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
22 changes: 19 additions & 3 deletions torchrl/modules/llm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,20 @@ def get_dist(
forward = TensorDictSequential.forward

@property
def log_prob_keys(self):
return ["log_probs"]
def log_prob_keys(self) -> list[NestedKey]:
return getattr(self, "_log_prob_keys", ["log_probs"])

log_prob_key = ProbabilisticTensorDictModule.log_prob_key
@log_prob_keys.setter
def log_prob_keys(self, value: list[NestedKey]):
self._log_prob_keys = value

@property
def log_prob_key(self) -> NestedKey:
return self.log_prob_keys[0]

@log_prob_key.setter
def log_prob_key(self, value: NestedKey) -> None:
self.log_prob_keys[0] = value

@property
def dist_params_keys(self) -> list[NestedKey]:
Expand All @@ -46,3 +56,9 @@ def dist_params_keys(self) -> list[NestedKey]:
@property
def dist_sample_keys(self) -> list[NestedKey]:
return ["tokens_response"]

def log_prob(self, data: TensorDictBase, **get_kwargs) -> TensorDictBase:
if not self.generate:
data = self(data)
return data.get(self.log_prob_key, **get_kwargs)
raise RuntimeError("log_prob not callable when generate=True.")
Loading
Loading