Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign when ? #217

Open
fabricerosay opened this issue Oct 21, 2024 · 3 comments
Open

Redesign when ? #217

fabricerosay opened this issue Oct 21, 2024 · 3 comments

Comments

@fabricerosay
Copy link

I played a little bit with the dev branch of redesign by GSOC author. It is a bit slower than AlphaGPU I think (pure performance wise) but with a few adjustement to Gumbel it works impressively well: this is with 32 rollouts and 4 actions, 20000 env on 3080, after around 6 minutes it catches AlphaZero.jl and then gets slightly better on tests.
pascal_pons_benchmark_error_rates

@jonathan-laurent
Copy link
Owner

Yes, @AndrewSpano did some really good work on this. I still haven't had time to finish the redesign but I am not losing hope! You mentioned adjustments to Gumbel: if you changed anything, would you mind submitting a PR?

@fabricerosay
Copy link
Author

Essentially you have to use the improved gumbel policy as training target (not the sequential halving move one hot encoded) and use the policy for moving during selfplay (again not the move you get from search), and add diversity at starting position (when reseting env, it randomly choose a position in all the positions you get with 0, 1 or two ply). My code is messy and naming is poor, here are essentially the 3 function to add/chnage in BatchedMCTS

"""
    gumbel_policy(tree, mcts_config, gumbel)

Returns an array of size (num_envs,) containing the resulting actions selected
by the sequential halving procedure with gumbel for each environment. This function should
be used after `gumbel_explore()` has been run.
"""
function gumbel_policy(tree, mcts_config, current_steps,rng::AbstractRNG)
    num_actions = Val(n_actions(tree))
#    τ = mcts_config.tau
 #   deterministic_move_idx = mcts_config.collapse_tau_move
    c_scale, c_visit = mcts_config.value_scale, mcts_config.max_visit_init
    probs = DeviceArray(mcts_config.device)(rand(rng, Float32, batch_size(tree)))
    actions = zeros(Int16, mcts_config.device, batch_size(tree))
    Devices.foreach(1:batch_size(tree), mcts_config.device) do bid
        # if current_steps[bid] >= 30#deterministic_move_idx 
        #     t=0.3f0
        # else
        #     t=1
        # end
        policy=get_ipolicy(tree,c_scale,c_visit,bid,num_actions,1.0f0)
        actions[bid] = categorical_sample(policy, probs[bid])#gumbel_mcts_action(c_scale, c_visit, tree, bid, gumbel, num_actions)
    end

    return actions
end


"""
    get_root_children_visits(tree, mcts_config)

Returns an array of size (num_actions, num_envs) containing the number of visits for
each action at the root node for each environment. This function should be used after
`gumbel_explore()` or `explore()` has been run.
"""
function get_ipolicy(tree,c_scale,c_visit, bid, num_actions::Val{A}=1.0f0) where {A}
    logits = SVector{A}(imap(aid -> tree.valid_actions[aid, 1, bid] ? tree.logit_prior[aid, 1, bid]/τ : -Inf32, 1:A))+
    transformed_qvalues(c_scale, c_visit, tree, 1, bid, num_actions)
    return softmax(logits)
end
function get_root_ipolicy(tree, mcts_config,c_scale=0.1f0,c_visit=50)
    # compute the policy: π′ = softmax(logits + σ(completedQ))
    num_actions = Val(n_actions(tree))
    #c_scale, c_visit = mcts_config.value_scale, mcts_config.max_visit_init
    ipolicy=zeros(Float32, mcts_config.device, (n_actions(tree), batch_size(tree)))
    Devices.foreach(1:batch_size(tree), mcts_config.device) do bid
        ipolicy[:,bid] .= get_ipolicy(tree,c_scale,c_visit,bid,num_actions)
    end
    ipolicy
end

And in train:

if config.use_gumbel_mcts
            t = @elapsed tree, gumbel = MCTS.gumbel_explore(mcts_config, envs, mcts_rng)
            times.explore_times[step] = t

            t = @elapsed begin
                actions = MCTS.gumbel_policy(tree, mcts_config,  steps_counter)
                policy= get_root_ipolicy(tree, mcts_config)|>cpu
                policy=[SVector{7}(policy[:,k]) for k in 1:size(policy)[2]]
            end
            times.selection_times[step] = t

How i choose random opening is through reset env but very ugly and not generic: but what should be done is randomly choose a number of steps and then play randomly those steps from beginning( for connect 4 i use 2 steps max)

@jonathan-laurent
Copy link
Owner

Thanks, this is very useful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants