-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchlib] Fix _fft_c2r #1844
base: main
Are you sure you want to change the base?
[torchlib] Fix _fft_c2r #1844
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1844 +/- ##
==========================================
- Coverage 73.77% 73.63% -0.14%
==========================================
Files 225 225
Lines 29333 29343 +10
Branches 3467 3470 +3
==========================================
- Hits 21639 21606 -33
- Misses 6560 6601 +41
- Partials 1134 1136 +2 ☔ View full report in Codecov by Sentry. |
Just specify dft_length |
@@ -1,7 +1,5 @@ | |||
# -------------------------------------------------------------------------- | |||
# Copyright (c) Microsoft Corporation. All rights reserved. | |||
# Copyright (c) Microsoft Corporation. |
Check warning
Code scanning / lintrunner
RUFF/format Warning
@@ -12,7 +10,7 @@ | |||
|
|||
from __future__ import annotations | |||
|
|||
from typing import Optional, Sequence | |||
from typing import Literal, Optional, Sequence |
Check warning
Code scanning / lintrunner
PYLINT/W0611 Warning
See unused-import. To disable, use # pylint: disable=unused-import
@@ -12,7 +10,7 @@ | |||
|
|||
from __future__ import annotations | |||
|
|||
from typing import Optional, Sequence | |||
from typing import Literal, Optional, Sequence |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
dims: Sequence[int], | ||
forward: bool, |
Check warning
Code scanning / lintrunner
PYLINT/W0613 Warning
See unused-argument. To disable, use # pylint: disable=unused-argument
result = transformed | ||
transformed = self | ||
|
||
signal_size = _compute_signal_size(self, dims, last_dim_size) |
Check warning
Code scanning / lintrunner
PYLINT/W0612 Warning
See unused-variable. To disable, use # pylint: disable=unused-variable
# If normalization is 1/n and we are in backward mode, we use the inverse | ||
# mode in ONNX to get the 1/n normalization. | ||
inverse = normalization == 2 and not forward | ||
ortho = normalization == 1 |
Check warning
Code scanning / lintrunner
RUFF/F841 Warning
See https://docs.astral.sh/ruff/rules/unused-variable
|
||
# Remove the batch dimension | ||
transformed = op.Squeeze(transformed, axes=[0]) | ||
normalized = _fftn_onnx_normalization( |
Check failure
Code scanning / lintrunner
PYLINT/E0602 Error
See undefined-variable. To disable, use # pylint: disable=undefined-variable
|
||
# Remove the batch dimension | ||
transformed = op.Squeeze(transformed, axes=[0]) | ||
normalized = _fftn_onnx_normalization( |
Check failure
Code scanning / lintrunner
RUFF/F821 Error
See https://docs.astral.sh/ruff/rules/undefined-name
self_rank = len(self.shape) | ||
# ONNX DFT input assumes the last dimension is the complex dimension. | ||
# Thus dim=-1 in PyTorch is dim=-2 in ONNX. | ||
dim = [(d - 1) + self_rank if d < 0 else d for d in dim] | ||
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) | ||
transformed = _fftn_onnx( |
Check failure
Code scanning / lintrunner
PYLINT/E1123 Error
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg
self_rank = len(self.shape) | ||
# ONNX DFT input assumes the last dimension is the complex dimension. | ||
# Thus dim=-1 in PyTorch is dim=-2 in ONNX. | ||
dim = [(d - 1) + self_rank if d < 0 else d for d in dim] | ||
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) | ||
transformed = _fftn_onnx( |
Check failure
Code scanning / lintrunner
PYLINT/E1120 Error
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
@@ -12,7 +10,7 @@ | |||
|
|||
from __future__ import annotations | |||
|
|||
from typing import Optional, Sequence | |||
from typing import Literal, Optional, Sequence |
Check notice
Code scanning / CodeQL
Unused import Note
result = transformed | ||
transformed = self | ||
|
||
signal_size = _compute_signal_size(self, dims, last_dim_size) |
Check notice
Code scanning / CodeQL
Unused local variable Note
|
||
# Torch computes one-sided FFT on the last dimension only. | ||
if onesided: | ||
transformed = op.DFT(transformed, axis=dims[-1], onesided=True) |
Check notice
Code scanning / CodeQL
Unused local variable Note
transformed = op.DFT(transformed, axis=dims[-1], onesided=True) | ||
# TODO: Update signal_size for one-sided FFT | ||
elif last_dim_size is not None: | ||
transformed = op.DFT( |
Check notice
Code scanning / CodeQL
Unused local variable Note
result = transformed | ||
else: | ||
result = op.Mul(transformed, total_sample_count) | ||
transformed = op.DFT(transformed, axis=dims[-1], onesided=False) |
Check notice
Code scanning / CodeQL
Unused local variable Note
# If normalization is 1/n and we are in backward mode, we use the inverse | ||
# mode in ONNX to get the 1/n normalization. | ||
inverse = normalization == 2 and not forward | ||
ortho = normalization == 1 |
Check notice
Code scanning / CodeQL
Unused local variable Note
transformed = _fftn_onnx( | ||
self, dim, normalization, inverse=True, onesided=False, last_dim_size=last_dim_size | ||
) |
Check failure
Code scanning / CodeQL
Wrong name for an argument in a call Error
Fix #1271 Fix pytorch/pytorch#113444 Fix pytorch/pytorch#119360