Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Nov 10, 2023
1 parent 4db8899 commit bd2b6de
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
16 changes: 9 additions & 7 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2743,9 +2743,10 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
T factor;
using AttrPair = std::vector<std::pair<const char*, ELEMENT_TYPE*>>;
std::vector<std::pair<const char*, T*>> GetAttrs() {
return {{"factor", &factor}};
}

typename AttrPair GetAttrs() { return {{"factor", &factor}}; }
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.template cast<T>().pow(factor); // NOLINT
Expand All @@ -2755,9 +2756,10 @@ struct PowFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
T factor;
using AttrPair = std::vector<std::pair<const char*, ELEMENT_TYPE*>>;
std::vector<std::pair<const char*, T*>> GetAttrs() {
return {{"factor", &factor}};
}

typename AttrPair GetAttrs() { return {{"factor", &factor}}; }
template <typename Device,
typename X,
typename Out,
Expand All @@ -2774,9 +2776,9 @@ template <typename T>
struct PowGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> factor;
using AttrPair = std::vector<std::pair<const char*, ComplexType<T>*>>;

typename AttrPair GetAttrs() { return {{"factor", &factor}}; }
std::vector<std::pair<const char*, ComplexType<T>*>> GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device,
typename X,
typename Out,
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/impl/activation_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ void PowGradKernel(const Context& dev_ctx,
GET_DATA_SAFELY(dx, "Output", "X@GRAD", "PowGrad"));
auto x_flatten =
EigenVector<T>::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad"));
std::cout << dout.dtype() << dx->dtype() << std::endl;
auto* place = dev_ctx.eigen_device();
phi::funcs::PowGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
Expand Down
14 changes: 7 additions & 7 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3829,13 +3829,13 @@ def test_check_grad(self):
if self.dtype == np.float16:
return
if self.dtype not in [np.complex64, np.complex128]:
self.check_grad(
['X'],
'Out',
check_prim=True,
check_prim_pir=True,
check_pir=True,
)
self.check_grad(
['X'],
'Out',
check_prim=True,
check_prim_pir=True,
check_pir=True,
)
else:
self.check_grad(
['X'],
Expand Down

0 comments on commit bd2b6de

Please sign in to comment.