diff --git a/configs/common/train.py b/configs/common/train.py index 2201b615c..fa9909589 100644 --- a/configs/common/train.py +++ b/configs/common/train.py @@ -44,7 +44,7 @@ # Enable automatic mixed precision for training which does not # change model's inference behavior. amp=dict(enabled=False), - + train_with_fp16=False, # Enable activation checkpointing to allow for training # with larger models, sequences, and batch sizes. # If enabled, checkpoint the input activations of each transformer layers by default. diff --git a/configs/loadder_mapping.py b/configs/loadder_mapping.py new file mode 100644 index 000000000..359f7a1a8 --- /dev/null +++ b/configs/loadder_mapping.py @@ -0,0 +1,23 @@ +loader_mapping_models = dict( + + llama=dict( + loader_prefix="projects.Llama.utils.llama_loader", + huggingface_loader="LlamaLoaderHuggerFace", + ), + + chatglm=dict( + loader_prefix="projects.ChatGLM.utils.chatglm_loader", + huggingface_loader="ChatGLMLoaderHuggerFace", + ), + + qwen2=dict( + loader_prefix="projects.Qwen2.utils.qwen_loader", + huggingface_loader="Qwen2LoaderHuggerFace", + ), + + aquila=dict( + loader_prefix="projects.Aquila.utils.aquila_loader", + huggingface_loader="AquilaLoaderHuggerFace", + ) + +) diff --git a/libai/data/build.py b/libai/data/build.py index 656929f8b..865240a9c 100644 --- a/libai/data/build.py +++ b/libai/data/build.py @@ -153,7 +153,7 @@ def build_nlp_train_loader( train_batch_size, test_batch_size=None, sampler=LazyCall(CyclicSampler)(shuffle=True), - num_workers=4, + num_workers=0, consumed_samples=0, seed=0, collate_fn=None, @@ -223,7 +223,7 @@ def build_nlp_test_loader( dataset, test_batch_size, sampler=LazyCall(SingleRoundSampler)(shuffle=False, drop_last=False), - num_workers=4, + num_workers=0, seed=0, collate_fn=None, ): diff --git a/libai/data/structures.py b/libai/data/structures.py index 380a8a1ce..753faa33f 100644 --- a/libai/data/structures.py +++ b/libai/data/structures.py @@ -99,7 +99,7 @@ def stack(distTensor_lists: List["DistTensorData"]) -> "DistTensorData": assert ( data.placement_idx == placement_idx ), f"placement_idx is not equal, {data.placement_idx} != {placement_idx}" - tensors.append(data.tensor) + tensors.append(data.tensor.to(flow.int64)) tensors = flow.stack(tensors, dim=0) ret = DistTensorData(tensors, sbp_list=sbp_list, placement_idx=placement_idx) return ret diff --git a/libai/engine/hooks.py b/libai/engine/hooks.py index a66cf4155..79912170c 100644 --- a/libai/engine/hooks.py +++ b/libai/engine/hooks.py @@ -341,9 +341,10 @@ def _do_eval(self): def after_step(self): next_iter = self.trainer.iter + 1 if self._period > 0 and next_iter % self._period == 0: - # do the last eval in after_train - if next_iter != self.trainer.max_iter: - self._do_eval() + # # do the last eval in after_train + # if next_iter != self.trainer.max_iter: + # self._do_eval() + pass def after_train(self): # This condition is to prevent the eval from running after a failed training diff --git a/libai/engine/trainer.py b/libai/engine/trainer.py index ffd64ebaa..3ed54ae74 100644 --- a/libai/engine/trainer.py +++ b/libai/engine/trainer.py @@ -132,6 +132,7 @@ def train(self, start_iter: int, max_iter: int): Args: start_iter, max_iter (int): See docs above """ + # start_iter = 9980 # for profiling logger = logging.getLogger(__name__) logger.info("Starting training from iteration {}".format(start_iter)) @@ -283,7 +284,7 @@ def run_step(self, get_batch: Callable, input_placement_device: str = "cuda"): if (self.iter + 1) % self.grad_acc_steps == 0: self.optimizer.clip_grad() self.optimizer.step() - self.optimizer.zero_grad() + self.optimizer.zero_grad(set_to_none=True) class GraphTrainer(TrainerBase): diff --git a/libai/evaluation/evaluator.py b/libai/evaluation/evaluator.py index 1414cdaa0..07cae8c6d 100644 --- a/libai/evaluation/evaluator.py +++ b/libai/evaluation/evaluator.py @@ -203,12 +203,12 @@ def inference_on_dataset( # get valid sample valid_data = { - key: dist.tensor_to_rank0(value, to_local=True)[:valid_sample] + key: dist.tensor_to_rank0(value, to_local=True, device=input_placement_device)[:valid_sample] for key, value in data.items() } valid_outputs = {} for key, value in outputs.items(): - value = dist.tensor_to_rank0(value, to_local=True) + value = dist.tensor_to_rank0(value, to_local=True, device=input_placement_device) if value.ndim > 1: valid_outputs[key] = value[:valid_sample] # Slice if it's batched output else: diff --git a/libai/inference/basic.py b/libai/inference/basic.py index 94d3f1781..1f0092eae 100644 --- a/libai/inference/basic.py +++ b/libai/inference/basic.py @@ -41,6 +41,7 @@ def __init__( pipeline_parallel=None, pipeline_stage_id=None, pipeline_num_layers=None, + device_type="npu", model_path=None, mode="libai", **kwargs, @@ -59,6 +60,7 @@ def __init__( pipeline_parallel, pipeline_stage_id, pipeline_num_layers, + device_type, ) dist.setup_dist_util(self.cfg.train.dist) logger.info(self.cfg.train.dist) @@ -90,11 +92,13 @@ def update_cfg( pipeline_parallel=1, pipeline_stage_id=None, pipeline_num_layers=None, + device_type="npu", ): self.cfg.train.dist.data_parallel_size = data_parallel self.cfg.train.dist.tensor_parallel_size = tensor_parallel self.cfg.train.dist.pipeline_parallel_size = pipeline_parallel self.cfg.train.dist.custom_pipeline_stage_id = pipeline_stage_id + self.cfg.train.dist.device_type = device_type if pipeline_num_layers is not None: self.cfg.train.dist.pipeline_num_layers = pipeline_num_layers diff --git a/libai/layers/cross_entropy.py b/libai/layers/cross_entropy.py index cde6b1632..6821dffe4 100644 --- a/libai/layers/cross_entropy.py +++ b/libai/layers/cross_entropy.py @@ -36,13 +36,24 @@ def forward(self, logits: flow.Tensor, target: flow.Tensor): assert target.ndim == 2 assert logits.shape[0:2] == target.shape - target = target.to_global(placement=logits.placement) - - # Change -1 in target to 0 because sparse_softmax_cross_entropy don't accept -1 - target = target * (target >= 0) - - lm_loss = flow._C.sparse_softmax_cross_entropy( + target = target.to(flow.int32) # NOTE:npu nll target only support int32 for now + target = target.to_global(placement=logits.placement) + lm_loss = flow._C.cross_entropy( logits.view(-1, logits.shape[-1]), target.view(-1), + None, + -100, + "none", + 0.0 ) + + # target = target.to_global(placement=logits.placement) + + # # Change -1 in target to 0 because sparse_softmax_cross_entropy don't accept -1 + # target = target * (target >= 0) + + # lm_loss = flow._C.sparse_softmax_cross_entropy( + # logits.view(-1, logits.shape[-1]), + # target.view(-1), + # ) return lm_loss diff --git a/libai/models/gpt_model.py b/libai/models/gpt_model.py index 27f6bc8e9..e9d890c93 100644 --- a/libai/models/gpt_model.py +++ b/libai/models/gpt_model.py @@ -244,7 +244,10 @@ def forward(self, input_ids, past_length=0): bsz, seq_length = input_ids.size() position_ids = self.position_ids[:, past_length : past_length + seq_length] - position_ids = position_ids.expand_as(input_ids).to_global(sbp=input_ids.sbp) + # position_ids = position_ids.expand_as(input_ids).to_global(sbp=input_ids.sbp) + position_ids = position_ids.expand_as(input_ids).to_global( + sbp=input_ids.sbp, placement=input_ids.placement + ) token_embeds = self.token_embeddings(input_ids) position_embeds = self.position_embeddings(position_ids) diff --git a/libai/models/utils/graph_base.py b/libai/models/utils/graph_base.py index 651209ccd..ad7e75118 100644 --- a/libai/models/utils/graph_base.py +++ b/libai/models/utils/graph_base.py @@ -102,8 +102,12 @@ def __init__( def build(self, **kwargs): if self.is_train: placement_sbp_dict = ( + # dict( + # placement=flow.env.all_device_placement("cuda"), + # sbp=flow.sbp.split(0), + # ) dict( - placement=flow.env.all_device_placement("cuda"), + placement=flow.env.all_device_placement("npu"), sbp=flow.sbp.split(0), ) if self.global_mode.enabled diff --git a/libai/models/utils/model_loader/base_loader.py b/libai/models/utils/model_loader/base_loader.py index e5a58a22a..1f06a06b3 100644 --- a/libai/models/utils/model_loader/base_loader.py +++ b/libai/models/utils/model_loader/base_loader.py @@ -384,6 +384,10 @@ def _convert_tensor(self, tensor): Returns: flow.Tensor: The target tensor. """ + import torch + if tensor.dtype == torch.bfloat16: + data = tensor.detach().half().cpu().numpy() + return flow.Tensor(data) return flow.Tensor(tensor.detach().cpu().numpy()) def _convert_tensors(self, torch_state_dict): @@ -490,6 +494,9 @@ def _load_torch_state_dict(self, state_dict_file, use_safetensors=False): merged_state_dict = {} for file in state_dict_file: state_dict = torch.load(file, map_location="cpu") + # NOTE: align to libai oneflow_xpu + for k in state_dict.keys(): + state_dict[k] = state_dict[k].to(torch.float) merged_state_dict.update(state_dict) return merged_state_dict diff --git a/libai/tokenizer/tokenization_base.py b/libai/tokenizer/tokenization_base.py index 026902fdf..5c309b592 100644 --- a/libai/tokenizer/tokenization_base.py +++ b/libai/tokenizer/tokenization_base.py @@ -782,9 +782,9 @@ def convert_to_tensors(self, token_ids, return_tensors=None, is_global=False, ** return_token_ids = flow.tensor(token_ids, dtype=flow.long) elif is_global: sbp = kwargs.get("sbp", dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])) - placement = kwargs.get( - "placement", flow.placement("cuda", list(range(dist.get_world_size()))) - ) + placement = kwargs.get("placement") + if placement is None: + placement = flow.placement("npu", list(range(dist.get_world_size()))) return_token_ids = flow.tensor( token_ids, sbp=sbp, placement=placement, dtype=flow.long ) diff --git a/libai/utils/distributed.py b/libai/utils/distributed.py index e7914a0ad..90bb15580 100644 --- a/libai/utils/distributed.py +++ b/libai/utils/distributed.py @@ -228,7 +228,7 @@ def device_type(self): return self._device_type def set_device_type(self, device_type): - assert device_type in ["cpu", "cuda"], f"not supported for {device_type}" + assert device_type in ["cpu", "cuda", "npu"], f"not supported for {device_type}" self._device_type = device_type def get_layer_ranks(self, layer_idx): @@ -431,6 +431,20 @@ def convert_to_distributed_default_setting(t): ) else: dist_util = get_dist_util() + if dist_util.device_type != "npu": + from omegaconf import DictConfig + + setup_dist_util( + DictConfig( + dict( + data_parallel_size=1, + tensor_parallel_size=1, + pipeline_parallel_size=1, + device_type="npu", + ) + ) + ) + dist_util = get_dist_util() device_type = dist_util.device_type return t.to_global(placement=flow.placement(device_type, ranks=t.placement.ranks)) @@ -438,7 +452,7 @@ def convert_to_distributed_default_setting(t): def ttol(tensor, pure_local=False, ranks=None): """Global tensor to local tensor.""" if tensor.is_global: - placement = tensor.placement if not ranks else flow.placement("cuda", ranks) + placement = tensor.placement if not ranks else flow.placement(tensor.placement.type, ranks) if pure_local: tensor = tensor.to_global(placement=placement).to_local() else: @@ -457,9 +471,9 @@ def tton(tensor, local_only=False, ranks=None): return tensor.numpy() -def tensor_to_rank0(tensor, device="cuda", to_local=False): +def tensor_to_rank0(tensor, device="npu", to_local=False): """Global tensor to rank0.""" - assert device in ["cpu", "cuda"], f"not supported for device:{device}" + assert device in ["cpu", "cuda", "npu"], f"not supported for device:{device}" if tensor.is_global: # Consider if it's 2d mesh, ranks should be [[0]] instead of [0] placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]]) diff --git a/libai/version.py b/libai/version.py new file mode 100644 index 000000000..d769f3af3 --- /dev/null +++ b/libai/version.py @@ -0,0 +1,2 @@ +__version__ = '0.2.0' +git_version = '229c4d9ee2bf6f881a9883176f1ea067254b3583' diff --git a/projects/ChatGLM/chatglm.py b/projects/ChatGLM/chatglm.py index cea239878..a0417363f 100644 --- a/projects/ChatGLM/chatglm.py +++ b/projects/ChatGLM/chatglm.py @@ -22,8 +22,10 @@ def apply_rotary_pos_emb(x: flow.Tensor, rope_cache: flow.Tensor) -> flow.Tensor x, x_pass = x[..., :rot_dim], x[..., rot_dim:] # truncate to support variable sizes rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + # xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + # rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + xshaped = dist.convert_to_distributed_default_setting(x.reshape(sq, -1, np, rot_dim // 2, 2)) + rope_cache = dist.convert_to_distributed_default_setting(rope_cache.view(sq, -1, 1, xshaped.size(3), 2)) x_out2 = flow.cat( [ (xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1]).unsqueeze( diff --git a/projects/ChatGLM/configs/chatglm_config.py b/projects/ChatGLM/configs/chatglm_config.py index 1192a3629..386f5aaf1 100644 --- a/projects/ChatGLM/configs/chatglm_config.py +++ b/projects/ChatGLM/configs/chatglm_config.py @@ -5,10 +5,11 @@ from projects.ChatGLM.chatglm import ChatGLMForConditionalGeneration from projects.ChatGLM.tokenizer import ChatGLMTokenizer from configs.common.train import train - +# from configs.train import train cfg = dict( # Model + model_type='chatglm', add_bias_linear=False, add_qkv_bias=True, apply_query_key_layer_scaling=True, @@ -61,7 +62,8 @@ output_scores=False, output_hidden_states=False, # train - pretrained_model_path=os.environ["CHATGLM_HF_DIR"], + # pretrained_model_path=os.environ["CHATGLM_HF_DIR"], + pretrained_model_path='/data0/hf_models/chatglm/chatglm2-6b', # lora_cfg lora_enable=False, lora_cfg=dict( @@ -87,5 +89,6 @@ tokenization = OmegaConf.create() tokenization.make_vocab_size_divisible_by = 1 tokenization.tokenizer = LazyCall(ChatGLMTokenizer)( - vocab_file=f"{os.environ['CHATGLM_HF_DIR']}/tokenizer.model" + # vocab_file=f"{os.environ['CHATGLM_HF_DIR']}/tokenizer.model" + vocab_file=cfg.pretrained_model_path+"/tokenizer.model" ) diff --git a/projects/ChatGLM/configs/chatglm_sft.py b/projects/ChatGLM/configs/chatglm_sft.py index ea60d2fa3..e015ab06e 100644 --- a/projects/ChatGLM/configs/chatglm_sft.py +++ b/projects/ChatGLM/configs/chatglm_sft.py @@ -21,11 +21,14 @@ max_source_len = 128 max_target_len = 128 max_length = 256 -dataset_path = os.environ["DATA_DIR"] -pretrained_model_path = os.environ["CHATGLM_HF_DIR"] +# dataset_path = os.environ["DATA_DIR"] +# pretrained_model_path = os.environ["CHATGLM_HF_DIR"] +dataset_path = './data/libai_xpu_alpaca' +pretrained_model_path = '/data0/hf_models/chatglm/chatglm2-6b' # graph & optim -graph["enabled"] = True +# graph["enabled"] = True +graph["enabled"] = False optim.update( dict( @@ -76,12 +79,17 @@ test_micro_batch_size=1, train_epoch=3, train_iter=1, - log_period=10, + # log_period=10, + log_period=1, warmup_ratio=2 / 5, num_accumulation_steps=8, rdma_enabled=True, - amp=dict(enabled=True), + # amp=dict(enabled=True), + amp=dict(enabled=False), + # train_with_fp16=True, + train_with_fp16=False, activation_checkpoint=dict(enabled=True), + input_placement_device='npu', checkpointer=dict( period=5000, max_to_keep=1, @@ -89,8 +97,10 @@ dist=dict( data_parallel_size=1, tensor_parallel_size=1, - pipeline_parallel_size=4, + # pipeline_parallel_size=4, + pipeline_parallel_size=1, pipeline_num_layers=cfg.num_layers, + device_type='npu', ), evaluation=dict( enabled=False, diff --git a/projects/ChatGLM/dataset.py b/projects/ChatGLM/dataset.py index f09b87ac9..9f10f0af4 100644 --- a/projects/ChatGLM/dataset.py +++ b/projects/ChatGLM/dataset.py @@ -24,7 +24,8 @@ from libai.utils.logger import setup_logger IGNORE_INDEX = -100 -logger = setup_logger() +# logger = setup_logger() +logger = setup_logger(name=__name__) class ChatGLMTrainDataset(Dataset): diff --git a/projects/ChatGLM/lora/layers.py b/projects/ChatGLM/lora/layers.py index 7fd54feb9..0746213b1 100644 --- a/projects/ChatGLM/lora/layers.py +++ b/projects/ChatGLM/lora/layers.py @@ -41,18 +41,22 @@ class BaseTunerLayer(ABC): active_adapter = None # All names of layers that may contain adapter (trainable) weights - adapter_layer_names: tuple[str] = () + # adapter_layer_names: tuple[str] = () + adapter_layer_names: tuple = () # All names of other parameters that may contain adapter-related parameters - other_param_names: tuple[str] = () + # other_param_names: tuple[str] = () + other_param_names: tuple = () # indicates whether all adapters should be disabled _disable_adapters: bool = False # the currently active adapter(s) - _active_adapter: str | list[str] = "default" + # _active_adapter: str | list[str] = "default" + _active_adapter: str = "default" # List all merged adapters - merged_adapters: list[str] = [] + # merged_adapters: list[str] = [] + merged_adapters: list = [] def get_base_layer(self) -> nn.Module: """ @@ -72,7 +76,8 @@ def weight(self) -> flow.Tensor: weight = base_layer.weight return weight - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + # def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + def merge(self, safe_merge: bool = False, adapter_names = None) -> None: raise NotImplementedError def unmerge(self) -> None: @@ -119,7 +124,8 @@ def enable_adapters(self, enabled: bool) -> None: layer.requires_grad_(False) self._disable_adapters = True - def set_adapter(self, adapter_names: str | list[str]) -> None: + # def set_adapter(self, adapter_names: str | list[str]) -> None: + def set_adapter(self, adapter_names) -> None: """Set the active adapter(s). Args: @@ -142,7 +148,8 @@ def set_adapter(self, adapter_names: str | list[str]) -> None: self._active_adapter = adapter_names - def _all_available_adapter_names(self) -> list[str]: + # def _all_available_adapter_names(self) -> list[str]: + def _all_available_adapter_names(self) -> list: """Return a sorted list of all available adapter names""" adapter_names = set() for name in self.adapter_layer_names + self.other_param_names: diff --git a/projects/ChatGLM/lora/lora_model.py b/projects/ChatGLM/lora/lora_model.py index 2a19c6675..e06a39b6f 100644 --- a/projects/ChatGLM/lora/lora_model.py +++ b/projects/ChatGLM/lora/lora_model.py @@ -50,7 +50,8 @@ def __init__(self, model, peft_config, adapter_name: str) -> None: self.inject_adapter(self.model, adapter_name) @property - def active_adapters(self) -> list[str]: + # def active_adapters(self) -> list[str]: + def active_adapters(self) -> list: if isinstance(self.active_adapter, str): return [self.active_adapter] # is already a list of str @@ -192,7 +193,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str): if adapter_name in n: p.requires_grad = False - def merge_adapter(self, safe_merge=False, adapter_names: Optional[list[str]] = None) -> None: + # def merge_adapter(self, safe_merge=False, adapter_names: Optional[list[str]] = None) -> None: + def merge_adapter(self, safe_merge=False, adapter_names = None) -> None: """ This method merges the adapter layers into the base model. @@ -404,7 +406,8 @@ def disable_adapter_layers(self) -> None: warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name: str | list[str]) -> None: + # def set_adapter(self, adapter_name: str | list[str]) -> None: + def set_adapter(self, adapter_name) -> None: """Set the active adapter(s). Args: diff --git a/projects/ChatGLM/lora/utils.py b/projects/ChatGLM/lora/utils.py index a1c195547..5990bb2a5 100644 --- a/projects/ChatGLM/lora/utils.py +++ b/projects/ChatGLM/lora/utils.py @@ -22,7 +22,8 @@ COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"] -def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: +# def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: +def check_target_module_exists(config, key: str): """A helper method to check if the passed module's key name matches any of the target modules in the adapter_config. diff --git a/projects/ChatGLM/pipeline.py b/projects/ChatGLM/pipeline.py index 0505cc855..adc321411 100644 --- a/projects/ChatGLM/pipeline.py +++ b/projects/ChatGLM/pipeline.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os - +import oneflow_npu from libai.inference.basic import BasePipeline from libai.utils import distributed as dist @@ -94,7 +94,8 @@ def _parse_parameters(self, **pipeline_parameters): return preprocess_params, forward_params, postprocess_params - def preprocess(self, sentence: str | list, **kwargs) -> dict: + # def preprocess(self, sentence: str | list, **kwargs) -> dict: + def preprocess(self, sentence, **kwargs) -> dict: # if type(sentence) is str: inputs = { @@ -162,7 +163,9 @@ def reset_conversation(self): tensor_parallel=1, pipeline_parallel=1, pipeline_num_layers=28, - model_path=os.environ["CHATGLM_HF_DIR"], + # model_path=os.environ["CHATGLM_HF_DIR"], + device_type='npu', + model_path='/data0/hf_models/chatglm/chatglm2-6b', mode="huggingface", ) pipeline.model = pipeline.model.half() diff --git a/run_chatglm_npu.sh b/run_chatglm_npu.sh new file mode 100644 index 000000000..32036686d --- /dev/null +++ b/run_chatglm_npu.sh @@ -0,0 +1,15 @@ +# set visible devices +export ASCEND_RT_VISIBLE_DEVICES=1 + +# debug +export ONEFLOW_DEBUG=0 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 + + +# infer +python projects/ChatGLM/pipeline.py + + +# # train +# python projects/ChatGLM/utils/prepare_alpaca.py +# bash tools/train.sh tools/train_net.py projects/ChatGLM/configs/ChatGLM_sft.py 1 \ No newline at end of file diff --git a/tools/train.sh b/tools/train.sh index 714ac9953..f2430c2c9 100755 --- a/tools/train.sh +++ b/tools/train.sh @@ -1,4 +1,11 @@ #!/usr/bin/env bash +# set visible devices +export ASCEND_RT_VISIBLE_DEVICES=1 + +# debug +export ONEFLOW_DEBUG=0 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 + FILE=$1 CONFIG=$2 diff --git a/tools/train_net.py b/tools/train_net.py index 458849c75..7cad038cc 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -14,32 +14,89 @@ # limitations under the License. import logging -import os import random -import sys +import importlib import numpy as np import oneflow as flow +import oneflow_npu -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) +import libai.utils.distributed as dist from libai.config import LazyConfig, default_argument_parser, try_get_key from libai.engine import DefaultTrainer, default_setup from libai.utils.checkpoint import Checkpointer +# from configs.loader_mapping import loader_mapping_models as mapping -logger = logging.getLogger("libai." + __name__) +mapping = dict( + llama=dict( + loader_prefix="projects.Llama.utils.llama_loader", + huggingface_loader="LlamaLoaderHuggerFace", + ), + + chatglm=dict( + loader_prefix="projects.ChatGLM.utils.chatglm_loader", + huggingface_loader="ChatGLMLoaderHuggerFace", + ), + + qwen2=dict( + loader_prefix="projects.Qwen2.utils.qwen_loader", + huggingface_loader="Qwen2LoaderHuggerFace", + ), + + aquila=dict( + loader_prefix="projects.Aquila.utils.aquila_loader", + huggingface_loader="AquilaLoaderHuggerFace", + ) + +) + + + + +def build_model(cfg): + model_arguments=mapping[cfg.cfg.model_type] + Loader = getattr( + importlib.import_module(model_arguments['loader_prefix']), + model_arguments['huggingface_loader'], + ) + model_loader = Loader( + cfg, + cfg.cfg, + cfg.cfg.pretrained_model_path, + ) + model = model_loader.load() + return model + +class Trainer(DefaultTrainer): + @classmethod + def build_model(cls, cfg): + assert try_get_key(cfg, "model") is not None, "cfg must contain `model` namespace" + # Set model fp16 option because of embedding layer `white_identity` manual + # insert for amp training if provided. + if try_get_key(cfg.model, "cfg.amp_enabled") is not None: + cfg.model.cfg.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled + # In case some model define without cfg keyword. + elif try_get_key(cfg.model, "amp_enabled") is not None: + cfg.model.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled + model = build_model(cfg.model) + logger = logging.getLogger(__name__) + logger.info("Model:\n{}".format(model)) + model._apply(dist.convert_to_distributed_default_setting) + + if cfg.train.train_with_fp16: + model = model.to(flow.float16) + flow.cuda.empty_cache() + '''for param in model.named_parameters(): + print(param[1].dtype)''' + + return model def main(args): cfg = LazyConfig.load(args.config_file) cfg = LazyConfig.apply_overrides(cfg, args.opts) default_setup(cfg, args) - seed_for_rank = cfg.train.seed + flow.env.get_rank() - flow.manual_seed(seed_for_rank) - flow.cuda.manual_seed(seed_for_rank) - np.random.seed(seed_for_rank) - random.seed(seed_for_rank) - if args.fast_dev_run: cfg.train.train_epoch = 0 cfg.train.train_iter = 20 @@ -58,11 +115,12 @@ def main(args): model = DefaultTrainer.build_graph(cfg, model, is_train=False) test_loader = DefaultTrainer.build_test_loader(cfg, tokenizer) if len(test_loader) == 0: + logger = logging.getLogger(__name__) logger.info("No dataset in dataloader.test, please set dataset for dataloader.test") _ = DefaultTrainer.test(cfg, test_loader, model) return - trainer = DefaultTrainer(cfg) + trainer = Trainer(cfg) return trainer.train() diff --git a/train_chatglm.sh b/train_chatglm.sh new file mode 100644 index 000000000..af19cbb49 --- /dev/null +++ b/train_chatglm.sh @@ -0,0 +1 @@ +bash tools/train.sh tools/train_net.py projects/ChatGLM/configs/chatglm_sft.py 1 \ No newline at end of file