Open
Description
In the ParallelEmbedding
layer, when sharding accross vocab, the output is masked at the very end of the operation.
It seems that the masking is done by multiplying by an hard-coded float
mask, which leads to the actual float16
/bfloat16
to be upcast to float32
.
A correct implementation would be to multiply by a mask of the same type as the intended output.