-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Redesign when ? #217
Comments
Yes, @AndrewSpano did some really good work on this. I still haven't had time to finish the redesign but I am not losing hope! You mentioned adjustments to Gumbel: if you changed anything, would you mind submitting a PR? |
Essentially you have to use the improved gumbel policy as training target (not the sequential halving move one hot encoded) and use the policy for moving during selfplay (again not the move you get from search), and add diversity at starting position (when reseting env, it randomly choose a position in all the positions you get with 0, 1 or two ply). My code is messy and naming is poor, here are essentially the 3 function to add/chnage in BatchedMCTS """
gumbel_policy(tree, mcts_config, gumbel)
Returns an array of size (num_envs,) containing the resulting actions selected
by the sequential halving procedure with gumbel for each environment. This function should
be used after `gumbel_explore()` has been run.
"""
function gumbel_policy(tree, mcts_config, current_steps,rng::AbstractRNG)
num_actions = Val(n_actions(tree))
# τ = mcts_config.tau
# deterministic_move_idx = mcts_config.collapse_tau_move
c_scale, c_visit = mcts_config.value_scale, mcts_config.max_visit_init
probs = DeviceArray(mcts_config.device)(rand(rng, Float32, batch_size(tree)))
actions = zeros(Int16, mcts_config.device, batch_size(tree))
Devices.foreach(1:batch_size(tree), mcts_config.device) do bid
# if current_steps[bid] >= 30#deterministic_move_idx
# t=0.3f0
# else
# t=1
# end
policy=get_ipolicy(tree,c_scale,c_visit,bid,num_actions,1.0f0)
actions[bid] = categorical_sample(policy, probs[bid])#gumbel_mcts_action(c_scale, c_visit, tree, bid, gumbel, num_actions)
end
return actions
end
"""
get_root_children_visits(tree, mcts_config)
Returns an array of size (num_actions, num_envs) containing the number of visits for
each action at the root node for each environment. This function should be used after
`gumbel_explore()` or `explore()` has been run.
"""
function get_ipolicy(tree,c_scale,c_visit, bid, num_actions::Val{A},τ=1.0f0) where {A}
logits = SVector{A}(imap(aid -> tree.valid_actions[aid, 1, bid] ? tree.logit_prior[aid, 1, bid]/τ : -Inf32, 1:A))+
transformed_qvalues(c_scale, c_visit, tree, 1, bid, num_actions)
return softmax(logits)
end
function get_root_ipolicy(tree, mcts_config,c_scale=0.1f0,c_visit=50)
# compute the policy: π′ = softmax(logits + σ(completedQ))
num_actions = Val(n_actions(tree))
#c_scale, c_visit = mcts_config.value_scale, mcts_config.max_visit_init
ipolicy=zeros(Float32, mcts_config.device, (n_actions(tree), batch_size(tree)))
Devices.foreach(1:batch_size(tree), mcts_config.device) do bid
ipolicy[:,bid] .= get_ipolicy(tree,c_scale,c_visit,bid,num_actions)
end
ipolicy
end And in train: if config.use_gumbel_mcts
t = @elapsed tree, gumbel = MCTS.gumbel_explore(mcts_config, envs, mcts_rng)
times.explore_times[step] = t
t = @elapsed begin
actions = MCTS.gumbel_policy(tree, mcts_config, steps_counter)
policy= get_root_ipolicy(tree, mcts_config)|>cpu
policy=[SVector{7}(policy[:,k]) for k in 1:size(policy)[2]]
end
times.selection_times[step] = t How i choose random opening is through reset env but very ugly and not generic: but what should be done is randomly choose a number of steps and then play randomly those steps from beginning( for connect 4 i use 2 steps max) |
Thanks, this is very useful! |
I played a little bit with the dev branch of redesign by GSOC author. It is a bit slower than AlphaGPU I think (pure performance wise) but with a few adjustement to Gumbel it works impressively well: this is with 32 rollouts and 4 actions, 20000 env on 3080, after around 6 minutes it catches AlphaZero.jl and then gets slightly better on tests.
The text was updated successfully, but these errors were encountered: