Skip to content

Commit

Permalink
MCTS Sampler (#2967)
Browse files Browse the repository at this point in the history
  • Loading branch information
lxline authored Feb 8, 2025
1 parent e65e5b4 commit 8f0630e
Show file tree
Hide file tree
Showing 11 changed files with 803 additions and 17 deletions.
11 changes: 11 additions & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数)

- prm_model: 过程奖励模型的类型,可以是模型id(以pt方式拉起),或者plugin中定义的prm key(自定义推理过程)
- orm_model: 结果奖励模型的类型,通常是通配符或测试用例等,一般定义在plugin中
- sampler_type:采样类型,目前支持 sample, mcts,未来会支持 dvts
- sampler_engine:支持`pt`, `lmdeploy`, `vllm`, `client`, `no`,默认为`pt`,采样模型的推理引擎
- sampler_type:采样类型,目前支持sample(do_sample方式),未来会支持mcts和dvts
- sampler_engine:支持`pt`, `lmdeploy`, `vllm`, `no`,默认为`pt`,采样模型的推理引擎
- output_dir:输出目录,默认为`sample_output`
Expand All @@ -448,6 +450,15 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数)
- cache_files:为避免同时加载prm和generator造成显存OOM,可以分两步进行采样,第一步将prm和orm置为`None`,则所有结果都会输出到文件中,第二次运行采样将sampler_engine置为`no`并传入`--cache_files`为上次采样的输出文件,则会使用上次输出的结果进行prm和orm评估并输出最终结果。
- 注意:使用cache_files时,`--dataset`仍然需要传入,这是因为cache_files的id是由原始数据计算的md5,需要把两部分信息结合使用。

#### MCTS
- rollout_depth:rollout 时的最大深度,默认为 `5`
- rollout_start_depth:开始 rollout 时的深度,低于此深度的节点只会进行 expand 操作,默认为 `3`
- max_iterations:mcts 的最大迭代次数,默认为 `100`
- process_reward_rate:select 中计算 value 时 process reward 占的比例,默认为 `0.0`,即不使用 PRM
- exploration_rate:UCT 算法中的探索参数,值越大越照顾探索次数较小的节点,默认为 `0.5`
- api_key:使用 client 作为推理引擎时需要,默认为 `EMPTY`
- base_url:使用 client 作为推理引擎时需要,默认为 'https://dashscope.aliyuncs.com/compatible-mode/v1'


## 特定模型参数
特定模型参数可以通过`--model_kwargs`或者环境变量进行设置,例如: `--model_kwargs '{"fps_max_frames": 12}'`或者`FPS_MAX_FRAMES=12`
Expand Down
9 changes: 9 additions & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,15 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum
- cache_files: To avoid loading both `prm` and `generator` simultaneously and causing GPU memory OOM, sampling can be done in two steps. In the first step, set `prm` and `orm` to `None`, and all results will be output to a file. In the second run, set `sampler_engine` to `no` and pass `--cache_files` with the output file from the first sampling. This will use the results from the first run for `prm` and `orm` evaluation and output the final results.
- Note: When using `cache_files`, the `--dataset` still needs to be provided because the ID for `cache_files` is calculated using the MD5 of the original data. Both pieces of information need to be used together.

#### MCTS
- rollout_depth: The maximum depth during rollouts, default is `5`.
- rollout_start_depth: The depth at which rollouts begin; nodes below this depth will only undergo expand operations, default is `3`.
- max_iterations: The maximum number of iterations for MCTS, default is `100`.
- process_reward_rate: The proportion of process reward used in calculating value during selection, default is `0.0`, meaning PRM is not used.
- exploration_rate: A parameter in the UCT algorithm that balances exploration; a higher value gives more weight to nodes with fewer explorations, default is `0.5`.
- api_key: Required when using the client as an inference engine, default is `EMPTY`.
- base_url: Required when using the client as an inference engine, default is 'https://dashscope.aliyuncs.com/compatible-mode/v1'.

## Specific Model Arguments

Specific model arguments can be set using `--model_kwargs` or environment variables, for example: `--model_kwargs '{"fps_max_frames": 12}'` or `FPS_MAX_FRAMES=12`.
Expand Down
116 changes: 116 additions & 0 deletions examples/sampler/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import subprocess
import time
from typing import List

import json
from modelscope.msdatasets import MsDataset

conda_prefix = ''


def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_count: int, output_dir: str):
handlers = []
# Sampling cache
api_key = os.getenv('DASHSCOPE_API_KEY')

for device in range(device_count):

output_file = f'iter_{iter}_proc_{device}.jsonl'
cache_file = f'iter_{iter}_proc_{device}_cache.jsonl'
dataset = f'train_{device:02}.jsonl'

# output_file_path = os.path.join(output_dir, output_file)
cache_file_path = os.path.join(output_dir, cache_file)
single_dataset_path = os.path.join(dataset_path, dataset)

if not os.path.exists(cache_file_path):
open(cache_file_path, 'w').close()
sample_cmd = (f'USE_OPENCOMPASS_EVALUATOR=True '
f'swift sample '
f'--model {model} '
f'--orm_model {orm} '
f'--sampler_type mcts '
f'--process_reward_rate 0 '
f'--stop_words ки '
f'--seed 42 '
f'--api_key {api_key} '
f'--dataset {single_dataset_path} '
f'--max_length 2048 '
f'--system ./scripts/sampler/system_prompt.txt '
f'--load_args false '
f'--sampler_engine client '
f'--max_new_tokens 768 '
f'--override_exist_file true '
f'--num_sampling_per_gpu_batch_size 1 '
f'--num_return_sequences 8 '
f'--exploration_rate 0.2 '
f'--max_iterations 200 '
f'--output_dir {output_dir} '
f'--cache_files {cache_file} '
f'--output_file {output_file} '
f'--temperature 1.0 ')
print(f'Sampling caches of iter {iter}, part {device}.', flush=True)
# env['CUDA_VISIBLE_DEVICES'] = str(device)
handler = subprocess.Popen(
f'{sample_cmd}' + f' > mcts_logs/sample_iter_{iter}_proc_{device}_cache.log 2>&1',
env=os.environ.copy(),
shell=True,
executable='/bin/bash')
handlers.append(handler)

