Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline placement with zigzag #25

Merged
merged 8 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,6 @@ third_party/

# reports
*.nsys-rep
reports/
reports/
steps_trace.json
*.png
30 changes: 24 additions & 6 deletions benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,21 @@ def __repr__(self):
f"itl_latency_p99: {self.itl_latency_p99_ms:.2f}ms\n"

def launch(args):
cluster_config = ClusterConfig(n_node=1, n_gpu=3,
cluster_config = ClusterConfig(n_node=args.num_nodes, n_gpu=args.num_gpus,
id_tokenizer=tokenizer,
id_sampler=sampler)

model_config = duo_expert_mixtral
model_config.num_layers = 32
model_config.ep_size = 2
model_config.num_experts = 8
model_config.tp_size = 1
model_config.num_layers = args.num_layers
model_config.ep_size = args.ep_size
model_config.tp_size = args.tp_size
model_config.tp_enable_inter_group = False
model_config.enable_cuda_graph = args.cuda_graph

mp = get_model_placement(model_config, cluster_config, "interleave")
mp = get_model_placement(model_config, cluster_config, args.placement,
step_attn=args.step_attn, step_expert=args.step_expert,
zigzag_attn=args.zigzag_attn)
# mp = get_model_placement(model_config, cluster_config, "interleave")

global master

Expand Down Expand Up @@ -202,8 +204,24 @@ def get_args():
parser.add_argument("-c", "--cuda-graph", action="store_true", default=False, help="enable cuda graph")
parser.add_argument("--nsys", action="store_true", help="enable nsys profiling")

# model config
parser.add_argument("-g", "--num-gpus", type=int, default=4, help="number of gpus per node")
parser.add_argument("-N", "--num-nodes", type=int, default=1, help="number of nodes")
parser.add_argument("--tp-size", type=int, default=1, help="tensor parallel size")
parser.add_argument("--ep-size", type=int, default=2, help="expert parallel size")
parser.add_argument("--num-layers", type=int, default=32, help="number of layers")

# placement config
parser.add_argument("--placement", type=str, default="pipeline", help="placement strategy")
parser.add_argument("--zigzag-attn", action="store_true", default=True, help="enable zigzag attention placment")
parser.add_argument("--step-attn", type=int, default=2, help="number of steps in attention placement")
parser.add_argument("--step-expert", type=int, default=1, help="number of steps in expert placement")

args = parser.parse_args()

if args.num_gpus % (args.tp_size * args.step_attn + args.ep_size * args.step_expert) != 0:
print("Warning: number of gpus is not divisible by the number of placement steps")

if args.nsys:
assert args.profile_dir is None, "cannot enable both nsys and torch profiler"

Expand Down
10 changes: 8 additions & 2 deletions csrc/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/functional.h>
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/stl.h>

#include "tests.h"
Expand All @@ -13,6 +16,8 @@
#define REGISTER_STRUCT(name, ...) py::class_<name>(m, #name).def(py::init<__VA_ARGS__>())
#define REGISTER_FUNC(name) m.def(#name, &name)

PYBIND11_MAKE_OPAQUE(std::map<std::pair<int, int>, int>);

namespace py = pybind11;

PYBIND11_MODULE(disagmoe_c, m) {
Expand Down Expand Up @@ -67,10 +72,11 @@ PYBIND11_MODULE(disagmoe_c, m) {
REGISTER_STRUCT(TokenMetadata);

py::class_<ParallelConfig>(m, "ParallelConfig")
.def(py::init<int, int, int>())
.def(py::init<>())
.def_readwrite("tp", &ParallelConfig::tp)
.def_readwrite("ep", &ParallelConfig::ep)
.def_readwrite("n_exp_per_rank", &ParallelConfig::n_exp_per_rank);
.def_readwrite("n_exp_per_rank", &ParallelConfig::n_exp_per_rank)
.def_readwrite("expert_ranks", &ParallelConfig::expert_ranks);

REGISTER_STRUCT(AttentionBatch)
.def_readwrite("data", &AttentionBatch::data)
Expand Down
7 changes: 5 additions & 2 deletions csrc/include/datatypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,9 @@ struct ParallelConfig {
int ep = 1;
int n_exp_per_rank = 1;

ParallelConfig(int tp, int ep, int n_exp_per_rank):
tp(tp), ep(ep), n_exp_per_rank(n_exp_per_rank) {}
// (layer_id, expert_id, expert_rank)
std::vector<std::tuple<int, int, int>> expert_ranks = {};

ParallelConfig(int tp = 1, int ep = 1, int n_exp_per_rank = 1, const std::vector<std::tuple<int, int, int>> &expert_ranks = {}):
tp(tp), ep(ep), n_exp_per_rank(n_exp_per_rank), expert_ranks(expert_ranks) {}
};
4 changes: 3 additions & 1 deletion csrc/include/muhelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ class MuAttnDispatcher: public MuDispatcher {
std::vector<int> exp_channels;
int max_exp_id;

std::vector<std::vector<int>> _inner_expert_ranks;

void _send_once(TensorBatch batch) override;

int _encode(int exp_layer_id, int exp_id) const;

int _get_rank(int exp_id) const;
int _get_rank(int exp_layer_id, int exp_id) const;

public:
MuAttnDispatcher(std::vector<int> layer_ids,
Expand Down
4 changes: 2 additions & 2 deletions csrc/muhelper/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Sampler::Sampler(int device_id,
MuExpertDispatcher(
/*layer_ids=*/ {0},
device_id,
ParallelConfig(1, 1, 1),
ParallelConfig(1, 1, 1, {}),
out_channels,
out_channel_infos
) {
Expand Down Expand Up @@ -176,7 +176,7 @@ void Sampler::start() {
Tokenizer::Tokenizer(int device_id,
std::vector<Channel_t> channels,
std::vector<ChannelInfo> out_channel_infos):
MuExpertDispatcher({}, device_id, ParallelConfig(1, 1, 1), channels, out_channel_infos) {
MuExpertDispatcher({}, device_id, ParallelConfig(1, 1, 1, {}), channels, out_channel_infos) {
}

void Tokenizer::put_request(int req_id, torch::Tensor tensor) {
Expand Down
26 changes: 20 additions & 6 deletions csrc/muhelper/muhelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,21 @@ MuAttnDispatcher::MuAttnDispatcher(
}
max_exp_id ++;
DMOE_LOG(INFO) << "max_layer_id " << max_layer_id << ", max_exp_id " << max_exp_id << LEND;
exp_channels.resize(_encode(max_layer_id + 1, 0), -1);
exp_channels.resize((max_layer_id + 1) * max_exp_id, -1);

// get expert ranks
_inner_expert_ranks.resize(max_layer_id + 1);
for (int i = 0; i <= max_layer_id; i ++)
_inner_expert_ranks[i].resize(max_exp_id + 1, -1);
for (auto &tuple: cfg.expert_ranks) {
int layer_id = std::get<0>(tuple);
int exp_id = std::get<1>(tuple);
int rank = std::get<2>(tuple);
_inner_expert_ranks[layer_id][exp_id] = rank;
ASSERT(rank < max_exp_id);
}

// get expert channels
for (int i = 0; i < channels.size(); i ++) {
for (auto exp_id: out_channel_infos[i].expert_ids) {
int id = _encode(exp_id.first, exp_id.second);
Expand All @@ -154,12 +167,13 @@ MuAttnDispatcher::MuAttnDispatcher(
}
}

inline int MuAttnDispatcher::_get_rank(int exp_id) const {
return exp_id / cfg.n_exp_per_rank;
inline int MuAttnDispatcher::_get_rank(int exp_layer_id, int exp_id) const {
ASSERT(_inner_expert_ranks[exp_layer_id][exp_id] >= 0);
return _inner_expert_ranks[exp_layer_id][exp_id];
}

inline int MuAttnDispatcher::_encode(int exp_layer_id, int exp_id) const {
return exp_layer_id * this->max_exp_id + _get_rank(exp_id);
return exp_layer_id * this->max_exp_id + _get_rank(exp_layer_id, exp_id);
}

void MuAttnDispatcher::_send_once(TensorBatch batch) {
Expand All @@ -172,8 +186,8 @@ void MuAttnDispatcher::_send_once(TensorBatch batch) {
int lid = batch.metadata->layer_id;
for (int i = 0; i < n;) {
int j = i + 1;
int ep_rank = _get_rank(batch.metadata->exp_ids[i]);
while (j < n && _get_rank(batch.metadata->exp_ids[j]) == ep_rank)
int ep_rank = _get_rank(lid, batch.metadata->exp_ids[i]);
while (j < n && _get_rank(lid, batch.metadata->exp_ids[j]) == ep_rank)
j ++;
ASSERT(ep_rank >= 0);
int cid = _encode(lid, batch.metadata->exp_ids[i]);
Expand Down
6 changes: 2 additions & 4 deletions disagmoe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ class ModelConfig:
dtype: torch.dtype
ep_size: int
tp_size: int = 1
dp_size: int = 1
rank: int = 0
layer_ids: Optional[List[int]] = None

tp_enable_inter_group: bool = True
enable_cuda_graph: bool = False
Expand All @@ -25,10 +27,6 @@ class ModelConfig:
def num_experts_per_rank(self):
return self.num_experts // self.ep_size

@property
def layer_ids(self):
return list(range(self.num_layers))

@dataclass
class CacheConfig(vllm.config.CacheConfig):

Expand Down
2 changes: 1 addition & 1 deletion disagmoe/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Executor:

def __init__(self, model_config: ModelConfig):
self.model_config = model_config
self.num_layers = model_config.num_layers
self.num_layers = len(model_config.layer_ids)
self.layer_mappings = {
id: i for i, id in enumerate(model_config.layer_ids)
}
Expand Down
1 change: 1 addition & 0 deletions disagmoe/frontend/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def init_engine(self,
device_group_ids=model_place.device_groups.get(device_id, []),
group_nccl_ids=group_nccl_ids.get(
tuple(model_place.device_groups.get(device_id, [])), ("", "", "")),
expert_ranks=model_place.out_expert_ranks_at(device_id),
)
for worker, device_id in zip(
self.workers + [self.sampler_worker, self.tokenizer_worker],
Expand Down
69 changes: 40 additions & 29 deletions disagmoe/frontend/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,37 @@ def init_core(
out_device_group_ids: Dict[int, List[int]],
out_nccl_ids: Dict[int, int],
device_group_ids: List[int] = None,
group_nccl_ids: Tuple[str, str, str] = ("", "", "")
group_nccl_ids: Tuple[str, str, str] = ("", "", ""),
expert_ranks: List[Tuple[int, int, int]] = [],
):
"""
NOTE(hogura|20241003): When using ray, all the device_id called to CUDA should become 0
"""
self.model_config.layer_ids = layer_ids
self.device_group_ids = device_group_ids

if self.engine_type == EngineType.ATTENTION:
self.executor = AttnExecutor.build(self.model_config, self.cache_config)
self._process_batch = self.process_batch_attn
self.block_mgr = BlockManager_C(
self.cache_config.block_size,
self.cache_config.num_gpu_blocks,
self.cache_config.num_reserved_blocks
)
if self._intra_group_tp_enabled:
self._create_broadcast_buffers()
elif self.engine_type == EngineType.EXPERT:
self.executor = ExpertsExecutor(self.model_config)
self._process_batch = self.process_batch_expert
# prepare inner exp rank, [n_exp_per_rank * rank, (rank + 1) * n_exp_per_rank) -> [0, n_exp_per_rank)
self.inner_exp_rank = [0 for _ in range(self.model_config.num_experts_per_rank)]
for i in range(self.model_config.num_experts_per_rank):
self.inner_exp_rank[i] = self.model_config.num_experts_per_rank * self.rank_in_group + i

if not self.model_config.tp_enable_inter_group:
device_group_ids = None
out_device_group_ids = {}

self._logger.info(f"launching core: {layer_ids, in_device_ids, \
out_device_ids, out_channel_infos, \
in_nccl_ids, out_nccl_ids, out_device_group_ids, \
Expand All @@ -137,10 +159,11 @@ def init_core(
[ChannelInfo_C(info.expert_ids, info.attn_layer_ids)
for info in out_channel_infos],
# Parallel config
ParallelConfig_C(
ParallelConfig.from_c(
self.model_config.tp_size if self.model_config.tp_enable_inter_group else 1, # control the init of attn_scheduler
self.model_config.ep_size,
self.model_config.num_experts_per_rank,
expert_ranks,
),
# Group Channels
in_nccl_ids,
Expand Down Expand Up @@ -207,24 +230,6 @@ def setup_engine(self,
self.model_config = model_config
self.cache_config = cache_config

if engine_type == EngineType.ATTENTION:
self.executor = AttnExecutor.build(model_config, cache_config)
self._process_batch = self.process_batch_attn
self.block_mgr = BlockManager_C(
cache_config.block_size,
cache_config.num_gpu_blocks,
cache_config.num_reserved_blocks
)
if self._intra_group_tp_enabled:
self._create_broadcast_buffers()
elif engine_type == EngineType.EXPERT:
self.executor = ExpertsExecutor(model_config)
self._process_batch = self.process_batch_expert
# prepare inner exp rank, [n_exp_per_rank * rank, (rank + 1) * n_exp_per_rank) -> [0, n_exp_per_rank)
self.inner_exp_rank = [0 for _ in range(model_config.num_experts_per_rank)]
for i in range(model_config.num_experts_per_rank):
self.inner_exp_rank[i] = model_config.num_experts_per_rank * rank + i

self._logger.info(f"engine setup. {self.engine_type, model_config}")

def _create_cuda_graph_contexts(self):
Expand All @@ -244,7 +249,7 @@ def _create_cuda_graph_contexts(self):
(MAX_BATCH_SIZE + MAX_BATCH_SIZE + (MAX_BATCH_SIZE + 1) + (MAX_BATCH_SIZE + 1)),
dtype=torch.int32, device="cuda")
else:
self.static_batch_sizes = torch.zeros((self.model_config.num_experts_per_rank, ), dtype=torch.long, device="cuda")
self.static_batch_sizes = torch.zeros((self.model_config.num_experts_per_rank, ), dtype=torch.long, device="cpu")

for i in range(self.model_config.num_layers):
self.graphs.append([
Expand All @@ -261,13 +266,14 @@ def _warmup_attn(self):
meta_py = make_dummy_meta(MAX_BATCH_SIZE, 0)
meta = self._pack_flash_attn_metadata(meta_py.to_c(), meta_py, [], mocking=True)
# warmup for CUBLAS
self.static_output, self.static_expert_ids = self.executor.execute(
0, self.static_positions, self.static_input, meta)
for _ in range(5):
self.static_output, self.static_expert_ids = self.executor.execute(
self.model_config.layer_ids[0], self.static_positions, self.static_input, meta)
if self.model_config.tp_size > 1:
group_sync()

meta.use_cuda_graph = True
for layer_id in range(self.model_config.num_layers):
for layer_id in self.model_config.layer_ids:
for graph, graph_batch_size in zip(self.graphs[layer_id], GRAPH_BATCH_SIZES):
meta_py = make_dummy_meta(graph_batch_size, 0)
meta = self._pack_flash_attn_metadata(meta_py.to_c(), meta_py, [], mocking=True)
Expand All @@ -281,7 +287,8 @@ def _warmup_attn(self):
graph.replay()

def _warmup_expert(self):
self.static_output = self.executor.execute(0, self.static_input, self.static_batch_sizes)
for _ in range(5):
self.static_output = self.executor.execute(0, self.static_input, self.static_batch_sizes)
self._logger.warning("Expert CUDA Graph is not implemented yet")
return

Expand Down Expand Up @@ -615,7 +622,7 @@ def process_batch_expert(self,
batch_sizes = torch.tensor(
[batch_sizes[i] for i in self.inner_exp_rank],
dtype=torch.int64,
device="cuda", # NOTE(hogura|20241014): grouped_gemm requires batch_sizes to be on cpu
device="cpu", # NOTE(hogura|20241014): grouped_gemm requires batch_sizes to be on cpu
)

# self._logger.info(f"executing expert {meta_c.req_ids}")
Expand Down Expand Up @@ -693,7 +700,7 @@ def loop(self):
for i, size in enumerate(pool_snapshot):
if size <= 0:
continue
layer = self.executor.layer_mappings[i]
layer = self.model_config.layer_ids[i]
pool_snapshot_dict[layer] = size

self._step_stats.append(
Expand Down Expand Up @@ -772,7 +779,9 @@ def init_core(
out_device_group_ids: Dict[int, List[int]],
out_nccl_ids: Dict[int, int],
device_group_ids: List[int] = None,
group_nccl_ids: str = ""):
group_nccl_ids: str = "",
expert_ranks: List[Tuple[int, int, int]] = [],
):
self.sampler = init_sampler(
self.device_id,
self.max_output_len,
Expand Down Expand Up @@ -843,7 +852,9 @@ def init_core(
out_device_group_ids: Dict[int, List[int]],
out_nccl_ids: Dict[int, int],
device_group_ids: List[int] = None,
group_nccl_ids: str = ""):
group_nccl_ids: str = "",
expert_ranks: List[Tuple[int, int, int]] = [],
):
self.tokenizer = init_tokenizer(
self.device_id,
out_device_ids,
Expand Down
2 changes: 1 addition & 1 deletion disagmoe/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
INTERMEDIATE_SIZE = 4 * HIDDEN_SIZE

BLOCK_SIZE = 32
NUM_BLOCKS = 4096
NUM_BLOCKS = 8000
RESERVED_BLOCKS = 256

CPS = 1e6
Expand Down
Loading