Skip to content

Commit

Permalink
bug : fix the type checking errors thrown by new ruff version (#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar authored Jul 1, 2024
1 parent f3a67bb commit 2731478
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
]

# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
if isinstance(self.past_key_values[0], tuple):
self.past_key_values = [list(layer) for layer in self.past_key_values]

# Update tensors in-place to allow incremental garbage collection
Expand Down Expand Up @@ -375,7 +375,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
]
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
self.encoder_last_hidden_state = self.encoder_last_hidden_state[keep_indices, -max_input_length:]

# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
if isinstance(self.past_key_values[0], tuple):
self.past_key_values = [[t for t in layer] for layer in self.past_key_values]

decoder_past_seq_len = max_decoder_input_length - 1
Expand Down Expand Up @@ -376,7 +376,7 @@ def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
batch.encoder_last_hidden_state = None

# Ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values]

# Add eventual padding tokens that were added while concatenating
Expand Down
8 changes: 4 additions & 4 deletions server/tests/utils/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ def test_batched_lora_weights_decode(
assert rd.lora_a_ptr.shape == (expected[lora_rank][0],)
assert rd.lora_b_ptr.shape == (expected[lora_rank][0],)
assert all(rd.indices == expected_indices)
assert rd.segment_starts == None
assert rd.segment_ends == None
assert rd.tmp_shrink == None
assert rd.tmp_expand == None
assert rd.segment_starts is None
assert rd.segment_ends is None
assert rd.tmp_shrink is None
assert rd.tmp_expand is None

def test_batched_lora_weights_no_segments():
batched_weights = LayerAdapterWeights()
Expand Down

0 comments on commit 2731478

Please sign in to comment.