diff --git a/paddle/fluid/operators/assign_value_op.h b/paddle/fluid/operators/assign_value_op.h index 2522fa580f758..f5b74c5441174 100644 --- a/paddle/fluid/operators/assign_value_op.h +++ b/paddle/fluid/operators/assign_value_op.h @@ -110,6 +110,9 @@ class AssignValueKernel : public framework::OpKernel { case framework::proto::VarType::FP32: value_name = "fp32_values"; break; + case framework::proto::VarType::FP64: + value_name = "fp64_values"; + break; case framework::proto::VarType::INT64: value_name = "int64_values"; case framework::proto::VarType::INT8: @@ -118,7 +121,7 @@ class AssignValueKernel : public framework::OpKernel { default: PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type(code %d) for AssignValue operator, only " - "supports bool, int32, float32, int8 and int64.", + "supports bool, int32, float32, float64, int8 and int64.", dtype)); break; } diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index fc118c8a401de..9f8def740385b 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -90,7 +90,7 @@ backward : assign_grad - op : assign_value - args : (int[] shape, DataType dtype, int[] bool_values = {}, float[] fp32_values = {}, int[] int32_values = {}, int64_t[] int64_values = {}) + args : (int[] shape, DataType dtype, int[] bool_values = {}, float[] fp32_values = {}, double[] fp64_values = {}, int[] int32_values = {}, int64_t[] int64_values = {}) output : Tensor(out) infer_meta : func : AssignValueInferMeta diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index db30ec7389619..b828aefa012a7 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(assign_value, bool, int, float, + double, int8_t, int64_t) {} @@ -159,6 +160,7 @@ PD_REGISTER_KERNEL(assign_value, bool, int, float, + double, int8_t, int64_t) {} #endif diff --git a/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc b/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc index 53346d92e78bf..d7efb76b4bf0e 100644 --- a/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc +++ b/paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc @@ -59,4 +59,5 @@ PD_REGISTER_KERNEL(truncated_gaussian_random, CPU, ALL_LAYOUT, phi::TruncatedGaussianRandomKernel, - float) {} + float, + double) {} diff --git a/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu b/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu index 698dcc20ad3fe..a7278302cf4e0 100644 --- a/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu +++ b/paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu @@ -124,4 +124,5 @@ PD_REGISTER_KERNEL(truncated_gaussian_random, GPU, ALL_LAYOUT, phi::TruncatedGaussianRandomKernel, - float) {} + float, + double) {} diff --git a/paddle/phi/ops/compat/assign_value_sig.cc b/paddle/phi/ops/compat/assign_value_sig.cc index 0fa1889ccde34..977c2260e59b9 100644 --- a/paddle/phi/ops/compat/assign_value_sig.cc +++ b/paddle/phi/ops/compat/assign_value_sig.cc @@ -33,6 +33,9 @@ KernelSignature AssignValueOpArgumentMapping( } else if (dtype == /*FP32*/ 5) { return KernelSignature( "assign_value", {}, {"shape", "dtype", "fp32_values"}, {"Out"}); + } else if (dtype == /*FP64*/ 6) { + return KernelSignature( + "assign_value", {}, {"shape", "dtype", "fp64_values"}, {"Out"}); } else if (dtype == /*INT64*/ 3) { return KernelSignature( "assign_value", {}, {"shape", "dtype", "int64_values"}, {"Out"}); diff --git a/python/paddle/nn/initializer/assign.py b/python/paddle/nn/initializer/assign.py index a1cd06cab59b4..9f9947ccc6ed8 100644 --- a/python/paddle/nn/initializer/assign.py +++ b/python/paddle/nn/initializer/assign.py @@ -78,6 +78,9 @@ def forward(self, var, block=None): if out_dtype == core.VarDesc.VarType.FP32: value_name = "fp32_values" values = [float(v) for v in np_value.flat] + elif out_dtype == core.VarDesc.VarType.FP64: + value_name = "fp64_values" + values = [float(v) for v in np_value.flat] elif out_dtype == core.VarDesc.VarType.INT32: value_name = "int32_values" values = [int(v) for v in np_value.flat] diff --git a/test/legacy_test/test_initializer.py b/test/legacy_test/test_initializer.py index 52b2e4d5024dd..903f47671549e 100644 --- a/test/legacy_test/test_initializer.py +++ b/test/legacy_test/test_initializer.py @@ -16,6 +16,8 @@ import unittest import numpy as np +from scipy import special +from utils import dygraph_guard, static_guard import paddle from paddle import base @@ -796,7 +798,7 @@ def test_order(self): paddle.set_device('cpu') SEED = 123 weight_attr = paddle.framework.ParamAttr( - name="linear_weight", + name="linear_weight2", learning_rate=1.0, trainable=False, regularizer=None, @@ -805,7 +807,7 @@ def test_order(self): ), ) bias_attr = paddle.framework.ParamAttr( - name="linear_bias", + name="linear_bias2", learning_rate=1.0, trainable=False, regularizer=None, @@ -815,16 +817,32 @@ def test_order(self): ) def run_dynamic_graph(): - paddle.disable_static() paddle.seed(SEED) linear = paddle.nn.Linear( - 1, 1, weight_attr=weight_attr, bias_attr=bias_attr + 1, + 1, + weight_attr=paddle.framework.ParamAttr( + name="linear_weight1", + learning_rate=1.0, + trainable=False, + regularizer=None, + initializer=paddle.nn.initializer.TruncatedNormal( + mean=0.0, std=2.0 + ), + ), + bias_attr=paddle.framework.ParamAttr( + name="linear_bias1", + learning_rate=1.0, + trainable=False, + regularizer=None, + initializer=paddle.nn.initializer.TruncatedNormal( + mean=0.0, std=2.0 + ), + ), ) return linear.weight.numpy(), linear.bias.numpy() - paddle.enable_static() def run_static_graph(): - paddle.enable_static() exe = paddle.static.Executor(paddle.CPUPlace()) paddle.seed(SEED) linear = paddle.nn.Linear( @@ -832,16 +850,93 @@ def run_static_graph(): ) res = exe.run( paddle.static.default_startup_program(), - fetch_list=['linear_weight', 'linear_bias'], + fetch_list=['linear_weight2', 'linear_bias2'], ) return res[0], res[1] - dynamic_res = run_dynamic_graph() - static_res = run_static_graph() + with dygraph_guard(): + dynamic_res = run_dynamic_graph() + with static_guard(): + static_res = run_static_graph() np.testing.assert_array_equal(dynamic_res[0], static_res[0]) np.testing.assert_array_equal(dynamic_res[1], static_res[1]) + def test_assign_static_fp32(self): + random_value = np.random.randn(128, 128).astype("float32") + + def run_dynamic_graph(dtype): + with dygraph_guard(): + w = paddle.create_parameter( + random_value.shape, + dtype, + default_initializer=paddle.nn.initializer.Assign( + random_value + ), + ) + return w + + def run_static_graph(dtype): + with static_guard(): + exe = paddle.static.Executor(paddle.CPUPlace()) + w = paddle.create_parameter( + random_value.shape, + dtype, + "w", + default_initializer=paddle.nn.initializer.Assign( + random_value + ), + ) + res = exe.run( + paddle.static.default_startup_program(), + fetch_list=['w'], + ) + return res[0] + + dynamic_res = run_dynamic_graph("float32") + static_res = run_static_graph("float32") + + np.testing.assert_array_equal(dynamic_res.numpy(), static_res) + np.testing.assert_array_equal(dynamic_res.numpy(), static_res) + + def test_assign_static_fp64(self): + random_value = np.random.randn(128, 128).astype("float64") + + def run_dynamic_graph(dtype): + with dygraph_guard(): + w = paddle.create_parameter( + random_value.shape, + dtype, + "www", + default_initializer=paddle.nn.initializer.Assign( + random_value + ), + ) + return w + + def run_static_graph(dtype): + with static_guard(): + exe = paddle.static.Executor(paddle.CPUPlace()) + w = paddle.create_parameter( + random_value.shape, + dtype, + "ww", + default_initializer=paddle.nn.initializer.Assign( + random_value + ), + ) + res = exe.run( + paddle.static.default_startup_program(), + fetch_list=['ww'], + ) + return res[0] + + dynamic_res = run_dynamic_graph("float64") + static_res = run_static_graph("float64") + + np.testing.assert_array_equal(dynamic_res.numpy(), static_res) + np.testing.assert_array_equal(dynamic_res.numpy(), static_res) + # 2-D Parameter with shape: [10, 15] class TestOrthogonalInitializer1(unittest.TestCase): @@ -1197,6 +1292,133 @@ def test_type_error(self): ) +class TestTruncatedNormalInitializerDygraph(unittest.TestCase): + def _trunc_normal_numpy(self, tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + _tensor = np.random.uniform( + low=2 * l - 1, high=2 * u - 1, size=tensor.shape + ).astype(paddle.get_default_dtype()) + print(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + _tensor = special.erfinv(_tensor) + + # Transform to proper mean, std + _tensor = np.multiply(_tensor, std * math.sqrt(2.0)) + _tensor = np.add(_tensor, mean) + + # Clamp to ensure it"s in the proper range + _tensor = np.clip(_tensor, a_min=a, a_max=b) + return _tensor + + def test_truncated_normal_initializer_fp32(self): + """ + In dygraph mode, we can use initializer directly to initialize a tensor. + """ + with dygraph_guard(): + paddle.seed(42) + pre_dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + + tensor = paddle.zeros([1024, 1024, 8]) + tensor.stop_gradient = False + + truncated_normal_ = paddle.nn.initializer.TruncatedNormal() + truncated_normal_(tensor) + + array = self._trunc_normal_numpy(tensor) + np.testing.assert_allclose( + array.mean(), tensor.mean().item(), rtol=0.01, atol=0.01 + ) + np.testing.assert_allclose( + array.std(), tensor.std().item(), rtol=0.01, atol=0.01 + ) + paddle.set_default_dtype(pre_dtype) + + def test_truncated_normal_initializer_fp64(self): + """ + In dygraph mode, we can use initializer directly to initialize a tensor. + """ + with dygraph_guard(): + paddle.seed(42) + pre_dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float64") + + tensor = paddle.zeros([1024, 1024, 8]) + tensor.stop_gradient = False + + truncated_normal_ = paddle.nn.initializer.TruncatedNormal() + truncated_normal_(tensor) + + array = self._trunc_normal_numpy(tensor) + np.testing.assert_allclose( + array.mean(), tensor.mean().item(), rtol=0.01, atol=0.01 + ) + np.testing.assert_allclose( + array.std(), tensor.std().item(), rtol=0.01, atol=0.01 + ) + paddle.set_default_dtype(pre_dtype) + + +class TestAssignInitializerDygraph(unittest.TestCase): + def test_assign_initializer_fp32(self): + """ + In dygraph mode, we can use initializer directly to initialize a tensor. + """ + with dygraph_guard(): + pre_dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + + tensor = paddle.zeros( + [1024, 1024, 8], dtype=paddle.get_default_dtype() + ) + tensor.stop_gradient = False + array = np.random.randn(*tensor.shape).astype( + paddle.get_default_dtype() + ) + + assign_ = paddle.nn.initializer.Assign(array) + assign_(tensor) + + np.testing.assert_allclose(array, tensor, rtol=1e-6, atol=1e-6) + paddle.set_default_dtype(pre_dtype) + + def test_assign_initializer_fp64(self): + """ + In dygraph mode, we can use initializer directly to initialize a tensor. + """ + with dygraph_guard(): + pre_dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float64") + + tensor = paddle.zeros( + [1024, 1024, 8], dtype=paddle.get_default_dtype() + ) + tensor.stop_gradient = False + array = np.random.randn(*tensor.shape).astype( + paddle.get_default_dtype() + ) + + assign_ = paddle.nn.initializer.Assign(array) + assign_(tensor) + + np.testing.assert_allclose(array, tensor, rtol=1e-6, atol=1e-6) + paddle.set_default_dtype(pre_dtype) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/test/legacy_test/test_initializer_nn.py b/test/legacy_test/test_initializer_nn.py index b0b0e0bef268d..95c64ac648290 100644 --- a/test/legacy_test/test_initializer_nn.py +++ b/test/legacy_test/test_initializer_nn.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from utils import static_guard import paddle from paddle import base, nn @@ -490,6 +491,12 @@ def test_truncated_normal_initializer_bf16(self): block = self.test_truncated_normal_initializer("uint16") # bfloat16 self.assertTrue(check_cast_op(block.ops[1])) + def test_truncated_normal_initializer_fp64(self): + """Test truncated normal initializer with float64""" + with static_guard(): + # Only test whether float64 data can be generated without error + _ = self.test_truncated_normal_initializer("float64") # float64 + def test_truncated_normal_initializer_dygraph(self): """Test truncated normal initializer in dygraph model.""" paddle.disable_static() diff --git a/test/legacy_test/test_truncated_gaussian_random_op.py b/test/legacy_test/test_truncated_gaussian_random_op.py index 0f56c5f9ef15e..eb8b502b082d4 100644 --- a/test/legacy_test/test_truncated_gaussian_random_op.py +++ b/test/legacy_test/test_truncated_gaussian_random_op.py @@ -35,20 +35,42 @@ def setUp(self): self.outputs = ["Out"] def test_cpu(self): - self.gaussian_random_test(place=base.CPUPlace()) - self.gaussian_random_test_eager(place=base.CPUPlace()) + self._gaussian_random_test( + place=base.CPUPlace(), dtype=core.VarDesc.VarType.FP32 + ) + self._gaussian_random_test( + place=base.CPUPlace(), dtype=core.VarDesc.VarType.FP64 + ) + self._gaussian_random_test_eager( + place=base.CPUPlace(), dtype=core.VarDesc.VarType.FP32 + ) + self._gaussian_random_test_eager( + place=base.CPUPlace(), dtype=core.VarDesc.VarType.FP64 + ) def test_gpu(self): if core.is_compiled_with_cuda(): - self.gaussian_random_test(place=base.CUDAPlace(0)) - self.gaussian_random_test_eager(place=base.CUDAPlace(0)) + self._gaussian_random_test( + place=base.CUDAPlace(0), dtype=core.VarDesc.VarType.FP32 + ) + self._gaussian_random_test( + place=base.CUDAPlace(0), dtype=core.VarDesc.VarType.FP64 + ) + self._gaussian_random_test_eager( + place=base.CUDAPlace(0), dtype=core.VarDesc.VarType.FP32 + ) + self._gaussian_random_test_eager( + place=base.CUDAPlace(0), dtype=core.VarDesc.VarType.FP64 + ) - def gaussian_random_test(self, place): + def _gaussian_random_test(self, place, dtype): program = base.Program() block = program.global_block() vout = block.create_var(name="Out") op = block.append_op( - type=self.op_type, outputs={"Out": vout}, attrs=self.attrs + type=self.op_type, + outputs={"Out": vout}, + attrs={**self.attrs, "dtype": dtype}, ) op.desc.infer_var_type(block.desc) @@ -66,14 +88,14 @@ def gaussian_random_test(self, place): # TruncatedNormal.__call__ has no return value, so here call _C_ops api # directly - def gaussian_random_test_eager(self, place): + def _gaussian_random_test_eager(self, place, dtype): with base.dygraph.guard(place): out = paddle._C_ops.truncated_gaussian_random( self.attrs["shape"], self.attrs["mean"], self.attrs["std"], self.attrs["seed"], - core.VarDesc.VarType.FP32, + dtype, place, ) self.assertAlmostEqual(numpy.mean(out.numpy()), 0.0, delta=0.1)