Skip to content

Commit

Permalink
Cleanup comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 1, 2024
1 parent 3ebcbea commit 922c5d6
Showing 1 changed file with 9 additions and 32 deletions.
41 changes: 9 additions & 32 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
segment_indices=adapter_segment_indices,
)

# logger.info("!!! FILTER slots {} -> {}", self.slots, slots)
# logger.info("!!! FILTER slots_indices {} -> {}", self.slot_indices, slot_indices)

return type(self)(
batch_id=self.batch_id,
requests=requests,
Expand Down Expand Up @@ -767,19 +764,14 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
sequence_processors=sequence_processors,
)

# Discard speculative IDs if they are not present in all batches
# We skip computing the speculative_ids when the batch size is too large, so
# we must check that all batches have them, otherwise they must be discarded
speculative_ids = None
if get_speculative_tokens() > 0:
keep_speculative_ids = all(b.speculative_ids is not None for b in batches)
if not keep_speculative_ids:
if all(b.speculative_ids is not None for b in batches):
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
else:
logger.info("Discarding speculative IDs, not every batch has them")

speculative_ids = (
torch.cat(
[b.speculative_ids for b in batches], dim=0)
if keep_speculative_ids else None
)
else:
speculative_ids = None

if adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
Expand All @@ -791,9 +783,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
segment_indices=adapter_segment_indices,
)

# logger.info("!!! CONCATENATE slots {} -> {}", [b.slots for b in batches], slots)
# logger.info("!!! CONCATENATE slots_indices {} -> {}", [b.slot_indices for b in batches], slot_indices)

return cls(
batch_id=batches[0].batch_id,
requests=requests,
Expand Down Expand Up @@ -1064,9 +1053,6 @@ def prepare_for_prefill(self):
segment_indices=adapter_segment_indices,
)

# logger.info("!!! PREPARE_FOR_PREFILL slots {}", self.slots)
# logger.info("!!! PREPARE_FOR_PREFILL slots_indices {}", self.slot_indices)

def __len__(self):
return len(self.requests)

Expand Down Expand Up @@ -1525,11 +1511,6 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) ->
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length

# logger.info("!!! BLOCKS={} {}\n SLOTS={} {}\n SLOT_INDICES={} {}",
# block_tables.tolist(), block_tables.shape,
# batch.slots.tolist(), batch.slots.shape,
# batch.slot_indices.tolist(), batch.slot_indices.shape)

if batch.speculative_ids is not None:
speculative_ids = batch.speculative_ids

Expand All @@ -1540,18 +1521,14 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) ->
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)

# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices = (batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# logger.info("!!! SLOT INDICES {} -> {}", batch.slot_indices.tolist(), slot_indices.tolist())

slots = batch.slots[slot_indices]

# slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# logger.info("!!! NEW SLOTS {}", slots.tolist(), slots.shape)

# logger.info("!!! BEFORE {} {}", input_lengths, batch.cache_lengths_tensor)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
cache_lengths_tensor = (batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)).reshape(-1)
# logger.info("!!! AFTER {} {}", input_lengths, cache_lengths_tensor)

block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous()
max_s = max_s + speculative_length
Expand Down

0 comments on commit 922c5d6

Please sign in to comment.