Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 14, 2024
1 parent ee24cf4 commit cbbe95a
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 42 deletions.
4 changes: 3 additions & 1 deletion server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,9 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option
else:
seed = None

generated_text = GeneratedText(output_text, stopping_criteria.current_tokens, reason, seed)
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, stopping_criteria.current_skipped, reason, seed
)
else:
generated_text = None

Expand Down
16 changes: 6 additions & 10 deletions server/lorax_server/models/custom_modeling/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type):

out_size = fc1.linear.weight.shape[-1] * weights.process_group.size()
self.fc1 = TensorParallelMultiAdapterLinear.load(
fc1,
layer_id,
[f'{model_type}_{FC1}'],
sizes=[out_size],
process_group=weights.process_group
fc1, layer_id, [f"{model_type}_{FC1}"], sizes=[out_size], process_group=weights.process_group
)
self.fc2 = TensorParallelAdapterRowLinear.load(
TensorParallelRowLinear.load(
Expand All @@ -239,7 +235,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type):
bias=True,
),
layer_id,
f'{model_type}_{FC2}',
f"{model_type}_{FC2}",
process_group=weights.process_group,
)

Expand All @@ -261,7 +257,7 @@ def load_attention(config, prefix, weights, layer_id, model_type, head_dim, n_he
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
[f'{model_type}_{Q_PROJ}', f'{model_type}_{K_PROJ}', f'{model_type}_{V_PROJ}'],
[f"{model_type}_{Q_PROJ}", f"{model_type}_{K_PROJ}", f"{model_type}_{V_PROJ}"],
sizes=[
head_dim * n_head,
head_dim * n_head_kv,
Expand Down Expand Up @@ -306,7 +302,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type):
bias=False,
),
layer_id,
f'{model_type}_{O_PROJ}',
f"{model_type}_{O_PROJ}",
process_group=weights.process_group,
)

Expand Down Expand Up @@ -557,15 +553,15 @@ def __init__(self, *, prefix, config, weights):
weights=weights,
is_gated=False,
num_layers=config.num_hidden_layers,
model_type='VISION_TRANSFORMER',
model_type="VISION_TRANSFORMER",
)
self.global_transformer = MllamaVisionEncoder(
prefix=f"{prefix}.global_transformer",
config=config,
weights=weights,
is_gated=True,
num_layers=config.num_global_layers,
model_type='VISION_GLOBAL_TRANSFORMER',
model_type="VISION_GLOBAL_TRANSFORMER",
)

def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
Expand Down
57 changes: 31 additions & 26 deletions server/lorax_server/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TEXT_ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD]
VISION_ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, FC1, FC2]


@dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch):
image_indices: List[int] = 42
Expand Down Expand Up @@ -179,33 +180,34 @@ def from_pb(


class MllamaCausalLM(VlmCausalLM):

@property
def supports_adapter_loading(self) -> bool:
return True

@property
def adapter_layers(self) -> List[str]:
return [f'TEXT_{layer_type}' for layer_type in TEXT_ADAPTER_LAYERS] \
+ [f'VISION_GLOBAL_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] \
+ [f'VISION_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS]
return (
[f"TEXT_{layer_type}" for layer_type in TEXT_ADAPTER_LAYERS]
+ [f"VISION_GLOBAL_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS]
+ [f"VISION_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS]
)

@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]

def get_num_layers_for_type(self, layer_type: str) -> int:
if 'LM_HEAD' in layer_type:
if "LM_HEAD" in layer_type:
return 1
if 'TEXT_' in layer_type:
if "TEXT_" in layer_type:
return [
layer_id
for layer_id, layer in enumerate(self.model.text_model.model.layers)
if not isinstance(layer, FlashLlamaCrossLayer)
if not isinstance(layer, FlashLlamaCrossLayer)
]
if 'VISION_GLOBAL_TRANSFORMER_' in layer_type:
if "VISION_GLOBAL_TRANSFORMER_" in layer_type:
return len(self.model.vision_model.global_transformer.layers)
if 'VISION_TRANSFORMER_' in layer_type:
if "VISION_TRANSFORMER_" in layer_type:
return len(self.model.vision_model.transformer.layers)

def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
Expand All @@ -215,51 +217,54 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
for i, layer in enumerate(self.model.text_model.model.layers):
if isinstance(layer, FlashLlamaCrossLayer):
continue
layer_weights[(i, f'TEXT_{Q_PROJ}')] = (
layer_weights[(i, f"TEXT_{Q_PROJ}")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, f'TEXT_{K_PROJ}')] = (
layer_weights[(i, f"TEXT_{K_PROJ}")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, f'TEXT_{V_PROJ}')] = (
layer_weights[(i, f"TEXT_{V_PROJ}")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, f'TEXT_{O_PROJ}')] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj)

layer_weights[(i, f'TEXT_{GATE_PROJ}')] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f'TEXT_{UP_PROJ}')] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f'TEXT_{DOWN_PROJ}')] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj)
layer_weights[(0, f'TEXT_{LM_HEAD}')] = ("base_model.model.language_model.lm_head", self.model.text_model.lm_head)
layer_weights[(i, f"TEXT_{O_PROJ}")] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj)

