-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reasons for upcasting the logits dtype outside the kernel #241
Comments
Thanks! We did the casting to stay consistent with huggingface behavior. But yes, i think we can do it inside. There is a PR doing this: #238.
Curious how did you measure the loss? Also, very impressive background. Welcome contribution :-) |
torch.manual_seed(42)
batch_size, seq_len, hidde_size, vocab_size = 8, 4096, 2048, 128000
x = torch.randn(batch_size * seq_len, hidde_size).cuda().bfloat16().requires_grad_()
target = torch.randint(0, vocab_size, (batch_size * seq_len,)).cuda()
weight = torch.randn(vocab_size, hidde_size).cuda().bfloat16().requires_grad_()
bias = torch.randn(vocab_size).cuda().bfloat16().requires_grad_()
logits = F.linear(x, weight, bias).float()
output1 = nn.CrossEntropyLoss()(logits, target)
do = torch.randn_like(output1).cuda().bfloat16()
output1.backward(do)
ref_dx, x.grad = x.grad.clone(), None
ref_dw, weight.grad = weight.grad.clone(), None
ref_db, bias.grad = bias.grad.clone(), None
output2 = FusedLinearCrossEntropyLoss()(x, target, weight, bias)
output2.backward(do)
tri_dx, x.grad = x.grad.clone(), None
tri_dw, weight.grad = weight.grad.clone(), None
tri_db, bias.grad = bias.grad.clone(), None
# print('\n\n', output1, )
# print(output2, '\n\n',)
print(" o", torch.abs(output1 - output2).max())
print("dx", torch.abs(ref_dx - tri_dx).max())
print("dw", torch.abs(ref_dw - tri_dw).max())
print("db", torch.abs(ref_db - tri_db).max()) very simple testing code.
|
o difference is a bit high? |
@ByronHsu That's the diffs of current impls. I think this loss makes sense given that the vocab is large and the first output is computed under fp32 while the second is bf16. |
@ByronHsu Hi, have you compared the final loss of FLCE with the naive counterpart? It might be better to limit the maximum number of chunks. For instance, we could set |
I've also found that for 128K vocab, 8 chunks can be faster, with the cost of nearly <1G additional mem.
|
UPDATE: Just trained 3 370M models on 10B tokens of Fineweb-edu with 8K ctx length and 32K vocab
Regarding throughputs, 8 is faster. |
Hello, thank you for this great work.
Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py
Line 69 in acd8272
Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py
Lines 91 to 96 in acd8272
I'm wondering if there are any reasons for upcasting/downcasting the logits dtype outside the kernel?
If I understand correctly, we already do fp32 upcast inside, so this op is redundant?
I just compare the outputs of the two versions, i.e., w/ and w/o the upcast, and found there's no precision loss if the above code r removed.
The text was updated successfully, but these errors were encountered: