[Bug Report] Pythia output inconsistent across batch sizes when use_split_qkv_input=True #661
Open
1 task done
Labels
bug
Something isn't working
complexity-high
Very complicated changes for people to address who are quite familiar with the code
implementation-inaccuracy
Any issues related to our implementation being off from the official version
Describe the bug
Pythia output inconsistent across batch sizes when use_split_qkv_input=True
Code example
System Info
Describe the characteristic of your environment:
Describe how
transformer_lens
was installed (pip, docker, source, ...)MacOS (MPS)
Python3.10
Torch 2.0.1
What OS are you using? (Linux, MacOS, Windows)
Python version (We suppourt 3.7 -3.10 currently)
Additional context
This doesn't happen on all inputs - for example, on imdb
possibly related to #385
Checklist
The text was updated successfully, but these errors were encountered: