Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.autocast errors #17

Open
JCBrouwer opened this issue Oct 17, 2024 · 3 comments
Open

torch.autocast errors #17

JCBrouwer opened this issue Oct 17, 2024 · 3 comments

Comments

@JCBrouwer
Copy link

JCBrouwer commented Oct 17, 2024

I'm getting a couple of dtype-related errors when using the MLP module in a torch.autocast block. Here's my simple wrapper of the MLP module:

from scattermoe.mlp import MLP as MoE


class MyMLP(nn.Module):
    def __init__(self, n_experts: int, d_model: int, mlp_ratio: int = 4, d_out: int | None = None) -> None:
        super().__init__()
        self.moe = MoE(
            input_size=d_model, hidden_size=d_model * mlp_ratio, num_experts=n_experts, top_k=1, activation=ReLUSquare()
        )
        if d_out is not None:
            self.out = nn.Linear(d_model, d_out)
        else:
            self.out = nn.Identity()

    def forward(self, x: Tensor, e: LongTensor) -> Tensor:
        v = self.moe.forward(
            x, expert_p=torch.ones_like(e, dtype=x.dtype).unsqueeze(1), expert_idxs=e.unsqueeze(1)
        )
        v = self.out(v)
        return v

If I add @torch.autocast(device_type='cuda', dtype=torch.bfloat16) to the forward method I get the following type mismatch on the linear layer directly after MyMLP:

Traceback (most recent call last):
...
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float

If I put my whole loss function in an autocast block I get this issue later in the backwards pass:

Traceback (most recent call last):
...
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/torch/autograd/function.py", line 306, in apply
    return user_fn(self, *args)
  File "/home/jcbgb/anaconda3/envs/hans/lib/python3.10/site-packages/scattermoe/parallel_experts.py", line 55, in backward
    d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
RuntimeError: expected scalar type BFloat16 but found Float
@JCBrouwer
Copy link
Author

I think it's just a question of adding @custom_fwd and @custom_bwd to the ParallelExperts autograd Function as explained here: https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops

@uldyssian2008
Copy link

did you solve the problem? i am facing similar issues

@JCBrouwer
Copy link
Author

JCBrouwer commented Oct 29, 2024

I added the the custom_fwd/bwd decorators to the ParallelExperts class like this:

...
from torch.amp import custom_fwd, custom_bwd

class ParallelLinear(torch.autograd.Function):
    @staticmethod
    @custom_fwd(device_type="cuda")
    def forward(
...
    @staticmethod
    @custom_bwd(device_type="cuda")
    def backward(ctx, grad_out):
...

Not sure if this is a generic solution, but it works on my end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants