Skip to content

Commit

Permalink
[Bugfix] Fix fully sharded LoRAs with Mixtral (vllm-project#11390)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Greene <[email protected]>
  • Loading branch information
n1hility authored Dec 22, 2024
1 parent 72d9c31 commit f1d1bf6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):


@pytest.mark.parametrize("tp_size", [4])
@pytest.mark.parametrize("fully_shard", [True, False])
def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules,
tp_size):
tp_size, fully_shard):
"""This LoRA model has all supported Mixtral target modules"""

if torch.cuda.device_count() < tp_size:
Expand All @@ -82,6 +83,7 @@ def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
fully_sharded_loras=fully_shard,
max_lora_rank=32,
)

Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,9 @@ def forward(self, input_):
if self.base_layer.skip_bias_add else None)
return output, output_bias

# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
Expand Down

0 comments on commit f1d1bf6

Please sign in to comment.