forked from jannerm/diffuser
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplan_guided.py
118 lines (89 loc) · 3.58 KB
/
plan_guided.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import pdb
import diffuser.sampling as sampling
import diffuser.utils as utils
#-----------------------------------------------------------------------------#
#----------------------------------- setup -----------------------------------#
#-----------------------------------------------------------------------------#
class Parser(utils.Parser):
dataset: str = 'walker2d-medium-replay-v2'
config: str = 'config.locomotion'
args = Parser().parse_args('plan')
#-----------------------------------------------------------------------------#
#---------------------------------- loading ----------------------------------#
#-----------------------------------------------------------------------------#
## load diffusion model and value function from disk
diffusion_experiment = utils.load_diffusion(
args.loadbase, args.dataset, args.diffusion_loadpath,
epoch=args.diffusion_epoch, seed=args.seed,
)
value_experiment = utils.load_diffusion(
args.loadbase, args.dataset, args.value_loadpath,
epoch=args.value_epoch, seed=args.seed,
)
## ensure that the diffusion model and value function are compatible with each other
utils.check_compatibility(diffusion_experiment, value_experiment)
diffusion = diffusion_experiment.ema
dataset = diffusion_experiment.dataset
renderer = diffusion_experiment.renderer
## initialize value guide
value_function = value_experiment.ema
guide_config = utils.Config(args.guide, model=value_function, verbose=False)
guide = guide_config()
logger_config = utils.Config(
utils.Logger,
renderer=renderer,
logpath=args.savepath,
vis_freq=args.vis_freq,
max_render=args.max_render,
)
## policies are wrappers around an unconditional diffusion model and a value guide
policy_config = utils.Config(
args.policy,
guide=guide,
scale=args.scale,
diffusion_model=diffusion,
normalizer=dataset.normalizer,
preprocess_fns=args.preprocess_fns,
## sampling kwargs
sample_fn=sampling.n_step_guided_p_sample,
n_guide_steps=args.n_guide_steps,
t_stopgrad=args.t_stopgrad,
scale_grad_by_std=args.scale_grad_by_std,
verbose=False,
)
logger = logger_config()
policy = policy_config()
#-----------------------------------------------------------------------------#
#--------------------------------- main loop ---------------------------------#
#-----------------------------------------------------------------------------#
env = dataset.env
observation = env.reset()
## observations for rendering
rollout = [observation.copy()]
total_reward = 0
for t in range(args.max_episode_length):
if t % 10 == 0: print(args.savepath, flush=True)
## save state for rendering only
state = env.state_vector().copy()
## format current observation for conditioning
conditions = {0: observation}
action, samples = policy(conditions, batch_size=args.batch_size, verbose=args.verbose)
## execute action in environment
next_observation, reward, terminal, _ = env.step(action)
## print reward and score
total_reward += reward
score = env.get_normalized_score(total_reward)
print(
f't: {t} | r: {reward:.2f} | R: {total_reward:.2f} | score: {score:.4f} | '
f'values: {samples.values} | scale: {args.scale}',
flush=True,
)
## update rollout observations
rollout.append(next_observation.copy())
## render every `args.vis_freq` steps
logger.log(t, samples, state, rollout)
if terminal:
break
observation = next_observation
## write results to json file at `args.savepath`
logger.finish(t, score, total_reward, terminal, diffusion_experiment, value_experiment)