@@ -189,47 +189,37 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
189189 if index == 0 : # condition
190190 tensor_constraints = [
191191 cp .Dtype .In (lambda deps : [torch .bool ]),
192- cp .Value .Ge (lambda deps , dtype , struct : - ( 2 ** 4 ) ),
193- cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
192+ cp .Value .Ge (lambda deps , dtype , struct : 0 ),
193+ cp .Value .Le (lambda deps , dtype , struct : 1 ),
194194 cp .Rank .Ge (lambda deps : 1 ),
195195 cp .Size .Ge (lambda deps , r , d : 1 ),
196196 max_size_constraint ,
197197 ]
198198 elif index == 1 : # input tensor(a)
199199 tensor_constraints = [
200- cp .Dtype .In (
201- lambda deps : [
202- torch .int8 ,
203- torch .int16 ,
204- torch .uint8 ,
205- torch .uint16 ,
206- torch .int32 ,
207- torch .float32 ,
208- ]
209- ),
200+ cp .Dtype .In (lambda deps : [torch .float32 ]),
210201 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
211202 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
212203 cp .Rank .Ge (lambda deps : 1 ),
213204 cp .Size .Ge (lambda deps , r , d : 1 ),
205+ cp .Size .In (
206+ lambda deps , r , d : fn .broadcast_with (deps [0 ].shape , r , d )
207+ ),
214208 max_size_constraint ,
215209 ]
216210 else : # input tensor(b)
217211 tensor_constraints = [
218- cp .Dtype .In (
219- lambda deps : [
220- torch .int8 ,
221- torch .int16 ,
222- torch .uint8 ,
223- torch .uint16 ,
224- torch .int32 ,
225- torch .float32 ,
226- ]
227- ),
212+ cp .Dtype .In (lambda deps : [torch .float32 ]),
228213 cp .Dtype .Eq (lambda deps : deps [1 ].dtype ),
229214 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
230215 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
231216 cp .Rank .Ge (lambda deps : 1 ),
232217 cp .Size .Ge (lambda deps , r , d : 1 ),
218+ cp .Size .In (
219+ lambda deps , r , d : fn .broadcast_with (
220+ fn .broadcasted_shape (deps [0 ].shape , deps [1 ].shape ), r , d
221+ )
222+ ),
233223 max_size_constraint ,
234224 ]
235225 case "embedding.default" :
0 commit comments