Skip to content

Commit

Permalink
Merge pull request #25 from USC-NSL/pipeline-placement
Browse files Browse the repository at this point in the history
Pipeline placement with zigzag
  • Loading branch information
shawlleyw authored Dec 17, 2024
2 parents f46cf93 + 2e53dac commit f724fcc
Show file tree
Hide file tree
Showing 14 changed files with 353 additions and 111 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,6 @@ third_party/

# reports
*.nsys-rep
reports/
reports/
steps_trace.json
*.png
30 changes: 24 additions & 6 deletions benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,21 @@ def __repr__(self):
f"itl_latency_p99: {self.itl_latency_p99_ms:.2f}ms\n"

def launch(args):
cluster_config = ClusterConfig(n_node=1, n_gpu=3,
cluster_config = ClusterConfig(n_node=args.num_nodes, n_gpu=args.num_gpus,
id_tokenizer=tokenizer,
id_sampler=sampler)

model_config = duo_expert_mixtral
model_config.num_layers = 32
model_config.ep_size = 2
model_config.num_experts = 8
model_config.tp_size = 1
model_config.num_layers = args.num_layers
model_config.ep_size = args.ep_size
model_config.tp_size = args.tp_size
model_config.tp_enable_inter_group = False
model_config.enable_cuda_graph = args.cuda_graph

mp = get_model_placement(model_config, cluster_config, "interleave")
mp = get_model_placement(model_config, cluster_config, args.placement,
step_attn=args.step_attn, step_expert=args.step_expert,
zigzag_attn=args.zigzag_attn)
# mp = get_model_placement(model_config, cluster_config, "interleave")

global master

Expand Down Expand Up @@ -202,8 +204,24 @@ def get_args():
parser.add_argument("-c", "--cuda-graph", action="store_true", default=False, help="enable cuda graph")
parser.add_argument("--nsys", action="store_true", help="enable nsys profiling")

# model config
parser.add_argument("-g", "--num-gpus", type=int, default=4, help="number of gpus per node")
parser.add_argument("-N", "--num-nodes", type=int, default=1, help="number of nodes")
parser.add_argument("--tp-size", type=int, default=1, help="tensor parallel size")
parser.add_argument("--ep-size", type=int, default=2, help="expert parallel size")
parser.add_argument("--num-layers", type=int, default=32, help="number of layers")

# placement config
parser.add_argument("--placement", type=str, default="pipeline", help="placement strategy")
parser.add_argument("--zigzag-attn", action="store_true", default=True, help="enable zigzag attention placment")
parser.add_argument("--step-attn", type=int, default=2, help="number of steps in attention placement")
parser.add_argument("--step-expert", type=int, default=1, help="number of steps in expert placement")

args = parser.parse_args()

if args.num_gpus % (args.tp_size * args.step_attn + args.ep_size * args.step_expert) != 0:
print("Warning: number of gpus is not divisible by the number of placement steps")

if args.nsys:
assert args.profile_dir is None, "cannot enable both nsys and torch profiler"

Expand Down
10 changes: 8 additions & 2 deletions csrc/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/functional.h>
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/stl.h>

#include "tests.h"
Expand All @@ -13,6 +16,8 @@
#define REGISTER_STRUCT(name, ...) py::class_<name>(m, #name).def(py::init<__VA_ARGS__>())
#define REGISTER_FUNC(name) m.def(#name, &name)

PYBIND11_MAKE_OPAQUE(std::map<std::pair<int, int>, int>);

namespace py = pybind11;

PYBIND11_MODULE(disagmoe_c, m) {
Expand Down Expand Up @@ -67,10 +72,11 @@ PYBIND11_MODULE(disagmoe_c, m) {
REGISTER_STRUCT(TokenMetadata);

py::class_<ParallelConfig>(m, "ParallelConfig")
.def(py::init<int, int, int>())
.def(py::init<>())
.def_readwrite("tp", &ParallelConfig::tp)
.def_readwrite("ep", &ParallelConfig::ep)
.def_readwrite("n_exp_per_rank", &ParallelConfig::n_exp_per_rank);
.def_readwrite("n_exp_per_rank", &ParallelConfig::n_exp_per_rank)
.def_readwrite("expert_ranks", &ParallelConfig::expert_ranks);

REGISTER_STRUCT(AttentionBatch)
.def_readwrite("data", &AttentionBatch::data)
Expand Down
7 changes: 5 additions & 2 deletions csrc/include/datatypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,9 @@ struct ParallelConfig {
int ep = 1;
int n_exp_per_rank = 1;

ParallelConfig(int tp, int ep, int n_exp_per_rank):
tp(tp), ep(ep), n_exp_per_rank(n_exp_per_rank) {}
// (layer_id, expert_id, expert_rank)
std::vector<std::tuple<int, int, int>> expert_ranks = {};

ParallelConfig(int tp = 1, int ep = 1, int n_exp_per_rank = 1, const std::vector<std::tuple<int, int, int>> &expert_ranks = {}):
tp(tp), ep(ep), n_exp_per_rank(n_exp_per_rank), expert_ranks(expert_ranks) {}
};
4 changes: 3 additions & 1 deletion csrc/include/muhelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ class MuAttnDispatcher: public MuDispatcher {
std::vector<int> exp_channels;
int max_exp_id;

std::vector<std::vector<int>> _inner_expert_ranks;

void _send_once(TensorBatch batch) override;

int _encode(int exp_layer_id, int exp_id) const;

int _get_rank(int exp_id) const;
int _get_rank(int exp_layer_id, int exp_id) const;

public:
MuAttnDispatcher(std::vector<int> layer_ids,
Expand Down
4 changes: 2 additions & 2 deletions csrc/muhelper/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Sampler::Sampler(int device_id,
MuExpertDispatcher(
/*layer_ids=*/ {0},
device_id,
ParallelConfig(1, 1, 1),
ParallelConfig(1, 1, 1, {}),
out_channels,
out_channel_infos
) {
Expand Down Expand Up @@ -176,7 +176,7 @@ void Sampler::start() {
Tokenizer::Tokenizer(int device_id,
std::vector<Channel_t> channels,
std::vector<ChannelInfo> out_channel_infos):
MuExpertDispatcher({}, device_id, ParallelConfig(1, 1, 1), channels, out_channel_infos) {
MuExpertDispatcher({}, device_id, ParallelConfig(1, 1, 1, {}), channels, out_channel_infos) {
}

void Tokenizer::put_request(int req_id, torch::Tensor tensor) {
Expand Down
26 changes: 20 additions & 6 deletions csrc/muhelper/muhelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,21 @@ MuAttnDispatcher::MuAttnDispatcher(
}
max_exp_id ++;
DMOE_LOG(INFO) << "max_layer_id " << max_layer_id << ", max_exp_id " << max_exp_id << LEND;
exp_channels.resize(_encode(max_layer_id + 1, 0), -1);
exp_channels.resize((max_layer_id + 1) * max_exp_id, -1);

// get expert ranks
_inner_expert_ranks.resize(max_layer_id + 1);
for (int i = 0; i <= max_layer_id; i ++)
_inner_expert_ranks[i].resize(max_exp_id + 1, -1);
for (auto &tuple: cfg.expert_ranks) {
int layer_id = std::get<0>(tuple);
int exp_id = std::get<1>(tuple);
int rank = std::get<2>(tuple);
_inner_expert_ranks[layer_id][exp_id] = rank;
ASSERT(rank < max_exp_id);
}

// get expert channels
for (int i = 0; i < channels.size(); i ++) {
for (auto exp_id: out_channel_infos[i].expert_ids) {
int id = _encode(exp_id.first, exp_id.second);
Expand All @@ -154,12 +167,13 @@ MuAttnDispatcher::MuAttnDispatcher(
}
}

inline int MuAttnDispatcher::_get_rank(int exp_id) const {
return exp_id / cfg.n_exp_per_rank;
inline int MuAttnDispatcher::_get_rank(int exp_layer_id, int exp_id) const {
ASSERT(_inner_expert_ranks[exp_layer_id][exp_id] >= 0);
return _inner_expert_ranks[exp_layer_id][exp_id];
}

inline int MuAttnDispatcher::_encode(int exp_layer_id, int exp_id) const {
return exp_layer_id * this->max_exp_id + _get_rank(exp_id);
return exp_layer_id * this->max_exp_id + _get_rank(exp_layer_id, exp_id);
}

void MuAttnDispatcher::_send_once(TensorBatch batch) {
Expand All @@ -172,8 +186,8 @@ void MuAttnDispatcher::_send_once(TensorBatch batch) {
int lid = batch.metadata->layer_id;
for (int i = 0; i < n;) {
int j = i + 1;
int ep_rank = _get_rank(batch.metadata->exp_ids[i]);
while (j < n && _get_rank(batch.metadata->exp_ids[j]) == ep_rank)
int ep_rank = _get_rank(lid, batch.metadata->exp_ids[i]);
while (j < n && _get_rank(lid, batch.metadata->exp_ids[j]) == ep_rank)
j ++;
ASSERT(ep_rank >= 0);
int cid = _encode(lid, batch.metadata->exp_ids[i]);
Expand Down
6 changes: 2 additions & 4 deletions disagmoe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ class ModelConfig:
dtype: torch.dtype
ep_size: int
tp_size: int = 1
dp_size: int = 1
rank: int = 0
layer_ids: Optional[List[int]] = None

tp_enable_inter_group: bool = True
enable_cuda_graph: bool = False
Expand All @@ -25,10 +27,6 @@ class ModelConfig:
def num_experts_per_rank(self):
return self.num_experts // self.ep_size

@property
def layer_ids(self):
return list(range(self.num_layers))

@dataclass
class CacheConfig(vllm.config.CacheConfig):

Expand Down
2 changes: 1 addition & 1 deletion disagmoe/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Executor:

def __init__(self, model_config: ModelConfig):
self.model_config = model_config
self.num_layers = model_config.num_layers
self.num_layers = len(model_config.layer_ids)
self.layer_mappings = {
id: i for i, id in enumerate(model_config.layer_ids)
}
Expand Down
1 change: 1 addition & 0 deletions disagmoe/frontend/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def init_engine(self,
device_group_ids=model_place.device_groups.get(device_id, []),
group_nccl_ids=group_nccl_ids.get(
tuple(model_place.device_groups.get(device_id, [])), ("", "", "")),
expert_ranks=model_place.out_expert_ranks_at(device_id),
)
for worker, device_id in zip(
self.workers + [self.sampler_worker, self.tokenizer_worker],
Expand Down
69 changes: 40 additions & 29 deletions disagmoe/frontend/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,37 @@ def init_core(
out_device_group_ids: Dict[int, List[int]],
out_nccl_ids: Dict[int, int],
device_group_ids: List[int] = None,
group_nccl_ids: Tuple[str, str, str] = ("", "", "")
group_nccl_ids: Tuple[str, str, str] = ("", "", ""),
expert_ranks: List[Tuple[int, int, int]] = [],
):
"""
NOTE(hogura|20241003): When using ray, all the device_id called to CUDA should become 0
"""
self.model_config.layer_ids = layer_ids
self.device_group_ids = device_group_ids

if self.engine_type == EngineType.ATTENTION:
self.executor = AttnExecutor.build(self.model_config, self.cache_config)
self._process_batch = self.process_batch_attn
self.block_mgr = BlockManager_C(
self.cache_config.block_size,
self.cache_config.num_gpu_blocks,
self.cache_config.num_reserved_blocks
)
if self._intra_group_tp_enabled:
self._create_broadcast_buffers()
elif self.engine_type == EngineType.EXPERT:
self.executor = ExpertsExecutor(self.model_config)
self._process_batch = self.process_batch_expert
# prepare inner exp rank, [n_exp_per_rank * rank, (rank + 1) * n_exp_per_rank) -> [0, n_exp_per_rank)
self.inner_exp_rank = [0 for _ in range(self.model_config.num_experts_per_rank)]
for i in range(self.model_config.num_experts_per_rank):
self.inner_exp_rank[i] = self.model_config.num_experts_per_rank * self.rank_in_group + i

if not self.model_config.tp_enable_inter_group:
device_group_ids = None
out_device_group_ids = {}

self._logger.info(f"launching core: {layer_ids, in_device_ids, \
out_device_ids, out_channel_infos, \
in_nccl_ids, out_nccl_ids, out_device_group_ids, \
Expand All @@ -137,10 +159,11 @@ def init_core(
[ChannelInfo_C(info.expert_ids, info.attn_layer_ids)
for info in out_channel_infos],
# Parallel config
ParallelConfig_C(
ParallelConfig.from_c(
self.model_config.tp_size if self.model_config.tp_enable_inter_group else 1, # control the init of attn_scheduler
self.model_config.ep_size,
self.model_config.num_experts_per_rank,
expert_ranks,
),
# Group Channels
in_nccl_ids,
Expand Down Expand Up @@ -207,24 +230,6 @@ def setup_engine(self,
self.model_config = model_config
self.cache_config = cache_config

if engine_type == EngineType.ATTENTION:
self.executor = AttnExecutor.build(model_config, cache_config)
self._process_batch = self.process_batch_attn
self.block_mgr = BlockManager_C(
cache_config.block_size,
cache_config.num_gpu_blocks,
cache_config.num_reserved_blocks
)
if self._intra_group_tp_enabled:
self._create_broadcast_buffers()
elif engine_type == EngineType.EXPERT:
self.executor = ExpertsExecutor(model_config)
self._process_batch = self.process_batch_expert
# prepare inner exp rank, [n_exp_per_rank * rank, (rank + 1) * n_exp_per_rank) -> [0, n_exp_per_rank)
self.inner_exp_rank = [0 for _ in range(model_config.num_experts_per_rank)]
for i in range(model_config.num_experts_per_rank):
self.inner_exp_rank[i] = model_config.num_experts_per_rank * rank + i

self._logger.info(f"engine setup. {self.engine_type, model_config}")

def _create_cuda_graph_contexts(self):
Expand All @@ -244,7 +249,7 @@ def _create_cuda_graph_contexts(self):
(MAX_BATCH_SIZE + MAX_BATCH_SIZE + (MAX_BATCH_SIZE + 1) + (MAX_BATCH_SIZE + 1)),
dtype=torch.int32, device="cuda")
else:
self.static_batch_sizes = torch.zeros((self.model_config.num_experts_per_rank, ), dtype=torch.long, device="cuda")
self.static_batch_sizes = torch.zeros((self.model_config.num_experts_per_rank, ), dtype=torch.long, device="cpu")

for i in range(self.model_config.num_layers):
self.graphs.append([
Expand All @@ -261,13 +266,14 @@ def _warmup_attn(self):
meta_py = make_dummy_meta(MAX_BATCH_SIZE, 0)
meta = self._pack_flash_attn_metadata(meta_py.to_c(), meta_py, [], mocking=True)
# warmup for CUBLAS
self.static_output, self.static_expert_ids = self.executor.execute(
0, self.static_positions, self.static_input, meta)
for _ in range(5):
self.static_output, self.static_expert_ids = self.executor.execute(
self.model_config.layer_ids[0], self.static_positions, self.static_input, meta)
if self.model_config.tp_size > 1:
group_sync()

meta.use_cuda_graph = True
for layer_id in range(self.model_config.num_layers):
for layer_id in self.model_config.layer_ids:
for graph, graph_batch_size in zip(self.graphs[layer_id], GRAPH_BATCH_SIZES):
meta_py = make_dummy_meta(graph_batch_size, 0)
meta = self._pack_flash_attn_metadata(meta_py.to_c(), meta_py, [], mocking=True)
Expand All @@ -281,7 +287,8 @@ def _warmup_attn(self):
graph.replay()

def _warmup_expert(self):
self.static_output = self.executor.execute(0, self.static_input, self.static_batch_sizes)
for _ in range(5):
self.static_output = self.executor.execute(0, self.static_input, self.static_batch_sizes)
self._logger.warning("Expert CUDA Graph is not implemented yet")
return

Expand Down Expand Up @@ -615,7 +622,7 @@ def process_batch_expert(self,
batch_sizes = torch.tensor(
[batch_sizes[i] for i in self.inner_exp_rank],
dtype=torch.int64,
device="cuda", # NOTE(hogura|20241014): grouped_gemm requires batch_sizes to be on cpu
device="cpu", # NOTE(hogura|20241014): grouped_gemm requires batch_sizes to be on cpu
)

# self._logger.info(f"executing expert {meta_c.req_ids}")
Expand Down Expand Up @@ -693,7 +700,7 @@ def loop(self):
for i, size in enumerate(pool_snapshot):
if size <= 0:
continue
layer = self.executor.layer_mappings[i]
layer = self.model_config.layer_ids[i]
pool_snapshot_dict[layer] = size

self._step_stats.append(
Expand Down Expand Up @@ -772,7 +779,9 @@ def init_core(
out_device_group_ids: Dict[int, List[int]],
out_nccl_ids: Dict[int, int],
device_group_ids: List[int] = None,
group_nccl_ids: str = ""):
group_nccl_ids: str = "",
expert_ranks: List[Tuple[int, int, int]] = [],
):
self.sampler = init_sampler(
self.device_id,
self.max_output_len,
Expand Down Expand Up @@ -843,7 +852,9 @@ def init_core(
out_device_group_ids: Dict[int, List[int]],
out_nccl_ids: Dict[int, int],
device_group_ids: List[int] = None,
group_nccl_ids: str = ""):
group_nccl_ids: str = "",
expert_ranks: List[Tuple[int, int, int]] = [],
):
self.tokenizer = init_tokenizer(
self.device_id,
out_device_ids,
Expand Down
2 changes: 1 addition & 1 deletion disagmoe/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
INTERMEDIATE_SIZE = 4 * HIDDEN_SIZE

BLOCK_SIZE = 32
NUM_BLOCKS = 4096
NUM_BLOCKS = 8000
RESERVED_BLOCKS = 256

CPS = 1e6
Expand Down
Loading

0 comments on commit f724fcc

Please sign in to comment.