-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchlib] Fix _fft_c2r #1844
base: main
Are you sure you want to change the base?
[torchlib] Fix _fft_c2r #1844
Changes from 12 commits
615795c
5724cb0
9600b32
f6447a6
4d9ff42
469396f
631246d
6932a88
388deff
b729c41
ed2d5ea
59df223
f29d912
a7d95b1
7832d6f
09c0372
6961474
af7a47c
26779c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
# -------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Copyright (c) Microsoft Corporation. | ||
Check warning Code scanning / lintrunner RUFF/format Warning
Run lintrunner -a to apply this patch.
|
||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" | ||
"""torch.ops.aten operators under the `fft` module. | ||
|
||
|
@@ -21,22 +19,21 @@ | |
from onnxscript.onnx_types import TensorType | ||
|
||
|
||
@torch_op( | ||
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), | ||
private=True, | ||
complex=True, | ||
traceable=True, | ||
) | ||
def _fftn_onnx_normalization( | ||
self, | ||
self: TFloat, | ||
transformed: TFloat, | ||
normalization: int, | ||
forward: bool, | ||
Check warning Code scanning / lintrunner PYLINT/W0613 Warning
Unused argument 'forward' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument |
||
dims: Sequence[int], | ||
last_dim_size: Optional[INT64] = None, | ||
) -> TFloat: | ||
# Obtain the total_sample_count (n) for normalization | ||
self_shape = op.Shape(self) | ||
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) | ||
if last_dim_size is None: | ||
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) | ||
else: | ||
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims[:-1]), keepdims=0) | ||
total_sample_count = op.Mul(total_sample_count, last_dim_size) | ||
total_sample_count = op.CastLike(total_sample_count, transformed) | ||
|
||
# Normalize the result | ||
|
@@ -64,14 +61,13 @@ | |
return result | ||
|
||
|
||
@torch_op( | ||
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), | ||
trace_only=True, | ||
private=True, | ||
complex=True, | ||
) | ||
def _fftn_onnx( | ||
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool | ||
self: TFloat, | ||
dims: Sequence[int], | ||
normalization: int, | ||
inverse: bool, | ||
onesided: bool, | ||
last_dim_size: Optional[INT64] = None, | ||
) -> TFloat: | ||
"""Standard complex to complex or real to complex FFT (forward or backward). | ||
|
||
|
@@ -88,31 +84,37 @@ | |
Returns: | ||
The transformed tensor. | ||
""" | ||
|
||
# NOTE: trace_only because we need to process each dimension in a loop | ||
# NOTE: SymInt dim is not support because DFT-17 needs a static axis | ||
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support | ||
|
||
# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new | ||
# dimension at the beginning to represent the batch dimension. | ||
transformed = op.Unsqueeze(self, axes=[0]) | ||
|
||
# Add 1 to account for the batch dimension when counting axes from the left | ||
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] | ||
# If taking FFT along the 0-th dimension: Since | ||
# the 0-th dimension in ONNX DFT-17 is the batch dimension (cannot take DFT over), | ||
# we need to add a new dimension at the beginning to represent the batch dimension. | ||
unsqueeze_first_dim = 0 in dims | ||
if unsqueeze_first_dim: | ||
transformed = op.Unsqueeze(self, axes=[0]) | ||
# Add 1 to account for the batch dimension when counting axes from the left | ||
dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] | ||
else: | ||
transformed = self | ||
|
||
for dim in new_dims[:-1]: | ||
for dim in dims[:-1]: | ||
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False) | ||
|
||
# Torch computers one-sided FFT on the last dimension only. | ||
if onesided: | ||
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True) | ||
transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=True) | ||
elif last_dim_size is not None: | ||
transformed = op.DFT( | ||
transformed, last_dim_size, axis=dims[-1], inverse=inverse, onesided=False | ||
) | ||
else: | ||
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False) | ||
transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False) | ||
|
||
# Remove the batch dimension | ||
transformed = op.Squeeze(transformed, axes=[0]) | ||
if unsqueeze_first_dim: | ||
# Remove the batch dimension | ||
transformed = op.Squeeze(transformed, axes=[0]) | ||
|
||
return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims) | ||
return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size) | ||
|
||
|
||
@torch_op("aten::_fft_c2c", trace_only=True, complex=True) | ||
|
@@ -124,9 +126,7 @@ | |
Standard complex to complex FFT (forward or backward). | ||
""" | ||
|
||
# NOTE: trace_only because we need to negate forward | ||
# NOTE: SymInt dim is not support because DFT-17 needs a static axis | ||
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support | ||
# NOTE: SymInt dim is not supported because DFT-17 needs a static axis | ||
|
||
# ONNX DFT input assumes the last dimension is the complex dimension. | ||
# Thus dim=-1 in PyTorch is dim=-2 in ONNX. | ||
|
@@ -139,23 +139,21 @@ | |
self: TFloat, | ||
dim: Sequence[int], | ||
normalization: int, | ||
last_dim_size: INT64, # pylint: disable=unused-argument | ||
last_dim_size: INT64, | ||
) -> TFloat: | ||
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor | ||
|
||
Complex to real inverse FFT. | ||
""" | ||
|
||
# TODO(justinchuby): Figure out what last_dim_size does | ||
|
||
self_rank = len(self.shape) | ||
# ONNX DFT input assumes the last dimension is the complex dimension. | ||
# Thus dim=-1 in PyTorch is dim=-2 in ONNX. | ||
dim = [(d - 1) + self_rank if d < 0 else d for d in dim] | ||
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) | ||
transformed = _fftn_onnx( | ||
Check failure Code scanning / lintrunner PYLINT/E1123 Error
Unexpected keyword argument 'inverse' in function call (unexpected-keyword-arg)
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg Check failure Code scanning / lintrunner PYLINT/E1120 Error
No value for argument 'forward' in function call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter |
||
self, dim, normalization, inverse=True, onesided=False, last_dim_size=last_dim_size | ||
) | ||
Comment on lines
+172
to
+174
Check failure Code scanning / CodeQL Wrong name for an argument in a call Error
Keyword argument 'inverse' is not a supported parameter name of
function _fftn_onnx Error loading related location Loading |
||
# Take only the real part | ||
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) | ||
|
||
return op.Squeeze(real_part, axes=[-1]) | ||
|
||
|
||
|
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning