Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AZ much worse than generic solution for simple game #193

Open
70Gage70 opened this issue May 9, 2023 · 2 comments
Open

AZ much worse than generic solution for simple game #193

70Gage70 opened this issue May 9, 2023 · 2 comments

Comments

@70Gage70
Copy link

70Gage70 commented May 9, 2023

I'm trying to train AZ on single-player 21. You have a shuffled deck of cards and at each step you either "take" a card (and add its value to your total, such that Ace = 1, 2 = 2 ... face cards = 10) or "stop" and receive your current total. The obvious strategy would be to take if the expected value of a draw would leave your total <=21, and stop otherwise. This gives an average reward of roughly 14. I defined the game and used the exact training parameters from gridworld.jl and this is the result:

benchmark_reward

I don't understand why (i) the rewards are much less than 14 and (ii) why AZ is worse than the network.

Game

using AlphaZero 
using CommonRLInterface 
const RL = CommonRLInterface
import Random as RNG

# using StaticArrays
# using Crayons

const NONFACE_CARDS = [i for j = 1:4 for i = 1:10]
const FACE_CARDS = [10 for j = 1:4 for i = 1:3]
const STANDARD_DECK = map(UInt8, vcat(NONFACE_CARDS, FACE_CARDS))

### MANDATORY INTERFACE

# state = "what the player should look at"
mutable struct Env21 <: AbstractEnv
    deck::Vector{UInt8} 
    state::UInt8 # points
    reward::UInt8
    terminated::Bool
end

function RL.reset!(env::Env21)
    env.deck = RNG.shuffle(STANDARD_DECK)
    env.state = 0
    env.reward = 0
    env.terminated = false

    return nothing
end

function Env21()
    deck = RNG.shuffle(STANDARD_DECK)
    state = 0
    reward = 0
    terminated = false

    return Env21(deck, state, reward, terminated)
end

RL.actions(env::Env21) = [:take, :stop]
RL.observe(env::Env21) = env.state
RL.terminated(env::Env21) = env.terminated

function RL.act!(env::Env21, action)
    if action == :take
        draw = popfirst!(env.deck)
        env.state += draw

        if env.state >= 22 
            env.reward = 0
            env.state = 0 ######################### okay?
            env.terminated = true
        end
    elseif action == :stop
        env.reward = env.state
        env.terminated = true
    else
        error("Invalid action $action")
    end

    return env.reward
end

### TESTING

# env = Env21()
# reset!(env)
# rsum = 0.0
# while !terminated(env)
#     global rsum += act!(env, rand(actions(env))) 
# end
# @show rsum

### MULTIPLAYER INTERFACE

RL.players(env::Env21) = [1]
RL.player(env::Env21) = 1 

### Optional Interface

RL.observations(env::Env21) = map(UInt8, collect(0:21))
RL.clone(env::Env21) = Env21(copy(env.deck), copy(env.state), copy(env.reward), copy(env.terminated))
RL.state(env::Env21) = env.state
RL.setstate!(env::Env21, new_state) = (env.state = new_state)
RL.valid_action_mask(env::Env21) = BitVector([1, 1])

### AlphaZero Interface

function GI.render(env::Env21)
  println(env.deck)
  println(env.state)
  println(env.reward)
  println(env.terminated)

  return nothing
end

function GI.vectorize_state(env::Env21, state)
  v = zeros(Float32, 22)
  v[state + 1] = 1

  return v
end

const action_names = ["take", "stop"]

function GI.action_string(env::Env21, a)
  idx = findfirst(==(a), RL.actions(env))
  return isnothing(idx) ? "?" : action_names[idx]
end

function GI.parse_action(env::Env21, s)
  idx = findfirst(==(s), action_names)
  return isnothing(idx) ? nothing : RL.actions(env)[idx]
end

function GI.read_state(env::Env21)
  return env.state
end

GI.heuristic_value(::Env21) = 0.

GameSpec() = CommonRLInterfaceWrapper.Spec(Env21())

Canonical strategy

import Random as RNG

const NONFACE_CARDS = [i for j = 1:4 for i = 1:10]
const FACE_CARDS = [10 for j = 1:4 for i = 1:3]
const STANDARD_DECK = map(UInt8, vcat(NONFACE_CARDS, FACE_CARDS))

function mc_run()
    deck = RNG.shuffle(STANDARD_DECK)
    score = 0
    while true
        expected_score = score + sum(STANDARD_DECK)/length(deck) 

        if expected_score >= 22
            return score
        else
            score = score + popfirst!(deck)
            if score >= 22
                return 0
            end
        end
    end

end

function mc(n_trials)
    score = 0 
    for i = 1:n_trials
        score = score + mc_run()
    end
    return score/n_trials
end

mc(10000)
@jonathan-laurent
Copy link
Owner

I don't have time to look too deeply but here are a few remarks:

  • AZ not learning a good policy with default hyperparameters is not necessarily a red flag in itself, even for simple games. AZ can not be used as a black box in general and tuning is important.
  • The MCTS policy being worse than the network policy is more surprising. Admittedly, I've not tested AZ.jl on a lot of stochastic environments but there may be subtlelties and rough edges here.
  • In particular, there are many ways to handle stochasticity in MCTS with differentt tradeoffs. The current MCTS implementation is an open-loop MCTS implementation, which if I remember correctly deals ok with light stochasticity but can struggle with highly stochastic environments.

My advice:

  • Try and benchmark pure MCTS (with rollouts). If it does terrible even with a lot of search, then there may be a bug in the MCTS implementation or AZ's MCTS implementation may not be suited to your game.
  • Do not hesitate to make the environment smaller in your tests, for example by having very small decks.

@70Gage70
Copy link
Author

Thanks for the tips, appreciate it. I'll certainly take a look at the MCTS

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants