Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support float64 for TruncatedNormal and Assign #57507

Merged
merged 19 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddle/fluid/operators/assign_value_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ class AssignValueKernel : public framework::OpKernel<T> {
case framework::proto::VarType::FP32:
value_name = "fp32_values";
break;
case framework::proto::VarType::FP64:
value_name = "fp64_values";
break;
case framework::proto::VarType::INT64:
value_name = "int64_values";
case framework::proto::VarType::INT8:
Expand All @@ -118,7 +121,7 @@ class AssignValueKernel : public framework::OpKernel<T> {
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type(code %d) for AssignValue operator, only "
"supports bool, int32, float32, int8 and int64.",
"supports bool, int32, float32, float64, int8 and int64.",
dtype));
break;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
backward : assign_grad

- op : assign_value
args : (int[] shape, DataType dtype, int[] bool_values = {}, float[] fp32_values = {}, int[] int32_values = {}, int64_t[] int64_values = {})
args : (int[] shape, DataType dtype, int[] bool_values = {}, float[] fp32_values = {}, double[] fp64_values = {}, int[] int32_values = {}, int64_t[] int64_values = {})
output : Tensor(out)
infer_meta :
func : AssignValueInferMeta
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/assign_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(assign_value,
bool,
int,
float,
double,
int8_t,
int64_t) {}

Expand Down Expand Up @@ -159,6 +160,7 @@ PD_REGISTER_KERNEL(assign_value,
bool,
int,
float,
double,
int8_t,
int64_t) {}
#endif
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ PD_REGISTER_KERNEL(truncated_gaussian_random,
CPU,
ALL_LAYOUT,
phi::TruncatedGaussianRandomKernel,
float) {}
float,
double) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,5 @@ PD_REGISTER_KERNEL(truncated_gaussian_random,
GPU,
ALL_LAYOUT,
phi::TruncatedGaussianRandomKernel,
float) {}
float,
double) {}
3 changes: 3 additions & 0 deletions paddle/phi/ops/compat/assign_value_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ KernelSignature AssignValueOpArgumentMapping(
} else if (dtype == /*FP32*/ 5) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "fp32_values"}, {"Out"});
} else if (dtype == /*FP64*/ 6) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "fp64_values"}, {"Out"});
} else if (dtype == /*INT64*/ 3) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "int64_values"}, {"Out"});
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/nn/initializer/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def forward(self, var, block=None):
if out_dtype == core.VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in np_value.flat]
elif out_dtype == core.VarDesc.VarType.FP64:
value_name = "fp64_values"
values = [float(v) for v in np_value.flat]
elif out_dtype == core.VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in np_value.flat]
Expand Down
240 changes: 231 additions & 9 deletions test/legacy_test/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import unittest

import numpy as np
from scipy import special
from utils import dygraph_guard, static_guard

import paddle
from paddle import base
Expand Down Expand Up @@ -796,7 +798,7 @@ def test_order(self):
paddle.set_device('cpu')
SEED = 123
weight_attr = paddle.framework.ParamAttr(
name="linear_weight",
name="linear_weight2",
learning_rate=1.0,
trainable=False,
regularizer=None,
Expand All @@ -805,7 +807,7 @@ def test_order(self):
),
)
bias_attr = paddle.framework.ParamAttr(
name="linear_bias",
name="linear_bias2",
learning_rate=1.0,
trainable=False,
regularizer=None,
Expand All @@ -815,33 +817,126 @@ def test_order(self):
)

def run_dynamic_graph():
paddle.disable_static()
paddle.seed(SEED)
linear = paddle.nn.Linear(
1, 1, weight_attr=weight_attr, bias_attr=bias_attr
1,
1,
weight_attr=paddle.framework.ParamAttr(
name="linear_weight1",
learning_rate=1.0,
trainable=False,
regularizer=None,
initializer=paddle.nn.initializer.TruncatedNormal(
mean=0.0, std=2.0
),
),
bias_attr=paddle.framework.ParamAttr(
name="linear_bias1",
learning_rate=1.0,
trainable=False,
regularizer=None,
initializer=paddle.nn.initializer.TruncatedNormal(
mean=0.0, std=2.0
),
),
)
return linear.weight.numpy(), linear.bias.numpy()
paddle.enable_static()

def run_static_graph():
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
paddle.seed(SEED)
linear = paddle.nn.Linear(
1, 1, weight_attr=weight_attr, bias_attr=bias_attr
)
res = exe.run(
paddle.static.default_startup_program(),
fetch_list=['linear_weight', 'linear_bias'],
fetch_list=['linear_weight2', 'linear_bias2'],
)
return res[0], res[1]

dynamic_res = run_dynamic_graph()
static_res = run_static_graph()
with dygraph_guard():
dynamic_res = run_dynamic_graph()
with static_guard():
static_res = run_static_graph()

np.testing.assert_array_equal(dynamic_res[0], static_res[0])
np.testing.assert_array_equal(dynamic_res[1], static_res[1])

def test_assign_static_fp32(self):
random_value = np.random.randn(128, 128).astype("float32")

def run_dynamic_graph(dtype):
with dygraph_guard():
w = paddle.create_parameter(
random_value.shape,
dtype,
default_initializer=paddle.nn.initializer.Assign(
random_value
),
)
return w

def run_static_graph(dtype):
with static_guard():
exe = paddle.static.Executor(paddle.CPUPlace())
w = paddle.create_parameter(
random_value.shape,
dtype,
"w",
default_initializer=paddle.nn.initializer.Assign(
random_value
),
)
res = exe.run(
paddle.static.default_startup_program(),
fetch_list=['w'],
)
return res[0]

