Skip to content

Commit f8f469e

Browse files
cennnyoukaichao
authored andcommitted
[misc] remove python function call for custom activation op (vllm-project#11885)
Co-authored-by: youkaichao <[email protected]>
1 parent 84b3a16 commit f8f469e

File tree

2 files changed

+46
-60
lines changed

2 files changed

+46
-60
lines changed

vllm/_custom_ops.py

-27
Original file line numberDiff line numberDiff line change
@@ -34,33 +34,6 @@ def register_fake(fn):
3434
from torch.library import impl_abstract as register_fake
3535

3636

37-
# activation ops
38-
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
39-
torch.ops._C.gelu_and_mul(out, x)
40-
41-
42-
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
43-
torch.ops._C.gelu_tanh_and_mul(out, x)
44-
45-
46-
def fatrelu_and_mul(out: torch.Tensor,
47-
x: torch.Tensor,
48-
threshold: float = 0.0) -> None:
49-
torch.ops._C.fatrelu_and_mul(out, x, threshold)
50-
51-
52-
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
53-
torch.ops._C.gelu_fast(out, x)
54-
55-
56-
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
57-
torch.ops._C.gelu_new(out, x)
58-
59-
60-
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
61-
torch.ops._C.gelu_quick(out, x)
62-
63-
6437
# page attention ops
6538
def paged_attention_v1(
6639
out: torch.Tensor,

vllm/model_executor/layers/activation.py

+46-33
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class FatreluAndMul(CustomOp):
3030
def __init__(self, threshold: float = 0.):
3131
super().__init__()
3232
self.threshold = threshold
33+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
34+
self.op = torch.ops._C.fatrelu_and_mul
3335

3436
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
3537
d = x.shape[-1] // 2
@@ -39,12 +41,10 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
3941
return x1 * x2
4042

4143
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
42-
from vllm import _custom_ops as ops
43-
4444
d = x.shape[-1] // 2
4545
output_shape = (x.shape[:-1] + (d, ))
4646
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
47-
ops.fatrelu_and_mul(out, x, self.threshold)
47+
self.op(out, x, self.threshold)
4848
return out
4949

5050

@@ -103,34 +103,35 @@ def __init__(self, approximate: str = "none"):
103103
self.approximate = approximate
104104
if approximate not in ("none", "tanh"):
105105
raise ValueError(f"Unknown approximate mode: {approximate}")
106+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
107+
if approximate == "none":
108+
self.op = torch.ops._C.gelu_and_mul
109+
elif approximate == "tanh":
110+
self.op = torch.ops._C.gelu_tanh_and_mul
111+
elif current_platform.is_xpu():
112+
from vllm._ipex_ops import ipex_ops
113+
if approximate == "none":
114+
self.op = ipex_ops.gelu_and_mul
115+
else:
116+
self.op = ipex_ops.gelu_tanh_and_mul
106117

107118
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
108119
"""PyTorch-native implementation equivalent to forward()."""
109120
d = x.shape[-1] // 2
110121
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
111122

112123
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
113-
from vllm import _custom_ops as ops
114-
115124
d = x.shape[-1] // 2
116125
output_shape = (x.shape[:-1] + (d, ))
117126
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
118-
if self.approximate == "none":
119-
ops.gelu_and_mul(out, x)
120-
elif self.approximate == "tanh":
121-
ops.gelu_tanh_and_mul(out, x)
127+
self.op(out, x)
122128
return out
123129

124130
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
125-
from vllm._ipex_ops import ipex_ops as ops
126-
127131
d = x.shape[-1] // 2
128132
output_shape = (x.shape[:-1] + (d, ))
129133
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
130-
if self.approximate == "none":
131-
ops.gelu_and_mul(out, x)
132-
elif self.approximate == "tanh":
133-
ops.gelu_tanh_and_mul(out, x)
134+
self.op(out, x)
134135
return out
135136

136137
def extra_repr(self) -> str:
@@ -140,65 +141,77 @@ def extra_repr(self) -> str:
140141
@CustomOp.register("gelu_new")
141142
class NewGELU(CustomOp):
142143

144+
def __init__(self):
145+
super().__init__()
146+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
147+
self.op = torch.ops._C.gelu_new
148+
elif current_platform.is_xpu():
149+
from vllm._ipex_ops import ipex_ops
150+
self.op = ipex_ops.gelu_new
151+
143152
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
144153
"""PyTorch-native implementation equivalent to forward()."""
145154
c = math.sqrt(2.0 / math.pi)
146155
return 0.5 * x * (1.0 + torch.tanh(c *
147156
(x + 0.044715 * torch.pow(x, 3.0))))
148157

149158
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
150-
from vllm import _custom_ops as ops
151-
152159
out = torch.empty_like(x)
153-
ops.gelu_new(out, x)
160+
self.op(out, x)
154161
return out
155162

156163
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
157-
from vllm._ipex_ops import ipex_ops as ops
158-
159-
return ops.gelu_new(x)
164+
return self.op(x)
160165

161166

162167
@CustomOp.register("gelu_fast")
163168
class FastGELU(CustomOp):
164169

170+
def __init__(self):
171+
super().__init__()
172+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
173+
self.op = torch.ops._C.gelu_fast
174+
elif current_platform.is_xpu():
175+
from vllm._ipex_ops import ipex_ops
176+
self.op = ipex_ops.gelu_fast
177+
165178
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
166179
"""PyTorch-native implementation equivalent to forward()."""
167180
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
168181
(1.0 + 0.044715 * x * x)))
169182

170183
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
171-
from vllm import _custom_ops as ops
172-
173184
out = torch.empty_like(x)
174-
ops.gelu_fast(out, x)
185+
self.op(out, x)
175186
return out
176187

177188
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
178-
from vllm._ipex_ops import ipex_ops as ops
179-
180-
return ops.gelu_fast(x)
189+
return self.op(x)
181190

182191

183192
@CustomOp.register("quick_gelu")
184193
class QuickGELU(CustomOp):
185194
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
195+
def __init__(self):
196+
super().__init__()
197+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
198+
self.op = torch.ops._C.gelu_quick
199+
elif current_platform.is_xpu():
200+
from vllm._ipex_ops import ipex_ops
201+
self.op = ipex_ops.gelu_quick
202+
186203
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
187204
"""PyTorch-native implementation equivalent to forward()."""
188205
return x * torch.sigmoid(1.702 * x)
189206

190207
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
191-
from vllm import _custom_ops as ops
192-
193208
out = torch.empty_like(x)
194-
ops.gelu_quick(out, x)
209+
self.op(out, x)
195210
return out
196211

197212
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
198-
from vllm._ipex_ops import ipex_ops as ops
199-
200213
out = torch.empty_like(x)
201-
ops.gelu_quick(out, x)
214+
self.op(out, x)
202215
return out
203216

204217
# TODO implement forward_xpu for QuickGELU

0 commit comments

Comments
 (0)