From ccc82b93ac56037c033762f8edd598d0b9e7facf Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Fri, 15 Sep 2023 05:59:17 +0000 Subject: [PATCH] add dtype and seed --- fastchat/llm_judge/gen_model_answer.py | 4 ++++ fastchat/model/model_adapter.py | 9 +++++++++ fastchat/serve/cli.py | 10 +++++++++ fastchat/serve/inference.py | 2 ++ fastchat/serve/model_worker.py | 28 +++++++++++++++++++++++++- fastchat/utils.py | 11 ++++++++++ 6 files changed, 63 insertions(+), 1 deletion(-) diff --git a/fastchat/llm_judge/gen_model_answer.py b/fastchat/llm_judge/gen_model_answer.py index 3d093ecd5..219c74b4d 100644 --- a/fastchat/llm_judge/gen_model_answer.py +++ b/fastchat/llm_judge/gen_model_answer.py @@ -29,6 +29,7 @@ def run_eval( num_gpus_per_model, num_gpus_total, max_gpu_memory, + dtype=None, ): questions = load_questions(question_file, question_begin, question_end) # random shuffle the questions to balance the loading @@ -58,6 +59,7 @@ def run_eval( num_choices, num_gpus_per_model, max_gpu_memory, + dtype=dtype, ) ) @@ -75,12 +77,14 @@ def get_model_answers( num_choices, num_gpus_per_model, max_gpu_memory, + dtype=None, ): model, tokenizer = load_model( model_path, device="cuda", num_gpus=num_gpus_per_model, max_gpu_memory=max_gpu_memory, + dtype=dtype, load_8bit=False, cpu_offloading=False, debug=False, diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 296b53c8f..87a007953 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -152,6 +152,7 @@ def load_model( device: str = "cuda", num_gpus: int = 1, max_gpu_memory: Optional[str] = None, + dtype=None, load_8bit: bool = False, cpu_offloading: bool = False, gptq_config: Optional[GptqConfig] = None, @@ -275,6 +276,9 @@ def load_model( return model, tokenizer kwargs["revision"] = revision + if dtype is not None: + kwargs["torch_dtype"] = dtype + # Load model model, tokenizer = adapter.load_model(model_path, kwargs) @@ -385,6 +389,11 @@ def add_model_args(parser): type=str, help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'", ) + parser.add_argument( + "--dtype", + choices=["fp32", "fp16", "bf16"], + default=None, + ) parser.add_argument( "--load-8bit", action="store_true", help="Use 8-bit quantization" ) diff --git a/fastchat/serve/cli.py b/fastchat/serve/cli.py index 41161ae35..37a6e6f87 100644 --- a/fastchat/serve/cli.py +++ b/fastchat/serve/cli.py @@ -17,6 +17,7 @@ import os import re import sys +import torch from prompt_toolkit import PromptSession from prompt_toolkit.auto_suggest import AutoSuggestFromHistory @@ -203,6 +204,14 @@ def main(args): else: raise ValueError(f"Invalid style for console: {args.style}") try: + dtype = None + if args.dtype == "fp32": + dtype = torch.float32 + elif args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + chat_loop( args.model_path, args.device, @@ -231,6 +240,7 @@ def main(args): judge_sent_end=args.judge_sent_end, debug=args.debug, history=not args.no_history, + dtype=dtype, ) except KeyboardInterrupt: print("exit...") diff --git a/fastchat/serve/inference.py b/fastchat/serve/inference.py index 4e5191610..d82c9ea90 100644 --- a/fastchat/serve/inference.py +++ b/fastchat/serve/inference.py @@ -302,6 +302,7 @@ def chat_loop( judge_sent_end: bool = True, debug: bool = True, history: bool = True, + dtype=None, ): # Model model, tokenizer = load_model( @@ -309,6 +310,7 @@ def chat_loop( device=device, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory, + dtype=dtype, load_8bit=load_8bit, cpu_offloading=cpu_offloading, gptq_config=gptq_config, diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index dac3764d4..edba50752 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -46,7 +46,14 @@ ) from fastchat.modules.gptq import GptqConfig from fastchat.modules.awq import AWQConfig -from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length +from fastchat.utils import ( + build_logger, + pretty_print_semaphore, + get_context_length, + set_random_seed, +) + +from transformers import set_seed worker_id = str(uuid.uuid4())[:8] @@ -190,6 +197,7 @@ def __init__( device: str, num_gpus: int, max_gpu_memory: str, + dtype=None, load_8bit: bool = False, cpu_offloading: bool = False, gptq_config: Optional[GptqConfig] = None, @@ -197,6 +205,7 @@ def __init__( stream_interval: int = 2, conv_template: str = None, embed_in_truncate: bool = False, + seed=None, **kwargs, ): super().__init__( @@ -215,6 +224,7 @@ def __init__( device=device, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory, + dtype=dtype, load_8bit=load_8bit, cpu_offloading=cpu_offloading, gptq_config=gptq_config, @@ -227,6 +237,9 @@ def __init__( self.generate_stream_func = get_generate_stream_function(self.model, model_path) self.stream_interval = stream_interval self.embed_in_truncate = embed_in_truncate + self.seed = seed + if seed is not None: + set_random_seed(seed) if not no_register: self.init_heart_beat() @@ -235,6 +248,8 @@ def generate_stream_gate(self, params): self.call_ct += 1 try: + if self.seed is not None: + set_seed(self.seed) for output in self.generate_stream_func( self.model, self.tokenizer, @@ -473,6 +488,7 @@ def create_model_worker(): ) parser.add_argument("--stream-interval", type=int, default=2) parser.add_argument("--no-register", action="store_true") + parser.add_argument("--seed", type=int, default=None) args = parser.parse_args() logger.info(f"args: {args}") @@ -495,6 +511,14 @@ def create_model_worker(): groupsize=args.awq_groupsize, ) + dtype = None + if args.dtype == "fp32": + dtype = torch.float32 + elif args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + worker = ModelWorker( args.controller_address, args.worker_address, @@ -506,6 +530,7 @@ def create_model_worker(): device=args.device, num_gpus=args.num_gpus, max_gpu_memory=args.max_gpu_memory, + dtype=dtype, load_8bit=args.load_8bit, cpu_offloading=args.cpu_offloading, gptq_config=gptq_config, @@ -513,6 +538,7 @@ def create_model_worker(): stream_interval=args.stream_interval, conv_template=args.conv_template, embed_in_truncate=args.embed_in_truncate, + seed=args.seed, ) return args, worker diff --git a/fastchat/utils.py b/fastchat/utils.py index 25370eb17..3e1b8ec69 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -5,9 +5,12 @@ import json import logging import logging.handlers +import numpy as np import os import platform +import random import sys +import torch from typing import AsyncGenerator, Generator import warnings @@ -302,3 +305,11 @@ def get_context_length(config): if val is not None: return int(rope_scaling_factor * val) return 2048 + + +def set_random_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed)