diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index 0121b9314..d26e3d187 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -155,7 +155,7 @@ def get_model( compile=compile, trust_remote_code=trust_remote_code, ) - + if model_type == "gpt_neox": from lorax_server.models.flash_neox import FlashNeoXSharded @@ -167,7 +167,6 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == "llama": from lorax_server.models.flash_llama import FlashLlama diff --git a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py index dcb41e16a..8f3d98f7c 100644 --- a/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,6 +28,7 @@ from transformers.activations import ACT2FN from lorax_server.adapters.weights import AdapterBatchData +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, @@ -361,9 +362,9 @@ def forward(self, hidden_states, adapter_data): class FlashCohereLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = prepend(prefix, f"model.layers.{layer_id}") self.self_attn = FlashCohereAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id ) @@ -416,16 +417,17 @@ def forward( class FlashCohereModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) + self.embed_tokens = TensorParallelEmbedding(prefix=prepend(prefix, "model.embed_tokens"), weights=weights) self.layers = nn.ModuleList( [ FlashCohereLayer( + prefix, layer_id, config, weights, @@ -433,7 +435,9 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = FastLayerNorm.load_no_bias(prefix="model.norm", weights=weights, eps=config.layer_norm_eps) + self.norm = FastLayerNorm.load_no_bias( + prefix=prepend(prefix, "model.norm"), weights=weights, eps=config.layer_norm_eps + ) self.gradient_checkpointing = False @@ -481,21 +485,21 @@ def forward( class FlashCohereForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.model = FlashCohereModel(config, weights) + self.model = FlashCohereModel(prefix, config, weights) try: lm_head = TensorParallelHead.load( config, - prefix="lm_head", + prefix=prepend(prefix, "lm_head"), weights=weights, ) except RuntimeError: lm_head = TensorParallelHead.load( config, - prefix="model.embed_tokens", + prefix=prepend(prefix, "model.embed_tokens"), weights=weights, ) self.lm_head = MultiAdapterHead.load( diff --git a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py index b6ac0c78c..db79bcdf4 100644 --- a/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_dbrx_modeling.py @@ -24,6 +24,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters.weights import AdapterBatchData +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, @@ -870,9 +871,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DbrxLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.blocks.{layer_id}" + prefix = prepend(prefix, f"transformer.blocks.{layer_id}") self.attn = DbrxNormAttentionNorm( prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights, layer_id=layer_id @@ -916,14 +917,15 @@ def forward( class DbrxModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.embed_tokens = TensorParallelEmbedding(prefix="transformer.wte", weights=weights) + self.embed_tokens = TensorParallelEmbedding(prefix=prepend(prefix, "transformer.wte"), weights=weights) self.layers = nn.ModuleList( [ DbrxLayer( + prefix, layer_id, config, weights, @@ -931,7 +933,7 @@ def __init__(self, config, weights): for layer_id in range(config.n_layers) ] ) - self.norm = FastLayerNorm.load_no_bias(prefix="transformer.norm_f", weights=weights, eps=1e-5) + self.norm = FastLayerNorm.load_no_bias(prefix=prepend(prefix, "transformer.norm_f"), weights=weights, eps=1e-5) self.head_size = self.layers[0].attn.self_attn.head_size self.num_heads = self.layers[0].attn.self_attn.num_heads @@ -977,15 +979,15 @@ def forward( class FlashDbrxForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.model = DbrxModel(config, weights) + self.model = DbrxModel(prefix, config, weights) self.lm_head = MultiAdapterHead.load( TensorParallelHead.load( config, - prefix="lm_head", + prefix=prepend(prefix, "lm_head"), weights=weights, ), 0, diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index 5f2eb6729..ef3328459 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -23,6 +23,7 @@ # Flash attention imports from lorax_server.adapters import AdapterBatchData +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( PositionRotaryEmbedding, @@ -361,10 +362,10 @@ def forward(self, hidden_states, adapter_data): class GemmaDecoderLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = prepend(prefix, f"model.layers.{layer_id}") self.self_attn = GemmaAttention( prefix=f"{prefix}.self_attn", config=config, @@ -424,16 +425,17 @@ def forward( class GemmaModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) + self.embed_tokens = TensorParallelEmbedding(prefix=prepend(prefix, "model.embed_tokens"), weights=weights) self.layers = nn.ModuleList( [ GemmaDecoderLayer( + prefix, layer_id, config, weights, @@ -441,7 +443,7 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = GemmaRMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) + self.norm = GemmaRMSNorm(prefix=prepend(prefix, "model.norm"), weights=weights, eps=config.rms_norm_eps) self.hidden_size = config.hidden_size self.head_size = self.layers[0].self_attn.head_size @@ -491,11 +493,11 @@ def forward( class GemmaForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.model = GemmaModel(config, weights) + self.model = GemmaModel(prefix, config, weights) self.embed_t = self.model.embed_tokens.weight.T.contiguous() self.vocab_size = config.vocab_size diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index e5a4eb09e..1845b3458 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -28,6 +28,7 @@ from transformers.models.gpt2 import GPT2Config from lorax_server.adapters import AdapterBatchData +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, @@ -291,18 +292,18 @@ class FlashGPT2PreTrainedModel(PreTrainedModel): class FlashGPT2Model(FlashGPT2PreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.embed_dim = config.hidden_size - self.wte = TensorParallelEmbedding(prefix="wte", weights=weights) - self.wpe = TensorParallelEmbedding(prefix="wpe", weights=weights) + self.wte = TensorParallelEmbedding(prefix=prepend(prefix, "wte"), weights=weights) + self.wpe = TensorParallelEmbedding(prefix=prepend(prefix, "wpe"), weights=weights) self.h = nn.ModuleList([GPT2Block(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)]) self.ln_f = FastLayerNorm.load( - prefix="ln_f", + prefix=prepend(prefix, "ln_f"), weights=weights, eps=config.layer_norm_epsilon, ) @@ -346,10 +347,10 @@ def forward( class FlashGPT2ForCausalLM(FlashGPT2PreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config - self.transformer = FlashGPT2Model(config, weights) + self.transformer = FlashGPT2Model(prefix, config, weights) self.wte_t = self.transformer.wte.weight.T.contiguous() def forward( diff --git a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py index 5faf75588..0455788a0 100644 --- a/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py @@ -31,6 +31,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters import AdapterBatchData +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from lorax_server.utils.layers import ( @@ -795,9 +796,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MixtralLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = prepend(prefix, f"model.layers.{layer_id}") self.self_attn = MixtralAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id @@ -855,14 +856,15 @@ def forward( class MixtralModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) + self.embed_tokens = TensorParallelEmbedding(prefix=prepend(prefix, "model.embed_tokens"), weights=weights) self.layers = nn.ModuleList( [ MixtralLayer( + prefix, layer_id, config, weights, @@ -870,7 +872,7 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = MixtralRMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) + self.norm = MixtralRMSNorm(prefix=prepend(prefix, "model.norm"), weights=weights, eps=config.rms_norm_eps) self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads @@ -918,15 +920,15 @@ def forward( class FlashMixtralForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.model = MixtralModel(config, weights) + self.model = MixtralModel(prefix, config, weights) self.lm_head = MultiAdapterHead.load( TensorParallelHead.load( config, - prefix="lm_head", + prefix=prepend(prefix, "lm_head"), weights=weights, ), 0, diff --git a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py index eb832e12f..0b0ff025f 100644 --- a/server/lorax_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_neox_modeling.py @@ -27,6 +27,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, @@ -190,12 +191,12 @@ def forward(self, hidden_states): class FlashNeoXLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() layer_norm_eps = config.layer_norm_eps - prefix = f"gpt_neox.layers.{layer_id}" + prefix = prepend(prefix, f"gpt_neox.layers.{layer_id}") self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = FastLayerNorm.load( @@ -278,17 +279,17 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config - self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights) + self.embed_in = TensorParallelEmbedding(prefix=prepend(prefix, "gpt_neox.embed_in"), weights=weights) self.layers = nn.ModuleList( - [FlashNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)] + [FlashNeoXLayer(prefix, layer_id, config, weights) for layer_id in range(config.num_hidden_layers)] ) self.final_layer_norm = FastLayerNorm.load( - prefix="gpt_neox.final_layer_norm", + prefix=prepend(prefix, "gpt_neox.final_layer_norm"), weights=weights, eps=config.layer_norm_eps, ) @@ -336,12 +337,12 @@ def forward( class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config - self.gpt_neox = FlashGPTNeoXModel(config, weights) + self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights) - self.embed_out = TensorParallelHead.load(config, prefix="embed_out", weights=weights) + self.embed_out = TensorParallelHead.load(config, prefix=prepend(prefix, "embed_out"), weights=weights) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py index 1defda005..b0c10b9bb 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi3_modeling.py @@ -29,6 +29,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters import AdapterBatchData +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( MultiAdapterHead, @@ -352,9 +353,9 @@ def forward(self, hidden_states, adapter_data): class FlashPhi3Layer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = prepend(prefix, f"model.layers.{layer_id}") self.self_attn = FlashPhi3Attention( prefix=f"{prefix}.self_attn", config=config, @@ -409,16 +410,17 @@ def forward( class FlashPhi3Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) + self.embed_tokens = TensorParallelEmbedding(prefix=prepend(prefix, "model.embed_tokens"), weights=weights) self.layers = nn.ModuleList( [ FlashPhi3Layer( + prefix, layer_id, config, weights, @@ -426,7 +428,7 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = Phi3RMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) + self.norm = Phi3RMSNorm(prefix=prepend(prefix, "model.norm"), weights=weights, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -474,15 +476,15 @@ def forward( class FlashPhi3ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.model = FlashPhi3Model(config, weights) + self.model = FlashPhi3Model(prefix, config, weights) self.lm_head = MultiAdapterHead.load( TensorParallelHead.load( config, - prefix="lm_head", + prefix=prepend(prefix, "lm_head"), weights=weights, ), 0, diff --git a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py index 108b7ce5a..16610de75 100644 --- a/server/lorax_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_phi_modeling.py @@ -17,6 +17,7 @@ from transformers.activations import ACT2FN from lorax_server.adapters import AdapterBatchData +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, @@ -235,9 +236,9 @@ def forward(self, hidden_states, adapter_data): class FlashPhiLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = prepend(prefix, f"model.layers.{layer_id}") self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps @@ -290,16 +291,17 @@ def forward( class FlashPhiModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) + self.embed_tokens = TensorParallelEmbedding(prefix=prepend(prefix, "model.embed_tokens"), weights=weights) self.layers = nn.ModuleList( [ FlashPhiLayer( + prefix, layer_id, config, weights, @@ -308,7 +310,7 @@ def __init__(self, config, weights): ] ) self.final_layernorm = FastLayerNorm.load( - prefix="model.final_layernorm", weights=weights, eps=config.layer_norm_eps + prefix=prepend(prefix, "model.final_layernorm"), weights=weights, eps=config.layer_norm_eps ) self.gradient_checkpointing = False @@ -356,15 +358,15 @@ def forward( class FlashPhiForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.model = FlashPhiModel(config, weights) + self.model = FlashPhiModel(prefix, config, weights) self.lm_head = MultiAdapterHead.load( TensorParallelHead.load( config, - prefix="lm_head", + prefix=prepend(prefix, "lm_head"), weights=weights, ), 0, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 37d763769..f66404e6d 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -14,6 +14,7 @@ from transformers.activations import ACT2FN from lorax_server.adapters import AdapterBatchData +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( MultiAdapterHead, @@ -315,9 +316,9 @@ def forward(self, hidden_states, adapter_data): class FlashQwen2Layer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = prepend(prefix, f"model.layers.{layer_id}") self.self_attn = FlashQwen2Attention( prefix=f"{prefix}.self_attn", config=config, @@ -376,16 +377,17 @@ def forward( class FlashQwen2Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) + self.embed_tokens = TensorParallelEmbedding(prefix=prepend(prefix, "model.embed_tokens"), weights=weights) self.layers = nn.ModuleList( [ FlashQwen2Layer( + prefix, layer_id, config, weights, @@ -393,7 +395,7 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = Qwen2RMSNorm(prefix="model.norm", weights=weights, eps=config.rms_norm_eps) + self.norm = Qwen2RMSNorm(prefix=prepend(prefix, "model.norm"), weights=weights, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -442,15 +444,15 @@ def forward( class FlashQwen2ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.model = FlashQwen2Model(config, weights) + self.model = FlashQwen2Model(prefix, config, weights) self.lm_head = MultiAdapterHead.load( TensorParallelHead.load( config, - prefix="lm_head", + prefix=prepend(prefix, "lm_head"), weights=weights, ), 0, @@ -503,14 +505,14 @@ def forward( class FlashQwen2ForEmbeddings(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.model = FlashQwen2Model(config, weights) + self.model = FlashQwen2Model(prefix, config, weights) self.max_past = config.sliding_window - self.output_weight = weights.get_tensor("linear.weight") - self.output_bias = weights.get_tensor("linear.bias") + self.output_weight = weights.get_tensor(prepend(prefix, "linear.weight")) + self.output_bias = weights.get_tensor(prepend(prefix, "linear.bias")) # To satisfy the parent class interface # TODO: fix self.lm_head = None diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 13922e984..57f3cce21 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -15,6 +15,7 @@ from transformers.configuration_utils import PretrainedConfig from lorax_server.adapters import AdapterBatchData +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( MultiAdapterHead, @@ -328,9 +329,9 @@ def forward(self, hidden_states, adapter_data): class FlashQwenLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = prepend(prefix, f"transformer.h.{layer_id}") self.attn = FlashQwenAttention( prefix=f"{prefix}.attn", config=config, @@ -385,16 +386,17 @@ def forward( class FlashQwenModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.wte = TensorParallelEmbedding(prefix="transformer.wte", weights=weights) + self.wte = TensorParallelEmbedding(prefix=prepend(prefix, "transformer.wte"), weights=weights) self.h = nn.ModuleList( [ FlashQwenLayer( + prefix, layer_id, config, weights, @@ -402,7 +404,9 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.ln_f = QwenRMSNorm(prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon) + self.ln_f = QwenRMSNorm( + prefix=prepend(prefix, "transformer.ln_f"), weights=weights, eps=config.layer_norm_epsilon + ) self.gradient_checkpointing = False @@ -450,15 +454,15 @@ def forward( class FlashQwenForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.transformer = FlashQwenModel(config, weights) + self.transformer = FlashQwenModel(prefix, config, weights) self.lm_head = MultiAdapterHead.load( TensorParallelHead.load( config, - prefix="lm_head", + prefix=prepend(prefix, "lm_head"), weights=weights, ), 0, diff --git a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py index 6a8f2dc24..f1fff63ea 100644 --- a/server/lorax_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_rw_modeling.py @@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, @@ -336,6 +337,7 @@ def forward(self, hidden_states): class FlashRWLayer(nn.Module): def __init__( self, + prefix: str, layer_id, config, weights, @@ -345,7 +347,7 @@ def __init__( parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn - prefix = f"transformer.h.{layer_id}" + prefix = prepend(prefix, f"transformer.h.{layer_id}") self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", @@ -433,9 +435,9 @@ def forward( class FlashRWLargeLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = prepend(prefix, f"transformer.h.{layer_id}") self.ln_attn = FastLayerNorm.load( prefix=f"{prefix}.ln_attn", weights=weights, @@ -503,25 +505,27 @@ class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config - self.word_embeddings = TensorParallelEmbedding(prefix="transformer.word_embeddings", weights=weights) + self.word_embeddings = TensorParallelEmbedding( + prefix=prepend(prefix, "transformer.word_embeddings"), weights=weights + ) if config.new_decoder_architecture: self.h = nn.ModuleList( - [FlashRWLargeLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)] + [FlashRWLargeLayer(prefix, layer_id, config, weights) for layer_id in range(config.num_hidden_layers)] ) self.cache_size = self.h[0].self_attention.num_groups else: self.h = nn.ModuleList( - [FlashRWLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)] + [FlashRWLayer(prefix, layer_id, config, weights) for layer_id in range(config.num_hidden_layers)] ) self.cache_size = self.h[0].self_attention.num_heads_kv self.ln_f = FastLayerNorm.load( - prefix="transformer.ln_f", + prefix=prepend(prefix, "transformer.ln_f"), weights=weights, eps=config.layer_norm_epsilon, ) @@ -567,13 +571,13 @@ def forward( class FlashRWForCausalLM(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config - self.transformer = FlashRWModel(config, weights) + self.transformer = FlashRWModel(prefix, config, weights) - self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights) + self.lm_head = TensorParallelHead.load(config, prefix=prepend(prefix, "lm_head"), weights=weights) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py index 538d42402..7efa57536 100644 --- a/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,6 +5,7 @@ from torch import nn from transformers.activations import ACT2FN +from lorax_server.models.custom_modeling.utils import prepend from lorax_server.utils import flash_attn, paged_attention from lorax_server.utils.layers import ( FastLayerNorm, @@ -288,9 +289,9 @@ def forward(self, hidden_states): class Block(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = prepend(prefix, f"transformer.h.{layer_id}") self.ln_1 = FastLayerNorm.load(prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon) self.ln_2 = FastLayerNorm.load(prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon) self.attn = FlashMQAttention( @@ -334,18 +335,18 @@ def forward( class FlashSantacoderModel(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config self.process_group = weights.process_group self.wte = TensorParallelEmbedding( - prefix="transformer.wte", + prefix=prepend(prefix, "transformer.wte"), weights=weights, reduce=False, ) self.wpe = TensorParallelEmbedding( - prefix="transformer.wpe", + prefix=prepend(prefix, "transformer.wpe"), weights=weights, reduce=False, ) @@ -353,6 +354,7 @@ def __init__(self, config, weights): self.h = nn.ModuleList( [ Block( + prefix, layer_id, config, weights, @@ -360,7 +362,9 @@ def __init__(self, config, weights): for layer_id in range(config.num_hidden_layers) ] ) - self.ln_f = FastLayerNorm.load(prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon) + self.ln_f = FastLayerNorm.load( + prefix=prepend(prefix, "transformer.ln_f"), weights=weights, eps=config.layer_norm_epsilon + ) self.head_size = self.h[0].attn.head_size self.num_heads = self.h[0].attn.num_heads @@ -400,11 +404,11 @@ def forward( class FlashSantacoderForCausalLM(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config - self.transformer = FlashSantacoderModel(config, weights) - self.lm_head = TensorParallelHead.load(config, prefix="transformer.wte", weights=weights) + self.transformer = FlashSantacoderModel(prefix, config, weights) + self.lm_head = TensorParallelHead.load(config, prefix=prepend(prefix, "transformer.wte"), weights=weights) def forward( self, diff --git a/server/lorax_server/models/custom_modeling/utils.py b/server/lorax_server/models/custom_modeling/utils.py new file mode 100644 index 000000000..7f24143d7 --- /dev/null +++ b/server/lorax_server/models/custom_modeling/utils.py @@ -0,0 +1,2 @@ +def prepend(prefix: str, path: str) -> str: + return f"{prefix}.{path}" if prefix else path diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index b532d3519..73fcfea1b 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -871,11 +871,7 @@ def __init__( num_kv_heads = getattr(config, "n_head", None) if num_kv_heads is None: raise ValueError("Cannot get the number of key/value heads") - self.num_kv_heads = ( - num_kv_heads // self.process_group.size() - if num_kv_heads > 1 - else num_kv_heads - ) + self.num_kv_heads = num_kv_heads // self.process_group.size() if num_kv_heads > 1 else num_kv_heads assert self.num_kv_heads > 0 if head_size is None: diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index 6394cde8b..8b03e1b8d 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -255,7 +255,7 @@ def __init__( ): if PREFIX_CACHING: raise NotImplementedError("Vlm do not work with prefix caching yet") - + if processor_kwargs is None: processor_kwargs = {} diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 0bafc1559..169ca543c 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -550,7 +550,9 @@ def forward( key = (batch_size, max_rank) graph = self.cache.get(key) if graph is None or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names()): - current_traced_adapter_layer_names = graph.input_state.traced_adapter_layer_names if graph is not None else set() + current_traced_adapter_layer_names = ( + graph.input_state.traced_adapter_layer_names if graph is not None else set() + ) logger.info( "Retrace graph with new adapter layers: {} -> {}", current_traced_adapter_layer_names, diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h index ecc42bcb9..ac2a66d92 100644 --- a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h @@ -10,7 +10,6 @@ void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X, // clang-format off #define FOR_BGMV_WIDE(f, T, narrow) \ - f(T, narrow, 128) \ f(T, narrow, 256) \ f(T, narrow, 512) \ f(T, narrow, 640) \ @@ -42,6 +41,7 @@ void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X, f(T, narrow, 8192) \ f(T, narrow, 8960) \ f(T, narrow, 9216) \ + f(T, narrow, 9472) \ f(T, narrow, 10240) \ f(T, narrow, 11008) \ f(T, narrow, 12288) \