diff --git a/paddle/phi/kernels/cpu/fold_grad_kernel.cc b/paddle/phi/kernels/cpu/fold_grad_kernel.cc index 0c3f1dda03e5e..a56b0aa054571 100644 --- a/paddle/phi/kernels/cpu/fold_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/fold_grad_kernel.cc @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/fold_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - fold_grad, CPU, ALL_LAYOUT, phi::FoldGradKernel, float, double) {} +PD_REGISTER_KERNEL(fold_grad, + CPU, + ALL_LAYOUT, + phi::FoldGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/fold_kernel.cc b/paddle/phi/kernels/cpu/fold_kernel.cc index e22ac4c771ed9..df6cf5652c992 100644 --- a/paddle/phi/kernels/cpu/fold_kernel.cc +++ b/paddle/phi/kernels/cpu/fold_kernel.cc @@ -18,4 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/fold_kernel_impl.h" -PD_REGISTER_KERNEL(fold, CPU, ALL_LAYOUT, phi::FoldKernel, float, double) {} +PD_REGISTER_KERNEL(fold, + CPU, + ALL_LAYOUT, + phi::FoldKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/im2col.cc b/paddle/phi/kernels/funcs/im2col.cc index 0b5901367488a..e4c470e1a7064 100644 --- a/paddle/phi/kernels/funcs/im2col.cc +++ b/paddle/phi/kernels/funcs/im2col.cc @@ -160,12 +160,24 @@ template class Im2ColFunctor; +template class Im2ColFunctor>; +template class Im2ColFunctor>; template class Col2ImFunctor; template class Col2ImFunctor; +template class Col2ImFunctor>; +template class Col2ImFunctor>; /* * im = [input_channels, input_height, input_width] @@ -331,11 +343,23 @@ template class Im2ColFunctor; +template class Im2ColFunctor>; +template class Im2ColFunctor>; template class Col2ImFunctor; template class Col2ImFunctor; +template class Col2ImFunctor>; +template class Col2ImFunctor>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/im2col.cu b/paddle/phi/kernels/funcs/im2col.cu index 87c82adbb7fbe..b633241810f9b 100644 --- a/paddle/phi/kernels/funcs/im2col.cu +++ b/paddle/phi/kernels/funcs/im2col.cu @@ -310,6 +310,12 @@ template class Im2ColFunctor; +template class Im2ColFunctor>; +template class Im2ColFunctor>; template class Im2ColFunctor; @@ -322,6 +328,12 @@ template class Col2ImFunctor; +template class Col2ImFunctor>; +template class Col2ImFunctor>; template class Col2ImFunctor; @@ -573,6 +585,12 @@ template class Im2ColFunctor; +template class Im2ColFunctor>; +template class Im2ColFunctor>; template class Im2ColFunctor; @@ -585,6 +603,12 @@ template class Col2ImFunctor; +template class Col2ImFunctor>; +template class Col2ImFunctor>; template class Col2ImFunctor; diff --git a/paddle/phi/kernels/gpu/fold_grad_kernel.cu b/paddle/phi/kernels/gpu/fold_grad_kernel.cu index ad469dd7981de..1e3cceb04dd0d 100644 --- a/paddle/phi/kernels/gpu/fold_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/fold_grad_kernel.cu @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/fold_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - fold_grad, GPU, ALL_LAYOUT, phi::FoldGradKernel, float, double) {} +PD_REGISTER_KERNEL(fold_grad, + GPU, + ALL_LAYOUT, + phi::FoldGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/fold_kernel.cu b/paddle/phi/kernels/gpu/fold_kernel.cu index b53ef402150c2..2e21a121a0cc6 100644 --- a/paddle/phi/kernels/gpu/fold_kernel.cu +++ b/paddle/phi/kernels/gpu/fold_kernel.cu @@ -18,4 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/fold_kernel_impl.h" -PD_REGISTER_KERNEL(fold, GPU, ALL_LAYOUT, phi::FoldKernel, float, double) {} +PD_REGISTER_KERNEL(fold, + GPU, + ALL_LAYOUT, + phi::FoldKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index bc43e73c4163d..1a23e072bd82b 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2280,7 +2280,7 @@ def fold( Parameters: x(Tensor): 3-D Tensor, input tensor of format [N, C, L], - data type can be float32 or float64 + data type can be float32, float64, complex64 or complex128 output_sizes(int|list|tuple): The size of output size, should be [output_size_h, output_size_w] or an interger o treated as [o, o]. kernel_sizes(int|list|tuple): The size of convolution kernel, should be [k_h, k_w] @@ -2325,7 +2325,9 @@ def fold( helper = LayerHelper("fold", **locals()) - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'fold') + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'complex64', 'complex128'], 'fold' + ) assert len(x.shape) == 3, "input should be the format of [N, C, L]" diff --git a/test/legacy_test/test_fold_op.py b/test/legacy_test/test_fold_op.py index 8fdb37deadf21..2a8649b8a3a63 100644 --- a/test/legacy_test/test_fold_op.py +++ b/test/legacy_test/test_fold_op.py @@ -39,7 +39,15 @@ def init_data(self): self.dilations = [1, 1] self.output_sizes = [4, 5] input_shape = [self.batch_size, self.input_channels, self.length] - self.x = np.random.rand(*input_shape).astype(np.float64) + self.x = np.random.rand(*input_shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + self.x = ( + np.random.uniform(-1, 1, input_shape) + + 1j * np.random.uniform(-1, 1, input_shape) + ).astype(self.dtype) + + def init_dtype(self): + self.dtype = np.float64 def calc_fold(self): output_shape = [0] * 4 @@ -75,7 +83,7 @@ def calc_fold(self): ) + 1 ) - output = np.zeros(output_shape).astype(np.float64) + output = np.zeros(output_shape).astype(self.dtype) # ------------- calculate output ------------- # for b in range(output_shape[0]): for c in range(self.input_channels): @@ -106,6 +114,7 @@ def calc_fold(self): self.outputs = output def set_data(self): + self.init_dtype() self.init_data() self.calc_fold() self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x)} @@ -130,6 +139,16 @@ def test_check_grad(self): self.check_grad(['X'], 'Y') +class TestFold_Complex64(TestFoldOp): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestFold_Complex128(TestFoldOp): + def init_dtype(self): + self.dtype = np.complex128 + + class TestFoldshape(TestFoldOp): def init_data(self): self.batch_size = 8