Skip to content

Commit

Permalink
Added google/flan models and fixed AutoModelForSeq2SeqLM when loading…
Browse files Browse the repository at this point in the history
… T5 compression model (#2402)
  • Loading branch information
wangzhen263 authored Sep 12, 2023
1 parent a8088ba commit b49d789
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
16 changes: 14 additions & 2 deletions fastchat/model/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from torch.nn import functional as F
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
AutoModel,
AutoModelForSeq2SeqLM,
)


@dataclasses.dataclass
Expand Down Expand Up @@ -123,7 +129,13 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
# some models are loaded by AutoModel but not AutoModelForCausalLM,
# such as chatglm, chatglm2
try:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
# google/flan-* models are based on an AutoModelForSeq2SeqLM.
if "T5Config" in str(type(config)):
model = AutoModelForSeq2SeqLM.from_config(
config, trust_remote_code=True
)
else:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
except NameError:
model = AutoModel.from_config(config, trust_remote_code=True)
linear_weights = get_compressed_list(model)
Expand Down
8 changes: 8 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,13 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class FlanAdapter(T5Adapter):
"""The model adapter for flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "flan" in model_path.lower()


class KoalaAdapter(BaseModelAdapter):
"""The model adapter for koala"""

Expand Down Expand Up @@ -1592,6 +1599,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(LongChatAdapter)
register_model_adapter(CodeT5pAdapter)
register_model_adapter(T5Adapter)
register_model_adapter(FlanAdapter)
register_model_adapter(KoalaAdapter)
register_model_adapter(AlpacaAdapter)
register_model_adapter(ChatGLMAdapter)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
]

[project.optional-dependencies]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0"]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
webui = ["gradio"]
train = ["einops", "flash-attn>=2.0", "wandb"]
llm_judge = ["openai", "anthropic>=0.3", "ray"]
Expand Down

0 comments on commit b49d789

Please sign in to comment.