From a90d44392384c6e1b2dbf86f2eea2f6830f88f8f Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 10 Jan 2024 10:35:46 -0800 Subject: [PATCH] OpenAI v1 Chat Completions API (#171) --- docs/reference/openapi.json | 5 ++ proto/generate.proto | 2 + router/client/src/client.rs | 1 + router/src/health.rs | 1 + router/src/lib.rs | 40 ++++++++++- router/src/scheduler.rs | 1 + router/src/server.rs | 71 +++++++++++++++++-- router/src/validation.rs | 3 + server/lorax_server/models/bloom.py | 4 +- server/lorax_server/models/causal_lm.py | 6 +- server/lorax_server/models/flash_causal_lm.py | 14 ++-- server/lorax_server/models/flash_mistral.py | 6 +- server/lorax_server/models/flash_mixtral.py | 6 +- server/lorax_server/models/galactica.py | 6 +- server/lorax_server/models/model.py | 2 + server/lorax_server/models/mpt.py | 4 +- server/lorax_server/models/seq2seq_lm.py | 6 +- server/lorax_server/server.py | 12 +++- server/lorax_server/utils/adapter.py | 10 ++- server/lorax_server/utils/tokenizer.py | 28 ++++++++ server/tests/models/test_bloom.py | 5 +- server/tests/models/test_causal_lm.py | 5 +- server/tests/models/test_santacoder.py | 3 + server/tests/models/test_seq2seq_lm.py | 5 +- 24 files changed, 219 insertions(+), 27 deletions(-) create mode 100644 server/lorax_server/utils/tokenizer.py diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index 18b1d8562..2dc12d229 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -607,6 +607,11 @@ "api_token": { "type": "string", "nullable": true + }, + "apply_chat_template": { + "type": "boolean", + "default": "false", + "example": true } } }, diff --git a/proto/generate.proto b/proto/generate.proto index c5b52d1dd..e7afe4f26 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -102,6 +102,8 @@ message Request { bool prefill_logprobs = 6; /// Adapter index uint32 adapter_index = 7; + /// Apply chat template to inputs + bool apply_chat_template = 8; } message Batch { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 6442c97c6..68b000f09 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -134,6 +134,7 @@ impl Client { }), adapter_index: 0, prefill_logprobs: true, + apply_chat_template: false, }); n_tokens += max_input_length; } diff --git a/router/src/health.rs b/router/src/health.rs index 237b3596a..578ac5716 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -52,6 +52,7 @@ impl Health { ignore_eos_token: false, }), adapter_index: 0, + apply_chat_template: false, }; let batch = Batch { id: BATCH_ID, diff --git a/router/src/lib.rs b/router/src/lib.rs index ce5003bb6..f011462cc 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -145,6 +145,9 @@ pub(crate) struct GenerateParameters { #[schema(default = "true")] pub decoder_input_details: bool, #[serde(default)] + #[schema(default = "false")] + pub apply_chat_template: bool, + #[serde(default)] #[schema( exclusive_minimum = 0, nullable = true, @@ -177,6 +180,7 @@ fn default_parameters() -> GenerateParameters { watermark: false, details: false, decoder_input_details: false, + apply_chat_template: false, seed: None, } } @@ -320,7 +324,7 @@ struct UsageInfo { #[derive(Clone, Debug, Deserialize, ToSchema)] struct ChatCompletionRequest { model: String, - messages: Vec, + messages: Vec>, temperature: Option, top_p: Option, n: Option, @@ -451,6 +455,40 @@ impl From for CompatGenerateRequest { watermark: false, details: true, decoder_input_details: req.logprobs.is_some(), + apply_chat_template: false, + seed: None, + }, + stream: req.stream.unwrap_or(false), + } + } +} + +impl From for CompatGenerateRequest { + fn from(req: ChatCompletionRequest) -> Self { + CompatGenerateRequest { + inputs: serde_json::to_string(&req.messages).unwrap(), + parameters: GenerateParameters { + adapter_id: req.model.parse().ok(), + adapter_source: None, + api_token: None, + best_of: req.n.map(|x| x as usize), + temperature: req.temperature, + repetition_penalty: None, + top_k: None, + top_p: req.top_p, + typical_p: None, + do_sample: !req.n.is_none(), + max_new_tokens: req + .max_tokens + .map(|x| x as u32) + .unwrap_or(default_max_new_tokens()), + return_full_text: None, + stop: req.stop, + truncate: None, + watermark: false, + details: true, + decoder_input_details: false, + apply_chat_template: true, seed: None, }, stream: req.stream.unwrap_or(false), diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 09577c4fb..18b7d8663 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -334,6 +334,7 @@ impl AdapterSchedulerState { parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), adapter_index: adapter.index(), + apply_chat_template: entry.request.apply_chat_template, }); // Set batch_time entry.batch_time = Some(Instant::now()); diff --git a/router/src/server.rs b/router/src/server.rs index 41f54d4c9..c4d051e8e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -3,10 +3,10 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ - BestOfSequence, CompatGenerateRequest, CompletionRequest, CompletionResponse, - CompletionStreamResponse, Details, ErrorResponse, FinishReason, GenerateParameters, - GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, StreamDetails, - StreamResponse, Token, Validation, + BestOfSequence, ChatCompletionRequest, CompatGenerateRequest, CompletionRequest, + CompletionResponse, CompletionStreamResponse, Details, ErrorResponse, FinishReason, + GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, + StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -78,7 +78,7 @@ async fn compat_generate( } } -/// Generate tokens if `stream == false` or a stream of token if `stream == true` +/// OpenAI compatible completions endpoint #[utoipa::path( post, tag = "LoRAX", @@ -138,6 +138,66 @@ async fn completions_v1( } } +/// OpenAI compatible chat completions endpoint +#[utoipa::path( +post, +tag = "LoRAX", +path = "/v1/chat/completions", +request_body = ChatCompletionRequest, +responses( +(status = 200, description = "Generated Text", +content( +("application/json" = ChatCompletionResponse), +("text/event-stream" = ChatCompletionStreamResponse), +)), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) +)] +#[instrument(skip(infer, req))] +async fn chat_completions_v1( + default_return_full_text: Extension, + infer: Extension, + req: Json, +) -> Result)> { + let req = req.0; + let mut gen_req = CompatGenerateRequest::from(req); + + // default return_full_text given the pipeline_tag + if gen_req.parameters.return_full_text.is_none() { + gen_req.parameters.return_full_text = Some(default_return_full_text.0) + } + + // switch on stream + if gen_req.stream { + let callback = move |resp: StreamResponse| { + Event::default() + .json_data(CompletionStreamResponse::from(resp)) + .map_or_else( + |err| { + tracing::error!("Failed to serialize CompletionStreamResponse: {err}"); + Event::default() + }, + |data| data, + ) + }; + + let (headers, stream) = + generate_stream_with_callback(infer, Json(gen_req.into()), callback).await; + Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response()) + } else { + let (headers, generation) = generate(infer, Json(gen_req.into())).await?; + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(vec![CompletionResponse::from(generation.0)])).into_response()) + } +} + /// LoRAX endpoint info #[utoipa::path( get, @@ -771,6 +831,7 @@ pub async fn run( .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/completions", post(completions_v1)) + .route("/v1/chat/completions", post(chat_completions_v1)) // AWS Sagemaker route .route("/invocations", post(compat_generate)) // Base Health route diff --git a/router/src/validation.rs b/router/src/validation.rs index 9d211e685..8f985cdcc 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -145,6 +145,7 @@ impl Validation { watermark, adapter_id, decoder_input_details, + apply_chat_template, .. } = request.parameters; @@ -270,6 +271,7 @@ impl Validation { parameters, stopping_parameters, adapter, + apply_chat_template, }) } @@ -344,6 +346,7 @@ pub(crate) struct ValidGenerateRequest { pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, pub adapter: Adapter, + pub apply_chat_template: bool, } #[derive(Error, Debug)] diff --git a/server/lorax_server/models/bloom.py b/server/lorax_server/models/bloom.py index 1448f9aab..09b6c829a 100644 --- a/server/lorax_server/models/bloom.py +++ b/server/lorax_server/models/bloom.py @@ -20,6 +20,7 @@ weight_files, Weights, ) +from lorax_server.utils.tokenizer import TokenizerManager class BloomCausalLMBatch(CausalLMBatch): @@ -28,10 +29,11 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) + batch = super().from_pb(pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device) batch.keys_head_dim_last = False return batch diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index f215f3a2c..6658281c3 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -1,3 +1,4 @@ +import json import torch import inspect @@ -15,6 +16,7 @@ ) from lorax_server.pb import generate_pb2 from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from lorax_server.utils.tokenizer import TokenizerManager tracer = trace.get_tracer(__name__) @@ -69,6 +71,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": @@ -86,7 +89,8 @@ def from_pb( adapter_indices_list = [] for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + req_inputs = tokenizers.get_inputs(r, tokenizer) + inputs.append(req_inputs) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index e55d0ddd9..40f3119ac 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1,4 +1,5 @@ from collections import defaultdict +import json import math import itertools from loguru import logger @@ -29,11 +30,11 @@ from lorax_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map from lorax_server.utils.dist import MEMORY_FRACTION -from lorax_server.utils.lora import LM_HEAD, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights +from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.weights import shard_on_dim from lorax_server.utils.graph import GraphCache -from lorax_server.utils.sgmv import get_tmp_tensor +from lorax_server.utils.tokenizer import TokenizerManager tracer = trace.get_tracer(__name__) @@ -114,13 +115,15 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": batch_inputs = [] max_truncation = 0 for r in pb.requests: - batch_inputs.append(r.inputs) + inputs = tokenizers.get_inputs(r, tokenizer) + batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( @@ -746,7 +749,7 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index): elif adapter_id != BASE_MODEL_ADAPTER_ID: logger.info(f"Loading adapter weights into model: {adapter_id}") weight_names = tuple([v[0] for v in self.target_to_layer.values()]) - module_map, adapter_config, adapter_weight_names = load_module_map( + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map( self.model_id, adapter_id, adapter_source, weight_names ) @@ -758,6 +761,9 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index): if len(unused_weight_names) > 0: logger.warning(f"{adapter_id} unused adapter weights: {unused_weight_names}") + + if adapter_tokenizer is not None: + self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) self.adapter_id = adapter_id diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index 1b4beae1a..bdf0a33a9 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -1,3 +1,4 @@ +import json import math import torch import torch.distributed @@ -32,6 +33,7 @@ from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData, AdapterBatchMetadata from lorax_server.utils.segments import find_segments +from lorax_server.utils.tokenizer import TokenizerManager tracer = trace.get_tracer(__name__) @@ -55,6 +57,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -64,7 +67,8 @@ def from_pb( batch_inputs = [] max_truncation = 0 for r in pb.requests: - batch_inputs.append(r.inputs) + inputs = tokenizers.get_inputs(r, tokenizer) + batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 18269da47..f1f507f52 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -1,3 +1,4 @@ +import json import math import torch import torch.distributed @@ -39,6 +40,7 @@ from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID from lorax_server.utils.lora import LM_HEAD, AdapterBatchData, AdapterBatchMetadata from lorax_server.utils.segments import find_segments +from lorax_server.utils.tokenizer import TokenizerManager tracer = trace.get_tracer(__name__) @@ -62,6 +64,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -71,7 +74,8 @@ def from_pb( batch_inputs = [] max_truncation = 0 for r in pb.requests: - batch_inputs.append(r.inputs) + inputs = tokenizers.get_inputs(r, tokenizer) + batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( diff --git a/server/lorax_server/models/galactica.py b/server/lorax_server/models/galactica.py index e95340afa..6d8fb8933 100644 --- a/server/lorax_server/models/galactica.py +++ b/server/lorax_server/models/galactica.py @@ -1,3 +1,4 @@ +import json import re import torch import torch.distributed @@ -20,6 +21,7 @@ weight_files, Weights, ) +from lorax_server.utils.tokenizer import TokenizerManager # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py @@ -73,6 +75,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "GalacticaCausalLMBatch": @@ -90,7 +93,8 @@ def from_pb( for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic - inputs.append(escape_custom_split_sequence(r.inputs)) + req_inputs = tokenizers.get_inputs(r, tokenizer) + inputs.append(escape_custom_split_sequence(req_inputs)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 1b3f1e543..f782bf294 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -8,6 +8,7 @@ from lorax_server.models.types import Batch, GeneratedText from lorax_server.pb.generate_pb2 import InfoResponse from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID +from lorax_server.utils.tokenizer import TokenizerManager B = TypeVar("B", bound=Batch) @@ -26,6 +27,7 @@ def __init__( ): self.model = model.eval() self.tokenizer = tokenizer + self.tokenizers = TokenizerManager() self.all_special_ids = set(tokenizer.all_special_ids) self.requires_padding = requires_padding self.dtype = dtype diff --git a/server/lorax_server/models/mpt.py b/server/lorax_server/models/mpt.py index 52a06718b..a2388eb68 100644 --- a/server/lorax_server/models/mpt.py +++ b/server/lorax_server/models/mpt.py @@ -19,6 +19,7 @@ weight_files, Weights, ) +from lorax_server.utils.tokenizer import TokenizerManager tracer = trace.get_tracer(__name__) @@ -29,10 +30,11 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) + batch = super().from_pb(pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device) batch.keys_head_dim_last = False return batch diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index 2004a2c2d..250b8d628 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -1,3 +1,4 @@ +import json import torch from dataclasses import dataclass @@ -14,6 +15,7 @@ ) from lorax_server.pb import generate_pb2 from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from lorax_server.utils.tokenizer import TokenizerManager tracer = trace.get_tracer(__name__) @@ -71,6 +73,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "Seq2SeqLMBatch": @@ -89,7 +92,8 @@ def from_pb( padding_right_offset = 0 max_decode_tokens = 0 for i, r in enumerate(pb.requests): - inputs.append(r.inputs) + req_inputs = tokenizers.get_inputs(r, tokenizer) + inputs.append(req_inputs) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index db2a4dfbd..12a80d144 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -69,7 +69,11 @@ async def FilterBatch(self, request, context): async def Warmup(self, request, context): batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device + request.batch, + self.model.tokenizer, + self.model.tokenizers, + self.model.dtype, + self.model.device, ) max_supported_total_tokens = self.model.warmup(batch) @@ -79,7 +83,11 @@ async def Warmup(self, request, context): async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device + request.batch, + self.model.tokenizer, + self.model.tokenizers, + self.model.dtype, + self.model.device, ) generations, next_batch = self.model.generate_token(batch) diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index a3ed66c12..d568c74c4 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -11,7 +11,7 @@ from peft import LoraConfig from peft.utils import transpose from safetensors.torch import load_file, save_file -from transformers import AutoConfig +from transformers import AutoConfig, AutoTokenizer from tqdm import tqdm from filelock import FileLock @@ -42,6 +42,12 @@ def load_module_map(model_id, adapter_id, adapter_source, weight_names): f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.") + try: + adapter_tokenizer = AutoTokenizer.from_pretrained(config_path) + except Exception: + # Adapter does not have a tokenizer, so fallback to base model tokenizer + adapter_tokenizer = None + # load adapter weights from all shards (should have relatively small memory footprint) adapter_filenames = source.weight_files() adapter_weights = {} @@ -63,7 +69,7 @@ def load_module_map(model_id, adapter_id, adapter_source, weight_names): } adapter_weight_names.add(lora_a_name) adapter_weight_names.add(lora_b_name) - return module_map, adapter_config, adapter_weight_names + return module_map, adapter_config, adapter_weight_names, adapter_tokenizer def compute_delta_weight( diff --git a/server/lorax_server/utils/tokenizer.py b/server/lorax_server/utils/tokenizer.py new file mode 100644 index 000000000..4a051861d --- /dev/null +++ b/server/lorax_server/utils/tokenizer.py @@ -0,0 +1,28 @@ +import json +from typing import Optional +from transformers import PreTrainedTokenizerBase + +from lorax_server.pb import generate_pb2 + + +class TokenizerManager: + def __init__(self): + self.tokenizers = {} + + def add_tokenizer(self, adapter_idx: int, tokenizer: PreTrainedTokenizerBase): + self.tokenizers[adapter_idx] = tokenizer + + def get_tokenizer(self, adapter_idx: int, default: PreTrainedTokenizerBase) -> Optional[PreTrainedTokenizerBase]: + return self.tokenizers.get(adapter_idx, default) + + def get_inputs( + self, + r: generate_pb2.Request, + base_tokenizer: PreTrainedTokenizerBase, + ) -> str: + inputs = r.inputs + if r.apply_chat_template: + inputs = json.loads(inputs) + tokenizer = self.get_tokenizer(r.adapter_index, base_tokenizer) + inputs = tokenizer.apply_chat_template(inputs, add_generation_prompt=True, tokenize=False) + return inputs diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index ac73b6f1c..88e4ed5eb 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -8,6 +8,7 @@ from lorax_server.models.causal_lm import CausalLMBatch from lorax_server.utils import weight_hub_files, download_weights from lorax_server.models.bloom import BloomCausalLMBatch, BLOOMSharded +from lorax_server.utils.tokenizer import TokenizerManager @pytest.fixture(scope="session") @@ -44,7 +45,7 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer): return BloomCausalLMBatch.from_pb( - default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu") + default_pb_batch, bloom_560m_tokenizer, TokenizerManager(), torch.float32, torch.device("cpu") ) @@ -58,7 +59,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer) batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return BloomCausalLMBatch.from_pb( - batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu") + batch_pb, bloom_560m_tokenizer, TokenizerManager(), torch.float32, torch.device("cpu") ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 8b177f991..9ce4bd813 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -6,6 +6,7 @@ from lorax_server.pb import generate_pb2 from lorax_server.models.causal_lm import CausalLM, CausalLMBatch +from lorax_server.utils.tokenizer import TokenizerManager @pytest.fixture(scope="session") @@ -40,7 +41,7 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): return CausalLMBatch.from_pb( - default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu") + default_pb_batch, gpt2_tokenizer, TokenizerManager(), torch.float32, torch.device("cpu") ) @@ -54,7 +55,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) return CausalLMBatch.from_pb( - batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu") + batch_pb, gpt2_tokenizer, TokenizerManager(), torch.float32, torch.device("cpu") ) diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index ef9854e56..d6efa4af6 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -3,6 +3,7 @@ from lorax_server.pb import generate_pb2 from lorax_server.models.causal_lm import CausalLMBatch from lorax_server.models.santacoder import SantaCoder +from lorax_server.utils.tokenizer import TokenizerManager @pytest.fixture(scope="session") @@ -49,6 +50,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat batch = CausalLMBatch.from_pb( default_pb_batch, default_santacoder.tokenizer, + TokenizerManager(), default_santacoder.dtype, default_santacoder.device, ) @@ -77,6 +79,7 @@ def test_fim_santacoder_generate_token_completion( batch = CausalLMBatch.from_pb( default_fim_pb_batch, default_santacoder.tokenizer, + TokenizerManager(), default_santacoder.dtype, default_santacoder.device, ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 44a0520fd..4c5f15b25 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -7,6 +7,7 @@ from lorax_server.pb import generate_pb2 from lorax_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch +from lorax_server.utils.tokenizer import TokenizerManager @pytest.fixture(scope="session") @@ -43,7 +44,7 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer): return Seq2SeqLMBatch.from_pb( - default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu") + default_pb_batch, mt0_small_tokenizer, TokenizerManager(), torch.float32, torch.device("cpu") ) @@ -57,7 +58,7 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return Seq2SeqLMBatch.from_pb( - batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu") + batch_pb, mt0_small_tokenizer, TokenizerManager(), torch.float32, torch.device("cpu") )