Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added mamba model support and test CI script #1573

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3606b69
Added mamba model support and test CI script
zzhang37 Dec 6, 2024
a63a784
Added mamba model support and test CI script
zzhang37 Dec 6, 2024
68da585
Added mamba model support and test CI script
zzhang37 Dec 6, 2024
ee6cb84
Added mamba model support and test CI script
zzhang37 Dec 6, 2024
fba7f6c
Generation utils update (minor) (#1468)
yafshar Dec 8, 2024
84f4651
style: removed tabs (#1577)
mgonchar Dec 9, 2024
f92097d
Add chatglm (#1478)
mengker33 Dec 9, 2024
6979ebd
Enable num_return_sequences in beam search (#1536)
mengker33 Dec 9, 2024
c3cc9e3
gpt_bigcode: added internal bucketing fix (#1526)
mgonchar Dec 9, 2024
cbaa02b
Update the Gaudi trainer with transformers 4.45.2 (#1398)
yafshar Dec 9, 2024
b883184
Merge branch 'main' into zhzhang/mamba_with_test_119
skaulintel Dec 9, 2024
313d238
Fixed spelling (#1576)
mgonchar Dec 10, 2024
14e473e
Update docs for baichuan2 training (#1586)
xhaihao Dec 10, 2024
33a718f
Update the Gaudi trainer with transformers 4.45.2 (#1398)
yafshar Dec 9, 2024
43d92c9
Fix Accuracy Calculation Issue in GPT-NeoX (#1591)
yafshar Dec 10, 2024
aa59027
Add WA flag for falcon-180b to resolve text-gen critical reset error …
hchauhan123 Dec 10, 2024
27a44c9
Update transformers tests generation util v4.45.2 (#1441)
malkomes Dec 11, 2024
ceace58
Update the Gaudi trainer with transformers 4.45.2 (#1398)
yafshar Dec 9, 2024
05ada67
Limit position embeddings in inference (#1598)
bhargaveede Dec 12, 2024
dd7e0bd
Verify model output is provided when check_output is enabled (#1597)
vidyasiv Dec 12, 2024
aa76b9f
Update README.md (#1595)
skaulintel Dec 12, 2024
4bd36c4
Fix scikit-learn to 1.5.2 to fix f1 evaluation crash in 1.6.0 (#1596)
sywangyi Dec 12, 2024
d6681ec
Update the Gaudi trainer with transformers 4.45.2 (#1398)
yafshar Dec 9, 2024
2c8690d
Merge remote-tracking branch 'optimum-habana/main' into zhzhang/mamba…
regisss Dec 12, 2024
a87ed29
Use relative import
regisss Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ The following model architectures, tasks and device distributions have been vali
| MiniCPM3 | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Baichuan2 | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| DeepSeek-V2 | | :heavy_check_mark: | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Mamba | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
regisss marked this conversation as resolved.
Show resolved Hide resolved
| ChatGLM | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
</div>

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| MiniCPM3 | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Baichuan2 | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| DeepSeek-V2 | | ✅ | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Mamba | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| ChatGLM | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |

- Diffusers
Expand Down
12 changes: 12 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,18 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
> --bf16
> ```

To run Mamba-130m inference on 1 Gaudi2 card, use the following command, for example if default custom kernel path is in /root/.cache/huggingface/hub/models--Habana--mamba/blobs/libcustom_tpc_perf_lib.so, if libcustom_tpc_perf_lib.so is in different folder, set accordingly,
```bash
GC_KERNEL_PATH=/root/.cache/huggingface/hub/models--Habana--mamba/blobs/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH python run_generation.py \
--model_name_or_path state-spaces/mamba-130m-hf \
--max_input_tokens 128 \
--max_new_tokens 128 \
--bf16 \
--use_hpu_graphs \
--use_kv_cache \
--batch_size 1024 \
```

### Use any dataset from the Hugging Face Hub

You can also provide the name of a dataset from the Hugging Face Hub to perform generation on it with the argument `--dataset_name`.
Expand Down
4 changes: 4 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@
gaudi_gpt_neox_model_forward,
gaudi_invert_attention_mask,
gaudi_llama_rmsnorm_forward,
gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_block_dynamic_moe_forward,
gaudi_mixtral_block_sparse_moe_forward,
Expand Down Expand Up @@ -675,6 +677,8 @@ def adapt_transformers_to_gaudi():
)
transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaModel.forward = gaudi_FalconMambaModel_forward
transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaRMSNorm.forward = gaudi_llama_rmsnorm_forward
transformers.models.mamba.modeling_mamba.MambaMixer = gaudi_MambaMixer
transformers.cache_utils.MambaCache.update_conv_state = gaudi_MambaCache_update_conv_state

# Optimization for Whisper on Gaudi
transformers.models.whisper.modeling_whisper.WhisperSdpaAttention = GaudiWhisperSdpaAttention
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,10 @@
from .llava import GaudiLlavaForConditionalGeneration
from .llava_next import GaudiLlavaNextForConditionalGeneration
from .mamba import (
gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
)
from .minicpm import MiniCPM3Config, MiniCPM3ForCausalLM
from .mistral import (
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/models/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .modeling_mamba import (
gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
)
234 changes: 231 additions & 3 deletions optimum/habana/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,88 @@
import os
from pathlib import Path
from typing import Any, Dict, Optional

import torch
from transformers.models.mamba.modeling_mamba import (
MambaCache,
)
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import MambaCache
from transformers.models.mamba.configuration_mamba import MambaConfig
from transformers.utils import (
ModelOutput,
logging,
)

from .util_mamba import set_mamba_lib


env_variables = os.environ.copy()

new_file_op, new_file_kernel = set_mamba_lib()
realpath_kfn = os.path.realpath(new_file_kernel)
kfn = os.path.basename(realpath_kfn)
new_kfn = os.path.join(os.path.dirname(realpath_kfn), "libcustom_tpc_perf_lib.so")
os.rename(realpath_kfn, new_kfn)


env_variables["HABANA_CUSTOM_OP_DIR"] = os.path.dirname(new_file_op)
default_path = env_variables["GC_KERNEL_PATH"]
env_variables["GC_KERNEL_PATH"] = new_kfn + os.pathsep + default_path

base_dir = env_variables["HABANA_CUSTOM_OP_DIR"]

custom_op_lib_path = str(next(Path(base_dir).glob("hpu_custom_pscan_all.cpython-*-x86_64-linux-gnu.so")))
torch.ops.load_library(custom_op_lib_path)

logger = logging.get_logger(__name__)

is_fast_path_available = False

use_pscan_kernel = False
if os.path.exists(custom_op_lib_path):
use_pscan_kernel = True


def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z):
in_state_h = in_state.unsqueeze(1).transpose(2, 3)
in_x_h = in_x.transpose(1, 2).unsqueeze(2)
in_dt_h = in_dt.unsqueeze(2)
in_A_h = in_A.unsqueeze(0).unsqueeze(1).transpose(2, 3)
in_B_h = in_B.unsqueeze(3)
in_C_h = in_C.unsqueeze(3)
in_D_h = in_D.unsqueeze(0).unsqueeze(1).unsqueeze(2)
in_z_h = in_z.transpose(1, 2).unsqueeze(2)

if in_state.dtype == torch.float:
state_out_h = torch.ops.custom_op.custom_pscan(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
output_h = torch.ops.custom_op.custom_pscan_update(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h)

else:
in_A_h = in_A_h.to(torch.bfloat16)
state_out_h = torch.ops.custom_op.custom_pscan_bf16(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
output_h = torch.ops.custom_op.custom_pscan_update_bf16(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h)

output_hpu = output_h.squeeze(2).transpose(1, 2)
state_hpu = state_out_h.transpose(2, 3)
state_out = torch.select(state_hpu, 1, output_hpu.shape[2] - 1)

return output_hpu, state_out


def gaudi_MambaCache_update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

conv_state = conv_state.roll(shifts=-1, dims=-1)
# conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
for c, i in enumerate(cache_position):
conv_state[:, :, i] = new_conv_state[:, :, c].to(conv_state.device)

self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]


def gaudi_MambaForCausalLM_update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs
Expand Down Expand Up @@ -94,3 +165,160 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation(
}
)
return model_inputs


class gaudi_MambaMixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
We only replaced the slow path with custom op
"""

def __init__(self, config: MambaConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = int(config.time_step_rank)
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size,
out_channels=self.intermediate_size,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.intermediate_size,
padding=config.conv_kernel - 1,
)

self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]

self.use_mambapy = config.use_mambapy

# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
# selective projection used to make dt, B and C input dependant
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
# time step projection (discretization)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)

# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
A = A.expand(self.intermediate_size, -1).contiguous()

self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.intermediate_size))
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)

# fmt: off
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
"""
We replaced the 3c and 3d parts with custom op "Run_Mamba_Forward_Gaudi", which removed the sequence length loop and gain the performance.
"""
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
ssm_state = ssm_state.to(hidden_states.device)
# use `cache_position.shape[0]` to check whether we are in prefill
# stage, it's equivalent to check `cache_position[0] == 0`, which
# breaks dynamo fullgraph constraints
if cache_position.shape[0] == self.conv_kernel_size:
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0)
)

cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
else:
conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
else:
ssm_state = torch.zeros(
(batch_size, self.intermediate_size, self.ssm_state_size),
device=hidden_states.device, dtype=dtype
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]

# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
if use_pscan_kernel:
scan_output, ssm_state = Run_Mamba_Forward_Gaudi(
ssm_state,
hidden_states,
discrete_time_step,
A,
B,
C,
self.D,
gate
)
else:
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()

# 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))

if cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
return contextualized_states
# fmt: on

def forward(
self,
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
29 changes: 29 additions & 0 deletions optimum/habana/transformers/models/mamba/util_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

from huggingface_hub import hf_hub_download

from ....utils import get_habana_frameworks_version


def set_mamba_lib():
version_no = get_habana_frameworks_version()

name_op = "hpu_custom_pscan_all.cpython-310-x86_64-linux-gnu.so"
name_kernel = "libcustom_tpc_perf_lib.so"
if version_no.minor == 19:
name_op = "hpu_custom_pscan_all.cpython-310-x86_64-linux-gnu_119.so"
name_kernel = "libcustom_tpc_perf_lib_119.so"

file_op = hf_hub_download(repo_id="Habana/mamba", filename=name_op)
file_kernel = hf_hub_download(repo_id="Habana/mamba", filename=name_kernel)

new_file_op = file_op
new_file_kernel = file_kernel

if version_no.minor == 19:
new_file_op = file_op[:-7] + ".so"
new_file_kernel = file_kernel[:-7] + ".so"
os.rename(file_op, new_file_op)
os.rename(file_kernel, new_file_kernel)

return new_file_op, new_file_kernel
Loading
Loading