From 01d3ac67e604447345b01dfb2bc318200a135d8e Mon Sep 17 00:00:00 2001 From: noiji Date: Thu, 22 Aug 2024 00:23:35 +0900 Subject: [PATCH] Load huggingface data with revision --- src/llamafactory/data/loader.py | 1 + src/llamafactory/data/parser.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 069ea1997e..0c3480a683 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -111,6 +111,7 @@ def _load_single_dataset( token=model_args.hf_hub_token, streaming=(data_args.streaming and (dataset_attr.load_from != "file")), trust_remote_code=True, + revision=dataset_attr.revision, ) if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 2dccfc5d21..d761f1cf24 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -39,6 +39,7 @@ class DatasetAttr: split: str = "train" folder: Optional[str] = None num_samples: Optional[int] = None + revision: Optional[str] = None # common columns system: Optional[str] = None tools: Optional[str] = None @@ -112,7 +113,11 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - if (use_modelscope() and has_ms_url) or (not has_hf_url): dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) else: - dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + dataset_attr = DatasetAttr( + "hf_hub", + dataset_name=dataset_info[name]["hf_hub_url"], + revision=dataset_info[name].get("revision") + ) elif "script_url" in dataset_info[name]: dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) else: