Skip to content

Commit

Permalink
Guard model has lm_head
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 1, 2024
1 parent ad460c0 commit 43c129b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) ->
lm_head_indices=batch.prefill_head_indices,
)

if skip_lm_head:
if skip_lm_head and hasattr(self.model, "lm_head"):
# re-run through the LM head as the graph did not capture it
out = self.model.lm_head(out[0], adapter_data)

Expand Down

0 comments on commit 43c129b

Please sign in to comment.