Skip to content

Commit

Permalink
Transformers backend TP fix (#2945)
Browse files Browse the repository at this point in the history
* init dispatch

* cohere fix
  • Loading branch information
Cyrilvallez authored Jan 23, 2025
1 parent 29a0893 commit 18c4607
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 48 deletions.
180 changes: 132 additions & 48 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,21 +391,6 @@ def get_model(
)
model_type = config_dict.get("model_type", None)

transformers_causal_lm_class = CausalLM

# Fast transformers path
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
transformers_causal_lm_class = TransformersFlashCausalLM

quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None:
quantization_config = config_dict.get("compression_config", None)
Expand Down Expand Up @@ -649,7 +634,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
)
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -756,7 +741,7 @@ def get_model(
except RuntimeError as e:
# Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -767,7 +752,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -792,7 +777,7 @@ def get_model(
except RuntimeError as e:
# Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -803,7 +788,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -840,7 +825,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -863,7 +848,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
else:
return transformers_causal_lm_class.fallback(
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -887,7 +872,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -913,12 +898,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)

elif (
model_type == LLAMA
or model_type == BAICHUAN
or model_type == PHI3
or model_type == GRANITE
):
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
Expand All @@ -931,19 +911,56 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
# elif sharded:
# raise NotImplementedError(
# FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
# )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

elif model_type == BAICHUAN:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

if model_type == GEMMA:
if FLASH_ATTENTION:
return FlashCausalLM(
Expand All @@ -959,10 +976,19 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -988,7 +1014,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1010,10 +1036,19 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -1041,7 +1076,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -1091,7 +1126,7 @@ def get_model(
config_class=RWConfig,
)
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1113,10 +1148,19 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1138,10 +1182,19 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1163,12 +1216,21 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
)
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1190,10 +1252,19 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -1339,8 +1410,6 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
if quantize == "gptq":
raise NotImplementedError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
Expand All @@ -1353,15 +1422,30 @@ def get_model(
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return transformers_causal_lm_class.fallback(

# Fast transformers if available
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")

if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM.fallback(
model_id,
Expand All @@ -1375,7 +1459,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,11 @@ def _model_forward(
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head(hidden_states)

# For Granite while next transformers version is released and we can use `lm_head_indices` natively
if hasattr(self.model.config, "logits_scaling"):
logits = logits / self.model.config.logits_scaling
# For Cohere for similar reasons
elif hasattr(self.model, "logit_scale"):
logits = logits * self.model.logit_scale

return logits, None

0 comments on commit 18c4607

Please sign in to comment.