Skip to content

Commit

Permalink
fix: Mixtral adapter loading wraps lm_head (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Dec 14, 2023
1 parent 9febb95 commit 3ad7aae
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
TensorParallelHead,
get_linear,
)
from lorax_server.utils.lora import AdapterBatchData
from lorax_server.utils.lora import AdapterBatchData, LM_HEAD

if not HAS_FLASH_ATTN_V2:
raise ImportError("Mixtral model requires flash attn v2")
Expand Down Expand Up @@ -943,11 +943,11 @@ def __init__(self, config, weights):
super().__init__()

self.model = MixtralModel(config, weights)
self.lm_head = TensorParallelHead.load(
self.lm_head = TensorParallelAdapterRowLinear.load(TensorParallelHead.load(
config,
prefix="lm_head",
weights=weights,
)
), 0, LM_HEAD, process_group=weights.process_group)
self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")
Expand Down Expand Up @@ -989,5 +989,5 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
logits = self.lm_head(hidden_states, adapter_data)
return logits
48 changes: 24 additions & 24 deletions server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "0.15.0"
huggingface-hub = "^0.19.4"
transformers = "4.35.2"
transformers = "4.36.0"
einops = "^0.6.1"
tiktoken = "^0.5.2"
texttable = { version = "^1.6.7", optional = true }
Expand Down
2 changes: 1 addition & 1 deletion server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ tiktoken==0.5.2 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "4.0"
torch==2.1.1+cu118 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.35.2 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.36.0 ; python_version >= "3.9" and python_version < "4.0"
triton==2.1.0 ; python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "4.0"
Expand Down

0 comments on commit 3ad7aae

Please sign in to comment.