Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Fix _fft_c2r #1844

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
166 changes: 92 additions & 74 deletions onnxscript/function_libs/torch_lib/ops/fft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
"""torch.ops.aten operators under the `fft` module.

Expand All @@ -12,7 +10,7 @@

from __future__ import annotations

from typing import Optional, Sequence
from typing import Literal, Optional, Sequence

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused Literal imported from typing (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

typing.Literal imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import

from onnxscript import INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
Expand All @@ -21,57 +19,53 @@
from onnxscript.onnx_types import TensorType


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
private=True,
complex=True,
traceable=True,
)
def _fftn_onnx_normalization(
self,
transformed: TFloat,
normalization: int,
forward: bool,
def _compute_signal_size(signal: TFloat, dims: Sequence[int], last_dim_size: Optional[INT64] = None) -> INT64:
if last_dim_size is not None:
all_other_dims = dims[:-1]

Check warning on line 24 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L24

Added line #L24 was not covered by tests
if all_other_dims:
signal_size = op.ReduceProd(signal, axes=all_other_dims, keepdims=False)
signal_size = op.Mul(signal_size, last_dim_size)

Check warning on line 27 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L26-L27

Added lines #L26 - L27 were not covered by tests
else:
signal_size = last_dim_size

Check warning on line 29 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L29

Added line #L29 was not covered by tests
else:
signal_size = op.ReduceProd(signal, axes=dims, keepdims=False)
return signal_size

Check warning on line 32 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L31-L32

Added lines #L31 - L32 were not covered by tests


def _fftn_ortho_normalization(
self: TFloat,
dims: Sequence[int],
forward: bool,

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning

Unused argument 'forward' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
onesided: bool,
last_dim_size: Optional[INT64] = None,
) -> TFloat:
# Obtain the total_sample_count (n) for normalization
self_shape = op.Shape(self)
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)
total_sample_count = op.CastLike(total_sample_count, transformed)

# Normalize the result
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
# Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
if normalization == 1:
# "forward" - normalize by 1/n
if forward:
result = op.Div(transformed, op.Sqrt(total_sample_count))
else:
result = op.Mul(transformed, op.Sqrt(total_sample_count))
elif normalization == 2:
# "ortho" - normalize by 1/sqrt(n)
if forward:
result = op.Div(transformed, total_sample_count)
else:
result = transformed
transformed = self

Check warning on line 42 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L42

Added line #L42 was not covered by tests

signal_size = _compute_signal_size(self, dims, last_dim_size)

Check warning on line 44 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L44

Added line #L44 was not covered by tests

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'signal_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable signal\_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable

for dim in dims[:-1]:
transformed = op.DFT(transformed, axis=dim, onesided=False)

Check warning on line 47 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L47

Added line #L47 was not covered by tests

# Torch computes one-sided FFT on the last dimension only.
if onesided:
transformed = op.DFT(transformed, axis=dims[-1], onesided=True)

Check warning on line 51 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L51

Added line #L51 was not covered by tests
# TODO: Update signal_size for one-sided FFT
elif last_dim_size is not None:
transformed = op.DFT(

Check warning on line 54 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L54

Added line #L54 was not covered by tests
transformed, last_dim_size, axis=dims[-1], onesided=True
)
else:
# "backward" - no normalization
if forward:
result = transformed
else:
result = op.Mul(transformed, total_sample_count)
transformed = op.DFT(transformed, axis=dims[-1], onesided=False)

Check warning on line 58 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L58

Added line #L58 was not covered by tests

return result


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
trace_only=True,
private=True,
complex=True,
)
def _fftn_onnx(
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
self: TFloat,
dims: Sequence[int],
normalization: int,
forward: bool,
onesided: bool,
last_dim_size: Optional[INT64] = None,
) -> TFloat:
"""Standard complex to complex or real to complex FFT (forward or backward).

Expand All @@ -81,38 +75,66 @@
self: The input tensor.
dims: The dimensions to apply FFT.
normalization: The normalization mode.
inverse: Whether to compute the inverse FFT.
forward: Whether to compute forward FFT or backward FFT.
onesided: Whether to compute the one-sided FFT, which retains only the
positive frequencies.
last_dim_size: The size of the last specified dimension.

Returns:
The transformed tensor.
"""

# NOTE: trace_only because we need to process each dimension in a loop
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support

# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
# dimension at the beginning to represent the batch dimension.
transformed = op.Unsqueeze(self, axes=[0])

# Add 1 to account for the batch dimension when counting axes from the left
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]

for dim in new_dims[:-1]:
# If taking FFT along the 0-th dimension: Since
# the 0-th dimension in ONNX DFT-17 is the batch dimension (cannot take DFT over),
# we need to add a new dimension at the beginning to represent the batch dimension.
unsqueeze_first_dim = 0 in dims

Check warning on line 91 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L91

Added line #L91 was not covered by tests
if unsqueeze_first_dim:
transformed = op.Unsqueeze(self, axes=[0])

Check warning on line 93 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L93

Added line #L93 was not covered by tests
# Add 1 to account for the batch dimension when counting axes from the left
dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
else:
transformed = self

Check warning on line 97 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L97

Added line #L97 was not covered by tests

# Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
# Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
# Modes:
# 0: no normalization
# 1: "ortho" - divide by 1/sqrt(signal_size)
# 2: divide by signal_size

# Select inverse mode for ONNX based on the norm mode and forward/backward mode.
# In ONNX the only difference between inverse=True/False is the 1/n normalization applied.
#
# If normalization is 1/n and we are in backward mode, we use the inverse
# mode in ONNX to get the 1/n normalization.
inverse = normalization == 2 and not forward
ortho = normalization == 1

Check warning on line 112 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L111-L112

Added lines #L111 - L112 were not covered by tests

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'ortho' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable ortho is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable

for dim in dims[:-1]:
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)

# Torch computers one-sided FFT on the last dimension only.
if onesided:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=True)

Check warning on line 119 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L119

Added line #L119 was not covered by tests
elif last_dim_size is not None:
transformed = op.DFT(

Check warning on line 121 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L121

Added line #L121 was not covered by tests
transformed, last_dim_size, axis=dims[-1], inverse=inverse, onesided=False
)
else:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)
transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False)

Check warning on line 125 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L125

Added line #L125 was not covered by tests

# Remove the batch dimension
transformed = op.Squeeze(transformed, axes=[0])
normalized = _fftn_onnx_normalization(

Check warning on line 127 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L127

Added line #L127 was not covered by tests

Check failure

Code scanning / lintrunner

PYLINT/E0602 Error

Undefined variable '_fftn_onnx_normalization' (undefined-variable)
See undefined-variable. To disable, use # pylint: disable=undefined-variable

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name \_fftn\_onnx\_normalization.
See https://docs.astral.sh/ruff/rules/undefined-name
self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size
)
# TODO: Merge to normalization mode and ONNX inverse mode
# Be sure to normalize before squeezing the batch dimension, because dims would
# have been shifted by 1 if the batch dimension was added.
if unsqueeze_first_dim:
# Remove the batch dimension
normalized = op.Squeeze(normalized, axes=[0])

Check warning on line 135 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L135

Added line #L135 was not covered by tests

return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)
return normalized

Check warning on line 137 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L137

Added line #L137 was not covered by tests


@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
Expand All @@ -124,9 +146,7 @@
Standard complex to complex FFT (forward or backward).
"""

# NOTE: trace_only because we need to negate forward
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
# NOTE: SymInt dim is not supported because DFT-17 needs a static axis

# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
Expand All @@ -139,23 +159,21 @@
self: TFloat,
dim: Sequence[int],
normalization: int,
last_dim_size: INT64, # pylint: disable=unused-argument
last_dim_size: INT64,
) -> TFloat:
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor

Complex to real inverse FFT.
"""

# TODO(justinchuby): Figure out what last_dim_size does

self_rank = len(self.shape)
# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
transformed = _fftn_onnx(

Check failure

Code scanning / lintrunner

PYLINT/E1123 Error

Unexpected keyword argument 'inverse' in function call (unexpected-keyword-arg)
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'forward' in function call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
self, dim, normalization, inverse=True, onesided=False, last_dim_size=last_dim_size
)
# Take only the real part
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])

return op.Squeeze(real_part, axes=[-1])


Expand Down
5 changes: 5 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,9 +689,14 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
(0, 1),
(0, 1, 2),
]:
# Slice
yield opinfo_core.SampleInput(
nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6
)
# Pad
yield opinfo_core.SampleInput(
nd_tensor(), dim=dim, normalization=normalization, last_dim_size=64
)


def _index_variable_bool(shape, max_indices, device):
Expand Down
3 changes: 0 additions & 3 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,6 @@ def _where_input_wrangler(
fft_ops.aten__fft_c2r,
tolerance={torch.complex64: (3e-3, 1.8e-4)},
complex=True,
).xfail(
dtypes=(torch.complex64,),
reason="fixme: the result is wrong: https://github.com/microsoft/onnxscript/pull/926",
),
TorchLibOpInfo(
"ops.aten._fft_r2c", # Custom from extra_opinfo
Expand Down
Loading