datasets = []
for proc, handler in enumerate(handlers):
handler.wait()
assert os.path.exists(os.path.join(output_dir, f'iter_{iter}_proc_{proc}.jsonl'))
datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}.jsonl'))
print(f'Sampling done, files:{datasets}', flush=True)


def split_dataset(ds, split_size, out_path):
data_size = int(len(ds) / split_size) + 1

for i in range(split_size):
file_name = f'train_{i:02}.jsonl'
file_path = os.path.join(out_path, file_name)
print(file_path)
ds_split = ds[data_size * i:min(data_size * (i + 1), len(ds))]
print(f"split_size: {len(ds_split['problem'])}")
with open(file_path, 'w', encoding='utf-8') as file:
for problem, solution in zip(ds_split['problem'], ds_split['solution']):
message = {
'messages': [
{
'role': 'user',
'content': problem,
},
{
'role': 'assistant',
'content': solution,
},
]
}
file.write(json.dumps(message, ensure_ascii=False) + '\n')


def main():
server_model = 'qwen-max'
orm = 'math'
device_count = 20
output_dir = 'output/sampler/client_mcts/'
dataset_dir = 'datasets/competition_math/'
log_dir = 'mcts_logs/'

os.makedirs(output_dir, exist_ok=True)
os.makedirs(dataset_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
ds = MsDataset.load('tastelikefeet/competition_math', subset_name='default', split='train')
split_dataset(ds, device_count, dataset_dir)

ts = time.time()
client_sample(server_model, orm, dataset_dir, 0, device_count, output_dir)
print(f'do sample cost: {(time.time() - ts) / 60:.1f} minutes.', flush=True)


if __name__ == '__main__':
main()
35 changes: 35 additions & 0 deletions examples/sampler/mcts.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
export CUDA_VISIBLE_DEVICES=0
export USE_OPENCOMPASS_EVALUATOR=True

swift sample \
--model ./output/Qwen2.5-Math-7B-Instruct/v40-20250126-161112/checkpoint-20 \
--orm_model math \
--sampler_type mcts \
--sampler_engine vllm \
--output_dir ./output/sampler/mcts \
--system ./examples/sampler/system_prompt.txt \
--stop_words ки \
--dataset ./datasets/competition_math/small_test.jsonl \
--num_return_sequences 2 \
--process_reward_rate 0 \
--max_new_tokens 2048

## Train
# nproc_per_node=8
# NPROC_PER_NODE=$nproc_per_node \
# swift sft \
# --model Qwen/Qwen2.5-Math-7B-Instruct \
# --train_type full \
# --torch_dtype bfloat16 \
# --dataset 'datasets/gen_V5.jsonl' \
# --num_train_epochs 1 \
# --per_device_train_batch_size 1 \
# --learning_rate 1e-5 \
# --gradient_accumulation_steps $(expr 128 / $nproc_per_node) \
# --eval_steps 1000 \
# --save_steps 10 \
# --save_total_limit 100 \
# --max_length 10000 \
# --logging_steps 5 \
# --gradient_checkpointing_kwargs '{"use_reentrant": false}' \
# --deepspeed zero3
7 changes: 7 additions & 0 deletions examples/sampler/system_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
You are a math model, you should **think step by step** carefully. Each step should **end with \"ки\”**. Final answer should be in a ‘\boxed()’.

## Example:
Step1: XXX. ки\n
Step2: XXX. ки\n
Step3: XXX. ки\n
Answer: \boxed(answer). ки\n
29 changes: 27 additions & 2 deletions swift/llm/argument/sampling_args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import dataclasses
import os
from dataclasses import dataclass
from datetime import datetime
from typing import List, Literal, Optional
Expand All @@ -20,8 +21,8 @@ class SamplingArguments(BaseArguments):

# sampler settings
# sample/mcts/dvts/xxx
sampler_type: str = 'sample'
sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no'] = 'pt'
sampler_type: Literal['sample', 'mcts'] = 'sample'
sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no', 'client'] = 'pt'
output_dir: str = 'sample_output'
output_file: Optional[str] = None
override_exist_file: bool = False
Expand All @@ -42,6 +43,21 @@ class SamplingArguments(BaseArguments):
# Vanilla
cache_files: List[str] = dataclasses.field(default_factory=list)

# MCTS
rollout_depth: int = 5
rollout_start_depth: int = 3
max_iterations: int = 100
process_reward_rate: float = 0.0
exploration_rate: float = 0.5
api_key: str = 'EMPTY'
base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1'

def _init_model_info(self):
if self.sampler_engine != 'client':
return super()._init_model_info()
self.task_type = 'causal_lm'
return

def __post_init__(self):
if self.output_file is None:
now = datetime.now()
Expand All @@ -58,4 +74,13 @@ def __post_init__(self):
self.engine_kwargs = json.loads(self.engine_kwargs)
else:
self.engine_kwargs = {}

super().__post_init__()

if self.system is not None:
self.system_message = [{
'role': 'system',
'content': self.system,
}]
else:
self.system_message = []
Loading

0 comments on commit 8f0630e

Please sign in to comment.