Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op (_upsample_bilinear2d_aa, _upsample_bicubic2d_aa) | feat(torchlib) #1259

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,7 +2204,9 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str:
@torch_op(
(
"aten::upsample_bicubic2d",
"aten::upsample_bicubic2d_aa",
"aten::upsample_bilinear2d",
"aten::upsample_bilinear2d_aa",
"aten::upsample_nearest1d",
"aten::upsample_nearest2d",
"aten::upsample_nearest3d",
Expand All @@ -2216,6 +2218,7 @@ def _aten_upsample_output_size(
output_size: INT64,
mode: str,
coordinate_transformation_mode: str,
antialias: int = 0,
) -> TReal:
self_shape = op.Shape(self)
starts = op.Constant(value_ints=[0])
Expand All @@ -2230,6 +2233,7 @@ def _aten_upsample_output_size(
mode=mode,
coordinate_transformation_mode=coordinate_transformation_mode,
nearest_mode="floor",
antialias=antialias,
)


Expand Down Expand Up @@ -2273,6 +2277,28 @@ def aten_upsample_bicubic2d(
)


@torch_op("aten::_upsample_bicubic2d_aa", trace_only=True)
def aten__upsample_bicubic2d_aa(
self: TReal,
output_size: INT64,
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TReal:
"""_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""

# NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch,
# unless when align_corners is True, in which case we do not know what is going on.
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
return _aten_upsample_output_size(
self,
output_size,
mode="cubic",
coordinate_transformation_mode=coordinate_transformation_mode,
antialias=1,
)


@torch_op("aten::upsample_bicubic2d.vec", trace_only=True)
def aten_upsample_bicubic2d_vec(
self: TReal,
Expand Down Expand Up @@ -2335,6 +2361,28 @@ def aten_upsample_bilinear2d(
)


@torch_op("aten::_upsample_bilinear2d_aa", trace_only=True)
def aten__upsample_bilinear2d_aa(
self: TReal,
output_size: INT64,
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TReal:
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""

# NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch,
# unless when align_corners is True, in which case we do not know what is going on.
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
return _aten_upsample_output_size(
self,
output_size,
coordinate_transformation_mode=coordinate_transformation_mode,
mode="linear",
antialias=1,
)


@torch_op("aten::upsample_bilinear2d.vec", trace_only=True)
def aten_upsample_bilinear2d_vec(
self: TReal,
Expand Down
14 changes: 14 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,13 +2232,27 @@ def __init__(self):
sample_inputs_func=sample_inputs_upsample_2d,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._upsample_bicubic2d_aa",
aten_name="_upsample_bicubic2d_aa",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_2d,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_bicubic2d.vec",
aten_name="upsample_bicubic2d.vec",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_2d_vec,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._upsample_bilinear2d_aa",
aten_name="_upsample_bilinear2d_aa",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_2d,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_bilinear2d.default",
aten_name="upsample_bilinear2d",
Expand Down
14 changes: 14 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,6 +2100,13 @@ def _where_input_wrangler(
and sample.kwargs.get("scales_h") is not None,
reason="fixme: align_corners=False output mismatch when scales are provided",
),
TorchLibOpInfo(
"ops.aten._upsample_bilinear2d_aa",
nn_ops.aten__upsample_bilinear2d_aa,
trace_only=True,
# ONNX use different antialias method than PyTorch, so the result is different
compare_shape_only_for_output=(0,),
),
TorchLibOpInfo(
"ops.aten.upsample_bilinear2d.vec",
nn_ops.aten_upsample_bilinear2d_vec,
Expand All @@ -2119,6 +2126,13 @@ def _where_input_wrangler(
nn_ops.aten_upsample_bicubic2d_vec,
trace_only=True,
),
TorchLibOpInfo(
"ops.aten._upsample_bicubic2d_aa",
nn_ops.aten__upsample_bicubic2d_aa,
trace_only=True,
# ONNX use different antialias method than PyTorch, so the result is different
compare_shape_only_for_output=(0,),
),
TorchLibOpInfo(
"ops.aten.upsample_linear1d",
nn_ops.aten_upsample_linear1d,
Expand Down
Loading