From d76199434204aab12d40e9d13b4345aacf3778a1 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Thu, 13 Jun 2024 23:08:45 -0700 Subject: [PATCH] enable TritonFusedRMSNorm with local_map annotation (#404) Summary This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. #364 --- .ci/docker/dev-requirements.txt | 3 +- test/test_fused_rms_norm.py | 72 ++++++++++++++++++++ test_runner.py | 13 +++- torchtitan/models/norms.py | 15 ++++ torchtitan/parallelisms/parallelize_llama.py | 5 -- 5 files changed, 100 insertions(+), 8 deletions(-) create mode 100644 test/test_fused_rms_norm.py diff --git a/.ci/docker/dev-requirements.txt b/.ci/docker/dev-requirements.txt index 1f960b0b..770301a0 100644 --- a/.ci/docker/dev-requirements.txt +++ b/.ci/docker/dev-requirements.txt @@ -1,3 +1,4 @@ -pytest +expecttest==0.1.6 +pytest==7.3.2 pytest-cov pre-commit diff --git a/test/test_fused_rms_norm.py b/test/test_fused_rms_norm.py new file mode 100644 index 00000000..9bd7e373 --- /dev/null +++ b/test/test_fused_rms_norm.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.distributed._tensor import ( + distribute_tensor, + init_device_mesh, + Replicate, + Shard, +) +from torch.distributed._tensor.debug import CommDebugMode +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + skip_if_lt_x_gpu, + with_comms, +) + +from torchtitan.models.norms import fused_rms_norm_fn + + +class TestFusedRMSNorm(DTensorTestBase): + @property + def world_size(self): + return 4 + + @skip_if_lt_x_gpu(4) + @with_comms + def test_fused_rms_norm(self): + mesh = init_device_mesh( + device_type=self.device_type, mesh_shape=(self.world_size,) + ) + x = torch.randn(4, 4, 4, device=self.device_type) # Shard(1) + w = torch.randn(4, device=self.device_type, requires_grad=True) # Replicate + + dist_x = distribute_tensor(x, mesh, [Shard(1)]) + dist_w = distribute_tensor(w, mesh, [Replicate()]) + + x = x.clone().detach() + w = w.clone().detach().requires_grad_() + + self.assertEqual(dist_x.full_tensor(), x) + self.assertEqual(dist_w.full_tensor(), w) + + # fused rmsnorm on DTensor + comm_mode = CommDebugMode() + # fused rmsnorm + with comm_mode: + dist_out = fused_rms_norm_fn(dist_x, dist_w) + + self.assertEqual(comm_mode.get_total_counts(), 0) + + with comm_mode: + dist_grad_out = torch.ones_like(dist_out) + dist_out.backward(dist_grad_out) + + self.assertEqual(comm_mode.get_total_counts(), 0) + + # fused rmsnorm on Tensor + out = fused_rms_norm_fn(x, w) + grad_out = torch.ones_like(out) + out.backward(grad_out) + + self.assertEqual(dist_out.full_tensor(), out) + self.assertEqual(dist_grad_out.full_tensor(), grad_out) + + +if __name__ == "__main__": + run_tests() diff --git a/test_runner.py b/test_runner.py index ad5b4131..d688bc3b 100755 --- a/test_runner.py +++ b/test_runner.py @@ -163,8 +163,17 @@ def build_test_list(): "--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", ], ], - "Eager mode 2DParallel", - "eager_2d", + "Eager mode 2DParallel with rmsnorm", + "eager_2d_rmsnorm", + ), + OverrideDefinitions( + [ + [ + "--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm", + ], + ], + "Eager mode 2DParallel with fused_rmsnorm", + "eager_2d_fused_rmsnorm", ), OverrideDefinitions( [ diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index e29338d9..4245fe41 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -6,12 +6,17 @@ import math +from functools import partial + import torch import torch.nn as nn import triton import triton.language as tl +from torch.distributed._tensor import Partial, Replicate, Shard +from torch.distributed._tensor.experimental import local_map + def create_norm(norm_type: str, dim: int, eps: float = 1e-6): """ @@ -214,6 +219,11 @@ def _rms_norm_bwd_kernel_sm( class TritonFusedRMSNorm(torch.autograd.Function): + @partial( + local_map, + out_placements=[Shard(1)], + in_placements=(None, [Shard(1)], [Replicate()], None), + ) @staticmethod def forward(ctx, x, weight, eps): x_shape_start = x.shape @@ -256,6 +266,11 @@ def forward(ctx, x, weight, eps): y = y.reshape(x_shape_start) return y + @partial( + local_map, + out_placements=([Shard(1)], [Partial()], None), + in_placements=(None, [Shard(1)]), + ) @staticmethod def backward(ctx, dy): x, weight, rstd = ctx.saved_tensors diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 7f555e88..0fbdfe3e 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -304,11 +304,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ if parallel_dims.tp_enabled: - if job_config.model.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm." - ) - tp_mesh = world_mesh["tp"] ( row_parallel_strategy,