dynamic_res = run_dynamic_graph("float32")
static_res = run_static_graph("float32")

np.testing.assert_array_equal(dynamic_res.numpy(), static_res)
np.testing.assert_array_equal(dynamic_res.numpy(), static_res)

def test_assign_static_fp64(self):
random_value = np.random.randn(128, 128).astype("float64")

def run_dynamic_graph(dtype):
with dygraph_guard():
w = paddle.create_parameter(
random_value.shape,
dtype,
"www",
default_initializer=paddle.nn.initializer.Assign(
random_value
),
)
return w

def run_static_graph(dtype):
with static_guard():
exe = paddle.static.Executor(paddle.CPUPlace())
w = paddle.create_parameter(
random_value.shape,
dtype,
"ww",
default_initializer=paddle.nn.initializer.Assign(
random_value
),
)
res = exe.run(
paddle.static.default_startup_program(),
fetch_list=['ww'],
)
return res[0]

dynamic_res = run_dynamic_graph("float64")
static_res = run_static_graph("float64")

np.testing.assert_array_equal(dynamic_res.numpy(), static_res)
np.testing.assert_array_equal(dynamic_res.numpy(), static_res)


# 2-D Parameter with shape: [10, 15]
class TestOrthogonalInitializer1(unittest.TestCase):
Expand Down Expand Up @@ -1197,6 +1292,133 @@ def test_type_error(self):
)


class TestTruncatedNormalInitializerDygraph(unittest.TestCase):
def _trunc_normal_numpy(self, tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)

# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
_tensor = np.random.uniform(
low=2 * l - 1, high=2 * u - 1, size=tensor.shape
).astype(paddle.get_default_dtype())
print(2 * l - 1, 2 * u - 1)

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
_tensor = special.erfinv(_tensor)

# Transform to proper mean, std
_tensor = np.multiply(_tensor, std * math.sqrt(2.0))
_tensor = np.add(_tensor, mean)

# Clamp to ensure it"s in the proper range
_tensor = np.clip(_tensor, a_min=a, a_max=b)
return _tensor

def test_truncated_normal_initializer_fp32(self):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
with dygraph_guard():
paddle.seed(42)
pre_dtype = paddle.get_default_dtype()
paddle.set_default_dtype("float32")

tensor = paddle.zeros([1024, 1024, 8])
tensor.stop_gradient = False

truncated_normal_ = paddle.nn.initializer.TruncatedNormal()
truncated_normal_(tensor)

array = self._trunc_normal_numpy(tensor)
np.testing.assert_allclose(
array.mean(), tensor.mean().item(), rtol=0.01, atol=0.01
)
np.testing.assert_allclose(
array.std(), tensor.std().item(), rtol=0.01, atol=0.01
)
paddle.set_default_dtype(pre_dtype)

def test_truncated_normal_initializer_fp64(self):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
with dygraph_guard():
paddle.seed(42)
pre_dtype = paddle.get_default_dtype()
paddle.set_default_dtype("float64")

tensor = paddle.zeros([1024, 1024, 8])
tensor.stop_gradient = False

truncated_normal_ = paddle.nn.initializer.TruncatedNormal()
truncated_normal_(tensor)

array = self._trunc_normal_numpy(tensor)
np.testing.assert_allclose(
array.mean(), tensor.mean().item(), rtol=0.01, atol=0.01
)
np.testing.assert_allclose(
array.std(), tensor.std().item(), rtol=0.01, atol=0.01
)
paddle.set_default_dtype(pre_dtype)


class TestAssignInitializerDygraph(unittest.TestCase):
def test_assign_initializer_fp32(self):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
with dygraph_guard():
pre_dtype = paddle.get_default_dtype()
paddle.set_default_dtype("float32")

tensor = paddle.zeros(
[1024, 1024, 8], dtype=paddle.get_default_dtype()
)
tensor.stop_gradient = False
array = np.random.randn(*tensor.shape).astype(
paddle.get_default_dtype()
)

assign_ = paddle.nn.initializer.Assign(array)
assign_(tensor)

np.testing.assert_allclose(array, tensor, rtol=1e-6, atol=1e-6)
paddle.set_default_dtype(pre_dtype)

def test_assign_initializer_fp64(self):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
with dygraph_guard():
pre_dtype = paddle.get_default_dtype()
paddle.set_default_dtype("float64")

tensor = paddle.zeros(
[1024, 1024, 8], dtype=paddle.get_default_dtype()
)
tensor.stop_gradient = False
array = np.random.randn(*tensor.shape).astype(
paddle.get_default_dtype()
)

assign_ = paddle.nn.initializer.Assign(array)
assign_(tensor)

np.testing.assert_allclose(array, tensor, rtol=1e-6, atol=1e-6)
paddle.set_default_dtype(pre_dtype)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
7 changes: 7 additions & 0 deletions test/legacy_test/test_initializer_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from utils import static_guard

import paddle
from paddle import base, nn
Expand Down Expand Up @@ -490,6 +491,12 @@ def test_truncated_normal_initializer_bf16(self):
block = self.test_truncated_normal_initializer("uint16") # bfloat16
self.assertTrue(check_cast_op(block.ops[1]))

def test_truncated_normal_initializer_fp64(self):
"""Test truncated normal initializer with float64"""
with static_guard():
# Only test whether float64 data can be generated without error
_ = self.test_truncated_normal_initializer("float64") # float64

def test_truncated_normal_initializer_dygraph(self):
"""Test truncated normal initializer in dygraph model."""
paddle.disable_static()
Expand Down
Loading