Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions src/alignment/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging

import os
import datasets
from datasets import DatasetDict, concatenate_datasets

Expand All @@ -23,45 +23,57 @@
logger = logging.getLogger(__name__)



def get_dataset(args: ScriptArguments) -> DatasetDict:
"""Load a dataset or a mixture of datasets based on the configuration.

Args:
args (ScriptArguments): Script arguments containing dataset configuration.

Returns:
DatasetDict: The loaded datasets.
"""
if args.dataset_name and not args.dataset_mixture:
logger.info(f"Loading dataset: {args.dataset_name}")
return datasets.load_dataset(args.dataset_name, args.dataset_config)
# Check if it's a local path
if os.path.exists(args.dataset_name):
logger.info(f"Loading local dataset from disk: {args.dataset_name}")
return datasets.load_from_disk(args.dataset_name)
else:
return datasets.load_dataset(args.dataset_name, args.dataset_config)
elif args.dataset_mixture:
logger.info(f"Creating dataset mixture with {len(args.dataset_mixture.datasets)} datasets")
seed = args.dataset_mixture.seed
datasets_list = []

for dataset_config in args.dataset_mixture.datasets:
logger.info(f"Loading dataset for mixture: {dataset_config.id} (config: {dataset_config.config})")
ds = datasets.load_dataset(
dataset_config.id,
dataset_config.config,
split=dataset_config.split,
)

# Check if it's a local path
if os.path.exists(dataset_config.id):
logger.info(f"Loading local dataset from disk: {dataset_config.id}")
ds = datasets.load_from_disk(dataset_config.id)
# Handle split if specified
if dataset_config.split and isinstance(ds, DatasetDict):
ds = ds[dataset_config.split]
else:
ds = datasets.load_dataset(
dataset_config.id,
dataset_config.config,
split=dataset_config.split,
)

if dataset_config.columns is not None:
ds = ds.select_columns(dataset_config.columns)
if dataset_config.weight is not None:
ds = ds.shuffle(seed=seed).select(range(int(len(ds) * dataset_config.weight)))
logger.info(
f"Subsampled dataset '{dataset_config.id}' (config: {dataset_config.config}) with weight={dataset_config.weight} to {len(ds)} examples"
)

datasets_list.append(ds)

if datasets_list:
combined_dataset = concatenate_datasets(datasets_list)
combined_dataset = combined_dataset.shuffle(seed=seed)
logger.info(f"Created dataset mixture with {len(combined_dataset)} examples")

if args.dataset_mixture.test_split_size is not None:
combined_dataset = combined_dataset.train_test_split(
test_size=args.dataset_mixture.test_split_size, seed=seed
Expand All @@ -74,6 +86,5 @@ def get_dataset(args: ScriptArguments) -> DatasetDict:
return DatasetDict({"train": combined_dataset})
else:
raise ValueError("No datasets were loaded from the mixture configuration")

else:
raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided")