Skip to content

Commit

Permalink
Fixed turbo lora + compile
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Oct 28, 2024
1 parent f0693e9 commit e62e0f8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def forward(
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
Expand All @@ -513,6 +514,10 @@ def forward(

if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

if skip_lm_head:
return hidden_states, None

logits, speculative_logits = self.lm_head(hidden_states, adapter_data)
return logits, speculative_logits

Expand Down
6 changes: 6 additions & 0 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,8 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) ->
lm_head_indices=batch.prefill_head_indices,
)
else:
skip_lm_head = get_speculative_tokens() > 0

# CUDA graph mode
out = model.forward(
input_ids=input_ids,
Expand All @@ -1436,6 +1438,10 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) ->
lm_head_indices=batch.prefill_head_indices,
)

if skip_lm_head:
# re-run through the LM head as the graph did not capture it
out = self.model.lm_head(out[0], adapter_data)

if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None

Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.attention.utils import block_tables_to_ragged
from lorax_server.utils.sgmv import BGMV_MAX_RANK, PunicaWrapper
from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER
from lorax_server.utils.state import BLOCK_SIZE, FLASH_INFER, get_speculative_tokens

if TYPE_CHECKING:
from lorax_server.models.flash_causal_lm import FlashCausalLMBatch
Expand Down Expand Up @@ -339,6 +339,7 @@ def trace(
adapter_data=input_state.adapter_data,
prefill_cache_indices=None,
lm_head_indices=None,
skip_lm_head=get_speculative_tokens() > 0,
)
torch.cuda.synchronize()

Expand All @@ -356,6 +357,7 @@ def trace(
adapter_data=input_state.adapter_data,
prefill_cache_indices=None,
lm_head_indices=None,
skip_lm_head=get_speculative_tokens() > 0,
)

torch.cuda.synchronize(device)
Expand Down

0 comments on commit e62e0f8

Please sign in to comment.