Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When mamba2 infers with pytorch.cuda.amp.autocast(dtype=bfloat16), and params in float32 will cause error in step function #645

Open
realwenlongwang opened this issue Dec 13, 2024 · 0 comments

Comments

@realwenlongwang
Copy link

Issue with Mamba2's generate Function and Torch's Autocast

When using the Mamba2 model with the provided generate function under the context:

with torch.autocast(device_type='cuda', dtype=torch.bfloat16):

an error occurs:

RuntimeError: Expected conv_state.scalar_type() == input_type to be true, but got false.
(Could this error message be improved? If so, please report an enhancement request to PyTorch.)

This happens at the following line in the step function within the Mamba2 class:

xBC = causal_conv1d_update(
    xBC,
    conv_state,
    rearrange(self.conv1d.weight, "d 1 w -> d w"),
    self.conv1d.bias,
    self.activation,
)
  • The issue is that xBC is in bfloat16 while conv_state is stored in float32.
  • Attempting to cast xBC to float32 or conv_state to bfloat16 results in incorrect inference outputs.
    Does anyone have a clue about how to fix this, or is there something I did wrong?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant