diff --git a/server/lorax_server/models/bloom.py b/server/lorax_server/models/bloom.py index 0d8999e1a..3c3caab44 100644 --- a/server/lorax_server/models/bloom.py +++ b/server/lorax_server/models/bloom.py @@ -9,8 +9,7 @@ ) from lorax_server.adapters import AdapterBatchData -from lorax_server.models import CausalLM -from lorax_server.models.causal_lm import CausalLMBatch +from lorax_server.models.causal_lm import CausalLM, CausalLMBatch from lorax_server.models.custom_modeling.bloom_modeling import ( ATTN_DENSE, ATTN_QKV, diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index cda4f8870..33658b7f6 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -6,7 +6,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata -from lorax_server.models import Model +from lorax_server.models.model import Model from lorax_server.models.types import ( Batch, GeneratedText, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 36ffc5e8a..163829c9d 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -13,7 +13,7 @@ from transformers import PreTrainedTokenizerBase from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata -from lorax_server.models import Model +from lorax_server.models.model import Model from lorax_server.models.cache_manager import ( BLOCK_SIZE, get_cache_manager, diff --git a/server/lorax_server/models/galactica.py b/server/lorax_server/models/galactica.py index e356ae780..ce3aae953 100644 --- a/server/lorax_server/models/galactica.py +++ b/server/lorax_server/models/galactica.py @@ -9,8 +9,7 @@ PreTrainedTokenizerBase, ) -from lorax_server.models import CausalLM -from lorax_server.models.causal_lm import CausalLMBatch +from lorax_server.models.causal_lm import CausalLM, CausalLMBatch from lorax_server.models.custom_modeling.opt_modeling import OPTForCausalLM from lorax_server.pb import generate_pb2 from lorax_server.utils import ( diff --git a/server/lorax_server/models/gpt_neox.py b/server/lorax_server/models/gpt_neox.py index cd3177baa..fb7b48042 100644 --- a/server/lorax_server/models/gpt_neox.py +++ b/server/lorax_server/models/gpt_neox.py @@ -7,7 +7,7 @@ AutoTokenizer, ) -from lorax_server.models import CausalLM +from lorax_server.models.causal_lm import CausalLM from lorax_server.models.custom_modeling.neox_modeling import ( GPTNeoxForCausalLM, ) diff --git a/server/lorax_server/models/mpt.py b/server/lorax_server/models/mpt.py index 334570c42..ce84d1e23 100644 --- a/server/lorax_server/models/mpt.py +++ b/server/lorax_server/models/mpt.py @@ -8,8 +8,7 @@ from opentelemetry import trace from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase -from lorax_server.models import CausalLM -from lorax_server.models.causal_lm import CausalLMBatch +from lorax_server.models.causal_lm import CausalLM, CausalLMBatch from lorax_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, ) diff --git a/server/lorax_server/models/opt.py b/server/lorax_server/models/opt.py index 666caa824..461d9a265 100644 --- a/server/lorax_server/models/opt.py +++ b/server/lorax_server/models/opt.py @@ -7,7 +7,7 @@ AutoTokenizer, ) -from lorax_server.models import CausalLM +from lorax_server.models.causal_lm import CausalLM from lorax_server.models.custom_modeling.opt_modeling import OPTForCausalLM from lorax_server.utils import ( Weights, diff --git a/server/lorax_server/models/rw.py b/server/lorax_server/models/rw.py index 97e94bd44..3928d753d 100644 --- a/server/lorax_server/models/rw.py +++ b/server/lorax_server/models/rw.py @@ -3,7 +3,7 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from lorax_server.models import CausalLM +from lorax_server.models.causal_lm import CausalLM class RW(CausalLM): diff --git a/server/lorax_server/models/santacoder.py b/server/lorax_server/models/santacoder.py index 22de0a1e9..f11407ce1 100644 --- a/server/lorax_server/models/santacoder.py +++ b/server/lorax_server/models/santacoder.py @@ -4,7 +4,7 @@ import torch.distributed from transformers import AutoModelForCausalLM, AutoTokenizer -from lorax_server.models import CausalLM +from lorax_server.models.causal_lm import CausalLM FIM_PREFIX = "" FIM_MIDDLE = "" diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index 1abbb6a77..6d20e8097 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -5,7 +5,7 @@ from opentelemetry import trace from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedTokenizerBase -from lorax_server.models import Model +from lorax_server.models.model import Model from lorax_server.models.types import ( Batch, GeneratedText, diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py index 8055a1bb5..ecd8e52a8 100644 --- a/server/tests/models/test_model.py +++ b/server/tests/models/test_model.py @@ -3,7 +3,7 @@ from transformers import AutoTokenizer -from lorax_server.models import Model +from lorax_server.models.model import Model def get_test_model(): @@ -17,9 +17,7 @@ def generate_token(self, batch): model_id = "meta-llama/Llama-2-7b-hf" tokenizer = AutoTokenizer.from_pretrained(model_id) - model = TestModel( - model_id, torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") - ) + model = TestModel(model_id, torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")) return model @@ -28,17 +26,13 @@ def test_decode_streaming_english_spaces(): model = get_test_model() truth = "Hello here, this is a simple test" all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243] - assert ( - all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"] - ) + assert all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"] decoded_text = "" offset = 0 token_offset = 0 for i in range(len(all_input_ids)): - text, offset, token_offset = model.decode_token( - all_input_ids[: i + 1], offset, token_offset - ) + text, offset, token_offset = model.decode_token(all_input_ids[: i + 1], offset, token_offset) decoded_text += text assert decoded_text == truth @@ -71,9 +65,7 @@ def test_decode_streaming_chinese_utf8(): offset = 0 token_offset = 0 for i in range(len(all_input_ids)): - text, offset, token_offset = model.decode_token( - all_input_ids[: i + 1], offset, token_offset - ) + text, offset, token_offset = model.decode_token(all_input_ids[: i + 1], offset, token_offset) decoded_text += text assert decoded_text == truth