Skip to content

Commit

Permalink
ADLR/megatron-lm!2075 - Triton cache fix
Browse files Browse the repository at this point in the history
  • Loading branch information
duncanriach authored and ericharper committed Oct 19, 2024
1 parent 0d89fc4 commit 839dff2
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 33 deletions.
13 changes: 10 additions & 3 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ The following applies to all files unless otherwise noted:

This repository also contains code from Hugging Face Inc., Google Research,
Facebook (from their Fairseq, Dino, and ParlAI projects), Microsoft (from their
Swin-Transformer project), Philip Popien, and the Mamba project (Tri Dao and
Albert Gu). Files from these organizations have notices at the top of each file.
Swin-Transformer project), Philip Popien, the Mamba project (Tri Dao and
Albert Gu), and the Triton language and compiler project (Philippe Tillet and
OpenAI). Files from these organizations have notices at the top of each file.
Below are licenses used in those files, as indicated.


Expand Down Expand Up @@ -241,7 +242,13 @@ Below are licenses used in those files, as indicated.
See the License for the specific language governing permissions and
limitations under the License.

------------- LICENSE FOR Facebook, Inc. and its affiliates, Meta Platforms, Inc. and its affiliates, Microsoft Corporation, and OpenGVLab/InternVL --------------
--------------------------------------------------------------------------------
LICENSE FOR
Facebook, Inc. and its affiliates,
Meta Platforms, Inc. and its affiliates,
Microsoft Corporation,
OpenGVLab/InternVL, and
Triton language and compiler.

MIT License

Expand Down
1 change: 0 additions & 1 deletion megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
Expand Down
43 changes: 39 additions & 4 deletions megatron/core/ssm/mamba_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
Expand All @@ -20,12 +19,33 @@

@dataclass
class MambaLayerSubmodules:
"""
Configuration class for specifying the submodules of a Mamba layer.
This class defines the structure and default implementations for various
components of a Mamba layer, allowing for flexible customization of the
layer's architecture.
Args:
norm (Union[ModuleSpec, type]): Specification for the input layer normalization.
mixer (Union[ModuleSpec, type]): Specification for the along-sequence mixing mechanism.
mamba_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation
after the mixer.
"""

norm: Union[ModuleSpec, type] = IdentityOp
mixer: Union[ModuleSpec, type] = IdentityOp
mamba_bda: Union[ModuleSpec, type] = IdentityOp


class MambaLayer(MegatronModule):
"""
A single Mamba layer.
Mamba layer takes input with size [s, b, h] and returns an
output of the same size.
"""

def __init__(
self,
config: TransformerConfig,
Expand All @@ -34,9 +54,7 @@ def __init__(
layer_number: int = 1,
residual_in_fp32=False,
):
"""
Top level Mamba Layer
"""
"""Initialize Mamba Layer."""
super().__init__(config)
self.config = config
self.layer_number = layer_number
Expand All @@ -60,6 +78,22 @@ def forward(
inference_params=None,
rotary_pos_emb: Tensor = None, # Not used in MambaLayer
):
"""
Perform a forward pass through the Mamba layer.
This method implements the core computation of a Mamba layer, including
the convolution and the selective SSM/SSD.
Args:
hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length,
b is batch size, and h is hidden size.
attention_mask (Tensor): Mask tensor for self-attention. Not used by this layer.
inference_params (object, optional): Parameters for inference-time optimizations.
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
Returns:
output (Tensor): Transformed hidden states of shape [s, b, h].
"""

residual = hidden_states
if self.residual_in_fp32:
Expand All @@ -78,4 +112,5 @@ def forward(
return hidden_states

def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
"""Allocate the inference cache."""
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
2 changes: 1 addition & 1 deletion megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
Expand Down Expand Up @@ -580,6 +579,7 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
return conv_state, ssm_state

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Provide a sharded state dictionary for distributed checkpointing."""
sharded_state_dict = {}
# Parameters
self._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
Expand Down
85 changes: 61 additions & 24 deletions megatron/core/ssm/triton_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,81 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright 2018-2020 Philippe Tillet
# Copyright 2020-2022 OpenAI

# Some of this code was adopted from https://github.com/triton-lang/triton
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import socket
import uuid
from pathlib import Path

import torch

try:
from triton import __version__ as triton_version
from triton.runtime.cache import FileCacheManager
except ImportError:
raise ImportError("triton is required by the Mamba model but cannot be imported")


def get_rank():
return torch.distributed.get_rank()
def _version_no_greater_than(version, version_limit):
major, minor, _ = map(int, version.split('.'))
limit_major, limit_minor = map(int, version_limit.split('.'))
return major < limit_major or (major == limit_major and minor <= limit_minor)


def default_cache_dir():
"""Provides a default path for the Triton cache directory."""
return os.path.join(Path.home(), ".triton", "cache")


class ParallelFileCacheManager(FileCacheManager):
"""
This patched version of ParallelFileCacheManager prevents errors related
to the builing of the Triton compiler cache when the number of model
parallel ranks is greater than one, including when certain types of file
system are used (such as Lustre).
Usage:
export TRITON_CACHE_DIR=<chosen-cache-location>
export TRITON_CACHE_MANAGER=megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager
# See https://github.com/triton-lang/triton/blob/main/python/triton/runtime/cache.py

# When running Triton with multiple ranks, they each create their own cache manager. Their input
# keys to that class are mostly (but not entirely) the same across ranks, which leads many ranks
# to write to the same 'key' directories in the cache dir at the same time during compilation,
# leading to conflicts. This works around that by making each cache dir be rank specific by
# adding "rank_<host>_<pid>" to the cache directory.

def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
self.cache_dir = os.path.join(
self.cache_dir, "rank_{}_{}".format(socket.gethostname(), os.getpid())
This patch implements the changes in the following two Triton project pull
requests:
1. https://github.com/triton-lang/triton/pull/3544
2. https://github.com/triton-lang/triton/pull/4295
The above changes will probably be included in Triton release version 3.1,
making this patch no longer necessary.
"""

def put(self, data, filename, binary=True) -> str:
"""A patched version of put, implementing PR 3544 and PR 4295."""
patch_limit = '3.0'
assert _version_no_greater_than(triton_version, patch_limit), (
"Assertion failed: ParallelFileCacheManager patch should not be "
f"used beyond Triton version {patch_limit}."
)
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
# Random ID to avoid any collisions
rnd_id = str(uuid.uuid4())
# we use the PID in case a bunch of these around so we can see what PID made it
pid = os.getpid()
# use temp dir to be robust against program interruptions
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, filename)

mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
os.removedirs(temp_dir)
return filepath

0 comments on commit 839dff2

Please sign in to comment.