diff --git a/README.md b/README.md index 2847a76..70cda9f 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ We create 32 threads to submit chat tasks to the server, and the following figur ![benchmark_chat](assets/benchmark_chat.jpg) ## News +- [2023/07/13] Support generation logprobs parameter. - [2023/06/18] Add ggml (llama.cpp gpt.cpp starcoder.cpp etc.) worker support. - [2023/06/09] Add LLama.cpp worker support. - [2023/06/01] Add HuggingFace Bert embedding worker support. diff --git a/docs/openai_api.md b/docs/openai_api.md index a94d0ce..ab3e419 100644 --- a/docs/openai_api.md +++ b/docs/openai_api.md @@ -73,7 +73,7 @@ Here we list the parameter compatibility of completions API. | `top_p` | ● | ● | `1.0` | - | | `n` | ● | ● | `1` | `COMPLETION_MAX_N` | | `stream` | ● | ● | `false` | - | -| `logprobs` | ○ | ● | `0` | `COMPLETION_MAX_LOGPROBS` | +| `logprobs` | ● | ● | `0` | `COMPLETION_MAX_LOGPROBS` | | `echo` | ● | ● | `false` | - | | `stop` | ● | ● | - | - | | `presence_penalty` | ○ | ● | - | - | diff --git a/langport/data/conversation/settings/ningyu.py b/langport/data/conversation/settings/ningyu.py new file mode 100644 index 0000000..d5a0b22 --- /dev/null +++ b/langport/data/conversation/settings/ningyu.py @@ -0,0 +1,13 @@ +from langport.data.conversation import ( + ConversationSettings, + SeparatorStyle, +) + +# Ningyu default template +ningyu = ConversationSettings( + name="ningyu", + roles=("user", "assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str="###", +) diff --git a/langport/model/adapters/ningyu.py b/langport/model/adapters/ningyu.py new file mode 100644 index 0000000..245c711 --- /dev/null +++ b/langport/model/adapters/ningyu.py @@ -0,0 +1,22 @@ +from langport.data.conversation import ConversationHistory, SeparatorStyle +from langport.data.conversation.conversation_settings import get_conv_settings +from langport.model.model_adapter import BaseAdapter + + +class NingYuAdapter(BaseAdapter): + """The model adapter for ningyu""" + + def match(self, model_path: str): + return "ningyu" in model_path + + def get_default_conv_template(self, model_path: str) -> ConversationHistory: + settings = get_conv_settings("ningyu") + return ConversationHistory( + system="""A chat between a curious user and an artificial intelligence assistant. +The name of the assistant is NingYu (凝语). +The assistant gives helpful, detailed, and polite answers to the user's questions.""", + messages=[], + offset=0, + settings=settings, + ) + diff --git a/langport/model/compression.py b/langport/model/compression.py index e06c2b2..8721e34 100644 --- a/langport/model/compression.py +++ b/langport/model/compression.py @@ -44,7 +44,11 @@ def __init__(self, weight=None, bias=None, device=None): def forward(self, input: Tensor) -> Tensor: weight = decompress(self.weight, default_compression_config) - return F.linear(input.to(weight.dtype), weight, self.bias) + if self.bias is not None: + bias = self.bias.to(weight.dtype) + else: + bias = self.bias + return F.linear(input.to(weight.dtype), weight, bias) def compress_module(module, target_device): @@ -138,7 +142,6 @@ def load_compress_model(model_path, device, torch_dtype): return model, tokenizer - def compress(tensor, config): """Simulate group-wise quantization.""" if not config.enabled: @@ -191,7 +194,6 @@ def compress(tensor, config): data = data.clamp_(0, B).round_().to(torch.uint8) return data, mn, scale, original_shape - def decompress(packed_data, config): """Simulate group-wise dequantization.""" if not config.enabled: diff --git a/langport/model/executor/generation/ggml.py b/langport/model/executor/generation/ggml.py index 80e4a2a..7a2cb20 100644 --- a/langport/model/executor/generation/ggml.py +++ b/langport/model/executor/generation/ggml.py @@ -28,14 +28,16 @@ def stream_generation( output_ids = [] # Compatible with some models - top_k = 40 if task.top_k <= 1 else task.top_k - repetition_penalty = 1.17647 if task.repetition_penalty == 0.0 else task.repetition_penalty + top_k = 10 if task.top_k <= 1 else task.top_k + repetition_penalty = 1.01 if task.repetition_penalty == 0.0 else task.repetition_penalty + model.config.max_new_tokens = task.max_tokens finish_reason = "stop" n_tokens = 0 for token in model.generate( - tokens, top_k=top_k, top_p=task.top_p, batch_size=512, - temperature=task.temperature, repetition_penalty=repetition_penalty): + tokens, top_k=top_k, top_p=task.top_p, batch_size=model.config.batch_size, + threads=model.config.threads, temperature=task.temperature, + last_n_tokens=256, repetition_penalty=repetition_penalty, reset=True): n_tokens += 1 output_ids.append(token) if n_tokens == task.max_tokens: @@ -94,6 +96,8 @@ def __init__( model_path: str, context_length: int, gpu_layers: int, + chunk_size: int, + threads: int, model_type: str = "llama", lib: Optional[str] = None, ) -> None: @@ -105,6 +109,8 @@ def __init__( num_gpus=n_gpu, max_gpu_memory=None, gpu_layers=gpu_layers, + chunk_size=chunk_size, + threads=threads, lib=lib, model_type=model_type, ) diff --git a/langport/model/executor/generation/huggingface.py b/langport/model/executor/generation/huggingface.py index 41ef4c3..644576e 100644 --- a/langport/model/executor/generation/huggingface.py +++ b/langport/model/executor/generation/huggingface.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union from langport.model.executor.generation import BaseStreamer from langport.model.executor.huggingface import HuggingfaceExecutor @@ -6,13 +6,14 @@ from langport.protocol.worker_protocol import ( BaseWorkerResult, GenerationTask, + GenerationWorkerLogprobs, GenerationWorkerResult, UsageInfo, ) import torch -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizer from transformers.generation.logits_process import ( LogitsProcessor, LogitsProcessorList, @@ -30,6 +31,12 @@ import torch +def token_to_unicode(token: str) -> str: + utf8_bytes = token.encode("utf-8") + # Convert the bytes to a string with \\x escape sequences + escaped_bytes = "".join([f"\\x{b:02x}" for b in utf8_bytes]) + return escaped_bytes + @cached(LRUCache(maxsize=64)) def prepare_logits_processor( temperature: float, repetition_penalty: float, top_p: float, top_k: int @@ -48,7 +55,7 @@ def prepare_logits_processor( class BatchingTask: - def __init__(self, tasks: List[GenerationTask], tokenizer: PreTrainedTokenizerBase, device: str, is_encoder_decoder: bool) -> None: + def __init__(self, tasks: List[GenerationTask], tokenizer: PreTrainedTokenizer, device: str, is_encoder_decoder: bool) -> None: self.batch_size = len(tasks) if self.batch_size == 0: return @@ -79,6 +86,8 @@ def __init__(self, tasks: List[GenerationTask], tokenizer: PreTrainedTokenizerBa # variables used in the streaming process self.batch_tokens_cache: List[List[int]] = [[] for i in range(self.batch_size)] + self.batch_tokens_probs_cache: List[List[float]] = [[] for i in range(self.batch_size)] + self.batch_top_logprobs_cache: List[List[Dict[str, float]]] = [[] for i in range(self.batch_size)] self.stop = [False for i in range(self.batch_size)] def __len__(self): @@ -132,15 +141,31 @@ def get_logits_processor_list(self, idx:int) -> LogitsProcessorList: self._check_idx(idx) return self.logits_processor_list[idx] - def get_generated_ids(self, idx:int) -> List[int]: + def get_generated_ids(self, idx: int) -> List[int]: self._check_idx(idx) return self.batch_tokens_cache[idx] - def get_generated_length(self, idx:int) -> int: + def get_generated_length(self, idx: int) -> int: return len(self.get_generated_ids(idx)) - def update_new_token(self, batch_token: List[int]): + def get_generated_token_probs(self, idx: int) -> List[float]: + self._check_idx(idx) + return self.batch_tokens_probs_cache[idx] + + def get_generated_top_logprobs(self, idx: int) -> List[Dict[int, float]]: + self._check_idx(idx) + return self.batch_top_logprobs_cache[idx] + + def update_new_token(self, batch_token: List[int], + token_probs: Optional[List[Optional[float]]]=None, + top_logprobs: Optional[List[Optional[Dict[int, float]]]]=None + ): self._check_batch_size(batch_token) + if token_probs is not None: + self._check_batch_size(token_probs) + if top_logprobs is not None: + self._check_batch_size(top_logprobs) + for i, token in enumerate(batch_token): if self.is_stop(i): continue @@ -153,6 +178,11 @@ def update_new_token(self, batch_token: List[int]): self.set_stop(i) if self.get_generated_length(i) == self.max_tokens[i]: self.set_stop(i) + + if token_probs is not None and token_probs[i] is not None: + self.batch_tokens_probs_cache[i].append(token_probs[i]) + if top_logprobs is not None and top_logprobs[i] is not None: + self.batch_top_logprobs_cache[i].append(top_logprobs[i]) def set_stop(self, idx:int): self._check_idx(idx) @@ -227,6 +257,9 @@ def generate(self, inputs: BatchingTask, decoder_input_ids_list = [] new_ids = [] + # logprobs + token_probs = [None] * inputs.batch_size + top_logprobs = [None] * inputs.batch_size for i, task in enumerate(inputs.tasks): if inputs.is_stop(i): @@ -253,10 +286,18 @@ def generate(self, inputs: BatchingTask, else: probs = torch.softmax(last_token_logits, dim=-1) token = int(torch.multinomial(probs, num_samples=1)) + + if task.logprobs is not None: + token_probs[i] = each_logits[0, token].item() + top_values, top_indices = torch.topk(each_logits[0, :], task.logprobs, dim=-1, largest=True, sorted=True) + item = {} + for top_i in range(len(top_values)): + item[top_indices[top_i].item()] = top_values[top_i].item() + top_logprobs[i] = item new_ids.append(token) # update state - inputs.update_new_token(new_ids) + inputs.update_new_token(new_ids, token_probs=token_probs, top_logprobs=top_logprobs) if streamer: streamer.put(new_ids) @@ -284,7 +325,7 @@ def generate(self, inputs: BatchingTask, class GenerationWorkerStreamer(BaseStreamer): def __init__(self, task_batch: BatchingTask, - tokenizer: PreTrainedTokenizerBase, + tokenizer: PreTrainedTokenizer, worker: "GenerationModelWorker") -> None: self.task_batch = task_batch self.tokenizer = tokenizer @@ -299,13 +340,74 @@ def put(self, value): if (self.done[i] or generated_len % self.stream_interval != 0) and self.done[i]==self.task_batch.is_stop(i): continue task = self.task_batch.tasks[i] - text = self.tokenizer.decode(self.task_batch.get_generated_ids(i), skip_special_tokens=True) + + token_ids = self.task_batch.get_generated_ids(i) + + # text = self.tokenizer.decode(token_ids, skip_special_tokens=True) + tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + text = self.tokenizer.convert_tokens_to_string(tokens) + + # get offset mapping from token to text + text_offset = [] + for token_i in range(0, len(tokens)): + if token_i == 0: + text_offset.append(-1) + continue + prefix_text = self.tokenizer.convert_tokens_to_string(tokens[:token_i]) + if text.startswith(prefix_text): + text_offset.append(len(prefix_text)) + else: + text_offset.append(-1) + + last_id = len(text) + for token_i in reversed(range(0, len(tokens))): + if text_offset[token_i] == -1: + text_offset[token_i] = last_id + else: + last_id = text_offset[token_i] + + token_logprobs = self.task_batch.get_generated_token_probs(i) + top_logprobs = self.task_batch.get_generated_top_logprobs(i) + if top_logprobs is not None: + top_logprobs_new = [] + for prob in top_logprobs: + top_logprobs_new.append({self.tokenizer.convert_ids_to_tokens(k): v for k, v in prob.items()}) + top_logprobs = top_logprobs_new + + # remove stop words stop_pos = stop_by_stopwords(text, 0, task.stop) if stop_pos != -1: + token_stop_pos = len(tokens) + for token_i in reversed(range(0, len(text_offset))): + if text_offset[token_i] < stop_pos: + token_stop_pos = token_i + 1 + break + self.task_batch.set_stop(i) + + # remove tokens after stop pos text = text[:stop_pos] + tokens = tokens[:token_stop_pos] + if token_logprobs is not None: + token_logprobs = token_logprobs[:token_stop_pos] + if top_logprobs is not None: + top_logprobs = top_logprobs[:token_stop_pos] + text_offset = text_offset[:token_stop_pos] + prompt_len = self.task_batch.get_prompt_length(i) - + + # logprobs + if self.task_batch.tasks[i].logprobs is not None: + logprobs = GenerationWorkerLogprobs( + tokens=tokens, + token_logprobs=token_logprobs, + top_logprobs=top_logprobs, + text_offset=text_offset, + ) + else: + logprobs = None + + # push task to queue if self.task_batch.is_stop(i): if generated_len == self.task_batch.max_tokens[i]: finish_reason = "length" @@ -321,6 +423,7 @@ def put(self, value): total_tokens=prompt_len + generated_len, completion_tokens=generated_len, ), + logprobs=logprobs, finish_reason=finish_reason, ) ) @@ -339,6 +442,7 @@ def put(self, value): total_tokens=prompt_len + generated_len, completion_tokens=generated_len, ), + logprobs=logprobs, finish_reason=None, ) ) diff --git a/langport/model/executor/ggml.py b/langport/model/executor/ggml.py index 8af765d..0ea1f10 100644 --- a/langport/model/executor/ggml.py +++ b/langport/model/executor/ggml.py @@ -28,6 +28,8 @@ def __init__( lib: Optional[str] = None, gpu_layers: int = 0, model_type: str = 'llama', + chunk_size: int = 1024, + threads: int = -1, load_8bit: bool = False, cpu_offloading: bool = False, ) -> None: @@ -44,6 +46,8 @@ def __init__( # ctransformers has a bug self.lib = lib self.model_type = model_type + self.chunk_size = chunk_size + self.threads = threads def load_model(self, model_path: str, from_pretrained_kwargs: dict): @@ -51,6 +55,8 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): config = Config() setattr(config, 'stream', True) setattr(config, 'gpu_layers', self.gpu_layers) + setattr(config, 'batch_size', self.chunk_size) + setattr(config, 'threads', self.threads) auto_config = AutoConfig(config=config, model_type=self.model_type) model = AutoModelForCausalLM.from_pretrained(model_path, config=auto_config, diff --git a/langport/protocol/openai_api_protocol.py b/langport/protocol/openai_api_protocol.py index b9c4cc8..2c4b402 100644 --- a/langport/protocol/openai_api_protocol.py +++ b/langport/protocol/openai_api_protocol.py @@ -130,10 +130,16 @@ class CompletionRequest(BaseModel): user: Optional[str] = None +class CompletionLogprobs(BaseModel): + tokens: List[str] + token_logprobs: List[float] + top_logprobs: List[Dict[str, float]] + text_offset: List[int] + class CompletionResponseChoice(BaseModel): index: int text: str - logprobs: Optional[int] = None + logprobs: Optional[CompletionLogprobs] = None finish_reason: Optional[Literal["stop", "length"]] @@ -149,7 +155,7 @@ class CompletionResponse(BaseModel): class CompletionResponseStreamChoice(BaseModel): index: int text: str - logprobs: Optional[float] = None + logprobs: Optional[CompletionLogprobs] = None finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/langport/protocol/worker_protocol.py b/langport/protocol/worker_protocol.py index 300761a..d262ffb 100644 --- a/langport/protocol/worker_protocol.py +++ b/langport/protocol/worker_protocol.py @@ -89,6 +89,7 @@ class GenerationTask(BaseWorkerTask): stop: Optional[Union[List[str], str]] = None echo: Optional[bool] = False stop_token_ids: Optional[List[int]] = None + logprobs: Optional[int] = None class UsageInfo(BaseModel): prompt_tokens: int = 0 @@ -106,8 +107,14 @@ class EmbeddingWorkerResult(BaseWorkerResult): embedding: List[float] usage: UsageInfo = None +class GenerationWorkerLogprobs(BaseModel): + tokens: List[str] + token_logprobs: List[float] + top_logprobs: List[Dict[str, float]] + text_offset: List[int] + class GenerationWorkerResult(BaseWorkerResult): text: str - logprobs: Optional[int] = None + logprobs: Optional[GenerationWorkerLogprobs] = None finish_reason: Optional[Literal["stop", "length"]] = None usage: UsageInfo = None diff --git a/langport/routers/gateway/openai_compatible.py b/langport/routers/gateway/openai_compatible.py index 37c54ee..e435816 100644 --- a/langport/routers/gateway/openai_compatible.py +++ b/langport/routers/gateway/openai_compatible.py @@ -32,6 +32,7 @@ ChatCompletionStreamResponse, ChatMessage, ChatCompletionResponseChoice, + CompletionLogprobs, CompletionRequest, CompletionResponse, CompletionResponseChoice, @@ -68,6 +69,7 @@ def get_gen_params( echo: Optional[bool], stream: Optional[bool], stop: Optional[Union[str, List[str]]], + logprobs: Optional[int]=None, ) -> Dict[str, Any]: # is_chatglm = "chatglm" in model_name.lower() conv = get_conversation_template(model_name) @@ -101,6 +103,7 @@ def get_gen_params( "max_tokens": max_tokens, "echo": echo, "stream": stream, + "logprobs": logprobs, } if stop is None: @@ -138,10 +141,14 @@ async def generate_completion_stream_generator(app_settings: AppSettings, payloa delta_text = decoded_unicode[len(previous_text) :] previous_text = decoded_unicode + if content.logprobs is None: + logprobs = None + else: + logprobs = CompletionLogprobs.parse_obj(content.logprobs.dict()) choice_data = CompletionResponseStreamChoice( index=i, text=delta_text, - logprobs=content.logprobs, + logprobs=logprobs, finish_reason=content.finish_reason, ) chunk = CompletionStreamResponse( @@ -305,11 +312,15 @@ async def completions_non_stream(app_settings: AppSettings, payload: Dict[str, A content = await content_task if content.error_code != ErrorCode.OK: return create_error_response(content.error_code, content.message) + if content.logprobs is None: + logprobs = None + else: + logprobs = CompletionLogprobs.parse_obj(content.logprobs.dict()) choices.append( CompletionResponseChoice( index=i, text=content.text, - logprobs=content.logprobs, + logprobs=logprobs, finish_reason=content.finish_reason, ) ) @@ -375,6 +386,7 @@ async def api_completions(app_settings: AppSettings, request: CompletionRequest) echo=request.echo, stream=request.stream, stop=request.stop, + logprobs=request.logprobs, ) if request.stream: diff --git a/langport/routers/server/generation_node.py b/langport/routers/server/generation_node.py index 8cd138d..f0e370d 100644 --- a/langport/routers/server/generation_node.py +++ b/langport/routers/server/generation_node.py @@ -60,6 +60,7 @@ async def api_completion_stream(request: Request): stop=params.get("stop", None), echo=params.get("echo", False), stop_token_ids=params.get("stop_token_ids", None), + logprobs=params.get("logprobs", None), )) background_tasks = create_background_tasks(app.node) return StreamingResponse(generator, background=background_tasks) @@ -79,6 +80,7 @@ async def api_completion(request: Request): stop=params.get("stop", None), echo=params.get("echo", False), stop_token_ids=params.get("stop_token_ids", None), + logprobs=params.get("logprobs", None), )) completion = None for chunk in generator: diff --git a/langport/service/server/ggml_generation_worker.py b/langport/service/server/ggml_generation_worker.py index a71994a..ad6724d 100644 --- a/langport/service/server/ggml_generation_worker.py +++ b/langport/service/server/ggml_generation_worker.py @@ -21,9 +21,10 @@ add_model_args(parser) parser.add_argument("--model-name", type=str, help="Optional display name") parser.add_argument("--limit-model-concurrency", type=int, default=8) - parser.add_argument("--batch", type=int, default=1) parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--chunk-size", type=int, default=512) + parser.add_argument("--threads", type=int, default=-1) parser.add_argument("--context-length", type=int, default=2048) parser.add_argument("--gpu-layers", type=int, default=0) parser.add_argument("--lib", type=str, default=None, choices=["avx2", "avx", "basic"], help="The path to a shared library or one of avx2, avx, basic.") @@ -58,6 +59,8 @@ gpu_layers=args.gpu_layers, lib=args.lib, model_type=args.model_type, + chunk_size=args.chunk_size, + threads=args.threads, ) app.node = GenerationModelWorker( @@ -66,7 +69,7 @@ init_neighborhoods_addr=args.neighbors, executor=executor, limit_model_concurrency=args.limit_model_concurrency, - max_batch=args.batch, + max_batch=1, stream_interval=args.stream_interval, logger=logger, ) diff --git a/langport/version.py b/langport/version.py index 9a6f28d..0225648 100644 --- a/langport/version.py +++ b/langport/version.py @@ -1 +1 @@ -LANGPORT_VERSION = "0.3.0" \ No newline at end of file +LANGPORT_VERSION = "0.3.1" \ No newline at end of file diff --git a/langport/workers/generation_worker.py b/langport/workers/generation_worker.py index 52faaa9..f7b0319 100644 --- a/langport/workers/generation_worker.py +++ b/langport/workers/generation_worker.py @@ -61,7 +61,6 @@ async def set_features(self): async def set_model_name(self): await self.set_local_state("model_name", self.executor.model_name, ttl=360) - async def generation_stream(self, task: GenerationTask): prompt_tokens = len(self.executor.tokenize(task.prompt)) max_tokens = task.max_tokens diff --git a/pyproject.toml b/pyproject.toml index 4eb3fbd..839b3d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "langport" -version = "0.3.0" +version = "0.3.1" description = "A large language model serving platform." readme = "README.md" requires-python = ">=3.8"