-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_dataset.py
60 lines (50 loc) · 1.14 KB
/
build_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import time
from rebar.learners.qlearner import QLearner
from rebar.learners.adp import ADP
from rebar.memory import Memory
import numpy as np
import gym
import torch
import pickle
from envs import Swingup, Reacher, InvertedDoublePendulum, InvertedPendulum, Walker
from copy import deepcopy
from gym.wrappers import TimeLimit
import matplotlib.pyplot as plt
env = InvertedPendulum
play_env = TimeLimit(deepcopy(env), max_episode_steps=2000)
play_env.render()
agt = ADP(
action_space=env.action_space,
observation_space=env.observation_space,
bins=5,
gamma=0.99,
initial_temp=2000,
delta=0.01
)
def play(agt, env):
done = False
s = env.reset()
total_r = 0
while not done:
a = agt.get_action(s, explore=False)
s, r, done, _ = env.step(a)
total_r += r
time.sleep(1./30.)
return total_r
s = env.reset()
m = Memory(
max_len=1000,
obs_shape=(5,),
action_shape=(1,)
)
for step in range(1000):
a = agt.get_action(s)
sp, r, done, _ = env.step(a)
agt.handle_transition(s, a, r, sp, done)
m.append((s, a, r, sp, done))
if done:
s = env.reset()
done = False
s = sp
pickle.dump(m, open('dataset.pkl', 'wb'))
print(play(agt, play_env))