-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchlib] Fix _fft_c2r #1844
base: main
Are you sure you want to change the base?
[torchlib] Fix _fft_c2r #1844
Changes from all commits
615795c
5724cb0
9600b32
f6447a6
4d9ff42
469396f
631246d
6932a88
388deff
b729c41
ed2d5ea
59df223
f29d912
a7d95b1
7832d6f
09c0372
6961474
af7a47c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
# -------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Copyright (c) Microsoft Corporation. | ||
Check warning Code scanning / lintrunner RUFF/format Warning
Run lintrunner -a to apply this patch.
|
||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" | ||
"""torch.ops.aten operators under the `fft` module. | ||
|
||
|
@@ -12,7 +10,7 @@ | |
|
||
from __future__ import annotations | ||
|
||
from typing import Optional, Sequence | ||
from typing import Literal, Optional, Sequence | ||
Check warning Code scanning / lintrunner PYLINT/W0611 Warning
Unused Literal imported from typing (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import Check warning Code scanning / lintrunner RUFF/F401 Warning
typing.Literal imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import |
||
|
||
from onnxscript import INT64 | ||
from onnxscript.function_libs.torch_lib.registration import torch_op | ||
|
@@ -21,57 +19,53 @@ | |
from onnxscript.onnx_types import TensorType | ||
|
||
|
||
@torch_op( | ||
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), | ||
private=True, | ||
complex=True, | ||
traceable=True, | ||
) | ||
def _fftn_onnx_normalization( | ||
self, | ||
transformed: TFloat, | ||
normalization: int, | ||
forward: bool, | ||
def _compute_signal_size(signal: TFloat, dims: Sequence[int], last_dim_size: Optional[INT64] = None) -> INT64: | ||
if last_dim_size is not None: | ||
all_other_dims = dims[:-1] | ||
if all_other_dims: | ||
signal_size = op.ReduceProd(signal, axes=all_other_dims, keepdims=False) | ||
signal_size = op.Mul(signal_size, last_dim_size) | ||
else: | ||
signal_size = last_dim_size | ||
else: | ||
signal_size = op.ReduceProd(signal, axes=dims, keepdims=False) | ||
return signal_size | ||
|
||
|
||
def _fftn_ortho_normalization( | ||
self: TFloat, | ||
dims: Sequence[int], | ||
forward: bool, | ||
Check warning Code scanning / lintrunner PYLINT/W0613 Warning
Unused argument 'forward' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument |
||
onesided: bool, | ||
last_dim_size: Optional[INT64] = None, | ||
) -> TFloat: | ||
# Obtain the total_sample_count (n) for normalization | ||
self_shape = op.Shape(self) | ||
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) | ||
total_sample_count = op.CastLike(total_sample_count, transformed) | ||
|
||
# Normalize the result | ||
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn | ||
# Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 | ||
if normalization == 1: | ||
# "forward" - normalize by 1/n | ||
if forward: | ||
result = op.Div(transformed, op.Sqrt(total_sample_count)) | ||
else: | ||
result = op.Mul(transformed, op.Sqrt(total_sample_count)) | ||
elif normalization == 2: | ||
# "ortho" - normalize by 1/sqrt(n) | ||
if forward: | ||
result = op.Div(transformed, total_sample_count) | ||
else: | ||
result = transformed | ||
transformed = self | ||
|
||
signal_size = _compute_signal_size(self, dims, last_dim_size) | ||
Check warning Code scanning / lintrunner PYLINT/W0612 Warning
Unused variable 'signal_size' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warning Code scanning / lintrunner RUFF/F841 Warning
Local variable signal\_size is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
|
||
for dim in dims[:-1]: | ||
transformed = op.DFT(transformed, axis=dim, onesided=False) | ||
|
||
# Torch computes one-sided FFT on the last dimension only. | ||
if onesided: | ||
transformed = op.DFT(transformed, axis=dims[-1], onesided=True) | ||
# TODO: Update signal_size for one-sided FFT | ||
elif last_dim_size is not None: | ||
transformed = op.DFT( | ||
transformed, last_dim_size, axis=dims[-1], onesided=True | ||
) | ||
else: | ||
# "backward" - no normalization | ||
if forward: | ||
result = transformed | ||
else: | ||
result = op.Mul(transformed, total_sample_count) | ||
transformed = op.DFT(transformed, axis=dims[-1], onesided=False) | ||
|
||
return result | ||
|
||
|
||
@torch_op( | ||
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), | ||
trace_only=True, | ||
private=True, | ||
complex=True, | ||
) | ||
def _fftn_onnx( | ||
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool | ||
self: TFloat, | ||
dims: Sequence[int], | ||
normalization: int, | ||
forward: bool, | ||
onesided: bool, | ||
last_dim_size: Optional[INT64] = None, | ||
) -> TFloat: | ||
"""Standard complex to complex or real to complex FFT (forward or backward). | ||
|
||
|
@@ -81,38 +75,66 @@ | |
self: The input tensor. | ||
dims: The dimensions to apply FFT. | ||
normalization: The normalization mode. | ||
inverse: Whether to compute the inverse FFT. | ||
forward: Whether to compute forward FFT or backward FFT. | ||
onesided: Whether to compute the one-sided FFT, which retains only the | ||
positive frequencies. | ||
last_dim_size: The size of the last specified dimension. | ||
|
||
Returns: | ||
The transformed tensor. | ||
""" | ||
|
||
# NOTE: trace_only because we need to process each dimension in a loop | ||
# NOTE: SymInt dim is not support because DFT-17 needs a static axis | ||
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support | ||
|
||
# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new | ||
# dimension at the beginning to represent the batch dimension. | ||
transformed = op.Unsqueeze(self, axes=[0]) | ||
|
||
# Add 1 to account for the batch dimension when counting axes from the left | ||
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] | ||
|
||
for dim in new_dims[:-1]: | ||
# If taking FFT along the 0-th dimension: Since | ||
# the 0-th dimension in ONNX DFT-17 is the batch dimension (cannot take DFT over), | ||
# we need to add a new dimension at the beginning to represent the batch dimension. | ||
unsqueeze_first_dim = 0 in dims | ||
if unsqueeze_first_dim: | ||
transformed = op.Unsqueeze(self, axes=[0]) | ||
# Add 1 to account for the batch dimension when counting axes from the left | ||
dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] | ||
else: | ||
transformed = self | ||
|
||
# Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131 | ||
# Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19 | ||
# Modes: | ||
# 0: no normalization | ||
# 1: "ortho" - divide by 1/sqrt(signal_size) | ||
# 2: divide by signal_size | ||
|
||
# Select inverse mode for ONNX based on the norm mode and forward/backward mode. | ||
# In ONNX the only difference between inverse=True/False is the 1/n normalization applied. | ||
# | ||
# If normalization is 1/n and we are in backward mode, we use the inverse | ||
# mode in ONNX to get the 1/n normalization. | ||
inverse = normalization == 2 and not forward | ||
ortho = normalization == 1 | ||
Check warning Code scanning / lintrunner PYLINT/W0612 Warning
Unused variable 'ortho' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warning Code scanning / lintrunner RUFF/F841 Warning
Local variable ortho is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||
|
||
for dim in dims[:-1]: | ||
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False) | ||
|
||
# Torch computers one-sided FFT on the last dimension only. | ||
if onesided: | ||
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True) | ||
transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=True) | ||
elif last_dim_size is not None: | ||
transformed = op.DFT( | ||
transformed, last_dim_size, axis=dims[-1], inverse=inverse, onesided=False | ||
) | ||
else: | ||
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False) | ||
transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False) | ||
|
||
# Remove the batch dimension | ||
transformed = op.Squeeze(transformed, axes=[0]) | ||
normalized = _fftn_onnx_normalization( | ||
Check failure Code scanning / lintrunner PYLINT/E0602 Error
Undefined variable '_fftn_onnx_normalization' (undefined-variable)
See undefined-variable. To disable, use # pylint: disable=undefined-variable Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name \_fftn\_onnx\_normalization.
See https://docs.astral.sh/ruff/rules/undefined-name |
||
self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size | ||
) | ||
# TODO: Merge to normalization mode and ONNX inverse mode | ||
# Be sure to normalize before squeezing the batch dimension, because dims would | ||
# have been shifted by 1 if the batch dimension was added. | ||
if unsqueeze_first_dim: | ||
# Remove the batch dimension | ||
normalized = op.Squeeze(normalized, axes=[0]) | ||
|
||
return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims) | ||
return normalized | ||
|
||
|
||
@torch_op("aten::_fft_c2c", trace_only=True, complex=True) | ||
|
@@ -124,9 +146,7 @@ | |
Standard complex to complex FFT (forward or backward). | ||
""" | ||
|
||
# NOTE: trace_only because we need to negate forward | ||
# NOTE: SymInt dim is not support because DFT-17 needs a static axis | ||
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support | ||
# NOTE: SymInt dim is not supported because DFT-17 needs a static axis | ||
|
||
# ONNX DFT input assumes the last dimension is the complex dimension. | ||
# Thus dim=-1 in PyTorch is dim=-2 in ONNX. | ||
|
@@ -139,23 +159,21 @@ | |
self: TFloat, | ||
dim: Sequence[int], | ||
normalization: int, | ||
last_dim_size: INT64, # pylint: disable=unused-argument | ||
last_dim_size: INT64, | ||
) -> TFloat: | ||
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor | ||
|
||
Complex to real inverse FFT. | ||
""" | ||
|
||
# TODO(justinchuby): Figure out what last_dim_size does | ||
|
||
self_rank = len(self.shape) | ||
# ONNX DFT input assumes the last dimension is the complex dimension. | ||
# Thus dim=-1 in PyTorch is dim=-2 in ONNX. | ||
dim = [(d - 1) + self_rank if d < 0 else d for d in dim] | ||
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) | ||
transformed = _fftn_onnx( | ||
Check failure Code scanning / lintrunner PYLINT/E1123 Error
Unexpected keyword argument 'inverse' in function call (unexpected-keyword-arg)
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg Check failure Code scanning / lintrunner PYLINT/E1120 Error
No value for argument 'forward' in function call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter |
||
self, dim, normalization, inverse=True, onesided=False, last_dim_size=last_dim_size | ||
) | ||
# Take only the real part | ||
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) | ||
|
||
return op.Squeeze(real_part, axes=[-1]) | ||
|
||
|
||
|
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning