From bded2ccf5783abec2734088781078240316f13db Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 1 Feb 2024 09:37:37 -0800 Subject: [PATCH] sft trainer auto concats multiple datasets --- self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py index 540f4c6..136d806 100644 --- a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py +++ b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py @@ -623,7 +623,7 @@ class FinetuneConfig: @dataclass class SFTConfig(FinetuneConfig): - train_dataset: Dataset + train_dataset: Union[Dataset, List[Dataset]] valid_dataset: Optional[Dataset] = None dropout: float = 0.1 trainer_kwargs: dict = field(default_factory = dict) @@ -677,7 +677,7 @@ class SelfPlayConfig(FinetuneConfig): @beartype def create_default_paper_config( *, - train_sft_dataset: Dataset, + train_sft_dataset: Union[Dataset, List[Dataset]], self_reward_prompt_dataset: Union[Dataset, Tuple[Dataset, Dataset]], valid_sft_dataset: Optional[Dataset] = None, num_generated_preference_pairs = (3964, 6942), diff --git a/setup.py b/setup.py index 92c4bb5..aec0fcf 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'self-rewarding-lm-pytorch', packages = find_packages(exclude=[]), - version = '0.2.1', + version = '0.2.2', license='MIT', description = 'Self Rewarding LM - Pytorch', author = 'Phil Wang',