Skip to content

Commit 6047187

Browse files
[ Misc ] Remove separate bias add (vllm-project#6353)
1 parent b6c16cf commit 6047187

File tree

1 file changed

+3
-15
lines changed

1 file changed

+3
-15
lines changed

vllm/model_executor/layers/linear.py

+3-15
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,7 @@ def apply(self,
9999

100100

101101
class UnquantizedLinearMethod(LinearMethodBase):
102-
"""Linear method without quantization.
103-
104-
Args:
105-
separate_bias_add: If true, add bias separately after matrix
106-
multiplication.
107-
"""
108-
109-
def __init__(self, separate_bias_add: bool = False):
110-
self.separate_bias_add = separate_bias_add
102+
"""Linear method without quantization."""
111103

112104
def create_weights(self, layer: torch.nn.Module,
113105
input_size_per_partition: int,
@@ -126,12 +118,8 @@ def apply(self,
126118
layer: torch.nn.Module,
127119
x: torch.Tensor,
128120
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
129-
weight = layer.weight
130-
if self.separate_bias_add:
131-
if bias is not None:
132-
return F.linear(x, weight) + bias
133-
return F.linear(x, weight)
134-
return F.linear(x, weight, bias)
121+
122+
return F.linear(x, layer.weight, bias)
135123

136124

137125
class LinearBase(torch.nn.Module):

0 commit comments

Comments
 (0)