Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/torch-26
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms authored Oct 8, 2024
2 parents c73ae34 + 1426e9f commit af81cbc
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 46 deletions.
53 changes: 25 additions & 28 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
RealType,
TFloat,
TFloatHighPrecision,
TFloatOrBFloat16,
TInt,
TReal,
TRealOrUInt8,
Expand Down Expand Up @@ -3564,14 +3563,14 @@ def aten_flipud(self: TensorType) -> TensorType:


@torch_op("aten::floor", traceable=True)
def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_floor(self: TFloat) -> TFloat:
"""floor(Tensor self) -> Tensor"""

return op.Floor(self)


@torch_op("math::floor", traceable=True)
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
def python_math_floor(self: TFloat) -> TInt:
"""floor(Tensor self) -> Tensor"""
floor = op.Floor(self)
return op.Cast(floor, to=INT64.dtype)
Expand Down Expand Up @@ -4533,7 +4532,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL:


@torch_op("aten::isinf")
def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
def aten_isinf(self: TFloat) -> BOOL:
"""isinf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand All @@ -4542,14 +4541,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL:


@torch_op("aten::isnan")
def aten_isnan(self: TFloatOrBFloat16) -> BOOL:
def aten_isnan(self: TFloat) -> BOOL:
"""isnan(Tensor self) -> Tensor"""

return op.IsNaN(self)


@torch_op("aten::isneginf")
def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
def aten_isneginf(self: TFloat) -> BOOL:
"""isneginf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand All @@ -4558,7 +4557,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:


@torch_op("aten::isposinf")
def aten_isposinf(self: TFloatOrBFloat16) -> BOOL:
def aten_isposinf(self: TFloat) -> BOOL:
"""isposinf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand Down Expand Up @@ -4778,42 +4777,42 @@ def aten_linspace(


@torch_op("aten::log", traceable=True)
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log(self: TFloat) -> TFloat:
"""log(Tensor self) -> Tensor"""

return op.Log(self)


@torch_op("aten::log10", traceable=True)
def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log10(self: TFloat) -> TFloat:
"""log10(Tensor self) -> Tensor"""

return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self))


@torch_op("aten::log1p")
def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log1p(self: TFloat) -> TFloat:
"""log1p(Tensor self) -> Tensor"""

return op.Log(op.Add(self, 1.0))


@torch_op("aten::log2", traceable=True)
def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log2(self: TFloat) -> TFloat:
"""log2(Tensor self) -> Tensor"""

return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self))


@torch_op("aten::logaddexp", traceable=True)
def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp(Tensor self, Tensor other) -> Tensor"""

return op.Log(op.Add(op.Exp(self), op.Exp(other)))


