From b49d789417eba974a6cfd3855f4293bfeeeeb49f Mon Sep 17 00:00:00 2001 From: "Jeff (Zhen) Wang" Date: Tue, 12 Sep 2023 14:04:46 +1000 Subject: [PATCH] Added google/flan models and fixed AutoModelForSeq2SeqLM when loading T5 compression model (#2402) --- fastchat/model/compression.py | 16 ++++++++++++++-- fastchat/model/model_adapter.py | 8 ++++++++ pyproject.toml | 2 +- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/fastchat/model/compression.py b/fastchat/model/compression.py index 4a1d2adb7..c928db154 100644 --- a/fastchat/model/compression.py +++ b/fastchat/model/compression.py @@ -11,7 +11,13 @@ from torch.nn import functional as F import torch.nn as nn from tqdm import tqdm -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + AutoModel, + AutoModelForSeq2SeqLM, +) @dataclasses.dataclass @@ -123,7 +129,13 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai # some models are loaded by AutoModel but not AutoModelForCausalLM, # such as chatglm, chatglm2 try: - model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + # google/flan-* models are based on an AutoModelForSeq2SeqLM. + if "T5Config" in str(type(config)): + model = AutoModelForSeq2SeqLM.from_config( + config, trust_remote_code=True + ) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) except NameError: model = AutoModel.from_config(config, trust_remote_code=True) linear_weights = get_compressed_list(model) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index f018c212e..423308455 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -649,6 +649,13 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer +class FlanAdapter(T5Adapter): + """The model adapter for flan-t5-*, flan-ul2""" + + def match(self, model_path: str): + return "flan" in model_path.lower() + + class KoalaAdapter(BaseModelAdapter): """The model adapter for koala""" @@ -1592,6 +1599,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(LongChatAdapter) register_model_adapter(CodeT5pAdapter) register_model_adapter(T5Adapter) +register_model_adapter(FlanAdapter) register_model_adapter(KoalaAdapter) register_model_adapter(AlpacaAdapter) register_model_adapter(ChatGLMAdapter) diff --git a/pyproject.toml b/pyproject.toml index 1b30b8881..c3ce59364 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ ] [project.optional-dependencies] -model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0"] +model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] webui = ["gradio"] train = ["einops", "flash-attn>=2.0", "wandb"] llm_judge = ["openai", "anthropic>=0.3", "ray"]