-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathA2C.py
52 lines (51 loc) · 1.63 KB
/
A2C.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
""" Runner for rlpyt's A2C """
from Network import *
from rlpyt.samplers.serial.sampler import SerialSampler
from rlpyt.samplers.parallel.gpu.sampler import GpuSampler
from rlpyt.algos.pg.a2c import A2C
from rlpyt.agents.pg.categorical import CategoricalPgAgent
from rlpyt.runners.minibatch_rl import MinibatchRl
from rlpyt.utils.logging.context import logger_context
from EnvWrapper import rlpyt_make
import Config as C
import os
import os.path as osp
def findOptimalAgent (reward, run_ID=0) :
"""
Find the optimal agent for the MDP (see Config for
specification) under a custom reward function
using rlpyt's implementation of A2C.
"""
cpus = list(range(C.N_PARALLEL))
affinity = dict(cuda_idx=C.CUDA_IDX, workers_cpus=cpus)
sampler = SerialSampler(
EnvCls=rlpyt_make,
env_kwargs=dict(id=C.ENV, reward=reward),
batch_T=C.BATCH_T,
batch_B=C.BATCH_B,
max_decorrelation_steps=400,
eval_env_kwargs=dict(id=C.ENV),
eval_n_envs=5,
eval_max_steps=2500
)
algo = A2C(
discount=C.DISCOUNT,
learning_rate=C.LR,
value_loss_coeff=C.VALUE_LOSS_COEFF,
entropy_loss_coeff=C.ENTROPY_LOSS_COEFF
)
agent = CategoricalPgAgent(AcrobotNet)
runner = MinibatchRl(
algo=algo,
agent=agent,
sampler=sampler,
n_steps=C.N_STEPS,
log_interval_steps=C.LOG_STEP,
affinity=affinity,
)
name = "a2c_" + C.ENV.lower()
log_dir = name
with logger_context(log_dir, run_ID, name,
snapshot_mode='last', override_prefix=True):
runner.train()
return agent