Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Fix _fft_c2r #1844

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
82 changes: 40 additions & 42 deletions onnxscript/function_libs/torch_lib/ops/fft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
"""torch.ops.aten operators under the `fft` module.

Expand All @@ -21,22 +19,21 @@
from onnxscript.onnx_types import TensorType


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
private=True,
complex=True,
traceable=True,
)
def _fftn_onnx_normalization(
self,
self: TFloat,
transformed: TFloat,
normalization: int,
forward: bool,

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning

Unused argument 'forward' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
dims: Sequence[int],
last_dim_size: Optional[INT64] = None,
) -> TFloat:
# Obtain the total_sample_count (n) for normalization
self_shape = op.Shape(self)
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)
if last_dim_size is None:
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)

Check warning on line 33 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L33

Added line #L33 was not covered by tests
else:
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims[:-1]), keepdims=0)
total_sample_count = op.Mul(total_sample_count, last_dim_size)

Check warning on line 36 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L35-L36

Added lines #L35 - L36 were not covered by tests
total_sample_count = op.CastLike(total_sample_count, transformed)

# Normalize the result
Expand Down Expand Up @@ -64,14 +61,13 @@
return result


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
trace_only=True,
private=True,
complex=True,
)
def _fftn_onnx(
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
self: TFloat,
dims: Sequence[int],
normalization: int,
inverse: bool,
onesided: bool,
last_dim_size: Optional[INT64] = None,
) -> TFloat:
"""Standard complex to complex or real to complex FFT (forward or backward).

Expand All @@ -88,31 +84,37 @@
Returns:
The transformed tensor.
"""

# NOTE: trace_only because we need to process each dimension in a loop
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support

# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
# dimension at the beginning to represent the batch dimension.
transformed = op.Unsqueeze(self, axes=[0])

# Add 1 to account for the batch dimension when counting axes from the left
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
# If taking FFT along the 0-th dimension: Since
# the 0-th dimension in ONNX DFT-17 is the batch dimension (cannot take DFT over),
# we need to add a new dimension at the beginning to represent the batch dimension.
unsqueeze_first_dim = 0 in dims

Check warning on line 92 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L92

Added line #L92 was not covered by tests
if unsqueeze_first_dim:
transformed = op.Unsqueeze(self, axes=[0])

Check warning on line 94 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L94

Added line #L94 was not covered by tests
# Add 1 to account for the batch dimension when counting axes from the left
dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
else:
transformed = self

Check warning on line 98 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L98

Added line #L98 was not covered by tests

for dim in new_dims[:-1]:
for dim in dims[:-1]:
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)

# Torch computers one-sided FFT on the last dimension only.
if onesided:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=True)

Check warning on line 105 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L105

Added line #L105 was not covered by tests
elif last_dim_size is not None:
transformed = op.DFT(

Check warning on line 107 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L107

Added line #L107 was not covered by tests
transformed, last_dim_size, axis=dims[-1], inverse=inverse, onesided=False
)
else:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)
transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False)

Check warning on line 111 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L111

Added line #L111 was not covered by tests

# Remove the batch dimension
transformed = op.Squeeze(transformed, axes=[0])
if unsqueeze_first_dim:
# Remove the batch dimension
transformed = op.Squeeze(transformed, axes=[0])

Check warning on line 115 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L115

Added line #L115 was not covered by tests

return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)
return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size)

Check warning on line 117 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L117

Added line #L117 was not covered by tests


@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
Expand All @@ -124,9 +126,7 @@
Standard complex to complex FFT (forward or backward).
"""

# NOTE: trace_only because we need to negate forward
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
# NOTE: SymInt dim is not supported because DFT-17 needs a static axis

# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
Expand All @@ -139,23 +139,21 @@
self: TFloat,
dim: Sequence[int],
normalization: int,
last_dim_size: INT64, # pylint: disable=unused-argument
last_dim_size: INT64,
) -> TFloat:
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor

Complex to real inverse FFT.
"""

# TODO(justinchuby): Figure out what last_dim_size does

self_rank = len(self.shape)
# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
transformed = _fftn_onnx(

Check warning on line 152 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L152

Added line #L152 was not covered by tests

Check failure

Code scanning / lintrunner

PYLINT/E1123 Error

Unexpected keyword argument 'inverse' in function call (unexpected-keyword-arg)
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'forward' in function call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
self, dim, normalization, inverse=True, onesided=False, last_dim_size=last_dim_size
)
Comment on lines +172 to +174

Check failure

Code scanning / CodeQL

Wrong name for an argument in a call Error

Keyword argument 'inverse' is not a supported parameter name of
function _fftn_onnx
.
# Take only the real part
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])

return op.Squeeze(real_part, axes=[-1])


Expand Down
5 changes: 5 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,9 +689,14 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
(0, 1),
(0, 1, 2),
]:
# Slice
yield opinfo_core.SampleInput(
nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6
)
# Pad
yield opinfo_core.SampleInput(
nd_tensor(), dim=dim, normalization=normalization, last_dim_size=64
)


def _index_variable_bool(shape, max_indices, device):
Expand Down
3 changes: 0 additions & 3 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,6 @@ def _where_input_wrangler(
fft_ops.aten__fft_c2r,
tolerance={torch.complex64: (3e-3, 1.8e-4)},
complex=True,
).xfail(
dtypes=(torch.complex64,),
reason="fixme: the result is wrong: https://github.com/microsoft/onnxscript/pull/926",
),
TorchLibOpInfo(
"ops.aten._fft_r2c", # Custom from extra_opinfo
Expand Down
Loading