Skip to content

Commit

Permalink
clean up CustomDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Nov 28, 2024
1 parent a47137d commit 26769ac
Showing 1 changed file with 46 additions and 29 deletions.
75 changes: 46 additions & 29 deletions src/llmcompressor/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,56 +44,73 @@ def __init__(self, data_args, split, tokenizer):
split=split,
tokenizer=tokenizer,
)
self.preprocessing_func = data_args.preprocessing_func
self.remove_columns = data_args.remove_columns

def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]:
"""Get the raw dataset and apply preprocessing func if provided"""

dataset = self.data_args.dataset
if isinstance(dataset, DatasetDict) or isinstance(dataset, Dataset):
# user passed in an already instantiated dataset, just use it directly
raw_dataset = dataset
else:
# dataset must be loaded from file or HF Hub
raw_dataset = super().get_raw_dataset()

if self.preprocessing_func is not None:
if callable(self.preprocessing_func):
func = self.preprocessing_func
elif ":" in self.preprocessing_func:
# load dataset
dataset = (
self.data_args.dataset
if isinstance(self.data_args.dataset, (DatasetDict, Dataset))
else super().get_raw_dataset() # load dataset from file or HF Hub
)

# preprocess dataset
dataset = self._preprocess_dataset(dataset)
dataset = self._remove_columns_from_dataset(dataset)

return dataset

def _preprocess_dataset(
self, dataset: Union[DatasetDict, Dataset]
) -> Union[DatasetDict, Dataset]:
preprocessing_func = self.data_args.preprocessing_func

if preprocessing_func is not None:
if callable(preprocessing_func):
pass

elif ":" in preprocessing_func:
# load func_name from "/path/to/file.py:func_name"
func = import_from_path(self.preprocessing_func)
preprocessing_func = import_from_path(preprocessing_func)
else:
# load from the registry
func = PreprocessingFunctionRegistry.get_value_from_registry(
name=self.preprocessing_func
preprocessing_func = (
PreprocessingFunctionRegistry.get_value_from_registry(
name=preprocessing_func
)
)

raw_dataset = self.map(
raw_dataset,
function=func,
dataset = self.map(
dataset,
function=preprocessing_func,
batched=False,
num_proc=self.data_args.preprocessing_num_workers,
desc="Applying custom func to the custom dataset",
)

self.remove_columns = (
self.remove_columns or self.get_remove_columns_from_dataset(raw_dataset)
)
return dataset

def _remove_columns_from_dataset(
self, dataset: Union[DatasetDict, Dataset]
) -> Union[DatasetDict, Dataset]:
remove_columns = self.data_args.remove_columns

if not remove_columns:
remove_columns = self._get_remove_columns_from_dataset(dataset)

if self.remove_columns is not None:
raw_dataset = self.map(
raw_dataset,
if remove_columns is not None:
dataset = self.map(
dataset,
batched=True,
remove_columns=self.remove_columns,
remove_columns=remove_columns,
num_proc=self.data_args.preprocessing_num_workers,
desc="Removing unneeded columns",
)

return raw_dataset
return dataset

def get_remove_columns_from_dataset(
def _get_remove_columns_from_dataset(
self, raw_dataset: Union[DatasetDict, Dataset]
) -> List[str]:
"""Remove redandant columns from the dataset for processing"""
Expand Down

0 comments on commit 26769ac

Please sign in to comment.