You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
RuntimeError: Expected conv_state.scalar_type() == input_type to be true, but got false.
(Could this error message be improved? If so, please report an enhancement request to PyTorch.)
This happens at the following line in the step function within the Mamba2 class:
xBC=causal_conv1d_update(
xBC,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
The issue is that xBC is in bfloat16 while conv_state is stored in float32.
Attempting to cast xBC to float32 or conv_state to bfloat16 results in incorrect inference outputs.
Does anyone have a clue about how to fix this, or is there something I did wrong?
The text was updated successfully, but these errors were encountered:
Issue with Mamba2's
generate
Function and Torch's AutocastWhen using the
Mamba2
model with the providedgenerate
function under the context:an error occurs:
This happens at the following line in the
step
function within theMamba2
class:xBC
is inbfloat16
whileconv_state
is stored infloat32
.xBC
tofloat32
orconv_state
tobfloat16
results in incorrect inference outputs.Does anyone have a clue about how to fix this, or is there something I did wrong?
The text was updated successfully, but these errors were encountered: