Skip to content

Commit

Permalink
add dtype and seed
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 committed Sep 15, 2023
1 parent 3149253 commit ccc82b9
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 1 deletion.
4 changes: 4 additions & 0 deletions fastchat/llm_judge/gen_model_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def run_eval(
num_gpus_per_model,
num_gpus_total,
max_gpu_memory,
dtype=None,
):
questions = load_questions(question_file, question_begin, question_end)
# random shuffle the questions to balance the loading
Expand Down Expand Up @@ -58,6 +59,7 @@ def run_eval(
num_choices,
num_gpus_per_model,
max_gpu_memory,
dtype=dtype,
)
)

Expand All @@ -75,12 +77,14 @@ def get_model_answers(
num_choices,
num_gpus_per_model,
max_gpu_memory,
dtype=None,
):
model, tokenizer = load_model(
model_path,
device="cuda",
num_gpus=num_gpus_per_model,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=False,
cpu_offloading=False,
debug=False,
Expand Down
9 changes: 9 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def load_model(
device: str = "cuda",
num_gpus: int = 1,
max_gpu_memory: Optional[str] = None,
dtype=None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
Expand Down Expand Up @@ -275,6 +276,9 @@ def load_model(
return model, tokenizer
kwargs["revision"] = revision

if dtype is not None:
kwargs["torch_dtype"] = dtype

# Load model
model, tokenizer = adapter.load_model(model_path, kwargs)

Expand Down Expand Up @@ -385,6 +389,11 @@ def add_model_args(parser):
type=str,
help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'",
)
parser.add_argument(
"--dtype",
choices=["fp32", "fp16", "bf16"],
default=None,
)
parser.add_argument(
"--load-8bit", action="store_true", help="Use 8-bit quantization"
)
Expand Down
10 changes: 10 additions & 0 deletions fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import re
import sys
import torch

from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
Expand Down Expand Up @@ -203,6 +204,14 @@ def main(args):
else:
raise ValueError(f"Invalid style for console: {args.style}")
try:
dtype = None
if args.dtype == "fp32":
dtype = torch.float32
elif args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16

chat_loop(
args.model_path,
args.device,
Expand Down Expand Up @@ -231,6 +240,7 @@ def main(args):
judge_sent_end=args.judge_sent_end,
debug=args.debug,
history=not args.no_history,
dtype=dtype,
)
except KeyboardInterrupt:
print("exit...")
Expand Down
2 changes: 2 additions & 0 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,15 @@ def chat_loop(
judge_sent_end: bool = True,
debug: bool = True,
history: bool = True,
dtype=None,
):
# Model
model, tokenizer = load_model(
model_path,
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
Expand Down
28 changes: 27 additions & 1 deletion fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@
)
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length
from fastchat.utils import (
build_logger,
pretty_print_semaphore,
get_context_length,
set_random_seed,
)

from transformers import set_seed


worker_id = str(uuid.uuid4())[:8]
Expand Down Expand Up @@ -190,13 +197,15 @@ def __init__(
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype=None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
stream_interval: int = 2,
conv_template: str = None,
embed_in_truncate: bool = False,
seed=None,
**kwargs,
):
super().__init__(
Expand All @@ -215,6 +224,7 @@ def __init__(
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
Expand All @@ -227,6 +237,9 @@ def __init__(
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
self.stream_interval = stream_interval
self.embed_in_truncate = embed_in_truncate
self.seed = seed
if seed is not None:
set_random_seed(seed)

if not no_register:
self.init_heart_beat()
Expand All @@ -235,6 +248,8 @@ def generate_stream_gate(self, params):
self.call_ct += 1

try:
if self.seed is not None:
set_seed(self.seed)
for output in self.generate_stream_func(
self.model,
self.tokenizer,
Expand Down Expand Up @@ -473,6 +488,7 @@ def create_model_worker():
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
parser.add_argument("--seed", type=int, default=None)
args = parser.parse_args()
logger.info(f"args: {args}")

Expand All @@ -495,6 +511,14 @@ def create_model_worker():
groupsize=args.awq_groupsize,
)

dtype = None
if args.dtype == "fp32":
dtype = torch.float32
elif args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16

worker = ModelWorker(
args.controller_address,
args.worker_address,
Expand All @@ -506,13 +530,15 @@ def create_model_worker():
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
dtype=dtype,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
seed=args.seed,
)
return args, worker

Expand Down
11 changes: 11 additions & 0 deletions fastchat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import json
import logging
import logging.handlers
import numpy as np
import os
import platform
import random
import sys
import torch
from typing import AsyncGenerator, Generator
import warnings

Expand Down Expand Up @@ -302,3 +305,11 @@ def get_context_length(config):
if val is not None:
return int(rope_scaling_factor * val)
return 2048


def set_random_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

0 comments on commit ccc82b9

Please sign in to comment.