Skip to content

Commit

Permalink
Added greedy agent as a strong baseline (search depth 2), fixed corne…
Browse files Browse the repository at this point in the history
…r cases in game logic and UI/interactive play
  • Loading branch information
elliottower committed Feb 3, 2023
1 parent d6c002c commit 7b18e16
Show file tree
Hide file tree
Showing 14 changed files with 1,257 additions and 41 deletions.
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,18 @@ from gobblet import gobblet_v1
env = gobblet_v1.env()
```

### Play against a DQL agent trained with Tianshou
### Play against a greedy agent

In the terminal, run the following:
```
python gobblet/example_tianshou_DQN.py --cpu-players 1
```

This will launch a game vs a greedy agent, which is a very strong baseline. This agent considers all possible moves with a depth of 2, winning if possible, blocking enemy wins, and even forcing the enemy to make losing moves.

Note: this policy exploits domain knowledge to reconstruct the internal game board from the observation (perfect information) and directly uses functions from `board.py`. Tianshou policies do not get direct access to the environment, only observations/action masks. So the greedy agent should not be directly compared with other RL agents.

### Play against a DQN agent trained with Tianshou

In the terminal, run the following:
```
Expand All @@ -58,6 +69,9 @@ In the terminal, run the following:
```
python gobblet/examples/example_user_input.py"
```

Note: Interactive play can be enabled in other scripts using the argument `--num-cpu 1`

To select a piece size, press a number key `1`, `2`, or `3`, or press `space` to cycle through pieces. Placing a piece is done by clicking on a square on the board. A preview will appear showing legal moves with the selected piece size. Clicking on an already placed piece will pick it up and prompt you to place it in a new location (re-placing in the same location is an illegal move).

### Create screen recording of a game
Expand Down
Binary file added gobblet/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
64 changes: 37 additions & 27 deletions gobblet/examples/example_tianshou_DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from gobblet import gobblet_v1
from gobblet.game.collector_manual_policy import ManualPolicyCollector
from gobblet.game.utils import GIFRecorder
from gobblet.game.greedy_policy import GreedyPolicy
import time


Expand Down Expand Up @@ -145,12 +146,26 @@ def get_agents(

if agent_opponent is None:
if args.self_play:
agent_opponent = deepcopy(agent_learn)
# Create a new network with the same shape
net_opponent = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
).to(args.device)
agent_opponent = DQNPolicy(
net_opponent,
optim,
args.gamma,
args.n_step,
target_update_freq=args.target_update_freq,
)
elif args.opponent_path:
agent_opponent = deepcopy(agent_learn)
agent_opponent.load_state_dict(torch.load(args.opponent_path))
else:
agent_opponent = RandomPolicy()
# agent_opponent = RandomPolicy()
agent_opponent = GreedyPolicy() # Greedy policy is a difficult opponent, should yeild much better results than random

if args.agent_id == 1:
agents = [agent_learn, agent_opponent]
Expand Down Expand Up @@ -220,13 +235,15 @@ def train_fn(epoch, env_step):
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train)

def train_fn_selfplay(epoch, env_step):
policy.policies[agents[:]].set_eps(args.eps_train) # Same as train_fn but for both agents instead of only learner
policy.policies[agents[0]].set_eps(args.eps_train) # Same as train_fn but for both agents instead of only learner
policy.policies[agents[1]].set_eps(args.eps_train)

def test_fn(epoch, env_step):
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)

def test_fn_selfplay(epoch, env_step):
policy.policies[agents[:]].set_eps(args.eps_test) # Same as test_fn but for both agents instead of only learner
policy.policies[agents[0]].set_eps(args.eps_test) # Same as test_fn but for both agents instead of only learner
policy.policies[agents[1]].set_eps(args.eps_test)


def reward_metric(rews):
Expand All @@ -242,8 +259,8 @@ def reward_metric(rews):
args.step_per_collect,
args.test_num,
args.batch_size,
train_fn=train_fn if not args.self_play else train_fn_selfplay,
test_fn=test_fn if not args.self_play else train_fn_selfplay,
train_fn=train_fn_selfplay if args.self_play else train_fn,
test_fn=test_fn_selfplay if args.self_play else test_fn,
stop_fn=stop_fn,
save_best_fn=save_best_fn,
update_per_step=args.update_per_step,
Expand All @@ -269,7 +286,9 @@ def watch(
if not args.self_play:
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
else:
policy.policies[agents[:]].set_eps(args.eps_test)
policy.policies[agents[0]].set_eps(args.eps_test)
policy.policies[agents[1]].set_eps(args.eps_test)

collector = Collector(policy, env, exploration_noise=True)

# First step (while loop stopping conditions are not defined until we run the first step)
Expand Down Expand Up @@ -311,6 +330,9 @@ def play(
policy.eval()
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)

# Experimental: let the CPU agent to continue training (TODO: check if this actually changes things meaningfully)
# policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train)

collector = ManualPolicyCollector(policy, env, exploration_noise=True) # Collector for CPU actions

pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env
Expand All @@ -320,24 +342,10 @@ def play(
recorder = None
manual_policy = gobblet_v1.ManualPolicy(env=pettingzoo_env, agent_id=args.player, recorder=recorder) # Gobblet keyboard input requires access to raw_env (uses functions from board)

# Get the first move from the CPU (human goes second))
if args.player == 1:
result = collector.collect(n_step=1, render=args.render)

# Get the first move from the player
else:
observation = {"observation": collector.data.obs.obs.flatten(), # Observation not used for manual_policy, bu
"action_mask": collector.data.obs.mask.flatten()} # Collector mask: [1,54], PettingZoo: [54,]
action = manual_policy(observation, pettingzoo_env.agents[0])

result = collector.collect_result(action=action.reshape(1), render=args.render)

while not (collector.data.terminated or collector.data.truncated):
while pettingzoo_env.agents:
agent_id = collector.data.obs.agent_id
# If it is the players turn and there are less than 2 CPU players (at least one human player)
if agent_id == pettingzoo_env.agents[args.player]:
# action_mask = collector.data.obs.mask[0]
# action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask))
observation = {"observation": collector.data.obs.obs.flatten(),
"action_mask": collector.data.obs.mask.flatten()} # PettingZoo expects a dict with this format
action = manual_policy(observation, agent_id)
Expand All @@ -346,17 +354,19 @@ def play(
else:
result = collector.collect(n_step=1, render=args.render)

rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
if collector.data.terminated or collector.data.truncated:
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
if recorder is not None:
recorder.end_recording()

if __name__ == "__main__":
# train the agent and watch its performance in a match!
args = get_args()
print("Training agent...")
result, agent = train_agent(args)
print("Starting game...")
if args.cpu_players == 2:
watch(args, agent)
else:
play(args, agent)

#TODO: debug why it seems to not let you move when your smaller pieces are covered (print out the currently selected size and the
play(args, agent)
Loading

0 comments on commit 7b18e16

Please sign in to comment.