diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index c56c6cb9..69dc1b65 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -33,6 +33,7 @@ from chatlearn.checkpoint.checkpoint_manager import CheckpointManager from chatlearn.utils import future from chatlearn.utils.dist_utils import bucket_tensors, coalesced_comm_dense +from chatlearn.utils.dist_utils import bucket_tensors_two_stage, coalesced_comm_dense_two_stage from chatlearn.utils.global_vars import get_args from chatlearn.utils.global_vars import set_global_variables from chatlearn.utils.logger import log_rank_0, debug_rank_0, setup_logger @@ -106,6 +107,9 @@ def __init__(self, name, args=None, replica_id=0): self._data_ckpt_manager = None self._peak_memory = 0 self._parameters_to_sync = defaultdict(list) + self._parameters_to_send = defaultdict(list) + self._parameters_to_recv = defaultdict(list) + self._parameters_shape = [] self._concat_params_dict = None self._to_fix_act_ordering_dict = None self._to_fix_qkv_ordering_dict = None @@ -124,6 +128,32 @@ def __init__(self, name, args=None, replica_id=0): # parameter sync from src_model self._src_parameter_model = None self.profiler = None + self._buffer_num = {} + self._tp_division = {} + self._num_mapping = 1 + self._sync_buffer = defaultdict(list) + + def get_sync_buffer(self): + return self._sync_buffer + + def set_num_mapping(self, _num_mapping): + self._num_mapping = _num_mapping + + @property + def num_mapping(self): + return self._num_mapping + + def set_buffer_num(self, buffer_num): + self._buffer_num.update(buffer_num) + + def get_buffer_num(self, param_names): + return [self._buffer_num[name] for name in param_names] + + def set_tp_division(self, tp_division): + self._tp_division.update(tp_division) + + def get_tp_division(self, param_names): + return [self._tp_division[name] for name in param_names] @property def is_colocate(self): @@ -621,11 +651,13 @@ def get_to_fix_qkv_ordering_func(self): def set_to_fix_qkv_ordering_func(self, _to_fix_qkv_ordering_func): self._to_fix_qkv_ordering_func = _to_fix_qkv_ordering_func - def set_sync_parameters(self, trainable_param_names, pipe_stage=0): + def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None): """ :meta private: """ - if pipe_stage not in self._parameters_to_sync or len(self._parameters_to_sync[pipe_stage]) == 0: # pylint: disable=too-many-nested-blocks + if parameters_to_sync is None: + parameters_to_sync = self._parameters_to_sync + if pipe_stage not in parameters_to_sync or len(parameters_to_sync[pipe_stage]) == 0: # pylint: disable=too-many-nested-blocks concat = [] set_sync_param_flag = False @@ -719,8 +751,21 @@ def set_sync_parameters(self, trainable_param_names, pipe_stage=0): _params_to_sync = _params_to_sync.contiguous() concat = [] set_sync_param_flag = False - self._parameters_to_sync[pipe_stage].append(_params_to_sync) + parameters_to_sync[pipe_stage].append((name, _params_to_sync)) + def set_send_parameters(self, trainable_param_names, pipe_stage=0): + """ + :meta private: + """ + return self.set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_send) + + def set_recv_parameters(self, rank, trainable_param_names, pipe_stage=0): + """ + :meta private: + """ + parameters_to_recv = defaultdict(list) + self._parameters_to_recv[rank] = parameters_to_recv + return self.set_sync_parameters(trainable_param_names, pipe_stage, parameters_to_recv) def get_parameter_names(self, requires_grad=True): """ @@ -732,6 +777,15 @@ def get_parameter_names(self, requires_grad=True): else: return [param_to_name[param] for param in self.parameters] + def get_parameter_shape(self, param_names): + """ + :meta private: + """ + parameters_shape = [] + for name in param_names: + parameters_shape.append((name, self.named_parameters[name].shape)) + return parameters_shape + def get_parameter(self, name): """ :meta private: @@ -774,7 +828,7 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): """ :meta private: """ - tensors = [param.data for param in self._parameters_to_sync[pipe_stage]] + tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]] assert len(tensors) > 0 dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) @@ -786,6 +840,102 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): for param in sparse_bucket: col.broadcast(param, src_rank, group_name) + def broadcast_parameter_two_stage(self, to_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): + """ + :meta private: + """ + tensor_changed = rank != src_rank + + if stage2: + if tensor_changed: + parameters_to_sync = self._parameters_to_recv[rank] + else: + parameters_to_sync = self._parameters_to_send + else: + del self._sync_buffer + self._sync_buffer = defaultdict(list) + parameters_to_sync = self._parameters_to_sync + + tensors = [] + buffer_num = [] + if stage2 and not tensor_changed and self._sync_buffer: + idx = 0 + for name, param in parameters_to_sync[pipe_stage]: + tensors.append(self._sync_buffer[(to_rank + 1) % self.num_mapping][idx]) + buffer_num.append(1) + idx += 1 + del self._sync_buffer[(to_rank + 1) % self.num_mapping] + else: + for name, param in parameters_to_sync[pipe_stage]: + param_data = param.data + param_data_shape = param_data.shape + if rank and self._buffer_num and not stage2: + assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}" + buffer_num.append(self._buffer_num[name]) + elif stage2: + buffer_num.append(1) + else: + if "attention.query_key_value" in name or "self_attention.query_key_value" in name: + tp_size = self.module_args.args_dict["tensor_model_parallel_size"] + heads = self.module_args.args_dict["num_attention_heads"] // tp_size + hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"] + param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:] + param_data = param_data.view(param_shape) + param_data_list = [] + head_offset = heads // self._tp_division[name] + for idx in range(self._tp_division[name]): + start = idx * head_offset + end = start + head_offset + param_data_list.append(param_data[:,start:end]) + param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) + del param_data_list + + if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name: + param_data_list = [] + col_offset = param_data_shape[1] // self._tp_division[name] + for idx in range(self._tp_division[name]): + start = idx * col_offset + end = start + col_offset + param_data_list.append(param_data[:,start:end]) + param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) + del param_data_list + if "mlp.dense_h_to_4h" in name: + param_data_list = [] + row_offset = param_data_shape[0] // self._tp_division[name] // 2 + for idx in range(self._tp_division[name]): + w1_start = idx * row_offset + w1_end = w1_start + row_offset + w2_start = (idx + self._tp_division[name]) * row_offset + w2_end = w2_start + row_offset + param_data_list.append( + torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0)) + param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) + del param_data_list + buffer_num.append(1) + tensors.append(param_data) + + assert len(tensors) > 0 + dense_buckets, sparse_bucket = bucket_tensors_two_stage( + tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb, + buffer_num=None if stage2 else buffer_num, tensor_changed=tensor_changed and not stage2) + debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) + + for bucket in dense_buckets: + index = 0 if stage2 else (to_rank % self.num_mapping) + all_buffers = coalesced_comm_dense_two_stage( + bucket, col.broadcast, rank, + extra_args=(src_rank, group_name), tensor_changed=tensor_changed, + stage2=stage2, index=index) + if tensor_changed and not stage2: + for key, value in all_buffers.items(): + self._sync_buffer[key] += value + + for param in sparse_bucket: + col.broadcast(param, src_rank, group_name) + + self.empty_cache() + return self._sync_buffer + def send_parameter(self, name, dst_rank, group_name, pipe_stage=0): """ diff --git a/chatlearn/models/vllm_module.py b/chatlearn/models/vllm_module.py index 816033c4..8f15e8d9 100644 --- a/chatlearn/models/vllm_module.py +++ b/chatlearn/models/vllm_module.py @@ -326,9 +326,10 @@ def empty_cache(self): self.worker.cache_engine.cpu_cache = None self.worker.cache_engine.gpu_cache = None elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1.value: - for ele in self.worker.gpu_cache: # pylint: disable=unused-variable - ele = None - self.worker.gpu_cache = None # pylint: disable=access-member-before-definition + if self.worker.gpu_cache is not None: + for ele in self.worker.gpu_cache: # pylint: disable=unused-variable + ele = None + self.worker.gpu_cache = None # pylint: disable=access-member-before-definition for c_e in self.worker.cache_engine: c_e.cpu_cache = None @@ -573,14 +574,14 @@ def data_parallel_size(self): """ :meta private: """ - return None + return 1 @property def data_parallel_rank(self): """ :meta private: """ - return None + return 0 def tensor_parallel_rank(self): """ diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index e274178b..44bb7a08 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -50,6 +50,8 @@ def __init__(self, src_model, dst_model, group_name, frequency, error_signal): self.error_signal = error_signal self.send_recv_actor_mappings = defaultdict(list) self.recv_send_actor_mappings = defaultdict(list) + self.send_recv_actor_mappings_stage2 = defaultdict(list) + self.recv_send_actor_mappings_stage2 = defaultdict(list) self.actor2rank = {} self._debug = get_args().runtime_args.debug self._num_src_pipeline_stage = None @@ -69,7 +71,12 @@ def __init__(self, src_model, dst_model, group_name, frequency, error_signal): logger.warning("Only support PARAM_SYNC_COMM_TYPE.BROADCAST when TP SIZE is even number, use P2P instead") self._comm_type = PARAM_SYNC_COMM_TYPE.P2P self.setup_collective_group() - self.build_rank_mapping() + self.num_mapping = self.num_dst_tensor_parallel // self.num_src_tensor_parallel + if self.num_mapping == 1: + self.build_rank_mapping() + else: + self.build_rank_mapping_two_stage() + self.enable_coalesce_param = get_args().runtime_args.coalesce_param self.concurrent_comm = get_args().runtime_args.concurrent_comm self._enable_lora = self.src_model.module_args.lora.enable_lora @@ -81,6 +88,7 @@ def __init__(self, src_model, dst_model, group_name, frequency, error_signal): self.collective_groups = [] self.src_dp_size = future.get(self.src_model.replicas[0].all_actors[0].get_data_parallel_size.remote()) self.sorted_send_actors = None + self.sorted_send_actors_stage2 = None def get_group_name(self, actors): return f"{self.group_name}_" + "_".join(str(self.actor2rank[actor]) for actor in actors) @@ -174,10 +182,28 @@ def add_recv_actor(self, src_rank, dst_rank): logger.debug(f"build rank mapping from {src_rank} to {dst_rank}, from gpu {src_gpu} to {dst_gpu}, " + \ f"from pipe_stage {src_pp_rank} to {dst_pp_rank}, " + \ f"from tp rank {src_tp_rank} to {dst_tp_rank}") - assert src_tp_rank == dst_tp_rank, f"src_tp_rank {src_tp_rank} should be same as dst_tp_rank {dst_tp_rank}" self.send_recv_actor_mappings[src_actor].append(dst_actor) self.recv_send_actor_mappings[dst_actor].append(src_actor) + def add_recv_actor_stage2(self, src_rank, dst_rank): + src_actor = self.dst_model.get_actor(src_rank) + self.actor2rank[src_actor] = src_rank + dst_actor = self.dst_model.get_actor(dst_rank) + self.actor2rank[dst_actor] = dst_rank + + src_gpu = future.get(src_actor.get_visible_gpus.remote()) + dst_gpu = future.get(dst_actor.get_visible_gpus.remote()) + # TODO(jiangle.jl): support ep/cp. + src_tp_rank = self.get_actor_tp_rank(src_actor) + dst_tp_rank = self.get_actor_tp_rank(dst_actor) + src_pp_rank = self.get_actor_pipe_rank(src_actor) + dst_pp_rank = self.get_actor_pipe_rank(dst_actor) + logger.debug(f"build rank mapping from {src_rank} to {dst_rank}, from gpu {src_gpu} to {dst_gpu}, " + \ + f"from pipe_stage {src_pp_rank} to {dst_pp_rank}, " + \ + f"from tp rank {src_tp_rank} to {dst_tp_rank}") + self.send_recv_actor_mappings_stage2[src_actor].append(dst_actor) + self.recv_send_actor_mappings_stage2[dst_actor].append(src_actor) + def build_rank_mapping(self): # setup rank mapping for src parameter and dst parameter # get rank for one src_model, without model replicas @@ -222,6 +248,93 @@ def split_ranks_by_tp_size(ranks, tp_size): for src_rank, dst_rank in zip(src_tp_group, dst_replica_ranks_group[j]): self.add_recv_actor(src_rank, dst_rank) + def build_rank_mapping_two_stage(self): + # setup rank mapping for src parameter and dst parameter + # get rank for one src_model, without model replicas + dst_ranks = self.dst_model.all_ranks + local_src_ranks = future.get(self.src_model.replicas[0].get_local_param_ranks()) + if local_src_ranks[0] is None or dst_ranks is None: + if self._debug: + logger.warning( + f"DEBUG MODE! src_ranks {local_src_ranks} or dst_ranks: {dst_ranks} is None, make sure they have values in real application.") + return + else: + raise Exception(f"src_ranks {local_src_ranks} or dst_ranks {dst_ranks} should not be None") + dp_rank_to_ranks = defaultdict(list) + for local_ranks, dp_rank in local_src_ranks: + dp_rank_to_ranks[dp_rank].append(local_ranks[dp_rank]) + src_ranks = [i[1] for i in sorted(dp_rank_to_ranks.items())] + + assert len(src_ranks[0]) % len(dst_ranks[0]) == 0, \ + f"src training model ranks should be times of dst ranks, but got {len(src_ranks[0])} and {len(dst_ranks[0])}" + + if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1: + replica_rank_iter = cycle(reversed(src_ranks)) + else: + replica_rank_iter = cycle(iter(src_ranks)) + + logger.debug(f"src_ranks: {src_ranks}") + logger.debug(f"dst_ranks: {dst_ranks}") + assert self.num_dst_tensor_parallel % self.num_src_tensor_parallel == 0, \ + "currently we require mod value equals to zero for tensor_model_parallel_size of dst_model and that of src_model while " + \ + f"src model {self.src_model.name}(TP={self.num_src_tensor_parallel}) and " + \ + f"dst model {self.dst_model.name}(TP={self.num_dst_tensor_parallel})" + assert self.num_src_pipeline_stage % self.num_dst_pipeline_stage == 0 + + def split_ranks_by_tp_size(ranks, tp_size): + return [ranks[i:i + tp_size] for i in range(0, len(ranks), tp_size)] + + pair_list = [] + p2p_list = [] + for dst_replica_ranks in dst_ranks: + src_replica_ranks = next(replica_rank_iter) + src_replica_ranks_group = split_ranks_by_tp_size(src_replica_ranks, self.num_src_tensor_parallel) + dst_replica_ranks_group = split_ranks_by_tp_size(dst_replica_ranks, self.num_dst_tensor_parallel) + logger.debug(f"src_replica_ranks_group: {src_replica_ranks_group}") + logger.debug(f"dst_replica_ranks_group: {dst_replica_ranks_group}") + pipe_map_interval = self.num_src_pipeline_stage // self.num_dst_pipeline_stage + + # stage 1: comm pairs that broadcast params from trainer to inference model + # Each rank in trainer holds weights for num_mapping ranks in inference model. + # For example: trainer_tp = 2, inference_tp = 4 => num_mapping = inference_tp // trainer_tp = 2 + # Weight mapping from training to inference: + # [0] -> [0', 1'] + # [1] -> [2', 3'] + # To avoid p2p communication on the same gpu, we only broadcast params to first rank in weight_mapping_group. + # Comm mapping from training to inference: + # [0] -> [0'] + # [1] -> [2'] + for i, src_tp_group in enumerate(src_replica_ranks_group): + j = i // pipe_map_interval + if self.num_mapping == 1: + start = 0 + else: + mod_i = i % self.num_mapping + start = mod_i if i < self.num_mapping else (self.num_mapping - mod_i - 1) % self.num_mapping + for s_idx, src_rank in enumerate(src_tp_group): + offset = s_idx * self.num_mapping + start + dst_rank = dst_replica_ranks_group[j][offset] + self.add_recv_actor(src_rank, dst_rank) + pair_list.append((src_rank, dst_rank)) + # stage 2: comm pairs that broadcast params from first rank to the other ranks for each weight_mapping_group + # Comm mapping in each weight_mapping_group of inference: + # [0'] -> [1'] + # [2'] -> [3'] + def p2p_pair_grouping(tuples): + for s_idx, src_rank in enumerate(tuples): + for d_idx, dst_rank in enumerate(tuples): + if s_idx == d_idx: + continue + self.add_recv_actor_stage2(src_rank, dst_rank) + p2p_list.append((src_rank, dst_rank)) + for dst_tp_group in dst_replica_ranks_group: + dst_tp_group = split_ranks_by_tp_size(dst_tp_group, self.num_mapping) + for tuples in dst_tp_group: + p2p_pair_grouping(tuples) + + logger.debug(f"comm pair_list : {pair_list}") + logger.debug(f"comm p2p_list : {p2p_list}") + def _get_dst_name(self, src_name): if self._src_prefix: dst_name = src_name[len(self._src_prefix):] @@ -247,6 +360,48 @@ def validate(): utils.get_or_cache(self._validate_params, (send_actor, recv_actor), validate) logger.info("Validation passed!") + def set_sync_param_names_stage2(self, send_actor, recv_actor, rank, requires_grad): + send_names = self.set_sync_param_names(send_actor, send_actor, requires_grad) + refs = [] + refs.append(send_actor.set_send_parameters.remote(send_names, self.get_actor_pipe_rank(send_actor))) + refs.append(recv_actor.set_recv_parameters.remote(rank, send_names, self.get_actor_pipe_rank(recv_actor))) + future.get(refs) + return send_names, send_names + + def sync_broadcast_two_stage(self, actors, group_name, requires_grad=None, stage2=False): + send_actor = actors[0] + for rank, recv_actor in enumerate(actors[1:]): + if stage2: + src_names, dst_names = self.set_sync_param_names_stage2(send_actor, recv_actor, rank + 1, requires_grad) + else: + src_names, dst_names = self.set_sync_param_names(send_actor, recv_actor, requires_grad) + shape_refs = [] + shape_refs.append(send_actor.get_parameter_shape.remote(src_names)) + shape_refs.append(recv_actor.get_parameter_shape.remote(dst_names)) + send_shape_list, recv_shape_list = future.get(shape_refs) + + buffer_num = {} + tp_division = {} + for send_name_and_shape, recv_name_and_shape in zip(send_shape_list, recv_shape_list): + buffer_num[recv_name_and_shape[0]] = send_name_and_shape[1].numel() // recv_name_and_shape[1].numel() + tp_division[send_name_and_shape[0]] = buffer_num[recv_name_and_shape[0]] + refs = [] + refs.append(recv_actor.set_num_mapping.remote(self.num_mapping)) + refs.append(recv_actor.set_buffer_num.remote(buffer_num)) + refs.append(send_actor.set_num_mapping.remote(self.num_mapping)) + refs.append(send_actor.set_tp_division.remote(tp_division)) + future.get(refs) + + assert self.enable_coalesce_param + refs = [] + pipe_stage = self.get_actor_pipe_rank(send_actor) + send_rank = 0 + for rank, actor in enumerate(actors): + ref = actor.broadcast_parameter_two_stage.remote(self.actor2rank[actor], rank, send_rank, group_name, pipe_stage, stage2) + refs.append(ref) + rets = future.wait(refs, return_output=True) + return rets + def sync_broadcast(self, actors, group_name, requires_grad=None): send_actor = actors[0] for recv_actor in actors[1:]: @@ -259,7 +414,6 @@ def sync_broadcast(self, actors, group_name, requires_grad=None): refs.append(ref) future.wait(refs, return_output=True) - def _sync_send_recv(self, send_actor, recv_actor, requires_grad=None): src_names, dst_names = self.set_sync_param_names(send_actor, recv_actor, requires_grad) pipe_stage = self.get_actor_pipe_rank(send_actor) @@ -372,6 +526,12 @@ def _set_sync_param_names(self, send_actor, recv_actor, requires_grad=None): dst_names = [self._get_dst_name(name) for name in dst_names] self.check_param_names(send_actor, recv_actor, src_names, dst_names) + if self.num_mapping > 1: + key = (recv_actor, recv_actor) + if key not in self._send_recv_param_names: + self._send_recv_param_names[key] = dst_names + else: + self._send_recv_param_names[key] += dst_names pipe_stage = self.get_actor_pipe_rank(send_actor) refs = [] refs.append(send_actor.set_sync_parameters.remote(src_names, pipe_stage)) @@ -383,13 +543,14 @@ def set_sync_param_names(self, send_actor, recv_actor, requires_grad=None): return utils.get_or_cache(self._send_recv_param_names, (send_actor, recv_actor), \ lambda: self._set_sync_param_names(send_actor, recv_actor, requires_grad)) - def create_broadcast_group(self, send_actor, recv_actors): + def create_broadcast_group(self, send_actor, recv_actors, group_name=None): actor_groups = [send_actor] actor_groups.extend(recv_actors) dp = self.get_actor_dp_rank(send_actor) pp = self.get_actor_pipe_rank(send_actor) tp = self.get_actor_tp_rank(send_actor) - group_name = f"{self.group_name}_dp{dp}_pp{pp}_tp{tp}" + group_name = self.group_name if group_name is None else group_name + group_name = f"{group_name}_dp{dp}_pp{pp}_tp{tp}" if group_name not in self.collective_groups: refs = [] for rank, actor in enumerate(actor_groups): @@ -399,27 +560,44 @@ def create_broadcast_group(self, send_actor, recv_actors): self.collective_groups.append(group_name) return actor_groups, group_name - def sort_send_actors(self): - if self.sorted_send_actors is not None: - return self.sorted_send_actors + def sort_send_actors(self, send_recv_actor_mappings, sorted_send_actors): + if sorted_send_actors is not None: + return sorted_send_actors dp2send_actors = defaultdict(list) - for send_actor in self.send_recv_actor_mappings: + for send_actor in send_recv_actor_mappings: dp2send_actors[self.get_actor_dp_rank(send_actor)].append(send_actor) for dp_rank in dp2send_actors: send_actors = dp2send_actors[dp_rank] dp2send_actors[dp_rank] = sorted(send_actors, key=lambda x: self.actor2rank[x]) sorted_send_actors = [] dp_rank = 0 - while len(sorted_send_actors) < len(self.send_recv_actor_mappings): + while len(sorted_send_actors) < len(send_recv_actor_mappings): sorted_send_actors.append(dp2send_actors[dp_rank].pop(0)) dp_rank += 1 # dp_rank not in dp2send_actors happens when inference replica number less than training replica number if dp_rank == self.src_dp_size or dp_rank not in dp2send_actors: dp_rank = 0 - assert len(self.send_recv_actor_mappings) == len(sorted_send_actors) - self.sorted_send_actors = sorted_send_actors + assert len(send_recv_actor_mappings) == len(sorted_send_actors) return sorted_send_actors + def sync_broadcast_multi_threads(self, sorted_send_actors, send_recv_actor_mappings, max_workers, requires_grad, group_name=None, stage2=False): + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for send_actor in sorted_send_actors: + recv_actors = send_recv_actor_mappings[send_actor] + if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: + actor_groups, group_name = self.create_broadcast_group(send_actor, recv_actors, group_name=group_name) + futures.append(executor.submit(self.sync_broadcast_two_stage, actor_groups, group_name, requires_grad, stage2)) + else: + raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) + + def sync(self, requires_grad=None): if not self._is_collective_group_created: # Re-create collective group only when it is destroyed before. @@ -433,7 +611,7 @@ def sync(self, requires_grad=None): assert state, "Check fuse lora layer fail." if self.concurrent_comm: - sorted_send_actors = self.sort_send_actors() + sorted_send_actors = self.sort_send_actors(self.send_recv_actor_mappings, self.sorted_send_actors) max_workers = get_args().runtime_args.param_sync_max_workers if max_workers is None: max_workers = max(self.src_model.total_gpu // 8, 1) @@ -442,22 +620,39 @@ def sync(self, requires_grad=None): max_workers = len(send_actors) else: max_workers = len(send_actors) * len(self.send_recv_actor_mappings[send_actors[0]]) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for send_actor in sorted_send_actors: - recv_actors = self.send_recv_actor_mappings[send_actor] + + if self.num_mapping > 1: + # stage 1 + self.sync_broadcast_multi_threads(sorted_send_actors, self.send_recv_actor_mappings, max_workers, requires_grad, stage2=False) + # stage 2 + sorted_send_actors = self.sort_send_actors(self.send_recv_actor_mappings_stage2, self.sorted_send_actors_stage2) + max_workers = get_args().runtime_args.param_sync_max_workers + if max_workers is None: + max_workers = max(self.dst_model.total_gpu // 8, 1) + if max_workers == -1: if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: - actor_groups, group_name = self.create_broadcast_group(send_actor, recv_actors) - futures.append(executor.submit(self.sync_broadcast, actor_groups, group_name, requires_grad)) + max_workers = len(sorted_send_actors) else: - for recv_actor in recv_actors: - futures.append(executor.submit(self.sync_send_recv, send_actor, recv_actor, requires_grad)) - for _future in concurrent.futures.as_completed(futures): - try: - _future.result() - except Exception as e: - raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from - concurrent.futures.wait(futures) + max_workers = len(sorted_send_actors) * len(self.send_recv_actor_mappings_stage2[sorted_send_actors[0]]) + self.sync_broadcast_multi_threads( + sorted_send_actors, self.send_recv_actor_mappings_stage2, max_workers, requires_grad, group_name="intra_comm", stage2=True) + else: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for send_actor in sorted_send_actors: + recv_actors = self.send_recv_actor_mappings[send_actor] + if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: + actor_groups, group_name = self.create_broadcast_group(send_actor, recv_actors) + futures.append(executor.submit(self.sync_broadcast, actor_groups, group_name, requires_grad)) + else: + for recv_actor in recv_actors: + futures.append(executor.submit(self.sync_send_recv, send_actor, recv_actor, requires_grad)) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) else: for send_actor, recv_actors in self.send_recv_actor_mappings.items(): if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index 899ea2bd..23c74f49 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -53,6 +53,72 @@ def bucket_tensors(tensors, bucket_size_mb): return dense_buckets, sparse_bucket +def bucket_tensors_two_stage(tensors, bucket_size_mb, buffer_num=None, tensor_changed=False): + """Group tensors into chunks. We seperate sparse and dense tensor, + each containing tensors of same type up to certain byte limit in total size. + + Args: + tensors (Sequence): A sequence of tensors to be separated into chunks. + size_limit (int): The limit of each chunk in bytes. + + Return: + dense_buckets: Blocks of tensors of same type and within size_limit. + sparse_bucket: A list of sparse tensors + """ + size_limit = bucket_size_mb * 1024 * 1024 + buf_dict = defaultdict(lambda: [[], 0]) + dense_buckets = [] + sparse_bucket = [] + for idx, tensor in enumerate(tensors): + buffer_multiple = 1 if buffer_num is None else buffer_num[idx] + if tensor.is_sparse: + sparse_bucket.append(tensor) + continue + t = tensor.type() + # expand buffer size of dst ranks which recv tensor from trainer. + size = tensor.numel() * tensor.element_size() * buffer_multiple + buf_and_size = buf_dict[t] + if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison + dense_buckets.append(buf_and_size[0]) + buf_and_size = buf_dict[t] = [[], 0] + buf_and_size[0].append((torch.empty(size=[tensor.numel() * buffer_multiple], + dtype=tensor.dtype, + device=tensor.device) if (tensor_changed and buffer_multiple > 1) else tensor, + [size // tensor.element_size(), buffer_multiple, tensor])) + buf_and_size[1] += size + for buf, size in buf_dict.values(): + if len(buf) > 0: + dense_buckets.append(buf) + return dense_buckets, sparse_bucket + + +def unflatten_dense_tensors(flat_tensors, tensors, sizes, num_ranks): + all_buffers = defaultdict(list) + + offset = 0 + for size_multiple, tensor in zip(sizes, tensors): + size, multiple, orig_tensor = size_multiple + assert offset <= flat_tensors.numel() + assert len(flat_tensors.shape) == 1 + flat_tensor = flat_tensors[offset:offset+size] + per_size = size // multiple + for rank in range(num_ranks): + if multiple > 1: + assert (flat_tensor.numel() // multiple) == tensor.numel(), \ + f"flat_tensor: {flat_tensor.shape} should be {multiple} times of tensor {orig_tensor.shape}, \ + per_size: {per_size} total_size: {size} num_ranks: {num_ranks} offset: {offset}" + all_buffers[rank].append(flat_tensor[rank * per_size:(rank + 1) * per_size].view(orig_tensor.shape)) + else: + assert flat_tensor.numel() == orig_tensor.numel(), \ + f"flat_tensor: {flat_tensor.shape} orig_tensor: {orig_tensor.shape}" + all_buffers[rank].append(flat_tensor.view(orig_tensor.shape)) + del flat_tensor + offset += size + del flat_tensors + return all_buffers + + + def coalesced_comm_dense(bucket, comm_call, extra_args, tensor_changed=True): """ coalesced communication for dense parameters @@ -64,6 +130,36 @@ def coalesced_comm_dense(bucket, comm_call, extra_args, tensor_changed=True): bucket, _unflatten_dense_tensors(flat_tensors, bucket)): tensor.copy_(synced) + +def coalesced_comm_dense_two_stage(bucket, comm_call, rank, extra_args, tensor_changed=True, stage2=False, index=0): + """ + coalesced communication for dense parameters + """ + all_tensors = [] + all_sizes = [] + num_ranks = 1 + orig_tensor_ele = 0 + orig_tensors = [] + for tensor, size in bucket: + all_tensors.append(tensor) + all_sizes.append(size) + orig_tensors.append(size[2]) + orig_tensor_ele += size[2].numel() + num_ranks = max(num_ranks, size[1]) + flat_tensors = _flatten_dense_tensors(all_tensors) + comm_call(flat_tensors, *extra_args) + if tensor_changed: + index = 0 if stage2 else index + all_buffers = unflatten_dense_tensors(flat_tensors, orig_tensors, all_sizes, num_ranks) + for tensor, synced in zip(orig_tensors, all_buffers[index]): + assert tensor.numel() == synced.numel(), \ + f"rank {rank} tensor {tensor.shape} should be equal to synced.shape {synced.shape}, for all_sizes {all_sizes}" + tensor.copy_(synced) + del all_buffers[index] + return all_buffers + return None + + def broadcast_var_object_dict(obj_dict, src_rank): if torch.distributed.get_rank() == src_rank: dict_as_list = list(obj_dict.items()) diff --git a/examples/megatron/configs/llama2/old_policy_inference.yaml b/examples/megatron/configs/llama2/old_policy_inference.yaml index d17eb4ab..9ede984a 100644 --- a/examples/megatron/configs/llama2/old_policy_inference.yaml +++ b/examples/megatron/configs/llama2/old_policy_inference.yaml @@ -10,3 +10,6 @@ temperature: ${policy_temperature:1.0} eval_temperature: 0.01 eval_top_k: 1 eval_top_p: 0 + +tensor_model_parallel_size: ${policy_tp} +pipeline_model_parallel_size: ${policy_pp} diff --git a/examples/megatron/configs/llama2/policy_shared.yaml b/examples/megatron/configs/llama2/policy_shared.yaml index e5cb6df3..4df69739 100644 --- a/examples/megatron/configs/llama2/policy_shared.yaml +++ b/examples/megatron/configs/llama2/policy_shared.yaml @@ -4,7 +4,6 @@ num_layers: ${policy_num_layers} hidden_size: ${policy_hidden_size} num_attention_heads: ${policy_num_attention_heads} ffn_hidden_size: ${policy_ffn_hidden_size} -tensor_model_parallel_size: ${policy_tp:8} num_query_groups: ${policy_num_query_groups} use_distributed_optimizer: True diff --git a/examples/megatron/configs/llama2/ppo_policy.yaml b/examples/megatron/configs/llama2/ppo_policy.yaml index 32a72538..270bbdd0 100644 --- a/examples/megatron/configs/llama2/ppo_policy.yaml +++ b/examples/megatron/configs/llama2/ppo_policy.yaml @@ -32,6 +32,7 @@ min_lr: ${policy_min_lr:1e-9} lr_decay_style: ${policy_lr_decay_style:linear} weight_decay: 0.01 pipeline_model_parallel_size: ${ppo_policy_pp:1} +tensor_model_parallel_size: ${ppo_policy_tp:8} sequence_parallel: ${sequence_parallel:True} recompute_activations: ${policy_recompute_activations:False} diff --git a/examples/megatron/configs/llama2/reference.yaml b/examples/megatron/configs/llama2/reference.yaml index 96cb77a2..32e06016 100644 --- a/examples/megatron/configs/llama2/reference.yaml +++ b/examples/megatron/configs/llama2/reference.yaml @@ -3,4 +3,5 @@ includes: - policy_shared.yaml parallel_output: True +tensor_model_parallel_size: ${policy_tp} pipeline_model_parallel_size: ${ref_pp:1} diff --git a/examples/megatron/configs/llama2/vllm_policy_inference.yaml b/examples/megatron/configs/llama2/vllm_policy_inference.yaml index 84244b04..d09282f4 100644 --- a/examples/megatron/configs/llama2/vllm_policy_inference.yaml +++ b/examples/megatron/configs/llama2/vllm_policy_inference.yaml @@ -36,4 +36,7 @@ sliding_window: ${sliding_window:None} tokenizer: ${tokenizer_load} vllm_input_ids_key: input_ids -vllm_prompt_key: prompt \ No newline at end of file +vllm_prompt_key: prompt + +tensor_model_parallel_size: ${policy_tp} +pipeline_model_parallel_size: ${policy_pp} diff --git a/examples/megatron/scripts/train_dpo_llama.sh b/examples/megatron/scripts/train_dpo_llama.sh index 05ed3cff..0242e1e3 100644 --- a/examples/megatron/scripts/train_dpo_llama.sh +++ b/examples/megatron/scripts/train_dpo_llama.sh @@ -20,6 +20,8 @@ export retro_encoder_hidden_dropout=0.0 export retro_encoder_attention_dropout=0.0 export policy_tp=8 +export policy_pp=1 +export ppo_policy_tp=8 export ppo_policy_pp=1 export train_global_batch_size=128 export ref_generation_batch_size=16 diff --git a/examples/megatron/scripts/train_online_dpo_llama.sh b/examples/megatron/scripts/train_online_dpo_llama.sh index 0e99a1ee..1e695f0e 100644 --- a/examples/megatron/scripts/train_online_dpo_llama.sh +++ b/examples/megatron/scripts/train_online_dpo_llama.sh @@ -26,6 +26,8 @@ export num_inference_per_prompt=8 if [[ "$model_size" == "llama2-7B" ]]; then export policy_tp=8 + export policy_pp=1 + export ppo_policy_tp=8 export ppo_policy_pp=1 export reward_tp=8 export ppo_value_pp=1 @@ -36,6 +38,8 @@ if [[ "$model_size" == "llama2-7B" ]]; then export gpu_memory_utilization=0.9 elif [[ "$model_size" == "llama2-13B" ]]; then export policy_tp=8 + export policy_pp=1 + export ppo_policy_tp=8 export ppo_policy_pp=2 export reward_tp=8 export ppo_value_pp=2 diff --git a/examples/megatron/scripts/train_rlhf_llama.sh b/examples/megatron/scripts/train_rlhf_llama.sh index 30bdaec8..eedd2fe8 100644 --- a/examples/megatron/scripts/train_rlhf_llama.sh +++ b/examples/megatron/scripts/train_rlhf_llama.sh @@ -43,6 +43,8 @@ export data_checkpoint_path=${output_dir}/data_checkpoint if [[ "$model_size" == "llama2-7B" ]]; then export policy_tp=4 + export policy_pp=1 + export ppo_policy_tp=4 export ppo_policy_pp=1 export reward_tp=4 export ppo_value_pp=1 @@ -67,6 +69,8 @@ if [[ "$model_size" == "llama2-7B" ]]; then export free_memory_ppo_value=True elif [[ "$model_size" == "llama2-13B" ]]; then export policy_tp=8 + export policy_pp=1 + export ppo_policy_tp=8 export ppo_policy_pp=2 export reward_tp=8 export ppo_value_pp=2 @@ -75,6 +79,8 @@ elif [[ "$model_size" == "llama2-13B" ]]; then export ref_generation_batch_size=16 elif [[ "$model_size" == "llama2-70B" ]]; then export policy_tp=8 + export policy_pp=1 + export ppo_policy_tp=8 export ppo_policy_pp=4 export reward_tp=8 export ppo_value_pp=4