Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added custom mamba op and fix the mamba cache issue #1521

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@
gaudi_llama_rmsnorm_forward,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
gaudi_MambaCache_update_conv_state,
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
Expand Down Expand Up @@ -593,6 +595,8 @@ def adapt_transformers_to_gaudi():
transformers.models.mamba.modeling_mamba.MambaForCausalLM._update_model_kwargs_for_generation = (
gaudi_MambaForCausalLM_update_model_kwargs_for_generation
)
transformers.models.mamba.modeling_mamba.MambaMixer = gaudi_MambaMixer
transformers.cache_utils.MambaCache.update_conv_state = gaudi_MambaCache_update_conv_state

# Optimization for Whisper on Gaudi
transformers.models.whisper.modeling_whisper.WhisperSdpaAttention = GaudiWhisperSdpaAttention
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
from .mamba import (
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
gaudi_MambaCache_update_conv_state,
)
from .mistral import (
GaudiMistralAttention,
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/models/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .modeling_mamba import (
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
gaudi_MambaCache_update_conv_state,
)
347 changes: 343 additions & 4 deletions optimum/habana/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,76 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import torch
from transformers.models.mamba.modeling_mamba import (
MambaCache,
)
from torch import nn
from transformers.cache_utils import MambaCache
from transformers.configuration_utils import PretrainedConfig
from transformers.models.mamba.configuration_mamba import MambaConfig
from transformers.activations import ACT2FN
from transformers.utils import (
ModelOutput,
logging,
)

from pathlib import Path
import os
my_dir = os.path.realpath(__file__)
my_len = my_dir.rfind("/")
base_dir = os.environ.get('HABANA_CUSTOM_OP_DIR', my_dir[:my_len])

custom_op_lib_path = str(
next(
Path(base_dir).glob("hpu_custom_pscan_all.cpython-*-x86_64-linux-gnu.so")
)
)
torch.ops.load_library(custom_op_lib_path)

logger = logging.get_logger(__name__)
is_fast_path_available = False

use_pscan_kernel = True
def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z):

in_state_h = in_state.unsqueeze(1).transpose(2,3)
in_x_h = in_x.transpose(1,2).unsqueeze(2)
in_dt_h = in_dt.unsqueeze(2)
in_A_h = in_A.unsqueeze(0).unsqueeze(1).transpose(2,3)
in_B_h = in_B.unsqueeze(3)
in_C_h = in_C.unsqueeze(3)
in_D_h = in_D.unsqueeze(0).unsqueeze(1).unsqueeze(2)
in_z_h = in_z.transpose(1,2).unsqueeze(2)


if in_state.dtype == torch.float:
state_out_h = torch.ops.custom_op.custom_pscan(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
output_h = torch.ops.custom_op.custom_pscan_update(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h)

else:
in_A_h = in_A_h.to(torch.bfloat16)
state_out_h = torch.ops.custom_op.custom_pscan_bf16(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
output_h = torch.ops.custom_op.custom_pscan_update_bf16(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h)


output_hpu = output_h.squeeze(2).transpose(1,2)
state_hpu = state_out_h.transpose(2,3)
state_out = torch.select(state_hpu,1,output_hpu.shape[2]-1)

return output_hpu, state_out


def gaudi_MambaCache_update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

conv_state = conv_state.roll(shifts=-1, dims=-1)
#conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
for c, i in enumerate(cache_position):
conv_state[:, :, i] = new_conv_state[:,:,c].to(conv_state.device)

self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]


def gaudi_MambaForCausalLM_update_model_kwargs_for_generation(
Expand Down Expand Up @@ -94,3 +154,282 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation(
}
)
return model_inputs

class gaudi_MambaMixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)

We only replaced the slow path with custom op
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zzhang37 can you plz add a comment in the code about the different between this and original impl?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def __init__(self, config: MambaConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = int(config.time_step_rank)
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size,
out_channels=self.intermediate_size,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.intermediate_size,
padding=config.conv_kernel - 1,
)

self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]

self.use_mambapy = config.use_mambapy

# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
# selective projection used to make dt, B and C input dependant
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
# time step projection (discretization)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)

# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
A = A.expand(self.intermediate_size, -1).contiguous()

self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.intermediate_size))
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias

if not is_fast_path_available:
if self.use_mambapy:
if is_mambapy_available():
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
else:
raise ImportError(
"use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
)
else:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)

def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)

if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
contextualized_states = mamba_inner_fn(
projected_states,
self.conv1d.weight,
self.conv1d.bias if self.use_conv_bias else None,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias.float() if self.use_bias else None,
-torch.exp(self.A_log.float()),
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)

else:
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0:
hidden_states = causal_conv1d_update(
hidden_states.squeeze(-1),
cache_params.conv_states[self.layer_idx],
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
hidden_states = causal_conv1d_fn(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)

A = -torch.exp(self.A_log.float())
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
if cache_params is not None and cache_position[0] > 0:
scan_outputs = selective_state_update(
cache_params.ssm_states[self.layer_idx],
hidden_states[..., 0],
discrete_time_step[..., 0],
A,
B[:, 0],
C[:, 0],
self.D,
gate[..., 0],
time_proj_bias,
dt_softplus=True,
).unsqueeze(-1)
else:
scan_outputs, ssm_state = selective_scan_fn(
hidden_states,
discrete_time_step,
A,
B.transpose(1, 2),
C.transpose(1, 2),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
return_last_state=True,
)
if ssm_state is not None and cache_params is not None:
cache_params.update_ssm_state(self.layer_idx, ssm_state)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states

# fmt: off
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
"""
We replaced the 3c and 3d parts with custom op "Run_Mamba_Forward_Gaudi", which removed the sequence length loop and gain the performance.
"""
batch_size, seq_len, _ = input_states.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zzhang37 , can u plz add a brief code comment about the difference between this and original.

is it only Run_Mamba_Forward_Gaudi ?

dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
ssm_state = ssm_state.to(hidden_states.device)
# use `cache_position.shape[0]` to check whether we are in prefill
# stage, it's equivalent to check `cache_position[0] == 0`, which
# breaks dynamo fullgraph constraints
if cache_position.shape[0] == self.conv_kernel_size:
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0)
)

cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
else:
conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
else:
ssm_state = torch.zeros(
(batch_size, self.intermediate_size, self.ssm_state_size),
device=hidden_states.device, dtype=dtype
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]

# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
if use_pscan_kernel:
scan_output, ssm_state = Run_Mamba_Forward_Gaudi(
ssm_state,
hidden_states,
discrete_time_step,
A,
B,
C,
self.D,
gate
)
else:
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()

# 3.c perform the recurrence y ← SSM(A, B, C)(x)
if self.use_mambapy and self.training and cache_params is None:
hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size]

scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
scan_output = scan_output + hidden_states * self.D[None, :, None]
scan_output = scan_output * self.act(gate)
else:
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))

if cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
return contextualized_states
# fmt: on

def forward(
self,
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)