Skip to content

Wrong output fp16/bf16 dtype in ParallelEmbedding when sharding accross vocab #35

Open
@dacorvo

Description

@dacorvo

In the ParallelEmbedding layer, when sharding accross vocab, the output is masked at the very end of the operation.

It seems that the masking is done by multiplying by an hard-coded float mask, which leads to the actual float16/bfloat16 to be upcast to float32.

A correct implementation would be to multiply by a mask of the same type as the intended output.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions