Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tp_overlap need tensor parallel is equal world size ? #966

Open
kuangdao opened this issue Jun 25, 2024 · 5 comments
Open

tp_overlap need tensor parallel is equal world size ? #966

kuangdao opened this issue Jun 25, 2024 · 5 comments

Comments

@kuangdao
Copy link

kuangdao commented Jun 25, 2024

i want set tp size = 2 and the global world size = 2

the code is :


import os
import sys
import subprocess
import argparse

import torch
import torch.distributed as dist

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling


def parse_args(argv=None, namespace=None):
    parser = argparse.ArgumentParser(
        description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
    )
    parser.add_argument(
        "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
    )
    parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
    parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
    parser.add_argument(
        "-n", "--num-heads", type=int, default=64, help="Number of attention heads."
    )
    parser.add_argument(
        "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
    )
    parser.add_argument(
        "--mlp-expansion-factor",
        type=int,
        default=4,
        help="MLP block intermediate size as a factor of hidden dimension.",
    )
    parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
    parser.add_argument(
        "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
    )
    parser.add_argument(
        "--no-comm-overlap",
        action="store_true",
        default=False,
        help="Disable the comm+GEMM overlap.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", default=False)
    return parser.parse_args(argv, namespace)


def train(opts):
    WORLD_RANK = int(os.getenv("RANK"))
    WORLD_SIZE = int(os.getenv("WORLD_SIZE"))

    def dist_print(msg, end="\n", all_ranks=False):
        if WORLD_RANK == 0 or all_ranks:
            print(f"[RANK-{WORLD_RANK}] {msg}", end=end)


    torch.cuda.set_device(WORLD_RANK)
    torch.manual_seed(opts.seed + WORLD_RANK)
    torch.cuda.manual_seed(opts.seed + WORLD_RANK)

    dist.init_process_group(
        backend="nccl",
        rank=WORLD_RANK,
        world_size=WORLD_SIZE,
        device_id=torch.device(f"cuda:{WORLD_RANK}"),
    )
    
    

    tp_group_0 = dist.new_group([0, 1],backend="nccl")
    tp_group_1 = dist.new_group([2, 3],backend="nccl")
    tp_group_2 = dist.new_group([4, 5],backend="nccl")
    tp_group_3 = dist.new_group([6, 7],backend="nccl")

    if WORLD_RANK in [0, 1]:
        tp_group = tp_group_0
    elif WORLD_RANK in [2, 3]:
        tp_group = tp_group_1
    elif WORLD_RANK in [4, 5]:
        tp_group = tp_group_2
    elif WORLD_RANK in [6, 7]:
        tp_group = tp_group_3

    tensor = torch.ones([2, 2]).cuda() * WORLD_RANK
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=tp_group)

    print("after allreduce is : {}".format(tensor))


    tp_size = dist.get_world_size(tp_group)


    ag_cfg = {  
        "method": "ring_exchange",
        "num_splits": 8,
        "num_sm": 1,
        "set_sm_margin": False,
    }
    rs_cfg = {  
        "method": "ring_exchange",
        "num_splits": 4,
        "num_sm": 1,
        "set_sm_margin": True,
    }
    hidden_size = opts.num_heads * opts.head_dim
    batched_size = opts.seq_length * opts.batch_size

    print("batched_size is : {}".format(batched_size))

    if not opts.no_comm_overlap:
        te.initialize_ub(
            [batched_size, hidden_size],
            tp_group,
            use_fp8=opts.fp8,
            dtype=torch.bfloat16,
            ub_cfgs={
                "fc1_fprop": ag_cfg,
                "fc1_dgrad": rs_cfg,
                "fc2_fprop": rs_cfg,
                "fc2_dgrad": ag_cfg,
            },
        )

    
    model = te.LayerNormMLP(
        hidden_size,
        opts.mlp_expansion_factor * hidden_size,
        params_dtype=torch.bfloat16,
        device="cuda",
        tp_group=tp_group,
        tp_size=tp_size,
        set_parallel_mode=True,
        sequence_parallel=True,  
        seq_length=opts.seq_length,
        micro_batch_size=opts.batch_size,
        ub_overlap_rs_dgrad=not opts.no_comm_overlap,
        ub_overlap_rs=not opts.no_comm_overlap,
        ub_overlap_ag=not opts.no_comm_overlap,
    )

    optim = torch.optim.Adam(model.parameters(), lr=0.0001)

    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

    for i in range(opts.num_iters):
        dist_print(f"Iter {i+1}", all_ranks=opts.verbose)

        dist_print("|-- Generate random input batch", all_ranks=opts.verbose)
        x = torch.rand(
            (opts.seq_length // tp_size, opts.batch_size, hidden_size),
            dtype=torch.bfloat16,
            device="cuda",
            requires_grad=True,
        )

        dist_print("|-- Forward pass", all_ranks=opts.verbose)
        with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group):
            y = model(x)
            dist_print("|-- Compute loss", all_ranks=opts.verbose)
            loss = y.flatten().sum()

        dist_print("|-- Backward pass", all_ranks=opts.verbose)
        loss.backward()

        dist_print("|-- Optimizer step", all_ranks=opts.verbose)
        optim.step()

    te.destroy_ub()
    dist.destroy_process_group()


if __name__ == "__main__":
    if "TORCHELASTIC_RUN_ID" in os.environ.keys():
        args = parse_args()
        train(args)
    else:
        subprocess.run(
            ["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv],
            env=os.environ,
            check=True,
        )
    os._exit(0)


and i run with torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_sub_group.py

the error is :

企业微信截图_f2d656f8-4940-4441-b4f3-066153c1117c

the commit id of TransformerEngine is 4a4f05d

and i use the docker image is nvcr.io/nvidia/nemo:24.05

@timmoon10
Copy link
Collaborator

The tensor parallel group can be a subset of the world group. We frequently split the world group into orthogonal tensor-parallel, data-parallel, and pipeline-parallel groups.

Based on the error message, it looks like there's an error when NCCL is initializing IPC communicators:

NCCLCHECK(ncclIpcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag));

To get more information, can you set NCCL_DEBUG=WARN in the environment?

@kuangdao
Copy link
Author

i have set export NCCL_DEBUG=WARN and there is no additional message
企业微信截图_84f6cf9d-47de-46f8-a314-aeac88cf9a0c

@denera
Copy link
Collaborator

denera commented Jul 1, 2024

@kuangdao TE in general supports TP size < world size, but the comm+GEMM overlap has some unique restrictions. The underlying device-to-device comms code currently assumes TP size == world size. You may be able to get around this limitation by running with UB_SKIPMC=1, but this leverages CUDA IPC Handles instead of CUDA Multicast so it may not be as performant.

As a disclaimer, comm+GEMM overlap is currently an experimental and somewhat fragile feature that is not yet fully supported in TE under all circumstances (and intentionally undocumented). That will change in the near future, as we improve the underlying device-to-device comms code and test it more rigorously on different platforms.

@kuangdao
Copy link
Author

kuangdao commented Jul 2, 2024

thanks, i know, i think comm+GEMM overlap is outstanding job, and i hope more documents such as design and Implementation will be give.

@denera
Copy link
Collaborator

denera commented Aug 16, 2024

@kuangdao -- we merged some changes to comm+GEMM overlap in the last month specifically to address multi-node mixed DP/TP use-cases. This feature is still restricted to tp_size <= local_size where local_size is the # of GPUs in a single NVLink domain (currently a single physical node of max 8 GPUs), but it now functions correctly with model replication across node boundaries. Could you test again and confirm if this works for your use case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants