diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 1c3a7747bf..4924f98ef4 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -239,15 +239,19 @@ def get_dataset( if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") - dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) + tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path) logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.") dataset_module: Dict[str, "Dataset"] = {} - if "train" in dataset_dict: - dataset_module["train_dataset"] = dataset_dict["train"] + if isinstance(tokenized_data, DatasetDict): + if "train" in tokenized_data: + dataset_module["train_dataset"] = tokenized_data["train"] - if "validation" in dataset_dict: - dataset_module["eval_dataset"] = dataset_dict["validation"] + if "validation" in tokenized_data: + dataset_module["eval_dataset"] = tokenized_data["validation"] + + else: # Dataset + dataset_module["train_dataset"] = tokenized_data if data_args.streaming: dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}