From b89f03a584ca1bbaffd242b7bb3adbd343daf5ed Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 4 May 2025 14:50:53 +0800 Subject: [PATCH 1/6] fix enable_cache --- ...73\244\350\241\214\345\217\202\346\225\260.md" | 1 - .../Instruction/Command-line-parameters.md | 1 - swift/llm/argument/base_args/data_args.py | 4 ---- swift/llm/dataset/__init__.py | 7 ++++--- swift/llm/dataset/loader.py | 15 ++++++++++++--- swift/llm/dataset/preprocessor/core.py | 9 +++++++-- 6 files changed, 23 insertions(+), 14 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index edda462034..951be9f45c 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -52,7 +52,6 @@ - interleave_prob: 默认值为 None。在组合多个数据集时,默认使用 `concatenate_datasets` 函数;如果设置了该参数,则会使用 `interleave_datasets` 函数。该参数通常用于流式数据集的组合,并会作为参数传入 `interleave_datasets` 函数中 - stopping_strategy: 可选为"first_exhausted", "all_exhausted",默认为"first_exhausted"。传入interleave_datasets函数中 - shuffle_buffer_size: 该参数用于指定流式数据集的随机buffer大小,默认为1000 -- enable_cache: 数据集预处理使用cache,默认False - download_mode: 数据集下载模式,包含`reuse_dataset_if_exists`和`force_redownload`,默认为reuse_dataset_if_exists - columns: 用于对数据集进行列映射,使数据集满足AutoPreprocessor可以处理的样式,具体查看[这里](../Customization/自定义数据集.md)。你可以传入json字符串,例如:`'{"text1": "query", "text2": "response"}'`,默认为None。 - strict: 如果为True,则数据集只要某行有问题直接抛错,否则会丢弃出错数据样本。默认False diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 49aa0328a6..b9924dccd1 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -53,7 +53,6 @@ Hints: - interleave_prob: Defaults to None. When combining multiple datasets, the `concatenate_datasets` function is used by default. If this parameter is set, the `interleave_datasets` function will be used instead. This parameter is typically used when combining streaming datasets and is passed to the `interleave_datasets` function. - stopping_strategy: Can be either "first_exhausted" or "all_exhausted", with the default being "first_exhausted". This parameter is passed to the `interleave_datasets` function. - shuffle_buffer_size: This parameter is used to specify the shuffle buffer size for streaming datasets. Defaults to 1000. -- enable_cache: Use cache for dataset preprocessing, default is False. - download_mode: Dataset download mode, including `reuse_dataset_if_exists` and `force_redownload`, default is reuse_dataset_if_exists. - columns: Used for column mapping of the dataset to ensure that the dataset conforms to the format that AutoPreprocessor can handle. For more details, see [here](../Customization/Custom-dataset.md). You can pass in a JSON string, for example: `'{"text1": "query", "text2": "response"}'`, with the default being None. - strict: If set to True, any row with an issue in the dataset will throw an error immediately, otherwise, erroneous data samples will be discarded. Default is False. diff --git a/swift/llm/argument/base_args/data_args.py b/swift/llm/argument/base_args/data_args.py index 519eb945a6..ffb337fb91 100644 --- a/swift/llm/argument/base_args/data_args.py +++ b/swift/llm/argument/base_args/data_args.py @@ -22,7 +22,6 @@ class DataArguments: data_seed (Optional[int]): Seed for dataset shuffling. Default is None. dataset_num_proc (int): Number of processes to use for data loading and preprocessing. Default is 1. streaming (bool): Flag to enable streaming of datasets. Default is False. - enable_cache (bool): Flag to load dataset from cache file. Default is False. download_mode (Literal): Mode for downloading datasets. Default is 'reuse_dataset_if_exists'. columns: Used for manual column mapping of datasets. model_name (List[str]): List containing Chinese and English names of the model. Default is [None, None]. @@ -46,7 +45,6 @@ class DataArguments: stopping_strategy: Literal['first_exhausted', 'all_exhausted'] = 'first_exhausted' shuffle_buffer_size: int = 1000 - enable_cache: bool = False download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists' columns: Optional[Union[dict, str]] = None strict: bool = False @@ -69,8 +67,6 @@ def __post_init__(self): if self.data_seed is None: self.data_seed = self.seed self.columns = self.parse_to_dict(self.columns) - if self.enable_cache: - enable_caching() if len(self.val_dataset) > 0 or self.streaming: self.split_dataset_ratio = 0. if len(self.val_dataset) > 0: diff --git a/swift/llm/dataset/__init__.py b/swift/llm/dataset/__init__.py index d4c6920242..6d8a54c47d 100644 --- a/swift/llm/dataset/__init__.py +++ b/swift/llm/dataset/__init__.py @@ -18,10 +18,12 @@ _get_temporary_cache_files_directory = datasets.fingerprint.get_temporary_cache_files_directory -def _update_fingerprint_mac(*args, **kwargs): +def _update_fingerprint_mac(fingerprint, transform, transform_args): # Prevent different nodes use the same location in unique shared disk mac = _find_local_mac().replace(':', '') - fp = _update_fingerprint(*args, **kwargs) + if 'function' in transform_args and hasattr(transform_args['function'], '__self__'): + transform_args['function'] = str(transform_args['function'].__self__.__class__) + fp = _update_fingerprint(fingerprint, transform, transform_args) fp += '-' + mac if len(fp) > 64: fp = fp[:64] @@ -33,4 +35,3 @@ def _update_fingerprint_mac(*args, **kwargs): datasets.fingerprint.get_temporary_cache_files_directory = get_temporary_cache_files_directory datasets.arrow_dataset.get_temporary_cache_files_directory = get_temporary_cache_files_directory register_dataset_info() -disable_caching() diff --git a/swift/llm/dataset/loader.py b/swift/llm/dataset/loader.py index 13cd09896f..56ce7d672a 100644 --- a/swift/llm/dataset/loader.py +++ b/swift/llm/dataset/loader.py @@ -198,6 +198,7 @@ def _load_dataset_path( dataset_meta: DatasetMeta, *, num_proc: int = 1, + load_from_cache_file: bool = False, strict: bool = False, streaming: bool = False, columns: Optional[Dict[str, str]] = None, @@ -211,7 +212,8 @@ def _load_dataset_path( dataset = hf_load_dataset(file_type, data_files=dataset_path, **kwargs) if columns: dataset = RowPreprocessor.safe_rename_columns(dataset, columns) - dataset = dataset_meta.preprocess_func(dataset, num_proc=num_proc, strict=strict) + dataset = dataset_meta.preprocess_func( + dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict) if remove_unused_columns: dataset = RowPreprocessor.remove_useless_columns(dataset) return dataset @@ -222,6 +224,7 @@ def _load_repo_dataset( subset: SubsetDataset, *, num_proc: int = 1, + load_from_cache_file: bool = False, streaming: bool = False, use_hf: Optional[bool] = None, hub_token: Optional[str] = None, @@ -282,7 +285,8 @@ def _load_repo_dataset( dataset = dataset.to_iterable_dataset() if columns: dataset = RowPreprocessor.safe_rename_columns(dataset, columns) - dataset = subset.preprocess_func(dataset, num_proc=num_proc, strict=strict) + dataset = subset.preprocess_func( + dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict) if remove_unused_columns: dataset = RowPreprocessor.remove_useless_columns(dataset) datasets.append(dataset) @@ -373,6 +377,7 @@ def load( dataset_meta: Optional[DatasetMeta] = None, *, num_proc: int = 1, + load_from_cache_file: bool = False, streaming: bool = False, use_hf: Optional[bool] = None, hub_token: Optional[str] = None, @@ -386,6 +391,7 @@ def load( dataset_syntax.dataset, dataset_meta=dataset_meta, num_proc=num_proc, + load_from_cache_file=load_from_cache_file, strict=strict, streaming=streaming, columns=columns, @@ -402,6 +408,7 @@ def load( use_hf=use_hf, hub_token=hub_token, num_proc=num_proc, + load_from_cache_file=load_from_cache_file, strict=strict, revision=revision, streaming=streaming, @@ -435,6 +442,7 @@ def load_dataset( split_dataset_ratio: float = 0., seed: Union[int, np.random.RandomState, None] = None, num_proc: int = 1, + load_from_cache_file: bool = False, shuffle: bool = False, streaming: bool = False, interleave_prob: Optional[List[float]] = None, @@ -444,7 +452,7 @@ def load_dataset( hub_token: Optional[str] = None, strict: bool = False, download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists', - columns: Optional[Dict[str, str]] = None, + columns: Optional[Dict[str, str]] = None, # columns_mapping remove_unused_columns: bool = True, # self-cognition model_name: Union[Tuple[str, str], List[str], None] = None, # zh, en @@ -482,6 +490,7 @@ def load_dataset( val_datasets = [] load_kwargs = { 'num_proc': num_proc, + 'load_from_cache_file': load_from_cache_file, 'strict': strict, 'download_mode': download_mode, 'columns': columns, diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/llm/dataset/preprocessor/core.py index 49077f8c6e..15c66a8391 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/llm/dataset/preprocessor/core.py @@ -279,6 +279,7 @@ def __call__( dataset: DATASET_TYPE, *, num_proc: int = 1, + load_from_cache_file: bool = False, strict: bool = False, batch_size: Optional[int] = None, ) -> DATASET_TYPE: @@ -290,7 +291,10 @@ def __call__( map_kwargs = {'batched': True, 'batch_size': batch_size} if isinstance(dataset, HfDataset): - map_kwargs['num_proc'] = num_proc + map_kwargs.update({ + 'num_proc': num_proc, + 'load_from_cache_file': load_from_cache_file, + }) # compat GRPO: The solution field will be retained. dataset = RowPreprocessor.get_features_dataset(dataset) if 'solution' in dataset.features: @@ -509,8 +513,9 @@ def __call__( dataset: DATASET_TYPE, *, num_proc: int = 1, + load_from_cache_file: bool = False, strict: bool = False, ) -> DATASET_TYPE: dataset = RowPreprocessor.safe_rename_columns(dataset, self.columns) preprocessor = self._get_preprocessor(dataset) - return preprocessor(dataset, num_proc=num_proc, strict=strict) + return preprocessor(dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict) From 7a7944340f039f56e17880feda411cd18f472e85 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 4 May 2025 15:27:40 +0800 Subject: [PATCH 2/6] update --- swift/llm/dataset/__init__.py | 19 +++++-------------- swift/llm/dataset/preprocessor/core.py | 9 ++++++--- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/swift/llm/dataset/__init__.py b/swift/llm/dataset/__init__.py index 6d8a54c47d..1b7bf36f74 100644 --- a/swift/llm/dataset/__init__.py +++ b/swift/llm/dataset/__init__.py @@ -1,9 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import datasets.fingerprint from datasets import Dataset as HfDataset -from datasets import disable_caching -from swift.utils.torch_utils import _find_local_mac from ..utils import get_temporary_cache_files_directory from . import dataset from .loader import DATASET_TYPE, load_dataset @@ -14,24 +12,17 @@ from .utils import (EncodePreprocessor, GetLengthPreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, sample_dataset) -_update_fingerprint = datasets.fingerprint.update_fingerprint -_get_temporary_cache_files_directory = datasets.fingerprint.get_temporary_cache_files_directory +update_fingerprint_origin = datasets.fingerprint.update_fingerprint -def _update_fingerprint_mac(fingerprint, transform, transform_args): - # Prevent different nodes use the same location in unique shared disk - mac = _find_local_mac().replace(':', '') +def update_fingerprint(fingerprint, transform, transform_args): if 'function' in transform_args and hasattr(transform_args['function'], '__self__'): transform_args['function'] = str(transform_args['function'].__self__.__class__) - fp = _update_fingerprint(fingerprint, transform, transform_args) - fp += '-' + mac - if len(fp) > 64: - fp = fp[:64] - return fp + return update_fingerprint_origin(fingerprint, transform, transform_args) -datasets.fingerprint.update_fingerprint = _update_fingerprint_mac -datasets.arrow_dataset.update_fingerprint = _update_fingerprint_mac +datasets.fingerprint.update_fingerprint = update_fingerprint +datasets.arrow_dataset.update_fingerprint = update_fingerprint datasets.fingerprint.get_temporary_cache_files_directory = get_temporary_cache_files_directory datasets.arrow_dataset.get_temporary_cache_files_directory = get_temporary_cache_files_directory register_dataset_info() diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/llm/dataset/preprocessor/core.py index 15c66a8391..d1d79eb7a4 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/llm/dataset/preprocessor/core.py @@ -11,7 +11,7 @@ from datasets import Sequence, Value from swift.llm import history_to_messages -from swift.utils import get_logger +from swift.utils import get_logger, is_dist, is_master, safe_ddp_context DATASET_TYPE = Union[HfDataset, HfIterableDataset] @@ -291,6 +291,8 @@ def __call__( map_kwargs = {'batched': True, 'batch_size': batch_size} if isinstance(dataset, HfDataset): + if is_dist() and not is_master(): + load_from_cache_file = True map_kwargs.update({ 'num_proc': num_proc, 'load_from_cache_file': load_from_cache_file, @@ -298,13 +300,14 @@ def __call__( # compat GRPO: The solution field will be retained. dataset = RowPreprocessor.get_features_dataset(dataset) if 'solution' in dataset.features: - dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs) + with safe_ddp_context('dataset_map_solution', True): + dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs) dataset = self._rename_columns(dataset) dataset = self.prepare_dataset(dataset) dataset = self._cast_pil_image(dataset) ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False - with self._patch_arrow_writer(): + with self._patch_arrow_writer(), safe_ddp_context('dataset_map', True): try: dataset_mapped = dataset.map( self.batched_preprocess, From d80c218cf3740c81fad0618ab4b9d122489f9ddc Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 4 May 2025 16:55:42 +0800 Subject: [PATCH 3/6] update --- examples/train/lora_sft.sh | 1 + swift/llm/dataset/utils.py | 26 +++++++++++++------------- swift/llm/train/sft.py | 4 ++-- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/train/lora_sft.sh b/examples/train/lora_sft.sh index 07195979bb..849adedf54 100644 --- a/examples/train/lora_sft.sh +++ b/examples/train/lora_sft.sh @@ -1,4 +1,5 @@ # 22GB +# qwen3: https://github.com/modelscope/ms-swift/blob/main/examples/train/think_model/qwen3_demo1.sh CUDA_VISIBLE_DEVICES=0 \ swift sft \ --model Qwen/Qwen2.5-7B-Instruct \ diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index 82d3dcc7e3..a243095692 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -111,14 +111,14 @@ def __len__(self) -> int: class BasePackingDataset: - def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: int = 128, strict: bool = False): + def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False): template._packing = True self.template = template self.dataset = dataset - self.num_workers = num_workers + self.num_proc = num_proc self.packing_interval = packing_interval self.strict = strict - assert num_workers >= 1, f'num_workers: {num_workers}' + assert num_proc >= 1, f'num_proc: {num_proc}' self.workers = [] @staticmethod @@ -150,13 +150,13 @@ def _encode_data(self, data) -> Dict[str, Any]: class PackingDataset(BasePackingDataset, Dataset): - def __init__(self, template, dataset, num_workers: int = 1, *, packing_interval: int = 128, strict: bool = False): - super().__init__(template, dataset, num_workers, packing_interval=packing_interval, strict=strict) - self.prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc='Packing') + def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False): + super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict) + self.prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc=f'Packing (num_proc={num_proc}):') self._queue = mp.Queue() self._terminated_workers = 0 - for i in range(self.num_workers): - shard_dataset = self.dataset.shard(self.num_workers, i) + for i in range(self.num_proc): + shard_dataset = self.dataset.shard(self.num_proc, i) worker = mp.Process(target=self._producer, args=(shard_dataset, ), daemon=True) worker.start() self.workers.append(worker) @@ -172,7 +172,7 @@ def fetch_packing_data(self, res: Optional[list] = None): data = self._queue.get() if data is None: self._terminated_workers += 1 - if self._terminated_workers == self.num_workers: + if self._terminated_workers == self.num_proc: break continue self.prog_bar.update(1) @@ -185,7 +185,7 @@ def get_packed_dataset(self): result = [] while True: data = self.fetch_packing_data(data) - is_finished = self._terminated_workers == self.num_workers + is_finished = self._terminated_workers == self.num_proc res, data = self.calculate_matched_group(self.template, data, is_finished=is_finished) result += res if is_finished: @@ -213,17 +213,17 @@ class IterablePackingDataset(BasePackingDataset, IterableDataset): def __init__(self, template, dataset, - num_workers: int = 1, + num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False, cyclic: bool = False): - super().__init__(template, dataset, num_workers, packing_interval=packing_interval, strict=strict) + super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict) self._in_queue = mp.Queue() self._out_queue = mp.Queue() self.workers = [] self.cyclic = cyclic - for _ in range(self.num_workers): + for _ in range(self.num_proc): worker = mp.Process(target=self._processor, daemon=True) worker.start() self.workers.append(worker) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 9ea8a31bb8..3b2a5d922d 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -253,12 +253,12 @@ def _encode_dataset(self, train_dataset, val_dataset): train_dataset = packing_dataset_cls( self.template, train_dataset, - num_workers=args.dataset_num_proc, + num_proc=args.dataset_num_proc, strict=args.strict, **dataset_kwargs) if val_dataset is not None: val_dataset = packing_dataset_cls( - self.template, val_dataset, num_workers=args.dataset_num_proc, strict=args.strict) + self.template, val_dataset, num_proc=args.dataset_num_proc, strict=args.strict) elif args.lazy_tokenize: train_dataset = LazyLLMDataset( train_dataset, template.encode, strict=args.strict, random_state=args.data_seed) From 61063bfcbb88290edac29cf676ea0639bf07e6c2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 4 May 2025 17:20:07 +0800 Subject: [PATCH 4/6] update --- swift/llm/dataset/loader.py | 8 +- swift/llm/dataset/preprocessor/core.py | 17 +-- swift/llm/dataset/utils.py | 138 ++++++++++--------------- swift/llm/train/sft.py | 6 +- swift/utils/torch_utils.py | 3 +- 5 files changed, 72 insertions(+), 100 deletions(-) diff --git a/swift/llm/dataset/loader.py b/swift/llm/dataset/loader.py index 56ce7d672a..247decd644 100644 --- a/swift/llm/dataset/loader.py +++ b/swift/llm/dataset/loader.py @@ -198,7 +198,7 @@ def _load_dataset_path( dataset_meta: DatasetMeta, *, num_proc: int = 1, - load_from_cache_file: bool = False, + load_from_cache_file: bool = True, strict: bool = False, streaming: bool = False, columns: Optional[Dict[str, str]] = None, @@ -224,7 +224,7 @@ def _load_repo_dataset( subset: SubsetDataset, *, num_proc: int = 1, - load_from_cache_file: bool = False, + load_from_cache_file: bool = True, streaming: bool = False, use_hf: Optional[bool] = None, hub_token: Optional[str] = None, @@ -377,7 +377,7 @@ def load( dataset_meta: Optional[DatasetMeta] = None, *, num_proc: int = 1, - load_from_cache_file: bool = False, + load_from_cache_file: bool = True, streaming: bool = False, use_hf: Optional[bool] = None, hub_token: Optional[str] = None, @@ -442,7 +442,7 @@ def load_dataset( split_dataset_ratio: float = 0., seed: Union[int, np.random.RandomState, None] = None, num_proc: int = 1, - load_from_cache_file: bool = False, + load_from_cache_file: bool = True, shuffle: bool = False, streaming: bool = False, interleave_prob: Optional[List[float]] = None, diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/llm/dataset/preprocessor/core.py index d1d79eb7a4..2507bfa493 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/llm/dataset/preprocessor/core.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import ast from collections import Counter -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -279,7 +279,7 @@ def __call__( dataset: DATASET_TYPE, *, num_proc: int = 1, - load_from_cache_file: bool = False, + load_from_cache_file: bool = True, strict: bool = False, batch_size: Optional[int] = None, ) -> DATASET_TYPE: @@ -290,9 +290,12 @@ def __call__( dataset = sample_dataset(dataset, self.dataset_sample, True, self.random_state) map_kwargs = {'batched': True, 'batch_size': batch_size} + map_context = nullcontext() if isinstance(dataset, HfDataset): - if is_dist() and not is_master(): - load_from_cache_file = True + if not load_from_cache_file: + map_context = safe_ddp_context(None, True) + if is_dist() and not is_master(): + load_from_cache_file = True map_kwargs.update({ 'num_proc': num_proc, 'load_from_cache_file': load_from_cache_file, @@ -300,14 +303,14 @@ def __call__( # compat GRPO: The solution field will be retained. dataset = RowPreprocessor.get_features_dataset(dataset) if 'solution' in dataset.features: - with safe_ddp_context('dataset_map_solution', True): + with map_context: dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs) dataset = self._rename_columns(dataset) dataset = self.prepare_dataset(dataset) dataset = self._cast_pil_image(dataset) ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False - with self._patch_arrow_writer(), safe_ddp_context('dataset_map', True): + with self._patch_arrow_writer(), map_context: try: dataset_mapped = dataset.map( self.batched_preprocess, @@ -516,7 +519,7 @@ def __call__( dataset: DATASET_TYPE, *, num_proc: int = 1, - load_from_cache_file: bool = False, + load_from_cache_file: bool = True, strict: bool = False, ) -> DATASET_TYPE: dataset = RowPreprocessor.safe_rename_columns(dataset, self.columns) diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index a243095692..2e782853d5 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -109,98 +109,56 @@ def __len__(self) -> int: return len(self.dataset) -class BasePackingDataset: +def calculate_matched_group(template, sequences, is_finished: bool = True): + if len(sequences) == 0: + return [], [] + # https://arxiv.org/pdf/2404.10830 + import binpacking + sequences = binpacking.to_constant_volume(sequences, template.max_length, weight_pos=1) + res = [] + if sequences and not is_finished: + sequences, ret_sequences = sequences[:-1], sequences[-1] + else: + ret_sequences = [] + for row in sequences: + packed = template.packing_row(row) + res.append(packed) + return res, ret_sequences + + +class PackingDataset(Dataset): def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False): - template._packing = True self.template = template self.dataset = dataset self.num_proc = num_proc self.packing_interval = packing_interval - self.strict = strict - assert num_proc >= 1, f'num_proc: {num_proc}' - self.workers = [] - - @staticmethod - def calculate_matched_group(template, sequences, is_finished: bool = True): - if len(sequences) == 0: - return [], [] - # https://arxiv.org/pdf/2404.10830 - import binpacking - sequences = binpacking.to_constant_volume(sequences, template.max_length, weight_pos=1) - res = [] - if sequences and not is_finished: - sequences, ret_sequences = sequences[:-1], sequences[-1] - else: - ret_sequences = [] - for row in sequences: - packed = template.packing_row(row) - res.append(packed) - return res, ret_sequences - - def _encode_data(self, data) -> Dict[str, Any]: - res = None - try: - res = self.template.encode(data) - except Exception as e: - if self.strict and not isinstance(e, MaxLengthError): - raise - return res or {} - - -class PackingDataset(BasePackingDataset, Dataset): - - def __init__(self, template, dataset, num_proc: int = 1, *, packing_interval: int = 128, strict: bool = False): - super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict) - self.prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc=f'Packing (num_proc={num_proc}):') - self._queue = mp.Queue() - self._terminated_workers = 0 - for i in range(self.num_proc): - shard_dataset = self.dataset.shard(self.num_proc, i) - worker = mp.Process(target=self._producer, args=(shard_dataset, ), daemon=True) - worker.start() - self.workers.append(worker) - - self.packed_dataset = self.get_packed_dataset() - self.prog_bar.close() - for worker in self.workers: - worker.terminate() - - def fetch_packing_data(self, res: Optional[list] = None): - res = res or [] - for _ in range(self.packing_interval): - data = self._queue.get() - if data is None: - self._terminated_workers += 1 - if self._terminated_workers == self.num_proc: - break - continue - self.prog_bar.update(1) - if data: - res.append((data, len(data['input_ids']))) - return res + dataset = dataset.to_iterable_dataset(num_shards=num_proc) + dataset = EncodePreprocessor(template)(dataset, num_proc=num_proc, strict=strict) + self.packed_dataset = self.get_packed_dataset(dataset) - def get_packed_dataset(self): - data = [] + def get_packed_dataset(self, dataset): + data_list = [] result = [] - while True: - data = self.fetch_packing_data(data) - is_finished = self._terminated_workers == self.num_proc - res, data = self.calculate_matched_group(self.template, data, is_finished=is_finished) + it = iter(dataset) + is_finished = False + prog_bar = tqdm(total=len(dataset), dynamic_ncols=True, desc=f'Packing (num_proc={num_proc}):') + + while not is_finished: + try: + for _ in range(self.packing_interval): + data = next(it) + prog_bar.update(1) + data_list.append((data, len(data['input_ids']))) + except StopIteration: + is_finished = True + res, data = calculate_matched_group(self.template, data_list, is_finished=is_finished) result += res if is_finished: break + prog_bar.close() return result - def _producer(self, shard_dataset): - for data in shard_dataset: - encoded_data = self._encode_data(data) # ignore - self._queue.put(encoded_data) - self._queue.put(None) - while True: - # Wait for the main process to terminate to avoid fd anomalies. - time.sleep(0.1) - def __getitem__(self, index): return self.packed_dataset[index].copy() @@ -208,7 +166,7 @@ def __len__(self): return len(self.packed_dataset) -class IterablePackingDataset(BasePackingDataset, IterableDataset): +class IterablePackingDataset(IterableDataset): def __init__(self, template, @@ -218,16 +176,30 @@ def __init__(self, packing_interval: int = 128, strict: bool = False, cyclic: bool = False): - super().__init__(template, dataset, num_proc, packing_interval=packing_interval, strict=strict) + self.template = template + self.dataset = dataset + self.num_proc = num_proc + self.packing_interval = packing_interval + self.strict = strict + self.cyclic = cyclic + self._in_queue = mp.Queue() self._out_queue = mp.Queue() self.workers = [] - self.cyclic = cyclic for _ in range(self.num_proc): worker = mp.Process(target=self._processor, daemon=True) worker.start() self.workers.append(worker) + def _encode_data(self, data) -> Dict[str, Any]: + res = None + try: + res = self.template.encode(data) + except Exception as e: + if self.strict and not isinstance(e, MaxLengthError): + raise + return res or {} + def _processor(self): while True: data = self._in_queue.get() @@ -276,7 +248,7 @@ def __iter__(self): while True: finished = self._put_data_in_queue(iterator) data = self._fetch_data_out_queue(data) - res, data = self.calculate_matched_group(self.template, data, is_finished=finished) + res, data = calculate_matched_group(self.template, data, is_finished=finished) yield from res if finished: break diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 3b2a5d922d..e76eab9b57 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -251,11 +251,7 @@ def _encode_dataset(self, train_dataset, val_dataset): packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset dataset_kwargs = {'cyclic': True} if args.streaming else {} train_dataset = packing_dataset_cls( - self.template, - train_dataset, - num_proc=args.dataset_num_proc, - strict=args.strict, - **dataset_kwargs) + self.template, train_dataset, num_proc=args.dataset_num_proc, strict=args.strict, **dataset_kwargs) if val_dataset is not None: val_dataset = packing_dataset_cls( self.template, val_dataset, num_proc=args.dataset_num_proc, strict=args.strict) diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py index a6c66e1adb..9b7ab64f22 100644 --- a/swift/utils/torch_utils.py +++ b/swift/utils/torch_utils.py @@ -230,7 +230,7 @@ def _cond(name, module): @contextmanager -def safe_ddp_context(hash_id: str, use_barrier: bool = False): +def safe_ddp_context(hash_id: Optional[str], use_barrier: bool = False): if use_barrier and dist.is_initialized(): if (is_dist() or is_dist_ta()) and not is_master(): dist.barrier() @@ -238,6 +238,7 @@ def safe_ddp_context(hash_id: str, use_barrier: bool = False): if (is_dist() or is_dist_ta()) and is_master(): dist.barrier() else: + assert hash_id is not None lock_dir = os.path.join(get_cache_dir(), 'lockers') os.makedirs(lock_dir, exist_ok=True) file_path = hashlib.sha256(hash_id.encode('utf-8')).hexdigest() + '.lock' From 6d653e369da6c70345054f785ddfc9fadcfbb56e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 4 May 2025 17:20:57 +0800 Subject: [PATCH 5/6] update --- swift/llm/train/sft.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 3b2a5d922d..e76eab9b57 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -251,11 +251,7 @@ def _encode_dataset(self, train_dataset, val_dataset): packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset dataset_kwargs = {'cyclic': True} if args.streaming else {} train_dataset = packing_dataset_cls( - self.template, - train_dataset, - num_proc=args.dataset_num_proc, - strict=args.strict, - **dataset_kwargs) + self.template, train_dataset, num_proc=args.dataset_num_proc, strict=args.strict, **dataset_kwargs) if val_dataset is not None: val_dataset = packing_dataset_cls( self.template, val_dataset, num_proc=args.dataset_num_proc, strict=args.strict) From 393bafb05a615042cd60021e2b7fb016f0205882 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 6 May 2025 11:49:10 +0800 Subject: [PATCH 6/6] update --- requirements/install_all.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/install_all.sh b/requirements/install_all.sh index 07501035a4..3bcce3ffc8 100644 --- a/requirements/install_all.sh +++ b/requirements/install_all.sh @@ -4,7 +4,7 @@ pip install "vllm>=0.5.1" -U pip install "lmdeploy>=0.5" -U --no-deps pip install autoawq -U --no-deps pip install auto_gptq optimum bitsandbytes -U -pip install git+https://github.com/modelscope/ms-swift.git#egg=ms-swift[all] +pip install git+https://github.com/modelscope/ms-swift.git pip install timm -U pip install deepspeed -U pip install qwen_vl_utils qwen_omni_utils decord librosa pyav icecream soundfile -U