Skip to content

Commit

Permalink
Fixed sliding window
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Mar 13, 2024
1 parent e33a5aa commit 8525e5f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,20 +295,23 @@ def from_pb(
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
prefill_cache_indices = torch.cat(prefill_cache_indices)
if SLIDING_WINDOW is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
prefill_cache_indices = prefill_cache_indices[0]
if SLIDING_WINDOW is not None:
prefill_cache_indices = prefill_cache_indices[0]

cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)

position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
prefill_cache_indices = prefill_cache_indices.to(device)
if SLIDING_WINDOW is not None:
prefill_cache_indices = prefill_cache_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
Expand Down Expand Up @@ -364,7 +367,7 @@ def from_pb(
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
prefill_cache_indices=prefill_cache_indices,
prefill_cache_indices=prefill_cache_indices if SLIDING_WINDOW is not None else None,
)

@tracer.start_as_current_span("filter")
Expand Down

0 comments on commit 8525e5f

Please sign in to comment.