-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_imitation_learning.py
95 lines (78 loc) · 3.28 KB
/
train_imitation_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import environment_imitation
from environment_imitation import Grid_World
from environment_imitation import generatelabel_env,generate_env
import numpy as np
from stable_baselines3.common.monitor import Monitor
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy
from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.data.types import TrajectoryWithRew
import argparse
def create_trajectories(env,episodes,datapath):
"""
args:
env object of gridworld environment
episodes number of files in the dataset
datapath path of the dataset
returns:
trajectories for expert demonstration for supervised learning
"""
actions = ['move', 'turnLeft','turnRight', 'finish','pickMarker', 'putMarker']
trajectories = []
for i in range (episodes):
env.reset()
actionlist = generatelabel_env(datapath,training_mode="data",episode = i)
int_to_actionlist = [actions.index(action) for action in actionlist]
config = generate_env(datapath,"data",i)
state_space= env.get_observation(config)
obs = [state_space]
rews =[]
# generate observation for each action
for action in actionlist :
state_space, reward, _,_ = env.step(actions.index(action))
obs.append(state_space)
rews.append(reward)
trajectories.append(TrajectoryWithRew(obs = np.array(obs).astype(int),
acts= np.array(int_to_actionlist).astype(int),terminal = True if actionlist[-1] == "finish" else False,rews=np.array(rews).astype(float),infos =None))
return trajectories
rng = np.random.default_rng(0)
def main(args):
rng = np.random.default_rng(0)
n_episodes = len([name for name in os.listdir(args.datapath+"data/train/seq")])
gamma=0.99
env = Grid_World(n_episodes)
env = Monitor(env,filename="ppo_log/")
#transitions = rollout.flatten_trajectories(rollouts)
n_episodes = len([name for name in os.listdir(args.datapath+"data/train/seq")])
print(n_episodes)
trajectories = create_trajectories(env,episodes=n_episodes,datapath = args.datapath)
# create a rollout object of the trajectories
transitions = rollout.flatten_trajectories_with_rew(trajectories)
bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
demonstrations=transitions,
rng=rng
)
bc_trainer.train(n_epochs=args.n_epochs)
bc_trainer.save_policy(policy_path ="bc_trainer")
#test(env,100000,102399,bc_trainer)
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--datapath", help='use this option to provide a file/corpus for topic modeling.'
'By default, samples from second line onwards are considered '
'(assuming line 1 gives header info). To change this behaviour, '
'use --include_first_line.')
parser.add_argument("--n_epochs", type=int)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = arg_parser()
main(args)