diff --git a/src/llmcompressor/transformers/finetune/data/custom.py b/src/llmcompressor/transformers/finetune/data/custom.py index e849594e7..7b74f30ac 100644 --- a/src/llmcompressor/transformers/finetune/data/custom.py +++ b/src/llmcompressor/transformers/finetune/data/custom.py @@ -44,56 +44,73 @@ def __init__(self, data_args, split, tokenizer): split=split, tokenizer=tokenizer, ) - self.preprocessing_func = data_args.preprocessing_func - self.remove_columns = data_args.remove_columns def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]: """Get the raw dataset and apply preprocessing func if provided""" - dataset = self.data_args.dataset - if isinstance(dataset, DatasetDict) or isinstance(dataset, Dataset): - # user passed in an already instantiated dataset, just use it directly - raw_dataset = dataset - else: - # dataset must be loaded from file or HF Hub - raw_dataset = super().get_raw_dataset() - - if self.preprocessing_func is not None: - if callable(self.preprocessing_func): - func = self.preprocessing_func - elif ":" in self.preprocessing_func: + # load dataset + dataset = ( + self.data_args.dataset + if isinstance(self.data_args.dataset, (DatasetDict, Dataset)) + else super().get_raw_dataset() # load dataset from file or HF Hub + ) + + # preprocess dataset + dataset = self._preprocess_dataset(dataset) + dataset = self._remove_columns_from_dataset(dataset) + + return dataset + + def _preprocess_dataset( + self, dataset: Union[DatasetDict, Dataset] + ) -> Union[DatasetDict, Dataset]: + preprocessing_func = self.data_args.preprocessing_func + + if preprocessing_func is not None: + if callable(preprocessing_func): + pass + + elif ":" in preprocessing_func: # load func_name from "/path/to/file.py:func_name" - func = import_from_path(self.preprocessing_func) + preprocessing_func = import_from_path(preprocessing_func) else: # load from the registry - func = PreprocessingFunctionRegistry.get_value_from_registry( - name=self.preprocessing_func + preprocessing_func = ( + PreprocessingFunctionRegistry.get_value_from_registry( + name=preprocessing_func + ) ) - raw_dataset = self.map( - raw_dataset, - function=func, + dataset = self.map( + dataset, + function=preprocessing_func, batched=False, num_proc=self.data_args.preprocessing_num_workers, desc="Applying custom func to the custom dataset", ) - self.remove_columns = ( - self.remove_columns or self.get_remove_columns_from_dataset(raw_dataset) - ) + return dataset + + def _remove_columns_from_dataset( + self, dataset: Union[DatasetDict, Dataset] + ) -> Union[DatasetDict, Dataset]: + remove_columns = self.data_args.remove_columns + + if not remove_columns: + remove_columns = self._get_remove_columns_from_dataset(dataset) - if self.remove_columns is not None: - raw_dataset = self.map( - raw_dataset, + if remove_columns is not None: + dataset = self.map( + dataset, batched=True, - remove_columns=self.remove_columns, + remove_columns=remove_columns, num_proc=self.data_args.preprocessing_num_workers, desc="Removing unneeded columns", ) - return raw_dataset + return dataset - def get_remove_columns_from_dataset( + def _get_remove_columns_from_dataset( self, raw_dataset: Union[DatasetDict, Dataset] ) -> List[str]: """Remove redandant columns from the dataset for processing"""