diff --git a/lua/octo/utils.lua b/lua/octo/utils.lua index a5542aee..c55ca596 100644 --- a/lua/octo/utils.lua +++ b/lua/octo/utils.lua @@ -449,9 +449,14 @@ end --- Determines if we are locally are in a branch matching the pr head ref --- @param pr PullRequest --- @return boolean -function M.in_pr_branch(pr) +function M.in_pr_branch_locally_tracked(pr) local cmd = "git rev-parse --abbrev-ref --symbolic-full-name @{u}" - local local_branch_with_local_remote = vim.split(string.gsub(vim.fn.system(cmd), "%s+", ""), "/") + local cmd_out = vim.fn.system(cmd) + if vim.v.shell_error ~= 0 then + return false + end + + local local_branch_with_local_remote = vim.split(string.gsub(cmd_out, "%s+", ""), "/") local local_remote = local_branch_with_local_remote[1] local local_branch = table.concat(local_branch_with_local_remote, "/", 2) @@ -465,6 +470,62 @@ function M.in_pr_branch(pr) return false end +-- Determines if we are locally in a branch matting the pr head ref when +-- the remote and branch information is stored in the branch's git config values +-- The gh CLI tool stores remote info directly in {branch.{branch}.x} configuration +-- fields and does not create a remote +function M.in_pr_branch_config_tracked(pr) + local branch_cmd = "git rev-parse --abbrev-ref HEAD" + local branch = vim.fn.system(branch_cmd) + if vim.v.shell_error ~= 0 then + return false + end + + if #branch == 0 then + return false + end + + -- trim white space off branch + branch = string.gsub(branch, "%s+", "") + + local merge_config_cmd = string.format('git config --get-regexp "^branch\\.%s\\.merge"', branch) + + local merge_config = vim.fn.system(merge_config_cmd) + if vim.v.shell_error ~= 0 then + return false + end + + if #merge_config == 0 then + return false + end + + -- split merge_config to key, value with space delimeter + local merge_config_kv = vim.split(merge_config, "%s+") + -- use > 2 since there maybe some garbage white space at the end of the map. + if #merge_config_kv < 2 then + return false + end + + local upstream_branch_ref = merge_config_kv[2] + + -- remove the prefix /refs/heads/ from upstream_branch_ref resulting in + -- branch's name. + local upstream_branch_name = string.gsub(upstream_branch_ref, "^refs/heads/", "") + + if upstream_branch_name:lower() == pr.head_ref_name then + return true + end + + return false +end + +--- Determines if we are locally are in a branch matching the pr head ref +--- @param pr PullRequest +--- @return boolean +function M.in_pr_branch(pr) + return M.in_pr_branch_locally_tracked(pr) or M.in_pr_branch_config_tracked(pr) +end + function M.checkout_pr(pr_number) gh.run { args = { "pr", "checkout", pr_number },