Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tp.Layernorm 1:1 with torch #119

Merged
merged 11 commits into from
Aug 23, 2024
61 changes: 61 additions & 0 deletions tripy/tests/integration/test_groupnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import torch
import pytest

import tripy as tp
from tripy.common.exception import TripyException

DTYPES = [
(torch.float16, tp.float16),
(torch.float32, tp.float32)
]

class TestGroupNorm:
@pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES)
@pytest.mark.parametrize("input_shape", [(1, 10, 2)])
@pytest.mark.parametrize("num_groups", [2, 5])
@pytest.mark.parametrize("num_channels", [10])
@pytest.mark.parametrize("eps", [1e-5, 1e-3])
def test_groupnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, num_groups, num_channels, eps):
groupnorm = torch.nn.GroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
dtype=torch_dtype,
)
tp_groupnorm = tp.GroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
dtype=tp_dtype,
)

tp_groupnorm.weight = tp.Parameter(groupnorm.weight)
tp_groupnorm.bias = tp.Parameter(groupnorm.bias)

input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype)
tp_input = tp.Tensor(input, dtype=tp_dtype)

output = tp_groupnorm(tp_input)
expected = tp.Tensor(groupnorm(input), device=tp.device("cpu"))

rtol_ = 2e-6 if tp_dtype == tp.float32 else 1e-3
assert output.shape == expected.shape
assert tp.allclose(output, expected, rtol=rtol_)
68 changes: 68 additions & 0 deletions tripy/tests/integration/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import re
import torch
import pytest

import tripy as tp
from tripy.common.exception import TripyException

DTYPES = [
(torch.float16, tp.float16),
(torch.float32, tp.float32)
]

class TestLayerNorm:
@pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES)
@pytest.mark.parametrize("input_shape", [(2, 2, 2)])
@pytest.mark.parametrize("normalized_shape", [(2, 2), (2,)])
@pytest.mark.parametrize("eps", [1e-5, 1e-3])
def test_layernorm_accuracy(self, torch_dtype, tp_dtype, input_shape, normalized_shape, eps):
layernorm = torch.nn.LayerNorm(
normalized_shape=normalized_shape,
eps=eps,
dtype=torch_dtype,
)
tp_layernorm = tp.LayerNorm(
normalized_shape=normalized_shape,
eps=eps,
dtype=tp_dtype,
)

# use Tripy's parameters
tp_layernorm.weight = tp.Parameter(layernorm.weight)
tp_layernorm.bias = tp.Parameter(layernorm.bias)

input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype)
tp_input = tp.Tensor(input, dtype=tp_dtype)

output = tp_layernorm(tp_input)
expected = tp.Tensor(layernorm(input), device=tp.device("cpu"))

rtol_ = 2e-7 if tp_dtype == tp.float32 else 1e-3
assert output.shape == expected.shape
assert tp.allclose(output, expected, rtol=rtol_)

def test_layernorm_improper_dimensions(self):
tp_layernorm = tp.LayerNorm(
normalized_shape=[2, 2],
)
x = tp.ones((5,5,5))
with pytest.raises(TripyException, match=re.escape("The input's last 2 dimensions must have a shape of [2, 2] and received [5, 5]")):
tp_layernorm(x)
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/module/groupnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class GroupNorm(Module):
eps: float
"""A value added to the denominator to prevent division by zero. Defaults to 1e-5."""

def __init__(self, num_groups: int, num_channels: int, dtype: datatype.dtype = datatype.float32) -> None:
def __init__(self, num_groups: int, num_channels: int, dtype: datatype.dtype = datatype.float32, eps: float = 1e-5) -> None:
"""
Args:
num_groups: The number of groups to split the channels into.
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, num_groups: int, num_channels: int, dtype: datatype.dtype = d
# Replace with random weights when #74 is completed.
self.weight = DefaultParameter((num_channels,), dtype=dtype)
self.bias = DefaultParameter((num_channels,), dtype=dtype)
self.eps = 1e-5
self.eps = eps

def __call__(self, x: "tripy.Tensor") -> "tripy.Tensor":
r"""
Expand Down
35 changes: 29 additions & 6 deletions tripy/tripy/frontend/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from dataclasses import dataclass
from typing import Union, Tuple

from tripy import export, utils
from tripy.common import datatype
Expand All @@ -25,19 +26,25 @@

@export.public_api(document_under="operations/modules")
@dataclass
@utils.constant_fields(["dtype"])
@utils.constant_fields(["dtype", "normalized_shape"])
class LayerNorm(Module):
r"""
Applies layer normalization over the input tensor:

:math:`\text{LayerNorm}(x) = \Large \frac{x - \bar{x}}{ \sqrt{\sigma^2 + \epsilon}} \normalsize * \gamma + \beta`

where :math:`\bar{x}` is the mean and :math:`\sigma^2` is the variance.

The mean and standard deviation are calculated over the last :math:`D`
dimensions, where :math:`D` is the dimension of `normalized_shape`.
"""

dtype: datatype.dtype
r"""The data type used to perform the operation."""

normalized_shape: Tuple[int]
r"""Defines the shape of the input tensor that is to be normalized over."""

weight: Parameter
r"""The :math:`\gamma` parameter of shape :math:`[\text{normalized_shape}]`."""

Expand All @@ -47,7 +54,7 @@ class LayerNorm(Module):
eps: float
"""A value added to the denominator to prevent division by zero."""

def __init__(self, normalized_shape: int, dtype: datatype.dtype = datatype.float32, eps: float = 1e-5) -> None:
def __init__(self, normalized_shape: Union[int, Tuple[int]], dtype: datatype.dtype = datatype.float32, eps: float = 1e-5) -> None:
"""
Args:
normalized_shape: The size of the feature dimension of the input over which normalization is performed.
Expand Down Expand Up @@ -77,9 +84,14 @@ def __init__(self, normalized_shape: int, dtype: datatype.dtype = datatype.float
self.dtype = dtype

# Replace with random weights when #74 is completed.
self.weight = DefaultParameter((normalized_shape,), dtype=dtype)
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)

self.normalized_shape = normalized_shape

self.bias = DefaultParameter((normalized_shape,), dtype=dtype)
self.weight = DefaultParameter(normalized_shape, dtype=dtype)

self.bias = DefaultParameter(normalized_shape, dtype=dtype)

self.eps = eps

Expand All @@ -92,9 +104,20 @@ def __call__(self, x: "tripy.Tensor") -> "tripy.Tensor":
A tensor of the same shape as the input.
"""
from tripy.frontend.trace.ops.reduce import mean, var
from tripy.frontend.shape import Shape
from tripy.frontend.trace.ops.unary_elementwise import rsqrt
from tripy.common.exception import raise_error

# The mean and the variance are computed over the last D dimensions
D = len(self.normalized_shape)

if x.shape[-D:] != self.normalized_shape:
raise_error("Unexpected input shape",
[f"The input's last {D} dimensions must have a shape of {self.normalized_shape} and received {x.shape[-D:].data()}"]
)

mean_val = mean(x, dim=-1, keepdim=True)
var_val = var(x, dim=-1, keepdim=True, correction=0) + self.eps
reduce_dims = tuple(-i for i in range(D, 0, -1))
mean_val = mean(x, dim=reduce_dims, keepdim=True)
var_val = var(x, dim=reduce_dims, keepdim=True, correction=0) + self.eps
x = (x - mean_val) * rsqrt(var_val)
return self.weight * x + self.bias