Skip to content

Commit

Permalink
enable TritonFusedRMSNorm with local_map annotation (#404)
Browse files Browse the repository at this point in the history
Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with
7%-8% performance gain compared to RMSNorm with TP. #364
  • Loading branch information
XilunWu authored Jun 14, 2024
1 parent 093ba15 commit d761994
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .ci/docker/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
expecttest==0.1.6
pytest==7.3.2
pytest-cov
pre-commit
72 changes: 72 additions & 0 deletions test/test_fused_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.distributed._tensor import (
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)

from torchtitan.models.norms import fused_rms_norm_fn


class TestFusedRMSNorm(DTensorTestBase):
@property
def world_size(self):
return 4

@skip_if_lt_x_gpu(4)
@with_comms
def test_fused_rms_norm(self):
mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
)
x = torch.randn(4, 4, 4, device=self.device_type) # Shard(1)
w = torch.randn(4, device=self.device_type, requires_grad=True) # Replicate

dist_x = distribute_tensor(x, mesh, [Shard(1)])
dist_w = distribute_tensor(w, mesh, [Replicate()])

x = x.clone().detach()
w = w.clone().detach().requires_grad_()

self.assertEqual(dist_x.full_tensor(), x)
self.assertEqual(dist_w.full_tensor(), w)

# fused rmsnorm on DTensor
comm_mode = CommDebugMode()
# fused rmsnorm
with comm_mode:
dist_out = fused_rms_norm_fn(dist_x, dist_w)

self.assertEqual(comm_mode.get_total_counts(), 0)

with comm_mode:
dist_grad_out = torch.ones_like(dist_out)
dist_out.backward(dist_grad_out)

self.assertEqual(comm_mode.get_total_counts(), 0)

# fused rmsnorm on Tensor
out = fused_rms_norm_fn(x, w)
grad_out = torch.ones_like(out)
out.backward(grad_out)

self.assertEqual(dist_out.full_tensor(), out)
self.assertEqual(dist_grad_out.full_tensor(), grad_out)


if __name__ == "__main__":
run_tests()
13 changes: 11 additions & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,17 @@ def build_test_list():
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
],
],
"Eager mode 2DParallel",
"eager_2d",
"Eager mode 2DParallel with rmsnorm",
"eager_2d_rmsnorm",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm",
],
],
"Eager mode 2DParallel with fused_rmsnorm",
"eager_2d_fused_rmsnorm",
),
OverrideDefinitions(
[
Expand Down
15 changes: 15 additions & 0 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

import math

from functools import partial

import torch
import torch.nn as nn

import triton
import triton.language as tl

from torch.distributed._tensor import Partial, Replicate, Shard
from torch.distributed._tensor.experimental import local_map


def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Expand Down Expand Up @@ -214,6 +219,11 @@ def _rms_norm_bwd_kernel_sm(


class TritonFusedRMSNorm(torch.autograd.Function):
@partial(
local_map,
out_placements=[Shard(1)],
in_placements=(None, [Shard(1)], [Replicate()], None),
)
@staticmethod
def forward(ctx, x, weight, eps):
x_shape_start = x.shape
Expand Down Expand Up @@ -256,6 +266,11 @@ def forward(ctx, x, weight, eps):
y = y.reshape(x_shape_start)
return y

@partial(
local_map,
out_placements=([Shard(1)], [Partial()], None),
in_placements=(None, [Shard(1)]),
)
@staticmethod
def backward(ctx, dy):
x, weight, rstd = ctx.saved_tensors
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""

if parallel_dims.tp_enabled:
if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm."
)

tp_mesh = world_mesh["tp"]
(
row_parallel_strategy,
Expand Down

0 comments on commit d761994

Please sign in to comment.