Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Sep 24, 2024
1 parent 731e4ca commit 11aaacf
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 18 deletions.
10 changes: 5 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,11 +2550,11 @@ def aten_upsample_linear1d_backward(

@torch_op("aten::upsample_nearest1d", trace_only=True)
def aten_upsample_nearest1d(
self: TReal, size: INT64, scale_factor: Optional[float] = None
self: TReal, output_size: INT64, scale_factor: Optional[float] = None
) -> TReal:
"""upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor"""
if size is not None:
return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric")
else:
return _aten_upsample_scales(
self, op.Constant(value_floats=[scale_factor]), "nearest", "asymmetric"
Expand Down Expand Up @@ -2612,7 +2612,7 @@ def aten_upsample_nearest2d(
"asymmetric",
)
else:
return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric")


def aten_upsample_nearest2d_backward(
Expand All @@ -2630,7 +2630,7 @@ def aten_upsample_nearest2d_backward(
@torch_op("aten::upsample_nearest3d", trace_only=True)
def aten_upsample_nearest3d(
self: TReal,
size: INT64,
output_size: INT64,
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
Expand All @@ -2645,7 +2645,7 @@ def aten_upsample_nearest3d(
"asymmetric",
)
else:
return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric")


def aten_upsample_nearest3d_backward(
Expand Down
20 changes: 8 additions & 12 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,7 @@ def sample_inputs_upsample_nearest1d(op_info, device, dtype, requires_grad, **kw

N, C = 2, 3
D = 4
SS = 3
L = 5

rank = 1
Expand Down Expand Up @@ -1633,12 +1634,12 @@ def shape(size, rank, with_batch_channel=True):
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
None, # output_size
shape(S, rank, False), # output_size
[1.7], # scaler
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
None, # if this is None, the scalar must be list
shape(S, rank, False), # if this is None, the scalar must be list
[0.6],
)

Expand Down Expand Up @@ -1687,6 +1688,7 @@ def shape(size, rank, with_batch_channel=True):
)



def sample_inputs_upsample_nearest2d(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down Expand Up @@ -1722,14 +1724,12 @@ def shape(size, rank, with_batch_channel=True):
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
1.7,
2.0, # scaler
1.7, 2.0, # scaler
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
0.6,
0.4,
0.6, 0.4,
)


Expand Down Expand Up @@ -1812,16 +1812,12 @@ def shape(size, rank, with_batch_channel=True):
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
1.7,
1.5,
2.0, # scaler
1.7, 1.5, 2.0, # scaler
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
0.6,
0.3,
0.5,
0.6, 0.3, 0.5,
)


Expand Down
1 change: 0 additions & 1 deletion tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,7 +2391,6 @@ def _where_input_wrangler(
"signbit",
"sin",
"sinh",
"slice",
"sqrt",
"squeeze",
"sub",
Expand Down

0 comments on commit 11aaacf

Please sign in to comment.