diff --git a/github/githubclient/client.go b/github/githubclient/client.go index bc624e8..9bf329e 100644 --- a/github/githubclient/client.go +++ b/github/githubclient/client.go @@ -58,7 +58,7 @@ type client struct { api *githubv4.Client } -var pullRequestRegex = regexp.MustCompile(`pr/[a-zA-Z0-9_\-]+/([a-zA-Z0-9_\-/]+)/([a-f0-9]{8})$`) +var BranchNameRegex = regexp.MustCompile(`pr/[a-zA-Z0-9_\-]+/([a-zA-Z0-9_\-/]+)/([a-f0-9]{8})$`) func (c *client) GetInfo(ctx context.Context, gitcmd git.GitInterface) *github.GitHubInfo { if c.config.User.LogGitHubCalls { @@ -120,7 +120,7 @@ func (c *client) GetInfo(ctx context.Context, gitcmd git.GitInterface) *github.G ToBranch: node.BaseRefName, } - matches := pullRequestRegex.FindStringSubmatch(node.HeadRefName) + matches := BranchNameRegex.FindStringSubmatch(node.HeadRefName) if matches != nil && matches[1] == branchname { pullRequest.Commit = git.Commit{ CommitID: matches[2], diff --git a/github/githubclient/client_test.go b/github/githubclient/client_test.go index 4b1378e..a56154f 100644 --- a/github/githubclient/client_test.go +++ b/github/githubclient/client_test.go @@ -13,7 +13,7 @@ func TestPullRequestRegex(t *testing.T) { } for _, tc := range tests { - matches := pullRequestRegex.FindStringSubmatch(tc.input) + matches := BranchNameRegex.FindStringSubmatch(tc.input) if tc.branch != matches[1] { t.Fatalf("expected: '%v', actual: '%v'", tc.branch, matches[1]) } diff --git a/spr/spr.go b/spr/spr.go index 6cef58f..ff34878 100644 --- a/spr/spr.go +++ b/spr/spr.go @@ -14,6 +14,7 @@ import ( "github.com/ejoffe/spr/config" "github.com/ejoffe/spr/git" "github.com/ejoffe/spr/github" + "github.com/ejoffe/spr/github/githubclient" "github.com/ejoffe/spr/hook" ) @@ -381,13 +382,20 @@ func (sd *stackediff) fetchAndGetGitHubInfo(ctx context.Context) *github.GitHubI sd.mustgit("fetch", nil) rebaseCommand := fmt.Sprintf("rebase %s/%s --autostash", sd.config.Repo.GitHubRemote, sd.config.Repo.GitHubBranch) - //var output string err := sd.gitcmd.Git(rebaseCommand, nil) if err != nil { return nil } - info := sd.github.GetInfo(ctx, sd.gitcmd) + if githubclient.BranchNameRegex.FindString(info.LocalBranch) != "" { + fmt.Printf("error: don't run spr in a remote pr branch\n") + fmt.Printf(" this could lead to weird duplicate pull requests getting created\n") + fmt.Printf(" in general there is no need to checkout remote branches used for prs\n") + fmt.Printf(" instead use local branches and run spr update to sync your commit stack\n") + fmt.Printf(" with your pull requests on github\n") + fmt.Printf("branch name: %s\n", info.LocalBranch) + return nil + } return info }