Skip to content

Commit

Permalink
feat(varlen): support varlen training for huggingface models (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai authored Aug 6, 2024
1 parent f6922c1 commit 3d84b85
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 98 deletions.
12 changes: 11 additions & 1 deletion doc/en/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ To start a demo model training, you need to prepare three things: **installation

Please refer to the [installation guide](./install.md) for instructions on how to install the necessary dependencies.

### Dataset Preparation (HuggingFace Datasets)
If you are using the HuggingFace datasets for on-the-fly streaming load and tokenize, taking the `roneneldan/TinyStories`` dataset as an example, the data preparation stage only requires the following adjustments in the configuration file:
```python
TRAIN_FOLDER = "roneneldan/TinyStories"
data = dict(
type="hf",
tokenizer_path="internlm/internlm-7b",
)
```

### Dataset Preparation (Pre-training)

The dataset for the InternEvo training task includes a series of `bin` and `meta` files. A `tokenizer` is used to generate the training dataset from the original text files. The tokenizer model is imported by specifying the model parameter path in `tools/tokenizer.py`. Currently, `tokenizer_internlm.model` is provided to generate tokens. If you want to use a different model, you can directly modify the model parameter path in `tokenizer.py`.
Expand Down Expand Up @@ -417,4 +427,4 @@ model = dict(
Regarding the principle of Dyanmic NTK, please refer to

1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
2. https://kexue.fm/archives/9675
2. https://kexue.fm/archives/9675
13 changes: 12 additions & 1 deletion doc/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
### 安装
请参考[安装文档](./install.md)进行安装。

### 数据准备 (使用huggingface数据集)

如果使用huggingface数据集进行在线加载并且在线tokenize的话,那么以`roneneldan/TinyStories`这个数据为例,数据准备阶段只需要将配置文件做如下改动:
```python
TRAIN_FOLDER = "roneneldan/TinyStories"
data = dict(
type="hf",
tokenizer_path="internlm/internlm-7b",
)
```

### 数据准备 (预训练)

InternEvo训练任务的数据集包括一系列的`bin``meta`文件。使用`tokenizer`从原始文本文件生成训练用数据集。通过在`tools/tokenizer.py`中指定模型参数路径的方式来导入tokenizer模型。目前提供`V7_sft.model`来生成tokens。若想使用不同的模型,可直接修改`tokernizer.py`中的模型参数路径。
Expand Down Expand Up @@ -499,4 +510,4 @@ generation = dict(

关于 Dyanmic NTK 的原理,详细请参考
1. [dynamically_scaled_rope_further_increases](https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases)
2. [https://kexue.fm/archives/9675](https://kexue.fm/archives/9675)
2. [https://kexue.fm/archives/9675](https://kexue.fm/archives/9675)
39 changes: 15 additions & 24 deletions internlm/data/build_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.data.streaming.batch_sampler import StreamingStaticBatchSampler
from internlm.data.streaming.collaters import nopack_collate_fn, pack_collate_fn
from internlm.data.streaming.collaters import pack_collate_fn
from internlm.data.streaming.dataset import (
HuggingFacePackedDataset,
HuggingFaceStreamingDataset,
Expand All @@ -32,7 +32,7 @@
)
from internlm.data.utils import get_dataset_type_ids_map
from internlm.utils.logger import get_logger
from internlm.utils.utils import DataType, ModelType
from internlm.utils.utils import DataType

# global llm logger
logger = get_logger(__file__)
Expand Down Expand Up @@ -120,34 +120,25 @@ def get_tokenized_valid_loader_items(data_cfg):


def get_hf_train_loader_items(data_cfg):
assert not data_cfg.pack_sample_into_one, "hf dataloader curently only supports pack_sample_into_one=False"
train_ds = HuggingFaceStreamingDataset(
dataset_name=data_cfg.train_folder,
tokenizer_name=data_cfg.tokenizer_path,
model_max_length=data_cfg.seq_len,
subset_name=data_cfg.get("subset_name", None),
)
pad_token_id = gpc.config.model.get("pad_token_id", 0)
if gpc.config.model_type == ModelType.HF.name and not data_cfg.use_packed_dataset:
train_sampler = StreamingStaticBatchSampler(
batch_size=data_cfg.micro_num * data_cfg.micro_bsz, rampup_batch_size=data_cfg.rampup_batch_size
)
train_collate_fn = partial(
nopack_collate_fn,
micro_num=data_cfg.micro_num,
micro_bsz=data_cfg.micro_bsz,
seq_len=data_cfg.seq_len,
pad_token_id=pad_token_id,
)
else:
train_ds = HuggingFacePackedDataset(
dataset=train_ds, seq_len=data_cfg.seq_len, micro_bsz=data_cfg.micro_bsz, pad_token_id=pad_token_id
)
train_sampler = StreamingStaticBatchSampler(
batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size
)
train_collate_fn = partial(
pack_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len
)
train_ds = HuggingFacePackedDataset(
dataset=train_ds,
seq_len=data_cfg.seq_len,
micro_bsz=data_cfg.micro_bsz,
pad_token_id=gpc.config.model.get("pad_token_id", 0),
)
train_sampler = StreamingStaticBatchSampler(
batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size
)
train_collate_fn = partial(
pack_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len
)
return train_ds, train_sampler, train_collate_fn


Expand Down
3 changes: 1 addition & 2 deletions internlm/data/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from .batch_sampler import StreamingStaticBatchSampler
from .collaters import nopack_collate_fn, pack_collate_fn
from .collaters import pack_collate_fn
from .dataset import HuggingFacePackedDataset, HuggingFaceStreamingDataset
from .utils import hf_simple_resume

__all__ = [
"StreamingStaticBatchSampler",
"nopack_collate_fn",
"pack_collate_fn",
"HuggingFaceStreamingDataset",
"HuggingFacePackedDataset",
Expand Down
35 changes: 0 additions & 35 deletions internlm/data/streaming/collaters.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,6 @@
import torch


def nopack_collate_fn(batch, micro_num, micro_bsz, seq_len, pad_token_id=0):
input_ids_list = []
attention_mask_list = []
labels_list = []

for b in batch:
assert len(b["input_ids"]) > 0

if "attention_mask" in b:
assert len(b["input_ids"]) == len(
b["attention_mask"]
), "input_ids and attention_mask should be equal length"
else:
b["attention_mask"] = [True] * len(b["input_ids"])

input_ids = b["input_ids"] + [pad_token_id] * (seq_len - len(b["input_ids"]))
attention_mask = b["attention_mask"] + [False] * (seq_len - len(b["attention_mask"]))
labels = [w if w > 0 else -100 for w in b["input_ids"]][1:] + [-100]
labels = labels + [-100] * (seq_len - len(b["input_ids"]))

input_ids_list.append(torch.LongTensor(input_ids))
attention_mask_list.append(torch.BoolTensor(attention_mask))
labels_list.append(torch.LongTensor(labels))

input_ids = torch.stack(input_ids_list)
attention_mask = torch.stack(attention_mask_list)
labels = torch.stack(labels_list)

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"type_ids": torch.zeros(micro_num, micro_bsz, seq_len, dtype=torch.int64),
}, labels


def pack_collate_fn(batch, micro_num, micro_bsz, seq_len):
packed_length = micro_bsz * seq_len

Expand Down
11 changes: 3 additions & 8 deletions internlm/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from internlm.core.context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
from internlm.utils.utils import ModelType


Expand Down Expand Up @@ -51,9 +52,6 @@ def unpack_type_ids(type_ids, cu_seqlens):

def unpack_data(data, label):

if gpc.config.model_type == ModelType.HF.name:
return data, label

data["input_ids"] = _unpack_data(data["input_ids"], data["cu_seqlens"], padding_v=0).squeeze(0)
label = _unpack_data(label, data["cu_seqlens"], padding_v=-100).squeeze(0)

Expand All @@ -73,11 +71,8 @@ def packed_data_normalizer(data, label):
data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item()

if gpc.config.model_type == ModelType.HF.name:
data.pop("cu_seqlens")
data.pop("max_seqlen")
gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("cu_seqlens")
gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("max_seqlen")
data["position_ids"] = data.pop("indexes")
data["attention_mask"] = torch.ones(
(data["input_ids"].shape), dtype=torch.bool, device=data["input_ids"].device
)

return data, label
19 changes: 13 additions & 6 deletions internlm/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,20 @@
from internlm.core.context import global_context as gpc
from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper
from internlm.model.registry import hf_config_initializer, model_initializer
from internlm.model.utils import convert_hf_config
from internlm.utils.common import get_current_device
from internlm.utils.utils import ModelType


def create_model(model_type, *args, **kwargs) -> Union[nn.Module, List[nn.Module]]:
def create_model(model_type) -> Union[nn.Module, List[nn.Module]]:

if model_type == ModelType.HF.name:
extra_kwargs = {"return_dict": False, "attn_implementation": "flash_attention_2"}
config = hf_config_initializer.get_module(module_name=model_type)(**extra_kwargs)
convert_hf_config(config)

kwargs = dict(gpc.config.model)

num_layers = kwargs.pop("num_layers")
num_chunks = kwargs.pop("num_chunks", 1)

Expand All @@ -26,17 +35,15 @@ def create_model(model_type, *args, **kwargs) -> Union[nn.Module, List[nn.Module

if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE):
if model_type == ModelType.HF.name:
hf_config_builder = hf_config_initializer.get_module(module_name=model_type)
config = hf_config_builder(return_dict=False)
model = model_buidler(*args, config).to(kwargs["device"])
model = model_buidler(config).to(kwargs["device"])
else:
kwargs["first"] = kwargs["last"] = True
kwargs["start_layer_idx"] = 0
kwargs["num_layers"] = num_layers
model = model_buidler(*args, **kwargs).to(kwargs["device"])
model = model_buidler(**kwargs).to(kwargs["device"])
setattr(model, "first_layer", 0)
setattr(model, "last_layer", num_layers)
else:
model = pipeline_parallel_sharding_wrapper(num_layers, num_chunks, model_buidler, *args, **kwargs)
model = pipeline_parallel_sharding_wrapper(num_layers, num_chunks, model_buidler, **kwargs)

return model
13 changes: 13 additions & 0 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, List

from internlm.core.context import global_context as gpc
from internlm.model.modules.mha import MHA


Expand Down Expand Up @@ -51,3 +52,15 @@ def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]:
kwargs["max_seqlen"] = args[3]

return kwargs


def convert_hf_config(config):
gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = config.vocab_size
gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = config.hidden_size
gpc.config.model.num_layers = gpc.config.NUM_LAYER = config.num_hidden_layers
gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = config.num_attention_heads
gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = config.intermediate_size / config.hidden_size

# For models that use GQA
if hasattr(config, "num_key_value_heads"):
gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = config.num_key_value_heads
26 changes: 7 additions & 19 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
sync_model_replica_param_group,
)
from internlm.utils.timeout import llm_timeout
from internlm.utils.utils import DataType, ModelType, TensorParallelMode
from internlm.utils.utils import TensorParallelMode

try:
import torch_npu
Expand Down Expand Up @@ -173,7 +173,7 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f

register_model_initializer()

model = create_model(model_type=gpc.config.model_type, **(gpc.config.model))
model = create_model(model_type=gpc.config.model_type)

if post_process_func:
post_process_func(pre_process_output)
Expand Down Expand Up @@ -464,8 +464,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
if batch[0].get("type_ids", None) is not None:
# if use_packed_dataset is False, we need to unpack type_ids
if not gpc.config.data.use_packed_dataset:
if gpc.config.data.type != DataType.hf.name or gpc.config.model_type != ModelType.HF.name:
batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"])
batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"])

return batch, train_iter

Expand Down Expand Up @@ -555,21 +554,10 @@ def record_current_batch_training_metrics(

num_tokens_in_batch = batch[1].nelement()
real_num_tokens = math.ceil(acc_perplex.pop("real_token_num") / gpc.get_world_size(ParallelMode.GLOBAL))
# TODO: check logic
if (
gpc.config.data.type == DataType.hf.name
and gpc.config.model_type == ModelType.HF.name
and not gpc.config.data.use_packed_dataset
):
num_samples_in_batch = gpc.config.data.micro_bsz * gpc.config.data.micro_num
max_length_in_batch = batch[0]["attention_mask"].sum(dim=1).max().item()
max_samples_in_batch = gpc.config.data.micro_bsz
min_samples_in_batch = gpc.config.data.micro_bsz
else:
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
time_cost = time.time() - start_time
tk_per_gpu = round(
num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL),
Expand Down
3 changes: 2 additions & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
transformers
sentencepiece
datasets
numpy
tqdm
einops
psutil
packaging
pre-commit
ninja
gputil
pytest
packaging
boto3
botocore
torch-scatter
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/common_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@

def init_naive_model():
register_model_initializer()
model = create_model(model_type=gpc.config.model_type, **(init_config.model))
model = create_model(model_type=gpc.config.model_type)
model = NaiveAMPModel(
model=model,
output_to_fp32=False,
Expand Down

0 comments on commit 3d84b85

Please sign in to comment.