Skip to content

Commit

Permalink
[Model] RowParallelLinear: pass bias to quant_method.apply (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#6327)

Signed-off-by: Thomas Parnell <[email protected]>
  • Loading branch information
tdoublep authored and phil committed Aug 6, 2024
1 parent 15a84a9 commit 6ba6963
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
3 changes: 3 additions & 0 deletions tests/spec_decode/e2e/test_integration_dist_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
# precision
"dtype": "float32",
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
Expand Down
20 changes: 11 additions & 9 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ def __init__(self,
self.reduce_results = reduce_results

# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None
Expand Down Expand Up @@ -770,18 +771,19 @@ def forward(self, input_):

# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel)
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
output = output_parallel

output_bias = self.bias if self.skip_bias_add else None

if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias

def extra_repr(self) -> str:
Expand Down

0 comments on commit 6ba6963

Please sign in to comment.