From 3be8fc482bc445b6eee4a83205bee5a73279bf1a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Oct 2024 11:49:08 -0700 Subject: [PATCH 1/2] [torchlib] Include bfloat16 as part of the float types (#1894) Since onnx in opset 20 or so enabled bfloat16 for most relevant ops, we are just going to include allow them in torchlib (even though it is opset18 for now) to unblock bfloat16 model export. --- .../function_libs/torch_lib/ops/core.py | 53 +++++++++---------- onnxscript/function_libs/torch_lib/ops/nn.py | 11 ++-- .../function_libs/torch_lib/ops/special.py | 16 +++--- .../function_libs/torch_lib/tensor_typing.py | 3 +- 4 files changed, 38 insertions(+), 45 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f41ff1c3e..1fc73a220 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -39,7 +39,6 @@ RealType, TFloat, TFloatHighPrecision, - TFloatOrBFloat16, TInt, TReal, TRealOrUInt8, @@ -3564,14 +3563,14 @@ def aten_flipud(self: TensorType) -> TensorType: @torch_op("aten::floor", traceable=True) -def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_floor(self: TFloat) -> TFloat: """floor(Tensor self) -> Tensor""" return op.Floor(self) @torch_op("math::floor", traceable=True) -def python_math_floor(self: TFloatOrBFloat16) -> TInt: +def python_math_floor(self: TFloat) -> TInt: """floor(Tensor self) -> Tensor""" floor = op.Floor(self) return op.Cast(floor, to=INT64.dtype) @@ -4533,7 +4532,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL: @torch_op("aten::isinf") -def aten_isinf(self: TFloatOrBFloat16) -> BOOL: +def aten_isinf(self: TFloat) -> BOOL: """isinf(Tensor self) -> Tensor""" # Added Cast inside the function so it can support all real dtypes naturally @@ -4542,14 +4541,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL: @torch_op("aten::isnan") -def aten_isnan(self: TFloatOrBFloat16) -> BOOL: +def aten_isnan(self: TFloat) -> BOOL: """isnan(Tensor self) -> Tensor""" return op.IsNaN(self) @torch_op("aten::isneginf") -def aten_isneginf(self: TFloatOrBFloat16) -> BOOL: +def aten_isneginf(self: TFloat) -> BOOL: """isneginf(Tensor self) -> Tensor""" # Added Cast inside the function so it can support all real dtypes naturally @@ -4558,7 +4557,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL: @torch_op("aten::isposinf") -def aten_isposinf(self: TFloatOrBFloat16) -> BOOL: +def aten_isposinf(self: TFloat) -> BOOL: """isposinf(Tensor self) -> Tensor""" # Added Cast inside the function so it can support all real dtypes naturally @@ -4778,42 +4777,42 @@ def aten_linspace( @torch_op("aten::log", traceable=True) -def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log(self: TFloat) -> TFloat: """log(Tensor self) -> Tensor""" return op.Log(self) @torch_op("aten::log10", traceable=True) -def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log10(self: TFloat) -> TFloat: """log10(Tensor self) -> Tensor""" return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self)) @torch_op("aten::log1p") -def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log1p(self: TFloat) -> TFloat: """log1p(Tensor self) -> Tensor""" return op.Log(op.Add(self, 1.0)) @torch_op("aten::log2", traceable=True) -def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log2(self: TFloat) -> TFloat: """log2(Tensor self) -> Tensor""" return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self)) @torch_op("aten::logaddexp", traceable=True) -def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat: """logaddexp(Tensor self, Tensor other) -> Tensor""" return op.Log(op.Add(op.Exp(self), op.Exp(other))) @torch_op("aten::logaddexp2", traceable=True) -def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat: """logaddexp2(Tensor self, Tensor other) -> Tensor""" two = op.CastLike(2.0, self) summation = op.Add(op.Pow(two, self), op.Pow(two, other)) @@ -4822,7 +4821,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr @torch_op("aten::logcumsumexp", traceable=True) -def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: +def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat: """logcumsumexp(Tensor self, int dim) -> Tensor""" if IsScalar(self): @@ -4908,12 +4907,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: @torch_op("aten::logit", private=True) -def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def _aten_logit_onnx(self: TFloat) -> TFloat: return op.Log(op.Div(self, op.Sub(1.0, self))) @torch_op("aten::logit", private=True) -def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16: +def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat: eps = op.CastLike(eps, self) one = op.CastLike(1.0, self) temporary_self = op.Where(self <= one - eps, self, one - eps) @@ -4923,7 +4922,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat @torch_op("aten::logit", trace_only=True) -def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16: +def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: """logit(Tensor self, float? eps=None) -> Tensor""" if eps is None: return _aten_logit_onnx(self) @@ -6041,9 +6040,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: @torch_op("aten::native_dropout", trace_only=True) -def aten_native_dropout( - input: TFloatOrBFloat16, p: float, train: bool = True -) -> Tuple[TFloatOrBFloat16, BOOL]: +def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]: """native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)""" result, mask = op.Dropout(input, p, train) @@ -7055,7 +7052,7 @@ def aten_real(self: TensorType) -> TensorType: @torch_op("aten::reciprocal") -def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_reciprocal(self: TFloat) -> TFloat: """reciprocal(Tensor self) -> Tensor""" return op.Reciprocal(self) @@ -7074,7 +7071,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: @torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar")) -def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_remainder(self: TFloat, other: TFloat) -> TFloat: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" # TODO(justinchuby): Improve fp16 precision by following the logic in @@ -7355,7 +7352,7 @@ def aten_rrelu( @torch_op("aten::rsqrt", traceable=True) -def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_rsqrt(self: TFloat) -> TFloat: """rsqrt(Tensor self) -> Tensor""" return op.Reciprocal(op.Sqrt(self)) @@ -7562,7 +7559,7 @@ def aten_sgn(self: TensorType) -> TensorType: @torch_op("aten::sigmoid", traceable=True) -def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_sigmoid(self: TFloat) -> TFloat: """sigmoid(Tensor self) -> Tensor""" return op.Sigmoid(self) @@ -7724,7 +7721,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType: @torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True) -def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16: +def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" self_is_scalar = IsScalar(self) @@ -7741,7 +7738,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB @torch_op(("aten::softmax.int", "aten::special_softmax"), traceable=True) -def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16: +def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat: """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" self_is_scalar = IsScalar(self) @@ -7812,7 +7809,7 @@ def aten_split_with_sizes_copy( @torch_op("aten::sqrt", traceable=True) -def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_sqrt(self: TFloat) -> TFloat: """sqrt(Tensor self) -> Tensor""" return op.Sqrt(self) @@ -8402,7 +8399,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType: @torch_op("aten::trunc") -def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_trunc(self: TFloat) -> TFloat: """trunc(Tensor self) -> Tensor""" # Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126 diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4687e260a..e963050f5 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -25,7 +25,6 @@ from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, TFloat, - TFloatOrBFloat16, TFloatOrUInt8, TInt, TReal, @@ -364,13 +363,13 @@ def aten_conv_depthwise3d( @torch_op("aten::cross_entropy_loss", traceable=True) def aten_cross_entropy_loss( - self: TFloatOrBFloat16, + self: TFloat, target: IntType, - weight: Optional[TFloatOrBFloat16] = None, + weight: Optional[TFloat] = None, reduction: int = 1, # default is 'mean' ignore_index: int = -100, label_smoothing: float = 0.0, # this was ignored due to ONNX not support -) -> TFloatOrBFloat16: +) -> TFloat: """cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor""" if reduction == 0: # "none" @@ -812,7 +811,7 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te @torch_op("aten::leaky_relu") -def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16: +def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat: """leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor""" return op.LeakyRelu(self, alpha=negative_slope) @@ -850,7 +849,7 @@ def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat: @torch_op("aten::log_sigmoid") -def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_log_sigmoid(self: TFloat) -> TFloat: """log_sigmoid(Tensor self) -> Tensor""" return op.Log(op.Sigmoid(self)) diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index 6dd9edcd3..c791937b1 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -17,7 +17,7 @@ from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op -from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16 +from onnxscript.function_libs.torch_lib.tensor_typing import TFloat from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType @@ -92,21 +92,21 @@ def aten_special_entr(self: TensorType) -> TensorType: @torch_op(("aten::erf", "aten::special_erf")) -def aten_special_erf(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_erf(self: TFloat) -> TFloat: """erf(Tensor self) -> Tensor""" return op.Erf(self) @torch_op(("aten::erfc", "aten::special_erfc")) -def aten_special_erfc(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_erfc(self: TFloat) -> TFloat: """erfc(Tensor self) -> Tensor""" return op.Sub(1, op.Erf(self)) @torch_op("aten::special_erfcx") -def aten_special_erfcx(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_erfcx(self: TFloat) -> TFloat: """special_erfcx(Tensor self) -> Tensor""" return op.Mul(op.Exp(op.Pow(self, 2)), op.Sub(1, op.Erf(self))) @@ -131,7 +131,7 @@ def aten_special_expit(self: TensorType) -> TensorType: @torch_op(("aten::expm1", "aten::special_expm1")) -def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_expm1(self: TFloat) -> TFloat: """special_expm1(Tensor self) -> Tensor""" return op.Sub(op.Exp(self), 1) @@ -216,9 +216,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType: @torch_op(("aten::log_softmax.int", "aten::special_log_softmax"), trace_only=True) -def aten_special_log_softmax( - self: TFloatOrBFloat16, dim: int, dtype: int = -1 -) -> TFloatOrBFloat16: +def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: """special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor""" self_is_scalar = IsScalar(self) @@ -366,7 +364,7 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType: @torch_op(("aten::xlogy.Tensor", "aten::xlogy.Scalar_Self", "aten::xlogy.Scalar_Other")) -def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16: +def aten_special_xlogy(self: TFloat, other: TFloat) -> TFloat: """special_xlogy(Tensor self, Tensor other) -> Tensor""" # https://pytorch.org/docs/stable/special.html#torch.special.xlogy diff --git a/onnxscript/function_libs/torch_lib/tensor_typing.py b/onnxscript/function_libs/torch_lib/tensor_typing.py index 7b5287f41..1f27c0cff 100644 --- a/onnxscript/function_libs/torch_lib/tensor_typing.py +++ b/onnxscript/function_libs/torch_lib/tensor_typing.py @@ -42,7 +42,7 @@ INT64, UINT8, ] -_FloatType = Union[FLOAT16, FLOAT, DOUBLE] +_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16] IntType = Union[INT8, INT16, INT32, INT64] RealType = Union[ BFLOAT16, @@ -61,7 +61,6 @@ TTensor2 = TypeVar("TTensor2", bound=_TensorType) TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING]) TFloat = TypeVar("TFloat", bound=_FloatType) -TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]) TFloatOrUInt8 = TypeVar("TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8]) TInt = TypeVar("TInt", bound=IntType) TReal = TypeVar("TReal", bound=RealType) From 1426e9f11b7bbbd9cf165d96c4c6ed9205f740d6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Oct 2024 11:56:19 -0700 Subject: [PATCH 2/2] Bump onnx-weekly in CI (#1895) To 1.18.0.dev20240930 because the previous weekly was cleaned up in pypi --- requirements/ci/requirements-onnx-weekly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index 2ebee9809..ea926d355 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.17.0.dev20240715 +onnx-weekly==1.18.0.dev20240930