Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Complex OP】No.17 complex pow op #57645

Closed
wants to merge 11 commits into from
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/eager_math_op_patch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,7 @@ static PyObject* tensor__pow__method(TensorObject* self,
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) {
} else if (IsNumpyType(other_obj)) {
other = CastPyArg2Double(other_obj, "__pow__", 0);
}
{
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4351,6 +4351,23 @@ void TileInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void PowInferMeta(const MetaTensor& x, const Scalar& y, MetaTensor* out) {
if (y.dtype() == DataType::COMPLEX128 &&
!(x.dtype() == DataType::COMPLEX64 ||
x.dtype() == DataType::COMPLEX128)) {
if (x.dtype() == DataType::FLOAT64) {
out->set_dtype(phi::DataType::COMPLEX128);
} else {
out->set_dtype(phi::DataType::COMPLEX64);
}
} else if (y.dtype() == DataType::FLOAT64 &&
(x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64)) {
out->set_dtype(phi::DataType::FLOAT32);
} else {
out->set_dtype(x.dtype());
}
}

void TopKInferMeta(const MetaTensor& x,
const Scalar& k_scalar,
int axis,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,8 @@ int GetSplitAxisValue(const MetaTensor& x,
const Scalar& axis,
MetaConfig config);

void PowInferMeta(const MetaTensor& x, const Scalar& y, MetaTensor* out);

void FillSplitOutDims(const MetaTensor& x,
const int axis_value,
const std::vector<int64_t>& sections_vec,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ PD_REGISTER_KERNEL(pow_grad,
CPU,
ALL_LAYOUT,
phi::PowGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
float,
double,
int,
Expand Down
12 changes: 10 additions & 2 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,13 @@ PD_REGISTER_KERNEL(negative,
int,
int64_t) {}
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_KERNEL(
pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(pow,
CPU,
ALL_LAYOUT,
phi::PowKernel,
float,
double,
int,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int64_t) {}
37 changes: 29 additions & 8 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2740,33 +2740,54 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
T factor;
std::vector<std::pair<const char*, T*>> GetAttrs() {
return {{"factor", &factor}};
}

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.pow(static_cast<T>(factor)); // NOLINT
out.device(d) = x.template cast<T>().pow(factor); // NOLINT
}
};

template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
T factor;
std::vector<std::pair<const char*, T*>> GetAttrs() {
return {{"factor", &factor}};
}

template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor) - static_cast<T>(1));
dx.device(d) = dout * factor * x.pow(factor - static_cast<T>(1));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct PowGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> factor;
std::vector<std::pair<const char*, ComplexType<T>*>> GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) =
dout * (factor * x.pow(factor - static_cast<ComplexType<T>>(1)))
.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ PD_REGISTER_KERNEL(pow_grad,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(pow_double_grad,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ PD_REGISTER_KERNEL(pow,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(selu,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/activation_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ void PowGradKernel(const Context& dev_ctx,
auto* place = dev_ctx.eigen_device();
phi::funcs::PowGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
*(attrs[0].second) = factor.to<T>();
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten);
}

Expand Down
45 changes: 32 additions & 13 deletions paddle/phi/kernels/impl/activation_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"

namespace phi {

#define ToString(x) #x
Expand Down Expand Up @@ -62,23 +62,42 @@ void LogitKernel(const Context& dev_ctx,
functor(place, eigen_in, eigen_out, eigen_p, eps);
}

template <typename T, typename Context>
void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
template <typename InT, typename OutT, typename Context>
void PowImpl(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
PADDLE_ENFORCE_NOT_NULL(out,
errors::NotFound("Output Out should not be nullptr"));
dev_ctx.template Alloc<T>(out);
auto x_flatten = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(&x, "Input", "X", "Activation"));
auto out_flatten = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(out, "Output", "Out", "Activation"));
dev_ctx.template Alloc<OutT>(out, out->numel() * sizeof(OutT));
auto x_flatten = phi::EigenVector<InT>::Flatten(x);
auto out_flatten = phi::EigenVector<OutT>::Flatten(*out);
auto* place = dev_ctx.eigen_device();
phi::funcs::PowFunctor<T> functor;
phi::funcs::PowFunctor<OutT> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
*(attrs[0].second) = factor.to<OutT>();
functor(*place, x_flatten, out_flatten);
}

template <typename T, typename Context>
void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
if (factor.dtype() == DataType::COMPLEX128 &&
!(x.dtype() == DataType::COMPLEX64 ||
x.dtype() == DataType::COMPLEX128)) {
if (x.dtype() == DataType::FLOAT64) {
PowImpl<T, phi::dtype::complex<double>, Context>(dev_ctx, x, factor, out);
} else {
PowImpl<T, phi::dtype::complex<float>, Context>(dev_ctx, x, factor, out);
}
} else if (factor.dtype() == DataType::FLOAT64 &&
(x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64)) {
PowImpl<T, float, Context>(dev_ctx, x, factor, out);
} else {
PowImpl<T, T, Context>(dev_ctx, x, factor, out);
}
}

} // namespace phi
4 changes: 3 additions & 1 deletion test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,8 +1731,10 @@ def _dfs_grad_op(op_desc, fwd_op_desc=None):
visited_ops.append(op_desc.type())
has_infer_inplace = base.core.has_infer_inplace(op_desc.type())
has_grad_op_maker = base.core.has_grad_op_maker(op_desc.type())
# the OP test doesn't support higher order grad
is_grad_op_desc = op_desc.type().endswith('_grad')
has_infer_inplace_in_grad_descendants = False
if not has_grad_op_maker:
if not has_grad_op_maker or is_grad_op_desc:
has_infer_inplace_in_descendants = False
else:
# get grad_op_desc
Expand Down
45 changes: 36 additions & 9 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3800,7 +3800,13 @@ def setUp(self):
self.if_enable_cinn()

np.random.seed(1024)
x = np.random.uniform(1, 2, self.shape).astype(self.dtype)
if self.dtype is np.complex64 or self.dtype is np.complex128:
x = (
np.random.uniform(1, 2, self.shape)
+ 1j * np.random.uniform(1, 2, self.shape)
).astype(self.dtype)
else:
x = np.random.uniform(1, 2, self.shape).astype(self.dtype)
out = np.power(x, 3)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
Expand All @@ -3812,25 +3818,46 @@ def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(check_prim=True, check_prim_pir=True, check_pir=True)
if self.dtype not in [np.complex64, np.complex128]:
self.check_output(
check_prim=True, check_prim_pir=True, check_pir=True
)
else:
self.check_output()

def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(
['X'],
'Out',
check_prim=True,
check_prim_pir=True,
check_pir=True,
)
if self.dtype not in [np.complex64, np.complex128]:
self.check_grad(
['X'],
'Out',
check_prim=True,
check_prim_pir=True,
check_pir=True,
)
else:
self.check_grad(
['X'],
'Out',
)


class TestPow_ZeroDim(TestPow):
def init_shape(self):
self.shape = []


class TestPowComplex64(TestPow):
def init_dtype(self):
self.dtype = np.complex64


class TestPowComplex128(TestPow):
def init_dtype(self):
self.dtype = np.complex128


class TestPow_API(TestActivation):
def test_api(self):
with static_guard():
Expand Down