diff --git a/cherry_picker/cherry_picker.py b/cherry_picker/cherry_picker.py index 245e299..b131509 100755 --- a/cherry_picker/cherry_picker.py +++ b/cherry_picker/cherry_picker.py @@ -38,6 +38,9 @@ CHECKING_OUT_DEFAULT_BRANCH CHECKED_OUT_DEFAULT_BRANCH + CHECKING_OUT_PREVIOUS_BRANCH + CHECKED_OUT_PREVIOUS_BRANCH + PUSHING_TO_REMOTE PUSHED_TO_REMOTE PUSHING_TO_REMOTE_FAILED @@ -138,6 +141,11 @@ def set_paused_state(self): save_cfg_vals_to_git_cfg(config_path=self.chosen_config_path) set_state(WORKFLOW_STATES.BACKPORT_PAUSED) + def remember_previous_branch(self): + """Save the current branch into Git config to be able to get back to it later.""" + current_branch = get_current_branch() + save_cfg_vals_to_git_cfg(previous_branch=current_branch) + @property def upstream(self): """Get the remote name to use for upstream branches @@ -184,24 +192,29 @@ def run_cmd(self, cmd): output = subprocess.check_output(cmd, stderr=subprocess.STDOUT) return output.decode("utf-8") - def checkout_branch(self, branch_name): - """git checkout -b """ - cmd = [ - "git", - "checkout", - "-b", - self.get_cherry_pick_branch(branch_name), - f"{self.upstream}/{branch_name}", - ] + def checkout_branch(self, branch_name, *, create_branch=False): + """git checkout [-b] """ + if create_branch: + checked_out_branch = self.get_cherry_pick_branch(branch_name) + cmd = [ + "git", + "checkout", + "-b", + checked_out_branch, + f"{self.upstream}/{branch_name}", + ] + else: + checked_out_branch = branch_name + cmd = ["git", "checkout", branch_name] try: self.run_cmd(cmd) except subprocess.CalledProcessError as err: click.echo( - f"Error checking out the branch {self.get_cherry_pick_branch(branch_name)}." + f"Error checking out the branch {branch_name}." ) click.echo(err.output) raise BranchCheckoutException( - f"Error checking out the branch {self.get_cherry_pick_branch(branch_name)}." + f"Error checking out the branch {branch_name}." ) def get_commit_message(self, commit_sha): @@ -225,11 +238,23 @@ def checkout_default_branch(self): """git checkout default branch""" set_state(WORKFLOW_STATES.CHECKING_OUT_DEFAULT_BRANCH) - cmd = "git", "checkout", self.config["default_branch"] - self.run_cmd(cmd) + self.checkout_branch(self.config["default_branch"]) set_state(WORKFLOW_STATES.CHECKED_OUT_DEFAULT_BRANCH) + def checkout_previous_branch(self): + """git checkout previous branch""" + set_state(WORKFLOW_STATES.CHECKING_OUT_PREVIOUS_BRANCH) + + previous_branch = load_val_from_git_cfg("previous_branch") + if previous_branch is None: + self.checkout_default_branch() + return + + self.checkout_branch(previous_branch) + + set_state(WORKFLOW_STATES.CHECKED_OUT_PREVIOUS_BRANCH) + def status(self): """ git status @@ -363,7 +388,12 @@ def cleanup_branch(self, branch): Switch to the default branch before that. """ set_state(WORKFLOW_STATES.REMOVING_BACKPORT_BRANCH) - self.checkout_default_branch() + try: + self.checkout_previous_branch() + except BranchCheckoutException: + click.echo(f"branch {branch} NOT deleted.") + set_state(WORKFLOW_STATES.REMOVING_BACKPORT_BRANCH_FAILED) + return try: self.delete_branch(branch) except subprocess.CalledProcessError: @@ -378,6 +408,7 @@ def backport(self): raise click.UsageError("At least one branch must be specified.") set_state(WORKFLOW_STATES.BACKPORT_STARTING) self.fetch_upstream() + self.remember_previous_branch() set_state(WORKFLOW_STATES.BACKPORT_LOOPING) for maint_branch in self.sorted_branches: @@ -385,7 +416,7 @@ def backport(self): click.echo(f"Now backporting '{self.commit_sha1}' into '{maint_branch}'") cherry_pick_branch = self.get_cherry_pick_branch(maint_branch) - self.checkout_branch(maint_branch) + self.checkout_branch(maint_branch, create_branch=True) commit_message = "" try: self.cherry_pick() @@ -419,6 +450,7 @@ def backport(self): self.set_paused_state() return # to preserve the correct state set_state(WORKFLOW_STATES.BACKPORT_LOOP_END) + reset_stored_previous_branch() reset_state() def abort_cherry_pick(self): @@ -440,6 +472,7 @@ def abort_cherry_pick(self): if get_current_branch().startswith("backport-"): self.cleanup_branch(get_current_branch()) + reset_stored_previous_branch() reset_stored_config_ref() reset_state() @@ -499,6 +532,7 @@ def continue_cherry_pick(self): ) set_state(WORKFLOW_STATES.CONTINUATION_FAILED) + reset_stored_previous_branch() reset_stored_config_ref() reset_state() @@ -828,6 +862,11 @@ def reset_stored_config_ref(): """Config file pointer is not stored in Git config.""" +def reset_stored_previous_branch(): + """Remove the previous branch information from Git config.""" + wipe_cfg_vals_from_git_cfg("previous_branch") + + def reset_state(): """Remove the progress state from Git config.""" wipe_cfg_vals_from_git_cfg("state") diff --git a/cherry_picker/test_cherry_picker.py b/cherry_picker/test_cherry_picker.py index a2919b0..7f8a793 100644 --- a/cherry_picker/test_cherry_picker.py +++ b/cherry_picker/test_cherry_picker.py @@ -84,6 +84,14 @@ def git_commit(): ) +@pytest.fixture +def git_worktree(): + git_worktree_cmd = "git", "worktree" + return lambda *extra_args: ( + subprocess.run(git_worktree_cmd + extra_args, check=True) + ) + + @pytest.fixture def git_cherry_pick(): git_cherry_pick_cmd = "git", "cherry-pick" @@ -100,12 +108,13 @@ def git_config(): @pytest.fixture def tmp_git_repo_dir(tmpdir, cd, git_init, git_commit, git_config): - cd(tmpdir) + repo_dir = tmpdir.mkdir("tmp-git-repo") + cd(repo_dir) git_init() git_config("--local", "user.name", "Monty Python") git_config("--local", "user.email", "bot@python.org") git_commit("Initial commit", "--allow-empty") - yield tmpdir + yield repo_dir @mock.patch("subprocess.check_output") @@ -545,6 +554,11 @@ def test_paused_flow(tmp_git_repo_dir, git_add, git_commit): WORKFLOW_STATES.CHECKING_OUT_DEFAULT_BRANCH, WORKFLOW_STATES.CHECKED_OUT_DEFAULT_BRANCH, ), + ( + "checkout_previous_branch", + WORKFLOW_STATES.CHECKING_OUT_PREVIOUS_BRANCH, + WORKFLOW_STATES.CHECKED_OUT_PREVIOUS_BRANCH, + ), ), ) def test_start_end_states(method_name, start_state, end_state, tmp_git_repo_dir): @@ -552,6 +566,7 @@ def test_start_end_states(method_name, start_state, end_state, tmp_git_repo_dir) with mock.patch("cherry_picker.cherry_picker.validate_sha", return_value=True): cherry_picker = CherryPicker("origin", "xxx", []) + cherry_picker.remember_previous_branch() assert get_state() == WORKFLOW_STATES.UNSET def _fetch(cmd): @@ -572,6 +587,22 @@ def test_cleanup_branch(tmp_git_repo_dir, git_checkout): git_checkout("-b", "some_branch") cherry_picker.cleanup_branch("some_branch") assert get_state() == WORKFLOW_STATES.REMOVED_BACKPORT_BRANCH + assert get_current_branch() == "main" + + +def test_cleanup_branch_checkout_previous_branch(tmp_git_repo_dir, git_checkout, git_worktree): + assert get_state() == WORKFLOW_STATES.UNSET + + with mock.patch("cherry_picker.cherry_picker.validate_sha", return_value=True): + cherry_picker = CherryPicker("origin", "xxx", []) + assert get_state() == WORKFLOW_STATES.UNSET + + git_checkout("-b", "previous_branch") + cherry_picker.remember_previous_branch() + git_checkout("-b", "some_branch") + cherry_picker.cleanup_branch("some_branch") + assert get_state() == WORKFLOW_STATES.REMOVED_BACKPORT_BRANCH + assert get_current_branch() == "previous_branch" def test_cleanup_branch_fail(tmp_git_repo_dir): @@ -585,6 +616,19 @@ def test_cleanup_branch_fail(tmp_git_repo_dir): assert get_state() == WORKFLOW_STATES.REMOVING_BACKPORT_BRANCH_FAILED +def test_cleanup_branch_checkout_fail(tmp_git_repo_dir, tmpdir, git_checkout, git_worktree): + assert get_state() == WORKFLOW_STATES.UNSET + + with mock.patch("cherry_picker.cherry_picker.validate_sha", return_value=True): + cherry_picker = CherryPicker("origin", "xxx", []) + assert get_state() == WORKFLOW_STATES.UNSET + + git_checkout("-b", "some_branch") + git_worktree("add", str(tmpdir.mkdir("test-worktree")), "main") + cherry_picker.cleanup_branch("some_branch") + assert get_state() == WORKFLOW_STATES.REMOVING_BACKPORT_BRANCH_FAILED + + def test_cherry_pick(tmp_git_repo_dir, git_add, git_branch, git_commit, git_checkout): cherry_pick_target_branches = ("3.8",) pr_remote = "origin"