@@ -30,6 +30,8 @@ class FatreluAndMul(CustomOp):
30
30
def __init__ (self , threshold : float = 0. ):
31
31
super ().__init__ ()
32
32
self .threshold = threshold
33
+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
34
+ self .op = torch .ops ._C .fatrelu_and_mul
33
35
34
36
def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
35
37
d = x .shape [- 1 ] // 2
@@ -39,12 +41,10 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
39
41
return x1 * x2
40
42
41
43
def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
42
- from vllm import _custom_ops as ops
43
-
44
44
d = x .shape [- 1 ] // 2
45
45
output_shape = (x .shape [:- 1 ] + (d , ))
46
46
out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
47
- ops . fatrelu_and_mul (out , x , self .threshold )
47
+ self . op (out , x , self .threshold )
48
48
return out
49
49
50
50
@@ -103,34 +103,35 @@ def __init__(self, approximate: str = "none"):
103
103
self .approximate = approximate
104
104
if approximate not in ("none" , "tanh" ):
105
105
raise ValueError (f"Unknown approximate mode: { approximate } " )
106
+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
107
+ if approximate == "none" :
108
+ self .op = torch .ops ._C .gelu_and_mul
109
+ elif approximate == "tanh" :
110
+ self .op = torch .ops ._C .gelu_tanh_and_mul
111
+ elif current_platform .is_xpu ():
112
+ from vllm ._ipex_ops import ipex_ops
113
+ if approximate == "none" :
114
+ self .op = ipex_ops .gelu_and_mul
115
+ else :
116
+ self .op = ipex_ops .gelu_tanh_and_mul
106
117
107
118
def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
108
119
"""PyTorch-native implementation equivalent to forward()."""
109
120
d = x .shape [- 1 ] // 2
110
121
return F .gelu (x [..., :d ], approximate = self .approximate ) * x [..., d :]
111
122
112
123
def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
113
- from vllm import _custom_ops as ops
114
-
115
124
d = x .shape [- 1 ] // 2
116
125
output_shape = (x .shape [:- 1 ] + (d , ))
117
126
out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
118
- if self .approximate == "none" :
119
- ops .gelu_and_mul (out , x )
120
- elif self .approximate == "tanh" :
121
- ops .gelu_tanh_and_mul (out , x )
127
+ self .op (out , x )
122
128
return out
123
129
124
130
def forward_xpu (self , x : torch .Tensor ) -> torch .Tensor :
125
- from vllm ._ipex_ops import ipex_ops as ops
126
-
127
131
d = x .shape [- 1 ] // 2
128
132
output_shape = (x .shape [:- 1 ] + (d , ))
129
133
out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
130
- if self .approximate == "none" :
131
- ops .gelu_and_mul (out , x )
132
- elif self .approximate == "tanh" :
133
- ops .gelu_tanh_and_mul (out , x )
134
+ self .op (out , x )
134
135
return out
135
136
136
137
def extra_repr (self ) -> str :
@@ -140,65 +141,77 @@ def extra_repr(self) -> str:
140
141
@CustomOp .register ("gelu_new" )
141
142
class NewGELU (CustomOp ):
142
143
144
+ def __init__ (self ):
145
+ super ().__init__ ()
146
+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
147
+ self .op = torch .ops ._C .gelu_new
148
+ elif current_platform .is_xpu ():
149
+ from vllm ._ipex_ops import ipex_ops
150
+ self .op = ipex_ops .gelu_new
151
+
143
152
def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
144
153
"""PyTorch-native implementation equivalent to forward()."""
145
154
c = math .sqrt (2.0 / math .pi )
146
155
return 0.5 * x * (1.0 + torch .tanh (c *
147
156
(x + 0.044715 * torch .pow (x , 3.0 ))))
148
157
149
158
def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
150
- from vllm import _custom_ops as ops
151
-
152
159
out = torch .empty_like (x )
153
- ops . gelu_new (out , x )
160
+ self . op (out , x )
154
161
return out
155
162
156
163
def forward_xpu (self , x : torch .Tensor ) -> torch .Tensor :
157
- from vllm ._ipex_ops import ipex_ops as ops
158
-
159
- return ops .gelu_new (x )
164
+ return self .op (x )
160
165
161
166
162
167
@CustomOp .register ("gelu_fast" )
163
168
class FastGELU (CustomOp ):
164
169
170
+ def __init__ (self ):
171
+ super ().__init__ ()
172
+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
173
+ self .op = torch .ops ._C .gelu_fast
174
+ elif current_platform .is_xpu ():
175
+ from vllm ._ipex_ops import ipex_ops
176
+ self .op = ipex_ops .gelu_fast
177
+
165
178
def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
166
179
"""PyTorch-native implementation equivalent to forward()."""
167
180
return 0.5 * x * (1.0 + torch .tanh (x * 0.7978845608 *
168
181
(1.0 + 0.044715 * x * x )))
169
182
170
183
def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
171
- from vllm import _custom_ops as ops
172
-
173
184
out = torch .empty_like (x )
174
- ops . gelu_fast (out , x )
185
+ self . op (out , x )
175
186
return out
176
187
177
188
def forward_xpu (self , x : torch .Tensor ) -> torch .Tensor :
178
- from vllm ._ipex_ops import ipex_ops as ops
179
-
180
- return ops .gelu_fast (x )
189
+ return self .op (x )
181
190
182
191
183
192
@CustomOp .register ("quick_gelu" )
184
193
class QuickGELU (CustomOp ):
185
194
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
195
+ def __init__ (self ):
196
+ super ().__init__ ()
197
+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
198
+ self .op = torch .ops ._C .gelu_quick
199
+ elif current_platform .is_xpu ():
200
+ from vllm ._ipex_ops import ipex_ops
201
+ self .op = ipex_ops .gelu_quick
202
+
186
203
def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
187
204
"""PyTorch-native implementation equivalent to forward()."""
188
205
return x * torch .sigmoid (1.702 * x )
189
206
190
207
def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
191
- from vllm import _custom_ops as ops
192
-
193
208
out = torch .empty_like (x )
194
- ops . gelu_quick (out , x )
209
+ self . op (out , x )
195
210
return out
196
211
197
212
def forward_xpu (self , x : torch .Tensor ) -> torch .Tensor :
198
- from vllm ._ipex_ops import ipex_ops as ops
199
-
200
213
out = torch .empty_like (x )
201
- ops . gelu_quick (out , x )
214
+ self . op (out , x )
202
215
return out
203
216
204
217
# TODO implement forward_xpu for QuickGELU
0 commit comments