layer_weights[(i, f"TEXT_{GATE_PROJ}")] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f"TEXT_{UP_PROJ}")] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f"TEXT_{DOWN_PROJ}")] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj)
layer_weights[(0, f"TEXT_{LM_HEAD}")] = (
"base_model.model.language_model.lm_head",
self.model.text_model.lm_head,
)

vision_layer_mappings = [
("vision_model.global_transformer.layers", self.model.vision_model.global_transformer.layers),
("vision_model.transformer.layers", self.model.vision_model.transformer.layers),
]
for prefix, layer_list in vision_layer_mappings:
layer_type_prefix = 'VISION_GLOBAL_TRANSFORMER' if 'global_transformer' in prefix else 'VISION_TRANSFORMER'
layer_type_prefix = "VISION_GLOBAL_TRANSFORMER" if "global_transformer" in prefix else "VISION_TRANSFORMER"
for i, layer in enumerate(layer_list):
layer_weights[(i, f'{layer_type_prefix}_{Q_PROJ}')] = (
layer_weights[(i, f"{layer_type_prefix}_{Q_PROJ}")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.qkv_proj,
)
layer_weights[(i, f'{layer_type_prefix}_{K_PROJ}')] = (
layer_weights[(i, f"{layer_type_prefix}_{K_PROJ}")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.qkv_proj,
)
layer_weights[(i, f'{layer_type_prefix}_{V_PROJ}')] = (
layer_weights[(i, f"{layer_type_prefix}_{V_PROJ}")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.qkv_proj,
)
layer_weights[(i, f'{layer_type_prefix}_{O_PROJ}')] = (
layer_weights[(i, f"{layer_type_prefix}_{O_PROJ}")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj
layer.self_attn.o_proj,
)

layer_weights[(i, f'{layer_type_prefix}_{FC1}')] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1)
layer_weights[(i, f'{layer_type_prefix}_{FC2}')] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2)
layer_weights[(i, f"{layer_type_prefix}_{FC1}")] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1)
layer_weights[(i, f"{layer_type_prefix}_{FC2}")] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2)

return layer_weights

Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,9 @@ def generate_token(self, batch: Seq2SeqLMBatch) -> Tuple[List[Generation], Optio
else:
seed = None

generated_text = GeneratedText(output_text, stopping_criteria.current_tokens, reason, seed)
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, stopping_criteria.current_skipped, reason, seed
)
else:
generated_text = None

Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"

FC1 = 'fc1'
FC2 = 'fc2'
FC1 = "fc1"
FC2 = "fc2"

LM_HEAD = "lm_head"
4 changes: 2 additions & 2 deletions server/lorax_server/utils/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ def __init__(
self.current_output = ""
self.current_skipped = 0
self.ignore_eos_token = ignore_eos_token

def __call__(self, last_token: int, last_output: str, skipped: bool = False) -> Tuple[bool, Optional[str]]:
if skipped:
self.current_skipped += 1

self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
Expand Down

0 comments on commit cbbe95a

Please sign in to comment.