Skip to content

Commit

Permalink
Merge commit '4ddb0a7bea787294282d0fe0715adf5ea4a39779' into dev/zhan…
Browse files Browse the repository at this point in the history
…grb/fused_multi_pad_cast_transpose
  • Loading branch information
BeingGod committed Aug 28, 2024
2 parents 8336b13 + 4ddb0a7 commit 63ef882
Show file tree
Hide file tree
Showing 79 changed files with 5,191 additions and 1,506 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 118 files
2 changes: 1 addition & 1 deletion benchmarks/attention/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model):
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6

if per_flash > 0:
t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.10.0.dev0
1.11.0.dev0
19 changes: 12 additions & 7 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from .utils import (
all_files_in_dir,
cuda_version,
cuda_archs,
cuda_path,
cuda_version,
)


Expand Down Expand Up @@ -48,8 +49,6 @@ def setup_pytorch_extension(
]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
Expand All @@ -61,6 +60,11 @@ def setup_pytorch_extension(
"--use_fast_math",
]

cuda_architectures = cuda_archs()

if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])

# Version-dependent CUDA options
try:
version = cuda_version()
Expand All @@ -73,13 +77,14 @@ def setup_pytorch_extension(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
"-gencode",
"arch=compute_80,code=sm_80",
"-gencode",
"arch=compute_90,code=sm_90",
)
)

if "80" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if "90" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])

# Libraries
library_dirs = []
libraries = []
Expand Down
48 changes: 40 additions & 8 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

import functools
import glob
import importlib
import os
import re
import shutil
import subprocess
import sys
import importlib
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -188,6 +188,11 @@ def cuda_path() -> Tuple[str, str]:
return cuda_home, nvcc_bin


@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90")


def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple."""
# Query NVCC for version info
Expand Down Expand Up @@ -254,12 +259,39 @@ def get_frameworks() -> List[str]:
return _frameworks


def copy_common_headers(te_src, dst):
headers = te_src / "common"
for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True):
new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :])
Path(new_path).parent.mkdir(exist_ok=True, parents=True)
shutil.copy(file_path, new_path)
def copy_common_headers(
src_dir: Union[Path, str],
dst_dir: Union[Path, str],
) -> None:
"""Copy headers from core library
src_dir should be the transformer_engine directory within the root
Transformer Engine repository. All .h and .cuh files within
transformer_engine/common are copied into dst_dir. Relative paths
are preserved.
"""

# Find common header files in src dir
headers = glob.glob(
os.path.join(str(src_dir), "common", "**", "*.h"),
recursive=True,
)
headers.extend(
glob.glob(
os.path.join(str(src_dir), "common", "**", "*.cuh"),
recursive=True,
)
)
headers = [Path(path) for path in headers]

# Copy common header files to dst dir
src_dir = Path(src_dir)
dst_dir = Path(dst_dir)
for path in headers:
new_path = dst_dir / path.relative_to(src_dir)
new_path.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(path, new_path)


def install_and_import(package):
Expand Down
5 changes: 4 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@

git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha

version = str(te_version + "-" + git_sha)
if "dev" in te_version:
version = str(te_version + "-" + git_sha)
else:
version = str(te_version)
release = te_version

# hack: version is used for html creation, so put the version picker
Expand Down
10 changes: 5 additions & 5 deletions examples/jax/encoder/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Basic Transformer Encoder Example with Optional FP8 #

This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `pjit` to set up multiple GPU training. The basic pjit usage can be referred to [Scale up Flax Modules on multiple devices with pjit](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html).
This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `jit` `in `in_shardings` and `out_shardings` parameters to set up multiple GPU training. The basic parallel jit usage can be referred to [Scale up Flax Modules on multiple devices](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html).

## Single GPU ##

Expand Down Expand Up @@ -31,11 +31,11 @@ python test_single_gpu_encoder.py --use-fp8

4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis.

5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for pjit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example.
5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for parallel jit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example.

6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding.
6. Fill in `params_sharding` and `encoder.init` to jit to get a compiled function, `jit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding.

7. The `train_step` and `eval_step` also need to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example.
7. The `train_step` and `eval_step` also need to be compiled by jit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example.

8. Use `CUDA_VISIBLE_DEVICES` to control the number of GPUs used. For example, if the system has 8 GPUs but only 4 GPUs need to be used, then:
```sh
Expand Down Expand Up @@ -84,7 +84,7 @@ python test_model_parallel_encoder.py --use-fp8
1. This example inherits previous model parallelism example, but uses multiprocessing instead of single-program multiple-data (SPMD). It uses 1 GPU per process.
2. The benefit of multiprocessing is to setup hardware affinity for GPUs, such as NUMA binding. It may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details.
2. There is two main benefits of multiprocessing: support multi-node and to setup hardware affinity for GPUs, such as NUMA binding. Affinity may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details.
3. The quick way to check system topology is to use `nvidia-smi`, for example:
```sh
Expand Down
85 changes: 48 additions & 37 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec, NamedSharding

import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
Expand Down Expand Up @@ -223,32 +223,36 @@ def check_fp8(state, var_collect, inputs, masks, labels):
)


def get_params_pspec(sharding_rules, abs_var_collect):
"""Refer params to create params partition spec"""
rules_dict = {}
for key, value in sharding_rules:
rules_dict[key] = value
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
"""Refer params to create params sharding"""
rules_dict = dict(sharding_rules)

def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
return jax.sharding.PartitionSpec(*partitions)
return NamedSharding(mesh, PartitionSpec(*partitions))

params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes))
params_axes_pspec = flax.core.unfreeze(params_axes_pspec)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = {**params_pspec, **params_axes_pspec}
return params_pspec
params_axes_sharding = jax.tree_util.tree_map(
to_device_axis, nn_partitioning.get_axis_names(params_axes)
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding


def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec"""
def get_state_sharding(state, params_sharding):
"""Refer params_sharding to create state sharding"""

def replace_params(x):
return params_pspec if isinstance(x, dict) else None
return params_sharding if isinstance(x, dict) else None

state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict))
return state_pspec
state_sharding = jax.tree_util.tree_map(
replace_params, state, is_leaf=lambda x: isinstance(x, dict)
)
return state_sharding


def train_and_evaluate(args):
Expand All @@ -270,7 +274,9 @@ def train_and_evaluate(args):
), f"Test batch size needs to be multiple of {num_gpu_dp}"

device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh:

rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
Expand All @@ -291,34 +297,39 @@ def train_and_evaluate(args):

customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))

in_shardings = (None, inputs_pspec, masks_pspec)
in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
var_collect = jit_encoder_init(init_rngs, inputs, masks)

optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
state_pspec = get_state_pspec(state, params_pspec)
labels_pspec = jax.sharding.PartitionSpec(
DEVICE_DP_AXIS,
state_sharding = get_state_sharding(state, params_sharding)
labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))

in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)

in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings)

in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
pjit_eval_step = pjit(eval_step, in_shardings, out_shardings)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)

if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
Expand All @@ -327,7 +338,7 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs)
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None

Expand All @@ -337,11 +348,11 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)

test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, pjit_eval_step
state, test_ds, args.test_batch_size, var_collect, jit_eval_step
)

print(
Expand Down
Loading

0 comments on commit 63ef882

Please sign in to comment.