From 5b3c4d61d60265130d1090c8bd40bd215f8ddd4d Mon Sep 17 00:00:00 2001 From: ex-yanminmin001 Date: Mon, 25 Nov 2024 03:02:44 +0000 Subject: [PATCH] add custom dataset config file as input --- src/llamafactory/data/loader.py | 3 ++- src/llamafactory/data/parser.py | 6 ++++-- src/llamafactory/hparams/data_args.py | 4 ++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 540dff1c4a..f227d57f16 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -164,7 +164,8 @@ def _get_merged_dataset( return None datasets = [] - for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): + # for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): + for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir, data_args.custom_config): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): raise ValueError("The dataset is not applicable in the current training stage.") diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 709d0c900c..2826c4d4d8 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -71,7 +71,8 @@ def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) setattr(self, key, obj.get(key, default)) -def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: +# def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: +def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str, custom_config: str) -> List["DatasetAttr"]: r""" Gets the attributes of the datasets. """ @@ -84,7 +85,8 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - if dataset_dir.startswith("REMOTE:"): config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") else: - config_path = os.path.join(dataset_dir, DATA_CONFIG) + # config_path = os.path.join(dataset_dir, DATA_CONFIG) + config_path = os.path.join(dataset_dir, DATA_CONFIG if custom_config is None else custom_config) try: with open(config_path) as f: diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 2d7e30df96..11a55f0742 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -40,6 +40,10 @@ class DataArguments: dataset_dir: str = field( default="data", metadata={"help": "Path to the folder containing the datasets."}, + ) + custom_config: Optional[str] = field( + default=None, + metadata={"help": "The path of custom config to use for training. ."}, ) image_dir: Optional[str] = field( default=None,