diff --git a/tripy/tests/integration/test_groupnorm.py b/tripy/tests/integration/test_groupnorm.py new file mode 100644 index 000000000..11ac2967e --- /dev/null +++ b/tripy/tests/integration/test_groupnorm.py @@ -0,0 +1,61 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import torch +import pytest + +import tripy as tp +from tripy.common.exception import TripyException + +DTYPES = [ + (torch.float16, tp.float16), + (torch.float32, tp.float32) +] + +class TestGroupNorm: + @pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES) + @pytest.mark.parametrize("input_shape", [(1, 10, 2)]) + @pytest.mark.parametrize("num_groups", [2, 5]) + @pytest.mark.parametrize("num_channels", [10]) + @pytest.mark.parametrize("eps", [1e-5, 1e-3]) + def test_groupnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, num_groups, num_channels, eps): + groupnorm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=num_channels, + eps=eps, + dtype=torch_dtype, + ) + tp_groupnorm = tp.GroupNorm( + num_groups=num_groups, + num_channels=num_channels, + eps=eps, + dtype=tp_dtype, + ) + + tp_groupnorm.weight = tp.Parameter(groupnorm.weight) + tp_groupnorm.bias = tp.Parameter(groupnorm.bias) + + input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype) + tp_input = tp.Tensor(input, dtype=tp_dtype) + + output = tp_groupnorm(tp_input) + expected = tp.Tensor(groupnorm(input), device=tp.device("cpu")) + + rtol_ = 2e-6 if tp_dtype == tp.float32 else 1e-3 + assert output.shape == expected.shape + assert tp.allclose(output, expected, rtol=rtol_) \ No newline at end of file diff --git a/tripy/tests/integration/test_layernorm.py b/tripy/tests/integration/test_layernorm.py new file mode 100644 index 000000000..7f297fb49 --- /dev/null +++ b/tripy/tests/integration/test_layernorm.py @@ -0,0 +1,68 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import re +import torch +import pytest + +import tripy as tp +from tripy.common.exception import TripyException + +DTYPES = [ + (torch.float16, tp.float16), + (torch.float32, tp.float32) +] + +class TestLayerNorm: + @pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES) + @pytest.mark.parametrize("input_shape", [(2, 2, 2)]) + @pytest.mark.parametrize("normalized_shape", [(2, 2), (2,)]) + @pytest.mark.parametrize("eps", [1e-5, 1e-3]) + def test_layernorm_accuracy(self, torch_dtype, tp_dtype, input_shape, normalized_shape, eps): + layernorm = torch.nn.LayerNorm( + normalized_shape=normalized_shape, + eps=eps, + dtype=torch_dtype, + ) + tp_layernorm = tp.LayerNorm( + normalized_shape=normalized_shape, + eps=eps, + dtype=tp_dtype, + ) + + # use Tripy's parameters + tp_layernorm.weight = tp.Parameter(layernorm.weight) + tp_layernorm.bias = tp.Parameter(layernorm.bias) + + input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype) + tp_input = tp.Tensor(input, dtype=tp_dtype) + + output = tp_layernorm(tp_input) + expected = tp.Tensor(layernorm(input), device=tp.device("cpu")) + + rtol_ = 2e-7 if tp_dtype == tp.float32 else 1e-3 + assert output.shape == expected.shape + assert tp.allclose(output, expected, rtol=rtol_) + + def test_layernorm_improper_dimensions(self): + tp_layernorm = tp.LayerNorm( + normalized_shape=[2, 2], + ) + x = tp.ones((5,5,5)) + with pytest.raises(TripyException, match=re.escape("The input's last 2 dimensions must have a shape of [2, 2] and received [5, 5]")): + tp_layernorm(x) \ No newline at end of file diff --git a/tripy/tripy/frontend/module/groupnorm.py b/tripy/tripy/frontend/module/groupnorm.py index 80150963a..e766be8fa 100644 --- a/tripy/tripy/frontend/module/groupnorm.py +++ b/tripy/tripy/frontend/module/groupnorm.py @@ -55,7 +55,7 @@ class GroupNorm(Module): eps: float """A value added to the denominator to prevent division by zero. Defaults to 1e-5.""" - def __init__(self, num_groups: int, num_channels: int, dtype: datatype.dtype = datatype.float32) -> None: + def __init__(self, num_groups: int, num_channels: int, dtype: datatype.dtype = datatype.float32, eps: float = 1e-5) -> None: """ Args: num_groups: The number of groups to split the channels into. @@ -97,7 +97,7 @@ def __init__(self, num_groups: int, num_channels: int, dtype: datatype.dtype = d # Replace with random weights when #74 is completed. self.weight = DefaultParameter((num_channels,), dtype=dtype) self.bias = DefaultParameter((num_channels,), dtype=dtype) - self.eps = 1e-5 + self.eps = eps def __call__(self, x: "tripy.Tensor") -> "tripy.Tensor": r""" diff --git a/tripy/tripy/frontend/module/layernorm.py b/tripy/tripy/frontend/module/layernorm.py index 1ba892c86..d62c4bb33 100644 --- a/tripy/tripy/frontend/module/layernorm.py +++ b/tripy/tripy/frontend/module/layernorm.py @@ -16,6 +16,7 @@ # from dataclasses import dataclass +from typing import Union, Tuple from tripy import export, utils from tripy.common import datatype @@ -25,7 +26,7 @@ @export.public_api(document_under="operations/modules") @dataclass -@utils.constant_fields(["dtype"]) +@utils.constant_fields(["dtype", "normalized_shape"]) class LayerNorm(Module): r""" Applies layer normalization over the input tensor: @@ -33,11 +34,17 @@ class LayerNorm(Module): :math:`\text{LayerNorm}(x) = \Large \frac{x - \bar{x}}{ \sqrt{\sigma^2 + \epsilon}} \normalsize * \gamma + \beta` where :math:`\bar{x}` is the mean and :math:`\sigma^2` is the variance. + + The mean and standard deviation are calculated over the last :math:`D` + dimensions, where :math:`D` is the dimension of `normalized_shape`. """ dtype: datatype.dtype r"""The data type used to perform the operation.""" + normalized_shape: Tuple[int] + r"""Defines the shape of the input tensor that is to be normalized over.""" + weight: Parameter r"""The :math:`\gamma` parameter of shape :math:`[\text{normalized_shape}]`.""" @@ -47,7 +54,7 @@ class LayerNorm(Module): eps: float """A value added to the denominator to prevent division by zero.""" - def __init__(self, normalized_shape: int, dtype: datatype.dtype = datatype.float32, eps: float = 1e-5) -> None: + def __init__(self, normalized_shape: Union[int, Tuple[int]], dtype: datatype.dtype = datatype.float32, eps: float = 1e-5) -> None: """ Args: normalized_shape: The size of the feature dimension of the input over which normalization is performed. @@ -77,9 +84,14 @@ def __init__(self, normalized_shape: int, dtype: datatype.dtype = datatype.float self.dtype = dtype # Replace with random weights when #74 is completed. - self.weight = DefaultParameter((normalized_shape,), dtype=dtype) + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + + self.normalized_shape = normalized_shape - self.bias = DefaultParameter((normalized_shape,), dtype=dtype) + self.weight = DefaultParameter(normalized_shape, dtype=dtype) + + self.bias = DefaultParameter(normalized_shape, dtype=dtype) self.eps = eps @@ -92,9 +104,20 @@ def __call__(self, x: "tripy.Tensor") -> "tripy.Tensor": A tensor of the same shape as the input. """ from tripy.frontend.trace.ops.reduce import mean, var + from tripy.frontend.shape import Shape from tripy.frontend.trace.ops.unary_elementwise import rsqrt + from tripy.common.exception import raise_error + + # The mean and the variance are computed over the last D dimensions + D = len(self.normalized_shape) + + if x.shape[-D:] != self.normalized_shape: + raise_error("Unexpected input shape", + [f"The input's last {D} dimensions must have a shape of {self.normalized_shape} and received {x.shape[-D:].data()}"] + ) - mean_val = mean(x, dim=-1, keepdim=True) - var_val = var(x, dim=-1, keepdim=True, correction=0) + self.eps + reduce_dims = tuple(-i for i in range(D, 0, -1)) + mean_val = mean(x, dim=reduce_dims, keepdim=True) + var_val = var(x, dim=reduce_dims, keepdim=True, correction=0) + self.eps x = (x - mean_val) * rsqrt(var_val) return self.weight * x + self.bias