From 615795cd7f755ec565857b811ea326e9fbb98575 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 2 Sep 2024 20:35:37 -0700 Subject: [PATCH 01/18] [torchlib] Fix _fft_c2r --- onnxscript/function_libs/torch_lib/ops/fft.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index f35b4f611..89804d22e 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -21,12 +21,6 @@ from onnxscript.onnx_types import TensorType -@torch_op( - ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), - private=True, - complex=True, - traceable=True, -) def _fftn_onnx_normalization( self, transformed: TFloat, @@ -64,12 +58,6 @@ def _fftn_onnx_normalization( return result -@torch_op( - ("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"), - trace_only=True, - private=True, - complex=True, -) def _fftn_onnx( self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool ) -> TFloat: @@ -91,7 +79,6 @@ def _fftn_onnx( # NOTE: trace_only because we need to process each dimension in a loop # NOTE: SymInt dim is not support because DFT-17 needs a static axis - # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new # dimension at the beginning to represent the batch dimension. @@ -124,9 +111,7 @@ def aten__fft_c2c( Standard complex to complex FFT (forward or backward). """ - # NOTE: trace_only because we need to negate forward - # NOTE: SymInt dim is not support because DFT-17 needs a static axis - # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support + # NOTE: SymInt dim is not supported because DFT-17 needs a static axis # ONNX DFT input assumes the last dimension is the complex dimension. # Thus dim=-1 in PyTorch is dim=-2 in ONNX. @@ -139,7 +124,7 @@ def aten__fft_c2r( self: TFloat, dim: Sequence[int], normalization: int, - last_dim_size: INT64, # pylint: disable=unused-argument + last_dim_size: INT64, ) -> TFloat: """_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor @@ -154,9 +139,10 @@ def aten__fft_c2r( dim = [(d - 1) + self_rank if d < 0 else d for d in dim] transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) # Take only the real part - real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) + real_part = op.Squeeze(op.Slice(transformed, axes=[-1], starts=[0], ends=[1]), axes=[-1]) + last_dim_size = op.Reshape(last_dim_size, shape=[1]) - return op.Squeeze(real_part, axes=[-1]) + return op.Slice(real_part, axes=[-1], starts=[0], ends=last_dim_size) @torch_op("aten::_fft_r2c", trace_only=True) From 5724cb006d0244ba2cc608d018b045e4783cb809 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 2 Sep 2024 20:36:15 -0700 Subject: [PATCH 02/18] note --- onnxscript/function_libs/torch_lib/ops/fft.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 89804d22e..df2922087 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -130,9 +130,6 @@ def aten__fft_c2r( Complex to real inverse FFT. """ - - # TODO(justinchuby): Figure out what last_dim_size does - self_rank = len(self.shape) # ONNX DFT input assumes the last dimension is the complex dimension. # Thus dim=-1 in PyTorch is dim=-2 in ONNX. From 9600b3239e8d8e21aeb94ba3320220ced63e3998 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 2 Sep 2024 21:08:41 -0700 Subject: [PATCH 03/18] Batch dim --- onnxscript/function_libs/torch_lib/ops/fft.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index df2922087..dea7ab921 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -76,28 +76,34 @@ def _fftn_onnx( Returns: The transformed tensor. """ - - # NOTE: trace_only because we need to process each dimension in a loop # NOTE: SymInt dim is not support because DFT-17 needs a static axis - # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new - # dimension at the beginning to represent the batch dimension. - transformed = op.Unsqueeze(self, axes=[0]) + if 0 in dims: + # Taking FFT along the first dimension + # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new + # dimension at the beginning to represent the batch dimension. + unsqueeze_first_dim = True + else: + unsqueeze_first_dim = False + - # Add 1 to account for the batch dimension when counting axes from the left - new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] + if unsqueeze_first_dim: + transformed = op.Unsqueeze(self, axes=[0]) + # Add 1 to account for the batch dimension when counting axes from the left + dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] - for dim in new_dims[:-1]: + for dim in dims[:-1]: transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False) # Torch computers one-sided FFT on the last dimension only. if onesided: - transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True) + transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=True) else: - transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False) + transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False) - # Remove the batch dimension - transformed = op.Squeeze(transformed, axes=[0]) + if unsqueeze_first_dim: + # Remove the batch dimension + transformed = op.Squeeze(transformed, axes=[0]) return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims) From f6447a658f09ee3581b5ebfa5db493fe4ecb0331 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 2 Sep 2024 21:08:53 -0700 Subject: [PATCH 04/18] format --- onnxscript/function_libs/torch_lib/ops/fft.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index dea7ab921..762fd4cde 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -86,7 +86,6 @@ def _fftn_onnx( else: unsqueeze_first_dim = False - if unsqueeze_first_dim: transformed = op.Unsqueeze(self, axes=[0]) # Add 1 to account for the batch dimension when counting axes from the left From 4d9ff42f00edd6688f787f8651880d9621beb236 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 2 Sep 2024 21:13:30 -0700 Subject: [PATCH 05/18] dim --- onnxscript/function_libs/torch_lib/ops/fft.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 762fd4cde..e35c6600b 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -141,10 +141,12 @@ def aten__fft_c2r( dim = [(d - 1) + self_rank if d < 0 else d for d in dim] transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) # Take only the real part - real_part = op.Squeeze(op.Slice(transformed, axes=[-1], starts=[0], ends=[1]), axes=[-1]) + real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) last_dim_size = op.Reshape(last_dim_size, shape=[1]) - - return op.Slice(real_part, axes=[-1], starts=[0], ends=last_dim_size) + # The last dim is -2 because the real last dim is the complex dim, which we + # remove in the last step with Squeeze. + result = op.Slice(real_part, axes=[-2], starts=[0], ends=last_dim_size) + return op.Squeeze(result, axes=[-1]) @torch_op("aten::_fft_r2c", trace_only=True) From 469396f6d56e6ce9a17f4b2c219ae9325bca82b1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 2 Sep 2024 21:15:56 -0700 Subject: [PATCH 06/18] transformed --- onnxscript/function_libs/torch_lib/ops/fft.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index e35c6600b..8c46861c1 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -90,6 +90,8 @@ def _fftn_onnx( transformed = op.Unsqueeze(self, axes=[0]) # Add 1 to account for the batch dimension when counting axes from the left dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims] + else: + transformed = self for dim in dims[:-1]: transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False) From 631246dd080ed5e7c620c0396dfb0b331e7541b7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 2 Sep 2024 21:19:59 -0700 Subject: [PATCH 07/18] last_dim_size --- onnxscript/function_libs/torch_lib/ops/fft.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 8c46861c1..104964e36 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -148,6 +148,7 @@ def aten__fft_c2r( # The last dim is -2 because the real last dim is the complex dim, which we # remove in the last step with Squeeze. result = op.Slice(real_part, axes=[-2], starts=[0], ends=last_dim_size) + # TODO(justinchuby): We may need to pad if last_dim_size is bigger return op.Squeeze(result, axes=[-1]) From 6932a884dd4f26090530928b5d0d52e452cdaea1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 09:56:43 -0700 Subject: [PATCH 08/18] Update fft.py --- onnxscript/function_libs/torch_lib/ops/fft.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 104964e36..83e086b02 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -59,7 +59,7 @@ def _fftn_onnx_normalization( def _fftn_onnx( - self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool + self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool, last_dim_size: Optional[INT64] = None ) -> TFloat: """Standard complex to complex or real to complex FFT (forward or backward). @@ -99,6 +99,8 @@ def _fftn_onnx( # Torch computers one-sided FFT on the last dimension only. if onesided: transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=True) + elif last_dim_size is not None: + transformed = op.DFT(transformed, last_dim_size, axis=dims[-1], inverse=inverse, onesided=False) else: transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False) @@ -141,14 +143,9 @@ def aten__fft_c2r( # ONNX DFT input assumes the last dimension is the complex dimension. # Thus dim=-1 in PyTorch is dim=-2 in ONNX. dim = [(d - 1) + self_rank if d < 0 else d for d in dim] - transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False) + transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False, last_dim_size=last_dim_size) # Take only the real part real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) - last_dim_size = op.Reshape(last_dim_size, shape=[1]) - # The last dim is -2 because the real last dim is the complex dim, which we - # remove in the last step with Squeeze. - result = op.Slice(real_part, axes=[-2], starts=[0], ends=last_dim_size) - # TODO(justinchuby): We may need to pad if last_dim_size is bigger return op.Squeeze(result, axes=[-1]) From 388deffa585b83a4ee9ede71024c9c45be8a02da Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 10:01:26 -0700 Subject: [PATCH 09/18] Update fft.py --- onnxscript/function_libs/torch_lib/ops/fft.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 83e086b02..c5c31be39 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -1,7 +1,5 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# -------------------------------------------------------------------------- # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" """torch.ops.aten operators under the `fft` module. @@ -78,14 +76,10 @@ def _fftn_onnx( """ # NOTE: SymInt dim is not support because DFT-17 needs a static axis - if 0 in dims: - # Taking FFT along the first dimension - # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new - # dimension at the beginning to represent the batch dimension. - unsqueeze_first_dim = True - else: - unsqueeze_first_dim = False - + # If taking FFT along the 0-th dimension: Since + # the 0-th dimension in ONNX DFT-17 is the batch dimension (cannot take DFT over), + # we need to add a new dimension at the beginning to represent the batch dimension. + unsqueeze_first_dim = 0 in dims if unsqueeze_first_dim: transformed = op.Unsqueeze(self, axes=[0]) # Add 1 to account for the batch dimension when counting axes from the left From b729c411269fc37200aa7a5b0ce27298b399f844 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 10:04:22 -0700 Subject: [PATCH 10/18] format --- onnxscript/function_libs/torch_lib/ops/fft.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index c5c31be39..7b9f5905e 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -57,7 +57,12 @@ def _fftn_onnx_normalization( def _fftn_onnx( - self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool, last_dim_size: Optional[INT64] = None + self: TFloat, + dims: Sequence[int], + normalization: int, + inverse: bool, + onesided: bool, + last_dim_size: Optional[INT64] = None, ) -> TFloat: """Standard complex to complex or real to complex FFT (forward or backward). @@ -94,7 +99,9 @@ def _fftn_onnx( if onesided: transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=True) elif last_dim_size is not None: - transformed = op.DFT(transformed, last_dim_size, axis=dims[-1], inverse=inverse, onesided=False) + transformed = op.DFT( + transformed, last_dim_size, axis=dims[-1], inverse=inverse, onesided=False + ) else: transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False) @@ -137,10 +144,12 @@ def aten__fft_c2r( # ONNX DFT input assumes the last dimension is the complex dimension. # Thus dim=-1 in PyTorch is dim=-2 in ONNX. dim = [(d - 1) + self_rank if d < 0 else d for d in dim] - transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False, last_dim_size=last_dim_size) + transformed = _fftn_onnx( + self, dim, normalization, inverse=True, onesided=False, last_dim_size=last_dim_size + ) # Take only the real part real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1]) - return op.Squeeze(result, axes=[-1]) + return op.Squeeze(real_part, axes=[-1]) @torch_op("aten::_fft_r2c", trace_only=True) From ed2d5ea3cd3482439cc8ef4030ffbc876c4fa765 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 10:10:18 -0700 Subject: [PATCH 11/18] Test --- tests/function_libs/torch_lib/extra_opinfo.py | 5 +++++ tests/function_libs/torch_lib/ops_test_data.py | 3 --- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 0abced612..36e449ad7 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -689,9 +689,14 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): (0, 1), (0, 1, 2), ]: + # Slice yield opinfo_core.SampleInput( nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6 ) + # Pad + yield opinfo_core.SampleInput( + nd_tensor(), dim=dim, normalization=normalization, last_dim_size=64 + ) def _index_variable_bool(shape, max_indices, device): diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 3f9576745..b869092ca 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -472,9 +472,6 @@ def _where_input_wrangler( fft_ops.aten__fft_c2r, tolerance={torch.complex64: (3e-3, 1.8e-4)}, complex=True, - ).xfail( - dtypes=(torch.complex64,), - reason="fixme: the result is wrong: https://github.com/microsoft/onnxscript/pull/926", ), TorchLibOpInfo( "ops.aten._fft_r2c", # Custom from extra_opinfo From 59df223d1c2559bdffe18e9c7ef5d712409edfa7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 10:27:39 -0700 Subject: [PATCH 12/18] ? --- onnxscript/function_libs/torch_lib/ops/fft.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 7b9f5905e..e5be39cf3 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -20,15 +20,20 @@ def _fftn_onnx_normalization( - self, + self: TFloat, transformed: TFloat, normalization: int, forward: bool, dims: Sequence[int], + last_dim_size: Optional[INT64] = None, ) -> TFloat: # Obtain the total_sample_count (n) for normalization self_shape = op.Shape(self) - total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) + if last_dim_size is None: + total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) + else: + total_sample_count = op.ReduceProd(op.Gather(self_shape, dims[:-1]), keepdims=0) + total_sample_count = op.Mul(total_sample_count, last_dim_size) total_sample_count = op.CastLike(total_sample_count, transformed) # Normalize the result @@ -109,7 +114,7 @@ def _fftn_onnx( # Remove the batch dimension transformed = op.Squeeze(transformed, axes=[0]) - return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims) + return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size) @torch_op("aten::_fft_c2c", trace_only=True, complex=True) From f29d912b88039117fe34187d0bee998e232422a0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 10:44:45 -0700 Subject: [PATCH 13/18] this --- onnxscript/function_libs/torch_lib/ops/fft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index e5be39cf3..c37a9226c 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -40,13 +40,13 @@ def _fftn_onnx_normalization( # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 if normalization == 1: - # "forward" - normalize by 1/n + # "ortho" - normalize by 1/sqrt(n) if forward: result = op.Div(transformed, op.Sqrt(total_sample_count)) else: result = op.Mul(transformed, op.Sqrt(total_sample_count)) elif normalization == 2: - # "ortho" - normalize by 1/sqrt(n) + # "forward" - normalize by 1/n if forward: result = op.Div(transformed, total_sample_count) else: From a7d95b1dbaed7f24281ccc3ac8ce240ca6516c56 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 11:09:04 -0700 Subject: [PATCH 14/18] try this --- onnxscript/function_libs/torch_lib/ops/fft.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index c37a9226c..02da9119e 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -39,24 +39,32 @@ def _fftn_onnx_normalization( # Normalize the result # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 - if normalization == 1: - # "ortho" - normalize by 1/sqrt(n) + # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131 + # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19 + if normalization == 0: + # "none" - no normalization + if forward: + result = transformed + else: + # Revert the 1/n normalization done by ONNX + result = op.Mul(transformed, total_sample_count) + elif normalization == 1: + # "ortho" - divide by 1/sqrt(signal_size) if forward: result = op.Div(transformed, op.Sqrt(total_sample_count)) else: + # ifft of DFT in ONNX is already normalized with 1/n, so we should + # multiply by n before dividing by sqrt(n) to get the correct result, + # Which is equivalent to `*sqrt(n)` in the end. result = op.Mul(transformed, op.Sqrt(total_sample_count)) - elif normalization == 2: + else: + # normalization == 2, divide by signal_size # "forward" - normalize by 1/n if forward: result = op.Div(transformed, total_sample_count) else: + # Keep the 1/n normalization done by ONNX result = transformed - else: - # "backward" - no normalization - if forward: - result = transformed - else: - result = op.Mul(transformed, total_sample_count) return result @@ -114,7 +122,9 @@ def _fftn_onnx( # Remove the batch dimension transformed = op.Squeeze(transformed, axes=[0]) - return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size) + return _fftn_onnx_normalization( + self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size + ) @torch_op("aten::_fft_c2c", trace_only=True, complex=True) From 7832d6fc985d3e47d55f3b0051fdee42e917d531 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 11:11:05 -0700 Subject: [PATCH 15/18] this --- onnxscript/function_libs/torch_lib/ops/fft.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 02da9119e..a4760c03d 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -32,8 +32,12 @@ def _fftn_onnx_normalization( if last_dim_size is None: total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) else: - total_sample_count = op.ReduceProd(op.Gather(self_shape, dims[:-1]), keepdims=0) - total_sample_count = op.Mul(total_sample_count, last_dim_size) + all_other_dims = dims[:-1] + if all_other_dims: + total_sample_count = op.ReduceProd(op.Gather(self_shape, dims[:-1]), keepdims=0) + total_sample_count = op.Mul(total_sample_count, last_dim_size) + else: + total_sample_count = last_dim_size total_sample_count = op.CastLike(total_sample_count, transformed) # Normalize the result From 09c03727420aabc96e050e7d6a103442ef142e0c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 12:29:05 -0700 Subject: [PATCH 16/18] normalization --- onnxscript/function_libs/torch_lib/ops/fft.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index a4760c03d..6f3506909 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -92,6 +92,7 @@ def _fftn_onnx( inverse: Whether to compute the inverse FFT. onesided: Whether to compute the one-sided FFT, which retains only the positive frequencies. + last_dim_size: The size of the last specified dimension. Returns: The transformed tensor. @@ -122,13 +123,16 @@ def _fftn_onnx( else: transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False) + normalized = _fftn_onnx_normalization( + self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size + ) + # Be sure to normalize before squeezing the batch dimension, because dims would + # have been shifted by 1 if the batch dimension was added. if unsqueeze_first_dim: # Remove the batch dimension - transformed = op.Squeeze(transformed, axes=[0]) + normalized = op.Squeeze(normalized, axes=[0]) - return _fftn_onnx_normalization( - self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size - ) + return normalized @torch_op("aten::_fft_c2c", trace_only=True, complex=True) From 696147429405b6c2de25cf1e65122b3338b24ba6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 12:42:18 -0700 Subject: [PATCH 17/18] todo --- onnxscript/function_libs/torch_lib/ops/fft.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 6f3506909..51b0af489 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -126,6 +126,7 @@ def _fftn_onnx( normalized = _fftn_onnx_normalization( self, transformed, normalization, not inverse, dims, last_dim_size=last_dim_size ) + # TODO: Merge to normalization mode and ONNX inverse mode # Be sure to normalize before squeezing the batch dimension, because dims would # have been shifted by 1 if the batch dimension was added. if unsqueeze_first_dim: From af7a47c8c454c0da1658db54725dcfcd1642e8b0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Sep 2024 14:52:59 -0700 Subject: [PATCH 18/18] ahh --- onnxscript/function_libs/torch_lib/ops/fft.py | 97 ++++++++++--------- 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index 51b0af489..02ce6166f 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -10,7 +10,7 @@ from __future__ import annotations -from typing import Optional, Sequence +from typing import Literal, Optional, Sequence from onnxscript import INT64 from onnxscript.function_libs.torch_lib.registration import torch_op @@ -19,65 +19,51 @@ from onnxscript.onnx_types import TensorType -def _fftn_onnx_normalization( +def _compute_signal_size(signal: TFloat, dims: Sequence[int], last_dim_size: Optional[INT64] = None) -> INT64: + if last_dim_size is not None: + all_other_dims = dims[:-1] + if all_other_dims: + signal_size = op.ReduceProd(signal, axes=all_other_dims, keepdims=False) + signal_size = op.Mul(signal_size, last_dim_size) + else: + signal_size = last_dim_size + else: + signal_size = op.ReduceProd(signal, axes=dims, keepdims=False) + return signal_size + + +def _fftn_ortho_normalization( self: TFloat, - transformed: TFloat, - normalization: int, - forward: bool, dims: Sequence[int], + forward: bool, + onesided: bool, last_dim_size: Optional[INT64] = None, ) -> TFloat: - # Obtain the total_sample_count (n) for normalization - self_shape = op.Shape(self) - if last_dim_size is None: - total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0) - else: - all_other_dims = dims[:-1] - if all_other_dims: - total_sample_count = op.ReduceProd(op.Gather(self_shape, dims[:-1]), keepdims=0) - total_sample_count = op.Mul(total_sample_count, last_dim_size) - else: - total_sample_count = last_dim_size - total_sample_count = op.CastLike(total_sample_count, transformed) + transformed = self - # Normalize the result - # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn - # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42 - # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131 - # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19 - if normalization == 0: - # "none" - no normalization - if forward: - result = transformed - else: - # Revert the 1/n normalization done by ONNX - result = op.Mul(transformed, total_sample_count) - elif normalization == 1: - # "ortho" - divide by 1/sqrt(signal_size) - if forward: - result = op.Div(transformed, op.Sqrt(total_sample_count)) - else: - # ifft of DFT in ONNX is already normalized with 1/n, so we should - # multiply by n before dividing by sqrt(n) to get the correct result, - # Which is equivalent to `*sqrt(n)` in the end. - result = op.Mul(transformed, op.Sqrt(total_sample_count)) + signal_size = _compute_signal_size(self, dims, last_dim_size) + + for dim in dims[:-1]: + transformed = op.DFT(transformed, axis=dim, onesided=False) + + # Torch computes one-sided FFT on the last dimension only. + if onesided: + transformed = op.DFT(transformed, axis=dims[-1], onesided=True) + # TODO: Update signal_size for one-sided FFT + elif last_dim_size is not None: + transformed = op.DFT( + transformed, last_dim_size, axis=dims[-1], onesided=True + ) else: - # normalization == 2, divide by signal_size - # "forward" - normalize by 1/n - if forward: - result = op.Div(transformed, total_sample_count) - else: - # Keep the 1/n normalization done by ONNX - result = transformed + transformed = op.DFT(transformed, axis=dims[-1], onesided=False) - return result def _fftn_onnx( self: TFloat, dims: Sequence[int], normalization: int, - inverse: bool, + forward: bool, onesided: bool, last_dim_size: Optional[INT64] = None, ) -> TFloat: @@ -89,7 +75,7 @@ def _fftn_onnx( self: The input tensor. dims: The dimensions to apply FFT. normalization: The normalization mode. - inverse: Whether to compute the inverse FFT. + forward: Whether to compute forward FFT or backward FFT. onesided: Whether to compute the one-sided FFT, which retains only the positive frequencies. last_dim_size: The size of the last specified dimension. @@ -110,6 +96,21 @@ def _fftn_onnx( else: transformed = self + # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131 + # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19 + # Modes: + # 0: no normalization + # 1: "ortho" - divide by 1/sqrt(signal_size) + # 2: divide by signal_size + + # Select inverse mode for ONNX based on the norm mode and forward/backward mode. + # In ONNX the only difference between inverse=True/False is the 1/n normalization applied. + # + # If normalization is 1/n and we are in backward mode, we use the inverse + # mode in ONNX to get the 1/n normalization. + inverse = normalization == 2 and not forward + ortho = normalization == 1 + for dim in dims[:-1]: transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)