-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPPO_AerisEnv.py
63 lines (44 loc) · 2.07 KB
/
PPO_AerisEnv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import copy
import gym
import gym_aeris.envs
from gym.spaces import Box, Discrete
from agents import TYPE
from agents.PPOAerisAgent import PPOAerisAgent, PPOAerisRNDAgent, PPOAerisDOPAgent, PPOAerisDOPRefAgent
from experiment.ppo_experiment import ExperimentPPO
from experiment.ppo_nenv_experiment import ExperimentNEnvPPO
from utils.MultiEnvWrapper import MultiEnvParallel
def create_env(env_id):
env = None
if env_id == 'AvoidFragiles-v0':
env = gym_aeris.envs.AvoidFragilesEnv()
if env_id == 'AvoidHazards-v0':
env = gym_aeris.envs.AvoidHazardsEnv()
if env_id == 'TargetNavigate-v0':
env = gym_aeris.envs.TargetNavigateEnv()
if env_id == 'GridTargetSearchAEnv-v0':
env = gym_aeris.envs.GridTargetSearchADiscreteEnv()
if env_id == 'GridTargetSearchBEnv-v0':
env = gym_aeris.envs.GridTargetSearchBDiscreteEnv()
return env
def run_env(env_name, config, trial, agent_class, experiment_type):
print('Creating {0:d} environments'.format(config.n_env))
env = MultiEnvParallel([create_env(env_name) for _ in range(config.n_env)], config.n_env, config.num_threads)
input_shape = env.observation_space.shape
if isinstance(env.action_space, Box):
action_dim = env.action_space.shape[0]
if isinstance(env.action_space, Discrete):
action_dim = env.action_space.n
print('Start training')
experiment = ExperimentNEnvPPO(env_name, env, config)
method_to_call = getattr(experiment, experiment_type)
agent = agent_class(input_shape, action_dim, config)
method_to_call(agent, trial)
env.close()
def run_baseline(env_name, config, trial, agent_class):
run_env(env_name, config, trial, agent_class, 'run_baseline')
def run_rnd_model(env_name, config, trial, agent_class):
run_env(env_name, config, trial, agent_class, 'run_rnd_model')
def run_dop_model(env_name, config, trial, agent_class):
run_env(env_name, config, trial, agent_class, 'run_dop_model')
def run_dop_ref_model(env_name, config, trial, agent_class):
run_env(env_name, config, trial, agent_class, 'run_baseline')