Skip to content

Commit

Permalink
use adafactor with accelerate
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Nov 5, 2024
1 parent 5d87a2a commit 4e79c5f
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 30 deletions.
8 changes: 4 additions & 4 deletions conf/finetune/rl_llama31_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ wandb_resume: always
# Whether to use only the basename or the full path as the run name
wandb_use_basename: false
config_name: meta-llama/Meta-Llama-3.1-8B-Instruct
learning_rate: 0.000005
train_batch_size: 1
gradient_accumulation_passes: 1024
learning_rate: 0.0000025
train_batch_size: 2
gradient_accumulation_passes: 512
seq_length: 4096
load_as_bf16: True
max_train_steps: 100000
save_checkpoint_steps: ???
optim: adamw_torch
optim: adafactor
objective: rl
log_each_n_steps: 1
resume_dataloader: false
Expand Down
2 changes: 2 additions & 0 deletions conf/rl_debug.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
defaults:
- rl_gsm8k
- _self_

max_agent_forks: 16
attempts: 1
test_every_n_iterations: -1
Expand Down
5 changes: 2 additions & 3 deletions conf/rl_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ attempts: 64
force_restart: false
max_iterations: 100
discount: 0.99
implicit_kl: false
implicit_kl: 0.0
max_steps: 100
llm:
parameters:
Expand Down Expand Up @@ -39,8 +39,7 @@ vllm_config:
--enable-chunked-prefill: ""

output_dir: outputs/rl_gsm8k
accelerate_cfg_path: conf/deepspeed/accelerate_local.yaml
deepspeed_cfg_path: conf/deepspeed/deepspeed_stage3_bf16.json
accelerate_cfg_path: conf/deepspeed/accelerate_base.yaml

hydra:
run:
Expand Down
8 changes: 5 additions & 3 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ def main(cfg: DictConfig):

datasets = [("train", train_agent, train_tapes)]
if state["iteration"] % cfg.test_every_n_iterations == 0 and cfg.test_every_n_iterations > 0:
datasets.append(("test", test_agent, test_tapes))
#datasets.append(("test", test_agent, test_tapes))
#TODO: for debugging purposes, remove before merging
datasets.append(("test", test_agent, train_tapes[:1000]))
all_results = {}
with VLLMServiceManager(
model_name_or_path=assistant_model_path,
Expand Down Expand Up @@ -415,7 +417,7 @@ def main(cfg: DictConfig):
{
"execution_time/populating_ref_logprobs": time_populating_ref_logprobs,
"execution_time/starting_assistantmodel_vllm": assistant_vllm_stats["starting_time"],
"execution_time/starting_refmodel_vllm": assistant_vllm_stats["starting_time"],
"execution_time/starting_refmodel_vllm": refmodel_vllm_stats["starting_time"],
},
step=state["iteration"],
)
Expand All @@ -439,7 +441,7 @@ def main(cfg: DictConfig):
OmegaConf.save(finetune_cfg, config_path)

start_finetune = time.time()
launch_training(str(conf_dir), str(state["iteration"]), cfg.accelerate_cfg_path, cfg.deepspeed_cfg_path)
launch_training(str(conf_dir), str(state["iteration"]), cfg.accelerate_cfg_path)
time_finetune = time.time() - start_finetune
time_iteration = time.time() - start_iteration
wandb.log(
Expand Down
26 changes: 6 additions & 20 deletions examples/rl_gsm8k/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def calculate_stats(stats):
}


def launch_training(config_dir: str, config_name: str, accelerate_cfg_path: str, deepspeed_cfg_path: str):
def launch_training(config_dir: str, config_name: str, accelerate_cfg_path: str):
"""
Launch training process with proper GPU configuration and error handling.
Expand Down Expand Up @@ -307,27 +307,13 @@ def launch_training(config_dir: str, config_name: str, accelerate_cfg_path: str,
]

if num_gpus > 1:
#TODO: better handling of multi-gpu training: accelerate or deepspeed
if False:
base_cmd[2:2] = [
"--use_deepspeed",
"--num_processes",
str(num_gpus),
"--deepspeed_config_file",
deepspeed_cfg_path,
]
else:
base_cmd[2:2] = [
#"--use_deepspeed",
"--multi_gpu",
"--num_processes",
str(num_gpus),
#"--deepspeed_config_file",
#deepspeed_cfg_path,
]
base_cmd[2:2] = [
"--multi_gpu",
"--num_processes",
str(num_gpus),
]

logger.info(f"Launching training with command: {' '.join(base_cmd)}")
print(f"Launching training with command: {' '.join(base_cmd)}")
try:
result = subprocess.run(
base_cmd,
Expand Down

0 comments on commit 4e79c5f

Please sign in to comment.