Skip to content

Commit

Permalink
save an import for researcher
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 1, 2024
1 parent 7c4ba1a commit e9a582c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ from torch import Tensor

from self_rewarding_lm_pytorch import (
SelfRewardingTrainer,
create_mock_dataset,
create_default_paper_config
create_mock_dataset
)

from x_transformers import TransformerWrapper, Decoder
Expand All @@ -54,7 +53,7 @@ def encode_str(seq_str: str) -> Tensor:

trainer = SelfRewardingTrainer(
transformer,
finetune_configs = create_default_paper_config(
finetune_configs = dict(
train_sft_dataset = sft_dataset,
self_reward_prompt_dataset = prompt_dataset,
dpo_num_train_steps = 1000
Expand Down Expand Up @@ -157,14 +156,16 @@ from self_rewarding_lm_pytorch import (
)

trainer = SelfRewardingTrainer(
model,
finetune_configs = [
SFTConfig(...),
SelfPlayConfig(...),
ExternalRewardDPOConfig(...),
SelfRewardDPOConfig(...),
SelfPlayConfig(...),
SelfRewardDPOConfig(...)
]
],
...
)

trainer()
Expand Down
8 changes: 7 additions & 1 deletion self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ class SelfPlayConfig(FinetuneConfig):

# generated default config for paper

@beartype
def create_default_paper_config(
*,
train_sft_dataset: Dataset,
Expand Down Expand Up @@ -720,7 +721,7 @@ def __init__(
self,
model: Module,
*,
finetune_configs: List[FinetuneConfig],
finetune_configs: Union[Dict, List[FinetuneConfig]],
tokenizer_encode: Callable[[str], TensorType['seq', int]],
tokenizer_decode: Callable[[TensorType['seq', int]], str],
self_reward_prompt_config: Union[RewardConfig, Dict[str, RewardConfig]] = SELF_REWARD_PROMPT_CONFIG,
Expand All @@ -737,6 +738,11 @@ def __init__(
if isinstance(self_reward_prompt_config, RewardConfig):
self_reward_prompt_config = dict(default = self_reward_prompt_config)

# finetune config

if isinstance(finetune_configs, dict):
finetune_configs = create_default_paper_config(**finetune_configs)

# model and accelerator

self.model = model
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'self-rewarding-lm-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.0',
version = '0.2.1',
license='MIT',
description = 'Self Rewarding LM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e9a582c

Please sign in to comment.