Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Jul 16, 2024
1 parent 1f97122 commit b4fa78b
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions projects/modular_llm/train_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,13 @@ def create_library(args):
# dm = get_datamodule(args)
# args.n_tasks = len(dm._task_names)
# args.task_names = dm._task_names
# if args.router_selector == "arrow_router":
args.trainable_param_names = None
# args.trainable_param_names = ".*prototypes.*"
ref_model = model_class(
**vars(args), tokenizer=dm.tokenizer, expert_library=expert_library
)

if args.rl_training == "dpo":
args.trainable_param_names = ".*prototypes.*"
# args.trainable_param_names = ".*prototypes.*"
model = model_class(
**vars(args), tokenizer=dm.tokenizer, expert_library=expert_library
)
Expand All @@ -117,7 +116,9 @@ def create_library(args):
# # ref_model = copy.deepcopy(model)
# ref_model.add_experts_from_library(expert_library)
# patch_prototypes(ref_model, expert_library, args)
module = ExpertModelDPO(model, ref_model, **vars(args))
module = ExpertModelDPO(
**vars(args), expert_model=model, ref_expert_model=ref_model
)

# get metric monitors for models
callbacks = get_monitors(args)
Expand Down

0 comments on commit b4fa78b

Please sign in to comment.