Skip to content

Commit

Permalink
Fixed imports
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 4, 2024
1 parent b4761c2 commit f9b1fc5
Show file tree
Hide file tree
Showing 11 changed files with 15 additions and 26 deletions.
3 changes: 1 addition & 2 deletions server/lorax_server/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
)

from lorax_server.adapters import AdapterBatchData
from lorax_server.models import CausalLM
from lorax_server.models.causal_lm import CausalLMBatch
from lorax_server.models.causal_lm import CausalLM, CausalLMBatch
from lorax_server.models.custom_modeling.bloom_modeling import (
ATTN_DENSE,
ATTN_QKV,
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata
from lorax_server.models import Model
from lorax_server.models.model import Model
from lorax_server.models.types import (
Batch,
GeneratedText,
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import PreTrainedTokenizerBase

from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata
from lorax_server.models import Model
from lorax_server.models.model import Model
from lorax_server.models.cache_manager import (
BLOCK_SIZE,
get_cache_manager,
Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/galactica.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
PreTrainedTokenizerBase,
)

from lorax_server.models import CausalLM
from lorax_server.models.causal_lm import CausalLMBatch
from lorax_server.models.causal_lm import CausalLM, CausalLMBatch
from lorax_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from lorax_server.pb import generate_pb2
from lorax_server.utils import (
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
AutoTokenizer,
)

from lorax_server.models import CausalLM
from lorax_server.models.causal_lm import CausalLM
from lorax_server.models.custom_modeling.neox_modeling import (
GPTNeoxForCausalLM,
)
Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase

from lorax_server.models import CausalLM
from lorax_server.models.causal_lm import CausalLMBatch
from lorax_server.models.causal_lm import CausalLM, CausalLMBatch
from lorax_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
AutoTokenizer,
)

from lorax_server.models import CausalLM
from lorax_server.models.causal_lm import CausalLM
from lorax_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from lorax_server.utils import (
Weights,
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from lorax_server.models import CausalLM
from lorax_server.models.causal_lm import CausalLM


class RW(CausalLM):
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.distributed
from transformers import AutoModelForCausalLM, AutoTokenizer

from lorax_server.models import CausalLM
from lorax_server.models.causal_lm import CausalLM

FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from opentelemetry import trace
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedTokenizerBase

from lorax_server.models import Model
from lorax_server.models.model import Model
from lorax_server.models.types import (
Batch,
GeneratedText,
Expand Down
18 changes: 5 additions & 13 deletions server/tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from transformers import AutoTokenizer

from lorax_server.models import Model
from lorax_server.models.model import Model


def get_test_model():
Expand All @@ -17,9 +17,7 @@ def generate_token(self, batch):
model_id = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = TestModel(
model_id, torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
)
model = TestModel(model_id, torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu"))
return model


Expand All @@ -28,17 +26,13 @@ def test_decode_streaming_english_spaces():
model = get_test_model()
truth = "Hello here, this is a simple test"
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
assert (
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
)
assert all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]

decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
text, offset, token_offset = model.decode_token(all_input_ids[: i + 1], offset, token_offset)
decoded_text += text

assert decoded_text == truth
Expand Down Expand Up @@ -71,9 +65,7 @@ def test_decode_streaming_chinese_utf8():
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
text, offset, token_offset = model.decode_token(all_input_ids[: i + 1], offset, token_offset)
decoded_text += text

assert decoded_text == truth

0 comments on commit f9b1fc5

Please sign in to comment.