diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ef4c7c96c4c38..e6eecbd6d99ac 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -494,8 +494,11 @@ ) from .tensor.random import ( bernoulli, + bernoulli_, binomial, check_shape, + log_normal, + log_normal_, multinomial, normal, normal_, @@ -943,8 +946,11 @@ 'hypot', 'hypot_', 'index_fill', - "index_fill_", + 'index_fill_', 'diagonal_scatter', 'combinations', 'signbit', + 'bernoulli_', + 'log_normal', + 'log_normal_', ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 43658f10a9324..4c1ca640d20e4 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -380,8 +380,11 @@ vander, ) from .random import ( # noqa: F401 + bernoulli_, binomial, exponential_, + log_normal, + log_normal_, multinomial, normal, normal_, @@ -784,6 +787,9 @@ 'masked_scatter_', "combinations", 'signbit', + 'bernoulli_', + 'log_normal', + 'log_normal_', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 6829bc84b045c..e55003a19a786 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -1463,7 +1463,7 @@ def exponential_(x, lam=1.0, name=None): f(x) = \lambda e^{-\lambda x} Args: - x(Tensor): Input tensor. The data type should be float32, float64. + x (Tensor): Input tensor. The data type should be float32, float64. lam(float, optional): :math:`\lambda` parameter of Exponential Distribution. Default, 1.0. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please @@ -1502,3 +1502,233 @@ def exponential_(x, lam=1.0, name=None): attrs={"lambda": lam}, ) return x + + +@dygraph_only +def bernoulli_(x, p=0.5, name=None): + r""" + This inplace OP fill input Tensor ``x`` with random number from a Bernoulli Distribution with probability ``p``. + + Args: + x (Tensor): Input tensor. The data type should be float32, float64. + p (float, optional): probability :math:`p` parameter of Bernoulli Distribution. Default: 0.5. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: + - x (Tensor): Input Tensor ``x``. + Examples: + .. code-block:: python + + >>> import paddle + >>> x = paddle.empty((3, 4)).uniform_(0, 1) + >>> x.bernoulli_() + >>> # doctest: +SKIP('random check') + >>> print(x) + Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True, + [[1., 1., 1., 1.], + [1., 0., 0., 1.], + [0., 1., 1., 0.]]) + >>> # doctest: -SKIP + + """ + if not (0 <= p and p <= 1): + raise ValueError(f"bernoulli_ expects p to be in [0, 1], but got p={p}") + + check_variable_and_dtype(x, "x", ["float32", "float64"], "exponential") + + uniform_(x, min=0.0, max=1.0) + return x.set_value((x < p).astype(x.dtype)) + + +def log_normal(mean=1.0, std=1.0, shape=None, dtype=None, name=None): + r""" + Returns a Tensor filled with random values sampled from a Log Normal + Distribution, with ``mean``, ``std``, ``shape`` and ``dtype``. + The Log Normal Distribution is defined as follows: + + Equation: + .. math:: + + f(x) = \frac{1}{x\sigma\sqrt{2\pi}}e^{-\frac{(\ln{x}-\mu)^2}{2\sigma^2}} + + Args: + mean (float|Tensor, optional): The mean of the output Tensor's normal distribution. + If ``mean`` is float, all elements of the output Tensor shared the same mean. + If ``mean`` is a Tensor(data type supports float32, float64), it has per-element means. + Default is 0.0 + std (float|Tensor, optional): The standard deviation of the output Tensor's normal distribution. + If ``std`` is float, all elements of the output Tensor shared the same standard deviation. + If ``std`` is a Tensor(data type supports float32, float64), it has per-element standard deviations. + Defaule is 1.0 + shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` . + If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape []. + If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. If ``mean`` or ``std`` + is a Tensor, the shape of the output Tensor is the same as ``mean`` or ``std`` , attr ``shape`` is ignored. + Default is None + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + Returns: + - out (Tensor): A Tensor filled with random values sampled from a log normal distribution with ``mean`` and ``std`` . + Examples: + .. code-block:: python + + :name: log_normal-example-1 + >>> import paddle + >>> out1 = paddle.log_normal(shape=[2, 3]) + >>> print(out1) + >>> # doctest: +SKIP("Random output") + Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, + [[-0.85107994, -0.85490644, -1.35941815], + [-0.55500370, 0.20964541, 2.24193954]]) + >>> # doctest: -SKIP + + :name: log_normal-example-2 + >>> import paddle + >>> mean_tensor = paddle.to_tensor([1.0, 2.0, 3.0]) + >>> out2 = paddle.log_normal(mean=mean_tensor) + >>> print(out2) + >>> # doctest: +SKIP("Random output") + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [1.05411839, 3.71514320, 3.42665267]) + >>> # doctest: -SKIP + + :name: log_normal-example-3 + >>> import paddle + >>> mean_tensor = paddle.to_tensor([1.0, 2.0, 3.0]) + >>> std_tensor = paddle.to_tensor([1.0, 2.0, 3.0]) + >>> out3 = paddle.log_normal(mean=mean_tensor, std=std_tensor) + >>> print(out3) + >>> # doctest: +SKIP("Random output") + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.48646951, 0.00815189, 3.74022293]) + >>> # doctest: -SKIP + + """ + if not in_dynamic_or_pir_mode(): + check_type(mean, 'mean', (list, tuple, Variable), 'log_normal') + check_type(std, 'std', (list, tuple, Variable), 'log_normal') + if isinstance(mean, Variable): + check_dtype( + mean.dtype, + 'mean', + ['float32', 'float64'], + 'log_normal', + "If mean is a Tensor, it's data type only support float32, float64", + ) + if isinstance(std, Variable): + check_dtype( + std.dtype, + 'std', + ['float16', 'float32', 'float64'], + 'log_normal', + "If std is a Tensor, it's data type only support float32, float64", + ) + if shape is not None: + check_shape(shape, 'log_normal') + + def normalize_mean_std(mean, std): + n_mean = paddle.log(mean**2 / paddle.sqrt(mean**2 + std**2)) + n_std = paddle.sqrt(paddle.log(1 + (std**2 / mean**2))) + return n_mean, n_std + + if isinstance(mean, Variable): + check_dtype( + mean.dtype, + 'mean', + ['float16', 'float32', 'float64'], + 'log_normal', + "If mean is a Tensor, it's data type only support float32, float64", + ) + if isinstance(std, Variable): + check_dtype( + std.dtype, + 'std', + ['float16', 'float32', 'float64'], + 'log_normal', + "If std is a Tensor, it's data type only support float32, float64", + ) + if std.dtype != mean.dtype: + std = paddle.cast(std, mean.dtype) + mean_shape = paddle.shape(mean) + std = paddle.reshape(std, mean_shape) + else: + std = paddle.to_tensor(std) + n_mean, n_std = normalize_mean_std(mean, std) + distribution = gaussian( + shape=paddle.shape(mean), + mean=n_mean, + std=n_std, + dtype=dtype, + name=name, + ) + elif isinstance(std, Variable): + mean = paddle.to_tensor(mean) + n_mean, n_std = normalize_mean_std(mean, std) + distribution = gaussian( + shape=paddle.shape(std), + mean=n_mean, + std=n_std, + dtype=dtype, + name=name, + ) + else: + mean = paddle.to_tensor(mean) + std = paddle.to_tensor(std) + n_mean, n_std = normalize_mean_std(mean, std) + distribution = gaussian( + mean=n_mean, std=n_std, shape=shape, dtype=dtype, name=name + ) + + return paddle.exp(distribution) + + +@dygraph_only +def log_normal_(x, mean=0.0, std=1.0, name=None): + r""" + This inplace OP fill input Tensor ``x`` with random number from a Log Normal Distribution + with ``mean`` and ``std``. The Log Normal Distribution is defined as follows: + + Equation: + .. math:: + + f(x) = \frac{1}{x\sigma\sqrt{2\pi}}e^{-\frac{(\ln{x}-\mu)^2}{2\sigma^2}} + + Args: + x (Tensor): The input tensor to be filled with random values. + mean (float|Tensor, optional): The mean of the output Tensor's normal distribution. + If ``mean`` is float, all elements of the output Tensor shared the same mean. + If ``mean`` is a Tensor(data type supports float32, float64), it has per-element means. + Default is 0.0 + std (float|Tensor, optional): The standard deviation of the output Tensor's normal distribution. + If ``std`` is float, all elements of the output Tensor shared the same standard deviation. + If ``std`` is a Tensor(data type supports float32, float64), it has per-element standard deviations. + Defaule is 1.0 + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: + A Tensor filled with random values sampled from a normal distribution with ``mean`` and ``std`` . + Examples: + .. code-block:: python + + >>> import paddle + >>> x = paddle.randn([3, 4]) + >>> x.log_normal_() + >>> # doctest: +SKIP('random check') + >>> print(x) + Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True, + [[ 0.06132207, 1.11349595, 0.41906244, -0.24858207], + [-1.85169315, -1.50370061, 1.73954511, 0.13331604], + [ 1.66359663, -0.55764782, -0.59911072, -0.57773495]]) + >>> # doctest: -SKIP + + """ + if not isinstance(mean, Variable) or not isinstance(mean, float): + mean = paddle.to_tensor(mean, dtype=paddle.float64) + if not isinstance(std, Variable) or not isinstance(std, float): + std = paddle.to_tensor(std, dtype=paddle.float64) + + n_mean = paddle.log(mean**2 / paddle.sqrt(mean**2 + std**2)) + n_std = paddle.sqrt(paddle.log(1 + (std**2 / mean**2))) + + return normal_(x, mean=n_mean, std=n_std).exp_() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index f06edfd83206c..7b3aeb008a00b 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -1680,5 +1680,38 @@ def test_backward_error(self): loss.backward() +class TestDygraphInplaceBernoulli(TestDygraphInplace): + def init_data(self): + self.shape = (20, 40) + self.input_var_numpy = np.random.random(self.shape) + self.dtype = "float32" + self.mean = 0 + self.std = 1 + self.seed = 100 + self.p = 0.5 + + def inplace_api_processing(self, var): + return paddle.bernoulli_(var, p=self.p) + + def non_inplace_api_processing(self, var): + return paddle.bernoulli(paddle.zeros(self.shape) + self.p) + + +class TestDygraphInplaceLogNormal(TestDygraphInplace): + def init_data(self): + self.shape = (20, 40) + self.input_var_numpy = np.random.random(self.shape) + self.dtype = "float32" + self.mean = 0 + self.std = 1 + self.seed = 100 + + def inplace_api_processing(self, var): + return paddle.log_normal_(var, self.shape, self.mean, self.std) + + def non_inplace_api_processing(self, var): + return paddle.log_normal(var, self.shape, self.mean, self.std) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_log_normal_op.py b/test/legacy_test/test_log_normal_op.py new file mode 100644 index 0000000000000..dc5eace984f37 --- /dev/null +++ b/test/legacy_test/test_log_normal_op.py @@ -0,0 +1,234 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import paddle_static_guard + +import paddle +from paddle import base +from paddle.base import core +from paddle.tensor import random + +SEED = 100 +np.random.seed(SEED) +paddle.seed(SEED) + + +def ref_log_normal(shape, mean, std): + return np.exp(np.random.normal(mean, std, shape)) + + +class TestLogNormalAPI(unittest.TestCase): + def test_static_api(self): + with paddle_static_guard(): + positive_2_int32 = paddle.tensor.fill_constant([1], "int32", 2000) + + positive_2_int64 = paddle.tensor.fill_constant([1], "int64", 500) + shape_tensor_int32 = paddle.static.data( + name="shape_tensor_int32", shape=[2], dtype="int32" + ) + + shape_tensor_int64 = paddle.static.data( + name="shape_tensor_int64", shape=[2], dtype="int64" + ) + + out_1 = random.log_normal( + shape=[2000, 500], dtype="float32", mean=0.0, std=1.0 + ) + + out_2 = random.log_normal( + shape=[2000, positive_2_int32], + dtype="float32", + mean=0.0, + std=1.0, + ) + + out_3 = random.log_normal( + shape=[2000, positive_2_int64], + dtype="float32", + mean=0.0, + std=1.0, + ) + + out_4 = random.log_normal( + shape=shape_tensor_int32, dtype="float32", mean=0.0, std=1.0 + ) + + out_5 = random.log_normal( + shape=shape_tensor_int64, dtype="float32", mean=0.0, std=1.0 + ) + + out_6 = random.log_normal( + shape=shape_tensor_int64, dtype=np.float32, mean=0.0, std=1.0 + ) + + exe = base.Executor(place=base.CPUPlace()) + res_1, res_2, res_3, res_4, res_5, res_6 = exe.run( + base.default_main_program(), + feed={ + "shape_tensor_int32": np.array([2000, 500]).astype("int32"), + "shape_tensor_int64": np.array([2000, 500]).astype("int64"), + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6], + ) + + self.assertAlmostEqual(np.mean(res_1), 0.0, delta=0.1) + self.assertAlmostEqual(np.std(res_1), 1.0, delta=0.1) + self.assertAlmostEqual(np.mean(res_2), 0.0, delta=0.1) + self.assertAlmostEqual(np.std(res_2), 1.0, delta=0.1) + self.assertAlmostEqual(np.mean(res_3), 0.0, delta=0.1) + self.assertAlmostEqual(np.std(res_3), 1.0, delta=0.1) + self.assertAlmostEqual(np.mean(res_4), 0.0, delta=0.1) + self.assertAlmostEqual(np.std(res_5), 1.0, delta=0.1) + self.assertAlmostEqual(np.mean(res_5), 0.0, delta=0.1) + self.assertAlmostEqual(np.std(res_5), 1.0, delta=0.1) + self.assertAlmostEqual(np.mean(res_6), 0.0, delta=0.1) + self.assertAlmostEqual(np.std(res_6), 1.0, delta=0.1) + + def test_default_dtype(self): + def test_default_fp16(): + paddle.framework.set_default_dtype('float16') + out = paddle.tensor.random.log_normal([2, 3]) + self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP16) + + def test_default_fp32(): + paddle.framework.set_default_dtype('float32') + out = paddle.tensor.random.log_normal([2, 3]) + self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP32) + + def test_default_fp64(): + paddle.framework.set_default_dtype('float64') + out = paddle.tensor.random.log_normal([2, 3]) + self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP64) + + if paddle.is_compiled_with_cuda(): + paddle.set_device('gpu') + test_default_fp16() + test_default_fp64() + test_default_fp32() + + +class TestStandardNormalDtype(unittest.TestCase): + def test_default_dtype(self): + def test_default_fp16(): + paddle.framework.set_default_dtype('float16') + out = paddle.tensor.random.standard_normal([2, 3]) + self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP16) + + def test_default_fp32(): + paddle.framework.set_default_dtype('float32') + out = paddle.tensor.random.standard_normal([2, 3]) + self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP32) + + def test_default_fp64(): + paddle.framework.set_default_dtype('float64') + out = paddle.tensor.random.standard_normal([2, 3]) + self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP64) + + if paddle.is_compiled_with_cuda(): + paddle.set_device('gpu') + test_default_fp16() + test_default_fp64() + test_default_fp32() + + +class TestRandomValue(unittest.TestCase): + def test_fixed_random_number(self): + # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t' + if not paddle.is_compiled_with_cuda(): + return + + # Different GPU generatte different random value. Only test V100 here. + if "V100" not in paddle.device.cuda.get_device_name(): + return + + def _check_random_value(dtype, expect, expect_mean, expect_std): + x = paddle.randn([32, 3, 1024, 1024], dtype=dtype) + actual = x.numpy() + np.testing.assert_allclose( + actual[2, 1, 512, 1000:1010], expect, rtol=1e-05 + ) + self.assertTrue(np.mean(actual), expect_mean) + self.assertTrue(np.std(actual), expect_std) + + print("Test Fixed Random number on V100 GPU------>") + paddle.disable_static() + paddle.set_device('gpu') + paddle.seed(100) + expect = [ + -0.79037829, + -0.54411126, + -0.32266671, + 0.35791815, + 1.44169267, + -0.87785644, + -1.23909874, + -2.18194139, + 0.49489656, + 0.40703062, + ] + expect_mean = ( + -0.0000053026194133403266873214888799115129813799285329878330230713 + ) + expect_std = 0.99999191058126390974081232343451119959354400634765625 + _check_random_value( + core.VarDesc.VarType.FP64, expect, expect_mean, expect_std + ) + + expect = [ + -0.7988942, + 1.8644791, + 0.02782744, + 1.3692524, + 0.6419724, + 0.12436751, + 0.12058455, + -1.9984808, + 1.5635862, + 0.18506318, + ] + expect_mean = -0.00004762359094456769526004791259765625 + expect_std = 0.999975681304931640625 + _check_random_value( + core.VarDesc.VarType.FP32, expect, expect_mean, expect_std + ) + + +class TestLogNormalErrors(unittest.TestCase): + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + mean = [1, 2, 3] + self.assertRaises(TypeError, paddle.log_normal, mean) + + std = [1, 2, 3] + self.assertRaises(TypeError, paddle.log_normal, std=std) + + mean = paddle.static.data('Mean', [100], 'int32') + self.assertRaises(TypeError, paddle.log_normal, mean) + + std = paddle.static.data('Std', [100], 'int32') + self.assertRaises(TypeError, paddle.log_normal, mean=1.0, std=std) + + self.assertRaises(TypeError, paddle.log_normal, shape=1) + + self.assertRaises(TypeError, paddle.log_normal, shape=[1.0]) + + shape = paddle.static.data('Shape', [100], 'float32') + self.assertRaises(TypeError, paddle.log_normal, shape=shape) + + +if __name__ == "__main__": + unittest.main()