2121try :
2222 # torch >=2.3
2323 _int_dtypes |= {torch .uint16 , torch .uint32 , torch .uint64 }
24+ _HAS_LARGE_UINT = True
2425except AttributeError :
25- pass
26-
26+ _HAS_LARGE_UINT = False
2727
2828_array_api_dtypes = {
2929 torch .bool ,
3434 torch .complex128 ,
3535}
3636
37- _promotion_table = {
38- # bool
39- (torch .bool , torch .bool ): torch .bool ,
37+ _promotion_table = {
4038 # ints
41- (torch .int8 , torch .int8 ): torch .int8 ,
4239 (torch .int8 , torch .int16 ): torch .int16 ,
4340 (torch .int8 , torch .int32 ): torch .int32 ,
4441 (torch .int8 , torch .int64 ): torch .int64 ,
45- (torch .int16 , torch .int8 ): torch .int16 ,
46- (torch .int16 , torch .int16 ): torch .int16 ,
4742 (torch .int16 , torch .int32 ): torch .int32 ,
4843 (torch .int16 , torch .int64 ): torch .int64 ,
49- (torch .int32 , torch .int8 ): torch .int32 ,
50- (torch .int32 , torch .int16 ): torch .int32 ,
51- (torch .int32 , torch .int32 ): torch .int32 ,
5244 (torch .int32 , torch .int64 ): torch .int64 ,
53- (torch .int64 , torch .int8 ): torch .int64 ,
54- (torch .int64 , torch .int16 ): torch .int64 ,
55- (torch .int64 , torch .int32 ): torch .int64 ,
56- (torch .int64 , torch .int64 ): torch .int64 ,
57- # uints
58- (torch .uint8 , torch .uint8 ): torch .uint8 ,
5945 # ints and uints (mixed sign)
60- (torch .int8 , torch .uint8 ): torch .int16 ,
61- (torch .int16 , torch .uint8 ): torch .int16 ,
62- (torch .int32 , torch .uint8 ): torch .int32 ,
63- (torch .int64 , torch .uint8 ): torch .int64 ,
6446 (torch .uint8 , torch .int8 ): torch .int16 ,
6547 (torch .uint8 , torch .int16 ): torch .int16 ,
6648 (torch .uint8 , torch .int32 ): torch .int32 ,
6749 (torch .uint8 , torch .int64 ): torch .int64 ,
6850 # floats
69- (torch .float32 , torch .float32 ): torch .float32 ,
7051 (torch .float32 , torch .float64 ): torch .float64 ,
71- (torch .float64 , torch .float32 ): torch .float64 ,
72- (torch .float64 , torch .float64 ): torch .float64 ,
7352 # complexes
74- (torch .complex64 , torch .complex64 ): torch .complex64 ,
7553 (torch .complex64 , torch .complex128 ): torch .complex128 ,
76- (torch .complex128 , torch .complex64 ): torch .complex128 ,
77- (torch .complex128 , torch .complex128 ): torch .complex128 ,
7854 # Mixed float and complex
7955 (torch .float32 , torch .complex64 ): torch .complex64 ,
8056 (torch .float32 , torch .complex128 ): torch .complex128 ,
8157 (torch .float64 , torch .complex64 ): torch .complex128 ,
8258 (torch .float64 , torch .complex128 ): torch .complex128 ,
8359}
8460
61+ if _HAS_LARGE_UINT : # torch >=2.3
62+ _promotion_table .update (
63+ {
64+ # uints
65+ (torch .uint8 , torch .uint16 ): torch .uint16 ,
66+ (torch .uint8 , torch .uint32 ): torch .uint32 ,
67+ (torch .uint8 , torch .uint64 ): torch .uint64 ,
68+ (torch .uint16 , torch .uint32 ): torch .uint32 ,
69+ (torch .uint16 , torch .uint64 ): torch .uint64 ,
70+ (torch .uint32 , torch .uint64 ): torch .uint64 ,
71+ # ints and uints (mixed sign)
72+ (torch .uint16 , torch .int8 ): torch .int32 ,
73+ (torch .uint16 , torch .int16 ): torch .int32 ,
74+ (torch .uint16 , torch .int32 ): torch .int32 ,
75+ (torch .uint16 , torch .int64 ): torch .int64 ,
76+ (torch .uint32 , torch .int8 ): torch .int64 ,
77+ (torch .uint32 , torch .int16 ): torch .int64 ,
78+ (torch .uint32 , torch .int32 ): torch .int64 ,
79+ (torch .uint32 , torch .int64 ): torch .int64 ,
80+ }
81+ )
82+
83+ _promotion_table .update ({(b , a ): c for (a , b ), c in _promotion_table .items ()})
84+ _promotion_table .update ({(a , a ): a for a in _array_api_dtypes })
85+
8586
8687def _two_arg (f ):
8788 @_wraps (f )
@@ -275,6 +276,31 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
275276 out = torch .unsqueeze (out , a )
276277 return out
277278
279+
280+ def _sum_prod_no_axis (x : Array , dtype : DType | None ) -> Array :
281+ """
282+ Implements `sum(..., axis=())` and `prod(..., axis=())`.
283+
284+ Works around https://github.com/pytorch/pytorch/issues/29137
285+ """
286+ if dtype is not None :
287+ return x .clone () if dtype == x .dtype else x .to (dtype )
288+
289+ if x .dtype in (torch .int8 , torch .int16 , torch .int32 ):
290+ return x .to (torch .int64 )
291+
292+ if _HAS_LARGE_UINT and x .dtype in (torch .uint8 , torch .uint16 , torch .uint32 ):
293+ return x .to (torch .uint64 )
294+
295+ if x .dtype == torch .uint8 :
296+ # We can't upcast uint8 according to the spec because there is no
297+ # torch.uint64, so at least upcast to int64 which is what prod does
298+ # when axis=None.
299+ return x .to (torch .int64 )
300+
301+ return x .clone ()
302+
303+
278304def prod (x : Array ,
279305 / ,
280306 * ,
@@ -283,20 +309,9 @@ def prod(x: Array,
283309 keepdims : bool = False ,
284310 ** kwargs ) -> Array :
285311 x = torch .asarray (x )
286- ndim = x .ndim
287312
288- # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
289- # below because it still needs to upcast.
290313 if axis == ():
291- if dtype is None :
292- # We can't upcast uint8 according to the spec because there is no
293- # torch.uint64, so at least upcast to int64 which is what sum does
294- # when axis=None.
295- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
296- return x .to (torch .int64 )
297- return x .clone ()
298- return x .to (dtype )
299-
314+ return _sum_prod_no_axis (x , dtype )
300315 # torch.prod doesn't support multiple axes
301316 # (https://github.com/pytorch/pytorch/issues/56586).
302317 if isinstance (axis , tuple ):
@@ -305,7 +320,7 @@ def prod(x: Array,
305320 # torch doesn't support keepdims with axis=None
306321 # (https://github.com/pytorch/pytorch/issues/71209)
307322 res = torch .prod (x , dtype = dtype , ** kwargs )
308- res = _axis_none_keepdims (res , ndim , keepdims )
323+ res = _axis_none_keepdims (res , x . ndim , keepdims )
309324 return res
310325
311326 return torch .prod (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -321,17 +336,8 @@ def sum(x: Array,
321336 x = torch .asarray (x )
322337 ndim = x .ndim
323338
324- # https://github.com/pytorch/pytorch/issues/29137.
325- # Make sure it upcasts.
326339 if axis == ():
327- if dtype is None :
328- # We can't upcast uint8 according to the spec because there is no
329- # torch.uint64, so at least upcast to int64 which is what sum does
330- # when axis=None.
331- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
332- return x .to (torch .int64 )
333- return x .clone ()
334- return x .to (dtype )
340+ return _sum_prod_no_axis (x , dtype )
335341
336342 if axis is None :
337343 # torch doesn't support keepdims with axis=None
0 commit comments