Skip to content

Commit

Permalink
Merge pull request #490 from allenai/RemoveAMDLN
Browse files Browse the repository at this point in the history
Remove AMD LayerNorm
  • Loading branch information
dirkgr authored Mar 8, 2024
2 parents 752353b + 1810817 commit eb5b2da
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 49 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Changed legacy checkpoint unsharding to use processes and shared memory instead of threads

### Removed

- Removed `AMDLayerNorm`, since the original layer norm bug has been fixed and we don't need this workaround anymore.


## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02

### Fixed
Expand Down
5 changes: 0 additions & 5 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,6 @@ class LayerNormType(StrEnum):
probably the fastest implementation.
"""

amd_compatible = "amd_compatible"
"""
LayerNorm implemented manually to work around an issue with ROCm.
"""


class ActivationType(StrEnum):
gelu = "gelu"
Expand Down
35 changes: 0 additions & 35 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
"LayerNormBase",
"LayerNorm",
"RMSLayerNorm",
"AMDLayerNorm",
"RotaryEmbedding",
"Activation",
"GELU",
Expand Down Expand Up @@ -152,8 +151,6 @@ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> Lay
return LayerNorm(config, size=size, low_precision=True, **kwargs)
elif config.layer_norm_type == LayerNormType.rms:
return RMSLayerNorm(config, size=size, **kwargs)
elif config.layer_norm_type == LayerNormType.amd_compatible:
return AMDLayerNorm(config, size=size, **kwargs)
else:
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")

Expand Down Expand Up @@ -207,38 +204,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)


class AMDLayerNorm(LayerNormBase):
"""
LayerNorm implemented using PyTorch primitives.
We do this to work around a bug in the PyTorch/ROCm implementation of layer norm that fails with a
segfault when the bias is not present.
"""

def __init__(
self,
config: ModelConfig,
size: Optional[int] = None,
elementwise_affine: Optional[bool] = None,
eps: float = 1e-05,
):
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)

def forward(self, x: torch.Tensor) -> torch.Tensor:
og_dtype = x.dtype
x = self._cast_if_autocast_enabled(x, dtype=torch.float32)
with torch.autocast(enabled=False, device_type=x.device.type):
var, mean = torch.var_mean(x, dim=-1, correction=0, keepdim=True)
var.add_(self.eps)
var.rsqrt_() # rsqrt should be more stable than 1/sqrt
x = var * (x - mean)
if self.weight is not None:
x.mul_(self.weight)
if self.bias is not None:
x.add_(self.bias)
return x.to(og_dtype)


class RMSLayerNorm(LayerNormBase):
"""
RMS layer norm, a simplified :class:`LayerNorm` implementation
Expand Down
9 changes: 0 additions & 9 deletions tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from olmo import BlockType, LayerNorm, Olmo, Tokenizer, TrainConfig
from olmo.config import ModelConfig, PaddingDirection
from olmo.data import DataCollator
from olmo.model import AMDLayerNorm


@pytest.mark.parametrize(
Expand Down Expand Up @@ -399,29 +398,24 @@ def test_layer_norm(train_config: TrainConfig, elementwise_affine: bool, include
train_config.model.layer_norm_with_affine = elementwise_affine
train_config.model.include_bias = include_bias
ln = LayerNorm.build(train_config.model)
amd_ln = AMDLayerNorm(train_config.model)

needs_weight = elementwise_affine
needs_bias = elementwise_affine and include_bias
with torch.no_grad():
if needs_weight:
weight = torch.randn(train_config.model.d_model)
ln.weight.copy_(weight)
amd_ln.weight.copy_(weight)
else:
weight = None

if needs_bias:
bias = torch.randn(train_config.model.d_model)
ln.bias.copy_(bias)
amd_ln.bias.copy_(bias)
else:
bias = None

assert ln.bias is None or ln.bias.requires_grad == needs_bias
assert ln.weight is None or ln.weight.requires_grad == needs_weight
assert amd_ln.bias is None or amd_ln.bias.requires_grad == needs_bias
assert amd_ln.weight is None or amd_ln.weight.requires_grad == needs_weight

x = torch.randn(16, 1024, train_config.model.d_model)
x.requires_grad = False
Expand All @@ -430,9 +424,6 @@ def test_layer_norm(train_config: TrainConfig, elementwise_affine: bool, include
y_actual = ln(x)
torch.testing.assert_close(y_actual, y_expected)

y_actual = amd_ln(x)
torch.testing.assert_close(y_actual, y_expected)


def test_block_groups():
model_with_block_groups = Olmo(ModelConfig(d_model=128, n_heads=2, n_layers=9, block_group_size=3)).eval()
Expand Down

0 comments on commit eb5b2da

Please sign in to comment.