From e9a582c52c0391319368dda60b002ff2a259c7bb Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 1 Feb 2024 09:35:02 -0800 Subject: [PATCH] save an import for researcher --- README.md | 9 +++++---- self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py | 8 +++++++- setup.py | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 94db031..eb233be 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,7 @@ from torch import Tensor from self_rewarding_lm_pytorch import ( SelfRewardingTrainer, - create_mock_dataset, - create_default_paper_config + create_mock_dataset ) from x_transformers import TransformerWrapper, Decoder @@ -54,7 +53,7 @@ def encode_str(seq_str: str) -> Tensor: trainer = SelfRewardingTrainer( transformer, - finetune_configs = create_default_paper_config( + finetune_configs = dict( train_sft_dataset = sft_dataset, self_reward_prompt_dataset = prompt_dataset, dpo_num_train_steps = 1000 @@ -157,6 +156,7 @@ from self_rewarding_lm_pytorch import ( ) trainer = SelfRewardingTrainer( + model, finetune_configs = [ SFTConfig(...), SelfPlayConfig(...), @@ -164,7 +164,8 @@ trainer = SelfRewardingTrainer( SelfRewardDPOConfig(...), SelfPlayConfig(...), SelfRewardDPOConfig(...) - ] + ], + ... ) trainer() diff --git a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py index 87637f0..540f4c6 100644 --- a/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py +++ b/self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py @@ -674,6 +674,7 @@ class SelfPlayConfig(FinetuneConfig): # generated default config for paper +@beartype def create_default_paper_config( *, train_sft_dataset: Dataset, @@ -720,7 +721,7 @@ def __init__( self, model: Module, *, - finetune_configs: List[FinetuneConfig], + finetune_configs: Union[Dict, List[FinetuneConfig]], tokenizer_encode: Callable[[str], TensorType['seq', int]], tokenizer_decode: Callable[[TensorType['seq', int]], str], self_reward_prompt_config: Union[RewardConfig, Dict[str, RewardConfig]] = SELF_REWARD_PROMPT_CONFIG, @@ -737,6 +738,11 @@ def __init__( if isinstance(self_reward_prompt_config, RewardConfig): self_reward_prompt_config = dict(default = self_reward_prompt_config) + # finetune config + + if isinstance(finetune_configs, dict): + finetune_configs = create_default_paper_config(**finetune_configs) + # model and accelerator self.model = model diff --git a/setup.py b/setup.py index 0b6d25e..92c4bb5 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'self-rewarding-lm-pytorch', packages = find_packages(exclude=[]), - version = '0.2.0', + version = '0.2.1', license='MIT', description = 'Self Rewarding LM - Pytorch', author = 'Phil Wang',