-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
47 lines (38 loc) · 1003 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
import FlappyBirdEnv
env = gym.make('FlappyBirdEnv/FlappyBird-v0')
env = Monitor(env)
# Create separate env for evaluation
eval_env = gym.make('FlappyBirdEnv/FlappyBird-v0')
eval_env = Monitor(eval_env)
model = DQN(
"MlpPolicy",
env,
verbose=1,
tensorboard_log="tensorboard"
)
checkpoint_callback = CheckpointCallback(
save_freq=10000,
save_path="./models/",
name_prefix="flappy_dqn"
)
eval_callback = EvalCallback(
eval_env,
best_model_save_path="./best_model/",
log_path="./eval_logs/",
eval_freq=10000,
deterministic=True,
render=False
)
callbacks = [checkpoint_callback, eval_callback]
TIMESTEPS = 3000000
model.learn(
total_timesteps=TIMESTEPS,
callback=callbacks,
progress_bar=True
)
model.save("flappy_bird_dqn_final2")
env.close()