Skip to content

Commit

Permalink
feat: support baichuan2
Browse files Browse the repository at this point in the history
  • Loading branch information
jimpang committed Nov 21, 2023
1 parent 338bce6 commit 23f06a9
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 51 deletions.
10 changes: 10 additions & 0 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BaiChuan2ForCausalLM": BaiChuan2ForCausalLM, # baichuan2-rope
"Baichuan2ForCausalLM": Baichuan2ForCausalLM, # baichuan2-alibi
"BloomForCausalLM": BloomForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
Expand Down Expand Up @@ -52,6 +54,14 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
# baichuan 2 has different vocab size
if ("baichuan" in arch.lower()) and (getattr(config, "vocab_size")
== 125696):
# baichuan 2 7b and 13b have different intermediate size
if getattr(config, "intermediate_size") == 11008:
return BaiChuan2ForCausalLM
elif getattr(config, "intermediate_size") == 13696:
return Baichuan2ForCausalLM
return _MODEL_REGISTRY[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from vllm.model_executor.models.aquila import AquilaForCausalLM
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
BaichuanForCausalLM)
BaichuanForCausalLM,
BaiChuan2ForCausalLM,
Baichuan2ForCausalLM)
from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.falcon import FalconForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
Expand All @@ -21,6 +23,8 @@
"AquilaForCausalLM",
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BaiChuan2ForCausalLM",
"Baichuan2ForCausalLM",
"BloomForCausalLM",
"ChatGLMForCausalLM",
"FalconForCausalLM",
Expand Down
148 changes: 98 additions & 50 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,41 @@
import torch
from torch import nn

from vllm.logger import init_logger
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi)
from vllm.model_executor.layers.attention import (PagedAttentionWithALiBi, PagedAttentionWithRoPE)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
QKVParallelLinear)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.layers.vocab_parallel_embedding import (ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig

logger = init_logger(__name__)

KVCache = Tuple[torch.Tensor, torch.Tensor]


def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)

if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
Expand All @@ -78,11 +79,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
class BaiChuanMLP(nn.Module):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
Expand All @@ -109,13 +110,13 @@ class BaiChuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
self,
hidden_size: int,
num_heads: int,
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
self,
hidden_size: int,
num_heads: int,
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -153,11 +154,11 @@ def __init__(
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()

scaling = self.head_dim**-0.5
scaling = self.head_dim ** -0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
else:
self.scaling = self.head_dim**-0.5
self.scaling = self.head_dim ** -0.5
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
Expand All @@ -167,12 +168,12 @@ def __init__(
max_position=self.max_position_embeddings)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
Expand Down Expand Up @@ -219,13 +220,13 @@ def __init__(self,
eps=config.rms_norm_eps)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
Expand Down Expand Up @@ -271,12 +272,12 @@ def __init__(self,
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
Expand All @@ -295,29 +296,62 @@ def forward(
return hidden_states


class NormHead(ColumnParallelLinear):

def __init__(self, hidden_size, vocab_size, bias=False):
super().__init__(hidden_size,
vocab_size,
bias=False,
gather_output=False)
self.first_flag = True

def get_weight(self):
if self.first_flag:
self.first_flag = False
self.weight = nn.Parameter(nn.functional.normalize(self.weight))
return self.weight

def forward(self, hidden_states):
if self.first_flag:
self.first_flag = False
self.weight = nn.Parameter(nn.functional.normalize(self.weight))
return ColumnParallelLinear.forward(self, hidden_states)


class BaiChuanBaseForCausalLM(nn.Module):

def __init__(self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
linear_method: Optional[LinearMethodBase] = None, version: str = "1"):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = BaiChuanModel(config, position_embedding, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.version = version
if self.version == "1":
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
elif self.version == "2":
self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
else:
raise ValueError("Only support baichuan version 1 and 2")

self.sampler = Sampler(config.vocab_size)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)

lm_head_weight = self.lm_head.weight
if self.version == "2":
lm_head_weight = self.lm_head.get_weight()
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
Expand Down Expand Up @@ -365,3 +399,17 @@ def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ROPE", linear_method)


class Baichuan2ForCausalLM(BaiChuanBaseForCausalLM): # baichuan2 13b

def __init__(self, config, linear_method: Optional[LinearMethodBase] = None):
logger.info("start init Baichuan2ForCausalLM for 13B version")
super().__init__(config, "ALIBI", linear_method, "2")


class BaiChuan2ForCausalLM(BaiChuanBaseForCausalLM): # baichuan2 7b

def __init__(self, config, linear_method: Optional[LinearMethodBase] = None):
logger.info("start init Baichuan2ForCausalLM for 7B version")
super().__init__(config, "ROPE", linear_method, "2")

0 comments on commit 23f06a9

Please sign in to comment.