From 0890ee89eb13cdc4029e6539d9da5770b70410a3 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 21 Jun 2024 09:20:08 +0000 Subject: [PATCH] add qa demo --- .gitignore | 162 +++++ README.md | 18 + chat/__init__.py | 5 + chat/base_engine.py | 69 ++ chat/chat_model.py | 91 +++ chat/data/__init__.py | 8 + chat/data/formatter.py | 187 ++++++ chat/data/template.py | 807 +++++++++++++++++++++++ chat/data/utils.py | 94 +++ chat/extras/__init__.py | 0 chat/extras/callbacks.py | 166 +++++ chat/extras/constants.py | 983 +++++++++++++++++++++++++++++ chat/extras/logging.py | 48 ++ chat/extras/misc.py | 232 +++++++ chat/extras/packages.py | 61 ++ chat/extras/patches/__init__.py | 0 chat/extras/patches/llama_patch.py | 198 ++++++ chat/extras/ploting.py | 57 ++ chat/hf_engine.py | 264 ++++++++ chat/hparams/__init__.py | 18 + chat/hparams/data_args.py | 100 +++ chat/hparams/evaluation_args.py | 48 ++ chat/hparams/finetuning_args.py | 276 ++++++++ chat/hparams/generating_args.py | 67 ++ chat/hparams/model_args.py | 179 ++++++ chat/hparams/parser.py | 302 +++++++++ chat/model/__init__.py | 10 + chat/model/adapter.py | 166 +++++ chat/model/loader.py | 135 ++++ chat/model/patcher.py | 387 ++++++++++++ chat/model/utils.py | 137 ++++ chat/vllm_engine.py | 149 +++++ demo_QAMaking_from_pdf.py | 117 ++++ requirements.txt | 19 + template/QAtemplat.txt | 21 + 35 files changed, 5581 insertions(+) create mode 100644 .gitignore create mode 100644 chat/__init__.py create mode 100644 chat/base_engine.py create mode 100644 chat/chat_model.py create mode 100644 chat/data/__init__.py create mode 100644 chat/data/formatter.py create mode 100644 chat/data/template.py create mode 100644 chat/data/utils.py create mode 100644 chat/extras/__init__.py create mode 100644 chat/extras/callbacks.py create mode 100644 chat/extras/constants.py create mode 100644 chat/extras/logging.py create mode 100644 chat/extras/misc.py create mode 100644 chat/extras/packages.py create mode 100644 chat/extras/patches/__init__.py create mode 100644 chat/extras/patches/llama_patch.py create mode 100644 chat/extras/ploting.py create mode 100644 chat/hf_engine.py create mode 100644 chat/hparams/__init__.py create mode 100644 chat/hparams/data_args.py create mode 100644 chat/hparams/evaluation_args.py create mode 100644 chat/hparams/finetuning_args.py create mode 100644 chat/hparams/generating_args.py create mode 100644 chat/hparams/model_args.py create mode 100644 chat/hparams/parser.py create mode 100644 chat/model/__init__.py create mode 100644 chat/model/adapter.py create mode 100644 chat/model/loader.py create mode 100644 chat/model/patcher.py create mode 100644 chat/model/utils.py create mode 100644 chat/vllm_engine.py create mode 100644 demo_QAMaking_from_pdf.py create mode 100644 requirements.txt create mode 100644 template/QAtemplat.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..82f9275 --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md index 3b8d25b..00e9299 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,20 @@ # LM-playground repo contains mutil LLM & VL demo + +# QAMaking Demo + +## 1. 安装依赖 + +```bash +pip install -r requirements.txt +``` + +## 2. 挂载模型和选择template + + +## 3 运行 + +```bash +python demo_QAMaking_from_pdf.py --model_name_or_path $MODEL_PATH --template qwen --prompt_path "template/QAtemplat.txt" +``` + diff --git a/chat/__init__.py b/chat/__init__.py new file mode 100644 index 0000000..a1a79de --- /dev/null +++ b/chat/__init__.py @@ -0,0 +1,5 @@ +from .base_engine import BaseEngine +from .chat_model import ChatModel + + +__all__ = ["BaseEngine", "ChatModel"] diff --git a/chat/base_engine.py b/chat/base_engine.py new file mode 100644 index 0000000..c5db41d --- /dev/null +++ b/chat/base_engine.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + from ..data import Template + from ..extras.packages import is_vllm_available + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + if is_vllm_available(): + from vllm import AsyncLLMEngine + + +@dataclass +class Response: + response_text: str + response_length: int + prompt_length: int + finish_reason: Literal["stop", "length"] + + +class BaseEngine(ABC): + model: Union["PreTrainedModel", "AsyncLLMEngine"] + tokenizer: "PreTrainedTokenizer" + can_generate: bool + template: "Template" + generating_args: Dict[str, Any] + + @abstractmethod + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: ... + + @abstractmethod + async def start( + self, + ) -> None: ... + + @abstractmethod + async def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> List["Response"]: ... + + @abstractmethod + async def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: ... + + @abstractmethod + async def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: ... diff --git a/chat/chat_model.py b/chat/chat_model.py new file mode 100644 index 0000000..b33f560 --- /dev/null +++ b/chat/chat_model.py @@ -0,0 +1,91 @@ +import asyncio +from threading import Thread +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence + +from .hparams import get_infer_args +from .hf_engine import HuggingfaceEngine +from .vllm_engine import VllmEngine + + +if TYPE_CHECKING: + from .base_engine import BaseEngine, Response + + +def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + +class ChatModel: + def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: + model_args, data_args, finetuning_args, generating_args = get_infer_args(args) + if model_args.infer_backend == "huggingface": + self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) + elif model_args.infer_backend == "vllm": + self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) + else: + raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) + + self._loop = asyncio.new_event_loop() + self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) + self._thread.start() + asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop) + + def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> List["Response"]: + task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop) + return task.result() + + async def achat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> List["Response"]: + return await self.engine.chat(messages, system, tools, **input_kwargs) + + def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> Generator[str, None, None]: + generator = self.astream_chat(messages, system, tools, **input_kwargs) + while True: + try: + task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) + yield task.result() + except StopAsyncIteration: + break + + async def astream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs): + yield new_token + + def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) + return task.result() + + async def aget_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + return await self.engine.get_scores(batch_input, **input_kwargs) diff --git a/chat/data/__init__.py b/chat/data/__init__.py new file mode 100644 index 0000000..43a80d6 --- /dev/null +++ b/chat/data/__init__.py @@ -0,0 +1,8 @@ +from .template import Template, get_template_and_fix_tokenizer, templates + + +__all__ = [ + "Template", + "get_template_and_fix_tokenizer", + "templates" +] diff --git a/chat/data/formatter.py b/chat/data/formatter.py new file mode 100644 index 0000000..0cd3d6c --- /dev/null +++ b/chat/data/formatter.py @@ -0,0 +1,187 @@ +import json +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union + + +SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] + + +JSON_FORMAT_PROMPT = ( + """, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)""" +) + + +TOOL_SYSTEM_PROMPT = ( + "You have access to the following tools:\n{tool_text}" + "Use the following format if using a tool:\n" + "```\n" + "Action: tool name (one of [{tool_names}]).\n" + "Action Input: the input to the tool{format_prompt}.\n" + "```\n" +) + + +def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + tool_names = [] + for tool in tools: + param_text = "" + for name, param in tool["parameters"]["properties"].items(): + required = ", required" if name in tool["parameters"].get("required", []) else "" + enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else "" + items = ( + ", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else "" + ) + param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( + name=name, + type=param.get("type", ""), + required=required, + desc=param.get("description", ""), + enum=enum, + items=items, + ) + + tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( + name=tool["name"], desc=tool.get("description", ""), args=param_text + ) + tool_names.append(tool["name"]) + + return TOOL_SYSTEM_PROMPT.format( + tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT + ) + + +def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: + regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL) + action_match = re.search(regex, content) + if not action_match: + return content + + tool_name = action_match.group(1).strip() + tool_input = action_match.group(2).strip().strip('"').strip("```") + try: + arguments = json.loads(tool_input) + except json.JSONDecodeError: + return content + + return tool_name, json.dumps(arguments, ensure_ascii=False) + + +@dataclass +class Formatter(ABC): + slots: SLOTS = field(default_factory=list) + tool_format: Optional[Literal["default"]] = None + + @abstractmethod + def apply(self, **kwargs) -> SLOTS: ... + + def extract(self, content: str) -> Union[str, Tuple[str, str]]: + raise NotImplementedError + + +@dataclass +class EmptyFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if has_placeholder: + raise ValueError("Empty formatter should not contain any placeholder.") + + def apply(self, **kwargs) -> SLOTS: + return self.slots + + +@dataclass +class StringFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if not has_placeholder: + raise ValueError("A placeholder is required in the string formatter.") + + def apply(self, **kwargs) -> SLOTS: + elements = [] + for slot in self.slots: + if isinstance(slot, str): + for name, value in kwargs.items(): + if not isinstance(value, str): + raise RuntimeError("Expected a string, got {}".format(value)) + + slot = slot.replace("{{" + name + "}}", value, 1) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + + return elements + + +@dataclass +class FunctionFormatter(Formatter): + def __post_init__(self): + has_name, has_args = False, False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if "{{name}}" in slot: + has_name = True + if "{{arguments}}" in slot: + has_args = True + + if not has_name or not has_args: + raise ValueError("Name and arguments placeholders are required in the function formatter.") + + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + try: + function = json.loads(content) + name = function["name"] + arguments = json.dumps(function["arguments"], ensure_ascii=False) + except Exception: + name, arguments = "", "" + + elements = [] + for slot in self.slots: + if isinstance(slot, str): + slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + + return elements + + +@dataclass +class ToolFormatter(Formatter): + def __post_init__(self): + if self.tool_format is None: + raise ValueError("Tool format was not found.") + + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + try: + tools = json.loads(content) + if not len(tools): + return [""] + + if self.tool_format == "default": + return [default_tool_formatter(tools)] + else: + raise NotImplementedError + except Exception: + return [""] + + def extract(self, content: str) -> Union[str, Tuple[str, str]]: + if self.tool_format == "default": + return default_tool_extractor(content) + else: + raise NotImplementedError diff --git a/chat/data/template.py b/chat/data/template.py new file mode 100644 index 0000000..f285914 --- /dev/null +++ b/chat/data/template.py @@ -0,0 +1,807 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + +from ..extras.logging import get_logger +from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter +from .utils import Role, infer_max_len + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from .formatter import SLOTS, Formatter + + +logger = get_logger(__name__) + + +@dataclass +class Template: + format_user: "Formatter" + format_assistant: "Formatter" + format_system: "Formatter" + format_function: "Formatter" + format_observation: "Formatter" + format_tools: "Formatter" + format_separator: "Formatter" + default_system: str + stop_words: List[str] + efficient_eos: bool + replace_eos: bool + force_system: bool + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: List[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + cutoff_len: int = 1_000_000, + reserved_label_len: int = 1, + ) -> Tuple[List[int], List[int]]: + r""" + Returns a single pair of token ids representing prompt and response respectively. + """ + encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) + prompt_ids = [] + for query_ids, resp_ids in encoded_pairs[:-1]: + prompt_ids += query_ids + resp_ids + prompt_ids = prompt_ids + encoded_pairs[-1][0] + answer_ids = encoded_pairs[-1][1] + return prompt_ids, answer_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: List[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + cutoff_len: int = 1_000_000, + reserved_label_len: int = 1, + ) -> Sequence[Tuple[List[int], List[int]]]: + r""" + Returns multiple pairs of token ids representing prompts and responses respectively. + """ + return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: List[Dict[str, str]], + system: str, + tools: str, + cutoff_len: int, + reserved_label_len: int, + ) -> Sequence[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: system + query resp + Turn t: sep + query resp + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + if i == 0 and (system or tools or self.force_system): + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) + elif i > 0 and i % 2 == 0: + elements += self.format_separator.apply() + + if message["role"] == Role.USER.value: + elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT.value: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION.value: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION.value: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + + def _convert_elements_to_ids( + self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]] + ) -> List[int]: + r""" + Converts elements to token ids. + """ + token_ids = [] + for elem in elements: + if isinstance(elem, str): + if len(elem) != 0: + token_ids += tokenizer.encode(elem, add_special_tokens=False) + elif isinstance(elem, dict): + token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] + elif isinstance(elem, set): + if "bos_token" in elem and tokenizer.bos_token_id is not None: + token_ids += [tokenizer.bos_token_id] + elif "eos_token" in elem and tokenizer.eos_token_id is not None: + token_ids += [tokenizer.eos_token_id] + else: + raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) + + return token_ids + + def _make_pairs( + self, + encoded_messages: Sequence[List[int]], + cutoff_len: int, + reserved_label_len: int, + ) -> Sequence[Tuple[List[int], List[int]]]: + encoded_pairs = [] + total_length = 0 + for i in range(0, len(encoded_messages), 2): + if total_length >= cutoff_len: + break + + max_source_len, max_target_len = infer_max_len( + source_len=len(encoded_messages[i]), + target_len=len(encoded_messages[i + 1]), + max_len=(cutoff_len - total_length), + reserved_label_len=reserved_label_len, + ) + source_ids = encoded_messages[i][:max_source_len] + target_ids = encoded_messages[i + 1][:max_target_len] + total_length += len(source_ids) + len(target_ids) + encoded_pairs.append((source_ids, target_ids)) + + return encoded_pairs + + +@dataclass +class Llama2Template(Template): + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: List[Dict[str, str]], + system: str, + tools: str, + cutoff_len: int, + reserved_label_len: int, + ) -> Sequence[Tuple[List[int], List[int]]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: system + query resp + Turn t: sep + query resp + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + system_text = "" + if i == 0 and (system or tools or self.force_system): + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + system_text = self.format_system.apply(content=(system + tool_text))[0] + elif i > 0 and i % 2 == 0: + elements += self.format_separator.apply() + + if message["role"] == Role.USER.value: + elements += self.format_user.apply(content=system_text + message["content"]) + elif message["role"] == Role.ASSISTANT.value: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION.value: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION.value: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + + +templates: Dict[str, Template] = {} + + +def _register_template( + name: str, + format_user: Optional["Formatter"] = None, + format_assistant: Optional["Formatter"] = None, + format_system: Optional["Formatter"] = None, + format_function: Optional["Formatter"] = None, + format_observation: Optional["Formatter"] = None, + format_tools: Optional["Formatter"] = None, + format_separator: Optional["Formatter"] = None, + default_system: str = "", + stop_words: List[str] = [], + efficient_eos: bool = False, + replace_eos: bool = False, + force_system: bool = False, +) -> None: + r""" + Registers a chat template. + + To add the following chat template: + ``` + [HUMAN]: + user prompt here + [AI]: + model response here + + [HUMAN]: + user prompt here + [AI]: + model response here + ``` + + The corresponding code should be: + ``` + _register_template( + name="custom", + format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + efficient_eos=True, + ) + ``` + """ + eos_slots = [] if efficient_eos else [{"eos_token"}] + template_class = Llama2Template if name.startswith("llama2") else Template + default_user_formatter = StringFormatter(slots=["{{content}}"]) + default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) + default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) + default_tool_formatter = ToolFormatter(tool_format="default") + default_separator_formatter = EmptyFormatter() + templates[name] = template_class( + format_user=format_user or default_user_formatter, + format_assistant=format_assistant or default_assistant_formatter, + format_system=format_system or default_user_formatter, + format_function=format_function or default_function_formatter, + format_observation=format_observation or format_user or default_user_formatter, + format_tools=format_tools or default_tool_formatter, + format_separator=format_separator or default_separator_formatter, + default_system=default_system, + stop_words=stop_words, + efficient_eos=efficient_eos, + replace_eos=replace_eos, + force_system=force_system, + ) + + +def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: + is_added = tokenizer.eos_token_id is None + num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) + + if is_added: + logger.info("Add eos token: {}".format(tokenizer.eos_token)) + else: + logger.info("Replace eos token: {}".format(tokenizer.eos_token)) + + if num_added_tokens > 0: + logger.warning("New tokens have been added, make sure `resize_vocab` is True.") + + +def _jinja_escape(content: str) -> str: + return content.replace("\n", r"\n").replace("'", r"\'") + + +def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'") + if len(slot_pieces) > 1: + slot_items.append(placeholder) + if slot_pieces[1]: + slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'") + elif isinstance(slot, set): + if "bos_token" in slot: + slot_items.append("'" + tokenizer.bos_token + "'") + elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced + slot_items.append("'" + tokenizer.eos_token + "'") + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return " + ".join(slot_items) + + +def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: + jinja_template = "" + + if template.default_system: + jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}" + ) + + system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") + if isinstance(template, Llama2Template): + pass + elif template.force_system: + jinja_template += "{{ " + system_message + " }}" + else: + jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" + + jinja_template += "{% for message in messages %}" + jinja_template += "{% set content = message['content'] %}" + if isinstance(template, Llama2Template): + jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" + jinja_template += "{% set content = " + system_message + " + message['content'] %}" + jinja_template += "{% endif %}" + jinja_template += "{% if message['role'] == 'user' %}" + user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) + jinja_template += "{{ " + user_message + " }}" + jinja_template += "{% elif message['role'] == 'assistant' %}" + assistant_message = _convert_slots_to_jinja( + template.format_assistant.apply() + template.format_separator.apply(), tokenizer + ) + jinja_template += "{{ " + assistant_message + " }}" + jinja_template += "{% endif %}" + jinja_template += "{% endfor %}" + return jinja_template + + +def get_template_and_fix_tokenizer( + tokenizer: "PreTrainedTokenizer", + name: Optional[str] = None, +) -> Template: + if name is None: + template = templates["vanilla"] # placeholder + else: + template = templates.get(name, None) + if template is None: + raise ValueError("Template {} does not exist.".format(name)) + + stop_words = template.stop_words + if template.replace_eos: + if not stop_words: + raise ValueError("Stop words are required to replace the EOS token.") + + _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) + stop_words = stop_words[1:] + + if tokenizer.eos_token_id is None: + _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info("Add pad token: {}".format(tokenizer.pad_token)) + + if stop_words: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False + ) + logger.info("Add {} to stop words.".format(",".join(stop_words))) + if num_added_tokens > 0: + logger.warning("New tokens have been added, make sure `resize_vocab` is True.") + + try: + tokenizer.chat_template = _get_jinja_template(template, tokenizer) + except ValueError: + logger.info("Cannot add this chat template to tokenizer.") + + return template + + +_register_template( + name="alpaca", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + default_system=( + "Below is an instruction that describes a task. " "Write a response that appropriately completes the request." + ), +) + +_register_template( + name="fewshot", + format_separator=EmptyFormatter(slots=["\n\n"]), + efficient_eos=True, +) + + +_register_template( + name="aquila", + format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), + format_separator=EmptyFormatter(slots=["###"]), + default_system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), + stop_words=[""], + efficient_eos=True, +) + + +_register_template( + name="atom", + format_user=StringFormatter( + slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"] + ), + format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), +) + + +_register_template( + name="baichuan", + format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), + efficient_eos=True, +) + + +_register_template( + name="baichuan2", + format_user=StringFormatter(slots=["{{content}}"]), + efficient_eos=True, +) + + +_register_template( + name="belle", + format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + force_system=True, +) + + +_register_template( + name="bluelm", + format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), +) + + +_register_template( + name="chatglm2", + format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), + format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + efficient_eos=True, + force_system=True, +) + + +_register_template( + name="chatglm3", + format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_assistant=StringFormatter(slots=["\n", "{{content}}"]), + format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), + format_observation=StringFormatter( + slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + ), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + force_system=True, +) + +_register_template( + name="llama3", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_system=StringFormatter( + slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"] + ), + format_observation=StringFormatter( + slots=[ + ( + "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + default_system="You are a helpful assistant.", + stop_words=["<|eot_id|>"], + replace_eos=True, +) + + +_register_template( + name="chatglm3_system", + format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_assistant=StringFormatter(slots=["\n", "{{content}}"]), + format_system=StringFormatter( + slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] + ), + format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), + format_observation=StringFormatter( + slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + ), + default_system=( + "You are ChatGLM3, a large language model trained by Zhipu.AI. " + "Follow the user's instructions carefully. Respond using markdown." + ), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +_register_template( + name="chatml", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, +) + + +_register_template( + name="chatml_de", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, +) + + +_register_template( + name="codegeex2", + format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + force_system=True, +) + + +_register_template( + name="cpm", + format_user=StringFormatter(slots=["<用户>{{content}}"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + force_system=True, +) + + +_register_template( + name="deepseek", + format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + force_system=True, +) + + +_register_template( + name="deepseekcoder", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), + format_assistant=StringFormatter(slots=["\n", "{{content}}"]), + format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]), + default_system=( + "You are an AI programming assistant, utilizing the Deepseek Coder model, " + "developed by Deepseek Company, and you only answer questions related to computer science. " + "For politically sensitive questions, security and privacy issues, " + "and other non-computer science questions, you will refuse to answer\n" + ), + stop_words=["<|EOT|>"], + efficient_eos=True, +) + + +_register_template( + name="default", + format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]), + format_system=StringFormatter(slots=["{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), +) + + +_register_template( + name="falcon", + format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), + format_separator=EmptyFormatter(slots=["\n"]), + efficient_eos=True, +) + + +_register_template( + name="gemma", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + format_separator=EmptyFormatter(slots=["\n"]), + efficient_eos=True, + force_system=True, +) + + +_register_template( + name="intern", + format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]), + format_separator=EmptyFormatter(slots=[{"token": ""}, "\n"]), + stop_words=[""], + efficient_eos=True, +) + + +_register_template( + name="intern2", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system=( + "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed " + "by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen " + "by the user such as English and 中文." + ), + stop_words=["<|im_end|>"], + efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id +) + + +_register_template( + name="llama2", + format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), + default_system=( + "You are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), +) + + +_register_template( + name="llama2_zh", + format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), + default_system="You are a helpful assistant. 你是一个乐于助人的助手。", +) + + +_register_template( + name="mistral", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + force_system=True, +) + + +_register_template( + name="olmo", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]), + format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]), + force_system=True, +) + + +_register_template( + name="openchat", + format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), + format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + force_system=True, +) + + +_register_template( + name="orion", + format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + force_system=True, +) + + +_register_template( + name="qwen", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="You are a helpful assistant.", + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +_register_template( + name="solar", + format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), + format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]), + efficient_eos=True, +) + + +_register_template( + name="starchat", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|end|>"], + replace_eos=True, + force_system=True, +) + + +_register_template( + name="vanilla", + format_separator=EmptyFormatter(slots=["\n"]), + efficient_eos=True, +) + + +_register_template( + name="vicuna", + format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), +) + + +_register_template( + name="xuanyuan", + format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), + default_system=( + "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头," + "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、" + "不安全、有争议、政治敏感等相关的话题、问题和指示。\n" + ), +) + + +_register_template( + name="xverse", + format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), +) + + +_register_template( + name="yayi", + format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), + format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + default_system=( + "You are a helpful, respectful and honest assistant named YaYi " + "developed by Beijing Wenge Technology Co.,Ltd. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), + stop_words=["<|End|>"], +) + + +_register_template( + name="yi", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +_register_template( + name="yuan", + format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=[""], + replace_eos=True, +) + + +_register_template( + name="zephyr", + format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), + default_system="You are a friendly chatbot who always responds in the style of a pirate", +) + + +_register_template( + name="ziya", + format_user=StringFormatter(slots=[":{{content}}\n:"]), + format_separator=EmptyFormatter(slots=["\n"]), +) diff --git a/chat/data/utils.py b/chat/data/utils.py new file mode 100644 index 0000000..83ee061 --- /dev/null +++ b/chat/data/utils.py @@ -0,0 +1,94 @@ +import hashlib +from enum import Enum, unique +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +from datasets import concatenate_datasets, interleave_datasets + +from ..extras.logging import get_logger + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments + + from llmtuner.hparams import DataArguments + + +logger = get_logger(__name__) + + +@unique +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + OBSERVATION = "observation" + + +def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: + if file_sha1 is None: + logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") + return + + if len(data_files) != 1: + logger.warning("Checksum failed: too many files.") + return + + with open(data_files[0], "rb") as f: + sha1 = hashlib.sha1(f.read()).hexdigest() + if sha1 != file_sha1: + logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) + + +def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]: + max_target_len = int(max_len * (target_len / (source_len + target_len))) + max_target_len = max(max_target_len, reserved_label_len) + max_source_len = max_len - min(max_target_len, target_len) + return max_source_len, max_target_len + + +def merge_dataset( + all_datasets: List[Union["Dataset", "IterableDataset"]], + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + if len(all_datasets) == 1: + return all_datasets[0] + elif data_args.mix_strategy == "concat": + if data_args.streaming: + logger.warning("The samples between different datasets will not be mixed in streaming mode.") + return concatenate_datasets(all_datasets) + elif data_args.mix_strategy.startswith("interleave"): + if not data_args.streaming: + logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") + return interleave_datasets( + datasets=all_datasets, + probabilities=data_args.interleave_probs, + seed=training_args.seed, + stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", + ) + else: + raise ValueError("Unknown mixing strategy.") + + +def split_dataset( + dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments" +) -> Dict[str, "Dataset"]: + if training_args.do_train: + if data_args.val_size > 1e-6: # Split the dataset + if data_args.streaming: + val_set = dataset.take(int(data_args.val_size)) + train_set = dataset.skip(int(data_args.val_size)) + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + return {"train_dataset": train_set, "eval_dataset": val_set} + else: + val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size + dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed) + return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} + else: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + return {"train_dataset": dataset} + else: # do_eval or do_predict + return {"eval_dataset": dataset} diff --git a/chat/extras/__init__.py b/chat/extras/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chat/extras/callbacks.py b/chat/extras/callbacks.py new file mode 100644 index 0000000..6e347c3 --- /dev/null +++ b/chat/extras/callbacks.py @@ -0,0 +1,166 @@ +import json +import os +import time +from datetime import timedelta +from typing import TYPE_CHECKING + +from transformers import TrainerCallback +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length + +from .constants import LOG_FILE_NAME +from .logging import get_logger +from .misc import fix_valuehead_checkpoint + + +if TYPE_CHECKING: + from transformers import TrainerControl, TrainerState, TrainingArguments + + +logger = get_logger(__name__) + + +class FixValueHeadModelCallback(TrainerCallback): + def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a checkpoint save. + """ + if args.should_save: + fix_valuehead_checkpoint( + model=kwargs.pop("model"), + output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), + safe_serialization=args.save_safetensors, + ) + + +class LogCallback(TrainerCallback): + def __init__(self, runner=None): + self.runner = runner + self.in_training = False + self.start_time = time.time() + self.cur_steps = 0 + self.max_steps = 0 + self.elapsed_time = "" + self.remaining_time = "" + + def timing(self): + cur_time = time.time() + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 + remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step + self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) + self.remaining_time = str(timedelta(seconds=int(remaining_time))) + + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the beginning of training. + """ + if state.is_local_process_zero: + self.in_training = True + self.start_time = time.time() + self.max_steps = state.max_steps + + if args.save_on_each_node: + if not state.is_local_process_zero: + return + else: + if not state.is_world_process_zero: + return + + if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: + logger.warning("Previous log file in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if state.is_local_process_zero: + self.in_training = False + self.cur_steps = 0 + self.max_steps = 0 + + def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of an substep during gradient accumulation. + """ + if state.is_local_process_zero and self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of a training step. + """ + if state.is_local_process_zero: + self.cur_steps = state.global_step + self.timing() + if self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after an evaluation phase. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + + def on_predict( + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs + ): + r""" + Event called after a successful prediction. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: + r""" + Event called after logging the last logs. + """ + if args.save_on_each_node: + if not state.is_local_process_zero: + return + else: + if not state.is_world_process_zero: + return + + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + loss=state.log_history[-1].get("loss", None), + eval_loss=state.log_history[-1].get("eval_loss", None), + predict_loss=state.log_history[-1].get("predict_loss", None), + reward=state.log_history[-1].get("reward", None), + accuracy=state.log_history[-1].get("rewards/accuracies", None), + learning_rate=state.log_history[-1].get("learning_rate", None), + epoch=state.log_history[-1].get("epoch", None), + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time, + ) + if self.runner is not None: + logger.info( + "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( + logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 + ) + ) + + os.makedirs(args.output_dir, exist_ok=True) + with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: + f.write(json.dumps(logs) + "\n") + + def on_prediction_step( + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs + ): + r""" + Event called after a prediction step. + """ + eval_dataloader = kwargs.pop("eval_dataloader", None) + if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: + if self.max_steps == 0: + self.max_steps = len(eval_dataloader) + self.cur_steps += 1 + self.timing() diff --git a/chat/extras/constants.py b/chat/extras/constants.py new file mode 100644 index 0000000..2c7f5e5 --- /dev/null +++ b/chat/extras/constants.py @@ -0,0 +1,983 @@ +from collections import OrderedDict, defaultdict +from enum import Enum +from typing import Dict, Optional + + +CHOICES = ["A", "B", "C", "D"] + +DATA_CONFIG = "dataset_info.json" + +DEFAULT_MODULE = defaultdict(str) + +DEFAULT_TEMPLATE = defaultdict(str) + +FILEEXT2TYPE = { + "arrow": "arrow", + "csv": "csv", + "json": "json", + "jsonl": "json", + "parquet": "parquet", + "txt": "text", +} + +IGNORE_INDEX = -100 + +LAYERNORM_NAMES = {"norm", "ln"} + +LOG_FILE_NAME = "trainer_log.jsonl" + +METHODS = ["full", "freeze", "lora"] + +PEFT_METHODS = ["lora"] + +SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] + +SUPPORTED_MODELS = OrderedDict() + +TRAINING_STAGES = { + "Supervised Fine-Tuning": "sft", + "Reward Modeling": "rm", + "PPO": "ppo", + "DPO": "dpo", + "ORPO": "orpo", + "Pre-Training": "pt", +} + +STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"] + +V_HEAD_WEIGHTS_NAME = "value_head.bin" + +V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" + + +class DownloadSource(str, Enum): + DEFAULT = "hf" + MODELSCOPE = "ms" + + +def register_model_group( + models: Dict[str, Dict[DownloadSource, str]], + module: Optional[str] = None, + template: Optional[str] = None, +) -> None: + prefix = None + for name, path in models.items(): + if prefix is None: + prefix = name.split("-")[0] + else: + assert prefix == name.split("-")[0], "prefix should be identical." + SUPPORTED_MODELS[name] = path + if module is not None: + DEFAULT_MODULE[prefix] = module + if template is not None: + DEFAULT_TEMPLATE[prefix] = template + + +register_model_group( + models={ + "Baichuan-7B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B", + DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B", + }, + "Baichuan-13B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base", + }, + "Baichuan-13B-Chat": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat", + }, + }, + module="W_pack", + template="baichuan", +) + + +register_model_group( + models={ + "Baichuan2-7B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base", + }, + "Baichuan2-13B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base", + }, + "Baichuan2-7B-Chat": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat", + }, + "Baichuan2-13B-Chat": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat", + }, + }, + module="W_pack", + template="baichuan2", +) + + +register_model_group( + models={ + "BLOOM-560M": { + DownloadSource.DEFAULT: "bigscience/bloom-560m", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m", + }, + "BLOOM-3B": { + DownloadSource.DEFAULT: "bigscience/bloom-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b", + }, + "BLOOM-7B1": { + DownloadSource.DEFAULT: "bigscience/bloom-7b1", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1", + }, + }, + module="query_key_value", +) + + +register_model_group( + models={ + "BLOOMZ-560M": { + DownloadSource.DEFAULT: "bigscience/bloomz-560m", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m", + }, + "BLOOMZ-3B": { + DownloadSource.DEFAULT: "bigscience/bloomz-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b", + }, + "BLOOMZ-7B1-mt": { + DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt", + }, + }, + module="query_key_value", +) + + +register_model_group( + models={ + "BlueLM-7B-Base": { + DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base", + DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base", + }, + "BlueLM-7B-Chat": { + DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat", + DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat", + }, + }, + template="bluelm", +) + + +register_model_group( + models={ + "ChatGLM2-6B-Chat": { + DownloadSource.DEFAULT: "THUDM/chatglm2-6b", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b", + } + }, + module="query_key_value", + template="chatglm2", +) + + +register_model_group( + models={ + "ChatGLM3-6B-Base": { + DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base", + }, + "ChatGLM3-6B-Chat": { + DownloadSource.DEFAULT: "THUDM/chatglm3-6b", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b", + }, + }, + module="query_key_value", + template="chatglm3", +) + + +register_model_group( + models={ + "ChineseLLaMA2-1.3B": { + DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b", + }, + "ChineseLLaMA2-7B": { + DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b", + }, + "ChineseLLaMA2-13B": { + DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b", + }, + "ChineseLLaMA2-1.3B-Chat": { + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b", + }, + "ChineseLLaMA2-7B-Chat": { + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b", + }, + "ChineseLLaMA2-13B-Chat": { + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b", + }, + }, + template="llama2_zh", +) + + +register_model_group( + models={ + "DeepSeek-LLM-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base", + }, + "DeepSeek-LLM-67B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base", + }, + "DeepSeek-LLM-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat", + }, + "DeepSeek-LLM-67B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat", + }, + "DeepSeek-Math-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base", + }, + "DeepSeek-Math-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct", + }, + "DeepSeek-MoE-16B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base", + }, + "DeepSeek-MoE-16B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat", + }, + }, + template="deepseek", +) + + +register_model_group( + models={ + "DeepSeekCoder-6.7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base", + }, + "DeepSeekCoder-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5", + }, + "DeepSeekCoder-33B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base", + }, + "DeepSeekCoder-6.7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct", + }, + "DeepSeekCoder-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5", + }, + "DeepSeekCoder-33B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct", + }, + }, + template="deepseekcoder", +) + + +register_model_group( + models={ + "Falcon-7B": { + DownloadSource.DEFAULT: "tiiuae/falcon-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b", + }, + "Falcon-40B": { + DownloadSource.DEFAULT: "tiiuae/falcon-40b", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b", + }, + "Falcon-180B": { + DownloadSource.DEFAULT: "tiiuae/falcon-180b", + DownloadSource.MODELSCOPE: "modelscope/falcon-180B", + }, + "Falcon-7B-Chat": { + DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct", + }, + "Falcon-40B-Chat": { + DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct", + }, + "Falcon-180B-Chat": { + DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat", + DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat", + }, + }, + module="query_key_value", + template="falcon", +) + + +register_model_group( + models={ + "Gemma-2B": { + DownloadSource.DEFAULT: "google/gemma-2b", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b", + }, + "Gemma-7B": { + DownloadSource.DEFAULT: "google/gemma-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it", + }, + "Gemma-2B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2b-it", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b", + }, + "Gemma-7B-Chat": { + DownloadSource.DEFAULT: "google/gemma-7b-it", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it", + }, + }, + template="gemma", +) + + +register_model_group( + models={ + "InternLM-7B": { + DownloadSource.DEFAULT: "internlm/internlm-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b", + }, + "InternLM-20B": { + DownloadSource.DEFAULT: "internlm/internlm-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b", + }, + "InternLM-7B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm-chat-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b", + }, + "InternLM-20B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm-chat-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b", + }, + }, + template="intern", +) + + +register_model_group( + models={ + "InternLM2-7B": { + DownloadSource.DEFAULT: "internlm/internlm2-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b", + }, + "InternLM2-20B": { + DownloadSource.DEFAULT: "internlm/internlm2-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b", + }, + "InternLM2-7B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm2-chat-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b", + }, + "InternLM2-20B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm2-chat-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b", + }, + }, + module="wqkv", + template="intern2", +) + + +register_model_group( + models={ + "LingoWhale-8B": { + DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B", + DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B", + } + }, + module="qkv_proj", +) + + +register_model_group( + models={ + "LLaMA-7B": { + DownloadSource.DEFAULT: "huggyllama/llama-7b", + DownloadSource.MODELSCOPE: "skyline2006/llama-7b", + }, + "LLaMA-13B": { + DownloadSource.DEFAULT: "huggyllama/llama-13b", + DownloadSource.MODELSCOPE: "skyline2006/llama-13b", + }, + "LLaMA-30B": { + DownloadSource.DEFAULT: "huggyllama/llama-30b", + DownloadSource.MODELSCOPE: "skyline2006/llama-30b", + }, + "LLaMA-65B": { + DownloadSource.DEFAULT: "huggyllama/llama-65b", + DownloadSource.MODELSCOPE: "skyline2006/llama-65b", + }, + } +) + + +register_model_group( + models={ + "LLaMA2-7B": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms", + }, + "LLaMA2-13B": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms", + }, + "LLaMA2-70B": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms", + }, + "LLaMA2-7B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms", + }, + "LLaMA2-13B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms", + }, + "LLaMA2-70B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms", + }, + }, + template="llama2", +) + + +register_model_group( + models={ + "Mistral-7B-v0.1": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1", + }, + "Mistral-7B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1", + }, + "Mistral-7B-v0.2": { + DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf", + }, + "Mistral-7B-v0.2-Chat": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2", + }, + }, + template="mistral", +) + + +register_model_group( + models={ + "Mixtral-8x7B": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1", + }, + "Mixtral-8x7B-Chat": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1", + }, + }, + template="mistral", +) + + +register_model_group( + models={ + "OLMo-1B": { + DownloadSource.DEFAULT: "allenai/OLMo-1B", + }, + "OLMo-7B": { + DownloadSource.DEFAULT: "allenai/OLMo-7B", + DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B", + }, + "OLMo-7B-Chat": { + DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct", + }, + }, + module="att_proj", + template="olmo", +) + + +register_model_group( + models={ + "OpenChat3.5-7B-Chat": { + DownloadSource.DEFAULT: "openchat/openchat-3.5-0106", + DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5", + } + }, + template="openchat", +) + + +register_model_group( + models={ + "Orion-14B-Base": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base", + }, + "Orion-14B-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat", + }, + "Orion-14B-Long-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat", + }, + "Orion-14B-RAG-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG", + }, + "Orion-14B-Plugin-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin", + }, + }, + template="orion", +) + + +register_model_group( + models={ + "Phi-1.5-1.3B": { + DownloadSource.DEFAULT: "microsoft/phi-1_5", + DownloadSource.MODELSCOPE: "allspace/PHI_1-5", + }, + "Phi-2-2.7B": { + DownloadSource.DEFAULT: "microsoft/phi-2", + DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2", + }, + } +) + + +register_model_group( + models={ + "Qwen-1.8B": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B", + }, + "Qwen-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B", + }, + "Qwen-14B": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B", + }, + "Qwen-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B", + }, + "Qwen-1.8B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat", + }, + "Qwen-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat", + }, + "Qwen-14B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat", + }, + "Qwen-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat", + }, + "Qwen-1.8B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8", + }, + "Qwen-1.8B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4", + }, + "Qwen-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8", + }, + "Qwen-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4", + }, + "Qwen-14B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8", + }, + "Qwen-14B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4", + }, + "Qwen-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8", + }, + "Qwen-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4", + }, + }, + module="c_attn", + template="qwen", +) + + +register_model_group( + models={ + "Qwen1.5-0.5B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B", + }, + "Qwen1.5-1.8B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B", + }, + "Qwen1.5-4B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B", + }, + "Qwen1.5-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B", + }, + "Qwen1.5-14B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B", + }, + "Qwen1.5-32B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B", + }, + "Qwen1.5-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B", + }, + "Qwen1.5-MoE-A2.7B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B", + }, + "Qwen1.5-0.5B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat", + }, + "Qwen1.5-1.8B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat", + }, + "Qwen1.5-4B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat", + }, + "Qwen1.5-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat", + }, + "Qwen1.5-14B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat", + }, + "Qwen1.5-32B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat", + }, + "Qwen1.5-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat", + }, + "Qwen1.5-MoE-A2.7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat", + }, + "Qwen1.5-0.5B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", + }, + "Qwen1.5-0.5B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ", + }, + "Qwen1.5-1.8B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", + }, + "Qwen1.5-1.8B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ", + }, + "Qwen1.5-4B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8", + }, + "Qwen1.5-4B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ", + }, + "Qwen1.5-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8", + }, + "Qwen1.5-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ", + }, + "Qwen1.5-14B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8", + }, + "Qwen1.5-14B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ", + }, + "Qwen1.5-32B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-GPTQ-Int4", + }, + "Qwen1.5-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8", + }, + "Qwen1.5-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ", + }, + "Qwen1.5-MoE-A2.7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", + }, + }, + template="qwen", +) + + +register_model_group( + models={ + "SOLAR-10.7B": { + DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0", + }, + "SOLAR-10.7B-Chat": { + DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", + DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0", + }, + }, + template="solar", +) + + +register_model_group( + models={ + "Skywork-13B-Base": { + DownloadSource.DEFAULT: "Skywork/Skywork-13B-base", + DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base", + } + } +) + + +register_model_group( + models={ + "StarCoder2-3B": { + DownloadSource.DEFAULT: "bigcode/starcoder2-3b", + }, + "StarCoder2-7B": { + DownloadSource.DEFAULT: "bigcode/starcoder2-7b", + }, + "StarCoder2-15B": { + DownloadSource.DEFAULT: "bigcode/starcoder2-15b", + }, + } +) + + +register_model_group( + models={ + "Vicuna1.5-7B-Chat": { + DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5", + DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5", + }, + "Vicuna1.5-13B-Chat": { + DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5", + DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5", + }, + }, + template="vicuna", +) + + +register_model_group( + models={ + "XuanYuan-70B": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B", + }, + "XuanYuan-70B-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat", + }, + "XuanYuan-70B-int8-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", + }, + "XuanYuan-70B-int4-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", + }, + }, + template="xuanyuan", +) + + +register_model_group( + models={ + "XVERSE-7B": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B", + }, + "XVERSE-13B": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B", + }, + "XVERSE-65B": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B", + }, + "XVERSE-65B-2": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B-2", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2", + }, + "XVERSE-7B-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat", + }, + "XVERSE-13B-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat", + }, + "XVERSE-65B-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat", + }, + }, + template="xverse", +) + + +register_model_group( + models={ + "Yayi-7B": { + DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2", + DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2", + }, + "Yayi-13B": { + DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2", + DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2", + }, + }, + template="yayi", +) + + +register_model_group( + models={ + "Yi-6B": { + DownloadSource.DEFAULT: "01-ai/Yi-6B", + DownloadSource.MODELSCOPE: "01ai/Yi-6B", + }, + "Yi-9B": { + DownloadSource.DEFAULT: "01-ai/Yi-9B", + DownloadSource.MODELSCOPE: "01ai/Yi-9B", + }, + "Yi-34B": { + DownloadSource.DEFAULT: "01-ai/Yi-34B", + DownloadSource.MODELSCOPE: "01ai/Yi-34B", + }, + "Yi-6B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat", + }, + "Yi-34B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat", + }, + "Yi-6B-int8-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits", + DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits", + }, + "Yi-6B-int4-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits", + DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits", + }, + "Yi-34B-int8-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits", + }, + "Yi-34B-int4-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits", + }, + }, + template="yi", +) + + +register_model_group( + models={ + "Yuan2-2B-Chat": { + DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf", + DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf", + }, + "Yuan2-51B-Chat": { + DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf", + DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf", + }, + "Yuan2-102B-Chat": { + DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf", + DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf", + }, + }, + template="yuan", +) + + +register_model_group( + models={ + "Zephyr-7B-Alpha-Chat": { + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha", + DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha", + }, + "Zephyr-7B-Beta-Chat": { + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta", + DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta", + }, + }, + template="zephyr", +) + + +register_model_group( + models={ + "Atom-7B": { + DownloadSource.DEFAULT: "FlagAlpha/Atom-7B", + DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B", + }, + "Atom-7B-Chat": { + DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat", + DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat", + }, + }, + template="atom", +) diff --git a/chat/extras/logging.py b/chat/extras/logging.py new file mode 100644 index 0000000..bb27077 --- /dev/null +++ b/chat/extras/logging.py @@ -0,0 +1,48 @@ +import logging +import sys + + +class LoggerHandler(logging.Handler): + r""" + Logger handler used in Web UI. + """ + + def __init__(self): + super().__init__() + self.log = "" + + def reset(self): + self.log = "" + + def emit(self, record): + if record.name == "httpx": + return + log_entry = self.format(record) + self.log += log_entry + self.log += "\n\n" + + +def get_logger(name: str) -> logging.Logger: + r""" + Gets a standard logger with a stream hander to stdout. + """ + formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" + ) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + logger.addHandler(handler) + + return logger + + +def reset_logging() -> None: + r""" + Removes basic config of root logger. (unused in script) + """ + root = logging.getLogger() + list(map(root.removeHandler, root.handlers)) + list(map(root.removeFilter, root.filters)) diff --git a/chat/extras/misc.py b/chat/extras/misc.py new file mode 100644 index 0000000..0237c33 --- /dev/null +++ b/chat/extras/misc.py @@ -0,0 +1,232 @@ +import gc +import os +from typing import TYPE_CHECKING, Dict, Tuple + +import torch +from peft import PeftModel +from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel +from transformers.utils import ( + SAFE_WEIGHTS_NAME, + WEIGHTS_NAME, + is_torch_bf16_gpu_available, + is_torch_cuda_available, + is_torch_mps_available, + is_torch_npu_available, + #is_torch_xpu_available, +) +from transformers.utils.versions import require_version + +from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from .logging import get_logger + + +_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() +try: + _is_bf16_available = is_torch_bf16_gpu_available() +except Exception: + _is_bf16_available = False + + +if TYPE_CHECKING: + from trl import AutoModelForCausalLMWithValueHead + + from llmtuner.hparams import ModelArguments + + +logger = get_logger(__name__) + + +class AverageMeter: + r""" + Computes and stores the average and current value. + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def check_dependencies() -> None: + if int(os.environ.get("DISABLE_VERSION_CHECK", "0")): + logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") + else: + require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") + require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") + require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2") + require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0") + require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1") + require_version("gradio>=4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0") + + +def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: + r""" + Returns the number of trainable parameters and number of all parameters in the model. + """ + trainable_params, all_param = 0, 0 + for param in model.parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 + if param.__class__.__name__ == "Params4bit": + if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"): + num_bytes = param.quant_storage.itemsize + else: + num_bytes = 1 + + num_params = num_params * 2 * num_bytes + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + +def fix_valuehead_checkpoint( + model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool +) -> None: + r""" + The model is already unwrapped. + + There are three cases: + 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} + 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} + 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} + + We assume `stage3_gather_16bit_weights_on_model_save=true`. + """ + if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): + return + + if safe_serialization: + from safetensors import safe_open + from safetensors.torch import save_file + + path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) + with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: + state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} + else: + path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) + state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + + decoder_state_dict = {} + v_head_state_dict = {} + for name, param in state_dict.items(): + if name.startswith("v_head."): + v_head_state_dict[name] = param + else: + decoder_state_dict[name.replace("pretrained_model.", "")] = param + + os.remove(path_to_checkpoint) + model.pretrained_model.save_pretrained( + output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization + ) + + if safe_serialization: + save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) + else: + torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) + + logger.info("Value head model saved at: {}".format(output_dir)) + + +def get_current_device() -> torch.device: + r""" + Gets the current available device. + """ + #if is_torch_xpu_available(): + # device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) + if is_torch_npu_available(): + device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif is_torch_mps_available(): + device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif is_torch_cuda_available(): + device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) + else: + device = "cpu" + + return torch.device(device) + + +def get_device_count() -> int: + r""" + Gets the number of available GPU devices. + """ + if not torch.cuda.is_available(): + return 0 + + return torch.cuda.device_count() + + +def get_logits_processor() -> "LogitsProcessorList": + r""" + Gets logits processor that removes NaN and Inf logits. + """ + logits_processor = LogitsProcessorList() + logits_processor.append(InfNanRemoveLogitsProcessor()) + return logits_processor + + +def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: + r""" + Infers the optimal dtype according to the model_dtype and device compatibility. + """ + if _is_bf16_available and model_dtype == torch.bfloat16: + return torch.bfloat16 + elif _is_fp16_available: + return torch.float16 + else: + return torch.float32 + + +def is_path_available(path: os.PathLike) -> bool: + r""" + Checks if the path is empty or not exist. + """ + if not os.path.exists(path): + return True + elif os.path.isdir(path) and not os.listdir(path): + return True + else: + return False + + +def torch_gc() -> None: + r""" + Collects GPU memory. + """ + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def try_download_model_from_ms(model_args: "ModelArguments") -> str: + if not use_modelscope() or os.path.exists(model_args.model_name_or_path): + return model_args.model_name_or_path + + try: + from modelscope import snapshot_download + + revision = "master" if model_args.model_revision == "main" else model_args.model_revision + return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir) + except ImportError: + raise ImportError("Please install modelscope via `pip install modelscope -U`") + + +def use_modelscope() -> bool: + return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0"))) diff --git a/chat/extras/packages.py b/chat/extras/packages.py new file mode 100644 index 0000000..cf10ffd --- /dev/null +++ b/chat/extras/packages.py @@ -0,0 +1,61 @@ +import importlib.metadata +import importlib.util + + +def _is_package_available(name: str) -> bool: + return importlib.util.find_spec(name) is not None + + +def _get_package_version(name: str) -> str: + try: + return importlib.metadata.version(name) + except Exception: + return "0.0.0" + + +def is_fastapi_availble(): + return _is_package_available("fastapi") + + +def is_flash_attn2_available(): + return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2") + + +def is_galore_available(): + return _is_package_available("galore_torch") + + +def is_jieba_available(): + return _is_package_available("jieba") + + +def is_matplotlib_available(): + return _is_package_available("matplotlib") + + +def is_nltk_available(): + return _is_package_available("nltk") + + +def is_requests_available(): + return _is_package_available("requests") + + +def is_rouge_available(): + return _is_package_available("rouge_chinese") + + +def is_starlette_available(): + return _is_package_available("sse_starlette") + + +def is_unsloth_available(): + return _is_package_available("unsloth") + + +def is_uvicorn_available(): + return _is_package_available("uvicorn") + + +def is_vllm_available(): + return _is_package_available("vllm") diff --git a/chat/extras/patches/__init__.py b/chat/extras/patches/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chat/extras/patches/llama_patch.py b/chat/extras/patches/llama_patch.py new file mode 100644 index 0000000..6a90c41 --- /dev/null +++ b/chat/extras/patches/llama_patch.py @@ -0,0 +1,198 @@ +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import ( + Cache, + LlamaAttention, + LlamaFlashAttention2, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import logging +from transformers.utils.versions import require_version + + +logger = logging.get_logger(__name__) + + +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py +def llama_torch_attn_forward( + self: "LlamaAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: torch.Tensor) -> torch.Tensor: + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ) + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py +def llama_flash_attn_forward( + self: "LlamaFlashAttention2", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once("The input hidden states seems to be silently casted in float32.") + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: torch.Tensor) -> torch.Tensor: + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + attn_output: torch.Tensor = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ) + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def apply_llama_patch() -> None: + require_version("transformers==4.39.3", "To fix: pip install transformers==4.39.3") + LlamaAttention.forward = llama_torch_attn_forward + LlamaFlashAttention2.forward = llama_flash_attn_forward diff --git a/chat/extras/ploting.py b/chat/extras/ploting.py new file mode 100644 index 0000000..aa101cb --- /dev/null +++ b/chat/extras/ploting.py @@ -0,0 +1,57 @@ +import json +import math +import os +from typing import List + +from transformers.trainer import TRAINER_STATE_NAME + +from .logging import get_logger +from .packages import is_matplotlib_available + + +if is_matplotlib_available(): + import matplotlib.pyplot as plt + + +logger = get_logger(__name__) + + +def smooth(scalars: List[float]) -> List[float]: + r""" + EMA implementation according to TensorBoard. + """ + last = scalars[0] + smoothed = list() + weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function + for next_val in scalars: + smoothed_val = last * weight + (1 - weight) * next_val + smoothed.append(smoothed_val) + last = smoothed_val + return smoothed + + +def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: + with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: + data = json.load(f) + + for key in keys: + steps, metrics = [], [] + for i in range(len(data["log_history"])): + if key in data["log_history"][i]: + steps.append(data["log_history"][i]["step"]) + metrics.append(data["log_history"][i][key]) + + if len(metrics) == 0: + logger.warning(f"No metric {key} to plot.") + continue + + plt.figure() + plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") + plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") + plt.title("training {} of {}".format(key, save_dictionary)) + plt.xlabel("step") + plt.ylabel(key) + plt.legend() + figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace(os.path.sep, "_"))) + plt.savefig(figure_path, format="png", dpi=100) + print("Figure saved at:", figure_path) diff --git a/chat/hf_engine.py b/chat/hf_engine.py new file mode 100644 index 0000000..36f831d --- /dev/null +++ b/chat/hf_engine.py @@ -0,0 +1,264 @@ +import asyncio +import concurrent.futures +import os +from threading import Thread +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple + +import torch +from transformers import GenerationConfig, TextIteratorStreamer + +from .data import get_template_and_fix_tokenizer +from .extras.misc import get_logits_processor +from .model import load_model, load_tokenizer +from .base_engine import BaseEngine, Response + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + from trl import PreTrainedModelWrapper + + from ..data import Template + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +class HuggingfaceEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.can_generate = finetuning_args.stage == "sft" + self.tokenizer = load_tokenizer(model_args) + self.tokenizer.padding_side = "left" if self.can_generate else "right" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.model = load_model( + self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) + ) # must after fixing tokenizer to resize vocab + self.generating_args = generating_args.to_dict() + + @staticmethod + def _process_args( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + template: "Template", + generating_args: Dict[str, Any], + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> Tuple[Dict[str, Any], int]: + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = template.encode_oneturn( + tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools + ) + prompt_length = len(prompt_ids) + inputs = torch.tensor([prompt_ids], device=model.device) + + do_sample = input_kwargs.pop("do_sample", None) + temperature = input_kwargs.pop("temperature", None) + top_p = input_kwargs.pop("top_p", None) + top_k = input_kwargs.pop("top_k", None) + num_return_sequences = input_kwargs.pop("num_return_sequences", None) + repetition_penalty = input_kwargs.pop("repetition_penalty", None) + max_length = input_kwargs.pop("max_length", None) + max_new_tokens = input_kwargs.pop("max_new_tokens", None) + + generating_args.update( + dict( + do_sample=do_sample if do_sample is not None else generating_args["do_sample"], + temperature=temperature or generating_args["temperature"], + top_p=top_p or generating_args["top_p"], + top_k=top_k or generating_args["top_k"], + num_return_sequences=num_return_sequences or 1, + repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], + eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, + pad_token_id=tokenizer.pad_token_id, + ) + ) + + if isinstance(num_return_sequences, int) and num_return_sequences > 1: + generating_args["do_sample"] = True + + if max_length: + generating_args.pop("max_new_tokens", None) + generating_args["max_length"] = max_length + + if max_new_tokens: + generating_args.pop("max_length", None) + generating_args["max_new_tokens"] = max_new_tokens + + gen_kwargs = dict( + inputs=inputs, + generation_config=GenerationConfig(**generating_args), + logits_processor=get_logits_processor(), + ) + + return gen_kwargs, prompt_length + + @staticmethod + @torch.inference_mode() + def _chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + template: "Template", + generating_args: Dict[str, Any], + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> List["Response"]: + gen_kwargs, prompt_length = HuggingfaceEngine._process_args( + model, tokenizer, template, generating_args, messages, system, tools, input_kwargs + ) + generate_output = model.generate(**gen_kwargs) + response_ids = generate_output[:, prompt_length:] + response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + results = [] + for i in range(len(response)): + eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero() + response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) + results.append( + Response( + response_text=response[i], + response_length=response_length, + prompt_length=prompt_length, + finish_reason="stop" if len(eos_index) else "length", + ) + ) + + return results + + @staticmethod + @torch.inference_mode() + def _stream_chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + template: "Template", + generating_args: Dict[str, Any], + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> Callable[[], str]: + gen_kwargs, _ = HuggingfaceEngine._process_args( + model, tokenizer, template, generating_args, messages, system, tools, input_kwargs + ) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + gen_kwargs["streamer"] = streamer + thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) + thread.start() + + def stream(): + try: + return streamer.__next__() + except StopIteration: + raise StopAsyncIteration() + + return stream + + @staticmethod + @torch.inference_mode() + def _get_scores( + model: "PreTrainedModelWrapper", + tokenizer: "PreTrainedTokenizer", + batch_input: List[str], + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> List[float]: + max_length = input_kwargs.pop("max_length", None) + device = getattr(model.pretrained_model, "device", "cuda") + inputs = tokenizer( + batch_input, + padding=True, + truncation=True, + max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), + return_tensors="pt", + add_special_tokens=True, + ).to(device) + + input_ids: torch.Tensor = inputs["input_ids"] + _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) + + if getattr(model.config, "model_type", None) == "chatglm": + values = torch.transpose(values, 0, 1) + + scores = [] + for i in range(input_ids.size(0)): + end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero() + end_index = end_indexes[-1].item() if len(end_indexes) else 0 + scores.append(values[i, end_index].nan_to_num().item()) + + return scores + + async def start(self) -> None: + self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) + + async def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> List["Response"]: + if not self.can_generate: + raise ValueError("The current model does not support `chat`.") + + loop = asyncio.get_running_loop() + input_args = ( + self.model, + self.tokenizer, + self.template, + self.generating_args, + messages, + system, + tools, + input_kwargs, + ) + async with self._semaphore: + with concurrent.futures.ThreadPoolExecutor() as pool: + return await loop.run_in_executor(pool, self._chat, *input_args) + + async def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + if not self.can_generate: + raise ValueError("The current model does not support `stream_chat`.") + + loop = asyncio.get_running_loop() + input_args = ( + self.model, + self.tokenizer, + self.template, + self.generating_args, + messages, + system, + tools, + input_kwargs, + ) + async with self._semaphore: + with concurrent.futures.ThreadPoolExecutor() as pool: + stream = self._stream_chat(*input_args) + while True: + try: + yield await loop.run_in_executor(pool, stream) + except StopAsyncIteration: + break + + async def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + if self.can_generate: + raise ValueError("Cannot get scores using an auto-regressive model.") + + loop = asyncio.get_running_loop() + input_args = (self.model, self.tokenizer, batch_input, input_kwargs) + async with self._semaphore: + with concurrent.futures.ThreadPoolExecutor() as pool: + return await loop.run_in_executor(pool, self._get_scores, *input_args) diff --git a/chat/hparams/__init__.py b/chat/hparams/__init__.py new file mode 100644 index 0000000..d1ee98d --- /dev/null +++ b/chat/hparams/__init__.py @@ -0,0 +1,18 @@ +from .data_args import DataArguments +from .evaluation_args import EvaluationArguments +from .finetuning_args import FinetuningArguments +from .generating_args import GeneratingArguments +from .model_args import ModelArguments +from .parser import get_eval_args, get_infer_args, get_train_args + + +__all__ = [ + "DataArguments", + "EvaluationArguments", + "FinetuningArguments", + "GeneratingArguments", + "ModelArguments", + "get_eval_args", + "get_infer_args", + "get_train_args", +] diff --git a/chat/hparams/data_args.py b/chat/hparams/data_args.py new file mode 100644 index 0000000..f5f75c7 --- /dev/null +++ b/chat/hparams/data_args.py @@ -0,0 +1,100 @@ +from dataclasses import dataclass, field +from typing import Literal, Optional + + +@dataclass +class DataArguments: + r""" + Arguments pertaining to what data we are going to input our model for training and evaluation. + """ + + template: Optional[str] = field( + default=None, + metadata={"help": "Which template to use for constructing prompts in training and inference."}, + ) + dataset: Optional[str] = field( + default=None, + metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, + ) + dataset_dir: str = field( + default="data", + metadata={"help": "Path to the folder containing the datasets."}, + ) + split: str = field( + default="train", + metadata={"help": "Which dataset split to use for training and evaluation."}, + ) + cutoff_len: int = field( + default=1024, + metadata={"help": "The cutoff length of the model inputs after tokenization."}, + ) + reserved_label_len: int = field( + default=1, + metadata={"help": "The minimum cutoff length reserved for label after tokenization."}, + ) + train_on_prompt: bool = field( + default=False, + metadata={"help": "Whether to disable the mask on the prompt or not."}, + ) + streaming: bool = field( + default=False, + metadata={"help": "Enable dataset streaming."}, + ) + buffer_size: int = field( + default=16384, + metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, + ) + mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field( + default="concat", + metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, + ) + interleave_probs: Optional[str] = field( + default=None, + metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the pre-processing."}, + ) + max_samples: Optional[int] = field( + default=None, + metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, + ) + eval_num_beams: Optional[int] = field( + default=None, + metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation." + }, + ) + val_size: float = field( + default=0.0, + metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}, + ) + packing: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." + }, + ) + tokenized_path: Optional[str] = field( + default=None, + metadata={"help": "Path to save or load the tokenized datasets."}, + ) + + def __post_init__(self): + if self.reserved_label_len >= self.cutoff_len: + raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.") + + if self.streaming and self.val_size > 1e-6 and self.val_size < 1: + raise ValueError("Streaming mode should have an integer val size.") + + if self.streaming and self.max_samples is not None: + raise ValueError("`max_samples` is incompatible with `streaming`.") diff --git a/chat/hparams/evaluation_args.py b/chat/hparams/evaluation_args.py new file mode 100644 index 0000000..5a05f6f --- /dev/null +++ b/chat/hparams/evaluation_args.py @@ -0,0 +1,48 @@ +import os +from dataclasses import dataclass, field +from typing import Literal, Optional + +from datasets import DownloadMode + + +@dataclass +class EvaluationArguments: + r""" + Arguments pertaining to specify the evaluation parameters. + """ + + task: str = field( + metadata={"help": "Name of the evaluation task."}, + ) + task_dir: str = field( + default="evaluation", + metadata={"help": "Path to the folder containing the evaluation datasets."}, + ) + batch_size: int = field( + default=4, + metadata={"help": "The batch size per GPU for evaluation."}, + ) + seed: int = field( + default=42, + metadata={"help": "Random seed to be used with data loaders."}, + ) + lang: Literal["en", "zh"] = field( + default="en", + metadata={"help": "Language used at evaluation."}, + ) + n_shot: int = field( + default=5, + metadata={"help": "Number of examplars for few-shot learning."}, + ) + save_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to save the evaluation results."}, + ) + download_mode: DownloadMode = field( + default=DownloadMode.REUSE_DATASET_IF_EXISTS, + metadata={"help": "Download mode used for the evaluation datasets."}, + ) + + def __post_init__(self): + if self.save_dir is not None and os.path.exists(self.save_dir): + raise ValueError("`save_dir` already exists, use another one.") diff --git a/chat/hparams/finetuning_args.py b/chat/hparams/finetuning_args.py new file mode 100644 index 0000000..177a9f8 --- /dev/null +++ b/chat/hparams/finetuning_args.py @@ -0,0 +1,276 @@ +import json +from dataclasses import asdict, dataclass, field +from typing import Literal, Optional + + +@dataclass +class FreezeArguments: + r""" + Arguments pertaining to the freeze (partial-parameter) training. + """ + + name_module_trainable: str = field( + default="all", + metadata={ + "help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \ + Use commas to separate multiple modules. \ + Use "all" to specify all the available modules. \ + LLaMA choices: ["mlp", "self_attn"], \ + BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \ + Qwen choices: ["mlp", "attn"], \ + InternLM2 choices: ["feed_forward", "attention"], \ + Others choices: the same as LLaMA.""" + }, + ) + num_layer_trainable: int = field( + default=2, + metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}, + ) + + +@dataclass +class LoraArguments: + r""" + Arguments pertaining to the LoRA training. + """ + + additional_target: Optional[str] = field( + default=None, + metadata={ + "help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint." + }, + ) + lora_alpha: Optional[int] = field( + default=None, + metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, + ) + lora_dropout: float = field( + default=0.0, + metadata={"help": "Dropout rate for the LoRA fine-tuning."}, + ) + lora_rank: int = field( + default=8, + metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}, + ) + lora_target: str = field( + default="all", + metadata={ + "help": """Name(s) of target modules to apply LoRA. \ + Use commas to separate multiple modules. \ + Use "all" to specify all the linear modules. \ + LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \ + BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \ + Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \ + Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \ + InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \ + Others choices: the same as LLaMA.""" + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, + ) + loraplus_lr_embedding: float = field( + default=1e-6, + metadata={"help": "LoRA plus learning rate for lora embedding layers."}, + ) + use_rslora: bool = field( + default=False, + metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, + ) + use_dora: bool = field( + default=False, + metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}, + ) + create_new_adapter: bool = field( + default=False, + metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, + ) + + +@dataclass +class RLHFArguments: + r""" + Arguments pertaining to the PPO and DPO training. + """ + + dpo_beta: float = field( + default=0.1, + metadata={"help": "The beta parameter for the DPO loss."}, + ) + dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field( + default="sigmoid", + metadata={"help": "The type of DPO loss to use."}, + ) + dpo_label_smoothing: float = field( + default=0.0, + metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."}, + ) + dpo_ftx: float = field( + default=0.0, + metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, + ) + orpo_beta: float = field( + default=0.1, + metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."}, + ) + ppo_buffer_size: int = field( + default=1, + metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, + ) + ppo_epochs: int = field( + default=4, + metadata={"help": "The number of epochs to perform in a PPO optimization step."}, + ) + ppo_score_norm: bool = field( + default=False, + metadata={"help": "Use score normalization in PPO training."}, + ) + ppo_target: float = field( + default=6.0, + metadata={"help": "Target KL value for adaptive KL control in PPO training."}, + ) + ppo_whiten_rewards: bool = field( + default=False, + metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, + ) + ref_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the reference model used for the PPO or DPO training."}, + ) + ref_model_adapters: Optional[str] = field( + default=None, + metadata={"help": "Path to the adapters of the reference model."}, + ) + ref_model_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the reference model."}, + ) + reward_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the reward model used for the PPO training."}, + ) + reward_model_adapters: Optional[str] = field( + default=None, + metadata={"help": "Path to the adapters of the reward model."}, + ) + reward_model_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the reward model."}, + ) + reward_model_type: Literal["lora", "full", "api"] = field( + default="lora", + metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, + ) + + +@dataclass +class GaloreArguments: + r""" + Arguments pertaining to the GaLore algorithm. + """ + + use_galore: bool = field( + default=False, + metadata={"help": "Whether or not to use gradient low-Rank projection."}, + ) + galore_target: str = field( + default="all", + metadata={ + "help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \ + Use "all" to specify all the linear modules.""" + }, + ) + galore_rank: int = field( + default=16, + metadata={"help": "The rank of GaLore gradients."}, + ) + galore_update_interval: int = field( + default=200, + metadata={"help": "Number of steps to update the GaLore projection."}, + ) + galore_scale: float = field( + default=0.25, + metadata={"help": "GaLore scaling coefficient."}, + ) + galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field( + default="std", + metadata={"help": "Type of GaLore projection."}, + ) + galore_layerwise: bool = field( + default=False, + metadata={"help": "Whether or not to enable layer-wise update to further save memory."}, + ) + + +@dataclass +class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments): + r""" + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + + pure_bf16: bool = field( + default=False, + metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, + ) + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."}, + ) + finetuning_type: Literal["lora", "freeze", "full"] = field( + default="lora", + metadata={"help": "Which fine-tuning method to use."}, + ) + use_llama_pro: bool = field( + default=False, + metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, + ) + plot_loss: bool = field( + default=False, + metadata={"help": "Whether or not to save the training loss curves."}, + ) + + def __post_init__(self): + def split_arg(arg): + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg + + self.name_module_trainable = split_arg(self.name_module_trainable) + self.lora_alpha = self.lora_alpha or self.lora_rank * 2 + self.lora_target = split_arg(self.lora_target) + self.additional_target = split_arg(self.additional_target) + self.galore_target = split_arg(self.galore_target) + + assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." + assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + + if self.stage == "ppo" and self.reward_model is None: + raise ValueError("`reward_model` is necessary for PPO training.") + + if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": + raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.") + + if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6: + raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") + + if self.use_llama_pro and self.finetuning_type == "full": + raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.") + + if self.use_galore and self.finetuning_type == "lora": + raise ValueError("Cannot use LoRA with GaLore together.") + + def save_to_json(self, json_path: str): + r"""Saves the content of this instance in JSON format inside `json_path`.""" + json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + r"""Creates an instance from the content of `json_path`.""" + with open(json_path, "r", encoding="utf-8") as f: + text = f.read() + + return cls(**json.loads(text)) diff --git a/chat/hparams/generating_args.py b/chat/hparams/generating_args.py new file mode 100644 index 0000000..7755d34 --- /dev/null +++ b/chat/hparams/generating_args.py @@ -0,0 +1,67 @@ +from dataclasses import asdict, dataclass, field +from typing import Any, Dict + + +@dataclass +class GeneratingArguments: + r""" + Arguments pertaining to specify the decoding parameters. + """ + + do_sample: bool = field( + default=True, + metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}, + ) + temperature: float = field( + default=0.95, + metadata={"help": "The value used to modulate the next token probabilities."}, + ) + top_p: float = field( + default=0.7, + metadata={ + "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." + }, + ) + top_k: int = field( + default=50, + metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, + ) + num_beams: int = field( + default=1, + metadata={"help": "Number of beams for beam search. 1 means no beam search."}, + ) + max_length: int = field( + default=512, + metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, + ) + max_new_tokens: int = field( + default=512, + metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, + ) + repetition_penalty: float = field( + default=1.0, + metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, + ) + length_penalty: float = field( + default=1.0, + metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, + ) + +## QA generation args + filepath: str = field( + default="", + metadata={"help": "the file path of raw input lines."}, + ) + prompt_path: str = field( + default=True, + metadata={"help": "the path of generation prompt"}, + ) + + + def to_dict(self) -> Dict[str, Any]: + args = asdict(self) + if args.get("max_new_tokens", -1) > 0: + args.pop("max_length", None) + else: + args.pop("max_new_tokens", None) + return args diff --git a/chat/hparams/model_args.py b/chat/hparams/model_args.py new file mode 100644 index 0000000..be71d32 --- /dev/null +++ b/chat/hparams/model_args.py @@ -0,0 +1,179 @@ +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Literal, Optional + + +@dataclass +class ModelArguments: + r""" + Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. + """ + + model_name_or_path: str = field( + metadata={ + "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." + }, + ) + adapter_name_or_path: Optional[str] = field( + default=None, + metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, + ) + use_fast_tokenizer: bool = field( + default=False, + metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, + ) + resize_vocab: bool = field( + default=False, + metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, + ) + split_special_tokens: bool = field( + default=False, + metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + low_cpu_mem_usage: bool = field( + default=True, + metadata={"help": "Whether or not to use memory-efficient model loading."}, + ) + quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the model using bitsandbytes."}, + ) + quantization_type: Literal["fp4", "nf4"] = field( + default="nf4", + metadata={"help": "Quantization data type to use in int4 training."}, + ) + double_quantization: bool = field( + default=True, + metadata={"help": "Whether or not to use double quantization in int4 training."}, + ) + quantization_device_map: Optional[Literal["auto"]] = field( + default=None, + metadata={"help": "Device map used for loading the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, + ) + rope_scaling: Optional[Literal["linear", "dynamic"]] = field( + default=None, + metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, + ) + flash_attn: bool = field( + default=False, + metadata={"help": "Enable FlashAttention-2 for faster training."}, + ) + shift_attn: bool = field( + default=False, + metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, + ) + use_unsloth: bool = field( + default=False, + metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, + ) + moe_aux_loss_coef: Optional[float] = field( + default=None, + metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, + ) + disable_gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Whether or not to disable gradient checkpointing."}, + ) + upcast_layernorm: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, + ) + upcast_lmhead_output: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, + ) + infer_backend: Literal["huggingface", "vllm"] = field( + default="huggingface", + metadata={"help": "Backend engine used at inference."}, + ) + vllm_maxlen: int = field( + default=2048, + metadata={"help": "Maximum input length of the vLLM engine."}, + ) + vllm_gpu_util: float = field( + default=0.9, + metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."}, + ) + vllm_enforce_eager: bool = field( + default=False, + metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."}, + ) + offload_folder: str = field( + default="offload", + metadata={"help": "Path to offload model weights."}, + ) + use_cache: bool = field( + default=True, + metadata={"help": "Whether or not to use KV cache in generation."}, + ) + hf_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Hugging Face Hub."}, + ) + ms_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with ModelScope Hub."}, + ) + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."}, + ) + export_size: int = field( + default=1, + metadata={"help": "The file shard size (in GB) of the exported model."}, + ) + export_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the exported model."}, + ) + export_quantization_dataset: Optional[str] = field( + default=None, + metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, + ) + export_quantization_nsamples: int = field( + default=128, + metadata={"help": "The number of samples used for quantization."}, + ) + export_quantization_maxlen: int = field( + default=1024, + metadata={"help": "The maximum length of the model inputs used for quantization."}, + ) + export_legacy_format: bool = field( + default=False, + metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, + ) + export_hub_model_id: Optional[str] = field( + default=None, + metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, + ) + print_param_status: bool = field( + default=False, + metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, + ) + + def __post_init__(self): + self.compute_dtype = None + self.device_map = None + self.model_max_length = None + + if self.split_special_tokens and self.use_fast_tokenizer: + raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + + if self.adapter_name_or_path is not None: # support merging multiple lora weights + self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] + + assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." + + if self.export_quantization_bit is not None and self.export_quantization_dataset is None: + raise ValueError("Quantization dataset is necessary for exporting.") + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) diff --git a/chat/hparams/parser.py b/chat/hparams/parser.py new file mode 100644 index 0000000..9264d1e --- /dev/null +++ b/chat/hparams/parser.py @@ -0,0 +1,302 @@ +import logging +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import torch +import transformers +from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import is_torch_bf16_gpu_available + +from ..extras.logging import get_logger +from ..extras.misc import check_dependencies +from ..extras.packages import is_unsloth_available +from .data_args import DataArguments +from .evaluation_args import EvaluationArguments +from .finetuning_args import FinetuningArguments +from .generating_args import GeneratingArguments +from .model_args import ModelArguments + + +logger = get_logger(__name__) + + +check_dependencies() + + +_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] +_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] +_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] +_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] + + +def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: + if args is not None: + return parser.parse_dict(args) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + return parser.parse_json_file(os.path.abspath(sys.argv[1])) + + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + if unknown_args: + print(parser.format_help()) + print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) + raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) + + return (*parsed_args,) + + +def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + +def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None: + if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Adapter is only valid for the LoRA method.") + + if model_args.quantization_bit is not None: + if finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + + if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter: + raise ValueError("Cannot create new adapter upon a quantized model.") + + if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: + raise ValueError("Quantized model only accepts a single adapter. Merge them first.") + + +def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + parser = HfArgumentParser(_TRAIN_ARGS) + return _parse_args(parser, args) + + +def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: + parser = HfArgumentParser(_INFER_ARGS) + return _parse_args(parser, args) + + +def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: + parser = HfArgumentParser(_EVAL_ARGS) + return _parse_args(parser, args) + + +def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) + + # Setup logging + if training_args.should_log: + _set_transformers_logging() + + # Check arguments + if finetuning_args.stage != "pt" and data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if finetuning_args.stage != "sft" and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True except SFT.") + + if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: + raise ValueError("Please enable `predict_with_generate` to save model predictions.") + + if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end: + raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") + + if finetuning_args.stage == "ppo" and not training_args.do_train: + raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.") + + if finetuning_args.stage == "ppo" and model_args.shift_attn: + raise ValueError("PPO training is incompatible with S^2-Attn.") + + if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: + raise ValueError("Unsloth does not support lora reward model.") + + if ( + finetuning_args.stage == "ppo" + and training_args.report_to + and training_args.report_to[0] not in ["wandb", "tensorboard"] + ): + raise ValueError("PPO only accepts wandb or tensorboard logger.") + + if training_args.max_steps == -1 and data_args.streaming: + raise ValueError("Please specify `max_steps` in streaming mode.") + + if training_args.do_train and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True while training.") + + if training_args.do_train and model_args.use_unsloth and not is_unsloth_available(): + raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth") + + if finetuning_args.use_dora and model_args.use_unsloth: + raise ValueError("Unsloth does not support DoRA.") + + if finetuning_args.pure_bf16: + if not is_torch_bf16_gpu_available(): + raise ValueError("This device does not support `pure_bf16`.") + + if training_args.fp16 or training_args.bf16: + raise ValueError("Turn off mixed precision training when using `pure_bf16`.") + + if ( + finetuning_args.use_galore + and finetuning_args.galore_layerwise + and training_args.parallel_mode.value == "distributed" + ): + raise ValueError("Distributed training does not support layer-wise GaLore.") + + if finetuning_args.use_galore and training_args.deepspeed is not None: + raise ValueError("GaLore is incompatible with DeepSpeed.") + + if model_args.infer_backend == "vllm": + raise ValueError("vLLM backend is only available for API, CLI and Web.") + + _verify_model_args(model_args, finetuning_args) + + if ( + training_args.do_train + and finetuning_args.finetuning_type == "lora" + and model_args.resize_vocab + and finetuning_args.additional_target is None + ): + logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.") + + if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): + logger.warning("We recommend enable `upcast_layernorm` in quantized training.") + + if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): + logger.warning("We recommend enable mixed precision training.") + + if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16: + logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.") + + if (not training_args.do_train) and model_args.quantization_bit is not None: + logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + + if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: + logger.warning("Specify `ref_model` for computing rewards at evaluation.") + + # Post-process training arguments + if ( + training_args.parallel_mode.value == "distributed" + and training_args.ddp_find_unused_parameters is None + and finetuning_args.finetuning_type == "lora" + ): + logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") + training_args.ddp_find_unused_parameters = False + + if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: + can_resume_from_checkpoint = False + if training_args.resume_from_checkpoint is not None: + logger.warning("Cannot resume from checkpoint in current stage.") + training_args.resume_from_checkpoint = None + else: + can_resume_from_checkpoint = True + + if ( + training_args.resume_from_checkpoint is None + and training_args.do_train + and os.path.isdir(training_args.output_dir) + and not training_args.overwrite_output_dir + and can_resume_from_checkpoint + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") + + if last_checkpoint is not None: + training_args.resume_from_checkpoint = last_checkpoint + logger.info( + "Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format( + training_args.resume_from_checkpoint + ) + ) + + if ( + finetuning_args.stage in ["rm", "ppo"] + and finetuning_args.finetuning_type == "lora" + and training_args.resume_from_checkpoint is not None + ): + logger.warning( + "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( + training_args.resume_from_checkpoint + ) + ) + + # Post-process model arguments + if training_args.bf16 or finetuning_args.pure_bf16: + model_args.compute_dtype = torch.bfloat16 + elif training_args.fp16: + model_args.compute_dtype = torch.float16 + + model_args.model_max_length = data_args.cutoff_len + data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" + + # Log on each process the small summary: + logger.info( + "Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format( + training_args.local_rank, + training_args.device, + training_args.n_gpu, + training_args.parallel_mode.value == "distributed", + str(model_args.compute_dtype), + ) + ) + + transformers.set_seed(training_args.seed) + + return model_args, data_args, training_args, finetuning_args, generating_args + + +def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: + model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) + + _set_transformers_logging() + + if data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if model_args.infer_backend == "vllm": + if finetuning_args.stage != "sft": + raise ValueError("vLLM engine only supports auto-regressive models.") + + if model_args.adapter_name_or_path is not None: + raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.") + + if model_args.quantization_bit is not None: + raise ValueError("vLLM engine does not support quantization.") + + if model_args.rope_scaling is not None: + raise ValueError("vLLM engine does not support RoPE scaling.") + + _verify_model_args(model_args, finetuning_args) + + model_args.device_map = "auto" + + return model_args, data_args, finetuning_args, generating_args + + +def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: + model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) + + _set_transformers_logging() + + if data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if model_args.infer_backend == "vllm": + raise ValueError("vLLM backend is only available for API, CLI and Web.") + + _verify_model_args(model_args, finetuning_args) + + model_args.device_map = "auto" + + transformers.set_seed(eval_args.seed) + + return model_args, data_args, eval_args, finetuning_args diff --git a/chat/model/__init__.py b/chat/model/__init__.py new file mode 100644 index 0000000..1eaf427 --- /dev/null +++ b/chat/model/__init__.py @@ -0,0 +1,10 @@ +from .loader import load_model, load_tokenizer +from .utils import find_all_linear_modules, load_valuehead_params + + +__all__ = [ + "load_model", + "load_tokenizer", + "load_valuehead_params", + "find_all_linear_modules", +] diff --git a/chat/model/adapter.py b/chat/model/adapter.py new file mode 100644 index 0000000..eb6d387 --- /dev/null +++ b/chat/model/adapter.py @@ -0,0 +1,166 @@ +from typing import TYPE_CHECKING + +import torch +from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model +from transformers.integrations import is_deepspeed_zero3_enabled + +from ..extras.logging import get_logger +from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules + + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + + from ..hparams import FinetuningArguments, ModelArguments + + +logger = get_logger(__name__) + + +def init_adapter( + model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool +) -> "PreTrainedModel": + r""" + Initializes the adapters. + + Support full-parameter, freeze and LoRA training. + + Note that the trainable parameters must be cast to float32. + """ + + if (not is_trainable) and model_args.adapter_name_or_path is None: + logger.info("Adapter is not found at evaluation, load the base model.") + return model + + if finetuning_args.finetuning_type == "full" and is_trainable: + logger.info("Fine-tuning method: Full") + if not finetuning_args.pure_bf16: + model = model.float() + + if finetuning_args.finetuning_type == "freeze" and is_trainable: + logger.info("Fine-tuning method: Freeze") + num_layers = ( + getattr(model.config, "num_hidden_layers", None) + or getattr(model.config, "num_layers", None) + or getattr(model.config, "n_layer", None) + ) + if not num_layers: + raise ValueError("Current model does not support freeze tuning.") + + if finetuning_args.use_llama_pro: + if num_layers % finetuning_args.num_layer_trainable != 0: + raise ValueError( + "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( + num_layers, finetuning_args.num_layer_trainable + ) + ) + + stride = num_layers // finetuning_args.num_layer_trainable + trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) + elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers) + else: # fine-tuning the first n layers if num_layer_trainable < 0 + trainable_layer_ids = range(-finetuning_args.num_layer_trainable) + + freeze_modules = {"all"} + for name, _ in model.named_modules(): + if ".0." in name: + freeze_modules.add(name.split(".0.")[-1].split(".")[0]) + + trainable_layers = [] + for module_name in finetuning_args.name_module_trainable: + if module_name not in freeze_modules: + raise ValueError( + "Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules)) + ) + + for idx in trainable_layer_ids: + trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else "")) + + for name, param in model.named_parameters(): + if any(trainable_layer in name for trainable_layer in trainable_layers): + if not finetuning_args.pure_bf16: + param.data = param.data.to(torch.float32) + else: + param.requires_grad_(False) + + logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))) + + if finetuning_args.finetuning_type == "lora": + logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + adapter_to_resume = None + + if model_args.adapter_name_or_path is not None: + is_mergeable = True + if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable + assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." + is_mergeable = False + + if is_deepspeed_zero3_enabled(): + assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." + is_mergeable = False + + if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): + adapter_to_merge = model_args.adapter_name_or_path[:-1] + adapter_to_resume = model_args.adapter_name_or_path[-1] + else: + adapter_to_merge = model_args.adapter_name_or_path + + for adapter in adapter_to_merge: + model: "LoraModel" = PeftModel.from_pretrained( + model, adapter, offload_folder=model_args.offload_folder + ) + model = model.merge_and_unload() + + if len(adapter_to_merge) > 0: + logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) + + if adapter_to_resume is not None: # resume lora training + model = PeftModel.from_pretrained( + model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder + ) + + if is_trainable and adapter_to_resume is None: # create new lora weights while training + if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": + target_modules = find_all_linear_modules(model) + else: + target_modules = finetuning_args.lora_target + + if finetuning_args.use_llama_pro: + target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) + + if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None: + if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES: + raise ValueError("DoRA is not compatible with PTQ-quantized models.") + + peft_kwargs = { + "r": finetuning_args.lora_rank, + "target_modules": target_modules, + "lora_alpha": finetuning_args.lora_alpha, + "lora_dropout": finetuning_args.lora_dropout, + "use_rslora": finetuning_args.use_rslora, + } + + if model_args.use_unsloth: + from unsloth import FastLanguageModel # type: ignore + + unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} + model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + else: + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + modules_to_save=finetuning_args.additional_target, + use_dora=finetuning_args.use_dora, + **peft_kwargs, + ) + model = get_peft_model(model, lora_config) + + if not finetuning_args.pure_bf16: + for param in filter(lambda p: p.requires_grad, model.parameters()): + param.data = param.data.to(torch.float32) + + if model_args.adapter_name_or_path is not None: + logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) + + return model diff --git a/chat/model/loader.py b/chat/model/loader.py new file mode 100644 index 0000000..e91a7b6 --- /dev/null +++ b/chat/model/loader.py @@ -0,0 +1,135 @@ +from typing import TYPE_CHECKING, Any, Dict + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead + +from ..extras.logging import get_logger +from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms +from .adapter import init_adapter +from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model +from .utils import load_valuehead_params, register_autoclass + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + from ..hparams import FinetuningArguments, ModelArguments + + +logger = get_logger(__name__) + + +def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: + model_args.model_name_or_path = try_download_model_from_ms(model_args) + return { + "trust_remote_code": True, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.hf_hub_token, + } + + +def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": + r""" + Loads pretrained tokenizer. Must before load_model. + + Note: including inplace operation of model_args. + """ + init_kwargs = _get_init_kwargs(model_args) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + split_special_tokens=model_args.split_special_tokens, + padding_side="right", + **init_kwargs, + ) + patch_tokenizer(tokenizer) + return tokenizer + + +def load_model( + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool = False, + add_valuehead: bool = False, +) -> "PreTrainedModel": + r""" + Loads pretrained model. Must after load_tokenizer. + """ + init_kwargs = _get_init_kwargs(model_args) + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) + patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) + + model = None + if is_trainable and model_args.use_unsloth: + from unsloth import FastLanguageModel # type: ignore + + unsloth_kwargs = { + "model_name": model_args.model_name_or_path, + "max_seq_length": model_args.model_max_length, + "dtype": model_args.compute_dtype, + "load_in_4bit": model_args.quantization_bit == 4, + "token": model_args.hf_hub_token, + "device_map": {"": get_current_device()}, + "rope_scaling": getattr(config, "rope_scaling", None), + } + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + model_args.use_unsloth = False + + if model_args.adapter_name_or_path: + model_args.adapter_name_or_path = None + logger.warning("Unsloth does not support loading adapters.") + + if model is None: + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs) + + patch_model(model, tokenizer, model_args, is_trainable) + register_autoclass(config, model, tokenizer) + + model = init_adapter(model, model_args, finetuning_args, is_trainable) + + if add_valuehead: + model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) + patch_valuehead_model(model) + + if model_args.adapter_name_or_path is not None: + vhead_path = model_args.adapter_name_or_path[-1] + else: + vhead_path = model_args.model_name_or_path + + vhead_params = load_valuehead_params(vhead_path, model_args) + if vhead_params is not None: + model.load_state_dict(vhead_params, strict=False) + logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) + + if not is_trainable: + model.requires_grad_(False) + model.eval() + for param in model.parameters(): + if param.device.type == "cuda": + param.data = param.data.to(model_args.compute_dtype) + else: + model.train() + + trainable_params, all_param = count_parameters(model) + if is_trainable: + param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + ) + else: + param_stats = "all params: {:d}".format(all_param) + logger.info(param_stats) + + if model_args.print_param_status: + for name, param in model.named_parameters(): + print( + "name: {}, dtype: {}, device: {}, trainable: {}".format( + name, param.dtype, param.device, param.requires_grad + ) + ) + + return model diff --git a/chat/model/patcher.py b/chat/model/patcher.py new file mode 100644 index 0000000..434a3a8 --- /dev/null +++ b/chat/model/patcher.py @@ -0,0 +1,387 @@ +import math +import os +import random +from contextlib import nullcontext +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +import torch +from datasets import load_dataset +from peft import PeftModel +from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils.versions import require_version + +from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES +from ..extras.logging import get_logger +from ..extras.misc import get_current_device, infer_optim_dtype +from ..extras.packages import is_flash_attn2_available +from ..extras.patches.llama_patch import apply_llama_patch +from .utils import QuantizationMethod, add_z3_leaf_module + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedTokenizer + from trl import AutoModelForCausalLMWithValueHead + + from ..hparams import ModelArguments + + +logger = get_logger(__name__) +SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] + + +def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: + r""" + Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 + TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 + """ + if os.path.isfile(model_args.export_quantization_dataset): + data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) + data_files = model_args.export_quantization_dataset + else: + data_path = model_args.export_quantization_dataset + data_files = None + + dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) + maxlen = model_args.export_quantization_maxlen + + samples = [] + for _ in range(model_args.export_quantization_nsamples): + while True: + sample_idx = random.randint(0, len(dataset) - 1) + sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") + if sample["input_ids"].size(1) >= maxlen: + break # TODO: fix large maxlen + + word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) + input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] + samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) + + return samples + + +def _configure_attn_implementation( + config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any] +) -> None: + if model_args.flash_attn: + if not is_flash_attn2_available(): + logger.warning("FlashAttention2 is not installed.") + return + + logger.info("Using FlashAttention-2 for faster training and inference.") + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + setattr(config, "attn_implementation", "flash_attention_2") + else: + init_kwargs["attn_implementation"] = "flash_attention_2" + else: + init_kwargs["attn_implementation"] = "eager" + + +def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if model_args.rope_scaling is None: + return + + if not hasattr(config, "rope_scaling"): + logger.warning("Current model does not support RoPE scaling.") + return + + if is_trainable: + if model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK scaling may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and model_args.model_max_length > current_max_length: + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + else: + logger.warning("Input length is smaller than max length. Consider increase input length.") + scaling_factor = 1.0 + else: + scaling_factor = 2.0 + + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info( + "Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor) + ) + + +def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.shift_attn: + return + + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: + setattr(config, "group_size_ratio", 0.25) + apply_llama_patch() + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") + + +def _configure_quantization( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + init_kwargs: Dict[str, Any], +) -> None: + r""" + Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + """ + if getattr(config, "quantization_config", None): # ptq + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") + + init_kwargs["device_map"] = {"": get_current_device()} + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + + if quant_method == QuantizationMethod.GPTQ: + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + quantization_config["use_exllama"] = False # disable exllama + + if quant_method == QuantizationMethod.AWQ: + require_version("autoawq", "To fix: pip install autoawq") + + if quant_method == QuantizationMethod.AQLM: + require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") + require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") + quantization_config["bits"] = 2 + + quant_bits = quantization_config.get("bits", "?") + logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) + + elif model_args.export_quantization_bit is not None: # auto-gptq + require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + from accelerate.utils import get_max_memory + + if getattr(config, "model_type", None) == "chatglm": + raise ValueError("ChatGLM model is not supported.") + + init_kwargs["quantization_config"] = GPTQConfig( + bits=model_args.export_quantization_bit, + tokenizer=tokenizer, + dataset=_get_quantization_dataset(tokenizer, model_args), + ) + init_kwargs["device_map"] = "auto" + init_kwargs["max_memory"] = get_max_memory() + logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) + + elif model_args.quantization_bit is not None: # bnb + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type, + bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora + ) + + if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto": + if model_args.quantization_bit != 4: + raise ValueError("Only 4-bit quantized model can use auto device map.") + + require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") + require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0") + require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + else: + init_kwargs["device_map"] = {"": get_current_device()} + + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + + +def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int): + embedding_dim = embed_weight.size(1) + avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) + noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) + noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) + embed_weight[-num_new_tokens:] = avg_weight + noise_weight + + +def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: + r""" + Resize token embeddings. + """ + if is_deepspeed_zero3_enabled(): + import deepspeed # type: ignore + + params = [model.get_input_embeddings().weight] + if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: + params.append(model.get_output_embeddings().weight) + + context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) + else: + context_maybe_zero3 = nullcontext() + + with context_maybe_zero3: + current_embedding_size = model.get_input_embeddings().weight.size(0) + + if len(tokenizer) > current_embedding_size: + if not isinstance(model.get_output_embeddings(), torch.nn.Linear): + logger.warning("Current model does not support resizing token embeddings.") + return + + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) + with context_maybe_zero3: + new_embedding_size = model.get_input_embeddings().weight.size(0) + num_new_tokens = new_embedding_size - current_embedding_size + _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) + _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) + + logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) + + +def _fp32_forward_post_hook( + module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" +) -> "torch.Tensor": + return output.to(torch.float32) + + +def _prepare_model_for_training( + model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" +) -> None: + r""" + Includes: + (1) cast the layernorm in fp32 + (2) make output embedding layer require grads + (3) add the upcasting of the lm_head in fp32 + Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 + """ + if model_args.upcast_layernorm: + logger.info("Upcasting layernorm weights in float32.") + for name, param in model.named_parameters(): + if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): + param.data = param.data.to(torch.float32) + + if not model_args.disable_gradient_checkpointing: + if not getattr(model, "supports_gradient_checkpointing", False): + logger.warning("Current model does not support gradient checkpointing.") + else: + # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) + # According to: https://github.com/huggingface/transformers/issues/28339 + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + model.enable_input_require_grads() + setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled + logger.info("Gradient checkpointing enabled.") + + if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: + logger.info("Upcasting lm_head outputs in float32.") + output_layer = getattr(model, output_layer_name) + if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: + output_layer.register_forward_hook(_fp32_forward_post_hook) + + +def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: + if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): + tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) + + +def patch_config( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + init_kwargs: Dict[str, Any], + is_trainable: bool, +) -> None: + if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 + model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + + _configure_attn_implementation(config, model_args, init_kwargs) + _configure_rope(config, model_args, is_trainable) + _configure_longlora(config, model_args, is_trainable) + _configure_quantization(config, tokenizer, model_args, init_kwargs) + + if model_args.use_cache and not is_trainable: + setattr(config, "use_cache", True) + logger.info("Using KV cache for faster generation.") + + if model_args.moe_aux_loss_coef is not None: + if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"]: + setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) + elif getattr(config, "model_type", None) == "deepseek": + setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) + + if getattr(config, "model_type", None) == "qwen": + setattr(config, "use_flash_attn", model_args.flash_attn) + for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: + setattr(config, dtype_name, model_args.compute_dtype == dtype) + + if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: + setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn + + init_kwargs["torch_dtype"] = model_args.compute_dtype + if not is_deepspeed_zero3_enabled(): + init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage + if init_kwargs["low_cpu_mem_usage"]: + if "device_map" not in init_kwargs: + init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()} + + if init_kwargs["device_map"] == "auto": + init_kwargs["offload_folder"] = model_args.offload_folder + + +def patch_model( + model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool +) -> None: + gen_config = model.generation_config # check and fix generation config + if not gen_config.do_sample and ( + (gen_config.temperature is not None and gen_config.temperature != 1.0) + or (gen_config.top_p is not None and gen_config.top_p != 1.0) + or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) + ): + gen_config.do_sample = True + + if "GenerationMixin" not in str(model.generate.__func__): + model.generate = MethodType(PreTrainedModel.generate, model) + + if is_trainable and getattr(model.config, "model_type", None) == "chatglm": + setattr(model, "lm_head", model.transformer.output_layer) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + if model_args.resize_vocab: + _resize_embedding_layer(model, tokenizer) + + if is_trainable: + _prepare_model_for_training(model, model_args) + + if getattr(model.config, "model_type", None) == "mixtral": + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + add_z3_leaf_module(model, MixtralSparseMoeBlock) + + if getattr(model.config, "model_type", None) == "qwen2moe": + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + + add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock) + + try: + model.add_model_tags(["llama-factory"]) + except Exception: + logger.warning("Cannot properly tag the model.") + + +def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: + def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: + if isinstance(self.pretrained_model, PreTrainedModel): + self.pretrained_model.tie_weights() + + def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_input_embeddings() + + def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: + if isinstance(self.pretrained_model, PeftModel): + self.pretrained_model.create_or_update_model_card(output_dir) + + ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] + setattr(model, "_keys_to_ignore_on_save", ignore_modules) + setattr(model, "tie_weights", MethodType(tie_weights, model)) + setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) + setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model)) diff --git a/chat/model/utils.py b/chat/model/utils.py new file mode 100644 index 0000000..771e611 --- /dev/null +++ b/chat/model/utils.py @@ -0,0 +1,137 @@ +from enum import Enum, unique +from typing import TYPE_CHECKING, Dict, List + +import torch +from transformers import PreTrainedModel +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils import cached_file +from transformers.utils.versions import require_version + +from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ..extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedTokenizer + + from ..hparams import ModelArguments + + +logger = get_logger(__name__) + + +@unique +class QuantizationMethod(str, Enum): + r""" + Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. + """ + + BITS_AND_BYTES = "bitsandbytes" + GPTQ = "gptq" + AWQ = "awq" + AQLM = "aqlm" + QUANTO = "quanto" + + +def add_z3_leaf_module(model: "PreTrainedModel", module: "torch.nn.Module") -> None: + r""" + Sets module as a leaf module to skip partitioning in deepspeed zero3. + """ + if is_deepspeed_zero3_enabled(): + require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + from deepspeed.utils import set_z3_leaf_modules # type: ignore + + set_z3_leaf_modules(model, [module]) + + +def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: + r""" + Finds all available modules to apply lora or galore. + """ + quantization_method = getattr(model, "quantization_method", None) + if quantization_method is None: + linear_cls = torch.nn.Linear + elif quantization_method == QuantizationMethod.BITS_AND_BYTES: + import bitsandbytes as bnb + + linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt + else: + raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method)) + + output_layer_names = ["lm_head"] + if model.config.model_type == "chatglm": + output_layer_names.append("output_layer") + elif model.config.model_type == "internlm2": + output_layer_names.append("output") + + module_names = set() + for name, module in model.named_modules(): + if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names): + module_names.add(name.split(".")[-1]) + + logger.info("Found linear modules: {}".format(",".join(module_names))) + return list(module_names) + + +def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]: + r""" + Finds the modules in the expanded blocks to apply lora. + """ + num_layers = getattr(model.config, "num_hidden_layers", None) + if not num_layers: + raise ValueError("Model was not supported.") + + if num_layers % num_layer_trainable != 0: + raise ValueError( + "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable) + ) + + stride = num_layers // num_layer_trainable + trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) + trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids] + module_names = [] + for name, _ in model.named_modules(): + if any(target_module in name for target_module in target_modules) and any( + trainable_layer in name for trainable_layer in trainable_layers + ): + module_names.append(name) + + logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) + return module_names + + +def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: + r""" + Loads value head parameters from Hugging Face Hub or local disk. + + Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. + """ + kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} + + try: + from safetensors import safe_open + + vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) + with safe_open(vhead_file, framework="pt", device="cpu") as f: + return {key: f.get_tensor(key) for key in f.keys()} + except Exception as err: + logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err))) + + try: + vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) + return torch.load(vhead_file, map_location="cpu") + except Exception as err: + logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err))) + + logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id)) + logger.info("Ignore these messages if you are not resuming the training of a value head model.") + return None + + +def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): + if "AutoConfig" in getattr(config, "auto_map", {}): + config.__class__.register_for_auto_class() + if "AutoModelForCausalLM" in getattr(config, "auto_map", {}): + model.__class__.register_for_auto_class() + if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): + tokenizer.__class__.register_for_auto_class() diff --git a/chat/vllm_engine.py b/chat/vllm_engine.py new file mode 100644 index 0000000..fb5c3c3 --- /dev/null +++ b/chat/vllm_engine.py @@ -0,0 +1,149 @@ +import uuid +from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence + +from transformers.utils.versions import require_version + +from .data import get_template_and_fix_tokenizer +from .extras.misc import get_device_count +from .extras.packages import is_vllm_available +from .model import load_tokenizer +from .base_engine import BaseEngine, Response + + +if is_vllm_available(): + from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams + +if TYPE_CHECKING: + from .hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +class VllmEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3") + self.can_generate = finetuning_args.stage == "sft" + engine_args = AsyncEngineArgs( + model=model_args.model_name_or_path, + trust_remote_code=True, + max_model_len=model_args.vllm_maxlen, + tensor_parallel_size=get_device_count() or 1, + gpu_memory_utilization=model_args.vllm_gpu_util, + disable_log_stats=True, + disable_log_requests=True, + enforce_eager=model_args.vllm_enforce_eager, + ) + self.model = AsyncLLMEngine.from_engine_args(engine_args) + self.tokenizer = load_tokenizer(model_args) + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.generating_args = generating_args.to_dict() + + async def _generate( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncIterator["RequestOutput"]: + request_id = "chatcmpl-{}".format(uuid.uuid4().hex) + paired_messages = messages + [{"role": "assistant", "content": ""}] + prompt_ids, _ = self.template.encode_oneturn( + tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools + ) + prompt_length = len(prompt_ids) + + temperature = input_kwargs.pop("temperature", None) + top_p = input_kwargs.pop("top_p", None) + top_k = input_kwargs.pop("top_k", None) + num_return_sequences = input_kwargs.pop("num_return_sequences", None) + repetition_penalty = input_kwargs.pop("repetition_penalty", None) + max_length = input_kwargs.pop("max_length", None) + max_new_tokens = input_kwargs.pop("max_new_tokens", None) + + generating_args = self.generating_args.copy() + generating_args.update( + dict( + temperature=temperature or generating_args["temperature"], + top_p=top_p or generating_args["top_p"], + top_k=top_k or generating_args["top_k"], + num_return_sequences=num_return_sequences or 1, + repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], + ) + ) + + if max_length: + generating_args["max_new_tokens"] = max_length - prompt_length + + if max_new_tokens: + generating_args["max_new_tokens"] = max_new_tokens + + sampling_params = SamplingParams( + n=generating_args["num_return_sequences"], + repetition_penalty=generating_args["repetition_penalty"], + temperature=generating_args["temperature"], + top_p=generating_args["top_p"], + top_k=generating_args["top_k"], + use_beam_search=generating_args["num_beams"] > 1, + length_penalty=generating_args["length_penalty"], + stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + max_tokens=generating_args["max_new_tokens"], + skip_special_tokens=True, + ) + result_generator = self.model.generate( + prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids + ) + return result_generator + + async def start(self) -> None: + pass + + async def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> List["Response"]: + final_output = None + generator = await self._generate(messages, system, tools, **input_kwargs) + async for request_output in generator: + final_output = request_output + + results = [] + for output in final_output.outputs: + results.append( + Response( + response_text=output.text, + response_length=len(output.token_ids), + prompt_length=len(final_output.prompt_token_ids), + finish_reason=output.finish_reason, + ) + ) + + return results + + async def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + generated_text = "" + generator = await self._generate(messages, system, tools, **input_kwargs) + async for result in generator: + delta_text = result.outputs[0].text[len(generated_text) :] + generated_text = result.outputs[0].text + yield delta_text + + async def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + raise NotImplementedError("vLLM engine does not support get_scores.") diff --git a/demo_QAMaking_from_pdf.py b/demo_QAMaking_from_pdf.py new file mode 100644 index 0000000..5ae6335 --- /dev/null +++ b/demo_QAMaking_from_pdf.py @@ -0,0 +1,117 @@ +import gradio as gr +from gradio_pdf import PDF +import fitz +from chat import ChatModel + + +def _launch_demo(model, template): + + def loadpdf(path): + print(path) + doc = fitz.open(path) + txt = "" + for page in doc: + text = page.get_text() + txt += text + "\n" + return txt + + def makequestion(content, evt: gr.SelectData): + # print(content) + # print(evt.value) + return evt.value + + def makeselection(content, evt: gr.SelectData): + # print(content) + # print(evt.value) + return evt.value + + + + def add_text(history, task_history, query): + def formatcontent(prompt:str, query: str) -> str: + qa_num = 3 + p = template.format(count=qa_num, text=query) + return p + history += [(formatcontent(template, query), None)] + task_history += [task_history + [("111", None)]] + #history = history + [(_parse_text(text), None)] + #task_history = task_history + [(task_text, None)] + print(history) + return history, task_history, "" + + + def makeQA(_chatbot, task_history): + #import pdb + #pdb.set_trace() + chat_query = _chatbot[-1][0] + query = task_history[-1][0] + print("User: " + chat_query) + + message = {"role": "user", "content": chat_query} + respones = "" + for new_text in model.stream_chat([message]): + respones += new_text + print(respones) + _chatbot[-1] = (chat_query, respones) + #print(_chatbot) + #_chatbot += (p, new_text) + yield _chatbot + return _chatbot + + + def reset_user_input(): + return gr.update(value="") + + + def reset_state(task_history): + task_history.clear() + return [] + + + + with gr.Blocks() as demo: + with gr.Row(): + pdf = PDF(label="Upload a PDF", interactive=True) + with gr.Column(): + name=gr.Textbox(label="info") + content = gr.Textbox("出题的内容") + with gr.Row(): + run = gr.Button("Make Question") + empty_bin = gr.Button("Clear") + name.select(makeselection, name, content) + pdf.upload(loadpdf, pdf, name) + chatbot = gr.Chatbot(label="QA") + task_history = gr.State([]) + run.click(add_text, [chatbot, task_history, content], [chatbot, task_history]).then( + makeQA, [chatbot, task_history], [chatbot], show_progress=True + ) + run.click(reset_user_input, [], [content]) + #run.click(reset_user_input, [], [chatbot]) + empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) + #run.click(makeQA, [chatbot, content], [chatbot], show_progress=True) + + demo.launch(server_name="0.0.0.0" ,server_port=8501) + + +if __name__ == '__main__': + + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument('--model_name_or_path', type=str, default="/workspace/mnt/storage/zhaozhijian/pt-fs/Qwen1___5-7B-Chat/") + parser.add_argument('--template', type=str, default="qwen") + parser.add_argument('--prompt_path', type=str, default="template/QAtemplat.txt") + + args = parser.parse_args() + + with open(args.prompt_path, 'r') as f: + template = ''.join(f.readlines()) + + #print(vars(args).keys()) + + model = ChatModel(vars(args)) + + _launch_demo(model, template) + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b997182 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +torch>=1.13.1 +transformers>=4.37.2 +datasets>=2.14.3 +accelerate>=0.27.2 +peft>=0.10.0 +trl>=0.8.1 +gradio>=4.0.0,<=4.21.0 +scipy +einops +sentencepiece +protobuf +uvicorn +pydantic +fastapi +sse-starlette +matplotlib +fire +gradio_pdf +PyMuPDF diff --git a/template/QAtemplat.txt b/template/QAtemplat.txt new file mode 100644 index 0000000..05ddb25 --- /dev/null +++ b/template/QAtemplat.txt @@ -0,0 +1,21 @@ +你是一个聪明的助手,旨在提出有意义的问答对。问题应该切中要点,答案应该尽可能充实详细。 +给定一段文本,不要编造东西,根据文本实际内容从中提炼出可用于评估的问题和答案对。 +问题和答案对格式如下: +``` +[ +{{ + "question": "$YOUR_QUESTION_HERE", + "A":$YOUR_CHOICE_HERE", + "B":$YOUR_CHOICE_HERE", + "C":$YOUR_CHOICE_HERE", + "D":$YOUR_CHOICE_HERE", + "answer": "A/B/C/D:$THE_ANSWER_HERE" +}} +] +``` + +在 ``` 中间的文字必须是有效的json格式. + +请从如下的内容中提炼出至少{count}个的问答对,格式是json: +---------------- +{text}