@torch_op("aten::logaddexp2", traceable=True)
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp2(Tensor self, Tensor other) -> Tensor"""
two = op.CastLike(2.0, self)
summation = op.Add(op.Pow(two, self), op.Pow(two, other))
Expand All @@ -4822,7 +4821,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr


@torch_op("aten::logcumsumexp", traceable=True)
def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat:
"""logcumsumexp(Tensor self, int dim) -> Tensor"""

if IsScalar(self):
Expand Down Expand Up @@ -4908,12 +4907,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:


@torch_op("aten::logit", private=True)
def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def _aten_logit_onnx(self: TFloat) -> TFloat:
return op.Log(op.Div(self, op.Sub(1.0, self)))


@torch_op("aten::logit", private=True)
def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16:
def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat:
eps = op.CastLike(eps, self)
one = op.CastLike(1.0, self)
temporary_self = op.Where(self <= one - eps, self, one - eps)
Expand All @@ -4923,7 +4922,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat


@torch_op("aten::logit", trace_only=True)
def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16:
def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat:
"""logit(Tensor self, float? eps=None) -> Tensor"""
if eps is None:
return _aten_logit_onnx(self)
Expand Down Expand Up @@ -6041,9 +6040,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:


@torch_op("aten::native_dropout", trace_only=True)
def aten_native_dropout(
input: TFloatOrBFloat16, p: float, train: bool = True
) -> Tuple[TFloatOrBFloat16, BOOL]:
def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]:
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""

result, mask = op.Dropout(input, p, train)
Expand Down Expand Up @@ -7055,7 +7052,7 @@ def aten_real(self: TensorType) -> TensorType:


@torch_op("aten::reciprocal")
def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_reciprocal(self: TFloat) -> TFloat:
"""reciprocal(Tensor self) -> Tensor"""

return op.Reciprocal(self)
Expand All @@ -7074,7 +7071,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:


@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

# TODO(justinchuby): Improve fp16 precision by following the logic in
Expand Down Expand Up @@ -7355,7 +7352,7 @@ def aten_rrelu(


@torch_op("aten::rsqrt", traceable=True)
def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_rsqrt(self: TFloat) -> TFloat:
"""rsqrt(Tensor self) -> Tensor"""

return op.Reciprocal(op.Sqrt(self))
Expand Down Expand Up @@ -7562,7 +7559,7 @@ def aten_sgn(self: TensorType) -> TensorType:


@torch_op("aten::sigmoid", traceable=True)
def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_sigmoid(self: TFloat) -> TFloat:
"""sigmoid(Tensor self) -> Tensor"""

return op.Sigmoid(self)
Expand Down Expand Up @@ -7724,7 +7721,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:


@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True)
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand All @@ -7741,7 +7738,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB


@torch_op(("aten::softmax.int", "aten::special_softmax"), traceable=True)
def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand Down Expand Up @@ -7812,7 +7809,7 @@ def aten_split_with_sizes_copy(


@torch_op("aten::sqrt", traceable=True)
def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_sqrt(self: TFloat) -> TFloat:
"""sqrt(Tensor self) -> Tensor"""

return op.Sqrt(self)
Expand Down Expand Up @@ -8402,7 +8399,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:


@torch_op("aten::trunc")
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_trunc(self: TFloat) -> TFloat:
"""trunc(Tensor self) -> Tensor"""

# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126
Expand Down
11 changes: 5 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from onnxscript.function_libs.torch_lib.tensor_typing import (
IntType,
TFloat,
TFloatOrBFloat16,
TFloatOrUInt8,
TInt,
TReal,
Expand Down Expand Up @@ -364,13 +363,13 @@ def aten_conv_depthwise3d(

@torch_op("aten::cross_entropy_loss", traceable=True)
def aten_cross_entropy_loss(
self: TFloatOrBFloat16,
self: TFloat,
target: IntType,
weight: Optional[TFloatOrBFloat16] = None,
weight: Optional[TFloat] = None,
reduction: int = 1, # default is 'mean'
ignore_index: int = -100,
label_smoothing: float = 0.0, # this was ignored due to ONNX not support
) -> TFloatOrBFloat16:
) -> TFloat:
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""

if reduction == 0: # "none"
Expand Down Expand Up @@ -812,7 +811,7 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te


@torch_op("aten::leaky_relu")
def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16:
def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat:
"""leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"""

return op.LeakyRelu(self, alpha=negative_slope)
Expand Down Expand Up @@ -850,7 +849,7 @@ def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat:


@torch_op("aten::log_sigmoid")
def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log_sigmoid(self: TFloat) -> TFloat:
"""log_sigmoid(Tensor self) -> Tensor"""

return op.Log(op.Sigmoid(self))
Expand Down
16 changes: 7 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand Down Expand Up @@ -92,21 +92,21 @@ def aten_special_entr(self: TensorType) -> TensorType:


@torch_op(("aten::erf", "aten::special_erf"))
def aten_special_erf(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_erf(self: TFloat) -> TFloat:
"""erf(Tensor self) -> Tensor"""

return op.Erf(self)


@torch_op(("aten::erfc", "aten::special_erfc"))
def aten_special_erfc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_erfc(self: TFloat) -> TFloat:
"""erfc(Tensor self) -> Tensor"""

return op.Sub(1, op.Erf(self))


@torch_op("aten::special_erfcx")
def aten_special_erfcx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_erfcx(self: TFloat) -> TFloat:
"""special_erfcx(Tensor self) -> Tensor"""

return op.Mul(op.Exp(op.Pow(self, 2)), op.Sub(1, op.Erf(self)))
Expand All @@ -131,7 +131,7 @@ def aten_special_expit(self: TensorType) -> TensorType:


@torch_op(("aten::expm1", "aten::special_expm1"))
def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_expm1(self: TFloat) -> TFloat:
"""special_expm1(Tensor self) -> Tensor"""

return op.Sub(op.Exp(self), 1)
Expand Down Expand Up @@ -216,9 +216,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:


@torch_op(("aten::log_softmax.int", "aten::special_log_softmax"), trace_only=True)
def aten_special_log_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = -1
) -> TFloatOrBFloat16:
def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
"""special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand Down Expand Up @@ -366,7 +364,7 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType:


@torch_op(("aten::xlogy.Tensor", "aten::xlogy.Scalar_Self", "aten::xlogy.Scalar_Other"))
def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_xlogy(self: TFloat, other: TFloat) -> TFloat:
"""special_xlogy(Tensor self, Tensor other) -> Tensor"""

# https://pytorch.org/docs/stable/special.html#torch.special.xlogy
Expand Down
3 changes: 1 addition & 2 deletions onnxscript/function_libs/torch_lib/tensor_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
INT64,
UINT8,
]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]
IntType = Union[INT8, INT16, INT32, INT64]
RealType = Union[
BFLOAT16,
Expand All @@ -61,7 +61,6 @@
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
TFloat = TypeVar("TFloat", bound=_FloatType)
TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
TFloatOrUInt8 = TypeVar("TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8])
TInt = TypeVar("TInt", bound=IntType)
TReal = TypeVar("TReal", bound=RealType)
Expand Down
2 changes: 1 addition & 1 deletion requirements/ci/requirements-onnx-weekly.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
onnx-weekly==1.17.0.dev20240715
onnx-weekly==1.18.0.dev20240930

0 comments on commit af81cbc

Please sign in to comment.