-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathagent.py
33 lines (24 loc) · 854 Bytes
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import argparse
from simple_agent import RandomAgent, ForwardOnlyAgent
from rl_agent import SACAgent
from gibson2.envs.challenge import Challenge
def get_agent(agent_class, ckpt_path=""):
if agent_class == "Random":
return RandomAgent()
elif agent_class == "ForwardOnly":
return ForwardOnlyAgent()
elif agent_class == "SAC":
return SACAgent(root_dir=ckpt_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--agent-class", type=str, default="Random", choices=["Random", "ForwardOnly", "SAC"])
parser.add_argument("--ckpt-path", default="", type=str)
args = parser.parse_args()
agent = get_agent(
agent_class=args.agent_class,
ckpt_path=args.ckpt_path
)
challenge = Challenge()
challenge.submit(agent)
if __name__ == "__main__":
main()