Skip to content

Commit e75b4f9

Browse files
authored
[None][feat] Dev DeepConf (#8362)
Signed-off-by: Dong Cao <[email protected]>
1 parent 4143887 commit e75b4f9

File tree

5 files changed

+576
-0
lines changed

5 files changed

+576
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import argparse
2+
import time
3+
4+
from tensorrt_llm.scaffolding import (NativeGenerationController,
5+
ScaffoldingLlm, TRTLLMWorker)
6+
from tensorrt_llm.scaffolding.contrib.DeepConf import (
7+
DeepConfOfflineController, DeepConfOfflineMajorityVoteController,
8+
DeepConfOnlineController, DeepConfOnlineMajorityVoteController)
9+
10+
_RUN_TYPE_TO_IMPL = {
11+
"offline": DeepConfOfflineController,
12+
"online": DeepConfOnlineController,
13+
"offline_majority_vote": DeepConfOfflineMajorityVoteController,
14+
"online_majority_vote": DeepConfOnlineMajorityVoteController,
15+
}
16+
17+
18+
def parse_arguments():
19+
parser = argparse.ArgumentParser()
20+
# .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
21+
parser.add_argument(
22+
'--model_dir',
23+
type=str,
24+
required=True,
25+
help="Path to the directory containing the generation model")
26+
parser.add_argument('--run_type',
27+
type=str,
28+
required=True,
29+
choices=list(_RUN_TYPE_TO_IMPL.keys()),
30+
help="Type of the run. Available choices: %(choices)s")
31+
parser.add_argument('--sample_num', type=int, default=20)
32+
parser.add_argument('--conf_group_size', type=int, default=128)
33+
parser.add_argument('--conf_threshold', type=float, default=0.5)
34+
parser.add_argument('--vote_policy',
35+
type=str,
36+
default="top10_bottom_window_filtered")
37+
parser.add_argument('--warmup_sample_num', type=int, default=5)
38+
parser.add_argument('--confidence_percentile', type=int, default=90)
39+
parser.add_argument('--logprobs_topk', type=int, default=20)
40+
parser.add_argument('--max_tokens', type=int, default=8192)
41+
parser.add_argument('--temperature', type=float, default=0.6)
42+
parser.add_argument('--top_p', type=float, default=0.95)
43+
args = parser.parse_args()
44+
return args
45+
46+
47+
def run_scaffolding_llm(prompts, proposer_worker, controller):
48+
llm = ScaffoldingLlm(
49+
controller,
50+
{
51+
NativeGenerationController.WorkerTag.GENERATION: proposer_worker,
52+
},
53+
)
54+
time_start = time.time()
55+
results = llm.generate(prompts)
56+
time_end = time.time()
57+
print(f"time cost: {time_end - time_start} seconds")
58+
for i, result in enumerate(results):
59+
print(f"result {i}:\n{result.outputs[0].text}")
60+
llm.shutdown(shutdown_workers=True)
61+
62+
63+
def test_single_vote_controller(prompts,
64+
proposer_worker,
65+
conf_group_size,
66+
conf_threshold,
67+
temperature,
68+
max_tokens,
69+
logprobs_topk,
70+
top_p,
71+
run_type="offline",
72+
**kwargs):
73+
generation_controller = NativeGenerationController(
74+
sampling_params={
75+
"temperature": temperature,
76+
"max_tokens": max_tokens,
77+
"num_logprobs": logprobs_topk,
78+
"top_p": top_p,
79+
})
80+
DeepConfControllerImpl = _RUN_TYPE_TO_IMPL[run_type]
81+
prototype_controller = DeepConfControllerImpl(
82+
generation_controller=generation_controller,
83+
conf_group_size=conf_group_size,
84+
conf_threshold=conf_threshold,
85+
)
86+
run_scaffolding_llm(prompts, proposer_worker, prototype_controller)
87+
88+
89+
def test_majority_vote_controller(prompts,
90+
proposer_worker,
91+
conf_group_size,
92+
conf_threshold,
93+
logprobs_topk,
94+
temperature,
95+
max_tokens,
96+
top_p,
97+
sample_num,
98+
warmup_sample_num,
99+
vote_policy,
100+
confidence_percentile,
101+
run_type="offline_majority_vote",
102+
**kwargs):
103+
generation_controller = NativeGenerationController(
104+
sampling_params={
105+
"temperature": temperature,
106+
"max_tokens": max_tokens,
107+
"num_logprobs": logprobs_topk,
108+
"top_p": top_p,
109+
})
110+
DeepConfControllerKwargs = {
111+
"generation_controller": generation_controller,
112+
"conf_group_size": conf_group_size,
113+
"conf_threshold": conf_threshold,
114+
}
115+
warmup_generation_controller = DeepConfOfflineController(
116+
**DeepConfControllerKwargs)
117+
final_generation_controller = DeepConfOnlineController(
118+
**DeepConfControllerKwargs)
119+
DeepConfMajorityVoteControllerImpl = _RUN_TYPE_TO_IMPL[run_type]
120+
majority_vote_controller = DeepConfMajorityVoteControllerImpl(
121+
generation_controller=warmup_generation_controller,
122+
warmup_generation_controller=warmup_generation_controller,
123+
final_generation_controller=final_generation_controller,
124+
sample_num=sample_num,
125+
vote_policy=vote_policy,
126+
warmup_sample_num=warmup_sample_num,
127+
confidence_percentile=confidence_percentile)
128+
run_scaffolding_llm(prompts, proposer_worker, majority_vote_controller)
129+
130+
131+
def main():
132+
args = parse_arguments()
133+
kwargs = {
134+
"sample_num": args.sample_num,
135+
"conf_group_size": args.conf_group_size,
136+
"conf_threshold": args.conf_threshold,
137+
"vote_policy": args.vote_policy,
138+
"warmup_sample_num": args.warmup_sample_num,
139+
"confidence_percentile": args.confidence_percentile,
140+
"logprobs_topk": args.logprobs_topk,
141+
"max_tokens": args.max_tokens,
142+
"temperature": args.temperature,
143+
"top_p": args.top_p,
144+
}
145+
146+
prompts = [
147+
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n",
148+
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
149+
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
150+
]
151+
152+
llm_worker = TRTLLMWorker.init_with_new_llm(
153+
args.model_dir,
154+
backend="pytorch",
155+
max_batch_size=32,
156+
max_num_tokens=kwargs.get("max_tokens"),
157+
)
158+
print(f"init llm worker done")
159+
160+
if args.run_type == "offline" or args.run_type == "online":
161+
test_single_vote_controller(prompts,
162+
llm_worker,
163+
run_type=args.run_type,
164+
**kwargs)
165+
elif args.run_type == "offline_majority_vote" or args.run_type == "online_majority_vote":
166+
test_majority_vote_controller(prompts,
167+
llm_worker,
168+
run_type=args.run_type,
169+
**kwargs)
170+
171+
llm_worker.shutdown()
172+
print('llm worker shutdown done')
173+
174+
175+
if __name__ == "__main__":
176+
main()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .deep_conf_controller import (DeepConfOfflineController,
2+
DeepConfOfflineMajorityVoteController,
3+
DeepConfOnlineController,
4+
DeepConfOnlineMajorityVoteController)
5+
6+
__all__ = [
7+
"DeepConfOfflineController",
8+
"DeepConfOnlineController",
9+
"DeepConfOfflineMajorityVoteController",
10+
"DeepConfOnlineMajorityVoteController",
11+
]

0 commit comments

Comments
 (0)