-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_pe.py
116 lines (96 loc) · 3.95 KB
/
run_pe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""An example PE run."""
import argparse
import os
from pyRDDLGym.core.grounder import RDDLGrounder
from pyRDDLGym.core.parser.reader import RDDLReader
from pyRDDLGym.core.parser.parser import RDDLParser
from pyRDDLGym_symbolic.core.model import RDDLModelXADD
from pyRDDLGym_symbolic.mdp.mdp_parser import MDPParser
from pyRDDLGym_symbolic.mdp.policy_parser import PolicyParser
from pyRDDLGym_symbolic.solver.pe import PolicyEvaluation
_DIR = 'pyRDDLGym_symbolic/examples/files/{domain}/'
_DOMAIN_PATH = _DIR + 'domain.rddl'
_INSTANCE_PATH = _DIR + 'instance{instance}.rddl'
def run_vi(args: argparse.Namespace):
"""Runs PE."""
# Read and parse domain and instance
domain = args.domain
instance = args.instance
domain_file = _DOMAIN_PATH.format(domain=domain)
instance_file = _INSTANCE_PATH.format(domain=domain, instance=instance)
reader = RDDLReader(
domain_file,
instance_file,
)
rddl_txt = reader.rddltxt
parser = RDDLParser(None, False)
parser.build()
# Parse RDDL file
rddl_ast = parser.parse(rddl_txt)
# Ground domain
grounder = RDDLGrounder(rddl_ast)
model = grounder.ground()
# XADD compilation
xadd_model = RDDLModelXADD(model, reparam=False)
xadd_model.compile()
mdp_parser = MDPParser()
mdp = mdp_parser.parse(
xadd_model,
xadd_model.discount,
concurrency=rddl_ast.instance.max_nondef_actions,
is_linear=args.is_linear,
is_vi=False,
)
policy_parser = PolicyParser()
policy = policy_parser.parse(
mdp=mdp,
policy_fname=args.policy_fpath,
assert_concurrency=args.assert_concurrency,
concurrency=mdp.max_allowed_actions,
)
pe_solver = PolicyEvaluation(
policy=policy,
mdp=mdp,
max_iter=args.max_iter,
enable_early_convergence=args.enable_early_convergence,
perform_reduce_lp=args.reduce_lp,
)
res = pe_solver.solve()
# Export the solution to a file
env_path = os.path.dirname(domain_file)
sol_dir = os.path.join(env_path, 'sdp', 'pe')
os.makedirs(sol_dir, exist_ok=True)
for i in range(args.max_iter):
sol_fpath = os.path.join(sol_dir, f'value_dd_iter_{i+1}.xadd')
value_dd = res['value_dd'][i]
mdp.context.export_xadd(value_dd, fname=sol_fpath)
# Visualize the solution XADD
if args.save_graph:
# Below is a hack to enforce saving to the given dir
graph_fpath = os.path.join(
os.path.abspath(sol_dir), f'value_dd_iter_{i+1}.pdf')
mdp.context.save_graph(value_dd, file_name=graph_fpath)
print(f'Times per iterations: {res["time"]}')
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--domain', type=str, default='RobotLinear_1D',
help='The name of the RDDL environment.')
parser.add_argument('--instance', type=str, default='0',
help='The instance number of the RDDL environment.')
parser.add_argument('--policy_fpath', type=str,
help='The file path to the policy.')
parser.add_argument('--max_iter', type=int, default=10,
help='The maximum number of iterations')
parser.add_argument('--enable_early_convergence', action='store_true',
help='Whether to enable early convergence.')
parser.add_argument('--is_linear', action='store_true',
help='Whether the MDP is linear or not.')
parser.add_argument('--reduce_lp', action='store_true',
help='Whether to perform the reduce LP function.')
parser.add_argument('--assert_concurrency', action='store_true',
help='Whether to assert concurrency or not')
parser.add_argument('--save_graph', action='store_true',
help='Whether to save the XADD graph to a file.')
args = parser.parse_args()
run_vi(args)