From 4f6e1023a5d6364f002d783882f65191e4ca6d30 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 20 Sep 2023 15:29:00 +0800 Subject: [PATCH 01/13] copy the model_worker.py --- fastchat/serve/huggingface_api_worker.py | 549 +++++++++++++++++++++++ 1 file changed, 549 insertions(+) create mode 100644 fastchat/serve/huggingface_api_worker.py diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py new file mode 100644 index 000000000..391a3b4fc --- /dev/null +++ b/fastchat/serve/huggingface_api_worker.py @@ -0,0 +1,549 @@ +""" +A model worker that executes the model. +""" +import argparse +import asyncio +import base64 +import dataclasses +import gc +import logging +import json +import os +import threading +import time +from typing import List, Optional +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +try: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + AutoModel, + ) +except ImportError: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LLaMATokenizer, + AutoModel, + ) +import torch +import torch.nn.functional as F +from transformers import set_seed +import uvicorn + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG +from fastchat.conversation import get_conv_template +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_conversation_template, + get_generate_stream_function, +) +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.awq import AWQConfig +from fastchat.utils import ( + build_logger, + pretty_print_semaphore, + get_context_length, + str_to_torch_dtype, +) + + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + +app = FastAPI() + + +def heart_beat_worker(obj): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + obj.send_heart_beat() + + +class BaseModelWorker: + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + conv_template: str = None, + ): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_names = model_names or [model_path.split("/")[-1]] + self.limit_worker_concurrency = limit_worker_concurrency + if conv_template: + self.conv = get_conv_template(conv_template) + else: + self.conv = get_conversation_template(model_path) + self.conv.sep_style = int(self.conv.sep_style) + self.tokenizer = None + self.context_len = None + self.call_ct = 0 + self.semaphore = None + + self.heart_beat_thread = None + + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, + args=(self,), + daemon=True, + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {self.model_names}. " + f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " + f"call_ct: {self.call_ct}. " + f"worker_id: {self.worker_id}. " + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except (requests.exceptions.RequestException, KeyError) as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if ( + self.semaphore is None + or self.semaphore._value is None + or self.semaphore._waiters is None + ): + return 0 + else: + return ( + self.limit_worker_concurrency + - self.semaphore._value + + len(self.semaphore._waiters) + ) + + def get_status(self): + return { + "model_names": self.model_names, + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def count_token(self, params): + prompt = params["prompt"] + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + + ret = { + "count": input_echo_len, + "error_code": 0, + } + return ret + + def get_conv_template(self): + return {"conv": self.conv} + + +class ModelWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + stream_interval: int = 2, + conv_template: Optional[str] = None, + embed_in_truncate: bool = False, + seed: Optional[int] = None, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id} ...") + self.model, self.tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + ) + self.device = device + if self.tokenizer.pad_token == None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.context_len = get_context_length(self.model.config) + self.generate_stream_func = get_generate_stream_function( + self.model, model_path) + self.stream_interval = stream_interval + self.embed_in_truncate = embed_in_truncate + self.seed = seed + + if not no_register: + self.init_heart_beat() + + def generate_stream_gate(self, params): + self.call_ct += 1 + + try: + if self.seed is not None: + set_seed(self.seed) + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + self.device, + self.context_len, + self.stream_interval, + ): + ret = { + "text": output["text"], + "error_code": 0, + } + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): + if model_type_dict.get("is_bert"): + model_output = self.model(input_ids) + if model_type_dict.get("is_robert"): + data = model_output.last_hidden_state + else: + data = model_output[0] + elif model_type_dict.get("is_t5"): + model_output = self.model(input_ids, decoder_input_ids=input_ids) + data = model_output.encoder_last_hidden_state + else: + model_output = self.model(input_ids, output_hidden_states=True) + if model_type_dict.get("is_chatglm"): + data = model_output.hidden_states[-1].transpose(0, 1) + else: + data = model_output.hidden_states[-1] + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) + token_num = torch.sum(attention_mask).item() + + return sum_embeddings, token_num + + def __encode_base64(self, embeddings: torch.Tensor) -> List[str]: + embeddings = embeddings.cpu() + return [ + base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings + ] + + @torch.inference_mode() + def get_embeddings(self, params): + self.call_ct += 1 + + try: + tokenizer = self.tokenizer + ret = {"embedding": [], "token_num": 0} + + model_type_dict = { + "is_llama": "llama" in str(type(self.model)), + "is_t5": "t5" in str(type(self.model)), + "is_chatglm": "chatglm" in str(type(self.model)), + "is_bert": "bert" in str(type(self.model)), + "is_robert": "robert" in str(type(self.model)), + } + + if self.embed_in_truncate: + encoding = tokenizer.batch_encode_plus( + params["input"], + padding=True, + truncation="longest_first", + return_tensors="pt", + max_length=self.context_len, + ) + else: + encoding = tokenizer.batch_encode_plus( + params["input"], padding=True, return_tensors="pt" + ) + input_ids = encoding["input_ids"].to(self.device) + attention_mask = input_ids != tokenizer.pad_token_id + + base64_encode = params.get("encoding_format", None) + + if self.embed_in_truncate: + chunk_embeddings, token_num = self.__process_embed_chunk( + input_ids, attention_mask, **model_type_dict + ) + embedding = chunk_embeddings / token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + ret["token_num"] = token_num + else: + all_embeddings = [] + all_token_num = 0 + for i in range(0, input_ids.size(1), self.context_len): + chunk_input_ids = input_ids[:, i: i + self.context_len] + chunk_attention_mask = attention_mask[:, + i: i + self.context_len] + + chunk_embeddings, token_num = self.__process_embed_chunk( + chunk_input_ids, chunk_attention_mask, **model_type_dict + ) + all_embeddings.append(chunk_embeddings) + all_token_num += token_num + + all_embeddings_tensor = torch.stack(all_embeddings) + embedding = torch.sum( + all_embeddings_tensor, dim=0) / all_token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + + ret["token_num"] = all_token_num + + if base64_encode == "base64": + out_embeddings = self.__encode_base64(normalized_embeddings) + else: + out_embeddings = normalized_embeddings.tolist() + ret["embedding"] = out_embeddings + + gc.collect() + torch.cuda.empty_cache() + if self.device == "xpu": + torch.xpu.empty_cache() + if self.device == "npu": + torch.npu.empty_cache() + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + embedding = worker.get_embeddings(params) + release_worker_semaphore() + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +def create_model_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, + default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument("--embed-in-truncate", action="store_true") + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + + worker = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + dtype=str_to_torch_dtype(args.dtype), + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + embed_in_truncate=args.embed_in_truncate, + seed=args.seed, + ) + return args, worker + + +if __name__ == "__main__": + args, worker = create_model_worker() + uvicorn.run(app, host=args.host, port=args.port, log_level="info") From 2fd1f7a579322caa6b23e6c554b39a883c7073ee Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 20 Sep 2023 17:06:19 +0800 Subject: [PATCH 02/13] the basic structure --- fastchat/serve/huggingface_api_worker.py | 381 ++++------------------- 1 file changed, 57 insertions(+), 324 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 391a3b4fc..776ba85d7 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -1,5 +1,14 @@ """ -A model worker that executes the model. +A model worker to call huggingface api. +The contents in supported_models.json : +{ + "falcon-180b": { + "model_path": "tiiuae/falcon-180B-chat", + "api_base": "https://api-inference.huggingface.co/models", + "token": "hf_xxx", + "context_length": "2048" + } +} """ import argparse import asyncio @@ -18,6 +27,8 @@ from fastapi.responses import StreamingResponse, JSONResponse import requests +from fastchat.serve.model_worker import BaseModelWorker + try: from transformers import ( AutoTokenizer, @@ -61,146 +72,19 @@ app = FastAPI() -def heart_beat_worker(obj): - while True: - time.sleep(WORKER_HEART_BEAT_INTERVAL) - obj.send_heart_beat() - - -class BaseModelWorker: - def __init__( - self, - controller_addr: str, - worker_addr: str, - worker_id: str, - model_path: str, - model_names: List[str], - limit_worker_concurrency: int, - conv_template: str = None, - ): - self.controller_addr = controller_addr - self.worker_addr = worker_addr - self.worker_id = worker_id - if model_path.endswith("/"): - model_path = model_path[:-1] - self.model_names = model_names or [model_path.split("/")[-1]] - self.limit_worker_concurrency = limit_worker_concurrency - if conv_template: - self.conv = get_conv_template(conv_template) - else: - self.conv = get_conversation_template(model_path) - self.conv.sep_style = int(self.conv.sep_style) - self.tokenizer = None - self.context_len = None - self.call_ct = 0 - self.semaphore = None - - self.heart_beat_thread = None - - def init_heart_beat(self): - self.register_to_controller() - self.heart_beat_thread = threading.Thread( - target=heart_beat_worker, - args=(self,), - daemon=True, - ) - self.heart_beat_thread.start() - - def register_to_controller(self): - logger.info("Register to controller") - - url = self.controller_addr + "/register_worker" - data = { - "worker_name": self.worker_addr, - "check_heart_beat": True, - "worker_status": self.get_status(), - } - r = requests.post(url, json=data) - assert r.status_code == 200 - - def send_heart_beat(self): - logger.info( - f"Send heart beat. Models: {self.model_names}. " - f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " - f"call_ct: {self.call_ct}. " - f"worker_id: {self.worker_id}. " - ) - - url = self.controller_addr + "/receive_heart_beat" - - while True: - try: - ret = requests.post( - url, - json={ - "worker_name": self.worker_addr, - "queue_length": self.get_queue_length(), - }, - timeout=5, - ) - exist = ret.json()["exist"] - break - except (requests.exceptions.RequestException, KeyError) as e: - logger.error(f"heart beat error: {e}") - time.sleep(5) - - if not exist: - self.register_to_controller() - - def get_queue_length(self): - if ( - self.semaphore is None - or self.semaphore._value is None - or self.semaphore._waiters is None - ): - return 0 - else: - return ( - self.limit_worker_concurrency - - self.semaphore._value - + len(self.semaphore._waiters) - ) - - def get_status(self): - return { - "model_names": self.model_names, - "speed": 1, - "queue_length": self.get_queue_length(), - } - - def count_token(self, params): - prompt = params["prompt"] - input_ids = self.tokenizer(prompt).input_ids - input_echo_len = len(input_ids) - - ret = { - "count": input_echo_len, - "error_code": 0, - } - return ret - - def get_conv_template(self): - return {"conv": self.conv} - - -class ModelWorker(BaseModelWorker): +class HuggingfaceApiWorker(BaseModelWorker): def __init__( self, controller_addr: str, worker_addr: str, worker_id: str, model_path: str, + api_base: str, + token: str, + context_length: int, model_names: List[str], limit_worker_concurrency: int, no_register: bool, - device: str, - num_gpus: int, - max_gpu_memory: str, - dtype: Optional[torch.dtype] = None, - load_8bit: bool = False, - cpu_offloading: bool = False, - gptq_config: Optional[GptqConfig] = None, - awq_config: Optional[AWQConfig] = None, stream_interval: int = 2, conv_template: Optional[str] = None, embed_in_truncate: bool = False, @@ -217,25 +101,19 @@ def __init__( conv_template=conv_template, ) + self.model_path = model_path + self.api_base = api_base + self.token = token + self.context_len = context_length + logger.info( - f"Loading the model {self.model_names} on worker {worker_id} ...") - self.model, self.tokenizer = load_model( - model_path, - device=device, - num_gpus=num_gpus, - max_gpu_memory=max_gpu_memory, - dtype=dtype, - load_8bit=load_8bit, - cpu_offloading=cpu_offloading, - gptq_config=gptq_config, - awq_config=awq_config, + f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." ) - self.device = device - if self.tokenizer.pad_token == None: - self.tokenizer.pad_token = self.tokenizer.eos_token - self.context_len = get_context_length(self.model.config) - self.generate_stream_func = get_generate_stream_function( - self.model, model_path) + + self.model = None + self.tokenizer = None + self.device = None + self.generate_stream_func = None self.stream_interval = stream_interval self.embed_in_truncate = embed_in_truncate self.seed = seed @@ -243,162 +121,24 @@ def __init__( if not no_register: self.init_heart_beat() + def count_token(self, params): + # No tokenizer here + return 0 + def generate_stream_gate(self, params): self.call_ct += 1 - - try: - if self.seed is not None: - set_seed(self.seed) - for output in self.generate_stream_func( - self.model, - self.tokenizer, - params, - self.device, - self.context_len, - self.stream_interval, - ): - ret = { - "text": output["text"], - "error_code": 0, - } - if "usage" in output: - ret["usage"] = output["usage"] - if "finish_reason" in output: - ret["finish_reason"] = output["finish_reason"] - if "logprobs" in output: - ret["logprobs"] = output["logprobs"] - yield json.dumps(ret).encode() + b"\0" - except torch.cuda.OutOfMemoryError as e: - ret = { - "text": f"{SERVER_ERROR_MSG}\n\n({e})", - "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, - } - yield json.dumps(ret).encode() + b"\0" - except (ValueError, RuntimeError) as e: - ret = { - "text": f"{SERVER_ERROR_MSG}\n\n({e})", - "error_code": ErrorCode.INTERNAL_ERROR, - } - yield json.dumps(ret).encode() + b"\0" + raise NotImplementedError() def generate_gate(self, params): for x in self.generate_stream_gate(params): pass return json.loads(x[:-1].decode()) - def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): - if model_type_dict.get("is_bert"): - model_output = self.model(input_ids) - if model_type_dict.get("is_robert"): - data = model_output.last_hidden_state - else: - data = model_output[0] - elif model_type_dict.get("is_t5"): - model_output = self.model(input_ids, decoder_input_ids=input_ids) - data = model_output.encoder_last_hidden_state - else: - model_output = self.model(input_ids, output_hidden_states=True) - if model_type_dict.get("is_chatglm"): - data = model_output.hidden_states[-1].transpose(0, 1) - else: - data = model_output.hidden_states[-1] - mask = attention_mask.unsqueeze(-1).expand(data.size()).float() - masked_embeddings = data * mask - sum_embeddings = torch.sum(masked_embeddings, dim=1) - token_num = torch.sum(attention_mask).item() - - return sum_embeddings, token_num - - def __encode_base64(self, embeddings: torch.Tensor) -> List[str]: - embeddings = embeddings.cpu() - return [ - base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings - ] - @torch.inference_mode() def get_embeddings(self, params): self.call_ct += 1 - try: - tokenizer = self.tokenizer - ret = {"embedding": [], "token_num": 0} - - model_type_dict = { - "is_llama": "llama" in str(type(self.model)), - "is_t5": "t5" in str(type(self.model)), - "is_chatglm": "chatglm" in str(type(self.model)), - "is_bert": "bert" in str(type(self.model)), - "is_robert": "robert" in str(type(self.model)), - } - - if self.embed_in_truncate: - encoding = tokenizer.batch_encode_plus( - params["input"], - padding=True, - truncation="longest_first", - return_tensors="pt", - max_length=self.context_len, - ) - else: - encoding = tokenizer.batch_encode_plus( - params["input"], padding=True, return_tensors="pt" - ) - input_ids = encoding["input_ids"].to(self.device) - attention_mask = input_ids != tokenizer.pad_token_id - - base64_encode = params.get("encoding_format", None) - - if self.embed_in_truncate: - chunk_embeddings, token_num = self.__process_embed_chunk( - input_ids, attention_mask, **model_type_dict - ) - embedding = chunk_embeddings / token_num - normalized_embeddings = F.normalize(embedding, p=2, dim=1) - ret["token_num"] = token_num - else: - all_embeddings = [] - all_token_num = 0 - for i in range(0, input_ids.size(1), self.context_len): - chunk_input_ids = input_ids[:, i: i + self.context_len] - chunk_attention_mask = attention_mask[:, - i: i + self.context_len] - - chunk_embeddings, token_num = self.__process_embed_chunk( - chunk_input_ids, chunk_attention_mask, **model_type_dict - ) - all_embeddings.append(chunk_embeddings) - all_token_num += token_num - - all_embeddings_tensor = torch.stack(all_embeddings) - embedding = torch.sum( - all_embeddings_tensor, dim=0) / all_token_num - normalized_embeddings = F.normalize(embedding, p=2, dim=1) - - ret["token_num"] = all_token_num - - if base64_encode == "base64": - out_embeddings = self.__encode_base64(normalized_embeddings) - else: - out_embeddings = normalized_embeddings.tolist() - ret["embedding"] = out_embeddings - - gc.collect() - torch.cuda.empty_cache() - if self.device == "xpu": - torch.xpu.empty_cache() - if self.device == "npu": - torch.npu.empty_cache() - except torch.cuda.OutOfMemoryError as e: - ret = { - "text": f"{SERVER_ERROR_MSG}\n\n({e})", - "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, - } - except (ValueError, RuntimeError) as e: - ret = { - "text": f"{SERVER_ERROR_MSG}\n\n({e})", - "error_code": ErrorCode.INTERNAL_ERROR, - } - return ret + raise NotImplementedError() def release_worker_semaphore(): @@ -469,12 +209,16 @@ def create_model_worker(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) - parser.add_argument("--worker-address", type=str, - default="http://localhost:21002") + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") parser.add_argument( "--controller-address", type=str, default="http://localhost:21001" ) - add_model_args(parser) + parser.add_argument( + "--supported-models-file", type=str, default="supported_models.json" + ) + parser.add_argument( + "--model", type=str, default="falcon-180b", help="The model name to be called." + ) parser.add_argument( "--model-names", type=lambda s: s.split(","), @@ -501,41 +245,30 @@ def create_model_worker(): args = parser.parse_args() logger.info(f"args: {args}") - if args.gpus: - if len(args.gpus.split(",")) < args.num_gpus: - raise ValueError( - f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" - ) - os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus - - gptq_config = GptqConfig( - ckpt=args.gptq_ckpt or args.model_path, - wbits=args.gptq_wbits, - groupsize=args.gptq_groupsize, - act_order=args.gptq_act_order, - ) - awq_config = AWQConfig( - ckpt=args.awq_ckpt or args.model_path, - wbits=args.awq_wbits, - groupsize=args.awq_groupsize, - ) + with open(args.supported_models_file, "r") as f: + supported_models = json.load(f) + + if args.model not in supported_models: + raise ValueError( + f"Model {args.model} not supported. Please add it to {args.supported_models_file}." + ) + + model_path = supported_models[args.model]["model_path"] + api_base = supported_models[args.model]["api_base"] + token = supported_models[args.model]["token"] + context_length = supported_models[args.model]["context_length"] - worker = ModelWorker( + worker = HuggingfaceApiWorker( args.controller_address, args.worker_address, worker_id, - args.model_path, + model_path, + api_base, + token, + context_length, args.model_names, args.limit_worker_concurrency, no_register=args.no_register, - device=args.device, - num_gpus=args.num_gpus, - max_gpu_memory=args.max_gpu_memory, - dtype=str_to_torch_dtype(args.dtype), - load_8bit=args.load_8bit, - cpu_offloading=args.cpu_offloading, - gptq_config=gptq_config, - awq_config=awq_config, stream_interval=args.stream_interval, conv_template=args.conv_template, embed_in_truncate=args.embed_in_truncate, From 25b1a20cae690bd24d9de216e324866f8b6c1c99 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 20 Sep 2023 21:51:04 +0800 Subject: [PATCH 03/13] fix some format bugs --- fastchat/serve/huggingface_api_worker.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 776ba85d7..303bf914a 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -6,7 +6,7 @@ "model_path": "tiiuae/falcon-180B-chat", "api_base": "https://api-inference.huggingface.co/models", "token": "hf_xxx", - "context_length": "2048" + "context_length": 2048 } } """ @@ -25,6 +25,7 @@ from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse +from huggingface_hub import InferenceClient import requests from fastchat.serve.model_worker import BaseModelWorker @@ -72,6 +73,11 @@ app = FastAPI() +def get_gen_kwargs(): + # TODO + pass + + class HuggingfaceApiWorker(BaseModelWorker): def __init__( self, @@ -123,10 +129,19 @@ def __init__( def count_token(self, params): # No tokenizer here - return 0 + ret = { + "count": 0, + "error_code": 0, + } + return ret def generate_stream_gate(self, params): self.call_ct += 1 + + url = f"{self.api_base}/{self.model_path}" + client = InferenceClient(url, token=self.token) + prompt = params["prompt"] + raise NotImplementedError() def generate_gate(self, params): @@ -222,6 +237,7 @@ def create_model_worker(): parser.add_argument( "--model-names", type=lambda s: s.split(","), + default="falcon-180b", help="Optional display comma separated names", ) parser.add_argument( From 104400b111a2a570a76174fbd725fdadb02e1637 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 21 Sep 2023 11:43:33 +0800 Subject: [PATCH 04/13] can roughly be used... --- fastchat/conversation.py | 2 + fastchat/serve/huggingface_api_worker.py | 89 ++++++++++++++++++++++-- 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 763856f85..4016e0ea8 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -958,6 +958,8 @@ def get_conv_template(name: str) -> Conversation: Conversation( name="falcon-chat", roles=("User", "Falcon"), + system_template="System: {system_message}\n", + system_message="\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.", messages=[], sep_style=SeparatorStyle.FALCON_CHAT, sep="\n", diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 303bf914a..2d8e93f02 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -73,9 +73,38 @@ app = FastAPI() -def get_gen_kwargs(): - # TODO - pass +# reference to +# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392 +def get_gen_kwargs( + params, + seed: Optional[int] = None, +): + gen_kwargs = { + "do_sample": True, + "return_full_text": bool(params.get("echo", False)), + "max_new_tokens": int(params.get("max_new_tokens", 256)), + "top_p": float(params.get("top_p", 1.0)), + "temperature": float(params.get("temperature", 1.0)), + "stop_sequences": params.get("stop", None), + "repetition_penalty": float(params.get("repetition_penalty", 1.0)), + "top_k": params.get("top_k", None), + "seed": seed, + } + if gen_kwargs["top_p"] == 1: + gen_kwargs["top_p"] = 0.9999999 + if gen_kwargs["top_p"] == 0: + gen_kwargs.pop("top_p") + if gen_kwargs["temperature"] == 0: + gen_kwargs.pop("temperature") + gen_kwargs["do_sample"] = False + return gen_kwargs + + +def could_be_stop(text, stop): + for s in stop: + if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)): + return True + return False class HuggingfaceApiWorker(BaseModelWorker): @@ -138,11 +167,57 @@ def count_token(self, params): def generate_stream_gate(self, params): self.call_ct += 1 - url = f"{self.api_base}/{self.model_path}" - client = InferenceClient(url, token=self.token) prompt = params["prompt"] - - raise NotImplementedError() + gen_kwargs = get_gen_kwargs(params, seed=self.seed) + logger.info(f"prompt: {prompt}") + logger.info(f"gen_kwargs: {gen_kwargs}") + + stop = gen_kwargs["stop_sequences"] + if "falcon" in self.model_path: + stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) + stop = list(set(stop)) + gen_kwargs["stop_sequences"] = stop + + try: + url = f"{self.api_base}/{self.model_path}" + client = InferenceClient(url, token=self.token) + res = client.text_generation( + prompt, stream=True, details=True, **gen_kwargs + ) + + reason = None + text = "" + for chunk in res: + if chunk.token.special: + continue + text += chunk.token.text + + s = next((x for x in stop if text.endswith(x)), None) + if s is not None: + text = text[: -len(s)] + reason = "stop" + break + if could_be_stop(text, stop): + continue + if ( + chunk.details is not None + and chunk.details.finish_reason is not None + ): + reason = chunk.details.finish_reason + if reason not in ["stop", "length"]: + reason = None + ret = { + "text": text, + "error_code": 0, + "finish_reason": reason, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" def generate_gate(self, params): for x in self.generate_stream_gate(params): From 5e2e2dc6d57bd57da8fb3e90e92c6e9ab7100ed2 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 21 Sep 2023 20:58:00 +0800 Subject: [PATCH 05/13] change the default name for falcon from "falcon-180b" to "falcon-180b-chat". To match the corresponding conversation template in model adapter. --- fastchat/serve/huggingface_api_worker.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 2d8e93f02..a0fbb8a0d 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -2,7 +2,7 @@ A model worker to call huggingface api. The contents in supported_models.json : { - "falcon-180b": { + "falcon-180b-chat": { "model_path": "tiiuae/falcon-180B-chat", "api_base": "https://api-inference.huggingface.co/models", "token": "hf_xxx", @@ -307,12 +307,15 @@ def create_model_worker(): "--supported-models-file", type=str, default="supported_models.json" ) parser.add_argument( - "--model", type=str, default="falcon-180b", help="The model name to be called." + "--model", + type=str, + default="falcon-180b-chat", + help="The model name to be called.", ) parser.add_argument( "--model-names", type=lambda s: s.split(","), - default="falcon-180b", + default="falcon-180b-chat", help="Optional display comma separated names", ) parser.add_argument( From f23c824bc0de87fa79ab5e6dad4e1281d6a37676 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Thu, 21 Sep 2023 21:08:33 +0800 Subject: [PATCH 06/13] fixed the params["stop"]'s type bugs --- fastchat/serve/huggingface_api_worker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index a0fbb8a0d..214782628 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -79,13 +79,20 @@ def get_gen_kwargs( params, seed: Optional[int] = None, ): + stop = params.get("stop", None) + if isinstance(stop, list): + stop_sequences = stop + elif isinstance(stop, str): + stop_sequences = [stop] + else: + stop_sequences = [] gen_kwargs = { "do_sample": True, "return_full_text": bool(params.get("echo", False)), "max_new_tokens": int(params.get("max_new_tokens", 256)), "top_p": float(params.get("top_p", 1.0)), "temperature": float(params.get("temperature", 1.0)), - "stop_sequences": params.get("stop", None), + "stop_sequences": stop_sequences, "repetition_penalty": float(params.get("repetition_penalty", 1.0)), "top_k": params.get("top_k", None), "seed": seed, From 6f41a868a05e23addaa8290c7a996507f508970b Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 22 Sep 2023 15:09:26 +0800 Subject: [PATCH 07/13] change the literal condition for falcon-180b-chat --- fastchat/serve/huggingface_api_worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 214782628..6c647e64e 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -176,15 +176,15 @@ def generate_stream_gate(self, params): prompt = params["prompt"] gen_kwargs = get_gen_kwargs(params, seed=self.seed) - logger.info(f"prompt: {prompt}") - logger.info(f"gen_kwargs: {gen_kwargs}") - stop = gen_kwargs["stop_sequences"] - if "falcon" in self.model_path: + if "falcon" in self.model_path and "chat" in self.model_path: stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) stop = list(set(stop)) gen_kwargs["stop_sequences"] = stop + logger.info(f"prompt: {prompt}") + logger.info(f"gen_kwargs: {gen_kwargs}") + try: url = f"{self.api_base}/{self.model_path}" client = InferenceClient(url, token=self.token) From f1c81370e5f0216c2cb8d9d9aa58cfa19053e0a0 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 22 Sep 2023 16:16:29 +0800 Subject: [PATCH 08/13] removed all unused imports --- fastchat/serve/huggingface_api_worker.py | 47 +----------------------- 1 file changed, 2 insertions(+), 45 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 6c647e64e..14f5aec6e 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -12,59 +12,19 @@ """ import argparse import asyncio -import base64 -import dataclasses -import gc -import logging import json -import os -import threading -import time from typing import List, Optional import uuid from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse from huggingface_hub import InferenceClient -import requests - from fastchat.serve.model_worker import BaseModelWorker -try: - from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - LlamaTokenizer, - AutoModel, - ) -except ImportError: - from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - LLaMATokenizer, - AutoModel, - ) -import torch -import torch.nn.functional as F -from transformers import set_seed import uvicorn -from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG -from fastchat.conversation import get_conv_template -from fastchat.model.model_adapter import ( - load_model, - add_model_args, - get_conversation_template, - get_generate_stream_function, -) -from fastchat.modules.gptq import GptqConfig -from fastchat.modules.awq import AWQConfig -from fastchat.utils import ( - build_logger, - pretty_print_semaphore, - get_context_length, - str_to_torch_dtype, -) +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.utils import build_logger worker_id = str(uuid.uuid4())[:8] @@ -231,10 +191,7 @@ def generate_gate(self, params): pass return json.loads(x[:-1].decode()) - @torch.inference_mode() def get_embeddings(self, params): - self.call_ct += 1 - raise NotImplementedError() From cd329a0dd6c52fc4ec24c6426a8bbb3d293c05c2 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 22 Sep 2023 16:24:08 +0800 Subject: [PATCH 09/13] remove the unused members and args --- fastchat/serve/huggingface_api_worker.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 14f5aec6e..389383e92 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -87,9 +87,7 @@ def __init__( model_names: List[str], limit_worker_concurrency: int, no_register: bool, - stream_interval: int = 2, conv_template: Optional[str] = None, - embed_in_truncate: bool = False, seed: Optional[int] = None, **kwargs, ): @@ -107,19 +105,12 @@ def __init__( self.api_base = api_base self.token = token self.context_len = context_length + self.seed = seed logger.info( f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." ) - self.model = None - self.tokenizer = None - self.device = None - self.generate_stream_func = None - self.stream_interval = stream_interval - self.embed_in_truncate = embed_in_truncate - self.seed = seed - if not no_register: self.init_heart_beat() @@ -285,14 +276,12 @@ def create_model_worker(): parser.add_argument( "--conv-template", type=str, default=None, help="Conversation prompt template." ) - parser.add_argument("--embed-in-truncate", action="store_true") parser.add_argument( "--limit-worker-concurrency", type=int, default=5, help="Limit the model concurrency to prevent OOM.", ) - parser.add_argument("--stream-interval", type=int, default=2) parser.add_argument("--no-register", action="store_true") parser.add_argument( "--seed", @@ -327,9 +316,7 @@ def create_model_worker(): args.model_names, args.limit_worker_concurrency, no_register=args.no_register, - stream_interval=args.stream_interval, conv_template=args.conv_template, - embed_in_truncate=args.embed_in_truncate, seed=args.seed, ) return args, worker From 6a7f4ba6b1e76557305efd17b8fb8ea1359fde88 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 22 Sep 2023 16:44:16 +0800 Subject: [PATCH 10/13] rename the arg's name --- fastchat/serve/huggingface_api_worker.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 389383e92..a4caf8618 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -1,6 +1,6 @@ """ A model worker to call huggingface api. -The contents in supported_models.json : +JSON file format: { "falcon-180b-chat": { "model_path": "tiiuae/falcon-180B-chat", @@ -259,7 +259,10 @@ def create_model_worker(): "--controller-address", type=str, default="http://localhost:21001" ) parser.add_argument( - "--supported-models-file", type=str, default="supported_models.json" + "--model-info-file", + type=str, + required=True, + help="Huggingface API model's info file path", ) parser.add_argument( "--model", @@ -292,18 +295,18 @@ def create_model_worker(): args = parser.parse_args() logger.info(f"args: {args}") - with open(args.supported_models_file, "r") as f: - supported_models = json.load(f) + with open(args.model_info_file, "r", encoding="UTF-8") as f: + model_info = json.load(f) - if args.model not in supported_models: + if args.model not in model_info: raise ValueError( - f"Model {args.model} not supported. Please add it to {args.supported_models_file}." + f"Model {args.model} not supported. Please add it to {args.model_info_file}." ) - model_path = supported_models[args.model]["model_path"] - api_base = supported_models[args.model]["api_base"] - token = supported_models[args.model]["token"] - context_length = supported_models[args.model]["context_length"] + model_path = model_info[args.model]["model_path"] + api_base = model_info[args.model]["api_base"] + token = model_info[args.model]["token"] + context_length = model_info[args.model]["context_length"] worker = HuggingfaceApiWorker( args.controller_address, From 3aa01c91f276f0d827a15071d57965275893e346 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 22 Sep 2023 18:56:20 +0800 Subject: [PATCH 11/13] support multiple huggingface models --- fastchat/serve/huggingface_api_worker.py | 145 ++++++++++++++++------- 1 file changed, 105 insertions(+), 40 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index a4caf8618..aae8a17f2 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -19,6 +19,7 @@ from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse from huggingface_hub import InferenceClient +import requests from fastchat.serve.model_worker import BaseModelWorker import uvicorn @@ -30,6 +31,8 @@ worker_id = str(uuid.uuid4())[:8] logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +workers = [] +worker_map = {} app = FastAPI() @@ -111,9 +114,6 @@ def __init__( f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." ) - if not no_register: - self.init_heart_beat() - def count_token(self, params): # No tokenizer here ret = { @@ -186,71 +186,83 @@ def get_embeddings(self, params): raise NotImplementedError() -def release_worker_semaphore(): +def release_worker_semaphore(worker): worker.semaphore.release() -def acquire_worker_semaphore(): +def acquire_worker_semaphore(worker): if worker.semaphore is None: worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) return worker.semaphore.acquire() -def create_background_tasks(): +def create_background_tasks(worker): background_tasks = BackgroundTasks() - background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(lambda: release_worker_semaphore(worker)) return background_tasks @app.post("/worker_generate_stream") async def api_generate_stream(request: Request): params = await request.json() - await acquire_worker_semaphore() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) generator = worker.generate_stream_gate(params) - background_tasks = create_background_tasks() + background_tasks = create_background_tasks(worker) return StreamingResponse(generator, background=background_tasks) @app.post("/worker_generate") async def api_generate(request: Request): params = await request.json() - await acquire_worker_semaphore() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) output = worker.generate_gate(params) - release_worker_semaphore() + release_worker_semaphore(worker) return JSONResponse(output) @app.post("/worker_get_embeddings") async def api_get_embeddings(request: Request): params = await request.json() - await acquire_worker_semaphore() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) embedding = worker.get_embeddings(params) - release_worker_semaphore() + release_worker_semaphore(worker) return JSONResponse(content=embedding) @app.post("/worker_get_status") async def api_get_status(request: Request): - return worker.get_status() + return { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + } @app.post("/count_token") async def api_count_token(request: Request): params = await request.json() + worker = worker_map[params["model"]] return worker.count_token(params) @app.post("/worker_get_conv_template") async def api_get_conv(request: Request): + params = await request.json() + worker = worker_map[params["model"]] return worker.get_conv_template() @app.post("/model_details") async def api_model_details(request: Request): + params = await request.json() + worker = worker_map[params["model"]] return {"context_length": worker.context_len} -def create_model_worker(): +def create_huggingface_api_worker(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) @@ -264,20 +276,27 @@ def create_model_worker(): required=True, help="Huggingface API model's info file path", ) + + # support multi huggingface api models here parser.add_argument( "--model", type=str, - default="falcon-180b-chat", - help="The model name to be called.", + default=[], + action="append", + help="The models' names to be called.", ) parser.add_argument( "--model-names", type=lambda s: s.split(","), - default="falcon-180b-chat", - help="Optional display comma separated names", + action="append", + help="One or more model names. Values must be aligned with `--model` values.", ) parser.add_argument( - "--conv-template", type=str, default=None, help="Conversation prompt template." + "--conv-template", + type=str, + default=None, + action="append", + help="Conversation prompt template. Values must be aligned with `--model` values. If only one value is provided, it will be repeated for all models.", ) parser.add_argument( "--limit-worker-concurrency", @@ -293,38 +312,84 @@ def create_model_worker(): help="Overwrite the random seed for each generation.", ) args = parser.parse_args() - logger.info(f"args: {args}") + + if args.model_names is None: + args.model_names = [[x.split("/")[-1]] for x in args.model] + if args.conv_template is None: + args.conv_template = [None] * len(args.model) + elif len(args.conv_template) == 1: # Repeat the same template + args.conv_template = args.conv_template * len(args.model) with open(args.model_info_file, "r", encoding="UTF-8") as f: model_info = json.load(f) - if args.model not in model_info: - raise ValueError( - f"Model {args.model} not supported. Please add it to {args.model_info_file}." - ) + logger.info(f"args: {args}") - model_path = model_info[args.model]["model_path"] - api_base = model_info[args.model]["api_base"] - token = model_info[args.model]["token"] - context_length = model_info[args.model]["context_length"] + model_path_list = [] + api_base_list = [] + token_list = [] + context_length_list = [] - worker = HuggingfaceApiWorker( - args.controller_address, - args.worker_address, - worker_id, + for m in args.model: + if m not in model_info: + raise ValueError( + f"Model {args.model} not supported. Please add it to {args.model_info_file}." + ) + model_path_list.append(model_info[m]["model_path"]) + api_base_list.append(model_info[m]["api_base"]) + token_list.append(model_info[m]["token"]) + context_length_list.append(model_info[m]["context_length"]) + + for ( + model_names, + conv_template, model_path, api_base, token, context_length, + ) in zip( args.model_names, - args.limit_worker_concurrency, - no_register=args.no_register, - conv_template=args.conv_template, - seed=args.seed, - ) - return args, worker + args.conv_template, + model_path_list, + api_base_list, + token_list, + context_length_list, + ): + m = HuggingfaceApiWorker( + args.controller_address, + args.worker_address, + worker_id, + model_path, + api_base, + token, + context_length, + model_names, + args.limit_worker_concurrency, + no_register=args.no_register, + conv_template=conv_template, + seed=args.seed, + ) + workers.append(m) + for name in model_names: + worker_map[name] = m + + # register all the models + url = args.controller_address + "/register_worker" + data = { + "worker_name": workers[0].worker_addr, + "check_heart_beat": not args.no_register, + "worker_status": { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + }, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + return args, workers if __name__ == "__main__": - args, worker = create_model_worker() + args, workers = create_huggingface_api_worker() uvicorn.run(app, host=args.host, port=args.port, log_level="info") From cc6b7ac53ecabc2fb350043319b9e07bce700a1c Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 29 Sep 2023 16:17:15 +0800 Subject: [PATCH 12/13] move all model-related arguments into --model-info-file --- fastchat/serve/huggingface_api_worker.py | 63 ++++++++++-------------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index aae8a17f2..98f32265e 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -7,8 +7,12 @@ "api_base": "https://api-inference.huggingface.co/models", "token": "hf_xxx", "context_length": 2048 + "model_names": "falcon-180b-chat", + "conv_template": null, } } + +Only "model_path", "api_base", and "token" are necessary, others are optional. """ import argparse import asyncio @@ -270,6 +274,7 @@ def create_huggingface_api_worker(): parser.add_argument( "--controller-address", type=str, default="http://localhost:21001" ) + # all model-related parameters are listed in --model-info-file parser.add_argument( "--model-info-file", type=str, @@ -277,27 +282,6 @@ def create_huggingface_api_worker(): help="Huggingface API model's info file path", ) - # support multi huggingface api models here - parser.add_argument( - "--model", - type=str, - default=[], - action="append", - help="The models' names to be called.", - ) - parser.add_argument( - "--model-names", - type=lambda s: s.split(","), - action="append", - help="One or more model names. Values must be aligned with `--model` values.", - ) - parser.add_argument( - "--conv-template", - type=str, - default=None, - action="append", - help="Conversation prompt template. Values must be aligned with `--model` values. If only one value is provided, it will be repeated for all models.", - ) parser.add_argument( "--limit-worker-concurrency", type=int, @@ -313,13 +297,6 @@ def create_huggingface_api_worker(): ) args = parser.parse_args() - if args.model_names is None: - args.model_names = [[x.split("/")[-1]] for x in args.model] - if args.conv_template is None: - args.conv_template = [None] * len(args.model) - elif len(args.conv_template) == 1: # Repeat the same template - args.conv_template = args.conv_template * len(args.model) - with open(args.model_info_file, "r", encoding="UTF-8") as f: model_info = json.load(f) @@ -329,16 +306,30 @@ def create_huggingface_api_worker(): api_base_list = [] token_list = [] context_length_list = [] + model_names_list = [] + conv_template_list = [] - for m in args.model: - if m not in model_info: - raise ValueError( - f"Model {args.model} not supported. Please add it to {args.model_info_file}." - ) + for m in model_info: model_path_list.append(model_info[m]["model_path"]) api_base_list.append(model_info[m]["api_base"]) token_list.append(model_info[m]["token"]) - context_length_list.append(model_info[m]["context_length"]) + + context_length = model_info[m].get("context_length", 1024) + model_names = model_info[m].get("model_names", [m.split("/")[-1]]) + if isinstance(model_names, str): + model_names = [model_names] + conv_template = model_info[m].get("conv_template", None) + + context_length_list.append(context_length) + model_names_list.append(model_names) + conv_template_list.append(conv_template) + + logger.info(f"Model paths: {model_path_list}") + logger.info(f"API bases: {api_base_list}") + logger.info(f"Tokens: {token_list}") + logger.info(f"Context lengths: {context_length_list}") + logger.info(f"Model names: {model_names_list}") + logger.info(f"Conv templates: {conv_template_list}") for ( model_names, @@ -348,8 +339,8 @@ def create_huggingface_api_worker(): token, context_length, ) in zip( - args.model_names, - args.conv_template, + model_names_list, + conv_template_list, model_path_list, api_base_list, token_list, From 331133bca839ee82617bc914c133d424d589ebd5 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 29 Sep 2023 16:32:13 +0800 Subject: [PATCH 13/13] organize the imports using isort --- fastchat/serve/huggingface_api_worker.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 98f32265e..29ddaa40c 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -17,21 +17,19 @@ import argparse import asyncio import json -from typing import List, Optional import uuid +from typing import List, Optional -from fastapi import FastAPI, Request, BackgroundTasks -from fastapi.responses import StreamingResponse, JSONResponse -from huggingface_hub import InferenceClient import requests -from fastchat.serve.model_worker import BaseModelWorker - import uvicorn +from fastapi import BackgroundTasks, FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse +from huggingface_hub import InferenceClient -from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.constants import SERVER_ERROR_MSG, ErrorCode +from fastchat.serve.model_worker import BaseModelWorker from fastchat.utils import build_logger - worker_id = str(uuid.uuid4())[:8] logger = build_logger("model_worker", f"model_worker_{worker_id}.log")