Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] Pythia output inconsistent across batch sizes when use_split_qkv_input=True #661

Open
1 task done
oliveradk opened this issue Jul 8, 2024 · 0 comments
Open
1 task done
Labels
bug Something isn't working complexity-high Very complicated changes for people to address who are quite familiar with the code implementation-inaccuracy Any issues related to our implementation being off from the official version

Comments

@oliveradk
Copy link
Contributor

oliveradk commented Jul 8, 2024

Describe the bug
Pythia output inconsistent across batch sizes when use_split_qkv_input=True

Code example

# set hook
input = torch.tensor([
    [2, 2, 2, 69, 26, 67, 17, 14, 14836, 3593, 14, 21, 12347, 14, 1257, 24],
    [535, 50270, 338, 1881, 15, 2364, 15, 25950, 2073, 15741, 64, 29786,  3401, 35495, 686, 26]
])
model_name = "pythia-70m"
model = HookedTransformer.from_pretrained(model_name)
model.set_use_split_qkv_input(True)
out_batch = model(input[:2])
out_single = model(input[:1])
assert torch.allclose(out_batch[0], out_single[0], atol=1e-3) # False

System Info
Describe the characteristic of your environment:

  • Describe how transformer_lens was installed (pip, docker, source, ...)
    MacOS (MPS)
    Python3.10
    Torch 2.0.1

  • What OS are you using? (Linux, MacOS, Windows)

  • Python version (We suppourt 3.7 -3.10 currently)

Additional context
This doesn't happen on all inputs - for example, on imdb

model_name = "pythia-70m"
dataset = datasets.load_dataset("imdb", split="train[:10]")
tokens = model.tokenizer(dataset["text"], padding=True, truncation=True, return_tensors="pt")["input_ids"]
model = HookedTransformer.from_pretrained(model_name)
model.set_use_split_qkv_input(True)
out_batch = model(tokens[:4])
out_single = model(tokens[:1])
assert torch.allclose(out_batch[0], out_single[0], atol=5e-4) # True

possibly related to #385

Checklist

  • I have checked that there is no similar issue in the repo (required)
@bryce13950 bryce13950 added complexity-high Very complicated changes for people to address who are quite familiar with the code implementation-inaccuracy Any issues related to our implementation being off from the official version bug Something isn't working labels Jul 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working complexity-high Very complicated changes for people to address who are quite familiar with the code implementation-inaccuracy Any issues related to our implementation being off from the official version
Projects
None yet
Development

No branches or pull requests

2 participants