diff --git a/.flake8 b/.flake8 deleted file mode 100644 index bcc463adc2..0000000000 --- a/.flake8 +++ /dev/null @@ -1,44 +0,0 @@ -[flake8] -# Ignore E402 ("module level import not at top of file"), -# because even with the lazy import plugin it still triggers -# for lazy_import statements before other imports. -exclude = .git,__pycache__,build,dist,target,.eggs,lib -ignore = - D - I - E12 - E261 - E265 - E266 - E301 - E302 - E303 - E305 - E306 - E401 - E402 - E501 - E502 - E702 - E704 - E722 - E731 - E741 - F401 - F402 - F403 - F405 - F811 - F812 - F821 - F841 - W391 - W503 - W504 - W605 -filename = *.py - -[flake8:local-plugins] -extension = - MC1 = flake8_lazy_import:LazyImport -paths = ./tools/ diff --git a/apport/source_brz.py b/apport/source_brz.py index 3bb05bfcb5..7a045024fb 100644 --- a/apport/source_brz.py +++ b/apport/source_brz.py @@ -8,12 +8,13 @@ from apport.hookutils import * # noqa: F403 -brz_log = os.path.expanduser('~/.brz.log') -dot_brz = os.path.expanduser('~/.config/breezy') +brz_log = os.path.expanduser("~/.brz.log") +dot_brz = os.path.expanduser("~/.config/breezy") + def _add_log_tail(report): # may have already been added in-process - if 'BrzLogTail' in report: + if "BrzLogTail" in report: return with open(brz_log) as f: @@ -23,33 +24,33 @@ def _add_log_tail(report): brz_log_tail = [] blanks = 0 for line in brz_log_lines: - if line == '\n': + if line == "\n": blanks += 1 brz_log_tail.append(line) if blanks >= 2: break brz_log_tail.reverse() - report['BrzLogTail'] = ''.join(brz_log_tail) + report["BrzLogTail"] = "".join(brz_log_tail) def add_info(report): _add_log_tail(report) - if 'BrzPlugins' not in report: + if "BrzPlugins" not in report: # may already be present in-process - report['BrzPlugins'] = command_output(['brz', 'plugins', '-v']) + report["BrzPlugins"] = command_output(["brz", "plugins", "-v"]) # by default assume brz crashes are upstream bugs; this relies on # having a brz entry under /etc/apport/crashdb.conf.d/ - report['CrashDB'] = 'brz' + report["CrashDB"] = "brz" # these may contain some sensitive info (smtp_passwords) # TODO: strip that out and attach the rest - #attach_file_if_exists(report, - # os.path.join(dot_brz, 'breezy.conf', 'BrzConfig') - #attach_file_if_exists(report, - # os.path.join(dot_brz, 'locations.conf', 'BrzLocations') + # attach_file_if_exists(report, + # os.path.join(dot_brz, 'breezy.conf', 'BrzConfig') + # attach_file_if_exists(report, + # os.path.join(dot_brz, 'locations.conf', 'BrzLocations') # vim: expandtab shiftwidth=4 diff --git a/breezy/__init__.py b/breezy/__init__.py index f6e51b0a2f..2fdcac71ab 100644 --- a/breezy/__init__.py +++ b/breezy/__init__.py @@ -39,8 +39,7 @@ import sys __copyright__ = ( - "Copyright 2005-2012 Canonical Ltd.\n" - "Copyright 2017-2023 Breezy developers" + "Copyright 2005-2012 Canonical Ltd.\n" "Copyright 2017-2023 Breezy developers" ) # same format as sys.version_info: "A tuple containing the five components of @@ -50,7 +49,7 @@ # Python version 2.0 is (2, 0, 0, 'final', 0)." Additionally we use a # releaselevel of 'dev' for unreleased under-development code. -version_info = (3, 4, 0, 'dev', 0) +version_info = (3, 4, 0, "dev", 0) def _format_version_tuple(version_info): @@ -82,38 +81,38 @@ def _format_version_tuple(version_info): 1.4.0.wibble.0 """ if len(version_info) == 2: - main_version = '%d.%d' % version_info[:2] + main_version = "%d.%d" % version_info[:2] else: - main_version = '%d.%d.%d' % version_info[:3] + main_version = "%d.%d.%d" % version_info[:3] if len(version_info) <= 3: return main_version release_type = version_info[3] sub = version_info[4] - if release_type == 'final' and sub == 0: - sub_string = '' - elif release_type == 'final': - sub_string = '.' + str(sub) - elif release_type == 'dev' and sub == 0: - sub_string = '.dev' - elif release_type == 'dev': - sub_string = '.dev' + str(sub) - elif release_type in ('alpha', 'beta'): + if release_type == "final" and sub == 0: + sub_string = "" + elif release_type == "final": + sub_string = "." + str(sub) + elif release_type == "dev" and sub == 0: + sub_string = ".dev" + elif release_type == "dev": + sub_string = ".dev" + str(sub) + elif release_type in ("alpha", "beta"): if version_info[2] == 0: - main_version = '%d.%d' % version_info[:2] - sub_string = '.' + release_type[0] + str(sub) - elif release_type == 'candidate': - sub_string = '.rc' + str(sub) + main_version = "%d.%d" % version_info[:2] + sub_string = "." + release_type[0] + str(sub) + elif release_type == "candidate": + sub_string = ".rc" + str(sub) else: - return '.'.join(map(str, version_info)) + return ".".join(map(str, version_info)) return main_version + sub_string __version__ = _format_version_tuple(version_info) version_string = __version__ -_core_version_string = '.'.join(map(str, version_info[:3])) +_core_version_string = ".".join(map(str, version_info[:3])) def _patch_filesystem_default_encoding(new_enc): @@ -129,15 +128,14 @@ def _patch_filesystem_default_encoding(new_enc): """ try: import ctypes - pythonapi = getattr(ctypes, 'pythonapi', None) + + pythonapi = getattr(ctypes, "pythonapi", None) if pythonapi is not None: - old_ptr = ctypes.c_void_p.in_dll(pythonapi, - "Py_FileSystemDefaultEncoding") - has_enc = ctypes.c_int.in_dll(pythonapi, - "Py_HasFileSystemDefaultEncoding") + old_ptr = ctypes.c_void_p.in_dll(pythonapi, "Py_FileSystemDefaultEncoding") + has_enc = ctypes.c_int.in_dll(pythonapi, "Py_HasFileSystemDefaultEncoding") as_utf8 = ctypes.PYFUNCTYPE( - ctypes.POINTER(ctypes.c_char), ctypes.py_object)( - ("PyUnicode_AsUTF8", pythonapi)) + ctypes.POINTER(ctypes.c_char), ctypes.py_object + )(("PyUnicode_AsUTF8", pythonapi)) except (ImportError, ValueError): return # No ctypes or not CPython implementation, do nothing new_enc = sys.intern(new_enc) @@ -154,7 +152,7 @@ def _patch_filesystem_default_encoding(new_enc): # just ensure a usable locale is set via the $LANG variable on posix systems. _fs_enc = sys.getfilesystemencoding() if getattr(sys, "_brz_default_fs_enc", None) is not None: - if (_fs_enc is None or codecs.lookup(_fs_enc).name == "ascii"): + if _fs_enc is None or codecs.lookup(_fs_enc).name == "ascii": _fs_enc = _patch_filesystem_default_encoding(sys._brz_default_fs_enc) # type: ignore if _fs_enc is None: _fs_enc = "ascii" @@ -202,8 +200,10 @@ def initialize(setup_ui=True, stdin=None, stdout=None, stderr=None): BzrLibraryState directly. """ from breezy import library_state, trace + if setup_ui: import breezy.ui + stdin = stdin or sys.stdin stdout = stdout or sys.stdout stderr = stderr or sys.stderr @@ -225,4 +225,5 @@ def get_global_state(): def test_suite(): import tests + return tests.test_suite() diff --git a/breezy/__main__.py b/breezy/__main__.py index c4fde61709..1aa2642235 100644 --- a/breezy/__main__.py +++ b/breezy/__main__.py @@ -21,23 +21,26 @@ import sys profiling = False -if '--profile-imports' in sys.argv: +if "--profile-imports" in sys.argv: import profile_imports + profile_imports.install() profiling = True if os.name == "posix": import locale + try: - locale.setlocale(locale.LC_ALL, '') + locale.setlocale(locale.LC_ALL, "") except locale.Error as e: sys.stderr.write( - 'brz: warning: %s\n' - ' bzr could not set the application locale.\n' - ' Although this should be no problem for bzr itself, it might\n' - ' cause problems with some plugins. To investigate the issue,\n' - ' look at the output of the locale(1p) tool.\n' % e) + "brz: warning: %s\n" + " bzr could not set the application locale.\n" + " Although this should be no problem for bzr itself, it might\n" + " cause problems with some plugins. To investigate the issue,\n" + " look at the output of the locale(1p) tool.\n" % e + ) # Use better default than ascii with posix filesystems that deal in bytes # natively even when the C locale or no locale at all is given. Note that # we need an immortal string for the hack, hence the lack of a hyphen. @@ -46,6 +49,7 @@ def main(): import breezy.breakin + breezy.breakin.hook_debugger_to_signal() import breezy.commands @@ -69,5 +73,5 @@ def main(): os._exit(exit_val) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/breezy/_annotator_py.py b/breezy/_annotator_py.py index a5d9de6bec..6dad5ee45f 100644 --- a/breezy/_annotator_py.py +++ b/breezy/_annotator_py.py @@ -81,8 +81,9 @@ def _get_needed_keys(self, key): parent_keys = () next_parent_map[key] = () self._update_needed_children(key, parent_keys) - needed_keys.update([key for key in parent_keys - if key not in parent_map]) + needed_keys.update( + [key for key in parent_keys if key not in parent_map] + ) parent_map.update(next_parent_map) # _heads_provider does some graph caching, so it is only valid # while self._parent_map hasn't changed @@ -100,15 +101,15 @@ def _get_needed_texts(self, key, pb=None): """ keys, ann_keys = self._get_needed_keys(key) if pb is not None: - pb.update('getting stream', 0, len(keys)) - stream = self._vf.get_record_stream(keys, 'topological', True) + pb.update("getting stream", 0, len(keys)) + stream = self._vf.get_record_stream(keys, "topological", True) for _idx, record in enumerate(stream): if pb is not None: - pb.update('extracting', 0, len(keys)) - if record.storage_kind == 'absent': + pb.update("extracting", 0, len(keys)) + if record.storage_kind == "absent": raise errors.RevisionNotPresent(record.key, self._vf) this_key = record.key - lines = record.get_bytes_as('lines') + lines = record.get_bytes_as("lines") num_lines = len(lines) self._text_cache[this_key] = lines yield this_key, lines, num_lines @@ -132,28 +133,32 @@ def _get_parent_annotations_and_matches(self, key, text, parent_key): parent_annotations = self._annotations_cache[parent_key] # PatienceSequenceMatcher should probably be part of Policy from patiencediff import PatienceSequenceMatcher - matcher = PatienceSequenceMatcher( - None, parent_lines, text) + + matcher = PatienceSequenceMatcher(None, parent_lines, text) matching_blocks = matcher.get_matching_blocks() return parent_annotations, matching_blocks def _update_from_first_parent(self, key, annotations, lines, parent_key): """Reannotate this text relative to its first parent.""" - (parent_annotations, - matching_blocks) = self._get_parent_annotations_and_matches( - key, lines, parent_key) + ( + parent_annotations, + matching_blocks, + ) = self._get_parent_annotations_and_matches(key, lines, parent_key) for parent_idx, lines_idx, match_len in matching_blocks: # For all matching regions we copy across the parent annotations - annotations[lines_idx:lines_idx + match_len] = \ - parent_annotations[parent_idx:parent_idx + match_len] + annotations[lines_idx : lines_idx + match_len] = parent_annotations[ + parent_idx : parent_idx + match_len + ] - def _update_from_other_parents(self, key, annotations, lines, - this_annotation, parent_key): + def _update_from_other_parents( + self, key, annotations, lines, this_annotation, parent_key + ): """Reannotate this text relative to a second (or more) parent.""" - (parent_annotations, - matching_blocks) = self._get_parent_annotations_and_matches( - key, lines, parent_key) + ( + parent_annotations, + matching_blocks, + ) = self._get_parent_annotations_and_matches(key, lines, parent_key) last_ann = None last_parent = None @@ -167,8 +172,8 @@ def _update_from_other_parents(self, key, annotations, lines, for parent_idx, lines_idx, match_len in matching_blocks: # For lines which match this parent, we will now resolve whether # this parent wins over the current annotation - ann_sub = annotations[lines_idx:lines_idx + match_len] - par_sub = parent_annotations[parent_idx:parent_idx + match_len] + ann_sub = annotations[lines_idx : lines_idx + match_len] + par_sub = parent_annotations[parent_idx : parent_idx + match_len] if ann_sub == par_sub: continue for idx in range(match_len): @@ -214,11 +219,11 @@ def _annotate_one(self, key, text, num_lines): annotations = [this_annotation] * num_lines parent_keys = self._parent_map[key] if parent_keys: - self._update_from_first_parent(key, annotations, text, - parent_keys[0]) + self._update_from_first_parent(key, annotations, text, parent_keys[0]) for parent in parent_keys[1:]: - self._update_from_other_parents(key, annotations, text, - this_annotation, parent) + self._update_from_other_parents( + key, annotations, text, this_annotation, parent + ) self._record_annotation(key, parent_keys, annotations) def add_special_text(self, key, parent_keys, text): @@ -246,8 +251,7 @@ def annotate(self, key): lines the text of "key" as a list of lines """ with ui.ui_factory.nested_progress_bar() as pb: - for text_key, text, num_lines in self._get_needed_texts( - key, pb=pb): + for text_key, text, num_lines in self._get_needed_texts(key, pb=pb): self._annotate_one(text_key, text, num_lines) try: annotations = self._annotations_cache[key] @@ -281,6 +285,7 @@ def annotate_flat(self, key): A list of tuples with a single annotation key for each line. """ from .annotate import _break_annotation_tie + custom_tiebreaker = _break_annotation_tie annotations, lines = self.annotate(key) out = [] @@ -295,7 +300,8 @@ def annotate_flat(self, key): # get the item out of the set head = next(iter(the_heads)) else: - head = self._resolve_annotation_tie(the_heads, line, - custom_tiebreaker) + head = self._resolve_annotation_tie( + the_heads, line, custom_tiebreaker + ) append((head, line)) return out diff --git a/breezy/_known_graph_py.py b/breezy/_known_graph_py.py index f2ce958652..fbe88e913e 100644 --- a/breezy/_known_graph_py.py +++ b/breezy/_known_graph_py.py @@ -24,7 +24,7 @@ class _KnownGraphNode: """Represents a single object in the known graph.""" - __slots__ = ('key', 'parent_keys', 'child_keys', 'gdfo') + __slots__ = ("key", "parent_keys", "child_keys", "gdfo") def __init__(self, key, parent_keys): self.key = key @@ -34,15 +34,19 @@ def __init__(self, key, parent_keys): self.gdfo = None def __repr__(self): - return '{}({} gdfo:{} par:{} child:{})'.format( - self.__class__.__name__, self.key, self.gdfo, - self.parent_keys, self.child_keys) + return "{}({} gdfo:{} par:{} child:{})".format( + self.__class__.__name__, + self.key, + self.gdfo, + self.parent_keys, + self.child_keys, + ) class _MergeSortNode: """Information about a specific node in the merge graph.""" - __slots__ = ('key', 'merge_depth', 'revno', 'end_of_merge') + __slots__ = ("key", "merge_depth", "revno", "end_of_merge") def __init__(self, key, merge_depth, revno, end_of_merge): self.key = key @@ -92,12 +96,10 @@ def _initialize_nodes(self, parent_map): parent_node.child_keys.append(key) def _find_tails(self): - return [node for node in self._nodes.values() - if not node.parent_keys] + return [node for node in self._nodes.values() if not node.parent_keys] def _find_tips(self): - return [node for node in self._nodes.values() - if not node.child_keys] + return [node for node in self._nodes.values() if not node.child_keys] def _find_gdfo(self): nodes = self._nodes @@ -157,8 +159,9 @@ def add_node(self, key, parent_keys): return # Identical content else: raise ValueError( - f'Parent key mismatch, existing node {key}' - f' has parents of {existing_parent_keys} not {parent_keys}') + f"Parent key mismatch, existing node {key}" + f" has parents of {existing_parent_keys} not {parent_keys}" + ) else: node = _KnownGraphNode(key, parent_keys) nodes[key] = node @@ -301,7 +304,7 @@ def gc_sort(self): prefix_tips = {} for node in tips: if node.key.__class__ is str or len(node.key) == 1: - prefix = '' + prefix = "" else: prefix = node.key[0] prefix_tips.setdefault(prefix, []).append(node) @@ -310,8 +313,7 @@ def gc_sort(self): result = [] for prefix in sorted(prefix_tips): - pending = sorted(prefix_tips[prefix], key=lambda n: n.key, - reverse=True) + pending = sorted(prefix_tips[prefix], key=lambda n: n.key, reverse=True) while pending: node = pending.pop() if node.parent_keys is None: @@ -333,17 +335,21 @@ def gc_sort(self): def merge_sort(self, tip_key): """Compute the merge sorted graph output.""" from breezy import tsort - as_parent_map = {node.key: node.parent_keys - for node in self._nodes.values() - if node.parent_keys is not None} + + as_parent_map = { + node.key: node.parent_keys + for node in self._nodes.values() + if node.parent_keys is not None + } # We intentionally always generate revnos and never force the # mainline_revisions # Strip the sequence_number that merge_sort generates - return [_MergeSortNode(key, merge_depth, revno, end_of_merge) - for _, key, merge_depth, revno, end_of_merge - in tsort.merge_sort(as_parent_map, tip_key, - mainline_revisions=None, - generate_revno=True)] + return [ + _MergeSortNode(key, merge_depth, revno, end_of_merge) + for _, key, merge_depth, revno, end_of_merge in tsort.merge_sort( + as_parent_map, tip_key, mainline_revisions=None, generate_revno=True + ) + ] def get_parent_keys(self, key): """Get the parents for a key. diff --git a/breezy/add.py b/breezy/add.py index 179622ad69..32b26604be 100644 --- a/breezy/add.py +++ b/breezy/add.py @@ -50,7 +50,7 @@ def __call__(self, inv, parent_ie, path, kind, _quote=osutils.quotefn): :param kind: The kind of the object being added. """ if self.should_print: - self._to_file.write(f'adding {_quote(path)}\n') + self._to_file.write(f"adding {_quote(path)}\n") return None def skip_file(self, tree, path, kind, stat_value=None): @@ -74,9 +74,9 @@ class AddWithSkipLargeAction(AddAction): _max_size = None def skip_file(self, tree, path, kind, stat_value=None): - if kind != 'file': + if kind != "file": return False - opt_name = 'add.maximum_file_size' + opt_name = "add.maximum_file_size" if self._max_size is None: config = tree.get_config_stack() self._max_size = config.get(opt_name) @@ -85,9 +85,11 @@ def skip_file(self, tree, path, kind, stat_value=None): else: file_size = stat_value.st_size if self._max_size > 0 and file_size > self._max_size: - ui.ui_factory.show_warning(gettext( - "skipping {0} (larger than {1} of {2} bytes)").format( - path, opt_name, self._max_size)) + ui.ui_factory.show_warning( + gettext("skipping {0} (larger than {1} of {2} bytes)").format( + path, opt_name, self._max_size + ) + ) return True return False @@ -96,8 +98,7 @@ class AddFromBaseAction(AddAction): """This class will try to extract file ids from another tree.""" def __init__(self, base_tree, base_path, to_file=None, should_print=None): - super().__init__(to_file=to_file, - should_print=should_print) + super().__init__(to_file=to_file, should_print=should_print) self.base_tree = base_tree self.base_path = base_path @@ -107,12 +108,11 @@ def __call__(self, inv, parent_ie, path, kind): file_id, base_path = self._get_base_file_id(path, parent_ie) if file_id is not None: if self.should_print: - self._to_file.write(f'adding {path} w/ file id from {base_path}\n') + self._to_file.write(f"adding {path} w/ file id from {base_path}\n") else: # we aren't doing anything special, so let the default # reporter happen - file_id = super().__call__( - inv, parent_ie, path, kind) + file_id = super().__call__(inv, parent_ie, path, kind) return file_id def _get_base_file_id(self, path, parent_ie): diff --git a/breezy/annotate.py b/breezy/annotate.py index ddcd6461a5..883eed524e 100644 --- a/breezy/annotate.py +++ b/breezy/annotate.py @@ -33,8 +33,9 @@ from .revision import CURRENT_REVISION, Revision -def annotate_file_tree(tree, path, to_file, verbose=False, full=False, - show_ids=False, branch=None): +def annotate_file_tree( + tree, path, to_file, verbose=False, full=False, show_ids=False, branch=None +): """Annotate path in a tree. The tree should already be read_locked() when annotate_file_tree is called. @@ -64,21 +65,22 @@ def annotate_file_tree(tree, path, to_file, verbose=False, full=False, # Should get some more pending commit attributes, like pending tags, # bugfixes etc. try: - committer = branch.get_config_stack().get('email') + committer = branch.get_config_stack().get("email") except errors.NoWhoami: - committer = 'local user' + committer = "local user" current_rev = Revision( - CURRENT_REVISION, - parent_ids=tree.get_parent_ids(), - committer=committer, message="?", - properties={}, - inventory_sha1=None, - timestamp=round(time.time(), 3), - timezone=osutils.local_time_offset()) + CURRENT_REVISION, + parent_ids=tree.get_parent_ids(), + committer=committer, + message="?", + properties={}, + inventory_sha1=None, + timestamp=round(time.time(), 3), + timezone=osutils.local_time_offset(), + ) else: current_rev = None - annotation = list(_expand_annotations( - annotations, branch, current_rev)) + annotation = list(_expand_annotations(annotations, branch, current_rev)) _print_annotations(annotation, verbose, to_file, full, encoding) @@ -100,21 +102,26 @@ def _print_annotations(annotation, verbose, to_file, full, encoding): max_revno_len = max(max_revno_len, 3) # Output the annotations - prevanno = '' - for (revno_str, author, date_str, _line_rev_id, text) in annotation: + prevanno = "" + for revno_str, author, date_str, _line_rev_id, text in annotation: if verbose: - anno = '%-*s %-*s %8s ' % (max_revno_len, revno_str, - max_origin_len, author, date_str) + anno = "%-*s %-*s %8s " % ( + max_revno_len, + revno_str, + max_origin_len, + author, + date_str, + ) else: if len(revno_str) > max_revno_len: - revno_str = revno_str[:max_revno_len - 1] + '>' + revno_str = revno_str[: max_revno_len - 1] + ">" anno = "%-*s %-7s " % (max_revno_len, revno_str, author[:7]) if anno.lstrip() == "" and full: anno = prevanno # GZ 2017-05-21: Writing both unicode annotation and bytes from file # which the given to_file must cope with. to_file.write(anno) - to_file.write(f'| {text.decode(encoding)}\n') + to_file.write(f"| {text.decode(encoding)}\n") prevanno = anno @@ -127,9 +134,10 @@ def _show_id_annotations(annotations, to_file, full, encoding): if full or last_rev_id != origin: this = origin else: - this = b'' - to_file.write('%*s | %s' % ( - max_origin_len, this.decode('utf-8'), text.decode(encoding))) + this = b"" + to_file.write( + "%*s | %s" % (max_origin_len, this.decode("utf-8"), text.decode(encoding)) + ) last_rev_id = origin return @@ -145,6 +153,7 @@ def _expand_annotations(annotations, branch, current_rev=None): :param branch: A locked branch to query for revision details. """ from . import tsort + repository = branch.repository revision_ids = {o for o, t in annotations} if current_rev is not None: @@ -158,19 +167,19 @@ def _expand_annotations(annotations, branch, current_rev=None): # VF.get_known_graph_ancestry(). graph = repository.get_graph() revision_graph = { - key: value for key, value in - graph.iter_ancestry(current_rev.parent_ids) if value is not None} + key: value + for key, value in graph.iter_ancestry(current_rev.parent_ids) + if value is not None + } revision_graph = _strip_NULL_ghosts(revision_graph) revision_graph[last_revision] = current_rev.parent_ids merge_sorted_revisions = tsort.merge_sort( - revision_graph, - last_revision, - None, - generate_revno=True) + revision_graph, last_revision, None, generate_revno=True + ) revision_id_to_revno = { rev_id: revno - for seq_num, rev_id, depth, revno, end_of_merge in - merge_sorted_revisions} + for seq_num, rev_id, depth, revno, end_of_merge in merge_sorted_revisions + } else: # TODO(jelmer): Only look up the revision ids that we need (i.e. those # in revision_ids). Possibly add a HPSS call that can look those up @@ -179,28 +188,26 @@ def _expand_annotations(annotations, branch, current_rev=None): last_origin = None revisions = {} if CURRENT_REVISION in revision_ids: - revision_id_to_revno[CURRENT_REVISION] = ( - "%d?" % (branch.revno() + 1),) + revision_id_to_revno[CURRENT_REVISION] = ("%d?" % (branch.revno() + 1),) revisions[CURRENT_REVISION] = current_rev revisions.update( - entry for entry in - repository.iter_revisions(revision_ids) - if entry[1] is not None) + entry + for entry in repository.iter_revisions(revision_ids) + if entry[1] is not None + ) for origin, text in annotations: - text = text.rstrip(b'\r\n') + text = text.rstrip(b"\r\n") if origin == last_origin: - (revno_str, author, date_str) = ('', '', '') + (revno_str, author, date_str) = ("", "", "") else: last_origin = origin if origin not in revisions: - (revno_str, author, date_str) = ('?', '?', '?') + (revno_str, author, date_str) = ("?", "?", "?") else: - revno_str = '.'.join( - str(i) for i in revision_id_to_revno[origin]) + revno_str = ".".join(str(i) for i in revision_id_to_revno[origin]) rev = revisions[origin] tz = rev.timezone or 0 - date_str = time.strftime('%Y%m%d', - time.gmtime(rev.timestamp + tz)) + date_str = time.strftime("%Y%m%d", time.gmtime(rev.timestamp + tz)) # a lazy way to get something like the email address # TODO: Get real email address author = rev.get_apparent_authors()[0] @@ -210,9 +217,13 @@ def _expand_annotations(annotations, branch, current_rev=None): yield (revno_str, author, date_str, origin, text) -def reannotate(parents_lines, new_lines, new_revision_id, - _left_matching_blocks=None, - heads_provider=None): +def reannotate( + parents_lines, + new_lines, + new_revision_id, + _left_matching_blocks=None, + heads_provider=None, +): """Create a new annotated version from new lines and parent annotations. :param parents_lines: List of annotated lines for all parents @@ -231,19 +242,25 @@ def reannotate(parents_lines, new_lines, new_revision_id, if len(parents_lines) == 0: lines = [(new_revision_id, line) for line in new_lines] elif len(parents_lines) == 1: - lines = _reannotate(parents_lines[0], new_lines, new_revision_id, - _left_matching_blocks) + lines = _reannotate( + parents_lines[0], new_lines, new_revision_id, _left_matching_blocks + ) elif len(parents_lines) == 2: - left = _reannotate(parents_lines[0], new_lines, new_revision_id, - _left_matching_blocks) - lines = _reannotate_annotated(parents_lines[1], new_lines, - new_revision_id, left, - heads_provider) + left = _reannotate( + parents_lines[0], new_lines, new_revision_id, _left_matching_blocks + ) + lines = _reannotate_annotated( + parents_lines[1], new_lines, new_revision_id, left, heads_provider + ) else: - reannotations = [_reannotate(parents_lines[0], new_lines, - new_revision_id, _left_matching_blocks)] - reannotations.extend(_reannotate(p, new_lines, new_revision_id) - for p in parents_lines[1:]) + reannotations = [ + _reannotate( + parents_lines[0], new_lines, new_revision_id, _left_matching_blocks + ) + ] + reannotations.extend( + _reannotate(p, new_lines, new_revision_id) for p in parents_lines[1:] + ) lines = [] for annos in zip(*reannotations): origins = {a for a, l in annos} @@ -261,26 +278,28 @@ def reannotate(parents_lines, new_lines, new_revision_id, return lines -def _reannotate(parent_lines, new_lines, new_revision_id, - matching_blocks=None): +def _reannotate(parent_lines, new_lines, new_revision_id, matching_blocks=None): import patiencediff + new_cur = 0 if matching_blocks is None: plain_parent_lines = [l for r, l in parent_lines] matcher = patiencediff.PatienceSequenceMatcher( - None, plain_parent_lines, new_lines) + None, plain_parent_lines, new_lines + ) matching_blocks = matcher.get_matching_blocks() lines = [] for i, j, n in matching_blocks: for line in new_lines[new_cur:j]: lines.append((new_revision_id, line)) - lines.extend(parent_lines[i:i + n]) + lines.extend(parent_lines[i : i + n]) new_cur = j + n return lines def _get_matching_blocks(old, new): import patiencediff + matcher = patiencediff.PatienceSequenceMatcher(None, old, new) return matcher.get_matching_blocks() @@ -309,10 +328,18 @@ def _old_break_annotation_tie(annotated_lines): return sorted(annotated_lines)[0] -def _find_matching_unannotated_lines(output_lines, plain_child_lines, - child_lines, start_child, end_child, - right_lines, start_right, end_right, - heads_provider, revision_id): +def _find_matching_unannotated_lines( + output_lines, + plain_child_lines, + child_lines, + start_child, + end_child, + right_lines, + start_right, + end_right, + heads_provider, + revision_id, +): """Find lines in plain_right_lines that match the existing lines. :param output_lines: Append final annotated lines to this list @@ -344,8 +371,9 @@ def _find_matching_unannotated_lines(output_lines, plain_child_lines, for right_idx, child_idx, match_len in match_blocks: # All the lines that don't match are just passed along if child_idx > last_child_idx: - output_extend(child_lines[start_child + last_child_idx: - start_child + child_idx]) + output_extend( + child_lines[start_child + last_child_idx : start_child + child_idx] + ) for offset in range(match_len): left = child_lines[start_child + child_idx + offset] right = right_lines[start_right + right_idx + offset] @@ -370,15 +398,15 @@ def _find_matching_unannotated_lines(output_lines, plain_child_lines, # performance degradation as criss-cross merges will # flip-flop the attribution. if _break_annotation_tie is None: - output_append( - _old_break_annotation_tie([left, right])) + output_append(_old_break_annotation_tie([left, right])) else: output_append(_break_annotation_tie([left, right])) last_child_idx = child_idx + match_len -def _reannotate_annotated(right_parent_lines, new_lines, new_revision_id, - annotated_lines, heads_provider): +def _reannotate_annotated( + right_parent_lines, new_lines, new_revision_id, annotated_lines, heads_provider +): """Update the annotations for a node based on another parent. :param right_parent_lines: A list of annotated lines for the right-hand @@ -401,8 +429,7 @@ def _reannotate_annotated(right_parent_lines, new_lines, new_revision_id, # The line just after the last match from the right side last_right_idx = 0 last_left_idx = 0 - matching_left_and_right = _get_matching_blocks(right_parent_lines, - annotated_lines) + matching_left_and_right = _get_matching_blocks(right_parent_lines, annotated_lines) for right_idx, left_idx, match_len in matching_left_and_right: # annotated lines from last_left_idx to left_idx did not match the # lines from last_right_idx to right_idx, the raw lines should be @@ -412,17 +439,22 @@ def _reannotate_annotated(right_parent_lines, new_lines, new_revision_id, lines_extend(annotated_lines[last_left_idx:left_idx]) else: # We need to see if any of the unannotated lines match - _find_matching_unannotated_lines(lines, - new_lines, annotated_lines, - last_left_idx, left_idx, - right_parent_lines, - last_right_idx, right_idx, - heads_provider, - new_revision_id) + _find_matching_unannotated_lines( + lines, + new_lines, + annotated_lines, + last_left_idx, + left_idx, + right_parent_lines, + last_right_idx, + right_idx, + heads_provider, + new_revision_id, + ) last_right_idx = right_idx + match_len last_left_idx = left_idx + match_len # If left and right agree on a range, just push that into the output - lines_extend(annotated_lines[left_idx:left_idx + match_len]) + lines_extend(annotated_lines[left_idx : left_idx + match_len]) return lines diff --git a/breezy/archive/__init__.py b/breezy/archive/__init__.py index 3274002f06..3e8b00f0e3 100644 --- a/breezy/archive/__init__.py +++ b/breezy/archive/__init__.py @@ -22,7 +22,6 @@ class ArchiveFormatInfo: - def __init__(self, extensions): self.extensions = extensions @@ -40,14 +39,15 @@ def extensions(self): def register(self, key, factory, extensions, help=None): """Register an archive format.""" - registry.Registry.register(self, key, factory, help, - ArchiveFormatInfo(extensions)) + registry.Registry.register( + self, key, factory, help, ArchiveFormatInfo(extensions) + ) self._register_extensions(key, extensions) - def register_lazy(self, key, module_name, member_name, extensions, - help=None): - registry.Registry.register_lazy(self, key, module_name, member_name, - help, ArchiveFormatInfo(extensions)) + def register_lazy(self, key, module_name, member_name, extensions, help=None): + registry.Registry.register_lazy( + self, key, module_name, member_name, help, ArchiveFormatInfo(extensions) + ) self._register_extensions(key, extensions) def _register_extensions(self, name, extensions): @@ -67,29 +67,45 @@ def get_format_from_filename(self, filename): return None -def create_archive(format, tree, name, root=None, subdir=None, - force_mtime=None, recurse_nested=False) -> Iterator[bytes]: +def create_archive( + format, tree, name, root=None, subdir=None, force_mtime=None, recurse_nested=False +) -> Iterator[bytes]: try: archive_fn = format_registry.get(format) except KeyError as exc: raise errors.NoSuchExportFormat(format) from exc - return cast(Iterator[bytes], - archive_fn( - tree, name, root=root, subdir=subdir, - force_mtime=force_mtime, - recurse_nested=recurse_nested)) + return cast( + Iterator[bytes], + archive_fn( + tree, + name, + root=root, + subdir=subdir, + force_mtime=force_mtime, + recurse_nested=recurse_nested, + ), + ) format_registry = ArchiveFormatRegistry() -format_registry.register_lazy('tar', 'breezy.archive.tar', - 'plain_tar_generator', ['.tar'], ) -format_registry.register_lazy('tgz', 'breezy.archive.tar', - 'tgz_generator', ['.tar.gz', '.tgz']) -format_registry.register_lazy('tbz2', 'breezy.archive.tar', - 'tbz_generator', ['.tar.bz2', '.tbz2']) -format_registry.register_lazy('tlzma', 'breezy.archive.tar', - 'tar_lzma_generator', ['.tar.lzma']) -format_registry.register_lazy('txz', 'breezy.archive.tar', - 'tar_xz_generator', ['.tar.xz']) -format_registry.register_lazy('zip', 'breezy.archive.zip', - 'zip_archive_generator', ['.zip']) +format_registry.register_lazy( + "tar", + "breezy.archive.tar", + "plain_tar_generator", + [".tar"], +) +format_registry.register_lazy( + "tgz", "breezy.archive.tar", "tgz_generator", [".tar.gz", ".tgz"] +) +format_registry.register_lazy( + "tbz2", "breezy.archive.tar", "tbz_generator", [".tar.bz2", ".tbz2"] +) +format_registry.register_lazy( + "tlzma", "breezy.archive.tar", "tar_lzma_generator", [".tar.lzma"] +) +format_registry.register_lazy( + "txz", "breezy.archive.tar", "tar_xz_generator", [".tar.xz"] +) +format_registry.register_lazy( + "zip", "breezy.archive.zip", "zip_archive_generator", [".zip"] +) diff --git a/breezy/archive/tar.py b/breezy/archive/tar.py index fc1054a37c..f8673663c2 100644 --- a/breezy/archive/tar.py +++ b/breezy/archive/tar.py @@ -58,7 +58,7 @@ def prepare_tarball_item(tree, root, final_path, tree_path, entry, force_mtime=N fileobj = BytesIO(content) elif entry.kind in ("directory", "tree-reference"): item.type = tarfile.DIRTYPE - item.name += '/' + item.name += "/" item.size = 0 item.mode = 0o755 fileobj = None @@ -69,11 +69,15 @@ def prepare_tarball_item(tree, root, final_path, tree_path, entry, force_mtime=N item.linkname = tree.get_symlink_target(tree_path) fileobj = None else: - raise errors.BzrError(f"don't know how to export {{{final_path}}} of kind {entry.kind!r}") + raise errors.BzrError( + f"don't know how to export {{{final_path}}} of kind {entry.kind!r}" + ) return (item, fileobj) -def tarball_generator(tree, root, subdir=None, force_mtime=None, format='', recurse_nested=False): +def tarball_generator( + tree, root, subdir=None, force_mtime=None, format="", recurse_nested=False +): """Export tree contents to a tarball. Args: @@ -89,9 +93,11 @@ def tarball_generator(tree, root, subdir=None, force_mtime=None, format='', recu buf = BytesIO() with closing(tarfile.open(None, f"w:{format}", buf)) as ball, tree.lock_read(): for final_path, tree_path, entry in _export_iter_entries( - tree, subdir, recurse_nested=recurse_nested): + tree, subdir, recurse_nested=recurse_nested + ): (item, fileobj) = prepare_tarball_item( - tree, root, final_path, tree_path, entry, force_mtime) + tree, root, final_path, tree_path, entry, force_mtime + ) ball.addfile(item, fileobj) # Yield the data that was written so far, rinse, repeat. yield buf.getvalue() @@ -108,15 +114,17 @@ def tgz_generator(tree, dest, root, subdir, force_mtime=None, recurse_nested=Fal """ with tree.lock_read(): import gzip + if force_mtime is not None: root_mtime = force_mtime - elif (getattr(tree, "repository", None) and - getattr(tree, "get_revision_id", None)): + elif getattr(tree, "repository", None) and getattr( + tree, "get_revision_id", None + ): # If this is a revision tree, use the revisions' timestamp rev = tree.repository.get_revision(tree.get_revision_id()) root_mtime = rev.timestamp - elif tree.is_versioned(''): - root_mtime = tree.get_file_mtime('') + elif tree.is_versioned(""): + root_mtime = tree.get_file_mtime("") else: root_mtime = None @@ -126,11 +134,10 @@ def tgz_generator(tree, dest, root, subdir, force_mtime=None, recurse_nested=Fal # dest. (bug 102234) basename = os.path.basename(dest) buf = BytesIO() - zipstream = gzip.GzipFile(basename, 'w', fileobj=buf, - mtime=root_mtime) + zipstream = gzip.GzipFile(basename, "w", fileobj=buf, mtime=root_mtime) for chunk in tarball_generator( - tree, root, subdir, force_mtime, - recurse_nested=recurse_nested): + tree, root, subdir, force_mtime, recurse_nested=recurse_nested + ): zipstream.write(chunk) # Yield the data that was written so far, rinse, repeat. yield buf.getvalue() @@ -148,30 +155,38 @@ def tbz_generator(tree, dest, root, subdir, force_mtime=None, recurse_nested=Fal already exists, it will be clobbered, like with "tar -c". """ return tarball_generator( - tree, root, subdir, force_mtime, format='bz2', - recurse_nested=recurse_nested) + tree, root, subdir, force_mtime, format="bz2", recurse_nested=recurse_nested + ) -def plain_tar_generator(tree, dest, root, subdir, - force_mtime=None, recurse_nested=False): +def plain_tar_generator( + tree, dest, root, subdir, force_mtime=None, recurse_nested=False +): """Export this tree to a new tar file. `dest` will be created holding the contents of this tree; if it already exists, it will be clobbered, like with "tar -c". """ return tarball_generator( - tree, root, subdir, force_mtime, format='', - recurse_nested=recurse_nested) + tree, root, subdir, force_mtime, format="", recurse_nested=recurse_nested + ) def tar_xz_generator(tree, dest, root, subdir, force_mtime=None, recurse_nested=False): return tar_lzma_generator( - tree, dest, root, subdir, force_mtime, "xz", - recurse_nested=recurse_nested) - - -def tar_lzma_generator(tree, dest, root, subdir, force_mtime=None, - compression_format="alone", recurse_nested=False): + tree, dest, root, subdir, force_mtime, "xz", recurse_nested=recurse_nested + ) + + +def tar_lzma_generator( + tree, + dest, + root, + subdir, + force_mtime=None, + compression_format="alone", + recurse_nested=False, +): """Export this tree to a new .tar.lzma file. `dest` will be created holding the contents of this tree; if it @@ -180,18 +195,19 @@ def tar_lzma_generator(tree, dest, root, subdir, force_mtime=None, try: import lzma except ModuleNotFoundError as exc: - raise errors.DependencyNotPresent('lzma', e) from exc + raise errors.DependencyNotPresent("lzma", e) from exc compressor = lzma.LZMACompressor( format={ - 'xz': lzma.FORMAT_XZ, - 'raw': lzma.FORMAT_RAW, - 'alone': lzma.FORMAT_ALONE, - }[compression_format]) + "xz": lzma.FORMAT_XZ, + "raw": lzma.FORMAT_RAW, + "alone": lzma.FORMAT_ALONE, + }[compression_format] + ) for chunk in tarball_generator( - tree, root, subdir, force_mtime=force_mtime, - recurse_nested=recurse_nested): + tree, root, subdir, force_mtime=force_mtime, recurse_nested=recurse_nested + ): yield compressor.compress(chunk) yield compressor.flush() diff --git a/breezy/archive/zip.py b/breezy/archive/zip.py index e532584f2b..f7579a6d24 100644 --- a/breezy/archive/zip.py +++ b/breezy/archive/zip.py @@ -28,16 +28,17 @@ # Windows expects this bit to be set in the 'external_attr' section, # or it won't consider the entry a directory. -ZIP_DIRECTORY_BIT = (1 << 4) -FILE_PERMISSIONS = (0o644 << 16) -DIR_PERMISSIONS = (0o755 << 16) +ZIP_DIRECTORY_BIT = 1 << 4 +FILE_PERMISSIONS = 0o644 << 16 +DIR_PERMISSIONS = 0o755 << 16 _FILE_ATTR = stat.S_IFREG | FILE_PERMISSIONS _DIR_ATTR = stat.S_IFDIR | ZIP_DIRECTORY_BIT | DIR_PERMISSIONS -def zip_archive_generator(tree, dest, root, subdir=None, - force_mtime=None, recurse_nested=False): +def zip_archive_generator( + tree, dest, root, subdir=None, force_mtime=None, recurse_nested=False +): """Export this tree to a new zip file. `dest` will be created holding the contents of this tree; if it @@ -45,10 +46,10 @@ def zip_archive_generator(tree, dest, root, subdir=None, """ compression = zipfile.ZIP_DEFLATED with tempfile.SpooledTemporaryFile() as buf: - with closing(zipfile.ZipFile(buf, "w", compression)) as zipf, \ - tree.lock_read(): + with closing(zipfile.ZipFile(buf, "w", compression)) as zipf, tree.lock_read(): for dp, tp, ie in _export_iter_entries( - tree, subdir, recurse_nested=recurse_nested): + tree, subdir, recurse_nested=recurse_nested + ): mutter(" export {%s} kind %s to %s", tp, ie.kind, dest) # zipfile.ZipFile switches all paths to forward @@ -60,9 +61,7 @@ def zip_archive_generator(tree, dest, root, subdir=None, date_time = time.localtime(mtime)[:6] filename = osutils.pathjoin(root, dp) if ie.kind == "file": - zinfo = zipfile.ZipInfo( - filename=filename, - date_time=date_time) + zinfo = zipfile.ZipInfo(filename=filename, date_time=date_time) zinfo.compress_type = compression zinfo.external_attr = _FILE_ATTR content = tree.get_file_text(tp) @@ -72,15 +71,15 @@ def zip_archive_generator(tree, dest, root, subdir=None, # to the zip routine that they are really directories and # not just empty files. zinfo = zipfile.ZipInfo( - filename=filename + '/', - date_time=date_time) + filename=filename + "/", date_time=date_time + ) zinfo.compress_type = compression zinfo.external_attr = _DIR_ATTR - zipf.writestr(zinfo, '') + zipf.writestr(zinfo, "") elif ie.kind == "symlink": zinfo = zipfile.ZipInfo( - filename=(filename + '.lnk'), - date_time=date_time) + filename=(filename + ".lnk"), date_time=date_time + ) zinfo.compress_type = compression zinfo.external_attr = _FILE_ATTR zipf.writestr(zinfo, tree.get_symlink_target(tp)) diff --git a/breezy/atomicfile.py b/breezy/atomicfile.py index ecfb978d4b..7a3bf06812 100644 --- a/breezy/atomicfile.py +++ b/breezy/atomicfile.py @@ -25,9 +25,7 @@ class AtomicFileAlreadyClosed(errors.PathError): - - _fmt = ('"%(function)s" called on an AtomicFile after it was closed:' - ' "%(path)s"') + _fmt = '"%(function)s" called on an AtomicFile after it was closed:' ' "%(path)s"' def __init__(self, path, function): errors.PathError.__init__(self, path=path, extra=None) @@ -43,9 +41,9 @@ class AtomicFile: place or abort() to cancel. """ - __slots__ = ['tmpfilename', 'realfilename', '_fd'] + __slots__ = ["tmpfilename", "realfilename", "_fd"] - def __init__(self, filename, mode='wb', new_mode=None): + def __init__(self, filename, mode="wb", new_mode=None): global _hostname self._fd = None @@ -53,15 +51,19 @@ def __init__(self, filename, mode='wb', new_mode=None): if _hostname is None: _hostname = osutils.get_host_name() - self.tmpfilename = '%s.%d.%s.%s.tmp' % (filename, _pid, _hostname, - osutils.rand_chars(10)) + self.tmpfilename = "%s.%d.%s.%s.tmp" % ( + filename, + _pid, + _hostname, + osutils.rand_chars(10), + ) self.realfilename = filename flags = os.O_EXCL | os.O_CREAT | os.O_WRONLY | osutils.O_NOINHERIT - if mode == 'wb': + if mode == "wb": flags |= osutils.O_BINARY - elif mode != 'wt': + elif mode != "wt": raise ValueError(f"invalid AtomicFile mode {mode!r}") if new_mode is not None: @@ -81,7 +83,7 @@ def __init__(self, filename, mode='wb', new_mode=None): osutils.chmod_if_possible(self.tmpfilename, new_mode) def __repr__(self): - return f'{self.__class__.__name__}({self.realfilename!r})' + return f"{self.__class__.__name__}({self.realfilename!r})" def write(self, data): """Write some data to the file. Like file.write().""" @@ -90,20 +92,19 @@ def write(self, data): def _close_tmpfile(self, func_name): """Close the local temp file in preparation for commit or abort.""" if self._fd is None: - raise AtomicFileAlreadyClosed(path=self.realfilename, - function=func_name) + raise AtomicFileAlreadyClosed(path=self.realfilename, function=func_name) fd = self._fd self._fd = None os.close(fd) def commit(self): """Close the file and move to final name.""" - self._close_tmpfile('commit') + self._close_tmpfile("commit") osutils.rename(self.tmpfilename, self.realfilename) def abort(self): """Discard temporary file without committing changes.""" - self._close_tmpfile('abort') + self._close_tmpfile("abort") os.remove(self.tmpfilename) def close(self): diff --git a/breezy/bisect.py b/breezy/bisect.py index 828018ae91..eb2e66e0f0 100644 --- a/breezy/bisect.py +++ b/breezy/bisect.py @@ -37,15 +37,15 @@ def __init__(self, controldir, filename=BISECT_REV_PATH): self._controldir = controldir self._branch = self._controldir.open_branch() if self._controldir.control_transport.has(filename): - self._revid = self._controldir.control_transport.get_bytes( - filename).strip() + self._revid = self._controldir.control_transport.get_bytes(filename).strip() else: self._revid = self._branch.last_revision() def _save(self): """Save the current revision.""" self._controldir.control_transport.put_bytes( - self._filename, self._revid + b"\n") + self._filename, self._revid + b"\n" + ) def get_current_revid(self): """Return the current revision id.""" @@ -79,8 +79,7 @@ def switch(self, revid): revid = self._branch.get_rev_id(revid) elif isinstance(revid, list): revid = revid[0].in_history(working.branch).rev_id - working.revert(None, working.branch.repository.revision_tree(revid), - False) + working.revert(None, working.branch.repository.revision_tree(revid), False) self._revid = revid self._save() @@ -134,14 +133,18 @@ def _find_range_and_middle(self, branch_last_rev=None): with repo.lock_read(): graph = repo.get_graph() rev_sequence = graph.iter_lefthand_ancestry( - last_revid, (_mod_revision.NULL_REVISION,)) + last_revid, (_mod_revision.NULL_REVISION,) + ) high_revid = None low_revid = None between_revs = [] for revision in rev_sequence: between_revs.insert(0, revision) - matches = [x[1] for x in self._items - if x[0] == revision and x[1] in ('yes', 'no')] + matches = [ + x[1] + for x in self._items + if x[0] == revision and x[1] in ("yes", "no") + ] if not matches: continue if len(matches) > 1: @@ -185,8 +188,9 @@ def _switch_wc_to_revno(self, revno, outf): def _set_status(self, revid, status): """Set the bisect status for the given revid.""" if not self.is_done(): - if status != "done" and revid in [x[0] for x in self._items - if x[1] in ['yes', 'no']]: + if status != "done" and revid in [ + x[0] for x in self._items if x[1] in ["yes", "no"] + ]: raise RuntimeError(f"attempting to add revid {revid} twice") self._items.append((revid, status)) @@ -201,16 +205,16 @@ def load(self): revlog = self._open_for_read() for line in revlog: (revid, status) = line.split() - self._items.append((revid, status.decode('ascii'))) + self._items.append((revid, status.decode("ascii"))) def save(self): """Save the bisection log.""" - contents = b''.join( - (b"%s %s\n" % (revid, status.encode('ascii'))) - for (revid, status) in self._items) + contents = b"".join( + (b"%s %s\n" % (revid, status.encode("ascii"))) + for (revid, status) in self._items + ) if self._filename: - self._controldir.control_transport.put_bytes( - self._filename, contents) + self._controldir.control_transport.put_bytes(self._filename, contents) else: sys.stdout.write(contents) @@ -242,9 +246,10 @@ def bisect(self, outf): self._find_range_and_middle() # If we've found the "final" revision, check for a # merge point. - while ((self._middle_revid == self._high_revid or - self._middle_revid == self._low_revid) and - self.is_merge_point(self._middle_revid)): + while ( + self._middle_revid == self._high_revid + or self._middle_revid == self._low_revid + ) and self.is_merge_point(self._middle_revid): for parent in self.get_parent_revids(self._middle_revid): if parent == self._low_revid: continue @@ -252,8 +257,10 @@ def bisect(self, outf): self._find_range_and_middle(parent) break self._switch_wc_to_revno(self._middle_revid, outf) - if self._middle_revid == self._high_revid or \ - self._middle_revid == self._low_revid: + if ( + self._middle_revid == self._high_revid + or self._middle_revid == self._low_revid + ): self.set_current("done") @@ -307,10 +314,12 @@ class cmd_bisect(Command): anything else for no """ - takes_args = ['subcommand', 'args*'] - takes_options = [Option('output', short_name='o', - help='Write log to this file.', type=str), - 'revision', 'directory'] + takes_args = ["subcommand", "args*"] + takes_options = [ + Option("output", short_name="o", help="Write log to this file.", type=str), + "revision", + "directory", + ] def _check(self, controldir): """Check preconditions for most operations to work.""" @@ -336,22 +345,19 @@ def _set_state(self, controldir, revspec, state): bisect_log.save() return False - def run(self, subcommand, args_list, directory='.', revision=None, - output=None): + def run(self, subcommand, args_list, directory=".", revision=None, output=None): """Handle the bisect command.""" log_fn = None - if subcommand in ('yes', 'no', 'move') and revision: + if subcommand in ("yes", "no", "move") and revision: pass - elif subcommand in ('replay', ) and args_list and len(args_list) == 1: + elif subcommand in ("replay",) and args_list and len(args_list) == 1: log_fn = args_list[0] - elif subcommand in ('move', ) and not revision: - raise CommandError( - "The 'bisect move' command requires a revision.") - elif subcommand in ('run', ): + elif subcommand in ("move",) and not revision: + raise CommandError("The 'bisect move' command requires a revision.") + elif subcommand in ("run",): run_script = args_list[0] elif args_list or revision: - raise CommandError( - "Improper arguments to bisect " + subcommand) + raise CommandError("Improper arguments to bisect " + subcommand) controldir, _ = ControlDir.open_containing(directory) @@ -373,8 +379,7 @@ def run(self, subcommand, args_list, directory='.', revision=None, elif subcommand == "run": self.run_bisect(controldir, run_script) else: - raise CommandError( - "Unknown bisect command: " + subcommand) + raise CommandError("Unknown bisect command: " + subcommand) def reset(self, controldir): """Reset the bisect state to no state.""" @@ -428,6 +433,7 @@ def replay(self, controldir, filename): def run_bisect(self, controldir, script): import subprocess + note("Starting bisect.") self.start(controldir) while True: @@ -436,11 +442,11 @@ def run_bisect(self, controldir, script): process.wait() retcode = process.returncode if retcode == 0: - done = self._set_state(controldir, None, 'yes') + done = self._set_state(controldir, None, "yes") elif retcode == 125: break else: - done = self._set_state(controldir, None, 'no') + done = self._set_state(controldir, None, "no") if done: break except RuntimeError: diff --git a/breezy/bisect_multi.py b/breezy/bisect_multi.py index ca47b55983..44ab9e4f9b 100644 --- a/breezy/bisect_multi.py +++ b/breezy/bisect_multi.py @@ -17,8 +17,8 @@ """Bisection lookup multiple keys.""" __all__ = [ - 'bisect_multi_bytes', - ] + "bisect_multi_bytes", +] def bisect_multi_bytes(content_lookup, size, keys): diff --git a/breezy/branch.py b/breezy/branch.py index aba30c7dd7..49ff0978fb 100644 --- a/breezy/branch.py +++ b/breezy/branch.py @@ -40,11 +40,11 @@ from .tag import TagConflict, TagUpdates - class UnstackableBranchFormat(errors.BzrError): - - _fmt = ("The branch '%(url)s'(%(format)s) is not a stackable format. " - "You will need to upgrade the branch to permit branch stacking.") + _fmt = ( + "The branch '%(url)s'(%(format)s) is not a stackable format. " + "You will need to upgrade the branch to permit branch stacking." + ) def __init__(self, format, url): errors.BzrError.__init__(self) @@ -53,7 +53,6 @@ def __init__(self, format, url): class BindingUnsupported(errors.UnsupportedOperation): - _fmt = "Branch at %(url)s does not support binding." def __init__(self, branch): @@ -99,7 +98,7 @@ def __init__(self, possible_transports: Optional[List[Transport]] = None) -> Non self._master_branch_cache = None self._merge_sorted_revisions_cache = None self._open_hook(possible_transports) - hooks = Branch.hooks['open'] + hooks = Branch.hooks["open"] for hook in hooks: hook(self) @@ -113,7 +112,7 @@ def _activate_fallback_location(self, url, possible_transports): # This fallback is already configured. This probably only # happens because ControlDir.sprout is a horrible mess. To # avoid confusing _unstack we don't add this a second time. - mutter('duplicate activation of fallback %r on %r', url, self) + mutter("duplicate activation of fallback %r on %r", url, self) return repo = self._get_fallback_repository(url, possible_transports) if repo.has_same_location(self.repository): @@ -131,8 +130,10 @@ def break_lock(self) -> None: raise NotImplementedError(self.break_lock) def _extend_partial_history( - self, stop_index: Optional[int] = None, - stop_revision: Optional[RevisionID] = None) -> None: + self, + stop_index: Optional[int] = None, + stop_revision: Optional[RevisionID] = None, + ) -> None: """Extend the partial history to include a given index. If a stop_index is supplied, stop when that index has been reached. @@ -149,10 +150,12 @@ def _extend_partial_history( if len(self._partial_revision_history_cache) == 0: self._partial_revision_history_cache = [self.last_revision()] repository._iter_for_revno( - self.repository, self._partial_revision_history_cache, - stop_index=stop_index, stop_revision=stop_revision) - if self._partial_revision_history_cache[-1] == \ - _mod_revision.NULL_REVISION: + self.repository, + self._partial_revision_history_cache, + stop_index=stop_index, + stop_revision=stop_revision, + ) + if self._partial_revision_history_cache[-1] == _mod_revision.NULL_REVISION: self._partial_revision_history_cache.pop() def _get_check_refs(self): @@ -161,7 +164,7 @@ def _get_check_refs(self): See breezy.check. """ revid = self.last_revision() - return [('revision-existence', revid), ('lefthand-distance', revid)] + return [("revision-existence", revid), ("lefthand-distance", revid)] @staticmethod def open(base, _unsupported=False, possible_transports=None): @@ -171,22 +174,24 @@ def open(base, _unsupported=False, possible_transports=None): Branch.open(URL) -> a Branch instance. """ control = ControlDir.open( - base, possible_transports=possible_transports, - _unsupported=_unsupported) + base, possible_transports=possible_transports, _unsupported=_unsupported + ) return control.open_branch( - unsupported=_unsupported, - possible_transports=possible_transports) + unsupported=_unsupported, possible_transports=possible_transports + ) @staticmethod - def open_from_transport(transport: Transport, name: Optional[str] = None, - _unsupported: bool = False, - possible_transports=None): + def open_from_transport( + transport: Transport, + name: Optional[str] = None, + _unsupported: bool = False, + possible_transports=None, + ): """Open the branch rooted at transport.""" - control = ControlDir.open_from_transport( - transport, _unsupported) + control = ControlDir.open_from_transport(transport, _unsupported) return control.open_branch( - name=name, unsupported=_unsupported, - possible_transports=possible_transports) + name=name, unsupported=_unsupported, possible_transports=possible_transports + ) @staticmethod def open_containing(url, possible_transports=None): @@ -201,8 +206,7 @@ def open_containing(url, possible_transports=None): raised. If there is one, it is returned, along with the unused portion of url. """ - control, relpath = ControlDir.open_containing( - url, possible_transports) + control, relpath = ControlDir.open_containing(url, possible_transports) branch = control.open_branch(possible_transports=possible_transports) return (branch, relpath) @@ -275,21 +279,22 @@ def _get_nick(self, local=False, possible_transports=None): except errors.BzrError as e: # Silently fall back to local implicit nick if the master is # unavailable - mutter("Could not connect to bound branch, " - "falling back to local nick.\n " + str(e)) + mutter( + "Could not connect to bound branch, " + "falling back to local nick.\n " + str(e) + ) return config.get_nickname() def _set_nick(self, nick): - self.get_config().set_user_option('nickname', nick, warn_masked=True) + self.get_config().set_user_option("nickname", nick, warn_masked=True) nick = property(_get_nick, _set_nick) def is_locked(self): raise NotImplementedError(self.is_locked) - def _lefthand_history(self, revision_id, last_rev=None, - other_branch=None): - if debug.debug_flag_enabled('evil'): + def _lefthand_history(self, revision_id, last_rev=None, other_branch=None): + if debug.debug_flag_enabled("evil"): mutter_callsite(4, "_lefthand_history scales with history.") # stop_revision must be a descendant of last_revision graph = self.repository.get_graph() @@ -305,8 +310,7 @@ def _lefthand_history(self, revision_id, last_rev=None, new_history = [] check_not_reserved_id = _mod_revision.check_not_reserved_id # Do not include ghosts or graph origin in revision_history - while (current_rev_id in parents_map - and len(parents_map[current_rev_id]) > 0): + while current_rev_id in parents_map and len(parents_map[current_rev_id]) > 0: check_not_reserved_id(current_rev_id) new_history.append(current_rev_id) current_rev_id = parents_map[current_rev_id][0] @@ -368,15 +372,19 @@ def _do_dotted_revno_to_revision_id(self, revno): try: return self.get_rev_id(revno[0]) except errors.RevisionNotPresent as exc: - raise errors.GhostRevisionsHaveNoRevno(revno[0], exc.revision_id) from exc + raise errors.GhostRevisionsHaveNoRevno( + revno[0], exc.revision_id + ) from exc revision_id_to_revno = self.get_revision_id_to_revno_map() - revision_ids = [revision_id for revision_id, this_revno - in revision_id_to_revno.items() - if revno == this_revno] + revision_ids = [ + revision_id + for revision_id, this_revno in revision_id_to_revno.items() + if revno == this_revno + ] if len(revision_ids) == 1: return revision_ids[0] else: - revno_str = '.'.join(map(str, revno)) + revno_str = ".".join(map(str, revno)) raise errors.NoSuchRevision(self, revno_str) def revision_id_to_dotted_revno(self, revision_id): @@ -416,9 +424,8 @@ def get_revision_id_to_revno_map(self): Returns: A dictionary mapping revision_id => dotted revno. This dictionary should not be modified by the caller. """ - if debug.debug_flag_enabled('evil'): - mutter_callsite( - 3, "get_revision_id_to_revno_map scales with ancestry.") + if debug.debug_flag_enabled("evil"): + mutter_callsite(3, "get_revision_id_to_revno_map scales with ancestry.") with self.lock_read(): if self._revision_id_to_revno_cache is not None: mapping = self._revision_id_to_revno_cache @@ -442,13 +449,18 @@ def _gen_revno_map(self): Returns: A dictionary mapping revision_id => dotted revno. """ revision_id_to_revno = { - rev_id: revno for rev_id, depth, revno, end_of_merge - in self.iter_merge_sorted_revisions()} + rev_id: revno + for rev_id, depth, revno, end_of_merge in self.iter_merge_sorted_revisions() + } return revision_id_to_revno - def iter_merge_sorted_revisions(self, start_revision_id=None, - stop_revision_id=None, - stop_rule='exclude', direction='reverse'): + def iter_merge_sorted_revisions( + self, + start_revision_id=None, + stop_revision_id=None, + stop_rule="exclude", + direction="reverse", + ): """Walk the revisions for a branch in merge sorted order. Merge sorted order is the output from a merge-aware, @@ -498,26 +510,29 @@ def iter_merge_sorted_revisions(self, start_revision_id=None, # start_revision_id. if self._merge_sorted_revisions_cache is None: last_revision = self.last_revision() - known_graph = self.repository.get_known_graph_ancestry( - [last_revision]) + known_graph = self.repository.get_known_graph_ancestry([last_revision]) self._merge_sorted_revisions_cache = known_graph.merge_sort( - last_revision) + last_revision + ) filtered = self._filter_merge_sorted_revisions( - self._merge_sorted_revisions_cache, start_revision_id, - stop_revision_id, stop_rule) + self._merge_sorted_revisions_cache, + start_revision_id, + stop_revision_id, + stop_rule, + ) # Make sure we don't return revisions that are not part of the # start_revision_id ancestry. filtered = self._filter_start_non_ancestors(filtered) - if direction == 'reverse': + if direction == "reverse": return filtered - if direction == 'forward': + if direction == "forward": return reversed(list(filtered)) else: - raise ValueError(f'invalid direction {direction!r}') + raise ValueError(f"invalid direction {direction!r}") - def _filter_merge_sorted_revisions(self, merge_sorted_revisions, - start_revision_id, stop_revision_id, - stop_rule): + def _filter_merge_sorted_revisions( + self, merge_sorted_revisions, start_revision_id, stop_revision_id, stop_rule + ): """Iterate over an inclusive range of sorted revisions.""" rev_iter = iter(merge_sorted_revisions) if start_revision_id is not None: @@ -535,35 +550,32 @@ def _filter_merge_sorted_revisions(self, merge_sorted_revisions, # Yield everything for node in rev_iter: rev_id = node.key - yield (rev_id, node.merge_depth, node.revno, - node.end_of_merge) - elif stop_rule == 'exclude': + yield (rev_id, node.merge_depth, node.revno, node.end_of_merge) + elif stop_rule == "exclude": for node in rev_iter: rev_id = node.key if rev_id == stop_revision_id: return - yield (rev_id, node.merge_depth, node.revno, - node.end_of_merge) - elif stop_rule == 'include': + yield (rev_id, node.merge_depth, node.revno, node.end_of_merge) + elif stop_rule == "include": for node in rev_iter: rev_id = node.key - yield (rev_id, node.merge_depth, node.revno, - node.end_of_merge) + yield (rev_id, node.merge_depth, node.revno, node.end_of_merge) if rev_id == stop_revision_id: return - elif stop_rule == 'with-merges-without-common-ancestry': + elif stop_rule == "with-merges-without-common-ancestry": # We want to exclude all revisions that are already part of the # stop_revision_id ancestry. graph = self.repository.get_graph() - ancestors = graph.find_unique_ancestors(start_revision_id, - [stop_revision_id]) + ancestors = graph.find_unique_ancestors( + start_revision_id, [stop_revision_id] + ) for node in rev_iter: rev_id = node.key if rev_id not in ancestors: continue - yield (rev_id, node.merge_depth, node.revno, - node.end_of_merge) - elif stop_rule == 'with-merges': + yield (rev_id, node.merge_depth, node.revno, node.end_of_merge) + elif stop_rule == "with-merges": stop_rev = self.repository.get_revision(stop_revision_id) if stop_rev.parent_ids: left_parent = stop_rev.parent_ids[0] @@ -578,10 +590,8 @@ def _filter_merge_sorted_revisions(self, merge_sorted_revisions, if rev_id == left_parent: # reached the left parent after the stop_revision return - if (not reached_stop_revision_id - or rev_id in revision_id_whitelist): - yield (rev_id, node.merge_depth, node.revno, - node.end_of_merge) + if not reached_stop_revision_id or rev_id in revision_id_whitelist: + yield (rev_id, node.merge_depth, node.revno, node.end_of_merge) if reached_stop_revision_id or rev_id == stop_revision_id: # only do the merged revs of rev_id from now on rev = self.repository.get_revision(rev_id) @@ -589,7 +599,7 @@ def _filter_merge_sorted_revisions(self, merge_sorted_revisions, reached_stop_revision_id = True revision_id_whitelist.extend(rev.parent_ids) else: - raise ValueError(f'invalid stop_rule {stop_rule!r}') + raise ValueError(f"invalid stop_rule {stop_rule!r}") def _filter_start_non_ancestors(self, rev_iter): # If we started from a dotted revno, we want to consider it as a tip @@ -621,7 +631,7 @@ def _filter_start_non_ancestors(self, rev_iter): # called in that case. -- vila 20100322 return - for (rev_id, merge_depth, revno, end_of_merge) in rev_iter: + for rev_id, merge_depth, revno, end_of_merge in rev_iter: if not clean: if rev_id in whitelist: pmap = self.repository.get_parent_map([rev_id]) @@ -668,12 +678,12 @@ def get_append_revisions_only(self): """Whether it is only possible to append revisions to the history.""" if not self._format.supports_set_append_revisions_only(): return False - return self.get_config_stack().get('append_revisions_only') + return self.get_config_stack().get("append_revisions_only") def set_append_revisions_only(self, enabled: bool) -> None: if not self._format.supports_set_append_revisions_only(): raise errors.UpgradeRequired(self.user_url) - self.get_config_stack().set('append_revisions_only', enabled) + self.get_config_stack().set("append_revisions_only", enabled) def fetch(self, from_branch, stop_revision=None, limit=None, lossy=False): """Copy revisions from from_branch into this branch. @@ -688,7 +698,8 @@ def fetch(self, from_branch, stop_revision=None, limit=None, lossy=False): """ with self.lock_write(): return InterBranch.get(from_branch, self).fetch( - stop_revision, limit=limit, lossy=lossy) + stop_revision, limit=limit, lossy=lossy + ) def get_bound_location(self) -> Optional[str]: """Return the URL of the branch we are bound to. @@ -702,9 +713,17 @@ def get_old_bound_location(self): """Return the URL of the branch we used to be bound to.""" raise errors.UpgradeRequired(self.user_url) - def get_commit_builder(self, parents, config_stack=None, timestamp=None, - timezone=None, committer=None, revprops=None, - revision_id=None, lossy=False): + def get_commit_builder( + self, + parents, + config_stack=None, + timestamp=None, + timezone=None, + committer=None, + revprops=None, + revision_id=None, + lossy=False, + ): """Obtain a CommitBuilder for this branch. Args: @@ -723,12 +742,20 @@ def get_commit_builder(self, parents, config_stack=None, timestamp=None, config_stack = self.get_config_stack() return self.repository.get_commit_builder( - self, parents, config_stack, timestamp, timezone, committer, - revprops, revision_id, lossy) + self, + parents, + config_stack, + timestamp, + timezone, + committer, + revprops, + revision_id, + lossy, + ) def get_master_branch( - self, possible_transports: Optional[List[Transport]] = None - ) -> Optional["Branch"]: + self, possible_transports: Optional[List[Transport]] = None + ) -> Optional["Branch"]: """Return the branch we are bound to. Returns: Either a Branch, or None @@ -746,7 +773,8 @@ def get_stacked_on_url(self) -> str: raise NotImplementedError(self.get_stacked_on_url) def set_last_revision_info( - self, revno: Optional[int], revision_id: RevisionID) -> None: + self, revno: Optional[int], revision_id: RevisionID + ) -> None: """Set the last revision of this branch. The caller is responsible for checking that the revno is correct @@ -759,9 +787,12 @@ def set_last_revision_info( """ raise NotImplementedError(self.set_last_revision_info) - def generate_revision_history(self, revision_id: RevisionID, - last_rev: Optional[RevisionID] = None, - other_branch: Optional["Branch"] = None) -> None: + def generate_revision_history( + self, + revision_id: RevisionID, + last_rev: Optional[RevisionID] = None, + other_branch: Optional["Branch"] = None, + ) -> None: """See Branch.generate_revision_history.""" with self.lock_write(): graph = self.repository.get_graph() @@ -769,13 +800,12 @@ def generate_revision_history(self, revision_id: RevisionID, known_revision_ids = [ (last_revid, last_revno), (_mod_revision.NULL_REVISION, 0), - ] + ] if last_rev is not None: if not graph.is_ancestor(last_rev, revision_id): # our previous tip is not merged into stop_revision raise errors.DivergedBranches(self, other_branch) - revno = graph.find_distance_to_null( - revision_id, known_revision_ids) + revno = graph.find_distance_to_null(revision_id, known_revision_ids) self.set_last_revision_info(revno, revision_id) def _set_parent_location(self, url: Optional[str]) -> None: @@ -790,11 +820,11 @@ def set_parent(self, url: Optional[str]) -> None: if url is not None: if isinstance(url, str): try: - url.encode('ascii') + url.encode("ascii") except UnicodeEncodeError as exc: raise urlutils.InvalidURL( - url, "Urls must be 7-bit ascii, " - "use breezy.urlutils.escape") from exc + url, "Urls must be 7-bit ascii, " "use breezy.urlutils.escape" + ) from exc url = urlutils.relative_url(self.base, url) with self.lock_write(): self._set_parent_location(url) @@ -861,7 +891,7 @@ def _gen_revision_history(self): raise NotImplementedError(self._gen_revision_history) def _revision_history(self) -> List[RevisionID]: - if debug.debug_flag_enabled('evil'): + if debug.debug_flag_enabled("evil"): mutter_callsite(3, "revision_history scales with history.") if self._revision_history_cache is not None: history = self._revision_history_cache @@ -893,15 +923,13 @@ def last_revision_info(self) -> Tuple[int, RevisionID]: """ with self.lock_read(): if self._last_revision_info_cache is None: - self._last_revision_info_cache = ( - self._read_last_revision_info()) + self._last_revision_info_cache = self._read_last_revision_info() return self._last_revision_info_cache def _read_last_revision_info(self): raise NotImplementedError(self._read_last_revision_info) - def import_last_revision_info_and_tags(self, source, revno, revid, - *, lossy=False): + def import_last_revision_info_and_tags(self, source, revno, revid, *, lossy=False): """Set the last revision info, importing from another repo if necessary. This is used by the bound branch code to upload a revision to @@ -932,7 +960,9 @@ def revision_id_to_revno(self, revision_id: RevisionID) -> int: except ValueError as exc: raise errors.NoSuchRevision(self, revision_id) from exc - def get_rev_id(self, revno: int, history: Optional[List[RevisionID]] = None) -> RevisionID: + def get_rev_id( + self, revno: int, history: Optional[List[RevisionID]] = None + ) -> RevisionID: """Find the revision id of the specified revno.""" with self.lock_read(): if revno == 0: @@ -947,10 +977,15 @@ def get_rev_id(self, revno: int, history: Optional[List[RevisionID]] = None) -> self._extend_partial_history(distance_from_last) return self._partial_revision_history_cache[distance_from_last] - def pull(self, source: "Branch", *, overwrite: bool = False, - stop_revision: Optional[RevisionID] = None, - possible_transports: Optional[List[Transport]] = None, - **kwargs) -> "PullResult": + def pull( + self, + source: "Branch", + *, + overwrite: bool = False, + stop_revision: Optional[RevisionID] = None, + possible_transports: Optional[List[Transport]] = None, + **kwargs, + ) -> "PullResult": """Mirror source into this branch. This branch is considered to be 'local', having low latency. @@ -958,18 +993,28 @@ def pull(self, source: "Branch", *, overwrite: bool = False, Returns: PullResult instance """ return InterBranch.get(source, self).pull( - overwrite=overwrite, stop_revision=stop_revision, - possible_transports=possible_transports, **kwargs) - - def push(self, target: "Branch", *, overwrite: bool = False, - stop_revision: Optional[RevisionID] = None, lossy: bool = False, - **kwargs): + overwrite=overwrite, + stop_revision=stop_revision, + possible_transports=possible_transports, + **kwargs, + ) + + def push( + self, + target: "Branch", + *, + overwrite: bool = False, + stop_revision: Optional[RevisionID] = None, + lossy: bool = False, + **kwargs, + ): """Mirror this branch into target. This branch is considered to be 'local', having low latency. """ return InterBranch.get(self, target).push( - overwrite, stop_revision, lossy, **kwargs) + overwrite, stop_revision, lossy, **kwargs + ) def basis_tree(self): """Return `Tree` object for last revision.""" @@ -987,7 +1032,7 @@ def get_parent(self) -> Optional[str]: return parent # This is an old-format absolute path to a local branch # turn it into a url - if parent.startswith('/'): + if parent.startswith("/"): parent = urlutils.local_path_to_url(parent) try: return urlutils.join(self.base[:-1], parent) @@ -997,12 +1042,11 @@ def get_parent(self) -> Optional[str]: def _get_parent_location(self): raise NotImplementedError(self._get_parent_location) - def _set_config_location(self, name, url, *, config=None, - make_relative=False): + def _set_config_location(self, name, url, *, config=None, make_relative=False): if config is None: config = self.get_config_stack() if url is None: - url = '' + url = "" elif make_relative: url = urlutils.relative_url(self.base, url) config.set(name, url) @@ -1011,13 +1055,13 @@ def _get_config_location(self, name: str, *, config=None) -> Optional[str]: if config is None: config = self.get_config_stack() location = config.get(name) - if location == '': + if location == "": location = None return cast(Optional[str], location) def get_child_submit_format(self) -> Optional[str]: """Return the preferred format of submissions to this branch.""" - return cast(Optional[str], self.get_config_stack().get('child_submit_format')) + return cast(Optional[str], self.get_config_stack().get("child_submit_format")) def get_submit_branch(self) -> Optional[str]: """Return the submit location of the branch. @@ -1026,7 +1070,7 @@ def get_submit_branch(self) -> Optional[str]: pattern is that the user can override it by specifying a location. """ - return cast(Optional[str], self.get_config_stack().get('submit_branch')) + return cast(Optional[str], self.get_config_stack().get("submit_branch")) def set_submit_branch(self, location: str) -> None: """Return the submit location of the branch. @@ -1035,14 +1079,14 @@ def set_submit_branch(self, location: str) -> None: pattern is that the user can override it by specifying a location. """ - self.get_config_stack().set('submit_branch', location) + self.get_config_stack().set("submit_branch", location) def get_public_branch(self) -> Optional[str]: """Return the public location of the branch. This is used by merge directives. """ - return self._get_config_location('public_branch') + return self._get_config_location("public_branch") def set_public_branch(self, location: str) -> None: """Return the submit location of the branch. @@ -1051,11 +1095,11 @@ def set_public_branch(self, location: str) -> None: pattern is that the user can override it by specifying a location. """ - self._set_config_location('public_branch', location) + self._set_config_location("public_branch", location) def get_push_location(self) -> Optional[str]: """Return None or the location to push this branch to.""" - return cast(str, self.get_config_stack().get('push_location')) + return cast(str, self.get_config_stack().get("push_location")) def set_push_location(self, location: str) -> None: """Set a new push location for this branch.""" @@ -1063,23 +1107,21 @@ def set_push_location(self, location: str) -> None: def _run_post_change_branch_tip_hooks(self, old_revno, old_revid): """Run the post_change_branch_tip hooks.""" - hooks = Branch.hooks['post_change_branch_tip'] + hooks = Branch.hooks["post_change_branch_tip"] if not hooks: return new_revno, new_revid = self.last_revision_info() - params = ChangeBranchTipParams( - self, old_revno, new_revno, old_revid, new_revid) + params = ChangeBranchTipParams(self, old_revno, new_revno, old_revid, new_revid) for hook in hooks: hook(params) def _run_pre_change_branch_tip_hooks(self, new_revno, new_revid): """Run the pre_change_branch_tip hooks.""" - hooks = Branch.hooks['pre_change_branch_tip'] + hooks = Branch.hooks["pre_change_branch_tip"] if not hooks: return old_revno, old_revid = self.last_revision_info() - params = ChangeBranchTipParams( - self, old_revno, new_revno, old_revid, new_revid) + params = ChangeBranchTipParams(self, old_revno, new_revno, old_revid, new_revid) for hook in hooks: hook(params) @@ -1106,9 +1148,15 @@ def check_real_revno(self, revno: int) -> None: if revno < 1 or revno > self.revno(): raise errors.InvalidRevisionNumber(revno) - def clone(self, to_controldir: ControlDir, *, - revision_id: Optional[RevisionID] = None, name: Optional[str] = None, - repository_policy=None, tag_selector=None) -> "Branch": + def clone( + self, + to_controldir: ControlDir, + *, + revision_id: Optional[RevisionID] = None, + name: Optional[str] = None, + repository_policy=None, + tag_selector=None, + ) -> "Branch": """Clone this branch into to_controldir preserving all semantic values. Most API users will want 'create_clone_on_transport', which creates a @@ -1122,12 +1170,21 @@ def clone(self, to_controldir: ControlDir, *, if repository_policy is not None: repository_policy.configure_branch(result) self.copy_content_into( - result, revision_id=revision_id, tag_selector=tag_selector) + result, revision_id=revision_id, tag_selector=tag_selector + ) return result - def sprout(self, to_controldir, *, revision_id=None, repository_policy=None, - repository=None, lossy=False, tag_selector=None, - name=None): + def sprout( + self, + to_controldir, + *, + revision_id=None, + repository_policy=None, + repository=None, + lossy=False, + tag_selector=None, + name=None, + ): """Create a new line of development from the branch, into to_controldir. to_controldir controls the branch format. @@ -1135,8 +1192,7 @@ def sprout(self, to_controldir, *, revision_id=None, repository_policy=None, revision_id: if not None, the revision history in the new branch will be truncated to end with revision_id. """ - if (repository_policy is not None - and repository_policy.requires_stacking()): + if repository_policy is not None and repository_policy.requires_stacking(): to_controldir._format.require_stacking(_skip_repo=True) result = to_controldir.create_branch(repository=repository, name=name) if lossy: @@ -1145,7 +1201,8 @@ def sprout(self, to_controldir, *, revision_id=None, repository_policy=None, if repository_policy is not None: repository_policy.configure_branch(result) self.copy_content_into( - result, revision_id=revision_id, tag_selector=tag_selector) + result, revision_id=revision_id, tag_selector=tag_selector + ) master_url = self.get_bound_location() if master_url is None: result.set_parent(self.user_url) @@ -1173,7 +1230,8 @@ def _synchronize_history(self, destination, revision_id): graph = self.repository.get_graph() try: revno = graph.find_distance_to_null( - revision_id, [(source_revision_id, source_revno)]) + revision_id, [(source_revision_id, source_revno)] + ) except errors.GhostRevisionsHaveNoRevno: # Default to 1, if we can't find anything else revno = 1 @@ -1188,7 +1246,8 @@ def copy_content_into(self, destination, *, revision_id=None, tag_selector=None) and should return a boolean to indicate whether a tag should be copied """ return InterBranch.get(self, destination).copy_content_into( - revision_id=revision_id, tag_selector=tag_selector) + revision_id=revision_id, tag_selector=tag_selector + ) def update_references(self, target): if not self._format.supports_reference_locations: @@ -1212,10 +1271,13 @@ def check(self, refs): with self.lock_read(): result = BranchCheckResult(self) last_revno, last_revision_id = self.last_revision_info() - actual_revno = refs[('lefthand-distance', last_revision_id)] + actual_revno = refs[("lefthand-distance", last_revision_id)] if actual_revno != last_revno: - result.errors.append(errors.BzrCheckError( - f'revno does not match len(mainline) {last_revno} != {actual_revno}')) + result.errors.append( + errors.BzrCheckError( + f"revno does not match len(mainline) {last_revno} != {actual_revno}" + ) + ) # TODO: We should probably also check that self.revision_history # matches the repository for older branch formats. # If looking for the code that cross-checks repository parents @@ -1231,10 +1293,17 @@ def _get_checkout_format(self, lightweight=False): format.set_branch_format(self._format) return format - def create_clone_on_transport(self, to_transport, *, revision_id=None, - stacked_on=None, create_prefix=False, - use_existing_dir=False, no_tree=None, - tag_selector=None): + def create_clone_on_transport( + self, + to_transport, + *, + revision_id=None, + stacked_on=None, + create_prefix=False, + use_existing_dir=False, + no_tree=None, + tag_selector=None, + ): """Create a clone of this branch and its bzrdir. Args: @@ -1255,14 +1324,26 @@ def create_clone_on_transport(self, to_transport, *, revision_id=None, if revision_id is None: revision_id = self.last_revision() dir_to = self.controldir.clone_on_transport( - to_transport, revision_id=revision_id, stacked_on=stacked_on, - create_prefix=create_prefix, use_existing_dir=use_existing_dir, - no_tree=no_tree, tag_selector=tag_selector) + to_transport, + revision_id=revision_id, + stacked_on=stacked_on, + create_prefix=create_prefix, + use_existing_dir=use_existing_dir, + no_tree=no_tree, + tag_selector=tag_selector, + ) return dir_to.open_branch() - def create_checkout(self, to_location, *, revision_id=None, - lightweight=False, accelerator_tree=None, - hardlink=False, recurse_nested=True): + def create_checkout( + self, + to_location, + *, + revision_id=None, + lightweight=False, + accelerator_tree=None, + hardlink=False, + recurse_nested=True, + ): """Create a checkout of a branch. Args: @@ -1294,8 +1375,10 @@ def create_checkout(self, to_location, *, revision_id=None, pass else: raise errors.AlreadyControlDirError(t.base) from exc - if (checkout.control_transport.base - == self.controldir.control_transport.base): + if ( + checkout.control_transport.base + == self.controldir.control_transport.base + ): # When checking out to the same control directory, # always create a lightweight checkout lightweight = True @@ -1311,21 +1394,24 @@ def create_checkout(self, to_location, *, revision_id=None, # branch tip correctly, and seed it with history. checkout_branch.pull(self, stop_revision=revision_id) from_branch = None - tree = checkout.create_workingtree(revision_id, - from_branch=from_branch, - accelerator_tree=accelerator_tree, - hardlink=hardlink) + tree = checkout.create_workingtree( + revision_id, + from_branch=from_branch, + accelerator_tree=accelerator_tree, + hardlink=hardlink, + ) basis_tree = tree.basis_tree() with basis_tree.lock_read(): for path in basis_tree.iter_references(): reference_parent = tree.reference_parent(path) if reference_parent is None: - warning('Branch location for %s unknown.', path) + warning("Branch location for %s unknown.", path) continue reference_parent.create_checkout( tree.abspath(path), revision_id=basis_tree.get_reference_revision(path), - lightweight=lightweight) + lightweight=lightweight, + ) return tree def reconcile(self, thorough=True): @@ -1345,14 +1431,15 @@ def automatic_tag_name(self, revision_id): revision_id: Revision id of the revision. Returns: A tag name or None if no tag name could be determined. """ - for hook in Branch.hooks['automatic_tag_name']: + for hook in Branch.hooks["automatic_tag_name"]: ret = hook(self, revision_id) if ret is not None: return ret return None - def _check_if_descendant_or_diverged(self, revision_a, revision_b, graph, - other_branch): + def _check_if_descendant_or_diverged( + self, revision_a, revision_b, graph, other_branch + ): """Ensure that revision_b is a descendant of revision_a. This is a helper function for update_revisions. @@ -1361,11 +1448,11 @@ def _check_if_descendant_or_diverged(self, revision_a, revision_b, graph, Returns: True if revision_b is a descendant of revision_a. """ relation = self._revision_relations(revision_a, revision_b, graph) - if relation == 'b_descends_from_a': + if relation == "b_descends_from_a": return True - elif relation == 'diverged': + elif relation == "diverged": raise errors.DivergedBranches(self, other_branch) - elif relation == 'a_descends_from_b': + elif relation == "a_descends_from_b": return False else: raise AssertionError(f"invalid relation: {relation!r}") @@ -1377,12 +1464,12 @@ def _revision_relations(self, revision_a, revision_b, graph): """ heads = graph.heads([revision_a, revision_b]) if heads == {revision_b}: - return 'b_descends_from_a' + return "b_descends_from_a" elif heads == {revision_a, revision_b}: # These branches have diverged - return 'diverged' + return "diverged" elif heads == {revision_a}: - return 'a_descends_from_b' + return "a_descends_from_b" else: raise AssertionError(f"invalid heads: {heads!r}") @@ -1399,7 +1486,7 @@ def heads_to_fetch(self): # if_present_fetch are the tags. must_fetch = {self.last_revision()} if_present_fetch = set() - if self.get_config_stack().get('branch.fetch_tags'): + if self.get_config_stack().get("branch.fetch_tags"): try: if_present_fetch = set(self.tags.get_reverse_tag_dict()) except errors.TagsNotSupported: @@ -1414,6 +1501,7 @@ def create_memorytree(self): Returns: An in-memory MutableTree instance """ from . import memorytree + return memorytree.MemoryTree.create_on_branch(self) @@ -1475,15 +1563,16 @@ def get_format_description(self): raise NotImplementedError(self.get_format_description) def _run_post_branch_init_hooks(self, controldir, name, branch): - hooks = Branch.hooks['post_branch_init'] + hooks = Branch.hooks["post_branch_init"] if not hooks: return params = BranchInitHookParams(self, controldir, name, branch) for hook in hooks: hook(params) - def initialize(self, controldir, name=None, repository=None, - append_revisions_only=None): + def initialize( + self, controldir, name=None, repository=None, append_revisions_only=None + ): """Create a branch of this format in controldir. Args: @@ -1516,6 +1605,7 @@ def make_tags(self, branch): on a RemoteBranch. """ from .tag import DisabledTags + return DisabledTags(branch) def network_name(self): @@ -1528,8 +1618,15 @@ def network_name(self): """ raise NotImplementedError(self.network_name) - def open(self, controldir, name=None, _found=False, ignore_fallbacks=False, - found_repository=None, possible_transports=None): + def open( + self, + controldir, + name=None, + _found=False, + ignore_fallbacks=False, + found_repository=None, + possible_transports=None, + ): """Return the branch object for controldir. Args: @@ -1595,21 +1692,27 @@ def __init__(self): """ Hooks.__init__(self, "breezy.branch", "Branch.hooks") self.add_hook( - 'open', + "open", "Called with the Branch object that has been opened after a " - "branch is opened.", (1, 8)) + "branch is opened.", + (1, 8), + ) self.add_hook( - 'post_push', + "post_push", "Called after a push operation completes. post_push is called " "with a breezy.branch.BranchPushResult object and only runs in " - "the bzr client.", (0, 15)) + "the bzr client.", + (0, 15), + ) self.add_hook( - 'post_pull', + "post_pull", "Called after a pull operation completes. post_pull is called " "with a breezy.branch.PullResult object and only runs in the " - "bzr client.", (0, 15)) + "bzr client.", + (0, 15), + ) self.add_hook( - 'pre_commit', + "pre_commit", "Called after a commit is calculated but before it is " "completed. pre_commit is called with (local, master, old_revno, " "old_revid, future_revno, future_revid, tree_delta, future_tree" @@ -1618,34 +1721,44 @@ def __init__(self): "basis revision. hooks MUST NOT modify this delta. " " future_tree is an in-memory tree obtained from " "CommitBuilder.revision_tree() and hooks MUST NOT modify this " - "tree.", (0, 91)) + "tree.", + (0, 91), + ) self.add_hook( - 'post_commit', + "post_commit", "Called in the bzr client after a commit has completed. " "post_commit is called with (local, master, old_revno, old_revid, " "new_revno, new_revid). old_revid is NULL_REVISION for the first " - "commit to a branch.", (0, 15)) + "commit to a branch.", + (0, 15), + ) self.add_hook( - 'post_uncommit', + "post_uncommit", "Called in the bzr client after an uncommit completes. " "post_uncommit is called with (local, master, old_revno, " "old_revid, new_revno, new_revid) where local is the local branch " "or None, master is the target branch, and an empty branch " - "receives new_revno of 0, new_revid of None.", (0, 15)) + "receives new_revno of 0, new_revid of None.", + (0, 15), + ) self.add_hook( - 'pre_change_branch_tip', + "pre_change_branch_tip", "Called in bzr client and server before a change to the tip of a " "branch is made. pre_change_branch_tip is called with a " "breezy.branch.ChangeBranchTipParams. Note that push, pull, " - "commit, uncommit will all trigger this hook.", (1, 6)) + "commit, uncommit will all trigger this hook.", + (1, 6), + ) self.add_hook( - 'post_change_branch_tip', + "post_change_branch_tip", "Called in bzr client and server after a change to the tip of a " "branch is made. post_change_branch_tip is called with a " "breezy.branch.ChangeBranchTipParams. Note that push, pull, " - "commit, uncommit will all trigger this hook.", (1, 4)) + "commit, uncommit will all trigger this hook.", + (1, 4), + ) self.add_hook( - 'transform_fallback_location', + "transform_fallback_location", "Called when a stacked branch is activating its fallback " "locations. transform_fallback_location is called with (branch, " "url), and should return a new url. Returning the same url " @@ -1656,26 +1769,33 @@ def __init__(self): "fallback locations have not been activated. When there are " "multiple hooks installed for transform_fallback_location, " "all are called with the url returned from the previous hook." - "The order is however undefined.", (1, 9)) + "The order is however undefined.", + (1, 9), + ) self.add_hook( - 'automatic_tag_name', + "automatic_tag_name", "Called to determine an automatic tag name for a revision. " "automatic_tag_name is called with (branch, revision_id) and " "should return a tag name or None if no tag name could be " "determined. The first non-None tag name returned will be used.", - (2, 2)) + (2, 2), + ) self.add_hook( - 'post_branch_init', + "post_branch_init", "Called after new branch initialization completes. " "post_branch_init is called with a " "breezy.branch.BranchInitHookParams. " "Note that init, branch and checkout (both heavyweight and " - "lightweight) will all trigger this hook.", (2, 2)) + "lightweight) will all trigger this hook.", + (2, 2), + ) self.add_hook( - 'post_switch', + "post_switch", "Called after a checkout switches branch. " "post_switch is called with a " - "breezy.branch.SwitchHookParams.", (2, 2)) + "breezy.branch.SwitchHookParams.", + (2, 2), + ) # install the default hooks into the Branch class. @@ -1718,8 +1838,13 @@ def __eq__(self, other): def __repr__(self): return "<{} of {} from ({}, {}) to ({}, {})>".format( - self.__class__.__name__, self.branch, - self.old_revno, self.old_revid, self.new_revno, self.new_revid) + self.__class__.__name__, + self.branch, + self.old_revno, + self.old_revid, + self.new_revno, + self.new_revid, + ) class BranchInitHookParams: @@ -1795,8 +1920,8 @@ def __eq__(self, other): def __repr__(self): return "<{} for {} to ({}, {})>".format( - self.__class__.__name__, self.control_dir, self.to_branch, - self.revision_id) + self.__class__.__name__, self.control_dir, self.to_branch, self.revision_id + ) class BranchFormatRegistry(ControlComponentFormatRegistry): @@ -1809,8 +1934,7 @@ def __init__(self, other_registry=None): def get_default(self): """Return the current default format.""" - if (self._default_format_key is not None - and self._default_format is None): + if self._default_format_key is not None and self._default_format is None: self._default_format = self.get(self._default_format_key) return self._default_format @@ -1854,13 +1978,12 @@ def __repr__(self): class _Result: - def _show_tag_conficts(self, to_file): - if not getattr(self, 'tag_conflicts', None): + if not getattr(self, "tag_conflicts", None): return - to_file.write('Conflicting tags:\n') + to_file.write("Conflicting tags:\n") for name, _value1, _value2 in self.tag_conflicts: - to_file.write(f' {name}\n') + to_file.write(f" {name}\n") class PullResult(_Result): @@ -1896,14 +2019,14 @@ def report(self, to_file: TextIO) -> None: tag_updates = getattr(self, "tag_updates", None) if not is_quiet(): if self.old_revid != self.new_revid: - to_file.write('Now on revision %d.\n' % self.new_revno) + to_file.write("Now on revision %d.\n" % self.new_revno) if tag_updates: - to_file.write(f'{len(tag_updates)} tag(s) updated.\n') + to_file.write(f"{len(tag_updates)} tag(s) updated.\n") if self.old_revid == self.new_revid and not tag_updates: if not tag_conflicts: - to_file.write('No revisions or tags to pull.\n') + to_file.write("No revisions or tags to pull.\n") else: - to_file.write('No revisions to pull.\n') + to_file.write("No revisions to pull.\n") self._show_tag_conficts(to_file) @@ -1948,19 +2071,22 @@ def report(self, to_file: TextIO) -> None: if not is_quiet(): if self.old_revid != self.new_revid: if self.new_revno is not None: - note(gettext('Pushed up to revision %d.'), - self.new_revno) + note(gettext("Pushed up to revision %d."), self.new_revno) else: - note(gettext('Pushed up to revision id %s.'), - self.new_revid.decode('utf-8')) + note( + gettext("Pushed up to revision id %s."), + self.new_revid.decode("utf-8"), + ) if tag_updates: - note(ngettext('%d tag updated.', '%d tags updated.', - len(tag_updates)) % len(tag_updates)) + note( + ngettext("%d tag updated.", "%d tags updated.", len(tag_updates)) + % len(tag_updates) + ) if self.old_revid == self.new_revid and not tag_updates: if not tag_conflicts: - note(gettext('No new revisions or tags to push.')) + note(gettext("No new revisions or tags to push.")) else: - note(gettext('No new revisions to push.')) + note(gettext("No new revisions to push.")) self._show_tag_conficts(to_file) @@ -1982,10 +2108,14 @@ def report_results(self, verbose: bool) -> None: if any. """ from .i18n import gettext - note(gettext('checked branch {0} format {1}').format( - self.branch.user_url, self.branch._format)) + + note( + gettext("checked branch {0} format {1}").format( + self.branch.user_url, self.branch._format + ) + ) for error in self.errors: - note(gettext('found error:%s'), error) + note(gettext("found error:%s"), error) class InterBranch(InterObject[Branch]): @@ -2009,10 +2139,14 @@ def _get_branch_formats_to_test(klass): """ raise NotImplementedError(klass._get_branch_formats_to_test) - def pull(self, overwrite: bool = False, - stop_revision: Optional[RevisionID] = None, - possible_transports: Optional[List[Transport]] = None, - local: bool = False, tag_selector=None) -> PullResult: + def pull( + self, + overwrite: bool = False, + stop_revision: Optional[RevisionID] = None, + possible_transports: Optional[List[Transport]] = None, + local: bool = False, + tag_selector=None, + ) -> PullResult: """Mirror source into target branch. The target branch is considered to be 'local', having low latency. @@ -2021,9 +2155,14 @@ def pull(self, overwrite: bool = False, """ raise NotImplementedError(self.pull) - def push(self, overwrite: bool = False, stop_revision: Optional[RevisionID] = None, - lossy: bool = False, _override_hook_source_branch: Optional[Branch] = None, - tag_selector=None): + def push( + self, + overwrite: bool = False, + stop_revision: Optional[RevisionID] = None, + lossy: bool = False, + _override_hook_source_branch: Optional[Branch] = None, + tag_selector=None, + ): """Mirror the source branch into the target branch. The source branch is considered to be 'local', having low latency. @@ -2042,9 +2181,12 @@ def copy_content_into(self, revision_id=None, tag_selector=None): """ raise NotImplementedError(self.copy_content_into) - def fetch(self, stop_revision: Optional[RevisionID] = None, - limit: Optional[int] = None, - lossy: bool = False) -> repository.FetchResult: + def fetch( + self, + stop_revision: Optional[RevisionID] = None, + limit: Optional[int] = None, + lossy: bool = False, + ) -> repository.FetchResult: """Fetch revisions. Args: @@ -2088,6 +2230,7 @@ def _get_branch_formats_to_test(klass): @classmethod def unwrap_format(klass, format): from .bzr.remote import RemoteBranchFormat + if isinstance(format, RemoteBranchFormat): format._ensure_real() return format._custom_format @@ -2105,7 +2248,7 @@ def copy_content_into(self, revision_id=None, tag_selector=None): try: parent = self.source.get_parent() except errors.InaccessibleParent as e: - mutter('parent was not accessible to copy: %s', str(e)) + mutter("parent was not accessible to copy: %s", str(e)) else: if parent: self.target.set_parent(parent) @@ -2116,23 +2259,21 @@ def fetch(self, stop_revision=None, limit=None, lossy=False): if self.target.base == self.source.base: return (0, []) from .bzr.fetch import FetchSpecFactory, TargetRepoKinds + with self.source.lock_read(), self.target.lock_write(): fetch_spec_factory = FetchSpecFactory() fetch_spec_factory.source_branch = self.source fetch_spec_factory.source_branch_stop_revision_id = stop_revision fetch_spec_factory.source_repo = self.source.repository fetch_spec_factory.target_repo = self.target.repository - fetch_spec_factory.target_repo_kind = ( - TargetRepoKinds.PREEXISTING) + fetch_spec_factory.target_repo_kind = TargetRepoKinds.PREEXISTING fetch_spec_factory.limit = limit fetch_spec = fetch_spec_factory.make_fetch_spec() return self.target.repository.fetch( - self.source.repository, - lossy=lossy, - fetch_spec=fetch_spec) + self.source.repository, lossy=lossy, fetch_spec=fetch_spec + ) - def _update_revisions(self, stop_revision=None, overwrite=False, - graph=None): + def _update_revisions(self, stop_revision=None, overwrite=False, graph=None): with self.source.lock_read(), self.target.lock_write(): other_revno, other_last_revision = self.source.last_revision_info() stop_revno = None # unknown @@ -2156,24 +2297,34 @@ def _update_revisions(self, stop_revision=None, overwrite=False, if graph is None: graph = self.target.repository.get_graph() if self.target._check_if_descendant_or_diverged( - stop_revision, last_rev, graph, self.source): + stop_revision, last_rev, graph, self.source + ): # stop_revision is a descendant of last_rev, but we aren't # overwriting, so we're done. return if stop_revno is None: if graph is None: graph = self.target.repository.get_graph() - this_revno, this_last_revision = \ - self.target.last_revision_info() + this_revno, this_last_revision = self.target.last_revision_info() stop_revno = graph.find_distance_to_null( - stop_revision, [(other_last_revision, other_revno), - (this_last_revision, this_revno)]) + stop_revision, + [ + (other_last_revision, other_revno), + (this_last_revision, this_revno), + ], + ) self.target.set_last_revision_info(stop_revno, stop_revision) - def pull(self, overwrite=False, stop_revision=None, - possible_transports=None, run_hooks=True, - _override_hook_target=None, local=False, - tag_selector=None): + def pull( + self, + overwrite=False, + stop_revision=None, + possible_transports=None, + run_hooks=True, + _override_hook_target=None, + local=False, + tag_selector=None, + ): """Pull from source into self, updating my master if any. Args: @@ -2194,29 +2345,40 @@ def pull(self, overwrite=False, stop_revision=None, normalized = urlutils.normalize_url(bound_location) try: relpath = self.source.user_transport.relpath(normalized) - source_is_master = (relpath == '') + source_is_master = relpath == "" except (errors.PathNotChild, urlutils.InvalidURL): source_is_master = False if not local and bound_location and not source_is_master: # not pulling from master, so we need to update master. - master_branch = self.target.get_master_branch( - possible_transports) + master_branch = self.target.get_master_branch(possible_transports) exit_stack.enter_context(master_branch.lock_write()) if master_branch: # pull from source into master. master_branch.pull( - self.source, overwrite=overwrite, stop_revision=stop_revision, + self.source, + overwrite=overwrite, + stop_revision=stop_revision, run_hooks=False, - tag_selector=tag_selector) + tag_selector=tag_selector, + ) return self._pull( - overwrite, stop_revision, _hook_master=master_branch, + overwrite, + stop_revision, + _hook_master=master_branch, run_hooks=run_hooks, _override_hook_target=_override_hook_target, merge_tags_to_master=not source_is_master, - tag_selector=tag_selector) - - def push(self, overwrite=False, stop_revision=None, lossy=False, - _override_hook_source_branch=None, tag_selector=None): + tag_selector=tag_selector, + ) + + def push( + self, + overwrite=False, + stop_revision=None, + lossy=False, + _override_hook_source_branch=None, + tag_selector=None, + ): """See InterBranch.push. This is the basic concrete implementation of push() @@ -2235,7 +2397,7 @@ def push(self, overwrite=False, stop_revision=None, lossy=False, def _run_hooks(): if _override_hook_source_branch: result.source_branch = _override_hook_source_branch - for hook in Branch.hooks['post_push']: + for hook in Branch.hooks["post_push"]: hook(result) with self.source.lock_read(), self.target.lock_write(): @@ -2250,12 +2412,14 @@ def _run_hooks(): # push into the master from the source branch. master_inter = InterBranch.get(self.source, master_branch) master_inter._basic_push( - overwrite, stop_revision, tag_selector=tag_selector) + overwrite, stop_revision, tag_selector=tag_selector + ) # and push into the target branch from the source. Note # that we push from the source branch again, because it's # considered the highest bandwidth repository. result = self._basic_push( - overwrite, stop_revision, tag_selector=tag_selector) + overwrite, stop_revision, tag_selector=tag_selector + ) result.master_branch = master_branch result.local_branch = self.target _run_hooks() @@ -2263,7 +2427,8 @@ def _run_hooks(): master_branch = None # no master branch result = self._basic_push( - overwrite, stop_revision, tag_selector=tag_selector) + overwrite, stop_revision, tag_selector=tag_selector + ) # TODO: Why set master_branch and local_branch if there's no # binding? Maybe cleaner to just leave them unset? -- mbp # 20070504 @@ -2287,19 +2452,28 @@ def _basic_push(self, overwrite, stop_revision, tag_selector=None): # the target. graph = self.source.repository.get_graph(self.target.repository) self._update_revisions( - stop_revision, overwrite=("history" in overwrite), graph=graph) + stop_revision, overwrite=("history" in overwrite), graph=graph + ) if self.source._push_should_merge_tags(): - result.tag_updates, result.tag_conflicts = ( - self.source.tags.merge_to( - self.target.tags, "tags" in overwrite, selector=tag_selector)) + result.tag_updates, result.tag_conflicts = self.source.tags.merge_to( + self.target.tags, "tags" in overwrite, selector=tag_selector + ) self.update_references() result.new_revno, result.new_revid = self.target.last_revision_info() return result - def _pull(self, overwrite=False, stop_revision=None, - possible_transports=None, _hook_master=None, run_hooks=True, - _override_hook_target=None, local=False, - merge_tags_to_master=True, tag_selector=None): + def _pull( + self, + overwrite=False, + stop_revision=None, + possible_transports=None, + _hook_master=None, + run_hooks=True, + _override_hook_target=None, + local=False, + merge_tags_to_master=True, + tag_selector=None, + ): """See Branch.pull. This function is the core worker, used by GenericInterBranch.pull to @@ -2333,22 +2507,22 @@ def _pull(self, overwrite=False, stop_revision=None, # TODO: Branch formats should have a flag that indicates # that revno's are expensive, and pull() should honor that flag. # -- JRV20090506 - result.old_revno, result.old_revid = \ - self.target.last_revision_info() + result.old_revno, result.old_revid = self.target.last_revision_info() overwrite = _fix_overwrite_type(overwrite) self._update_revisions( - stop_revision, overwrite=("history" in overwrite), graph=graph) + stop_revision, overwrite=("history" in overwrite), graph=graph + ) # TODO: The old revid should be specified when merging tags, # so a tags implementation that versions tags can only # pull in the most recent changes. -- JRV20090506 - result.tag_updates, result.tag_conflicts = ( - self.source.tags.merge_to( - self.target.tags, "tags" in overwrite, - ignore_master=not merge_tags_to_master, - selector=tag_selector)) + result.tag_updates, result.tag_conflicts = self.source.tags.merge_to( + self.target.tags, + "tags" in overwrite, + ignore_master=not merge_tags_to_master, + selector=tag_selector, + ) self.update_references() - result.new_revno, result.new_revid = ( - self.target.last_revision_info()) + result.new_revno, result.new_revid = self.target.last_revision_info() if _hook_master: result.master_branch = _hook_master result.local_branch = result.target_branch @@ -2356,12 +2530,12 @@ def _pull(self, overwrite=False, stop_revision=None, result.master_branch = result.target_branch result.local_branch = None if run_hooks: - for hook in Branch.hooks['post_pull']: + for hook in Branch.hooks["post_pull"]: hook(result) return result def update_references(self): - if not getattr(self.source._format, 'supports_reference_locations', False): + if not getattr(self.source._format, "supports_reference_locations", False): return reference_dict = self.source._get_all_reference_info() if len(reference_dict) == 0: @@ -2371,13 +2545,13 @@ def update_references(self): target_reference_dict = self.target._get_all_reference_info() for tree_path, (branch_location, file_id) in reference_dict.items(): try: - branch_location = urlutils.rebase_url(branch_location, - old_base, new_base) + branch_location = urlutils.rebase_url( + branch_location, old_base, new_base + ) except urlutils.InvalidRebaseURLs: # Fall back to absolute URL branch_location = urlutils.join(old_base, branch_location) - target_reference_dict.setdefault( - tree_path, (branch_location, file_id)) + target_reference_dict.setdefault(tree_path, (branch_location, file_id)) self.target._set_all_reference_info(target_reference_dict) diff --git a/breezy/branchbuilder.py b/breezy/branchbuilder.py index da8e293ef9..675d61c9f9 100644 --- a/breezy/branchbuilder.py +++ b/breezy/branchbuilder.py @@ -66,25 +66,27 @@ def __init__(self, transport=None, format=None, branch=None): """ if branch is not None: if format is not None: - raise AssertionError( - "branch and format kwargs are mutually exclusive") + raise AssertionError("branch and format kwargs are mutually exclusive") if transport is not None: raise AssertionError( - "branch and transport kwargs are mutually exclusive") + "branch and transport kwargs are mutually exclusive" + ) self._branch = branch else: - if not transport.has('.'): - transport.mkdir('.') + if not transport.has("."): + transport.mkdir(".") if format is None: - format = 'default' + format = "default" if isinstance(format, str): format = controldir.format_registry.make_controldir(format) self._branch = controldir.ControlDir.create_branch_convenience( - transport.base, format=format, force_new_tree=False) + transport.base, format=format, force_new_tree=False + ) self._tree = None - def build_commit(self, parent_ids=None, allow_leftmost_as_ghost=False, - **commit_kwargs): + def build_commit( + self, parent_ids=None, allow_leftmost_as_ghost=False, **commit_kwargs + ): """Build a commit on the branch. This makes a commit with no real file content for when you only want @@ -100,25 +102,26 @@ def build_commit(self, parent_ids=None, allow_leftmost_as_ghost=False, base_id = parent_ids[0] if base_id != self._branch.last_revision(): self._move_branch_pointer( - base_id, allow_leftmost_as_ghost=allow_leftmost_as_ghost) + base_id, allow_leftmost_as_ghost=allow_leftmost_as_ghost + ) tree = self._branch.create_memorytree() with tree.lock_write(): if parent_ids is not None: tree.set_parent_ids( - parent_ids, - allow_leftmost_as_ghost=allow_leftmost_as_ghost) - tree.add('') + parent_ids, allow_leftmost_as_ghost=allow_leftmost_as_ghost + ) + tree.add("") return self._do_commit(tree, **commit_kwargs) def _do_commit(self, tree, message=None, message_callback=None, **kwargs): reporter = commit.NullCommitReporter() if message is None and message_callback is None: - message = 'commit %d' % (self._branch.revno() + 1,) - return tree.commit(message, message_callback=message_callback, - reporter=reporter, **kwargs) + message = "commit %d" % (self._branch.revno() + 1,) + return tree.commit( + message, message_callback=message_callback, reporter=reporter, **kwargs + ) - def _move_branch_pointer(self, new_revision_id, - allow_leftmost_as_ghost=False): + def _move_branch_pointer(self, new_revision_id, allow_leftmost_as_ghost=False): """Point self._branch to a different revision id.""" with self._branch.lock_write(): # We don't seem to have a simple set_last_revision(), so we @@ -127,7 +130,8 @@ def _move_branch_pointer(self, new_revision_id, try: g = self._branch.repository.get_graph() new_revno = g.find_distance_to_null( - new_revision_id, [(cur_revision_id, cur_revno)]) + new_revision_id, [(cur_revision_id, cur_revno)] + ) self._branch.set_last_revision_info(new_revno, new_revision_id) except errors.GhostRevisionsHaveNoRevno: if not allow_leftmost_as_ghost: @@ -153,8 +157,9 @@ def start_series(self): Make sure to call 'finish_series' when you are done. """ if self._tree is not None: - raise AssertionError('You cannot start a new series while a' - ' series is already going.') + raise AssertionError( + "You cannot start a new series while a" " series is already going." + ) self._tree = self._branch.create_memorytree() self._tree.lock_write() @@ -163,9 +168,18 @@ def finish_series(self): self._tree.unlock() self._tree = None - def build_snapshot(self, parent_ids, actions, message=None, timestamp=None, - allow_leftmost_as_ghost=False, committer=None, - timezone=None, message_callback=None, revision_id=None): + def build_snapshot( + self, + parent_ids, + actions, + message=None, + timestamp=None, + allow_leftmost_as_ghost=False, + committer=None, + timezone=None, + message_callback=None, + revision_id=None, + ): """Build a commit, shaped in a specific way. Most of the actions are self-explanatory. 'flush' is special action to @@ -201,7 +215,8 @@ def build_snapshot(self, parent_ids, actions, message=None, timestamp=None, base_id = parent_ids[0] if base_id != self._branch.last_revision(): self._move_branch_pointer( - base_id, allow_leftmost_as_ghost=allow_leftmost_as_ghost) + base_id, allow_leftmost_as_ghost=allow_leftmost_as_ghost + ) if self._tree is not None: tree = self._tree @@ -210,17 +225,17 @@ def build_snapshot(self, parent_ids, actions, message=None, timestamp=None, with tree.lock_write(): if parent_ids is not None: tree.set_parent_ids( - parent_ids, - allow_leftmost_as_ghost=allow_leftmost_as_ghost) + parent_ids, allow_leftmost_as_ghost=allow_leftmost_as_ghost + ) # Unfortunately, MemoryTree.add(directory) just creates an # inventory entry. And the only public function to create a # directory is MemoryTree.mkdir() which creates the directory, but # also always adds it. So we have to use a multi-pass setup. pending = _PendingActions() for action, info in actions: - if action == 'add': + if action == "add": path, file_id, kind, content = info - if kind == 'directory': + if kind == "directory": pending.to_add_directories.append((path, file_id)) else: pending.to_add_files.append(path) @@ -228,38 +243,42 @@ def build_snapshot(self, parent_ids, actions, message=None, timestamp=None, pending.to_add_kinds.append(kind) if content is not None: pending.new_contents[path] = content - elif action == 'modify': + elif action == "modify": path, content = info pending.new_contents[path] = content - elif action == 'unversion': + elif action == "unversion": pending.to_unversion_paths.add(info) - elif action == 'rename': + elif action == "rename": from_relpath, to_relpath = info pending.to_rename.append((from_relpath, to_relpath)) - elif action == 'flush': + elif action == "flush": self._flush_pending(tree, pending) pending = _PendingActions() else: raise ValueError(f'Unknown build action: "{action}"') self._flush_pending(tree, pending) return self._do_commit( - tree, message=message, rev_id=revision_id, - timestamp=timestamp, timezone=timezone, committer=committer, - message_callback=message_callback) + tree, + message=message, + rev_id=revision_id, + timestamp=timestamp, + timezone=timezone, + committer=committer, + message_callback=message_callback, + ) def _flush_pending(self, tree, pending): """Flush the pending actions in 'pending', i.e. apply them to tree.""" for path, file_id in pending.to_add_directories: - if path == '': - if tree.has_filename(path) \ - and path in pending.to_unversion_paths: + if path == "": + if tree.has_filename(path) and path in pending.to_unversion_paths: # We're overwriting this path, no need to unversion pending.to_unversion_paths.discard(path) # Special case, because the path already exists if file_id is not None: - tree.add([path], ['directory'], ids=[file_id]) + tree.add([path], ["directory"], ids=[file_id]) else: - tree.add([path], ['directory']) + tree.add([path], ["directory"]) else: if file_id is not None: tree.mkdir(path, file_id) @@ -270,8 +289,9 @@ def _flush_pending(self, tree, pending): if pending.to_unversion_paths: tree.unversion(pending.to_unversion_paths) if tree.supports_file_ids: - tree.add(pending.to_add_files, - pending.to_add_kinds, pending.to_add_file_ids) + tree.add( + pending.to_add_files, pending.to_add_kinds, pending.to_add_file_ids + ) else: tree.add(pending.to_add_files, pending.to_add_kinds) for path, content in pending.new_contents.items(): diff --git a/breezy/breakin.py b/breezy/breakin.py index c8aa4a8fae..9d026262ff 100644 --- a/breezy/breakin.py +++ b/breezy/breakin.py @@ -25,9 +25,12 @@ def _debug(signal_number, interrupted_frame): import pdb import sys - sys.stderr.write(f"** {_breakin_signal_name} received, entering debugger\n" - "** Type 'c' to continue or 'q' to stop the process\n" - f"** Or {_breakin_signal_name} again to quit (and possibly dump core)\n") + + sys.stderr.write( + f"** {_breakin_signal_name} received, entering debugger\n" + "** Type 'c' to continue or 'q' to stop the process\n" + f"** Or {_breakin_signal_name} again to quit (and possibly dump core)\n" + ) # It seems that on Windows, when sys.stderr is to a PIPE, then we need to # flush. Not sure why it is buffered, but that seems to be the case. sys.stderr.flush() @@ -51,14 +54,14 @@ def determine_signal(): # and other platforms defined SIGQUIT. There doesn't seem to be a # platform that defines both. # -- jam 2009-07-30 - sigquit = getattr(signal, 'SIGQUIT', None) - sigbreak = getattr(signal, 'SIGBREAK', None) + sigquit = getattr(signal, "SIGQUIT", None) + sigbreak = getattr(signal, "SIGBREAK", None) if sigquit is not None: _breakin_signal_number = sigquit - _breakin_signal_name = 'SIGQUIT' + _breakin_signal_name = "SIGQUIT" elif sigbreak is not None: _breakin_signal_number = sigbreak - _breakin_signal_name = 'SIGBREAK' + _breakin_signal_name = "SIGBREAK" return _breakin_signal_number @@ -70,7 +73,7 @@ def hook_debugger_to_signal(): hooked into SIGBREAK (C-Pause). """ # when sigquit (C-\) or sigbreak (C-Pause) is received go into pdb - if os.environ.get('BRZ_SIGQUIT_PDB', '1') == '0': + if os.environ.get("BRZ_SIGQUIT_PDB", "1") == "0": # User explicitly requested we don't support this return sig = determine_signal() diff --git a/breezy/bugtracker.py b/breezy/bugtracker.py index f43010a6ba..47e0a04eb5 100644 --- a/breezy/bugtracker.py +++ b/breezy/bugtracker.py @@ -32,11 +32,11 @@ """ - class MalformedBugIdentifier(errors.BzrError): - - _fmt = ('Did not understand bug identifier %(bug_id)s: %(reason)s. ' - 'See "brz help bugs" for more information on this feature.') + _fmt = ( + "Did not understand bug identifier %(bug_id)s: %(reason)s. " + 'See "brz help bugs" for more information on this feature.' + ) def __init__(self, bug_id, reason): self.bug_id = bug_id @@ -44,9 +44,9 @@ def __init__(self, bug_id, reason): class InvalidBugTrackerURL(errors.BzrError): - - _fmt = ("The URL for bug tracker \"%(abbreviation)s\" doesn't " - "contain {id}: %(url)s") + _fmt = ( + 'The URL for bug tracker "%(abbreviation)s" doesn\'t ' "contain {id}: %(url)s" + ) def __init__(self, abbreviation, url): self.abbreviation = abbreviation @@ -54,9 +54,7 @@ def __init__(self, abbreviation, url): class UnknownBugTrackerAbbreviation(errors.BzrError): - - _fmt = ("Cannot find registered bug tracker called %(abbreviation)s " - "on %(branch)s") + _fmt = "Cannot find registered bug tracker called %(abbreviation)s " "on %(branch)s" def __init__(self, abbreviation, branch): self.abbreviation = abbreviation @@ -64,15 +62,13 @@ def __init__(self, abbreviation, branch): class InvalidLineInBugsProperty(errors.BzrError): - - _fmt = ("Invalid line in bugs property: '%(line)s'") + _fmt = "Invalid line in bugs property: '%(line)s'" def __init__(self, line): self.line = line class InvalidBugUrl(errors.BzrError): - _fmt = "Invalid bug URL: %(url)s" def __init__(self, url): @@ -80,8 +76,7 @@ def __init__(self, url): class InvalidBugStatus(errors.BzrError): - - _fmt = ("Invalid bug status: '%(status)s'") + _fmt = "Invalid bug status: '%(status)s'" def __init__(self, status): self.status = status @@ -109,8 +104,7 @@ def get_tracker(self, abbreviated_bugtracker_name, branch): tracker = tracker_type.get(abbreviated_bugtracker_name, branch) if tracker is not None: return tracker - raise UnknownBugTrackerAbbreviation( - abbreviated_bugtracker_name, branch) + raise UnknownBugTrackerAbbreviation(abbreviated_bugtracker_name, branch) def help_topic(self, topic): return _bugs_help @@ -191,7 +185,7 @@ def get(self, abbreviated_bugtracker_name, branch): def check_bug_id(self, bug_id): try: - (project, bug_id) = bug_id.rsplit('/', 1) + (project, bug_id) = bug_id.rsplit("/", 1) except ValueError as exc: raise MalformedBugIdentifier(bug_id, "Expected format: project/id") from exc try: @@ -200,32 +194,35 @@ def check_bug_id(self, bug_id): raise MalformedBugIdentifier(bug_id, "Bug id must be an integer") from exc def _get_bug_url(self, bug_id): - (project, bug_id) = bug_id.rsplit('/', 1) + (project, bug_id) = bug_id.rsplit("/", 1) """Return the URL for bug_id.""" - if '{id}' not in self._base_url: + if "{id}" not in self._base_url: raise InvalidBugTrackerURL(self.abbreviation, self._base_url) - if '{project}' not in self._base_url: + if "{project}" not in self._base_url: raise InvalidBugTrackerURL(self.abbreviation, self._base_url) - return self._base_url.replace( - '{project}', project).replace('{id}', str(bug_id)) + return self._base_url.replace("{project}", project).replace("{id}", str(bug_id)) tracker_registry.register( - 'launchpad', UniqueIntegerBugTracker('lp', 'https://launchpad.net/bugs/')) + "launchpad", UniqueIntegerBugTracker("lp", "https://launchpad.net/bugs/") +) tracker_registry.register( - 'debian', UniqueIntegerBugTracker('deb', 'http://bugs.debian.org/')) + "debian", UniqueIntegerBugTracker("deb", "http://bugs.debian.org/") +) tracker_registry.register( - 'gnome', UniqueIntegerBugTracker( - 'gnome', 'http://bugzilla.gnome.org/show_bug.cgi?id=')) + "gnome", + UniqueIntegerBugTracker("gnome", "http://bugzilla.gnome.org/show_bug.cgi?id="), +) tracker_registry.register( - 'github', ProjectIntegerBugTracker( - 'github', 'https://github.com/{project}/issues/{id}')) + "github", + ProjectIntegerBugTracker("github", "https://github.com/{project}/issues/{id}"), +) class URLParametrizedBugTracker(BugTracker): @@ -240,7 +237,8 @@ class URLParametrizedBugTracker(BugTracker): def get(self, abbreviation, branch): config = branch.get_config() url = config.get_user_option( - f"{self.type_name}_{abbreviation}_url", expand=False) + f"{self.type_name}_{abbreviation}_url", expand=False + ) if url is None: return None self._base_url = url @@ -255,8 +253,7 @@ def _get_bug_url(self, bug_id): return urlutils.join(self._base_url, self._bug_area) + str(bug_id) -class URLParametrizedIntegerBugTracker(IntegerBugTracker, - URLParametrizedBugTracker): +class URLParametrizedIntegerBugTracker(IntegerBugTracker, URLParametrizedBugTracker): """A type of bug tracker that only allows integer bug IDs. This can be found on a variety of different sites, and thus needs to have @@ -269,19 +266,18 @@ class URLParametrizedIntegerBugTracker(IntegerBugTracker, """ -tracker_registry.register( - 'trac', URLParametrizedIntegerBugTracker('trac', 'ticket/')) +tracker_registry.register("trac", URLParametrizedIntegerBugTracker("trac", "ticket/")) tracker_registry.register( - 'bugzilla', - URLParametrizedIntegerBugTracker('bugzilla', 'show_bug.cgi?id=')) + "bugzilla", URLParametrizedIntegerBugTracker("bugzilla", "show_bug.cgi?id=") +) class GenericBugTracker(URLParametrizedBugTracker): """Generic bug tracker specified by an URL template.""" def __init__(self): - super().__init__('bugtracker', None) + super().__init__("bugtracker", None) def get(self, abbreviation, branch): self._abbreviation = abbreviation @@ -289,16 +285,16 @@ def get(self, abbreviation, branch): def _get_bug_url(self, bug_id): """Given a validated bug_id, return the bug's web page's URL.""" - if '{id}' not in self._base_url: + if "{id}" not in self._base_url: raise InvalidBugTrackerURL(self._abbreviation, self._base_url) - return self._base_url.replace('{id}', str(bug_id)) + return self._base_url.replace("{id}", str(bug_id)) -tracker_registry.register('generic', GenericBugTracker()) +tracker_registry.register("generic", GenericBugTracker()) -FIXED = 'fixed' -RELATED = 'related' +FIXED = "fixed" +RELATED = "related" ALLOWED_BUG_STATUSES = {FIXED, RELATED} @@ -312,11 +308,11 @@ def encode_fixes_bug_urls(bug_urls): as part of a commit. """ lines = [] - for (url, tag) in bug_urls: - if ' ' in url: + for url, tag in bug_urls: + if " " in url: raise InvalidBugUrl(url) - lines.append(f'{url} {tag}') - return '\n'.join(lines) + lines.append(f"{url} {tag}") + return "\n".join(lines) def decode_bug_urls(bug_lines): diff --git a/breezy/builtins.py b/breezy/builtins.py index feaaef5da8..fc050ed573 100644 --- a/breezy/builtins.py +++ b/breezy/builtins.py @@ -24,7 +24,9 @@ from . import controldir, errors, lazy_import, osutils, transport -lazy_import.lazy_import(globals(), """ +lazy_import.lazy_import( + globals(), + """ import time import breezy @@ -46,7 +48,8 @@ ) from breezy.branch import Branch from breezy.i18n import gettext, ngettext -""") +""", +) from .commands import Command, builtin_command_registry, display_command from .option import ListOption, Option, RegistryOption, _parse_revision_str, custom_help @@ -62,8 +65,7 @@ def _get_branch_location(control_dir, possible_transports=None): return control_dir.root_transport.base if target is not None: return target - this_branch = control_dir.open_branch( - possible_transports=possible_transports) + this_branch = control_dir.open_branch(possible_transports=possible_transports) # This may be a heavy checkout, where we want the master branch master_location = this_branch.get_bound_location() if master_location is not None: @@ -81,11 +83,13 @@ def _is_colocated(control_dir, possible_transports=None): """ # This path is meant to be relative to the existing branch this_url = _get_branch_location( - control_dir, possible_transports=possible_transports) + control_dir, possible_transports=possible_transports + ) # Perhaps the target control dir supports colocated branches? try: root = controldir.ControlDir.open( - this_url, possible_transports=possible_transports) + this_url, possible_transports=possible_transports + ) except errors.NotBranchError: return (False, this_url) else: @@ -95,9 +99,10 @@ def _is_colocated(control_dir, possible_transports=None): return (False, this_url) else: return ( - root._format.colocated_branches and - control_dir.control_url == root.control_url, - this_url) + root._format.colocated_branches + and control_dir.control_url == root.control_url, + this_url, + ) def lookup_new_sibling_branch(control_dir, location, possible_transports=None): @@ -108,15 +113,17 @@ def lookup_new_sibling_branch(control_dir, location, possible_transports=None): :return: Full location to the new branch """ from .directory_service import directories + location = directories.dereference(location) - if '/' not in location and '\\' not in location: + if "/" not in location and "\\" not in location: (colocated, this_url) = _is_colocated(control_dir, possible_transports) if colocated: return urlutils.join_segment_parameters( - this_url, {"branch": urlutils.escape(location)}) + this_url, {"branch": urlutils.escape(location)} + ) else: - return urlutils.join(this_url, '..', urlutils.escape(location)) + return urlutils.join(this_url, "..", urlutils.escape(location)) return location @@ -131,12 +138,11 @@ def open_sibling_branch(control_dir, location, possible_transports=None): try: # Perhaps it's a colocated branch? return control_dir.open_branch( - location, possible_transports=possible_transports) + location, possible_transports=possible_transports + ) except (errors.NotBranchError, controldir.NoColocatedBranchSupport): this_url = _get_branch_location(control_dir) - return Branch.open( - urlutils.join( - this_url, '..', urlutils.escape(location))) + return Branch.open(urlutils.join(this_url, "..", urlutils.escape(location))) def open_nearby_branch(near=None, location=None, possible_transports=None): @@ -150,14 +156,11 @@ def open_nearby_branch(near=None, location=None, possible_transports=None): if location is None: location = "." try: - return Branch.open( - location, possible_transports=possible_transports) + return Branch.open(location, possible_transports=possible_transports) except errors.NotBranchError: near = "." - cdir = controldir.ControlDir.open( - near, possible_transports=possible_transports) - return open_sibling_branch( - cdir, location, possible_transports=possible_transports) + cdir = controldir.ControlDir.open(near, possible_transports=possible_transports) + return open_sibling_branch(cdir, location, possible_transports=possible_transports) def iter_sibling_branches(control_dir, possible_transports=None): @@ -172,8 +175,7 @@ def iter_sibling_branches(control_dir, possible_transports=None): reference = None if reference is not None: try: - ref_branch = Branch.open( - reference, possible_transports=possible_transports) + ref_branch = Branch.open(reference, possible_transports=possible_transports) except errors.NotBranchError: ref_branch = None else: @@ -186,8 +188,7 @@ def iter_sibling_branches(control_dir, possible_transports=None): else: repo = ref_branch.controldir.find_repository() for branch in repo.find_branches(using=True): - name = urlutils.relative_url( - repo.user_url, branch.user_url).rstrip("/") + name = urlutils.relative_url(repo.user_url, branch.user_url).rstrip("/") yield name, branch @@ -215,14 +216,13 @@ def tree_files_for_add(file_list): file_list = file_list[:] file_list[0] = tree.abspath(relpath) else: - tree = WorkingTree.open_containing('.')[0] + tree = WorkingTree.open_containing(".")[0] if tree.supports_views(): view_files = tree.views.lookup_view() if view_files: file_list = view_files view_str = views.view_display_str(view_files) - note(gettext("Ignoring files outside view. View is %s"), - view_str) + note(gettext("Ignoring files outside view. View is %s"), view_str) return tree, file_list @@ -230,9 +230,10 @@ def _get_one_revision(command_name, revisions): if revisions is None: return None if len(revisions) != 1: - raise errors.CommandError(gettext( - 'brz %s --revision takes exactly one revision identifier') % ( - command_name,)) + raise errors.CommandError( + gettext("brz %s --revision takes exactly one revision identifier") + % (command_name,) + ) return revisions[0] @@ -262,6 +263,7 @@ def _get_one_revision_tree(command_name, revisions, branch=None, tree=None): def _get_view_info_for_change_reporter(tree): """Get the view information from a tree for change reporting.""" from . import views + view_info = None try: current_view = tree.views.get_view_info()[0] @@ -287,6 +289,7 @@ def _open_directory_or_containing_tree_or_branch(filename, directory): # Argument class, representing a file in a branch, where the first occurrence # opens the branch?) + class cmd_status(Command): __doc__ = """Display status summary. @@ -346,48 +349,67 @@ class cmd_status(Command): # TODO: --no-recurse/-N, --recurse options - takes_args = ['file*'] - takes_options = ['show-ids', 'revision', 'change', 'verbose', - Option('short', help='Use short status indicators.', - short_name='S'), - Option('versioned', help='Only show versioned files.', - short_name='V'), - Option('no-pending', help='Don\'t show pending merges.'), - Option('no-classify', - help='Do not mark object type using indicator.'), - ] - aliases = ['st', 'stat'] - - encoding_type = 'replace' - _see_also = ['diff', 'revert', 'status-flags'] + takes_args = ["file*"] + takes_options = [ + "show-ids", + "revision", + "change", + "verbose", + Option("short", help="Use short status indicators.", short_name="S"), + Option("versioned", help="Only show versioned files.", short_name="V"), + Option("no-pending", help="Don't show pending merges."), + Option("no-classify", help="Do not mark object type using indicator."), + ] + aliases = ["st", "stat"] + + encoding_type = "replace" + _see_also = ["diff", "revert", "status-flags"] @display_command - def run(self, show_ids=False, file_list=None, revision=None, short=False, - versioned=False, no_pending=False, verbose=False, - no_classify=False): + def run( + self, + show_ids=False, + file_list=None, + revision=None, + short=False, + versioned=False, + no_pending=False, + verbose=False, + no_classify=False, + ): from .status import show_tree_status from .workingtree import WorkingTree if revision and len(revision) > 2: raise errors.CommandError( - gettext('brz status --revision takes exactly' - ' one or two revision specifiers')) + gettext( + "brz status --revision takes exactly" + " one or two revision specifiers" + ) + ) tree, relfile_list = WorkingTree.open_containing_paths(file_list) # Avoid asking for specific files when that is not needed. - if relfile_list == ['']: + if relfile_list == [""]: relfile_list = None # Don't disable pending merges for full trees other than '.'. - if file_list == ['.']: + if file_list == ["."]: no_pending = True # A specific path within a tree was given. elif relfile_list is not None: no_pending = True - show_tree_status(tree, show_ids=show_ids, - specific_files=relfile_list, revision=revision, - to_file=self.outf, short=short, versioned=versioned, - show_pending=(not no_pending), verbose=verbose, - classify=not no_classify) + show_tree_status( + tree, + show_ids=show_ids, + specific_files=relfile_list, + revision=revision, + to_file=self.outf, + short=short, + versioned=versioned, + show_pending=(not no_pending), + verbose=verbose, + classify=not no_classify, + ) class cmd_cat_revision(Command): @@ -398,52 +420,58 @@ class cmd_cat_revision(Command): """ hidden = True - takes_args = ['revision_id?'] - takes_options = ['directory', 'revision'] + takes_args = ["revision_id?"] + takes_options = ["directory", "revision"] # cat-revision is more for frontends so should be exact - encoding = 'strict' + encoding = "strict" def print_revision(self, revisions, revid): - stream = revisions.get_record_stream([(revid,)], 'unordered', True) + stream = revisions.get_record_stream([(revid,)], "unordered", True) record = next(stream) - if record.storage_kind == 'absent': + if record.storage_kind == "absent": raise errors.NoSuchRevision(revisions, revid) - revtext = record.get_bytes_as('fulltext') - self.outf.write(revtext.decode('utf-8')) + revtext = record.get_bytes_as("fulltext") + self.outf.write(revtext.decode("utf-8")) @display_command - def run(self, revision_id=None, revision=None, directory='.'): + def run(self, revision_id=None, revision=None, directory="."): if revision_id is not None and revision is not None: - raise errors.CommandError(gettext('You can only supply one of' - ' revision_id or --revision')) + raise errors.CommandError( + gettext("You can only supply one of" " revision_id or --revision") + ) if revision_id is None and revision is None: raise errors.CommandError( - gettext('You must supply either --revision or a revision_id')) + gettext("You must supply either --revision or a revision_id") + ) b = controldir.ControlDir.open_containing_tree_or_branch(directory)[1] revisions = getattr(b.repository, "revisions", None) if revisions is None: raise errors.CommandError( - gettext('Repository %r does not support ' - 'access to raw revision texts') % b.repository) + gettext( + "Repository %r does not support " "access to raw revision texts" + ) + % b.repository + ) with b.repository.lock_read(): # TODO: jam 20060112 should cat-revision always output utf-8? if revision_id is not None: - revision_id = revision_id.encode('utf-8') + revision_id = revision_id.encode("utf-8") try: self.print_revision(revisions, revision_id) except errors.NoSuchRevision as exc: msg = gettext( - "The repository {0} contains no revision {1}.").format( - b.repository.base, revision_id.decode('utf-8')) + "The repository {0} contains no revision {1}." + ).format(b.repository.base, revision_id.decode("utf-8")) raise errors.CommandError(msg) from exc elif revision is not None: for rev in revision: if rev is None: raise errors.CommandError( - gettext('You cannot specify a NULL revision.')) + gettext("You cannot specify a NULL revision.") + ) rev_id = rev.as_revision_id(b) self.print_revision(revisions, rev_id) @@ -456,17 +484,19 @@ class cmd_remove_tree(Command): To re-create the working tree, use "brz checkout". """ - _see_also = ['checkout', 'working-trees'] - takes_args = ['location*'] + _see_also = ["checkout", "working-trees"] + takes_args = ["location*"] takes_options = [ - Option('force', - help='Remove the working tree even if it has ' - 'uncommitted or shelved changes.'), - ] + Option( + "force", + help="Remove the working tree even if it has " + "uncommitted or shelved changes.", + ), + ] def run(self, location_list, force=False): if not location_list: - location_list = ['.'] + location_list = ["."] for location in location_list: d = controldir.ControlDir.open(location) @@ -474,22 +504,24 @@ def run(self, location_list, force=False): try: working = d.open_workingtree() except errors.NoWorkingTree as exc: - raise errors.CommandError( - gettext("No working tree to remove")) from exc + raise errors.CommandError(gettext("No working tree to remove")) from exc except errors.NotLocalUrl as exc: raise errors.CommandError( - gettext("You cannot remove the working tree" - " of a remote path")) from exc + gettext("You cannot remove the working tree" " of a remote path") + ) from exc if not force: - if (working.has_changes()): + if working.has_changes(): raise errors.UncommittedChanges(working) if working.get_shelf_manager().last_shelf() is not None: raise errors.ShelvedChanges(working) if working.user_url != working.branch.user_url: raise errors.CommandError( - gettext("You cannot remove the working tree" - " from a lightweight checkout")) + gettext( + "You cannot remove the working tree" + " from a lightweight checkout" + ) + ) d.destroy_workingtree() @@ -511,15 +543,17 @@ class cmd_repair_workingtree(Command): """ takes_options = [ - 'revision', 'directory', - Option('force', - help='Reset the tree even if it doesn\'t appear to be' - ' corrupted.'), + "revision", + "directory", + Option( + "force", help="Reset the tree even if it doesn't appear to be" " corrupted." + ), ] hidden = True - def run(self, revision=None, directory='.', force=False): + def run(self, revision=None, directory=".", force=False): from .workingtree import WorkingTree + tree, _ = WorkingTree.open_containing(directory) self.enter_context(tree.lock_tree_write()) if not force: @@ -529,10 +563,13 @@ def run(self, revision=None, directory='.', force=False): pass # There seems to be a real error here, so we'll reset else: # Refuse - raise errors.CommandError(gettext( - 'The tree does not appear to be corrupt. You probably' - ' want "brz revert" instead. Use "--force" if you are' - ' sure you want to reset the working tree.')) + raise errors.CommandError( + gettext( + "The tree does not appear to be corrupt. You probably" + ' want "brz revert" instead. Use "--force" if you are' + " sure you want to reset the working tree." + ) + ) if revision is None: revision_ids = None else: @@ -541,12 +578,15 @@ def run(self, revision=None, directory='.', force=False): tree.reset_state(revision_ids) except errors.BzrError as exc: if revision_ids is None: - extra = gettext(', the header appears corrupt, try passing ' - '-r -1 to set the state to the last commit') + extra = gettext( + ", the header appears corrupt, try passing " + "-r -1 to set the state to the last commit" + ) else: - extra = '' + extra = "" raise errors.CommandError( - gettext('failed to reset the tree state{0}').format(extra)) from exc + gettext("failed to reset the tree state{0}").format(extra) + ) from exc class cmd_revno(Command): @@ -555,19 +595,21 @@ class cmd_revno(Command): This is equal to the number of revisions on this branch. """ - _see_also = ['info'] - takes_args = ['location?'] + _see_also = ["info"] + takes_args = ["location?"] takes_options = [ - Option('tree', help='Show revno of working tree.'), - 'revision', - ] + Option("tree", help="Show revno of working tree."), + "revision", + ] @display_command - def run(self, tree=False, location='.', revision=None): + def run(self, tree=False, location=".", revision=None): from .workingtree import WorkingTree + if revision is not None and tree: raise errors.CommandError( - gettext("--tree and --revision can not be used together")) + gettext("--tree and --revision can not be used together") + ) if tree: try: @@ -582,36 +624,41 @@ def run(self, tree=False, location='.', revision=None): self.enter_context(b.lock_read()) if revision: if len(revision) != 1: - raise errors.CommandError(gettext( - "Revision numbers only make sense for single " - "revisions, not ranges")) + raise errors.CommandError( + gettext( + "Revision numbers only make sense for single " + "revisions, not ranges" + ) + ) revid = revision[0].as_revision_id(b) else: revid = b.last_revision() try: revno_t = b.revision_id_to_dotted_revno(revid) except (errors.NoSuchRevision, errors.GhostRevisionsHaveNoRevno): - revno_t = ('???',) + revno_t = ("???",) revno = ".".join(str(n) for n in revno_t) self.cleanup_now() - self.outf.write(revno + '\n') + self.outf.write(revno + "\n") class cmd_revision_info(Command): __doc__ = """Show revision number and revision id for a given revision identifier. """ hidden = True - takes_args = ['revision_info*'] + takes_args = ["revision_info*"] takes_options = [ - 'revision', - custom_help('directory', help='Branch to examine, ' - 'rather than the one containing the working directory.'), - Option('tree', help='Show revno of working tree.'), - ] + "revision", + custom_help( + "directory", + help="Branch to examine, " + "rather than the one containing the working directory.", + ), + Option("tree", help="Show revno of working tree."), + ] @display_command - def run(self, revision=None, directory='.', tree=False, - revision_info_list=None): + def run(self, revision=None, directory=".", tree=False, revision_info_list=None): from .workingtree import WorkingTree try: @@ -643,16 +690,15 @@ def run(self, revision=None, directory='.', tree=False, for revision_id in revision_ids: try: dotted_revno = b.revision_id_to_dotted_revno(revision_id) - revno = '.'.join(str(i) for i in dotted_revno) + revno = ".".join(str(i) for i in dotted_revno) except errors.NoSuchRevision: - revno = '???' + revno = "???" maxlen = max(maxlen, len(revno)) revinfos.append((revno, revision_id)) self.cleanup_now() for revno, revid in revinfos: - self.outf.write( - '%*s %s\n' % (maxlen, revno, revid.decode('utf-8'))) + self.outf.write("%*s %s\n" % (maxlen, revno, revid.decode("utf-8"))) class cmd_add(Command): @@ -699,64 +745,75 @@ class cmd_add(Command): add.maximum_file_size will be skipped. Named items are never skipped due to file size. """ - takes_args = ['file*'] + takes_args = ["file*"] takes_options = [ - Option('no-recurse', - help="Don't recursively add the contents of directories.", - short_name='N'), - Option('dry-run', - help="Show what would be done, but don't actually do " - "anything."), - 'verbose', - Option('file-ids-from', - type=str, - help='Lookup file ids from this tree.'), - ] - encoding_type = 'replace' - _see_also = ['remove', 'ignore'] - - def run(self, file_list, no_recurse=False, dry_run=False, verbose=False, - file_ids_from=None): + Option( + "no-recurse", + help="Don't recursively add the contents of directories.", + short_name="N", + ), + Option( + "dry-run", + help="Show what would be done, but don't actually do " "anything.", + ), + "verbose", + Option("file-ids-from", type=str, help="Lookup file ids from this tree."), + ] + encoding_type = "replace" + _see_also = ["remove", "ignore"] + + def run( + self, + file_list, + no_recurse=False, + dry_run=False, + verbose=False, + file_ids_from=None, + ): import breezy.add from .workingtree import WorkingTree + tree, file_list = tree_files_for_add(file_list) if file_ids_from is not None and not tree.supports_setting_file_ids(): warning( - gettext('Ignoring --file-ids-from, since the tree does not ' - 'support setting file ids.')) + gettext( + "Ignoring --file-ids-from, since the tree does not " + "support setting file ids." + ) + ) file_ids_from = None base_tree = None if file_ids_from is not None: try: - base_tree, base_path = WorkingTree.open_containing( - file_ids_from) + base_tree, base_path = WorkingTree.open_containing(file_ids_from) except errors.NoWorkingTree: - base_branch, base_path = Branch.open_containing( - file_ids_from) + base_branch, base_path = Branch.open_containing(file_ids_from) base_tree = base_branch.basis_tree() action = breezy.add.AddFromBaseAction( - base_tree, base_path, to_file=self.outf, - should_print=(not is_quiet())) + base_tree, base_path, to_file=self.outf, should_print=(not is_quiet()) + ) else: action = breezy.add.AddWithSkipLargeAction( - to_file=self.outf, should_print=(not is_quiet())) + to_file=self.outf, should_print=(not is_quiet()) + ) if base_tree: self.enter_context(base_tree.lock_read()) added, ignored = tree.smart_add( - file_list, not no_recurse, action=action, save=not dry_run) + file_list, not no_recurse, action=action, save=not dry_run + ) self.cleanup_now() if len(ignored) > 0: if verbose: for glob in sorted(ignored): for path in ignored[glob]: self.outf.write( - gettext("ignored {0} matching \"{1}\"\n").format( - path, glob)) + gettext('ignored {0} matching "{1}"\n').format(path, glob) + ) class cmd_mkdir(Command): @@ -765,15 +822,15 @@ class cmd_mkdir(Command): This is equivalent to creating the directory and then adding it. """ - takes_args = ['dir+'] + takes_args = ["dir+"] takes_options = [ Option( - 'parents', - help='No error if existing, make parent directories as needed.', - short_name='p' - ) - ] - encoding_type = 'replace' + "parents", + help="No error if existing, make parent directories as needed.", + short_name="p", + ) + ] + encoding_type = "replace" @classmethod def add_file_with_parents(cls, wt, relpath): @@ -788,6 +845,7 @@ def add_file_single(cls, wt, relpath): def run(self, dir_list, parents=False): from .workingtree import WorkingTree + if parents: add_file = self.add_file_with_parents else: @@ -803,13 +861,13 @@ def run(self, dir_list, parents=False): os.mkdir(dir) add_file(wt, relpath) if not is_quiet(): - self.outf.write(gettext('added %s\n') % dir) + self.outf.write(gettext("added %s\n") % dir) class cmd_relpath(Command): __doc__ = """Show path of a file relative to root""" - takes_args = ['filename'] + takes_args = ["filename"] hidden = True @display_command @@ -820,7 +878,7 @@ def run(self, filename): # sys.stdout encoding cannot represent it? tree, relpath = WorkingTree.open_containing(filename) self.outf.write(relpath) - self.outf.write('\n') + self.outf.write("\n") class cmd_inventory(Command): @@ -834,28 +892,36 @@ class cmd_inventory(Command): """ hidden = True - _see_also = ['ls'] + _see_also = ["ls"] takes_options = [ - 'revision', - 'show-ids', - Option('include-root', - help='Include the entry for the root of the tree, if any.'), - Option('kind', - help='List entries of a particular kind: file, directory, ' - 'symlink.', - type=str), - ] - takes_args = ['file*'] + "revision", + "show-ids", + Option( + "include-root", help="Include the entry for the root of the tree, if any." + ), + Option( + "kind", + help="List entries of a particular kind: file, directory, " "symlink.", + type=str, + ), + ] + takes_args = ["file*"] @display_command - def run(self, revision=None, show_ids=False, kind=None, include_root=False, - file_list=None): + def run( + self, + revision=None, + show_ids=False, + kind=None, + include_root=False, + file_list=None, + ): from .workingtree import WorkingTree - if kind and kind not in ['file', 'directory', 'symlink']: - raise errors.CommandError( - gettext('invalid kind %r specified') % (kind,)) - revision = _get_one_revision('inventory', revision) + if kind and kind not in ["file", "directory", "symlink"]: + raise errors.CommandError(gettext("invalid kind %r specified") % (kind,)) + + revision = _get_one_revision("inventory", revision) work_tree, file_list = WorkingTree.open_containing_paths(file_list) self.enter_context(work_tree.lock_read()) if revision is not None: @@ -870,7 +936,8 @@ def run(self, revision=None, show_ids=False, kind=None, include_root=False, self.enter_context(tree.lock_read()) if file_list is not None: paths = tree.find_related_paths_across_trees( - file_list, extra_trees, require_versioned=True) + file_list, extra_trees, require_versioned=True + ) # find_ids_across_trees may include some paths that don't # exist in 'tree'. entries = tree.iter_entries_by_dir(specific_files=paths) @@ -883,11 +950,10 @@ def run(self, revision=None, show_ids=False, kind=None, include_root=False, if path == "" and not include_root: continue if show_ids: - self.outf.write('%-50s %s\n' % ( - path, entry.file_id.decode('utf-8'))) + self.outf.write("%-50s %s\n" % (path, entry.file_id.decode("utf-8"))) else: self.outf.write(path) - self.outf.write('\n') + self.outf.write("\n") class cmd_cp(Command): @@ -906,22 +972,23 @@ class cmd_cp(Command): at the moment. """ - takes_args = ['names*'] - aliases = ['copy'] - encoding_type = 'replace' + takes_args = ["names*"] + aliases = ["copy"] + encoding_type = "replace" def run(self, names_list): from .workingtree import WorkingTree + if names_list is None: names_list = [] if len(names_list) < 2: raise errors.CommandError(gettext("missing file argument")) tree, rel_names = WorkingTree.open_containing_paths( - names_list, canonicalize=False) + names_list, canonicalize=False + ) for file_name in rel_names[0:-1]: - if file_name == '': - raise errors.CommandError( - gettext("can not copy root of branch")) + if file_name == "": + raise errors.CommandError(gettext("can not copy root of branch")) self.enter_context(tree.lock_tree_write()) into_existing = osutils.isdir(names_list[-1]) if not into_existing: @@ -929,41 +996,51 @@ def run(self, names_list): (src, dst) = rel_names except IndexError as exc: raise errors.CommandError( - gettext('to copy multiple files the' - ' destination must be a versioned' - ' directory')) from exc + gettext( + "to copy multiple files the" + " destination must be a versioned" + " directory" + ) + ) from exc pairs = [(src, dst)] else: pairs = [ (n, osutils.joinpath([rel_names[-1], osutils.basename(n)])) - for n in rel_names[:-1]] + for n in rel_names[:-1] + ] for src, dst in pairs: try: src_kind = tree.stored_kind(src) except transport.NoSuchFile as exc: raise errors.CommandError( - gettext('Could not copy %s => %s: %s is not versioned.') - % (src, dst, src)) from exc + gettext("Could not copy %s => %s: %s is not versioned.") + % (src, dst, src) + ) from exc if src_kind is None: raise errors.CommandError( - gettext('Could not copy %s => %s . %s is not versioned\\.') % (src, dst, src)) - if src_kind == 'directory': + gettext("Could not copy %s => %s . %s is not versioned\\.") + % (src, dst, src) + ) + if src_kind == "directory": raise errors.CommandError( - gettext('Could not copy %s => %s . %s is a directory.') % ( - src, dst, src)) + gettext("Could not copy %s => %s . %s is a directory.") + % (src, dst, src) + ) dst_parent = osutils.split(dst)[0] - if dst_parent != '': + if dst_parent != "": try: dst_parent_kind = tree.stored_kind(dst_parent) except transport.NoSuchFile as exc: raise errors.CommandError( - gettext('Could not copy %s => %s: %s is not versioned.') - % (src, dst, dst_parent)) from exc - if dst_parent_kind != 'directory': + gettext("Could not copy %s => %s: %s is not versioned.") + % (src, dst, dst_parent) + ) from exc + if dst_parent_kind != "directory": raise errors.CommandError( - gettext('Could not copy to %s: %s is not a directory.') - % (dst_parent, dst_parent)) + gettext("Could not copy to %s: %s is not a directory.") + % (dst_parent, dst_parent) + ) tree.copy_one(src, dst) @@ -989,49 +1066,56 @@ class cmd_mv(Command): Files cannot be moved between branches. """ - takes_args = ['names*'] - takes_options = [Option("after", help="Move only the brz identifier" - " of the file, because the file has already been moved."), - Option('auto', help='Automatically guess renames.'), - Option( - 'dry-run', help='Avoid making changes when guessing renames.'), - ] - aliases = ['move', 'rename'] - encoding_type = 'replace' + takes_args = ["names*"] + takes_options = [ + Option( + "after", + help="Move only the brz identifier" + " of the file, because the file has already been moved.", + ), + Option("auto", help="Automatically guess renames."), + Option("dry-run", help="Avoid making changes when guessing renames."), + ] + aliases = ["move", "rename"] + encoding_type = "replace" def run(self, names_list, after=False, auto=False, dry_run=False): from .workingtree import WorkingTree + if auto: return self.run_auto(names_list, after, dry_run) elif dry_run: - raise errors.CommandError(gettext('--dry-run requires --auto.')) + raise errors.CommandError(gettext("--dry-run requires --auto.")) if names_list is None: names_list = [] if len(names_list) < 2: raise errors.CommandError(gettext("missing file argument")) tree, rel_names = WorkingTree.open_containing_paths( - names_list, canonicalize=False) + names_list, canonicalize=False + ) for file_name in rel_names[0:-1]: - if file_name == '': - raise errors.CommandError( - gettext("can not move root of branch")) + if file_name == "": + raise errors.CommandError(gettext("can not move root of branch")) self.enter_context(tree.lock_tree_write()) self._run(tree, names_list, rel_names, after) def run_auto(self, names_list, after, dry_run): from .rename_map import RenameMap from .workingtree import WorkingTree + if names_list is not None and len(names_list) > 1: raise errors.CommandError( - gettext('Only one path may be specified to --auto.')) + gettext("Only one path may be specified to --auto.") + ) if after: raise errors.CommandError( - gettext('--after cannot be specified with --auto.')) + gettext("--after cannot be specified with --auto.") + ) work_tree, file_list = WorkingTree.open_containing_paths( - names_list, default_directory='.') + names_list, default_directory="." + ) self.enter_context(work_tree.lock_tree_write()) - RenameMap.guess_renames( - work_tree.basis_tree(), work_tree, dry_run) + RenameMap.guess_renames(work_tree.basis_tree(), work_tree, dry_run) def _run(self, tree, names_list, rel_names, after): into_existing = osutils.isdir(names_list[-1]) @@ -1041,15 +1125,16 @@ def _run(self, tree, names_list, rel_names, after): # b. move directory after the fact (if the source used to be # a directory, but now doesn't exist in the working tree # and the target is an existing directory, just rename it) - if (not tree.case_sensitive - and rel_names[0].lower() == rel_names[1].lower()): + if not tree.case_sensitive and rel_names[0].lower() == rel_names[1].lower(): into_existing = False else: # 'fix' the case of a potential 'from' from_path = tree.get_canonical_path(rel_names[0]) - if (not osutils.lexists(names_list[0]) and - tree.is_versioned(from_path) and - tree.stored_kind(from_path) == "directory"): + if ( + not osutils.lexists(names_list[0]) + and tree.is_versioned(from_path) + and tree.stored_kind(from_path) == "directory" + ): into_existing = False # move/rename if into_existing: @@ -1062,9 +1147,13 @@ def _run(self, tree, names_list, rel_names, after): self.outf.write(f"{src} => {dest}\n") else: if len(names_list) != 2: - raise errors.CommandError(gettext('to mv multiple files the' - ' destination must be a versioned' - ' directory')) + raise errors.CommandError( + gettext( + "to mv multiple files the" + " destination must be a versioned" + " directory" + ) + ) # for cicp file-systems: the src references an existing inventory # item: @@ -1091,16 +1180,15 @@ def _run(self, tree, names_list, rel_names, after): if after: # If 'after' is specified, the tail must refer to a file on disk. if dest_parent: - dest_parent_fq = osutils.pathjoin( - tree.basedir, dest_parent) + dest_parent_fq = osutils.pathjoin(tree.basedir, dest_parent) else: # pathjoin with an empty tail adds a slash, which breaks # relpath :( dest_parent_fq = tree.basedir dest_tail = osutils.canonical_relpath( - dest_parent_fq, - osutils.pathjoin(dest_parent_fq, spec_tail)) + dest_parent_fq, osutils.pathjoin(dest_parent_fq, spec_tail) + ) else: # not 'after', so case as specified is used dest_tail = spec_tail @@ -1144,30 +1232,41 @@ class cmd_pull(Command): with brz send. """ - _see_also = ['push', 'update', 'status-flags', 'send'] - takes_options = ['remember', 'overwrite', 'revision', - custom_help('verbose', - help='Show logs of pulled revisions.'), - custom_help('directory', - help='Branch to pull into, ' - 'rather than the one containing the working directory.'), - Option('local', - help="Perform a local pull in a bound " - "branch. Local pulls are not applied to " - "the master branch." - ), - Option('show-base', - help="Show base revision text in conflicts."), - Option('overwrite-tags', - help="Overwrite tags only."), - ] - takes_args = ['location?'] - encoding_type = 'replace' - - def run(self, location=None, remember=None, overwrite=False, - revision=None, verbose=False, - directory=None, local=False, - show_base=False, overwrite_tags=False): + _see_also = ["push", "update", "status-flags", "send"] + takes_options = [ + "remember", + "overwrite", + "revision", + custom_help("verbose", help="Show logs of pulled revisions."), + custom_help( + "directory", + help="Branch to pull into, " + "rather than the one containing the working directory.", + ), + Option( + "local", + help="Perform a local pull in a bound " + "branch. Local pulls are not applied to " + "the master branch.", + ), + Option("show-base", help="Show base revision text in conflicts."), + Option("overwrite-tags", help="Overwrite tags only."), + ] + takes_args = ["location?"] + encoding_type = "replace" + + def run( + self, + location=None, + remember=None, + overwrite=False, + revision=None, + verbose=False, + directory=None, + local=False, + show_base=False, + overwrite_tags=False, + ): from . import mergeable as _mod_mergeable from .workingtree import WorkingTree @@ -1181,7 +1280,7 @@ def run(self, location=None, remember=None, overwrite=False, revision_id = None mergeable = None if directory is None: - directory = '.' + directory = "." try: tree_to = WorkingTree.open_containing(directory)[0] branch_to = tree_to.branch @@ -1200,39 +1299,43 @@ def run(self, location=None, remember=None, overwrite=False, if location is not None: try: mergeable = _mod_mergeable.read_mergeable_from_url( - location, possible_transports=possible_transports) + location, possible_transports=possible_transports + ) except errors.NotABundle: mergeable = None stored_loc = branch_to.get_parent() if location is None: if stored_loc is None: - raise errors.CommandError(gettext("No pull location known or" - " specified.")) + raise errors.CommandError( + gettext("No pull location known or" " specified.") + ) else: - display_url = urlutils.unescape_for_display(stored_loc, - self.outf.encoding) + display_url = urlutils.unescape_for_display( + stored_loc, self.outf.encoding + ) if not is_quiet(): self.outf.write( - gettext("Using saved parent location: %s\n") % display_url) + gettext("Using saved parent location: %s\n") % display_url + ) location = stored_loc - revision = _get_one_revision('pull', revision) + revision = _get_one_revision("pull", revision) if mergeable is not None: if revision is not None: - raise errors.CommandError(gettext( - 'Cannot use -r with merge directives or bundles')) + raise errors.CommandError( + gettext("Cannot use -r with merge directives or bundles") + ) mergeable.install_revisions(branch_to.repository) - base_revision_id, revision_id, verified = \ - mergeable.get_merge_request(branch_to.repository) + base_revision_id, revision_id, verified = mergeable.get_merge_request( + branch_to.repository + ) branch_from = branch_to else: - branch_from = Branch.open(location, - possible_transports=possible_transports) + branch_from = Branch.open(location, possible_transports=possible_transports) self.enter_context(branch_from.lock_read()) # Remembers if asked explicitly or no previous location is set - if (remember - or (remember is None and branch_to.get_parent() is None)): + if remember or (remember is None and branch_to.get_parent() is None): # FIXME: This shouldn't be done before the pull # succeeds... -- vila 2012-01-02 branch_to.set_parent(branch_from.base) @@ -1243,23 +1346,27 @@ def run(self, location=None, remember=None, overwrite=False, if tree_to is not None: view_info = _get_view_info_for_change_reporter(tree_to) change_reporter = delta._ChangeReporter( - unversioned_filter=tree_to.is_ignored, - view_info=view_info) + unversioned_filter=tree_to.is_ignored, view_info=view_info + ) result = tree_to.pull( - branch_from, overwrite=overwrite, stop_revision=revision_id, + branch_from, + overwrite=overwrite, + stop_revision=revision_id, change_reporter=change_reporter, - local=local, show_base=show_base) + local=local, + show_base=show_base, + ) else: result = branch_to.pull( - branch_from, overwrite=overwrite, stop_revision=revision_id, - local=local) + branch_from, overwrite=overwrite, stop_revision=revision_id, local=local + ) result.report(self.outf) if verbose and result.old_revid != result.new_revid: log.show_branch_change( - branch_to, self.outf, result.old_revno, - result.old_revid) - if getattr(result, 'tag_conflicts', None): + branch_to, self.outf, result.old_revno, result.old_revid + ) + if getattr(result, "tag_conflicts", None): return 1 else: return 0 @@ -1295,46 +1402,78 @@ class cmd_push(Command): -Olog_format=. """ - _see_also = ['pull', 'update', 'working-trees'] - takes_options = ['remember', 'overwrite', 'verbose', 'revision', - Option('create-prefix', - help='Create the path leading up to the branch ' - 'if it does not already exist.'), - custom_help('directory', - help='Branch to push from, ' - 'rather than the one containing the working directory.'), - Option('use-existing-dir', - help='By default push will fail if the target' - ' directory exists, but does not already' - ' have a control directory. This flag will' - ' allow push to proceed.'), - Option('stacked', - help='Create a stacked branch that references the public location ' - 'of the parent branch.'), - Option('stacked-on', - help='Create a stacked branch that refers to another branch ' - 'for the commit history. Only the work not present in the ' - 'referenced branch is included in the branch created.', - type=str), - Option('strict', - help='Refuse to push if there are uncommitted changes in' - ' the working tree, --no-strict disables the check.'), - Option('no-tree', - help="Don't populate the working tree, even for protocols" - " that support it."), - Option('overwrite-tags', - help="Overwrite tags only."), - Option('lossy', help="Allow lossy push, i.e. dropping metadata " - "that can't be represented in the target.") - ] - takes_args = ['location?'] - encoding_type = 'replace' - - def run(self, location=None, remember=None, overwrite=False, - create_prefix=False, verbose=False, revision=None, - use_existing_dir=False, directory=None, stacked_on=None, - stacked=False, strict=None, no_tree=False, - overwrite_tags=False, lossy=False): + _see_also = ["pull", "update", "working-trees"] + takes_options = [ + "remember", + "overwrite", + "verbose", + "revision", + Option( + "create-prefix", + help="Create the path leading up to the branch " + "if it does not already exist.", + ), + custom_help( + "directory", + help="Branch to push from, " + "rather than the one containing the working directory.", + ), + Option( + "use-existing-dir", + help="By default push will fail if the target" + " directory exists, but does not already" + " have a control directory. This flag will" + " allow push to proceed.", + ), + Option( + "stacked", + help="Create a stacked branch that references the public location " + "of the parent branch.", + ), + Option( + "stacked-on", + help="Create a stacked branch that refers to another branch " + "for the commit history. Only the work not present in the " + "referenced branch is included in the branch created.", + type=str, + ), + Option( + "strict", + help="Refuse to push if there are uncommitted changes in" + " the working tree, --no-strict disables the check.", + ), + Option( + "no-tree", + help="Don't populate the working tree, even for protocols" + " that support it.", + ), + Option("overwrite-tags", help="Overwrite tags only."), + Option( + "lossy", + help="Allow lossy push, i.e. dropping metadata " + "that can't be represented in the target.", + ), + ] + takes_args = ["location?"] + encoding_type = "replace" + + def run( + self, + location=None, + remember=None, + overwrite=False, + create_prefix=False, + verbose=False, + revision=None, + use_existing_dir=False, + directory=None, + stacked_on=None, + stacked=False, + strict=None, + no_tree=False, + overwrite_tags=False, + lossy=False, + ): from .location import location_to_url from .push import _show_push_branch @@ -1346,24 +1485,27 @@ def run(self, location=None, remember=None, overwrite=False, overwrite = [] if directory is None: - directory = '.' + directory = "." # Get the source branch - (tree, br_from, - _unused) = controldir.ControlDir.open_containing_tree_or_branch(directory) + (tree, br_from, _unused) = controldir.ControlDir.open_containing_tree_or_branch( + directory + ) # Get the tip's revision_id - revision = _get_one_revision('push', revision) + revision = _get_one_revision("push", revision) if revision is not None: revision_id = revision.in_history(br_from).rev_id else: revision_id = None if tree is not None and revision_id is None: tree.check_changed_or_out_of_date( - strict, 'push_strict', - more_error='Use --no-strict to force the push.', - more_warning='Uncommitted changes will not be pushed.') + strict, + "push_strict", + more_error="Use --no-strict to force the push.", + more_warning="Uncommitted changes will not be pushed.", + ) # Get the stacked_on branch, if any if stacked_on is not None: - stacked_on = location_to_url(stacked_on, 'read') + stacked_on = location_to_url(stacked_on, "read") stacked_on = urlutils.normalize_url(stacked_on) elif stacked: parent_url = br_from.get_parent() @@ -1377,8 +1519,9 @@ def run(self, location=None, remember=None, overwrite=False, # error by the feedback given to them. RBC 20080227. stacked_on = parent_url if not stacked_on: - raise errors.CommandError(gettext( - "Could not determine branch to refer to.")) + raise errors.CommandError( + gettext("Could not determine branch to refer to.") + ) # Get the destination location if location is None: @@ -1386,25 +1529,38 @@ def run(self, location=None, remember=None, overwrite=False, if stored_loc is None: parent_loc = br_from.get_parent() if parent_loc: - raise errors.CommandError(gettext( - "No push location known or specified. To push to the " - "parent branch (at %s), use 'brz push :parent'.") % - urlutils.unescape_for_display(parent_loc, - self.outf.encoding)) + raise errors.CommandError( + gettext( + "No push location known or specified. To push to the " + "parent branch (at %s), use 'brz push :parent'." + ) + % urlutils.unescape_for_display(parent_loc, self.outf.encoding) + ) else: - raise errors.CommandError(gettext( - "No push location known or specified.")) + raise errors.CommandError( + gettext("No push location known or specified.") + ) else: - display_url = urlutils.unescape_for_display(stored_loc, - self.outf.encoding) + display_url = urlutils.unescape_for_display( + stored_loc, self.outf.encoding + ) note(gettext("Using saved push location: %s") % display_url) location = stored_loc - _show_push_branch(br_from, revision_id, location, self.outf, - verbose=verbose, overwrite=overwrite, remember=remember, - stacked_on=stacked_on, create_prefix=create_prefix, - use_existing_dir=use_existing_dir, no_tree=no_tree, - lossy=lossy) + _show_push_branch( + br_from, + revision_id, + location, + self.outf, + verbose=verbose, + overwrite=overwrite, + remember=remember, + stacked_on=stacked_on, + create_prefix=create_prefix, + use_existing_dir=use_existing_dir, + no_tree=no_tree, + lossy=lossy, + ) class cmd_branch(Command): @@ -1421,51 +1577,69 @@ class cmd_branch(Command): parameter, as in "branch foo/bar -r 5". """ - aliase = ['sprout'] - _see_also = ['checkout'] - takes_args = ['from_location', 'to_location?'] - takes_options = ['revision', - Option( - 'hardlink', help='Hard-link working tree files where possible.'), - Option('files-from', type=str, - help="Get file contents from this tree."), - Option('no-tree', - help="Create a branch without a working-tree."), - Option('switch', - help="Switch the checkout in the current directory " - "to the new branch."), - Option('stacked', - help='Create a stacked branch referring to the source branch. ' - 'The new branch will depend on the availability of the source ' - 'branch for all operations.'), - Option('standalone', - help='Do not use a shared repository, even if available.'), - Option('use-existing-dir', - help='By default branch will fail if the target' - ' directory exists, but does not already' - ' have a control directory. This flag will' - ' allow branch to proceed.'), - Option('bind', - help="Bind new branch to from location."), - Option('no-recurse-nested', - help='Do not recursively check out nested trees.'), - Option('colocated-branch', short_name='b', - type=str, help='Name of colocated branch to sprout.'), - ] - - def run(self, from_location, to_location=None, revision=None, - hardlink=False, stacked=False, standalone=False, no_tree=False, - use_existing_dir=False, switch=False, bind=False, - files_from=None, no_recurse_nested=False, colocated_branch=None): + aliase = ["sprout"] + _see_also = ["checkout"] + takes_args = ["from_location", "to_location?"] + takes_options = [ + "revision", + Option("hardlink", help="Hard-link working tree files where possible."), + Option("files-from", type=str, help="Get file contents from this tree."), + Option("no-tree", help="Create a branch without a working-tree."), + Option( + "switch", + help="Switch the checkout in the current directory " "to the new branch.", + ), + Option( + "stacked", + help="Create a stacked branch referring to the source branch. " + "The new branch will depend on the availability of the source " + "branch for all operations.", + ), + Option("standalone", help="Do not use a shared repository, even if available."), + Option( + "use-existing-dir", + help="By default branch will fail if the target" + " directory exists, but does not already" + " have a control directory. This flag will" + " allow branch to proceed.", + ), + Option("bind", help="Bind new branch to from location."), + Option("no-recurse-nested", help="Do not recursively check out nested trees."), + Option( + "colocated-branch", + short_name="b", + type=str, + help="Name of colocated branch to sprout.", + ), + ] + + def run( + self, + from_location, + to_location=None, + revision=None, + hardlink=False, + stacked=False, + standalone=False, + no_tree=False, + use_existing_dir=False, + switch=False, + bind=False, + files_from=None, + no_recurse_nested=False, + colocated_branch=None, + ): from breezy import switch as _mod_switch from .workingtree import WorkingTree + accelerator_tree, br_from = controldir.ControlDir.open_tree_or_branch( - from_location, name=colocated_branch) + from_location, name=colocated_branch + ) if no_recurse_nested: - recurse = 'none' + recurse = "none" else: - recurse = 'down' + recurse = "down" if not (hardlink or files_from): # accelerator_tree is usually slower because you have to read N # files (no readahead, lots of seeks, etc), but allow the user to @@ -1473,7 +1647,7 @@ def run(self, from_location, to_location=None, revision=None, accelerator_tree = None if files_from is not None and files_from != from_location: accelerator_tree = WorkingTree.open(files_from) - revision = _get_one_revision('branch', revision) + revision = _get_one_revision("branch", revision) self.enter_context(br_from.lock_read()) if revision is not None: revision_id = revision.as_revision_id(br_from) @@ -1484,18 +1658,18 @@ def run(self, from_location, to_location=None, revision=None, revision_id = br_from.last_revision() if to_location is None: to_location = urlutils.derive_to_location(from_location) - to_transport = transport.get_transport(to_location, purpose='write') + to_transport = transport.get_transport(to_location, purpose="write") try: - to_transport.mkdir('.') + to_transport.mkdir(".") except transport.FileExists: try: - to_dir = controldir.ControlDir.open_from_transport( - to_transport) + to_dir = controldir.ControlDir.open_from_transport(to_transport) except errors.NotBranchError as exc: if not use_existing_dir: raise errors.CommandError( - gettext('Target directory "%s" ' - 'already exists.') % to_location) from exc + gettext('Target directory "%s" ' "already exists.") + % to_location + ) from exc else: to_dir = None else: @@ -1506,27 +1680,37 @@ def run(self, from_location, to_location=None, revision=None, else: raise errors.AlreadyBranchError(to_location) except transport.NoSuchFile as exc: - raise errors.CommandError(gettext('Parent of "%s" does not exist.') - % to_location) from exc + raise errors.CommandError( + gettext('Parent of "%s" does not exist.') % to_location + ) from exc else: to_dir = None if to_dir is None: try: # preserve whatever source format we have. to_dir = br_from.controldir.sprout( - to_transport.base, revision_id, + to_transport.base, + revision_id, possible_transports=[to_transport], - accelerator_tree=accelerator_tree, hardlink=hardlink, - stacked=stacked, force_new_repo=standalone, - create_tree_if_local=not no_tree, source_branch=br_from, - recurse=recurse) + accelerator_tree=accelerator_tree, + hardlink=hardlink, + stacked=stacked, + force_new_repo=standalone, + create_tree_if_local=not no_tree, + source_branch=br_from, + recurse=recurse, + ) branch = to_dir.open_branch( possible_transports=[ - br_from.controldir.root_transport, to_transport]) + br_from.controldir.root_transport, + to_transport, + ] + ) except errors.NoSuchRevision as exc: - to_transport.delete_tree('.') + to_transport.delete_tree(".") msg = gettext("The branch {0} has no revision {1}.").format( - from_location, revision) + from_location, revision + ) raise errors.CommandError(msg) from exc else: try: @@ -1534,36 +1718,47 @@ def run(self, from_location, to_location=None, revision=None, except errors.NoRepositoryPresent: to_repo = to_dir.create_repository() to_repo.fetch(br_from.repository, revision_id=revision_id) - branch = br_from.sprout( - to_dir, revision_id=revision_id) + branch = br_from.sprout(to_dir, revision_id=revision_id) br_from.tags.merge_to(branch.tags) # If the source branch is stacked, the new branch may # be stacked whether we asked for that explicitly or not. # We therefore need a try/except here and not just 'if stacked:' try: - note(gettext('Created new stacked branch referring to %s.') % - branch.get_stacked_on_url()) - except (errors.NotStacked, _mod_branch.UnstackableBranchFormat, - errors.UnstackableRepositoryFormat): + note( + gettext("Created new stacked branch referring to %s.") + % branch.get_stacked_on_url() + ) + except ( + errors.NotStacked, + _mod_branch.UnstackableBranchFormat, + errors.UnstackableRepositoryFormat, + ): revno = branch.revno() if revno is not None: - note(ngettext('Branched %d revision.', - 'Branched %d revisions.', - branch.revno()) % revno) + note( + ngettext( + "Branched %d revision.", + "Branched %d revisions.", + branch.revno(), + ) + % revno + ) else: - note(gettext('Created new branch.')) + note(gettext("Created new branch.")) if bind: # Bind to the parent parent_branch = Branch.open(from_location) branch.bind(parent_branch) - note(gettext('New branch bound to %s') % from_location) + note(gettext("New branch bound to %s") % from_location) if switch: # Switch to the new branch - wt, _ = WorkingTree.open_containing('.') + wt, _ = WorkingTree.open_containing(".") _mod_switch.switch(wt.controldir, branch) - note(gettext('Switched to branch: %s'), - urlutils.unescape_for_display(branch.base, 'utf-8')) + note( + gettext("Switched to branch: %s"), + urlutils.unescape_for_display(branch.base, "utf-8"), + ) class cmd_branches(Command): @@ -1573,22 +1768,28 @@ class cmd_branches(Command): location. """ - takes_args = ['location?'] + takes_args = ["location?"] takes_options = [ - Option('recursive', short_name='R', - help='Recursively scan for branches rather than ' - 'just looking in the specified location.')] + Option( + "recursive", + short_name="R", + help="Recursively scan for branches rather than " + "just looking in the specified location.", + ) + ] def run(self, location=".", recursive=False): if recursive: - t = transport.get_transport(location, purpose='read') + t = transport.get_transport(location, purpose="read") if not t.listable(): - raise errors.CommandError( - "Can't scan this type of location.") + raise errors.CommandError("Can't scan this type of location.") for b in controldir.ControlDir.find_branches(t): - self.outf.write("%s\n" % urlutils.unescape_for_display( - urlutils.relative_url(t.base, b.base), - self.outf.encoding).rstrip("/")) + self.outf.write( + "%s\n" + % urlutils.unescape_for_display( + urlutils.relative_url(t.base, b.base), self.outf.encoding + ).rstrip("/") + ) else: dir = controldir.ControlDir.open_containing(location)[0] try: @@ -1599,8 +1800,10 @@ def run(self, location=".", recursive=False): for name, branch in iter_sibling_branches(dir): if name == "": continue - active = (active_branch is not None and - active_branch.user_url == branch.user_url) + active = ( + active_branch is not None + and active_branch.user_url == branch.user_url + ) names[name] = active # Only mention the current branch explicitly if it's not # one of the colocated branches @@ -1636,38 +1839,46 @@ class cmd_checkout(Command): to examine old code.) """ - _see_also = ['checkouts', 'branch', 'working-trees', 'remove-tree'] - takes_args = ['branch_location?', 'to_location?'] - takes_options = ['revision', - Option('lightweight', - help="Perform a lightweight checkout. Lightweight " - "checkouts depend on access to the branch for " - "every operation. Normal checkouts can perform " - "common operations like diff and status without " - "such access, and also support local commits." - ), - Option('files-from', type=str, - help="Get file contents from this tree."), - Option('hardlink', - help='Hard-link working tree files where possible.' - ), - ] - aliases = ['co'] - - def run(self, branch_location=None, to_location=None, revision=None, - lightweight=False, files_from=None, hardlink=False): + _see_also = ["checkouts", "branch", "working-trees", "remove-tree"] + takes_args = ["branch_location?", "to_location?"] + takes_options = [ + "revision", + Option( + "lightweight", + help="Perform a lightweight checkout. Lightweight " + "checkouts depend on access to the branch for " + "every operation. Normal checkouts can perform " + "common operations like diff and status without " + "such access, and also support local commits.", + ), + Option("files-from", type=str, help="Get file contents from this tree."), + Option("hardlink", help="Hard-link working tree files where possible."), + ] + aliases = ["co"] + + def run( + self, + branch_location=None, + to_location=None, + revision=None, + lightweight=False, + files_from=None, + hardlink=False, + ): from .workingtree import WorkingTree + if branch_location is None: branch_location = osutils.getcwd() to_location = branch_location accelerator_tree, source = controldir.ControlDir.open_tree_or_branch( - branch_location) + branch_location + ) if not (hardlink or files_from): # accelerator_tree is usually slower because you have to read N # files (no readahead, lots of seeks, etc), but allow the user to # explicitly request it accelerator_tree = None - revision = _get_one_revision('checkout', revision) + revision = _get_one_revision("checkout", revision) if files_from is not None and files_from != branch_location: accelerator_tree = WorkingTree.open(files_from) if revision is not None: @@ -1685,30 +1896,36 @@ def run(self, branch_location=None, to_location=None, revision=None, except errors.NoWorkingTree: source.controldir.create_workingtree(revision_id) return - source.create_checkout(to_location, revision_id=revision_id, - lightweight=lightweight, - accelerator_tree=accelerator_tree, - hardlink=hardlink) + source.create_checkout( + to_location, + revision_id=revision_id, + lightweight=lightweight, + accelerator_tree=accelerator_tree, + hardlink=hardlink, + ) class cmd_clone(Command): __doc__ = """Clone a control directory. """ - takes_args = ['from_location', 'to_location?'] - takes_options = ['revision', - Option('no-recurse-nested', - help='Do not recursively check out nested trees.'), - ] + takes_args = ["from_location", "to_location?"] + takes_options = [ + "revision", + Option("no-recurse-nested", help="Do not recursively check out nested trees."), + ] - def run(self, from_location, to_location=None, revision=None, no_recurse_nested=False): + def run( + self, from_location, to_location=None, revision=None, no_recurse_nested=False + ): accelerator_tree, br_from = controldir.ControlDir.open_tree_or_branch( - from_location) + from_location + ) if no_recurse_nested: pass else: pass - revision = _get_one_revision('branch', revision) + revision = _get_one_revision("branch", revision) self.enter_context(br_from.lock_read()) if revision is not None: revision_id = revision.as_revision_id(br_from) @@ -1720,7 +1937,7 @@ def run(self, from_location, to_location=None, revision=None, no_recurse_nested= if to_location is None: to_location = urlutils.derive_to_location(from_location) br_from.controldir.clone(to_location, revision_id=revision_id) - note(gettext('Created new control directory.')) + note(gettext("Created new control directory.")) class cmd_renames(Command): @@ -1729,12 +1946,13 @@ class cmd_renames(Command): # TODO: Option to show renames between two historical versions. # TODO: Only show renames under dir, rather than in the whole branch. - _see_also = ['status'] - takes_args = ['dir?'] + _see_also = ["status"] + takes_args = ["dir?"] @display_command - def run(self, dir='.'): + def run(self, dir="."): from .workingtree import WorkingTree + tree = WorkingTree.open_containing(dir)[0] self.enter_context(tree.lock_read()) old_tree = tree.basis_tree() @@ -1782,32 +2000,36 @@ class cmd_update(Command): current working directory is used. """ - _see_also = ['pull', 'working-trees', 'status-flags'] - takes_args = ['dir?'] - takes_options = ['revision', - Option('show-base', - help="Show base revision text in conflicts."), - ] - aliases = ['up'] + _see_also = ["pull", "working-trees", "status-flags"] + takes_args = ["dir?"] + takes_options = [ + "revision", + Option("show-base", help="Show base revision text in conflicts."), + ] + aliases = ["up"] def run(self, dir=None, revision=None, show_base=None): from .workingtree import WorkingTree + if revision is not None and len(revision) != 1: - raise errors.CommandError(gettext( - "brz update --revision takes exactly one revision")) + raise errors.CommandError( + gettext("brz update --revision takes exactly one revision") + ) if dir is None: - tree = WorkingTree.open_containing('.')[0] + tree = WorkingTree.open_containing(".")[0] else: tree, relpath = WorkingTree.open_containing(dir) if relpath: # See bug 557886. - raise errors.CommandError(gettext( - "brz update can only update a whole tree, " - "not a file or subdirectory")) + raise errors.CommandError( + gettext( + "brz update can only update a whole tree, " + "not a file or subdirectory" + ) + ) branch = tree.branch possible_transports = [] - master = branch.get_master_branch( - possible_transports=possible_transports) + master = branch.get_master_branch(possible_transports=possible_transports) if master is not None: branch_location = master.base self.enter_context(tree.lock_write()) @@ -1816,8 +2038,8 @@ def run(self, dir=None, revision=None, show_base=None): self.enter_context(tree.lock_tree_write()) # get rid of the final '/' and be ready for display branch_location = urlutils.unescape_for_display( - branch_location.rstrip('/'), - self.outf.encoding) + branch_location.rstrip("/"), self.outf.encoding + ) existing_pending_merges = tree.get_parent_ids()[1:] if master is None: old_tip = None @@ -1832,33 +2054,47 @@ def run(self, dir=None, revision=None, show_base=None): revision_id = branch.last_revision() if revision_id == tree.last_revision(): revno = branch.revision_id_to_dotted_revno(revision_id) - note(gettext("Tree is up to date at revision {0} of branch {1}" - ).format('.'.join(map(str, revno)), branch_location)) + note( + gettext("Tree is up to date at revision {0} of branch {1}").format( + ".".join(map(str, revno)), branch_location + ) + ) return 0 view_info = _get_view_info_for_change_reporter(tree) change_reporter = delta._ChangeReporter( - unversioned_filter=tree.is_ignored, - view_info=view_info) + unversioned_filter=tree.is_ignored, view_info=view_info + ) try: conflicts = tree.update( change_reporter, possible_transports=possible_transports, revision=revision_id, old_tip=old_tip, - show_base=show_base) + show_base=show_base, + ) except errors.NoSuchRevision as exc: - raise errors.CommandError(gettext( - "branch has no revision %s\n" - "brz update --revision only works" - " for a revision in the branch history") - % (exc.revision)) from exc + raise errors.CommandError( + gettext( + "branch has no revision %s\n" + "brz update --revision only works" + " for a revision in the branch history" + ) + % (exc.revision) + ) from exc revno = tree.branch.revision_id_to_dotted_revno(tree.last_revision()) - note(gettext('Updated to revision {0} of branch {1}').format( - '.'.join(map(str, revno)), branch_location)) + note( + gettext("Updated to revision {0} of branch {1}").format( + ".".join(map(str, revno)), branch_location + ) + ) parent_ids = tree.get_parent_ids() if parent_ids[1:] and parent_ids[1:] != existing_pending_merges: - note(gettext('Your local commits will now show as pending merges with ' - "'brz status', and can be committed with 'brz commit'.")) + note( + gettext( + "Your local commits will now show as pending merges with " + "'brz status', and can be committed with 'brz commit'." + ) + ) if conflicts != 0: return 1 else: @@ -1893,10 +2129,10 @@ class cmd_info(Command): brz info -vv """ - _see_also = ['revno', 'working-trees', 'repositories'] - takes_args = ['location?'] - takes_options = ['verbose'] - encoding_type = 'replace' + _see_also = ["revno", "working-trees", "repositories"] + takes_args = ["location?"] + takes_options = ["verbose"] + encoding_type = "replace" @display_command def run(self, location=None, verbose=False): @@ -1905,8 +2141,12 @@ def run(self, location=None, verbose=False): else: noise_level = 0 from .info import show_bzrdir_info - show_bzrdir_info(controldir.ControlDir.open_containing(location)[0], - verbose=noise_level, outfile=self.outf) + + show_bzrdir_info( + controldir.ControlDir.open_containing(location)[0], + verbose=noise_level, + outfile=self.outf, + ) class cmd_remove(Command): @@ -1918,22 +2158,25 @@ class cmd_remove(Command): parameters are given Breezy will scan for files that are being tracked by Breezy but missing in your tree and stop tracking them for you. """ - takes_args = ['file*'] - takes_options = ['verbose', - Option( - 'new', help='Only remove files that have never been committed.'), - RegistryOption.from_kwargs('file-deletion-strategy', - 'The file deletion mode to be used.', - title='Deletion Strategy', value_switches=True, enum_switch=False, - safe='Backup changed files (default).', - keep='Delete from brz but leave the working copy.', - no_backup='Don\'t backup changed files.'), - ] - aliases = ['rm', 'del'] - encoding_type = 'replace' - - def run(self, file_list, verbose=False, new=False, - file_deletion_strategy='safe'): + takes_args = ["file*"] + takes_options = [ + "verbose", + Option("new", help="Only remove files that have never been committed."), + RegistryOption.from_kwargs( + "file-deletion-strategy", + "The file deletion mode to be used.", + title="Deletion Strategy", + value_switches=True, + enum_switch=False, + safe="Backup changed files (default).", + keep="Delete from brz but leave the working copy.", + no_backup="Don't backup changed files.", + ), + ] + aliases = ["rm", "del"] + encoding_type = "replace" + + def run(self, file_list, verbose=False, new=False, file_deletion_strategy="safe"): from .workingtree import WorkingTree tree, file_list = WorkingTree.open_containing_paths(file_list) @@ -1945,11 +2188,10 @@ def run(self, file_list, verbose=False, new=False, # Heuristics should probably all move into tree.remove_smart or # some such? if new: - added = tree.changes_from(tree.basis_tree(), - specific_files=file_list).added + added = tree.changes_from(tree.basis_tree(), specific_files=file_list).added file_list = sorted([f.path[1] for f in added], reverse=True) if len(file_list) == 0: - raise errors.CommandError(gettext('No matching files.')) + raise errors.CommandError(gettext("No matching files.")) elif file_list is None: # missing files show up in iter_changes(basis) as # versioned-with-no-kind. @@ -1959,10 +2201,14 @@ def run(self, file_list, verbose=False, new=False, if change.path[1] is not None and change.kind[1] is None: missing.append(change.path[1]) file_list = sorted(missing, reverse=True) - file_deletion_strategy = 'keep' - tree.remove(file_list, verbose=verbose, to_file=self.outf, - keep_files=file_deletion_strategy == 'keep', - force=(file_deletion_strategy == 'no-backup')) + file_deletion_strategy = "keep" + tree.remove( + file_list, + verbose=verbose, + to_file=self.outf, + keep_files=file_deletion_strategy == "keep", + force=(file_deletion_strategy == "no-backup"), + ) class cmd_reconcile(Command): @@ -1984,17 +2230,19 @@ class cmd_reconcile(Command): The branch *MUST* be on a listable system such as local disk or sftp. """ - _see_also = ['check'] - takes_args = ['branch?'] + _see_also = ["check"] + takes_args = ["branch?"] takes_options = [ - Option('canonicalize-chks', - help='Make sure CHKs are in canonical form (repairs ' - 'bug 522637).', - hidden=True), - ] + Option( + "canonicalize-chks", + help="Make sure CHKs are in canonical form (repairs " "bug 522637).", + hidden=True, + ), + ] def run(self, branch=".", canonicalize_chks=False): from .reconcile import reconcile + dir = controldir.ControlDir.open(branch) reconcile(dir, canonicalize_chks=canonicalize_chks) @@ -2002,8 +2250,8 @@ def run(self, branch=".", canonicalize_chks=False): class cmd_revision_history(Command): __doc__ = """Display the list of revision ids on a branch.""" - _see_also = ['log'] - takes_args = ['location?'] + _see_also = ["log"] + takes_args = ["location?"] hidden = True @@ -2012,24 +2260,28 @@ def run(self, location="."): branch = Branch.open_containing(location)[0] self.enter_context(branch.lock_read()) graph = branch.repository.get_graph() - history = list(graph.iter_lefthand_ancestry(branch.last_revision(), - [_mod_revision.NULL_REVISION])) + history = list( + graph.iter_lefthand_ancestry( + branch.last_revision(), [_mod_revision.NULL_REVISION] + ) + ) for revid in reversed(history): self.outf.write(revid) - self.outf.write('\n') + self.outf.write("\n") class cmd_ancestry(Command): __doc__ = """List all revisions merged into this branch.""" - _see_also = ['log', 'revision-history'] - takes_args = ['location?'] + _see_also = ["log", "revision-history"] + takes_args = ["location?"] hidden = True @display_command def run(self, location="."): from .workingtree import WorkingTree + try: wt = WorkingTree.open_containing(location)[0] except errors.NoWorkingTree: @@ -2041,12 +2293,11 @@ def run(self, location="."): self.enter_context(b.repository.lock_read()) graph = b.repository.get_graph() - revisions = [revid for revid, parents in - graph.iter_ancestry([last_revision])] + revisions = [revid for revid, parents in graph.iter_ancestry([last_revision])] for revision_id in reversed(revisions): if _mod_revision.is_null(revision_id): continue - self.outf.write(revision_id.decode('utf-8') + '\n') + self.outf.write(revision_id.decode("utf-8") + "\n") class cmd_init(Command): @@ -2072,36 +2323,47 @@ class cmd_init(Command): brz commit -m "imported project" """ - _see_also = ['init-shared-repository', 'branch', 'checkout'] - takes_args = ['location?'] + _see_also = ["init-shared-repository", "branch", "checkout"] + takes_args = ["location?"] takes_options = [ - Option('create-prefix', - help='Create the path leading up to the branch ' - 'if it does not already exist.'), - RegistryOption('format', - help='Specify a format for this branch. ' - 'See "help formats" for a full list.', - lazy_registry=('breezy.controldir', 'format_registry'), - converter=lambda name: controldir.format_registry.make_controldir( # type: ignore - name), - value_switches=True, - title="Branch format", - ), - Option('append-revisions-only', - help='Never change revnos or the existing log.' - ' Append revisions to it only.'), - Option('no-tree', - 'Create a branch without a working tree.') - ] - - def run(self, location=None, format=None, append_revisions_only=False, - create_prefix=False, no_tree=False): + Option( + "create-prefix", + help="Create the path leading up to the branch " + "if it does not already exist.", + ), + RegistryOption( + "format", + help="Specify a format for this branch. " + 'See "help formats" for a full list.', + lazy_registry=("breezy.controldir", "format_registry"), + converter=lambda name: controldir.format_registry.make_controldir( # type: ignore + name + ), + value_switches=True, + title="Branch format", + ), + Option( + "append-revisions-only", + help="Never change revnos or the existing log." + " Append revisions to it only.", + ), + Option("no-tree", "Create a branch without a working tree."), + ] + + def run( + self, + location=None, + format=None, + append_revisions_only=False, + create_prefix=False, + no_tree=False, + ): if format is None: - format = controldir.format_registry.make_controldir('default') + format = controldir.format_registry.make_controldir("default") if location is None: - location = '.' + location = "." - to_transport = transport.get_transport(location, purpose='write') + to_transport = transport.get_transport(location, purpose="write") # The path has to exist to initialize a # branch inside of it. @@ -2112,16 +2374,19 @@ def run(self, location=None, format=None, append_revisions_only=False, to_transport.ensure_base() except transport.NoSuchFile as exc: if not create_prefix: - raise errors.CommandError(gettext("Parent directory of %s" - " does not exist." - "\nYou may supply --create-prefix to create all" - " leading parent directories.") - % location) from exc + raise errors.CommandError( + gettext( + "Parent directory of %s" + " does not exist." + "\nYou may supply --create-prefix to create all" + " leading parent directories." + ) + % location + ) from exc to_transport.create_prefix() try: - a_controldir = controldir.ControlDir.open_from_transport( - to_transport) + a_controldir = controldir.ControlDir.open_from_transport(to_transport) except errors.NotBranchError: # really a NotBzrDir error... create_branch = controldir.ControlDir.create_branch_convenience @@ -2129,15 +2394,21 @@ def run(self, location=None, format=None, append_revisions_only=False, force_new_tree = False else: force_new_tree = None - branch = create_branch(to_transport.base, format=format, - possible_transports=[to_transport], - force_new_tree=force_new_tree) + branch = create_branch( + to_transport.base, + format=format, + possible_transports=[to_transport], + force_new_tree=force_new_tree, + ) a_controldir = branch.controldir else: from .transport.local import LocalTransport + if a_controldir.has_branch(): - if (isinstance(to_transport, LocalTransport) - and not a_controldir.has_workingtree()): + if ( + isinstance(to_transport, LocalTransport) + and not a_controldir.has_workingtree() + ): raise errors.BranchExistsWithoutWorkingTree(location) raise errors.AlreadyBranchError(location) branch = a_controldir.create_branch() @@ -2147,10 +2418,15 @@ def run(self, location=None, format=None, append_revisions_only=False, try: branch.set_append_revisions_only(True) except errors.UpgradeRequired as exc: - raise errors.CommandError(gettext('This branch format cannot be set' - ' to append-revisions-only. Try --default.')) from exc + raise errors.CommandError( + gettext( + "This branch format cannot be set" + " to append-revisions-only. Try --default." + ) + ) from exc if not is_quiet(): from .info import describe_format, describe_layout + try: tree = a_controldir.open_workingtree(recommend_upgrade=False) except (errors.NoWorkingTree, errors.NotLocalUrl): @@ -2158,8 +2434,9 @@ def run(self, location=None, format=None, append_revisions_only=False, repository = branch.repository layout = describe_layout(repository, branch, tree).lower() format = describe_format(a_controldir, repository, branch, tree) - self.outf.write(gettext("Created a {0} (format: {1})\n").format( - layout, format)) + self.outf.write( + gettext("Created a {0} (format: {1})\n").format(layout, format) + ) if repository.is_shared(): # XXX: maybe this can be refactored into transport.path_or_url() url = repository.controldir.root_transport.external_url() @@ -2198,44 +2475,59 @@ class cmd_init_shared_repository(Command): (add files here) """ - _see_also = ['init', 'branch', 'checkout', 'repositories'] + _see_also = ["init", "branch", "checkout", "repositories"] takes_args = ["location"] - takes_options = [RegistryOption('format', - help='Specify a format for this repository. See' - ' "brz help formats" for details.', - lazy_registry=( - 'breezy.controldir', 'format_registry'), - converter=lambda name: controldir.format_registry.make_controldir( # type: ignore - name), - value_switches=True, title='Repository format'), - Option('no-trees', - help='Branches in the repository will default to' - ' not having a working tree.'), - ] + takes_options = [ + RegistryOption( + "format", + help="Specify a format for this repository. See" + ' "brz help formats" for details.', + lazy_registry=("breezy.controldir", "format_registry"), + converter=lambda name: controldir.format_registry.make_controldir( # type: ignore + name + ), + value_switches=True, + title="Repository format", + ), + Option( + "no-trees", + help="Branches in the repository will default to" + " not having a working tree.", + ), + ] aliases = ["init-shared-repo", "init-repo"] def run(self, location, format=None, no_trees=False): if format is None: - format = controldir.format_registry.make_controldir('default') + format = controldir.format_registry.make_controldir("default") if location is None: - location = '.' + location = "." - to_transport = transport.get_transport(location, purpose='write') + to_transport = transport.get_transport(location, purpose="write") if format.fixed_components: repo_format_name = None else: repo_format_name = format.repository_format.get_format_string() - (repo, newdir, require_stacking, repository_policy) = ( - format.initialize_on_transport_ex(to_transport, - create_prefix=True, make_working_trees=not no_trees, - shared_repo=True, force_new_repo=True, - use_existing_dir=True, - repo_format_name=repo_format_name)) + ( + repo, + newdir, + require_stacking, + repository_policy, + ) = format.initialize_on_transport_ex( + to_transport, + create_prefix=True, + make_working_trees=not no_trees, + shared_repo=True, + force_new_repo=True, + use_existing_dir=True, + repo_format_name=repo_format_name, + ) if not is_quiet(): from .info import show_bzrdir_info + show_bzrdir_info(newdir, verbose=0, outfile=self.outf) @@ -2325,108 +2617,157 @@ class cmd_diff(Command): brz diff --using /usr/bin/diff --diff-options -wu """ - _see_also = ['status'] - takes_args = ['file*'] + _see_also = ["status"] + takes_args = ["file*"] takes_options = [ - Option('diff-options', type=str, - help='Pass these options to the external diff program.'), - Option('prefix', type=str, - short_name='p', - help='Set prefixes added to old and new filenames, as ' - 'two values separated by a colon. (eg "old/:new/").'), - Option('old', - help='Branch/tree to compare from.', - type=str, - ), - Option('new', - help='Branch/tree to compare to.', - type=str, - ), - 'revision', - 'change', - Option('using', - help='Use this command to compare files.', - type=str, - ), - RegistryOption('format', - short_name='F', - help='Diff format to use.', - lazy_registry=('breezy.diff', 'format_registry'), - title='Diff format'), - Option('context', - help='How many lines of context to show.', - type=int, - ), + Option( + "diff-options", + type=str, + help="Pass these options to the external diff program.", + ), + Option( + "prefix", + type=str, + short_name="p", + help="Set prefixes added to old and new filenames, as " + 'two values separated by a colon. (eg "old/:new/").', + ), + Option( + "old", + help="Branch/tree to compare from.", + type=str, + ), + Option( + "new", + help="Branch/tree to compare to.", + type=str, + ), + "revision", + "change", + Option( + "using", + help="Use this command to compare files.", + type=str, + ), + RegistryOption( + "format", + short_name="F", + help="Diff format to use.", + lazy_registry=("breezy.diff", "format_registry"), + title="Diff format", + ), + Option( + "context", + help="How many lines of context to show.", + type=int, + ), RegistryOption.from_kwargs( - 'color', - help='Color mode to use.', - title='Color Mode', value_switches=False, enum_switch=True, - never='Never colorize output.', - auto='Only colorize output if terminal supports it and STDOUT is a' - ' TTY.', - always='Always colorize output (default).'), + "color", + help="Color mode to use.", + title="Color Mode", + value_switches=False, + enum_switch=True, + never="Never colorize output.", + auto="Only colorize output if terminal supports it and STDOUT is a" " TTY.", + always="Always colorize output (default).", + ), Option( - 'check-style', - help=('Warn if trailing whitespace or spurious changes have been' - ' added.')) - ] + "check-style", + help=( + "Warn if trailing whitespace or spurious changes have been" " added." + ), + ), + ] - aliases = ['di', 'dif'] - encoding_type = 'exact' + aliases = ["di", "dif"] + encoding_type = "exact" @display_command - def run(self, revision=None, file_list=None, diff_options=None, - prefix=None, old=None, new=None, using=None, format=None, - context=None, color='auto'): + def run( + self, + revision=None, + file_list=None, + diff_options=None, + prefix=None, + old=None, + new=None, + using=None, + format=None, + context=None, + color="auto", + ): from .diff import get_trees_and_branches_to_diff_locked, show_diff_trees - if prefix == '0': + if prefix == "0": # diff -p0 format - old_label = '' - new_label = '' - elif prefix == '1' or prefix is None: - old_label = 'old/' - new_label = 'new/' - elif ':' in prefix: + old_label = "" + new_label = "" + elif prefix == "1" or prefix is None: + old_label = "old/" + new_label = "new/" + elif ":" in prefix: old_label, new_label = prefix.split(":") else: - raise errors.CommandError(gettext( - '--prefix expects two values separated by a colon' - ' (eg "old/:new/")')) + raise errors.CommandError( + gettext( + "--prefix expects two values separated by a colon" + ' (eg "old/:new/")' + ) + ) if revision and len(revision) > 2: - raise errors.CommandError(gettext('brz diff --revision takes exactly' - ' one or two revision specifiers')) + raise errors.CommandError( + gettext( + "brz diff --revision takes exactly" + " one or two revision specifiers" + ) + ) if using is not None and format is not None: - raise errors.CommandError(gettext( - '{0} and {1} are mutually exclusive').format( - '--using', '--format')) - - (old_tree, new_tree, - old_branch, new_branch, - specific_files, extra_trees) = get_trees_and_branches_to_diff_locked( - file_list, revision, old, new, self._exit_stack, apply_view=True) + raise errors.CommandError( + gettext("{0} and {1} are mutually exclusive").format( + "--using", "--format" + ) + ) + + ( + old_tree, + new_tree, + old_branch, + new_branch, + specific_files, + extra_trees, + ) = get_trees_and_branches_to_diff_locked( + file_list, revision, old, new, self._exit_stack, apply_view=True + ) # GNU diff on Windows uses ANSI encoding for filenames path_encoding = osutils.get_diff_header_encoding() outf = self.outf - if color == 'auto': + if color == "auto": from .terminal import has_ansi_colors + if has_ansi_colors(): - color = 'always' + color = "always" else: - color = 'never' - if 'always' == color: + color = "never" + if "always" == color: from .colordiff import DiffWriter + outf = DiffWriter(outf) - return show_diff_trees(old_tree, new_tree, outf, - specific_files=specific_files, - external_diff_options=diff_options, - old_label=old_label, new_label=new_label, - extra_trees=extra_trees, - path_encoding=path_encoding, - using=using, context=context, - format_cls=format) + return show_diff_trees( + old_tree, + new_tree, + outf, + specific_files=specific_files, + external_diff_options=diff_options, + old_label=old_label, + new_label=new_label, + extra_trees=extra_trees, + path_encoding=path_encoding, + using=using, + context=context, + format_cls=format, + ) class cmd_deleted(Command): @@ -2438,12 +2779,13 @@ class cmd_deleted(Command): # directories with readdir, rather than stating each one. Same # level of effort but possibly much less IO. (Or possibly not, # if the directories are very large...) - _see_also = ['status', 'ls'] - takes_options = ['directory', 'show-ids'] + _see_also = ["status", "ls"] + takes_options = ["directory", "show-ids"] @display_command - def run(self, show_ids=False, directory='.'): + def run(self, show_ids=False, directory="."): from .workingtree import WorkingTree + tree = WorkingTree.open_containing(directory)[0] self.enter_context(tree.lock_read()) old = tree.basis_tree() @@ -2452,9 +2794,9 @@ def run(self, show_ids=False, directory='.'): for change in delta.removed: self.outf.write(change.path[0]) if show_ids: - self.outf.write(' ') + self.outf.write(" ") self.outf.write(change.file_id) - self.outf.write('\n') + self.outf.write("\n") class cmd_modified(Command): @@ -2462,21 +2804,22 @@ class cmd_modified(Command): """ hidden = True - _see_also = ['status', 'ls'] - takes_options = ['directory', 'null'] + _see_also = ["status", "ls"] + takes_options = ["directory", "null"] @display_command - def run(self, null=False, directory='.'): + def run(self, null=False, directory="."): from .workingtree import WorkingTree + tree = WorkingTree.open_containing(directory)[0] self.enter_context(tree.lock_read()) td = tree.changes_from(tree.basis_tree()) self.cleanup_now() for change in td.modified: if null: - self.outf.write(change.path[1] + '\0') + self.outf.write(change.path[1] + "\0") else: - self.outf.write(osutils.quotefn(change.path[1]) + '\n') + self.outf.write(osutils.quotefn(change.path[1]) + "\n") class cmd_added(Command): @@ -2484,12 +2827,13 @@ class cmd_added(Command): """ hidden = True - _see_also = ['status', 'ls'] - takes_options = ['directory', 'null'] + _see_also = ["status", "ls"] + takes_options = ["directory", "null"] @display_command - def run(self, null=False, directory='.'): + def run(self, null=False, directory="."): from .workingtree import WorkingTree + wt = WorkingTree.open_containing(directory)[0] self.enter_context(wt.lock_read()) basis = wt.basis_tree() @@ -2497,14 +2841,14 @@ def run(self, null=False, directory='.'): for path in wt.all_versioned_paths(): if basis.has_filename(path): continue - if path == '': + if path == "": continue if not os.access(osutils.pathjoin(wt.basedir, path), os.F_OK): continue if null: - self.outf.write(path + '\0') + self.outf.write(path + "\0") else: - self.outf.write(osutils.quotefn(path) + '\n') + self.outf.write(osutils.quotefn(path) + "\n") class cmd_root(Command): @@ -2513,14 +2857,15 @@ class cmd_root(Command): The root is the nearest enclosing directory with a control directory.""" - takes_args = ['filename?'] + takes_args = ["filename?"] @display_command def run(self, filename=None): """Print the branch root.""" from .workingtree import WorkingTree + tree = WorkingTree.open_containing(filename)[0] - self.outf.write(tree.basedir + '\n') + self.outf.write(tree.basedir + "\n") def _parse_limit(limitstring): @@ -2694,128 +3039,146 @@ class cmd_log(Command): the historycache plugin. This plugin buffers historical information trading disk space for faster speed. """ - takes_args = ['file*'] - _see_also = ['log-formats', 'revisionspec'] + takes_args = ["file*"] + _see_also = ["log-formats", "revisionspec"] takes_options = [ - Option('forward', - help='Show from oldest to newest.'), - 'timezone', - custom_help('verbose', - help='Show files changed in each revision.'), - 'show-ids', - 'revision', - Option('change', - type=breezy.option._parse_revision_str, - short_name='c', - help='Show just the specified revision.' - ' See also "help revisionspec".'), - 'log-format', - RegistryOption('authors', - 'What names to list as authors - first, all or committer.', - title='Authors', - lazy_registry=( - 'breezy.log', 'author_list_registry'), - ), - Option('levels', - short_name='n', - help='Number of levels to display - 0 for all, 1 for flat.', - argname='N', - type=_parse_levels), - Option('message', - help='Show revisions whose message matches this ' - 'regular expression.', - type=str, - hidden=True), - Option('limit', - short_name='l', - help='Limit the output to the first N revisions.', - argname='N', - type=_parse_limit), - Option('show-diff', - short_name='p', - help='Show changes made in each revision as a patch.'), - Option('include-merged', - help='Show merged revisions like --levels 0 does.'), - Option('include-merges', hidden=True, - help='Historical alias for --include-merged.'), - Option('omit-merges', - help='Do not report commits with more than one parent.'), - Option('exclude-common-ancestry', - help='Display only the revisions that are not part' - ' of both ancestries (require -rX..Y).' - ), - Option('signatures', - help='Show digital signature validity.'), - ListOption('match', - short_name='m', - help='Show revisions whose properties match this ' - 'expression.', - type=str), - ListOption('match-message', - help='Show revisions whose message matches this ' - 'expression.', - type=str), - ListOption('match-committer', - help='Show revisions whose committer matches this ' - 'expression.', - type=str), - ListOption('match-author', - help='Show revisions whose authors match this ' - 'expression.', - type=str), - ListOption('match-bugs', - help='Show revisions whose bugs match this ' - 'expression.', - type=str) - ] - encoding_type = 'replace' + Option("forward", help="Show from oldest to newest."), + "timezone", + custom_help("verbose", help="Show files changed in each revision."), + "show-ids", + "revision", + Option( + "change", + type=breezy.option._parse_revision_str, + short_name="c", + help="Show just the specified revision." ' See also "help revisionspec".', + ), + "log-format", + RegistryOption( + "authors", + "What names to list as authors - first, all or committer.", + title="Authors", + lazy_registry=("breezy.log", "author_list_registry"), + ), + Option( + "levels", + short_name="n", + help="Number of levels to display - 0 for all, 1 for flat.", + argname="N", + type=_parse_levels, + ), + Option( + "message", + help="Show revisions whose message matches this " "regular expression.", + type=str, + hidden=True, + ), + Option( + "limit", + short_name="l", + help="Limit the output to the first N revisions.", + argname="N", + type=_parse_limit, + ), + Option( + "show-diff", + short_name="p", + help="Show changes made in each revision as a patch.", + ), + Option("include-merged", help="Show merged revisions like --levels 0 does."), + Option( + "include-merges", hidden=True, help="Historical alias for --include-merged." + ), + Option("omit-merges", help="Do not report commits with more than one parent."), + Option( + "exclude-common-ancestry", + help="Display only the revisions that are not part" + " of both ancestries (require -rX..Y).", + ), + Option("signatures", help="Show digital signature validity."), + ListOption( + "match", + short_name="m", + help="Show revisions whose properties match this " "expression.", + type=str, + ), + ListOption( + "match-message", + help="Show revisions whose message matches this " "expression.", + type=str, + ), + ListOption( + "match-committer", + help="Show revisions whose committer matches this " "expression.", + type=str, + ), + ListOption( + "match-author", + help="Show revisions whose authors match this " "expression.", + type=str, + ), + ListOption( + "match-bugs", + help="Show revisions whose bugs match this " "expression.", + type=str, + ), + ] + encoding_type = "replace" @display_command - def run(self, file_list=None, timezone='original', - verbose=False, - show_ids=False, - forward=False, - revision=None, - change=None, - log_format=None, - levels=None, - message=None, - limit=None, - show_diff=False, - include_merged=None, - authors=None, - exclude_common_ancestry=False, - signatures=False, - match=None, - match_message=None, - match_committer=None, - match_author=None, - match_bugs=None, - omit_merges=False, - ): + def run( + self, + file_list=None, + timezone="original", + verbose=False, + show_ids=False, + forward=False, + revision=None, + change=None, + log_format=None, + levels=None, + message=None, + limit=None, + show_diff=False, + include_merged=None, + authors=None, + exclude_common_ancestry=False, + signatures=False, + match=None, + match_message=None, + match_committer=None, + match_author=None, + match_bugs=None, + omit_merges=False, + ): from .log import Logger, _get_info_for_log_files, make_log_request_dict - direction = (forward and 'forward') or 'reverse' + + direction = (forward and "forward") or "reverse" if include_merged is None: include_merged = False - if (exclude_common_ancestry - and (revision is None or len(revision) != 2)): - raise errors.CommandError(gettext( - '--exclude-common-ancestry requires -r with two revisions')) + if exclude_common_ancestry and (revision is None or len(revision) != 2): + raise errors.CommandError( + gettext("--exclude-common-ancestry requires -r with two revisions") + ) if include_merged: if levels is None: levels = 0 else: - raise errors.CommandError(gettext( - '{0} and {1} are mutually exclusive').format( - '--levels', '--include-merged')) + raise errors.CommandError( + gettext("{0} and {1} are mutually exclusive").format( + "--levels", "--include-merged" + ) + ) if change is not None: if len(change) > 1: raise errors.RangeInChangeOption() if revision is not None: - raise errors.CommandError(gettext( - '{0} and {1} are mutually exclusive').format( - '--revision', '--change')) + raise errors.CommandError( + gettext("{0} and {1} are mutually exclusive").format( + "--revision", "--change" + ) + ) else: revision = change @@ -2824,34 +3187,36 @@ def run(self, file_list=None, timezone='original', if file_list: # find the file ids to log and check for directory filtering b, file_info_list, rev1, rev2 = _get_info_for_log_files( - revision, file_list, self._exit_stack) + revision, file_list, self._exit_stack + ) for relpath, kind in file_info_list: if not kind: - raise errors.CommandError(gettext( - "Path unknown at end or start of revision range: %s") % - relpath) + raise errors.CommandError( + gettext("Path unknown at end or start of revision range: %s") + % relpath + ) # If the relpath is the top of the tree, we log everything - if relpath == '': + if relpath == "": files = [] break else: files.append(relpath) filter_by_dir = filter_by_dir or ( - kind in ['directory', 'tree-reference']) + kind in ["directory", "tree-reference"] + ) else: # log everything # FIXME ? log the current subdir only RBC 20060203 - if revision is not None \ - and len(revision) > 0 and revision[0].get_branch(): + if revision is not None and len(revision) > 0 and revision[0].get_branch(): location = revision[0].get_branch() else: - location = '.' + location = "." dir, relpath = controldir.ControlDir.open_containing(location) b = dir.open_branch() self.enter_context(b.lock_read()) rev1, rev2 = _get_revision_range(revision, b, self.name()) - if b.get_config_stack().get('validate_signatures_in_log'): + if b.get_config_stack().get("validate_signatures_in_log"): signatures = True if signatures: @@ -2863,27 +3228,29 @@ def run(self, file_list=None, timezone='original', if not verbose: delta_type = None else: - delta_type = 'full' + delta_type = "full" if not show_diff: diff_type = None elif files: - diff_type = 'partial' + diff_type = "partial" else: - diff_type = 'full' + diff_type = "full" # Build the log formatter if log_format is None: log_format = log.log_formatter_registry.get_default(b) # Make a non-encoding output to include the diffs - bug 328007 - unencoded_output = ui.ui_factory.make_output_stream( - encoding_type='exact') - lf = log_format(show_ids=show_ids, to_file=self.outf, - to_exact_file=unencoded_output, - show_timezone=timezone, - delta_format=get_verbosity_level(), - levels=levels, - show_advice=levels is None, - author_list_handler=authors) + unencoded_output = ui.ui_factory.make_output_stream(encoding_type="exact") + lf = log_format( + show_ids=show_ids, + to_file=self.outf, + to_exact_file=unencoded_output, + show_timezone=timezone, + delta_format=get_verbosity_level(), + levels=levels, + show_advice=levels is None, + author_list_handler=authors, + ) # Choose the algorithm for doing the logging. It's annoying # having multiple code paths like this but necessary until @@ -2898,32 +3265,40 @@ def run(self, file_list=None, timezone='original', # original algorithm - per-file-graph - for the "single # file that isn't a directory without showing a delta" case. partial_history = revision and b.repository._format.supports_chks - match_using_deltas = (len(files) != 1 or filter_by_dir - or delta_type or partial_history) + match_using_deltas = ( + len(files) != 1 or filter_by_dir or delta_type or partial_history + ) match_dict = {} if match: - match_dict[''] = match + match_dict[""] = match if match_message: - match_dict['message'] = match_message + match_dict["message"] = match_message if match_committer: - match_dict['committer'] = match_committer + match_dict["committer"] = match_committer if match_author: - match_dict['author'] = match_author + match_dict["author"] = match_author if match_bugs: - match_dict['bugs'] = match_bugs + match_dict["bugs"] = match_bugs # Build the LogRequest and execute it if len(files) == 0: files = None rqst = make_log_request_dict( - direction=direction, specific_files=files, - start_revision=rev1, end_revision=rev2, limit=limit, - message_search=message, delta_type=delta_type, - diff_type=diff_type, _match_using_deltas=match_using_deltas, - exclude_common_ancestry=exclude_common_ancestry, match=match_dict, - signature=signatures, omit_merges=omit_merges, - ) + direction=direction, + specific_files=files, + start_revision=rev1, + end_revision=rev2, + limit=limit, + message_search=message, + delta_type=delta_type, + diff_type=diff_type, + _match_using_deltas=match_using_deltas, + exclude_common_ancestry=exclude_common_ancestry, + match=match_dict, + signature=signatures, + omit_merges=omit_merges, + ) Logger(b, rqst).show(lf) @@ -2945,9 +3320,10 @@ def _get_revision_range(revisionspec_list, branch, command_name): # b is taken from revision[0].get_branch(), and # show_log will use its revision_history. Having # different branches will lead to weird behaviors. - raise errors.CommandError(gettext( - "brz %s doesn't accept two revisions in different" - " branches.") % command_name) + raise errors.CommandError( + gettext("brz %s doesn't accept two revisions in different" " branches.") + % command_name + ) if start_spec.spec is None: # Avoid loading all the history. rev1 = RevisionInfo(branch, None, None) @@ -2961,8 +3337,9 @@ def _get_revision_range(revisionspec_list, branch, command_name): else: rev2 = end_spec.in_history(branch) else: - raise errors.CommandError(gettext( - 'brz %s --revision takes one or two values.') % command_name) + raise errors.CommandError( + gettext("brz %s --revision takes one or two values.") % command_name + ) return rev1, rev2 @@ -2976,14 +3353,14 @@ def _revision_range_to_revid_range(revision_range): return rev_id1, rev_id2 -def get_log_format(long=False, short=False, line=False, default='long'): +def get_log_format(long=False, short=False, line=False, default="long"): log_format = default if long: - log_format = 'long' + log_format = "long" if short: - log_format = 'short' + log_format = "short" if line: - log_format = 'line' + log_format = "line" return log_format @@ -2999,10 +3376,12 @@ class cmd_touching_revisions(Command): @display_command def run(self, filename): from .workingtree import WorkingTree + tree, relpath = WorkingTree.open_containing(filename) with tree.lock_read(): touching_revs = log.find_touching_revisions( - tree.branch.repository, tree.branch.last_revision(), tree, relpath) + tree.branch.repository, tree.branch.last_revision(), tree, relpath + ) for revno, _revision_id, what in reversed(list(touching_revs)): self.outf.write("%6d %s\n" % (revno, what)) @@ -3011,68 +3390,80 @@ class cmd_ls(Command): __doc__ = """List files in a tree. """ - _see_also = ['status', 'cat'] - takes_args = ['path?'] + _see_also = ["status", "cat"] + takes_args = ["path?"] takes_options = [ - 'verbose', - 'revision', - Option('recursive', short_name='R', - help='Recurse into subdirectories.'), - Option('from-root', - help='Print paths relative to the root of the branch.'), - Option('unknown', short_name='u', - help='Print unknown files.'), - Option('versioned', help='Print versioned files.', - short_name='V'), - Option('ignored', short_name='i', - help='Print ignored files.'), - Option('kind', short_name='k', - help=('List entries of a particular kind: file, ' - 'directory, symlink, tree-reference.'), - type=str), - 'null', - 'show-ids', - 'directory', - ] + "verbose", + "revision", + Option("recursive", short_name="R", help="Recurse into subdirectories."), + Option("from-root", help="Print paths relative to the root of the branch."), + Option("unknown", short_name="u", help="Print unknown files."), + Option("versioned", help="Print versioned files.", short_name="V"), + Option("ignored", short_name="i", help="Print ignored files."), + Option( + "kind", + short_name="k", + help=( + "List entries of a particular kind: file, " + "directory, symlink, tree-reference." + ), + type=str, + ), + "null", + "show-ids", + "directory", + ] @display_command - def run(self, revision=None, verbose=False, - recursive=False, from_root=False, - unknown=False, versioned=False, ignored=False, - null=False, kind=None, show_ids=False, path=None, directory=None): + def run( + self, + revision=None, + verbose=False, + recursive=False, + from_root=False, + unknown=False, + versioned=False, + ignored=False, + null=False, + kind=None, + show_ids=False, + path=None, + directory=None, + ): from . import views from .workingtree import WorkingTree - if kind and kind not in ('file', 'directory', 'symlink', 'tree-reference'): - raise errors.CommandError(gettext('invalid kind specified')) + if kind and kind not in ("file", "directory", "symlink", "tree-reference"): + raise errors.CommandError(gettext("invalid kind specified")) if verbose and null: - raise errors.CommandError( - gettext('Cannot set both --verbose and --null')) + raise errors.CommandError(gettext("Cannot set both --verbose and --null")) all = not (unknown or versioned or ignored) - selection = {'I': ignored, '?': unknown, 'V': versioned} + selection = {"I": ignored, "?": unknown, "V": versioned} if path is None: - fs_path = '.' + fs_path = "." else: if from_root: - raise errors.CommandError(gettext('cannot specify both --from-root' - ' and PATH')) + raise errors.CommandError( + gettext("cannot specify both --from-root" " and PATH") + ) fs_path = path - tree, branch, relpath = \ - _open_directory_or_containing_tree_or_branch(fs_path, directory) + tree, branch, relpath = _open_directory_or_containing_tree_or_branch( + fs_path, directory + ) # Calculate the prefix to use prefix = None if from_root: if relpath: - prefix = relpath + '/' - elif fs_path != '.' and not fs_path.endswith('/'): - prefix = fs_path + '/' + prefix = relpath + "/" + elif fs_path != "." and not fs_path.endswith("/"): + prefix = fs_path + "/" if revision is not None or tree is None: - tree = _get_one_revision_tree('ls', revision, branch=branch) + tree = _get_one_revision_tree("ls", revision, branch=branch) apply_view = False if isinstance(tree, WorkingTree) and tree.supports_views(): @@ -3084,7 +3475,8 @@ def run(self, revision=None, verbose=False, self.enter_context(tree.lock_read()) for fp, fc, fkind, entry in tree.list_files( - include_root=False, from_dir=relpath, recursive=recursive): + include_root=False, from_dir=relpath, recursive=recursive + ): # Apply additional masking if not all and not selection[fc]: continue @@ -3107,26 +3499,26 @@ def run(self, revision=None, verbose=False, outstring = fp + kindch ui.ui_factory.clear_term() if verbose: - outstring = '%-8s %s' % (fc, outstring) - if show_ids and getattr(entry, 'file_id', None) is not None: - outstring = "%-50s %s" % (outstring, entry.file_id.decode('utf-8')) - self.outf.write(outstring + '\n') + outstring = "%-8s %s" % (fc, outstring) + if show_ids and getattr(entry, "file_id", None) is not None: + outstring = "%-50s %s" % (outstring, entry.file_id.decode("utf-8")) + self.outf.write(outstring + "\n") elif null: - self.outf.write(fp + '\0') + self.outf.write(fp + "\0") if show_ids: - if getattr(entry, 'file_id', None) is not None: - self.outf.write(entry.file_id.decode('utf-8')) - self.outf.write('\0') + if getattr(entry, "file_id", None) is not None: + self.outf.write(entry.file_id.decode("utf-8")) + self.outf.write("\0") self.outf.flush() else: if show_ids: - if getattr(entry, 'file_id', None) is not None: - my_id = entry.file_id.decode('utf-8') + if getattr(entry, "file_id", None) is not None: + my_id = entry.file_id.decode("utf-8") else: - my_id = '' - self.outf.write('%-50s %s\n' % (outstring, my_id)) + my_id = "" + self.outf.write("%-50s %s\n" % (outstring, my_id)) else: - self.outf.write(outstring + '\n') + self.outf.write(outstring + "\n") class cmd_unknowns(Command): @@ -3134,14 +3526,15 @@ class cmd_unknowns(Command): """ hidden = True - _see_also = ['ls'] - takes_options = ['directory'] + _see_also = ["ls"] + takes_options = ["directory"] @display_command - def run(self, directory='.'): + def run(self, directory="."): from .workingtree import WorkingTree + for f in WorkingTree.open_containing(directory)[0].unknowns(): - self.outf.write(osutils.quotefn(f) + '\n') + self.outf.write(osutils.quotefn(f) + "\n") class cmd_ignore(Command): @@ -3219,61 +3612,75 @@ class cmd_ignore(Command): brz ignore "!!*~" """ - _see_also = ['status', 'ignored', 'patterns'] - takes_args = ['name_pattern*'] - takes_options = ['directory', - Option('default-rules', - help='Display the default ignore rules that brz uses.') - ] + _see_also = ["status", "ignored", "patterns"] + takes_args = ["name_pattern*"] + takes_options = [ + "directory", + Option("default-rules", help="Display the default ignore rules that brz uses."), + ] - def run(self, name_pattern_list=None, default_rules=None, - directory='.'): + def run(self, name_pattern_list=None, default_rules=None, directory="."): from breezy import ignores from . import globbing, lazy_regex from .workingtree import WorkingTree + if default_rules is not None: # dump the default rules and exit for pattern in ignores.USER_DEFAULTS: self.outf.write(f"{pattern}\n") return if not name_pattern_list: - raise errors.CommandError(gettext("ignore requires at least one " - "NAME_PATTERN or --default-rules.")) - name_pattern_list = [globbing.normalize_pattern(p) - for p in name_pattern_list] - bad_patterns = '' + raise errors.CommandError( + gettext( + "ignore requires at least one " "NAME_PATTERN or --default-rules." + ) + ) + name_pattern_list = [globbing.normalize_pattern(p) for p in name_pattern_list] + bad_patterns = "" bad_patterns_count = 0 for p in name_pattern_list: if not globbing.Globster.is_pattern_valid(p): bad_patterns_count += 1 - bad_patterns += f'\n {p}' + bad_patterns += f"\n {p}" if bad_patterns: - msg = (ngettext('Invalid ignore pattern found. %s', - 'Invalid ignore patterns found. %s', - bad_patterns_count) % bad_patterns) + msg = ( + ngettext( + "Invalid ignore pattern found. %s", + "Invalid ignore patterns found. %s", + bad_patterns_count, + ) + % bad_patterns + ) ui.ui_factory.show_error(msg) - raise lazy_regex.InvalidPattern('') + raise lazy_regex.InvalidPattern("") for name_pattern in name_pattern_list: - if (name_pattern[0] == '/' or - (len(name_pattern) > 1 and name_pattern[1] == ':')): - raise errors.CommandError(gettext( - "NAME_PATTERN should not be an absolute path")) + if name_pattern[0] == "/" or ( + len(name_pattern) > 1 and name_pattern[1] == ":" + ): + raise errors.CommandError( + gettext("NAME_PATTERN should not be an absolute path") + ) tree, relpath = WorkingTree.open_containing(directory) ignores.tree_ignores_add_patterns(tree, name_pattern_list) ignored = globbing.Globster(name_pattern_list) matches = [] self.enter_context(tree.lock_read()) for filename, _fc, _fkind, entry in tree.list_files(): - id = getattr(entry, 'file_id', None) + id = getattr(entry, "file_id", None) if id is not None: if ignored.match(filename): matches.append(filename) if len(matches) > 0: - self.outf.write(gettext("Warning: the following files are version " - "controlled and match your ignore pattern:\n%s" - "\nThese files will continue to be version controlled" - " unless you 'brz remove' them.\n") % ("\n".join(matches),)) + self.outf.write( + gettext( + "Warning: the following files are version " + "controlled and match your ignore pattern:\n%s" + "\nThese files will continue to be version controlled" + " unless you 'brz remove' them.\n" + ) + % ("\n".join(matches),) + ) class cmd_ignored(Command): @@ -3287,21 +3694,22 @@ class cmd_ignored(Command): brz ls --ignored """ - encoding_type = 'replace' - _see_also = ['ignore', 'ls'] - takes_options = ['directory'] + encoding_type = "replace" + _see_also = ["ignore", "ls"] + takes_options = ["directory"] @display_command - def run(self, directory='.'): + def run(self, directory="."): from .workingtree import WorkingTree + tree = WorkingTree.open_containing(directory)[0] self.enter_context(tree.lock_read()) for path, file_class, _kind, _entry in tree.list_files(): - if file_class != 'I': + if file_class != "I": continue # XXX: Slightly inefficient since this was already calculated pat = tree.is_ignored(path) - self.outf.write('%-50s %s\n' % (path, pat)) + self.outf.write("%-50s %s\n" % (path, pat)) class cmd_lookup_revision(Command): @@ -3311,19 +3719,20 @@ class cmd_lookup_revision(Command): brz lookup-revision 33 """ hidden = True - takes_args = ['revno'] - takes_options = ['directory'] + takes_args = ["revno"] + takes_options = ["directory"] @display_command - def run(self, revno, directory='.'): + def run(self, revno, directory="."): from .workingtree import WorkingTree + try: revno = int(revno) except ValueError as exc: - raise errors.CommandError(gettext("not a valid revision-number: %r") - % revno) from exc - revid = WorkingTree.open_containing( - directory)[0].branch.get_rev_id(revno) + raise errors.CommandError( + gettext("not a valid revision-number: %r") % revno + ) from exc + revid = WorkingTree.open_containing(directory)[0].branch.get_rev_id(revno) self.outf.write(f"{revid.decode('utf-8')}\n") @@ -3355,51 +3764,68 @@ class cmd_export(Command): zip .zip ================= ========================= """ - encoding = 'exact' - encoding_type = 'exact' - takes_args = ['dest', 'branch_or_subdir?'] - takes_options = ['directory', - Option('format', - help="Type of file to export to.", - type=str), - 'revision', - Option('filters', help='Apply content filters to export the ' - 'convenient form.'), - Option('root', - type=str, - help="Name of the root directory inside the exported file."), - Option('per-file-timestamps', - help='Set modification time of files to that of the last ' - 'revision in which it was changed.'), - Option('uncommitted', - help='Export the working tree contents rather than that of the ' - 'last revision.'), - Option('recurse-nested', - help='Include contents of nested trees.'), - ] - - def run(self, dest, branch_or_subdir=None, revision=None, format=None, - root=None, filters=False, per_file_timestamps=False, uncommitted=False, - directory='.', recurse_nested=False): + encoding = "exact" + encoding_type = "exact" + takes_args = ["dest", "branch_or_subdir?"] + takes_options = [ + "directory", + Option("format", help="Type of file to export to.", type=str), + "revision", + Option( + "filters", help="Apply content filters to export the " "convenient form." + ), + Option( + "root", + type=str, + help="Name of the root directory inside the exported file.", + ), + Option( + "per-file-timestamps", + help="Set modification time of files to that of the last " + "revision in which it was changed.", + ), + Option( + "uncommitted", + help="Export the working tree contents rather than that of the " + "last revision.", + ), + Option("recurse-nested", help="Include contents of nested trees."), + ] + + def run( + self, + dest, + branch_or_subdir=None, + revision=None, + format=None, + root=None, + filters=False, + per_file_timestamps=False, + uncommitted=False, + directory=".", + recurse_nested=False, + ): from .export import export, get_root_name, guess_format if branch_or_subdir is None: branch_or_subdir = directory (tree, b, subdir) = controldir.ControlDir.open_containing_tree_or_branch( - branch_or_subdir) + branch_or_subdir + ) if tree is not None: self.enter_context(tree.lock_read()) if uncommitted: if tree is None: raise errors.CommandError( - gettext("--uncommitted requires a working tree")) + gettext("--uncommitted requires a working tree") + ) export_tree = tree else: export_tree = _get_one_revision_tree( - 'export', revision, branch=b, - tree=tree) + "export", revision, branch=b, tree=tree + ) if format is None: format = guess_format(dest) @@ -3414,16 +3840,25 @@ def run(self, dest, branch_or_subdir=None, revision=None, format=None, if filters: from .filter_tree import ContentFilterTree + export_tree = ContentFilterTree( - export_tree, export_tree._content_filter_stack) + export_tree, export_tree._content_filter_stack + ) try: - export(export_tree, dest, format, root, subdir, - per_file_timestamps=per_file_timestamps, - recurse_nested=recurse_nested) + export( + export_tree, + dest, + format, + root, + subdir, + per_file_timestamps=per_file_timestamps, + recurse_nested=recurse_nested, + ) except errors.NoSuchExportFormat as exc: raise errors.CommandError( - gettext('Unsupported export format: %s') % exc.format) from exc + gettext("Unsupported export format: %s") % exc.format + ) from exc class cmd_cat(Command): @@ -3435,65 +3870,78 @@ class cmd_cat(Command): binary file. """ - _see_also = ['ls'] - takes_options = ['directory', - Option('name-from-revision', - help='The path name in the old tree.'), - Option('filters', help='Apply content filters to display the ' - 'convenience form.'), - 'revision', - ] - takes_args = ['filename'] - encoding_type = 'exact' + _see_also = ["ls"] + takes_options = [ + "directory", + Option("name-from-revision", help="The path name in the old tree."), + Option( + "filters", help="Apply content filters to display the " "convenience form." + ), + "revision", + ] + takes_args = ["filename"] + encoding_type = "exact" @display_command - def run(self, filename, revision=None, name_from_revision=False, - filters=False, directory=None): + def run( + self, + filename, + revision=None, + name_from_revision=False, + filters=False, + directory=None, + ): if revision is not None and len(revision) != 1: - raise errors.CommandError(gettext("brz cat --revision takes exactly" - " one revision specifier")) - tree, branch, relpath = \ - _open_directory_or_containing_tree_or_branch(filename, directory) + raise errors.CommandError( + gettext("brz cat --revision takes exactly" " one revision specifier") + ) + tree, branch, relpath = _open_directory_or_containing_tree_or_branch( + filename, directory + ) self.enter_context(branch.lock_read()) - return self._run(tree, branch, relpath, filename, revision, - name_from_revision, filters) + return self._run( + tree, branch, relpath, filename, revision, name_from_revision, filters + ) - def _run(self, tree, b, relpath, filename, revision, name_from_revision, - filtered): + def _run(self, tree, b, relpath, filename, revision, name_from_revision, filtered): import shutil + if tree is None: tree = b.basis_tree() - rev_tree = _get_one_revision_tree('cat', revision, branch=b) + rev_tree = _get_one_revision_tree("cat", revision, branch=b) self.enter_context(rev_tree.lock_read()) if name_from_revision: # Try in revision if requested if not rev_tree.is_versioned(relpath): - raise errors.CommandError(gettext( - "{0!r} is not present in revision {1}").format( - filename, rev_tree.get_revision_id())) + raise errors.CommandError( + gettext("{0!r} is not present in revision {1}").format( + filename, rev_tree.get_revision_id() + ) + ) rev_tree_path = relpath else: try: - rev_tree_path = _mod_tree.find_previous_path( - tree, rev_tree, relpath) + rev_tree_path = _mod_tree.find_previous_path(tree, rev_tree, relpath) except transport.NoSuchFile: rev_tree_path = None if rev_tree_path is None: # Path didn't exist in working tree if not rev_tree.is_versioned(relpath): - raise errors.CommandError(gettext( - "{0!r} is not present in revision {1}").format( - filename, rev_tree.get_revision_id())) + raise errors.CommandError( + gettext("{0!r} is not present in revision {1}").format( + filename, rev_tree.get_revision_id() + ) + ) else: # Fall back to the same path in the basis tree, if present. rev_tree_path = relpath if filtered: from .filter_tree import ContentFilterTree - filter_tree = ContentFilterTree( - rev_tree, rev_tree._content_filter_stack) + + filter_tree = ContentFilterTree(rev_tree, rev_tree._content_filter_stack) fileobj = filter_tree.get_file(rev_tree_path) else: fileobj = rev_tree.get_file(rev_tree_path) @@ -3567,93 +4015,133 @@ class cmd_commit(Command): one or more bugs. See ``brz help bugs`` for details. """ - _see_also = ['add', 'bugs', 'hooks', 'uncommit'] - takes_args = ['selected*'] + _see_also = ["add", "bugs", "hooks", "uncommit"] + takes_args = ["selected*"] takes_options = [ ListOption( - 'exclude', type=str, short_name='x', - help="Do not consider changes made to a given path."), - Option('message', type=str, - short_name='m', - help="Description of the new revision."), - 'verbose', - Option('unchanged', - help='Commit even if nothing has changed.'), - Option('file', type=str, - short_name='F', - argname='msgfile', - help='Take commit message from this file.'), - Option('strict', - help="Refuse to commit if there are unknown " - "files in the working tree."), - Option('commit-time', type=str, - help="Manually set a commit time using commit date " - "format, e.g. '2009-10-10 08:00:00 +0100'."), + "exclude", + type=str, + short_name="x", + help="Do not consider changes made to a given path.", + ), + Option( + "message", type=str, short_name="m", help="Description of the new revision." + ), + "verbose", + Option("unchanged", help="Commit even if nothing has changed."), + Option( + "file", + type=str, + short_name="F", + argname="msgfile", + help="Take commit message from this file.", + ), + Option( + "strict", + help="Refuse to commit if there are unknown " "files in the working tree.", + ), + Option( + "commit-time", + type=str, + help="Manually set a commit time using commit date " + "format, e.g. '2009-10-10 08:00:00 +0100'.", + ), ListOption( - 'bugs', type=str, - help="Link to a related bug. (see \"brz help bugs\")."), + "bugs", type=str, help='Link to a related bug. (see "brz help bugs").' + ), ListOption( - 'fixes', type=str, - help="Mark a bug as being fixed by this revision " - "(see \"brz help bugs\")."), + "fixes", + type=str, + help="Mark a bug as being fixed by this revision " '(see "brz help bugs").', + ), ListOption( - 'author', type=str, - help="Set the author's name, if it's different " - "from the committer."), - Option('local', - help="Perform a local commit in a bound " - "branch. Local commits are not pushed to " - "the master branch until a normal commit " - "is performed." - ), - Option('show-diff', short_name='p', - help='When no message is supplied, show the diff along' - ' with the status summary in the message editor.'), - Option('lossy', - help='When committing to a foreign version control ' - 'system do not push data that can not be natively ' - 'represented.'), ] - aliases = ['ci', 'checkin'] + "author", + type=str, + help="Set the author's name, if it's different " "from the committer.", + ), + Option( + "local", + help="Perform a local commit in a bound " + "branch. Local commits are not pushed to " + "the master branch until a normal commit " + "is performed.", + ), + Option( + "show-diff", + short_name="p", + help="When no message is supplied, show the diff along" + " with the status summary in the message editor.", + ), + Option( + "lossy", + help="When committing to a foreign version control " + "system do not push data that can not be natively " + "represented.", + ), + ] + aliases = ["ci", "checkin"] def _iter_bug_urls(self, bugs, branch, status): default_bugtracker = None # Configure the properties for bug fixing attributes. for bug in bugs: - tokens = bug.split(':') + tokens = bug.split(":") if len(tokens) == 1: if default_bugtracker is None: branch_config = branch.get_config_stack() - default_bugtracker = branch_config.get( - "bugtracker") + default_bugtracker = branch_config.get("bugtracker") if default_bugtracker is None: - raise errors.CommandError(gettext( - "No tracker specified for bug %s. Use the form " - "'tracker:id' or specify a default bug tracker " - "using the `bugtracker` option.\nSee " - "\"brz help bugs\" for more information on this " - "feature. Commit refused.") % bug) + raise errors.CommandError( + gettext( + "No tracker specified for bug %s. Use the form " + "'tracker:id' or specify a default bug tracker " + "using the `bugtracker` option.\nSee " + '"brz help bugs" for more information on this ' + "feature. Commit refused." + ) + % bug + ) tag = default_bugtracker bug_id = tokens[0] elif len(tokens) != 2: - raise errors.CommandError(gettext( - "Invalid bug %s. Must be in the form of 'tracker:id'. " - "See \"brz help bugs\" for more information on this " - "feature.\nCommit refused.") % bug) + raise errors.CommandError( + gettext( + "Invalid bug %s. Must be in the form of 'tracker:id'. " + 'See "brz help bugs" for more information on this ' + "feature.\nCommit refused." + ) + % bug + ) else: tag, bug_id = tokens try: yield bugtracker.get_bug_url(tag, branch, bug_id), status except bugtracker.UnknownBugTrackerAbbreviation as exc: - raise errors.CommandError(gettext( - 'Unrecognized bug %s. Commit refused.') % bug) from exc + raise errors.CommandError( + gettext("Unrecognized bug %s. Commit refused.") % bug + ) from exc except bugtracker.MalformedBugIdentifier as exc: - raise errors.CommandError(gettext( - "%s\nCommit refused.") % (exc,)) from exc - - def run(self, message=None, file=None, verbose=False, selected_list=None, - unchanged=False, strict=False, local=False, fixes=None, bugs=None, - author=None, show_diff=False, exclude=None, commit_time=None, - lossy=False): + raise errors.CommandError( + gettext("%s\nCommit refused.") % (exc,) + ) from exc + + def run( + self, + message=None, + file=None, + verbose=False, + selected_list=None, + unchanged=False, + strict=False, + local=False, + fixes=None, + bugs=None, + author=None, + show_diff=False, + exclude=None, + commit_time=None, + lossy=False, + ): import itertools from .commit import PointlessCommit @@ -3671,13 +4159,14 @@ def run(self, message=None, file=None, verbose=False, selected_list=None, try: commit_stamp, offset = patch.parse_patch_date(commit_time) except ValueError as exc: - raise errors.CommandError(gettext( - "Could not parse --commit-time: " + str(exc))) from exc + raise errors.CommandError( + gettext("Could not parse --commit-time: " + str(exc)) + ) from exc properties = {} tree, selected_list = WorkingTree.open_containing_paths(selected_list) - if selected_list == ['']: + if selected_list == [""]: # workaround - commit of root of tree should be exactly the same # as just default commit in that tree, and succeed even though # selected-file merge commit is not done yet @@ -3690,9 +4179,11 @@ def run(self, message=None, file=None, verbose=False, selected_list=None, bug_property = bugtracker.encode_fixes_bug_urls( itertools.chain( self._iter_bug_urls(bugs, tree.branch, bugtracker.RELATED), - self._iter_bug_urls(fixes, tree.branch, bugtracker.FIXED))) + self._iter_bug_urls(fixes, tree.branch, bugtracker.FIXED), + ) + ) if bug_property: - properties['bugs'] = bug_property + properties["bugs"] = bug_property if local and not tree.branch.get_bound_location(): raise errors.LocalRequiresBoundBranch() @@ -3708,28 +4199,33 @@ def run(self, message=None, file=None, verbose=False, selected_list=None, if file_exists: warning_msg = ( f'The commit message is a file name: "{message}".\n' - f'(use --file "{message}" to take commit message from that file)') + f'(use --file "{message}" to take commit message from that file)' + ) ui.ui_factory.show_warning(warning_msg) - if '\r' in message: - message = message.replace('\r\n', '\n') - message = message.replace('\r', '\n') + if "\r" in message: + message = message.replace("\r\n", "\n") + message = message.replace("\r", "\n") if file: - raise errors.CommandError(gettext( - "please specify either --message or --file")) + raise errors.CommandError( + gettext("please specify either --message or --file") + ) def get_message(commit_obj): """Callback to get commit message.""" if file: - with open(file, 'rb') as f: + with open(file, "rb") as f: my_message = f.read().decode(osutils.get_user_encoding()) elif message is not None: my_message = message else: # No message supplied: make one up. # text is the status of the tree - text = make_commit_message_template_encoded(tree, - selected_list, diff=show_diff, - output_encoding=osutils.get_user_encoding()) + text = make_commit_message_template_encoded( + tree, + selected_list, + diff=show_diff, + output_encoding=osutils.get_user_encoding(), + ) # start_message is the template generated from hooks # XXX: Warning - looks like hooks return unicode, # make_commit_message_template_encoded returns user encoding. @@ -3737,21 +4233,30 @@ def get_message(commit_obj): # avoid this. my_message = set_commit_message(commit_obj) if my_message is None: - start_message = generate_commit_message_template( - commit_obj) + start_message = generate_commit_message_template(commit_obj) if start_message is not None: start_message = start_message.encode( - osutils.get_user_encoding()) - my_message = edit_commit_message_encoded(text, - start_message=start_message) + osutils.get_user_encoding() + ) + my_message = edit_commit_message_encoded( + text, start_message=start_message + ) if my_message is None: - raise errors.CommandError(gettext("please specify a commit" - " message with either --message or --file")) + raise errors.CommandError( + gettext( + "please specify a commit" + " message with either --message or --file" + ) + ) if my_message == "": - raise errors.CommandError(gettext("Empty commit message specified." - " Please specify a commit message with either" - " --message or --file or leave a blank message" - " with --message \"\".")) + raise errors.CommandError( + gettext( + "Empty commit message specified." + " Please specify a commit message with either" + " --message or --file or leave a blank message" + ' with --message "".' + ) + ) return my_message # The API permits a commit with a filter of [] to mean 'select nothing' @@ -3759,30 +4264,51 @@ def get_message(commit_obj): if not selected_list: selected_list = None try: - tree.commit(message_callback=get_message, - specific_files=selected_list, - allow_pointless=unchanged, strict=strict, local=local, - reporter=None, verbose=verbose, revprops=properties, - authors=author, timestamp=commit_stamp, - timezone=offset, - exclude=tree.safe_relpath_files(exclude), - lossy=lossy) + tree.commit( + message_callback=get_message, + specific_files=selected_list, + allow_pointless=unchanged, + strict=strict, + local=local, + reporter=None, + verbose=verbose, + revprops=properties, + authors=author, + timestamp=commit_stamp, + timezone=offset, + exclude=tree.safe_relpath_files(exclude), + lossy=lossy, + ) except PointlessCommit as exc: - raise errors.CommandError(gettext("No changes to commit." - " Please 'brz add' the files you want to commit, or use" - " --unchanged to force an empty commit.")) from exc + raise errors.CommandError( + gettext( + "No changes to commit." + " Please 'brz add' the files you want to commit, or use" + " --unchanged to force an empty commit." + ) + ) from exc except ConflictsInTree as exc: - raise errors.CommandError(gettext('Conflicts detected in working ' - 'tree. Use "brz conflicts" to list, "brz resolve FILE" to' - ' resolve.')) from exc + raise errors.CommandError( + gettext( + "Conflicts detected in working " + 'tree. Use "brz conflicts" to list, "brz resolve FILE" to' + " resolve." + ) + ) from exc except StrictCommitFailed as exc: - raise errors.CommandError(gettext("Commit refused because there are" - " unknown files in the working tree.")) from exc + raise errors.CommandError( + gettext( + "Commit refused because there are" + " unknown files in the working tree." + ) + ) from exc except errors.BoundBranchOutOfDate as exc: - exc.extra_help = (gettext("\n" - 'To commit to master branch, run update and then commit.\n' - 'You can also pass --local to commit to continue working ' - 'disconnected.')) + exc.extra_help = gettext( + "\n" + "To commit to master branch, run update and then commit.\n" + "You can also pass --local to commit to continue working " + "disconnected." + ) raise @@ -3836,21 +4362,24 @@ class cmd_check(Command): brz check baz """ - _see_also = ['reconcile'] - takes_args = ['path?'] - takes_options = ['verbose', - Option('branch', help="Check the branch related to the" - " current directory."), - Option('repo', help="Check the repository related to the" - " current directory."), - Option('tree', help="Check the working tree related to" - " the current directory.")] - - def run(self, path=None, verbose=False, branch=False, repo=False, - tree=False): + _see_also = ["reconcile"] + takes_args = ["path?"] + takes_options = [ + "verbose", + Option("branch", help="Check the branch related to the" " current directory."), + Option( + "repo", help="Check the repository related to the" " current directory." + ), + Option( + "tree", help="Check the working tree related to" " the current directory." + ), + ] + + def run(self, path=None, verbose=False, branch=False, repo=False, tree=False): from .check import check_dwim + if path is None: - path = '.' + path = "." if not branch and not repo and not tree: branch = repo = tree = True check_dwim(path, verbose, do_branch=branch, do_repo=repo, do_tree=tree) @@ -3889,24 +4418,29 @@ class cmd_upgrade(Command): https://www.breezy-vcs.org/doc/en/upgrade-guide/. """ - _see_also = ['check', 'reconcile', 'formats'] - takes_args = ['url?'] + _see_also = ["check", "reconcile", "formats"] + takes_args = ["url?"] takes_options = [ - RegistryOption('format', - help='Upgrade to a specific format. See "brz help' - ' formats" for details.', - lazy_registry=('breezy.controldir', 'format_registry'), - converter=lambda name: controldir.format_registry.make_controldir( # type: ignore - name), - value_switches=True, title='Branch format'), - Option('clean', - help='Remove the backup.bzr directory if successful.'), - Option('dry-run', - help="Show what would be done, but don't actually do anything."), + RegistryOption( + "format", + help='Upgrade to a specific format. See "brz help' + ' formats" for details.', + lazy_registry=("breezy.controldir", "format_registry"), + converter=lambda name: controldir.format_registry.make_controldir( # type: ignore + name + ), + value_switches=True, + title="Branch format", + ), + Option("clean", help="Remove the backup.bzr directory if successful."), + Option( + "dry-run", help="Show what would be done, but don't actually do anything." + ), ] - def run(self, url='.', format=None, clean=False, dry_run=False): + def run(self, url=".", format=None, clean=False, dry_run=False): from .upgrade import upgrade + exceptions = upgrade(url, format, clean_up=clean, dry_run=dry_run) if exceptions: if len(exceptions) == 1: @@ -3928,15 +4462,15 @@ class cmd_whoami(Command): brz whoami "Frank Chu " """ - takes_options = ['directory', - Option('email', - help='Display email address only.'), - Option('branch', - help='Set identity for the current branch instead of ' - 'globally.'), - ] - takes_args = ['name?'] - encoding_type = 'replace' + takes_options = [ + "directory", + Option("email", help="Display email address only."), + Option( + "branch", help="Set identity for the current branch instead of " "globally." + ), + ] + takes_args = ["name?"] + encoding_type = "replace" @display_command def run(self, email=False, branch=False, name=None, directory=None): @@ -3944,41 +4478,44 @@ def run(self, email=False, branch=False, name=None, directory=None): if directory is None: # use branch if we're inside one; otherwise global config try: - c = Branch.open_containing('.')[0].get_config_stack() + c = Branch.open_containing(".")[0].get_config_stack() except errors.NotBranchError: c = _mod_config.GlobalStack() else: c = Branch.open(directory).get_config_stack() - identity = c.get('email') + identity = c.get("email") if email: - self.outf.write(_mod_config.extract_email_address(identity) - + '\n') + self.outf.write(_mod_config.extract_email_address(identity) + "\n") else: - self.outf.write(identity + '\n') + self.outf.write(identity + "\n") return if email: - raise errors.CommandError(gettext("--email can only be used to display existing " - "identity")) + raise errors.CommandError( + gettext("--email can only be used to display existing " "identity") + ) # display a warning if an email address isn't included in the given name. try: _mod_config.extract_email_address(name) except _mod_config.NoEmailInUsername: - warning('"%s" does not seem to contain an email address. ' - 'This is allowed, but not recommended.', name) + warning( + '"%s" does not seem to contain an email address. ' + "This is allowed, but not recommended.", + name, + ) # use global config unless --branch given if branch: if directory is None: - c = Branch.open_containing('.')[0].get_config_stack() + c = Branch.open_containing(".")[0].get_config_stack() else: b = Branch.open(directory) self.enter_context(b.lock_write()) c = b.get_config_stack() else: c = _mod_config.GlobalStack() - c.set('email', name) + c.set("email", name) class cmd_nick(Command): @@ -3992,11 +4529,11 @@ class cmd_nick(Command): locally. """ - _see_also = ['info'] - takes_args = ['nickname?'] - takes_options = ['directory'] + _see_also = ["info"] + takes_args = ["nickname?"] + takes_options = ["directory"] - def run(self, nickname=None, directory='.'): + def run(self, nickname=None, directory="."): branch = Branch.open_containing(directory)[0] if nickname is None: self.printme(branch) @@ -4005,7 +4542,7 @@ def run(self, nickname=None, directory='.'): @display_command def printme(self, branch): - self.outf.write(f'{branch.nick}\n') + self.outf.write(f"{branch.nick}\n") class cmd_alias(Command): @@ -4029,10 +4566,10 @@ class cmd_alias(Command): brz alias --remove ll """ - takes_args = ['name?'] + takes_args = ["name?"] takes_options = [ - Option('remove', help='Remove the alias.'), - ] + Option("remove", help="Remove the alias."), + ] def run(self, name=None, remove=False): if remove: @@ -4040,16 +4577,17 @@ def run(self, name=None, remove=False): elif name is None: self.print_aliases() else: - equal_pos = name.find('=') + equal_pos = name.find("=") if equal_pos == -1: self.print_alias(name) else: - self.set_alias(name[:equal_pos], name[equal_pos + 1:]) + self.set_alias(name[:equal_pos], name[equal_pos + 1 :]) def remove_alias(self, alias_name): if alias_name is None: - raise errors.CommandError(gettext( - 'brz alias --remove expects an alias to remove.')) + raise errors.CommandError( + gettext("brz alias --remove expects an alias to remove.") + ) # If alias is not found, print something like: # unalias: foo: not found c = _mod_config.GlobalConfig() @@ -4060,17 +4598,17 @@ def print_aliases(self): """Print out the defined aliases in a similar format to bash.""" aliases = _mod_config.GlobalConfig().get_aliases() for key, value in sorted(aliases.items()): - self.outf.write(f'brz alias {key}="{value}\"\n') + self.outf.write(f'brz alias {key}="{value}"\n') @display_command def print_alias(self, alias_name): from .commands import get_alias + alias = get_alias(alias_name) if alias is None: self.outf.write(f"brz alias: {alias_name}: not found\n") else: - self.outf.write( - f"brz alias {alias_name}=\"{' '.join(alias)}\"\n") + self.outf.write(f"brz alias {alias_name}=\"{' '.join(alias)}\"\n") def set_alias(self, alias_name, alias_command): """Save the alias in the global config.""" @@ -4135,87 +4673,123 @@ def get_transport_type(typestring): """Parse and return a transport specifier.""" if typestring == "sftp": from .tests import stub_sftp + return stub_sftp.SFTPAbsoluteServer elif typestring == "memory": from breezy.transport import memory from .tests import test_server + return memory.MemoryServer elif typestring == "fakenfs": from .tests import test_server + return test_server.FakeNFSServer msg = f"No known transport type {typestring}. Supported types are: sftp\n" raise errors.CommandError(msg) hidden = True - takes_args = ['testspecs*'] - takes_options = ['verbose', - Option('one', - help='Stop when one test fails.', - short_name='1', - ), - Option('transport', - help='Use a different transport by default ' - 'throughout the test suite.', - type=get_transport_type), - Option('benchmark', - help='Run the benchmarks rather than selftests.', - hidden=True), - Option('lsprof-timed', - help='Generate lsprof output for benchmarked' - ' sections of code.'), - Option('lsprof-tests', - help='Generate lsprof output for each test.'), - Option('first', - help='Run all tests, but run specified tests first.', - short_name='f', - ), - Option('list-only', - help='List the tests instead of running them.'), - RegistryOption('parallel', - help="Run the test suite in parallel.", - lazy_registry=( - 'breezy.tests', 'parallel_registry'), - value_switches=False, - ), - Option('randomize', type=str, argname="SEED", - help='Randomize the order of tests using the given' - ' seed or "now" for the current time.'), - ListOption('exclude', type=str, argname="PATTERN", - short_name='x', - help='Exclude tests that match this regular' - ' expression.'), - Option('subunit1', - help='Output test progress via subunit v1.'), - Option('subunit2', - help='Output test progress via subunit v2.'), - Option('strict', help='Fail on missing dependencies or ' - 'known failures.'), - Option('load-list', type=str, argname='TESTLISTFILE', - help='Load a test id list from a text file.'), - ListOption('debugflag', type=str, short_name='E', - help='Turn on a selftest debug flag.'), - ListOption('starting-with', type=str, argname='TESTID', - param_name='starting_with', short_name='s', - help='Load only the tests starting with TESTID.'), - Option('sync', - help="By default we disable fsync and fdatasync" - " while running the test suite.") - ] - encoding_type = 'replace' + takes_args = ["testspecs*"] + takes_options = [ + "verbose", + Option( + "one", + help="Stop when one test fails.", + short_name="1", + ), + Option( + "transport", + help="Use a different transport by default " "throughout the test suite.", + type=get_transport_type, + ), + Option( + "benchmark", help="Run the benchmarks rather than selftests.", hidden=True + ), + Option( + "lsprof-timed", + help="Generate lsprof output for benchmarked" " sections of code.", + ), + Option("lsprof-tests", help="Generate lsprof output for each test."), + Option( + "first", + help="Run all tests, but run specified tests first.", + short_name="f", + ), + Option("list-only", help="List the tests instead of running them."), + RegistryOption( + "parallel", + help="Run the test suite in parallel.", + lazy_registry=("breezy.tests", "parallel_registry"), + value_switches=False, + ), + Option( + "randomize", + type=str, + argname="SEED", + help="Randomize the order of tests using the given" + ' seed or "now" for the current time.', + ), + ListOption( + "exclude", + type=str, + argname="PATTERN", + short_name="x", + help="Exclude tests that match this regular" " expression.", + ), + Option("subunit1", help="Output test progress via subunit v1."), + Option("subunit2", help="Output test progress via subunit v2."), + Option("strict", help="Fail on missing dependencies or " "known failures."), + Option( + "load-list", + type=str, + argname="TESTLISTFILE", + help="Load a test id list from a text file.", + ), + ListOption( + "debugflag", type=str, short_name="E", help="Turn on a selftest debug flag." + ), + ListOption( + "starting-with", + type=str, + argname="TESTID", + param_name="starting_with", + short_name="s", + help="Load only the tests starting with TESTID.", + ), + Option( + "sync", + help="By default we disable fsync and fdatasync" + " while running the test suite.", + ), + ] + encoding_type = "replace" def __init__(self): Command.__init__(self) self.additional_selftest_args = {} - def run(self, testspecs_list=None, verbose=False, one=False, - transport=None, benchmark=None, - lsprof_timed=None, - first=False, list_only=False, - randomize=None, exclude=None, strict=False, - load_list=None, debugflag=None, starting_with=None, subunit1=False, - subunit2=False, parallel=None, lsprof_tests=False, sync=False): - + def run( + self, + testspecs_list=None, + verbose=False, + one=False, + transport=None, + benchmark=None, + lsprof_timed=None, + first=False, + list_only=False, + randomize=None, + exclude=None, + strict=False, + load_list=None, + debugflag=None, + starting_with=None, + subunit1=False, + subunit2=False, + parallel=None, + lsprof_tests=False, + sync=False, + ): # During selftest, disallow proxying, as it can cause severe # performance penalties and is only needed for thread # safety. The selftest command is assumed to not use threads @@ -4230,74 +4804,89 @@ def run(self, testspecs_list=None, verbose=False, one=False, try: from . import tests except ImportError as exc: - raise errors.CommandError("tests not available. Install the " - "breezy tests to run the breezy testsuite.") from exc + raise errors.CommandError( + "tests not available. Install the " + "breezy tests to run the breezy testsuite." + ) from exc if testspecs_list is not None: - pattern = '|'.join(testspecs_list) + pattern = "|".join(testspecs_list) else: pattern = ".*" if subunit1: try: from .tests import SubUnitBzrRunnerv1 except ImportError as exc: - raise errors.CommandError(gettext( - "subunit not available. subunit needs to be installed " - "to use --subunit.")) from exc - self.additional_selftest_args['runner_class'] = SubUnitBzrRunnerv1 + raise errors.CommandError( + gettext( + "subunit not available. subunit needs to be installed " + "to use --subunit." + ) + ) from exc + self.additional_selftest_args["runner_class"] = SubUnitBzrRunnerv1 # On Windows, disable automatic conversion of '\n' to '\r\n' in # stdout, which would corrupt the subunit stream. # FIXME: This has been fixed in subunit trunk (>0.0.5) so the # following code can be deleted when it's sufficiently deployed # -- vila/mgz 20100514 - if (sys.platform == "win32" - and getattr(sys.stdout, 'fileno', None) is not None): + if ( + sys.platform == "win32" + and getattr(sys.stdout, "fileno", None) is not None + ): import msvcrt + msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) if subunit2: try: from .tests import SubUnitBzrRunnerv2 except ImportError as exc: - raise errors.CommandError(gettext( - "subunit not available. subunit " - "needs to be installed to use --subunit2.")) from exc - self.additional_selftest_args['runner_class'] = SubUnitBzrRunnerv2 + raise errors.CommandError( + gettext( + "subunit not available. subunit " + "needs to be installed to use --subunit2." + ) + ) from exc + self.additional_selftest_args["runner_class"] = SubUnitBzrRunnerv2 if parallel: - self.additional_selftest_args.setdefault( - 'suite_decorators', []).append(parallel) + self.additional_selftest_args.setdefault("suite_decorators", []).append( + parallel + ) if benchmark: - raise errors.CommandError(gettext( - "--benchmark is no longer supported from brz 2.2; " - "use bzr-usertest instead")) + raise errors.CommandError( + gettext( + "--benchmark is no longer supported from brz 2.2; " + "use bzr-usertest instead" + ) + ) test_suite_factory = None if not exclude: exclude_pattern = None else: - exclude_pattern = '(' + '|'.join(exclude) + ')' + exclude_pattern = "(" + "|".join(exclude) + ")" if not sync: self._disable_fsync() - selftest_kwargs = {"verbose": verbose, - "pattern": pattern, - "stop_on_failure": one, - "transport": transport, - "test_suite_factory": test_suite_factory, - "lsprof_timed": lsprof_timed, - "lsprof_tests": lsprof_tests, - "matching_tests_first": first, - "list_only": list_only, - "random_seed": randomize, - "exclude_pattern": exclude_pattern, - "strict": strict, - "load_list": load_list, - "debug_flags": debugflag, - "starting_with": starting_with - } + selftest_kwargs = { + "verbose": verbose, + "pattern": pattern, + "stop_on_failure": one, + "transport": transport, + "test_suite_factory": test_suite_factory, + "lsprof_timed": lsprof_timed, + "lsprof_tests": lsprof_tests, + "matching_tests_first": first, + "list_only": list_only, + "random_seed": randomize, + "exclude_pattern": exclude_pattern, + "strict": strict, + "load_list": load_list, + "debug_flags": debugflag, + "starting_with": starting_with, + } selftest_kwargs.update(self.additional_selftest_args) # Make deprecation warnings visible, unless -Werror is set - cleanup = symbol_versioning.activate_deprecation_warnings( - override=False) + cleanup = symbol_versioning.activate_deprecation_warnings(override=False) try: result = tests.selftest(**selftest_kwargs) finally: @@ -4306,10 +4895,10 @@ def run(self, testspecs_list=None, verbose=False, one=False, def _disable_fsync(self): """Change the 'os' functionality to not synchronize.""" - self._orig_fsync = getattr(os, 'fsync', None) + self._orig_fsync = getattr(os, "fsync", None) if self._orig_fsync is not None: os.fsync = lambda filedes: None - self._orig_fdatasync = getattr(os, 'fdatasync', None) + self._orig_fdatasync = getattr(os, "fdatasync", None) if self._orig_fdatasync is not None: os.fdatasync = lambda filedes: None @@ -4317,16 +4906,17 @@ def _disable_fsync(self): class cmd_version(Command): __doc__ = """Show version of brz.""" - encoding_type = 'replace' + encoding_type = "replace" takes_options = [ Option("short", help="Print just the version number."), - ] + ] @display_command def run(self, short=False): from .version import show_version + if short: - self.outf.write(breezy.version_string + '\n') + self.outf.write(breezy.version_string + "\n") else: show_version(to_file=self.outf) @@ -4345,7 +4935,7 @@ class cmd_find_merge_base(Command): __doc__ = """Find and print a base revision for merging two branches.""" # TODO: Options to specify revisions on either side, as if # merging only part of the history. - takes_args = ['branch', 'other'] + takes_args = ["branch", "other"] hidden = True @display_command @@ -4360,8 +4950,9 @@ def run(self, branch, other): graph = branch1.repository.get_graph(branch2.repository) base_rev_id = graph.find_unique_lca(last1, last2) - self.outf.write(gettext('merge base is revision %s\n') % - base_rev_id.decode('utf-8')) + self.outf.write( + gettext("merge base is revision %s\n") % base_rev_id.decode("utf-8") + ) class cmd_merge(Command): @@ -4446,57 +5037,76 @@ class cmd_merge(Command): brz commit -m 'revision with three parents' """ - encoding_type = 'exact' - _see_also = ['update', 'remerge', 'status-flags', 'send'] - takes_args = ['location?'] + encoding_type = "exact" + _see_also = ["update", "remerge", "status-flags", "send"] + takes_args = ["location?"] takes_options = [ - 'change', - 'revision', - Option('force', - help='Merge even if the destination tree has uncommitted changes.'), - 'merge-type', - 'reprocess', - 'remember', - Option('show-base', help="Show base revision text in " - "conflicts."), - Option('uncommitted', help='Apply uncommitted changes' - ' from a working copy, instead of branch changes.'), - Option('pull', help='If the destination is already' - ' completely merged into the source, pull from the' - ' source rather than merging. When this happens,' - ' you do not need to commit the result.'), - custom_help('directory', - help='Branch to merge into, ' - 'rather than the one containing the working directory.'), - Option('preview', help='Instead of merging, show a diff of the' - ' merge.'), - Option('interactive', help='Select changes interactively.', - short_name='i') + "change", + "revision", + Option( + "force", help="Merge even if the destination tree has uncommitted changes." + ), + "merge-type", + "reprocess", + "remember", + Option("show-base", help="Show base revision text in " "conflicts."), + Option( + "uncommitted", + help="Apply uncommitted changes" + " from a working copy, instead of branch changes.", + ), + Option( + "pull", + help="If the destination is already" + " completely merged into the source, pull from the" + " source rather than merging. When this happens," + " you do not need to commit the result.", + ), + custom_help( + "directory", + help="Branch to merge into, " + "rather than the one containing the working directory.", + ), + Option("preview", help="Instead of merging, show a diff of the" " merge."), + Option("interactive", help="Select changes interactively.", short_name="i"), ] - def run(self, location=None, revision=None, force=False, - merge_type=None, show_base=False, reprocess=None, remember=None, - uncommitted=False, pull=False, - directory=None, - preview=False, - interactive=False, - ): + def run( + self, + location=None, + revision=None, + force=False, + merge_type=None, + show_base=False, + reprocess=None, + remember=None, + uncommitted=False, + pull=False, + directory=None, + preview=False, + interactive=False, + ): from . import mergeable as _mod_mergeable from .workingtree import WorkingTree + if merge_type is None: merge_type = _mod_merge.Merge3Merger if directory is None: - directory = '.' + directory = "." possible_transports = [] merger = None allow_pending = True - verified = 'inapplicable' + verified = "inapplicable" tree = WorkingTree.open_containing(directory)[0] if tree.branch.last_revision() == _mod_revision.NULL_REVISION: - raise errors.CommandError(gettext('Merging into empty branches not currently supported, ' - 'https://bugs.launchpad.net/bzr/+bug/308562')) + raise errors.CommandError( + gettext( + "Merging into empty branches not currently supported, " + "https://bugs.launchpad.net/bzr/+bug/308562" + ) + ) # die as quickly as possible if there are uncommitted changes if not force: @@ -4505,73 +5115,84 @@ def run(self, location=None, revision=None, force=False, view_info = _get_view_info_for_change_reporter(tree) change_reporter = delta._ChangeReporter( - unversioned_filter=tree.is_ignored, view_info=view_info) + unversioned_filter=tree.is_ignored, view_info=view_info + ) pb = ui.ui_factory.nested_progress_bar() self.enter_context(pb) self.enter_context(tree.lock_write()) if location is not None: try: mergeable = _mod_mergeable.read_mergeable_from_url( - location, possible_transports=possible_transports) + location, possible_transports=possible_transports + ) except errors.NotABundle: mergeable = None else: if uncommitted: - raise errors.CommandError(gettext('Cannot use --uncommitted' - ' with bundles or merge directives.')) + raise errors.CommandError( + gettext( + "Cannot use --uncommitted" + " with bundles or merge directives." + ) + ) if revision is not None: - raise errors.CommandError(gettext( - 'Cannot use -r with merge directives or bundles')) - merger, verified = _mod_merge.Merger.from_mergeable(tree, - mergeable) + raise errors.CommandError( + gettext("Cannot use -r with merge directives or bundles") + ) + merger, verified = _mod_merge.Merger.from_mergeable(tree, mergeable) if merger is None and uncommitted: if revision is not None and len(revision) > 0: - raise errors.CommandError(gettext('Cannot use --uncommitted and' - ' --revision at the same time.')) + raise errors.CommandError( + gettext( + "Cannot use --uncommitted and" " --revision at the same time." + ) + ) merger = self.get_merger_from_uncommitted(tree, location, None) allow_pending = False if merger is None: - merger, allow_pending = self._get_merger_from_branch(tree, - location, revision, remember, possible_transports, None) + merger, allow_pending = self._get_merger_from_branch( + tree, location, revision, remember, possible_transports, None + ) merger.merge_type = merge_type merger.reprocess = reprocess merger.show_base = show_base self.sanity_check_merger(merger) - if (merger.base_rev_id == merger.other_rev_id and - merger.other_rev_id is not None): + if ( + merger.base_rev_id == merger.other_rev_id + and merger.other_rev_id is not None + ): # check if location is a nonexistent file (and not a branch) to # disambiguate the 'Nothing to do' if merger.interesting_files: - if not merger.other_tree.has_filename( - merger.interesting_files[0]): + if not merger.other_tree.has_filename(merger.interesting_files[0]): note(gettext("merger: ") + str(merger)) raise errors.PathsDoNotExist([location]) - note(gettext('Nothing to do.')) + note(gettext("Nothing to do.")) return 0 if pull and not preview: if merger.interesting_files is not None: - raise errors.CommandError( - gettext('Cannot pull individual files')) - if (merger.base_rev_id == tree.last_revision()): - result = tree.pull(merger.other_branch, False, - merger.other_rev_id) + raise errors.CommandError(gettext("Cannot pull individual files")) + if merger.base_rev_id == tree.last_revision(): + result = tree.pull(merger.other_branch, False, merger.other_rev_id) result.report(self.outf) return 0 if merger.this_basis is None: - raise errors.CommandError(gettext( - "This branch has no commits." - " (perhaps you would prefer 'brz pull')")) + raise errors.CommandError( + gettext( + "This branch has no commits." + " (perhaps you would prefer 'brz pull')" + ) + ) if preview: return self._do_preview(merger) elif interactive: return self._do_interactive(merger) else: - return self._do_merge(merger, change_reporter, allow_pending, - verified) + return self._do_merge(merger, change_reporter, allow_pending, verified) def _get_preview(self, merger): tree_merger = merger.make_merger() @@ -4582,19 +5203,25 @@ def _get_preview(self, merger): def _do_preview(self, merger): from .diff import show_diff_trees + result_tree = self._get_preview(merger) path_encoding = osutils.get_diff_header_encoding() - show_diff_trees(merger.this_tree, result_tree, self.outf, - old_label='', new_label='', - path_encoding=path_encoding) + show_diff_trees( + merger.this_tree, + result_tree, + self.outf, + old_label="", + new_label="", + path_encoding=path_encoding, + ) def _do_merge(self, merger, change_reporter, allow_pending, verified): merger.change_reporter = change_reporter conflict_count = len(merger.do_merge()) if allow_pending: merger.set_pending() - if verified == 'failed': - warning('Preview patch does not match changes') + if verified == "failed": + warning("Preview patch does not match changes") if conflict_count != 0: return 1 else: @@ -4608,21 +5235,27 @@ def _do_interactive(self, merger): and the preview tree. """ from . import shelf_ui + result_tree = self._get_preview(merger) writer = breezy.option.diff_writer_registry.get() - shelver = shelf_ui.Shelver(merger.this_tree, result_tree, destroy=True, - reporter=shelf_ui.ApplyReporter(), - diff_writer=writer(self.outf)) + shelver = shelf_ui.Shelver( + merger.this_tree, + result_tree, + destroy=True, + reporter=shelf_ui.ApplyReporter(), + diff_writer=writer(self.outf), + ) try: shelver.run() finally: shelver.finalize() def sanity_check_merger(self, merger): - if (merger.show_base and - merger.merge_type is not _mod_merge.Merge3Merger): - raise errors.CommandError(gettext("Show-base is not supported for this" - " merge type. %s") % merger.merge_type) + if merger.show_base and merger.merge_type is not _mod_merge.Merge3Merger: + raise errors.CommandError( + gettext("Show-base is not supported for this" " merge type. %s") + % merger.merge_type + ) if merger.reprocess is None: if merger.show_base: merger.reprocess = False @@ -4630,41 +5263,54 @@ def sanity_check_merger(self, merger): # Use reprocess if the merger supports it merger.reprocess = merger.merge_type.supports_reprocess if merger.reprocess and not merger.merge_type.supports_reprocess: - raise errors.CommandError(gettext("Conflict reduction is not supported" - " for merge type %s.") % - merger.merge_type) + raise errors.CommandError( + gettext("Conflict reduction is not supported" " for merge type %s.") + % merger.merge_type + ) if merger.reprocess and merger.show_base: - raise errors.CommandError(gettext("Cannot do conflict reduction and" - " show base.")) - - if (merger.merge_type.requires_file_merge_plan and - (not getattr(merger.this_tree, 'plan_file_merge', None) or - not getattr(merger.other_tree, 'plan_file_merge', None) or - (merger.base_tree is not None and - not getattr(merger.base_tree, 'plan_file_merge', None)))): raise errors.CommandError( - gettext('Plan file merge unsupported: ' - 'Merge type incompatible with tree formats.')) + gettext("Cannot do conflict reduction and" " show base.") + ) - def _get_merger_from_branch(self, tree, location, revision, remember, - possible_transports, pb): + if merger.merge_type.requires_file_merge_plan and ( + not getattr(merger.this_tree, "plan_file_merge", None) + or not getattr(merger.other_tree, "plan_file_merge", None) + or ( + merger.base_tree is not None + and not getattr(merger.base_tree, "plan_file_merge", None) + ) + ): + raise errors.CommandError( + gettext( + "Plan file merge unsupported: " + "Merge type incompatible with tree formats." + ) + ) + + def _get_merger_from_branch( + self, tree, location, revision, remember, possible_transports, pb + ): """Produce a merger from a location, assuming it refers to a branch.""" # find the branch locations - other_loc, user_location = self._select_branch_location(tree, location, - revision, -1) + other_loc, user_location = self._select_branch_location( + tree, location, revision, -1 + ) if revision is not None and len(revision) == 2: - base_loc, _unused = self._select_branch_location(tree, - location, revision, 0) + base_loc, _unused = self._select_branch_location( + tree, location, revision, 0 + ) else: base_loc = other_loc # Open the branches - other_branch, other_path = Branch.open_containing(other_loc, - possible_transports) + other_branch, other_path = Branch.open_containing( + other_loc, possible_transports + ) if base_loc == other_loc: base_branch = other_branch else: - base_branch, base_path = Branch.open_containing(base_loc, - possible_transports) + base_branch, base_path = Branch.open_containing( + base_loc, possible_transports + ) # Find the revision ids other_revision_id = None base_revision_id = None @@ -4680,17 +5326,17 @@ def _get_merger_from_branch(self, tree, location, revision, remember, # branch) # - user ask to remember or there is no previous location set to merge # from and user didn't ask to *not* remember - if (user_location is not None - and (remember or - (remember is None and - tree.branch.get_submit_branch() is None))): + if user_location is not None and ( + remember or (remember is None and tree.branch.get_submit_branch() is None) + ): tree.branch.set_submit_branch(other_branch.base) # Merge tags (but don't set them in the master branch yet, the user # might revert this merge). Commit will propagate them. other_branch.tags.merge_to(tree.branch.tags, ignore_master=True) - merger = _mod_merge.Merger.from_revision_ids(tree, - other_revision_id, base_revision_id, other_branch, base_branch) - if other_path != '': + merger = _mod_merge.Merger.from_revision_ids( + tree, other_revision_id, base_revision_id, other_branch, base_branch + ) + if other_path != "": allow_pending = False merger.interesting_files = [other_path] else: @@ -4705,15 +5351,15 @@ def get_merger_from_uncommitted(self, tree, location, pb): :param pb: The progress bar to use for showing progress. """ from .workingtree import WorkingTree + location = self._select_branch_location(tree, location)[0] other_tree, other_path = WorkingTree.open_containing(location) merger = _mod_merge.Merger.from_uncommitted(tree, other_tree, pb) - if other_path != '': + if other_path != "": merger.interesting_files = [other_path] return merger - def _select_branch_location(self, tree, user_location, revision=None, - index=None): + def _select_branch_location(self, tree, user_location, revision=None, index=None): """Select a branch location, according to possible inputs. If provided, branches from ``revision`` are preferred. (Both @@ -4730,13 +5376,12 @@ def _select_branch_location(self, tree, user_location, revision=None, :return: (selected_location, user_location). The default location will be the user-entered location. """ - if (revision is not None and index is not None - and revision[index] is not None): + if revision is not None and index is not None and revision[index] is not None: branch = revision[index].get_branch() if branch is not None: return branch, branch if user_location is None: - location = self._get_remembered(tree, 'Merging from') + location = self._get_remembered(tree, "Merging from") else: location = user_location return location, user_location @@ -4753,11 +5398,13 @@ def _get_remembered(self, tree, verb_string): stored_location_type = "parent" mutter("%s", stored_location) if stored_location is None: - raise errors.CommandError( - gettext("No location specified or remembered")) - display_url = urlutils.unescape_for_display(stored_location, 'utf-8') - note(gettext("{0} remembered {1} location {2}").format(verb_string, - stored_location_type, display_url)) + raise errors.CommandError(gettext("No location specified or remembered")) + display_url = urlutils.unescape_for_display(stored_location, "utf-8") + note( + gettext("{0} remembered {1} location {2}").format( + verb_string, stored_location_type, display_url + ) + ) return stored_location @@ -4783,18 +5430,17 @@ class cmd_remerge(Command): brz remerge --merge-type weave --reprocess foobar """ - takes_args = ['file*'] + takes_args = ["file*"] takes_options = [ - 'merge-type', - 'reprocess', - Option('show-base', - help="Show base revision text in conflicts."), - ] - - def run(self, file_list=None, merge_type=None, show_base=False, - reprocess=False): + "merge-type", + "reprocess", + Option("show-base", help="Show base revision text in conflicts."), + ] + + def run(self, file_list=None, merge_type=None, show_base=False, reprocess=False): from .conflicts import restore from .workingtree import WorkingTree + if merge_type is None: merge_type = _mod_merge.Merge3Merger tree, file_list = WorkingTree.open_containing_paths(file_list) @@ -4802,8 +5448,11 @@ def run(self, file_list=None, merge_type=None, show_base=False, parents = tree.get_parent_ids() if len(parents) != 2: raise errors.CommandError( - gettext("Sorry, remerge only works after normal" - " merges. Not cherrypicking or multi-merges.")) + gettext( + "Sorry, remerge only works after normal" + " merges. Not cherrypicking or multi-merges." + ) + ) interesting_files = None new_conflicts = [] conflicts = tree.conflicts() @@ -4816,15 +5465,15 @@ def run(self, file_list=None, merge_type=None, show_base=False, if tree.kind(filename) != "directory": continue - for path, _ie in tree.iter_entries_by_dir( - specific_files=[filename]): + for path, _ie in tree.iter_entries_by_dir(specific_files=[filename]): interesting_files.add(path) new_conflicts = conflicts.select_conflicts(tree, file_list)[0] else: # Remerge only supports resolving contents conflicts - allowed_conflicts = ('text conflict', 'contents conflict') - restore_files = [c.path for c in conflicts - if c.typestring in allowed_conflicts] + allowed_conflicts = ("text conflict", "contents conflict") + restore_files = [ + c.path for c in conflicts if c.typestring in allowed_conflicts + ] _mod_merge.transform_tree(tree, tree.basis_tree(), interesting_files) tree.set_conflicts(new_conflicts) if file_list is not None: @@ -4909,18 +5558,19 @@ class cmd_revert(Command): target branches. """ - _see_also = ['cat', 'export', 'merge', 'shelve'] + _see_also = ["cat", "export", "merge", "shelve"] takes_options = [ - 'revision', - Option('no-backup', "Do not save backups of reverted files."), - Option('forget-merges', - 'Remove pending merge marker, without changing any files.'), - ] - takes_args = ['file*'] - - def run(self, revision=None, no_backup=False, file_list=None, - forget_merges=None): + "revision", + Option("no-backup", "Do not save backups of reverted files."), + Option( + "forget-merges", "Remove pending merge marker, without changing any files." + ), + ] + takes_args = ["file*"] + + def run(self, revision=None, no_backup=False, file_list=None, forget_merges=None): from .workingtree import WorkingTree + tree, file_list = WorkingTree.open_containing_paths(file_list) self.enter_context(tree.lock_tree_write()) if forget_merges: @@ -4930,9 +5580,8 @@ def run(self, revision=None, no_backup=False, file_list=None, @staticmethod def _revert_tree_to_revision(tree, revision, file_list, no_backup): - rev_tree = _get_one_revision_tree('revert', revision, tree=tree) - tree.revert(file_list, rev_tree, not no_backup, None, - report_changes=True) + rev_tree = _get_one_revision_tree("revert", revision, tree=tree) + tree.revert(file_list, rev_tree, not no_backup, None, report_changes=True) class cmd_assert_fail(Command): @@ -4949,16 +5598,17 @@ class cmd_help(Command): __doc__ = """Show help on a command or other topic. """ - _see_also = ['topics'] + _see_also = ["topics"] takes_options = [ - Option('long', 'Show help on all commands.'), - ] - takes_args = ['topic?'] - aliases = ['?', '--help', '-?', '-h'] + Option("long", "Show help on all commands."), + ] + takes_args = ["topic?"] + aliases = ["?", "--help", "-?", "-h"] @display_command def run(self, topic=None, long=False): import breezy.help + if topic is None and long: topic = "commands" breezy.help.help(topic) @@ -4969,13 +5619,14 @@ class cmd_shell_complete(Command): For a list of all available commands, say 'brz shell-complete'. """ - takes_args = ['context?'] - aliases = ['s-c'] + takes_args = ["context?"] + aliases = ["s-c"] hidden = True @display_command def run(self, context=None): from . import shellcomplete + shellcomplete.shellcomplete(context) @@ -5014,41 +5665,58 @@ class cmd_missing(Command): brz missing --my-revision ..-10 """ - _see_also = ['merge', 'pull'] - takes_args = ['other_branch?'] + _see_also = ["merge", "pull"] + takes_args = ["other_branch?"] takes_options = [ - 'directory', - Option('reverse', 'Reverse the order of revisions.'), - Option('mine-only', - 'Display changes in the local branch only.'), - Option('this', 'Same as --mine-only.'), - Option('theirs-only', - 'Display changes in the remote branch only.'), - Option('other', 'Same as --theirs-only.'), - 'log-format', - 'show-ids', - 'verbose', - custom_help('revision', - help='Filter on other branch revisions (inclusive). ' - 'See "help revisionspec" for details.'), - Option('my-revision', - type=_parse_revision_str, - help='Filter on local branch revisions (inclusive). ' - 'See "help revisionspec" for details.'), - Option('include-merged', - 'Show all revisions in addition to the mainline ones.'), - Option('include-merges', hidden=True, - help='Historical alias for --include-merged.'), - ] - encoding_type = 'replace' + "directory", + Option("reverse", "Reverse the order of revisions."), + Option("mine-only", "Display changes in the local branch only."), + Option("this", "Same as --mine-only."), + Option("theirs-only", "Display changes in the remote branch only."), + Option("other", "Same as --theirs-only."), + "log-format", + "show-ids", + "verbose", + custom_help( + "revision", + help="Filter on other branch revisions (inclusive). " + 'See "help revisionspec" for details.', + ), + Option( + "my-revision", + type=_parse_revision_str, + help="Filter on local branch revisions (inclusive). " + 'See "help revisionspec" for details.', + ), + Option( + "include-merged", "Show all revisions in addition to the mainline ones." + ), + Option( + "include-merges", hidden=True, help="Historical alias for --include-merged." + ), + ] + encoding_type = "replace" @display_command - def run(self, other_branch=None, reverse=False, mine_only=False, - theirs_only=False, - log_format=None, long=False, short=False, line=False, - show_ids=False, verbose=False, this=False, other=False, - include_merged=None, revision=None, my_revision=None, - directory='.'): + def run( + self, + other_branch=None, + reverse=False, + mine_only=False, + theirs_only=False, + log_format=None, + long=False, + short=False, + line=False, + show_ids=False, + verbose=False, + this=False, + other=False, + include_merged=None, + revision=None, + my_revision=None, + directory=".", + ): from .missing import find_unmerged, iter_log_revisions def message(s): @@ -5064,11 +5732,11 @@ def message(s): # TODO: We should probably check that we don't have mine-only and # theirs-only set, but it gets complicated because we also have # this and other which could be used. - restrict = 'all' + restrict = "all" if mine_only: - restrict = 'local' + restrict = "local" elif theirs_only: - restrict = 'remote' + restrict = "remote" local_branch = Branch.open_containing(directory)[0] self.enter_context(local_branch.lock_read()) @@ -5077,12 +5745,11 @@ def message(s): if other_branch is None: other_branch = parent if other_branch is None: - raise errors.CommandError(gettext("No peer location known" - " or specified.")) - display_url = urlutils.unescape_for_display(parent, - self.outf.encoding) - message(gettext("Using saved parent location: {0}\n").format( - display_url)) + raise errors.CommandError( + gettext("No peer location known" " or specified.") + ) + display_url = urlutils.unescape_for_display(parent, self.outf.encoding) + message(gettext("Using saved parent location: {0}\n").format(display_url)) remote_branch = Branch.open(other_branch) if remote_branch.base == local_branch.base: @@ -5091,40 +5758,44 @@ def message(s): self.enter_context(remote_branch.lock_read()) local_revid_range = _revision_range_to_revid_range( - _get_revision_range(my_revision, local_branch, - self.name())) + _get_revision_range(my_revision, local_branch, self.name()) + ) remote_revid_range = _revision_range_to_revid_range( - _get_revision_range(revision, - remote_branch, self.name())) + _get_revision_range(revision, remote_branch, self.name()) + ) local_extra, remote_extra = find_unmerged( - local_branch, remote_branch, restrict, + local_branch, + remote_branch, + restrict, backward=not reverse, include_merged=include_merged, local_revid_range=local_revid_range, - remote_revid_range=remote_revid_range) + remote_revid_range=remote_revid_range, + ) if log_format is None: registry = log.log_formatter_registry log_format = registry.get_default(local_branch) - lf = log_format(to_file=self.outf, - show_ids=show_ids, - show_timezone='original') + lf = log_format(to_file=self.outf, show_ids=show_ids, show_timezone="original") status_code = 0 if local_extra and not theirs_only: - message(ngettext("You have %d extra revision:\n", - "You have %d extra revisions:\n", - len(local_extra)) % - len(local_extra)) + message( + ngettext( + "You have %d extra revision:\n", + "You have %d extra revisions:\n", + len(local_extra), + ) + % len(local_extra) + ) rev_tag_dict = {} if local_branch.supports_tags(): rev_tag_dict = local_branch.tags.get_reverse_tag_dict() - for revision in iter_log_revisions(local_extra, - local_branch.repository, - verbose, - rev_tag_dict): + for revision in iter_log_revisions( + local_extra, local_branch.repository, verbose, rev_tag_dict + ): lf.log_revision(revision) printed_local = True status_code = 1 @@ -5134,27 +5805,29 @@ def message(s): if remote_extra and not mine_only: if printed_local is True: message("\n\n\n") - message(ngettext("You are missing %d revision:\n", - "You are missing %d revisions:\n", - len(remote_extra)) % - len(remote_extra)) + message( + ngettext( + "You are missing %d revision:\n", + "You are missing %d revisions:\n", + len(remote_extra), + ) + % len(remote_extra) + ) if remote_branch.supports_tags(): rev_tag_dict = remote_branch.tags.get_reverse_tag_dict() - for revision in iter_log_revisions(remote_extra, - remote_branch.repository, - verbose, - rev_tag_dict): + for revision in iter_log_revisions( + remote_extra, remote_branch.repository, verbose, rev_tag_dict + ): lf.log_revision(revision) status_code = 1 if mine_only and not local_extra: # We checked local, and found nothing extra - message(gettext('This branch has no new revisions.\n')) + message(gettext("This branch has no new revisions.\n")) elif theirs_only and not remote_extra: # We checked remote, and found nothing extra - message(gettext('Other branch has no new revisions.\n')) - elif not (mine_only or theirs_only or local_extra or - remote_extra): + message(gettext("Other branch has no new revisions.\n")) + elif not (mine_only or theirs_only or local_extra or remote_extra): # We checked both branches, and neither one had extra # revisions message(gettext("Branches are up to date.\n")) @@ -5186,14 +5859,13 @@ class cmd_pack(Command): been. In this case the repository may be unusable. """ - _see_also = ['repositories'] - takes_args = ['branch_or_repo?'] + _see_also = ["repositories"] + takes_args = ["branch_or_repo?"] takes_options = [ - Option('clean-obsolete-packs', - 'Delete obsolete packs to save disk space.'), - ] + Option("clean-obsolete-packs", "Delete obsolete packs to save disk space."), + ] - def run(self, branch_or_repo='.', clean_obsolete_packs=False): + def run(self, branch_or_repo=".", clean_obsolete_packs=False): dir = controldir.ControlDir.open_containing(branch_or_repo)[0] try: branch = dir.open_branch() @@ -5222,35 +5894,35 @@ class cmd_plugins(Command): install them. Instructions are also provided there on how to write new plugins using the Python programming language. """ - takes_options = ['verbose'] + takes_options = ["verbose"] @display_command def run(self, verbose=False): from . import plugin # Don't give writelines a generator as some codecs don't like that - self.outf.writelines( - list(plugin.describe_plugins(show_paths=verbose))) + self.outf.writelines(list(plugin.describe_plugins(show_paths=verbose))) class cmd_testament(Command): __doc__ = """Show testament (signing-form) of a revision.""" takes_options = [ - 'revision', - Option('long', help='Produce long-format testament.'), - Option('strict', - help='Produce a strict-format testament.')] - takes_args = ['branch?'] - encoding_type = 'exact' + "revision", + Option("long", help="Produce long-format testament."), + Option("strict", help="Produce a strict-format testament."), + ] + takes_args = ["branch?"] + encoding_type = "exact" @display_command - def run(self, branch='.', revision=None, long=False, strict=False): + def run(self, branch=".", revision=None, long=False, strict=False): from .bzr.testament import StrictTestament, Testament + if strict is True: testament_class = StrictTestament else: testament_class = Testament - if branch == '.': + if branch == ".": b = Branch.open_containing(branch)[0] else: b = Branch.open(branch) @@ -5278,40 +5950,50 @@ class cmd_annotate(Command): # TODO: annotate directories; showing when each file was last changed # TODO: if the working copy is modified, show annotations on that # with new uncommitted lines marked - aliases = ['ann', 'blame', 'praise'] - takes_args = ['filename'] - takes_options = [Option('all', help='Show annotations on all lines.'), - Option('long', help='Show commit date in annotations.'), - 'revision', - 'show-ids', - 'directory', - ] - encoding_type = 'exact' + aliases = ["ann", "blame", "praise"] + takes_args = ["filename"] + takes_options = [ + Option("all", help="Show annotations on all lines."), + Option("long", help="Show commit date in annotations."), + "revision", + "show-ids", + "directory", + ] + encoding_type = "exact" @display_command - def run(self, filename, all=False, long=False, revision=None, - show_ids=False, directory=None): + def run( + self, + filename, + all=False, + long=False, + revision=None, + show_ids=False, + directory=None, + ): from .annotate import annotate_file_tree - wt, branch, relpath = \ - _open_directory_or_containing_tree_or_branch(filename, directory) + + wt, branch, relpath = _open_directory_or_containing_tree_or_branch( + filename, directory + ) if wt is not None: self.enter_context(wt.lock_read()) else: self.enter_context(branch.lock_read()) - tree = _get_one_revision_tree('annotate', revision, branch=branch) + tree = _get_one_revision_tree("annotate", revision, branch=branch) self.enter_context(tree.lock_read()) if wt is not None and revision is None: if not wt.is_versioned(relpath): raise errors.NotVersionedError(relpath) # If there is a tree and we're not annotating historical # versions, annotate the working tree's content. - annotate_file_tree(wt, relpath, self.outf, long, all, - show_ids=show_ids) + annotate_file_tree(wt, relpath, self.outf, long, all, show_ids=show_ids) else: if not tree.is_versioned(relpath): raise errors.NotVersionedError(relpath) - annotate_file_tree(tree, relpath, self.outf, long, all, - show_ids=show_ids, branch=branch) + annotate_file_tree( + tree, relpath, self.outf, long, all, show_ids=show_ids, branch=branch + ) class cmd_re_sign(Command): @@ -5319,28 +6001,32 @@ class cmd_re_sign(Command): # TODO be able to replace existing ones. hidden = True # is this right ? - takes_args = ['revision_id*'] - takes_options = ['directory', 'revision'] + takes_args = ["revision_id*"] + takes_options = ["directory", "revision"] - def run(self, revision_id_list=None, revision=None, directory='.'): + def run(self, revision_id_list=None, revision=None, directory="."): from .workingtree import WorkingTree + if revision_id_list is not None and revision is not None: raise errors.CommandError( - gettext('You can only supply one of revision_id or --revision')) + gettext("You can only supply one of revision_id or --revision") + ) if revision_id_list is None and revision is None: raise errors.CommandError( - gettext('You must supply either --revision or a revision_id')) + gettext("You must supply either --revision or a revision_id") + ) b = WorkingTree.open_containing(directory)[0].branch self.enter_context(b.lock_write()) return self._run(b, revision_id_list, revision) def _run(self, b, revision_id_list, revision): from .repository import WriteGroup + gpg_strategy = gpg.GPGStrategy(b.get_config_stack()) if revision_id_list is not None: with WriteGroup(b.repository): for revision_id in revision_id_list: - revision_id = revision_id.encode('utf-8') + revision_id = revision_id.encode("utf-8") b.repository.sign_revision(revision_id, gpg_strategy) elif revision is not None: if len(revision) == 1: @@ -5357,14 +6043,15 @@ def _run(self, b, revision_id_list, revision): to_revno = b.revno() if from_revno is None or to_revno is None: raise errors.CommandError( - gettext('Cannot sign a range of non-revision-history revisions')) + gettext("Cannot sign a range of non-revision-history revisions") + ) with WriteGroup(b.repository): for revno in range(from_revno, to_revno + 1): - b.repository.sign_revision(b.get_rev_id(revno), - gpg_strategy) + b.repository.sign_revision(b.get_rev_id(revno), gpg_strategy) else: raise errors.CommandError( - gettext('Please supply either one revision, or a range.')) + gettext("Please supply either one revision, or a range.") + ) class cmd_bind(Command): @@ -5379,35 +6066,41 @@ class cmd_bind(Command): that of the master. """ - _see_also = ['checkouts', 'unbind'] - takes_args = ['location?'] - takes_options = ['directory'] + _see_also = ["checkouts", "unbind"] + takes_args = ["location?"] + takes_options = ["directory"] - def run(self, location=None, directory='.'): + def run(self, location=None, directory="."): b, relpath = Branch.open_containing(directory) if location is None: try: location = b.get_old_bound_location() except errors.UpgradeRequired as exc: raise errors.CommandError( - gettext('No location supplied. ' - 'This format does not remember old locations.')) from exc + gettext( + "No location supplied. " + "This format does not remember old locations." + ) + ) from exc else: if location is None: if b.get_bound_location() is not None: - raise errors.CommandError( - gettext('Branch is already bound')) + raise errors.CommandError(gettext("Branch is already bound")) else: raise errors.CommandError( - gettext('No location supplied' - ' and no previous location known')) + gettext( + "No location supplied" " and no previous location known" + ) + ) b_other = Branch.open(location) try: b.bind(b_other) except errors.DivergedBranches as exc: raise errors.CommandError( - gettext('These branches have diverged.' - ' Try merging, and then bind again.')) from exc + gettext( + "These branches have diverged." " Try merging, and then bind again." + ) + ) from exc if b.get_config().has_explicit_nickname(): b.nick = b_other.nick @@ -5419,13 +6112,13 @@ class cmd_unbind(Command): commits will be local only. """ - _see_also = ['checkouts', 'bind'] - takes_options = ['directory'] + _see_also = ["checkouts", "bind"] + takes_options = ["directory"] - def run(self, directory='.'): + def run(self, directory="."): b, relpath = Branch.open_containing(directory) if not b.unbind(): - raise errors.CommandError(gettext('Local branch is not bound')) + raise errors.CommandError(gettext("Local branch is not bound")) class cmd_uncommit(Command): @@ -5448,24 +6141,33 @@ class cmd_uncommit(Command): # unreferenced information in 'branch-as-repository' branches. # TODO: jam 20060108 Add the ability for uncommit to remove unreferenced # information in shared branches as well. - _see_also = ['commit'] - takes_options = ['verbose', 'revision', - Option('dry-run', help='Don\'t actually make changes.'), - Option('force', help='Say yes to all questions.'), - Option('keep-tags', - help='Keep tags that point to removed revisions.'), - Option('local', - help="Only remove the commits from the local " - "branch when in a checkout." - ), - ] - takes_args = ['location?'] - encoding_type = 'replace' - - def run(self, location=None, dry_run=False, verbose=False, - revision=None, force=False, local=False, keep_tags=False): + _see_also = ["commit"] + takes_options = [ + "verbose", + "revision", + Option("dry-run", help="Don't actually make changes."), + Option("force", help="Say yes to all questions."), + Option("keep-tags", help="Keep tags that point to removed revisions."), + Option( + "local", + help="Only remove the commits from the local " "branch when in a checkout.", + ), + ] + takes_args = ["location?"] + encoding_type = "replace" + + def run( + self, + location=None, + dry_run=False, + verbose=False, + revision=None, + force=False, + local=False, + keep_tags=False, + ): if location is None: - location = '.' + location = "." control, relpath = controldir.ControlDir.open_containing(location) try: tree = control.open_workingtree() @@ -5478,11 +6180,13 @@ def run(self, location=None, dry_run=False, verbose=False, self.enter_context(tree.lock_write()) else: self.enter_context(b.lock_write()) - return self._run(b, tree, dry_run, verbose, revision, force, - local, keep_tags, location) + return self._run( + b, tree, dry_run, verbose, revision, force, local, keep_tags, location + ) - def _run(self, b, tree, dry_run, verbose, revision, force, local, - keep_tags, location): + def _run( + self, b, tree, dry_run, verbose, revision, force, local, keep_tags, location + ): from .log import log_formatter, show_log from .uncommit import uncommit @@ -5503,49 +6207,60 @@ def _run(self, b, tree, dry_run, verbose, revision, force, local, rev_id = b.get_rev_id(revno) if rev_id is None or _mod_revision.is_null(rev_id): - self.outf.write(gettext('No revisions to uncommit.\n')) + self.outf.write(gettext("No revisions to uncommit.\n")) return 1 - lf = log_formatter('short', - to_file=self.outf, - show_timezone='original') + lf = log_formatter("short", to_file=self.outf, show_timezone="original") - show_log(b, - lf, - verbose=False, - direction='forward', - start_revision=revno, - end_revision=last_revno) + show_log( + b, + lf, + verbose=False, + direction="forward", + start_revision=revno, + end_revision=last_revno, + ) if dry_run: - self.outf.write(gettext('Dry-run, pretending to remove' - ' the above revisions.\n')) - else: self.outf.write( - gettext('The above revision(s) will be removed.\n')) + gettext("Dry-run, pretending to remove" " the above revisions.\n") + ) + else: + self.outf.write(gettext("The above revision(s) will be removed.\n")) if not force: if not ui.ui_factory.confirm_action( - gettext('Uncommit these revisions'), - 'breezy.builtins.uncommit', - {}): - self.outf.write(gettext('Canceled\n')) + gettext("Uncommit these revisions"), "breezy.builtins.uncommit", {} + ): + self.outf.write(gettext("Canceled\n")) return 0 - mutter('Uncommitting from {%s} to {%s}', - last_rev_id, rev_id) - uncommit(b, tree=tree, dry_run=dry_run, verbose=verbose, - revno=revno, local=local, keep_tags=keep_tags) - if location != '.': + mutter("Uncommitting from {%s} to {%s}", last_rev_id, rev_id) + uncommit( + b, + tree=tree, + dry_run=dry_run, + verbose=verbose, + revno=revno, + local=local, + keep_tags=keep_tags, + ) + if location != ".": self.outf.write( - gettext('You can restore the old tip by running:\n' - ' brz pull -d %s %s -r revid:%s\n') - % (location, location, last_rev_id.decode('utf-8'))) + gettext( + "You can restore the old tip by running:\n" + " brz pull -d %s %s -r revid:%s\n" + ) + % (location, location, last_rev_id.decode("utf-8")) + ) else: self.outf.write( - gettext('You can restore the old tip by running:\n' - ' brz pull . -r revid:%s\n') - % last_rev_id.decode('utf-8')) + gettext( + "You can restore the old tip by running:\n" + " brz pull . -r revid:%s\n" + ) + % last_rev_id.decode("utf-8") + ) class cmd_break_lock(Command): @@ -5566,21 +6281,19 @@ class cmd_break_lock(Command): brz break-lock --conf ~/.config/breezy """ - takes_args = ['location?'] + takes_args = ["location?"] takes_options = [ - Option('config', - help='LOCATION is the directory where the config lock is.'), - Option('force', - help='Do not ask for confirmation before breaking the lock.'), - ] + Option("config", help="LOCATION is the directory where the config lock is."), + Option("force", help="Do not ask for confirmation before breaking the lock."), + ] def run(self, location=None, config=False, force=False): if location is None: - location = '.' + location = "." if force: - ui.ui_factory = ui.ConfirmationUserInterfacePolicy(ui.ui_factory, - None, - {'breezy.lockdir.break': True}) + ui.ui_factory = ui.ConfirmationUserInterfacePolicy( + ui.ui_factory, None, {"breezy.lockdir.break": True} + ) if config: conf = _mod_config.LockableConfig(file_name=location) conf.break_lock() @@ -5609,49 +6322,61 @@ def run(self): class cmd_serve(Command): __doc__ = """Run the brz server.""" - aliases = ['server'] + aliases = ["server"] takes_options = [ - Option('inet', - help='Serve on stdin/out for use from inetd or sshd.'), - RegistryOption('protocol', - help="Protocol to serve.", - lazy_registry=('breezy.transport', - 'transport_server_registry'), - value_switches=True), - Option('listen', - help='Listen for connections on nominated address.', - type=str), - Option('port', - help='Listen for connections on nominated port. Passing 0 as ' - 'the port number will result in a dynamically allocated ' - 'port. The default port depends on the protocol.', - type=int), - custom_help('directory', - help='Serve contents of this directory.'), - Option('allow-writes', - help='By default the server is a readonly server. Supplying ' - '--allow-writes enables write access to the contents of ' - 'the served directory and below. Note that ``brz serve`` ' - 'does not perform authentication, so unless some form of ' - 'external authentication is arranged supplying this ' - 'option leads to global uncontrolled write access to your ' - 'file system.' - ), - Option('client-timeout', type=float, - help='Override the default idle client timeout (5min).'), - ] - - def run(self, listen=None, port=None, inet=False, directory=None, - allow_writes=False, protocol=None, client_timeout=None): + Option("inet", help="Serve on stdin/out for use from inetd or sshd."), + RegistryOption( + "protocol", + help="Protocol to serve.", + lazy_registry=("breezy.transport", "transport_server_registry"), + value_switches=True, + ), + Option("listen", help="Listen for connections on nominated address.", type=str), + Option( + "port", + help="Listen for connections on nominated port. Passing 0 as " + "the port number will result in a dynamically allocated " + "port. The default port depends on the protocol.", + type=int, + ), + custom_help("directory", help="Serve contents of this directory."), + Option( + "allow-writes", + help="By default the server is a readonly server. Supplying " + "--allow-writes enables write access to the contents of " + "the served directory and below. Note that ``brz serve`` " + "does not perform authentication, so unless some form of " + "external authentication is arranged supplying this " + "option leads to global uncontrolled write access to your " + "file system.", + ), + Option( + "client-timeout", + type=float, + help="Override the default idle client timeout (5min).", + ), + ] + + def run( + self, + listen=None, + port=None, + inet=False, + directory=None, + allow_writes=False, + protocol=None, + client_timeout=None, + ): from . import location, transport + if directory is None: directory = osutils.getcwd() if protocol is None: protocol = transport.transport_server_registry.get() url = location.location_to_url(directory) if not allow_writes: - url = 'readonly+' + url + url = "readonly+" + url t = transport.get_transport_from_url(url) protocol(t, listen, port, inet, client_timeout) @@ -5670,24 +6395,28 @@ class cmd_join(Command): and all history is preserved. """ - _see_also = ['split'] - takes_args = ['tree'] + _see_also = ["split"] + takes_args = ["tree"] takes_options = [ - Option('reference', help='Join by reference.', hidden=True), - ] + Option("reference", help="Join by reference.", hidden=True), + ] def run(self, tree, reference=False): from .mutabletree import BadReferenceTarget from .workingtree import WorkingTree + sub_tree = WorkingTree.open(tree) parent_dir = osutils.dirname(sub_tree.basedir) containing_tree = WorkingTree.open_containing(parent_dir)[0] repo = containing_tree.branch.repository if not repo.supports_rich_root(): - raise errors.CommandError(gettext( - "Can't join trees because %s doesn't support rich root data.\n" - "You can use brz upgrade on the repository.") - % (repo,)) + raise errors.CommandError( + gettext( + "Can't join trees because %s doesn't support rich root data.\n" + "You can use brz upgrade on the repository." + ) + % (repo,) + ) if reference: try: containing_tree.add_reference(sub_tree) @@ -5695,13 +6424,15 @@ def run(self, tree, reference=False): # XXX: Would be better to just raise a nicely printable # exception from the real origin. Also below. mbp 20070306 raise errors.CommandError( - gettext("Cannot join {0}. {1}").format(tree, exc.reason)) from exc + gettext("Cannot join {0}. {1}").format(tree, exc.reason) + ) from exc else: try: containing_tree.subsume(sub_tree) except errors.BadSubsumeSource as exc: raise errors.CommandError( - gettext("Cannot join {0}. {1}").format(tree, exc.reason)) from exc + gettext("Cannot join {0}. {1}").format(tree, exc.reason) + ) from exc class cmd_split(Command): @@ -5716,11 +6447,12 @@ class cmd_split(Command): branch. Commits in the top-level tree will not apply to the new subtree. """ - _see_also = ['join'] - takes_args = ['tree'] + _see_also = ["join"] + takes_args = ["tree"] def run(self, tree): from .workingtree import WorkingTree + containing_tree, subdir = WorkingTree.open_containing(tree) if not containing_tree.is_versioned(subdir): raise errors.NotVersionedError(subdir) @@ -5747,43 +6479,60 @@ class cmd_merge_directive(Command): after the first use. """ - takes_args = ['submit_branch?', 'public_branch?'] + takes_args = ["submit_branch?", "public_branch?"] hidden = True - _see_also = ['send'] + _see_also = ["send"] takes_options = [ - 'directory', + "directory", RegistryOption.from_kwargs( - 'patch-type', - 'The type of patch to include in the directive.', - title='Patch type', + "patch-type", + "The type of patch to include in the directive.", + title="Patch type", value_switches=True, enum_switch=False, - bundle='Bazaar revision bundle (default).', - diff='Normal unified diff.', - plain='No patch, just directive.'), - Option('sign', help='GPG-sign the directive.'), 'revision', - Option('mail-to', type=str, - help='Instead of printing the directive, email to this ' - 'address.'), - Option('message', type=str, short_name='m', - help='Message to use when committing this merge.') - ] - - encoding_type = 'exact' - - def run(self, submit_branch=None, public_branch=None, patch_type='bundle', - sign=False, revision=None, mail_to=None, message=None, - directory='.'): + bundle="Bazaar revision bundle (default).", + diff="Normal unified diff.", + plain="No patch, just directive.", + ), + Option("sign", help="GPG-sign the directive."), + "revision", + Option( + "mail-to", + type=str, + help="Instead of printing the directive, email to this " "address.", + ), + Option( + "message", + type=str, + short_name="m", + help="Message to use when committing this merge.", + ), + ] + + encoding_type = "exact" + + def run( + self, + submit_branch=None, + public_branch=None, + patch_type="bundle", + sign=False, + revision=None, + mail_to=None, + message=None, + directory=".", + ): from . import merge_directive from .revision import NULL_REVISION + include_patch, include_bundle = { - 'plain': (False, False), - 'diff': (True, False), - 'bundle': (True, True), - }[patch_type] + "plain": (False, False), + "diff": (True, False), + "bundle": (True, True), + }[patch_type] branch = Branch.open(directory) stored_submit_branch = branch.get_submit_branch() if submit_branch is None: @@ -5794,8 +6543,7 @@ def run(self, submit_branch=None, public_branch=None, patch_type='bundle', if submit_branch is None: submit_branch = branch.get_parent() if submit_branch is None: - raise errors.CommandError( - gettext('No submit branch specified or known')) + raise errors.CommandError(gettext("No submit branch specified or known")) stored_public_branch = branch.get_public_branch() if public_branch is None: @@ -5804,28 +6552,35 @@ def run(self, submit_branch=None, public_branch=None, patch_type='bundle', # FIXME: Should be done only if we succeed ? -- vila 2012-01-03 branch.set_public_branch(public_branch) if not include_bundle and public_branch is None: - raise errors.CommandError( - gettext('No public branch specified or known')) + raise errors.CommandError(gettext("No public branch specified or known")) base_revision_id = None if revision is not None: if len(revision) > 2: raise errors.CommandError( - gettext('brz merge-directive takes ' - 'at most two one revision identifiers')) + gettext( + "brz merge-directive takes " + "at most two one revision identifiers" + ) + ) revision_id = revision[-1].as_revision_id(branch) if len(revision) == 2: base_revision_id = revision[0].as_revision_id(branch) else: revision_id = branch.last_revision() if revision_id == NULL_REVISION: - raise errors.CommandError(gettext('No revisions to bundle.')) + raise errors.CommandError(gettext("No revisions to bundle.")) directive = merge_directive.MergeDirective2.from_objects( - repository=branch.repository, revision_id=revision_id, - time=time.time(), timezone=osutils.local_time_offset(), + repository=branch.repository, + revision_id=revision_id, + time=time.time(), + timezone=osutils.local_time_offset(), target_branch=submit_branch, - public_branch=public_branch, include_patch=include_patch, - include_bundle=include_bundle, message=message, - base_revision_id=base_revision_id) + public_branch=public_branch, + include_patch=include_patch, + include_bundle=include_bundle, + message=message, + base_revision_id=base_revision_id, + ) if mail_to is None: if sign: self.outf.write(directive.to_signed(branch)) @@ -5833,6 +6588,7 @@ def run(self, submit_branch=None, public_branch=None, patch_type='bundle', self.outf.writelines(directive.to_lines()) else: from .smtp_connection import SMTPConnection + message = directive.to_email(mail_to, branch, sign) s = SMTPConnection(branch.get_config_stack()) s.send_email(message) @@ -5907,51 +6663,82 @@ class cmd_send(Command): set them, and use `brz info` to display them. """ - encoding_type = 'exact' + encoding_type = "exact" - _see_also = ['merge', 'pull'] + _see_also = ["merge", "pull"] - takes_args = ['submit_branch?', 'public_branch?'] + takes_args = ["submit_branch?", "public_branch?"] takes_options = [ - Option('no-bundle', - help='Do not include a bundle in the merge directive.'), - Option('no-patch', help='Do not include a preview patch in the merge' - ' directive.'), - Option('remember', - help='Remember submit and public branch.'), - Option('from', - help='Branch to generate the submission from, ' - 'rather than the one containing the working directory.', - short_name='f', - type=str), - Option('output', short_name='o', - help='Write merge directive to this file or directory; ' - 'use - for stdout.', - type=str), - Option('strict', - help='Refuse to send if there are uncommitted changes in' - ' the working tree, --no-strict disables the check.'), - Option('mail-to', help='Mail the request to this address.', - type=str), - 'revision', - 'message', - Option('body', help='Body for the email.', type=str), - RegistryOption('format', - help='Use the specified output format.', - lazy_registry=('breezy.send', 'format_registry')), - ] - - def run(self, submit_branch=None, public_branch=None, no_bundle=False, - no_patch=False, revision=None, remember=None, output=None, - format=None, mail_to=None, message=None, body=None, - strict=None, **kwargs): + Option("no-bundle", help="Do not include a bundle in the merge directive."), + Option( + "no-patch", help="Do not include a preview patch in the merge" " directive." + ), + Option("remember", help="Remember submit and public branch."), + Option( + "from", + help="Branch to generate the submission from, " + "rather than the one containing the working directory.", + short_name="f", + type=str, + ), + Option( + "output", + short_name="o", + help="Write merge directive to this file or directory; " + "use - for stdout.", + type=str, + ), + Option( + "strict", + help="Refuse to send if there are uncommitted changes in" + " the working tree, --no-strict disables the check.", + ), + Option("mail-to", help="Mail the request to this address.", type=str), + "revision", + "message", + Option("body", help="Body for the email.", type=str), + RegistryOption( + "format", + help="Use the specified output format.", + lazy_registry=("breezy.send", "format_registry"), + ), + ] + + def run( + self, + submit_branch=None, + public_branch=None, + no_bundle=False, + no_patch=False, + revision=None, + remember=None, + output=None, + format=None, + mail_to=None, + message=None, + body=None, + strict=None, + **kwargs, + ): from .send import send - return send(submit_branch, revision, public_branch, remember, - format, no_bundle, no_patch, output, - kwargs.get('from', '.'), mail_to, message, body, - self.outf, - strict=strict) + + return send( + submit_branch, + revision, + public_branch, + remember, + format, + no_bundle, + no_patch, + output, + kwargs.get("from", "."), + mail_to, + message, + body, + self.outf, + strict=strict, + ) class cmd_bundle_revisions(cmd_send): @@ -5982,43 +6769,72 @@ class cmd_bundle_revisions(cmd_send): """ takes_options = [ - Option('no-bundle', - help='Do not include a bundle in the merge directive.'), - Option('no-patch', help='Do not include a preview patch in the merge' - ' directive.'), - Option('remember', - help='Remember submit and public branch.'), - Option('from', - help='Branch to generate the submission from, ' - 'rather than the one containing the working directory.', - short_name='f', - type=str), - Option('output', short_name='o', help='Write directive to this file.', - type=str), - Option('strict', - help='Refuse to bundle revisions if there are uncommitted' - ' changes in the working tree, --no-strict disables the check.'), - 'revision', - RegistryOption('format', - help='Use the specified output format.', - lazy_registry=('breezy.send', 'format_registry')), - ] - aliases = ['bundle'] - - _see_also = ['send', 'merge'] + Option("no-bundle", help="Do not include a bundle in the merge directive."), + Option( + "no-patch", help="Do not include a preview patch in the merge" " directive." + ), + Option("remember", help="Remember submit and public branch."), + Option( + "from", + help="Branch to generate the submission from, " + "rather than the one containing the working directory.", + short_name="f", + type=str, + ), + Option( + "output", short_name="o", help="Write directive to this file.", type=str + ), + Option( + "strict", + help="Refuse to bundle revisions if there are uncommitted" + " changes in the working tree, --no-strict disables the check.", + ), + "revision", + RegistryOption( + "format", + help="Use the specified output format.", + lazy_registry=("breezy.send", "format_registry"), + ), + ] + aliases = ["bundle"] + + _see_also = ["send", "merge"] hidden = True - def run(self, submit_branch=None, public_branch=None, no_bundle=False, - no_patch=False, revision=None, remember=False, output=None, - format=None, strict=None, **kwargs): + def run( + self, + submit_branch=None, + public_branch=None, + no_bundle=False, + no_patch=False, + revision=None, + remember=False, + output=None, + format=None, + strict=None, + **kwargs, + ): if output is None: - output = '-' + output = "-" from .send import send - return send(submit_branch, revision, public_branch, remember, - format, no_bundle, no_patch, output, - kwargs.get('from', '.'), None, None, None, - self.outf, strict=strict) + + return send( + submit_branch, + revision, + public_branch, + remember, + format, + no_bundle, + no_patch, + output, + kwargs.get("from", "."), + None, + None, + None, + self.outf, + strict=strict, + ) class cmd_tag(Command): @@ -6043,48 +6859,52 @@ class cmd_tag(Command): details. """ - _see_also = ['commit', 'tags'] - takes_args = ['tag_name?'] + _see_also = ["commit", "tags"] + takes_args = ["tag_name?"] takes_options = [ - Option('delete', - help='Delete this tag rather than placing it.', - ), - custom_help('directory', - help='Branch in which to place the tag.'), - Option('force', - help='Replace existing tags.', - ), - 'revision', - ] - - def run(self, tag_name=None, - delete=None, - directory='.', - force=None, - revision=None, - ): + Option( + "delete", + help="Delete this tag rather than placing it.", + ), + custom_help("directory", help="Branch in which to place the tag."), + Option( + "force", + help="Replace existing tags.", + ), + "revision", + ] + + def run( + self, + tag_name=None, + delete=None, + directory=".", + force=None, + revision=None, + ): branch, relpath = Branch.open_containing(directory) self.enter_context(branch.lock_write()) if delete: if tag_name is None: - raise errors.CommandError( - gettext("No tag specified to delete.")) + raise errors.CommandError(gettext("No tag specified to delete.")) branch.tags.delete_tag(tag_name) - note(gettext('Deleted tag %s.') % tag_name) + note(gettext("Deleted tag %s.") % tag_name) else: if revision: if len(revision) != 1: - raise errors.CommandError(gettext( - "Tags can only be placed on a single revision, " - "not on a range")) + raise errors.CommandError( + gettext( + "Tags can only be placed on a single revision, " + "not on a range" + ) + ) revision_id = revision[0].as_revision_id(branch) else: revision_id = branch.last_revision() if tag_name is None: tag_name = branch.automatic_tag_name(revision_id) if tag_name is None: - raise errors.CommandError(gettext( - "Please specify a tag name.")) + raise errors.CommandError(gettext("Please specify a tag name.")) try: existing_target = branch.tags.lookup_tag(tag_name) except errors.NoSuchTag: @@ -6092,13 +6912,13 @@ def run(self, tag_name=None, if not force and existing_target not in (None, revision_id): raise errors.TagAlreadyExists(tag_name) if existing_target == revision_id: - note(gettext('Tag %s already exists for that revision.') % tag_name) + note(gettext("Tag %s already exists for that revision.") % tag_name) else: branch.tags.set_tag(tag_name, revision_id) if existing_target is None: - note(gettext('Created tag %s.') % tag_name) + note(gettext("Created tag %s.") % tag_name) else: - note(gettext('Updated tag %s.') % tag_name) + note(gettext("Updated tag %s.") % tag_name) class cmd_tags(Command): @@ -6107,21 +6927,23 @@ class cmd_tags(Command): This command shows a table of tag names and the revisions they reference. """ - _see_also = ['tag'] + _see_also = ["tag"] takes_options = [ - custom_help('directory', - help='Branch whose tags should be displayed.'), - RegistryOption('sort', - 'Sort tags by different criteria.', title='Sorting', - lazy_registry=('breezy.tag', 'tag_sort_methods') - ), - 'show-ids', - 'revision', + custom_help("directory", help="Branch whose tags should be displayed."), + RegistryOption( + "sort", + "Sort tags by different criteria.", + title="Sorting", + lazy_registry=("breezy.tag", "tag_sort_methods"), + ), + "show-ids", + "revision", ] @display_command - def run(self, directory='.', sort=None, show_ids=False, revision=None): + def run(self, directory=".", sort=None, show_ids=False, revision=None): from .tag import tag_sort_methods + branch, relpath = Branch.open_containing(directory) tags = list(branch.tags.get_tag_dict().items()) @@ -6141,19 +6963,21 @@ def run(self, directory='.', sort=None, show_ids=False, revision=None): try: revno = branch.revision_id_to_dotted_revno(revid) if isinstance(revno, tuple): - revno = '.'.join(map(str, revno)) - except (errors.NoSuchRevision, - errors.GhostRevisionsHaveNoRevno, - errors.UnsupportedOperation): + revno = ".".join(map(str, revno)) + except ( + errors.NoSuchRevision, + errors.GhostRevisionsHaveNoRevno, + errors.UnsupportedOperation, + ): # Bad tag data/merges can lead to tagged revisions # which are not in this branch. Fail gracefully ... - revno = '?' + revno = "?" tags[index] = (tag, revno) else: - tags = [(tag, revid.decode('utf-8')) for (tag, revid) in tags] + tags = [(tag, revid.decode("utf-8")) for (tag, revid) in tags] self.cleanup_now() for tag, revspec in tags: - self.outf.write('%-20s %s\n' % (tag, revspec)) + self.outf.write("%-20s %s\n" % (tag, revspec)) def _tags_for_range(self, branch, revision): rev1, rev2 = _get_revision_range(revision, branch, self.name()) @@ -6173,8 +6997,8 @@ def _tags_for_range(self, branch, revision): tagged_revids = branch.tags.get_reverse_tag_dict() found = [] for r in branch.iter_merge_sorted_revisions( - start_revision_id=revid2, stop_revision_id=revid1, - stop_rule='include'): + start_revision_id=revid2, stop_revision_id=revid1, stop_rule="include" + ): revid_tags = tagged_revids.get(r[0], None) if revid_tags: found.extend([(tag, r[0]) for tag in revid_tags]) @@ -6195,61 +7019,77 @@ class cmd_reconfigure(Command): If none of these is available, --bind-to must be specified. """ - _see_also = ['branches', 'checkouts', 'standalone-trees', 'working-trees'] - takes_args = ['location?'] + _see_also = ["branches", "checkouts", "standalone-trees", "working-trees"] + takes_args = ["location?"] takes_options = [ RegistryOption.from_kwargs( - 'tree_type', - title='Tree type', - help='The relation between branch and tree.', - value_switches=True, enum_switch=False, - branch='Reconfigure to be an unbound branch with no working tree.', - tree='Reconfigure to be an unbound branch with a working tree.', - checkout='Reconfigure to be a bound branch with a working tree.', - lightweight_checkout='Reconfigure to be a lightweight' - ' checkout (with no local history).', - ), + "tree_type", + title="Tree type", + help="The relation between branch and tree.", + value_switches=True, + enum_switch=False, + branch="Reconfigure to be an unbound branch with no working tree.", + tree="Reconfigure to be an unbound branch with a working tree.", + checkout="Reconfigure to be a bound branch with a working tree.", + lightweight_checkout="Reconfigure to be a lightweight" + " checkout (with no local history).", + ), RegistryOption.from_kwargs( - 'repository_type', - title='Repository type', - help='Location fo the repository.', - value_switches=True, enum_switch=False, - standalone='Reconfigure to be a standalone branch ' - '(i.e. stop using shared repository).', - use_shared='Reconfigure to use a shared repository.', - ), + "repository_type", + title="Repository type", + help="Location fo the repository.", + value_switches=True, + enum_switch=False, + standalone="Reconfigure to be a standalone branch " + "(i.e. stop using shared repository).", + use_shared="Reconfigure to use a shared repository.", + ), RegistryOption.from_kwargs( - 'repository_trees', - title='Trees in Repository', - help='Whether new branches in the repository have trees.', - value_switches=True, enum_switch=False, - with_trees='Reconfigure repository to create ' - 'working trees on branches by default.', - with_no_trees='Reconfigure repository to not create ' - 'working trees on branches by default.' - ), - Option('bind-to', help='Branch to bind checkout to.', type=str), - Option('force', - help='Perform reconfiguration even if local changes' - ' will be lost.'), - Option('stacked-on', - help='Reconfigure a branch to be stacked on another branch.', - type=str, - ), - Option('unstacked', - help='Reconfigure a branch to be unstacked. This ' - 'may require copying substantial data into it.', - ), - ] - - def run(self, location=None, bind_to=None, force=False, - tree_type=None, repository_type=None, repository_trees=None, - stacked_on=None, unstacked=None): + "repository_trees", + title="Trees in Repository", + help="Whether new branches in the repository have trees.", + value_switches=True, + enum_switch=False, + with_trees="Reconfigure repository to create " + "working trees on branches by default.", + with_no_trees="Reconfigure repository to not create " + "working trees on branches by default.", + ), + Option("bind-to", help="Branch to bind checkout to.", type=str), + Option( + "force", + help="Perform reconfiguration even if local changes" " will be lost.", + ), + Option( + "stacked-on", + help="Reconfigure a branch to be stacked on another branch.", + type=str, + ), + Option( + "unstacked", + help="Reconfigure a branch to be unstacked. This " + "may require copying substantial data into it.", + ), + ] + + def run( + self, + location=None, + bind_to=None, + force=False, + tree_type=None, + repository_type=None, + repository_trees=None, + stacked_on=None, + unstacked=None, + ): from . import reconfigure + directory = controldir.ControlDir.open(location) if stacked_on and unstacked: raise errors.CommandError( - gettext("Can't use both --stacked-on and --unstacked")) + gettext("Can't use both --stacked-on and --unstacked") + ) elif stacked_on is not None: reconfigure.ReconfigureStackedOn().apply(directory, stacked_on) elif unstacked: @@ -6257,41 +7097,42 @@ def run(self, location=None, bind_to=None, force=False, # At the moment you can use --stacked-on and a different # reconfiguration shape at the same time; there seems no good reason # to ban it. - if (tree_type is None and - repository_type is None and - repository_trees is None): + if tree_type is None and repository_type is None and repository_trees is None: if stacked_on or unstacked: return else: - raise errors.CommandError(gettext('No target configuration ' - 'specified')) + raise errors.CommandError( + gettext("No target configuration " "specified") + ) reconfiguration = None - if tree_type == 'branch': + if tree_type == "branch": reconfiguration = reconfigure.Reconfigure.to_branch(directory) - elif tree_type == 'tree': + elif tree_type == "tree": reconfiguration = reconfigure.Reconfigure.to_tree(directory) - elif tree_type == 'checkout': - reconfiguration = reconfigure.Reconfigure.to_checkout( - directory, bind_to) - elif tree_type == 'lightweight-checkout': + elif tree_type == "checkout": + reconfiguration = reconfigure.Reconfigure.to_checkout(directory, bind_to) + elif tree_type == "lightweight-checkout": reconfiguration = reconfigure.Reconfigure.to_lightweight_checkout( - directory, bind_to) + directory, bind_to + ) if reconfiguration: reconfiguration.apply(force) reconfiguration = None - if repository_type == 'use-shared': + if repository_type == "use-shared": reconfiguration = reconfigure.Reconfigure.to_use_shared(directory) - elif repository_type == 'standalone': + elif repository_type == "standalone": reconfiguration = reconfigure.Reconfigure.to_standalone(directory) if reconfiguration: reconfiguration.apply(force) reconfiguration = None - if repository_trees == 'with-trees': + if repository_trees == "with-trees": reconfiguration = reconfigure.Reconfigure.set_repository_trees( - directory, True) - elif repository_trees == 'with-no-trees': + directory, True + ) + elif repository_trees == "with-no-trees": reconfiguration = reconfigure.Reconfigure.set_repository_trees( - directory, False) + directory, False + ) if reconfiguration: reconfiguration.apply(force) reconfiguration = None @@ -6320,34 +7161,42 @@ class cmd_switch(Command): that of the master. """ - takes_args = ['to_location?'] - takes_options = ['directory', - Option('force', - help='Switch even if local commits will be lost.'), - 'revision', - Option('create-branch', short_name='b', - help='Create the target branch from this one before' - ' switching to it.'), - Option('store', - help='Store and restore uncommitted changes in the' - ' branch.'), - ] - - def run(self, to_location=None, force=False, create_branch=False, - revision=None, directory='.', store=False): + takes_args = ["to_location?"] + takes_options = [ + "directory", + Option("force", help="Switch even if local commits will be lost."), + "revision", + Option( + "create-branch", + short_name="b", + help="Create the target branch from this one before" " switching to it.", + ), + Option("store", help="Store and restore uncommitted changes in the" " branch."), + ] + + def run( + self, + to_location=None, + force=False, + create_branch=False, + revision=None, + directory=".", + store=False, + ): from . import switch + tree_location = directory - revision = _get_one_revision('switch', revision) + revision = _get_one_revision("switch", revision) control_dir = controldir.ControlDir.open_containing(tree_location)[0] possible_transports = [control_dir.root_transport] if to_location is None: if revision is None: - raise errors.CommandError(gettext('You must supply either a' - ' revision or a location')) + raise errors.CommandError( + gettext("You must supply either a" " revision or a location") + ) to_location = tree_location try: - branch = control_dir.open_branch( - possible_transports=possible_transports) + branch = control_dir.open_branch(possible_transports=possible_transports) had_explicit_nick = branch.get_config().has_explicit_nickname() except errors.NotBranchError: branch = None @@ -6357,49 +7206,65 @@ def run(self, to_location=None, force=False, create_branch=False, if create_branch: if branch is None: raise errors.CommandError( - gettext('cannot create branch without source branch')) + gettext("cannot create branch without source branch") + ) to_location = lookup_new_sibling_branch( - control_dir, to_location, - possible_transports=possible_transports) + control_dir, to_location, possible_transports=possible_transports + ) if revision is not None: revision = revision.as_revision_id(branch) to_branch = branch.controldir.sprout( to_location, possible_transports=possible_transports, revision_id=revision, - source_branch=branch).open_branch() + source_branch=branch, + ).open_branch() else: try: to_branch = Branch.open( - to_location, possible_transports=possible_transports) + to_location, possible_transports=possible_transports + ) except errors.NotBranchError: to_branch = open_sibling_branch( - control_dir, to_location, - possible_transports=possible_transports) + control_dir, to_location, possible_transports=possible_transports + ) if revision is not None: revision = revision.as_revision_id(to_branch) possible_transports.append(to_branch.user_transport) try: - switch.switch(control_dir, to_branch, force, revision_id=revision, - store_uncommitted=store, - possible_transports=possible_transports) + switch.switch( + control_dir, + to_branch, + force, + revision_id=revision, + store_uncommitted=store, + possible_transports=possible_transports, + ) except controldir.BranchReferenceLoop as exc: raise errors.CommandError( - gettext('switching would create a branch reference loop. ' - 'Use the "bzr up" command to switch to a ' - 'different revision.')) from exc + gettext( + "switching would create a branch reference loop. " + 'Use the "bzr up" command to switch to a ' + "different revision." + ) + ) from exc if had_explicit_nick: branch = control_dir.open_branch() # get the new branch! branch.nick = to_branch.nick if to_branch.name: if to_branch.controldir.control_url != control_dir.control_url: - note(gettext('Switched to branch %s at %s'), - to_branch.name, urlutils.unescape_for_display(to_branch.base, 'utf-8')) + note( + gettext("Switched to branch %s at %s"), + to_branch.name, + urlutils.unescape_for_display(to_branch.base, "utf-8"), + ) else: - note(gettext('Switched to branch %s'), to_branch.name) + note(gettext("Switched to branch %s"), to_branch.name) else: - note(gettext('Switched to branch at %s'), - urlutils.unescape_for_display(to_branch.base, 'utf-8')) + note( + gettext("Switched to branch at %s"), + urlutils.unescape_for_display(to_branch.base, "utf-8"), + ) class cmd_view(Command): @@ -6464,104 +7329,109 @@ class cmd_view(Command): brz view --delete --all """ - takes_args = ['file*'] + takes_args = ["file*"] takes_options = [ - Option('all', - help='Apply list or delete action to all views.', - ), - Option('delete', - help='Delete the view.', - ), - Option('name', - help='Name of the view to define, list or delete.', - type=str, - ), - Option('switch', - help='Name of the view to switch to.', - type=str, - ), - ] - - def run(self, file_list, - all=False, - delete=False, - name=None, - switch=None, - ): + Option( + "all", + help="Apply list or delete action to all views.", + ), + Option( + "delete", + help="Delete the view.", + ), + Option( + "name", + help="Name of the view to define, list or delete.", + type=str, + ), + Option( + "switch", + help="Name of the view to switch to.", + type=str, + ), + ] + + def run( + self, + file_list, + all=False, + delete=False, + name=None, + switch=None, + ): from . import views from .workingtree import WorkingTree - tree, file_list = WorkingTree.open_containing_paths(file_list, - apply_view=False) + + tree, file_list = WorkingTree.open_containing_paths(file_list, apply_view=False) current_view, view_dict = tree.views.get_view_info() if name is None: name = current_view if delete: if file_list: - raise errors.CommandError(gettext( - "Both --delete and a file list specified")) + raise errors.CommandError( + gettext("Both --delete and a file list specified") + ) elif switch: - raise errors.CommandError(gettext( - "Both --delete and --switch specified")) + raise errors.CommandError( + gettext("Both --delete and --switch specified") + ) elif all: tree.views.set_view_info(None, {}) self.outf.write(gettext("Deleted all views.\n")) elif name is None: - raise errors.CommandError( - gettext("No current view to delete")) + raise errors.CommandError(gettext("No current view to delete")) else: tree.views.delete_view(name) self.outf.write(gettext("Deleted '%s' view.\n") % name) elif switch: if file_list: - raise errors.CommandError(gettext( - "Both --switch and a file list specified")) + raise errors.CommandError( + gettext("Both --switch and a file list specified") + ) elif all: - raise errors.CommandError(gettext( - "Both --switch and --all specified")) - elif switch == 'off': + raise errors.CommandError(gettext("Both --switch and --all specified")) + elif switch == "off": if current_view is None: - raise errors.CommandError( - gettext("No current view to disable")) + raise errors.CommandError(gettext("No current view to disable")) tree.views.set_view_info(None, view_dict) - self.outf.write(gettext("Disabled '%s' view.\n") % - (current_view)) + self.outf.write(gettext("Disabled '%s' view.\n") % (current_view)) else: tree.views.set_view_info(switch, view_dict) view_str = views.view_display_str(tree.views.lookup_view()) self.outf.write( - gettext("Using '{0}' view: {1}\n").format(switch, view_str)) + gettext("Using '{0}' view: {1}\n").format(switch, view_str) + ) elif all: if view_dict: - self.outf.write(gettext('Views defined:\n')) + self.outf.write(gettext("Views defined:\n")) for view in sorted(view_dict): if view == current_view: active = "=>" else: active = " " view_str = views.view_display_str(view_dict[view]) - self.outf.write('%s %-20s %s\n' % (active, view, view_str)) + self.outf.write("%s %-20s %s\n" % (active, view, view_str)) else: - self.outf.write(gettext('No views defined.\n')) + self.outf.write(gettext("No views defined.\n")) elif file_list: if name is None: # No name given and no current view set - name = 'my' - elif name == 'off': - raise errors.CommandError(gettext( - "Cannot change the 'off' pseudo view")) + name = "my" + elif name == "off": + raise errors.CommandError( + gettext("Cannot change the 'off' pseudo view") + ) tree.views.set_view(name, sorted(file_list)) view_str = views.view_display_str(tree.views.lookup_view()) - self.outf.write( - gettext("Using '{0}' view: {1}\n").format(name, view_str)) + self.outf.write(gettext("Using '{0}' view: {1}\n").format(name, view_str)) else: # list the files if name is None: # No name given and no current view set - self.outf.write(gettext('No current view.\n')) + self.outf.write(gettext("No current view.\n")) else: view_str = views.view_display_str(tree.views.lookup_view(name)) - self.outf.write( - gettext("'{0}' view is: {1}\n").format(name, view_str)) + self.outf.write(gettext("'{0}' view is: {1}\n").format(name, view_str)) class cmd_hooks(Command): @@ -6599,8 +7469,10 @@ class cmd_remove_branch(Command): takes_args = ["location?"] - takes_options = ['directory', - Option('force', help='Remove branch even if it is the active branch.')] + takes_options = [ + "directory", + Option("force", help="Remove branch even if it is the active branch."), + ] aliases = ["rmbranch"] @@ -6611,10 +7483,13 @@ def run(self, directory=None, location=None, force=False): active_branch = br.controldir.open_branch(name="") except errors.NotBranchError: active_branch = None - if (active_branch is not None and - br.control_url == active_branch.control_url): + if ( + active_branch is not None + and br.control_url == active_branch.control_url + ): raise errors.CommandError( - gettext("Branch is active. Use --force to remove it.")) + gettext("Branch is active. Use --force to remove it.") + ) br.controldir.destroy_branch(br.name) @@ -6655,33 +7530,52 @@ class cmd_shelve(Command): """ - takes_args = ['file*'] + takes_args = ["file*"] takes_options = [ - 'directory', - 'revision', - Option('all', help='Shelve all changes.'), - 'message', - RegistryOption('writer', 'Method to use for writing diffs.', - breezy.option.diff_writer_registry, - value_switches=True, enum_switch=False), - - Option('list', help='List shelved changes.'), - Option('destroy', - help='Destroy removed changes instead of shelving them.'), + "directory", + "revision", + Option("all", help="Shelve all changes."), + "message", + RegistryOption( + "writer", + "Method to use for writing diffs.", + breezy.option.diff_writer_registry, + value_switches=True, + enum_switch=False, + ), + Option("list", help="List shelved changes."), + Option("destroy", help="Destroy removed changes instead of shelving them."), ] - _see_also = ['unshelve', 'configuration'] - - def run(self, revision=None, all=False, file_list=None, message=None, - writer=None, list=False, destroy=False, directory=None): + _see_also = ["unshelve", "configuration"] + + def run( + self, + revision=None, + all=False, + file_list=None, + message=None, + writer=None, + list=False, + destroy=False, + directory=None, + ): if list: return self.run_for_list(directory=directory) from .shelf_ui import Shelver + if writer is None: writer = breezy.option.diff_writer_registry.get() try: - shelver = Shelver.from_args(writer(self.outf), revision, all, - file_list, message, destroy=destroy, directory=directory) + shelver = Shelver.from_args( + writer(self.outf), + revision, + all, + file_list, + message, + destroy=destroy, + directory=directory, + ) try: shelver.run() finally: @@ -6691,20 +7585,21 @@ def run(self, revision=None, all=False, file_list=None, message=None, def run_for_list(self, directory=None): from .workingtree import WorkingTree + if directory is None: - directory = '.' + directory = "." tree = WorkingTree.open_containing(directory)[0] self.enter_context(tree.lock_read()) manager = tree.get_shelf_manager() shelves = manager.active_shelves() if len(shelves) == 0: - note(gettext('No shelved changes.')) + note(gettext("No shelved changes.")) return 0 for shelf_id in reversed(shelves): - message = manager.get_metadata(shelf_id).get(b'message') + message = manager.get_metadata(shelf_id).get(b"message") if message is None: - message = '' - self.outf.write('%3d: %s\n' % (shelf_id, message)) + message = "" + self.outf.write("%3d: %s\n" % (shelf_id, message)) return 1 @@ -6716,24 +7611,27 @@ class cmd_unshelve(Command): best when the changes don't depend on each other. """ - takes_args = ['shelf_id?'] + takes_args = ["shelf_id?"] takes_options = [ - 'directory', + "directory", RegistryOption.from_kwargs( - 'action', help="The action to perform.", - enum_switch=False, value_switches=True, + "action", + help="The action to perform.", + enum_switch=False, + value_switches=True, apply="Apply changes and remove from the shelf.", dry_run="Show changes, but do not apply or remove them.", preview="Instead of unshelving the changes, show the diff that " - "would result from unshelving.", + "would result from unshelving.", delete_only="Delete changes without applying them.", keep="Apply changes but don't delete them.", - ) + ), ] - _see_also = ['shelve'] + _see_also = ["shelve"] - def run(self, shelf_id=None, action='apply', directory='.'): + def run(self, shelf_id=None, action="apply", directory="."): from .shelf_ui import Unshelver + unshelver = Unshelver.from_args(shelf_id, action, directory=directory) try: unshelver.run() @@ -6756,25 +7654,42 @@ class cmd_clean_tree(Command): To check what clean-tree will do, use --dry-run. """ - takes_options = ['directory', - Option('ignored', help='Delete all ignored files.'), - Option('detritus', help='Delete conflict files, merge and revert' - ' backups, and failed selftest dirs.'), - Option('unknown', - help='Delete files unknown to brz (default).'), - Option('dry-run', help='Show files to delete instead of' - ' deleting them.'), - Option('force', help='Do not prompt before deleting.')] - - def run(self, unknown=False, ignored=False, detritus=False, dry_run=False, - force=False, directory='.'): + takes_options = [ + "directory", + Option("ignored", help="Delete all ignored files."), + Option( + "detritus", + help="Delete conflict files, merge and revert" + " backups, and failed selftest dirs.", + ), + Option("unknown", help="Delete files unknown to brz (default)."), + Option("dry-run", help="Show files to delete instead of" " deleting them."), + Option("force", help="Do not prompt before deleting."), + ] + + def run( + self, + unknown=False, + ignored=False, + detritus=False, + dry_run=False, + force=False, + directory=".", + ): from .clean_tree import clean_tree + if not (unknown or ignored or detritus): unknown = True if dry_run: force = True - clean_tree(directory, unknown=unknown, ignored=ignored, - detritus=detritus, dry_run=dry_run, no_prompt=force) + clean_tree( + directory, + unknown=unknown, + ignored=ignored, + detritus=detritus, + dry_run=dry_run, + no_prompt=force, + ) class cmd_reference(Command): @@ -6787,23 +7702,26 @@ class cmd_reference(Command): hidden = True - takes_args = ['path?', 'location?'] + takes_args = ["path?", "location?"] takes_options = [ - 'directory', - Option('force-unversioned', - help='Set reference even if path is not versioned.'), - ] - - def run(self, path=None, directory='.', location=None, force_unversioned=False): - tree, branch, relpath = ( - controldir.ControlDir.open_containing_tree_or_branch(directory)) + "directory", + Option( + "force-unversioned", help="Set reference even if path is not versioned." + ), + ] + + def run(self, path=None, directory=".", location=None, force_unversioned=False): + tree, branch, relpath = controldir.ControlDir.open_containing_tree_or_branch( + directory + ) if tree is None: tree = branch.basis_tree() if path is None: with tree.lock_read(): info = [ (path, tree.get_reference_info(path, branch)) - for path in tree.iter_references()] + for path in tree.iter_references() + ] self._display_reference_info(tree, branch, info) else: if not tree.is_versioned(path) and not force_unversioned: @@ -6819,24 +7737,30 @@ def _display_reference_info(self, tree, branch, info): for path, location in info: ref_list.append((path, location)) for path, location in sorted(ref_list): - self.outf.write(f'{path} {location}\n') + self.outf.write(f"{path} {location}\n") class cmd_export_pot(Command): __doc__ = """Export command helps and error messages in po format.""" hidden = True - takes_options = [Option('plugin', - help='Export help text from named command ' - '(defaults to all built in commands).', - type=str), - Option('include-duplicates', - help='Output multiple copies of the same msgid ' - 'string if it appears more than once.'), - ] + takes_options = [ + Option( + "plugin", + help="Export help text from named command " + "(defaults to all built in commands).", + type=str, + ), + Option( + "include-duplicates", + help="Output multiple copies of the same msgid " + "string if it appears more than once.", + ), + ] def run(self, plugin=None, include_duplicates=False): from .export_pot import export_pot + export_pot(self.outf, plugin, include_duplicates) @@ -6854,10 +7778,11 @@ class cmd_import(Command): stripped when extracting the tarball. This is not done for directories. """ - takes_args = ['source', 'tree?'] + takes_args = ["source", "tree?"] def run(self, source, tree=None): from .upstream_import import do_import + do_import(source, tree) @@ -6867,11 +7792,12 @@ class cmd_link_tree(Command): Only files with identical content and execute bit will be linked. """ - takes_args = ['location'] + takes_args = ["location"] def run(self, location): from .transform import link_tree from .workingtree import WorkingTree + target_tree = WorkingTree.open_containing(".")[0] source_tree = WorkingTree.open(location) with target_tree.lock_write(), source_tree.lock_read(): @@ -6885,21 +7811,22 @@ class cmd_fetch_ghosts(Command): """ hidden = True - aliases = ['fetch-missing'] - takes_args = ['branch?'] - takes_options = [Option('no-fix', help="Skip additional synchonization.")] + aliases = ["fetch-missing"] + takes_args = ["branch?"] + takes_options = [Option("no-fix", help="Skip additional synchonization.")] def run(self, branch=None, no_fix=False): from .fetch_ghosts import GhostFetcher + installed, failed = GhostFetcher.from_cmdline(branch).run() if len(installed) > 0: self.outf.write("Installed:\n") for rev in installed: - self.outf.write(rev.decode('utf-8') + "\n") + self.outf.write(rev.decode("utf-8") + "\n") if len(failed) > 0: self.outf.write("Still missing:\n") for rev in failed: - self.outf.write(rev.decode('utf-8') + "\n") + self.outf.write(rev.decode("utf-8") + "\n") if not no_fix and len(installed) > 0: cmd_reconcile().run(".") @@ -6931,80 +7858,133 @@ class cmd_grep(Command): [1] http://docs.python.org/library/re.html#regular-expression-syntax """ - encoding_type = 'replace' - takes_args = ['pattern', 'path*'] + encoding_type = "replace" + takes_args = ["pattern", "path*"] takes_options = [ - 'verbose', - 'revision', - Option('color', type=str, argname='when', - help='Show match in color. WHEN is never, always or auto.'), - Option('diff', short_name='p', - help='Grep for pattern in changeset for each revision.'), - ListOption('exclude', type=str, argname='glob', short_name='X', - help="Skip files whose base name matches GLOB."), - ListOption('include', type=str, argname='glob', short_name='I', - help="Search only files whose base name matches GLOB."), - Option('files-with-matches', short_name='l', - help='Print only the name of each input file in ' - 'which PATTERN is found.'), - Option('files-without-match', short_name='L', - help='Print only the name of each input file in ' - 'which PATTERN is not found.'), - Option('fixed-string', short_name='F', - help='Interpret PATTERN is a single fixed string (not regex).'), - Option('from-root', - help='Search for pattern starting from the root of the branch. ' - '(implies --recursive)'), - Option('ignore-case', short_name='i', - help='Ignore case distinctions while matching.'), - Option('levels', - help='Number of levels to display - 0 for all, 1 for collapsed ' - '(1 is default).', - argname='N', - type=_parse_levels), - Option('line-number', short_name='n', - help='Show 1-based line number.'), - Option('no-recursive', - help="Don't recurse into subdirectories. (default is --recursive)"), - Option('null', short_name='Z', - help='Write an ASCII NUL (\\0) separator ' - 'between output lines rather than a newline.'), - ] + "verbose", + "revision", + Option( + "color", + type=str, + argname="when", + help="Show match in color. WHEN is never, always or auto.", + ), + Option( + "diff", + short_name="p", + help="Grep for pattern in changeset for each revision.", + ), + ListOption( + "exclude", + type=str, + argname="glob", + short_name="X", + help="Skip files whose base name matches GLOB.", + ), + ListOption( + "include", + type=str, + argname="glob", + short_name="I", + help="Search only files whose base name matches GLOB.", + ), + Option( + "files-with-matches", + short_name="l", + help="Print only the name of each input file in " "which PATTERN is found.", + ), + Option( + "files-without-match", + short_name="L", + help="Print only the name of each input file in " + "which PATTERN is not found.", + ), + Option( + "fixed-string", + short_name="F", + help="Interpret PATTERN is a single fixed string (not regex).", + ), + Option( + "from-root", + help="Search for pattern starting from the root of the branch. " + "(implies --recursive)", + ), + Option( + "ignore-case", + short_name="i", + help="Ignore case distinctions while matching.", + ), + Option( + "levels", + help="Number of levels to display - 0 for all, 1 for collapsed " + "(1 is default).", + argname="N", + type=_parse_levels, + ), + Option("line-number", short_name="n", help="Show 1-based line number."), + Option( + "no-recursive", + help="Don't recurse into subdirectories. (default is --recursive)", + ), + Option( + "null", + short_name="Z", + help="Write an ASCII NUL (\\0) separator " + "between output lines rather than a newline.", + ), + ] @display_command - def run(self, verbose=False, ignore_case=False, no_recursive=False, - from_root=False, null=False, levels=None, line_number=False, - path_list=None, revision=None, pattern=None, include=None, - exclude=None, fixed_string=False, files_with_matches=False, - files_without_match=False, color=None, diff=False): + def run( + self, + verbose=False, + ignore_case=False, + no_recursive=False, + from_root=False, + null=False, + levels=None, + line_number=False, + path_list=None, + revision=None, + pattern=None, + include=None, + exclude=None, + fixed_string=False, + files_with_matches=False, + files_without_match=False, + color=None, + diff=False, + ): import re from breezy import terminal from . import grep + if path_list is None: - path_list = ['.'] + path_list = ["."] else: if from_root: - raise errors.CommandError( - 'cannot specify both --from-root and PATH.') + raise errors.CommandError("cannot specify both --from-root and PATH.") if files_with_matches and files_without_match: raise errors.CommandError( - 'cannot specify both ' - '-l/--files-with-matches and -L/--files-without-matches.') + "cannot specify both " + "-l/--files-with-matches and -L/--files-without-matches." + ) global_config = _mod_config.GlobalConfig() if color is None: - color = global_config.get_user_option('grep_color') + color = global_config.get_user_option("grep_color") if color is None: - color = 'auto' + color = "auto" - if color not in ['always', 'never', 'auto']: - raise errors.CommandError('Valid values for --color are ' - '"always", "never" or "auto".') + if color not in ["always", "never", "auto"]: + raise errors.CommandError( + "Valid values for --color are " '"always", "never" or "auto".' + ) if levels is None: levels = 1 @@ -7014,9 +7994,9 @@ def run(self, verbose=False, ignore_case=False, no_recursive=False, # print revision numbers as we may be showing multiple revisions print_revno = True - eol_marker = '\n' + eol_marker = "\n" if null: - eol_marker = '\0' + eol_marker = "\0" if not ignore_case and grep.is_fixed_string(pattern): # if the pattern isalnum, implicitly use to -F for faster grep @@ -7034,13 +8014,14 @@ def run(self, verbose=False, ignore_case=False, no_recursive=False, if not fixed_string: patternc = grep.compile_pattern( - pattern.encode(grep._user_encoding), re_flags) + pattern.encode(grep._user_encoding), re_flags + ) - if color == 'always': + if color == "always": show_color = True - elif color == 'never': + elif color == "never": show_color = False - elif color == 'auto': + elif color == "auto": show_color = terminal.has_ansi_colors() opts = grep.GrepOptions() @@ -7087,28 +8068,37 @@ def run(self, verbose=False, ignore_case=False, no_recursive=False, class cmd_patch(Command): """Apply a named patch to the current tree.""" - takes_args = ['filename?'] - takes_options = [Option('strip', type=int, short_name='p', - help=("Strip the smallest prefix containing num " - "leading slashes from filenames.")), - Option('silent', help='Suppress chatter.')] + takes_args = ["filename?"] + takes_options = [ + Option( + "strip", + type=int, + short_name="p", + help=( + "Strip the smallest prefix containing num " + "leading slashes from filenames." + ), + ), + Option("silent", help="Suppress chatter."), + ] def run(self, filename=None, strip=None, silent=False): from .workingtree import WorkingTree, patch_tree - wt = WorkingTree.open_containing('.')[0] + + wt = WorkingTree.open_containing(".")[0] if strip is None: strip = 1 my_file = None if filename is None: - my_file = getattr(sys.stdin, 'buffer', sys.stdin) + my_file = getattr(sys.stdin, "buffer", sys.stdin) else: - my_file = open(filename, 'rb') + my_file = open(filename, "rb") patches = [my_file.read()] from io import BytesIO + b = BytesIO() patch_tree(wt, patches, strip, quiet=is_quiet(), out=b) - self.outf.write(b.getvalue().decode('utf-8', 'replace')) - + self.outf.write(b.getvalue().decode("utf-8", "replace")) class cmd_resolve_location(Command): @@ -7119,32 +8109,33 @@ class cmd_resolve_location(Command): brz resolve-location lp:brz """ - takes_args = ['location'] + takes_args = ["location"] hidden = True def run(self, location): from .location import location_to_url + url = location_to_url(location) display_url = urlutils.unescape_for_display(url, self.outf.encoding) - self.outf.write(f'{display_url}\n') + self.outf.write(f"{display_url}\n") def _register_lazy_builtins(): # register lazy builtins from other modules; called at startup and should # be only called once. - for (name, aliases, module_name) in [ - ('cmd_bisect', [], 'breezy.bisect'), - ('cmd_bundle_info', [], 'breezy.bzr.bundle.commands'), - ('cmd_config', [], 'breezy.config'), - ('cmd_dump_btree', [], 'breezy.bzr.debug_commands'), - ('cmd_file_id', [], 'breezy.bzr.debug_commands'), - ('cmd_file_path', [], 'breezy.bzr.debug_commands'), - ('cmd_version_info', [], 'breezy.cmd_version_info'), - ('cmd_resolve', ['resolved'], 'breezy.conflicts'), - ('cmd_conflicts', [], 'breezy.conflicts'), - ('cmd_ping', [], 'breezy.bzr.smart.ping'), - ('cmd_sign_my_commits', [], 'breezy.commit_signature_commands'), - ('cmd_verify_signatures', [], 'breezy.commit_signature_commands'), - ('cmd_test_script', [], 'breezy.cmd_test_script'), - ]: + for name, aliases, module_name in [ + ("cmd_bisect", [], "breezy.bisect"), + ("cmd_bundle_info", [], "breezy.bzr.bundle.commands"), + ("cmd_config", [], "breezy.config"), + ("cmd_dump_btree", [], "breezy.bzr.debug_commands"), + ("cmd_file_id", [], "breezy.bzr.debug_commands"), + ("cmd_file_path", [], "breezy.bzr.debug_commands"), + ("cmd_version_info", [], "breezy.cmd_version_info"), + ("cmd_resolve", ["resolved"], "breezy.conflicts"), + ("cmd_conflicts", [], "breezy.conflicts"), + ("cmd_ping", [], "breezy.bzr.smart.ping"), + ("cmd_sign_my_commits", [], "breezy.commit_signature_commands"), + ("cmd_verify_signatures", [], "breezy.commit_signature_commands"), + ("cmd_test_script", [], "breezy.cmd_test_script"), + ]: builtin_command_registry.register_lazy(name, aliases, module_name) diff --git a/breezy/bzr/__init__.py b/breezy/bzr/__init__.py index bbdce15cd4..9ca8161829 100644 --- a/breezy/bzr/__init__.py +++ b/breezy/bzr/__init__.py @@ -32,9 +32,10 @@ class LineEndingError(errors.BzrError): - - _fmt = ("Line ending corrupted for file: %(file)s; " - "Maybe your files got corrupted in transport?") + _fmt = ( + "Line ending corrupted for file: %(file)s; " + "Maybe your files got corrupted in transport?" + ) def __init__(self, file): self.file = file @@ -43,7 +44,9 @@ def __init__(self, file): class BzrProber(controldir.Prober): """Prober for formats that use a .bzr/ control directory.""" - formats = registry.FormatRegistry["BzrDirFormat", None](controldir.network_format_registry) + formats = registry.FormatRegistry["BzrDirFormat", None]( + controldir.network_format_registry + ) """The known .bzr formats.""" @classmethod @@ -58,27 +61,26 @@ def probe_transport(klass, transport): except _mod_transport.NoSuchFile as e: raise errors.NotBranchError(path=transport.base) from e except errors.BadHttpRequest as e: - if e.reason == 'no such method: .bzr': + if e.reason == "no such method: .bzr": # hgweb raise errors.NotBranchError(path=transport.base) from e raise try: - first_line = format_string[:format_string.index(b"\n") + 1] + first_line = format_string[: format_string.index(b"\n") + 1] except ValueError: first_line = format_string - if (first_line.startswith(b' format 3 # state._dirblocks = sorted(state._dirblocks, diff --git a/breezy/bzr/_groupcompress_py.py b/breezy/bzr/_groupcompress_py.py index facb7fba50..ffcec2eb4d 100644 --- a/breezy/bzr/_groupcompress_py.py +++ b/breezy/bzr/_groupcompress_py.py @@ -50,8 +50,7 @@ def _flush_insert(self): if not self.cur_insert_lines: return if self.cur_insert_len > 127: - raise AssertionError('We cannot insert more than 127 bytes' - ' at a time.') + raise AssertionError("We cannot insert more than 127 bytes" " at a time.") self.out_lines.append(bytes([self.cur_insert_len])) self.index_lines.append(False) self.out_lines.extend(self.cur_insert_lines) @@ -70,15 +69,16 @@ def _insert_long_line(self, line): next_len = min(127, line_len - start_index) self.out_lines.append(bytes([next_len])) self.index_lines.append(False) - self.out_lines.append(line[start_index:start_index + next_len]) + self.out_lines.append(line[start_index : start_index + next_len]) # We don't index long lines, because we won't be able to match # a line split across multiple inserts anway self.index_lines.append(False) def add_insert(self, lines): if self.cur_insert_lines != []: - raise AssertionError('self.cur_insert_lines must be empty when' - ' adding a new insert') + raise AssertionError( + "self.cur_insert_lines must be empty when" " adding a new insert" + ) for line in lines: if len(line) > 127: self._insert_long_line(line) @@ -119,9 +119,11 @@ def _update_matching_lines(self, new_lines, index): matches = self._matching_lines start_idx = len(self.lines) if len(new_lines) != len(index): - raise AssertionError('The number of lines to be indexed does' - ' not match the index/don\'t index flags: %d != %d' - % (len(new_lines), len(index))) + raise AssertionError( + "The number of lines to be indexed does" + " not match the index/don't index flags: %d != %d" + % (len(new_lines), len(index)) + ) for idx, do_index in enumerate(index): if not do_index: continue @@ -173,8 +175,9 @@ def _get_longest_match(self, lines, pos): else: # We have a match started, compare to see if any of the # current matches can be continued - next_locations = locations.intersection([loc + 1 for loc - in prev_locations]) + next_locations = locations.intersection( + [loc + 1 for loc in prev_locations] + ) if next_locations: # At least one of the regions continues to match prev_locations = set(next_locations) @@ -225,7 +228,9 @@ def get_matching_blocks(self, lines, soft=False): if block[-1] < min_match_bytes: # This block may be a 'short' block, check old_start, new_start, range_len = block - matched_bytes = sum(map(len, lines[new_start:new_start + range_len])) + matched_bytes = sum( + map(len, lines[new_start : new_start + range_len]) + ) if matched_bytes < min_match_bytes: block = None if block is not None: @@ -247,14 +252,17 @@ def extend_lines(self, lines, index): endpoint += len(line) self.line_offsets.append(endpoint) if len(self.line_offsets) != len(self.lines): - raise AssertionError('Somehow the line offset indicator' - ' got out of sync with the line counter.') + raise AssertionError( + "Somehow the line offset indicator" + " got out of sync with the line counter." + ) self.endpoint = endpoint - def _flush_insert(self, start_linenum, end_linenum, - new_lines, out_lines, index_lines): + def _flush_insert( + self, start_linenum, end_linenum, new_lines, out_lines, index_lines + ): """Add an 'insert' request to the data stream.""" - bytes_to_insert = b''.join(new_lines[start_linenum:end_linenum]) + bytes_to_insert = b"".join(new_lines[start_linenum:end_linenum]) insert_length = len(bytes_to_insert) # Each insert instruction is at most 127 bytes long for start_byte in range(0, insert_length, 127): @@ -262,13 +270,12 @@ def _flush_insert(self, start_linenum, end_linenum, out_lines.append(bytes([insert_count])) # Don't index the 'insert' instruction index_lines.append(False) - insert = bytes_to_insert[start_byte:start_byte + insert_count] + insert = bytes_to_insert[start_byte : start_byte + insert_count] as_lines = osutils.split_lines(insert) out_lines.extend(as_lines) index_lines.extend([True] * len(as_lines)) - def _flush_copy(self, old_start_linenum, num_lines, - out_lines, index_lines): + def _flush_copy(self, old_start_linenum, num_lines, out_lines, index_lines): if old_start_linenum == 0: first_byte = 0 else: @@ -286,10 +293,9 @@ def _flush_copy(self, old_start_linenum, num_lines, def make_delta(self, new_lines, bytes_length, soft=False): """Compute the delta for this content versus the original content.""" # reserved for content type, content length - out_lines = [b'', b'', encode_base128_int(bytes_length)] + out_lines = [b"", b"", encode_base128_int(bytes_length)] index_lines = [False, False, False] - output_handler = _OutputHandler(out_lines, index_lines, - self._MIN_MATCH_BYTES) + output_handler = _OutputHandler(out_lines, index_lines, self._MIN_MATCH_BYTES) blocks = self.get_matching_blocks(new_lines, soft=soft) current_line_num = 0 # We either copy a range (while there are reusable lines) or we @@ -297,8 +303,7 @@ def make_delta(self, new_lines, bytes_length, soft=False): for old_start, new_start, range_len in blocks: if new_start != current_line_num: # non-matching region, insert the content - output_handler.add_insert( - new_lines[current_line_num:new_start]) + output_handler.add_insert(new_lines[current_line_num:new_start]) current_line_num = new_start + range_len if range_len: # Convert the line based offsets into byte based offsets @@ -314,10 +319,11 @@ def make_delta(self, new_lines, bytes_length, soft=False): def make_delta(source_bytes, target_bytes): """Create a delta from source to target.""" if not isinstance(source_bytes, bytes): - raise TypeError('source is not bytes') + raise TypeError("source is not bytes") if not isinstance(target_bytes, bytes): - raise TypeError('target is not bytes') + raise TypeError("target is not bytes") line_locations = LinesDeltaIndex(osutils.split_lines(source_bytes)) - delta, _ = line_locations.make_delta(osutils.split_lines(target_bytes), - bytes_length=len(target_bytes)) - return b''.join(delta) + delta, _ = line_locations.make_delta( + osutils.split_lines(target_bytes), bytes_length=len(target_bytes) + ) + return b"".join(delta) diff --git a/breezy/bzr/_knit_load_data_py.py b/breezy/bzr/_knit_load_data_py.py index c2d7bd4784..348471fb4c 100644 --- a/breezy/bzr/_knit_load_data_py.py +++ b/breezy/bzr/_knit_load_data_py.py @@ -34,7 +34,7 @@ def _load_data_py(kndx, fp): history_top = len(history) - 1 for line in fp.readlines(): rec = line.split() - if len(rec) < 5 or rec[-1] != b':': + if len(rec) < 5 or rec[-1] != b":": # corrupt line. # FIXME: in the future we should determine if its a # short write - and ignore it @@ -44,7 +44,7 @@ def _load_data_py(kndx, fp): try: parents = [] for value in rec[4:-1]: - if value[:1] == b'.': + if value[:1] == b".": # uncompressed reference parent_id = value[1:] else: @@ -65,13 +65,15 @@ def _load_data_py(kndx, fp): try: pos = int(pos) except ValueError as e: - raise KnitCorrupt(kndx._filename, - f"invalid position on line {rec!r}: {e}") from e + raise KnitCorrupt( + kndx._filename, f"invalid position on line {rec!r}: {e}" + ) from e try: size = int(size) except ValueError as e: - raise KnitCorrupt(kndx._filename, - f"invalid size on line {rec!r}: {e}") from e + raise KnitCorrupt( + kndx._filename, f"invalid size on line {rec!r}: {e}" + ) from e # See kndx._cache_version # only want the _history index to reference the 1st @@ -82,10 +84,12 @@ def _load_data_py(kndx, fp): history.append(version_id) else: index = cache[version_id][5] - cache[version_id] = (version_id, - options.split(b','), - pos, - size, - tuple(parents), - index) + cache[version_id] = ( + version_id, + options.split(b","), + pos, + size, + tuple(parents), + index, + ) # end kndx._cache_version diff --git a/breezy/bzr/_static_tuple_py.py b/breezy/bzr/_static_tuple_py.py index c2929e27f0..5ee46e8a9f 100644 --- a/breezy/bzr/_static_tuple_py.py +++ b/breezy/bzr/_static_tuple_py.py @@ -38,18 +38,20 @@ def __init__(self, *args): """Create a new 'StaticTuple'.""" num_keys = len(args) if num_keys < 0 or num_keys > 255: - raise TypeError('StaticTuple(...) takes from 0 to 255 items') + raise TypeError("StaticTuple(...) takes from 0 to 255 items") for bit in args: if type(bit) not in _valid_types: - raise TypeError('StaticTuple can only point to' - ' StaticTuple, str, unicode, int, float, bool, or' - f' None not {type(bit)}') + raise TypeError( + "StaticTuple can only point to" + " StaticTuple, str, unicode, int, float, bool, or" + f" None not {type(bit)}" + ) # We don't need to pass args to tuple.__init__, because that was # already handled in __new__. tuple.__init__(self) def __repr__(self): - return f'{self.__class__.__name__}{tuple.__repr__(self)}' + return f"{self.__class__.__name__}{tuple.__repr__(self)}" def __reduce__(self): return (StaticTuple, tuple(self)) diff --git a/breezy/bzr/branch.py b/breezy/bzr/branch.py index be240735b1..7839ed55eb 100644 --- a/breezy/bzr/branch.py +++ b/breezy/bzr/branch.py @@ -20,13 +20,16 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( config as _mod_config, lockdir, ui, ) -""") +""", +) from .. import errors, urlutils from .. import revision as _mod_revision @@ -72,18 +75,22 @@ class BzrBranch(Branch, _RelockDebugMixin): def control_transport(self) -> _mod_transport.Transport: return self._transport - def __init__(self, *, a_controldir: bzrdir.BzrDir, name: str, - _repository: MetaDirRepository, - _control_files: lockable_files.LockableFiles, - _format=None, - ignore_fallbacks=False, - possible_transports=None): + def __init__( + self, + *, + a_controldir: bzrdir.BzrDir, + name: str, + _repository: MetaDirRepository, + _control_files: lockable_files.LockableFiles, + _format=None, + ignore_fallbacks=False, + possible_transports=None, + ): """Create new branch object at a particular location.""" self.controldir = a_controldir - self._user_transport = self.controldir.transport.clone('..') + self._user_transport = self.controldir.transport.clone("..") if name != "": - self._user_transport.set_segment_parameter( - "branch", urlutils.escape(name)) + self._user_transport.set_segment_parameter("branch", urlutils.escape(name)) self._base = self._user_transport.base self.name = name self._format = _format @@ -95,7 +102,7 @@ def __init__(self, *, a_controldir: bzrdir.BzrDir, name: str, self._tags_bytes = None def __str__(self): - return f'{self.__class__.__name__}({self.user_url})' + return f"{self.__class__.__name__}({self.user_url})" __repr__ = __str__ @@ -119,7 +126,7 @@ def _get_config(self): :return: An object supporting get_option and set_option. """ - return _mod_config.TransportConfig(self._transport, 'branch.conf') + return _mod_config.TransportConfig(self._transport, "branch.conf") def _get_config_store(self): if self.conf_store is None: @@ -143,14 +150,14 @@ def store_uncommitted(self, creator): """ branch = self._uncommitted_branch() if creator is None: - branch._transport.delete('stored-transform') + branch._transport.delete("stored-transform") return - if branch._transport.has('stored-transform'): + if branch._transport.has("stored-transform"): raise errors.ChangesAlreadyStored transform = BytesIO() creator.write_shelf(transform) transform.seek(0) - branch._transport.put_file('stored-transform', transform) + branch._transport.put_file("stored-transform", transform) def get_unshelver(self, tree): """Return a shelf.Unshelver for this branch and tree. @@ -160,10 +167,11 @@ def get_unshelver(self, tree): """ branch = self._uncommitted_branch() try: - transform = branch._transport.get('stored-transform') + transform = branch._transport.get("stored-transform") except _mod_transport.NoSuchFile: return None from ..shelf import Unshelver + return Unshelver.from_tree_and_shelf(tree, transform) def is_locked(self) -> bool: @@ -177,7 +185,7 @@ def lock_write(self, token=None): :return: A BranchWriteLockResult. """ if not self.is_locked(): - self._note_lock('w') + self._note_lock("w") self.repository._warn_if_deprecated(self) self.repository.lock_write() took_lock = True @@ -185,8 +193,8 @@ def lock_write(self, token=None): took_lock = False try: return BranchWriteLockResult( - self.unlock, - self.control_files.lock_write(token=token)) + self.unlock, self.control_files.lock_write(token=token) + ) except BaseException: if took_lock: self.repository.unlock() @@ -198,7 +206,7 @@ def lock_read(self): :return: A breezy.lock.LogicalLockResult. """ if not self.is_locked(): - self._note_lock('r') + self._note_lock("r") self.repository._warn_if_deprecated(self) self.repository.lock_read() took_lock = True @@ -235,8 +243,7 @@ def get_physical_lock_status(self): def set_last_revision_info(self, revno, revision_id): if not revision_id or not isinstance(revision_id, bytes): - raise errors.InvalidRevisionId( - revision_id=revision_id, branch=self) + raise errors.InvalidRevisionId(revision_id=revision_id, branch=self) with self.lock_write(): old_revno, old_revid = self.last_revision_info() if self.get_append_revisions_only(): @@ -252,14 +259,14 @@ def basis_tree(self): return self.repository.revision_tree(self.last_revision()) def _get_parent_location(self): - _locs = ['parent', 'pull', 'x-pull'] + _locs = ["parent", "pull", "x-pull"] for l in _locs: try: contents = self._transport.get_bytes(l) except _mod_transport.NoSuchFile: pass else: - return contents.strip(b'\n').decode('utf-8') + return contents.strip(b"\n").decode("utf-8") return None def get_stacked_on_url(self): @@ -268,17 +275,18 @@ def get_stacked_on_url(self): def set_push_location(self, location): """See Branch.set_push_location.""" self.get_config().set_user_option( - 'push_location', location, - store=_mod_config.STORE_LOCATION_NORECURSE) + "push_location", location, store=_mod_config.STORE_LOCATION_NORECURSE + ) def _set_parent_location(self, url): if url is None: - self._transport.delete('parent') + self._transport.delete("parent") else: if isinstance(url, str): - url = url.encode('utf-8') - self._transport.put_bytes('parent', url + b'\n', - mode=self.controldir._get_file_mode()) + url = url.encode("utf-8") + self._transport.put_bytes( + "parent", url + b"\n", mode=self.controldir._get_file_mode() + ) def unbind(self): """If bound, unbind.""" @@ -313,7 +321,7 @@ def bind(self, other): def get_bound_location(self): try: - return self._transport.get_bytes('bound')[:-1].decode('utf-8') + return self._transport.get_bytes("bound")[:-1].decode("utf-8") except _mod_transport.NoSuchFile: return None @@ -324,8 +332,7 @@ def get_master_branch(self, possible_transports=None): """ with self.lock_read(): if self._master_branch_cache is None: - self._master_branch_cache = self._get_master_branch( - possible_transports) + self._master_branch_cache = self._get_master_branch(possible_transports) return self._master_branch_cache def _get_master_branch(self, possible_transports): @@ -333,11 +340,9 @@ def _get_master_branch(self, possible_transports): if not bound_loc: return None try: - return Branch.open(bound_loc, - possible_transports=possible_transports) + return Branch.open(bound_loc, possible_transports=possible_transports) except (errors.NotBranchError, ConnectionError) as exc: - raise errors.BoundBranchConnectionFailure( - self, bound_loc, exc) from exc + raise errors.BoundBranchConnectionFailure(self, bound_loc, exc) from exc def set_bound_location(self, location): """Set the target where this branch is bound to. @@ -348,11 +353,13 @@ def set_bound_location(self, location): self._master_branch_cache = None if location: self._transport.put_bytes( - 'bound', location.encode('utf-8') + b'\n', - mode=self.controldir._get_file_mode()) + "bound", + location.encode("utf-8") + b"\n", + mode=self.controldir._get_file_mode(), + ) else: try: - self._transport.delete('bound') + self._transport.delete("bound") except _mod_transport.NoSuchFile: return False return True @@ -369,15 +376,17 @@ def update(self, possible_transports=None): old_tip = self.last_revision() self.pull(master, overwrite=True) if self.repository.get_graph().is_ancestor( - old_tip, self.last_revision()): + old_tip, self.last_revision() + ): return None return old_tip return None def _read_last_revision_info(self): from .. import cache_utf8 - revision_string = self._transport.get_bytes('last-revision') - revno, revision_id = revision_string.rstrip(b'\n').split(b' ', 1) + + revision_string = self._transport.get_bytes("last-revision") + revno, revision_id = revision_string.rstrip(b"\n").split(b" ", 1) revision_id = cache_utf8.get_cached_utf8(revision_id) revno = int(revno) return revno, revision_id @@ -389,9 +398,10 @@ def _write_last_revision_info(self, revno, revision_id): Does not update the revision_history cache. """ - out_string = b'%d %s\n' % (revno, revision_id) - self._transport.put_bytes('last-revision', out_string, - mode=self.controldir._get_file_mode()) + out_string = b"%d %s\n" % (revno, revision_id) + self._transport.put_bytes( + "last-revision", out_string, mode=self.controldir._get_file_mode() + ) def update_feature_flags(self, updated_flags): """Update the feature flags for this branch. @@ -401,8 +411,7 @@ def update_feature_flags(self, updated_flags): """ with self.lock_write(): self._format._update_feature_flags(updated_flags) - self.control_transport.put_bytes( - 'format', self._format.as_string()) + self.control_transport.put_bytes("format", self._format.as_string()) def _get_tags_bytes(self): """Get the bytes of a serialised tags dict. @@ -418,7 +427,7 @@ def _get_tags_bytes(self): """ with self.lock_read(): if self._tags_bytes is None: - self._tags_bytes = self._transport.get_bytes('tags') + self._tags_bytes = self._transport.get_bytes("tags") return self._tags_bytes def _set_tags_bytes(self, bytes): @@ -428,7 +437,7 @@ def _set_tags_bytes(self, bytes): """ with self.lock_write(): self._tags_bytes = bytes - return self._transport.put_bytes('tags', bytes) + return self._transport.put_bytes("tags", bytes) def _clear_cached_state(self): super()._clear_cached_state() @@ -437,6 +446,7 @@ def _clear_cached_state(self): def reconcile(self, thorough=True): """Make sure the data stored in this branch is consistent.""" from .reconcile import BranchReconciler + with self.lock_write(): reconciler = BranchReconciler(self, thorough=thorough) return reconciler.reconcile() @@ -463,13 +473,16 @@ def reference_parent(self, file_id, path, possible_transports=None): try: return Branch.open_from_transport( self.controldir.root_transport.clone(path), - possible_transports=possible_transports) + possible_transports=possible_transports, + ) except errors.NotBranchError: return None return Branch.open( urlutils.join( - urlutils.strip_segment_parameters(self.user_url), branch_location), - possible_transports=possible_transports) + urlutils.strip_segment_parameters(self.user_url), branch_location + ), + possible_transports=possible_transports, + ) def set_stacked_on_url(self, url: str) -> None: """Set the URL this branch is stacked against. @@ -489,21 +502,26 @@ def set_stacked_on_url(self, url: str) -> None: if not url: try: self.get_stacked_on_url() - except (errors.NotStacked, UnstackableBranchFormat, - errors.UnstackableRepositoryFormat): + except ( + errors.NotStacked, + UnstackableBranchFormat, + errors.UnstackableRepositoryFormat, + ): return self._unstack() else: self._activate_fallback_location( - url, possible_transports=[self.controldir.root_transport]) + url, possible_transports=[self.controldir.root_transport] + ) # write this out after the repository is stacked to avoid setting a # stacked config that doesn't work. - self._set_config_location('stacked_on_location', url) + self._set_config_location("stacked_on_location", url) def _check_stackable_repo(self) -> None: if not self.repository._format.supports_external_lookups: raise errors.UnstackableRepositoryFormat( - self.repository._format, self.repository.user_url) + self.repository._format, self.repository.user_url + ) def _unstack(self): """Change a branch to be unstacked, copying data as needed. @@ -511,6 +529,7 @@ def _unstack(self): Don't call this directly, use set_stacked_on_url(None). """ from .vf_search import NotInOtherForRevs + with ui.ui_factory.nested_progress_bar(): # The basic approach here is to fetch the tip of the branch, # including all available ghosts, from the existing stacked @@ -523,7 +542,9 @@ def _unstack(self): raise AssertionError( "can't cope with fallback repositories " "of {!r} (fallbacks: {!r})".format( - old_repository, old_repository._fallback_repositories)) + old_repository, old_repository._fallback_repositories + ) + ) # Open the new repository object. # Repositories don't offer an interface to remove fallback # repositories today; take the conceptually simpler option and just @@ -532,12 +553,12 @@ def _unstack(self): # stream from one of them to the other. This does mean doing # separate SSH connection setup, but unstacking is not a # common operation so it's tolerable. - new_bzrdir = ControlDir.open( - self.controldir.root_transport.base) + new_bzrdir = ControlDir.open(self.controldir.root_transport.base) new_repository = new_bzrdir.find_repository() if new_repository._fallback_repositories: raise AssertionError( - f"didn't expect {self.repository!r} to have fallback_repositories") + f"didn't expect {self.repository!r} to have fallback_repositories" + ) # Replace self.repository with the new repository. # Do our best to transfer the lock state (i.e. lock-tokens and # lock count) of self.repository to the new repository. @@ -563,7 +584,8 @@ def _unstack(self): old_lock_count += 1 if old_lock_count == 0: raise AssertionError( - 'old_repository should have been locked at least once.') + "old_repository should have been locked at least once." + ) for _i in range(old_lock_count - 1): self.repository.lock_write() # Fetch from the old repository into the new. @@ -576,9 +598,12 @@ def _unstack(self): except errors.TagsNotSupported: tags_to_fetch = set() fetch_spec = NotInOtherForRevs( - self.repository, old_repository, + self.repository, + old_repository, required_ids=[self.last_revision()], - if_present_ids=tags_to_fetch, find_ghosts=True).execute() + if_present_ids=tags_to_fetch, + find_ghosts=True, + ).execute() self.repository.fetch(old_repository, fetch_spec=fetch_spec) def break_lock(self) -> None: @@ -606,22 +631,27 @@ def _open_hook(self, possible_transports=None): possible_transports = [self.controldir.root_transport] try: url = self.get_stacked_on_url() - except (errors.UnstackableRepositoryFormat, errors.NotStacked, - UnstackableBranchFormat): + except ( + errors.UnstackableRepositoryFormat, + errors.NotStacked, + UnstackableBranchFormat, + ): pass else: - for hook in Branch.hooks['transform_fallback_location']: + for hook in Branch.hooks["transform_fallback_location"]: url = hook(self, url) if url is None: hook_name = Branch.hooks.get_hook_name(hook) raise AssertionError( "'transform_fallback_location' hook %s returned " - "None, not a URL." % hook_name) + "None, not a URL." % hook_name + ) self._activate_fallback_location( - url, possible_transports=possible_transports) + url, possible_transports=possible_transports + ) def __init__(self, *args, **kwargs): - self._ignore_fallbacks = kwargs.get('ignore_fallbacks', False) + self._ignore_fallbacks = kwargs.get("ignore_fallbacks", False) super().__init__(*args, **kwargs) self._last_revision_info_cache = None self._reference_info = None @@ -650,13 +680,12 @@ def _gen_revision_history(self): def _set_parent_location(self, url): """Set the parent branch.""" with self.lock_write(): - self._set_config_location( - 'parent_location', url, make_relative=True) + self._set_config_location("parent_location", url, make_relative=True) def _get_parent_location(self): """Set the parent branch.""" with self.lock_read(): - return self._get_config_location('parent_location') + return self._get_config_location("parent_location") def _set_all_reference_info(self, info_dict): """Replace all reference info stored in a branch. @@ -666,13 +695,14 @@ def _set_all_reference_info(self, info_dict): s = BytesIO() writer = rio.RioWriter(s) for file_id, (branch_location, tree_path) in info_dict.items(): - stanza = rio.Stanza(file_id=file_id.decode('utf-8'), - branch_location=branch_location) + stanza = rio.Stanza( + file_id=file_id.decode("utf-8"), branch_location=branch_location + ) if tree_path is not None: - stanza.add('tree_path', tree_path) + stanza.add("tree_path", tree_path) writer.write_stanza(stanza) with self.lock_write(): - self._transport.put_bytes('references', s.getvalue()) + self._transport.put_bytes("references", s.getvalue()) self._reference_info = info_dict def _get_all_reference_info(self): @@ -684,13 +714,15 @@ def _get_all_reference_info(self): if self._reference_info is not None: return self._reference_info try: - with self._transport.get('references') as rio_file: + with self._transport.get("references") as rio_file: stanzas = rio.read_stanzas(rio_file) info_dict = { - s.get('file_id').encode('utf-8'): ( - s.get('branch_location'), - s.get('tree_path') if 'tree_path' in s else None) - for s in stanzas} + s.get("file_id").encode("utf-8"): ( + s.get("branch_location"), + s.get("tree_path") if "tree_path" in s else None, + ) + for s in stanzas + } except _mod_transport.NoSuchFile: info_dict = {} self._reference_info = info_dict @@ -719,22 +751,21 @@ def get_reference_info(self, file_id): def set_push_location(self, location): """See Branch.set_push_location.""" - self._set_config_location('push_location', location) + self._set_config_location("push_location", location) def set_bound_location(self, location): """See Branch.set_push_location.""" self._master_branch_cache = None conf = self.get_config_stack() if location is None: - if not conf.get('bound'): + if not conf.get("bound"): return False else: - conf.set('bound', 'False') + conf.set("bound", "False") return True else: - self._set_config_location('bound_location', location, - config=conf) - conf.set('bound', 'True') + self._set_config_location("bound_location", location, config=conf) + conf.set("bound", "True") return True def _get_bound_location(self, bound): @@ -743,9 +774,9 @@ def _get_bound_location(self, bound): Return None if the bound parameter does not match """ conf = self.get_config_stack() - if conf.get('bound') != bound: + if conf.get("bound") != bound: return None - return self._get_config_location('bound_location', config=conf) + return self._get_config_location("bound_location", config=conf) def get_bound_location(self): """See Branch.get_bound_location.""" @@ -762,8 +793,7 @@ def get_stacked_on_url(self): # stacked_on_location is only ever defined in branch.conf, so don't # waste effort reading the whole stack of config files. conf = _mod_config.BranchOnlyStack(self) - stacked_url = self._get_config_location('stacked_on_location', - config=conf) + stacked_url = self._get_config_location("stacked_on_location", config=conf) if stacked_url is None: raise errors.NotStacked(self) # TODO(jelmer): Clean this up for pad.lv/1696545 @@ -802,7 +832,8 @@ def revision_id_to_revno(self, revision_id): self._extend_partial_history(stop_revision=revision_id) except errors.RevisionNotPresent as exc: raise errors.GhostRevisionsHaveNoRevno( - revision_id, exc.revision_id) from exc + revision_id, exc.revision_id + ) from exc index = len(self._partial_revision_history_cache) - 1 if index < 0: raise errors.NoSuchRevision(self, revision_id) from e @@ -815,11 +846,10 @@ class BzrBranch7(BzrBranch8): """A branch with support for a fallback repository.""" def set_reference_info(self, file_id, branch_location, tree_path=None): - super().set_reference_info( - file_id, branch_location, tree_path) + super().set_reference_info(file_id, branch_location, tree_path) format_string = BzrBranchFormat8.get_format_string() - mutter('Upgrading branch to format %r', format_string) - self._transport.put_bytes('format', format_string) + mutter("Upgrading branch to format %r", format_string) + self._transport.put_bytes("format", format_string) class BzrBranch6(BzrBranch7): @@ -851,8 +881,9 @@ def find_format(klass, controldir, name=None): format_string = transport.get_bytes("format") except _mod_transport.NoSuchFile as exc: raise errors.NotBranchError( - path=transport.base, controldir=controldir) from exc - return klass._find_format(format_registry, 'branch', format_string) + path=transport.base, controldir=controldir + ) from exc + return klass._find_format(format_registry, "branch", format_string) def _branch_class(self): """What class to instantiate on open calls.""" @@ -866,8 +897,7 @@ def _get_initial_config(self, append_revisions_only=None): # as that is the default. return b"" - def _initialize_helper(self, a_controldir, utf8_files, name=None, - repository=None): + def _initialize_helper(self, a_controldir, utf8_files, name=None, repository=None): """Initialize a branch in a control dir, with specified files. :param a_controldir: The bzrdir to initialize the branch in @@ -878,28 +908,34 @@ def _initialize_helper(self, a_controldir, utf8_files, name=None, """ if name is None: name = a_controldir._get_selected_branch() - mutter('creating branch %r in %s', self, a_controldir.user_url) + mutter("creating branch %r in %s", self, a_controldir.user_url) branch_transport = a_controldir.get_branch_transport(self, name=name) - control_files = lockable_files.LockableFiles(branch_transport, - 'lock', lockdir.LockDir) + control_files = lockable_files.LockableFiles( + branch_transport, "lock", lockdir.LockDir + ) control_files.create_lock() control_files.lock_write() try: - utf8_files += [('format', self.as_string())] - for (filename, content) in utf8_files: + utf8_files += [("format", self.as_string())] + for filename, content in utf8_files: branch_transport.put_bytes( - filename, content, - mode=a_controldir._get_file_mode()) + filename, content, mode=a_controldir._get_file_mode() + ) finally: control_files.unlock() - branch = self.open(a_controldir, name, _found=True, - found_repository=repository) + branch = self.open(a_controldir, name, _found=True, found_repository=repository) self._run_post_branch_init_hooks(a_controldir, name, branch) return branch - def open(self, a_controldir, name=None, _found=False, - ignore_fallbacks=False, found_repository=None, - possible_transports=None): + def open( + self, + a_controldir, + name=None, + _found=False, + ignore_fallbacks=False, + found_repository=None, + possible_transports=None, + ): """See BranchFormat.open().""" if name is None: name = a_controldir._get_selected_branch() @@ -909,18 +945,24 @@ def open(self, a_controldir, name=None, _found=False, raise AssertionError(f"wrong format {format!r} found for {self!r}") transport = a_controldir.get_branch_transport(None, name=name) try: - control_files = lockable_files.LockableFiles(transport, 'lock', - lockdir.LockDir) + control_files = lockable_files.LockableFiles( + transport, "lock", lockdir.LockDir + ) if found_repository is None: found_repository = a_controldir.find_repository() return self._branch_class()( - _format=self, _control_files=control_files, name=name, - a_controldir=a_controldir, _repository=found_repository, + _format=self, + _control_files=control_files, + name=name, + a_controldir=a_controldir, + _repository=found_repository, ignore_fallbacks=ignore_fallbacks, - possible_transports=possible_transports) + possible_transports=possible_transports, + ) except _mod_transport.NoSuchFile as exc: raise errors.NotBranchError( - path=transport.base, controldir=a_controldir) from exc + path=transport.base, controldir=a_controldir + ) from exc @property def _matchingcontroldir(self): @@ -934,14 +976,21 @@ def supports_tags(self): def supports_leaving_lock(self): return True - def check_support_status(self, allow_unsupported, recommend_upgrade=True, - basedir=None): + def check_support_status( + self, allow_unsupported, recommend_upgrade=True, basedir=None + ): BranchFormat.check_support_status( - self, allow_unsupported=allow_unsupported, - recommend_upgrade=recommend_upgrade, basedir=basedir) + self, + allow_unsupported=allow_unsupported, + recommend_upgrade=recommend_upgrade, + basedir=basedir, + ) bzrdir.BzrFormat.check_support_status( - self, allow_unsupported=allow_unsupported, - recommend_upgrade=recommend_upgrade, basedir=basedir) + self, + allow_unsupported=allow_unsupported, + recommend_upgrade=recommend_upgrade, + basedir=basedir, + ) class BzrBranchFormat6(BranchFormatMetadir): @@ -967,20 +1016,21 @@ def get_format_description(self): """See BranchFormat.get_format_description().""" return "Branch format 6" - def initialize(self, a_controldir, name=None, repository=None, - append_revisions_only=None): + def initialize( + self, a_controldir, name=None, repository=None, append_revisions_only=None + ): """Create a branch of this format in a_controldir.""" utf8_files = [ - ('last-revision', b'0 null:\n'), - ('branch.conf', self._get_initial_config(append_revisions_only)), - ('tags', b''), - ] - return self._initialize_helper( - a_controldir, utf8_files, name, repository) + ("last-revision", b"0 null:\n"), + ("branch.conf", self._get_initial_config(append_revisions_only)), + ("tags", b""), + ] + return self._initialize_helper(a_controldir, utf8_files, name, repository) def make_tags(self, branch): """See breezy.branch.BranchFormat.make_tags().""" from .tag import BasicTags + return BasicTags(branch) def supports_set_append_revisions_only(self): @@ -1004,21 +1054,22 @@ def get_format_description(self): """See BranchFormat.get_format_description().""" return "Branch format 8" - def initialize(self, a_controldir, name=None, repository=None, - append_revisions_only=None): + def initialize( + self, a_controldir, name=None, repository=None, append_revisions_only=None + ): """Create a branch of this format in a_controldir.""" - utf8_files = [('last-revision', b'0 null:\n'), - ('branch.conf', - self._get_initial_config(append_revisions_only)), - ('tags', b''), - ('references', b'') - ] - return self._initialize_helper( - a_controldir, utf8_files, name, repository) + utf8_files = [ + ("last-revision", b"0 null:\n"), + ("branch.conf", self._get_initial_config(append_revisions_only)), + ("tags", b""), + ("references", b""), + ] + return self._initialize_helper(a_controldir, utf8_files, name, repository) def make_tags(self, branch): """See breezy.branch.BranchFormat.make_tags().""" from .tag import BasicTags + return BasicTags(branch) def supports_set_append_revisions_only(self): @@ -1039,16 +1090,16 @@ class BzrBranchFormat7(BranchFormatMetadir): This format was introduced in bzr 1.6. """ - def initialize(self, a_controldir, name=None, repository=None, - append_revisions_only=None): + def initialize( + self, a_controldir, name=None, repository=None, append_revisions_only=None + ): """Create a branch of this format in a_controldir.""" - utf8_files = [('last-revision', b'0 null:\n'), - ('branch.conf', - self._get_initial_config(append_revisions_only)), - ('tags', b''), - ] - return self._initialize_helper( - a_controldir, utf8_files, name, repository) + utf8_files = [ + ("last-revision", b"0 null:\n"), + ("branch.conf", self._get_initial_config(append_revisions_only)), + ("tags", b""), + ] + return self._initialize_helper(a_controldir, utf8_files, name, repository) def _branch_class(self): return BzrBranch7 @@ -1071,6 +1122,7 @@ def supports_stacking(self): def make_tags(self, branch): """See breezy.branch.BranchFormat.make_tags().""" from .tag import BasicTags + return BasicTags(branch) # This is a white lie; as soon as you set a reference location, we upgrade @@ -1102,8 +1154,7 @@ def get_reference(self, a_controldir, name=None): """See BranchFormat.get_reference().""" transport = a_controldir.get_branch_transport(None, name=name) url = urlutils.strip_segment_parameters(a_controldir.user_url) - return urlutils.join( - url, transport.get_bytes('location').decode('utf-8')) + return urlutils.join(url, transport.get_bytes("location").decode("utf-8")) def _write_reference(self, a_controldir, transport, to_branch): to_url = to_branch.user_url @@ -1113,47 +1164,71 @@ def _write_reference(self, a_controldir, transport, to_branch): # does not support relative URLs. See pad.lv/1803845 -- jelmer # to_url = urlutils.relative_url( # a_controldir.user_url, to_branch.user_url) - transport.put_bytes('location', to_url.encode('utf-8')) + transport.put_bytes("location", to_url.encode("utf-8")) def set_reference(self, a_controldir, name, to_branch): """See BranchFormat.set_reference().""" transport = a_controldir.get_branch_transport(None, name=name) self._write_reference(a_controldir, transport, to_branch) - def initialize(self, a_controldir, name=None, target_branch=None, - repository=None, append_revisions_only=None): + def initialize( + self, + a_controldir, + name=None, + target_branch=None, + repository=None, + append_revisions_only=None, + ): """Create a branch of this format in a_controldir.""" if target_branch is None: # this format does not implement branch itself, thus the implicit # creation contract must see it as uninitializable raise errors.UninitializableFormat(self) - mutter('creating branch reference in %s', a_controldir.user_url) + mutter("creating branch reference in %s", a_controldir.user_url) if a_controldir._format.fixed_components: raise errors.IncompatibleFormat(self, a_controldir._format) if name is None: name = a_controldir._get_selected_branch() branch_transport = a_controldir.get_branch_transport(self, name=name) self._write_reference(a_controldir, branch_transport, target_branch) - branch_transport.put_bytes('format', self.as_string()) - branch = self.open(a_controldir, name, _found=True, - possible_transports=[target_branch.controldir.root_transport]) + branch_transport.put_bytes("format", self.as_string()) + branch = self.open( + a_controldir, + name, + _found=True, + possible_transports=[target_branch.controldir.root_transport], + ) self._run_post_branch_init_hooks(a_controldir, name, branch) return branch def _make_reference_clone_function(self, a_branch): """Create a clone() routine for a branch dynamically.""" - def clone(to_bzrdir, revision_id=None, repository_policy=None, name=None, - tag_selector=None): + + def clone( + to_bzrdir, + revision_id=None, + repository_policy=None, + name=None, + tag_selector=None, + ): """See Branch.clone().""" return self.initialize(to_bzrdir, target_branch=a_branch, name=name) # cannot obey revision_id limits when cloning a reference ... # FIXME RBC 20060210 either nuke revision_id for clone, or # emit some sort of warning/error to the caller ?! + return clone - def open(self, a_controldir, name=None, _found=False, location=None, - possible_transports=None, ignore_fallbacks=False, - found_repository=None): + def open( + self, + a_controldir, + name=None, + _found=False, + location=None, + possible_transports=None, + ignore_fallbacks=False, + found_repository=None, + ): """Return the branch that the branch reference in a_controldir points at. :param a_controldir: A BzrDir that contains a branch. @@ -1175,11 +1250,10 @@ def open(self, a_controldir, name=None, _found=False, location=None, raise AssertionError(f"wrong format {format!r} found for {self!r}") if location is None: location = self.get_reference(a_controldir, name) - real_bzrdir = ControlDir.open( - location, possible_transports=possible_transports) + real_bzrdir = ControlDir.open(location, possible_transports=possible_transports) result = real_bzrdir.open_branch( - ignore_fallbacks=ignore_fallbacks, - possible_transports=possible_transports) + ignore_fallbacks=ignore_fallbacks, possible_transports=possible_transports + ) # this changes the behaviour of result.clone to create a new reference # rather than a copy of the content of the branch. # I did not use a proxy object because that needs much more extensive @@ -1212,11 +1286,11 @@ def convert(self, branch): # Copying done; now update target format new_branch._transport.put_bytes( - 'format', format.as_string(), - mode=new_branch.controldir._get_file_mode()) + "format", format.as_string(), mode=new_branch.controldir._get_file_mode() + ) # Clean up old files - new_branch._transport.delete('revision-history') + new_branch._transport.delete("revision-history") with branch.lock_write(): try: branch.set_parent(None) @@ -1230,9 +1304,9 @@ class Converter6to7: def convert(self, branch): format = BzrBranchFormat7() - branch._set_config_location('stacked_on_location', '') + branch._set_config_location("stacked_on_location", "") # update target format - branch._transport.put_bytes('format', format.as_string()) + branch._transport.put_bytes("format", format.as_string()) class Converter7to8: @@ -1240,6 +1314,6 @@ class Converter7to8: def convert(self, branch): format = BzrBranchFormat8() - branch._transport.put_bytes('references', b'') + branch._transport.put_bytes("references", b"") # update target format - branch._transport.put_bytes('format', format.as_string()) + branch._transport.put_bytes("format", format.as_string()) diff --git a/breezy/bzr/btree_index.py b/breezy/bzr/btree_index.py index d31c0d97d2..57ead36135 100644 --- a/breezy/bzr/btree_index.py +++ b/breezy/bzr/btree_index.py @@ -21,11 +21,14 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import math import tempfile import zlib -""") +""", +) from .. import chunk_writer, debug, fifo_cache, lru_cache, osutils, trace, transport from . import index as _mod_index @@ -67,7 +70,7 @@ def finish_node(self, pad=True): self.spool.write(b"\x00" * _RESERVED_HEADER_BYTES) elif self.nodes == 1: # We got bigger than 1 node, switch to a temp file - spool = tempfile.TemporaryFile(prefix='bzr-index-row-') + spool = tempfile.TemporaryFile(prefix="bzr-index-row-") spool.write(self.spool.getvalue()) self.spool = spool skipped_bytes = 0 @@ -77,8 +80,9 @@ def finish_node(self, pad=True): self.spool.writelines(byte_lines) remainder = (self.spool.tell() + skipped_bytes) % _PAGE_SIZE if remainder != 0: - raise AssertionError("incorrect node length: %d, %d" - % (self.spool.tell(), remainder)) + raise AssertionError( + "incorrect node length: %d, %d" % (self.spool.tell(), remainder) + ) self.nodes += 1 self.writer = None @@ -130,8 +134,9 @@ def __init__(self, reference_lists=0, key_elements=1, spill_at=100000): :param spill_at: Optional parameter controlling the maximum number of nodes that BTreeBuilder will hold in memory. """ - _mod_index.GraphIndexBuilder.__init__(self, reference_lists=reference_lists, - key_elements=key_elements) + _mod_index.GraphIndexBuilder.__init__( + self, reference_lists=reference_lists, key_elements=key_elements + ) self._spill_at = spill_at self._backing_indices = [] # A map of {key: (node_refs, value)} @@ -181,14 +186,14 @@ def _spill_mem_keys_to_disk(self): size 4x. On the fifth create a single new one, etc. """ if self._combine_backing_indices: - (new_backing_file, size, - backing_pos) = self._spill_mem_keys_and_combine() + (new_backing_file, size, backing_pos) = self._spill_mem_keys_and_combine() else: new_backing_file, size = self._spill_mem_keys_without_combining() # Note: The transport here isn't strictly needed, because we will use # direct access to the new_backing._file object - new_backing = BTreeGraphIndex(transport.get_transport_from_path('.'), - '', size) + new_backing = BTreeGraphIndex( + transport.get_transport_from_path("."), "", size + ) # GC will clean up the file new_backing._file = new_backing_file if self._combine_backing_indices: @@ -214,9 +219,9 @@ def _spill_mem_keys_and_combine(self): break iterators_to_combine.append(backing.iter_all_entries()) backing_pos = pos + 1 - new_backing_file, size = \ - self._write_nodes(self._iter_smallest(iterators_to_combine), - allow_optimize=False) + new_backing_file, size = self._write_nodes( + self._iter_smallest(iterators_to_combine), allow_optimize=False + ) return new_backing_file, size, backing_pos def add_nodes(self, nodes): @@ -225,10 +230,10 @@ def add_nodes(self, nodes): :param nodes: An iterable of (key, node_refs, value) entries to add. """ if self.reference_lists: - for (key, value, node_refs) in nodes: + for key, value, node_refs in nodes: self.add_node(key, value, node_refs) else: - for (key, value) in nodes: + for key, value in nodes: self.add_node(key, value) def _iter_mem_nodes(self): @@ -256,8 +261,11 @@ def _iter_smallest(self, iterators_to_combine): last = None while True: # Decorate candidates with the value to allow 2.4's min to be used. - candidates = [(item[1][1], item) for item - in enumerate(current_values) if item[1] is not None] + candidates = [ + (item[1][1], item) + for item in enumerate(current_values) + if item[1] is not None + ] if not len(candidates): return selected = min(candidates) @@ -299,16 +307,19 @@ def _add_key(self, string_key, line, rows, allow_optimize=True): else: optimize_for_size = False internal_row.writer = chunk_writer.ChunkWriter( - length, 0, optimize_for_size=optimize_for_size) + length, 0, optimize_for_size=optimize_for_size + ) internal_row.writer.write(_INTERNAL_FLAG) - internal_row.writer.write(_INTERNAL_OFFSET - + b"%d\n" % rows[pos + 1].nodes) + internal_row.writer.write( + _INTERNAL_OFFSET + b"%d\n" % rows[pos + 1].nodes + ) # add a new leaf length = _PAGE_SIZE if rows[-1].nodes == 0: length -= _RESERVED_HEADER_BYTES # padded rows[-1].writer = chunk_writer.ChunkWriter( - length, optimize_for_size=self._optimize_for_size) + length, optimize_for_size=self._optimize_for_size + ) rows[-1].writer.write(_LEAF_FLAG) if rows[-1].writer.write(line): # if we failed to write, despite having an empty page to write to, @@ -334,8 +345,8 @@ def _add_key(self, string_key, line, rows, allow_optimize=True): # division point, then we need a new root: if new_row: # We need a new row - if debug.debug_flag_enabled('index'): - trace.mutter('Inserting new global row.') + if debug.debug_flag_enabled("index"): + trace.mutter("Inserting new global row.") new_row = _InternalBuilderRow() reserved_bytes = 0 rows.insert(0, new_row) @@ -343,13 +354,12 @@ def _add_key(self, string_key, line, rows, allow_optimize=True): new_row.writer = chunk_writer.ChunkWriter( _PAGE_SIZE - _RESERVED_HEADER_BYTES, reserved_bytes, - optimize_for_size=self._optimize_for_size) + optimize_for_size=self._optimize_for_size, + ) new_row.writer.write(_INTERNAL_FLAG) - new_row.writer.write(_INTERNAL_OFFSET - + b"%d\n" % (rows[1].nodes - 1)) + new_row.writer.write(_INTERNAL_OFFSET + b"%d\n" % (rows[1].nodes - 1)) new_row.writer.write(key_line) - self._add_key(string_key, line, rows, - allow_optimize=allow_optimize) + self._add_key(string_key, line, rows, allow_optimize=allow_optimize) def _write_nodes(self, node_iterator, allow_optimize=True): """Write node_iterator out as a B+Tree. @@ -382,29 +392,33 @@ def _write_nodes(self, node_iterator, allow_optimize=True): rows.append(_LeafBuilderRow()) key_count += 1 string_key, line = _btree_serializer._flatten_node( - node, self.reference_lists) - self._add_key(string_key, line, rows, - allow_optimize=allow_optimize) + node, self.reference_lists + ) + self._add_key(string_key, line, rows, allow_optimize=allow_optimize) for row in reversed(rows): - pad = (not isinstance(row, _LeafBuilderRow)) + pad = not isinstance(row, _LeafBuilderRow) row.finish_node(pad=pad) lines = [_BTSIGNATURE] - lines.append(b'%s%d\n' % (_OPTION_NODE_REFS, self.reference_lists)) - lines.append(b'%s%d\n' % (_OPTION_KEY_ELEMENTS, self._key_length)) - lines.append(b'%s%d\n' % (_OPTION_LEN, key_count)) + lines.append(b"%s%d\n" % (_OPTION_NODE_REFS, self.reference_lists)) + lines.append(b"%s%d\n" % (_OPTION_KEY_ELEMENTS, self._key_length)) + lines.append(b"%s%d\n" % (_OPTION_LEN, key_count)) row_lengths = [row.nodes for row in rows] - lines.append(_OPTION_ROW_LENGTHS + ','.join( - map(str, row_lengths)).encode('ascii') + b'\n') + lines.append( + _OPTION_ROW_LENGTHS + + ",".join(map(str, row_lengths)).encode("ascii") + + b"\n" + ) if row_lengths and row_lengths[-1] > 1: - result = tempfile.NamedTemporaryFile(prefix='bzr-index-') + result = tempfile.NamedTemporaryFile(prefix="bzr-index-") else: result = BytesIO() result.writelines(lines) position = sum(map(len, lines)) if position > _RESERVED_HEADER_BYTES: - raise AssertionError("Could not fit the header in the" - " reserved space: %d > %d" - % (position, _RESERVED_HEADER_BYTES)) + raise AssertionError( + "Could not fit the header in the" + " reserved space: %d > %d" % (position, _RESERVED_HEADER_BYTES) + ) # write the rows out: for row in rows: reserved = _RESERVED_HEADER_BYTES # reserved space for first node @@ -420,10 +434,11 @@ def _write_nodes(self, node_iterator, allow_optimize=True): copied_len = osutils.pumpfile(row.spool, result) if copied_len != (row.nodes - 1) * _PAGE_SIZE: if not isinstance(row, _LeafBuilderRow): - raise AssertionError("Incorrect amount of data copied" - " expected: %d, got: %d" - % ((row.nodes - 1) * _PAGE_SIZE, - copied_len)) + raise AssertionError( + "Incorrect amount of data copied" + " expected: %d, got: %d" + % ((row.nodes - 1) * _PAGE_SIZE, copied_len) + ) result.flush() size = result.tell() result.seek(0) @@ -444,9 +459,8 @@ def iter_all_entries(self): no defined order for the result iteration - it will be in the most efficient order for the index (in this case dictionary hash order). """ - if debug.debug_flag_enabled('evil'): - trace.mutter_callsite( - 3, "iter_all_entries scales with size of history.") + if debug.debug_flag_enabled("evil"): + trace.mutter_callsite(3, "iter_all_entries scales with size of history.") # Doing serial rather than ordered would be faster; but this shouldn't # be getting called routinely anyway. iterators = [self._iter_mem_nodes()] @@ -564,7 +578,8 @@ def key_count(self): return len(self._nodes) + sum( backing.key_count() for backing in self._backing_indices - if backing is not None) + if backing is not None + ) def validate(self): """In memory index's have no known corruption at the moment.""" @@ -581,13 +596,14 @@ def __lt__(self, other): class _LeafNode(dict): """A leaf node for a serialised B+Tree index.""" - __slots__ = ('min_key', 'max_key', '_keys') + __slots__ = ("min_key", "max_key", "_keys") def __init__(self, bytes, key_length, ref_list_length): """Parse bytes to create a leaf node object.""" # splitlines mangles the \r delimiters.. don't use it. key_list = _btree_serializer._parse_leaf_lines( - bytes, key_length, ref_list_length) + bytes, key_length, ref_list_length + ) if key_list: self.min_key = key_list[0][0] self.max_key = key_list[-1][0] @@ -610,23 +626,23 @@ def all_keys(self): class _InternalNode: """An internal node for a serialised B+Tree index.""" - __slots__ = ('keys', 'offset') + __slots__ = ("keys", "offset") def __init__(self, bytes): """Parse bytes to create an internal node object.""" # splitlines mangles the \r delimiters.. don't use it. - self.keys = self._parse_lines(bytes.split(b'\n')) + self.keys = self._parse_lines(bytes.split(b"\n")) def _parse_lines(self, lines): nodes = [] self.offset = int(lines[1][7:]) as_st = static_tuple.StaticTuple.from_sequence for line in lines[2:]: - if line == b'': + if line == b"": break # GZ 2017-05-24: Used to intern() each chunk of line as well, need # to recheck performance and perhaps adapt StaticTuple to adjust. - nodes.append(as_st(line.split(b'\0')).intern()) + nodes.append(as_st(line.split(b"\0")).intern()) return nodes @@ -637,8 +653,7 @@ class BTreeGraphIndex: memory except when very large walks are done. """ - def __init__(self, transport, name, size, unlimited_cache=False, - offset=0): + def __init__(self, transport, name, size, unlimited_cache=False, offset=0): """Create a B+Tree index object on the index name. :param transport: The transport to read data for the index from. @@ -686,11 +701,12 @@ def __eq__(self, other): isinstance(self, type(other)) and self._transport == other._transport and self._name == other._name - and self._size == other._size) + and self._size == other._size + ) def __lt__(self, other): if isinstance(other, type(self)): - return ((self._name, self._size) < (other._name, other._size)) + return (self._name, self._size) < (other._name, other._size) # Always sort existing indexes before ones that are still being built. if isinstance(other, BTreeBuilder): return True @@ -744,8 +760,10 @@ def _compute_total_pages_in_index(self): Otherwise it will be computed based on the length of the index. """ if self._size is None: - raise AssertionError('_compute_total_pages_in_index should not be' - ' called when self._size is None') + raise AssertionError( + "_compute_total_pages_in_index should not be" + " called when self._size is None" + ) if self._root_node is not None: # This is the number of pages as defined by the header return self._row_offsets[-1] @@ -767,19 +785,22 @@ def _expand_offsets(self, offsets): :param offsets: The offsets to be read :return: A list of offsets to download """ - if debug.debug_flag_enabled('index'): - trace.mutter('expanding: %s\toffsets: %s', self._name, offsets) + if debug.debug_flag_enabled("index"): + trace.mutter("expanding: %s\toffsets: %s", self._name, offsets) if len(offsets) >= self._recommended_pages: # Don't add more, we are already requesting more than enough - if debug.debug_flag_enabled('index'): - trace.mutter(' not expanding large request (%s >= %s)', - len(offsets), self._recommended_pages) + if debug.debug_flag_enabled("index"): + trace.mutter( + " not expanding large request (%s >= %s)", + len(offsets), + self._recommended_pages, + ) return offsets if self._size is None: # Don't try anything, because we don't know where the file ends - if debug.debug_flag_enabled('index'): - trace.mutter(' not expanding without knowing index size') + if debug.debug_flag_enabled("index"): + trace.mutter(" not expanding without knowing index size") return offsets total_pages = self._compute_total_pages_in_index() cached_offsets = self._get_offsets_to_cached_pages() @@ -788,12 +809,11 @@ def _expand_offsets(self, offsets): if total_pages - len(cached_offsets) <= self._recommended_pages: # Read whatever is left if cached_offsets: - expanded = [x for x in range(total_pages) - if x not in cached_offsets] + expanded = [x for x in range(total_pages) if x not in cached_offsets] else: expanded = list(range(total_pages)) - if debug.debug_flag_enabled('index'): - trace.mutter(' reading all unread pages: %s', expanded) + if debug.debug_flag_enabled("index"): + trace.mutter(" reading all unread pages: %s", expanded) return expanded if self._root_node is None: @@ -812,15 +832,16 @@ def _expand_offsets(self, offsets): # then it isn't worth expanding our request. Once we've read at # least 2 nodes, then we are probably doing a search, and we # start expanding our requests. - if debug.debug_flag_enabled('index'): - trace.mutter(' not expanding on first reads') + if debug.debug_flag_enabled("index"): + trace.mutter(" not expanding on first reads") return offsets - final_offsets = self._expand_to_neighbors(offsets, cached_offsets, - total_pages) + final_offsets = self._expand_to_neighbors( + offsets, cached_offsets, total_pages + ) final_offsets = sorted(final_offsets) - if debug.debug_flag_enabled('index'): - trace.mutter('expanded: %s', final_offsets) + if debug.debug_flag_enabled("index"): + trace.mutter("expanded: %s", final_offsets) return final_offsets def _expand_to_neighbors(self, offsets, cached_offsets, total_pages): @@ -846,16 +867,20 @@ def _expand_to_neighbors(self, offsets, cached_offsets, total_pages): if first is None: first, end = self._find_layer_first_and_end(pos) previous = pos - 1 - if (previous > 0 and - previous not in cached_offsets and - previous not in final_offsets and - previous >= first): + if ( + previous > 0 + and previous not in cached_offsets + and previous not in final_offsets + and previous >= first + ): next_tips.add(previous) after = pos + 1 - if (after < total_pages and - after not in cached_offsets and - after not in final_offsets and - after < end): + if ( + after < total_pages + and after not in cached_offsets + and after not in final_offsets + and after < end + ): next_tips.add(after) # This would keep us from going bigger than # recommended_pages by only expanding the first offsets. @@ -884,8 +909,10 @@ def external_references(self, ref_list_num): if self._root_node is None: self._get_root_node() if ref_list_num + 1 > self.node_ref_lists: - raise ValueError('No ref list %d, index has %d ref lists' - % (ref_list_num, self.node_ref_lists)) + raise ValueError( + "No ref list %d, index has %d ref lists" + % (ref_list_num, self.node_ref_lists) + ) keys = set() refs = set() for node in self.iter_all_entries(): @@ -974,9 +1001,8 @@ def iter_all_entries(self): There is no defined order for the result iteration - it will be in the most efficient order for the index. """ - if debug.debug_flag_enabled('evil'): - trace.mutter_callsite( - 3, "iter_all_entries scales with size of history.") + if debug.debug_flag_enabled("evil"): + trace.mutter_callsite(3, "iter_all_entries scales with size of history.") if not self.key_count(): return if self._row_offsets[-1] == 1: @@ -1020,6 +1046,7 @@ def _multi_bisect_right(in_keys, fixed_keys): :return: A list of (integer position, [key list]) tuples. """ import bisect + if not in_keys: return [] if not fixed_keys: @@ -1116,8 +1143,9 @@ def _walk_through_internal_nodes(self, keys): node = nodes[node_index] positions = self._multi_bisect_right(sub_keys, node.keys) node_offset = next_row_start + node.offset - next_nodes_and_keys.extend([(node_offset + pos, s_keys) - for pos, s_keys in positions]) + next_nodes_and_keys.extend( + [(node_offset + pos, s_keys) for pos, s_keys in positions] + ) keys_at_index = next_nodes_and_keys # We should now be at the _LeafNodes node_indexes = [idx for idx, s_keys in keys_at_index] @@ -1213,8 +1241,10 @@ def _find_ancestors(self, keys, ref_list_num, parent_map, missing_keys): missing_keys.update(keys) return set() if ref_list_num >= self.node_ref_lists: - raise ValueError('No ref list %d, index has %d ref lists' - % (ref_list_num, self.node_ref_lists)) + raise ValueError( + "No ref list %d, index has %d ref lists" + % (ref_list_num, self.node_ref_lists) + ) # The main trick we are trying to accomplish is that when we find a # key listing its parents, we expect that the parent key is also likely @@ -1305,8 +1335,9 @@ def _find_ancestors(self, keys, ref_list_num, parent_map, missing_keys): # parents_not_on_page could have been found on a different page, or be # known to be missing. So cull out everything that has already been # found. - search_keys = parents_not_on_page.difference( - parent_map).difference(missing_keys) + search_keys = parents_not_on_page.difference(parent_map).difference( + missing_keys + ) return search_keys def iter_entries_prefix(self, keys): @@ -1410,45 +1441,46 @@ def _parse_header_from_bytes(self, bytes): :return: An offset, data tuple such as readv yields, for the unparsed data. (which may be of length 0). """ - signature = bytes[0:len(self._signature())] + signature = bytes[0 : len(self._signature())] if not signature == self._signature(): raise _mod_index.BadIndexFormatSignature(self._name, BTreeGraphIndex) - lines = bytes[len(self._signature()):].splitlines() + lines = bytes[len(self._signature()) :].splitlines() options_line = lines[0] if not options_line.startswith(_OPTION_NODE_REFS): raise _mod_index.BadIndexOptions(self) try: - self.node_ref_lists = int(options_line[len(_OPTION_NODE_REFS):]) + self.node_ref_lists = int(options_line[len(_OPTION_NODE_REFS) :]) except ValueError as e: raise _mod_index.BadIndexOptions(self) from e options_line = lines[1] if not options_line.startswith(_OPTION_KEY_ELEMENTS): raise _mod_index.BadIndexOptions(self) try: - self._key_length = int(options_line[len(_OPTION_KEY_ELEMENTS):]) + self._key_length = int(options_line[len(_OPTION_KEY_ELEMENTS) :]) except ValueError as e: raise _mod_index.BadIndexOptions(self) from e options_line = lines[2] if not options_line.startswith(_OPTION_LEN): raise _mod_index.BadIndexOptions(self) try: - self._key_count = int(options_line[len(_OPTION_LEN):]) + self._key_count = int(options_line[len(_OPTION_LEN) :]) except ValueError as e: raise _mod_index.BadIndexOptions(self) from e options_line = lines[3] if not options_line.startswith(_OPTION_ROW_LENGTHS): raise _mod_index.BadIndexOptions(self) try: - self._row_lengths = [int(length) for length in - options_line[len(_OPTION_ROW_LENGTHS):].split( - b',') - if length] + self._row_lengths = [ + int(length) + for length in options_line[len(_OPTION_ROW_LENGTHS) :].split(b",") + if length + ] except ValueError as e: raise _mod_index.BadIndexOptions(self) from e self._compute_row_offsets() # calculate the bytes we have processed - header_end = (len(signature) + sum(map(len, lines[0:4])) + 4) + header_end = len(signature) + sum(map(len, lines[0:4])) + 4 return header_end, bytes[header_end:] def _read_nodes(self, nodes): @@ -1469,7 +1501,7 @@ def _read_nodes(self, nodes): ranges = [] base_offset = self._base_offset for index in nodes: - offset = (index * _PAGE_SIZE) + offset = index * _PAGE_SIZE size = _PAGE_SIZE if index == 0: # Root node - special case @@ -1482,22 +1514,26 @@ def _read_nodes(self, nodes): num_bytes = len(bytes) self._size = num_bytes - base_offset # the whole thing should be parsed out of 'bytes' - ranges = [(start, min(_PAGE_SIZE, num_bytes - start)) - for start in range( - base_offset, num_bytes, _PAGE_SIZE)] + ranges = [ + (start, min(_PAGE_SIZE, num_bytes - start)) + for start in range(base_offset, num_bytes, _PAGE_SIZE) + ] break else: if offset > self._size: - raise AssertionError('tried to read past the end' - f' of the file {offset} > {self._size}') + raise AssertionError( + "tried to read past the end" + f" of the file {offset} > {self._size}" + ) size = min(size, self._size - offset) ranges.append((base_offset + offset, size)) if not ranges: return elif bytes is not None: # already have the whole file - data_ranges = [(start, bytes[start:start + size]) - for start, size in ranges] + data_ranges = [ + (start, bytes[start : start + size]) for start, size in ranges + ] elif self._file is None: data_ranges = self._transport.readv(self._name, ranges) else: @@ -1514,8 +1550,7 @@ def _read_nodes(self, nodes): continue bytes = zlib.decompress(data) if bytes.startswith(_LEAF_FLAG): - node = self._leaf_factory(bytes, self._key_length, - self.node_ref_lists) + node = self._leaf_factory(bytes, self._key_length, self.node_ref_lists) elif bytes.startswith(_INTERNAL_FLAG): node = _InternalNode(bytes) else: @@ -1544,6 +1579,7 @@ def validate(self): try: from . import _btree_serializer_pyx as _btree_serializer # type: ignore + _gcchk_factory = _btree_serializer._parse_into_chk # type: ignore except ImportError as e: osutils.failed_to_load_extension(e) diff --git a/breezy/bzr/bundle/apply_bundle.py b/breezy/bzr/bundle/apply_bundle.py index 3de1be7d6e..f6ea1557dc 100644 --- a/breezy/bzr/bundle/apply_bundle.py +++ b/breezy/bzr/bundle/apply_bundle.py @@ -27,7 +27,7 @@ def install_bundle(repository, bundle_reader): - custom_install = getattr(bundle_reader, 'install', None) + custom_install = getattr(bundle_reader, "install", None) if custom_install is not None: return custom_install(repository) with repository.lock_write(), ui.ui_factory.nested_progress_bar() as pb: @@ -36,20 +36,19 @@ def install_bundle(repository, bundle_reader): pb.update(gettext("Install revisions"), i, len(real_revisions)) if repository.has_revision(revision.revision_id): continue - cset_tree = bundle_reader.revision_tree(repository, - revision.revision_id) + cset_tree = bundle_reader.revision_tree(repository, revision.revision_id) install_revision(repository, revision, cset_tree) -def merge_bundle(reader, tree, check_clean, merge_type, - reprocess, show_base, change_reporter=None): +def merge_bundle( + reader, tree, check_clean, merge_type, reprocess, show_base, change_reporter=None +): """Merge a revision bundle into the current tree.""" with ui.ui_factory.nested_progress_bar() as pb: pp = ProgressPhase("Merge phase", 6, pb) pp.next_phase() install_bundle(tree.branch.repository, reader) - merger = Merger(tree.branch, this_tree=tree, - change_reporter=change_reporter) + merger = Merger(tree.branch, this_tree=tree, change_reporter=change_reporter) merger.pp = pp merger.pp.next_phase() if check_clean and tree.has_changes(): diff --git a/breezy/bzr/bundle/bundle_data.py b/breezy/bzr/bundle/bundle_data.py index 119f385b04..0827cc44ca 100644 --- a/breezy/bzr/bundle/bundle_data.py +++ b/breezy/bzr/bundle/bundle_data.py @@ -59,37 +59,37 @@ def as_revision(self): properties = {} if self.properties: for property in self.properties: - key_end = property.find(': ') + key_end = property.find(": ") if key_end == -1: - if not property.endswith(':'): + if not property.endswith(":"): raise ValueError(property) key = str(property[:-1]) - value = '' + value = "" else: key = str(property[:key_end]) - value = property[key_end + 2:] + value = property[key_end + 2 :] properties[key] = value - return Revision(revision_id=self.revision_id, - committer=self.committer, - timestamp=float(self.timestamp), - timezone=int(self.timezone), - inventory_sha1=self.inventory_sha1, - message='\n'.join(self.message), - parent_ids=self.parent_ids or [], - properties=properties) + return Revision( + revision_id=self.revision_id, + committer=self.committer, + timestamp=float(self.timestamp), + timezone=int(self.timezone), + inventory_sha1=self.inventory_sha1, + message="\n".join(self.message), + parent_ids=self.parent_ids or [], + properties=properties, + ) @staticmethod def from_revision(revision): revision_info = RevisionInfo(revision.revision_id) - date = osutils.format_highres_date(revision.timestamp, - revision.timezone) + date = osutils.format_highres_date(revision.timestamp, revision.timezone) revision_info.date = date revision_info.timezone = revision.timezone revision_info.timestamp = revision.timestamp - revision_info.message = revision.message.split('\n') - revision_info.properties = [': '.join(p) for p in - revision.properties.items()] + revision_info.message = revision.message.split("\n") + revision_info.properties = [": ".join(p) for p in revision.properties.items()] return revision_info @@ -135,8 +135,7 @@ def complete_info(self): for rev in self.revisions: if rev.timestamp is None: if rev.date is not None: - rev.timestamp, rev.timezone = \ - osutils.unpack_highres_date(rev.date) + rev.timestamp, rev.timezone = osutils.unpack_highres_date(rev.date) else: rev.timestamp = self.timestamp rev.timezone = self.timezone @@ -167,7 +166,7 @@ def _get_target(self): return self.revisions[0].revision_id return None - target = property(_get_target, doc='The target revision id') + target = property(_get_target, doc="The target revision id") def get_revision(self, revision_id): for r in self.real_revisions: @@ -190,8 +189,7 @@ def revision_tree(self, repository, revision_id, base=None): self._validate_references_from_repository(repository) self.get_revision_info(revision_id) inventory_revision_id = revision_id - bundle_tree = BundleTree(repository.revision_tree(base), - inventory_revision_id) + bundle_tree = BundleTree(repository.revision_tree(base), inventory_revision_id) self._update_tree(bundle_tree, revision_id) inv = bundle_tree.inventory @@ -211,16 +209,18 @@ def _validate_references_from_repository(self, repository): def add_sha(d, revision_id, sha1): if revision_id is None: if sha1 is not None: - raise BzrError('A Null revision should always' - 'have a null sha1 hash') + raise BzrError( + "A Null revision should always" "have a null sha1 hash" + ) return if revision_id in d: # This really should have been validated as part # of _validate_revisions but lets do it again if sha1 != d[revision_id]: - raise BzrError('** Revision {!r} referenced with 2 different' - ' sha hashes {} != {}'.format(revision_id, - sha1, d[revision_id])) + raise BzrError( + f"** Revision {revision_id!r} referenced with 2 different" + f" sha hashes {sha1} != {d[revision_id]}" + ) else: d[revision_id] = sha1 @@ -231,20 +231,20 @@ def add_sha(d, revision_id, sha1): checked[rev_info.revision_id] = True add_sha(rev_to_sha, rev_info.revision_id, rev_info.sha1) - for (_rev, rev_info) in zip(self.real_revisions, self.revisions): + for _rev, rev_info in zip(self.real_revisions, self.revisions): add_sha(inv_to_sha, rev_info.revision_id, rev_info.inventory_sha1) count = 0 missing = {} for revision_id, sha1 in rev_to_sha.items(): if repository.has_revision(revision_id): - StrictTestament.from_revision(repository, - revision_id) - local_sha1 = self._testament_sha1_from_revision(repository, - revision_id) + StrictTestament.from_revision(repository, revision_id) + local_sha1 = self._testament_sha1_from_revision(repository, revision_id) if sha1 != local_sha1: - raise BzrError(f'sha1 mismatch. For revision id {{{revision_id}}}' - f'local: {local_sha1}, bundle: {sha1}') + raise BzrError( + f"sha1 mismatch. For revision id {{{revision_id}}}" + f"local: {local_sha1}, bundle: {sha1}" + ) else: count += 1 elif revision_id not in checked: @@ -252,9 +252,11 @@ def add_sha(d, revision_id, sha1): if len(missing) > 0: # I don't know if this is an error yet - warning('Not all revision hashes could be validated.' - ' Unable validate %d hashes' % len(missing)) - mutter('Verified %d sha hashes for the bundle.' % count) + warning( + "Not all revision hashes could be validated." + " Unable validate %d hashes" % len(missing) + ) + mutter("Verified %d sha hashes for the bundle." % count) self._validated_revisions_against_repo = True def _validate_inventory(self, inv, revision_id): @@ -269,10 +271,12 @@ def _validate_inventory(self, inv, revision_id): if rev.revision_id != revision_id: raise AssertionError() if sha1 != rev.inventory_sha1: - with open(',,bogus-inv', 'wb') as f: + with open(",,bogus-inv", "wb") as f: f.writelines(cs) - warning(f'Inventory sha hash mismatch for revision {revision_id}. {sha1}' - f' != {rev.inventory_sha1}') + warning( + f"Inventory sha hash mismatch for revision {revision_id}. {sha1}" + f" != {rev.inventory_sha1}" + ) def _testament(self, revision, tree): raise NotImplementedError(self._testament) @@ -293,8 +297,7 @@ def _validate_revision(self, tree, revision_id): if sha1 != rev_info.sha1: raise TestamentMismatch(rev.revision_id, rev_info.sha1, sha1) if rev.revision_id in rev_to_sha1: - raise BzrError('Revision {%s} given twice in the list' - % (rev.revision_id)) + raise BzrError("Revision {%s} given twice in the list" % (rev.revision_id)) rev_to_sha1[rev.revision_id] = sha1 def _update_tree(self, bundle_tree, revision_id): @@ -319,36 +322,37 @@ def extra_info(info, new_path): encoding = None for info_item in info: try: - name, value = info_item.split(':', 1) + name, value = info_item.split(":", 1) except ValueError as e: - raise ValueError(f'Value {info_item!r} has no colon') from e - if name == 'last-changed': + raise ValueError(f"Value {info_item!r} has no colon") from e + if name == "last-changed": last_changed = value - elif name == 'executable': - val = (value == 'yes') + elif name == "executable": + val = value == "yes" bundle_tree.note_executable(new_path, val) - elif name == 'target': + elif name == "target": bundle_tree.note_target(new_path, value) - elif name == 'encoding': + elif name == "encoding": encoding = value return last_changed, encoding def do_patch(path, lines, encoding): - if encoding == 'base64': - patch = base64.b64decode(b''.join(lines)) + if encoding == "base64": + patch = base64.b64decode(b"".join(lines)) elif encoding is None: - patch = b''.join(lines) + patch = b"".join(lines) else: raise ValueError(encoding) bundle_tree.note_patch(path, patch) def renamed(kind, extra, lines): - info = extra.split(' // ') + info = extra.split(" // ") if len(info) < 2: - raise BzrError('renamed action lines need both a from and to' - ': %r' % extra) + raise BzrError( + "renamed action lines need both a from and to" ": %r" % extra + ) old_path = info[0] - if info[1].startswith('=> '): + if info[1].startswith("=> "): new_path = info[1][3:] else: new_path = info[1] @@ -360,27 +364,31 @@ def renamed(kind, extra, lines): do_patch(new_path, lines, encoding) def removed(kind, extra, lines): - info = extra.split(' // ') + info = extra.split(" // ") if len(info) > 1: # TODO: in the future we might allow file ids to be # given for removed entries - raise BzrError('removed action lines should only have the path' - ': %r' % extra) + raise BzrError( + "removed action lines should only have the path" ": %r" % extra + ) path = info[0] bundle_tree.note_deletion(path) def added(kind, extra, lines): - info = extra.split(' // ') + info = extra.split(" // ") if len(info) <= 1: - raise BzrError('add action lines require the path and file id' - ': %r' % extra) + raise BzrError( + "add action lines require the path and file id" ": %r" % extra + ) elif len(info) > 5: - raise BzrError('add action lines have fewer than 5 entries.' - ': %r' % extra) + raise BzrError( + "add action lines have fewer than 5 entries." ": %r" % extra + ) path = info[0] - if not info[1].startswith('file-id:'): - raise BzrError('The file-id should follow the path for an add' - ': %r' % extra) + if not info[1].startswith("file-id:"): + raise BzrError( + "The file-id should follow the path for an add" ": %r" % extra + ) # This will be Unicode because of how the stream is read. Turn it # back into a utf8 file_id file_id = cache_utf8.encode(info[1][8:]) @@ -390,15 +398,16 @@ def added(kind, extra, lines): bundle_tree.note_executable(path, False) last_changed, encoding = extra_info(info[2:], path) get_rev_id(last_changed, path, kind) - if kind == 'directory': + if kind == "directory": return do_patch(path, lines, encoding) def modified(kind, extra, lines): - info = extra.split(' // ') + info = extra.split(" // ") if len(info) < 1: - raise BzrError('modified action lines have at least' - 'the path in them: %r' % extra) + raise BzrError( + "modified action lines have at least" "the path in them: %r" % extra + ) path = info[0] last_modified, encoding = extra_info(info[1:], path) @@ -407,30 +416,33 @@ def modified(kind, extra, lines): do_patch(path, lines, encoding) valid_actions = { - 'renamed': renamed, - 'removed': removed, - 'added': added, - 'modified': modified + "renamed": renamed, + "removed": removed, + "added": added, + "modified": modified, } - for action_line, lines in \ - self.get_revision_info(revision_id).tree_actions: - first = action_line.find(' ') + for action_line, lines in self.get_revision_info(revision_id).tree_actions: + first = action_line.find(" ") if first == -1: - raise BzrError(f'Bogus action line (no opening space): {action_line!r}') - second = action_line.find(' ', first + 1) + raise BzrError(f"Bogus action line (no opening space): {action_line!r}") + second = action_line.find(" ", first + 1) if second == -1: - raise BzrError('Bogus action line' - ' (missing second space): %r' % action_line) + raise BzrError( + "Bogus action line" " (missing second space): %r" % action_line + ) action = action_line[:first] - kind = action_line[first + 1:second] - if kind not in ('file', 'directory', 'symlink'): - raise BzrError('Bogus action line' - f' (invalid object kind {kind!r}): {action_line!r}') - extra = action_line[second + 1:] + kind = action_line[first + 1 : second] + if kind not in ("file", "directory", "symlink"): + raise BzrError( + "Bogus action line" + f" (invalid object kind {kind!r}): {action_line!r}" + ) + extra = action_line[second + 1 :] if action not in valid_actions: - raise BzrError('Bogus action line' - ' (unrecognized action): %r' % action_line) + raise BzrError( + "Bogus action line" " (unrecognized action): %r" % action_line + ) valid_actions[action](kind, extra, lines) def install_revisions(self, target_repo, stream_input=True): @@ -447,11 +459,10 @@ def get_merge_request(self, target_repo): Returns suggested base, suggested target, and patch verification status """ - return None, self.target, 'inapplicable' + return None, self.target, "inapplicable" class BundleTree(InventoryTree): - def __init__(self, base_tree, revision_id): self.base_tree = base_tree self._renamed = {} # Mapping from old_path => new_path @@ -480,17 +491,18 @@ def note_rename(self, old_path, new_path): self._renamed[new_path] = old_path self._renamed_r[old_path] = new_path - def note_id(self, new_id, new_path, kind='file'): + def note_id(self, new_id, new_path, kind="file"): """Files that don't exist in base need a new id.""" self._new_id[new_path] = new_id self._new_id_r[new_id] = new_path self._kinds[new_path] = kind def note_last_changed(self, file_id, revision_id): - if (file_id in self._last_changed - and self._last_changed[file_id] != revision_id): - raise BzrError(f'Mismatched last-changed revision for file_id {{{file_id}}}' - f': {self._last_changed[file_id]} != {revision_id}') + if file_id in self._last_changed and self._last_changed[file_id] != revision_id: + raise BzrError( + f"Mismatched last-changed revision for file_id {{{file_id}}}" + f": {self._last_changed[file_id]} != {revision_id}" + ) self._last_changed[file_id] = revision_id def note_patch(self, new_path, patch): @@ -510,7 +522,7 @@ def note_executable(self, new_path, executable): def old_path(self, new_path): """Get the old_path (path in the base_tree) for the file at new_path.""" - if new_path[:1] in ('\\', '/'): + if new_path[:1] in ("\\", "/"): raise ValueError(new_path) old_path = self._renamed.get(new_path) if old_path is not None: @@ -519,7 +531,7 @@ def old_path(self, new_path): # dirname is not '' doesn't work, because # dirname may be a unicode entry, and is # requires the objects to be identical - if dirname != '': + if dirname != "": old_dir = self.old_path(dirname) if old_dir is None: old_path = None @@ -537,7 +549,7 @@ def new_path(self, old_path): """Get the new_path (path in the target_tree) for the file at old_path in the base tree. """ - if old_path[:1] in ('\\', '/'): + if old_path[:1] in ("\\", "/"): raise ValueError(old_path) new_path = self._renamed_r.get(old_path) if new_path is not None: @@ -545,7 +557,7 @@ def new_path(self, old_path): if new_path in self._renamed: return None dirname, basename = os.path.split(old_path) - if dirname != '': + if dirname != "": new_dir = self.new_path(dirname) if new_dir is None: new_path = None @@ -571,7 +583,7 @@ def path2id(self, path): return None return self.base_tree.path2id(old_path) - def id2path(self, file_id, recurse='down'): + def id2path(self, file_id, recurse="down"): """Return the new path in the target tree of the file with id file_id.""" path = self._new_id_r.get(file_id) if path is not None: @@ -601,16 +613,14 @@ def get_file(self, path): patch_original = self.base_tree.get_file(old_path) file_patch = self.patches.get(path) if file_patch is None: - if (patch_original is None and - self.kind(path) == 'directory'): + if patch_original is None and self.kind(path) == "directory": return BytesIO() if patch_original is None: raise AssertionError(f"None: {file_id}") return patch_original - if file_patch.startswith(b'\\'): - raise ValueError( - f'Malformed patch for {file_id}, {file_patch!r}') + if file_patch.startswith(b"\\"): + raise ValueError(f"Malformed patch for {file_id}, {file_patch!r}") return patched_file(file_patch, patch_original) def get_symlink_target(self, path): @@ -671,10 +681,11 @@ def _get_inventory(self): This need to be called before ever accessing self.inventory """ from os.path import basename, dirname + inv = Inventory(None, self.revision_id) def add_entry(path, file_id): - if path == '': + if path == "": parent_id = None else: parent_path = dirname(path) @@ -684,21 +695,20 @@ def add_entry(path, file_id): revision_id = self.get_last_changed(path) name = basename(path) - if kind == 'directory': + if kind == "directory": ie = InventoryDirectory(file_id, name, parent_id) - elif kind == 'file': + elif kind == "file": ie = InventoryFile(file_id, name, parent_id) ie.executable = self.is_executable(path) - elif kind == 'symlink': + elif kind == "symlink": ie = InventoryLink(file_id, name, parent_id) ie.symlink_target = self.get_symlink_target(path) ie.revision = revision_id - if kind == 'file': + if kind == "file": ie.text_size, ie.text_sha1 = self.get_size_and_sha1(path) if ie.text_size is None: - raise BzrError( - f'Got a text_size of None for file_id {file_id!r}') + raise BzrError(f"Got a text_size of None for file_id {file_id!r}") inv.add(ie) sorted_entries = self.sorted_path_id() @@ -737,7 +747,7 @@ def list_files(self, include_root=False, from_dir=None, recursive=True): # skip the root for compatibility with the current apis. next(entries) for path, entry in entries: - yield path, 'V', entry.kind, entry + yield path, "V", entry.kind, entry def sorted_path_id(self): paths = [] @@ -745,7 +755,7 @@ def sorted_path_id(self): paths.append(result) for id in self.base_tree.all_file_ids(): try: - path = self.id2path(id, recurse='none') + path = self.id2path(id, recurse="none") except NoSuchId: continue paths.append((path, id)) @@ -757,9 +767,9 @@ def patched_file(file_patch, original): """Produce a file-like object with the patched version of a text.""" from ...osutils import IterableFile from ...patches import iter_patched + if file_patch == b"": return IterableFile(()) # string.splitlines(True) also splits on '\r', but the iter_patched code # only expects to iterate over '\n' style lines - return IterableFile(iter_patched(original, - BytesIO(file_patch).readlines())) + return IterableFile(iter_patched(original, BytesIO(file_patch).readlines())) diff --git a/breezy/bzr/bundle/commands.py b/breezy/bzr/bundle/commands.py index 2f436843d2..3bc62f8bf9 100644 --- a/breezy/bzr/bundle/commands.py +++ b/breezy/bzr/bundle/commands.py @@ -31,9 +31,9 @@ class cmd_bundle_info(Command): __doc__ = """Output interesting stats about a bundle""" hidden = True - takes_args = ['location'] - takes_options = ['verbose'] - encoding_type = 'exact' + takes_args = ["location"] + takes_options = ["verbose"] + encoding_type = "exact" def run(self, location, verbose=False): from breezy import merge_directive, osutils @@ -41,6 +41,7 @@ def run(self, location, verbose=False): from breezy.i18n import gettext from ...mergeable import read_mergeable_from_url + term_encoding = osutils.get_terminal_encoding() bundle_info = read_mergeable_from_url(location) if isinstance(bundle_info, merge_directive.BaseMergeDirective): @@ -48,48 +49,64 @@ def run(self, location, verbose=False): bundle_info = read_bundle(bundle_file) else: if verbose: - raise errors.CommandError(gettext( - '--verbose requires a merge directive')) - reader_method = getattr(bundle_info, 'get_bundle_reader', None) + raise errors.CommandError( + gettext("--verbose requires a merge directive") + ) + reader_method = getattr(bundle_info, "get_bundle_reader", None) if reader_method is None: - raise errors.CommandError( - gettext('Bundle format not supported')) + raise errors.CommandError(gettext("Bundle format not supported")) by_kind = {} file_ids = set() - for bytes, parents, repo_kind, revision_id, file_id\ - in reader_method().iter_records(): + for ( + bytes, + parents, + repo_kind, + revision_id, + file_id, + ) in reader_method().iter_records(): by_kind.setdefault(repo_kind, []).append( - (bytes, parents, repo_kind, revision_id, file_id)) + (bytes, parents, repo_kind, revision_id, file_id) + ) if file_id is not None: file_ids.add(file_id) - self.outf.write(gettext('Records\n')) + self.outf.write(gettext("Records\n")) for kind, records in sorted(by_kind.items()): - multiparent = sum(1 for b, m, k, r, f in records if - len(m.get('parents', [])) > 1) - self.outf.write(gettext('{0}: {1} ({2} multiparent)\n').format( - kind, len(records), multiparent)) - self.outf.write(gettext('unique files: %d\n') % len(file_ids)) - self.outf.write('\n') + multiparent = sum( + 1 for b, m, k, r, f in records if len(m.get("parents", [])) > 1 + ) + self.outf.write( + gettext("{0}: {1} ({2} multiparent)\n").format( + kind, len(records), multiparent + ) + ) + self.outf.write(gettext("unique files: %d\n") % len(file_ids)) + self.outf.write("\n") nicks = set() committers = set() for revision in bundle_info.real_revisions: - if 'branch-nick' in revision.properties: - nicks.add(revision.properties['branch-nick']) + if "branch-nick" in revision.properties: + nicks.add(revision.properties["branch-nick"]) committers.add(revision.committer) - self.outf.write(gettext('Revisions\n')) - self.outf.write((gettext('nicks: %s\n') - % ', '.join(sorted(nicks))).encode(term_encoding, 'replace')) - self.outf.write(gettext('committers: \n%s\n') % - '\n'.join(sorted(committers)).encode(term_encoding, 'replace')) + self.outf.write(gettext("Revisions\n")) + self.outf.write( + (gettext("nicks: %s\n") % ", ".join(sorted(nicks))).encode( + term_encoding, "replace" + ) + ) + self.outf.write( + gettext("committers: \n%s\n") + % "\n".join(sorted(committers)).encode(term_encoding, "replace") + ) if verbose: - self.outf.write('\n') + self.outf.write("\n") bundle_file.seek(0) bundle_file.readline() bundle_file.readline() import bz2 + content = bz2.decompress(bundle_file.read()) self.outf.write(gettext("Decoded contents\n")) self.outf.write(content) - self.outf.write('\n') + self.outf.write("\n") diff --git a/breezy/bzr/bundle/serializer/__init__.py b/breezy/bzr/bundle/serializer/__init__.py index fafd6fcebc..0b505f03e1 100644 --- a/breezy/bzr/bundle/serializer/__init__.py +++ b/breezy/bzr/bundle/serializer/__init__.py @@ -26,19 +26,21 @@ # For backwards-compatibility # New bundles should try to use this header format -BUNDLE_HEADER = b'# Bazaar revision bundle v' +BUNDLE_HEADER = b"# Bazaar revision bundle v" BUNDLE_HEADER_RE = re.compile( - br'^# Bazaar revision bundle v(?P\d+[\w.]*)(?P\r?)\n$') + rb"^# Bazaar revision bundle v(?P\d+[\w.]*)(?P\r?)\n$" +) CHANGESET_OLD_HEADER_RE = re.compile( - br'^# Bazaar-NG changeset v(?P\d+[\w.]*)(?P\r?)\n$') + rb"^# Bazaar-NG changeset v(?P\d+[\w.]*)(?P\r?)\n$" +) def _get_bundle_header(version): - return b''.join([BUNDLE_HEADER, version.encode('ascii'), b'\n']) + return b"".join([BUNDLE_HEADER, version.encode("ascii"), b"\n"]) def _get_filename(f): - return getattr(f, 'name', '') + return getattr(f, "name", "") def read_bundle(f): @@ -51,31 +53,28 @@ def read_bundle(f): for line in f: m = BUNDLE_HEADER_RE.match(line) if m: - if m.group('lineending') != b'': + if m.group("lineending") != b"": raise errors.UnsupportedEOLMarker() - version = m.group('version') + version = m.group("version") break elif line.startswith(BUNDLE_HEADER): - raise errors.MalformedHeader( - 'Extra characters after version number') + raise errors.MalformedHeader("Extra characters after version number") m = CHANGESET_OLD_HEADER_RE.match(line) if m: - version = m.group('version') - raise errors.BundleNotSupported(version, - 'old format bundles not supported') + version = m.group("version") + raise errors.BundleNotSupported(version, "old format bundles not supported") if version is None: - raise errors.NotABundle('Did not find an opening header') + raise errors.NotABundle("Did not find an opening header") - return get_serializer(version.decode('ascii')).read(f) + return get_serializer(version.decode("ascii")).read(f) def get_serializer(version): try: serializer = serializer_registry.get(version) except KeyError as e: - raise errors.BundleNotSupported(version, - 'unknown bundle format') from e + raise errors.BundleNotSupported(version, "unknown bundle format") from e return serializer(version) @@ -91,8 +90,7 @@ def write(source, revision_ids, f, version=None, forced_bases=None): if forced_bases is None: forced_bases = {} with source.lock_read(): - return get_serializer(version).write(source, revision_ids, - forced_bases, f) + return get_serializer(version).write(source, revision_ids, forced_bases, f) def write_bundle(repository, revision_id, base_revision_id, out, format=None): @@ -106,8 +104,9 @@ def write_bundle(repository, revision_id, base_revision_id, out, format=None): :return: List of revision ids written """ with repository.lock_read(): - return get_serializer(format).write_bundle(repository, revision_id, - base_revision_id, out) + return get_serializer(format).write_bundle( + repository, revision_id, base_revision_id, out + ) class BundleSerializer: @@ -142,19 +141,17 @@ def write_bundle(self, repository, target, base, fileobj): def binary_diff(old_filename, old_lines, new_filename, new_lines, to_file): temp = BytesIO() - internal_diff(old_filename, old_lines, new_filename, new_lines, temp, - allow_binary=True) + internal_diff( + old_filename, old_lines, new_filename, new_lines, temp, allow_binary=True + ) temp.seek(0) base64.encode(temp, to_file) - to_file.write(b'\n') + to_file.write(b"\n") serializer_registry = registry.Registry[str, BundleSerializer, None]() -serializer_registry.register_lazy( - '0.8', __name__ + '.v08', 'BundleSerializerV08') -serializer_registry.register_lazy( - '0.9', __name__ + '.v09', 'BundleSerializerV09') -serializer_registry.register_lazy('4', __name__ + '.v4', - 'BundleSerializerV4') -serializer_registry.default_key = '4' +serializer_registry.register_lazy("0.8", __name__ + ".v08", "BundleSerializerV08") +serializer_registry.register_lazy("0.9", __name__ + ".v09", "BundleSerializerV09") +serializer_registry.register_lazy("4", __name__ + ".v4", "BundleSerializerV4") +serializer_registry.default_key = "4" diff --git a/breezy/bzr/bundle/serializer/v08.py b/breezy/bzr/bundle/serializer/v08.py index fd2940e5df..07e7555b53 100644 --- a/breezy/bzr/bundle/serializer/v08.py +++ b/breezy/bzr/bundle/serializer/v08.py @@ -27,7 +27,7 @@ from ..bundle_data import BundleInfo, RevisionInfo from . import BundleSerializer, _get_bundle_header, binary_diff -bool_text = {True: 'yes', False: 'no'} +bool_text = {True: "yes", False: "no"} class Action: @@ -46,7 +46,7 @@ def __init__(self, name, parameters=None, properties=None): def add_utf8_property(self, name, value): """Add a property whose value is currently utf8 to the action.""" - self.properties.append((name, value.decode('utf8'))) + self.properties.append((name, value.decode("utf8"))) def add_property(self, name, value): """Add a property to the action.""" @@ -58,26 +58,25 @@ def add_bool_property(self, name, value): def write(self, to_file): """Write action as to a file.""" - p_texts = [' '.join([self.name] + self.parameters)] + p_texts = [" ".join([self.name] + self.parameters)] for prop in self.properties: if len(prop) == 1: p_texts.append(prop[0]) else: - p_texts.append('{}:{}'.format(*prop)) - text = ['=== '] - text.append(' // '.join(p_texts)) - text_line = ''.join(text).encode('utf-8') + p_texts.append("{}:{}".format(*prop)) + text = ["=== "] + text.append(" // ".join(p_texts)) + text_line = "".join(text).encode("utf-8") available = 79 while len(text_line) > available: to_file.write(text_line[:available]) text_line = text_line[available:] - to_file.write(b'\n... ') - available = 79 - len(b'... ') - to_file.write(text_line + b'\n') + to_file.write(b"\n... ") + available = 79 - len(b"... ") + to_file.write(text_line + b"\n") class BundleSerializerV08(BundleSerializer): - def read(self, f): """Read the rest of the bundles from the supplied file. @@ -88,7 +87,7 @@ def read(self, f): def check_compatible(self): if self.source.supports_rich_root(): - raise errors.IncompatibleBundleFormat('0.8', repr(self.source)) + raise errors.IncompatibleBundleFormat("0.8", repr(self.source)) def write(self, source, revision_ids, forced_bases, f): """Write the bundless to the supplied files. @@ -114,10 +113,8 @@ def write_bundle(self, repository, revision_id, base_revision_id, out): if base_revision_id is NULL_REVISION: base_revision_id = None graph = repository.get_graph() - revision_ids = graph.find_unique_ancestors(revision_id, - [base_revision_id]) - revision_ids = list(repository.get_graph().iter_topo_order( - revision_ids)) + revision_ids = graph.find_unique_ancestors(revision_id, [base_revision_id]) + revision_ids = list(repository.get_graph().iter_topo_order(revision_ids)) revision_ids.reverse() self.write(repository, revision_ids, forced_bases, out) return revision_ids @@ -125,8 +122,8 @@ def write_bundle(self, repository, revision_id, base_revision_id, out): def _write_main_header(self): """Write the header for the changes.""" f = self.to_file - f.write(_get_bundle_header('0.8')) - f.write(b'#\n') + f.write(_get_bundle_header("0.8")) + f.write(b"#\n") def _write(self, key, value, indent=1, trailing_space_when_empty=False): r"""Write out meta information, with proper indenting, etc. @@ -138,32 +135,32 @@ def _write(self, key, value, indent=1, trailing_space_when_empty=False): write an extra space. """ if indent < 1: - raise ValueError('indentation must be greater than 0') + raise ValueError("indentation must be greater than 0") f = self.to_file - f.write(b'#' + (b' ' * indent)) - f.write(key.encode('utf-8')) + f.write(b"#" + (b" " * indent)) + f.write(key.encode("utf-8")) if not value: - if trailing_space_when_empty and value == '': - f.write(b': \n') + if trailing_space_when_empty and value == "": + f.write(b": \n") else: - f.write(b':\n') + f.write(b":\n") elif isinstance(value, bytes): - f.write(b': ') + f.write(b": ") f.write(value) - f.write(b'\n') + f.write(b"\n") elif isinstance(value, str): - f.write(b': ') - f.write(value.encode('utf-8')) - f.write(b'\n') + f.write(b": ") + f.write(value.encode("utf-8")) + f.write(b"\n") else: - f.write(b':\n') + f.write(b":\n") for entry in value: - f.write(b'#' + (b' ' * (indent + 2))) + f.write(b"#" + (b" " * (indent + 2))) if isinstance(entry, bytes): f.write(entry) else: - f.write(entry.encode('utf-8')) - f.write(b'\n') + f.write(entry.encode("utf-8")) + f.write(b"\n") def _write_revisions(self, pb): """Write the information for all of the revisions.""" @@ -195,59 +192,59 @@ def _write_revisions(self, pb): base_tree = last_rev_tree else: base_tree = self.source.revision_tree(base_id) - force_binary = (i != 0) - self._write_revision(rev, rev_tree, base_id, base_tree, - explicit_base, force_binary) + force_binary = i != 0 + self._write_revision( + rev, rev_tree, base_id, base_tree, explicit_base, force_binary + ) last_rev_id = base_id last_rev_tree = base_tree def _testament_sha1(self, revision_id): - return StrictTestament.from_revision(self.source, - revision_id).as_sha1() + return StrictTestament.from_revision(self.source, revision_id).as_sha1() - def _write_revision(self, rev, rev_tree, base_rev, base_tree, - explicit_base, force_binary): + def _write_revision( + self, rev, rev_tree, base_rev, base_tree, explicit_base, force_binary + ): """Write out the information for a revision.""" + def w(key, value): self._write(key, value, indent=1) - w('message', rev.message.split('\n')) - w('committer', rev.committer) - w('date', format_highres_date(rev.timestamp, rev.timezone)) - self.to_file.write(b'\n') + w("message", rev.message.split("\n")) + w("committer", rev.committer) + w("date", format_highres_date(rev.timestamp, rev.timezone)) + self.to_file.write(b"\n") self._write_delta(rev_tree, base_tree, rev.revision_id, force_binary) - w('revision id', rev.revision_id) - w('sha1', self._testament_sha1(rev.revision_id)) - w('inventory sha1', rev.inventory_sha1) + w("revision id", rev.revision_id) + w("sha1", self._testament_sha1(rev.revision_id)) + w("inventory sha1", rev.inventory_sha1) if rev.parent_ids: - w('parent ids', rev.parent_ids) + w("parent ids", rev.parent_ids) if explicit_base: - w('base id', base_rev) + w("base id", base_rev) if rev.properties: - self._write('properties', None, indent=1) + self._write("properties", None, indent=1) for name, value in sorted(rev.properties.items()): - self._write(name, value, indent=3, - trailing_space_when_empty=True) + self._write(name, value, indent=3, trailing_space_when_empty=True) # Add an extra blank space at the end - self.to_file.write(b'\n') + self.to_file.write(b"\n") def _write_action(self, name, parameters, properties=None): if properties is None: properties = [] - p_texts = ['{}:{}'.format(*v) for v in properties] - self.to_file.write(b'=== ') - self.to_file.write(' '.join([name] + parameters).encode('utf-8')) - self.to_file.write(' // '.join(p_texts).encode('utf-8')) - self.to_file.write(b'\n') - - def _write_delta(self, new_tree, old_tree, default_revision_id, - force_binary): + p_texts = ["{}:{}".format(*v) for v in properties] + self.to_file.write(b"=== ") + self.to_file.write(" ".join([name] + parameters).encode("utf-8")) + self.to_file.write(" // ".join(p_texts).encode("utf-8")) + self.to_file.write(b"\n") + + def _write_delta(self, new_tree, old_tree, default_revision_id, force_binary): """Write out the changes between the trees.""" - DEVNULL = '/dev/null' + DEVNULL = "/dev/null" def do_diff(file_id, old_path, new_path, action, force_binary): def tree_lines(tree, path, require_text=False): @@ -266,54 +263,78 @@ def tree_lines(tree, path, require_text=False): old_lines = tree_lines(old_tree, old_path, require_text=True) new_lines = tree_lines(new_tree, new_path, require_text=True) action.write(self.to_file) - internal_diff(old_path, old_lines, new_path, new_lines, - self.to_file) + internal_diff(old_path, old_lines, new_path, new_lines, self.to_file) except errors.BinaryFile: old_lines = tree_lines(old_tree, old_path, require_text=False) new_lines = tree_lines(new_tree, new_path, require_text=False) - action.add_property('encoding', 'base64') + action.add_property("encoding", "base64") action.write(self.to_file) - binary_diff(old_path, old_lines, new_path, new_lines, - self.to_file) + binary_diff(old_path, old_lines, new_path, new_lines, self.to_file) - def finish_action(action, file_id, kind, meta_modified, text_modified, - old_path, new_path): + def finish_action( + action, file_id, kind, meta_modified, text_modified, old_path, new_path + ): entry = new_tree.root_inventory.get_entry(file_id) if entry.revision != default_revision_id: - action.add_utf8_property('last-changed', entry.revision) + action.add_utf8_property("last-changed", entry.revision) if meta_modified: - action.add_bool_property('executable', entry.executable) + action.add_bool_property("executable", entry.executable) if text_modified and kind == "symlink": - action.add_property('target', entry.symlink_target) + action.add_property("target", entry.symlink_target) if text_modified and kind == "file": do_diff(file_id, old_path, new_path, action, force_binary) else: action.write(self.to_file) - delta = new_tree.changes_from(old_tree, want_unchanged=True, - include_root=True) + delta = new_tree.changes_from(old_tree, want_unchanged=True, include_root=True) for change in delta.removed: - action = Action('removed', [change.kind[0], change.path[0]]).write(self.to_file) + action = Action("removed", [change.kind[0], change.path[0]]).write( + self.to_file + ) # TODO(jelmer): Treat copied specially here? for change in delta.added + delta.copied: action = Action( - 'added', [change.kind[1], change.path[1]], - [('file-id', change.file_id.decode('utf-8'))]) - meta_modified = (change.kind[1] == 'file' and - change.executable[1]) - finish_action(action, change.file_id, change.kind[1], meta_modified, change.changed_content, - DEVNULL, change.path[1]) + "added", + [change.kind[1], change.path[1]], + [("file-id", change.file_id.decode("utf-8"))], + ) + meta_modified = change.kind[1] == "file" and change.executable[1] + finish_action( + action, + change.file_id, + change.kind[1], + meta_modified, + change.changed_content, + DEVNULL, + change.path[1], + ) for change in delta.renamed: - action = Action('renamed', [change.kind[1], change.path[0]], [(change.path[1],)]) - finish_action(action, change.file_id, change.kind[1], change.meta_modified(), change.changed_content, - change.path[0], change.path[1]) + action = Action( + "renamed", [change.kind[1], change.path[0]], [(change.path[1],)] + ) + finish_action( + action, + change.file_id, + change.kind[1], + change.meta_modified(), + change.changed_content, + change.path[0], + change.path[1], + ) for change in delta.modified: - action = Action('modified', [change.kind[1], change.path[1]]) - finish_action(action, change.file_id, change.kind[1], change.meta_modified(), change.changed_content, - change.path[0], change.path[1]) + action = Action("modified", [change.kind[1], change.path[1]]) + finish_action( + action, + change.file_id, + change.kind[1], + change.meta_modified(), + change.changed_content, + change.path[0], + change.path[1], + ) for change in delta.unchanged: new_rev = new_tree.get_file_revision(change.path[1]) @@ -321,8 +342,8 @@ def finish_action(action, file_id, kind, meta_modified, text_modified, continue old_rev = old_tree.get_file_revision(change.path[0]) if new_rev != old_rev: - action = Action('modified', [change.kind[1], change.path[1]]) - action.add_utf8_property('last-changed', new_rev) + action = Action("modified", [change.kind[1], change.path[1]]) + action.add_utf8_property("last-changed", new_rev) action.write(self.to_file) @@ -377,11 +398,11 @@ def _next(self): last = self._next_line self._next_line = line if last is not None: - #mutter('yielding line: %r' % last) + # mutter('yielding line: %r' % last) yield last last = self._next_line self._next_line = None - #mutter('yielding line: %r' % last) + # mutter('yielding line: %r' % last) yield last def _read_revision_header(self): @@ -390,9 +411,9 @@ def _read_revision_header(self): for line in self._next(): # The bzr header is terminated with a blank line # which does not start with '#' - if line is None or line == b'\n': + if line is None or line == b"\n": break - if not line.startswith(b'#'): + if not line.startswith(b"#"): continue found_something = True self._handle_next(line) @@ -403,51 +424,53 @@ def _read_revision_header(self): def _read_next_entry(self, line, indent=1): """Read in a key-value pair.""" - if not line.startswith(b'#'): - raise errors.MalformedHeader('Bzr header did not start with #') - line = line[1:-1].decode('utf-8') # Remove the '#' and '\n' - if line[:indent] == ' ' * indent: + if not line.startswith(b"#"): + raise errors.MalformedHeader("Bzr header did not start with #") + line = line[1:-1].decode("utf-8") # Remove the '#' and '\n' + if line[:indent] == " " * indent: line = line[indent:] if not line: return None, None # Ignore blank lines - loc = line.find(': ') + loc = line.find(": ") if loc != -1: key = line[:loc] - value = line[loc + 2:] + value = line[loc + 2 :] if not value: value = self._read_many(indent=indent + 2) - elif line[-1:] == ':': + elif line[-1:] == ":": key = line[:-1] value = self._read_many(indent=indent + 2) else: - raise errors.MalformedHeader('While looking for key: value pairs,' - ' did not find the colon %r' % (line)) + raise errors.MalformedHeader( + "While looking for key: value pairs," + " did not find the colon %r" % (line) + ) - key = key.replace(' ', '_') - #mutter('found %s: %s' % (key, value)) + key = key.replace(" ", "_") + # mutter('found %s: %s' % (key, value)) return key, value def _handle_next(self, line): if line is None: return key, value = self._read_next_entry(line, indent=1) - mutter(f'_handle_next {key!r} => {value!r}') + mutter(f"_handle_next {key!r} => {value!r}") if key is None: return revision_info = self.info.revisions[-1] if key in revision_info.__dict__: if getattr(revision_info, key) is None: - if key in ('file_id', 'revision_id', 'base_id'): - value = value.encode('utf8') - elif key in ('parent_ids'): - value = [v.encode('utf8') for v in value] - elif key in ('testament_sha1', 'inventory_sha1', 'sha1'): - value = value.encode('ascii') + if key in ("file_id", "revision_id", "base_id"): + value = value.encode("utf8") + elif key in ("parent_ids"): + value = [v.encode("utf8") for v in value] + elif key in ("testament_sha1", "inventory_sha1", "sha1"): + value = value.encode("ascii") setattr(revision_info, key, value) else: - raise errors.MalformedHeader(f'Duplicated Key: {key}') + raise errors.MalformedHeader(f"Duplicated Key: {key}") else: # What do we do with a key we don't recognize raise errors.MalformedHeader(f'Unknown Key: "{key}"') @@ -460,13 +483,13 @@ def _read_many(self, indent): does not start properly indented. """ values = [] - start = b'#' + (b' ' * indent) + start = b"#" + (b" " * indent) if self._next_line is None or not self._next_line.startswith(start): return values for line in self._next(): - values.append(line[len(start):-1].decode('utf-8')) + values.append(line[len(start) : -1].decode("utf-8")) if self._next_line is None or not self._next_line.startswith(start): break return values @@ -477,32 +500,33 @@ def _read_one_patch(self): :return: action, lines, do_continue """ - #mutter('_read_one_patch: %r' % self._next_line) + # mutter('_read_one_patch: %r' % self._next_line) # Peek and see if there are no patches - if self._next_line is None or self._next_line.startswith(b'#'): + if self._next_line is None or self._next_line.startswith(b"#"): return None, [], False first = True lines = [] for line in self._next(): if first: - if not line.startswith(b'==='): - raise errors.MalformedPatches('The first line of all patches' - ' should be a bzr meta line "==="' - ': %r' % line) - action = line[4:-1].decode('utf-8') - elif line.startswith(b'... '): - action += line[len(b'... '):-1].decode('utf-8') - - if (self._next_line is not None and - self._next_line.startswith(b'===')): + if not line.startswith(b"==="): + raise errors.MalformedPatches( + "The first line of all patches" + ' should be a bzr meta line "==="' + ": %r" % line + ) + action = line[4:-1].decode("utf-8") + elif line.startswith(b"... "): + action += line[len(b"... ") : -1].decode("utf-8") + + if self._next_line is not None and self._next_line.startswith(b"==="): return action, lines, True - elif self._next_line is None or self._next_line.startswith(b'#'): + elif self._next_line is None or self._next_line.startswith(b"#"): return action, lines, False if first: first = False - elif not line.startswith(b'... '): + elif not line.startswith(b"... "): lines.append(line) return action, lines, False @@ -528,16 +552,15 @@ def _read_footer(self): self._handle_next(line) if self._next_line is None: break - if not self._next_line.startswith(b'#'): + if not self._next_line.startswith(b"#"): # Consume the trailing \n and stop processing next(self._next()) break class BundleInfo08(BundleInfo): - def _update_tree(self, bundle_tree, revision_id): - bundle_tree.note_last_changed('', revision_id) + bundle_tree.note_last_changed("", revision_id) BundleInfo._update_tree(self, bundle_tree, revision_id) def _testament_sha1_from_revision(self, repository, revision_id): diff --git a/breezy/bzr/bundle/serializer/v09.py b/breezy/bzr/bundle/serializer/v09.py index e8fdf47959..775f010575 100644 --- a/breezy/bzr/bundle/serializer/v09.py +++ b/breezy/bzr/bundle/serializer/v09.py @@ -37,11 +37,10 @@ def check_compatible(self): def _write_main_header(self): """Write the header for the changes.""" f = self.to_file - f.write(_get_bundle_header('0.9') + b'#\n') + f.write(_get_bundle_header("0.9") + b"#\n") def _testament_sha1(self, revision_id): - return StrictTestament3.from_revision(self.source, - revision_id).as_sha1() + return StrictTestament3.from_revision(self.source, revision_id).as_sha1() def read(self, f): """Read the rest of the bundles from the supplied file. diff --git a/breezy/bzr/bundle/serializer/v4.py b/breezy/bzr/bundle/serializer/v4.py index 67ec99cf6c..65d404e7fe 100644 --- a/breezy/bzr/bundle/serializer/v4.py +++ b/breezy/bzr/bundle/serializer/v4.py @@ -34,8 +34,7 @@ class _MPDiffInventoryGenerator(_mod_versionedfile._MPDiffGenerator): """Generate Inventory diffs serialized inventories.""" def __init__(self, repo, inventory_keys): - super().__init__(repo.inventories, - inventory_keys) + super().__init__(repo.inventories, inventory_keys) self.repo = repo self.sha1s = {} @@ -88,8 +87,8 @@ def _write_encoded(self, bytes): def begin(self): """Start writing the bundle.""" - self._fileobj.write(bundle_serializer._get_bundle_header('4')) - self._fileobj.write(b'#\n') + self._fileobj.write(bundle_serializer._get_bundle_header("4")) + self._fileobj.write(b"#\n") self._container.begin() def end(self): @@ -97,8 +96,9 @@ def end(self): self._container.end() self._fileobj.write(self._compressor.flush()) - def add_multiparent_record(self, mp_bytes, sha1, parents, repo_kind, - revision_id, file_id): + def add_multiparent_record( + self, mp_bytes, sha1, parents, repo_kind, revision_id, file_id + ): """Add a record for a multi-parent diff. :mp_bytes: A multi-parent diff, as a bytestring @@ -109,9 +109,7 @@ def add_multiparent_record(self, mp_bytes, sha1, parents, repo_kind, :revision_id: The revision id of the mpdiff being added. :file_id: The file-id of the file, or None for inventories. """ - metadata = {b'parents': parents, - b'storage_kind': b'mpdiff', - b'sha1': sha1} + metadata = {b"parents": parents, b"storage_kind": b"mpdiff", b"sha1": sha1} self._add_record(mp_bytes, metadata, repo_kind, revision_id, file_id) def add_fulltext_record(self, bytes, parents, repo_kind, revision_id): @@ -123,8 +121,13 @@ def add_fulltext_record(self, bytes, parents, repo_kind, revision_id): 'signature' :revision_id: The revision id of the fulltext being added. """ - self._add_record(bytes, {b'parents': parents, - b'storage_kind': b'fulltext'}, repo_kind, revision_id, None) + self._add_record( + bytes, + {b"parents": parents, b"storage_kind": b"fulltext"}, + repo_kind, + revision_id, + None, + ) def add_info_record(self, kwargs): """Add an info record to the bundle. @@ -132,29 +135,31 @@ def add_info_record(self, kwargs): Any parameters may be supplied, except 'self' and 'storage_kind'. Values must be lists, strings, integers, dicts, or a combination. """ - kwargs[b'storage_kind'] = b'header' - self._add_record(None, kwargs, 'info', None, None) + kwargs[b"storage_kind"] = b"header" + self._add_record(None, kwargs, "info", None, None) @staticmethod def encode_name(content_kind, revision_id, file_id=None): """Encode semantic ids as a container name.""" - if content_kind not in ('revision', 'file', 'inventory', 'signature', - 'info'): + if content_kind not in ("revision", "file", "inventory", "signature", "info"): raise ValueError(content_kind) - if content_kind == 'file': + if content_kind == "file": if file_id is None: raise AssertionError() else: if file_id is not None: raise AssertionError() - if content_kind == 'info': + if content_kind == "info": if revision_id is not None: raise AssertionError() elif revision_id is None: raise AssertionError() - names = [n.replace(b'/', b'//') for n in - (content_kind.encode('ascii'), revision_id, file_id) if n is not None] - return b'/'.join(names) + names = [ + n.replace(b"/", b"//") + for n in (content_kind.encode("ascii"), revision_id, file_id) + if n is not None + ] + return b"/".join(names) def _add_record(self, bytes, metadata, repo_kind, revision_id, file_id): """Add a bundle record to the container. @@ -165,8 +170,10 @@ def _add_record(self, bytes, metadata, repo_kind, revision_id, file_id): """ name = self.encode_name(repo_kind, revision_id, file_id) encoded_metadata = bencode.bencode(metadata) - self._container.add_bytes_record([encoded_metadata], len(encoded_metadata), [(name, )]) - if metadata[b'storage_kind'] != b'header': + self._container.add_bytes_record( + [encoded_metadata], len(encoded_metadata), [(name,)] + ) + if metadata[b"storage_kind"] != b"header": self._container.add_bytes_record([bytes], len(bytes), []) @@ -187,7 +194,7 @@ def __init__(self, fileobj, stream_input=True): once is (currently) faster. """ line = fileobj.readline() - if line != '\n': + if line != "\n": fileobj.readline() self.patch_lines = [] if stream_input: @@ -212,13 +219,13 @@ def decode_name(name): :retval: content_kind, revision_id, file_id """ - segments = re.split(b'(//?)', name) - names = [b''] + segments = re.split(b"(//?)", name) + names = [b""] for segment in segments: - if segment == b'//': - names[-1] += b'/' - elif segment == b'/': - names.append(b'') + if segment == b"//": + names[-1] += b"/" + elif segment == b"/": + names.append(b"") else: names[-1] += segment content_kind = names[0] @@ -228,7 +235,7 @@ def decode_name(name): revision_id = names[1] if len(names) > 2: file_id = names[2] - return content_kind.decode('ascii'), revision_id, file_id + return content_kind.decode("ascii"), revision_id, file_id def iter_records(self): """Iterate through bundle records. @@ -239,9 +246,9 @@ def iter_records(self): iterator = pack.iter_records_from_file(self._container_file) for names, bytes in iterator: if len(names) != 1: - raise errors.BadBundle(f'Record has {len(names)} names instead of 1') + raise errors.BadBundle(f"Record has {len(names)} names instead of 1") metadata = bencode.bdecode(bytes) - if metadata[b'storage_kind'] == b'header': + if metadata[b"storage_kind"] == b"header": bytes = None else: _unused, bytes = next(iterator) @@ -271,9 +278,11 @@ def read(self, file): @staticmethod def get_source_serializer(info): """Retrieve the serializer for a given info object.""" - format_name = info[b'serializer'].decode('ascii') + format_name = info[b"serializer"].decode("ascii") inventory_serializer = serializer.inventory_format_registry.get(format_name) - revision_serializer = serializer.revision_format_registry.get({'7': '5', '6': '5'}.get(format_name, format_name)) + revision_serializer = serializer.revision_format_registry.get( + {"7": "5", "6": "5"}.get(format_name, format_name) + ) return (revision_serializer, inventory_serializer) @@ -298,8 +307,14 @@ def __init__(self, base, target, repository, fileobj, revision_ids=None): def do_write(self): """Write all data to the bundle.""" - trace.note(ngettext('Bundling %d revision.', 'Bundling %d revisions.', - len(self.revision_ids)), len(self.revision_ids)) + trace.note( + ngettext( + "Bundling %d revision.", + "Bundling %d revisions.", + len(self.revision_ids), + ), + len(self.revision_ids), + ) with self.repository.lock_read(): self.bundle.begin() self.write_info() @@ -311,26 +326,31 @@ def do_write(self): def write_info(self): """Write format info.""" serializer_format = self.repository.get_serializer_format() - supports_rich_root = {True: 1, False: 0}[ - self.repository.supports_rich_root()] - self.bundle.add_info_record({b'serializer': serializer_format, - b'supports_rich_root': supports_rich_root}) + supports_rich_root = {True: 1, False: 0}[self.repository.supports_rich_root()] + self.bundle.add_info_record( + { + b"serializer": serializer_format, + b"supports_rich_root": supports_rich_root, + } + ) def write_files(self): """Write bundle records for all revisions of all files.""" text_keys = [] altered_fileids = self.repository.fileids_altered_by_revision_ids( - self.revision_ids) + self.revision_ids + ) for file_id, revision_ids in altered_fileids.items(): for revision_id in revision_ids: text_keys.append((file_id, revision_id)) - self._add_mp_records_keys('file', self.repository.texts, text_keys) + self._add_mp_records_keys("file", self.repository.texts, text_keys) def write_revisions(self): """Write bundle records for all revisions and signatures.""" inv_vf = self.repository.inventories - topological_order = [key[-1] for key in multiparent.topo_iter_keys( - inv_vf, self.revision_keys)] + topological_order = [ + key[-1] for key in multiparent.topo_iter_keys(inv_vf, self.revision_keys) + ] revision_order = topological_order if self.target is not None and self.target in self.revision_ids: # Make sure the target revision is always the last entry @@ -342,8 +362,9 @@ def write_revisions(self): # inventories.make_mpdiffs() contains all the data about the tree # shape. Formats without support_altered_by_hack require # chk_bytes/etc, so we use a different code path. - self._add_mp_records_keys('inventory', inv_vf, - [(revid,) for revid in topological_order]) + self._add_mp_records_keys( + "inventory", inv_vf, [(revid,) for revid in topological_order] + ) else: # Inventories should always be added in pure-topological order, so # that we can apply the mpdiff for the child to the parent texts. @@ -360,27 +381,33 @@ def _add_inventory_mpdiffs_from_serializer(self, revision_order): the other side. """ inventory_key_order = [(r,) for r in revision_order] - generator = _MPDiffInventoryGenerator(self.repository, - inventory_key_order) + generator = _MPDiffInventoryGenerator(self.repository, inventory_key_order) for revision_id, parent_ids, sha1, diff in generator.iter_diffs(): - text = b''.join(diff.to_patch()) - self.bundle.add_multiparent_record(text, sha1, parent_ids, - 'inventory', revision_id, None) + text = b"".join(diff.to_patch()) + self.bundle.add_multiparent_record( + text, sha1, parent_ids, "inventory", revision_id, None + ) def _add_revision_texts(self, revision_order): parent_map = self.repository.get_parent_map(revision_order) - revision_to_bytes = self.repository._revision_serializer.write_revision_to_string + revision_to_bytes = ( + self.repository._revision_serializer.write_revision_to_string + ) revisions = self.repository.get_revisions(revision_order) for revision in revisions: revision_id = revision.revision_id parents = parent_map.get(revision_id, None) revision_text = revision_to_bytes(revision) - self.bundle.add_fulltext_record(revision_text, parents, - 'revision', revision_id) + self.bundle.add_fulltext_record( + revision_text, parents, "revision", revision_id + ) try: self.bundle.add_fulltext_record( - self.repository.get_signature_text( - revision_id), parents, 'signature', revision_id) + self.repository.get_signature_text(revision_id), + parents, + "signature", + revision_id, + ) except errors.NoSuchRevision: pass @@ -405,17 +432,21 @@ def _add_mp_records_keys(self, repo_kind, vf, keys): mpdiffs = vf.make_mpdiffs(ordered_keys) sha1s = vf.get_sha1s(ordered_keys) parent_map = vf.get_parent_map(ordered_keys) - for mpdiff, item_key, in zip(mpdiffs, ordered_keys): + for ( + mpdiff, + item_key, + ) in zip(mpdiffs, ordered_keys): sha1 = sha1s[item_key] parents = [key[-1] for key in parent_map[item_key]] - text = b''.join(mpdiff.to_patch()) + text = b"".join(mpdiff.to_patch()) # Infer file id records as appropriate. if len(item_key) == 2: file_id = item_key[0] else: file_id = None - self.bundle.add_multiparent_record(text, sha1, parents, repo_kind, - item_key[-1], file_id) + self.bundle.add_multiparent_record( + text, sha1, parents, repo_kind, item_key[-1], file_id + ) class BundleInfoV4: @@ -439,8 +470,9 @@ def install_revisions(self, repository, stream_input=True): (currently) faster. """ with repository.lock_write(): - ri = RevisionInstaller(self.get_bundle_reader(stream_input), - self._serializer, repository) + ri = RevisionInstaller( + self.get_bundle_reader(stream_input), self._serializer, repository + ) return ri.install() def get_merge_request(self, target_repo): @@ -448,7 +480,7 @@ def get_merge_request(self, target_repo): Returns suggested base, suggested target, and patch verification status """ - return None, self.target, 'inapplicable' + return None, self.target, "inapplicable" def get_bundle_reader(self, stream_input=True): """Return a new BundleReader for the associated bundle. @@ -464,15 +496,23 @@ def _get_real_revisions(self): if self.__real_revisions is None: self.__real_revisions = [] bundle_reader = self.get_bundle_reader() - for bytes, metadata, repo_kind, _revision_id, _file_id in \ - bundle_reader.iter_records(): - if repo_kind == 'info': - revision_serializer, inventory_serializer =\ - self._serializer.get_source_serializer(metadata) - if repo_kind == 'revision': + for ( + bytes, + metadata, + repo_kind, + _revision_id, + _file_id, + ) in bundle_reader.iter_records(): + if repo_kind == "info": + ( + revision_serializer, + inventory_serializer, + ) = self._serializer.get_source_serializer(metadata) + if repo_kind == "revision": rev = revision_serializer.read_revision_from_string(bytes) self.__real_revisions.append(rev) return self.__real_revisions + real_revisions = property(_get_real_revisions) def _get_revisions(self): @@ -480,7 +520,8 @@ def _get_revisions(self): self.__revisions = [] for revision in self.real_revisions: self.__revisions.append( - bundle_data.RevisionInfo.from_revision(revision)) + bundle_data.RevisionInfo.from_revision(revision) + ) return self.__revisions revisions = property(_get_revisions) @@ -513,45 +554,49 @@ def _install_in_write_group(self): pending_file_records = [] pending_inventory_records = [] target_revision = None - for bytes, metadata, repo_kind, revision_id, file_id in\ - self._container.iter_records(): - if repo_kind == 'info': + for ( + bytes, + metadata, + repo_kind, + revision_id, + file_id, + ) in self._container.iter_records(): + if repo_kind == "info": if self._info is not None: raise AssertionError() self._handle_info(metadata) - if (pending_file_records and - (repo_kind, file_id) != ('file', current_file)): + if pending_file_records and (repo_kind, file_id) != ("file", current_file): # Flush the data for a single file - prevents memory # spiking due to buffering all files in memory. - self._install_mp_records_keys(self._repository.texts, - pending_file_records) + self._install_mp_records_keys( + self._repository.texts, pending_file_records + ) current_file = None del pending_file_records[:] - if len(pending_inventory_records) > 0 and repo_kind != 'inventory': + if len(pending_inventory_records) > 0 and repo_kind != "inventory": self._install_inventory_records(pending_inventory_records) pending_inventory_records = [] - if repo_kind == 'inventory': - pending_inventory_records.append( - ((revision_id,), metadata, bytes)) - if repo_kind == 'revision': + if repo_kind == "inventory": + pending_inventory_records.append(((revision_id,), metadata, bytes)) + if repo_kind == "revision": target_revision = revision_id self._install_revision(revision_id, metadata, bytes) - if repo_kind == 'signature': + if repo_kind == "signature": self._install_signature(revision_id, metadata, bytes) - if repo_kind == 'file': + if repo_kind == "file": current_file = file_id - pending_file_records.append( - ((file_id, revision_id), metadata, bytes)) - self._install_mp_records_keys( - self._repository.texts, pending_file_records) + pending_file_records.append(((file_id, revision_id), metadata, bytes)) + self._install_mp_records_keys(self._repository.texts, pending_file_records) return target_revision def _handle_info(self, info): """Extract data from an info record.""" self._info = info - (self._source_revision_serializer, self._source_inventory_serializer) = self._serializer.get_source_serializer(info) - if (info[b'supports_rich_root'] == 0 and - self._repository.supports_rich_root()): + ( + self._source_revision_serializer, + self._source_inventory_serializer, + ) = self._serializer.get_source_serializer(info) + if info[b"supports_rich_root"] == 0 and self._repository.supports_rich_root(): self.update_root = True else: self.update_root = False @@ -560,8 +605,11 @@ def _install_mp_records(self, versionedfile, records): if len(records) == 0: return d_func = multiparent.MultiParent.from_patch - vf_records = [(r, m['parents'], m['sha1'], d_func(t)) for r, m, t in - records if r not in versionedfile] + vf_records = [ + (r, m["parents"], m["sha1"], d_func(t)) + for r, m, t in records + if r not in versionedfile + ] versionedfile.add_mpdiffs(vf_records) def _install_mp_records_keys(self, versionedfile, records): @@ -577,12 +625,13 @@ def _install_mp_records_keys(self, versionedfile, records): prefix = key[:1] else: prefix = () - parents = [prefix + (parent,) for parent in meta[b'parents']] - vf_records.append((key, parents, meta[b'sha1'], d_func(text))) + parents = [prefix + (parent,) for parent in meta[b"parents"]] + vf_records.append((key, parents, meta[b"sha1"], d_func(text))) versionedfile.add_mpdiffs(vf_records) - def _get_parent_inventory_texts(self, inventory_text_cache, - inventory_cache, parent_ids): + def _get_parent_inventory_texts( + self, inventory_text_cache, inventory_cache, parent_ids + ): cached_parent_texts = {} remaining_parent_ids = [] for parent_id in parent_ids: @@ -600,7 +649,8 @@ def _get_parent_inventory_texts(self, inventory_text_cache, # installed yet.) parent_keys = [(r,) for r in remaining_parent_ids] present_parent_map = self._repository.inventories.get_parent_map( - parent_keys) + parent_keys + ) present_parent_ids = [] ghosts = set() for p_id in remaining_parent_ids: @@ -609,23 +659,26 @@ def _get_parent_inventory_texts(self, inventory_text_cache, else: ghosts.add(p_id) to_lines = self._source_inventory_serializer.write_inventory_to_chunks - for parent_inv in self._repository.iter_inventories( - present_parent_ids): - p_text = b''.join(to_lines(parent_inv)) + for parent_inv in self._repository.iter_inventories(present_parent_ids): + p_text = b"".join(to_lines(parent_inv)) inventory_cache[parent_inv.revision_id] = parent_inv cached_parent_texts[parent_inv.revision_id] = p_text inventory_text_cache[parent_inv.revision_id] = p_text - parent_texts = [cached_parent_texts[parent_id] - for parent_id in parent_ids - if parent_id not in ghosts] + parent_texts = [ + cached_parent_texts[parent_id] + for parent_id in parent_ids + if parent_id not in ghosts + ] return parent_texts def _install_inventory_records(self, records): - if (self._info[b'serializer'] == self._repository._inventory_serializer.format_num - and self._repository._inventory_serializer.support_altered_by_hack): - return self._install_mp_records_keys(self._repository.inventories, - records) + if ( + self._info[b"serializer"] + == self._repository._inventory_serializer.format_num + and self._repository._inventory_serializer.support_altered_by_hack + ): + return self._install_mp_records_keys(self._repository.inventories, records) # Use a 10MB text cache, since these are string xml inventories. Note # that 10MB is fairly small for large projects (a single inventory can # be >5MB). Another possibility is to cache 10-20 inventory texts @@ -638,27 +691,31 @@ def _install_inventory_records(self, records): with ui.ui_factory.nested_progress_bar() as pb: num_records = len(records) for idx, (key, metadata, bytes) in enumerate(records): - pb.update('installing inventory', idx, num_records) + pb.update("installing inventory", idx, num_records) revision_id = key[-1] - parent_ids = metadata[b'parents'] + parent_ids = metadata[b"parents"] # Note: This assumes the local ghosts are identical to the # ghosts in the source, as the Bundle serialization # format doesn't record ghosts. - p_texts = self._get_parent_inventory_texts(inventory_text_cache, - inventory_cache, - parent_ids) + p_texts = self._get_parent_inventory_texts( + inventory_text_cache, inventory_cache, parent_ids + ) # Why does to_lines() take strings as the source, it seems that # it would have to cast to a list of lines, which we get back # as lines and then cast back to a string. - target_lines = multiparent.MultiParent.from_patch(bytes - ).to_lines(p_texts) + target_lines = multiparent.MultiParent.from_patch(bytes).to_lines( + p_texts + ) sha1 = osutils.sha_strings(target_lines) - if sha1 != metadata[b'sha1']: + if sha1 != metadata[b"sha1"]: raise errors.BadBundle("Can't convert to target format") # Add this to the cache so we don't have to extract it again. - inventory_text_cache[revision_id] = b''.join(target_lines) - target_inv = self._source_inventory_serializer.read_inventory_from_lines( - target_lines) + inventory_text_cache[revision_id] = b"".join(target_lines) + target_inv = ( + self._source_inventory_serializer.read_inventory_from_lines( + target_lines + ) + ) del target_lines self._handle_root(target_inv, parent_ids) parent_inv = None @@ -666,12 +723,14 @@ def _install_inventory_records(self, records): parent_inv = inventory_cache.get(parent_ids[0], None) try: if parent_inv is None: - self._repository.add_inventory(revision_id, target_inv, - parent_ids) + self._repository.add_inventory( + revision_id, target_inv, parent_ids + ) else: delta = target_inv._make_delta(parent_inv) - self._repository.add_inventory_by_delta(parent_ids[0], - delta, revision_id, parent_ids) + self._repository.add_inventory_by_delta( + parent_ids[0], delta, revision_id, parent_ids + ) except serializer.UnsupportedInventoryKind as e: raise errors.IncompatibleRevision(repr(self._repository)) from e inventory_cache[revision_id] = target_inv @@ -680,8 +739,7 @@ def _handle_root(self, target_inv, parent_ids): revision_id = target_inv.revision_id if self.update_root: text_key = (target_inv.root.file_id, revision_id) - parent_keys = [(target_inv.root.file_id, parent) for - parent in parent_ids] + parent_keys = [(target_inv.root.file_id, parent) for parent in parent_ids] self._repository.texts.add_lines(text_key, parent_keys, []) elif not self._repository.supports_rich_root(): if target_inv.root.revision != revision_id: diff --git a/breezy/bzr/bzrdir.py b/breezy/bzr/bzrdir.py index 083aa8d8cc..84c1093fab 100644 --- a/breezy/bzr/bzrdir.py +++ b/breezy/bzr/bzrdir.py @@ -31,7 +31,9 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( branch as _mod_branch, repository, @@ -52,7 +54,8 @@ workingtree_4, ) from breezy.i18n import gettext -""") +""", +) from .. import config, controldir, errors, lockdir, osutils from .. import transport as _mod_transport @@ -61,17 +64,17 @@ class MissingFeature(errors.BzrError): - - _fmt = ("Missing feature %(feature)s not provided by this " - "version of Breezy or any plugin.") + _fmt = ( + "Missing feature %(feature)s not provided by this " + "version of Breezy or any plugin." + ) def __init__(self, feature): self.feature = feature class FeatureAlreadyRegistered(errors.BzrError): - - _fmt = 'The feature %(feature)s has already been registered.' + _fmt = "The feature %(feature)s has already been registered." def __init__(self, feature): self.feature = feature @@ -120,16 +123,23 @@ def check_conversion_target(self, target_format): # fetched compatibly with the target. target_repo_format = target_format.repository_format try: - self.open_repository()._format.check_conversion_target( - target_repo_format) + self.open_repository()._format.check_conversion_target(target_repo_format) except errors.NoRepositoryPresent: # No repo, no problem. pass - def clone_on_transport(self, transport, revision_id=None, - force_new_repo=False, preserve_stacking=False, stacked_on=None, - create_prefix=False, use_existing_dir=True, no_tree=False, - tag_selector=None): + def clone_on_transport( + self, + transport, + revision_id=None, + force_new_repo=False, + preserve_stacking=False, + stacked_on=None, + create_prefix=False, + use_existing_dir=True, + no_tree=False, + tag_selector=None, + ): """Clone this bzrdir and its contents to transport verbatim. :param transport: The transport for the location to produce the clone @@ -151,7 +161,7 @@ def clone_on_transport(self, transport, revision_id=None, # We may want to create a repo/branch/tree, if we do so what format # would we want for each: - require_stacking = (stacked_on is not None) + require_stacking = stacked_on is not None format = self.cloning_metadir(require_stacking) # Figure out what objects we want: @@ -161,7 +171,7 @@ def clone_on_transport(self, transport, revision_id=None, local_repo = None local_branches = self.get_branches() try: - local_active_branch = local_branches[''] + local_active_branch = local_branches[""] except KeyError: pass else: @@ -171,16 +181,17 @@ def clone_on_transport(self, transport, revision_id=None, if preserve_stacking: try: stacked_on = local_active_branch.get_stacked_on_url() - except (_mod_branch.UnstackableBranchFormat, - errors.UnstackableRepositoryFormat, - errors.NotStacked): + except ( + _mod_branch.UnstackableBranchFormat, + errors.UnstackableRepositoryFormat, + errors.NotStacked, + ): pass # Bug: We create a metadir without knowing if it can support stacking, # we should look up the policy needs first, or just use it as a hint, # or something. if local_repo: - make_working_trees = (local_repo.make_working_trees() and - not no_tree) + make_working_trees = local_repo.make_working_trees() and not no_tree want_shared = local_repo.is_shared() repo_format_name = format.repository_format.network_name() else: @@ -188,13 +199,22 @@ def clone_on_transport(self, transport, revision_id=None, want_shared = False repo_format_name = None - result_repo, result, require_stacking, repository_policy = \ - format.initialize_on_transport_ex( - transport, use_existing_dir=use_existing_dir, - create_prefix=create_prefix, force_new_repo=force_new_repo, - stacked_on=stacked_on, stack_on_pwd=self.root_transport.base, - repo_format_name=repo_format_name, - make_working_trees=make_working_trees, shared_repo=want_shared) + ( + result_repo, + result, + require_stacking, + repository_policy, + ) = format.initialize_on_transport_ex( + transport, + use_existing_dir=use_existing_dir, + create_prefix=create_prefix, + force_new_repo=force_new_repo, + stacked_on=stacked_on, + stack_on_pwd=self.root_transport.base, + repo_format_name=repo_format_name, + make_working_trees=make_working_trees, + shared_repo=want_shared, + ) if repo_format_name: try: # If the result repository is in the same place as the @@ -203,11 +223,14 @@ def clone_on_transport(self, transport, revision_id=None, # copied, and finally if we are copying up to a specific # revision_id then we can use the pending-ancestry-result which # does not require traversing all of history to describe it. - if (result_repo.user_url == result.user_url + if ( + result_repo.user_url == result.user_url and not require_stacking - and revision_id is not None): + and revision_id is not None + ): fetch_spec = vf_search.PendingAncestryResult( - [revision_id], local_repo) + [revision_id], local_repo + ) result_repo.fetch(local_repo, fetch_spec=fetch_spec) else: result_repo.fetch(local_repo, revision_id=revision_id) @@ -215,19 +238,22 @@ def clone_on_transport(self, transport, revision_id=None, result_repo.unlock() else: if result_repo is not None: - raise AssertionError(f'result_repo not None({result_repo!r})') + raise AssertionError(f"result_repo not None({result_repo!r})") # 1 if there is a branch present # make sure its content is available in the target repository # clone it. for name, local_branch in local_branches.items(): local_branch.clone( - result, revision_id=(None if name != '' else revision_id), + result, + revision_id=(None if name != "" else revision_id), repository_policy=repository_policy, - name=name, tag_selector=tag_selector) + name=name, + tag_selector=tag_selector, + ) try: # Cheaper to check if the target is not local, than to try making # the tree and fail. - result.root_transport.local_abspath('.') + result.root_transport.local_abspath(".") if result_repo is None or result_repo.make_working_trees(): self.open_workingtree().clone(result, revision_id=revision_id) except (errors.NoWorkingTree, errors.NotLocalUrl): @@ -240,8 +266,13 @@ def _make_tail(self, url): t = _mod_transport.get_transport(url) t.ensure_base() - def determine_repository_policy(self, force_new_repo=False, stack_on=None, - stack_on_pwd=None, require_stacking=False): + def determine_repository_policy( + self, + force_new_repo=False, + stack_on=None, + stack_on_pwd=None, + require_stacking=False, + ): """Return an object representing a policy to use. This controls whether a new repository is created, and the format of @@ -255,6 +286,7 @@ def determine_repository_policy(self, force_new_repo=False, stack_on=None, :param stack_on_pwd: If stack_on is relative, the location it is relative to. """ + def repository_policy(found_bzrdir): stack_on = None stack_on_pwd = None @@ -270,8 +302,10 @@ def repository_policy(found_bzrdir): except errors.NoRepositoryPresent: repository = None else: - if (found_bzrdir.user_url != self.user_url and - not repository.is_shared()): + if ( + found_bzrdir.user_url != self.user_url + and not repository.is_shared() + ): # Don't look higher, can't use a higher shared repo. repository = None stop = True @@ -281,12 +315,15 @@ def repository_policy(found_bzrdir): return None, False if repository: return UseExistingRepository( - repository, stack_on, stack_on_pwd, - require_stacking=require_stacking), True + repository, + stack_on, + stack_on_pwd, + require_stacking=require_stacking, + ), True else: return CreateRepository( - self, stack_on, stack_on_pwd, - require_stacking=require_stacking), True + self, stack_on, stack_on_pwd, require_stacking=require_stacking + ), True if not force_new_repo: if stack_on is None: @@ -296,12 +333,16 @@ def repository_policy(found_bzrdir): else: try: return UseExistingRepository( - self.open_repository(), stack_on, stack_on_pwd, - require_stacking=require_stacking) + self.open_repository(), + stack_on, + stack_on_pwd, + require_stacking=require_stacking, + ) except errors.NoRepositoryPresent: pass - return CreateRepository(self, stack_on, stack_on_pwd, - require_stacking=require_stacking) + return CreateRepository( + self, stack_on, stack_on_pwd, require_stacking=require_stacking + ) def _find_or_create_repository(self, force_new_repo): """Create a new repository if needed, returning the repository.""" @@ -335,11 +376,20 @@ def _find_source_repo(self, exit_stack, source_branch): exit_stack.enter_context(source_branch.lock_read()) return source_branch, source_repository - def sprout(self, url, revision_id=None, force_new_repo=False, - recurse='down', possible_transports=None, - accelerator_tree=None, hardlink=False, stacked=False, - source_branch=None, create_tree_if_local=True, - lossy=False): + def sprout( + self, + url, + revision_id=None, + force_new_repo=False, + recurse="down", + possible_transports=None, + accelerator_tree=None, + hardlink=False, + stacked=False, + source_branch=None, + create_tree_if_local=True, + lossy=False, + ): """Create a copy of this controldir prepared for use as a new line of development. @@ -373,20 +423,18 @@ def sprout(self, url, revision_id=None, force_new_repo=False, if possible_transports is None: possible_transports = [] else: - possible_transports = list(possible_transports) + [ - self.root_transport] - target_transport = _mod_transport.get_transport(url, - possible_transports) + possible_transports = list(possible_transports) + [self.root_transport] + target_transport = _mod_transport.get_transport(url, possible_transports) target_transport.ensure_base() cloning_format = self.cloning_metadir(stacked) # Create/update the result branch try: - result = controldir.ControlDir.open_from_transport( - target_transport) + result = controldir.ControlDir.open_from_transport(target_transport) except errors.NotBranchError: result = cloning_format.initialize_on_transport(target_transport) source_branch, source_repository = self._find_source_repo( - stack, source_branch) + stack, source_branch + ) fetch_spec_factory.source_branch = source_branch # if a stacked branch wasn't requested, we don't create one # even if the origin was stacked @@ -395,9 +443,11 @@ def sprout(self, url, revision_id=None, force_new_repo=False, else: stacked_branch_url = None repository_policy = result.determine_repository_policy( - force_new_repo, stacked_branch_url, require_stacking=stacked) + force_new_repo, stacked_branch_url, require_stacking=stacked + ) result_repo, is_new_repo = repository_policy.acquire_repository( - possible_transports=possible_transports) + possible_transports=possible_transports + ) stack.enter_context(result_repo.lock_write()) fetch_spec_factory.source_repo = source_repository fetch_spec_factory.target_repo = result_repo @@ -421,29 +471,38 @@ def sprout(self, url, revision_id=None, force_new_repo=False, result_branch.generate_revision_history(revision_id) else: result_branch = source_branch.sprout( - result, revision_id=revision_id, - repository_policy=repository_policy, repository=result_repo) + result, + revision_id=revision_id, + repository_policy=repository_policy, + repository=result_repo, + ) mutter(f"created new branch {result_branch!r}") # Create/update the result working tree - if (create_tree_if_local and not result.has_workingtree() - and isinstance(target_transport, local.LocalTransport) - and (result_repo is None or result_repo.make_working_trees()) - and result.open_branch( - name="", - possible_transports=possible_transports).name == result_branch.name): + if ( + create_tree_if_local + and not result.has_workingtree() + and isinstance(target_transport, local.LocalTransport) + and (result_repo is None or result_repo.make_working_trees()) + and result.open_branch( + name="", possible_transports=possible_transports + ).name + == result_branch.name + ): wt = result.create_workingtree( - accelerator_tree=accelerator_tree, hardlink=hardlink, - from_branch=result_branch) + accelerator_tree=accelerator_tree, + hardlink=hardlink, + from_branch=result_branch, + ) with wt.lock_write(): - if not wt.is_versioned(''): + if not wt.is_versioned(""): try: - wt.set_root_id(self.open_workingtree.path2id('')) + wt.set_root_id(self.open_workingtree.path2id("")) except errors.NoWorkingTree: pass else: wt = None - if recurse == 'down': + if recurse == "down": tree = None if wt is not None: tree = wt @@ -461,17 +520,22 @@ def sprout(self, url, revision_id=None, force_new_repo=False, for path in subtrees: target = urlutils.join(url, urlutils.escape(path)) sublocation = tree.reference_parent( - path, branch=result_branch, - possible_transports=possible_transports) + path, + branch=result_branch, + possible_transports=possible_transports, + ) if sublocation is None: warning( - 'Ignoring nested tree %s, parent location unknown.', - path) + "Ignoring nested tree %s, parent location unknown.", path + ) continue sublocation.controldir.sprout( - target, basis.get_reference_revision(path), - force_new_repo=force_new_repo, recurse=recurse, - stacked=stacked) + target, + basis.get_reference_revision(path), + force_new_repo=force_new_repo, + recurse=recurse, + stacked=stacked, + ) return result def _available_backup_name(self, base): @@ -487,14 +551,16 @@ def backup_bzrdir(self): :return: Tuple with old path name and new path name """ with ui.ui_factory.nested_progress_bar(): - old_path = self.root_transport.abspath('.bzr') - backup_dir = self._available_backup_name('backup.bzr') + old_path = self.root_transport.abspath(".bzr") + backup_dir = self._available_backup_name("backup.bzr") new_path = self.root_transport.abspath(backup_dir) ui.ui_factory.note( - gettext('making backup of {0}\n to {1}').format( - urlutils.unescape_for_display(old_path, 'utf-8'), - urlutils.unescape_for_display(new_path, 'utf-8'))) - self.root_transport.copy_tree('.bzr', backup_dir) + gettext("making backup of {0}\n to {1}").format( + urlutils.unescape_for_display(old_path, "utf-8"), + urlutils.unescape_for_display(new_path, "utf-8"), + ) + ) + self.root_transport.copy_tree(".bzr", backup_dir) return (old_path, new_path) def retire_bzrdir(self, limit=10000): @@ -510,10 +576,13 @@ def retire_bzrdir(self, limit=10000): i = 0 while True: try: - to_path = '.bzr.retired.%d' % i - self.root_transport.rename('.bzr', to_path) - note(gettext("renamed {0} to {1}").format( - self.root_transport.abspath('.bzr'), to_path)) + to_path = ".bzr.retired.%d" % i + self.root_transport.rename(".bzr", to_path) + note( + gettext("renamed {0} to {1}").format( + self.root_transport.abspath(".bzr"), to_path + ) + ) return except (errors.TransportError, OSError, errors.PathError): i += 1 @@ -540,14 +609,13 @@ def _find_containing(self, evaluate): result, stop = evaluate(found_bzrdir) if stop: return result - next_transport = found_bzrdir.root_transport.clone('..') - if (found_bzrdir.user_url == next_transport.base): + next_transport = found_bzrdir.root_transport.clone("..") + if found_bzrdir.user_url == next_transport.base: # top of the file system return None # find the next containing bzrdir try: - found_bzrdir = self.open_containing_from_transport( - next_transport)[0] + found_bzrdir = self.open_containing_from_transport(next_transport)[0] except errors.NotBranchError: return None @@ -558,6 +626,7 @@ def find_repository(self): new branches as well as to hook existing branches up to their repository. """ + def usable_repository(found_bzrdir): # does it have a repository ? try: @@ -588,7 +657,7 @@ def _find_creation_modes(self): return self._mode_check_done = True try: - st = self.transport.stat('.') + st = self.transport.stat(".") except errors.TransportNotPossible: self._dir_mode = None self._file_mode = None @@ -597,7 +666,7 @@ def _find_creation_modes(self): # directories and files are read-write for this user. This is # mostly a workaround for filesystems which lie about being able to # write to a directory (cygwin & win32) - if (st.st_mode & 0o7777 == 00000): + if st.st_mode & 0o7777 == 00000: # FTP allows stat but does not return dir/file modes self._dir_mode = None self._file_mode = None @@ -638,7 +707,7 @@ def __init__(self, _transport, _format): self._format = _format # these are also under the more standard names of # control_transport and user_transport - self.transport = _transport.clone('.bzr') + self.transport = _transport.clone(".bzr") self.root_transport = _transport self._mode_check_done = False @@ -706,8 +775,7 @@ def cloning_metadir(self, require_stacking=False): return format # We have a repository, so set a working tree? (Why? This seems to # contradict the stated return value in the docstring). - tree_format = ( - repository._format._matchingcontroldir.workingtree_format) + tree_format = repository._format._matchingcontroldir.workingtree_format format.workingtree_format = tree_format.__class__() if require_stacking: format.require_stacking() @@ -762,12 +830,18 @@ def create(cls, base, format=None, possible_transports=None) -> "BzrDir": can be reused to share a remote connection. """ if cls is not BzrDir: - raise AssertionError("BzrDir.create always creates the " - "default format, not one of %r" % cls) + raise AssertionError( + "BzrDir.create always creates the " + "default format, not one of %r" % cls + ) if format is None: format = BzrDirFormat.get_default_format() - return cast("BzrDir", controldir.ControlDir.create( - base, format=format, possible_transports=possible_transports)) + return cast( + "BzrDir", + controldir.ControlDir.create( + base, format=format, possible_transports=possible_transports + ), + ) def __repr__(self): return f"<{self.__class__.__name__} at {self.user_url!r}>" @@ -781,7 +855,7 @@ def update_feature_flags(self, updated_flags): self.control_files.lock_write() try: self._format._update_feature_flags(updated_flags) - self.transport.put_bytes('branch-format', self._format.as_string()) + self.transport.put_bytes("branch-format", self._format.as_string()) finally: self.control_files.unlock() @@ -806,8 +880,8 @@ def _get_branch_path(self, name): :return: Relative path to branch """ if name == "": - return 'branch' - return urlutils.join('branches', urlutils.escape(name)) + return "branch" + return urlutils.join("branches", urlutils.escape(name)) def _read_branch_list(self): """Read the branch list. @@ -815,14 +889,14 @@ def _read_branch_list(self): :return: List of branch names. """ try: - f = self.control_transport.get('branch-list') + f = self.control_transport.get("branch-list") except _mod_transport.NoSuchFile: return [] ret = [] try: for name in f: - ret.append(name.rstrip(b"\n").decode('utf-8')) + ret.append(name.rstrip(b"\n").decode("utf-8")) finally: f.close() return ret @@ -833,27 +907,31 @@ def _write_branch_list(self, branches): :param branches: List of utf-8 branch names to write """ self.transport.put_bytes( - 'branch-list', - b"".join([name.encode('utf-8') + b"\n" for name in branches])) + "branch-list", b"".join([name.encode("utf-8") + b"\n" for name in branches]) + ) def __init__(self, _transport, _format): super().__init__(_transport, _format) self.control_files = lockable_files.LockableFiles( - self.control_transport, self._format._lock_file_name, - self._format._lock_class) + self.control_transport, + self._format._lock_file_name, + self._format._lock_class, + ) def can_convert_format(self): """See BzrDir.can_convert_format().""" return True - def create_branch(self, name=None, repository=None, - append_revisions_only=None): + def create_branch(self, name=None, repository=None, append_revisions_only=None): """See ControlDir.create_branch.""" if name is None: name = self._get_selected_branch() return self._format.get_branch_format().initialize( - self, name=name, repository=repository, - append_revisions_only=append_revisions_only) + self, + name=name, + repository=repository, + append_revisions_only=append_revisions_only, + ) def destroy_branch(self, name=None): """See ControlDir.destroy_branch.""" @@ -875,7 +953,8 @@ def destroy_branch(self, name=None): self.transport.delete_tree(path) except _mod_transport.NoSuchFile as e: raise errors.NotBranchError( - path=urlutils.join(self.transport.base, path), controldir=self) from e + path=urlutils.join(self.transport.base, path), controldir=self + ) from e def create_repository(self, shared=False): """See BzrDir.create_repository.""" @@ -884,16 +963,21 @@ def create_repository(self, shared=False): def destroy_repository(self): """See BzrDir.destroy_repository.""" try: - self.transport.delete_tree('repository') + self.transport.delete_tree("repository") except _mod_transport.NoSuchFile as e: raise errors.NoRepositoryPresent(self) from e - def create_workingtree(self, revision_id=None, from_branch=None, - accelerator_tree=None, hardlink=False): + def create_workingtree( + self, revision_id=None, from_branch=None, accelerator_tree=None, hardlink=False + ): """See BzrDir.create_workingtree.""" return self._format.workingtree_format.initialize( - self, revision_id, from_branch=from_branch, - accelerator_tree=accelerator_tree, hardlink=hardlink) + self, + revision_id, + from_branch=from_branch, + accelerator_tree=accelerator_tree, + hardlink=hardlink, + ) def destroy_workingtree(self): """See BzrDir.destroy_workingtree.""" @@ -907,7 +991,7 @@ def destroy_workingtree(self): self.destroy_workingtree_metadata() def destroy_workingtree_metadata(self): - self.transport.delete_tree('checkout') + self.transport.delete_tree("checkout") def find_branch_format(self, name=None): """Find the branch 'format' for this bzrdir. @@ -915,24 +999,29 @@ def find_branch_format(self, name=None): This might be a synthetic object for e.g. RemoteBranch and SVN. """ from .branch import BranchFormatMetadir + return BranchFormatMetadir.find_format(self, name=name) def _get_mkdir_mode(self): """Figure out the mode to use when creating a bzrdir subdir.""" temp_control = lockable_files.LockableFiles( - self.transport, '', lockable_files.TransportLock) + self.transport, "", lockable_files.TransportLock + ) return temp_control._dir_mode def get_branch_reference(self, name=None): """See BzrDir.get_branch_reference().""" from .branch import BranchFormatMetadir + format = BranchFormatMetadir.find_format(self, name=name) return format.get_reference(self, name=name) def set_branch_reference(self, target_branch, name=None): format = _mod_bzrbranch.BranchReferenceFormat() - if (self.control_url == target_branch.controldir.control_url - and name == target_branch.name): + if ( + self.control_url == target_branch.controldir.control_url + and name == target_branch.name + ): raise controldir.BranchReferenceLoop(target_branch) return format.initialize(self, target_branch=target_branch, name=name) @@ -958,8 +1047,7 @@ def get_branch_transport(self, branch_format, name=None): dirname = urlutils.dirname(name) if dirname != "" and dirname in branches: raise errors.ParentBranchExists(name) - child_branches = [ - b.startswith(name + "/") for b in branches] + child_branches = [b.startswith(name + "/") for b in branches] if any(child_branches): raise errors.AlreadyBranchError(name) branches.append(name) @@ -978,30 +1066,30 @@ def get_branch_transport(self, branch_format, name=None): def get_repository_transport(self, repository_format): """See BzrDir.get_repository_transport().""" if repository_format is None: - return self.transport.clone('repository') + return self.transport.clone("repository") try: repository_format.get_format_string() except NotImplementedError as e: raise errors.IncompatibleFormat(repository_format, self._format) from e try: - self.transport.mkdir('repository', mode=self._get_mkdir_mode()) + self.transport.mkdir("repository", mode=self._get_mkdir_mode()) except _mod_transport.FileExists: pass - return self.transport.clone('repository') + return self.transport.clone("repository") def get_workingtree_transport(self, workingtree_format): """See BzrDir.get_workingtree_transport().""" if workingtree_format is None: - return self.transport.clone('checkout') + return self.transport.clone("checkout") try: workingtree_format.get_format_string() except NotImplementedError as e: raise errors.IncompatibleFormat(workingtree_format, self._format) from e try: - self.transport.mkdir('checkout', mode=self._get_mkdir_mode()) + self.transport.mkdir("checkout", mode=self._get_mkdir_mode()) except _mod_transport.FileExists: pass - return self.transport.clone('checkout') + return self.transport.clone("checkout") def branch_names(self): """See ControlDir.branch_names.""" @@ -1035,6 +1123,7 @@ def has_workingtree(self): ahead and try, and not ask permission first. """ from .workingtree import WorkingTreeFormatMetaDir + try: WorkingTreeFormatMetaDir.find_format_string(self) except errors.NoWorkingTree: @@ -1043,35 +1132,41 @@ def has_workingtree(self): def needs_format_conversion(self, format): """See BzrDir.needs_format_conversion().""" - if (not isinstance(self._format, format.__class__) - or self._format.get_format_string() != format.get_format_string()): + if ( + not isinstance(self._format, format.__class__) + or self._format.get_format_string() != format.get_format_string() + ): # it is not a meta dir format, conversion is needed. return True # we might want to push this down to the repository? try: - if not isinstance(self.open_repository()._format, - format.repository_format.__class__): + if not isinstance( + self.open_repository()._format, format.repository_format.__class__ + ): # the repository needs an upgrade. return True except errors.NoRepositoryPresent: pass for branch in self.list_branches(): - if not isinstance(branch._format, - format.get_branch_format().__class__): + if not isinstance(branch._format, format.get_branch_format().__class__): # the branch needs an upgrade. return True try: my_wt = self.open_workingtree(recommend_upgrade=False) - if not isinstance(my_wt._format, - format.workingtree_format.__class__): + if not isinstance(my_wt._format, format.workingtree_format.__class__): # the workingtree needs an upgrade. return True except (errors.NoWorkingTree, errors.NotLocalUrl): pass return False - def open_branch(self, name=None, unsupported=False, - ignore_fallbacks=False, possible_transports=None): + def open_branch( + self, + name=None, + unsupported=False, + ignore_fallbacks=False, + possible_transports=None, + ): """See ControlDir.open_branch.""" if name is None: name = self._get_selected_branch() @@ -1082,28 +1177,34 @@ def open_branch(self, name=None, unsupported=False, else: possible_transports = list(possible_transports) possible_transports.append(self.root_transport) - return format.open(self, name=name, - _found=True, ignore_fallbacks=ignore_fallbacks, - possible_transports=possible_transports) + return format.open( + self, + name=name, + _found=True, + ignore_fallbacks=ignore_fallbacks, + possible_transports=possible_transports, + ) def open_repository(self, unsupported=False): """See BzrDir.open_repository.""" from .repository import RepositoryFormatMetaDir + format = RepositoryFormatMetaDir.find_format(self) format.check_support_status(unsupported) return format.open(self, _found=True) - def open_workingtree(self, unsupported=False, - recommend_upgrade=True): + def open_workingtree(self, unsupported=False, recommend_upgrade=True): """See BzrDir.open_workingtree.""" from .workingtree import WorkingTreeFormatMetaDir + format = WorkingTreeFormatMetaDir.find_format(self) - format.check_support_status(unsupported, recommend_upgrade, - basedir=self.root_transport.base) + format.check_support_status( + unsupported, recommend_upgrade, basedir=self.root_transport.base + ) return format.open(self, _found=True) def _get_config(self): - return config.TransportConfig(self.transport, 'control.conf') + return config.TransportConfig(self.transport, "control.conf") class BzrFormat: @@ -1150,8 +1251,9 @@ def unregister_feature(cls, name): """Unregister a feature.""" cls._present_features.remove(name) - def check_support_status(self, allow_unsupported, recommend_upgrade=True, - basedir=None): + def check_support_status( + self, allow_unsupported, recommend_upgrade=True, basedir=None + ): for name, necessity in self.features.items(): if name in self._present_features: continue @@ -1161,8 +1263,7 @@ def check_support_status(self, allow_unsupported, recommend_upgrade=True, elif necessity == b"required": raise MissingFeature(name) else: - mutter("treating unknown necessity as require for %s", - name) + mutter("treating unknown necessity as require for %s", name) raise MissingFeature(name) @classmethod @@ -1174,30 +1275,34 @@ def get_format_string(cls): def from_string(cls, text): format_string = cls.get_format_string() if not text.startswith(format_string): - raise AssertionError( - f"Invalid format header {text!r} for {cls!r}") - lines = text[len(format_string):].splitlines() + raise AssertionError(f"Invalid format header {text!r} for {cls!r}") + lines = text[len(format_string) :].splitlines() ret = cls() for lineno, line in enumerate(lines): try: (necessity, feature) = line.split(b" ", 1) except ValueError as e: - raise errors.ParseFormatError(format=cls, lineno=lineno + 2, - line=line, text=text) from e + raise errors.ParseFormatError( + format=cls, lineno=lineno + 2, line=line, text=text + ) from e ret.features[feature] = necessity return ret def as_string(self): """Return the string representation of this format.""" lines = [self.get_format_string()] - lines.extend([(item[1] + b" " + item[0] + b"\n") - for item in sorted(self.features.items())]) + lines.extend( + [ + (item[1] + b" " + item[0] + b"\n") + for item in sorted(self.features.items()) + ] + ) return b"".join(lines) @classmethod def _find_format(klass, registry, kind, format_string): try: - first_line = format_string[:format_string.index(b"\n") + 1] + first_line = format_string[: format_string.index(b"\n") + 1] except ValueError: first_line = format_string try: @@ -1214,8 +1319,7 @@ def network_name(self): return self.as_string() def __eq__(self, other): - return (self.__class__ is other.__class__ - and self.features == other.features) + return self.__class__ is other.__class__ and self.features == other.features def _update_feature_flags(self, updated_flags): """Update the feature flags in this format. @@ -1244,7 +1348,7 @@ class BzrDirFormat(BzrFormat, controldir.ControlDirFormat): object will be created every system load. """ - _lock_file_name = 'branch-lock' + _lock_file_name = "branch-lock" # _lock_class must be set in subclasses to the lock type, typ. # TransportLock or LockDir @@ -1264,14 +1368,24 @@ def initialize_on_transport(self, transport): if not isinstance(self, BzrDirMetaFormat1): return self._initialize_on_transport_vfs(transport) from .remote import RemoteBzrDirFormat + remote_format = RemoteBzrDirFormat() self._supply_sub_formats_to(remote_format) return remote_format.initialize_on_transport(transport) - def initialize_on_transport_ex(self, transport, use_existing_dir=False, - create_prefix=False, force_new_repo=False, stacked_on=None, - stack_on_pwd=None, repo_format_name=None, make_working_trees=None, - shared_repo=False, vfs_only=False): + def initialize_on_transport_ex( + self, + transport, + use_existing_dir=False, + create_prefix=False, + force_new_repo=False, + stacked_on=None, + stack_on_pwd=None, + repo_format_name=None, + make_working_trees=None, + shared_repo=False, + vfs_only=False, + ): """Create this format on transport. The directory to initialize will be created. @@ -1314,26 +1428,30 @@ def initialize_on_transport_ex(self, transport, use_existing_dir=False, remote_dir_format._network_name = self.network_name() self._supply_sub_formats_to(remote_dir_format) return remote_dir_format.initialize_on_transport_ex( - transport, use_existing_dir=use_existing_dir, - create_prefix=create_prefix, force_new_repo=force_new_repo, - stacked_on=stacked_on, stack_on_pwd=stack_on_pwd, + transport, + use_existing_dir=use_existing_dir, + create_prefix=create_prefix, + force_new_repo=force_new_repo, + stacked_on=stacked_on, + stack_on_pwd=stack_on_pwd, repo_format_name=repo_format_name, make_working_trees=make_working_trees, - shared_repo=shared_repo) + shared_repo=shared_repo, + ) # XXX: Refactor the create_prefix/no_create_prefix code into a # common helper function # The destination may not exist - if so make it according to policy. def make_directory(transport): - transport.mkdir('.') + transport.mkdir(".") return transport def redirected(transport, e, redirection_notice): note(redirection_notice) return transport._redirected_to(e.source, e.target) + try: - transport = do_catching_redirections(make_directory, transport, - redirected) + transport = do_catching_redirections(make_directory, transport, redirected) except _mod_transport.FileExists: if not use_existing_dir: raise @@ -1342,7 +1460,7 @@ def redirected(transport, e, redirection_notice): raise transport.create_prefix() - require_stacking = (stacked_on is not None) + require_stacking = stacked_on is not None # Now the target directory exists, but doesn't have a .bzr # directory. So we need to create it, along with any work to create # all of the dependent branches, etc. @@ -1351,17 +1469,22 @@ def redirected(transport, e, redirection_notice): if repo_format_name: try: # use a custom format - result._format.repository_format = \ + result._format.repository_format = ( repository.network_format_registry.get(repo_format_name) + ) except AttributeError: # The format didn't permit it to be set. pass # A repository is desired, either in-place or shared. repository_policy = result.determine_repository_policy( - force_new_repo, stacked_on, stack_on_pwd, - require_stacking=require_stacking) + force_new_repo, + stacked_on, + stack_on_pwd, + require_stacking=require_stacking, + ) result_repo, is_new_repo = repository_policy.acquire_repository( - make_working_trees, shared_repo) + make_working_trees, shared_repo + ) if not require_stacking and repository_policy._require_stacking: require_stacking = True result._format.require_stacking() @@ -1379,35 +1502,41 @@ def _initialize_on_transport_vfs(self, transport): """ # Since we are creating a .bzr directory, inherit the # mode from the root directory - temp_control = lockable_files.LockableFiles(transport, - '', lockable_files.TransportLock) + temp_control = lockable_files.LockableFiles( + transport, "", lockable_files.TransportLock + ) try: - temp_control._transport.mkdir('.bzr', - # FIXME: RBC 20060121 don't peek under - # the covers - mode=temp_control._dir_mode) + temp_control._transport.mkdir( + ".bzr", + # FIXME: RBC 20060121 don't peek under + # the covers + mode=temp_control._dir_mode, + ) except _mod_transport.FileExists as e: raise errors.AlreadyControlDirError(transport.base) from e - if sys.platform == 'win32' and isinstance(transport, local.LocalTransport): - win32utils.set_file_attr_hidden(transport._abspath('.bzr')) + if sys.platform == "win32" and isinstance(transport, local.LocalTransport): + win32utils.set_file_attr_hidden(transport._abspath(".bzr")) file_mode = temp_control._file_mode del temp_control - bzrdir_transport = transport.clone('.bzr') - utf8_files = [('README', - b"This is a Bazaar control directory.\n" - b"Do not change any files in this directory.\n" - b"See http://bazaar.canonical.com/ for more information about Bazaar.\n"), - ('branch-format', self.as_string()), - ] + bzrdir_transport = transport.clone(".bzr") + utf8_files = [ + ( + "README", + b"This is a Bazaar control directory.\n" + b"Do not change any files in this directory.\n" + b"See http://bazaar.canonical.com/ for more information about Bazaar.\n", + ), + ("branch-format", self.as_string()), + ] # NB: no need to escape relative paths that are url safe. - control_files = lockable_files.LockableFiles(bzrdir_transport, - self._lock_file_name, self._lock_class) + control_files = lockable_files.LockableFiles( + bzrdir_transport, self._lock_file_name, self._lock_class + ) control_files.create_lock() control_files.lock_write() try: - for (filename, content) in utf8_files: - bzrdir_transport.put_bytes(filename, content, - mode=file_mode) + for filename, content in utf8_files: + bzrdir_transport.put_bytes(filename, content, mode=file_mode) finally: control_files.unlock() return self.open(transport, _found=True) @@ -1420,8 +1549,11 @@ def open(self, transport, _found=False): if not _found: found_format = controldir.ControlDirFormat.find_format(transport) if not isinstance(found_format, self.__class__): - raise AssertionError("{} was asked to open {}, but it seems to need " - "format {}".format(self, transport, found_format)) + raise AssertionError( + "{} was asked to open {}, but it seems to need " "format {}".format( + self, transport, found_format + ) + ) # Allow subclasses - use the found format. self._supply_sub_formats_to(found_format) return found_format._open(transport) @@ -1452,13 +1584,21 @@ def supports_transport(self, transport): # bzr formats can be opened over all known transports return True - def check_support_status(self, allow_unsupported, recommend_upgrade=True, - basedir=None): - controldir.ControlDirFormat.check_support_status(self, - allow_unsupported=allow_unsupported, recommend_upgrade=recommend_upgrade, - basedir=basedir) - BzrFormat.check_support_status(self, allow_unsupported=allow_unsupported, - recommend_upgrade=recommend_upgrade, basedir=basedir) + def check_support_status( + self, allow_unsupported, recommend_upgrade=True, basedir=None + ): + controldir.ControlDirFormat.check_support_status( + self, + allow_unsupported=allow_unsupported, + recommend_upgrade=recommend_upgrade, + basedir=basedir, + ) + BzrFormat.check_support_status( + self, + allow_unsupported=allow_unsupported, + recommend_upgrade=recommend_upgrade, + basedir=basedir, + ) @classmethod def is_control_filename(klass, filename): @@ -1476,12 +1616,12 @@ def is_control_filename(klass, filename): # it was extracted from WorkingTree.is_control_filename. If the # method's contract is extended beyond the current trivial # implementation, please add new tests for it to the appropriate place. - return filename == '.bzr' or filename.startswith('.bzr/') + return filename == ".bzr" or filename.startswith(".bzr/") @classmethod def get_default_format(klass): """Return the current default format.""" - return controldir.format_registry.get('bzr')() + return controldir.format_registry.get("bzr")() class BzrDirMetaFormat1(BzrDirFormat): @@ -1526,14 +1666,16 @@ def __ne__(self, other): def get_branch_format(self): if self._branch_format is None: from .branch import format_registry as branch_format_registry + self._branch_format = branch_format_registry.get_default() return self._branch_format def set_branch_format(self, format): self._branch_format = format - def require_stacking(self, stack_on=None, possible_transports=None, - _skip_repo=False): + def require_stacking( + self, stack_on=None, possible_transports=None, _skip_repo=False + ): """We have a request to stack, try to ensure the formats support it. :param stack_on: If supplied, it is the URL to a branch that we want to @@ -1559,8 +1701,9 @@ def get_target_branch(): target[:] = [None, True, True] return target try: - target_dir = BzrDir.open(stack_on, - possible_transports=possible_transports) + target_dir = BzrDir.open( + stack_on, possible_transports=possible_transports + ) except errors.NotBranchError: # Nothing there, don't change formats target[:] = [None, True, False] @@ -1578,8 +1721,7 @@ def get_target_branch(): target[:] = [target_branch, True, False] return target - if (not _skip_repo - and not self.repository_format.supports_external_lookups): + if not _skip_repo and not self.repository_format.supports_external_lookups: # We need to upgrade the Repository. target_branch, _, do_upgrade = get_target_branch() if target_branch is None: @@ -1588,7 +1730,9 @@ def get_target_branch(): # stack_on is inaccessible, JFDI. # TODO: bad monkey, hard-coded formats... if self.repository_format.rich_root_data: - new_repo_format = knitpack_repo.RepositoryFormatKnitPack5RichRoot() + new_repo_format = ( + knitpack_repo.RepositoryFormatKnitPack5RichRoot() + ) else: new_repo_format = knitpack_repo.RepositoryFormatKnitPack5() else: @@ -1602,9 +1746,13 @@ def get_target_branch(): new_repo_format = None if new_repo_format is not None: self.repository_format = new_repo_format - note(gettext('Source repository format does not support stacking,' - ' using format:\n %s'), - new_repo_format.get_format_description()) + note( + gettext( + "Source repository format does not support stacking," + " using format:\n %s" + ), + new_repo_format.get_format_description(), + ) if not self.get_branch_format().supports_stacking(): # We just checked the repo, now lets check if we need to @@ -1614,6 +1762,7 @@ def get_target_branch(): if do_upgrade: # TODO: bad monkey, hard-coded formats... from .branch import BzrBranchFormat7 + new_branch_format = BzrBranchFormat7() else: new_branch_format = target_branch._format @@ -1622,19 +1771,25 @@ def get_target_branch(): if new_branch_format is not None: # Does support stacking, use its format. self.set_branch_format(new_branch_format) - note(gettext('Source branch format does not support stacking,' - ' using format:\n %s'), - new_branch_format.get_format_description()) + note( + gettext( + "Source branch format does not support stacking," + " using format:\n %s" + ), + new_branch_format.get_format_description(), + ) def get_converter(self, format=None): """See BzrDirFormat.get_converter().""" if format is None: format = BzrDirFormat.get_default_format() - if (isinstance(self, BzrDirMetaFormat1) - and isinstance(format, BzrDirMetaFormat1Colo)): + if isinstance(self, BzrDirMetaFormat1) and isinstance( + format, BzrDirMetaFormat1Colo + ): return ConvertMetaToColo(format) - if (isinstance(self, BzrDirMetaFormat1Colo) - and isinstance(format, BzrDirMetaFormat1)): + if isinstance(self, BzrDirMetaFormat1Colo) and isinstance( + format, BzrDirMetaFormat1 + ): return ConvertMetaToColo(format) if not isinstance(self, format.__class__): # converting away from metadir is not implemented @@ -1664,14 +1819,14 @@ def __return_repository_format(self): if self._repository_format: return self._repository_format from .repository import format_registry + return format_registry.get_default() def _set_repository_format(self, value): """Allow changing the repository format for metadir formats.""" self._repository_format = value - repository_format = property(__return_repository_format, - _set_repository_format) + repository_format = property(__return_repository_format, _set_repository_format) def _supply_sub_formats_to(self, other_format): """Give other_format the same values for sub formats as this has. @@ -1685,7 +1840,7 @@ def _supply_sub_formats_to(self, other_format): :return: None. """ super()._supply_sub_formats_to(other_format) - if getattr(self, '_repository_format', None) is not None: + if getattr(self, "_repository_format", None) is not None: other_format.repository_format = self.repository_format if self._branch_format is not None: other_format._branch_format = self._branch_format @@ -1695,6 +1850,7 @@ def _supply_sub_formats_to(self, other_format): def __get_workingtree_format(self): if self._workingtree_format is None: from .workingtree import format_registry as wt_format_registry + self._workingtree_format = wt_format_registry.get_default() return self._workingtree_format @@ -1704,8 +1860,7 @@ def __set_workingtree_format(self, wt_format): def __repr__(self): return f"<{self.__class__.__name__!r}>" - workingtree_format = property(__get_workingtree_format, - __set_workingtree_format) + workingtree_format = property(__get_workingtree_format, __set_workingtree_format) class BzrDirMetaFormat1Colo(BzrDirMetaFormat1): @@ -1748,7 +1903,7 @@ def convert(self, to_convert, pb): with ui.ui_factory.nested_progress_bar() as self.pb: self.count = 0 self.total = 1 - self.step('checking repository format') + self.step("checking repository format") try: repo = self.controldir.open_repository() except errors.NoRepositoryPresent: @@ -1757,11 +1912,13 @@ def convert(self, to_convert, pb): repo_fmt = self.target_format.repository_format if not isinstance(repo._format, repo_fmt.__class__): from ..repository import CopyConverter - ui.ui_factory.note(gettext('starting repository conversion')) + + ui.ui_factory.note(gettext("starting repository conversion")) if not repo_fmt.supports_overriding_transport: raise AssertionError( "Repository in metadir does not support " - "overriding transport") + "overriding transport" + ) converter = CopyConverter(self.target_format.repository_format) converter.convert(repo, pb) for branch in self.controldir.list_branches(): @@ -1771,21 +1928,26 @@ def convert(self, to_convert, pb): old = branch._format.__class__ new = self.target_format.get_branch_format().__class__ while old != new: - if (old == fullhistorybranch.BzrBranchFormat5 - and new in (_mod_bzrbranch.BzrBranchFormat6, - _mod_bzrbranch.BzrBranchFormat7, - _mod_bzrbranch.BzrBranchFormat8)): + if old == fullhistorybranch.BzrBranchFormat5 and new in ( + _mod_bzrbranch.BzrBranchFormat6, + _mod_bzrbranch.BzrBranchFormat7, + _mod_bzrbranch.BzrBranchFormat8, + ): branch_converter = _mod_bzrbranch.Converter5to6() - elif (old == _mod_bzrbranch.BzrBranchFormat6 - and new in (_mod_bzrbranch.BzrBranchFormat7, - _mod_bzrbranch.BzrBranchFormat8)): + elif old == _mod_bzrbranch.BzrBranchFormat6 and new in ( + _mod_bzrbranch.BzrBranchFormat7, + _mod_bzrbranch.BzrBranchFormat8, + ): branch_converter = _mod_bzrbranch.Converter6to7() - elif (old == _mod_bzrbranch.BzrBranchFormat7 - and new is _mod_bzrbranch.BzrBranchFormat8): + elif ( + old == _mod_bzrbranch.BzrBranchFormat7 + and new is _mod_bzrbranch.BzrBranchFormat8 + ): branch_converter = _mod_bzrbranch.Converter7to8() else: - raise errors.BadConversionTarget("No converter", new, - branch._format) + raise errors.BadConversionTarget( + "No converter", new, branch._format + ) branch_converter.convert(branch) branch = self.controldir.open_branch() old = branch._format.__class__ @@ -1796,20 +1958,32 @@ def convert(self, to_convert, pb): else: # TODO: conversions of Branch and Tree should be done by # InterXFormat lookups - if (isinstance(tree, workingtree_3.WorkingTree3) + if ( + isinstance(tree, workingtree_3.WorkingTree3) and not isinstance(tree, workingtree_4.DirStateWorkingTree) - and isinstance(self.target_format.workingtree_format, - workingtree_4.DirStateWorkingTreeFormat)): + and isinstance( + self.target_format.workingtree_format, + workingtree_4.DirStateWorkingTreeFormat, + ) + ): workingtree_4.Converter3to4().convert(tree) - if (isinstance(tree, workingtree_4.DirStateWorkingTree) + if ( + isinstance(tree, workingtree_4.DirStateWorkingTree) and not isinstance(tree, workingtree_4.WorkingTree5) - and isinstance(self.target_format.workingtree_format, - workingtree_4.WorkingTreeFormat5)): + and isinstance( + self.target_format.workingtree_format, + workingtree_4.WorkingTreeFormat5, + ) + ): workingtree_4.Converter4to5().convert(tree) - if (isinstance(tree, workingtree_4.DirStateWorkingTree) + if ( + isinstance(tree, workingtree_4.DirStateWorkingTree) and not isinstance(tree, workingtree_4.WorkingTree6) - and isinstance(self.target_format.workingtree_format, - workingtree_4.WorkingTreeFormat6)): + and isinstance( + self.target_format.workingtree_format, + workingtree_4.WorkingTreeFormat6, + ) + ): workingtree_4.Converter4or5to6().convert(tree) return to_convert @@ -1826,16 +2000,16 @@ def __init__(self, target_format): def convert(self, to_convert, pb): """See Converter.convert().""" - to_convert.transport.put_bytes('branch-format', - self.target_format.as_string()) + to_convert.transport.put_bytes("branch-format", self.target_format.as_string()) return BzrDir.open_from_transport(to_convert.root_transport) class CreateRepository(controldir.RepositoryAcquisitionPolicy): """A policy of creating a new repository.""" - def __init__(self, controldir, stack_on=None, stack_on_pwd=None, - require_stacking=False): + def __init__( + self, controldir, stack_on=None, stack_on_pwd=None, require_stacking=False + ): """Constructor. :param controldir: The controldir to create the repository on. @@ -1843,12 +2017,12 @@ def __init__(self, controldir, stack_on=None, stack_on_pwd=None, :param stack_on_pwd: If stack_on is relative, the location it is relative to. """ - super().__init__( - stack_on, stack_on_pwd, require_stacking) + super().__init__(stack_on, stack_on_pwd, require_stacking) self._controldir = controldir - def acquire_repository(self, make_working_trees=None, shared=False, - possible_transports=None): + def acquire_repository( + self, make_working_trees=None, shared=False, possible_transports=None + ): """Implementation of RepositoryAcquisitionPolicy.acquire_repository. Creates the desired repository in the controldir we already have. @@ -1861,15 +2035,18 @@ def acquire_repository(self, make_working_trees=None, shared=False, stack_on = self._get_full_stack_on() if stack_on: format = self._controldir._format - format.require_stacking(stack_on=stack_on, - possible_transports=possible_transports) + format.require_stacking( + stack_on=stack_on, possible_transports=possible_transports + ) if not self._require_stacking: # We have picked up automatic stacking somewhere. - note(gettext('Using default stacking branch {0} at {1}').format( - self._stack_on, self._stack_on_pwd)) + note( + gettext("Using default stacking branch {0} at {1}").format( + self._stack_on, self._stack_on_pwd + ) + ) repository = self._controldir.create_repository(shared=shared) - self._add_fallback(repository, - possible_transports=possible_transports) + self._add_fallback(repository, possible_transports=possible_transports) if make_working_trees is not None: repository.set_make_working_trees(make_working_trees) return repository, True @@ -1878,8 +2055,9 @@ def acquire_repository(self, make_working_trees=None, shared=False, class UseExistingRepository(controldir.RepositoryAcquisitionPolicy): """A policy of reusing an existing repository.""" - def __init__(self, repository, stack_on=None, stack_on_pwd=None, - require_stacking=False): + def __init__( + self, repository, stack_on=None, stack_on_pwd=None, require_stacking=False + ): """Constructor. :param repository: The repository to use. @@ -1887,12 +2065,12 @@ def __init__(self, repository, stack_on=None, stack_on_pwd=None, :param stack_on_pwd: If stack_on is relative, the location it is relative to. """ - super().__init__( - stack_on, stack_on_pwd, require_stacking) + super().__init__(stack_on, stack_on_pwd, require_stacking) self._repository = repository - def acquire_repository(self, make_working_trees=None, shared=False, - possible_transports=None): + def acquire_repository( + self, make_working_trees=None, shared=False, possible_transports=None + ): """Implementation of RepositoryAcquisitionPolicy.acquire_repository. Returns an existing repository to use. @@ -1902,8 +2080,7 @@ def acquire_repository(self, make_working_trees=None, shared=False, else: possible_transports = list(possible_transports) possible_transports.append(self._repository.controldir.transport) - self._add_fallback(self._repository, - possible_transports=possible_transports) + self._add_fallback(self._repository, possible_transports=possible_transports) return self._repository, False diff --git a/breezy/bzr/check.py b/breezy/bzr/check.py index 393ea32912..6c51a4017d 100644 --- a/breezy/bzr/check.py +++ b/breezy/bzr/check.py @@ -89,17 +89,17 @@ def check(self, callback_refs=None, check_repo=True): if callback_refs is None: callback_refs = {} with self.repository.lock_read(), ui.ui_factory.nested_progress_bar() as self.progress: - self.progress.update(gettext('check'), 0, 4) + self.progress.update(gettext("check"), 0, 4) if self.check_repo: - self.progress.update(gettext('checking revisions'), 0) + self.progress.update(gettext("checking revisions"), 0) self.check_revisions() - self.progress.update(gettext('checking commit contents'), 1) + self.progress.update(gettext("checking commit contents"), 1) self.repository._check_inventories(self) - self.progress.update(gettext('checking file graphs'), 2) + self.progress.update(gettext("checking file graphs"), 2) # check_weaves is done after the revision scan so that # revision index is known to be valid. self.check_weaves() - self.progress.update(gettext('checking branches and trees'), 3) + self.progress.update(gettext("checking branches and trees"), 3) if callback_refs: repo = self.repository # calculate all refs, and callback the objects requesting them. @@ -115,27 +115,26 @@ def check(self, callback_refs=None, check_repo=True): for ref, wantlist in callback_refs.items(): wanting_items.update(wantlist) kind, value = ref - if kind == 'trees': + if kind == "trees": refs[ref] = repo.revision_tree(value) - elif kind == 'lefthand-distance': + elif kind == "lefthand-distance": distances.add(value) - elif kind == 'revision-existence': + elif kind == "revision-existence": existences.add(value) else: - raise AssertionError( - f'unknown ref kind for ref {ref}') + raise AssertionError(f"unknown ref kind for ref {ref}") node_distances = repo.get_graph().find_lefthand_distances(distances) for key, distance in node_distances.items(): - refs[('lefthand-distance', key)] = distance + refs[("lefthand-distance", key)] = distance if key in existences and distance > 0: - refs[('revision-existence', key)] = True + refs[("revision-existence", key)] = True existences.remove(key) parent_map = repo.get_graph().get_parent_map(existences) for key in parent_map: - refs[('revision-existence', key)] = True + refs[("revision-existence", key)] = True existences.remove(key) for key in existences: - refs[('revision-existence', key)] = False + refs[("revision-existence", key)] = False for item in wanting_items: if isinstance(item, WorkingTree): item._check(refs) @@ -160,7 +159,8 @@ def _check_revisions(self, revisions_iterator): def check_revisions(self): """Scan revisions, checking data directly available as we go.""" revision_iterator = self.repository.iter_revisions( - self.repository.all_revision_ids()) + self.repository.all_revision_ids() + ) revision_iterator = self._check_revisions(revision_iterator) # We read the all revisions here: # - doing this allows later code to depend on the revision index. @@ -174,7 +174,8 @@ def check_revisions(self): pass else: bad_revisions = self.repository._find_inconsistent_revision_parents( - revision_iterator) + revision_iterator + ) self.revs_with_bad_parents_in_index = list(bad_revisions) def report_results(self, verbose): @@ -184,62 +185,87 @@ def report_results(self, verbose): result.report_results(verbose) def _report_repo_results(self, verbose): - note(gettext('checked repository {0} format {1}').format( - self.repository.user_url, - self.repository._format)) - note(gettext('%6d revisions'), self.checked_rev_cnt) - note(gettext('%6d file-ids'), len(self.checked_weaves)) + note( + gettext("checked repository {0} format {1}").format( + self.repository.user_url, self.repository._format + ) + ) + note(gettext("%6d revisions"), self.checked_rev_cnt) + note(gettext("%6d file-ids"), len(self.checked_weaves)) if verbose: - note(gettext('%6d unreferenced text versions'), - len(self.unreferenced_versions)) + note( + gettext("%6d unreferenced text versions"), + len(self.unreferenced_versions), + ) if verbose and len(self.unreferenced_versions): for file_id, revision_id in self.unreferenced_versions: - note(gettext('unreferenced version: {{{0}}} in {1}').format( - revision_id.decode('utf-8'), file_id.decode('utf-8'))) + note( + gettext("unreferenced version: {{{0}}} in {1}").format( + revision_id.decode("utf-8"), file_id.decode("utf-8") + ) + ) if self.missing_inventory_sha_cnt: - note(gettext('%6d revisions are missing inventory_sha1'), - self.missing_inventory_sha_cnt) + note( + gettext("%6d revisions are missing inventory_sha1"), + self.missing_inventory_sha_cnt, + ) if self.missing_revision_cnt: - note(gettext('%6d revisions are mentioned but not present'), - self.missing_revision_cnt) + note( + gettext("%6d revisions are mentioned but not present"), + self.missing_revision_cnt, + ) if len(self.ghosts): - note(gettext('%6d ghost revisions'), len(self.ghosts)) + note(gettext("%6d ghost revisions"), len(self.ghosts)) if verbose: for ghost in self.ghosts: - note(' %s', ghost.decode('utf-8')) + note(" %s", ghost.decode("utf-8")) if len(self.missing_parent_links): - note(gettext('%6d revisions missing parents in ancestry'), - len(self.missing_parent_links)) + note( + gettext("%6d revisions missing parents in ancestry"), + len(self.missing_parent_links), + ) if verbose: for link, linkers in self.missing_parent_links.items(): - note(gettext(' %s should be in the ancestry for:'), - link.decode('utf-8')) + note( + gettext(" %s should be in the ancestry for:"), + link.decode("utf-8"), + ) for linker in linkers: - note(' * %s', linker.decode('utf-8')) + note(" * %s", linker.decode("utf-8")) if len(self.inconsistent_parents): - note(gettext('%6d inconsistent parents'), - len(self.inconsistent_parents)) + note(gettext("%6d inconsistent parents"), len(self.inconsistent_parents)) if verbose: for info in self.inconsistent_parents: revision_id, file_id, found_parents, correct_parents = info - note(gettext(' * {0} version {1} has parents ({2}) ' - 'but should have ({3})').format( - file_id.decode('utf-8'), revision_id.decode('utf-8'), - ', '.join(p.decode('utf-8') for p in found_parents), - ', '.join(p.decode('utf-8') for p in correct_parents))) + note( + gettext( + " * {0} version {1} has parents ({2}) " + "but should have ({3})" + ).format( + file_id.decode("utf-8"), + revision_id.decode("utf-8"), + ", ".join(p.decode("utf-8") for p in found_parents), + ", ".join(p.decode("utf-8") for p in correct_parents), + ) + ) if self.revs_with_bad_parents_in_index: - note(gettext( - '%6d revisions have incorrect parents in the revision index'), - len(self.revs_with_bad_parents_in_index)) + note( + gettext("%6d revisions have incorrect parents in the revision index"), + len(self.revs_with_bad_parents_in_index), + ) if verbose: for item in self.revs_with_bad_parents_in_index: revision_id, index_parents, actual_parents = item - note(gettext( - ' {0} has wrong parents in index: ' - '({1}) should be ({2})').format( - revision_id.decode('utf-8'), - ', '.join(p.decode('utf-8') for p in index_parents), - ', '.join(p.decode('utf-8') for p in actual_parents))) + note( + gettext( + " {0} has wrong parents in index: " + "({1}) should be ({2})" + ).format( + revision_id.decode("utf-8"), + ", ".join(p.decode("utf-8") for p in index_parents), + ", ".join(p.decode("utf-8") for p in actual_parents), + ) + ) for item in self._report_items: note(item) @@ -250,9 +276,11 @@ def _check_one_rev(self, rev_id, rev): :param rev: A revision or None to indicate a missing revision. """ if rev.revision_id != rev_id: - self._report_items.append(gettext( - 'Mismatched internal revid {{{0}}} and index revid {{{1}}}').format( - rev.revision_id.decode('utf-8'), rev_id.decode('utf-8'))) + self._report_items.append( + gettext( + "Mismatched internal revid {{{0}}} and index revid {{{1}}}" + ).format(rev.revision_id.decode("utf-8"), rev_id.decode("utf-8")) + ) rev_id = rev.revision_id # Check this revision tree etc, and count as seen when we encounter a # reference to it. @@ -265,8 +293,9 @@ def _check_one_rev(self, rev_id, rev): self.ghosts.add(parent) self.ancestors[rev_id] = tuple(rev.parent_ids) or (NULL_REVISION,) - self.add_pending_item(rev_id, ('inventories', rev_id), 'inventory', - rev.inventory_sha1) + self.add_pending_item( + rev_id, ("inventories", rev_id), "inventory", rev.inventory_sha1 + ) self.checked_rev_cnt += 1 def add_pending_item(self, referer, key, kind, sha1): @@ -280,9 +309,12 @@ def add_pending_item(self, referer, key, kind, sha1): existing = self.pending_keys.get(key) if existing: if sha1 != existing[1]: - self._report_items.append(gettext('Multiple expected sha1s for {0}. {{{1}}}' - ' expects {{{2}}}, {{{3}}} expects {{{4}}}').format( - key, referer, sha1, existing[1], existing[0])) + self._report_items.append( + gettext( + "Multiple expected sha1s for {0}. {{{1}}}" + " expects {{{2}}}, {{{3}}} expects {{{4}}}" + ).format(key, referer, sha1, existing[1], existing[0]) + ) else: self.pending_keys[key] = (kind, sha1, referer) @@ -292,18 +324,20 @@ def check_weaves(self): self._check_weaves(storebar) def _check_weaves(self, storebar): - storebar.update('text-index', 0, 2) + storebar.update("text-index", 0, 2) if self.repository._format.fast_deltas: # We haven't considered every fileid instance so far. weave_checker = self.repository._get_versioned_file_checker( - ancestors=self.ancestors) + ancestors=self.ancestors + ) else: weave_checker = self.repository._get_versioned_file_checker( - text_key_references=self.text_key_references, - ancestors=self.ancestors) - storebar.update('file-graph', 1) + text_key_references=self.text_key_references, ancestors=self.ancestors + ) + storebar.update("file-graph", 1) wrongs, unused_versions = weave_checker.check_file_version_parents( - self.repository.texts) + self.repository.texts + ) self.checked_weaves = weave_checker.file_ids for text_key, (stored_parents, correct_parents) in wrongs.items(): # XXX not ready for id join/split operations. @@ -312,11 +346,12 @@ def _check_weaves(self, storebar): weave_parents = tuple([parent[-1] for parent in stored_parents]) correct_parents = tuple([parent[-1] for parent in correct_parents]) self.inconsistent_parents.append( - (revision_id, weave_id, weave_parents, correct_parents)) + (revision_id, weave_id, weave_parents, correct_parents) + ) self.unreferenced_versions.update(unused_versions) def _add_entry_to_text_key_references(self, inv, entry): - if not self.rich_roots and entry.name == '': + if not self.rich_roots and entry.name == "": return key = (entry.file_id, entry.revision) self.text_key_references.setdefault(key, False) diff --git a/breezy/bzr/chk_map.py b/breezy/bzr/chk_map.py index e781251724..368a0dd9f3 100644 --- a/breezy/bzr/chk_map.py +++ b/breezy/bzr/chk_map.py @@ -62,7 +62,7 @@ def _get_cache(): We need a function to do this because in a new thread the _thread_caches threading.local object does not have the cache initialized yet. """ - page_cache = getattr(_thread_caches, 'page_cache', None) + page_cache = getattr(_thread_caches, "page_cache", None) if page_cache is None: # We are caching bytes so len(value) is perfectly accurate page_cache = lru_cache.LRUSizeCache(_PAGE_CACHE_SIZE) @@ -82,17 +82,17 @@ def clear_cache(): def _search_key_plain(key): """Map the key tuple into a search string that just uses the key bytes.""" - return b'\x00'.join(key) + return b"\x00".join(key) search_key_registry = registry.Registry[bytes, Callable[[bytes], bytes], None]() -search_key_registry.register(b'plain', _search_key_plain) +search_key_registry.register(b"plain", _search_key_plain) class CHKMap: """A persistent map from string to string backed by a CHK store.""" - __slots__ = ('_store', '_root_node', '_search_key_func') + __slots__ = ("_store", "_root_node", "_search_key_func") def __init__(self, store, root_key, search_key_func=None): """Create a CHKMap object. @@ -124,12 +124,14 @@ def apply_delta(self, delta): has_deletes = False # Check preconditions first. as_st = StaticTuple.from_sequence - new_items = {as_st(key) for (old, key, value) in delta - if key is not None and old is None} + new_items = { + as_st(key) for (old, key, value) in delta if key is not None and old is None + } existing_new = list(self.iteritems(key_filter=new_items)) if existing_new: - raise errors.InconsistentDeltaDelta(delta, - f"New items are already in the map {existing_new!r}.") + raise errors.InconsistentDeltaDelta( + delta, f"New items are already in the map {existing_new!r}." + ) # Now apply changes. for old, new, _value in delta: if old is not None and old != new: @@ -160,8 +162,7 @@ def _get_node(self, node): """ if isinstance(node, StaticTuple): bytes = self._read_bytes(node) - return _deserialise(bytes, node, - search_key_func=self._search_key_func) + return _deserialise(bytes, node, search_key_func=self._search_key_func) else: return node @@ -169,49 +170,66 @@ def _read_bytes(self, key): try: return _get_cache()[key] except KeyError: - stream = self._store.get_record_stream([key], 'unordered', True) - bytes = next(stream).get_bytes_as('fulltext') + stream = self._store.get_record_stream([key], "unordered", True) + bytes = next(stream).get_bytes_as("fulltext") _get_cache()[key] = bytes return bytes - def _dump_tree(self, include_keys=False, encoding='utf-8'): + def _dump_tree(self, include_keys=False, encoding="utf-8"): """Return the tree in a string representation.""" self._ensure_root() - def decode(x): return x.decode(encoding) - res = self._dump_tree_node(self._root_node, prefix=b'', indent='', - decode=decode, include_keys=include_keys) - res.append('') # Give a trailing '\n' - return '\n'.join(res) + + def decode(x): + return x.decode(encoding) + + res = self._dump_tree_node( + self._root_node, + prefix=b"", + indent="", + decode=decode, + include_keys=include_keys, + ) + res.append("") # Give a trailing '\n' + return "\n".join(res) def _dump_tree_node(self, node, prefix, indent, decode, include_keys=True): """For this node and all children, generate a string representation.""" result = [] if not include_keys: - key_str = '' + key_str = "" else: node_key = node.key() if node_key is not None: - key_str = f' {decode(node_key[0])}' + key_str = f" {decode(node_key[0])}" else: - key_str = ' None' - result.append(f'{indent}{decode(prefix)!r} {node.__class__.__name__}{key_str}') + key_str = " None" + result.append(f"{indent}{decode(prefix)!r} {node.__class__.__name__}{key_str}") if isinstance(node, InternalNode): # Trigger all child nodes to get loaded list(node._iter_nodes(self._store)) for prefix, sub in sorted(node._items.items()): - result.extend(self._dump_tree_node(sub, prefix, indent + ' ', - decode=decode, include_keys=include_keys)) + result.extend( + self._dump_tree_node( + sub, + prefix, + indent + " ", + decode=decode, + include_keys=include_keys, + ) + ) else: for key, value in sorted(node._items.items()): # Don't use prefix nor indent here to line up when used in # tests in conjunction with assertEqualDiff - result.append(' {!r} {!r}'.format( - tuple([decode(ke) for ke in key]), decode(value))) + result.append( + f" {tuple([decode(ke) for ke in key])!r} {decode(value)!r}" + ) return result @classmethod - def from_dict(klass, store, initial_value, maximum_size=0, key_width=1, - search_key_func=None): + def from_dict( + klass, store, initial_value, maximum_size=0, key_width=1, search_key_func=None + ): """Create a CHKMap in store with initial_value as the content. :param store: The store to record initial_value in, a VersionedFiles @@ -227,16 +245,21 @@ def from_dict(klass, store, initial_value, maximum_size=0, key_width=1, multiple pages. :return: The root chk of the resulting CHKMap. """ - root_key = klass._create_directly(store, initial_value, - maximum_size=maximum_size, key_width=key_width, - search_key_func=search_key_func) + root_key = klass._create_directly( + store, + initial_value, + maximum_size=maximum_size, + key_width=key_width, + search_key_func=search_key_func, + ) if not isinstance(root_key, StaticTuple): - raise AssertionError(f'we got a {type(root_key)} instead of a StaticTuple') + raise AssertionError(f"we got a {type(root_key)} instead of a StaticTuple") return root_key @classmethod - def _create_via_map(klass, store, initial_value, maximum_size=0, - key_width=1, search_key_func=None): + def _create_via_map( + klass, store, initial_value, maximum_size=0, key_width=1, search_key_func=None + ): result = klass(store, None, search_key_func=search_key_func) result._root_node.set_maximum_size(maximum_size) result._root_node._key_width = key_width @@ -247,25 +270,24 @@ def _create_via_map(klass, store, initial_value, maximum_size=0, return root_key @classmethod - def _create_directly(klass, store, initial_value, maximum_size=0, - key_width=1, search_key_func=None): + def _create_directly( + klass, store, initial_value, maximum_size=0, key_width=1, search_key_func=None + ): node = LeafNode(search_key_func=search_key_func) node.set_maximum_size(maximum_size) node._key_width = key_width as_st = StaticTuple.from_sequence - node._items = {as_st(key): val - for key, val in initial_value.items()} - node._raw_size = sum(node._key_value_len(key, value) - for key, value in node._items.items()) + node._items = {as_st(key): val for key, val in initial_value.items()} + node._raw_size = sum( + node._key_value_len(key, value) for key, value in node._items.items() + ) node._len = len(node._items) node._compute_search_prefix() node._compute_serialised_prefix() - if (node._len > 1 and - maximum_size and - node._current_size() > maximum_size): + if node._len > 1 and maximum_size and node._current_size() > maximum_size: prefix, node_details = node._split(store) if len(node_details) == 1: - raise AssertionError('Failed to split using node._split') + raise AssertionError("Failed to split using node._split") node = InternalNode(prefix, search_key_func=search_key_func) node.set_maximum_size(maximum_size) node._key_width = key_width @@ -353,22 +375,22 @@ def process_common_leaf_nodes(self_node, basis_node): prefix = basis._search_key_func(key) heapq.heappush(basis_pending, (prefix, key, value, path)) - def process_common_prefix_nodes(self_node, self_path, - basis_node, basis_path): + def process_common_prefix_nodes(self_node, self_path, basis_node, basis_path): # Would it be more efficient if we could request both at the same # time? self_node = self._get_node(self_node) basis_node = basis._get_node(basis_node) - if (isinstance(self_node, InternalNode) and - isinstance(basis_node, InternalNode)): + if isinstance(self_node, InternalNode) and isinstance( + basis_node, InternalNode + ): # Matching internal nodes process_common_internal_nodes(self_node, basis_node) - elif (isinstance(self_node, LeafNode) and - isinstance(basis_node, LeafNode)): + elif isinstance(self_node, LeafNode) and isinstance(basis_node, LeafNode): process_common_leaf_nodes(self_node, basis_node) else: process_node(self_node, self_path, self, self_pending) process_node(basis_node, basis_path, basis, basis_pending) + process_common_prefix_nodes(self_node, None, basis_node, None) excluded_keys = set() @@ -460,12 +482,12 @@ def check_excluded(key_path): self_details = heapq.heappop(self_pending) basis_details = heapq.heappop(basis_pending) if self_details[2] != basis_details[2]: - yield (self_details[1], - basis_details[2], self_details[2]) + yield (self_details[1], basis_details[2], self_details[2]) continue # At least one side wasn't a simple value - if (self._node_key(self_pending[0][2]) - == self._node_key(basis_pending[0][2])): + if self._node_key(self_pending[0][2]) == self._node_key( + basis_pending[0][2] + ): # Identical pointers, skip (and don't bother adding to # excluded, it won't turn up again. heapq.heappop(self_pending) @@ -476,15 +498,16 @@ def check_excluded(key_path): # Both sides start with the same prefix, so process # them in parallel self_prefix, _, self_node, self_path = heapq.heappop( - self_pending) + self_pending + ) basis_prefix, _, basis_node, basis_path = heapq.heappop( - basis_pending) + basis_pending + ) if self_prefix != basis_prefix: - raise AssertionError( - f'{self_prefix!r} != {basis_prefix!r}') + raise AssertionError(f"{self_prefix!r} != {basis_prefix!r}") process_common_prefix_nodes( - self_node, self_path, - basis_node, basis_path) + self_node, self_path, basis_node, basis_path + ) continue if read_self: prefix, key, node, path = heapq.heappop(self_pending) @@ -530,8 +553,9 @@ def map(self, key, value): if len(node_details) == 1: self._root_node = node_details[0][1] else: - self._root_node = InternalNode(prefix, - search_key_func=self._search_key_func) + self._root_node = InternalNode( + prefix, search_key_func=self._search_key_func + ) self._root_node.set_maximum_size(node_details[0][1].maximum_size) self._root_node._key_width = node_details[0][1]._key_width for split, node in node_details: @@ -551,8 +575,7 @@ def unmap(self, key, check_remap=True): key = StaticTuple.from_sequence(key) self._ensure_root() if isinstance(self._root_node, InternalNode): - unmapped = self._root_node.unmap(self._store, key, - check_remap=check_remap) + unmapped = self._root_node.unmap(self._store, key, check_remap=check_remap) else: unmapped = self._root_node.unmap(self._store, key) self._root_node = unmapped @@ -582,9 +605,16 @@ class Node: adding the header bytes, and without prefix compression. """ - __slots__ = ('_key', '_len', '_maximum_size', '_key_width', - '_raw_size', '_items', '_search_prefix', '_search_key_func' - ) + __slots__ = ( + "_key", + "_len", + "_maximum_size", + "_key_width", + "_raw_size", + "_items", + "_search_prefix", + "_search_key_func", + ) def __init__(self, key_width=1): """Create a node. @@ -606,10 +636,16 @@ def __init__(self, key_width=1): def __repr__(self): items_str = str(sorted(self._items)) if len(items_str) > 20: - items_str = items_str[:16] + '...]' - return '{}(key:{} len:{} size:{} max:{} prefix:{} items:{})'.format( - self.__class__.__name__, self._key, self._len, self._raw_size, - self._maximum_size, self._search_prefix, items_str) + items_str = items_str[:16] + "...]" + return "{}(key:{} len:{} size:{} max:{} prefix:{} items:{})".format( + self.__class__.__name__, + self._key, + self._len, + self._raw_size, + self._maximum_size, + self._search_prefix, + items_str, + ) def key(self): return self._key @@ -646,7 +682,7 @@ def common_prefix(cls, prefix, key): if left != right: pos -= 1 break - common = prefix[:pos + 1] + common = prefix[: pos + 1] return common @classmethod @@ -665,7 +701,7 @@ def common_prefix_for_keys(cls, keys): if not common_prefix: # if common_prefix is the empty string, then we know it won't # change further - return b'' + return b"" return common_prefix @@ -681,7 +717,7 @@ class LeafNode(Node): the key/value pairs. """ - __slots__ = ('_common_serialised_prefix',) + __slots__ = ("_common_serialised_prefix",) def __init__(self, search_key_func=None): Node.__init__(self) @@ -695,10 +731,17 @@ def __init__(self, search_key_func=None): def __repr__(self): items_str = str(sorted(self._items)) if len(items_str) > 20: - items_str = items_str[:16] + '...]' - return \ - '{}(key:{} len:{} size:{} max:{} prefix:{} keywidth:{} items:{})'.format(self.__class__.__name__, self._key, self._len, self._raw_size, - self._maximum_size, self._search_prefix, self._key_width, items_str) + items_str = items_str[:16] + "...]" + return "{}(key:{} len:{} size:{} max:{} prefix:{} keywidth:{} items:{})".format( + self.__class__.__name__, + self._key, + self._len, + self._raw_size, + self._maximum_size, + self._search_prefix, + self._key_width, + items_str, + ) def _current_size(self): """Answer the current serialised size of this node. @@ -714,13 +757,19 @@ def _current_size(self): # And then that common prefix will not be stored in any of the # entry lines prefix_len = len(self._common_serialised_prefix) - bytes_for_items = (self._raw_size - (prefix_len * self._len)) - return (9 + # 'chkleaf:\n' + - len(str(self._maximum_size)) + 1 + - len(str(self._key_width)) + 1 + - len(str(self._len)) + 1 + - prefix_len + 1 + - bytes_for_items) + bytes_for_items = self._raw_size - (prefix_len * self._len) + return ( + 9 # 'chkleaf:\n' + + + len(str(self._maximum_size)) + + 1 + + len(str(self._key_width)) + + 1 + + len(str(self._len)) + + 1 + + prefix_len + + 1 + + bytes_for_items + ) @classmethod def deserialise(klass, bytes, key, search_key_func=None): @@ -730,8 +779,7 @@ def deserialise(klass, bytes, key, search_key_func=None): :param key: The key that the serialised node has. """ key = expect_static_tuple(key) - return _deserialise_leaf_node(bytes, key, - search_key_func=search_key_func) + return _deserialise_leaf_node(bytes, key, search_key_func=search_key_func) def iteritems(self, store, key_filter=None): """Iterate over items in the node. @@ -769,9 +817,14 @@ def iteritems(self, store, key_filter=None): def _key_value_len(self, key, value): # TODO: Should probably be done without actually joining the key, but # then that can be done via the C extension - return (len(self._serialise_key(key)) + 1 + - len(b'%d' % value.count(b'\n')) + 1 + - len(value) + 1) + return ( + len(self._serialise_key(key)) + + 1 + + len(b"%d" % value.count(b"\n")) + + 1 + + len(value) + + 1 + ) def _search_key(self, key): return self._search_key_func(key) @@ -792,23 +845,27 @@ def _map_no_split(self, key, value): self._common_serialised_prefix = serialised_key else: self._common_serialised_prefix = self.common_prefix( - self._common_serialised_prefix, serialised_key) + self._common_serialised_prefix, serialised_key + ) search_key = self._search_key(key) if self._search_prefix is _unknown: self._compute_search_prefix() if self._search_prefix is None: self._search_prefix = search_key else: - self._search_prefix = self.common_prefix( - self._search_prefix, search_key) - if (self._len > 1 and - self._maximum_size and - self._current_size() > self._maximum_size): + self._search_prefix = self.common_prefix(self._search_prefix, search_key) + if ( + self._len > 1 + and self._maximum_size + and self._current_size() > self._maximum_size + ): # Check to see if all of the search_keys for this node are # identical. We allow the node to grow under that circumstance # (we could track this as common state, but it is infrequent) - if (search_key != self._search_prefix or - not self._are_search_keys_identical()): + if ( + search_key != self._search_prefix + or not self._are_search_keys_identical() + ): return True return False @@ -821,7 +878,7 @@ def _split(self, store): :return: (common_serialised_prefix, [(node_serialised_prefix, node)]) """ if self._search_prefix is _unknown: - raise AssertionError('Search prefix must be known') + raise AssertionError("Search prefix must be known") common_prefix = self._search_prefix split_at = len(common_prefix) + 1 result = {} @@ -837,7 +894,7 @@ def _split(self, store): # may get a '\00' node anywhere, but won't have keys of # different lengths. if len(prefix) < split_at: - prefix += b'\x00' * (split_at - len(prefix)) + prefix += b"\x00" * (split_at - len(prefix)) if prefix not in result: node = LeafNode(search_key_func=self._search_key_func) node.set_maximum_size(self._maximum_size) @@ -851,8 +908,9 @@ def _split(self, store): # This node has been split and is now found via a different # path result.pop(prefix) - new_node = InternalNode(sub_prefix, - search_key_func=self._search_key_func) + new_node = InternalNode( + sub_prefix, search_key_func=self._search_key_func + ) new_node.set_maximum_size(self._maximum_size) new_node._key_width = self._key_width for split, node in node_details: @@ -870,10 +928,10 @@ def map(self, store, key, value): return self._split(store) else: if self._search_prefix is _unknown: - raise AssertionError(f'{self._search_prefix!r} must be known') + raise AssertionError(f"{self._search_prefix!r} must be known") return self._search_prefix, [(b"", self)] - _serialise_key = b'\x00'.join + _serialise_key = b"\x00".join def serialise(self, store): """Serialise the LeafNode to store. @@ -886,28 +944,32 @@ def serialise(self, store): lines.append(b"%d\n" % self._key_width) lines.append(b"%d\n" % self._len) if self._common_serialised_prefix is None: - lines.append(b'\n') + lines.append(b"\n") if len(self._items) != 0: - raise AssertionError('If _common_serialised_prefix is None' - ' we should have no items') + raise AssertionError( + "If _common_serialised_prefix is None" " we should have no items" + ) else: - lines.append(b'%s\n' % (self._common_serialised_prefix,)) + lines.append(b"%s\n" % (self._common_serialised_prefix,)) prefix_len = len(self._common_serialised_prefix) for key, value in sorted(self._items.items()): # Always add a final newline - value_lines = osutils.chunks_to_lines([value + b'\n']) - serialized = b"%s\x00%d\n" % (self._serialise_key(key), - len(value_lines)) + value_lines = osutils.chunks_to_lines([value + b"\n"]) + serialized = b"%s\x00%d\n" % (self._serialise_key(key), len(value_lines)) if not serialized.startswith(self._common_serialised_prefix): - raise AssertionError(f'We thought the common prefix was {self._common_serialised_prefix!r}' - f' but entry {serialized!r} does not have it in common') + raise AssertionError( + f"We thought the common prefix was {self._common_serialised_prefix!r}" + f" but entry {serialized!r} does not have it in common" + ) lines.append(serialized[prefix_len:]) lines.extend(value_lines) sha1, _, _ = store.add_lines((None,), (), lines) - self._key = StaticTuple(b"sha1:" + sha1,).intern() - data = b''.join(lines) + self._key = StaticTuple( + b"sha1:" + sha1, + ).intern() + data = b"".join(lines) if len(data) != self._current_size(): - raise AssertionError('Invalid _current_size') + raise AssertionError("Invalid _current_size") _get_cache()[self._key] = data return [self._key] @@ -948,8 +1010,7 @@ def _compute_serialised_prefix(self): unique within this node. """ serialised_keys = [self._serialise_key(key) for key in self._items] - self._common_serialised_prefix = self.common_prefix_for_keys( - serialised_keys) + self._common_serialised_prefix = self.common_prefix_for_keys(serialised_keys) return self._common_serialised_prefix def unmap(self, store, key): @@ -978,9 +1039,9 @@ class InternalNode(Node): LeafNode or InternalNode. """ - __slots__ = ('_node_width',) + __slots__ = ("_node_width",) - def __init__(self, prefix=b'', search_key_func=None): + def __init__(self, prefix=b"", search_key_func=None): Node.__init__(self) # The size of an internalnode with default values and no children. # How many octets key prefixes within this node are. @@ -1000,23 +1061,33 @@ def add_node(self, prefix, node): if self._search_prefix is None: raise AssertionError("_search_prefix should not be None") if not prefix.startswith(self._search_prefix): - raise AssertionError(f"prefixes mismatch: {prefix} must start with {self._search_prefix}") + raise AssertionError( + f"prefixes mismatch: {prefix} must start with {self._search_prefix}" + ) if len(prefix) != len(self._search_prefix) + 1: - raise AssertionError("prefix wrong length: len(%s) is not %d" % - (prefix, len(self._search_prefix) + 1)) + raise AssertionError( + "prefix wrong length: len(%s) is not %d" + % (prefix, len(self._search_prefix) + 1) + ) self._len += len(node) if not len(self._items): self._node_width = len(prefix) if self._node_width != len(self._search_prefix) + 1: - raise AssertionError("node width mismatch: %d is not %d" % - (self._node_width, len(self._search_prefix) + 1)) + raise AssertionError( + "node width mismatch: %d is not %d" + % (self._node_width, len(self._search_prefix) + 1) + ) self._items[prefix] = node self._key = None def _current_size(self): """Answer the current serialised size of this node.""" - return (self._raw_size + len(str(self._len)) + len(str(self._key_width)) - + len(str(self._maximum_size))) + return ( + self._raw_size + + len(str(self._len)) + + len(str(self._key_width)) + + len(str(self._maximum_size)) + ) @classmethod def deserialise(klass, bytes, key, search_key_func=None): @@ -1027,8 +1098,7 @@ def deserialise(klass, bytes, key, search_key_func=None): :return: An InternalNode instance. """ key = expect_static_tuple(key) - return _deserialise_internal_node(bytes, key, - search_key_func=search_key_func) + return _deserialise_internal_node(bytes, key, search_key_func=search_key_func) def iteritems(self, store, key_filter=None): for node, node_filter in self._iter_nodes(store, key_filter=key_filter): @@ -1105,13 +1175,11 @@ def _iter_nodes(self, store, key_filter=None, batch_size=None): length_filters = {} for key in key_filter: search_prefix = self._search_prefix_filter(key) - length_filter = length_filters.setdefault( - len(search_prefix), set()) + length_filter = length_filters.setdefault(len(search_prefix), set()) length_filter.add(search_prefix) prefix_to_keys.setdefault(search_prefix, []).append(key) - if (self._node_width in length_filters and - len(length_filters) == 1): + if self._node_width in length_filters and len(length_filters) == 1: # all of the search prefixes match exactly _node_width. This # means that everything is an exact match, and we can do a # lookup into self._items, rather than iterating over the items @@ -1152,8 +1220,9 @@ def _iter_nodes(self, store, key_filter=None, batch_size=None): except KeyError: continue else: - node = _deserialise(bytes, key, - search_key_func=self._search_key_func) + node = _deserialise( + bytes, key, search_key_func=self._search_key_func + ) prefix, node_key_filter = keys[key] self._items[prefix] = node found_keys.add(key) @@ -1167,15 +1236,16 @@ def _iter_nodes(self, store, key_filter=None, batch_size=None): batch_size = len(keys) key_order = list(keys) for batch_start in range(0, len(key_order), batch_size): - batch = key_order[batch_start:batch_start + batch_size] + batch = key_order[batch_start : batch_start + batch_size] # We have to fully consume the stream so there is no pending # I/O, so we buffer the nodes for now. - stream = store.get_record_stream(batch, 'unordered', True) + stream = store.get_record_stream(batch, "unordered", True) node_and_filters = [] for record in stream: - bytes = record.get_bytes_as('fulltext') - node = _deserialise(bytes, record.key, - search_key_func=self._search_key_func) + bytes = record.get_bytes_as("fulltext") + node = _deserialise( + bytes, record.key, search_key_func=self._search_key_func + ) prefix, node_key_filter = keys[record.key] node_and_filters.append((node, node_key_filter)) self._items[prefix] = node @@ -1188,23 +1258,21 @@ def map(self, store, key, value): raise AssertionError("can't map in an empty InternalNode.") search_key = self._search_key(key) if self._node_width != len(self._search_prefix) + 1: - raise AssertionError("node width mismatch: %d is not %d" % - (self._node_width, len(self._search_prefix) + 1)) + raise AssertionError( + "node width mismatch: %d is not %d" + % (self._node_width, len(self._search_prefix) + 1) + ) if not search_key.startswith(self._search_prefix): # This key doesn't fit in this index, so we need to split at the # point where it would fit, insert self into that internal node, # and then map this key into that node. - new_prefix = self.common_prefix(self._search_prefix, - search_key) - new_parent = InternalNode(new_prefix, - search_key_func=self._search_key_func) + new_prefix = self.common_prefix(self._search_prefix, search_key) + new_parent = InternalNode(new_prefix, search_key_func=self._search_key_func) new_parent.set_maximum_size(self._maximum_size) new_parent._key_width = self._key_width - new_parent.add_node(self._search_prefix[:len(new_prefix) + 1], - self) + new_parent.add_node(self._search_prefix[: len(new_prefix) + 1], self) return new_parent.map(store, key, value) - children = [node for node, _ in self._iter_nodes( - store, key_filter=[key])] + children = [node for node, _ in self._iter_nodes(store, key_filter=[key])] if children: child = children[0] else: @@ -1239,15 +1307,20 @@ def map(self, store, key, value): # amount is over a configurable limit. new_size = child._current_size() shrinkage = old_size - new_size - if (shrinkage > 0 and new_size < _INTERESTING_NEW_SIZE or - shrinkage > _INTERESTING_SHRINKAGE_LIMIT): + if ( + shrinkage > 0 + and new_size < _INTERESTING_NEW_SIZE + or shrinkage > _INTERESTING_SHRINKAGE_LIMIT + ): trace.mutter( "checking remap as size shrunk by %d to be %d", - shrinkage, new_size) + shrinkage, + new_size, + ) new_node = self._check_remap(store) if new_node._search_prefix is None: raise AssertionError("_search_prefix should not be None") - return new_node._search_prefix, [(b'', new_node)] + return new_node._search_prefix, [(b"", new_node)] # child has overflown - create a new intermediate node. # XXX: This is where we might want to try and expand our depth # to refer to more bytes of every child (which would give us @@ -1290,7 +1363,7 @@ def serialise(self, store): lines.append(b"%d\n" % self._len) if self._search_prefix is None: raise AssertionError("_search_prefix should not be None") - lines.append(b'%s\n' % (self._search_prefix,)) + lines.append(b"%s\n" % (self._search_prefix,)) prefix_len = len(self._search_prefix) for prefix, node in sorted(self._items.items()): if isinstance(node, StaticTuple): @@ -1299,22 +1372,28 @@ def serialise(self, store): key = node._key[0] serialised = b"%s\x00%s\n" % (prefix, key) if not serialised.startswith(self._search_prefix): - raise AssertionError(f"prefixes mismatch: {serialised} must start with {self._search_prefix}") + raise AssertionError( + f"prefixes mismatch: {serialised} must start with {self._search_prefix}" + ) lines.append(serialised[prefix_len:]) sha1, _, _ = store.add_lines((None,), (), lines) - self._key = StaticTuple(b"sha1:" + sha1,).intern() - _get_cache()[self._key] = b''.join(lines) + self._key = StaticTuple( + b"sha1:" + sha1, + ).intern() + _get_cache()[self._key] = b"".join(lines) yield self._key def _search_key(self, key): """Return the serialised key for key in this node.""" # search keys are fixed width. All will be self._node_width wide, so we # pad as necessary. - return (self._search_key_func(key) + b'\x00' * self._node_width)[:self._node_width] + return (self._search_key_func(key) + b"\x00" * self._node_width)[ + : self._node_width + ] def _search_prefix_filter(self, key): """Serialise key for use as a prefix filter in iteritems.""" - return self._search_key_func(key)[:self._node_width] + return self._search_key_func(key)[: self._node_width] def _split(self, offset): """Split this node into smaller nodes starting at offset. @@ -1352,8 +1431,7 @@ def unmap(self, store, key, check_remap=True): """Remove key from this node and its children.""" if not len(self._items): raise AssertionError("can't unmap in an empty InternalNode.") - children = [node for node, _ - in self._iter_nodes(store, key_filter=[key])] + children = [node for node, _ in self._iter_nodes(store, key_filter=[key])] if children: child = children[0] else: @@ -1433,8 +1511,7 @@ def _deserialise(data, key, search_key_func): if data.startswith(b"chkleaf:\n"): node = LeafNode.deserialise(data, key, search_key_func=search_key_func) elif data.startswith(b"chknode:\n"): - node = InternalNode.deserialise(data, key, - search_key_func=search_key_func) + node = InternalNode.deserialise(data, key, search_key_func=search_key_func) else: raise AssertionError("Unknown node type.") return node @@ -1451,8 +1528,7 @@ class CHKMapDifference: but it won't yield (key,value) pairs that are common. """ - def __init__(self, store, new_root_keys, old_root_keys, - search_key_func, pb=None): + def __init__(self, store, new_root_keys, old_root_keys, search_key_func, pb=None): # TODO: Should we add a StaticTuple barrier here? It would be nice to # force callers to use StaticTuple, because there will often be # lots of keys passed in here. And even if we cast it locally, @@ -1493,15 +1569,16 @@ def _read_nodes_from_store(self, keys): # only 1 time during this code. (We may want to evaluate saving the # raw bytes into the page cache, which would allow a working tree # update after the fetch to not have to read the bytes again.) - stream = self._store.get_record_stream(keys, 'unordered', True) + stream = self._store.get_record_stream(keys, "unordered", True) for record in stream: if self._pb is not None: self._pb.tick() - if record.storage_kind == 'absent': + if record.storage_kind == "absent": raise errors.NoSuchRevision(self._store, record.key) - bytes = record.get_bytes_as('fulltext') - node = _deserialise(bytes, record.key, - search_key_func=self._search_key_func) + bytes = record.get_bytes_as("fulltext") + node = _deserialise( + bytes, record.key, search_key_func=self._search_key_func + ) if isinstance(node, InternalNode): # Note we don't have to do node.refs() because we know that # there are no children that have been pushed into this node @@ -1522,11 +1599,11 @@ def _read_nodes_from_store(self, keys): def _read_old_roots(self): old_chks_to_enqueue = [] all_old_chks = self._all_old_chks - for _record, _node, prefix_refs, items in \ - self._read_nodes_from_store(self._old_root_keys): + for _record, _node, prefix_refs, items in self._read_nodes_from_store( + self._old_root_keys + ): # Uninteresting node - prefix_refs = [p_r for p_r in prefix_refs - if p_r[1] not in all_old_chks] + prefix_refs = [p_r for p_r in prefix_refs if p_r[1] not in all_old_chks] new_refs = [p_r[1] for p_r in prefix_refs] all_old_chks.update(new_refs) # TODO: This might be a good time to turn items into StaticTuple @@ -1577,13 +1654,14 @@ def _read_all_roots(self): # added a second time processed_new_refs = self._processed_new_refs processed_new_refs.update(new_keys) - for record, _node, prefix_refs, items in \ - self._read_nodes_from_store(new_keys): + for record, _node, prefix_refs, items in self._read_nodes_from_store(new_keys): # At this level, we now know all the uninteresting references # So we filter and queue up whatever is remaining - prefix_refs = [p_r for p_r in prefix_refs - if p_r[1] not in self._all_old_chks and - p_r[1] not in processed_new_refs] + prefix_refs = [ + p_r + for p_r in prefix_refs + if p_r[1] not in self._all_old_chks and p_r[1] not in processed_new_refs + ] refs = [p_r[1] for p_r in prefix_refs] new_prefixes.update([p_r[0] for p_r in prefix_refs]) self._new_queue.extend(refs) @@ -1594,11 +1672,9 @@ def _read_all_roots(self): # TODO: This might be a good time to cast to StaticTuple, as # self._new_item_queue will hold the contents of multiple # records for an extended lifetime - new_items = [item for item in items - if item not in self._all_old_items] + new_items = [item for item in items if item not in self._all_old_items] self._new_item_queue.extend(new_items) - new_prefixes.update([self._search_key_func(item[0]) - for item in new_items]) + new_prefixes.update([self._search_key_func(item[0]) for item in new_items]) processed_new_refs.update(refs) yield record # For new_prefixes we have the full length prefixes queued up. @@ -1618,8 +1694,7 @@ def _flush_new_queue(self): all_old_chks = self._all_old_chks processed_new_refs = self._processed_new_refs all_old_items = self._all_old_items - new_items = [item for item in self._new_item_queue - if item not in all_old_items] + new_items = [item for item in self._new_item_queue if item not in all_old_items] self._new_item_queue = [] if new_items: yield None, new_items @@ -1637,8 +1712,7 @@ def _flush_new_queue(self): if all_old_items: # using the 'if' check saves about 145s => 141s, when # streaming initial branch of Launchpad data. - items = [item for item in items - if item not in all_old_items] + items = [item for item in items if item not in all_old_items] yield record, items next_refs_update([p_r[1] for p_r in p_refs]) del p_refs @@ -1679,8 +1753,9 @@ def process(self): yield record, items -def iter_interesting_nodes(store, interesting_root_keys, - uninteresting_root_keys, pb=None): +def iter_interesting_nodes( + store, interesting_root_keys, uninteresting_root_keys, pb=None +): """Given root keys, find interesting nodes. Evaluate nodes referenced by interesting_root_keys. Ones that are also @@ -1693,10 +1768,13 @@ def iter_interesting_nodes(store, interesting_root_keys, :return: Yield (interesting record, {interesting key:values}) """ - iterator = CHKMapDifference(store, interesting_root_keys, - uninteresting_root_keys, - search_key_func=store._search_key_func, - pb=pb) + iterator = CHKMapDifference( + store, + interesting_root_keys, + uninteresting_root_keys, + search_key_func=store._search_key_func, + pb=pb, + ) return iterator.process() @@ -1711,8 +1789,8 @@ def iter_interesting_nodes(store, interesting_root_keys, except ImportError as e: osutils.failed_to_load_extension(e) from ._chk_map_py import _deserialise_internal_node, _deserialise_leaf_node -search_key_registry.register(b'hash-16-way', _search_key_16) -search_key_registry.register(b'hash-255-way', _search_key_255) +search_key_registry.register(b"hash-16-way", _search_key_16) +search_key_registry.register(b"hash-255-way", _search_key_255) def _check_key(key): @@ -1722,10 +1800,10 @@ def _check_key(key): to debug problems. """ if not isinstance(key, StaticTuple): - raise TypeError(f'key {key!r} is not StaticTuple but {type(key)}') + raise TypeError(f"key {key!r} is not StaticTuple but {type(key)}") if len(key) != 1: - raise ValueError(f'key {key!r} should have length 1, not {len(key)}') + raise ValueError(f"key {key!r} should have length 1, not {len(key)}") if not isinstance(key[0], str): - raise TypeError(f'key {key!r} should hold a str, not {type(key[0])!r}') - if not key[0].startswith('sha1:'): - raise ValueError(f'key {key!r} should point to a sha1:') + raise TypeError(f"key {key!r} should hold a str, not {type(key[0])!r}") + if not key[0].startswith("sha1:"): + raise ValueError(f"key {key!r} should point to a sha1:") diff --git a/breezy/bzr/chk_serializer.py b/breezy/bzr/chk_serializer.py index a3ff40306d..9890f6f3be 100644 --- a/breezy/bzr/chk_serializer.py +++ b/breezy/bzr/chk_serializer.py @@ -23,24 +23,27 @@ class CHKSerializer(serializer.InventorySerializer): """A CHKInventory based serializer with 'plain' behaviour.""" support_altered_by_hack = False - supported_kinds = {'file', 'directory', 'symlink', 'tree-reference'} + supported_kinds = {"file", "directory", "symlink", "tree-reference"} def __init__(self, format_num, node_size, search_key_name): self.format_num = format_num self.maximum_size = node_size self.search_key_name = search_key_name - def _unpack_inventory(self, elt, revision_id=None, entry_cache=None, - return_from_cache=False): + def _unpack_inventory( + self, elt, revision_id=None, entry_cache=None, return_from_cache=False + ): """Construct from XML Element.""" from .xml_serializer import unpack_inventory_entry, unpack_inventory_flat - inv = unpack_inventory_flat(elt, self.format_num, - unpack_inventory_entry, entry_cache, - return_from_cache) + + inv = unpack_inventory_flat( + elt, self.format_num, unpack_inventory_entry, entry_cache, return_from_cache + ) return inv - def read_inventory_from_lines(self, xml_lines, revision_id=None, - entry_cache=None, return_from_cache=False): + def read_inventory_from_lines( + self, xml_lines, revision_id=None, entry_cache=None, return_from_cache=False + ): """Read xml_string into an inventory object. :param xml_string: The xml to read. @@ -55,21 +58,24 @@ def read_inventory_from_lines(self, xml_lines, revision_id=None, make some operations significantly faster. """ from .xml_serializer import ParseError, fromstringlist + try: return self._unpack_inventory( - fromstringlist(xml_lines), revision_id, + fromstringlist(xml_lines), + revision_id, entry_cache=entry_cache, - return_from_cache=return_from_cache) + return_from_cache=return_from_cache, + ) except ParseError as e: raise serializer.UnexpectedInventoryFormat(e) from e def read_inventory(self, f, revision_id=None): """Read an inventory from a file-like object.""" from .xml_serializer import ParseError + try: try: - return self._unpack_inventory(self._read_element(f), - revision_id=None) + return self._unpack_inventory(self._read_element(f), revision_id=None) finally: f.close() except ParseError as e: @@ -94,28 +100,36 @@ def write_inventory(self, inv, f, working=False): :return: The inventory as a list of lines. """ from .xml_serializer import encode_and_escape, serialize_inventory_flat + output = [] append = output.append if inv.revision_id is not None: - revid = b''.join( - [b' revision_id="', - encode_and_escape(inv.revision_id), b'"']) + revid = b"".join( + [b' revision_id="', encode_and_escape(inv.revision_id), b'"'] + ) else: revid = b"" - append(b'\n' % ( - self.format_num, revid)) - append(b'\n' % ( - encode_and_escape(inv.root.file_id), - encode_and_escape(inv.root.name), - encode_and_escape(inv.root.revision))) - serialize_inventory_flat(inv, append, root_id=None, - supported_kinds=self.supported_kinds, - working=working) + append(b'\n' % (self.format_num, revid)) + append( + b'\n' + % ( + encode_and_escape(inv.root.file_id), + encode_and_escape(inv.root.name), + encode_and_escape(inv.root.revision), + ) + ) + serialize_inventory_flat( + inv, + append, + root_id=None, + supported_kinds=self.supported_kinds, + working=working, + ) if f is not None: f.writelines(output) return output # A CHKInventory based serializer with 'plain' behaviour. -inventory_chk_serializer_255_bigpage_9 = CHKSerializer(b'9', 65536, b'hash-255-way') -inventory_chk_serializer_255_bigpage_10 = CHKSerializer(b'10', 65536, b'hash-255-way') +inventory_chk_serializer_255_bigpage_9 = CHKSerializer(b"9", 65536, b"hash-255-way") +inventory_chk_serializer_255_bigpage_10 = CHKSerializer(b"10", 65536, b"hash-255-way") diff --git a/breezy/bzr/conflicts.py b/breezy/bzr/conflicts.py index 22b54a16a3..bebb4fae4c 100644 --- a/breezy/bzr/conflicts.py +++ b/breezy/bzr/conflicts.py @@ -19,13 +19,16 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( cache_utf8, transform, ) -""") +""", +) from .. import errors, osutils from .. import transport as _mod_transport @@ -33,7 +36,7 @@ from ..conflicts import ConflictList as BaseConflictList from . import rio -CONFLICT_SUFFIXES = ('.THIS', '.BASE', '.OTHER') +CONFLICT_SUFFIXES = (".THIS", ".BASE", ".OTHER") class Conflict(BaseConflict): @@ -54,7 +57,7 @@ def as_stanza(self): s = rio.Stanza(type=self.typestring, path=self.path) if self.file_id is not None: # Stanza requires Unicode apis - s.add('file_id', self.file_id.decode('utf8')) + s.add("file_id", self.file_id.decode("utf8")) return s def _cmp_list(self): @@ -84,7 +87,7 @@ def describe(self): def __repr__(self): rdict = dict(self.__dict__) - rdict['class'] = self.__class__.__name__ + rdict["class"] = self.__class__.__name__ return self.rformat % rdict @staticmethod @@ -108,9 +111,9 @@ def do(self, action, tree): :param tree: The tree passed as a parameter to the method. """ - meth = getattr(self, f'action_{action}', None) + meth = getattr(self, f"action_{action}", None) if meth is None: - raise NotImplementedError(self.__class__.__name__ + '.' + action) + raise NotImplementedError(self.__class__.__name__ + "." + action) meth(tree) def action_auto(self, tree): @@ -133,7 +136,6 @@ def _resolve_with_cleanups(self, tree, *args, **kwargs): class ConflictList(BaseConflictList): - @staticmethod def from_stanzas(stanzas): """Produce a new ConflictList from an iterable of stanzas.""" @@ -147,8 +149,7 @@ def to_stanzas(self): for conflict in self: yield conflict.as_stanza() - def select_conflicts(self, tree, paths, ignore_misses=False, - recurse=False): + def select_conflicts(self, tree, paths, ignore_misses=False, recurse=False): """Select the conflicts associated with paths in a tree. File-ids are also used for this. @@ -166,7 +167,7 @@ def select_conflicts(self, tree, paths, ignore_misses=False, for conflict in self: selected = False - for key in ('path', 'conflict_path'): + for key in ("path", "conflict_path"): cpath = getattr(conflict, key, None) if cpath is None: continue @@ -178,7 +179,7 @@ def select_conflicts(self, tree, paths, ignore_misses=False, selected = True selected_paths.add(cpath) - for key in ('file_id', 'conflict_file_id'): + for key in ("file_id", "conflict_file_id"): cfile_id = getattr(conflict, key, None) if cfile_id is None: continue @@ -201,16 +202,14 @@ def select_conflicts(self, tree, paths, ignore_misses=False, return new_conflicts, selected_conflicts - - class PathConflict(Conflict): """A conflict was encountered merging file paths.""" - typestring = 'path conflict' + typestring = "path conflict" - format = 'Path conflict: %(path)s / %(conflict_path)s' + format = "Path conflict: %(path)s / %(conflict_path)s" - rformat = '%(class)s(%(path)r, %(conflict_path)r, %(file_id)r)' + rformat = "%(class)s(%(path)r, %(conflict_path)r, %(file_id)r)" def __init__(self, path, conflict_path=None, file_id=None): Conflict.__init__(self, path, file_id) @@ -219,7 +218,7 @@ def __init__(self, path, conflict_path=None, file_id=None): def as_stanza(self): s = Conflict.as_stanza(self) if self.conflict_path is not None: - s.add('conflict_path', self.conflict_path) + s.add("conflict_path", self.conflict_path) return s def associated_filenames(self): @@ -235,16 +234,16 @@ def _resolve(self, tt, file_id, path, winner): :param winner: 'this' or 'other' indicates which side is the winner. """ path_to_create = None - if winner == 'this': - if self.path == '': + if winner == "this": + if self.path == "": return # Nothing to do - if self.conflict_path == '': + if self.conflict_path == "": path_to_create = self.path revid = tt._tree.get_parent_ids()[0] - elif winner == 'other': - if self.conflict_path == '': + elif winner == "other": + if self.conflict_path == "": return # Nothing to do - if self.path == '': + if self.path == "": path_to_create = self.conflict_path # FIXME: If there are more than two parents we may need to # iterate. Taking the last parent is the safer bet in the mean @@ -252,12 +251,11 @@ def _resolve(self, tt, file_id, path, winner): revid = tt._tree.get_parent_ids()[-1] else: # Programmer error - raise AssertionError(f'bad winner: {winner!r}') + raise AssertionError(f"bad winner: {winner!r}") if path_to_create is not None: tid = tt.trans_id_tree_path(path_to_create) tree = self._revision_tree(tt._tree, revid) - transform.create_from_tree( - tt, tid, tree, tree.id2path(file_id)) + transform.create_from_tree(tt, tid, tree, tree.id2path(file_id)) tt.version_file(tid, file_id=file_id) else: tid = tt.trans_id_file_id(file_id) @@ -275,7 +273,7 @@ def _infer_file_id(self, tree): # Establish which path we should use to find back the file-id possible_paths = [] for p in (self.path, self.conflict_path): - if p == '': + if p == "": # special hard-coded path continue if p is not None: @@ -292,26 +290,23 @@ def _infer_file_id(self, tree): def action_take_this(self, tree): if self.file_id is not None: - self._resolve_with_cleanups(tree, self.file_id, self.path, - winner='this') + self._resolve_with_cleanups(tree, self.file_id, self.path, winner="this") else: # Prior to bug #531967 we need to find back the file_id and restore # the content from there revtree, file_id = self._infer_file_id(tree) - tree.revert([revtree.id2path(file_id)], - old_tree=revtree, backups=False) + tree.revert([revtree.id2path(file_id)], old_tree=revtree, backups=False) def action_take_other(self, tree): if self.file_id is not None: - self._resolve_with_cleanups(tree, self.file_id, - self.conflict_path, - winner='other') + self._resolve_with_cleanups( + tree, self.file_id, self.conflict_path, winner="other" + ) else: # Prior to bug #531967 we need to find back the file_id and restore # the content from there revtree, file_id = self._infer_file_id(tree) - tree.revert([revtree.id2path(file_id)], - old_tree=revtree, backups=False) + tree.revert([revtree.id2path(file_id)], old_tree=revtree, backups=False) class ContentsConflict(PathConflict): @@ -319,12 +314,12 @@ class ContentsConflict(PathConflict): has_files = True - typestring = 'contents conflict' + typestring = "contents conflict" - format = 'Contents conflict in %(path)s' + format = "Contents conflict in %(path)s" def associated_filenames(self): - return [self.path + suffix for suffix in ('.BASE', '.OTHER')] + return [self.path + suffix for suffix in (".BASE", ".OTHER")] def _resolve(self, tt, suffix_to_remove): """Resolve the conflict. @@ -339,7 +334,8 @@ def _resolve(self, tt, suffix_to_remove): # Delete 'item.THIS' or 'item.OTHER' depending on # suffix_to_remove tt.delete_contents( - tt.trans_id_tree_path(self.path + '.' + suffix_to_remove)) + tt.trans_id_tree_path(self.path + "." + suffix_to_remove) + ) except _mod_transport.NoSuchFile: # There are valid cases where 'item.suffix_to_remove' either # never existed or was already deleted (including the case @@ -364,10 +360,10 @@ def _resolve(self, tt, suffix_to_remove): tt.apply() def action_take_this(self, tree): - self._resolve_with_cleanups(tree, 'OTHER') + self._resolve_with_cleanups(tree, "OTHER") def action_take_other(self, tree): - self._resolve_with_cleanups(tree, 'THIS') + self._resolve_with_cleanups(tree, "THIS") # TODO: There should be a base revid attribute to better inform the user about @@ -377,13 +373,13 @@ class TextConflict(Conflict): has_files = True - typestring = 'text conflict' + typestring = "text conflict" - format = 'Text conflict in %(path)s' + format = "Text conflict in %(path)s" - rformat = '%(class)s(%(path)r, %(file_id)r)' + rformat = "%(class)s(%(path)r, %(file_id)r)" - _conflict_re = re.compile(b'^(<{7}|={7}|>{7})') + _conflict_re = re.compile(b"^(<{7}|={7}|>{7})") def associated_filenames(self): return [self.path + suffix for suffix in CONFLICT_SUFFIXES] @@ -402,14 +398,12 @@ def _resolve(self, tt, winner_suffix): # item will exist after the conflict has been resolved anyway. item_tid = tt.trans_id_file_id(self.file_id) item_parent_tid = tt.get_tree_parent(item_tid) - winner_path = self.path + '.' + winner_suffix + winner_path = self.path + "." + winner_suffix winner_tid = tt.trans_id_tree_path(winner_path) winner_parent_tid = tt.get_tree_parent(winner_tid) # Switch the paths to preserve the content - tt.adjust_path(osutils.basename(self.path), - winner_parent_tid, winner_tid) - tt.adjust_path(osutils.basename(winner_path), - item_parent_tid, item_tid) + tt.adjust_path(osutils.basename(self.path), winner_parent_tid, winner_tid) + tt.adjust_path(osutils.basename(winner_path), item_parent_tid, item_tid) # Associate the file_id to the right content tt.unversion_file(item_tid) tt.version_file(winner_tid, file_id=self.file_id) @@ -422,7 +416,7 @@ def action_auto(self, tree): kind = tree.kind(self.path) except _mod_transport.NoSuchFile: return - if kind != 'file': + if kind != "file": raise NotImplementedError("Conflict is not a file") conflict_markers_in_line = self._conflict_re.search # GZ 2012-07-27: What if not tree.has_id(self.file_id) due to removal? @@ -432,10 +426,10 @@ def action_auto(self, tree): raise NotImplementedError("Conflict markers present") def action_take_this(self, tree): - self._resolve_with_cleanups(tree, 'THIS') + self._resolve_with_cleanups(tree, "THIS") def action_take_other(self, tree): - self._resolve_with_cleanups(tree, 'OTHER') + self._resolve_with_cleanups(tree, "OTHER") class HandledConflict(Conflict): @@ -454,7 +448,7 @@ def _cmp_list(self): def as_stanza(self): s = Conflict.as_stanza(self) - s.add('action', self.action) + s.add("action", self.action) return s def associated_filenames(self): @@ -467,11 +461,14 @@ class HandledPathConflict(HandledConflict): This is intended to be a base class. """ - rformat = "%(class)s(%(action)r, %(path)r, %(conflict_path)r,"\ + rformat = ( + "%(class)s(%(action)r, %(path)r, %(conflict_path)r," " %(file_id)r, %(conflict_file_id)r)" + ) - def __init__(self, action, path, conflict_path, file_id=None, - conflict_file_id=None): + def __init__( + self, action, path, conflict_path, file_id=None, conflict_file_id=None + ): HandledConflict.__init__(self, action, path, file_id) self.conflict_path = conflict_path # the factory blindly transfers the Stanza values to __init__, @@ -481,14 +478,16 @@ def __init__(self, action, path, conflict_path, file_id=None, self.conflict_file_id = conflict_file_id def _cmp_list(self): - return HandledConflict._cmp_list(self) + [self.conflict_path, - self.conflict_file_id] + return HandledConflict._cmp_list(self) + [ + self.conflict_path, + self.conflict_file_id, + ] def as_stanza(self): s = HandledConflict.as_stanza(self) - s.add('conflict_path', self.conflict_path) + s.add("conflict_path", self.conflict_path) if self.conflict_file_id is not None: - s.add('conflict_file_id', self.conflict_file_id.decode('utf8')) + s.add("conflict_file_id", self.conflict_file_id.decode("utf8")) return s @@ -496,17 +495,17 @@ def as_stanza(self): class DuplicateID(HandledPathConflict): """Two files want the same file_id.""" - typestring = 'duplicate id' + typestring = "duplicate id" - format = 'Conflict adding id to %(conflict_path)s. %(action)s %(path)s.' + format = "Conflict adding id to %(conflict_path)s. %(action)s %(path)s." class DuplicateEntry(HandledPathConflict): """Two directory entries want to have the same name.""" - typestring = 'duplicate' + typestring = "duplicate" - format = 'Conflict adding file %(conflict_path)s. %(action)s %(path)s.' + format = "Conflict adding file %(conflict_path)s. %(action)s %(path)s." def action_take_this(self, tree): tree.remove([self.conflict_path], force=True, keep_files=False) @@ -528,9 +527,9 @@ class ParentLoop(HandledPathConflict): merge A and B """ - typestring = 'parent loop' + typestring = "parent loop" - format = 'Conflict moving %(path)s into %(conflict_path)s. %(action)s.' + format = "Conflict moving %(path)s into %(conflict_path)s. %(action)s." def action_take_this(self, tree): # just acccept brz proposal @@ -543,8 +542,7 @@ def action_take_other(self, tree): cp_tid = tt.trans_id_file_id(self.conflict_file_id) cparent_tid = tt.get_tree_parent(cp_tid) tt.adjust_path(osutils.basename(self.path), cparent_tid, cp_tid) - tt.adjust_path(osutils.basename(self.conflict_path), - parent_tid, p_tid) + tt.adjust_path(osutils.basename(self.conflict_path), parent_tid, p_tid) tt.apply() @@ -554,10 +552,12 @@ class UnversionedParent(HandledConflict): and the other added a versioned file to it. """ - typestring = 'unversioned parent' + typestring = "unversioned parent" - format = 'Conflict because %(path)s is not versioned, but has versioned'\ - ' children. %(action)s.' + format = ( + "Conflict because %(path)s is not versioned, but has versioned" + " children. %(action)s." + ) # FIXME: We silently do nothing to make tests pass, but most probably the # conflict shouldn't exist (the long story is that the conflict is @@ -576,9 +576,9 @@ class MissingParent(HandledConflict): See also: DeletingParent (same situation, THIS and OTHER reversed). """ - typestring = 'missing parent' + typestring = "missing parent" - format = 'Conflict adding files to %(path)s. %(action)s.' + format = "Conflict adding files to %(path)s. %(action)s." def action_take_this(self, tree): tree.remove([self.path], force=True, keep_files=False) @@ -594,10 +594,9 @@ class DeletingParent(HandledConflict): the THIS added a file to it. """ - typestring = 'deleting parent' + typestring = "deleting parent" - format = "Conflict: can't delete %(path)s because it is not empty. "\ - "%(action)s." + format = "Conflict: can't delete %(path)s because it is not empty. " "%(action)s." # FIXME: It's a bit strange that the default action is not coherent with # MissingParent from the *user* pov. @@ -615,17 +614,18 @@ class NonDirectoryParent(HandledConflict): an attempt to change the kind of a directory with files. """ - typestring = 'non-directory parent' + typestring = "non-directory parent" - format = "Conflict: %(path)s is not a directory, but has files in it."\ - " %(action)s." + format = ( + "Conflict: %(path)s is not a directory, but has files in it." " %(action)s." + ) # FIXME: .OTHER should be used instead of .new when the conflict is created def action_take_this(self, tree): # FIXME: we should preserve that path when the conflict is generated ! - if self.path.endswith('.new'): - conflict_path = self.path[:-(len('.new'))] + if self.path.endswith(".new"): + conflict_path = self.path[: -(len(".new"))] tree.remove([self.path], force=True, keep_files=False) tree.add(conflict_path) else: @@ -633,8 +633,8 @@ def action_take_this(self, tree): def action_take_other(self, tree): # FIXME: we should preserve that path when the conflict is generated ! - if self.path.endswith('.new'): - conflict_path = self.path[:-(len('.new'))] + if self.path.endswith(".new"): + conflict_path = self.path[: -(len(".new"))] tree.remove([conflict_path], force=True, keep_files=False) tree.rename_one(self.path, conflict_path) else: @@ -651,6 +651,15 @@ def register_types(*conflict_types): ctype[conflict_type.typestring] = conflict_type -register_types(ContentsConflict, TextConflict, PathConflict, DuplicateID, - DuplicateEntry, ParentLoop, UnversionedParent, MissingParent, - DeletingParent, NonDirectoryParent) +register_types( + ContentsConflict, + TextConflict, + PathConflict, + DuplicateID, + DuplicateEntry, + ParentLoop, + UnversionedParent, + MissingParent, + DeletingParent, + NonDirectoryParent, +) diff --git a/breezy/bzr/debug_commands.py b/breezy/bzr/debug_commands.py index 19e1949155..03553fdb20 100644 --- a/breezy/bzr/debug_commands.py +++ b/breezy/bzr/debug_commands.py @@ -41,11 +41,14 @@ class cmd_dump_btree(Command): # rather than only going through iter_all_entries. However, this is # good enough for a start hidden = True - encoding_type = 'exact' - takes_args = ['path'] - takes_options = [Option('raw', help='Write the uncompressed bytes out,' - ' rather than the parsed tuples.'), - ] + encoding_type = "exact" + takes_args = ["path"] + takes_options = [ + Option( + "raw", + help="Write the uncompressed bytes out," " rather than the parsed tuples.", + ), + ] def run(self, path, raw=False): dirname, basename = osutils.split(path) @@ -70,22 +73,23 @@ def _dump_raw_bytes(self, trans, basename): # This is because the first page of every row starts with an # uncompressed header. bt, bytes = self._get_index_and_bytes(trans, basename) - for page_idx, page_start in enumerate(range(0, len(bytes), - btree_index._PAGE_SIZE)): + for page_idx, page_start in enumerate( + range(0, len(bytes), btree_index._PAGE_SIZE) + ): page_end = min(page_start + btree_index._PAGE_SIZE, len(bytes)) page_bytes = bytes[page_start:page_end] if page_idx == 0: - self.outf.write('Root node:\n') + self.outf.write("Root node:\n") header_end, data = bt._parse_header_from_bytes(page_bytes) self.outf.write(page_bytes[:header_end]) page_bytes = data - self.outf.write('\nPage %d\n' % (page_idx,)) + self.outf.write("\nPage %d\n" % (page_idx,)) if len(page_bytes) == 0: - self.outf.write('(empty)\n') + self.outf.write("(empty)\n") else: decomp_bytes = zlib.decompress(page_bytes) self.outf.write(decomp_bytes) - self.outf.write('\n') + self.outf.write("\n") def _dump_entries(self, trans, basename): try: @@ -107,14 +111,15 @@ def _dump_entries(self, trans, basename): refs_as_tuples = static_tuple.as_tuples(refs) if refs_as_tuples is not None: refs_as_tuples = tuple( - tuple(tuple(r.decode('utf-8') - for r in t1) for t1 in t2) - for t2 in refs_as_tuples) + tuple(tuple(r.decode("utf-8") for r in t1) for t1 in t2) + for t2 in refs_as_tuples + ) as_tuple = ( - tuple([r.decode('utf-8') for r in node[1]]), - node[2].decode('utf-8'), - refs_as_tuples) - self.outf.write(f'{as_tuple}\n') + tuple([r.decode("utf-8") for r in node[1]]), + node[2].decode("utf-8"), + refs_as_tuples, + ) + self.outf.write(f"{as_tuple}\n") class cmd_file_id(Command): @@ -126,8 +131,8 @@ class cmd_file_id(Command): """ hidden = True - _see_also = ['inventory', 'ls'] - takes_args = ['filename'] + _see_also = ["inventory", "ls"] + takes_args = ["filename"] @display_command def run(self, filename): @@ -136,7 +141,7 @@ def run(self, filename): if file_id is None: raise errors.NotVersionedError(filename) else: - self.outf.write(file_id.decode('utf-8') + '\n') + self.outf.write(file_id.decode("utf-8") + "\n") class cmd_file_path(Command): @@ -147,7 +152,7 @@ class cmd_file_path(Command): """ hidden = True - takes_args = ['filename'] + takes_args = ["filename"] @display_command def run(self, filename): diff --git a/breezy/bzr/dirstate.py b/breezy/bzr/dirstate.py index 555743ab70..55820b5f9a 100644 --- a/breezy/bzr/dirstate.py +++ b/breezy/bzr/dirstate.py @@ -250,7 +250,6 @@ class DirstateCorrupt(errors.BzrError): - _fmt = "The dirstate file (%(state)s) appears to be corrupt: %(msg)s" def __init__(self, state, msg): @@ -299,27 +298,27 @@ class DirState: """ _kind_to_minikind = { - 'absent': b'a', - 'file': b'f', - 'directory': b'd', - 'relocated': b'r', - 'symlink': b'l', - 'tree-reference': b't', - } + "absent": b"a", + "file": b"f", + "directory": b"d", + "relocated": b"r", + "symlink": b"l", + "tree-reference": b"t", + } _minikind_to_kind = { - b'a': 'absent', - b'f': 'file', - b'd': 'directory', - b'l': 'symlink', - b'r': 'relocated', - b't': 'tree-reference', - } + b"a": "absent", + b"f": "file", + b"d": "directory", + b"l": "symlink", + b"r": "relocated", + b"t": "tree-reference", + } _stat_to_minikind = { - stat.S_IFDIR: b'd', - stat.S_IFREG: b'f', - stat.S_IFLNK: b'l', + stat.S_IFDIR: b"d", + stat.S_IFREG: b"f", + stat.S_IFLNK: b"l", } - _to_yesno = {True: b'y', False: b'n'} # TODO profile the performance gain + _to_yesno = {True: b"y", False: b"n"} # TODO profile the performance gain # of using int conversion rather than a dict here. AND BLAME ANDREW IF # it is faster. @@ -335,14 +334,15 @@ class DirState: # A pack_stat (the x's) that is just noise and will never match the output # of base64 encode. - NULLSTAT = b'x' * 32 - NULL_PARENT_DETAILS = static_tuple.StaticTuple(b'a', b'', 0, False, b'') + NULLSTAT = b"x" * 32 + NULL_PARENT_DETAILS = static_tuple.StaticTuple(b"a", b"", 0, False, b"") - HEADER_FORMAT_2 = b'#bazaar dirstate flat format 2\n' - HEADER_FORMAT_3 = b'#bazaar dirstate flat format 3\n' + HEADER_FORMAT_2 = b"#bazaar dirstate flat format 2\n" + HEADER_FORMAT_3 = b"#bazaar dirstate flat format 3\n" - def __init__(self, path, sha1_provider, worth_saving_limit=0, - use_filesystem_for_exec=True): + def __init__( + self, path, sha1_provider, worth_saving_limit=0, use_filesystem_for_exec=True + ): """Create a DirState object. :param path: The path at which the dirstate file on disk should live. @@ -385,7 +385,7 @@ def __init__(self, path, sha1_provider, worth_saving_limit=0, self._split_path_cache = {} self._bisect_page_size = DirState.BISECT_PAGE_SIZE self._sha1_provider = sha1_provider - if debug.debug_flag_enabled('hashcache'): + if debug.debug_flag_enabled("hashcache"): self._sha1_file = self._sha1_file_and_mutter else: self._sha1_file = self._sha1_provider.sha1 @@ -399,8 +399,7 @@ def __init__(self, path, sha1_provider, worth_saving_limit=0, self._known_hash_changes = set() # How many hash changed entries can we have without saving self._worth_saving_limit = worth_saving_limit - self._config_stack = config.LocationStack(urlutils.local_path_to_url( - path)) + self._config_stack = config.LocationStack(urlutils.local_path_to_url(path)) self._use_filesystem_for_exec = use_filesystem_for_exec def __repr__(self): @@ -414,12 +413,13 @@ def _mark_modified(self, hash_changed_entries=None, header_modified=False): :param header_modified: mark the header modified as well, not just the dirblocks. """ - #trace.mutter_callsite(3, "modified hash entries: %s", hash_changed_entries) + # trace.mutter_callsite(3, "modified hash entries: %s", hash_changed_entries) if hash_changed_entries: - self._known_hash_changes.update( - [e[0] for e in hash_changed_entries]) - if self._dirblock_state in (DirState.NOT_IN_MEMORY, - DirState.IN_MEMORY_UNMODIFIED): + self._known_hash_changes.update([e[0] for e in hash_changed_entries]) + if self._dirblock_state in ( + DirState.NOT_IN_MEMORY, + DirState.IN_MEMORY_UNMODIFIED, + ): # If the dirstate is already marked a IN_MEMORY_MODIFIED, then # that takes precedence. self._dirblock_state = DirState.IN_MEMORY_HASH_MODIFIED @@ -472,47 +472,47 @@ def add(self, path, file_id, kind, stat, fingerprint): raise errors.InvalidNormalization(path) # you should never have files called . or ..; just add the directory # in the parent, or according to the special treatment for the root - if basename == '.' or basename == '..': + if basename == "." or basename == "..": raise inventory.InvalidEntryName(path) # now that we've normalised, we need the correct utf8 path and # dirname and basename elements. This single encode and split should be # faster than three separate encodes. - utf8path = (dirname + '/' + basename).strip('/').encode('utf8') + utf8path = (dirname + "/" + basename).strip("/").encode("utf8") dirname, basename = osutils.split(utf8path) # uses __class__ for speed; the check is needed for safety if file_id.__class__ is not bytes: - raise AssertionError( - f"must be a utf8 file_id not {type(file_id)}") + raise AssertionError(f"must be a utf8 file_id not {type(file_id)}") # Make sure the file_id does not exist in this tree rename_from = None - file_id_entry = self._get_entry( - 0, fileid_utf8=file_id, include_deleted=True) + file_id_entry = self._get_entry(0, fileid_utf8=file_id, include_deleted=True) if file_id_entry != (None, None): - if file_id_entry[1][0][0] == b'a': + if file_id_entry[1][0][0] == b"a": if file_id_entry[0] != (dirname, basename, file_id): # set the old name's current operation to rename - self.update_minimal(file_id_entry[0], - b'r', - path_utf8=b'', - packed_stat=b'', - fingerprint=utf8path - ) + self.update_minimal( + file_id_entry[0], + b"r", + path_utf8=b"", + packed_stat=b"", + fingerprint=utf8path, + ) rename_from = file_id_entry[0][0:2] else: - path = osutils.pathjoin( - file_id_entry[0][0], file_id_entry[0][1]) + path = osutils.pathjoin(file_id_entry[0][0], file_id_entry[0][1]) kind = DirState._minikind_to_kind[file_id_entry[1][0][0]] - info = f'{kind}:{path}' + info = f"{kind}:{path}" raise inventory.DuplicateFileId(file_id, info) - first_key = (dirname, basename, b'') + first_key = (dirname, basename, b"") block_index, present = self._find_block_index_from_key(first_key) if present: # check the path is not in the tree block = self._dirblocks[block_index][1] entry_index, _ = self._find_entry_index(first_key, block) - while (entry_index < len(block) and - block[entry_index][0][0:2] == first_key[0:2]): - if block[entry_index][1][0][0] not in (b'a', b'r'): + while ( + entry_index < len(block) + and block[entry_index][0][0:2] == first_key[0:2] + ): + if block[entry_index][1][0][0] not in (b"a", b"r"): # this path is in the dirstate in the current tree. raise Exception("adding already added path!") entry_index += 1 @@ -521,8 +521,12 @@ def add(self, path, file_id, kind, stat, fingerprint): # might be because the directory was empty, or not loaded yet. Look # for a parent entry, if not found, raise NotVersionedError parent_dir, parent_base = osutils.split(dirname) - parent_block_idx, parent_entry_idx, _, parent_present = \ - self._get_block_entry_index(parent_dir, parent_base, 0) + ( + parent_block_idx, + parent_entry_idx, + _, + parent_present, + ) = self._get_block_entry_index(parent_dir, parent_base, 0) if not parent_present: raise errors.NotVersionedError(path, str(self)) self._ensure_block(parent_block_idx, parent_entry_idx, dirname) @@ -538,37 +542,53 @@ def add(self, path, file_id, kind, stat, fingerprint): minikind = DirState._kind_to_minikind[kind] if rename_from is not None: if rename_from[0]: - old_path_utf8 = b'%s/%s' % rename_from + old_path_utf8 = b"%s/%s" % rename_from else: old_path_utf8 = rename_from[1] - parent_info[0] = (b'r', old_path_utf8, 0, False, b'') - if kind == 'file': - entry_data = entry_key, [ - (minikind, fingerprint, size, False, packed_stat), - ] + parent_info - elif kind == 'directory': - entry_data = entry_key, [ - (minikind, b'', 0, False, packed_stat), - ] + parent_info - elif kind == 'symlink': - entry_data = entry_key, [ - (minikind, fingerprint, size, False, packed_stat), - ] + parent_info - elif kind == 'tree-reference': - entry_data = entry_key, [ - (minikind, fingerprint, 0, False, packed_stat), - ] + parent_info + parent_info[0] = (b"r", old_path_utf8, 0, False, b"") + if kind == "file": + entry_data = ( + entry_key, + [ + (minikind, fingerprint, size, False, packed_stat), + ] + + parent_info, + ) + elif kind == "directory": + entry_data = ( + entry_key, + [ + (minikind, b"", 0, False, packed_stat), + ] + + parent_info, + ) + elif kind == "symlink": + entry_data = ( + entry_key, + [ + (minikind, fingerprint, size, False, packed_stat), + ] + + parent_info, + ) + elif kind == "tree-reference": + entry_data = ( + entry_key, + [ + (minikind, fingerprint, 0, False, packed_stat), + ] + + parent_info, + ) else: - raise errors.BzrError(f'unknown kind {kind!r}') + raise errors.BzrError(f"unknown kind {kind!r}") entry_index, present = self._find_entry_index(entry_key, block) if not present: block.insert(entry_index, entry_data) else: - if block[entry_index][1][0][0] != b'a': + if block[entry_index][1][0][0] != b"a": raise AssertionError(f" {basename!r}({file_id!r}) already added") block[entry_index][1][0] = entry_data[1][0] - if kind == 'directory': + if kind == "directory": # insert a new dirblock self._ensure_block(block_index, entry_index, utf8path) self._mark_modified() @@ -632,7 +652,7 @@ def _bisect(self, paths): count += 1 if count > max_count: - raise errors.BzrError('Too many seeks, most likely a bug.') + raise errors.BzrError("Too many seeks, most likely a bug.") mid = max(low, (low + high - page_size) // 2) @@ -643,7 +663,7 @@ def _bisect(self, paths): block = state_file.read(read_size) start = mid - entries = block.split(b'\n') + entries = block.split(b"\n") if len(entries) < 2: # We didn't find a '\n', so we cannot have found any records. @@ -657,13 +677,13 @@ def _bisect(self, paths): # Check the first and last entries, in case they are partial, or if # we don't care about the rest of this page first_entry_num = 0 - first_fields = entries[0].split(b'\0') + first_fields = entries[0].split(b"\0") if len(first_fields) < entry_field_count: # We didn't get the complete first entry # so move start, and grab the next, which # should be a full entry start += len(entries[0]) + 1 - first_fields = entries[1].split(b'\0') + first_fields = entries[1].split(b"\0") first_entry_num = 1 if len(first_fields) <= 2: @@ -677,7 +697,7 @@ def _bisect(self, paths): # after this first record. after = start if first_fields[1]: - first_path = first_fields[1] + b'/' + first_fields[2] + first_path = first_fields[1] + b"/" + first_fields[2] else: first_path = first_fields[2] first_loc = bisect_path_left(cur_files, first_path) @@ -693,18 +713,18 @@ def _bisect(self, paths): # Parse the last entry last_entry_num = len(entries) - 1 - last_fields = entries[last_entry_num].split(b'\0') + last_fields = entries[last_entry_num].split(b"\0") if len(last_fields) < entry_field_count: # The very last hunk was not complete, # read the previous hunk after = mid + len(block) - len(entries[-1]) last_entry_num -= 1 - last_fields = entries[last_entry_num].split(b'\0') + last_fields = entries[last_entry_num].split(b"\0") else: after = mid + len(block) if last_fields[1]: - last_path = last_fields[1] + b'/' + last_fields[2] + last_path = last_fields[1] + b"/" + last_fields[2] else: last_path = last_fields[2] last_loc = bisect_path_right(post, last_path) @@ -734,9 +754,9 @@ def _bisect(self, paths): # TODO: jam 20070223 We are already splitting here, so # shouldn't we just split the whole thing rather # than doing the split again in add_one_record? - fields = entries[num].split(b'\0') + fields = entries[num].split(b"\0") if fields[1]: - path = fields[1] + b'/' + fields[2] + path = fields[1] + b"/" + fields[2] else: path = fields[2] paths.setdefault(path, []).append(fields) @@ -824,7 +844,7 @@ def _bisect_dirblocks(self, dir_list): count += 1 if count > max_count: - raise errors.BzrError('Too many seeks, most likely a bug.') + raise errors.BzrError("Too many seeks, most likely a bug.") mid = max(low, (low + high - page_size) // 2) @@ -835,7 +855,7 @@ def _bisect_dirblocks(self, dir_list): block = state_file.read(read_size) start = mid - entries = block.split(b'\n') + entries = block.split(b"\n") if len(entries) < 2: # We didn't find a '\n', so we cannot have found any records. @@ -849,13 +869,13 @@ def _bisect_dirblocks(self, dir_list): # Check the first and last entries, in case they are partial, or if # we don't care about the rest of this page first_entry_num = 0 - first_fields = entries[0].split(b'\0') + first_fields = entries[0].split(b"\0") if len(first_fields) < entry_field_count: # We didn't get the complete first entry # so move start, and grab the next, which # should be a full entry start += len(entries[0]) + 1 - first_fields = entries[1].split(b'\0') + first_fields = entries[1].split(b"\0") first_entry_num = 1 if len(first_fields) <= 1: @@ -882,13 +902,13 @@ def _bisect_dirblocks(self, dir_list): # Parse the last entry last_entry_num = len(entries) - 1 - last_fields = entries[last_entry_num].split(b'\0') + last_fields = entries[last_entry_num].split(b"\0") if len(last_fields) < entry_field_count: # The very last hunk was not complete, # read the previous hunk after = mid + len(block) - len(entries[-1]) last_entry_num -= 1 - last_fields = entries[last_entry_num].split(b'\0') + last_fields = entries[last_entry_num].split(b"\0") else: after = mid + len(block) @@ -920,7 +940,7 @@ def _bisect_dirblocks(self, dir_list): # TODO: jam 20070223 We are already splitting here, so # shouldn't we just split the whole thing rather # than doing the split again in add_one_record? - fields = entries[num].split(b'\0') + fields = entries[num].split(b"\0") paths.setdefault(fields[1], []).append(fields) for cur_dir in middle_files: @@ -975,7 +995,7 @@ def _bisect_recursive(self, paths): is_dir = False for tree_info in trees_info: minikind = tree_info[0] - if minikind == b'd': + if minikind == b"d": if is_dir: # We already processed this one as a directory, # we don't need to do the extra work again. @@ -985,7 +1005,7 @@ def _bisect_recursive(self, paths): is_dir = True if path not in processed_dirs: pending_dirs.add(path) - elif minikind == b'r': + elif minikind == b"r": # Rename, we need to directly search the target # which is contained in the fingerprint column dir_name = osutils.split(tree_info[1]) @@ -1018,8 +1038,7 @@ def _discard_merge_parents(self): return # only require all dirblocks if we are doing a full-pass removal. self._read_dirblocks_if_needed() - dead_patterns = {(b'a', b'r'), (b'a', b'a'), - (b'r', b'r'), (b'r', b'a')} + dead_patterns = {(b"a", b"r"), (b"a", b"a"), (b"r", b"r"), (b"r", b"a")} def iter_entries_removable(): for block in self._dirblocks: @@ -1034,6 +1053,7 @@ def iter_entries_removable(): else: for pos in reversed(deleted_positions): del block[1][pos] + # if the first parent is a ghost: if parents[0] in self.get_ghosts(): empty_parent = [DirState.NULL_PARENT_DETAILS] @@ -1048,8 +1068,7 @@ def iter_entries_removable(): self._mark_modified(header_modified=True) def _empty_parent_info(self): - return [DirState.NULL_PARENT_DETAILS] * (len(self._parents) - - len(self._ghosts)) + return [DirState.NULL_PARENT_DETAILS] * (len(self._parents) - len(self._ghosts)) def _ensure_block(self, parent_block_index, parent_row_index, dirname): """Ensure a block for dirname exists. @@ -1071,18 +1090,19 @@ def _ensure_block(self, parent_block_index, parent_row_index, dirname): :param dirname: The utf8 dirname to ensure there is a block for. :return: The index for the block. """ - if dirname == b'' and parent_row_index == 0 and parent_block_index == 0: + if dirname == b"" and parent_row_index == 0 and parent_block_index == 0: # This is the signature of the root row, and the # contents-of-root row is always index 1 return 1 # the basename of the directory must be the end of its full name. - if not (parent_block_index == -1 and - parent_block_index == -1 and dirname == b''): + if not ( + parent_block_index == -1 and parent_block_index == -1 and dirname == b"" + ): if not dirname.endswith( - self._dirblocks[parent_block_index][1][parent_row_index][0][1]): + self._dirblocks[parent_block_index][1][parent_row_index][0][1] + ): raise AssertionError(f"bad dirname {dirname!r}") - block_index, present = self._find_block_index_from_key( - (dirname, b'', b'')) + block_index, present = self._find_block_index_from_key((dirname, b"", b"")) if not present: # In future, when doing partial parsing, this should load and # populate the entire block. @@ -1099,14 +1119,13 @@ def _entries_to_current_state(self, new_entries): to prevent unneeded overhead when callers have a sorted list already. :return: Nothing. """ - if new_entries[0][0][0:2] != (b'', b''): - raise AssertionError( - f"Missing root row {new_entries[0][0]!r}") + if new_entries[0][0][0:2] != (b"", b""): + raise AssertionError(f"Missing root row {new_entries[0][0]!r}") # The two blocks here are deliberate: the root block and the # contents-of-root block. - self._dirblocks = [(b'', []), (b'', [])] + self._dirblocks = [(b"", []), (b"", [])] current_block = self._dirblocks[0][1] - current_dirname = b'' + current_dirname = b"" append_entry = current_block.append for entry in new_entries: if entry[0][0] != current_dirname: @@ -1128,7 +1147,7 @@ def _split_root_dirblock_into_contents(self): # The above loop leaves the "root block" entries mixed with the # "contents-of-root block". But we don't want an if check on # all entries, so instead we just fix it up here. - if self._dirblocks[1] != (b'', []): + if self._dirblocks[1] != (b"", []): raise ValueError(f"bad dirblock start {self._dirblocks[1]!r}") root_block = [] contents_of_root_block = [] @@ -1137,13 +1156,13 @@ def _split_root_dirblock_into_contents(self): root_block.append(entry) else: contents_of_root_block.append(entry) - self._dirblocks[0] = (b'', root_block) - self._dirblocks[1] = (b'', contents_of_root_block) + self._dirblocks[0] = (b"", root_block) + self._dirblocks[1] = (b"", contents_of_root_block) def _entries_for_path(self, path): """Return a list with all the entries that match path for all ids.""" dirname, basename = os.path.split(path) - key = (dirname, basename, b'') + key = (dirname, basename, b"") block_index, present = self._find_block_index_from_key(key) if not present: # the block which should contain path is absent. @@ -1152,8 +1171,7 @@ def _entries_for_path(self, path): block = self._dirblocks[block_index][1] entry_index, _ = self._find_entry_index(key, block) # we may need to look at multiple entries at this path: walk while the specific_files match. - while (entry_index < len(block) and - block[entry_index][0][0:2] == key[0:2]): + while entry_index < len(block) and block[entry_index][0][0:2] == key[0:2]: result.append(block[entry_index]) entry_index += 1 return result @@ -1173,10 +1191,10 @@ def _entry_to_line(entry): # minikind entire_entry[tree_offset + 0] = tree_data[0] # size - entire_entry[tree_offset + 2] = b'%d' % tree_data[2] + entire_entry[tree_offset + 2] = b"%d" % tree_data[2] # executable entire_entry[tree_offset + 3] = DirState._to_yesno[tree_data[3]] - return b'\0'.join(entire_entry) + return b"\0".join(entire_entry) def _find_block(self, key, add_if_missing=False): """Return the block that key should be present in. @@ -1205,24 +1223,29 @@ def _find_block_index_from_key(self, key): :return: The block index, True if the block for the key is present. """ - if key[0:2] == (b'', b''): + if key[0:2] == (b"", b""): return 0, True try: - if (self._last_block_index is not None and - self._dirblocks[self._last_block_index][0] == key[0]): + if ( + self._last_block_index is not None + and self._dirblocks[self._last_block_index][0] == key[0] + ): return self._last_block_index, True except IndexError: pass - block_index = bisect_dirblock(self._dirblocks, key[0], 1, - cache=self._split_path_cache) + block_index = bisect_dirblock( + self._dirblocks, key[0], 1, cache=self._split_path_cache + ) # _right returns one-past-where-key is so we have to subtract # one to use it. we use _right here because there are two # b'' blocks - the root, and the contents of root # we always have a minimum of 2 in self._dirblocks: root and # root-contents, and for b'', we get 2 back, so this is # simple and correct: - present = (block_index < len(self._dirblocks) and - self._dirblocks[block_index][0] == key[0]) + present = ( + block_index < len(self._dirblocks) + and self._dirblocks[block_index][0] == key[0] + ) self._last_block_index = block_index # Reset the entry index cache to the beginning of the block. self._last_entry_index = -1 @@ -1240,16 +1263,16 @@ def _find_entry_index(self, key, block): entry_index = self._last_entry_index + 1 # A hit is when the key is after the last slot, and before or # equal to the next slot. - if ((entry_index > 0 and block[entry_index - 1][0] < key) and - key <= block[entry_index][0]): + if ( + entry_index > 0 and block[entry_index - 1][0] < key + ) and key <= block[entry_index][0]: self._last_entry_index = entry_index - present = (block[entry_index][0] == key) + present = block[entry_index][0] == key return entry_index, present except IndexError: pass entry_index = bisect.bisect_left(block, (key, [])) - present = (entry_index < len_block and - block[entry_index][0] == key) + present = entry_index < len_block and block[entry_index][0] == key self._last_entry_index = entry_index return entry_index, present @@ -1264,8 +1287,7 @@ def from_tree(tree, dir_state_filename, sha1_provider=None): :return: a DirState object which is currently locked for writing. (it was locked by DirState.initialize) """ - result = DirState.initialize(dir_state_filename, - sha1_provider=sha1_provider) + result = DirState.initialize(dir_state_filename, sha1_provider=sha1_provider) try: with contextlib.ExitStack() as exit_stack: exit_stack.enter_context(tree.lock_read()) @@ -1273,8 +1295,7 @@ def from_tree(tree, dir_state_filename, sha1_provider=None): len(parent_ids) parent_trees = [] for parent_id in parent_ids: - parent_tree = tree.branch.repository.revision_tree( - parent_id) + parent_tree = tree.branch.repository.revision_tree(parent_id) parent_trees.append((parent_id, parent_tree)) exit_stack.enter_context(parent_tree.lock_read()) result.set_parent_trees(parent_trees, []) @@ -1312,32 +1333,34 @@ def update_by_delta(self, delta): delta.sort() for old_path, new_path, file_id, inv_entry in delta: if not isinstance(file_id, bytes): - raise AssertionError( - f"must be a utf8 file_id not {type(file_id)}") + raise AssertionError(f"must be a utf8 file_id not {type(file_id)}") if (file_id in insertions) or (file_id in removals): - self._raise_invalid(old_path or new_path, file_id, - "repeated file_id") + self._raise_invalid(old_path or new_path, file_id, "repeated file_id") if old_path is not None: - old_path = old_path.encode('utf-8') + old_path = old_path.encode("utf-8") removals[file_id] = old_path else: new_ids.add(file_id) if new_path is not None: if inv_entry is None: - self._raise_invalid(new_path, file_id, - "new_path with no entry") - new_path = new_path.encode('utf-8') + self._raise_invalid(new_path, file_id, "new_path with no entry") + new_path = new_path.encode("utf-8") dirname_utf8, basename = osutils.split(new_path) if basename: parents.add((dirname_utf8, inv_entry.parent_id)) key = (dirname_utf8, basename, file_id) minikind = DirState._kind_to_minikind[inv_entry.kind] - if minikind == b't': - fingerprint = inv_entry.reference_revision or b'' + if minikind == b"t": + fingerprint = inv_entry.reference_revision or b"" else: - fingerprint = b'' - insertions[file_id] = (key, minikind, inv_entry.executable, - fingerprint, new_path) + fingerprint = b"" + insertions[file_id] = ( + key, + minikind, + inv_entry.executable, + fingerprint, + new_path, + ) # Transform moves into delete+add pairs if None not in (old_path, new_path): for child in self._iter_child_entries(0, old_path): @@ -1348,16 +1371,19 @@ def update_by_delta(self, delta): minikind = child[1][0][0] fingerprint = child[1][0][4] executable = child[1][0][3] - old_child_path = osutils.pathjoin(child_dirname, - child_basename) + old_child_path = osutils.pathjoin(child_dirname, child_basename) removals[child[0][2]] = old_child_path - child_suffix = child_dirname[len(old_path):] - new_child_dirname = (new_path + child_suffix) + child_suffix = child_dirname[len(old_path) :] + new_child_dirname = new_path + child_suffix key = (new_child_dirname, child_basename, child[0][2]) - new_child_path = osutils.pathjoin(new_child_dirname, - child_basename) - insertions[child[0][2]] = (key, minikind, executable, - fingerprint, new_child_path) + new_child_path = osutils.pathjoin(new_child_dirname, child_basename) + insertions[child[0][2]] = ( + key, + minikind, + executable, + fingerprint, + new_child_path, + ) self._check_delta_ids_absent(new_ids, 0) try: self._apply_removals(removals.items()) @@ -1366,56 +1392,62 @@ def update_by_delta(self, delta): self._after_delta_check_parents(parents, 0) except errors.BzrError as e: self._changes_aborted = True - if 'integrity error' not in str(e): + if "integrity error" not in str(e): raise # _get_entry raises BzrError when a request is inconsistent; we # want such errors to be shown as InconsistentDelta - and that # fits the behaviour we trigger. - raise errors.InconsistentDeltaDelta(delta, f"error from _get_entry. {e}") from e + raise errors.InconsistentDeltaDelta( + delta, f"error from _get_entry. {e}" + ) from e def _apply_removals(self, removals): - for file_id, path in sorted(removals, reverse=True, - key=operator.itemgetter(1)): + for file_id, path in sorted(removals, reverse=True, key=operator.itemgetter(1)): dirname, basename = osutils.split(path) - block_i, entry_i, d_present, f_present = \ - self._get_block_entry_index(dirname, basename, 0) + block_i, entry_i, d_present, f_present = self._get_block_entry_index( + dirname, basename, 0 + ) try: entry = self._dirblocks[block_i][1][entry_i] except IndexError: - self._raise_invalid(path, file_id, - "Wrong path for old path.") - if not f_present or entry[1][0][0] in (b'a', b'r'): - self._raise_invalid(path, file_id, - "Wrong path for old path.") + self._raise_invalid(path, file_id, "Wrong path for old path.") + if not f_present or entry[1][0][0] in (b"a", b"r"): + self._raise_invalid(path, file_id, "Wrong path for old path.") if file_id != entry[0][2]: - self._raise_invalid(path, file_id, - "Attempt to remove path has wrong id - found %r." - % entry[0][2]) + self._raise_invalid( + path, + file_id, + "Attempt to remove path has wrong id - found %r." % entry[0][2], + ) self._make_absent(entry) # See if we have a malformed delta: deleting a directory must not # leave crud behind. This increases the number of bisects needed # substantially, but deletion or renames of large numbers of paths # is rare enough it shouldn't be an issue (famous last words?) RBC # 20080730. - block_i, entry_i, d_present, f_present = \ - self._get_block_entry_index(path, b'', 0) + block_i, entry_i, d_present, f_present = self._get_block_entry_index( + path, b"", 0 + ) if d_present: # The dir block is still present in the dirstate; this could # be due to it being in a parent tree, or a corrupt delta. for child_entry in self._dirblocks[block_i][1]: - if child_entry[1][0][0] not in (b'r', b'a'): - self._raise_invalid(path, entry[0][2], - "The file id was deleted but its children were " - "not deleted.") + if child_entry[1][0][0] not in (b"r", b"a"): + self._raise_invalid( + path, + entry[0][2], + "The file id was deleted but its children were " + "not deleted.", + ) def _apply_insertions(self, adds): try: for key, minikind, executable, fingerprint, path_utf8 in sorted(adds): - self.update_minimal(key, minikind, executable, fingerprint, - path_utf8=path_utf8) + self.update_minimal( + key, minikind, executable, fingerprint, path_utf8=path_utf8 + ) except errors.NotVersionedError: - self._raise_invalid(path_utf8.decode('utf8'), key[2], - "Missing parent") + self._raise_invalid(path_utf8.decode("utf8"), key[2], "Missing parent") def update_basis_by_delta(self, delta, new_revid): """Update the parents of this tree after a commit. @@ -1462,7 +1494,7 @@ def update_basis_by_delta(self, delta, new_revid): # expanding them recursively as needed. # At the same time, to reduce interface friction we convert the input # inventory entries to dirstate. - root_only = ('', '') + root_only = ("", "") # Accumulate parent references (path_utf8, id), to check for parentless # items or items placed under files/links/tree-references. We get # references from every item in the delta that is not a deletion and @@ -1473,17 +1505,16 @@ def update_basis_by_delta(self, delta, new_revid): new_ids = set() for old_path, new_path, file_id, inv_entry in delta: if file_id.__class__ is not bytes: - raise AssertionError( - f"must be a utf8 file_id not {type(file_id)}") + raise AssertionError(f"must be a utf8 file_id not {type(file_id)}") if inv_entry is not None and file_id != inv_entry.file_id: - self._raise_invalid(new_path, file_id, - f"mismatched entry file_id {inv_entry!r}") + self._raise_invalid( + new_path, file_id, f"mismatched entry file_id {inv_entry!r}" + ) if new_path is None: new_path_utf8 = None else: if inv_entry is None: - self._raise_invalid(new_path, file_id, - "new_path with no entry") + self._raise_invalid(new_path, file_id, "new_path with no entry") new_path_utf8 = encode(new_path) # note the parent for validation dirname_utf8, basename_utf8 = osutils.split(new_path_utf8) @@ -1494,8 +1525,9 @@ def update_basis_by_delta(self, delta, new_revid): else: old_path_utf8 = encode(old_path) if old_path is None: - adds.append((None, new_path_utf8, file_id, - inv_to_entry(inv_entry), True)) + adds.append( + (None, new_path_utf8, file_id, inv_to_entry(inv_entry), True) + ) new_ids.add(file_id) elif new_path is None: deletes.append((old_path_utf8, None, file_id, None, True)) @@ -1508,8 +1540,9 @@ def update_basis_by_delta(self, delta, new_revid): # handle that case, we can avoid a lot of add+delete # pairs for objects that stay put. # elif old_path == new_path: - changes.append((old_path_utf8, new_path_utf8, file_id, - inv_to_entry(inv_entry))) + changes.append( + (old_path_utf8, new_path_utf8, file_id, inv_to_entry(inv_entry)) + ) else: # Renames: # Because renames must preserve their children we must have @@ -1526,34 +1559,35 @@ def update_basis_by_delta(self, delta, new_revid): self._update_basis_apply_deletes(deletes) deletes = [] # Split into an add/delete pair recursively. - adds.append((old_path_utf8, new_path_utf8, file_id, - inv_to_entry(inv_entry), False)) + adds.append( + ( + old_path_utf8, + new_path_utf8, + file_id, + inv_to_entry(inv_entry), + False, + ) + ) # Expunge deletes that we've seen so that deleted/renamed # children of a rename directory are handled correctly. - new_deletes = reversed(list( - self._iter_child_entries(1, old_path_utf8))) + new_deletes = reversed(list(self._iter_child_entries(1, old_path_utf8))) # Remove the current contents of the tree at orig_path, and # reinsert at the correct new path. for entry in new_deletes: child_dirname, child_basename, child_file_id = entry[0] if child_dirname: - source_path = child_dirname + b'/' + child_basename + source_path = child_dirname + b"/" + child_basename else: source_path = child_basename if new_path_utf8: - target_path = \ - new_path_utf8 + source_path[len(old_path_utf8):] + target_path = new_path_utf8 + source_path[len(old_path_utf8) :] else: - if old_path_utf8 == b'': - raise AssertionError("cannot rename directory to" - " itself") - target_path = source_path[len(old_path_utf8) + 1:] - adds.append( - (None, target_path, entry[0][2], entry[1][1], False)) - deletes.append( - (source_path, target_path, entry[0][2], None, False)) - deletes.append( - (old_path_utf8, new_path_utf8, file_id, None, False)) + if old_path_utf8 == b"": + raise AssertionError("cannot rename directory to" " itself") + target_path = source_path[len(old_path_utf8) + 1 :] + adds.append((None, target_path, entry[0][2], entry[1][1], False)) + deletes.append((source_path, target_path, entry[0][2], None, False)) + deletes.append((old_path_utf8, new_path_utf8, file_id, None, False)) self._check_delta_ids_absent(new_ids, 1) try: @@ -1567,12 +1601,14 @@ def update_basis_by_delta(self, delta, new_revid): self._after_delta_check_parents(parents, 1) except errors.BzrError as e: self._changes_aborted = True - if 'integrity error' not in str(e): + if "integrity error" not in str(e): raise # _get_entry raises BzrError when a request is inconsistent; we # want such errors to be shown as InconsistentDelta - and that # fits the behaviour we trigger. - raise errors.InconsistentDeltaDelta(delta, f"error from _get_entry. {e}") from e + raise errors.InconsistentDeltaDelta( + delta, f"error from _get_entry. {e}" + ) from e self._mark_modified(header_modified=True) self._id_index = None @@ -1585,8 +1621,9 @@ def _check_delta_ids_absent(self, new_ids, tree_index): id_index = self._get_id_index() for file_id in new_ids: for key in id_index.get(file_id): - block_i, entry_i, d_present, f_present = \ - self._get_block_entry_index(key[0], key[1], tree_index) + block_i, entry_i, d_present, f_present = self._get_block_entry_index( + key[0], key[1], tree_index + ) if not f_present: # In a different tree continue @@ -1594,9 +1631,12 @@ def _check_delta_ids_absent(self, new_ids, tree_index): if entry[0][2] != file_id: # Different file_id, so not what we want. continue - self._raise_invalid((b"%s/%s" % key[0:2]).decode('utf8'), file_id, - "This file_id is new in the delta but already present in " - "the target") + self._raise_invalid( + (b"%s/%s" % key[0:2]).decode("utf8"), + file_id, + "This file_id is new in the delta but already present in " + "the target", + ) def _raise_invalid(self, path, file_id, reason): self._changes_aborted = True @@ -1631,31 +1671,44 @@ def _update_basis_apply_adds(self, adds): # However, it might have just been an empty directory. Look for # the parent in the basis-so-far before throwing an error. parent_dir, parent_base = osutils.split(dirname) - parent_block_idx, parent_entry_idx, _, parent_present = \ - self._get_block_entry_index(parent_dir, parent_base, 1) + ( + parent_block_idx, + parent_entry_idx, + _, + parent_present, + ) = self._get_block_entry_index(parent_dir, parent_base, 1) if not parent_present: - self._raise_invalid(new_path, file_id, - "Unable to find block for this record." - " Was the parent added?") + self._raise_invalid( + new_path, + file_id, + "Unable to find block for this record." + " Was the parent added?", + ) self._ensure_block(parent_block_idx, parent_entry_idx, dirname) block = self._dirblocks[block_index][1] entry_index, present = self._find_entry_index(entry_key, block) if real_add: if old_path is not None: - self._raise_invalid(new_path, file_id, - f'considered a real add but still had old_path at {old_path}') + self._raise_invalid( + new_path, + file_id, + f"considered a real add but still had old_path at {old_path}", + ) if present: entry = block[entry_index] basis_kind = entry[1][1][0] - if basis_kind == b'a': + if basis_kind == b"a": entry[1][1] = new_details - elif basis_kind == b'r': + elif basis_kind == b"r": raise NotImplementedError() else: - self._raise_invalid(new_path, file_id, - "An entry was marked as a new add" - " but the basis target already existed") + self._raise_invalid( + new_path, + file_id, + "An entry was marked as a new add" + " but the basis target already existed", + ) else: # The exact key was not found in the block. However, we need to # check if there is a key next to us that would have matched. @@ -1670,20 +1723,23 @@ def _update_basis_apply_adds(self, adds): continue if maybe_entry[0][2] == file_id: raise AssertionError( - '_find_entry_index didnt find a key match' - f' but walking the data did, for {entry_key}') + "_find_entry_index didnt find a key match" + f" but walking the data did, for {entry_key}" + ) basis_kind = maybe_entry[1][1][0] - if basis_kind not in (b'a', b'r'): - self._raise_invalid(new_path, file_id, - "we have an add record for path, but the path" - f" is already present with another file_id {maybe_entry[0][2]}") - - entry = (entry_key, [DirState.NULL_PARENT_DETAILS, - new_details]) + if basis_kind not in (b"a", b"r"): + self._raise_invalid( + new_path, + file_id, + "we have an add record for path, but the path" + f" is already present with another file_id {maybe_entry[0][2]}", + ) + + entry = (entry_key, [DirState.NULL_PARENT_DETAILS, new_details]) block.insert(entry_index, entry) active_kind = entry[1][0][0] - if active_kind == b'a': + if active_kind == b"a": # The active record shows up as absent, this could be genuine, # or it could be present at some other location. We need to # verify. @@ -1694,37 +1750,44 @@ def _update_basis_apply_adds(self, adds): # need keys = id_index.get(file_id) for key in keys: - block_i, entry_i, d_present, f_present = \ - self._get_block_entry_index(key[0], key[1], 0) + ( + block_i, + entry_i, + d_present, + f_present, + ) = self._get_block_entry_index(key[0], key[1], 0) if not f_present: continue active_entry = self._dirblocks[block_i][1][entry_i] - if (active_entry[0][2] != file_id): + if active_entry[0][2] != file_id: # Some other file is at this path, we don't need to # link it. continue real_active_kind = active_entry[1][0][0] - if real_active_kind in (b'a', b'r'): + if real_active_kind in (b"a", b"r"): # We found a record, which was not *this* record, # which matches the file_id, but is not actually # present. Something seems *really* wrong. - self._raise_invalid(new_path, file_id, - "We found a tree0 entry that doesnt make sense") + self._raise_invalid( + new_path, + file_id, + "We found a tree0 entry that doesnt make sense", + ) # Now, we've found a tree0 entry which matches the file_id # but is at a different location. So update them to be # rename records. active_dir, active_name = active_entry[0][:2] if active_dir: - active_path = active_dir + b'/' + active_name + active_path = active_dir + b"/" + active_name else: active_path = active_name - active_entry[1][1] = st(b'r', new_path, 0, False, b'') - entry[1][0] = st(b'r', active_path, 0, False, b'') - elif active_kind == b'r': + active_entry[1][1] = st(b"r", new_path, 0, False, b"") + entry[1][0] = st(b"r", active_path, 0, False, b"") + elif active_kind == b"r": raise NotImplementedError() new_kind = new_details[0] - if new_kind == b'd': + if new_kind == b"d": self._ensure_block(block_index, entry_index, new_path) def _update_basis_apply_changes(self, changes): @@ -1736,9 +1799,10 @@ def _update_basis_apply_changes(self, changes): for _old_path, new_path, file_id, new_details in changes: # the entry for this file_id must be in tree 0. entry = self._get_entry(1, file_id, new_path) - if entry[0] is None or entry[1][1][0] in (b'a', b'r'): - self._raise_invalid(new_path, file_id, - 'changed entry considered not present') + if entry[0] is None or entry[1][1][0] in (b"a", b"r"): + self._raise_invalid( + new_path, file_id, "changed entry considered not present" + ) entry[1][1] = new_details def _update_basis_apply_deletes(self, deletes): @@ -1759,45 +1823,54 @@ def _update_basis_apply_deletes(self, deletes): self._raise_invalid(old_path, file_id, "bad delete delta") # the entry for this file_id must be in tree 1. dirname, basename = osutils.split(old_path) - block_index, entry_index, dir_present, file_present = \ - self._get_block_entry_index(dirname, basename, 1) + ( + block_index, + entry_index, + dir_present, + file_present, + ) = self._get_block_entry_index(dirname, basename, 1) if not file_present: - self._raise_invalid(old_path, file_id, - 'basis tree does not contain removed entry') + self._raise_invalid( + old_path, file_id, "basis tree does not contain removed entry" + ) entry = self._dirblocks[block_index][1][entry_index] # The state of the entry in the 'active' WT active_kind = entry[1][0][0] if entry[0][2] != file_id: - self._raise_invalid(old_path, file_id, - 'mismatched file_id in tree 1') + self._raise_invalid(old_path, file_id, "mismatched file_id in tree 1") dir_block = () old_kind = entry[1][1][0] - if active_kind in b'ar': + if active_kind in b"ar": # The active tree doesn't have this file_id. # The basis tree is changing this record. If this is a # rename, then we don't want the record here at all # anymore. If it is just an in-place change, we want the # record here, but we'll add it if we need to. So we just # delete it - if active_kind == b'r': + if active_kind == b"r": active_path = entry[1][0][1] active_entry = self._get_entry(0, file_id, active_path) - if active_entry[1][1][0] != b'r': - self._raise_invalid(old_path, file_id, - "Dirstate did not have matching rename entries") - elif active_entry[1][0][0] in b'ar': - self._raise_invalid(old_path, file_id, - "Dirstate had a rename pointing at an inactive" - " tree0") + if active_entry[1][1][0] != b"r": + self._raise_invalid( + old_path, + file_id, + "Dirstate did not have matching rename entries", + ) + elif active_entry[1][0][0] in b"ar": + self._raise_invalid( + old_path, + file_id, + "Dirstate had a rename pointing at an inactive" " tree0", + ) active_entry[1][1] = null del self._dirblocks[block_index][1][entry_index] - if old_kind == b'd': + if old_kind == b"d": # This was a directory, and the active tree says it # doesn't exist, and now the basis tree says it doesn't # exist. Remove its dirblock if present - (dir_block_index, - present) = self._find_block_index_from_key( - (old_path, b'', b'')) + (dir_block_index, present) = self._find_block_index_from_key( + (old_path, b"", b"") + ) if present: dir_block = self._dirblocks[dir_block_index][1] if not dir_block: @@ -1807,16 +1880,19 @@ def _update_basis_apply_deletes(self, deletes): # There is still an active record, so just mark this # removed. entry[1][1] = null - block_i, entry_i, d_present, f_present = \ - self._get_block_entry_index(old_path, b'', 1) + block_i, entry_i, d_present, f_present = self._get_block_entry_index( + old_path, b"", 1 + ) if d_present: dir_block = self._dirblocks[block_i][1] for child_entry in dir_block: child_basis_kind = child_entry[1][1][0] - if child_basis_kind not in b'ar': - self._raise_invalid(old_path, file_id, - "The file id was deleted but its children were " - "not deleted.") + if child_basis_kind not in b"ar": + self._raise_invalid( + old_path, + file_id, + "The file id was deleted but its children were " "not deleted.", + ) def _after_delta_check_parents(self, parents, index): """Check that parents required by the delta are all intact. @@ -1831,15 +1907,20 @@ def _after_delta_check_parents(self, parents, index): # has the right file id. entry = self._get_entry(index, file_id, dirname_utf8) if entry[1] is None: - self._raise_invalid(dirname_utf8.decode('utf8'), - file_id, "This parent is not present.") + self._raise_invalid( + dirname_utf8.decode("utf8"), file_id, "This parent is not present." + ) # Parents of things must be directories - if entry[1][index][0] != b'd': - self._raise_invalid(dirname_utf8.decode('utf8'), - file_id, "This parent is not a directory.") - - def _observed_sha1(self, entry, sha1, stat_value, - _stat_to_minikind=_stat_to_minikind): + if entry[1][index][0] != b"d": + self._raise_invalid( + dirname_utf8.decode("utf8"), + file_id, + "This parent is not a directory.", + ) + + def _observed_sha1( + self, entry, sha1, stat_value, _stat_to_minikind=_stat_to_minikind + ): """Note the sha1 of a file. :param entry: The entry the sha1 is for. @@ -1851,13 +1932,20 @@ def _observed_sha1(self, entry, sha1, stat_value, except KeyError: # Unhandled kind return None - if minikind == b'f': + if minikind == b"f": if self._cutoff_time is None: self._sha_cutoff_time() - if (stat_value.st_mtime < self._cutoff_time - and stat_value.st_ctime < self._cutoff_time): - entry[1][0] = (b'f', sha1, stat_value.st_size, entry[1][0][3], - pack_stat(stat_value)) + if ( + stat_value.st_mtime < self._cutoff_time + and stat_value.st_ctime < self._cutoff_time + ): + entry[1][0] = ( + b"f", + sha1, + stat_value.st_size, + entry[1][0][3], + pack_stat(stat_value), + ) self._mark_modified([entry]) def _sha_cutoff_time(self): @@ -1905,9 +1993,9 @@ def _read_link(abspath, old_link): # paths are produced by UnicodeDirReader on purpose. abspath = os.fsencode(abspath) target = os.readlink(abspath) - if sys.getfilesystemencoding() not in ('utf-8', 'ascii'): + if sys.getfilesystemencoding() not in ("utf-8", "ascii"): # Change encoding if needed - target = os.fsdecode(target).encode('UTF-8') + target = os.fsdecode(target).encode("UTF-8") return target def get_ghosts(self): @@ -1917,8 +2005,10 @@ def get_ghosts(self): def get_lines(self): """Serialise the entire dirstate to a sequence of lines.""" - if (self._header_state == DirState.IN_MEMORY_UNMODIFIED and - self._dirblock_state == DirState.IN_MEMORY_UNMODIFIED): + if ( + self._header_state == DirState.IN_MEMORY_UNMODIFIED + and self._dirblock_state == DirState.IN_MEMORY_UNMODIFIED + ): # read what's on disk. self._state_file.seek(0) return self._state_file.readlines() @@ -1939,74 +2029,96 @@ def _get_fields_to_entry(self): # This is intentionally unrolled for performance num_present_parents = self._num_present_parents() if num_present_parents == 0: + def fields_to_entry_0_parents(fields, _int=int): path_name_file_id_key = (fields[0], fields[1], fields[2]) - return (path_name_file_id_key, [ - ( # Current tree - fields[3], # minikind - fields[4], # fingerprint - _int(fields[5]), # size - fields[6] == b'y', # executable - fields[7], # packed_stat or revision_id - )]) + return ( + path_name_file_id_key, + [ + ( # Current tree + fields[3], # minikind + fields[4], # fingerprint + _int(fields[5]), # size + fields[6] == b"y", # executable + fields[7], # packed_stat or revision_id + ) + ], + ) + return fields_to_entry_0_parents elif num_present_parents == 1: + def fields_to_entry_1_parent(fields, _int=int): path_name_file_id_key = (fields[0], fields[1], fields[2]) - return (path_name_file_id_key, [ - ( # Current tree - fields[3], # minikind - fields[4], # fingerprint - _int(fields[5]), # size - fields[6] == b'y', # executable - fields[7], # packed_stat or revision_id - ), - ( # Parent 1 - fields[8], # minikind - fields[9], # fingerprint - _int(fields[10]), # size - fields[11] == b'y', # executable - fields[12], # packed_stat or revision_id - ), - ]) + return ( + path_name_file_id_key, + [ + ( # Current tree + fields[3], # minikind + fields[4], # fingerprint + _int(fields[5]), # size + fields[6] == b"y", # executable + fields[7], # packed_stat or revision_id + ), + ( # Parent 1 + fields[8], # minikind + fields[9], # fingerprint + _int(fields[10]), # size + fields[11] == b"y", # executable + fields[12], # packed_stat or revision_id + ), + ], + ) + return fields_to_entry_1_parent elif num_present_parents == 2: + def fields_to_entry_2_parents(fields, _int=int): path_name_file_id_key = (fields[0], fields[1], fields[2]) - return (path_name_file_id_key, [ - ( # Current tree - fields[3], # minikind - fields[4], # fingerprint - _int(fields[5]), # size - fields[6] == b'y', # executable - fields[7], # packed_stat or revision_id - ), - ( # Parent 1 - fields[8], # minikind - fields[9], # fingerprint - _int(fields[10]), # size - fields[11] == b'y', # executable - fields[12], # packed_stat or revision_id - ), - ( # Parent 2 - fields[13], # minikind - fields[14], # fingerprint - _int(fields[15]), # size - fields[16] == b'y', # executable - fields[17], # packed_stat or revision_id - ), - ]) + return ( + path_name_file_id_key, + [ + ( # Current tree + fields[3], # minikind + fields[4], # fingerprint + _int(fields[5]), # size + fields[6] == b"y", # executable + fields[7], # packed_stat or revision_id + ), + ( # Parent 1 + fields[8], # minikind + fields[9], # fingerprint + _int(fields[10]), # size + fields[11] == b"y", # executable + fields[12], # packed_stat or revision_id + ), + ( # Parent 2 + fields[13], # minikind + fields[14], # fingerprint + _int(fields[15]), # size + fields[16] == b"y", # executable + fields[17], # packed_stat or revision_id + ), + ], + ) + return fields_to_entry_2_parents else: + def fields_to_entry_n_parents(fields, _int=int): path_name_file_id_key = (fields[0], fields[1], fields[2]) - trees = [(fields[cur], # minikind - fields[cur + 1], # fingerprint - _int(fields[cur + 2]), # size - fields[cur + 3] == b'y', # executable - fields[cur + 4], # stat or revision_id - ) for cur in range(3, len(fields) - 1, 5)] + trees = [ + ( + fields[cur], # minikind + fields[cur + 1], # fingerprint + _int(fields[cur + 2]), # size + fields[cur + 3] == b"y", # executable + fields[cur + 4], # stat or revision_id + ) + for cur in range(3, len(fields) - 1, 5) + ] return path_name_file_id_key, trees + return fields_to_entry_n_parents def get_parent_ids(self): @@ -2035,7 +2147,7 @@ def _get_block_entry_index(self, dirname, basename, tree_index): tree present there. """ self._read_dirblocks_if_needed() - key = dirname, basename, b'' + key = dirname, basename, b"" block_index, present = self._find_block_index_from_key(key) if not present: # no such directory - return the dir index and 0 for the row. @@ -2045,14 +2157,15 @@ def _get_block_entry_index(self, dirname, basename, tree_index): # linear search through entries at this path to find the one # requested. while entry_index < len(block) and block[entry_index][0][1] == basename: - if block[entry_index][1][tree_index][0] not in (b'a', b'r'): + if block[entry_index][1][tree_index][0] not in (b"a", b"r"): # neither absent or relocated return block_index, entry_index, True, True entry_index += 1 return block_index, entry_index, True, False - def _get_entry(self, tree_index, fileid_utf8=None, path_utf8=None, - include_deleted=False): + def _get_entry( + self, tree_index, fileid_utf8=None, path_utf8=None, include_deleted=False + ): """Get the dirstate entry for path in tree tree_index. If either file_id or path is supplied, it is used as the key to lookup. @@ -2074,29 +2187,36 @@ def _get_entry(self, tree_index, fileid_utf8=None, path_utf8=None, self._read_dirblocks_if_needed() if path_utf8 is not None: if not isinstance(path_utf8, bytes): - raise errors.BzrError(f'path_utf8 is not bytes: {type(path_utf8)} {path_utf8!r}') + raise errors.BzrError( + f"path_utf8 is not bytes: {type(path_utf8)} {path_utf8!r}" + ) # path lookups are faster dirname, basename = osutils.split(path_utf8) - block_index, entry_index, dir_present, file_present = \ - self._get_block_entry_index(dirname, basename, tree_index) + ( + block_index, + entry_index, + dir_present, + file_present, + ) = self._get_block_entry_index(dirname, basename, tree_index) if not file_present: return None, None entry = self._dirblocks[block_index][1][entry_index] - if not (entry[0][2] and entry[1][tree_index][0] not in (b'a', b'r')): - raise AssertionError('unversioned entry?') + if not (entry[0][2] and entry[1][tree_index][0] not in (b"a", b"r")): + raise AssertionError("unversioned entry?") if fileid_utf8: if entry[0][2] != fileid_utf8: self._changes_aborted = True - raise errors.BzrError('integrity error ? : mismatching' - ' tree_index, file_id and path') + raise errors.BzrError( + "integrity error ? : mismatching" + " tree_index, file_id and path" + ) return entry else: possible_keys = self._get_id_index().get(fileid_utf8) if not possible_keys: return None, None for key in possible_keys: - block_index, present = \ - self._find_block_index_from_key(key) + block_index, present = self._find_block_index_from_key(key) # strange, probably indicates an out of date # id index - for now, allow this. if not present: @@ -2111,23 +2231,25 @@ def _get_entry(self, tree_index, fileid_utf8=None, path_utf8=None, # TODO: We might want to assert that entry[0][2] == # fileid_utf8. # GZ 2017-06-09: Hoist set of minkinds somewhere - if entry[1][tree_index][0] in {b'f', b'd', b'l', b't'}: + if entry[1][tree_index][0] in {b"f", b"d", b"l", b"t"}: # this is the result we are looking for: the # real home of this file_id in this tree. return entry - if entry[1][tree_index][0] == b'a': + if entry[1][tree_index][0] == b"a": # there is no home for this entry in this tree if include_deleted: return entry return None, None - if entry[1][tree_index][0] != b'r': + if entry[1][tree_index][0] != b"r": raise AssertionError( - "entry {!r} has invalid minikind {!r} for tree {!r}".format(entry, - entry[1][tree_index][0], - tree_index)) + "entry {!r} has invalid minikind {!r} for tree {!r}".format( + entry, entry[1][tree_index][0], tree_index + ) + ) real_path = entry[1][tree_index][1] - return self._get_entry(tree_index, fileid_utf8=fileid_utf8, - path_utf8=real_path) + return self._get_entry( + tree_index, fileid_utf8=fileid_utf8, path_utf8=real_path + ) return None, None @classmethod @@ -2151,12 +2273,16 @@ def initialize(cls, path, sha1_provider=None): sha1_provider = DefaultSHA1Provider() result = cls(path, sha1_provider) # root dir and root dir contents with no children. - empty_tree_dirblocks = [(b'', []), (b'', [])] + empty_tree_dirblocks = [(b"", []), (b"", [])] # a new root directory, with a NULLSTAT. empty_tree_dirblocks[0][1].append( - ((b'', b'', inventory.ROOT_ID), [ - (b'd', b'', 0, False, DirState.NULLSTAT), - ])) + ( + (b"", b"", inventory.ROOT_ID), + [ + (b"d", b"", 0, False, DirState.NULLSTAT), + ], + ) + ) result.lock_write() try: result._set_data([], empty_tree_dirblocks) @@ -2179,13 +2305,12 @@ def _iter_child_entries(self, tree_index, path_utf8): """ pending_dirs = [] next_pending_dirs = [path_utf8] - absent = (b'a', b'r') + absent = (b"a", b"r") while next_pending_dirs: pending_dirs = next_pending_dirs next_pending_dirs = [] for path in pending_dirs: - block_index, present = self._find_block_index_from_key( - (path, b'', b'')) + block_index, present = self._find_block_index_from_key((path, b"", b"")) if block_index == 0: block_index = 1 if len(self._dirblocks) == 1: @@ -2200,9 +2325,9 @@ def _iter_child_entries(self, tree_index, path_utf8): kind = entry[1][tree_index][0] if kind not in absent: yield entry - if kind == b'd': + if kind == b"d": if entry[0][0]: - path = entry[0][0] + b'/' + entry[0][1] + path = entry[0][0] + b"/" + entry[0][1] else: path = entry[0][1] next_pending_dirs.append(path) @@ -2233,16 +2358,28 @@ def _get_id_index(self): @classmethod def _make_deleted_row(cls, fileid_utf8, parents): """Return a deleted row for fileid_utf8.""" - return (b'/', b'RECYCLED.BIN', b'file', fileid_utf8, 0, DirState.NULLSTAT, - b''), parents + return ( + b"/", + b"RECYCLED.BIN", + b"file", + fileid_utf8, + 0, + DirState.NULLSTAT, + b"", + ), parents def _num_present_parents(self): """The number of parent entries in each record row.""" return len(self._parents) - len(self._ghosts) @classmethod - def on_file(cls, path, sha1_provider=None, worth_saving_limit=0, - use_filesystem_for_exec=True): + def on_file( + cls, + path, + sha1_provider=None, + worth_saving_limit=0, + use_filesystem_for_exec=True, + ): """Construct a DirState on the file at path "path". :param path: The path at which the dirstate file on disk should live. @@ -2257,9 +2394,12 @@ def on_file(cls, path, sha1_provider=None, worth_saving_limit=0, """ if sha1_provider is None: sha1_provider = DefaultSHA1Provider() - result = cls(path, sha1_provider, - worth_saving_limit=worth_saving_limit, - use_filesystem_for_exec=use_filesystem_for_exec) + result = cls( + path, + sha1_provider, + worth_saving_limit=worth_saving_limit, + use_filesystem_for_exec=use_filesystem_for_exec, + ) return result def _read_dirblocks_if_needed(self): @@ -2283,11 +2423,11 @@ def _read_header(self): """ self._read_prelude() parent_line = self._state_file.readline() - info = parent_line.split(b'\0') + info = parent_line.split(b"\0") int(info[0]) self._parents = info[1:-1] ghost_line = self._state_file.readline() - info = ghost_line.split(b'\0') + info = ghost_line.split(b"\0") int(info[1]) self._ghosts = info[2:-1] self._header_state = DirState.IN_MEMORY_UNMODIFIED @@ -2312,16 +2452,15 @@ def _read_prelude(self): """ header = self._state_file.readline() if header != DirState.HEADER_FORMAT_3: - raise errors.BzrError( - f'invalid header line: {header!r}') + raise errors.BzrError(f"invalid header line: {header!r}") crc_line = self._state_file.readline() - if not crc_line.startswith(b'crc32: '): - raise errors.BzrError(f'missing crc32 checksum: {crc_line!r}') - self.crc_expected = int(crc_line[len(b'crc32: '):-1]) + if not crc_line.startswith(b"crc32: "): + raise errors.BzrError(f"missing crc32 checksum: {crc_line!r}") + self.crc_expected = int(crc_line[len(b"crc32: ") : -1]) num_entries_line = self._state_file.readline() - if not num_entries_line.startswith(b'num_entries: '): - raise errors.BzrError('missing num_entries line') - self._num_entries = int(num_entries_line[len(b'num_entries: '):-1]) + if not num_entries_line.startswith(b"num_entries: "): + raise errors.BzrError("missing num_entries line") + self._num_entries = int(num_entries_line[len(b"num_entries: ") : -1]) def sha1_from_stat(self, path, stat_result): """Find a sha1 given a stat lookup.""" @@ -2332,7 +2471,7 @@ def _get_packed_stat_index(self): if self._packed_stat_index is None: index = {} for _key, tree_details in self._iter_entries(): - if tree_details[0][0] == b'f': + if tree_details[0][0] == b"f": index[tree_details[0][4]] = tree_details[0][1] self._packed_stat_index = index return self._packed_stat_index @@ -2354,8 +2493,7 @@ def save(self): if self._changes_aborted: # Should this be a warning? For now, I'm expecting that places that # mark it inconsistent will warn, making a warning here redundant. - trace.mutter('Not saving DirState because ' - '_changes_aborted is set.') + trace.mutter("Not saving DirState because " "_changes_aborted is set.") return # TODO: Since we now distinguish IN_MEMORY_MODIFIED from # IN_MEMORY_HASH_MODIFIED, we should only fail quietly if we fail @@ -2365,7 +2503,7 @@ def save(self): return grabbed_write_lock = False - if self._lock_state != 'w': + if self._lock_state != "w": grabbed_write_lock, new_lock = self._lock_token.temporary_write_lock() # Switch over to the new lock, as the old one may be closed. # TODO: jam 20070315 We should validate the disk file has @@ -2394,13 +2532,15 @@ def save(self): def _maybe_fdatasync(self): """Flush to disk if possible and if not configured off.""" - if self._config_stack.get('dirstate.fdatasync'): + if self._config_stack.get("dirstate.fdatasync"): osutils.fdatasync(self._state_file.fileno()) def _worth_saving(self): """Is it worth saving the dirstate or not?""" - if (self._header_state == DirState.IN_MEMORY_MODIFIED - or self._dirblock_state == DirState.IN_MEMORY_MODIFIED): + if ( + self._header_state == DirState.IN_MEMORY_MODIFIED + or self._dirblock_state == DirState.IN_MEMORY_MODIFIED + ): return True if self._dirblock_state == DirState.IN_MEMORY_HASH_MODIFIED: if self._worth_saving_limit == -1: @@ -2453,12 +2593,12 @@ def set_path_id(self, path, new_id): # Nothing to change. return if new_id.__class__ != bytes: - raise AssertionError( - f"must be a utf8 file_id not {type(new_id)}") + raise AssertionError(f"must be a utf8 file_id not {type(new_id)}") # mark the old path absent, and insert a new root path self._make_absent(entry) - self.update_minimal((b'', b'', new_id), b'd', - path_utf8=b'', packed_stat=entry[1][0][4]) + self.update_minimal( + (b"", b"", new_id), b"d", path_utf8=b"", packed_stat=entry[1][0][4] + ) self._mark_modified() def set_parent_trees(self, trees, ghosts): @@ -2515,10 +2655,11 @@ def set_parent_trees(self, trees, ghosts): # one: the current tree for entry in self._iter_entries(): # skip entries not in the current tree - if entry[1][0][0] in (b'a', b'r'): # absent, relocated + if entry[1][0][0] in (b"a", b"r"): # absent, relocated continue - by_path[entry[0]] = [entry[1][0]] + \ - [DirState.NULL_PARENT_DETAILS] * parent_count + by_path[entry[0]] = [entry[1][0]] + [ + DirState.NULL_PARENT_DETAILS + ] * parent_count # TODO: Possibly inline this, since we know it isn't present yet # id_index[entry[0][2]] = (entry[0],) id_index.add(entry[0]) @@ -2531,8 +2672,9 @@ def set_parent_trees(self, trees, ghosts): # any fileid in this tree as we set the by_path[id] to: # already_processed_tree_details + new_details + new_location_suffix # the suffix is from tree_index+1:parent_count+1. - new_location_suffix = [ - DirState.NULL_PARENT_DETAILS] * (parent_count - tree_index) + new_location_suffix = [DirState.NULL_PARENT_DETAILS] * ( + parent_count - tree_index + ) # now stitch in all the entries from this tree last_dirname = None for path, entry in tree.iter_entries_by_dir(): @@ -2546,7 +2688,7 @@ def set_parent_trees(self, trees, ghosts): # all the mappings are valid and have correct relocation # records where needed. file_id = entry.file_id - path_utf8 = path.encode('utf8') + path_utf8 = path.encode("utf8") dirname, basename = osutils.split(path_utf8) if dirname == last_dirname: # Try to re-use objects as much as possible @@ -2566,16 +2708,16 @@ def set_parent_trees(self, trees, ghosts): # other trees, so put absent pointers there # This is the vertical axis in the matrix, all pointing # to the real path. - by_path[entry_key][tree_index] = st(b'r', path_utf8, 0, - False, b'') + by_path[entry_key][tree_index] = st( + b"r", path_utf8, 0, False, b"" + ) # by path consistency: Insert into an existing path record # (trivial), or add a new one with relocation pointers for the # other tree indexes. if new_entry_key in entry_keys: # there is already an entry where this data belongs, just # insert it. - by_path[new_entry_key][tree_index] = \ - _inv_entry_to_details(entry) + by_path[new_entry_key][tree_index] = _inv_entry_to_details(entry) else: # add relocated entries to the horizontal axis - this row # mapping from path,id. We need to look up the correct path @@ -2590,16 +2732,14 @@ def set_parent_trees(self, trees, ghosts): else: # grab any one entry, use it to find the right path. a_key = next(iter(entry_keys)) - if by_path[a_key][lookup_index][0] in (b'r', b'a'): + if by_path[a_key][lookup_index][0] in (b"r", b"a"): # its a pointer or missing statement, use it as # is. - new_details.append( - by_path[a_key][lookup_index]) + new_details.append(by_path[a_key][lookup_index]) else: # we have the right key, make a pointer to it. - real_path = (b'/'.join(a_key[0:2])).strip(b'/') - new_details.append(st(b'r', real_path, 0, False, - b'')) + real_path = (b"/".join(a_key[0:2])).strip(b"/") + new_details.append(st(b"r", real_path, 0, False, b"")) new_details.append(_inv_entry_to_details(entry)) new_details.extend(new_location_suffix) by_path[new_entry_key] = new_details @@ -2634,9 +2774,10 @@ def _key(entry, _split_dirs=split_dirs, _st=static_tuple.StaticTuple): try: split = _split_dirs[dirpath] except KeyError: - split = _st.from_sequence(dirpath.split(b'/')) + split = _st.from_sequence(dirpath.split(b"/")) _split_dirs[dirpath] = split return _st(split, fname, file_id) + return sorted(entry_list, key=_key) def set_state_from_inventory(self, new_inv): @@ -2647,10 +2788,11 @@ def set_state_from_inventory(self, new_inv): :param new_inv: The inventory object to set current state from. """ - if debug.debug_flag_enabled('evil'): - trace.mutter_callsite(1, - "set_state_from_inventory called; please mutate the tree instead") - tracing = debug.debug_flag_enabled('dirstate') + if debug.debug_flag_enabled("evil"): + trace.mutter_callsite( + 1, "set_state_from_inventory called; please mutate the tree instead" + ) + tracing = debug.debug_flag_enabled("dirstate") if tracing: trace.mutter("set_state_from_inventory trace:") self._read_dirblocks_if_needed() @@ -2680,50 +2822,58 @@ def advance(iterator): return next(iterator) except StopIteration: return None + while current_new or current_old: # skip entries in old that are not really there - if current_old and current_old[1][0][0] in (b'a', b'r'): + if current_old and current_old[1][0][0] in (b"a", b"r"): # relocated or absent current_old = advance(old_iterator) continue if current_new: # convert new into dirblock style - new_path_utf8 = current_new[0].encode('utf8') + new_path_utf8 = current_new[0].encode("utf8") new_dirname, new_basename = osutils.split(new_path_utf8) new_id = current_new[1].file_id new_entry_key = (new_dirname, new_basename, new_id) - current_new_minikind = \ - DirState._kind_to_minikind[current_new[1].kind] - if current_new_minikind == b't': - fingerprint = current_new[1].reference_revision or b'' + current_new_minikind = DirState._kind_to_minikind[current_new[1].kind] + if current_new_minikind == b"t": + fingerprint = current_new[1].reference_revision or b"" else: # We normally only insert or remove records, or update # them when it has significantly changed. Then we want to # erase its fingerprint. Unaffected records should # normally not be updated at all. - fingerprint = b'' + fingerprint = b"" else: # for safety disable variables - new_path_utf8 = new_dirname = new_basename = new_id = \ - new_entry_key = None + new_path_utf8 = ( + new_dirname + ) = new_basename = new_id = new_entry_key = None # 5 cases, we dont have a value that is strictly greater than everything, so # we make both end conditions explicit if not current_old: # old is finished: insert current_new into the state. if tracing: - trace.mutter("Appending from new '%s'.", - new_path_utf8.decode('utf8')) - self.update_minimal(new_entry_key, current_new_minikind, - executable=current_new[1].executable, - path_utf8=new_path_utf8, fingerprint=fingerprint, - fullscan=True) + trace.mutter( + "Appending from new '%s'.", new_path_utf8.decode("utf8") + ) + self.update_minimal( + new_entry_key, + current_new_minikind, + executable=current_new[1].executable, + path_utf8=new_path_utf8, + fingerprint=fingerprint, + fullscan=True, + ) current_new = advance(new_iterator) elif not current_new: # new is finished if tracing: - trace.mutter("Truncating from old '%s/%s'.", - current_old[0][0].decode('utf8'), - current_old[0][1].decode('utf8')) + trace.mutter( + "Truncating from old '%s/%s'.", + current_old[0][0].decode("utf8"), + current_old[0][1].decode("utf8"), + ) self._make_absent(current_old) current_old = advance(old_iterator) elif new_entry_key == current_old[0]: @@ -2734,38 +2884,54 @@ def advance(iterator): # TODO: update the record if anything significant has changed. # the minimal required trigger is if the execute bit or cached # kind has changed. - if (current_old[1][0][3] != current_new[1].executable or - current_old[1][0][0] != current_new_minikind): + if ( + current_old[1][0][3] != current_new[1].executable + or current_old[1][0][0] != current_new_minikind + ): if tracing: - trace.mutter("Updating in-place change '%s'.", - new_path_utf8.decode('utf8')) - self.update_minimal(current_old[0], current_new_minikind, - executable=current_new[1].executable, - path_utf8=new_path_utf8, fingerprint=fingerprint, - fullscan=True) + trace.mutter( + "Updating in-place change '%s'.", + new_path_utf8.decode("utf8"), + ) + self.update_minimal( + current_old[0], + current_new_minikind, + executable=current_new[1].executable, + path_utf8=new_path_utf8, + fingerprint=fingerprint, + fullscan=True, + ) # both sides are dealt with, move on current_old = advance(old_iterator) current_new = advance(new_iterator) - elif (lt_by_dirs(new_dirname, current_old[0][0]) - or (new_dirname == current_old[0][0] and - new_entry_key[1:] < current_old[0][1:])): + elif lt_by_dirs(new_dirname, current_old[0][0]) or ( + new_dirname == current_old[0][0] + and new_entry_key[1:] < current_old[0][1:] + ): # new comes before: # add a entry for this and advance new if tracing: - trace.mutter("Inserting from new '%s'.", - new_path_utf8.decode('utf8')) - self.update_minimal(new_entry_key, current_new_minikind, - executable=current_new[1].executable, - path_utf8=new_path_utf8, fingerprint=fingerprint, - fullscan=True) + trace.mutter( + "Inserting from new '%s'.", new_path_utf8.decode("utf8") + ) + self.update_minimal( + new_entry_key, + current_new_minikind, + executable=current_new[1].executable, + path_utf8=new_path_utf8, + fingerprint=fingerprint, + fullscan=True, + ) current_new = advance(new_iterator) else: # we've advanced past the place where the old key would be, # without seeing it in the new list. so it must be gone. if tracing: - trace.mutter("Deleting from old '%s/%s'.", - current_old[0][0].decode('utf8'), - current_old[0][1].decode('utf8')) + trace.mutter( + "Deleting from old '%s/%s'.", + current_old[0][0].decode("utf8"), + current_old[0][1].decode("utf8"), + ) self._make_absent(current_old) current_old = advance(old_iterator) self._mark_modified() @@ -2784,9 +2950,11 @@ def set_state_from_scratch(self, working_inv, parent_trees, parent_ghosts): self._requires_lock() # root dir and root dir contents with no children. We have to have a # root for set_state_from_inventory to work correctly. - empty_root = ((b'', b'', inventory.ROOT_ID), - [(b'd', b'', 0, False, DirState.NULLSTAT)]) - empty_tree_dirblocks = [(b'', [empty_root]), (b'', [])] + empty_root = ( + (b"", b"", inventory.ROOT_ID), + [(b"d", b"", 0, False, DirState.NULLSTAT)], + ) + empty_tree_dirblocks = [(b"", [empty_root]), (b"", [])] self._set_data([], empty_tree_dirblocks) self.set_state_from_inventory(working_inv) self.set_parent_trees(parent_trees, parent_ghosts) @@ -2803,12 +2971,13 @@ def _make_absent(self, current_old): all_remaining_keys = set() # Dont check the working tree, because it's going. for details in current_old[1][1:]: - if details[0] not in (b'a', b'r'): # absent, relocated + if details[0] not in (b"a", b"r"): # absent, relocated all_remaining_keys.add(current_old[0]) - elif details[0] == b'r': # relocated + elif details[0] == b"r": # relocated # record the key for the real path. all_remaining_keys.add( - tuple(osutils.split(details[1])) + (current_old[0][2],)) + tuple(osutils.split(details[1])) + (current_old[0][2],) + ) # absent rows are not present at any path. last_reference = current_old[0] not in all_remaining_keys if last_reference: @@ -2816,11 +2985,9 @@ def _make_absent(self, current_old): # absent), and relocated or absent entries for the other trees: # Remove it, its meaningless. block = self._find_block(current_old[0]) - entry_index, present = self._find_entry_index( - current_old[0], block[1]) + entry_index, present = self._find_entry_index(current_old[0], block[1]) if not present: - raise AssertionError( - f'could not find entry for {current_old}') + raise AssertionError(f"could not find entry for {current_old}") block[1].pop(entry_index) # if we have an id_index in use, remove this key from it for this id. if self._id_index is not None: @@ -2830,27 +2997,35 @@ def _make_absent(self, current_old): # (if there were other trees with the id present at this path), or may # be relocations. for update_key in all_remaining_keys: - update_block_index, present = \ - self._find_block_index_from_key(update_key) + update_block_index, present = self._find_block_index_from_key(update_key) if not present: - raise AssertionError( - f'could not find block for {update_key}') - update_entry_index, present = \ - self._find_entry_index( - update_key, self._dirblocks[update_block_index][1]) + raise AssertionError(f"could not find block for {update_key}") + update_entry_index, present = self._find_entry_index( + update_key, self._dirblocks[update_block_index][1] + ) if not present: - raise AssertionError( - f'could not find entry for {update_key}') - update_tree_details = self._dirblocks[update_block_index][1][update_entry_index][1] + raise AssertionError(f"could not find entry for {update_key}") + update_tree_details = self._dirblocks[update_block_index][1][ + update_entry_index + ][1] # it must not be absent at the moment - if update_tree_details[0][0] == b'a': # absent - raise AssertionError(f'bad row {update_tree_details!r}') + if update_tree_details[0][0] == b"a": # absent + raise AssertionError(f"bad row {update_tree_details!r}") update_tree_details[0] = DirState.NULL_PARENT_DETAILS self._mark_modified() return last_reference - def update_minimal(self, key, minikind, executable=False, fingerprint=b'', - packed_stat=None, size=0, path_utf8=None, fullscan=False): + def update_minimal( + self, + key, + minikind, + executable=False, + fingerprint=b"", + packed_stat=None, + size=0, + path_utf8=None, + fullscan=False, + ): """Update an entry to the state in tree 0. This will either create a new entry at 'key' or update an existing one. @@ -2886,18 +3061,20 @@ def update_minimal(self, key, minikind, executable=False, fingerprint=b'', if not present: # New record. Check there isn't a entry at this path already. if not fullscan: - low_index, _ = self._find_entry_index(key[0:2] + (b'',), block) + low_index, _ = self._find_entry_index(key[0:2] + (b"",), block) while low_index < len(block): entry = block[low_index] if entry[0][0:2] == key[0:2]: - if entry[1][0][0] not in (b'a', b'r'): + if entry[1][0][0] not in (b"a", b"r"): # This entry has the same path (but a different id) as # the new entry we're adding, and is present in ths # tree. self._raise_invalid( - (b"%s/%s" % key[0:2]).decode('utf8'), key[2], + (b"%s/%s" % key[0:2]).decode("utf8"), + key[2], "Attempt to add item at path already occupied by " - "id %r" % entry[0][2]) + "id %r" % entry[0][2], + ) low_index += 1 else: break @@ -2919,26 +3096,28 @@ def update_minimal(self, key, minikind, executable=False, fingerprint=b'', # the test for existing kinds is different: this can be # factored out to a helper though. other_block_index, present = self._find_block_index_from_key( - other_key) + other_key + ) if not present: - raise AssertionError(f'could not find block for {other_key}') + raise AssertionError(f"could not find block for {other_key}") other_block = self._dirblocks[other_block_index][1] other_entry_index, present = self._find_entry_index( - other_key, other_block) + other_key, other_block + ) if not present: raise AssertionError( - f'update_minimal: could not find other entry for {other_key}') + f"update_minimal: could not find other entry for {other_key}" + ) if path_utf8 is None: - raise AssertionError('no path') + raise AssertionError("no path") # Turn this other location into a reference to the new # location. This also updates the aliased iterator # (current_old in set_state_from_inventory) so that the old # entry, if not already examined, is skipped over by that # loop. other_entry = other_block[other_entry_index] - other_entry[1][0] = (b'r', path_utf8, 0, False, b'') - if self._maybe_remove_row(other_block, other_entry_index, - id_index): + other_entry[1][0] = (b"r", path_utf8, 0, False, b"") + if self._maybe_remove_row(other_block, other_entry_index, id_index): # If the row holding this was removed, we need to # recompute where this entry goes entry_index, _ = self._find_entry_index(key, block) @@ -2957,27 +3136,29 @@ def update_minimal(self, key, minikind, executable=False, fingerprint=b'', # TODO: optimise this to reduce memory use in highly # fragmented situations by reusing the relocation # records. - update_block_index, present = \ - self._find_block_index_from_key(other_key) + update_block_index, present = self._find_block_index_from_key( + other_key + ) if not present: - raise AssertionError( - f'could not find block for {other_key}') - update_entry_index, present = \ - self._find_entry_index( - other_key, self._dirblocks[update_block_index][1]) + raise AssertionError(f"could not find block for {other_key}") + update_entry_index, present = self._find_entry_index( + other_key, self._dirblocks[update_block_index][1] + ) if not present: raise AssertionError( - f'update_minimal: could not find entry for {other_key}') - update_details = self._dirblocks[update_block_index][1][update_entry_index][1][lookup_index] - if update_details[0] in (b'a', b'r'): # relocated, absent + f"update_minimal: could not find entry for {other_key}" + ) + update_details = self._dirblocks[update_block_index][1][ + update_entry_index + ][1][lookup_index] + if update_details[0] in (b"a", b"r"): # relocated, absent # its a pointer or absent in lookup_index's tree, use # it as is. new_entry[1].append(update_details) else: # we have the right key, make a pointer to it. pointer_path = osutils.pathjoin(*other_key[0:2]) - new_entry[1].append( - (b'r', pointer_path, 0, False, b'')) + new_entry[1].append((b"r", pointer_path, 0, False, b"")) block.insert(entry_index, new_entry) id_index.add(key) else: @@ -2993,12 +3174,14 @@ def update_minimal(self, key, minikind, executable=False, fingerprint=b'', # that were absent - where parent entries are - and they need to be # converted to relocated. if path_utf8 is None: - raise AssertionError('no path') + raise AssertionError("no path") existing_keys = id_index.get(key[2]) if key not in existing_keys: - raise AssertionError('We found the entry in the blocks, but' - ' the key is not in the id_index.' - f' key: {key}, existing_keys: {existing_keys}') + raise AssertionError( + "We found the entry in the blocks, but" + " the key is not in the id_index." + f" key: {key}, existing_keys: {existing_keys}" + ) for entry_key in existing_keys: # TODO:PROFILING: It might be faster to just update # rather than checking if we need to, and then overwrite @@ -3008,20 +3191,25 @@ def update_minimal(self, key, minikind, executable=False, fingerprint=b'', # other trees, so put absent pointers there # This is the vertical axis in the matrix, all pointing # to the real path. - block_index, present = self._find_block_index_from_key( - entry_key) + block_index, present = self._find_block_index_from_key(entry_key) if not present: - raise AssertionError('not present: %r', entry_key) + raise AssertionError("not present: %r", entry_key) entry_index, present = self._find_entry_index( - entry_key, self._dirblocks[block_index][1]) + entry_key, self._dirblocks[block_index][1] + ) if not present: - raise AssertionError('not present: %r', entry_key) - self._dirblocks[block_index][1][entry_index][1][0] = \ - (b'r', path_utf8, 0, False, b'') + raise AssertionError("not present: %r", entry_key) + self._dirblocks[block_index][1][entry_index][1][0] = ( + b"r", + path_utf8, + 0, + False, + b"", + ) # add a containing dirblock if needed. - if new_details[0] == b'd': + if new_details[0] == b"d": # GZ 2017-06-09: Using pathjoin why? - subdir_key = (osutils.pathjoin(*key[0:2]), b'', b'') + subdir_key = (osutils.pathjoin(*key[0:2]), b"", b"") block_index, present = self._find_block_index_from_key(subdir_key) if not present: self._dirblocks.insert(block_index, (subdir_key[0], [])) @@ -3037,7 +3225,7 @@ def _maybe_remove_row(self, block, index, id_index): present_in_row = False entry = block[index] for column in entry[1]: - if column[0] not in (b'a', b'r'): + if column[0] not in (b"a", b"r"): present_in_row = True break if not present_in_row: @@ -3066,26 +3254,28 @@ def _validate(self): # # -- mbp 20070325 from pprint import pformat + self._read_dirblocks_if_needed() if len(self._dirblocks) > 0: - if not self._dirblocks[0][0] == b'': + if not self._dirblocks[0][0] == b"": raise AssertionError( "dirblocks don't start with root block:\n" - + pformat(self._dirblocks)) + + pformat(self._dirblocks) + ) if len(self._dirblocks) > 1: - if not self._dirblocks[1][0] == b'': + if not self._dirblocks[1][0] == b"": raise AssertionError( - "dirblocks missing root directory:\n" - + pformat(self._dirblocks)) + "dirblocks missing root directory:\n" + pformat(self._dirblocks) + ) # the dirblocks are sorted by their path components, name, and dir id - dir_names = [d[0].split(b'/') - for d in self._dirblocks[1:]] + dir_names = [d[0].split(b"/") for d in self._dirblocks[1:]] if dir_names != sorted(dir_names): raise AssertionError( - "dir names are not in sorted order:\n" + - pformat(self._dirblocks) + - "\nkeys:\n" + - pformat(dir_names)) + "dir names are not in sorted order:\n" + + pformat(self._dirblocks) + + "\nkeys:\n" + + pformat(dir_names) + ) for dirblock in self._dirblocks: # within each dirblock, the entries are sorted by filename and # then by id. @@ -3093,10 +3283,12 @@ def _validate(self): if dirblock[0] != entry[0][0]: raise AssertionError( f"entry key for {entry!r}" - f"doesn't match directory name in\n{pformat(dirblock)!r}") + f"doesn't match directory name in\n{pformat(dirblock)!r}" + ) if dirblock[1] != sorted(dirblock[1]): raise AssertionError( - f"dirblock for {dirblock[0]!r} is not sorted:\n{pformat(dirblock)}") + f"dirblock for {dirblock[0]!r} is not sorted:\n{pformat(dirblock)}" + ) def check_valid_parent(): """Check that the current entry has a valid parent. @@ -3106,17 +3298,19 @@ def check_valid_parent(): current tree. (It is invalid to have a non-absent file in an absent directory.) """ - if entry[0][0:2] == (b'', b''): + if entry[0][0:2] == (b"", b""): # There should be no parent for the root row return parent_entry = self._get_entry(tree_index, path_utf8=entry[0][0]) if parent_entry == (None, None): raise AssertionError( - f"no parent entry for: {this_path} in tree {tree_index}") - if parent_entry[1][tree_index][0] != b'd': + f"no parent entry for: {this_path} in tree {tree_index}" + ) + if parent_entry[1][tree_index][0] != b"d": raise AssertionError( f"Parent entry for {this_path} is not marked as a valid" - f" directory. {parent_entry}") + f" directory. {parent_entry}" + ) # For each file id, for each tree: either # the file id is not present at all; all rows with that id in the @@ -3135,64 +3329,68 @@ def check_valid_parent(): if len(entry[1]) != tree_count: raise AssertionError( "wrong number of entry details for row\n%s" - ",\nexpected %d" % - (pformat(entry), tree_count)) + ",\nexpected %d" % (pformat(entry), tree_count) + ) absent_positions = 0 for tree_index, tree_state in enumerate(entry[1]): this_tree_map = id_path_maps[tree_index] minikind = tree_state[0] - if minikind in (b'a', b'r'): + if minikind in (b"a", b"r"): absent_positions += 1 # have we seen this id before in this column? if file_id in this_tree_map: previous_path, previous_loc = this_tree_map[file_id] # any later mention of this file must be consistent with # what was said before - if minikind == b'a': + if minikind == b"a": if previous_path is not None: raise AssertionError( "file {} is absent in row {!r} but also present " - "at {!r}".format(file_id.decode('utf-8'), entry, previous_path)) - elif minikind == b'r': + "at {!r}".format( + file_id.decode("utf-8"), entry, previous_path + ) + ) + elif minikind == b"r": target_location = tree_state[1] if previous_path != target_location: raise AssertionError( - f"file {file_id} relocation in row {entry!r} but also at {previous_path!r}") + f"file {file_id} relocation in row {entry!r} but also at {previous_path!r}" + ) else: # a file, directory, etc - may have been previously # pointed to by a relocation, which must point here if previous_path != this_path: raise AssertionError( "entry {!r} inconsistent with previous path {!r} " - "seen at {!r}".format(entry, previous_path, previous_loc)) + "seen at {!r}".format( + entry, previous_path, previous_loc + ) + ) check_valid_parent() else: - if minikind == b'a': + if minikind == b"a": # absent; should not occur anywhere else this_tree_map[file_id] = None, this_path - elif minikind == b'r': + elif minikind == b"r": # relocation, must occur at expected location this_tree_map[file_id] = tree_state[1], this_path else: this_tree_map[file_id] = this_path, this_path check_valid_parent() if absent_positions == tree_count: - raise AssertionError( - f"entry {entry!r} has no data for any tree.") + raise AssertionError(f"entry {entry!r} has no data for any tree.") if self._id_index is not None: for entry_key in self._id_index.iter_all(): # And that from this entry key, we can look up the original # record - block_index, present = self._find_block_index_from_key( - entry_key) + block_index, present = self._find_block_index_from_key(entry_key) if not present: - raise AssertionError( - 'missing block for entry key: %r', entry_key) + raise AssertionError("missing block for entry key: %r", entry_key) entry_index, present = self._find_entry_index( - entry_key, self._dirblocks[block_index][1]) + entry_key, self._dirblocks[block_index][1] + ) if not present: - raise AssertionError( - 'missing entry for key: %r', entry_key) + raise AssertionError("missing entry for key: %r", entry_key) def _wipe_state(self): """Forget all state information about the dirstate.""" @@ -3217,7 +3415,7 @@ def lock_read(self): # any modification. If not modified, we can just leave things # alone self._lock_token = _transport_rs.ReadLock(self._filename) - self._lock_state = 'r' + self._lock_state = "r" self._state_file = self._lock_token.f self._wipe_state() return lock.LogicalLockResult(self.unlock) @@ -3231,7 +3429,7 @@ def lock_write(self): # any modification. If not modified, we can just leave things # alone self._lock_token = _transport_rs.WriteLock(self._filename) - self._lock_state = 'w' + self._lock_state = "w" self._state_file = self._lock_token.f self._wipe_state() return lock.LogicalLockResult(self.unlock, self._lock_token) @@ -3256,8 +3454,9 @@ def _requires_lock(self): raise errors.ObjectNotLocked(self) -def py_update_entry(state, entry, abspath, stat_value, - _stat_to_minikind=DirState._stat_to_minikind): +def py_update_entry( + state, entry, abspath, stat_value, _stat_to_minikind=DirState._stat_to_minikind +): """Update the entry based on what is actually on disk. This function only calculates the sha if it needs to - if the entry is @@ -3277,18 +3476,22 @@ def py_update_entry(state, entry, abspath, stat_value, # Unhandled kind return None packed_stat = pack_stat(stat_value) - (saved_minikind, saved_link_or_sha1, saved_file_size, - saved_executable, saved_packed_stat) = entry[1][0] + ( + saved_minikind, + saved_link_or_sha1, + saved_file_size, + saved_executable, + saved_packed_stat, + ) = entry[1][0] if not isinstance(saved_minikind, bytes): raise TypeError(saved_minikind) - if minikind == b'd' and saved_minikind == b't': - minikind = b't' - if (minikind == saved_minikind - and packed_stat == saved_packed_stat): + if minikind == b"d" and saved_minikind == b"t": + minikind = b"t" + if minikind == saved_minikind and packed_stat == saved_packed_stat: # The stat hasn't changed since we saved, so we can re-use the # saved sha hash. - if minikind == b'd': + if minikind == b"d": return None # size should also be in packed_stat @@ -3299,15 +3502,16 @@ def py_update_entry(state, entry, abspath, stat_value, # process this entry. link_or_sha1 = None worth_saving = True - if minikind == b'f': - executable = state._is_executable(stat_value.st_mode, - saved_executable) + if minikind == b"f": + executable = state._is_executable(stat_value.st_mode, saved_executable) if state._cutoff_time is None: state._sha_cutoff_time() - if (stat_value.st_mtime < state._cutoff_time + if ( + stat_value.st_mtime < state._cutoff_time and stat_value.st_ctime < state._cutoff_time and len(entry[1]) > 1 - and entry[1][1][0] != b'a'): + and entry[1][1][0] != b"a" + ): # Could check for size changes for further optimised # avoidance of sha1's. However the most prominent case of # over-shaing is during initial add, which this catches. @@ -3315,59 +3519,89 @@ def py_update_entry(state, entry, abspath, stat_value, # are calculated at the same time, so checking just the size # gains nothing w.r.t. performance. link_or_sha1 = state._sha1_file(abspath) - entry[1][0] = (b'f', link_or_sha1, stat_value.st_size, - executable, packed_stat) + entry[1][0] = ( + b"f", + link_or_sha1, + stat_value.st_size, + executable, + packed_stat, + ) else: - entry[1][0] = (b'f', b'', stat_value.st_size, - executable, DirState.NULLSTAT) + entry[1][0] = (b"f", b"", stat_value.st_size, executable, DirState.NULLSTAT) worth_saving = False - elif minikind == b'd': + elif minikind == b"d": link_or_sha1 = None - entry[1][0] = (b'd', b'', 0, False, packed_stat) - if saved_minikind != b'd': + entry[1][0] = (b"d", b"", 0, False, packed_stat) + if saved_minikind != b"d": # This changed from something into a directory. Make sure we # have a directory block for it. This doesn't happen very # often, so this doesn't have to be super fast. - block_index, entry_index, dir_present, file_present = \ - state._get_block_entry_index(entry[0][0], entry[0][1], 0) - state._ensure_block(block_index, entry_index, - osutils.pathjoin(entry[0][0], entry[0][1])) + ( + block_index, + entry_index, + dir_present, + file_present, + ) = state._get_block_entry_index(entry[0][0], entry[0][1], 0) + state._ensure_block( + block_index, entry_index, osutils.pathjoin(entry[0][0], entry[0][1]) + ) else: worth_saving = False - elif minikind == b'l': - if saved_minikind == b'l': + elif minikind == b"l": + if saved_minikind == b"l": worth_saving = False link_or_sha1 = state._read_link(abspath, saved_link_or_sha1) if state._cutoff_time is None: state._sha_cutoff_time() - if (stat_value.st_mtime < state._cutoff_time - and stat_value.st_ctime < state._cutoff_time): - entry[1][0] = (b'l', link_or_sha1, stat_value.st_size, - False, packed_stat) + if ( + stat_value.st_mtime < state._cutoff_time + and stat_value.st_ctime < state._cutoff_time + ): + entry[1][0] = (b"l", link_or_sha1, stat_value.st_size, False, packed_stat) else: - entry[1][0] = (b'l', b'', stat_value.st_size, - False, DirState.NULLSTAT) + entry[1][0] = (b"l", b"", stat_value.st_size, False, DirState.NULLSTAT) if worth_saving: state._mark_modified([entry]) return link_or_sha1 class ProcessEntryPython: - - __slots__ = ["old_dirname_to_file_id", "new_dirname_to_file_id", - "last_source_parent", "last_target_parent", "include_unchanged", - "partial", "use_filesystem_for_exec", "utf8_decode", - "searched_specific_files", "search_specific_files", - "searched_exact_paths", "search_specific_file_parents", "seen_ids", - "state", "source_index", "target_index", "want_unversioned", "tree"] - - def __init__(self, include_unchanged, use_filesystem_for_exec, - search_specific_files, state, source_index, target_index, - want_unversioned, tree): + __slots__ = [ + "old_dirname_to_file_id", + "new_dirname_to_file_id", + "last_source_parent", + "last_target_parent", + "include_unchanged", + "partial", + "use_filesystem_for_exec", + "utf8_decode", + "searched_specific_files", + "search_specific_files", + "searched_exact_paths", + "search_specific_file_parents", + "seen_ids", + "state", + "source_index", + "target_index", + "want_unversioned", + "tree", + ] + + def __init__( + self, + include_unchanged, + use_filesystem_for_exec, + search_specific_files, + state, + source_index, + target_index, + want_unversioned, + tree, + ): self.old_dirname_to_file_id = {} self.new_dirname_to_file_id = {} # Are we doing a partial iter_changes? - self.partial = search_specific_files != {''} + self.partial = search_specific_files != {""} # Using a list so that we can access the values and change them in # nested scope. Each one is [path, file_id, entry] self.last_source_parent = [None, None] @@ -3394,7 +3628,7 @@ def __init__(self, include_unchanged, use_filesystem_for_exec, self.target_index = target_index if target_index != 0: # A lot of code in here depends on target_index == 0 - raise errors.BzrError('unsupported target index') + raise errors.BzrError("unsupported target index") self.want_unversioned = want_unversioned self.tree = tree @@ -3420,16 +3654,17 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): else: source_details = entry[1][self.source_index] # GZ 2017-06-09: Eck, more sets. - _fdltr = {b'f', b'd', b'l', b't', b'r'} - _fdlt = {b'f', b'd', b'l', b't'} - _ra = (b'r', b'a') + _fdltr = {b"f", b"d", b"l", b"t", b"r"} + _fdlt = {b"f", b"d", b"l", b"t"} + _ra = (b"r", b"a") target_details = entry[1][self.target_index] target_minikind = target_details[0] if path_info is not None and target_minikind in _fdlt: if not (self.target_index == 0): raise AssertionError() - link_or_sha1 = update_entry(self.state, entry, - abspath=path_info[4], stat_value=path_info[3]) + link_or_sha1 = update_entry( + self.state, entry, abspath=path_info[4], stat_value=path_info[3] + ) # The entry may have been modified by update_entry target_details = entry[1][self.target_index] target_minikind = target_details[0] @@ -3443,26 +3678,28 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): # | | | diff check on source-target # r | fdlt | a | dangling file that was present in the basis. # | | | ??? - if source_minikind == b'r': + if source_minikind == b"r": # add the source to the search path to find any children it # has. TODO ? : only add if it is a container ? - if not osutils.is_inside_any(self.searched_specific_files, - source_details[1]): + if not osutils.is_inside_any( + self.searched_specific_files, source_details[1] + ): self.search_specific_files.add(source_details[1]) # generate the old path; this is needed for stating later # as well. old_path = source_details[1] old_dirname, old_basename = os.path.split(old_path) path = pathjoin(entry[0][0], entry[0][1]) - old_entry = self.state._get_entry(self.source_index, - path_utf8=old_path) + old_entry = self.state._get_entry(self.source_index, path_utf8=old_path) # update the source details variable to be the real # location. if old_entry == (None, None): - raise DirstateCorrupt(self.state._filename, - "entry '{}/{}' is considered renamed from {!r}" - " but source does not exist\n" - "entry: {}".format(entry[0][0], entry[0][1], old_path, entry)) + raise DirstateCorrupt( + self.state._filename, + "entry '{}/{}' is considered renamed from {!r}" + " but source does not exist\n" + "entry: {}".format(entry[0][0], entry[0][1], old_path, entry), + ) source_details = old_entry[1][self.source_index] source_minikind = source_details[0] else: @@ -3477,18 +3714,18 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): else: # source and target are both versioned and disk file is present. target_kind = path_info[2] - if target_kind == 'directory': + if target_kind == "directory": if path is None: old_path = path = pathjoin(old_dirname, old_basename) self.new_dirname_to_file_id[path] = file_id - if source_minikind != b'd': + if source_minikind != b"d": content_change = True else: # directories have no fingerprint content_change = False target_exec = False - elif target_kind == 'file': - if source_minikind != b'f': + elif target_kind == "file": + if source_minikind != b"f": content_change = True else: # Check the sha. We can't just rely on the size as @@ -3496,12 +3733,12 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): # map to the same content if link_or_sha1 is None: # Stat cache miss: - statvalue, link_or_sha1 = \ - self.state._sha1_provider.stat_and_sha1( - path_info[4]) - self.state._observed_sha1(entry, link_or_sha1, - statvalue) - content_change = (link_or_sha1 != source_details[1]) + ( + statvalue, + link_or_sha1, + ) = self.state._sha1_provider.stat_and_sha1(path_info[4]) + self.state._observed_sha1(entry, link_or_sha1, statvalue) + content_change = link_or_sha1 != source_details[1] # Target details is updated at update_entry time if self.use_filesystem_for_exec: # We don't need S_ISREG here, because we are sure @@ -3509,14 +3746,14 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): target_exec = bool(stat.S_IEXEC & path_info[3].st_mode) else: target_exec = target_details[3] - elif target_kind == 'symlink': - if source_minikind != b'l': + elif target_kind == "symlink": + if source_minikind != b"l": content_change = True else: - content_change = (link_or_sha1 != source_details[1]) + content_change = link_or_sha1 != source_details[1] target_exec = False - elif target_kind == 'tree-reference': - if source_minikind != b't': + elif target_kind == "tree-reference": + if source_minikind != b"t": content_change = True else: content_change = False @@ -3525,7 +3762,7 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): if path is None: path = pathjoin(old_dirname, old_basename) raise errors.BadFileKindError(path, path_info[2]) - if source_minikind == b'd': + if source_minikind == b"d": if path is None: old_path = path = pathjoin(old_dirname, old_basename) self.old_dirname_to_file_id[old_path] = file_id @@ -3536,8 +3773,9 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): try: source_parent_id = self.old_dirname_to_file_id[old_dirname] except KeyError: - source_parent_entry = self.state._get_entry(self.source_index, - path_utf8=old_dirname) + source_parent_entry = self.state._get_entry( + self.source_index, path_utf8=old_dirname + ) source_parent_id = source_parent_entry[0][2] if source_parent_id == entry[0][2]: # This is the root, so the parent is None @@ -3554,11 +3792,13 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): except KeyError as e: # TODO: We don't always need to do the lookup, because the # parent entry will be the same as the source entry. - target_parent_entry = self.state._get_entry(self.target_index, - path_utf8=new_dirname) + target_parent_entry = self.state._get_entry( + self.target_index, path_utf8=new_dirname + ) if target_parent_entry == (None, None): raise AssertionError( - f"Could not find target parent in wt: {new_dirname}\nparent of: {entry}") from e + f"Could not find target parent in wt: {new_dirname}\nparent of: {entry}" + ) from e target_parent_id = target_parent_entry[0][2] if target_parent_id == entry[0][2]: # This is the root, so the parent is None @@ -3568,24 +3808,25 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): self.last_target_parent[1] = target_parent_id source_exec = source_details[3] - changed = (content_change - or source_parent_id != target_parent_id - or old_basename != entry[0][1] - or source_exec != target_exec - ) + changed = ( + content_change + or source_parent_id != target_parent_id + or old_basename != entry[0][1] + or source_exec != target_exec + ) if not changed and not self.include_unchanged: return None, False else: if old_path is None: old_path = path = pathjoin(old_dirname, old_basename) - old_path_u = self.utf8_decode(old_path, 'surrogateescape')[0] + old_path_u = self.utf8_decode(old_path, "surrogateescape")[0] path_u = old_path_u else: - old_path_u = self.utf8_decode(old_path, 'surrogateescape')[0] + old_path_u = self.utf8_decode(old_path, "surrogateescape")[0] if old_path == path: path_u = old_path_u else: - path_u = self.utf8_decode(path, 'surrogateescape')[0] + path_u = self.utf8_decode(path, "surrogateescape")[0] source_kind = DirState._minikind_to_kind[source_minikind] return InventoryTreeChange( entry[0][2], @@ -3593,17 +3834,21 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): content_change, (True, True), (source_parent_id, target_parent_id), - (self.utf8_decode(old_basename, 'surrogateescape')[ - 0], self.utf8_decode(entry[0][1], 'surrogateescape')[0]), + ( + self.utf8_decode(old_basename, "surrogateescape")[0], + self.utf8_decode(entry[0][1], "surrogateescape")[0], + ), (source_kind, target_kind), - (source_exec, target_exec)), changed - elif source_minikind in b'a' and target_minikind in _fdlt: + (source_exec, target_exec), + ), changed + elif source_minikind in b"a" and target_minikind in _fdlt: # looks like a new file path = pathjoin(entry[0][0], entry[0][1]) # parent id is the entry for the path in the target tree # TODO: these are the same for an entire directory: cache em. - parent_id = self.state._get_entry(self.target_index, - path_utf8=entry[0][0])[0][2] + parent_id = self.state._get_entry(self.target_index, path_utf8=entry[0][0])[ + 0 + ][2] if parent_id == entry[0][2]: parent_id = None if path_info is not None: @@ -3613,56 +3858,62 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): # is a file or not. target_exec = bool( stat.S_ISREG(path_info[3].st_mode) - and stat.S_IEXEC & path_info[3].st_mode) + and stat.S_IEXEC & path_info[3].st_mode + ) else: target_exec = target_details[3] return InventoryTreeChange( entry[0][2], - (None, self.utf8_decode(path, 'surrogateescape')[0]), + (None, self.utf8_decode(path, "surrogateescape")[0]), True, (False, True), (None, parent_id), - (None, self.utf8_decode(entry[0][1], 'surrogateescape')[0]), + (None, self.utf8_decode(entry[0][1], "surrogateescape")[0]), (None, path_info[2]), - (None, target_exec)), True + (None, target_exec), + ), True else: # Its a missing file, report it as such. return InventoryTreeChange( entry[0][2], - (None, self.utf8_decode(path, 'surrogateescape')[0]), + (None, self.utf8_decode(path, "surrogateescape")[0]), False, (False, True), (None, parent_id), - (None, self.utf8_decode(entry[0][1], 'surrogateescape')[0]), + (None, self.utf8_decode(entry[0][1], "surrogateescape")[0]), (None, None), - (None, False)), True - elif source_minikind in _fdlt and target_minikind in b'a': + (None, False), + ), True + elif source_minikind in _fdlt and target_minikind in b"a": # unversioned, possibly, or possibly not deleted: we dont care. # if its still on disk, *and* theres no other entry at this # path [we dont know this in this routine at the moment - # perhaps we should change this - then it would be an unknown. old_path = pathjoin(entry[0][0], entry[0][1]) # parent id is the entry for the path in the target tree - parent_id = self.state._get_entry( - self.source_index, path_utf8=entry[0][0])[0][2] + parent_id = self.state._get_entry(self.source_index, path_utf8=entry[0][0])[ + 0 + ][2] if parent_id == entry[0][2]: parent_id = None return InventoryTreeChange( entry[0][2], - (self.utf8_decode(old_path, 'surrogateescape')[0], None), + (self.utf8_decode(old_path, "surrogateescape")[0], None), True, (True, False), (parent_id, None), - (self.utf8_decode(entry[0][1], 'surrogateescape')[0], None), + (self.utf8_decode(entry[0][1], "surrogateescape")[0], None), (DirState._minikind_to_kind[source_minikind], None), - (source_details[3], None)), True - elif source_minikind in _fdlt and target_minikind in b'r': + (source_details[3], None), + ), True + elif source_minikind in _fdlt and target_minikind in b"r": # a rename; could be a true rename, or a rename inherited from # a renamed parent. TODO: handle this efficiently. Its not # common case to rename dirs though, so a correct but slow # implementation will do. - if not osutils.is_inside_any(self.searched_specific_files, - target_details[1]): + if not osutils.is_inside_any( + self.searched_specific_files, target_details[1] + ): self.search_specific_files.add(target_details[1]) elif source_minikind in _ra and target_minikind in _ra: # neither of the selected trees contain this file, @@ -3670,8 +3921,10 @@ def _process_entry(self, entry, path_info, pathjoin=osutils.pathjoin): # is indirectly via test_too_much.TestCommands.test_conflicts. pass else: - raise AssertionError("don't know how to compare " - f"source_minikind={source_minikind!r}, target_minikind={target_minikind!r}") + raise AssertionError( + "don't know how to compare " + f"source_minikind={source_minikind!r}, target_minikind={target_minikind!r}" + ) return None, None def __iter__(self): @@ -3691,10 +3944,12 @@ def _gather_result_for_consistency(self, result): if new_path: # Not the root and not a delete: queue up the parents of the path. self.search_specific_file_parents.update( - p.encode('utf8', 'surrogateescape') for p in osutils.parent_directories(new_path)) + p.encode("utf8", "surrogateescape") + for p in osutils.parent_directories(new_path) + ) # Add the root directory which parent_directories does not # provide. - self.search_specific_file_parents.add(b'') + self.search_specific_file_parents.add(b"") def iter_changes(self): """Iterate over the changes.""" @@ -3746,7 +4001,7 @@ def iter_changes(self): # TODO: the pending list should be lexically sorted? the # interface doesn't require it. current_root = search_specific_files.pop() - current_root_unicode = current_root.decode('utf8') + current_root_unicode = current_root.decode("utf8") searched_specific_files.add(current_root) # process the entries for this containing directory: the rest will be # found by their parents recursively. @@ -3758,15 +4013,20 @@ def iter_changes(self): # the path does not exist: let _process_entry know that. root_dir_info = None else: - root_dir_info = (b'', current_root, - osutils.file_kind_from_stat_mode( - root_stat.st_mode), root_stat, - root_abspath) - if root_dir_info[2] == 'directory': + root_dir_info = ( + b"", + current_root, + osutils.file_kind_from_stat_mode(root_stat.st_mode), + root_stat, + root_abspath, + ) + if root_dir_info[2] == "directory": if self.tree._directory_is_tree_reference( - current_root.decode('utf8')): - root_dir_info = root_dir_info[:2] + \ - ('tree-reference',) + root_dir_info[3:] + current_root.decode("utf8") + ): + root_dir_info = ( + root_dir_info[:2] + ("tree-reference",) + root_dir_info[3:] + ) if not root_entries and not root_dir_info: # this specified path is not present at all, skip it. @@ -3783,7 +4043,8 @@ def iter_changes(self): if self.want_unversioned and not path_handled and root_dir_info: new_executable = bool( stat.S_ISREG(root_dir_info[3].st_mode) - and stat.S_IEXEC & root_dir_info[3].st_mode) + and stat.S_IEXEC & root_dir_info[3].st_mode + ) yield InventoryTreeChange( None, (None, current_root_unicode), @@ -3792,43 +4053,43 @@ def iter_changes(self): (None, None), (None, splitpath(current_root_unicode)[-1]), (None, root_dir_info[2]), - (None, new_executable) - ) - initial_key = (current_root, b'', b'') + (None, new_executable), + ) + initial_key = (current_root, b"", b"") block_index, _ = self.state._find_block_index_from_key(initial_key) if block_index == 0: # we have processed the total root already, but because the # initial key matched it we should skip it here. block_index += 1 - if root_dir_info and root_dir_info[2] == 'tree-reference': + if root_dir_info and root_dir_info[2] == "tree-reference": current_dir_info = None else: - dir_iterator = osutils._walkdirs_utf8( - root_abspath, prefix=current_root) + dir_iterator = osutils._walkdirs_utf8(root_abspath, prefix=current_root) try: current_dir_info = next(dir_iterator) except (FileNotFoundError, NotADirectoryError, ValueError): current_dir_info = None else: - if current_dir_info[0][0] == b'': + if current_dir_info[0][0] == b"": # remove .bzr from iteration - bzr_index = bisect.bisect_left( - current_dir_info[1], (b'.bzr',)) - if current_dir_info[1][bzr_index][0] != b'.bzr': + bzr_index = bisect.bisect_left(current_dir_info[1], (b".bzr",)) + if current_dir_info[1][bzr_index][0] != b".bzr": raise AssertionError() del current_dir_info[1][bzr_index] # walk until both the directory listing and the versioned metadata # are exhausted. - if (block_index < len(self.state._dirblocks) and - osutils.is_inside(current_root, - self.state._dirblocks[block_index][0])): + if block_index < len(self.state._dirblocks) and osutils.is_inside( + current_root, self.state._dirblocks[block_index][0] + ): current_block = self.state._dirblocks[block_index] else: current_block = None - while (current_dir_info is not None or - current_block is not None): - if (current_dir_info and current_block - and current_dir_info[0][0] != current_block[0]): + while current_dir_info is not None or current_block is not None: + if ( + current_dir_info + and current_block + and current_dir_info[0][0] != current_block[0] + ): if _lt_by_dirs(current_dir_info[0][0], current_block[0]): # filesystem data refers to paths not covered by the dirblock. # this has two possibilities: @@ -3843,28 +4104,42 @@ def iter_changes(self): while path_index < len(current_dir_info[1]): current_path_info = current_dir_info[1][path_index] if self.want_unversioned: - if current_path_info[2] == 'directory': + if current_path_info[2] == "directory": if self.tree._directory_is_tree_reference( - current_path_info[0].decode('utf8')): - current_path_info = current_path_info[:2] + \ - ('tree-reference',) + \ - current_path_info[3:] + current_path_info[0].decode("utf8") + ): + current_path_info = ( + current_path_info[:2] + + ("tree-reference",) + + current_path_info[3:] + ) new_executable = bool( stat.S_ISREG(current_path_info[3].st_mode) - and stat.S_IEXEC & current_path_info[3].st_mode) + and stat.S_IEXEC & current_path_info[3].st_mode + ) yield InventoryTreeChange( None, - (None, utf8_decode(current_path_info[0], 'surrogateescape')[0]), + ( + None, + utf8_decode( + current_path_info[0], "surrogateescape" + )[0], + ), True, (False, False), (None, None), - (None, utf8_decode(current_path_info[1], 'surrogateescape')[0]), + ( + None, + utf8_decode( + current_path_info[1], "surrogateescape" + )[0], + ), (None, current_path_info[2]), - (None, new_executable)) + (None, new_executable), + ) # dont descend into this unversioned path if it is # a dir - if current_path_info[2] in ('directory', - 'tree-reference'): + if current_path_info[2] in ("directory", "tree-reference"): del current_dir_info[1][path_index] path_index -= 1 path_index += 1 @@ -3885,17 +4160,18 @@ def iter_changes(self): for current_entry in current_block[1]: # entry referring to file not present on disk. # advance the entry only, after processing. - result, changed = _process_entry( - current_entry, None) + result, changed = _process_entry(current_entry, None) if changed is not None: if changed: self._gather_result_for_consistency(result) if changed or self.include_unchanged: yield result block_index += 1 - if (block_index < len(self.state._dirblocks) and - osutils.is_inside(current_root, - self.state._dirblocks[block_index][0])): + if block_index < len( + self.state._dirblocks + ) and osutils.is_inside( + current_root, self.state._dirblocks[block_index][0] + ): current_block = self.state._dirblocks[block_index] else: current_block = None @@ -3909,17 +4185,20 @@ def iter_changes(self): path_index = 0 if current_dir_info and path_index < len(current_dir_info[1]): current_path_info = current_dir_info[1][path_index] - if current_path_info[2] == 'directory': + if current_path_info[2] == "directory": if self.tree._directory_is_tree_reference( - current_path_info[0].decode('utf8')): - current_path_info = current_path_info[:2] + \ - ('tree-reference',) + current_path_info[3:] + current_path_info[0].decode("utf8") + ): + current_path_info = ( + current_path_info[:2] + + ("tree-reference",) + + current_path_info[3:] + ) else: current_path_info = None advance_path = True path_handled = False - while (current_entry is not None or - current_path_info is not None): + while current_entry is not None or current_path_info is not None: if current_entry is None: # the check for path_handled when the path is advanced # will yield this path if needed. @@ -3927,14 +4206,16 @@ def iter_changes(self): elif current_path_info is None: # no path is fine: the per entry code will handle it. result, changed = _process_entry( - current_entry, current_path_info) + current_entry, current_path_info + ) if changed is not None: if changed: self._gather_result_for_consistency(result) if changed or self.include_unchanged: yield result - elif (current_entry[0][1] != current_path_info[1] - or current_entry[1][self.target_index][0] in (b'a', b'r')): + elif current_entry[0][1] != current_path_info[1] or current_entry[ + 1 + ][self.target_index][0] in (b"a", b"r"): # The current path on disk doesn't match the dirblock # record. Either the dirblock is marked as absent, or # the file on disk is not present at all in the @@ -3950,8 +4231,7 @@ def iter_changes(self): else: # entry referring to file not present on disk. # advance the entry only, after processing. - result, changed = _process_entry( - current_entry, None) + result, changed = _process_entry(current_entry, None) if changed is not None: if changed: self._gather_result_for_consistency(result) @@ -3960,7 +4240,8 @@ def iter_changes(self): advance_path = False else: result, changed = _process_entry( - current_entry, current_path_info) + current_entry, current_path_info + ) if changed is not None: path_handled = True if changed: @@ -3981,37 +4262,48 @@ def iter_changes(self): if self.want_unversioned: new_executable = bool( stat.S_ISREG(current_path_info[3].st_mode) - and stat.S_IEXEC & current_path_info[3].st_mode) + and stat.S_IEXEC & current_path_info[3].st_mode + ) relpath_unicode = utf8_decode( - current_path_info[0], 'surrogateescape')[0] + current_path_info[0], "surrogateescape" + )[0] yield InventoryTreeChange( None, (None, relpath_unicode), True, (False, False), (None, None), - (None, utf8_decode(current_path_info[1], 'surrogateescape')[0]), + ( + None, + utf8_decode( + current_path_info[1], "surrogateescape" + )[0], + ), (None, current_path_info[2]), - (None, new_executable)) + (None, new_executable), + ) # dont descend into this unversioned path if it is # a dir - if current_path_info[2] in ('directory'): + if current_path_info[2] in ("directory"): del current_dir_info[1][path_index] path_index -= 1 # dont descend the disk iterator into any tree # paths. - if current_path_info[2] == 'tree-reference': + if current_path_info[2] == "tree-reference": del current_dir_info[1][path_index] path_index -= 1 path_index += 1 if path_index < len(current_dir_info[1]): current_path_info = current_dir_info[1][path_index] - if current_path_info[2] == 'directory': + if current_path_info[2] == "directory": if self.tree._directory_is_tree_reference( - current_path_info[0].decode('utf8')): - current_path_info = current_path_info[:2] + \ - ('tree-reference',) + \ - current_path_info[3:] + current_path_info[0].decode("utf8") + ): + current_path_info = ( + current_path_info[:2] + + ("tree-reference",) + + current_path_info[3:] + ) else: current_path_info = None path_handled = False @@ -4019,9 +4311,9 @@ def iter_changes(self): advance_path = True # reset the advance flagg. if current_block is not None: block_index += 1 - if (block_index < len(self.state._dirblocks) and - osutils.is_inside(current_root, - self.state._dirblocks[block_index][0])): + if block_index < len(self.state._dirblocks) and osutils.is_inside( + current_root, self.state._dirblocks[block_index][0] + ): current_block = self.state._dirblocks[block_index] else: current_block = None @@ -4055,26 +4347,30 @@ def _iter_specific_file_parents(self): found_item = False for candidate_entry in path_entries: # Find entries present in target at this path: - if candidate_entry[1][self.target_index][0] not in (b'a', b'r'): + if candidate_entry[1][self.target_index][0] not in (b"a", b"r"): found_item = True selected_entries.append(candidate_entry) # Find entries present in source at this path: - elif (self.source_index is not None and - candidate_entry[1][self.source_index][0] not in (b'a', b'r')): + elif self.source_index is not None and candidate_entry[1][ + self.source_index + ][0] not in (b"a", b"r"): found_item = True - if candidate_entry[1][self.target_index][0] == b'a': + if candidate_entry[1][self.target_index][0] == b"a": # Deleted, emit it here. selected_entries.append(candidate_entry) else: # renamed, emit it when we process the directory it # ended up at. self.search_specific_file_parents.add( - candidate_entry[1][self.target_index][1]) + candidate_entry[1][self.target_index][1] + ) if not found_item: raise AssertionError( "Missing entry for specific path parent {!r}, {!r}".format( - path_utf8, path_entries)) - path_info = self._path_info(path_utf8, path_utf8.decode('utf8')) + path_utf8, path_entries + ) + ) + path_info = self._path_info(path_utf8, path_utf8.decode("utf8")) for entry in selected_entries: if entry[0][2] in self.seen_ids: continue @@ -4082,43 +4378,44 @@ def _iter_specific_file_parents(self): if changed is None: raise AssertionError( "Got entry<->path mismatch for specific path " - f"{path_utf8!r} entry {entry!r} path_info {path_info!r} ") + f"{path_utf8!r} entry {entry!r} path_info {path_info!r} " + ) # Only include changes - we're outside the users requested # expansion. if changed: self._gather_result_for_consistency(result) - if (result.kind[0] == 'directory' and - result.kind[1] != 'directory'): + if result.kind[0] == "directory" and result.kind[1] != "directory": # This stopped being a directory, the old children have # to be included. - if entry[1][self.source_index][0] == b'r': + if entry[1][self.source_index][0] == b"r": # renamed, take the source path entry_path_utf8 = entry[1][self.source_index][1] else: entry_path_utf8 = path_utf8 - initial_key = (entry_path_utf8, b'', b'') + initial_key = (entry_path_utf8, b"", b"") block_index, _ = self.state._find_block_index_from_key( - initial_key) + initial_key + ) if block_index == 0: # The children of the root are in block index 1. block_index += 1 current_block = None if block_index < len(self.state._dirblocks): current_block = self.state._dirblocks[block_index] - if not osutils.is_inside( - entry_path_utf8, current_block[0]): + if not osutils.is_inside(entry_path_utf8, current_block[0]): # No entries for this directory at all. current_block = None if current_block is not None: for entry in current_block[1]: - if entry[1][self.source_index][0] in (b'a', b'r'): + if entry[1][self.source_index][0] in (b"a", b"r"): # Not in the source tree, so doesn't have to be # included. continue # Path of the entry itself. self.search_specific_file_parents.add( - osutils.pathjoin(*entry[0][:2])) + osutils.pathjoin(*entry[0][:2]) + ) if changed or self.include_unchanged: yield result self.searched_exact_paths.add(path_utf8) @@ -4134,15 +4431,21 @@ def _path_info(self, utf8_path, unicode_path): except FileNotFoundError: # the path does not exist. return None - utf8_basename = utf8_path.rsplit(b'/', 1)[-1] - dir_info = (utf8_path, utf8_basename, - osutils.file_kind_from_stat_mode(stat.st_mode), stat, - abspath) - if dir_info[2] == 'directory': - if self.tree._directory_is_tree_reference( - unicode_path): - self.root_dir_info = self.root_dir_info[:2] + \ - ('tree-reference',) + self.root_dir_info[3:] + utf8_basename = utf8_path.rsplit(b"/", 1)[-1] + dir_info = ( + utf8_path, + utf8_basename, + osutils.file_kind_from_stat_mode(stat.st_mode), + stat, + abspath, + ) + if dir_info[2] == "directory": + if self.tree._directory_is_tree_reference(unicode_path): + self.root_dir_info = ( + self.root_dir_info[:2] + + ("tree-reference",) + + self.root_dir_info[3:] + ) return dir_info diff --git a/breezy/bzr/fetch.py b/breezy/bzr/fetch.py index 2df7abd2b5..1d4d5e1a36 100644 --- a/breezy/bzr/fetch.py +++ b/breezy/bzr/fetch.py @@ -39,8 +39,14 @@ class RepoFetcher: the logic in InterRepository.fetch(). """ - def __init__(self, to_repository, from_repository, last_revision=None, - find_ghosts=True, fetch_spec=None): + def __init__( + self, + to_repository, + from_repository, + last_revision=None, + find_ghosts=True, + fetch_spec=None, + ): """Create a repo fetcher. Args: @@ -60,9 +66,13 @@ def __init__(self, to_repository, from_repository, last_revision=None, self._fetch_spec = fetch_spec self.find_ghosts = find_ghosts with self.from_repository.lock_read(): - mutter("Using fetch logic to copy between %s(%s) and %s(%s)", - str(self.from_repository), str(self.from_repository._format), - str(self.to_repository), str(self.to_repository._format)) + mutter( + "Using fetch logic to copy between %s(%s) and %s(%s)", + str(self.from_repository), + str(self.from_repository._format), + str(self.to_repository), + str(self.to_repository._format), + ) self.__fetch() def __fetch(self): @@ -83,7 +93,7 @@ def __fetch(self): pb.show_pct = pb.show_count = False pb.update(gettext("Finding revisions"), 0, 2) search_result = self._revids_to_fetch() - mutter('fetching: %s', str(search_result)) + mutter("fetching: %s", str(search_result)) if search_result.is_empty(): return pb.update(gettext("Fetching revisions"), 1, 2) @@ -99,32 +109,37 @@ def _fetch_everything_for_search(self, search): # item_keys_introduced_by should have a richer API than it does at the # moment, so that it can feed the progress information back to this # function? - if (self.from_repository._format.rich_root_data and - not self.to_repository._format.rich_root_data): + if ( + self.from_repository._format.rich_root_data + and not self.to_repository._format.rich_root_data + ): raise errors.IncompatibleRepositories( - self.from_repository, self.to_repository, - "different rich-root support") + self.from_repository, self.to_repository, "different rich-root support" + ) with ui.ui_factory.nested_progress_bar() as pb: pb.update("Get stream source") - source = self.from_repository._get_source( - self.to_repository._format) + source = self.from_repository._get_source(self.to_repository._format) stream = source.get_stream(search) from_format = self.from_repository._format pb.update("Inserting stream") resume_tokens, missing_keys = self.sink.insert_stream( - stream, from_format, []) + stream, from_format, [] + ) if missing_keys: pb.update("Missing keys") stream = source.get_stream_for_missing_keys(missing_keys) pb.update("Inserting missing keys") resume_tokens, missing_keys = self.sink.insert_stream( - stream, from_format, resume_tokens) + stream, from_format, resume_tokens + ) if missing_keys: raise AssertionError( - f"second push failed to complete a fetch {missing_keys!r}.") + f"second push failed to complete a fetch {missing_keys!r}." + ) if resume_tokens: raise AssertionError( - f"second push failed to commit the fetch {resume_tokens!r}.") + f"second push failed to commit the fetch {resume_tokens!r}." + ) pb.update("Finishing stream") self.sink.finished() @@ -137,6 +152,7 @@ def _revids_to_fetch(self): PendingAncestryResult, EmptySearchResult, etc.) """ from . import vf_search + if self._fetch_spec is not None: # The fetch spec is already a concrete search result. return self._fetch_spec @@ -145,14 +161,16 @@ def _revids_to_fetch(self): # explicit limit of no revisions needed return vf_search.EmptySearchResult() elif self._last_revision is not None: - return vf_search.NotInOtherForRevs(self.to_repository, - self.from_repository, [ - self._last_revision], - find_ghosts=self.find_ghosts).execute() + return vf_search.NotInOtherForRevs( + self.to_repository, + self.from_repository, + [self._last_revision], + find_ghosts=self.find_ghosts, + ).execute() else: # self._last_revision is None: - return vf_search.EverythingNotInOther(self.to_repository, - self.from_repository, - find_ghosts=self.find_ghosts).execute() + return vf_search.EverythingNotInOther( + self.to_repository, self.from_repository, find_ghosts=self.find_ghosts + ).execute() class Inter1and2Helper: @@ -195,8 +213,8 @@ def iter_rev_trees(self, revs): def _find_root_ids(self, revs, parent_map, graph): revision_root = {} for tree in self.iter_rev_trees(revs): - root_id = tree.path2id('') - revision_id = tree.get_file_revision('') + root_id = tree.path2id("") + revision_id = tree.get_file_revision("") revision_root[revision_id] = root_id # Find out which parents we don't already know root ids for parents = set(parent_map.values()) @@ -205,7 +223,7 @@ def _find_root_ids(self, revs, parent_map, graph): # Limit to revisions present in the versionedfile parents = graph.get_parent_map(parents) for tree in self.iter_rev_trees(parents): - root_id = tree.path2id('') + root_id = tree.path2id("") revision_root[tree.get_revision_id()] = root_id return revision_root @@ -216,12 +234,12 @@ def generate_root_texts(self, revs): revs: the revisions to include """ from ..tsort import topo_sort + graph = self.source.get_graph() parent_map = graph.get_parent_map(revs) rev_order = topo_sort(parent_map) rev_id_to_root_id = self._find_root_ids(revs, parent_map, graph) - root_id_order = [(rev_id_to_root_id[rev_id], rev_id) for rev_id in - rev_order] + root_id_order = [(rev_id_to_root_id[rev_id], rev_id) for rev_id in rev_order] # Guaranteed stable, this groups all the file id operations together # retaining topological order within the revisions of a file id. # File id splits and joins would invalidate this, but they don't exist @@ -231,12 +249,14 @@ def generate_root_texts(self, revs): if len(revs) > self.known_graph_threshold: graph = self.source.get_known_graph_ancestry(revs) new_roots_stream = _new_root_data_stream( - root_id_order, rev_id_to_root_id, parent_map, self.source, graph) - return [('texts', new_roots_stream)] + root_id_order, rev_id_to_root_id, parent_map, self.source, graph + ) + return [("texts", new_roots_stream)] def _new_root_data_stream( - root_keys_to_create, rev_id_to_root_id_map, parent_map, repo, graph=None): + root_keys_to_create, rev_id_to_root_id_map, parent_map, repo, graph=None +): """Generate a texts substream of synthesised root entries. Used in fetches that do rich-root upgrades. @@ -253,15 +273,18 @@ def _new_root_data_stream( graph: a graph to use instead of repo.get_graph(). """ from .versionedfile import ChunkedContentFactory + for root_key in root_keys_to_create: root_id, rev_id = root_key parent_keys = _parent_keys_for_root_version( - root_id, rev_id, rev_id_to_root_id_map, parent_map, repo, graph) + root_id, rev_id, rev_id_to_root_id_map, parent_map, repo, graph + ) yield ChunkedContentFactory(root_key, parent_keys, None, []) def _parent_keys_for_root_version( - root_id, rev_id, rev_id_to_root_id_map, parent_map, repo, graph=None): + root_id, rev_id, rev_id_to_root_id_map, parent_map, repo, graph=None +): """Get the parent keys for a given root id. A helper function for _new_root_data_stream. @@ -284,7 +307,7 @@ def _parent_keys_for_root_version( # But set parent_root_id to None since we don't really know parent_root_id = None else: - parent_root_id = tree.path2id('') + parent_root_id = tree.path2id("") rev_id_to_root_id_map[parent_id] = None # XXX: why not: # rev_id_to_root_id_map[parent_id] = parent_root_id @@ -307,8 +330,8 @@ def _parent_keys_for_root_version( else: try: parent_ids.append( - tree.get_file_revision( - tree.id2path(root_id, recurse='none'))) + tree.get_file_revision(tree.id2path(root_id, recurse="none")) + ) except errors.NoSuchId: # not in the tree pass @@ -330,9 +353,9 @@ class TargetRepoKinds: They are the possible values of FetchSpecFactory.target_repo_kinds. """ - PREEXISTING = 'preexisting' - STACKED = 'stacked' - EMPTY = 'empty' + PREEXISTING = "preexisting" + STACKED = "stacked" + EMPTY = "empty" class FetchSpecFactory: @@ -372,13 +395,14 @@ def add_revision_ids(self, revision_ids): def make_fetch_spec(self): """Build a SearchResult or PendingAncestryResult or etc.""" from . import vf_search + if self.target_repo_kind is None or self.source_repo is None: - raise AssertionError( - f'Incomplete FetchSpecFactory: {self.__dict__!r}') + raise AssertionError(f"Incomplete FetchSpecFactory: {self.__dict__!r}") if len(self._explicit_rev_ids) == 0 and self.source_branch is None: if self.limit is not None: raise NotImplementedError( - "limit is only supported with a source branch set") + "limit is only supported with a source branch set" + ) # Caller hasn't specified any revisions or source branch if self.target_repo_kind == TargetRepoKinds.EMPTY: return vf_search.EverythingResult(self.source_repo) @@ -386,7 +410,8 @@ def make_fetch_spec(self): # We want everything not already in the target (or target's # fallbacks). return vf_search.EverythingNotInOther( - self.target_repo, self.source_repo).execute() + self.target_repo, self.source_repo + ).execute() heads_to_fetch = set(self._explicit_rev_ids) if self.source_branch is not None: must_fetch, if_present_fetch = self.source_branch.heads_to_fetch() @@ -412,11 +437,14 @@ def make_fetch_spec(self): if self.limit is not None: graph = self.source_repo.get_graph() topo_order = list(graph.iter_topo_order(ret.get_keys())) - result_set = topo_order[:self.limit] - ret = self.source_repo.revision_ids_to_search_result( - result_set) + result_set = topo_order[: self.limit] + ret = self.source_repo.revision_ids_to_search_result(result_set) return ret else: - return vf_search.NotInOtherForRevs(self.target_repo, self.source_repo, - required_ids=heads_to_fetch, if_present_ids=if_present_fetch, - limit=self.limit).execute() + return vf_search.NotInOtherForRevs( + self.target_repo, + self.source_repo, + required_ids=heads_to_fetch, + if_present_ids=if_present_fetch, + limit=self.limit, + ).execute() diff --git a/breezy/bzr/fullhistory.py b/breezy/bzr/fullhistory.py index 0884a014b7..dfa09b1fe8 100644 --- a/breezy/bzr/fullhistory.py +++ b/breezy/bzr/fullhistory.py @@ -28,15 +28,14 @@ class FullHistoryBzrBranch(BzrBranch): def set_last_revision_info(self, revno, revision_id): if not revision_id or not isinstance(revision_id, bytes): - raise errors.InvalidRevisionId( - revision_id=revision_id, branch=self) + raise errors.InvalidRevisionId(revision_id=revision_id, branch=self) with self.lock_write(): # this old format stores the full history, but this api doesn't # provide it, so we must generate, and might as well check it's # correct history = self._lefthand_history(revision_id) if len(history) != revno: - raise AssertionError('%d != %d' % (len(history), revno)) + raise AssertionError("%d != %d" % (len(history), revno)) self._set_revision_history(history) def _read_last_revision_info(self): @@ -48,12 +47,12 @@ def _read_last_revision_info(self): return (0, _mod_revision.NULL_REVISION) def _set_revision_history(self, rev_history): - if debug.debug_flag_enabled('evil'): + if debug.debug_flag_enabled("evil"): mutter_callsite(3, "set_revision_history scales with history.") check_not_reserved_id = _mod_revision.check_not_reserved_id for rev_id in rev_history: check_not_reserved_id(rev_id) - if Branch.hooks['post_change_branch_tip']: + if Branch.hooks["post_change_branch_tip"]: # Don't calculate the last_revision_info() if there are no hooks # that will use it. old_revno, old_revid = self.last_revision_info() @@ -65,7 +64,7 @@ def _set_revision_history(self, rev_history): self._write_revision_history(rev_history) self._clear_cached_state() self._cache_revision_history(rev_history) - if Branch.hooks['post_change_branch_tip']: + if Branch.hooks["post_change_branch_tip"]: self._run_post_change_branch_tip_hooks(old_revno, old_revid) def _write_revision_history(self, history): @@ -75,20 +74,21 @@ def _write_revision_history(self, history): It is intended to be called by set_revision_history. """ self._transport.put_bytes( - 'revision-history', b'\n'.join(history), - mode=self.controldir._get_file_mode()) + "revision-history", + b"\n".join(history), + mode=self.controldir._get_file_mode(), + ) def _gen_revision_history(self): - history = self._transport.get_bytes('revision-history').split(b'\n') - if history[-1:] == [b'']: + history = self._transport.get_bytes("revision-history").split(b"\n") + if history[-1:] == [b""]: # There shouldn't be a trailing newline, but just in case. history.pop() return history def _synchronize_history(self, destination, revision_id): if not isinstance(destination, FullHistoryBzrBranch): - super(BzrBranch, self)._synchronize_history( - destination, revision_id) + super(BzrBranch, self)._synchronize_history(destination, revision_id) return if revision_id == _mod_revision.NULL_REVISION: new_history = [] @@ -96,14 +96,13 @@ def _synchronize_history(self, destination, revision_id): new_history = self._revision_history() if revision_id is not None and new_history != []: try: - new_history = new_history[:new_history.index(revision_id) + 1] + new_history = new_history[: new_history.index(revision_id) + 1] except ValueError: rev = self.repository.get_revision(revision_id) new_history = _mod_revision.get_history(self.repository, rev)[1:] destination._set_revision_history(new_history) - def generate_revision_history(self, revision_id, last_rev=None, - other_branch=None): + def generate_revision_history(self, revision_id, last_rev=None, other_branch=None): """Create a new revision history that will finish with revision_id. :param revision_id: the new tip to use. @@ -113,8 +112,9 @@ def generate_revision_history(self, revision_id, last_rev=None, raise with respect to. """ with self.lock_write(): - self._set_revision_history(self._lefthand_history(revision_id, - last_rev, other_branch)) + self._set_revision_history( + self._lefthand_history(revision_id, last_rev, other_branch) + ) class BzrBranch5(FullHistoryBzrBranch): @@ -149,14 +149,16 @@ def get_format_description(self): """See BranchFormat.get_format_description().""" return "Branch format 5" - def initialize(self, a_controldir, name=None, repository=None, - append_revisions_only=None): + def initialize( + self, a_controldir, name=None, repository=None, append_revisions_only=None + ): """Create a branch of this format in a_controldir.""" if append_revisions_only: raise errors.UpgradeRequired(a_controldir.user_url) - utf8_files = [('revision-history', b''), - ('branch-name', b''), - ] + utf8_files = [ + ("revision-history", b""), + ("branch-name", b""), + ] return self._initialize_helper(a_controldir, utf8_files, name, repository) def supports_tags(self): diff --git a/breezy/bzr/groupcompress.py b/breezy/bzr/groupcompress.py index bdca924c06..49be5c2d9c 100644 --- a/breezy/bzr/groupcompress.py +++ b/breezy/bzr/groupcompress.py @@ -25,13 +25,16 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy.bzr import ( knit, static_tuple, ) -""") +""", +) from .. import errors, osutils, trace from ..lru_cache import LRUSizeCache @@ -51,7 +54,7 @@ BATCH_SIZE = 2**16 # osutils.sha_string(b'') -_null_sha1 = b'da39a3ee5e6b4b0d3255bfef95601890afd80709' +_null_sha1 = b"da39a3ee5e6b4b0d3255bfef95601890afd80709" def sort_gc_optimal(parent_map): @@ -69,7 +72,7 @@ def sort_gc_optimal(parent_map): per_prefix_map = {} for key, value in parent_map.items(): if isinstance(key, bytes) or len(key) == 1: - prefix = b'' + prefix = b"" else: prefix = key[0] try: @@ -84,7 +87,6 @@ def sort_gc_optimal(parent_map): class DecompressCorruption(errors.BzrError): - _fmt = "Corruption while decompressing repository file%(orig_error)s" def __init__(self, orig_error=None): @@ -108,9 +110,9 @@ class GroupCompressBlock: """ # Group Compress Block v1 Zlib - GCB_HEADER = b'gcb1z\n' + GCB_HEADER = b"gcb1z\n" # Group Compress Block v1 Lzma - GCB_LZ_HEADER = b'gcb1l\n' + GCB_LZ_HEADER = b"gcb1l\n" GCB_KNOWN_HEADERS = (GCB_HEADER, GCB_LZ_HEADER) def __init__(self): @@ -136,33 +138,34 @@ def _ensure_content(self, num_bytes=None): content. If None, consume everything """ if self._content_length is None: - raise AssertionError('self._content_length should never be None') + raise AssertionError("self._content_length should never be None") if num_bytes is None: num_bytes = self._content_length - elif (self._content_length is not None - and num_bytes > self._content_length): + elif self._content_length is not None and num_bytes > self._content_length: raise AssertionError( - 'requested num_bytes (%d) > content length (%d)' - % (num_bytes, self._content_length)) + "requested num_bytes (%d) > content length (%d)" + % (num_bytes, self._content_length) + ) # Expand the content if required if self._content is None: if self._content_chunks is not None: - self._content = b''.join(self._content_chunks) + self._content = b"".join(self._content_chunks) self._content_chunks = None if self._content is None: # We join self._z_content_chunks here, because if we are # decompressing, then it is *very* likely that we have a single # chunk if self._z_content_chunks is None: - raise AssertionError('No content to decompress') - z_content = b''.join(self._z_content_chunks) - if z_content == b'': - self._content = b'' - elif self._compressor_name == 'lzma': + raise AssertionError("No content to decompress") + z_content = b"".join(self._z_content_chunks) + if z_content == b"": + self._content = b"" + elif self._compressor_name == "lzma": # We don't do partial lzma decomp yet import pylzma + self._content = pylzma.decompress(z_content) - elif self._compressor_name == 'zlib': + elif self._compressor_name == "zlib": # Start a zlib decompressor if num_bytes * 4 > self._content_length * 3: # If we are requesting more that 3/4ths of the content, @@ -174,11 +177,12 @@ def _ensure_content(self, num_bytes=None): # Seed the decompressor with the uncompressed bytes, so # that the rest of the code is simplified self._content = self._z_content_decompressor.decompress( - z_content, num_bytes + _ZLIB_DECOMP_WINDOW) + z_content, num_bytes + _ZLIB_DECOMP_WINDOW + ) if not self._z_content_decompressor.unconsumed_tail: self._z_content_decompressor = None else: - raise AssertionError(f'Unknown compressor: {self._compressor_name!r}') + raise AssertionError(f"Unknown compressor: {self._compressor_name!r}") # Any bytes remaining to be decompressed will be in the decompressors # 'unconsumed_tail' @@ -187,21 +191,22 @@ def _ensure_content(self, num_bytes=None): return # If we got this far, and don't have a decompressor, something is wrong if self._z_content_decompressor is None: - raise AssertionError( - 'No decompressor to decompress %d bytes' % num_bytes) + raise AssertionError("No decompressor to decompress %d bytes" % num_bytes) remaining_decomp = self._z_content_decompressor.unconsumed_tail if not remaining_decomp: - raise AssertionError('Nothing left to decompress') + raise AssertionError("Nothing left to decompress") needed_bytes = num_bytes - len(self._content) # We always set max_size to 32kB over the minimum needed, so that # zlib will give us as much as we really want. # TODO: If this isn't good enough, we could make a loop here, # that keeps expanding the request until we get enough self._content += self._z_content_decompressor.decompress( - remaining_decomp, needed_bytes + _ZLIB_DECOMP_WINDOW) + remaining_decomp, needed_bytes + _ZLIB_DECOMP_WINDOW + ) if len(self._content) < num_bytes: - raise AssertionError('%d bytes wanted, only %d available' - % (num_bytes, len(self._content))) + raise AssertionError( + "%d bytes wanted, only %d available" % (num_bytes, len(self._content)) + ) if not self._z_content_decompressor.unconsumed_tail: # The stream is finished self._z_content_decompressor = None @@ -216,16 +221,18 @@ def _parse_bytes(self, data, pos): # At present, we have 2 integers for the compressed and uncompressed # content. In base10 (ascii) 14 bytes can represent > 1TB, so to avoid # checking too far, cap the search to 14 bytes. - pos2 = data.index(b'\n', pos, pos + 14) + pos2 = data.index(b"\n", pos, pos + 14) self._z_content_length = int(data[pos:pos2]) pos = pos2 + 1 - pos2 = data.index(b'\n', pos, pos + 14) + pos2 = data.index(b"\n", pos, pos + 14) self._content_length = int(data[pos:pos2]) pos = pos2 + 1 if len(data) != (pos + self._z_content_length): # XXX: Define some GCCorrupt error ? - raise AssertionError('Invalid bytes: (%d) != %d + %d' % - (len(data), pos, self._z_content_length)) + raise AssertionError( + "Invalid bytes: (%d) != %d + %d" + % (len(data), pos, self._z_content_length) + ) self._z_content_chunks = (data[pos:],) @property @@ -235,7 +242,7 @@ def _z_content(self): Meant only to be used by the test suite. """ if self._z_content_chunks is not None: - return b''.join(self._z_content_chunks) + return b"".join(self._z_content_chunks) return None @classmethod @@ -243,13 +250,15 @@ def from_bytes(cls, bytes): out = cls() header = bytes[:6] if header not in cls.GCB_KNOWN_HEADERS: - raise ValueError(f'bytes did not start with any of {cls.GCB_KNOWN_HEADERS!r}') + raise ValueError( + f"bytes did not start with any of {cls.GCB_KNOWN_HEADERS!r}" + ) if header == cls.GCB_HEADER: - out._compressor_name = 'zlib' + out._compressor_name = "zlib" elif header == cls.GCB_LZ_HEADER: - out._compressor_name = 'lzma' + out._compressor_name = "lzma" else: - raise ValueError(f'unknown compressor: {header!r}') + raise ValueError(f"unknown compressor: {header!r}") out._parse_bytes(bytes, 6) return out @@ -267,19 +276,20 @@ def extract(self, key, start, end, sha1=None): # base128 integer for the content size, then the actual content # We know that the variable-length integer won't be longer than 5 # bytes (it takes 5 bytes to encode 2^32) - c = self._content[start:start + 1] - if c == b'f': + c = self._content[start : start + 1] + if c == b"f": pass else: - if c != b'd': - raise ValueError(f'Unknown content control code: {c}') - content_len, len_len = decode_base128_int( - self._content[start + 1:start + 6]) + if c != b"d": + raise ValueError(f"Unknown content control code: {c}") + content_len, len_len = decode_base128_int(self._content[start + 1 : start + 6]) content_start = start + 1 + len_len if end != content_start + content_len: - raise ValueError('end != len according to field header' - f' {end} != {content_start + content_len}') - if c == b'f': + raise ValueError( + "end != len according to field header" + f" {end} != {content_start + content_len}" + ) + if c == b"f": return [self._content[content_start:end]] # Must be type delta as checked above return [apply_delta_to_source(self._content, content_start, end)] @@ -325,9 +335,9 @@ def to_chunks(self): """Create the byte stream as a series of 'chunks'.""" self._create_z_content() header = self.GCB_HEADER - chunks = [b'%s%d\n%d\n' - % (header, self._z_content_length, self._content_length), - ] + chunks = [ + b"%s%d\n%d\n" % (header, self._z_content_length, self._content_length), + ] chunks.extend(self._z_content_chunks) total_len = sum(map(len, chunks)) return total_len, chunks @@ -335,7 +345,7 @@ def to_chunks(self): def to_bytes(self): """Encode the information into a byte stream.""" total_len, chunks = self.to_chunks() - return b''.join(chunks) + return b"".join(chunks) def _dump(self, include_text=False): """Take this block, and spit out a human-readable structure. @@ -351,57 +361,61 @@ def _dump(self, include_text=False): result = [] pos = 0 while pos < self._content_length: - kind = self._content[pos:pos + 1] + kind = self._content[pos : pos + 1] pos += 1 - if kind not in (b'f', b'd'): - raise ValueError(f'invalid kind character: {kind!r}') - content_len, len_len = decode_base128_int( - self._content[pos:pos + 5]) + if kind not in (b"f", b"d"): + raise ValueError(f"invalid kind character: {kind!r}") + content_len, len_len = decode_base128_int(self._content[pos : pos + 5]) pos += len_len if content_len + pos > self._content_length: - raise ValueError('invalid content_len %d for record @ pos %d' - % (content_len, pos - len_len - 1)) - if kind == b'f': # Fulltext + raise ValueError( + "invalid content_len %d for record @ pos %d" + % (content_len, pos - len_len - 1) + ) + if kind == b"f": # Fulltext if include_text: - text = self._content[pos:pos + content_len] - result.append((b'f', content_len, text)) + text = self._content[pos : pos + content_len] + result.append((b"f", content_len, text)) else: - result.append((b'f', content_len)) - elif kind == b'd': # Delta - delta_content = self._content[pos:pos + content_len] + result.append((b"f", content_len)) + elif kind == b"d": # Delta + delta_content = self._content[pos : pos + content_len] delta_info = [] # The first entry in a delta is the decompressed length decomp_len, delta_pos = decode_base128_int(delta_content) - result.append((b'd', content_len, decomp_len, delta_info)) + result.append((b"d", content_len, decomp_len, delta_info)) measured_len = 0 while delta_pos < content_len: c = delta_content[delta_pos] delta_pos += 1 if c & 0x80: # Copy - (offset, length, - delta_pos) = decode_copy_instruction(delta_content, c, - delta_pos) + (offset, length, delta_pos) = decode_copy_instruction( + delta_content, c, delta_pos + ) if include_text: - text = self._content[offset:offset + length] - delta_info.append((b'c', offset, length, text)) + text = self._content[offset : offset + length] + delta_info.append((b"c", offset, length, text)) else: - delta_info.append((b'c', offset, length)) + delta_info.append((b"c", offset, length)) measured_len += length else: # Insert if include_text: - txt = delta_content[delta_pos:delta_pos + c] + txt = delta_content[delta_pos : delta_pos + c] else: - txt = b'' - delta_info.append((b'i', c, txt)) + txt = b"" + delta_info.append((b"i", c, txt)) measured_len += c delta_pos += c if delta_pos != content_len: - raise ValueError('Delta consumed a bad number of bytes:' - ' %d != %d' % (delta_pos, content_len)) + raise ValueError( + "Delta consumed a bad number of bytes:" + " %d != %d" % (delta_pos, content_len) + ) if measured_len != decomp_len: - raise ValueError('Delta claimed fulltext was %d bytes, but' - ' extraction resulted in %d bytes' - % (decomp_len, measured_len)) + raise ValueError( + "Delta claimed fulltext was %d bytes, but" + " extraction resulted in %d bytes" % (decomp_len, measured_len) + ) pos += content_len return result @@ -432,15 +446,15 @@ def __init__(self, key, parents, manager, start, end, first): # the object? self._manager = manager self._chunks = None - self.storage_kind = 'groupcompress-block' + self.storage_kind = "groupcompress-block" if not first: - self.storage_kind = 'groupcompress-block-ref' + self.storage_kind = "groupcompress-block-ref" self._first = first self._start = start self._end = end def __repr__(self): - return f'{self.__class__.__name__}({self.key}, first={self._first})' + return f"{self.__class__.__name__}({self.key}, first={self._first})" def _extract_bytes(self): # Grab and cache the raw bytes for this entry @@ -462,28 +476,26 @@ def get_bytes_as(self, storage_kind): # wire bytes, something... return self._manager._wire_bytes() else: - return b'' - if storage_kind in ('fulltext', 'chunked', 'lines'): + return b"" + if storage_kind in ("fulltext", "chunked", "lines"): if self._chunks is None: self._extract_bytes() - if storage_kind == 'fulltext': - return b''.join(self._chunks) - elif storage_kind == 'chunked': + if storage_kind == "fulltext": + return b"".join(self._chunks) + elif storage_kind == "chunked": return self._chunks else: return osutils.chunks_to_lines(self._chunks) - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) def iter_bytes_as(self, storage_kind): if self._chunks is None: self._extract_bytes() - if storage_kind == 'chunked': + if storage_kind == "chunked": return iter(self._chunks) - elif storage_kind == 'lines': + elif storage_kind == "lines": return osutils.chunks_to_lines_iter(iter(self._chunks)) - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) class _LazyGroupContentManager: @@ -523,8 +535,7 @@ def add_factory(self, key, parents, start, end): else: first = False # Note that this creates a reference cycle.... - factory = _LazyGroupCompressFactory(key, parents, self, - start, end, first=first) + factory = _LazyGroupCompressFactory(key, parents, self, start, end, first=first) # max() works here, but as a function call, doing a compare seems to be # significantly faster, timeit says 250ms for max() and 100ms for the # comparison @@ -547,8 +558,11 @@ def _trim_block(self, last_byte): # None of the factories need to be adjusted, because the content is # located in an identical place. Just that some of the unreferenced # trailing bytes are stripped - trace.mutter('stripping trailing bytes from groupcompress block' - ' %d => %d', self._block._content_length, last_byte) + trace.mutter( + "stripping trailing bytes from groupcompress block" " %d => %d", + self._block._content_length, + last_byte, + ) new_block = GroupCompressBlock() self._block._ensure_content(last_byte) new_block.set_content(self._block._content[:last_byte]) @@ -564,13 +578,13 @@ def _rebuild_block(self): old_length = self._block._content_length end_point = 0 for factory in self._factories: - chunks = factory.get_bytes_as('chunked') + chunks = factory.get_bytes_as("chunked") chunks_len = factory.size if chunks_len is None: chunks_len = sum(map(len, chunks)) - (found_sha1, start_point, end_point, - type) = compressor.compress( - factory.key, chunks, chunks_len, factory.sha1) + (found_sha1, start_point, end_point, type) = compressor.compress( + factory.key, chunks, chunks_len, factory.sha1 + ) # Now update this factory with the new offsets, etc factory.sha1 = found_sha1 factory._start = start_point @@ -588,9 +602,12 @@ def _rebuild_block(self): # to a very low value, causing poor compression. delta = time.time() - tstart self._block = new_block - trace.mutter('creating new compressed block on-the-fly in %.3fs' - ' %d bytes => %d bytes', delta, old_length, - self._block._content_length) + trace.mutter( + "creating new compressed block on-the-fly in %.3fs" " %d bytes => %d bytes", + delta, + old_length, + self._block._content_length, + ) def _prepare_for_extract(self): """A _LazyGroupCompressFactory is about to extract to fulltext.""" @@ -616,7 +633,7 @@ def _check_rebuild_action(self): # using at the beginning of the block? If so, we can just trim the # tail, rather than rebuilding from scratch. if total_bytes_used * 2 > last_byte_used: - return 'trim', last_byte_used, total_bytes_used + return "trim", last_byte_used, total_bytes_used # We are using a small amount of the data, and it isn't just packed # nicely at the front, so rebuild the content. @@ -629,7 +646,7 @@ def _check_rebuild_action(self): # expanding many deltas into fulltexts, as well. # If we build a cheap enough 'strip', then we could try a strip, # if that expands the content, we then rebuild. - return 'rebuild', last_byte_used, total_bytes_used + return "rebuild", last_byte_used, total_bytes_used def check_is_well_utilized(self): """Is the current block considered 'well utilized'? @@ -697,12 +714,12 @@ def _check_rebuild_block(self): action, last_byte_used, total_bytes_used = self._check_rebuild_action() if action is None: return - if action == 'trim': + if action == "trim": self._trim_block(last_byte_used) - elif action == 'rebuild': + elif action == "rebuild": self._rebuild_block() else: - raise ValueError(f'unknown rebuild action: {action!r}') + raise ValueError(f"unknown rebuild action: {action!r}") def _wire_bytes(self): """Return a byte stream suitable for transmitting over the wire.""" @@ -714,7 +731,7 @@ def _wire_bytes(self): # \n #
# - lines = [b'groupcompress-block\n'] + lines = [b"groupcompress-block\n"] # The minimal info we need is the key, the start offset, and the # parents. The length and type are encoded in the record itself. # However, passing in the other bits makes it easier. The list of @@ -725,80 +742,89 @@ def _wire_bytes(self): # 1 line for end byte header_lines = [] for factory in self._factories: - key_bytes = b'\x00'.join(factory.key) + key_bytes = b"\x00".join(factory.key) parents = factory.parents if parents is None: - parent_bytes = b'None:' + parent_bytes = b"None:" else: - parent_bytes = b'\t'.join(b'\x00'.join(key) for key in parents) - record_header = b'%s\n%s\n%d\n%d\n' % ( - key_bytes, parent_bytes, factory._start, factory._end) + parent_bytes = b"\t".join(b"\x00".join(key) for key in parents) + record_header = b"%s\n%s\n%d\n%d\n" % ( + key_bytes, + parent_bytes, + factory._start, + factory._end, + ) header_lines.append(record_header) # TODO: Can we break the refcycle at this point and set # factory._manager = None? - header_bytes = b''.join(header_lines) + header_bytes = b"".join(header_lines) del header_lines header_bytes_len = len(header_bytes) z_header_bytes = zlib.compress(header_bytes) del header_bytes z_header_bytes_len = len(z_header_bytes) block_bytes_len, block_chunks = self._block.to_chunks() - lines.append(b'%d\n%d\n%d\n' % ( - z_header_bytes_len, header_bytes_len, block_bytes_len)) + lines.append( + b"%d\n%d\n%d\n" % (z_header_bytes_len, header_bytes_len, block_bytes_len) + ) lines.append(z_header_bytes) lines.extend(block_chunks) del z_header_bytes, block_chunks # TODO: This is a point where we will double the memory consumption. To # avoid this, we probably have to switch to a 'chunked' api - return b''.join(lines) + return b"".join(lines) @classmethod def from_bytes(cls, bytes): # TODO: This does extra string copying, probably better to do it a # different way. At a minimum this creates 2 copies of the # compressed content - (storage_kind, z_header_len, header_len, - block_len, rest) = bytes.split(b'\n', 4) + (storage_kind, z_header_len, header_len, block_len, rest) = bytes.split( + b"\n", 4 + ) del bytes - if storage_kind != b'groupcompress-block': - raise ValueError(f'Unknown storage kind: {storage_kind}') + if storage_kind != b"groupcompress-block": + raise ValueError(f"Unknown storage kind: {storage_kind}") z_header_len = int(z_header_len) if len(rest) < z_header_len: - raise ValueError('Compressed header len shorter than all bytes') + raise ValueError("Compressed header len shorter than all bytes") z_header = rest[:z_header_len] header_len = int(header_len) header = zlib.decompress(z_header) if len(header) != header_len: - raise ValueError('invalid length for decompressed bytes') + raise ValueError("invalid length for decompressed bytes") del z_header block_len = int(block_len) if len(rest) != z_header_len + block_len: - raise ValueError('Invalid length for block') + raise ValueError("Invalid length for block") block_bytes = rest[z_header_len:] del rest # So now we have a valid GCB, we just need to parse the factories that # were sent to us - header_lines = header.split(b'\n') + header_lines = header.split(b"\n") del header last = header_lines.pop() - if last != b'': - raise ValueError('header lines did not end with a trailing' - ' newline') + if last != b"": + raise ValueError("header lines did not end with a trailing" " newline") if len(header_lines) % 4 != 0: - raise ValueError('The header was not an even multiple of 4 lines') + raise ValueError("The header was not an even multiple of 4 lines") block = GroupCompressBlock.from_bytes(block_bytes) del block_bytes result = cls(block) for start in range(0, len(header_lines), 4): # intern()? - key = tuple(header_lines[start].split(b'\x00')) + key = tuple(header_lines[start].split(b"\x00")) parents_line = header_lines[start + 1] - if parents_line == b'None:': + if parents_line == b"None:": parents = None else: - parents = tuple([tuple(segment.split(b'\x00')) - for segment in parents_line.split(b'\t') - if segment]) + parents = tuple( + [ + tuple(segment.split(b"\x00")) + for segment in parents_line.split(b"\t") + if segment + ] + ) start_offset = int(header_lines[start + 2]) end_offset = int(header_lines[start + 3]) result.add_factory(key, parents, start_offset, end_offset) @@ -806,14 +832,13 @@ def from_bytes(cls, bytes): def network_block_to_records(storage_kind, bytes, line_end): - if storage_kind != 'groupcompress-block': - raise ValueError(f'Unknown storage kind: {storage_kind}') + if storage_kind != "groupcompress-block": + raise ValueError(f"Unknown storage kind: {storage_kind}") manager = _LazyGroupContentManager.from_bytes(bytes) return manager.get_record_stream() class _CommonGroupCompressor: - def __init__(self, settings=None): """Create a GroupCompressor.""" self.chunks = [] @@ -828,8 +853,7 @@ def __init__(self, settings=None): else: self._settings = settings - def compress(self, key, chunks, length, expected_sha, nostore_sha=None, - soft=False): + def compress(self, key, chunks, length, expected_sha, nostore_sha=None, soft=False): """Compress lines with label key. :param key: A key tuple. It is stored in the output @@ -854,7 +878,7 @@ def compress(self, key, chunks, length, expected_sha, nostore_sha=None, if length == 0: # empty, like a dir entry, etc if nostore_sha == _null_sha1: raise ExistingContent() - return _null_sha1, 0, 0, 'fulltext' + return _null_sha1, 0, 0, "fulltext" # we assume someone knew what they were doing when they passed it in if expected_sha is not None: sha1 = expected_sha @@ -864,7 +888,7 @@ def compress(self, key, chunks, length, expected_sha, nostore_sha=None, if sha1 == nostore_sha: raise ExistingContent() if key[-1] is None: - key = key[:-1] + (b'sha1:' + sha1,) + key = key[:-1] + (b"sha1:" + sha1,) start, end, type = self._compress(key, chunks, length, length / 2, soft) return sha1, start, end, type @@ -896,29 +920,32 @@ def extract(self, key): :param key: The key to extract. :return: An iterable over chunks and the sha1. """ - (start_byte, start_chunk, end_byte, - end_chunk) = self.labels_deltas[key] + (start_byte, start_chunk, end_byte, end_chunk) = self.labels_deltas[key] delta_chunks = self.chunks[start_chunk:end_chunk] - stored_bytes = b''.join(delta_chunks) + stored_bytes = b"".join(delta_chunks) kind = stored_bytes[:1] - if kind == b'f': + if kind == b"f": fulltext_len, offset = decode_base128_int(stored_bytes[1:10]) data_len = fulltext_len + 1 + offset if data_len != len(stored_bytes): - raise ValueError('Index claimed fulltext len, but stored bytes' - f' claim {len(stored_bytes)} != {data_len}') - data = [stored_bytes[offset + 1:]] + raise ValueError( + "Index claimed fulltext len, but stored bytes" + f" claim {len(stored_bytes)} != {data_len}" + ) + data = [stored_bytes[offset + 1 :]] else: - if kind != b'd': - raise ValueError(f'Unknown content kind, bytes claim {kind}') + if kind != b"d": + raise ValueError(f"Unknown content kind, bytes claim {kind}") # XXX: This is inefficient at best - source = b''.join(self.chunks[:start_chunk]) + source = b"".join(self.chunks[:start_chunk]) delta_len, offset = decode_base128_int(stored_bytes[1:10]) data_len = delta_len + 1 + offset if data_len != len(stored_bytes): - raise ValueError('Index claimed delta len, but stored bytes' - f' claim {len(stored_bytes)} != {data_len}') - data = [apply_delta(source, stored_bytes[offset + 1:])] + raise ValueError( + "Index claimed delta len, but stored bytes" + f" claim {len(stored_bytes)} != {data_len}" + ) + data = [apply_delta(source, stored_bytes[offset + 1 :])] data_sha1 = osutils.sha_strings(data) return data, data_sha1 @@ -939,7 +966,7 @@ def pop_last(self): more compression. """ self._delta_index = None - del self.chunks[self._last[0]:] + del self.chunks[self._last[0] :] self.endpoint = self._last[1] self._last = None @@ -949,7 +976,6 @@ def ratio(self): class PythonGroupCompressor(_CommonGroupCompressor): - def __init__(self, settings=None): """Create a GroupCompressor. @@ -964,19 +990,20 @@ def _compress(self, key, chunks, input_len, max_delta_size, soft=False): """See _CommonGroupCompressor._compress.""" new_lines = osutils.chunks_to_lines(chunks) out_lines, index_lines = self._delta_index.make_delta( - new_lines, bytes_length=input_len, soft=soft) + new_lines, bytes_length=input_len, soft=soft + ) delta_length = sum(map(len, out_lines)) if delta_length > max_delta_size: # The delta is longer than the fulltext, insert a fulltext - type = 'fulltext' - out_lines = [b'f', encode_base128_int(input_len)] + type = "fulltext" + out_lines = [b"f", encode_base128_int(input_len)] out_lines.extend(new_lines) index_lines = [False, False] index_lines.extend([True] * len(new_lines)) else: # this is a worthy delta, output it - type = 'delta' - out_lines[0] = b'd' + type = "delta" + out_lines[0] = b"d" # Update the delta_length to include those two encoded integers out_lines[1] = encode_base128_int(delta_length) # Before insertion @@ -987,8 +1014,7 @@ def _compress(self, key, chunks, input_len, max_delta_size, soft=False): self.endpoint = self._delta_index.endpoint self.input_bytes += input_len chunk_end = len(self.chunks) - self.labels_deltas[key] = (start, chunk_start, - self.endpoint, chunk_end) + self.labels_deltas[key] = (start, chunk_start, self.endpoint, chunk_end) return start, self.endpoint, type @@ -1011,7 +1037,7 @@ class PyrexGroupCompressor(_CommonGroupCompressor): def __init__(self, settings=None): super().__init__(settings) - max_bytes_to_index = self._settings.get('max_bytes_to_index', 0) + max_bytes_to_index = self._settings.get("max_bytes_to_index", 0) self._delta_index = DeltaIndex(max_bytes_to_index=max_bytes_to_index) def _compress(self, key, chunks, input_len, max_delta_size, soft=False): @@ -1026,22 +1052,24 @@ def _compress(self, key, chunks, input_len, max_delta_size, soft=False): # inventory pages, and 5.8% increase for text pages # new_chunks = ['label:%s\nsha1:%s\n' % (label, sha1)] if self._delta_index._source_offset != self.endpoint: - raise AssertionError('_source_offset != endpoint' - ' somehow the DeltaIndex got out of sync with' - ' the output lines') - bytes = b''.join(chunks) + raise AssertionError( + "_source_offset != endpoint" + " somehow the DeltaIndex got out of sync with" + " the output lines" + ) + bytes = b"".join(chunks) delta = self._delta_index.make_delta(bytes, max_delta_size) if delta is None: - type = 'fulltext' + type = "fulltext" enc_length = encode_base128_int(input_len) len_mini_header = 1 + len(enc_length) self._delta_index.add_source(bytes, len_mini_header) - new_chunks = [b'f', enc_length] + chunks + new_chunks = [b"f", enc_length] + chunks else: - type = 'delta' + type = "delta" enc_length = encode_base128_int(len(delta)) len_mini_header = 1 + len(enc_length) - new_chunks = [b'd', enc_length, delta] + new_chunks = [b"d", enc_length, delta] self._delta_index.add_delta_source(delta, len_mini_header) # Before insertion start = self.endpoint @@ -1050,11 +1078,12 @@ def _compress(self, key, chunks, input_len, max_delta_size, soft=False): self._output_chunks(new_chunks) self.input_bytes += input_len chunk_end = len(self.chunks) - self.labels_deltas[key] = (start, chunk_start, - self.endpoint, chunk_end) + self.labels_deltas[key] = (start, chunk_start, self.endpoint, chunk_end) if not self._delta_index._source_offset == self.endpoint: - raise AssertionError('the delta index is out of sync' - f'with the output lines {self._delta_index._source_offset} != {self.endpoint}') + raise AssertionError( + "the delta index is out of sync" + f"with the output lines {self._delta_index._source_offset} != {self.endpoint}" + ) return start, self.endpoint, type def _output_chunks(self, new_chunks): @@ -1081,25 +1110,30 @@ def make_pack_factory(graph, delta, keylength, inconsistency_fatal=True): """ from .pack import ContainerWriter from .pack_repo import _DirectPackAccess + def factory(transport): parents = graph ref_length = 0 if graph: ref_length = 1 - graph_index = BTreeBuilder(reference_lists=ref_length, - key_elements=keylength) - stream = transport.open_write_stream('newpack') + graph_index = BTreeBuilder(reference_lists=ref_length, key_elements=keylength) + stream = transport.open_write_stream("newpack") writer = ContainerWriter(stream.write) writer.begin() - index = _GCGraphIndex(graph_index, lambda: True, parents=parents, - add_callback=graph_index.add_nodes, - inconsistency_fatal=inconsistency_fatal) + index = _GCGraphIndex( + graph_index, + lambda: True, + parents=parents, + add_callback=graph_index.add_nodes, + inconsistency_fatal=inconsistency_fatal, + ) access = _DirectPackAccess({}) - access.set_writer(writer, graph_index, (transport, 'newpack')) + access.set_writer(writer, graph_index, (transport, "newpack")) result = GroupCompressVersionedFiles(index, access, delta) result.stream = stream result.writer = writer return result + return factory @@ -1197,13 +1231,15 @@ def yield_factories(self, full_flush=False): if block_read_memo != read_memo: raise AssertionError( "block_read_memo out of sync with read_memo" - f"({block_read_memo!r} != {read_memo!r})") + f"({block_read_memo!r} != {read_memo!r})" + ) self.batch_memos[read_memo] = block memos_to_get_stack.pop() else: block = self.batch_memos[read_memo] - self.manager = _LazyGroupContentManager(block, - get_compressor_settings=self._get_compressor_settings) + self.manager = _LazyGroupContentManager( + block, get_compressor_settings=self._get_compressor_settings + ) self.last_read_memo = read_memo start, end = index_memo[3:5] self.manager.add_factory(key, parents, start, end) @@ -1227,11 +1263,11 @@ class GroupCompressVersionedFiles(VersionedFilesWithFallbacks): # versus running out of memory trying to track everything. The default max # gives 100% sampling of a 1MB file. _DEFAULT_MAX_BYTES_TO_INDEX = 1024 * 1024 - _DEFAULT_COMPRESSOR_SETTINGS = {'max_bytes_to_index': - _DEFAULT_MAX_BYTES_TO_INDEX} + _DEFAULT_COMPRESSOR_SETTINGS = {"max_bytes_to_index": _DEFAULT_MAX_BYTES_TO_INDEX} - def __init__(self, index, access, delta=True, _unadded_refs=None, - _group_cache=None): + def __init__( + self, index, access, delta=True, _unadded_refs=None, _group_cache=None + ): """Create a GroupCompressVersionedFiles object. :param index: The index object storing access and graph data. @@ -1254,14 +1290,25 @@ def __init__(self, index, access, delta=True, _unadded_refs=None, def without_fallbacks(self): """Return a clone of this object without any fallbacks configured.""" - return GroupCompressVersionedFiles(self._index, self._access, - self._delta, _unadded_refs=dict( - self._unadded_refs), - _group_cache=self._group_cache) - - def add_lines(self, key, parents, lines, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False, - check_content=True): + return GroupCompressVersionedFiles( + self._index, + self._access, + self._delta, + _unadded_refs=dict(self._unadded_refs), + _group_cache=self._group_cache, + ) + + def add_lines( + self, + key, + parents, + lines, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + check_content=True, + ): r"""Add a text to the store. :param key: The key tuple of the text to add. @@ -1301,12 +1348,22 @@ def add_lines(self, key, parents, lines, parent_texts=None, self._check_lines_are_lines(lines) return self.add_content( ChunkedContentFactory( - key, parents, osutils.sha_strings(lines), lines, chunks_are_lines=True), - parent_texts, left_matching_blocks, nostore_sha, random_id) - - def add_content(self, factory, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, - random_id=False): + key, parents, osutils.sha_strings(lines), lines, chunks_are_lines=True + ), + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + ) + + def add_content( + self, + factory, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + ): """Add a text to the store. :param factory: A ContentFactory that can be used to retrieve the key, @@ -1339,8 +1396,11 @@ def add_content(self, factory, parent_texts=None, # an empty tuple instead. parents = () # double handling for now. Make it work until then. - sha1, length = list(self._insert_record_stream( - [factory], random_id=random_id, nostore_sha=nostore_sha))[0] + sha1, length = list( + self._insert_record_stream( + [factory], random_id=random_id, nostore_sha=nostore_sha + ) + )[0] return sha1, length, None def add_fallback_versioned_files(self, a_versioned_files): @@ -1357,17 +1417,18 @@ def annotate(self, key): def get_annotator(self): from ..annotate import Annotator + return Annotator(self) def check(self, progress_bar=None, keys=None): """See VersionedFiles.check().""" if keys is None: keys = self.keys() - for record in self.get_record_stream(keys, 'unordered', True): - for _chunk in record.iter_bytes_as('chunked'): + for record in self.get_record_stream(keys, "unordered", True): + for _chunk in record.iter_bytes_as("chunked"): pass else: - return self.get_record_stream(keys, 'unordered', True) + return self.get_record_stream(keys, "unordered", True) def clear_cache(self): """See VersionedFiles.clear_cache().""" @@ -1485,18 +1546,18 @@ def get_record_stream(self, keys, ordering, include_delta_closure): keys = set(keys) if not keys: return - if (not self._index.has_graph - and ordering in ('topological', 'groupcompress')): + if not self._index.has_graph and ordering in ("topological", "groupcompress"): # Cannot topological order when no graph has been stored. # but we allow 'as-requested' or 'unordered' - ordering = 'unordered' + ordering = "unordered" remaining_keys = keys while True: try: keys = set(remaining_keys) - for content_factory in self._get_remaining_record_stream(keys, - orig_keys, ordering, include_delta_closure): + for content_factory in self._get_remaining_record_stream( + keys, orig_keys, ordering, include_delta_closure + ): remaining_keys.discard(content_factory.key) yield content_factory return @@ -1538,7 +1599,8 @@ def _get_ordered_source_keys(self, ordering, parent_map, key_to_source_map): the defined order, regardless of source. """ from .. import tsort - if ordering == 'topological': + + if ordering == "topological": present_keys = tsort.topo_sort(parent_map) else: # ordering == 'groupcompress' @@ -1558,8 +1620,9 @@ def _get_ordered_source_keys(self, ordering, parent_map, key_to_source_map): source_keys[-1][1].append(key) return source_keys - def _get_as_requested_source_keys(self, orig_keys, locations, unadded_keys, - key_to_source_map): + def _get_as_requested_source_keys( + self, orig_keys, locations, unadded_keys, key_to_source_map + ): source_keys = [] current_source = None for key in orig_keys: @@ -1575,12 +1638,12 @@ def _get_as_requested_source_keys(self, orig_keys, locations, unadded_keys, source_keys[-1][1].append(key) return source_keys - def _get_io_ordered_source_keys(self, locations, unadded_keys, - source_result): + def _get_io_ordered_source_keys(self, locations, unadded_keys, source_result): def get_group(key): # This is the group the bytes are stored in, followed by the # location in the group return locations[key][0] + # We don't have an ordering for keys in the in-memory object, but # lets process the in-memory ones first. present_keys = list(unadded_keys) @@ -1590,8 +1653,9 @@ def get_group(key): source_keys.extend(source_result) return source_keys - def _get_remaining_record_stream(self, keys, orig_keys, ordering, - include_delta_closure): + def _get_remaining_record_stream( + self, keys, orig_keys, ordering, include_delta_closure + ): """Get a stream of records for keys. :param keys: The keys to include. @@ -1607,35 +1671,41 @@ def _get_remaining_record_stream(self, keys, orig_keys, ordering, unadded_keys = set(self._unadded_refs).intersection(keys) missing = keys.difference(locations) missing.difference_update(unadded_keys) - (fallback_parent_map, key_to_source_map, - source_result) = self._find_from_fallback(missing) - if ordering in ('topological', 'groupcompress'): + ( + fallback_parent_map, + key_to_source_map, + source_result, + ) = self._find_from_fallback(missing) + if ordering in ("topological", "groupcompress"): # would be better to not globally sort initially but instead # start with one key, recurse to its oldest parent, then grab # everything in the same group, etc. - parent_map = {key: details[2] for key, details in - locations.items()} + parent_map = {key: details[2] for key, details in locations.items()} for key in unadded_keys: parent_map[key] = self._unadded_refs[key] parent_map.update(fallback_parent_map) - source_keys = self._get_ordered_source_keys(ordering, parent_map, - key_to_source_map) - elif ordering == 'as-requested': - source_keys = self._get_as_requested_source_keys(orig_keys, - locations, unadded_keys, key_to_source_map) + source_keys = self._get_ordered_source_keys( + ordering, parent_map, key_to_source_map + ) + elif ordering == "as-requested": + source_keys = self._get_as_requested_source_keys( + orig_keys, locations, unadded_keys, key_to_source_map + ) else: # We want to yield the keys in a semi-optimal (read-wise) ordering. # Otherwise we thrash the _group_cache and destroy performance - source_keys = self._get_io_ordered_source_keys(locations, - unadded_keys, source_result) + source_keys = self._get_io_ordered_source_keys( + locations, unadded_keys, source_result + ) for key in missing: yield AbsentContentFactory(key) # Batch up as many keys as we can until either: # - we encounter an unadded ref, or # - we run out of keys, or # - the total bytes to retrieve for this batch > BATCH_SIZE - batcher = _BatchingBlockFetcher(self, locations, - get_compressor_settings=self._get_compressor_settings) + batcher = _BatchingBlockFetcher( + self, locations, get_compressor_settings=self._get_compressor_settings + ) for source, keys in source_keys: if source is self: for key in keys: @@ -1652,20 +1722,22 @@ def _get_remaining_record_stream(self, keys, orig_keys, ordering, yield from batcher.yield_factories() else: yield from batcher.yield_factories(full_flush=True) - yield from source.get_record_stream(keys, ordering, - include_delta_closure) + yield from source.get_record_stream( + keys, ordering, include_delta_closure + ) yield from batcher.yield_factories(full_flush=True) def get_sha1s(self, keys): """See VersionedFiles.get_sha1s().""" result = {} - for record in self.get_record_stream(keys, 'unordered', True): + for record in self.get_record_stream(keys, "unordered", True): if record.sha1 is not None: result[record.key] = record.sha1 else: - if record.storage_kind != 'absent': + if record.storage_kind != "absent": result[record.key] = osutils.sha_strings( - record.iter_bytes_as('chunked')) + record.iter_bytes_as("chunked") + ) return result def insert_record_stream(self, stream): @@ -1684,30 +1756,34 @@ def insert_record_stream(self, stream): def _get_compressor_settings(self): from ..config import GlobalConfig + if self._max_bytes_to_index is None: # TODO: VersionedFiles don't know about their containing # repository, so they don't have much of an idea about their # location. So for now, this is only a global option. c = GlobalConfig() - val = c.get_user_option('bzr.groupcompress.max_bytes_to_index') + val = c.get_user_option("bzr.groupcompress.max_bytes_to_index") if val is not None: try: val = int(val) except ValueError: - trace.warning('Value for ' - '"bzr.groupcompress.max_bytes_to_index"' - f' {val!r} is not an integer') + trace.warning( + "Value for " + '"bzr.groupcompress.max_bytes_to_index"' + f" {val!r} is not an integer" + ) val = None if val is None: val = self._DEFAULT_MAX_BYTES_TO_INDEX self._max_bytes_to_index = val - return {'max_bytes_to_index': self._max_bytes_to_index} + return {"max_bytes_to_index": self._max_bytes_to_index} def _make_group_compressor(self): return GroupCompressor(self._get_compressor_settings()) - def _insert_record_stream(self, stream, random_id=False, nostore_sha=None, - reuse_blocks=True): + def _insert_record_stream( + self, stream, random_id=False, nostore_sha=None, reuse_blocks=True + ): """Internal core to insert a record stream into this container. This helper function has a different interface than insert_record_stream @@ -1733,6 +1809,7 @@ def get_adapter(adapter_key): adapter = adapter_factory(self) adapters[adapter_key] = adapter return adapter + # This will go up to fulltexts for gc to gc fetching, which isn't # ideal. self._compressor = self._make_group_compressor() @@ -1749,8 +1826,7 @@ def flush(): # the fulltext content at this point. Note that sometimes we # will want it later (streaming CHK pages), but most of the # time we won't (everything else) - index, start, length = self._access.add_raw_record( - None, bytes_len, chunks) + index, start, length = self._access.add_raw_record(None, bytes_len, chunks) nodes = [] for key, reads, refs in keys_to_add: nodes.append((key, b"%d %d %s" % (start, length, reads), refs)) @@ -1769,12 +1845,17 @@ def flush(): reuse_this_block = reuse_blocks for record in stream: # Raise an error when a record is missing. - if record.storage_kind == 'absent': + if record.storage_kind == "absent": raise errors.RevisionNotPresent(record.key, self) if random_id: if record.key in inserted_keys: - trace.note(gettext('Insert claimed random_id=True,' - ' but then inserted %r two times'), record.key) + trace.note( + gettext( + "Insert claimed random_id=True," + " but then inserted %r two times" + ), + record.key, + ) continue inserted_keys.add(record.key) if reuse_blocks: @@ -1783,7 +1864,7 @@ def flush(): # We only check on the first record (groupcompress-block) not # on all of the (groupcompress-block-ref) entries. # The reuse_this_block flag is then kept for as long as - if record.storage_kind == 'groupcompress-block': + if record.storage_kind == "groupcompress-block": # Check to see if we really want to re-use this block insert_manager = record._manager reuse_this_block = insert_manager.check_is_well_utilized() @@ -1791,25 +1872,33 @@ def flush(): reuse_this_block = False if reuse_this_block: # We still want to reuse this block - if record.storage_kind == 'groupcompress-block': + if record.storage_kind == "groupcompress-block": # Insert the raw block into the target repo insert_manager = record._manager bytes_len, chunks = record._manager._block.to_chunks() _, start, length = self._access.add_raw_record( - None, bytes_len, chunks) + None, bytes_len, chunks + ) block_start = start block_length = length - if record.storage_kind in ('groupcompress-block', - 'groupcompress-block-ref'): + if record.storage_kind in ( + "groupcompress-block", + "groupcompress-block-ref", + ): if insert_manager is None: - raise AssertionError('No insert_manager set') + raise AssertionError("No insert_manager set") if insert_manager is not record._manager: - raise AssertionError('insert_manager does not match' - ' the current record, we cannot be positive' - ' that the appropriate content was inserted.' - ) - value = b"%d %d %d %d" % (block_start, block_length, - record._start, record._end) + raise AssertionError( + "insert_manager does not match" + " the current record, we cannot be positive" + " that the appropriate content was inserted." + ) + value = b"%d %d %d %d" % ( + block_start, + block_length, + record._start, + record._end, + ) nodes = [(record.key, value, (record.parents,))] # TODO: Consider buffering up many nodes to be added, not # sure how much overhead this has, but we're seeing @@ -1817,38 +1906,44 @@ def flush(): self._index.add_records(nodes, random_id=random_id) continue try: - chunks = record.get_bytes_as('chunked') + chunks = record.get_bytes_as("chunked") except UnavailableRepresentation: - adapter_key = record.storage_kind, 'chunked' + adapter_key = record.storage_kind, "chunked" adapter = get_adapter(adapter_key) - chunks = adapter.get_bytes(record, 'chunked') + chunks = adapter.get_bytes(record, "chunked") chunks_len = record.size if chunks_len is None: chunks_len = sum(map(len, chunks)) if len(record.key) > 1: prefix = record.key[0] - soft = (prefix == last_prefix) + soft = prefix == last_prefix else: prefix = None soft = False if max_fulltext_len < chunks_len: max_fulltext_len = chunks_len max_fulltext_prefix = prefix - (found_sha1, start_point, end_point, - type) = self._compressor.compress( - record.key, chunks, chunks_len, record.sha1, soft=soft, - nostore_sha=nostore_sha) + (found_sha1, start_point, end_point, type) = self._compressor.compress( + record.key, + chunks, + chunks_len, + record.sha1, + soft=soft, + nostore_sha=nostore_sha, + ) # delta_ratio = float(chunks_len) / (end_point - start_point) # Check if we want to continue to include that text - if (prefix == max_fulltext_prefix - and end_point < 2 * max_fulltext_len): + if prefix == max_fulltext_prefix and end_point < 2 * max_fulltext_len: # As long as we are on the same file_id, we will fill at least # 2 * max_fulltext_len start_new_block = False elif end_point > 4 * 1024 * 1024: start_new_block = True - elif (prefix is not None and prefix != last_prefix - and end_point > 2 * 1024 * 1024): + elif ( + prefix is not None + and prefix != last_prefix + and end_point > 2 * 1024 * 1024 + ): start_new_block = True else: start_new_block = False @@ -1857,11 +1952,11 @@ def flush(): self._compressor.pop_last() flush() max_fulltext_len = chunks_len - (found_sha1, start_point, end_point, - type) = self._compressor.compress( - record.key, chunks, chunks_len, record.sha1) + (found_sha1, start_point, end_point, type) = self._compressor.compress( + record.key, chunks, chunks_len, record.sha1 + ) if record.key[-1] is None: - key = record.key[:-1] + (b'sha1:' + found_sha1,) + key = record.key[:-1] + (b"sha1:" + found_sha1,) else: key = record.key self._unadded_refs[key] = record.parents @@ -1872,8 +1967,7 @@ def flush(): else: parents = None refs = static_tuple.StaticTuple(parents) - keys_to_add.append( - (key, b'%d %d' % (start_point, end_point), refs)) + keys_to_add.append((key, b"%d %d" % (start_point, end_point), refs)) if len(keys_to_add): flush() self._compressor = None @@ -1904,22 +1998,23 @@ def iter_lines_added_or_present_in_keys(self, keys, pb=None): # we don't care about inclusions, the caller cares. # but we need to setup a list of records to visit. # we need key, position, length - for key_idx, record in enumerate(self.get_record_stream(keys, - 'unordered', True)): + for key_idx, record in enumerate( + self.get_record_stream(keys, "unordered", True) + ): # XXX: todo - optimise to use less than full texts. key = record.key if pb is not None: - pb.update('Walking content', key_idx, total) - if record.storage_kind == 'absent': + pb.update("Walking content", key_idx, total) + if record.storage_kind == "absent": raise errors.RevisionNotPresent(key, self) - for line in record.iter_bytes_as('lines'): + for line in record.iter_bytes_as("lines"): yield line, key if pb is not None: - pb.update('Walking content', total, total) + pb.update("Walking content", total, total) def keys(self): """See VersionedFiles.keys.""" - if debug.debug_flag_enabled('evil'): + if debug.debug_flag_enabled("evil"): trace.mutter_callsite(2, "keys scales with size of history") sources = [self._index] + self._immediate_fallback_vfs result = set() @@ -1935,24 +2030,40 @@ class _GCBuildDetails: api, without taking as much memory. """ - __slots__ = ('_index', '_group_start', '_group_end', '_basis_end', - '_delta_end', '_parents') + __slots__ = ( + "_index", + "_group_start", + "_group_end", + "_basis_end", + "_delta_end", + "_parents", + ) - method = 'group' + method = "group" compression_parent = None def __init__(self, parents, position_info): self._parents = parents - (self._index, self._group_start, self._group_end, self._basis_end, - self._delta_end) = position_info + ( + self._index, + self._group_start, + self._group_end, + self._basis_end, + self._delta_end, + ) = position_info def __repr__(self): - return f'{self.__class__.__name__}({self.index_memo}, {self._parents})' + return f"{self.__class__.__name__}({self.index_memo}, {self._parents})" @property def index_memo(self): - return (self._index, self._group_start, self._group_end, - self._basis_end, self._delta_end) + return ( + self._index, + self._group_start, + self._group_end, + self._basis_end, + self._delta_end, + ) @property def record_details(self): @@ -1969,7 +2080,7 @@ def __getitem__(self, offset): elif offset == 3: return self.record_details else: - raise IndexError('offset out of range') + raise IndexError("offset out of range") def __len__(self): return 4 @@ -1978,9 +2089,16 @@ def __len__(self): class _GCGraphIndex: """Mapper from GroupCompressVersionedFiles needs into GraphIndex storage.""" - def __init__(self, graph_index, is_locked, parents=True, - add_callback=None, track_external_parent_refs=False, - inconsistency_fatal=True, track_new_keys=False): + def __init__( + self, + graph_index, + is_locked, + parents=True, + add_callback=None, + track_external_parent_refs=False, + inconsistency_fatal=True, + track_new_keys=False, + ): """Construct a _GCGraphIndex on a graph_index. :param graph_index: An implementation of breezy.index.GraphIndex. @@ -2009,8 +2127,7 @@ def __init__(self, graph_index, is_locked, parents=True, # repeated over and over, this creates a surplus of ints self._int_cache = {} if track_external_parent_refs: - self._key_dependencies = _KeyRefs( - track_new_keys=track_new_keys) + self._key_dependencies = _KeyRefs(track_new_keys=track_new_keys) else: self._key_dependencies = None @@ -2034,33 +2151,36 @@ def add_records(self, records, random_id=False): changed = False keys = {} - for (key, value, refs) in records: + for key, value, refs in records: if not self._parents: if refs: for ref in refs: if ref: - raise knit.KnitCorrupt(self, - "attempt to add node with parents " - "in parentless index.") + raise knit.KnitCorrupt( + self, + "attempt to add node with parents " + "in parentless index.", + ) refs = () changed = True keys[key] = (value, refs) # check for dups if not random_id: present_nodes = self._get_entries(keys) - for (_index, key, value, node_refs) in present_nodes: + for _index, key, value, node_refs in present_nodes: # Sometimes these are passed as a list rather than a tuple node_refs = static_tuple.as_tuples(node_refs) passed = static_tuple.as_tuples(keys[key]) if node_refs != passed[1]: - details = f'{key} {value, node_refs} {passed}' + details = f"{key} {value, node_refs} {passed}" if self._inconsistency_fatal: - raise knit.KnitCorrupt(self, "inconsistent details" - " in add_records: %s" % - details) + raise knit.KnitCorrupt( + self, "inconsistent details" " in add_records: %s" % details + ) else: - trace.warning("inconsistent details in skipped" - " record: %s", details) + trace.warning( + "inconsistent details in skipped" " record: %s", details + ) del keys[key] changed = True if changed: @@ -2144,7 +2264,8 @@ def get_missing_parents(self): # Copied from _KnitGraphIndex.get_missing_parents # We may have false positives, so filter those out. self._key_dependencies.satisfy_refs_for_keys( - self.get_parent_map(self._key_dependencies.get_unsatisfied_refs())) + self.get_parent_map(self._key_dependencies.get_unsatisfied_refs()) + ) return frozenset(self._key_dependencies.get_unsatisfied_refs()) def get_build_details(self, keys): @@ -2187,7 +2308,7 @@ def keys(self): def _node_to_position(self, node): """Convert an index value to position details.""" - bits = node[2].split(b' ') + bits = node[2].split(b" ") # It would be nice not to read the entire gzip. # start and stop are put into _int_cache because they are very common. # They define the 'group' that an entry is in, and many groups can have @@ -2240,6 +2361,7 @@ def scan_unvalidated_index(self, graph_index): try: from ._groupcompress_pyx import DeltaIndex + GroupCompressor = PyrexGroupCompressor except ImportError as e: osutils.failed_to_load_extension(e) diff --git a/breezy/bzr/groupcompress_repo.py b/breezy/bzr/groupcompress_repo.py index ca6d700a65..156301f9b2 100644 --- a/breezy/bzr/groupcompress_repo.py +++ b/breezy/bzr/groupcompress_repo.py @@ -41,8 +41,7 @@ class GCPack(NewPack): - - def __init__(self, pack_collection, upload_suffix='', file_mode=None): + def __init__(self, pack_collection, upload_suffix="", file_mode=None): """Create a NewPack instance. :param pack_collection: A PackCollection into which this is being @@ -64,23 +63,24 @@ def __init__(self, pack_collection, upload_suffix='', file_mode=None): chk_index = index_builder_class(reference_lists=0) else: chk_index = None - Pack.__init__(self, - # Revisions: parents list, no text compression. - index_builder_class(reference_lists=1), - # Inventory: We want to map compression only, but currently the - # knit code hasn't been updated enough to understand that, so we - # have a regular 2-list index giving parents and compression - # source. - index_builder_class(reference_lists=1), - # Texts: per file graph, for all fileids - so one reference list - # and two elements in the key tuple. - index_builder_class(reference_lists=1, key_elements=2), - # Signatures: Just blobs to store, no compression, no parents - # listing. - index_builder_class(reference_lists=0), - # CHK based storage - just blobs, no compression or parents. - chk_index=chk_index - ) + Pack.__init__( + self, + # Revisions: parents list, no text compression. + index_builder_class(reference_lists=1), + # Inventory: We want to map compression only, but currently the + # knit code hasn't been updated enough to understand that, so we + # have a regular 2-list index giving parents and compression + # source. + index_builder_class(reference_lists=1), + # Texts: per file graph, for all fileids - so one reference list + # and two elements in the key tuple. + index_builder_class(reference_lists=1, key_elements=2), + # Signatures: Just blobs to store, no compression, no parents + # listing. + index_builder_class(reference_lists=0), + # CHK based storage - just blobs, no compression or parents. + chk_index=chk_index, + ) self._pack_collection = pack_collection # When we make readonly indices, we need this. self.index_class = pack_collection._index_class @@ -108,11 +108,16 @@ def __init__(self, pack_collection, upload_suffix='', file_mode=None): self.start_time = time.time() # open an output stream for the data added to the pack. self.write_stream = self.upload_transport.open_write_stream( - self.random_name, mode=self._file_mode) - if debug.debug_flag_enabled('pack'): - trace.mutter('%s: create_pack: pack stream open: %s%s t+%6.3fs', - time.ctime(), self.upload_transport.base, self.random_name, - time.time() - self.start_time) + self.random_name, mode=self._file_mode + ) + if debug.debug_flag_enabled("pack"): + trace.mutter( + "%s: create_pack: pack stream open: %s%s t+%6.3fs", + time.ctime(), + self.upload_transport.base, + self.random_name, + time.time() - self.start_time, + ) # A list of byte sequences to be written to the new pack, and the # aggregate size of them. Stored as a list rather than separate # variables so that the _write_data closure below can update them. @@ -123,23 +128,29 @@ def __init__(self, pack_collection, upload_suffix='', file_mode=None): # so that the variables are locals, and faster than accessing object # members. - def _write_data(data, flush=False, _buffer=self._buffer, - _write=self.write_stream.write, _update=self._hash.update): + def _write_data( + data, + flush=False, + _buffer=self._buffer, + _write=self.write_stream.write, + _update=self._hash.update, + ): _buffer[0].append(data) _buffer[1] += len(data) # buffer cap if _buffer[1] > self._cache_limit or flush: - data = b''.join(_buffer[0]) + data = b"".join(_buffer[0]) _write(data) _update(data) _buffer[:] = [[], 0] + # expose this on self, for the occasion when clients want to add data. self._write_data = _write_data # a pack writer object to serialise pack records. self._writer = pack.ContainerWriter(self._write_data) self._writer.begin() # what state is the pack in? (open, finished, aborted) - self._state = 'open' + self._state = "open" # no name until we finish writing the content self.name = None @@ -157,7 +168,6 @@ def _check_references(self): class ResumedGCPack(ResumedPack): - def _check_references(self): """Make sure our external compression parents are present.""" # See GCPack._check_references for why this is empty @@ -171,16 +181,20 @@ def _get_external_refs(self, index): class GCCHKPacker(Packer): """This class understand what it takes to collect a GCCHK repo.""" - def __init__(self, pack_collection, packs, suffix, revision_ids=None, - reload_func=None): - super().__init__(pack_collection, packs, suffix, - revision_ids=revision_ids, - reload_func=reload_func) + def __init__( + self, pack_collection, packs, suffix, revision_ids=None, reload_func=None + ): + super().__init__( + pack_collection, + packs, + suffix, + revision_ids=revision_ids, + reload_func=reload_func, + ) self._pack_collection = pack_collection # ATM, We only support this for GCCHK repositories if pack_collection.chk_index is None: - raise AssertionError( - 'pack_collection.chk_index should not be None') + raise AssertionError("pack_collection.chk_index should not be None") self._gather_text_refs = False self._chk_id_roots = [] self._chk_p_id_roots = [] @@ -190,12 +204,12 @@ def __init__(self, pack_collection, packs, suffix, revision_ids=None, def _get_progress_stream(self, source_vf, keys, message, pb): def pb_stream(): - substream = source_vf.get_record_stream( - keys, 'groupcompress', True) + substream = source_vf.get_record_stream(keys, "groupcompress", True) for idx, record in enumerate(substream): if pb is not None: pb.update(message, idx + 1, len(keys)) yield record + return pb_stream() def _get_filtered_inv_stream(self, source_vf, keys, message, pb=None): @@ -205,21 +219,20 @@ def _get_filtered_inv_stream(self, source_vf, keys, message, pb=None): def _filtered_inv_stream(): id_roots_set = set() p_id_roots_set = set() - stream = source_vf.get_record_stream(keys, 'groupcompress', True) + stream = source_vf.get_record_stream(keys, "groupcompress", True) for idx, record in enumerate(stream): # Inventories should always be with revisions; assume success. - lines = record.get_bytes_as('lines') - chk_inv = inventory.CHKInventory.deserialise( - None, lines, record.key) + lines = record.get_bytes_as("lines") + chk_inv = inventory.CHKInventory.deserialise(None, lines, record.key) if pb is not None: - pb.update('inv', idx, total_keys) + pb.update("inv", idx, total_keys) key = chk_inv.id_to_entry.key() if key not in id_roots_set: self._chk_id_roots.append(key) id_roots_set.add(key) p_id_map = chk_inv.parent_id_basename_to_file_id if p_id_map is None: - raise AssertionError('Parent id -> file_id map not set') + raise AssertionError("Parent id -> file_id map not set") key = p_id_map.key() if key not in p_id_roots_set: p_id_roots_set.add(key) @@ -229,6 +242,7 @@ def _filtered_inv_stream(): # don't need these sets anymore id_roots_set.clear() p_id_roots_set.clear() + return _filtered_inv_stream() def _get_chk_streams(self, source_vf, keys, pb=None): @@ -272,8 +286,10 @@ def handle_internal_node(node): # pages as 'external_references' so that we # always fill them in for stacked branches if value not in next_keys and value in remaining_keys: # noqa: B023 - keys_by_search_prefix.setdefault(prefix, # noqa: B023 - []).append(value) + keys_by_search_prefix.setdefault( # noqa: B023 + prefix, + [], + ).append(value) next_keys.add(value) # noqa: B023 def handle_leaf_node(node): @@ -283,27 +299,32 @@ def handle_leaf_node(node): self._text_refs.add(chk_map._bytes_to_text_key(bytes)) def next_stream(): - stream = source_vf.get_record_stream(cur_keys, # noqa: B023 - 'as-requested', True) + stream = source_vf.get_record_stream( + cur_keys, # noqa: B023 + "as-requested", + True, + ) for record in stream: - if record.storage_kind == 'absent': + if record.storage_kind == "absent": # An absent CHK record: we assume that the missing # record is in a different pack - e.g. a page not # altered by the commit we're packing. continue - bytes = record.get_bytes_as('fulltext') + bytes = record.get_bytes_as("fulltext") # We don't care about search_key_func for this code, # because we only care about external references. - node = chk_map._deserialise(bytes, record.key, - search_key_func=None) + node = chk_map._deserialise( + bytes, record.key, search_key_func=None + ) if isinstance(node, chk_map.InternalNode): handle_internal_node(node) elif parse_leaf_nodes: handle_leaf_node(node) counter[0] += 1 if pb is not None: - pb.update('chk node', counter[0], total_keys) + pb.update("chk node", counter[0], total_keys) yield record + yield next_stream() # Double check that we won't be emitting any keys twice # If we get rid of the pre-calculation of all keys, we could @@ -316,43 +337,46 @@ def next_stream(): cur_keys = [] for prefix in sorted(keys_by_search_prefix): cur_keys.extend(keys_by_search_prefix.pop(prefix)) - for stream in _get_referenced_stream(self._chk_id_roots, - self._gather_text_refs): + + for stream in _get_referenced_stream( + self._chk_id_roots, self._gather_text_refs + ): yield stream del self._chk_id_roots # while it isn't really possible for chk_id_roots to not be in the # local group of packs, it is possible that the tree shape has not # changed recently, so we need to filter _chk_p_id_roots by the # available keys - chk_p_id_roots = [key for key in self._chk_p_id_roots - if key in remaining_keys] + chk_p_id_roots = [key for key in self._chk_p_id_roots if key in remaining_keys] del self._chk_p_id_roots for stream in _get_referenced_stream(chk_p_id_roots, False): yield stream if remaining_keys: - trace.mutter('There were %d keys in the chk index, %d of which' - ' were not referenced', total_keys, - len(remaining_keys)) + trace.mutter( + "There were %d keys in the chk index, %d of which" + " were not referenced", + total_keys, + len(remaining_keys), + ) if self.revision_ids is None: - stream = source_vf.get_record_stream(remaining_keys, - 'unordered', True) + stream = source_vf.get_record_stream(remaining_keys, "unordered", True) yield stream def _build_vf(self, index_name, parents, delta, for_write=False): """Build a VersionedFiles instance on top of this group of packs.""" - index_name = index_name + '_index' + index_name = index_name + "_index" index_to_pack = {} - access = _DirectPackAccess(index_to_pack, - reload_func=self._reload_func) + access = _DirectPackAccess(index_to_pack, reload_func=self._reload_func) if for_write: # Use new_pack if self.new_pack is None: - raise AssertionError('No new pack has been set') + raise AssertionError("No new pack has been set") index = getattr(self.new_pack, index_name) index_to_pack[index] = self.new_pack.access_tuple() index.set_optimize(for_size=True) - access.set_writer(self.new_pack._writer, index, - self.new_pack.access_tuple()) + access.set_writer( + self.new_pack._writer, index, self.new_pack.access_tuple() + ) add_callback = index.add_nodes else: indices = [] @@ -363,50 +387,58 @@ def _build_vf(self, index_name, parents, delta, for_write=False): index = _mod_index.CombinedGraphIndex(indices) add_callback = None vf = GroupCompressVersionedFiles( - _GCGraphIndex(index, - add_callback=add_callback, - parents=parents, - is_locked=self._pack_collection.repo.is_locked), + _GCGraphIndex( + index, + add_callback=add_callback, + parents=parents, + is_locked=self._pack_collection.repo.is_locked, + ), access=access, - delta=delta) + delta=delta, + ) return vf def _build_vfs(self, index_name, parents, delta): """Build the source and target VersionedFiles.""" - source_vf = self._build_vf(index_name, parents, - delta, for_write=False) - target_vf = self._build_vf(index_name, parents, - delta, for_write=True) + source_vf = self._build_vf(index_name, parents, delta, for_write=False) + target_vf = self._build_vf(index_name, parents, delta, for_write=True) return source_vf, target_vf - def _copy_stream(self, source_vf, target_vf, keys, message, vf_to_stream, - pb_offset): - trace.mutter('repacking %d %s', len(keys), message) - self.pb.update(f'repacking {message}', pb_offset) + def _copy_stream( + self, source_vf, target_vf, keys, message, vf_to_stream, pb_offset + ): + trace.mutter("repacking %d %s", len(keys), message) + self.pb.update(f"repacking {message}", pb_offset) with ui.ui_factory.nested_progress_bar() as child_pb: stream = vf_to_stream(source_vf, keys, message, child_pb) for _, _ in target_vf._insert_record_stream( - stream, random_id=True, reuse_blocks=False): + stream, random_id=True, reuse_blocks=False + ): pass def _copy_revision_texts(self): - source_vf, target_vf = self._build_vfs('revision', True, False) + source_vf, target_vf = self._build_vfs("revision", True, False) if not self.revision_keys: # We are doing a full fetch, aka 'pack' self.revision_keys = source_vf.keys() - self._copy_stream(source_vf, target_vf, self.revision_keys, - 'revisions', self._get_progress_stream, 1) + self._copy_stream( + source_vf, + target_vf, + self.revision_keys, + "revisions", + self._get_progress_stream, + 1, + ) def _copy_inventory_texts(self): - source_vf, target_vf = self._build_vfs('inventory', True, True) + source_vf, target_vf = self._build_vfs("inventory", True, True) # It is not sufficient to just use self.revision_keys, as stacked # repositories can have more inventories than they have revisions. # One alternative would be to do something with # get_parent_map(self.revision_keys), but that shouldn't be any faster # than this. inventory_keys = source_vf.keys() - missing_inventories = set( - self.revision_keys).difference(inventory_keys) + missing_inventories = set(self.revision_keys).difference(inventory_keys) if missing_inventories: # Go back to the original repo, to see if these are really missing # https://bugs.launchpad.net/bzr/+bug/437003 @@ -418,50 +450,67 @@ def _copy_inventory_texts(self): really_missing = missing_inventories.difference(pmap) if really_missing: missing_inventories = sorted(really_missing) - raise ValueError(f'We are missing inventories for revisions: {missing_inventories}') - self._copy_stream(source_vf, target_vf, inventory_keys, - 'inventories', self._get_filtered_inv_stream, 2) + raise ValueError( + f"We are missing inventories for revisions: {missing_inventories}" + ) + self._copy_stream( + source_vf, + target_vf, + inventory_keys, + "inventories", + self._get_filtered_inv_stream, + 2, + ) def _get_chk_vfs_for_copy(self): - return self._build_vfs('chk', False, False) + return self._build_vfs("chk", False, False) def _copy_chk_texts(self): source_vf, target_vf = self._get_chk_vfs_for_copy() # TODO: This is technically spurious... if it is a performance issue, # remove it total_keys = source_vf.keys() - trace.mutter('repacking chk: %d id_to_entry roots,' - ' %d p_id_map roots, %d total keys', - len(self._chk_id_roots), len(self._chk_p_id_roots), - len(total_keys)) - self.pb.update('repacking chk', 3) + trace.mutter( + "repacking chk: %d id_to_entry roots," " %d p_id_map roots, %d total keys", + len(self._chk_id_roots), + len(self._chk_p_id_roots), + len(total_keys), + ) + self.pb.update("repacking chk", 3) with ui.ui_factory.nested_progress_bar() as child_pb: - for stream in self._get_chk_streams(source_vf, total_keys, - pb=child_pb): + for stream in self._get_chk_streams(source_vf, total_keys, pb=child_pb): for _, _ in target_vf._insert_record_stream( - stream, random_id=True, reuse_blocks=False): + stream, random_id=True, reuse_blocks=False + ): pass def _copy_text_texts(self): - source_vf, target_vf = self._build_vfs('text', True, True) + source_vf, target_vf = self._build_vfs("text", True, True) # XXX: We don't walk the chk map to determine referenced (file_id, # revision_id) keys. We don't do it yet because you really need # to filter out the ones that are present in the parents of the # rev just before the ones you are copying, otherwise the filter # is grabbing too many keys... text_keys = source_vf.keys() - self._copy_stream(source_vf, target_vf, text_keys, - 'texts', self._get_progress_stream, 4) + self._copy_stream( + source_vf, target_vf, text_keys, "texts", self._get_progress_stream, 4 + ) def _copy_signature_texts(self): - source_vf, target_vf = self._build_vfs('signature', False, False) + source_vf, target_vf = self._build_vfs("signature", False, False) signature_keys = source_vf.keys() signature_keys.intersection(self.revision_keys) - self._copy_stream(source_vf, target_vf, signature_keys, - 'signatures', self._get_progress_stream, 5) + self._copy_stream( + source_vf, + target_vf, + signature_keys, + "signatures", + self._get_progress_stream, + 5, + ) def _create_pack_from_packs(self): - self.pb.update('repacking', 0, 7) + self.pb.update("repacking", 0, 7) self.new_pack = self.open_pack() # Is this necessary for GC ? self.new_pack.set_write_cache_size(1024 * 1024) @@ -479,11 +528,12 @@ def _create_pack_from_packs(self): old_pack = self.packs[0] if old_pack.name == self.new_pack._hash.hexdigest(): # The single old pack was already optimally packed. - trace.mutter('single pack %s was already optimally packed', - old_pack.name) + trace.mutter( + "single pack %s was already optimally packed", old_pack.name + ) self.new_pack.abort() return None - self.pb.update('finishing repack', 6, 7) + self.pb.update("finishing repack", 6, 7) self.new_pack.finish() self._pack_collection.allocate(self.new_pack) return self.new_pack @@ -502,26 +552,33 @@ def __init__(self, *args, **kwargs): self._gather_text_refs = True def _copy_inventory_texts(self): - source_vf, target_vf = self._build_vfs('inventory', True, True) - self._copy_stream(source_vf, target_vf, self.revision_keys, - 'inventories', self._get_filtered_inv_stream, 2) + source_vf, target_vf = self._build_vfs("inventory", True, True) + self._copy_stream( + source_vf, + target_vf, + self.revision_keys, + "inventories", + self._get_filtered_inv_stream, + 2, + ) if source_vf.keys() != self.revision_keys: self._data_changed = True def _copy_text_texts(self): """Generate what texts we should have and then copy.""" - source_vf, target_vf = self._build_vfs('text', True, True) - trace.mutter('repacking %d texts', len(self._text_refs)) + source_vf, target_vf = self._build_vfs("text", True, True) + trace.mutter("repacking %d texts", len(self._text_refs)) self.pb.update("repacking texts", 4) # we have three major tasks here: # 1) generate the ideal index repo = self._pack_collection.repo # We want the one we just wrote, so base it on self.new_pack - revision_vf = self._build_vf('revision', True, False, for_write=True) + revision_vf = self._build_vf("revision", True, False, for_write=True) ancestor_keys = revision_vf.get_parent_map(revision_vf.keys()) # Strip keys back into revision_ids. - ancestors = {k[0]: tuple([p[0] for p in parents]) - for k, parents in ancestor_keys.items()} + ancestors = { + k[0]: tuple([p[0] for p in parents]) for k, parents in ancestor_keys.items() + } del ancestor_keys # TODO: _generate_text_key_index should be much cheaper to generate from # a chk repository, rather than the current implementation @@ -562,12 +619,14 @@ def _copy_text_texts(self): # 3) bulk copy the data, updating records than need it def _update_parents_for_texts(): - stream = source_vf.get_record_stream(self._text_refs, - 'groupcompress', False) + stream = source_vf.get_record_stream( + self._text_refs, "groupcompress", False + ) for record in stream: if record.key in new_parent_keys: record.parents = new_parent_keys[record.key] yield record + target_vf.insert_record_stream(_update_parents_for_texts()) def _use_pack(self, new_pack): @@ -592,12 +651,12 @@ def _exhaust_stream(self, source_vf, keys, message, vf_to_stream, pb_offset): This is useful to get the side-effects of generating a stream. """ - self.pb.update(f'scanning {message}', pb_offset) + self.pb.update(f"scanning {message}", pb_offset) with ui.ui_factory.nested_progress_bar() as child_pb: list(vf_to_stream(source_vf, keys, message, child_pb)) def _copy_inventory_texts(self): - source_vf, target_vf = self._build_vfs('inventory', True, True) + source_vf, target_vf = self._build_vfs("inventory", True, True) source_chk_vf, target_chk_vf = self._get_chk_vfs_for_copy() inventory_keys = source_vf.keys() # First, copy the existing CHKs on the assumption that most of them @@ -606,23 +665,33 @@ def _copy_inventory_texts(self): # few unused CHKs. # (Iterate but don't insert _get_filtered_inv_stream to populate the # variables needed by GCCHKPacker._copy_chk_texts.) - self._exhaust_stream(source_vf, inventory_keys, 'inventories', - self._get_filtered_inv_stream, 2) + self._exhaust_stream( + source_vf, inventory_keys, "inventories", self._get_filtered_inv_stream, 2 + ) GCCHKPacker._copy_chk_texts(self) # Now copy and fix the inventories, and any regenerated CHKs. def chk_canonicalizing_inv_stream(source_vf, keys, message, pb=None): return self._get_filtered_canonicalizing_inv_stream( - source_vf, keys, message, pb, source_chk_vf, target_chk_vf) - self._copy_stream(source_vf, target_vf, inventory_keys, - 'inventories', chk_canonicalizing_inv_stream, 4) + source_vf, keys, message, pb, source_chk_vf, target_chk_vf + ) + + self._copy_stream( + source_vf, + target_vf, + inventory_keys, + "inventories", + chk_canonicalizing_inv_stream, + 4, + ) def _copy_chk_texts(self): # No-op; in this class this happens during _copy_inventory_texts. pass - def _get_filtered_canonicalizing_inv_stream(self, source_vf, keys, message, - pb=None, source_chk_vf=None, target_chk_vf=None): + def _get_filtered_canonicalizing_inv_stream( + self, source_vf, keys, message, pb=None, source_chk_vf=None, target_chk_vf=None + ): """Filter the texts of inventories, regenerating CHKs to make sure they are canonical. """ @@ -630,15 +699,16 @@ def _get_filtered_canonicalizing_inv_stream(self, source_vf, keys, message, target_chk_vf = versionedfile.NoDupeAddLinesDecorator(target_chk_vf) def _filtered_inv_stream(): - stream = source_vf.get_record_stream(keys, 'groupcompress', True) + stream = source_vf.get_record_stream(keys, "groupcompress", True) search_key_name = None for idx, record in enumerate(stream): # Inventories should always be with revisions; assume success. - lines = record.get_bytes_as('lines') + lines = record.get_bytes_as("lines") chk_inv = inventory.CHKInventory.deserialise( - source_chk_vf, lines, record.key) + source_chk_vf, lines, record.key + ) if pb is not None: - pb.update('inv', idx, total_keys) + pb.update("inv", idx, total_keys) chk_inv.id_to_entry._ensure_root() if search_key_name is None: # Find the name corresponding to the search_key_func @@ -647,31 +717,44 @@ def _filtered_inv_stream(): if func == chk_inv.id_to_entry._search_key_func: break canonical_inv = inventory.CHKInventory.from_inventory( - target_chk_vf, chk_inv, + target_chk_vf, + chk_inv, maximum_size=chk_inv.id_to_entry._root_node._maximum_size, - search_key_name=search_key_name) + search_key_name=search_key_name, + ) if chk_inv.id_to_entry.key() != canonical_inv.id_to_entry.key(): trace.mutter( - 'Non-canonical CHK map for id_to_entry of inv: {} ' - '(root is {}, should be {})'.format(chk_inv.revision_id, - chk_inv.id_to_entry.key()[ - 0], - canonical_inv.id_to_entry.key()[0])) + "Non-canonical CHK map for id_to_entry of inv: {} " + "(root is {}, should be {})".format( + chk_inv.revision_id, + chk_inv.id_to_entry.key()[0], + canonical_inv.id_to_entry.key()[0], + ) + ) self._data_changed = True p_id_map = chk_inv.parent_id_basename_to_file_id p_id_map._ensure_root() canon_p_id_map = canonical_inv.parent_id_basename_to_file_id if p_id_map.key() != canon_p_id_map.key(): trace.mutter( - 'Non-canonical CHK map for parent_id_to_basename of ' - 'inv: {} (root is {}, should be {})'.format(chk_inv.revision_id, p_id_map.key()[0], - canon_p_id_map.key()[0])) + "Non-canonical CHK map for parent_id_to_basename of " + "inv: {} (root is {}, should be {})".format( + chk_inv.revision_id, + p_id_map.key()[0], + canon_p_id_map.key()[0], + ) + ) self._data_changed = True yield versionedfile.ChunkedContentFactory( - record.key, record.parents, record.sha1, canonical_inv.to_lines(), - chunks_are_lines=True) + record.key, + record.parents, + record.sha1, + canonical_inv.to_lines(), + chunks_are_lines=True, + ) # We have finished processing all of the inventory records, we # don't need these sets anymore + return _filtered_inv_stream() def _use_pack(self, new_pack): @@ -680,7 +763,6 @@ def _use_pack(self, new_pack): class GCRepositoryPackCollection(RepositoryPackCollection): - pack_factory = GCPack resumed_pack_factory = ResumedGCPack normal_packer_class = GCCHKPacker @@ -705,14 +787,15 @@ def _check_new_inventories(self): no_fallback_inv_index = self.repo.inventories._index no_fallback_chk_bytes_index = self.repo.chk_bytes._index no_fallback_texts_index = self.repo.texts._index - inv_parent_map = no_fallback_inv_index.get_parent_map( - new_revisions_keys) + inv_parent_map = no_fallback_inv_index.get_parent_map(new_revisions_keys) # Are any inventories for corresponding to the new revisions missing? corresponding_invs = set(inv_parent_map) missing_corresponding = set(new_revisions_keys) missing_corresponding.difference_update(corresponding_invs) if missing_corresponding: - problems.append(f"inventories missing for revisions {sorted(missing_corresponding)}") + problems.append( + f"inventories missing for revisions {sorted(missing_corresponding)}" + ) return problems # Are any chk root entries missing for any inventories? This includes # any present parent inventories, which may be used when calculating @@ -722,49 +805,57 @@ def _check_new_inventories(self): all_inv_keys.update(parent_inv_keys) # Filter out ghost parents. all_inv_keys.intersection_update( - no_fallback_inv_index.get_parent_map(all_inv_keys)) - parent_invs_only_keys = all_inv_keys.symmetric_difference( - corresponding_invs) + no_fallback_inv_index.get_parent_map(all_inv_keys) + ) + parent_invs_only_keys = all_inv_keys.symmetric_difference(corresponding_invs) inv_ids = [key[-1] for key in all_inv_keys] parent_invs_only_ids = [key[-1] for key in parent_invs_only_keys] root_key_info = _build_interesting_key_sets( - self.repo, inv_ids, parent_invs_only_ids) + self.repo, inv_ids, parent_invs_only_ids + ) expected_chk_roots = root_key_info.all_keys() present_chk_roots = no_fallback_chk_bytes_index.get_parent_map( - expected_chk_roots) + expected_chk_roots + ) missing_chk_roots = expected_chk_roots.difference(present_chk_roots) if missing_chk_roots: problems.append( f"missing referenced chk root keys: {sorted(missing_chk_roots)}." "Run 'brz reconcile --canonicalize-chks' on the affected " - "repository.") + "repository." + ) # Don't bother checking any further. return problems # Find all interesting chk_bytes records, and make sure they are # present, as well as the text keys they reference. chk_bytes_no_fallbacks = self.repo.chk_bytes.without_fallbacks() - chk_bytes_no_fallbacks._search_key_func = \ - self.repo.chk_bytes._search_key_func + chk_bytes_no_fallbacks._search_key_func = self.repo.chk_bytes._search_key_func chk_diff = chk_map.iter_interesting_nodes( - chk_bytes_no_fallbacks, root_key_info.interesting_root_keys, - root_key_info.uninteresting_root_keys) + chk_bytes_no_fallbacks, + root_key_info.interesting_root_keys, + root_key_info.uninteresting_root_keys, + ) text_keys = set() try: - for _record in _filter_text_keys(chk_diff, text_keys, - chk_map._bytes_to_text_key): + for _record in _filter_text_keys( + chk_diff, text_keys, chk_map._bytes_to_text_key + ): pass except errors.NoSuchRevision: # XXX: It would be nice if we could give a more precise error here. problems.append("missing chk node(s) for id_to_entry maps") chk_diff = chk_map.iter_interesting_nodes( - chk_bytes_no_fallbacks, root_key_info.interesting_pid_root_keys, - root_key_info.uninteresting_pid_root_keys) + chk_bytes_no_fallbacks, + root_key_info.interesting_pid_root_keys, + root_key_info.uninteresting_pid_root_keys, + ) try: for _interesting_rec, _interesting_map in chk_diff: pass except errors.NoSuchRevision: problems.append( - "missing chk node(s) for parent_id_basename_to_file_id maps") + "missing chk node(s) for parent_id_basename_to_file_id maps" + ) present_text_keys = no_fallback_texts_index.get_parent_map(text_keys) missing_text_keys = text_keys.difference(present_text_keys) if missing_text_keys: @@ -775,54 +866,89 @@ def _check_new_inventories(self): class CHKInventoryRepository(PackRepository): """subclass of PackRepository that uses CHK based inventories.""" - def __init__(self, _format, a_controldir, control_files, _commit_builder_class, - _revision_serializer, _inventory_serializer): + def __init__( + self, + _format, + a_controldir, + control_files, + _commit_builder_class, + _revision_serializer, + _inventory_serializer, + ): """Overridden to change pack collection class.""" - super().__init__(_format, a_controldir, control_files, _commit_builder_class, _revision_serializer, _inventory_serializer) - index_transport = self._transport.clone('indices') - self._pack_collection = GCRepositoryPackCollection(self, - self._transport, index_transport, - self._transport.clone( - 'upload'), - self._transport.clone( - 'packs'), - _format.index_builder_class, - _format.index_class, - use_chk_index=self._format.supports_chks, - ) + super().__init__( + _format, + a_controldir, + control_files, + _commit_builder_class, + _revision_serializer, + _inventory_serializer, + ) + index_transport = self._transport.clone("indices") + self._pack_collection = GCRepositoryPackCollection( + self, + self._transport, + index_transport, + self._transport.clone("upload"), + self._transport.clone("packs"), + _format.index_builder_class, + _format.index_class, + use_chk_index=self._format.supports_chks, + ) self.inventories = GroupCompressVersionedFiles( - _GCGraphIndex(self._pack_collection.inventory_index.combined_index, - add_callback=self._pack_collection.inventory_index.add_callback, - parents=True, is_locked=self.is_locked, - inconsistency_fatal=False), - access=self._pack_collection.inventory_index.data_access) + _GCGraphIndex( + self._pack_collection.inventory_index.combined_index, + add_callback=self._pack_collection.inventory_index.add_callback, + parents=True, + is_locked=self.is_locked, + inconsistency_fatal=False, + ), + access=self._pack_collection.inventory_index.data_access, + ) self.revisions = GroupCompressVersionedFiles( - _GCGraphIndex(self._pack_collection.revision_index.combined_index, - add_callback=self._pack_collection.revision_index.add_callback, - parents=True, is_locked=self.is_locked, - track_external_parent_refs=True, track_new_keys=True), + _GCGraphIndex( + self._pack_collection.revision_index.combined_index, + add_callback=self._pack_collection.revision_index.add_callback, + parents=True, + is_locked=self.is_locked, + track_external_parent_refs=True, + track_new_keys=True, + ), access=self._pack_collection.revision_index.data_access, - delta=False) + delta=False, + ) self.signatures = GroupCompressVersionedFiles( - _GCGraphIndex(self._pack_collection.signature_index.combined_index, - add_callback=self._pack_collection.signature_index.add_callback, - parents=False, is_locked=self.is_locked, - inconsistency_fatal=False), + _GCGraphIndex( + self._pack_collection.signature_index.combined_index, + add_callback=self._pack_collection.signature_index.add_callback, + parents=False, + is_locked=self.is_locked, + inconsistency_fatal=False, + ), access=self._pack_collection.signature_index.data_access, - delta=False) + delta=False, + ) self.texts = GroupCompressVersionedFiles( - _GCGraphIndex(self._pack_collection.text_index.combined_index, - add_callback=self._pack_collection.text_index.add_callback, - parents=True, is_locked=self.is_locked, - inconsistency_fatal=False), - access=self._pack_collection.text_index.data_access) + _GCGraphIndex( + self._pack_collection.text_index.combined_index, + add_callback=self._pack_collection.text_index.add_callback, + parents=True, + is_locked=self.is_locked, + inconsistency_fatal=False, + ), + access=self._pack_collection.text_index.data_access, + ) # No parents, individual CHK pages don't have specific ancestry self.chk_bytes = GroupCompressVersionedFiles( - _GCGraphIndex(self._pack_collection.chk_index.combined_index, - add_callback=self._pack_collection.chk_index.add_callback, - parents=False, is_locked=self.is_locked, - inconsistency_fatal=False), - access=self._pack_collection.chk_index.data_access) + _GCGraphIndex( + self._pack_collection.chk_index.combined_index, + add_callback=self._pack_collection.chk_index.add_callback, + parents=False, + is_locked=self.is_locked, + inconsistency_fatal=False, + ), + access=self._pack_collection.chk_index.data_access, + ) search_key_name = self._format._inventory_serializer.search_key_name search_key_func = chk_map.search_key_registry.get(search_key_name) self.chk_bytes._search_key_func = search_key_func @@ -847,12 +973,16 @@ def _add_inventory_checked(self, revision_id, inv, parents): """ # make inventory serializer = self._format._inventory_serializer - result = inventory.CHKInventory.from_inventory(self.chk_bytes, inv, - maximum_size=serializer.maximum_size, - search_key_name=serializer.search_key_name) + result = inventory.CHKInventory.from_inventory( + self.chk_bytes, + inv, + maximum_size=serializer.maximum_size, + search_key_name=serializer.search_key_name, + ) inv_lines = result.to_lines() - return self._inventory_add_lines(revision_id, parents, - inv_lines, check_content=False) + return self._inventory_add_lines( + revision_id, parents, inv_lines, check_content=False + ) def _create_inv_from_null(self, delta, revision_id): """This will mutate new_inv directly. @@ -868,18 +998,23 @@ def _create_inv_from_null(self, delta, revision_id): parent_id_basename_dict = {} for old_path, new_path, file_id, entry in delta: if old_path is not None: - raise ValueError('Invalid delta, somebody tried to delete {!r}' - ' from the NULL_REVISION'.format((old_path, file_id))) + raise ValueError( + f"Invalid delta, somebody tried to delete {(old_path, file_id)!r}" + " from the NULL_REVISION" + ) if new_path is None: - raise ValueError('Invalid delta, delta from NULL_REVISION has' - f' no new_path {file_id!r}') - if new_path == '': + raise ValueError( + "Invalid delta, delta from NULL_REVISION has" + f" no new_path {file_id!r}" + ) + if new_path == "": new_inv.root_id = file_id - parent_id_basename_key = StaticTuple(b'', b'').intern() + parent_id_basename_key = StaticTuple(b"", b"").intern() else: - utf8_entry_name = entry.name.encode('utf-8') - parent_id_basename_key = StaticTuple(entry.parent_id, - utf8_entry_name).intern() + utf8_entry_name = entry.name.encode("utf-8") + parent_id_basename_key = StaticTuple( + entry.parent_id, utf8_entry_name + ).intern() new_value = entry_to_bytes(entry) # Populate Caches? # new_inv._path_to_fileid_cache[new_path] = file_id @@ -887,12 +1022,23 @@ def _create_inv_from_null(self, delta, revision_id): id_to_entry_dict[key] = new_value parent_id_basename_dict[parent_id_basename_key] = file_id - new_inv._populate_from_dicts(self.chk_bytes, id_to_entry_dict, - parent_id_basename_dict, maximum_size=serializer.maximum_size) + new_inv._populate_from_dicts( + self.chk_bytes, + id_to_entry_dict, + parent_id_basename_dict, + maximum_size=serializer.maximum_size, + ) return new_inv - def add_inventory_by_delta(self, basis_revision_id, delta, new_revision_id, - parents, basis_inv=None, propagate_caches=False): + def add_inventory_by_delta( + self, + basis_revision_id, + delta, + new_revision_id, + parents, + basis_inv=None, + propagate_caches=False, + ): """Add a new inventory expressed as a delta against another revision. :param basis_revision_id: The inventory id the delta was created @@ -926,36 +1072,38 @@ def add_inventory_by_delta(self, basis_revision_id, delta, new_revision_id, if new_inv.root_id is None: raise errors.RootMissing() inv_lines = new_inv.to_lines() - return self._inventory_add_lines(new_revision_id, parents, - inv_lines, check_content=False), new_inv + return self._inventory_add_lines( + new_revision_id, parents, inv_lines, check_content=False + ), new_inv else: basis_tree = self.revision_tree(basis_revision_id) basis_tree.lock_read() basis_inv = basis_tree.root_inventory try: - result = basis_inv.create_by_apply_delta(delta, new_revision_id, - propagate_caches=propagate_caches) + result = basis_inv.create_by_apply_delta( + delta, new_revision_id, propagate_caches=propagate_caches + ) inv_lines = result.to_lines() - return self._inventory_add_lines(new_revision_id, parents, - inv_lines, check_content=False), result + return self._inventory_add_lines( + new_revision_id, parents, inv_lines, check_content=False + ), result finally: if basis_tree is not None: basis_tree.unlock() def _deserialise_inventory(self, revision_id, lines): - return inventory.CHKInventory.deserialise(self.chk_bytes, lines, - (revision_id,)) + return inventory.CHKInventory.deserialise(self.chk_bytes, lines, (revision_id,)) def _iter_inventories(self, revision_ids, ordering): """Iterate over many inventory objects.""" if ordering is None: - ordering = 'unordered' + ordering = "unordered" keys = [(revision_id,) for revision_id in revision_ids] stream = self.inventories.get_record_stream(keys, ordering, True) texts = {} for record in stream: - if record.storage_kind != 'absent': - texts[record.key] = record.get_bytes_as('lines') + if record.storage_kind != "absent": + texts[record.key] = record.get_bytes_as("lines") else: texts[record.key] = None for key in keys: @@ -963,8 +1111,10 @@ def _iter_inventories(self, revision_ids, ordering): if lines is None: yield (None, key[-1]) else: - yield (inventory.CHKInventory.deserialise( - self.chk_bytes, lines, key), key[-1]) + yield ( + inventory.CHKInventory.deserialise(self.chk_bytes, lines, key), + key[-1], + ) def _get_inventory_xml(self, revision_id): """Get serialized inventory as a string.""" @@ -974,7 +1124,8 @@ def _get_inventory_xml(self, revision_id): # serializer directly; this also isn't ideal, but there isn't an xml # iteration interface offered at all for repositories. return self._inventory_serializer.write_inventory_to_lines( - self.get_inventory(revision_id)) + self.get_inventory(revision_id) + ) def _find_present_inventory_keys(self, revision_keys): parent_map = self.inventories.get_parent_map(revision_keys) @@ -1001,26 +1152,26 @@ def fileids_altered_by_revision_ids(self, revision_ids, _inv_weave=None): # code paths to allow missing inventories to be tolerated. # However, we only want to tolerate missing parent # inventories, not missing inventories for revision_ids - present_parent_inv_keys = self._find_present_inventory_keys( - parent_keys) + present_parent_inv_keys = self._find_present_inventory_keys(parent_keys) present_parent_inv_ids = {k[-1] for k in present_parent_inv_keys} inventories_to_read = set(revision_ids) inventories_to_read.update(present_parent_inv_ids) root_key_info = _build_interesting_key_sets( - self, inventories_to_read, present_parent_inv_ids) + self, inventories_to_read, present_parent_inv_ids + ) interesting_root_keys = root_key_info.interesting_root_keys uninteresting_root_keys = root_key_info.uninteresting_root_keys chk_bytes = self.chk_bytes - for _record, items in chk_map.iter_interesting_nodes(chk_bytes, - interesting_root_keys, uninteresting_root_keys, - pb=pb): + for _record, items in chk_map.iter_interesting_nodes( + chk_bytes, interesting_root_keys, uninteresting_root_keys, pb=pb + ): for _name, bytes in items: (name_utf8, file_id, revision_id) = bytes_to_info(bytes) # TODO: consider interning file_id, revision_id here, or # pushing that intern() into bytes_to_info() # TODO: rich_root should always be True here, for all # repositories that support chk_bytes - if not rich_root and name_utf8 == '': + if not rich_root and name_utf8 == "": continue try: file_id_revisions[file_id].add(revision_id) @@ -1062,9 +1213,9 @@ def reconcile_canonicalize_chks(self): form. """ from .reconcile import PackReconciler + with self.lock_write(): - reconciler = PackReconciler( - self, thorough=True, canonicalize_chks=True) + reconciler = PackReconciler(self, thorough=True, canonicalize_chks=True) return reconciler.reconcile() def _reconcile_pack(self, collection, packs, extension, revs, pb): @@ -1077,8 +1228,10 @@ def _canonicalize_chks_pack(self, collection, packs, extension, revs, pb): def _get_source(self, to_format): """Return a source for streaming from this repository.""" - if (self._format._inventory_serializer == to_format._inventory_serializer and - self._format._revision_serializer == to_format._revision_serializer): + if ( + self._format._inventory_serializer == to_format._inventory_serializer + and self._format._revision_serializer == to_format._revision_serializer + ): # We must be exactly the same format, otherwise stuff like the chk # page layout might be different. # Actually, this test is just slightly looser than exact so that @@ -1104,18 +1257,17 @@ def _find_inconsistent_revision_parents(self, revisions_iterator=None): if revision is None: pass parent_map = vf.get_parent_map([(revid,)]) - parents_according_to_index = tuple(parent[-1] for parent in - parent_map[(revid,)]) + parents_according_to_index = tuple( + parent[-1] for parent in parent_map[(revid,)] + ) parents_according_to_revision = tuple(revision.parent_ids) if parents_according_to_index != parents_according_to_revision: - yield (revid, parents_according_to_index, - parents_according_to_revision) + yield (revid, parents_according_to_index, parents_according_to_revision) def _check_for_inconsistent_revision_parents(self): inconsistencies = list(self._find_inconsistent_revision_parents()) if inconsistencies: - raise errors.BzrCheckError( - "Revision index has inconsistent parents.") + raise errors.BzrCheckError("Revision index has inconsistent parents.") class GroupCHKStreamSource(StreamSource): @@ -1126,7 +1278,7 @@ def __init__(self, from_repository, to_format): super().__init__(from_repository, to_format) self._revision_keys = None self._text_keys = None - self._text_fetch_order = 'groupcompress' + self._text_fetch_order = "groupcompress" self._chk_id_roots = None self._chk_p_id_roots = None @@ -1143,24 +1295,22 @@ def _filtered_inv_stream(): id_roots_set = set() p_id_roots_set = set() source_vf = self.from_repository.inventories - stream = source_vf.get_record_stream(inventory_keys, - 'groupcompress', True) + stream = source_vf.get_record_stream(inventory_keys, "groupcompress", True) for record in stream: - if record.storage_kind == 'absent': + if record.storage_kind == "absent": if allow_absent: continue else: raise errors.NoSuchRevision(self, record.key) - lines = record.get_bytes_as('lines') - chk_inv = inventory.CHKInventory.deserialise(None, lines, - record.key) + lines = record.get_bytes_as("lines") + chk_inv = inventory.CHKInventory.deserialise(None, lines, record.key) key = chk_inv.id_to_entry.key() if key not in id_roots_set: self._chk_id_roots.append(key) id_roots_set.add(key) p_id_map = chk_inv.parent_id_basename_to_file_id if p_id_map is None: - raise AssertionError('Parent id -> file_id map not set') + raise AssertionError("Parent id -> file_id map not set") key = p_id_map.key() if key not in p_id_roots_set: p_id_roots_set.add(key) @@ -1170,7 +1320,8 @@ def _filtered_inv_stream(): # don't need these sets anymore id_roots_set.clear() p_id_roots_set.clear() - return ('inventories', _filtered_inv_stream()) + + return ("inventories", _filtered_inv_stream()) def _get_filtered_chk_streams(self, excluded_revision_keys): self._text_keys = set() @@ -1184,42 +1335,48 @@ def _get_filtered_chk_streams(self, excluded_revision_keys): # TODO: Update Repository.iter_inventories() to add # ignore_missing=True present_keys = self.from_repository._find_present_inventory_keys( - excluded_revision_keys) + excluded_revision_keys + ) present_ids = [k[-1] for k in present_keys] uninteresting_root_keys = set() uninteresting_pid_root_keys = set() for inv in self.from_repository.iter_inventories(present_ids): uninteresting_root_keys.add(inv.id_to_entry.key()) - uninteresting_pid_root_keys.add( - inv.parent_id_basename_to_file_id.key()) + uninteresting_pid_root_keys.add(inv.parent_id_basename_to_file_id.key()) chk_bytes = self.from_repository.chk_bytes def _filter_id_to_entry(): - interesting_nodes = chk_map.iter_interesting_nodes(chk_bytes, - self._chk_id_roots, uninteresting_root_keys) - for record in _filter_text_keys(interesting_nodes, self._text_keys, - chk_map._bytes_to_text_key): + interesting_nodes = chk_map.iter_interesting_nodes( + chk_bytes, self._chk_id_roots, uninteresting_root_keys + ) + for record in _filter_text_keys( + interesting_nodes, self._text_keys, chk_map._bytes_to_text_key + ): if record is not None: yield record # Consumed self._chk_id_roots = None - yield 'chk_bytes', _filter_id_to_entry() + + yield "chk_bytes", _filter_id_to_entry() def _get_parent_id_basename_to_file_id_pages(): - for record, _items in chk_map.iter_interesting_nodes(chk_bytes, - self._chk_p_id_roots, uninteresting_pid_root_keys): + for record, _items in chk_map.iter_interesting_nodes( + chk_bytes, self._chk_p_id_roots, uninteresting_pid_root_keys + ): if record is not None: yield record # Consumed self._chk_p_id_roots = None - yield 'chk_bytes', _get_parent_id_basename_to_file_id_pages() + + yield "chk_bytes", _get_parent_id_basename_to_file_id_pages() def _get_text_stream(self): # Note: We know we don't have to handle adding root keys, because both # the source and target are the identical network name. text_stream = self.from_repository.texts.get_record_stream( - self._text_keys, self._text_fetch_order, False) - return ('texts', text_stream) + self._text_keys, self._text_fetch_order, False + ) + return ("texts", text_stream) def get_stream(self, search): def wrap_and_count(pb, rc, stream): @@ -1228,7 +1385,7 @@ def wrap_and_count(pb, rc, stream): for record in stream: if count == rc.STEP: rc.increment(count) - pb.update('Estimate', rc.current, rc.max) + pb.update("Estimate", rc.current, rc.max) count = 0 count += 1 yield record @@ -1238,15 +1395,13 @@ def wrap_and_count(pb, rc, stream): rc = self._record_counter self._record_counter.setup(len(revision_ids)) for stream_info in self._fetch_revision_texts(revision_ids): - yield (stream_info[0], - wrap_and_count(pb, rc, stream_info[1])) + yield (stream_info[0], wrap_and_count(pb, rc, stream_info[1])) self._revision_keys = [(rev_id,) for rev_id in revision_ids] # TODO: The keys to exclude might be part of the search recipe # For now, exclude all parents that are at the edge of ancestry, for # which we have inventories from_repo = self.from_repository - parent_keys = from_repo._find_parent_keys_of_revisions( - self._revision_keys) + parent_keys = from_repo._find_parent_keys_of_revisions(self._revision_keys) self.from_repository.revisions.clear_cache() self.from_repository.signatures.clear_cache() # Clear the repo's get_parent_map cache too. @@ -1261,7 +1416,7 @@ def wrap_and_count(pb, rc, stream): s = self._get_text_stream() yield (s[0], wrap_and_count(pb, rc, s[1])) self.from_repository.texts.clear_cache() - pb.update('Done', rc.max, rc.max) + pb.update("Done", rc.max, rc.max) def get_stream_for_missing_keys(self, missing_keys): # missing keys can only occur when we are byte copying and not @@ -1269,19 +1424,22 @@ def get_stream_for_missing_keys(self, missing_keys): # unreconstructable deltas ever). missing_inventory_keys = set() for key in missing_keys: - if key[0] != 'inventories': - raise AssertionError('The only missing keys we should' - f' be filling in are inventory keys, not {key[0]}') + if key[0] != "inventories": + raise AssertionError( + "The only missing keys we should" + f" be filling in are inventory keys, not {key[0]}" + ) missing_inventory_keys.add(key[1:]) if self._chk_id_roots or self._chk_p_id_roots: - raise AssertionError('Cannot call get_stream_for_missing_keys' - ' until all of get_stream() has been consumed.') + raise AssertionError( + "Cannot call get_stream_for_missing_keys" + " until all of get_stream() has been consumed." + ) # Yield the inventory stream, so we can find the chk stream # Some of the missing_keys will be missing because they are ghosts. # As such, we can ignore them. The Sink is required to verify there are # no unavailable texts when the ghost inventories are not filled in. - yield self._get_inventory_stream(missing_inventory_keys, - allow_absent=True) + yield self._get_inventory_stream(missing_inventory_keys, allow_absent=True) # We use the empty set for excluded_revision_keys, to make it clear # that we want to transmit all referenced chk pages. yield from self._get_filtered_chk_streams(set()) @@ -1298,8 +1456,7 @@ def all_interesting(self): return self.interesting_root_keys.union(self.interesting_pid_root_keys) def all_uninteresting(self): - return self.uninteresting_root_keys.union( - self.uninteresting_pid_root_keys) + return self.uninteresting_root_keys.union(self.uninteresting_pid_root_keys) def all_keys(self): return self.all_interesting().union(self.all_uninteresting()) @@ -1307,7 +1464,7 @@ def all_keys(self): def _build_interesting_key_sets(repo, inventory_ids, parent_only_inv_ids): result = _InterestingKeyInfo() - for inv in repo.iter_inventories(inventory_ids, 'unordered'): + for inv in repo.iter_inventories(inventory_ids, "unordered"): root_key = inv.id_to_entry.key() pid_root_key = inv.parent_id_basename_to_file_id.key() if inv.revision_id in parent_only_inv_ids: @@ -1349,7 +1506,7 @@ class RepositoryFormat2a(RepositoryFormatPack): # operations, because the source can be smart about extracting # multiple in-a-row (and sharing strings). Topological is better # for remote, because we access less data. - _fetch_order = 'unordered' + _fetch_order = "unordered" # essentially ignored by the groupcompress code. _fetch_uses_deltas = False fast_deltas = True @@ -1357,44 +1514,46 @@ class RepositoryFormat2a(RepositoryFormatPack): supports_tree_reference = True def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir('2a') + return controldir.format_registry.make_controldir("2a") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): - return b'Bazaar repository format 2a (needs bzr 1.16 or later)\n' + return b"Bazaar repository format 2a (needs bzr 1.16 or later)\n" def get_format_description(self): """See RepositoryFormat.get_format_description().""" - return ("Repository format 2a - rich roots, group compression" - " and chk inventories") + return ( + "Repository format 2a - rich roots, group compression" + " and chk inventories" + ) class RepositoryFormat2aSubtree(RepositoryFormat2a): """A 2a repository format that supports nested trees.""" def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir('development-subtree') + return controldir.format_registry.make_controldir("development-subtree") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): - return b'Bazaar development format 8\n' + return b"Bazaar development format 8\n" def get_format_description(self): """See RepositoryFormat.get_format_description().""" - return ("Development repository format 8 - nested trees, " - "group compression and chk inventories") + return ( + "Development repository format 8 - nested trees, " + "group compression and chk inventories" + ) experimental = True supports_tree_reference = True diff --git a/breezy/bzr/index.py b/breezy/bzr/index.py index 2544016cbc..4bd52e7733 100644 --- a/breezy/bzr/index.py +++ b/breezy/bzr/index.py @@ -17,12 +17,12 @@ """Indexing facilities.""" __all__ = [ - 'CombinedGraphIndex', - 'GraphIndex', - 'GraphIndexBuilder', - 'GraphIndexPrefixAdapter', - 'InMemoryGraphIndex', - ] + "CombinedGraphIndex", + "GraphIndex", + "GraphIndexBuilder", + "GraphIndexPrefixAdapter", + "InMemoryGraphIndex", +] import re from bisect import bisect_right @@ -41,7 +41,6 @@ class BadIndexFormatSignature(errors.BzrError): - _fmt = "%(value)s is not an index of type %(_type)s." def __init__(self, value, _type): @@ -51,7 +50,6 @@ def __init__(self, value, _type): class BadIndexData(errors.BzrError): - _fmt = "Error in data for index %(value)s." def __init__(self, value): @@ -60,7 +58,6 @@ def __init__(self, value): class BadIndexDuplicateKey(errors.BzrError): - _fmt = "The key '%(key)s' is already in index '%(index)s'." def __init__(self, key, index): @@ -70,7 +67,6 @@ def __init__(self, key, index): class BadIndexKey(errors.BzrError): - _fmt = "The key '%(key)s' is not a valid key." def __init__(self, key): @@ -79,7 +75,6 @@ def __init__(self, key): class BadIndexOptions(errors.BzrError): - _fmt = "Could not parse options for index %(value)s." def __init__(self, value): @@ -88,7 +83,6 @@ def __init__(self, value): class BadIndexValue(errors.BzrError): - _fmt = "The value '%(value)s' is not a valid value." def __init__(self, value): @@ -96,8 +90,8 @@ def __init__(self, value): self.value = value -_whitespace_re = re.compile(b'[\t\n\x0b\x0c\r\x00 ]') -_newline_null_re = re.compile(b'[\n\0]') +_whitespace_re = re.compile(b"[\t\n\x0b\x0c\r\x00 ]") +_newline_null_re = re.compile(b"[\n\0]") def _has_key_from_parent_map(self, key): @@ -106,7 +100,7 @@ def _has_key_from_parent_map(self, key): If it's possible to check for multiple keys at once through calling get_parent_map that should be faster. """ - return (key in self.get_parent_map([key])) + return key in self.get_parent_map([key]) def _missing_keys_from_parent_map(self, keys): @@ -156,7 +150,11 @@ def _check_key(self, key): if self._key_length != len(key): raise BadIndexKey(key) for element in key: - if not element or not isinstance(element, bytes) or _whitespace_re.search(element) is not None: + if ( + not element + or not isinstance(element, bytes) + or _whitespace_re.search(element) is not None + ): raise BadIndexKey(key) def _external_references(self): @@ -252,8 +250,7 @@ def _check_key_ref_value(self, key, references, value): if reference not in self._nodes: self._check_key(reference) absent_references.append(reference) - reference_list = as_st([as_st(ref).intern() - for ref in reference_list]) + reference_list = as_st([as_st(ref).intern() for ref in reference_list]) node_refs.append(reference_list) return as_st(node_refs), absent_references @@ -268,17 +265,18 @@ def add_node(self, key, value, references=()): :param value: The value to associate with the key. It may be any bytes as long as it does not contain \0 or \n. """ - (node_refs, - absent_references) = self._check_key_ref_value(key, references, value) - if key in self._nodes and self._nodes[key][0] != b'a': + (node_refs, absent_references) = self._check_key_ref_value( + key, references, value + ) + if key in self._nodes and self._nodes[key][0] != b"a": raise BadIndexDuplicateKey(key, self) for reference in absent_references: # There may be duplicates, but I don't think it is worth worrying # about - self._nodes[reference] = (b'a', (), b'') + self._nodes[reference] = (b"a", (), b"") self._absent_keys.update(absent_references) self._absent_keys.discard(key) - self._nodes[key] = (b'', node_refs, value) + self._nodes[key] = (b"", node_refs, value) if self._nodes_by_key is not None and self._key_length > 1: self._update_nodes_by_key(key, value, node_refs) @@ -296,10 +294,10 @@ def finish(self): should be written to disk. """ lines = [_SIGNATURE] - lines.append(b'%s%d\n' % (_OPTION_NODE_REFS, self.reference_lists)) - lines.append(b'%s%d\n' % (_OPTION_KEY_ELEMENTS, self._key_length)) + lines.append(b"%s%d\n" % (_OPTION_NODE_REFS, self.reference_lists)) + lines.append(b"%s%d\n" % (_OPTION_KEY_ELEMENTS, self._key_length)) key_count = len(self._nodes) - len(self._absent_keys) - lines.append(b'%s%d\n' % (_OPTION_LEN, key_count)) + lines.append(b"%s%d\n" % (_OPTION_LEN, key_count)) prefix_length = sum(len(x) for x in lines) # references are byte offsets. To avoid having to do nasty # polynomial work to resolve offsets (references to later in the @@ -353,7 +351,7 @@ def finish(self): # how many digits are needed to represent the total byte count? digits = 1 possible_total_bytes = non_ref_bytes + total_references * digits - while 10 ** digits < possible_total_bytes: + while 10**digits < possible_total_bytes: digits += 1 possible_total_bytes = non_ref_bytes + total_references * digits expected_bytes = possible_total_bytes + 1 # terminating newline @@ -362,24 +360,27 @@ def finish(self): for key, non_ref_bytes, total_references in key_offset_info: key_addresses[key] = non_ref_bytes + total_references * digits # serialise - format_string = b'%%0%dd' % digits + format_string = b"%%0%dd" % digits for key, (absent, references, value) in nodes: flattened_references = [] for ref_list in references: ref_addresses = [] for reference in ref_list: - ref_addresses.append(format_string % - key_addresses[reference]) - flattened_references.append(b'\r'.join(ref_addresses)) - string_key = b'\x00'.join(key) - lines.append(b"%s\x00%s\x00%s\x00%s\n" % (string_key, absent, - b'\t'.join(flattened_references), value)) - lines.append(b'\n') - result = BytesIO(b''.join(lines)) + ref_addresses.append(format_string % key_addresses[reference]) + flattened_references.append(b"\r".join(ref_addresses)) + string_key = b"\x00".join(key) + lines.append( + b"%s\x00%s\x00%s\x00%s\n" + % (string_key, absent, b"\t".join(flattened_references), value) + ) + lines.append(b"\n") + result = BytesIO(b"".join(lines)) if expected_bytes and len(result.getvalue()) != expected_bytes: - raise errors.BzrError('Failed index creation. Internal error:' - ' mismatched output length and expected length: %d %d' % - (len(result.getvalue()), expected_bytes)) + raise errors.BzrError( + "Failed index creation. Internal error:" + " mismatched output length and expected length: %d %d" + % (len(result.getvalue()), expected_bytes) + ) return result def set_optimize(self, for_size=None, combine_backing_indices=None): @@ -410,8 +411,7 @@ def find_ancestry(self, keys, ref_list_num): for _, key, _value, ref_lists in self.iter_entries(pending): parent_keys = ref_lists[ref_list_num] parent_map[key] = parent_keys - next_pending.update([p for p in parent_keys if p not in - parent_map]) + next_pending.update([p for p in parent_keys if p not in parent_map]) missing_keys.update(pending.difference(parent_map)) pending = next_pending return parent_map, missing_keys @@ -476,18 +476,20 @@ def __init__(self, transport, name, size, unlimited_cache=False, offset=0): def __eq__(self, other): """Equal when self and other were created with the same parameters.""" return ( - isinstance(self, type(other)) and - self._transport == other._transport and - self._name == other._name and - self._size == other._size) + isinstance(self, type(other)) + and self._transport == other._transport + and self._name == other._name + and self._size == other._size + ) def __ne__(self, other): return not self.__eq__(other) def __lt__(self, other): # We don't really care about the order, just that there is an order. - if (not isinstance(other, GraphIndex) and - not isinstance(other, InMemoryGraphIndex)): + if not isinstance(other, GraphIndex) and not isinstance( + other, InMemoryGraphIndex + ): raise TypeError(other) return hash(self) < hash(other) @@ -505,15 +507,14 @@ def _buffer_all(self, stream=None): if self._nodes is not None: # We already did this return - if debug.debug_flag_enabled('index'): - trace.mutter('Reading entire index %s', - self._transport.abspath(self._name)) + if debug.debug_flag_enabled("index"): + trace.mutter("Reading entire index %s", self._transport.abspath(self._name)) if stream is None: stream = self._transport.get(self._name) if self._base_offset != 0 or not hasattr(stream, "readline"): # This is wasteful, but it is better than dealing with # adjusting all the offsets, etc. - stream = BytesIO(stream.read()[self._base_offset:]) + stream = BytesIO(stream.read()[self._base_offset :]) try: self._read_prefix(stream) self._expected_elements = 3 + self._key_length @@ -524,7 +525,7 @@ def _buffer_all(self, stream=None): self._nodes_by_key = None trailers = 0 pos = stream.tell() - lines = stream.read().split(b'\n') + lines = stream.read().split(b"\n") finally: stream.close() del lines[-1] @@ -555,8 +556,10 @@ def external_references(self, ref_list_num): """Return references that are not present in this index.""" self._buffer_all() if ref_list_num + 1 > self.node_ref_lists: - raise ValueError('No ref list %d, index has %d ref lists' - % (ref_list_num, self.node_ref_lists)) + raise ValueError( + "No ref list %d, index has %d ref lists" + % (ref_list_num, self.node_ref_lists) + ) refs = set() nodes = self._nodes for _key, (_value, ref_lists) in nodes.items(): @@ -591,9 +594,8 @@ def iter_all_entries(self): There is no defined order for the result iteration - it will be in the most efficient order for the index. """ - if debug.debug_flag_enabled('evil'): - trace.mutter_callsite(3, - "iter_all_entries scales with size of history.") + if debug.debug_flag_enabled("evil"): + trace.mutter_callsite(3, "iter_all_entries scales with size of history.") if self._nodes is None: self._buffer_all() if self.node_ref_lists: @@ -611,21 +613,21 @@ def _read_prefix(self, stream): if not options_line.startswith(_OPTION_NODE_REFS): raise BadIndexOptions(self) try: - self.node_ref_lists = int(options_line[len(_OPTION_NODE_REFS):-1]) + self.node_ref_lists = int(options_line[len(_OPTION_NODE_REFS) : -1]) except ValueError as e: raise BadIndexOptions(self) from e options_line = stream.readline() if not options_line.startswith(_OPTION_KEY_ELEMENTS): raise BadIndexOptions(self) try: - self._key_length = int(options_line[len(_OPTION_KEY_ELEMENTS):-1]) + self._key_length = int(options_line[len(_OPTION_KEY_ELEMENTS) : -1]) except ValueError as e: raise BadIndexOptions(self) from e options_line = stream.readline() if not options_line.startswith(_OPTION_LEN): raise BadIndexOptions(self) try: - self._key_count = int(options_line[len(_OPTION_LEN):-1]) + self._key_count = int(options_line[len(_OPTION_LEN) : -1]) except ValueError as e: raise BadIndexOptions(self) from e @@ -641,8 +643,7 @@ def _resolve_references(self, references): """ node_refs = [] for ref_list in references: - node_refs.append( - tuple([self._keys_by_offset[ref][0] for ref in ref_list])) + node_refs.append(tuple([self._keys_by_offset[ref][0] for ref in ref_list])) return tuple(node_refs) @staticmethod @@ -684,7 +685,7 @@ def _parsed_key_index(self, key): asking for 'b' will return 1 asking for 'e' will return 1 """ - search_key = (key, b'') + search_key = (key, b"") return self._find_index(self._parsed_key_map, search_key) def _is_parsed(self, offset): @@ -718,6 +719,7 @@ def iter_entries(self, keys): key supplied that is in the index will be returned. """ from .. import bisect_multi + keys = set(keys) if not keys: return [] @@ -736,8 +738,12 @@ def iter_entries(self, keys): if self._nodes is not None: return self._iter_entries_from_total_buffer(keys) else: - return (result[1] for result in bisect_multi.bisect_multi_bytes( - self._lookup_keys_via_location, self._size, keys)) + return ( + result[1] + for result in bisect_multi.bisect_multi_bytes( + self._lookup_keys_via_location, self._size, keys + ) + ) def iter_entries_prefix(self, keys): """Iterate over keys within the index using prefix matching. @@ -835,19 +841,26 @@ def _lookup_keys_via_location(self, location_keys): # We have the key parsed. continue index = self._parsed_key_index(key) - if (len(self._parsed_key_map) and - self._parsed_key_map[index][0] <= key and - (self._parsed_key_map[index][1] >= key or - # end of the file has been parsed - self._parsed_byte_map[index][1] == self._size)): + if ( + len(self._parsed_key_map) + and self._parsed_key_map[index][0] <= key + and ( + self._parsed_key_map[index][1] >= key + or + # end of the file has been parsed + self._parsed_byte_map[index][1] == self._size + ) + ): # the key has been parsed, so no lookup is needed even if its # not present. continue # - if we have examined this part of the file already - yes index = self._parsed_byte_index(location) - if (len(self._parsed_byte_map) and - self._parsed_byte_map[index][0] <= location and - self._parsed_byte_map[index][1] > location): + if ( + len(self._parsed_byte_map) + and self._parsed_byte_map[index][0] <= location + and self._parsed_byte_map[index][1] > location + ): # the byte region has been parsed, so no read is needed. continue length = 800 @@ -869,11 +882,9 @@ def _lookup_keys_via_location(self, location_keys): result.append(((location, key), False)) elif self.node_ref_lists: value, refs = self._nodes[key] - result.append(((location, key), - (self, key, value, refs))) + result.append(((location, key), (self, key, value, refs))) else: - result.append(((location, key), - (self, key, self._nodes[key]))) + result.append(((location, key), (self, key, self._nodes[key]))) return result # generate results: # - figure out <, >, missing, present @@ -897,19 +908,26 @@ def _lookup_keys_via_location(self, location_keys): pending_locations.update(wanted_locations) pending_references.append((location, key)) continue - result.append(((location, key), (self, key, - value, self._resolve_references(refs)))) + result.append( + ( + (location, key), + (self, key, value, self._resolve_references(refs)), + ) + ) else: - result.append(((location, key), - (self, key, self._bisect_nodes[key]))) + result.append( + ((location, key), (self, key, self._bisect_nodes[key])) + ) continue else: # has the region the key should be in, been parsed? index = self._parsed_key_index(key) - if (self._parsed_key_map[index][0] <= key and - (self._parsed_key_map[index][1] >= key or - # end of the file has been parsed - self._parsed_byte_map[index][1] == self._size)): + if self._parsed_key_map[index][0] <= key and ( + self._parsed_key_map[index][1] >= key + or + # end of the file has been parsed + self._parsed_byte_map[index][1] == self._size + ): result.append(((location, key), False)) continue # no, is the key above or below the probed location: @@ -943,8 +961,9 @@ def _lookup_keys_via_location(self, location_keys): for location, key in pending_references: # answer key references we had to look-up-late. value, refs = self._bisect_nodes[key] - result.append(((location, key), (self, key, - value, self._resolve_references(refs)))) + result.append( + ((location, key), (self, key, value, self._resolve_references(refs))) + ) return result def _parse_header_from_bytes(self, bytes): @@ -954,34 +973,33 @@ def _parse_header_from_bytes(self, bytes): :return: An offset, data tuple such as readv yields, for the unparsed data. (which may length 0). """ - signature = bytes[0:len(self._signature())] + signature = bytes[0 : len(self._signature())] if not signature == self._signature(): raise BadIndexFormatSignature(self._name, GraphIndex) - lines = bytes[len(self._signature()):].splitlines() + lines = bytes[len(self._signature()) :].splitlines() options_line = lines[0] if not options_line.startswith(_OPTION_NODE_REFS): raise BadIndexOptions(self) try: - self.node_ref_lists = int(options_line[len(_OPTION_NODE_REFS):]) + self.node_ref_lists = int(options_line[len(_OPTION_NODE_REFS) :]) except ValueError as e: raise BadIndexOptions(self) from e options_line = lines[1] if not options_line.startswith(_OPTION_KEY_ELEMENTS): raise BadIndexOptions(self) try: - self._key_length = int(options_line[len(_OPTION_KEY_ELEMENTS):]) + self._key_length = int(options_line[len(_OPTION_KEY_ELEMENTS) :]) except ValueError as e: raise BadIndexOptions(self) from e options_line = lines[2] if not options_line.startswith(_OPTION_LEN): raise BadIndexOptions(self) try: - self._key_count = int(options_line[len(_OPTION_LEN):]) + self._key_count = int(options_line[len(_OPTION_LEN) :]) except ValueError as e: raise BadIndexOptions(self) from e # calculate the bytes we have processed - header_end = (len(signature) + len(lines[0]) + len(lines[1]) + - len(lines[2]) + 3) + header_end = len(signature) + len(lines[0]) + len(lines[1]) + len(lines[2]) + 3 self._parsed_bytes(0, (), header_end, ()) # setup parsing state self._expected_elements = 3 + self._key_length @@ -1009,8 +1027,7 @@ def _parse_region(self, offset, data): return # print "[%d:%d]" % (offset, end), \ # self._parsed_byte_map[index:index + 2] - high_parsed, last_segment = self._parse_segment( - offset, data, end, index) + high_parsed, last_segment = self._parse_segment(offset, data, end, index) if last_segment: return @@ -1088,38 +1105,39 @@ def _parse_segment(self, offset, data, end, index): if not start_adjacent: # work around python bug in rfind if trim_start is None: - trim_start = data.find(b'\n') + 1 + trim_start = data.find(b"\n") + 1 else: - trim_start = data.find(b'\n', trim_start) + 1 + trim_start = data.find(b"\n", trim_start) + 1 if not (trim_start != 0): - raise AssertionError('no \n was present') + raise AssertionError("no \n was present") # print 'removing start', offset, trim_start, repr(data[:trim_start]) if not end_adjacent: # work around python bug in rfind if trim_end is None: - trim_end = data.rfind(b'\n') + 1 + trim_end = data.rfind(b"\n") + 1 else: - trim_end = data.rfind(b'\n', None, trim_end) + 1 + trim_end = data.rfind(b"\n", None, trim_end) + 1 if not (trim_end != 0): - raise AssertionError('no \n was present') + raise AssertionError("no \n was present") # print 'removing end', offset, trim_end, repr(data[trim_end:]) # adjust offset and data to the parseable data. trimmed_data = data[trim_start:trim_end] if not (trimmed_data): - raise AssertionError('read unneeded data [%d:%d] from [%d:%d]' - % (trim_start, trim_end, offset, offset + len(data))) + raise AssertionError( + "read unneeded data [%d:%d] from [%d:%d]" + % (trim_start, trim_end, offset, offset + len(data)) + ) if trim_start: offset += trim_start # print "parsing", repr(trimmed_data) # splitlines mangles the \r delimiters.. don't use it. - lines = trimmed_data.split(b'\n') + lines = trimmed_data.split(b"\n") del lines[-1] pos = offset first_key, last_key, nodes, _ = self._parse_lines(lines, pos) for key, value in nodes: self._bisect_nodes[key] = value - self._parsed_bytes(offset, first_key, - offset + len(trimmed_data), last_key) + self._parsed_bytes(offset, first_key, offset + len(trimmed_data), last_key) return offset + len(trimmed_data), last_segment def _parse_lines(self, lines, pos): @@ -1128,27 +1146,27 @@ def _parse_lines(self, lines, pos): trailers = 0 nodes = [] for line in lines: - if line == b'': + if line == b"": # must be at the end if self._size: if not (self._size == pos + 1): raise AssertionError(f"{self._size} {pos}") trailers += 1 continue - elements = line.split(b'\0') + elements = line.split(b"\0") if len(elements) != self._expected_elements: raise BadIndexData(self) # keys are tuples. Each element is a string that may occur many # times, so we intern them to save space. AB, RC, 200807 - key = tuple(elements[:self._key_length]) + key = tuple(elements[: self._key_length]) if first_key is None: first_key = key absent, references, value = elements[-3:] ref_lists = [] - for ref_string in references.split(b'\t'): - ref_lists.append(tuple([ - int(ref) for ref in ref_string.split(b'\r') if ref - ])) + for ref_string in references.split(b"\t"): + ref_lists.append( + tuple([int(ref) for ref in ref_string.split(b"\r") if ref]) + ) ref_lists = tuple(ref_lists) self._keys_by_offset[pos] = (key, absent, ref_lists, value) pos += len(line) + 1 # +1 for the \n @@ -1184,29 +1202,39 @@ def _parsed_bytes(self, start, start_key, end, end_key): # extend lower region # extend higher region # combine two regions - if (index + 1 < len(self._parsed_byte_map) and - self._parsed_byte_map[index][1] == start and - self._parsed_byte_map[index + 1][0] == end): + if ( + index + 1 < len(self._parsed_byte_map) + and self._parsed_byte_map[index][1] == start + and self._parsed_byte_map[index + 1][0] == end + ): # combine two regions - self._parsed_byte_map[index] = (self._parsed_byte_map[index][0], - self._parsed_byte_map[index + 1][1]) - self._parsed_key_map[index] = (self._parsed_key_map[index][0], - self._parsed_key_map[index + 1][1]) + self._parsed_byte_map[index] = ( + self._parsed_byte_map[index][0], + self._parsed_byte_map[index + 1][1], + ) + self._parsed_key_map[index] = ( + self._parsed_key_map[index][0], + self._parsed_key_map[index + 1][1], + ) del self._parsed_byte_map[index + 1] del self._parsed_key_map[index + 1] elif self._parsed_byte_map[index][1] == start: # extend the lower entry - self._parsed_byte_map[index] = ( - self._parsed_byte_map[index][0], end) - self._parsed_key_map[index] = ( - self._parsed_key_map[index][0], end_key) - elif (index + 1 < len(self._parsed_byte_map) and - self._parsed_byte_map[index + 1][0] == end): + self._parsed_byte_map[index] = (self._parsed_byte_map[index][0], end) + self._parsed_key_map[index] = (self._parsed_key_map[index][0], end_key) + elif ( + index + 1 < len(self._parsed_byte_map) + and self._parsed_byte_map[index + 1][0] == end + ): # extend the higher entry self._parsed_byte_map[index + 1] = ( - start, self._parsed_byte_map[index + 1][1]) + start, + self._parsed_byte_map[index + 1][1], + ) self._parsed_key_map[index + 1] = ( - start_key, self._parsed_key_map[index + 1][1]) + start_key, + self._parsed_key_map[index + 1][1], + ) else: # new entry self._parsed_byte_map.insert(index + 1, new_value) @@ -1228,10 +1256,10 @@ def _read_and_parse(self, readv_ranges): base_offset = self._base_offset if base_offset != 0: # Rewrite the ranges for the offset - readv_ranges = [(start + base_offset, size) - for start, size in readv_ranges] - readv_data = self._transport.readv(self._name, readv_ranges, True, - self._size + self._base_offset) + readv_ranges = [(start + base_offset, size) for start, size in readv_ranges] + readv_data = self._transport.readv( + self._name, readv_ranges, True, self._size + self._base_offset + ) # parse for offset, data in readv_data: offset -= base_offset @@ -1447,7 +1475,7 @@ def _move_to_front(self, hit_indices): _move_to_front propagates to all objects in self._sibling_indices by calling _move_to_front_by_name. """ - if self._indices[:len(hit_indices)] == hit_indices: + if self._indices[: len(hit_indices)] == hit_indices: # The 'hit_indices' are already at the front (and in the same # order), no need to re-order return @@ -1461,10 +1489,13 @@ def _move_to_front_by_index(self, hit_indices): Returns a list of names corresponding to the hit_indices param. """ indices_info = zip(self._index_names, self._indices) - if debug.debug_flag_enabled('index'): + if debug.debug_flag_enabled("index"): indices_info = list(indices_info) - trace.mutter('CombinedGraphIndex reordering: currently %r, ' - 'promoting %r', indices_info, hit_indices) + trace.mutter( + "CombinedGraphIndex reordering: currently %r, " "promoting %r", + indices_info, + hit_indices, + ) hit_names = [] unhit_names = [] new_hit_indices = [] @@ -1477,8 +1508,8 @@ def _move_to_front_by_index(self, hit_indices): if len(new_hit_indices) == len(hit_indices): # We've found all of the hit entries, everything else is # unhit - unhit_names.extend(self._index_names[offset + 1:]) - unhit_indices.extend(self._indices[offset + 1:]) + unhit_names.extend(self._index_names[offset + 1 :]) + unhit_indices.extend(self._indices[offset + 1 :]) break else: unhit_names.append(name) @@ -1486,8 +1517,8 @@ def _move_to_front_by_index(self, hit_indices): self._indices = new_hit_indices + unhit_indices self._index_names = hit_names + unhit_names - if debug.debug_flag_enabled('index'): - trace.mutter('CombinedGraphIndex reordered: %r', self._indices) + if debug.debug_flag_enabled("index"): + trace.mutter("CombinedGraphIndex reordered: %r", self._indices) return hit_names def _move_to_front_by_name(self, hit_names): @@ -1545,8 +1576,9 @@ def find_ancestry(self, keys, ref_list_num): # TODO: ref_list_num should really be a parameter, since # CombinedGraphIndex does not know what the ref lists # mean. - search_keys = index._find_ancestors(search_keys, - ref_list_num, parent_map, index_missing_keys) + search_keys = index._find_ancestors( + search_keys, ref_list_num, parent_map, index_missing_keys + ) # print ' \t \t%2d\t%4d\t%5d\t%5d' % ( # sub_generation, len(search_keys), # len(parent_map), len(index_missing_keys)) @@ -1592,12 +1624,13 @@ def _try_reload(self, error): """ if self._reload_func is None: return False - trace.mutter( - 'Trying to reload after getting exception: %s', str(error)) + trace.mutter("Trying to reload after getting exception: %s", str(error)) if not self._reload_func(): # We tried to reload, but nothing changed, so we fail anyway - trace.mutter('_reload_func indicated nothing has changed.' - ' Raising original exception.') + trace.mutter( + "_reload_func indicated nothing has changed." + " Raising original exception." + ) return False return True @@ -1631,10 +1664,10 @@ def add_nodes(self, nodes): :param nodes: An iterable of (key, node_refs, value) entries to add. """ if self.reference_lists: - for (key, value, node_refs) in nodes: + for key, value, node_refs in nodes: self.add_node(key, value, node_refs) else: - for (key, value) in nodes: + for key, value in nodes: self.add_node(key, value) def iter_all_entries(self): @@ -1644,9 +1677,8 @@ def iter_all_entries(self): defined order for the result iteration - it will be in the most efficient order for the index (in this case dictionary hash order). """ - if debug.debug_flag_enabled('evil'): - trace.mutter_callsite(3, - "iter_all_entries scales with size of history.") + if debug.debug_flag_enabled("evil"): + trace.mutter_callsite(3, "iter_all_entries scales with size of history.") if self.reference_lists: for key, (absent, references, value) in self._nodes.items(): if not absent: @@ -1725,8 +1757,9 @@ def validate(self): def __lt__(self, other): # We don't really care about the order, just that there is an order. - if (not isinstance(other, GraphIndex) and - not isinstance(other, InMemoryGraphIndex)): + if not isinstance(other, GraphIndex) and not isinstance( + other, InMemoryGraphIndex + ): raise TypeError(other) return hash(self) < hash(other) @@ -1741,8 +1774,7 @@ class GraphIndexPrefixAdapter: nodes and references being added will have prefix prepended. """ - def __init__(self, adapted, prefix, missing_key_length, - add_nodes_callback=None): + def __init__(self, adapted, prefix, missing_key_length, add_nodes_callback=None): """Construct an adapter against adapted with prefix.""" self.adapted = adapted self.prefix_key = prefix + (None,) * missing_key_length @@ -1761,17 +1793,17 @@ def add_nodes(self, nodes): try: # Add prefix_key to each reference node_refs is a tuple of tuples, # so split it apart, and add prefix_key to the internal reference - for (key, value, node_refs) in nodes: - adjusted_references = ( - tuple(tuple(self.prefix + ref_node for ref_node in ref_list) - for ref_list in node_refs)) - translated_nodes.append((self.prefix + key, value, - adjusted_references)) + for key, value, node_refs in nodes: + adjusted_references = tuple( + tuple(self.prefix + ref_node for ref_node in ref_list) + for ref_list in node_refs + ) + translated_nodes.append((self.prefix + key, value, adjusted_references)) except ValueError: # XXX: TODO add an explicit interface for getting the reference list # status, to handle this bit of user-friendliness in the API more # explicitly. - for (key, value) in nodes: + for key, value in nodes: translated_nodes.append((self.prefix + key, value)) self.add_nodes_callback(translated_nodes) @@ -1786,21 +1818,29 @@ def add_node(self, key, value, references=()): :param value: The value to associate with the key. It may be any bytes as long as it does not contain \0 or \n. """ - self.add_nodes(((key, value, references), )) + self.add_nodes(((key, value, references),)) def _strip_prefix(self, an_iter): """Strip prefix data from nodes and return it.""" for node in an_iter: # cross checks - if node[1][:self.prefix_len] != self.prefix: + if node[1][: self.prefix_len] != self.prefix: raise BadIndexData(self) for ref_list in node[3]: for ref_node in ref_list: - if ref_node[:self.prefix_len] != self.prefix: + if ref_node[: self.prefix_len] != self.prefix: raise BadIndexData(self) - yield node[0], node[1][self.prefix_len:], node[2], ( - tuple(tuple(ref_node[self.prefix_len:] for ref_node in ref_list) - for ref_list in node[3])) + yield ( + node[0], + node[1][self.prefix_len :], + node[2], + ( + tuple( + tuple(ref_node[self.prefix_len :] for ref_node in ref_list) + for ref_list in node[3] + ) + ), + ) def iter_all_entries(self): """Iterate over all keys within the index. @@ -1822,8 +1862,9 @@ def iter_entries(self, keys): defined order for the result iteration - it will be in the most efficient order for the index (keys iteration order in this case). """ - return self._strip_prefix(self.adapted.iter_entries( - self.prefix + key for key in keys)) + return self._strip_prefix( + self.adapted.iter_entries(self.prefix + key for key in keys) + ) def iter_entries_prefix(self, keys): """Iterate over keys within the index using prefix matching. @@ -1842,8 +1883,9 @@ def iter_entries_prefix(self, keys): will be returned, and every match that is in the index will be returned. """ - return self._strip_prefix(self.adapted.iter_entries_prefix( - self.prefix + key for key in keys)) + return self._strip_prefix( + self.adapted.iter_entries_prefix(self.prefix + key for key in keys) + ) def key_count(self): """Return an estimate of the number of keys in this index. @@ -1895,7 +1937,7 @@ def _iter_entries_prefix(index_or_builder, nodes_by_key, keys): for value in values_view: # each value is the key:value:node refs tuple # ready to yield. - yield (index_or_builder, ) + value + yield (index_or_builder,) + value else: # the last thing looked up was a terminal element - yield (index_or_builder, ) + key_dict + yield (index_or_builder,) + key_dict diff --git a/breezy/bzr/inventory.py b/breezy/bzr/inventory.py index b629294374..da36ed31d1 100644 --- a/breezy/bzr/inventory.py +++ b/breezy/bzr/inventory.py @@ -27,13 +27,16 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy.bzr import ( chk_map, generate_ids, ) -""") +""", +) from .. import errors, osutils from .._bzr_rs import ROOT_ID @@ -48,7 +51,6 @@ class InvalidEntryName(errors.InternalBzrError): - _fmt = "Invalid entry name: %(name)s" def __init__(self, name): @@ -57,7 +59,6 @@ def __init__(self, name): class DuplicateFileId(errors.BzrError): - _fmt = "File id {%(file_id)s} already exists in inventory as %(entry)s" def __init__(self, file_id, entry): @@ -96,9 +97,11 @@ def id2path(self, file_id): :raises NoSuchId: If file_id is not present in the inventory. """ # get all names, skipping root - return '/'.join(reversed( - [parent.name for parent in - self._iter_file_id_parents(file_id)][:-1])) + return "/".join( + reversed( + [parent.name for parent in self._iter_file_id_parents(file_id)][:-1] + ) + ) def iter_entries(self, from_dir=None, recursive=True): """Return (path, entry) pairs, in order by name. @@ -111,7 +114,7 @@ def iter_entries(self, from_dir=None, recursive=True): if self.root is None: return from_dir = self.root.file_id - yield '', self.root + yield "", self.root elif not isinstance(from_dir, bytes): from_dir = from_dir.file_id @@ -122,7 +125,7 @@ def iter_entries(self, from_dir=None, recursive=True): yield from children return children = deque(children) - stack = [('', children)] + stack = [("", children)] while stack: from_dir_relpath, children = stack[-1] @@ -133,15 +136,17 @@ def iter_entries(self, from_dir=None, recursive=True): # and 'f' doesn't begin with one, we can do a string op, rather # than the checks of pathjoin(), though this means that all paths # start with a slash - path = from_dir_relpath + '/' + name + path = from_dir_relpath + "/" + name yield path[1:], ie - if ie.kind != 'directory': + if ie.kind != "directory": continue # But do this child first - new_children = [(c.name, c) for c in self.iter_sorted_children(ie.file_id)] + new_children = [ + (c.name, c) for c in self.iter_sorted_children(ie.file_id) + ] new_children = deque(new_children) stack.append((path, new_children)) # Break out of inner loop, so that we start outer loop with child @@ -187,8 +192,7 @@ def iter_entries_by_dir(self, from_dir=None, specific_file_ids=None): if self.root is None: return # Optimize a common case - if (specific_file_ids is not None - and len(specific_file_ids) == 1): + if specific_file_ids is not None and len(specific_file_ids) == 1: file_id = list(specific_file_ids)[0] if file_id is not None: try: @@ -199,9 +203,8 @@ def iter_entries_by_dir(self, from_dir=None, specific_file_ids=None): yield path, self.get_entry(file_id) return from_dir = self.root - if (specific_file_ids is None - or self.root.file_id in specific_file_ids): - yield '', self.root + if specific_file_ids is None or self.root.file_id in specific_file_ids: + yield "", self.root elif isinstance(from_dir, bytes): from_dir = self.get_entry(from_dir) else: @@ -222,27 +225,26 @@ def add_ancestors(file_id): if parent_id not in parents: parents.add(parent_id) add_ancestors(parent_id) + for file_id in specific_file_ids: add_ancestors(file_id) else: parents = None - stack = [('', from_dir)] + stack = [("", from_dir)] while stack: cur_relpath, cur_dir = stack.pop() child_dirs = [] for child_ie in self.iter_sorted_children(cur_dir.file_id): - child_relpath = cur_relpath + child_ie.name - if (specific_file_ids is None - or child_ie.file_id in specific_file_ids): + if specific_file_ids is None or child_ie.file_id in specific_file_ids: yield child_relpath, child_ie - if child_ie.kind == 'directory': + if child_ie.kind == "directory": if parents is None or child_ie.file_id in parents: - child_dirs.append((child_relpath + '/', child_ie)) + child_dirs.append((child_relpath + "/", child_ie)) stack.extend(reversed(child_dirs)) def _make_delta(self, old: "CommonInventory"): @@ -256,13 +258,21 @@ def _make_delta(self, old: "CommonInventory"): for file_id in deletes: delta.append((old.id2path(file_id), None, file_id, None)) for file_id in adds: - delta.append((None, self.id2path(file_id), - file_id, self.get_entry(file_id))) + delta.append( + (None, self.id2path(file_id), file_id, self.get_entry(file_id)) + ) for file_id in common: if old.get_entry(file_id) != self.get_entry(file_id): - delta.append((old.id2path(file_id), self.id2path(file_id), - file_id, self.get_entry(file_id))) + delta.append( + ( + old.id2path(file_id), + self.id2path(file_id), + file_id, + self.get_entry(file_id), + ) + ) from .inventory_delta import InventoryDelta + return InventoryDelta(delta) def make_entry(self, kind, name, parent_id, file_id=None): @@ -280,11 +290,11 @@ def descend(dir_ie, dir_path): for ie in self.iter_sorted_children(dir_ie.file_id): child_path = osutils.pathjoin(dir_path, ie.name) accum.append((child_path, ie)) - if ie.kind == 'directory': + if ie.kind == "directory": descend(ie, child_path) if self.root is not None: - descend(self.root, '') + descend(self.root, "") return accum def get_entry_by_path_partial(self, relpath): @@ -311,8 +321,8 @@ def get_entry_by_path_partial(self, relpath): cie = self.get_child(parent.file_id, f) if cie is None: return None, None, None - if cie.kind == 'tree-reference': - return cie, names[:i + 1], names[i + 1:] + if cie.kind == "tree-reference": + return cie, names[: i + 1], names[i + 1 :] parent = cie except KeyError: # or raise an error? @@ -393,9 +403,8 @@ def filter(self, specific_fileids): directories_to_expand = set() for _path, entry in entries: file_id = entry.file_id - if (file_id in specific_fileids or - entry.parent_id in directories_to_expand): - if entry.kind == 'directory': + if file_id in specific_fileids or entry.parent_id in directories_to_expand: + if entry.kind == "directory": directories_to_expand.add(file_id) elif file_id not in interesting_parents: continue @@ -463,14 +472,14 @@ def __init__(self, root_id=ROOT_ID, revision_id=None): self._byid = {} self._children = {} if root_id is not None: - self._set_root(InventoryDirectory(root_id, '', None)) + self._set_root(InventoryDirectory(root_id, "", None)) self.revision_id = revision_id def change_root_id(self, file_id): # unlinkit from the byid index children = self._children.pop(self.root.file_id) del self._byid[self.root.file_id] - self.root = InventoryDirectory(file_id, '', None) + self.root = InventoryDirectory(file_id, "", None) # and link it into the index with the new changed id. self._byid[self.root.file_id] = self.root self._children[self.root.file_id] = children @@ -480,7 +489,7 @@ def change_root_id(self, file_id): def rename_id(self, old_file_id, new_file_id): self._byid[new_file_id] = self._byid.pop(old_file_id) - if self._byid[new_file_id].kind == 'directory': + if self._byid[new_file_id].kind == "directory": self._children[new_file_id] = self._children.pop(old_file_id) for child in self._children[new_file_id].values(): child.parent_id = new_file_id @@ -489,10 +498,10 @@ def rename_id(self, old_file_id, new_file_id): def __repr__(self): # More than one page of ouput is not useful anymore to debug max_len = 2048 - closing = '...}' + closing = "...}" contents = repr(self._byid) if len(contents) > max_len: - contents = contents[:(max_len - len(closing))] + closing + contents = contents[: (max_len - len(closing))] + closing return f"" def get_children(self, file_id): @@ -552,11 +561,15 @@ def apply_delta(self, delta): # starting with the longest paths, thus ensuring parents are examined # after their children, which means that everything we examine has no # modified children remaining by the time we examine it. - for old_path, file_id in sorted(((op, f) for op, np, f, e in delta - if op is not None), reverse=True): + for old_path, file_id in sorted( + ((op, f) for op, np, f, e in delta if op is not None), reverse=True + ): if self.id2path(file_id) != old_path: - raise errors.InconsistentDelta(old_path, file_id, - f"Entry was at wrong other path {self.id2path(file_id)!r}.") + raise errors.InconsistentDelta( + old_path, + file_id, + f"Entry was at wrong other path {self.id2path(file_id)!r}.", + ) # Remove file_id and the unaltered children. If file_id is not # being deleted it will be reinserted back later. ie = self._byid.pop(file_id) @@ -575,30 +588,40 @@ def apply_delta(self, delta): # longest, ensuring that items which were modified and whose parents in # the resulting inventory were also modified, are inserted after their # parents. - for new_path, _f, new_entry in sorted((np, f, e) for op, np, f, e in - delta if np is not None): + for new_path, _f, new_entry in sorted( + (np, f, e) for op, np, f, e in delta if np is not None + ): try: self.add(new_entry) except DuplicateFileId as ex: - raise errors.InconsistentDelta(new_path, new_entry.file_id, - "New id is already present in target.") from ex + raise errors.InconsistentDelta( + new_path, new_entry.file_id, "New id is already present in target." + ) from ex except AttributeError as ex: - raise errors.InconsistentDelta(new_path, new_entry.file_id, - "Parent is not a directory.") from ex + raise errors.InconsistentDelta( + new_path, new_entry.file_id, "Parent is not a directory." + ) from ex if self.id2path(new_entry.file_id) != new_path: - raise errors.InconsistentDelta(new_path, new_entry.file_id, - "New path is not consistent with parent path.") - if new_entry.kind == 'directory': + raise errors.InconsistentDelta( + new_path, + new_entry.file_id, + "New path is not consistent with parent path.", + ) + if new_entry.kind == "directory": self._children[new_entry.file_id] = children.pop(new_entry.file_id, {}) if len(children): # Get the parent id that was deleted parent_id, children = children.popitem() - raise errors.InconsistentDelta("", parent_id, - "The file id was deleted but its children were not deleted.") - - def create_by_apply_delta(self, inventory_delta, new_revision_id, - propagate_caches=False): + raise errors.InconsistentDelta( + "", + parent_id, + "The file id was deleted but its children were not deleted.", + ) + + def create_by_apply_delta( + self, inventory_delta, new_revision_id, propagate_caches=False + ): """See CHKInventory.create_by_apply_delta().""" new_inv = self.copy() new_inv.apply_delta(inventory_delta) @@ -609,7 +632,7 @@ def _set_root(self, ie): self.root = ie self._byid = {self.root.file_id: self.root} if self.root.file_id in self._children: - raise AssertionError('Root id already in children') + raise AssertionError("Root id already in children") self._children = {self.root.file_id: {}} def copy(self): @@ -681,21 +704,25 @@ def add(self, entry): try: parent = self._byid[entry.parent_id] except KeyError as e: - raise errors.InconsistentDelta("", entry.parent_id, - "Parent not in inventory.") from e - if parent.kind != 'directory': - raise errors.InconsistentDelta(self.id2path(entry.parent_id), - entry.file_id, - "Parent is not a directory.") + raise errors.InconsistentDelta( + "", entry.parent_id, "Parent not in inventory." + ) from e + if parent.kind != "directory": + raise errors.InconsistentDelta( + self.id2path(entry.parent_id), + entry.file_id, + "Parent is not a directory.", + ) siblings = self._children[parent.file_id] if entry.name in siblings: raise errors.InconsistentDelta( self.id2path(siblings[entry.name].file_id), entry.file_id, - "Path already versioned") + "Path already versioned", + ) siblings[entry.name] = entry self._byid[entry.file_id] = entry - if entry.kind == 'directory': + if entry.kind == "directory": self._children[entry.file_id] = {} return entry @@ -711,10 +738,10 @@ def add_path(self, relpath, kind, file_id=None, parent_id=None): if len(parts) == 0: if file_id is None: file_id = generate_ids.gen_root_id() - self.root = InventoryDirectory(file_id, '', None) + self.root = InventoryDirectory(file_id, "", None) self._byid = {self.root.file_id: self.root} if self.root.file_id in self._children: - raise AssertionError('Root id already in children') + raise AssertionError("Root id already in children") self._children = {self.root.file_id: {}} return self.root else: @@ -767,7 +794,7 @@ def __ne__(self, other): return not self.__eq__(other) def __hash__(self): - raise ValueError('not hashable') + raise ValueError("not hashable") def _iter_file_id_parents(self, file_id): """Yield the parents of file_id up to the root.""" @@ -780,7 +807,7 @@ def _iter_file_id_parents(self, file_id): file_id = ie.parent_id def has_id(self, file_id): - return (file_id in self._byid) + return file_id in self._byid def _make_delta(self, old: CommonInventory): """Make an inventory delta from two inventories.""" @@ -798,8 +825,9 @@ def _make_delta(self, old: CommonInventory): for file_id in deletes: delta.append((old.id2path(file_id), None, file_id, None)) for file_id in adds: - delta.append((None, self.id2path(file_id), - file_id, self.get_entry(file_id))) + delta.append( + (None, self.id2path(file_id), file_id, self.get_entry(file_id)) + ) for file_id in common: new_ie = new_getter(file_id) old_ie = old_getter(file_id) @@ -810,9 +838,11 @@ def _make_delta(self, old: CommonInventory): if old_ie is new_ie or old_ie == new_ie: continue else: - delta.append((old.id2path(file_id), self.id2path(file_id), - file_id, new_ie)) + delta.append( + (old.id2path(file_id), self.id2path(file_id), file_id, new_ie) + ) from .inventory_delta import InventoryDelta + return InventoryDelta(delta) def remove_recursive_id(self, file_id): @@ -825,17 +855,17 @@ def remove_recursive_id(self, file_id): while to_find_delete: ie = to_find_delete.pop() to_delete.append(ie) - if ie.kind == 'directory': + if ie.kind == "directory": to_find_delete.extend(self.get_children(ie.file_id).values()) for ie in reversed(to_delete): del self._byid[ie.file_id] - if ie.kind == 'directory': + if ie.kind == "directory": if self._children[ie.file_id]: - raise AssertionError('Directory not empty') + raise AssertionError("Directory not empty") del self._children[ie.file_id] else: if ie.parent_id is None: - raise AssertionError('Root id already in children') + raise AssertionError("Root id already in children") if ie.parent_id is not None: del self._children[ie.parent_id][ie.name] else: @@ -855,12 +885,15 @@ def rename(self, file_id, new_parent_id, new_name): new_parent = self._byid[new_parent_id] if new_name in self._children[new_parent.file_id]: - raise errors.BzrError(f"{new_name!r} already exists in {self.id2path(new_parent_id)!r}") + raise errors.BzrError( + f"{new_name!r} already exists in {self.id2path(new_parent_id)!r}" + ) new_parent_idpath = self.get_idpath(new_parent_id) if file_id in new_parent_idpath: raise errors.BzrError( - f"cannot move directory {self.id2path(file_id)!r} into a subdirectory of itself, {self.id2path(new_parent_id)!r}") + f"cannot move directory {self.id2path(file_id)!r} into a subdirectory of itself, {self.id2path(new_parent_id)!r}" + ) file_ie = self._byid[file_id] old_parent = self._byid[file_ie.parent_id] @@ -934,19 +967,32 @@ def get_children(self, dir_id): return children # No longer supported if self.parent_id_basename_to_file_id is None: - raise AssertionError("Inventories without" - " parent_id_basename_to_file_id are no longer supported") + raise AssertionError( + "Inventories without" + " parent_id_basename_to_file_id are no longer supported" + ) result = {} # XXX: Todo - use proxy objects for the children rather than loading # all when the attribute is referenced. child_keys = set() - for (_parent_id, _name_utf8), file_id in self.parent_id_basename_to_file_id.iteritems( - key_filter=[StaticTuple(dir_id,)]): - child_keys.add(StaticTuple(file_id,)) + for ( + _parent_id, + _name_utf8, + ), file_id in self.parent_id_basename_to_file_id.iteritems( + key_filter=[ + StaticTuple( + dir_id, + ) + ] + ): + child_keys.add( + StaticTuple( + file_id, + ) + ) cached = set() for file_id_key in child_keys: - entry = self._fileid_to_entry_cache.get( - file_id_key[0], None) + entry = self._fileid_to_entry_cache.get(file_id_key[0], None) if entry is not None: result[entry.name] = entry cached.add(file_id_key) @@ -994,11 +1040,10 @@ def _expand_fileids_to_parents_and_children(self, file_ids): children_of_parent_id = {} # It is okay if some of the fileids are missing for entry in self._getitems(file_ids): - if entry.kind == 'directory': + if entry.kind == "directory": directories_to_expand.add(entry.file_id) interesting.add(entry.parent_id) - children_of_parent_id.setdefault(entry.parent_id, set() - ).add(entry.file_id) + children_of_parent_id.setdefault(entry.parent_id, set()).add(entry.file_id) # Now, interesting has all of the direct parents, but not the # parents of those parents. It also may have some duplicates with @@ -1012,8 +1057,9 @@ def _expand_fileids_to_parents_and_children(self, file_ids): next_parents = set() for entry in self._getitems(remaining_parents): next_parents.add(entry.parent_id) - children_of_parent_id.setdefault(entry.parent_id, set() - ).add(entry.file_id) + children_of_parent_id.setdefault(entry.parent_id, set()).add( + entry.file_id + ) # Remove any search tips we've already processed remaining_parents = next_parents.difference(interesting) interesting.update(remaining_parents) @@ -1023,17 +1069,23 @@ def _expand_fileids_to_parents_and_children(self, file_ids): while directories_to_expand: # Expand directories by looking in the # parent_id_basename_to_file_id map - keys = [StaticTuple(f,).intern() for f in directories_to_expand] + keys = [ + StaticTuple( + f, + ).intern() + for f in directories_to_expand + ] directories_to_expand = set() items = self.parent_id_basename_to_file_id.iteritems(keys) next_file_ids = {item[1] for item in items} next_file_ids = next_file_ids.difference(interesting) interesting.update(next_file_ids) for entry in self._getitems(next_file_ids): - if entry.kind == 'directory': + if entry.kind == "directory": directories_to_expand.add(entry.file_id) - children_of_parent_id.setdefault(entry.parent_id, set() - ).add(entry.file_id) + children_of_parent_id.setdefault(entry.parent_id, set()).add( + entry.file_id + ) return interesting, children_of_parent_id def filter(self, specific_fileids): @@ -1044,9 +1096,10 @@ def filter(self, specific_fileids): The result may or may not reference the underlying inventory so it should be treated as immutable. """ - (interesting, - parent_to_children) = self._expand_fileids_to_parents_and_children( - specific_fileids) + ( + interesting, + parent_to_children, + ) = self._expand_fileids_to_parents_and_children(specific_fileids) # There is some overlap here, but we assume that all interesting items # are in the _fileid_to_entry_cache because we had to read them to # determine if they were a dir we wanted to recurse, or just a file @@ -1061,12 +1114,11 @@ def filter(self, specific_fileids): # parent_to_children with at least the tree root.) return other cache = self._fileid_to_entry_cache - remaining_children = deque( - parent_to_children[self.root_id]) + remaining_children = deque(parent_to_children[self.root_id]) while remaining_children: file_id = remaining_children.popleft() ie = cache[file_id] - if ie.kind == 'directory': + if ie.kind == "directory": ie = ie.copy() # We create a copy to depopulate the .children attribute # TODO: depending on the uses of 'other' we should probably alwyas # '.copy()' to prevent someone from mutating other and @@ -1082,8 +1134,9 @@ def _bytes_to_entry(self, bytes): self._fileid_to_entry_cache[result.file_id] = result return result - def create_by_apply_delta(self, inventory_delta, new_revision_id, - propagate_caches=False): + def create_by_apply_delta( + self, inventory_delta, new_revision_id, propagate_caches=False + ): """Create a new CHKInventory by applying inventory_delta to this one. See the inventory developers documentation for the theory behind @@ -1101,15 +1154,15 @@ def create_by_apply_delta(self, inventory_delta, new_revision_id, if propagate_caches: # Just propagate the path-to-fileid cache for now result._path_to_fileid_cache = self._path_to_fileid_cache.copy() - search_key_func = chk_map.search_key_registry.get( - self._search_key_name) + search_key_func = chk_map.search_key_registry.get(self._search_key_name) self.id_to_entry._ensure_root() maximum_size = self.id_to_entry._root_node.maximum_size result.revision_id = new_revision_id result.id_to_entry = chk_map.CHKMap( self.id_to_entry._store, self.id_to_entry.key(), - search_key_func=search_key_func) + search_key_func=search_key_func, + ) result.id_to_entry._ensure_root() result.id_to_entry._root_node.set_maximum_size(maximum_size) # Change to apply to the parent_id_basename delta. The dict maps @@ -1122,7 +1175,8 @@ def create_by_apply_delta(self, inventory_delta, new_revision_id, result.parent_id_basename_to_file_id = chk_map.CHKMap( self.parent_id_basename_to_file_id._store, self.parent_id_basename_to_file_id.key(), - search_key_func=search_key_func) + search_key_func=search_key_func, + ) result.parent_id_basename_to_file_id._ensure_root() self.parent_id_basename_to_file_id._ensure_root() result_p_id_root = result.parent_id_basename_to_file_id._root_node @@ -1147,7 +1201,7 @@ def create_by_apply_delta(self, inventory_delta, new_revision_id, altered = set() for old_path, new_path, file_id, entry in inventory_delta: # file id changes - if new_path == '': + if new_path == "": result.root_id = file_id if new_path is None: # Make a delete: @@ -1161,7 +1215,9 @@ def create_by_apply_delta(self, inventory_delta, new_revision_id, pass deletes.add(file_id) else: - new_key = StaticTuple(file_id,) + new_key = StaticTuple( + file_id, + ) new_value = _chk_inventory_entry_to_bytes(entry) # Update caches. It's worth doing this whether # we're propagating the old caches or not. @@ -1170,11 +1226,15 @@ def create_by_apply_delta(self, inventory_delta, new_revision_id, if old_path is None: old_key = None else: - old_key = StaticTuple(file_id,) + old_key = StaticTuple( + file_id, + ) if self.id2path(file_id) != old_path: - raise errors.InconsistentDelta(old_path, file_id, - "Entry was at wrong other path %r." % - self.id2path(file_id)) + raise errors.InconsistentDelta( + old_path, + file_id, + "Entry was at wrong other path %r." % self.id2path(file_id), + ) altered.add(file_id) id_to_entry_delta.append(StaticTuple(old_key, new_key, new_value)) if result.parent_id_basename_to_file_id is not None: @@ -1196,23 +1256,27 @@ def create_by_apply_delta(self, inventory_delta, new_revision_id, # Transform a change into explicit delete/add preserving # a possible match on the key from a different file id. if old_key is not None: - parent_id_basename_delta.setdefault( - old_key, [None, None])[0] = old_key + parent_id_basename_delta.setdefault(old_key, [None, None])[ + 0 + ] = old_key if new_key is not None: - parent_id_basename_delta.setdefault( - new_key, [None, None])[1] = new_value + parent_id_basename_delta.setdefault(new_key, [None, None])[ + 1 + ] = new_value # validate that deletes are complete. for file_id in deletes: entry = self.get_entry(file_id) - if entry.kind != 'directory': + if entry.kind != "directory": continue # This loop could potentially be better by using the id_basename # map to just get the child file ids. for child in self.iter_sorted_children(entry.file_id): if child.file_id not in altered: - raise errors.InconsistentDelta(self.id2path(child.file_id), - child.file_id, "Child not deleted or reparented when " - "parent deleted.") + raise errors.InconsistentDelta( + self.id2path(child.file_id), + child.file_id, + "Child not deleted or reparented when " "parent deleted.", + ) result.id_to_entry.apply_delta(id_to_entry_delta) if parent_id_basename_delta: # Transform the parent_id_basename delta data into a linear delta @@ -1226,18 +1290,25 @@ def create_by_apply_delta(self, inventory_delta, new_revision_id, else: delta_list.append((old_key, None, None)) result.parent_id_basename_to_file_id.apply_delta(delta_list) - parents.discard(('', None)) + parents.discard(("", None)) for parent_path, parent in parents: try: - if result.get_entry(parent).kind != 'directory': - raise errors.InconsistentDelta(result.id2path(parent), parent, - 'Not a directory, but given children') + if result.get_entry(parent).kind != "directory": + raise errors.InconsistentDelta( + result.id2path(parent), + parent, + "Not a directory, but given children", + ) except errors.NoSuchId as e: - raise errors.InconsistentDelta("", parent, - "Parent is not present in resulting inventory.") from e + raise errors.InconsistentDelta( + "", parent, "Parent is not present in resulting inventory." + ) from e if result.path2id(parent_path) != parent: - raise errors.InconsistentDelta(parent_path, parent, - f"Parent has wrong path {result.path2id(parent_path)!r}.") + raise errors.InconsistentDelta( + parent_path, + parent, + f"Parent has wrong path {result.path2id(parent_path)!r}.", + ) return result @classmethod @@ -1250,54 +1321,72 @@ def deserialise(klass, chk_store, lines, expected_revision_id): for. :return: A CHKInventory """ - if not lines[-1].endswith(b'\n'): + if not lines[-1].endswith(b"\n"): raise ValueError("last line should have trailing eol\n") - if lines[0] != b'chkinventory:\n': + if lines[0] != b"chkinventory:\n": raise ValueError(f"not a serialised CHKInventory: {bytes!r}") info = {} - allowed_keys = frozenset((b'root_id', b'revision_id', - b'parent_id_basename_to_file_id', - b'search_key_name', b'id_to_entry')) + allowed_keys = frozenset( + ( + b"root_id", + b"revision_id", + b"parent_id_basename_to_file_id", + b"search_key_name", + b"id_to_entry", + ) + ) for line in lines[1:]: - key, value = line.rstrip(b'\n').split(b': ', 1) + key, value = line.rstrip(b"\n").split(b": ", 1) if key not in allowed_keys: - raise errors.BzrError(f'Unknown key in inventory: {key!r}\n{bytes!r}') + raise errors.BzrError(f"Unknown key in inventory: {key!r}\n{bytes!r}") if key in info: - raise errors.BzrError(f'Duplicate key in inventory: {key!r}\n{bytes!r}') + raise errors.BzrError(f"Duplicate key in inventory: {key!r}\n{bytes!r}") info[key] = value - revision_id = info[b'revision_id'] - root_id = info[b'root_id'] - search_key_name = info.get(b'search_key_name', b'plain') - parent_id_basename_to_file_id = info.get( - b'parent_id_basename_to_file_id', None) - if not parent_id_basename_to_file_id.startswith(b'sha1:'): - raise ValueError('parent_id_basename_to_file_id should be a sha1' - f' key not {parent_id_basename_to_file_id!r}') - id_to_entry = info[b'id_to_entry'] - if not id_to_entry.startswith(b'sha1:'): - raise ValueError(f'id_to_entry should be a sha1 key not {id_to_entry!r}') + revision_id = info[b"revision_id"] + root_id = info[b"root_id"] + search_key_name = info.get(b"search_key_name", b"plain") + parent_id_basename_to_file_id = info.get(b"parent_id_basename_to_file_id", None) + if not parent_id_basename_to_file_id.startswith(b"sha1:"): + raise ValueError( + "parent_id_basename_to_file_id should be a sha1" + f" key not {parent_id_basename_to_file_id!r}" + ) + id_to_entry = info[b"id_to_entry"] + if not id_to_entry.startswith(b"sha1:"): + raise ValueError(f"id_to_entry should be a sha1 key not {id_to_entry!r}") result = CHKInventory(search_key_name) result.revision_id = revision_id result.root_id = root_id - search_key_func = chk_map.search_key_registry.get( - result._search_key_name) + search_key_func = chk_map.search_key_registry.get(result._search_key_name) if parent_id_basename_to_file_id is not None: result.parent_id_basename_to_file_id = chk_map.CHKMap( - chk_store, StaticTuple(parent_id_basename_to_file_id,), - search_key_func=search_key_func) + chk_store, + StaticTuple( + parent_id_basename_to_file_id, + ), + search_key_func=search_key_func, + ) else: result.parent_id_basename_to_file_id = None - result.id_to_entry = chk_map.CHKMap(chk_store, - StaticTuple(id_to_entry,), - search_key_func=search_key_func) + result.id_to_entry = chk_map.CHKMap( + chk_store, + StaticTuple( + id_to_entry, + ), + search_key_func=search_key_func, + ) if (result.revision_id,) != expected_revision_id: - raise ValueError(f"Mismatched revision id and expected: {result.revision_id!r}, {expected_revision_id!r}") + raise ValueError( + f"Mismatched revision id and expected: {result.revision_id!r}, {expected_revision_id!r}" + ) return result @classmethod - def from_inventory(klass, chk_store, inventory, maximum_size=0, search_key_name=b'plain'): + def from_inventory( + klass, chk_store, inventory, maximum_size=0, search_key_name=b"plain" + ): """Create a CHKInventory from an existing inventory. The content of inventory is copied into the chk_store, and a @@ -1316,38 +1405,51 @@ def from_inventory(klass, chk_store, inventory, maximum_size=0, search_key_name= id_to_entry_dict = {} parent_id_basename_dict = {} for _path, entry in inventory.iter_entries(): - key = StaticTuple(entry.file_id,).intern() + key = StaticTuple( + entry.file_id, + ).intern() id_to_entry_dict[key] = _chk_inventory_entry_to_bytes(entry) p_id_key = parent_id_basename_key(entry) parent_id_basename_dict[p_id_key] = entry.file_id - result._populate_from_dicts(chk_store, id_to_entry_dict, - parent_id_basename_dict, maximum_size=maximum_size) + result._populate_from_dicts( + chk_store, + id_to_entry_dict, + parent_id_basename_dict, + maximum_size=maximum_size, + ) return result - def _populate_from_dicts(self, chk_store, id_to_entry_dict, - parent_id_basename_dict, maximum_size): - search_key_func = chk_map.search_key_registry.get( - self._search_key_name) - root_key = chk_map.CHKMap.from_dict(chk_store, id_to_entry_dict, - maximum_size=maximum_size, key_width=1, - search_key_func=search_key_func) - self.id_to_entry = chk_map.CHKMap(chk_store, root_key, - search_key_func) - root_key = chk_map.CHKMap.from_dict(chk_store, - parent_id_basename_dict, - maximum_size=maximum_size, key_width=2, - search_key_func=search_key_func) - self.parent_id_basename_to_file_id = chk_map.CHKMap(chk_store, - root_key, search_key_func) + def _populate_from_dicts( + self, chk_store, id_to_entry_dict, parent_id_basename_dict, maximum_size + ): + search_key_func = chk_map.search_key_registry.get(self._search_key_name) + root_key = chk_map.CHKMap.from_dict( + chk_store, + id_to_entry_dict, + maximum_size=maximum_size, + key_width=1, + search_key_func=search_key_func, + ) + self.id_to_entry = chk_map.CHKMap(chk_store, root_key, search_key_func) + root_key = chk_map.CHKMap.from_dict( + chk_store, + parent_id_basename_dict, + maximum_size=maximum_size, + key_width=2, + search_key_func=search_key_func, + ) + self.parent_id_basename_to_file_id = chk_map.CHKMap( + chk_store, root_key, search_key_func + ) def _parent_id_basename_key(self, entry): """Create a key for a entry in a parent_id_basename_to_file_id index.""" if entry.parent_id is not None: parent_id = entry.parent_id else: - parent_id = b'' - return StaticTuple(parent_id, entry.name.encode('utf8')).intern() + parent_id = b"" + return StaticTuple(parent_id, entry.name.encode("utf8")).intern() def get_entry(self, file_id): """Map a single file_id -> InventoryEntry.""" @@ -1358,7 +1460,16 @@ def get_entry(self, file_id): return result try: return self._bytes_to_entry( - next(self.id_to_entry.iteritems([StaticTuple(file_id,)]))[1]) + next( + self.id_to_entry.iteritems( + [ + StaticTuple( + file_id, + ) + ] + ) + )[1] + ) except StopIteration as e: # really we're passing an inventory, not a tree... raise errors.NoSuchId(self, file_id) from e @@ -1377,7 +1488,12 @@ def _getitems(self, file_ids): remaining.append(file_id) else: result.append(entry) - file_keys = [StaticTuple(f,).intern() for f in remaining] + file_keys = [ + StaticTuple( + f, + ).intern() + for f in remaining + ] for _file_key, value in self.id_to_entry.iteritems(file_keys): entry = self._bytes_to_entry(value) result.append(entry) @@ -1388,8 +1504,20 @@ def has_id(self, file_id): # Perhaps have an explicit 'contains' method on CHKMap ? if self._fileid_to_entry_cache.get(file_id, None) is not None: return True - return len(list( - self.id_to_entry.iteritems([StaticTuple(file_id,)]))) == 1 + return ( + len( + list( + self.id_to_entry.iteritems( + [ + StaticTuple( + file_id, + ) + ] + ) + ) + ) + == 1 + ) def is_root(self, file_id): return file_id == self.root_id @@ -1441,11 +1569,13 @@ def _preload_cache(self): last_parent_id = last_parent_ie = None pid_items = self.parent_id_basename_to_file_id.iteritems() for key, child_file_id in pid_items: - if key == (b'', b''): # This is the root + if key == (b"", b""): # This is the root if child_file_id != self.root_id: - raise ValueError('Data inconsistency detected.' - ' We expected data with key ("","") to match' - f' the root id, but {child_file_id} != {self.root_id}') + raise ValueError( + "Data inconsistency detected." + ' We expected data with key ("","") to match' + f" the root id, but {child_file_id} != {self.root_id}" + ) continue parent_id, basename = key ie = cache[child_file_id] @@ -1453,24 +1583,32 @@ def _preload_cache(self): parent_ie = last_parent_ie else: parent_ie = cache[parent_id] - if parent_ie.kind != 'directory': - raise ValueError('Data inconsistency detected.' - ' An entry in the parent_id_basename_to_file_id map' - f' has parent_id {{{parent_id}}} but the kind of that object' - f' is {parent_ie.kind!r} not "directory"') + if parent_ie.kind != "directory": + raise ValueError( + "Data inconsistency detected." + " An entry in the parent_id_basename_to_file_id map" + f" has parent_id {{{parent_id}}} but the kind of that object" + f' is {parent_ie.kind!r} not "directory"' + ) siblings = self._children_cache.setdefault(parent_ie.file_id, {}) - basename = basename.decode('utf-8') + basename = basename.decode("utf-8") if basename in siblings: existing_ie = siblings[basename] if existing_ie != ie: - raise ValueError('Data inconsistency detected.' - f' Two entries with basename {basename!r} were found' - f' in the parent entry {{{parent_id}}}') + raise ValueError( + "Data inconsistency detected." + f" Two entries with basename {basename!r} were found" + f" in the parent entry {{{parent_id}}}" + ) if basename != ie.name: - raise ValueError('Data inconsistency detected.' - ' In the parent_id_basename_to_file_id map, file_id' - ' {{{}}} is listed as having basename {!r}, but in the' - ' id_to_entry map it is {!r}'.format(child_file_id, basename, ie.name)) + raise ValueError( + "Data inconsistency detected." + " In the parent_id_basename_to_file_id map, file_id" + " {{{}}} is listed as having basename {!r}, but in the" + " id_to_entry map it is {!r}".format( + child_file_id, basename, ie.name + ) + ) siblings[basename] = ie self._fully_cached = True @@ -1484,8 +1622,9 @@ def iter_changes(self, basis): # We want: (file_id, (path_in_source, path_in_target), # changed_content, versioned, parent, name, kind, # executable) - for key, basis_value, self_value in \ - self.id_to_entry.iter_changes(basis.id_to_entry): + for key, basis_value, self_value in self.id_to_entry.iter_changes( + basis.id_to_entry + ): file_id = key[0] if basis_value is not None: basis_entry = basis._bytes_to_entry(basis_value) @@ -1523,30 +1662,40 @@ def iter_changes(self, basis): changed_content = False if kind[0] != kind[1]: changed_content = True - elif kind[0] == 'file': - if (self_entry.text_size != basis_entry.text_size - or self_entry.text_sha1 != basis_entry.text_sha1): + elif kind[0] == "file": + if ( + self_entry.text_size != basis_entry.text_size + or self_entry.text_sha1 != basis_entry.text_sha1 + ): changed_content = True - elif kind[0] == 'symlink': + elif kind[0] == "symlink": if self_entry.symlink_target != basis_entry.symlink_target: changed_content = True - elif kind[0] == 'tree-reference': - if (self_entry.reference_revision - != basis_entry.reference_revision): + elif kind[0] == "tree-reference": + if self_entry.reference_revision != basis_entry.reference_revision: changed_content = True parent = (basis_parent, self_parent) name = (basis_name, self_name) executable = (basis_executable, self_executable) - if (not changed_content and - parent[0] == parent[1] and - name[0] == name[1] and - executable[0] == executable[1]): + if ( + not changed_content + and parent[0] == parent[1] + and name[0] == name[1] + and executable[0] == executable[1] + ): # Could happen when only the revision changed for a directory # for instance. continue yield ( - file_id, (path_in_source, path_in_target), changed_content, - versioned, parent, name, kind, executable) + file_id, + (path_in_source, path_in_target), + changed_content, + versioned, + parent, + name, + kind, + executable, + ) def __len__(self): """Return the number of entries in the inventory.""" @@ -1557,8 +1706,9 @@ def _make_delta(self, old: CommonInventory): if not isinstance(old, CHKInventory): return CommonInventory._make_delta(self, old) delta = [] - for key, old_value, self_value in \ - self.id_to_entry.iter_changes(old.id_to_entry): + for key, old_value, self_value in self.id_to_entry.iter_changes( + old.id_to_entry + ): file_id = key[0] if old_value is not None: old_path = old.id2path(file_id) @@ -1573,6 +1723,7 @@ def _make_delta(self, old: CommonInventory): new_path = None delta.append((old_path, new_path, file_id, entry)) from .inventory_delta import InventoryDelta + return InventoryDelta(delta) def path2id(self, relpath): @@ -1597,17 +1748,19 @@ def path2id(self, relpath): if cur_path is None: cur_path = basename else: - cur_path = cur_path + '/' + basename - basename_utf8 = basename.encode('utf8') + cur_path = cur_path + "/" + basename + basename_utf8 = basename.encode("utf8") file_id = self._path_to_fileid_cache.get(cur_path, None) if file_id is None: key_filter = [StaticTuple(current_id, basename_utf8)] items = parent_id_index.iteritems(key_filter) for (parent_id, name_utf8), file_id in items: # noqa: B007 if parent_id != current_id or name_utf8 != basename_utf8: - raise errors.BzrError("corrupt inventory lookup! " - "{!r} {!r} {!r} {!r}".format(parent_id, current_id, name_utf8, - basename_utf8)) + raise errors.BzrError( + "corrupt inventory lookup! " "{!r} {!r} {!r} {!r}".format( + parent_id, current_id, name_utf8, basename_utf8 + ) + ) if file_id is None: return None else: @@ -1618,21 +1771,24 @@ def path2id(self, relpath): def to_lines(self): """Serialise the inventory to lines.""" lines = [b"chkinventory:\n"] - if self._search_key_name != b'plain': + if self._search_key_name != b"plain": # custom ordering grouping things that don't change together - lines.append(b'search_key_name: %s\n' % ( - self._search_key_name)) + lines.append(b"search_key_name: %s\n" % (self._search_key_name)) lines.append(b"root_id: %s\n" % self.root_id) - lines.append(b'parent_id_basename_to_file_id: %s\n' % - (self.parent_id_basename_to_file_id.key()[0],)) + lines.append( + b"parent_id_basename_to_file_id: %s\n" + % (self.parent_id_basename_to_file_id.key()[0],) + ) lines.append(b"revision_id: %s\n" % self.revision_id) lines.append(b"id_to_entry: %s\n" % (self.id_to_entry.key()[0],)) else: lines.append(b"revision_id: %s\n" % self.revision_id) lines.append(b"root_id: %s\n" % self.root_id) if self.parent_id_basename_to_file_id is not None: - lines.append(b'parent_id_basename_to_file_id: %s\n' % - (self.parent_id_basename_to_file_id.key()[0],)) + lines.append( + b"parent_id_basename_to_file_id: %s\n" + % (self.parent_id_basename_to_file_id.key()[0],) + ) lines.append(b"id_to_entry: %s\n" % (self.id_to_entry.key()[0],)) return lines @@ -1643,10 +1799,10 @@ def root(self): entry_factory = { - 'directory': InventoryDirectory, - 'file': InventoryFile, - 'symlink': InventoryLink, - 'tree-reference': TreeReference + "directory": InventoryDirectory, + "file": InventoryFile, + "symlink": InventoryLink, + "tree-reference": TreeReference, } @@ -1671,6 +1827,7 @@ def make_entry(kind, name, parent_id, file_id=None): ensure_normalized_name = _mod_inventory_rs.ensure_normalized_name is_valid_name = _mod_inventory_rs.is_valid_name + def mutable_inventory_from_tree(tree): """Create a new inventory that has the same contents as a specified tree. @@ -1683,6 +1840,8 @@ def mutable_inventory_from_tree(tree): return inv -chk_inventory_bytes_to_utf8name_key = _mod_inventory_rs.chk_inventory_bytes_to_utf8name_key +chk_inventory_bytes_to_utf8name_key = ( + _mod_inventory_rs.chk_inventory_bytes_to_utf8name_key +) _chk_inventory_bytes_to_entry = _mod_inventory_rs.chk_inventory_bytes_to_entry _chk_inventory_entry_to_bytes = _mod_inventory_rs.chk_inventory_entry_to_bytes diff --git a/breezy/bzr/inventory_delta.py b/breezy/bzr/inventory_delta.py index 7b43c5d7eb..e9d6cd32d0 100644 --- a/breezy/bzr/inventory_delta.py +++ b/breezy/bzr/inventory_delta.py @@ -22,7 +22,7 @@ - InventoryDeltaSerializer - object to read/write inventory deltas. """ -__all__ = ['InventoryDeltaSerializer'] +__all__ = ["InventoryDeltaSerializer"] from .._bzr_rs import inventory as _inventory_delta_rs @@ -62,7 +62,12 @@ def delta_to_lines(self, old_name, new_name, delta_to_new): :return: The serialized delta as lines. """ return _inventory_delta_rs.serialize_inventory_delta( - old_name, new_name, delta_to_new, self._versioned_root, self._tree_references) + old_name, + new_name, + delta_to_new, + self._versioned_root, + self._tree_references, + ) class InventoryDeltaDeserializer: @@ -90,4 +95,6 @@ def parse_text_bytes(self, lines): :return: (parent_id, new_id, versioned_root, tree_references, inventory_delta) """ - return _inventory_delta_rs.parse_inventory_delta(lines, self._allow_versioned_root, self._allow_tree_references) + return _inventory_delta_rs.parse_inventory_delta( + lines, self._allow_versioned_root, self._allow_tree_references + ) diff --git a/breezy/bzr/inventorytree.py b/breezy/bzr/inventorytree.py index 967b456455..23acbe44b2 100644 --- a/breezy/bzr/inventorytree.py +++ b/breezy/bzr/inventorytree.py @@ -29,14 +29,17 @@ from ..revisiontree import RevisionTree from ..transport.local import file_kind, file_stat -lazy_import.lazy_import(globals(), """ +lazy_import.lazy_import( + globals(), + """ from breezy import ( add, ) from breezy.bzr import ( inventory as _mod_inventory, ) -""") +""", +) from ..tree import ( FileTimestampUnavailable, InterTree, @@ -48,23 +51,47 @@ class InventoryTreeChange(TreeChange): - - __slots__ = TreeChange.__slots__ + ['file_id', 'parent_id'] - - def __init__(self, file_id, path, changed_content, versioned, parent_id, - name, kind, executable, copied=False): + __slots__ = TreeChange.__slots__ + ["file_id", "parent_id"] + + def __init__( + self, + file_id, + path, + changed_content, + versioned, + parent_id, + name, + kind, + executable, + copied=False, + ): self.file_id = file_id self.parent_id = parent_id super().__init__( - path=path, changed_content=changed_content, versioned=versioned, - name=name, kind=kind, executable=executable, copied=copied) + path=path, + changed_content=changed_content, + versioned=versioned, + name=name, + kind=kind, + executable=executable, + copied=copied, + ) def __repr__(self): return f"{self.__class__.__name__}{self._as_tuple()!r}" def _as_tuple(self): - return (self.file_id, self.path, self.changed_content, self.versioned, - self.parent_id, self.name, self.kind, self.executable, self.copied) + return ( + self.file_id, + self.path, + self.changed_content, + self.versioned, + self.parent_id, + self.name, + self.kind, + self.executable, + self.copied, + ) def __eq__(self, other): if isinstance(other, TreeChange): @@ -78,7 +105,7 @@ def __lt__(self, other): def meta_modified(self): if self.versioned == (True, True): - return (self.executable[0] != self.executable[1]) + return self.executable[0] != self.executable[1] return False def is_reparented(self): @@ -87,18 +114,24 @@ def is_reparented(self): @property def renamed(self): return ( - not self.copied and - None not in self.name and - None not in self.parent_id and - (self.name[0] != self.name[1] or self.parent_id[0] != self.parent_id[1])) + not self.copied + and None not in self.name + and None not in self.parent_id + and (self.name[0] != self.name[1] or self.parent_id[0] != self.parent_id[1]) + ) def discard_new(self): return self.__class__( - self.file_id, (self.path[0], None), self.changed_content, - (self.versioned[0], None), (self.parent_id[0], None), - (self.name[0], None), (self.kind[0], None), + self.file_id, + (self.path[0], None), + self.changed_content, + (self.versioned[0], None), + (self.parent_id[0], None), + (self.name[0], None), + (self.kind[0], None), (self.executable[0], None), - copied=False) + copied=False, + ) def _filesize(f) -> int: @@ -125,13 +158,12 @@ def supports_symlinks(self): @classmethod def is_special_path(cls, path): - return path.startswith('.bzr') + return path.startswith(".bzr") def _get_root_inventory(self): return self._inventory - root_inventory = property(_get_root_inventory, - doc="Root inventory of this tree") + root_inventory = property(_get_root_inventory, doc="Root inventory of this tree") supports_file_ids = True @@ -143,13 +175,13 @@ def _unpack_file_id(self, file_id): """ if isinstance(file_id, tuple): if len(file_id) != 1: - raise ValueError( - f"nested trees not yet supported: {file_id!r}") + raise ValueError(f"nested trees not yet supported: {file_id!r}") file_id = file_id[0] return self.root_inventory, file_id - def find_related_paths_across_trees(self, paths, trees=None, - require_versioned=True): + def find_related_paths_across_trees( + self, paths, trees=None, require_versioned=True + ): """Find related paths in tree corresponding to specified filenames in any of `lookup_trees`. @@ -168,8 +200,7 @@ def find_related_paths_across_trees(self, paths, trees=None, trees = [] if paths is None: return None - file_ids = self.paths2ids( - paths, trees, require_versioned=require_versioned) + file_ids = self.paths2ids(paths, trees, require_versioned=require_versioned) ret = set() for file_id in file_ids: try: @@ -204,7 +235,6 @@ def path2id(self, path): with self.lock_read(): return self._path2inv_file_id(path)[1] - def is_versioned(self, path): return self.path2id(path) is not None @@ -229,7 +259,9 @@ def _path2inv_ie(self, path): while remaining: ie, base, remaining = inv.get_entry_by_path_partial(remaining) if remaining: - inv = self._get_nested_tree('/'.join(base), ie.file_id, ie.reference_revision).root_inventory + inv = self._get_nested_tree( + "/".join(base), ie.file_id, ie.reference_revision + ).root_inventory if ie is None: return None, None return inv, ie @@ -245,7 +277,7 @@ def _path2inv_file_id(self, path): return None, None return inv, ie.file_id - def id2path(self, file_id, recurse='down'): + def id2path(self, file_id, recurse="down"): """Return the path for a file id. :raises NoSuchId: @@ -254,10 +286,11 @@ def id2path(self, file_id, recurse='down'): try: return inventory.id2path(file_id) except errors.NoSuchId as e: - if recurse == 'down': - if debug.debug_flag_enabled('evil'): + if recurse == "down": + if debug.debug_flag_enabled("evil"): trace.mutter_callsite( - 2, "id2path with nested trees scales with tree size.") + 2, "id2path with nested trees scales with tree size." + ) for path in self.iter_references(): subtree = self.get_nested_tree(path) try: @@ -272,8 +305,7 @@ def all_file_ids(self): def all_versioned_paths(self): return {path for path, entry in self.iter_entries_by_dir()} - def iter_entries_by_dir(self, specific_files=None, - recurse_nested=False): + def iter_entries_by_dir(self, specific_files=None, recurse_nested=False): """Walk the tree in 'by_dir' order. This will yield each entry in the tree as a (path, entry) tuple. @@ -287,15 +319,22 @@ def iter_entries_by_dir(self, specific_files=None, for path in specific_files: inventory, inv_file_id = self._path2inv_file_id(path) if inventory and inventory is not self.root_inventory: - raise AssertionError(f"{inventory!r} != {self.root_inventory!r}") + raise AssertionError( + f"{inventory!r} != {self.root_inventory!r}" + ) inventory_file_ids.add(inv_file_id) else: inventory_file_ids = None + def iter_entries(inv): - for p, e in inv.iter_entries_by_dir(specific_file_ids=inventory_file_ids): - if e.kind == 'tree-reference' and recurse_nested: + for p, e in inv.iter_entries_by_dir( + specific_file_ids=inventory_file_ids + ): + if e.kind == "tree-reference" and recurse_nested: try: - subtree = self._get_nested_tree(p, e.file_id, e.reference_revision) + subtree = self._get_nested_tree( + p, e.file_id, e.reference_revision + ) except errors.NotBranchError: yield p, e else: @@ -305,21 +344,24 @@ def iter_entries(inv): yield (osutils.pathjoin(p, subp) if subp else p), e else: yield p, e + return iter_entries(self.root_inventory) def _get_plan_merge_data(self, path, other, base): from . import versionedfile + file_id = self.path2id(path) vf = versionedfile._PlanMergeVersionedFile(file_id) - last_revision_a = self._get_file_revision( - path, file_id, vf, b'this:') + last_revision_a = self._get_file_revision(path, file_id, vf, b"this:") last_revision_b = other._get_file_revision( - other.id2path(file_id), file_id, vf, b'other:') + other.id2path(file_id), file_id, vf, b"other:" + ) if base is None: last_revision_base = None else: last_revision_base = base._get_file_revision( - base.id2path(file_id), file_id, vf, b'base:') + base.id2path(file_id), file_id, vf, b"base:" + ) return vf, last_revision_a, last_revision_b, last_revision_base def plan_file_merge(self, path, other, base=None): @@ -332,8 +374,7 @@ def plan_file_merge(self, path, other, base=None): """ data = self._get_plan_merge_data(path, other, base) vf, last_revision_a, last_revision_b, last_revision_base = data - return vf.plan_merge(last_revision_a, last_revision_b, - last_revision_base) + return vf.plan_merge(last_revision_a, last_revision_b, last_revision_base) def plan_file_lca_merge(self, path, other, base=None): """Generate a merge plan based lca-newness. @@ -345,8 +386,7 @@ def plan_file_lca_merge(self, path, other, base=None): """ data = self._get_plan_merge_data(path, other, base) vf, last_revision_a, last_revision_b, last_revision_base = data - return vf.plan_lca_merge(last_revision_a, last_revision_b, - last_revision_base) + return vf.plan_lca_merge(last_revision_a, last_revision_b, last_revision_base) def _iter_parent_trees(self): """Iterate through parent trees, defaulting to Tree.revision_tree.""" @@ -359,14 +399,17 @@ def _iter_parent_trees(self): def _get_file_revision(self, path, file_id, vf, tree_revision): """Ensure that file_id, tree_revision is in vf to plan the merge.""" from . import versionedfile + last_revision = tree_revision parent_keys = [ - (file_id, t.get_file_revision(path)) for t in - self._iter_parent_trees()] + (file_id, t.get_file_revision(path)) for t in self._iter_parent_trees() + ] with self.get_file(path) as f: vf.add_content( versionedfile.FileContentFactory( - (file_id, last_revision), parent_keys, f, size=_filesize(f))) + (file_id, last_revision), parent_keys, f, size=_filesize(f) + ) + ) repo = self.branch.repository base_vf = repo.texts if base_vf not in vf.fallback_versionedfiles: @@ -375,6 +418,7 @@ def _get_file_revision(self, path, file_id, vf, tree_revision): def preview_transform(self, pb=None): from .transform import TransformPreview + return TransformPreview(self, pb=pb) @@ -393,8 +437,7 @@ def find_ids_across_trees(filenames, trees, require_versioned=True): """ if not filenames: return None - specified_path_ids = _find_ids_across_trees(filenames, trees, - require_versioned) + specified_path_ids = _find_ids_across_trees(filenames, trees, require_versioned) return _find_children_across_trees(specified_path_ids, trees) @@ -457,7 +500,6 @@ def _find_children_across_trees(specified_ids, trees): class MutableInventoryTree(MutableTree, InventoryTree): - def apply_inventory_delta(self, changes): """Apply changes to the inventory as an atomic operation. @@ -467,6 +509,7 @@ def apply_inventory_delta(self, changes): :seealso Inventory.apply_delta: For details on the changes parameter. """ from .inventory_delta import InventoryDelta + with self.lock_tree_write(): self.flush() inv = self.root_inventory @@ -510,8 +553,8 @@ def has_changes(self, _from_tree=None): # working copy as compared to the repository. # Also, exclude root as mention in the above fast path. changes = filter( - lambda c: c[6][0] != 'symlink' and c[4] != (None, None), - changes) + lambda c: c[6][0] != "symlink" and c[4] != (None, None), changes + ) try: next(iter(changes)) except StopIteration: @@ -546,7 +589,7 @@ def smart_add(self, file_list, recurse=True, action=None, save=True): """ with self.lock_tree_write(): # Not all mutable trees can have conflicts - if getattr(self, 'conflicts', None) is not None: + if getattr(self, "conflicts", None) is not None: # Collect all related files without checking whether they exist or # are versioned. It's cheaper to do that once for all conflicts # than trying to find the relevant conflict for each added file. @@ -600,12 +643,12 @@ def update_basis_by_delta(self, new_revid, delta): # it only makes sense when apply_delta is cheaper than get_inventory() inventory = _mod_inventory.mutable_inventory_from_tree(basis) inventory.apply_delta(delta) - rev_tree = InventoryRevisionTree(self.branch.repository, - inventory, new_revid) + rev_tree = InventoryRevisionTree(self.branch.repository, inventory, new_revid) self.set_parent_trees([(new_revid, rev_tree)]) def transform(self, pb=None): from .transform import InventoryTreeTransform + return InventoryTreeTransform(self, pb=pb) def add(self, files, kinds=None, ids=None): @@ -640,7 +683,7 @@ def add(self, files, kinds=None, ids=None): if kinds is not None: kinds = [kinds] - files = [path.strip('/') for path in files] + files = [path.strip("/") for path in files] if ids is None: ids = [None] * len(files) @@ -698,8 +741,7 @@ def _get_ie(self, inv_path): # Find a 'best fit' match if the filesystem is case-insensitive inv_path = self.tree._fix_case_of_inventory_path(inv_path) try: - return next(self.tree.iter_entries_by_dir( - specific_files=[inv_path]))[1] + return next(self.tree.iter_entries_by_dir(specific_files=[inv_path]))[1] except StopIteration: return None @@ -713,9 +755,9 @@ def _convert_to_directory(self, this_ie, inv_path): # Same as in _add_one below, if the inventory doesn't # think this is a directory, update the inventory this_ie = _mod_inventory.InventoryDirectory( - this_ie.file_id, this_ie.name, this_ie.parent_id) - self._invdelta[inv_path] = (inv_path, inv_path, this_ie.file_id, - this_ie) + this_ie.file_id, this_ie.name, this_ie.parent_id + ) + self._invdelta[inv_path] = (inv_path, inv_path, this_ie.file_id, this_ie) return this_ie def _add_one_and_parent(self, parent_ie, path, kind, inv_path): @@ -743,19 +785,20 @@ def _add_one_and_parent(self, parent_ie, path, kind, inv_path): # note that the dirname use leads to some extra str copying etc but as # there are a limited number of dirs we can be nested under, it should # generally find it very fast and not recurse after that. - parent_ie = self._add_one_and_parent(None, - dirname, 'directory', - inv_dirname) + parent_ie = self._add_one_and_parent( + None, dirname, "directory", inv_dirname + ) # if the parent exists, but isn't a directory, we have to do the # kind change now -- really the inventory shouldn't pretend to know # the kind of wt files, but it does. - if parent_ie.kind != 'directory': + if parent_ie.kind != "directory": # nb: this relies on someone else checking that the path we're using # doesn't contain symlinks. parent_ie = self._convert_to_directory(parent_ie, inv_dirname) file_id = self.action(self.tree, parent_ie, path, kind) - entry = _mod_inventory.make_entry(kind, basename, parent_ie.file_id, - file_id=file_id) + entry = _mod_inventory.make_entry( + kind, basename, parent_ie.file_id, file_id=file_id + ) self._invdelta[inv_path] = (None, inv_path, entry.file_id, entry) self.added.append(inv_path) return entry @@ -767,7 +810,7 @@ def _gather_dirs_to_add(self, user_dirs): is_inside = osutils.is_inside_or_parent_of_any for path in sorted(user_dirs): - if (prev_dir is None or not is_inside([prev_dir], path)): + if prev_dir is None or not is_inside([prev_dir], path): inv_path, this_ie = user_dirs[path] yield (path, inv_path, this_ie, None) prev_dir = path @@ -791,7 +834,7 @@ def add(self, file_list, recurse=True): # no paths supplied: add the entire tree. # FIXME: this assumes we are running in a working tree subdir :-/ # -- vila 20100208 - file_list = ['.'] + file_list = ["."] # expand any symlinks in the directory part, while leaving the # filename alone @@ -819,9 +862,8 @@ def add(self, file_list, recurse=True): inv_path, _ = osutils.normalized_filename(filepath) this_ie = self._get_ie(inv_path) if this_ie is None: - this_ie = self._add_one_and_parent( - None, filepath, kind, inv_path) - if kind == 'directory': + this_ie = self._add_one_and_parent(None, filepath, kind, inv_path) + if kind == "directory": # schedule the dir for scanning user_dirs[filepath] = (inv_path, this_ie) @@ -831,7 +873,7 @@ def add(self, file_list, recurse=True): things_to_add = list(self._gather_dirs_to_add(user_dirs)) - illegalpath_re = re.compile(r'[\r\n]') + illegalpath_re = re.compile(r"[\r\n]") for directory, inv_path, this_ie, parent_ie in things_to_add: # directory is tree-relative abspath = self.tree.abspath(directory) @@ -851,8 +893,9 @@ def add(self, file_list, recurse=True): if self.action.skip_file(self.tree, abspath, kind, stat_value): continue if not _mod_inventory.InventoryEntry.versionable_kind(kind): - trace.warning("skipping %s (can't add file of kind '%s')", - abspath, kind) + trace.warning( + "skipping %s (can't add file of kind '%s')", abspath, kind + ) continue if illegalpath_re.search(directory): trace.warning(f"skipping {abspath!r} (contains \\n or \\r)") @@ -861,11 +904,11 @@ def add(self, file_list, recurse=True): # If the file looks like one generated for a conflict, don't # add it. trace.warning( - 'skipping %s (generated to help resolve conflicts)', - abspath) + "skipping %s (generated to help resolve conflicts)", abspath + ) continue - if kind == 'directory' and directory != '': + if kind == "directory" and directory != "": try: transport = _mod_transport.get_transport_from_path(abspath) controldir.ControlDirFormat.find_format(transport) @@ -890,11 +933,10 @@ def add(self, file_list, recurse=True): # 20070306 trace.warning("skipping nested tree %r", abspath) else: - this_ie = self._add_one_and_parent(parent_ie, directory, kind, - inv_path) + this_ie = self._add_one_and_parent(parent_ie, directory, kind, inv_path) - if kind == 'directory' and not sub_tree: - if this_ie.kind != 'directory': + if kind == "directory" and not sub_tree: + if this_ie.kind != "directory": this_ie = self._convert_to_directory(this_ie, inv_path) for subf in sorted(os.listdir(abspath)): @@ -924,15 +966,12 @@ def add(self, file_list, recurse=True): # outer loop we would ignore user files. ignore_glob = self.tree.is_ignored(subp) if ignore_glob is not None: - self.ignored.setdefault( - ignore_glob, []).append(subp) + self.ignored.setdefault(ignore_glob, []).append(subp) else: - things_to_add.append( - (subp, sub_invp, None, this_ie)) + things_to_add.append((subp, sub_invp, None, this_ie)) class InventoryRevisionTree(RevisionTree, InventoryTree): - def __init__(self, repository, inv, revision_id): RevisionTree.__init__(self, repository, revision_id) self._inventory = inv @@ -980,19 +1019,21 @@ def reference_parent(self, path, branch=None, possible_transports=None): parent_url = branch.get_reference_info(file_id)[0] else: subdir = ControlDir.open_from_transport( - self._repository.user_transport.clone(path)) + self._repository.user_transport.clone(path) + ) parent_url = subdir.open_branch().get_parent() if parent_url is None: return None return _mod_branch.Branch.open( - parent_url, - possible_transports=possible_transports) + parent_url, possible_transports=possible_transports + ) def get_reference_info(self, path, branch=None): return branch.get_reference_info(self.path2id(path))[0] - def list_files(self, include_root=False, from_dir=None, recursive=True, - recurse_nested=False): + def list_files( + self, include_root=False, from_dir=None, recursive=True, recurse_nested=False + ): # The only files returned by this are those from the version if from_dir is None: from_dir_id = None @@ -1007,32 +1048,35 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, # skip the root for compatibility with the current apis. next(entries) for path, entry in entries: - if entry.kind == 'tree-reference' and recurse_nested: + if entry.kind == "tree-reference" and recurse_nested: subtree = self._get_nested_tree( - path, entry.file_id, entry.reference_revision) + path, entry.file_id, entry.reference_revision + ) for subpath, status, kind, entry in subtree.list_files( - include_root=True, recurse_nested=recurse_nested, - recursive=recursive): + include_root=True, + recurse_nested=recurse_nested, + recursive=recursive, + ): if subpath: full_subpath = osutils.pathjoin(path, subpath) else: full_subpath = path yield full_subpath, status, kind, entry else: - yield path, 'V', entry.kind, entry + yield path, "V", entry.kind, entry def iter_child_entries(self, path): inv, ie = self._path2inv_ie(path) if ie is None: raise _mod_transport.NoSuchFile(path) - if ie.kind != 'directory': + if ie.kind != "directory": raise errors.NotADirectory(path) return inv.iter_sorted_children(ie.file_id) def get_symlink_target(self, path): # Inventories store symlink targets in unicode ie = self._path2ie(path) - if ie.kind != 'symlink': + if ie.kind != "symlink": return None return ie.symlink_target @@ -1043,7 +1087,8 @@ def _get_nested_tree(self, path, file_id, reference_revision): # Just a guess.. try: subdir = ControlDir.open_from_transport( - self._repository.user_transport.clone(path)) + self._repository.user_transport.clone(path) + ) except errors.NotBranchError as e: raise MissingNestedTree(path) from e subrepo = subdir.find_repository() @@ -1051,9 +1096,10 @@ def _get_nested_tree(self, path, file_id, reference_revision): revtree = subrepo.revision_tree(reference_revision) except errors.NoSuchRevision as e: raise MissingNestedTree(path) from e - if file_id is not None and file_id != revtree.path2id(''): - raise AssertionError('invalid root id: {!r} != {!r}'.format( - file_id, revtree.path2id(''))) + if file_id is not None and file_id != revtree.path2id(""): + raise AssertionError( + "invalid root id: {!r} != {!r}".format(file_id, revtree.path2id("")) + ) return revtree def get_nested_tree(self, path): @@ -1068,11 +1114,11 @@ def path_content_summary(self, path): try: entry = self._path2ie(path) except _mod_transport.NoSuchFile: - return ('missing', None, None, None) + return ("missing", None, None, None) kind = entry.kind - if kind == 'file': + if kind == "file": return (kind, entry.text_size, entry.executable, entry.text_sha1) - elif kind == 'symlink': + elif kind == "symlink": return (kind, None, None, entry.symlink_target) else: return (kind, None, None, None) @@ -1083,7 +1129,7 @@ def _comparison_data(self, entry, path): return entry.kind, entry.executable, None def walkdirs(self, prefix=""): - _directory = 'directory' + _directory = "directory" inv, top_id = self._path2inv_file_id(prefix) if top_id is None: pending = [] @@ -1093,7 +1139,7 @@ def walkdirs(self, prefix=""): dirblock = [] root, file_id = pending.pop() if root: - relroot = root + '/' + relroot = root + "/" else: relroot = "" # FIXME: stash the node in pending @@ -1112,8 +1158,9 @@ def iter_files_bytes(self, desired_files): This version is implemented on top of Repository.iter_files_bytes """ - repo_desired_files = [(self.path2id(f), self.get_file_revision(f), i) - for f, i in desired_files] + repo_desired_files = [ + (self.path2id(f), self.get_file_revision(f), i) for f, i in desired_files + ] try: yield from self._repository.iter_files_bytes(repo_desired_files) except errors.RevisionNotPresent as e: @@ -1131,27 +1178,28 @@ def __eq__(self, other): if self is other: return True if isinstance(other, InventoryRevisionTree): - return (self.root_inventory == other.root_inventory) + return self.root_inventory == other.root_inventory return False def __ne__(self, other): return not (self == other) def __hash__(self): - raise ValueError('not hashable') + raise ValueError("not hashable") class InterInventoryTree(InterTree): """InterTree implementation for InventoryTree objects.""" + @classmethod def is_compatible(kls, source, target): # The default implementation is naive and uses the public API, so # it works for all trees. - return (isinstance(source, InventoryTree) and - isinstance(target, InventoryTree)) + return isinstance(source, InventoryTree) and isinstance(target, InventoryTree) - def _changes_from_entries(self, source_entry, target_entry, source_path, - target_path): + def _changes_from_entries( + self, source_entry, target_entry, source_path, target_path + ): """Generate a iter_changes tuple between source_entry and target_entry. :param source_entry: An inventory entry from self.source, or None. @@ -1171,8 +1219,9 @@ def _changes_from_entries(self, source_entry, target_entry, source_path, source_versioned = True source_name = source_entry.name source_parent = source_entry.parent_id - source_kind, source_executable, source_stat = \ - self.source._comparison_data(source_entry, source_path) + source_kind, source_executable, source_stat = self.source._comparison_data( + source_entry, source_path + ) else: source_versioned = False source_name = None @@ -1183,8 +1232,9 @@ def _changes_from_entries(self, source_entry, target_entry, source_path, target_versioned = True target_name = target_entry.name target_parent = target_entry.parent_id - target_kind, target_executable, target_stat = \ - self.target._comparison_data(target_entry, target_path) + target_kind, target_executable, target_stat = self.target._comparison_data( + target_entry, target_path + ) else: target_versioned = False target_name = None @@ -1196,35 +1246,54 @@ def _changes_from_entries(self, source_entry, target_entry, source_path, changed_content = False if source_kind != target_kind: changed_content = True - elif source_kind == 'file': + elif source_kind == "file": if not self.file_content_matches( - source_path, target_path, - source_stat, target_stat): + source_path, target_path, source_stat, target_stat + ): changed_content = True - elif source_kind == 'symlink': - if (self.source.get_symlink_target(source_path) != - self.target.get_symlink_target(target_path)): + elif source_kind == "symlink": + if self.source.get_symlink_target( + source_path + ) != self.target.get_symlink_target(target_path): changed_content = True - elif source_kind == 'tree-reference': - if (self.source.get_reference_revision(source_path) - != self.target.get_reference_revision(target_path)): + elif source_kind == "tree-reference": + if self.source.get_reference_revision( + source_path + ) != self.target.get_reference_revision(target_path): changed_content = True parent = (source_parent, target_parent) name = (source_name, target_name) executable = (source_executable, target_executable) - if (changed_content is not False or versioned[0] != versioned[1] or - parent[0] != parent[1] or name[0] != name[1] or - executable[0] != executable[1]): + if ( + changed_content is not False + or versioned[0] != versioned[1] + or parent[0] != parent[1] + or name[0] != name[1] + or executable[0] != executable[1] + ): changes = True else: changes = False return InventoryTreeChange( - file_id, (source_path, target_path), changed_content, - versioned, parent, name, kind, executable), changes - - def iter_changes(self, include_unchanged=False, - specific_files=None, pb=None, extra_trees=None, - require_versioned=True, want_unversioned=False): + file_id, + (source_path, target_path), + changed_content, + versioned, + parent, + name, + kind, + executable, + ), changes + + def iter_changes( + self, + include_unchanged=False, + specific_files=None, + pb=None, + extra_trees=None, + require_versioned=True, + want_unversioned=False, + ): """Generate an iterator of changes between trees. A tuple is returned: @@ -1270,11 +1339,15 @@ def iter_changes(self, include_unchanged=False, source_specific_files = [] else: target_specific_files = self.target.find_related_paths_across_trees( - specific_files, [self.source] + extra_trees, - require_versioned=require_versioned) + specific_files, + [self.source] + extra_trees, + require_versioned=require_versioned, + ) source_specific_files = self.source.find_related_paths_across_trees( - specific_files, [self.target] + extra_trees, - require_versioned=require_versioned) + specific_files, + [self.target] + extra_trees, + require_versioned=require_versioned, + ) if specific_files is not None: # reparented or added entries must have their parents included # so that valid deltas can be created. The seen_parents set @@ -1285,19 +1358,25 @@ def iter_changes(self, include_unchanged=False, seen_parents = set() seen_dirs = set() if want_unversioned: - all_unversioned = sorted([(p.split('/'), p) for p in - self.target.extras() - if specific_files is None or - osutils.is_inside_any(specific_files, p)]) + all_unversioned = sorted( + [ + (p.split("/"), p) + for p in self.target.extras() + if specific_files is None + or osutils.is_inside_any(specific_files, p) + ] + ) all_unversioned = deque(all_unversioned) else: all_unversioned = deque() to_paths = {} - from_entries_by_dir = list(self.source.iter_entries_by_dir( - specific_files=source_specific_files)) + from_entries_by_dir = list( + self.source.iter_entries_by_dir(specific_files=source_specific_files) + ) from_data = dict(from_entries_by_dir) - to_entries_by_dir = list(self.target.iter_entries_by_dir( - specific_files=target_specific_files)) + to_entries_by_dir = list( + self.target.iter_entries_by_dir(specific_files=target_specific_files) + ) path_equivs = self.find_source_paths([p for p, e in to_entries_by_dir]) num_entries = len(from_entries_by_dir) + len(to_entries_by_dir) entry_count = 0 @@ -1306,31 +1385,40 @@ def iter_changes(self, include_unchanged=False, # executable it values when execute is not supported. fake_entry = TreeFile() for target_path, target_entry in to_entries_by_dir: - while (all_unversioned and - all_unversioned[0][0] < target_path.split('/')): + while all_unversioned and all_unversioned[0][0] < target_path.split("/"): unversioned_path = all_unversioned.popleft() - target_kind, target_executable, target_stat = \ - self.target._comparison_data( - fake_entry, unversioned_path[1]) + ( + target_kind, + target_executable, + target_stat, + ) = self.target._comparison_data(fake_entry, unversioned_path[1]) yield InventoryTreeChange( - None, (None, unversioned_path[1]), True, (False, False), + None, + (None, unversioned_path[1]), + True, + (False, False), (None, None), (None, unversioned_path[0][-1]), (None, target_kind), - (None, target_executable)) + (None, target_executable), + ) source_path = path_equivs[target_path] if source_path is not None: source_entry = from_data.get(source_path) else: source_entry = None result, changes = self._changes_from_entries( - source_entry, target_entry, source_path=source_path, target_path=target_path) + source_entry, + target_entry, + source_path=source_path, + target_path=target_path, + ) to_paths[result.file_id] = result.path[1] entry_count += 1 if result.versioned[0]: entry_count += 1 if pb is not None: - pb.update('comparing files', entry_count, num_entries) + pb.update("comparing files", entry_count, num_entries) if changes or include_unchanged: if specific_files is not None: precise_file_ids.add(result.parent_id[1]) @@ -1339,7 +1427,7 @@ def iter_changes(self, include_unchanged=False, # Ensure correct behaviour for reparented/added specific files. if specific_files is not None: # Record output dirs - if result.kind[1] == 'directory': + if result.kind[1] == "directory": seen_dirs.add(result.file_id) # Record parents of reparented/added entries. if not result.versioned[0] or result.is_reparented(): @@ -1347,14 +1435,19 @@ def iter_changes(self, include_unchanged=False, while all_unversioned: # yield any trailing unversioned paths unversioned_path = all_unversioned.popleft() - to_kind, to_executable, to_stat = \ - self.target._comparison_data(fake_entry, unversioned_path[1]) + to_kind, to_executable, to_stat = self.target._comparison_data( + fake_entry, unversioned_path[1] + ) yield InventoryTreeChange( - None, (None, unversioned_path[1]), True, (False, False), + None, + (None, unversioned_path[1]), + True, + (False, False), (None, None), (None, unversioned_path[0][-1]), (None, to_kind), - (None, to_executable)) + (None, to_executable), + ) # Yield all remaining source paths for path, from_entry in from_entries_by_dir: file_id = from_entry.file_id @@ -1364,24 +1457,31 @@ def iter_changes(self, include_unchanged=False, to_path = self.find_target_path(path) entry_count += 1 if pb is not None: - pb.update('comparing files', entry_count, num_entries) + pb.update("comparing files", entry_count, num_entries) versioned = (True, False) parent = (from_entry.parent_id, None) name = (from_entry.name, None) - from_kind, from_executable, stat_value = \ - self.source._comparison_data(from_entry, path) + from_kind, from_executable, stat_value = self.source._comparison_data( + from_entry, path + ) kind = (from_kind, None) executable = (from_executable, None) changed_content = from_kind is not None # the parent's path is necessarily known at this point. changed_file_ids.append(file_id) yield InventoryTreeChange( - file_id, (path, to_path), changed_content, versioned, parent, - name, kind, executable) + file_id, + (path, to_path), + changed_content, + versioned, + parent, + name, + kind, + executable, + ) changed_file_ids = set(changed_file_ids) if specific_files is not None: - for result in self._handle_precise_ids(precise_file_ids, - changed_file_ids): + for result in self._handle_precise_ids(precise_file_ids, changed_file_ids): yield result @staticmethod @@ -1402,8 +1502,9 @@ def _get_entry(tree, path): except StopIteration: return None - def _handle_precise_ids(self, precise_file_ids, changed_file_ids, - discarded_changes=None): + def _handle_precise_ids( + self, precise_file_ids, changed_file_ids, discarded_changes=None + ): """Fill out a partial iter_changes to be consistent. :param precise_file_ids: The file ids of parents that were seen during @@ -1458,39 +1559,37 @@ def _handle_precise_ids(self, precise_file_ids, changed_file_ids, source_path = None source_entry = None else: - source_entry = self._get_entry( - self.source, source_path) + source_entry = self._get_entry(self.source, source_path) try: target_path = self.target.id2path(file_id) except errors.NoSuchId: target_path = None target_entry = None else: - target_entry = self._get_entry( - self.target, target_path) + target_entry = self._get_entry(self.target, target_path) result, changes = self._changes_from_entries( - source_entry, target_entry, source_path, target_path) + source_entry, target_entry, source_path, target_path + ) else: changes = True # Get this parents parent to examine. new_parent_id = result.parent_id[1] precise_file_ids.add(new_parent_id) if changes: - if (result.kind[0] == 'directory' and - result.kind[1] != 'directory'): + if result.kind[0] == "directory" and result.kind[1] != "directory": # This stopped being a directory, the old children have # to be included. if source_entry is None: # Reusing a discarded change. - source_entry = self._get_entry( - self.source, result.path[0]) + source_entry = self._get_entry(self.source, result.path[0]) precise_file_ids.update( child.file_id - for child in self.source.iter_child_entries(result.path[0])) + for child in self.source.iter_child_entries(result.path[0]) + ) changed_file_ids.add(result.file_id) yield result - def find_target_path(self, path, recurse='none'): + def find_target_path(self, path, recurse="none"): """Find target tree path. :param path: Path to search for (exists in source) @@ -1505,7 +1604,7 @@ def find_target_path(self, path, recurse='none'): except errors.NoSuchId: return None - def find_source_path(self, path, recurse='none'): + def find_source_path(self, path, recurse="none"): """Find the source tree path. :param path: Path to search for (exists in target) @@ -1529,8 +1628,7 @@ class InterCHKRevisionTree(InterInventoryTree): @staticmethod def is_compatible(source, target): - if (isinstance(source, RevisionTree) and - isinstance(target, RevisionTree)): + if isinstance(source, RevisionTree) and isinstance(target, RevisionTree): try: # Only CHK inventories have id_to_entry attribute source.root_inventory.id_to_entry # noqa: B018 @@ -1540,9 +1638,15 @@ def is_compatible(source, target): pass return False - def iter_changes(self, include_unchanged=False, - specific_files=None, pb=None, extra_trees=None, - require_versioned=True, want_unversioned=False): + def iter_changes( + self, + include_unchanged=False, + specific_files=None, + pb=None, + extra_trees=None, + require_versioned=True, + want_unversioned=False, + ): if extra_trees is None: extra_trees = [] lookup_trees = [self.source] @@ -1554,15 +1658,17 @@ def iter_changes(self, include_unchanged=False, if specific_files == []: specific_file_ids = [] else: - specific_file_ids = self.target.paths2ids(specific_files, - lookup_trees, require_versioned=require_versioned) + specific_file_ids = self.target.paths2ids( + specific_files, lookup_trees, require_versioned=require_versioned + ) # FIXME: It should be possible to delegate include_unchanged handling # to CHKInventory.iter_changes and do a better job there -- vila # 20090304 changed_file_ids = set() # FIXME: nested tree support for result in self.target.root_inventory.iter_changes( - self.source.root_inventory): + self.source.root_inventory + ): result = InventoryTreeChange(*result) if specific_file_ids is not None: if result.file_id not in specific_file_ids: @@ -1575,8 +1681,9 @@ def iter_changes(self, include_unchanged=False, yield result changed_file_ids.add(result.file_id) if specific_file_ids is not None: - for result in self._handle_precise_ids(precise_file_ids, - changed_file_ids, discarded_changes=discarded_changes): + for result in self._handle_precise_ids( + precise_file_ids, changed_file_ids, discarded_changes=discarded_changes + ): yield result if include_unchanged: # CHKMap avoid being O(tree), so we go to O(tree) only if @@ -1586,8 +1693,10 @@ def iter_changes(self, include_unchanged=False, # FIXME: Support nested trees changed_file_ids = set(changed_file_ids) for relpath, entry in self.target.root_inventory.iter_entries(): - if (specific_file_ids is not None and - entry.file_id not in specific_file_ids): + if ( + specific_file_ids is not None + and entry.file_id not in specific_file_ids + ): continue if entry.file_id not in changed_file_ids: yield InventoryTreeChange( @@ -1598,7 +1707,8 @@ def iter_changes(self, include_unchanged=False, (entry.parent_id, entry.parent_id), (entry.name, entry.name), (entry.kind, entry.kind), - (entry.executable, entry.executable)) + (entry.executable, entry.executable), + ) InterTree.register_optimiser(InterCHKRevisionTree) diff --git a/breezy/bzr/knit.py b/breezy/bzr/knit.py index 08897fb6d2..2cb21d1acb 100644 --- a/breezy/bzr/knit.py +++ b/breezy/bzr/knit.py @@ -57,7 +57,9 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import gzip from breezy import ( @@ -73,7 +75,8 @@ from breezy.bzr import pack_repo from breezy.i18n import gettext -""") +""", +) from .. import annotate, debug, errors, osutils, trace from .. import transport as _mod_transport from ..bzr.versionedfile import ( @@ -105,18 +108,16 @@ # position after writing to work out where it was located. we may need to # bypass python file buffering. -DATA_SUFFIX = '.knit' -INDEX_SUFFIX = '.kndx' +DATA_SUFFIX = ".knit" +INDEX_SUFFIX = ".kndx" _STREAM_MIN_BUFFER_SIZE = 5 * 1024 * 1024 class KnitError(InternalBzrError): - _fmt = "Knit error" class KnitCorrupt(KnitError): - _fmt = "Knit %(filename)s corrupt: %(how)s" def __init__(self, filename, how): @@ -126,10 +127,11 @@ def __init__(self, filename, how): class SHA1KnitCorrupt(KnitCorrupt): - - _fmt = ("Knit %(filename)s corrupt: sha-1 of reconstructed text does not " - "match expected sha-1. key %(key)s expected sha %(expected)s actual " - "sha %(actual)s") + _fmt = ( + "Knit %(filename)s corrupt: sha-1 of reconstructed text does not " + "match expected sha-1. key %(key)s expected sha %(expected)s actual " + "sha %(actual)s" + ) def __init__(self, filename, actual, expected, key, content): KnitError.__init__(self) @@ -144,7 +146,7 @@ class KnitDataStreamIncompatible(KnitError): # Not raised anymore, as we can convert data streams. In future we may # need it again for more exotic cases, so we're keeping it around for now. - _fmt = "Cannot insert knit data stream of format \"%(stream_format)s\" into knit of format \"%(target_format)s\"." + _fmt = 'Cannot insert knit data stream of format "%(stream_format)s" into knit of format "%(target_format)s".' def __init__(self, stream_format, target_format): self.stream_format = stream_format @@ -154,14 +156,13 @@ def __init__(self, stream_format, target_format): class KnitDataStreamUnknown(KnitError): # Indicates a data stream we don't know how to handle. - _fmt = "Cannot parse knit data stream of format \"%(stream_format)s\"." + _fmt = 'Cannot parse knit data stream of format "%(stream_format)s".' def __init__(self, stream_format): self.stream_format = stream_format class KnitHeaderError(KnitError): - _fmt = 'Knit header error: %(badline)r unexpected for file "%(filename)s".' def __init__(self, badline, filename): @@ -176,8 +177,10 @@ class KnitIndexUnknownMethod(KnitError): Currently only 'fulltext' and 'line-delta' are supported. """ - _fmt = ("Knit index %(filename)s does not have a known method" - " in options: %(options)r") + _fmt = ( + "Knit index %(filename)s does not have a known method" + " in options: %(options)r" + ) def __init__(self, filename, options): KnitError.__init__(self) @@ -204,33 +207,31 @@ class FTAnnotatedToUnannotated(KnitAdapter): """An adapter from FT annotated knits to unannotated ones.""" def get_bytes(self, factory, target_storage_kind): - if target_storage_kind != 'knit-ft-gz': + if target_storage_kind != "knit-ft-gz": raise UnavailableRepresentation( - factory.key, target_storage_kind, factory.storage_kind) + factory.key, target_storage_kind, factory.storage_kind + ) annotated_compressed_bytes = factory._raw_record - rec, contents = \ - self._data._parse_record_unchecked(annotated_compressed_bytes) + rec, contents = self._data._parse_record_unchecked(annotated_compressed_bytes) content = self._annotate_factory.parse_fulltext(contents, rec[1]) - size, chunks = self._data._record_to_data( - (rec[1],), rec[3], content.text()) - return b''.join(chunks) + size, chunks = self._data._record_to_data((rec[1],), rec[3], content.text()) + return b"".join(chunks) class DeltaAnnotatedToUnannotated(KnitAdapter): """An adapter for deltas from annotated to unannotated.""" def get_bytes(self, factory, target_storage_kind): - if target_storage_kind != 'knit-delta-gz': + if target_storage_kind != "knit-delta-gz": raise UnavailableRepresentation( - factory.key, target_storage_kind, factory.storage_kind) + factory.key, target_storage_kind, factory.storage_kind + ) annotated_compressed_bytes = factory._raw_record - rec, contents = \ - self._data._parse_record_unchecked(annotated_compressed_bytes) - delta = self._annotate_factory.parse_line_delta(contents, rec[1], - plain=True) + rec, contents = self._data._parse_record_unchecked(annotated_compressed_bytes) + delta = self._annotate_factory.parse_line_delta(contents, rec[1], plain=True) contents = self._plain_factory.lower_line_delta(delta) size, chunks = self._data._record_to_data((rec[1],), rec[3], contents) - return b''.join(chunks) + return b"".join(chunks) class FTAnnotatedToFullText(KnitAdapter): @@ -238,16 +239,17 @@ class FTAnnotatedToFullText(KnitAdapter): def get_bytes(self, factory, target_storage_kind): annotated_compressed_bytes = factory._raw_record - rec, contents = \ - self._data._parse_record_unchecked(annotated_compressed_bytes) - content, delta = self._annotate_factory.parse_record(factory.key[-1], - contents, factory._build_details, None) - if target_storage_kind == 'fulltext': - return b''.join(content.text()) - elif target_storage_kind in ('chunked', 'lines'): + rec, contents = self._data._parse_record_unchecked(annotated_compressed_bytes) + content, delta = self._annotate_factory.parse_record( + factory.key[-1], contents, factory._build_details, None + ) + if target_storage_kind == "fulltext": + return b"".join(content.text()) + elif target_storage_kind in ("chunked", "lines"): return content.text() raise UnavailableRepresentation( - factory.key, target_storage_kind, factory.storage_kind) + factory.key, target_storage_kind, factory.storage_kind + ) class DeltaAnnotatedToFullText(KnitAdapter): @@ -255,28 +257,28 @@ class DeltaAnnotatedToFullText(KnitAdapter): def get_bytes(self, factory, target_storage_kind): annotated_compressed_bytes = factory._raw_record - rec, contents = \ - self._data._parse_record_unchecked(annotated_compressed_bytes) - delta = self._annotate_factory.parse_line_delta(contents, rec[1], - plain=True) + rec, contents = self._data._parse_record_unchecked(annotated_compressed_bytes) + delta = self._annotate_factory.parse_line_delta(contents, rec[1], plain=True) compression_parent = factory.parents[0] - basis_entry = next(self._basis_vf.get_record_stream( - [compression_parent], 'unordered', True)) - if basis_entry.storage_kind == 'absent': + basis_entry = next( + self._basis_vf.get_record_stream([compression_parent], "unordered", True) + ) + if basis_entry.storage_kind == "absent": raise errors.RevisionNotPresent(compression_parent, self._basis_vf) - basis_lines = basis_entry.get_bytes_as('lines') + basis_lines = basis_entry.get_bytes_as("lines") # Manually apply the delta because we have one annotated content and # one plain. basis_content = PlainKnitContent(basis_lines, compression_parent) basis_content.apply_delta(delta, rec[1]) basis_content._should_strip_eol = factory._build_details[1] - if target_storage_kind == 'fulltext': - return b''.join(basis_content.text()) - elif target_storage_kind in ('chunked', 'lines'): + if target_storage_kind == "fulltext": + return b"".join(basis_content.text()) + elif target_storage_kind in ("chunked", "lines"): return basis_content.text() raise UnavailableRepresentation( - factory.key, target_storage_kind, factory.storage_kind) + factory.key, target_storage_kind, factory.storage_kind + ) class FTPlainToFullText(KnitAdapter): @@ -284,16 +286,17 @@ class FTPlainToFullText(KnitAdapter): def get_bytes(self, factory, target_storage_kind): compressed_bytes = factory._raw_record - rec, contents = \ - self._data._parse_record_unchecked(compressed_bytes) - content, delta = self._plain_factory.parse_record(factory.key[-1], - contents, factory._build_details, None) - if target_storage_kind == 'fulltext': - return b''.join(content.text()) - elif target_storage_kind in ('chunked', 'lines'): + rec, contents = self._data._parse_record_unchecked(compressed_bytes) + content, delta = self._plain_factory.parse_record( + factory.key[-1], contents, factory._build_details, None + ) + if target_storage_kind == "fulltext": + return b"".join(content.text()) + elif target_storage_kind in ("chunked", "lines"): return content.text() raise UnavailableRepresentation( - factory.key, target_storage_kind, factory.storage_kind) + factory.key, target_storage_kind, factory.storage_kind + ) class DeltaPlainToFullText(KnitAdapter): @@ -301,27 +304,29 @@ class DeltaPlainToFullText(KnitAdapter): def get_bytes(self, factory, target_storage_kind): compressed_bytes = factory._raw_record - rec, contents = \ - self._data._parse_record_unchecked(compressed_bytes) + rec, contents = self._data._parse_record_unchecked(compressed_bytes) self._plain_factory.parse_line_delta(contents, rec[1]) compression_parent = factory.parents[0] # XXX: string splitting overhead. - basis_entry = next(self._basis_vf.get_record_stream( - [compression_parent], 'unordered', True)) - if basis_entry.storage_kind == 'absent': + basis_entry = next( + self._basis_vf.get_record_stream([compression_parent], "unordered", True) + ) + if basis_entry.storage_kind == "absent": raise errors.RevisionNotPresent(compression_parent, self._basis_vf) - basis_lines = basis_entry.get_bytes_as('lines') + basis_lines = basis_entry.get_bytes_as("lines") basis_content = PlainKnitContent(basis_lines, compression_parent) # Manually apply the delta because we have one annotated content and # one plain. - content, _ = self._plain_factory.parse_record(rec[1], contents, - factory._build_details, basis_content) - if target_storage_kind == 'fulltext': - return b''.join(content.text()) - elif target_storage_kind in ('chunked', 'lines'): + content, _ = self._plain_factory.parse_record( + rec[1], contents, factory._build_details, basis_content + ) + if target_storage_kind == "fulltext": + return b"".join(content.text()) + elif target_storage_kind in ("chunked", "lines"): return content.text() raise UnavailableRepresentation( - factory.key, target_storage_kind, factory.storage_kind) + factory.key, target_storage_kind, factory.storage_kind + ) class KnitContentFactory(ContentFactory): @@ -330,8 +335,17 @@ class KnitContentFactory(ContentFactory): :seealso ContentFactory: """ - def __init__(self, key, parents, build_details, sha1, raw_record, - annotated, knit=None, network_bytes=None): + def __init__( + self, + key, + parents, + build_details, + sha1, + raw_record, + annotated, + knit=None, + network_bytes=None, + ): """Create a KnitContentFactory for key. :param key: The key. @@ -348,15 +362,15 @@ def __init__(self, key, parents, build_details, sha1, raw_record, self.sha1 = sha1 self.key = key self.parents = parents - if build_details[0] == 'line-delta': - kind = 'delta' + if build_details[0] == "line-delta": + kind = "delta" else: - kind = 'ft' + kind = "ft" if annotated: - annotated_kind = 'annotated-' + annotated_kind = "annotated-" else: - annotated_kind = '' - self.storage_kind = f'knit-{annotated_kind}{kind}-gz' + annotated_kind = "" + self.storage_kind = f"knit-{annotated_kind}{kind}-gz" self._raw_record = raw_record self._network_bytes = network_bytes self._build_details = build_details @@ -365,19 +379,22 @@ def __init__(self, key, parents, build_details, sha1, raw_record, def _create_network_bytes(self): """Create a fully serialised network version for transmission.""" # storage_kind, key, parents, Noeol, raw_record - key_bytes = b'\x00'.join(self.key) + key_bytes = b"\x00".join(self.key) if self.parents is None: - parent_bytes = b'None:' + parent_bytes = b"None:" else: - parent_bytes = b'\t'.join(b'\x00'.join(key) - for key in self.parents) + parent_bytes = b"\t".join(b"\x00".join(key) for key in self.parents) if self._build_details[1]: - noeol = b'N' + noeol = b"N" else: - noeol = b' ' + noeol = b" " network_bytes = b"%s\n%s\n%s\n%s%s" % ( - self.storage_kind.encode('ascii'), key_bytes, - parent_bytes, noeol, self._raw_record) + self.storage_kind.encode("ascii"), + key_bytes, + parent_bytes, + noeol, + self._raw_record, + ) self._network_bytes = network_bytes def get_bytes_as(self, storage_kind): @@ -385,8 +402,11 @@ def get_bytes_as(self, storage_kind): if self._network_bytes is None: self._create_network_bytes() return self._network_bytes - if ('-ft-' in self.storage_kind - and storage_kind in ('chunked', 'fulltext', 'lines')): + if "-ft-" in self.storage_kind and storage_kind in ( + "chunked", + "fulltext", + "lines", + ): adapter_key = (self.storage_kind, storage_kind) adapter_factory = adapter_registry.get(adapter_key) adapter = adapter_factory(None) @@ -394,12 +414,11 @@ def get_bytes_as(self, storage_kind): if self._knit is not None: # Not redundant with direct conversion above - that only handles # fulltext cases. - if storage_kind in ('chunked', 'lines'): + if storage_kind in ("chunked", "lines"): return self._knit.get_lines(self.key[0]) - elif storage_kind == 'fulltext': + elif storage_kind == "fulltext": return self._knit.get_text(self.key[0]) - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) def iter_bytes_as(self, storage_kind): return iter(self.get_bytes_as(storage_kind)) @@ -439,22 +458,22 @@ def get_bytes_as(self, storage_kind): else: # all the keys etc are contained in the bytes returned in the # first record. - return b'' - if storage_kind in ('chunked', 'fulltext', 'lines'): + return b"" + if storage_kind in ("chunked", "fulltext", "lines"): chunks = self._generator._get_one_work(self.key).text() - if storage_kind in ('chunked', 'lines'): + if storage_kind in ("chunked", "lines"): return chunks else: - return b''.join(chunks) - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + return b"".join(chunks) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) def iter_bytes_as(self, storage_kind): - if storage_kind in ('chunked', 'lines'): + if storage_kind in ("chunked", "lines"): chunks = self._generator._get_one_work(self.key).text() return iter(chunks) - raise errors.UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise errors.UnavailableRepresentation( + self.key, storage_kind, self.storage_kind + ) def knit_delta_closure_to_records(storage_kind, bytes, line_end): @@ -475,29 +494,42 @@ def knit_network_to_record(storage_kind, bytes, line_end): :param bytes: The bytes of the record on the network. """ start = line_end - line_end = bytes.find(b'\n', start) - key = tuple(bytes[start:line_end].split(b'\x00')) + line_end = bytes.find(b"\n", start) + key = tuple(bytes[start:line_end].split(b"\x00")) start = line_end + 1 - line_end = bytes.find(b'\n', start) + line_end = bytes.find(b"\n", start) parent_line = bytes[start:line_end] - if parent_line == b'None:': + if parent_line == b"None:": parents = None else: parents = tuple( - [tuple(segment.split(b'\x00')) for segment in parent_line.split(b'\t') - if segment]) + [ + tuple(segment.split(b"\x00")) + for segment in parent_line.split(b"\t") + if segment + ] + ) start = line_end + 1 - noeol = bytes[start:start + 1] == b'N' - if 'ft' in storage_kind: - method = 'fulltext' + noeol = bytes[start : start + 1] == b"N" + if "ft" in storage_kind: + method = "fulltext" else: - method = 'line-delta' + method = "line-delta" build_details = (method, noeol) start = start + 1 raw_record = bytes[start:] - annotated = 'annotated' in storage_kind - return [KnitContentFactory(key, parents, build_details, None, raw_record, - annotated, network_bytes=bytes)] + annotated = "annotated" in storage_kind + return [ + KnitContentFactory( + key, + parents, + build_details, + None, + raw_record, + annotated, + network_bytes=bytes, + ) + ] class KnitContent: @@ -518,11 +550,12 @@ def apply_delta(self, delta, new_version_id): def line_delta_iter(self, new_lines): """Generate line-based delta from this content to new_lines.""" import patiencediff + new_texts = new_lines.text() old_texts = self.text() s = patiencediff.PatienceSequenceMatcher(None, old_texts, new_texts) for tag, i1, i2, j1, j2 in s.get_opcodes(): - if tag == 'equal': + if tag == "equal": continue # ofrom, oto, length, data yield i1, i2, j2 - j1, new_lines._lines[j1:j2] @@ -569,7 +602,7 @@ def annotate(self): lines = self._lines[:] if self._should_strip_eol: origin, last_line = lines[-1] - lines[-1] = (origin, last_line.rstrip(b'\n')) + lines[-1] = (origin, last_line.rstrip(b"\n")) return lines def apply_delta(self, delta, new_version_id): @@ -577,7 +610,7 @@ def apply_delta(self, delta, new_version_id): offset = 0 lines = self._lines for start, end, count, delta_lines in delta: - lines[offset + start:offset + end] = delta_lines + lines[offset + start : offset + end] = delta_lines offset = offset + (start - end) + count def text(self): @@ -587,10 +620,11 @@ def text(self): # most commonly (only?) caused by the internal form of the knit # missing annotation information because of a bug - see thread # around 20071015 - raise KnitCorrupt(self, - f"line in annotated knit missing annotation information: {e}") from e + raise KnitCorrupt( + self, f"line in annotated knit missing annotation information: {e}" + ) from e if self._should_strip_eol: - lines[-1] = lines[-1].rstrip(b'\n') + lines[-1] = lines[-1].rstrip(b"\n") return lines def copy(self): @@ -619,7 +653,7 @@ def apply_delta(self, delta, new_version_id): offset = 0 lines = self._lines for start, end, count, delta_lines in delta: - lines[offset + start:offset + end] = delta_lines + lines[offset + start : offset + end] = delta_lines offset = offset + (start - end) + count self._version_id = new_version_id @@ -630,15 +664,16 @@ def text(self): lines = self._lines if self._should_strip_eol: lines = lines[:] - lines[-1] = lines[-1].rstrip(b'\n') + lines[-1] = lines[-1].rstrip(b"\n") return lines class _KnitFactory: """Base class for common Factory functions.""" - def parse_record(self, version_id, record, record_details, - base_content, copy_base_content=True): + def parse_record( + self, version_id, record, record_details, base_content, copy_base_content=True + ): """Parse a record into a full content object. :param version_id: The official version id for this content @@ -654,7 +689,7 @@ def parse_record(self, version_id, record, record_details, delta may be None """ method, noeol = record_details - if method == 'line-delta': + if method == "line-delta": if copy_base_content: content = base_content.copy() else: @@ -689,7 +724,7 @@ def parse_fulltext(self, content, version_id): # but the code itself doesn't really depend on that. # Figure out a way to not require the overhead of turning the # list back into tuples. - lines = (tuple(line.split(b' ', 1)) for line in content) + lines = (tuple(line.split(b" ", 1)) for line in content) return AnnotatedKnitContent(lines) def parse_line_delta(self, lines, version_id, plain=False): @@ -712,7 +747,7 @@ def parse_line_delta(self, lines, version_id, plain=False): cache = {} def cache_and_return(line): - origin, text = line.split(b' ', 1) + origin, text = line.split(b" ", 1) return cache.setdefault(origin, origin), text # walk through the lines parsing. @@ -720,21 +755,19 @@ def cache_and_return(line): # loop to minimise any performance impact if plain: for header in lines: - start, end, count = (int(n) for n in header.split(b',')) - contents = [next(lines).split(b' ', 1)[1] - for _ in range(count)] + start, end, count = (int(n) for n in header.split(b",")) + contents = [next(lines).split(b" ", 1)[1] for _ in range(count)] result.append((start, end, count, contents)) else: for header in lines: - start, end, count = (int(n) for n in header.split(b',')) - contents = [tuple(next(lines).split(b' ', 1)) - for _ in range(count)] + start, end, count = (int(n) for n in header.split(b",")) + contents = [tuple(next(lines).split(b" ", 1)) for _ in range(count)] result.append((start, end, count, contents)) return result def get_fulltext_content(self, lines): """Extract just the content lines from a fulltext.""" - return (line.split(b' ', 1)[1] for line in lines) + return (line.split(b" ", 1)[1] for line in lines) def get_linedelta_content(self, lines): """Extract just the content from a line delta. @@ -744,10 +777,10 @@ def get_linedelta_content(self, lines): """ lines = iter(lines) for header in lines: - header = header.split(b',') + header = header.split(b",") count = int(header[2]) for _ in range(count): - origin, text = next(lines).split(b' ', 1) + origin, text = next(lines).split(b" ", 1) yield text def lower_fulltext(self, content): @@ -755,7 +788,7 @@ def lower_fulltext(self, content): see parse_fulltext which this inverts. """ - return [b'%s %s' % (o, t) for o, t in content._lines] + return [b"%s %s" % (o, t) for o, t in content._lines] def lower_line_delta(self, delta): """Convert a delta into a serializable form. @@ -766,9 +799,8 @@ def lower_line_delta(self, delta): # the origin is a valid utf-8 line, eventually we could remove it out = [] for start, end, c, lines in delta: - out.append(b'%d,%d,%d\n' % (start, end, c)) - out.extend(origin + b' ' + text - for origin, text in lines) + out.append(b"%d,%d,%d\n" % (start, end, c)) + out.extend(origin + b" " + text for origin, text in lines) return out def annotate(self, knit, key): @@ -810,8 +842,8 @@ def parse_line_delta_iter(self, lines, version_id): while cur < num_lines: header = lines[cur] cur += 1 - start, end, c = (int(n) for n in header.split(b',')) - yield start, end, c, lines[cur:cur + c] + start, end, c = (int(n) for n in header.split(b",")) + yield start, end, c, lines[cur : cur + c] cur += c def parse_line_delta(self, lines, version_id): @@ -829,7 +861,7 @@ def get_linedelta_content(self, lines): """ lines = iter(lines) for header in lines: - header = header.split(b',') + header = header.split(b",") count = int(header[2]) for _ in range(count): yield next(lines) @@ -840,7 +872,7 @@ def lower_fulltext(self, content): def lower_line_delta(self, delta): out = [] for start, end, c, lines in delta: - out.append(b'%d,%d,%d\n' % (start, end, c)) + out.append(b"%d,%d,%d\n" % (start, end, c)) out.extend(lines) return out @@ -858,11 +890,12 @@ def make_file_factory(annotated, mapper): :param annotated: knit annotations are wanted. :param mapper: The mapper from keys to paths. """ + def factory(transport): - index = _KndxIndex(transport, mapper, lambda: None, - lambda: True, lambda: True) + index = _KndxIndex(transport, mapper, lambda: None, lambda: True, lambda: True) access = _KnitKeyAccess(transport, mapper) return KnitVersionedFiles(index, access, annotated=annotated) + return factory @@ -876,6 +909,7 @@ def make_pack_factory(graph, delta, keylength): :param delta: Delta compress contents. :param keylength: How long should keys be. """ + def factory(transport): parents = graph or delta ref_length = 0 @@ -886,20 +920,26 @@ def factory(transport): max_delta_chain = 200 else: max_delta_chain = 0 - graph_index = _mod_index.InMemoryGraphIndex(reference_lists=ref_length, - key_elements=keylength) - stream = transport.open_write_stream('newpack') + graph_index = _mod_index.InMemoryGraphIndex( + reference_lists=ref_length, key_elements=keylength + ) + stream = transport.open_write_stream("newpack") writer = pack.ContainerWriter(stream.write) writer.begin() - index = _KnitGraphIndex(graph_index, lambda: True, parents=parents, - deltas=delta, add_callback=graph_index.add_nodes) + index = _KnitGraphIndex( + graph_index, + lambda: True, + parents=parents, + deltas=delta, + add_callback=graph_index.add_nodes, + ) access = pack_repo._DirectPackAccess({}) - access.set_writer(writer, graph_index, (transport, 'newpack')) - result = KnitVersionedFiles(index, access, - max_delta_chain=max_delta_chain) + access.set_writer(writer, graph_index, (transport, "newpack")) + result = KnitVersionedFiles(index, access, max_delta_chain=max_delta_chain) result.stream = stream result.writer = writer return result + return factory @@ -933,8 +973,7 @@ def _get_total_build_size(self, keys, positions): if compression_parent not in all_build_index_memos: next_keys.add(compression_parent) build_keys = next_keys - return sum(index_memo[2] - for index_memo in all_build_index_memos.values()) + return sum(index_memo[2] for index_memo in all_build_index_memos.values()) class KnitVersionedFiles(VersionedFilesWithFallbacks): @@ -948,8 +987,9 @@ class KnitVersionedFiles(VersionedFilesWithFallbacks): *this* vfs; if there are fallbacks they must be queried separately. """ - def __init__(self, index, data_access, max_delta_chain=200, - annotated=False, reload_func=None): + def __init__( + self, index, data_access, max_delta_chain=200, annotated=False, reload_func=None + ): """Create a KnitVersionedFiles with index and data_access. :param index: The index for the knit data. @@ -978,9 +1018,13 @@ def __repr__(self): def without_fallbacks(self): """Return a clone of this object without any fallbacks configured.""" - return KnitVersionedFiles(self._index, self._access, - self._max_delta_chain, self._factory.annotated, - self._reload_func) + return KnitVersionedFiles( + self._index, + self._access, + self._max_delta_chain, + self._factory.annotated, + self._reload_func, + ) def add_fallback_versioned_files(self, a_versioned_files): """Add a source of texts for texts not present in this knit. @@ -989,9 +1033,17 @@ def add_fallback_versioned_files(self, a_versioned_files): """ self._immediate_fallback_vfs.append(a_versioned_files) - def add_lines(self, key, parents, lines, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False, - check_content=True): + def add_lines( + self, + key, + parents, + lines, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + check_content=True, + ): """See VersionedFiles.add_lines().""" self._index._check_write_ok() self._check_add(key, lines, random_id, check_content) @@ -1000,14 +1052,26 @@ def add_lines(self, key, parents, lines, parent_texts=None, # indexes can't directly store that, so we give them # an empty tuple instead. parents = () - line_bytes = b''.join(lines) - return self._add(key, lines, parents, - parent_texts, left_matching_blocks, nostore_sha, random_id, - line_bytes=line_bytes) - - def add_content(self, content_factory, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, - random_id=False): + line_bytes = b"".join(lines) + return self._add( + key, + lines, + parents, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + line_bytes=line_bytes, + ) + + def add_content( + self, + content_factory, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + ): """See VersionedFiles.add_content().""" self._index._check_write_ok() key = content_factory.key @@ -1018,15 +1082,30 @@ def add_content(self, content_factory, parent_texts=None, # indexes can't directly store that, so we give them # an empty tuple instead. parents = () - lines = content_factory.get_bytes_as('lines') - line_bytes = content_factory.get_bytes_as('fulltext') - return self._add(key, lines, parents, - parent_texts, left_matching_blocks, nostore_sha, random_id, - line_bytes=line_bytes) - - def _add(self, key, lines, parents, parent_texts, - left_matching_blocks, nostore_sha, random_id, - line_bytes): + lines = content_factory.get_bytes_as("lines") + line_bytes = content_factory.get_bytes_as("fulltext") + return self._add( + key, + lines, + parents, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + line_bytes=line_bytes, + ) + + def _add( + self, + key, + lines, + parents, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + line_bytes, + ): """Add a set of lines on top of version specified by parents. Any versions not present will be converted into ghosts. @@ -1058,8 +1137,7 @@ def _add(self, key, lines, parents, parent_texts, present_parents.append(parent) # Currently we can only compress against the left most present parent. - if (len(present_parents) == 0 - or present_parents[0] != parents[0]): + if len(present_parents) == 0 or present_parents[0] != parents[0]: delta = False else: # To speed the extract of texts the delta chain is limited @@ -1073,8 +1151,8 @@ def _add(self, key, lines, parents, parent_texts, # Note: line_bytes is not modified to add a newline, that is tracked # via the no_eol flag. 'lines' *is* modified, because that is the # general values needed by the Content code. - if line_bytes and not line_bytes.endswith(b'\n'): - options.append(b'no-eol') + if line_bytes and not line_bytes.endswith(b"\n"): + options.append(b"no-eol") no_eol = True # Copy the existing list, or create a new one if lines is None: @@ -1082,7 +1160,7 @@ def _add(self, key, lines, parents, parent_texts, else: lines = lines[:] # Replace the last line with one that ends in a final newline - lines[-1] = lines[-1] + b'\n' + lines[-1] = lines[-1] + b"\n" if lines is None: lines = osutils.split_lines(line_bytes) @@ -1090,7 +1168,7 @@ def _add(self, key, lines, parents, parent_texts, if not isinstance(element, bytes): raise TypeError(f"key contains non-bytestrings: {key!r}") if key[-1] is None: - key = key[:-1] + (b'sha1:' + digest,) + key = key[:-1] + (b"sha1:" + digest,) elif not isinstance(key[-1], bytes): raise TypeError(f"key contains non-bytestrings: {key!r}") # Knit hunks are still last-element only @@ -1102,25 +1180,29 @@ def _add(self, key, lines, parents, parent_texts, content._should_strip_eol = True if delta or (self._factory.annotated and len(present_parents) > 0): # Merge annotations from parent texts if needed. - delta_hunks = self._merge_annotations(content, present_parents, - parent_texts, delta, self._factory.annotated, - left_matching_blocks) + delta_hunks = self._merge_annotations( + content, + present_parents, + parent_texts, + delta, + self._factory.annotated, + left_matching_blocks, + ) if delta: - options.append(b'line-delta') + options.append(b"line-delta") store_lines = self._factory.lower_line_delta(delta_hunks) size, data = self._record_to_data(key, digest, store_lines) else: - options.append(b'fulltext') + options.append(b"fulltext") # isinstance is slower and we have no hierarchy. if self._factory.__class__ is KnitPlainFactory: # Use the already joined bytes saving iteration time in # _record_to_data. dense_lines = [line_bytes] if no_eol: - dense_lines.append(b'\n') - size, data = self._record_to_data(key, digest, - lines, dense_lines) + dense_lines.append(b"\n") + size, data = self._record_to_data(key, digest, lines, dense_lines) else: # get mixed annotation + content and feed it into the # serialiser. @@ -1129,8 +1211,8 @@ def _add(self, key, lines, parents, parent_texts, access_memo = self._access.add_raw_record(key, size, data) self._index.add_records( - ((key, options, access_memo, parents),), - random_id=random_id) + ((key, options, access_memo, parents),), random_id=random_id + ) return digest, text_length, content def annotate(self, key): @@ -1146,7 +1228,7 @@ def check(self, progress_bar=None, keys=None): return self._logical_check() else: # At the moment, check does not extra work over get_record_stream - return self.get_record_stream(keys, 'unordered', True) + return self.get_record_stream(keys, "unordered", True) def _logical_check(self): # This doesn't actually test extraction of everything, but that will @@ -1156,12 +1238,13 @@ def _logical_check(self): keys = self._index.keys() parent_map = self.get_parent_map(keys) for key in keys: - if self._index.get_method(key) != 'fulltext': + if self._index.get_method(key) != "fulltext": compression_parent = parent_map[key][0] if compression_parent not in parent_map: - raise KnitCorrupt(self, - "Missing basis parent {} for {}".format( - compression_parent, key)) + raise KnitCorrupt( + self, + f"Missing basis parent {compression_parent} for {key}", + ) for fallback_vfs in self._immediate_fallback_vfs: fallback_vfs.check() @@ -1193,8 +1276,9 @@ def _check_header_version(self, rec, version_id): These have the last component of the key embedded in the record. """ if rec[1] != version_id: - raise KnitCorrupt(self, - f'unexpected version, wanted {version_id!r}, got {rec[1]!r}') + raise KnitCorrupt( + self, f"unexpected version, wanted {version_id!r}, got {rec[1]!r}" + ) def _check_should_delta(self, parent): """Iterate back through the parent listing, looking for a fulltext. @@ -1265,12 +1349,10 @@ def _get_components_positions(self, keys, allow_missing=False): current_components = set(pending_components) pending_components = set() for key, details in build_details.items(): - (index_memo, compression_parent, parents, - record_details) = details + (index_memo, compression_parent, parents, record_details) = details if compression_parent is not None: pending_components.add(compression_parent) - component_data[key] = self._build_details_to_components( - details) + component_data[key] = self._build_details_to_components(details) missing = current_components.difference(build_details) if missing and not allow_missing: raise errors.RevisionNotPresent(missing.pop(), self) @@ -1338,8 +1420,7 @@ def _get_record_map(self, keys, allow_missing=False): :param allow_missing: If some records are missing, rather than error, just return the data that could be generated. """ - raw_map = self._get_record_map_unparsed(keys, - allow_missing=allow_missing) + raw_map = self._get_record_map_unparsed(keys, allow_missing=allow_missing) return self._raw_map_to_record_map(raw_map) def _raw_map_to_record_map(self, raw_map): @@ -1369,12 +1450,12 @@ def _get_record_map_unparsed(self, keys, allow_missing=False): # operation? We wouldn't want to end up with a broken chain. while True: try: - position_map = self._get_components_positions(keys, - allow_missing=allow_missing) + position_map = self._get_components_positions( + keys, allow_missing=allow_missing + ) # key = component_id, r = record_details, i_m = index_memo, # n = next - records = [(key, i_m) for key, (r, i_m, n) - in position_map.items()] + records = [(key, i_m) for key, (r, i_m, n) in position_map.items()] # Sort by the index memo, so that we request records from the # same pack file together, and in forward-sorted order records.sort(key=operator.itemgetter(1)) @@ -1409,7 +1490,7 @@ def _split_by_prefix(cls, keys): prefix_order = [] for key in keys: if len(key) == 1: - prefix = b'' + prefix = b"" else: prefix = key[0] @@ -1420,8 +1501,9 @@ def _split_by_prefix(cls, keys): prefix_order.append(prefix) return split_by_prefix, prefix_order - def _group_keys_for_io(self, keys, non_local_keys, positions, - _min_buffer_size=_STREAM_MIN_BUFFER_SIZE): + def _group_keys_for_io( + self, keys, non_local_keys, positions, _min_buffer_size=_STREAM_MIN_BUFFER_SIZE + ): """For the given keys, group them into 'best-sized' requests. The idea is to avoid making 1 request per file, but to never try to @@ -1481,32 +1563,33 @@ def get_record_stream(self, keys, ordering, include_delta_closure): return if not self._index.has_graph: # Cannot sort when no graph has been stored. - ordering = 'unordered' + ordering = "unordered" remaining_keys = keys while True: try: keys = set(remaining_keys) - for content_factory in self._get_remaining_record_stream(keys, - ordering, include_delta_closure): + for content_factory in self._get_remaining_record_stream( + keys, ordering, include_delta_closure + ): remaining_keys.discard(content_factory.key) yield content_factory return except pack_repo.RetryWithNewPacks as e: self._access.reload_or_raise(e) - def _get_remaining_record_stream(self, keys, ordering, - include_delta_closure): + def _get_remaining_record_stream(self, keys, ordering, include_delta_closure): """This function is the 'retry' portion for get_record_stream.""" if include_delta_closure: - positions = self._get_components_positions( - keys, allow_missing=True) + positions = self._get_components_positions(keys, allow_missing=True) else: build_details = self._index.get_build_details(keys) # map from key to # (record_details, access_memo, compression_parent_key) - positions = {key: self._build_details_to_components(details) - for key, details in build_details.items()} + positions = { + key: self._build_details_to_components(details) + for key, details in build_details.items() + } absent_keys = keys.difference(set(positions)) # There may be more absent keys : if we're missing the basis component # and are trying to include the delta closure. @@ -1543,8 +1626,8 @@ def _get_remaining_record_stream(self, keys, ordering, needed_from_fallback.add(key) # Double index lookups here : need a unified api ? global_map, parent_maps = self._get_parent_map_with_sources(keys) - if ordering in ('topological', 'groupcompress'): - if ordering == 'topological': + if ordering in ("topological", "groupcompress"): + if ordering == "topological": # Global topological sort present_keys = tsort.topo_sort(global_map) else: @@ -1562,9 +1645,11 @@ def _get_remaining_record_stream(self, keys, ordering, current_source = key_source source_keys[-1][1].append(key) else: - if ordering != 'unordered': - raise AssertionError('valid values for ordering are:' - f' "unordered", "groupcompress" or "topological" not: {ordering!r}') + if ordering != "unordered": + raise AssertionError( + "valid values for ordering are:" + f' "unordered", "groupcompress" or "topological" not: {ordering!r}' + ) # Just group by source; remote sources first. present_keys = [] source_keys = [] @@ -1591,12 +1676,14 @@ def _get_remaining_record_stream(self, keys, ordering, # XXX: get_content_maps performs its own index queries; allow state # to be passed in. non_local_keys = needed_from_fallback - absent_keys - for keys, non_local_keys in self._group_keys_for_io(present_keys, # noqa: B020 - non_local_keys, - positions): - generator = _VFContentMapGenerator(self, keys, non_local_keys, - global_map, - ordering=ordering) + for keys, non_local_keys in self._group_keys_for_io( # noqa: B020 + present_keys, + non_local_keys, + positions, + ): + generator = _VFContentMapGenerator( + self, keys, non_local_keys, global_map, ordering=ordering + ) yield from generator.get_record_stream() else: for source, keys in source_keys: @@ -1605,13 +1692,20 @@ def _get_remaining_record_stream(self, keys, ordering, records = [(key, positions[key][1]) for key in keys] for key, raw_data in self._read_records_iter_unchecked(records): (record_details, index_memo, _) = positions[key] - yield KnitContentFactory(key, global_map[key], - record_details, None, raw_data, self._factory.annotated, None) + yield KnitContentFactory( + key, + global_map[key], + record_details, + None, + raw_data, + self._factory.annotated, + None, + ) else: - vf = self._immediate_fallback_vfs[parent_maps.index( - source) - 1] - yield from vf.get_record_stream(keys, ordering, - include_delta_closure) + vf = self._immediate_fallback_vfs[parent_maps.index(source) - 1] + yield from vf.get_record_stream( + keys, ordering, include_delta_closure + ) def get_sha1s(self, keys): """See VersionedFiles.get_sha1s().""" @@ -1639,6 +1733,7 @@ def insert_record_stream(self, stream): :return: None :seealso VersionedFiles.get_record_stream: """ + def get_adapter(adapter_key): try: return adapters[adapter_key] @@ -1647,6 +1742,7 @@ def get_adapter(adapter_key): adapter = adapter_factory(self) adapters[adapter_key] = adapter return adapter + delta_types = set() if self._factory.annotated: # self is annotated, we need annotated knits to use directly. @@ -1682,7 +1778,7 @@ def get_adapter(adapter_key): buffered_index_entries = {} for record in stream: kind = record.storage_kind - if kind.startswith('knit-') and kind.endswith('-gz'): + if kind.startswith("knit-") and kind.endswith("-gz"): # Check that the ID in the header of the raw knit bytes matches # the record metadata. raw_data = record._raw_record @@ -1697,13 +1793,14 @@ def get_adapter(adapter_key): else: compression_parent = None # Raise an error when a record is missing. - if record.storage_kind == 'absent': + if record.storage_kind == "absent": raise RevisionNotPresent([record.key], self) - elif ((record.storage_kind in knit_types) and - (compression_parent is None or - not self._immediate_fallback_vfs or - compression_parent in self._index or - compression_parent not in self)): + elif (record.storage_kind in knit_types) and ( + compression_parent is None + or not self._immediate_fallback_vfs + or compression_parent in self._index + or compression_parent not in self + ): # we can insert the knit record literally if either it has no # compression parent OR we already have its basis in this kvf # OR the basis is not present even in the fallbacks. In the @@ -1726,9 +1823,9 @@ def get_adapter(adapter_key): # It's a knit record, it has a _raw_record field (even if # it was reconstituted from a network stream). bytes = record._raw_record - options = [record._build_details[0].encode('ascii')] + options = [record._build_details[0].encode("ascii")] if record._build_details[1]: - options.append(b'no-eol') + options.append(b"no-eol") # Just blat it across. # Note: This does end up adding data on duplicate keys. As # modern repositories use atomic insertions this should not @@ -1738,9 +1835,10 @@ def get_adapter(adapter_key): # needed by in the kndx index support raising on a duplicate # add with identical parents and options. access_memo = self._access.add_raw_record( - record.key, len(bytes), [bytes]) + record.key, len(bytes), [bytes] + ) index_entry = (record.key, options, access_memo, parents) - if b'fulltext' not in options: + if b"fulltext" not in options: # Not a fulltext, so we need to make sure the compression # parent will also be present. # Note that pack backed knits don't need to buffer here @@ -1753,13 +1851,14 @@ def get_adapter(adapter_key): # KnitVersionedFiles, not in a fallback. if compression_parent not in self._index: pending = buffered_index_entries.setdefault( - compression_parent, []) + compression_parent, [] + ) pending.append(index_entry) buffered = True if not buffered: self._index.add_records([index_entry]) - elif record.storage_kind in ('chunked', 'file'): - self.add_lines(record.key, parents, record.get_bytes_as('lines')) + elif record.storage_kind in ("chunked", "file"): + self.add_lines(record.key, parents, record.get_bytes_as("lines")) else: # Not suitable for direct insertion as a # delta, either because it's not the right format, or this @@ -1769,11 +1868,11 @@ def get_adapter(adapter_key): self._access.flush() try: # Try getting a fulltext directly from the record. - lines = record.get_bytes_as('lines') + lines = record.get_bytes_as("lines") except UnavailableRepresentation: - adapter_key = record.storage_kind, 'lines' + adapter_key = record.storage_kind, "lines" adapter = get_adapter(adapter_key) - lines = adapter.get_bytes(record, 'lines') + lines = adapter.get_bytes(record, "lines") try: self.add_lines(record.key, parents, lines) except errors.RevisionAlreadyPresent: @@ -1787,7 +1886,8 @@ def get_adapter(adapter_key): index_entries = buffered_index_entries[key] self._index.add_records(index_entries) added_keys.extend( - [index_entry[0] for index_entry in index_entries]) + [index_entry[0] for index_entry in index_entries] + ) del buffered_index_entries[key] if buffered_index_entries: # There were index entries buffered at the end of the stream, @@ -1797,8 +1897,7 @@ def get_adapter(adapter_key): for key in buffered_index_entries: index_entries = buffered_index_entries[key] all_entries.extend(index_entries) - self._index.add_records( - all_entries, missing_compression_parents=True) + self._index.add_records(all_entries, missing_compression_parents=True) def get_missing_compression_parent_keys(self): """Return an iterable of keys of missing compression parents. @@ -1852,17 +1951,15 @@ def iter_lines_added_or_present_in_keys(self, keys, pb=None): if key in keys: key_records.append((key, details[0])) records_iter = enumerate(self._read_records_iter(key_records)) - for (key_idx, (key, data, _sha_value)) in records_iter: - pb.update(gettext('Walking content'), key_idx, total) + for key_idx, (key, data, _sha_value) in records_iter: + pb.update(gettext("Walking content"), key_idx, total) compression_parent = build_details[key][1] if compression_parent is None: # fulltext - line_iterator = self._factory.get_fulltext_content( - data) + line_iterator = self._factory.get_fulltext_content(data) else: # Delta - line_iterator = self._factory.get_linedelta_content( - data) + line_iterator = self._factory.get_linedelta_content(data) # Now that we are yielding the data for this key, remove it # from the list keys.remove(key) @@ -1891,21 +1988,28 @@ def iter_lines_added_or_present_in_keys(self, keys, pb=None): source_keys.add(key) yield line, key keys.difference_update(source_keys) - pb.update(gettext('Walking content'), total, total) + pb.update(gettext("Walking content"), total, total) def _make_line_delta(self, delta_seq, new_content): """Generate a line delta from delta_seq and new_content.""" diff_hunks = [] for op in delta_seq.get_opcodes(): - if op[0] == 'equal': + if op[0] == "equal": continue diff_hunks.append( - (op[1], op[2], op[4] - op[3], new_content._lines[op[3]:op[4]])) + (op[1], op[2], op[4] - op[3], new_content._lines[op[3] : op[4]]) + ) return diff_hunks - def _merge_annotations(self, content, parents, parent_texts=None, - delta=None, annotated=None, - left_matching_blocks=None): + def _merge_annotations( + self, + content, + parents, + parent_texts=None, + delta=None, + annotated=None, + left_matching_blocks=None, + ): """Merge annotations for content and generate deltas. This is done by comparing the annotations based on changes to the text @@ -1915,6 +2019,7 @@ def _merge_annotations(self, content, parents, parent_texts=None, if parent_texts is None: parent_texts = {} import patiencediff + if left_matching_blocks is not None: delta_seq = diff._PrematchedMatcher(left_matching_blocks) else: @@ -1922,25 +2027,26 @@ def _merge_annotations(self, content, parents, parent_texts=None, if annotated: for parent_key in parents: merge_content = self._get_content(parent_key, parent_texts) - if (parent_key == parents[0] and delta_seq is not None): + if parent_key == parents[0] and delta_seq is not None: seq = delta_seq else: seq = patiencediff.PatienceSequenceMatcher( - None, merge_content.text(), content.text()) + None, merge_content.text(), content.text() + ) for i, j, n in seq.get_matching_blocks(): if n == 0: continue # this copies (origin, text) pairs across to the new # content for any line that matches the last-checked # parent. - content._lines[j:j + n] = merge_content._lines[i:i + n] + content._lines[j : j + n] = merge_content._lines[i : i + n] # XXX: Robert says the following block is a workaround for a # now-fixed bug and it can probably be deleted. -- mbp 20080618 - if content._lines and not content._lines[-1][1].endswith(b'\n'): + if content._lines and not content._lines[-1][1].endswith(b"\n"): # The copied annotation was from a line without a trailing EOL, # reinstate one for the content object, to ensure correct # serialization. - line = content._lines[-1][1] + b'\n' + line = content._lines[-1][1] + b"\n" content._lines[-1] = (content._lines[-1][0], line) if delta: if delta_seq is None: @@ -1948,7 +2054,8 @@ def _merge_annotations(self, content, parents, parent_texts=None, new_texts = content.text() old_texts = reference_content.text() delta_seq = patiencediff.PatienceSequenceMatcher( - None, old_texts, new_texts) + None, old_texts, new_texts + ) return self._make_line_delta(delta_seq, content) def _parse_record(self, version_id, data): @@ -1966,13 +2073,14 @@ def _parse_record_header(self, key, raw_data): :return: the header and the decompressor stream. as (stream, header_record) """ - df = gzip.GzipFile(mode='rb', fileobj=BytesIO(raw_data)) + df = gzip.GzipFile(mode="rb", fileobj=BytesIO(raw_data)) try: # Current serialise rec = self._check_header(key, df.readline()) except Exception as e: - raise KnitCorrupt(self, - f"While reading {{{key}}} got {e.__class__.__name__}({str(e)})") from e + raise KnitCorrupt( + self, f"While reading {{{key}}} got {e.__class__.__name__}({str(e)})" + ) from e return df, rec def _parse_record_unchecked(self, data): @@ -1980,22 +2088,28 @@ def _parse_record_unchecked(self, data): # 4168 calls in 2880 217 internal # 4168 calls to _parse_record_header in 2121 # 4168 calls to readlines in 330 - with gzip.GzipFile(mode='rb', fileobj=BytesIO(data)) as df: + with gzip.GzipFile(mode="rb", fileobj=BytesIO(data)) as df: try: record_contents = df.readlines() except Exception as e: - raise KnitCorrupt(self, f"Corrupt compressed record {data!r}, got {e.__class__.__name__}({str(e)})") from e + raise KnitCorrupt( + self, + f"Corrupt compressed record {data!r}, got {e.__class__.__name__}({str(e)})", + ) from e header = record_contents.pop(0) rec = self._split_header(header) last_line = record_contents.pop() if len(record_contents) != int(rec[2]): - raise KnitCorrupt(self, - 'incorrect number of lines {} != {}' - ' for version {{{}}} {}'.format(len(record_contents), int(rec[2]), - rec[1], record_contents)) - if last_line != b'end %s\n' % rec[1]: - raise KnitCorrupt(self, - f'unexpected version end line {last_line!r}, wanted {rec[1]!r}') + raise KnitCorrupt( + self, + f"incorrect number of lines {len(record_contents)} != {int(rec[2])}" + f" for version {{{rec[1]}}} {record_contents}", + ) + if last_line != b"end %s\n" % rec[1]: + raise KnitCorrupt( + self, + f"unexpected version end line {last_line!r}, wanted {rec[1]!r}", + ) return rec, record_contents def _read_records_iter(self, records): @@ -2020,7 +2134,8 @@ def _read_records_iter(self, records): # The transport optimizes the fetching as well # (ie, reads continuous ranges.) raw_data = self._access.get_raw_records( - [index_memo for key, index_memo in needed_records]) + [index_memo for key, index_memo in needed_records] + ) for (key, _index_memo), data in zip(needed_records, raw_data): content, digest = self._parse_record(key[-1], data) @@ -2053,8 +2168,7 @@ def _read_records_iter_unchecked(self, records): # uses readv so nice and fast we hope. if len(records): # grab the disk data needed. - needed_offsets = [index_memo for key, index_memo - in records] + needed_offsets = [index_memo for key, index_memo in records] raw_records = self._access.get_raw_records(needed_offsets) for key, _index_memo in records: @@ -2079,23 +2193,21 @@ def _record_to_data(self, key, digest, lines, dense_lines=None): chunks.append(b"end " + key[-1] + b"\n") for chunk in chunks: if not isinstance(chunk, bytes): - raise AssertionError( - f'data must be plain bytes was {type(chunk)}') - if lines and not lines[-1].endswith(b'\n'): - raise ValueError(f'corrupt lines value {lines!r}') + raise AssertionError(f"data must be plain bytes was {type(chunk)}") + if lines and not lines[-1].endswith(b"\n"): + raise ValueError(f"corrupt lines value {lines!r}") compressed_chunks = tuned_gzip.chunks_to_gzip(chunks) return sum(map(len, compressed_chunks)), compressed_chunks def _split_header(self, line): rec = line.split() if len(rec) != 4: - raise KnitCorrupt(self, - 'unexpected number of elements in record header') + raise KnitCorrupt(self, "unexpected number of elements in record header") return rec def keys(self): """See VersionedFiles.keys.""" - if debug.debug_flag_enabled('evil'): + if debug.debug_flag_enabled("evil"): trace.mutter_callsite(2, "keys scales with size of history") sources = [self._index] + self._immediate_fallback_vfs result = set() @@ -2107,7 +2219,7 @@ def keys(self): class _ContentMapGenerator: """Generate texts or expose raw deltas for a set of texts.""" - def __init__(self, ordering='unordered'): + def __init__(self, ordering="unordered"): self._ordering = ordering def _get_content(self, key): @@ -2117,7 +2229,7 @@ def _get_content(self, key): if key in self.nonlocal_keys: record = next(self.get_record_stream()) # Create a content object on the fly - lines = record.get_bytes_as('lines') + lines = record.get_bytes_as("lines") return PlainKnitContent(lines, record.key) else: # local keys we can ask for directly @@ -2147,9 +2259,8 @@ def _work(self): break # Loop over fallback repositories asking them for texts - ignore # any missing from a particular fallback. - for record in source.get_record_stream(missing_keys, - self._ordering, True): - if record.storage_kind == 'absent': + for record in source.get_record_stream(missing_keys, self._ordering, True): + if record.storage_kind == "absent": # Not in thie particular stream, may be in one of the # other fallback vfs objects. continue @@ -2157,7 +2268,7 @@ def _work(self): yield record if self._raw_record_map is None: - raise AssertionError('_raw_record_map should have been filled') + raise AssertionError("_raw_record_map should have been filled") first = True for key in self.keys: if key in self.nonlocal_keys: @@ -2178,8 +2289,7 @@ def _get_one_work(self, requested_key): # final output. multiple_versions = len(self.keys) != 1 if self._record_map is None: - self._record_map = self.vf._raw_map_to_record_map( - self._raw_record_map) + self._record_map = self.vf._raw_map_to_record_map(self._raw_record_map) record_map = self._record_map # raw_record_map is key: # Have read and parsed records at this point. @@ -2202,14 +2312,17 @@ def _get_one_work(self, requested_key): break content = None - for (component_id, record, record_details, - digest) in reversed(components): # noqa: B007 + for component_id, record, record_details, digest in reversed(components): # noqa: B007 if component_id in self._contents_map: content = self._contents_map[component_id] else: content, delta = self._factory.parse_record( - key[-1], record, record_details, content, - copy_base_content=multiple_versions) + key[-1], + record, + record_details, + content, + copy_base_content=multiple_versions, + ) if multiple_versions: self._contents_map[component_id] = content @@ -2237,15 +2350,18 @@ def _wire_bytes(self): """ lines = [] # kind marker for dispatch on the far side, - lines.append(b'knit-delta-closure') + lines.append(b"knit-delta-closure") # Annotated or not if self.vf._factory.annotated: - lines.append(b'annotated') + lines.append(b"annotated") else: - lines.append(b'') + lines.append(b"") # then the list of keys - lines.append(b'\t'.join(b'\x00'.join(key) for key in self.keys - if key not in self.nonlocal_keys)) + lines.append( + b"\t".join( + b"\x00".join(key) for key in self.keys if key not in self.nonlocal_keys + ) + ) # then the _raw_record_map in serialised form: map_byte_list = [] # for each item in the map: @@ -2256,37 +2372,53 @@ def _wire_bytes(self): # one line with next ('' for None) # one line with byte count of the record bytes # the record bytes - for key, (record_bytes, (method, noeol), next) in ( - self._raw_record_map.items()): - key_bytes = b'\x00'.join(key) + for key, (record_bytes, (method, noeol), next) in self._raw_record_map.items(): + key_bytes = b"\x00".join(key) parents = self.global_map.get(key, None) if parents is None: - parent_bytes = b'None:' + parent_bytes = b"None:" else: - parent_bytes = b'\t'.join(b'\x00'.join(key) for key in parents) - method_bytes = method.encode('ascii') + parent_bytes = b"\t".join(b"\x00".join(key) for key in parents) + method_bytes = method.encode("ascii") if noeol: noeol_bytes = b"T" else: noeol_bytes = b"F" if next: - next_bytes = b'\x00'.join(next) + next_bytes = b"\x00".join(next) else: - next_bytes = b'' - map_byte_list.append(b'\n'.join( - [key_bytes, parent_bytes, method_bytes, noeol_bytes, next_bytes, - b'%d' % len(record_bytes), record_bytes])) - map_bytes = b''.join(map_byte_list) + next_bytes = b"" + map_byte_list.append( + b"\n".join( + [ + key_bytes, + parent_bytes, + method_bytes, + noeol_bytes, + next_bytes, + b"%d" % len(record_bytes), + record_bytes, + ] + ) + ) + map_bytes = b"".join(map_byte_list) lines.append(map_bytes) - bytes = b'\n'.join(lines) + bytes = b"\n".join(lines) return bytes class _VFContentMapGenerator(_ContentMapGenerator): """Content map generator reading from a VersionedFiles object.""" - def __init__(self, versioned_files, keys, nonlocal_keys=None, - global_map=None, raw_record_map=None, ordering='unordered'): + def __init__( + self, + versioned_files, + keys, + nonlocal_keys=None, + global_map=None, + raw_record_map=None, + ordering="unordered", + ): """Create a _ContentMapGenerator. :param versioned_files: The versioned files that the texts are being @@ -2320,8 +2452,9 @@ def __init__(self, versioned_files, keys, nonlocal_keys=None, # texts. self._record_map = None if raw_record_map is None: - self._raw_record_map = self.vf._get_record_map_unparsed(keys, - allow_missing=True) + self._raw_record_map = self.vf._get_record_map_unparsed( + keys, allow_missing=True + ) else: self._raw_record_map = raw_record_map # the factory for parsing records @@ -2343,65 +2476,67 @@ def __init__(self, bytes, line_end): self.vf = KnitVersionedFiles(None, None) start = line_end # Annotated or not - line_end = bytes.find(b'\n', start) + line_end = bytes.find(b"\n", start) line = bytes[start:line_end] start = line_end + 1 - if line == b'annotated': + if line == b"annotated": self._factory = KnitAnnotateFactory() else: self._factory = KnitPlainFactory() # list of keys to emit in get_record_stream - line_end = bytes.find(b'\n', start) + line_end = bytes.find(b"\n", start) line = bytes[start:line_end] start = line_end + 1 self.keys = [ - tuple(segment.split(b'\x00')) for segment in line.split(b'\t') - if segment] + tuple(segment.split(b"\x00")) for segment in line.split(b"\t") if segment + ] # now a loop until the end. XXX: It would be nice if this was just a # bunch of the same records as get_record_stream(..., False) gives, but # there is a decent sized gap stopping that at the moment. end = len(bytes) while start < end: # 1 line with key - line_end = bytes.find(b'\n', start) - key = tuple(bytes[start:line_end].split(b'\x00')) + line_end = bytes.find(b"\n", start) + key = tuple(bytes[start:line_end].split(b"\x00")) start = line_end + 1 # 1 line with parents (None: for None, '' for ()) - line_end = bytes.find(b'\n', start) + line_end = bytes.find(b"\n", start) line = bytes[start:line_end] - if line == b'None:': + if line == b"None:": parents = None else: parents = tuple( - tuple(segment.split(b'\x00')) for segment in line.split(b'\t') - if segment) + tuple(segment.split(b"\x00")) + for segment in line.split(b"\t") + if segment + ) self.global_map[key] = parents start = line_end + 1 # one line with method - line_end = bytes.find(b'\n', start) + line_end = bytes.find(b"\n", start) line = bytes[start:line_end] - method = line.decode('ascii') + method = line.decode("ascii") start = line_end + 1 # one line with noeol - line_end = bytes.find(b'\n', start) + line_end = bytes.find(b"\n", start) line = bytes[start:line_end] noeol = line == b"T" start = line_end + 1 # one line with next (b'' for None) - line_end = bytes.find(b'\n', start) + line_end = bytes.find(b"\n", start) line = bytes[start:line_end] if not line: next = None else: - next = tuple(bytes[start:line_end].split(b'\x00')) + next = tuple(bytes[start:line_end].split(b"\x00")) start = line_end + 1 # one line with byte count of the record bytes - line_end = bytes.find(b'\n', start) + line_end = bytes.find(b"\n", start) line = bytes[start:line_end] count = int(line) start = line_end + 1 # the record bytes - record_bytes = bytes[start:start + count] + record_bytes = bytes[start : start + count] start = start + count # put it in the map self._raw_record_map[key] = (record_bytes, (method, noeol), next) @@ -2509,7 +2644,7 @@ def add_records(self, records, random_id=False, missing_compression_parents=Fals for record in records: key = record[0] prefix = key[:-1] - path = self._mapper.map(key) + '.kndx' + path = self._mapper.map(key) + ".kndx" path_keys = paths.setdefault(path, (prefix, [])) path_keys[1].append(record) for path in sorted(paths): @@ -2526,17 +2661,22 @@ def add_records(self, records, random_id=False, missing_compression_parents=Fals if parents is None: # kndx indices cannot be parentless. parents = () - line = b' '.join([ - b'\n' - + key[-1], b','.join(options), b'%d' % pos, b'%d' % size, - self._dictionary_compress(parents), b':']) + line = b" ".join( + [ + b"\n" + key[-1], + b",".join(options), + b"%d" % pos, + b"%d" % size, + self._dictionary_compress(parents), + b":", + ] + ) if not isinstance(line, bytes): - raise AssertionError( - f'data must be utf8 was {type(line)}') + raise AssertionError(f"data must be utf8 was {type(line)}") lines.append(line) self._cache_key(key, options, pos, size, parents) if len(orig_history): - self._transport.append_bytes(path, b''.join(lines)) + self._transport.append_bytes(path, b"".join(lines)) else: self._init_index(path, lines) except: @@ -2578,16 +2718,11 @@ def _cache_key(self, key, options, pos, size, parent_keys): history.append(version_id) else: index = cache[version_id][5] - cache[version_id] = (version_id, - options, - pos, - size, - parents, - index) + cache[version_id] = (version_id, options, pos, size, parents, index) def check_header(self, fp): line = fp.readline() - if line == b'': + if line == b"": # An empty file can actually be treated as though the file doesn't # exist yet. raise _mod_transport.NoSuchFile(self) @@ -2606,7 +2741,7 @@ def _check_write_ok(self): raise errors.ObjectNotLocked(self) if self._get_scope() != self._scope: self._reset_cache() - if self._mode != 'w': + if self._mode != "w": raise errors.ReadOnlyObjectDirtiedError(self) def get_build_details(self, keys): @@ -2637,23 +2772,22 @@ def get_build_details(self, keys): if not isinstance(method, str): raise TypeError(method) parents = parent_map[key] - if method == 'fulltext': + if method == "fulltext": compression_parent = None else: compression_parent = parents[0] - noeol = b'no-eol' in self.get_options(key) + noeol = b"no-eol" in self.get_options(key) index_memo = self.get_position(key) - result[key] = (index_memo, compression_parent, - parents, (method, noeol)) + result[key] = (index_memo, compression_parent, parents, (method, noeol)) return result def get_method(self, key): """Return compression method of specified key.""" options = self.get_options(key) - if b'fulltext' in options: - return 'fulltext' - elif b'line-delta' in options: - return 'line-delta' + if b"fulltext" in options: + return "fulltext" + elif b"line-delta" in options: + return "line-delta" else: raise KnitIndexUnknownMethod(self, options) @@ -2688,11 +2822,9 @@ def find_ancestry(self, keys): except KeyError: missing_keys.add(key) else: - parent_keys = tuple([prefix + (suffix,) - for suffix in suffix_parents]) + parent_keys = tuple([prefix + (suffix,) for suffix in suffix_parents]) parent_map[key] = parent_keys - pending_keys.extend([p for p in parent_keys - if p not in parent_map]) + pending_keys.extend([p for p in parent_keys if p not in parent_map]) return parent_map, missing_keys def get_parent_map(self, keys): @@ -2715,8 +2847,7 @@ def get_parent_map(self, keys): except KeyError: pass else: - result[key] = tuple(prefix + (suffix,) for - suffix in suffix_parents) + result[key] = tuple(prefix + (suffix,) for suffix in suffix_parents) return result def get_position(self, key): @@ -2740,8 +2871,7 @@ def _init_index(self, path, extra_lines=None): sio.write(self.HEADER) sio.writelines(extra_lines) sio.seek(0) - self._transport.put_file_non_atomic(path, sio, - create_parent_dir=True) + self._transport.put_file_non_atomic(path, sio, create_parent_dir=True) # self._create_parent_dir) # mode=self._file_mode, # dir_mode=self._dir_mode) @@ -2778,7 +2908,7 @@ def _load_prefixes(self, prefixes): self._history = [] self._filename = prefix try: - path = self._mapper.map(prefix) + '.kndx' + path = self._mapper.map(prefix) + ".kndx" with self._transport.get(path) as fp: # _load_data may raise NoSuchFile if the target knit is # completely empty. @@ -2815,7 +2945,7 @@ def _dictionary_compress(self, keys): '.' prefix. """ if not keys: - return b'' + return b"" result_list = [] prefix = keys[0][:-1] cache = self._kndx_cache[prefix][0] @@ -2825,11 +2955,11 @@ def _dictionary_compress(self, keys): raise ValueError(f"mismatched prefixes for {keys!r}") if key[-1] in cache: # -- inlined lookup() -- - result_list.append(b'%d' % cache[key[-1]][5]) + result_list.append(b"%d" % cache[key[-1]][5]) # -- end lookup () -- else: - result_list.append(b'.' + key[-1]) - return b' '.join(result_list) + result_list.append(b"." + key[-1]) + return b" ".join(result_list) def _reset_cache(self): # Possibly this should be a LRU cache. A dictionary from key_prefix to @@ -2838,9 +2968,9 @@ def _reset_cache(self): self._scope = self._get_scope() allow_writes = self._allow_writes() if allow_writes: - self._mode = 'w' + self._mode = "w" else: - self._mode = 'r' + self._mode = "r" def _sort_keys_by_io(self, keys, positions): """Figure out an optimal order to read the records for the given keys. @@ -2853,6 +2983,7 @@ def _sort_keys_by_io(self, keys, positions): _get_components_positions() :return: None """ + def get_sort_key(key): index_memo = positions[key][1] # Group by prefix and position. index_memo[0] is the key, so it is @@ -2860,6 +2991,7 @@ def get_sort_key(key): # index_memo[1] is the position, and index_memo[2] is the size, # which doesn't matter for the sort return index_memo[0][:-1], index_memo[1] + return keys.sort(key=get_sort_key) _get_total_build_size = _get_total_build_size @@ -2875,8 +3007,15 @@ def _split_key(self, key): class _KnitGraphIndex: """A KnitVersionedFiles index layered on GraphIndex.""" - def __init__(self, graph_index, is_locked, deltas=False, parents=True, - add_callback=None, track_external_parent_refs=False): + def __init__( + self, + graph_index, + is_locked, + deltas=False, + parents=True, + add_callback=None, + track_external_parent_refs=False, + ): """Construct a KnitGraphIndex on a graph_index. :param graph_index: An implementation of breezy.index.GraphIndex. @@ -2901,8 +3040,9 @@ def __init__(self, graph_index, is_locked, deltas=False, parents=True, if deltas and not parents: # XXX: TODO: Delta tree and parent graph should be conceptually # separate. - raise KnitCorrupt(self, "Cannot do delta compression without " - "parent tracking.") + raise KnitCorrupt( + self, "Cannot do delta compression without " "parent tracking." + ) self.has_graph = parents self._is_locked = is_locked self._missing_compression_parents = set() @@ -2914,8 +3054,7 @@ def __init__(self, graph_index, is_locked, deltas=False, parents=True, def __repr__(self): return f"{self.__class__.__name__}({self._graph_index!r})" - def add_records(self, records, random_id=False, - missing_compression_parents=False): + def add_records(self, records, random_id=False, missing_compression_parents=False): """Add multiple records to the index. This function does not insert data into the Immutable GraphIndex @@ -2940,50 +3079,54 @@ def add_records(self, records, random_id=False, keys = {} compression_parents = set() key_dependencies = self._key_dependencies - for (key, options, access_memo, parents) in records: + for key, options, access_memo, parents in records: if self._parents: parents = tuple(parents) if key_dependencies is not None: key_dependencies.add_references(key, parents) index, pos, size = access_memo - if b'no-eol' in options: - value = b'N' + if b"no-eol" in options: + value = b"N" else: - value = b' ' + value = b" " value += b"%d %d" % (pos, size) if not self._deltas: - if b'line-delta' in options: + if b"line-delta" in options: raise KnitCorrupt( - self, "attempt to add line-delta in non-delta knit") + self, "attempt to add line-delta in non-delta knit" + ) if self._parents: if self._deltas: - if b'line-delta' in options: + if b"line-delta" in options: node_refs = (parents, (parents[0],)) if missing_compression_parents: compression_parents.add(parents[0]) else: node_refs = (parents, ()) else: - node_refs = (parents, ) + node_refs = (parents,) else: if parents: - raise KnitCorrupt(self, "attempt to add node with parents " - "in parentless index.") + raise KnitCorrupt( + self, "attempt to add node with parents " "in parentless index." + ) node_refs = () keys[key] = (value, node_refs) # check for dups if not random_id: present_nodes = self._get_entries(keys) - for (_index, key, value, node_refs) in present_nodes: + for _index, key, value, node_refs in present_nodes: parents = node_refs[:1] # Sometimes these are passed as a list rather than a tuple passed = static_tuple.as_tuples(keys[key]) passed_parents = passed[1][:1] - if (value[0:1] != keys[key][0][0:1] - or parents != passed_parents): + if value[0:1] != keys[key][0][0:1] or parents != passed_parents: node_refs = static_tuple.as_tuples(node_refs) - raise KnitCorrupt(self, "inconsistent details in add_records" - f": {(value, node_refs)} {passed}") + raise KnitCorrupt( + self, + "inconsistent details in add_records" + f": {(value, node_refs)} {passed}", + ) del keys[key] result = [] if self._parents: @@ -3038,7 +3181,8 @@ def get_missing_parents(self): # groupcompress._GCGraphIndex.get_missing_parents # We may have false positives, so filter those out. self._key_dependencies.satisfy_refs_for_keys( - self.get_parent_map(self._key_dependencies.get_unsatisfied_refs())) + self.get_parent_map(self._key_dependencies.get_unsatisfied_refs()) + ) return frozenset(self._key_dependencies.get_unsatisfied_refs()) def _check_read(self): @@ -3059,7 +3203,8 @@ def _compression_parent(self, an_entry): return None if len(compression_parents) != 1: raise AssertionError( - f"Too many compression parents: {compression_parents!r}") + f"Too many compression parents: {compression_parents!r}" + ) return compression_parents[0] def get_build_details(self, keys): @@ -3094,14 +3239,17 @@ def get_build_details(self, keys): compression_parent_key = None else: compression_parent_key = self._compression_parent(entry) - noeol = (entry[2][0:1] == b'N') + noeol = entry[2][0:1] == b"N" if compression_parent_key: - method = 'line-delta' + method = "line-delta" else: - method = 'fulltext' - result[key] = (self._node_to_position(entry), - compression_parent_key, parents, - (method, noeol)) + method = "fulltext" + result[key] = ( + self._node_to_position(entry), + compression_parent_key, + parents, + (method, noeol), + ) return result def _get_entries(self, keys, check_present=False): @@ -3131,11 +3279,11 @@ def get_method(self, key): def _get_method(self, node): if not self._deltas: - return 'fulltext' + return "fulltext" if self._compression_parent(node): - return 'line-delta' + return "line-delta" else: - return 'fulltext' + return "fulltext" def _get_node(self, key): try: @@ -3149,9 +3297,9 @@ def get_options(self, key): e.g. ['foo', 'bar'] """ node = self._get_node(key) - options = [self._get_method(node).encode('ascii')] - if node[2][0:1] == b'N': - options.append(b'no-eol') + options = [self._get_method(node).encode("ascii")] + if node[2][0:1] == b"N": + options.append(b"no-eol") return options def find_ancestry(self, keys): @@ -3199,7 +3347,7 @@ def keys(self): def _node_to_position(self, node): """Convert an index value to position details.""" - bits = node[2][1:].split(b' ') + bits = node[2][1:].split(b" ") return node[0], int(bits[0]), int(bits[1]) def _sort_keys_by_io(self, keys, positions): @@ -3213,6 +3361,7 @@ def _sort_keys_by_io(self, keys, positions): _get_components_positions() :return: None """ + def get_index_memo(key): # index_memo is at offset [1]. It is made up of (GraphIndex, # position, size). GI is an object, which will be unique for each @@ -3220,6 +3369,7 @@ def get_index_memo(key): # position. Size doesn't matter, but it isn't worth breaking up the # tuple. return positions[key][1] + return keys.sort(key=get_index_memo) _get_total_build_size = _get_total_build_size @@ -3252,10 +3402,10 @@ def add_raw_record(self, key, size, raw_data): """ path = self._mapper.map(key) try: - base = self._transport.append_bytes(path + '.knit', b''.join(raw_data)) + base = self._transport.append_bytes(path + ".knit", b"".join(raw_data)) except _mod_transport.NoSuchFile: self._transport.mkdir(osutils.dirname(path)) - base = self._transport.append_bytes(path + '.knit', b''.join(raw_data)) + base = self._transport.append_bytes(path + ".knit", b"".join(raw_data)) # if base == 0: # chmod. return (key, base, size) @@ -3273,17 +3423,16 @@ def add_raw_records(self, key_sizes, raw_data): opaque index memo. For _KnitKeyAccess the memo is (key, pos, length), where the key is the record key. """ - raw_data = b''.join(raw_data) + raw_data = b"".join(raw_data) if not isinstance(raw_data, bytes): - raise AssertionError( - f'data must be plain bytes was {type(raw_data)}') + raise AssertionError(f"data must be plain bytes was {type(raw_data)}") result = [] offset = 0 # TODO: This can be tuned for writing to sftp and other servers where # append() is relatively expensive by grouping the writes to each key # prefix. for key, size in key_sizes: - record_bytes = [raw_data[offset:offset + size]] + record_bytes = [raw_data[offset : offset + size]] result.append(self.add_raw_record(key, size, record_bytes)) offset += size return result @@ -3305,7 +3454,7 @@ def get_raw_records(self, memos_for_retrieval): # first pass, group into same-index request to minimise readv's issued. request_lists = [] current_prefix = None - for (key, offset, length) in memos_for_retrieval: + for key, offset, length in memos_for_retrieval: if current_prefix == key[:-1]: current_list.append((offset, length)) else: @@ -3317,7 +3466,7 @@ def get_raw_records(self, memos_for_retrieval): if current_prefix is not None: request_lists.append((current_prefix, current_list)) for prefix, read_vector in request_lists: - path = self._mapper.map(prefix) + '.knit' + path = self._mapper.map(prefix) + ".knit" for _pos, data in self._transport.readv(path, read_vector): yield data @@ -3382,14 +3531,14 @@ def _get_build_graph(self, key): # new_nodes = self._vf._index._get_entries(this_iteration) pending = set() for key, details in build_details.items(): - (index_memo, compression_parent, parent_keys, - record_details) = details + (index_memo, compression_parent, parent_keys, record_details) = details self._parent_map[key] = parent_keys self._heads_provider = None records.append((key, index_memo)) # Do we actually need to check _annotated_lines? - pending.update([p for p in parent_keys - if p not in self._all_build_details]) + pending.update( + [p for p in parent_keys if p not in self._all_build_details] + ) if parent_keys: for parent_key in parent_keys: if parent_key in self._num_needed_children: @@ -3415,8 +3564,9 @@ def _get_build_graph(self, key): self._num_needed_children[parent_key] += 1 else: self._num_needed_children[parent_key] = 1 - pending.update([p for p in parent_keys - if p not in self._all_build_details]) + pending.update( + [p for p in parent_keys if p not in self._all_build_details] + ) else: raise errors.RevisionNotPresent(key, self._vf) # Generally we will want to read the records in reverse order, because @@ -3434,9 +3584,10 @@ def _get_needed_texts(self, key, pb=None): try: records, ann_keys = self._get_build_graph(key) for idx, (sub_key, text, num_lines) in enumerate( - self._extract_texts(records)): + self._extract_texts(records) + ): if pb is not None: - pb.update(gettext('annotating'), idx, len(records)) + pb.update(gettext("annotating"), idx, len(records)) yield sub_key, text, num_lines for sub_key in ann_keys: text = self._text_cache[sub_key] @@ -3450,18 +3601,19 @@ def _get_needed_texts(self, key, pb=None): def _cache_delta_blocks(self, key, compression_parent, delta, lines): parent_lines = self._text_cache[compression_parent] - blocks = list(KnitContent.get_line_delta_blocks( - delta, parent_lines, lines)) + blocks = list(KnitContent.get_line_delta_blocks(delta, parent_lines, lines)) self._matching_blocks[(key, compression_parent)] = blocks - def _expand_record(self, key, parent_keys, compression_parent, record, - record_details): + def _expand_record( + self, key, parent_keys, compression_parent, record, record_details + ): delta = None if compression_parent: if compression_parent not in self._content_objects: # Waiting for the parent self._pending_deltas.setdefault(compression_parent, []).append( - (key, parent_keys, record, record_details)) + (key, parent_keys, record, record_details) + ) return None # We have the basis parent, so expand the delta num = self._num_compression_children[compression_parent] @@ -3479,12 +3631,13 @@ def _expand_record(self, key, parent_keys, compression_parent, record, # The alternative is to copy the lines into text cache, but then we # are copying anyway, so just do it here. content, delta = self._vf._factory.parse_record( - key, record, record_details, base_content, - copy_base_content=True) + key, record, record_details, base_content, copy_base_content=True + ) else: # Fulltext record content, _ = self._vf._factory.parse_record( - key, record, record_details, None) + key, record, record_details, None + ) if self._num_compression_children.get(key, 0) > 0: self._content_objects[key] = content lines = content.text() @@ -3509,8 +3662,9 @@ def _get_parent_annotations_and_matches(self, key, text, parent_key): blocks = self._matching_blocks.pop(block_key) parent_annotations = self._annotations_cache[parent_key] return parent_annotations, blocks - return annotate.Annotator._get_parent_annotations_and_matches(self, - key, text, parent_key) + return annotate.Annotator._get_parent_annotations_and_matches( + self, key, text, parent_key + ) def _process_pending(self, key): """The content for 'key' was just processed. @@ -3522,17 +3676,22 @@ def _process_pending(self, key): compression_parent = key children = self._pending_deltas.pop(key) for child_key, parent_keys, record, record_details in children: - self._expand_record(child_key, parent_keys, - compression_parent, - record, record_details) + self._expand_record( + child_key, parent_keys, compression_parent, record, record_details + ) if self._check_ready_for_annotations(child_key, parent_keys): to_return.append(child_key) # Also check any children that are waiting for this parent to be # annotation ready if key in self._pending_annotation: children = self._pending_annotation.pop(key) - to_return.extend([c for c, p_keys in children - if self._check_ready_for_annotations(c, p_keys)]) + to_return.extend( + [ + c + for c, p_keys in children + if self._check_ready_for_annotations(c, p_keys) + ] + ) return to_return def _check_ready_for_annotations(self, key, parent_keys): @@ -3546,8 +3705,9 @@ def _check_ready_for_annotations(self, key, parent_keys): # still waiting on at least one parent text, so queue it up # Note that if there are multiple parents, we need to wait # for all of them. - self._pending_annotation.setdefault(parent_key, - []).append((key, parent_keys)) + self._pending_annotation.setdefault(parent_key, []).append( + (key, parent_keys) + ) return False return True @@ -3582,19 +3742,19 @@ def _extract_texts(self, records): # code can know when it can stop caching fulltexts, as well. # Children that are missing their compression parent - for (key, record, _digest) in self._vf._read_records_iter(records): + for key, record, _digest in self._vf._read_records_iter(records): # ghosts? details = self._all_build_details[key] (_, compression_parent, parent_keys, record_details) = details - lines = self._expand_record(key, parent_keys, compression_parent, - record, record_details) + lines = self._expand_record( + key, parent_keys, compression_parent, record, record_details + ) if lines is None: # Pending delta should be queued up continue # At this point, we may be able to yield this content, if all # parents are also finished - yield_this_text = self._check_ready_for_annotations(key, - parent_keys) + yield_this_text = self._check_ready_for_annotations(key, parent_keys) if yield_this_text: # All parents present yield key, lines, len(lines) diff --git a/breezy/bzr/knitpack_repo.py b/breezy/bzr/knitpack_repo.py index a4a970d93e..e48cf74754 100644 --- a/breezy/bzr/knitpack_repo.py +++ b/breezy/bzr/knitpack_repo.py @@ -20,7 +20,9 @@ from .. import transport as _mod_transport from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import time from breezy import ( @@ -36,7 +38,8 @@ KnitPlainFactory, KnitVersionedFiles, ) -""") +""", +) from ..bzr import btree_index from ..bzr.index import ( @@ -60,50 +63,82 @@ class KnitPackRepository(PackRepository, KnitRepository): - - def __init__(self, _format, a_controldir, control_files, _commit_builder_class, - _revision_serializer, _inventory_serializer): - PackRepository.__init__(self, _format, a_controldir, control_files, - _commit_builder_class, _revision_serializer, _inventory_serializer) + def __init__( + self, + _format, + a_controldir, + control_files, + _commit_builder_class, + _revision_serializer, + _inventory_serializer, + ): + PackRepository.__init__( + self, + _format, + a_controldir, + control_files, + _commit_builder_class, + _revision_serializer, + _inventory_serializer, + ) if self._format.supports_chks: raise AssertionError("chk not supported") - index_transport = self._transport.clone('indices') - self._pack_collection = KnitRepositoryPackCollection(self, - self._transport, - index_transport, - self._transport.clone( - 'upload'), - self._transport.clone( - 'packs'), - _format.index_builder_class, - _format.index_class, - use_chk_index=False, - ) + index_transport = self._transport.clone("indices") + self._pack_collection = KnitRepositoryPackCollection( + self, + self._transport, + index_transport, + self._transport.clone("upload"), + self._transport.clone("packs"), + _format.index_builder_class, + _format.index_class, + use_chk_index=False, + ) self.inventories = KnitVersionedFiles( - _KnitGraphIndex(self._pack_collection.inventory_index.combined_index, - add_callback=self._pack_collection.inventory_index.add_callback, - deltas=True, parents=True, is_locked=self.is_locked), + _KnitGraphIndex( + self._pack_collection.inventory_index.combined_index, + add_callback=self._pack_collection.inventory_index.add_callback, + deltas=True, + parents=True, + is_locked=self.is_locked, + ), data_access=self._pack_collection.inventory_index.data_access, - max_delta_chain=200) + max_delta_chain=200, + ) self.revisions = KnitVersionedFiles( - _KnitGraphIndex(self._pack_collection.revision_index.combined_index, - add_callback=self._pack_collection.revision_index.add_callback, - deltas=False, parents=True, is_locked=self.is_locked, - track_external_parent_refs=True), + _KnitGraphIndex( + self._pack_collection.revision_index.combined_index, + add_callback=self._pack_collection.revision_index.add_callback, + deltas=False, + parents=True, + is_locked=self.is_locked, + track_external_parent_refs=True, + ), data_access=self._pack_collection.revision_index.data_access, - max_delta_chain=0) + max_delta_chain=0, + ) self.signatures = KnitVersionedFiles( - _KnitGraphIndex(self._pack_collection.signature_index.combined_index, - add_callback=self._pack_collection.signature_index.add_callback, - deltas=False, parents=False, is_locked=self.is_locked), + _KnitGraphIndex( + self._pack_collection.signature_index.combined_index, + add_callback=self._pack_collection.signature_index.add_callback, + deltas=False, + parents=False, + is_locked=self.is_locked, + ), data_access=self._pack_collection.signature_index.data_access, - max_delta_chain=0) + max_delta_chain=0, + ) self.texts = KnitVersionedFiles( - _KnitGraphIndex(self._pack_collection.text_index.combined_index, - add_callback=self._pack_collection.text_index.add_callback, - deltas=True, parents=True, is_locked=self.is_locked), + _KnitGraphIndex( + self._pack_collection.text_index.combined_index, + add_callback=self._pack_collection.text_index.add_callback, + deltas=True, + parents=True, + is_locked=self.is_locked, + ), data_access=self._pack_collection.text_index.data_access, - max_delta_chain=200) + max_delta_chain=200, + ) self.chk_bytes = None # True when the repository object is 'write locked' (as opposed to the # physical lock only taken out around changes to the pack-names list.) @@ -139,11 +174,13 @@ class RepositoryFormatKnitPack1(RepositoryFormatPack): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml5 import inventory_serializer_v5 + return inventory_serializer_v5 # What index classes to use @@ -151,13 +188,12 @@ def _inventory_serializer(self): index_class = GraphIndex def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir('pack-0.92') + return controldir.format_registry.make_controldir("pack-0.92") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): @@ -188,11 +224,13 @@ class RepositoryFormatKnitPack3(RepositoryFormatPack): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml7 import inventory_serializer_v7 + return inventory_serializer_v7 # What index classes to use @@ -200,19 +238,19 @@ def _inventory_serializer(self): index_class = GraphIndex def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir( - 'pack-0.92-subtree') + return controldir.format_registry.make_controldir("pack-0.92-subtree") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): """See RepositoryFormat.get_format_string().""" - return b"Bazaar pack repository format 1 with subtree support (needs bzr 0.92)\n" + return ( + b"Bazaar pack repository format 1 with subtree support (needs bzr 0.92)\n" + ) def get_format_description(self): """See RepositoryFormat.get_format_description().""" @@ -236,11 +274,13 @@ class RepositoryFormatKnitPack4(RepositoryFormatPack): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml6 import inventory_serializer_v6 + return inventory_serializer_v6 # What index classes to use @@ -248,20 +288,17 @@ def _inventory_serializer(self): index_class = GraphIndex def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir( - 'rich-root-pack') + return controldir.format_registry.make_controldir("rich-root-pack") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): """See RepositoryFormat.get_format_string().""" - return (b"Bazaar pack repository format 1 with rich root" - b" (needs bzr 1.0)\n") + return b"Bazaar pack repository format 1 with rich root" b" (needs bzr 1.0)\n" def get_format_description(self): """See RepositoryFormat.get_format_description().""" @@ -285,21 +322,22 @@ class RepositoryFormatKnitPack5(RepositoryFormatPack): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml5 import inventory_serializer_v5 + return inventory_serializer_v5 def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir('1.6') + return controldir.format_registry.make_controldir("1.6") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): @@ -330,22 +368,22 @@ class RepositoryFormatKnitPack5RichRoot(RepositoryFormatPack): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml6 import inventory_serializer_v6 + return inventory_serializer_v6 def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir( - '1.6.1-rich-root') + return controldir.format_registry.make_controldir("1.6.1-rich-root") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): @@ -380,24 +418,24 @@ class RepositoryFormatKnitPack5RichRootBroken(RepositoryFormatPack): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml7 import inventory_serializer_v7 + return inventory_serializer_v7 def _get_matching_bzrdir(self): - matching = controldir.format_registry.make_controldir( - '1.6.1-rich-root') + matching = controldir.format_registry.make_controldir("1.6.1-rich-root") matching.repository_format = self return matching def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): @@ -405,8 +443,10 @@ def get_format_string(cls): return b"Bazaar RepositoryFormatKnitPack5RichRoot (bzr 1.6)\n" def get_format_description(self): - return ("Packs 5 rich-root (adds stacking support, requires bzr 1.6)" - " (deprecated)") + return ( + "Packs 5 rich-root (adds stacking support, requires bzr 1.6)" + " (deprecated)" + ) def is_deprecated(self): return True @@ -429,21 +469,22 @@ class RepositoryFormatKnitPack6(RepositoryFormatPack): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml5 import inventory_serializer_v5 + return inventory_serializer_v5 def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir('1.9') + return controldir.format_registry.make_controldir("1.9") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): @@ -473,22 +514,22 @@ class RepositoryFormatKnitPack6RichRoot(RepositoryFormatPack): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml6 import inventory_serializer_v6 + return inventory_serializer_v6 def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir( - '1.9-rich-root') + return controldir.format_registry.make_controldir("1.9-rich-root") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): @@ -521,33 +562,37 @@ class RepositoryFormatPackDevelopment2Subtree(RepositoryFormatPack): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml7 import inventory_serializer_v7 + return inventory_serializer_v7 def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir( - 'development5-subtree') + return controldir.format_registry.make_controldir("development5-subtree") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): """See RepositoryFormat.get_format_string().""" - return (b"Bazaar development format 2 with subtree support " - b"(needs bzr.dev from before 1.8)\n") + return ( + b"Bazaar development format 2 with subtree support " + b"(needs bzr.dev from before 1.8)\n" + ) def get_format_description(self): """See RepositoryFormat.get_format_description().""" - return ("Development repository format, currently the same as " - "1.6.1-subtree with B+Tree indices.\n") + return ( + "Development repository format, currently the same as " + "1.6.1-subtree with B+Tree indices.\n" + ) class KnitPackStreamSource(StreamSource): @@ -565,55 +610,61 @@ class KnitPackStreamSource(StreamSource): def __init__(self, from_repository, to_format): super().__init__(from_repository, to_format) self._text_keys = None - self._text_fetch_order = 'unordered' + self._text_fetch_order = "unordered" def _get_filtered_inv_stream(self, revision_ids): from_repo = self.from_repository parent_ids = from_repo._find_parent_ids_of_revisions(revision_ids) parent_keys = [(p,) for p in parent_ids] find_text_keys = from_repo._inventory_serializer._find_text_key_references - parent_text_keys = set(find_text_keys( - from_repo._inventory_xml_lines_for_keys(parent_keys))) + parent_text_keys = set( + find_text_keys(from_repo._inventory_xml_lines_for_keys(parent_keys)) + ) content_text_keys = set() knit = KnitVersionedFiles(None, None) factory = KnitPlainFactory() def find_text_keys_from_content(record): - if record.storage_kind not in ('knit-delta-gz', 'knit-ft-gz'): - raise ValueError("Unknown content storage kind for" - f" inventory text: {record.storage_kind}") + if record.storage_kind not in ("knit-delta-gz", "knit-ft-gz"): + raise ValueError( + "Unknown content storage kind for" + f" inventory text: {record.storage_kind}" + ) # It's a knit record, it has a _raw_record field (even if it was # reconstituted from a network stream). raw_data = record._raw_record # read the entire thing revision_id = record.key[-1] content, _ = knit._parse_record(revision_id, raw_data) - if record.storage_kind == 'knit-delta-gz': + if record.storage_kind == "knit-delta-gz": line_iterator = factory.get_linedelta_content(content) - elif record.storage_kind == 'knit-ft-gz': + elif record.storage_kind == "knit-ft-gz": line_iterator = factory.get_fulltext_content(content) - content_text_keys.update(find_text_keys( - [(line, revision_id) for line in line_iterator])) + content_text_keys.update( + find_text_keys([(line, revision_id) for line in line_iterator]) + ) + revision_keys = [(r,) for r in revision_ids] def _filtered_inv_stream(): source_vf = from_repo.inventories - stream = source_vf.get_record_stream(revision_keys, - 'unordered', False) + stream = source_vf.get_record_stream(revision_keys, "unordered", False) for record in stream: - if record.storage_kind == 'absent': + if record.storage_kind == "absent": raise errors.NoSuchRevision(from_repo, record.key) find_text_keys_from_content(record) yield record self._text_keys = content_text_keys - parent_text_keys - return ('inventories', _filtered_inv_stream()) + + return ("inventories", _filtered_inv_stream()) def _get_text_stream(self): # Note: We know we don't have to handle adding root keys, because both # the source and target are the identical network name. text_stream = self.from_repository.texts.get_record_stream( - self._text_keys, self._text_fetch_order, False) - return ('texts', text_stream) + self._text_keys, self._text_fetch_order, False + ) + return ("texts", text_stream) def get_stream(self, search): revision_ids = search.get_keys() @@ -626,11 +677,16 @@ def get_stream(self, search): class KnitPacker(Packer): """Packer that works with knit packs.""" - def __init__(self, pack_collection, packs, suffix, revision_ids=None, - reload_func=None): - super().__init__(pack_collection, packs, suffix, - revision_ids=revision_ids, - reload_func=reload_func) + def __init__( + self, pack_collection, packs, suffix, revision_ids=None, reload_func=None + ): + super().__init__( + pack_collection, + packs, + suffix, + revision_ids=revision_ids, + reload_func=reload_func, + ) def _pack_map_and_index_list(self, index_attribute): """Convert a list of packs to an index pack map and index list. @@ -661,18 +717,19 @@ def _index_contents(self, indices, key_filter=None): else: return all_index.iter_entries(key_filter) - def _copy_nodes(self, nodes, index_map, writer, write_index, - output_lines=None): + def _copy_nodes(self, nodes, index_map, writer, write_index, output_lines=None): """Copy knit nodes between packs with no graph references. :param output_lines: Output full texts of copied items. """ with ui.ui_factory.nested_progress_bar() as pb: - return self._do_copy_nodes(nodes, index_map, writer, - write_index, pb, output_lines=output_lines) + return self._do_copy_nodes( + nodes, index_map, writer, write_index, pb, output_lines=output_lines + ) - def _do_copy_nodes(self, nodes, index_map, writer, write_index, pb, - output_lines=None): + def _do_copy_nodes( + self, nodes, index_map, writer, write_index, pb, output_lines=None + ): # for record verification knit = KnitVersionedFiles(None, None) # plan a readv on each source pack: @@ -693,7 +750,7 @@ def _do_copy_nodes(self, nodes, index_map, writer, write_index, pb, pack_readv_requests = [] for key, value in items: # ---- KnitGraphIndex.get_position - bits = value[1:].split(b' ') + bits = value[1:].split(b" ") offset, length = int(bits[0]), int(bits[1]) pack_readv_requests.append((offset, length, (key, value[0:1]))) # linear scan up the pack @@ -702,14 +759,16 @@ def _do_copy_nodes(self, nodes, index_map, writer, write_index, pb, pack_obj = index_map[index] transport, path = pack_obj.access_tuple() try: - reader = pack.make_readv_reader(transport, path, - [offset[0:2] for offset in pack_readv_requests]) + reader = pack.make_readv_reader( + transport, path, [offset[0:2] for offset in pack_readv_requests] + ) except _mod_transport.NoSuchFile: if self._reload_func is not None: self._reload_func() raise for (names, read_func), (_1, _2, (key, eol_flag)) in zip( - reader.iter_records(), pack_readv_requests): + reader.iter_records(), pack_readv_requests + ): raw_data = read_func(None) # check the header only if output_lines is not None: @@ -722,19 +781,41 @@ def _do_copy_nodes(self, nodes, index_map, writer, write_index, pb, pb.update("Copied record", record_index) record_index += 1 - def _copy_nodes_graph(self, index_map, writer, write_index, - readv_group_iter, total_items, output_lines=False): + def _copy_nodes_graph( + self, + index_map, + writer, + write_index, + readv_group_iter, + total_items, + output_lines=False, + ): """Copy knit nodes between packs. :param output_lines: Return lines present in the copied data as an iterator of line,version_id. """ with ui.ui_factory.nested_progress_bar() as pb: - yield from self._do_copy_nodes_graph(index_map, writer, - write_index, output_lines, pb, readv_group_iter, total_items) - - def _do_copy_nodes_graph(self, index_map, writer, write_index, - output_lines, pb, readv_group_iter, total_items): + yield from self._do_copy_nodes_graph( + index_map, + writer, + write_index, + output_lines, + pb, + readv_group_iter, + total_items, + ) + + def _do_copy_nodes_graph( + self, + index_map, + writer, + write_index, + output_lines, + pb, + readv_group_iter, + total_items, + ): # for record verification knit = KnitVersionedFiles(None, None) # for line extraction when requested (inventories only) @@ -753,7 +834,8 @@ def _do_copy_nodes_graph(self, index_map, writer, write_index, self._reload_func() raise for (names, read_func), (key, eol_flag, references) in zip( - reader.iter_records(), node_vector): + reader.iter_records(), node_vector + ): raw_data = read_func(None) if output_lines: # read the entire thing @@ -769,8 +851,7 @@ def _do_copy_nodes_graph(self, index_map, writer, write_index, df, _ = knit._parse_record_header(key, raw_data) df.close() pos, size = writer.add_bytes_record([raw_data], len(raw_data), names) - write_index.add_node(key, eol_flag + b"%d %d" % - (pos, size), references) + write_index.add_node(key, eol_flag + b"%d %d" % (pos, size), references) pb.update("Copied record", record_index) record_index += 1 @@ -778,11 +859,11 @@ def _process_inventory_lines(self, inv_lines): """Use up the inv_lines generator and setup a text key filter.""" repo = self._pack_collection.repo fileid_revisions = repo._find_file_ids_from_xml_inventory_lines( - inv_lines, self.revision_keys) + inv_lines, self.revision_keys + ) text_filter = [] for fileid, file_revids in fileid_revisions.items(): - text_filter.extend([(fileid, file_revid) - for file_revid in file_revids]) + text_filter.extend([(fileid, file_revid) for file_revid in file_revids]) self._text_filter = text_filter def _copy_inventory_texts(self): @@ -792,7 +873,8 @@ def _copy_inventory_texts(self): # is missed, so do not change it to query separately without cross # checking like the text key check below. inventory_index_map, inventory_indices = self._pack_map_and_index_list( - 'inventory_index') + "inventory_index" + ) inv_nodes = self._index_contents(inventory_indices, inv_keys) # copy inventory keys and adjust values # XXX: Should be a helper function to allow different inv representation @@ -801,21 +883,29 @@ def _copy_inventory_texts(self): total_items, readv_group_iter = self._least_readv_node_readv(inv_nodes) # Only grab the output lines if we will be processing them output_lines = bool(self.revision_ids) - inv_lines = self._copy_nodes_graph(inventory_index_map, - self.new_pack._writer, self.new_pack.inventory_index, - readv_group_iter, total_items, output_lines=output_lines) + inv_lines = self._copy_nodes_graph( + inventory_index_map, + self.new_pack._writer, + self.new_pack.inventory_index, + readv_group_iter, + total_items, + output_lines=output_lines, + ) if self.revision_ids: self._process_inventory_lines(inv_lines) else: # eat the iterator to cause it to execute. list(inv_lines) self._text_filter = None - if debug.debug_flag_enabled('pack'): - trace.mutter('%s: create_pack: inventories copied: %s%s %d items t+%6.3fs', - time.ctime(), self._pack_collection._upload_transport.base, - self.new_pack.random_name, - self.new_pack.inventory_index.key_count(), - time.time() - self.new_pack.start_time) + if debug.debug_flag_enabled("pack"): + trace.mutter( + "%s: create_pack: inventories copied: %s%s %d items t+%6.3fs", + time.ctime(), + self._pack_collection._upload_transport.base, + self.new_pack.random_name, + self.new_pack.inventory_index.key_count(), + time.time() - self.new_pack.start_time, + ) def _update_pack_order(self, entries, index_to_pack_map): """Determine how we want our packs to be ordered. @@ -836,53 +926,59 @@ def _update_pack_order(self, entries, index_to_pack_map): packs.append(index_to_pack_map[index]) seen_indexes.add(index) if len(packs) == len(self.packs): - if debug.debug_flag_enabled('pack'): - trace.mutter('Not changing pack list, all packs used.') + if debug.debug_flag_enabled("pack"): + trace.mutter("Not changing pack list, all packs used.") return seen_packs = set(packs) for pack in self.packs: if pack not in seen_packs: packs.append(pack) seen_packs.add(pack) - if debug.debug_flag_enabled('pack'): + if debug.debug_flag_enabled("pack"): old_names = [p.access_tuple()[1] for p in self.packs] new_names = [p.access_tuple()[1] for p in packs] - trace.mutter('Reordering packs\nfrom: %s\n to: %s', - old_names, new_names) + trace.mutter("Reordering packs\nfrom: %s\n to: %s", old_names, new_names) self.packs = packs def _copy_revision_texts(self): # select revisions if self.revision_ids: - revision_keys = [(revision_id,) - for revision_id in self.revision_ids] + revision_keys = [(revision_id,) for revision_id in self.revision_ids] else: revision_keys = None # select revision keys revision_index_map, revision_indices = self._pack_map_and_index_list( - 'revision_index') + "revision_index" + ) revision_nodes = self._index_contents(revision_indices, revision_keys) revision_nodes = list(revision_nodes) self._update_pack_order(revision_nodes, revision_index_map) # copy revision keys and adjust values self.pb.update("Copying revision texts", 1) - total_items, readv_group_iter = self._revision_node_readv( - revision_nodes) - list(self._copy_nodes_graph(revision_index_map, self.new_pack._writer, - self.new_pack.revision_index, readv_group_iter, total_items)) - if debug.debug_flag_enabled('pack'): - trace.mutter('%s: create_pack: revisions copied: %s%s %d items t+%6.3fs', - time.ctime(), self._pack_collection._upload_transport.base, - self.new_pack.random_name, - self.new_pack.revision_index.key_count(), - time.time() - self.new_pack.start_time) + total_items, readv_group_iter = self._revision_node_readv(revision_nodes) + list( + self._copy_nodes_graph( + revision_index_map, + self.new_pack._writer, + self.new_pack.revision_index, + readv_group_iter, + total_items, + ) + ) + if debug.debug_flag_enabled("pack"): + trace.mutter( + "%s: create_pack: revisions copied: %s%s %d items t+%6.3fs", + time.ctime(), + self._pack_collection._upload_transport.base, + self.new_pack.random_name, + self.new_pack.revision_index.key_count(), + time.time() - self.new_pack.start_time, + ) self._revision_keys = revision_keys def _get_text_nodes(self): - text_index_map, text_indices = self._pack_map_and_index_list( - 'text_index') - return text_index_map, self._index_contents(text_indices, - self._text_filter) + text_index_map, text_indices = self._pack_map_and_index_list("text_index") + return text_index_map, self._index_contents(text_indices, self._text_filter) def _copy_text_texts(self): # select text keys @@ -900,17 +996,21 @@ def _copy_text_texts(self): if missing_text_keys: # TODO: raise a specific error that can handle many missing # keys. - trace.mutter("missing keys during fetch: %r", - missing_text_keys) + trace.mutter("missing keys during fetch: %r", missing_text_keys) a_missing_key = missing_text_keys.pop() - raise errors.RevisionNotPresent(a_missing_key[1], - a_missing_key[0]) + raise errors.RevisionNotPresent(a_missing_key[1], a_missing_key[0]) # copy text keys and adjust values self.pb.update("Copying content texts", 3) - total_items, readv_group_iter = self._least_readv_node_readv( - text_nodes) - list(self._copy_nodes_graph(text_index_map, self.new_pack._writer, - self.new_pack.text_index, readv_group_iter, total_items)) + total_items, readv_group_iter = self._least_readv_node_readv(text_nodes) + list( + self._copy_nodes_graph( + text_index_map, + self.new_pack._writer, + self.new_pack.text_index, + readv_group_iter, + total_items, + ) + ) self._log_copied_texts() def _create_pack_from_packs(self): @@ -920,35 +1020,49 @@ def _create_pack_from_packs(self): # buffer data - we won't be reading-back during the pack creation and # this makes a significant difference on sftp pushes. new_pack.set_write_cache_size(1024 * 1024) - if debug.debug_flag_enabled('pack'): - plain_pack_list = [f'{a_pack.pack_transport.base}{a_pack.name}' - for a_pack in self.packs] + if debug.debug_flag_enabled("pack"): + plain_pack_list = [ + f"{a_pack.pack_transport.base}{a_pack.name}" for a_pack in self.packs + ] if self.revision_ids is not None: rev_count = len(self.revision_ids) else: - rev_count = 'all' - trace.mutter('%s: create_pack: creating pack from source packs: ' - '%s%s %s revisions wanted %s t=0', - time.ctime(), self._pack_collection._upload_transport.base, new_pack.random_name, - plain_pack_list, rev_count) + rev_count = "all" + trace.mutter( + "%s: create_pack: creating pack from source packs: " + "%s%s %s revisions wanted %s t=0", + time.ctime(), + self._pack_collection._upload_transport.base, + new_pack.random_name, + plain_pack_list, + rev_count, + ) self._copy_revision_texts() self._copy_inventory_texts() self._copy_text_texts() # select signature keys signature_filter = self._revision_keys # same keyspace signature_index_map, signature_indices = self._pack_map_and_index_list( - 'signature_index') - signature_nodes = self._index_contents(signature_indices, - signature_filter) + "signature_index" + ) + signature_nodes = self._index_contents(signature_indices, signature_filter) # copy signature keys and adjust values self.pb.update("Copying signature texts", 4) - self._copy_nodes(signature_nodes, signature_index_map, new_pack._writer, - new_pack.signature_index) - if debug.debug_flag_enabled('pack'): - trace.mutter('%s: create_pack: revision signatures copied: %s%s %d items t+%6.3fs', - time.ctime(), self._pack_collection._upload_transport.base, new_pack.random_name, - new_pack.signature_index.key_count(), - time.time() - new_pack.start_time) + self._copy_nodes( + signature_nodes, + signature_index_map, + new_pack._writer, + new_pack.signature_index, + ) + if debug.debug_flag_enabled("pack"): + trace.mutter( + "%s: create_pack: revision signatures copied: %s%s %d items t+%6.3fs", + time.ctime(), + self._pack_collection._upload_transport.base, + new_pack.random_name, + new_pack.signature_index.key_count(), + time.time() - new_pack.start_time, + ) new_pack._check_references() if not self._use_pack(new_pack): new_pack.abort() @@ -983,10 +1097,11 @@ def _least_readv_node_readv(self, nodes): pack_readv_requests = [] for key, value, references in items: # ---- KnitGraphIndex.get_position - bits = value[1:].split(b' ') + bits = value[1:].split(b" ") offset, length = int(bits[0]), int(bits[1]) pack_readv_requests.append( - ((offset, length), (key, value[0:1], references))) + ((offset, length), (key, value[0:1], references)) + ) # linear scan up the pack to maximum range combining. pack_readv_requests.sort() # split out the readv and the node data. @@ -1035,9 +1150,10 @@ def _copy_text_texts(self): # we have three major tasks here: # 1) generate the ideal index repo = self._pack_collection.repo - ancestors = {key[0]: tuple(ref[0] for ref in refs[0]) for - _1, key, _2, refs in - self.new_pack.revision_index.iter_all_entries()} + ancestors = { + key[0]: tuple(ref[0] for ref in refs[0]) + for _1, key, _2, refs in self.new_pack.revision_index.iter_all_entries() + } ideal_index = repo._generate_text_key_index(self._text_refs, ancestors) # 2) generate a text_nodes list that contains all the deltas that can # be used as-is, with corrected parents. @@ -1067,8 +1183,9 @@ def _copy_text_texts(self): # today. Either way, we can preserve the representation as # long as we change the refs to be inserted. self._data_changed = True - ok_nodes.append((node[0], node[1], node[2], - (ideal_parents, node[3][1]))) + ok_nodes.append( + (node[0], node[1], node[2], (ideal_parents, node[3][1])) + ) self._data_changed = True else: # Reinsert this text completely @@ -1079,8 +1196,15 @@ def _copy_text_texts(self): del text_nodes # 3) bulk copy the ok data total_items, readv_group_iter = self._least_readv_node_readv(ok_nodes) - list(self._copy_nodes_graph(text_index_map, self.new_pack._writer, - self.new_pack.text_index, readv_group_iter, total_items)) + list( + self._copy_nodes_graph( + text_index_map, + self.new_pack._writer, + self.new_pack.text_index, + readv_group_iter, + total_items, + ) + ) # 4) adhoc copy all the other texts. # We have to topologically insert all texts otherwise we can fail to # reconcile when parts of a single delta chain are preserved intact, @@ -1098,17 +1222,29 @@ def _copy_text_texts(self): repo.get_transaction() GraphIndexPrefixAdapter( self.new_pack.text_index, - ('blank', ), 1, - add_nodes_callback=self.new_pack.text_index.add_nodes) + ("blank",), + 1, + add_nodes_callback=self.new_pack.text_index.add_nodes, + ) data_access = _DirectPackAccess( - {self.new_pack.text_index: self.new_pack.access_tuple()}) - data_access.set_writer(self.new_pack._writer, self.new_pack.text_index, - self.new_pack.access_tuple()) + {self.new_pack.text_index: self.new_pack.access_tuple()} + ) + data_access.set_writer( + self.new_pack._writer, + self.new_pack.text_index, + self.new_pack.access_tuple(), + ) output_texts = KnitVersionedFiles( - _KnitGraphIndex(self.new_pack.text_index, - add_callback=self.new_pack.text_index.add_nodes, - deltas=True, parents=True, is_locked=repo.is_locked), - data_access=data_access, max_delta_chain=200) + _KnitGraphIndex( + self.new_pack.text_index, + add_callback=self.new_pack.text_index.add_nodes, + deltas=True, + parents=True, + is_locked=repo.is_locked, + ), + data_access=data_access, + max_delta_chain=200, + ) for key, parent_keys in bad_texts: # We refer to the new pack to delta data being output. # A possible improvement would be to catch errors on short reads @@ -1118,16 +1254,22 @@ def _copy_text_texts(self): for parent_key in parent_keys: if parent_key[0] != key[0]: # Graph parents must match the fileid - raise errors.BzrError(f'Mismatched key parent {key!r}:{parent_keys!r}') + raise errors.BzrError( + f"Mismatched key parent {key!r}:{parent_keys!r}" + ) parents.append(parent_key[1]) - text_lines = next(repo.texts.get_record_stream( - [key], 'unordered', True)).get_bytes_as('lines') - output_texts.add_lines(key, parent_keys, text_lines, - random_id=True, check_content=False) + text_lines = next( + repo.texts.get_record_stream([key], "unordered", True) + ).get_bytes_as("lines") + output_texts.add_lines( + key, parent_keys, text_lines, random_id=True, check_content=False + ) # 5) check that nothing inserted has a reference outside the keyspace. missing_text_keys = self.new_pack.text_index._external_references() if missing_text_keys: - raise errors.BzrCheckError(f'Reference to missing compression parents {missing_text_keys!r}') + raise errors.BzrCheckError( + f"Reference to missing compression parents {missing_text_keys!r}" + ) self._log_copied_texts() def _use_pack(self, new_pack): @@ -1171,10 +1313,11 @@ def _revision_node_readv(self, revision_nodes): for key in reversed(order): index, value, references = by_key[key] # ---- KnitGraphIndex.get_position - bits = value[1:].split(b' ') + bits = value[1:].split(b" ") offset, length = int(bits[0]), int(bits[1]) requests.append( - (index, [(offset, length)], [(key, value[0:1], references)])) + (index, [(offset, length)], [(key, value[0:1], references)]) + ) # TODO: combine requests in the same index that are in ascending order. return total, requests diff --git a/breezy/bzr/knitrepo.py b/breezy/bzr/knitrepo.py index d29faad509..cd35636912 100644 --- a/breezy/bzr/knitrepo.py +++ b/breezy/bzr/knitrepo.py @@ -18,7 +18,9 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( transactions, ) @@ -27,7 +29,8 @@ lockable_files, versionedfile, ) -""") +""", +) from .. import controldir, errors, lockdir, trace from .. import revision as _mod_revision from .. import transport as _mod_transport @@ -43,25 +46,23 @@ class _KnitParentsProvider: - def __init__(self, knit): self._knit = knit def __repr__(self): - return f'KnitParentsProvider({self._knit!r})' + return f"KnitParentsProvider({self._knit!r})" def get_parent_map(self, keys): """See graph.StackedParentsProvider.get_parent_map.""" parent_map = {} for revision_id in keys: if revision_id is None: - raise ValueError('get_parent_map(None) is not valid') + raise ValueError("get_parent_map(None) is not valid") if revision_id == _mod_revision.NULL_REVISION: parent_map[revision_id] = () else: try: - parents = tuple( - self._knit.get_parents_with_ghosts(revision_id)) + parents = tuple(self._knit.get_parents_with_ghosts(revision_id)) except errors.RevisionNotPresent: continue else: @@ -72,19 +73,17 @@ def get_parent_map(self, keys): class _KnitsParentsProvider: - def __init__(self, knit, prefix=()): """Create a parent provider for string keys mapped to tuple keys.""" self._knit = knit self._prefix = prefix def __repr__(self): - return f'KnitsParentsProvider({self._knit!r})' + return f"KnitsParentsProvider({self._knit!r})" def get_parent_map(self, keys): """See graph.StackedParentsProvider.get_parent_map.""" - parent_map = self._knit.get_parent_map( - [self._prefix + (key,) for key in keys]) + parent_map = self._knit.get_parent_map([self._prefix + (key,) for key in keys]) result = {} for key, parents in parent_map.items(): revid = key[-1] @@ -110,10 +109,16 @@ class KnitRepository(MetaDirVersionedFileRepository): _revision_serializer: RevisionSerializer _inventory_serializer: InventorySerializer - def __init__(self, _format, a_controldir, control_files, _commit_builder_class, - _revision_serializer, _inventory_serializer): - super().__init__( - _format, a_controldir, control_files) + def __init__( + self, + _format, + a_controldir, + control_files, + _commit_builder_class, + _revision_serializer, + _inventory_serializer, + ): + super().__init__(_format, a_controldir, control_files) self._commit_builder_class = _commit_builder_class self._revision_serializer = _revision_serializer self._inventory_serializer = _inventory_serializer @@ -128,16 +133,16 @@ def _activate_new_inventory(self): """Put a replacement inventory.new into use as inventories.""" # Copy the content across t = self._transport - t.copy('inventory.new.kndx', 'inventory.kndx') + t.copy("inventory.new.kndx", "inventory.kndx") try: - t.copy('inventory.new.knit', 'inventory.knit') + t.copy("inventory.new.knit", "inventory.knit") except _mod_transport.NoSuchFile: # empty inventories knit - t.delete('inventory.knit') + t.delete("inventory.knit") # delete the temp inventory - t.delete('inventory.new.kndx') + t.delete("inventory.new.kndx") try: - t.delete('inventory.new.knit') + t.delete("inventory.new.knit") except _mod_transport.NoSuchFile: # empty inventories knit pass @@ -147,33 +152,32 @@ def _activate_new_inventory(self): def _backup_inventory(self): t = self._transport - t.copy('inventory.kndx', 'inventory.backup.kndx') - t.copy('inventory.knit', 'inventory.backup.knit') + t.copy("inventory.kndx", "inventory.backup.kndx") + t.copy("inventory.knit", "inventory.backup.knit") def _move_file_id(self, from_id, to_id): - t = self._transport.clone('knits') + t = self._transport.clone("knits") from_rel_url = self.texts._index._mapper.map((from_id, None)) to_rel_url = self.texts._index._mapper.map((to_id, None)) # We expect both files to always exist in this case. - for suffix in ('.knit', '.kndx'): + for suffix in (".knit", ".kndx"): t.rename(from_rel_url + suffix, to_rel_url + suffix) def _remove_file_id(self, file_id): - t = self._transport.clone('knits') + t = self._transport.clone("knits") rel_url = self.texts._index._mapper.map((file_id, None)) - for suffix in ('.kndx', '.knit'): + for suffix in (".kndx", ".knit"): try: t.delete(rel_url + suffix) except _mod_transport.NoSuchFile: pass def _temp_inventories(self): - result = self._format._get_inventories(self._transport, self, - 'inventory.new') + result = self._format._get_inventories(self._transport, self, "inventory.new") # Reconciling when the output has no revisions would result in no # writes - but we want to ensure there is an inventory for # compatibility with older clients that don't lazy-load. - result.get_parent_map([(b'A',)]) + result.get_parent_map([(b"A",)]) return result def get_revision(self, revision_id): @@ -197,6 +201,7 @@ def _refresh_data(self): def reconcile(self, other=None, thorough=False): """Reconcile this repository.""" from .reconcile import KnitReconciler + with self.lock_write(): reconciler = KnitReconciler(self, thorough=thorough) return reconciler.reconcile() @@ -232,19 +237,22 @@ class RepositoryFormatKnit(MetaDirVersionedFileRepositoryFormat): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml5 import inventory_serializer_v5 + return inventory_serializer_v5 + # Knit based repositories handle ghosts reasonably well. supports_ghosts = True # External lookups are not supported in this format. supports_external_lookups = False # No CHK support. supports_chks = False - _fetch_order = 'topological' + _fetch_order = "topological" _fetch_uses_deltas = True fast_deltas = False supports_funky_characters = True @@ -252,37 +260,60 @@ def _inventory_serializer(self): # parent to the revision text. revision_graph_can_have_wrong_parents = True - def _get_inventories(self, repo_transport, repo, name='inventory'): + def _get_inventories(self, repo_transport, repo, name="inventory"): mapper = versionedfile.ConstantMapper(name) - index = _mod_knit._KndxIndex(repo_transport, mapper, - repo.get_transaction, repo.is_write_locked, repo.is_locked) + index = _mod_knit._KndxIndex( + repo_transport, + mapper, + repo.get_transaction, + repo.is_write_locked, + repo.is_locked, + ) access = _mod_knit._KnitKeyAccess(repo_transport, mapper) return _mod_knit.KnitVersionedFiles(index, access, annotated=False) def _get_revisions(self, repo_transport, repo): - mapper = versionedfile.ConstantMapper('revisions') - index = _mod_knit._KndxIndex(repo_transport, mapper, - repo.get_transaction, repo.is_write_locked, repo.is_locked) + mapper = versionedfile.ConstantMapper("revisions") + index = _mod_knit._KndxIndex( + repo_transport, + mapper, + repo.get_transaction, + repo.is_write_locked, + repo.is_locked, + ) access = _mod_knit._KnitKeyAccess(repo_transport, mapper) - return _mod_knit.KnitVersionedFiles(index, access, max_delta_chain=0, - annotated=False) + return _mod_knit.KnitVersionedFiles( + index, access, max_delta_chain=0, annotated=False + ) def _get_signatures(self, repo_transport, repo): - mapper = versionedfile.ConstantMapper('signatures') - index = _mod_knit._KndxIndex(repo_transport, mapper, - repo.get_transaction, repo.is_write_locked, repo.is_locked) + mapper = versionedfile.ConstantMapper("signatures") + index = _mod_knit._KndxIndex( + repo_transport, + mapper, + repo.get_transaction, + repo.is_write_locked, + repo.is_locked, + ) access = _mod_knit._KnitKeyAccess(repo_transport, mapper) - return _mod_knit.KnitVersionedFiles(index, access, max_delta_chain=0, - annotated=False) + return _mod_knit.KnitVersionedFiles( + index, access, max_delta_chain=0, annotated=False + ) def _get_texts(self, repo_transport, repo): mapper = versionedfile.HashEscapedPrefixMapper() - base_transport = repo_transport.clone('knits') - index = _mod_knit._KndxIndex(base_transport, mapper, - repo.get_transaction, repo.is_write_locked, repo.is_locked) + base_transport = repo_transport.clone("knits") + index = _mod_knit._KndxIndex( + base_transport, + mapper, + repo.get_transaction, + repo.is_write_locked, + repo.is_locked, + ) access = _mod_knit._KnitKeyAccess(base_transport, mapper) - return _mod_knit.KnitVersionedFiles(index, access, max_delta_chain=200, - annotated=True) + return _mod_knit.KnitVersionedFiles( + index, access, max_delta_chain=200, annotated=True + ) def initialize(self, a_controldir, shared=False): """Create a knit format 1 repository. @@ -292,24 +323,22 @@ def initialize(self, a_controldir, shared=False): :param shared: If true the repository will be initialized as a shared repository. """ - trace.mutter('creating repository in %s.', a_controldir.transport.base) - dirs = ['knits'] + trace.mutter("creating repository in %s.", a_controldir.transport.base) + dirs = ["knits"] files = [] - utf8_files = [('format', self.get_format_string())] + utf8_files = [("format", self.get_format_string())] - self._upload_blank_content( - a_controldir, dirs, files, utf8_files, shared) + self._upload_blank_content(a_controldir, dirs, files, utf8_files, shared) repo_transport = a_controldir.get_repository_transport(None) - lockable_files.LockableFiles(repo_transport, - 'lock', lockdir.LockDir) + lockable_files.LockableFiles(repo_transport, "lock", lockdir.LockDir) transactions.WriteTransaction() result = self.open(a_controldir=a_controldir, _found=True) result.lock_write() # the revision id here is irrelevant: it will not be stored, and cannot # already exist, we do this to create files on disk for older clients. - result.inventories.get_parent_map([(b'A',)]) - result.revisions.get_parent_map([(b'A',)]) - result.signatures.get_parent_map([(b'A',)]) + result.inventories.get_parent_map([(b"A",)]) + result.revisions.get_parent_map([(b"A",)]) + result.signatures.get_parent_map([(b"A",)]) result.unlock() self._run_post_repo_init_hooks(result, a_controldir, shared) return result @@ -327,14 +356,17 @@ def open(self, a_controldir, _found=False, _override_transport=None): repo_transport = _override_transport else: repo_transport = a_controldir.get_repository_transport(None) - control_files = lockable_files.LockableFiles(repo_transport, - 'lock', lockdir.LockDir) - repo = self.repository_class(_format=self, - a_controldir=a_controldir, - control_files=control_files, - _commit_builder_class=self._commit_builder_class, - _revision_serializer=self._revision_serializer, - _inventory_serializer=self._inventory_serializer) + control_files = lockable_files.LockableFiles( + repo_transport, "lock", lockdir.LockDir + ) + repo = self.repository_class( + _format=self, + a_controldir=a_controldir, + control_files=control_files, + _commit_builder_class=self._commit_builder_class, + _revision_serializer=self._revision_serializer, + _inventory_serializer=self._inventory_serializer, + ) repo.revisions = self._get_revisions(repo_transport, repo) repo.signatures = self._get_signatures(repo_transport, repo) repo.inventories = self._get_inventories(repo_transport, repo) @@ -366,11 +398,13 @@ class RepositoryFormatKnit1(RepositoryFormatKnit): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml5 import inventory_serializer_v5 + return inventory_serializer_v5 def __ne__(self, other): @@ -411,21 +445,22 @@ class RepositoryFormatKnit3(RepositoryFormatKnit): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml7 import inventory_serializer_v7 + return inventory_serializer_v7 def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir('dirstate-with-subtree') + return controldir.format_registry.make_controldir("dirstate-with-subtree") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): @@ -461,26 +496,27 @@ class RepositoryFormatKnit4(RepositoryFormatKnit): @property def _revision_serializer(self): from .xml5 import revision_serializer_v5 + return revision_serializer_v5 @property def _inventory_serializer(self): from .xml6 import inventory_serializer_v6 + return inventory_serializer_v6 def _get_matching_bzrdir(self): - return controldir.format_registry.make_controldir('rich-root') + return controldir.format_registry.make_controldir("rich-root") def _ignore_setting_bzrdir(self, format): pass - _matchingcontroldir = property( - _get_matching_bzrdir, _ignore_setting_bzrdir) + _matchingcontroldir = property(_get_matching_bzrdir, _ignore_setting_bzrdir) @classmethod def get_format_string(cls): """See RepositoryFormat.get_format_string().""" - return b'Bazaar Knit Repository Format 4 (bzr 1.0)\n' + return b"Bazaar Knit Repository Format 4 (bzr 1.0)\n" def get_format_description(self): """See RepositoryFormat.get_format_description().""" @@ -503,31 +539,33 @@ def is_compatible(source, target): overly general. """ try: - are_knits = (isinstance(source._format, RepositoryFormatKnit) - and isinstance(target._format, RepositoryFormatKnit)) + are_knits = isinstance(source._format, RepositoryFormatKnit) and isinstance( + target._format, RepositoryFormatKnit + ) except AttributeError: return False return are_knits and InterRepository._same_model(source, target) - def search_missing_revision_ids(self, - find_ghosts=True, revision_ids=None, if_present_ids=None, - limit=None): + def search_missing_revision_ids( + self, find_ghosts=True, revision_ids=None, if_present_ids=None, limit=None + ): """See InterRepository.search_missing_revision_ids().""" import itertools + with self.lock_read(): source_ids_set = self._present_source_revisions_for( - revision_ids, if_present_ids) + revision_ids, if_present_ids + ) # source_ids is the worst possible case we may need to pull. # now we want to filter source_ids against what we actually # have in target, but don't try to check for existence where we know # we do not have a revision as that would be pointless. target_ids = set(self.target.all_revision_ids()) - possibly_present_revisions = target_ids.intersection( - source_ids_set) + possibly_present_revisions = target_ids.intersection(source_ids_set) actually_present_revisions = set( - self.target._eliminate_revisions_not_present(possibly_present_revisions)) - required_revisions = source_ids_set.difference( - actually_present_revisions) + self.target._eliminate_revisions_not_present(possibly_present_revisions) + ) + required_revisions = source_ids_set.difference(actually_present_revisions) if revision_ids is not None: # we used get_ancestry to determine source_ids then we are assured all # revisions referenced are present as they are installed in topological order. @@ -538,7 +576,8 @@ def search_missing_revision_ids(self, # we only have an estimate of whats available and need to validate # that against the revision records. result_set = set( - self.source._eliminate_revisions_not_present(required_revisions)) + self.source._eliminate_revisions_not_present(required_revisions) + ) if limit is not None: topo_ordered = self.source.get_graph().iter_topo_order(result_set) result_set = set(itertools.islice(topo_ordered, limit)) diff --git a/breezy/bzr/lockable_files.py b/breezy/bzr/lockable_files.py index c8ef0e4280..0e4d01cbb3 100644 --- a/breezy/bzr/lockable_files.py +++ b/breezy/bzr/lockable_files.py @@ -49,7 +49,9 @@ class LockableFiles: _token_from_lock: Optional[lock.LockToken] _transaction: Optional[transactions.Transaction] - def __init__(self, transport: Transport, lock_name: str, lock_class: Type[lock.Lock]) -> None: + def __init__( + self, transport: Transport, lock_name: str, lock_class: Type[lock.Lock] + ) -> None: """Create a LockableFiles group. :param transport: Transport pointing to the directory holding the @@ -65,9 +67,12 @@ def __init__(self, transport: Transport, lock_name: str, lock_class: Type[lock.L self._lock_count = 0 self._find_modes() esc_name = self._escape(lock_name) - self._lock = lock_class(transport, esc_name, - file_modebits=self._file_mode, - dir_modebits=self._dir_mode) + self._lock = lock_class( + transport, + esc_name, + file_modebits=self._file_mode, + dir_modebits=self._dir_mode, + ) self._counted_lock = counted_lock.CountedLock(self._lock) def create_lock(self) -> None: @@ -79,10 +84,10 @@ def create_lock(self) -> None: self._lock.create(mode=self._dir_mode) def __repr__(self): - return f'{self.__class__.__name__}({self._transport!r})' + return f"{self.__class__.__name__}({self._transport!r})" def __str__(self): - return f'LockableFiles({self.lock_name}, {self._transport.base})' + return f"LockableFiles({self.lock_name}, {self._transport.base})" def break_lock(self) -> None: """Break the lock of this lockable files group if it is held. @@ -93,8 +98,8 @@ def break_lock(self) -> None: def _escape(self, file_or_path: str) -> str: """DEPRECATED: Do not use outside this class.""" - if file_or_path == '': - return '' + if file_or_path == "": + return "" return urlutils.escape(file_or_path) def _find_modes(self) -> None: @@ -106,7 +111,7 @@ def _find_modes(self) -> None: # once all the _get_text_store methods etc no longer use them. # -- mbp 20080512 try: - st = self._transport.stat('.') + st = self._transport.stat(".") except errors.TransportNotPossible: self._dir_mode = 0o755 self._file_mode = 0o644 @@ -127,7 +132,9 @@ def dont_leave_in_place(self) -> None: """Set this LockableFiles to clear the physical lock on unlock.""" self._lock.dont_leave_in_place() - def lock_write(self, token: Optional[lock.LockToken] = None) -> Optional[lock.LockToken]: + def lock_write( + self, token: Optional[lock.LockToken] = None + ) -> Optional[lock.LockToken]: """Lock this group of files for writing. :param token: if this is already locked, then lock_write will fail @@ -143,8 +150,7 @@ def lock_write(self, token: Optional[lock.LockToken] = None) -> Optional[lock.Lo fact. """ if self._lock_mode: - if (self._lock_mode != 'w' - or not self.get_transaction().writeable()): + if self._lock_mode != "w" or not self.get_transaction().writeable(): raise errors.ReadOnlyError(self) self._lock.validate_token(token) self._lock_count += 1 @@ -152,7 +158,7 @@ def lock_write(self, token: Optional[lock.LockToken] = None) -> Optional[lock.Lo else: token_from_lock = self._lock.lock_write(token=token) # traceback.print_stack() - self._lock_mode = 'w' + self._lock_mode = "w" self._lock_count = 1 self._set_write_transaction() self._token_from_lock = token_from_lock @@ -160,13 +166,13 @@ def lock_write(self, token: Optional[lock.LockToken] = None) -> Optional[lock.Lo def lock_read(self) -> None: if self._lock_mode: - if self._lock_mode not in ('r', 'w'): + if self._lock_mode not in ("r", "w"): raise ValueError(f"invalid lock mode {self._lock_mode!r}") self._lock_count += 1 else: self._lock.lock_read() # traceback.print_stack() - self._lock_mode = 'r' + self._lock_mode = "r" self._lock_count = 1 self._set_read_transaction() @@ -225,13 +231,13 @@ def get_transaction(self) -> transactions.Transaction: def _set_transaction(self, new_transaction): """Set a new active transaction.""" if self._transaction is not None: - raise errors.LockError(f'Branch {self} is in a transaction already.') + raise errors.LockError(f"Branch {self} is in a transaction already.") self._transaction = new_transaction def _finish_transaction(self): """Exit the current transaction.""" if self._transaction is None: - raise errors.LockError(f'Branch {self} is not in a transaction') + raise errors.LockError(f"Branch {self} is not in a transaction") transaction = self._transaction self._transaction = None transaction.finish() @@ -249,7 +255,9 @@ class TransportLock: always local). """ - def __init__(self, transport: Transport, escaped_name: str, file_modebits, dir_modebits): + def __init__( + self, transport: Transport, escaped_name: str, file_modebits, dir_modebits + ): self._transport = transport self._escaped_name = escaped_name self._file_modebits = file_modebits @@ -282,8 +290,7 @@ def peek(self): def create(self, mode=None): """Create lock mechanism.""" # for old-style locks, create the file now - self._transport.put_bytes(self._escaped_name, b'', - mode=self._file_modebits) + self._transport.put_bytes(self._escaped_name, b"", mode=self._file_modebits) def validate_token(self, token): if token is not None: diff --git a/breezy/bzr/pack.py b/breezy/bzr/pack.py index 0b350a7681..45e9c8d6e7 100644 --- a/breezy/bzr/pack.py +++ b/breezy/bzr/pack.py @@ -28,7 +28,7 @@ FORMAT_ONE = b"Bazaar pack format 1 (introduced in 0.18)" -_whitespace_re = re.compile(b'[\t\n\x0b\x0c\r ]') +_whitespace_re = re.compile(b"[\t\n\x0b\x0c\r ]") class ContainerError(errors.BzrError): @@ -36,7 +36,6 @@ class ContainerError(errors.BzrError): class UnknownContainerFormatError(ContainerError): - _fmt = "Unrecognised container format: %(container_format)r" def __init__(self, container_format): @@ -44,12 +43,10 @@ def __init__(self, container_format): class UnexpectedEndOfContainerError(ContainerError): - _fmt = "Unexpected end of container stream" class UnknownRecordTypeError(ContainerError): - _fmt = "Unknown record type: %(record_type)r" def __init__(self, record_type): @@ -57,7 +54,6 @@ def __init__(self, record_type): class InvalidRecordError(ContainerError): - _fmt = "Invalid record: %(reason)s" def __init__(self, reason): @@ -65,7 +61,6 @@ def __init__(self, reason): class ContainerHasExcessDataError(ContainerError): - _fmt = "Container has data after end marker: %(excess)r" def __init__(self, excess): @@ -73,7 +68,6 @@ def __init__(self, excess): class DuplicateRecordNameError(ContainerError): - _fmt = "Container has multiple records with the same name: %(name)s" def __init__(self, name): @@ -102,7 +96,7 @@ def _check_name_encoding(name): :raises InvalidRecordError: if name is not valid UTF-8. """ try: - name.decode('utf-8') + name.decode("utf-8") except UnicodeDecodeError as e: raise InvalidRecordError(str(e)) from e @@ -135,10 +129,10 @@ def bytes_header(self, length, names): # half-written record if a name is bad! for name in name_tuple: _check_name(name) - byte_sections.append(b'\x00'.join(name_tuple) + b"\n") + byte_sections.append(b"\x00".join(name_tuple) + b"\n") # End of headers byte_sections.append(b"\n") - return b''.join(byte_sections) + return b"".join(byte_sections) def bytes_record(self, bytes, names): """Return the bytes for a Bytes record with the given name and @@ -202,8 +196,9 @@ def add_bytes_record(self, chunks, length, names): """ current_offset = self.current_offset if length < self._JOIN_WRITES_THRESHOLD: - self.write_func(self._serialiser.bytes_header(length, names) - + b''.join(chunks)) + self.write_func( + self._serialiser.bytes_header(length, names) + b"".join(chunks) + ) else: self.write_func(self._serialiser.bytes_header(length, names)) for chunk in chunks: @@ -237,8 +232,7 @@ def __init__(self, readv_result): self._string = None def _next(self): - if (self._string is None or - self._string.tell() == self._string_length): + if self._string is None or self._string.tell() == self._string_length: offset, data = next(self.readv_result) self._string_length = len(data) self._string = BytesIO(data) @@ -247,17 +241,18 @@ def read(self, length): self._next() result = self._string.read(length) if len(result) < length: - raise errors.BzrError('wanted %d bytes but next ' - 'hunk only contains %d: %r...' % - (length, len(result), result[:20])) + raise errors.BzrError( + "wanted %d bytes but next " + "hunk only contains %d: %r..." % (length, len(result), result[:20]) + ) return result def readline(self): """Note that readline will not cross readv segments.""" self._next() result = self._string.readline() - if self._string.tell() == self._string_length and result[-1:] != b'\n': - raise errors.BzrError(f'short readline in the readvfile hunk: {result!r}') + if self._string.tell() == self._string_length and result[-1:] != b"\n": + raise errors.BzrError(f"short readline in the readvfile hunk: {result!r}") return result @@ -271,13 +266,11 @@ def make_readv_reader(transport, filename, requested_records): """ readv_blocks = [(0, len(FORMAT_ONE) + 1)] readv_blocks.extend(requested_records) - result = ContainerReader(ReadVFile( - transport.readv(filename, readv_blocks))) + result = ContainerReader(ReadVFile(transport.readv(filename, readv_blocks))) return result class BaseReader: - def __init__(self, source_file): """Constructor. @@ -291,9 +284,9 @@ def reader_func(self, length=None): def _read_line(self): line = self._source.readline() - if not line.endswith(b'\n'): + if not line.endswith(b"\n"): raise UnexpectedEndOfContainerError() - return line.rstrip(b'\n') + return line.rstrip(b"\n") class ContainerReader(BaseReader): @@ -350,14 +343,14 @@ def _iter_record_objects(self): record_kind = self.reader_func(1) except StopIteration: return - if record_kind == b'B': + if record_kind == b"B": # Bytes record. reader = BytesRecordReader(self._source) yield reader - elif record_kind == b'E': + elif record_kind == b"E": # End marker. There are no more records. return - elif record_kind == b'': + elif record_kind == b"": # End of stream encountered, but no End Marker record seen, so # this container is incomplete. raise UnexpectedEndOfContainerError() @@ -393,12 +386,11 @@ def validate(self): raise DuplicateRecordNameError(name_tuple[0]) all_names.add(name_tuple) excess_bytes = self.reader_func(1) - if excess_bytes != b'': + if excess_bytes != b"": raise ContainerHasExcessDataError(excess_bytes) class BytesRecordReader(BaseReader): - def read(self): """Read this record. @@ -415,16 +407,15 @@ def read(self): try: length = int(length_line) except ValueError as e: - raise InvalidRecordError( - f"{length_line!r} is not a valid length.") from e + raise InvalidRecordError(f"{length_line!r} is not a valid length.") from e # Read the list of names. names = [] while True: name_line = self._read_line() - if name_line == b'': + if name_line == b"": break - name_tuple = tuple(name_line.split(b'\x00')) + name_tuple = tuple(name_line.split(b"\x00")) for name in name_tuple: _check_name(name) names.append(name_tuple) @@ -465,7 +456,7 @@ class ContainerPushParser: """ def __init__(self): - self._buffer = b'' + self._buffer = b"" self._state_handler = self._state_expecting_format_line self._parsed_records = [] self._reset_current_record() @@ -482,8 +473,10 @@ def accept_bytes(self, bytes): last_buffer_length = None cur_buffer_length = len(self._buffer) last_state_handler = None - while (cur_buffer_length != last_buffer_length - or last_state_handler != self._state_handler): + while ( + cur_buffer_length != last_buffer_length + or last_state_handler != self._state_handler + ): last_buffer_length = cur_buffer_length last_state_handler = self._state_handler self._state_handler() @@ -505,10 +498,10 @@ def _consume_line(self): If a newline byte is not found in the buffer, the buffer is unchanged and this returns None instead. """ - newline_pos = self._buffer.find(b'\n') + newline_pos = self._buffer.find(b"\n") if newline_pos != -1: line = self._buffer[:newline_pos] - self._buffer = self._buffer[newline_pos + 1:] + self._buffer = self._buffer[newline_pos + 1 :] return line else: return None @@ -524,9 +517,9 @@ def _state_expecting_record_type(self): if len(self._buffer) >= 1: record_type = self._buffer[:1] self._buffer = self._buffer[1:] - if record_type == b'B': + if record_type == b"B": self._state_handler = self._state_expecting_length - elif record_type == b'E': + elif record_type == b"E": self.finished = True self._state_handler = self._state_expecting_nothing else: @@ -538,24 +531,23 @@ def _state_expecting_length(self): try: self._current_record_length = int(line) except ValueError as e: - raise InvalidRecordError( - f"{line!r} is not a valid length.") from e + raise InvalidRecordError(f"{line!r} is not a valid length.") from e self._state_handler = self._state_expecting_name def _state_expecting_name(self): encoded_name_parts = self._consume_line() - if encoded_name_parts == b'': + if encoded_name_parts == b"": self._state_handler = self._state_expecting_body elif encoded_name_parts: - name_parts = tuple(encoded_name_parts.split(b'\x00')) + name_parts = tuple(encoded_name_parts.split(b"\x00")) for name_part in name_parts: _check_name(name_part) self._current_record_names.append(name_parts) def _state_expecting_body(self): if len(self._buffer) >= self._current_record_length: - body_bytes = self._buffer[:self._current_record_length] - self._buffer = self._buffer[self._current_record_length:] + body_bytes = self._buffer[: self._current_record_length] + self._buffer = self._buffer[self._current_record_length :] record = (self._current_record_names, body_bytes) self._parsed_records.append(record) self._reset_current_record() diff --git a/breezy/bzr/pack_repo.py b/breezy/bzr/pack_repo.py index 16325f8d9b..8b6df2c492 100644 --- a/breezy/bzr/pack_repo.py +++ b/breezy/bzr/pack_repo.py @@ -22,7 +22,9 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import time from breezy import ( @@ -37,7 +39,8 @@ from breezy.bzr.index import ( CombinedGraphIndex, ) -""") +""", +) from .. import debug, errors, lockdir, osutils from .. import transport as _mod_transport from ..bzr import btree_index, lockable_files @@ -65,8 +68,10 @@ class RetryWithNewPacks(errors.BzrError): internal_error = True - _fmt = ("Pack files have changed, reload and retry. context: %(context)s" - " %(orig_error)s") + _fmt = ( + "Pack files have changed, reload and retry. context: %(context)s" + " %(orig_error)s" + ) def __init__(self, context, reload_occurred, exc_info): """Create a new RetryWithNewPacks error. @@ -98,8 +103,10 @@ class RetryAutopack(RetryWithNewPacks): internal_error = True - _fmt = ("Pack files have changed, reload and try autopack again." - " context: %(context)s %(orig_error)s") + _fmt = ( + "Pack files have changed, reload and try autopack again." + " context: %(context)s %(orig_error)s" + ) class PackCommitBuilder(VersionedFileCommitBuilder): @@ -109,15 +116,35 @@ class PackCommitBuilder(VersionedFileCommitBuilder): added text, reducing memory and object pressure. """ - def __init__(self, repository, parents, config, timestamp=None, - timezone=None, committer=None, revprops=None, - revision_id=None, lossy=False, owns_transaction=True): - VersionedFileCommitBuilder.__init__(self, repository, parents, config, - timestamp=timestamp, timezone=timezone, committer=committer, - revprops=revprops, revision_id=revision_id, lossy=lossy, - owns_transaction=owns_transaction) + def __init__( + self, + repository, + parents, + config, + timestamp=None, + timezone=None, + committer=None, + revprops=None, + revision_id=None, + lossy=False, + owns_transaction=True, + ): + VersionedFileCommitBuilder.__init__( + self, + repository, + parents, + config, + timestamp=timestamp, + timezone=timezone, + committer=committer, + revprops=revprops, + revision_id=revision_id, + lossy=lossy, + owns_transaction=owns_transaction, + ) self._file_graph = graph.Graph( - repository._pack_collection.text_index.combined_index) + repository._pack_collection.text_index.combined_index + ) def _heads(self, file_id, revision_ids): keys = [(file_id, revision_id) for revision_id in revision_ids] @@ -134,15 +161,21 @@ class Pack: # A map of index 'type' to the file extension and position in the # index_sizes array. index_definitions = { - 'chk': ('.cix', 4), - 'revision': ('.rix', 0), - 'inventory': ('.iix', 1), - 'text': ('.tix', 2), - 'signature': ('.six', 3), - } - - def __init__(self, revision_index, inventory_index, text_index, - signature_index, chk_index=None): + "chk": (".cix", 4), + "revision": (".rix", 0), + "inventory": (".iix", 1), + "text": (".tix", 2), + "signature": (".six", 3), + } + + def __init__( + self, + revision_index, + inventory_index, + text_index, + signature_index, + chk_index=None, + ): """Create a pack instance. :param revision_index: A GraphIndex for determining what revisions are @@ -177,28 +210,34 @@ def _check_references(self): (See ) """ missing_items = {} - for (index_name, external_refs, index) in [ - ('texts', - self._get_external_refs(self.text_index), - self._pack_collection.text_index.combined_index), - ('inventories', - self._get_external_refs(self.inventory_index), - self._pack_collection.inventory_index.combined_index), - ]: + for index_name, external_refs, index in [ + ( + "texts", + self._get_external_refs(self.text_index), + self._pack_collection.text_index.combined_index, + ), + ( + "inventories", + self._get_external_refs(self.inventory_index), + self._pack_collection.inventory_index.combined_index, + ), + ]: missing = external_refs.difference( - k for (idx, k, v, r) in - index.iter_entries(external_refs)) + k for (idx, k, v, r) in index.iter_entries(external_refs) + ) if missing: missing_items[index_name] = sorted(missing) if missing_items: from pprint import pformat + raise errors.BzrCheckError( f"Newly created pack file {self!r} has delta references to " - f"items not in its repository:\n{pformat(missing_items)}") + f"items not in its repository:\n{pformat(missing_items)}" + ) def file_name(self): """Get the file name for the pack on disk.""" - return self.name + '.pack' + return self.name + ".pack" def get_revision_count(self): return self.revision_index.key_count() @@ -213,59 +252,88 @@ def index_offset(self, index_type): def inventory_index_name(self, name): """The inv index is the name + .iix.""" - return self.index_name('inventory', name) + return self.index_name("inventory", name) def revision_index_name(self, name): """The revision index is the name + .rix.""" - return self.index_name('revision', name) + return self.index_name("revision", name) def signature_index_name(self, name): """The signature index is the name + .six.""" - return self.index_name('signature', name) + return self.index_name("signature", name) def text_index_name(self, name): """The text index is the name + .tix.""" - return self.index_name('text', name) + return self.index_name("text", name) def _replace_index_with_readonly(self, index_type): unlimited_cache = False - if index_type == 'chk': + if index_type == "chk": unlimited_cache = True - index = self.index_class(self.index_transport, - self.index_name(index_type, self.name), - self.index_sizes[self.index_offset( - index_type)], - unlimited_cache=unlimited_cache) - if index_type == 'chk': + index = self.index_class( + self.index_transport, + self.index_name(index_type, self.name), + self.index_sizes[self.index_offset(index_type)], + unlimited_cache=unlimited_cache, + ) + if index_type == "chk": index._leaf_factory = btree_index._gcchk_factory - setattr(self, index_type + '_index', index) + setattr(self, index_type + "_index", index) def __lt__(self, other): if not isinstance(other, Pack): raise TypeError(other) - return (id(self) < id(other)) + return id(self) < id(other) def __hash__(self): - return hash((type(self), self.revision_index, self.inventory_index, - self.text_index, self.signature_index, self.chk_index)) + return hash( + ( + type(self), + self.revision_index, + self.inventory_index, + self.text_index, + self.signature_index, + self.chk_index, + ) + ) class ExistingPack(Pack): """An in memory proxy for an existing .pack and its disk indices.""" - def __init__(self, pack_transport, name, revision_index, inventory_index, - text_index, signature_index, chk_index=None): + def __init__( + self, + pack_transport, + name, + revision_index, + inventory_index, + text_index, + signature_index, + chk_index=None, + ): """Create an ExistingPack object. :param pack_transport: The transport where the pack file resides. :param name: The name of the pack on disk in the pack_transport. """ - Pack.__init__(self, revision_index, inventory_index, text_index, - signature_index, chk_index) + Pack.__init__( + self, + revision_index, + inventory_index, + text_index, + signature_index, + chk_index, + ) self.name = name self.pack_transport = pack_transport - if None in (revision_index, inventory_index, text_index, - signature_index, name, pack_transport): + if None in ( + revision_index, + inventory_index, + text_index, + signature_index, + name, + pack_transport, + ): raise AssertionError() def __eq__(self, other): @@ -276,54 +344,78 @@ def __ne__(self, other): def __repr__(self): return "<{}.{} object at 0x{:x}, {}, {}".format( - self.__class__.__module__, self.__class__.__name__, id(self), - self.pack_transport, self.name) + self.__class__.__module__, + self.__class__.__name__, + id(self), + self.pack_transport, + self.name, + ) def __hash__(self): return hash((type(self), self.name)) class ResumedPack(ExistingPack): - - def __init__(self, name, revision_index, inventory_index, text_index, - signature_index, upload_transport, pack_transport, index_transport, - pack_collection, chk_index=None): + def __init__( + self, + name, + revision_index, + inventory_index, + text_index, + signature_index, + upload_transport, + pack_transport, + index_transport, + pack_collection, + chk_index=None, + ): """Create a ResumedPack object.""" - ExistingPack.__init__(self, pack_transport, name, revision_index, - inventory_index, text_index, signature_index, - chk_index=chk_index) + ExistingPack.__init__( + self, + pack_transport, + name, + revision_index, + inventory_index, + text_index, + signature_index, + chk_index=chk_index, + ) self.upload_transport = upload_transport self.index_transport = index_transport self.index_sizes = [None, None, None, None] indices = [ - ('revision', revision_index), - ('inventory', inventory_index), - ('text', text_index), - ('signature', signature_index), - ] + ("revision", revision_index), + ("inventory", inventory_index), + ("text", text_index), + ("signature", signature_index), + ] if chk_index is not None: - indices.append(('chk', chk_index)) + indices.append(("chk", chk_index)) self.index_sizes.append(None) for index_type, index in indices: offset = self.index_offset(index_type) self.index_sizes[offset] = index._size self.index_class = pack_collection._index_class self._pack_collection = pack_collection - self._state = 'resumed' + self._state = "resumed" # XXX: perhaps check that the .pack file exists? def access_tuple(self): - if self._state == 'finished': + if self._state == "finished": return Pack.access_tuple(self) - elif self._state == 'resumed': + elif self._state == "resumed": return self.upload_transport, self.file_name() else: raise AssertionError(self._state) def abort(self): self.upload_transport.delete(self.file_name()) - indices = [self.revision_index, self.inventory_index, self.text_index, - self.signature_index] + indices = [ + self.revision_index, + self.inventory_index, + self.text_index, + self.signature_index, + ] if self.chk_index is not None: indices.append(self.chk_index) for index in indices: @@ -331,17 +423,17 @@ def abort(self): def finish(self): self._check_references() - index_types = ['revision', 'inventory', 'text', 'signature'] + index_types = ["revision", "inventory", "text", "signature"] if self.chk_index is not None: - index_types.append('chk') + index_types.append("chk") for index_type in index_types: old_name = self.index_name(index_type, self.name) - new_name = '../indices/' + old_name + new_name = "../indices/" + old_name self.upload_transport.move(old_name, new_name) self._replace_index_with_readonly(index_type) - new_name = '../packs/' + self.file_name() + new_name = "../packs/" + self.file_name() self.upload_transport.move(self.file_name(), new_name) - self._state = 'finished' + self._state = "finished" def _get_external_refs(self, index): """Return compression parents for this index that are not present. @@ -355,7 +447,7 @@ def _get_external_refs(self, index): class NewPack(Pack): """An in memory proxy for a pack which is being created.""" - def __init__(self, pack_collection, upload_suffix='', file_mode=None): + def __init__(self, pack_collection, upload_suffix="", file_mode=None): """Create a NewPack instance. :param pack_collection: A PackCollection into which this is being inserted. @@ -370,23 +462,24 @@ def __init__(self, pack_collection, upload_suffix='', file_mode=None): chk_index = index_builder_class(reference_lists=0) else: chk_index = None - Pack.__init__(self, - # Revisions: parents list, no text compression. - index_builder_class(reference_lists=1), - # Inventory: We want to map compression only, but currently the - # knit code hasn't been updated enough to understand that, so we - # have a regular 2-list index giving parents and compression - # source. - index_builder_class(reference_lists=2), - # Texts: compression and per file graph, for all fileids - so two - # reference lists and two elements in the key tuple. - index_builder_class(reference_lists=2, key_elements=2), - # Signatures: Just blobs to store, no compression, no parents - # listing. - index_builder_class(reference_lists=0), - # CHK based storage - just blobs, no compression or parents. - chk_index=chk_index - ) + Pack.__init__( + self, + # Revisions: parents list, no text compression. + index_builder_class(reference_lists=1), + # Inventory: We want to map compression only, but currently the + # knit code hasn't been updated enough to understand that, so we + # have a regular 2-list index giving parents and compression + # source. + index_builder_class(reference_lists=2), + # Texts: compression and per file graph, for all fileids - so two + # reference lists and two elements in the key tuple. + index_builder_class(reference_lists=2, key_elements=2), + # Signatures: Just blobs to store, no compression, no parents + # listing. + index_builder_class(reference_lists=0), + # CHK based storage - just blobs, no compression or parents. + chk_index=chk_index, + ) self._pack_collection = pack_collection # When we make readonly indices, we need this. self.index_class = pack_collection._index_class @@ -414,11 +507,16 @@ def __init__(self, pack_collection, upload_suffix='', file_mode=None): self.start_time = time.time() # open an output stream for the data added to the pack. self.write_stream = self.upload_transport.open_write_stream( - self.random_name, mode=self._file_mode) - if debug.debug_flag_enabled('pack'): - mutter('%s: create_pack: pack stream open: %s%s t+%6.3fs', - time.ctime(), self.upload_transport.base, self.random_name, - time.time() - self.start_time) + self.random_name, mode=self._file_mode + ) + if debug.debug_flag_enabled("pack"): + mutter( + "%s: create_pack: pack stream open: %s%s t+%6.3fs", + time.ctime(), + self.upload_transport.base, + self.random_name, + time.time() - self.start_time, + ) # A list of byte sequences to be written to the new pack, and the # aggregate size of them. Stored as a list rather than separate # variables so that the _write_data closure below can update them. @@ -429,29 +527,35 @@ def __init__(self, pack_collection, upload_suffix='', file_mode=None): # so that the variables are locals, and faster than accessing object # members. - def _write_data(bytes, flush=False, _buffer=self._buffer, - _write=self.write_stream.write, _update=self._hash.update): + def _write_data( + bytes, + flush=False, + _buffer=self._buffer, + _write=self.write_stream.write, + _update=self._hash.update, + ): _buffer[0].append(bytes) _buffer[1] += len(bytes) # buffer cap if _buffer[1] > self._cache_limit or flush: - bytes = b''.join(_buffer[0]) + bytes = b"".join(_buffer[0]) _write(bytes) _update(bytes) _buffer[:] = [[], 0] + # expose this on self, for the occasion when clients want to add data. self._write_data = _write_data # a pack writer object to serialise pack records. self._writer = pack.ContainerWriter(self._write_data) self._writer.begin() # what state is the pack in? (open, finished, aborted) - self._state = 'open' + self._state = "open" # no name until we finish writing the content self.name = None def abort(self): """Cancel creating this pack.""" - self._state = 'aborted' + self._state = "aborted" self.write_stream.close() # Remove the temporary pack file. self.upload_transport.delete(self.random_name) @@ -459,27 +563,29 @@ def abort(self): def access_tuple(self): """Return a tuple (transport, name) for the pack content.""" - if self._state == 'finished': + if self._state == "finished": return Pack.access_tuple(self) - elif self._state == 'open': + elif self._state == "open": return self.upload_transport, self.random_name else: raise AssertionError(self._state) def data_inserted(self): """True if data has been added to this pack.""" - return bool(self.get_revision_count() or - self.inventory_index.key_count() or - self.text_index.key_count() or - self.signature_index.key_count() or - (self.chk_index is not None and self.chk_index.key_count())) + return bool( + self.get_revision_count() + or self.inventory_index.key_count() + or self.text_index.key_count() + or self.signature_index.key_count() + or (self.chk_index is not None and self.chk_index.key_count()) + ) def finish_content(self): if self.name is not None: return self._writer.end() if self._buffer[1]: - self._write_data(b'', flush=True) + self._write_data(b"", flush=True) self.name = self._hash.hexdigest() def finish(self, suspend=False): @@ -502,19 +608,20 @@ def finish(self, suspend=False): # visible is smaller. On the other hand none will be seen until # they're in the names list. self.index_sizes = [None, None, None, None] - self._write_index('revision', self.revision_index, 'revision', - suspend) - self._write_index('inventory', self.inventory_index, 'inventory', - suspend) - self._write_index('text', self.text_index, 'file texts', suspend) - self._write_index('signature', self.signature_index, - 'revision signatures', suspend) + self._write_index("revision", self.revision_index, "revision", suspend) + self._write_index("inventory", self.inventory_index, "inventory", suspend) + self._write_index("text", self.text_index, "file texts", suspend) + self._write_index( + "signature", self.signature_index, "revision signatures", suspend + ) if self.chk_index is not None: self.index_sizes.append(None) - self._write_index('chk', self.chk_index, - 'content hash bytes', suspend) + self._write_index("chk", self.chk_index, "content hash bytes", suspend) self.write_stream.close( - want_fdatasync=self._pack_collection.config_stack.get('repository.fdatasync')) + want_fdatasync=self._pack_collection.config_stack.get( + "repository.fdatasync" + ) + ) # Note that this will clobber an existing pack with the same name, # without checking for hash collisions. While this is undesirable this # is something that can be rectified in a subsequent release. One way @@ -526,21 +633,26 @@ def finish(self, suspend=False): # - try for HASH.pack # - try for temporary-name # - refresh the pack-list to see if the pack is now absent - new_name = self.name + '.pack' + new_name = self.name + ".pack" if not suspend: - new_name = '../packs/' + new_name + new_name = "../packs/" + new_name self.upload_transport.move(self.random_name, new_name) - self._state = 'finished' - if debug.debug_flag_enabled('pack'): + self._state = "finished" + if debug.debug_flag_enabled("pack"): # XXX: size might be interesting? - mutter('%s: create_pack: pack finished: %s%s->%s t+%6.3fs', - time.ctime(), self.upload_transport.base, self.random_name, - new_name, time.time() - self.start_time) + mutter( + "%s: create_pack: pack finished: %s%s->%s t+%6.3fs", + time.ctime(), + self.upload_transport.base, + self.random_name, + new_name, + time.time() - self.start_time, + ) def flush(self): """Flush any current data.""" if self._buffer[1]: - bytes = b''.join(self._buffer[0]) + bytes = b"".join(self._buffer[0]) self.write_stream.write(bytes) self._hash.update(bytes) self._buffer[:] = [[], 0] @@ -565,17 +677,24 @@ def _write_index(self, index_type, index, label, suspend=False): transport = self.index_transport index_tempfile = index.finish() index_bytes = index_tempfile.read() - write_stream = transport.open_write_stream(index_name, - mode=self._file_mode) + write_stream = transport.open_write_stream(index_name, mode=self._file_mode) write_stream.write(index_bytes) write_stream.close( - want_fdatasync=self._pack_collection.config_stack.get('repository.fdatasync')) + want_fdatasync=self._pack_collection.config_stack.get( + "repository.fdatasync" + ) + ) self.index_sizes[self.index_offset(index_type)] = len(index_bytes) - if debug.debug_flag_enabled('pack'): + if debug.debug_flag_enabled("pack"): # XXX: size might be interesting? - mutter('%s: create_pack: wrote %s index: %s%s t+%6.3fs', - time.ctime(), label, self.upload_transport.base, - self.random_name, time.time() - self.start_time) + mutter( + "%s: create_pack: wrote %s index: %s%s t+%6.3fs", + time.ctime(), + label, + self.upload_transport.base, + self.random_name, + time.time() - self.start_time, + ) # Replace the writable index on this object with a readonly, # presently unloaded index. We should alter # the index layer to make its finish() error if add_node is @@ -608,9 +727,9 @@ def __init__(self, reload_func=None, flush_func=None): self._reload_func = reload_func self.index_to_pack = {} self.combined_index = CombinedGraphIndex([], reload_func=reload_func) - self.data_access = _DirectPackAccess(self.index_to_pack, - reload_func=reload_func, - flush_func=flush_func) + self.data_access = _DirectPackAccess( + self.index_to_pack, reload_func=reload_func, flush_func=flush_func + ) self.add_callback = None def add_index(self, index, pack): @@ -638,7 +757,8 @@ def add_writable_index(self, index, pack): """ if self.add_callback is not None: raise AssertionError( - f"{self} already has a writable index through {self.add_callback}") + f"{self} already has a writable index through {self.add_callback}" + ) # allow writing: queue writes to a new index self.add_index(index, pack) # Updates the index to packs mapping as a side effect, @@ -662,8 +782,10 @@ def remove_index(self, index): pos = self.combined_index._indices.index(index) del self.combined_index._indices[pos] del self.combined_index._index_names[pos] - if (self.add_callback is not None and - getattr(index, 'add_nodes', None) == self.add_callback): + if ( + self.add_callback is not None + and getattr(index, "add_nodes", None) == self.add_callback + ): self.add_callback = None self.data_access.set_writer(None, None, (None, None)) @@ -671,8 +793,9 @@ def remove_index(self, index): class Packer: """Create a pack from packs.""" - def __init__(self, pack_collection, packs, suffix, revision_ids=None, - reload_func=None): + def __init__( + self, pack_collection, packs, suffix, revision_ids=None, reload_func=None + ): """Create a Packer. :param pack_collection: A RepositoryPackCollection object where the @@ -717,16 +840,18 @@ def pack(self, pb=None): # XXX: - duplicate code warning with start_write_group; fix before # considering 'done'. if self._pack_collection._new_pack is not None: - raise errors.BzrError('call to {}.pack() while another pack is' - ' being written.'.format(self.__class__.__name__)) + raise errors.BzrError( + "call to {}.pack() while another pack is" " being written.".format( + self.__class__.__name__ + ) + ) if self.revision_ids is not None: if len(self.revision_ids) == 0: # silly fetch request. return None else: self.revision_ids = frozenset(self.revision_ids) - self.revision_keys = frozenset((revid,) for revid in - self.revision_ids) + self.revision_keys = frozenset((revid,) for revid in self.revision_ids) if pb is None: self.pb = ui.ui_factory.nested_progress_bar() else: @@ -739,9 +864,11 @@ def pack(self, pb=None): def open_pack(self): """Open a pack for the pack we are creating.""" - new_pack = self._pack_collection.pack_factory(self._pack_collection, - upload_suffix=self.suffix, - file_mode=self._pack_collection.repo.controldir._get_file_mode()) + new_pack = self._pack_collection.pack_factory( + self._pack_collection, + upload_suffix=self.suffix, + file_mode=self._pack_collection.repo.controldir._get_file_mode(), + ) # We know that we will process all nodes in order, and don't need to # query, so don't combine any indices spilled to disk until we are done new_pack.revision_index.set_optimize(combine_backing_indices=False) @@ -770,12 +897,15 @@ def _create_pack_from_packs(self): raise NotImplementedError(self._create_pack_from_packs) def _log_copied_texts(self): - if debug.debug_flag_enabled('pack'): - mutter('%s: create_pack: file texts copied: %s%s %d items t+%6.3fs', - time.ctime(), self._pack_collection._upload_transport.base, - self.new_pack.random_name, - self.new_pack.text_index.key_count(), - time.time() - self.new_pack.start_time) + if debug.debug_flag_enabled("pack"): + mutter( + "%s: create_pack: file texts copied: %s%s %d items t+%6.3fs", + time.ctime(), + self._pack_collection._upload_transport.base, + self.new_pack.random_name, + self.new_pack.text_index.key_count(), + time.time() - self.new_pack.start_time, + ) def _use_pack(self, new_pack): """Return True if new_pack should be used. @@ -797,9 +927,17 @@ class RepositoryPackCollection: normal_packer_class: Type[Packer] optimising_packer_class: Type[Packer] - def __init__(self, repo, transport, index_transport, upload_transport, - pack_transport, index_builder_class, index_class, - use_chk_index): + def __init__( + self, + repo, + transport, + index_transport, + upload_transport, + pack_transport, + index_builder_class, + index_class, + use_chk_index, + ): """Create a new RepositoryPackCollection. :param transport: Addresses the repository base directory @@ -820,8 +958,7 @@ def __init__(self, repo, transport, index_transport, upload_transport, self._pack_transport = pack_transport self._index_builder_class = index_builder_class self._index_class = index_class - self._suffix_offsets = {'.rix': 0, '.iix': 1, '.tix': 2, '.six': 3, - '.cix': 4} + self._suffix_offsets = {".rix": 0, ".iix": 1, ".tix": 2, ".six": 3, ".cix": 4} self.packs = [] # name:Pack mapping self._names = None @@ -836,8 +973,12 @@ def __init__(self, repo, transport, index_transport, upload_transport, self.inventory_index = AggregateIndex(self.reload_pack_names, flush) self.text_index = AggregateIndex(self.reload_pack_names, flush) self.signature_index = AggregateIndex(self.reload_pack_names, flush) - all_indices = [self.revision_index, self.inventory_index, - self.text_index, self.signature_index] + all_indices = [ + self.revision_index, + self.inventory_index, + self.text_index, + self.signature_index, + ] if use_chk_index: self.chk_index = AggregateIndex(self.reload_pack_names, flush) all_indices.append(self.chk_index) @@ -849,13 +990,14 @@ def __init__(self, repo, transport, index_transport, upload_transport, all_combined = [agg_idx.combined_index for agg_idx in all_indices] for combined_idx in all_combined: combined_idx.set_sibling_indices( - set(all_combined).difference([combined_idx])) + set(all_combined).difference([combined_idx]) + ) # resumed packs self._resumed_packs = [] self.config_stack = config.LocationStack(self.transport.base) def __repr__(self): - return f'{self.__class__.__name__}({self.repo!r})' + return f"{self.__class__.__name__}({self.repo!r})" def add_pack_to_memory(self, pack): """Make a Pack object available to the repository to satisfy queries. @@ -863,8 +1005,7 @@ def add_pack_to_memory(self, pack): :param pack: A Pack object. """ if pack.name in self._packs_by_name: - raise AssertionError( - f'pack {pack.name} already in _packs_by_name') + raise AssertionError(f"pack {pack.name} already in _packs_by_name") self.packs.append(pack) self._packs_by_name[pack.name] = pack self.revision_index.add_index(pack.revision_index, pack) @@ -935,22 +1076,31 @@ def _do_autopack(self): continue existing_packs.append((revision_count, pack)) pack_operations = self.plan_autopack_combinations( - existing_packs, pack_distribution) + existing_packs, pack_distribution + ) num_new_packs = len(pack_operations) num_old_packs = sum([len(po[1]) for po in pack_operations]) num_revs_affected = sum([po[0] for po in pack_operations]) - mutter('Auto-packing repository %s, which has %d pack files, ' - 'containing %d revisions. Packing %d files into %d affecting %d' - ' revisions', str( - self), total_packs, total_revisions, num_old_packs, - num_new_packs, num_revs_affected) - result = self._execute_pack_operations(pack_operations, packer_class=self.normal_packer_class, - reload_func=self._restart_autopack) - mutter('Auto-packing repository %s completed', str(self)) + mutter( + "Auto-packing repository %s, which has %d pack files, " + "containing %d revisions. Packing %d files into %d affecting %d" + " revisions", + str(self), + total_packs, + total_revisions, + num_old_packs, + num_new_packs, + num_revs_affected, + ) + result = self._execute_pack_operations( + pack_operations, + packer_class=self.normal_packer_class, + reload_func=self._restart_autopack, + ) + mutter("Auto-packing repository %s completed", str(self)) return result - def _execute_pack_operations(self, pack_operations, packer_class, - reload_func=None): + def _execute_pack_operations(self, pack_operations, packer_class, reload_func=None): """Execute a series of pack operations. :param pack_operations: A list of [revision_count, packs_to_combine]. @@ -961,8 +1111,7 @@ def _execute_pack_operations(self, pack_operations, packer_class, # we may have no-ops from the setup logic if len(packs) == 0: continue - packer = packer_class(self, packs, '.autopack', - reload_func=reload_func) + packer = packer_class(self, packs, ".autopack", reload_func=reload_func) try: result = packer.pack() except RetryWithNewPacks: @@ -982,8 +1131,9 @@ def _execute_pack_operations(self, pack_operations, packer_class, to_be_obsoleted = [] for _, packs in pack_operations: to_be_obsoleted.extend(packs) - result = self._save_pack_names(clear_obsolete_packs=True, - obsolete_packs=to_be_obsoleted) + result = self._save_pack_names( + clear_obsolete_packs=True, obsolete_packs=to_be_obsoleted + ) return result def _flush_new_pack(self): @@ -1011,9 +1161,14 @@ def pack(self, hint=None, clean_obsolete_packs=False): total_revisions = self.revision_index.combined_index.key_count() # XXX: the following may want to be a class, to pack with a given # policy. - mutter('Packing repository %s, which has %d pack files, ' - 'containing %d revisions with hint %r.', str(self), total_packs, - total_revisions, hint) + mutter( + "Packing repository %s, which has %d pack files, " + "containing %d revisions with hint %r.", + str(self), + total_packs, + total_revisions, + hint, + ) while True: try: self._try_pack_operations(hint) @@ -1036,9 +1191,11 @@ def _try_pack_operations(self, hint): # or this pack was included in the hint. pack_operations[-1][0] += pack.get_revision_count() pack_operations[-1][1].append(pack) - self._execute_pack_operations(pack_operations, - packer_class=self.optimising_packer_class, - reload_func=self._restart_pack_operations) + self._execute_pack_operations( + pack_operations, + packer_class=self.optimising_packer_class, + reload_func=self._restart_pack_operations, + ) def plan_autopack_combinations(self, existing_packs, pack_distribution): """Plan a pack operation. @@ -1087,8 +1244,10 @@ def plan_autopack_combinations(self, existing_packs, pack_distribution): final_rev_count += num_revs final_pack_list.extend(pack_files) if len(final_pack_list) == 1: - raise AssertionError('We somehow generated an autopack with a' - ' single pack file being moved.') + raise AssertionError( + "We somehow generated an autopack with a" + " single pack file being moved." + ) return [] return [[final_rev_count, final_pack_list]] @@ -1105,7 +1264,7 @@ def ensure_loaded(self): self._names = {} self._packs_at_load = set() for _index, key, value in self._iter_disk_pack_index(): - name = key[0].decode('ascii') + name = key[0].decode("ascii") self._names[name] = self._parse_index_sizes(value) self._packs_at_load.add((name, value)) result = True @@ -1117,7 +1276,7 @@ def ensure_loaded(self): def _parse_index_sizes(self, value): """Parse a string of index sizes.""" - return tuple(int(digits) for digits in value.split(b' ')) + return tuple(int(digits) for digits in value.split(b" ")) def get_pack_by_name(self, name): """Get a Pack object by name. @@ -1128,16 +1287,23 @@ def get_pack_by_name(self, name): try: return self._packs_by_name[name] except KeyError: - rev_index = self._make_index(name, '.rix') - inv_index = self._make_index(name, '.iix') - txt_index = self._make_index(name, '.tix') - sig_index = self._make_index(name, '.six') + rev_index = self._make_index(name, ".rix") + inv_index = self._make_index(name, ".iix") + txt_index = self._make_index(name, ".tix") + sig_index = self._make_index(name, ".six") if self.chk_index is not None: - chk_index = self._make_index(name, '.cix', is_chk=True) + chk_index = self._make_index(name, ".cix", is_chk=True) else: chk_index = None - result = ExistingPack(self._pack_transport, name, rev_index, - inv_index, txt_index, sig_index, chk_index) + result = ExistingPack( + self._pack_transport, + name, + rev_index, + inv_index, + txt_index, + sig_index, + chk_index, + ) self.add_pack_to_memory(result) return result @@ -1147,25 +1313,33 @@ def _resume_pack(self, name): :param name: The name of the pack - e.g. '123456' :return: A Pack object. """ - if not re.match('[a-f0-9]{32}', name): + if not re.match("[a-f0-9]{32}", name): # Tokens should be md5sums of the suspended pack file, i.e. 32 hex # digits. raise errors.UnresumableWriteGroup( - self.repo, [name], 'Malformed write group token') + self.repo, [name], "Malformed write group token" + ) try: - rev_index = self._make_index(name, '.rix', resume=True) - inv_index = self._make_index(name, '.iix', resume=True) - txt_index = self._make_index(name, '.tix', resume=True) - sig_index = self._make_index(name, '.six', resume=True) + rev_index = self._make_index(name, ".rix", resume=True) + inv_index = self._make_index(name, ".iix", resume=True) + txt_index = self._make_index(name, ".tix", resume=True) + sig_index = self._make_index(name, ".six", resume=True) if self.chk_index is not None: - chk_index = self._make_index(name, '.cix', resume=True, - is_chk=True) + chk_index = self._make_index(name, ".cix", resume=True, is_chk=True) else: chk_index = None - result = self.resumed_pack_factory(name, rev_index, inv_index, - txt_index, sig_index, self._upload_transport, - self._pack_transport, self._index_transport, self, - chk_index=chk_index) + result = self.resumed_pack_factory( + name, + rev_index, + inv_index, + txt_index, + sig_index, + self._upload_transport, + self._pack_transport, + self._index_transport, + self, + chk_index=chk_index, + ) except _mod_transport.NoSuchFile as e: raise errors.UnresumableWriteGroup(self.repo, [name], str(e)) from e self.add_pack_to_memory(result) @@ -1180,8 +1354,7 @@ def allocate(self, a_new_pack): """ self.ensure_loaded() if a_new_pack.name in self._names: - raise errors.BzrError( - f'Pack {a_new_pack.name!r} already exists in {self}') + raise errors.BzrError(f"Pack {a_new_pack.name!r} already exists in {self}") self._names[a_new_pack.name] = tuple(a_new_pack.index_sizes) self.add_pack_to_memory(a_new_pack) @@ -1192,8 +1365,7 @@ def _iter_disk_pack_index(self): detect updates from others during our write operation. :return: An iterator of the index contents. """ - return self._index_class(self.transport, 'pack-names', None - ).iter_all_entries() + return self._index_class(self.transport, "pack-names", None).iter_all_entries() def _make_index(self, name, suffix, resume=False, is_chk=False): size_offset = self._suffix_offsets[suffix] @@ -1204,8 +1376,9 @@ def _make_index(self, name, suffix, resume=False, is_chk=False): else: transport = self._index_transport index_size = self._names[name][size_offset] - index = self._index_class(transport, index_name, index_size, - unlimited_cache=is_chk) + index = self._index_class( + transport, index_name, index_size, unlimited_cache=is_chk + ) if is_chk and self._index_class is btree_index.BTreeGraphIndex: index._leaf_factory = btree_index._gcchk_factory return index @@ -1244,30 +1417,33 @@ def _obsolete_packs(self, packs): for pack in packs: try: try: - pack.pack_transport.move(pack.file_name(), - '../obsolete_packs/' + pack.file_name()) + pack.pack_transport.move( + pack.file_name(), "../obsolete_packs/" + pack.file_name() + ) except _mod_transport.NoSuchFile: # perhaps obsolete_packs was removed? Let's create it and # try again try: - pack.pack_transport.mkdir('../obsolete_packs/') + pack.pack_transport.mkdir("../obsolete_packs/") except _mod_transport.FileExists: pass - pack.pack_transport.move(pack.file_name(), - '../obsolete_packs/' + pack.file_name()) + pack.pack_transport.move( + pack.file_name(), "../obsolete_packs/" + pack.file_name() + ) except (errors.PathError, errors.TransportError) as e: # TODO: Should these be warnings or mutters? mutter(f"couldn't rename obsolete pack, skipping it:\n{e}") # TODO: Probably needs to know all possible indices for this pack # - or maybe list the directory and move all indices matching this # name whether we recognize it or not? - suffixes = ['.iix', '.six', '.tix', '.rix'] + suffixes = [".iix", ".six", ".tix", ".rix"] if self.chk_index is not None: - suffixes.append('.cix') + suffixes.append(".cix") for suffix in suffixes: try: - self._index_transport.move(pack.name + suffix, - '../obsolete_packs/' + pack.name + suffix) + self._index_transport.move( + pack.name + suffix, "../obsolete_packs/" + pack.name + suffix + ) except (errors.PathError, errors.TransportError) as e: mutter(f"couldn't rename obsolete index, skipping it:\n{e}") @@ -1282,14 +1458,14 @@ def pack_distribution(self, total_revisions): digits = reversed(str(total_revisions)) result = [] for exponent, count in enumerate(digits): - size = 10 ** exponent + size = 10**exponent for _pos in range(int(count)): result.append(size) return list(reversed(result)) def _pack_tuple(self, name): """Return a tuple with the transport and file name for a pack name.""" - return self._pack_transport, name + '.pack' + return self._pack_transport, name + ".pack" def _remove_pack_from_memory(self, pack): """Remove pack from the packs accessed by this repository. @@ -1307,7 +1483,7 @@ def _remove_pack_indices(self, pack, ignore_missing=False): :param ignore_missing: Suppress KeyErrors from calling remove_index. """ for index_type in Pack.index_definitions: - attr_name = index_type + '_index' + attr_name = index_type + "_index" aggregate_index = getattr(self, attr_name) if aggregate_index is not None: pack_index = getattr(pack, attr_name) @@ -1354,14 +1530,13 @@ def _diff_pack_names(self): # load the disk nodes across disk_nodes = set() for _index, key, value in self._iter_disk_pack_index(): - disk_nodes.add((key[0].decode('ascii'), value)) + disk_nodes.add((key[0].decode("ascii"), value)) orig_disk_nodes = set(disk_nodes) # do a two-way diff against our original content current_nodes = set() for name, sizes in self._names.items(): - current_nodes.add( - (name, b' '.join(b'%d' % size for size in sizes))) + current_nodes.add((name, b" ".join(b"%d" % size for size in sizes))) # Packs no longer present in the repository, which were present when we # locked the repository @@ -1437,15 +1612,22 @@ def _save_pack_names(self, clear_obsolete_packs=False, obsolete_packs=None): self.lock_names() try: builder = self._index_builder_class() - (disk_nodes, deleted_nodes, new_nodes, - orig_disk_nodes) = self._diff_pack_names() + ( + disk_nodes, + deleted_nodes, + new_nodes, + orig_disk_nodes, + ) = self._diff_pack_names() # TODO: handle same-name, index-size-changes here - # e.g. use the value from disk, not ours, *unless* we're the one # changing it. for name, value in disk_nodes: - builder.add_node((name.encode('ascii'), ), value) - self.transport.put_file('pack-names', builder.finish(), - mode=self.repo.controldir._get_file_mode()) + builder.add_node((name.encode("ascii"),), value) + self.transport.put_file( + "pack-names", + builder.finish(), + mode=self.repo.controldir._get_file_mode(), + ) self._packs_at_load = disk_nodes if clear_obsolete_packs: to_preserve = None @@ -1462,8 +1644,9 @@ def _save_pack_names(self, clear_obsolete_packs=False, obsolete_packs=None): # disk yet. However, the new pack object is not easily # accessible here (it would have to be passed through the # autopacking code, etc.) - obsolete_packs = [o for o in obsolete_packs - if o.name not in already_obsolete] + obsolete_packs = [ + o for o in obsolete_packs if o.name not in already_obsolete + ] self._obsolete_packs(obsolete_packs) return [new_node[0] for new_node in new_nodes] @@ -1485,14 +1668,19 @@ def reload_pack_names(self): if first_read: return True # out the new value. - (disk_nodes, deleted_nodes, new_nodes, - orig_disk_nodes) = self._diff_pack_names() + ( + disk_nodes, + deleted_nodes, + new_nodes, + orig_disk_nodes, + ) = self._diff_pack_names() # _packs_at_load is meant to be the explicit list of names in # 'pack-names' at then start. As such, it should not contain any # pending names that haven't been written out yet. self._packs_at_load = orig_disk_nodes - (removed, added, - modified) = self._syncronize_pack_names_from_disk_nodes(disk_nodes) + (removed, added, modified) = self._syncronize_pack_names_from_disk_nodes( + disk_nodes + ) if removed or added or modified: return True return False @@ -1520,16 +1708,16 @@ def _clear_obsolete_packs(self, preserve=None): were found in obsolete_packs. """ found = [] - obsolete_pack_transport = self.transport.clone('obsolete_packs') + obsolete_pack_transport = self.transport.clone("obsolete_packs") if preserve is None: preserve = set() try: - obsolete_pack_files = obsolete_pack_transport.list_dir('.') + obsolete_pack_files = obsolete_pack_transport.list_dir(".") except _mod_transport.NoSuchFile: return found for filename in obsolete_pack_files: name, ext = osutils.splitext(filename) - if ext == '.pack': + if ext == ".pack": found.append(name) if name in preserve: continue @@ -1543,24 +1731,25 @@ def _start_write_group(self): # Do not permit preparation for writing if we're not in a 'write lock'. if not self.repo.is_write_locked(): raise errors.NotWriteLocked(self) - self._new_pack = self.pack_factory(self, upload_suffix='.pack', - file_mode=self.repo.controldir._get_file_mode()) + self._new_pack = self.pack_factory( + self, upload_suffix=".pack", file_mode=self.repo.controldir._get_file_mode() + ) # allow writing: queue writes to a new index - self.revision_index.add_writable_index(self._new_pack.revision_index, - self._new_pack) - self.inventory_index.add_writable_index(self._new_pack.inventory_index, - self._new_pack) - self.text_index.add_writable_index(self._new_pack.text_index, - self._new_pack) + self.revision_index.add_writable_index( + self._new_pack.revision_index, self._new_pack + ) + self.inventory_index.add_writable_index( + self._new_pack.inventory_index, self._new_pack + ) + self.text_index.add_writable_index(self._new_pack.text_index, self._new_pack) self._new_pack.text_index.set_optimize(combine_backing_indices=False) - self.signature_index.add_writable_index(self._new_pack.signature_index, - self._new_pack) + self.signature_index.add_writable_index( + self._new_pack.signature_index, self._new_pack + ) if self.chk_index is not None: - self.chk_index.add_writable_index(self._new_pack.chk_index, - self._new_pack) + self.chk_index.add_writable_index(self._new_pack.chk_index, self._new_pack) self.repo.chk_bytes._index._add_callback = self.chk_index.add_callback - self._new_pack.chk_index.set_optimize( - combine_backing_indices=False) + self._new_pack.chk_index.set_optimize(combine_backing_indices=False) self.repo.inventories._index._add_callback = self.inventory_index.add_callback self.repo.revisions._index._add_callback = self.revision_index.add_callback @@ -1572,19 +1761,21 @@ def _abort_write_group(self): # forget what names there are if self._new_pack is not None: with contextlib.ExitStack() as stack: - stack.callback(setattr, self, '_new_pack', None) + stack.callback(setattr, self, "_new_pack", None) # If we aborted while in the middle of finishing the write # group, _remove_pack_indices could fail because the indexes are # already gone. But they're not there we shouldn't fail in this # case, so we pass ignore_missing=True. - stack.callback(self._remove_pack_indices, self._new_pack, - ignore_missing=True) + stack.callback( + self._remove_pack_indices, self._new_pack, ignore_missing=True + ) self._new_pack.abort() for resumed_pack in self._resumed_packs: with contextlib.ExitStack() as stack: # See comment in previous finally block. - stack.callback(self._remove_pack_indices, resumed_pack, - ignore_missing=True) + stack.callback( + self._remove_pack_indices, resumed_pack, ignore_missing=True + ) resumed_pack.abort() del self._resumed_packs[:] @@ -1606,21 +1797,23 @@ def _check_new_inventories(self): def _commit_write_group(self): all_missing = set() for prefix, versioned_file in ( - ('revisions', self.repo.revisions), - ('inventories', self.repo.inventories), - ('texts', self.repo.texts), - ('signatures', self.repo.signatures), - ): + ("revisions", self.repo.revisions), + ("inventories", self.repo.inventories), + ("texts", self.repo.texts), + ("signatures", self.repo.signatures), + ): missing = versioned_file.get_missing_compression_parent_keys() all_missing.update([(prefix,) + key for key in missing]) if all_missing: raise errors.BzrCheckError( - f"Repository {self.repo} has missing compression parent(s) {sorted(all_missing)!r} ") + f"Repository {self.repo} has missing compression parent(s) {sorted(all_missing)!r} " + ) problems = self._check_new_inventories() if problems: - problems_summary = '\n'.join(problems) + problems_summary = "\n".join(problems) raise errors.BzrCheckError( - "Cannot add revision(s) to repository: " + problems_summary) + "Cannot add revision(s) to repository: " + problems_summary + ) self._remove_pack_indices(self._new_pack) any_new_content = False if self._new_pack.data_inserted(): @@ -1702,8 +1895,15 @@ class PackRepository(MetaDirVersionedFileRepository): _revision_serializer: RevisionSerializer _inventory_serializer: InventorySerializer - def __init__(self, _format, a_controldir, control_files, _commit_builder_class, - _revision_serializer, _inventory_serializer): + def __init__( + self, + _format, + a_controldir, + control_files, + _commit_builder_class, + _revision_serializer, + _inventory_serializer, + ): MetaDirRepository.__init__(self, _format, a_controldir, control_files) self._commit_builder_class = _commit_builder_class self._revision_serializer = _revision_serializer @@ -1711,7 +1911,8 @@ def __init__(self, _format, a_controldir, control_files, _commit_builder_class, self._reconcile_fixes_text_parents = True if self._format.supports_external_lookups: self._unstacked_provider = graph.CachingParentsProvider( - self._make_parents_provider_unstacked()) + self._make_parents_provider_unstacked() + ) else: self._unstacked_provider = graph.CachingParentsProvider(self) self._unstacked_provider.disable_cache() @@ -1728,8 +1929,9 @@ def _abort_write_group(self): def _make_parents_provider(self): if not self._format.supports_external_lookups: return self._unstacked_provider - return graph.StackedParentsProvider(_LazyListJoin( - [self._unstacked_provider], self._fallback_repositories)) + return graph.StackedParentsProvider( + _LazyListJoin([self._unstacked_provider], self._fallback_repositories) + ) def _refresh_data(self): if not self.is_locked(): @@ -1791,9 +1993,9 @@ def lock_write(self, token=None): if self._write_lock_count == 1: self._transaction = transactions.WriteTransaction() if not locked: - if debug.debug_flag_enabled('relock') and self._prev_lock == 'w': - note('%r was write locked again', self) - self._prev_lock = 'w' + if debug.debug_flag_enabled("relock") and self._prev_lock == "w": + note("%r was write locked again", self) + self._prev_lock = "w" self._unstacked_provider.enable_cache() for repo in self._fallback_repositories: # Writes don't affect fallback repos @@ -1812,9 +2014,9 @@ def lock_read(self): else: self.control_files.lock_read() if not locked: - if debug.debug_flag_enabled('relock') and self._prev_lock == 'r': - note('%r was read locked again', self) - self._prev_lock = 'r' + if debug.debug_flag_enabled("relock") and self._prev_lock == "r": + note("%r was read locked again", self) + self._prev_lock = "r" self._unstacked_provider.enable_cache() for repo in self._fallback_repositories: repo.lock_read() @@ -1837,11 +2039,13 @@ def pack(self, hint=None, clean_obsolete_packs=False): """ with self.lock_write(): self._pack_collection.pack( - hint=hint, clean_obsolete_packs=clean_obsolete_packs) + hint=hint, clean_obsolete_packs=clean_obsolete_packs + ) def reconcile(self, other=None, thorough=False): """Reconcile this repository.""" from .reconcile import PackReconciler + with self.lock_write(): reconciler = PackReconciler(self, thorough=thorough) return reconciler.reconcile() @@ -1857,7 +2061,8 @@ def unlock(self): self._transaction = None self._write_lock_count = 0 raise errors.BzrError( - f'Must end write group before releasing write lock on {self}') + f"Must end write group before releasing write lock on {self}" + ) if self._write_lock_count: self._write_lock_count -= 1 if not self._write_lock_count: @@ -1921,14 +2126,13 @@ def initialize(self, a_controldir, shared=False): :param shared: If true the repository will be initialized as a shared repository. """ - mutter('creating repository in %s.', a_controldir.transport.base) - dirs = ['indices', 'obsolete_packs', 'packs', 'upload'] + mutter("creating repository in %s.", a_controldir.transport.base) + dirs = ["indices", "obsolete_packs", "packs", "upload"] builder = self.index_builder_class() - files = [('pack-names', builder.finish())] - utf8_files = [('format', self.get_format_string())] + files = [("pack-names", builder.finish())] + utf8_files = [("format", self.get_format_string())] - self._upload_blank_content( - a_controldir, dirs, files, utf8_files, shared) + self._upload_blank_content(a_controldir, dirs, files, utf8_files, shared) repository = self.open(a_controldir=a_controldir, _found=True) self._run_post_repo_init_hooks(repository, a_controldir, shared) return repository @@ -1946,14 +2150,17 @@ def open(self, a_controldir, _found=False, _override_transport=None): repo_transport = _override_transport else: repo_transport = a_controldir.get_repository_transport(None) - control_files = lockable_files.LockableFiles(repo_transport, - 'lock', lockdir.LockDir) - return self.repository_class(_format=self, - a_controldir=a_controldir, - control_files=control_files, - _commit_builder_class=self._commit_builder_class, - _revision_serializer=self._revision_serializer, - _inventory_serializer=self._inventory_serializer) + control_files = lockable_files.LockableFiles( + repo_transport, "lock", lockdir.LockDir + ) + return self.repository_class( + _format=self, + a_controldir=a_controldir, + control_files=control_files, + _commit_builder_class=self._commit_builder_class, + _revision_serializer=self._revision_serializer, + _inventory_serializer=self._inventory_serializer, + ) class RetryPackOperations(RetryWithNewPacks): @@ -1965,8 +2172,10 @@ class RetryPackOperations(RetryWithNewPacks): internal_error = True - _fmt = ("Pack files have changed, reload and try pack again." - " context: %(context)s %(orig_error)s") + _fmt = ( + "Pack files have changed, reload and try pack again." + " context: %(context)s %(orig_error)s" + ) class _DirectPackAccess: @@ -2000,8 +2209,7 @@ def add_raw_record(self, key, size, raw_data): (index, pos, length), where the index field is the write_index object supplied to the PackAccess object. """ - p_offset, p_length = self._container_writer.add_bytes_record( - raw_data, size, []) + p_offset, p_length = self._container_writer.add_bytes_record(raw_data, size, []) return (self._write_index, p_offset, p_length) def add_raw_records(self, key_sizes, raw_data): @@ -2018,15 +2226,15 @@ def add_raw_records(self, key_sizes, raw_data): length), where the index field is the write_index object supplied to the PackAccess object. """ - raw_data = b''.join(raw_data) + raw_data = b"".join(raw_data) if not isinstance(raw_data, bytes): - raise AssertionError( - f'data must be plain bytes was {type(raw_data)}') + raise AssertionError(f"data must be plain bytes was {type(raw_data)}") result = [] offset = 0 for key, size in key_sizes: result.append( - self.add_raw_record(key, size, [raw_data[offset:offset + size]])) + self.add_raw_record(key, size, [raw_data[offset : offset + size]]) + ) offset += size return result @@ -2050,7 +2258,7 @@ def get_raw_records(self, memos_for_retrieval): # first pass, group into same-index requests request_lists = [] current_index = None - for (index, offset, length) in memos_for_retrieval: + for index, offset, length in memos_for_retrieval: if current_index == index: current_list.append((offset, length)) else: @@ -2072,9 +2280,9 @@ def get_raw_records(self, memos_for_retrieval): # If we don't have a _reload_func there is nothing that can # be done raise - raise RetryWithNewPacks(index, - reload_occurred=True, - exc_info=sys.exc_info()) from e + raise RetryWithNewPacks( + index, reload_occurred=True, exc_info=sys.exc_info() + ) from e try: reader = pack.make_readv_reader(transport, path, offsets) for _names, read_func in reader.iter_records(): @@ -2084,9 +2292,11 @@ def get_raw_records(self, memos_for_retrieval): # missing on disk, we need to trigger a reload, and start over. if self._reload_func is None: raise - raise RetryWithNewPacks(transport.abspath(path), - reload_occurred=False, - exc_info=sys.exc_info()) from e + raise RetryWithNewPacks( + transport.abspath(path), + reload_occurred=False, + exc_info=sys.exc_info(), + ) from e def set_writer(self, writer, index, transport_packname): """Set a writer to use for adding data.""" diff --git a/breezy/bzr/reconcile.py b/breezy/bzr/reconcile.py index f9ec38297f..0638f23fd9 100644 --- a/breezy/bzr/reconcile.py +++ b/breezy/bzr/reconcile.py @@ -17,11 +17,11 @@ """Reconcilers are able to fix some potential data errors in a branch.""" __all__ = [ - 'BranchReconciler', - 'KnitReconciler', - 'PackReconciler', - 'VersionedFileRepoReconciler', - ] + "BranchReconciler", + "KnitReconciler", + "PackReconciler", + "VersionedFileRepoReconciler", +] from .. import errors, ui from .. import revision as _mod_revision @@ -66,8 +66,7 @@ def reconcile(self): * `garbage_inventories`: The number of inventory objects without revisions that were garbage collected. """ - with self.repo.lock_write(), \ - ui.ui_factory.nested_progress_bar() as self.pb: + with self.repo.lock_write(), ui.ui_factory.nested_progress_bar() as self.pb: self._reconcile_steps() ret = ReconcileResult() ret.aborted = self.aborted @@ -88,7 +87,7 @@ def _reweave_inventory(self): (self.thorough) are treated as requiring the reweave. """ self.repo.get_transaction() - self.pb.update(gettext('Reading inventory data')) + self.pb.update(gettext("Reading inventory data")) self.inventory = self.repo.inventories self.revisions = self.repo.revisions # the total set of revisions to process @@ -106,32 +105,33 @@ def _reweave_inventory(self): self._check_garbage_inventories() # if there are no inconsistent_parents and # (no garbage inventories or we are not doing a thorough check) - if (not self.inconsistent_parents - and (not self.garbage_inventories or not self.thorough)): - ui.ui_factory.note(gettext('Inventory ok.')) + if not self.inconsistent_parents and ( + not self.garbage_inventories or not self.thorough + ): + ui.ui_factory.note(gettext("Inventory ok.")) return - self.pb.update(gettext('Backing up inventory'), 0, 0) + self.pb.update(gettext("Backing up inventory"), 0, 0) self.repo._backup_inventory() - ui.ui_factory.note(gettext('Backup inventory created.')) + ui.ui_factory.note(gettext("Backup inventory created.")) new_inventories = self.repo._temp_inventories() # we have topological order of revisions and non ghost parents ready. self._setup_steps(len(self._rev_graph)) revision_keys = [(rev_id,) for rev_id in topo_sort(self._rev_graph)] stream = self._change_inv_parents( - self.inventory.get_record_stream(revision_keys, 'unordered', True), + self.inventory.get_record_stream(revision_keys, "unordered", True), self._new_inv_parents, - set(revision_keys)) + set(revision_keys), + ) new_inventories.insert_record_stream(stream) # if this worked, the set of new_inventories.keys should equal # self.pending - if not (set(new_inventories.keys()) - == {(revid,) for revid in self.pending}): + if not (set(new_inventories.keys()) == {(revid,) for revid in self.pending}): raise AssertionError() - self.pb.update(gettext('Writing weave')) + self.pb.update(gettext("Writing weave")) self.repo._activate_new_inventory() self.inventory = None - ui.ui_factory.note(gettext('Inventory regenerated.')) + ui.ui_factory.note(gettext("Inventory regenerated.")) def _new_inv_parents(self, revision_key): """Lookup ghost-filtered parents for revision_key.""" @@ -146,13 +146,14 @@ def _change_inv_parents(self, stream, get_parents, all_revision_keys): # The check for the left most parent only handles knit # compressors, but this code only applies to knit and weave # repositories anyway. - chunks = record.get_bytes_as('chunked') - yield ChunkedContentFactory(record.key, wanted_parents, record.sha1, chunks) + chunks = record.get_bytes_as("chunked") + yield ChunkedContentFactory( + record.key, wanted_parents, record.sha1, chunks + ) else: - adapted_record = AdapterFactory( - record.key, wanted_parents, record) + adapted_record = AdapterFactory(record.key, wanted_parents, record) yield adapted_record - self._reweave_step('adding inventories') + self._reweave_step("adding inventories") def _setup_steps(self, new_total): """Setup the markers we need to control the progress bar.""" @@ -163,14 +164,14 @@ def _graph_revision(self, rev_id): """Load a revision into the revision graph.""" # pick a random revision # analyse revision id rev_id and put it in the stack. - self._reweave_step('loading revisions') + self._reweave_step("loading revisions") rev = self.repo.get_revision_reconcile(rev_id) parents = [] for parent in rev.parent_ids: if self._parent_is_available(parent): parents.append(parent) else: - mutter('found ghost %s', parent) + mutter("found ghost %s", parent) self._rev_graph[rev_id] = parents def _check_garbage_inventories(self): @@ -186,7 +187,7 @@ def _check_garbage_inventories(self): garbage = inventories.difference(revisions) self.garbage_inventories = len(garbage) for revision_key in garbage: - mutter('Garbage inventory {%s} found.', revision_key[-1]) + mutter("Garbage inventory {%s} found.", revision_key[-1]) def _parent_is_available(self, parent): """True if parent is a fully available revision. @@ -196,8 +197,8 @@ def _parent_is_available(self, parent): """ if parent in self._rev_graph: return True - inv_present = (1 == len(self.inventory.get_parent_map([(parent,)]))) - return (inv_present and self.repo.has_revision(parent)) + inv_present = 1 == len(self.inventory.get_parent_map([(parent,)])) + return inv_present and self.repo.has_revision(parent) def _reweave_step(self, message): """Mark a single step of regeneration complete.""" @@ -226,24 +227,24 @@ def _reconcile_steps(self): def _load_indexes(self): """Load indexes for the reconciliation.""" self.transaction = self.repo.get_transaction() - self.pb.update(gettext('Reading indexes'), 0, 2) + self.pb.update(gettext("Reading indexes"), 0, 2) self.inventory = self.repo.inventories - self.pb.update(gettext('Reading indexes'), 1, 2) + self.pb.update(gettext("Reading indexes"), 1, 2) self.repo._check_for_inconsistent_revision_parents() self.revisions = self.repo.revisions - self.pb.update(gettext('Reading indexes'), 2, 2) + self.pb.update(gettext("Reading indexes"), 2, 2) def _gc_inventory(self): """Remove inventories that are not referenced from the revision store.""" - self.pb.update(gettext('Checking unused inventories'), 0, 1) + self.pb.update(gettext("Checking unused inventories"), 0, 1) self._check_garbage_inventories() - self.pb.update(gettext('Checking unused inventories'), 1, 3) + self.pb.update(gettext("Checking unused inventories"), 1, 3) if not self.garbage_inventories: - ui.ui_factory.note(gettext('Inventory ok.')) + ui.ui_factory.note(gettext("Inventory ok.")) return - self.pb.update(gettext('Backing up inventory'), 0, 0) + self.pb.update(gettext("Backing up inventory"), 0, 0) self.repo._backup_inventory() - ui.ui_factory.note(gettext('Backup Inventory created')) + ui.ui_factory.note(gettext("Backup Inventory created")) # asking for '' should never return a non-empty weave new_inventories = self.repo._temp_inventories() # we have topological order of revisions and non ghost parents ready. @@ -252,18 +253,19 @@ def _gc_inventory(self): [key[-1] for key in revision_keys] self._setup_steps(len(revision_keys)) stream = self._change_inv_parents( - self.inventory.get_record_stream(revision_keys, 'unordered', True), + self.inventory.get_record_stream(revision_keys, "unordered", True), graph.__getitem__, - set(revision_keys)) + set(revision_keys), + ) new_inventories.insert_record_stream(stream) # if this worked, the set of new_inventory_vf.names should equal # the revisionds list if set(new_inventories.keys()) != set(revision_keys): raise AssertionError() - self.pb.update(gettext('Writing weave')) + self.pb.update(gettext("Writing weave")) self.repo._activate_new_inventory() self.inventory = None - ui.ui_factory.note(gettext('Inventory regenerated.')) + ui.ui_factory.note(gettext("Inventory regenerated.")) def _fix_text_parents(self): """Fix bad versionedfile parent entries. @@ -276,11 +278,11 @@ def _fix_text_parents(self): """ self.repo.get_transaction() versions = [key[-1] for key in self.revisions.keys()] - mutter('Prepopulating revision text cache with %d revisions', - len(versions)) + mutter("Prepopulating revision text cache with %d revisions", len(versions)) vf_checker = self.repo._get_versioned_file_checker() bad_parents, unused_versions = vf_checker.check_file_version_parents( - self.repo.texts, self.pb) + self.repo.texts, self.pb + ) text_index = vf_checker.text_index per_id_bad_parents = {} for key in unused_versions: @@ -301,27 +303,32 @@ def _fix_text_parents(self): versions_list.append(text_key[1]) # Do the reconcile of individual weaves. for num, file_id in enumerate(per_id_bad_parents): - self.pb.update(gettext('Fixing text parents'), num, - len(per_id_bad_parents)) + self.pb.update(gettext("Fixing text parents"), num, len(per_id_bad_parents)) versions_with_bad_parents = per_id_bad_parents[file_id] - id_unused_versions = {key[-1] for key in unused_versions - if key[0] == file_id} + id_unused_versions = { + key[-1] for key in unused_versions if key[0] == file_id + } if file_id in file_id_versions: file_versions = file_id_versions[file_id] else: # This id was present in the disk store but is not referenced # by any revision at all. file_versions = [] - self._fix_text_parent(file_id, versions_with_bad_parents, - id_unused_versions, file_versions) + self._fix_text_parent( + file_id, versions_with_bad_parents, id_unused_versions, file_versions + ) - def _fix_text_parent(self, file_id, versions_with_bad_parents, - unused_versions, all_versions): + def _fix_text_parent( + self, file_id, versions_with_bad_parents, unused_versions, all_versions + ): """Fix bad versionedfile entries in a single versioned file.""" - mutter('fixing text parent: %r (%d versions)', file_id, - len(versions_with_bad_parents)) - mutter('(%d are unused)', len(unused_versions)) - new_file_id = b'temp:%s' % file_id + mutter( + "fixing text parent: %r (%d versions)", + file_id, + len(versions_with_bad_parents), + ) + mutter("(%d are unused)", len(unused_versions)) + new_file_id = b"temp:%s" % file_id new_parents = {} needed_keys = set() for version in all_versions: @@ -333,17 +340,18 @@ def _fix_text_parent(self, file_id, versions_with_bad_parents, pmap = self.repo.texts.get_parent_map([(file_id, version)]) parents = [key[-1] for key in pmap[(file_id, version)]] new_parents[(new_file_id, version)] = [ - (new_file_id, parent) for parent in parents] + (new_file_id, parent) for parent in parents + ] needed_keys.add((file_id, version)) def fix_parents(stream): for record in stream: - chunks = record.get_bytes_as('chunked') + chunks = record.get_bytes_as("chunked") new_key = (new_file_id, record.key[-1]) parents = new_parents[new_key] yield ChunkedContentFactory(new_key, parents, record.sha1, chunks) - stream = self.repo.texts.get_record_stream( - needed_keys, 'topological', True) + + stream = self.repo.texts.get_record_stream(needed_keys, "topological", True) self.repo._remove_file_id(new_file_id) self.repo.texts.insert_record_stream(fix_parents(stream)) self.repo._remove_file_id(file_id) @@ -369,10 +377,8 @@ class PackReconciler(VersionedFileRepoReconciler): # - unlock the names list # https://bugs.launchpad.net/bzr/+bug/154173 - def __init__(self, repo, other=None, thorough=False, - canonicalize_chks=False): - super().__init__(repo, other=other, - thorough=thorough) + def __init__(self, repo, other=None, thorough=False, canonicalize_chks=False): + super().__init__(repo, other=other, thorough=thorough) self.canonicalize_chks = canonicalize_chks def _reconcile_steps(self): @@ -385,22 +391,25 @@ def _reconcile_steps(self): try: packs = collection.all_packs() all_revisions = self.repo.all_revision_ids() - total_inventories = len(list( - collection.inventory_index.combined_index.iter_all_entries())) + total_inventories = len( + list(collection.inventory_index.combined_index.iter_all_entries()) + ) if len(all_revisions): if self.canonicalize_chks: reconcile_meth = self.repo._canonicalize_chks_pack else: reconcile_meth = self.repo._reconcile_pack - new_pack = reconcile_meth(collection, packs, ".reconcile", - all_revisions, self.pb) + new_pack = reconcile_meth( + collection, packs, ".reconcile", all_revisions, self.pb + ) if new_pack is not None: self._discard_and_save(packs) else: # only make a new pack when there is data to copy. self._discard_and_save(packs) - self.garbage_inventories = total_inventories - len(list( - collection.inventory_index.combined_index.iter_all_entries())) + self.garbage_inventories = total_inventories - len( + list(collection.inventory_index.combined_index.iter_all_entries()) + ) finally: collection._unlock_names() @@ -429,8 +438,7 @@ def __init__(self, a_branch, thorough=False): self.branch = a_branch def reconcile(self): - with self.branch.lock_write(), \ - ui.ui_factory.nested_progress_bar() as self.pb: + with self.branch.lock_write(), ui.ui_factory.nested_progress_bar() as self.pb: ret = ReconcileResult() ret.fixed_history = self._reconcile_steps() return ret @@ -444,7 +452,8 @@ def _reconcile_revision_history(self): graph = self.branch.repository.get_graph() try: for revid in graph.iter_lefthand_ancestry( - last_revision_id, (_mod_revision.NULL_REVISION,)): + last_revision_id, (_mod_revision.NULL_REVISION,) + ): real_history.append(revid) except errors.RevisionNotPresent: pass # Hit a ghost left hand parent @@ -454,12 +463,13 @@ def _reconcile_revision_history(self): # set_revision_history, as this will regenerate it again. # Not really worth a whole BranchReconciler class just for this, # though. - ui.ui_factory.note(gettext('Fixing last revision info {0} ' - ' => {1}').format( - last_revno, len(real_history))) - self.branch.set_last_revision_info(len(real_history), - last_revision_id) + ui.ui_factory.note( + gettext("Fixing last revision info {0} " " => {1}").format( + last_revno, len(real_history) + ) + ) + self.branch.set_last_revision_info(len(real_history), last_revision_id) return True else: - ui.ui_factory.note(gettext('revision_history ok.')) + ui.ui_factory.note(gettext("revision_history ok.")) return False diff --git a/breezy/bzr/remote.py b/breezy/bzr/remote.py index fdebdd51e1..b1ff744c20 100644 --- a/breezy/bzr/remote.py +++ b/breezy/bzr/remote.py @@ -113,11 +113,13 @@ def _call_with_body_bytes(self, method, args, body_bytes, **err_context): except errors.ErrorFromSmartServer as err: self._translate_error(err, **err_context) - def _call_with_body_bytes_expecting_body(self, method, args, body_bytes, - **err_context): + def _call_with_body_bytes_expecting_body( + self, method, args, body_bytes, **err_context + ): try: return self._client.call_with_body_bytes_expecting_body( - method, args, body_bytes) + method, args, body_bytes + ) except errors.ErrorFromSmartServer as err: self._translate_error(err, **err_context) @@ -125,9 +127,9 @@ def _call_with_body_bytes_expecting_body(self, method, args, body_bytes, def response_tuple_to_repo_format(response): """Convert a response tuple describing a repository format to a format.""" format = RemoteRepositoryFormat() - format._rich_root_data = (response[0] == b'yes') - format._supports_tree_reference = (response[1] == b'yes') - format._supports_external_lookups = (response[2] == b'yes') + format._rich_root_data = response[0] == b"yes" + format._supports_tree_reference = response[1] == b"yes" + format._supports_external_lookups = response[2] == b"yes" format._network_name = response[3] return format @@ -135,6 +137,7 @@ def response_tuple_to_repo_format(response): # Note that RemoteBzrDirProber lives in breezy.bzrdir so breezy.bzr.remote # does not have to be imported unless a remote format is involved. + class RemoteBzrDirFormat(_mod_bzrdir.BzrDirMetaFormat1): """Format representing bzrdirs accessed via a smart server.""" @@ -156,13 +159,12 @@ def __repr__(self): def get_format_description(self): if self._network_name: try: - real_format = controldir.network_format_registry.get( - self._network_name) + real_format = controldir.network_format_registry.get(self._network_name) except KeyError: pass else: - return 'Remote: ' + real_format.get_format_description() - return 'bzr remote bzrdir' + return "Remote: " + real_format.get_format_description() + return "bzr remote bzrdir" def get_format_string(self): raise NotImplementedError(self.get_format_string) @@ -184,12 +186,11 @@ def initialize_on_transport(self, transport): client = _SmartClient(client_medium) path = client.remote_path_from_transport(transport) try: - response = client.call(b'BzrDirFormat.initialize', path) + response = client.call(b"BzrDirFormat.initialize", path) except errors.ErrorFromSmartServer as err: _translate_error(err, path=path) - if response[0] != b'ok': - raise errors.SmartProtocolError( - f'unexpected response code {response}') + if response[0] != b"ok": + raise errors.SmartProtocolError(f"unexpected response code {response}") format = RemoteBzrDirFormat() self._supply_sub_formats_to(format) return RemoteBzrDir(transport, format) @@ -197,26 +198,34 @@ def initialize_on_transport(self, transport): def parse_NoneTrueFalse(self, arg): if not arg: return None - if arg == b'False': + if arg == b"False": return False - if arg == b'True': + if arg == b"True": return True raise AssertionError(f"invalid arg {arg!r}") def _serialize_NoneTrueFalse(self, arg): if arg is False: - return b'False' + return b"False" if arg: - return b'True' - return b'' + return b"True" + return b"" def _serialize_NoneString(self, arg): - return arg or b'' - - def initialize_on_transport_ex(self, transport, use_existing_dir=False, - create_prefix=False, force_new_repo=False, stacked_on=None, - stack_on_pwd=None, repo_format_name=None, make_working_trees=None, - shared_repo=False): + return arg or b"" + + def initialize_on_transport_ex( + self, + transport, + use_existing_dir=False, + create_prefix=False, + force_new_repo=False, + stacked_on=None, + stack_on_pwd=None, + repo_format_name=None, + make_working_trees=None, + shared_repo=False, + ): try: # hand off the request to the smart server client_medium = transport.get_smart_medium() @@ -228,7 +237,7 @@ def initialize_on_transport_ex(self, transport, use_existing_dir=False, if client_medium.should_probe(): try: server_version = client_medium.protocol_version() - if server_version != '2': + if server_version != "2": do_vfs = True else: do_vfs = False @@ -247,19 +256,46 @@ def initialize_on_transport_ex(self, transport, use_existing_dir=False, # TODO: lookup the local format from a server hint. local_dir_format = _mod_bzrdir.BzrDirMetaFormat1() self._supply_sub_formats_to(local_dir_format) - return local_dir_format.initialize_on_transport_ex(transport, - use_existing_dir=use_existing_dir, create_prefix=create_prefix, - force_new_repo=force_new_repo, stacked_on=stacked_on, - stack_on_pwd=stack_on_pwd, repo_format_name=repo_format_name, - make_working_trees=make_working_trees, shared_repo=shared_repo, - vfs_only=True) - return self._initialize_on_transport_ex_rpc(client, path, transport, - use_existing_dir, create_prefix, force_new_repo, stacked_on, - stack_on_pwd, repo_format_name, make_working_trees, shared_repo) - - def _initialize_on_transport_ex_rpc(self, client, path, transport, - use_existing_dir, create_prefix, force_new_repo, stacked_on, - stack_on_pwd, repo_format_name, make_working_trees, shared_repo): + return local_dir_format.initialize_on_transport_ex( + transport, + use_existing_dir=use_existing_dir, + create_prefix=create_prefix, + force_new_repo=force_new_repo, + stacked_on=stacked_on, + stack_on_pwd=stack_on_pwd, + repo_format_name=repo_format_name, + make_working_trees=make_working_trees, + shared_repo=shared_repo, + vfs_only=True, + ) + return self._initialize_on_transport_ex_rpc( + client, + path, + transport, + use_existing_dir, + create_prefix, + force_new_repo, + stacked_on, + stack_on_pwd, + repo_format_name, + make_working_trees, + shared_repo, + ) + + def _initialize_on_transport_ex_rpc( + self, + client, + path, + transport, + use_existing_dir, + create_prefix, + force_new_repo, + stacked_on, + stack_on_pwd, + repo_format_name, + make_working_trees, + shared_repo, + ): args = [] args.append(self._serialize_NoneTrueFalse(use_existing_dir)) args.append(self._serialize_NoneTrueFalse(create_prefix)) @@ -268,32 +304,41 @@ def _initialize_on_transport_ex_rpc(self, client, path, transport, # stack_on_pwd is often/usually our transport if stack_on_pwd: try: - stack_on_pwd = transport.relpath(stack_on_pwd).encode('utf-8') + stack_on_pwd = transport.relpath(stack_on_pwd).encode("utf-8") if not stack_on_pwd: - stack_on_pwd = b'.' + stack_on_pwd = b"." except errors.PathNotChild: pass args.append(self._serialize_NoneString(stack_on_pwd)) args.append(self._serialize_NoneString(repo_format_name)) args.append(self._serialize_NoneTrueFalse(make_working_trees)) args.append(self._serialize_NoneTrueFalse(shared_repo)) - request_network_name = self._network_name or \ - _mod_bzrdir.BzrDirFormat.get_default_format().network_name() + request_network_name = ( + self._network_name + or _mod_bzrdir.BzrDirFormat.get_default_format().network_name() + ) try: - response = client.call(b'BzrDirFormat.initialize_ex_1.16', - request_network_name, path, *args) + response = client.call( + b"BzrDirFormat.initialize_ex_1.16", request_network_name, path, *args + ) except errors.UnknownSmartMethod: client._medium._remember_remote_is_before((1, 16)) local_dir_format = _mod_bzrdir.BzrDirMetaFormat1() self._supply_sub_formats_to(local_dir_format) - return local_dir_format.initialize_on_transport_ex(transport, - use_existing_dir=use_existing_dir, create_prefix=create_prefix, - force_new_repo=force_new_repo, stacked_on=stacked_on, - stack_on_pwd=stack_on_pwd, repo_format_name=repo_format_name, - make_working_trees=make_working_trees, shared_repo=shared_repo, - vfs_only=True) + return local_dir_format.initialize_on_transport_ex( + transport, + use_existing_dir=use_existing_dir, + create_prefix=create_prefix, + force_new_repo=force_new_repo, + stacked_on=stacked_on, + stack_on_pwd=stack_on_pwd, + repo_format_name=repo_format_name, + make_working_trees=make_working_trees, + shared_repo=shared_repo, + vfs_only=True, + ) except errors.ErrorFromSmartServer as err: - _translate_error(err, path=path.decode('utf-8')) + _translate_error(err, path=path.decode("utf-8")) repo_path = response[0] bzrdir_name = response[6] require_stacking = response[7] @@ -304,23 +349,23 @@ def _initialize_on_transport_ex_rpc(self, client, path, transport, bzrdir = RemoteBzrDir(transport, format, _client=client) if repo_path: repo_format = response_tuple_to_repo_format(response[1:]) - if repo_path == b'.': - repo_path = b'' - repo_path = repo_path.decode('utf-8') + if repo_path == b".": + repo_path = b"" + repo_path = repo_path.decode("utf-8") if repo_path: repo_bzrdir_format = RemoteBzrDirFormat() repo_bzrdir_format._network_name = response[5] - repo_bzr = RemoteBzrDir(transport.clone(repo_path), - repo_bzrdir_format) + repo_bzr = RemoteBzrDir(transport.clone(repo_path), repo_bzrdir_format) else: repo_bzr = bzrdir final_stack = response[8] or None if final_stack: - final_stack = final_stack.decode('utf-8') + final_stack = final_stack.decode("utf-8") final_stack_pwd = response[9] or None if final_stack_pwd: final_stack_pwd = urlutils.join( - transport.base, final_stack_pwd.decode('utf-8')) + transport.base, final_stack_pwd.decode("utf-8") + ) remote_repo = RemoteRepository(repo_bzr, repo_format) if len(response) > 10: # Updated server verb that locks remotely. @@ -330,8 +375,9 @@ def _initialize_on_transport_ex_rpc(self, client, path, transport, remote_repo.dont_leave_lock_in_place() else: remote_repo.lock_write() - policy = _mod_bzrdir.UseExistingRepository(remote_repo, - final_stack, final_stack_pwd, require_stacking) + policy = _mod_bzrdir.UseExistingRepository( + remote_repo, final_stack, final_stack_pwd, require_stacking + ) policy.acquire_repository() else: remote_repo = None @@ -356,7 +402,7 @@ def __return_repository_format(self): # repository format has been asked for, tell the RemoteRepositoryFormat # that it should use that for init() etc. result = RemoteRepositoryFormat() - custom_format = getattr(self, '_repository_format', None) + custom_format = getattr(self, "_repository_format", None) if custom_format: if isinstance(custom_format, RemoteRepositoryFormat): return custom_format @@ -377,8 +423,9 @@ def get_branch_format(self): result = new_result return result - repository_format = property(__return_repository_format, - _mod_bzrdir.BzrDirMetaFormat1._set_repository_format) # .im_func) + repository_format = property( + __return_repository_format, _mod_bzrdir.BzrDirMetaFormat1._set_repository_format + ) # .im_func) class RemoteControlStore(_mod_config.IniFileStore): @@ -414,17 +461,18 @@ def _ensure_real(self): self._real_store = _mod_config.ControlStore(self.controldir) def external_url(self): - return urlutils.join(self.branch.user_url, 'control.conf') + return urlutils.join(self.branch.user_url, "control.conf") def _load_content(self): path = self.controldir._path_for_remote_call(self.controldir._client) try: response, handler = self.controldir._call_expecting_body( - b'BzrDir.get_config_file', path) + b"BzrDir.get_config_file", path + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_store._load_content() - if len(response) and response[0] != b'ok': + if len(response) and response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) return handler.read_body_bytes() @@ -473,7 +521,7 @@ def __init__(self, transport, format, _client=None, _force_probe=False): self._probe_bzrdir() def __repr__(self): - return f'{self.__class__.__name__}({self._client!r})' + return f"{self.__class__.__name__}({self._client!r})" def _probe_bzrdir(self): medium = self._client._medium @@ -489,13 +537,13 @@ def _probe_bzrdir(self): self._rpc_open(path) def _rpc_open_2_1(self, path): - response = self._call(b'BzrDir.open_2.1', path) - if response == (b'no',): + response = self._call(b"BzrDir.open_2.1", path) + if response == (b"no",): raise errors.NotBranchError(path=self.root_transport.base) - elif response[0] == b'yes': - if response[1] == b'yes': + elif response[0] == b"yes": + if response[1] == b"yes": self._has_working_tree = True - elif response[1] == b'no': + elif response[1] == b"no": self._has_working_tree = False else: raise errors.UnexpectedSmartServerResponse(response) @@ -503,10 +551,10 @@ def _rpc_open_2_1(self, path): raise errors.UnexpectedSmartServerResponse(response) def _rpc_open(self, path): - response = self._call(b'BzrDir.open', path) - if response not in [(b'yes',), (b'no',)]: + response = self._call(b"BzrDir.open", path) + if response not in [(b"yes",), (b"no",)]: raise errors.UnexpectedSmartServerResponse(response) - if response == (b'no',): + if response == (b"no",): raise errors.NotBranchError(path=self.root_transport.base) def _ensure_real(self): @@ -515,14 +563,16 @@ def _ensure_real(self): Used before calls to self._real_bzrdir. """ if not self._real_bzrdir: - if debug.debug_flag_enabled('hpssvfs'): + if debug.debug_flag_enabled("hpssvfs"): import traceback - warning('VFS BzrDir access triggered\n%s', - ''.join(traceback.format_stack())) + + warning( + "VFS BzrDir access triggered\n%s", "".join(traceback.format_stack()) + ) self._real_bzrdir = _mod_bzrdir.BzrDir.open_from_transport( - self.root_transport, probers=[_mod_bzr.BzrProber]) - self._format._network_name = \ - self._real_bzrdir._format.network_name() + self.root_transport, probers=[_mod_bzr.BzrProber] + ) + self._format._network_name = self._real_bzrdir._format.network_name() def _translate_error(self, err, **context): _translate_error(err, bzrdir=self, **context) @@ -544,8 +594,7 @@ def checkout_metadir(self): return self._vfs_checkout_metadir() path = self._path_for_remote_call(self._client) try: - response = self._client.call(b'BzrDir.checkout_metadir', - path) + response = self._client.call(b"BzrDir.checkout_metadir", path) except errors.UnknownSmartMethod: medium._remember_remote_is_before((2, 5)) return self._vfs_checkout_metadir() @@ -555,39 +604,39 @@ def checkout_metadir(self): try: format = controldir.network_format_registry.get(control_name) except KeyError as e: - raise errors.UnknownFormatError(kind='control', - format=control_name) from e + raise errors.UnknownFormatError(kind="control", format=control_name) from e if repo_name: try: - repo_format = _mod_repository.network_format_registry.get( - repo_name) + repo_format = _mod_repository.network_format_registry.get(repo_name) except KeyError as e: - raise errors.UnknownFormatError(kind='repository', - format=repo_name) from e + raise errors.UnknownFormatError( + kind="repository", format=repo_name + ) from e format.repository_format = repo_format if branch_name: try: format.set_branch_format( - branch.network_format_registry.get(branch_name)) + branch.network_format_registry.get(branch_name) + ) except KeyError as e: - raise errors.UnknownFormatError(kind='branch', - format=branch_name) from e + raise errors.UnknownFormatError( + kind="branch", format=branch_name + ) from e return format def _vfs_cloning_metadir(self, require_stacking=False): self._ensure_real() - return self._real_bzrdir.cloning_metadir( - require_stacking=require_stacking) + return self._real_bzrdir.cloning_metadir(require_stacking=require_stacking) def cloning_metadir(self, require_stacking=False): medium = self._client._medium if medium._is_remote_before((1, 13)): return self._vfs_cloning_metadir(require_stacking=require_stacking) - verb = b'BzrDir.cloning_metadir' + verb = b"BzrDir.cloning_metadir" if require_stacking: - stacking = b'True' + stacking = b"True" else: - stacking = b'False' + stacking = b"False" path = self._path_for_remote_call(self._client) try: response = self._call(verb, path, stacking) @@ -595,7 +644,7 @@ def cloning_metadir(self, require_stacking=False): medium._remember_remote_is_before((1, 13)) return self._vfs_cloning_metadir(require_stacking=require_stacking) except UnknownErrorFromSmartServer as err: - if err.error_tuple != (b'BranchReference',): + if err.error_tuple != (b"BranchReference",): raise # We need to resolve the branch reference to determine the # cloning_metadir. This causes unnecessary RPCs to open the @@ -612,30 +661,31 @@ def cloning_metadir(self, require_stacking=False): try: format = controldir.network_format_registry.get(control_name) except KeyError as e: - raise errors.UnknownFormatError( - kind='control', format=control_name) from e + raise errors.UnknownFormatError(kind="control", format=control_name) from e if repo_name: try: format.repository_format = _mod_repository.network_format_registry.get( - repo_name) + repo_name + ) except KeyError as e: - raise errors.UnknownFormatError(kind='repository', - format=repo_name) from e - if branch_ref == b'ref': + raise errors.UnknownFormatError( + kind="repository", format=repo_name + ) from e + if branch_ref == b"ref": # XXX: we need possible_transports here to avoid reopening the # connection to the referenced location ref_bzrdir = _mod_bzrdir.BzrDir.open(branch_name) branch_format = ref_bzrdir.cloning_metadir().get_branch_format() format.set_branch_format(branch_format) - elif branch_ref == b'branch': + elif branch_ref == b"branch": if branch_name: try: - branch_format = branch.network_format_registry.get( - branch_name) + branch_format = branch.network_format_registry.get(branch_name) except KeyError as e: - raise errors.UnknownFormatError(kind='branch', - format=branch_name) from e + raise errors.UnknownFormatError( + kind="branch", format=branch_name + ) from e format.set_branch_format(branch_format) else: raise errors.UnexpectedSmartServerResponse(response) @@ -654,30 +704,32 @@ def destroy_repository(self): """See BzrDir.destroy_repository.""" path = self._path_for_remote_call(self._client) try: - response = self._call(b'BzrDir.destroy_repository', path) + response = self._call(b"BzrDir.destroy_repository", path) except errors.UnknownSmartMethod: self._ensure_real() self._real_bzrdir.destroy_repository() return - if response[0] != b'ok': - raise SmartProtocolError( - f'unexpected response code {response}') + if response[0] != b"ok": + raise SmartProtocolError(f"unexpected response code {response}") - def create_branch(self, name=None, repository=None, - append_revisions_only=None): + def create_branch(self, name=None, repository=None, append_revisions_only=None): if name is None: name = self._get_selected_branch() if name != "": raise controldir.NoColocatedBranchSupport(self) # as per meta1 formats - just delegate to the format object which may # be parameterised. - real_branch = self._format.get_branch_format().initialize(self, - name=name, repository=repository, - append_revisions_only=append_revisions_only) + real_branch = self._format.get_branch_format().initialize( + self, + name=name, + repository=repository, + append_revisions_only=append_revisions_only, + ) if not isinstance(real_branch, RemoteBranch): if not isinstance(repository, RemoteRepository): raise AssertionError( - f'need a RemoteRepository to use with RemoteBranch, got {repository!r}') + f"need a RemoteRepository to use with RemoteBranch, got {repository!r}" + ) result = RemoteBranch(self, repository, real_branch, name=name) else: result = real_branch @@ -699,22 +751,22 @@ def destroy_branch(self, name=None): path = self._path_for_remote_call(self._client) try: if name != "": - args = (name, ) + args = (name,) else: args = () - response = self._call(b'BzrDir.destroy_branch', path, *args) + response = self._call(b"BzrDir.destroy_branch", path, *args) except errors.UnknownSmartMethod: self._ensure_real() self._real_bzrdir.destroy_branch(name=name) self._next_open_branch_result = None return self._next_open_branch_result = None - if response[0] != b'ok': - raise SmartProtocolError( - f'unexpected response code {response}') + if response[0] != b"ok": + raise SmartProtocolError(f"unexpected response code {response}") - def create_workingtree(self, revision_id=None, from_branch=None, - accelerator_tree=None, hardlink=False): + def create_workingtree( + self, revision_id=None, from_branch=None, accelerator_tree=None, hardlink=False + ): raise errors.NotLocalUrl(self.transport.base) def find_branch_format(self, name=None): @@ -728,8 +780,7 @@ def find_branch_format(self, name=None): def branch_names(self): path = self._path_for_remote_call(self._client) try: - response, handler = self._call_expecting_body( - b'BzrDir.get_branches', path) + response, handler = self._call_expecting_body(b"BzrDir.get_branches", path) except errors.UnknownSmartMethod: self._ensure_real() return self._real_bzrdir.branch_names() @@ -738,15 +789,14 @@ def branch_names(self): body = bencode.bdecode(handler.read_body_bytes()) ret = [] for name, _value in body.items(): - name = name.decode('utf-8') + name = name.decode("utf-8") ret.append(name) return ret def get_branches(self, possible_transports=None, ignore_fallbacks=False): path = self._path_for_remote_call(self._client) try: - response, handler = self._call_expecting_body( - b'BzrDir.get_branches', path) + response, handler = self._call_expecting_body(b"BzrDir.get_branches", path) except errors.UnknownSmartMethod: self._ensure_real() return self._real_bzrdir.get_branches() @@ -755,11 +805,14 @@ def get_branches(self, possible_transports=None, ignore_fallbacks=False): body = bencode.bdecode(handler.read_body_bytes()) ret = {} for name, value in body.items(): - name = name.decode('utf-8') + name = name.decode("utf-8") ret[name] = self._open_branch( - name, value[0].decode('ascii'), value[1], + name, + value[0].decode("ascii"), + value[1], possible_transports=possible_transports, - ignore_fallbacks=ignore_fallbacks) + ignore_fallbacks=ignore_fallbacks, + ) return ret def set_branch_reference(self, target_branch, name=None): @@ -778,8 +831,8 @@ def get_branch_reference(self, name=None): if name != "": raise controldir.NoColocatedBranchSupport(self) response = self._get_branch_reference() - if response[0] == 'ref': - return response[1].decode('utf-8') + if response[0] == "ref": + return response[1].decode("utf-8") else: return None @@ -793,10 +846,10 @@ def _get_branch_reference(self): path = self._path_for_remote_call(self._client) medium = self._client._medium candidate_calls = [ - (b'BzrDir.open_branchV3', (2, 1)), - (b'BzrDir.open_branchV2', (1, 13)), - (b'BzrDir.open_branch', None), - ] + (b"BzrDir.open_branchV3", (2, 1)), + (b"BzrDir.open_branchV2", (1, 13)), + (b"BzrDir.open_branch", None), + ] for verb, required_version in candidate_calls: if required_version and medium._is_remote_before(required_version): continue @@ -808,74 +861,97 @@ def _get_branch_reference(self): medium._remember_remote_is_before(required_version) else: break - if verb == b'BzrDir.open_branch': - if response[0] != b'ok': + if verb == b"BzrDir.open_branch": + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) - if response[1] != b'': - return ('ref', response[1]) + if response[1] != b"": + return ("ref", response[1]) else: - return ('branch', b'') - if response[0] not in (b'ref', b'branch'): + return ("branch", b"") + if response[0] not in (b"ref", b"branch"): raise errors.UnexpectedSmartServerResponse(response) - return (response[0].decode('ascii'), response[1]) + return (response[0].decode("ascii"), response[1]) def _get_tree_branch(self, name=None): """See BzrDir._get_tree_branch().""" return None, self.open_branch(name=name) - def _open_branch(self, name, kind, location_or_format, - ignore_fallbacks=False, possible_transports=None): - if kind == 'ref': + def _open_branch( + self, + name, + kind, + location_or_format, + ignore_fallbacks=False, + possible_transports=None, + ): + if kind == "ref": # a branch reference, use the existing BranchReference logic. format = BranchReferenceFormat() - ref_loc = urlutils.join(self.user_url, location_or_format.decode('utf-8')) - return format.open(self, name=name, _found=True, - location=ref_loc, - ignore_fallbacks=ignore_fallbacks, - possible_transports=possible_transports) + ref_loc = urlutils.join(self.user_url, location_or_format.decode("utf-8")) + return format.open( + self, + name=name, + _found=True, + location=ref_loc, + ignore_fallbacks=ignore_fallbacks, + possible_transports=possible_transports, + ) branch_format_name = location_or_format if not branch_format_name: branch_format_name = None format = RemoteBranchFormat(network_name=branch_format_name) - return RemoteBranch(self, self.find_repository(), format=format, - setup_stacking=not ignore_fallbacks, name=name, - possible_transports=possible_transports) - - def open_branch(self, name=None, unsupported=False, - ignore_fallbacks=False, possible_transports=None): + return RemoteBranch( + self, + self.find_repository(), + format=format, + setup_stacking=not ignore_fallbacks, + name=name, + possible_transports=possible_transports, + ) + + def open_branch( + self, + name=None, + unsupported=False, + ignore_fallbacks=False, + possible_transports=None, + ): if name is None: name = self._get_selected_branch() if name != "": raise controldir.NoColocatedBranchSupport(self) if unsupported: - raise NotImplementedError( - 'unsupported flag support not implemented yet.') + raise NotImplementedError("unsupported flag support not implemented yet.") if self._next_open_branch_result is not None: # See create_branch for details. result = self._next_open_branch_result self._next_open_branch_result = None return result response = self._get_branch_reference() - return self._open_branch(name, response[0], response[1], - possible_transports=possible_transports, - ignore_fallbacks=ignore_fallbacks) + return self._open_branch( + name, + response[0], + response[1], + possible_transports=possible_transports, + ignore_fallbacks=ignore_fallbacks, + ) def _open_repo_v1(self, path): - verb = b'BzrDir.find_repository' + verb = b"BzrDir.find_repository" response = self._call(verb, path) - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) # servers that only support the v1 method don't support external # references either. self._ensure_real() repo = self._real_bzrdir.open_repository() - response = response + (b'no', repo._format.network_name()) + response = response + (b"no", repo._format.network_name()) return response, repo def _open_repo_v2(self, path): - verb = b'BzrDir.find_repositoryV2' + verb = b"BzrDir.find_repositoryV2" response = self._call(verb, path) - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) self._ensure_real() repo = self._real_bzrdir.open_repository() @@ -883,7 +959,7 @@ def _open_repo_v2(self, path): return response, repo def _open_repo_v3(self, path): - verb = b'BzrDir.find_repositoryV3' + verb = b"BzrDir.find_repositoryV3" medium = self._client._medium if medium._is_remote_before((1, 13)): raise errors.UnknownSmartMethod(verb) @@ -892,28 +968,26 @@ def _open_repo_v3(self, path): except errors.UnknownSmartMethod: medium._remember_remote_is_before((1, 13)) raise - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) return response, None def open_repository(self): path = self._path_for_remote_call(self._client) response = None - for probe in [self._open_repo_v3, self._open_repo_v2, - self._open_repo_v1]: + for probe in [self._open_repo_v3, self._open_repo_v2, self._open_repo_v1]: try: response, real_repo = probe(path) break except errors.UnknownSmartMethod: pass if response is None: - raise errors.UnknownSmartMethod(b'BzrDir.find_repository{3,2,}') - if response[0] != b'ok': + raise errors.UnknownSmartMethod(b"BzrDir.find_repository{3,2,}") + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) if len(response) != 6: - raise SmartProtocolError( - f'incorrect response length {response}') - if response[1] == b'': + raise SmartProtocolError(f"incorrect response length {response}") + if response[1] == b"": # repo is at this dir. format = response_tuple_to_repo_format(response[2:]) # Used to support creating a real format instance when needed. @@ -930,15 +1004,14 @@ def has_workingtree(self): if self._has_working_tree is None: path = self._path_for_remote_call(self._client) try: - response = self._call(b'BzrDir.has_workingtree', path) + response = self._call(b"BzrDir.has_workingtree", path) except errors.UnknownSmartMethod: self._ensure_real() self._has_working_tree = self._real_bzrdir.has_workingtree() else: - if response[0] not in (b'yes', b'no'): - raise SmartProtocolError( - f'unexpected response code {response}') - self._has_working_tree = (response[0] == b'yes') + if response[0] not in (b"yes", b"no"): + raise SmartProtocolError(f"unexpected response code {response}") + self._has_working_tree = response[0] == b"yes" return self._has_working_tree def open_workingtree(self, recommend_upgrade=True): @@ -950,10 +1023,11 @@ def open_workingtree(self, recommend_upgrade=True): def _path_for_remote_call(self, client): """Return the path to be used for this bzrdir in a remote call.""" remote_path = client.remote_path_from_transport(self.root_transport) - remote_path = remote_path.decode('utf-8') + remote_path = remote_path.decode("utf-8") base_url, segment_parameters = urlutils.split_segment_parameters_raw( - remote_path) - base_url = base_url.encode('utf-8') + remote_path + ) + base_url = base_url.encode("utf-8") return base_url def get_branch_transport(self, branch_format, name=None): @@ -984,27 +1058,43 @@ def _get_config_store(self): class RemoteInventoryTree(InventoryRevisionTree): - def __init__(self, repository, inv, revision_id): super().__init__(repository, inv, revision_id) - def archive(self, format, name, root=None, subdir=None, force_mtime=None, recurse_nested=False): + def archive( + self, + format, + name, + root=None, + subdir=None, + force_mtime=None, + recurse_nested=False, + ): if recurse_nested: # For now, just fall back to non-HPSS mode if nested trees are involved. return super().archive( - format, name, root, subdir, force_mtime=force_mtime, - recurse_nested=recurse_nested) + format, + name, + root, + subdir, + force_mtime=force_mtime, + recurse_nested=recurse_nested, + ) ret = self._repository._revision_archive( - self.get_revision_id(), format, name, root, subdir, - force_mtime=force_mtime) + self.get_revision_id(), format, name, root, subdir, force_mtime=force_mtime + ) if ret is None: return super().archive( - format, name, root, subdir, force_mtime=force_mtime, - recurse_nested=recurse_nested) + format, + name, + root, + subdir, + force_mtime=force_mtime, + recurse_nested=recurse_nested, + ) return ret - def annotate_iter(self, path, - default_revision=_mod_revision.CURRENT_REVISION): + def annotate_iter(self, path, default_revision=_mod_revision.CURRENT_REVISION): """Return an iterator of revision_id, line tuples. For working trees (and mutable trees in general), the special @@ -1015,11 +1105,13 @@ def annotate_iter(self, path, this value. """ ret = self._repository._annotate_file_revision( - self.get_revision_id(), path, file_id=None, - default_revision=default_revision) + self.get_revision_id(), + path, + file_id=None, + default_revision=default_revision, + ) if ret is None: - return super().annotate_iter( - path, default_revision=default_revision) + return super().annotate_iter(path, default_revision=default_revision) return ret @@ -1088,55 +1180,58 @@ def supports_chks(self): def supports_external_lookups(self): if self._supports_external_lookups is None: self._ensure_real() - self._supports_external_lookups = \ + self._supports_external_lookups = ( self._custom_format.supports_external_lookups + ) return self._supports_external_lookups @property def supports_funky_characters(self): if self._supports_funky_characters is None: self._ensure_real() - self._supports_funky_characters = \ + self._supports_funky_characters = ( self._custom_format.supports_funky_characters + ) return self._supports_funky_characters @property def supports_nesting_repositories(self): if self._supports_nesting_repositories is None: self._ensure_real() - self._supports_nesting_repositories = \ + self._supports_nesting_repositories = ( self._custom_format.supports_nesting_repositories + ) return self._supports_nesting_repositories @property def supports_tree_reference(self): if self._supports_tree_reference is None: self._ensure_real() - self._supports_tree_reference = \ - self._custom_format.supports_tree_reference + self._supports_tree_reference = self._custom_format.supports_tree_reference return self._supports_tree_reference @property def revision_graph_can_have_wrong_parents(self): if self._revision_graph_can_have_wrong_parents is None: self._ensure_real() - self._revision_graph_can_have_wrong_parents = \ + self._revision_graph_can_have_wrong_parents = ( self._custom_format.revision_graph_can_have_wrong_parents + ) return self._revision_graph_can_have_wrong_parents def _vfs_initialize(self, a_controldir, shared): """Helper for common code in initialize.""" if self._custom_format: # Custom format requested - result = self._custom_format.initialize( - a_controldir, shared=shared) + result = self._custom_format.initialize(a_controldir, shared=shared) elif self._creating_bzrdir is not None: # Use the format that the repository we were created to back # has. prior_repo = self._creating_bzrdir.open_repository() prior_repo._ensure_real() result = prior_repo._real_repository._format.initialize( - a_controldir, shared=shared) + a_controldir, shared=shared + ) else: # assume that a_bzr is a RemoteBzrDir but the smart server didn't # support remote initialization. @@ -1165,17 +1260,16 @@ def initialize(self, a_controldir, shared=False): network_name = self._network_name else: # Select the current breezy default and ask for that. - reference_bzrdir_format = controldir.format_registry.get( - 'default')() + reference_bzrdir_format = controldir.format_registry.get("default")() reference_format = reference_bzrdir_format.repository_format network_name = reference_format.network_name() # 2) try direct creation via RPC path = a_controldir._path_for_remote_call(a_controldir._client) - verb = b'BzrDir.create_repository' + verb = b"BzrDir.create_repository" if shared: - shared_str = b'True' + shared_str = b"True" else: - shared_str = b'False' + shared_str = b"False" try: response = a_controldir._call(verb, path, network_name, shared_str) except errors.UnknownSmartMethod: @@ -1193,17 +1287,19 @@ def initialize(self, a_controldir, shared=False): def open(self, a_controldir): if not isinstance(a_controldir, RemoteBzrDir): - raise AssertionError(f'{a_controldir!r} is not a RemoteBzrDir') + raise AssertionError(f"{a_controldir!r} is not a RemoteBzrDir") return a_controldir.open_repository() def _ensure_real(self): if self._custom_format is None: try: self._custom_format = _mod_repository.network_format_registry.get( - self._network_name) + self._network_name + ) except KeyError as e: raise errors.UnknownFormatError( - kind='repository', format=self._network_name) from e + kind="repository", format=self._network_name + ) from e @property def _fetch_order(self): @@ -1222,7 +1318,7 @@ def _fetch_reconcile(self): def get_format_description(self): self._ensure_real() - return 'Remote: ' + self._custom_format.get_format_description() + return "Remote: " + self._custom_format.get_format_description() def __eq__(self, other): return self.__class__ is other.__class__ @@ -1249,8 +1345,7 @@ def _inventory_serializer(self): return self._custom_format._inventory_serializer -class RemoteRepository(_mod_repository.Repository, _RpcHelper, - lock._RelockDebugMixin): +class RemoteRepository(_mod_repository.Repository, _RpcHelper, lock._RelockDebugMixin): """Repository accessed over rpc. For the moment most operations are performed using local transport-backed @@ -1260,9 +1355,13 @@ class RemoteRepository(_mod_repository.Repository, _RpcHelper, _format: RemoteRepositoryFormat _real_repository: Optional[_mod_repository.Repository] - def __init__(self, remote_bzrdir: RemoteBzrDir, - format: RemoteRepositoryFormat, - real_repository: Optional[_mod_repository.Repository] = None, _client=None): + def __init__( + self, + remote_bzrdir: RemoteBzrDir, + format: RemoteRepositoryFormat, + real_repository: Optional[_mod_repository.Repository] = None, + _client=None, + ): """Create a RemoteRepository instance. :param remote_bzrdir: The bzrdir hosting this repository. @@ -1291,7 +1390,8 @@ def __init__(self, remote_bzrdir: RemoteBzrDir, # Cache of revision parents; misses are cached during read locks, and # write locks when no _real_repository has been set. self._unstacked_provider = graph.CachingParentsProvider( - get_parent_map=self._get_parent_map_rpc) + get_parent_map=self._get_parent_map_rpc + ) self._unstacked_provider.disable_cache() # For tests: # These depend on the actual remote format, so force them off for @@ -1334,26 +1434,30 @@ def abort_write_group(self, suppress_errors=False): if self._real_repository: self._ensure_real() return self._real_repository.abort_write_group( - suppress_errors=suppress_errors) + suppress_errors=suppress_errors + ) if not self.is_in_write_group(): if suppress_errors: - mutter('(suppressed) not in write group') + mutter("(suppressed) not in write group") return raise errors.BzrError("not in write group") path = self.controldir._path_for_remote_call(self._client) try: - response = self._call(b'Repository.abort_write_group', path, - self._lock_token, - [token.encode('utf-8') for token in self._write_group_tokens]) + response = self._call( + b"Repository.abort_write_group", + path, + self._lock_token, + [token.encode("utf-8") for token in self._write_group_tokens], + ) except Exception as exc: self._write_group = None if not suppress_errors: raise - mutter('abort_write_group failed') + mutter("abort_write_group failed") log_exception_quietly() - note(gettext('bzr: ERROR (ignored): %s'), exc) + note(gettext("bzr: ERROR (ignored): %s"), exc) else: - if response != (b'ok', ): + if response != (b"ok",): raise errors.UnexpectedSmartServerResponse(response) self._write_group_tokens = None @@ -1381,9 +1485,13 @@ def commit_write_group(self): if not self.is_in_write_group(): raise errors.BzrError("not in write group") path = self.controldir._path_for_remote_call(self._client) - response = self._call(b'Repository.commit_write_group', path, - self._lock_token, [token.encode('utf-8') for token in self._write_group_tokens]) - if response != (b'ok', ): + response = self._call( + b"Repository.commit_write_group", + path, + self._lock_token, + [token.encode("utf-8") for token in self._write_group_tokens], + ) + if response != (b"ok",): raise errors.UnexpectedSmartServerResponse(response) self._write_group_tokens = None # Refresh data after writing to the repository. @@ -1394,12 +1502,16 @@ def resume_write_group(self, tokens): return self._real_repository.resume_write_group(tokens) path = self.controldir._path_for_remote_call(self._client) try: - response = self._call(b'Repository.check_write_group', path, - self._lock_token, [token.encode('utf-8') for token in tokens]) + response = self._call( + b"Repository.check_write_group", + path, + self._lock_token, + [token.encode("utf-8") for token in tokens], + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_repository.resume_write_group(tokens) - if response != (b'ok', ): + if response != (b"ok",): raise errors.UnexpectedSmartServerResponse(response) self._write_group_tokens = tokens @@ -1413,12 +1525,12 @@ def suspend_write_group(self): def get_missing_parent_inventories(self, check_for_missing_texts=True): self._ensure_real() return self._real_repository.get_missing_parent_inventories( - check_for_missing_texts=check_for_missing_texts) + check_for_missing_texts=check_for_missing_texts + ) def _get_rev_id_for_revno_vfs(self, revno, known_pair): self._ensure_real() - return self._real_repository.get_rev_id_for_revno( - revno, known_pair) + return self._real_repository.get_rev_id_for_revno(revno, known_pair) def get_rev_id_for_revno(self, revno, known_pair): """See Repository.get_rev_id_for_revno.""" @@ -1427,7 +1539,8 @@ def get_rev_id_for_revno(self, revno, known_pair): if self._client._medium._is_remote_before((1, 17)): return self._get_rev_id_for_revno_vfs(revno, known_pair) response = self._call( - b'Repository.get_rev_id_for_revno', path, revno, known_pair) + b"Repository.get_rev_id_for_revno", path, revno, known_pair + ) except errors.UnknownSmartMethod: self._client._medium._remember_remote_is_before((1, 17)) return self._get_rev_id_for_revno_vfs(revno, known_pair) @@ -1436,21 +1549,22 @@ def get_rev_id_for_revno(self, revno, known_pair): # ValueError instead of returning revno-outofbounds if len(e.error_tuple) < 3: raise - if e.error_tuple[:2] != (b'error', b'ValueError'): + if e.error_tuple[:2] != (b"error", b"ValueError"): raise m = re.match( - br"requested revno \(([0-9]+)\) is later than given " - br"known revno \(([0-9]+)\)", e.error_tuple[2]) + rb"requested revno \(([0-9]+)\) is later than given " + rb"known revno \(([0-9]+)\)", + e.error_tuple[2], + ) if not m: raise raise errors.RevnoOutOfBounds(int(m.group(1)), (0, int(m.group(2)))) from e - if response[0] == b'ok': + if response[0] == b"ok": return True, response[1] - elif response[0] == b'history-incomplete': + elif response[0] == b"history-incomplete": known_pair = response[1:3] for fallback in self._fallback_repositories: - found, result = fallback.get_rev_id_for_revno( - revno, known_pair) + found, result = fallback.get_rev_id_for_revno(revno, known_pair) if found: return True, result else: @@ -1474,14 +1588,16 @@ def _ensure_real(self): invocation. If in doubt chat to the bzr network team. """ if self._real_repository is None: - if debug.debug_flag_enabled('hpssvfs'): + if debug.debug_flag_enabled("hpssvfs"): import traceback - warning('VFS Repository access triggered\n%s', - ''.join(traceback.format_stack())) + + warning( + "VFS Repository access triggered\n%s", + "".join(traceback.format_stack()), + ) self._unstacked_provider.missing_keys.clear() self.controldir._ensure_real() - self._set_real_repository( - self.controldir._real_bzrdir.open_repository()) + self._set_real_repository(self.controldir._real_bzrdir.open_repository()) def _translate_error(self, err, **context): self.controldir._translate_error(err, repository=self, **context) @@ -1511,21 +1627,22 @@ def _generate_text_key_index(self): def _get_revision_graph(self, revision_id: RevisionID): """Private method for using with old (< 1.2) servers to fallback.""" if revision_id is None: - revision_id = b'' + revision_id = b"" elif _mod_revision.is_null(revision_id): return {} path = self.controldir._path_for_remote_call(self._client) response = self._call_expecting_body( - b'Repository.get_revision_graph', path, revision_id) + b"Repository.get_revision_graph", path, revision_id + ) response_tuple, response_handler = response - if response_tuple[0] != b'ok': + if response_tuple[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response_tuple) coded = response_handler.read_body_bytes() - if coded == b'': + if coded == b"": # no revisions in this repository! return {} - lines = coded.split(b'\n') + lines = coded.split(b"\n") revision_graph = {} for line in lines: d = tuple(line.split()) @@ -1582,8 +1699,10 @@ def has_same_location(self, other): # TODO: Move to RepositoryBase and unify with the regular Repository # one; unfortunately the tests rely on slightly different behaviour at # present -- mbp 20090710 - return (self.__class__ is other.__class__ - and self.controldir.transport.base == other.controldir.transport.base) + return ( + self.__class__ is other.__class__ + and self.controldir.transport.base == other.controldir.transport.base + ) def get_graph(self, other_repository=None): """Return the graph for this repository format.""" @@ -1593,8 +1712,11 @@ def get_graph(self, other_repository=None): def get_known_graph_ancestry(self, revision_ids): """Return the known graph for a set of revision ids and their ancestors.""" with self.lock_read(): - revision_graph = {key: value for key, value in - self.get_graph().iter_ancestry(revision_ids) if value is not None} + revision_graph = { + key: value + for key, value in self.get_graph().iter_ancestry(revision_ids) + if value is not None + } revision_graph = _mod_repository._strip_NULL_ghosts(revision_graph) return graph.KnownGraph(revision_graph) @@ -1603,29 +1725,30 @@ def gather_stats(self, revid=None, committers=None): path = self.controldir._path_for_remote_call(self._client) # revid can be None to indicate no revisions, not just NULL_REVISION if revid is None or _mod_revision.is_null(revid): - fmt_revid = b'' + fmt_revid = b"" else: fmt_revid = revid if committers is None or not committers: - fmt_committers = b'no' + fmt_committers = b"no" else: - fmt_committers = b'yes' + fmt_committers = b"yes" response_tuple, response_handler = self._call_expecting_body( - b'Repository.gather_stats', path, fmt_revid, fmt_committers) - if response_tuple[0] != b'ok': + b"Repository.gather_stats", path, fmt_revid, fmt_committers + ) + if response_tuple[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response_tuple) body = response_handler.read_body_bytes() result = {} - for line in body.split(b'\n'): + for line in body.split(b"\n"): if not line: continue - key, val_text = line.split(b':') - key = key.decode('ascii') - if key in ('revisions', 'size', 'committers'): + key, val_text = line.split(b":") + key = key.decode("ascii") + if key in ("revisions", "size", "committers"): result[key] = int(val_text) - elif key in ('firstrev', 'latestrev'): - values = val_text.split(b' ')[1:] + elif key in ("firstrev", "latestrev"): + values = val_text.split(b" ")[1:] result[key] = (float(values[0]), int(values[1])) return result @@ -1640,13 +1763,13 @@ def get_physical_lock_status(self): """See Repository.get_physical_lock_status().""" path = self.controldir._path_for_remote_call(self._client) try: - response = self._call(b'Repository.get_physical_lock_status', path) + response = self._call(b"Repository.get_physical_lock_status", path) except errors.UnknownSmartMethod: self._ensure_real() return self._real_repository.get_physical_lock_status() - if response[0] not in (b'yes', b'no'): + if response[0] not in (b"yes", b"no"): raise errors.UnexpectedSmartServerResponse(response) - return (response[0] == b'yes') + return response[0] == b"yes" def is_in_write_group(self): """Return True if there is an open write group. @@ -1664,14 +1787,13 @@ def is_locked(self): def is_shared(self): """See Repository.is_shared().""" path = self.controldir._path_for_remote_call(self._client) - response = self._call(b'Repository.is_shared', path) - if response[0] not in (b'yes', b'no'): - raise SmartProtocolError( - f'unexpected response code {response}') - return response[0] == b'yes' + response = self._call(b"Repository.is_shared", path) + if response[0] not in (b"yes", b"no"): + raise SmartProtocolError(f"unexpected response code {response}") + return response[0] == b"yes" def is_write_locked(self): - return self._lock_mode == 'w' + return self._lock_mode == "w" def _warn_if_deprecated(self, branch=None): # If we have a real repository, the check will be done there, if we @@ -1685,8 +1807,8 @@ def lock_read(self): """ # wrong eventually - want a local lock cache context if not self._lock_mode: - self._note_lock('r') - self._lock_mode = 'r' + self._note_lock("r") + self._lock_mode = "r" self._lock_count = 1 self._unstacked_provider.enable_cache(cache_misses=True) if self._real_repository is not None: @@ -1700,11 +1822,10 @@ def lock_read(self): def _remote_lock_write(self, token): path = self.controldir._path_for_remote_call(self._client) if token is None: - token = b'' - err_context = {'token': token} - response = self._call(b'Repository.lock_write', path, token, - **err_context) - if response[0] == b'ok': + token = b"" + err_context = {"token": token} + response = self._call(b"Repository.lock_write", path, token, **err_context) + if response[0] == b"ok": ok, token = response return token else: @@ -1712,7 +1833,7 @@ def _remote_lock_write(self, token): def lock_write(self, token=None, _skip_rpc=False): if not self._lock_mode: - self._note_lock('w') + self._note_lock("w") if _skip_rpc: if self._lock_token is not None: if token != self._lock_token: @@ -1730,14 +1851,14 @@ def lock_write(self, token=None, _skip_rpc=False): self._leave_lock = True else: self._leave_lock = False - self._lock_mode = 'w' + self._lock_mode = "w" self._lock_count = 1 cache_misses = self._real_repository is None self._unstacked_provider.enable_cache(cache_misses=cache_misses) for repo in self._fallback_repositories: # Writes don't affect fallback repos repo.lock_read() - elif self._lock_mode == 'r': + elif self._lock_mode == "r": raise errors.ReadOnlyError(self) else: self._lock_count += 1 @@ -1764,7 +1885,7 @@ def _set_real_repository(self, repository: _mod_repository.Repository): # We cannot do this [currently] if the repository is locked - # synchronised state might be lost. if self.is_locked(): - raise AssertionError('_real_repository is already set') + raise AssertionError("_real_repository is already set") if isinstance(repository, RemoteRepository): raise AssertionError() self._real_repository = repository @@ -1782,19 +1903,20 @@ def _set_real_repository(self, repository: _mod_repository.Repository): # 3) new servers, RemoteRepository.ensure_real is triggered before # RemoteBranch.ensure real, in this case we get a repo with no fallbacks # and need to populate it. - if (self._fallback_repositories - and len(self._real_repository._fallback_repositories) - != len(self._fallback_repositories)): + if self._fallback_repositories and len( + self._real_repository._fallback_repositories + ) != len(self._fallback_repositories): if len(self._real_repository._fallback_repositories): raise AssertionError( - "cannot cleanly remove existing _fallback_repositories") + "cannot cleanly remove existing _fallback_repositories" + ) for fb in self._fallback_repositories: self._real_repository.add_fallback_repository(fb) - if self._lock_mode == 'w': + if self._lock_mode == "w": # if we are already locked, the real repository must be able to # acquire the lock with our token. self._real_repository.lock_write(self._lock_token) - elif self._lock_mode == 'r': + elif self._lock_mode == "r": self._real_repository.lock_read() if self._write_group_tokens is not None: # if we are already in a write group, resume it @@ -1815,28 +1937,27 @@ def start_write_group(self): if not self.is_write_locked(): raise errors.NotWriteLocked(self) if self._write_group_tokens is not None: - raise errors.BzrError('already in a write group') + raise errors.BzrError("already in a write group") path = self.controldir._path_for_remote_call(self._client) try: - response = self._call(b'Repository.start_write_group', path, - self._lock_token) + response = self._call( + b"Repository.start_write_group", path, self._lock_token + ) except (errors.UnknownSmartMethod, errors.UnsuspendableWriteGroup): self._ensure_real() return self._real_repository.start_write_group() - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) - self._write_group_tokens = [ - token.decode('utf-8') for token in response[1]] + self._write_group_tokens = [token.decode("utf-8") for token in response[1]] def _unlock(self, token): path = self.controldir._path_for_remote_call(self._client) if not token: # with no token the remote repository is not persistently locked. return - err_context = {'token': token} - response = self._call(b'Repository.unlock', path, token, - **err_context) - if response == (b'ok',): + err_context = {"token": token} + response = self._call(b"Repository.unlock", path, token, **err_context) + if response == (b"ok",): return else: raise errors.UnexpectedSmartServerResponse(response) @@ -1864,7 +1985,7 @@ def unlock(self): finally: # The rpc-level lock should be released even if there was a # problem releasing the vfs-based lock. - if old_mode == 'w': + if old_mode == "w": # Only write-locked repositories need to make a remote method # call to perform the unlock. old_token = self._lock_token @@ -1884,7 +2005,7 @@ def break_lock(self): except errors.UnknownSmartMethod: self._ensure_real() return self._real_repository.break_lock() - if response != (b'ok',): + if response != (b"ok",): raise errors.UnexpectedSmartServerResponse(response) def _get_tarball(self, compression): @@ -1893,14 +2014,16 @@ def _get_tarball(self, compression): Returns None if the server does not support sending tarballs. """ import tempfile + path = self.controldir._path_for_remote_call(self._client) try: response, protocol = self._call_expecting_body( - b'Repository.tarball', path, compression.encode('ascii')) + b"Repository.tarball", path, compression.encode("ascii") + ) except errors.UnknownSmartMethod: protocol.cancel_read_body() return None - if response[0] == b'ok': + if response[0] == b"ok": # Extract the tarball and return it t = tempfile.NamedTemporaryFile() # TODO: rpc layer should read directly into it... @@ -1927,8 +2050,7 @@ def _create_sprouting_repo(self, a_controldir, shared): # Most control formats need the repository to be specifically # created, but on some old all-in-one formats it's not needed try: - dest_repo = self._format.initialize( - a_controldir, shared=shared) + dest_repo = self._format.initialize(a_controldir, shared=shared) except errors.UninitializableFormat: dest_repo = a_controldir.open_repository() return dest_repo @@ -1938,26 +2060,37 @@ def _create_sprouting_repo(self, a_controldir, shared): def revision_tree(self, revision_id): with self.lock_read(): if revision_id == _mod_revision.NULL_REVISION: - return InventoryRevisionTree(self, - Inventory(root_id=None), _mod_revision.NULL_REVISION) + return InventoryRevisionTree( + self, Inventory(root_id=None), _mod_revision.NULL_REVISION + ) else: return list(self.revision_trees([revision_id]))[0] def get_serializer_format(self): path = self.controldir._path_for_remote_call(self._client) try: - response = self._call(b'VersionedFileRepository.get_serializer_format', - path) + response = self._call( + b"VersionedFileRepository.get_serializer_format", path + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_repository.get_serializer_format() - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) return response[1] - def get_commit_builder(self, branch, parents, config, timestamp=None, - timezone=None, committer=None, revprops=None, - revision_id=None, lossy=False): + def get_commit_builder( + self, + branch, + parents, + config, + timestamp=None, + timezone=None, + committer=None, + revprops=None, + revision_id=None, + lossy=False, + ): """Obtain a CommitBuilder for this repository. :param branch: Branch to commit to. @@ -1972,13 +2105,23 @@ def get_commit_builder(self, branch, parents, config, timestamp=None, represented, when pushing to a foreign VCS """ if self._fallback_repositories and not self._format.supports_chks: - raise errors.BzrError("Cannot commit directly to a stacked branch" - " in pre-2a formats. See " - "https://bugs.launchpad.net/bzr/+bug/375013 for details.") + raise errors.BzrError( + "Cannot commit directly to a stacked branch" + " in pre-2a formats. See " + "https://bugs.launchpad.net/bzr/+bug/375013 for details." + ) commit_builder_kls = vf_repository.VersionedFileCommitBuilder - result = commit_builder_kls(self, parents, config, - timestamp, timezone, committer, revprops, revision_id, - lossy) + result = commit_builder_kls( + self, + parents, + config, + timestamp, + timezone, + committer, + revprops, + revision_id, + lossy, + ) self.start_write_group() return result @@ -1989,7 +2132,8 @@ def add_fallback_repository(self, repository): """ if not self._format.supports_external_lookups: raise errors.UnstackableRepositoryFormat( - self._format.network_name(), self.base) + self._format.network_name(), self.base + ) # We need to accumulate additional repositories here, to pass them in # on various RPC's. # @@ -2005,8 +2149,9 @@ def add_fallback_repository(self, repository): # _real_branch had its get_stacked_on_url method called), then the # repository to be added may already be in the _real_repositories list. if self._real_repository is not None: - fallback_locations = [repo.user_url for repo in - self._real_repository._fallback_repositories] + fallback_locations = [ + repo.user_url for repo in self._real_repository._fallback_repositories + ] if repository.user_url not in fallback_locations: self._real_repository.add_fallback_repository(repository) @@ -2017,19 +2162,30 @@ def _check_fallback_repository(self, repository): :param repository: A repository to fallback to. """ - return _mod_repository.InterRepository._assert_same_model( - self, repository) + return _mod_repository.InterRepository._assert_same_model(self, repository) def add_inventory(self, revid, inv, parents): self._ensure_real() return self._real_repository.add_inventory(revid, inv, parents) - def add_inventory_by_delta(self, basis_revision_id, delta, new_revision_id, - parents, basis_inv=None, propagate_caches=False): + def add_inventory_by_delta( + self, + basis_revision_id, + delta, + new_revision_id, + parents, + basis_inv=None, + propagate_caches=False, + ): self._ensure_real() - return self._real_repository.add_inventory_by_delta(basis_revision_id, - delta, new_revision_id, parents, basis_inv=basis_inv, - propagate_caches=propagate_caches) + return self._real_repository.add_inventory_by_delta( + basis_revision_id, + delta, + new_revision_id, + parents, + basis_inv=basis_inv, + propagate_caches=propagate_caches, + ) def add_revision(self, revision_id, rev, inv=None): _mod_revision.check_not_reserved_id(revision_id) @@ -2037,12 +2193,12 @@ def add_revision(self, revision_id, rev, inv=None): # check inventory present if not self.inventories.get_parent_map([key]): if inv is None: - raise errors.WeaveRevisionNotPresent(revision_id, - self.inventories) + raise errors.WeaveRevisionNotPresent(revision_id, self.inventories) else: # yes, this is not suitable for adding with ghosts. - rev.inventory_sha1 = self.add_inventory(revision_id, inv, - rev.parent_ids) + rev.inventory_sha1 = self.add_inventory( + revision_id, inv, rev.parent_ids + ) else: rev.inventory_sha1 = self.inventories.get_sha1s([key])[key] self._add_revision(rev) @@ -2054,8 +2210,19 @@ def _add_revision(self, rev): key = (rev.revision_id,) parents = tuple((parent,) for parent in rev.parent_ids) self._write_group_tokens, missing_keys = self._get_sink().insert_stream( - [('revisions', [ChunkedContentFactory(key, parents, None, lines, chunks_are_lines=True)])], - self._format, self._write_group_tokens) + [ + ( + "revisions", + [ + ChunkedContentFactory( + key, parents, None, lines, chunks_are_lines=True + ) + ], + ) + ], + self._format, + self._write_group_tokens, + ) def get_inventory(self, revision_id): with self.lock_read(): @@ -2063,13 +2230,14 @@ def get_inventory(self, revision_id): def _iter_inventories_rpc(self, revision_ids, ordering): if ordering is None: - ordering = 'unordered' + ordering = "unordered" path = self.controldir._path_for_remote_call(self._client) body = b"\n".join(revision_ids) - response_tuple, response_handler = ( - self._call_with_body_bytes_expecting_body( - b"VersionedFileRepository.get_inventories", - (path, ordering.encode('ascii')), body)) + response_tuple, response_handler = self._call_with_body_bytes_expecting_body( + b"VersionedFileRepository.get_inventories", + (path, ordering.encode("ascii")), + body, + ) if response_tuple[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response_tuple) deserializer = inventory_delta.InventoryDeltaDeserializer() @@ -2082,25 +2250,31 @@ def _iter_inventories_rpc(self, revision_ids, ordering): if src_format.network_name() != self._format.network_name(): raise AssertionError( "Mismatched RemoteRepository and stream src {!r}, {!r}".format( - src_format.network_name(), self._format.network_name())) + src_format.network_name(), self._format.network_name() + ) + ) # ignore the src format, it's not really relevant - prev_inv = Inventory(root_id=None, - revision_id=_mod_revision.NULL_REVISION) + prev_inv = Inventory(root_id=None, revision_id=_mod_revision.NULL_REVISION) # there should be just one substream, with inventory deltas try: substream_kind, substream = next(stream) except StopIteration: return if substream_kind != "inventory-deltas": - raise AssertionError( - f"Unexpected stream {substream_kind!r} received") + raise AssertionError(f"Unexpected stream {substream_kind!r} received") for record in substream: - (parent_id, new_id, versioned_root, tree_references, invdelta) = ( - deserializer.parse_text_bytes(record.get_bytes_as("lines"))) + ( + parent_id, + new_id, + versioned_root, + tree_references, + invdelta, + ) = deserializer.parse_text_bytes(record.get_bytes_as("lines")) invdelta = InventoryDelta(invdelta) if parent_id != prev_inv.revision_id: - raise AssertionError("invalid base {!r} != {!r}".format(parent_id, - prev_inv.revision_id)) + raise AssertionError( + f"invalid base {parent_id!r} != {prev_inv.revision_id!r}" + ) inv = prev_inv.create_by_apply_delta(invdelta, new_id) yield inv, inv.revision_id prev_inv = inv @@ -2122,9 +2296,8 @@ def iter_inventories(self, revision_ids, ordering=None): buffering if necessary). :return: An iterator of inventories. """ - if ((None in revision_ids) or - (_mod_revision.NULL_REVISION in revision_ids)): - raise ValueError('cannot get null revision inventory') + if (None in revision_ids) or (_mod_revision.NULL_REVISION in revision_ids): + raise ValueError("cannot get null revision inventory") for inv, revid in self._iter_inventories(revision_ids, ordering): if inv is None: raise errors.NoSuchRevision(self, revid) @@ -2142,11 +2315,11 @@ def _iter_inventories(self, revision_ids, ordering=None): next_revid = order.pop() else: order_as_requested = False - if ordering != 'unordered' and self._fallback_repositories: - raise ValueError(f'unsupported ordering {ordering!r}') + if ordering != "unordered" and self._fallback_repositories: + raise ValueError(f"unsupported ordering {ordering!r}") iter_inv_fns = [self._iter_inventories_rpc] + [ - fallback._iter_inventories for fallback in - self._fallback_repositories] + fallback._iter_inventories for fallback in self._fallback_repositories + ] try: for iter_inv in iter_inv_fns: request = [revid for revid in revision_ids if revid in missing] @@ -2154,7 +2327,7 @@ def _iter_inventories(self, revision_ids, ordering=None): if inv is None: continue missing.remove(inv.revision_id) - if ordering != 'unordered': + if ordering != "unordered": invs[revid] = inv else: yield inv, revid @@ -2196,7 +2369,8 @@ def get_transaction(self): def clone(self, a_controldir, revision_id=None): with self.lock_read(): dest_repo = self._create_sprouting_repo( - a_controldir, shared=self.is_shared()) + a_controldir, shared=self.is_shared() + ) self.copy_content_into(dest_repo, revision_id) return dest_repo @@ -2204,14 +2378,13 @@ def make_working_trees(self): """See Repository.make_working_trees.""" path = self.controldir._path_for_remote_call(self._client) try: - response = self._call(b'Repository.make_working_trees', path) + response = self._call(b"Repository.make_working_trees", path) except errors.UnknownSmartMethod: self._ensure_real() return self._real_repository.make_working_trees() - if response[0] not in (b'yes', b'no'): - raise SmartProtocolError( - f'unexpected response code {response}') - return response[0] == b'yes' + if response[0] not in (b"yes", b"no"): + raise SmartProtocolError(f"unexpected response code {response}") + return response[0] == b"yes" def refresh_data(self): """Re-read any data needed to synchronise with disk. @@ -2238,13 +2411,19 @@ def revision_ids_to_search_result(self, result_set): included_keys = result_set.intersection(result_parents) start_keys = result_set.difference(included_keys) exclude_keys = result_parents.difference(result_set) - result = vf_search.SearchResult(start_keys, exclude_keys, - len(result_set), result_set) + result = vf_search.SearchResult( + start_keys, exclude_keys, len(result_set), result_set + ) return result - def search_missing_revision_ids(self, other, - find_ghosts=True, revision_ids=None, if_present_ids=None, - limit=None): + def search_missing_revision_ids( + self, + other, + find_ghosts=True, + revision_ids=None, + if_present_ids=None, + limit=None, + ): """Return the revision ids that other has that this does not. These are returned in topological order. @@ -2254,40 +2433,44 @@ def search_missing_revision_ids(self, other, with self.lock_read(): inter_repo = _mod_repository.InterRepository.get(other, self) return inter_repo.search_missing_revision_ids( - find_ghosts=find_ghosts, revision_ids=revision_ids, - if_present_ids=if_present_ids, limit=limit) - - def fetch(self, source, revision_id=None, find_ghosts=False, - fetch_spec=None, lossy=False): + find_ghosts=find_ghosts, + revision_ids=revision_ids, + if_present_ids=if_present_ids, + limit=limit, + ) + + def fetch( + self, source, revision_id=None, find_ghosts=False, fetch_spec=None, lossy=False + ): # No base implementation to use as RemoteRepository is not a subclass # of Repository; so this is a copy of Repository.fetch(). if fetch_spec is not None and revision_id is not None: - raise AssertionError( - "fetch_spec and revision_id are mutually exclusive.") + raise AssertionError("fetch_spec and revision_id are mutually exclusive.") if self.is_in_write_group(): - raise errors.InternalBzrError( - "May not fetch while in a write group.") + raise errors.InternalBzrError("May not fetch while in a write group.") # fast path same-url fetch operations - if (self.has_same_location(source) and - fetch_spec is None and - self._has_same_fallbacks(source)): + if ( + self.has_same_location(source) + and fetch_spec is None + and self._has_same_fallbacks(source) + ): # check that last_revision is in 'from' and then return a # no-operation. - if (revision_id is not None - and not _mod_revision.is_null(revision_id)): + if revision_id is not None and not _mod_revision.is_null(revision_id): self.get_revision(revision_id) return _mod_repository.FetchResult(0) # if there is no specific appropriate InterRepository, this will get # the InterRepository base class, which raises an # IncompatibleRepositories when asked to fetch. inter = _mod_repository.InterRepository.get(source, self) - if (fetch_spec is not None - and not getattr(inter, "supports_fetch_spec", False)): - raise errors.UnsupportedOperation( - f"fetch_spec not supported for {inter!r}") - return inter.fetch(revision_id=revision_id, - find_ghosts=find_ghosts, fetch_spec=fetch_spec, - lossy=lossy) + if fetch_spec is not None and not getattr(inter, "supports_fetch_spec", False): + raise errors.UnsupportedOperation(f"fetch_spec not supported for {inter!r}") + return inter.fetch( + revision_id=revision_id, + find_ghosts=find_ghosts, + fetch_spec=fetch_spec, + lossy=lossy, + ) def create_bundle(self, target, base, fileobj, format=None): self._ensure_real() @@ -2300,22 +2483,20 @@ def fileids_altered_by_revision_ids(self, revision_ids): def _get_versioned_file_checker(self, revisions, revision_versions_cache): self._ensure_real() return self._real_repository._get_versioned_file_checker( - revisions, revision_versions_cache) + revisions, revision_versions_cache + ) def _iter_files_bytes_rpc(self, desired_files, absent): path = self.controldir._path_for_remote_call(self._client) lines = [] identifiers = [] - for (file_id, revid, identifier) in desired_files: - lines.append(b''.join([ - file_id, - b'\0', - revid])) + for file_id, revid, identifier in desired_files: + lines.append(b"".join([file_id, b"\0", revid])) identifiers.append(identifier) - (response_tuple, response_handler) = ( - self._call_with_body_bytes_expecting_body( - b"Repository.iter_files_bytes", (path, ), b"\n".join(lines))) - if response_tuple != (b'ok', ): + (response_tuple, response_handler) = self._call_with_body_bytes_expecting_body( + b"Repository.iter_files_bytes", (path,), b"\n".join(lines) + ) + if response_tuple != (b"ok",): response_handler.cancel_read_body() raise errors.UnexpectedSmartServerResponse(response_tuple) byte_stream = response_handler.read_streamed_body() @@ -2331,6 +2512,7 @@ def decompress_stream(start, byte_stream, unused): yield decompressor.decompress(data) yield decompressor.flush() unused.append(decompressor.unused_data) + unused = b"" while True: while b"\n" not in unused: @@ -2349,23 +2531,29 @@ def decompress_stream(start, byte_stream, unused): else: raise errors.UnexpectedSmartServerResponse(args) unused_chunks = [] - yield (identifiers[idx], - decompress_stream(rest, byte_stream, unused_chunks)) + yield ( + identifiers[idx], + decompress_stream(rest, byte_stream, unused_chunks), + ) unused = b"".join(unused_chunks) def iter_files_bytes(self, desired_files): """See Repository.iter_file_bytes.""" try: absent = {} - for (identifier, bytes_iterator) in self._iter_files_bytes_rpc( - desired_files, absent): + for identifier, bytes_iterator in self._iter_files_bytes_rpc( + desired_files, absent + ): yield identifier, bytes_iterator for fallback in self._fallback_repositories: if not absent: break - desired_files = [(key[0], key[1], identifier) - for identifier, key in absent.items()] - for (identifier, bytes_iterator) in fallback.iter_files_bytes(desired_files): + desired_files = [ + (key[0], key[1], identifier) for identifier, key in absent.items() + ] + for identifier, bytes_iterator in fallback.iter_files_bytes( + desired_files + ): del absent[identifier] yield identifier, bytes_iterator if absent: @@ -2373,12 +2561,14 @@ def iter_files_bytes(self, desired_files): # for just one. missing_identifier = next(iter(absent)) missing_key = absent[missing_identifier] - raise errors.RevisionNotPresent(revision_id=missing_key[1], - file_id=missing_key[0]) + raise errors.RevisionNotPresent( + revision_id=missing_key[1], file_id=missing_key[0] + ) except errors.UnknownSmartMethod: self._ensure_real() - for (identifier, bytes_iterator) in ( - self._real_repository.iter_files_bytes(desired_files)): + for identifier, bytes_iterator in self._real_repository.iter_files_bytes( + desired_files + ): yield identifier, bytes_iterator def get_cached_parent_map(self, revision_ids): @@ -2419,7 +2609,7 @@ def _get_parent_map_rpc(self, keys): keys = set(keys) if None in keys: - raise ValueError('get_parent_map(None) is not valid') + raise ValueError("get_parent_map(None) is not valid") if NULL_REVISION in keys: keys.discard(NULL_REVISION) found_parents = {NULL_REVISION: ()} @@ -2449,34 +2639,39 @@ def _get_parent_map_rpc(self, keys): # Repository is not locked, so there's no cache. parents_map = {} if _DEFAULT_SEARCH_DEPTH <= 0: - (start_set, stop_keys, - key_count) = vf_search.search_result_from_parent_map( - parents_map, self._unstacked_provider.missing_keys) + (start_set, stop_keys, key_count) = vf_search.search_result_from_parent_map( + parents_map, self._unstacked_provider.missing_keys + ) else: - (start_set, stop_keys, - key_count) = vf_search.limited_search_result_from_parent_map( - parents_map, self._unstacked_provider.missing_keys, - keys, depth=_DEFAULT_SEARCH_DEPTH) - recipe = ('manual', start_set, stop_keys, key_count) + ( + start_set, + stop_keys, + key_count, + ) = vf_search.limited_search_result_from_parent_map( + parents_map, + self._unstacked_provider.missing_keys, + keys, + depth=_DEFAULT_SEARCH_DEPTH, + ) + recipe = ("manual", start_set, stop_keys, key_count) body = self._serialise_search_recipe(recipe) path = self.controldir._path_for_remote_call(self._client) for key in keys: if not isinstance(key, bytes): - raise ValueError( - f"key {key!r} not a bytes string") - verb = b'Repository.get_parent_map' - args = (path, b'include-missing:') + tuple(keys) + raise ValueError(f"key {key!r} not a bytes string") + verb = b"Repository.get_parent_map" + args = (path, b"include-missing:") + tuple(keys) try: - response = self._call_with_body_bytes_expecting_body( - verb, args, body) + response = self._call_with_body_bytes_expecting_body(verb, args, body) except errors.UnknownSmartMethod: # Server does not support this method, so get the whole graph. # Worse, we have to force a disconnection, because the server now # doesn't realise it has a body on the wire to consume, so the # only way to recover is to abandon the connection. warning( - 'Server is too old for fast get_parent_map, reconnecting. ' - '(Upgrade the server to Bazaar 1.2 to avoid this)') + "Server is too old for fast get_parent_map, reconnecting. " + "(Upgrade the server to Bazaar 1.2 to avoid this)" + ) medium.disconnect() # To avoid having to disconnect repeatedly, we keep track of the # fact the server doesn't understand remote methods added in 1.2. @@ -2484,15 +2679,15 @@ def _get_parent_map_rpc(self, keys): # Recurse just once and we should use the fallback code. return self._get_parent_map_rpc(keys) response_tuple, response_handler = response - if response_tuple[0] not in [b'ok']: + if response_tuple[0] not in [b"ok"]: response_handler.cancel_read_body() raise errors.UnexpectedSmartServerResponse(response_tuple) - if response_tuple[0] == b'ok': + if response_tuple[0] == b"ok": coded = bz2.decompress(response_handler.read_body_bytes()) - if coded == b'': + if coded == b"": # no revisions found return {} - lines = coded.split(b'\n') + lines = coded.split(b"\n") revision_graph = {} for line in lines: d = tuple(line.split()) @@ -2500,7 +2695,7 @@ def _get_parent_map_rpc(self, keys): revision_graph[d[0]] = d[1:] else: # No parents: - if d[0].startswith(b'missing:'): + if d[0].startswith(b"missing:"): revid = d[0][8:] self._unstacked_provider.note_missing_key(revid) else: @@ -2514,7 +2709,8 @@ def get_signature_text(self, revision_id): path = self.controldir._path_for_remote_call(self._client) try: response_tuple, response_handler = self._call_expecting_body( - b'Repository.get_revision_signature_text', path, revision_id) + b"Repository.get_revision_signature_text", path, revision_id + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_repository.get_signature_text(revision_id) @@ -2526,7 +2722,7 @@ def get_signature_text(self, revision_id): pass raise err else: - if response_tuple[0] != b'ok': + if response_tuple[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response_tuple) return response_handler.read_body_bytes() @@ -2539,25 +2735,27 @@ def _get_inventory_xml(self, revision_id): def reconcile(self, other=None, thorough=False): from ..reconcile import ReconcileResult + with self.lock_write(): path = self.controldir._path_for_remote_call(self._client) try: response, handler = self._call_expecting_body( - b'Repository.reconcile', path, self._lock_token) + b"Repository.reconcile", path, self._lock_token + ) except (errors.UnknownSmartMethod, errors.TokenLockingNotSupported): self._ensure_real() return self._real_repository.reconcile(other=other, thorough=thorough) - if response != (b'ok', ): + if response != (b"ok",): raise errors.UnexpectedSmartServerResponse(response) body = handler.read_body_bytes() result = ReconcileResult() result.garbage_inventories = None result.inconsistent_parents = None result.aborted = None - for line in body.split(b'\n'): + for line in body.split(b"\n"): if not line: continue - key, val_text = line.split(b':') + key, val_text = line.split(b":") if key == b"garbage_inventories": result.garbage_inventories = int(val_text) elif key == b"inconsistent_parents": @@ -2570,11 +2768,12 @@ def all_revision_ids(self): path = self.controldir._path_for_remote_call(self._client) try: response_tuple, response_handler = self._call_expecting_body( - b"Repository.all_revision_ids", path) + b"Repository.all_revision_ids", path + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_repository.all_revision_ids() - if response_tuple != (b"ok", ): + if response_tuple != (b"ok",): raise errors.UnexpectedSmartServerResponse(response_tuple) revids = set(response_handler.read_body_bytes().splitlines()) for fallback in self._fallback_repositories: @@ -2616,8 +2815,11 @@ def get_revision_reconcile(self, revision_id): def check(self, revision_ids=None, callback_refs=None, check_repo=True): with self.lock_read(): self._ensure_real() - return self._real_repository.check(revision_ids=revision_ids, - callback_refs=callback_refs, check_repo=check_repo) + return self._real_repository.check( + revision_ids=revision_ids, + callback_refs=callback_refs, + check_repo=check_repo, + ) def copy_content_into(self, destination, revision_id=None): """Make a complete copy of the content in self into destination. @@ -2635,13 +2837,13 @@ def _copy_repository_tarball(self, to_bzrdir, revision_id=None): # TODO: Maybe a progress bar while streaming the tarball? note(gettext("Copying repository content as tarball...")) - tar_file = self._get_tarball('bz2') + tar_file = self._get_tarball("bz2") if tar_file is None: return None destination = to_bzrdir.create_repository() - with tarfile.open('repository', fileobj=tar_file, - mode='r|bz2') as tar, \ - osutils.TemporaryDirectory() as tmpdir: + with tarfile.open( + "repository", fileobj=tar_file, mode="r|bz2" + ) as tar, osutils.TemporaryDirectory() as tmpdir: tar.extractall(tmpdir) tmp_bzrdir = _mod_bzrdir.BzrDir.open(tmpdir) tmp_repo = tmp_bzrdir.open_repository() @@ -2665,19 +2867,22 @@ def pack(self, hint=None, clean_obsolete_packs=False): if hint is None: body = b"" else: - body = b"".join([l.encode('ascii') + b"\n" for l in hint]) + body = b"".join([l.encode("ascii") + b"\n" for l in hint]) with self.lock_write(): path = self.controldir._path_for_remote_call(self._client) try: response, handler = self._call_with_body_bytes_expecting_body( - b'Repository.pack', (path, self._lock_token, - str(clean_obsolete_packs).encode('ascii')), body) + b"Repository.pack", + (path, self._lock_token, str(clean_obsolete_packs).encode("ascii")), + body, + ) except errors.UnknownSmartMethod: self._ensure_real() - return self._real_repository.pack(hint=hint, - clean_obsolete_packs=clean_obsolete_packs) + return self._real_repository.pack( + hint=hint, clean_obsolete_packs=clean_obsolete_packs + ) handler.cancel_read_body() - if response != (b'ok', ): + if response != (b"ok",): raise errors.UnexpectedSmartServerResponse(response) @property @@ -2697,12 +2902,13 @@ def set_make_working_trees(self, new_value): path = self.controldir._path_for_remote_call(self._client) try: response = self._call( - b'Repository.set_make_working_trees', path, new_value_str) + b"Repository.set_make_working_trees", path, new_value_str + ) except errors.UnknownSmartMethod: self._ensure_real() self._real_repository.set_make_working_trees(new_value) else: - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) @property @@ -2717,8 +2923,7 @@ def signatures(self): def sign_revision(self, revision_id, gpg_strategy): with self.lock_write(): - testament = _mod_testament.Testament.from_revision( - self, revision_id) + testament = _mod_testament.Testament.from_revision(self, revision_id) plaintext = testament.as_short_text() self.store_revision_signature(gpg_strategy, plaintext, revision_id) @@ -2735,12 +2940,12 @@ def texts(self): def _iter_revisions_rpc(self, revision_ids): body = b"\n".join(revision_ids) path = self.controldir._path_for_remote_call(self._client) - response_tuple, response_handler = ( - self._call_with_body_bytes_expecting_body( - b"Repository.iter_revisions", (path, ), body)) + response_tuple, response_handler = self._call_with_body_bytes_expecting_body( + b"Repository.iter_revisions", (path,), body + ) if response_tuple[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response_tuple) - serializer_format = response_tuple[1].decode('ascii') + serializer_format = response_tuple[1].decode("ascii") serializer = revision_format_registry.get(serializer_format) byte_stream = response_handler.read_streamed_body() decompressor = zlib.decompressobj() @@ -2761,8 +2966,7 @@ def _iter_revisions_rpc(self, revision_ids): def iter_revisions(self, revision_ids): for rev_id in revision_ids: if not rev_id or not isinstance(rev_id, bytes): - raise errors.InvalidRevisionId( - revision_id=rev_id, branch=self) + raise errors.InvalidRevisionId(revision_id=rev_id, branch=self) with self.lock_read(): try: missing = set(revision_ids) @@ -2772,7 +2976,7 @@ def iter_revisions(self, revision_ids): for fallback in self._fallback_repositories: if not missing: break - for (revid, rev) in fallback.iter_revisions(missing): + for revid, rev in fallback.iter_revisions(missing): if rev is not None: yield (revid, rev) missing.remove(revid) @@ -2803,35 +3007,32 @@ def add_signature_text(self, revision_id, signature): # If there is a real repository the write group will # be in the real repository as well, so use that: self._ensure_real() - return self._real_repository.add_signature_text( - revision_id, signature) + return self._real_repository.add_signature_text(revision_id, signature) path = self.controldir._path_for_remote_call(self._client) response, handler = self._call_with_body_bytes_expecting_body( - b'Repository.add_signature_text', (path, self._lock_token, - revision_id) + - tuple([token.encode('utf-8') - for token in self._write_group_tokens]), - signature) + b"Repository.add_signature_text", + (path, self._lock_token, revision_id) + + tuple([token.encode("utf-8") for token in self._write_group_tokens]), + signature, + ) handler.cancel_read_body() self.refresh_data() - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) - self._write_group_tokens = [token.decode( - 'utf-8') for token in response[1:]] + self._write_group_tokens = [token.decode("utf-8") for token in response[1:]] def has_signature_for_revision_id(self, revision_id): path = self.controldir._path_for_remote_call(self._client) try: - response = self._call(b'Repository.has_signature_for_revision_id', - path, revision_id) + response = self._call( + b"Repository.has_signature_for_revision_id", path, revision_id + ) except errors.UnknownSmartMethod: self._ensure_real() - return self._real_repository.has_signature_for_revision_id( - revision_id) - if response[0] not in (b'yes', b'no'): - raise SmartProtocolError( - f'unexpected response code {response}') - if response[0] == b'yes': + return self._real_repository.has_signature_for_revision_id(revision_id) + if response[0] not in (b"yes", b"no"): + raise SmartProtocolError(f"unexpected response code {response}") + if response[0] == b"yes": return True for fallback in self._fallback_repositories: if fallback.has_signature_for_revision_id(revision_id): @@ -2844,8 +3045,7 @@ def verify_revision_signature(self, revision_id, gpg_strategy): return gpg.SIGNATURE_NOT_SIGNED, None signature = self.get_signature_text(revision_id) - testament = _mod_testament.Testament.from_revision( - self, revision_id) + testament = _mod_testament.Testament.from_revision(self, revision_id) (status, key, signed_plaintext) = gpg_strategy.verify(signature) if testament.as_short_text() != signed_plaintext: @@ -2854,13 +3054,15 @@ def verify_revision_signature(self, revision_id, gpg_strategy): def item_keys_introduced_by(self, revision_ids, _files_pb=None): self._ensure_real() - return self._real_repository.item_keys_introduced_by(revision_ids, - _files_pb=_files_pb) + return self._real_repository.item_keys_introduced_by( + revision_ids, _files_pb=_files_pb + ) def _find_inconsistent_revision_parents(self, revisions_iterator=None): self._ensure_real() return self._real_repository._find_inconsistent_revision_parents( - revisions_iterator) + revisions_iterator + ) def _check_for_inconsistent_revision_parents(self): self._ensure_real() @@ -2870,8 +3072,9 @@ def _make_parents_provider(self, other=None): providers = [self._unstacked_provider] if other is not None: providers.insert(0, other) - return graph.StackedParentsProvider(_LazyListJoin( - providers, self._fallback_repositories)) + return graph.StackedParentsProvider( + _LazyListJoin(providers, self._fallback_repositories) + ) def _serialise_search_recipe(self, recipe): """Serialise a graph search recipe. @@ -2879,67 +3082,74 @@ def _serialise_search_recipe(self, recipe): :param recipe: A search recipe (start, stop, count). :return: Serialised bytes. """ - start_keys = b' '.join(recipe[1]) - stop_keys = b' '.join(recipe[2]) - count = str(recipe[3]).encode('ascii') - return b'\n'.join((start_keys, stop_keys, count)) + start_keys = b" ".join(recipe[1]) + stop_keys = b" ".join(recipe[2]) + count = str(recipe[3]).encode("ascii") + return b"\n".join((start_keys, stop_keys, count)) def _serialise_search_result(self, search_result): parts = search_result.get_network_struct() - return b'\n'.join(parts) + return b"\n".join(parts) def autopack(self): path = self.controldir._path_for_remote_call(self._client) try: - response = self._call(b'PackRepository.autopack', path) + response = self._call(b"PackRepository.autopack", path) except errors.UnknownSmartMethod: self._ensure_real() self._real_repository._pack_collection.autopack() return self.refresh_data() - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) - def _revision_archive(self, revision_id, format, name, root, subdir, - force_mtime=None): + def _revision_archive( + self, revision_id, format, name, root, subdir, force_mtime=None + ): path = self.controldir._path_for_remote_call(self._client) - format = format or '' - root = root or '' - subdir = subdir or '' + format = format or "" + root = root or "" + subdir = subdir or "" force_mtime = int(force_mtime) if force_mtime is not None else None try: response, protocol = self._call_expecting_body( - b'Repository.revision_archive', path, + b"Repository.revision_archive", + path, revision_id, - format.encode('ascii'), - os.path.basename(name).encode('utf-8'), - root.encode('utf-8'), - subdir.encode('utf-8'), - force_mtime) + format.encode("ascii"), + os.path.basename(name).encode("utf-8"), + root.encode("utf-8"), + subdir.encode("utf-8"), + force_mtime, + ) except errors.UnknownSmartMethod: return None - if response[0] == b'ok': + if response[0] == b"ok": return iter([protocol.read_body_bytes()]) raise errors.UnexpectedSmartServerResponse(response) def _annotate_file_revision(self, revid, tree_path, file_id, default_revision): path = self.controldir._path_for_remote_call(self._client) - tree_path = tree_path.encode('utf-8') - file_id = file_id or b'' - default_revision = default_revision or b'' + tree_path = tree_path.encode("utf-8") + file_id = file_id or b"" + default_revision = default_revision or b"" try: response, handler = self._call_expecting_body( - b'Repository.annotate_file_revision', path, - revid, tree_path, file_id, default_revision) + b"Repository.annotate_file_revision", + path, + revid, + tree_path, + file_id, + default_revision, + ) except errors.UnknownSmartMethod: return None - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) return map(tuple, bencode.bdecode(handler.read_body_bytes())) class RemoteStreamSink(vf_repository.StreamSink): - def _insert_real(self, stream, src_format, resume_tokens): self.target_repo._ensure_real() sink = self.target_repo._real_repository._get_sink() @@ -2949,26 +3159,27 @@ def _insert_real(self, stream, src_format, resume_tokens): return result def insert_missing_keys(self, source, missing_keys): - if (isinstance(source, RemoteStreamSource) - and source.from_repository._client._medium == self.target_repo._client._medium): + if ( + isinstance(source, RemoteStreamSource) + and source.from_repository._client._medium + == self.target_repo._client._medium + ): # Streaming from and to the same medium is tricky, since we don't support # more than one concurrent request. For now, just force VFS. stream = source._get_real_stream_for_missing_keys(missing_keys) else: stream = source.get_stream_for_missing_keys(missing_keys) - return self.insert_stream_without_locking(stream, - self.target_repo._format) + return self.insert_stream_without_locking(stream, self.target_repo._format) def insert_stream(self, stream, src_format, resume_tokens): target = self.target_repo target._unstacked_provider.missing_keys.clear() - candidate_calls = [(b'Repository.insert_stream_1.19', (1, 19))] + candidate_calls = [(b"Repository.insert_stream_1.19", (1, 19))] if target._lock_token: - candidate_calls.append( - (b'Repository.insert_stream_locked', (1, 14))) - lock_args = (target._lock_token or b'',) + candidate_calls.append((b"Repository.insert_stream_locked", (1, 14))) + lock_args = (target._lock_token or b"",) else: - candidate_calls.append((b'Repository.insert_stream', (1, 13))) + candidate_calls.append((b"Repository.insert_stream", (1, 13))) lock_args = () client = target._client medium = client._medium @@ -2989,7 +3200,8 @@ def insert_stream(self, stream, src_format, resume_tokens): byte_stream = smart_repo._stream_to_byte_stream([], src_format) try: response = client.call_with_body_stream( - (verb, path, b'') + lock_args, byte_stream) + (verb, path, b"") + lock_args, byte_stream + ) except errors.UnknownSmartMethod: medium._remember_remote_is_before(required_version) else: @@ -3006,13 +3218,12 @@ def insert_stream(self, stream, src_format, resume_tokens): # deltas we'll interrupt the smart insert_stream request and # fallback to VFS. stream = self._stop_stream_if_inventory_delta(stream) - byte_stream = smart_repo._stream_to_byte_stream( - stream, src_format) - resume_tokens = b' '.join([token.encode('utf-8') - for token in resume_tokens]) + byte_stream = smart_repo._stream_to_byte_stream(stream, src_format) + resume_tokens = b" ".join([token.encode("utf-8") for token in resume_tokens]) response = client.call_with_body_stream( - (verb, path, resume_tokens) + lock_args, byte_stream) - if response[0][0] not in (b'ok', b'missing-basis'): + (verb, path, resume_tokens) + lock_args, byte_stream + ) + if response[0][0] not in (b"ok", b"missing-basis"): raise errors.UnexpectedSmartServerResponse(response) if self._last_substream is not None: # The stream included an inventory-delta record, but the remote @@ -3020,10 +3231,12 @@ def insert_stream(self, stream, src_format, resume_tokens): # rest of the stream via VFS. self.target_repo.refresh_data() return self._resume_stream_with_vfs(response, src_format) - if response[0][0] == b'missing-basis': + if response[0][0] == b"missing-basis": tokens, missing_keys = bencode.bdecode_as_tuple(response[0][1]) - resume_tokens = [token.decode('utf-8') for token in tokens] - return resume_tokens, {(entry[0].decode('utf-8'), ) + entry[1:] for entry in missing_keys} + resume_tokens = [token.decode("utf-8") for token in tokens] + return resume_tokens, { + (entry[0].decode("utf-8"),) + entry[1:] for entry in missing_keys + } else: self.target_repo.refresh_data() return [], set() @@ -3032,9 +3245,9 @@ def _resume_stream_with_vfs(self, response, src_format): """Resume sending a stream via VFS, first resending the record and substream that couldn't be sent via an insert_stream verb. """ - if response[0][0] == b'missing-basis': + if response[0][0] == b"missing-basis": tokens, missing_keys = bencode.bdecode_as_tuple(response[0][1]) - tokens = [token.decode('utf-8') for token in tokens] + tokens = [token.decode("utf-8") for token in tokens] # Ignore missing_keys, we haven't finished inserting yet else: tokens = [] @@ -3046,9 +3259,10 @@ def resume_substream(): def resume_stream(): # Finish sending the interrupted substream - yield ('inventory-deltas', resume_substream()) + yield ("inventory-deltas", resume_substream()) # Then simply continue sending the rest of the stream. yield from self._last_stream + return self._insert_real(resume_stream(), src_format, tokens) def _stop_stream_if_inventory_delta(self, stream): @@ -3061,7 +3275,7 @@ def _stop_stream_if_inventory_delta(self, stream): """ stream_iter = iter(stream) for substream_kind, substream in stream_iter: - if substream_kind == 'inventory-deltas': + if substream_kind == "inventory-deltas": self._last_substream = substream self._last_stream = stream_iter return @@ -3073,8 +3287,10 @@ class RemoteStreamSource(vf_repository.StreamSource): """Stream data from a remote server.""" def get_stream(self, search): - if (self.from_repository._fallback_repositories - and self.to_format._fetch_order == 'topological'): + if ( + self.from_repository._fallback_repositories + and self.to_format._fetch_order == "topological" + ): return self._real_stream(self.from_repository, search) sources = [] seen = set() @@ -3103,22 +3319,30 @@ def get_stream_for_missing_keys(self, missing_keys): return self._get_real_stream_for_missing_keys(missing_keys) path = self.from_repository.controldir._path_for_remote_call(client) args = (path, self.to_format.network_name()) - search_bytes = b'\n'.join( - [b'%s\t%s' % (key[0].encode('utf-8'), key[1]) for key in missing_keys]) + search_bytes = b"\n".join( + [b"%s\t%s" % (key[0].encode("utf-8"), key[1]) for key in missing_keys] + ) try: - response, handler = self.from_repository._call_with_body_bytes_expecting_body( - b'Repository.get_stream_for_missing_keys', args, search_bytes) + ( + response, + handler, + ) = self.from_repository._call_with_body_bytes_expecting_body( + b"Repository.get_stream_for_missing_keys", args, search_bytes + ) except (errors.UnknownSmartMethod, errors.UnknownFormatError): return self._get_real_stream_for_missing_keys(missing_keys) - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) byte_stream = handler.read_streamed_body() - src_format, stream = smart_repo._byte_stream_to_stream(byte_stream, - self._record_counter) + src_format, stream = smart_repo._byte_stream_to_stream( + byte_stream, self._record_counter + ) if src_format.network_name() != self.from_repository._format.network_name(): raise AssertionError( "Mismatched RemoteRepository and stream src {!r}, {!r}".format( - src_format.network_name(), repo._format.network_name())) + src_format.network_name(), repo._format.network_name() + ) + ) return stream def _real_stream(self, repo, search): @@ -3160,8 +3384,9 @@ def _get_stream(self, repo, search): search_bytes = repo._serialise_search_result(search) args = (path, self.to_format.network_name()) candidate_verbs = [ - (b'Repository.get_stream_1.19', (1, 19)), - (b'Repository.get_stream', (1, 13))] + (b"Repository.get_stream_1.19", (1, 19)), + (b"Repository.get_stream", (1, 13)), + ] found_verb = False for verb, version in candidate_verbs: @@ -3169,13 +3394,14 @@ def _get_stream(self, repo, search): continue try: response = repo._call_with_body_bytes_expecting_body( - verb, args, search_bytes) + verb, args, search_bytes + ) except errors.UnknownSmartMethod: medium._remember_remote_is_before(version) except UnknownErrorFromSmartServer as e: if isinstance(search, vf_search.EverythingResult): error_verb = e.error_from_smart_server.error_verb - if error_verb == b'BadSearch': + if error_verb == b"BadSearch": # Pre-2.4 servers don't support this sort of search. # XXX: perhaps falling back to VFS on BadSearch is a # good idea in general? It might provide a little bit @@ -3189,15 +3415,18 @@ def _get_stream(self, repo, search): break if not found_verb: return self._real_stream(repo, search) - if response_tuple[0] != b'ok': + if response_tuple[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response_tuple) byte_stream = response_handler.read_streamed_body() - src_format, stream = smart_repo._byte_stream_to_stream(byte_stream, - self._record_counter) + src_format, stream = smart_repo._byte_stream_to_stream( + byte_stream, self._record_counter + ) if src_format.network_name() != repo._format.network_name(): raise AssertionError( "Mismatched RemoteRepository and stream src {!r}, {!r}".format( - src_format.network_name(), repo._format.network_name())) + src_format.network_name(), repo._format.network_name() + ) + ) return stream def missing_parents_chain(self, search, sources): @@ -3206,7 +3435,9 @@ def missing_parents_chain(self, search, sources): :param search: The overall search to satisfy with streams. :param sources: A list of Repository objects to query. """ - self.from_revision_serialiser = self.from_repository._format._revision_serializer + self.from_revision_serialiser = ( + self.from_repository._format._revision_serializer + ) self.seen_revs = set() self.referenced_revs = set() # If there are heads in the search, or the key count is > 0, we are not @@ -3215,7 +3446,7 @@ def missing_parents_chain(self, search, sources): source = sources.pop(0) stream = self._get_stream(source, search) for kind, substream in stream: - if kind != 'revisions': + if kind != "revisions": yield kind, substream else: yield kind, self.missing_parents_rev_handler(substream) @@ -3228,9 +3459,10 @@ def missing_parents_chain(self, search, sources): def missing_parents_rev_handler(self, substream): for content in substream: - revision_bytes = content.get_bytes_as('fulltext') + revision_bytes = content.get_bytes_as("fulltext") revision = self.from_revision_serialiser.read_revision_from_string( - revision_bytes) + revision_bytes + ) self.seen_revs.add(content.key[-1]) self.referenced_revs.update(revision.parent_ids) yield content @@ -3247,8 +3479,8 @@ def __init__(self, bzrdir, _client): self._client = _client self._need_find_modes = True LockableFiles.__init__( - self, bzrdir.get_branch_transport(None), - 'lock', lockdir.LockDir) + self, bzrdir.get_branch_transport(None), "lock", lockdir.LockDir + ) def _find_modes(self): # RemoteBranches don't let the client set the mode of control files. @@ -3257,7 +3489,6 @@ def _find_modes(self): class RemoteBranchFormat(branch.BranchFormat): - def __init__(self, network_name=None): super().__init__() self._matchingcontroldir = RemoteBzrDirFormat() @@ -3266,52 +3497,63 @@ def __init__(self, network_name=None): self._network_name = network_name def __eq__(self, other): - return (isinstance(other, RemoteBranchFormat) - and self.__dict__ == other.__dict__) + return isinstance(other, RemoteBranchFormat) and self.__dict__ == other.__dict__ def _ensure_real(self): if self._custom_format is None: try: self._custom_format = branch.network_format_registry.get( - self._network_name) + self._network_name + ) except KeyError as e: - raise errors.UnknownFormatError(kind='branch', format=self._network_name) from e + raise errors.UnknownFormatError( + kind="branch", format=self._network_name + ) from e def get_format_description(self): self._ensure_real() - return 'Remote: ' + self._custom_format.get_format_description() + return "Remote: " + self._custom_format.get_format_description() def network_name(self): return self._network_name def open(self, a_controldir, name=None, ignore_fallbacks=False): - return a_controldir.open_branch(name=name, - ignore_fallbacks=ignore_fallbacks) + return a_controldir.open_branch(name=name, ignore_fallbacks=ignore_fallbacks) - def _vfs_initialize(self, a_controldir, name, append_revisions_only, - repository=None): + def _vfs_initialize( + self, a_controldir, name, append_revisions_only, repository=None + ): # Initialisation when using a local bzrdir object, or a non-vfs init # method is not available on the server. # self._custom_format is always set - the start of initialize ensures # that. if isinstance(a_controldir, RemoteBzrDir): a_controldir._ensure_real() - result = self._custom_format.initialize(a_controldir._real_bzrdir, - name=name, append_revisions_only=append_revisions_only, - repository=repository) + result = self._custom_format.initialize( + a_controldir._real_bzrdir, + name=name, + append_revisions_only=append_revisions_only, + repository=repository, + ) else: # We assume the bzrdir is parameterised; it may not be. - result = self._custom_format.initialize(a_controldir, name=name, - append_revisions_only=append_revisions_only, - repository=repository) - if (isinstance(a_controldir, RemoteBzrDir) - and not isinstance(result, RemoteBranch)): - result = RemoteBranch(a_controldir, a_controldir.find_repository(), result, - name=name) + result = self._custom_format.initialize( + a_controldir, + name=name, + append_revisions_only=append_revisions_only, + repository=repository, + ) + if isinstance(a_controldir, RemoteBzrDir) and not isinstance( + result, RemoteBranch + ): + result = RemoteBranch( + a_controldir, a_controldir.find_repository(), result, name=name + ) return result - def initialize(self, a_controldir, name=None, repository=None, - append_revisions_only=None): + def initialize( + self, a_controldir, name=None, repository=None, append_revisions_only=None + ): if name is None: name = a_controldir._get_selected_branch() # 1) get the network name to use. @@ -3319,62 +3561,72 @@ def initialize(self, a_controldir, name=None, repository=None, network_name = self._custom_format.network_name() else: # Select the current breezy default and ask for that. - reference_bzrdir_format = controldir.format_registry.get( - 'default')() + reference_bzrdir_format = controldir.format_registry.get("default")() reference_format = reference_bzrdir_format.get_branch_format() self._custom_format = reference_format network_name = reference_format.network_name() # Being asked to create on a non RemoteBzrDir: if not isinstance(a_controldir, RemoteBzrDir): - return self._vfs_initialize(a_controldir, name=name, - append_revisions_only=append_revisions_only, - repository=repository) + return self._vfs_initialize( + a_controldir, + name=name, + append_revisions_only=append_revisions_only, + repository=repository, + ) medium = a_controldir._client._medium if medium._is_remote_before((1, 13)): - return self._vfs_initialize(a_controldir, name=name, - append_revisions_only=append_revisions_only, - repository=repository) + return self._vfs_initialize( + a_controldir, + name=name, + append_revisions_only=append_revisions_only, + repository=repository, + ) # Creating on a remote bzr dir. # 2) try direct creation via RPC path = a_controldir._path_for_remote_call(a_controldir._client) if name != "": # XXX JRV20100304: Support creating colocated branches raise controldir.NoColocatedBranchSupport(self) - verb = b'BzrDir.create_branch' + verb = b"BzrDir.create_branch" try: response = a_controldir._call(verb, path, network_name) except errors.UnknownSmartMethod: # Fallback - use vfs methods medium._remember_remote_is_before((1, 13)) - return self._vfs_initialize(a_controldir, name=name, - append_revisions_only=append_revisions_only, - repository=repository) - if response[0] != b'ok': + return self._vfs_initialize( + a_controldir, + name=name, + append_revisions_only=append_revisions_only, + repository=repository, + ) + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) # Turn the response into a RemoteRepository object. format = RemoteBranchFormat(network_name=response[1]) repo_format = response_tuple_to_repo_format(response[3:]) - repo_path = response[2].decode('utf-8') + repo_path = response[2].decode("utf-8") if repository is not None: remote_repo_url = urlutils.join(a_controldir.user_url, repo_path) - url_diff = urlutils.relative_url(repository.user_url, - remote_repo_url) - if url_diff != '.': + url_diff = urlutils.relative_url(repository.user_url, remote_repo_url) + if url_diff != ".": raise AssertionError( - f'repository.user_url {repository.user_url!r} does not match URL from server ' - f'response ({a_controldir.user_url!r} + {repo_path!r})') + f"repository.user_url {repository.user_url!r} does not match URL from server " + f"response ({a_controldir.user_url!r} + {repo_path!r})" + ) remote_repo = repository else: - if repo_path == '': + if repo_path == "": repo_bzrdir = a_controldir else: repo_bzrdir = RemoteBzrDir( - a_controldir.root_transport.clone( - repo_path), a_controldir._format, - a_controldir._client) + a_controldir.root_transport.clone(repo_path), + a_controldir._format, + a_controldir._client, + ) remote_repo = RemoteRepository(repo_bzrdir, repo_format) - remote_branch = RemoteBranch(a_controldir, remote_repo, - format=format, setup_stacking=False, name=name) + remote_branch = RemoteBranch( + a_controldir, remote_repo, format=format, setup_stacking=False, name=name + ) if append_revisions_only: remote_branch.set_append_revisions_only(append_revisions_only) # XXX: We know this is a new branch, so it must have revno 0, revid @@ -3437,17 +3689,18 @@ def __init__(self, branch): self._real_store = None def external_url(self): - return urlutils.join(self.branch.user_url, 'branch.conf') + return urlutils.join(self.branch.user_url, "branch.conf") def _load_content(self): path = self.branch._remote_path() try: response, handler = self.branch._call_expecting_body( - b'Branch.get_config_file', path) + b"Branch.get_config_file", path + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_store._load_content() - if len(response) and response[0] != b'ok': + if len(response) and response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) return handler.read_body_bytes() @@ -3455,14 +3708,15 @@ def _save_content(self, content): path = self.branch._remote_path() try: response, handler = self.branch._call_with_body_bytes_expecting_body( - b'Branch.put_config_file', (path, - self.branch._lock_token, self.branch._repo_lock_token), - content) + b"Branch.put_config_file", + (path, self.branch._lock_token, self.branch._repo_lock_token), + content, + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_store._save_content(content) handler.cancel_read_body() - if response != (b'ok', ): + if response != (b"ok",): raise errors.UnexpectedSmartServerResponse(response) def _ensure_real(self): @@ -3485,11 +3739,17 @@ class RemoteBranch(branch.Branch, _RpcHelper, lock._RelockDebugMixin): def control_transport(self) -> _mod_transport.Transport: return self._transport # type: ignore - def __init__(self, remote_bzrdir: RemoteBzrDir, remote_repository: RemoteRepository, - real_branch: Optional["bzrbranch.BzrBranch"] = None, - _client=None, format=None, setup_stacking: bool = True, - name: Optional[str] = None, - possible_transports: Optional[List[_mod_transport.Transport]] = None): + def __init__( + self, + remote_bzrdir: RemoteBzrDir, + remote_repository: RemoteRepository, + real_branch: Optional["bzrbranch.BzrBranch"] = None, + _client=None, + format=None, + setup_stacking: bool = True, + name: Optional[str] = None, + possible_transports: Optional[List[_mod_transport.Transport]] = None, + ): """Create a RemoteBranch instance. :param real_branch: An optional local implementation of the branch @@ -3543,8 +3803,7 @@ def __init__(self, remote_bzrdir: RemoteBzrDir, remote_repository: RemoteReposit if format is None: self._format = RemoteBranchFormat() if self._real_branch is not None: - self._format._network_name = \ - self._real_branch._format.network_name() + self._format._network_name = self._real_branch._format.network_name() else: self._format = format # when we do _ensure_real we may need to pass ignore_fallbacks to the @@ -3555,11 +3814,10 @@ def __init__(self, remote_bzrdir: RemoteBzrDir, remote_repository: RemoteReposit self._ensure_real() if not self._real_branch: raise AssertionError - self._format._network_name = \ - self._real_branch._format.network_name() + self._format._network_name = self._real_branch._format.network_name() self.tags = self._format.make_tags(self) # The base class init is not called, so we duplicate this: - hooks = branch.Branch.hooks['open'] + hooks = branch.Branch.hooks["open"] for hook in hooks: hook(self) self._is_stacked = False @@ -3571,8 +3829,11 @@ def _setup_stacking(self, possible_transports): # the vfs branch. try: fallback_url = self.get_stacked_on_url() - except (errors.NotStacked, branch.UnstackableBranchFormat, - errors.UnstackableRepositoryFormat): + except ( + errors.NotStacked, + branch.UnstackableBranchFormat, + errors.UnstackableRepositoryFormat, + ): return self._is_stacked = True if possible_transports is None: @@ -3580,8 +3841,9 @@ def _setup_stacking(self, possible_transports): else: possible_transports = list(possible_transports) possible_transports.append(self.controldir.root_transport) - self._activate_fallback_location(fallback_url, - possible_transports=possible_transports) + self._activate_fallback_location( + fallback_url, possible_transports=possible_transports + ) def _get_config(self): return RemoteBranchConfig(self) @@ -3624,11 +3886,13 @@ def _ensure_real(self): """ if self._real_branch is None: if not vfs.vfs_enabled(): - raise AssertionError('smart server vfs must be enabled ' - 'to use vfs implementation') + raise AssertionError( + "smart server vfs must be enabled " "to use vfs implementation" + ) self.controldir._ensure_real() self._real_branch = self.controldir._real_bzrdir.open_branch( - ignore_fallbacks=self._real_ignore_fallbacks, name=self._name) + ignore_fallbacks=self._real_ignore_fallbacks, name=self._name + ) # The remote branch and the real branch shares the same store. If # we don't, there will always be cases where one of the stores # doesn't see an update made on the other. @@ -3643,9 +3907,9 @@ def _ensure_real(self): # Give the real branch the remote repository to let fast-pathing # happen. self._real_branch.repository = self.repository - if self._lock_mode == 'r': + if self._lock_mode == "r": self._real_branch.lock_read() - elif self._lock_mode == 'w': + elif self._lock_mode == "w": self._real_branch.lock_write(token=self._lock_token) def _translate_error(self, err, **context): @@ -3675,20 +3939,22 @@ def control_files(self): # because it triggers an _ensure_real that we otherwise might not need. if self._control_files is None: self._control_files = RemoteBranchLockableFiles( - self.controldir, self._client) + self.controldir, self._client + ) return self._control_files def get_physical_lock_status(self): """See Branch.get_physical_lock_status().""" try: - response = self._client.call(b'Branch.get_physical_lock_status', - self._remote_path()) + response = self._client.call( + b"Branch.get_physical_lock_status", self._remote_path() + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_branch.get_physical_lock_status() - if response[0] not in (b'yes', b'no'): + if response[0] not in (b"yes", b"no"): raise errors.UnexpectedSmartServerResponse(response) - return (response[0] == b'yes') + return response[0] == b"yes" def get_stacked_on_url(self): """Get the URL this branch is stacked against. @@ -3702,8 +3968,9 @@ def get_stacked_on_url(self): try: # there may not be a repository yet, so we can't use # self._translate_error, so we can't use self._call either. - response = self._client.call(b'Branch.get_stacked_on_url', - self._remote_path()) + response = self._client.call( + b"Branch.get_stacked_on_url", self._remote_path() + ) except errors.ErrorFromSmartServer as err: # there may not be a repository yet, so we can't call through # its _translate_error @@ -3711,14 +3978,15 @@ def get_stacked_on_url(self): except errors.UnknownSmartMethod: self._ensure_real() return self._real_branch.get_stacked_on_url() - if response[0] != b'ok': + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) - return response[1].decode('utf-8') + return response[1].decode("utf-8") def _check_stackable_repo(self) -> None: if not self.repository._format.supports_external_lookups: raise errors.UnstackableRepositoryFormat( - self.repository._format, self.repository.user_url) + self.repository._format, self.repository.user_url + ) def _unstack(self): """Change a branch to be unstacked, copying data as needed. @@ -3737,7 +4005,9 @@ def _unstack(self): raise AssertionError( "can't cope with fallback repositories " "of {!r} (fallbacks: {!r})".format( - old_repository, old_repository._fallback_repositories)) + old_repository, old_repository._fallback_repositories + ) + ) # Open the new repository object. # Repositories don't offer an interface to remove fallback # repositories today; take the conceptually simpler option and just @@ -3746,12 +4016,12 @@ def _unstack(self): # stream from one of them to the other. This does mean doing # separate SSH connection setup, but unstacking is not a # common operation so it's tolerable. - new_bzrdir = controldir.ControlDir.open( - self.controldir.root_transport.base) + new_bzrdir = controldir.ControlDir.open(self.controldir.root_transport.base) new_repository = new_bzrdir.find_repository() if new_repository._fallback_repositories: raise AssertionError( - f"didn't expect {self.repository!r} to have fallback_repositories") + f"didn't expect {self.repository!r} to have fallback_repositories" + ) # Replace self.repository with the new repository. # Do our best to transfer the lock state (i.e. lock-tokens and # lock count) of self.repository to the new repository. @@ -3781,7 +4051,8 @@ def _unstack(self): old_lock_count += 1 if old_lock_count == 0: raise AssertionError( - 'old_repository should have been locked at least once.') + "old_repository should have been locked at least once." + ) for _i in range(old_lock_count - 1): self.repository.lock_write() # Fetch from the old repository into the new. @@ -3794,9 +4065,12 @@ def _unstack(self): except errors.TagsNotSupported: tags_to_fetch = set() fetch_spec = vf_search.NotInOtherForRevs( - self.repository, old_repository, + self.repository, + old_repository, required_ids=[self.last_revision()], - if_present_ids=tags_to_fetch, find_ghosts=True).execute() + if_present_ids=tags_to_fetch, + find_ghosts=True, + ).execute() self.repository.fetch(old_repository, fetch_spec=fetch_spec) def set_stacked_on_url(self, url): @@ -3810,16 +4084,20 @@ def set_stacked_on_url(self, url): if not url: try: self.get_stacked_on_url() - except (errors.NotStacked, UnstackableBranchFormat, - errors.UnstackableRepositoryFormat): + except ( + errors.NotStacked, + UnstackableBranchFormat, + errors.UnstackableRepositoryFormat, + ): return self._unstack() else: self._activate_fallback_location( - url, possible_transports=[self.controldir.root_transport]) + url, possible_transports=[self.controldir.root_transport] + ) # write this out after the repository is stacked to avoid setting a # stacked config that doesn't work. - self._set_config_location('stacked_on_location', url) + self._set_config_location("stacked_on_location", url) # We need the stacked_on_url to be visible both locally (to not query # it repeatedly) and remotely (so smart verbs can get it server side) # Without the following line, @@ -3847,8 +4125,7 @@ def _get_tags_bytes_via_hpss(self): if medium._is_remote_before((1, 13)): return self._vfs_get_tags_bytes() try: - response = self._call( - b'Branch.get_tags_bytes', self._remote_path()) + response = self._call(b"Branch.get_tags_bytes", self._remote_path()) except errors.UnknownSmartMethod: medium._remember_remote_is_before((1, 13)) return self._vfs_get_tags_bytes() @@ -3866,10 +4143,8 @@ def _set_tags_bytes(self, bytes): self._vfs_set_tags_bytes(bytes) return try: - args = ( - self._remote_path(), self._lock_token, self._repo_lock_token) - self._call_with_body_bytes( - b'Branch.set_tags_bytes', args, bytes) + args = (self._remote_path(), self._lock_token, self._repo_lock_token) + self._call_with_body_bytes(b"Branch.set_tags_bytes", args, bytes) except errors.UnknownSmartMethod: medium._remember_remote_is_before((1, 18)) self._vfs_set_tags_bytes(bytes) @@ -3881,8 +4156,8 @@ def lock_read(self): """ self.repository.lock_read() if not self._lock_mode: - self._note_lock('r') - self._lock_mode = 'r' + self._note_lock("r") + self._lock_mode = "r" self._lock_count = 1 if self._real_branch is not None: self._real_branch.lock_read() @@ -3892,38 +4167,42 @@ def lock_read(self): def _remote_lock_write(self, token): if token is None: - branch_token = repo_token = b'' + branch_token = repo_token = b"" else: branch_token = token repo_token = self.repository.lock_write().repository_token self.repository.unlock() - err_context = {'token': token} + err_context = {"token": token} try: response = self._call( - b'Branch.lock_write', self._remote_path(), branch_token, - repo_token or b'', **err_context) + b"Branch.lock_write", + self._remote_path(), + branch_token, + repo_token or b"", + **err_context, + ) except errors.LockContention as e: # The LockContention from the server doesn't have any # information about the lock_url. We re-raise LockContention # with valid lock_url. - raise errors.LockContention('(remote lock)', self.repository.base.split('.bzr/')[0]) from e - if response[0] != b'ok': + raise errors.LockContention( + "(remote lock)", self.repository.base.split(".bzr/")[0] + ) from e + if response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) ok, branch_token, repo_token = response return branch_token, repo_token def lock_write(self, token=None): if not self._lock_mode: - self._note_lock('w') + self._note_lock("w") # Lock the branch and repo in one remote call. remote_tokens = self._remote_lock_write(token) self._lock_token, self._repo_lock_token = remote_tokens if not self._lock_token: - raise SmartProtocolError( - 'Remote server did not return a token!') + raise SmartProtocolError("Remote server did not return a token!") # Tell the self.repository object that it is locked. - self.repository.lock_write( - self._repo_lock_token, _skip_rpc=True) + self.repository.lock_write(self._repo_lock_token, _skip_rpc=True) if self._real_branch is not None: self._real_branch.lock_write(token=self._lock_token) @@ -3931,9 +4210,9 @@ def lock_write(self, token=None): self._leave_lock = True else: self._leave_lock = False - self._lock_mode = 'w' + self._lock_mode = "w" self._lock_count = 1 - elif self._lock_mode == 'r': + elif self._lock_mode == "r": raise errors.ReadOnlyError(self) else: if token is not None: @@ -3948,11 +4227,15 @@ def lock_write(self, token=None): return BranchWriteLockResult(self.unlock, self._lock_token or None) def _unlock(self, branch_token, repo_token): - err_context = {'token': str((branch_token, repo_token))} + err_context = {"token": str((branch_token, repo_token))} response = self._call( - b'Branch.unlock', self._remote_path(), branch_token, - repo_token or b'', **err_context) - if response == (b'ok',): + b"Branch.unlock", + self._remote_path(), + branch_token, + repo_token or b"", + **err_context, + ) + if response == (b"ok",): return raise errors.UnexpectedSmartServerResponse(response) @@ -3967,20 +4250,19 @@ def unlock(self): mode = self._lock_mode self._lock_mode = None if self._real_branch is not None: - if (not self._leave_lock and mode == 'w' - and self._repo_lock_token): + if not self._leave_lock and mode == "w" and self._repo_lock_token: # If this RemoteBranch will remove the physical lock # for the repository, make sure the _real_branch # doesn't do it first. (Because the _real_branch's # repository is set to be the RemoteRepository.) self._real_branch.repository.leave_lock_in_place() self._real_branch.unlock() - if mode != 'w': + if mode != "w": # Only write-locked branched need to make a remote method # call to perform the unlock. return if not self._lock_token: - raise AssertionError('Locked, but no token!') + raise AssertionError("Locked, but no token!") branch_token = self._lock_token repo_token = self._repo_lock_token self._lock_token = None @@ -3992,12 +4274,11 @@ def unlock(self): def break_lock(self): try: - response = self._call( - b'Branch.break_lock', self._remote_path()) + response = self._call(b"Branch.break_lock", self._remote_path()) except errors.UnknownSmartMethod: self._ensure_real() return self._real_branch.break_lock() - if response != (b'ok',): + if response != (b"ok",): raise errors.UnexpectedSmartServerResponse(response) def leave_lock_in_place(self): @@ -4016,10 +4297,8 @@ def get_rev_id(self, revno, history=None): with self.lock_read(): last_revision_info = self.last_revision_info() if revno < 0: - raise errors.RevnoOutOfBounds( - revno, (0, last_revision_info[0])) - ok, result = self.repository.get_rev_id_for_revno( - revno, last_revision_info) + raise errors.RevnoOutOfBounds(revno, (0, last_revision_info[0])) + ok, result = self.repository.get_rev_id_for_revno(revno, last_revision_info) if ok: return result missing_parent = result[1] @@ -4032,11 +4311,9 @@ def get_rev_id(self, revno, history=None): raise errors.NoSuchRevision(self, missing_parent) def _read_last_revision_info(self): - response = self._call( - b'Branch.last_revision_info', self._remote_path()) - if response[0] != b'ok': - raise SmartProtocolError( - f'unexpected response code {response}') + response = self._call(b"Branch.last_revision_info", self._remote_path()) + if response[0] != b"ok": + raise SmartProtocolError(f"unexpected response code {response}") revno = int(response[1]) last_revision = response[2] return (revno, last_revision) @@ -4047,19 +4324,25 @@ def _gen_revision_history(self): self._ensure_real() return self._real_branch._gen_revision_history() response_tuple, response_handler = self._call_expecting_body( - b'Branch.revision_history', self._remote_path()) - if response_tuple[0] != b'ok': + b"Branch.revision_history", self._remote_path() + ) + if response_tuple[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response_tuple) - result = response_handler.read_body_bytes().split(b'\x00') - if result == ['']: + result = response_handler.read_body_bytes().split(b"\x00") + if result == [""]: return [] return result def _remote_path(self): return self.controldir._path_for_remote_call(self._client) - def _set_last_revision_descendant(self, revision_id, other_branch, - allow_diverged=False, allow_overwrite_descendant=False): + def _set_last_revision_descendant( + self, + revision_id, + other_branch, + allow_diverged=False, + allow_overwrite_descendant=False, + ): # This performs additional work to meet the hook contract; while its # undesirable, we have to synthesise the revno to call the hook, and # not calling the hook is worse as it means changes can't be prevented. @@ -4069,14 +4352,19 @@ def _set_last_revision_descendant(self, revision_id, other_branch, old_revno, old_revid = self.last_revision_info() history = self._lefthand_history(revision_id) self._run_pre_change_branch_tip_hooks(len(history), revision_id) - err_context = {'other_branch': other_branch} - response = self._call(b'Branch.set_last_revision_ex', - self._remote_path(), self._lock_token, self._repo_lock_token, - revision_id, int(allow_diverged), int( - allow_overwrite_descendant), - **err_context) + err_context = {"other_branch": other_branch} + response = self._call( + b"Branch.set_last_revision_ex", + self._remote_path(), + self._lock_token, + self._repo_lock_token, + revision_id, + int(allow_diverged), + int(allow_overwrite_descendant), + **err_context, + ) self._clear_cached_state() - if len(response) != 3 and response[0] != b'ok': + if len(response) != 3 and response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) new_revno, new_revision_id = response[1:] self._last_revision_info_cache = new_revno, new_revision_id @@ -4096,10 +4384,14 @@ def _set_last_revision(self, revision_id): history = self._lefthand_history(revision_id) self._run_pre_change_branch_tip_hooks(len(history), revision_id) self._clear_cached_state() - response = self._call(b'Branch.set_last_revision', - self._remote_path(), self._lock_token, self._repo_lock_token, - revision_id) - if response != (b'ok',): + response = self._call( + b"Branch.set_last_revision", + self._remote_path(), + self._lock_token, + self._repo_lock_token, + revision_id, + ) + if response != (b"ok",): raise errors.UnexpectedSmartServerResponse(response) self._run_post_change_branch_tip_hooks(old_revno, old_revid) @@ -4108,16 +4400,16 @@ def _get_parent_location(self): if medium._is_remote_before((1, 13)): return self._vfs_get_parent_location() try: - response = self._call(b'Branch.get_parent', self._remote_path()) + response = self._call(b"Branch.get_parent", self._remote_path()) except errors.UnknownSmartMethod: medium._remember_remote_is_before((1, 13)) return self._vfs_get_parent_location() if len(response) != 1: raise errors.UnexpectedSmartServerResponse(response) parent_location = response[0] - if parent_location == b'': + if parent_location == b"": return None - return parent_location.decode('utf-8') + return parent_location.decode("utf-8") def _vfs_get_parent_location(self): self._ensure_real() @@ -4128,12 +4420,16 @@ def _set_parent_location(self, url): if medium._is_remote_before((1, 15)): return self._vfs_set_parent_location(url) try: - call_url = url or '' + call_url = url or "" if isinstance(call_url, str): - call_url = call_url.encode('utf-8') - response = self._call(b'Branch.set_parent_location', - self._remote_path(), self._lock_token, self._repo_lock_token, - call_url) + call_url = call_url.encode("utf-8") + response = self._call( + b"Branch.set_parent_location", + self._remote_path(), + self._lock_token, + self._repo_lock_token, + call_url, + ) except errors.UnknownSmartMethod: medium._remember_remote_is_before((1, 15)) return self._vfs_set_parent_location(url) @@ -4144,21 +4440,36 @@ def _vfs_set_parent_location(self, url): self._ensure_real() return self._real_branch._set_parent_location(url) - def pull(self, source, overwrite=False, stop_revision=None, - **kwargs): + def pull(self, source, overwrite=False, stop_revision=None, **kwargs): with self.lock_write(): self._clear_cached_state_of_remote_branch_only() self._ensure_real() return self._real_branch.pull( - source, overwrite=overwrite, stop_revision=stop_revision, - _override_hook_target=self, **kwargs) - - def push(self, target, overwrite=False, stop_revision=None, lossy=False, tag_selector=None): + source, + overwrite=overwrite, + stop_revision=stop_revision, + _override_hook_target=self, + **kwargs, + ) + + def push( + self, + target, + overwrite=False, + stop_revision=None, + lossy=False, + tag_selector=None, + ): with self.lock_read(): self._ensure_real() return self._real_branch.push( - target, overwrite=overwrite, stop_revision=stop_revision, lossy=lossy, - _override_hook_source_branch=self, tag_selector=tag_selector) + target, + overwrite=overwrite, + stop_revision=stop_revision, + lossy=lossy, + _override_hook_source_branch=self, + tag_selector=tag_selector, + ) def peek_lock_mode(self): return self._lock_mode @@ -4173,19 +4484,20 @@ def revision_id_to_dotted_revno(self, revision_id): """ with self.lock_read(): try: - response = self._call(b'Branch.revision_id_to_revno', - self._remote_path(), revision_id) + response = self._call( + b"Branch.revision_id_to_revno", self._remote_path(), revision_id + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_branch.revision_id_to_dotted_revno(revision_id) except UnknownErrorFromSmartServer as e: # Deal with older versions of bzr/brz that didn't explicitly # wrap GhostRevisionsHaveNoRevno. - if e.error_tuple[1] == b'GhostRevisionsHaveNoRevno': + if e.error_tuple[1] == b"GhostRevisionsHaveNoRevno": (revid, ghost_revid) = re.findall(b"{([^}]+)}", e.error_tuple[2]) raise errors.GhostRevisionsHaveNoRevno(revid, ghost_revid) from e raise - if response[0] == b'ok': + if response[0] == b"ok": return tuple([int(x) for x in response[1:]]) else: raise errors.UnexpectedSmartServerResponse(response) @@ -4197,12 +4509,13 @@ def revision_id_to_revno(self, revision_id): """ with self.lock_read(): try: - response = self._call(b'Branch.revision_id_to_revno', - self._remote_path(), revision_id) + response = self._call( + b"Branch.revision_id_to_revno", self._remote_path(), revision_id + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_branch.revision_id_to_revno(revision_id) - if response[0] == b'ok': + if response[0] == b"ok": if len(response) == 2: return int(response[1]) raise NoSuchRevision(self, revision_id) @@ -4215,19 +4528,23 @@ def set_last_revision_info(self, revno, revision_id): old_revno, old_revid = self.last_revision_info() self._run_pre_change_branch_tip_hooks(revno, revision_id) if not revision_id or not isinstance(revision_id, bytes): - raise errors.InvalidRevisionId( - revision_id=revision_id, branch=self) + raise errors.InvalidRevisionId(revision_id=revision_id, branch=self) try: - response = self._call(b'Branch.set_last_revision_info', - self._remote_path(), self._lock_token, self._repo_lock_token, - str(revno).encode('ascii'), revision_id) + response = self._call( + b"Branch.set_last_revision_info", + self._remote_path(), + self._lock_token, + self._repo_lock_token, + str(revno).encode("ascii"), + revision_id, + ) except errors.UnknownSmartMethod: self._ensure_real() self._clear_cached_state_of_remote_branch_only() self._real_branch.set_last_revision_info(revno, revision_id) self._last_revision_info_cache = revno, revision_id return - if response == (b'ok',): + if response == (b"ok",): self._clear_cached_state() self._last_revision_info_cache = revno, revision_id self._run_post_change_branch_tip_hooks(old_revno, old_revid) @@ -4238,15 +4555,18 @@ def set_last_revision_info(self, revno, revision_id): else: raise errors.UnexpectedSmartServerResponse(response) - def generate_revision_history(self, revision_id, last_rev=None, - other_branch=None): + def generate_revision_history(self, revision_id, last_rev=None, other_branch=None): with self.lock_write(): medium = self._client._medium if not medium._is_remote_before((1, 6)): # Use a smart method for 1.6 and above servers try: - self._set_last_revision_descendant(revision_id, other_branch, - allow_diverged=True, allow_overwrite_descendant=True) + self._set_last_revision_descendant( + revision_id, + other_branch, + allow_diverged=True, + allow_overwrite_descendant=True, + ) return except errors.UnknownSmartMethod: medium._remember_remote_is_before((1, 6)) @@ -4256,17 +4576,16 @@ def generate_revision_history(self, revision_id, last_rev=None, known_revision_ids = [ (last_revid, last_revno), (_mod_revision.NULL_REVISION, 0), - ] + ] if last_rev is not None: if not graph.is_ancestor(last_rev, revision_id): # our previous tip is not merged into stop_revision raise errors.DivergedBranches(self, other_branch) - revno = graph.find_distance_to_null( - revision_id, known_revision_ids) + revno = graph.find_distance_to_null(revision_id, known_revision_ids) self.set_last_revision_info(revno, revision_id) def set_push_location(self, location): - self._set_config_location('push_location', location) + self._set_config_location("push_location", location) def heads_to_fetch(self): if self._format._use_default_local_heads_to_fetch(): @@ -4285,7 +4604,7 @@ def heads_to_fetch(self): return self._vfs_heads_to_fetch() def _rpc_heads_to_fetch(self): - response = self._call(b'Branch.heads_to_fetch', self._remote_path()) + response = self._call(b"Branch.heads_to_fetch", self._remote_path()) if len(response) != 2: raise errors.UnexpectedSmartServerResponse(response) must_fetch, if_present_fetch = response @@ -4298,6 +4617,7 @@ def _vfs_heads_to_fetch(self): def reconcile(self, thorough=True): """Make sure the data stored in this branch is consistent.""" from .reconcile import BranchReconciler + with self.lock_write(): reconciler = BranchReconciler(self, thorough=thorough) return reconciler.reconcile() @@ -4313,8 +4633,7 @@ def set_reference_info(self, file_id, branch_location, tree_path=None): if not self._format.supports_reference_locations: raise errors.UnsupportedOperation(self.set_reference_info, self) self._ensure_real() - self._real_branch.set_reference_info( - file_id, branch_location, tree_path) + self._real_branch.set_reference_info(file_id, branch_location, tree_path) def _set_all_reference_info(self, reference_info): if not self._format.supports_reference_locations: @@ -4327,15 +4646,16 @@ def _get_all_reference_info(self): return {} try: response, handler = self._call_expecting_body( - b'Branch.get_all_reference_info', self._remote_path()) + b"Branch.get_all_reference_info", self._remote_path() + ) except errors.UnknownSmartMethod: self._ensure_real() return self._real_branch._get_all_reference_info() - if len(response) and response[0] != b'ok': + if len(response) and response[0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) ret = {} - for (f, u, p) in bencode.bdecode(handler.read_body_bytes()): - ret[f] = (u.decode('utf-8'), p.decode('utf-8') if p else None) + for f, u, p in bencode.bdecode(handler.read_body_bytes()): + ret[f] = (u.decode("utf-8"), p.decode("utf-8") if p else None) return ret def reference_parent(self, file_id, path, possible_transports=None): @@ -4349,13 +4669,16 @@ def reference_parent(self, file_id, path, possible_transports=None): try: return branch.Branch.open_from_transport( self.controldir.root_transport.clone(path), - possible_transports=possible_transports) + possible_transports=possible_transports, + ) except errors.NotBranchError: return None return branch.Branch.open( urlutils.join( - urlutils.strip_segment_parameters(self.user_url), branch_location), - possible_transports=possible_transports) + urlutils.strip_segment_parameters(self.user_url), branch_location + ), + possible_transports=possible_transports, + ) class RemoteConfig: @@ -4390,16 +4713,16 @@ def get_option(self, name, section=None, default=None): value = section_obj.get(name, default) except errors.UnknownSmartMethod: value = self._vfs_get_option(name, section, default) - for hook in _mod_config.OldConfigHooks['get']: + for hook in _mod_config.OldConfigHooks["get"]: hook(self, name, value) return value def _response_to_configobj(self, response): - if len(response[0]) and response[0][0] != b'ok': + if len(response[0]) and response[0][0] != b"ok": raise errors.UnexpectedSmartServerResponse(response) lines = response[1].read_body_bytes().splitlines() - conf = _mod_config.ConfigObj(lines, encoding='utf-8') - for hook in _mod_config.OldConfigHooks['load']: + conf = _mod_config.ConfigObj(lines, encoding="utf-8") + for hook in _mod_config.OldConfigHooks["load"]: hook(self) return conf @@ -4413,7 +4736,8 @@ def __init__(self, branch): def _get_configobj(self): path = self._branch._remote_path() response = self._branch._client.call_expecting_body( - b'Branch.get_config_file', path) + b"Branch.get_config_file", path + ) return self._response_to_configobj(response) def set_option(self, value, name, section=None): @@ -4442,10 +4766,15 @@ def _set_config_option(self, value, name, section): raise TypeError(value) try: path = self._branch._remote_path() - response = self._branch._client.call(b'Branch.set_config_option', - path, self._branch._lock_token, self._branch._repo_lock_token, - value.encode('utf-8'), name.encode('utf-8'), - (section or '').encode('utf-8')) + response = self._branch._client.call( + b"Branch.set_config_option", + path, + self._branch._lock_token, + self._branch._repo_lock_token, + value.encode("utf-8"), + name.encode("utf-8"), + (section or "").encode("utf-8"), + ) except errors.UnknownSmartMethod: medium = self._branch._client._medium medium._remember_remote_is_before((1, 14)) @@ -4457,9 +4786,9 @@ def _serialize_option_dict(self, option_dict): utf8_dict = {} for key, value in option_dict.items(): if isinstance(key, str): - key = key.encode('utf8') + key = key.encode("utf8") if isinstance(value, str): - value = value.encode('utf8') + value = value.encode("utf8") utf8_dict[key] = value return bencode.bencode(utf8_dict) @@ -4468,9 +4797,14 @@ def _set_config_option_dict(self, value, name, section): path = self._branch._remote_path() serialised_dict = self._serialize_option_dict(value) response = self._branch._client.call( - b'Branch.set_config_option_dict', - path, self._branch._lock_token, self._branch._repo_lock_token, - serialised_dict, name.encode('utf-8'), (section or '').encode('utf-8')) + b"Branch.set_config_option_dict", + path, + self._branch._lock_token, + self._branch._repo_lock_token, + serialised_dict, + name.encode("utf-8"), + (section or "").encode("utf-8"), + ) except errors.UnknownSmartMethod: medium = self._branch._client._medium medium._remember_remote_is_before((2, 2)) @@ -4483,8 +4817,7 @@ def _real_object(self): return self._branch._real_branch def _vfs_set_option(self, value, name, section=None): - return self._real_object()._get_config().set_option( - value, name, section) + return self._real_object()._get_config().set_option(value, name, section) class RemoteBzrDirConfig(RemoteConfig): @@ -4495,17 +4828,15 @@ def __init__(self, bzrdir): def _get_configobj(self): medium = self._bzrdir._client._medium - verb = b'BzrDir.get_config_file' + verb = b"BzrDir.get_config_file" if medium._is_remote_before((1, 15)): raise errors.UnknownSmartMethod(verb) path = self._bzrdir._path_for_remote_call(self._bzrdir._client) - response = self._bzrdir._call_expecting_body( - verb, path) + response = self._bzrdir._call_expecting_body(verb, path) return self._response_to_configobj(response) def _vfs_get_option(self, name, section, default): - return self._real_object()._get_config().get_option( - name, section, default) + return self._real_object()._get_config().get_option(name, section, default) def set_option(self, value, name, section=None): """Set the value associated with a named option. @@ -4514,8 +4845,7 @@ def set_option(self, value, name, section=None): :param name: The name of the value to set :param section: The section the option is in (if any) """ - return self._real_object()._get_config().set_option( - value, name, section) + return self._real_object()._get_config().set_option(value, name, section) def _real_object(self): self._bzrdir._ensure_real() @@ -4540,11 +4870,12 @@ def _translate_error(err, **context): If the error from the server doesn't match a known pattern, then UnknownErrorFromSmartServer is raised. """ + def find(name): try: return context[name] except KeyError: - mutter('Missing key \'%s\' in context %r', name, context) + mutter("Missing key '%s' in context %r", name, context) raise err from None def get_path(): @@ -4552,13 +4883,14 @@ def get_path(): arg. """ try: - return context['path'] + return context["path"] except KeyError: try: - return err.error_args[0].decode('utf-8') + return err.error_args[0].decode("utf-8") except IndexError: - mutter('Missing key \'path\' in context %r', context) + mutter("Missing key 'path' in context %r", context) raise err from None + if not isinstance(err.error_verb, bytes): raise TypeError(err.error_verb) try: @@ -4575,137 +4907,199 @@ def get_path(): raise translator(err) -error_translators.register(b'NoSuchRevision', - lambda err, find, get_path: NoSuchRevision( - find('branch'), err.error_args[0])) -error_translators.register(b'nosuchrevision', - lambda err, find, get_path: NoSuchRevision( - find('repository'), err.error_args[0])) error_translators.register( - b'revno-outofbounds', + b"NoSuchRevision", + lambda err, find, get_path: NoSuchRevision(find("branch"), err.error_args[0]), +) +error_translators.register( + b"nosuchrevision", + lambda err, find, get_path: NoSuchRevision(find("repository"), err.error_args[0]), +) +error_translators.register( + b"revno-outofbounds", lambda err, find, get_path: errors.RevnoOutOfBounds( - err.error_args[0], (err.error_args[1], err.error_args[2]))) + err.error_args[0], (err.error_args[1], err.error_args[2]) + ), +) def _translate_nobranch_error(err, find, get_path): if len(err.error_args) >= 1: - extra = err.error_args[0].decode('utf-8') + extra = err.error_args[0].decode("utf-8") else: extra = None - return errors.NotBranchError(path=find('bzrdir').root_transport.base, - detail=extra) - - -error_translators.register(b'nobranch', _translate_nobranch_error) -error_translators.register(b'norepository', - lambda err, find, get_path: errors.NoRepositoryPresent( - find('bzrdir'))) -error_translators.register(b'UnlockableTransport', - lambda err, find, get_path: errors.UnlockableTransport( - find('bzrdir').root_transport)) -error_translators.register(b'TokenMismatch', - lambda err, find, get_path: errors.TokenMismatch( - find('token'), '(remote token)')) -error_translators.register(b'Diverged', - lambda err, find, get_path: errors.DivergedBranches( - find('branch'), find('other_branch'))) -error_translators.register(b'NotStacked', - lambda err, find, get_path: errors.NotStacked(branch=find('branch'))) + return errors.NotBranchError(path=find("bzrdir").root_transport.base, detail=extra) + + +error_translators.register(b"nobranch", _translate_nobranch_error) +error_translators.register( + b"norepository", + lambda err, find, get_path: errors.NoRepositoryPresent(find("bzrdir")), +) +error_translators.register( + b"UnlockableTransport", + lambda err, find, get_path: errors.UnlockableTransport( + find("bzrdir").root_transport + ), +) +error_translators.register( + b"TokenMismatch", + lambda err, find, get_path: errors.TokenMismatch(find("token"), "(remote token)"), +) +error_translators.register( + b"Diverged", + lambda err, find, get_path: errors.DivergedBranches( + find("branch"), find("other_branch") + ), +) +error_translators.register( + b"NotStacked", lambda err, find, get_path: errors.NotStacked(branch=find("branch")) +) def _translate_PermissionDenied(err, find, get_path): path = get_path() if len(err.error_args) >= 2: - extra = err.error_args[1].decode('utf-8') + extra = err.error_args[1].decode("utf-8") else: extra = None return errors.PermissionDenied(path, extra=extra) -error_translators.register(b'PermissionDenied', _translate_PermissionDenied) -error_translators.register(b'ReadError', - lambda err, find, get_path: errors.ReadError(get_path())) -error_translators.register(b'NoSuchFile', - lambda err, find, get_path: _mod_transport.NoSuchFile(get_path())) -error_translators.register(b'TokenLockingNotSupported', - lambda err, find, get_path: errors.TokenLockingNotSupported( - find('repository'))) -error_translators.register(b'UnsuspendableWriteGroup', - lambda err, find, get_path: errors.UnsuspendableWriteGroup( - repository=find('repository'))) -error_translators.register(b'UnresumableWriteGroup', - lambda err, find, get_path: errors.UnresumableWriteGroup( - repository=find('repository'), write_groups=err.error_args[0], - reason=err.error_args[1])) -error_translators.register(b'AlreadyControlDir', - lambda err, find, get_path: errors.AlreadyControlDirError(get_path())) - -no_context_error_translators.register(b'GhostRevisionsHaveNoRevno', - lambda err: errors.GhostRevisionsHaveNoRevno(*err.error_args)) -no_context_error_translators.register(b'IncompatibleRepositories', - lambda err: errors.IncompatibleRepositories( - err.error_args[0].decode('utf-8'), err.error_args[1].decode('utf-8'), err.error_args[2].decode('utf-8'))) -no_context_error_translators.register(b'LockContention', - lambda err: errors.LockContention('(remote lock)')) -no_context_error_translators.register(b'LockFailed', - lambda err: errors.LockFailed(err.error_args[0].decode('utf-8'), err.error_args[1].decode('utf-8'))) -no_context_error_translators.register(b'TipChangeRejected', - lambda err: errors.TipChangeRejected(err.error_args[0].decode('utf8'))) -no_context_error_translators.register(b'UnstackableBranchFormat', - lambda err: branch.UnstackableBranchFormat(*err.error_args)) -no_context_error_translators.register(b'UnstackableRepositoryFormat', - lambda err: errors.UnstackableRepositoryFormat(*err.error_args)) -no_context_error_translators.register(b'FileExists', - lambda err: _mod_transport.FileExists(err.error_args[0].decode('utf-8'))) -no_context_error_translators.register(b'DirectoryNotEmpty', - lambda err: errors.DirectoryNotEmpty(err.error_args[0].decode('utf-8'))) -no_context_error_translators.register(b'UnknownFormat', - lambda err: errors.UnknownFormatError( - err.error_args[0].decode('ascii'), err.error_args[0].decode('ascii'))) -no_context_error_translators.register(b'InvalidURL', - lambda err: urlutils.InvalidURL( - err.error_args[0].decode('utf-8'), err.error_args[1].decode('utf-8'))) +error_translators.register(b"PermissionDenied", _translate_PermissionDenied) +error_translators.register( + b"ReadError", lambda err, find, get_path: errors.ReadError(get_path()) +) +error_translators.register( + b"NoSuchFile", lambda err, find, get_path: _mod_transport.NoSuchFile(get_path()) +) +error_translators.register( + b"TokenLockingNotSupported", + lambda err, find, get_path: errors.TokenLockingNotSupported(find("repository")), +) +error_translators.register( + b"UnsuspendableWriteGroup", + lambda err, find, get_path: errors.UnsuspendableWriteGroup( + repository=find("repository") + ), +) +error_translators.register( + b"UnresumableWriteGroup", + lambda err, find, get_path: errors.UnresumableWriteGroup( + repository=find("repository"), + write_groups=err.error_args[0], + reason=err.error_args[1], + ), +) +error_translators.register( + b"AlreadyControlDir", + lambda err, find, get_path: errors.AlreadyControlDirError(get_path()), +) + +no_context_error_translators.register( + b"GhostRevisionsHaveNoRevno", + lambda err: errors.GhostRevisionsHaveNoRevno(*err.error_args), +) +no_context_error_translators.register( + b"IncompatibleRepositories", + lambda err: errors.IncompatibleRepositories( + err.error_args[0].decode("utf-8"), + err.error_args[1].decode("utf-8"), + err.error_args[2].decode("utf-8"), + ), +) +no_context_error_translators.register( + b"LockContention", lambda err: errors.LockContention("(remote lock)") +) +no_context_error_translators.register( + b"LockFailed", + lambda err: errors.LockFailed( + err.error_args[0].decode("utf-8"), err.error_args[1].decode("utf-8") + ), +) +no_context_error_translators.register( + b"TipChangeRejected", + lambda err: errors.TipChangeRejected(err.error_args[0].decode("utf8")), +) +no_context_error_translators.register( + b"UnstackableBranchFormat", + lambda err: branch.UnstackableBranchFormat(*err.error_args), +) +no_context_error_translators.register( + b"UnstackableRepositoryFormat", + lambda err: errors.UnstackableRepositoryFormat(*err.error_args), +) +no_context_error_translators.register( + b"FileExists", + lambda err: _mod_transport.FileExists(err.error_args[0].decode("utf-8")), +) +no_context_error_translators.register( + b"DirectoryNotEmpty", + lambda err: errors.DirectoryNotEmpty(err.error_args[0].decode("utf-8")), +) +no_context_error_translators.register( + b"UnknownFormat", + lambda err: errors.UnknownFormatError( + err.error_args[0].decode("ascii"), err.error_args[0].decode("ascii") + ), +) +no_context_error_translators.register( + b"InvalidURL", + lambda err: urlutils.InvalidURL( + err.error_args[0].decode("utf-8"), err.error_args[1].decode("utf-8") + ), +) def _translate_short_readv_error(err): args = err.error_args return errors.ShortReadvError( - args[0].decode('utf-8'), - int(args[1].decode('ascii')), int(args[2].decode('ascii')), - int(args[3].decode('ascii'))) + args[0].decode("utf-8"), + int(args[1].decode("ascii")), + int(args[2].decode("ascii")), + int(args[3].decode("ascii")), + ) -no_context_error_translators.register(b'ShortReadvError', - _translate_short_readv_error) +no_context_error_translators.register(b"ShortReadvError", _translate_short_readv_error) def _translate_unicode_error(err): - encoding = err.error_args[0].decode('ascii') - val = err.error_args[1].decode('utf-8') - start = int(err.error_args[2].decode('ascii')) - end = int(err.error_args[3].decode('ascii')) - reason = err.error_args[4].decode('utf-8') - if val.startswith('u:'): - val = val[2:].decode('utf-8') - elif val.startswith('s:'): - val = val[2:].decode('base64') - if err.error_verb == 'UnicodeDecodeError': + encoding = err.error_args[0].decode("ascii") + val = err.error_args[1].decode("utf-8") + start = int(err.error_args[2].decode("ascii")) + end = int(err.error_args[3].decode("ascii")) + reason = err.error_args[4].decode("utf-8") + if val.startswith("u:"): + val = val[2:].decode("utf-8") + elif val.startswith("s:"): + val = val[2:].decode("base64") + if err.error_verb == "UnicodeDecodeError": raise UnicodeDecodeError(encoding, val, start, end, reason) - elif err.error_verb == 'UnicodeEncodeError': + elif err.error_verb == "UnicodeEncodeError": raise UnicodeEncodeError(encoding, val, start, end, reason) -no_context_error_translators.register(b'UnicodeEncodeError', - _translate_unicode_error) -no_context_error_translators.register(b'UnicodeDecodeError', - _translate_unicode_error) -no_context_error_translators.register(b'ReadOnlyError', - lambda err: errors.TransportNotPossible('readonly transport')) -no_context_error_translators.register(b'MemoryError', - lambda err: errors.BzrError("remote server out of memory\n" - "Retry non-remotely, or contact the server admin for details.")) -no_context_error_translators.register(b'RevisionNotPresent', - lambda err: errors.RevisionNotPresent(err.error_args[0].decode('utf-8'), err.error_args[1].decode('utf-8'))) - -no_context_error_translators.register(b'BzrCheckError', - lambda err: errors.BzrCheckError(msg=err.error_args[0].decode('utf-8'))) +no_context_error_translators.register(b"UnicodeEncodeError", _translate_unicode_error) +no_context_error_translators.register(b"UnicodeDecodeError", _translate_unicode_error) +no_context_error_translators.register( + b"ReadOnlyError", lambda err: errors.TransportNotPossible("readonly transport") +) +no_context_error_translators.register( + b"MemoryError", + lambda err: errors.BzrError( + "remote server out of memory\n" + "Retry non-remotely, or contact the server admin for details." + ), +) +no_context_error_translators.register( + b"RevisionNotPresent", + lambda err: errors.RevisionNotPresent( + err.error_args[0].decode("utf-8"), err.error_args[1].decode("utf-8") + ), +) + +no_context_error_translators.register( + b"BzrCheckError", + lambda err: errors.BzrCheckError(msg=err.error_args[0].decode("utf-8")), +) diff --git a/breezy/bzr/repository.py b/breezy/bzr/repository.py index 8f7231d0aa..a31ce2c9cf 100644 --- a/breezy/bzr/repository.py +++ b/breezy/bzr/repository.py @@ -33,13 +33,12 @@ class MetaDirRepository(Repository): _format: "RepositoryFormatMetaDir" def __init__(self, _format, a_bzrdir, control_files): - super().__init__( - _format, a_bzrdir, control_files) + super().__init__(_format, a_bzrdir, control_files) self._transport = control_files._transport def is_shared(self): """Return True if this repository is flagged as a shared repository.""" - return self._transport.has('shared-storage') + return self._transport.has("shared-storage") def set_make_working_trees(self, new_value): """Set the policy flag for making working trees when creating branches. @@ -53,17 +52,17 @@ def set_make_working_trees(self, new_value): with self.lock_write(): if new_value: try: - self._transport.delete('no-working-trees') + self._transport.delete("no-working-trees") except _mod_transport.NoSuchFile: pass else: self._transport.put_bytes( - 'no-working-trees', b'', - mode=self.controldir._get_file_mode()) + "no-working-trees", b"", mode=self.controldir._get_file_mode() + ) def make_working_trees(self): """Returns the policy for making working trees on new branches.""" - return not self._transport.has('no-working-trees') + return not self._transport.has("no-working-trees") def update_feature_flags(self, updated_flags): """Update the feature flags for this branch. @@ -73,8 +72,7 @@ def update_feature_flags(self, updated_flags): """ with self.lock_write(): self._format._update_feature_flags(updated_flags) - self.control_transport.put_bytes( - 'format', self._format.as_string()) + self.control_transport.put_bytes("format", self._format.as_string()) def _find_parent_ids_of_revisions(self, revision_ids): """Find all parent ids that are mentioned in the revision graph. @@ -82,8 +80,9 @@ def _find_parent_ids_of_revisions(self, revision_ids): :return: set of revisions that are parents of revision_ids which are not part of revision_ids themselves """ - parent_ids = set(itertools.chain.from_iterable( - self.get_parent_map(revision_ids).values())) + parent_ids = set( + itertools.chain.from_iterable(self.get_parent_map(revision_ids).values()) + ) parent_ids.difference_update(revision_ids) parent_ids.discard(_mod_revision.NULL_REVISION) return parent_ids @@ -113,8 +112,9 @@ def _create_control_files(self, a_bzrdir): # FIXME: RBC 20060125 don't peek under the covers # NB: no need to escape relative paths that are url safe. repository_transport = a_bzrdir.get_repository_transport(self) - control_files = lockable_files.LockableFiles(repository_transport, - 'lock', lockdir.LockDir) + control_files = lockable_files.LockableFiles( + repository_transport, "lock", lockdir.LockDir + ) control_files.create_lock() return control_files @@ -124,16 +124,18 @@ def _upload_blank_content(self, a_bzrdir, dirs, files, utf8_files, shared): control_files.lock_write() transport = control_files._transport if shared is True: - utf8_files += [('shared-storage', b'')] + utf8_files += [("shared-storage", b"")] try: for dir in dirs: transport.mkdir(dir, mode=a_bzrdir._get_dir_mode()) - for (filename, content_stream) in files: - transport.put_file(filename, content_stream, - mode=a_bzrdir._get_file_mode()) - for (filename, content_bytes) in utf8_files: - transport.put_bytes_non_atomic(filename, content_bytes, - mode=a_bzrdir._get_file_mode()) + for filename, content_stream in files: + transport.put_file( + filename, content_stream, mode=a_bzrdir._get_file_mode() + ) + for filename, content_bytes in utf8_files: + transport.put_bytes_non_atomic( + filename, content_bytes, mode=a_bzrdir._get_file_mode() + ) finally: control_files.unlock() @@ -150,12 +152,20 @@ def find_format(klass, a_bzrdir): format_string = transport.get_bytes("format") except _mod_transport.NoSuchFile as e: raise errors.NoRepositoryPresent(a_bzrdir) from e - return klass._find_format(format_registry, 'repository', format_string) - - def check_support_status(self, allow_unsupported, recommend_upgrade=True, - basedir=None): - RepositoryFormat.check_support_status(self, - allow_unsupported=allow_unsupported, recommend_upgrade=recommend_upgrade, - basedir=basedir) - bzrdir.BzrFormat.check_support_status(self, allow_unsupported=allow_unsupported, - recommend_upgrade=recommend_upgrade, basedir=basedir) + return klass._find_format(format_registry, "repository", format_string) + + def check_support_status( + self, allow_unsupported, recommend_upgrade=True, basedir=None + ): + RepositoryFormat.check_support_status( + self, + allow_unsupported=allow_unsupported, + recommend_upgrade=recommend_upgrade, + basedir=basedir, + ) + bzrdir.BzrFormat.check_support_status( + self, + allow_unsupported=allow_unsupported, + recommend_upgrade=recommend_upgrade, + basedir=basedir, + ) diff --git a/breezy/bzr/rio_patch.py b/breezy/bzr/rio_patch.py index 6e81b195de..799703aaf9 100644 --- a/breezy/bzr/rio_patch.py +++ b/breezy/bzr/rio_patch.py @@ -50,62 +50,60 @@ def to_patch_lines(stanza, max_width=72): max_rio_width = max_width - 4 lines = [] for pline in stanza.to_lines(): - for line in pline.split(b'\n')[:-1]: - line = re.sub(b'\\\\', b'\\\\\\\\', line) + for line in pline.split(b"\n")[:-1]: + line = re.sub(b"\\\\", b"\\\\\\\\", line) while len(line) > 0: partline = line[:max_rio_width] line = line[max_rio_width:] - if len(line) > 0 and line[:1] != [b' ']: + if len(line) > 0 and line[:1] != [b" "]: break_index = -1 - break_index = partline.rfind(b' ', -20) + break_index = partline.rfind(b" ", -20) if break_index < 3: - break_index = partline.rfind(b'-', -20) + break_index = partline.rfind(b"-", -20) break_index += 1 if break_index < 3: - break_index = partline.rfind(b'/', -20) + break_index = partline.rfind(b"/", -20) if break_index >= 3: line = partline[break_index:] + line partline = partline[:break_index] if len(line) > 0: - line = b' ' + line - partline = re.sub(b'\r', b'\\\\r', partline) + line = b" " + line + partline = re.sub(b"\r", b"\\\\r", partline) blank_line = False if len(line) > 0: - partline += b'\\' - elif re.search(b' $', partline): - partline += b'\\' + partline += b"\\" + elif re.search(b" $", partline): + partline += b"\\" blank_line = True - lines.append(b'# ' + partline + b'\n') + lines.append(b"# " + partline + b"\n") if blank_line: - lines.append(b'# \n') + lines.append(b"# \n") return lines def _patch_stanza_iter(line_iter): - map = {b'\\\\': b'\\', - b'\\r': b'\r', - b'\\\n': b''} + map = {b"\\\\": b"\\", b"\\r": b"\r", b"\\\n": b""} def mapget(match): return map[match.group(0)] last_line = None for line in line_iter: - if line.startswith(b'# '): + if line.startswith(b"# "): line = line[2:] - elif line.startswith(b'#'): + elif line.startswith(b"#"): line = line[1:] else: raise ValueError(f"bad line {line!r}") if last_line is not None and len(line) > 2: line = line[2:] - line = re.sub(b'\r', b'', line) - line = re.sub(b'\\\\(.|\n)', mapget, line) + line = re.sub(b"\r", b"", line) + line = re.sub(b"\\\\(.|\n)", mapget, line) if last_line is None: last_line = line else: last_line += line - if last_line[-1:] == b'\n': + if last_line[-1:] == b"\n": yield last_line last_line = None if last_line is not None: diff --git a/breezy/bzr/serializer.py b/breezy/bzr/serializer.py index df32fa94d7..7889947407 100644 --- a/breezy/bzr/serializer.py +++ b/breezy/bzr/serializer.py @@ -20,12 +20,10 @@ class BadInventoryFormat(errors.BzrError): - _fmt = "Root class for inventory serialization errors" class UnexpectedInventoryFormat(BadInventoryFormat): - _fmt = "The inventory was not in the expected format:\n %(msg)s" def __init__(self, msg): @@ -33,7 +31,6 @@ def __init__(self, msg): class UnsupportedInventoryKind(errors.BzrError): - _fmt = """Unsupported entry kind %(kind)s""" def __init__(self, kind): @@ -93,8 +90,9 @@ def write_inventory_to_lines(self, inv): """ raise NotImplementedError(self.write_inventory_to_lines) - def read_inventory_from_lines(self, lines, revision_id=None, - entry_cache=None, return_from_cache=False): + def read_inventory_from_lines( + self, lines, revision_id=None, entry_cache=None, return_from_cache=False + ): """Read bytestring chunks into an inventory object. :param lines: The serialized inventory to read. @@ -125,17 +123,29 @@ class SerializerRegistry(registry.Registry): revision_format_registry = SerializerRegistry() -revision_format_registry.register_lazy('5', 'breezy._bzr_rs', 'revision_serializer_v5') -revision_format_registry.register_lazy('8', 'breezy._bzr_rs', 'revision_serializer_v8') -revision_format_registry.register_lazy('10', 'breezy._bzr_rs', 'revision_bencode_serializer') +revision_format_registry.register_lazy("5", "breezy._bzr_rs", "revision_serializer_v5") +revision_format_registry.register_lazy("8", "breezy._bzr_rs", "revision_serializer_v8") +revision_format_registry.register_lazy( + "10", "breezy._bzr_rs", "revision_bencode_serializer" +) inventory_format_registry = SerializerRegistry() -inventory_format_registry.register_lazy('5', 'breezy.bzr.xml5', 'inventory_serializer_v5') -inventory_format_registry.register_lazy('6', 'breezy.bzr.xml6', 'inventory_serializer_v6') -inventory_format_registry.register_lazy('7', 'breezy.bzr.xml7', 'inventory_serializer_v7') -inventory_format_registry.register_lazy('8', 'breezy.bzr.xml8', 'inventory_serializer_v8') -inventory_format_registry.register_lazy('9', 'breezy.bzr.chk_serializer', - 'inventory_chk_serializer_255_bigpage_9') -inventory_format_registry.register_lazy('10', 'breezy.bzr.chk_serializer', - 'inventory_chk_serializer_255_bigpage_10') +inventory_format_registry.register_lazy( + "5", "breezy.bzr.xml5", "inventory_serializer_v5" +) +inventory_format_registry.register_lazy( + "6", "breezy.bzr.xml6", "inventory_serializer_v6" +) +inventory_format_registry.register_lazy( + "7", "breezy.bzr.xml7", "inventory_serializer_v7" +) +inventory_format_registry.register_lazy( + "8", "breezy.bzr.xml8", "inventory_serializer_v8" +) +inventory_format_registry.register_lazy( + "9", "breezy.bzr.chk_serializer", "inventory_chk_serializer_255_bigpage_9" +) +inventory_format_registry.register_lazy( + "10", "breezy.bzr.chk_serializer", "inventory_chk_serializer_255_bigpage_10" +) diff --git a/breezy/bzr/smart/branch.py b/breezy/bzr/smart/branch.py index 86e31551d5..e09b2d1859 100644 --- a/breezy/bzr/smart/branch.py +++ b/breezy/bzr/smart/branch.py @@ -64,31 +64,32 @@ def do_with_branch(self, branch, branch_token, repo_token, *args): processed. The physical lock state won't be changed. """ # XXX: write a test for LockContention - with branch.repository.lock_write(token=repo_token), \ - branch.lock_write(token=branch_token): + with branch.repository.lock_write(token=repo_token), branch.lock_write( + token=branch_token + ): return self.do_with_locked_branch(branch, *args) class SmartServerBranchBreakLock(SmartServerBranchRequest): - def do_with_branch(self, branch): """Break a branch lock.""" branch.break_lock() - return SuccessfulSmartServerResponse((b'ok', ), ) + return SuccessfulSmartServerResponse( + (b"ok",), + ) class SmartServerBranchGetConfigFile(SmartServerBranchRequest): - def do_with_branch(self, branch): """Return the content of branch.conf. The body is not utf8 decoded - its the literal bytestream from disk. """ try: - content = branch.control_transport.get_bytes('branch.conf') + content = branch.control_transport.get_bytes("branch.conf") except _mod_transport.NoSuchFile: - content = b'' - return SuccessfulSmartServerResponse((b'ok', ), content) + content = b"" + return SuccessfulSmartServerResponse((b"ok",), content) class SmartServerBranchPutConfigFile(SmartServerBranchRequest): @@ -109,23 +110,21 @@ def do_with_branch(self, branch, branch_token, repo_token): return None def do_body(self, body_bytes): - with self._branch.repository.lock_write(token=self._repo_token), \ - self._branch.lock_write(token=self._branch_token): - self._branch.control_transport.put_bytes( - 'branch.conf', body_bytes) - return SuccessfulSmartServerResponse((b'ok', )) + with self._branch.repository.lock_write( + token=self._repo_token + ), self._branch.lock_write(token=self._branch_token): + self._branch.control_transport.put_bytes("branch.conf", body_bytes) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerBranchGetParent(SmartServerBranchRequest): - def do_with_branch(self, branch): """Return the parent of branch.""" - parent = branch._get_parent_location() or '' - return SuccessfulSmartServerResponse((parent.encode('utf-8'),)) + parent = branch._get_parent_location() or "" + return SuccessfulSmartServerResponse((parent.encode("utf-8"),)) class SmartServerBranchGetTagsBytes(SmartServerBranchRequest): - def do_with_branch(self, branch): """Return the _get_tags_bytes for a branch.""" bytes = branch._get_tags_bytes() @@ -133,10 +132,10 @@ def do_with_branch(self, branch): class SmartServerBranchSetTagsBytes(SmartServerLockedBranchRequest): - - def __init__(self, backing_transport, root_client_path='/', jail_root=None): + def __init__(self, backing_transport, root_client_path="/", jail_root=None): SmartServerLockedBranchRequest.__init__( - self, backing_transport, root_client_path, jail_root) + self, backing_transport, root_client_path, jail_root + ) self.locked = False def do_with_locked_branch(self, branch): @@ -169,7 +168,6 @@ def do_end(self): class SmartServerBranchHeadsToFetch(SmartServerBranchRequest): - def do_with_branch(self, branch): """Return the heads-to-fetch for a Branch as two bencoded lists. @@ -178,19 +176,16 @@ def do_with_branch(self, branch): New in 2.4. """ must_fetch, if_present_fetch = branch.heads_to_fetch() - return SuccessfulSmartServerResponse( - (list(must_fetch), list(if_present_fetch))) + return SuccessfulSmartServerResponse((list(must_fetch), list(if_present_fetch))) class SmartServerBranchRequestGetStackedOnURL(SmartServerBranchRequest): - def do_with_branch(self, branch): stacked_on_url = branch.get_stacked_on_url() - return SuccessfulSmartServerResponse((b'ok', stacked_on_url.encode('ascii'))) + return SuccessfulSmartServerResponse((b"ok", stacked_on_url.encode("ascii"))) class SmartServerRequestRevisionHistory(SmartServerBranchRequest): - def do_with_branch(self, branch): r"""Get the revision history for the branch. @@ -200,14 +195,15 @@ def do_with_branch(self, branch): with branch.lock_read(): graph = branch.repository.get_graph() stop_revisions = (None, _mod_revision.NULL_REVISION) - history = list(graph.iter_lefthand_ancestry( - branch.last_revision(), stop_revisions)) + history = list( + graph.iter_lefthand_ancestry(branch.last_revision(), stop_revisions) + ) return SuccessfulSmartServerResponse( - (b'ok', ), (b'\x00'.join(reversed(history)))) + (b"ok",), (b"\x00".join(reversed(history))) + ) class SmartServerBranchRequestLastRevisionInfo(SmartServerBranchRequest): - def do_with_branch(self, branch): """Return branch.last_revision_info(). @@ -215,11 +211,11 @@ def do_with_branch(self, branch): """ revno, last_revision = branch.last_revision_info() return SuccessfulSmartServerResponse( - (b'ok', str(revno).encode('ascii'), last_revision)) + (b"ok", str(revno).encode("ascii"), last_revision) + ) class SmartServerBranchRequestRevisionIdToRevno(SmartServerBranchRequest): - def do_with_branch(self, branch, revid): """Return branch.revision_id_to_revno(). @@ -230,13 +226,14 @@ def do_with_branch(self, branch, revid): try: dotted_revno = branch.revision_id_to_dotted_revno(revid) except errors.NoSuchRevision: - return FailedSmartServerResponse((b'NoSuchRevision', revid)) + return FailedSmartServerResponse((b"NoSuchRevision", revid)) except errors.GhostRevisionsHaveNoRevno as e: return FailedSmartServerResponse( - (b'GhostRevisionsHaveNoRevno', e.revision_id, - e.ghost_revision_id)) + (b"GhostRevisionsHaveNoRevno", e.revision_id, e.ghost_revision_id) + ) return SuccessfulSmartServerResponse( - (b'ok', ) + tuple([b'%d' % x for x in dotted_revno])) + (b"ok",) + tuple([b"%d" % x for x in dotted_revno]) + ) class SmartServerSetTipRequest(SmartServerLockedBranchRequest): @@ -250,8 +247,8 @@ def do_with_locked_branch(self, branch, *args): except errors.TipChangeRejected as e: msg = e.msg if isinstance(msg, str): - msg = msg.encode('utf-8') - return FailedSmartServerResponse((b'TipChangeRejected', msg)) + msg = msg.encode("utf-8") + return FailedSmartServerResponse((b"TipChangeRejected", msg)) class SmartServerBranchRequestSetConfigOption(SmartServerLockedBranchRequest): @@ -261,8 +258,10 @@ def do_with_locked_branch(self, branch, value, name, section): if not section: section = None branch._get_config().set_option( - value.decode('utf-8'), name.decode('utf-8'), - section.decode('utf-8') if section is not None else None) + value.decode("utf-8"), + name.decode("utf-8"), + section.decode("utf-8") if section is not None else None, + ) return SuccessfulSmartServerResponse(()) @@ -276,32 +275,32 @@ def do_with_locked_branch(self, branch, value_dict, name, section): utf8_dict = bencode.bdecode(value_dict) value_dict = {} for key, value in utf8_dict.items(): - value_dict[key.decode('utf8')] = value.decode('utf8') + value_dict[key.decode("utf8")] = value.decode("utf8") if not section: section = None else: - section = section.decode('utf-8') - branch._get_config().set_option(value_dict, name.decode('utf-8'), section) + section = section.decode("utf-8") + branch._get_config().set_option(value_dict, name.decode("utf-8"), section) return SuccessfulSmartServerResponse(()) class SmartServerBranchRequestSetLastRevision(SmartServerSetTipRequest): - def do_tip_change_with_locked_branch(self, branch, new_last_revision_id): - if new_last_revision_id == b'null:': + if new_last_revision_id == b"null:": branch.set_last_revision_info(0, new_last_revision_id) else: if not branch.repository.has_revision(new_last_revision_id): return FailedSmartServerResponse( - (b'NoSuchRevision', new_last_revision_id)) + (b"NoSuchRevision", new_last_revision_id) + ) branch.generate_revision_history(new_last_revision_id, None, None) - return SuccessfulSmartServerResponse((b'ok',)) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerBranchRequestSetLastRevisionEx(SmartServerSetTipRequest): - - def do_tip_change_with_locked_branch(self, branch, new_last_revision_id, - allow_divergence, allow_overwrite_descendant): + def do_tip_change_with_locked_branch( + self, branch, new_last_revision_id, allow_divergence, allow_overwrite_descendant + ): """Set the last revision of the branch. New in 1.6. @@ -330,20 +329,19 @@ def do_tip_change_with_locked_branch(self, branch, new_last_revision_id, graph = branch.repository.get_graph() if not allow_divergence or do_not_overwrite_descendant: relation = branch._revision_relations( - last_rev, new_last_revision_id, graph) - if relation == 'diverged' and not allow_divergence: - return FailedSmartServerResponse((b'Diverged',)) - if relation == 'a_descends_from_b' and do_not_overwrite_descendant: - return SuccessfulSmartServerResponse( - (b'ok', last_revno, last_rev)) + last_rev, new_last_revision_id, graph + ) + if relation == "diverged" and not allow_divergence: + return FailedSmartServerResponse((b"Diverged",)) + if relation == "a_descends_from_b" and do_not_overwrite_descendant: + return SuccessfulSmartServerResponse((b"ok", last_revno, last_rev)) new_revno = graph.find_distance_to_null( - new_last_revision_id, [(last_rev, last_revno)]) + new_last_revision_id, [(last_rev, last_revno)] + ) branch.set_last_revision_info(new_revno, new_last_revision_id) except errors.GhostRevisionsHaveNoRevno: - return FailedSmartServerResponse( - (b'NoSuchRevision', new_last_revision_id)) - return SuccessfulSmartServerResponse( - (b'ok', new_revno, new_last_revision_id)) + return FailedSmartServerResponse((b"NoSuchRevision", new_last_revision_id)) + return SuccessfulSmartServerResponse((b"ok", new_revno, new_last_revision_id)) class SmartServerBranchRequestSetLastRevisionInfo(SmartServerSetTipRequest): @@ -353,14 +351,12 @@ class SmartServerBranchRequestSetLastRevisionInfo(SmartServerSetTipRequest): New in breezy 1.4. """ - def do_tip_change_with_locked_branch(self, branch, new_revno, - new_last_revision_id): + def do_tip_change_with_locked_branch(self, branch, new_revno, new_last_revision_id): try: branch.set_last_revision_info(int(new_revno), new_last_revision_id) except errors.NoSuchRevision: - return FailedSmartServerResponse( - (b'NoSuchRevision', new_last_revision_id)) - return SuccessfulSmartServerResponse((b'ok',)) + return FailedSmartServerResponse((b"NoSuchRevision", new_last_revision_id)) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerBranchRequestSetParentLocation(SmartServerLockedBranchRequest): @@ -370,57 +366,54 @@ class SmartServerBranchRequestSetParentLocation(SmartServerLockedBranchRequest): """ def do_with_locked_branch(self, branch, location): - branch._set_parent_location(location.decode('utf-8')) + branch._set_parent_location(location.decode("utf-8")) return SuccessfulSmartServerResponse(()) class SmartServerBranchRequestLockWrite(SmartServerBranchRequest): - - def do_with_branch(self, branch, branch_token=b'', repo_token=b''): - if branch_token == b'': + def do_with_branch(self, branch, branch_token=b"", repo_token=b""): + if branch_token == b"": branch_token = None - if repo_token == b'': + if repo_token == b"": repo_token = None try: - repo_token = branch.repository.lock_write( - token=repo_token).repository_token + repo_token = branch.repository.lock_write(token=repo_token).repository_token try: - branch_token = branch.lock_write( - token=branch_token).token + branch_token = branch.lock_write(token=branch_token).token finally: # this leaves the repository with 1 lock branch.repository.unlock() except errors.LockContention: - return FailedSmartServerResponse((b'LockContention',)) + return FailedSmartServerResponse((b"LockContention",)) except errors.TokenMismatch: - return FailedSmartServerResponse((b'TokenMismatch',)) + return FailedSmartServerResponse((b"TokenMismatch",)) except errors.UnlockableTransport: - return FailedSmartServerResponse((b'UnlockableTransport',)) + return FailedSmartServerResponse((b"UnlockableTransport",)) except errors.LockFailed as e: - return FailedSmartServerResponse((b'LockFailed', - str(e.lock).encode('utf-8'), str(e.why).encode('utf-8'))) + return FailedSmartServerResponse( + (b"LockFailed", str(e.lock).encode("utf-8"), str(e.why).encode("utf-8")) + ) if repo_token is None: - repo_token = b'' + repo_token = b"" else: branch.repository.leave_lock_in_place() branch.leave_lock_in_place() branch.unlock() - return SuccessfulSmartServerResponse((b'ok', branch_token, repo_token)) + return SuccessfulSmartServerResponse((b"ok", branch_token, repo_token)) class SmartServerBranchRequestUnlock(SmartServerBranchRequest): - def do_with_branch(self, branch, branch_token, repo_token): try: with branch.repository.lock_write(token=repo_token): branch.lock_write(token=branch_token) except errors.TokenMismatch: - return FailedSmartServerResponse((b'TokenMismatch',)) + return FailedSmartServerResponse((b"TokenMismatch",)) if repo_token: branch.repository.dont_leave_lock_in_place() branch.dont_leave_lock_in_place() branch.unlock() - return SuccessfulSmartServerResponse((b'ok',)) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerBranchRequestGetPhysicalLockStatus(SmartServerBranchRequest): @@ -431,9 +424,9 @@ class SmartServerBranchRequestGetPhysicalLockStatus(SmartServerBranchRequest): def do_with_branch(self, branch): if branch.get_physical_lock_status(): - return SuccessfulSmartServerResponse((b'yes',)) + return SuccessfulSmartServerResponse((b"yes",)) else: - return SuccessfulSmartServerResponse((b'no',)) + return SuccessfulSmartServerResponse((b"no",)) class SmartServerBranchRequestGetAllReferenceInfo(SmartServerBranchRequest): @@ -444,7 +437,14 @@ class SmartServerBranchRequestGetAllReferenceInfo(SmartServerBranchRequest): def do_with_branch(self, branch): all_reference_info = branch._get_all_reference_info() - content = bencode.bencode([ - (key, value[0].encode('utf-8'), value[1].encode('utf-8') if value[1] else b'') - for (key, value) in all_reference_info.items()]) - return SuccessfulSmartServerResponse((b'ok', ), content) + content = bencode.bencode( + [ + ( + key, + value[0].encode("utf-8"), + value[1].encode("utf-8") if value[1] else b"", + ) + for (key, value) in all_reference_info.items() + ] + ) + return SuccessfulSmartServerResponse((b"ok",), content) diff --git a/breezy/bzr/smart/bzrdir.py b/breezy/bzr/smart/bzrdir.py index c66fece086..213a224030 100644 --- a/breezy/bzr/smart/bzrdir.py +++ b/breezy/bzr/smart/bzrdir.py @@ -30,7 +30,6 @@ class SmartServerRequestOpenBzrDir(SmartServerRequest): - def do(self, path): try: t = self.transport_from_client_path(path) @@ -40,20 +39,19 @@ def do(self, path): # Ideally we'd return a FailedSmartServerResponse here rather than # a "successful" negative, but we want to be compatibile with # clients that don't anticipate errors from this method. - answer = b'no' + answer = b"no" else: bzr_prober = BzrProber() try: bzr_prober.probe_transport(t) except (errors.NotBranchError, errors.UnknownFormatError): - answer = b'no' + answer = b"no" else: - answer = b'yes' + answer = b"yes" return SuccessfulSmartServerResponse((answer,)) class SmartServerRequestOpenBzrDir_2_1(SmartServerRequest): - def do(self, path): """Is there a BzrDir present, and if so does it have a working tree? @@ -64,60 +62,56 @@ def do(self, path): except errors.PathNotChild: # The client is trying to ask about a path that they have no access # to. - return SuccessfulSmartServerResponse((b'no',)) + return SuccessfulSmartServerResponse((b"no",)) try: bd = BzrDir.open_from_transport(t) except errors.NotBranchError: - answer = (b'no',) + answer = (b"no",) else: - answer = (b'yes',) + answer = (b"yes",) if bd.has_workingtree(): - answer += (b'yes',) + answer += (b"yes",) else: - answer += (b'no',) + answer += (b"no",) return SuccessfulSmartServerResponse(answer) class SmartServerRequestBzrDir(SmartServerRequest): - def do(self, path, *args): """Open a BzrDir at path, and return `self.do_bzrdir_request(*args)`.""" try: self._bzrdir = BzrDir.open_from_transport( - self.transport_from_client_path(path)) + self.transport_from_client_path(path) + ) except errors.NotBranchError: - return FailedSmartServerResponse((b'nobranch',)) + return FailedSmartServerResponse((b"nobranch",)) return self.do_bzrdir_request(*args) def _boolean_to_yes_no(self, a_boolean): if a_boolean: - return b'yes' + return b"yes" else: - return b'no' + return b"no" def _format_to_capabilities(self, repo_format): rich_root = self._boolean_to_yes_no(repo_format.rich_root_data) - tree_ref = self._boolean_to_yes_no( - repo_format.supports_tree_reference) - external_lookup = self._boolean_to_yes_no( - repo_format.supports_external_lookups) + tree_ref = self._boolean_to_yes_no(repo_format.supports_tree_reference) + external_lookup = self._boolean_to_yes_no(repo_format.supports_external_lookups) return rich_root, tree_ref, external_lookup def _repo_relpath(self, current_transport, repository): """Get the relative path for repository from current_transport.""" # the relpath of the bzrdir in the found repository gives us the # path segments to pop-out. - relpath = repository.user_transport.relpath( - current_transport.base) + relpath = repository.user_transport.relpath(current_transport.base) if len(relpath): - segments = ['..'] * len(relpath.split('/')) + segments = [".."] * len(relpath.split("/")) else: segments = [] - return '/'.join(segments) + return "/".join(segments) class SmartServerBzrDirRequestDestroyBranch(SmartServerRequestBzrDir): - def do_bzrdir_request(self, name=None): """Destroy the branch with the specified name. @@ -126,14 +120,14 @@ def do_bzrdir_request(self, name=None): """ try: self._bzrdir.destroy_branch( - name.decode('utf-8') if name is not None else None) + name.decode("utf-8") if name is not None else None + ) except errors.NotBranchError: - return FailedSmartServerResponse((b'nobranch',)) - return SuccessfulSmartServerResponse((b'ok',)) + return FailedSmartServerResponse((b"nobranch",)) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerBzrDirRequestHasWorkingTree(SmartServerRequestBzrDir): - def do_bzrdir_request(self, name=None): """Check whether there is a working tree present. @@ -143,13 +137,12 @@ def do_bzrdir_request(self, name=None): Otherwise 'no'. """ if self._bzrdir.has_workingtree(): - return SuccessfulSmartServerResponse((b'yes', )) + return SuccessfulSmartServerResponse((b"yes",)) else: - return SuccessfulSmartServerResponse((b'no', )) + return SuccessfulSmartServerResponse((b"no",)) class SmartServerBzrDirRequestDestroyRepository(SmartServerRequestBzrDir): - def do_bzrdir_request(self, name=None): """Destroy the repository. @@ -160,12 +153,11 @@ def do_bzrdir_request(self, name=None): try: self._bzrdir.destroy_repository() except errors.NoRepositoryPresent: - return FailedSmartServerResponse((b'norepository',)) - return SuccessfulSmartServerResponse((b'ok',)) + return FailedSmartServerResponse((b"norepository",)) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerBzrDirRequestCloningMetaDir(SmartServerRequestBzrDir): - def do_bzrdir_request(self, require_stacking): """Get the format that should be used when cloning from this dir. @@ -185,24 +177,23 @@ def do_bzrdir_request(self, require_stacking): # The server shouldn't try to resolve references, and it quite # possibly can't reach them anyway. The client needs to resolve # the branch reference to determine the cloning_metadir. - return FailedSmartServerResponse((b'BranchReference',)) + return FailedSmartServerResponse((b"BranchReference",)) if require_stacking == b"True": require_stacking = True else: require_stacking = False - control_format = self._bzrdir.cloning_metadir( - require_stacking=require_stacking) + control_format = self._bzrdir.cloning_metadir(require_stacking=require_stacking) control_name = control_format.network_name() if not control_format.fixed_components: - branch_name = (b'branch', - control_format.get_branch_format().network_name()) + branch_name = (b"branch", control_format.get_branch_format().network_name()) repository_name = control_format.repository_format.network_name() else: # Only MetaDir has delegated formats today. - branch_name = (b'branch', b'') - repository_name = b'' - return SuccessfulSmartServerResponse((control_name, repository_name, - branch_name)) + branch_name = (b"branch", b"") + repository_name = b"" + return SuccessfulSmartServerResponse( + (control_name, repository_name, branch_name) + ) class SmartServerBzrDirRequestCheckoutMetaDir(SmartServerRequestBzrDir): @@ -227,21 +218,19 @@ def do_bzrdir_request(self): # The server shouldn't try to resolve references, and it quite # possibly can't reach them anyway. The client needs to resolve # the branch reference to determine the cloning_metadir. - return FailedSmartServerResponse((b'BranchReference',)) + return FailedSmartServerResponse((b"BranchReference",)) control_format = self._bzrdir.checkout_metadir() control_name = control_format.network_name() if not control_format.fixed_components: branch_name = control_format.get_branch_format().network_name() repo_name = control_format.repository_format.network_name() else: - branch_name = b'' - repo_name = b'' - return SuccessfulSmartServerResponse( - (control_name, repo_name, branch_name)) + branch_name = b"" + repo_name = b"" + return SuccessfulSmartServerResponse((control_name, repo_name, branch_name)) class SmartServerRequestCreateBranch(SmartServerRequestBzrDir): - def do(self, path, network_name): """Create a branch in the bzr dir at path. @@ -260,25 +249,32 @@ def do(self, path, network_name): :return: ('ok', branch_format, repo_path, rich_root, tree_ref, external_lookup, repo_format) """ - bzrdir = BzrDir.open_from_transport( - self.transport_from_client_path(path)) + bzrdir = BzrDir.open_from_transport(self.transport_from_client_path(path)) format = branch.network_format_registry.get(network_name) bzrdir.branch_format = format result = format.initialize(bzrdir, name="") rich_root, tree_ref, external_lookup = self._format_to_capabilities( - result.repository._format) + result.repository._format + ) branch_format = result._format.network_name() repo_format = result.repository._format.network_name() - repo_path = self._repo_relpath(bzrdir.root_transport, - result.repository) + repo_path = self._repo_relpath(bzrdir.root_transport, result.repository) # branch format, repo relpath, rich_root, tree_ref, external_lookup, # repo_network_name - return SuccessfulSmartServerResponse((b'ok', branch_format, repo_path, - rich_root, tree_ref, external_lookup, repo_format)) + return SuccessfulSmartServerResponse( + ( + b"ok", + branch_format, + repo_path, + rich_root, + tree_ref, + external_lookup, + repo_format, + ) + ) class SmartServerRequestCreateRepository(SmartServerRequestBzrDir): - def do(self, path, network_name, shared): """Create a repository in the bzr dir at path. @@ -298,20 +294,20 @@ def do(self, path, network_name, shared): parameter. :return: (ok, rich_root, tree_ref, external_lookup, network_name) """ - bzrdir = BzrDir.open_from_transport( - self.transport_from_client_path(path)) - shared = shared == b'True' + bzrdir = BzrDir.open_from_transport(self.transport_from_client_path(path)) + shared = shared == b"True" format = repository.network_format_registry.get(network_name) bzrdir.repository_format = format result = format.initialize(bzrdir, shared=shared) rich_root, tree_ref, external_lookup = self._format_to_capabilities( - result._format) - return SuccessfulSmartServerResponse((b'ok', rich_root, tree_ref, - external_lookup, result._format.network_name())) + result._format + ) + return SuccessfulSmartServerResponse( + (b"ok", rich_root, tree_ref, external_lookup, result._format.network_name()) + ) class SmartServerRequestFindRepository(SmartServerRequestBzrDir): - def _find(self, path): """Try to find a repository from path upwards. @@ -324,18 +320,17 @@ def _find(self, path): :raises errors.NoRepositoryPresent: When there is no repository present. """ - bzrdir = BzrDir.open_from_transport( - self.transport_from_client_path(path)) + bzrdir = BzrDir.open_from_transport(self.transport_from_client_path(path)) repository = bzrdir.find_repository() path = self._repo_relpath(bzrdir.root_transport, repository) rich_root, tree_ref, external_lookup = self._format_to_capabilities( - repository._format) + repository._format + ) network_name = repository._format.network_name() return path, rich_root, tree_ref, external_lookup, network_name class SmartServerRequestFindRepositoryV1(SmartServerRequestFindRepository): - def do(self, path): """Try to find a repository from path upwards. @@ -352,13 +347,14 @@ def do(self, path): """ try: path, rich_root, tree_ref, external_lookup, name = self._find(path) - return SuccessfulSmartServerResponse((b'ok', path.encode('utf-8'), rich_root, tree_ref)) + return SuccessfulSmartServerResponse( + (b"ok", path.encode("utf-8"), rich_root, tree_ref) + ) except errors.NoRepositoryPresent: - return FailedSmartServerResponse((b'norepository', )) + return FailedSmartServerResponse((b"norepository",)) class SmartServerRequestFindRepositoryV2(SmartServerRequestFindRepository): - def do(self, path): """Try to find a repository from path upwards. @@ -377,13 +373,13 @@ def do(self, path): try: path, rich_root, tree_ref, external_lookup, name = self._find(path) return SuccessfulSmartServerResponse( - (b'ok', path.encode('utf-8'), rich_root, tree_ref, external_lookup)) + (b"ok", path.encode("utf-8"), rich_root, tree_ref, external_lookup) + ) except errors.NoRepositoryPresent: - return FailedSmartServerResponse((b'norepository', )) + return FailedSmartServerResponse((b"norepository",)) class SmartServerRequestFindRepositoryV3(SmartServerRequestFindRepository): - def do(self, path): """Try to find a repository from path upwards. @@ -401,13 +397,20 @@ def do(self, path): try: path, rich_root, tree_ref, external_lookup, name = self._find(path) return SuccessfulSmartServerResponse( - (b'ok', path.encode('utf-8'), rich_root, tree_ref, external_lookup, name)) + ( + b"ok", + path.encode("utf-8"), + rich_root, + tree_ref, + external_lookup, + name, + ) + ) except errors.NoRepositoryPresent: - return FailedSmartServerResponse((b'norepository', )) + return FailedSmartServerResponse((b"norepository",)) class SmartServerBzrDirRequestConfigFile(SmartServerRequestBzrDir): - def do_bzrdir_request(self): """Get the configuration bytes for a config file in bzrdir. @@ -415,14 +418,13 @@ def do_bzrdir_request(self): """ config = self._bzrdir._get_config() if config is None: - content = b'' + content = b"" else: content = config._get_config_file().read() return SuccessfulSmartServerResponse((), content) class SmartServerBzrDirRequestGetBranches(SmartServerRequestBzrDir): - def do_bzrdir_request(self): """Get the branches in a control directory. @@ -437,17 +439,15 @@ def do_bzrdir_request(self): branch_ref = self._bzrdir.get_branch_reference(name=name) if branch_ref is not None: branch_ref = urlutils.relative_url(self._bzrdir.user_url, branch_ref) - value = (b"ref", branch_ref.encode('utf-8')) + value = (b"ref", branch_ref.encode("utf-8")) else: b = self._bzrdir.open_branch(name=name, ignore_fallbacks=True) value = (b"branch", b._format.network_name()) - ret[name.encode('utf-8')] = value - return SuccessfulSmartServerResponse( - (b"success", ), bencode.bencode(ret)) + ret[name.encode("utf-8")] = value + return SuccessfulSmartServerResponse((b"success",), bencode.bencode(ret)) class SmartServerRequestInitializeBzrDir(SmartServerRequest): - def do(self, path): """Initialize a bzrdir at path. @@ -456,17 +456,16 @@ def do(self, path): """ target_transport = self.transport_from_client_path(path) BzrDirFormat.get_default_format().initialize_on_transport(target_transport) - return SuccessfulSmartServerResponse((b'ok', )) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerRequestBzrDirInitializeEx(SmartServerRequestBzrDir): - def parse_NoneTrueFalse(self, arg): if not arg: return None - if arg == b'False': + if arg == b"False": return False - if arg == b'True': + if arg == b"True": return True raise AssertionError(f"invalid arg {arg!r}") @@ -476,18 +475,28 @@ def parse_NoneBytestring(self, arg): def parse_NoneString(self, arg): if not arg: return None - return arg.decode('utf-8') + return arg.decode("utf-8") def _serialize_NoneTrueFalse(self, arg): if arg is False: - return b'False' + return b"False" if not arg: - return b'' - return b'True' - - def do(self, bzrdir_network_name, path, use_existing_dir, create_prefix, - force_new_repo, stacked_on, stack_on_pwd, repo_format_name, - make_working_trees, shared_repo): + return b"" + return b"True" + + def do( + self, + bzrdir_network_name, + path, + use_existing_dir, + create_prefix, + force_new_repo, + stacked_on, + stack_on_pwd, + repo_format_name, + make_working_trees, + shared_repo, + ): """Initialize a bzrdir at path as per BzrDirFormat.initialize_on_transport_ex. @@ -508,29 +517,35 @@ def do(self, bzrdir_network_name, path, use_existing_dir, create_prefix, stack_on_pwd = self.parse_NoneString(stack_on_pwd) make_working_trees = self.parse_NoneTrueFalse(make_working_trees) shared_repo = self.parse_NoneTrueFalse(shared_repo) - if stack_on_pwd == b'.': - stack_on_pwd = target_transport.base.encode('utf-8') + if stack_on_pwd == b".": + stack_on_pwd = target_transport.base.encode("utf-8") repo_format_name = self.parse_NoneBytestring(repo_format_name) - repo, bzrdir, stacking, repository_policy = \ - format.initialize_on_transport_ex(target_transport, - use_existing_dir=use_existing_dir, create_prefix=create_prefix, - force_new_repo=force_new_repo, stacked_on=stacked_on, - stack_on_pwd=stack_on_pwd, repo_format_name=repo_format_name, - make_working_trees=make_working_trees, shared_repo=shared_repo) + repo, bzrdir, stacking, repository_policy = format.initialize_on_transport_ex( + target_transport, + use_existing_dir=use_existing_dir, + create_prefix=create_prefix, + force_new_repo=force_new_repo, + stacked_on=stacked_on, + stack_on_pwd=stack_on_pwd, + repo_format_name=repo_format_name, + make_working_trees=make_working_trees, + shared_repo=shared_repo, + ) if repo is None: - repo_path = '' - repo_name = b'' - rich_root = tree_ref = external_lookup = b'' - repo_bzrdir_name = b'' + repo_path = "" + repo_name = b"" + rich_root = tree_ref = external_lookup = b"" + repo_bzrdir_name = b"" final_stack = None final_stack_pwd = None - repo_lock_token = b'' + repo_lock_token = b"" else: repo_path = self._repo_relpath(bzrdir.root_transport, repo) - if repo_path == '': - repo_path = '.' + if repo_path == "": + repo_path = "." rich_root, tree_ref, external_lookup = self._format_to_capabilities( - repo._format) + repo._format + ) repo_name = repo._format.network_name() repo_bzrdir_name = repo.controldir._format.network_name() final_stack = repository_policy._stack_on @@ -538,48 +553,55 @@ def do(self, bzrdir_network_name, path, use_existing_dir, create_prefix, # It is returned locked, but we need to do the lock to get the lock # token. repo.unlock() - repo_lock_token = repo.lock_write().repository_token or b'' + repo_lock_token = repo.lock_write().repository_token or b"" if repo_lock_token: repo.leave_lock_in_place() repo.unlock() - final_stack = final_stack or '' - final_stack_pwd = final_stack_pwd or '' + final_stack = final_stack or "" + final_stack_pwd = final_stack_pwd or "" # We want this to be relative to the bzrdir. if final_stack_pwd: final_stack_pwd = urlutils.relative_url( - target_transport.base, final_stack_pwd) + target_transport.base, final_stack_pwd + ) # Can't meaningfully return a root path. - if final_stack.startswith('/'): + if final_stack.startswith("/"): client_path = self._root_client_path + final_stack[1:] - final_stack = urlutils.relative_url( - self._root_client_path, client_path) - final_stack_pwd = '.' + final_stack = urlutils.relative_url(self._root_client_path, client_path) + final_stack_pwd = "." - return SuccessfulSmartServerResponse((repo_path.encode('utf-8'), - rich_root, tree_ref, external_lookup, repo_name, repo_bzrdir_name, - bzrdir._format.network_name(), - self._serialize_NoneTrueFalse( - stacking), final_stack.encode('utf-8'), - final_stack_pwd.encode('utf-8'), repo_lock_token)) + return SuccessfulSmartServerResponse( + ( + repo_path.encode("utf-8"), + rich_root, + tree_ref, + external_lookup, + repo_name, + repo_bzrdir_name, + bzrdir._format.network_name(), + self._serialize_NoneTrueFalse(stacking), + final_stack.encode("utf-8"), + final_stack_pwd.encode("utf-8"), + repo_lock_token, + ) + ) class SmartServerRequestOpenBranch(SmartServerRequestBzrDir): - def do_bzrdir_request(self): """Open a branch at path and return the branch reference or branch.""" try: reference_url = self._bzrdir.get_branch_reference() if reference_url is None: - reference_url = '' - return SuccessfulSmartServerResponse((b'ok', reference_url.encode('utf-8'))) + reference_url = "" + return SuccessfulSmartServerResponse((b"ok", reference_url.encode("utf-8"))) except errors.NotBranchError: - return FailedSmartServerResponse((b'nobranch',)) + return FailedSmartServerResponse((b"nobranch",)) class SmartServerRequestOpenBranchV2(SmartServerRequestBzrDir): - def do_bzrdir_request(self): """Open a branch at path and return the reference or format.""" try: @@ -587,15 +609,16 @@ def do_bzrdir_request(self): if reference_url is None: br = self._bzrdir.open_branch(ignore_fallbacks=True) format = br._format.network_name() - return SuccessfulSmartServerResponse((b'branch', format)) + return SuccessfulSmartServerResponse((b"branch", format)) else: - return SuccessfulSmartServerResponse((b'ref', reference_url.encode('utf-8'))) + return SuccessfulSmartServerResponse( + (b"ref", reference_url.encode("utf-8")) + ) except errors.NotBranchError: - return FailedSmartServerResponse((b'nobranch',)) + return FailedSmartServerResponse((b"nobranch",)) class SmartServerRequestOpenBranchV3(SmartServerRequestBzrDir): - def do_bzrdir_request(self): """Open a branch at path and return the reference or format. @@ -611,17 +634,19 @@ def do_bzrdir_request(self): if reference_url is None: br = self._bzrdir.open_branch(ignore_fallbacks=True) format = br._format.network_name() - return SuccessfulSmartServerResponse((b'branch', format)) + return SuccessfulSmartServerResponse((b"branch", format)) else: - return SuccessfulSmartServerResponse((b'ref', reference_url.encode('utf-8'))) + return SuccessfulSmartServerResponse( + (b"ref", reference_url.encode("utf-8")) + ) except errors.NotBranchError as e: # Stringify the exception so that its .detail attribute will be # filled out. str(e) - resp = (b'nobranch',) + resp = (b"nobranch",) detail = e.detail if detail: - if detail.startswith(': '): + if detail.startswith(": "): detail = detail[2:] - resp += (detail.encode('utf-8'),) + resp += (detail.encode("utf-8"),) return FailedSmartServerResponse(resp) diff --git a/breezy/bzr/smart/client.py b/breezy/bzr/smart/client.py index a17bec711d..3bc96dcf69 100644 --- a/breezy/bzr/smart/client.py +++ b/breezy/bzr/smart/client.py @@ -21,7 +21,6 @@ class _SmartClient: - def __init__(self, medium, headers=None): """Constructor. @@ -29,19 +28,31 @@ def __init__(self, medium, headers=None): """ self._medium = medium if headers is None: - self._headers = { - b'Software version': breezy.__version__.encode('utf-8')} + self._headers = {b"Software version": breezy.__version__.encode("utf-8")} else: self._headers = dict(headers) def __repr__(self): - return f'{self.__class__.__name__}({self._medium!r})' - - def _call_and_read_response(self, method, args, body=None, readv_body=None, - body_stream=None, expect_response_body=True): - request = _SmartClientRequest(self, method, args, body=body, - readv_body=readv_body, body_stream=body_stream, - expect_response_body=expect_response_body) + return f"{self.__class__.__name__}({self._medium!r})" + + def _call_and_read_response( + self, + method, + args, + body=None, + readv_body=None, + body_stream=None, + expect_response_body=True, + ): + request = _SmartClientRequest( + self, + method, + args, + body=body, + readv_body=readv_body, + body_stream=body_stream, + expect_response_body=expect_response_body, + ) return request.call_and_read_response() def call(self, method, *args): @@ -58,44 +69,46 @@ def call_expecting_body(self, method, *args): result, smart_protocol = smart_client.call_expecting_body(...) body = smart_protocol.read_body_bytes() """ - return self._call_and_read_response( - method, args, expect_response_body=True) + return self._call_and_read_response(method, args, expect_response_body=True) def call_with_body_bytes(self, method, args, body): """Call a method on the remote server with body bytes.""" if not isinstance(method, bytes): - raise TypeError(f'method must be a byte string, not {method!r}') + raise TypeError(f"method must be a byte string, not {method!r}") for arg in args: if not isinstance(arg, bytes): - raise TypeError(f'args must be byte strings, not {args!r}') + raise TypeError(f"args must be byte strings, not {args!r}") if not isinstance(body, bytes): - raise TypeError(f'body must be byte string, not {body!r}') + raise TypeError(f"body must be byte string, not {body!r}") response, response_handler = self._call_and_read_response( - method, args, body=body, expect_response_body=False) + method, args, body=body, expect_response_body=False + ) return response def call_with_body_bytes_expecting_body(self, method, args, body): """Call a method on the remote server with body bytes.""" if not isinstance(method, bytes): - raise TypeError(f'method must be a byte string, not {method!r}') + raise TypeError(f"method must be a byte string, not {method!r}") for arg in args: if not isinstance(arg, bytes): - raise TypeError(f'args must be byte strings, not {args!r}') + raise TypeError(f"args must be byte strings, not {args!r}") if not isinstance(body, bytes): - raise TypeError(f'body must be byte string, not {body!r}') + raise TypeError(f"body must be byte string, not {body!r}") response, response_handler = self._call_and_read_response( - method, args, body=body, expect_response_body=True) + method, args, body=body, expect_response_body=True + ) return (response, response_handler) def call_with_body_readv_array(self, args, body): response, response_handler = self._call_and_read_response( - args[0], args[1:], readv_body=body, expect_response_body=True) + args[0], args[1:], readv_body=body, expect_response_body=True + ) return (response, response_handler) def call_with_body_stream(self, args, stream): response, response_handler = self._call_and_read_response( - args[0], args[1:], body_stream=stream, - expect_response_body=False) + args[0], args[1:], body_stream=stream, expect_response_body=False + ) return (response, response_handler) def remote_path_from_transport(self, transport): @@ -105,7 +118,7 @@ def remote_path_from_transport(self, transport): anything but path, so it is only safe to use it in requests sent over the medium from the matching transport. """ - return self._medium.remote_path_from_transport(transport).encode('utf-8') + return self._medium.remote_path_from_transport(transport).encode("utf-8") class _SmartClientRequest: @@ -121,8 +134,16 @@ class _SmartClientRequest: get the response from the server. """ - def __init__(self, client, method, args, body=None, readv_body=None, - body_stream=None, expect_response_body=True): + def __init__( + self, + client, + method, + args, + body=None, + readv_body=None, + body_stream=None, + expect_response_body=True, + ): self.client = client self.method = method self.args = args @@ -148,26 +169,28 @@ def call_and_read_response(self): def _is_safe_to_send_twice(self): """Check if the current method is re-entrant safe.""" - if self.body_stream is not None or debug.debug_flag_enabled('noretry'): + if self.body_stream is not None or debug.debug_flag_enabled("noretry"): # We can't restart a body stream that has already been consumed. return False from breezy.bzr.smart import request as _mod_request + request_type = _mod_request.request_handlers.get_info(self.method) - if request_type in ('read', 'idem', 'semi'): + if request_type in ("read", "idem", "semi"): return True # If we have gotten this far, 'stream' cannot be retried, because we # already consumed the local stream. - if request_type in ('semivfs', 'mutate', 'stream'): + if request_type in ("semivfs", "mutate", "stream"): return False - trace.mutter(f'Unknown request type: {request_type} for method {self.method}') + trace.mutter(f"Unknown request type: {request_type} for method {self.method}") return False def _run_call_hooks(self): - if not _SmartClient.hooks['call']: + if not _SmartClient.hooks["call"]: return - params = CallHookParams(self.method, self.args, self.body, - self.readv_body, self.client._medium) - for hook in _SmartClient.hooks['call']: + params = CallHookParams( + self.method, self.args, self.body, self.readv_body, self.client._medium + ) + for hook in _SmartClient.hooks["call"]: hook(params) def _call(self, protocol_version): @@ -179,18 +202,21 @@ def _call(self, protocol_version): response_handler = self._send(protocol_version) try: response_tuple = response_handler.read_response_tuple( - expect_body=self.expect_response_body) + expect_body=self.expect_response_body + ) except ConnectionResetError: self.client._medium.reset() if not self._is_safe_to_send_twice(): raise - trace.warning(f'ConnectionReset reading response for {self.method!r}, retrying') + trace.warning( + f"ConnectionReset reading response for {self.method!r}, retrying" + ) trace.log_exception_quietly() - encoder, response_handler = self._construct_protocol( - protocol_version) + encoder, response_handler = self._construct_protocol(protocol_version) self._send_no_retry(encoder) response_tuple = response_handler.read_response_tuple( - expect_body=self.expect_response_body) + expect_body=self.expect_response_body + ) return (response_tuple, response_handler) def _call_determining_protocol_version(self): @@ -210,9 +236,10 @@ def _call_determining_protocol_version(self): # TODO: We could recover from this without disconnecting if # we recognise the protocol version. trace.warning( - 'Server does not understand Bazaar network protocol %d,' - ' reconnecting. (Upgrade the server to avoid this.)' - % (protocol_version,)) + "Server does not understand Bazaar network protocol %d," + " reconnecting. (Upgrade the server to avoid this.)" + % (protocol_version,) + ) self.client._medium.disconnect() last_err = err continue @@ -225,7 +252,8 @@ def _call_determining_protocol_version(self): self.client._medium._protocol_version = protocol_version return response_tuple, response_handler raise errors.SmartProtocolError( - 'Server is not a Bazaar server: ' + str(last_err)) + "Server is not a Bazaar server: " + str(last_err) + ) def _construct_protocol(self, version): """Build the encoding stack for a given protocol version.""" @@ -234,7 +262,8 @@ def _construct_protocol(self, version): request_encoder = protocol.ProtocolThreeRequester(request) response_handler = message.ConventionalResponseHandler() response_proto = protocol.ProtocolThreeDecoder( - response_handler, expect_version_marker=True) + response_handler, expect_version_marker=True + ) response_handler.setProtoAndMediumRequest(response_proto, request) elif version == 2: request_encoder = protocol.SmartClientRequestProtocolTwo(request) @@ -265,19 +294,18 @@ def _send(self, protocol_version): # Connection is dead, so close our end of it. self.client._medium.reset() - if ((debug.debug_flag_enabled('noretry')) or - (self.body_stream is not None and - encoder.body_stream_started)): + if (debug.debug_flag_enabled("noretry")) or ( + self.body_stream is not None and encoder.body_stream_started + ): # We can't restart a body_stream that has been partially # consumed, so we don't retry. # Note: We don't have to worry about # SmartClientRequestProtocolOne or Two, because they don't # support client-side body streams. raise - trace.warning(f'ConnectionReset calling {self.method!r}, retrying') + trace.warning(f"ConnectionReset calling {self.method!r}, retrying") trace.log_exception_quietly() - encoder, response_handler = self._construct_protocol( - protocol_version) + encoder, response_handler = self._construct_protocol(protocol_version) self._send_no_retry(encoder) return response_handler @@ -286,43 +314,41 @@ def _send_no_retry(self, encoder): encoder.set_headers(self.client._headers) if self.body is not None: if self.readv_body is not None: - raise AssertionError( - "body and readv_body are mutually exclusive.") + raise AssertionError("body and readv_body are mutually exclusive.") if self.body_stream is not None: - raise AssertionError( - "body and body_stream are mutually exclusive.") - encoder.call_with_body_bytes( - (self.method, ) + self.args, self.body) + raise AssertionError("body and body_stream are mutually exclusive.") + encoder.call_with_body_bytes((self.method,) + self.args, self.body) elif self.readv_body is not None: if self.body_stream is not None: raise AssertionError( - "readv_body and body_stream are mutually exclusive.") - encoder.call_with_body_readv_array((self.method, ) + self.args, - self.readv_body) + "readv_body and body_stream are mutually exclusive." + ) + encoder.call_with_body_readv_array( + (self.method,) + self.args, self.readv_body + ) elif self.body_stream is not None: - encoder.call_with_body_stream((self.method, ) + self.args, - self.body_stream) + encoder.call_with_body_stream((self.method,) + self.args, self.body_stream) else: encoder.call(self.method, *self.args) class SmartClientHooks(hooks.Hooks): - def __init__(self): - hooks.Hooks.__init__( - self, "breezy.bzr.smart.client", "_SmartClient.hooks") - self.add_hook('call', - "Called when the smart client is submitting a request to the " - "smart server. Called with a breezy.bzr.smart.client.CallHookParams " - "object. Streaming request bodies, and responses, are not " - "accessible.", None) + hooks.Hooks.__init__(self, "breezy.bzr.smart.client", "_SmartClient.hooks") + self.add_hook( + "call", + "Called when the smart client is submitting a request to the " + "smart server. Called with a breezy.bzr.smart.client.CallHookParams " + "object. Streaming request bodies, and responses, are not " + "accessible.", + None, + ) _SmartClient.hooks = SmartClientHooks() # type: ignore class CallHookParams: - def __init__(self, method, args, body, readv_body, medium): self.method = method self.args = args @@ -331,9 +357,8 @@ def __init__(self, method, args, body, readv_body, medium): self.medium = medium def __repr__(self): - attrs = {k: v for k, v in self.__dict__.items() - if v is not None} - return f'<{self.__class__.__name__} {attrs!r}>' + attrs = {k: v for k, v in self.__dict__.items() if v is not None} + return f"<{self.__class__.__name__} {attrs!r}>" def __eq__(self, other): if not isinstance(other, type(self)): diff --git a/breezy/bzr/smart/medium.py b/breezy/bzr/smart/medium.py index 86f8d2d186..1b65179ac2 100644 --- a/breezy/bzr/smart/medium.py +++ b/breezy/bzr/smart/medium.py @@ -35,7 +35,9 @@ from ...lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import select import socket import weakref @@ -48,7 +50,8 @@ from breezy.i18n import gettext from breezy.bzr.smart import client, protocol, request, signals, vfs from breezy.transport import ssh -""") +""", +) from ... import debug, errors, osutils, trace # Throughout this module buffer size parameters are either limited to be at @@ -59,9 +62,10 @@ class HpssVfsRequestNotAllowed(errors.BzrError): - - _fmt = ("VFS requests over the smart server are not allowed. Encountered: " - "%(method)s, %(arguments)s.") + _fmt = ( + "VFS requests over the smart server are not allowed. Encountered: " + "%(method)s, %(arguments)s." + ) def __init__(self, method, arguments): self.method = method @@ -91,10 +95,10 @@ def _get_protocol_factory_for_bytes(bytes): """ if bytes.startswith(protocol.MESSAGE_VERSION_THREE): protocol_factory = protocol.build_server_protocol_three - bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):] + bytes = bytes[len(protocol.MESSAGE_VERSION_THREE) :] elif bytes.startswith(protocol.REQUEST_VERSION_TWO): protocol_factory = protocol.SmartServerRequestProtocolTwo - bytes = bytes[len(protocol.REQUEST_VERSION_TWO):] + bytes = bytes[len(protocol.REQUEST_VERSION_TWO) :] else: protocol_factory = protocol.SmartServerRequestProtocolOne return protocol_factory, bytes @@ -109,16 +113,16 @@ def _get_line(read_bytes_func): :returns: a tuple of two strs: (line, excess) """ newline_pos = -1 - bytes = b'' + bytes = b"" while newline_pos == -1: new_bytes = read_bytes_func(1) bytes += new_bytes - if new_bytes == b'': + if new_bytes == b"": # Ran out of bytes before receiving a complete line. - return bytes, b'' - newline_pos = bytes.find(b'\n') - line = bytes[:newline_pos + 1] - excess = bytes[newline_pos + 1:] + return bytes, b"" + newline_pos = bytes.find(b"\n") + line = bytes[: newline_pos + 1] + excess = bytes[newline_pos + 1 :] return line, excess @@ -138,16 +142,18 @@ def _push_back(self, data): raise TypeError(data) if self._push_back_buffer is not None: raise AssertionError( - f"_push_back called when self._push_back_buffer is {self._push_back_buffer!r}") - if data == b'': + f"_push_back called when self._push_back_buffer is {self._push_back_buffer!r}" + ) + if data == b"": return self._push_back_buffer = data def _get_push_back_buffer(self): - if self._push_back_buffer == b'': + if self._push_back_buffer == b"": raise AssertionError( - f'{self}._push_back_buffer should never be the empty string, ' - 'which can be confused with EOF') + f"{self}._push_back_buffer should never be the empty string, " + "which can be confused with EOF" + ) bytes = self._push_back_buffer self._push_back_buffer = None return bytes @@ -192,7 +198,7 @@ def _report_activity(self, bytes, direction): _bad_file_descriptor = (errno.EBADF,) -if sys.platform == 'win32': +if sys.platform == "win32": # Given on Windows if you pass a closed socket to select.select. Probably # also given if you pass a file handle to select. WSAENOTSOCK = 10038 @@ -219,7 +225,7 @@ class SmartServerStreamMedium(SmartMedium): _timer = time.time - def __init__(self, backing_transport, root_client_path='/', timeout=None): + def __init__(self, backing_transport, root_client_path="/", timeout=None): """Construct new server. :param backing_transport: Transport for the directory served. @@ -229,7 +235,7 @@ def __init__(self, backing_transport, root_client_path='/', timeout=None): self.root_client_path = root_client_path self.finished = False if timeout is None: - raise AssertionError('You must supply a timeout.') + raise AssertionError("You must supply a timeout.") self._client_timeout = timeout self._client_poll_timeout = min(timeout / 10.0, 1.0) SmartMedium.__init__(self) @@ -239,12 +245,13 @@ def serve(self): # Keep a reference to stderr because the sys module's globals get set to # None during interpreter shutdown. from sys import stderr + try: while not self.finished: server_protocol = self._build_protocol() self._serve_one_request(server_protocol) except errors.ConnectionTimeout as e: - trace.note(f'{e}') + trace.note(f"{e}") trace.log_exception_quietly() self._disconnect_client() # We reported it, no reason to make a big fuss. @@ -256,7 +263,7 @@ def serve(self): def _stop_gracefully(self): """When we finish this message, stop looking for more.""" - trace.mutter(f'Stopping {self}') + trace.mutter(f"Stopping {self}") self.finished = True def _disconnect_client(self): @@ -293,7 +300,8 @@ def _build_protocol(self): bytes = self._get_line() protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes) protocol = protocol_factory( - self.backing_transport, self._write_out, self.root_client_path) + self.backing_transport, self._write_out, self.root_client_path + ) protocol.accept_bytes(unused_bytes) return protocol @@ -313,8 +321,8 @@ def _wait_on_descriptor(self, fd, timeout_seconds): try: rs, _, xs = select.select([fd], [], [fd], poll_timeout) except OSError as e: - err = getattr(e, 'errno', None) - if err is None and getattr(e, 'args', None) is not None: + err = getattr(e, "errno", None) + if err is None and getattr(e, "args", None) is not None: # select.error doesn't have 'errno', it just has args[0] err = e.args[0] if err in _bad_file_descriptor: @@ -327,7 +335,9 @@ def _wait_on_descriptor(self, fd, timeout_seconds): return # Socket may already be closed if rs or xs: return - raise errors.ConnectionTimeout(f'disconnecting client after {timeout_seconds:.1f} seconds') + raise errors.ConnectionTimeout( + f"disconnecting client after {timeout_seconds:.1f} seconds" + ) def _serve_one_request(self, protocol): """Read one request from input, process, send back a response. @@ -356,31 +366,30 @@ def _read_bytes(self, desired_count): class SmartServerSocketStreamMedium(SmartServerStreamMedium): - - def __init__(self, sock, backing_transport, root_client_path='/', - timeout=None): + def __init__(self, sock, backing_transport, root_client_path="/", timeout=None): """Constructor. :param sock: the socket the server will read from. It will be put into blocking mode. """ SmartServerStreamMedium.__init__( - self, backing_transport, root_client_path=root_client_path, - timeout=timeout) + self, backing_transport, root_client_path=root_client_path, timeout=timeout + ) sock.setblocking(True) self.socket = sock # Get the getpeername now, as we might be closed later when we care. try: self._client_info = sock.getpeername() except OSError: - self._client_info = '' + self._client_info = "" def __str__(self): - return f'{self.__class__.__name__}(client={self._client_info})' + return f"{self.__class__.__name__}(client={self._client_info})" def __repr__(self): - return '{}.{}(client={})'.format(self.__module__, self.__class__.__name__, - self._client_info) + return "{}.{}(client={})".format( + self.__module__, self.__class__.__name__, self._client_info + ) def _serve_one_request_unguarded(self, protocol): while protocol.next_read_size(): @@ -388,7 +397,7 @@ def _serve_one_request_unguarded(self, protocol): # than MAX_SOCKET_CHUNK ready, the socket will just return a # short read immediately rather than block. bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK) - if bytes == b'': + if bytes == b"": self.finished = True return protocol.accept_bytes(bytes) @@ -412,8 +421,7 @@ def _wait_for_bytes_with_timeout(self, timeout_seconds): return self._wait_on_descriptor(self.socket, timeout_seconds) def _read_bytes(self, desired_count): - return osutils.read_bytes_from_socket( - self.socket, self._report_activity) + return osutils.read_bytes_from_socket(self.socket, self._report_activity) def terminate_due_to_error(self): # TODO: This should log to a server log file, but no such thing @@ -424,15 +432,15 @@ def terminate_due_to_error(self): def _write_out(self, bytes): tstart = osutils.perf_counter() osutils.send_all(self.socket, bytes, self._report_activity) - if debug.debug_flag_enabled('hpss'): + if debug.debug_flag_enabled("hpss"): thread_id = _thread.get_ident() - trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs' - % ('wrote', thread_id, len(bytes), - osutils.perf_counter() - tstart)) + trace.mutter( + "%12s: [%s] %d bytes to the socket in %.3fs" + % ("wrote", thread_id, len(bytes), osutils.perf_counter() - tstart) + ) class SmartServerPipeStreamMedium(SmartServerStreamMedium): - def __init__(self, in_file, out_file, backing_transport, timeout=None): """Construct new server. @@ -440,13 +448,13 @@ def __init__(self, in_file, out_file, backing_transport, timeout=None): :param out_file: Python file to write responses. :param backing_transport: Transport for the directory served. """ - SmartServerStreamMedium.__init__(self, backing_transport, - timeout=timeout) - if sys.platform == 'win32': + SmartServerStreamMedium.__init__(self, backing_transport, timeout=timeout) + if sys.platform == "win32": # force binary mode for files import msvcrt + for f in (in_file, out_file): - fileno = getattr(f, 'fileno', None) + fileno = getattr(f, "fileno", None) if fileno: msvcrt.setmode(fileno(), os.O_BINARY) self._in = in_file @@ -474,7 +482,7 @@ def _serve_one_request_unguarded(self, protocol): self._out.flush() return bytes = self.read_bytes(bytes_to_read) - if bytes == b'': + if bytes == b"": # Connection has been closed. self.finished = True self._out.flush() @@ -496,8 +504,7 @@ def _wait_for_bytes_with_timeout(self, timeout_seconds): :return: None, this will raise ConnectionTimeout if we time out before data is available. """ - if (getattr(self._in, 'fileno', None) is None - or sys.platform == 'win32'): + if getattr(self._in, "fileno", None) is None or sys.platform == "win32": # You can't select() file descriptors on Windows. return try: @@ -638,11 +645,12 @@ def _read_bytes(self, count): def read_line(self): line = self._read_line() - if not line.endswith(b'\n'): + if not line.endswith(b"\n"): # end of file encountered reading from server raise ConnectionResetError( "Unexpected end of message. Please check connectivity " - "and permissions, and report a bug if problems persist.") + "and permissions, and report a bug if problems persist." + ) return line def _read_line(self): @@ -659,7 +667,8 @@ class _VfsRefuser: def __init__(self): client._SmartClient.hooks.install_named_hook( - 'call', self.check_vfs, 'vfs refuser') + "call", self.check_vfs, "vfs refuser" + ) def check_vfs(self, params): try: @@ -682,7 +691,8 @@ class _DebugCounter: def __init__(self): self.counts = weakref.WeakKeyDictionary() client._SmartClient.hooks.install_named_hook( - 'call', self.increment_call_count, 'hpss call counter') + "call", self.increment_call_count, "hpss call counter" + ) breezy.get_global_state().exit_stack.callback(self.flush_all) def track(self, medium): @@ -693,8 +703,7 @@ def track(self, medium): """ medium_repr = repr(medium) # Add this medium to the WeakKeyDictionary - self.counts[medium] = {"count": 0, "vfs_count": 0, - "medium_repr": medium_repr} + self.counts[medium] = {"count": 0, "vfs_count": 0, "medium_repr": medium_repr} # Weakref callbacks are fired in reverse order of their association # with the referenced object. So we add a weakref *after* adding to # the WeakKeyDict so that we can report the value from it before the @@ -704,27 +713,33 @@ def track(self, medium): def increment_call_count(self, params): # Increment the count in the WeakKeyDictionary value = self.counts[params.medium] - value['count'] += 1 + value["count"] += 1 try: request_method = request.request_handlers.get(params.method) except KeyError: # A method we don't know about doesn't count as a VFS method. return if issubclass(request_method, vfs.VfsRequest): - value['vfs_count'] += 1 + value["vfs_count"] += 1 def done(self, ref): value = self.counts[ref] count, vfs_count, medium_repr = ( - value['count'], value['vfs_count'], value['medium_repr']) + value["count"], + value["vfs_count"], + value["medium_repr"], + ) # In case this callback is invoked for the same ref twice (by the # weakref callback and by the atexit function), set the call count back # to 0 so this item won't be reported twice. - value['count'] = 0 - value['vfs_count'] = 0 + value["count"] = 0 + value["vfs_count"] = 0 if count != 0: - trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format( - count, vfs_count, medium_repr)) + trace.note( + gettext("HPSS calls: {0} ({1} vfs) {2}").format( + count, vfs_count, medium_repr + ) + ) def flush_all(self): for ref in list(self.counts.keys()): @@ -750,12 +765,12 @@ def __init__(self, base): # can be based on what we've seen so far. self._remote_version_is_before = None # Install debug hook function if debug flag is set. - if debug.debug_flag_enabled('hpss'): + if debug.debug_flag_enabled("hpss"): global _debug_counter if _debug_counter is None: _debug_counter = _DebugCounter() _debug_counter.track(self) - if debug.debug_flag_enabled('hpss_client_no_vfs'): + if debug.debug_flag_enabled("hpss_client_no_vfs"): global _vfs_refuser if _vfs_refuser is None: _vfs_refuser = _VfsRefuser() @@ -787,19 +802,25 @@ def _remember_remote_is_before(self, version_tuple): :seealso: _is_remote_before """ - if (self._remote_version_is_before is not None and - version_tuple > self._remote_version_is_before): + if ( + self._remote_version_is_before is not None + and version_tuple > self._remote_version_is_before + ): # We have been told that the remote side is older than some version # which is newer than a previously supplied older-than version. # This indicates that some smart verb call is not guarded # appropriately (it should simply not have been tried). trace.mutter( "_remember_remote_is_before(%r) called, but " - "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before) - if debug.debug_flag_enabled('hpss'): + "_remember_remote_is_before(%r) was called previously.", + version_tuple, + self._remote_version_is_before, + ) + if debug.debug_flag_enabled("hpss"): ui.ui_factory.show_warning( f"_remember_remote_is_before({version_tuple!r}) called, but " - f"_remember_remote_is_before({self._remote_version_is_before!r}) was called previously.") + f"_remember_remote_is_before({self._remote_version_is_before!r}) was called previously." + ) return self._remote_version_is_before = version_tuple @@ -812,8 +833,7 @@ def protocol_version(self): medium_request = self.get_request() # Send a 'hello' request in protocol version one, for maximum # backwards compatibility. - client_protocol = protocol.SmartClientRequestProtocolOne( - medium_request) + client_protocol = protocol.SmartClientRequestProtocolOne(medium_request) client_protocol.query_version() self._done_hello = True except errors.SmartProtocolError as e: @@ -821,7 +841,7 @@ def protocol_version(self): # result. self._protocol_version_error = e raise - return '2' + return "2" def should_probe(self): """Should RemoteBzrDirFormat.probe_transport send a smart request on @@ -852,7 +872,7 @@ def remote_path_from_transport(self, transport): anything but path, so it is only safe to use it in requests sent over the medium from the matching transport. """ - medium_base = urlutils.join(self.base, '/') + medium_base = urlutils.join(self.base, "/") rel_url = urlutils.relative_url(medium_base, transport.base) return urlutils.unquote(rel_url) @@ -922,9 +942,10 @@ def _accept_bytes(self, data): except OSError as e: if e.errno in (errno.EINVAL, errno.EPIPE): raise ConnectionResetError( - "Error trying to write to subprocess", e) from e + "Error trying to write to subprocess", e + ) from e raise - self._report_activity(len(data), 'write') + self._report_activity(len(data), "write") def _flush(self): """See SmartClientStreamMedium._flush().""" @@ -937,15 +958,16 @@ def _read_bytes(self, count): """See SmartClientStreamMedium._read_bytes.""" bytes_to_read = min(count, _MAX_READ_SIZE) data = self._readable_pipe.read(bytes_to_read) - self._report_activity(len(data), 'read') + self._report_activity(len(data), "read") return data class SSHParams: """A set of parameters for starting a remote bzr via SSH.""" - def __init__(self, host, port=None, username=None, password=None, - bzr_remote_path='bzr'): + def __init__( + self, host, port=None, username=None, password=None, bzr_remote_path="bzr" + ): self.host = host self.port = port self.username = username @@ -971,7 +993,7 @@ def __init__(self, base, ssh_params, vendor=None): self._ssh_params = ssh_params # for the benefit of progress making a short description of this # transport - self._scheme = 'bzr+ssh' + self._scheme = "bzr+ssh" # SmartClientStreamMedium stores the repr of this object in its # _DebugCounter so we have to store all the values used in our repr # method before calling the super init. @@ -981,19 +1003,20 @@ def __init__(self, base, ssh_params, vendor=None): def __repr__(self): if self._ssh_params.port is None: - maybe_port = '' + maybe_port = "" else: - maybe_port = f':{self._ssh_params.port}' + maybe_port = f":{self._ssh_params.port}" if self._ssh_params.username is None: - maybe_user = '' + maybe_user = "" else: - maybe_user = f'{self._ssh_params.username}@' + maybe_user = f"{self._ssh_params.username}@" return "{}({}://{}{}{}/)".format( self.__class__.__name__, self._scheme, maybe_user, self._ssh_params.host, - maybe_port) + maybe_port, + ) def _accept_bytes(self, bytes): """See SmartClientStreamMedium.accept_bytes.""" @@ -1017,22 +1040,33 @@ def _ensure_connection(self): vendor = ssh._get_ssh_vendor() else: vendor = self._vendor - self._ssh_connection = vendor.connect_ssh(self._ssh_params.username, - self._ssh_params.password, self._ssh_params.host, - self._ssh_params.port, - command=[self._ssh_params.bzr_remote_path, 'serve', '--inet', - '--directory=/', '--allow-writes']) + self._ssh_connection = vendor.connect_ssh( + self._ssh_params.username, + self._ssh_params.password, + self._ssh_params.host, + self._ssh_params.port, + command=[ + self._ssh_params.bzr_remote_path, + "serve", + "--inet", + "--directory=/", + "--allow-writes", + ], + ) io_kind, io_object = self._ssh_connection.get_sock_or_pipes() - if io_kind == 'socket': + if io_kind == "socket": self._real_medium = SmartClientAlreadyConnectedSocketMedium( - self.base, io_object) - elif io_kind == 'pipes': + self.base, io_object + ) + elif io_kind == "pipes": read_from, write_to = io_object self._real_medium = SmartSimplePipesClientMedium( - read_from, write_to, self.base) + read_from, write_to, self.base + ) else: raise AssertionError( - f"Unexpected io_kind {io_kind!r} from {self._ssh_connection!r}") + f"Unexpected io_kind {io_kind!r} from {self._ssh_connection!r}" + ) for hook in transport.Transport.hooks["post_connect"]: hook(self) @@ -1084,8 +1118,7 @@ def _read_bytes(self, count): """See SmartClientMedium.read_bytes.""" if not self._connected: raise errors.MediumNotConnected(self) - return osutils.read_bytes_from_socket( - self._socket, self._report_activity) + return osutils.read_bytes_from_socket(self._socket, self._report_activity) def disconnect(self): """See SmartClientMedium.disconnect().""" @@ -1114,18 +1147,20 @@ def _ensure_connection(self): else: port = int(self._port) try: - sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM, 0, 0) + sockaddrs = socket.getaddrinfo( + self._host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, 0 + ) except socket.gaierror as e: (err_num, err_msg) = e.args - raise ConnectionError("failed to lookup %s:%d: %s" % (self._host, port, err_msg)) from e + raise ConnectionError( + "failed to lookup %s:%d: %s" % (self._host, port, err_msg) + ) from e # Initialize err in case there are no addresses returned: last_err = socket.error(f"no address found for {self._host}") - for (family, socktype, proto, _canonname, sockaddr) in sockaddrs: + for family, socktype, proto, _canonname, sockaddr in sockaddrs: try: self._socket = socket.socket(family, socktype, proto) - self._socket.setsockopt(socket.IPPROTO_TCP, - socket.TCP_NODELAY, 1) + self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._socket.connect(sockaddr) except OSError as err: if self._socket is not None: @@ -1143,7 +1178,9 @@ def _ensure_connection(self): err_msg = last_err.args[1] if isinstance(last_err, ConnectionError): raise last_err - raise ConnectionError("failed to connect to %s:%d: %s" % (self._host, port, err_msg)) + raise ConnectionError( + "failed to connect to %s:%d: %s" % (self._host, port, err_msg) + ) self._connected = True for hook in transport.Transport.hooks["post_connect"]: hook(self) @@ -1167,10 +1204,11 @@ def _ensure_connection(self): class TooManyConcurrentRequests(errors.InternalBzrError): - - _fmt = ("The medium '%(medium)s' has reached its concurrent request limit." - " Be sure to finish_writing and finish_reading on the" - " currently open request.") + _fmt = ( + "The medium '%(medium)s' has reached its concurrent request limit." + " Be sure to finish_writing and finish_reading on the" + " currently open request." + ) def __init__(self, medium): self.medium = medium diff --git a/breezy/bzr/smart/message.py b/breezy/bzr/smart/message.py index 6648a3f22c..a79dc3fbeb 100644 --- a/breezy/bzr/smart/message.py +++ b/breezy/bzr/smart/message.py @@ -94,7 +94,7 @@ def __init__(self, request_handler, responder): MessageHandler.__init__(self) self.request_handler = request_handler self.responder = responder - self.expecting = 'args' + self.expecting = "args" self._should_finish_body = False self._response_sent = False @@ -108,59 +108,63 @@ def protocol_error(self, exception): def byte_part_received(self, byte): if not isinstance(byte, bytes): raise TypeError(byte) - if self.expecting == 'body': - if byte == b'S': + if self.expecting == "body": + if byte == b"S": # Success. Nothing more to come except the end of message. - self.expecting = 'end' - elif byte == b'E': + self.expecting = "end" + elif byte == b"E": # Error. Expect an error structure. - self.expecting = 'error' + self.expecting = "error" else: raise errors.SmartProtocolError( - f'Non-success status byte in request body: {byte!r}') + f"Non-success status byte in request body: {byte!r}" + ) else: - raise errors.SmartProtocolError( - f'Unexpected message part: byte({byte!r})') + raise errors.SmartProtocolError(f"Unexpected message part: byte({byte!r})") def structure_part_received(self, structure): - if self.expecting == 'args': + if self.expecting == "args": self._args_received(structure) - elif self.expecting == 'error': + elif self.expecting == "error": self._error_received(structure) else: raise errors.SmartProtocolError( - f'Unexpected message part: structure({structure!r})') + f"Unexpected message part: structure({structure!r})" + ) def _args_received(self, args): - self.expecting = 'body' + self.expecting = "body" self.request_handler.args_received(args) if self.request_handler.finished_reading: self._response_sent = True self.responder.send_response(self.request_handler.response) - self.expecting = 'end' + self.expecting = "end" def _error_received(self, error_args): - self.expecting = 'end' + self.expecting = "end" self.request_handler.post_body_error_received(error_args) def bytes_part_received(self, bytes): - if self.expecting == 'body': + if self.expecting == "body": self._should_finish_body = True self.request_handler.accept_body(bytes) else: raise errors.SmartProtocolError( - f'Unexpected message part: bytes({bytes!r})') + f"Unexpected message part: bytes({bytes!r})" + ) def end_received(self): - if self.expecting not in ['body', 'end']: + if self.expecting not in ["body", "end"]: raise errors.SmartProtocolError( - f'End of message received prematurely (while expecting {self.expecting})') - self.expecting = 'nothing' + f"End of message received prematurely (while expecting {self.expecting})" + ) + self.expecting = "nothing" self.request_handler.end_received() if not self.request_handler.finished_reading: raise errors.SmartProtocolError( "Complete conventional request was received, but request " - "handler has not finished reading.") + "handler has not finished reading." + ) if not self._response_sent: self.responder.send_response(self.request_handler.response) @@ -204,7 +208,6 @@ def cancel_read_body(self): class ConventionalResponseHandler(MessageHandler, ResponseHandler): - def __init__(self): MessageHandler.__init__(self) self.status = None @@ -223,18 +226,19 @@ def setProtoAndMediumRequest(self, protocol_decoder, medium_request): def byte_part_received(self, byte): if not isinstance(byte, bytes): raise TypeError(byte) - if byte not in [b'E', b'S']: - raise errors.SmartProtocolError( - f'Unknown response status: {byte!r}') + if byte not in [b"E", b"S"]: + raise errors.SmartProtocolError(f"Unknown response status: {byte!r}") if self._body_started: if self._body_stream_status is not None: raise errors.SmartProtocolError( - f'Unexpected byte part received: {byte!r}') + f"Unexpected byte part received: {byte!r}" + ) self._body_stream_status = byte else: if self.status is not None: raise errors.SmartProtocolError( - f'Unexpected byte part received: {byte!r}') + f"Unexpected byte part received: {byte!r}" + ) self.status = byte def bytes_part_received(self, bytes): @@ -244,16 +248,19 @@ def bytes_part_received(self, bytes): def structure_part_received(self, structure): if not isinstance(structure, tuple): raise errors.SmartProtocolError( - f'Args structure is not a sequence: {structure!r}') + f"Args structure is not a sequence: {structure!r}" + ) if not self._body_started: if self.args is not None: raise errors.SmartProtocolError( - f'Unexpected structure received: {structure!r} (already got {self.args!r})') + f"Unexpected structure received: {structure!r} (already got {self.args!r})" + ) self.args = structure else: - if self._body_stream_status != b'E': + if self._body_stream_status != b"E": raise errors.SmartProtocolError( - f'Unexpected structure received after body: {structure!r}') + f"Unexpected structure received after body: {structure!r}" + ) self._body_error_args = structure def _wait_for_response_args(self): @@ -272,17 +279,19 @@ def _read_more(self): self._medium_request.finished_reading() return data = self._medium_request.read_bytes(next_read_size) - if data == b'': + if data == b"": # end of file encountered reading from server - if debug.debug_flag_enabled('hpss'): + if debug.debug_flag_enabled("hpss"): mutter( - 'decoder state: buf[:10]=%r, state_accept=%s', + "decoder state: buf[:10]=%r, state_accept=%s", self._protocol_decoder._get_in_buffer()[:10], - self._protocol_decoder.state_accept.__name__) + self._protocol_decoder.state_accept.__name__, + ) raise ConnectionResetError( "Unexpected end of message. " "Please check connectivity and permissions, and report a bug " - "if problems persist.") + "if problems persist." + ) self._protocol_decoder.accept_bytes(data) def protocol_error(self, exception): @@ -296,9 +305,9 @@ def read_response_tuple(self, expect_body=False): self._wait_for_response_args() if not expect_body: self._wait_for_response_end() - if debug.debug_flag_enabled('hpss'): - mutter(' result: %r', self.args) - if self.status == b'E': + if debug.debug_flag_enabled("hpss"): + mutter(" result: %r", self.args) + if self.status == b"E": self._wait_for_response_end() _raise_smart_server_error(self.args) return tuple(self.args) @@ -316,9 +325,9 @@ def read_body_bytes(self, count=-1): # != -1. (2008/04/30, Andrew Bennetts) if self._body is None: self._wait_for_response_end() - body_bytes = b''.join(self._bytes_parts) - if debug.debug_flag_enabled('hpss'): - mutter(' %d body bytes read', len(body_bytes)) + body_bytes = b"".join(self._bytes_parts) + if debug.debug_flag_enabled("hpss"): + mutter(" %d body bytes read", len(body_bytes)) self._body = BytesIO(body_bytes) self._bytes_parts = None return self._body.read(count) @@ -327,11 +336,11 @@ def read_streamed_body(self): while not self.finished_reading: while self._bytes_parts: bytes_part = self._bytes_parts.popleft() - if debug.debug_flag_enabled('hpssdetail'): - mutter(' %d byte part read', len(bytes_part)) + if debug.debug_flag_enabled("hpssdetail"): + mutter(" %d byte part read", len(bytes_part)) yield bytes_part self._read_more() - if self._body_stream_status == b'E': + if self._body_stream_status == b"E": _raise_smart_server_error(self._body_error_args) def cancel_read_body(self): @@ -343,6 +352,6 @@ def _raise_smart_server_error(error_tuple): Specific error translation is handled by breezy.bzr.remote._translate_error """ - if error_tuple[0] == b'UnknownMethod': + if error_tuple[0] == b"UnknownMethod": raise errors.UnknownSmartMethod(error_tuple[1]) raise errors.ErrorFromSmartServer(error_tuple) diff --git a/breezy/bzr/smart/packrepository.py b/breezy/bzr/smart/packrepository.py index a3fa19f684..f345ba08d2 100644 --- a/breezy/bzr/smart/packrepository.py +++ b/breezy/bzr/smart/packrepository.py @@ -21,13 +21,12 @@ class SmartServerPackRepositoryAutopack(SmartServerRepositoryRequest): - def do_repository_request(self, repository): - pack_collection = getattr(repository, '_pack_collection', None) + pack_collection = getattr(repository, "_pack_collection", None) if pack_collection is None: # This is a not a pack repo, so asking for an autopack is just a # no-op. - return SuccessfulSmartServerResponse((b'ok',)) + return SuccessfulSmartServerResponse((b"ok",)) with repository.lock_write(): repository._pack_collection.autopack() - return SuccessfulSmartServerResponse((b'ok',)) + return SuccessfulSmartServerResponse((b"ok",)) diff --git a/breezy/bzr/smart/ping.py b/breezy/bzr/smart/ping.py index b6137c21a2..71ecb02001 100644 --- a/breezy/bzr/smart/ping.py +++ b/breezy/bzr/smart/ping.py @@ -29,10 +29,11 @@ class cmd_ping(Command): smart protocol, and reports the response. """ - takes_args = ['location'] + takes_args = ["location"] def run(self, location): from breezy.bzr.smart.client import _SmartClient + transport = get_transport(location) try: medium = transport.get_smart_medium() @@ -41,11 +42,12 @@ def run(self, location): client = _SmartClient(medium) # Use call_expecting_body (even though we don't expect a body) so that # we can see the response headers (if any) via the handler object. - response, handler = client.call_expecting_body(b'hello') + response, handler = client.call_expecting_body(b"hello") handler.cancel_read_body() - self.outf.write(f'Response: {response!r}\n') - if getattr(handler, 'headers', None) is not None: + self.outf.write(f"Response: {response!r}\n") + if getattr(handler, "headers", None) is not None: headers = { - k.decode('utf-8'): v.decode('utf-8') - for (k, v) in handler.headers.items()} - self.outf.write(f'Headers: {headers!r}\n') + k.decode("utf-8"): v.decode("utf-8") + for (k, v) in handler.headers.items() + } + self.outf.write(f"Headers: {headers!r}\n") diff --git a/breezy/bzr/smart/protocol.py b/breezy/bzr/smart/protocol.py index e181bdb723..baeea96f2b 100644 --- a/breezy/bzr/smart/protocol.py +++ b/breezy/bzr/smart/protocol.py @@ -36,17 +36,15 @@ # Protocol version strings. These are sent as prefixes of bzr requests and # responses to identify the protocol version being used. (There are no version # one strings because that version doesn't send any). -REQUEST_VERSION_TWO = b'bzr request 2\n' -RESPONSE_VERSION_TWO = b'bzr response 2\n' +REQUEST_VERSION_TWO = b"bzr request 2\n" +RESPONSE_VERSION_TWO = b"bzr response 2\n" -MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n' +MESSAGE_VERSION_THREE = b"bzr message 3 (bzr 1.6)\n" RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE class SmartMessageHandlerError(errors.InternalBzrError): - - _fmt = ("The message handler raised an exception:\n" - "%(traceback_text)s") + _fmt = "The message handler raised an exception:\n" "%(traceback_text)s" def __init__(self, exc_info): import traceback @@ -55,8 +53,9 @@ def __init__(self, exc_info): self.exc_type, self.exc_value, self.exc_tb = exc_info self.exc_info = exc_info traceback_strings = traceback.format_exception( - self.exc_type, self.exc_value, self.exc_tb) - self.traceback_text = ''.join(traceback_strings) + self.exc_type, self.exc_value, self.exc_tb + ) + self.traceback_text = "".join(traceback_strings) def _recv_tuple(from_file): @@ -65,11 +64,11 @@ def _recv_tuple(from_file): def _decode_tuple(req_line): - if req_line is None or req_line == b'': + if req_line is None or req_line == b"": return None - if not req_line.endswith(b'\n'): + if not req_line.endswith(b"\n"): raise errors.SmartProtocolError(f"request {req_line!r} not terminated") - return tuple(req_line[:-1].split(b'\x01')) + return tuple(req_line[:-1].split(b"\x01")) def _encode_tuple(args): @@ -77,7 +76,7 @@ def _encode_tuple(args): for arg in args: if isinstance(arg, str): raise TypeError(args) - return b'\x01'.join(args) + b'\n' + return b"\x01".join(args) + b"\n" class Requester: @@ -121,27 +120,28 @@ class SmartProtocolBase: # support multiple chunks? def _encode_bulk_data(self, body): """Encode body as a bulk data chunk.""" - return b''.join((b'%d\n' % len(body), body, b'done\n')) + return b"".join((b"%d\n" % len(body), body, b"done\n")) def _serialise_offsets(self, offsets): """Serialise a readv offset list.""" txt = [] for start, length in offsets: - txt.append(b'%d,%d' % (start, length)) - return b'\n'.join(txt) + txt.append(b"%d,%d" % (start, length)) + return b"\n".join(txt) class SmartServerRequestProtocolOne(SmartProtocolBase): """Server-side encoding and decoding logic for smart version 1.""" - def __init__(self, backing_transport, write_func, root_client_path='/', - jail_root=None): + def __init__( + self, backing_transport, write_func, root_client_path="/", jail_root=None + ): self._backing_transport = backing_transport self._root_client_path = root_client_path self._jail_root = jail_root - self.unused_data = b'' + self.unused_data = b"" self._finished = False - self.in_buffer = b'' + self.in_buffer = b"" self._has_dispatched = False self.request = None self._body_decoder = None @@ -156,38 +156,45 @@ def accept_bytes(self, data): raise ValueError(data) self.in_buffer += data if not self._has_dispatched: - if b'\n' not in self.in_buffer: + if b"\n" not in self.in_buffer: # no command line yet return self._has_dispatched = True try: - first_line, self.in_buffer = self.in_buffer.split(b'\n', 1) - first_line += b'\n' + first_line, self.in_buffer = self.in_buffer.split(b"\n", 1) + first_line += b"\n" req_args = _decode_tuple(first_line) self.request = request.SmartServerRequestHandler( - self._backing_transport, commands=request.request_handlers, + self._backing_transport, + commands=request.request_handlers, root_client_path=self._root_client_path, - jail_root=self._jail_root) + jail_root=self._jail_root, + ) self.request.args_received(req_args) if self.request.finished_reading: # trivial request self.unused_data = self.in_buffer - self.in_buffer = b'' + self.in_buffer = b"" self._send_response(self.request.response) except KeyboardInterrupt: raise except errors.UnknownSmartMethod as err: protocol_error = errors.SmartProtocolError( - f"bad request '{err.verb.decode('ascii')}'") + f"bad request '{err.verb.decode('ascii')}'" + ) failure = request.FailedSmartServerResponse( - (b'error', str(protocol_error).encode('utf-8'))) + (b"error", str(protocol_error).encode("utf-8")) + ) self._send_response(failure) return except Exception as exception: # everything else: pass to client, flush, and quit log_exception_quietly() - self._send_response(request.FailedSmartServerResponse( - (b'error', str(exception).encode('utf-8')))) + self._send_response( + request.FailedSmartServerResponse( + (b"error", str(exception).encode("utf-8")) + ) + ) return if self._has_dispatched: @@ -195,7 +202,7 @@ def accept_bytes(self, data): # nothing to do.XXX: this routine should be a single state # machine too. self.unused_data += self.in_buffer - self.in_buffer = b'' + self.in_buffer = b"" return if self._body_decoder is None: self._body_decoder = LengthPrefixedBodyDecoder() @@ -210,16 +217,15 @@ def accept_bytes(self, data): if self.request.response is not None: self._send_response(self.request.response) self.unused_data = self.in_buffer - self.in_buffer = b'' + self.in_buffer = b"" else: if self.request.finished_reading: - raise AssertionError( - "no response and we have finished reading.") + raise AssertionError("no response and we have finished reading.") def _send_response(self, response): """Send a smart server response down the output stream.""" if self._finished: - raise AssertionError('response already sent') + raise AssertionError("response already sent") args = response.args body = response.body self._finished = True @@ -267,9 +273,9 @@ class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne): def _write_success_or_failure_prefix(self, response): """Write the protocol specific success/failure prefix.""" if response.is_successful(): - self._write_func(b'success\n') + self._write_func(b"success\n") else: - self._write_func(b'failed\n') + self._write_func(b"failed\n") def _write_protocol_version(self): r"""Write any prefixes this protocol requires. @@ -280,18 +286,17 @@ def _write_protocol_version(self): def _send_response(self, response): """Send a smart server response down the output stream.""" - if (self._finished): - raise AssertionError('response already sent') + if self._finished: + raise AssertionError("response already sent") self._finished = True self._write_protocol_version() self._write_success_or_failure_prefix(response) self._write_func(_encode_tuple(response.args)) if response.body is not None: if not isinstance(response.body, bytes): - raise AssertionError('body must be bytes') + raise AssertionError("body must be bytes") if response.body_stream is not None: - raise AssertionError( - 'body_stream and body cannot both be set') + raise AssertionError("body_stream and body cannot both be set") data = self._encode_bulk_data(response.body) self._write_func(data) elif response.body_stream is not None: @@ -299,23 +304,24 @@ def _send_response(self, response): def _send_stream(stream, write_func): - write_func(b'chunked\n') + write_func(b"chunked\n") _send_chunks(stream, write_func) - write_func(b'END\n') + write_func(b"END\n") def _send_chunks(stream, write_func): for chunk in stream: if isinstance(chunk, bytes): - data = f"{len(chunk):x}\n".encode('ascii') + chunk + data = f"{len(chunk):x}\n".encode("ascii") + chunk write_func(data) elif isinstance(chunk, request.FailedSmartServerResponse): - write_func(b'ERR\n') + write_func(b"ERR\n") _send_chunks(chunk.args, write_func) return else: raise errors.BzrError( - f'Chunks must be str or FailedSmartServerResponse, got {chunk!r}') + f"Chunks must be str or FailedSmartServerResponse, got {chunk!r}" + ) class _NeedMoreBytes(Exception): @@ -347,17 +353,21 @@ def __init__(self): self.finished_reading = False self._in_buffer_list = [] self._in_buffer_len = 0 - self.unused_data = b'' + self.unused_data = b"" self.bytes_left = None self._number_needed_bytes = None def _get_in_buffer(self): if len(self._in_buffer_list) == 1: return self._in_buffer_list[0] - in_buffer = b''.join(self._in_buffer_list) + in_buffer = b"".join(self._in_buffer_list) if len(in_buffer) != self._in_buffer_len: raise AssertionError( - "Length of buffer did not match expected value: {} != {}".format(*self._in_buffer_len), len(in_buffer)) + "Length of buffer did not match expected value: {} != {}".format( + *self._in_buffer_len + ), + len(in_buffer), + ) self._in_buffer_list = [in_buffer] return in_buffer @@ -371,8 +381,10 @@ def _get_in_bytes(self, count): """ # check if we can yield the bytes from just the first entry in our list if len(self._in_buffer_list) == 0: - raise AssertionError('Callers must be sure we have buffered bytes' - ' before calling _get_in_bytes') + raise AssertionError( + "Callers must be sure we have buffered bytes" + " before calling _get_in_bytes" + ) if len(self._in_buffer_list[0]) > count: return self._in_buffer_list[0][:count] # We can't yield it from the first buffer, so collapse all buffers, and @@ -457,7 +469,7 @@ def next_read_size(self): elif self.state_accept == self._state_accept_reading_unused: return 1 elif self.state_accept == self._state_accept_expecting_header: - return max(0, len('chunked\n') - self._in_buffer_len) + return max(0, len("chunked\n") - self._in_buffer_len) else: raise AssertionError(f"Impossible state: {self.state_accept!r}") @@ -469,14 +481,14 @@ def read_next_chunk(self): def _extract_line(self): in_buf = self._get_in_buffer() - pos = in_buf.find(b'\n') + pos = in_buf.find(b"\n") if pos == -1: # We haven't read a complete line yet, so request more bytes before # we continue. raise _NeedMoreBytes(1) line = in_buf[:pos] # Trim the prefix (including '\n' delimiter) from the _in_buffer. - self._set_in_buffer(in_buf[pos + 1:]) + self._set_in_buffer(in_buf[pos + 1 :]) return line def _finished(self): @@ -492,20 +504,19 @@ def _finished(self): def _state_accept_expecting_header(self): prefix = self._extract_line() - if prefix == b'chunked': + if prefix == b"chunked": self.state_accept = self._state_accept_expecting_length else: - raise errors.SmartProtocolError( - f'Bad chunked body header: "{prefix}"') + raise errors.SmartProtocolError(f'Bad chunked body header: "{prefix}"') def _state_accept_expecting_length(self): prefix = self._extract_line() - if prefix == b'ERR': + if prefix == b"ERR": self.error = True self.error_in_progress = [] self._state_accept_expecting_length() return - elif prefix == b'END': + elif prefix == b"END": # We've read the end-of-body marker. # Any further bytes are unused data, including the bytes left in # the _in_buffer. @@ -513,14 +524,14 @@ def _state_accept_expecting_length(self): return else: self.bytes_left = int(prefix, 16) - self.chunk_in_progress = b'' + self.chunk_in_progress = b"" self.state_accept = self._state_accept_reading_chunk def _state_accept_reading_chunk(self): in_buf = self._get_in_buffer() in_buffer_len = len(in_buf) - self.chunk_in_progress += in_buf[:self.bytes_left] - self._set_in_buffer(in_buf[self.bytes_left:]) + self.chunk_in_progress += in_buf[: self.bytes_left] + self._set_in_buffer(in_buf[self.bytes_left :]) self.bytes_left -= in_buffer_len if self.bytes_left <= 0: # Finished with chunk @@ -544,8 +555,8 @@ def __init__(self): _StatefulDecoder.__init__(self) self.state_accept = self._state_accept_expecting_length self.state_read = self._state_read_no_data - self._body = b'' - self._trailer_buffer = b'' + self._body = b"" + self._trailer_buffer = b"" def next_read_size(self): if self.bytes_left is not None: @@ -569,11 +580,11 @@ def read_pending_data(self): def _state_accept_expecting_length(self): in_buf = self._get_in_buffer() - pos = in_buf.find(b'\n') + pos = in_buf.find(b"\n") if pos == -1: return self.bytes_left = int(in_buf[:pos]) - self._set_in_buffer(in_buf[pos + 1:]) + self._set_in_buffer(in_buf[pos + 1 :]) self.state_accept = self._state_accept_reading_body self.state_read = self._state_read_body_buffer @@ -585,8 +596,8 @@ def _state_accept_reading_body(self): if self.bytes_left <= 0: # Finished with body if self.bytes_left != 0: - self._trailer_buffer = self._body[self.bytes_left:] - self._body = self._body[:self.bytes_left] + self._trailer_buffer = self._body[self.bytes_left :] + self._body = self._body[: self.bytes_left] self.bytes_left = None self.state_accept = self._state_accept_reading_trailer @@ -595,8 +606,8 @@ def _state_accept_reading_trailer(self): self._set_in_buffer(None) # TODO: what if the trailer does not match "done\n"? Should this raise # a ProtocolViolation exception? - if self._trailer_buffer.startswith(b'done\n'): - self.unused_data = self._trailer_buffer[len(b'done\n'):] + if self._trailer_buffer.startswith(b"done\n"): + self.unused_data = self._trailer_buffer[len(b"done\n") :] self.state_accept = self._state_accept_reading_unused self.finished_reading = True @@ -605,16 +616,17 @@ def _state_accept_reading_unused(self): self._set_in_buffer(None) def _state_read_no_data(self): - return b'' + return b"" def _state_read_body_buffer(self): result = self._body - self._body = b'' + self._body = b"" return result -class SmartClientRequestProtocolOne(SmartProtocolBase, Requester, - message.ResponseHandler): +class SmartClientRequestProtocolOne( + SmartProtocolBase, Requester, message.ResponseHandler +): """The client-side protocol for smart version 1.""" def __init__(self, request): @@ -633,10 +645,10 @@ def set_headers(self, headers): self._headers = dict(headers) def call(self, *args): - if debug.debug_flag_enabled('hpss'): - mutter('hpss call: %s', repr(args)[1:-1]) - if getattr(self._request._medium, 'base', None) is not None: - mutter(' (to %s)', self._request._medium.base) + if debug.debug_flag_enabled("hpss"): + mutter("hpss call: %s", repr(args)[1:-1]) + if getattr(self._request._medium, "base", None) is not None: + mutter(" (to %s)", self._request._medium.base) self._request_start_time = osutils.perf_counter() self._write_args(args) self._request.finished_writing() @@ -647,15 +659,14 @@ def call_with_body_bytes(self, args, body): After calling this, call read_response_tuple to find the result out. """ - if debug.debug_flag_enabled('hpss'): - mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20]) - if getattr(self._request._medium, '_path', None) is not None: - mutter(' (to %s)', - self._request._medium._path) - mutter(' %d bytes', len(body)) + if debug.debug_flag_enabled("hpss"): + mutter("hpss call w/body: %s (%r...)", repr(args)[1:-1], body[:20]) + if getattr(self._request._medium, "_path", None) is not None: + mutter(" (to %s)", self._request._medium._path) + mutter(" %d bytes", len(body)) self._request_start_time = osutils.perf_counter() - if debug.debug_flag_enabled('hpssdetail'): - mutter('hpss body content: %s', body) + if debug.debug_flag_enabled("hpssdetail"): + mutter("hpss body content: %s", body) self._write_args(args) bytes = self._encode_bulk_data(body) self._request.accept_bytes(bytes) @@ -668,19 +679,18 @@ def call_with_body_readv_array(self, args, body): The body is encoded with one line per readv offset pair. The numbers in each pair are separated by a comma, and no trailing \n is emitted. """ - if debug.debug_flag_enabled('hpss'): - mutter('hpss call w/readv: %s', repr(args)[1:-1]) - if getattr(self._request._medium, '_path', None) is not None: - mutter(' (to %s)', - self._request._medium._path) + if debug.debug_flag_enabled("hpss"): + mutter("hpss call w/readv: %s", repr(args)[1:-1]) + if getattr(self._request._medium, "_path", None) is not None: + mutter(" (to %s)", self._request._medium._path) self._request_start_time = osutils.perf_counter() self._write_args(args) readv_bytes = self._serialise_offsets(body) bytes = self._encode_bulk_data(readv_bytes) self._request.accept_bytes(bytes) self._request.finished_writing() - if debug.debug_flag_enabled('hpss'): - mutter(' %d bytes in readv request', len(readv_bytes)) + if debug.debug_flag_enabled("hpss"): + mutter(" %d bytes in readv request", len(readv_bytes)) self._last_verb = args[0] def call_with_body_stream(self, args, stream): @@ -702,14 +712,16 @@ def cancel_read_body(self): def _read_response_tuple(self): result = self._recv_tuple() - if debug.debug_flag_enabled('hpss'): + if debug.debug_flag_enabled("hpss"): if self._request_start_time is not None: - mutter(' result: %6.3fs %s', - osutils.perf_counter() - self._request_start_time, - repr(result)[1:-1]) + mutter( + " result: %6.3fs %s", + osutils.perf_counter() - self._request_start_time, + repr(result)[1:-1], + ) self._request_start_time = None else: - mutter(' result: %s', repr(result)[1:-1]) + mutter(" result: %s", repr(result)[1:-1]) return result def read_response_tuple(self, expect_body=False): @@ -731,24 +743,24 @@ def _raise_args_if_error(self, result_tuple): # returned in response to existing version 1 smart requests. Responses # starting with these codes are always "failed" responses. v1_error_codes = [ - b'norepository', - b'NoSuchFile', - b'FileExists', - b'DirectoryNotEmpty', - b'ShortReadvError', - b'UnicodeEncodeError', - b'UnicodeDecodeError', - b'ReadOnlyError', - b'nobranch', - b'NoSuchRevision', - b'nosuchrevision', - b'LockContention', - b'UnlockableTransport', - b'LockFailed', - b'TokenMismatch', - b'ReadError', - b'PermissionDenied', - ] + b"norepository", + b"NoSuchFile", + b"FileExists", + b"DirectoryNotEmpty", + b"ShortReadvError", + b"UnicodeEncodeError", + b"UnicodeDecodeError", + b"ReadOnlyError", + b"nobranch", + b"NoSuchRevision", + b"nosuchrevision", + b"LockContention", + b"UnlockableTransport", + b"LockFailed", + b"TokenMismatch", + b"ReadError", + b"PermissionDenied", + ] if result_tuple[0] in v1_error_codes: self._request.finished_reading() raise errors.ErrorFromSmartServer(result_tuple) @@ -762,10 +774,15 @@ def _response_is_unknown_method(self, result_tuple): :param verb: The verb used in that call. :raises: UnexpectedSmartServerResponse """ - if (result_tuple == (b'error', b"Generic bzr smart protocol error: " - b"bad request '" + self._last_verb + b"'") - or result_tuple == (b'error', b"Generic bzr smart protocol error: " - b"bad request u'%s'" % self._last_verb)): + if result_tuple == ( + b"error", + b"Generic bzr smart protocol error: " + b"bad request '" + self._last_verb + b"'", + ) or result_tuple == ( + b"error", + b"Generic bzr smart protocol error: " + b"bad request u'%s'" % self._last_verb, + ): # The response will have no body, so we've finished reading. self._request.finished_reading() raise errors.UnknownSmartMethod(self._last_verb) @@ -782,17 +799,19 @@ def read_body_bytes(self, count=-1): while not _body_decoder.finished_reading: bytes = self._request.read_bytes(_body_decoder.next_read_size()) - if bytes == b'': + if bytes == b"": # end of file encountered reading from server raise ConnectionResetError( - "Connection lost while reading response body.") + "Connection lost while reading response body." + ) _body_decoder.accept_bytes(bytes) self._request.finished_reading() self._body_buffer = BytesIO(_body_decoder.read_pending_data()) # XXX: TODO check the trailer result. - if debug.debug_flag_enabled('hpss'): - mutter(' %d body bytes read', - len(self._body_buffer.getvalue())) + if debug.debug_flag_enabled("hpss"): + mutter( + " %d body bytes read", len(self._body_buffer.getvalue()) + ) return self._body_buffer.read(count) def _recv_tuple(self): @@ -801,11 +820,11 @@ def _recv_tuple(self): def query_version(self): """Return protocol version number of the server.""" - self.call(b'hello') + self.call(b"hello") resp = self.read_response_tuple() - if resp == (b'ok', b'1'): + if resp == (b"ok", b"1"): return 1 - elif resp == (b'ok', b'2'): + elif resp == (b"ok", b"2"): return 2 else: raise errors.SmartProtocolError(f"bad response {resp!r}") @@ -843,18 +862,17 @@ def read_response_tuple(self, expect_body=False): response_status = self._request.read_line() result = SmartClientRequestProtocolOne._read_response_tuple(self) self._response_is_unknown_method(result) - if response_status == b'success\n': + if response_status == b"success\n": self.response_status = True if not expect_body: self._request.finished_reading() return result - elif response_status == b'failed\n': + elif response_status == b"failed\n": self.response_status = False self._request.finished_reading() raise errors.ErrorFromSmartServer(result) else: - raise errors.SmartProtocolError( - f'bad protocol status {response_status!r}') + raise errors.SmartProtocolError(f"bad protocol status {response_status!r}") def _write_protocol_version(self): """Write any prefixes this protocol requires. @@ -870,32 +888,34 @@ def read_streamed_body(self): _body_decoder = ChunkedBodyDecoder() while not _body_decoder.finished_reading: bytes = self._request.read_bytes(_body_decoder.next_read_size()) - if bytes == b'': + if bytes == b"": # end of file encountered reading from server raise ConnectionResetError( - "Connection lost while reading streamed body.") + "Connection lost while reading streamed body." + ) _body_decoder.accept_bytes(bytes) for body_bytes in iter(_body_decoder.read_next_chunk, None): - if debug.debug_flag_enabled('hpss') and isinstance(body_bytes, str): - mutter(' %d byte chunk read', - len(body_bytes)) + if debug.debug_flag_enabled("hpss") and isinstance(body_bytes, str): + mutter(" %d byte chunk read", len(body_bytes)) yield body_bytes self._request.finished_reading() -def build_server_protocol_three(backing_transport, write_func, - root_client_path, jail_root=None): +def build_server_protocol_three( + backing_transport, write_func, root_client_path, jail_root=None +): request_handler = request.SmartServerRequestHandler( - backing_transport, commands=request.request_handlers, - root_client_path=root_client_path, jail_root=jail_root) + backing_transport, + commands=request.request_handlers, + root_client_path=root_client_path, + jail_root=jail_root, + ) responder = ProtocolThreeResponder(write_func) - message_handler = message.ConventionalRequestHandler( - request_handler, responder) + message_handler = message.ConventionalRequestHandler(request_handler, responder) return ProtocolThreeDecoder(message_handler) class ProtocolThreeDecoder(_StatefulDecoder): - response_marker = RESPONSE_VERSION_THREE request_marker = REQUEST_VERSION_THREE @@ -930,7 +950,7 @@ def accept_bytes(self, bytes): # The state machine is ready to continue decoding, but the # exception has interrupted the loop that runs the state machine. # So we call accept_bytes again to restart it. - self.accept_bytes(b'') + self.accept_bytes(b"") except Exception as exception: # The decoder itself has raised an exception. We cannot continue # decoding. @@ -952,7 +972,7 @@ def _extract_length_prefixed_bytes(self): # A length prefix by itself is 4 bytes, and we don't even have that # many yet. raise _NeedMoreBytes(4) - (length,) = struct.unpack('!L', self._get_in_bytes(4)) + (length,) = struct.unpack("!L", self._get_in_bytes(4)) end_of_bytes = 4 + length if self._in_buffer_len < end_of_bytes: # We haven't yet read as many bytes as the length-prefix says there @@ -969,7 +989,9 @@ def _extract_prefixed_bencoded_data(self): try: decoded = bdecode_as_tuple(prefixed_bytes) except ValueError as e: - raise errors.SmartProtocolError(f'Bytes {prefixed_bytes!r} not bencoded') from e + raise errors.SmartProtocolError( + f"Bytes {prefixed_bytes!r} not bencoded" + ) from e return decoded def _extract_single_byte(self): @@ -999,14 +1021,13 @@ def _state_accept_expecting_protocol_version(self): raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE)) if not in_buf.startswith(MESSAGE_VERSION_THREE): raise errors.UnexpectedProtocolVersionMarker(in_buf) - self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):]) + self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE) :]) self.state_accept = self._state_accept_expecting_headers def _state_accept_expecting_headers(self): decoded = self._extract_prefixed_bencoded_data() if not isinstance(decoded, dict): - raise errors.SmartProtocolError( - f'Header object {decoded!r} is not a dict') + raise errors.SmartProtocolError(f"Header object {decoded!r} is not a dict") self.state_accept = self._state_accept_expecting_message_part try: self.message_handler.headers_received(decoded) @@ -1015,17 +1036,18 @@ def _state_accept_expecting_headers(self): def _state_accept_expecting_message_part(self): message_part_kind = self._extract_single_byte() - if message_part_kind == b'o': + if message_part_kind == b"o": self.state_accept = self._state_accept_expecting_one_byte - elif message_part_kind == b's': + elif message_part_kind == b"s": self.state_accept = self._state_accept_expecting_structure - elif message_part_kind == b'b': + elif message_part_kind == b"b": self.state_accept = self._state_accept_expecting_bytes - elif message_part_kind == b'e': + elif message_part_kind == b"e": self.done() else: raise errors.SmartProtocolError( - f'Bad message kind byte: {message_part_kind!r}') + f"Bad message kind byte: {message_part_kind!r}" + ) def _state_accept_expecting_one_byte(self): byte = self._extract_single_byte() @@ -1083,7 +1105,6 @@ def next_read_size(self): class _ProtocolThreeEncoder: - response_marker = request_marker = MESSAGE_VERSION_THREE BUFFER_SIZE = 1024 * 1024 # 1 MiB buffer before flushing @@ -1105,7 +1126,7 @@ def _write_func(self, bytes): def flush(self): if self._buf: - self._real_write_func(b''.join(self._buf)) + self._real_write_func(b"".join(self._buf)) del self._buf[:] self._buf_len = 0 @@ -1113,57 +1134,55 @@ def _serialise_offsets(self, offsets): """Serialise a readv offset list.""" txt = [] for start, length in offsets: - txt.append(b'%d,%d' % (start, length)) - return b'\n'.join(txt) + txt.append(b"%d,%d" % (start, length)) + return b"\n".join(txt) def _write_protocol_version(self): self._write_func(MESSAGE_VERSION_THREE) def _write_prefixed_bencode(self, structure): bytes = bencode(structure) - self._write_func(struct.pack('!L', len(bytes))) + self._write_func(struct.pack("!L", len(bytes))) self._write_func(bytes) def _write_headers(self, headers): self._write_prefixed_bencode(headers) def _write_structure(self, args): - self._write_func(b's') + self._write_func(b"s") utf8_args = [] for arg in args: if isinstance(arg, str): - utf8_args.append(arg.encode('utf8')) + utf8_args.append(arg.encode("utf8")) else: utf8_args.append(arg) self._write_prefixed_bencode(utf8_args) def _write_end(self): - self._write_func(b'e') + self._write_func(b"e") self.flush() def _write_prefixed_body(self, bytes): - self._write_func(b'b') - self._write_func(struct.pack('!L', len(bytes))) + self._write_func(b"b") + self._write_func(struct.pack("!L", len(bytes))) self._write_func(bytes) def _write_chunked_body_start(self): - self._write_func(b'oC') + self._write_func(b"oC") def _write_error_status(self): - self._write_func(b'oE') + self._write_func(b"oE") def _write_success_status(self): - self._write_func(b'oS') + self._write_func(b"oS") class ProtocolThreeResponder(_ProtocolThreeEncoder): - def __init__(self, write_func): _ProtocolThreeEncoder.__init__(self, write_func) self.response_sent = False - self._headers = { - b'Software version': breezy.__version__.encode('utf-8')} - if debug.debug_flag_enabled('hpss'): + self._headers = {b"Software version": breezy.__version__.encode("utf-8")} + if debug.debug_flag_enabled("hpss"): self._thread_id = _thread.get_ident() self._response_start_time = None @@ -1171,41 +1190,42 @@ def _trace(self, action, message, extra_bytes=None, include_time=False): if self._response_start_time is None: self._response_start_time = osutils.perf_counter() if include_time: - t = f'{osutils.perf_counter() - self._response_start_time:5.3f}s ' + t = f"{osutils.perf_counter() - self._response_start_time:5.3f}s " else: - t = '' + t = "" if extra_bytes is None: - extra = '' + extra = "" else: - extra = ' ' + repr(extra_bytes[:40]) + extra = " " + repr(extra_bytes[:40]) if len(extra) > 33: - extra = extra[:29] + extra[-1] + '...' - mutter('%12s: [%s] %s%s%s' - % (action, self._thread_id, t, message, extra)) + extra = extra[:29] + extra[-1] + "..." + mutter("%12s: [%s] %s%s%s" % (action, self._thread_id, t, message, extra)) def send_error(self, exception): if self.response_sent: raise AssertionError( - f"send_error({exception}) called, but response already sent.") + f"send_error({exception}) called, but response already sent." + ) if isinstance(exception, errors.UnknownSmartMethod): failure = request.FailedSmartServerResponse( - (b'UnknownMethod', exception.verb)) + (b"UnknownMethod", exception.verb) + ) self.send_response(failure) return - if debug.debug_flag_enabled('hpss'): - self._trace('error', str(exception)) + if debug.debug_flag_enabled("hpss"): + self._trace("error", str(exception)) self.response_sent = True self._write_protocol_version() self._write_headers(self._headers) self._write_error_status() - self._write_structure( - (b'error', str(exception).encode('utf-8', 'replace'))) + self._write_structure((b"error", str(exception).encode("utf-8", "replace"))) self._write_end() def send_response(self, response): if self.response_sent: raise AssertionError( - f"send_response({response!r}) called, but response already sent.") + f"send_response({response!r}) called, but response already sent." + ) self.response_sent = True self._write_protocol_version() self._write_headers(self._headers) @@ -1213,14 +1233,18 @@ def send_response(self, response): self._write_success_status() else: self._write_error_status() - if debug.debug_flag_enabled('hpss'): - self._trace('response', repr(response.args)) + if debug.debug_flag_enabled("hpss"): + self._trace("response", repr(response.args)) self._write_structure(response.args) if response.body is not None: self._write_prefixed_body(response.body) - if debug.debug_flag_enabled('hpss'): - self._trace('body', f'{len(response.body)} bytes', - response.body, include_time=True) + if debug.debug_flag_enabled("hpss"): + self._trace( + "body", + f"{len(response.body)} bytes", + response.body, + include_time=True, + ) elif response.body_stream is not None: count = num_bytes = 0 first_chunk = None @@ -1241,19 +1265,24 @@ def send_response(self, response): first_chunk = chunk self._write_prefixed_body(chunk) self.flush() - if debug.debug_flag_enabled('hpssdetail'): + if debug.debug_flag_enabled("hpssdetail"): # Not worth timing separately, as _write_func is # actually buffered - self._trace('body chunk', - f'{len(chunk)} bytes', - chunk, suppress_time=True) - if debug.debug_flag_enabled('hpss'): - self._trace('body stream', - '%d bytes %d chunks' % (num_bytes, count), - first_chunk) + self._trace( + "body chunk", + f"{len(chunk)} bytes", + chunk, + suppress_time=True, + ) + if debug.debug_flag_enabled("hpss"): + self._trace( + "body stream", + "%d bytes %d chunks" % (num_bytes, count), + first_chunk, + ) self._write_end() - if debug.debug_flag_enabled('hpss'): - self._trace('response end', '', include_time=True) + if debug.debug_flag_enabled("hpss"): + self._trace("response end", "", include_time=True) def _iter_with_errors(iterable): @@ -1289,14 +1318,13 @@ def _iter_with_errors(iterable): except (KeyboardInterrupt, SystemExit): raise except Exception: - mutter('_iter_with_errors caught error') + mutter("_iter_with_errors caught error") log_exception_quietly() yield sys.exc_info(), None return class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester): - def __init__(self, medium_request): _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes) self._medium_request = medium_request @@ -1307,11 +1335,11 @@ def set_headers(self, headers): self._headers = headers.copy() def call(self, *args): - if debug.debug_flag_enabled('hpss'): - mutter('hpss call: %s', repr(args)[1:-1]) - base = getattr(self._medium_request._medium, 'base', None) + if debug.debug_flag_enabled("hpss"): + mutter("hpss call: %s", repr(args)[1:-1]) + base = getattr(self._medium_request._medium, "base", None) if base is not None: - mutter(' (to %s)', base) + mutter(" (to %s)", base) self._request_start_time = osutils.perf_counter() self._write_protocol_version() self._write_headers(self._headers) @@ -1324,12 +1352,12 @@ def call_with_body_bytes(self, args, body): After calling this, call read_response_tuple to find the result out. """ - if debug.debug_flag_enabled('hpss'): - mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20]) - path = getattr(self._medium_request._medium, '_path', None) + if debug.debug_flag_enabled("hpss"): + mutter("hpss call w/body: %s (%r...)", repr(args)[1:-1], body[:20]) + path = getattr(self._medium_request._medium, "_path", None) if path is not None: - mutter(' (to %s)', path) - mutter(' %d bytes', len(body)) + mutter(" (to %s)", path) + mutter(" %d bytes", len(body)) self._request_start_time = osutils.perf_counter() self._write_protocol_version() self._write_headers(self._headers) @@ -1344,28 +1372,28 @@ def call_with_body_readv_array(self, args, body): The body is encoded with one line per readv offset pair. The numbers in each pair are separated by a comma, and no trailing \n is emitted. """ - if debug.debug_flag_enabled('hpss'): - mutter('hpss call w/readv: %s', repr(args)[1:-1]) - path = getattr(self._medium_request._medium, '_path', None) + if debug.debug_flag_enabled("hpss"): + mutter("hpss call w/readv: %s", repr(args)[1:-1]) + path = getattr(self._medium_request._medium, "_path", None) if path is not None: - mutter(' (to %s)', path) + mutter(" (to %s)", path) self._request_start_time = osutils.perf_counter() self._write_protocol_version() self._write_headers(self._headers) self._write_structure(args) readv_bytes = self._serialise_offsets(body) - if debug.debug_flag_enabled('hpss'): - mutter(' %d bytes in readv request', len(readv_bytes)) + if debug.debug_flag_enabled("hpss"): + mutter(" %d bytes in readv request", len(readv_bytes)) self._write_prefixed_body(readv_bytes) self._write_end() self._medium_request.finished_writing() def call_with_body_stream(self, args, stream): - if debug.debug_flag_enabled('hpss'): - mutter('hpss call w/body stream: %r', args) - path = getattr(self._medium_request._medium, '_path', None) + if debug.debug_flag_enabled("hpss"): + mutter("hpss call w/body stream: %r", args) + path = getattr(self._medium_request._medium, "_path", None) if path is not None: - mutter(' (to %s)', path) + mutter(" (to %s)", path) self._request_start_time = osutils.perf_counter() self.body_stream_started = False self._write_protocol_version() @@ -1384,7 +1412,7 @@ def call_with_body_stream(self, args, stream): self._write_error_status() # Currently the client unconditionally sends ('error',) as the # error args. - self._write_structure((b'error',)) + self._write_structure((b"error",)) self._write_end() self._medium_request.finished_writing() (exc_type, exc_val, exc_tb) = exc_info diff --git a/breezy/bzr/smart/repository.py b/breezy/bzr/smart/repository.py index f1c814dc04..013374664a 100644 --- a/breezy/bzr/smart/repository.py +++ b/breezy/bzr/smart/repository.py @@ -81,21 +81,21 @@ def recreate_search(self, repository, search_bytes, discard_excess=False): recreate_search trusts that clients will look for missing things they expected and get it from elsewhere. """ - if search_bytes == b'everything': + if search_bytes == b"everything": return vf_search.EverythingResult(repository), None - lines = search_bytes.split(b'\n') - if lines[0] == b'ancestry-of': + lines = search_bytes.split(b"\n") + if lines[0] == b"ancestry-of": heads = lines[1:] search_result = vf_search.PendingAncestryResult(heads, repository) return search_result, None - elif lines[0] == b'search': - return self.recreate_search_from_recipe(repository, lines[1:], - discard_excess=discard_excess) + elif lines[0] == b"search": + return self.recreate_search_from_recipe( + repository, lines[1:], discard_excess=discard_excess + ) else: - return (None, FailedSmartServerResponse((b'BadSearch',))) + return (None, FailedSmartServerResponse((b"BadSearch",))) - def recreate_search_from_recipe(self, repository, lines, - discard_excess=False): + def recreate_search_from_recipe(self, repository, lines, discard_excess=False): """Recreate a specific revision search (vs a from-tip search). :param discard_excess: If True, and the search refers to data we don't @@ -103,12 +103,11 @@ def recreate_search_from_recipe(self, repository, lines, recreate_search trusts that clients will look for missing things they expected and get it from elsewhere. """ - start_keys = set(lines[0].split(b' ')) - exclude_keys = set(lines[1].split(b' ')) - revision_count = int(lines[2].decode('ascii')) + start_keys = set(lines[0].split(b" ")) + exclude_keys = set(lines[1].split(b" ")) + revision_count = int(lines[2].decode("ascii")) with repository.lock_read(): - search = repository.get_graph()._make_breadth_first_searcher( - start_keys) + search = repository.get_graph()._make_breadth_first_searcher(start_keys) while True: try: next_revs = next(search) @@ -116,15 +115,16 @@ def recreate_search_from_recipe(self, repository, lines, break search.stop_searching_any(exclude_keys.intersection(next_revs)) (started_keys, excludes, included_keys) = search.get_state() - if (not discard_excess and len(included_keys) != revision_count): + if not discard_excess and len(included_keys) != revision_count: # we got back a different amount of data than expected, this # gets reported as NoSuchRevision, because less revisions # indicates missing revisions, and more should never happen as # the excludes list considers ghosts and ensures that ghost # filling races are not a problem. - return (None, FailedSmartServerResponse((b'NoSuchRevision',))) - search_result = vf_search.SearchResult(started_keys, excludes, - len(included_keys), included_keys) + return (None, FailedSmartServerResponse((b"NoSuchRevision",))) + search_result = vf_search.SearchResult( + started_keys, excludes, len(included_keys), included_keys + ) return (search_result, None) @@ -142,7 +142,7 @@ class SmartServerRepositoryBreakLock(SmartServerRepositoryRequest): def do_repository_request(self, repository): repository.break_lock() - return SuccessfulSmartServerResponse((b'ok', )) + return SuccessfulSmartServerResponse((b"ok",)) _lsprof_count = 0 @@ -183,8 +183,14 @@ def do_body(self, body_bytes): with repository.lock_read(): return self._do_repository_request(body_bytes) - def _expand_requested_revs(self, repo_graph, revision_ids, client_seen_revs, - include_missing, max_size=65536): + def _expand_requested_revs( + self, + repo_graph, + revision_ids, + client_seen_revs, + include_missing, + max_size=65536, + ): result = {} queried_revs = set() estimator = zlib_util.ZLibEstimator(max_size) @@ -209,22 +215,27 @@ def _expand_requested_revs(self, repo_graph, revision_ids, client_seen_revs, missing_rev = True encoded_id = b"missing:" + revision_id parents = [] - if (revision_id not in client_seen_revs - and (not missing_rev or include_missing)): + if revision_id not in client_seen_revs and ( + not missing_rev or include_missing + ): # Client does not have this revision, give it to it. # add parents to the result result[encoded_id] = parents # Approximate the serialized cost of this revision_id. - line = encoded_id + b' ' + b' '.join(parents) + b'\n' + line = encoded_id + b" " + b" ".join(parents) + b"\n" estimator.add_content(line) # get all the directly asked for parents, and then flesh out to # 64K (compressed) or so. We do one level of depth at a time to # stay in sync with the client. The 250000 magic number is # estimated compression ratio taken from bzr.dev itself. if self.no_extra_results or (first_loop_done and estimator.full()): - trace.mutter('size: %d, z_size: %d' - % (estimator._uncompressed_size_added, - estimator._compressed_size_added)) + trace.mutter( + "size: %d, z_size: %d" + % ( + estimator._uncompressed_size_added, + estimator._compressed_size_added, + ) + ) next_revs = set() break # don't query things we've already queried @@ -235,12 +246,11 @@ def _expand_requested_revs(self, repo_graph, revision_ids, client_seen_revs, def _do_repository_request(self, body_bytes): repository = self._repository revision_ids = set(self._revision_ids) - include_missing = b'include-missing:' in revision_ids + include_missing = b"include-missing:" in revision_ids if include_missing: - revision_ids.remove(b'include-missing:') - body_lines = body_bytes.split(b'\n') - search_result, error = self.recreate_search_from_recipe( - repository, body_lines) + revision_ids.remove(b"include-missing:") + body_lines = body_bytes.split(b"\n") + search_result, error = self.recreate_search_from_recipe(repository, body_lines) if error is not None: return error # TODO might be nice to start up the search again; but thats not @@ -250,21 +260,20 @@ def _do_repository_request(self, body_bytes): client_seen_revs.difference_update(revision_ids) repo_graph = repository.get_graph() - result = self._expand_requested_revs(repo_graph, revision_ids, - client_seen_revs, include_missing) + result = self._expand_requested_revs( + repo_graph, revision_ids, client_seen_revs, include_missing + ) # sorting trivially puts lexographically similar revision ids together. # Compression FTW. lines = [] for revision, parents in sorted(result.items()): - lines.append(b' '.join((revision, ) + tuple(parents))) + lines.append(b" ".join((revision,) + tuple(parents))) - return SuccessfulSmartServerResponse( - (b'ok', ), bz2.compress(b'\n'.join(lines))) + return SuccessfulSmartServerResponse((b"ok",), bz2.compress(b"\n".join(lines))) class SmartServerRepositoryGetRevisionGraph(SmartServerRepositoryReadLocked): - def do_readlocked_repository_request(self, repository, revision_id): """Return the result of repository.get_revision_graph(revision_id). @@ -292,45 +301,43 @@ def do_readlocked_repository_request(self, repository, revision_id): # Note that we return an empty body, rather than omitting the body. # This way the client knows that it can always expect to find a body # in the response for this method, even in the error case. - return FailedSmartServerResponse((b'nosuchrevision', revision_id), b'') + return FailedSmartServerResponse((b"nosuchrevision", revision_id), b"") for revision, parents in revision_graph.items(): - lines.append(b' '.join((revision, ) + tuple(parents))) + lines.append(b" ".join((revision,) + tuple(parents))) - return SuccessfulSmartServerResponse((b'ok', ), b'\n'.join(lines)) + return SuccessfulSmartServerResponse((b"ok",), b"\n".join(lines)) class SmartServerRepositoryGetRevIdForRevno(SmartServerRepositoryReadLocked): - - def do_readlocked_repository_request(self, repository, revno, - known_pair): + def do_readlocked_repository_request(self, repository, revno, known_pair): """Find the revid for a given revno, given a known revno/revid pair. New in 1.17. """ try: - found_flag, result = repository.get_rev_id_for_revno( - revno, known_pair) + found_flag, result = repository.get_rev_id_for_revno(revno, known_pair) except errors.NoSuchRevision as err: if err.revision != known_pair[1]: raise AssertionError( - 'get_rev_id_for_revno raised RevisionNotPresent for ' - 'non-initial revision: ' + err.revision) from err - return FailedSmartServerResponse( - (b'nosuchrevision', err.revision)) + "get_rev_id_for_revno raised RevisionNotPresent for " + "non-initial revision: " + err.revision + ) from err + return FailedSmartServerResponse((b"nosuchrevision", err.revision)) except errors.RevnoOutOfBounds as e: return FailedSmartServerResponse( - (b'revno-outofbounds', e.revno, e.minimum, e.maximum)) + (b"revno-outofbounds", e.revno, e.minimum, e.maximum) + ) if found_flag: - return SuccessfulSmartServerResponse((b'ok', result)) + return SuccessfulSmartServerResponse((b"ok", result)) else: earliest_revno, earliest_revid = result return SuccessfulSmartServerResponse( - (b'history-incomplete', earliest_revno, earliest_revid)) + (b"history-incomplete", earliest_revno, earliest_revid) + ) class SmartServerRepositoryGetSerializerFormat(SmartServerRepositoryRequest): - def do_repository_request(self, repository): """Return the serializer format for this repository. @@ -340,11 +347,10 @@ def do_repository_request(self, repository): :return: A smart server response (b'ok', FORMAT) """ serializer = repository.get_serializer_format() - return SuccessfulSmartServerResponse((b'ok', serializer)) + return SuccessfulSmartServerResponse((b"ok", serializer)) class SmartServerRequestHasRevision(SmartServerRepositoryRequest): - def do_repository_request(self, repository, revision_id): """Return ok if a specific revision is in the repository at path. @@ -354,14 +360,12 @@ def do_repository_request(self, repository, revision_id): present. ('no', ) if it is missing. """ if repository.has_revision(revision_id): - return SuccessfulSmartServerResponse((b'yes', )) + return SuccessfulSmartServerResponse((b"yes",)) else: - return SuccessfulSmartServerResponse((b'no', )) + return SuccessfulSmartServerResponse((b"no",)) -class SmartServerRequestHasSignatureForRevisionId( - SmartServerRepositoryRequest): - +class SmartServerRequestHasSignatureForRevisionId(SmartServerRepositoryRequest): def do_repository_request(self, repository, revision_id): """Return ok if a signature is present for a revision. @@ -375,16 +379,14 @@ def do_repository_request(self, repository, revision_id): """ try: if repository.has_signature_for_revision_id(revision_id): - return SuccessfulSmartServerResponse((b'yes', )) + return SuccessfulSmartServerResponse((b"yes",)) else: - return SuccessfulSmartServerResponse((b'no', )) + return SuccessfulSmartServerResponse((b"no",)) except errors.NoSuchRevision: - return FailedSmartServerResponse( - (b'nosuchrevision', revision_id)) + return FailedSmartServerResponse((b"nosuchrevision", revision_id)) class SmartServerRepositoryGatherStats(SmartServerRepositoryRequest): - def do_repository_request(self, repository, revid, committers): """Return the result of repository.gather_stats(). @@ -400,37 +402,35 @@ def do_repository_request(self, repository, revid, committers): But containing only fields returned by the gather_stats() call """ - if revid == b'': + if revid == b"": decoded_revision_id = None else: decoded_revision_id = revid - if committers == b'yes': + if committers == b"yes": decoded_committers = True else: decoded_committers = None try: - stats = repository.gather_stats(decoded_revision_id, - decoded_committers) + stats = repository.gather_stats(decoded_revision_id, decoded_committers) except errors.NoSuchRevision: - return FailedSmartServerResponse((b'nosuchrevision', revid)) + return FailedSmartServerResponse((b"nosuchrevision", revid)) - body = b'' - if 'committers' in stats: - body += b'committers: %d\n' % stats['committers'] - if 'firstrev' in stats: - body += b'firstrev: %.3f %d\n' % stats['firstrev'] - if 'latestrev' in stats: - body += b'latestrev: %.3f %d\n' % stats['latestrev'] - if 'revisions' in stats: - body += b'revisions: %d\n' % stats['revisions'] - if 'size' in stats: - body += b'size: %d\n' % stats['size'] + body = b"" + if "committers" in stats: + body += b"committers: %d\n" % stats["committers"] + if "firstrev" in stats: + body += b"firstrev: %.3f %d\n" % stats["firstrev"] + if "latestrev" in stats: + body += b"latestrev: %.3f %d\n" % stats["latestrev"] + if "revisions" in stats: + body += b"revisions: %d\n" % stats["revisions"] + if "size" in stats: + body += b"size: %d\n" % stats["size"] - return SuccessfulSmartServerResponse((b'ok', ), body) + return SuccessfulSmartServerResponse((b"ok",), body) -class SmartServerRepositoryGetRevisionSignatureText( - SmartServerRepositoryRequest): +class SmartServerRepositoryGetRevisionSignatureText(SmartServerRepositoryRequest): """Return the signature text of a revision. New in 2.5. @@ -446,13 +446,11 @@ def do_repository_request(self, repository, revision_id): try: text = repository.get_signature_text(revision_id) except errors.NoSuchRevision as err: - return FailedSmartServerResponse( - (b'nosuchrevision', err.revision)) - return SuccessfulSmartServerResponse((b'ok', ), text) + return FailedSmartServerResponse((b"nosuchrevision", err.revision)) + return SuccessfulSmartServerResponse((b"ok",), text) class SmartServerRepositoryIsShared(SmartServerRepositoryRequest): - def do_repository_request(self, repository): """Return the result of repository.is_shared(). @@ -461,13 +459,12 @@ def do_repository_request(self, repository): shared, and ('no', ) if it is not. """ if repository.is_shared(): - return SuccessfulSmartServerResponse((b'yes', )) + return SuccessfulSmartServerResponse((b"yes",)) else: - return SuccessfulSmartServerResponse((b'no', )) + return SuccessfulSmartServerResponse((b"no",)) class SmartServerRepositoryMakeWorkingTrees(SmartServerRepositoryRequest): - def do_repository_request(self, repository): """Return the result of repository.make_working_trees(). @@ -478,36 +475,33 @@ def do_repository_request(self, repository): working trees, and ('no', ) if it is not. """ if repository.make_working_trees(): - return SuccessfulSmartServerResponse((b'yes', )) + return SuccessfulSmartServerResponse((b"yes",)) else: - return SuccessfulSmartServerResponse((b'no', )) + return SuccessfulSmartServerResponse((b"no",)) class SmartServerRepositoryLockWrite(SmartServerRepositoryRequest): - - def do_repository_request(self, repository, token=b''): + def do_repository_request(self, repository, token=b""): # XXX: this probably should not have a token. - if token == b'': + if token == b"": token = None try: token = repository.lock_write(token=token).repository_token except errors.LockContention: - return FailedSmartServerResponse((b'LockContention',)) + return FailedSmartServerResponse((b"LockContention",)) except errors.UnlockableTransport: - return FailedSmartServerResponse((b'UnlockableTransport',)) + return FailedSmartServerResponse((b"UnlockableTransport",)) except errors.LockFailed as e: - return FailedSmartServerResponse((b'LockFailed', - str(e.lock), str(e.why))) + return FailedSmartServerResponse((b"LockFailed", str(e.lock), str(e.why))) if token is not None: repository.leave_lock_in_place() repository.unlock() if token is None: - token = b'' - return SuccessfulSmartServerResponse((b'ok', token)) + token = b"" + return SuccessfulSmartServerResponse((b"ok", token)) class SmartServerRepositoryGetStream(SmartServerRepositoryRequest): - def do_repository_request(self, repository, to_network_name): """Get a stream for inserting into a to_format repository. @@ -525,7 +519,8 @@ def do_repository_request(self, repository, to_network_name): self._to_format = network_format_registry.get(to_network_name) if self._should_fake_unknown(): return FailedSmartServerResponse( - (b'UnknownMethod', b'Repository.get_stream')) + (b"UnknownMethod", b"Repository.get_stream") + ) return None # Signal that we want a body. def _should_fake_unknown(self): @@ -549,10 +544,12 @@ def _should_fake_unknown(self): if not from_format.supports_chks: # Source not CHK: that's ok return False - if (to_format.supports_chks + if ( + to_format.supports_chks and from_format.repository_class is to_format.repository_class - and from_format._revision_serializer == to_format._revision_serializer - and from_format._inventory_serializer == to_format._inventory_serializer): + and from_format._revision_serializer == to_format._revision_serializer + and from_format._inventory_serializer == to_format._inventory_serializer + ): # Source is CHK, but target matches: that's ok # (e.g. 2a->2a, or CHK2->2a) return False @@ -564,8 +561,9 @@ def do_body(self, body_bytes): repository = self._repository repository.lock_read() try: - search_result, error = self.recreate_search(repository, body_bytes, - discard_excess=True) + search_result, error = self.recreate_search( + repository, body_bytes, discard_excess=True + ) if error is not None: repository.unlock() return error @@ -577,8 +575,9 @@ def do_body(self, body_bytes): repository.unlock() finally: raise - return SuccessfulSmartServerResponse((b'ok',), - body_stream=self.body_stream(stream, repository)) + return SuccessfulSmartServerResponse( + (b"ok",), body_stream=self.body_stream(stream, repository) + ) def body_stream(self, stream, repository): byte_stream = _stream_to_byte_stream(stream, repository._format) @@ -588,7 +587,7 @@ def body_stream(self, stream, repository): # This shouldn't be able to happen, but as we don't buffer # everything it can in theory happen. repository.unlock() - yield FailedSmartServerResponse((b'NoSuchRevision', e.revision_id)) + yield FailedSmartServerResponse((b"NoSuchRevision", e.revision_id)) else: repository.unlock() @@ -611,12 +610,12 @@ def _stream_to_byte_stream(stream, src_format): """Convert a record stream to a self delimited byte stream.""" pack_writer = pack.ContainerSerialiser() yield pack_writer.begin() - yield pack_writer.bytes_record(src_format.network_name(), b'') + yield pack_writer.bytes_record(src_format.network_name(), b"") for substream_type, substream in stream: for record in substream: - if record.storage_kind in ('chunked', 'fulltext'): + if record.storage_kind in ("chunked", "fulltext"): serialised = record_to_fulltext_bytes(record) - elif record.storage_kind == 'absent': + elif record.storage_kind == "absent": raise ValueError(f"Absent factory for {record.key}") else: serialised = record.get_bytes_as(record.storage_kind) @@ -624,7 +623,9 @@ def _stream_to_byte_stream(stream, src_format): # Some streams embed the whole stream into the wire # representation of the first record, which means that # later records have no wire representation: we skip them. - yield pack_writer.bytes_record(serialised, [(substream_type.encode('ascii'),)]) + yield pack_writer.bytes_record( + serialised, [(substream_type.encode("ascii"),)] + ) yield pack_writer.end() @@ -678,7 +679,7 @@ def iter_substream_bytes(self): self.first_bytes = None for record in self.iter_pack_records: record_names, record_bytes = record - record_name, = record_names + (record_name,) = record_names substream_type = record_name[0] if substream_type != self.current_type: # end of a substream, seed the next substream. @@ -689,11 +690,12 @@ def iter_substream_bytes(self): def record_stream(self): """Yield substream_type, substream from the byte stream.""" + def wrap_and_count(pb, rc, substream): """Yield records from stream while showing progress.""" counter = 0 if rc: - if self.current_type != 'revisions' and self.key_count != 0: + if self.current_type != "revisions" and self.key_count != 0: # As we know the number of revisions now (in self.key_count) # we can setup and use record_counter (rc). if not rc.is_initialized(): @@ -702,9 +704,9 @@ def wrap_and_count(pb, rc, substream): if rc: if rc.is_initialized() and counter == rc.STEP: rc.increment(counter) - pb.update('Estimate', rc.current, rc.max) + pb.update("Estimate", rc.current, rc.max) counter = 0 - if self.current_type == 'revisions': + if self.current_type == "revisions": # Total records is proportional to number of revs # to fetch. With remote, we used self.key_count to # track the number of revs. Once we have the revs @@ -712,7 +714,7 @@ def wrap_and_count(pb, rc, substream): # from 'Estimating..' to 'Estimate' above. self.key_count += 1 if counter == rc.STEP: - pb.update('Estimating..', self.key_count) + pb.update("Estimating..", self.key_count) counter = 0 counter += 1 yield record @@ -723,15 +725,17 @@ def wrap_and_count(pb, rc, substream): try: # Make and consume sub generators, one per substream type: while self.first_bytes is not None: - substream = NetworkRecordStream( - self.iter_substream_bytes()) + substream = NetworkRecordStream(self.iter_substream_bytes()) # after substream is fully consumed, self.current_type is set # to the next type, and self.first_bytes is set to the matching # bytes. - yield self.current_type.decode('ascii'), wrap_and_count(pb, rc, substream) + yield ( + self.current_type.decode("ascii"), + wrap_and_count(pb, rc, substream), + ) finally: if rc: - pb.update('Done', rc.max, rc.max) + pb.update("Done", rc.max, rc.max) def seed_state(self): """Prepare the _ByteStreamDecoder to decode from the pack stream.""" @@ -758,15 +762,14 @@ def _byte_stream_to_stream(byte_stream, record_counter=None): class SmartServerRepositoryUnlock(SmartServerRepositoryRequest): - def do_repository_request(self, repository, token): try: repository.lock_write(token=token) except errors.TokenMismatch: - return FailedSmartServerResponse((b'TokenMismatch',)) + return FailedSmartServerResponse((b"TokenMismatch",)) repository.dont_leave_lock_in_place() repository.unlock() - return SuccessfulSmartServerResponse((b'ok',)) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerRepositoryGetPhysicalLockStatus(SmartServerRepositoryRequest): @@ -777,20 +780,19 @@ class SmartServerRepositoryGetPhysicalLockStatus(SmartServerRepositoryRequest): def do_repository_request(self, repository): if repository.get_physical_lock_status(): - return SuccessfulSmartServerResponse((b'yes', )) + return SuccessfulSmartServerResponse((b"yes",)) else: - return SuccessfulSmartServerResponse((b'no', )) + return SuccessfulSmartServerResponse((b"no",)) class SmartServerRepositorySetMakeWorkingTrees(SmartServerRepositoryRequest): - def do_repository_request(self, repository, str_bool_new_value): - if str_bool_new_value == b'True': + if str_bool_new_value == b"True": new_value = True else: new_value = False repository.set_make_working_trees(new_value) - return SuccessfulSmartServerResponse((b'ok',)) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerRepositoryTarball(SmartServerRepositoryRequest): @@ -808,13 +810,13 @@ class SmartServerRepositoryTarball(SmartServerRepositoryRequest): def do_repository_request(self, repository, compression): tmp_dirname, tmp_repo = self._copy_to_tempdir(repository) try: - controldir_name = tmp_dirname + '/.bzr' + controldir_name = tmp_dirname + "/.bzr" return self._tarfile_response(controldir_name, compression) finally: osutils.rmtree(tmp_dirname) def _copy_to_tempdir(self, from_repo): - tmp_dirname = tempfile.mkdtemp(prefix='tmpbzrclone') + tmp_dirname = tempfile.mkdtemp(prefix="tmpbzrclone") tmp_bzrdir = from_repo.controldir._format.initialize(tmp_dirname) tmp_repo = from_repo._format.initialize(tmp_bzrdir) from_repo.copy_content_into(tmp_repo) @@ -825,15 +827,17 @@ def _tarfile_response(self, tmp_dirname, compression): self._tarball_of_dir(tmp_dirname, compression, temp.file) # all finished; write the tempfile out to the network temp.seek(0) - return SuccessfulSmartServerResponse((b'ok',), temp.read()) + return SuccessfulSmartServerResponse((b"ok",), temp.read()) # FIXME: Don't read the whole thing into memory here; rather stream # it out from the file onto the network. mbp 20070411 def _tarball_of_dir(self, dirname, compression, ofile): import tarfile + filename = os.path.basename(ofile.name) - with tarfile.open(fileobj=ofile, name=filename, - mode='w|' + compression) as tarball: + with tarfile.open( + fileobj=ofile, name=filename, mode="w|" + compression + ) as tarball: # The tarball module only accepts ascii names, and (i guess) # packs them with their 8bit names. We know all the files # within the repository have ASCII names so the should be safe @@ -841,9 +845,9 @@ def _tarball_of_dir(self, dirname, compression, ofile): dirname = dirname.encode(sys.getfilesystemencoding()) # python's tarball module includes the whole path by default so # override it - if not dirname.endswith('.bzr'): + if not dirname.endswith(".bzr"): raise ValueError(dirname) - tarball.add(dirname, '.bzr') # recursive by default + tarball.add(dirname, ".bzr") # recursive by default class SmartServerRepositoryInsertStreamLocked(SmartServerRepositoryRequest): @@ -862,8 +866,7 @@ def do_repository_request(self, repository, resume_tokens, lock_token): self.do_insert_stream_request(repository, resume_tokens) def do_insert_stream_request(self, repository, resume_tokens): - tokens = [token.decode('utf-8') - for token in resume_tokens.split(b' ') if token] + tokens = [token.decode("utf-8") for token in resume_tokens.split(b" ") if token] self.tokens = tokens self.repository = repository self.queue = queue.Queue() @@ -875,10 +878,10 @@ def do_chunk(self, body_stream_chunk): def _inserter_thread(self): try: - src_format, stream = _byte_stream_to_stream( - self.blocking_byte_stream()) + src_format, stream = _byte_stream_to_stream(self.blocking_byte_stream()) self.insert_result = self.repository._get_sink().insert_stream( - stream, src_format, self.tokens) + stream, src_format, self.tokens + ) self.insert_ok = True except BaseException: self.insert_exception = sys.exc_info() @@ -907,14 +910,16 @@ def do_end(self): # bzip needed? missing keys should typically be a small set. # Should this be a streaming body response ? missing_keys = sorted( - [(entry[0].encode('utf-8'),) + entry[1:] for entry in missing_keys]) - bytes = bencode.bencode(( - [token.encode('utf-8') for token in write_group_tokens], missing_keys)) + [(entry[0].encode("utf-8"),) + entry[1:] for entry in missing_keys] + ) + bytes = bencode.bencode( + ([token.encode("utf-8") for token in write_group_tokens], missing_keys) + ) self.repository.unlock() - return SuccessfulSmartServerResponse((b'missing-basis', bytes)) + return SuccessfulSmartServerResponse((b"missing-basis", bytes)) else: self.repository.unlock() - return SuccessfulSmartServerResponse((b'ok', )) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerRepositoryInsertStream_1_19(SmartServerRepositoryInsertStreamLocked): @@ -931,7 +936,8 @@ class SmartServerRepositoryInsertStream_1_19(SmartServerRepositoryInsertStreamLo def do_repository_request(self, repository, resume_tokens, lock_token=None): """StreamSink.insert_stream for a remote repository.""" SmartServerRepositoryInsertStreamLocked.do_repository_request( - self, repository, resume_tokens, lock_token) + self, repository, resume_tokens, lock_token + ) class SmartServerRepositoryInsertStream(SmartServerRepositoryInsertStreamLocked): @@ -956,8 +962,9 @@ class SmartServerRepositoryAddSignatureText(SmartServerRepositoryRequest): New in 2.5. """ - def do_repository_request(self, repository, lock_token, revision_id, - *write_group_tokens): + def do_repository_request( + self, repository, lock_token, revision_id, *write_group_tokens + ): """Add a revision signature text. :param repository: Repository to operate on @@ -967,8 +974,9 @@ def do_repository_request(self, repository, lock_token, revision_id, """ self._lock_token = lock_token self._revision_id = revision_id - self._write_group_tokens = [token.decode( - 'utf-8') for token in write_group_tokens] + self._write_group_tokens = [ + token.decode("utf-8") for token in write_group_tokens + ] return None def do_body(self, body_bytes): @@ -981,12 +989,13 @@ def do_body(self, body_bytes): with self._repository.lock_write(token=self._lock_token): self._repository.resume_write_group(self._write_group_tokens) try: - self._repository.add_signature_text(self._revision_id, - body_bytes) + self._repository.add_signature_text(self._revision_id, body_bytes) finally: new_write_group_tokens = self._repository.suspend_write_group() return SuccessfulSmartServerResponse( - (b'ok', ) + tuple([token.encode('utf-8') for token in new_write_group_tokens])) + (b"ok",) + + tuple([token.encode("utf-8") for token in new_write_group_tokens]) + ) class SmartServerRepositoryStartWriteGroup(SmartServerRepositoryRequest): @@ -1002,8 +1011,8 @@ def do_repository_request(self, repository, lock_token): try: tokens = repository.suspend_write_group() except errors.UnsuspendableWriteGroup: - return FailedSmartServerResponse((b'UnsuspendableWriteGroup',)) - return SuccessfulSmartServerResponse((b'ok', tokens)) + return FailedSmartServerResponse((b"UnsuspendableWriteGroup",)) + return SuccessfulSmartServerResponse((b"ok", tokens)) class SmartServerRepositoryCommitWriteGroup(SmartServerRepositoryRequest): @@ -1012,17 +1021,21 @@ class SmartServerRepositoryCommitWriteGroup(SmartServerRepositoryRequest): New in 2.5. """ - def do_repository_request(self, repository, lock_token, - write_group_tokens): + def do_repository_request(self, repository, lock_token, write_group_tokens): """Commit a write group.""" with repository.lock_write(token=lock_token): try: repository.resume_write_group( - [token.decode('utf-8') for token in write_group_tokens]) + [token.decode("utf-8") for token in write_group_tokens] + ) except errors.UnresumableWriteGroup as e: return FailedSmartServerResponse( - (b'UnresumableWriteGroup', [token.encode('utf-8') for token - in e.write_groups], e.reason.encode('utf-8'))) + ( + b"UnresumableWriteGroup", + [token.encode("utf-8") for token in e.write_groups], + e.reason.encode("utf-8"), + ) + ) try: repository.commit_write_group() except: @@ -1030,7 +1043,7 @@ def do_repository_request(self, repository, lock_token, # FIXME JRV 2011-11-19: What if the write_group_tokens # have changed? raise - return SuccessfulSmartServerResponse((b'ok', )) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerRepositoryAbortWriteGroup(SmartServerRepositoryRequest): @@ -1044,14 +1057,18 @@ def do_repository_request(self, repository, lock_token, write_group_tokens): with repository.lock_write(token=lock_token): try: repository.resume_write_group( - [token.decode('utf-8') for token in write_group_tokens]) + [token.decode("utf-8") for token in write_group_tokens] + ) except errors.UnresumableWriteGroup as e: return FailedSmartServerResponse( - (b'UnresumableWriteGroup', - [token.encode('utf-8') for token in e.write_groups], - e.reason.encode('utf-8'))) + ( + b"UnresumableWriteGroup", + [token.encode("utf-8") for token in e.write_groups], + e.reason.encode("utf-8"), + ) + ) repository.abort_write_group() - return SuccessfulSmartServerResponse((b'ok', )) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerRepositoryCheckWriteGroup(SmartServerRepositoryRequest): @@ -1065,15 +1082,19 @@ def do_repository_request(self, repository, lock_token, write_group_tokens): with repository.lock_write(token=lock_token): try: repository.resume_write_group( - [token.decode('utf-8') for token in write_group_tokens]) + [token.decode("utf-8") for token in write_group_tokens] + ) except errors.UnresumableWriteGroup as e: return FailedSmartServerResponse( - (b'UnresumableWriteGroup', - [token.encode('utf-8') for token in e.write_groups], - e.reason.encode('utf-8'))) + ( + b"UnresumableWriteGroup", + [token.encode("utf-8") for token in e.write_groups], + e.reason.encode("utf-8"), + ) + ) else: repository.suspend_write_group() - return SuccessfulSmartServerResponse((b'ok', )) + return SuccessfulSmartServerResponse((b"ok",)) class SmartServerRepositoryAllRevisionIds(SmartServerRepositoryRequest): @@ -1084,7 +1105,7 @@ class SmartServerRepositoryAllRevisionIds(SmartServerRepositoryRequest): def do_repository_request(self, repository): revids = repository.all_revision_ids() - return SuccessfulSmartServerResponse((b"ok", ), b"\n".join(revids)) + return SuccessfulSmartServerResponse((b"ok",), b"\n".join(revids)) class SmartServerRepositoryReconcile(SmartServerRepositoryRequest): @@ -1097,8 +1118,7 @@ def do_repository_request(self, repository, lock_token): try: repository.lock_write(token=lock_token) except errors.TokenLockingNotSupported: - return FailedSmartServerResponse( - (b'TokenLockingNotSupported', )) + return FailedSmartServerResponse((b"TokenLockingNotSupported",)) try: reconciler = repository.reconcile() finally: @@ -1106,8 +1126,8 @@ def do_repository_request(self, repository, lock_token): body = [ b"garbage_inventories: %d\n" % reconciler.garbage_inventories, b"inconsistent_parents: %d\n" % reconciler.inconsistent_parents, - ] - return SuccessfulSmartServerResponse((b'ok', ), b"".join(body)) + ] + return SuccessfulSmartServerResponse((b"ok",), b"".join(body)) class SmartServerRepositoryPack(SmartServerRepositoryRequest): @@ -1119,7 +1139,7 @@ class SmartServerRepositoryPack(SmartServerRepositoryRequest): def do_repository_request(self, repository, lock_token, clean_obsolete_packs): self._repository = repository self._lock_token = lock_token - if clean_obsolete_packs == b'True': + if clean_obsolete_packs == b"True": self._clean_obsolete_packs = True else: self._clean_obsolete_packs = False @@ -1132,7 +1152,9 @@ def do_body(self, body_bytes): hint = body_bytes.splitlines() with self._repository.lock_write(token=self._lock_token): self._repository.pack(hint, self._clean_obsolete_packs) - return SuccessfulSmartServerResponse((b"ok", ), ) + return SuccessfulSmartServerResponse( + (b"ok",), + ) class SmartServerRepositoryIterFilesBytes(SmartServerRepositoryRequest): @@ -1160,17 +1182,21 @@ def body_stream(self, repository, desired_files): text_keys = {} for i, key in enumerate(desired_files): text_keys[key] = i - for record in repository.texts.get_record_stream(text_keys, - 'unordered', True): + for record in repository.texts.get_record_stream( + text_keys, "unordered", True + ): identifier = text_keys[record.key] - if record.storage_kind == 'absent': - yield b"absent\0%s\0%s\0%d\n" % (record.key[0], - record.key[1], identifier) + if record.storage_kind == "absent": + yield b"absent\0%s\0%s\0%d\n" % ( + record.key[0], + record.key[1], + identifier, + ) # FIXME: Way to abort early? continue yield b"ok\0%d\n" % identifier compressor = zlib.compressobj() - for bytes in record.iter_bytes_as('chunked'): + for bytes in record.iter_bytes_as("chunked"): data = compressor.compress(bytes) if data: yield data @@ -1179,10 +1205,10 @@ def body_stream(self, repository, desired_files): yield data def do_body(self, body_bytes): - desired_files = [ - tuple(l.split(b"\0")) for l in body_bytes.splitlines()] - return SuccessfulSmartServerResponse((b'ok', ), - body_stream=self.body_stream(self._repository, desired_files)) + desired_files = [tuple(l.split(b"\0")) for l in body_bytes.splitlines()] + return SuccessfulSmartServerResponse( + (b"ok",), body_stream=self.body_stream(self._repository, desired_files) + ) def do_repository_request(self, repository): # Signal that we want a body @@ -1210,16 +1236,18 @@ def do_repository_request(self, repository): def do_body(self, body_bytes): revision_ids = body_bytes.split(b"\n") return SuccessfulSmartServerResponse( - (b'ok', self._repository.get_serializer_format()), - body_stream=self.body_stream(self._repository, revision_ids)) + (b"ok", self._repository.get_serializer_format()), + body_stream=self.body_stream(self._repository, revision_ids), + ) def body_stream(self, repository, revision_ids): with self._repository.lock_read(): for record in repository.revisions.get_record_stream( - [(revid,) for revid in revision_ids], 'unordered', True): - if record.storage_kind == 'absent': + [(revid,) for revid in revision_ids], "unordered", True + ): + if record.storage_kind == "absent": continue - yield zlib.compress(record.get_bytes_as('fulltext')) + yield zlib.compress(record.get_bytes_as("fulltext")) class SmartServerRepositoryGetInventories(SmartServerRepositoryRequest): @@ -1237,47 +1265,51 @@ class SmartServerRepositoryGetInventories(SmartServerRepositoryRequest): """ def _inventory_delta_stream(self, repository, ordering, revids): - prev_inv = _mod_inventory.Inventory(root_id=None, - revision_id=_mod_revision.NULL_REVISION) + prev_inv = _mod_inventory.Inventory( + root_id=None, revision_id=_mod_revision.NULL_REVISION + ) serializer = inventory_delta.InventoryDeltaSerializer( - repository.supports_rich_root(), - repository._format.supports_tree_reference) + repository.supports_rich_root(), repository._format.supports_tree_reference + ) with repository.lock_read(): for inv, _revid in repository._iter_inventories(revids, ordering): if inv is None: continue inv_delta = inv._make_delta(prev_inv) lines = serializer.delta_to_lines( - prev_inv.revision_id, inv.revision_id, inv_delta) + prev_inv.revision_id, inv.revision_id, inv_delta + ) yield ChunkedContentFactory( - inv.revision_id, None, None, lines, - chunks_are_lines=True) + inv.revision_id, None, None, lines, chunks_are_lines=True + ) prev_inv = inv def body_stream(self, repository, ordering, revids): - substream = self._inventory_delta_stream(repository, - ordering, revids) - return _stream_to_byte_stream([('inventory-deltas', substream)], - repository._format) + substream = self._inventory_delta_stream(repository, ordering, revids) + return _stream_to_byte_stream( + [("inventory-deltas", substream)], repository._format + ) def do_body(self, body_bytes): - return SuccessfulSmartServerResponse((b'ok', ), - body_stream=self.body_stream(self._repository, self._ordering, - body_bytes.splitlines())) + return SuccessfulSmartServerResponse( + (b"ok",), + body_stream=self.body_stream( + self._repository, self._ordering, body_bytes.splitlines() + ), + ) def do_repository_request(self, repository, ordering): - ordering = ordering.decode('ascii') - if ordering == 'unordered': + ordering = ordering.decode("ascii") + if ordering == "unordered": # inventory deltas for a topologically sorted stream # are likely to be smaller - ordering = 'topological' + ordering = "topological" self._ordering = ordering # Signal that we want a body return None class SmartServerRepositoryGetStreamForMissingKeys(SmartServerRepositoryRequest): - def do_repository_request(self, repository, to_network_name): """Get a stream for missing keys. @@ -1289,7 +1321,8 @@ def do_repository_request(self, repository, to_network_name): self._to_format = network_format_registry.get(to_network_name) except KeyError: return FailedSmartServerResponse( - (b'UnknownFormat', b'repository', to_network_name)) + (b"UnknownFormat", b"repository", to_network_name) + ) return None # Signal that we want a body. def do_body(self, body_bytes): @@ -1298,9 +1331,9 @@ def do_body(self, body_bytes): try: source = repository._get_source(self._to_format) keys = [] - for entry in body_bytes.split(b'\n'): - (kind, revid) = entry.split(b'\t') - keys.append((kind.decode('utf-8'), revid)) + for entry in body_bytes.split(b"\n"): + (kind, revid) = entry.split(b"\t") + keys.append((kind.decode("utf-8"), revid)) stream = source.get_stream_for_missing_keys(keys) except Exception: try: @@ -1308,8 +1341,9 @@ def do_body(self, body_bytes): repository.unlock() finally: raise - return SuccessfulSmartServerResponse((b'ok',), - body_stream=self.body_stream(stream, repository)) + return SuccessfulSmartServerResponse( + (b"ok",), body_stream=self.body_stream(stream, repository) + ) def body_stream(self, stream, repository): byte_stream = _stream_to_byte_stream(stream, repository._format) @@ -1319,15 +1353,15 @@ def body_stream(self, stream, repository): # This shouldn't be able to happen, but as we don't buffer # everything it can in theory happen. repository.unlock() - yield FailedSmartServerResponse((b'NoSuchRevision', e.revision_id)) + yield FailedSmartServerResponse((b"NoSuchRevision", e.revision_id)) else: repository.unlock() class SmartServerRepositoryRevisionArchive(SmartServerRepositoryRequest): - - def do_repository_request(self, repository, revision_id, format, name, - root, subdir=None, force_mtime=None): + def do_repository_request( + self, repository, revision_id, format, name, root, subdir=None, force_mtime=None + ): """Stream an archive file for a specific revision. :param repository: The repository to stream from. :param revision_id: Revision for which to export the tree @@ -1338,15 +1372,21 @@ def do_repository_request(self, repository, revision_id, format, name, """ tree = repository.revision_tree(revision_id) if subdir is not None: - subdir = subdir.decode('utf-8') + subdir = subdir.decode("utf-8") if root is not None: - root = root.decode('utf-8') - name = name.decode('utf-8') - return SuccessfulSmartServerResponse((b'ok',), - body_stream=self.body_stream( - tree, format.decode( - 'utf-8'), os.path.basename(name), root, subdir, - force_mtime)) + root = root.decode("utf-8") + name = name.decode("utf-8") + return SuccessfulSmartServerResponse( + (b"ok",), + body_stream=self.body_stream( + tree, + format.decode("utf-8"), + os.path.basename(name), + root, + subdir, + force_mtime, + ), + ) def body_stream(self, tree, format, name, root, subdir=None, force_mtime=None): with tree.lock_read(): @@ -1354,9 +1394,9 @@ def body_stream(self, tree, format, name, root, subdir=None, force_mtime=None): class SmartServerRepositoryAnnotateFileRevision(SmartServerRepositoryRequest): - - def do_repository_request(self, repository, revision_id, tree_path, - file_id=None, default_revision=None): + def do_repository_request( + self, repository, revision_id, tree_path, file_id=None, default_revision=None + ): """Stream an archive file for a specific revision. :param repository: The repository to stream from. @@ -1367,6 +1407,7 @@ def do_repository_request(self, repository, revision_id, tree_path, """ tree = repository.revision_tree(revision_id) with tree.lock_read(): - body = bencode.bencode(list(tree.annotate_iter( - tree_path.decode('utf-8'), default_revision))) - return SuccessfulSmartServerResponse((b'ok',), body=body) + body = bencode.bencode( + list(tree.annotate_iter(tree_path.decode("utf-8"), default_revision)) + ) + return SuccessfulSmartServerResponse((b"ok",), body=body) diff --git a/breezy/bzr/smart/request.py b/breezy/bzr/smart/request.py index b22caf9374..6f54af2850 100644 --- a/breezy/bzr/smart/request.py +++ b/breezy/bzr/smart/request.py @@ -43,7 +43,6 @@ class DisabledMethod(errors.InternalBzrError): - _fmt = "The smart server method '%(class_name)s' is disabled." def __init__(self, class_name): @@ -53,12 +52,14 @@ def __init__(self, class_name): def _install_hook(): from breezy.bzr import bzrdir + bzrdir.BzrDir.hooks.install_named_hook( - 'pre_open', _pre_open_hook, 'checking server jail') + "pre_open", _pre_open_hook, "checking server jail" + ) def _pre_open_hook(transport): - allowed_transports = getattr(jail_info, 'transports', None) + allowed_transports = getattr(jail_info, "transports", None) if allowed_transports is None: return abspath = transport.base @@ -86,7 +87,7 @@ class SmartServerRequest: # XXX: rename this class to BaseSmartServerRequestHandler ? A request # *handler* is a different concept to the request. - def __init__(self, backing_transport, root_client_path='/', jail_root=None): + def __init__(self, backing_transport, root_client_path="/", jail_root=None): """Constructor. :param backing_transport: the base transport to be used when performing @@ -104,10 +105,10 @@ def __init__(self, backing_transport, root_client_path='/', jail_root=None): jail_root = backing_transport self._jail_root = jail_root if root_client_path is not None: - if not root_client_path.startswith('/'): - root_client_path = '/' + root_client_path - if not root_client_path.endswith('/'): - root_client_path += '/' + if not root_client_path.startswith("/"): + root_client_path = "/" + root_client_path + if not root_client_path.endswith("/"): + root_client_path += "/" self._root_client_path = root_client_path self._body_chunks = [] @@ -143,8 +144,8 @@ def do_body(self, body_bytes): Must return a SmartServerResponse. """ - if body_bytes != b'': - raise errors.SmartProtocolError('Request does not expect a body') + if body_bytes != b"": + raise errors.SmartProtocolError("Request does not expect a body") def do_chunk(self, chunk_bytes): """Called with each body chunk if the request has a streamed body. @@ -155,7 +156,7 @@ def do_chunk(self, chunk_bytes): def do_end(self): """Called when the end of the request has been received.""" - body_bytes = b''.join(self._body_chunks) + body_bytes = b"".join(self._body_chunks) self._body_chunks = None return self.do_body(body_bytes) @@ -176,20 +177,20 @@ def translate_client_path(self, client_path): (unlike the untranslated client_path, which must not be used with the backing transport). """ - client_path = client_path.decode('utf-8') + client_path = client_path.decode("utf-8") if self._root_client_path is None: # no translation necessary! return client_path - if not client_path.startswith('/'): - client_path = '/' + client_path - if client_path + '/' == self._root_client_path: - return '.' + if not client_path.startswith("/"): + client_path = "/" + client_path + if client_path + "/" == self._root_client_path: + return "." if client_path.startswith(self._root_client_path): - path = client_path[len(self._root_client_path):] - relpath = urlutils.joinpath('/', path) - if not relpath.startswith('/'): + path = client_path[len(self._root_client_path) :] + relpath = urlutils.joinpath("/", path) + if not relpath.startswith("/"): raise ValueError(relpath) - return urlutils.escape('.' + relpath) + return urlutils.escape("." + relpath) else: raise errors.PathNotChild(client_path, self._root_client_path) @@ -221,17 +222,18 @@ def __init__(self, args, body=None, body_stream=None): """ self.args = args if body is not None and body_stream is not None: - raise errors.BzrError( - "'body' and 'body_stream' are mutually exclusive.") + raise errors.BzrError("'body' and 'body_stream' are mutually exclusive.") self.body = body self.body_stream = body_stream def __eq__(self, other): if other is None: return False - return (other.args == self.args - and other.body == self.body - and other.body_stream is self.body_stream) + return ( + other.args == self.args + and other.body == self.body + and other.body_stream is self.body_stream + ) def __repr__(self): return f"<{self.__class__.__name__} args={self.args!r} body={self.body!r}>" @@ -270,8 +272,7 @@ class SmartServerRequestHandler: # TODO: Better way of representing the body for commands that take it, # and allow it to be streamed into the server. - def __init__(self, backing_transport, commands, root_client_path, - jail_root=None): + def __init__(self, backing_transport, commands, root_client_path, jail_root=None): """Constructor. :param backing_transport: a Transport to handle requests for. @@ -287,7 +288,7 @@ def __init__(self, backing_transport, commands, root_client_path, self.response = None self.finished_reading = False self._command = None - if debug.debug_flag_enabled('hpss'): + if debug.debug_flag_enabled("hpss"): self._request_start_time = osutils.perf_counter() self._thread_id = get_ident() @@ -297,17 +298,16 @@ def _trace(self, action, message, extra_bytes=None, include_time=False): # that just putting it in a helper doesn't help a lot. And some state # is taken from the instance. if include_time: - t = f'{osutils.perf_counter() - self._request_start_time:5.3f}s ' + t = f"{osutils.perf_counter() - self._request_start_time:5.3f}s " else: - t = '' + t = "" if extra_bytes is None: - extra = '' + extra = "" else: - extra = ' ' + repr(extra_bytes[:40]) + extra = " " + repr(extra_bytes[:40]) if len(extra) > 33: - extra = extra[:29] + extra[-1] + '...' - trace.mutter('%12s: [%s] %s%s%s' - % (action, self._thread_id, t, message, extra)) + extra = extra[:29] + extra[-1] + "..." + trace.mutter("%12s: [%s] %s%s%s" % (action, self._thread_id, t, message, extra)) def accept_body(self, bytes): """Accept body data.""" @@ -315,17 +315,16 @@ def accept_body(self, bytes): # no active command object, so ignore the event. return self._run_handler_code(self._command.do_chunk, (bytes,), {}) - if debug.debug_flag_enabled('hpss'): - self._trace('accept body', - f'{len(bytes)} bytes', bytes) + if debug.debug_flag_enabled("hpss"): + self._trace("accept body", f"{len(bytes)} bytes", bytes) def end_of_body(self): """No more body data will be received.""" self._run_handler_code(self._command.do_end, (), {}) # cannot read after this. self.finished_reading = True - if debug.debug_flag_enabled('hpss'): - self._trace('end of body', '', include_time=True) + if debug.debug_flag_enabled("hpss"): + self._trace("end of body", "", include_time=True) def _run_handler_code(self, callable, args, kwargs): """Run some handler specific code 'callable'. @@ -360,8 +359,8 @@ def _call_converting_errors(self, callable, args, kwargs): def headers_received(self, headers): # Just a no-op at the moment. - if debug.debug_flag_enabled('hpss'): - self._trace('headers', repr(headers)) + if debug.debug_flag_enabled("hpss"): + self._trace("headers", repr(headers)) def args_received(self, args): cmd = args[0] @@ -369,19 +368,20 @@ def args_received(self, args): try: command = self._commands.get(cmd) except LookupError as e: - if debug.debug_flag_enabled('hpss'): - self._trace('hpss unknown request', - cmd, repr(args)[1:-1]) + if debug.debug_flag_enabled("hpss"): + self._trace("hpss unknown request", cmd, repr(args)[1:-1]) raise errors.UnknownSmartMethod(cmd) from e - if debug.debug_flag_enabled('hpss'): + if debug.debug_flag_enabled("hpss"): from . import vfs + if issubclass(command, vfs.VfsRequest): - action = 'hpss vfs req' + action = "hpss vfs req" else: - action = 'hpss request' - self._trace(action, f'{cmd} {repr(args)[1:-1]}') + action = "hpss request" + self._trace(action, f"{cmd} {repr(args)[1:-1]}") self._command = command( - self._backing_transport, self._root_client_path, self._jail_root) + self._backing_transport, self._root_client_path, self._jail_root + ) self._run_handler_code(self._command.execute, args, {}) def end_received(self): @@ -389,8 +389,8 @@ def end_received(self): # no active command object, so ignore the event. return self._run_handler_code(self._command.do_end, (), {}) - if debug.debug_flag_enabled('hpss'): - self._trace('end', '', include_time=True) + if debug.debug_flag_enabled("hpss"): + self._trace("end", "", include_time=True) def post_body_error_received(self, error_args): # Just a no-op at the moment. @@ -399,31 +399,44 @@ def post_body_error_received(self, error_args): def _translate_error(err): if isinstance(err, _mod_transport.NoSuchFile): - return (b'NoSuchFile', err.path.encode('utf-8')) + return (b"NoSuchFile", err.path.encode("utf-8")) elif isinstance(err, _mod_transport.FileExists): - return (b'FileExists', err.path.encode('utf-8')) + return (b"FileExists", err.path.encode("utf-8")) elif isinstance(err, errors.DirectoryNotEmpty): - return (b'DirectoryNotEmpty', err.path.encode('utf-8')) + return (b"DirectoryNotEmpty", err.path.encode("utf-8")) elif isinstance(err, errors.IncompatibleRepositories): - return (b'IncompatibleRepositories', str(err.source), str(err.target), - str(err.details)) + return ( + b"IncompatibleRepositories", + str(err.source), + str(err.target), + str(err.details), + ) elif isinstance(err, errors.ShortReadvError): - return (b'ShortReadvError', err.path.encode('utf-8') if err.path is not None else None, - str(err.offset).encode('ascii') if err.offset is not None else None, - str(err.length).encode('ascii') if err.length is not None else None, - str(err.actual).encode('ascii') if err.actual is not None else None) + return ( + b"ShortReadvError", + err.path.encode("utf-8") if err.path is not None else None, + str(err.offset).encode("ascii") if err.offset is not None else None, + str(err.length).encode("ascii") if err.length is not None else None, + str(err.actual).encode("ascii") if err.actual is not None else None, + ) elif isinstance(err, errors.RevisionNotPresent): - return (b'RevisionNotPresent', err.revision_id, err.file_id) + return (b"RevisionNotPresent", err.revision_id, err.file_id) elif isinstance(err, errors.UnstackableRepositoryFormat): - return ((b'UnstackableRepositoryFormat', - str(err.format).encode('utf-8'), err.url.encode('utf-8'))) + return ( + b"UnstackableRepositoryFormat", + str(err.format).encode("utf-8"), + err.url.encode("utf-8"), + ) elif isinstance(err, _mod_branch.UnstackableBranchFormat): - return (b'UnstackableBranchFormat', str(err.format).encode('utf-8'), - err.url.encode('utf-8')) + return ( + b"UnstackableBranchFormat", + str(err.format).encode("utf-8"), + err.url.encode("utf-8"), + ) elif isinstance(err, errors.NotStacked): - return (b'NotStacked',) + return (b"NotStacked",) elif isinstance(err, errors.BzrCheckError): - return (b'BzrCheckError', err.msg.encode('utf-8')) + return (b"BzrCheckError", err.msg.encode("utf-8")) elif isinstance(err, UnicodeError): # If it is a DecodeError, than most likely we are starting # with a plain string @@ -432,39 +445,50 @@ def _translate_error(err): # XXX: UTF-8 might have \x01 (our protocol v1 and v2 seperator # byte) in it, so this encoding could cause broken responses. # Newer clients use protocol v3, so will be fine. - val = 'u:' + str_or_unicode.encode('utf-8') + val = "u:" + str_or_unicode.encode("utf-8") else: - val = 's:' + str_or_unicode.encode('base64') + val = "s:" + str_or_unicode.encode("base64") # This handles UnicodeEncodeError or UnicodeDecodeError - return (err.__class__.__name__, err.encoding, val, str(err.start), - str(err.end), err.reason) + return ( + err.__class__.__name__, + err.encoding, + val, + str(err.start), + str(err.end), + err.reason, + ) elif isinstance(err, errors.TransportNotPossible): if err.msg == "readonly transport": - return (b'ReadOnlyError', ) + return (b"ReadOnlyError",) elif isinstance(err, errors.ReadError): # cannot read the file - return (b'ReadError', err.path) + return (b"ReadError", err.path) elif isinstance(err, errors.PermissionDenied): - return (b'PermissionDenied', err.path.encode('utf-8'), err.extra.encode('utf-8')) + return ( + b"PermissionDenied", + err.path.encode("utf-8"), + err.extra.encode("utf-8"), + ) elif isinstance(err, errors.TokenMismatch): - return (b'TokenMismatch', err.given_token, err.lock_token) + return (b"TokenMismatch", err.given_token, err.lock_token) elif isinstance(err, errors.LockContention): - return (b'LockContention',) + return (b"LockContention",) elif isinstance(err, errors.GhostRevisionsHaveNoRevno): - return (b'GhostRevisionsHaveNoRevno', err.revision_id, err.ghost_revision_id) + return (b"GhostRevisionsHaveNoRevno", err.revision_id, err.ghost_revision_id) elif isinstance(err, urlutils.InvalidURL): - return (b'InvalidURL', err.path.encode('utf-8'), err.extra.encode('utf-8')) + return (b"InvalidURL", err.path.encode("utf-8"), err.extra.encode("utf-8")) elif isinstance(err, MemoryError): # GZ 2011-02-24: Copy breezy.trace -Dmem_dump functionality here? - return (b'MemoryError',) + return (b"MemoryError",) elif isinstance(err, errors.AlreadyControlDirError): - return (b'AlreadyControlDir', err.path) + return (b"AlreadyControlDir", err.path) # Unserialisable error. Log it, and return a generic error trace.log_exception_quietly() - return (b'error', - trace._qualified_exception_name( - err.__class__, True).encode('utf-8'), - str(err).encode('utf-8')) + return ( + b"error", + trace._qualified_exception_name(err.__class__, True).encode("utf-8"), + str(err).encode("utf-8"), + ) class HelloRequest(SmartServerRequest): @@ -473,7 +497,7 @@ class HelloRequest(SmartServerRequest): """ def do(self): - return SuccessfulSmartServerResponse((b'ok', b'2')) + return SuccessfulSmartServerResponse((b"ok", b"2")) class GetBundleRequest(SmartServerRequest): @@ -501,9 +525,9 @@ class SmartServerIsReadonly(SmartServerRequest): def do(self): if self._backing_transport.is_readonly(): - answer = b'yes' + answer = b"yes" else: - answer = b'no' + answer = b"no" return SuccessfulSmartServerResponse((answer,)) @@ -533,266 +557,515 @@ def do(self): # file. If append succeeds, it moves the file pointer. request_handlers = registry.Registry[bytes, SmartServerRequest, str]() request_handlers.register_lazy( - b'append', 'breezy.bzr.smart.vfs', 'AppendRequest', info='mutate') -request_handlers.register_lazy( - b'Branch.break_lock', 'breezy.bzr.smart.branch', - 'SmartServerBranchBreakLock', info='idem') -request_handlers.register_lazy( - b'Branch.get_config_file', 'breezy.bzr.smart.branch', - 'SmartServerBranchGetConfigFile', info='read') -request_handlers.register_lazy( - b'Branch.get_parent', 'breezy.bzr.smart.branch', 'SmartServerBranchGetParent', - info='read') -request_handlers.register_lazy( - b'Branch.put_config_file', 'breezy.bzr.smart.branch', - 'SmartServerBranchPutConfigFile', info='idem') -request_handlers.register_lazy( - b'Branch.get_tags_bytes', 'breezy.bzr.smart.branch', - 'SmartServerBranchGetTagsBytes', info='read') -request_handlers.register_lazy( - b'Branch.set_tags_bytes', 'breezy.bzr.smart.branch', - 'SmartServerBranchSetTagsBytes', info='idem') -request_handlers.register_lazy( - b'Branch.heads_to_fetch', 'breezy.bzr.smart.branch', - 'SmartServerBranchHeadsToFetch', info='read') -request_handlers.register_lazy( - b'Branch.get_stacked_on_url', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestGetStackedOnURL', info='read') -request_handlers.register_lazy( - b'Branch.get_physical_lock_status', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestGetPhysicalLockStatus', info='read') -request_handlers.register_lazy( - b'Branch.last_revision_info', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestLastRevisionInfo', info='read') -request_handlers.register_lazy( - b'Branch.lock_write', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestLockWrite', info='semi') -request_handlers.register_lazy( - b'Branch.revision_history', 'breezy.bzr.smart.branch', - 'SmartServerRequestRevisionHistory', info='read') -request_handlers.register_lazy( - b'Branch.set_config_option', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestSetConfigOption', info='idem') -request_handlers.register_lazy( - b'Branch.set_config_option_dict', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestSetConfigOptionDict', info='idem') -request_handlers.register_lazy( - b'Branch.set_last_revision', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestSetLastRevision', info='idem') -request_handlers.register_lazy( - b'Branch.set_last_revision_info', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestSetLastRevisionInfo', info='idem') -request_handlers.register_lazy( - b'Branch.set_last_revision_ex', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestSetLastRevisionEx', info='idem') -request_handlers.register_lazy( - b'Branch.set_parent_location', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestSetParentLocation', info='idem') -request_handlers.register_lazy( - b'Branch.unlock', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestUnlock', info='semi') -request_handlers.register_lazy( - b'Branch.revision_id_to_revno', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestRevisionIdToRevno', info='read') -request_handlers.register_lazy( - b'Branch.get_all_reference_info', 'breezy.bzr.smart.branch', - 'SmartServerBranchRequestGetAllReferenceInfo', info='read') -request_handlers.register_lazy( - b'BzrDir.checkout_metadir', 'breezy.bzr.smart.bzrdir', - 'SmartServerBzrDirRequestCheckoutMetaDir', info='read') -request_handlers.register_lazy( - b'BzrDir.cloning_metadir', 'breezy.bzr.smart.bzrdir', - 'SmartServerBzrDirRequestCloningMetaDir', info='read') -request_handlers.register_lazy( - b'BzrDir.create_branch', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestCreateBranch', info='semi') -request_handlers.register_lazy( - b'BzrDir.create_repository', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestCreateRepository', info='semi') -request_handlers.register_lazy( - b'BzrDir.find_repository', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestFindRepositoryV1', info='read') -request_handlers.register_lazy( - b'BzrDir.find_repositoryV2', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestFindRepositoryV2', info='read') -request_handlers.register_lazy( - b'BzrDir.find_repositoryV3', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestFindRepositoryV3', info='read') -request_handlers.register_lazy( - b'BzrDir.get_branches', 'breezy.bzr.smart.bzrdir', - 'SmartServerBzrDirRequestGetBranches', info='read') -request_handlers.register_lazy( - b'BzrDir.get_config_file', 'breezy.bzr.smart.bzrdir', - 'SmartServerBzrDirRequestConfigFile', info='read') -request_handlers.register_lazy( - b'BzrDir.destroy_branch', 'breezy.bzr.smart.bzrdir', - 'SmartServerBzrDirRequestDestroyBranch', info='semi') -request_handlers.register_lazy( - b'BzrDir.destroy_repository', 'breezy.bzr.smart.bzrdir', - 'SmartServerBzrDirRequestDestroyRepository', info='semi') -request_handlers.register_lazy( - b'BzrDir.has_workingtree', 'breezy.bzr.smart.bzrdir', - 'SmartServerBzrDirRequestHasWorkingTree', info='read') -request_handlers.register_lazy( - b'BzrDirFormat.initialize', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestInitializeBzrDir', info='semi') -request_handlers.register_lazy( - b'BzrDirFormat.initialize_ex_1.16', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestBzrDirInitializeEx', info='semi') -request_handlers.register_lazy( - b'BzrDir.open', 'breezy.bzr.smart.bzrdir', 'SmartServerRequestOpenBzrDir', - info='read') -request_handlers.register_lazy( - b'BzrDir.open_2.1', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestOpenBzrDir_2_1', info='read') -request_handlers.register_lazy( - b'BzrDir.open_branch', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestOpenBranch', info='read') -request_handlers.register_lazy( - b'BzrDir.open_branchV2', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestOpenBranchV2', info='read') -request_handlers.register_lazy( - b'BzrDir.open_branchV3', 'breezy.bzr.smart.bzrdir', - 'SmartServerRequestOpenBranchV3', info='read') -request_handlers.register_lazy( - b'delete', 'breezy.bzr.smart.vfs', 'DeleteRequest', info='semivfs') -request_handlers.register_lazy( - b'get', 'breezy.bzr.smart.vfs', 'GetRequest', info='read') -request_handlers.register_lazy( - b'get_bundle', 'breezy.bzr.smart.request', 'GetBundleRequest', info='read') -request_handlers.register_lazy( - b'has', 'breezy.bzr.smart.vfs', 'HasRequest', info='read') -request_handlers.register_lazy( - b'hello', 'breezy.bzr.smart.request', 'HelloRequest', info='read') -request_handlers.register_lazy( - b'iter_files_recursive', 'breezy.bzr.smart.vfs', 'IterFilesRecursiveRequest', - info='read') -request_handlers.register_lazy( - b'list_dir', 'breezy.bzr.smart.vfs', 'ListDirRequest', info='read') -request_handlers.register_lazy( - b'mkdir', 'breezy.bzr.smart.vfs', 'MkdirRequest', info='semivfs') -request_handlers.register_lazy( - b'move', 'breezy.bzr.smart.vfs', 'MoveRequest', info='semivfs') -request_handlers.register_lazy( - b'put', 'breezy.bzr.smart.vfs', 'PutRequest', info='idem') -request_handlers.register_lazy( - b'put_non_atomic', 'breezy.bzr.smart.vfs', 'PutNonAtomicRequest', info='idem') -request_handlers.register_lazy( - b'readv', 'breezy.bzr.smart.vfs', 'ReadvRequest', info='read') -request_handlers.register_lazy( - b'rename', 'breezy.bzr.smart.vfs', 'RenameRequest', info='semivfs') -request_handlers.register_lazy( - b'Repository.add_signature_text', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryAddSignatureText', info='idem') -request_handlers.register_lazy( - b'Repository.annotate_file_revision', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryAnnotateFileRevision', info='read') -request_handlers.register_lazy( - b'Repository.all_revision_ids', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryAllRevisionIds', info='read') -request_handlers.register_lazy( - b'PackRepository.autopack', 'breezy.bzr.smart.packrepository', - 'SmartServerPackRepositoryAutopack', info='idem') -request_handlers.register_lazy( - b'Repository.break_lock', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryBreakLock', info='idem') -request_handlers.register_lazy( - b'Repository.gather_stats', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGatherStats', info='read') -request_handlers.register_lazy( - b'Repository.get_parent_map', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetParentMap', info='read') -request_handlers.register_lazy( - b'Repository.get_revision_graph', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetRevisionGraph', info='read') -request_handlers.register_lazy( - b'Repository.get_revision_signature_text', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetRevisionSignatureText', info='read') -request_handlers.register_lazy( - b'Repository.has_revision', 'breezy.bzr.smart.repository', - 'SmartServerRequestHasRevision', info='read') -request_handlers.register_lazy( - b'Repository.has_signature_for_revision_id', 'breezy.bzr.smart.repository', - 'SmartServerRequestHasSignatureForRevisionId', info='read') -request_handlers.register_lazy( - b'Repository.insert_stream', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryInsertStream', info='stream') -request_handlers.register_lazy( - b'Repository.insert_stream_1.19', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryInsertStream_1_19', info='stream') -request_handlers.register_lazy( - b'Repository.insert_stream_locked', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryInsertStreamLocked', info='stream') -request_handlers.register_lazy( - b'Repository.is_shared', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryIsShared', info='read') -request_handlers.register_lazy( - b'Repository.iter_files_bytes', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryIterFilesBytes', info='read') -request_handlers.register_lazy( - b'Repository.lock_write', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryLockWrite', info='semi') -request_handlers.register_lazy( - b'Repository.make_working_trees', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryMakeWorkingTrees', info='read') -request_handlers.register_lazy( - b'Repository.set_make_working_trees', 'breezy.bzr.smart.repository', - 'SmartServerRepositorySetMakeWorkingTrees', info='idem') -request_handlers.register_lazy( - b'Repository.unlock', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryUnlock', info='semi') -request_handlers.register_lazy( - b'Repository.get_physical_lock_status', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetPhysicalLockStatus', info='read') -request_handlers.register_lazy( - b'Repository.get_rev_id_for_revno', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetRevIdForRevno', info='read') -request_handlers.register_lazy( - b'Repository.get_stream', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetStream', info='read') -request_handlers.register_lazy( - b'Repository.get_stream_1.19', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetStream_1_19', info='read') -request_handlers.register_lazy( - b'Repository.get_stream_for_missing_keys', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetStreamForMissingKeys', info='read') -request_handlers.register_lazy( - b'Repository.iter_revisions', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryIterRevisions', info='read') -request_handlers.register_lazy( - b'Repository.pack', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryPack', info='idem') -request_handlers.register_lazy( - b'Repository.start_write_group', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryStartWriteGroup', info='semi') -request_handlers.register_lazy( - b'Repository.commit_write_group', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryCommitWriteGroup', info='semi') -request_handlers.register_lazy( - b'Repository.abort_write_group', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryAbortWriteGroup', info='semi') -request_handlers.register_lazy( - b'Repository.check_write_group', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryCheckWriteGroup', info='read') -request_handlers.register_lazy( - b'Repository.reconcile', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryReconcile', info='idem') -request_handlers.register_lazy( - b'Repository.revision_archive', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryRevisionArchive', info='read') -request_handlers.register_lazy( - b'Repository.tarball', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryTarball', info='read') -request_handlers.register_lazy( - b'VersionedFileRepository.get_serializer_format', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetSerializerFormat', info='read') -request_handlers.register_lazy( - b'VersionedFileRepository.get_inventories', 'breezy.bzr.smart.repository', - 'SmartServerRepositoryGetInventories', info='read') -request_handlers.register_lazy( - b'rmdir', 'breezy.bzr.smart.vfs', 'RmdirRequest', info='semivfs') -request_handlers.register_lazy( - b'stat', 'breezy.bzr.smart.vfs', 'StatRequest', info='read') -request_handlers.register_lazy( - b'Transport.is_readonly', 'breezy.bzr.smart.request', - 'SmartServerIsReadonly', info='read') + b"append", "breezy.bzr.smart.vfs", "AppendRequest", info="mutate" +) +request_handlers.register_lazy( + b"Branch.break_lock", + "breezy.bzr.smart.branch", + "SmartServerBranchBreakLock", + info="idem", +) +request_handlers.register_lazy( + b"Branch.get_config_file", + "breezy.bzr.smart.branch", + "SmartServerBranchGetConfigFile", + info="read", +) +request_handlers.register_lazy( + b"Branch.get_parent", + "breezy.bzr.smart.branch", + "SmartServerBranchGetParent", + info="read", +) +request_handlers.register_lazy( + b"Branch.put_config_file", + "breezy.bzr.smart.branch", + "SmartServerBranchPutConfigFile", + info="idem", +) +request_handlers.register_lazy( + b"Branch.get_tags_bytes", + "breezy.bzr.smart.branch", + "SmartServerBranchGetTagsBytes", + info="read", +) +request_handlers.register_lazy( + b"Branch.set_tags_bytes", + "breezy.bzr.smart.branch", + "SmartServerBranchSetTagsBytes", + info="idem", +) +request_handlers.register_lazy( + b"Branch.heads_to_fetch", + "breezy.bzr.smart.branch", + "SmartServerBranchHeadsToFetch", + info="read", +) +request_handlers.register_lazy( + b"Branch.get_stacked_on_url", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestGetStackedOnURL", + info="read", +) +request_handlers.register_lazy( + b"Branch.get_physical_lock_status", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestGetPhysicalLockStatus", + info="read", +) +request_handlers.register_lazy( + b"Branch.last_revision_info", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestLastRevisionInfo", + info="read", +) +request_handlers.register_lazy( + b"Branch.lock_write", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestLockWrite", + info="semi", +) +request_handlers.register_lazy( + b"Branch.revision_history", + "breezy.bzr.smart.branch", + "SmartServerRequestRevisionHistory", + info="read", +) +request_handlers.register_lazy( + b"Branch.set_config_option", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestSetConfigOption", + info="idem", +) +request_handlers.register_lazy( + b"Branch.set_config_option_dict", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestSetConfigOptionDict", + info="idem", +) +request_handlers.register_lazy( + b"Branch.set_last_revision", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestSetLastRevision", + info="idem", +) +request_handlers.register_lazy( + b"Branch.set_last_revision_info", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestSetLastRevisionInfo", + info="idem", +) +request_handlers.register_lazy( + b"Branch.set_last_revision_ex", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestSetLastRevisionEx", + info="idem", +) +request_handlers.register_lazy( + b"Branch.set_parent_location", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestSetParentLocation", + info="idem", +) +request_handlers.register_lazy( + b"Branch.unlock", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestUnlock", + info="semi", +) +request_handlers.register_lazy( + b"Branch.revision_id_to_revno", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestRevisionIdToRevno", + info="read", +) +request_handlers.register_lazy( + b"Branch.get_all_reference_info", + "breezy.bzr.smart.branch", + "SmartServerBranchRequestGetAllReferenceInfo", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.checkout_metadir", + "breezy.bzr.smart.bzrdir", + "SmartServerBzrDirRequestCheckoutMetaDir", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.cloning_metadir", + "breezy.bzr.smart.bzrdir", + "SmartServerBzrDirRequestCloningMetaDir", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.create_branch", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestCreateBranch", + info="semi", +) +request_handlers.register_lazy( + b"BzrDir.create_repository", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestCreateRepository", + info="semi", +) +request_handlers.register_lazy( + b"BzrDir.find_repository", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestFindRepositoryV1", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.find_repositoryV2", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestFindRepositoryV2", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.find_repositoryV3", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestFindRepositoryV3", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.get_branches", + "breezy.bzr.smart.bzrdir", + "SmartServerBzrDirRequestGetBranches", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.get_config_file", + "breezy.bzr.smart.bzrdir", + "SmartServerBzrDirRequestConfigFile", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.destroy_branch", + "breezy.bzr.smart.bzrdir", + "SmartServerBzrDirRequestDestroyBranch", + info="semi", +) +request_handlers.register_lazy( + b"BzrDir.destroy_repository", + "breezy.bzr.smart.bzrdir", + "SmartServerBzrDirRequestDestroyRepository", + info="semi", +) +request_handlers.register_lazy( + b"BzrDir.has_workingtree", + "breezy.bzr.smart.bzrdir", + "SmartServerBzrDirRequestHasWorkingTree", + info="read", +) +request_handlers.register_lazy( + b"BzrDirFormat.initialize", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestInitializeBzrDir", + info="semi", +) +request_handlers.register_lazy( + b"BzrDirFormat.initialize_ex_1.16", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestBzrDirInitializeEx", + info="semi", +) +request_handlers.register_lazy( + b"BzrDir.open", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestOpenBzrDir", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.open_2.1", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestOpenBzrDir_2_1", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.open_branch", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestOpenBranch", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.open_branchV2", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestOpenBranchV2", + info="read", +) +request_handlers.register_lazy( + b"BzrDir.open_branchV3", + "breezy.bzr.smart.bzrdir", + "SmartServerRequestOpenBranchV3", + info="read", +) +request_handlers.register_lazy( + b"delete", "breezy.bzr.smart.vfs", "DeleteRequest", info="semivfs" +) +request_handlers.register_lazy( + b"get", "breezy.bzr.smart.vfs", "GetRequest", info="read" +) +request_handlers.register_lazy( + b"get_bundle", "breezy.bzr.smart.request", "GetBundleRequest", info="read" +) +request_handlers.register_lazy( + b"has", "breezy.bzr.smart.vfs", "HasRequest", info="read" +) +request_handlers.register_lazy( + b"hello", "breezy.bzr.smart.request", "HelloRequest", info="read" +) +request_handlers.register_lazy( + b"iter_files_recursive", + "breezy.bzr.smart.vfs", + "IterFilesRecursiveRequest", + info="read", +) +request_handlers.register_lazy( + b"list_dir", "breezy.bzr.smart.vfs", "ListDirRequest", info="read" +) +request_handlers.register_lazy( + b"mkdir", "breezy.bzr.smart.vfs", "MkdirRequest", info="semivfs" +) +request_handlers.register_lazy( + b"move", "breezy.bzr.smart.vfs", "MoveRequest", info="semivfs" +) +request_handlers.register_lazy( + b"put", "breezy.bzr.smart.vfs", "PutRequest", info="idem" +) +request_handlers.register_lazy( + b"put_non_atomic", "breezy.bzr.smart.vfs", "PutNonAtomicRequest", info="idem" +) +request_handlers.register_lazy( + b"readv", "breezy.bzr.smart.vfs", "ReadvRequest", info="read" +) +request_handlers.register_lazy( + b"rename", "breezy.bzr.smart.vfs", "RenameRequest", info="semivfs" +) +request_handlers.register_lazy( + b"Repository.add_signature_text", + "breezy.bzr.smart.repository", + "SmartServerRepositoryAddSignatureText", + info="idem", +) +request_handlers.register_lazy( + b"Repository.annotate_file_revision", + "breezy.bzr.smart.repository", + "SmartServerRepositoryAnnotateFileRevision", + info="read", +) +request_handlers.register_lazy( + b"Repository.all_revision_ids", + "breezy.bzr.smart.repository", + "SmartServerRepositoryAllRevisionIds", + info="read", +) +request_handlers.register_lazy( + b"PackRepository.autopack", + "breezy.bzr.smart.packrepository", + "SmartServerPackRepositoryAutopack", + info="idem", +) +request_handlers.register_lazy( + b"Repository.break_lock", + "breezy.bzr.smart.repository", + "SmartServerRepositoryBreakLock", + info="idem", +) +request_handlers.register_lazy( + b"Repository.gather_stats", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGatherStats", + info="read", +) +request_handlers.register_lazy( + b"Repository.get_parent_map", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetParentMap", + info="read", +) +request_handlers.register_lazy( + b"Repository.get_revision_graph", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetRevisionGraph", + info="read", +) +request_handlers.register_lazy( + b"Repository.get_revision_signature_text", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetRevisionSignatureText", + info="read", +) +request_handlers.register_lazy( + b"Repository.has_revision", + "breezy.bzr.smart.repository", + "SmartServerRequestHasRevision", + info="read", +) +request_handlers.register_lazy( + b"Repository.has_signature_for_revision_id", + "breezy.bzr.smart.repository", + "SmartServerRequestHasSignatureForRevisionId", + info="read", +) +request_handlers.register_lazy( + b"Repository.insert_stream", + "breezy.bzr.smart.repository", + "SmartServerRepositoryInsertStream", + info="stream", +) +request_handlers.register_lazy( + b"Repository.insert_stream_1.19", + "breezy.bzr.smart.repository", + "SmartServerRepositoryInsertStream_1_19", + info="stream", +) +request_handlers.register_lazy( + b"Repository.insert_stream_locked", + "breezy.bzr.smart.repository", + "SmartServerRepositoryInsertStreamLocked", + info="stream", +) +request_handlers.register_lazy( + b"Repository.is_shared", + "breezy.bzr.smart.repository", + "SmartServerRepositoryIsShared", + info="read", +) +request_handlers.register_lazy( + b"Repository.iter_files_bytes", + "breezy.bzr.smart.repository", + "SmartServerRepositoryIterFilesBytes", + info="read", +) +request_handlers.register_lazy( + b"Repository.lock_write", + "breezy.bzr.smart.repository", + "SmartServerRepositoryLockWrite", + info="semi", +) +request_handlers.register_lazy( + b"Repository.make_working_trees", + "breezy.bzr.smart.repository", + "SmartServerRepositoryMakeWorkingTrees", + info="read", +) +request_handlers.register_lazy( + b"Repository.set_make_working_trees", + "breezy.bzr.smart.repository", + "SmartServerRepositorySetMakeWorkingTrees", + info="idem", +) +request_handlers.register_lazy( + b"Repository.unlock", + "breezy.bzr.smart.repository", + "SmartServerRepositoryUnlock", + info="semi", +) +request_handlers.register_lazy( + b"Repository.get_physical_lock_status", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetPhysicalLockStatus", + info="read", +) +request_handlers.register_lazy( + b"Repository.get_rev_id_for_revno", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetRevIdForRevno", + info="read", +) +request_handlers.register_lazy( + b"Repository.get_stream", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetStream", + info="read", +) +request_handlers.register_lazy( + b"Repository.get_stream_1.19", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetStream_1_19", + info="read", +) +request_handlers.register_lazy( + b"Repository.get_stream_for_missing_keys", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetStreamForMissingKeys", + info="read", +) +request_handlers.register_lazy( + b"Repository.iter_revisions", + "breezy.bzr.smart.repository", + "SmartServerRepositoryIterRevisions", + info="read", +) +request_handlers.register_lazy( + b"Repository.pack", + "breezy.bzr.smart.repository", + "SmartServerRepositoryPack", + info="idem", +) +request_handlers.register_lazy( + b"Repository.start_write_group", + "breezy.bzr.smart.repository", + "SmartServerRepositoryStartWriteGroup", + info="semi", +) +request_handlers.register_lazy( + b"Repository.commit_write_group", + "breezy.bzr.smart.repository", + "SmartServerRepositoryCommitWriteGroup", + info="semi", +) +request_handlers.register_lazy( + b"Repository.abort_write_group", + "breezy.bzr.smart.repository", + "SmartServerRepositoryAbortWriteGroup", + info="semi", +) +request_handlers.register_lazy( + b"Repository.check_write_group", + "breezy.bzr.smart.repository", + "SmartServerRepositoryCheckWriteGroup", + info="read", +) +request_handlers.register_lazy( + b"Repository.reconcile", + "breezy.bzr.smart.repository", + "SmartServerRepositoryReconcile", + info="idem", +) +request_handlers.register_lazy( + b"Repository.revision_archive", + "breezy.bzr.smart.repository", + "SmartServerRepositoryRevisionArchive", + info="read", +) +request_handlers.register_lazy( + b"Repository.tarball", + "breezy.bzr.smart.repository", + "SmartServerRepositoryTarball", + info="read", +) +request_handlers.register_lazy( + b"VersionedFileRepository.get_serializer_format", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetSerializerFormat", + info="read", +) +request_handlers.register_lazy( + b"VersionedFileRepository.get_inventories", + "breezy.bzr.smart.repository", + "SmartServerRepositoryGetInventories", + info="read", +) +request_handlers.register_lazy( + b"rmdir", "breezy.bzr.smart.vfs", "RmdirRequest", info="semivfs" +) +request_handlers.register_lazy( + b"stat", "breezy.bzr.smart.vfs", "StatRequest", info="read" +) +request_handlers.register_lazy( + b"Transport.is_readonly", + "breezy.bzr.smart.request", + "SmartServerIsReadonly", + info="read", +) diff --git a/breezy/bzr/smart/server.py b/breezy/bzr/smart/server.py index 83cf58b1ca..e4b62e31a4 100644 --- a/breezy/bzr/smart/server.py +++ b/breezy/bzr/smart/server.py @@ -29,7 +29,9 @@ from ...i18n import gettext from ...lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy.bzr.smart import ( medium, signals, @@ -42,7 +44,8 @@ config, urlutils, ) -""") +""", +) class SmartTCPServer: @@ -63,8 +66,7 @@ class SmartTCPServer: _timer = time.time - def __init__(self, backing_transport, root_client_path='/', - client_timeout=None): + def __init__(self, backing_transport, root_client_path="/", client_timeout=None): """Construct a new server. To actually start it running, call either start_background_thread or @@ -95,18 +97,19 @@ def start_server(self, host, port): # module's globals get set to None during interpreter shutdown. from socket import error as socket_error from socket import timeout as socket_timeout + self._socket_error = socket_error self._socket_timeout = socket_timeout - addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM, 0, socket.AI_PASSIVE)[0] + addrs = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE + )[0] (family, socktype, proto, canonname, sockaddr) = addrs self._server_socket = socket.socket(family, socktype, proto) # SO_REUSERADDR has a different meaning on Windows - if sys.platform != 'win32': - self._server_socket.setsockopt(socket.SOL_SOCKET, - socket.SO_REUSEADDR, 1) + if sys.platform != "win32": + self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: self._server_socket.bind(sockaddr) except self._socket_error as message: @@ -150,19 +153,19 @@ def _backing_urls(self): def run_server_started_hooks(self, backing_urls=None): if backing_urls is None: backing_urls = self._backing_urls() - for hook in SmartTCPServer.hooks['server_started']: + for hook in SmartTCPServer.hooks["server_started"]: hook(backing_urls, self.get_url()) - for hook in SmartTCPServer.hooks['server_started_ex']: + for hook in SmartTCPServer.hooks["server_started_ex"]: hook(backing_urls, self) def run_server_stopped_hooks(self, backing_urls=None): if backing_urls is None: backing_urls = self._backing_urls() - for hook in SmartTCPServer.hooks['server_stopped']: + for hook in SmartTCPServer.hooks["server_stopped"]: hook(backing_urls, self.get_url()) def _stop_gracefully(self): - trace.note(gettext('Requested to stop gracefully')) + trace.note(gettext("Requested to stop gracefully")) self._should_terminate = True self._gracefully_stopping = True for handler, _ in self._active_connections: @@ -172,18 +175,22 @@ def _wait_for_clients_to_disconnect(self): self._poll_active_connections() if not self._active_connections: return - trace.note(gettext('Waiting for %d client(s) to finish') - % (len(self._active_connections),)) + trace.note( + gettext("Waiting for %d client(s) to finish") + % (len(self._active_connections),) + ) t_next_log = self._timer() + self._LOG_WAITING_TIMEOUT while self._active_connections: now = self._timer() if now >= t_next_log: - trace.note(gettext('Still waiting for %d client(s) to finish') - % (len(self._active_connections),)) + trace.note( + gettext("Still waiting for %d client(s) to finish") + % (len(self._active_connections),) + ) t_next_log = now + self._LOG_WAITING_TIMEOUT self._poll_active_connections(self._SHUTDOWN_POLL_TIMEOUT) - def serve(self, thread_name_suffix=''): + def serve(self, thread_name_suffix=""): # Note: There is a temptation to do # signals.register_on_hangup(id(self), self._stop_gracefully) # However, that creates a temporary object which is a bound @@ -212,8 +219,7 @@ def serve(self, thread_name_suffix=''): # can get EINTR, any other socket errors should get # logged. if e.args[0] not in (errno.EBADF, errno.EINTR): - trace.warning(gettext("listening socket error: %s") - % (e,)) + trace.warning(gettext("listening socket error: %s") % (e,)) else: if self._should_terminate: conn.close() @@ -247,8 +253,11 @@ def get_url(self): def _make_handler(self, conn): return medium.SmartServerSocketStreamMedium( - conn, self.backing_transport, self.root_client_path, - timeout=self._client_timeout) + conn, + self.backing_transport, + self.root_client_path, + timeout=self._client_timeout, + ) def _poll_active_connections(self, timeout=0.0): """Check to see if any active connections have finished. @@ -272,20 +281,24 @@ def serve_conn(self, conn, thread_name_suffix): # propagates to the newly accepted socket. conn.setblocking(True) conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - thread_name = 'smart-server-child' + thread_name_suffix + thread_name = "smart-server-child" + thread_name_suffix handler = self._make_handler(conn) connection_thread = threading.Thread( - None, handler.serve, name=thread_name, daemon=True) + None, handler.serve, name=thread_name, daemon=True + ) self._active_connections.append((handler, connection_thread)) connection_thread.start() return connection_thread - def start_background_thread(self, thread_name_suffix=''): + def start_background_thread(self, thread_name_suffix=""): self._started.clear() self._server_thread = threading.Thread( - None, self.serve, args=(thread_name_suffix,), - name='server-' + self.get_url(), - daemon=True) + None, + self.serve, + args=(thread_name_suffix,), + name="server-" + self.get_url(), + daemon=True, + ) self._server_thread.start() self._started.wait() @@ -325,25 +338,36 @@ def __init__(self): notified. """ Hooks.__init__(self, "breezy.bzr.smart.server", "SmartTCPServer.hooks") - self.add_hook('server_started', - "Called by the bzr server when it starts serving a directory. " - "server_started is called with (backing urls, public url), " - "where backing_url is a list of URLs giving the " - "server-specific directory locations, and public_url is the " - "public URL for the directory being served.", (0, 16)) - self.add_hook('server_started_ex', - "Called by the bzr server when it starts serving a directory. " - "server_started is called with (backing_urls, server_obj).", - (1, 17)) - self.add_hook('server_stopped', - "Called by the bzr server when it stops serving a directory. " - "server_stopped is called with the same parameters as the " - "server_started hook: (backing_urls, public_url).", (0, 16)) - self.add_hook('server_exception', - "Called by the bzr server when an exception occurs. " - "server_exception is called with the sys.exc_info() tuple " - "return true for the hook if the exception has been handled, " - "in which case the server will exit normally.", (2, 4)) + self.add_hook( + "server_started", + "Called by the bzr server when it starts serving a directory. " + "server_started is called with (backing urls, public url), " + "where backing_url is a list of URLs giving the " + "server-specific directory locations, and public_url is the " + "public URL for the directory being served.", + (0, 16), + ) + self.add_hook( + "server_started_ex", + "Called by the bzr server when it starts serving a directory. " + "server_started is called with (backing_urls, server_obj).", + (1, 17), + ) + self.add_hook( + "server_stopped", + "Called by the bzr server when it stops serving a directory. " + "server_stopped is called with the same parameters as the " + "server_started hook: (backing_urls, public_url).", + (0, 16), + ) + self.add_hook( + "server_exception", + "Called by the bzr server when an exception occurs. " + "server_exception is called with the sys.exc_info() tuple " + "return true for the hook if the exception has been handled, " + "in which case the server will exit normally.", + (2, 4), + ) SmartTCPServer.hooks = SmartServerHooks() # type: ignore @@ -364,8 +388,8 @@ def _local_path_for_transport(transport): return None else: # Strip readonly prefix - if base_url.startswith('readonly+'): - base_url = base_url[len('readonly+'):] + if base_url.startswith("readonly+"): + base_url = base_url[len("readonly+") :] try: return urlutils.local_path_from_url(base_url) except urlutils.InvalidURL: @@ -399,12 +423,12 @@ def _expand_userdirs(self, path): the translated path is joe. """ result = path - if path.startswith('~'): + if path.startswith("~"): expanded = self.userdir_expander(path) - if not expanded.endswith('/'): - expanded += '/' + if not expanded.endswith("/"): + expanded += "/" if expanded.startswith(self.base_path): - result = expanded[len(self.base_path):] + result = expanded[len(self.base_path) :] return result def _make_expand_userdirs_filter(self, transport): @@ -416,16 +440,14 @@ def _make_backing_transport(self, transport): chroot_server = chroot.ChrootServer(transport) chroot_server.start_server() self.cleanups.append(chroot_server.stop_server) - transport = _mod_transport.get_transport_from_url( - chroot_server.get_url()) + transport = _mod_transport.get_transport_from_url(chroot_server.get_url()) if self.base_path is not None: # Decorate the server's backing transport with a filter that can # expand homedirs. expand_userdirs = self._make_expand_userdirs_filter(transport) expand_userdirs.start_server() self.cleanups.append(expand_userdirs.stop_server) - transport = _mod_transport.get_transport_from_url( - expand_userdirs.get_url()) + transport = _mod_transport.get_transport_from_url(expand_userdirs.get_url()) self.transport = transport def _get_stdin_stdout(self): @@ -434,21 +456,20 @@ def _get_stdin_stdout(self): def _make_smart_server(self, host, port, inet, timeout): if timeout is None: c = config.GlobalStack() - timeout = c.get('serve.client_timeout') + timeout = c.get("serve.client_timeout") if inet: stdin, stdout = self._get_stdin_stdout() smart_server = medium.SmartServerPipeStreamMedium( - stdin, stdout, self.transport, timeout=timeout) + stdin, stdout, self.transport, timeout=timeout + ) else: if host is None: host = medium.BZR_DEFAULT_INTERFACE if port is None: port = medium.BZR_DEFAULT_PORT - smart_server = SmartTCPServer(self.transport, - client_timeout=timeout) + smart_server = SmartTCPServer(self.transport, client_timeout=timeout) smart_server.start_server(host, port) - trace.note(gettext('listening on port: %s'), - str(smart_server.port)) + trace.note(gettext("listening on port: %s"), str(smart_server.port)) self.smart_server = smart_server def _change_globals(self): @@ -464,6 +485,7 @@ def _change_globals(self): def restore_default_ui_factory_and_lockdir_timeout(): ui.ui_factory = old_factory lockdir._DEFAULT_TIMEOUT_SECONDS = old_lockdir_timeout + self.cleanups.append(restore_default_ui_factory_and_lockdir_timeout) ui.ui_factory = ui.SilentUIFactory() lockdir._DEFAULT_TIMEOUT_SECONDS = 0 @@ -471,6 +493,7 @@ def restore_default_ui_factory_and_lockdir_timeout(): def restore_signals(): signals.restore_sighup_handler(orig) + self.cleanups.append(restore_signals) def set_up(self, transport, host, port, inet, timeout): @@ -496,7 +519,7 @@ def serve_bzr(transport, host=None, port=None, inet=False, timeout=None): bzr_server.smart_server.serve() except BaseException: hook_caught_exception = False - for hook in SmartTCPServer.hooks['server_exception']: + for hook in SmartTCPServer.hooks["server_exception"]: hook_caught_exception = hook(sys.exc_info()) if not hook_caught_exception: raise diff --git a/breezy/bzr/smart/signals.py b/breezy/bzr/smart/signals.py index 80fb9daede..d446bc1c71 100644 --- a/breezy/bzr/smart/signals.py +++ b/breezy/bzr/smart/signals.py @@ -39,7 +39,7 @@ def _sighup_handler(signal_number, interrupted_frame): """ if _on_sighup is None: return - trace.mutter('Caught SIGHUP, sending graceful shutdown requests.') + trace.mutter("Caught SIGHUP, sending graceful shutdown requests.") for ref in _on_sighup.valuerefs(): try: cb = ref() @@ -48,7 +48,7 @@ def _sighup_handler(signal_number, interrupted_frame): except KeyboardInterrupt: raise except Exception: - trace.mutter('Error occurred while running SIGHUP handlers:') + trace.mutter("Error occurred while running SIGHUP handlers:") trace.log_exception_quietly() @@ -110,5 +110,5 @@ def unregister_on_hangup(identifier): except Exception: # This usually runs as a tear-down step. So we don't want to propagate # most exceptions. - trace.mutter('Error occurred during unregister_on_hangup:') + trace.mutter("Error occurred during unregister_on_hangup:") trace.log_exception_quietly() diff --git a/breezy/bzr/smart/vfs.py b/breezy/bzr/smart/vfs.py index 848bb297c2..a8f70281d1 100644 --- a/breezy/bzr/smart/vfs.py +++ b/breezy/bzr/smart/vfs.py @@ -33,7 +33,7 @@ def _deserialise_optional_mode(mode): # XXX: FIXME this should be on the protocol object. Later protocol versions # might serialise modes differently. - if mode == b'': + if mode == b"": return None else: return int(mode) @@ -46,7 +46,7 @@ def vfs_enabled(): :return: ``True`` if it is enabled. """ - return 'BRZ_NO_SMART_VFS' not in os.environ + return "BRZ_NO_SMART_VFS" not in os.environ class VfsRequest(request.SmartServerRequest): @@ -69,23 +69,20 @@ def translate_client_path(self, relpath): class HasRequest(VfsRequest): - def do(self, relpath): relpath = self.translate_client_path(relpath) - r = self._backing_transport.has(relpath) and b'yes' or b'no' + r = self._backing_transport.has(relpath) and b"yes" or b"no" return request.SuccessfulSmartServerResponse((r,)) class GetRequest(VfsRequest): - def do(self, relpath): relpath = self.translate_client_path(relpath) backing_bytes = self._backing_transport.get_bytes(relpath) - return request.SuccessfulSmartServerResponse((b'ok',), backing_bytes) + return request.SuccessfulSmartServerResponse((b"ok",), backing_bytes) class AppendRequest(VfsRequest): - def do(self, relpath, mode): relpath = self.translate_client_path(relpath) self._relpath = relpath @@ -93,91 +90,88 @@ def do(self, relpath, mode): def do_body(self, body_bytes): old_length = self._backing_transport.append_bytes( - self._relpath, body_bytes, self._mode) - return request.SuccessfulSmartServerResponse((b'appended', str(old_length).encode('ascii'))) + self._relpath, body_bytes, self._mode + ) + return request.SuccessfulSmartServerResponse( + (b"appended", str(old_length).encode("ascii")) + ) class DeleteRequest(VfsRequest): - def do(self, relpath): relpath = self.translate_client_path(relpath) self._backing_transport.delete(relpath) - return request.SuccessfulSmartServerResponse((b'ok', )) + return request.SuccessfulSmartServerResponse((b"ok",)) class IterFilesRecursiveRequest(VfsRequest): - def do(self, relpath): - if not relpath.endswith(b'/'): - relpath += b'/' + if not relpath.endswith(b"/"): + relpath += b"/" relpath = self.translate_client_path(relpath) transport = self._backing_transport.clone(relpath) filenames = transport.iter_files_recursive() - return request.SuccessfulSmartServerResponse((b'names',) + tuple(filenames)) + return request.SuccessfulSmartServerResponse((b"names",) + tuple(filenames)) class ListDirRequest(VfsRequest): - def do(self, relpath): - if not relpath.endswith(b'/'): - relpath += b'/' + if not relpath.endswith(b"/"): + relpath += b"/" relpath = self.translate_client_path(relpath) filenames = self._backing_transport.list_dir(relpath) - return request.SuccessfulSmartServerResponse((b'names',) + tuple([filename.encode('utf-8') for filename in filenames])) + return request.SuccessfulSmartServerResponse( + (b"names",) + tuple([filename.encode("utf-8") for filename in filenames]) + ) class MkdirRequest(VfsRequest): - def do(self, relpath, mode): relpath = self.translate_client_path(relpath) - self._backing_transport.mkdir(relpath, - _deserialise_optional_mode(mode)) - return request.SuccessfulSmartServerResponse((b'ok',)) + self._backing_transport.mkdir(relpath, _deserialise_optional_mode(mode)) + return request.SuccessfulSmartServerResponse((b"ok",)) class MoveRequest(VfsRequest): - def do(self, rel_from, rel_to): rel_from = self.translate_client_path(rel_from) rel_to = self.translate_client_path(rel_to) self._backing_transport.move(rel_from, rel_to) - return request.SuccessfulSmartServerResponse((b'ok',)) + return request.SuccessfulSmartServerResponse((b"ok",)) class PutRequest(VfsRequest): - def do(self, relpath, mode): relpath = self.translate_client_path(relpath) self._relpath = relpath self._mode = _deserialise_optional_mode(mode) def do_body(self, body_bytes): - self._backing_transport.put_bytes( - self._relpath, body_bytes, self._mode) - return request.SuccessfulSmartServerResponse((b'ok',)) + self._backing_transport.put_bytes(self._relpath, body_bytes, self._mode) + return request.SuccessfulSmartServerResponse((b"ok",)) class PutNonAtomicRequest(VfsRequest): - def do(self, relpath, mode, create_parent, dir_mode): relpath = self.translate_client_path(relpath) self._relpath = relpath self._dir_mode = _deserialise_optional_mode(dir_mode) self._mode = _deserialise_optional_mode(mode) # a boolean would be nicer XXX - self._create_parent = (create_parent == b'T') + self._create_parent = create_parent == b"T" def do_body(self, body_bytes): - self._backing_transport.put_bytes_non_atomic(self._relpath, - body_bytes, - mode=self._mode, - create_parent_dir=self._create_parent, - dir_mode=self._dir_mode) - return request.SuccessfulSmartServerResponse((b'ok',)) + self._backing_transport.put_bytes_non_atomic( + self._relpath, + body_bytes, + mode=self._mode, + create_parent_dir=self._create_parent, + dir_mode=self._dir_mode, + ) + return request.SuccessfulSmartServerResponse((b"ok",)) class ReadvRequest(VfsRequest): - def do(self, relpath): relpath = self.translate_client_path(relpath) self._relpath = relpath @@ -185,44 +179,48 @@ def do(self, relpath): def do_body(self, body_bytes): """Accept offsets for a readv request.""" offsets = self._deserialise_offsets(body_bytes) - backing_bytes = b''.join(bytes for offset, bytes in - self._backing_transport.readv(self._relpath, offsets)) - return request.SuccessfulSmartServerResponse((b'readv',), backing_bytes) + backing_bytes = b"".join( + bytes + for offset, bytes in self._backing_transport.readv(self._relpath, offsets) + ) + return request.SuccessfulSmartServerResponse((b"readv",), backing_bytes) def _deserialise_offsets(self, text): # XXX: FIXME this should be on the protocol object. offsets = [] - for line in text.split(b'\n'): + for line in text.split(b"\n"): if not line: continue - start, length = line.split(b',') + start, length = line.split(b",") offsets.append((int(start), int(length))) return offsets class RenameRequest(VfsRequest): - def do(self, rel_from, rel_to): rel_from = self.translate_client_path(rel_from) rel_to = self.translate_client_path(rel_to) self._backing_transport.rename(rel_from, rel_to) - return request.SuccessfulSmartServerResponse((b'ok', )) + return request.SuccessfulSmartServerResponse((b"ok",)) class RmdirRequest(VfsRequest): - def do(self, relpath): relpath = self.translate_client_path(relpath) self._backing_transport.rmdir(relpath) - return request.SuccessfulSmartServerResponse((b'ok', )) + return request.SuccessfulSmartServerResponse((b"ok",)) class StatRequest(VfsRequest): - def do(self, relpath): - if not relpath.endswith(b'/'): - relpath += b'/' + if not relpath.endswith(b"/"): + relpath += b"/" relpath = self.translate_client_path(relpath) stat = self._backing_transport.stat(relpath) return request.SuccessfulSmartServerResponse( - (b'stat', str(stat.st_size).encode('ascii'), oct(stat.st_mode).encode('ascii'))) + ( + b"stat", + str(stat.st_size).encode("ascii"), + oct(stat.st_mode).encode("ascii"), + ) + ) diff --git a/breezy/bzr/static_tuple.py b/breezy/bzr/static_tuple.py index 2d118aca4b..d5d5e9eaa6 100644 --- a/breezy/bzr/static_tuple.py +++ b/breezy/bzr/static_tuple.py @@ -22,6 +22,7 @@ from ._static_tuple_c import StaticTuple except ImportError as e: from .. import osutils + osutils.failed_to_load_extension(e) from ._static_tuple_py import StaticTuple @@ -35,10 +36,10 @@ def expect_static_tuple(obj): As apis are improved, we will probably eventually stop calling this as it adds overhead we shouldn't need. """ - if not debug.debug_flag_enabled('static_tuple'): + if not debug.debug_flag_enabled("static_tuple"): return StaticTuple.from_sequence(obj) if not isinstance(obj, StaticTuple): - raise TypeError(f'We expected a StaticTuple not a {type(obj)}') + raise TypeError(f"We expected a StaticTuple not a {type(obj)}") return obj diff --git a/breezy/bzr/tag.py b/breezy/bzr/tag.py index 1c6413727c..14e41473f9 100644 --- a/breezy/bzr/tag.py +++ b/breezy/bzr/tag.py @@ -53,9 +53,11 @@ def get_tag_dict(self): tag_content = self.branch._get_tags_bytes() except _mod_transport.NoSuchFile: # ugly, but only abentley should see this :) - trace.warning(f'No branch/tags file in {self.branch}. ' - 'This branch was probably created by bzr 0.15pre. ' - 'Create an empty file to silence this message.') + trace.warning( + f"No branch/tags file in {self.branch}. " + "This branch was probably created by bzr 0.15pre. " + "Create an empty file to silence this message." + ) return {} return self._deserialize_tag_dict(tag_content) @@ -86,20 +88,21 @@ def _set_tag_dict(self, new_dict): return self.branch._set_tags_bytes(self._serialize_tag_dict(new_dict)) def _serialize_tag_dict(self, tag_dict): - td = {k.encode('utf-8'): v - for k, v in tag_dict.items()} + td = {k.encode("utf-8"): v for k, v in tag_dict.items()} return bencode.bencode(td) def _deserialize_tag_dict(self, tag_content): """Convert the tag file into a dictionary of tags.""" # was a special case to make initialization easy, an empty definition # is an empty dictionary - if tag_content == b'': + if tag_content == b"": return {} try: r = {} for k, v in bencode.bdecode(tag_content).items(): - r[k.decode('utf-8')] = v + r[k.decode("utf-8")] = v return r except ValueError as e: - raise ValueError(f"failed to deserialize tag dictionary {tag_content!r}: {e}") from e + raise ValueError( + f"failed to deserialize tag dictionary {tag_content!r}: {e}" + ) from e diff --git a/breezy/bzr/testament.py b/breezy/bzr/testament.py index dec93693bf..21b0893fb9 100644 --- a/breezy/bzr/testament.py +++ b/breezy/bzr/testament.py @@ -86,8 +86,8 @@ class Testament: - compared to a revision """ - long_header = 'bazaar-ng testament version 1\n' - short_header = 'bazaar-ng testament short form 1\n' + long_header = "bazaar-ng testament version 1\n" + short_header = "bazaar-ng testament short form 1\n" include_root = False @classmethod @@ -112,8 +112,9 @@ def __init__(self, rev, tree): self.message = rev.message self.parent_ids = rev.parent_ids[:] if not isinstance(tree, Tree): - raise TypeError("As of bzr 2.4 Testament.__init__() takes a " - "Revision and a Tree.") + raise TypeError( + "As of bzr 2.4 Testament.__init__() takes a " "Revision and a Tree." + ) self.tree = tree self.revprops = copy(rev.properties) if contains_whitespace(self.revision_id): @@ -131,80 +132,87 @@ def as_text_lines(self): a = r.append a(self.long_header) a(f"revision-id: {self.revision_id.decode('utf-8')}\n") - a(f'committer: {self.committer}\n') - a('timestamp: %d\n' % self.timestamp) - a('timezone: %d\n' % self.timezone) + a(f"committer: {self.committer}\n") + a("timestamp: %d\n" % self.timestamp) + a("timezone: %d\n" % self.timezone) # inventory length contains the root, which is not shown here - a('parents:\n') + a("parents:\n") for parent_id in sorted(self.parent_ids): if contains_whitespace(parent_id): raise ValueError(parent_id) a(f" {parent_id.decode('utf-8')}\n") - a('message:\n') + a("message:\n") for l in self.message.splitlines(): - a(f' {l}\n') - a('inventory:\n') + a(f" {l}\n") + a("inventory:\n") for path, ie in self._get_entries(): a(self._entry_to_line(path, ie)) r.extend(self._revprops_to_lines()) - return [line.encode('utf-8') for line in r] + return [line.encode("utf-8") for line in r] def _get_entries(self): - return ((path, ie) for (path, file_class, kind, ie) in - self.tree.list_files(include_root=self.include_root)) + return ( + (path, ie) + for (path, file_class, kind, ie) in self.tree.list_files( + include_root=self.include_root + ) + ) def _escape_path(self, path): if contains_linebreaks(path): raise ValueError(path) if not isinstance(path, str): # TODO(jelmer): Clean this up for pad.lv/1696545 - path = path.decode('ascii') - return path.replace('\\', '/').replace(' ', '\\ ') + path = path.decode("ascii") + return path.replace("\\", "/").replace(" ", "\\ ") def _entry_to_line(self, path, ie): """Turn an inventory entry into a testament line.""" if contains_whitespace(ie.file_id): raise ValueError(ie.file_id) - content = '' - content_spacer = '' - if ie.kind == 'file': + content = "" + content_spacer = "" + if ie.kind == "file": # TODO: avoid switching on kind if not ie.text_sha1: raise AssertionError() - content = ie.text_sha1.decode('ascii') - content_spacer = ' ' - elif ie.kind == 'symlink': + content = ie.text_sha1.decode("ascii") + content_spacer = " " + elif ie.kind == "symlink": if not ie.symlink_target: raise AssertionError() content = self._escape_path(ie.symlink_target) - content_spacer = ' ' - - l = ' {} {} {}{}{}\n'.format(ie.kind, self._escape_path(path), - ie.file_id.decode('utf8'), - content_spacer, content) + content_spacer = " " + + l = " {} {} {}{}{}\n".format( + ie.kind, + self._escape_path(path), + ie.file_id.decode("utf8"), + content_spacer, + content, + ) return l def as_text(self): - return b''.join(self.as_text_lines()) + return b"".join(self.as_text_lines()) def as_short_text(self): """Return short digest-based testament.""" - return (self.short_header.encode('ascii') + - b'revision-id: %s\n' - b'sha1: %s\n' - % (self.revision_id, self.as_sha1())) + return self.short_header.encode( + "ascii" + ) + b"revision-id: %s\n" b"sha1: %s\n" % (self.revision_id, self.as_sha1()) def _revprops_to_lines(self): """Pack up revision properties.""" if not self.revprops: return [] - r = ['properties:\n'] + r = ["properties:\n"] for name, value in sorted(self.revprops.items()): if contains_whitespace(name): raise ValueError(name) - r.append(f' {name}:\n') + r.append(f" {name}:\n") for line in value.splitlines(): - r.append(f' {line}\n') + r.append(f" {line}\n") return r def as_sha1(self): @@ -214,14 +222,14 @@ def as_sha1(self): class StrictTestament(Testament): """This testament format is for use as a checksum in bundle format 0.8.""" - long_header = 'bazaar-ng testament version 2.1\n' - short_header = 'bazaar-ng testament short form 2.1\n' + long_header = "bazaar-ng testament version 2.1\n" + short_header = "bazaar-ng testament short form 2.1\n" include_root = False def _entry_to_line(self, path, ie): l = Testament._entry_to_line(self, path, ie)[:-1] - l += ' ' + ie.revision.decode('utf-8') - l += {True: ' yes\n', False: ' no\n'}[ie.executable] + l += " " + ie.revision.decode("utf-8") + l += {True: " yes\n", False: " no\n"}[ie.executable] return l @@ -231,8 +239,8 @@ class StrictTestament3(StrictTestament): It differs from StrictTestament by including data about the tree root. """ - long_header = 'bazaar testament version 3 strict\n' - short_header = 'bazaar testament short form 3 strict\n' + long_header = "bazaar testament version 3 strict\n" + short_header = "bazaar testament short form 3 strict\n" include_root = True def _escape_path(self, path): @@ -240,7 +248,7 @@ def _escape_path(self, path): raise ValueError(path) if not isinstance(path, str): # TODO(jelmer): Clean this up for pad.lv/1696545 - path = path.decode('ascii') - if path == '': - path = '.' - return path.replace('\\', '/').replace(' ', '\\ ') + path = path.decode("ascii") + if path == "": + path = "." + return path.replace("\\", "/").replace(" ", "\\ ") diff --git a/breezy/bzr/tests/__init__.py b/breezy/bzr/tests/__init__.py index fc5cae98c9..3c96ef8e78 100644 --- a/breezy/bzr/tests/__init__.py +++ b/breezy/bzr/tests/__init__.py @@ -31,62 +31,65 @@ def load_tests(loader, basic_tests, pattern): # add the tests for this module suite.addTests(basic_tests) - prefix = __name__ + '.' + prefix = __name__ + "." testmod_names = [ - 'blackbox', - 'test_dirstate', - 'per_bzrdir', - 'per_inventory', - 'per_pack_repository', - 'per_repository_chk', - 'per_repository_vf', - 'per_versionedfile', - 'test__btree_serializer', - 'test__chk_map', - 'test__dirstate_helpers', - 'test__groupcompress', - 'test__simple_set', - 'test__static_tuple', - 'test_btree_index', - 'test_bundle', - 'test_bzrdir', - 'test_chk_map', - 'test_chk_serializer', - 'test_conflicts', - 'test_generate_ids', - 'test_groupcompress', - 'test_hashcache', - 'test_index', - 'test_inv', - 'test_inventory_delta', - 'test_knit', - 'test_lockable_files', - 'test_matchers', - 'test_pack', - 'test_read_bundle', - 'test_remote', - 'test_repository', - 'test_rio', - 'test_smart', - 'test_smart_request', - 'test_smart_signals', - 'test_smart_transport', - 'test_serializer', - 'test_tag', - 'test_testament', - 'test_tuned_gzip', - 'test_transform', - 'test_versionedfile', - 'test_vf_search', - 'test_vfs_ratchet', - 'test_workingtree', - 'test_workingtree_4', - 'test_weave', - 'test_xml', - ] + "blackbox", + "test_dirstate", + "per_bzrdir", + "per_inventory", + "per_pack_repository", + "per_repository_chk", + "per_repository_vf", + "per_versionedfile", + "test__btree_serializer", + "test__chk_map", + "test__dirstate_helpers", + "test__groupcompress", + "test__simple_set", + "test__static_tuple", + "test_btree_index", + "test_bundle", + "test_bzrdir", + "test_chk_map", + "test_chk_serializer", + "test_conflicts", + "test_generate_ids", + "test_groupcompress", + "test_hashcache", + "test_index", + "test_inv", + "test_inventory_delta", + "test_knit", + "test_lockable_files", + "test_matchers", + "test_pack", + "test_read_bundle", + "test_remote", + "test_repository", + "test_rio", + "test_smart", + "test_smart_request", + "test_smart_signals", + "test_smart_transport", + "test_serializer", + "test_tag", + "test_testament", + "test_tuned_gzip", + "test_transform", + "test_versionedfile", + "test_vf_search", + "test_vfs_ratchet", + "test_workingtree", + "test_workingtree_4", + "test_weave", + "test_xml", + ] # add the tests for the sub modules - suite.addTests(loader.loadTestsFromModuleNames( - [prefix + module_name for module_name in testmod_names])) + suite.addTests( + loader.loadTestsFromModuleNames( + [prefix + module_name for module_name in testmod_names] + ) + ) return suite diff --git a/breezy/bzr/tests/blackbox/__init__.py b/breezy/bzr/tests/blackbox/__init__.py index ff45ec4ce5..0cf167be8f 100644 --- a/breezy/bzr/tests/blackbox/__init__.py +++ b/breezy/bzr/tests/blackbox/__init__.py @@ -23,18 +23,19 @@ """ - - def load_tests(loader, basic_tests, pattern): suite = loader.suiteClass() # add the tests for this module suite.addTests(basic_tests) - prefix = __name__ + '.' + prefix = __name__ + "." testmod_names = [ - 'test_dump_btree', - ] + "test_dump_btree", + ] # add the tests for the sub modules - suite.addTests(loader.loadTestsFromModuleNames( - [prefix + module_name for module_name in testmod_names])) + suite.addTests( + loader.loadTestsFromModuleNames( + [prefix + module_name for module_name in testmod_names] + ) + ) return suite diff --git a/breezy/bzr/tests/blackbox/test_dump_btree.py b/breezy/bzr/tests/blackbox/test_dump_btree.py index 0f177bd175..5a4f097d6d 100644 --- a/breezy/bzr/tests/blackbox/test_dump_btree.py +++ b/breezy/bzr/tests/blackbox/test_dump_btree.py @@ -23,106 +23,103 @@ class TestDumpBtree(tests.TestCaseWithTransport): - def create_sample_btree_index(self): - builder = btree_index.BTreeBuilder( - reference_lists=1, key_elements=2) - builder.add_node((b'test', b'key1'), b'value', - (((b'ref', b'entry'),),)) - builder.add_node((b'test', b'key2'), b'value2', - (((b'ref', b'entry2'),),)) - builder.add_node((b'test2', b'key3'), b'value3', - (((b'ref', b'entry3'),),)) + builder = btree_index.BTreeBuilder(reference_lists=1, key_elements=2) + builder.add_node((b"test", b"key1"), b"value", (((b"ref", b"entry"),),)) + builder.add_node((b"test", b"key2"), b"value2", (((b"ref", b"entry2"),),)) + builder.add_node((b"test2", b"key3"), b"value3", (((b"ref", b"entry3"),),)) out_f = builder.finish() try: - self.build_tree_contents([('test.btree', out_f.read())]) + self.build_tree_contents([("test.btree", out_f.read())]) finally: out_f.close() def test_dump_btree_smoke(self): self.create_sample_btree_index() - out, err = self.run_bzr('dump-btree test.btree') + out, err = self.run_bzr("dump-btree test.btree") self.assertEqualDiff( "(('test', 'key1'), 'value', ((('ref', 'entry'),),))\n" "(('test', 'key2'), 'value2', ((('ref', 'entry2'),),))\n" "(('test2', 'key3'), 'value3', ((('ref', 'entry3'),),))\n", - out) + out, + ) def test_dump_btree_http_smoke(self): self.transport_readonly_server = http_server.HttpServer self.create_sample_btree_index() - url = self.get_readonly_url('test.btree') - out, err = self.run_bzr(['dump-btree', url]) + url = self.get_readonly_url("test.btree") + out, err = self.run_bzr(["dump-btree", url]) self.assertEqualDiff( "(('test', 'key1'), 'value', ((('ref', 'entry'),),))\n" "(('test', 'key2'), 'value2', ((('ref', 'entry2'),),))\n" "(('test2', 'key3'), 'value3', ((('ref', 'entry3'),),))\n", - out) + out, + ) def test_dump_btree_raw_smoke(self): self.create_sample_btree_index() - out, err = self.run_bzr('dump-btree test.btree --raw') + out, err = self.run_bzr("dump-btree test.btree --raw") self.assertEqualDiff( - 'Root node:\n' - 'B+Tree Graph Index 2\n' - 'node_ref_lists=1\n' - 'key_elements=2\n' - 'len=3\n' - 'row_lengths=1\n' - '\n' - 'Page 0\n' - 'type=leaf\n' - 'test\0key1\0ref\0entry\0value\n' - 'test\0key2\0ref\0entry2\0value2\n' - 'test2\0key3\0ref\0entry3\0value3\n' - '\n', - out) + "Root node:\n" + "B+Tree Graph Index 2\n" + "node_ref_lists=1\n" + "key_elements=2\n" + "len=3\n" + "row_lengths=1\n" + "\n" + "Page 0\n" + "type=leaf\n" + "test\0key1\0ref\0entry\0value\n" + "test\0key2\0ref\0entry2\0value2\n" + "test2\0key3\0ref\0entry3\0value3\n" + "\n", + out, + ) def test_dump_btree_no_refs_smoke(self): # A BTree index with no ref lists (such as *.cix) can be dumped without # errors. - builder = btree_index.BTreeBuilder( - reference_lists=0, key_elements=2) - builder.add_node((b'test', b'key1'), b'value') + builder = btree_index.BTreeBuilder(reference_lists=0, key_elements=2) + builder.add_node((b"test", b"key1"), b"value") out_f = builder.finish() try: - self.build_tree_contents([('test.btree', out_f.read())]) + self.build_tree_contents([("test.btree", out_f.read())]) finally: out_f.close() - out, err = self.run_bzr('dump-btree test.btree') + out, err = self.run_bzr("dump-btree test.btree") def create_sample_empty_btree_index(self): - builder = btree_index.BTreeBuilder( - reference_lists=1, key_elements=2) + builder = btree_index.BTreeBuilder(reference_lists=1, key_elements=2) out_f = builder.finish() try: - self.build_tree_contents([('test.btree', out_f.read())]) + self.build_tree_contents([("test.btree", out_f.read())]) finally: out_f.close() def test_dump_empty_btree_smoke(self): self.create_sample_empty_btree_index() - out, err = self.run_bzr('dump-btree test.btree') + out, err = self.run_bzr("dump-btree test.btree") self.assertEqualDiff("", out) def test_dump_empty_btree_http_smoke(self): self.transport_readonly_server = http_server.HttpServer self.create_sample_empty_btree_index() - url = self.get_readonly_url('test.btree') - out, err = self.run_bzr(['dump-btree', url]) + url = self.get_readonly_url("test.btree") + out, err = self.run_bzr(["dump-btree", url]) self.assertEqualDiff("", out) def test_dump_empty_btree_raw_smoke(self): self.create_sample_empty_btree_index() - out, err = self.run_bzr('dump-btree test.btree --raw') + out, err = self.run_bzr("dump-btree test.btree --raw") self.assertEqualDiff( - 'Root node:\n' - 'B+Tree Graph Index 2\n' - 'node_ref_lists=1\n' - 'key_elements=2\n' - 'len=0\n' - 'row_lengths=\n' - '\n' - 'Page 0\n' - '(empty)\n', - out) + "Root node:\n" + "B+Tree Graph Index 2\n" + "node_ref_lists=1\n" + "key_elements=2\n" + "len=0\n" + "row_lengths=\n" + "\n" + "Page 0\n" + "(empty)\n", + out, + ) diff --git a/breezy/bzr/tests/matchers.py b/breezy/bzr/tests/matchers.py index 3560e81955..735d953a4d 100644 --- a/breezy/bzr/tests/matchers.py +++ b/breezy/bzr/tests/matchers.py @@ -27,8 +27,8 @@ """ __all__ = [ - 'ContainsNoVfsCalls', - ] + "ContainsNoVfsCalls", +] from testtools.matchers import Matcher, Mismatch @@ -44,16 +44,19 @@ def __init__(self, vfs_calls): self.vfs_calls = vfs_calls def describe(self): - return "no VFS calls expected, got: %s" % ",".join([ - "{}({})".format(c.method, - ", ".join([repr(a) for a in c.args])) for c in self.vfs_calls]) + return "no VFS calls expected, got: %s" % ",".join( + [ + "{}({})".format(c.method, ", ".join([repr(a) for a in c.args])) + for c in self.vfs_calls + ] + ) class ContainsNoVfsCalls(Matcher): """Ensure that none of the specified calls are HPSS calls.""" def __str__(self): - return 'ContainsNoVfsCalls()' + return "ContainsNoVfsCalls()" @classmethod def match(cls, hpss_calls): diff --git a/breezy/bzr/tests/per_bzrdir/__init__.py b/breezy/bzr/tests/per_bzrdir/__init__.py index 32a57bfcca..6a49f13d1a 100644 --- a/breezy/bzr/tests/per_bzrdir/__init__.py +++ b/breezy/bzr/tests/per_bzrdir/__init__.py @@ -40,7 +40,6 @@ class TestCaseWithBzrDir(TestCaseWithTransport): - def setUp(self): super().setUp() self.controldir = None @@ -56,35 +55,45 @@ def get_default_format(self): def load_tests(loader, standard_tests, pattern): test_per_bzrdir = [ - 'breezy.bzr.tests.per_bzrdir.test_bzrdir', - ] + "breezy.bzr.tests.per_bzrdir.test_bzrdir", + ] submod_tests = loader.loadTestsFromModuleNames(test_per_bzrdir) - formats = [format for format in ControlDirFormat.known_formats() - if isinstance(format, BzrDirFormat)] + formats = [ + format + for format in ControlDirFormat.known_formats() + if isinstance(format, BzrDirFormat) + ] scenarios = make_scenarios( default_transport, None, # None here will cause a readonly decorator to be created # by the TestCaseWithTransport.get_readonly_transport method. None, - formats) + formats, + ) # This will always add scenarios using the smart server. from ...remote import RemoteBzrDirFormat # test the remote server behaviour when backed with a MemoryTransport # Once for the current version - scenarios.extend(make_scenarios( - memory.MemoryServer, - test_server.SmartTCPServer_for_testing, - test_server.ReadonlySmartTCPServer_for_testing, - [(RemoteBzrDirFormat())], - name_suffix='-default')) + scenarios.extend( + make_scenarios( + memory.MemoryServer, + test_server.SmartTCPServer_for_testing, + test_server.ReadonlySmartTCPServer_for_testing, + [(RemoteBzrDirFormat())], + name_suffix="-default", + ) + ) # And once with < 1.6 - the 'v2' protocol. - scenarios.extend(make_scenarios( - memory.MemoryServer, - test_server.SmartTCPServer_for_testing_v2_only, - test_server.ReadonlySmartTCPServer_for_testing_v2_only, - [(RemoteBzrDirFormat())], - name_suffix='-v2')) + scenarios.extend( + make_scenarios( + memory.MemoryServer, + test_server.SmartTCPServer_for_testing_v2_only, + test_server.ReadonlySmartTCPServer_for_testing_v2_only, + [(RemoteBzrDirFormat())], + name_suffix="-v2", + ) + ) # add the tests for the sub modules return multiply_tests(submod_tests, scenarios, standard_tests) diff --git a/breezy/bzr/tests/per_bzrdir/test_bzrdir.py b/breezy/bzr/tests/per_bzrdir/test_bzrdir.py index 4b53d27c03..fcaa788775 100644 --- a/breezy/bzr/tests/per_bzrdir/test_bzrdir.py +++ b/breezy/bzr/tests/per_bzrdir/test_bzrdir.py @@ -73,7 +73,6 @@ def get_format_string(self): class TestBzrDir(TestCaseWithBzrDir): - # Many of these tests test for disk equality rather than checking # for semantic equivalence. This works well for some tests but # is not good at handling changes in representation or the addition @@ -101,24 +100,26 @@ def assertDirectoriesEqual(self, source, target, ignore_list=None): """ if ignore_list is None: ignore_list = [] - directories = ['.'] + directories = ["."] while directories: dir = directories.pop() for path in set(source.list_dir(dir) + target.list_dir(dir)): - path = dir + '/' + path + path = dir + "/" + path if path in ignore_list: continue try: stat = source.stat(path) except transport.NoSuchFile: - self.fail(f'{path} not in source') + self.fail(f"{path} not in source") if S_ISDIR(stat.st_mode): self.assertTrue(S_ISDIR(target.stat(path).st_mode)) directories.append(path) else: - self.assertEqualDiff(source.get_bytes(path), - target.get_bytes(path), - f"text for file {path!r} differs:\n") + self.assertEqualDiff( + source.get_bytes(path), + target.get_bytes(path), + f"text for file {path!r} differs:\n", + ) def assertRepositoryHasSameItems(self, left_repo, right_repo): """Require left_repo and right_repo to contain the same data.""" @@ -130,34 +131,40 @@ def assertRepositoryHasSameItems(self, left_repo, right_repo): with left_repo.lock_read(), right_repo.lock_read(): # revs all_revs = left_repo.all_revision_ids() - self.assertEqual(left_repo.all_revision_ids(), - right_repo.all_revision_ids()) + self.assertEqual( + left_repo.all_revision_ids(), right_repo.all_revision_ids() + ) for rev_id in left_repo.all_revision_ids(): - self.assertEqual(left_repo.get_revision(rev_id), - right_repo.get_revision(rev_id)) + self.assertEqual( + left_repo.get_revision(rev_id), right_repo.get_revision(rev_id) + ) # Assert the revision trees (and thus the inventories) are equal - def sort_key(rev_tree): return rev_tree.get_revision_id() - rev_trees_a = sorted( - left_repo.revision_trees(all_revs), key=sort_key) - rev_trees_b = sorted( - right_repo.revision_trees(all_revs), key=sort_key) + def sort_key(rev_tree): + return rev_tree.get_revision_id() + + rev_trees_a = sorted(left_repo.revision_trees(all_revs), key=sort_key) + rev_trees_b = sorted(right_repo.revision_trees(all_revs), key=sort_key) for tree_a, tree_b in zip(rev_trees_a, rev_trees_b): self.assertEqual([], list(tree_a.iter_changes(tree_b))) # texts text_index = left_repo._generate_text_key_index() - self.assertEqual(text_index, - right_repo._generate_text_key_index()) + self.assertEqual(text_index, right_repo._generate_text_key_index()) desired_files = [] for file_id, revision_id in text_index: - desired_files.append( - (file_id, revision_id, (file_id, revision_id))) - left_texts = [(identifier, b"".join(bytes_iterator)) for - (identifier, bytes_iterator) in - left_repo.iter_files_bytes(desired_files)] - right_texts = [(identifier, b"".join(bytes_iterator)) for - (identifier, bytes_iterator) in - right_repo.iter_files_bytes(desired_files)] + desired_files.append((file_id, revision_id, (file_id, revision_id))) + left_texts = [ + (identifier, b"".join(bytes_iterator)) + for (identifier, bytes_iterator) in left_repo.iter_files_bytes( + desired_files + ) + ] + right_texts = [ + (identifier, b"".join(bytes_iterator)) + for (identifier, bytes_iterator) in right_repo.iter_files_bytes( + desired_files + ) + ] left_texts.sort() right_texts.sort() self.assertEqual(left_texts, right_texts) @@ -170,9 +177,15 @@ def sort_key(rev_tree): return rev_tree.get_revision_id() right_text = right_repo.get_signature_text(rev_id) self.assertEqual(left_text, right_text) - def sproutOrSkip(self, from_bzrdir, to_url, revision_id=None, - force_new_repo=False, accelerator_tree=None, - create_tree_if_local=True): + def sproutOrSkip( + self, + from_bzrdir, + to_url, + revision_id=None, + force_new_repo=False, + accelerator_tree=None, + create_tree_if_local=True, + ): """Sprout from_bzrdir into to_url, or raise TestSkipped. A simple wrapper for from_bzrdir.sprout that translates NotLocalUrl into @@ -180,12 +193,15 @@ def sproutOrSkip(self, from_bzrdir, to_url, revision_id=None, """ to_transport = transport.get_transport(to_url) if not isinstance(to_transport, LocalTransport): - raise TestSkipped('Cannot sprout to remote bzrdirs.') - target = from_bzrdir.sprout(to_url, revision_id=revision_id, - force_new_repo=force_new_repo, - possible_transports=[to_transport], - accelerator_tree=accelerator_tree, - create_tree_if_local=create_tree_if_local) + raise TestSkipped("Cannot sprout to remote bzrdirs.") + target = from_bzrdir.sprout( + to_url, + revision_id=revision_id, + force_new_repo=force_new_repo, + possible_transports=[to_transport], + accelerator_tree=accelerator_tree, + create_tree_if_local=create_tree_if_local, + ) return target def skipIfNoWorkingTree(self, a_controldir): @@ -196,8 +212,9 @@ def skipIfNoWorkingTree(self, a_controldir): try: a_controldir.open_workingtree() except (errors.NotLocalUrl, errors.NoWorkingTree) as e: - raise TestSkipped("bzrdir on transport %r has no working tree" - % a_controldir.transport) from e + raise TestSkipped( + "bzrdir on transport %r has no working tree" % a_controldir.transport + ) from e def createWorkingTreeOrSkip(self, a_controldir): """Create a working tree on a_controldir, or raise TestSkipped. @@ -213,341 +230,400 @@ def createWorkingTreeOrSkip(self, a_controldir): revision_id=None, from_branch=None, accelerator_tree=None, - hardlink=False) + hardlink=False, + ) except errors.NotLocalUrl as e: - raise TestSkipped("cannot make working tree with transport %r" - % a_controldir.transport) from e + raise TestSkipped( + "cannot make working tree with transport %r" % a_controldir.transport + ) from e def test_clone_bzrdir_repository_under_shared_force_new_repo(self): - tree = self.make_branch_and_tree('commit_tree') - self.build_tree(['commit_tree/foo']) - tree.add('foo') - tree.commit('revision 1', rev_id=b'1') - dir = self.make_controldir('source') + tree = self.make_branch_and_tree("commit_tree") + self.build_tree(["commit_tree/foo"]) + tree.add("foo") + tree.commit("revision 1", rev_id=b"1") + dir = self.make_controldir("source") repo = dir.create_repository() repo.fetch(tree.branch.repository) - self.assertTrue(repo.has_revision(b'1')) + self.assertTrue(repo.has_revision(b"1")) try: - self.make_repository('target', shared=True) + self.make_repository("target", shared=True) except errors.IncompatibleFormat: return - target = dir.clone(self.get_url('target/child'), force_new_repo=True) + target = dir.clone(self.get_url("target/child"), force_new_repo=True) self.assertNotEqual(dir.transport.base, target.transport.base) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - ['./.bzr/repository', - ]) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/repository", + ], + ) self.assertRepositoryHasSameItems(tree.branch.repository, repo) def test_clone_bzrdir_branch_and_repo(self): - tree = self.make_branch_and_tree('commit_tree') - self.build_tree(['commit_tree/foo']) - tree.add('foo') - tree.commit('revision 1') - source = self.make_branch('source') + tree = self.make_branch_and_tree("commit_tree") + self.build_tree(["commit_tree/foo"]) + tree.add("foo") + tree.commit("revision 1") + source = self.make_branch("source") tree.branch.repository.copy_content_into(source.repository) tree.branch.copy_content_into(source) dir = source.controldir - target = dir.clone(self.get_url('target')) + target = dir.clone(self.get_url("target")) self.assertNotEqual(dir.transport.base, target.transport.base) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - [ - './.bzr/basis-inventory-cache', - './.bzr/checkout/stat-cache', - './.bzr/merge-hashes', - './.bzr/repository', - './.bzr/stat-cache', - ]) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/basis-inventory-cache", + "./.bzr/checkout/stat-cache", + "./.bzr/merge-hashes", + "./.bzr/repository", + "./.bzr/stat-cache", + ], + ) self.assertRepositoryHasSameItems( - tree.branch.repository, target.open_repository()) + tree.branch.repository, target.open_repository() + ) def test_clone_on_transport(self): - a_dir = self.make_controldir('source') - target_transport = a_dir.root_transport.clone('..').clone('target') + a_dir = self.make_controldir("source") + target_transport = a_dir.root_transport.clone("..").clone("target") target = a_dir.clone_on_transport(target_transport) self.assertNotEqual(a_dir.transport.base, target.transport.base) - self.assertDirectoriesEqual(a_dir.root_transport, target.root_transport, - ['./.bzr/merge-hashes']) + self.assertDirectoriesEqual( + a_dir.root_transport, target.root_transport, ["./.bzr/merge-hashes"] + ) def test_clone_bzrdir_empty(self): - dir = self.make_controldir('source') - target = dir.clone(self.get_url('target')) + dir = self.make_controldir("source") + target = dir.clone(self.get_url("target")) self.assertNotEqual(dir.transport.base, target.transport.base) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - ['./.bzr/merge-hashes']) + self.assertDirectoriesEqual( + dir.root_transport, target.root_transport, ["./.bzr/merge-hashes"] + ) def test_clone_bzrdir_empty_force_new_ignored(self): # the force_new_repo parameter should have no effect on an empty # bzrdir's clone logic - dir = self.make_controldir('source') - target = dir.clone(self.get_url('target'), force_new_repo=True) + dir = self.make_controldir("source") + target = dir.clone(self.get_url("target"), force_new_repo=True) self.assertNotEqual(dir.transport.base, target.transport.base) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - ['./.bzr/merge-hashes']) + self.assertDirectoriesEqual( + dir.root_transport, target.root_transport, ["./.bzr/merge-hashes"] + ) def test_clone_bzrdir_repository(self): - tree = self.make_branch_and_tree('commit_tree') - self.build_tree( - ['foo'], transport=tree.controldir.transport.clone('..')) - tree.add('foo') - tree.commit('revision 1', rev_id=b'1') - dir = self.make_controldir('source') + tree = self.make_branch_and_tree("commit_tree") + self.build_tree(["foo"], transport=tree.controldir.transport.clone("..")) + tree.add("foo") + tree.commit("revision 1", rev_id=b"1") + dir = self.make_controldir("source") repo = dir.create_repository() repo.fetch(tree.branch.repository) - self.assertTrue(repo.has_revision(b'1')) - target = dir.clone(self.get_url('target')) + self.assertTrue(repo.has_revision(b"1")) + target = dir.clone(self.get_url("target")) self.assertNotEqual(dir.transport.base, target.transport.base) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - [ - './.bzr/merge-hashes', - './.bzr/repository', - ]) - self.assertRepositoryHasSameItems(tree.branch.repository, - target.open_repository()) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/merge-hashes", + "./.bzr/repository", + ], + ) + self.assertRepositoryHasSameItems( + tree.branch.repository, target.open_repository() + ) def test_clone_bzrdir_tree_branch_repo(self): - tree = self.make_branch_and_tree('source') - self.build_tree(['source/foo']) - tree.add('foo') - tree.commit('revision 1') + tree = self.make_branch_and_tree("source") + self.build_tree(["source/foo"]) + tree.add("foo") + tree.commit("revision 1") dir = tree.controldir - target = dir.clone(self.get_url('target')) + target = dir.clone(self.get_url("target")) self.skipIfNoWorkingTree(target) self.assertNotEqual(dir.transport.base, target.transport.base) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - ['./.bzr/stat-cache', - './.bzr/checkout/dirstate', - './.bzr/checkout/stat-cache', - './.bzr/checkout/merge-hashes', - './.bzr/merge-hashes', - './.bzr/repository', - ]) - self.assertRepositoryHasSameItems(tree.branch.repository, - target.open_branch().repository) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/stat-cache", + "./.bzr/checkout/dirstate", + "./.bzr/checkout/stat-cache", + "./.bzr/checkout/merge-hashes", + "./.bzr/merge-hashes", + "./.bzr/repository", + ], + ) + self.assertRepositoryHasSameItems( + tree.branch.repository, target.open_branch().repository + ) target.open_workingtree().revert() def test_revert_inventory(self): - tree = self.make_branch_and_tree('source') - self.build_tree(['source/foo']) - tree.add('foo') - tree.commit('revision 1') + tree = self.make_branch_and_tree("source") + self.build_tree(["source/foo"]) + tree.add("foo") + tree.commit("revision 1") dir = tree.controldir - target = dir.clone(self.get_url('target')) + target = dir.clone(self.get_url("target")) self.skipIfNoWorkingTree(target) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - ['./.bzr/stat-cache', - './.bzr/checkout/dirstate', - './.bzr/checkout/stat-cache', - './.bzr/checkout/merge-hashes', - './.bzr/merge-hashes', - './.bzr/repository', - ]) - self.assertRepositoryHasSameItems(tree.branch.repository, - target.open_branch().repository) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/stat-cache", + "./.bzr/checkout/dirstate", + "./.bzr/checkout/stat-cache", + "./.bzr/checkout/merge-hashes", + "./.bzr/merge-hashes", + "./.bzr/repository", + ], + ) + self.assertRepositoryHasSameItems( + tree.branch.repository, target.open_branch().repository + ) target.open_workingtree().revert() - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - ['./.bzr/stat-cache', - './.bzr/checkout/dirstate', - './.bzr/checkout/stat-cache', - './.bzr/checkout/merge-hashes', - './.bzr/merge-hashes', - './.bzr/repository', - ]) - self.assertRepositoryHasSameItems(tree.branch.repository, - target.open_branch().repository) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/stat-cache", + "./.bzr/checkout/dirstate", + "./.bzr/checkout/stat-cache", + "./.bzr/checkout/merge-hashes", + "./.bzr/merge-hashes", + "./.bzr/repository", + ], + ) + self.assertRepositoryHasSameItems( + tree.branch.repository, target.open_branch().repository + ) def test_clone_bzrdir_tree_branch_reference(self): # a tree with a branch reference (aka a checkout) # should stay a checkout on clone. - referenced_branch = self.make_branch('referencced') - dir = self.make_controldir('source') + referenced_branch = self.make_branch("referencced") + dir = self.make_controldir("source") try: dir.set_branch_reference(referenced_branch) except errors.IncompatibleFormat: # this is ok too, not all formats have to support references. return self.createWorkingTreeOrSkip(dir) - target = dir.clone(self.get_url('target')) + target = dir.clone(self.get_url("target")) self.skipIfNoWorkingTree(target) self.assertNotEqual(dir.transport.base, target.transport.base) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - ['./.bzr/stat-cache', - './.bzr/checkout/stat-cache', - './.bzr/checkout/merge-hashes', - './.bzr/merge-hashes', - './.bzr/repository/inventory.knit', - ]) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/stat-cache", + "./.bzr/checkout/stat-cache", + "./.bzr/checkout/merge-hashes", + "./.bzr/merge-hashes", + "./.bzr/repository/inventory.knit", + ], + ) def test_clone_bzrdir_branch_and_repo_into_shared_repo_force_new_repo(self): # by default cloning into a shared repo uses the shared repo. - tree = self.make_branch_and_tree('commit_tree') - self.build_tree(['commit_tree/foo']) - tree.add('foo') - tree.commit('revision 1') - source = self.make_branch('source') + tree = self.make_branch_and_tree("commit_tree") + self.build_tree(["commit_tree/foo"]) + tree.add("foo") + tree.commit("revision 1") + source = self.make_branch("source") tree.branch.repository.copy_content_into(source.repository) tree.branch.copy_content_into(source) try: - self.make_repository('target', shared=True) + self.make_repository("target", shared=True) except errors.IncompatibleFormat: return dir = source.controldir - target = dir.clone(self.get_url('target/child'), force_new_repo=True) + target = dir.clone(self.get_url("target/child"), force_new_repo=True) self.assertNotEqual(dir.transport.base, target.transport.base) repo = target.open_repository() - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - ['./.bzr/repository', - ]) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/repository", + ], + ) self.assertRepositoryHasSameItems(tree.branch.repository, repo) def test_clone_bzrdir_branch_reference(self): # cloning should preserve the reference status of the branch in a bzrdir - referenced_branch = self.make_branch('referencced') - dir = self.make_controldir('source') + referenced_branch = self.make_branch("referencced") + dir = self.make_controldir("source") try: dir.set_branch_reference(referenced_branch) except errors.IncompatibleFormat: # this is ok too, not all formats have to support references. return - target = dir.clone(self.get_url('target')) + target = dir.clone(self.get_url("target")) self.assertNotEqual(dir.transport.base, target.transport.base) self.assertDirectoriesEqual(dir.root_transport, target.root_transport) def test_sprout_bzrdir_repository(self): - tree = self.make_branch_and_tree('commit_tree') - self.build_tree( - ['foo'], transport=tree.controldir.transport.clone('..')) - tree.add('foo') - tree.commit('revision 1', rev_id=b'1') - dir = self.make_controldir('source') + tree = self.make_branch_and_tree("commit_tree") + self.build_tree(["foo"], transport=tree.controldir.transport.clone("..")) + tree.add("foo") + tree.commit("revision 1", rev_id=b"1") + dir = self.make_controldir("source") repo = dir.create_repository() repo.fetch(tree.branch.repository) - self.assertTrue(repo.has_revision(b'1')) + self.assertTrue(repo.has_revision(b"1")) try: - self.assertTrue( - _mod_revision.is_null(dir.open_branch().last_revision())) + self.assertTrue(_mod_revision.is_null(dir.open_branch().last_revision())) except errors.NotBranchError: pass - target = dir.sprout(self.get_url('target')) + target = dir.sprout(self.get_url("target")) self.assertNotEqual(dir.transport.base, target.transport.base) # testing inventory isn't reasonable for repositories - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - [ - './.bzr/branch', - './.bzr/checkout', - './.bzr/inventory', - './.bzr/parent', - './.bzr/repository/inventory.knit', - ]) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/branch", + "./.bzr/checkout", + "./.bzr/inventory", + "./.bzr/parent", + "./.bzr/repository/inventory.knit", + ], + ) try: - local_inventory = dir.transport.local_abspath('inventory') + local_inventory = dir.transport.local_abspath("inventory") except errors.NotLocalUrl: return try: # If we happen to have a tree, we'll guarantee everything # except for the tree root is the same. - with open(local_inventory, 'rb') as inventory_f: - self.assertContainsRe(inventory_f.read(), - b'\n\n') + with open(local_inventory, "rb") as inventory_f: + self.assertContainsRe( + inventory_f.read(), b'\n\n' + ) except FileNotFoundError: pass def test_sprout_bzrdir_branch_and_repo(self): - tree = self.make_branch_and_tree('commit_tree') - self.build_tree(['commit_tree/foo']) - tree.add('foo') - tree.commit('revision 1') - source = self.make_branch('source') + tree = self.make_branch_and_tree("commit_tree") + self.build_tree(["commit_tree/foo"]) + tree.add("foo") + tree.commit("revision 1") + source = self.make_branch("source") tree.branch.repository.copy_content_into(source.repository) tree.controldir.open_branch().copy_content_into(source) dir = source.controldir - target = dir.sprout(self.get_url('target')) + target = dir.sprout(self.get_url("target")) self.assertNotEqual(dir.transport.base, target.transport.base) target_repo = target.open_repository() self.assertRepositoryHasSameItems(source.repository, target_repo) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - [ - './.bzr/basis-inventory-cache', - './.bzr/branch/branch.conf', - './.bzr/branch/parent', - './.bzr/checkout', - './.bzr/checkout/inventory', - './.bzr/checkout/stat-cache', - './.bzr/inventory', - './.bzr/parent', - './.bzr/repository', - './.bzr/stat-cache', - './foo', - ]) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/basis-inventory-cache", + "./.bzr/branch/branch.conf", + "./.bzr/branch/parent", + "./.bzr/checkout", + "./.bzr/checkout/inventory", + "./.bzr/checkout/stat-cache", + "./.bzr/inventory", + "./.bzr/parent", + "./.bzr/repository", + "./.bzr/stat-cache", + "./foo", + ], + ) def test_sprout_bzrdir_tree_branch_repo(self): - tree = self.make_branch_and_tree('source') - self.build_tree( - ['foo'], transport=tree.controldir.transport.clone('..')) - tree.add('foo') - tree.commit('revision 1') + tree = self.make_branch_and_tree("source") + self.build_tree(["foo"], transport=tree.controldir.transport.clone("..")) + tree.add("foo") + tree.commit("revision 1") dir = tree.controldir - target = self.sproutOrSkip(dir, self.get_url('target')) + target = self.sproutOrSkip(dir, self.get_url("target")) self.assertNotEqual(dir.transport.base, target.transport.base) - self.assertDirectoriesEqual(dir.root_transport, target.root_transport, - [ - './.bzr/branch', - './.bzr/checkout/dirstate', - './.bzr/checkout/stat-cache', - './.bzr/checkout/inventory', - './.bzr/inventory', - './.bzr/parent', - './.bzr/repository', - './.bzr/stat-cache', - ]) + self.assertDirectoriesEqual( + dir.root_transport, + target.root_transport, + [ + "./.bzr/branch", + "./.bzr/checkout/dirstate", + "./.bzr/checkout/stat-cache", + "./.bzr/checkout/inventory", + "./.bzr/inventory", + "./.bzr/parent", + "./.bzr/repository", + "./.bzr/stat-cache", + ], + ) self.assertRepositoryHasSameItems( - tree.branch.repository, target.open_repository()) + tree.branch.repository, target.open_repository() + ) def test_sprout_branch_no_source_branch(self): try: - repo = self.make_repository('source', shared=True) + repo = self.make_repository("source", shared=True) except errors.IncompatibleFormat: return if isinstance(self.bzrdir_format, RemoteBzrDirFormat): - self.skipTest('remote formats not supported') - branch = controldir.ControlDir.create_branch_convenience('source/trunk') + self.skipTest("remote formats not supported") + branch = controldir.ControlDir.create_branch_convenience("source/trunk") tree = branch.controldir.open_workingtree() - self.build_tree(['source/trunk/foo']) - tree.add('foo') - tree.commit('revision 1') - rev2 = tree.commit('revision 2', allow_pointless=True) + self.build_tree(["source/trunk/foo"]) + tree.add("foo") + tree.commit("revision 1") + rev2 = tree.commit("revision 2", allow_pointless=True) target = self.sproutOrSkip( - repo.controldir, self.get_url('target'), revision_id=rev2) + repo.controldir, self.get_url("target"), revision_id=rev2 + ) self.assertEqual([rev2], target.open_workingtree().get_parent_ids()) def test_retire_bzrdir(self): - bd = self.make_controldir('.') + bd = self.make_controldir(".") transport = bd.root_transport # must not overwrite existing directories - self.build_tree(['.bzr.retired.0/', '.bzr.retired.0/junk', ], - transport=transport) - self.assertTrue(transport.has('.bzr')) + self.build_tree( + [ + ".bzr.retired.0/", + ".bzr.retired.0/junk", + ], + transport=transport, + ) + self.assertTrue(transport.has(".bzr")) bd.retire_bzrdir() - self.assertFalse(transport.has('.bzr')) - self.assertTrue(transport.has('.bzr.retired.1')) + self.assertFalse(transport.has(".bzr")) + self.assertTrue(transport.has(".bzr.retired.1")) def test_retire_bzrdir_limited(self): - bd = self.make_controldir('.') + bd = self.make_controldir(".") transport = bd.root_transport # must not overwrite existing directories - self.build_tree(['.bzr.retired.0/', '.bzr.retired.0/junk', ], - transport=transport) - self.assertTrue(transport.has('.bzr')) - self.assertRaises((FileExists, errors.DirectoryNotEmpty), - bd.retire_bzrdir, limit=0) + self.build_tree( + [ + ".bzr.retired.0/", + ".bzr.retired.0/junk", + ], + transport=transport, + ) + self.assertTrue(transport.has(".bzr")) + self.assertRaises( + (FileExists, errors.DirectoryNotEmpty), bd.retire_bzrdir, limit=0 + ) def test_get_branch_transport(self): - dir = self.make_controldir('.') + dir = self.make_controldir(".") # without a format, get_branch_transport gives use a transport # which -may- point to an existing dir. self.assertIsInstance( dir.get_branch_transport(None), - (transport.Transport, transport._transport_rs.Transport) + (transport.Transport, transport._transport_rs.Transport), ) # with a given format, either the bzr dir supports identifiable # branches, or it supports anonymous branch formats, but not both. @@ -555,22 +631,24 @@ def test_get_branch_transport(self): identifiable_format = IdentifiableTestBranchFormat() try: found_transport = dir.get_branch_transport(anonymous_format) - self.assertRaises(errors.IncompatibleFormat, - dir.get_branch_transport, - identifiable_format) + self.assertRaises( + errors.IncompatibleFormat, dir.get_branch_transport, identifiable_format + ) except errors.IncompatibleFormat: found_transport = dir.get_branch_transport(identifiable_format) - self.assertIsInstance(found_transport, (transport.Transport, transport._transport_rs.Transport)) + self.assertIsInstance( + found_transport, (transport.Transport, transport._transport_rs.Transport) + ) # and the dir which has been initialized for us must exist. - found_transport.list_dir('.') + found_transport.list_dir(".") def test_get_repository_transport(self): - dir = self.make_controldir('.') + dir = self.make_controldir(".") # without a format, get_repository_transport gives use a transport # which -may- point to an existing dir. self.assertIsInstance( dir.get_repository_transport(None), - (transport.Transport, transport._transport_rs.Transport) + (transport.Transport, transport._transport_rs.Transport), ) # with a given format, either the bzr dir supports identifiable # repositories, or it supports anonymous repository formats, but not both. @@ -578,17 +656,21 @@ def test_get_repository_transport(self): identifiable_format = IdentifiableTestRepositoryFormat() try: found_transport = dir.get_repository_transport(anonymous_format) - self.assertRaises(errors.IncompatibleFormat, - dir.get_repository_transport, - identifiable_format) + self.assertRaises( + errors.IncompatibleFormat, + dir.get_repository_transport, + identifiable_format, + ) except errors.IncompatibleFormat: found_transport = dir.get_repository_transport(identifiable_format) - self.assertIsInstance(found_transport, (transport.Transport, transport._transport_rs.Transport)) + self.assertIsInstance( + found_transport, (transport.Transport, transport._transport_rs.Transport) + ) # and the dir which has been initialized for us must exist. - found_transport.list_dir('.') + found_transport.list_dir(".") def test_get_workingtree_transport(self): - dir = self.make_controldir('.') + dir = self.make_controldir(".") # without a format, get_workingtree_transport gives use a transport # which -may- point to an existing dir. self.assertIsInstance( @@ -601,15 +683,18 @@ def test_get_workingtree_transport(self): identifiable_format = IdentifiableTestWorkingTreeFormat() try: found_transport = dir.get_workingtree_transport(anonymous_format) - self.assertRaises(errors.IncompatibleFormat, - dir.get_workingtree_transport, - identifiable_format) + self.assertRaises( + errors.IncompatibleFormat, + dir.get_workingtree_transport, + identifiable_format, + ) except errors.IncompatibleFormat: - found_transport = dir.get_workingtree_transport( - identifiable_format) - self.assertIsInstance(found_transport, (transport.Transport, transport._transport_rs.Transport)) + found_transport = dir.get_workingtree_transport(identifiable_format) + self.assertIsInstance( + found_transport, (transport.Transport, transport._transport_rs.Transport) + ) # and the dir which has been initialized for us must exist. - found_transport.list_dir('.') + found_transport.list_dir(".") def assertInitializeEx(self, t, need_meta=False, **kwargs): """Execute initialize_on_transport_ex and check it succeeded correctly. @@ -623,10 +708,13 @@ def assertInitializeEx(self, t, need_meta=False, **kwargs): :return: the resulting repo, control dir tuple. """ if not self.bzrdir_format.is_initializable(): - raise TestNotApplicable("control dir format is not " - "initializable") - repo, control, require_stacking, repo_policy = \ - self.bzrdir_format.initialize_on_transport_ex(t, **kwargs) + raise TestNotApplicable("control dir format is not " "initializable") + ( + repo, + control, + require_stacking, + repo_policy, + ) = self.bzrdir_format.initialize_on_transport_ex(t, **kwargs) if repo is not None: # Repositories are open write-locked self.assertTrue(repo.is_write_locked()) @@ -639,10 +727,12 @@ def assertInitializeEx(self, t, need_meta=False, **kwargs): # needs a metaformat, because clone is used for push. expected_format = bzrdir.BzrDirMetaFormat1() if not isinstance(expected_format, RemoteBzrDirFormat): - self.assertEqual(control._format.network_name(), - expected_format.network_name()) - self.assertEqual(control._format.network_name(), - opened._format.network_name()) + self.assertEqual( + control._format.network_name(), expected_format.network_name() + ) + self.assertEqual( + control._format.network_name(), opened._format.network_name() + ) self.assertEqual(control.__class__, opened.__class__) return repo, control @@ -651,30 +741,30 @@ def test_format_initialize_on_transport_ex_default_stack_on(self): # a stacking policy on the target, the location of the fallback # repository is the same as the external location of the stacked-on # branch. - balloon = self.make_controldir('balloon') + balloon = self.make_controldir("balloon") if isinstance(balloon._format, bzrdir.BzrDirMetaFormat1): - stack_on = self.make_branch('stack-on', format='1.9') + stack_on = self.make_branch("stack-on", format="1.9") else: - stack_on = self.make_branch('stack-on') + stack_on = self.make_branch("stack-on") if not stack_on.repository._format.supports_nesting_repositories: raise TestNotApplicable("requires nesting repositories") - config = self.make_controldir('.').get_config() + config = self.make_controldir(".").get_config() try: - config.set_default_stack_on('stack-on') + config.set_default_stack_on("stack-on") except errors.BzrError as e: - raise TestNotApplicable('Only relevant for stackable formats.') from e + raise TestNotApplicable("Only relevant for stackable formats.") from e # Initialize a bzrdir subject to the policy. - t = self.get_transport('stacked') - repo_fmt = controldir.format_registry.make_controldir('1.9') + t = self.get_transport("stacked") + repo_fmt = controldir.format_registry.make_controldir("1.9") repo_name = repo_fmt.repository_format.network_name() repo, control = self.assertInitializeEx( - t, need_meta=True, repo_format_name=repo_name, stacked_on=None) + t, need_meta=True, repo_format_name=repo_name, stacked_on=None + ) # self.addCleanup(repo.unlock) # There's one fallback repo, with a public location. self.assertLength(1, repo._fallback_repositories) fallback_repo = repo._fallback_repositories[0] - self.assertEqual( - stack_on.base, fallback_repo.controldir.root_transport.base) + self.assertEqual(stack_on.base, fallback_repo.controldir.root_transport.base) # The bzrdir creates a branch in stacking-capable format. new_branch = control.create_branch() self.assertTrue(new_branch._format.supports_stacking()) @@ -683,19 +773,15 @@ def test_no_leftover_dirs(self): # bug 886196: development-colo uses a branch-lock directory # in the user directory rather than the control directory. if not self.bzrdir_format.colocated_branches: - raise TestNotApplicable( - "format does not support colocated branches") - branch = self.make_branch('.', format='development-colo') + raise TestNotApplicable("format does not support colocated branches") + branch = self.make_branch(".", format="development-colo") branch.controldir.create_branch(name="another-colocated-branch") - self.assertEqual( - branch.controldir.user_transport.list_dir("."), - [".bzr"]) + self.assertEqual(branch.controldir.user_transport.list_dir("."), [".bzr"]) def test_get_branches(self): - repo = self.make_repository('branch-1') + repo = self.make_repository("branch-1") if not repo.controldir._format.colocated_branches: - raise TestNotApplicable('Format does not support colocation') - target_branch = repo.controldir.create_branch(name='foo') + raise TestNotApplicable("Format does not support colocation") + target_branch = repo.controldir.create_branch(name="foo") repo.controldir.set_branch_reference(target_branch) - self.assertEqual({"", 'foo'}, - set(repo.controldir.branch_names())) + self.assertEqual({"", "foo"}, set(repo.controldir.branch_names())) diff --git a/breezy/bzr/tests/per_inventory/__init__.py b/breezy/bzr/tests/per_inventory/__init__.py index 8253100ac2..1082bf3314 100644 --- a/breezy/bzr/tests/per_inventory/__init__.py +++ b/breezy/bzr/tests/per_inventory/__init__.py @@ -23,36 +23,43 @@ def load_tests(loader, basic_tests, pattern): """Generate suite containing all parameterized tests.""" modules_to_test = [ - 'breezy.bzr.tests.per_inventory.basics', - ] + "breezy.bzr.tests.per_inventory.basics", + ] from ...inventory import CHKInventory, Inventory def inv_to_chk_inv(test, inv): """CHKInventory needs a backing VF, so we create one.""" factory = groupcompress.make_pack_factory(True, True, 1) - trans = test.get_transport('chk-inv') + trans = test.get_transport("chk-inv") trans.ensure_base() vf = factory(trans) # We intentionally use a non-standard maximum_size, so that we are more # likely to trigger splits, and get increased test coverage. - chk_inv = CHKInventory.from_inventory(vf, inv, - maximum_size=100, - search_key_name=b'hash-255-way') + chk_inv = CHKInventory.from_inventory( + vf, inv, maximum_size=100, search_key_name=b"hash-255-way" + ) return chk_inv - scenarios = [('Inventory', {'_inventory_class': Inventory, - '_inv_to_test_inv': lambda test, inv: inv - }), - ('CHKInventory', {'_inventory_class': CHKInventory, - '_inv_to_test_inv': inv_to_chk_inv, - })] + + scenarios = [ + ( + "Inventory", + {"_inventory_class": Inventory, "_inv_to_test_inv": lambda test, inv: inv}, + ), + ( + "CHKInventory", + { + "_inventory_class": CHKInventory, + "_inv_to_test_inv": inv_to_chk_inv, + }, + ), + ] # add the tests for the sub modules return tests.multiply_tests( - loader.loadTestsFromModuleNames(modules_to_test), - scenarios, basic_tests) + loader.loadTestsFromModuleNames(modules_to_test), scenarios, basic_tests + ) class TestCaseWithInventory(tests.TestCaseWithMemoryTransport): - _inventory_class = None # set by load_tests _inv_to_test_inv = None # set by load_tests diff --git a/breezy/bzr/tests/per_inventory/basics.py b/breezy/bzr/tests/per_inventory/basics.py index fb211ab321..d724f5ce9f 100644 --- a/breezy/bzr/tests/per_inventory/basics.py +++ b/breezy/bzr/tests/per_inventory/basics.py @@ -28,43 +28,45 @@ class TestInventory(TestCaseWithInventory): - def make_init_inventory(self): - inv = inventory.Inventory(b'tree-root') - inv.revision_id = b'initial-rev' - inv.root.revision = b'initial-rev' + inv = inventory.Inventory(b"tree-root") + inv.revision_id = b"initial-rev" + inv.root.revision = b"initial-rev" return self.inv_to_test_inv(inv) - def make_file(self, file_id, name, parent_id, content=b'content\n', - revision=b'new-test-rev'): + def make_file( + self, file_id, name, parent_id, content=b"content\n", revision=b"new-test-rev" + ): ie = InventoryFile(file_id, name, parent_id) ie.text_sha1 = osutils.sha_string(content) ie.text_size = len(content) ie.revision = revision return ie - def make_link(self, file_id, name, parent_id, target='link-target\n'): + def make_link(self, file_id, name, parent_id, target="link-target\n"): ie = InventoryLink(file_id, name, parent_id) ie.symlink_target = target return ie def prepare_inv_with_nested_dirs(self): - inv = inventory.Inventory(b'tree-root') - inv.root.revision = b'revision' - for args in [('src', 'directory', b'src-id'), - ('doc', 'directory', b'doc-id'), - ('src/hello.c', 'file', b'hello-id'), - ('src/bye.c', 'file', b'bye-id'), - ('zz', 'file', b'zz-id'), - ('src/sub/', 'directory', b'sub-id'), - ('src/zz.c', 'file', b'zzc-id'), - ('src/sub/a', 'file', b'a-id'), - ('Makefile', 'file', b'makefile-id')]: + inv = inventory.Inventory(b"tree-root") + inv.root.revision = b"revision" + for args in [ + ("src", "directory", b"src-id"), + ("doc", "directory", b"doc-id"), + ("src/hello.c", "file", b"hello-id"), + ("src/bye.c", "file", b"bye-id"), + ("zz", "file", b"zz-id"), + ("src/sub/", "directory", b"sub-id"), + ("src/zz.c", "file", b"zzc-id"), + ("src/sub/a", "file", b"a-id"), + ("Makefile", "file", b"makefile-id"), + ]: ie = inv.add_path(*args) - ie.revision = b'revision' - if args[1] == 'file': - ie.text_sha1 = osutils.sha_string(b'content\n') - ie.text_size = len(b'content\n') + ie.revision = b"revision" + if args[1] == "file": + ie.text_sha1 = osutils.sha_string(b"content\n") + ie.text_size = len(b"content\n") return self.inv_to_test_inv(inv) @@ -77,287 +79,404 @@ class TestInventoryCreateByApplyDelta(TestInventory): def test_add(self): inv = self.make_init_inventory() - inv = inv.create_by_apply_delta(InventoryDelta([ - (None, "a", b"a-id", self.make_file(b'a-id', 'a', b'tree-root')), - ]), b'new-test-rev') - self.assertEqual('a', inv.id2path(b'a-id')) + inv = inv.create_by_apply_delta( + InventoryDelta( + [ + (None, "a", b"a-id", self.make_file(b"a-id", "a", b"tree-root")), + ] + ), + b"new-test-rev", + ) + self.assertEqual("a", inv.id2path(b"a-id")) def test_delete(self): inv = self.make_init_inventory() - inv = inv.create_by_apply_delta(InventoryDelta([ - (None, "a", b"a-id", self.make_file(b'a-id', 'a', b'tree-root')), - ]), b'new-rev-1') - self.assertEqual('a', inv.id2path(b'a-id')) - inv = inv.create_by_apply_delta(InventoryDelta([ - ("a", None, b"a-id", None), - ]), b'new-rev-2') - self.assertRaises(errors.NoSuchId, inv.id2path, b'a-id') + inv = inv.create_by_apply_delta( + InventoryDelta( + [ + (None, "a", b"a-id", self.make_file(b"a-id", "a", b"tree-root")), + ] + ), + b"new-rev-1", + ) + self.assertEqual("a", inv.id2path(b"a-id")) + inv = inv.create_by_apply_delta( + InventoryDelta( + [ + ("a", None, b"a-id", None), + ] + ), + b"new-rev-2", + ) + self.assertRaises(errors.NoSuchId, inv.id2path, b"a-id") def test_rename(self): inv = self.make_init_inventory() - inv = inv.create_by_apply_delta(InventoryDelta([ - (None, "a", b"a-id", self.make_file(b'a-id', 'a', b'tree-root')), - ]), b'new-rev-1') - self.assertEqual('a', inv.id2path(b'a-id')) - a_ie = inv.get_entry(b'a-id') + inv = inv.create_by_apply_delta( + InventoryDelta( + [ + (None, "a", b"a-id", self.make_file(b"a-id", "a", b"tree-root")), + ] + ), + b"new-rev-1", + ) + self.assertEqual("a", inv.id2path(b"a-id")) + a_ie = inv.get_entry(b"a-id") b_ie = self.make_file(a_ie.file_id, "b", a_ie.parent_id) inv = inv.create_by_apply_delta( - InventoryDelta([("a", "b", b"a-id", b_ie)]), b'new-rev-2') - self.assertEqual("b", inv.id2path(b'a-id')) + InventoryDelta([("a", "b", b"a-id", b_ie)]), b"new-rev-2" + ) + self.assertEqual("b", inv.id2path(b"a-id")) def test_illegal(self): # A file-id cannot appear in a delta more than once inv = self.make_init_inventory() - self.assertRaises(errors.InconsistentDelta, inv.create_by_apply_delta, InventoryDelta([ - (None, "a", b"id-1", self.make_file(b'id-1', 'a', b'tree-root')), - (None, "b", b"id-1", self.make_file(b'id-1', 'b', b'tree-root')), - ]), b'new-rev-1') + self.assertRaises( + errors.InconsistentDelta, + inv.create_by_apply_delta, + InventoryDelta( + [ + (None, "a", b"id-1", self.make_file(b"id-1", "a", b"tree-root")), + (None, "b", b"id-1", self.make_file(b"id-1", "b", b"tree-root")), + ] + ), + b"new-rev-1", + ) class TestInventoryReads(TestInventory): - def test_is_root(self): """Ensure our root-checking code is accurate.""" inv = self.make_init_inventory() - self.assertTrue(inv.is_root(b'tree-root')) - self.assertFalse(inv.is_root(b'booga')) - ie = inv.get_entry(b'tree-root').copy() - ie._file_id = b'booga' - inv = inv.create_by_apply_delta(InventoryDelta([("", None, b"tree-root", None), - (None, "", b"booga", ie)]), b'new-rev-2') - self.assertFalse(inv.is_root(b'TREE_ROOT')) - self.assertTrue(inv.is_root(b'booga')) + self.assertTrue(inv.is_root(b"tree-root")) + self.assertFalse(inv.is_root(b"booga")) + ie = inv.get_entry(b"tree-root").copy() + ie._file_id = b"booga" + inv = inv.create_by_apply_delta( + InventoryDelta([("", None, b"tree-root", None), (None, "", b"booga", ie)]), + b"new-rev-2", + ) + self.assertFalse(inv.is_root(b"TREE_ROOT")) + self.assertTrue(inv.is_root(b"booga")) def test_ids(self): """Test detection of files within selected directories.""" - inv = inventory.Inventory(b'TREE_ROOT') - inv.root.revision = b'revision' - for args in [('src', 'directory', b'src-id'), - ('doc', 'directory', b'doc-id'), - ('src/hello.c', 'file'), - ('src/bye.c', 'file', b'bye-id'), - ('Makefile', 'file')]: + inv = inventory.Inventory(b"TREE_ROOT") + inv.root.revision = b"revision" + for args in [ + ("src", "directory", b"src-id"), + ("doc", "directory", b"doc-id"), + ("src/hello.c", "file"), + ("src/bye.c", "file", b"bye-id"), + ("Makefile", "file"), + ]: ie = inv.add_path(*args) - ie.revision = b'revision' - if args[1] == 'file': - ie.text_sha1 = osutils.sha_string(b'content\n') - ie.text_size = len(b'content\n') + ie.revision = b"revision" + if args[1] == "file": + ie.text_sha1 = osutils.sha_string(b"content\n") + ie.text_size = len(b"content\n") inv = self.inv_to_test_inv(inv) - self.assertEqual(inv.path2id('src'), b'src-id') - self.assertEqual(inv.path2id('src/bye.c'), b'bye-id') + self.assertEqual(inv.path2id("src"), b"src-id") + self.assertEqual(inv.path2id("src/bye.c"), b"bye-id") def test_get_entry_by_path_partial(self): - inv = inventory.Inventory(b'TREE_ROOT') - inv.root.revision = b'revision' - for args in [('src', 'directory', b'src-id'), - ('doc', 'directory', b'doc-id'), - ('src/hello.c', 'file'), - ('src/bye.c', 'file', b'bye-id'), - ('Makefile', 'file'), - ('external', 'tree-reference', b'other-root')]: + inv = inventory.Inventory(b"TREE_ROOT") + inv.root.revision = b"revision" + for args in [ + ("src", "directory", b"src-id"), + ("doc", "directory", b"doc-id"), + ("src/hello.c", "file"), + ("src/bye.c", "file", b"bye-id"), + ("Makefile", "file"), + ("external", "tree-reference", b"other-root"), + ]: ie = inv.add_path(*args) - ie.revision = b'revision' - if args[1] == 'file': - ie.text_sha1 = osutils.sha_string(b'content\n') - ie.text_size = len(b'content\n') - if args[1] == 'tree-reference': - ie.reference_revision = b'reference' + ie.revision = b"revision" + if args[1] == "file": + ie.text_sha1 = osutils.sha_string(b"content\n") + ie.text_size = len(b"content\n") + if args[1] == "tree-reference": + ie.reference_revision = b"reference" inv = self.inv_to_test_inv(inv) # Standard lookups - ie, resolved, remaining = inv.get_entry_by_path_partial('') - self.assertEqual((ie.file_id, resolved, remaining), (b'TREE_ROOT', [], [])) - ie, resolved, remaining = inv.get_entry_by_path_partial('src') - self.assertEqual((ie.file_id, resolved, remaining), (b'src-id', ['src'], [])) - ie, resolved, remaining = inv.get_entry_by_path_partial('src/bye.c') - self.assertEqual((ie.file_id, resolved, remaining), (b'bye-id', ['src', 'bye.c'], [])) + ie, resolved, remaining = inv.get_entry_by_path_partial("") + self.assertEqual((ie.file_id, resolved, remaining), (b"TREE_ROOT", [], [])) + ie, resolved, remaining = inv.get_entry_by_path_partial("src") + self.assertEqual((ie.file_id, resolved, remaining), (b"src-id", ["src"], [])) + ie, resolved, remaining = inv.get_entry_by_path_partial("src/bye.c") + self.assertEqual( + (ie.file_id, resolved, remaining), (b"bye-id", ["src", "bye.c"], []) + ) # Paths in the external tree - ie, resolved, remaining = inv.get_entry_by_path_partial('external') - self.assertEqual((ie.file_id, resolved, remaining), (b'other-root', ['external'], [])) - ie, resolved, remaining = inv.get_entry_by_path_partial('external/blah') - self.assertEqual((ie.file_id, resolved, remaining), (b'other-root', ['external'], ['blah'])) + ie, resolved, remaining = inv.get_entry_by_path_partial("external") + self.assertEqual( + (ie.file_id, resolved, remaining), (b"other-root", ["external"], []) + ) + ie, resolved, remaining = inv.get_entry_by_path_partial("external/blah") + self.assertEqual( + (ie.file_id, resolved, remaining), (b"other-root", ["external"], ["blah"]) + ) # Nonexistant paths - ie, resolved, remaining = inv.get_entry_by_path_partial('foo.c') + ie, resolved, remaining = inv.get_entry_by_path_partial("foo.c") self.assertEqual((ie, resolved, remaining), (None, None, None)) def test_non_directory_children(self): """Test path2id when a parent directory has no children.""" - inv = inventory.Inventory(b'tree-root') - inv.add(self.make_file(b'file-id', 'file', b'tree-root')) - inv.add(self.make_link(b'link-id', 'link', b'tree-root')) - self.assertIs(None, inv.path2id('file/subfile')) - self.assertIs(None, inv.path2id('link/subfile')) + inv = inventory.Inventory(b"tree-root") + inv.add(self.make_file(b"file-id", "file", b"tree-root")) + inv.add(self.make_link(b"link-id", "link", b"tree-root")) + self.assertIs(None, inv.path2id("file/subfile")) + self.assertIs(None, inv.path2id("link/subfile")) def test_is_unmodified(self): - f1 = self.make_file(b'file-id', 'file', b'tree-root') - f1.revision = b'rev' + f1 = self.make_file(b"file-id", "file", b"tree-root") + f1.revision = b"rev" self.assertTrue(f1.is_unmodified(f1)) - f2 = self.make_file(b'file-id', 'file', b'tree-root') - f2.revision = b'rev' + f2 = self.make_file(b"file-id", "file", b"tree-root") + f2.revision = b"rev" self.assertTrue(f1.is_unmodified(f2)) - f3 = self.make_file(b'file-id', 'file', b'tree-root') + f3 = self.make_file(b"file-id", "file", b"tree-root") self.assertFalse(f1.is_unmodified(f3)) - f4 = self.make_file(b'file-id', 'file', b'tree-root') - f4.revision = b'rev1' + f4 = self.make_file(b"file-id", "file", b"tree-root") + f4.revision = b"rev1" self.assertFalse(f1.is_unmodified(f4)) def test_iter_entries(self): inv = self.prepare_inv_with_nested_dirs() # Test all entries - self.assertEqual([ - ('', b'tree-root'), - ('Makefile', b'makefile-id'), - ('doc', b'doc-id'), - ('src', b'src-id'), - ('src/bye.c', b'bye-id'), - ('src/hello.c', b'hello-id'), - ('src/sub', b'sub-id'), - ('src/sub/a', b'a-id'), - ('src/zz.c', b'zzc-id'), - ('zz', b'zz-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries()]) + self.assertEqual( + [ + ("", b"tree-root"), + ("Makefile", b"makefile-id"), + ("doc", b"doc-id"), + ("src", b"src-id"), + ("src/bye.c", b"bye-id"), + ("src/hello.c", b"hello-id"), + ("src/sub", b"sub-id"), + ("src/sub/a", b"a-id"), + ("src/zz.c", b"zzc-id"), + ("zz", b"zz-id"), + ], + [(path, ie.file_id) for path, ie in inv.iter_entries()], + ) # Test a subdirectory - self.assertEqual([ - ('bye.c', b'bye-id'), - ('hello.c', b'hello-id'), - ('sub', b'sub-id'), - ('sub/a', b'a-id'), - ('zz.c', b'zzc-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries( - from_dir=b'src-id')]) + self.assertEqual( + [ + ("bye.c", b"bye-id"), + ("hello.c", b"hello-id"), + ("sub", b"sub-id"), + ("sub/a", b"a-id"), + ("zz.c", b"zzc-id"), + ], + [(path, ie.file_id) for path, ie in inv.iter_entries(from_dir=b"src-id")], + ) # Test not recursing at the root level - self.assertEqual([ - ('', b'tree-root'), - ('Makefile', b'makefile-id'), - ('doc', b'doc-id'), - ('src', b'src-id'), - ('zz', b'zz-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries( - recursive=False)]) + self.assertEqual( + [ + ("", b"tree-root"), + ("Makefile", b"makefile-id"), + ("doc", b"doc-id"), + ("src", b"src-id"), + ("zz", b"zz-id"), + ], + [(path, ie.file_id) for path, ie in inv.iter_entries(recursive=False)], + ) # Test not recursing at a subdirectory level - self.assertEqual([ - ('bye.c', b'bye-id'), - ('hello.c', b'hello-id'), - ('sub', b'sub-id'), - ('zz.c', b'zzc-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries( - from_dir=b'src-id', recursive=False)]) + self.assertEqual( + [ + ("bye.c", b"bye-id"), + ("hello.c", b"hello-id"), + ("sub", b"sub-id"), + ("zz.c", b"zzc-id"), + ], + [ + (path, ie.file_id) + for path, ie in inv.iter_entries(from_dir=b"src-id", recursive=False) + ], + ) def test_iter_entries_by_dir(self): - inv = self. prepare_inv_with_nested_dirs() - self.assertEqual([ - ('', b'tree-root'), - ('Makefile', b'makefile-id'), - ('doc', b'doc-id'), - ('src', b'src-id'), - ('zz', b'zz-id'), - ('src/bye.c', b'bye-id'), - ('src/hello.c', b'hello-id'), - ('src/sub', b'sub-id'), - ('src/zz.c', b'zzc-id'), - ('src/sub/a', b'a-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries_by_dir()]) - self.assertEqual([ - ('', b'tree-root'), - ('Makefile', b'makefile-id'), - ('doc', b'doc-id'), - ('src', b'src-id'), - ('zz', b'zz-id'), - ('src/bye.c', b'bye-id'), - ('src/hello.c', b'hello-id'), - ('src/sub', b'sub-id'), - ('src/zz.c', b'zzc-id'), - ('src/sub/a', b'a-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries_by_dir( - specific_file_ids={b'a-id', b'zzc-id', b'doc-id', b'tree-root', - b'hello-id', b'bye-id', b'zz-id', b'src-id', b'makefile-id', - b'sub-id'})]) - - self.assertEqual([ - ('Makefile', b'makefile-id'), - ('doc', b'doc-id'), - ('zz', b'zz-id'), - ('src/bye.c', b'bye-id'), - ('src/hello.c', b'hello-id'), - ('src/zz.c', b'zzc-id'), - ('src/sub/a', b'a-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries_by_dir( - specific_file_ids={b'a-id', b'zzc-id', b'doc-id', - b'hello-id', b'bye-id', b'zz-id', b'makefile-id'})]) - - self.assertEqual([ - ('Makefile', b'makefile-id'), - ('src/bye.c', b'bye-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries_by_dir( - specific_file_ids={b'bye-id', b'makefile-id'})]) - - self.assertEqual([ - ('Makefile', b'makefile-id'), - ('src/bye.c', b'bye-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries_by_dir( - specific_file_ids={b'bye-id', b'makefile-id'})]) - - self.assertEqual([ - ('src/bye.c', b'bye-id'), - ], [(path, ie.file_id) for path, ie in inv.iter_entries_by_dir( - specific_file_ids={b'bye-id'})]) + inv = self.prepare_inv_with_nested_dirs() + self.assertEqual( + [ + ("", b"tree-root"), + ("Makefile", b"makefile-id"), + ("doc", b"doc-id"), + ("src", b"src-id"), + ("zz", b"zz-id"), + ("src/bye.c", b"bye-id"), + ("src/hello.c", b"hello-id"), + ("src/sub", b"sub-id"), + ("src/zz.c", b"zzc-id"), + ("src/sub/a", b"a-id"), + ], + [(path, ie.file_id) for path, ie in inv.iter_entries_by_dir()], + ) + self.assertEqual( + [ + ("", b"tree-root"), + ("Makefile", b"makefile-id"), + ("doc", b"doc-id"), + ("src", b"src-id"), + ("zz", b"zz-id"), + ("src/bye.c", b"bye-id"), + ("src/hello.c", b"hello-id"), + ("src/sub", b"sub-id"), + ("src/zz.c", b"zzc-id"), + ("src/sub/a", b"a-id"), + ], + [ + (path, ie.file_id) + for path, ie in inv.iter_entries_by_dir( + specific_file_ids={ + b"a-id", + b"zzc-id", + b"doc-id", + b"tree-root", + b"hello-id", + b"bye-id", + b"zz-id", + b"src-id", + b"makefile-id", + b"sub-id", + } + ) + ], + ) + + self.assertEqual( + [ + ("Makefile", b"makefile-id"), + ("doc", b"doc-id"), + ("zz", b"zz-id"), + ("src/bye.c", b"bye-id"), + ("src/hello.c", b"hello-id"), + ("src/zz.c", b"zzc-id"), + ("src/sub/a", b"a-id"), + ], + [ + (path, ie.file_id) + for path, ie in inv.iter_entries_by_dir( + specific_file_ids={ + b"a-id", + b"zzc-id", + b"doc-id", + b"hello-id", + b"bye-id", + b"zz-id", + b"makefile-id", + } + ) + ], + ) + + self.assertEqual( + [ + ("Makefile", b"makefile-id"), + ("src/bye.c", b"bye-id"), + ], + [ + (path, ie.file_id) + for path, ie in inv.iter_entries_by_dir( + specific_file_ids={b"bye-id", b"makefile-id"} + ) + ], + ) + + self.assertEqual( + [ + ("Makefile", b"makefile-id"), + ("src/bye.c", b"bye-id"), + ], + [ + (path, ie.file_id) + for path, ie in inv.iter_entries_by_dir( + specific_file_ids={b"bye-id", b"makefile-id"} + ) + ], + ) + + self.assertEqual( + [ + ("src/bye.c", b"bye-id"), + ], + [ + (path, ie.file_id) + for path, ie in inv.iter_entries_by_dir(specific_file_ids={b"bye-id"}) + ], + ) class TestInventoryFiltering(TestInventory): - def test_inv_filter_empty(self): inv = self.prepare_inv_with_nested_dirs() new_inv = inv.filter(set()) - self.assertEqual([ - ('', b'tree-root'), - ], [(path, ie.file_id) for path, ie in new_inv.iter_entries()]) + self.assertEqual( + [ + ("", b"tree-root"), + ], + [(path, ie.file_id) for path, ie in new_inv.iter_entries()], + ) def test_inv_filter_files(self): inv = self.prepare_inv_with_nested_dirs() - new_inv = inv.filter({b'zz-id', b'hello-id', b'a-id'}) - self.assertEqual([ - ('', b'tree-root'), - ('src', b'src-id'), - ('src/hello.c', b'hello-id'), - ('src/sub', b'sub-id'), - ('src/sub/a', b'a-id'), - ('zz', b'zz-id'), - ], [(path, ie.file_id) for path, ie in new_inv.iter_entries()]) + new_inv = inv.filter({b"zz-id", b"hello-id", b"a-id"}) + self.assertEqual( + [ + ("", b"tree-root"), + ("src", b"src-id"), + ("src/hello.c", b"hello-id"), + ("src/sub", b"sub-id"), + ("src/sub/a", b"a-id"), + ("zz", b"zz-id"), + ], + [(path, ie.file_id) for path, ie in new_inv.iter_entries()], + ) def test_inv_filter_dirs(self): inv = self.prepare_inv_with_nested_dirs() - new_inv = inv.filter({b'doc-id', b'sub-id'}) - self.assertEqual([ - ('', b'tree-root'), - ('doc', b'doc-id'), - ('src', b'src-id'), - ('src/sub', b'sub-id'), - ('src/sub/a', b'a-id'), - ], [(path, ie.file_id) for path, ie in new_inv.iter_entries()]) + new_inv = inv.filter({b"doc-id", b"sub-id"}) + self.assertEqual( + [ + ("", b"tree-root"), + ("doc", b"doc-id"), + ("src", b"src-id"), + ("src/sub", b"sub-id"), + ("src/sub/a", b"a-id"), + ], + [(path, ie.file_id) for path, ie in new_inv.iter_entries()], + ) def test_inv_filter_files_and_dirs(self): inv = self.prepare_inv_with_nested_dirs() - new_inv = inv.filter({b'makefile-id', b'src-id'}) - self.assertEqual([ - ('', b'tree-root'), - ('Makefile', b'makefile-id'), - ('src', b'src-id'), - ('src/bye.c', b'bye-id'), - ('src/hello.c', b'hello-id'), - ('src/sub', b'sub-id'), - ('src/sub/a', b'a-id'), - ('src/zz.c', b'zzc-id'), - ], [(path, ie.file_id) for path, ie in new_inv.iter_entries()]) + new_inv = inv.filter({b"makefile-id", b"src-id"}) + self.assertEqual( + [ + ("", b"tree-root"), + ("Makefile", b"makefile-id"), + ("src", b"src-id"), + ("src/bye.c", b"bye-id"), + ("src/hello.c", b"hello-id"), + ("src/sub", b"sub-id"), + ("src/sub/a", b"a-id"), + ("src/zz.c", b"zzc-id"), + ], + [(path, ie.file_id) for path, ie in new_inv.iter_entries()], + ) def test_inv_filter_entry_not_present(self): inv = self.prepare_inv_with_nested_dirs() - new_inv = inv.filter({b'not-present-id'}) - self.assertEqual([ - ('', b'tree-root'), - ], [(path, ie.file_id) for path, ie in new_inv.iter_entries()]) + new_inv = inv.filter({b"not-present-id"}) + self.assertEqual( + [ + ("", b"tree-root"), + ], + [(path, ie.file_id) for path, ie in new_inv.iter_entries()], + ) diff --git a/breezy/bzr/tests/per_pack_repository.py b/breezy/bzr/tests/per_pack_repository.py index 83920142c3..5cd92007fa 100644 --- a/breezy/bzr/tests/per_pack_repository.py +++ b/breezy/bzr/tests/per_pack_repository.py @@ -49,13 +49,13 @@ def get_format(self): def test_attribute__fetch_order(self): """Packs do not need ordered data retrieval.""" format = self.get_format() - repo = self.make_repository('.', format=format) - self.assertEqual('unordered', repo._format._fetch_order) + repo = self.make_repository(".", format=format) + self.assertEqual("unordered", repo._format._fetch_order) def test_attribute__fetch_uses_deltas(self): """Packs reuse deltas.""" format = self.get_format() - repo = self.make_repository('.', format=format) + repo = self.make_repository(".", format=format) if isinstance(format.repository_format, RepositoryFormat2a): # TODO: This is currently a workaround. CHK format repositories # ignore the 'deltas' flag, but during conversions, we can't @@ -67,7 +67,7 @@ def test_attribute__fetch_uses_deltas(self): def test_disk_layout(self): format = self.get_format() - repo = self.make_repository('.', format=format) + repo = self.make_repository(".", format=format) # in case of side effects of locking. repo.lock_write() repo.unlock() @@ -78,54 +78,56 @@ def test_disk_layout(self): self.check_databases(t) def check_format(self, t): - with t.get('format') as f: + with t.get("format") as f: self.assertEqualDiff( - self.format_string.encode('ascii'), # from scenario - f.read()) + self.format_string.encode("ascii"), # from scenario + f.read(), + ) def assertHasNoKndx(self, t, knit_name): """Assert that knit_name has no index on t.""" - self.assertFalse(t.has(knit_name + '.kndx')) + self.assertFalse(t.has(knit_name + ".kndx")) def assertHasNoKnit(self, t, knit_name): """Assert that knit_name exists on t.""" # no default content - self.assertFalse(t.has(knit_name + '.knit')) + self.assertFalse(t.has(knit_name + ".knit")) def check_databases(self, t): """Check knit content for a repository.""" # check conversion worked - self.assertHasNoKndx(t, 'inventory') - self.assertHasNoKnit(t, 'inventory') - self.assertHasNoKndx(t, 'revisions') - self.assertHasNoKnit(t, 'revisions') - self.assertHasNoKndx(t, 'signatures') - self.assertHasNoKnit(t, 'signatures') - self.assertFalse(t.has('knits')) + self.assertHasNoKndx(t, "inventory") + self.assertHasNoKnit(t, "inventory") + self.assertHasNoKndx(t, "revisions") + self.assertHasNoKnit(t, "revisions") + self.assertHasNoKndx(t, "signatures") + self.assertHasNoKnit(t, "signatures") + self.assertFalse(t.has("knits")) # revision-indexes file-container directory - self.assertEqual([], - list(self.index_class(t, 'pack-names', None).iter_all_entries())) - self.assertTrue(S_ISDIR(t.stat('packs').st_mode)) - self.assertTrue(S_ISDIR(t.stat('upload').st_mode)) - self.assertTrue(S_ISDIR(t.stat('indices').st_mode)) - self.assertTrue(S_ISDIR(t.stat('obsolete_packs').st_mode)) + self.assertEqual( + [], list(self.index_class(t, "pack-names", None).iter_all_entries()) + ) + self.assertTrue(S_ISDIR(t.stat("packs").st_mode)) + self.assertTrue(S_ISDIR(t.stat("upload").st_mode)) + self.assertTrue(S_ISDIR(t.stat("indices").st_mode)) + self.assertTrue(S_ISDIR(t.stat("obsolete_packs").st_mode)) def test_shared_disk_layout(self): format = self.get_format() - repo = self.make_repository('.', shared=True, format=format) + repo = self.make_repository(".", shared=True, format=format) # we want: t = repo.controldir.get_repository_transport(None) self.check_format(t) # XXX: no locks left when unlocked at the moment # self.assertEqualDiff('', t.get('lock').read()) # We should have a 'shared-storage' marker file. - with t.get('shared-storage') as f: - self.assertEqualDiff(b'', f.read()) + with t.get("shared-storage") as f: + self.assertEqualDiff(b"", f.read()) self.check_databases(t) def test_shared_no_tree_disk_layout(self): format = self.get_format() - repo = self.make_repository('.', shared=True, format=format) + repo = self.make_repository(".", shared=True, format=format) repo.set_make_working_trees(False) # we want: t = repo.controldir.get_repository_transport(None) @@ -133,86 +135,84 @@ def test_shared_no_tree_disk_layout(self): # XXX: no locks left when unlocked at the moment # self.assertEqualDiff('', t.get('lock').read()) # We should have a 'shared-storage' marker file. - with t.get('shared-storage') as f: - self.assertEqualDiff(b'', f.read()) + with t.get("shared-storage") as f: + self.assertEqualDiff(b"", f.read()) # We should have a marker for the no-working-trees flag. - with t.get('no-working-trees') as f: - self.assertEqualDiff(b'', f.read()) + with t.get("no-working-trees") as f: + self.assertEqualDiff(b"", f.read()) # The marker should go when we toggle the setting. repo.set_make_working_trees(True) - self.assertFalse(t.has('no-working-trees')) + self.assertFalse(t.has("no-working-trees")) self.check_databases(t) def test_adding_revision_creates_pack_indices(self): format = self.get_format() - tree = self.make_branch_and_tree('.', format=format) - trans = tree.branch.repository.controldir.get_repository_transport( - None) - self.assertEqual([], - list(self.index_class(trans, 'pack-names', None).iter_all_entries())) - tree.commit('foobarbaz') - index = self.index_class(trans, 'pack-names', None) + tree = self.make_branch_and_tree(".", format=format) + trans = tree.branch.repository.controldir.get_repository_transport(None) + self.assertEqual( + [], list(self.index_class(trans, "pack-names", None).iter_all_entries()) + ) + tree.commit("foobarbaz") + index = self.index_class(trans, "pack-names", None) index_nodes = list(index.iter_all_entries()) self.assertEqual(1, len(index_nodes)) node = index_nodes[0] name = node[1][0] # the pack sizes should be listed in the index pack_value = node[2] - sizes = [int(digits) for digits in pack_value.split(b' ')] - for size, suffix in zip(sizes, ['.rix', '.iix', '.tix', '.six']): + sizes = [int(digits) for digits in pack_value.split(b" ")] + for size, suffix in zip(sizes, [".rix", ".iix", ".tix", ".six"]): stat = trans.stat(f"indices/{name.decode('ascii')}{suffix}") self.assertEqual(size, stat.st_size) def test_pulling_nothing_leads_to_no_new_names(self): format = self.get_format() - tree1 = self.make_branch_and_tree('1', format=format) - tree2 = self.make_branch_and_tree('2', format=format) + tree1 = self.make_branch_and_tree("1", format=format) + tree2 = self.make_branch_and_tree("2", format=format) tree1.branch.repository.fetch(tree2.branch.repository) - trans = tree1.branch.repository.controldir.get_repository_transport( - None) - self.assertEqual([], - list(self.index_class(trans, 'pack-names', None).iter_all_entries())) + trans = tree1.branch.repository.controldir.get_repository_transport(None) + self.assertEqual( + [], list(self.index_class(trans, "pack-names", None).iter_all_entries()) + ) def test_commit_across_pack_shape_boundary_autopacks(self): format = self.get_format() - tree = self.make_branch_and_tree('.', format=format) - trans = tree.branch.repository.controldir.get_repository_transport( - None) + tree = self.make_branch_and_tree(".", format=format) + trans = tree.branch.repository.controldir.get_repository_transport(None) # This test could be a little cheaper by replacing the packs # attribute on the repository to allow a different pack distribution # and max packs policy - so we are checking the policy is honoured # in the test. But for now 11 commits is not a big deal in a single # test. for x in range(9): - tree.commit(f'commit {x}') + tree.commit(f"commit {x}") # there should be 9 packs: - index = self.index_class(trans, 'pack-names', None) + index = self.index_class(trans, "pack-names", None) self.assertEqual(9, len(list(index.iter_all_entries()))) # insert some files in obsolete_packs which should be removed by pack. - trans.put_bytes('obsolete_packs/foo', b'123') - trans.put_bytes('obsolete_packs/bar', b'321') + trans.put_bytes("obsolete_packs/foo", b"123") + trans.put_bytes("obsolete_packs/bar", b"321") # committing one more should coalesce to 1 of 10. - tree.commit('commit triggering pack') - index = self.index_class(trans, 'pack-names', None) + tree.commit("commit triggering pack") + index = self.index_class(trans, "pack-names", None) self.assertEqual(1, len(list(index.iter_all_entries()))) # packing should not damage data tree = tree.controldir.open_workingtree() - tree.branch.repository.check( - [tree.branch.last_revision()]) + tree.branch.repository.check([tree.branch.last_revision()]) nb_files = 5 # .pack, .rix, .iix, .tix, .six if tree.branch.repository._format.supports_chks: nb_files += 1 # .cix # We should have 10 x nb_files files in the obsolete_packs directory. - obsolete_files = list(trans.list_dir('obsolete_packs')) - self.assertFalse('foo' in obsolete_files) - self.assertFalse('bar' in obsolete_files) + obsolete_files = list(trans.list_dir("obsolete_packs")) + self.assertFalse("foo" in obsolete_files) + self.assertFalse("bar" in obsolete_files) self.assertEqual(10 * nb_files, len(obsolete_files)) # XXX: Todo check packs obsoleted correctly - old packs and indices # in the obsolete_packs directory. large_pack_name = list(index.iter_all_entries())[0][1][0] # finally, committing again should not touch the large pack. - tree.commit('commit not triggering pack') - index = self.index_class(trans, 'pack-names', None) + tree.commit("commit not triggering pack") + index = self.index_class(trans, "pack-names", None) self.assertEqual(2, len(list(index.iter_all_entries()))) pack_names = [node[1][0] for node in index.iter_all_entries()] self.assertTrue(large_pack_name in pack_names) @@ -221,7 +221,7 @@ def test_commit_write_group_returns_new_pack_names(self): # This test doesn't need real disk. self.vfs_transport_factory = memory.MemoryServer format = self.get_format() - repo = self.make_repository('foo', format=format) + repo = self.make_repository("foo", format=format) with repo.lock_write(): # All current pack repository styles autopack at 10 revisions; and # autopack as well as regular commit write group needs to return @@ -229,18 +229,22 @@ def test_commit_write_group_returns_new_pack_names(self): # clean way to test both the autopack logic and the normal code # path without doing this loop. for pos in range(10): - revid = b'%d' % pos + revid = b"%d" % pos repo.start_write_group() try: inv = inventory.Inventory(revision_id=revid) inv.root.revision = revid repo.texts.add_lines((inv.root.file_id, revid), [], []) - rev = _mod_revision.Revision(timestamp=0, timezone=None, - committer="Foo Bar ", message="Message", - parent_ids=[], - properties={}, - inventory_sha1=None, - revision_id=revid) + rev = _mod_revision.Revision( + timestamp=0, + timezone=None, + committer="Foo Bar ", + message="Message", + parent_ids=[], + properties={}, + inventory_sha1=None, + revision_id=revid, + ) repo.add_revision(revid, rev, inv=inv) except: repo.abort_write_group() @@ -262,36 +266,40 @@ def test_fail_obsolete_deletion(self): bzrdir = self.get_format().initialize_on_transport(t) repo = bzrdir.create_repository() repo_transport = bzrdir.get_repository_transport(None) - self.assertTrue(repo_transport.has('obsolete_packs')) + self.assertTrue(repo_transport.has("obsolete_packs")) # these files are in use by another client and typically can't be deleted - repo_transport.put_bytes('obsolete_packs/.nfsblahblah', b'contents') + repo_transport.put_bytes("obsolete_packs/.nfsblahblah", b"contents") repo._pack_collection._clear_obsolete_packs() - self.assertTrue(repo_transport.has('obsolete_packs/.nfsblahblah')) + self.assertTrue(repo_transport.has("obsolete_packs/.nfsblahblah")) def test_pack_collection_sets_sibling_indices(self): """The CombinedGraphIndex objects in the pack collection are all siblings of each other, so that search-order reorderings will be copied to each other. """ - repo = self.make_repository('repo') + repo = self.make_repository("repo") pack_coll = repo._pack_collection - indices = {pack_coll.revision_index, pack_coll.inventory_index, - pack_coll.text_index, pack_coll.signature_index} + indices = { + pack_coll.revision_index, + pack_coll.inventory_index, + pack_coll.text_index, + pack_coll.signature_index, + } if pack_coll.chk_index is not None: indices.add(pack_coll.chk_index) combined_indices = {idx.combined_index for idx in indices} for combined_index in combined_indices: self.assertEqual( combined_indices.difference([combined_index]), - combined_index._sibling_indices) + combined_index._sibling_indices, + ) def test_pack_with_signatures(self): format = self.get_format() - tree = self.make_branch_and_tree('.', format=format) - trans = tree.branch.repository.controldir.get_repository_transport( - None) - revid1 = tree.commit('start') - revid2 = tree.commit('more work') + tree = self.make_branch_and_tree(".", format=format) + trans = tree.branch.repository.controldir.get_repository_transport(None) + revid1 = tree.commit("start") + revid2 = tree.commit("more work") strategy = gpg.LoopbackGPGStrategy(None) repo = tree.branch.repository self.addCleanup(repo.lock_write().unlock) @@ -303,20 +311,19 @@ def test_pack_with_signatures(self): repo.commit_write_group() tree.branch.repository.pack() # there should be 1 pack: - index = self.index_class(trans, 'pack-names', None) + index = self.index_class(trans, "pack-names", None) self.assertEqual(1, len(list(index.iter_all_entries()))) self.assertEqual(2, len(tree.branch.repository.all_revision_ids())) def test_pack_after_two_commits_packs_everything(self): format = self.get_format() - tree = self.make_branch_and_tree('.', format=format) - trans = tree.branch.repository.controldir.get_repository_transport( - None) - tree.commit('start') - tree.commit('more work') + tree = self.make_branch_and_tree(".", format=format) + trans = tree.branch.repository.controldir.get_repository_transport(None) + tree.commit("start") + tree.commit("more work") tree.branch.repository.pack() # there should be 1 pack: - index = self.index_class(trans, 'pack-names', None) + index = self.index_class(trans, "pack-names", None) self.assertEqual(1, len(list(index.iter_all_entries()))) self.assertEqual(2, len(tree.branch.repository.all_revision_ids())) @@ -327,51 +334,55 @@ def test_pack_preserves_all_inventories(self): # after a pack operation. However, it is harder to test that, then just # test that all inventory texts are preserved. format = self.get_format() - builder = self.make_branch_builder('source', format=format) + builder = self.make_branch_builder("source", format=format) builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None))], - revision_id=b'A-id') - builder.build_snapshot(None, [ - ('add', ('file', b'file-id', 'file', b'B content\n'))], - revision_id=b'B-id') - builder.build_snapshot(None, [ - ('modify', ('file', b'C content\n'))], - revision_id=b'C-id') + builder.build_snapshot( + None, [("add", ("", b"root-id", "directory", None))], revision_id=b"A-id" + ) + builder.build_snapshot( + None, + [("add", ("file", b"file-id", "file", b"B content\n"))], + revision_id=b"B-id", + ) + builder.build_snapshot( + None, [("modify", ("file", b"C content\n"))], revision_id=b"C-id" + ) builder.finish_series() b = builder.get_branch() b.lock_read() self.addCleanup(b.unlock) - repo = self.make_repository('repo', shared=True, format=format) + repo = self.make_repository("repo", shared=True, format=format) repo.lock_write() self.addCleanup(repo.unlock) - repo.fetch(b.repository, revision_id=b'B-id') - inv = next(b.repository.iter_inventories([b'C-id'])) + repo.fetch(b.repository, revision_id=b"B-id") + inv = next(b.repository.iter_inventories([b"C-id"])) repo.start_write_group() - repo.add_inventory(b'C-id', inv, [b'B-id']) + repo.add_inventory(b"C-id", inv, [b"B-id"]) repo.commit_write_group() - self.assertEqual([(b'A-id',), (b'B-id',), (b'C-id',)], - sorted(repo.inventories.keys())) + self.assertEqual( + [(b"A-id",), (b"B-id",), (b"C-id",)], sorted(repo.inventories.keys()) + ) repo.pack() - self.assertEqual([(b'A-id',), (b'B-id',), (b'C-id',)], - sorted(repo.inventories.keys())) + self.assertEqual( + [(b"A-id",), (b"B-id",), (b"C-id",)], sorted(repo.inventories.keys()) + ) # Content should be preserved as well - self.assertEqual(inv, next(repo.iter_inventories([b'C-id']))) + self.assertEqual(inv, next(repo.iter_inventories([b"C-id"]))) def test_pack_layout(self): # Test that the ordering of revisions in pack repositories is # tip->ancestor format = self.get_format() - tree = self.make_branch_and_tree('.', format=format) - tree.branch.repository.controldir.get_repository_transport( - None) - tree.commit('start', rev_id=b'1') - tree.commit('more work', rev_id=b'2') + tree = self.make_branch_and_tree(".", format=format) + tree.branch.repository.controldir.get_repository_transport(None) + tree.commit("start", rev_id=b"1") + tree.commit("more work", rev_id=b"2") tree.branch.repository.pack() tree.lock_read() self.addCleanup(tree.unlock) pack = tree.branch.repository._pack_collection.get_pack_by_name( - tree.branch.repository._pack_collection.names()[0]) + tree.branch.repository._pack_collection.names()[0] + ) # revision access tends to be tip->ancestor, so ordering that way on # disk is a good idea. pos_1 = pos_2 = None @@ -382,17 +393,17 @@ def test_pack_layout(self): else: # eol_flag, start, len pos = int(val[1:].split()[0]) - if key == (b'1',): + if key == (b"1",): pos_1 = pos else: pos_2 = pos - self.assertTrue(pos_2 < pos_1, f'rev 1 came before rev 2 {pos_1} > {pos_2}') + self.assertTrue(pos_2 < pos_1, f"rev 1 came before rev 2 {pos_1} > {pos_2}") def test_pack_repositories_support_multiple_write_locks(self): format = self.get_format() - self.make_repository('.', shared=True, format=format) - r1 = repository.Repository.open('.') - r2 = repository.Repository.open('.') + self.make_repository(".", shared=True, format=format) + r1 = repository.Repository.open(".") + r2 = repository.Repository.open(".") r1.lock_write() self.addCleanup(r1.unlock) r2.lock_write() @@ -400,14 +411,15 @@ def test_pack_repositories_support_multiple_write_locks(self): def _add_text(self, repo, fileid): """Add a text to the repository within a write group.""" - repo.texts.add_lines((fileid, b'samplerev+' + fileid), [], - [b'smaplerev+' + fileid]) + repo.texts.add_lines( + (fileid, b"samplerev+" + fileid), [], [b"smaplerev+" + fileid] + ) def test_concurrent_writers_merge_new_packs(self): format = self.get_format() - self.make_repository('.', shared=True, format=format) - r1 = repository.Repository.open('.') - r2 = repository.Repository.open('.') + self.make_repository(".", shared=True, format=format) + r1 = repository.Repository.open(".") + r2 = repository.Repository.open(".") with r1.lock_write(): # access enough data to load the names list list(r1.all_revision_ids()) @@ -418,8 +430,8 @@ def test_concurrent_writers_merge_new_packs(self): try: r2.start_write_group() try: - self._add_text(r1, b'fileidr1') - self._add_text(r2, b'fileidr2') + self._add_text(r1, b"fileidr1") + self._add_text(r2, b"fileidr2") except: r2.abort_write_group() raise @@ -441,19 +453,20 @@ def test_concurrent_writers_merge_new_packs(self): # Now both repositories should know about both names r1._pack_collection.ensure_loaded() r2._pack_collection.ensure_loaded() - self.assertEqual(r1._pack_collection.names(), - r2._pack_collection.names()) + self.assertEqual( + r1._pack_collection.names(), r2._pack_collection.names() + ) self.assertEqual(2, len(r1._pack_collection.names())) def test_concurrent_writer_second_preserves_dropping_a_pack(self): format = self.get_format() - self.make_repository('.', shared=True, format=format) - r1 = repository.Repository.open('.') - r2 = repository.Repository.open('.') + self.make_repository(".", shared=True, format=format) + r1 = repository.Repository.open(".") + r2 = repository.Repository.open(".") # add a pack to drop with r1.lock_write(): with repository.WriteGroup(r1): - self._add_text(r1, b'fileidr1') + self._add_text(r1, b"fileidr1") r1._pack_collection.ensure_loaded() name_to_drop = r1._pack_collection.all_packs()[0].name with r1.lock_write(): @@ -468,9 +481,10 @@ def test_concurrent_writer_second_preserves_dropping_a_pack(self): try: # in r1, drop the pack r1._pack_collection._remove_pack_from_memory( - r1._pack_collection.get_pack_by_name(name_to_drop)) + r1._pack_collection.get_pack_by_name(name_to_drop) + ) # in r2, add a pack - self._add_text(r2, b'fileidr2') + self._add_text(r2, b"fileidr2") except: r2.abort_write_group() raise @@ -496,18 +510,19 @@ def test_concurrent_writer_second_preserves_dropping_a_pack(self): # Now both repositories should now about just one name. r1._pack_collection.ensure_loaded() r2._pack_collection.ensure_loaded() - self.assertEqual(r1._pack_collection.names(), - r2._pack_collection.names()) + self.assertEqual( + r1._pack_collection.names(), r2._pack_collection.names() + ) self.assertEqual(1, len(r1._pack_collection.names())) self.assertFalse(name_to_drop in r1._pack_collection.names()) def test_concurrent_pack_triggers_reload(self): # create 2 packs, which we will then collapse - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") with tree.lock_write(): - rev1 = tree.commit('one') - rev2 = tree.commit('two') - r2 = repository.Repository.open('tree') + rev1 = tree.commit("one") + rev2 = tree.commit("two") + r2 = repository.Repository.open("tree") with r2.lock_read(): # Now r2 has read the pack-names file, but will need to reload # it after r1 has repacked @@ -515,19 +530,18 @@ def test_concurrent_pack_triggers_reload(self): self.assertEqual({rev2: (rev1,)}, r2.get_parent_map([rev2])) def test_concurrent_pack_during_get_record_reloads(self): - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") with tree.lock_write(): - rev1 = tree.commit('one') - rev2 = tree.commit('two') + rev1 = tree.commit("one") + rev2 = tree.commit("two") keys = [(rev1,), (rev2,)] - r2 = repository.Repository.open('tree') + r2 = repository.Repository.open("tree") with r2.lock_read(): # At this point, we will start grabbing a record stream, and # trigger a repack mid-way packed = False result = {} - record_stream = r2.revisions.get_record_stream(keys, - 'unordered', False) + record_stream = r2.revisions.get_record_stream(keys, "unordered", False) for record in record_stream: result[record.key] = record if not packed: @@ -538,11 +552,11 @@ def test_concurrent_pack_during_get_record_reloads(self): self.assertEqual(sorted(keys), sorted(result.keys())) def test_concurrent_pack_during_autopack(self): - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") with tree.lock_write(): for i in range(9): - tree.commit('rev %d' % (i,)) - r2 = repository.Repository.open('tree') + tree.commit("rev %d" % (i,)) + r2 = repository.Repository.open("tree") with r2.lock_write(): # Monkey patch so that pack occurs while the other repo is # autopacking. This is slightly bad, but all current pack @@ -559,8 +573,9 @@ def trigger_during_auto(*args, **kwargs): r2.pack() autopack_count[0] += 1 return ret + r1._pack_collection.pack_distribution = trigger_during_auto - tree.commit('autopack-rev') + tree.commit("autopack-rev") # This triggers 2 autopacks. The first one causes r2.pack() to # fire, but r2 doesn't see the new pack file yet. The # autopack restarts and sees there are 2 files and there @@ -569,7 +584,7 @@ def trigger_during_auto(*args, **kwargs): self.assertEqual([2], autopack_count) def test_lock_write_does_not_physically_lock(self): - repo = self.make_repository('.', format=self.get_format()) + repo = self.make_repository(".", format=self.get_format()) repo.lock_write() self.addCleanup(repo.unlock) self.assertFalse(repo.get_physical_lock_status()) @@ -580,25 +595,24 @@ def prepare_for_break_lock(self): ui.ui_factory = ui.CannedInputUIFactory([True]) def test_break_lock_breaks_physical_lock(self): - repo = self.make_repository('.', format=self.get_format()) + repo = self.make_repository(".", format=self.get_format()) repo._pack_collection.lock_names() repo.control_files.leave_in_place() repo.unlock() - repo2 = repository.Repository.open('.') + repo2 = repository.Repository.open(".") self.assertTrue(repo.get_physical_lock_status()) self.prepare_for_break_lock() repo2.break_lock() self.assertFalse(repo.get_physical_lock_status()) def test_broken_physical_locks_error_on__unlock_names_lock(self): - repo = self.make_repository('.', format=self.get_format()) + repo = self.make_repository(".", format=self.get_format()) repo._pack_collection.lock_names() self.assertTrue(repo.get_physical_lock_status()) - repo2 = repository.Repository.open('.') + repo2 = repository.Repository.open(".") self.prepare_for_break_lock() repo2.break_lock() - self.assertRaises(errors.LockBroken, - repo._pack_collection._unlock_names) + self.assertRaises(errors.LockBroken, repo._pack_collection._unlock_names) def test_fetch_without_find_ghosts_ignores_ghosts(self): # we want two repositories at this point: @@ -613,9 +627,8 @@ def test_fetch_without_find_ghosts_ignores_ghosts(self): # 'references' 'references' # 'tip' - # In this test we fetch 'tip' which should not fetch 'ghost' - has_ghost = self.make_repository('has_ghost', format=self.get_format()) - missing_ghost = self.make_repository('missing_ghost', - format=self.get_format()) + has_ghost = self.make_repository("has_ghost", format=self.get_format()) + missing_ghost = self.make_repository("missing_ghost", format=self.get_format()) def add_commit(repo, revision_id, parent_ids): repo.lock_write() @@ -625,35 +638,36 @@ def add_commit(repo, revision_id, parent_ids): root_id = inv.root.file_id sha1 = repo.add_inventory(revision_id, inv, []) repo.texts.add_lines((root_id, revision_id), [], []) - rev = _mod_revision.Revision(timestamp=0, - timezone=None, - committer="Foo Bar ", - properties={}, - message="Message", - inventory_sha1=sha1, - parent_ids=parent_ids, - revision_id=revision_id) + rev = _mod_revision.Revision( + timestamp=0, + timezone=None, + committer="Foo Bar ", + properties={}, + message="Message", + inventory_sha1=sha1, + parent_ids=parent_ids, + revision_id=revision_id, + ) repo.add_revision(revision_id, rev) repo.commit_write_group() repo.unlock() - add_commit(has_ghost, b'ghost', []) - add_commit(has_ghost, b'references', [b'ghost']) - add_commit(missing_ghost, b'references', [b'ghost']) - add_commit(has_ghost, b'tip', [b'references']) - missing_ghost.fetch(has_ghost, b'tip') + + add_commit(has_ghost, b"ghost", []) + add_commit(has_ghost, b"references", [b"ghost"]) + add_commit(missing_ghost, b"references", [b"ghost"]) + add_commit(has_ghost, b"tip", [b"references"]) + missing_ghost.fetch(has_ghost, b"tip") # missing ghost now has tip and not ghost. - missing_ghost.get_revision(b'tip') - missing_ghost.get_inventory(b'tip') - self.assertRaises(errors.NoSuchRevision, - missing_ghost.get_revision, b'ghost') - self.assertRaises(errors.NoSuchRevision, - missing_ghost.get_inventory, b'ghost') + missing_ghost.get_revision(b"tip") + missing_ghost.get_inventory(b"tip") + self.assertRaises(errors.NoSuchRevision, missing_ghost.get_revision, b"ghost") + self.assertRaises(errors.NoSuchRevision, missing_ghost.get_inventory, b"ghost") def make_write_ready_repo(self): format = self.get_format() if isinstance(format.repository_format, RepositoryFormat2a): raise TestNotApplicable("No missing compression parents") - repo = self.make_repository('.', format=format) + repo = self.make_repository(".", format=format) repo.lock_write() self.addCleanup(repo.unlock) repo.start_write_group() @@ -662,36 +676,38 @@ def make_write_ready_repo(self): def test_missing_inventories_compression_parent_prevents_commit(self): repo = self.make_write_ready_repo() - key = ('junk',) + key = ("junk",) repo.inventories._index._missing_compression_parents.add(key) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) def test_missing_revisions_compression_parent_prevents_commit(self): repo = self.make_write_ready_repo() - key = ('junk',) + key = ("junk",) repo.revisions._index._missing_compression_parents.add(key) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) def test_missing_signatures_compression_parent_prevents_commit(self): repo = self.make_write_ready_repo() - key = ('junk',) + key = ("junk",) repo.signatures._index._missing_compression_parents.add(key) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) def test_missing_text_compression_parent_prevents_commit(self): repo = self.make_write_ready_repo() - key = ('some', 'junk') + key = ("some", "junk") repo.texts._index._missing_compression_parents.add(key) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) def test_supports_external_lookups(self): - repo = self.make_repository('.', format=self.get_format()) - self.assertEqual(self.format_supports_external_lookups, - repo._format.supports_external_lookups) + repo = self.make_repository(".", format=self.get_format()) + self.assertEqual( + self.format_supports_external_lookups, + repo._format.supports_external_lookups, + ) def _lock_write(self, write_lockable): """Lock write_lockable, add a cleanup and return the result. @@ -709,27 +725,27 @@ def test_abort_write_group_does_not_raise_when_suppressed(self): Also requires that the exception is logged. """ self.vfs_transport_factory = memory.MemoryServer - repo = self.make_repository('repo', format=self.get_format()) + repo = self.make_repository("repo", format=self.get_format()) token = self._lock_write(repo).repository_token repo.start_write_group() # Damage the repository on the filesystem - self.get_transport('').rename('repo', 'foo') + self.get_transport("").rename("repo", "foo") # abort_write_group will not raise an error self.assertEqual(None, repo.abort_write_group(suppress_errors=True)) # But it does log an error log = self.get_log() - self.assertContainsRe(log, 'abort_write_group failed') - self.assertContainsRe(log, r'INFO brz: ERROR \(ignored\):') + self.assertContainsRe(log, "abort_write_group failed") + self.assertContainsRe(log, r"INFO brz: ERROR \(ignored\):") if token is not None: repo.leave_lock_in_place() def test_abort_write_group_does_raise_when_not_suppressed(self): self.vfs_transport_factory = memory.MemoryServer - repo = self.make_repository('repo', format=self.get_format()) + repo = self.make_repository("repo", format=self.get_format()) token = self._lock_write(repo).repository_token repo.start_write_group() # Damage the repository on the filesystem - self.get_transport('').rename('repo', 'foo') + self.get_transport("").rename("repo", "foo") # abort_write_group will not raise an error self.assertRaises(Exception, repo.abort_write_group) if token is not None: @@ -737,32 +753,33 @@ def test_abort_write_group_does_raise_when_not_suppressed(self): def test_suspend_write_group(self): self.vfs_transport_factory = memory.MemoryServer - repo = self.make_repository('repo', format=self.get_format()) + repo = self.make_repository("repo", format=self.get_format()) self._lock_write(repo).repository_token # noqa: B018 repo.start_write_group() - repo.texts.add_lines((b'file-id', b'revid'), (), [b'lines']) + repo.texts.add_lines((b"file-id", b"revid"), (), [b"lines"]) wg_tokens = repo.suspend_write_group() - expected_pack_name = wg_tokens[0] + '.pack' - expected_names = [wg_tokens[0] + ext for ext in - ('.rix', '.iix', '.tix', '.six')] + expected_pack_name = wg_tokens[0] + ".pack" + expected_names = [ + wg_tokens[0] + ext for ext in (".rix", ".iix", ".tix", ".six") + ] if repo.chk_bytes is not None: - expected_names.append(wg_tokens[0] + '.cix') + expected_names.append(wg_tokens[0] + ".cix") expected_names.append(expected_pack_name) upload_transport = repo._pack_collection._upload_transport - limbo_files = upload_transport.list_dir('') + limbo_files = upload_transport.list_dir("") self.assertEqual(sorted(expected_names), sorted(limbo_files)) md5 = hashlib.md5(upload_transport.get_bytes(expected_pack_name)) # noqa: S324 self.assertEqual(wg_tokens[0], md5.hexdigest()) def test_resume_chk_bytes(self): self.vfs_transport_factory = memory.MemoryServer - repo = self.make_repository('repo', format=self.get_format()) + repo = self.make_repository("repo", format=self.get_format()) if repo.chk_bytes is None: - raise TestNotApplicable('no chk_bytes for this repository') + raise TestNotApplicable("no chk_bytes for this repository") self._lock_write(repo).repository_token # noqa: B018 repo.start_write_group() - text = b'a bit of text\n' - key = (b'sha1:' + osutils.sha_string(text),) + text = b"a bit of text\n" + key = (b"sha1:" + osutils.sha_string(text),) repo.chk_bytes.add_lines(key, (), [text]) wg_tokens = repo.suspend_write_group() same_repo = repo.controldir.open_repository() @@ -771,19 +788,22 @@ def test_resume_chk_bytes(self): same_repo.resume_write_group(wg_tokens) self.assertEqual([key], list(same_repo.chk_bytes.keys())) self.assertEqual( - text, next(same_repo.chk_bytes.get_record_stream( - [key], 'unordered', True)).get_bytes_as('fulltext')) + text, + next( + same_repo.chk_bytes.get_record_stream([key], "unordered", True) + ).get_bytes_as("fulltext"), + ) same_repo.abort_write_group() self.assertEqual([], list(same_repo.chk_bytes.keys())) def test_resume_write_group_then_abort(self): # Create a repo, start a write group, insert some data, suspend. self.vfs_transport_factory = memory.MemoryServer - repo = self.make_repository('repo', format=self.get_format()) + repo = self.make_repository("repo", format=self.get_format()) self._lock_write(repo).repository_token # noqa: B018 repo.start_write_group() - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() # Get a fresh repository object for the repo on the filesystem. same_repo = repo.controldir.open_repository() @@ -792,18 +812,16 @@ def test_resume_write_group_then_abort(self): self.addCleanup(same_repo.unlock) same_repo.resume_write_group(wg_tokens) same_repo.abort_write_group() - self.assertEqual( - [], same_repo._pack_collection._upload_transport.list_dir('')) - self.assertEqual( - [], same_repo._pack_collection._pack_transport.list_dir('')) + self.assertEqual([], same_repo._pack_collection._upload_transport.list_dir("")) + self.assertEqual([], same_repo._pack_collection._pack_transport.list_dir("")) def test_commit_resumed_write_group(self): self.vfs_transport_factory = memory.MemoryServer - repo = self.make_repository('repo', format=self.get_format()) + repo = self.make_repository("repo", format=self.get_format()) self._lock_write(repo).repository_token # noqa: B018 repo.start_write_group() - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() # Get a fresh repository object for the repo on the filesystem. same_repo = repo.controldir.open_repository() @@ -812,35 +830,34 @@ def test_commit_resumed_write_group(self): self.addCleanup(same_repo.unlock) same_repo.resume_write_group(wg_tokens) same_repo.commit_write_group() - expected_pack_name = wg_tokens[0] + '.pack' - expected_names = [wg_tokens[0] + ext for ext in - ('.rix', '.iix', '.tix', '.six')] + expected_pack_name = wg_tokens[0] + ".pack" + expected_names = [ + wg_tokens[0] + ext for ext in (".rix", ".iix", ".tix", ".six") + ] if repo.chk_bytes is not None: - expected_names.append(wg_tokens[0] + '.cix') - self.assertEqual( - [], same_repo._pack_collection._upload_transport.list_dir('')) - index_names = repo._pack_collection._index_transport.list_dir('') + expected_names.append(wg_tokens[0] + ".cix") + self.assertEqual([], same_repo._pack_collection._upload_transport.list_dir("")) + index_names = repo._pack_collection._index_transport.list_dir("") self.assertEqual(sorted(expected_names), sorted(index_names)) - pack_names = repo._pack_collection._pack_transport.list_dir('') + pack_names = repo._pack_collection._pack_transport.list_dir("") self.assertEqual([expected_pack_name], pack_names) def test_resume_malformed_token(self): self.vfs_transport_factory = memory.MemoryServer # Make a repository with a suspended write group - repo = self.make_repository('repo', format=self.get_format()) + repo = self.make_repository("repo", format=self.get_format()) self._lock_write(repo).repository_token # noqa: B018 repo.start_write_group() - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() # Make a new repository - new_repo = self.make_repository('new_repo', format=self.get_format()) + new_repo = self.make_repository("new_repo", format=self.get_format()) self._lock_write(new_repo).repository_token # noqa: B018 - hacked_wg_token = ( - '../../../../repo/.bzr/repository/upload/' + wg_tokens[0]) + hacked_wg_token = "../../../../repo/.bzr/repository/upload/" + wg_tokens[0] self.assertRaises( - errors.UnresumableWriteGroup, - new_repo.resume_write_group, [hacked_wg_token]) + errors.UnresumableWriteGroup, new_repo.resume_write_group, [hacked_wg_token] + ) class TestPackRepositoryStacking(TestCaseWithTransport): @@ -862,75 +879,79 @@ def test_stack_checks_rich_root_compatibility(self): # TODO: Possibly this should be run per-repository-format and raise # TestNotApplicable on formats that don't support stacking. -- mbp # 20080729 - repo = self.make_repository('repo', format=self.get_format()) + repo = self.make_repository("repo", format=self.get_format()) if repo.supports_rich_root(): # can only stack on repositories that have compatible internal # metadata - if getattr(repo._format, 'supports_tree_reference', False): - matching_format_name = '2a' + if getattr(repo._format, "supports_tree_reference", False): + matching_format_name = "2a" else: if repo._format.supports_chks: - matching_format_name = '2a' + matching_format_name = "2a" else: - matching_format_name = 'rich-root-pack' - mismatching_format_name = 'pack-0.92' + matching_format_name = "rich-root-pack" + mismatching_format_name = "pack-0.92" else: # We don't have a non-rich-root CHK format. if repo._format.supports_chks: raise AssertionError("no non-rich-root CHK formats known") else: - matching_format_name = 'pack-0.92' - mismatching_format_name = 'pack-0.92-subtree' - base = self.make_repository('base', format=matching_format_name) + matching_format_name = "pack-0.92" + mismatching_format_name = "pack-0.92-subtree" + base = self.make_repository("base", format=matching_format_name) repo.add_fallback_repository(base) # you can't stack on something with incompatible data - bad_repo = self.make_repository('mismatch', - format=mismatching_format_name) - e = self.assertRaises(errors.IncompatibleRepositories, - repo.add_fallback_repository, bad_repo) - self.assertContainsRe(str(e), - r'(?m)KnitPackRepository.*/mismatch/.*\nis not compatible with\n' - r'.*Repository.*/repo/.*\n' - r'different rich-root support') + bad_repo = self.make_repository("mismatch", format=mismatching_format_name) + e = self.assertRaises( + errors.IncompatibleRepositories, repo.add_fallback_repository, bad_repo + ) + self.assertContainsRe( + str(e), + r"(?m)KnitPackRepository.*/mismatch/.*\nis not compatible with\n" + r".*Repository.*/repo/.*\n" + r"different rich-root support", + ) def test_stack_checks_serializers_compatibility(self): - repo = self.make_repository('repo', format=self.get_format()) - if getattr(repo._format, 'supports_tree_reference', False): + repo = self.make_repository("repo", format=self.get_format()) + if getattr(repo._format, "supports_tree_reference", False): # can only stack on repositories that have compatible internal # metadata - matching_format_name = '2a' - mismatching_format_name = 'rich-root-pack' + matching_format_name = "2a" + mismatching_format_name = "rich-root-pack" else: if repo.supports_rich_root(): if repo._format.supports_chks: - matching_format_name = '2a' + matching_format_name = "2a" else: - matching_format_name = 'rich-root-pack' - mismatching_format_name = 'pack-0.92-subtree' + matching_format_name = "rich-root-pack" + mismatching_format_name = "pack-0.92-subtree" else: - raise TestNotApplicable('No formats use non-v5 serializer' - ' without having rich-root also set') - base = self.make_repository('base', format=matching_format_name) + raise TestNotApplicable( + "No formats use non-v5 serializer" + " without having rich-root also set" + ) + base = self.make_repository("base", format=matching_format_name) repo.add_fallback_repository(base) # you can't stack on something with incompatible data - bad_repo = self.make_repository('mismatch', - format=mismatching_format_name) - e = self.assertRaises(errors.IncompatibleRepositories, - repo.add_fallback_repository, bad_repo) - self.assertContainsRe(str(e), - r'(?m)KnitPackRepository.*/mismatch/.*\nis not compatible with\n' - r'.*Repository.*/repo/.*\n' - r'different inventory serializers') + bad_repo = self.make_repository("mismatch", format=mismatching_format_name) + e = self.assertRaises( + errors.IncompatibleRepositories, repo.add_fallback_repository, bad_repo + ) + self.assertContainsRe( + str(e), + r"(?m)KnitPackRepository.*/mismatch/.*\nis not compatible with\n" + r".*Repository.*/repo/.*\n" + r"different inventory serializers", + ) def test_adding_pack_does_not_record_pack_names_from_other_repositories(self): - base = self.make_branch_and_tree('base', format=self.get_format()) - base.commit('foo') - referencing = self.make_branch_and_tree( - 'repo', format=self.get_format()) - referencing.branch.repository.add_fallback_repository( - base.branch.repository) - local_tree = referencing.branch.create_checkout('local') - local_tree.commit('bar') + base = self.make_branch_and_tree("base", format=self.get_format()) + base.commit("foo") + referencing = self.make_branch_and_tree("repo", format=self.get_format()) + referencing.branch.repository.add_fallback_repository(base.branch.repository) + local_tree = referencing.branch.create_checkout("local") + local_tree.commit("bar") new_instance = referencing.controldir.open_repository() new_instance.lock_read() self.addCleanup(new_instance.unlock) @@ -939,66 +960,65 @@ def test_adding_pack_does_not_record_pack_names_from_other_repositories(self): def test_autopack_only_considers_main_repo_packs(self): format = self.get_format() - base = self.make_branch_and_tree('base', format=format) - base.commit('foo') - tree = self.make_branch_and_tree('repo', format=format) + base = self.make_branch_and_tree("base", format=format) + base.commit("foo") + tree = self.make_branch_and_tree("repo", format=format) tree.branch.repository.add_fallback_repository(base.branch.repository) - trans = tree.branch.repository.controldir.get_repository_transport( - None) + trans = tree.branch.repository.controldir.get_repository_transport(None) # This test could be a little cheaper by replacing the packs # attribute on the repository to allow a different pack distribution # and max packs policy - so we are checking the policy is honoured # in the test. But for now 11 commits is not a big deal in a single # test. - local_tree = tree.branch.create_checkout('local') + local_tree = tree.branch.create_checkout("local") for x in range(9): - local_tree.commit(f'commit {x}') + local_tree.commit(f"commit {x}") # there should be 9 packs: - index = self.index_class(trans, 'pack-names', None) + index = self.index_class(trans, "pack-names", None) self.assertEqual(9, len(list(index.iter_all_entries()))) # committing one more should coalesce to 1 of 10. - local_tree.commit('commit triggering pack') - index = self.index_class(trans, 'pack-names', None) + local_tree.commit("commit triggering pack") + index = self.index_class(trans, "pack-names", None) self.assertEqual(1, len(list(index.iter_all_entries()))) # packing should not damage data tree = tree.controldir.open_workingtree() - tree.branch.repository.check( - [tree.branch.last_revision()]) + tree.branch.repository.check([tree.branch.last_revision()]) nb_files = 5 # .pack, .rix, .iix, .tix, .six if tree.branch.repository._format.supports_chks: nb_files += 1 # .cix # We should have 10 x nb_files files in the obsolete_packs directory. - obsolete_files = list(trans.list_dir('obsolete_packs')) - self.assertFalse('foo' in obsolete_files) - self.assertFalse('bar' in obsolete_files) + obsolete_files = list(trans.list_dir("obsolete_packs")) + self.assertFalse("foo" in obsolete_files) + self.assertFalse("bar" in obsolete_files) self.assertEqual(10 * nb_files, len(obsolete_files)) # XXX: Todo check packs obsoleted correctly - old packs and indices # in the obsolete_packs directory. large_pack_name = list(index.iter_all_entries())[0][1][0] # finally, committing again should not touch the large pack. - local_tree.commit('commit not triggering pack') - index = self.index_class(trans, 'pack-names', None) + local_tree.commit("commit not triggering pack") + index = self.index_class(trans, "pack-names", None) self.assertEqual(2, len(list(index.iter_all_entries()))) pack_names = [node[1][0] for node in index.iter_all_entries()] self.assertTrue(large_pack_name in pack_names) class TestKeyDependencies(TestCaseWithTransport): - def get_format(self): return controldir.format_registry.make_controldir(self.format_name) def create_source_and_target(self): - builder = self.make_branch_builder('source', format=self.get_format()) + builder = self.make_branch_builder("source", format=self.get_format()) builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None))], - revision_id=b'A-id') builder.build_snapshot( - [b'A-id', b'ghost-id'], [], - revision_id=b'B-id', ) + None, [("add", ("", b"root-id", "directory", None))], revision_id=b"A-id" + ) + builder.build_snapshot( + [b"A-id", b"ghost-id"], + [], + revision_id=b"B-id", + ) builder.finish_series() - repo = self.make_repository('target', format=self.get_format()) + repo = self.make_repository("target", format=self.get_format()) b = builder.get_branch() b.lock_read() self.addCleanup(b.unlock) @@ -1010,11 +1030,12 @@ def test_key_dependencies_cleared_on_abort(self): source_repo, target_repo = self.create_source_and_target() target_repo.start_write_group() try: - stream = source_repo.revisions.get_record_stream([(b'B-id',)], - 'unordered', True) + stream = source_repo.revisions.get_record_stream( + [(b"B-id",)], "unordered", True + ) target_repo.revisions.insert_record_stream(stream) key_refs = target_repo.revisions._index._key_dependencies - self.assertEqual([(b'B-id',)], sorted(key_refs.get_referrers())) + self.assertEqual([(b"B-id",)], sorted(key_refs.get_referrers())) finally: target_repo.abort_write_group() self.assertEqual([], sorted(key_refs.get_referrers())) @@ -1023,11 +1044,12 @@ def test_key_dependencies_cleared_on_suspend(self): source_repo, target_repo = self.create_source_and_target() target_repo.start_write_group() try: - stream = source_repo.revisions.get_record_stream([(b'B-id',)], - 'unordered', True) + stream = source_repo.revisions.get_record_stream( + [(b"B-id",)], "unordered", True + ) target_repo.revisions.insert_record_stream(stream) key_refs = target_repo.revisions._index._key_dependencies - self.assertEqual([(b'B-id',)], sorted(key_refs.get_referrers())) + self.assertEqual([(b"B-id",)], sorted(key_refs.get_referrers())) finally: target_repo.suspend_write_group() self.assertEqual([], sorted(key_refs.get_referrers())) @@ -1038,27 +1060,28 @@ def test_key_dependencies_cleared_on_commit(self): try: # Copy all texts, inventories, and chks so that nothing is missing # for revision B-id. - for vf_name in ['texts', 'chk_bytes', 'inventories']: + for vf_name in ["texts", "chk_bytes", "inventories"]: source_vf = getattr(source_repo, vf_name, None) if source_vf is None: continue target_vf = getattr(target_repo, vf_name) stream = source_vf.get_record_stream( - source_vf.keys(), 'unordered', True) + source_vf.keys(), "unordered", True + ) target_vf.insert_record_stream(stream) # Copy just revision B-id stream = source_repo.revisions.get_record_stream( - [(b'B-id',)], 'unordered', True) + [(b"B-id",)], "unordered", True + ) target_repo.revisions.insert_record_stream(stream) key_refs = target_repo.revisions._index._key_dependencies - self.assertEqual([(b'B-id',)], sorted(key_refs.get_referrers())) + self.assertEqual([(b"B-id",)], sorted(key_refs.get_referrers())) finally: target_repo.commit_write_group() self.assertEqual([], sorted(key_refs.get_referrers())) class TestSmartServerAutopack(TestCaseWithTransport): - def setUp(self): super().setUp() # Create a smart server that publishes whatever the backing VFS server @@ -1067,7 +1090,8 @@ def setUp(self): self.start_server(self.smart_server, self.get_server()) # Log all HPSS calls into self.hpss_calls. client._SmartClient.hooks.install_named_hook( - 'call', self.capture_hpss_call, None) + "call", self.capture_hpss_call, None + ) self.hpss_calls = [] def capture_hpss_call(self, params): @@ -1079,24 +1103,30 @@ def get_format(self): def test_autopack_or_streaming_rpc_is_used_when_using_hpss(self): # Make local and remote repos format = self.get_format() - tree = self.make_branch_and_tree('local', format=format) - self.make_branch_and_tree('remote', format=format) - remote_branch_url = self.smart_server.get_url() + 'remote' - remote_branch = controldir.ControlDir.open( - remote_branch_url).open_branch() + tree = self.make_branch_and_tree("local", format=format) + self.make_branch_and_tree("remote", format=format) + remote_branch_url = self.smart_server.get_url() + "remote" + remote_branch = controldir.ControlDir.open(remote_branch_url).open_branch() # Make 9 local revisions, and push them one at a time to the remote # repo to produce 9 pack files. for x in range(9): - tree.commit(f'commit {x}') + tree.commit(f"commit {x}") tree.branch.push(remote_branch) # Make one more push to trigger an autopack self.hpss_calls = [] - tree.commit('commit triggering pack') + tree.commit("commit triggering pack") tree.branch.push(remote_branch) - autopack_calls = len([call for call in self.hpss_calls if call - == b'PackRepository.autopack']) - streaming_calls = len([call for call in self.hpss_calls if call in - (b'Repository.insert_stream', b'Repository.insert_stream_1.19')]) + autopack_calls = len( + [call for call in self.hpss_calls if call == b"PackRepository.autopack"] + ) + streaming_calls = len( + [ + call + for call in self.hpss_calls + if call + in (b"Repository.insert_stream", b"Repository.insert_stream_1.19") + ] + ) if autopack_calls: # Non streaming server self.assertEqual(1, autopack_calls) @@ -1114,39 +1144,52 @@ def load_tests(loader, basic_tests, pattern): # these give the bzrdir canned format name, and the repository on-disk # format string scenarios_params = [ - {'format_name': 'pack-0.92', - 'format_string': "Bazaar pack repository format 1 (needs bzr 0.92)\n", - 'format_supports_external_lookups': False, - 'index_class': GraphIndex}, - {'format_name': 'pack-0.92-subtree', - 'format_string': "Bazaar pack repository format 1 " - "with subtree support (needs bzr 0.92)\n", - 'format_supports_external_lookups': False, - 'index_class': GraphIndex}, - {'format_name': '1.6', - 'format_string': "Bazaar RepositoryFormatKnitPack5 (bzr 1.6)\n", - 'format_supports_external_lookups': True, - 'index_class': GraphIndex}, - {'format_name': '1.6.1-rich-root', - 'format_string': "Bazaar RepositoryFormatKnitPack5RichRoot " - "(bzr 1.6.1)\n", - 'format_supports_external_lookups': True, - 'index_class': GraphIndex}, - {'format_name': '1.9', - 'format_string': "Bazaar RepositoryFormatKnitPack6 (bzr 1.9)\n", - 'format_supports_external_lookups': True, - 'index_class': BTreeGraphIndex}, - {'format_name': '1.9-rich-root', - 'format_string': "Bazaar RepositoryFormatKnitPack6RichRoot " - "(bzr 1.9)\n", - 'format_supports_external_lookups': True, - 'index_class': BTreeGraphIndex}, - {'format_name': '2a', - 'format_string': "Bazaar repository format 2a " - "(needs bzr 1.16 or later)\n", - 'format_supports_external_lookups': True, - 'index_class': BTreeGraphIndex}, - ] + { + "format_name": "pack-0.92", + "format_string": "Bazaar pack repository format 1 (needs bzr 0.92)\n", + "format_supports_external_lookups": False, + "index_class": GraphIndex, + }, + { + "format_name": "pack-0.92-subtree", + "format_string": "Bazaar pack repository format 1 " + "with subtree support (needs bzr 0.92)\n", + "format_supports_external_lookups": False, + "index_class": GraphIndex, + }, + { + "format_name": "1.6", + "format_string": "Bazaar RepositoryFormatKnitPack5 (bzr 1.6)\n", + "format_supports_external_lookups": True, + "index_class": GraphIndex, + }, + { + "format_name": "1.6.1-rich-root", + "format_string": "Bazaar RepositoryFormatKnitPack5RichRoot " + "(bzr 1.6.1)\n", + "format_supports_external_lookups": True, + "index_class": GraphIndex, + }, + { + "format_name": "1.9", + "format_string": "Bazaar RepositoryFormatKnitPack6 (bzr 1.9)\n", + "format_supports_external_lookups": True, + "index_class": BTreeGraphIndex, + }, + { + "format_name": "1.9-rich-root", + "format_string": "Bazaar RepositoryFormatKnitPack6RichRoot " "(bzr 1.9)\n", + "format_supports_external_lookups": True, + "index_class": BTreeGraphIndex, + }, + { + "format_name": "2a", + "format_string": "Bazaar repository format 2a " + "(needs bzr 1.16 or later)\n", + "format_supports_external_lookups": True, + "index_class": BTreeGraphIndex, + }, + ] # name of the scenario is the format name - scenarios = [(s['format_name'], s) for s in scenarios_params] + scenarios = [(s["format_name"], s) for s in scenarios_params] return tests.multiply_tests(basic_tests, scenarios, loader.suiteClass()) diff --git a/breezy/bzr/tests/per_repository_chk/__init__.py b/breezy/bzr/tests/per_repository_chk/__init__.py index 512ec6d54f..13e2722701 100644 --- a/breezy/bzr/tests/per_repository_chk/__init__.py +++ b/breezy/bzr/tests/per_repository_chk/__init__.py @@ -36,7 +36,6 @@ class TestCaseWithRepositoryCHK(TestCaseWithRepository): - def make_repository(self, path, format=None): TestCaseWithRepository.make_repository(self, path, format=format) return repository.Repository.open(self.get_transport(path).base) @@ -46,7 +45,7 @@ def load_tests(loader, standard_tests, pattern): supported_scenarios = [] unsupported_scenarios = [] for test_name, scenario_info in all_repository_format_scenarios(): - format = scenario_info['repository_format'] + format = scenario_info["repository_format"] # For remote repositories, we test both with, and without a backing chk # capable format: change the format we use to create the repo to direct # formats, and then the overridden make_repository in @@ -54,22 +53,22 @@ def load_tests(loader, standard_tests, pattern): # with the chosen backing format. if isinstance(format, remote.RemoteRepositoryFormat): with_support = dict(scenario_info) - with_support['repository_format'] = RepositoryFormat2a() - supported_scenarios.append( - (test_name + "(Supported)", with_support)) + with_support["repository_format"] = RepositoryFormat2a() + supported_scenarios.append((test_name + "(Supported)", with_support)) no_support = dict(scenario_info) - no_support['repository_format'] = RepositoryFormatKnitPack5() - unsupported_scenarios.append( - (test_name + "(Not Supported)", no_support)) + no_support["repository_format"] = RepositoryFormatKnitPack5() + unsupported_scenarios.append((test_name + "(Not Supported)", no_support)) elif format.supports_chks: supported_scenarios.append((test_name, scenario_info)) else: unsupported_scenarios.append((test_name, scenario_info)) result = loader.suiteClass() - supported_tests = loader.loadTestsFromModuleNames([ - 'breezy.bzr.tests.per_repository_chk.test_supported']) - unsupported_tests = loader.loadTestsFromModuleNames([ - 'breezy.bzr.tests.per_repository_chk.test_unsupported']) + supported_tests = loader.loadTestsFromModuleNames( + ["breezy.bzr.tests.per_repository_chk.test_supported"] + ) + unsupported_tests = loader.loadTestsFromModuleNames( + ["breezy.bzr.tests.per_repository_chk.test_unsupported"] + ) multiply_tests(supported_tests, supported_scenarios, result) multiply_tests(unsupported_tests, unsupported_scenarios, result) return result diff --git a/breezy/bzr/tests/per_repository_chk/test_supported.py b/breezy/bzr/tests/per_repository_chk/test_supported.py index ce50936b4f..8c3df66277 100644 --- a/breezy/bzr/tests/per_repository_chk/test_supported.py +++ b/breezy/bzr/tests/per_repository_chk/test_supported.py @@ -26,41 +26,49 @@ class TestCHKSupport(TestCaseWithRepositoryCHK): - def test_chk_bytes_attribute_is_VersionedFiles(self): - repo = self.make_repository('.') + repo = self.make_repository(".") self.assertIsInstance(repo.chk_bytes, VersionedFiles) def test_add_bytes_to_chk_bytes_store(self): - repo = self.make_repository('.') + repo = self.make_repository(".") with repo.lock_write(), repository.WriteGroup(repo): sha1, len, _ = repo.chk_bytes.add_lines( - (None,), None, [b"foo\n", b"bar\n"], random_id=True) - self.assertEqual( - b'4e48e2c9a3d2ca8a708cb0cc545700544efb5021', sha1) + (None,), None, [b"foo\n", b"bar\n"], random_id=True + ) + self.assertEqual(b"4e48e2c9a3d2ca8a708cb0cc545700544efb5021", sha1) self.assertEqual( - {(b'sha1:4e48e2c9a3d2ca8a708cb0cc545700544efb5021',)}, - repo.chk_bytes.keys()) + {(b"sha1:4e48e2c9a3d2ca8a708cb0cc545700544efb5021",)}, + repo.chk_bytes.keys(), + ) # And after an unlock/lock pair with repo.lock_read(): self.assertEqual( - {(b'sha1:4e48e2c9a3d2ca8a708cb0cc545700544efb5021',)}, - repo.chk_bytes.keys()) + {(b"sha1:4e48e2c9a3d2ca8a708cb0cc545700544efb5021",)}, + repo.chk_bytes.keys(), + ) # and reopening repo = repo.controldir.open_repository() with repo.lock_read(): self.assertEqual( - {(b'sha1:4e48e2c9a3d2ca8a708cb0cc545700544efb5021',)}, - repo.chk_bytes.keys()) + {(b"sha1:4e48e2c9a3d2ca8a708cb0cc545700544efb5021",)}, + repo.chk_bytes.keys(), + ) def test_pack_preserves_chk_bytes_store(self): leaf_lines = [b"chkleaf:\n", b"0\n", b"1\n", b"0\n", b"\n"] leaf_sha1 = osutils.sha_strings(leaf_lines) - node_lines = [b"chknode:\n", b"0\n", b"1\n", b"1\n", b"foo\n", - b"\x00sha1:%s\n" % (leaf_sha1,)] + node_lines = [ + b"chknode:\n", + b"0\n", + b"1\n", + b"1\n", + b"foo\n", + b"\x00sha1:%s\n" % (leaf_sha1,), + ] node_sha1 = osutils.sha_strings(node_lines) - expected_set = {(b'sha1:' + leaf_sha1,), (b'sha1:' + node_sha1,)} - repo = self.make_repository('.') + expected_set = {(b"sha1:" + leaf_sha1,), (b"sha1:" + node_sha1,)} + repo = self.make_repository(".") with repo.lock_write(): with repository.WriteGroup(repo): # Internal node pointing at a leaf. @@ -76,17 +84,18 @@ def test_pack_preserves_chk_bytes_store(self): self.assertEqual(expected_set, repo.chk_bytes.keys()) def test_chk_bytes_are_fully_buffered(self): - repo = self.make_repository('.') + repo = self.make_repository(".") repo.lock_write() self.addCleanup(repo.unlock) with repository.WriteGroup(repo): - sha1, len, _ = repo.chk_bytes.add_lines((None,), - None, [b"foo\n", b"bar\n"], random_id=True) - self.assertEqual(b'4e48e2c9a3d2ca8a708cb0cc545700544efb5021', - sha1) + sha1, len, _ = repo.chk_bytes.add_lines( + (None,), None, [b"foo\n", b"bar\n"], random_id=True + ) + self.assertEqual(b"4e48e2c9a3d2ca8a708cb0cc545700544efb5021", sha1) self.assertEqual( - {(b'sha1:4e48e2c9a3d2ca8a708cb0cc545700544efb5021',)}, - repo.chk_bytes.keys()) + {(b"sha1:4e48e2c9a3d2ca8a708cb0cc545700544efb5021",)}, + repo.chk_bytes.keys(), + ) # This may not always be correct if we change away from BTreeGraphIndex # in the future. But for now, lets check that chk_bytes are fully # buffered @@ -121,37 +130,41 @@ def test_missing_chk_root_for_inventory(self): """commit_write_group fails with BzrCheckError when the chk root record for a new inventory is missing. """ - repo = self.make_repository('damaged-repo') - builder = self.make_branch_builder('simple-branch') - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('file', b'file-id', 'file', b'content\n'))], - revision_id=b'A-id') + repo = self.make_repository("damaged-repo") + builder = self.make_branch_builder("simple-branch") + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("file", b"file-id", "file", b"content\n")), + ], + revision_id=b"A-id", + ) b = builder.get_branch() b.lock_read() self.addCleanup(b.unlock) repo.lock_write() repo.start_write_group() # Now, add the objects manually - text_keys = [(b'file-id', b'A-id'), (b'root-id', b'A-id')] + text_keys = [(b"file-id", b"A-id"), (b"root-id", b"A-id")] # Directly add the texts, inventory, and revision object for 'A-id' -- # but don't add the chk_bytes. src_repo = b.repository - repo.texts.insert_record_stream(src_repo.texts.get_record_stream( - text_keys, 'unordered', True)) + repo.texts.insert_record_stream( + src_repo.texts.get_record_stream(text_keys, "unordered", True) + ) repo.inventories.insert_record_stream( - src_repo.inventories.get_record_stream( - [(b'A-id',)], 'unordered', True)) + src_repo.inventories.get_record_stream([(b"A-id",)], "unordered", True) + ) repo.revisions.insert_record_stream( - src_repo.revisions.get_record_stream( - [(b'A-id',)], 'unordered', True)) + src_repo.revisions.get_record_stream([(b"A-id",)], "unordered", True) + ) # Make sure the presence of the missing data in a fallback does not # avoid the error. repo.add_fallback_repository(b.repository) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) reopened_repo = self.reopen_repo_and_resume_write_group(repo) - self.assertRaises( - errors.BzrCheckError, reopened_repo.commit_write_group) + self.assertRaises(errors.BzrCheckError, reopened_repo.commit_write_group) reopened_repo.abort_write_group() def test_missing_chk_root_for_unchanged_inventory(self): @@ -166,22 +179,26 @@ def test_missing_chk_root_for_unchanged_inventory(self): (In principle the chk records are unnecessary in this case, but in practice bzr 2.0rc1 (at least) expects to find them.) """ - repo = self.make_repository('damaged-repo') + repo = self.make_repository("damaged-repo") # Make a branch where the last two revisions have identical # inventories. - builder = self.make_branch_builder('simple-branch') - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('file', b'file-id', 'file', b'content\n'))], - revision_id=b'A-id') - builder.build_snapshot(None, [], revision_id=b'B-id') - builder.build_snapshot(None, [], revision_id=b'C-id') + builder = self.make_branch_builder("simple-branch") + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("file", b"file-id", "file", b"content\n")), + ], + revision_id=b"A-id", + ) + builder.build_snapshot(None, [], revision_id=b"B-id") + builder.build_snapshot(None, [], revision_id=b"C-id") b = builder.get_branch() b.lock_read() self.addCleanup(b.unlock) # check our setup: B-id and C-id should have identical chk root keys. - inv_b = b.repository.get_inventory(b'B-id') - inv_c = b.repository.get_inventory(b'C-id') + inv_b = b.repository.get_inventory(b"B-id") + inv_c = b.repository.get_inventory(b"C-id") if not isinstance(repo, RemoteRepository): # Remote repositories always return plain inventories self.assertEqual(inv_b.id_to_entry.key(), inv_c.id_to_entry.key()) @@ -195,72 +212,75 @@ def test_missing_chk_root_for_unchanged_inventory(self): src_repo = b.repository repo.inventories.insert_record_stream( src_repo.inventories.get_record_stream( - [(b'B-id',), (b'C-id',)], 'unordered', True)) + [(b"B-id",), (b"C-id",)], "unordered", True + ) + ) repo.revisions.insert_record_stream( - src_repo.revisions.get_record_stream( - [(b'C-id',)], 'unordered', True)) + src_repo.revisions.get_record_stream([(b"C-id",)], "unordered", True) + ) # Make sure the presence of the missing data in a fallback does not # avoid the error. repo.add_fallback_repository(b.repository) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) reopened_repo = self.reopen_repo_and_resume_write_group(repo) - self.assertRaises( - errors.BzrCheckError, reopened_repo.commit_write_group) + self.assertRaises(errors.BzrCheckError, reopened_repo.commit_write_group) reopened_repo.abort_write_group() def test_missing_chk_leaf_for_inventory(self): """commit_write_group fails with BzrCheckError when the chk root record for a parent inventory of a new revision is missing. """ - repo = self.make_repository('damaged-repo') + repo = self.make_repository("damaged-repo") if isinstance(repo, RemoteRepository): - raise TestNotApplicable( - "Unable to obtain CHKInventory from remote repo") + raise TestNotApplicable("Unable to obtain CHKInventory from remote repo") b = self.make_branch_with_multiple_chk_nodes() src_repo = b.repository src_repo.lock_read() self.addCleanup(src_repo.unlock) # Now, manually insert objects for a stacked repo with only revision # C-id, *except* drop the non-root chk records. - inv_b = src_repo.get_inventory(b'B-id') - inv_c = src_repo.get_inventory(b'C-id') + inv_b = src_repo.get_inventory(b"B-id") + inv_c = src_repo.get_inventory(b"C-id") chk_root_keys_only = [ - inv_b.id_to_entry.key(), inv_b.parent_id_basename_to_file_id.key(), - inv_c.id_to_entry.key(), inv_c.parent_id_basename_to_file_id.key()] + inv_b.id_to_entry.key(), + inv_b.parent_id_basename_to_file_id.key(), + inv_c.id_to_entry.key(), + inv_c.parent_id_basename_to_file_id.key(), + ] all_chks = src_repo.chk_bytes.keys() for key_to_drop in all_chks.difference(chk_root_keys_only): all_chks.discard(key_to_drop) repo.lock_write() repo.start_write_group() repo.chk_bytes.insert_record_stream( - src_repo.chk_bytes.get_record_stream( - all_chks, 'unordered', True)) + src_repo.chk_bytes.get_record_stream(all_chks, "unordered", True) + ) repo.texts.insert_record_stream( - src_repo.texts.get_record_stream( - src_repo.texts.keys(), 'unordered', True)) + src_repo.texts.get_record_stream(src_repo.texts.keys(), "unordered", True) + ) repo.inventories.insert_record_stream( src_repo.inventories.get_record_stream( - [(b'B-id',), (b'C-id',)], 'unordered', True)) + [(b"B-id",), (b"C-id",)], "unordered", True + ) + ) repo.revisions.insert_record_stream( - src_repo.revisions.get_record_stream( - [(b'C-id',)], 'unordered', True)) + src_repo.revisions.get_record_stream([(b"C-id",)], "unordered", True) + ) # Make sure the presence of the missing data in a fallback does not # avoid the error. repo.add_fallback_repository(b.repository) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) reopened_repo = self.reopen_repo_and_resume_write_group(repo) - self.assertRaises( - errors.BzrCheckError, reopened_repo.commit_write_group) + self.assertRaises(errors.BzrCheckError, reopened_repo.commit_write_group) reopened_repo.abort_write_group() def test_missing_chk_root_for_parent_inventory(self): """commit_write_group fails with BzrCheckError when the chk root record for a parent inventory of a new revision is missing. """ - repo = self.make_repository('damaged-repo') + repo = self.make_repository("damaged-repo") if isinstance(repo, RemoteRepository): - raise TestNotApplicable( - "Unable to obtain CHKInventory from remote repo") + raise TestNotApplicable("Unable to obtain CHKInventory from remote repo") b = self.make_branch_with_multiple_chk_nodes() b.lock_read() self.addCleanup(b.unlock) @@ -269,55 +289,67 @@ def test_missing_chk_root_for_parent_inventory(self): # We need (b'revisions', b'C-id'), (b'inventories', b'C-id'), # (b'inventories', b'B-id'), and the corresponding chk roots for those # inventories. - inv_c = b.repository.get_inventory(b'C-id') + inv_c = b.repository.get_inventory(b"C-id") chk_keys_for_c_only = [ - inv_c.id_to_entry.key(), inv_c.parent_id_basename_to_file_id.key()] + inv_c.id_to_entry.key(), + inv_c.parent_id_basename_to_file_id.key(), + ] repo.lock_write() repo.start_write_group() src_repo = b.repository repo.chk_bytes.insert_record_stream( - src_repo.chk_bytes.get_record_stream( - chk_keys_for_c_only, 'unordered', True)) + src_repo.chk_bytes.get_record_stream(chk_keys_for_c_only, "unordered", True) + ) repo.inventories.insert_record_stream( src_repo.inventories.get_record_stream( - [(b'B-id',), (b'C-id',)], 'unordered', True)) + [(b"B-id",), (b"C-id",)], "unordered", True + ) + ) repo.revisions.insert_record_stream( - src_repo.revisions.get_record_stream( - [(b'C-id',)], 'unordered', True)) + src_repo.revisions.get_record_stream([(b"C-id",)], "unordered", True) + ) # Make sure the presence of the missing data in a fallback does not # avoid the error. repo.add_fallback_repository(b.repository) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) reopened_repo = self.reopen_repo_and_resume_write_group(repo) - self.assertRaises( - errors.BzrCheckError, reopened_repo.commit_write_group) + self.assertRaises(errors.BzrCheckError, reopened_repo.commit_write_group) reopened_repo.abort_write_group() def make_branch_with_multiple_chk_nodes(self): # add and modify files with very long file-ids, so that the chk map # will need more than just a root node. - builder = self.make_branch_builder('simple-branch') + builder = self.make_branch_builder("simple-branch") file_adds = [] file_modifies = [] - for char in 'abc': + for char in "abc": name = char * 10000 file_adds.append( - ('add', ('file-' + name, (f'file-{name}-id').encode(), 'file', - f'content {name}\n'.encode()))) + ( + "add", + ( + "file-" + name, + (f"file-{name}-id").encode(), + "file", + f"content {name}\n".encode(), + ), + ) + ) file_modifies.append( - ('modify', ('file-' + name, - f'new content {name}\n'.encode()))) - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None))] + - file_adds, - revision_id=b'A-id') - builder.build_snapshot(None, [], revision_id=b'B-id') - builder.build_snapshot(None, file_modifies, revision_id=b'C-id') + ("modify", ("file-" + name, f"new content {name}\n".encode())) + ) + builder.build_snapshot( + None, + [("add", ("", b"root-id", "directory", None))] + file_adds, + revision_id=b"A-id", + ) + builder.build_snapshot(None, [], revision_id=b"B-id") + builder.build_snapshot(None, file_modifies, revision_id=b"C-id") return builder.get_branch() def test_missing_text_record(self): """commit_write_group fails with BzrCheckError when a text is missing.""" - repo = self.make_repository('damaged-repo') + repo = self.make_repository("damaged-repo") b = self.make_branch_with_multiple_chk_nodes() src_repo = b.repository src_repo.lock_read() @@ -325,26 +357,29 @@ def test_missing_text_record(self): # Now, manually insert objects for a stacked repo with only revision # C-id, *except* drop one changed text. all_texts = src_repo.texts.keys() - all_texts.remove((b'file-%s-id' % (b'c' * 10000,), b'C-id')) + all_texts.remove((b"file-%s-id" % (b"c" * 10000,), b"C-id")) repo.lock_write() repo.start_write_group() repo.chk_bytes.insert_record_stream( src_repo.chk_bytes.get_record_stream( - src_repo.chk_bytes.keys(), 'unordered', True)) + src_repo.chk_bytes.keys(), "unordered", True + ) + ) repo.texts.insert_record_stream( - src_repo.texts.get_record_stream( - all_texts, 'unordered', True)) + src_repo.texts.get_record_stream(all_texts, "unordered", True) + ) repo.inventories.insert_record_stream( src_repo.inventories.get_record_stream( - [(b'B-id',), (b'C-id',)], 'unordered', True)) + [(b"B-id",), (b"C-id",)], "unordered", True + ) + ) repo.revisions.insert_record_stream( - src_repo.revisions.get_record_stream( - [(b'C-id',)], 'unordered', True)) + src_repo.revisions.get_record_stream([(b"C-id",)], "unordered", True) + ) # Make sure the presence of the missing data in a fallback does not # avoid the error. repo.add_fallback_repository(b.repository) self.assertRaises(errors.BzrCheckError, repo.commit_write_group) reopened_repo = self.reopen_repo_and_resume_write_group(repo) - self.assertRaises( - errors.BzrCheckError, reopened_repo.commit_write_group) + self.assertRaises(errors.BzrCheckError, reopened_repo.commit_write_group) reopened_repo.abort_write_group() diff --git a/breezy/bzr/tests/per_repository_chk/test_unsupported.py b/breezy/bzr/tests/per_repository_chk/test_unsupported.py index 183d730fd5..bc71adfcc7 100644 --- a/breezy/bzr/tests/per_repository_chk/test_unsupported.py +++ b/breezy/bzr/tests/per_repository_chk/test_unsupported.py @@ -24,7 +24,6 @@ class TestNoCHKSupport(TestCaseWithRepositoryCHK): - def test_chk_bytes_attribute_is_None(self): - repo = self.make_repository('.') + repo = self.make_repository(".") self.assertEqual(None, repo.chk_bytes) diff --git a/breezy/bzr/tests/per_repository_vf/__init__.py b/breezy/bzr/tests/per_repository_vf/__init__.py index 5b9a8a3a82..6bd0a69c2a 100644 --- a/breezy/bzr/tests/per_repository_vf/__init__.py +++ b/breezy/bzr/tests/per_repository_vf/__init__.py @@ -29,7 +29,7 @@ def all_repository_vf_format_scenarios(): scenarios = [] for test_name, scenario_info in all_repository_format_scenarios(): - format = scenario_info['repository_format'] + format = scenario_info["repository_format"] if format.supports_full_versioned_files: scenarios.append((test_name, scenario_info)) return scenarios @@ -37,19 +37,20 @@ def all_repository_vf_format_scenarios(): def load_tests(loader, basic_tests, pattern): testmod_names = [ - 'test_add_inventory_by_delta', - 'test_check', - 'test_check_reconcile', - 'test_find_text_key_references', - 'test__generate_text_key_index', - 'test_fetch', - 'test_fileid_involved', - 'test_merge_directive', - 'test_reconcile', - 'test_refresh_data', - 'test_repository', - 'test_write_group', - ] - basic_tests.addTest(loader.loadTestsFromModuleNames( - [f"{__name__}.{tmn}" for tmn in testmod_names])) + "test_add_inventory_by_delta", + "test_check", + "test_check_reconcile", + "test_find_text_key_references", + "test__generate_text_key_index", + "test_fetch", + "test_fileid_involved", + "test_merge_directive", + "test_reconcile", + "test_refresh_data", + "test_repository", + "test_write_group", + ] + basic_tests.addTest( + loader.loadTestsFromModuleNames([f"{__name__}.{tmn}" for tmn in testmod_names]) + ) return basic_tests diff --git a/breezy/bzr/tests/per_repository_vf/helpers.py b/breezy/bzr/tests/per_repository_vf/helpers.py index 2ac65ca91d..4ad48d898c 100644 --- a/breezy/bzr/tests/per_repository_vf/helpers.py +++ b/breezy/bzr/tests/per_repository_vf/helpers.py @@ -25,7 +25,6 @@ class TestCaseWithBrokenRevisionIndex(TestCaseWithRepository): - def make_repo_with_extra_ghost_index(self): """Make a corrupt repository. @@ -42,28 +41,32 @@ def make_repo_with_extra_ghost_index(self): # pretty deprecated. Ideally these tests should apply to any repo # where repo.revision_graph_can_have_wrong_parents() is True, but # at the moment we only know how to corrupt knit repos. - raise TestNotApplicable( - f"{self.repository_format} isn't a knit format") + raise TestNotApplicable(f"{self.repository_format} isn't a knit format") - repo = self.make_repository('broken') + repo = self.make_repository("broken") with repo.lock_write(), WriteGroup(repo): - inv = inventory.Inventory(revision_id=b'revision-id') - inv.root.revision = b'revision-id' - inv_sha1 = repo.add_inventory(b'revision-id', inv, []) + inv = inventory.Inventory(revision_id=b"revision-id") + inv.root.revision = b"revision-id" + inv_sha1 = repo.add_inventory(b"revision-id", inv, []) if repo.supports_rich_root(): root_id = inv.root.file_id - repo.texts.add_lines((root_id, b'revision-id'), [], []) - revision = _mod_revision.Revision(b'revision-id', - committer='jrandom@example.com', timestamp=0, - inventory_sha1=inv_sha1, timezone=0, message='message', - properties={}, - parent_ids=[]) + repo.texts.add_lines((root_id, b"revision-id"), [], []) + revision = _mod_revision.Revision( + b"revision-id", + committer="jrandom@example.com", + timestamp=0, + inventory_sha1=inv_sha1, + timezone=0, + message="message", + properties={}, + parent_ids=[], + ) # Manually add the revision text using the RevisionStore API, with # bad parents. lines = repo._revision_serializer.write_revision_to_lines(revision) - repo.revisions.add_lines((revision.revision_id,), - [(b'incorrect-parent',)], - lines) + repo.revisions.add_lines( + (revision.revision_id,), [(b"incorrect-parent",)], lines + ) repo.lock_write() self.addCleanup(repo.unlock) diff --git a/breezy/bzr/tests/per_repository_vf/test__generate_text_key_index.py b/breezy/bzr/tests/per_repository_vf/test__generate_text_key_index.py index f9b75fda31..8709173068 100644 --- a/breezy/bzr/tests/per_repository_vf/test__generate_text_key_index.py +++ b/breezy/bzr/tests/per_repository_vf/test__generate_text_key_index.py @@ -29,11 +29,10 @@ class TestGenerateTextKeyIndex(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def test_empty(self): - repo = self.make_repository('.') + repo = self.make_repository(".") repo.lock_read() self.addCleanup(repo.unlock) self.assertEqual({}, repo._generate_text_key_index()) diff --git a/breezy/bzr/tests/per_repository_vf/test_add_inventory_by_delta.py b/breezy/bzr/tests/per_repository_vf/test_add_inventory_by_delta.py index 9b0c3ac166..28e2e8235b 100644 --- a/breezy/bzr/tests/per_repository_vf/test_add_inventory_by_delta.py +++ b/breezy/bzr/tests/per_repository_vf/test_add_inventory_by_delta.py @@ -29,10 +29,9 @@ class TestAddInventoryByDelta(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() - def _get_repo_in_write_group(self, path='repository'): + def _get_repo_in_write_group(self, path="repository"): repo = self.make_repository(path) repo.lock_write() self.addCleanup(repo.unlock) @@ -42,27 +41,38 @@ def _get_repo_in_write_group(self, path='repository'): def test_basis_missing_errors(self): repo = self._get_repo_in_write_group() try: - self.assertRaises(errors.NoSuchRevision, - repo.add_inventory_by_delta, "missing-revision", [], - "new-revision", ["missing-revision"]) + self.assertRaises( + errors.NoSuchRevision, + repo.add_inventory_by_delta, + "missing-revision", + [], + "new-revision", + ["missing-revision"], + ) finally: repo.abort_write_group() def test_not_in_write_group_errors(self): - repo = self.make_repository('repository') + repo = self.make_repository("repository") repo.lock_write() self.addCleanup(repo.unlock) - self.assertRaises(AssertionError, repo.add_inventory_by_delta, - "missing-revision", [], "new-revision", ["missing-revision"]) + self.assertRaises( + AssertionError, + repo.add_inventory_by_delta, + "missing-revision", + [], + "new-revision", + ["missing-revision"], + ) def make_inv_delta(self, old, new): """Make an inventory delta from two inventories.""" - by_id = getattr(old, '_byid', None) + by_id = getattr(old, "_byid", None) if by_id is None: old_ids = {entry.file_id for (_n, entry) in old.iter_entries()} else: old_ids = set(by_id) - by_id = getattr(new, '_byid', None) + by_id = getattr(new, "_byid", None) if by_id is None: new_ids = {entry.file_id for (_n, entry) in new.iter_entries()} else: @@ -75,42 +85,42 @@ def make_inv_delta(self, old, new): for file_id in deletes: delta.append((old.id2path(file_id), None, file_id, None)) for file_id in adds: - delta.append((None, new.id2path(file_id), - file_id, new.get_entry(file_id))) + delta.append((None, new.id2path(file_id), file_id, new.get_entry(file_id))) for file_id in common: if old.get_entry(file_id) != new.get_entry(file_id): - delta.append((old.id2path(file_id), new.id2path(file_id), - file_id, new[file_id])) + delta.append( + (old.id2path(file_id), new.id2path(file_id), file_id, new[file_id]) + ) return InventoryDelta(delta) def test_same_validator(self): # Adding an inventory via delta or direct results in the same # validator. - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") revid = tree.commit("empty post") # tree.basis_tree() always uses a plain Inventory from the dirstate, we # want the same format inventory as we have in the repository - revtree = tree.branch.repository.revision_tree( - tree.branch.last_revision()) + revtree = tree.branch.repository.revision_tree(tree.branch.last_revision()) tree.basis_tree() revtree.lock_read() self.addCleanup(revtree.unlock) old_inv = tree.branch.repository.revision_tree( - revision.NULL_REVISION).root_inventory + revision.NULL_REVISION + ).root_inventory new_inv = revtree.root_inventory delta = self.make_inv_delta(old_inv, new_inv) - repo_direct = self._get_repo_in_write_group('direct') + repo_direct = self._get_repo_in_write_group("direct") add_validator = repo_direct.add_inventory(revid, new_inv, []) repo_direct.commit_write_group() - repo_delta = self._get_repo_in_write_group('delta') + repo_delta = self._get_repo_in_write_group("delta") try: delta_validator, inv = repo_delta.add_inventory_by_delta( - revision.NULL_REVISION, delta, revid, []) + revision.NULL_REVISION, delta, revid, [] + ) except: repo_delta.abort_write_group() raise else: repo_delta.commit_write_group() self.assertEqual(add_validator, delta_validator) - self.assertEqual(list(new_inv.iter_entries()), - list(inv.iter_entries())) + self.assertEqual(list(new_inv.iter_entries()), list(inv.iter_entries())) diff --git a/breezy/bzr/tests/per_repository_vf/test_check.py b/breezy/bzr/tests/per_repository_vf/test_check.py index dcb19adca2..36d80975f6 100644 --- a/breezy/bzr/tests/per_repository_vf/test_check.py +++ b/breezy/bzr/tests/per_repository_vf/test_check.py @@ -31,7 +31,6 @@ class TestFindInconsistentRevisionParents(TestCaseWithBrokenRevisionIndex): - scenarios = all_repository_vf_format_scenarios() def test__find_inconsistent_revision_parents(self): @@ -40,8 +39,9 @@ def test__find_inconsistent_revision_parents(self): """ repo = self.make_repo_with_extra_ghost_index() self.assertEqual( - [(b'revision-id', (b'incorrect-parent',), ())], - list(repo._find_inconsistent_revision_parents())) + [(b"revision-id", (b"incorrect-parent",), ())], + list(repo._find_inconsistent_revision_parents()), + ) def test__check_for_inconsistent_revision_parents(self): """_check_for_inconsistent_revision_parents raises BzrCheckError if @@ -49,17 +49,16 @@ def test__check_for_inconsistent_revision_parents(self): """ repo = self.make_repo_with_extra_ghost_index() self.assertRaises( - errors.BzrCheckError, - repo._check_for_inconsistent_revision_parents) + errors.BzrCheckError, repo._check_for_inconsistent_revision_parents + ) def test__check_for_inconsistent_revision_parents_on_clean_repo(self): """_check_for_inconsistent_revision_parents does nothing if there are no broken revisions. """ - repo = self.make_repository('empty-repo') + repo = self.make_repository("empty-repo") if not repo._format.revision_graph_can_have_wrong_parents: - raise TestNotApplicable( - f'{repo!r} cannot have corrupt revision index.') + raise TestNotApplicable(f"{repo!r} cannot have corrupt revision index.") with repo.lock_read(): repo._check_for_inconsistent_revision_parents() # nothing happens @@ -67,25 +66,26 @@ def test_check_reports_bad_ancestor(self): repo = self.make_repo_with_extra_ghost_index() # XXX: check requires a non-empty revision IDs list, but it ignores the # contents of it! - check_object = repo.check(['ignored']) + check_object = repo.check(["ignored"]) check_object.report_results(verbose=False) - self.assertContainsRe(self.get_log(), - '1 revisions have incorrect parents in the revision index') + self.assertContainsRe( + self.get_log(), "1 revisions have incorrect parents in the revision index" + ) check_object.report_results(verbose=True) self.assertContainsRe( self.get_log(), "revision-id has wrong parents in index: " - r"\(incorrect-parent\) should be \(\)") + r"\(incorrect-parent\) should be \(\)", + ) class TestCallbacks(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def test_callback_tree_and_branch(self): # use a real tree to get actual refs that will work - tree = self.make_branch_and_tree('foo') - revid = tree.commit('foo') + tree = self.make_branch_and_tree("foo") + revid = tree.commit("foo") tree.lock_read() self.addCleanup(tree.unlock) needed_refs = {} @@ -102,27 +102,26 @@ def test_callback_tree_and_branch(self): self.assertNotEqual([], self.callbacks) def tree_callback(self, refs): - self.callbacks.append(('tree', refs)) + self.callbacks.append(("tree", refs)) return self.tree_check(refs) def branch_callback(self, refs): - self.callbacks.append(('branch', refs)) + self.callbacks.append(("branch", refs)) return self.branch_check(refs) class TestNoSpuriousInconsistentAncestors(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def test_two_files_different_versions_no_inconsistencies_bug_165071(self): """Two files, with different versions can be clean.""" - tree = self.make_branch_and_tree('.') - self.build_tree(['foo']) - tree.smart_add(['.']) - revid1 = tree.commit('1') - self.build_tree(['bar']) - tree.smart_add(['.']) - revid2 = tree.commit('2') + tree = self.make_branch_and_tree(".") + self.build_tree(["foo"]) + tree.smart_add(["."]) + revid1 = tree.commit("1") + self.build_tree(["bar"]) + tree.smart_add(["."]) + revid2 = tree.commit("2") check_object = tree.branch.repository.check([revid1, revid2]) check_object.report_results(verbose=True) self.assertContainsRe(self.get_log(), "0 unreferenced text versions") diff --git a/breezy/bzr/tests/per_repository_vf/test_check_reconcile.py b/breezy/bzr/tests/per_repository_vf/test_check_reconcile.py index 55b9104686..d81e2ca5f8 100644 --- a/breezy/bzr/tests/per_repository_vf/test_check_reconcile.py +++ b/breezy/bzr/tests/per_repository_vf/test_check_reconcile.py @@ -61,13 +61,25 @@ class BrokenRepoScenario: def __init__(self, test_case): self.test_case = test_case - def make_one_file_inventory(self, repo, revision, parents, - inv_revision=None, root_revision=None, - file_contents=None, make_file_version=True): + def make_one_file_inventory( + self, + repo, + revision, + parents, + inv_revision=None, + root_revision=None, + file_contents=None, + make_file_version=True, + ): return self.test_case.make_one_file_inventory( - repo, revision, parents, inv_revision=inv_revision, - root_revision=root_revision, file_contents=file_contents, - make_file_version=make_file_version) + repo, + revision, + parents, + inv_revision=inv_revision, + root_revision=root_revision, + file_contents=file_contents, + make_file_version=make_file_version, + ) def add_revision(self, repo, revision_id, inv, parent_ids): return self.test_case.add_revision(repo, revision_id, inv, parent_ids) @@ -90,10 +102,10 @@ class UndamagedRepositoryScenario(BrokenRepoScenario): """ def all_versions_after_reconcile(self): - return (b'rev1a', ) + return (b"rev1a",) def populated_parents(self): - return (((), b'rev1a'), ) + return (((), b"rev1a"),) def corrected_parents(self): # Same as the populated parents, because there was nothing wrong. @@ -104,23 +116,22 @@ def check_regexes(self, repo): def populate_repository(self, repo): # make rev1a: A well-formed revision, containing 'a-file' - inv = self.make_one_file_inventory( - repo, b'rev1a', [], root_revision=b'rev1a') - self.add_revision(repo, b'rev1a', inv, []) + inv = self.make_one_file_inventory(repo, b"rev1a", [], root_revision=b"rev1a") + self.add_revision(repo, b"rev1a", inv, []) self.versioned_root = repo.supports_rich_root() def repository_text_key_references(self): result = {} if self.versioned_root: - result.update({(b'TREE_ROOT', b'rev1a'): True}) - result.update({(b'a-file-id', b'rev1a'): True}) + result.update({(b"TREE_ROOT", b"rev1a"): True}) + result.update({(b"a-file-id", b"rev1a"): True}) return result def repository_text_keys(self): - return {(b'a-file-id', b'rev1a'): [NULL_REVISION]} + return {(b"a-file-id", b"rev1a"): [NULL_REVISION]} def versioned_repository_text_keys(self): - return {(b'TREE_ROOT', b'rev1a'): [NULL_REVISION]} + return {(b"TREE_ROOT", b"rev1a"): [NULL_REVISION]} class FileParentIsNotInRevisionAncestryScenario(BrokenRepoScenario): @@ -132,62 +143,62 @@ class FileParentIsNotInRevisionAncestryScenario(BrokenRepoScenario): """ def all_versions_after_reconcile(self): - return (b'rev1a', b'rev2') + return (b"rev1a", b"rev2") def populated_parents(self): return ( - ((), b'rev1a'), - ((), b'rev1b'), # Will be gc'd - ((b'rev1a', b'rev1b'), b'rev2')) # Will have parents trimmed + ((), b"rev1a"), + ((), b"rev1b"), # Will be gc'd + ((b"rev1a", b"rev1b"), b"rev2"), + ) # Will have parents trimmed def corrected_parents(self): - return ( - ((), b'rev1a'), - (None, b'rev1b'), - ((b'rev1a',), b'rev2')) + return (((), b"rev1a"), (None, b"rev1b"), ((b"rev1a",), b"rev2")) def check_regexes(self, repo): - return [r"\* a-file-id version rev2 has parents \(rev1a, rev1b\) " - r"but should have \(rev1a\)", - "1 unreferenced text versions", - ] + return [ + r"\* a-file-id version rev2 has parents \(rev1a, rev1b\) " + r"but should have \(rev1a\)", + "1 unreferenced text versions", + ] def populate_repository(self, repo): # make rev1a: A well-formed revision, containing 'a-file' - inv = self.make_one_file_inventory( - repo, b'rev1a', [], root_revision=b'rev1a') - self.add_revision(repo, b'rev1a', inv, []) + inv = self.make_one_file_inventory(repo, b"rev1a", [], root_revision=b"rev1a") + self.add_revision(repo, b"rev1a", inv, []) # make rev1b, which has no Revision, but has an Inventory, and # a-file - inv = self.make_one_file_inventory( - repo, b'rev1b', [], root_revision=b'rev1b') - repo.add_inventory(b'rev1b', inv, []) + inv = self.make_one_file_inventory(repo, b"rev1b", [], root_revision=b"rev1b") + repo.add_inventory(b"rev1b", inv, []) # make rev2, with a-file. # a-file has 'rev1b' as an ancestor, even though this is not # mentioned by 'rev1a', making it an unreferenced ancestor - inv = self.make_one_file_inventory( - repo, b'rev2', [b'rev1a', b'rev1b']) - self.add_revision(repo, b'rev2', inv, [b'rev1a']) + inv = self.make_one_file_inventory(repo, b"rev2", [b"rev1a", b"rev1b"]) + self.add_revision(repo, b"rev2", inv, [b"rev1a"]) self.versioned_root = repo.supports_rich_root() def repository_text_key_references(self): result = {} if self.versioned_root: - result.update({(b'TREE_ROOT', b'rev1a'): True, - (b'TREE_ROOT', b'rev2'): True}) - result.update({(b'a-file-id', b'rev1a'): True, - (b'a-file-id', b'rev2'): True}) + result.update( + {(b"TREE_ROOT", b"rev1a"): True, (b"TREE_ROOT", b"rev2"): True} + ) + result.update({(b"a-file-id", b"rev1a"): True, (b"a-file-id", b"rev2"): True}) return result def repository_text_keys(self): - return {(b'a-file-id', b'rev1a'): [NULL_REVISION], - (b'a-file-id', b'rev2'): [(b'a-file-id', b'rev1a')]} + return { + (b"a-file-id", b"rev1a"): [NULL_REVISION], + (b"a-file-id", b"rev2"): [(b"a-file-id", b"rev1a")], + } def versioned_repository_text_keys(self): - return {(b'TREE_ROOT', b'rev1a'): [NULL_REVISION], - (b'TREE_ROOT', b'rev2'): [(b'TREE_ROOT', b'rev1a')]} + return { + (b"TREE_ROOT", b"rev1a"): [NULL_REVISION], + (b"TREE_ROOT", b"rev2"): [(b"TREE_ROOT", b"rev1a")], + } class FileParentHasInaccessibleInventoryScenario(BrokenRepoScenario): @@ -200,28 +211,24 @@ class FileParentHasInaccessibleInventoryScenario(BrokenRepoScenario): """ def all_versions_after_reconcile(self): - return (b'rev2', b'rev3') + return (b"rev2", b"rev3") def populated_parents(self): - return ( - ((), b'rev2'), - ((b'rev1c',), b'rev3')) + return (((), b"rev2"), ((b"rev1c",), b"rev3")) def corrected_parents(self): - return ( - ((), b'rev2'), - ((), b'rev3')) + return (((), b"rev2"), ((), b"rev3")) def check_regexes(self, repo): - return [r"\* a-file-id version rev3 has parents " - r"\(rev1c\) but should have \(\)", - ] + return [ + r"\* a-file-id version rev3 has parents " r"\(rev1c\) but should have \(\)", + ] def populate_repository(self, repo): # make rev2, with a-file # a-file is sane - inv = self.make_one_file_inventory(repo, b'rev2', []) - self.add_revision(repo, b'rev2', inv, []) + inv = self.make_one_file_inventory(repo, b"rev2", []) + self.add_revision(repo, b"rev2", inv, []) # make ghost revision rev1c, with a version of a-file present so # that we generate a knit delta against this version. In real life @@ -229,31 +236,35 @@ def populate_repository(self, repo): # generated against a revision that was present at the time. So # currently we have the full history of a-file present even though # the inventory and revision objects are not. - self.make_one_file_inventory(repo, b'rev1c', []) + self.make_one_file_inventory(repo, b"rev1c", []) # make rev3 with a-file # a-file refers to 'rev1c', which is a ghost in this repository, so # a-file cannot have rev1c as its ancestor. - inv = self.make_one_file_inventory(repo, b'rev3', [b'rev1c']) - self.add_revision(repo, b'rev3', inv, [b'rev1c', b'rev1a']) + inv = self.make_one_file_inventory(repo, b"rev3", [b"rev1c"]) + self.add_revision(repo, b"rev3", inv, [b"rev1c", b"rev1a"]) self.versioned_root = repo.supports_rich_root() def repository_text_key_references(self): result = {} if self.versioned_root: - result.update({(b'TREE_ROOT', b'rev2'): True, - (b'TREE_ROOT', b'rev3'): True}) - result.update({(b'a-file-id', b'rev2'): True, - (b'a-file-id', b'rev3'): True}) + result.update( + {(b"TREE_ROOT", b"rev2"): True, (b"TREE_ROOT", b"rev3"): True} + ) + result.update({(b"a-file-id", b"rev2"): True, (b"a-file-id", b"rev3"): True}) return result def repository_text_keys(self): - return {(b'a-file-id', b'rev2'): [NULL_REVISION], - (b'a-file-id', b'rev3'): [NULL_REVISION]} + return { + (b"a-file-id", b"rev2"): [NULL_REVISION], + (b"a-file-id", b"rev3"): [NULL_REVISION], + } def versioned_repository_text_keys(self): - return {(b'TREE_ROOT', b'rev2'): [NULL_REVISION], - (b'TREE_ROOT', b'rev3'): [NULL_REVISION]} + return { + (b"TREE_ROOT", b"rev2"): [NULL_REVISION], + (b"TREE_ROOT", b"rev3"): [NULL_REVISION], + } class FileParentsNotReferencedByAnyInventoryScenario(BrokenRepoScenario): @@ -269,30 +280,32 @@ class FileParentsNotReferencedByAnyInventoryScenario(BrokenRepoScenario): """ def all_versions_after_reconcile(self): - return (b'rev1a', b'rev2c', b'rev4', b'rev5') + return (b"rev1a", b"rev2c", b"rev4", b"rev5") def populated_parents(self): return [ - ((b'rev1a',), b'rev2'), - ((b'rev1a',), b'rev2b'), - ((b'rev2',), b'rev3'), - ((b'rev2',), b'rev4'), - ((b'rev2', b'rev2c'), b'rev5')] + ((b"rev1a",), b"rev2"), + ((b"rev1a",), b"rev2b"), + ((b"rev2",), b"rev3"), + ((b"rev2",), b"rev4"), + ((b"rev2", b"rev2c"), b"rev5"), + ] def corrected_parents(self): return ( # rev2 and rev2b have been removed. - (None, b'rev2'), - (None, b'rev2b'), + (None, b"rev2"), + (None, b"rev2b"), # rev3's accessible parent inventories all have rev1a as the last # modifier. - ((b'rev1a',), b'rev3'), + ((b"rev1a",), b"rev3"), # rev1a features in both rev4's parents but should only appear once # in the result - ((b'rev1a',), b'rev4'), + ((b"rev1a",), b"rev4"), # rev2c is the head of rev1a and rev2c, the inventory provided # per-file last-modified revisions. - ((b'rev2c',), b'rev5')) + ((b"rev2c",), b"rev5"), + ) def check_regexes(self, repo): if repo.supports_rich_root(): @@ -313,20 +326,18 @@ def check_regexes(self, repo): r"a-file-id version rev4 has parents \(rev2\) " r"but should have \(rev1a\)", "%d inconsistent parents" % count, - ] + ] def populate_repository(self, repo): # make rev1a: A well-formed revision, containing 'a-file' - inv = self.make_one_file_inventory( - repo, b'rev1a', [], root_revision=b'rev1a') - self.add_revision(repo, b'rev1a', inv, []) + inv = self.make_one_file_inventory(repo, b"rev1a", [], root_revision=b"rev1a") + self.add_revision(repo, b"rev1a", inv, []) # make rev2, with a-file. # a-file is unmodified from rev1a, and an unreferenced rev2 file # version is present in the repository. - self.make_one_file_inventory( - repo, b'rev2', [b'rev1a'], inv_revision=b'rev1a') - self.add_revision(repo, b'rev2', inv, [b'rev1a']) + self.make_one_file_inventory(repo, b"rev2", [b"rev1a"], inv_revision=b"rev1a") + self.add_revision(repo, b"rev2", inv, [b"rev1a"]) # make rev3 with a-file # a-file has 'rev2' as its ancestor, but the revision in 'rev2' was @@ -335,8 +346,8 @@ def populate_repository(self, repo): # ghost, so only the details from rev1a are available for # determining whether a delta is acceptable, or a full is needed, # and what the correct parents are. - inv = self.make_one_file_inventory(repo, b'rev3', [b'rev2']) - self.add_revision(repo, b'rev3', inv, [b'rev1c', b'rev1a']) + inv = self.make_one_file_inventory(repo, b"rev3", [b"rev2"]) + self.add_revision(repo, b"rev3", inv, [b"rev1c", b"rev1a"]) # In rev2b, the true last-modifying-revision of a-file is rev1a, # inherited from rev2, but there is a version rev2b of the file, which @@ -345,8 +356,9 @@ def populate_repository(self, repo): # a-file-rev2b. # ??? This is to test deduplication in fixing rev4 inv = self.make_one_file_inventory( - repo, b'rev2b', [b'rev1a'], inv_revision=b'rev1a') - self.add_revision(repo, b'rev2b', inv, [b'rev1a']) + repo, b"rev2b", [b"rev1a"], inv_revision=b"rev1a" + ) + self.add_revision(repo, b"rev2b", inv, [b"rev1a"]) # rev4 is for testing that when the last modified of a file in # multiple parent revisions is the same, that it only appears once @@ -356,13 +368,13 @@ def populate_repository(self, repo): # a-file, and is a merge of rev2 and rev2b, so it should end up with # a parent of just rev1a - the starting file parents list is simply # completely wrong. - inv = self.make_one_file_inventory(repo, b'rev4', [b'rev2']) - self.add_revision(repo, b'rev4', inv, [b'rev2', b'rev2b']) + inv = self.make_one_file_inventory(repo, b"rev4", [b"rev2"]) + self.add_revision(repo, b"rev4", inv, [b"rev2", b"rev2b"]) # rev2c changes a-file from rev1a, so the version it of a-file it # introduces is a head revision when rev5 is checked. - inv = self.make_one_file_inventory(repo, b'rev2c', [b'rev1a']) - self.add_revision(repo, b'rev2c', inv, [b'rev1a']) + inv = self.make_one_file_inventory(repo, b"rev2c", [b"rev1a"]) + self.add_revision(repo, b"rev2c", inv, [b"rev1a"]) # rev5 descends from rev2 and rev2c; as rev2 does not alter a-file, # but rev2c does, this should use rev2c as the parent for the per @@ -370,44 +382,60 @@ def populate_repository(self, repo): # available, because we use the heads of the revision parents for # the inventory modification revisions of the file to determine the # parents for the per file graph. - inv = self.make_one_file_inventory(repo, b'rev5', [b'rev2', b'rev2c']) - self.add_revision(repo, b'rev5', inv, [b'rev2', b'rev2c']) + inv = self.make_one_file_inventory(repo, b"rev5", [b"rev2", b"rev2c"]) + self.add_revision(repo, b"rev5", inv, [b"rev2", b"rev2c"]) self.versioned_root = repo.supports_rich_root() def repository_text_key_references(self): result = {} if self.versioned_root: - result.update({(b'TREE_ROOT', b'rev1a'): True, - (b'TREE_ROOT', b'rev2'): True, - (b'TREE_ROOT', b'rev2b'): True, - (b'TREE_ROOT', b'rev2c'): True, - (b'TREE_ROOT', b'rev3'): True, - (b'TREE_ROOT', b'rev4'): True, - (b'TREE_ROOT', b'rev5'): True}) - result.update({(b'a-file-id', b'rev1a'): True, - (b'a-file-id', b'rev2c'): True, - (b'a-file-id', b'rev3'): True, - (b'a-file-id', b'rev4'): True, - (b'a-file-id', b'rev5'): True}) + result.update( + { + (b"TREE_ROOT", b"rev1a"): True, + (b"TREE_ROOT", b"rev2"): True, + (b"TREE_ROOT", b"rev2b"): True, + (b"TREE_ROOT", b"rev2c"): True, + (b"TREE_ROOT", b"rev3"): True, + (b"TREE_ROOT", b"rev4"): True, + (b"TREE_ROOT", b"rev5"): True, + } + ) + result.update( + { + (b"a-file-id", b"rev1a"): True, + (b"a-file-id", b"rev2c"): True, + (b"a-file-id", b"rev3"): True, + (b"a-file-id", b"rev4"): True, + (b"a-file-id", b"rev5"): True, + } + ) return result def repository_text_keys(self): - return {(b'a-file-id', b'rev1a'): [NULL_REVISION], - (b'a-file-id', b'rev2c'): [(b'a-file-id', b'rev1a')], - (b'a-file-id', b'rev3'): [(b'a-file-id', b'rev1a')], - (b'a-file-id', b'rev4'): [(b'a-file-id', b'rev1a')], - (b'a-file-id', b'rev5'): [(b'a-file-id', b'rev2c')]} + return { + (b"a-file-id", b"rev1a"): [NULL_REVISION], + (b"a-file-id", b"rev2c"): [(b"a-file-id", b"rev1a")], + (b"a-file-id", b"rev3"): [(b"a-file-id", b"rev1a")], + (b"a-file-id", b"rev4"): [(b"a-file-id", b"rev1a")], + (b"a-file-id", b"rev5"): [(b"a-file-id", b"rev2c")], + } def versioned_repository_text_keys(self): - return {(b'TREE_ROOT', b'rev1a'): [NULL_REVISION], - (b'TREE_ROOT', b'rev2'): [(b'TREE_ROOT', b'rev1a')], - (b'TREE_ROOT', b'rev2b'): [(b'TREE_ROOT', b'rev1a')], - (b'TREE_ROOT', b'rev2c'): [(b'TREE_ROOT', b'rev1a')], - (b'TREE_ROOT', b'rev3'): [(b'TREE_ROOT', b'rev1a')], - (b'TREE_ROOT', b'rev4'): - [(b'TREE_ROOT', b'rev2'), (b'TREE_ROOT', b'rev2b')], - (b'TREE_ROOT', b'rev5'): - [(b'TREE_ROOT', b'rev2'), (b'TREE_ROOT', b'rev2c')]} + return { + (b"TREE_ROOT", b"rev1a"): [NULL_REVISION], + (b"TREE_ROOT", b"rev2"): [(b"TREE_ROOT", b"rev1a")], + (b"TREE_ROOT", b"rev2b"): [(b"TREE_ROOT", b"rev1a")], + (b"TREE_ROOT", b"rev2c"): [(b"TREE_ROOT", b"rev1a")], + (b"TREE_ROOT", b"rev3"): [(b"TREE_ROOT", b"rev1a")], + (b"TREE_ROOT", b"rev4"): [ + (b"TREE_ROOT", b"rev2"), + (b"TREE_ROOT", b"rev2b"), + ], + (b"TREE_ROOT", b"rev5"): [ + (b"TREE_ROOT", b"rev2"), + (b"TREE_ROOT", b"rev2c"), + ], + } class UnreferencedFileParentsFromNoOpMergeScenario(BrokenRepoScenario): @@ -417,95 +445,117 @@ class UnreferencedFileParentsFromNoOpMergeScenario(BrokenRepoScenario): """ def all_versions_after_reconcile(self): - return (b'rev1a', b'rev1b', b'rev2', b'rev4') + return (b"rev1a", b"rev1b", b"rev2", b"rev4") def populated_parents(self): return ( - ((), b'rev1a'), - ((), b'rev1b'), - ((b'rev1a', b'rev1b'), b'rev2'), - (None, b'rev3'), - ((b'rev2',), b'rev4'), - ) + ((), b"rev1a"), + ((), b"rev1b"), + ((b"rev1a", b"rev1b"), b"rev2"), + (None, b"rev3"), + ((b"rev2",), b"rev4"), + ) def corrected_parents(self): return ( - ((), b'rev1a'), - ((), b'rev1b'), - ((), b'rev2'), - (None, b'rev3'), - ((b'rev2',), b'rev4'), - ) + ((), b"rev1a"), + ((), b"rev1b"), + ((), b"rev2"), + (None, b"rev3"), + ((b"rev2",), b"rev4"), + ) def corrected_fulltexts(self): - return [b'rev2'] + return [b"rev2"] def check_regexes(self, repo): return [] def populate_repository(self, repo): # make rev1a: A well-formed revision, containing 'a-file' - inv1a = self.make_one_file_inventory( - repo, b'rev1a', [], root_revision=b'rev1a') - self.add_revision(repo, b'rev1a', inv1a, []) + inv1a = self.make_one_file_inventory(repo, b"rev1a", [], root_revision=b"rev1a") + self.add_revision(repo, b"rev1a", inv1a, []) # make rev1b: A well-formed revision, containing 'a-file' # rev1b of a-file has the exact same contents as rev1a. file_contents = next( - repo.texts.get_record_stream([(b'a-file-id', b'rev1a')], - "unordered", False)).get_bytes_as('fulltext') + repo.texts.get_record_stream([(b"a-file-id", b"rev1a")], "unordered", False) + ).get_bytes_as("fulltext") inv = self.make_one_file_inventory( - repo, b'rev1b', [], root_revision=b'rev1b', - file_contents=file_contents) - self.add_revision(repo, b'rev1b', inv, []) + repo, b"rev1b", [], root_revision=b"rev1b", file_contents=file_contents + ) + self.add_revision(repo, b"rev1b", inv, []) # make rev2, a merge of rev1a and rev1b, with a-file. # a-file is unmodified from rev1a and rev1b, but a new version is # wrongly present anyway. inv = self.make_one_file_inventory( - repo, b'rev2', [b'rev1a', b'rev1b'], inv_revision=b'rev1a', - file_contents=file_contents) - self.add_revision(repo, b'rev2', inv, [b'rev1a', b'rev1b']) + repo, + b"rev2", + [b"rev1a", b"rev1b"], + inv_revision=b"rev1a", + file_contents=file_contents, + ) + self.add_revision(repo, b"rev2", inv, [b"rev1a", b"rev1b"]) # rev3: a-file unchanged from rev2, but wrongly referencing rev2 of the # file in its inventory. inv = self.make_one_file_inventory( - repo, b'rev3', [b'rev2'], inv_revision=b'rev2', - file_contents=file_contents, make_file_version=False) - self.add_revision(repo, b'rev3', inv, [b'rev2']) + repo, + b"rev3", + [b"rev2"], + inv_revision=b"rev2", + file_contents=file_contents, + make_file_version=False, + ) + self.add_revision(repo, b"rev3", inv, [b"rev2"]) # rev4: a modification of a-file on top of rev3. - inv = self.make_one_file_inventory(repo, b'rev4', [b'rev2']) - self.add_revision(repo, b'rev4', inv, [b'rev3']) + inv = self.make_one_file_inventory(repo, b"rev4", [b"rev2"]) + self.add_revision(repo, b"rev4", inv, [b"rev3"]) self.versioned_root = repo.supports_rich_root() def repository_text_key_references(self): result = {} if self.versioned_root: - result.update({(b'TREE_ROOT', b'rev1a'): True, - (b'TREE_ROOT', b'rev1b'): True, - (b'TREE_ROOT', b'rev2'): True, - (b'TREE_ROOT', b'rev3'): True, - (b'TREE_ROOT', b'rev4'): True}) - result.update({(b'a-file-id', b'rev1a'): True, - (b'a-file-id', b'rev1b'): True, - (b'a-file-id', b'rev2'): False, - (b'a-file-id', b'rev4'): True}) + result.update( + { + (b"TREE_ROOT", b"rev1a"): True, + (b"TREE_ROOT", b"rev1b"): True, + (b"TREE_ROOT", b"rev2"): True, + (b"TREE_ROOT", b"rev3"): True, + (b"TREE_ROOT", b"rev4"): True, + } + ) + result.update( + { + (b"a-file-id", b"rev1a"): True, + (b"a-file-id", b"rev1b"): True, + (b"a-file-id", b"rev2"): False, + (b"a-file-id", b"rev4"): True, + } + ) return result def repository_text_keys(self): - return {(b'a-file-id', b'rev1a'): [NULL_REVISION], - (b'a-file-id', b'rev1b'): [NULL_REVISION], - (b'a-file-id', b'rev2'): [NULL_REVISION], - (b'a-file-id', b'rev4'): [(b'a-file-id', b'rev2')]} + return { + (b"a-file-id", b"rev1a"): [NULL_REVISION], + (b"a-file-id", b"rev1b"): [NULL_REVISION], + (b"a-file-id", b"rev2"): [NULL_REVISION], + (b"a-file-id", b"rev4"): [(b"a-file-id", b"rev2")], + } def versioned_repository_text_keys(self): - return {(b'TREE_ROOT', b'rev1a'): [NULL_REVISION], - (b'TREE_ROOT', b'rev1b'): [NULL_REVISION], - (b'TREE_ROOT', b'rev2'): - [(b'TREE_ROOT', b'rev1a'), (b'TREE_ROOT', b'rev1b')], - (b'TREE_ROOT', b'rev3'): [(b'TREE_ROOT', b'rev2')], - (b'TREE_ROOT', b'rev4'): [(b'TREE_ROOT', b'rev3')]} + return { + (b"TREE_ROOT", b"rev1a"): [NULL_REVISION], + (b"TREE_ROOT", b"rev1b"): [NULL_REVISION], + (b"TREE_ROOT", b"rev2"): [ + (b"TREE_ROOT", b"rev1a"), + (b"TREE_ROOT", b"rev1b"), + ], + (b"TREE_ROOT", b"rev3"): [(b"TREE_ROOT", b"rev2")], + (b"TREE_ROOT", b"rev4"): [(b"TREE_ROOT", b"rev3")], + } class TooManyParentsScenario(BrokenRepoScenario): @@ -516,19 +566,21 @@ class TooManyParentsScenario(BrokenRepoScenario): """ def all_versions_after_reconcile(self): - return (b'bad-parent', b'good-parent', b'broken-revision') + return (b"bad-parent", b"good-parent", b"broken-revision") def populated_parents(self): return ( - ((), b'bad-parent'), - ((b'bad-parent',), b'good-parent'), - ((b'good-parent', b'bad-parent'), b'broken-revision')) + ((), b"bad-parent"), + ((b"bad-parent",), b"good-parent"), + ((b"good-parent", b"bad-parent"), b"broken-revision"), + ) def corrected_parents(self): return ( - ((), b'bad-parent'), - ((b'bad-parent',), b'good-parent'), - ((b'good-parent',), b'broken-revision')) + ((), b"bad-parent"), + ((b"bad-parent",), b"good-parent"), + ((b"good-parent",), b"broken-revision"), + ) def check_regexes(self, repo): if repo.supports_rich_root(): @@ -538,47 +590,61 @@ def check_regexes(self, repo): else: count = 1 return ( - ' %d inconsistent parents' % count, - (r" \* a-file-id version broken-revision has parents " - r"\(good-parent, bad-parent\) but " - r"should have \(good-parent\)")) + " %d inconsistent parents" % count, + ( + r" \* a-file-id version broken-revision has parents " + r"\(good-parent, bad-parent\) but " + r"should have \(good-parent\)" + ), + ) def populate_repository(self, repo): inv = self.make_one_file_inventory( - repo, b'bad-parent', (), root_revision=b'bad-parent') - self.add_revision(repo, b'bad-parent', inv, ()) + repo, b"bad-parent", (), root_revision=b"bad-parent" + ) + self.add_revision(repo, b"bad-parent", inv, ()) - inv = self.make_one_file_inventory( - repo, b'good-parent', (b'bad-parent',)) - self.add_revision(repo, b'good-parent', inv, (b'bad-parent',)) + inv = self.make_one_file_inventory(repo, b"good-parent", (b"bad-parent",)) + self.add_revision(repo, b"good-parent", inv, (b"bad-parent",)) inv = self.make_one_file_inventory( - repo, b'broken-revision', (b'good-parent', b'bad-parent')) - self.add_revision(repo, b'broken-revision', inv, (b'good-parent',)) + repo, b"broken-revision", (b"good-parent", b"bad-parent") + ) + self.add_revision(repo, b"broken-revision", inv, (b"good-parent",)) self.versioned_root = repo.supports_rich_root() def repository_text_key_references(self): result = {} if self.versioned_root: - result.update({(b'TREE_ROOT', b'bad-parent'): True, - (b'TREE_ROOT', b'broken-revision'): True, - (b'TREE_ROOT', b'good-parent'): True}) - result.update({(b'a-file-id', b'bad-parent'): True, - (b'a-file-id', b'broken-revision'): True, - (b'a-file-id', b'good-parent'): True}) + result.update( + { + (b"TREE_ROOT", b"bad-parent"): True, + (b"TREE_ROOT", b"broken-revision"): True, + (b"TREE_ROOT", b"good-parent"): True, + } + ) + result.update( + { + (b"a-file-id", b"bad-parent"): True, + (b"a-file-id", b"broken-revision"): True, + (b"a-file-id", b"good-parent"): True, + } + ) return result def repository_text_keys(self): - return {(b'a-file-id', b'bad-parent'): [NULL_REVISION], - (b'a-file-id', b'broken-revision'): - [(b'a-file-id', b'good-parent')], - (b'a-file-id', b'good-parent'): [(b'a-file-id', b'bad-parent')]} + return { + (b"a-file-id", b"bad-parent"): [NULL_REVISION], + (b"a-file-id", b"broken-revision"): [(b"a-file-id", b"good-parent")], + (b"a-file-id", b"good-parent"): [(b"a-file-id", b"bad-parent")], + } def versioned_repository_text_keys(self): - return {(b'TREE_ROOT', b'bad-parent'): [NULL_REVISION], - (b'TREE_ROOT', b'broken-revision'): - [(b'TREE_ROOT', b'good-parent')], - (b'TREE_ROOT', b'good-parent'): [(b'TREE_ROOT', b'bad-parent')]} + return { + (b"TREE_ROOT", b"bad-parent"): [NULL_REVISION], + (b"TREE_ROOT", b"broken-revision"): [(b"TREE_ROOT", b"good-parent")], + (b"TREE_ROOT", b"good-parent"): [(b"TREE_ROOT", b"bad-parent")], + } class ClaimedFileParentDidNotModifyFileScenario(BrokenRepoScenario): @@ -592,19 +658,21 @@ class ClaimedFileParentDidNotModifyFileScenario(BrokenRepoScenario): """ def all_versions_after_reconcile(self): - return (b'basis', b'current') + return (b"basis", b"current") def populated_parents(self): return ( - ((), b'basis'), - ((b'basis',), b'modified-something-else'), - ((b'modified-something-else',), b'current')) + ((), b"basis"), + ((b"basis",), b"modified-something-else"), + ((b"modified-something-else",), b"current"), + ) def corrected_parents(self): return ( - ((), b'basis'), - (None, b'modified-something-else'), - ((b'basis',), b'current')) + ((), b"basis"), + (None, b"modified-something-else"), + ((b"basis",), b"current"), + ) def check_regexes(self, repo): if repo.supports_rich_root(): @@ -617,47 +685,56 @@ def check_regexes(self, repo): "%d inconsistent parents" % count, r"\* a-file-id version current has parents " r"\(modified-something-else\) but should have \(basis\)", - ) + ) def populate_repository(self, repo): - inv = self.make_one_file_inventory(repo, b'basis', ()) - self.add_revision(repo, b'basis', inv, ()) + inv = self.make_one_file_inventory(repo, b"basis", ()) + self.add_revision(repo, b"basis", inv, ()) # 'modified-something-else' is a correctly recorded revision, but it # does not modify the file we are looking at, so the inventory for that # file in this revision points to 'basis'. inv = self.make_one_file_inventory( - repo, b'modified-something-else', (b'basis',), inv_revision=b'basis') - self.add_revision(repo, b'modified-something-else', inv, (b'basis',)) + repo, b"modified-something-else", (b"basis",), inv_revision=b"basis" + ) + self.add_revision(repo, b"modified-something-else", inv, (b"basis",)) # The 'current' revision has 'modified-something-else' as its parent, # but the 'current' version of 'a-file' should have 'basis' as its # parent. inv = self.make_one_file_inventory( - repo, b'current', (b'modified-something-else',)) - self.add_revision(repo, b'current', inv, (b'modified-something-else',)) + repo, b"current", (b"modified-something-else",) + ) + self.add_revision(repo, b"current", inv, (b"modified-something-else",)) self.versioned_root = repo.supports_rich_root() def repository_text_key_references(self): result = {} if self.versioned_root: - result.update({(b'TREE_ROOT', b'basis'): True, - (b'TREE_ROOT', b'current'): True, - (b'TREE_ROOT', b'modified-something-else'): True}) - result.update({(b'a-file-id', b'basis'): True, - (b'a-file-id', b'current'): True}) + result.update( + { + (b"TREE_ROOT", b"basis"): True, + (b"TREE_ROOT", b"current"): True, + (b"TREE_ROOT", b"modified-something-else"): True, + } + ) + result.update( + {(b"a-file-id", b"basis"): True, (b"a-file-id", b"current"): True} + ) return result def repository_text_keys(self): - return {(b'a-file-id', b'basis'): [NULL_REVISION], - (b'a-file-id', b'current'): [(b'a-file-id', b'basis')]} + return { + (b"a-file-id", b"basis"): [NULL_REVISION], + (b"a-file-id", b"current"): [(b"a-file-id", b"basis")], + } def versioned_repository_text_keys(self): - return {(b'TREE_ROOT', b'basis'): [b'null:'], - (b'TREE_ROOT', b'current'): - [(b'TREE_ROOT', b'modified-something-else')], - (b'TREE_ROOT', b'modified-something-else'): - [(b'TREE_ROOT', b'basis')]} + return { + (b"TREE_ROOT", b"basis"): [b"null:"], + (b"TREE_ROOT", b"current"): [(b"TREE_ROOT", b"modified-something-else")], + (b"TREE_ROOT", b"modified-something-else"): [(b"TREE_ROOT", b"basis")], + } class IncorrectlyOrderedParentsScenario(BrokenRepoScenario): @@ -673,22 +750,28 @@ class IncorrectlyOrderedParentsScenario(BrokenRepoScenario): """ def all_versions_after_reconcile(self): - return [b'parent-1', b'parent-2', b'broken-revision-1-2', - b'broken-revision-2-1'] + return [ + b"parent-1", + b"parent-2", + b"broken-revision-1-2", + b"broken-revision-2-1", + ] def populated_parents(self): return ( - ((), b'parent-1'), - ((), b'parent-2'), - ((b'parent-2', b'parent-1'), b'broken-revision-1-2'), - ((b'parent-1', b'parent-2'), b'broken-revision-2-1')) + ((), b"parent-1"), + ((), b"parent-2"), + ((b"parent-2", b"parent-1"), b"broken-revision-1-2"), + ((b"parent-1", b"parent-2"), b"broken-revision-2-1"), + ) def corrected_parents(self): return ( - ((), b'parent-1'), - ((), b'parent-2'), - ((b'parent-1', b'parent-2'), b'broken-revision-1-2'), - ((b'parent-2', b'parent-1'), b'broken-revision-2-1')) + ((), b"parent-1"), + ((), b"parent-2"), + ((b"parent-1", b"parent-2"), b"broken-revision-1-2"), + ((b"parent-2", b"parent-1"), b"broken-revision-2-1"), + ) def check_regexes(self, repo): if repo.supports_rich_root(): @@ -704,54 +787,75 @@ def check_regexes(self, repo): r"\(parent-1, parent-2\)", r"\* a-file-id version broken-revision-2-1 has parents " r"\(parent-1, parent-2\) but should have " - r"\(parent-2, parent-1\)") + r"\(parent-2, parent-1\)", + ) def populate_repository(self, repo): - inv = self.make_one_file_inventory(repo, b'parent-1', []) - self.add_revision(repo, b'parent-1', inv, []) + inv = self.make_one_file_inventory(repo, b"parent-1", []) + self.add_revision(repo, b"parent-1", inv, []) - inv = self.make_one_file_inventory(repo, b'parent-2', []) - self.add_revision(repo, b'parent-2', inv, []) + inv = self.make_one_file_inventory(repo, b"parent-2", []) + self.add_revision(repo, b"parent-2", inv, []) inv = self.make_one_file_inventory( - repo, b'broken-revision-1-2', [b'parent-2', b'parent-1']) - self.add_revision( - repo, b'broken-revision-1-2', inv, [b'parent-1', b'parent-2']) + repo, b"broken-revision-1-2", [b"parent-2", b"parent-1"] + ) + self.add_revision(repo, b"broken-revision-1-2", inv, [b"parent-1", b"parent-2"]) inv = self.make_one_file_inventory( - repo, b'broken-revision-2-1', [b'parent-1', b'parent-2']) - self.add_revision( - repo, b'broken-revision-2-1', inv, [b'parent-2', b'parent-1']) + repo, b"broken-revision-2-1", [b"parent-1", b"parent-2"] + ) + self.add_revision(repo, b"broken-revision-2-1", inv, [b"parent-2", b"parent-1"]) self.versioned_root = repo.supports_rich_root() def repository_text_key_references(self): result = {} if self.versioned_root: - result.update({(b'TREE_ROOT', b'broken-revision-1-2'): True, - (b'TREE_ROOT', b'broken-revision-2-1'): True, - (b'TREE_ROOT', b'parent-1'): True, - (b'TREE_ROOT', b'parent-2'): True}) - result.update({(b'a-file-id', b'broken-revision-1-2'): True, - (b'a-file-id', b'broken-revision-2-1'): True, - (b'a-file-id', b'parent-1'): True, - (b'a-file-id', b'parent-2'): True}) + result.update( + { + (b"TREE_ROOT", b"broken-revision-1-2"): True, + (b"TREE_ROOT", b"broken-revision-2-1"): True, + (b"TREE_ROOT", b"parent-1"): True, + (b"TREE_ROOT", b"parent-2"): True, + } + ) + result.update( + { + (b"a-file-id", b"broken-revision-1-2"): True, + (b"a-file-id", b"broken-revision-2-1"): True, + (b"a-file-id", b"parent-1"): True, + (b"a-file-id", b"parent-2"): True, + } + ) return result def repository_text_keys(self): - return {(b'a-file-id', b'broken-revision-1-2'): - [(b'a-file-id', b'parent-1'), (b'a-file-id', b'parent-2')], - (b'a-file-id', b'broken-revision-2-1'): - [(b'a-file-id', b'parent-2'), (b'a-file-id', b'parent-1')], - (b'a-file-id', b'parent-1'): [NULL_REVISION], - (b'a-file-id', b'parent-2'): [NULL_REVISION]} + return { + (b"a-file-id", b"broken-revision-1-2"): [ + (b"a-file-id", b"parent-1"), + (b"a-file-id", b"parent-2"), + ], + (b"a-file-id", b"broken-revision-2-1"): [ + (b"a-file-id", b"parent-2"), + (b"a-file-id", b"parent-1"), + ], + (b"a-file-id", b"parent-1"): [NULL_REVISION], + (b"a-file-id", b"parent-2"): [NULL_REVISION], + } def versioned_repository_text_keys(self): - return {(b'TREE_ROOT', b'broken-revision-1-2'): - [(b'TREE_ROOT', b'parent-1'), (b'TREE_ROOT', b'parent-2')], - (b'TREE_ROOT', b'broken-revision-2-1'): - [(b'TREE_ROOT', b'parent-2'), (b'TREE_ROOT', b'parent-1')], - (b'TREE_ROOT', b'parent-1'): [NULL_REVISION], - (b'TREE_ROOT', b'parent-2'): [NULL_REVISION]} + return { + (b"TREE_ROOT", b"broken-revision-1-2"): [ + (b"TREE_ROOT", b"parent-1"), + (b"TREE_ROOT", b"parent-2"), + ], + (b"TREE_ROOT", b"broken-revision-2-1"): [ + (b"TREE_ROOT", b"parent-2"), + (b"TREE_ROOT", b"parent-1"), + ], + (b"TREE_ROOT", b"parent-1"): [NULL_REVISION], + (b"TREE_ROOT", b"parent-2"): [NULL_REVISION], + } all_broken_scenario_classes = [ @@ -763,15 +867,16 @@ def versioned_repository_text_keys(self): ClaimedFileParentDidNotModifyFileScenario, IncorrectlyOrderedParentsScenario, UnreferencedFileParentsFromNoOpMergeScenario, - ] +] def broken_scenarios_for_all_formats(): format_scenarios = all_repository_vf_format_scenarios() # test_check_reconcile needs to be parameterized by format *and* by broken # repository scenario. - broken_scenarios = [(s.__name__, {'scenario_class': s}) - for s in all_broken_scenario_classes] + broken_scenarios = [ + (s.__name__, {"scenario_class": s}) for s in all_broken_scenario_classes + ] return multiply_scenarios(format_scenarios, broken_scenarios) @@ -782,7 +887,7 @@ class TestFileParentReconciliation(TestCaseWithRepository): def make_populated_repository(self, factory): """Create a new repository populated by the given factory.""" - repo = self.make_repository('broken-repo') + repo = self.make_repository("broken-repo") with repo.lock_write(), WriteGroup(repo): factory(repo) return repo @@ -802,15 +907,28 @@ def add_revision(self, repo, revision_id, inv, parent_ids): root_id = inv.root.file_id repo.texts.add_lines((root_id, revision_id), [], []) repo.add_inventory(revision_id, inv, parent_ids) - revision = Revision(revision_id, committer='jrandom@example.com', - timestamp=0, inventory_sha1=b'', timezone=0, message='foo', - properties={}, - parent_ids=parent_ids) + revision = Revision( + revision_id, + committer="jrandom@example.com", + timestamp=0, + inventory_sha1=b"", + timezone=0, + message="foo", + properties={}, + parent_ids=parent_ids, + ) repo.add_revision(revision_id, revision, inv) - def make_one_file_inventory(self, repo, revision, parents, - inv_revision=None, root_revision=None, - file_contents=None, make_file_version=True): + def make_one_file_inventory( + self, + repo, + revision, + parents, + inv_revision=None, + root_revision=None, + file_contents=None, + make_file_version=True, + ): """Make an inventory containing a version of a file with ID 'a-file'. The file's ID will be 'a-file', and its filename will be 'a file name', @@ -830,46 +948,52 @@ def make_one_file_inventory(self, repo, revision, parents, inv = Inventory(revision_id=revision) if root_revision is not None: inv.root.revision = root_revision - file_id = b'a-file-id' - entry = InventoryFile(file_id, 'a file name', b'TREE_ROOT') + file_id = b"a-file-id" + entry = InventoryFile(file_id, "a file name", b"TREE_ROOT") if inv_revision is not None: entry.revision = inv_revision else: entry.revision = revision entry.text_size = 0 if file_contents is None: - file_contents = b'%sline\n' % entry.revision + file_contents = b"%sline\n" % entry.revision entry.text_sha1 = osutils.sha_string(file_contents) inv.add(entry) if make_file_version: - repo.texts.add_lines((file_id, revision), - [(file_id, parent) for parent in parents], [file_contents]) + repo.texts.add_lines( + (file_id, revision), + [(file_id, parent) for parent in parents], + [file_contents], + ) return inv def require_repo_suffers_text_parent_corruption(self, repo): if not repo._reconcile_fixes_text_parents: raise TestNotApplicable( - "Format does not support text parent reconciliation") + "Format does not support text parent reconciliation" + ) def file_parents(self, repo, revision_id): - key = (b'a-file-id', revision_id) + key = (b"a-file-id", revision_id) parent_map = repo.texts.get_parent_map([key]) return tuple(parent[-1] for parent in parent_map[key]) def assertFileVersionAbsent(self, repo, revision_id): - self.assertEqual({}, - repo.texts.get_parent_map([(b'a-file-id', revision_id)])) + self.assertEqual({}, repo.texts.get_parent_map([(b"a-file-id", revision_id)])) - def assertParentsMatch(self, expected_parents_for_versions, repo, - when_description): + def assertParentsMatch(self, expected_parents_for_versions, repo, when_description): for expected_parents, version in expected_parents_for_versions: if expected_parents is None: self.assertFileVersionAbsent(repo, version) else: found_parents = self.file_parents(repo, version) - self.assertEqual(expected_parents, found_parents, - "{} reconcile {} has parents {}, should have {}.".format(when_description, version, found_parents, - expected_parents)) + self.assertEqual( + expected_parents, + found_parents, + "{} reconcile {} has parents {}, should have {}.".format( + when_description, version, found_parents, expected_parents + ), + ) def prepare_test_repository(self): """Prepare a repository to test with from the test scenario. @@ -889,7 +1013,7 @@ def shas_for_versions_of_file(self, repo, versions): :returns: A dict of `{version: hash}`. """ - keys = [(b'a-file-id', version) for version in versions] + keys = [(b"a-file-id", version) for version in versions] return repo.texts.get_sha1s(keys) def test_reconcile_behaviour(self): @@ -898,20 +1022,21 @@ def test_reconcile_behaviour(self): """ repo, scenario = self.prepare_test_repository() with repo.lock_read(): - self.assertParentsMatch(scenario.populated_parents(), repo, - b'before') + self.assertParentsMatch(scenario.populated_parents(), repo, b"before") vf_shas = self.shas_for_versions_of_file( - repo, scenario.all_versions_after_reconcile()) + repo, scenario.all_versions_after_reconcile() + ) repo.reconcile(thorough=True) with repo.lock_read(): - self.assertParentsMatch(scenario.corrected_parents(), repo, - b'after') + self.assertParentsMatch(scenario.corrected_parents(), repo, b"after") # The contents of the versions in the versionedfile should be the # same after the reconcile. self.assertEqual( vf_shas, self.shas_for_versions_of_file( - repo, scenario.all_versions_after_reconcile())) + repo, scenario.all_versions_after_reconcile() + ), + ) # Scenario.corrected_fulltexts contains texts which the test wants # to assert are now fulltexts. However this is an abstraction @@ -921,12 +1046,14 @@ def test_reconcile_behaviour(self): # (we specify it this way because a store can use arbitrary # compression pointers in principle. for file_version in scenario.corrected_fulltexts(): - key = (b'a-file-id', file_version) + key = (b"a-file-id", file_version) self.assertEqual({key: ()}, repo.texts.get_parent_map([key])) self.assertIsInstance( - next(repo.texts.get_record_stream([key], 'unordered', - True)).get_bytes_as('fulltext'), - bytes) + next( + repo.texts.get_record_stream([key], "unordered", True) + ).get_bytes_as("fulltext"), + bytes, + ) def test_check_behaviour(self): """Populate a repository and check it, and verify the output.""" @@ -942,13 +1069,15 @@ def test_find_text_key_references(self): repo, scenario = self.prepare_test_repository() repo.lock_read() self.addCleanup(repo.unlock) - self.assertEqual(scenario.repository_text_key_references(), - repo.find_text_key_references()) + self.assertEqual( + scenario.repository_text_key_references(), repo.find_text_key_references() + ) def test__generate_text_key_index(self): """Test that the generated text key index has all entries.""" repo, scenario = self.prepare_test_repository() repo.lock_read() self.addCleanup(repo.unlock) - self.assertEqual(scenario.repository_text_key_index(), - repo._generate_text_key_index()) + self.assertEqual( + scenario.repository_text_key_index(), repo._generate_text_key_index() + ) diff --git a/breezy/bzr/tests/per_repository_vf/test_fetch.py b/breezy/bzr/tests/per_repository_vf/test_fetch.py index 2ac5d2fa0d..e8fbbf9db1 100644 --- a/breezy/bzr/tests/per_repository_vf/test_fetch.py +++ b/breezy/bzr/tests/per_repository_vf/test_fetch.py @@ -37,20 +37,26 @@ def test_no_absent_records_in_stream_with_ghosts(self): # doesn't actually gain coverage there; need a specific set of # permutations to cover it. # bug lp:376255 was reported about this. - builder = self.make_branch_builder('repo') + builder = self.make_branch_builder("repo") builder.start_series() - builder.build_snapshot([b'ghost'], - [('add', ('', b'ROOT_ID', 'directory', ''))], - allow_leftmost_as_ghost=True, revision_id=b'tip') + builder.build_snapshot( + [b"ghost"], + [("add", ("", b"ROOT_ID", "directory", ""))], + allow_leftmost_as_ghost=True, + revision_id=b"tip", + ) builder.finish_series() b = builder.get_branch() b.lock_read() self.addCleanup(b.unlock) repo = b.repository source = repo._get_source(repo._format) - search = vf_search.PendingAncestryResult([b'tip'], repo) + search = vf_search.PendingAncestryResult([b"tip"], repo) stream = source.get_stream(search) for substream_type, substream in stream: for record in substream: - self.assertNotEqual('absent', record.storage_kind, - f"Absent record for {(substream_type,) + record.key}") + self.assertNotEqual( + "absent", + record.storage_kind, + f"Absent record for {(substream_type,) + record.key}", + ) diff --git a/breezy/bzr/tests/per_repository_vf/test_fileid_involved.py b/breezy/bzr/tests/per_repository_vf/test_fileid_involved.py index 7582dfae82..62b5368513 100644 --- a/breezy/bzr/tests/per_repository_vf/test_fileid_involved.py +++ b/breezy/bzr/tests/per_repository_vf/test_fileid_involved.py @@ -31,33 +31,37 @@ class FileIdInvolvedWGhosts(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def create_branch_with_ghost_text(self): - builder = self.make_branch_builder('ghost') - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('a', b'a-file-id', 'file', b'some content\n'))], - revision_id=b'A-id') + builder = self.make_branch_builder("ghost") + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("a", b"a-file-id", "file", b"some content\n")), + ], + revision_id=b"A-id", + ) b = builder.get_branch() - old_rt = b.repository.revision_tree(b'A-id') + old_rt = b.repository.revision_tree(b"A-id") new_inv = inventory.mutable_inventory_from_tree(old_rt) - new_inv.revision_id = b'B-id' - new_inv.get_entry(b'a-file-id').revision = b'ghost-id' - new_rev = _mod_revision.Revision(b'B-id', - timestamp=time.time(), - timezone=0, - message='Committing against a ghost', - committer='Joe Foo ', - properties={}, - parent_ids=(b'A-id', b'ghost-id'), - inventory_sha1=None, - ) + new_inv.revision_id = b"B-id" + new_inv.get_entry(b"a-file-id").revision = b"ghost-id" + new_rev = _mod_revision.Revision( + b"B-id", + timestamp=time.time(), + timezone=0, + message="Committing against a ghost", + committer="Joe Foo ", + properties={}, + parent_ids=(b"A-id", b"ghost-id"), + inventory_sha1=None, + ) b.lock_write() self.addCleanup(b.unlock) b.repository.start_write_group() - b.repository.add_revision(b'B-id', new_rev, new_inv) + b.repository.add_revision(b"B-id", new_rev, new_inv) self.disable_commit_write_group_paranoia(b.repository) b.repository.commit_write_group() return b @@ -68,8 +72,9 @@ def disable_commit_write_group_paranoia(self, repo): repo.abort_write_group() raise tests.TestSkipped( "repository format does not support storing revisions with " - "missing texts.") - pack_coll = getattr(repo, '_pack_collection', None) + "missing texts." + ) + pack_coll = getattr(repo, "_pack_collection", None) if pack_coll is not None: # Monkey-patch the pack collection instance to allow storing # incomplete revisions. @@ -79,66 +84,71 @@ def test_file_ids_include_ghosts(self): b = self.create_branch_with_ghost_text() repo = b.repository self.assertEqual( - {b'a-file-id': {b'ghost-id'}}, - repo.fileids_altered_by_revision_ids([b'B-id'])) + {b"a-file-id": {b"ghost-id"}}, + repo.fileids_altered_by_revision_ids([b"B-id"]), + ) def test_file_ids_uses_fallbacks(self): - builder = self.make_branch_builder('source', - format=self.bzrdir_format) + builder = self.make_branch_builder("source", format=self.bzrdir_format) repo = builder.get_branch().repository if not repo._format.supports_external_lookups: - raise tests.TestNotApplicable('format does not support stacking') + raise tests.TestNotApplicable("format does not support stacking") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('file', b'file-id', 'file', b'contents\n'))], - revision_id=b'A-id') - builder.build_snapshot([b'A-id'], [ - ('modify', ('file', b'new-content\n'))], - revision_id=b'B-id') - builder.build_snapshot([b'B-id'], [ - ('modify', ('file', b'yet more content\n'))], - revision_id=b'C-id') + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("file", b"file-id", "file", b"contents\n")), + ], + revision_id=b"A-id", + ) + builder.build_snapshot( + [b"A-id"], [("modify", ("file", b"new-content\n"))], revision_id=b"B-id" + ) + builder.build_snapshot( + [b"B-id"], + [("modify", ("file", b"yet more content\n"))], + revision_id=b"C-id", + ) builder.finish_series() source_b = builder.get_branch() source_b.lock_read() self.addCleanup(source_b.unlock) - base = self.make_branch('base') - base.pull(source_b, stop_revision=b'B-id') - stacked = self.make_branch('stacked') - stacked.set_stacked_on_url('../base') - stacked.pull(source_b, stop_revision=b'C-id') + base = self.make_branch("base") + base.pull(source_b, stop_revision=b"B-id") + stacked = self.make_branch("stacked") + stacked.set_stacked_on_url("../base") + stacked.pull(source_b, stop_revision=b"C-id") stacked.lock_read() self.addCleanup(stacked.unlock) repo = stacked.repository - keys = {b'file-id': {b'A-id'}} + keys = {b"file-id": {b"A-id"}} if stacked.repository.supports_rich_root(): - keys[b'root-id'] = {b'A-id'} - self.assertEqual(keys, repo.fileids_altered_by_revision_ids([b'A-id'])) + keys[b"root-id"] = {b"A-id"} + self.assertEqual(keys, repo.fileids_altered_by_revision_ids([b"A-id"])) class FileIdInvolvedBase(TestCaseWithRepository): - def touch(self, tree, filename): # use the trees transport to not depend on the tree's location or type. - tree.controldir.root_transport.append_bytes( - filename, b"appended line\n") + tree.controldir.root_transport.append_bytes(filename, b"appended line\n") def compare_tree_fileids(self, branch, old_rev, new_rev): old_tree = self.branch.repository.revision_tree(old_rev) new_tree = self.branch.repository.revision_tree(new_rev) delta = new_tree.changes_from(old_tree) - l2 = [change.file_id for change in delta.added] + \ - [change.file_id for change in delta.renamed] + \ - [change.file_id for change in delta.modified] + \ - [change.file_id for change in delta.copied] + l2 = ( + [change.file_id for change in delta.added] + + [change.file_id for change in delta.renamed] + + [change.file_id for change in delta.modified] + + [change.file_id for change in delta.copied] + ) return set(l2) class TestFileIdInvolved(FileIdInvolvedBase): - scenarios = all_repository_vf_format_scenarios() def setUp(self): @@ -161,34 +171,38 @@ def setUp(self): # J changes: 'b-file-id-2006-01-01-defg' # K changes: 'c-funkyquiji%bo' - main_wt = self.make_branch_and_tree('main') + main_wt = self.make_branch_and_tree("main") main_branch = main_wt.branch self.build_tree(["main/a", "main/b", "main/c"]) main_wt.add( - ['a', 'b', 'c'], - ids=[b'a-file-id-2006-01-01-abcd', - b'b-file-id-2006-01-01-defg', - b'c-funkyquiji%bo']) + ["a", "b", "c"], + ids=[ + b"a-file-id-2006-01-01-abcd", + b"b-file-id-2006-01-01-defg", + b"c-funkyquiji%bo", + ], + ) try: main_wt.commit("Commit one", rev_id=b"rev-A") except errors.IllegalPath as e: # TODO: jam 20060701 Consider raising a different exception # newer formats do support this, and nothin can done to # correct this test - its not a bug. - if sys.platform == 'win32': - raise tests.TestSkipped('Old repository formats do not' - ' support file ids with <> on win32') from e + if sys.platform == "win32": + raise tests.TestSkipped( + "Old repository formats do not" " support file ids with <> on win32" + ) from e # This is not a known error condition raise # -------- end A ----------- - bt1 = self.make_branch_and_tree('branch1') + bt1 = self.make_branch_and_tree("branch1") bt1.pull(main_branch) b1 = bt1.branch self.build_tree(["branch1/d"]) - bt1.add(['d'], ids=[b'file-d']) + bt1.add(["d"], ids=[b"file-d"]) bt1.commit("branch1, Commit one", rev_id=b"rev-E") # -------- end E ----------- @@ -198,10 +212,10 @@ def setUp(self): # -------- end B ----------- - bt2 = self.make_branch_and_tree('branch2') + bt2 = self.make_branch_and_tree("branch2") bt2.pull(main_branch) branch2_branch = bt2.branch - set_executability(bt2, 'b', True) + set_executability(bt2, "b", True) bt2.commit("branch2, Commit one", rev_id=b"rev-J") # -------- end J ----------- @@ -237,42 +251,54 @@ def setUp(self): def test_fileids_altered_between_two_revs(self): self.branch.lock_read() self.addCleanup(self.branch.unlock) - self.branch.repository.fileids_altered_by_revision_ids( - [b"rev-J", b"rev-K"]) + self.branch.repository.fileids_altered_by_revision_ids([b"rev-J", b"rev-K"]) self.assertEqual( - {b'b-file-id-2006-01-01-defg': {b'rev-J'}, - b'c-funkyquiji%bo': {b'rev-K'} - }, - self.branch.repository.fileids_altered_by_revision_ids([b"rev-J", b"rev-K"])) + { + b"b-file-id-2006-01-01-defg": {b"rev-J"}, + b"c-funkyquiji%bo": {b"rev-K"}, + }, + self.branch.repository.fileids_altered_by_revision_ids( + [b"rev-J", b"rev-K"] + ), + ) self.assertEqual( - {b'b-file-id-2006-01-01-defg': {b'rev-'}, - b'file-d': {b'rev-F'}, - }, - self.branch.repository.fileids_altered_by_revision_ids([b'rev-', b'rev-F'])) + { + b"b-file-id-2006-01-01-defg": {b"rev-"}, + b"file-d": {b"rev-F"}, + }, + self.branch.repository.fileids_altered_by_revision_ids( + [b"rev-", b"rev-F"] + ), + ) self.assertEqual( { - b'b-file-id-2006-01-01-defg': {b'rev-', b'rev-G', b'rev-J'}, - b'c-funkyquiji%bo': {b'rev-K'}, - b'file-d': {b'rev-F'}, - }, + b"b-file-id-2006-01-01-defg": {b"rev-", b"rev-G", b"rev-J"}, + b"c-funkyquiji%bo": {b"rev-K"}, + b"file-d": {b"rev-F"}, + }, self.branch.repository.fileids_altered_by_revision_ids( - [b'rev-', b'rev-G', b'rev-F', b'rev-K', b'rev-J'])) + [b"rev-", b"rev-G", b"rev-F", b"rev-K", b"rev-J"] + ), + ) self.assertEqual( - {b'a-file-id-2006-01-01-abcd': {b'rev-B'}, - b'b-file-id-2006-01-01-defg': {b'rev-', b'rev-G', b'rev-J'}, - b'c-funkyquiji%bo': {b'rev-K'}, - b'file-d': {b'rev-F'}, - }, + { + b"a-file-id-2006-01-01-abcd": {b"rev-B"}, + b"b-file-id-2006-01-01-defg": {b"rev-", b"rev-G", b"rev-J"}, + b"c-funkyquiji%bo": {b"rev-K"}, + b"file-d": {b"rev-F"}, + }, self.branch.repository.fileids_altered_by_revision_ids( - [b'rev-G', b'rev-F', b'rev-C', b'rev-B', b'rev-', b'rev-K', b'rev-J'])) + [b"rev-G", b"rev-F", b"rev-C", b"rev-B", b"rev-", b"rev-K", b"rev-J"] + ), + ) def fileids_altered_by_revision_ids(self, revision_ids): """This is a wrapper to strip TREE_ROOT if it occurs.""" repo = self.branch.repository - root_id = self.branch.basis_tree().path2id('') + root_id = self.branch.basis_tree().path2id("") result = repo.fileids_altered_by_revision_ids(revision_ids) if root_id in result: del result[root_id] @@ -282,19 +308,21 @@ def test_fileids_altered_by_revision_ids(self): self.branch.lock_read() self.addCleanup(self.branch.unlock) self.assertEqual( - {b'a-file-id-2006-01-01-abcd': {b'rev-A'}, - b'b-file-id-2006-01-01-defg': {b'rev-A'}, - b'c-funkyquiji%bo': {b'rev-A'}, - }, - self.fileids_altered_by_revision_ids([b"rev-A"])) + { + b"a-file-id-2006-01-01-abcd": {b"rev-A"}, + b"b-file-id-2006-01-01-defg": {b"rev-A"}, + b"c-funkyquiji%bo": {b"rev-A"}, + }, + self.fileids_altered_by_revision_ids([b"rev-A"]), + ) self.assertEqual( - {b'a-file-id-2006-01-01-abcd': {b'rev-B'} - }, - self.branch.repository.fileids_altered_by_revision_ids([b"rev-B"])) + {b"a-file-id-2006-01-01-abcd": {b"rev-B"}}, + self.branch.repository.fileids_altered_by_revision_ids([b"rev-B"]), + ) self.assertEqual( - {b'b-file-id-2006-01-01-defg': {b'rev-'} - }, - self.branch.repository.fileids_altered_by_revision_ids([b"rev-"])) + {b"b-file-id-2006-01-01-defg": {b"rev-"}}, + self.branch.repository.fileids_altered_by_revision_ids([b"rev-"]), + ) def test_fileids_involved_full_compare(self): # this tests that the result of each fileid_involved calculation @@ -305,8 +333,11 @@ def test_fileids_involved_full_compare(self): self.branch.lock_read() self.addCleanup(self.branch.unlock) graph = self.branch.repository.get_graph() - history = list(graph.iter_lefthand_ancestry(self.branch.last_revision(), - [_mod_revision.NULL_REVISION])) + history = list( + graph.iter_lefthand_ancestry( + self.branch.last_revision(), [_mod_revision.NULL_REVISION] + ) + ) history.reverse() if len(history) < 2: @@ -317,81 +348,81 @@ def test_fileids_involved_full_compare(self): for end in range(start + 1, len(history)): end_id = history[end] unique_revs = graph.find_unique_ancestors(end_id, [start_id]) - l1 = self.branch.repository.fileids_altered_by_revision_ids( - unique_revs) + l1 = self.branch.repository.fileids_altered_by_revision_ids(unique_revs) l1 = set(l1.keys()) l2 = self.compare_tree_fileids(self.branch, start_id, end_id) self.assertEqual(l1, l2) class TestFileIdInvolvedNonAscii(FileIdInvolvedBase): - scenarios = all_repository_vf_format_scenarios() def test_utf8_file_ids_and_revision_ids(self): - main_wt = self.make_branch_and_tree('main') + main_wt = self.make_branch_and_tree("main") self.build_tree(["main/a"]) - file_id = 'a-f\xedle-id'.encode() - main_wt.add(['a'], ids=[file_id]) - revision_id = 'r\xe9v-a'.encode() + file_id = "a-f\xedle-id".encode() + main_wt.add(["a"], ids=[file_id]) + revision_id = "r\xe9v-a".encode() try: - main_wt.commit('a', rev_id=revision_id) + main_wt.commit("a", rev_id=revision_id) except errors.NonAsciiRevisionId as e: - raise tests.TestSkipped('non-ascii revision ids not supported by %s' - % self.repository_format) from e + raise tests.TestSkipped( + "non-ascii revision ids not supported by %s" % self.repository_format + ) from e repo = main_wt.branch.repository repo.lock_read() self.addCleanup(repo.unlock) file_ids = repo.fileids_altered_by_revision_ids([revision_id]) - root_id = main_wt.basis_tree().path2id('') + root_id = main_wt.basis_tree().path2id("") if root_id in file_ids: - self.assertEqual({file_id: {revision_id}, - root_id: {revision_id} - }, file_ids) + self.assertEqual({file_id: {revision_id}, root_id: {revision_id}}, file_ids) else: self.assertEqual({file_id: {revision_id}}, file_ids) class TestFileIdInvolvedSuperset(FileIdInvolvedBase): - scenarios = all_repository_vf_format_scenarios() def setUp(self): super().setUp() self.branch = None - main_wt = self.make_branch_and_tree('main') + main_wt = self.make_branch_and_tree("main") main_branch = main_wt.branch self.build_tree(["main/a", "main/b", "main/c"]) main_wt.add( - ['a', 'b', 'c'], - ids=[b'a-file-id-2006-01-01-abcd', - b'b-file-id-2006-01-01-defg', - b'c-funkyquiji\'"%bo']) + ["a", "b", "c"], + ids=[ + b"a-file-id-2006-01-01-abcd", + b"b-file-id-2006-01-01-defg", + b"c-funkyquiji'\"%bo", + ], + ) try: main_wt.commit("Commit one", rev_id=b"rev-A") except errors.IllegalPath as e: # TODO: jam 20060701 Consider raising a different exception # newer formats do support this, and nothin can done to # correct this test - its not a bug. - if sys.platform == 'win32': - raise tests.TestSkipped('Old repository formats do not' - ' support file ids with <> on win32') from e + if sys.platform == "win32": + raise tests.TestSkipped( + "Old repository formats do not" " support file ids with <> on win32" + ) from e # This is not a known error condition raise - branch2_wt = self.make_branch_and_tree('branch2') + branch2_wt = self.make_branch_and_tree("branch2") branch2_wt.pull(main_branch) branch2_bzrdir = branch2_wt.controldir branch2_branch = branch2_bzrdir.open_branch() - set_executability(branch2_wt, 'b', True) + set_executability(branch2_wt, "b", True) branch2_wt.commit("branch2, Commit one", rev_id=b"rev-J") main_wt.merge_from_branch(branch2_branch) - set_executability(main_wt, 'b', False) + set_executability(main_wt, "b", False) main_wt.commit("merge branch1, rev-22", rev_id=b"rev-G") # end G @@ -404,15 +435,17 @@ def test_fileid_involved_full_compare2(self): self.branch.lock_read() self.addCleanup(self.branch.unlock) graph = self.branch.repository.get_graph() - history = list(graph.iter_lefthand_ancestry(self.branch.last_revision(), - [_mod_revision.NULL_REVISION])) + history = list( + graph.iter_lefthand_ancestry( + self.branch.last_revision(), [_mod_revision.NULL_REVISION] + ) + ) history.reverse() old_rev = history[0] new_rev = history[1] unique_revs = graph.find_unique_ancestors(new_rev, [old_rev]) - l1 = self.branch.repository.fileids_altered_by_revision_ids( - unique_revs) + l1 = self.branch.repository.fileids_altered_by_revision_ids(unique_revs) l1 = set(l1.keys()) l2 = self.compare_tree_fileids(self.branch, old_rev, new_rev) diff --git a/breezy/bzr/tests/per_repository_vf/test_find_text_key_references.py b/breezy/bzr/tests/per_repository_vf/test_find_text_key_references.py index 68103997ec..d520d579c0 100644 --- a/breezy/bzr/tests/per_repository_vf/test_find_text_key_references.py +++ b/breezy/bzr/tests/per_repository_vf/test_find_text_key_references.py @@ -29,11 +29,10 @@ class TestFindTextKeyReferences(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def test_empty(self): - repo = self.make_repository('.') + repo = self.make_repository(".") repo.lock_read() self.addCleanup(repo.unlock) self.assertEqual({}, repo.find_text_key_references()) diff --git a/breezy/bzr/tests/per_repository_vf/test_merge_directive.py b/breezy/bzr/tests/per_repository_vf/test_merge_directive.py index f67d9f6c48..2be8449d06 100644 --- a/breezy/bzr/tests/per_repository_vf/test_merge_directive.py +++ b/breezy/bzr/tests/per_repository_vf/test_merge_directive.py @@ -33,43 +33,50 @@ class TestMergeDirective(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def make_two_branches(self): - builder = self.make_branch_builder('source') + builder = self.make_branch_builder("source") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('f', b'f-id', 'file', b'initial content\n')), - ], revision_id=b'A') - builder.build_snapshot([b'A'], [ - ('modify', ('f', b'new content\n')), - ], revision_id=b'B') + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("f", b"f-id", "file", b"initial content\n")), + ], + revision_id=b"A", + ) + builder.build_snapshot( + [b"A"], + [ + ("modify", ("f", b"new content\n")), + ], + revision_id=b"B", + ) builder.finish_series() b1 = builder.get_branch() - b2 = b1.controldir.sprout('target', revision_id=b'A').open_branch() + b2 = b1.controldir.sprout("target", revision_id=b"A").open_branch() return b1, b2 def create_merge_directive(self, source_branch, submit_url): return merge_directive.MergeDirective2.from_objects( repository=source_branch.repository, revision_id=source_branch.last_revision(), - time=1247775710, timezone=0, - target_branch=submit_url) + time=1247775710, + timezone=0, + target_branch=submit_url, + ) def test_create_merge_directive(self): source_branch, target_branch = self.make_two_branches() - directive = self.create_merge_directive(source_branch, - target_branch.base) + directive = self.create_merge_directive(source_branch, target_branch.base) self.assertIsInstance(directive, merge_directive.MergeDirective2) def test_create_and_install_directive(self): source_branch, target_branch = self.make_two_branches() - directive = self.create_merge_directive(source_branch, - target_branch.base) + directive = self.create_merge_directive(source_branch, target_branch.base) chk_map.clear_cache() directive.install_revisions(target_branch.repository) - rt = target_branch.repository.revision_tree(b'B') + rt = target_branch.repository.revision_tree(b"B") with rt.lock_read(): - self.assertEqualDiff(b'new content\n', rt.get_file_text('f')) + self.assertEqualDiff(b"new content\n", rt.get_file_text("f")) diff --git a/breezy/bzr/tests/per_repository_vf/test_reconcile.py b/breezy/bzr/tests/per_repository_vf/test_reconcile.py index b7aaaed2d3..01373ddd5e 100644 --- a/breezy/bzr/tests/per_repository_vf/test_reconcile.py +++ b/breezy/bzr/tests/per_repository_vf/test_reconcile.py @@ -37,7 +37,6 @@ class TestReconcile(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def checkUnreconciled(self, d, reconciler): @@ -51,12 +50,11 @@ def checkUnreconciled(self, d, reconciler): def checkNoBackupInventory(self, a_bzr_dir): """Check that there is no backup inventory in aBzrDir.""" repo = a_bzr_dir.open_repository() - for path in repo.control_transport.list_dir('.'): - self.assertNotIn('inventory.backup', path) + for path in repo.control_transport.list_dir("."): + self.assertNotIn("inventory.backup", path) class TestBadRevisionParents(TestCaseWithBrokenRevisionIndex): - scenarios = all_repository_vf_format_scenarios() def test_aborts_if_bad_parents_in_index(self): @@ -69,29 +67,31 @@ def test_aborts_if_bad_parents_in_index(self): """ repo = self.make_repo_with_extra_ghost_index() result = repo.reconcile(thorough=True) - self.assertTrue(result.aborted, - "reconcile should have aborted due to bad parents.") + self.assertTrue( + result.aborted, "reconcile should have aborted due to bad parents." + ) def test_does_not_abort_on_clean_repo(self): - repo = self.make_repository('.') + repo = self.make_repository(".") result = repo.reconcile(thorough=True) - self.assertFalse(result.aborted, - "reconcile should not have aborted on an unbroken repository.") + self.assertFalse( + result.aborted, + "reconcile should not have aborted on an unbroken repository.", + ) class TestsNeedingReweave(TestReconcile): - def setUp(self): super().setUp() t = self.get_transport() # an empty inventory with no revision for testing with. - repo = self.make_repository('inventory_without_revision') + repo = self.make_repository("inventory_without_revision") repo.lock_write() repo.start_write_group() - inv = Inventory(revision_id=b'missing') - inv.root.revision = b'missing' - repo.add_inventory(b'missing', inv, []) + inv = Inventory(revision_id=b"missing") + inv.root.revision = b"missing" + repo.add_inventory(b"missing", inv, []) repo.commit_write_group() repo.unlock() @@ -103,45 +103,48 @@ def add_commit(repo, revision_id, parent_ids): root_id = inv.root.file_id sha1 = repo.add_inventory(revision_id, inv, parent_ids) repo.texts.add_lines((root_id, revision_id), [], []) - rev = breezy.revision.Revision(timestamp=0, - timezone=None, - committer="Foo Bar ", - properties={}, - message="Message", - inventory_sha1=sha1, - parent_ids=parent_ids, - revision_id=revision_id) + rev = breezy.revision.Revision( + timestamp=0, + timezone=None, + committer="Foo Bar ", + properties={}, + message="Message", + inventory_sha1=sha1, + parent_ids=parent_ids, + revision_id=revision_id, + ) repo.add_revision(revision_id, rev) repo.commit_write_group() repo.unlock() + # an empty inventory with no revision for testing with. # this is referenced by 'references_missing' to let us test # that all the cached data is correctly converted into ghost links # and the referenced inventory still cleaned. - repo = self.make_repository('inventory_without_revision_and_ghost') + repo = self.make_repository("inventory_without_revision_and_ghost") repo.lock_write() repo.start_write_group() - repo.add_inventory(b'missing', inv, []) + repo.add_inventory(b"missing", inv, []) repo.commit_write_group() repo.unlock() - add_commit(repo, b'references_missing', [b'missing']) + add_commit(repo, b"references_missing", [b"missing"]) # a inventory with no parents and the revision has parents.. # i.e. a ghost. - repo = self.make_repository('inventory_one_ghost') - add_commit(repo, b'ghost', [b'the_ghost']) + repo = self.make_repository("inventory_one_ghost") + add_commit(repo, b"ghost", [b"the_ghost"]) # a inventory with a ghost that can be corrected now. - t.copy_tree('inventory_one_ghost', 'inventory_ghost_present') - bzrdir_url = self.get_url('inventory_ghost_present') + t.copy_tree("inventory_one_ghost", "inventory_ghost_present") + bzrdir_url = self.get_url("inventory_ghost_present") bzrdir = BzrDir.open(bzrdir_url) repo = bzrdir.open_repository() - add_commit(repo, b'the_ghost', []) + add_commit(repo, b"the_ghost", []) def checkEmptyReconcile(self, **kwargs): """Check a reconcile on an empty repository.""" - self.make_repository('empty') - d = BzrDir.open(self.get_url('empty')) + self.make_repository("empty") + d = BzrDir.open(self.get_url("empty")) # calling on a empty repository should do nothing result = d.find_repository().reconcile(**kwargs) # no inconsistent parents should have been found @@ -156,7 +159,7 @@ def test_reconcile_empty(self): self.checkEmptyReconcile() def test_repo_has_reconcile_does_inventory_gc_attribute(self): - repo = self.make_repository('repo') + repo = self.make_repository("repo") self.assertNotEqual(None, repo._reconcile_does_inventory_gc) def test_reconcile_empty_thorough(self): @@ -165,11 +168,11 @@ def test_reconcile_empty_thorough(self): def test_convenience_reconcile_inventory_without_revision_reconcile(self): # smoke test for the all in one ui tool - bzrdir_url = self.get_url('inventory_without_revision') + bzrdir_url = self.get_url("inventory_without_revision") bzrdir = BzrDir.open(bzrdir_url) repo = bzrdir.open_repository() if not repo._reconcile_does_inventory_gc: - raise TestSkipped('Irrelevant test') + raise TestSkipped("Irrelevant test") reconcile(bzrdir) # now the backup should have it but not the current inventory repo = bzrdir.open_repository() @@ -177,11 +180,11 @@ def test_convenience_reconcile_inventory_without_revision_reconcile(self): def test_reweave_inventory_without_revision(self): # an excess inventory on its own is only reconciled by using thorough - d_url = self.get_url('inventory_without_revision') + d_url = self.get_url("inventory_without_revision") d = BzrDir.open(d_url) repo = d.open_repository() if not repo._reconcile_does_inventory_gc: - raise TestSkipped('Irrelevant test') + raise TestSkipped("Irrelevant test") self.checkUnreconciled(d, repo.reconcile()) result = repo.reconcile(thorough=True) # no bad parents @@ -190,11 +193,10 @@ def test_reweave_inventory_without_revision(self): self.assertEqual(1, result.garbage_inventories) self.check_missing_was_removed(repo) - def check_thorough_reweave_missing_revision(self, a_bzr_dir, reconcile, - **kwargs): + def check_thorough_reweave_missing_revision(self, a_bzr_dir, reconcile, **kwargs): # actual low level test. repo = a_bzr_dir.open_repository() - if not repo.has_revision(b'missing'): + if not repo.has_revision(b"missing"): # the repo handles ghosts without corruption, so reconcile has # nothing to do here. Specifically, this test has the inventory # 'missing' present and the revision 'missing' missing, so clearly @@ -205,8 +207,7 @@ def check_thorough_reweave_missing_revision(self, a_bzr_dir, reconcile, expected_inconsistent_parents = 1 reconciler = reconcile(**kwargs) # some number of inconsistent parents should have been found - self.assertEqual(expected_inconsistent_parents, - reconciler.inconsistent_parents) + self.assertEqual(expected_inconsistent_parents, reconciler.inconsistent_parents) # and one garbage inventories self.assertEqual(1, reconciler.garbage_inventories) # now the backup should have it but not the current inventory @@ -214,46 +215,46 @@ def check_thorough_reweave_missing_revision(self, a_bzr_dir, reconcile, self.check_missing_was_removed(repo) # and the parent list for 'references_missing' should have that # revision a ghost now. - self.assertFalse(repo.has_revision(b'missing')) + self.assertFalse(repo.has_revision(b"missing")) def check_missing_was_removed(self, repo): if repo._reconcile_backsup_inventory: backed_up = False - for path in repo.control_transport.list_dir('.'): - if 'inventory.backup' in path: + for path in repo.control_transport.list_dir("."): + if "inventory.backup" in path: backed_up = True self.assertTrue(backed_up) # Not clear how to do this at an interface level: # self.assertTrue('missing' in backup.versions()) - self.assertRaises(errors.NoSuchRevision, repo.get_inventory, 'missing') + self.assertRaises(errors.NoSuchRevision, repo.get_inventory, "missing") def test_reweave_inventory_without_revision_reconciler(self): # smoke test for the all in one Reconciler class, # other tests use the lower level repo.reconcile() - d_url = self.get_url('inventory_without_revision_and_ghost') + d_url = self.get_url("inventory_without_revision_and_ghost") d = BzrDir.open(d_url) if not d.open_repository()._reconcile_does_inventory_gc: - raise TestSkipped('Irrelevant test') + raise TestSkipped("Irrelevant test") def reconcile(): reconciler = Reconciler(d) return reconciler.reconcile() + self.check_thorough_reweave_missing_revision(d, reconcile) def test_reweave_inventory_without_revision_and_ghost(self): # actual low level test. - d_url = self.get_url('inventory_without_revision_and_ghost') + d_url = self.get_url("inventory_without_revision_and_ghost") d = BzrDir.open(d_url) repo = d.open_repository() if not repo._reconcile_does_inventory_gc: - raise TestSkipped('Irrelevant test') + raise TestSkipped("Irrelevant test") # nothing should have been altered yet : inventories without # revisions are not data loss incurring for current format - self.check_thorough_reweave_missing_revision(d, repo.reconcile, - thorough=True) + self.check_thorough_reweave_missing_revision(d, repo.reconcile, thorough=True) def test_reweave_inventory_preserves_a_revision_with_ghosts(self): - d = BzrDir.open(self.get_url('inventory_one_ghost')) + d = BzrDir.open(self.get_url("inventory_one_ghost")) reconciler = d.open_repository().reconcile(thorough=True) # no inconsistent parents should have been found: # the lack of a parent for ghost is normal @@ -262,19 +263,18 @@ def test_reweave_inventory_preserves_a_revision_with_ghosts(self): self.assertEqual(0, reconciler.garbage_inventories) # now the current inventory should still have 'ghost' repo = d.open_repository() - repo.get_inventory(b'ghost') - self.assertThat([b'ghost', b'the_ghost'], - MatchesAncestry(repo, b'ghost')) + repo.get_inventory(b"ghost") + self.assertThat([b"ghost", b"the_ghost"], MatchesAncestry(repo, b"ghost")) def test_reweave_inventory_fixes_ancestryfor_a_present_ghost(self): - d = BzrDir.open(self.get_url('inventory_ghost_present')) + d = BzrDir.open(self.get_url("inventory_ghost_present")) repo = d.open_repository() - m = MatchesAncestry(repo, b'ghost') - if m.match([b'the_ghost', b'ghost']) is None: + m = MatchesAncestry(repo, b"ghost") + if m.match([b"the_ghost", b"ghost"]) is None: # the repo handles ghosts without corruption, so reconcile has # nothing to do return - self.assertThat([b'ghost'], m) + self.assertThat([b"ghost"], m) reconciler = repo.reconcile() # this is a data corrupting error, so a normal reconcile should fix it. # one inconsistent parents should have been found : the @@ -284,42 +284,43 @@ def test_reweave_inventory_fixes_ancestryfor_a_present_ghost(self): self.assertEqual(0, reconciler.garbage_inventories) # now the current inventory should still have 'ghost' repo = d.open_repository() - repo.get_inventory(b'ghost') - repo.get_inventory(b'the_ghost') - self.assertThat([b'the_ghost', b'ghost'], - MatchesAncestry(repo, b'ghost')) - self.assertThat([b'the_ghost'], - MatchesAncestry(repo, b'the_ghost')) + repo.get_inventory(b"ghost") + repo.get_inventory(b"the_ghost") + self.assertThat([b"the_ghost", b"ghost"], MatchesAncestry(repo, b"ghost")) + self.assertThat([b"the_ghost"], MatchesAncestry(repo, b"the_ghost")) def test_text_from_ghost_revision(self): - repo = self.make_repository('text-from-ghost') - inv = Inventory(revision_id=b'final-revid') - inv.root.revision = b'root-revid' - ie = inv.add_path('bla', 'file', b'myfileid') - ie.revision = b'ghostrevid' + repo = self.make_repository("text-from-ghost") + inv = Inventory(revision_id=b"final-revid") + inv.root.revision = b"root-revid" + ie = inv.add_path("bla", "file", b"myfileid") + ie.revision = b"ghostrevid" ie.text_size = 42 ie.text_sha1 = b"bee68c8acd989f5f1765b4660695275948bf5c00" - rev = breezy.revision.Revision(timestamp=0, - timezone=None, - committer="Foo Bar ", - properties={}, - message="Message", - parent_ids=[], - inventory_sha1=None, - revision_id=b'final-revid') + rev = breezy.revision.Revision( + timestamp=0, + timezone=None, + committer="Foo Bar ", + properties={}, + message="Message", + parent_ids=[], + inventory_sha1=None, + revision_id=b"final-revid", + ) with repo.lock_write(): repo.start_write_group() try: - repo.add_revision(b'final-revid', rev, inv) + repo.add_revision(b"final-revid", rev, inv) try: - repo.texts.add_lines((b'myfileid', b'ghostrevid'), - ((b'myfileid', b'ghost-text-parent'),), - [b"line1\n", b"line2\n"]) + repo.texts.add_lines( + (b"myfileid", b"ghostrevid"), + ((b"myfileid", b"ghost-text-parent"),), + [b"line1\n", b"line2\n"], + ) except errors.RevisionNotPresent as e: raise TestSkipped("text ghost parents not supported") from e if repo.supports_rich_root(): - repo.texts.add_lines((inv.root.file_id, inv.root.revision), - [], []) + repo.texts.add_lines((inv.root.file_id, inv.root.revision), [], []) finally: repo.commit_write_group() repo.reconcile(thorough=True) @@ -348,17 +349,16 @@ def setUp(self): # we should add a lower level api to allow constructing such cases. # first off the common logic: - self.first_tree = self.make_branch_and_tree('wrong-first-parent') - self.second_tree = self.make_branch_and_tree( - 'reversed-secondary-parents') + self.first_tree = self.make_branch_and_tree("wrong-first-parent") + self.second_tree = self.make_branch_and_tree("reversed-secondary-parents") for t in [self.first_tree, self.second_tree]: - t.commit('1', rev_id=b'1') + t.commit("1", rev_id=b"1") uncommit(t.branch, tree=t) - t.commit('2', rev_id=b'2') + t.commit("2", rev_id=b"2") uncommit(t.branch, tree=t) - t.commit('3', rev_id=b'3') + t.commit("3", rev_id=b"3") uncommit(t.branch, tree=t) - #second_tree = self.make_branch_and_tree('reversed-secondary-parents') + # second_tree = self.make_branch_and_tree('reversed-secondary-parents') # second_tree.pull(tree) # XXX won't copy the repo? repo_secondary = self.second_tree.branch.repository @@ -366,21 +366,23 @@ def setUp(self): repo = self.first_tree.branch.repository repo.lock_write() repo.start_write_group() - inv = Inventory(revision_id=b'wrong-first-parent') - inv.root.revision = b'wrong-first-parent' + inv = Inventory(revision_id=b"wrong-first-parent") + inv.root.revision = b"wrong-first-parent" if repo.supports_rich_root(): root_id = inv.root.file_id - repo.texts.add_lines((root_id, b'wrong-first-parent'), [], []) - sha1 = repo.add_inventory(b'wrong-first-parent', inv, [b'2', b'1']) - rev = Revision(timestamp=0, - timezone=0, - committer="Foo Bar ", - message="Message", - inventory_sha1=sha1, - properties={}, - parent_ids=[b'1', b'2'], - revision_id=b'wrong-first-parent') - repo.add_revision(b'wrong-first-parent', rev) + repo.texts.add_lines((root_id, b"wrong-first-parent"), [], []) + sha1 = repo.add_inventory(b"wrong-first-parent", inv, [b"2", b"1"]) + rev = Revision( + timestamp=0, + timezone=0, + committer="Foo Bar ", + message="Message", + inventory_sha1=sha1, + properties={}, + parent_ids=[b"1", b"2"], + revision_id=b"wrong-first-parent", + ) + repo.add_revision(b"wrong-first-parent", rev) repo.commit_write_group() repo.unlock() @@ -388,22 +390,23 @@ def setUp(self): repo = repo_secondary repo.lock_write() repo.start_write_group() - inv = Inventory(revision_id=b'wrong-secondary-parent') - inv.root.revision = b'wrong-secondary-parent' + inv = Inventory(revision_id=b"wrong-secondary-parent") + inv.root.revision = b"wrong-secondary-parent" if repo.supports_rich_root(): root_id = inv.root.file_id - repo.texts.add_lines((root_id, b'wrong-secondary-parent'), [], []) - sha1 = repo.add_inventory( - b'wrong-secondary-parent', inv, [b'1', b'3', b'2']) - rev = Revision(timestamp=0, - timezone=None, - committer="Foo Bar ", - message="Message", - inventory_sha1=sha1, - parent_ids=[b'1', b'2', b'3'], - properties={}, - revision_id=b'wrong-secondary-parent') - repo.add_revision(b'wrong-secondary-parent', rev) + repo.texts.add_lines((root_id, b"wrong-secondary-parent"), [], []) + sha1 = repo.add_inventory(b"wrong-secondary-parent", inv, [b"1", b"3", b"2"]) + rev = Revision( + timestamp=0, + timezone=None, + committer="Foo Bar ", + message="Message", + inventory_sha1=sha1, + parent_ids=[b"1", b"2", b"3"], + properties={}, + revision_id=b"wrong-secondary-parent", + ) + repo.add_revision(b"wrong-secondary-parent", rev) repo.commit_write_group() repo.unlock() @@ -412,10 +415,11 @@ def test_reconcile_wrong_order(self): repo = self.first_tree.branch.repository with repo.lock_read(): g = repo.get_graph() - if g.get_parent_map([b'wrong-first-parent'])[b'wrong-first-parent'] \ - == (b'1', b'2'): - raise TestSkipped( - 'wrong-first-parent is not setup for testing') + if g.get_parent_map([b"wrong-first-parent"])[b"wrong-first-parent"] == ( + b"1", + b"2", + ): + raise TestSkipped("wrong-first-parent is not setup for testing") self.checkUnreconciled(repo.controldir, repo.reconcile()) # nothing should have been altered yet : inventories without # revisions are not data loss incurring for current format @@ -429,8 +433,9 @@ def test_reconcile_wrong_order(self): self.addCleanup(repo.unlock) g = repo.get_graph() self.assertEqual( - {b'wrong-first-parent': (b'1', b'2')}, - g.get_parent_map([b'wrong-first-parent'])) + {b"wrong-first-parent": (b"1", b"2")}, + g.get_parent_map([b"wrong-first-parent"]), + ) def test_reconcile_wrong_order_secondary_inventory(self): # a wrong order in the parents for inventories is ignored. diff --git a/breezy/bzr/tests/per_repository_vf/test_refresh_data.py b/breezy/bzr/tests/per_repository_vf/test_refresh_data.py index c0b4b4f768..2091763de9 100644 --- a/breezy/bzr/tests/per_repository_vf/test_refresh_data.py +++ b/breezy/bzr/tests/per_repository_vf/test_refresh_data.py @@ -28,18 +28,17 @@ class TestRefreshData(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def fetch_new_revision_into_concurrent_instance(self, repo, token): """Create a new revision (revid 'new-rev') and fetch it into a concurrent instance of repo. """ - source = self.make_branch_and_memory_tree('source') + source = self.make_branch_and_memory_tree("source") source.lock_write() self.addCleanup(source.unlock) - source.add([''], [b'root-id']) - revid = source.commit('foo', rev_id=b'new-rev') + source.add([""], [b"root-id"]) + revid = source.commit("foo", rev_id=b"new-rev") # Force data reading on weaves/knits repo.all_revision_ids() repo.revisions.keys() @@ -50,19 +49,21 @@ def fetch_new_revision_into_concurrent_instance(self, repo, token): try: server_repo.lock_write(token) except errors.TokenLockingNotSupported: - self.skipTest('Cannot concurrently insert into repo format %r' - % self.repository_format) + self.skipTest( + "Cannot concurrently insert into repo format %r" + % self.repository_format + ) try: server_repo.fetch(source.branch.repository, revid) finally: server_repo.unlock() def test_refresh_data_after_fetch_new_data_visible_in_write_group(self): - tree = self.make_branch_and_memory_tree('target') + tree = self.make_branch_and_memory_tree("target") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'root-id']) - tree.commit('foo', rev_id=b'commit-in-target') + tree.add([""], ids=[b"root-id"]) + tree.commit("foo", rev_id=b"commit-in-target") repo = tree.branch.repository token = repo.lock_write().repository_token self.addCleanup(repo.unlock) @@ -77,14 +78,13 @@ def test_refresh_data_after_fetch_new_data_visible_in_write_group(self): pass else: self.assertEqual( - [b'commit-in-target', b'new-rev'], - sorted(repo.all_revision_ids())) + [b"commit-in-target", b"new-rev"], sorted(repo.all_revision_ids()) + ) def test_refresh_data_after_fetch_new_data_visible(self): - repo = self.make_repository('target') + repo = self.make_repository("target") token = repo.lock_write().repository_token self.addCleanup(repo.unlock) self.fetch_new_revision_into_concurrent_instance(repo, token) repo.refresh_data() - self.assertNotEqual( - {}, repo.get_graph().get_parent_map([b'new-rev'])) + self.assertNotEqual({}, repo.get_graph().get_parent_map([b"new-rev"])) diff --git a/breezy/bzr/tests/per_repository_vf/test_repository.py b/breezy/bzr/tests/per_repository_vf/test_repository.py index b85b8ceebf..3ae729e668 100644 --- a/breezy/bzr/tests/per_repository_vf/test_repository.py +++ b/breezy/bzr/tests/per_repository_vf/test_repository.py @@ -32,170 +32,180 @@ class TestRepository(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def assertFormatAttribute(self, attribute, allowed_values): """Assert that the format has an attribute 'attribute'.""" - repo = self.make_repository('repo') + repo = self.make_repository("repo") self.assertSubset([getattr(repo._format, attribute)], allowed_values) def test_attribute__fetch_order(self): """Test the _fetch_order attribute.""" - self.assertFormatAttribute( - '_fetch_order', ('topological', 'unordered')) + self.assertFormatAttribute("_fetch_order", ("topological", "unordered")) def test_attribute__fetch_uses_deltas(self): """Test the _fetch_uses_deltas attribute.""" - self.assertFormatAttribute('_fetch_uses_deltas', (True, False)) + self.assertFormatAttribute("_fetch_uses_deltas", (True, False)) def test_attribute_inventories_store(self): """Test the existence of the inventories attribute.""" - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") repo = tree.branch.repository self.assertIsInstance(repo.inventories, versionedfile.VersionedFiles) def test_attribute_inventories_basics(self): """Test basic aspects of the inventories attribute.""" - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") repo = tree.branch.repository - rev_id = (tree.commit('a'),) + rev_id = (tree.commit("a"),) tree.lock_read() self.addCleanup(tree.unlock) self.assertEqual({rev_id}, set(repo.inventories.keys())) def test_attribute_revision_store(self): """Test the existence of the revisions attribute.""" - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") repo = tree.branch.repository - self.assertIsInstance(repo.revisions, - versionedfile.VersionedFiles) + self.assertIsInstance(repo.revisions, versionedfile.VersionedFiles) def test_attribute_revision_store_basics(self): """Test the basic behaviour of the revisions attribute.""" - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") repo = tree.branch.repository repo.lock_write() try: self.assertEqual(set(), set(repo.revisions.keys())) revid = (tree.commit("foo"),) self.assertEqual({revid}, set(repo.revisions.keys())) - self.assertEqual({revid: ()}, - repo.revisions.get_parent_map([revid])) + self.assertEqual({revid: ()}, repo.revisions.get_parent_map([revid])) finally: repo.unlock() - tree2 = self.make_branch_and_tree('tree2') + tree2 = self.make_branch_and_tree("tree2") tree2.pull(tree.branch) - left_id = (tree2.commit('left'),) - right_id = (tree.commit('right'),) + left_id = (tree2.commit("left"),) + right_id = (tree.commit("right"),) tree.merge_from_branch(tree2.branch) - merge_id = (tree.commit('merged'),) + merge_id = (tree.commit("merged"),) repo.lock_read() self.addCleanup(repo.unlock) - self.assertEqual({revid, left_id, right_id, merge_id}, - set(repo.revisions.keys())) - self.assertEqual({revid: (), left_id: (revid,), right_id: (revid,), - merge_id: (right_id, left_id)}, - repo.revisions.get_parent_map(repo.revisions.keys())) + self.assertEqual( + {revid, left_id, right_id, merge_id}, set(repo.revisions.keys()) + ) + self.assertEqual( + { + revid: (), + left_id: (revid,), + right_id: (revid,), + merge_id: (right_id, left_id), + }, + repo.revisions.get_parent_map(repo.revisions.keys()), + ) def test_attribute_signature_store(self): """Test the existence of the signatures attribute.""" - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") repo = tree.branch.repository - self.assertIsInstance(repo.signatures, - versionedfile.VersionedFiles) + self.assertIsInstance(repo.signatures, versionedfile.VersionedFiles) def test_exposed_versioned_files_are_marked_dirty(self): - repo = self.make_repository('.') + repo = self.make_repository(".") repo.lock_write() signatures = repo.signatures revisions = repo.revisions inventories = repo.inventories repo.unlock() - self.assertRaises(errors.ObjectNotLocked, - signatures.keys) - self.assertRaises(errors.ObjectNotLocked, - revisions.keys) - self.assertRaises(errors.ObjectNotLocked, - inventories.keys) - self.assertRaises(errors.ObjectNotLocked, - signatures.add_lines, ('foo',), [], []) - self.assertRaises(errors.ObjectNotLocked, - revisions.add_lines, ('foo',), [], []) - self.assertRaises(errors.ObjectNotLocked, - inventories.add_lines, ('foo',), [], []) + self.assertRaises(errors.ObjectNotLocked, signatures.keys) + self.assertRaises(errors.ObjectNotLocked, revisions.keys) + self.assertRaises(errors.ObjectNotLocked, inventories.keys) + self.assertRaises( + errors.ObjectNotLocked, signatures.add_lines, ("foo",), [], [] + ) + self.assertRaises(errors.ObjectNotLocked, revisions.add_lines, ("foo",), [], []) + self.assertRaises( + errors.ObjectNotLocked, inventories.add_lines, ("foo",), [], [] + ) def test__get_sink(self): - repo = self.make_repository('repo') + repo = self.make_repository("repo") sink = repo._get_sink() self.assertIsInstance(sink, vf_repository.StreamSink) def test_get_serializer_format(self): - repo = self.make_repository('.') + repo = self.make_repository(".") format = repo.get_serializer_format() self.assertEqual(repo._inventory_serializer.format_num, format) def test_add_revision_inventory_sha1(self): - inv = inventory.Inventory(revision_id=b'A') - root = inventory.InventoryDirectory(b'fixed-root', '', None) - root.revision = b'A' + inv = inventory.Inventory(revision_id=b"A") + root = inventory.InventoryDirectory(b"fixed-root", "", None) + root.revision = b"A" inv.add(root) # Insert the inventory on its own to an identical repository, to get # its sha1. - reference_repo = self.make_repository('reference_repo') + reference_repo = self.make_repository("reference_repo") reference_repo.lock_write() reference_repo.start_write_group() - inv_sha1 = reference_repo.add_inventory(b'A', inv, []) + inv_sha1 = reference_repo.add_inventory(b"A", inv, []) reference_repo.abort_write_group() reference_repo.unlock() # Now insert a revision with this inventory, and it should get the same # sha1. - repo = self.make_repository('repo') + repo = self.make_repository("repo") repo.lock_write() repo.start_write_group() - repo.texts.add_lines((b'fixed-root', b'A'), [], []) - repo.add_revision(b'A', _mod_revision.Revision( - b'A', committer='B', timestamp=0, - timezone=0, message='C', parent_ids=[], properties={}, inventory_sha1=None), inv=inv) + repo.texts.add_lines((b"fixed-root", b"A"), [], []) + repo.add_revision( + b"A", + _mod_revision.Revision( + b"A", + committer="B", + timestamp=0, + timezone=0, + message="C", + parent_ids=[], + properties={}, + inventory_sha1=None, + ), + inv=inv, + ) repo.commit_write_group() repo.unlock() repo.lock_read() - self.assertEqual(inv_sha1, repo.get_revision(b'A').inventory_sha1) + self.assertEqual(inv_sha1, repo.get_revision(b"A").inventory_sha1) repo.unlock() def test_install_revisions(self): - wt = self.make_branch_and_tree('source') - wt.commit('A', allow_pointless=True, rev_id=b'A') + wt = self.make_branch_and_tree("source") + wt.commit("A", allow_pointless=True, rev_id=b"A") repo = wt.branch.repository repo.lock_write() repo.start_write_group() - repo.sign_revision(b'A', gpg.LoopbackGPGStrategy(None)) + repo.sign_revision(b"A", gpg.LoopbackGPGStrategy(None)) repo.commit_write_group() repo.unlock() repo.lock_read() self.addCleanup(repo.unlock) - repo2 = self.make_repository('repo2') - revision = repo.get_revision(b'A') - tree = repo.revision_tree(b'A') - signature = repo.get_signature_text(b'A') + repo2 = self.make_repository("repo2") + revision = repo.get_revision(b"A") + tree = repo.revision_tree(b"A") + signature = repo.get_signature_text(b"A") repo2.lock_write() self.addCleanup(repo2.unlock) vf_repository.install_revisions(repo2, [(revision, tree, signature)]) - self.assertEqual(revision, repo2.get_revision(b'A')) - self.assertEqual(signature, repo2.get_signature_text(b'A')) + self.assertEqual(revision, repo2.get_revision(b"A")) + self.assertEqual(signature, repo2.get_signature_text(b"A")) def test_attribute_text_store(self): """Test the existence of the texts attribute.""" - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") repo = tree.branch.repository - self.assertIsInstance(repo.texts, - versionedfile.VersionedFiles) + self.assertIsInstance(repo.texts, versionedfile.VersionedFiles) def test_iter_inventories_is_ordered(self): # just a smoke test - tree = self.make_branch_and_tree('a') - first_revision = tree.commit('') - second_revision = tree.commit('') + tree = self.make_branch_and_tree("a") + first_revision = tree.commit("") + second_revision = tree.commit("") tree.lock_read() self.addCleanup(tree.unlock) revs = (first_revision, second_revision) @@ -205,19 +215,19 @@ def test_iter_inventories_is_ordered(self): def test_item_keys_introduced_by(self): # Make a repo with one revision and one versioned file. - tree = self.make_branch_and_tree('t') - self.build_tree(['t/foo']) - tree.add('foo', ids=b'file1') - tree.commit('message', rev_id=b'rev_id') + tree = self.make_branch_and_tree("t") + self.build_tree(["t/foo"]) + tree.add("foo", ids=b"file1") + tree.commit("message", rev_id=b"rev_id") repo = tree.branch.repository repo.lock_write() repo.start_write_group() try: - repo.sign_revision(b'rev_id', gpg.LoopbackGPGStrategy(None)) + repo.sign_revision(b"rev_id", gpg.LoopbackGPGStrategy(None)) except errors.UnsupportedOperation: signature_texts = [] else: - signature_texts = [b'rev_id'] + signature_texts = [b"rev_id"] repo.commit_write_group() repo.unlock() repo.lock_read() @@ -230,14 +240,15 @@ def test_item_keys_introduced_by(self): # * signatures # * revisions expected_item_keys = [ - ('file', b'file1', [b'rev_id']), - ('inventory', None, [b'rev_id']), - ('signatures', None, signature_texts), - ('revisions', None, [b'rev_id'])] - item_keys = list(repo.item_keys_introduced_by([b'rev_id'])) + ("file", b"file1", [b"rev_id"]), + ("inventory", None, [b"rev_id"]), + ("signatures", None, signature_texts), + ("revisions", None, [b"rev_id"]), + ] + item_keys = list(repo.item_keys_introduced_by([b"rev_id"])) item_keys = [ - (kind, file_id, list(versions)) - for (kind, file_id, versions) in item_keys] + (kind, file_id, list(versions)) for (kind, file_id, versions) in item_keys + ] if repo.supports_rich_root(): # Check for the root versioned file in the item_keys, then remove @@ -245,8 +256,8 @@ def test_item_keys_introduced_by(self): # expected_record_names. # Note that the file keys can be in any order, so this test is # written to allow that. - inv = repo.get_inventory(b'rev_id') - root_item_key = ('file', inv.root.file_id, [b'rev_id']) + inv = repo.get_inventory(b"rev_id") + root_item_key = ("file", inv.root.file_id, [b"rev_id"]) self.assertIn(root_item_key, item_keys) item_keys.remove(root_item_key) @@ -254,23 +265,23 @@ def test_item_keys_introduced_by(self): def test_attribute_text_store_basics(self): """Test the basic behaviour of the text store.""" - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") repo = tree.branch.repository file_id = b"Foo:Bar" file_key = (file_id,) with tree.lock_write(): self.assertEqual(set(), set(repo.texts.keys())) - tree.add(['foo'], ['file'], [file_id]) - tree.put_file_bytes_non_atomic( - 'foo', b'content\n') + tree.add(["foo"], ["file"], [file_id]) + tree.put_file_bytes_non_atomic("foo", b"content\n") try: rev_key = (tree.commit("foo"),) except errors.IllegalPath as e: raise tests.TestNotApplicable( - f'file_id {file_id!r} cannot be stored on this' - ' platform for this repo format') from e + f"file_id {file_id!r} cannot be stored on this" + " platform for this repo format" + ) from e if repo._format.rich_root_data: - root_commit = (tree.path2id(''),) + rev_key + root_commit = (tree.path2id(""),) + rev_key keys = {root_commit} parents = {root_commit: ()} else: @@ -279,28 +290,26 @@ def test_attribute_text_store_basics(self): keys.add(file_key + rev_key) parents[file_key + rev_key] = () self.assertEqual(keys, set(repo.texts.keys())) - self.assertEqual(parents, - repo.texts.get_parent_map(repo.texts.keys())) - tree2 = self.make_branch_and_tree('tree2') + self.assertEqual(parents, repo.texts.get_parent_map(repo.texts.keys())) + tree2 = self.make_branch_and_tree("tree2") tree2.pull(tree.branch) - tree2.put_file_bytes_non_atomic('foo', b'right\n') - right_key = (tree2.commit('right'),) + tree2.put_file_bytes_non_atomic("foo", b"right\n") + right_key = (tree2.commit("right"),) keys.add(file_key + right_key) parents[file_key + right_key] = (file_key + rev_key,) - tree.put_file_bytes_non_atomic('foo', b'left\n') - left_key = (tree.commit('left'),) + tree.put_file_bytes_non_atomic("foo", b"left\n") + left_key = (tree.commit("left"),) keys.add(file_key + left_key) parents[file_key + left_key] = (file_key + rev_key,) tree.merge_from_branch(tree2.branch) - tree.put_file_bytes_non_atomic('foo', b'merged\n') + tree.put_file_bytes_non_atomic("foo", b"merged\n") try: tree.auto_resolve() except errors.UnsupportedOperation: pass - merge_key = (tree.commit('merged'),) + merge_key = (tree.commit("merged"),) keys.add(file_key + merge_key) - parents[file_key + merge_key] = (file_key + left_key, - file_key + right_key) + parents[file_key + merge_key] = (file_key + left_key, file_key + right_key) repo.lock_read() self.addCleanup(repo.unlock) self.assertEqual(keys, set(repo.texts.keys())) @@ -308,36 +317,36 @@ def test_attribute_text_store_basics(self): class TestCaseWithComplexRepository(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def setUp(self): super().setUp() - tree_a = self.make_branch_and_tree('a') + tree_a = self.make_branch_and_tree("a") self.controldir = tree_a.branch.controldir # add a corrupt inventory 'orphan' # this may need some generalising for knits. with tree_a.lock_write(), _mod_repository.WriteGroup(tree_a.branch.repository): inv_file = tree_a.branch.repository.inventories - inv_file.add_lines((b'orphan',), [], []) + inv_file.add_lines((b"orphan",), [], []) # add a real revision 'rev1' - tree_a.commit('rev1', rev_id=b'rev1', allow_pointless=True) + tree_a.commit("rev1", rev_id=b"rev1", allow_pointless=True) # add a real revision 'rev2' based on rev1 - tree_a.commit('rev2', rev_id=b'rev2', allow_pointless=True) + tree_a.commit("rev2", rev_id=b"rev2", allow_pointless=True) # add a reference to a ghost - tree_a.add_parent_tree_id(b'ghost1') + tree_a.add_parent_tree_id(b"ghost1") try: - tree_a.commit('rev3', rev_id=b'rev3', allow_pointless=True) + tree_a.commit("rev3", rev_id=b"rev3", allow_pointless=True) except errors.RevisionNotPresent as e: raise tests.TestNotApplicable( - "Cannot test with ghosts for this format.") from e + "Cannot test with ghosts for this format." + ) from e # add another reference to a ghost, and a second ghost. - tree_a.add_parent_tree_id(b'ghost1') - tree_a.add_parent_tree_id(b'ghost2') - tree_a.commit('rev4', rev_id=b'rev4', allow_pointless=True) + tree_a.add_parent_tree_id(b"ghost1") + tree_a.add_parent_tree_id(b"ghost2") + tree_a.commit("rev4", rev_id=b"rev4", allow_pointless=True) def test_revision_trees(self): - revision_ids = [b'rev1', b'rev2', b'rev3', b'rev4'] + revision_ids = [b"rev1", b"rev2", b"rev3", b"rev4"] repository = self.controldir.open_repository() repository.lock_read() self.addCleanup(repository.unlock) @@ -351,93 +360,112 @@ def test_get_revision_deltas(self): repository = self.controldir.open_repository() repository.lock_read() self.addCleanup(repository.unlock) - revisions = [repository.get_revision(r) for r in - [b'rev1', b'rev2', b'rev3', b'rev4']] + revisions = [ + repository.get_revision(r) for r in [b"rev1", b"rev2", b"rev3", b"rev4"] + ] deltas1 = list(repository.get_revision_deltas(revisions)) - deltas2 = [repository.get_revision_delta(r.revision_id) - for r in revisions] + deltas2 = [repository.get_revision_delta(r.revision_id) for r in revisions] self.assertEqual(deltas1, deltas2) def test_all_revision_ids(self): # all_revision_ids -> all revisions - self.assertEqual({b'rev1', b'rev2', b'rev3', b'rev4'}, - set(self.controldir.open_repository().all_revision_ids())) + self.assertEqual( + {b"rev1", b"rev2", b"rev3", b"rev4"}, + set(self.controldir.open_repository().all_revision_ids()), + ) def test_reserved_id(self): - repo = self.make_repository('repository') + repo = self.make_repository("repository") with repo.lock_write(), _mod_repository.WriteGroup(repo): - self.assertRaises(errors.ReservedId, repo.add_inventory, - b'reserved:', None, None) - self.assertRaises(errors.ReservedId, repo.add_inventory_by_delta, - "foo", [], b'reserved:', None) self.assertRaises( - errors.ReservedId, repo.add_revision, b'reserved:', None) + errors.ReservedId, repo.add_inventory, b"reserved:", None, None + ) + self.assertRaises( + errors.ReservedId, + repo.add_inventory_by_delta, + "foo", + [], + b"reserved:", + None, + ) + self.assertRaises(errors.ReservedId, repo.add_revision, b"reserved:", None) class TestCaseWithCorruptRepository(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def setUp(self): super().setUp() # a inventory with no parents and the revision has parents.. # i.e. a ghost. - repo = self.make_repository('inventory_with_unnecessary_ghost') + repo = self.make_repository("inventory_with_unnecessary_ghost") repo.lock_write() repo.start_write_group() - inv = inventory.Inventory(revision_id=b'ghost') - inv.root.revision = b'ghost' + inv = inventory.Inventory(revision_id=b"ghost") + inv.root.revision = b"ghost" if repo.supports_rich_root(): root_id = inv.root.file_id - repo.texts.add_lines((root_id, b'ghost'), [], []) - sha1 = repo.add_inventory(b'ghost', inv, []) + repo.texts.add_lines((root_id, b"ghost"), [], []) + sha1 = repo.add_inventory(b"ghost", inv, []) rev = _mod_revision.Revision( - timestamp=0, timezone=None, committer="Foo Bar ", - message="Message", inventory_sha1=sha1, revision_id=b'ghost', - parent_ids=[b'the_ghost'], properties={}) + timestamp=0, + timezone=None, + committer="Foo Bar ", + message="Message", + inventory_sha1=sha1, + revision_id=b"ghost", + parent_ids=[b"the_ghost"], + properties={}, + ) try: - repo.add_revision(b'ghost', rev) + repo.add_revision(b"ghost", rev) except (errors.NoSuchRevision, errors.RevisionNotPresent) as e: raise tests.TestNotApplicable( - "Cannot test with ghosts for this format.") from e + "Cannot test with ghosts for this format." + ) from e - inv = inventory.Inventory(revision_id=b'the_ghost') - inv.root.revision = b'the_ghost' + inv = inventory.Inventory(revision_id=b"the_ghost") + inv.root.revision = b"the_ghost" if repo.supports_rich_root(): root_id = inv.root.file_id - repo.texts.add_lines((root_id, b'the_ghost'), [], []) - sha1 = repo.add_inventory(b'the_ghost', inv, []) + repo.texts.add_lines((root_id, b"the_ghost"), [], []) + sha1 = repo.add_inventory(b"the_ghost", inv, []) rev = _mod_revision.Revision( - timestamp=0, timezone=None, committer="Foo Bar ", - message="Message", inventory_sha1=sha1, revision_id=b'the_ghost', + timestamp=0, + timezone=None, + committer="Foo Bar ", + message="Message", + inventory_sha1=sha1, + revision_id=b"the_ghost", properties={}, - parent_ids=[]) - repo.add_revision(b'the_ghost', rev) + parent_ids=[], + ) + repo.add_revision(b"the_ghost", rev) # check its setup usefully inv_weave = repo.inventories - possible_parents = (None, ((b'ghost',),)) - self.assertSubset(inv_weave.get_parent_map([(b'ghost',)])[(b'ghost',)], - possible_parents) + possible_parents = (None, ((b"ghost",),)) + self.assertSubset( + inv_weave.get_parent_map([(b"ghost",)])[(b"ghost",)], possible_parents + ) repo.commit_write_group() repo.unlock() def test_corrupt_revision_access_asserts_if_reported_wrong(self): - repo_url = self.get_url('inventory_with_unnecessary_ghost') + repo_url = self.get_url("inventory_with_unnecessary_ghost") repo = _mod_repository.Repository.open(repo_url) - m = MatchesAncestry(repo, b'ghost') + m = MatchesAncestry(repo, b"ghost") reported_wrong = False try: - if m.match([b'the_ghost', b'ghost']) is not None: + if m.match([b"the_ghost", b"ghost"]) is not None: reported_wrong = True except errors.CorruptRepository: # caught the bad data: return if not reported_wrong: return - self.assertRaises(errors.CorruptRepository, - repo.get_revision, b'ghost') + self.assertRaises(errors.CorruptRepository, repo.get_revision, b"ghost") def test_corrupt_revision_get_revision_reconcile(self): - repo_url = self.get_url('inventory_with_unnecessary_ghost') + repo_url = self.get_url("inventory_with_unnecessary_ghost") repo = _mod_repository.Repository.open(repo_url) - repo.get_revision_reconcile(b'ghost') + repo.get_revision_reconcile(b"ghost") diff --git a/breezy/bzr/tests/per_repository_vf/test_write_group.py b/breezy/bzr/tests/per_repository_vf/test_write_group.py index 29d10a4e8a..254f3df759 100644 --- a/breezy/bzr/tests/per_repository_vf/test_write_group.py +++ b/breezy/bzr/tests/per_repository_vf/test_write_group.py @@ -31,12 +31,11 @@ class TestGetMissingParentInventories(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() def test_empty_get_missing_parent_inventories(self): """A new write group has no missing parent inventories.""" - repo = self.make_repository('.') + repo = self.make_repository(".") repo.lock_write() repo.start_write_group() try: @@ -46,46 +45,50 @@ def test_empty_get_missing_parent_inventories(self): repo.unlock() def branch_trunk_and_make_tree(self, trunk_repo, relpath): - tree = self.make_branch_and_memory_tree('branch') + tree = self.make_branch_and_memory_tree("branch") trunk_repo.lock_read() self.addCleanup(trunk_repo.unlock) - tree.branch.repository.fetch(trunk_repo, revision_id=b'rev-1') - tree.set_parent_ids([b'rev-1']) + tree.branch.repository.fetch(trunk_repo, revision_id=b"rev-1") + tree.set_parent_ids([b"rev-1"]) return tree def make_first_commit(self, repo): trunk = repo.controldir.create_branch() tree = memorytree.MemoryTree.create_on_branch(trunk) tree.lock_write() - tree.add([''], ['directory'], [b'TREE_ROOT']) - tree.add(['dir'], ['directory'], [b'dir-id']) - tree.add(['filename'], ['file'], [b'file-id'], ) - tree.put_file_bytes_non_atomic('filename', b'content\n') - tree.commit('Trunk commit', rev_id=b'rev-0') - tree.commit('Trunk commit', rev_id=b'rev-1') + tree.add([""], ["directory"], [b"TREE_ROOT"]) + tree.add(["dir"], ["directory"], [b"dir-id"]) + tree.add( + ["filename"], + ["file"], + [b"file-id"], + ) + tree.put_file_bytes_non_atomic("filename", b"content\n") + tree.commit("Trunk commit", rev_id=b"rev-0") + tree.commit("Trunk commit", rev_id=b"rev-1") tree.unlock() def make_new_commit_in_new_repo(self, trunk_repo, parents=None): - tree = self.branch_trunk_and_make_tree(trunk_repo, 'branch') + tree = self.branch_trunk_and_make_tree(trunk_repo, "branch") tree.set_parent_ids(parents) - tree.commit('Branch commit', rev_id=b'rev-2') + tree.commit("Branch commit", rev_id=b"rev-2") branch_repo = tree.branch.repository branch_repo.lock_read() self.addCleanup(branch_repo.unlock) return branch_repo - def make_stackable_repo(self, relpath='trunk'): + def make_stackable_repo(self, relpath="trunk"): if isinstance(self.repository_format, remote.RemoteRepositoryFormat): # RemoteRepository by default builds a default format real # repository, but the default format is unstackble. So explicitly # make a stackable real repository and use that. - repo = self.make_repository(relpath, format='1.9') + repo = self.make_repository(relpath, format="1.9") dir = controldir.ControlDir.open(self.get_url(relpath)) repo = dir.open_repository() else: repo = self.make_repository(relpath) if not repo._format.supports_external_lookups: - raise tests.TestNotApplicable('format not stackable') + raise tests.TestNotApplicable("format not stackable") repo.controldir._format.set_branch_format(bzrbranch.BzrBranchFormat7()) return repo @@ -117,28 +120,32 @@ def test_ghost_revision(self): self.addCleanup(trunk_repo.unlock) # Branch the trunk, add a new commit. branch_repo = self.make_new_commit_in_new_repo( - trunk_repo, parents=[b'rev-1', b'ghost-rev']) - inv = branch_repo.get_inventory(b'rev-2') + trunk_repo, parents=[b"rev-1", b"ghost-rev"] + ) + inv = branch_repo.get_inventory(b"rev-2") # Make a new repo stacked on trunk, and then copy into it: # - all texts in rev-2 # - the new inventory (rev-2) # - the new revision (rev-2) - repo = self.make_stackable_repo('stacked') + repo = self.make_stackable_repo("stacked") repo.lock_write() repo.start_write_group() # Add all texts from in rev-2 inventory. Note that this has to exclude # the root if the repo format does not support rich roots. rich_root = branch_repo._format.rich_root_data all_texts = [ - (ie.file_id, ie.revision) for (_n, ie) in inv.iter_entries() - if rich_root or inv.id2path(ie.file_id) != ''] + (ie.file_id, ie.revision) + for (_n, ie) in inv.iter_entries() + if rich_root or inv.id2path(ie.file_id) != "" + ] repo.texts.insert_record_stream( - branch_repo.texts.get_record_stream(all_texts, 'unordered', False)) + branch_repo.texts.get_record_stream(all_texts, "unordered", False) + ) # Add inventory and revision for rev-2. - repo.add_inventory(b'rev-2', inv, [b'rev-1', b'ghost-rev']) + repo.add_inventory(b"rev-2", inv, [b"rev-1", b"ghost-rev"]) repo.revisions.insert_record_stream( - branch_repo.revisions.get_record_stream( - [(b'rev-2',)], 'unordered', False)) + branch_repo.revisions.get_record_stream([(b"rev-2",)], "unordered", False) + ) # Now, no inventories are reported as missing, even though there is a # ghost. self.assertEqual(set(), repo.get_missing_parent_inventories()) @@ -163,67 +170,75 @@ def test_get_missing_parent_inventories(self): trunk_repo.lock_read() self.addCleanup(trunk_repo.unlock) # Branch the trunk, add a new commit. - branch_repo = self.make_new_commit_in_new_repo( - trunk_repo, parents=[b'rev-1']) - inv = branch_repo.get_inventory(b'rev-2') + branch_repo = self.make_new_commit_in_new_repo(trunk_repo, parents=[b"rev-1"]) + inv = branch_repo.get_inventory(b"rev-2") # Make a new repo stacked on trunk, and copy the new commit's revision # and inventory records to it. - repo = self.make_stackable_repo('stacked') + repo = self.make_stackable_repo("stacked") repo.lock_write() repo.start_write_group() # Insert a single fulltext inv (using add_inventory because it's # simpler than insert_record_stream) - repo.add_inventory(b'rev-2', inv, [b'rev-1']) + repo.add_inventory(b"rev-2", inv, [b"rev-1"]) repo.revisions.insert_record_stream( - branch_repo.revisions.get_record_stream( - [(b'rev-2',)], 'unordered', False)) + branch_repo.revisions.get_record_stream([(b"rev-2",)], "unordered", False) + ) # There should be no missing compression parents - self.assertEqual(set(), - repo.inventories.get_missing_compression_parent_keys()) + self.assertEqual(set(), repo.inventories.get_missing_compression_parent_keys()) self.assertEqual( - {('inventories', b'rev-1')}, - repo.get_missing_parent_inventories()) + {("inventories", b"rev-1")}, repo.get_missing_parent_inventories() + ) # Resuming the write group does not affect # get_missing_parent_inventories. reopened_repo = self.reopen_repo_and_resume_write_group(repo) self.assertEqual( - {('inventories', b'rev-1')}, - reopened_repo.get_missing_parent_inventories()) + {("inventories", b"rev-1")}, reopened_repo.get_missing_parent_inventories() + ) # Adding the parent inventory satisfies get_missing_parent_inventories. reopened_repo.inventories.insert_record_stream( - branch_repo.inventories.get_record_stream( - [(b'rev-1',)], 'unordered', False)) - self.assertEqual( - set(), reopened_repo.get_missing_parent_inventories()) + branch_repo.inventories.get_record_stream([(b"rev-1",)], "unordered", False) + ) + self.assertEqual(set(), reopened_repo.get_missing_parent_inventories()) reopened_repo.abort_write_group() def test_get_missing_parent_inventories_check(self): - builder = self.make_branch_builder('test') - builder.build_snapshot([b'ghost-parent-id'], [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('file', b'file-id', 'file', b'content\n'))], - allow_leftmost_as_ghost=True, revision_id=b'A-id') + builder = self.make_branch_builder("test") + builder.build_snapshot( + [b"ghost-parent-id"], + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("file", b"file-id", "file", b"content\n")), + ], + allow_leftmost_as_ghost=True, + revision_id=b"A-id", + ) b = builder.get_branch() b.lock_read() self.addCleanup(b.unlock) - repo = self.make_repository('test-repo') + repo = self.make_repository("test-repo") repo.lock_write() self.addCleanup(repo.unlock) repo.start_write_group() self.addCleanup(repo.abort_write_group) # Now, add the objects manually - text_keys = [(b'file-id', b'A-id')] + text_keys = [(b"file-id", b"A-id")] if repo.supports_rich_root(): - text_keys.append((b'root-id', b'A-id')) + text_keys.append((b"root-id", b"A-id")) # Directly add the texts, inventory, and revision object for b'A-id' - repo.texts.insert_record_stream(b.repository.texts.get_record_stream( - text_keys, 'unordered', True)) - repo.add_revision(b'A-id', b.repository.get_revision(b'A-id'), - b.repository.get_inventory(b'A-id')) + repo.texts.insert_record_stream( + b.repository.texts.get_record_stream(text_keys, "unordered", True) + ) + repo.add_revision( + b"A-id", + b.repository.get_revision(b"A-id"), + b.repository.get_inventory(b"A-id"), + ) get_missing = repo.get_missing_parent_inventories if repo._format.supports_external_lookups: - self.assertEqual({('inventories', b'ghost-parent-id')}, - get_missing(check_for_missing_texts=False)) + self.assertEqual( + {("inventories", b"ghost-parent-id")}, + get_missing(check_for_missing_texts=False), + ) self.assertEqual(set(), get_missing(check_for_missing_texts=True)) self.assertEqual(set(), get_missing()) else: @@ -233,11 +248,13 @@ def test_get_missing_parent_inventories_check(self): self.assertEqual(set(), get_missing()) def test_insert_stream_passes_resume_info(self): - repo = self.make_repository('test-repo') - if (not repo._format.supports_external_lookups or - isinstance(repo, remote.RemoteRepository)): + repo = self.make_repository("test-repo") + if not repo._format.supports_external_lookups or isinstance( + repo, remote.RemoteRepository + ): raise tests.TestNotApplicable( - 'only valid for direct connections to resumable repos') + "only valid for direct connections to resumable repos" + ) # log calls to get_missing_parent_inventories, so that we can assert it # is called with the correct parameters call_log = [] @@ -246,6 +263,7 @@ def test_insert_stream_passes_resume_info(self): def get_missing(check_for_missing_texts=True): call_log.append(check_for_missing_texts) return orig(check_for_missing_texts=check_for_missing_texts) + repo.get_missing_parent_inventories = get_missing repo.lock_write() self.addCleanup(repo.unlock) @@ -256,47 +274,80 @@ def get_missing(check_for_missing_texts=True): repo.start_write_group() # We need to insert something, or suspend_write_group won't actually # create a token - repo.texts.insert_record_stream([versionedfile.FulltextContentFactory( - (b'file-id', b'rev-id'), (), None, b'lines\n')]) + repo.texts.insert_record_stream( + [ + versionedfile.FulltextContentFactory( + (b"file-id", b"rev-id"), (), None, b"lines\n" + ) + ] + ) tokens = repo.suspend_write_group() self.assertNotEqual([], tokens) sink.insert_stream((), repo._format, tokens) self.assertEqual([True], call_log) def test_insert_stream_without_locking_fails_without_lock(self): - repo = self.make_repository('test-repo') + repo = self.make_repository("test-repo") sink = repo._get_sink() - stream = [('texts', [versionedfile.FulltextContentFactory( - (b'file-id', b'rev-id'), (), None, b'lines\n')])] - self.assertRaises(errors.ObjectNotLocked, - sink.insert_stream_without_locking, stream, repo._format) + stream = [ + ( + "texts", + [ + versionedfile.FulltextContentFactory( + (b"file-id", b"rev-id"), (), None, b"lines\n" + ) + ], + ) + ] + self.assertRaises( + errors.ObjectNotLocked, + sink.insert_stream_without_locking, + stream, + repo._format, + ) def test_insert_stream_without_locking_fails_without_write_group(self): - repo = self.make_repository('test-repo') + repo = self.make_repository("test-repo") self.addCleanup(repo.lock_write().unlock) sink = repo._get_sink() - stream = [('texts', [versionedfile.FulltextContentFactory( - (b'file-id', b'rev-id'), (), None, b'lines\n')])] - self.assertRaises(errors.BzrError, - sink.insert_stream_without_locking, stream, repo._format) + stream = [ + ( + "texts", + [ + versionedfile.FulltextContentFactory( + (b"file-id", b"rev-id"), (), None, b"lines\n" + ) + ], + ) + ] + self.assertRaises( + errors.BzrError, sink.insert_stream_without_locking, stream, repo._format + ) def test_insert_stream_without_locking(self): - repo = self.make_repository('test-repo') + repo = self.make_repository("test-repo") self.addCleanup(repo.lock_write().unlock) repo.start_write_group() sink = repo._get_sink() - stream = [('texts', [versionedfile.FulltextContentFactory( - (b'file-id', b'rev-id'), (), None, b'lines\n')])] + stream = [ + ( + "texts", + [ + versionedfile.FulltextContentFactory( + (b"file-id", b"rev-id"), (), None, b"lines\n" + ) + ], + ) + ] missing_keys = sink.insert_stream_without_locking(stream, repo._format) repo.commit_write_group() self.assertEqual(set(), missing_keys) class TestResumeableWriteGroup(TestCaseWithRepository): - scenarios = all_repository_vf_format_scenarios() - def make_write_locked_repo(self, relpath='repo'): + def make_write_locked_repo(self, relpath="repo"): repo = self.make_repository(relpath) repo.lock_write() self.addCleanup(repo.unlock) @@ -309,7 +360,7 @@ def reopen_repo(self, repo): return same_repo def require_suspendable_write_groups(self, reason): - repo = self.make_repository('__suspend_test') + repo = self.make_repository("__suspend_test") repo.lock_write() self.addCleanup(repo.unlock) repo.start_write_group() @@ -324,7 +375,7 @@ def test_suspend_write_group(self): repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - repo.texts.add_lines((b'file-id', b'revid'), (), [b'lines']) + repo.texts.add_lines((b"file-id", b"revid"), (), [b"lines"]) try: wg_tokens = repo.suspend_write_group() except errors.UnsuspendableWriteGroup: @@ -348,8 +399,8 @@ def test_resume_write_group_then_abort(self): repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) try: wg_tokens = repo.suspend_write_group() except errors.UnsuspendableWriteGroup: @@ -357,9 +408,10 @@ def test_resume_write_group_then_abort(self): # support resuming them either. repo.abort_write_group() self.assertRaises( - errors.UnsuspendableWriteGroup, repo.resume_write_group, []) + errors.UnsuspendableWriteGroup, repo.resume_write_group, [] + ) else: - #self.assertEqual([], list(repo.texts.keys())) + # self.assertEqual([], list(repo.texts.keys())) same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) self.assertEqual([text_key], list(same_repo.texts.keys())) @@ -370,19 +422,20 @@ def test_resume_write_group_then_abort(self): def test_multiple_resume_write_group(self): self.require_suspendable_write_groups( - 'Cannot test resume on repo that does not support suspending') + "Cannot test resume on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - first_key = (b'file-id', b'revid') - repo.texts.add_lines(first_key, (), [b'lines']) + first_key = (b"file-id", b"revid") + repo.texts.add_lines(first_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) self.assertTrue(same_repo.is_in_write_group()) - second_key = (b'file-id', b'second-revid') - same_repo.texts.add_lines(second_key, (first_key,), [b'more lines']) + second_key = (b"file-id", b"second-revid") + same_repo.texts.add_lines(second_key, (first_key,), [b"more lines"]) try: new_wg_tokens = same_repo.suspend_write_group() except: @@ -398,13 +451,14 @@ def test_multiple_resume_write_group(self): def test_no_op_suspend_resume(self): self.require_suspendable_write_groups( - 'Cannot test resume on repo that does not support suspending') + "Cannot test resume on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) @@ -417,25 +471,27 @@ def test_no_op_suspend_resume(self): def test_read_after_suspend_fails(self): self.require_suspendable_write_groups( - 'Cannot test suspend on repo that does not support suspending') + "Cannot test suspend on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) repo.suspend_write_group() self.assertEqual([], list(repo.texts.keys())) def test_read_after_second_suspend_fails(self): self.require_suspendable_write_groups( - 'Cannot test suspend on repo that does not support suspending') + "Cannot test suspend on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) @@ -444,13 +500,14 @@ def test_read_after_second_suspend_fails(self): def test_read_after_resume_abort_fails(self): self.require_suspendable_write_groups( - 'Cannot test suspend on repo that does not support suspending') + "Cannot test suspend on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) @@ -459,93 +516,105 @@ def test_read_after_resume_abort_fails(self): def test_cannot_resume_aborted_write_group(self): self.require_suspendable_write_groups( - 'Cannot test resume on repo that does not support suspending') + "Cannot test resume on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) same_repo.abort_write_group() same_repo = self.reopen_repo(repo) self.assertRaises( - errors.UnresumableWriteGroup, same_repo.resume_write_group, - wg_tokens) + errors.UnresumableWriteGroup, same_repo.resume_write_group, wg_tokens + ) def test_commit_resumed_write_group_no_new_data(self): self.require_suspendable_write_groups( - 'Cannot test resume on repo that does not support suspending') + "Cannot test resume on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) same_repo.commit_write_group() self.assertEqual([text_key], list(same_repo.texts.keys())) self.assertEqual( - b'lines', next(same_repo.texts.get_record_stream([text_key], - 'unordered', True)).get_bytes_as('fulltext')) + b"lines", + next( + same_repo.texts.get_record_stream([text_key], "unordered", True) + ).get_bytes_as("fulltext"), + ) self.assertRaises( - errors.UnresumableWriteGroup, same_repo.resume_write_group, - wg_tokens) + errors.UnresumableWriteGroup, same_repo.resume_write_group, wg_tokens + ) def test_commit_resumed_write_group_plus_new_data(self): self.require_suspendable_write_groups( - 'Cannot test resume on repo that does not support suspending') + "Cannot test resume on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - first_key = (b'file-id', b'revid') - repo.texts.add_lines(first_key, (), [b'lines']) + first_key = (b"file-id", b"revid") + repo.texts.add_lines(first_key, (), [b"lines"]) wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) - second_key = (b'file-id', b'second-revid') - same_repo.texts.add_lines(second_key, (first_key,), [b'more lines']) + second_key = (b"file-id", b"second-revid") + same_repo.texts.add_lines(second_key, (first_key,), [b"more lines"]) same_repo.commit_write_group() + self.assertEqual({first_key, second_key}, set(same_repo.texts.keys())) self.assertEqual( - {first_key, second_key}, set(same_repo.texts.keys())) - self.assertEqual( - b'lines', next(same_repo.texts.get_record_stream([first_key], - 'unordered', True)).get_bytes_as('fulltext')) + b"lines", + next( + same_repo.texts.get_record_stream([first_key], "unordered", True) + ).get_bytes_as("fulltext"), + ) self.assertEqual( - b'more lines', next(same_repo.texts.get_record_stream([second_key], - 'unordered', True)).get_bytes_as('fulltext')) + b"more lines", + next( + same_repo.texts.get_record_stream([second_key], "unordered", True) + ).get_bytes_as("fulltext"), + ) def make_source_with_delta_record(self): # Make a source repository with a delta record in it. - source_repo = self.make_write_locked_repo('source') + source_repo = self.make_write_locked_repo("source") source_repo.start_write_group() - key_base = (b'file-id', b'base') - key_delta = (b'file-id', b'delta') + key_base = (b"file-id", b"base") + key_delta = (b"file-id", b"delta") def text_stream(): + yield versionedfile.FulltextContentFactory(key_base, (), None, b"lines\n") yield versionedfile.FulltextContentFactory( - key_base, (), None, b'lines\n') - yield versionedfile.FulltextContentFactory( - key_delta, (key_base,), None, b'more\nlines\n') + key_delta, (key_base,), None, b"more\nlines\n" + ) + source_repo.texts.insert_record_stream(text_stream()) source_repo.commit_write_group() return source_repo def test_commit_resumed_write_group_with_missing_parents(self): self.require_suspendable_write_groups( - 'Cannot test resume on repo that does not support suspending') + "Cannot test resume on repo that does not support suspending" + ) source_repo = self.make_source_with_delta_record() - key_delta = (b'file-id', b'delta') + key_delta = (b"file-id", b"delta") # Start a write group, insert just a delta. repo = self.make_write_locked_repo() repo.start_write_group() - stream = source_repo.texts.get_record_stream( - [key_delta], 'unordered', False) + stream = source_repo.texts.get_record_stream([key_delta], "unordered", False) repo.texts.insert_record_stream(stream) # It's either not commitable due to the missing compression parent, or # the stacked location has already filled in the fulltext. @@ -557,37 +626,37 @@ def test_commit_resumed_write_group_with_missing_parents(self): else: same_repo = self.reopen_repo(repo) same_repo.lock_read() - record = next(same_repo.texts.get_record_stream([key_delta], - 'unordered', True)) - self.assertEqual(b'more\nlines\n', record.get_bytes_as('fulltext')) + record = next( + same_repo.texts.get_record_stream([key_delta], "unordered", True) + ) + self.assertEqual(b"more\nlines\n", record.get_bytes_as("fulltext")) return # Merely suspending and resuming doesn't make it commitable either. wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) - self.assertRaises( - errors.BzrCheckError, same_repo.commit_write_group) + self.assertRaises(errors.BzrCheckError, same_repo.commit_write_group) same_repo.abort_write_group() def test_commit_resumed_write_group_adding_missing_parents(self): self.require_suspendable_write_groups( - 'Cannot test resume on repo that does not support suspending') + "Cannot test resume on repo that does not support suspending" + ) source_repo = self.make_source_with_delta_record() - key_delta = (b'file-id', b'delta') + key_delta = (b"file-id", b"delta") # Start a write group. repo = self.make_write_locked_repo() repo.start_write_group() # Add some content so this isn't an empty write group (which may return # 0 tokens) - text_key = (b'file-id', b'revid') - repo.texts.add_lines(text_key, (), [b'lines']) + text_key = (b"file-id", b"revid") + repo.texts.add_lines(text_key, (), [b"lines"]) # Suspend it, then resume it. wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) # Add a record with a missing compression parent - stream = source_repo.texts.get_record_stream( - [key_delta], 'unordered', False) + stream = source_repo.texts.get_record_stream([key_delta], "unordered", False) same_repo.texts.insert_record_stream(stream) # Just like if we'd added that record without a suspend/resume cycle, # commit_write_group fails. @@ -600,31 +669,31 @@ def test_commit_resumed_write_group_adding_missing_parents(self): # insert_record_stream already gave it a fulltext. same_repo = self.reopen_repo(repo) same_repo.lock_read() - record = next(same_repo.texts.get_record_stream([key_delta], - 'unordered', True)) - self.assertEqual(b'more\nlines\n', record.get_bytes_as('fulltext')) + record = next( + same_repo.texts.get_record_stream([key_delta], "unordered", True) + ) + self.assertEqual(b"more\nlines\n", record.get_bytes_as("fulltext")) return same_repo.abort_write_group() def test_add_missing_parent_after_resume(self): self.require_suspendable_write_groups( - 'Cannot test resume on repo that does not support suspending') + "Cannot test resume on repo that does not support suspending" + ) source_repo = self.make_source_with_delta_record() - key_base = (b'file-id', b'base') - key_delta = (b'file-id', b'delta') + key_base = (b"file-id", b"base") + key_delta = (b"file-id", b"delta") # Start a write group, insert just a delta. repo = self.make_write_locked_repo() repo.start_write_group() - stream = source_repo.texts.get_record_stream( - [key_delta], 'unordered', False) + stream = source_repo.texts.get_record_stream([key_delta], "unordered", False) repo.texts.insert_record_stream(stream) # Suspend it, then resume it. wg_tokens = repo.suspend_write_group() same_repo = self.reopen_repo(repo) same_repo.resume_write_group(wg_tokens) # Fill in the missing compression parent. - stream = source_repo.texts.get_record_stream( - [key_base], 'unordered', False) + stream = source_repo.texts.get_record_stream([key_base], "unordered", False) same_repo.texts.insert_record_stream(stream) same_repo.commit_write_group() @@ -633,7 +702,8 @@ def test_suspend_empty_initial_write_group(self): list. """ self.require_suspendable_write_groups( - 'Cannot test suspend on repo that does not support suspending') + "Cannot test suspend on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.start_write_group() wg_tokens = repo.suspend_write_group() @@ -642,7 +712,8 @@ def test_suspend_empty_initial_write_group(self): def test_resume_empty_initial_write_group(self): """Resuming an empty token list is equivalent to start_write_group.""" self.require_suspendable_write_groups( - 'Cannot test resume on repo that does not support suspending') + "Cannot test resume on repo that does not support suspending" + ) repo = self.make_write_locked_repo() repo.resume_write_group([]) repo.abort_write_group() diff --git a/breezy/bzr/tests/per_versionedfile.py b/breezy/bzr/tests/per_versionedfile.py index 9116084fd8..c786e82933 100644 --- a/breezy/bzr/tests/per_versionedfile.py +++ b/breezy/bzr/tests/per_versionedfile.py @@ -63,30 +63,33 @@ def get_diamond_vf(f, trailing_eol=True, left_only=False): :param trailing_eol: If True end the last line with \n. """ parents = { - b'origin': (), - b'base': ((b'origin',),), - b'left': ((b'base',),), - b'right': ((b'base',),), - b'merged': ((b'left',), (b'right',)), - } + b"origin": (), + b"base": ((b"origin",),), + b"left": ((b"base",),), + b"right": ((b"base",),), + b"merged": ((b"left",), (b"right",)), + } # insert a diamond graph to exercise deltas and merges. if trailing_eol: - last_char = b'\n' + last_char = b"\n" else: - last_char = b'' - f.add_lines(b'origin', [], [b'origin' + last_char]) - f.add_lines(b'base', [b'origin'], [b'base' + last_char]) - f.add_lines(b'left', [b'base'], [b'base\n', b'left' + last_char]) + last_char = b"" + f.add_lines(b"origin", [], [b"origin" + last_char]) + f.add_lines(b"base", [b"origin"], [b"base" + last_char]) + f.add_lines(b"left", [b"base"], [b"base\n", b"left" + last_char]) if not left_only: - f.add_lines(b'right', [b'base'], - [b'base\n', b'right' + last_char]) - f.add_lines(b'merged', [b'left', b'right'], - [b'base\n', b'left\n', b'right\n', b'merged' + last_char]) + f.add_lines(b"right", [b"base"], [b"base\n", b"right" + last_char]) + f.add_lines( + b"merged", + [b"left", b"right"], + [b"base\n", b"left\n", b"right\n", b"merged" + last_char], + ) return f, parents -def get_diamond_files(files, key_length, trailing_eol=True, left_only=False, - nograph=False, nokeys=False): +def get_diamond_files( + files, key_length, trailing_eol=True, left_only=False, nograph=False, nokeys=False +): r"""Get a diamond graph to exercise deltas and merges. This creates a 5-node graph in files. If files supports 2-length keys two @@ -108,12 +111,12 @@ def get_diamond_files(files, key_length, trailing_eol=True, left_only=False, if key_length == 1: prefixes = [()] else: - prefixes = [(b'FileA',), (b'FileB',)] + prefixes = [(b"FileA",), (b"FileB",)] # insert a diamond graph to exercise deltas and merges. if trailing_eol: - last_char = b'\n' + last_char = b"\n" else: - last_char = b'' + last_char = b"" result = [] def get_parents(suffix_list): @@ -125,31 +128,49 @@ def get_parents(suffix_list): def get_key(suffix): if nokeys: - return (None, ) + return (None,) else: return (suffix,) + # we loop over each key because that spreads the inserts across prefixes, # which is how commit operates. for prefix in prefixes: - result.append(files.add_lines(prefix + get_key(b'origin'), (), - [b'origin' + last_char])) + result.append( + files.add_lines(prefix + get_key(b"origin"), (), [b"origin" + last_char]) + ) for prefix in prefixes: - result.append(files.add_lines(prefix + get_key(b'base'), - get_parents([(b'origin',)]), [b'base' + last_char])) + result.append( + files.add_lines( + prefix + get_key(b"base"), + get_parents([(b"origin",)]), + [b"base" + last_char], + ) + ) for prefix in prefixes: - result.append(files.add_lines(prefix + get_key(b'left'), - get_parents([(b'base',)]), - [b'base\n', b'left' + last_char])) + result.append( + files.add_lines( + prefix + get_key(b"left"), + get_parents([(b"base",)]), + [b"base\n", b"left" + last_char], + ) + ) if not left_only: for prefix in prefixes: - result.append(files.add_lines(prefix + get_key(b'right'), - get_parents([(b'base',)]), - [b'base\n', b'right' + last_char])) + result.append( + files.add_lines( + prefix + get_key(b"right"), + get_parents([(b"base",)]), + [b"base\n", b"right" + last_char], + ) + ) for prefix in prefixes: - result.append(files.add_lines(prefix + get_key(b'merged'), - get_parents( - [(b'left',), (b'right',)]), - [b'base\n', b'left\n', b'right\n', b'merged' + last_char])) + result.append( + files.add_lines( + prefix + get_key(b"merged"), + get_parents([(b"left",), (b"right",)]), + [b"base\n", b"left\n", b"right\n", b"merged" + last_char], + ) + ) return result @@ -162,29 +183,28 @@ class VersionedFileTestMixIn: """ def get_transaction(self): - if not hasattr(self, '_transaction'): + if not hasattr(self, "_transaction"): self._transaction = None return self._transaction def test_add(self): f = self.get_file() - f.add_lines(b'r0', [], [b'a\n', b'b\n']) - f.add_lines(b'r1', [b'r0'], [b'b\n', b'c\n']) + f.add_lines(b"r0", [], [b"a\n", b"b\n"]) + f.add_lines(b"r1", [b"r0"], [b"b\n", b"c\n"]) def verify_file(f): versions = f.versions() - self.assertTrue(b'r0' in versions) - self.assertTrue(b'r1' in versions) - self.assertEqual(f.get_lines(b'r0'), [b'a\n', b'b\n']) - self.assertEqual(f.get_text(b'r0'), b'a\nb\n') - self.assertEqual(f.get_lines(b'r1'), [b'b\n', b'c\n']) + self.assertTrue(b"r0" in versions) + self.assertTrue(b"r1" in versions) + self.assertEqual(f.get_lines(b"r0"), [b"a\n", b"b\n"]) + self.assertEqual(f.get_text(b"r0"), b"a\nb\n") + self.assertEqual(f.get_lines(b"r1"), [b"b\n", b"c\n"]) self.assertEqual(2, len(f)) self.assertEqual(2, f.num_versions()) - self.assertRaises(RevisionNotPresent, - f.add_lines, b'r2', [b'foo'], []) - self.assertRaises(RevisionAlreadyPresent, - f.add_lines, b'r1', [], []) + self.assertRaises(RevisionNotPresent, f.add_lines, b"r2", [b"foo"], []) + self.assertRaises(RevisionAlreadyPresent, f.add_lines, b"r1", [], []) + verify_file(f) # this checks that reopen with create=True does not break anything. f = self.reopen_file(create=True) @@ -193,34 +213,35 @@ def verify_file(f): def test_adds_with_parent_texts(self): f = self.get_file() parent_texts = {} - _, _, parent_texts[b'r0'] = f.add_lines(b'r0', [], [b'a\n', b'b\n']) + _, _, parent_texts[b"r0"] = f.add_lines(b"r0", [], [b"a\n", b"b\n"]) try: - _, _, parent_texts[b'r1'] = f.add_lines_with_ghosts(b'r1', - [b'r0', b'ghost'], [b'b\n', b'c\n'], parent_texts=parent_texts) + _, _, parent_texts[b"r1"] = f.add_lines_with_ghosts( + b"r1", [b"r0", b"ghost"], [b"b\n", b"c\n"], parent_texts=parent_texts + ) except NotImplementedError: # if the format doesn't support ghosts, just add normally. - _, _, parent_texts[b'r1'] = f.add_lines(b'r1', - [b'r0'], [b'b\n', b'c\n'], parent_texts=parent_texts) - f.add_lines(b'r2', [b'r1'], [b'c\n', b'd\n'], - parent_texts=parent_texts) - self.assertNotEqual(None, parent_texts[b'r0']) - self.assertNotEqual(None, parent_texts[b'r1']) + _, _, parent_texts[b"r1"] = f.add_lines( + b"r1", [b"r0"], [b"b\n", b"c\n"], parent_texts=parent_texts + ) + f.add_lines(b"r2", [b"r1"], [b"c\n", b"d\n"], parent_texts=parent_texts) + self.assertNotEqual(None, parent_texts[b"r0"]) + self.assertNotEqual(None, parent_texts[b"r1"]) def verify_file(f): versions = f.versions() - self.assertTrue(b'r0' in versions) - self.assertTrue(b'r1' in versions) - self.assertTrue(b'r2' in versions) - self.assertEqual(f.get_lines(b'r0'), [b'a\n', b'b\n']) - self.assertEqual(f.get_lines(b'r1'), [b'b\n', b'c\n']) - self.assertEqual(f.get_lines(b'r2'), [b'c\n', b'd\n']) + self.assertTrue(b"r0" in versions) + self.assertTrue(b"r1" in versions) + self.assertTrue(b"r2" in versions) + self.assertEqual(f.get_lines(b"r0"), [b"a\n", b"b\n"]) + self.assertEqual(f.get_lines(b"r1"), [b"b\n", b"c\n"]) + self.assertEqual(f.get_lines(b"r2"), [b"c\n", b"d\n"]) self.assertEqual(3, f.num_versions()) - origins = f.annotate(b'r1') - self.assertEqual(origins[0][0], b'r0') - self.assertEqual(origins[1][0], b'r1') - origins = f.annotate(b'r2') - self.assertEqual(origins[0][0], b'r1') - self.assertEqual(origins[1][0], b'r2') + origins = f.annotate(b"r1") + self.assertEqual(origins[0][0], b"r0") + self.assertEqual(origins[1][0], b"r1") + origins = f.annotate(b"r2") + self.assertEqual(origins[0][0], b"r1") + self.assertEqual(origins[1][0], b"r2") verify_file(f) f = self.reopen_file() @@ -230,11 +251,20 @@ def test_add_unicode_content(self): # unicode content is not permitted in versioned files. # versioned files version sequences of bytes only. vf = self.get_file() - self.assertRaises(errors.BzrBadParameterUnicode, - vf.add_lines, b'a', [], [b'a\n', 'b\n', b'c\n']) + self.assertRaises( + errors.BzrBadParameterUnicode, + vf.add_lines, + b"a", + [], + [b"a\n", "b\n", b"c\n"], + ) self.assertRaises( (errors.BzrBadParameterUnicode, NotImplementedError), - vf.add_lines_with_ghosts, b'a', [], [b'a\n', 'b\n', b'c\n']) + vf.add_lines_with_ghosts, + b"a", + [], + [b"a\n", "b\n", b"c\n"], + ) def test_add_follows_left_matching_blocks(self): """If we change left_matching_blocks, delta changes. @@ -245,60 +275,79 @@ def test_add_follows_left_matching_blocks(self): vf = self.get_file() if isinstance(vf, WeaveFile): raise TestSkipped("WeaveFile ignores left_matching_blocks") - vf.add_lines(b'1', [], [b'a\n']) - vf.add_lines(b'2', [b'1'], [b'a\n', b'a\n', b'a\n'], - left_matching_blocks=[(0, 0, 1), (1, 3, 0)]) - self.assertEqual([b'a\n', b'a\n', b'a\n'], vf.get_lines(b'2')) - vf.add_lines(b'3', [b'1'], [b'a\n', b'a\n', b'a\n'], - left_matching_blocks=[(0, 2, 1), (1, 3, 0)]) - self.assertEqual([b'a\n', b'a\n', b'a\n'], vf.get_lines(b'3')) + vf.add_lines(b"1", [], [b"a\n"]) + vf.add_lines( + b"2", + [b"1"], + [b"a\n", b"a\n", b"a\n"], + left_matching_blocks=[(0, 0, 1), (1, 3, 0)], + ) + self.assertEqual([b"a\n", b"a\n", b"a\n"], vf.get_lines(b"2")) + vf.add_lines( + b"3", + [b"1"], + [b"a\n", b"a\n", b"a\n"], + left_matching_blocks=[(0, 2, 1), (1, 3, 0)], + ) + self.assertEqual([b"a\n", b"a\n", b"a\n"], vf.get_lines(b"3")) def test_inline_newline_throws(self): # \r characters are not permitted in lines being added vf = self.get_file() - self.assertRaises(errors.BzrBadParameterContainsNewline, - vf.add_lines, b'a', [], [b'a\n\n']) + self.assertRaises( + errors.BzrBadParameterContainsNewline, vf.add_lines, b"a", [], [b"a\n\n"] + ) self.assertRaises( (errors.BzrBadParameterContainsNewline, NotImplementedError), - vf.add_lines_with_ghosts, b'a', [], [b'a\n\n']) + vf.add_lines_with_ghosts, + b"a", + [], + [b"a\n\n"], + ) # but inline CR's are allowed - vf.add_lines(b'a', [], [b'a\r\n']) + vf.add_lines(b"a", [], [b"a\r\n"]) try: - vf.add_lines_with_ghosts(b'b', [], [b'a\r\n']) + vf.add_lines_with_ghosts(b"b", [], [b"a\r\n"]) except NotImplementedError: pass def test_add_reserved(self): vf = self.get_file() - self.assertRaises(errors.ReservedId, - vf.add_lines, b'a:', [], [b'a\n', b'b\n', b'c\n']) + self.assertRaises( + errors.ReservedId, vf.add_lines, b"a:", [], [b"a\n", b"b\n", b"c\n"] + ) def test_add_lines_nostoresha(self): """When nostore_sha is supplied using old content raises.""" vf = self.get_file() - empty_text = (b'a', []) - sample_text_nl = (b'b', [b"foo\n", b"bar\n"]) - sample_text_no_nl = (b'c', [b"foo\n", b"bar"]) + empty_text = (b"a", []) + sample_text_nl = (b"b", [b"foo\n", b"bar\n"]) + sample_text_no_nl = (b"c", [b"foo\n", b"bar"]) shas = [] for version, lines in (empty_text, sample_text_nl, sample_text_no_nl): sha, _, _ = vf.add_lines(version, [], lines) shas.append(sha) # we now have a copy of all the lines in the vf. for sha, (version, lines) in zip( - shas, (empty_text, sample_text_nl, sample_text_no_nl)): - self.assertRaises(ExistingContent, - vf.add_lines, version + b"2", [], lines, - nostore_sha=sha) + shas, (empty_text, sample_text_nl, sample_text_no_nl) + ): + self.assertRaises( + ExistingContent, + vf.add_lines, + version + b"2", + [], + lines, + nostore_sha=sha, + ) # and no new version should have been added. - self.assertRaises(errors.RevisionNotPresent, vf.get_lines, - version + b"2") + self.assertRaises(errors.RevisionNotPresent, vf.get_lines, version + b"2") def test_add_lines_with_ghosts_nostoresha(self): """When nostore_sha is supplied using old content raises.""" vf = self.get_file() - empty_text = (b'a', []) - sample_text_nl = (b'b', [b"foo\n", b"bar\n"]) - sample_text_no_nl = (b'c', [b"foo\n", b"bar"]) + empty_text = (b"a", []) + sample_text_nl = (b"b", [b"foo\n", b"bar\n"]) + sample_text_no_nl = (b"c", [b"foo\n", b"bar"]) shas = [] for version, lines in (empty_text, sample_text_nl, sample_text_no_nl): sha, _, _ = vf.add_lines(version, [], lines) @@ -306,24 +355,29 @@ def test_add_lines_with_ghosts_nostoresha(self): # we now have a copy of all the lines in the vf. # is the test applicable to this vf implementation? try: - vf.add_lines_with_ghosts(b'd', [], []) + vf.add_lines_with_ghosts(b"d", [], []) except NotImplementedError as e: raise TestSkipped("add_lines_with_ghosts is optional") from e for sha, (version, lines) in zip( - shas, (empty_text, sample_text_nl, sample_text_no_nl)): - self.assertRaises(ExistingContent, - vf.add_lines_with_ghosts, version + b"2", [], lines, - nostore_sha=sha) + shas, (empty_text, sample_text_nl, sample_text_no_nl) + ): + self.assertRaises( + ExistingContent, + vf.add_lines_with_ghosts, + version + b"2", + [], + lines, + nostore_sha=sha, + ) # and no new version should have been added. - self.assertRaises(errors.RevisionNotPresent, vf.get_lines, - version + b"2") + self.assertRaises(errors.RevisionNotPresent, vf.get_lines, version + b"2") def test_add_lines_return_value(self): # add_lines should return the sha1 and the text size. vf = self.get_file() - empty_text = (b'a', []) - sample_text_nl = (b'b', [b"foo\n", b"bar\n"]) - sample_text_no_nl = (b'c', [b"foo\n", b"bar"]) + empty_text = (b"a", []) + sample_text_nl = (b"b", [b"foo\n", b"bar\n"]) + sample_text_no_nl = (b"c", [b"foo\n", b"bar"]) # check results for the three cases: for version, lines in (empty_text, sample_text_nl, sample_text_no_nl): # the first two elements are the same for all versioned files: @@ -331,18 +385,21 @@ def test_add_lines_return_value(self): # additional data is returned in additional tuple elements. result = vf.add_lines(version, [], lines) self.assertEqual(3, len(result)) - self.assertEqual((osutils.sha_strings(lines), sum(map(len, lines))), - result[0:2]) + self.assertEqual( + (osutils.sha_strings(lines), sum(map(len, lines))), result[0:2] + ) # parents should not affect the result: lines = sample_text_nl[1] - self.assertEqual((osutils.sha_strings(lines), sum(map(len, lines))), - vf.add_lines(b'd', [b'b', b'c'], lines)[0:2]) + self.assertEqual( + (osutils.sha_strings(lines), sum(map(len, lines))), + vf.add_lines(b"d", [b"b", b"c"], lines)[0:2], + ) def test_get_reserved(self): vf = self.get_file() - self.assertRaises(errors.ReservedId, vf.get_texts, [b'b:']) - self.assertRaises(errors.ReservedId, vf.get_lines, b'b:') - self.assertRaises(errors.ReservedId, vf.get_text, b'b:') + self.assertRaises(errors.ReservedId, vf.get_texts, [b"b:"]) + self.assertRaises(errors.ReservedId, vf.get_lines, b"b:") + self.assertRaises(errors.ReservedId, vf.get_text, b"b:") def test_add_unchanged_last_line_noeol_snapshot(self): """Add a text with an unchanged last line with no eol should work.""" @@ -356,18 +413,18 @@ def test_add_unchanged_last_line_noeol_snapshot(self): # happens). for length in range(20): version_lines = {} - vf = self.get_file('case-%d' % length) - prefix = b'step-%d' + vf = self.get_file("case-%d" % length) + prefix = b"step-%d" parents = [] for step in range(length): version = prefix % step - lines = ([b'prelude \n'] * step) + [b'line'] + lines = ([b"prelude \n"] * step) + [b"line"] vf.add_lines(version, parents, lines) version_lines[version] = lines parents = [version] - vf.add_lines(b'no-eol', parents, [b'line']) + vf.add_lines(b"no-eol", parents, [b"line"]) vf.get_texts(version_lines.keys()) - self.assertEqualDiff(b'line', vf.get_text(b'no-eol')) + self.assertEqualDiff(b"line", vf.get_text(b"no-eol")) def test_get_texts_eol_variation(self): # similar to the failure in @@ -378,7 +435,7 @@ def test_get_texts_eol_variation(self): version_lines = {} parents = [] for i in range(4): - version = b'v%d' % i + version = b"v%d" % i if i % 2: lines = sample_text_nl else: @@ -389,8 +446,7 @@ def test_get_texts_eol_variation(self): # file. Using it here ensures that a broken internal implementation # (which is what this test tests) will generate a correct line # delta (which is to say, an empty delta). - vf.add_lines(version, parents, lines, - left_matching_blocks=[(0, 0, 1)]) + vf.add_lines(version, parents, lines, left_matching_blocks=[(0, 0, 1)]) parents = [version] versions.append(version) version_lines[version] = lines @@ -405,108 +461,123 @@ def test_add_lines_with_matching_blocks_noeol_last_line(self): # reuse the last line unaltered (which can cause annotation reuse). # Test adding this in two situations: # On top of a new insertion - vf = self.get_file('fulltext') - vf.add_lines(b'noeol', [], [b'line']) - vf.add_lines(b'noeol2', [b'noeol'], [b'newline\n', b'line'], - left_matching_blocks=[(0, 1, 1)]) - self.assertEqualDiff(b'newline\nline', vf.get_text(b'noeol2')) + vf = self.get_file("fulltext") + vf.add_lines(b"noeol", [], [b"line"]) + vf.add_lines( + b"noeol2", + [b"noeol"], + [b"newline\n", b"line"], + left_matching_blocks=[(0, 1, 1)], + ) + self.assertEqualDiff(b"newline\nline", vf.get_text(b"noeol2")) # On top of a delta - vf = self.get_file('delta') - vf.add_lines(b'base', [], [b'line']) - vf.add_lines(b'noeol', [b'base'], [b'prelude\n', b'line']) - vf.add_lines(b'noeol2', [b'noeol'], [b'newline\n', b'line'], - left_matching_blocks=[(1, 1, 1)]) - self.assertEqualDiff(b'newline\nline', vf.get_text(b'noeol2')) + vf = self.get_file("delta") + vf.add_lines(b"base", [], [b"line"]) + vf.add_lines(b"noeol", [b"base"], [b"prelude\n", b"line"]) + vf.add_lines( + b"noeol2", + [b"noeol"], + [b"newline\n", b"line"], + left_matching_blocks=[(1, 1, 1)], + ) + self.assertEqualDiff(b"newline\nline", vf.get_text(b"noeol2")) def test_make_mpdiffs(self): from breezy import multiparent - vf = self.get_file('foo') + + vf = self.get_file("foo") self._setup_for_deltas(vf) - new_vf = self.get_file('bar') + new_vf = self.get_file("bar") for version in multiparent.topo_iter(vf): mpdiff = vf.make_mpdiffs([version])[0] - new_vf.add_mpdiffs([(version, vf.get_parent_map([version])[version], - vf.get_sha1s([version])[version], mpdiff)]) - self.assertEqualDiff(vf.get_text(version), - new_vf.get_text(version)) + new_vf.add_mpdiffs( + [ + ( + version, + vf.get_parent_map([version])[version], + vf.get_sha1s([version])[version], + mpdiff, + ) + ] + ) + self.assertEqualDiff(vf.get_text(version), new_vf.get_text(version)) def test_make_mpdiffs_with_ghosts(self): - vf = self.get_file('foo') + vf = self.get_file("foo") try: - vf.add_lines_with_ghosts(b'text', [b'ghost'], [b'line\n']) + vf.add_lines_with_ghosts(b"text", [b"ghost"], [b"line\n"]) except NotImplementedError: # old Weave formats do not allow ghosts return - self.assertRaises(errors.RevisionNotPresent, - vf.make_mpdiffs, [b'ghost']) + self.assertRaises(errors.RevisionNotPresent, vf.make_mpdiffs, [b"ghost"]) def _setup_for_deltas(self, f): - self.assertFalse(f.has_version('base')) + self.assertFalse(f.has_version("base")) # add texts that should trip the knit maximum delta chain threshold # as well as doing parallel chains of data in knits. # this is done by two chains of 25 insertions - f.add_lines(b'base', [], [b'line\n']) - f.add_lines(b'noeol', [b'base'], [b'line']) + f.add_lines(b"base", [], [b"line\n"]) + f.add_lines(b"noeol", [b"base"], [b"line"]) # detailed eol tests: # shared last line with parent no-eol - f.add_lines(b'noeolsecond', [b'noeol'], [b'line\n', b'line']) + f.add_lines(b"noeolsecond", [b"noeol"], [b"line\n", b"line"]) # differing last line with parent, both no-eol - f.add_lines(b'noeolnotshared', [b'noeolsecond'], [b'line\n', b'phone']) + f.add_lines(b"noeolnotshared", [b"noeolsecond"], [b"line\n", b"phone"]) # add eol following a noneol parent, change content - f.add_lines(b'eol', [b'noeol'], [b'phone\n']) + f.add_lines(b"eol", [b"noeol"], [b"phone\n"]) # add eol following a noneol parent, no change content - f.add_lines(b'eolline', [b'noeol'], [b'line\n']) + f.add_lines(b"eolline", [b"noeol"], [b"line\n"]) # noeol with no parents: - f.add_lines(b'noeolbase', [], [b'line']) + f.add_lines(b"noeolbase", [], [b"line"]) # noeol preceeding its leftmost parent in the output: # this is done by making it a merge of two parents with no common # anestry: noeolbase and noeol with the # later-inserted parent the leftmost. - f.add_lines(b'eolbeforefirstparent', [ - b'noeolbase', b'noeol'], [b'line']) + f.add_lines(b"eolbeforefirstparent", [b"noeolbase", b"noeol"], [b"line"]) # two identical eol texts - f.add_lines(b'noeoldup', [b'noeol'], [b'line']) - next_parent = b'base' - text_name = b'chain1-' - text = [b'line\n'] - sha1s = {0: b'da6d3141cb4a5e6f464bf6e0518042ddc7bfd079', - 1: b'45e21ea146a81ea44a821737acdb4f9791c8abe7', - 2: b'e1f11570edf3e2a070052366c582837a4fe4e9fa', - 3: b'26b4b8626da827088c514b8f9bbe4ebf181edda1', - 4: b'e28a5510be25ba84d31121cff00956f9970ae6f6', - 5: b'd63ec0ce22e11dcf65a931b69255d3ac747a318d', - 6: b'2c2888d288cb5e1d98009d822fedfe6019c6a4ea', - 7: b'95c14da9cafbf828e3e74a6f016d87926ba234ab', - 8: b'779e9a0b28f9f832528d4b21e17e168c67697272', - 9: b'1f8ff4e5c6ff78ac106fcfe6b1e8cb8740ff9a8f', - 10: b'131a2ae712cf51ed62f143e3fbac3d4206c25a05', - 11: b'c5a9d6f520d2515e1ec401a8f8a67e6c3c89f199', - 12: b'31a2286267f24d8bedaa43355f8ad7129509ea85', - 13: b'dc2a7fe80e8ec5cae920973973a8ee28b2da5e0a', - 14: b'2c4b1736566b8ca6051e668de68650686a3922f2', - 15: b'5912e4ecd9b0c07be4d013e7e2bdcf9323276cde', - 16: b'b0d2e18d3559a00580f6b49804c23fea500feab3', - 17: b'8e1d43ad72f7562d7cb8f57ee584e20eb1a69fc7', - 18: b'5cf64a3459ae28efa60239e44b20312d25b253f3', - 19: b'1ebed371807ba5935958ad0884595126e8c4e823', - 20: b'2aa62a8b06fb3b3b892a3292a068ade69d5ee0d3', - 21: b'01edc447978004f6e4e962b417a4ae1955b6fe5d', - 22: b'd8d8dc49c4bf0bab401e0298bb5ad827768618bb', - 23: b'c21f62b1c482862983a8ffb2b0c64b3451876e3f', - 24: b'c0593fe795e00dff6b3c0fe857a074364d5f04fc', - 25: b'dd1a1cf2ba9cc225c3aff729953e6364bf1d1855', - } + f.add_lines(b"noeoldup", [b"noeol"], [b"line"]) + next_parent = b"base" + text_name = b"chain1-" + text = [b"line\n"] + sha1s = { + 0: b"da6d3141cb4a5e6f464bf6e0518042ddc7bfd079", + 1: b"45e21ea146a81ea44a821737acdb4f9791c8abe7", + 2: b"e1f11570edf3e2a070052366c582837a4fe4e9fa", + 3: b"26b4b8626da827088c514b8f9bbe4ebf181edda1", + 4: b"e28a5510be25ba84d31121cff00956f9970ae6f6", + 5: b"d63ec0ce22e11dcf65a931b69255d3ac747a318d", + 6: b"2c2888d288cb5e1d98009d822fedfe6019c6a4ea", + 7: b"95c14da9cafbf828e3e74a6f016d87926ba234ab", + 8: b"779e9a0b28f9f832528d4b21e17e168c67697272", + 9: b"1f8ff4e5c6ff78ac106fcfe6b1e8cb8740ff9a8f", + 10: b"131a2ae712cf51ed62f143e3fbac3d4206c25a05", + 11: b"c5a9d6f520d2515e1ec401a8f8a67e6c3c89f199", + 12: b"31a2286267f24d8bedaa43355f8ad7129509ea85", + 13: b"dc2a7fe80e8ec5cae920973973a8ee28b2da5e0a", + 14: b"2c4b1736566b8ca6051e668de68650686a3922f2", + 15: b"5912e4ecd9b0c07be4d013e7e2bdcf9323276cde", + 16: b"b0d2e18d3559a00580f6b49804c23fea500feab3", + 17: b"8e1d43ad72f7562d7cb8f57ee584e20eb1a69fc7", + 18: b"5cf64a3459ae28efa60239e44b20312d25b253f3", + 19: b"1ebed371807ba5935958ad0884595126e8c4e823", + 20: b"2aa62a8b06fb3b3b892a3292a068ade69d5ee0d3", + 21: b"01edc447978004f6e4e962b417a4ae1955b6fe5d", + 22: b"d8d8dc49c4bf0bab401e0298bb5ad827768618bb", + 23: b"c21f62b1c482862983a8ffb2b0c64b3451876e3f", + 24: b"c0593fe795e00dff6b3c0fe857a074364d5f04fc", + 25: b"dd1a1cf2ba9cc225c3aff729953e6364bf1d1855", + } for depth in range(26): - new_version = text_name + b'%d' % depth - text = text + [b'line\n'] + new_version = text_name + b"%d" % depth + text = text + [b"line\n"] f.add_lines(new_version, [next_parent], text) next_parent = new_version - next_parent = b'base' - text_name = b'chain2-' - text = [b'line\n'] + next_parent = b"base" + text_name = b"chain2-" + text = [b"line\n"] for depth in range(26): - new_version = text_name + b'%d' % depth - text = text + [b'line\n'] + new_version = text_name + b"%d" % depth + text = text + [b"line\n"] f.add_lines(new_version, [next_parent], text) next_parent = new_version return sha1s @@ -514,35 +585,34 @@ def _setup_for_deltas(self, f): def test_ancestry(self): f = self.get_file() self.assertEqual(set(), f.get_ancestry([])) - f.add_lines(b'r0', [], [b'a\n', b'b\n']) - f.add_lines(b'r1', [b'r0'], [b'b\n', b'c\n']) - f.add_lines(b'r2', [b'r0'], [b'b\n', b'c\n']) - f.add_lines(b'r3', [b'r2'], [b'b\n', b'c\n']) - f.add_lines(b'rM', [b'r1', b'r2'], [b'b\n', b'c\n']) + f.add_lines(b"r0", [], [b"a\n", b"b\n"]) + f.add_lines(b"r1", [b"r0"], [b"b\n", b"c\n"]) + f.add_lines(b"r2", [b"r0"], [b"b\n", b"c\n"]) + f.add_lines(b"r3", [b"r2"], [b"b\n", b"c\n"]) + f.add_lines(b"rM", [b"r1", b"r2"], [b"b\n", b"c\n"]) self.assertEqual(set(), f.get_ancestry([])) - f.get_ancestry([b'rM']) + f.get_ancestry([b"rM"]) - self.assertRaises(RevisionNotPresent, - f.get_ancestry, [b'rM', b'rX']) + self.assertRaises(RevisionNotPresent, f.get_ancestry, [b"rM", b"rX"]) - self.assertEqual(set(f.get_ancestry(b'rM')), - set(f.get_ancestry(b'rM'))) + self.assertEqual(set(f.get_ancestry(b"rM")), set(f.get_ancestry(b"rM"))) def test_mutate_after_finish(self): - self._transaction = 'before' + self._transaction = "before" f = self.get_file() - self._transaction = 'after' - self.assertRaises(errors.OutSideTransaction, f.add_lines, b'', [], []) - self.assertRaises(errors.OutSideTransaction, - f.add_lines_with_ghosts, b'', [], []) + self._transaction = "after" + self.assertRaises(errors.OutSideTransaction, f.add_lines, b"", [], []) + self.assertRaises( + errors.OutSideTransaction, f.add_lines_with_ghosts, b"", [], [] + ) def test_copy_to(self): f = self.get_file() - f.add_lines(b'0', [], [b'a\n']) + f.add_lines(b"0", [], [b"a\n"]) t = MemoryTransport() - f.copy_to('foo', t) + f.copy_to("foo", t) for suffix in self.get_factory().get_suffixes(): - self.assertTrue(t.has('foo' + suffix)) + self.assertTrue(t.has("foo" + suffix)) def test_get_suffixes(self): self.get_file() @@ -551,37 +621,29 @@ def test_get_suffixes(self): def test_get_parent_map(self): f = self.get_file() - f.add_lines(b'r0', [], [b'a\n', b'b\n']) - self.assertEqual( - {b'r0': ()}, f.get_parent_map([b'r0'])) - f.add_lines(b'r1', [b'r0'], [b'a\n', b'b\n']) - self.assertEqual( - {b'r1': (b'r0',)}, f.get_parent_map([b'r1'])) + f.add_lines(b"r0", [], [b"a\n", b"b\n"]) + self.assertEqual({b"r0": ()}, f.get_parent_map([b"r0"])) + f.add_lines(b"r1", [b"r0"], [b"a\n", b"b\n"]) + self.assertEqual({b"r1": (b"r0",)}, f.get_parent_map([b"r1"])) + self.assertEqual({b"r0": (), b"r1": (b"r0",)}, f.get_parent_map([b"r0", b"r1"])) + f.add_lines(b"r2", [], [b"a\n", b"b\n"]) + f.add_lines(b"r3", [], [b"a\n", b"b\n"]) + f.add_lines(b"m", [b"r0", b"r1", b"r2", b"r3"], [b"a\n", b"b\n"]) + self.assertEqual({b"m": (b"r0", b"r1", b"r2", b"r3")}, f.get_parent_map([b"m"])) + self.assertEqual({}, f.get_parent_map(b"y")) self.assertEqual( - {b'r0': (), - b'r1': (b'r0',)}, - f.get_parent_map([b'r0', b'r1'])) - f.add_lines(b'r2', [], [b'a\n', b'b\n']) - f.add_lines(b'r3', [], [b'a\n', b'b\n']) - f.add_lines(b'm', [b'r0', b'r1', b'r2', b'r3'], [b'a\n', b'b\n']) - self.assertEqual( - {b'm': (b'r0', b'r1', b'r2', b'r3')}, f.get_parent_map([b'm'])) - self.assertEqual({}, f.get_parent_map(b'y')) - self.assertEqual( - {b'r0': (), - b'r1': (b'r0',)}, - f.get_parent_map([b'r0', b'y', b'r1'])) + {b"r0": (), b"r1": (b"r0",)}, f.get_parent_map([b"r0", b"y", b"r1"]) + ) def test_annotate(self): f = self.get_file() - f.add_lines(b'r0', [], [b'a\n', b'b\n']) - f.add_lines(b'r1', [b'r0'], [b'c\n', b'b\n']) - origins = f.annotate(b'r1') - self.assertEqual(origins[0][0], b'r1') - self.assertEqual(origins[1][0], b'r0') + f.add_lines(b"r0", [], [b"a\n", b"b\n"]) + f.add_lines(b"r1", [b"r0"], [b"c\n", b"b\n"]) + origins = f.annotate(b"r1") + self.assertEqual(origins[0][0], b"r1") + self.assertEqual(origins[1][0], b"r0") - self.assertRaises(RevisionNotPresent, - f.annotate, b'foo') + self.assertRaises(RevisionNotPresent, f.annotate, b"foo") def test_detection(self): # Test weaves detect corruption. @@ -592,23 +654,23 @@ def test_detection(self): w = self.get_file_corrupted_text() - self.assertEqual(b'hello\n', w.get_text(b'v1')) - self.assertRaises(WeaveInvalidChecksum, w.get_text, b'v2') - self.assertRaises(WeaveInvalidChecksum, w.get_lines, b'v2') + self.assertEqual(b"hello\n", w.get_text(b"v1")) + self.assertRaises(WeaveInvalidChecksum, w.get_text, b"v2") + self.assertRaises(WeaveInvalidChecksum, w.get_lines, b"v2") self.assertRaises(WeaveInvalidChecksum, w.check) w = self.get_file_corrupted_checksum() - self.assertEqual(b'hello\n', w.get_text(b'v1')) - self.assertRaises(WeaveInvalidChecksum, w.get_text, b'v2') - self.assertRaises(WeaveInvalidChecksum, w.get_lines, b'v2') + self.assertEqual(b"hello\n", w.get_text(b"v1")) + self.assertRaises(WeaveInvalidChecksum, w.get_text, b"v2") + self.assertRaises(WeaveInvalidChecksum, w.get_lines, b"v2") self.assertRaises(WeaveInvalidChecksum, w.check) def get_file_corrupted_text(self): """Return a versioned file with corrupt text but valid metadata.""" raise NotImplementedError(self.get_file_corrupted_text) - def reopen_file(self, name='foo'): + def reopen_file(self, name="foo"): """Open the versioned file from disk again.""" raise NotImplementedError(self.reopen_file) @@ -619,7 +681,6 @@ def test_iter_lines_added_or_present_in_versions(self): # more changes to muck up. class InstrumentedProgress(progress.ProgressTask): - def __init__(self): progress.ProgressTask.__init__(self) self.updates = [] @@ -629,52 +690,65 @@ def update(self, msg=None, current=None, total=None): vf = self.get_file() # add a base to get included - vf.add_lines(b'base', [], [b'base\n']) + vf.add_lines(b"base", [], [b"base\n"]) # add a ancestor to be included on one side - vf.add_lines(b'lancestor', [], [b'lancestor\n']) + vf.add_lines(b"lancestor", [], [b"lancestor\n"]) # add a ancestor to be included on the other side - vf.add_lines(b'rancestor', [b'base'], [b'rancestor\n']) + vf.add_lines(b"rancestor", [b"base"], [b"rancestor\n"]) # add a child of rancestor with no eofile-nl - vf.add_lines(b'child', [b'rancestor'], [b'base\n', b'child\n']) + vf.add_lines(b"child", [b"rancestor"], [b"base\n", b"child\n"]) # add a child of lancestor and base to join the two roots - vf.add_lines(b'otherchild', - [b'lancestor', b'base'], - [b'base\n', b'lancestor\n', b'otherchild\n']) + vf.add_lines( + b"otherchild", + [b"lancestor", b"base"], + [b"base\n", b"lancestor\n", b"otherchild\n"], + ) def iter_with_versions(versions, expected): # now we need to see what lines are returned, and how often. lines = {} progress = InstrumentedProgress() # iterate over the lines - for line in vf.iter_lines_added_or_present_in_versions(versions, - pb=progress): + for line in vf.iter_lines_added_or_present_in_versions( + versions, pb=progress + ): lines.setdefault(line, 0) lines[line] += 1 if [] != progress.updates: self.assertEqual(expected, progress.updates) return lines - lines = iter_with_versions([b'child', b'otherchild'], - [('Walking content', 0, 2), - ('Walking content', 1, 2), - ('Walking content', 2, 2)]) + + lines = iter_with_versions( + [b"child", b"otherchild"], + [ + ("Walking content", 0, 2), + ("Walking content", 1, 2), + ("Walking content", 2, 2), + ], + ) # we must see child and otherchild - self.assertTrue(lines[(b'child\n', b'child')] > 0) - self.assertTrue(lines[(b'otherchild\n', b'otherchild')] > 0) + self.assertTrue(lines[(b"child\n", b"child")] > 0) + self.assertTrue(lines[(b"otherchild\n", b"otherchild")] > 0) # we dont care if we got more than that. # test all lines - lines = iter_with_versions(None, [('Walking content', 0, 5), - ('Walking content', 1, 5), - ('Walking content', 2, 5), - ('Walking content', 3, 5), - ('Walking content', 4, 5), - ('Walking content', 5, 5)]) + lines = iter_with_versions( + None, + [ + ("Walking content", 0, 5), + ("Walking content", 1, 5), + ("Walking content", 2, 5), + ("Walking content", 3, 5), + ("Walking content", 4, 5), + ("Walking content", 5, 5), + ], + ) # all lines must be seen at least once - self.assertTrue(lines[(b'base\n', b'base')] > 0) - self.assertTrue(lines[(b'lancestor\n', b'lancestor')] > 0) - self.assertTrue(lines[(b'rancestor\n', b'rancestor')] > 0) - self.assertTrue(lines[(b'child\n', b'child')] > 0) - self.assertTrue(lines[(b'otherchild\n', b'otherchild')] > 0) + self.assertTrue(lines[(b"base\n", b"base")] > 0) + self.assertTrue(lines[(b"lancestor\n", b"lancestor")] > 0) + self.assertTrue(lines[(b"rancestor\n", b"rancestor")] > 0) + self.assertTrue(lines[(b"child\n", b"child")] > 0) + self.assertTrue(lines[(b"otherchild\n", b"otherchild")] > 0) def test_add_lines_with_ghosts(self): # some versioned file formats allow lines to be added with parent @@ -684,41 +758,43 @@ def test_add_lines_with_ghosts(self): vf = self.get_file() # add a revision with ghost parents # The preferred form is utf8, but we should translate when needed - parent_id_unicode = 'b\xbfse' - parent_id_utf8 = parent_id_unicode.encode('utf8') + parent_id_unicode = "b\xbfse" + parent_id_utf8 = parent_id_unicode.encode("utf8") try: - vf.add_lines_with_ghosts(b'notbxbfse', [parent_id_utf8], []) + vf.add_lines_with_ghosts(b"notbxbfse", [parent_id_utf8], []) except NotImplementedError: # check the other ghost apis are also not implemented - self.assertRaises(NotImplementedError, - vf.get_ancestry_with_ghosts, [b'foo']) - self.assertRaises(NotImplementedError, - vf.get_parents_with_ghosts, b'foo') + self.assertRaises( + NotImplementedError, vf.get_ancestry_with_ghosts, [b"foo"] + ) + self.assertRaises(NotImplementedError, vf.get_parents_with_ghosts, b"foo") return vf = self.reopen_file() # test key graph related apis: getncestry, _graph, get_parents # has_version # - these are ghost unaware and must not be reflect ghosts - self.assertEqual({b'notbxbfse'}, vf.get_ancestry(b'notbxbfse')) + self.assertEqual({b"notbxbfse"}, vf.get_ancestry(b"notbxbfse")) self.assertFalse(vf.has_version(parent_id_utf8)) # we have _with_ghost apis to give us ghost information. - self.assertEqual({parent_id_utf8, b'notbxbfse'}, - vf.get_ancestry_with_ghosts([b'notbxbfse'])) - self.assertEqual([parent_id_utf8], - vf.get_parents_with_ghosts(b'notbxbfse')) + self.assertEqual( + {parent_id_utf8, b"notbxbfse"}, vf.get_ancestry_with_ghosts([b"notbxbfse"]) + ) + self.assertEqual([parent_id_utf8], vf.get_parents_with_ghosts(b"notbxbfse")) # if we add something that is a ghost of another, it should correct the # results of the prior apis vf.add_lines(parent_id_utf8, [], []) - self.assertEqual({parent_id_utf8, b'notbxbfse'}, - vf.get_ancestry([b'notbxbfse'])) - self.assertEqual({b'notbxbfse': (parent_id_utf8,)}, - vf.get_parent_map([b'notbxbfse'])) + self.assertEqual( + {parent_id_utf8, b"notbxbfse"}, vf.get_ancestry([b"notbxbfse"]) + ) + self.assertEqual( + {b"notbxbfse": (parent_id_utf8,)}, vf.get_parent_map([b"notbxbfse"]) + ) self.assertTrue(vf.has_version(parent_id_utf8)) # we have _with_ghost apis to give us ghost information. - self.assertEqual({parent_id_utf8, b'notbxbfse'}, - vf.get_ancestry_with_ghosts([b'notbxbfse'])) - self.assertEqual([parent_id_utf8], - vf.get_parents_with_ghosts(b'notbxbfse')) + self.assertEqual( + {parent_id_utf8, b"notbxbfse"}, vf.get_ancestry_with_ghosts([b"notbxbfse"]) + ) + self.assertEqual([parent_id_utf8], vf.get_parents_with_ghosts(b"notbxbfse")) def test_add_lines_with_ghosts_after_normal_revs(self): # some versioned file formats allow lines to be added with parent @@ -728,160 +804,184 @@ def test_add_lines_with_ghosts_after_normal_revs(self): vf = self.get_file() # probe for ghost support try: - vf.add_lines_with_ghosts(b'base', [], [b'line\n', b'line_b\n']) + vf.add_lines_with_ghosts(b"base", [], [b"line\n", b"line_b\n"]) except NotImplementedError: return - vf.add_lines_with_ghosts(b'references_ghost', - [b'base', b'a_ghost'], - [b'line\n', b'line_b\n', b'line_c\n']) - origins = vf.annotate(b'references_ghost') - self.assertEqual((b'base', b'line\n'), origins[0]) - self.assertEqual((b'base', b'line_b\n'), origins[1]) - self.assertEqual((b'references_ghost', b'line_c\n'), origins[2]) + vf.add_lines_with_ghosts( + b"references_ghost", + [b"base", b"a_ghost"], + [b"line\n", b"line_b\n", b"line_c\n"], + ) + origins = vf.annotate(b"references_ghost") + self.assertEqual((b"base", b"line\n"), origins[0]) + self.assertEqual((b"base", b"line_b\n"), origins[1]) + self.assertEqual((b"references_ghost", b"line_c\n"), origins[2]) def test_readonly_mode(self): t = self.get_transport() factory = self.get_factory() - vf = factory('id', t, 0o777, create=True, access_mode='w') - vf = factory('id', t, access_mode='r') - self.assertRaises(errors.ReadOnlyError, vf.add_lines, b'base', [], []) - self.assertRaises(errors.ReadOnlyError, - vf.add_lines_with_ghosts, - b'base', - [], - []) + vf = factory("id", t, 0o777, create=True, access_mode="w") + vf = factory("id", t, access_mode="r") + self.assertRaises(errors.ReadOnlyError, vf.add_lines, b"base", [], []) + self.assertRaises( + errors.ReadOnlyError, vf.add_lines_with_ghosts, b"base", [], [] + ) def test_get_sha1s(self): # check the sha1 data is available vf = self.get_file() # a simple file - vf.add_lines(b'a', [], [b'a\n']) + vf.add_lines(b"a", [], [b"a\n"]) # the same file, different metadata - vf.add_lines(b'b', [b'a'], [b'a\n']) + vf.add_lines(b"b", [b"a"], [b"a\n"]) # a file differing only in last newline. - vf.add_lines(b'c', [], [b'a']) - self.assertEqual({ - b'a': b'3f786850e387550fdab836ed7e6dc881de23001b', - b'c': b'86f7e437faa5a7fce15d1ddcb9eaeaea377667b8', - b'b': b'3f786850e387550fdab836ed7e6dc881de23001b', + vf.add_lines(b"c", [], [b"a"]) + self.assertEqual( + { + b"a": b"3f786850e387550fdab836ed7e6dc881de23001b", + b"c": b"86f7e437faa5a7fce15d1ddcb9eaeaea377667b8", + b"b": b"3f786850e387550fdab836ed7e6dc881de23001b", }, - vf.get_sha1s([b'a', b'c', b'b'])) + vf.get_sha1s([b"a", b"c", b"b"]), + ) class TestWeave(TestCaseWithMemoryTransport, VersionedFileTestMixIn): - - def get_file(self, name='foo'): - return WeaveFile(name, self.get_transport(), - create=True, - get_scope=self.get_transaction) + def get_file(self, name="foo"): + return WeaveFile( + name, self.get_transport(), create=True, get_scope=self.get_transaction + ) def get_file_corrupted_text(self): - w = WeaveFile('foo', self.get_transport(), - create=True, - get_scope=self.get_transaction) - w.add_lines(b'v1', [], [b'hello\n']) - w.add_lines(b'v2', [b'v1'], [b'hello\n', b'there\n']) + w = WeaveFile( + "foo", self.get_transport(), create=True, get_scope=self.get_transaction + ) + w.add_lines(b"v1", [], [b"hello\n"]) + w.add_lines(b"v2", [b"v1"], [b"hello\n", b"there\n"]) # We are going to invasively corrupt the text # Make sure the internals of weave are the same - self.assertEqual([(b'{', 0), b'hello\n', (b'}', None), (b'{', 1), b'there\n', (b'}', None) - ], w._weave) + self.assertEqual( + [(b"{", 0), b"hello\n", (b"}", None), (b"{", 1), b"there\n", (b"}", None)], + w._weave, + ) - self.assertEqual([b'f572d396fae9206628714fb2ce00f72e94f2258f', b'90f265c6e75f1c8f9ab76dcf85528352c5f215ef' - ], w._sha1s) + self.assertEqual( + [ + b"f572d396fae9206628714fb2ce00f72e94f2258f", + b"90f265c6e75f1c8f9ab76dcf85528352c5f215ef", + ], + w._sha1s, + ) w.check() # Corrupted - w._weave[4] = b'There\n' + w._weave[4] = b"There\n" return w def get_file_corrupted_checksum(self): w = self.get_file_corrupted_text() # Corrected - w._weave[4] = b'there\n' - self.assertEqual(b'hello\nthere\n', w.get_text(b'v2')) + w._weave[4] = b"there\n" + self.assertEqual(b"hello\nthere\n", w.get_text(b"v2")) # Invalid checksum, first digit changed - w._sha1s[1] = b'f0f265c6e75f1c8f9ab76dcf85528352c5f215ef' + w._sha1s[1] = b"f0f265c6e75f1c8f9ab76dcf85528352c5f215ef" return w - def reopen_file(self, name='foo', create=False): - return WeaveFile(name, self.get_transport(), - create=create, - get_scope=self.get_transaction) + def reopen_file(self, name="foo", create=False): + return WeaveFile( + name, self.get_transport(), create=create, get_scope=self.get_transaction + ) def test_no_implicit_create(self): - self.assertRaises(transport.NoSuchFile, - WeaveFile, - 'foo', - self.get_transport(), - get_scope=self.get_transaction) + self.assertRaises( + transport.NoSuchFile, + WeaveFile, + "foo", + self.get_transport(), + get_scope=self.get_transaction, + ) def get_factory(self): return WeaveFile class TestPlanMergeVersionedFile(TestCaseWithMemoryTransport): - def setUp(self): super().setUp() mapper = PrefixMapper() factory = make_file_factory(True, mapper) - self.vf1 = factory(self.get_transport('root-1')) - self.vf2 = factory(self.get_transport('root-2')) - self.plan_merge_vf = versionedfile._PlanMergeVersionedFile('root') + self.vf1 = factory(self.get_transport("root-1")) + self.vf2 = factory(self.get_transport("root-2")) + self.plan_merge_vf = versionedfile._PlanMergeVersionedFile("root") self.plan_merge_vf.fallback_versionedfiles.extend([self.vf1, self.vf2]) def test_add_lines(self): - self.plan_merge_vf.add_lines((b'root', b'a:'), [], []) - self.assertRaises(ValueError, self.plan_merge_vf.add_lines, - (b'root', b'a'), [], []) - self.assertRaises(ValueError, self.plan_merge_vf.add_lines, - (b'root', b'a:'), None, []) - self.assertRaises(ValueError, self.plan_merge_vf.add_lines, - (b'root', b'a:'), [], None) + self.plan_merge_vf.add_lines((b"root", b"a:"), [], []) + self.assertRaises( + ValueError, self.plan_merge_vf.add_lines, (b"root", b"a"), [], [] + ) + self.assertRaises( + ValueError, self.plan_merge_vf.add_lines, (b"root", b"a:"), None, [] + ) + self.assertRaises( + ValueError, self.plan_merge_vf.add_lines, (b"root", b"a:"), [], None + ) def setup_abcde(self): - self.vf1.add_lines((b'root', b'A'), [], [b'a']) - self.vf1.add_lines((b'root', b'B'), [(b'root', b'A')], [b'b']) - self.vf2.add_lines((b'root', b'C'), [], [b'c']) - self.vf2.add_lines((b'root', b'D'), [(b'root', b'C')], [b'd']) - self.plan_merge_vf.add_lines((b'root', b'E:'), - [(b'root', b'B'), (b'root', b'D')], [b'e']) + self.vf1.add_lines((b"root", b"A"), [], [b"a"]) + self.vf1.add_lines((b"root", b"B"), [(b"root", b"A")], [b"b"]) + self.vf2.add_lines((b"root", b"C"), [], [b"c"]) + self.vf2.add_lines((b"root", b"D"), [(b"root", b"C")], [b"d"]) + self.plan_merge_vf.add_lines( + (b"root", b"E:"), [(b"root", b"B"), (b"root", b"D")], [b"e"] + ) def test_get_parents(self): self.setup_abcde() - self.assertEqual({(b'root', b'B'): ((b'root', b'A'),)}, - self.plan_merge_vf.get_parent_map([(b'root', b'B')])) - self.assertEqual({(b'root', b'D'): ((b'root', b'C'),)}, - self.plan_merge_vf.get_parent_map([(b'root', b'D')])) - self.assertEqual({(b'root', b'E:'): ((b'root', b'B'), (b'root', b'D'))}, - self.plan_merge_vf.get_parent_map([(b'root', b'E:')])) - self.assertEqual({}, - self.plan_merge_vf.get_parent_map([(b'root', b'F')])) - self.assertEqual({ - (b'root', b'B'): ((b'root', b'A'),), - (b'root', b'D'): ((b'root', b'C'),), - (b'root', b'E:'): ((b'root', b'B'), (b'root', b'D')), + self.assertEqual( + {(b"root", b"B"): ((b"root", b"A"),)}, + self.plan_merge_vf.get_parent_map([(b"root", b"B")]), + ) + self.assertEqual( + {(b"root", b"D"): ((b"root", b"C"),)}, + self.plan_merge_vf.get_parent_map([(b"root", b"D")]), + ) + self.assertEqual( + {(b"root", b"E:"): ((b"root", b"B"), (b"root", b"D"))}, + self.plan_merge_vf.get_parent_map([(b"root", b"E:")]), + ) + self.assertEqual({}, self.plan_merge_vf.get_parent_map([(b"root", b"F")])) + self.assertEqual( + { + (b"root", b"B"): ((b"root", b"A"),), + (b"root", b"D"): ((b"root", b"C"),), + (b"root", b"E:"): ((b"root", b"B"), (b"root", b"D")), }, self.plan_merge_vf.get_parent_map( - [(b'root', b'B'), (b'root', b'D'), (b'root', b'E:'), (b'root', b'F')])) + [(b"root", b"B"), (b"root", b"D"), (b"root", b"E:"), (b"root", b"F")] + ), + ) def test_get_record_stream(self): self.setup_abcde() def get_record(suffix): - return next(self.plan_merge_vf.get_record_stream( - [(b'root', suffix)], 'unordered', True)) - self.assertEqual(b'a', get_record(b'A').get_bytes_as('fulltext')) - self.assertEqual(b'a', b''.join(get_record(b'A').iter_bytes_as('chunked'))) - self.assertEqual(b'c', get_record(b'C').get_bytes_as('fulltext')) - self.assertEqual(b'e', get_record(b'E:').get_bytes_as('fulltext')) - self.assertEqual('absent', get_record('F').storage_kind) + return next( + self.plan_merge_vf.get_record_stream( + [(b"root", suffix)], "unordered", True + ) + ) + self.assertEqual(b"a", get_record(b"A").get_bytes_as("fulltext")) + self.assertEqual(b"a", b"".join(get_record(b"A").iter_bytes_as("chunked"))) + self.assertEqual(b"c", get_record(b"C").get_bytes_as("fulltext")) + self.assertEqual(b"e", get_record(b"E:").get_bytes_as("fulltext")) + self.assertEqual("absent", get_record("F").storage_kind) -class TestReadonlyHttpMixin: +class TestReadonlyHttpMixin: def get_transaction(self): return 1 @@ -889,56 +989,55 @@ def test_readonly_http_works(self): # we should be able to read from http with a versioned file. self.get_file() # try an empty file access - readonly_vf = self.get_factory()('foo', - transport.get_transport_from_url(self.get_readonly_url('.'))) + readonly_vf = self.get_factory()( + "foo", transport.get_transport_from_url(self.get_readonly_url(".")) + ) self.assertEqual([], readonly_vf.versions()) def test_readonly_http_works_with_feeling(self): # we should be able to read from http with a versioned file. vf = self.get_file() # now with feeling. - vf.add_lines(b'1', [], [b'a\n']) - vf.add_lines(b'2', [b'1'], [b'b\n', b'a\n']) - readonly_vf = self.get_factory()('foo', - transport.get_transport_from_url(self.get_readonly_url('.'))) - self.assertEqual([b'1', b'2'], vf.versions()) - self.assertEqual([b'1', b'2'], readonly_vf.versions()) + vf.add_lines(b"1", [], [b"a\n"]) + vf.add_lines(b"2", [b"1"], [b"b\n", b"a\n"]) + readonly_vf = self.get_factory()( + "foo", transport.get_transport_from_url(self.get_readonly_url(".")) + ) + self.assertEqual([b"1", b"2"], vf.versions()) + self.assertEqual([b"1", b"2"], readonly_vf.versions()) for version in readonly_vf.versions(): readonly_vf.get_lines(version) class TestWeaveHTTP(TestCaseWithWebserver, TestReadonlyHttpMixin): - def get_file(self): - return WeaveFile('foo', self.get_transport(), - create=True, - get_scope=self.get_transaction) + return WeaveFile( + "foo", self.get_transport(), create=True, get_scope=self.get_transaction + ) def get_factory(self): return WeaveFile class MergeCasesMixin: - def doMerge(self, base, a, b, mp): - def addcrlf(x): - return x + b'\n' + return x + b"\n" w = self.get_file() - w.add_lines(b'text0', [], list(map(addcrlf, base))) - w.add_lines(b'text1', [b'text0'], list(map(addcrlf, a))) - w.add_lines(b'text2', [b'text0'], list(map(addcrlf, b))) + w.add_lines(b"text0", [], list(map(addcrlf, base))) + w.add_lines(b"text1", [b"text0"], list(map(addcrlf, a))) + w.add_lines(b"text2", [b"text0"], list(map(addcrlf, b))) self.log_contents(w) - self.log('merge plan:') - p = list(w.plan_merge(b'text1', b'text2')) + self.log("merge plan:") + p = list(w.plan_merge(b"text1", b"text2")) for state, line in p: if line: - self.log('%12s | %s' % (state, line[:-1])) + self.log("%12s | %s" % (state, line[:-1])) - self.log('merge:') + self.log("merge:") mt = BytesIO() mt.writelines(w.weave_merge(p)) mt.seek(0) @@ -948,51 +1047,55 @@ def addcrlf(x): self.assertEqual(mt.readlines(), mp) def testOneInsert(self): - self.doMerge([], - [b'aa'], - [], - [b'aa']) + self.doMerge([], [b"aa"], [], [b"aa"]) def testSeparateInserts(self): - self.doMerge([b'aaa', b'bbb', b'ccc'], - [b'aaa', b'xxx', b'bbb', b'ccc'], - [b'aaa', b'bbb', b'yyy', b'ccc'], - [b'aaa', b'xxx', b'bbb', b'yyy', b'ccc']) + self.doMerge( + [b"aaa", b"bbb", b"ccc"], + [b"aaa", b"xxx", b"bbb", b"ccc"], + [b"aaa", b"bbb", b"yyy", b"ccc"], + [b"aaa", b"xxx", b"bbb", b"yyy", b"ccc"], + ) def testSameInsert(self): - self.doMerge([b'aaa', b'bbb', b'ccc'], - [b'aaa', b'xxx', b'bbb', b'ccc'], - [b'aaa', b'xxx', b'bbb', b'yyy', b'ccc'], - [b'aaa', b'xxx', b'bbb', b'yyy', b'ccc']) - overlapped_insert_expected = [b'aaa', b'xxx', b'yyy', b'bbb'] + self.doMerge( + [b"aaa", b"bbb", b"ccc"], + [b"aaa", b"xxx", b"bbb", b"ccc"], + [b"aaa", b"xxx", b"bbb", b"yyy", b"ccc"], + [b"aaa", b"xxx", b"bbb", b"yyy", b"ccc"], + ) + + overlapped_insert_expected = [b"aaa", b"xxx", b"yyy", b"bbb"] def testOverlappedInsert(self): - self.doMerge([b'aaa', b'bbb'], - [b'aaa', b'xxx', b'yyy', b'bbb'], - [b'aaa', b'xxx', b'bbb'], self.overlapped_insert_expected) + self.doMerge( + [b"aaa", b"bbb"], + [b"aaa", b"xxx", b"yyy", b"bbb"], + [b"aaa", b"xxx", b"bbb"], + self.overlapped_insert_expected, + ) # really it ought to reduce this to # [b'aaa', b'xxx', b'yyy', b'bbb'] def testClashReplace(self): - self.doMerge([b'aaa'], - [b'xxx'], - [b'yyy', b'zzz'], - [b'<<<<<<< ', b'xxx', b'=======', b'yyy', b'zzz', - b'>>>>>>> ']) + self.doMerge( + [b"aaa"], + [b"xxx"], + [b"yyy", b"zzz"], + [b"<<<<<<< ", b"xxx", b"=======", b"yyy", b"zzz", b">>>>>>> "], + ) def testNonClashInsert1(self): - self.doMerge([b'aaa'], - [b'xxx', b'aaa'], - [b'yyy', b'zzz'], - [b'<<<<<<< ', b'xxx', b'aaa', b'=======', b'yyy', b'zzz', - b'>>>>>>> ']) + self.doMerge( + [b"aaa"], + [b"xxx", b"aaa"], + [b"yyy", b"zzz"], + [b"<<<<<<< ", b"xxx", b"aaa", b"=======", b"yyy", b"zzz", b">>>>>>> "], + ) def testNonClashInsert2(self): - self.doMerge([b'aaa'], - [b'aaa'], - [b'yyy', b'zzz'], - [b'yyy', b'zzz']) + self.doMerge([b"aaa"], [b"aaa"], [b"yyy", b"zzz"], [b"yyy", b"zzz"]) def testDeleteAndModify(self): """Clashing delete and modification. @@ -1004,30 +1107,32 @@ def testDeleteAndModify(self): # skippd, not working yet return - self.doMerge([b'aaa', b'bbb', b'ccc'], - [b'aaa', b'ddd', b'ccc'], - [b'aaa', b'ccc'], - [b'<<<<<<<< ', b'aaa', b'=======', b'>>>>>>> ', b'ccc']) + self.doMerge( + [b"aaa", b"bbb", b"ccc"], + [b"aaa", b"ddd", b"ccc"], + [b"aaa", b"ccc"], + [b"<<<<<<<< ", b"aaa", b"=======", b">>>>>>> ", b"ccc"], + ) def _test_merge_from_strings(self, base, a, b, expected): w = self.get_file() - w.add_lines(b'text0', [], base.splitlines(True)) - w.add_lines(b'text1', [b'text0'], a.splitlines(True)) - w.add_lines(b'text2', [b'text0'], b.splitlines(True)) - self.log('merge plan:') - p = list(w.plan_merge(b'text1', b'text2')) + w.add_lines(b"text0", [], base.splitlines(True)) + w.add_lines(b"text1", [b"text0"], a.splitlines(True)) + w.add_lines(b"text2", [b"text0"], b.splitlines(True)) + self.log("merge plan:") + p = list(w.plan_merge(b"text1", b"text2")) for state, line in p: if line: - self.log('%12s | %s' % (state, line[:-1])) - self.log('merge result:') - result_text = b''.join(w.weave_merge(p)) + self.log("%12s | %s" % (state, line[:-1])) + self.log("merge result:") + result_text = b"".join(w.weave_merge(p)) self.log(result_text) self.assertEqualDiff(result_text, expected) def test_weave_merge_conflicts(self): # does weave merge properly handle plans that end with unchanged? - result = b''.join(self.get_file().weave_merge([('new-a', b'hello\n')])) - self.assertEqual(result, b'hello\n') + result = b"".join(self.get_file().weave_merge([("new-a", b"hello\n")])) + self.assertEqual(result, b"hello\n") def test_deletion_extended(self): """One side deletes, the other deletes more.""" @@ -1154,66 +1259,68 @@ def test_sync_on_deletion(self): class TestWeaveMerge(TestCaseWithMemoryTransport, MergeCasesMixin): - - def get_file(self, name='foo'): - return WeaveFile(name, self.get_transport(), - create=True) + def get_file(self, name="foo"): + return WeaveFile(name, self.get_transport(), create=True) def log_contents(self, w): - self.log('weave is:') + self.log("weave is:") tmpf = BytesIO() write_weave(w, tmpf) self.log(tmpf.getvalue()) - overlapped_insert_expected = [b'aaa', b'<<<<<<< ', b'xxx', b'yyy', b'=======', - b'xxx', b'>>>>>>> ', b'bbb'] + overlapped_insert_expected = [ + b"aaa", + b"<<<<<<< ", + b"xxx", + b"yyy", + b"=======", + b"xxx", + b">>>>>>> ", + b"bbb", + ] class TestContentFactoryAdaption(TestCaseWithMemoryTransport): - def test_select_adaptor(self): """Test expected adapters exist.""" # One scenario for each lookup combination we expect to use. # Each is source_kind, requested_kind, adapter class scenarios = [ - ('knit-delta-gz', 'fulltext', _mod_knit.DeltaPlainToFullText), - ('knit-delta-gz', 'lines', _mod_knit.DeltaPlainToFullText), - ('knit-delta-gz', 'chunked', _mod_knit.DeltaPlainToFullText), - ('knit-ft-gz', 'fulltext', _mod_knit.FTPlainToFullText), - ('knit-ft-gz', 'lines', _mod_knit.FTPlainToFullText), - ('knit-ft-gz', 'chunked', _mod_knit.FTPlainToFullText), - ('knit-annotated-delta-gz', 'knit-delta-gz', - _mod_knit.DeltaAnnotatedToUnannotated), - ('knit-annotated-delta-gz', 'fulltext', - _mod_knit.DeltaAnnotatedToFullText), - ('knit-annotated-ft-gz', 'knit-ft-gz', - _mod_knit.FTAnnotatedToUnannotated), - ('knit-annotated-ft-gz', 'fulltext', - _mod_knit.FTAnnotatedToFullText), - ('knit-annotated-ft-gz', 'lines', - _mod_knit.FTAnnotatedToFullText), - ('knit-annotated-ft-gz', 'chunked', - _mod_knit.FTAnnotatedToFullText), - ] + ("knit-delta-gz", "fulltext", _mod_knit.DeltaPlainToFullText), + ("knit-delta-gz", "lines", _mod_knit.DeltaPlainToFullText), + ("knit-delta-gz", "chunked", _mod_knit.DeltaPlainToFullText), + ("knit-ft-gz", "fulltext", _mod_knit.FTPlainToFullText), + ("knit-ft-gz", "lines", _mod_knit.FTPlainToFullText), + ("knit-ft-gz", "chunked", _mod_knit.FTPlainToFullText), + ( + "knit-annotated-delta-gz", + "knit-delta-gz", + _mod_knit.DeltaAnnotatedToUnannotated, + ), + ("knit-annotated-delta-gz", "fulltext", _mod_knit.DeltaAnnotatedToFullText), + ("knit-annotated-ft-gz", "knit-ft-gz", _mod_knit.FTAnnotatedToUnannotated), + ("knit-annotated-ft-gz", "fulltext", _mod_knit.FTAnnotatedToFullText), + ("knit-annotated-ft-gz", "lines", _mod_knit.FTAnnotatedToFullText), + ("knit-annotated-ft-gz", "chunked", _mod_knit.FTAnnotatedToFullText), + ] for source, requested, klass in scenarios: - adapter_factory = versionedfile.adapter_registry.get( - (source, requested)) + adapter_factory = versionedfile.adapter_registry.get((source, requested)) adapter = adapter_factory(None) self.assertIsInstance(adapter, klass) def get_knit(self, annotated=True): - mapper = ConstantMapper('knit') + mapper = ConstantMapper("knit") transport = self.get_transport() return make_file_factory(annotated, mapper)(transport) def helpGetBytes(self, f, ft_name, ft_adapter, delta_name, delta_adapter): """Grab the interested adapted texts for tests.""" # origin is a fulltext - entries = f.get_record_stream([(b'origin',)], 'unordered', False) + entries = f.get_record_stream([(b"origin",)], "unordered", False) base = next(entries) ft_data = ft_adapter.get_bytes(base, ft_name) # merged is both a delta and multiple parents. - entries = f.get_record_stream([(b'merged',)], 'unordered', False) + entries = f.get_record_stream([(b"merged",)], "unordered", False) merged = next(entries) delta_data = delta_adapter.get_bytes(merged, delta_name) return ft_data, delta_data @@ -1224,17 +1331,23 @@ def test_deannotation_noeol(self): f = self.get_knit() get_diamond_files(f, 1, trailing_eol=False) ft_data, delta_data = self.helpGetBytes( - f, 'knit-ft-gz', _mod_knit.FTAnnotatedToUnannotated(None), - 'knit-delta-gz', _mod_knit.DeltaAnnotatedToUnannotated(None)) + f, + "knit-ft-gz", + _mod_knit.FTAnnotatedToUnannotated(None), + "knit-delta-gz", + _mod_knit.DeltaAnnotatedToUnannotated(None), + ) self.assertEqual( - b'version origin 1 b284f94827db1fa2970d9e2014f080413b547a7e\n' - b'origin\n' - b'end origin\n', - GzipFile(mode='rb', fileobj=BytesIO(ft_data)).read()) + b"version origin 1 b284f94827db1fa2970d9e2014f080413b547a7e\n" + b"origin\n" + b"end origin\n", + GzipFile(mode="rb", fileobj=BytesIO(ft_data)).read(), + ) self.assertEqual( - b'version merged 4 32c2e79763b3f90e8ccde37f9710b6629c25a796\n' - b'1,2,3\nleft\nright\nmerged\nend merged\n', - GzipFile(mode='rb', fileobj=BytesIO(delta_data)).read()) + b"version merged 4 32c2e79763b3f90e8ccde37f9710b6629c25a796\n" + b"1,2,3\nleft\nright\nmerged\nend merged\n", + GzipFile(mode="rb", fileobj=BytesIO(delta_data)).read(), + ) def test_deannotation(self): """Test converting annotated knits to unannotated knits.""" @@ -1242,17 +1355,23 @@ def test_deannotation(self): f = self.get_knit() get_diamond_files(f, 1) ft_data, delta_data = self.helpGetBytes( - f, 'knit-ft-gz', _mod_knit.FTAnnotatedToUnannotated(None), - 'knit-delta-gz', _mod_knit.DeltaAnnotatedToUnannotated(None)) + f, + "knit-ft-gz", + _mod_knit.FTAnnotatedToUnannotated(None), + "knit-delta-gz", + _mod_knit.DeltaAnnotatedToUnannotated(None), + ) self.assertEqual( - b'version origin 1 00e364d235126be43292ab09cb4686cf703ddc17\n' - b'origin\n' - b'end origin\n', - GzipFile(mode='rb', fileobj=BytesIO(ft_data)).read()) + b"version origin 1 00e364d235126be43292ab09cb4686cf703ddc17\n" + b"origin\n" + b"end origin\n", + GzipFile(mode="rb", fileobj=BytesIO(ft_data)).read(), + ) self.assertEqual( - b'version merged 3 ed8bce375198ea62444dc71952b22cfc2b09226d\n' - b'2,2,2\nright\nmerged\nend merged\n', - GzipFile(mode='rb', fileobj=BytesIO(delta_data)).read()) + b"version merged 3 ed8bce375198ea62444dc71952b22cfc2b09226d\n" + b"2,2,2\nright\nmerged\nend merged\n", + GzipFile(mode="rb", fileobj=BytesIO(delta_data)).read(), + ) def test_annotated_to_fulltext_no_eol(self): """Test adapting annotated knits to full texts (for -> weaves).""" @@ -1263,12 +1382,17 @@ def test_annotated_to_fulltext_no_eol(self): # must have the base lines requested from it. logged_vf = versionedfile.RecordingVersionedFilesDecorator(f) ft_data, delta_data = self.helpGetBytes( - f, 'fulltext', _mod_knit.FTAnnotatedToFullText(None), - 'fulltext', _mod_knit.DeltaAnnotatedToFullText(logged_vf)) - self.assertEqual(b'origin', ft_data) - self.assertEqual(b'base\nleft\nright\nmerged', delta_data) - self.assertEqual([('get_record_stream', [(b'left',)], 'unordered', - True)], logged_vf.calls) + f, + "fulltext", + _mod_knit.FTAnnotatedToFullText(None), + "fulltext", + _mod_knit.DeltaAnnotatedToFullText(logged_vf), + ) + self.assertEqual(b"origin", ft_data) + self.assertEqual(b"base\nleft\nright\nmerged", delta_data) + self.assertEqual( + [("get_record_stream", [(b"left",)], "unordered", True)], logged_vf.calls + ) def test_annotated_to_fulltext(self): """Test adapting annotated knits to full texts (for -> weaves).""" @@ -1279,12 +1403,17 @@ def test_annotated_to_fulltext(self): # must have the base lines requested from it. logged_vf = versionedfile.RecordingVersionedFilesDecorator(f) ft_data, delta_data = self.helpGetBytes( - f, 'fulltext', _mod_knit.FTAnnotatedToFullText(None), - 'fulltext', _mod_knit.DeltaAnnotatedToFullText(logged_vf)) - self.assertEqual(b'origin\n', ft_data) - self.assertEqual(b'base\nleft\nright\nmerged\n', delta_data) - self.assertEqual([('get_record_stream', [(b'left',)], 'unordered', - True)], logged_vf.calls) + f, + "fulltext", + _mod_knit.FTAnnotatedToFullText(None), + "fulltext", + _mod_knit.DeltaAnnotatedToFullText(logged_vf), + ) + self.assertEqual(b"origin\n", ft_data) + self.assertEqual(b"base\nleft\nright\nmerged\n", delta_data) + self.assertEqual( + [("get_record_stream", [(b"left",)], "unordered", True)], logged_vf.calls + ) def test_unannotated_to_fulltext(self): """Test adapting unannotated knits to full texts. @@ -1298,12 +1427,17 @@ def test_unannotated_to_fulltext(self): # must have the base lines requested from it. logged_vf = versionedfile.RecordingVersionedFilesDecorator(f) ft_data, delta_data = self.helpGetBytes( - f, 'fulltext', _mod_knit.FTPlainToFullText(None), - 'fulltext', _mod_knit.DeltaPlainToFullText(logged_vf)) - self.assertEqual(b'origin\n', ft_data) - self.assertEqual(b'base\nleft\nright\nmerged\n', delta_data) - self.assertEqual([('get_record_stream', [(b'left',)], 'unordered', - True)], logged_vf.calls) + f, + "fulltext", + _mod_knit.FTPlainToFullText(None), + "fulltext", + _mod_knit.DeltaPlainToFullText(logged_vf), + ) + self.assertEqual(b"origin\n", ft_data) + self.assertEqual(b"base\nleft\nright\nmerged\n", delta_data) + self.assertEqual( + [("get_record_stream", [(b"left",)], "unordered", True)], logged_vf.calls + ) def test_unannotated_to_fulltext_no_eol(self): """Test adapting unannotated knits to full texts. @@ -1317,12 +1451,17 @@ def test_unannotated_to_fulltext_no_eol(self): # must have the base lines requested from it. logged_vf = versionedfile.RecordingVersionedFilesDecorator(f) ft_data, delta_data = self.helpGetBytes( - f, 'fulltext', _mod_knit.FTPlainToFullText(None), - 'fulltext', _mod_knit.DeltaPlainToFullText(logged_vf)) - self.assertEqual(b'origin', ft_data) - self.assertEqual(b'base\nleft\nright\nmerged', delta_data) - self.assertEqual([('get_record_stream', [(b'left',)], 'unordered', - True)], logged_vf.calls) + f, + "fulltext", + _mod_knit.FTPlainToFullText(None), + "fulltext", + _mod_knit.DeltaPlainToFullText(logged_vf), + ) + self.assertEqual(b"origin", ft_data) + self.assertEqual(b"base\nleft\nright\nmerged", delta_data) + self.assertEqual( + [("get_record_stream", [(b"left",)], "unordered", True)], logged_vf.calls + ) class TestKeyMapper(TestCaseWithMemoryTransport): @@ -1330,36 +1469,33 @@ class TestKeyMapper(TestCaseWithMemoryTransport): def test_identity_mapper(self): mapper = versionedfile.ConstantMapper("inventory") - self.assertEqual("inventory", mapper.map((b'foo@ar',))) - self.assertEqual("inventory", mapper.map((b'quux',))) + self.assertEqual("inventory", mapper.map((b"foo@ar",))) + self.assertEqual("inventory", mapper.map((b"quux",))) def test_prefix_mapper(self): - #format5: plain + # format5: plain mapper = versionedfile.PrefixMapper() self.assertEqual("file-id", mapper.map((b"file-id", b"revision-id"))) self.assertEqual("new-id", mapper.map((b"new-id", b"revision-id"))) - self.assertEqual((b'file-id',), mapper.unmap("file-id")) - self.assertEqual((b'new-id',), mapper.unmap("new-id")) + self.assertEqual((b"file-id",), mapper.unmap("file-id")) + self.assertEqual((b"new-id",), mapper.unmap("new-id")) def test_hash_prefix_mapper(self): - #format6: hash + plain + # format6: hash + plain mapper = versionedfile.HashPrefixMapper() - self.assertEqual( - "9b/file-id", mapper.map((b"file-id", b"revision-id"))) + self.assertEqual("9b/file-id", mapper.map((b"file-id", b"revision-id"))) self.assertEqual("45/new-id", mapper.map((b"new-id", b"revision-id"))) - self.assertEqual((b'file-id',), mapper.unmap("9b/file-id")) - self.assertEqual((b'new-id',), mapper.unmap("45/new-id")) + self.assertEqual((b"file-id",), mapper.unmap("9b/file-id")) + self.assertEqual((b"new-id",), mapper.unmap("45/new-id")) def test_hash_escaped_mapper(self): - #knit1: hash + escaped + # knit1: hash + escaped mapper = versionedfile.HashEscapedPrefixMapper() self.assertEqual("88/%2520", mapper.map((b" ", b"revision-id"))) - self.assertEqual("ed/fil%2545-%2549d", mapper.map((b"filE-Id", - b"revision-id"))) - self.assertEqual("88/ne%2557-%2549d", mapper.map((b"neW-Id", - b"revision-id"))) - self.assertEqual((b'filE-Id',), mapper.unmap("ed/fil%2545-%2549d")) - self.assertEqual((b'neW-Id',), mapper.unmap("88/ne%2557-%2549d")) + self.assertEqual("ed/fil%2545-%2549d", mapper.map((b"filE-Id", b"revision-id"))) + self.assertEqual("88/ne%2557-%2549d", mapper.map((b"neW-Id", b"revision-id"))) + self.assertEqual((b"filE-Id",), mapper.unmap("ed/fil%2545-%2549d")) + self.assertEqual((b"neW-Id",), mapper.unmap("88/ne%2557-%2549d")) class TestVersionedFiles(TestCaseWithMemoryTransport): @@ -1376,88 +1512,118 @@ class TestVersionedFiles(TestCaseWithMemoryTransport): # individual graph nocompression knits in packs (revisions) # plain text knits in packs (texts) len_one_scenarios = [ - ('weave-named', { - 'cleanup': None, - 'factory': make_versioned_files_factory(WeaveFile, - ConstantMapper('inventory')), - 'graph': True, - 'key_length': 1, - 'support_partial_insertion': False, - }), - ('named-knit', { - 'cleanup': None, - 'factory': make_file_factory(False, ConstantMapper('revisions')), - 'graph': True, - 'key_length': 1, - 'support_partial_insertion': False, - }), - ('named-nograph-nodelta-knit-pack', { - 'cleanup': cleanup_pack_knit, - 'factory': make_pack_factory(False, False, 1), - 'graph': False, - 'key_length': 1, - 'support_partial_insertion': False, - }), - ('named-graph-knit-pack', { - 'cleanup': cleanup_pack_knit, - 'factory': make_pack_factory(True, True, 1), - 'graph': True, - 'key_length': 1, - 'support_partial_insertion': True, - }), - ('named-graph-nodelta-knit-pack', { - 'cleanup': cleanup_pack_knit, - 'factory': make_pack_factory(True, False, 1), - 'graph': True, - 'key_length': 1, - 'support_partial_insertion': False, - }), - ('groupcompress-nograph', { - 'cleanup': groupcompress.cleanup_pack_group, - 'factory': groupcompress.make_pack_factory(False, False, 1), - 'graph': False, - 'key_length': 1, - 'support_partial_insertion': False, - }), - ] + ( + "weave-named", + { + "cleanup": None, + "factory": make_versioned_files_factory( + WeaveFile, ConstantMapper("inventory") + ), + "graph": True, + "key_length": 1, + "support_partial_insertion": False, + }, + ), + ( + "named-knit", + { + "cleanup": None, + "factory": make_file_factory(False, ConstantMapper("revisions")), + "graph": True, + "key_length": 1, + "support_partial_insertion": False, + }, + ), + ( + "named-nograph-nodelta-knit-pack", + { + "cleanup": cleanup_pack_knit, + "factory": make_pack_factory(False, False, 1), + "graph": False, + "key_length": 1, + "support_partial_insertion": False, + }, + ), + ( + "named-graph-knit-pack", + { + "cleanup": cleanup_pack_knit, + "factory": make_pack_factory(True, True, 1), + "graph": True, + "key_length": 1, + "support_partial_insertion": True, + }, + ), + ( + "named-graph-nodelta-knit-pack", + { + "cleanup": cleanup_pack_knit, + "factory": make_pack_factory(True, False, 1), + "graph": True, + "key_length": 1, + "support_partial_insertion": False, + }, + ), + ( + "groupcompress-nograph", + { + "cleanup": groupcompress.cleanup_pack_group, + "factory": groupcompress.make_pack_factory(False, False, 1), + "graph": False, + "key_length": 1, + "support_partial_insertion": False, + }, + ), + ] len_two_scenarios = [ - ('weave-prefix', { - 'cleanup': None, - 'factory': make_versioned_files_factory(WeaveFile, - PrefixMapper()), - 'graph': True, - 'key_length': 2, - 'support_partial_insertion': False, - }), - ('annotated-knit-escape', { - 'cleanup': None, - 'factory': make_file_factory(True, HashEscapedPrefixMapper()), - 'graph': True, - 'key_length': 2, - 'support_partial_insertion': False, - }), - ('plain-knit-pack', { - 'cleanup': cleanup_pack_knit, - 'factory': make_pack_factory(True, True, 2), - 'graph': True, - 'key_length': 2, - 'support_partial_insertion': True, - }), - ('groupcompress', { - 'cleanup': groupcompress.cleanup_pack_group, - 'factory': groupcompress.make_pack_factory(True, False, 1), - 'graph': True, - 'key_length': 1, - 'support_partial_insertion': False, - }), - ] + ( + "weave-prefix", + { + "cleanup": None, + "factory": make_versioned_files_factory(WeaveFile, PrefixMapper()), + "graph": True, + "key_length": 2, + "support_partial_insertion": False, + }, + ), + ( + "annotated-knit-escape", + { + "cleanup": None, + "factory": make_file_factory(True, HashEscapedPrefixMapper()), + "graph": True, + "key_length": 2, + "support_partial_insertion": False, + }, + ), + ( + "plain-knit-pack", + { + "cleanup": cleanup_pack_knit, + "factory": make_pack_factory(True, True, 2), + "graph": True, + "key_length": 2, + "support_partial_insertion": True, + }, + ), + ( + "groupcompress", + { + "cleanup": groupcompress.cleanup_pack_group, + "factory": groupcompress.make_pack_factory(True, False, 1), + "graph": True, + "key_length": 1, + "support_partial_insertion": False, + }, + ), + ] scenarios = len_one_scenarios + len_two_scenarios - def get_versionedfiles(self, relpath='files'): + def get_versionedfiles(self, relpath="files"): transport = self.get_transport(relpath) - if relpath != '.': - transport.mkdir('.') + if relpath != ".": + transport.mkdir(".") files = self.factory(transport) if self.cleanup is not None: self.addCleanup(self.cleanup, files) @@ -1468,64 +1634,65 @@ def get_simple_key(self, suffix): if self.key_length == 1: return (suffix,) else: - return (b'FileA',) + (suffix,) + return (b"FileA",) + (suffix,) def test_add_fallback_implies_without_fallbacks(self): - f = self.get_versionedfiles('files') - if getattr(f, 'add_fallback_versioned_files', None) is None: + f = self.get_versionedfiles("files") + if getattr(f, "add_fallback_versioned_files", None) is None: raise TestNotApplicable(f"{f.__class__.__name__} doesn't support fallbacks") - g = self.get_versionedfiles('fallback') - key_a = self.get_simple_key(b'a') - g.add_lines(key_a, [], [b'\n']) + g = self.get_versionedfiles("fallback") + key_a = self.get_simple_key(b"a") + g.add_lines(key_a, [], [b"\n"]) f.add_fallback_versioned_files(g) self.assertTrue(key_a in f.get_parent_map([key_a])) - self.assertFalse( - key_a in f.without_fallbacks().get_parent_map([key_a])) + self.assertFalse(key_a in f.without_fallbacks().get_parent_map([key_a])) def test_add_lines(self): f = self.get_versionedfiles() - key0 = self.get_simple_key(b'r0') - key1 = self.get_simple_key(b'r1') - self.get_simple_key(b'r2') - self.get_simple_key(b'foo') - f.add_lines(key0, [], [b'a\n', b'b\n']) + key0 = self.get_simple_key(b"r0") + key1 = self.get_simple_key(b"r1") + self.get_simple_key(b"r2") + self.get_simple_key(b"foo") + f.add_lines(key0, [], [b"a\n", b"b\n"]) if self.graph: - f.add_lines(key1, [key0], [b'b\n', b'c\n']) + f.add_lines(key1, [key0], [b"b\n", b"c\n"]) else: - f.add_lines(key1, [], [b'b\n', b'c\n']) + f.add_lines(key1, [], [b"b\n", b"c\n"]) keys = f.keys() self.assertTrue(key0 in keys) self.assertTrue(key1 in keys) records = [] - for record in f.get_record_stream([key0, key1], 'unordered', True): - records.append((record.key, record.get_bytes_as('fulltext'))) + for record in f.get_record_stream([key0, key1], "unordered", True): + records.append((record.key, record.get_bytes_as("fulltext"))) records.sort() - self.assertEqual([(key0, b'a\nb\n'), (key1, b'b\nc\n')], records) + self.assertEqual([(key0, b"a\nb\n"), (key1, b"b\nc\n")], records) def test_add_chunks(self): f = self.get_versionedfiles() - key0 = self.get_simple_key(b'r0') - key1 = self.get_simple_key(b'r1') - self.get_simple_key(b'r2') - self.get_simple_key(b'foo') + key0 = self.get_simple_key(b"r0") + key1 = self.get_simple_key(b"r1") + self.get_simple_key(b"r2") + self.get_simple_key(b"foo") + def add_chunks(key, parents, chunks): factory = ChunkedContentFactory( - key, parents, osutils.sha_strings(chunks), chunks) + key, parents, osutils.sha_strings(chunks), chunks + ) return f.add_content(factory) - add_chunks(key0, [], [b'a', b'\nb\n']) + add_chunks(key0, [], [b"a", b"\nb\n"]) if self.graph: - add_chunks(key1, [key0], [b'b', b'\n', b'c\n']) + add_chunks(key1, [key0], [b"b", b"\n", b"c\n"]) else: - add_chunks(key1, [], [b'b\n', b'c\n']) + add_chunks(key1, [], [b"b\n", b"c\n"]) keys = f.keys() self.assertIn(key0, keys) self.assertIn(key1, keys) records = [] - for record in f.get_record_stream([key0, key1], 'unordered', True): - records.append((record.key, record.get_bytes_as('fulltext'))) + for record in f.get_record_stream([key0, key1], "unordered", True): + records.append((record.key, record.get_bytes_as("fulltext"))) records.sort() - self.assertEqual([(key0, b'a\nb\n'), (key1, b'b\nc\n')], records) + self.assertEqual([(key0, b"a\nb\n"), (key1, b"b\nc\n")], records) def test_annotate(self): files = self.get_versionedfiles() @@ -1533,38 +1700,37 @@ def test_annotate(self): if self.key_length == 1: prefix = () else: - prefix = (b'FileA',) + prefix = (b"FileA",) # introduced full text - origins = files.annotate(prefix + (b'origin',)) - self.assertEqual([ - (prefix + (b'origin',), b'origin\n')], - origins) + origins = files.annotate(prefix + (b"origin",)) + self.assertEqual([(prefix + (b"origin",), b"origin\n")], origins) # a delta - origins = files.annotate(prefix + (b'base',)) - self.assertEqual([ - (prefix + (b'base',), b'base\n')], - origins) + origins = files.annotate(prefix + (b"base",)) + self.assertEqual([(prefix + (b"base",), b"base\n")], origins) # a merge - origins = files.annotate(prefix + (b'merged',)) + origins = files.annotate(prefix + (b"merged",)) if self.graph: - self.assertEqual([ - (prefix + (b'base',), b'base\n'), - (prefix + (b'left',), b'left\n'), - (prefix + (b'right',), b'right\n'), - (prefix + (b'merged',), b'merged\n') + self.assertEqual( + [ + (prefix + (b"base",), b"base\n"), + (prefix + (b"left",), b"left\n"), + (prefix + (b"right",), b"right\n"), + (prefix + (b"merged",), b"merged\n"), ], - origins) + origins, + ) else: # Without a graph everything is new. - self.assertEqual([ - (prefix + (b'merged',), b'base\n'), - (prefix + (b'merged',), b'left\n'), - (prefix + (b'merged',), b'right\n'), - (prefix + (b'merged',), b'merged\n') + self.assertEqual( + [ + (prefix + (b"merged",), b"base\n"), + (prefix + (b"merged",), b"left\n"), + (prefix + (b"merged",), b"right\n"), + (prefix + (b"merged",), b"merged\n"), ], - origins) - self.assertRaises(RevisionNotPresent, - files.annotate, prefix + ('missing-key',)) + origins, + ) + self.assertRaises(RevisionNotPresent, files.annotate, prefix + ("missing-key",)) def test_check_no_parameters(self): self.get_versionedfiles() @@ -1583,8 +1749,9 @@ def test_check_with_keys_becomes_generator(self): entries = files.check(keys=keys) seen = set() # Texts output should be fulltexts. - self.capture_stream(files, entries, seen.add, - files.get_parent_map(keys), require_fulltext=True) + self.capture_stream( + files, entries, seen.add, files.get_parent_map(keys), require_fulltext=True + ) # All texts should be output. self.assertEqual(set(keys), seen) @@ -1596,40 +1763,45 @@ def test_construct(self): """Each parameterised test can be constructed on a transport.""" self.get_versionedfiles() - def get_diamond_files(self, files, trailing_eol=True, left_only=False, - nokeys=False): - return get_diamond_files(files, self.key_length, - trailing_eol=trailing_eol, nograph=not self.graph, - left_only=left_only, nokeys=nokeys) + def get_diamond_files( + self, files, trailing_eol=True, left_only=False, nokeys=False + ): + return get_diamond_files( + files, + self.key_length, + trailing_eol=trailing_eol, + nograph=not self.graph, + left_only=left_only, + nokeys=nokeys, + ) def _add_content_nostoresha(self, add_lines): """When nostore_sha is supplied using old content raises.""" vf = self.get_versionedfiles() - empty_text = (b'a', []) - sample_text_nl = (b'b', [b"foo\n", b"bar\n"]) - sample_text_no_nl = (b'c', [b"foo\n", b"bar"]) + empty_text = (b"a", []) + sample_text_nl = (b"b", [b"foo\n", b"bar\n"]) + sample_text_no_nl = (b"c", [b"foo\n", b"bar"]) shas = [] for version, lines in (empty_text, sample_text_nl, sample_text_no_nl): if add_lines: - sha, _, _ = vf.add_lines(self.get_simple_key(version), [], - lines) + sha, _, _ = vf.add_lines(self.get_simple_key(version), [], lines) else: - sha, _, _ = vf.add_lines(self.get_simple_key(version), [], - lines) + sha, _, _ = vf.add_lines(self.get_simple_key(version), [], lines) shas.append(sha) # we now have a copy of all the lines in the vf. for sha, (version, lines) in zip( - shas, (empty_text, sample_text_nl, sample_text_no_nl)): + shas, (empty_text, sample_text_nl, sample_text_no_nl) + ): new_key = self.get_simple_key(version + b"2") - self.assertRaises(ExistingContent, - vf.add_lines, new_key, [], lines, - nostore_sha=sha) - self.assertRaises(ExistingContent, - vf.add_lines, new_key, [], lines, - nostore_sha=sha) + self.assertRaises( + ExistingContent, vf.add_lines, new_key, [], lines, nostore_sha=sha + ) + self.assertRaises( + ExistingContent, vf.add_lines, new_key, [], lines, nostore_sha=sha + ) # and no new version should have been added. - record = next(vf.get_record_stream([new_key], 'unordered', True)) - self.assertEqual('absent', record.storage_kind) + record = next(vf.get_record_stream([new_key], "unordered", True)) + self.assertEqual("absent", record.storage_kind) def test_add_lines_nostoresha(self): self._add_content_nostoresha(add_lines=True) @@ -1644,26 +1816,32 @@ def test_add_lines_return(self): self.assertEqual(3, len(add)) results.append(add[:2]) if self.key_length == 1: - self.assertEqual([ - (b'00e364d235126be43292ab09cb4686cf703ddc17', 7), - (b'51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44', 5), - (b'a8478686da38e370e32e42e8a0c220e33ee9132f', 10), - (b'9ef09dfa9d86780bdec9219a22560c6ece8e0ef1', 11), - (b'ed8bce375198ea62444dc71952b22cfc2b09226d', 23)], - results) + self.assertEqual( + [ + (b"00e364d235126be43292ab09cb4686cf703ddc17", 7), + (b"51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44", 5), + (b"a8478686da38e370e32e42e8a0c220e33ee9132f", 10), + (b"9ef09dfa9d86780bdec9219a22560c6ece8e0ef1", 11), + (b"ed8bce375198ea62444dc71952b22cfc2b09226d", 23), + ], + results, + ) elif self.key_length == 2: - self.assertEqual([ - (b'00e364d235126be43292ab09cb4686cf703ddc17', 7), - (b'00e364d235126be43292ab09cb4686cf703ddc17', 7), - (b'51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44', 5), - (b'51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44', 5), - (b'a8478686da38e370e32e42e8a0c220e33ee9132f', 10), - (b'a8478686da38e370e32e42e8a0c220e33ee9132f', 10), - (b'9ef09dfa9d86780bdec9219a22560c6ece8e0ef1', 11), - (b'9ef09dfa9d86780bdec9219a22560c6ece8e0ef1', 11), - (b'ed8bce375198ea62444dc71952b22cfc2b09226d', 23), - (b'ed8bce375198ea62444dc71952b22cfc2b09226d', 23)], - results) + self.assertEqual( + [ + (b"00e364d235126be43292ab09cb4686cf703ddc17", 7), + (b"00e364d235126be43292ab09cb4686cf703ddc17", 7), + (b"51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44", 5), + (b"51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44", 5), + (b"a8478686da38e370e32e42e8a0c220e33ee9132f", 10), + (b"a8478686da38e370e32e42e8a0c220e33ee9132f", 10), + (b"9ef09dfa9d86780bdec9219a22560c6ece8e0ef1", 11), + (b"9ef09dfa9d86780bdec9219a22560c6ece8e0ef1", 11), + (b"ed8bce375198ea62444dc71952b22cfc2b09226d", 23), + (b"ed8bce375198ea62444dc71952b22cfc2b09226d", 23), + ], + results, + ) def test_add_lines_no_key_generates_chk_key(self): files = self.get_versionedfiles() @@ -1675,148 +1853,180 @@ def test_add_lines_no_key_generates_chk_key(self): self.assertEqual(3, len(add)) results.append(add[:2]) if self.key_length == 1: - self.assertEqual([ - (b'00e364d235126be43292ab09cb4686cf703ddc17', 7), - (b'51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44', 5), - (b'a8478686da38e370e32e42e8a0c220e33ee9132f', 10), - (b'9ef09dfa9d86780bdec9219a22560c6ece8e0ef1', 11), - (b'ed8bce375198ea62444dc71952b22cfc2b09226d', 23)], - results) + self.assertEqual( + [ + (b"00e364d235126be43292ab09cb4686cf703ddc17", 7), + (b"51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44", 5), + (b"a8478686da38e370e32e42e8a0c220e33ee9132f", 10), + (b"9ef09dfa9d86780bdec9219a22560c6ece8e0ef1", 11), + (b"ed8bce375198ea62444dc71952b22cfc2b09226d", 23), + ], + results, + ) # Check the added items got CHK keys. - self.assertEqual({ - (b'sha1:00e364d235126be43292ab09cb4686cf703ddc17',), - (b'sha1:51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44',), - (b'sha1:9ef09dfa9d86780bdec9219a22560c6ece8e0ef1',), - (b'sha1:a8478686da38e370e32e42e8a0c220e33ee9132f',), - (b'sha1:ed8bce375198ea62444dc71952b22cfc2b09226d',), + self.assertEqual( + { + (b"sha1:00e364d235126be43292ab09cb4686cf703ddc17",), + (b"sha1:51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44",), + (b"sha1:9ef09dfa9d86780bdec9219a22560c6ece8e0ef1",), + (b"sha1:a8478686da38e370e32e42e8a0c220e33ee9132f",), + (b"sha1:ed8bce375198ea62444dc71952b22cfc2b09226d",), }, - files.keys()) + files.keys(), + ) elif self.key_length == 2: - self.assertEqual([ - (b'00e364d235126be43292ab09cb4686cf703ddc17', 7), - (b'00e364d235126be43292ab09cb4686cf703ddc17', 7), - (b'51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44', 5), - (b'51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44', 5), - (b'a8478686da38e370e32e42e8a0c220e33ee9132f', 10), - (b'a8478686da38e370e32e42e8a0c220e33ee9132f', 10), - (b'9ef09dfa9d86780bdec9219a22560c6ece8e0ef1', 11), - (b'9ef09dfa9d86780bdec9219a22560c6ece8e0ef1', 11), - (b'ed8bce375198ea62444dc71952b22cfc2b09226d', 23), - (b'ed8bce375198ea62444dc71952b22cfc2b09226d', 23)], - results) + self.assertEqual( + [ + (b"00e364d235126be43292ab09cb4686cf703ddc17", 7), + (b"00e364d235126be43292ab09cb4686cf703ddc17", 7), + (b"51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44", 5), + (b"51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44", 5), + (b"a8478686da38e370e32e42e8a0c220e33ee9132f", 10), + (b"a8478686da38e370e32e42e8a0c220e33ee9132f", 10), + (b"9ef09dfa9d86780bdec9219a22560c6ece8e0ef1", 11), + (b"9ef09dfa9d86780bdec9219a22560c6ece8e0ef1", 11), + (b"ed8bce375198ea62444dc71952b22cfc2b09226d", 23), + (b"ed8bce375198ea62444dc71952b22cfc2b09226d", 23), + ], + results, + ) # Check the added items got CHK keys. - self.assertEqual({ - (b'FileA', b'sha1:00e364d235126be43292ab09cb4686cf703ddc17'), - (b'FileA', b'sha1:51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44'), - (b'FileA', b'sha1:9ef09dfa9d86780bdec9219a22560c6ece8e0ef1'), - (b'FileA', b'sha1:a8478686da38e370e32e42e8a0c220e33ee9132f'), - (b'FileA', b'sha1:ed8bce375198ea62444dc71952b22cfc2b09226d'), - (b'FileB', b'sha1:00e364d235126be43292ab09cb4686cf703ddc17'), - (b'FileB', b'sha1:51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44'), - (b'FileB', b'sha1:9ef09dfa9d86780bdec9219a22560c6ece8e0ef1'), - (b'FileB', b'sha1:a8478686da38e370e32e42e8a0c220e33ee9132f'), - (b'FileB', b'sha1:ed8bce375198ea62444dc71952b22cfc2b09226d'), + self.assertEqual( + { + (b"FileA", b"sha1:00e364d235126be43292ab09cb4686cf703ddc17"), + (b"FileA", b"sha1:51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44"), + (b"FileA", b"sha1:9ef09dfa9d86780bdec9219a22560c6ece8e0ef1"), + (b"FileA", b"sha1:a8478686da38e370e32e42e8a0c220e33ee9132f"), + (b"FileA", b"sha1:ed8bce375198ea62444dc71952b22cfc2b09226d"), + (b"FileB", b"sha1:00e364d235126be43292ab09cb4686cf703ddc17"), + (b"FileB", b"sha1:51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44"), + (b"FileB", b"sha1:9ef09dfa9d86780bdec9219a22560c6ece8e0ef1"), + (b"FileB", b"sha1:a8478686da38e370e32e42e8a0c220e33ee9132f"), + (b"FileB", b"sha1:ed8bce375198ea62444dc71952b22cfc2b09226d"), }, - files.keys()) + files.keys(), + ) def test_empty_lines(self): """Empty files can be stored.""" f = self.get_versionedfiles() - key_a = self.get_simple_key(b'a') + key_a = self.get_simple_key(b"a") f.add_lines(key_a, [], []) - self.assertEqual(b'', - next(f.get_record_stream([key_a], 'unordered', True - )).get_bytes_as('fulltext')) - key_b = self.get_simple_key(b'b') + self.assertEqual( + b"", + next(f.get_record_stream([key_a], "unordered", True)).get_bytes_as( + "fulltext" + ), + ) + key_b = self.get_simple_key(b"b") f.add_lines(key_b, self.get_parents([key_a]), []) - self.assertEqual(b'', - next(f.get_record_stream([key_b], 'unordered', True - )).get_bytes_as('fulltext')) + self.assertEqual( + b"", + next(f.get_record_stream([key_b], "unordered", True)).get_bytes_as( + "fulltext" + ), + ) def test_newline_only(self): f = self.get_versionedfiles() - key_a = self.get_simple_key(b'a') - f.add_lines(key_a, [], [b'\n']) - self.assertEqual(b'\n', - next(f.get_record_stream([key_a], 'unordered', True - )).get_bytes_as('fulltext')) - key_b = self.get_simple_key(b'b') - f.add_lines(key_b, self.get_parents([key_a]), [b'\n']) - self.assertEqual(b'\n', - next(f.get_record_stream([key_b], 'unordered', True - )).get_bytes_as('fulltext')) + key_a = self.get_simple_key(b"a") + f.add_lines(key_a, [], [b"\n"]) + self.assertEqual( + b"\n", + next(f.get_record_stream([key_a], "unordered", True)).get_bytes_as( + "fulltext" + ), + ) + key_b = self.get_simple_key(b"b") + f.add_lines(key_b, self.get_parents([key_a]), [b"\n"]) + self.assertEqual( + b"\n", + next(f.get_record_stream([key_b], "unordered", True)).get_bytes_as( + "fulltext" + ), + ) def test_get_known_graph_ancestry(self): f = self.get_versionedfiles() if not self.graph: - raise TestNotApplicable('ancestry info only relevant with graph.') - key_a = self.get_simple_key(b'a') - key_b = self.get_simple_key(b'b') - key_c = self.get_simple_key(b'c') + raise TestNotApplicable("ancestry info only relevant with graph.") + key_a = self.get_simple_key(b"a") + key_b = self.get_simple_key(b"b") + key_c = self.get_simple_key(b"c") # A # |\ # | B # |/ # C - f.add_lines(key_a, [], [b'\n']) - f.add_lines(key_b, [key_a], [b'\n']) - f.add_lines(key_c, [key_a, key_b], [b'\n']) + f.add_lines(key_a, [], [b"\n"]) + f.add_lines(key_b, [key_a], [b"\n"]) + f.add_lines(key_c, [key_a, key_b], [b"\n"]) kg = f.get_known_graph_ancestry([key_c]) self.assertIsInstance(kg, _mod_graph.KnownGraph) self.assertEqual([key_a, key_b, key_c], list(kg.topo_sort())) def test_known_graph_with_fallbacks(self): - f = self.get_versionedfiles('files') + f = self.get_versionedfiles("files") if not self.graph: - raise TestNotApplicable('ancestry info only relevant with graph.') - if getattr(f, 'add_fallback_versioned_files', None) is None: + raise TestNotApplicable("ancestry info only relevant with graph.") + if getattr(f, "add_fallback_versioned_files", None) is None: raise TestNotApplicable(f"{f.__class__.__name__} doesn't support fallbacks") - key_a = self.get_simple_key(b'a') - key_b = self.get_simple_key(b'b') - key_c = self.get_simple_key(b'c') + key_a = self.get_simple_key(b"a") + key_b = self.get_simple_key(b"b") + key_c = self.get_simple_key(b"c") # A only in fallback # |\ # | B # |/ # C - g = self.get_versionedfiles('fallback') - g.add_lines(key_a, [], [b'\n']) + g = self.get_versionedfiles("fallback") + g.add_lines(key_a, [], [b"\n"]) f.add_fallback_versioned_files(g) - f.add_lines(key_b, [key_a], [b'\n']) - f.add_lines(key_c, [key_a, key_b], [b'\n']) + f.add_lines(key_b, [key_a], [b"\n"]) + f.add_lines(key_c, [key_a, key_b], [b"\n"]) kg = f.get_known_graph_ancestry([key_c]) self.assertEqual([key_a, key_b, key_c], list(kg.topo_sort())) def test_get_record_stream_empty(self): """An empty stream can be requested without error.""" f = self.get_versionedfiles() - entries = f.get_record_stream([], 'unordered', False) + entries = f.get_record_stream([], "unordered", False) self.assertEqual([], list(entries)) def assertValidStorageKind(self, storage_kind): """Assert that storage_kind is a valid storage_kind.""" - self.assertSubset([storage_kind], - ['mpdiff', 'knit-annotated-ft', 'knit-annotated-delta', - 'knit-ft', 'knit-delta', 'chunked', 'fulltext', - 'knit-annotated-ft-gz', 'knit-annotated-delta-gz', 'knit-ft-gz', - 'knit-delta-gz', - 'knit-delta-closure', 'knit-delta-closure-ref', - 'groupcompress-block', 'groupcompress-block-ref']) - - def capture_stream(self, f, entries, on_seen, parents, - require_fulltext=False): + self.assertSubset( + [storage_kind], + [ + "mpdiff", + "knit-annotated-ft", + "knit-annotated-delta", + "knit-ft", + "knit-delta", + "chunked", + "fulltext", + "knit-annotated-ft-gz", + "knit-annotated-delta-gz", + "knit-ft-gz", + "knit-delta-gz", + "knit-delta-closure", + "knit-delta-closure-ref", + "groupcompress-block", + "groupcompress-block-ref", + ], + ) + + def capture_stream(self, f, entries, on_seen, parents, require_fulltext=False): """Capture a stream for testing.""" for factory in entries: on_seen(factory.key) self.assertValidStorageKind(factory.storage_kind) if factory.sha1 is not None: - self.assertEqual(f.get_sha1s([factory.key])[factory.key], - factory.sha1) + self.assertEqual(f.get_sha1s([factory.key])[factory.key], factory.sha1) self.assertEqual(parents[factory.key], factory.parents) - self.assertIsInstance(factory.get_bytes_as(factory.storage_kind), - bytes) + self.assertIsInstance(factory.get_bytes_as(factory.storage_kind), bytes) if require_fulltext: - factory.get_bytes_as('fulltext') + factory.get_bytes_as("fulltext") def test_get_record_stream_interface(self): """Each item in a stream has to provide a regular interface.""" @@ -1824,7 +2034,7 @@ def test_get_record_stream_interface(self): self.get_diamond_files(files) keys, _ = self.get_keys_and_sort_order() parent_map = files.get_parent_map(keys) - entries = files.get_record_stream(keys, 'unordered', False) + entries = files.get_record_stream(keys, "unordered", False) seen = set() self.capture_stream(files, entries, seen.add, parent_map) self.assertEqual(set(keys), seen) @@ -1832,43 +2042,57 @@ def test_get_record_stream_interface(self): def get_keys_and_sort_order(self): """Get diamond test keys list, and their sort ordering.""" if self.key_length == 1: - keys = [(b'merged',), (b'left',), (b'right',), (b'base',)] - sort_order = {(b'merged',): 2, (b'left',): 1, - (b'right',): 1, (b'base',): 0} + keys = [(b"merged",), (b"left",), (b"right",), (b"base",)] + sort_order = {(b"merged",): 2, (b"left",): 1, (b"right",): 1, (b"base",): 0} else: keys = [ - (b'FileA', b'merged'), (b'FileA', b'left'), (b'FileA', b'right'), - (b'FileA', b'base'), - (b'FileB', b'merged'), (b'FileB', b'left'), (b'FileB', b'right'), - (b'FileB', b'base'), - ] + (b"FileA", b"merged"), + (b"FileA", b"left"), + (b"FileA", b"right"), + (b"FileA", b"base"), + (b"FileB", b"merged"), + (b"FileB", b"left"), + (b"FileB", b"right"), + (b"FileB", b"base"), + ] sort_order = { - (b'FileA', b'merged'): 2, (b'FileA', b'left'): 1, (b'FileA', b'right'): 1, - (b'FileA', b'base'): 0, - (b'FileB', b'merged'): 2, (b'FileB', b'left'): 1, (b'FileB', b'right'): 1, - (b'FileB', b'base'): 0, - } + (b"FileA", b"merged"): 2, + (b"FileA", b"left"): 1, + (b"FileA", b"right"): 1, + (b"FileA", b"base"): 0, + (b"FileB", b"merged"): 2, + (b"FileB", b"left"): 1, + (b"FileB", b"right"): 1, + (b"FileB", b"base"): 0, + } return keys, sort_order def get_keys_and_groupcompress_sort_order(self): """Get diamond test keys list, and their groupcompress sort ordering.""" if self.key_length == 1: - keys = [(b'merged',), (b'left',), (b'right',), (b'base',)] - sort_order = {(b'merged',): 0, (b'left',): 1, - (b'right',): 1, (b'base',): 2} + keys = [(b"merged",), (b"left",), (b"right",), (b"base",)] + sort_order = {(b"merged",): 0, (b"left",): 1, (b"right",): 1, (b"base",): 2} else: keys = [ - (b'FileA', b'merged'), (b'FileA', b'left'), (b'FileA', b'right'), - (b'FileA', b'base'), - (b'FileB', b'merged'), (b'FileB', b'left'), (b'FileB', b'right'), - (b'FileB', b'base'), - ] + (b"FileA", b"merged"), + (b"FileA", b"left"), + (b"FileA", b"right"), + (b"FileA", b"base"), + (b"FileB", b"merged"), + (b"FileB", b"left"), + (b"FileB", b"right"), + (b"FileB", b"base"), + ] sort_order = { - (b'FileA', b'merged'): 0, (b'FileA', b'left'): 1, (b'FileA', b'right'): 1, - (b'FileA', b'base'): 2, - (b'FileB', b'merged'): 3, (b'FileB', b'left'): 4, (b'FileB', b'right'): 4, - (b'FileB', b'base'): 5, - } + (b"FileA", b"merged"): 0, + (b"FileA", b"left"): 1, + (b"FileA", b"right"): 1, + (b"FileA", b"base"): 2, + (b"FileB", b"merged"): 3, + (b"FileB", b"left"): 4, + (b"FileB", b"right"): 4, + (b"FileB", b"base"): 5, + } return keys, sort_order def test_get_record_stream_interface_ordered(self): @@ -1877,7 +2101,7 @@ def test_get_record_stream_interface_ordered(self): self.get_diamond_files(files) keys, sort_order = self.get_keys_and_sort_order() parent_map = files.get_parent_map(keys) - entries = files.get_record_stream(keys, 'topological', False) + entries = files.get_record_stream(keys, "topological", False) seen = [] self.capture_stream(files, entries, seen.append, parent_map) self.assertStreamOrder(sort_order, seen, keys) @@ -1888,21 +2112,22 @@ def test_get_record_stream_interface_ordered_with_delta_closure(self): self.get_diamond_files(files) keys, sort_order = self.get_keys_and_sort_order() parent_map = files.get_parent_map(keys) - entries = files.get_record_stream(keys, 'topological', True) + entries = files.get_record_stream(keys, "topological", True) seen = [] for factory in entries: seen.append(factory.key) self.assertValidStorageKind(factory.storage_kind) - self.assertSubset([factory.sha1], - [None, files.get_sha1s([factory.key])[factory.key]]) + self.assertSubset( + [factory.sha1], [None, files.get_sha1s([factory.key])[factory.key]] + ) self.assertEqual(parent_map[factory.key], factory.parents) # self.assertEqual(files.get_text(factory.key), - ft_bytes = factory.get_bytes_as('fulltext') + ft_bytes = factory.get_bytes_as("fulltext") self.assertIsInstance(ft_bytes, bytes) - chunked_bytes = factory.get_bytes_as('chunked') - self.assertEqualDiff(ft_bytes, b''.join(chunked_bytes)) - chunked_bytes = factory.iter_bytes_as('chunked') - self.assertEqualDiff(ft_bytes, b''.join(chunked_bytes)) + chunked_bytes = factory.get_bytes_as("chunked") + self.assertEqualDiff(ft_bytes, b"".join(chunked_bytes)) + chunked_bytes = factory.iter_bytes_as("chunked") + self.assertEqualDiff(ft_bytes, b"".join(chunked_bytes)) self.assertStreamOrder(sort_order, seen, keys) @@ -1912,7 +2137,7 @@ def test_get_record_stream_interface_groupcompress(self): self.get_diamond_files(files) keys, sort_order = self.get_keys_and_groupcompress_sort_order() parent_map = files.get_parent_map(keys) - entries = files.get_record_stream(keys, 'groupcompress', False) + entries = files.get_record_stream(keys, "groupcompress", False) seen = [] self.capture_stream(files, entries, seen.append, parent_map) self.assertStreamOrder(sort_order, seen, keys) @@ -1922,14 +2147,16 @@ def assertStreamOrder(self, sort_order, seen, keys): if self.key_length == 1: lows = {(): 0} else: - lows = {(b'FileA',): 0, (b'FileB',): 0} + lows = {(b"FileA",): 0, (b"FileB",): 0} if not self.graph: self.assertEqual(set(keys), set(seen)) else: for key in seen: sort_pos = sort_order[key] - self.assertTrue(sort_pos >= lows[key[:-1]], - f"Out of order in sorted stream: {key!r}, {seen!r}") + self.assertTrue( + sort_pos >= lows[key[:-1]], + f"Out of order in sorted stream: {key!r}, {seen!r}", + ) lows[key[:-1]] = sort_pos def test_get_record_stream_unknown_storage_kind_raises(self): @@ -1937,16 +2164,20 @@ def test_get_record_stream_unknown_storage_kind_raises(self): files = self.get_versionedfiles() self.get_diamond_files(files) if self.key_length == 1: - keys = [(b'merged',), (b'left',), (b'right',), (b'base',)] + keys = [(b"merged",), (b"left",), (b"right",), (b"base",)] else: keys = [ - (b'FileA', b'merged'), (b'FileA', b'left'), (b'FileA', b'right'), - (b'FileA', b'base'), - (b'FileB', b'merged'), (b'FileB', b'left'), (b'FileB', b'right'), - (b'FileB', b'base'), - ] + (b"FileA", b"merged"), + (b"FileA", b"left"), + (b"FileA", b"right"), + (b"FileA", b"base"), + (b"FileB", b"merged"), + (b"FileB", b"left"), + (b"FileB", b"right"), + (b"FileB", b"base"), + ] parent_map = files.get_parent_map(keys) - entries = files.get_record_stream(keys, 'unordered', False) + entries = files.get_record_stream(keys, "unordered", False) # We track the contents because we should be able to try, fail a # particular kind and then ask for one that works and continue. seen = set() @@ -1954,48 +2185,52 @@ def test_get_record_stream_unknown_storage_kind_raises(self): seen.add(factory.key) self.assertValidStorageKind(factory.storage_kind) if factory.sha1 is not None: - self.assertEqual(files.get_sha1s([factory.key])[factory.key], - factory.sha1) + self.assertEqual( + files.get_sha1s([factory.key])[factory.key], factory.sha1 + ) self.assertEqual(parent_map[factory.key], factory.parents) # currently no stream emits mpdiff - self.assertRaises(UnavailableRepresentation, - factory.get_bytes_as, 'mpdiff') - self.assertIsInstance(factory.get_bytes_as(factory.storage_kind), - bytes) + self.assertRaises(UnavailableRepresentation, factory.get_bytes_as, "mpdiff") + self.assertIsInstance(factory.get_bytes_as(factory.storage_kind), bytes) self.assertEqual(set(keys), seen) def test_get_record_stream_missing_records_are_absent(self): files = self.get_versionedfiles() self.get_diamond_files(files) if self.key_length == 1: - keys = [(b'merged',), (b'left',), (b'right',), - (b'absent',), (b'base',)] + keys = [(b"merged",), (b"left",), (b"right",), (b"absent",), (b"base",)] else: keys = [ - (b'FileA', b'merged'), (b'FileA', b'left'), (b'FileA', b'right'), - (b'FileA', b'absent'), (b'FileA', b'base'), - (b'FileB', b'merged'), (b'FileB', b'left'), (b'FileB', b'right'), - (b'FileB', b'absent'), (b'FileB', b'base'), - (b'absent', b'absent'), - ] + (b"FileA", b"merged"), + (b"FileA", b"left"), + (b"FileA", b"right"), + (b"FileA", b"absent"), + (b"FileA", b"base"), + (b"FileB", b"merged"), + (b"FileB", b"left"), + (b"FileB", b"right"), + (b"FileB", b"absent"), + (b"FileB", b"base"), + (b"absent", b"absent"), + ] parent_map = files.get_parent_map(keys) - entries = files.get_record_stream(keys, 'unordered', False) + entries = files.get_record_stream(keys, "unordered", False) self.assertAbsentRecord(files, keys, parent_map, entries) - entries = files.get_record_stream(keys, 'topological', False) + entries = files.get_record_stream(keys, "topological", False) self.assertAbsentRecord(files, keys, parent_map, entries) def assertRecordHasContent(self, record, bytes): """Assert that record has the bytes bytes.""" - self.assertEqual(bytes, record.get_bytes_as('fulltext')) - self.assertEqual(bytes, b''.join(record.get_bytes_as('chunked'))) + self.assertEqual(bytes, record.get_bytes_as("fulltext")) + self.assertEqual(bytes, b"".join(record.get_bytes_as("chunked"))) def test_get_record_stream_native_formats_are_wire_ready_one_ft(self): files = self.get_versionedfiles() - key = self.get_simple_key(b'foo') - files.add_lines(key, (), [b'my text\n', b'content']) - stream = files.get_record_stream([key], 'unordered', False) + key = self.get_simple_key(b"foo") + files.add_lines(key, (), [b"my text\n", b"content"]) + stream = files.get_record_stream([key], "unordered", False) record = next(stream) - if record.storage_kind in ('chunked', 'fulltext'): + if record.storage_kind in ("chunked", "fulltext"): # chunked and fulltext representations are for direct use not wire # serialisation: check they are able to be used directly. To send # such records over the wire translation will be needed. @@ -2007,12 +2242,12 @@ def test_get_record_stream_native_formats_are_wire_ready_one_ft(self): records = [] for record in network_stream: records.append(record) - self.assertEqual(source_record.storage_kind, - record.storage_kind) + self.assertEqual(source_record.storage_kind, record.storage_kind) self.assertEqual(source_record.parents, record.parents) self.assertEqual( source_record.get_bytes_as(source_record.storage_kind), - record.get_bytes_as(record.storage_kind)) + record.get_bytes_as(record.storage_kind), + ) self.assertEqual(1, len(records)) def assertStreamMetaEqual(self, records, expected, stream): @@ -2030,8 +2265,7 @@ def assertStreamMetaEqual(self, records, expected, stream): self.assertEqual(ref_record.parents, record.parents) yield record - def stream_to_bytes_or_skip_counter(self, skipped_records, full_texts, - stream): + def stream_to_bytes_or_skip_counter(self, skipped_records, full_texts, stream): """Convert a stream to a bytes iterator. :param skipped_records: A list with one element to increment when a @@ -2042,7 +2276,7 @@ def stream_to_bytes_or_skip_counter(self, skipped_records, full_texts, :return: An iterator over the bytes of each record. """ for record in stream: - if record.storage_kind in ('chunked', 'fulltext'): + if record.storage_kind in ("chunked", "fulltext"): skipped_records[0] += 1 # check the content is correct for direct use. self.assertRecordHasContent(record, full_texts[record.key]) @@ -2051,32 +2285,33 @@ def stream_to_bytes_or_skip_counter(self, skipped_records, full_texts, def test_get_record_stream_native_formats_are_wire_ready_ft_delta(self): files = self.get_versionedfiles() - target_files = self.get_versionedfiles('target') - key = self.get_simple_key(b'ft') - key_delta = self.get_simple_key(b'delta') - files.add_lines(key, (), [b'my text\n', b'content']) + target_files = self.get_versionedfiles("target") + key = self.get_simple_key(b"ft") + key_delta = self.get_simple_key(b"delta") + files.add_lines(key, (), [b"my text\n", b"content"]) if self.graph: delta_parents = (key,) else: delta_parents = () - files.add_lines(key_delta, delta_parents, [ - b'different\n', b'content\n']) - local = files.get_record_stream([key, key_delta], 'unordered', False) - ref = files.get_record_stream([key, key_delta], 'unordered', False) + files.add_lines(key_delta, delta_parents, [b"different\n", b"content\n"]) + local = files.get_record_stream([key, key_delta], "unordered", False) + ref = files.get_record_stream([key, key_delta], "unordered", False) skipped_records = [0] full_texts = { key: b"my text\ncontent", key_delta: b"different\ncontent\n", - } + } byte_stream = self.stream_to_bytes_or_skip_counter( - skipped_records, full_texts, local) + skipped_records, full_texts, local + ) network_stream = versionedfile.NetworkRecordStream(byte_stream).read() records = [] # insert the stream from the network into a versioned files object so we can # check the content was carried across correctly without doing delta # inspection. target_files.insert_record_stream( - self.assertStreamMetaEqual(records, ref, network_stream)) + self.assertStreamMetaEqual(records, ref, network_stream) + ) # No duplicates on the wire thank you! self.assertEqual(2, len(records) + skipped_records[0]) if len(records): @@ -2086,35 +2321,37 @@ def test_get_record_stream_native_formats_are_wire_ready_ft_delta(self): def test_get_record_stream_native_formats_are_wire_ready_delta(self): # copy a delta over the wire files = self.get_versionedfiles() - target_files = self.get_versionedfiles('target') - key = self.get_simple_key(b'ft') - key_delta = self.get_simple_key(b'delta') - files.add_lines(key, (), [b'my text\n', b'content']) + target_files = self.get_versionedfiles("target") + key = self.get_simple_key(b"ft") + key_delta = self.get_simple_key(b"delta") + files.add_lines(key, (), [b"my text\n", b"content"]) if self.graph: delta_parents = (key,) else: delta_parents = () - files.add_lines(key_delta, delta_parents, [ - b'different\n', b'content\n']) + files.add_lines(key_delta, delta_parents, [b"different\n", b"content\n"]) # Copy the basis text across so we can reconstruct the delta during # insertion into target. - target_files.insert_record_stream(files.get_record_stream([key], - 'unordered', False)) - local = files.get_record_stream([key_delta], 'unordered', False) - ref = files.get_record_stream([key_delta], 'unordered', False) + target_files.insert_record_stream( + files.get_record_stream([key], "unordered", False) + ) + local = files.get_record_stream([key_delta], "unordered", False) + ref = files.get_record_stream([key_delta], "unordered", False) skipped_records = [0] full_texts = { key_delta: b"different\ncontent\n", - } + } byte_stream = self.stream_to_bytes_or_skip_counter( - skipped_records, full_texts, local) + skipped_records, full_texts, local + ) network_stream = versionedfile.NetworkRecordStream(byte_stream).read() records = [] # insert the stream from the network into a versioned files object so we can # check the content was carried across correctly without doing delta # inspection during check_stream. target_files.insert_record_stream( - self.assertStreamMetaEqual(records, ref, network_stream)) + self.assertStreamMetaEqual(records, ref, network_stream) + ) # No duplicates on the wire thank you! self.assertEqual(1, len(records) + skipped_records[0]) if len(records): @@ -2124,23 +2361,23 @@ def test_get_record_stream_native_formats_are_wire_ready_delta(self): def test_get_record_stream_wire_ready_delta_closure_included(self): # copy a delta over the wire with the ability to get its full text. files = self.get_versionedfiles() - key = self.get_simple_key(b'ft') - key_delta = self.get_simple_key(b'delta') - files.add_lines(key, (), [b'my text\n', b'content']) + key = self.get_simple_key(b"ft") + key_delta = self.get_simple_key(b"delta") + files.add_lines(key, (), [b"my text\n", b"content"]) if self.graph: delta_parents = (key,) else: delta_parents = () - files.add_lines(key_delta, delta_parents, [ - b'different\n', b'content\n']) - local = files.get_record_stream([key_delta], 'unordered', True) - ref = files.get_record_stream([key_delta], 'unordered', True) + files.add_lines(key_delta, delta_parents, [b"different\n", b"content\n"]) + local = files.get_record_stream([key_delta], "unordered", True) + ref = files.get_record_stream([key_delta], "unordered", True) skipped_records = [0] full_texts = { key_delta: b"different\ncontent\n", - } + } byte_stream = self.stream_to_bytes_or_skip_counter( - skipped_records, full_texts, local) + skipped_records, full_texts, local + ) network_stream = versionedfile.NetworkRecordStream(byte_stream).read() records = [] # insert the stream from the network into a versioned files object so we can @@ -2157,8 +2394,8 @@ def assertAbsentRecord(self, files, keys, parents, entries): seen = set() for factory in entries: seen.add(factory.key) - if factory.key[-1] == b'absent': - self.assertEqual('absent', factory.storage_kind) + if factory.key[-1] == b"absent": + self.assertEqual("absent", factory.storage_kind) self.assertEqual(None, factory.sha1) self.assertEqual(None, factory.parents) else: @@ -2167,8 +2404,7 @@ def assertAbsentRecord(self, files, keys, parents, entries): sha1 = files.get_sha1s([factory.key])[factory.key] self.assertEqual(sha1, factory.sha1) self.assertEqual(parents[factory.key], factory.parents) - self.assertIsInstance(factory.get_bytes_as(factory.storage_kind), - bytes) + self.assertIsInstance(factory.get_bytes_as(factory.storage_kind), bytes) self.assertEqual(set(keys), seen) def test_filter_absent_records(self): @@ -2182,19 +2418,20 @@ def test_filter_absent_records(self): # absent keys is still delivered). present_keys = list(keys) if self.key_length == 1: - keys.insert(2, (b'extra',)) + keys.insert(2, (b"extra",)) else: - keys.insert(2, (b'extra', b'extra')) - entries = files.get_record_stream(keys, 'unordered', False) + keys.insert(2, (b"extra", b"extra")) + entries = files.get_record_stream(keys, "unordered", False) seen = set() - self.capture_stream(files, versionedfile.filter_absent(entries), seen.add, - parent_map) + self.capture_stream( + files, versionedfile.filter_absent(entries), seen.add, parent_map + ) self.assertEqual(set(present_keys), seen) def get_mapper(self): """Get a mapper suitable for the key length of the test interface.""" if self.key_length == 1: - return ConstantMapper('source') + return ConstantMapper("source") else: return HashEscapedPrefixMapper() @@ -2208,58 +2445,76 @@ def get_parents(self, parents): def test_get_annotator(self): files = self.get_versionedfiles() self.get_diamond_files(files) - origin_key = self.get_simple_key(b'origin') - base_key = self.get_simple_key(b'base') - left_key = self.get_simple_key(b'left') - right_key = self.get_simple_key(b'right') - merged_key = self.get_simple_key(b'merged') + origin_key = self.get_simple_key(b"origin") + base_key = self.get_simple_key(b"base") + left_key = self.get_simple_key(b"left") + right_key = self.get_simple_key(b"right") + merged_key = self.get_simple_key(b"merged") # annotator = files.get_annotator() # introduced full text origins, lines = files.get_annotator().annotate(origin_key) self.assertEqual([(origin_key,)], origins) - self.assertEqual([b'origin\n'], lines) + self.assertEqual([b"origin\n"], lines) # a delta origins, lines = files.get_annotator().annotate(base_key) self.assertEqual([(base_key,)], origins) # a merge origins, lines = files.get_annotator().annotate(merged_key) if self.graph: - self.assertEqual([ - (base_key,), - (left_key,), - (right_key,), - (merged_key,), - ], origins) + self.assertEqual( + [ + (base_key,), + (left_key,), + (right_key,), + (merged_key,), + ], + origins, + ) else: # Without a graph everything is new. - self.assertEqual([ - (merged_key,), - (merged_key,), - (merged_key,), - (merged_key,), - ], origins) - self.assertRaises(RevisionNotPresent, - files.get_annotator().annotate, self.get_simple_key(b'missing-key')) + self.assertEqual( + [ + (merged_key,), + (merged_key,), + (merged_key,), + (merged_key,), + ], + origins, + ) + self.assertRaises( + RevisionNotPresent, + files.get_annotator().annotate, + self.get_simple_key(b"missing-key"), + ) def test_get_parent_map(self): files = self.get_versionedfiles() if self.key_length == 1: parent_details = [ - ((b'r0',), self.get_parents(())), - ((b'r1',), self.get_parents(((b'r0',),))), - ((b'r2',), self.get_parents(())), - ((b'r3',), self.get_parents(())), - ((b'm',), self.get_parents(((b'r0',), (b'r1',), (b'r2',), (b'r3',)))), - ] + ((b"r0",), self.get_parents(())), + ((b"r1",), self.get_parents(((b"r0",),))), + ((b"r2",), self.get_parents(())), + ((b"r3",), self.get_parents(())), + ((b"m",), self.get_parents(((b"r0",), (b"r1",), (b"r2",), (b"r3",)))), + ] else: parent_details = [ - ((b'FileA', b'r0'), self.get_parents(())), - ((b'FileA', b'r1'), self.get_parents(((b'FileA', b'r0'),))), - ((b'FileA', b'r2'), self.get_parents(())), - ((b'FileA', b'r3'), self.get_parents(())), - ((b'FileA', b'm'), self.get_parents(((b'FileA', b'r0'), - (b'FileA', b'r1'), (b'FileA', b'r2'), (b'FileA', b'r3')))), - ] + ((b"FileA", b"r0"), self.get_parents(())), + ((b"FileA", b"r1"), self.get_parents(((b"FileA", b"r0"),))), + ((b"FileA", b"r2"), self.get_parents(())), + ((b"FileA", b"r3"), self.get_parents(())), + ( + (b"FileA", b"m"), + self.get_parents( + ( + (b"FileA", b"r0"), + (b"FileA", b"r1"), + (b"FileA", b"r2"), + (b"FileA", b"r3"), + ) + ), + ), + ] for key, parents in parent_details: files.add_lines(key, parents, []) # immediately after adding it should be queryable. @@ -2272,9 +2527,9 @@ def test_get_parent_map(self): # Absent keys are just not included in the result. keys = list(all_parents.keys()) if self.key_length == 1: - keys.insert(1, (b'missing',)) + keys.insert(1, (b"missing",)) else: - keys.insert(1, (b'missing', b'missing')) + keys.insert(1, (b"missing", b"missing")) # Absent keys are just ignored self.assertEqual(all_parents, files.get_parent_map(keys)) @@ -2282,22 +2537,26 @@ def test_get_sha1s(self): files = self.get_versionedfiles() self.get_diamond_files(files) if self.key_length == 1: - keys = [(b'base',), (b'origin',), (b'left',), - (b'merged',), (b'right',)] + keys = [(b"base",), (b"origin",), (b"left",), (b"merged",), (b"right",)] else: # ask for shas from different prefixes. keys = [ - (b'FileA', b'base'), (b'FileB', b'origin'), (b'FileA', b'left'), - (b'FileA', b'merged'), (b'FileB', b'right'), - ] - self.assertEqual({ - keys[0]: b'51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44', - keys[1]: b'00e364d235126be43292ab09cb4686cf703ddc17', - keys[2]: b'a8478686da38e370e32e42e8a0c220e33ee9132f', - keys[3]: b'ed8bce375198ea62444dc71952b22cfc2b09226d', - keys[4]: b'9ef09dfa9d86780bdec9219a22560c6ece8e0ef1', + (b"FileA", b"base"), + (b"FileB", b"origin"), + (b"FileA", b"left"), + (b"FileA", b"merged"), + (b"FileB", b"right"), + ] + self.assertEqual( + { + keys[0]: b"51c64a6f4fc375daf0d24aafbabe4d91b6f4bb44", + keys[1]: b"00e364d235126be43292ab09cb4686cf703ddc17", + keys[2]: b"a8478686da38e370e32e42e8a0c220e33ee9132f", + keys[3]: b"ed8bce375198ea62444dc71952b22cfc2b09226d", + keys[4]: b"9ef09dfa9d86780bdec9219a22560c6ece8e0ef1", }, - files.get_sha1s(keys)) + files.get_sha1s(keys), + ) def test_insert_record_stream_empty(self): """Inserting an empty record stream should work.""" @@ -2309,30 +2568,29 @@ def assertIdenticalVersionedFile(self, expected, actual): self.assertEqual(set(actual.keys()), set(expected.keys())) actual_parents = actual.get_parent_map(actual.keys()) if self.graph: - self.assertEqual( - actual_parents, expected.get_parent_map(expected.keys())) + self.assertEqual(actual_parents, expected.get_parent_map(expected.keys())) else: for _key, parents in actual_parents.items(): self.assertEqual(None, parents) for key in actual.keys(): - actual_text = next(actual.get_record_stream( - [key], 'unordered', True)).get_bytes_as('fulltext') - expected_text = next(expected.get_record_stream( - [key], 'unordered', True)).get_bytes_as('fulltext') + actual_text = next( + actual.get_record_stream([key], "unordered", True) + ).get_bytes_as("fulltext") + expected_text = next( + expected.get_record_stream([key], "unordered", True) + ).get_bytes_as("fulltext") self.assertEqual(actual_text, expected_text) def test_insert_record_stream_fulltexts(self): """Any file should accept a stream of fulltexts.""" files = self.get_versionedfiles() mapper = self.get_mapper() - source_transport = self.get_transport('source') - source_transport.mkdir('.') + source_transport = self.get_transport("source") + source_transport.mkdir(".") # weaves always output fulltexts. - source = make_versioned_files_factory(WeaveFile, mapper)( - source_transport) + source = make_versioned_files_factory(WeaveFile, mapper)(source_transport) self.get_diamond_files(source, trailing_eol=False) - stream = source.get_record_stream(source.keys(), 'topological', - False) + stream = source.get_record_stream(source.keys(), "topological", False) files.insert_record_stream(stream) self.assertIdenticalVersionedFile(source, files) @@ -2340,14 +2598,12 @@ def test_insert_record_stream_fulltexts_noeol(self): """Any file should accept a stream of fulltexts.""" files = self.get_versionedfiles() mapper = self.get_mapper() - source_transport = self.get_transport('source') - source_transport.mkdir('.') + source_transport = self.get_transport("source") + source_transport.mkdir(".") # weaves always output fulltexts. - source = make_versioned_files_factory(WeaveFile, mapper)( - source_transport) + source = make_versioned_files_factory(WeaveFile, mapper)(source_transport) self.get_diamond_files(source, trailing_eol=False) - stream = source.get_record_stream(source.keys(), 'topological', - False) + stream = source.get_record_stream(source.keys(), "topological", False) files.insert_record_stream(stream) self.assertIdenticalVersionedFile(source, files) @@ -2355,12 +2611,11 @@ def test_insert_record_stream_annotated_knits(self): """Any file should accept a stream from plain knits.""" files = self.get_versionedfiles() mapper = self.get_mapper() - source_transport = self.get_transport('source') - source_transport.mkdir('.') + source_transport = self.get_transport("source") + source_transport.mkdir(".") source = make_file_factory(True, mapper)(source_transport) self.get_diamond_files(source) - stream = source.get_record_stream(source.keys(), 'topological', - False) + stream = source.get_record_stream(source.keys(), "topological", False) files.insert_record_stream(stream) self.assertIdenticalVersionedFile(source, files) @@ -2368,12 +2623,11 @@ def test_insert_record_stream_annotated_knits_noeol(self): """Any file should accept a stream from plain knits.""" files = self.get_versionedfiles() mapper = self.get_mapper() - source_transport = self.get_transport('source') - source_transport.mkdir('.') + source_transport = self.get_transport("source") + source_transport.mkdir(".") source = make_file_factory(True, mapper)(source_transport) self.get_diamond_files(source, trailing_eol=False) - stream = source.get_record_stream(source.keys(), 'topological', - False) + stream = source.get_record_stream(source.keys(), "topological", False) files.insert_record_stream(stream) self.assertIdenticalVersionedFile(source, files) @@ -2381,12 +2635,11 @@ def test_insert_record_stream_plain_knits(self): """Any file should accept a stream from plain knits.""" files = self.get_versionedfiles() mapper = self.get_mapper() - source_transport = self.get_transport('source') - source_transport.mkdir('.') + source_transport = self.get_transport("source") + source_transport.mkdir(".") source = make_file_factory(False, mapper)(source_transport) self.get_diamond_files(source) - stream = source.get_record_stream(source.keys(), 'topological', - False) + stream = source.get_record_stream(source.keys(), "topological", False) files.insert_record_stream(stream) self.assertIdenticalVersionedFile(source, files) @@ -2394,56 +2647,84 @@ def test_insert_record_stream_plain_knits_noeol(self): """Any file should accept a stream from plain knits.""" files = self.get_versionedfiles() mapper = self.get_mapper() - source_transport = self.get_transport('source') - source_transport.mkdir('.') + source_transport = self.get_transport("source") + source_transport.mkdir(".") source = make_file_factory(False, mapper)(source_transport) self.get_diamond_files(source, trailing_eol=False) - stream = source.get_record_stream(source.keys(), 'topological', - False) + stream = source.get_record_stream(source.keys(), "topological", False) files.insert_record_stream(stream) self.assertIdenticalVersionedFile(source, files) def test_insert_record_stream_existing_keys(self): """Inserting keys already in a file should not error.""" files = self.get_versionedfiles() - source = self.get_versionedfiles('source') + source = self.get_versionedfiles("source") self.get_diamond_files(source) # insert some keys into f. self.get_diamond_files(files, left_only=True) - stream = source.get_record_stream(source.keys(), 'topological', - False) + stream = source.get_record_stream(source.keys(), "topological", False) files.insert_record_stream(stream) self.assertIdenticalVersionedFile(source, files) def test_insert_record_stream_missing_keys(self): """Inserting a stream with absent keys should raise an error.""" files = self.get_versionedfiles() - source = self.get_versionedfiles('source') - stream = source.get_record_stream([(b'missing',) * self.key_length], - 'topological', False) - self.assertRaises(errors.RevisionNotPresent, files.insert_record_stream, - stream) + source = self.get_versionedfiles("source") + stream = source.get_record_stream( + [(b"missing",) * self.key_length], "topological", False + ) + self.assertRaises(errors.RevisionNotPresent, files.insert_record_stream, stream) def test_insert_record_stream_out_of_order(self): """An out of order stream can either error or work.""" files = self.get_versionedfiles() - source = self.get_versionedfiles('source') + source = self.get_versionedfiles("source") self.get_diamond_files(source) if self.key_length == 1: - origin_keys = [(b'origin',)] - end_keys = [(b'merged',), (b'left',)] - start_keys = [(b'right',), (b'base',)] + origin_keys = [(b"origin",)] + end_keys = [(b"merged",), (b"left",)] + start_keys = [(b"right",), (b"base",)] else: - origin_keys = [(b'FileA', b'origin'), (b'FileB', b'origin')] - end_keys = [(b'FileA', b'merged',), (b'FileA', b'left',), - (b'FileB', b'merged',), (b'FileB', b'left',)] - start_keys = [(b'FileA', b'right',), (b'FileA', b'base',), - (b'FileB', b'right',), (b'FileB', b'base',)] - origin_entries = source.get_record_stream( - origin_keys, 'unordered', False) - end_entries = source.get_record_stream(end_keys, 'topological', False) - start_entries = source.get_record_stream( - start_keys, 'topological', False) + origin_keys = [(b"FileA", b"origin"), (b"FileB", b"origin")] + end_keys = [ + ( + b"FileA", + b"merged", + ), + ( + b"FileA", + b"left", + ), + ( + b"FileB", + b"merged", + ), + ( + b"FileB", + b"left", + ), + ] + start_keys = [ + ( + b"FileA", + b"right", + ), + ( + b"FileA", + b"base", + ), + ( + b"FileB", + b"right", + ), + ( + b"FileB", + b"base", + ), + ] + origin_entries = source.get_record_stream(origin_keys, "unordered", False) + end_entries = source.get_record_stream(end_keys, "topological", False) + start_entries = source.get_record_stream(start_keys, "topological", False) entries = itertools.chain(origin_entries, end_entries, start_entries) try: files.insert_record_stream(entries) @@ -2456,20 +2737,20 @@ def test_insert_record_stream_out_of_order(self): def test_insert_record_stream_long_parent_chain_out_of_order(self): """An out of order stream can either error or work.""" if not self.graph: - raise TestNotApplicable('ancestry info only relevant with graph.') + raise TestNotApplicable("ancestry info only relevant with graph.") # Create a reasonably long chain of records based on each other, where # most will be deltas. - source = self.get_versionedfiles('source') + source = self.get_versionedfiles("source") parents = () keys = [] - content = [(b'same same %d\n' % n) for n in range(500)] - letters = b'abcdefghijklmnopqrstuvwxyz' + content = [(b"same same %d\n" % n) for n in range(500)] + letters = b"abcdefghijklmnopqrstuvwxyz" for i in range(len(letters)): - letter = letters[i:i + 1] - key = (b'key-' + letter,) + letter = letters[i : i + 1] + key = (b"key-" + letter,) if self.key_length == 2: - key = (b'prefix',) + key - content.append(b'content for ' + letter + b'\n') + key = (b"prefix",) + key + content.append(b"content for " + letter + b"\n") source.add_lines(key, parents, content) keys.append(key) parents = (key,) @@ -2477,7 +2758,7 @@ def test_insert_record_stream_long_parent_chain_out_of_order(self): # rest ultimately depend upon, and insert it into a new vf. streams = [] for key in reversed(keys): - streams.append(source.get_record_stream([key], 'unordered', False)) + streams.append(source.get_record_stream([key], "unordered", False)) deltas = itertools.chain.from_iterable(streams[:-1]) files = self.get_versionedfiles() try: @@ -2497,11 +2778,12 @@ def get_knit_delta_source(self): regardless of this test's scenario. """ mapper = self.get_mapper() - source_transport = self.get_transport('source') - source_transport.mkdir('.') + source_transport = self.get_transport("source") + source_transport.mkdir(".") source = make_file_factory(False, mapper)(source_transport) - get_diamond_files(source, self.key_length, trailing_eol=True, - nograph=False, left_only=False) + get_diamond_files( + source, self.key_length, trailing_eol=True, nograph=False, left_only=False + ) return source def test_insert_record_stream_delta_missing_basis_no_corruption(self): @@ -2510,20 +2792,19 @@ def test_insert_record_stream_delta_missing_basis_no_corruption(self): not added. """ source = self.get_knit_delta_source() - keys = [self.get_simple_key(b'origin'), self.get_simple_key(b'merged')] - entries = source.get_record_stream(keys, 'unordered', False) + keys = [self.get_simple_key(b"origin"), self.get_simple_key(b"merged")] + entries = source.get_record_stream(keys, "unordered", False) files = self.get_versionedfiles() if self.support_partial_insertion: - self.assertEqual([], - list(files.get_missing_compression_parent_keys())) + self.assertEqual([], list(files.get_missing_compression_parent_keys())) files.insert_record_stream(entries) missing_bases = files.get_missing_compression_parent_keys() - self.assertEqual({self.get_simple_key(b'left')}, - set(missing_bases)) + self.assertEqual({self.get_simple_key(b"left")}, set(missing_bases)) self.assertEqual(set(keys), set(files.get_parent_map(keys))) else: self.assertRaises( - errors.RevisionNotPresent, files.insert_record_stream, entries) + errors.RevisionNotPresent, files.insert_record_stream, entries + ) files.check() def test_insert_record_stream_delta_missing_basis_can_be_added_later(self): @@ -2535,28 +2816,28 @@ def test_insert_record_stream_delta_missing_basis_can_be_added_later(self): """ if not self.support_partial_insertion: raise TestNotApplicable( - 'versioned file scenario does not support partial insertion') + "versioned file scenario does not support partial insertion" + ) source = self.get_knit_delta_source() - entries = source.get_record_stream([self.get_simple_key(b'origin'), - self.get_simple_key(b'merged')], 'unordered', False) + entries = source.get_record_stream( + [self.get_simple_key(b"origin"), self.get_simple_key(b"merged")], + "unordered", + False, + ) files = self.get_versionedfiles() files.insert_record_stream(entries) missing_bases = files.get_missing_compression_parent_keys() - self.assertEqual({self.get_simple_key(b'left')}, - set(missing_bases)) + self.assertEqual({self.get_simple_key(b"left")}, set(missing_bases)) # 'merged' is inserted (although a commit of a write group involving # this versionedfiles would fail). - merged_key = self.get_simple_key(b'merged') - self.assertEqual( - [merged_key], list(files.get_parent_map([merged_key]).keys())) + merged_key = self.get_simple_key(b"merged") + self.assertEqual([merged_key], list(files.get_parent_map([merged_key]).keys())) # Add the full delta closure of the missing records - missing_entries = source.get_record_stream( - missing_bases, 'unordered', True) + missing_entries = source.get_record_stream(missing_bases, "unordered", True) files.insert_record_stream(missing_entries) # Now 'merged' is fully inserted (and a commit would succeed). self.assertEqual([], list(files.get_missing_compression_parent_keys())) - self.assertEqual( - [merged_key], list(files.get_parent_map([merged_key]).keys())) + self.assertEqual([merged_key], list(files.get_parent_map([merged_key]).keys())) files.check() def test_iter_lines_added_or_present_in_keys(self): @@ -2566,7 +2847,6 @@ def test_iter_lines_added_or_present_in_keys(self): # more changes to muck up. class InstrumentedProgress(progress.ProgressTask): - def __init__(self): progress.ProgressTask.__init__(self) self.updates = [] @@ -2576,131 +2856,164 @@ def update(self, msg=None, current=None, total=None): files = self.get_versionedfiles() # add a base to get included - files.add_lines(self.get_simple_key(b'base'), (), [b'base\n']) + files.add_lines(self.get_simple_key(b"base"), (), [b"base\n"]) # add a ancestor to be included on one side - files.add_lines(self.get_simple_key( - b'lancestor'), (), [b'lancestor\n']) + files.add_lines(self.get_simple_key(b"lancestor"), (), [b"lancestor\n"]) # add a ancestor to be included on the other side - files.add_lines(self.get_simple_key(b'rancestor'), - self.get_parents([self.get_simple_key(b'base')]), [b'rancestor\n']) + files.add_lines( + self.get_simple_key(b"rancestor"), + self.get_parents([self.get_simple_key(b"base")]), + [b"rancestor\n"], + ) # add a child of rancestor with no eofile-nl - files.add_lines(self.get_simple_key(b'child'), - self.get_parents([self.get_simple_key(b'rancestor')]), - [b'base\n', b'child\n']) + files.add_lines( + self.get_simple_key(b"child"), + self.get_parents([self.get_simple_key(b"rancestor")]), + [b"base\n", b"child\n"], + ) # add a child of lancestor and base to join the two roots - files.add_lines(self.get_simple_key(b'otherchild'), - self.get_parents([self.get_simple_key(b'lancestor'), - self.get_simple_key(b'base')]), - [b'base\n', b'lancestor\n', b'otherchild\n']) + files.add_lines( + self.get_simple_key(b"otherchild"), + self.get_parents( + [self.get_simple_key(b"lancestor"), self.get_simple_key(b"base")] + ), + [b"base\n", b"lancestor\n", b"otherchild\n"], + ) def iter_with_keys(keys, expected): # now we need to see what lines are returned, and how often. lines = {} progress = InstrumentedProgress() # iterate over the lines - for line in files.iter_lines_added_or_present_in_keys(keys, - pb=progress): + for line in files.iter_lines_added_or_present_in_keys(keys, pb=progress): lines.setdefault(line, 0) lines[line] += 1 if [] != progress.updates: self.assertEqual(expected, progress.updates) return lines + lines = iter_with_keys( - [self.get_simple_key(b'child'), - self.get_simple_key(b'otherchild')], - [('Walking content', 0, 2), - ('Walking content', 1, 2), - ('Walking content', 2, 2)]) + [self.get_simple_key(b"child"), self.get_simple_key(b"otherchild")], + [ + ("Walking content", 0, 2), + ("Walking content", 1, 2), + ("Walking content", 2, 2), + ], + ) # we must see child and otherchild - self.assertTrue(lines[(b'child\n', self.get_simple_key(b'child'))] > 0) + self.assertTrue(lines[(b"child\n", self.get_simple_key(b"child"))] > 0) self.assertTrue( - lines[(b'otherchild\n', self.get_simple_key(b'otherchild'))] > 0) + lines[(b"otherchild\n", self.get_simple_key(b"otherchild"))] > 0 + ) # we dont care if we got more than that. # test all lines - lines = iter_with_keys(files.keys(), - [('Walking content', 0, 5), - ('Walking content', 1, 5), - ('Walking content', 2, 5), - ('Walking content', 3, 5), - ('Walking content', 4, 5), - ('Walking content', 5, 5)]) + lines = iter_with_keys( + files.keys(), + [ + ("Walking content", 0, 5), + ("Walking content", 1, 5), + ("Walking content", 2, 5), + ("Walking content", 3, 5), + ("Walking content", 4, 5), + ("Walking content", 5, 5), + ], + ) # all lines must be seen at least once - self.assertTrue(lines[(b'base\n', self.get_simple_key(b'base'))] > 0) - self.assertTrue( - lines[(b'lancestor\n', self.get_simple_key(b'lancestor'))] > 0) + self.assertTrue(lines[(b"base\n", self.get_simple_key(b"base"))] > 0) + self.assertTrue(lines[(b"lancestor\n", self.get_simple_key(b"lancestor"))] > 0) + self.assertTrue(lines[(b"rancestor\n", self.get_simple_key(b"rancestor"))] > 0) + self.assertTrue(lines[(b"child\n", self.get_simple_key(b"child"))] > 0) self.assertTrue( - lines[(b'rancestor\n', self.get_simple_key(b'rancestor'))] > 0) - self.assertTrue(lines[(b'child\n', self.get_simple_key(b'child'))] > 0) - self.assertTrue( - lines[(b'otherchild\n', self.get_simple_key(b'otherchild'))] > 0) + lines[(b"otherchild\n", self.get_simple_key(b"otherchild"))] > 0 + ) def test_make_mpdiffs(self): from breezy import multiparent - files = self.get_versionedfiles('source') + + files = self.get_versionedfiles("source") # add texts that should trip the knit maximum delta chain threshold # as well as doing parallel chains of data in knits. # this is done by two chains of 25 insertions - files.add_lines(self.get_simple_key(b'base'), [], [b'line\n']) - files.add_lines(self.get_simple_key(b'noeol'), - self.get_parents([self.get_simple_key(b'base')]), [b'line']) + files.add_lines(self.get_simple_key(b"base"), [], [b"line\n"]) + files.add_lines( + self.get_simple_key(b"noeol"), + self.get_parents([self.get_simple_key(b"base")]), + [b"line"], + ) # detailed eol tests: # shared last line with parent no-eol - files.add_lines(self.get_simple_key(b'noeolsecond'), - self.get_parents([self.get_simple_key(b'noeol')]), - [b'line\n', b'line']) + files.add_lines( + self.get_simple_key(b"noeolsecond"), + self.get_parents([self.get_simple_key(b"noeol")]), + [b"line\n", b"line"], + ) # differing last line with parent, both no-eol - files.add_lines(self.get_simple_key(b'noeolnotshared'), - self.get_parents( - [self.get_simple_key(b'noeolsecond')]), - [b'line\n', b'phone']) + files.add_lines( + self.get_simple_key(b"noeolnotshared"), + self.get_parents([self.get_simple_key(b"noeolsecond")]), + [b"line\n", b"phone"], + ) # add eol following a noneol parent, change content - files.add_lines(self.get_simple_key(b'eol'), - self.get_parents([self.get_simple_key(b'noeol')]), [b'phone\n']) + files.add_lines( + self.get_simple_key(b"eol"), + self.get_parents([self.get_simple_key(b"noeol")]), + [b"phone\n"], + ) # add eol following a noneol parent, no change content - files.add_lines(self.get_simple_key(b'eolline'), - self.get_parents([self.get_simple_key(b'noeol')]), [b'line\n']) + files.add_lines( + self.get_simple_key(b"eolline"), + self.get_parents([self.get_simple_key(b"noeol")]), + [b"line\n"], + ) # noeol with no parents: - files.add_lines(self.get_simple_key(b'noeolbase'), [], [b'line']) + files.add_lines(self.get_simple_key(b"noeolbase"), [], [b"line"]) # noeol preceeding its leftmost parent in the output: # this is done by making it a merge of two parents with no common # anestry: noeolbase and noeol with the # later-inserted parent the leftmost. - files.add_lines(self.get_simple_key(b'eolbeforefirstparent'), - self.get_parents([self.get_simple_key(b'noeolbase'), - self.get_simple_key(b'noeol')]), - [b'line']) + files.add_lines( + self.get_simple_key(b"eolbeforefirstparent"), + self.get_parents( + [self.get_simple_key(b"noeolbase"), self.get_simple_key(b"noeol")] + ), + [b"line"], + ) # two identical eol texts - files.add_lines(self.get_simple_key(b'noeoldup'), - self.get_parents([self.get_simple_key(b'noeol')]), [b'line']) - next_parent = self.get_simple_key(b'base') - text_name = b'chain1-' - text = [b'line\n'] + files.add_lines( + self.get_simple_key(b"noeoldup"), + self.get_parents([self.get_simple_key(b"noeol")]), + [b"line"], + ) + next_parent = self.get_simple_key(b"base") + text_name = b"chain1-" + text = [b"line\n"] for depth in range(26): - new_version = self.get_simple_key(text_name + b'%d' % depth) - text = text + [b'line\n'] + new_version = self.get_simple_key(text_name + b"%d" % depth) + text = text + [b"line\n"] files.add_lines(new_version, self.get_parents([next_parent]), text) next_parent = new_version - next_parent = self.get_simple_key(b'base') - text_name = b'chain2-' - text = [b'line\n'] + next_parent = self.get_simple_key(b"base") + text_name = b"chain2-" + text = [b"line\n"] for depth in range(26): - new_version = self.get_simple_key(text_name + b'%d' % depth) - text = text + [b'line\n'] + new_version = self.get_simple_key(text_name + b"%d" % depth) + text = text + [b"line\n"] files.add_lines(new_version, self.get_parents([next_parent]), text) next_parent = new_version - target = self.get_versionedfiles('target') + target = self.get_versionedfiles("target") for key in multiparent.topo_iter_keys(files, files.keys()): mpdiff = files.make_mpdiffs([key])[0] parents = files.get_parent_map([key])[key] or [] - target.add_mpdiffs( - [(key, parents, files.get_sha1s([key])[key], mpdiff)]) + target.add_mpdiffs([(key, parents, files.get_sha1s([key])[key], mpdiff)]) self.assertEqualDiff( - next(files.get_record_stream([key], 'unordered', - True)).get_bytes_as('fulltext'), - next(target.get_record_stream([key], 'unordered', - True)).get_bytes_as('fulltext') - ) + next(files.get_record_stream([key], "unordered", True)).get_bytes_as( + "fulltext" + ), + next(target.get_record_stream([key], "unordered", True)).get_bytes_as( + "fulltext" + ), + ) def test_keys(self): # While use is discouraged, versions() is still needed by aspects of @@ -2708,9 +3021,12 @@ def test_keys(self): files = self.get_versionedfiles() self.assertEqual(set(), set(files.keys())) if self.key_length == 1: - key = (b'foo',) + key = (b"foo",) else: - key = (b'foo', b'bar',) + key = ( + b"foo", + b"bar", + ) files.add_lines(key, (), []) self.assertEqual({key}, set(files.keys())) @@ -2729,36 +3045,35 @@ def setUp(self): super().setUp() self._lines = {} self._parent_map = {} - self.texts = VirtualVersionedFiles(self._get_parent_map, - self._lines.get) + self.texts = VirtualVersionedFiles(self._get_parent_map, self._lines.get) def test_add_lines(self): - self.assertRaises(NotImplementedError, - self.texts.add_lines, b"foo", [], []) + self.assertRaises(NotImplementedError, self.texts.add_lines, b"foo", [], []) def test_add_mpdiffs(self): - self.assertRaises(NotImplementedError, - self.texts.add_mpdiffs, []) + self.assertRaises(NotImplementedError, self.texts.add_mpdiffs, []) def test_check_noerrors(self): self.texts.check() def test_insert_record_stream(self): - self.assertRaises(NotImplementedError, self.texts.insert_record_stream, - []) + self.assertRaises(NotImplementedError, self.texts.insert_record_stream, []) def test_get_sha1s_nonexistent(self): self.assertEqual({}, self.texts.get_sha1s([(b"NONEXISTENT",)])) def test_get_sha1s(self): self._lines[b"key"] = [b"dataline1", b"dataline2"] - self.assertEqual({(b"key",): osutils.sha_strings(self._lines[b"key"])}, - self.texts.get_sha1s([(b"key",)])) + self.assertEqual( + {(b"key",): osutils.sha_strings(self._lines[b"key"])}, + self.texts.get_sha1s([(b"key",)]), + ) def test_get_parent_map(self): self._parent_map = {b"G": (b"A", b"B")} - self.assertEqual({(b"G",): ((b"A",), (b"B",))}, - self.texts.get_parent_map([(b"G",), (b"L",)])) + self.assertEqual( + {(b"G",): ((b"A",), (b"B",))}, self.texts.get_parent_map([(b"G",), (b"L",)]) + ) def test_get_record_stream(self): self._lines[b"A"] = [b"FOO", b"BAR"] @@ -2778,21 +3093,21 @@ def test_iter_lines_added_or_present_in_keys(self): self._lines[b"B"] = [b"HEY"] self._lines[b"C"] = [b"Alberta"] it = self.texts.iter_lines_added_or_present_in_keys([(b"A",), (b"B",)]) - self.assertEqual(sorted([(b"FOO", b"A"), (b"BAR", b"A"), (b"HEY", b"B")]), - sorted(it)) + self.assertEqual( + sorted([(b"FOO", b"A"), (b"BAR", b"A"), (b"HEY", b"B")]), sorted(it) + ) class TestOrderingVersionedFilesDecorator(TestCaseWithMemoryTransport): - def get_ordering_vf(self, key_priority): - builder = self.make_branch_builder('test') + builder = self.make_branch_builder("test") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'TREE_ROOT', 'directory', None))], - revision_id=b'A') - builder.build_snapshot([b'A'], [], revision_id=b'B') - builder.build_snapshot([b'B'], [], revision_id=b'C') - builder.build_snapshot([b'C'], [], revision_id=b'D') + builder.build_snapshot( + None, [("add", ("", b"TREE_ROOT", "directory", None))], revision_id=b"A" + ) + builder.build_snapshot([b"A"], [], revision_id=b"B") + builder.build_snapshot([b"B"], [], revision_id=b"C") + builder.build_snapshot([b"C"], [], revision_id=b"D") builder.finish_series() b = builder.get_branch() b.lock_read() @@ -2805,37 +3120,35 @@ def test_get_empty(self): self.assertEqual([], vf.calls) def test_get_record_stream_topological(self): - vf = self.get_ordering_vf( - {(b'A',): 3, (b'B',): 2, (b'C',): 4, (b'D',): 1}) - request_keys = [(b'B',), (b'C',), (b'D',), (b'A',)] - keys = [r.key for r in vf.get_record_stream(request_keys, - 'topological', False)] + vf = self.get_ordering_vf({(b"A",): 3, (b"B",): 2, (b"C",): 4, (b"D",): 1}) + request_keys = [(b"B",), (b"C",), (b"D",), (b"A",)] + keys = [r.key for r in vf.get_record_stream(request_keys, "topological", False)] # We should have gotten the keys in topological order - self.assertEqual([(b'A',), (b'B',), (b'C',), (b'D',)], keys) + self.assertEqual([(b"A",), (b"B",), (b"C",), (b"D",)], keys) # And recorded that the request was made - self.assertEqual([('get_record_stream', request_keys, 'topological', - False)], vf.calls) + self.assertEqual( + [("get_record_stream", request_keys, "topological", False)], vf.calls + ) def test_get_record_stream_ordered(self): - vf = self.get_ordering_vf( - {(b'A',): 3, (b'B',): 2, (b'C',): 4, (b'D',): 1}) - request_keys = [(b'B',), (b'C',), (b'D',), (b'A',)] - keys = [r.key for r in vf.get_record_stream(request_keys, - 'unordered', False)] + vf = self.get_ordering_vf({(b"A",): 3, (b"B",): 2, (b"C",): 4, (b"D",): 1}) + request_keys = [(b"B",), (b"C",), (b"D",), (b"A",)] + keys = [r.key for r in vf.get_record_stream(request_keys, "unordered", False)] # They should be returned based on their priority - self.assertEqual([(b'D',), (b'B',), (b'A',), (b'C',)], keys) + self.assertEqual([(b"D",), (b"B",), (b"A",), (b"C",)], keys) # And the request recorded - self.assertEqual([('get_record_stream', request_keys, 'unordered', - False)], vf.calls) + self.assertEqual( + [("get_record_stream", request_keys, "unordered", False)], vf.calls + ) def test_get_record_stream_implicit_order(self): - vf = self.get_ordering_vf({(b'B',): 2, (b'D',): 1}) - request_keys = [(b'B',), (b'C',), (b'D',), (b'A',)] - keys = [r.key for r in vf.get_record_stream(request_keys, - 'unordered', False)] + vf = self.get_ordering_vf({(b"B",): 2, (b"D",): 1}) + request_keys = [(b"B",), (b"C",), (b"D",), (b"A",)] + keys = [r.key for r in vf.get_record_stream(request_keys, "unordered", False)] # A and C are not in the map, so they get sorted to the front. A comes # before C alphabetically, so it comes back first - self.assertEqual([(b'A',), (b'C',), (b'D',), (b'B',)], keys) + self.assertEqual([(b"A",), (b"C",), (b"D",), (b"B",)], keys) # And the request recorded - self.assertEqual([('get_record_stream', request_keys, 'unordered', - False)], vf.calls) + self.assertEqual( + [("get_record_stream", request_keys, "unordered", False)], vf.calls + ) diff --git a/breezy/bzr/tests/test__btree_serializer.py b/breezy/bzr/tests/test__btree_serializer.py index ff14eb34a6..1208f0a570 100644 --- a/breezy/bzr/tests/test__btree_serializer.py +++ b/breezy/bzr/tests/test__btree_serializer.py @@ -25,7 +25,6 @@ class TestBtreeSerializer(tests.TestCase): - _test_needs_features = [compiled_btreeparser_feature] @property @@ -34,21 +33,23 @@ def module(self): class TestHexAndUnhex(TestBtreeSerializer): - def assertHexlify(self, as_binary): - self.assertEqual(binascii.hexlify(as_binary), - self.module._py_hexlify(as_binary)) + self.assertEqual( + binascii.hexlify(as_binary), self.module._py_hexlify(as_binary) + ) def assertUnhexlify(self, as_hex): ba_unhex = binascii.unhexlify(as_hex) mod_unhex = self.module._py_unhexlify(as_hex) if ba_unhex != mod_unhex: if mod_unhex is None: - mod_hex = b'' + mod_hex = b"" else: mod_hex = binascii.hexlify(mod_unhex) - self.fail('_py_unhexlify returned a different answer' - f' from binascii:\n {binascii.hexlify(ba_unhex)!r}\n != {mod_hex!r}') + self.fail( + "_py_unhexlify returned a different answer" + f" from binascii:\n {binascii.hexlify(ba_unhex)!r}\n != {mod_hex!r}" + ) def assertFailUnhexlify(self, as_hex): # Invalid hex content @@ -57,33 +58,32 @@ def assertFailUnhexlify(self, as_hex): def test_to_hex(self): raw_bytes = bytes(range(256)) for i in range(0, 240, 20): - self.assertHexlify(raw_bytes[i:i + 20]) + self.assertHexlify(raw_bytes[i : i + 20]) self.assertHexlify(raw_bytes[240:] + raw_bytes[0:4]) def test_from_hex(self): - self.assertUnhexlify(b'0123456789abcdef0123456789abcdef01234567') - self.assertUnhexlify(b'123456789abcdef0123456789abcdef012345678') - self.assertUnhexlify(b'0123456789ABCDEF0123456789ABCDEF01234567') - self.assertUnhexlify(b'123456789ABCDEF0123456789ABCDEF012345678') + self.assertUnhexlify(b"0123456789abcdef0123456789abcdef01234567") + self.assertUnhexlify(b"123456789abcdef0123456789abcdef012345678") + self.assertUnhexlify(b"0123456789ABCDEF0123456789ABCDEF01234567") + self.assertUnhexlify(b"123456789ABCDEF0123456789ABCDEF012345678") hex_chars = binascii.hexlify(bytes(range(256))) for i in range(0, 480, 40): - self.assertUnhexlify(hex_chars[i:i + 40]) + self.assertUnhexlify(hex_chars[i : i + 40]) self.assertUnhexlify(hex_chars[480:] + hex_chars[0:8]) def test_from_invalid_hex(self): - self.assertFailUnhexlify(b'123456789012345678901234567890123456789X') - self.assertFailUnhexlify(b'12345678901234567890123456789012345678X9') + self.assertFailUnhexlify(b"123456789012345678901234567890123456789X") + self.assertFailUnhexlify(b"12345678901234567890123456789012345678X9") def test_bad_argument(self): - self.assertRaises(ValueError, self.module._py_unhexlify, '1a') - self.assertRaises(ValueError, self.module._py_unhexlify, b'1b') + self.assertRaises(ValueError, self.module._py_unhexlify, "1a") + self.assertRaises(ValueError, self.module._py_unhexlify, b"1b") -_hex_form = b'123456789012345678901234567890abcdefabcd' +_hex_form = b"123456789012345678901234567890abcdefabcd" class Test_KeyToSha1(TestBtreeSerializer): - def assertKeyToSha1(self, expected, key): if expected is None: expected_bin = None @@ -93,14 +93,14 @@ def assertKeyToSha1(self, expected, key): if expected_bin != actual_sha1: if actual_sha1 is not None: binascii.hexlify(actual_sha1) - self.fail(f'_key_to_sha1 returned:\n {actual_sha1}\n != {expected}') + self.fail(f"_key_to_sha1 returned:\n {actual_sha1}\n != {expected}") def test_simple(self): - self.assertKeyToSha1(_hex_form, (b'sha1:' + _hex_form,)) + self.assertKeyToSha1(_hex_form, (b"sha1:" + _hex_form,)) def test_invalid_not_tuple(self): self.assertKeyToSha1(None, _hex_form) - self.assertKeyToSha1(None, b'sha1:' + _hex_form) + self.assertKeyToSha1(None, b"sha1:" + _hex_form) def test_invalid_empty(self): self.assertKeyToSha1(None, ()) @@ -111,19 +111,17 @@ def test_invalid_not_string(self): def test_invalid_not_sha1(self): self.assertKeyToSha1(None, (_hex_form,)) - self.assertKeyToSha1(None, (b'sha2:' + _hex_form,)) + self.assertKeyToSha1(None, (b"sha2:" + _hex_form,)) def test_invalid_not_hex(self): - self.assertKeyToSha1(None, - (b'sha1:abcdefghijklmnopqrstuvwxyz12345678901234',)) + self.assertKeyToSha1(None, (b"sha1:abcdefghijklmnopqrstuvwxyz12345678901234",)) class Test_Sha1ToKey(TestBtreeSerializer): - def assertSha1ToKey(self, hex_sha1): bin_sha1 = binascii.unhexlify(hex_sha1) key = self.module._py_sha1_to_key(bin_sha1) - self.assertEqual((b'sha1:' + hex_sha1,), key) + self.assertEqual((b"sha1:" + hex_sha1,), key) def test_simple(self): self.assertSha1ToKey(_hex_form) @@ -174,43 +172,47 @@ def test_simple(self): class TestGCCKHSHA1LeafNode(TestBtreeSerializer): - def assertInvalid(self, data): """Ensure that we get a proper error when trying to parse invalid bytes. (mostly this is testing that bad input doesn't cause us to segfault) """ self.assertRaises( - (ValueError, TypeError), self.module._parse_into_chk, data, 1, 0) + (ValueError, TypeError), self.module._parse_into_chk, data, 1, 0 + ) def test_non_bytes(self): - self.assertInvalid('type=leaf\n') + self.assertInvalid("type=leaf\n") def test_not_leaf(self): - self.assertInvalid(b'type=internal\n') + self.assertInvalid(b"type=internal\n") def test_empty_leaf(self): - leaf = self.module._parse_into_chk(b'type=leaf\n', 1, 0) + leaf = self.module._parse_into_chk(b"type=leaf\n", 1, 0) self.assertEqual(0, len(leaf)) self.assertEqual([], leaf.all_items()) self.assertEqual([], leaf.all_keys()) # It should allow any key to be queried - self.assertNotIn(('key',), leaf) + self.assertNotIn(("key",), leaf) def test_one_key_leaf(self): leaf = self.module._parse_into_chk(_one_key_content, 1, 0) self.assertEqual(1, len(leaf)) - sha_key = (b'sha1:' + _hex_form,) + sha_key = (b"sha1:" + _hex_form,) self.assertEqual([sha_key], leaf.all_keys()) - self.assertEqual([(sha_key, (b'1 2 3 4', ()))], leaf.all_items()) + self.assertEqual([(sha_key, (b"1 2 3 4", ()))], leaf.all_items()) self.assertIn(sha_key, leaf) def test_large_offsets(self): leaf = self.module._parse_into_chk(_large_offsets, 1, 0) - self.assertEqual([b'12345678901 1234567890 0 1', - b'2147483648 2147483647 0 1', - b'4294967296 4294967295 4294967294 1', - ], [x[1][0] for x in leaf.all_items()]) + self.assertEqual( + [ + b"12345678901 1234567890 0 1", + b"2147483648 2147483647 0 1", + b"4294967296 4294967295 4294967294 1", + ], + [x[1][0] for x in leaf.all_items()], + ) def test_many_key_leaf(self): leaf = self.module._parse_into_chk(_multi_key_content, 1, 0) @@ -218,7 +220,7 @@ def test_many_key_leaf(self): all_keys = leaf.all_keys() self.assertEqual(8, len(leaf.all_keys())) for idx, key in enumerate(all_keys): - self.assertEqual(b'%d' % idx, leaf[key][0].split()[0]) + self.assertEqual(b"%d" % idx, leaf[key][0].split()[0]) def test_common_shift(self): # The keys were deliberately chosen so that the first 5 bits all @@ -231,12 +233,11 @@ def test_common_shift(self): # (defined as the 8-bits that come after the common prefix) lst = [1, 13, 28, 180, 190, 193, 210, 239] offsets = leaf._get_offsets() - self.assertEqual([bisect.bisect_left(lst, x) for x in range(0, 257)], - offsets) + self.assertEqual([bisect.bisect_left(lst, x) for x in range(0, 257)], offsets) for idx, val in enumerate(lst): self.assertEqual(idx, offsets[val]) for idx, key in enumerate(leaf.all_keys()): - self.assertEqual(b'%d' % idx, leaf[key][0].split()[0]) + self.assertEqual(b"%d" % idx, leaf[key][0].split()[0]) def test_multi_key_same_offset(self): # there is no common prefix, though there are some common bits @@ -245,12 +246,11 @@ def test_multi_key_same_offset(self): offsets = leaf._get_offsets() # The interesting byte is just the first 8-bits of the key lst = [8, 200, 205, 205, 205, 205, 206, 206] - self.assertEqual([bisect.bisect_left(lst, x) for x in range(0, 257)], - offsets) + self.assertEqual([bisect.bisect_left(lst, x) for x in range(0, 257)], offsets) for val in lst: self.assertEqual(lst.index(val), offsets[val]) for idx, key in enumerate(leaf.all_keys()): - self.assertEqual(b'%d' % idx, leaf[key][0].split()[0]) + self.assertEqual(b"%d" % idx, leaf[key][0].split()[0]) def test_all_common_prefix(self): # The first 32 bits of all hashes are the same. This is going to be @@ -259,23 +259,22 @@ def test_all_common_prefix(self): self.assertEqual(0, leaf.common_shift) lst = [0x78] * 8 offsets = leaf._get_offsets() - self.assertEqual([bisect.bisect_left(lst, x) for x in range(0, 257)], - offsets) + self.assertEqual([bisect.bisect_left(lst, x) for x in range(0, 257)], offsets) for val in lst: self.assertEqual(lst.index(val), offsets[val]) for idx, key in enumerate(leaf.all_keys()): - self.assertEqual(b'%d' % idx, leaf[key][0].split()[0]) + self.assertEqual(b"%d" % idx, leaf[key][0].split()[0]) def test_many_entries(self): # Again, this is almost impossible, but we should still work # It would be hard to fit more that 120 entries in a 4k page, much less # more than 256 of them. but hey, weird stuff happens sometimes - lines = [b'type=leaf\n'] + lines = [b"type=leaf\n"] for i in range(500): - key_str = b'sha1:%04x%s' % (i, _hex_form[:36]) + key_str = b"sha1:%04x%s" % (i, _hex_form[:36]) key = (key_str,) - lines.append(b'%s\0\0%d %d %d %d\n' % (key_str, i, i, i, i)) - data = b''.join(lines) + lines.append(b"%s\0\0%d %d %d %d\n" % (key_str, i, i, i, i)) + data = b"".join(lines) leaf = self.module._parse_into_chk(data, 1, 0) self.assertEqual(24 - 7, leaf.common_shift) offsets = leaf._get_offsets() @@ -286,18 +285,17 @@ def test_many_entries(self): # We truncate because offsets is an unsigned char. So the bisection # will just say 'greater than the last one' for all the rest lst = lst[:255] - self.assertEqual([bisect.bisect_left(lst, x) for x in range(0, 257)], - offsets) + self.assertEqual([bisect.bisect_left(lst, x) for x in range(0, 257)], offsets) for val in lst: self.assertEqual(lst.index(val), offsets[val]) for idx, key in enumerate(leaf.all_keys()): - self.assertEqual(b'%d' % idx, leaf[key][0].split()[0]) + self.assertEqual(b"%d" % idx, leaf[key][0].split()[0]) def test__sizeof__(self): # We can't use the exact numbers because of platform variations, etc. # But what we really care about is that it does get bigger with more # content. - leaf0 = self.module._parse_into_chk(b'type=leaf\n', 1, 0) + leaf0 = self.module._parse_into_chk(b"type=leaf\n", 1, 0) leaf1 = self.module._parse_into_chk(_one_key_content, 1, 0) leafN = self.module._parse_into_chk(_multi_key_content, 1, 0) sizeof_1 = leaf1.__sizeof__() - leaf0.__sizeof__() diff --git a/breezy/bzr/tests/test__chk_map.py b/breezy/bzr/tests/test__chk_map.py index 8d48b40b59..e7a1396d71 100644 --- a/breezy/bzr/tests/test__chk_map.py +++ b/breezy/bzr/tests/test__chk_map.py @@ -24,33 +24,41 @@ def load_tests(loader, standard_tests, pattern): - suite, _ = tests.permute_tests_for_extension(standard_tests, loader, - 'breezy.bzr._chk_map_py', 'breezy.bzr._chk_map_pyx') + suite, _ = tests.permute_tests_for_extension( + standard_tests, loader, "breezy.bzr._chk_map_py", "breezy.bzr._chk_map_pyx" + ) return suite class TestDeserialiseLeafNode(tests.TestCase): - module = None def assertDeserialiseErrors(self, text): - self.assertRaises((ValueError, IndexError), - self.module._deserialise_leaf_node, text, b'not-a-real-sha') + self.assertRaises( + (ValueError, IndexError), + self.module._deserialise_leaf_node, + text, + b"not-a-real-sha", + ) def test_raises_on_non_leaf(self): - self.assertDeserialiseErrors(b'') - self.assertDeserialiseErrors(b'short\n') - self.assertDeserialiseErrors(b'chknotleaf:\n') - self.assertDeserialiseErrors(b'chkleaf:x\n') - self.assertDeserialiseErrors(b'chkleaf:\n') - self.assertDeserialiseErrors(b'chkleaf:\nnotint\n') - self.assertDeserialiseErrors(b'chkleaf:\n10\n') - self.assertDeserialiseErrors(b'chkleaf:\n10\n256\n') - self.assertDeserialiseErrors(b'chkleaf:\n10\n256\n10\n') + self.assertDeserialiseErrors(b"") + self.assertDeserialiseErrors(b"short\n") + self.assertDeserialiseErrors(b"chknotleaf:\n") + self.assertDeserialiseErrors(b"chkleaf:x\n") + self.assertDeserialiseErrors(b"chkleaf:\n") + self.assertDeserialiseErrors(b"chkleaf:\nnotint\n") + self.assertDeserialiseErrors(b"chkleaf:\n10\n") + self.assertDeserialiseErrors(b"chkleaf:\n10\n256\n") + self.assertDeserialiseErrors(b"chkleaf:\n10\n256\n10\n") def test_deserialise_empty(self): node = self.module._deserialise_leaf_node( - b"chkleaf:\n10\n1\n0\n\n", stuple(b"sha1:1234",)) + b"chkleaf:\n10\n1\n0\n\n", + stuple( + b"sha1:1234", + ), + ) self.assertEqual(0, len(node)) self.assertEqual(10, node.maximum_size) self.assertEqual((b"sha1:1234",), node.key()) @@ -61,145 +69,184 @@ def test_deserialise_empty(self): def test_deserialise_items(self): node = self.module._deserialise_leaf_node( b"chkleaf:\n0\n1\n2\n\nfoo bar\x001\nbaz\nquux\x001\nblarh\n", - (b"sha1:1234",)) + (b"sha1:1234",), + ) self.assertEqual(2, len(node)) - self.assertEqual([((b"foo bar",), b"baz"), ((b"quux",), b"blarh")], - sorted(node.iteritems(None))) + self.assertEqual( + [((b"foo bar",), b"baz"), ((b"quux",), b"blarh")], + sorted(node.iteritems(None)), + ) def test_deserialise_item_with_null_width_1(self): node = self.module._deserialise_leaf_node( b"chkleaf:\n0\n1\n2\n\nfoo\x001\nbar\x00baz\nquux\x001\nblarh\n", - (b"sha1:1234",)) + (b"sha1:1234",), + ) self.assertEqual(2, len(node)) - self.assertEqual([((b"foo",), b"bar\x00baz"), ((b"quux",), b"blarh")], - sorted(node.iteritems(None))) + self.assertEqual( + [((b"foo",), b"bar\x00baz"), ((b"quux",), b"blarh")], + sorted(node.iteritems(None)), + ) def test_deserialise_item_with_null_width_2(self): node = self.module._deserialise_leaf_node( b"chkleaf:\n0\n2\n2\n\nfoo\x001\x001\nbar\x00baz\n" b"quux\x00\x001\nblarh\n", - (b"sha1:1234",)) + (b"sha1:1234",), + ) self.assertEqual(2, len(node)) - self.assertEqual([((b"foo", b"1"), b"bar\x00baz"), ((b"quux", b""), b"blarh")], - sorted(node.iteritems(None))) + self.assertEqual( + [((b"foo", b"1"), b"bar\x00baz"), ((b"quux", b""), b"blarh")], + sorted(node.iteritems(None)), + ) def test_iteritems_selected_one_of_two_items(self): node = self.module._deserialise_leaf_node( b"chkleaf:\n0\n1\n2\n\nfoo bar\x001\nbaz\nquux\x001\nblarh\n", - (b"sha1:1234",)) + (b"sha1:1234",), + ) self.assertEqual(2, len(node)) - self.assertEqual([((b"quux",), b"blarh")], - sorted(node.iteritems(None, [(b"quux",), (b"qaz",)]))) + self.assertEqual( + [((b"quux",), b"blarh")], + sorted(node.iteritems(None, [(b"quux",), (b"qaz",)])), + ) def test_deserialise_item_with_common_prefix(self): node = self.module._deserialise_leaf_node( b"chkleaf:\n0\n2\n2\nfoo\x00\n1\x001\nbar\x00baz\n2\x001\nblarh\n", - (b"sha1:1234",)) + (b"sha1:1234",), + ) self.assertEqual(2, len(node)) - self.assertEqual([((b"foo", b"1"), b"bar\x00baz"), ((b"foo", b"2"), b"blarh")], - sorted(node.iteritems(None))) + self.assertEqual( + [((b"foo", b"1"), b"bar\x00baz"), ((b"foo", b"2"), b"blarh")], + sorted(node.iteritems(None)), + ) self.assertIs(chk_map._unknown, node._search_prefix) - self.assertEqual(b'foo\x00', node._common_serialised_prefix) + self.assertEqual(b"foo\x00", node._common_serialised_prefix) def test_deserialise_multi_line(self): node = self.module._deserialise_leaf_node( b"chkleaf:\n0\n2\n2\nfoo\x00\n1\x002\nbar\nbaz\n2\x002\nblarh\n\n", - (b"sha1:1234",)) + (b"sha1:1234",), + ) self.assertEqual(2, len(node)) - self.assertEqual([((b"foo", b"1"), b"bar\nbaz"), - ((b"foo", b"2"), b"blarh\n"), - ], sorted(node.iteritems(None))) + self.assertEqual( + [ + ((b"foo", b"1"), b"bar\nbaz"), + ((b"foo", b"2"), b"blarh\n"), + ], + sorted(node.iteritems(None)), + ) self.assertIs(chk_map._unknown, node._search_prefix) - self.assertEqual(b'foo\x00', node._common_serialised_prefix) + self.assertEqual(b"foo\x00", node._common_serialised_prefix) def test_key_after_map(self): node = self.module._deserialise_leaf_node( - b"chkleaf:\n10\n1\n0\n\n", (b"sha1:1234",)) + b"chkleaf:\n10\n1\n0\n\n", (b"sha1:1234",) + ) node.map(None, (b"foo bar",), b"baz quux") self.assertEqual(None, node.key()) def test_key_after_unmap(self): node = self.module._deserialise_leaf_node( b"chkleaf:\n0\n1\n2\n\nfoo bar\x001\nbaz\nquux\x001\nblarh\n", - (b"sha1:1234",)) + (b"sha1:1234",), + ) node.unmap(None, (b"foo bar",)) self.assertEqual(None, node.key()) class TestDeserialiseInternalNode(tests.TestCase): - module = None def assertDeserialiseErrors(self, text): - self.assertRaises((ValueError, IndexError), - self.module._deserialise_internal_node, text, - stuple(b'not-a-real-sha',)) + self.assertRaises( + (ValueError, IndexError), + self.module._deserialise_internal_node, + text, + stuple( + b"not-a-real-sha", + ), + ) def test_raises_on_non_internal(self): - self.assertDeserialiseErrors(b'') - self.assertDeserialiseErrors(b'short\n') - self.assertDeserialiseErrors(b'chknotnode:\n') - self.assertDeserialiseErrors(b'chknode:x\n') - self.assertDeserialiseErrors(b'chknode:\n') - self.assertDeserialiseErrors(b'chknode:\nnotint\n') - self.assertDeserialiseErrors(b'chknode:\n10\n') - self.assertDeserialiseErrors(b'chknode:\n10\n256\n') - self.assertDeserialiseErrors(b'chknode:\n10\n256\n10\n') + self.assertDeserialiseErrors(b"") + self.assertDeserialiseErrors(b"short\n") + self.assertDeserialiseErrors(b"chknotnode:\n") + self.assertDeserialiseErrors(b"chknode:x\n") + self.assertDeserialiseErrors(b"chknode:\n") + self.assertDeserialiseErrors(b"chknode:\nnotint\n") + self.assertDeserialiseErrors(b"chknode:\n10\n") + self.assertDeserialiseErrors(b"chknode:\n10\n256\n") + self.assertDeserialiseErrors(b"chknode:\n10\n256\n10\n") # no trailing newline - self.assertDeserialiseErrors(b'chknode:\n10\n256\n0\n1\nfo') + self.assertDeserialiseErrors(b"chknode:\n10\n256\n0\n1\nfo") def test_deserialise_one(self): node = self.module._deserialise_internal_node( - b"chknode:\n10\n1\n1\n\na\x00sha1:abcd\n", stuple(b'sha1:1234',)) + b"chknode:\n10\n1\n1\n\na\x00sha1:abcd\n", + stuple( + b"sha1:1234", + ), + ) self.assertIsInstance(node, chk_map.InternalNode) self.assertEqual(1, len(node)) self.assertEqual(10, node.maximum_size) self.assertEqual((b"sha1:1234",), node.key()) - self.assertEqual(b'', node._search_prefix) - self.assertEqual({b'a': (b'sha1:abcd',)}, node._items) + self.assertEqual(b"", node._search_prefix) + self.assertEqual({b"a": (b"sha1:abcd",)}, node._items) def test_deserialise_with_prefix(self): node = self.module._deserialise_internal_node( b"chknode:\n10\n1\n1\npref\na\x00sha1:abcd\n", - stuple(b'sha1:1234',)) + stuple( + b"sha1:1234", + ), + ) self.assertIsInstance(node, chk_map.InternalNode) self.assertEqual(1, len(node)) self.assertEqual(10, node.maximum_size) self.assertEqual((b"sha1:1234",), node.key()) - self.assertEqual(b'pref', node._search_prefix) - self.assertEqual({b'prefa': (b'sha1:abcd',)}, node._items) + self.assertEqual(b"pref", node._search_prefix) + self.assertEqual({b"prefa": (b"sha1:abcd",)}, node._items) node = self.module._deserialise_internal_node( b"chknode:\n10\n1\n1\npref\n\x00sha1:abcd\n", - stuple(b'sha1:1234',)) + stuple( + b"sha1:1234", + ), + ) self.assertIsInstance(node, chk_map.InternalNode) self.assertEqual(1, len(node)) self.assertEqual(10, node.maximum_size) self.assertEqual((b"sha1:1234",), node.key()) - self.assertEqual(b'pref', node._search_prefix) - self.assertEqual({b'pref': (b'sha1:abcd',)}, node._items) + self.assertEqual(b"pref", node._search_prefix) + self.assertEqual({b"pref": (b"sha1:abcd",)}, node._items) def test_deserialise_pref_with_null(self): node = self.module._deserialise_internal_node( b"chknode:\n10\n1\n1\npref\x00fo\n\x00sha1:abcd\n", - stuple(b'sha1:1234',)) + stuple( + b"sha1:1234", + ), + ) self.assertIsInstance(node, chk_map.InternalNode) self.assertEqual(1, len(node)) self.assertEqual(10, node.maximum_size) self.assertEqual((b"sha1:1234",), node.key()) - self.assertEqual(b'pref\x00fo', node._search_prefix) - self.assertEqual({b'pref\x00fo': (b'sha1:abcd',)}, node._items) + self.assertEqual(b"pref\x00fo", node._search_prefix) + self.assertEqual({b"pref\x00fo": (b"sha1:abcd",)}, node._items) def test_deserialise_with_null_pref(self): node = self.module._deserialise_internal_node( b"chknode:\n10\n1\n1\npref\x00fo\n\x00\x00sha1:abcd\n", - stuple(b'sha1:1234',)) + stuple( + b"sha1:1234", + ), + ) self.assertIsInstance(node, chk_map.InternalNode) self.assertEqual(1, len(node)) self.assertEqual(10, node.maximum_size) self.assertEqual((b"sha1:1234",), node.key()) - self.assertEqual(b'pref\x00fo', node._search_prefix) - self.assertEqual({b'pref\x00fo\x00': (b'sha1:abcd',)}, node._items) - - - + self.assertEqual(b"pref\x00fo", node._search_prefix) + self.assertEqual({b"pref\x00fo\x00": (b"sha1:abcd",)}, node._items) diff --git a/breezy/bzr/tests/test__dirstate_helpers.py b/breezy/bzr/tests/test__dirstate_helpers.py index 52f7e117f7..0c3aef85ae 100644 --- a/breezy/bzr/tests/test__dirstate_helpers.py +++ b/breezy/bzr/tests/test__dirstate_helpers.py @@ -31,27 +31,27 @@ compiled_dirstate_helpers_feature = features.ModuleAvailableFeature( - 'breezy.bzr._dirstate_helpers_pyx') + "breezy.bzr._dirstate_helpers_pyx" +) # FIXME: we should also parametrize against SHA1Provider ! -ue_scenarios = [('dirstate_Python', - {'update_entry': dirstate.py_update_entry})] +ue_scenarios = [("dirstate_Python", {"update_entry": dirstate.py_update_entry})] if compiled_dirstate_helpers_feature.available(): update_entry = compiled_dirstate_helpers_feature.module.update_entry - ue_scenarios.append(('dirstate_Pyrex', {'update_entry': update_entry})) + ue_scenarios.append(("dirstate_Pyrex", {"update_entry": update_entry})) -pe_scenarios = [('dirstate_Python', - {'_process_entry': dirstate.ProcessEntryPython})] +pe_scenarios = [("dirstate_Python", {"_process_entry": dirstate.ProcessEntryPython})] if compiled_dirstate_helpers_feature.available(): process_entry = compiled_dirstate_helpers_feature.module.ProcessEntryC - pe_scenarios.append(('dirstate_Pyrex', {'_process_entry': process_entry})) + pe_scenarios.append(("dirstate_Pyrex", {"_process_entry": process_entry})) -helper_scenarios = [('dirstate_Python', {'helpers': _dirstate_helpers_py})] +helper_scenarios = [("dirstate_Python", {"helpers": _dirstate_helpers_py})] if compiled_dirstate_helpers_feature.available(): - helper_scenarios.append(('dirstate_Pyrex', - {'helpers': compiled_dirstate_helpers_feature.module})) + helper_scenarios.append( + ("dirstate_Pyrex", {"helpers": compiled_dirstate_helpers_feature.module}) + ) class TestBisectPathMixin: @@ -99,11 +99,13 @@ def assertBisect(self, paths, split_paths, path, exists=True): split_path = self.split_for_dirblocks([path])[0] bisect_func, offset = self.get_bisect() bisect_split_idx = bisect_func(split_paths, split_path) - self.assertEqual(bisect_split_idx, bisect_path_idx, - '{} disagreed. {} != {}' - ' for key {!r}'.format(bisect_path.__name__, - bisect_split_idx, bisect_path_idx, path) - ) + self.assertEqual( + bisect_split_idx, + bisect_path_idx, + "{} disagreed. {} != {}" " for key {!r}".format( + bisect_path.__name__, bisect_split_idx, bisect_path_idx, path + ), + ) if exists: self.assertEqual(path, paths[bisect_path_idx + offset]) @@ -111,25 +113,25 @@ def split_for_dirblocks(self, paths): dir_split_paths = [] for path in paths: dirname, basename = os.path.split(path) - dir_split_paths.append((dirname.split(b'/'), basename)) + dir_split_paths.append((dirname.split(b"/"), basename)) dir_split_paths.sort() return dir_split_paths def test_simple(self): """In the simple case it works just like bisect_left.""" - paths = [b'', b'a', b'b', b'c', b'd'] + paths = [b"", b"a", b"b", b"c", b"d"] split_paths = self.split_for_dirblocks(paths) for path in paths: self.assertBisect(paths, split_paths, path, exists=True) - self.assertBisect(paths, split_paths, b'_', exists=False) - self.assertBisect(paths, split_paths, b'aa', exists=False) - self.assertBisect(paths, split_paths, b'bb', exists=False) - self.assertBisect(paths, split_paths, b'cc', exists=False) - self.assertBisect(paths, split_paths, b'dd', exists=False) - self.assertBisect(paths, split_paths, b'a/a', exists=False) - self.assertBisect(paths, split_paths, b'b/b', exists=False) - self.assertBisect(paths, split_paths, b'c/c', exists=False) - self.assertBisect(paths, split_paths, b'd/d', exists=False) + self.assertBisect(paths, split_paths, b"_", exists=False) + self.assertBisect(paths, split_paths, b"aa", exists=False) + self.assertBisect(paths, split_paths, b"bb", exists=False) + self.assertBisect(paths, split_paths, b"cc", exists=False) + self.assertBisect(paths, split_paths, b"dd", exists=False) + self.assertBisect(paths, split_paths, b"a/a", exists=False) + self.assertBisect(paths, split_paths, b"b/b", exists=False) + self.assertBisect(paths, split_paths, b"c/c", exists=False) + self.assertBisect(paths, split_paths, b"d/d", exists=False) def test_involved(self): """This is where bisect_path_* diverges slightly.""" @@ -167,40 +169,53 @@ def test_involved(self): # So all the root-directory paths, then all the # first sub directory, etc. paths = [ # content of '/' - b'', b'a', b'a-a', b'a-z', b'a=a', b'a=z', - # content of 'a/' - b'a/a', b'a/a-a', b'a/a-z', - b'a/a=a', b'a/a=z', - b'a/z', b'a/z-a', b'a/z-z', - b'a/z=a', b'a/z=z', - # content of 'a/a/' - b'a/a/a', b'a/a/z', - # content of 'a/a-a' - b'a/a-a/a', - # content of 'a/a-z' - b'a/a-z/z', - # content of 'a/a=a' - b'a/a=a/a', - # content of 'a/a=z' - b'a/a=z/z', - # content of 'a/z/' - b'a/z/a', b'a/z/z', - # content of 'a-a' - b'a-a/a', - # content of 'a-z' - b'a-z/z', - # content of 'a=a' - b'a=a/a', - # content of 'a=z' - b'a=z/z', - ] + b"", + b"a", + b"a-a", + b"a-z", + b"a=a", + b"a=z", + # content of 'a/' + b"a/a", + b"a/a-a", + b"a/a-z", + b"a/a=a", + b"a/a=z", + b"a/z", + b"a/z-a", + b"a/z-z", + b"a/z=a", + b"a/z=z", + # content of 'a/a/' + b"a/a/a", + b"a/a/z", + # content of 'a/a-a' + b"a/a-a/a", + # content of 'a/a-z' + b"a/a-z/z", + # content of 'a/a=a' + b"a/a=a/a", + # content of 'a/a=z' + b"a/a=z/z", + # content of 'a/z/' + b"a/z/a", + b"a/z/z", + # content of 'a-a' + b"a-a/a", + # content of 'a-z' + b"a-z/z", + # content of 'a=a' + b"a=a/a", + # content of 'a=z' + b"a=z/z", + ] split_paths = self.split_for_dirblocks(paths) sorted_paths = [] for dir_parts, basename in split_paths: - if dir_parts == [b'']: + if dir_parts == [b""]: sorted_paths.append(basename) else: - sorted_paths.append(b'/'.join(dir_parts + [basename])) + sorted_paths.append(b"/".join(dir_parts + [basename])) self.assertEqual(sorted_paths, paths) @@ -213,6 +228,7 @@ class TestBisectPathLeft(tests.TestCase, TestBisectPathMixin): def get_bisect_path(self): from ..dirstate import bisect_path_left + return bisect_path_left def get_bisect(self): @@ -224,6 +240,7 @@ class TestBisectPathRight(tests.TestCase, TestBisectPathMixin): def get_bisect_path(self): from ..dirstate import bisect_path_right + return bisect_path_right def get_bisect(self): @@ -258,74 +275,73 @@ def assertCmpByDirs(self, expected, str1, str2): def test_cmp_empty(self): """Compare against the empty string.""" - self.assertCmpByDirs(0, b'', b'') - self.assertCmpByDirs(1, b'a', b'') - self.assertCmpByDirs(1, b'ab', b'') - self.assertCmpByDirs(1, b'abc', b'') - self.assertCmpByDirs(1, b'abcd', b'') - self.assertCmpByDirs(1, b'abcde', b'') - self.assertCmpByDirs(1, b'abcdef', b'') - self.assertCmpByDirs(1, b'abcdefg', b'') - self.assertCmpByDirs(1, b'abcdefgh', b'') - self.assertCmpByDirs(1, b'abcdefghi', b'') - self.assertCmpByDirs(1, b'test/ing/a/path/', b'') + self.assertCmpByDirs(0, b"", b"") + self.assertCmpByDirs(1, b"a", b"") + self.assertCmpByDirs(1, b"ab", b"") + self.assertCmpByDirs(1, b"abc", b"") + self.assertCmpByDirs(1, b"abcd", b"") + self.assertCmpByDirs(1, b"abcde", b"") + self.assertCmpByDirs(1, b"abcdef", b"") + self.assertCmpByDirs(1, b"abcdefg", b"") + self.assertCmpByDirs(1, b"abcdefgh", b"") + self.assertCmpByDirs(1, b"abcdefghi", b"") + self.assertCmpByDirs(1, b"test/ing/a/path/", b"") def test_cmp_same_str(self): """Compare the same string.""" - self.assertCmpByDirs(0, b'a', b'a') - self.assertCmpByDirs(0, b'ab', b'ab') - self.assertCmpByDirs(0, b'abc', b'abc') - self.assertCmpByDirs(0, b'abcd', b'abcd') - self.assertCmpByDirs(0, b'abcde', b'abcde') - self.assertCmpByDirs(0, b'abcdef', b'abcdef') - self.assertCmpByDirs(0, b'abcdefg', b'abcdefg') - self.assertCmpByDirs(0, b'abcdefgh', b'abcdefgh') - self.assertCmpByDirs(0, b'abcdefghi', b'abcdefghi') - self.assertCmpByDirs(0, b'testing a long string', - b'testing a long string') - self.assertCmpByDirs(0, b'x' * 10000, b'x' * 10000) - self.assertCmpByDirs(0, b'a/b', b'a/b') - self.assertCmpByDirs(0, b'a/b/c', b'a/b/c') - self.assertCmpByDirs(0, b'a/b/c/d', b'a/b/c/d') - self.assertCmpByDirs(0, b'a/b/c/d/e', b'a/b/c/d/e') + self.assertCmpByDirs(0, b"a", b"a") + self.assertCmpByDirs(0, b"ab", b"ab") + self.assertCmpByDirs(0, b"abc", b"abc") + self.assertCmpByDirs(0, b"abcd", b"abcd") + self.assertCmpByDirs(0, b"abcde", b"abcde") + self.assertCmpByDirs(0, b"abcdef", b"abcdef") + self.assertCmpByDirs(0, b"abcdefg", b"abcdefg") + self.assertCmpByDirs(0, b"abcdefgh", b"abcdefgh") + self.assertCmpByDirs(0, b"abcdefghi", b"abcdefghi") + self.assertCmpByDirs(0, b"testing a long string", b"testing a long string") + self.assertCmpByDirs(0, b"x" * 10000, b"x" * 10000) + self.assertCmpByDirs(0, b"a/b", b"a/b") + self.assertCmpByDirs(0, b"a/b/c", b"a/b/c") + self.assertCmpByDirs(0, b"a/b/c/d", b"a/b/c/d") + self.assertCmpByDirs(0, b"a/b/c/d/e", b"a/b/c/d/e") def test_simple_paths(self): """Compare strings that act like normal string comparison.""" - self.assertCmpByDirs(-1, b'a', b'b') - self.assertCmpByDirs(-1, b'aa', b'ab') - self.assertCmpByDirs(-1, b'ab', b'bb') - self.assertCmpByDirs(-1, b'aaa', b'aab') - self.assertCmpByDirs(-1, b'aab', b'abb') - self.assertCmpByDirs(-1, b'abb', b'bbb') - self.assertCmpByDirs(-1, b'aaaa', b'aaab') - self.assertCmpByDirs(-1, b'aaab', b'aabb') - self.assertCmpByDirs(-1, b'aabb', b'abbb') - self.assertCmpByDirs(-1, b'abbb', b'bbbb') - self.assertCmpByDirs(-1, b'aaaaa', b'aaaab') - self.assertCmpByDirs(-1, b'a/a', b'a/b') - self.assertCmpByDirs(-1, b'a/b', b'b/b') - self.assertCmpByDirs(-1, b'a/a/a', b'a/a/b') - self.assertCmpByDirs(-1, b'a/a/b', b'a/b/b') - self.assertCmpByDirs(-1, b'a/b/b', b'b/b/b') - self.assertCmpByDirs(-1, b'a/a/a/a', b'a/a/a/b') - self.assertCmpByDirs(-1, b'a/a/a/b', b'a/a/b/b') - self.assertCmpByDirs(-1, b'a/a/b/b', b'a/b/b/b') - self.assertCmpByDirs(-1, b'a/b/b/b', b'b/b/b/b') - self.assertCmpByDirs(-1, b'a/a/a/a/a', b'a/a/a/a/b') + self.assertCmpByDirs(-1, b"a", b"b") + self.assertCmpByDirs(-1, b"aa", b"ab") + self.assertCmpByDirs(-1, b"ab", b"bb") + self.assertCmpByDirs(-1, b"aaa", b"aab") + self.assertCmpByDirs(-1, b"aab", b"abb") + self.assertCmpByDirs(-1, b"abb", b"bbb") + self.assertCmpByDirs(-1, b"aaaa", b"aaab") + self.assertCmpByDirs(-1, b"aaab", b"aabb") + self.assertCmpByDirs(-1, b"aabb", b"abbb") + self.assertCmpByDirs(-1, b"abbb", b"bbbb") + self.assertCmpByDirs(-1, b"aaaaa", b"aaaab") + self.assertCmpByDirs(-1, b"a/a", b"a/b") + self.assertCmpByDirs(-1, b"a/b", b"b/b") + self.assertCmpByDirs(-1, b"a/a/a", b"a/a/b") + self.assertCmpByDirs(-1, b"a/a/b", b"a/b/b") + self.assertCmpByDirs(-1, b"a/b/b", b"b/b/b") + self.assertCmpByDirs(-1, b"a/a/a/a", b"a/a/a/b") + self.assertCmpByDirs(-1, b"a/a/a/b", b"a/a/b/b") + self.assertCmpByDirs(-1, b"a/a/b/b", b"a/b/b/b") + self.assertCmpByDirs(-1, b"a/b/b/b", b"b/b/b/b") + self.assertCmpByDirs(-1, b"a/a/a/a/a", b"a/a/a/a/b") def test_tricky_paths(self): - self.assertCmpByDirs(1, b'ab/cd/ef', b'ab/cc/ef') - self.assertCmpByDirs(1, b'ab/cd/ef', b'ab/c/ef') - self.assertCmpByDirs(-1, b'ab/cd/ef', b'ab/cd-ef') - self.assertCmpByDirs(-1, b'ab/cd', b'ab/cd-') - self.assertCmpByDirs(-1, b'ab/cd', b'ab-cd') + self.assertCmpByDirs(1, b"ab/cd/ef", b"ab/cc/ef") + self.assertCmpByDirs(1, b"ab/cd/ef", b"ab/c/ef") + self.assertCmpByDirs(-1, b"ab/cd/ef", b"ab/cd-ef") + self.assertCmpByDirs(-1, b"ab/cd", b"ab/cd-") + self.assertCmpByDirs(-1, b"ab/cd", b"ab-cd") def test_cmp_non_ascii(self): - self.assertCmpByDirs(-1, b'\xc2\xb5', b'\xc3\xa5') # u'\xb5', u'\xe5' - self.assertCmpByDirs(-1, b'a', b'\xc3\xa5') # u'a', u'\xe5' - self.assertCmpByDirs(-1, b'b', b'\xc2\xb5') # u'b', u'\xb5' - self.assertCmpByDirs(-1, b'a/b', b'a/\xc3\xa5') # u'a/b', u'a/\xe5' - self.assertCmpByDirs(-1, b'b/a', b'b/\xc2\xb5') # u'b/a', u'b/\xb5' + self.assertCmpByDirs(-1, b"\xc2\xb5", b"\xc3\xa5") # u'\xb5', u'\xe5' + self.assertCmpByDirs(-1, b"a", b"\xc3\xa5") # u'a', u'\xe5' + self.assertCmpByDirs(-1, b"b", b"\xc2\xb5") # u'b', u'\xb5' + self.assertCmpByDirs(-1, b"a/b", b"a/\xc3\xa5") # u'a/b', u'a/\xe5' + self.assertCmpByDirs(-1, b"b/a", b"b/\xc2\xb5") # u'b/a', u'b/\xb5' class TestLtPathByDirblock(tests.TestCase): @@ -341,6 +357,7 @@ class TestLtPathByDirblock(tests.TestCase): def get_lt_path_by_dirblock(self): """Get a specific implementation of lt_path_by_dirblock.""" from ..dirstate import lt_path_by_dirblock + return lt_path_by_dirblock def assertLtPathByDirblock(self, paths): @@ -351,102 +368,158 @@ def assertLtPathByDirblock(self, paths): :param paths: a sorted list of paths to compare """ + # First, make sure the paths being passed in are correct def _key(p): dirname, basename = os.path.split(p) - return dirname.split(b'/'), basename + return dirname.split(b"/"), basename + self.assertEqual(sorted(paths, key=_key), paths) lt_path_by_dirblock = self.get_lt_path_by_dirblock() for idx1, path1 in enumerate(paths): for idx2, path2 in enumerate(paths): lt_result = lt_path_by_dirblock(path1, path2) - self.assertEqual(idx1 < idx2, lt_result, - '{} did not state that {!r} < {!r}, lt={}'.format(lt_path_by_dirblock.__name__, - path1, path2, lt_result)) + self.assertEqual( + idx1 < idx2, + lt_result, + "{} did not state that {!r} < {!r}, lt={}".format( + lt_path_by_dirblock.__name__, path1, path2, lt_result + ), + ) def test_cmp_simple_paths(self): """Compare against the empty string.""" - self.assertLtPathByDirblock( - [b'', b'a', b'ab', b'abc', b'a/b/c', b'b/d/e']) - self.assertLtPathByDirblock([b'kl', b'ab/cd', b'ab/ef', b'gh/ij']) + self.assertLtPathByDirblock([b"", b"a", b"ab", b"abc", b"a/b/c", b"b/d/e"]) + self.assertLtPathByDirblock([b"kl", b"ab/cd", b"ab/ef", b"gh/ij"]) def test_tricky_paths(self): - self.assertLtPathByDirblock([ - # Contents of '' - b'', b'a', b'a-a', b'a=a', b'b', - # Contents of 'a' - b'a/a', b'a/a-a', b'a/a=a', b'a/b', - # Contents of 'a/a' - b'a/a/a', b'a/a/a-a', b'a/a/a=a', - # Contents of 'a/a/a' - b'a/a/a/a', b'a/a/a/b', - # Contents of 'a/a/a-a', - b'a/a/a-a/a', b'a/a/a-a/b', - # Contents of 'a/a/a=a', - b'a/a/a=a/a', b'a/a/a=a/b', - # Contents of 'a/a-a' - b'a/a-a/a', - # Contents of 'a/a-a/a' - b'a/a-a/a/a', b'a/a-a/a/b', - # Contents of 'a/a=a' - b'a/a=a/a', - # Contents of 'a/b' - b'a/b/a', b'a/b/b', - # Contents of 'a-a', - b'a-a/a', b'a-a/b', - # Contents of 'a=a', - b'a=a/a', b'a=a/b', - # Contents of 'b', - b'b/a', b'b/b', - ]) - self.assertLtPathByDirblock([ - # content of '/' - b'', b'a', b'a-a', b'a-z', b'a=a', b'a=z', - # content of 'a/' - b'a/a', b'a/a-a', b'a/a-z', - b'a/a=a', b'a/a=z', - b'a/z', b'a/z-a', b'a/z-z', - b'a/z=a', b'a/z=z', - # content of 'a/a/' - b'a/a/a', b'a/a/z', - # content of 'a/a-a' - b'a/a-a/a', - # content of 'a/a-z' - b'a/a-z/z', - # content of 'a/a=a' - b'a/a=a/a', - # content of 'a/a=z' - b'a/a=z/z', - # content of 'a/z/' - b'a/z/a', b'a/z/z', - # content of 'a-a' - b'a-a/a', - # content of 'a-z' - b'a-z/z', - # content of 'a=a' - b'a=a/a', - # content of 'a=z' - b'a=z/z', - ]) + self.assertLtPathByDirblock( + [ + # Contents of '' + b"", + b"a", + b"a-a", + b"a=a", + b"b", + # Contents of 'a' + b"a/a", + b"a/a-a", + b"a/a=a", + b"a/b", + # Contents of 'a/a' + b"a/a/a", + b"a/a/a-a", + b"a/a/a=a", + # Contents of 'a/a/a' + b"a/a/a/a", + b"a/a/a/b", + # Contents of 'a/a/a-a', + b"a/a/a-a/a", + b"a/a/a-a/b", + # Contents of 'a/a/a=a', + b"a/a/a=a/a", + b"a/a/a=a/b", + # Contents of 'a/a-a' + b"a/a-a/a", + # Contents of 'a/a-a/a' + b"a/a-a/a/a", + b"a/a-a/a/b", + # Contents of 'a/a=a' + b"a/a=a/a", + # Contents of 'a/b' + b"a/b/a", + b"a/b/b", + # Contents of 'a-a', + b"a-a/a", + b"a-a/b", + # Contents of 'a=a', + b"a=a/a", + b"a=a/b", + # Contents of 'b', + b"b/a", + b"b/b", + ] + ) + self.assertLtPathByDirblock( + [ + # content of '/' + b"", + b"a", + b"a-a", + b"a-z", + b"a=a", + b"a=z", + # content of 'a/' + b"a/a", + b"a/a-a", + b"a/a-z", + b"a/a=a", + b"a/a=z", + b"a/z", + b"a/z-a", + b"a/z-z", + b"a/z=a", + b"a/z=z", + # content of 'a/a/' + b"a/a/a", + b"a/a/z", + # content of 'a/a-a' + b"a/a-a/a", + # content of 'a/a-z' + b"a/a-z/z", + # content of 'a/a=a' + b"a/a=a/a", + # content of 'a/a=z' + b"a/a=z/z", + # content of 'a/z/' + b"a/z/a", + b"a/z/z", + # content of 'a-a' + b"a-a/a", + # content of 'a-z' + b"a-z/z", + # content of 'a=a' + b"a=a/a", + # content of 'a=z' + b"a=z/z", + ] + ) def test_nonascii(self): - self.assertLtPathByDirblock([ - # content of '/' - b'', b'a', b'\xc2\xb5', b'\xc3\xa5', - # content of 'a' - b'a/a', b'a/\xc2\xb5', b'a/\xc3\xa5', - # content of 'a/a' - b'a/a/a', b'a/a/\xc2\xb5', b'a/a/\xc3\xa5', - # content of 'a/\xc2\xb5' - b'a/\xc2\xb5/a', b'a/\xc2\xb5/\xc2\xb5', b'a/\xc2\xb5/\xc3\xa5', - # content of 'a/\xc3\xa5' - b'a/\xc3\xa5/a', b'a/\xc3\xa5/\xc2\xb5', b'a/\xc3\xa5/\xc3\xa5', - # content of '\xc2\xb5' - b'\xc2\xb5/a', b'\xc2\xb5/\xc2\xb5', b'\xc2\xb5/\xc3\xa5', - # content of '\xc2\xe5' - b'\xc3\xa5/a', b'\xc3\xa5/\xc2\xb5', b'\xc3\xa5/\xc3\xa5', - ]) + self.assertLtPathByDirblock( + [ + # content of '/' + b"", + b"a", + b"\xc2\xb5", + b"\xc3\xa5", + # content of 'a' + b"a/a", + b"a/\xc2\xb5", + b"a/\xc3\xa5", + # content of 'a/a' + b"a/a/a", + b"a/a/\xc2\xb5", + b"a/a/\xc3\xa5", + # content of 'a/\xc2\xb5' + b"a/\xc2\xb5/a", + b"a/\xc2\xb5/\xc2\xb5", + b"a/\xc2\xb5/\xc3\xa5", + # content of 'a/\xc3\xa5' + b"a/\xc3\xa5/a", + b"a/\xc3\xa5/\xc2\xb5", + b"a/\xc3\xa5/\xc3\xa5", + # content of '\xc2\xb5' + b"\xc2\xb5/a", + b"\xc2\xb5/\xc2\xb5", + b"\xc2\xb5/\xc3\xa5", + # content of '\xc2\xe5' + b"\xc3\xa5/a", + b"\xc3\xa5/\xc2\xb5", + b"\xc3\xa5/\xc3\xa5", + ] + ) class TestReadDirblocks(test_dirstate.TestCaseWithDirState): @@ -463,6 +536,7 @@ class TestReadDirblocks(test_dirstate.TestCaseWithDirState): def get_read_dirblocks(self): from .._dirstate_helpers_py import _read_dirblocks + return _read_dirblocks def test_smoketest(self): @@ -470,12 +544,10 @@ def test_smoketest(self): tree, state, expected = self.create_basic_dirstate() del tree state._read_header_if_needed() - self.assertEqual(dirstate.DirState.NOT_IN_MEMORY, - state._dirblock_state) + self.assertEqual(dirstate.DirState.NOT_IN_MEMORY, state._dirblock_state) read_dirblocks = self.get_read_dirblocks() read_dirblocks(state) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) def test_trailing_garbage(self): tree, state, expected = self.create_basic_dirstate() @@ -483,17 +555,16 @@ def test_trailing_garbage(self): # on Win32, if you've opened the file with FILE_SHARE_READ, trying to # open it in append mode will fail. state.unlock() - f = open('dirstate', 'ab') + f = open("dirstate", "ab") try: # Add bogus trailing garbage - f.write(b'bogus\n') + f.write(b"bogus\n") finally: f.close() state.lock_read() - e = self.assertRaises(dirstate.DirstateCorrupt, - state._read_dirblocks_if_needed) + e = self.assertRaises(dirstate.DirstateCorrupt, state._read_dirblocks_if_needed) # Make sure we mention the bogus characters in the error - self.assertContainsRe(str(e), 'bogus') + self.assertContainsRe(str(e), "bogus") class TestCompiledReadDirblocks(TestReadDirblocks): @@ -503,6 +574,7 @@ class TestCompiledReadDirblocks(TestReadDirblocks): def get_read_dirblocks(self): from .._dirstate_helpers_pyx import _read_dirblocks + return _read_dirblocks @@ -531,218 +603,213 @@ def test_update_entry(self): def test_process_entry(self): if compiled_dirstate_helpers_feature.available(): from .._dirstate_helpers_pyx import ProcessEntryC + self.assertIs(ProcessEntryC, dirstate._process_entry) else: from ..dirstate import ProcessEntryPython + self.assertIs(ProcessEntryPython, dirstate._process_entry) class TestUpdateEntry(test_dirstate.TestCaseWithDirState): """Test the DirState.update_entry functions.""" - scenarios = multiply_scenarios( - dir_reader_scenarios(), ue_scenarios) + scenarios = multiply_scenarios(dir_reader_scenarios(), ue_scenarios) # Set by load_tests update_entry = None def setUp(self): super().setUp() - self.overrideAttr(dirstate, 'update_entry', self.update_entry) + self.overrideAttr(dirstate, "update_entry", self.update_entry) def get_state_with_a(self): """Create a DirState tracking a single object named 'a'.""" - state = test_dirstate.InstrumentedDirState.initialize('dirstate') + state = test_dirstate.InstrumentedDirState.initialize("dirstate") self.addCleanup(state.unlock) - state.add('a', b'a-id', 'file', None, b'') - entry = state._get_entry(0, path_utf8=b'a') + state.add("a", b"a-id", "file", None, b"") + entry = state._get_entry(0, path_utf8=b"a") return state, entry def test_observed_sha1_cachable(self): state, entry = self.get_state_with_a() state.save() atime = time.time() - 10 - self.build_tree(['a']) - statvalue = test_dirstate._FakeStat.from_stat(os.lstat('a')) + self.build_tree(["a"]) + statvalue = test_dirstate._FakeStat.from_stat(os.lstat("a")) statvalue.st_mtime = statvalue.st_ctime = atime - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) state._observed_sha1(entry, b"foo", statvalue) - self.assertEqual(b'foo', entry[1][0][1]) + self.assertEqual(b"foo", entry[1][0][1]) packed_stat = dirstate.pack_stat(statvalue) self.assertEqual(packed_stat, entry[1][0][4]) - self.assertEqual(dirstate.DirState.IN_MEMORY_HASH_MODIFIED, - state._dirblock_state) + self.assertEqual( + dirstate.DirState.IN_MEMORY_HASH_MODIFIED, state._dirblock_state + ) def test_observed_sha1_not_cachable(self): state, entry = self.get_state_with_a() state.save() oldval = entry[1][0][1] oldstat = entry[1][0][4] - self.build_tree(['a']) - statvalue = os.lstat('a') - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.build_tree(["a"]) + statvalue = os.lstat("a") + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) state._observed_sha1(entry, "foo", statvalue) self.assertEqual(oldval, entry[1][0][1]) self.assertEqual(oldstat, entry[1][0][4]) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) def test_update_entry(self): state, _ = self.get_state_with_a() - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") tree.lock_write() - empty_revid = tree.commit('empty') - self.build_tree(['tree/a']) - tree.add(['a'], ids=[b'a-id']) - with_a_id = tree.commit('with_a') + empty_revid = tree.commit("empty") + self.build_tree(["tree/a"]) + tree.add(["a"], ids=[b"a-id"]) + with_a_id = tree.commit("with_a") self.addCleanup(tree.unlock) state.set_parent_trees( - [(empty_revid, tree.branch.repository.revision_tree(empty_revid))], - []) - entry = state._get_entry(0, path_utf8=b'a') - self.build_tree(['a']) + [(empty_revid, tree.branch.repository.revision_tree(empty_revid))], [] + ) + entry = state._get_entry(0, path_utf8=b"a") + self.build_tree(["a"]) # Add one where we don't provide the stat or sha already - self.assertEqual((b'', b'a', b'a-id'), entry[0]) - self.assertEqual((b'f', b'', 0, False, dirstate.DirState.NULLSTAT), - entry[1][0]) + self.assertEqual((b"", b"a", b"a-id"), entry[0]) + self.assertEqual((b"f", b"", 0, False, dirstate.DirState.NULLSTAT), entry[1][0]) # Flush the buffers to disk state.save() - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) - stat_value = os.lstat('a') + stat_value = os.lstat("a") packed_stat = dirstate.pack_stat(stat_value) - link_or_sha1 = self.update_entry(state, entry, abspath=b'a', - stat_value=stat_value) + link_or_sha1 = self.update_entry( + state, entry, abspath=b"a", stat_value=stat_value + ) self.assertEqual(None, link_or_sha1) # The dirblock entry should not have computed or cached the file's # sha1, but it did update the files' st_size. However, this is not # worth writing a dirstate file for, so we leave the state UNMODIFIED - self.assertEqual((b'f', b'', 14, False, dirstate.DirState.NULLSTAT), - entry[1][0]) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual( + (b"f", b"", 14, False, dirstate.DirState.NULLSTAT), entry[1][0] + ) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) mode = stat_value.st_mode - self.assertEqual([('is_exec', mode, False)], state._log) + self.assertEqual([("is_exec", mode, False)], state._log) state.save() - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) # Roll the clock back so the file is guaranteed to look too new. We # should still not compute the sha1. state.adjust_time(-10) del state._log[:] - link_or_sha1 = self.update_entry(state, entry, abspath=b'a', - stat_value=stat_value) - self.assertEqual([('is_exec', mode, False)], state._log) + link_or_sha1 = self.update_entry( + state, entry, abspath=b"a", stat_value=stat_value + ) + self.assertEqual([("is_exec", mode, False)], state._log) self.assertEqual(None, link_or_sha1) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) - self.assertEqual((b'f', b'', 14, False, dirstate.DirState.NULLSTAT), - entry[1][0]) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) + self.assertEqual( + (b"f", b"", 14, False, dirstate.DirState.NULLSTAT), entry[1][0] + ) state.save() # If it is cachable (the clock has moved forward) but new it still # won't calculate the sha or cache it. state.adjust_time(+20) del state._log[:] - link_or_sha1 = dirstate.update_entry(state, entry, abspath=b'a', - stat_value=stat_value) + link_or_sha1 = dirstate.update_entry( + state, entry, abspath=b"a", stat_value=stat_value + ) self.assertEqual(None, link_or_sha1) - self.assertEqual([('is_exec', mode, False)], state._log) - self.assertEqual((b'f', b'', 14, False, dirstate.DirState.NULLSTAT), - entry[1][0]) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual([("is_exec", mode, False)], state._log) + self.assertEqual( + (b"f", b"", 14, False, dirstate.DirState.NULLSTAT), entry[1][0] + ) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) # If the file is no longer new, and the clock has been moved forward # sufficiently, it will cache the sha. del state._log[:] state.set_parent_trees( - [(with_a_id, tree.branch.repository.revision_tree(with_a_id))], - []) - entry = state._get_entry(0, path_utf8=b'a') - - link_or_sha1 = self.update_entry(state, entry, abspath=b'a', - stat_value=stat_value) - self.assertEqual(b'b50e5406bb5e153ebbeb20268fcf37c87e1ecfb6', - link_or_sha1) - self.assertEqual([('is_exec', mode, False), ('sha1', b'a')], - state._log) - self.assertEqual((b'f', link_or_sha1, 14, False, packed_stat), - entry[1][0]) + [(with_a_id, tree.branch.repository.revision_tree(with_a_id))], [] + ) + entry = state._get_entry(0, path_utf8=b"a") + + link_or_sha1 = self.update_entry( + state, entry, abspath=b"a", stat_value=stat_value + ) + self.assertEqual(b"b50e5406bb5e153ebbeb20268fcf37c87e1ecfb6", link_or_sha1) + self.assertEqual([("is_exec", mode, False), ("sha1", b"a")], state._log) + self.assertEqual((b"f", link_or_sha1, 14, False, packed_stat), entry[1][0]) # Subsequent calls will just return the cached value del state._log[:] - link_or_sha1 = self.update_entry(state, entry, abspath=b'a', - stat_value=stat_value) - self.assertEqual(b'b50e5406bb5e153ebbeb20268fcf37c87e1ecfb6', - link_or_sha1) + link_or_sha1 = self.update_entry( + state, entry, abspath=b"a", stat_value=stat_value + ) + self.assertEqual(b"b50e5406bb5e153ebbeb20268fcf37c87e1ecfb6", link_or_sha1) self.assertEqual([], state._log) - self.assertEqual((b'f', link_or_sha1, 14, False, packed_stat), - entry[1][0]) + self.assertEqual((b"f", link_or_sha1, 14, False, packed_stat), entry[1][0]) def test_update_entry_symlink(self): """Update entry should read symlinks.""" self.requireFeature(features.SymlinkFeature(self.test_dir)) state, entry = self.get_state_with_a() state.save() - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) - os.symlink('target', 'a') + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) + os.symlink("target", "a") state.adjust_time(-10) # Make the symlink look new - stat_value = os.lstat('a') + stat_value = os.lstat("a") packed_stat = dirstate.pack_stat(stat_value) - link_or_sha1 = self.update_entry(state, entry, abspath=b'a', - stat_value=stat_value) - self.assertEqual(b'target', link_or_sha1) - self.assertEqual([('read_link', b'a', b'')], state._log) + link_or_sha1 = self.update_entry( + state, entry, abspath=b"a", stat_value=stat_value + ) + self.assertEqual(b"target", link_or_sha1) + self.assertEqual([("read_link", b"a", b"")], state._log) # Dirblock is not updated (the link is too new) - self.assertEqual([(b'l', b'', 6, False, dirstate.DirState.NULLSTAT)], - entry[1]) + self.assertEqual([(b"l", b"", 6, False, dirstate.DirState.NULLSTAT)], entry[1]) # The file entry turned into a symlink, that is considered # HASH modified worthy. - self.assertEqual(dirstate.DirState.IN_MEMORY_HASH_MODIFIED, - state._dirblock_state) + self.assertEqual( + dirstate.DirState.IN_MEMORY_HASH_MODIFIED, state._dirblock_state + ) # Because the stat_value looks new, we should re-read the target del state._log[:] - link_or_sha1 = self.update_entry(state, entry, abspath=b'a', - stat_value=stat_value) - self.assertEqual(b'target', link_or_sha1) - self.assertEqual([('read_link', b'a', b'')], state._log) - self.assertEqual([(b'l', b'', 6, False, dirstate.DirState.NULLSTAT)], - entry[1]) + link_or_sha1 = self.update_entry( + state, entry, abspath=b"a", stat_value=stat_value + ) + self.assertEqual(b"target", link_or_sha1) + self.assertEqual([("read_link", b"a", b"")], state._log) + self.assertEqual([(b"l", b"", 6, False, dirstate.DirState.NULLSTAT)], entry[1]) state.save() state.adjust_time(+20) # Skip into the future, all files look old del state._log[:] - link_or_sha1 = self.update_entry(state, entry, abspath=b'a', - stat_value=stat_value) + link_or_sha1 = self.update_entry( + state, entry, abspath=b"a", stat_value=stat_value + ) # The symlink stayed a symlink. So while it is new enough to cache, we # don't bother setting the flag, because it is not really worth saving # (when we stat the symlink, we'll have paged in the target.) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) - self.assertEqual(b'target', link_or_sha1) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) + self.assertEqual(b"target", link_or_sha1) # We need to re-read the link because only now can we cache it - self.assertEqual([('read_link', b'a', b'')], state._log) - self.assertEqual([(b'l', b'target', 6, False, packed_stat)], - entry[1]) + self.assertEqual([("read_link", b"a", b"")], state._log) + self.assertEqual([(b"l", b"target", 6, False, packed_stat)], entry[1]) del state._log[:] # Another call won't re-read the link self.assertEqual([], state._log) - link_or_sha1 = self.update_entry(state, entry, abspath=b'a', - stat_value=stat_value) - self.assertEqual(b'target', link_or_sha1) - self.assertEqual([(b'l', b'target', 6, False, packed_stat)], - entry[1]) + link_or_sha1 = self.update_entry( + state, entry, abspath=b"a", stat_value=stat_value + ) + self.assertEqual(b"target", link_or_sha1) + self.assertEqual([(b"l", b"target", 6, False, packed_stat)], entry[1]) def do_update_entry(self, state, entry, abspath): stat_value = os.lstat(abspath) @@ -750,74 +817,67 @@ def do_update_entry(self, state, entry, abspath): def test_update_entry_dir(self): state, entry = self.get_state_with_a() - self.build_tree(['a/']) - self.assertIs(None, self.do_update_entry(state, entry, b'a')) + self.build_tree(["a/"]) + self.assertIs(None, self.do_update_entry(state, entry, b"a")) def test_update_entry_dir_unchanged(self): state, entry = self.get_state_with_a() - self.build_tree(['a/']) + self.build_tree(["a/"]) state.adjust_time(+20) - self.assertIs(None, self.do_update_entry(state, entry, b'a')) + self.assertIs(None, self.do_update_entry(state, entry, b"a")) # a/ used to be a file, but is now a directory, worth saving - self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, state._dirblock_state) state.save() - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) # No changes to a/ means not worth saving. - self.assertIs(None, self.do_update_entry(state, entry, b'a')) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertIs(None, self.do_update_entry(state, entry, b"a")) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) # Change the last-modified time for the directory t = time.time() - 100.0 try: - os.utime('a', (t, t)) + os.utime("a", (t, t)) except OSError as e: # It looks like Win32 + FAT doesn't allow to change times on a dir. raise tests.TestSkipped("can't update mtime of a dir on FAT") from e saved_packed_stat = entry[1][0][-1] - self.assertIs(None, self.do_update_entry(state, entry, b'a')) + self.assertIs(None, self.do_update_entry(state, entry, b"a")) # We *do* go ahead and update the information in the dirblocks, but we # don't bother setting IN_MEMORY_MODIFIED because it is trivial to # recompute. self.assertNotEqual(saved_packed_stat, entry[1][0][-1]) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) def test_update_entry_file_unchanged(self): state, _ = self.get_state_with_a() - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") tree.lock_write() - self.build_tree(['tree/a']) - tree.add(['a'], ids=[b'a-id']) - with_a_id = tree.commit('witha') + self.build_tree(["tree/a"]) + tree.add(["a"], ids=[b"a-id"]) + with_a_id = tree.commit("witha") self.addCleanup(tree.unlock) state.set_parent_trees( - [(with_a_id, tree.branch.repository.revision_tree(with_a_id))], - []) - entry = state._get_entry(0, path_utf8=b'a') - self.build_tree(['a']) - sha1sum = b'b50e5406bb5e153ebbeb20268fcf37c87e1ecfb6' + [(with_a_id, tree.branch.repository.revision_tree(with_a_id))], [] + ) + entry = state._get_entry(0, path_utf8=b"a") + self.build_tree(["a"]) + sha1sum = b"b50e5406bb5e153ebbeb20268fcf37c87e1ecfb6" state.adjust_time(+20) - self.assertEqual(sha1sum, self.do_update_entry(state, entry, b'a')) - self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, - state._dirblock_state) + self.assertEqual(sha1sum, self.do_update_entry(state, entry, b"a")) + self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, state._dirblock_state) state.save() - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) - self.assertEqual(sha1sum, self.do_update_entry(state, entry, b'a')) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) + self.assertEqual(sha1sum, self.do_update_entry(state, entry, b"a")) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) def test_update_entry_tree_reference(self): - state = test_dirstate.InstrumentedDirState.initialize('dirstate') + state = test_dirstate.InstrumentedDirState.initialize("dirstate") self.addCleanup(state.unlock) - state.add('r', b'r-id', 'tree-reference', None, b'') - self.build_tree(['r/']) - entry = state._get_entry(0, path_utf8=b'r') - self.do_update_entry(state, entry, 'r') - entry = state._get_entry(0, path_utf8=b'r') - self.assertEqual(b't', entry[1][0][0]) + state.add("r", b"r-id", "tree-reference", None, b"") + self.build_tree(["r/"]) + entry = state._get_entry(0, path_utf8=b"r") + self.do_update_entry(state, entry, "r") + entry = state._get_entry(0, path_utf8=b"r") + self.assertEqual(b"t", entry[1][0][0]) def create_and_test_file(self, state, entry): """Create a file at 'a' and verify the state finds it during update. @@ -825,14 +885,13 @@ def create_and_test_file(self, state, entry): The state should already be versioning *something* at 'a'. This makes sure that state.update_entry recognizes it as a file. """ - self.build_tree(['a']) - stat_value = os.lstat('a') + self.build_tree(["a"]) + stat_value = os.lstat("a") packed_stat = dirstate.pack_stat(stat_value) - link_or_sha1 = self.do_update_entry(state, entry, abspath='a') + link_or_sha1 = self.do_update_entry(state, entry, abspath="a") self.assertEqual(None, link_or_sha1) - self.assertEqual([(b'f', b'', 14, False, dirstate.DirState.NULLSTAT)], - entry[1]) + self.assertEqual([(b"f", b"", 14, False, dirstate.DirState.NULLSTAT)], entry[1]) return packed_stat def create_and_test_dir(self, state, entry): @@ -841,13 +900,13 @@ def create_and_test_dir(self, state, entry): The state should already be versioning *something* at 'a'. This makes sure that state.update_entry recognizes it as a directory. """ - self.build_tree(['a/']) - stat_value = os.lstat('a') + self.build_tree(["a/"]) + stat_value = os.lstat("a") packed_stat = dirstate.pack_stat(stat_value) - link_or_sha1 = self.do_update_entry(state, entry, abspath=b'a') + link_or_sha1 = self.do_update_entry(state, entry, abspath=b"a") self.assertIs(None, link_or_sha1) - self.assertEqual([(b'd', b'', 0, False, packed_stat)], entry[1]) + self.assertEqual([(b"d", b"", 0, False, packed_stat)], entry[1]) return packed_stat @@ -862,15 +921,14 @@ def create_and_test_symlink(self, state, entry): support. """ # caller should care about skipping test on platforms without symlinks - os.symlink('path/to/foo', 'a') + os.symlink("path/to/foo", "a") - stat_value = os.lstat('a') + stat_value = os.lstat("a") packed_stat = dirstate.pack_stat(stat_value) - link_or_sha1 = self.do_update_entry(state, entry, abspath=b'a') - self.assertEqual(b'path/to/foo', link_or_sha1) - self.assertEqual([(b'l', b'path/to/foo', 11, False, packed_stat)], - entry[1]) + link_or_sha1 = self.do_update_entry(state, entry, abspath=b"a") + self.assertEqual(b"path/to/foo", link_or_sha1) + self.assertEqual([(b"l", b"path/to/foo", 11, False, packed_stat)], entry[1]) return packed_stat def test_update_file_to_dir(self): @@ -881,7 +939,7 @@ def test_update_file_to_dir(self): # The file sha1 won't be cached unless the file is old state.adjust_time(+10) self.create_and_test_file(state, entry) - os.remove('a') + os.remove("a") self.create_and_test_dir(state, entry) def test_update_file_to_symlink(self): @@ -891,7 +949,7 @@ def test_update_file_to_symlink(self): # The file sha1 won't be cached unless the file is old state.adjust_time(+10) self.create_and_test_file(state, entry) - os.remove('a') + os.remove("a") self.create_and_test_symlink(state, entry) def test_update_dir_to_file(self): @@ -900,7 +958,7 @@ def test_update_dir_to_file(self): # The file sha1 won't be cached unless the file is old state.adjust_time(+10) self.create_and_test_dir(state, entry) - os.rmdir('a') + os.rmdir("a") self.create_and_test_file(state, entry) def test_update_dir_to_symlink(self): @@ -910,7 +968,7 @@ def test_update_dir_to_symlink(self): # The symlink target won't be cached if it isn't old state.adjust_time(+10) self.create_and_test_dir(state, entry) - os.rmdir('a') + os.rmdir("a") self.create_and_test_symlink(state, entry) def test_update_symlink_to_file(self): @@ -920,7 +978,7 @@ def test_update_symlink_to_file(self): # The symlink and file info won't be cached unless old state.adjust_time(+10) self.create_and_test_symlink(state, entry) - os.remove('a') + os.remove("a") self.create_and_test_file(state, entry) def test_update_symlink_to_dir(self): @@ -930,12 +988,12 @@ def test_update_symlink_to_dir(self): # The symlink target won't be cached if it isn't old state.adjust_time(+10) self.create_and_test_symlink(state, entry) - os.remove('a') + os.remove("a") self.create_and_test_dir(state, entry) def test__is_executable_win32(self): state, entry = self.get_state_with_a() - self.build_tree(['a']) + self.build_tree(["a"]) # Make sure we are using the version of _is_executable that doesn't # check the filesystem mode. @@ -944,48 +1002,44 @@ def test__is_executable_win32(self): # The file on disk is not executable, but we are marking it as though # it is. With _use_filesystem_for_exec disabled we ignore what is on # disk. - entry[1][0] = (b'f', b'', 0, True, dirstate.DirState.NULLSTAT) + entry[1][0] = (b"f", b"", 0, True, dirstate.DirState.NULLSTAT) - stat_value = os.lstat('a') + stat_value = os.lstat("a") dirstate.pack_stat(stat_value) state.adjust_time(-10) # Make sure everything is new - self.update_entry(state, entry, abspath=b'a', stat_value=stat_value) + self.update_entry(state, entry, abspath=b"a", stat_value=stat_value) # The row is updated, but the executable bit stays set. - self.assertEqual([(b'f', b'', 14, True, dirstate.DirState.NULLSTAT)], - entry[1]) + self.assertEqual([(b"f", b"", 14, True, dirstate.DirState.NULLSTAT)], entry[1]) # Make the disk object look old enough to cache (but it won't cache the # sha as it is a new file). state.adjust_time(+20) - self.update_entry(state, entry, abspath=b'a', stat_value=stat_value) - self.assertEqual([(b'f', b'', 14, True, dirstate.DirState.NULLSTAT)], - entry[1]) + self.update_entry(state, entry, abspath=b"a", stat_value=stat_value) + self.assertEqual([(b"f", b"", 14, True, dirstate.DirState.NULLSTAT)], entry[1]) def _prepare_tree(self): # Create a tree - text = b'Hello World\n' - tree = self.make_branch_and_tree('tree') - self.build_tree_contents([('tree/a file', text)]) - tree.add('a file', ids=b'a-file-id') + text = b"Hello World\n" + tree = self.make_branch_and_tree("tree") + self.build_tree_contents([("tree/a file", text)]) + tree.add("a file", ids=b"a-file-id") # Note: dirstate does not sha prior to the first commit # so commit now in order for the test to work - tree.commit('first') + tree.commit("first") return tree, text def test_sha1provider_sha1_used(self): tree, text = self._prepare_tree() - state = dirstate.DirState.from_tree(tree, 'dirstate', - UppercaseSHA1Provider()) + state = dirstate.DirState.from_tree(tree, "dirstate", UppercaseSHA1Provider()) self.addCleanup(state.unlock) expected_sha = osutils.sha_string(text.upper() + b"foo") - entry = state._get_entry(0, path_utf8=b'a file') + entry = state._get_entry(0, path_utf8=b"a file") self.assertNotEqual((None, None), entry) state._sha_cutoff_time() state._cutoff_time += 10 - sha1 = self.update_entry(state, entry, 'tree/a file', - os.lstat('tree/a file')) + sha1 = self.update_entry(state, entry, "tree/a file", os.lstat("tree/a file")) self.assertEqual(expected_sha, sha1) def test_sha1provider_stat_and_sha1_used(self): @@ -996,9 +1050,10 @@ def test_sha1provider_stat_and_sha1_used(self): state._sha1_provider = UppercaseSHA1Provider() # If we used the standard provider, it would look like nothing has # changed - file_ids_changed = [change.file_id for change - in tree.iter_changes(tree.basis_tree())] - self.assertEqual([b'a-file-id'], file_ids_changed) + file_ids_changed = [ + change.file_id for change in tree.iter_changes(tree.basis_tree()) + ] + self.assertEqual([b"a-file-id"], file_ids_changed) class UppercaseSHA1Provider(dirstate.SHA1Provider): @@ -1008,15 +1063,14 @@ def sha1(self, abspath): return self.stat_and_sha1(abspath)[1] def stat_and_sha1(self, abspath): - with open(abspath, 'rb') as file_obj: + with open(abspath, "rb") as file_obj: statvalue = os.fstat(file_obj.fileno()) - text = b''.join(file_obj.readlines()) + text = b"".join(file_obj.readlines()) sha1 = osutils.sha_string(text.upper() + b"foo") return statvalue, sha1 class TestProcessEntry(test_dirstate.TestCaseWithDirState): - scenarios = multiply_scenarios(dir_reader_scenarios(), pe_scenarios) # Set by load_tests @@ -1024,55 +1078,55 @@ class TestProcessEntry(test_dirstate.TestCaseWithDirState): def setUp(self): super().setUp() - self.overrideAttr(dirstate, '_process_entry', self._process_entry) + self.overrideAttr(dirstate, "_process_entry", self._process_entry) def assertChangedFileIds(self, expected, tree): with tree.lock_read(): - file_ids = [info.file_id for info - in tree.iter_changes(tree.basis_tree())] + file_ids = [info.file_id for info in tree.iter_changes(tree.basis_tree())] self.assertEqual(sorted(expected), sorted(file_ids)) def test_exceptions_raised(self): # This is a direct test of bug #495023, it relies on osutils.is_inside # getting called in an inner function. Which makes it a bit brittle, # but at least it does reproduce the bug. - tree = self.make_branch_and_tree('tree') - self.build_tree(['tree/file', 'tree/dir/', 'tree/dir/sub', - 'tree/dir2/', 'tree/dir2/sub2']) - tree.add(['file', 'dir', 'dir/sub', 'dir2', 'dir2/sub2']) - tree.commit('first commit') + tree = self.make_branch_and_tree("tree") + self.build_tree( + ["tree/file", "tree/dir/", "tree/dir/sub", "tree/dir2/", "tree/dir2/sub2"] + ) + tree.add(["file", "dir", "dir/sub", "dir2", "dir2/sub2"]) + tree.commit("first commit") tree.lock_read() self.addCleanup(tree.unlock) basis_tree = tree.basis_tree() def is_inside_raises(*args, **kwargs): - raise RuntimeError('stop this') - self.overrideAttr(dirstate, 'is_inside', is_inside_raises) + raise RuntimeError("stop this") + + self.overrideAttr(dirstate, "is_inside", is_inside_raises) try: from .. import _dirstate_helpers_pyx except ImportError: pass else: - self.overrideAttr(_dirstate_helpers_pyx, - 'is_inside', is_inside_raises) - self.overrideAttr(osutils, 'is_inside', is_inside_raises) + self.overrideAttr(_dirstate_helpers_pyx, "is_inside", is_inside_raises) + self.overrideAttr(osutils, "is_inside", is_inside_raises) self.assertListRaises(RuntimeError, tree.iter_changes, basis_tree) def test_simple_changes(self): - tree = self.make_branch_and_tree('tree') - self.build_tree(['tree/file']) - tree.add(['file'], ids=[b'file-id']) - self.assertChangedFileIds([tree.path2id(''), b'file-id'], tree) - tree.commit('one') + tree = self.make_branch_and_tree("tree") + self.build_tree(["tree/file"]) + tree.add(["file"], ids=[b"file-id"]) + self.assertChangedFileIds([tree.path2id(""), b"file-id"], tree) + tree.commit("one") self.assertChangedFileIds([], tree) def test_sha1provider_stat_and_sha1_used(self): - tree = self.make_branch_and_tree('tree') - self.build_tree(['tree/file']) - tree.add(['file'], ids=[b'file-id']) - tree.commit('one') + tree = self.make_branch_and_tree("tree") + self.build_tree(["tree/file"]) + tree.add(["file"], ids=[b"file-id"]) + tree.commit("one") tree.lock_write() self.addCleanup(tree.unlock) state = tree._current_dirstate() state._sha1_provider = UppercaseSHA1Provider() - self.assertChangedFileIds([b'file-id'], tree) + self.assertChangedFileIds([b"file-id"], tree) diff --git a/breezy/bzr/tests/test__groupcompress.py b/breezy/bzr/tests/test__groupcompress.py index f1413254aa..d93b732f57 100644 --- a/breezy/bzr/tests/test__groupcompress.py +++ b/breezy/bzr/tests/test__groupcompress.py @@ -27,24 +27,19 @@ def module_scenarios(): scenarios = [ - ('python', {'_gc_module': _groupcompress_py}), - ] + ("python", {"_gc_module": _groupcompress_py}), + ] if compiled_groupcompress_feature.available(): gc_module = compiled_groupcompress_feature.module - scenarios.append(('C', - {'_gc_module': gc_module})) + scenarios.append(("C", {"_gc_module": gc_module})) return scenarios def two_way_scenarios(): - scenarios = [ - ('PR', {'make_delta': _groupcompress_py.make_delta}) - ] + scenarios = [("PR", {"make_delta": _groupcompress_py.make_delta})] if compiled_groupcompress_feature.available(): gc_module = compiled_groupcompress_feature.module - scenarios.extend([ - ('CR', {'make_delta': gc_module.make_delta}) - ]) + scenarios.extend([("CR", {"make_delta": gc_module.make_delta})]) return scenarios @@ -52,7 +47,8 @@ def two_way_scenarios(): compiled_groupcompress_feature = features.ModuleAvailableFeature( - 'breezy.bzr._groupcompress_pyx') + "breezy.bzr._groupcompress_pyx" +) _text1 = b"""\ This is a bit @@ -113,7 +109,6 @@ def two_way_scenarios(): class TestMakeAndApplyDelta(tests.TestCase): - scenarios = module_scenarios() _gc_module = None # Set by load_tests @@ -124,53 +119,57 @@ def setUp(self): self.apply_delta_to_source = _groupcompress_rs.apply_delta_to_source def test_make_delta_is_typesafe(self): - self.make_delta(b'a string', b'another string') + self.make_delta(b"a string", b"another string") def _check_make_delta(string1, string2): self.assertRaises(TypeError, self.make_delta, string1, string2) - _check_make_delta(b'a string', object()) - _check_make_delta(b'a string', 'not a string') - _check_make_delta(object(), b'a string') - _check_make_delta('not a string', b'a string') + _check_make_delta(b"a string", object()) + _check_make_delta(b"a string", "not a string") + _check_make_delta(object(), b"a string") + _check_make_delta("not a string", b"a string") def test_make_noop_delta(self): ident_delta = self.make_delta(_text1, _text1) - self.assertEqual(b'M\x90M', ident_delta) + self.assertEqual(b"M\x90M", ident_delta) ident_delta = self.make_delta(_text2, _text2) - self.assertEqual(b'N\x90N', ident_delta) + self.assertEqual(b"N\x90N", ident_delta) ident_delta = self.make_delta(_text3, _text3) - self.assertEqual(b'\x87\x01\x90\x87', ident_delta) + self.assertEqual(b"\x87\x01\x90\x87", ident_delta) def assertDeltaIn(self, delta1, delta2, delta): """Make sure that the delta bytes match one of the expectations.""" # In general, the python delta matcher gives different results than the # pyrex delta matcher. Both should be valid deltas, though. if delta not in (delta1, delta2): - self.fail(b"Delta bytes:\n" - b" %r\n" - b"not in %r\n" - b" or %r" - % (delta, delta1, delta2)) + self.fail( + b"Delta bytes:\n" + b" %r\n" + b"not in %r\n" + b" or %r" % (delta, delta1, delta2) + ) def test_make_delta(self): delta = self.make_delta(_text1, _text2) self.assertDeltaIn( - b'N\x90/\x1fdiffer from\nagainst other text\n', - b'N\x90\x1d\x1ewhich is meant to differ from\n\x91:\x13', - delta) + b"N\x90/\x1fdiffer from\nagainst other text\n", + b"N\x90\x1d\x1ewhich is meant to differ from\n\x91:\x13", + delta, + ) delta = self.make_delta(_text2, _text1) self.assertDeltaIn( - b'M\x90/\x1ebe matched\nagainst other text\n', - b'M\x90\x1d\x1dwhich is meant to be matched\n\x91;\x13', - delta) + b"M\x90/\x1ebe matched\nagainst other text\n", + b"M\x90\x1d\x1dwhich is meant to be matched\n\x91;\x13", + delta, + ) delta = self.make_delta(_text3, _text1) - self.assertEqual(b'M\x90M', delta) + self.assertEqual(b"M\x90M", delta) delta = self.make_delta(_text3, _text2) self.assertDeltaIn( - b'N\x90/\x1fdiffer from\nagainst other text\n', - b'N\x90\x1d\x1ewhich is meant to differ from\n\x91:\x13', - delta) + b"N\x90/\x1fdiffer from\nagainst other text\n", + b"N\x90\x1d\x1ewhich is meant to differ from\n\x91:\x13", + delta, + ) def test_make_delta_with_large_copies(self): # We want to have a copy that is larger than 64kB, which forces us to @@ -178,53 +177,55 @@ def test_make_delta_with_large_copies(self): big_text = _text3 * 1220 delta = self.make_delta(big_text, big_text) self.assertDeltaIn( - b'\xdc\x86\x0a' # Encoding the length of the uncompressed text - b'\x80' # Copy 64kB, starting at byte 0 - b'\x84\x01' # and another 64kB starting at 64kB - b'\xb4\x02\x5c\x83', # And the bit of tail. - None, # Both implementations should be identical - delta) + b"\xdc\x86\x0a" # Encoding the length of the uncompressed text + b"\x80" # Copy 64kB, starting at byte 0 + b"\x84\x01" # and another 64kB starting at 64kB + b"\xb4\x02\x5c\x83", # And the bit of tail. + None, # Both implementations should be identical + delta, + ) def test_apply_delta_is_typesafe(self): - self.apply_delta(_text1, b'M\x90M') - self.assertRaises(TypeError, self.apply_delta, object(), b'M\x90M') - self.assertRaises((ValueError, TypeError), self.apply_delta, - _text1.decode('latin1'), b'M\x90M') - self.assertRaises((ValueError, TypeError), self.apply_delta, _text1, 'M\x90M') + self.apply_delta(_text1, b"M\x90M") + self.assertRaises(TypeError, self.apply_delta, object(), b"M\x90M") + self.assertRaises( + (ValueError, TypeError), + self.apply_delta, + _text1.decode("latin1"), + b"M\x90M", + ) + self.assertRaises((ValueError, TypeError), self.apply_delta, _text1, "M\x90M") self.assertRaises(TypeError, self.apply_delta, _text1, object()) def test_apply_delta(self): - target = self.apply_delta(_text1, - b'N\x90/\x1fdiffer from\nagainst other text\n') + target = self.apply_delta( + _text1, b"N\x90/\x1fdiffer from\nagainst other text\n" + ) self.assertEqual(_text2, target) - target = self.apply_delta(_text2, - b'M\x90/\x1ebe matched\nagainst other text\n') + target = self.apply_delta(_text2, b"M\x90/\x1ebe matched\nagainst other text\n") self.assertEqual(_text1, target) def test_apply_delta_to_source_is_safe(self): - self.assertRaises(TypeError, - self.apply_delta_to_source, object(), 0, 1) - self.assertRaises(TypeError, - self.apply_delta_to_source, 'unicode str', 0, 1) + self.assertRaises(TypeError, self.apply_delta_to_source, object(), 0, 1) + self.assertRaises(TypeError, self.apply_delta_to_source, "unicode str", 0, 1) # end > length - self.assertRaises(ValueError, - self.apply_delta_to_source, b'foo', 1, 4) + self.assertRaises(ValueError, self.apply_delta_to_source, b"foo", 1, 4) # start > length - self.assertRaises(ValueError, - self.apply_delta_to_source, b'foo', 5, 3) + self.assertRaises(ValueError, self.apply_delta_to_source, b"foo", 5, 3) # start > end - self.assertRaises(ValueError, - self.apply_delta_to_source, b'foo', 3, 2) + self.assertRaises(ValueError, self.apply_delta_to_source, b"foo", 3, 2) def test_apply_delta_to_source(self): - source_and_delta = (_text1 - + b'N\x90/\x1fdiffer from\nagainst other text\n') - self.assertEqual(_text2, self.apply_delta_to_source(source_and_delta, - len(_text1), len(source_and_delta))) + source_and_delta = _text1 + b"N\x90/\x1fdiffer from\nagainst other text\n" + self.assertEqual( + _text2, + self.apply_delta_to_source( + source_and_delta, len(_text1), len(source_and_delta) + ), + ) class TestMakeAndApplyCompatible(tests.TestCase): - scenarios = two_way_scenarios() make_delta = None # Set by load_tests @@ -246,7 +247,6 @@ def test_direct(self): class TestDeltaIndex(tests.TestCase): - def setUp(self): super().setUp() # This test isn't multiplied, because we only have DeltaIndex for the @@ -256,8 +256,8 @@ def setUp(self): self._gc_module = compiled_groupcompress_feature.module def test_repr(self): - di = self._gc_module.DeltaIndex(b'test text\n') - self.assertEqual('DeltaIndex(1, 10)', repr(di)) + di = self._gc_module.DeltaIndex(b"test text\n") + self.assertEqual("DeltaIndex(1, 10)", repr(di)) def test_sizeof(self): di = self._gc_module.DeltaIndex() @@ -280,19 +280,24 @@ def test__dump_index_simple(self): hash_list, entry_list = di._dump_index() self.assertEqual(16, len(hash_list)) self.assertEqual(68, len(entry_list)) - just_entries = [(idx, text_offset, hash_val) - for idx, (text_offset, hash_val) - in enumerate(entry_list) - if text_offset != 0 or hash_val != 0] + just_entries = [ + (idx, text_offset, hash_val) + for idx, (text_offset, hash_val) in enumerate(entry_list) + if text_offset != 0 or hash_val != 0 + ] rabin_hash = self._gc_module._rabin_hash - self.assertEqual([(8, 16, rabin_hash(_text1[1:17])), - (25, 48, rabin_hash(_text1[33:49])), - (34, 32, rabin_hash(_text1[17:33])), - (47, 64, rabin_hash(_text1[49:65])), - ], just_entries) + self.assertEqual( + [ + (8, 16, rabin_hash(_text1[1:17])), + (25, 48, rabin_hash(_text1[33:49])), + (34, 32, rabin_hash(_text1[17:33])), + (47, 64, rabin_hash(_text1[49:65])), + ], + just_entries, + ) # This ensures that the hash map points to the location we expect it to for entry_idx, _text_offset, hash_val in just_entries: - self.assertEqual(entry_idx, hash_list[hash_val & 0xf]) + self.assertEqual(entry_idx, hash_list[hash_val & 0xF]) def test__dump_index_two_sources(self): di = self._gc_module.DeltaIndex() @@ -303,25 +308,29 @@ def test__dump_index_two_sources(self): hash_list, entry_list = di._dump_index() self.assertEqual(16, len(hash_list)) self.assertEqual(68, len(entry_list)) - just_entries = [(idx, text_offset, hash_val) - for idx, (text_offset, hash_val) - in enumerate(entry_list) - if text_offset != 0 or hash_val != 0] + just_entries = [ + (idx, text_offset, hash_val) + for idx, (text_offset, hash_val) in enumerate(entry_list) + if text_offset != 0 or hash_val != 0 + ] rabin_hash = self._gc_module._rabin_hash - self.assertEqual([(8, 16, rabin_hash(_text1[1:17])), - (9, start2 + 16, rabin_hash(_text2[1:17])), - (25, 48, rabin_hash(_text1[33:49])), - (30, start2 + 64, rabin_hash(_text2[49:65])), - (34, 32, rabin_hash(_text1[17:33])), - (35, start2 + 32, rabin_hash(_text2[17:33])), - (43, start2 + 48, rabin_hash(_text2[33:49])), - (47, 64, rabin_hash(_text1[49:65])), - ], just_entries) + self.assertEqual( + [ + (8, 16, rabin_hash(_text1[1:17])), + (9, start2 + 16, rabin_hash(_text2[1:17])), + (25, 48, rabin_hash(_text1[33:49])), + (30, start2 + 64, rabin_hash(_text2[49:65])), + (34, 32, rabin_hash(_text1[17:33])), + (35, start2 + 32, rabin_hash(_text2[17:33])), + (43, start2 + 48, rabin_hash(_text2[33:49])), + (47, 64, rabin_hash(_text1[49:65])), + ], + just_entries, + ) # Each entry should be in the appropriate hash bucket. for entry_idx, _text_offset, hash_val in just_entries: - hash_idx = hash_val & 0xf - self.assertTrue( - hash_list[hash_idx] <= entry_idx < hash_list[hash_idx + 1]) + hash_idx = hash_val & 0xF + self.assertTrue(hash_list[hash_idx] <= entry_idx < hash_list[hash_idx + 1]) def test_first_add_source_doesnt_index_until_make_delta(self): di = self._gc_module.DeltaIndex() @@ -332,7 +341,7 @@ def test_first_add_source_doesnt_index_until_make_delta(self): # generated, and will generate a proper delta delta = di.make_delta(_text2) self.assertTrue(di._has_index()) - self.assertEqual(b'N\x90/\x1fdiffer from\nagainst other text\n', delta) + self.assertEqual(b"N\x90/\x1fdiffer from\nagainst other text\n", delta) def test_add_source_max_bytes_to_index(self): di = self._gc_module.DeltaIndex() @@ -343,17 +352,25 @@ def test_add_source_max_bytes_to_index(self): hash_list, entry_list = di._dump_index() self.assertEqual(16, len(hash_list)) self.assertEqual(67, len(entry_list)) - just_entries = sorted([(text_offset, hash_val) - for text_offset, hash_val in entry_list - if text_offset != 0 or hash_val != 0]) + just_entries = sorted( + [ + (text_offset, hash_val) + for text_offset, hash_val in entry_list + if text_offset != 0 or hash_val != 0 + ] + ) rabin_hash = self._gc_module._rabin_hash - self.assertEqual([(25, rabin_hash(_text1[10:26])), - (50, rabin_hash(_text1[35:51])), - (75, rabin_hash(_text1[60:76])), - (start2 + 44, rabin_hash(_text3[29:45])), - (start2 + 88, rabin_hash(_text3[73:89])), - (start2 + 132, rabin_hash(_text3[117:133])), - ], just_entries) + self.assertEqual( + [ + (25, rabin_hash(_text1[10:26])), + (50, rabin_hash(_text1[35:51])), + (75, rabin_hash(_text1[60:76])), + (start2 + 44, rabin_hash(_text3[29:45])), + (start2 + 88, rabin_hash(_text3[73:89])), + (start2 + 132, rabin_hash(_text3[117:133])), + ], + just_entries, + ) def test_second_add_source_triggers_make_index(self): di = self._gc_module.DeltaIndex() @@ -366,36 +383,38 @@ def test_second_add_source_triggers_make_index(self): def test_make_delta(self): di = self._gc_module.DeltaIndex(_text1) delta = di.make_delta(_text2) - self.assertEqual(b'N\x90/\x1fdiffer from\nagainst other text\n', delta) + self.assertEqual(b"N\x90/\x1fdiffer from\nagainst other text\n", delta) def test_delta_against_multiple_sources(self): di = self._gc_module.DeltaIndex() di.add_source(_first_text, 0) self.assertEqual(len(_first_text), di._source_offset) di.add_source(_second_text, 0) - self.assertEqual(len(_first_text) + len(_second_text), - di._source_offset) + self.assertEqual(len(_first_text) + len(_second_text), di._source_offset) delta = di.make_delta(_third_text) result = _groupcompress_rs.apply_delta(_first_text + _second_text, delta) self.assertEqualDiff(_third_text, result) - self.assertEqual(b'\x85\x01\x90\x14\x0chas some in ' - b'\x91v6\x03and\x91d"\x91:\n', delta) + self.assertEqual( + b"\x85\x01\x90\x14\x0chas some in " b'\x91v6\x03and\x91d"\x91:\n', delta + ) def test_delta_with_offsets(self): di = self._gc_module.DeltaIndex() di.add_source(_first_text, 5) self.assertEqual(len(_first_text) + 5, di._source_offset) di.add_source(_second_text, 10) - self.assertEqual(len(_first_text) + len(_second_text) + 15, - di._source_offset) + self.assertEqual(len(_first_text) + len(_second_text) + 15, di._source_offset) delta = di.make_delta(_third_text) self.assertIsNot(None, delta) result = _groupcompress_rs.apply_delta( - b'12345' + _first_text + b'1234567890' + _second_text, delta) + b"12345" + _first_text + b"1234567890" + _second_text, delta + ) self.assertIsNot(None, result) self.assertEqualDiff(_third_text, result) - self.assertEqual(b'\x85\x01\x91\x05\x14\x0chas some in ' - b'\x91\x856\x03and\x91s"\x91?\n', delta) + self.assertEqual( + b"\x85\x01\x91\x05\x14\x0chas some in " b'\x91\x856\x03and\x91s"\x91?\n', + delta, + ) def test_delta_with_delta_bytes(self): di = self._gc_module.DeltaIndex() @@ -403,8 +422,9 @@ def test_delta_with_delta_bytes(self): di.add_source(_first_text, 0) self.assertEqual(len(_first_text), di._source_offset) delta = di.make_delta(_second_text) - self.assertEqual(b'h\tsome more\x91\x019' - b'&previous text\nand has some extra text\n', delta) + self.assertEqual( + b"h\tsome more\x91\x019" b"&previous text\nand has some extra text\n", delta + ) di.add_delta_source(delta, 0) source += delta self.assertEqual(len(_first_text) + len(delta), di._source_offset) @@ -416,8 +436,11 @@ def test_delta_with_delta_bytes(self): # Note that we don't match the 'common with the', because it isn't long # enough to match in the original text, and those bytes are not present # in the delta for the second text. - self.assertEqual(b'\x85\x01\x90\x14\x1chas some in common with the ' - b'\x91S&\x03and\x91\x18,', second_delta) + self.assertEqual( + b"\x85\x01\x90\x14\x1chas some in common with the " + b"\x91S&\x03and\x91\x18,", + second_delta, + ) # Add this delta, and create a new delta for the same text. We should # find the remaining text, and only insert the short 'and' text. di.add_delta_source(second_delta, 0) @@ -425,30 +448,35 @@ def test_delta_with_delta_bytes(self): third_delta = di.make_delta(_third_text) result = _groupcompress_rs.apply_delta(source, third_delta) self.assertEqualDiff(_third_text, result) - self.assertEqual(b'\x85\x01\x90\x14\x91\x7e\x1c' - b'\x91S&\x03and\x91\x18,', third_delta) + self.assertEqual( + b"\x85\x01\x90\x14\x91\x7e\x1c" b"\x91S&\x03and\x91\x18,", third_delta + ) # Now create a delta, which we know won't be able to be 'fit' into the # existing index fourth_delta = di.make_delta(_fourth_text) - self.assertEqual(_fourth_text, - _groupcompress_rs.apply_delta(source, fourth_delta)) - self.assertEqual(b'\x80\x01' - b'\x7f123456789012345\nsame rabin hash\n' - b'123456789012345\nsame rabin hash\n' - b'123456789012345\nsame rabin hash\n' - b'123456789012345\nsame rabin hash' - b'\x01\n', fourth_delta) + self.assertEqual( + _fourth_text, _groupcompress_rs.apply_delta(source, fourth_delta) + ) + self.assertEqual( + b"\x80\x01" + b"\x7f123456789012345\nsame rabin hash\n" + b"123456789012345\nsame rabin hash\n" + b"123456789012345\nsame rabin hash\n" + b"123456789012345\nsame rabin hash" + b"\x01\n", + fourth_delta, + ) di.add_delta_source(fourth_delta, 0) source += fourth_delta # With the next delta, everything should be found fifth_delta = di.make_delta(_fourth_text) - self.assertEqual(_fourth_text, - _groupcompress_rs.apply_delta(source, fifth_delta)) - self.assertEqual(b'\x80\x01\x91\xa7\x7f\x01\n', fifth_delta) + self.assertEqual( + _fourth_text, _groupcompress_rs.apply_delta(source, fifth_delta) + ) + self.assertEqual(b"\x80\x01\x91\xa7\x7f\x01\n", fifth_delta) class TestCopyInstruction(tests.TestCase): - def assertEncode(self, expected, offset, length): data = _groupcompress_py.encode_copy_instruction(offset, length) self.assertEqual(expected, data) @@ -460,80 +488,80 @@ def assertDecode(self, exp_offset, exp_length, exp_newpos, data, pos): self.assertEqual((exp_offset, exp_length, exp_newpos), out) def test_encode_no_length(self): - self.assertEncode(b'\x80', 0, 64 * 1024) - self.assertEncode(b'\x81\x01', 1, 64 * 1024) - self.assertEncode(b'\x81\x0a', 10, 64 * 1024) - self.assertEncode(b'\x81\xff', 255, 64 * 1024) - self.assertEncode(b'\x82\x01', 256, 64 * 1024) - self.assertEncode(b'\x83\x01\x01', 257, 64 * 1024) - self.assertEncode(b'\x8F\xff\xff\xff\xff', 0xFFFFFFFF, 64 * 1024) - self.assertEncode(b'\x8E\xff\xff\xff', 0xFFFFFF00, 64 * 1024) - self.assertEncode(b'\x8D\xff\xff\xff', 0xFFFF00FF, 64 * 1024) - self.assertEncode(b'\x8B\xff\xff\xff', 0xFF00FFFF, 64 * 1024) - self.assertEncode(b'\x87\xff\xff\xff', 0x00FFFFFF, 64 * 1024) - self.assertEncode(b'\x8F\x04\x03\x02\x01', 0x01020304, 64 * 1024) + self.assertEncode(b"\x80", 0, 64 * 1024) + self.assertEncode(b"\x81\x01", 1, 64 * 1024) + self.assertEncode(b"\x81\x0a", 10, 64 * 1024) + self.assertEncode(b"\x81\xff", 255, 64 * 1024) + self.assertEncode(b"\x82\x01", 256, 64 * 1024) + self.assertEncode(b"\x83\x01\x01", 257, 64 * 1024) + self.assertEncode(b"\x8F\xff\xff\xff\xff", 0xFFFFFFFF, 64 * 1024) + self.assertEncode(b"\x8E\xff\xff\xff", 0xFFFFFF00, 64 * 1024) + self.assertEncode(b"\x8D\xff\xff\xff", 0xFFFF00FF, 64 * 1024) + self.assertEncode(b"\x8B\xff\xff\xff", 0xFF00FFFF, 64 * 1024) + self.assertEncode(b"\x87\xff\xff\xff", 0x00FFFFFF, 64 * 1024) + self.assertEncode(b"\x8F\x04\x03\x02\x01", 0x01020304, 64 * 1024) def test_encode_no_offset(self): - self.assertEncode(b'\x90\x01', 0, 1) - self.assertEncode(b'\x90\x0a', 0, 10) - self.assertEncode(b'\x90\xff', 0, 255) - self.assertEncode(b'\xA0\x01', 0, 256) - self.assertEncode(b'\xB0\x01\x01', 0, 257) - self.assertEncode(b'\xB0\xff\xff', 0, 0xFFFF) + self.assertEncode(b"\x90\x01", 0, 1) + self.assertEncode(b"\x90\x0a", 0, 10) + self.assertEncode(b"\x90\xff", 0, 255) + self.assertEncode(b"\xA0\x01", 0, 256) + self.assertEncode(b"\xB0\x01\x01", 0, 257) + self.assertEncode(b"\xB0\xff\xff", 0, 0xFFFF) # Special case, if copy == 64KiB, then we store exactly 0 # Note that this puns with a copy of exactly 0 bytes, but we don't care # about that, as we would never actually copy 0 bytes - self.assertEncode(b'\x80', 0, 64 * 1024) + self.assertEncode(b"\x80", 0, 64 * 1024) def test_encode(self): - self.assertEncode(b'\x91\x01\x01', 1, 1) - self.assertEncode(b'\x91\x09\x0a', 9, 10) - self.assertEncode(b'\x91\xfe\xff', 254, 255) - self.assertEncode(b'\xA2\x02\x01', 512, 256) - self.assertEncode(b'\xB3\x02\x01\x01\x01', 258, 257) - self.assertEncode(b'\xB0\x01\x01', 0, 257) + self.assertEncode(b"\x91\x01\x01", 1, 1) + self.assertEncode(b"\x91\x09\x0a", 9, 10) + self.assertEncode(b"\x91\xfe\xff", 254, 255) + self.assertEncode(b"\xA2\x02\x01", 512, 256) + self.assertEncode(b"\xB3\x02\x01\x01\x01", 258, 257) + self.assertEncode(b"\xB0\x01\x01", 0, 257) # Special case, if copy == 64KiB, then we store exactly 0 # Note that this puns with a copy of exactly 0 bytes, but we don't care # about that, as we would never actually copy 0 bytes - self.assertEncode(b'\x81\x0a', 10, 64 * 1024) + self.assertEncode(b"\x81\x0a", 10, 64 * 1024) def test_decode_no_length(self): # If length is 0, it is interpreted as 64KiB # The shortest possible instruction is a copy of 64KiB from offset 0 - self.assertDecode(0, 65536, 1, b'\x80', 0) - self.assertDecode(1, 65536, 2, b'\x81\x01', 0) - self.assertDecode(10, 65536, 2, b'\x81\x0a', 0) - self.assertDecode(255, 65536, 2, b'\x81\xff', 0) - self.assertDecode(256, 65536, 2, b'\x82\x01', 0) - self.assertDecode(257, 65536, 3, b'\x83\x01\x01', 0) - self.assertDecode(0xFFFFFFFF, 65536, 5, b'\x8F\xff\xff\xff\xff', 0) - self.assertDecode(0xFFFFFF00, 65536, 4, b'\x8E\xff\xff\xff', 0) - self.assertDecode(0xFFFF00FF, 65536, 4, b'\x8D\xff\xff\xff', 0) - self.assertDecode(0xFF00FFFF, 65536, 4, b'\x8B\xff\xff\xff', 0) - self.assertDecode(0x00FFFFFF, 65536, 4, b'\x87\xff\xff\xff', 0) - self.assertDecode(0x01020304, 65536, 5, b'\x8F\x04\x03\x02\x01', 0) + self.assertDecode(0, 65536, 1, b"\x80", 0) + self.assertDecode(1, 65536, 2, b"\x81\x01", 0) + self.assertDecode(10, 65536, 2, b"\x81\x0a", 0) + self.assertDecode(255, 65536, 2, b"\x81\xff", 0) + self.assertDecode(256, 65536, 2, b"\x82\x01", 0) + self.assertDecode(257, 65536, 3, b"\x83\x01\x01", 0) + self.assertDecode(0xFFFFFFFF, 65536, 5, b"\x8F\xff\xff\xff\xff", 0) + self.assertDecode(0xFFFFFF00, 65536, 4, b"\x8E\xff\xff\xff", 0) + self.assertDecode(0xFFFF00FF, 65536, 4, b"\x8D\xff\xff\xff", 0) + self.assertDecode(0xFF00FFFF, 65536, 4, b"\x8B\xff\xff\xff", 0) + self.assertDecode(0x00FFFFFF, 65536, 4, b"\x87\xff\xff\xff", 0) + self.assertDecode(0x01020304, 65536, 5, b"\x8F\x04\x03\x02\x01", 0) def test_decode_no_offset(self): - self.assertDecode(0, 1, 2, b'\x90\x01', 0) - self.assertDecode(0, 10, 2, b'\x90\x0a', 0) - self.assertDecode(0, 255, 2, b'\x90\xff', 0) - self.assertDecode(0, 256, 2, b'\xA0\x01', 0) - self.assertDecode(0, 257, 3, b'\xB0\x01\x01', 0) - self.assertDecode(0, 65535, 3, b'\xB0\xff\xff', 0) + self.assertDecode(0, 1, 2, b"\x90\x01", 0) + self.assertDecode(0, 10, 2, b"\x90\x0a", 0) + self.assertDecode(0, 255, 2, b"\x90\xff", 0) + self.assertDecode(0, 256, 2, b"\xA0\x01", 0) + self.assertDecode(0, 257, 3, b"\xB0\x01\x01", 0) + self.assertDecode(0, 65535, 3, b"\xB0\xff\xff", 0) # Special case, if copy == 64KiB, then we store exactly 0 # Note that this puns with a copy of exactly 0 bytes, but we don't care # about that, as we would never actually copy 0 bytes - self.assertDecode(0, 65536, 1, b'\x80', 0) + self.assertDecode(0, 65536, 1, b"\x80", 0) def test_decode(self): - self.assertDecode(1, 1, 3, b'\x91\x01\x01', 0) - self.assertDecode(9, 10, 3, b'\x91\x09\x0a', 0) - self.assertDecode(254, 255, 3, b'\x91\xfe\xff', 0) - self.assertDecode(512, 256, 3, b'\xA2\x02\x01', 0) - self.assertDecode(258, 257, 5, b'\xB3\x02\x01\x01\x01', 0) - self.assertDecode(0, 257, 3, b'\xB0\x01\x01', 0) + self.assertDecode(1, 1, 3, b"\x91\x01\x01", 0) + self.assertDecode(9, 10, 3, b"\x91\x09\x0a", 0) + self.assertDecode(254, 255, 3, b"\x91\xfe\xff", 0) + self.assertDecode(512, 256, 3, b"\xA2\x02\x01", 0) + self.assertDecode(258, 257, 5, b"\xB3\x02\x01\x01\x01", 0) + self.assertDecode(0, 257, 3, b"\xB0\x01\x01", 0) def test_decode_not_start(self): - self.assertDecode(1, 1, 6, b'abc\x91\x01\x01def', 3) - self.assertDecode(9, 10, 5, b'ab\x91\x09\x0ade', 2) - self.assertDecode(254, 255, 6, b'not\x91\xfe\xffcopy', 3) + self.assertDecode(1, 1, 6, b"abc\x91\x01\x01def", 3) + self.assertDecode(9, 10, 5, b"ab\x91\x09\x0ade", 2) + self.assertDecode(254, 255, 6, b"not\x91\xfe\xffcopy", 3) diff --git a/breezy/bzr/tests/test__simple_set.py b/breezy/bzr/tests/test__simple_set.py index cc0a911bdf..2a07c9eaed 100644 --- a/breezy/bzr/tests/test__simple_set.py +++ b/breezy/bzr/tests/test__simple_set.py @@ -47,7 +47,6 @@ def __eq__(self, other): class _BadSecondHash(_Hashable): - def __init__(self, the_hash): _Hashable.__init__(self, the_hash) self._first = True @@ -57,19 +56,17 @@ def __hash__(self): self._first = False return self.hash else: - raise ValueError('I can only be hashed once.') + raise ValueError("I can only be hashed once.") class _BadCompare(_Hashable): - def __eq__(self, other): - raise RuntimeError('I refuse to play nice') + raise RuntimeError("I refuse to play nice") __hash__ = _Hashable.__hash__ class _NoImplementCompare(_Hashable): - def __eq__(self, other): return NotImplemented @@ -79,11 +76,11 @@ def __eq__(self, other): # Even though this is an extension, we don't permute the tests for a python # version. As the plain python version is just a dict or set compiled_simpleset_feature = features.ModuleAvailableFeature( - 'breezy.bzr._simple_set_pyx') + "breezy.bzr._simple_set_pyx" +) class TestSimpleSet(tests.TestCase): - _test_needs_features = [compiled_simpleset_feature] module = _simple_set_pyx @@ -112,25 +109,25 @@ def assertRefcount(self, count, obj): def test_initial(self): obj = self.module.SimpleSet() self.assertEqual(0, len(obj)) - self.assertFillState(0, 0, 0x3ff, obj) + self.assertFillState(0, 0, 0x3FF, obj) def test__lookup(self): # These are carefully chosen integers to force hash collisions in the # algorithm, based on the initial set size of 1024 obj = self.module.SimpleSet() - self.assertLookup(643, '', obj, _Hashable(643)) - self.assertLookup(643, '', obj, _Hashable(643 + 1024)) - self.assertLookup(643, '', obj, _Hashable(643 + 50 * 1024)) + self.assertLookup(643, "", obj, _Hashable(643)) + self.assertLookup(643, "", obj, _Hashable(643 + 1024)) + self.assertLookup(643, "", obj, _Hashable(643 + 50 * 1024)) def test__lookup_collision(self): obj = self.module.SimpleSet() k1 = _Hashable(643) k2 = _Hashable(643 + 1024) - self.assertLookup(643, '', obj, k1) - self.assertLookup(643, '', obj, k2) + self.assertLookup(643, "", obj, k1) + self.assertLookup(643, "", obj, k2) obj.add(k1) self.assertLookup(643, k1, obj, k1) - self.assertLookup(644, '', obj, k2) + self.assertLookup(644, "", obj, k2) def test__lookup_after_resize(self): obj = self.module.SimpleSet() @@ -165,38 +162,38 @@ def test_get_set_del_with_collisions(self): k4 = _Hashable(h4) k5 = _Hashable(h5) k6 = _Hashable(h6) - self.assertLookup(643, '', obj, k1) - self.assertLookup(643, '', obj, k2) - self.assertLookup(643, '', obj, k3) - self.assertLookup(643, '', obj, k4) - self.assertLookup(644, '', obj, k5) - self.assertLookup(644, '', obj, k6) + self.assertLookup(643, "", obj, k1) + self.assertLookup(643, "", obj, k2) + self.assertLookup(643, "", obj, k3) + self.assertLookup(643, "", obj, k4) + self.assertLookup(644, "", obj, k5) + self.assertLookup(644, "", obj, k6) obj.add(k1) self.assertIn(k1, obj) self.assertNotIn(k2, obj) self.assertNotIn(k3, obj) self.assertNotIn(k4, obj) self.assertLookup(643, k1, obj, k1) - self.assertLookup(644, '', obj, k2) - self.assertLookup(644, '', obj, k3) - self.assertLookup(644, '', obj, k4) - self.assertLookup(644, '', obj, k5) - self.assertLookup(644, '', obj, k6) + self.assertLookup(644, "", obj, k2) + self.assertLookup(644, "", obj, k3) + self.assertLookup(644, "", obj, k4) + self.assertLookup(644, "", obj, k5) + self.assertLookup(644, "", obj, k6) self.assertIs(k1, obj[k1]) self.assertIs(k2, obj.add(k2)) self.assertIs(k2, obj[k2]) self.assertLookup(643, k1, obj, k1) self.assertLookup(644, k2, obj, k2) - self.assertLookup(646, '', obj, k3) - self.assertLookup(646, '', obj, k4) - self.assertLookup(645, '', obj, k5) - self.assertLookup(645, '', obj, k6) + self.assertLookup(646, "", obj, k3) + self.assertLookup(646, "", obj, k4) + self.assertLookup(645, "", obj, k5) + self.assertLookup(645, "", obj, k6) self.assertLookup(643, k1, obj, _Hashable(h1)) self.assertLookup(644, k2, obj, _Hashable(h2)) - self.assertLookup(646, '', obj, _Hashable(h3)) - self.assertLookup(646, '', obj, _Hashable(h4)) - self.assertLookup(645, '', obj, _Hashable(h5)) - self.assertLookup(645, '', obj, _Hashable(h6)) + self.assertLookup(646, "", obj, _Hashable(h3)) + self.assertLookup(646, "", obj, _Hashable(h4)) + self.assertLookup(645, "", obj, _Hashable(h5)) + self.assertLookup(645, "", obj, _Hashable(h6)) obj.add(k3) self.assertIs(k3, obj[k3]) self.assertIn(k1, obj) @@ -205,10 +202,10 @@ def test_get_set_del_with_collisions(self): self.assertNotIn(k4, obj) obj.discard(k1) - self.assertLookup(643, '', obj, k1) + self.assertLookup(643, "", obj, k1) self.assertLookup(644, k2, obj, k2) self.assertLookup(646, k3, obj, k3) - self.assertLookup(643, '', obj, k4) + self.assertLookup(643, "", obj, k4) self.assertNotIn(k1, obj) self.assertIn(k2, obj) self.assertIn(k3, obj) @@ -216,26 +213,26 @@ def test_get_set_del_with_collisions(self): def test_add(self): obj = self.module.SimpleSet() - self.assertFillState(0, 0, 0x3ff, obj) + self.assertFillState(0, 0, 0x3FF, obj) # We use this clumsy notation, because otherwise the refcounts are off. # I'm guessing the python compiler sees it is a static tuple, and adds # it to the function variables, or somesuch - k1 = tuple(['foo']) # noqa: C409 + k1 = tuple(["foo"]) # noqa: C409 self.assertRefcount(1, k1) self.assertIs(k1, obj.add(k1)) - self.assertFillState(1, 1, 0x3ff, obj) + self.assertFillState(1, 1, 0x3FF, obj) self.assertRefcount(2, k1) ktest = obj[k1] self.assertRefcount(3, k1) self.assertIs(k1, ktest) del ktest self.assertRefcount(2, k1) - k2 = tuple(['foo']) # noqa: C409 + k2 = tuple(["foo"]) # noqa: C409 self.assertRefcount(1, k2) self.assertIsNot(k1, k2) # doesn't add anything, so the counters shouldn't be adjusted self.assertIs(k1, obj.add(k2)) - self.assertFillState(1, 1, 0x3ff, obj) + self.assertFillState(1, 1, 0x3FF, obj) self.assertRefcount(2, k1) # not changed self.assertRefcount(1, k2) # not incremented self.assertIs(k1, obj[k1]) @@ -244,24 +241,24 @@ def test_add(self): self.assertRefcount(1, k2) # Deleting an entry should remove the fill, but not the used obj.discard(k1) - self.assertFillState(0, 1, 0x3ff, obj) + self.assertFillState(0, 1, 0x3FF, obj) self.assertRefcount(1, k1) - k3 = tuple(['bar']) # noqa: C409 + k3 = tuple(["bar"]) # noqa: C409 self.assertRefcount(1, k3) self.assertIs(k3, obj.add(k3)) - self.assertFillState(1, 2, 0x3ff, obj) + self.assertFillState(1, 2, 0x3FF, obj) self.assertRefcount(2, k3) self.assertIs(k2, obj.add(k2)) - self.assertFillState(2, 2, 0x3ff, obj) + self.assertFillState(2, 2, 0x3FF, obj) self.assertRefcount(1, k1) self.assertRefcount(2, k2) self.assertRefcount(2, k3) def test_discard(self): obj = self.module.SimpleSet() - k1 = tuple(['foo']) # noqa: C409 - k2 = tuple(['foo']) # noqa: C409 - k3 = tuple(['bar']) # noqa: C409 + k1 = tuple(["foo"]) # noqa: C409 + k2 = tuple(["foo"]) # noqa: C409 + k3 = tuple(["bar"]) # noqa: C409 self.assertRefcount(1, k1) self.assertRefcount(1, k2) self.assertRefcount(1, k3) @@ -284,26 +281,26 @@ def test__resize(self): obj.add(k2) obj.add(k3) obj.discard(k2) - self.assertFillState(2, 3, 0x3ff, obj) + self.assertFillState(2, 3, 0x3FF, obj) self.assertEqual(1024, obj._py_resize(500)) # Doesn't change the size, but does change the content - self.assertFillState(2, 2, 0x3ff, obj) + self.assertFillState(2, 2, 0x3FF, obj) obj.add(k2) obj.discard(k3) - self.assertFillState(2, 3, 0x3ff, obj) + self.assertFillState(2, 3, 0x3FF, obj) self.assertEqual(4096, obj._py_resize(4095)) - self.assertFillState(2, 2, 0xfff, obj) + self.assertFillState(2, 2, 0xFFF, obj) self.assertIn(k1, obj) self.assertIn(k2, obj) self.assertNotIn(k3, obj) obj.add(k2) self.assertIn(k2, obj) obj.discard(k2) - self.assertEqual((591, ''), obj._test_lookup(k2)) - self.assertFillState(1, 2, 0xfff, obj) + self.assertEqual((591, ""), obj._test_lookup(k2)) + self.assertFillState(1, 2, 0xFFF, obj) self.assertEqual(2048, obj._py_resize(1024)) - self.assertFillState(1, 1, 0x7ff, obj) - self.assertEqual((591, ''), obj._test_lookup(k2)) + self.assertFillState(1, 1, 0x7FF, obj) + self.assertEqual((591, ""), obj._test_lookup(k2)) def test_second_hash_failure(self): obj = self.module.SimpleSet() @@ -328,24 +325,23 @@ def test_richcompare_not_implemented(self): # NotImplemented, which means we treat them as not equal k1 = _NoImplementCompare(200) k2 = _NoImplementCompare(200) - self.assertLookup(200, '', obj, k1) - self.assertLookup(200, '', obj, k2) + self.assertLookup(200, "", obj, k1) + self.assertLookup(200, "", obj, k2) self.assertIs(k1, obj.add(k1)) self.assertLookup(200, k1, obj, k1) - self.assertLookup(201, '', obj, k2) + self.assertLookup(201, "", obj, k2) self.assertIs(k2, obj.add(k2)) self.assertIs(k1, obj[k1]) def test_add_and_remove_lots_of_items(self): obj = self.module.SimpleSet() - chars = ('ABCDEFGHIJKLMNOPQRSTUVWXYZ' - 'abcdefghijklmnopqrstuvwxyz1234567890') + chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz1234567890" for i in chars: for j in chars: k = (i, j) obj.add(k) num = len(chars) * len(chars) - self.assertFillState(num, num, 0x1fff, obj) + self.assertFillState(num, num, 0x1FFF, obj) # Now delete all of the entries and it should shrink again for i in chars: for j in chars: @@ -353,15 +349,15 @@ def test_add_and_remove_lots_of_items(self): obj.discard(k) # It should be back to 1024 wide mask, though there may still be some # dummy values in there - self.assertFillState(0, obj.fill, 0x3ff, obj) + self.assertFillState(0, obj.fill, 0x3FF, obj) # but there should be fewer than 1/5th dummy entries self.assertLess(obj.fill, 1024 / 5) def test__iter__(self): obj = self.module.SimpleSet() - k1 = ('1',) - k2 = ('1', '2') - k3 = ('3', '4') + k1 = ("1",) + k2 = ("1", "2") + k3 = ("3", "4") obj.add(k1) obj.add(k2) obj.add(k3) @@ -371,7 +367,7 @@ def test__iter__(self): self.assertEqual(sorted([k1, k2, k3]), sorted(all)) iterator = iter(obj) self.assertIn(next(iterator), all) - obj.add(('foo',)) + obj.add(("foo",)) # Set changed size self.assertRaises(RuntimeError, next, iterator) # And even removing an item still causes it to fail diff --git a/breezy/bzr/tests/test__static_tuple.py b/breezy/bzr/tests/test__static_tuple.py index d84d454653..c12e182852 100644 --- a/breezy/bzr/tests/test__static_tuple.py +++ b/breezy/bzr/tests/test__static_tuple.py @@ -29,13 +29,15 @@ def load_tests(loader, standard_tests, pattern): """Parameterize tests for all versions of groupcompress.""" global compiled_static_tuple_feature suite, compiled_static_tuple_feature = tests.permute_tests_for_extension( - standard_tests, loader, 'breezy.bzr._static_tuple_py', - 'breezy.bzr._static_tuple_c') + standard_tests, + loader, + "breezy.bzr._static_tuple_py", + "breezy.bzr._static_tuple_c", + ) return suite class TestStaticTuple(tests.TestCase): - def assertRefcount(self, count, obj): """Assert that the refcount for obj is what we expect. @@ -54,31 +56,31 @@ def assertRefcount(self, count, obj): self.assertEqual(count, sys.getrefcount(obj) - 2) def test_create(self): - self.module.StaticTuple('foo') - self.module.StaticTuple('foo', 'bar') + self.module.StaticTuple("foo") + self.module.StaticTuple("foo", "bar") def test_create_bad_args(self): - args_256 = ['a'] * 256 + args_256 = ["a"] * 256 # too many args self.assertRaises(TypeError, self.module.StaticTuple, *args_256) - args_300 = ['a'] * 300 + args_300 = ["a"] * 300 self.assertRaises(TypeError, self.module.StaticTuple, *args_300) # not a string self.assertRaises(TypeError, self.module.StaticTuple, object()) def test_concat(self): - st1 = self.module.StaticTuple('foo') - st2 = self.module.StaticTuple('bar') - st3 = self.module.StaticTuple('foo', 'bar') + st1 = self.module.StaticTuple("foo") + st2 = self.module.StaticTuple("bar") + st3 = self.module.StaticTuple("foo", "bar") st4 = st1 + st2 self.assertEqual(st3, st4) self.assertIsInstance(st4, self.module.StaticTuple) def test_concat_with_tuple(self): - st1 = self.module.StaticTuple('foo') - t2 = ('bar',) - st3 = self.module.StaticTuple('foo', 'bar') - st4 = self.module.StaticTuple('bar', 'foo') + st1 = self.module.StaticTuple("foo") + t2 = ("bar",) + st3 = self.module.StaticTuple("foo", "bar") + st4 = self.module.StaticTuple("bar", "foo") st5 = st1 + t2 st6 = t2 + st1 self.assertEqual(st3, st5) @@ -93,85 +95,85 @@ def test_concat_with_tuple(self): self.assertIsInstance(st6, self.module.StaticTuple) def test_concat_with_bad_tuple(self): - st1 = self.module.StaticTuple('foo') + st1 = self.module.StaticTuple("foo") t2 = (object(),) # Using st1.__add__ doesn't give the same results as doing the '+' form self.assertRaises(TypeError, lambda: st1 + t2) def test_concat_with_non_tuple(self): - st1 = self.module.StaticTuple('foo') + st1 = self.module.StaticTuple("foo") self.assertRaises(TypeError, lambda: st1 + 10) def test_as_tuple(self): - k = self.module.StaticTuple('foo') + k = self.module.StaticTuple("foo") t = k.as_tuple() - self.assertEqual(('foo',), t) + self.assertEqual(("foo",), t) self.assertIsInstance(t, tuple) self.assertNotIsInstance(t, self.module.StaticTuple) - k = self.module.StaticTuple('foo', 'bar') + k = self.module.StaticTuple("foo", "bar") t = k.as_tuple() - self.assertEqual(('foo', 'bar'), t) + self.assertEqual(("foo", "bar"), t) k2 = self.module.StaticTuple(1, k) t = k2.as_tuple() self.assertIsInstance(t, tuple) # For pickling to work, we need to keep the sub-items as StaticTuple so # that it knows that they also need to be converted. self.assertIsInstance(t[1], self.module.StaticTuple) - self.assertEqual((1, ('foo', 'bar')), t) + self.assertEqual((1, ("foo", "bar")), t) def test_as_tuples(self): - k1 = self.module.StaticTuple('foo', 'bar') + k1 = self.module.StaticTuple("foo", "bar") t = static_tuple.as_tuples(k1) self.assertIsInstance(t, tuple) - self.assertEqual(('foo', 'bar'), t) + self.assertEqual(("foo", "bar"), t) k2 = self.module.StaticTuple(1, k1) t = static_tuple.as_tuples(k2) self.assertIsInstance(t, tuple) self.assertIsInstance(t[1], tuple) - self.assertEqual((1, ('foo', 'bar')), t) + self.assertEqual((1, ("foo", "bar")), t) mixed = (1, k1) t = static_tuple.as_tuples(mixed) self.assertIsInstance(t, tuple) self.assertIsInstance(t[1], tuple) - self.assertEqual((1, ('foo', 'bar')), t) + self.assertEqual((1, ("foo", "bar")), t) def test_len(self): k = self.module.StaticTuple() self.assertEqual(0, len(k)) - k = self.module.StaticTuple('foo') + k = self.module.StaticTuple("foo") self.assertEqual(1, len(k)) - k = self.module.StaticTuple('foo', 'bar') + k = self.module.StaticTuple("foo", "bar") self.assertEqual(2, len(k)) - k = self.module.StaticTuple('foo', 'bar', 'b', 'b', 'b', 'b', 'b') + k = self.module.StaticTuple("foo", "bar", "b", "b", "b", "b", "b") self.assertEqual(7, len(k)) - args = ['foo'] * 255 + args = ["foo"] * 255 k = self.module.StaticTuple(*args) self.assertEqual(255, len(k)) def test_hold_other_static_tuples(self): - k = self.module.StaticTuple('foo', 'bar') + k = self.module.StaticTuple("foo", "bar") k2 = self.module.StaticTuple(k, k) self.assertEqual(2, len(k2)) self.assertIs(k, k2[0]) self.assertIs(k, k2[1]) def test_getitem(self): - k = self.module.StaticTuple('foo', 'bar', 'b', 'b', 'b', 'b', 'z') - self.assertEqual('foo', k[0]) - self.assertEqual('foo', k[0]) - self.assertEqual('foo', k[0]) - self.assertEqual('z', k[6]) - self.assertEqual('z', k[-1]) + k = self.module.StaticTuple("foo", "bar", "b", "b", "b", "b", "z") + self.assertEqual("foo", k[0]) + self.assertEqual("foo", k[0]) + self.assertEqual("foo", k[0]) + self.assertEqual("z", k[6]) + self.assertEqual("z", k[-1]) self.assertRaises(IndexError, k.__getitem__, 7) self.assertRaises(IndexError, k.__getitem__, 256 + 7) self.assertRaises(IndexError, k.__getitem__, 12024) # Python's [] resolver handles the negative arguments, so we can't # really test StaticTuple_item() with negative values. - self.assertRaises(TypeError, k.__getitem__, 'not-an-int') - self.assertRaises(TypeError, k.__getitem__, '5') + self.assertRaises(TypeError, k.__getitem__, "not-an-int") + self.assertRaises(TypeError, k.__getitem__, "5") def test_refcount(self): - f = 'fo' + 'oo' + f = "fo" + "oo" num_refs = sys.getrefcount(f) - 1 # sys.getrefcount() adds one k = self.module.StaticTuple(f) self.assertRefcount(num_refs + 1, f) @@ -187,7 +189,7 @@ def test_refcount(self): self.assertRefcount(num_refs, f) def test__repr__(self): - k = self.module.StaticTuple('foo', 'bar', 'baz', 'bing') + k = self.module.StaticTuple("foo", "bar", "baz", "bing") self.assertEqual("StaticTuple('foo', 'bar', 'baz', 'bing')", repr(k)) def assertCompareEqual(self, k1, k2): @@ -207,6 +209,7 @@ def test_holds_int(self): class subint(int): pass + # But not a subclass, because subint could introduce refcycles self.assertRaises(TypeError, self.module.StaticTuple, subint(2)) @@ -215,22 +218,24 @@ def test_holds_float(self): class subfloat(float): pass + self.assertRaises(TypeError, self.module.StaticTuple, subfloat(1.5)) def test_holds_bytes(self): - self.module.StaticTuple(b'astring') + self.module.StaticTuple(b"astring") class substr(bytes): pass - self.assertRaises(TypeError, self.module.StaticTuple, substr(b'a')) + + self.assertRaises(TypeError, self.module.StaticTuple, substr(b"a")) def test_holds_unicode(self): - self.module.StaticTuple('\xb5') + self.module.StaticTuple("\xb5") class subunicode(str): pass - self.assertRaises(TypeError, self.module.StaticTuple, - subunicode('\xb5')) + + self.assertRaises(TypeError, self.module.StaticTuple, subunicode("\xb5")) def test_hold_bool(self): self.module.StaticTuple(True) @@ -238,36 +243,33 @@ def test_hold_bool(self): # Cannot subclass bool def test_compare_same_obj(self): - k1 = self.module.StaticTuple('foo', 'bar') + k1 = self.module.StaticTuple("foo", "bar") self.assertCompareEqual(k1, k1) k2 = self.module.StaticTuple(k1, k1) self.assertCompareEqual(k2, k2) - k3 = self.module.StaticTuple('foo', 1, None, '\xb5', 1.2, 2**65, True, - k1) + k3 = self.module.StaticTuple("foo", 1, None, "\xb5", 1.2, 2**65, True, k1) self.assertCompareEqual(k3, k3) def test_compare_equivalent_obj(self): - k1 = self.module.StaticTuple('foo', 'bar') - k2 = self.module.StaticTuple('foo', 'bar') + k1 = self.module.StaticTuple("foo", "bar") + k2 = self.module.StaticTuple("foo", "bar") self.assertCompareEqual(k1, k2) self.module.StaticTuple(k1, k2) self.module.StaticTuple(k2, k1) self.assertCompareEqual(k1, k2) - k5 = self.module.StaticTuple('foo', 1, None, '\xb5', 1.2, 2**65, True, - k1) - k6 = self.module.StaticTuple('foo', 1, None, '\xb5', 1.2, 2**65, True, - k1) + k5 = self.module.StaticTuple("foo", 1, None, "\xb5", 1.2, 2**65, True, k1) + k6 = self.module.StaticTuple("foo", 1, None, "\xb5", 1.2, 2**65, True, k1) self.assertCompareEqual(k5, k6) k7 = self.module.StaticTuple(None) k8 = self.module.StaticTuple(None) self.assertCompareEqual(k7, k8) def test_compare_similar_obj(self): - k1 = self.module.StaticTuple('foo' + ' bar', 'bar' + ' baz') - k2 = self.module.StaticTuple('fo' + 'o bar', 'ba' + 'r baz') + k1 = self.module.StaticTuple("foo" + " bar", "bar" + " baz") + k2 = self.module.StaticTuple("fo" + "o bar", "ba" + "r baz") self.assertCompareEqual(k1, k2) - k3 = self.module.StaticTuple('foo ' + 'bar', 'bar ' + 'baz') - k4 = self.module.StaticTuple('f' + 'oo bar', 'b' + 'ar baz') + k3 = self.module.StaticTuple("foo " + "bar", "bar " + "baz") + k4 = self.module.StaticTuple("f" + "oo bar", "b" + "ar baz") k5 = self.module.StaticTuple(k1, k2) k6 = self.module.StaticTuple(k3, k4) self.assertCompareEqual(k5, k6) @@ -306,17 +308,17 @@ def assertCompareNoRelation(self, k1, k2, mismatched_types=False): k1 < k2 # noqa: B015 def test_compare_vs_none(self): - k1 = self.module.StaticTuple('baz', 'bing') + k1 = self.module.StaticTuple("baz", "bing") self.assertCompareDifferent(None, k1, mismatched_types=True) def test_compare_cross_class(self): - k1 = self.module.StaticTuple('baz', 'bing') + k1 = self.module.StaticTuple("baz", "bing") self.assertCompareNoRelation(10, k1, mismatched_types=True) - self.assertCompareNoRelation('baz', k1, mismatched_types=True) + self.assertCompareNoRelation("baz", k1, mismatched_types=True) def test_compare_all_different_same_width(self): - k1 = self.module.StaticTuple('baz', 'bing') - k2 = self.module.StaticTuple('foo', 'bar') + k1 = self.module.StaticTuple("baz", "bing") + k2 = self.module.StaticTuple("foo", "bar") self.assertCompareDifferent(k1, k2) k3 = self.module.StaticTuple(k1, k2) k4 = self.module.StaticTuple(k2, k1) @@ -327,35 +329,34 @@ def test_compare_all_different_same_width(self): k7 = self.module.StaticTuple(1.2) k8 = self.module.StaticTuple(2.4) self.assertCompareDifferent(k7, k8) - k9 = self.module.StaticTuple('s\xb5') - k10 = self.module.StaticTuple('s\xe5') + k9 = self.module.StaticTuple("s\xb5") + k10 = self.module.StaticTuple("s\xe5") self.assertCompareDifferent(k9, k10) def test_compare_some_different(self): - k1 = self.module.StaticTuple('foo', 'bar') - k2 = self.module.StaticTuple('foo', 'zzz') + k1 = self.module.StaticTuple("foo", "bar") + k2 = self.module.StaticTuple("foo", "zzz") self.assertCompareDifferent(k1, k2) k3 = self.module.StaticTuple(k1, k1) k4 = self.module.StaticTuple(k1, k2) self.assertCompareDifferent(k3, k4) - k5 = self.module.StaticTuple('foo', None) + k5 = self.module.StaticTuple("foo", None) self.assertCompareDifferent(k5, k1, mismatched_types=True) self.assertCompareDifferent(k5, k2, mismatched_types=True) def test_compare_diff_width(self): - k1 = self.module.StaticTuple('foo') - k2 = self.module.StaticTuple('foo', 'bar') + k1 = self.module.StaticTuple("foo") + k2 = self.module.StaticTuple("foo", "bar") self.assertCompareDifferent(k1, k2) k3 = self.module.StaticTuple(k1) k4 = self.module.StaticTuple(k1, k2) self.assertCompareDifferent(k3, k4) def test_compare_different_types(self): - k1 = self.module.StaticTuple('foo', 'bar') - k2 = self.module.StaticTuple('foo', 1, None, '\xb5', 1.2, 2**65, True, - k1) + k1 = self.module.StaticTuple("foo", "bar") + k2 = self.module.StaticTuple("foo", 1, None, "\xb5", 1.2, 2**65, True, k1) self.assertCompareNoRelation(k1, k2, mismatched_types=True) - k3 = self.module.StaticTuple('foo') + k3 = self.module.StaticTuple("foo") self.assertCompareDifferent(k3, k1) k4 = self.module.StaticTuple(None) self.assertCompareDifferent(k4, k1, mismatched_types=True) @@ -363,65 +364,88 @@ def test_compare_different_types(self): self.assertCompareNoRelation(k1, k5, mismatched_types=True) def test_compare_to_tuples(self): - k1 = self.module.StaticTuple('foo') - self.assertCompareEqual(k1, ('foo',)) - self.assertCompareEqual(('foo',), k1) - self.assertCompareDifferent(k1, ('foo', 'bar')) - self.assertCompareDifferent(k1, ('foo', 10)) - - k2 = self.module.StaticTuple('foo', 'bar') - self.assertCompareEqual(k2, ('foo', 'bar')) - self.assertCompareEqual(('foo', 'bar'), k2) - self.assertCompareDifferent(k2, ('foo', 'zzz')) - self.assertCompareDifferent(('foo',), k2) - self.assertCompareDifferent(('foo', 'aaa'), k2) - self.assertCompareDifferent(('baz', 'bing'), k2) - self.assertCompareDifferent(('foo', 10), k2, mismatched_types=True) + k1 = self.module.StaticTuple("foo") + self.assertCompareEqual(k1, ("foo",)) + self.assertCompareEqual(("foo",), k1) + self.assertCompareDifferent(k1, ("foo", "bar")) + self.assertCompareDifferent(k1, ("foo", 10)) + + k2 = self.module.StaticTuple("foo", "bar") + self.assertCompareEqual(k2, ("foo", "bar")) + self.assertCompareEqual(("foo", "bar"), k2) + self.assertCompareDifferent(k2, ("foo", "zzz")) + self.assertCompareDifferent(("foo",), k2) + self.assertCompareDifferent(("foo", "aaa"), k2) + self.assertCompareDifferent(("baz", "bing"), k2) + self.assertCompareDifferent(("foo", 10), k2, mismatched_types=True) k3 = self.module.StaticTuple(k1, k2) - self.assertCompareEqual(k3, (('foo',), ('foo', 'bar'))) - self.assertCompareEqual((('foo',), ('foo', 'bar')), k3) - self.assertCompareEqual(k3, (k1, ('foo', 'bar'))) - self.assertCompareEqual((k1, ('foo', 'bar')), k3) + self.assertCompareEqual(k3, (("foo",), ("foo", "bar"))) + self.assertCompareEqual((("foo",), ("foo", "bar")), k3) + self.assertCompareEqual(k3, (k1, ("foo", "bar"))) + self.assertCompareEqual((k1, ("foo", "bar")), k3) def test_compare_mixed_depths(self): stuple = self.module.StaticTuple - k1 = stuple(stuple('a',), stuple('b',)) - k2 = stuple(stuple(stuple('c',), stuple('d',)), - stuple('b',)) + k1 = stuple( + stuple( + "a", + ), + stuple( + "b", + ), + ) + k2 = stuple( + stuple( + stuple( + "c", + ), + stuple( + "d", + ), + ), + stuple( + "b", + ), + ) # This requires comparing a StaticTuple to a 'string', and then # interpreting that value in the next higher StaticTuple. This used to # generate a PyErr_BadIternalCall. We now fall back to *something*. self.assertCompareNoRelation(k1, k2, mismatched_types=True) def test_hash(self): - k = self.module.StaticTuple('foo') - self.assertEqual(hash(k), hash(('foo',))) - k = self.module.StaticTuple('foo', 'bar', 'baz', 'bing') - as_tuple = ('foo', 'bar', 'baz', 'bing') + k = self.module.StaticTuple("foo") + self.assertEqual(hash(k), hash(("foo",))) + k = self.module.StaticTuple("foo", "bar", "baz", "bing") + as_tuple = ("foo", "bar", "baz", "bing") self.assertEqual(hash(k), hash(as_tuple)) - x = {k: 'foo'} + x = {k: "foo"} # Because k == , it replaces the slot, rather than having both # present in the dict. - self.assertEqual('foo', x[as_tuple]) - x[as_tuple] = 'bar' - self.assertEqual({as_tuple: 'bar'}, x) + self.assertEqual("foo", x[as_tuple]) + x[as_tuple] = "bar" + self.assertEqual({as_tuple: "bar"}, x) k2 = self.module.StaticTuple(k) - as_tuple2 = (('foo', 'bar', 'baz', 'bing'),) + as_tuple2 = (("foo", "bar", "baz", "bing"),) self.assertEqual(hash(k2), hash(as_tuple2)) - k3 = self.module.StaticTuple('foo', 1, None, '\xb5', 1.2, 2**65, True, - k) - as_tuple3 = ('foo', 1, None, '\xb5', 1.2, 2**65, True, k) + k3 = self.module.StaticTuple("foo", 1, None, "\xb5", 1.2, 2**65, True, k) + as_tuple3 = ("foo", 1, None, "\xb5", 1.2, 2**65, True, k) self.assertEqual(hash(as_tuple3), hash(k3)) def test_slice(self): - k = self.module.StaticTuple('foo', 'bar', 'baz', 'bing') - self.assertEqual(('foo', 'bar'), k[:2]) - self.assertEqual(('baz',), k[2:-1]) - self.assertEqual(('foo', 'baz',), k[::2]) - self.assertRaises(TypeError, k.__getitem__, 'not_slice') + k = self.module.StaticTuple("foo", "bar", "baz", "bing") + self.assertEqual(("foo", "bar"), k[:2]) + self.assertEqual(("baz",), k[2:-1]) + self.assertEqual( + ( + "foo", + "baz", + ), + k[::2], + ) + self.assertRaises(TypeError, k.__getitem__, "not_slice") def test_referents(self): # We implement tp_traverse so that things like 'meliae' can measure the @@ -430,32 +454,37 @@ def test_referents(self): # helper func, but that won't work for the generic implementation... self.requireFeature(features.meliae) from meliae import scanner - strs = ['foo', 'bar', 'baz', 'bing'] + + strs = ["foo", "bar", "baz", "bing"] k = self.module.StaticTuple(*strs) if self.module is _static_tuple_py: refs = strs + [self.module.StaticTuple] else: refs = strs + def key(k): if isinstance(k, type): return (0, k) if isinstance(k, str): return (1, k) raise TypeError(k) + self.assertEqual( - sorted(refs, key=key), - sorted(scanner.get_referents(k), key=key)) + sorted(refs, key=key), sorted(scanner.get_referents(k), key=key) + ) def test_nested_referents(self): self.requireFeature(features.meliae) from meliae import scanner - strs = ['foo', 'bar', 'baz', 'bing'] + + strs = ["foo", "bar", "baz", "bing"] k1 = self.module.StaticTuple(*strs[:2]) k2 = self.module.StaticTuple(*strs[2:]) k3 = self.module.StaticTuple(k1, k2) refs = [k1, k2] if self.module is _static_tuple_py: refs.append(self.module.StaticTuple) + def key(k): if isinstance(k, type): return (0, k) @@ -463,16 +492,17 @@ def key(k): return (1, k) raise TypeError(k) - self.assertEqual(sorted(refs, key=key), - sorted(scanner.get_referents(k3), key=key)) + self.assertEqual( + sorted(refs, key=key), sorted(scanner.get_referents(k3), key=key) + ) def test_empty_is_singleton(self): key = self.module.StaticTuple() self.assertIs(key, self.module._empty_tuple) def test_intern(self): - unique_str1 = 'unique str ' + osutils.rand_chars(20) - unique_str2 = 'unique str ' + osutils.rand_chars(20) + unique_str1 = "unique str " + osutils.rand_chars(20) + unique_str2 = "unique str " + osutils.rand_chars(20) key = self.module.StaticTuple(unique_str1, unique_str2) self.assertNotIn(key, self.module._interned_tuples) key2 = self.module.StaticTuple(unique_str1, unique_str2) @@ -488,8 +518,8 @@ def test_intern(self): def test__c_intern_handles_refcount(self): if self.module is _static_tuple_py: return # Not applicable - unique_str1 = 'unique str ' + osutils.rand_chars(20) - unique_str2 = 'unique str ' + osutils.rand_chars(20) + unique_str1 = "unique str " + osutils.rand_chars(20) + unique_str2 = "unique str " + osutils.rand_chars(20) key = self.module.StaticTuple(unique_str1, unique_str2) self.assertRefcount(1, key) self.assertNotIn(key, self.module._interned_tuples) @@ -523,8 +553,8 @@ def test__c_intern_handles_refcount(self): def test__c_keys_are_not_immortal(self): if self.module is _static_tuple_py: return # Not applicable - unique_str1 = 'unique str ' + osutils.rand_chars(20) - unique_str2 = 'unique str ' + osutils.rand_chars(20) + unique_str1 = "unique str " + osutils.rand_chars(20) + unique_str2 = "unique str " + osutils.rand_chars(20) key = self.module.StaticTuple(unique_str1, unique_str2) self.assertNotIn(key, self.module._interned_tuples) self.assertRefcount(1, key) @@ -546,53 +576,52 @@ def test__c_has_C_API(self): self.assertIsNot(None, self.module._C_API) def test_from_sequence_tuple(self): - st = self.module.StaticTuple.from_sequence(('foo', 'bar')) + st = self.module.StaticTuple.from_sequence(("foo", "bar")) self.assertIsInstance(st, self.module.StaticTuple) - self.assertEqual(('foo', 'bar'), st) + self.assertEqual(("foo", "bar"), st) def test_from_sequence_str(self): - st = self.module.StaticTuple.from_sequence('foo') + st = self.module.StaticTuple.from_sequence("foo") self.assertIsInstance(st, self.module.StaticTuple) - self.assertEqual(('f', 'o', 'o'), st) + self.assertEqual(("f", "o", "o"), st) def test_from_sequence_list(self): - st = self.module.StaticTuple.from_sequence(['foo', 'bar']) + st = self.module.StaticTuple.from_sequence(["foo", "bar"]) self.assertIsInstance(st, self.module.StaticTuple) - self.assertEqual(('foo', 'bar'), st) + self.assertEqual(("foo", "bar"), st) def test_from_sequence_static_tuple(self): - st = self.module.StaticTuple('foo', 'bar') + st = self.module.StaticTuple("foo", "bar") st2 = self.module.StaticTuple.from_sequence(st) # If the source is a StaticTuple already, we return the exact object self.assertIs(st, st2) def test_from_sequence_not_sequence(self): - self.assertRaises(TypeError, - self.module.StaticTuple.from_sequence, object()) - self.assertRaises(TypeError, - self.module.StaticTuple.from_sequence, 10) + self.assertRaises(TypeError, self.module.StaticTuple.from_sequence, object()) + self.assertRaises(TypeError, self.module.StaticTuple.from_sequence, 10) def test_from_sequence_incorrect_args(self): - self.assertRaises(TypeError, - self.module.StaticTuple.from_sequence, object(), 'a') - self.assertRaises(TypeError, - self.module.StaticTuple.from_sequence, foo='a') + self.assertRaises( + TypeError, self.module.StaticTuple.from_sequence, object(), "a" + ) + self.assertRaises(TypeError, self.module.StaticTuple.from_sequence, foo="a") def test_from_sequence_iterable(self): - st = self.module.StaticTuple.from_sequence(iter(['foo', 'bar'])) + st = self.module.StaticTuple.from_sequence(iter(["foo", "bar"])) self.assertIsInstance(st, self.module.StaticTuple) - self.assertEqual(('foo', 'bar'), st) + self.assertEqual(("foo", "bar"), st) def test_from_sequence_generator(self): def generate_tuple(): - yield 'foo' - yield 'bar' + yield "foo" + yield "bar" + st = self.module.StaticTuple.from_sequence(generate_tuple()) self.assertIsInstance(st, self.module.StaticTuple) - self.assertEqual(('foo', 'bar'), st) + self.assertEqual(("foo", "bar"), st) def test_pickle(self): - st = self.module.StaticTuple('foo', 'bar') + st = self.module.StaticTuple("foo", "bar") pickled = pickle.dumps(st) unpickled = pickle.loads(pickled) # noqa: S301 self.assertEqual(unpickled, st) @@ -604,7 +633,7 @@ def test_pickle_empty(self): self.assertIs(st, unpickled) def test_pickle_nested(self): - st = self.module.StaticTuple('foo', self.module.StaticTuple('bar')) + st = self.module.StaticTuple("foo", self.module.StaticTuple("bar")) pickled = pickle.dumps(st) unpickled = pickle.loads(pickled) # noqa: S301 self.assertEqual(unpickled, st) @@ -616,30 +645,28 @@ def test_static_tuple_thunk(self): if compiled_static_tuple_feature.available(): # We will be using the C version return - self.assertIs(static_tuple.StaticTuple, - self.module.StaticTuple) + self.assertIs(static_tuple.StaticTuple, self.module.StaticTuple) class TestEnsureStaticTuple(tests.TestCase): - def test_is_static_tuple(self): - st = static_tuple.StaticTuple('foo') + st = static_tuple.StaticTuple("foo") st2 = static_tuple.expect_static_tuple(st) self.assertIs(st, st2) def test_is_tuple(self): - t = ('foo',) + t = ("foo",) st = static_tuple.expect_static_tuple(t) self.assertIsInstance(st, static_tuple.StaticTuple) self.assertEqual(t, st) def test_flagged_is_static_tuple(self): - debug.set_debug_flag('static_tuple') - st = static_tuple.StaticTuple('foo') + debug.set_debug_flag("static_tuple") + st = static_tuple.StaticTuple("foo") st2 = static_tuple.expect_static_tuple(st) self.assertIs(st, st2) def test_flagged_is_tuple(self): - debug.set_debug_flag('static_tuple') - t = ('foo',) + debug.set_debug_flag("static_tuple") + t = ("foo",) self.assertRaises(TypeError, static_tuple.expect_static_tuple, t) diff --git a/breezy/bzr/tests/test_btree_index.py b/breezy/bzr/tests/test_btree_index.py index c79bb0bd4d..37b946a16b 100644 --- a/breezy/bzr/tests/test_btree_index.py +++ b/breezy/bzr/tests/test_btree_index.py @@ -30,15 +30,16 @@ def btreeparser_scenarios(): import breezy.bzr._btree_serializer_py as py_module - scenarios = [('python', {'parse_btree': py_module})] + + scenarios = [("python", {"parse_btree": py_module})] if compiled_btreeparser_feature.available(): - scenarios.append(('C', - {'parse_btree': compiled_btreeparser_feature.module})) + scenarios.append(("C", {"parse_btree": compiled_btreeparser_feature.module})) return scenarios compiled_btreeparser_feature = features.ModuleAvailableFeature( - 'breezy.bzr._btree_serializer_pyx') + "breezy.bzr._btree_serializer_pyx" +) class BTreeTestCase(TestCaseWithTransport): @@ -47,12 +48,14 @@ class BTreeTestCase(TestCaseWithTransport): def setUp(self): super().setUp() - self.overrideAttr(btree_index, '_RESERVED_HEADER_BYTES', 100) + self.overrideAttr(btree_index, "_RESERVED_HEADER_BYTES", 100) def make_nodes(self, count, key_elements, reference_lists): """Generate count*key_elements sample nodes.""" + def _pos_to_key(pos, lead=b""): return (lead + (b"%d" % pos) * 40,) + keys = [] for prefix_pos in range(key_elements): if key_elements - 1: @@ -76,12 +79,10 @@ def _pos_to_key(pos, lead=b""): for ref_pos in range(list_pos + pos % 2): if pos % 2: # refer to a nearby key - refs[-1].append(prefix - + _pos_to_key(pos - 1, b"ref")) + refs[-1].append(prefix + _pos_to_key(pos - 1, b"ref")) else: # serial of this ref in the ref list - refs[-1].append(prefix - + _pos_to_key(ref_pos, b"ref")) + refs[-1].append(prefix + _pos_to_key(ref_pos, b"ref")) refs[-1] = tuple(refs[-1]) refs = tuple(refs) else: @@ -91,7 +92,7 @@ def _pos_to_key(pos, lead=b""): def shrink_page_size(self): """Shrink the default page size so that less fits in a page.""" - self.overrideAttr(btree_index, '_PAGE_SIZE') + self.overrideAttr(btree_index, "_PAGE_SIZE") btree_index._PAGE_SIZE = 2048 def assertEqualApproxCompressed(self, expected, actual, slop=6): @@ -101,12 +102,12 @@ def assertEqualApproxCompressed(self, expected, actual, slop=6): slightly bogus, but zlib is stable enough that this mostly works. """ if not expected - slop < actual < expected + slop: - self.fail("Expected around %d bytes compressed but got %d" % - (expected, actual)) + self.fail( + "Expected around %d bytes compressed but got %d" % (expected, actual) + ) class TestBTreeBuilder(BTreeTestCase): - def test_clear_cache(self): builder = btree_index.BTreeBuilder(reference_lists=0, key_elements=1) # This is a no-op, but we need the api to be consistent with other @@ -122,7 +123,8 @@ def test_empty_1_0(self): self.assertEqual( b"B+Tree Graph Index 2\nnode_ref_lists=0\nkey_elements=1\nlen=0\n" b"row_lengths=\n", - content) + content, + ) def test_empty_2_1(self): builder = btree_index.BTreeBuilder(key_elements=2, reference_lists=1) @@ -133,7 +135,8 @@ def test_empty_2_1(self): self.assertEqual( b"B+Tree Graph Index 2\nnode_ref_lists=1\nkey_elements=2\nlen=0\n" b"row_lengths=\n", - content) + content, + ) def test_root_leaf_1_0(self): builder = btree_index.BTreeBuilder(key_elements=1, reference_lists=0) @@ -148,15 +151,18 @@ def test_root_leaf_1_0(self): self.assertEqual( b"B+Tree Graph Index 2\nnode_ref_lists=0\nkey_elements=1\nlen=5\n" b"row_lengths=1\n", - content[:73]) + content[:73], + ) node_content = content[73:] node_bytes = zlib.decompress(node_content) - expected_node = (b"type=leaf\n" - b"0000000000000000000000000000000000000000\x00\x00value:0\n" - b"1111111111111111111111111111111111111111\x00\x00value:1\n" - b"2222222222222222222222222222222222222222\x00\x00value:2\n" - b"3333333333333333333333333333333333333333\x00\x00value:3\n" - b"4444444444444444444444444444444444444444\x00\x00value:4\n") + expected_node = ( + b"type=leaf\n" + b"0000000000000000000000000000000000000000\x00\x00value:0\n" + b"1111111111111111111111111111111111111111\x00\x00value:1\n" + b"2222222222222222222222222222222222222222\x00\x00value:2\n" + b"3333333333333333333333333333333333333333\x00\x00value:3\n" + b"4444444444444444444444444444444444444444\x00\x00value:4\n" + ) self.assertEqual(expected_node, node_bytes) def test_root_leaf_2_2(self): @@ -172,7 +178,8 @@ def test_root_leaf_2_2(self): self.assertEqual( b"B+Tree Graph Index 2\nnode_ref_lists=2\nkey_elements=2\nlen=10\n" b"row_lengths=1\n", - content[:74]) + content[:74], + ) node_content = content[74:] node_bytes = zlib.decompress(node_content) expected_node = ( @@ -188,7 +195,7 @@ def test_root_leaf_2_2(self): b"1111111111111111111111111111111111111111\x003333333333333333333333333333333333333333\x001111111111111111111111111111111111111111\x00ref2222222222222222222222222222222222222222\t1111111111111111111111111111111111111111\x00ref2222222222222222222222222222222222222222\r1111111111111111111111111111111111111111\x00ref2222222222222222222222222222222222222222\x00value:3\n" b"1111111111111111111111111111111111111111\x004444444444444444444444444444444444444444\x00\t1111111111111111111111111111111111111111\x00ref0000000000000000000000000000000000000000\x00value:4\n" b"" - ) + ) self.assertEqual(expected_node, node_bytes) def test_2_leaves_1_0(self): @@ -204,15 +211,13 @@ def test_2_leaves_1_0(self): self.assertEqual( b"B+Tree Graph Index 2\nnode_ref_lists=0\nkey_elements=1\nlen=400\n" b"row_lengths=1,2\n", - content[:77]) + content[:77], + ) root = content[77:4096] leaf1 = content[4096:8192] leaf2 = content[8192:] root_bytes = zlib.decompress(root) - expected_root = ( - b"type=internal\n" - b"offset=0\n" - ) + (b"307" * 40) + b"\n" + expected_root = (b"type=internal\n" b"offset=0\n") + (b"307" * 40) + b"\n" self.assertEqual(expected_root, root_bytes) # We already know serialisation works for leaves, check key selection: leaf1_bytes = zlib.decompress(leaf1) @@ -238,7 +243,8 @@ def test_last_page_rounded_1_layer(self): self.assertEqual( b"B+Tree Graph Index 2\nnode_ref_lists=0\nkey_elements=1\nlen=10\n" b"row_lengths=1\n", - content[:74]) + content[:74], + ) # Check thelast page is well formed leaf2 = content[74:] leaf2_bytes = zlib.decompress(leaf2) @@ -260,7 +266,8 @@ def test_last_page_not_rounded_2_layer(self): self.assertEqual( b"B+Tree Graph Index 2\nnode_ref_lists=0\nkey_elements=1\nlen=400\n" b"row_lengths=1,2\n", - content[:77]) + content[:77], + ) # Check the last page is well formed leaf2 = content[8192:] leaf2_bytes = zlib.decompress(leaf2) @@ -282,14 +289,15 @@ def test_three_level_tree_details(self): for node in nodes: builder.add_node(*node) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - size = t.put_file('index', self.time(builder.finish)) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + size = t.put_file("index", self.time(builder.finish)) del builder - index = btree_index.BTreeGraphIndex(t, 'index', size) + index = btree_index.BTreeGraphIndex(t, "index", size) # Seed the metadata, we're using internal calls now. index.key_count() - self.assertEqual(3, len(index._row_lengths), - f"Not enough rows: {index._row_lengths!r}") + self.assertEqual( + 3, len(index._row_lengths), f"Not enough rows: {index._row_lengths!r}" + ) self.assertEqual(4, len(index._row_offsets)) self.assertEqual(sum(index._row_lengths), index._row_offsets[-1]) internal_nodes = index._get_internal_nodes([0, 1, 2]) @@ -319,7 +327,8 @@ def test_2_leaves_2_2(self): self.assertEqual( b"B+Tree Graph Index 2\nnode_ref_lists=2\nkey_elements=2\nlen=200\n" b"row_lengths=1,3\n", - content[:77]) + content[:77], + ) root = content[77:4096] content[4096:8192] content[8192:12288] @@ -327,10 +336,16 @@ def test_2_leaves_2_2(self): root_bytes = zlib.decompress(root) expected_root = ( b"type=internal\n" - b"offset=0\n" + - (b"0" * 40) + b"\x00" + (b"91" * 40) + b"\n" + - (b"1" * 40) + b"\x00" + (b"81" * 40) + b"\n" - ) + b"offset=0\n" + + (b"0" * 40) + + b"\x00" + + (b"91" * 40) + + b"\n" + + (b"1" * 40) + + b"\x00" + + (b"81" * 40) + + b"\n" + ) self.assertEqual(expected_root, root_bytes) # We assume the other leaf nodes have been written correctly - layering # FTW. @@ -390,14 +405,20 @@ def test_spill_index_stress_1_1(self): builder.add_node(*nodes[12]) # Test that memory and disk are both used for query methods; and that # None is skipped over happily. - self.assertEqual([(builder,) + node for node in sorted(nodes[:13])], - list(builder.iter_all_entries())) + self.assertEqual( + [(builder,) + node for node in sorted(nodes[:13])], + list(builder.iter_all_entries()), + ) # Two nodes - one memory one disk - self.assertEqual({(builder,) + node for node in nodes[11:13]}, - set(builder.iter_entries([nodes[12][0], nodes[11][0]]))) + self.assertEqual( + {(builder,) + node for node in nodes[11:13]}, + set(builder.iter_entries([nodes[12][0], nodes[11][0]])), + ) self.assertEqual(13, builder.key_count()) - self.assertEqual({(builder,) + node for node in nodes[11:13]}, - set(builder.iter_entries_prefix([nodes[12][0], nodes[11][0]]))) + self.assertEqual( + {(builder,) + node for node in nodes[11:13]}, + set(builder.iter_entries_prefix([nodes[12][0], nodes[11][0]])), + ) builder.add_node(*nodes[13]) self.assertEqual(3, len(builder._backing_indices)) self.assertEqual(2, builder._backing_indices[0].key_count()) @@ -411,9 +432,9 @@ def test_spill_index_stress_1_1(self): self.assertEqual(None, builder._backing_indices[2]) self.assertEqual(16, builder._backing_indices[3].key_count()) # Now finish, and check we got a correctly ordered tree - t = self.get_transport('') - size = t.put_file('index', builder.finish()) - index = btree_index.BTreeGraphIndex(t, 'index', size) + t = self.get_transport("") + size = t.put_file("index", builder.finish()) + index = btree_index.BTreeGraphIndex(t, "index", size) nodes = list(index.iter_all_entries()) self.assertEqual(sorted(nodes), nodes) self.assertEqual(16, len(nodes)) @@ -464,14 +485,20 @@ def test_spill_index_stress_1_1_no_combine(self): self.assertEqual(2, backing_index.key_count()) # Test that memory and disk are both used for query methods; and that # None is skipped over happily. - self.assertEqual([(builder,) + node for node in sorted(nodes[:13])], - list(builder.iter_all_entries())) + self.assertEqual( + [(builder,) + node for node in sorted(nodes[:13])], + list(builder.iter_all_entries()), + ) # Two nodes - one memory one disk - self.assertEqual({(builder,) + node for node in nodes[11:13]}, - set(builder.iter_entries([nodes[12][0], nodes[11][0]]))) + self.assertEqual( + {(builder,) + node for node in nodes[11:13]}, + set(builder.iter_entries([nodes[12][0], nodes[11][0]])), + ) self.assertEqual(13, builder.key_count()) - self.assertEqual({(builder,) + node for node in nodes[11:13]}, - set(builder.iter_entries_prefix([nodes[12][0], nodes[11][0]]))) + self.assertEqual( + {(builder,) + node for node in nodes[11:13]}, + set(builder.iter_entries_prefix([nodes[12][0], nodes[11][0]])), + ) builder.add_node(*nodes[13]) builder.add_node(*nodes[14]) builder.add_node(*nodes[15]) @@ -479,9 +506,9 @@ def test_spill_index_stress_1_1_no_combine(self): for backing_index in builder._backing_indices: self.assertEqual(2, backing_index.key_count()) # Now finish, and check we got a correctly ordered tree - transport = self.get_transport('') - size = transport.put_file('index', builder.finish()) - index = btree_index.BTreeGraphIndex(transport, 'index', size) + transport = self.get_transport("") + size = transport.put_file("index", builder.finish()) + index = btree_index.BTreeGraphIndex(transport, "index", size) nodes = list(index.iter_all_entries()) self.assertEqual(sorted(nodes), nodes) self.assertEqual(16, len(nodes)) @@ -505,8 +532,9 @@ def test_set_optimize(self): def test_spill_index_stress_2_2(self): # test that references and longer keys don't confuse things. - builder = btree_index.BTreeBuilder(key_elements=2, reference_lists=2, - spill_at=2) + builder = btree_index.BTreeBuilder( + key_elements=2, reference_lists=2, spill_at=2 + ) nodes = self.make_nodes(16, 2, 2) builder.add_node(*nodes[0]) # Test the parts of the index that take up memory are doing so @@ -565,14 +593,20 @@ def test_spill_index_stress_2_2(self): builder.add_node(*nodes[12]) # Test that memory and disk are both used for query methods; and that # None is skipped over happily. - self.assertEqual([(builder,) + node for node in sorted(nodes[:13])], - list(builder.iter_all_entries())) + self.assertEqual( + [(builder,) + node for node in sorted(nodes[:13])], + list(builder.iter_all_entries()), + ) # Two nodes - one memory one disk - self.assertEqual({(builder,) + node for node in nodes[11:13]}, - set(builder.iter_entries([nodes[12][0], nodes[11][0]]))) + self.assertEqual( + {(builder,) + node for node in nodes[11:13]}, + set(builder.iter_entries([nodes[12][0], nodes[11][0]])), + ) self.assertEqual(13, builder.key_count()) - self.assertEqual({(builder,) + node for node in nodes[11:13]}, - set(builder.iter_entries_prefix([nodes[12][0], nodes[11][0]]))) + self.assertEqual( + {(builder,) + node for node in nodes[11:13]}, + set(builder.iter_entries_prefix([nodes[12][0], nodes[11][0]])), + ) builder.add_node(*nodes[13]) self.assertEqual(3, len(builder._backing_indices)) self.assertEqual(2, builder._backing_indices[0].key_count()) @@ -586,9 +620,9 @@ def test_spill_index_stress_2_2(self): self.assertEqual(None, builder._backing_indices[2]) self.assertEqual(16, builder._backing_indices[3].key_count()) # Now finish, and check we got a correctly ordered tree - transport = self.get_transport('') - size = transport.put_file('index', builder.finish()) - index = btree_index.BTreeGraphIndex(transport, 'index', size) + transport = self.get_transport("") + size = transport.put_file("index", builder.finish()) + index = btree_index.BTreeGraphIndex(transport, "index", size) nodes = list(index.iter_all_entries()) self.assertEqual(sorted(nodes), nodes) self.assertEqual(16, len(nodes)) @@ -603,36 +637,35 @@ def test_spill_index_duplicate_key_caught_on_finish(self): class TestBTreeIndex(BTreeTestCase): - def make_index(self, ref_lists=0, key_elements=1, nodes=None): if nodes is None: nodes = [] - builder = btree_index.BTreeBuilder(reference_lists=ref_lists, - key_elements=key_elements) + builder = btree_index.BTreeBuilder( + reference_lists=ref_lists, key_elements=key_elements + ) for key, value, references in nodes: builder.add_node(key, value, references) stream = builder.finish() - trans = transport.get_transport_from_url('trace+' + self.get_url()) - size = trans.put_file('index', stream) - return btree_index.BTreeGraphIndex(trans, 'index', size) + trans = transport.get_transport_from_url("trace+" + self.get_url()) + size = trans.put_file("index", stream) + return btree_index.BTreeGraphIndex(trans, "index", size) - def make_index_with_offset(self, ref_lists=1, key_elements=1, nodes=None, - offset=0): + def make_index_with_offset(self, ref_lists=1, key_elements=1, nodes=None, offset=0): if nodes is None: nodes = [] - builder = btree_index.BTreeBuilder(key_elements=key_elements, - reference_lists=ref_lists) + builder = btree_index.BTreeBuilder( + key_elements=key_elements, reference_lists=ref_lists + ) builder.add_nodes(nodes) - transport = self.get_transport('') + transport = self.get_transport("") # NamedTemporaryFile dies on builder.finish().read(). weird. temp_file = builder.finish() content = temp_file.read() del temp_file size = len(content) - transport.put_bytes('index', (b' ' * offset) + content) - return btree_index.BTreeGraphIndex(transport, 'index', size=size, - offset=offset) + transport.put_bytes("index", (b" " * offset) + content) + return btree_index.BTreeGraphIndex(transport, "index", size=size, offset=offset) def test_clear_cache(self): nodes = self.make_nodes(160, 2, 2) @@ -651,101 +684,98 @@ def test_clear_cache(self): # will save round-trips. This assertion isn't very strong, # becuase without a 3-level index, we don't have any internal # nodes cached. - self.assertEqual(internal_node_pre_clear, - set(index._internal_node_cache)) + self.assertEqual(internal_node_pre_clear, set(index._internal_node_cache)) self.assertEqual(0, len(index._leaf_node_cache)) def test_trivial_constructor(self): - t = transport.get_transport_from_url('trace+' + self.get_url('')) - btree_index.BTreeGraphIndex(t, 'index', None) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + btree_index.BTreeGraphIndex(t, "index", None) # Checks the page size at load, but that isn't logged yet. self.assertEqual([], t._activity) def test_with_size_constructor(self): - t = transport.get_transport_from_url('trace+' + self.get_url('')) - btree_index.BTreeGraphIndex(t, 'index', 1) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + btree_index.BTreeGraphIndex(t, "index", 1) # Checks the page size at load, but that isn't logged yet. self.assertEqual([], t._activity) def test_empty_key_count_no_size(self): builder = btree_index.BTreeBuilder(key_elements=1, reference_lists=0) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - t.put_file('index', builder.finish()) - index = btree_index.BTreeGraphIndex(t, 'index', None) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + t.put_file("index", builder.finish()) + index = btree_index.BTreeGraphIndex(t, "index", None) del t._activity[:] self.assertEqual([], t._activity) self.assertEqual(0, index.key_count()) # The entire index should have been requested (as we generally have the # size available, and doing many small readvs is inappropriate). # We can't tell how much was actually read here, but - check the code. - self.assertEqual([('get', 'index')], t._activity) + self.assertEqual([("get", "index")], t._activity) def test_empty_key_count(self): builder = btree_index.BTreeBuilder(key_elements=1, reference_lists=0) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - size = t.put_file('index', builder.finish()) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + size = t.put_file("index", builder.finish()) self.assertEqual(72, size) - index = btree_index.BTreeGraphIndex(t, 'index', size) + index = btree_index.BTreeGraphIndex(t, "index", size) del t._activity[:] self.assertEqual([], t._activity) self.assertEqual(0, index.key_count()) # The entire index should have been read, as 4K > size - self.assertEqual([('readv', 'index', [(0, 72)], False, None)], - t._activity) + self.assertEqual([("readv", "index", [(0, 72)], False, None)], t._activity) def test_non_empty_key_count_2_2(self): builder = btree_index.BTreeBuilder(key_elements=2, reference_lists=2) nodes = self.make_nodes(35, 2, 2) for node in nodes: builder.add_node(*node) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - size = t.put_file('index', builder.finish()) - index = btree_index.BTreeGraphIndex(t, 'index', size) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + size = t.put_file("index", builder.finish()) + index = btree_index.BTreeGraphIndex(t, "index", size) del t._activity[:] self.assertEqual([], t._activity) self.assertEqual(70, index.key_count()) # The entire index should have been read, as it is one page long. - self.assertEqual([('readv', 'index', [(0, size)], False, None)], - t._activity) + self.assertEqual([("readv", "index", [(0, size)], False, None)], t._activity) self.assertEqualApproxCompressed(1173, size) def test_with_offset_no_size(self): - index = self.make_index_with_offset(key_elements=1, ref_lists=1, - offset=1234, - nodes=self.make_nodes(200, 1, 1)) + index = self.make_index_with_offset( + key_elements=1, ref_lists=1, offset=1234, nodes=self.make_nodes(200, 1, 1) + ) index._size = None # throw away the size info self.assertEqual(200, index.key_count()) def test_with_small_offset(self): - index = self.make_index_with_offset(key_elements=1, ref_lists=1, - offset=1234, - nodes=self.make_nodes(200, 1, 1)) + index = self.make_index_with_offset( + key_elements=1, ref_lists=1, offset=1234, nodes=self.make_nodes(200, 1, 1) + ) self.assertEqual(200, index.key_count()) def test_with_large_offset(self): - index = self.make_index_with_offset(key_elements=1, ref_lists=1, - offset=123456, - nodes=self.make_nodes(200, 1, 1)) + index = self.make_index_with_offset( + key_elements=1, ref_lists=1, offset=123456, nodes=self.make_nodes(200, 1, 1) + ) self.assertEqual(200, index.key_count()) def test__read_nodes_no_size_one_page_reads_once(self): - self.make_index(nodes=[((b'key',), b'value', ())]) - trans = transport.get_transport_from_url('trace+' + self.get_url()) - index = btree_index.BTreeGraphIndex(trans, 'index', None) + self.make_index(nodes=[((b"key",), b"value", ())]) + trans = transport.get_transport_from_url("trace+" + self.get_url()) + index = btree_index.BTreeGraphIndex(trans, "index", None) del trans._activity[:] nodes = dict(index._read_nodes([0])) self.assertEqual({0}, set(nodes)) node = nodes[0] - self.assertEqual([(b'key',)], node.all_keys()) - self.assertEqual([('get', 'index')], trans._activity) + self.assertEqual([(b"key",)], node.all_keys()) + self.assertEqual([("get", "index")], trans._activity) def test__read_nodes_no_size_multiple_pages(self): index = self.make_index(2, 2, nodes=self.make_nodes(160, 2, 2)) index.key_count() num_pages = index._row_offsets[-1] # Reopen with a traced transport and no size - trans = transport.get_transport_from_url('trace+' + self.get_url()) - index = btree_index.BTreeGraphIndex(trans, 'index', None) + trans = transport.get_transport_from_url("trace+" + self.get_url()) + index = btree_index.BTreeGraphIndex(trans, "index", None) del trans._activity[:] nodes = dict(index._read_nodes([0])) self.assertEqual(list(range(num_pages)), sorted(nodes)) @@ -755,31 +785,29 @@ def test_2_levels_key_count_2_2(self): nodes = self.make_nodes(160, 2, 2) for node in nodes: builder.add_node(*node) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - size = t.put_file('index', builder.finish()) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + size = t.put_file("index", builder.finish()) self.assertEqualApproxCompressed(17692, size) - index = btree_index.BTreeGraphIndex(t, 'index', size) + index = btree_index.BTreeGraphIndex(t, "index", size) del t._activity[:] self.assertEqual([], t._activity) self.assertEqual(320, index.key_count()) # The entire index should not have been read. - self.assertEqual([('readv', 'index', [(0, 4096)], False, None)], - t._activity) + self.assertEqual([("readv", "index", [(0, 4096)], False, None)], t._activity) def test_validate_one_page(self): builder = btree_index.BTreeBuilder(key_elements=2, reference_lists=2) nodes = self.make_nodes(45, 2, 2) for node in nodes: builder.add_node(*node) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - size = t.put_file('index', builder.finish()) - index = btree_index.BTreeGraphIndex(t, 'index', size) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + size = t.put_file("index", builder.finish()) + index = btree_index.BTreeGraphIndex(t, "index", size) del t._activity[:] self.assertEqual([], t._activity) index.validate() # The entire index should have been read linearly. - self.assertEqual([('readv', 'index', [(0, size)], False, None)], - t._activity) + self.assertEqual([("readv", "index", [(0, size)], False, None)], t._activity) self.assertEqualApproxCompressed(1488, size) def test_validate_two_pages(self): @@ -787,84 +815,88 @@ def test_validate_two_pages(self): nodes = self.make_nodes(80, 2, 2) for node in nodes: builder.add_node(*node) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - size = t.put_file('index', builder.finish()) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + size = t.put_file("index", builder.finish()) # Root page, 2 leaf pages self.assertEqualApproxCompressed(9339, size) - index = btree_index.BTreeGraphIndex(t, 'index', size) + index = btree_index.BTreeGraphIndex(t, "index", size) del t._activity[:] self.assertEqual([], t._activity) index.validate() rem = size - 8192 # Number of remaining bytes after second block # The entire index should have been read linearly. self.assertEqual( - [('readv', 'index', [(0, 4096)], False, None), - ('readv', 'index', [(4096, 4096), (8192, rem)], False, None)], - t._activity) + [ + ("readv", "index", [(0, 4096)], False, None), + ("readv", "index", [(4096, 4096), (8192, rem)], False, None), + ], + t._activity, + ) # XXX: TODO: write some badly-ordered nodes, and some pointers-to-wrong # node and make validate find them. def test_eq_ne(self): # two indices are equal when constructed with the same parameters: - t1 = transport.get_transport_from_url('trace+' + self.get_url('')) + t1 = transport.get_transport_from_url("trace+" + self.get_url("")) t2 = self.get_transport() self.assertEqual( - btree_index.BTreeGraphIndex(t1, 'index', None), - btree_index.BTreeGraphIndex(t1, 'index', None) + btree_index.BTreeGraphIndex(t1, "index", None), + btree_index.BTreeGraphIndex(t1, "index", None), ) self.assertEqual( - btree_index.BTreeGraphIndex(t1, 'index', 20), - btree_index.BTreeGraphIndex(t1, 'index', 20) + btree_index.BTreeGraphIndex(t1, "index", 20), + btree_index.BTreeGraphIndex(t1, "index", 20), ) self.assertNotEqual( - btree_index.BTreeGraphIndex(t1, 'index', 20), - btree_index.BTreeGraphIndex(t2, 'index', 20) + btree_index.BTreeGraphIndex(t1, "index", 20), + btree_index.BTreeGraphIndex(t2, "index", 20), ) self.assertNotEqual( - btree_index.BTreeGraphIndex(t1, 'inde1', 20), - btree_index.BTreeGraphIndex(t1, 'inde2', 20) + btree_index.BTreeGraphIndex(t1, "inde1", 20), + btree_index.BTreeGraphIndex(t1, "inde2", 20), ) self.assertNotEqual( - btree_index.BTreeGraphIndex(t1, 'index', 10), - btree_index.BTreeGraphIndex(t1, 'index', 20) + btree_index.BTreeGraphIndex(t1, "index", 10), + btree_index.BTreeGraphIndex(t1, "index", 20), ) self.assertEqual( - btree_index.BTreeGraphIndex(t1, 'index', None), - btree_index.BTreeGraphIndex(t1, 'index', None) + btree_index.BTreeGraphIndex(t1, "index", None), + btree_index.BTreeGraphIndex(t1, "index", None), ) self.assertEqual( - btree_index.BTreeGraphIndex(t1, 'index', 20), - btree_index.BTreeGraphIndex(t1, 'index', 20) + btree_index.BTreeGraphIndex(t1, "index", 20), + btree_index.BTreeGraphIndex(t1, "index", 20), ) self.assertNotEqual( - btree_index.BTreeGraphIndex(t1, 'index', 20), - btree_index.BTreeGraphIndex(t2, 'index', 20) + btree_index.BTreeGraphIndex(t1, "index", 20), + btree_index.BTreeGraphIndex(t2, "index", 20), ) self.assertNotEqual( - btree_index.BTreeGraphIndex(t1, 'inde1', 20), - btree_index.BTreeGraphIndex(t1, 'inde2', 20) + btree_index.BTreeGraphIndex(t1, "inde1", 20), + btree_index.BTreeGraphIndex(t1, "inde2", 20), ) self.assertNotEqual( - btree_index.BTreeGraphIndex(t1, 'index', 10), - btree_index.BTreeGraphIndex(t1, 'index', 20) + btree_index.BTreeGraphIndex(t1, "index", 10), + btree_index.BTreeGraphIndex(t1, "index", 20), ) def test_key_too_big(self): # the size that matters here is the _compressed_ size of the key, so we can't # do a simple character repeat. - bigKey = b''.join(b'%d' % n for n in range(btree_index._PAGE_SIZE)) - self.assertRaises(_mod_index.BadIndexKey, - self.make_index, - nodes=[((bigKey,), b'value', ())]) + bigKey = b"".join(b"%d" % n for n in range(btree_index._PAGE_SIZE)) + self.assertRaises( + _mod_index.BadIndexKey, self.make_index, nodes=[((bigKey,), b"value", ())] + ) def test_iter_all_only_root_no_size(self): - self.make_index(nodes=[((b'key',), b'value', ())]) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - index = btree_index.BTreeGraphIndex(t, 'index', None) + self.make_index(nodes=[((b"key",), b"value", ())]) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + index = btree_index.BTreeGraphIndex(t, "index", None) del t._activity[:] - self.assertEqual([((b'key',), b'value')], - [x[1:] for x in index.iter_all_entries()]) - self.assertEqual([('get', 'index')], t._activity) + self.assertEqual( + [((b"key",), b"value")], [x[1:] for x in index.iter_all_entries()] + ) + self.assertEqual([("get", "index")], t._activity) def test_iter_all_entries_reads(self): # iterating all entries reads the header, then does a linear @@ -876,11 +908,11 @@ def test_iter_all_entries_reads(self): nodes = self.make_nodes(10000, 2, 2) for node in nodes: builder.add_node(*node) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - size = t.put_file('index', builder.finish()) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + size = t.put_file("index", builder.finish()) page_size = btree_index._PAGE_SIZE del builder - index = btree_index.BTreeGraphIndex(t, 'index', size) + index = btree_index.BTreeGraphIndex(t, "index", size) del t._activity[:] self.assertEqual([], t._activity) found_nodes = self.time(list, index.iter_all_entries()) @@ -888,8 +920,9 @@ def test_iter_all_entries_reads(self): for node in found_nodes: self.assertIs(node[0], index) bare_nodes.append(node[1:]) - self.assertEqual(3, len(index._row_lengths), - f"Not enough rows: {index._row_lengths!r}") + self.assertEqual( + 3, len(index._row_lengths), f"Not enough rows: {index._row_lengths!r}" + ) # Should be as long as the nodes we supplied self.assertEqual(20000, len(found_nodes)) # Should have the same content @@ -907,11 +940,12 @@ def test_iter_all_entries_reads(self): readv_request.append((offset, page_size)) # The last page is truncated readv_request[-1] = (readv_request[-1][0], size % page_size) - expected = [('readv', 'index', [(0, page_size)], False, None), - ('readv', 'index', readv_request, False, None)] + expected = [ + ("readv", "index", [(0, page_size)], False, None), + ("readv", "index", readv_request, False, None), + ] if expected != t._activity: - self.assertEqualDiff(pprint.pformat(expected), - pprint.pformat(t._activity)) + self.assertEqualDiff(pprint.pformat(expected), pprint.pformat(t._activity)) def test_iter_entries_references_2_refs_resolved(self): # iterating some entries reads just the pages needed. For now, to @@ -921,10 +955,10 @@ def test_iter_entries_references_2_refs_resolved(self): nodes = self.make_nodes(160, 2, 2) for node in nodes: builder.add_node(*node) - t = transport.get_transport_from_url('trace+' + self.get_url('')) - size = t.put_file('index', builder.finish()) + t = transport.get_transport_from_url("trace+" + self.get_url("")) + size = t.put_file("index", builder.finish()) del builder - index = btree_index.BTreeGraphIndex(t, 'index', size) + index = btree_index.BTreeGraphIndex(t, "index", size) del t._activity[:] self.assertEqual([], t._activity) # search for one key @@ -938,68 +972,116 @@ def test_iter_entries_references_2_refs_resolved(self): # Should have the same content self.assertEqual(nodes[30], bare_nodes[0]) # Should have read the root node, then one leaf page: - self.assertEqual([('readv', 'index', [(0, 4096)], False, None), - ('readv', 'index', [(8192, 4096), ], False, None)], - t._activity) + self.assertEqual( + [ + ("readv", "index", [(0, 4096)], False, None), + ( + "readv", + "index", + [ + (8192, 4096), + ], + False, + None, + ), + ], + t._activity, + ) def test_iter_key_prefix_1_element_key_None(self): index = self.make_index() - self.assertRaises(_mod_index.BadIndexKey, list, - index.iter_entries_prefix([(None, )])) + self.assertRaises( + _mod_index.BadIndexKey, list, index.iter_entries_prefix([(None,)]) + ) def test_iter_key_prefix_wrong_length(self): index = self.make_index() - self.assertRaises(_mod_index.BadIndexKey, list, - index.iter_entries_prefix([(b'foo', None)])) + self.assertRaises( + _mod_index.BadIndexKey, list, index.iter_entries_prefix([(b"foo", None)]) + ) index = self.make_index(key_elements=2) - self.assertRaises(_mod_index.BadIndexKey, list, - index.iter_entries_prefix([(b'foo', )])) - self.assertRaises(_mod_index.BadIndexKey, list, - index.iter_entries_prefix([(b'foo', None, None)])) + self.assertRaises( + _mod_index.BadIndexKey, list, index.iter_entries_prefix([(b"foo",)]) + ) + self.assertRaises( + _mod_index.BadIndexKey, + list, + index.iter_entries_prefix([(b"foo", None, None)]), + ) def test_iter_key_prefix_1_key_element_no_refs(self): - index = self.make_index(nodes=[ - ((b'name', ), b'data', ()), - ((b'ref', ), b'refdata', ())]) - self.assertEqual({(index, (b'name', ), b'data'), - (index, (b'ref', ), b'refdata')}, - set(index.iter_entries_prefix([(b'name', ), (b'ref', )]))) + index = self.make_index( + nodes=[((b"name",), b"data", ()), ((b"ref",), b"refdata", ())] + ) + self.assertEqual( + {(index, (b"name",), b"data"), (index, (b"ref",), b"refdata")}, + set(index.iter_entries_prefix([(b"name",), (b"ref",)])), + ) def test_iter_key_prefix_1_key_element_refs(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], )), - ((b'ref', ), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),)), - (index, (b'ref', ), b'refdata', ((), ))}, - set(index.iter_entries_prefix([(b'name', ), (b'ref', )]))) + index = self.make_index( + 1, + nodes=[ + ((b"name",), b"data", ([(b"ref",)],)), + ((b"ref",), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name",), b"data", (((b"ref",),),)), + (index, (b"ref",), b"refdata", ((),)), + }, + set(index.iter_entries_prefix([(b"name",), (b"ref",)])), + ) def test_iter_key_prefix_2_key_element_no_refs(self): - index = self.make_index(key_elements=2, nodes=[ - ((b'name', b'fin1'), b'data', ()), - ((b'name', b'fin2'), b'beta', ()), - ((b'ref', b'erence'), b'refdata', ())]) - self.assertEqual({(index, (b'name', b'fin1'), b'data'), - (index, (b'ref', b'erence'), b'refdata')}, - set(index.iter_entries_prefix( - [(b'name', b'fin1'), (b'ref', b'erence')]))) - self.assertEqual({(index, (b'name', b'fin1'), b'data'), - (index, (b'name', b'fin2'), b'beta')}, - set(index.iter_entries_prefix([(b'name', None)]))) + index = self.make_index( + key_elements=2, + nodes=[ + ((b"name", b"fin1"), b"data", ()), + ((b"name", b"fin2"), b"beta", ()), + ((b"ref", b"erence"), b"refdata", ()), + ], + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data"), + (index, (b"ref", b"erence"), b"refdata"), + }, + set(index.iter_entries_prefix([(b"name", b"fin1"), (b"ref", b"erence")])), + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data"), + (index, (b"name", b"fin2"), b"beta"), + }, + set(index.iter_entries_prefix([(b"name", None)])), + ) def test_iter_key_prefix_2_key_element_refs(self): - index = self.make_index(1, key_elements=2, nodes=[ - ((b'name', b'fin1'), b'data', ([(b'ref', b'erence')], )), - ((b'name', b'fin2'), b'beta', ([], )), - ((b'ref', b'erence'), b'refdata', ([], ))]) + index = self.make_index( + 1, + key_elements=2, + nodes=[ + ((b"name", b"fin1"), b"data", ([(b"ref", b"erence")],)), + ((b"name", b"fin2"), b"beta", ([],)), + ((b"ref", b"erence"), b"refdata", ([],)), + ], + ) self.assertEqual( - {(index, (b'name', b'fin1'), b'data', (((b'ref', b'erence'),),)), - (index, (b'ref', b'erence'), b'refdata', ((), ))}, - set(index.iter_entries_prefix( - [(b'name', b'fin1'), (b'ref', b'erence')]))) + { + (index, (b"name", b"fin1"), b"data", (((b"ref", b"erence"),),)), + (index, (b"ref", b"erence"), b"refdata", ((),)), + }, + set(index.iter_entries_prefix([(b"name", b"fin1"), (b"ref", b"erence")])), + ) self.assertEqual( - {(index, (b'name', b'fin1'), b'data', (((b'ref', b'erence'),),)), - (index, (b'name', b'fin2'), b'beta', ((), ))}, - set(index.iter_entries_prefix([(b'name', None)]))) + { + (index, (b"name", b"fin1"), b"data", (((b"ref", b"erence"),),)), + (index, (b"name", b"fin2"), b"beta", ((),)), + }, + set(index.iter_entries_prefix([(b"name", None)])), + ) # XXX: external_references tests are duplicated in test_index. We # probably should have per_graph_index tests... @@ -1008,57 +1090,67 @@ def test_external_references_no_refs(self): self.assertRaises(ValueError, index.external_references, 0) def test_external_references_no_results(self): - index = self.make_index(ref_lists=1, nodes=[ - ((b'key',), b'value', ([],))]) + index = self.make_index(ref_lists=1, nodes=[((b"key",), b"value", ([],))]) self.assertEqual(set(), index.external_references(0)) def test_external_references_missing_ref(self): - missing_key = (b'missing',) - index = self.make_index(ref_lists=1, nodes=[ - ((b'key',), b'value', ([missing_key],))]) + missing_key = (b"missing",) + index = self.make_index( + ref_lists=1, nodes=[((b"key",), b"value", ([missing_key],))] + ) self.assertEqual({missing_key}, index.external_references(0)) def test_external_references_multiple_ref_lists(self): - missing_key = (b'missing',) - index = self.make_index(ref_lists=2, nodes=[ - ((b'key',), b'value', ([], [missing_key]))]) + missing_key = (b"missing",) + index = self.make_index( + ref_lists=2, nodes=[((b"key",), b"value", ([], [missing_key]))] + ) self.assertEqual(set(), index.external_references(0)) self.assertEqual({missing_key}, index.external_references(1)) def test_external_references_two_records(self): - index = self.make_index(ref_lists=1, nodes=[ - ((b'key-1',), b'value', ([(b'key-2',)],)), - ((b'key-2',), b'value', ([],)), - ]) + index = self.make_index( + ref_lists=1, + nodes=[ + ((b"key-1",), b"value", ([(b"key-2",)],)), + ((b"key-2",), b"value", ([],)), + ], + ) self.assertEqual(set(), index.external_references(0)) def test__find_ancestors_one_page(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - index = self.make_index(ref_lists=1, key_elements=1, nodes=[ - (key1, b'value', ([key2],)), - (key2, b'value', ([],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + index = self.make_index( + ref_lists=1, + key_elements=1, + nodes=[ + (key1, b"value", ([key2],)), + (key2, b"value", ([],)), + ], + ) parent_map = {} missing_keys = set() - search_keys = index._find_ancestors( - [key1], 0, parent_map, missing_keys) + search_keys = index._find_ancestors([key1], 0, parent_map, missing_keys) self.assertEqual({key1: (key2,), key2: ()}, parent_map) self.assertEqual(set(), missing_keys) self.assertEqual(set(), search_keys) def test__find_ancestors_one_page_w_missing(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - key3 = (b'key-3',) - index = self.make_index(ref_lists=1, key_elements=1, nodes=[ - (key1, b'value', ([key2],)), - (key2, b'value', ([],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + key3 = (b"key-3",) + index = self.make_index( + ref_lists=1, + key_elements=1, + nodes=[ + (key1, b"value", ([key2],)), + (key2, b"value", ([],)), + ], + ) parent_map = {} missing_keys = set() - search_keys = index._find_ancestors([key2, key3], 0, parent_map, - missing_keys) + search_keys = index._find_ancestors([key2, key3], 0, parent_map, missing_keys) self.assertEqual({key2: ()}, parent_map) # we know that key3 is missing because we read the page that it would # otherwise be on @@ -1066,44 +1158,49 @@ def test__find_ancestors_one_page_w_missing(self): self.assertEqual(set(), search_keys) def test__find_ancestors_one_parent_missing(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - key3 = (b'key-3',) - index = self.make_index(ref_lists=1, key_elements=1, nodes=[ - (key1, b'value', ([key2],)), - (key2, b'value', ([key3],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + key3 = (b"key-3",) + index = self.make_index( + ref_lists=1, + key_elements=1, + nodes=[ + (key1, b"value", ([key2],)), + (key2, b"value", ([key3],)), + ], + ) parent_map = {} missing_keys = set() - search_keys = index._find_ancestors([key1], 0, parent_map, - missing_keys) + search_keys = index._find_ancestors([key1], 0, parent_map, missing_keys) self.assertEqual({key1: (key2,), key2: (key3,)}, parent_map) self.assertEqual(set(), missing_keys) # all we know is that key3 wasn't present on the page we were reading # but if you look, the last key is key2 which comes before key3, so we # don't know whether key3 would land on this page or not. self.assertEqual({key3}, search_keys) - search_keys = index._find_ancestors(search_keys, 0, parent_map, - missing_keys) + search_keys = index._find_ancestors(search_keys, 0, parent_map, missing_keys) # passing it back in, we are sure it is 'missing' self.assertEqual({key1: (key2,), key2: (key3,)}, parent_map) self.assertEqual({key3}, missing_keys) self.assertEqual(set(), search_keys) def test__find_ancestors_dont_search_known(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - key3 = (b'key-3',) - index = self.make_index(ref_lists=1, key_elements=1, nodes=[ - (key1, b'value', ([key2],)), - (key2, b'value', ([key3],)), - (key3, b'value', ([],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + key3 = (b"key-3",) + index = self.make_index( + ref_lists=1, + key_elements=1, + nodes=[ + (key1, b"value", ([key2],)), + (key2, b"value", ([key3],)), + (key3, b"value", ([],)), + ], + ) # We already know about key2, so we won't try to search for key3 parent_map = {key2: (key3,)} missing_keys = set() - search_keys = index._find_ancestors([key1], 0, parent_map, - missing_keys) + search_keys = index._find_ancestors([key1], 0, parent_map, missing_keys) self.assertEqual({key1: (key2,), key2: (key3,)}, parent_map) self.assertEqual(set(), missing_keys) self.assertEqual(set(), search_keys) @@ -1116,11 +1213,13 @@ def test__find_ancestors_multiple_pages(self): ref_lists = ((),) rev_keys = [] for i in range(400): - rev_id = ('{}-{}-{}'.format(email, - osutils.compact_date(start_time + i), - osutils.rand_chars(16))).encode('ascii') + rev_id = ( + "{}-{}-{}".format( + email, osutils.compact_date(start_time + i), osutils.rand_chars(16) + ) + ).encode("ascii") rev_key = (rev_id,) - nodes.append((rev_key, b'value', ref_lists)) + nodes.append((rev_key, b"value", ref_lists)) # We have a ref 'list' of length 1, with a list of parents, with 1 # parent which is a key ref_lists = ((rev_key,),) @@ -1145,14 +1244,12 @@ def test__find_ancestors_multiple_pages(self): # parent of that key parent_map = {} missing_keys = set() - search_keys = index._find_ancestors([next_key], 0, parent_map, - missing_keys) + search_keys = index._find_ancestors([next_key], 0, parent_map, missing_keys) self.assertEqual([min_l2_key, next_key], sorted(parent_map)) self.assertEqual(set(), missing_keys) self.assertEqual({max_l1_key}, search_keys) parent_map = {} - search_keys = index._find_ancestors([max_l1_key], 0, parent_map, - missing_keys) + search_keys = index._find_ancestors([max_l1_key], 0, parent_map, missing_keys) self.assertEqual(l1.all_keys(), sorted(parent_map)) self.assertEqual(set(), missing_keys) self.assertEqual(set(), search_keys) @@ -1161,11 +1258,12 @@ def test__find_ancestors_empty_index(self): index = self.make_index(ref_lists=1, key_elements=1, nodes=[]) parent_map = {} missing_keys = set() - search_keys = index._find_ancestors([('one',), ('two',)], 0, parent_map, - missing_keys) + search_keys = index._find_ancestors( + [("one",), ("two",)], 0, parent_map, missing_keys + ) self.assertEqual(set(), search_keys) self.assertEqual({}, parent_map) - self.assertEqual({('one',), ('two',)}, missing_keys) + self.assertEqual({("one",), ("two",)}, missing_keys) def test_supports_unlimited_cache(self): builder = btree_index.BTreeBuilder(reference_lists=0, key_elements=1) @@ -1176,28 +1274,28 @@ def test_supports_unlimited_cache(self): builder.add_node(*node) stream = builder.finish() trans = self.get_transport() - size = trans.put_file('index', stream) - index = btree_index.BTreeGraphIndex(trans, 'index', size) + size = trans.put_file("index", stream) + index = btree_index.BTreeGraphIndex(trans, "index", size) self.assertEqual(500, index.key_count()) # We have an internal node self.assertEqual(2, len(index._row_lengths)) # We have at least 2 leaf nodes self.assertGreaterEqual(index._row_lengths[-1], 2) self.assertIsInstance(index._leaf_node_cache, lru_cache.LRUCache) - self.assertEqual(btree_index._NODE_CACHE_SIZE, - index._leaf_node_cache._max_cache) + self.assertEqual( + btree_index._NODE_CACHE_SIZE, index._leaf_node_cache._max_cache + ) self.assertIsInstance(index._internal_node_cache, fifo_cache.FIFOCache) self.assertEqual(100, index._internal_node_cache._max_cache) # No change if unlimited_cache=False is passed - index = btree_index.BTreeGraphIndex(trans, 'index', size, - unlimited_cache=False) + index = btree_index.BTreeGraphIndex(trans, "index", size, unlimited_cache=False) self.assertIsInstance(index._leaf_node_cache, lru_cache.LRUCache) - self.assertEqual(btree_index._NODE_CACHE_SIZE, - index._leaf_node_cache._max_cache) + self.assertEqual( + btree_index._NODE_CACHE_SIZE, index._leaf_node_cache._max_cache + ) self.assertIsInstance(index._internal_node_cache, fifo_cache.FIFOCache) self.assertEqual(100, index._internal_node_cache._max_cache) - index = btree_index.BTreeGraphIndex(trans, 'index', size, - unlimited_cache=True) + index = btree_index.BTreeGraphIndex(trans, "index", size, unlimited_cache=True) self.assertIsInstance(index._leaf_node_cache, dict) self.assertIs(type(index._internal_node_cache), dict) # Exercise the lookup code @@ -1206,105 +1304,145 @@ def test_supports_unlimited_cache(self): class TestBTreeNodes(BTreeTestCase): - scenarios = btreeparser_scenarios() def setUp(self): super().setUp() - self.overrideAttr(btree_index, '_btree_serializer', self.parse_btree) + self.overrideAttr(btree_index, "_btree_serializer", self.parse_btree) def test_LeafNode_1_0(self): - node_bytes = (b"type=leaf\n" - b"0000000000000000000000000000000000000000\x00\x00value:0\n" - b"1111111111111111111111111111111111111111\x00\x00value:1\n" - b"2222222222222222222222222222222222222222\x00\x00value:2\n" - b"3333333333333333333333333333333333333333\x00\x00value:3\n" - b"4444444444444444444444444444444444444444\x00\x00value:4\n") + node_bytes = ( + b"type=leaf\n" + b"0000000000000000000000000000000000000000\x00\x00value:0\n" + b"1111111111111111111111111111111111111111\x00\x00value:1\n" + b"2222222222222222222222222222222222222222\x00\x00value:2\n" + b"3333333333333333333333333333333333333333\x00\x00value:3\n" + b"4444444444444444444444444444444444444444\x00\x00value:4\n" + ) node = btree_index._LeafNode(node_bytes, 1, 0) # We do direct access, or don't care about order, to leaf nodes most of # the time, so a dict is useful: - self.assertEqual({ - (b"0000000000000000000000000000000000000000",): (b"value:0", ()), - (b"1111111111111111111111111111111111111111",): (b"value:1", ()), - (b"2222222222222222222222222222222222222222",): (b"value:2", ()), - (b"3333333333333333333333333333333333333333",): (b"value:3", ()), - (b"4444444444444444444444444444444444444444",): (b"value:4", ()), - }, dict(node.all_items())) + self.assertEqual( + { + (b"0000000000000000000000000000000000000000",): (b"value:0", ()), + (b"1111111111111111111111111111111111111111",): (b"value:1", ()), + (b"2222222222222222222222222222222222222222",): (b"value:2", ()), + (b"3333333333333333333333333333333333333333",): (b"value:3", ()), + (b"4444444444444444444444444444444444444444",): (b"value:4", ()), + }, + dict(node.all_items()), + ) def test_LeafNode_2_2(self): - node_bytes = (b"type=leaf\n" - b"00\x0000\x00\t00\x00ref00\x00value:0\n" - b"00\x0011\x0000\x00ref00\t00\x00ref00\r01\x00ref01\x00value:1\n" - b"11\x0033\x0011\x00ref22\t11\x00ref22\r11\x00ref22\x00value:3\n" - b"11\x0044\x00\t11\x00ref00\x00value:4\n" - b"" - ) + node_bytes = ( + b"type=leaf\n" + b"00\x0000\x00\t00\x00ref00\x00value:0\n" + b"00\x0011\x0000\x00ref00\t00\x00ref00\r01\x00ref01\x00value:1\n" + b"11\x0033\x0011\x00ref22\t11\x00ref22\r11\x00ref22\x00value:3\n" + b"11\x0044\x00\t11\x00ref00\x00value:4\n" + b"" + ) node = btree_index._LeafNode(node_bytes, 2, 2) # We do direct access, or don't care about order, to leaf nodes most of # the time, so a dict is useful: - self.assertEqual({ - (b'00', b'00'): (b'value:0', ((), ((b'00', b'ref00'),))), - (b'00', b'11'): (b'value:1', (((b'00', b'ref00'),), - ((b'00', b'ref00'), (b'01', b'ref01')))), - (b'11', b'33'): (b'value:3', (((b'11', b'ref22'),), - ((b'11', b'ref22'), (b'11', b'ref22')))), - (b'11', b'44'): (b'value:4', ((), ((b'11', b'ref00'),))) - }, dict(node.all_items())) + self.assertEqual( + { + (b"00", b"00"): (b"value:0", ((), ((b"00", b"ref00"),))), + (b"00", b"11"): ( + b"value:1", + (((b"00", b"ref00"),), ((b"00", b"ref00"), (b"01", b"ref01"))), + ), + (b"11", b"33"): ( + b"value:3", + (((b"11", b"ref22"),), ((b"11", b"ref22"), (b"11", b"ref22"))), + ), + (b"11", b"44"): (b"value:4", ((), ((b"11", b"ref00"),))), + }, + dict(node.all_items()), + ) def test_InternalNode_1(self): - node_bytes = (b"type=internal\n" - b"offset=1\n" - b"0000000000000000000000000000000000000000\n" - b"1111111111111111111111111111111111111111\n" - b"2222222222222222222222222222222222222222\n" - b"3333333333333333333333333333333333333333\n" - b"4444444444444444444444444444444444444444\n" - ) + node_bytes = ( + b"type=internal\n" + b"offset=1\n" + b"0000000000000000000000000000000000000000\n" + b"1111111111111111111111111111111111111111\n" + b"2222222222222222222222222222222222222222\n" + b"3333333333333333333333333333333333333333\n" + b"4444444444444444444444444444444444444444\n" + ) node = btree_index._InternalNode(node_bytes) # We want to bisect to find the right children from this node, so a # vector is most useful. - self.assertEqual([ - (b"0000000000000000000000000000000000000000",), - (b"1111111111111111111111111111111111111111",), - (b"2222222222222222222222222222222222222222",), - (b"3333333333333333333333333333333333333333",), - (b"4444444444444444444444444444444444444444",), - ], node.keys) + self.assertEqual( + [ + (b"0000000000000000000000000000000000000000",), + (b"1111111111111111111111111111111111111111",), + (b"2222222222222222222222222222222222222222",), + (b"3333333333333333333333333333333333333333",), + (b"4444444444444444444444444444444444444444",), + ], + node.keys, + ) self.assertEqual(1, node.offset) def assertFlattened(self, expected, key, value, refs): flat_key, flat_line = self.parse_btree._flatten_node( - (None, key, value, refs), bool(refs)) - self.assertEqual(b'\x00'.join(key), flat_key) + (None, key, value, refs), bool(refs) + ) + self.assertEqual(b"\x00".join(key), flat_key) self.assertEqual(expected, flat_line) def test__flatten_node(self): - self.assertFlattened(b'key\0\0value\n', (b'key',), b'value', []) - self.assertFlattened(b'key\0tuple\0\0value str\n', - (b'key', b'tuple'), b'value str', []) - self.assertFlattened(b'key\0tuple\0triple\0\0value str\n', - (b'key', b'tuple', b'triple'), b'value str', []) - self.assertFlattened(b'k\0t\0s\0ref\0value str\n', - (b'k', b't', b's'), b'value str', [[(b'ref',)]]) - self.assertFlattened(b'key\0tuple\0ref\0key\0value str\n', - (b'key', b'tuple'), b'value str', [[(b'ref', b'key')]]) - self.assertFlattened(b"00\x0000\x00\t00\x00ref00\x00value:0\n", - (b'00', b'00'), b'value:0', ((), ((b'00', b'ref00'),))) + self.assertFlattened(b"key\0\0value\n", (b"key",), b"value", []) + self.assertFlattened( + b"key\0tuple\0\0value str\n", (b"key", b"tuple"), b"value str", [] + ) + self.assertFlattened( + b"key\0tuple\0triple\0\0value str\n", + (b"key", b"tuple", b"triple"), + b"value str", + [], + ) + self.assertFlattened( + b"k\0t\0s\0ref\0value str\n", + (b"k", b"t", b"s"), + b"value str", + [[(b"ref",)]], + ) + self.assertFlattened( + b"key\0tuple\0ref\0key\0value str\n", + (b"key", b"tuple"), + b"value str", + [[(b"ref", b"key")]], + ) + self.assertFlattened( + b"00\x0000\x00\t00\x00ref00\x00value:0\n", + (b"00", b"00"), + b"value:0", + ((), ((b"00", b"ref00"),)), + ) self.assertFlattened( b"00\x0011\x0000\x00ref00\t00\x00ref00\r01\x00ref01\x00value:1\n", - (b'00', b'11'), b'value:1', - (((b'00', b'ref00'),), ((b'00', b'ref00'), (b'01', b'ref01')))) + (b"00", b"11"), + b"value:1", + (((b"00", b"ref00"),), ((b"00", b"ref00"), (b"01", b"ref01"))), + ) self.assertFlattened( b"11\x0033\x0011\x00ref22\t11\x00ref22\r11\x00ref22\x00value:3\n", - (b'11', b'33'), b'value:3', - (((b'11', b'ref22'),), ((b'11', b'ref22'), (b'11', b'ref22')))) + (b"11", b"33"), + b"value:3", + (((b"11", b"ref22"),), ((b"11", b"ref22"), (b"11", b"ref22"))), + ) self.assertFlattened( b"11\x0044\x00\t11\x00ref00\x00value:4\n", - (b'11', b'44'), b'value:4', ((), ((b'11', b'ref00'),))) + (b"11", b"44"), + b"value:4", + ((), ((b"11", b"ref00"),)), + ) class TestCompiledBtree(tests.TestCase): - def test_exists(self): # This is just to let the user know if they don't have the feature # available @@ -1312,43 +1450,48 @@ def test_exists(self): class TestMultiBisectRight(tests.TestCase): - def assertMultiBisectRight(self, offsets, search_keys, fixed_keys): - self.assertEqual(offsets, - btree_index.BTreeGraphIndex._multi_bisect_right( - search_keys, fixed_keys)) + self.assertEqual( + offsets, + btree_index.BTreeGraphIndex._multi_bisect_right(search_keys, fixed_keys), + ) def test_after(self): - self.assertMultiBisectRight([(1, ['b'])], ['b'], ['a']) - self.assertMultiBisectRight([(3, ['e', 'f', 'g'])], - ['e', 'f', 'g'], ['a', 'b', 'c']) + self.assertMultiBisectRight([(1, ["b"])], ["b"], ["a"]) + self.assertMultiBisectRight( + [(3, ["e", "f", "g"])], ["e", "f", "g"], ["a", "b", "c"] + ) def test_before(self): - self.assertMultiBisectRight([(0, ['a'])], ['a'], ['b']) - self.assertMultiBisectRight([(0, ['a', 'b', 'c', 'd'])], - ['a', 'b', 'c', 'd'], ['e', 'f', 'g']) + self.assertMultiBisectRight([(0, ["a"])], ["a"], ["b"]) + self.assertMultiBisectRight( + [(0, ["a", "b", "c", "d"])], ["a", "b", "c", "d"], ["e", "f", "g"] + ) def test_exact(self): - self.assertMultiBisectRight([(1, ['a'])], ['a'], ['a']) + self.assertMultiBisectRight([(1, ["a"])], ["a"], ["a"]) + self.assertMultiBisectRight([(1, ["a"]), (2, ["b"])], ["a", "b"], ["a", "b"]) self.assertMultiBisectRight( - [(1, ['a']), (2, ['b'])], ['a', 'b'], ['a', 'b']) - self.assertMultiBisectRight([(1, ['a']), (3, ['c'])], - ['a', 'c'], ['a', 'b', 'c']) + [(1, ["a"]), (3, ["c"])], ["a", "c"], ["a", "b", "c"] + ) def test_inbetween(self): - self.assertMultiBisectRight([(1, ['b'])], ['b'], ['a', 'c']) - self.assertMultiBisectRight([(1, ['b', 'c', 'd']), (2, ['f', 'g'])], - ['b', 'c', 'd', 'f', 'g'], ['a', 'e', 'h']) + self.assertMultiBisectRight([(1, ["b"])], ["b"], ["a", "c"]) + self.assertMultiBisectRight( + [(1, ["b", "c", "d"]), (2, ["f", "g"])], + ["b", "c", "d", "f", "g"], + ["a", "e", "h"], + ) def test_mixed(self): - self.assertMultiBisectRight([(0, ['a', 'b']), (2, ['d', 'e']), - (4, ['g', 'h'])], - ['a', 'b', 'd', 'e', 'g', 'h'], - ['c', 'd', 'f', 'g']) + self.assertMultiBisectRight( + [(0, ["a", "b"]), (2, ["d", "e"]), (4, ["g", "h"])], + ["a", "b", "d", "e", "g", "h"], + ["c", "d", "f", "g"], + ) class TestExpandOffsets(tests.TestCase): - def make_index(self, size, recommended_pages=None): """Make an index with a generic size. @@ -1356,44 +1499,57 @@ def make_index(self, size, recommended_pages=None): BTreeGraphIndex with the recommended information. """ index = btree_index.BTreeGraphIndex( - transport.get_transport_from_url('memory:///'), - 'test-index', size=size) + transport.get_transport_from_url("memory:///"), "test-index", size=size + ) if recommended_pages is not None: index._recommended_pages = recommended_pages return index def set_cached_offsets(self, index, cached_offsets): """Monkeypatch to give a canned answer for _get_offsets_for...().""" + def _get_offsets_to_cached_pages(): cached = set(cached_offsets) return cached + index._get_offsets_to_cached_pages = _get_offsets_to_cached_pages - def prepare_index(self, index, node_ref_lists, key_length, key_count, - row_lengths, cached_offsets): + def prepare_index( + self, index, node_ref_lists, key_length, key_count, row_lengths, cached_offsets + ): """Setup the BTreeGraphIndex with some pre-canned information.""" index.node_ref_lists = node_ref_lists index._key_length = key_length index._key_count = key_count index._row_lengths = row_lengths index._compute_row_offsets() - index._root_node = btree_index._InternalNode(b'internal\noffset=0\n') + index._root_node = btree_index._InternalNode(b"internal\noffset=0\n") self.set_cached_offsets(index, cached_offsets) def make_100_node_index(self): index = self.make_index(4096 * 100, 6) # Consider we've already made a single request at the middle - self.prepare_index(index, node_ref_lists=0, key_length=1, - key_count=1000, row_lengths=[1, 99], - cached_offsets=[0, 50]) + self.prepare_index( + index, + node_ref_lists=0, + key_length=1, + key_count=1000, + row_lengths=[1, 99], + cached_offsets=[0, 50], + ) return index def make_1000_node_index(self): index = self.make_index(4096 * 1000, 6) # Pretend we've already made a single request in the middle - self.prepare_index(index, node_ref_lists=0, key_length=1, - key_count=90000, row_lengths=[1, 9, 990], - cached_offsets=[0, 5, 500]) + self.prepare_index( + index, + node_ref_lists=0, + key_length=1, + key_count=90000, + row_lengths=[1, 9, 990], + cached_offsets=[0, 5, 500], + ) return index def assertNumPages(self, expected_pages, index, size): @@ -1401,8 +1557,11 @@ def assertNumPages(self, expected_pages, index, size): self.assertEqual(expected_pages, index._compute_total_pages_in_index()) def assertExpandOffsets(self, expected, index, offsets): - self.assertEqual(expected, index._expand_offsets(offsets), - f'We did not get the expected value after expanding {offsets}') + self.assertEqual( + expected, + index._expand_offsets(offsets), + f"We did not get the expected value after expanding {offsets}", + ) def test_default_recommended_pages(self): index = self.make_index(None) @@ -1445,9 +1604,14 @@ def test_read_all_from_root(self): def test_read_all_when_cached(self): # We've read enough that we can grab all the rest in a single request index = self.make_index(4096 * 10, 5) - self.prepare_index(index, node_ref_lists=0, key_length=1, - key_count=1000, row_lengths=[1, 9], - cached_offsets=[0, 1, 2, 5, 6]) + self.prepare_index( + index, + node_ref_lists=0, + key_length=1, + key_count=1000, + row_lengths=[1, 9], + cached_offsets=[0, 1, 2, 5, 6], + ) # It should fill the remaining nodes, regardless of the one requested self.assertExpandOffsets([3, 4, 7, 8, 9], index, [3]) self.assertExpandOffsets([3, 4, 7, 8, 9], index, [8]) @@ -1469,8 +1633,7 @@ def test_include_neighbors(self): # Requesting many nodes will expand all locations equally self.assertExpandOffsets([1, 2, 3, 80, 81, 82], index, [2, 81]) - self.assertExpandOffsets([1, 2, 3, 9, 10, 11, 80, 81, 82], index, - [2, 10, 81]) + self.assertExpandOffsets([1, 2, 3, 9, 10, 11, 80, 81, 82], index, [2, 10, 81]) def test_stop_at_cached(self): index = self.make_100_node_index() @@ -1525,5 +1688,4 @@ def test_small_requests_unexpanded(self): self.assertExpandOffsets([2, 3, 4, 5, 6, 7], index, [2]) self.assertExpandOffsets([2, 3, 4, 5, 6, 7], index, [4]) self.set_cached_offsets(index, [0, 1, 2, 3, 4, 5, 6, 7, 100]) - self.assertExpandOffsets([102, 103, 104, 105, 106, 107, 108], index, - [105]) + self.assertExpandOffsets([102, 103, 104, 105, 106, 107, 108], index, [105]) diff --git a/breezy/bzr/tests/test_bundle.py b/breezy/bzr/tests/test_bundle.py index 79c960066c..a995447406 100644 --- a/breezy/bzr/tests/test_bundle.py +++ b/breezy/bzr/tests/test_bundle.py @@ -37,9 +37,9 @@ def get_text(vf, key): """Get the fulltext for a given revision id that is present in the vf.""" - stream = vf.get_record_stream([key], 'unordered', True) + stream = vf.get_record_stream([key], "unordered", True) record = next(stream) - return record.get_bytes_as('fulltext') + return record.get_bytes_as("fulltext") def get_inventory_text(repo, revision_id): @@ -49,14 +49,14 @@ def get_inventory_text(repo, revision_id): class MockTree(InventoryTree): - def __init__(self): from ..inventory import ROOT_ID, InventoryDirectory + object.__init__(self) self.paths = {ROOT_ID: ""} self.ids = {"": ROOT_ID} self.contents = {} - self.root = InventoryDirectory(ROOT_ID, '', None) + self.root = InventoryDirectory(ROOT_ID, "", None) inventory = property(lambda x: x) root_inventory = property(lambda x: x) @@ -95,29 +95,30 @@ def iter_entries(self): def kind(self, path): if path in self.contents: - kind = 'file' + kind = "file" else: - kind = 'directory' + kind = "directory" return kind def make_entry(self, file_id, path): from ..inventory import InventoryDirectory, InventoryFile, InventoryLink + if not isinstance(file_id, bytes): raise TypeError(file_id) name = os.path.basename(path) kind = self.kind(path) parent_id = self.parent_id(file_id) text_sha_1, text_size = self.contents_stats(path) - if kind == 'directory': + if kind == "directory": ie = InventoryDirectory(file_id, name, parent_id) - elif kind == 'file': + elif kind == "file": ie = InventoryFile(file_id, name, parent_id) ie.text_sha1 = text_sha_1 ie.text_size = text_size - elif kind == 'symlink': + elif kind == "symlink": ie = InventoryLink(file_id, name, parent_id) else: - raise errors.BzrError(f'unknown kind {kind!r}') + raise errors.BzrError(f"unknown kind {kind!r}") return ie def add_dir(self, file_id, path): @@ -135,7 +136,7 @@ def add_file(self, file_id, path, contents): def path2id(self, path): return self.ids.get(path) - def id2path(self, file_id, recurse='down'): + def id2path(self, file_id, recurse="down"): try: return self.paths[file_id] except KeyError as e: @@ -175,16 +176,16 @@ def make_tree_1(self): mtree.add_dir(b"b", "grandparent/parent") mtree.add_file(b"c", "grandparent/parent/file", b"Hello\n") mtree.add_dir(b"d", "grandparent/alt_parent") - return BundleTree(mtree, b''), mtree + return BundleTree(mtree, b""), mtree def test_renames(self): """Ensure that file renames have the proper effect on children.""" btree = self.make_tree_1()[0] self.assertEqual(btree.old_path("grandparent"), "grandparent") - self.assertEqual(btree.old_path("grandparent/parent"), - "grandparent/parent") - self.assertEqual(btree.old_path("grandparent/parent/file"), - "grandparent/parent/file") + self.assertEqual(btree.old_path("grandparent/parent"), "grandparent/parent") + self.assertEqual( + btree.old_path("grandparent/parent/file"), "grandparent/parent/file" + ) self.assertEqual(btree.id2path(b"a"), "grandparent") self.assertEqual(btree.id2path(b"b"), "grandparent/parent") @@ -227,8 +228,7 @@ def test_renames(self): self.assertIsNone(btree.path2id("grandparent2/parent")) self.assertIsNone(btree.path2id("grandparent2/parent/file")) - btree.note_rename("grandparent/parent/file", - "grandparent2/parent2/file2") + btree.note_rename("grandparent/parent/file", "grandparent2/parent2/file2") self.assertEqual(btree.id2path(b"a"), "grandparent2") self.assertEqual(btree.id2path(b"b"), "grandparent2/parent2") self.assertEqual(btree.id2path(b"c"), "grandparent2/parent2/file2") @@ -242,8 +242,7 @@ def test_renames(self): def test_moves(self): """Ensure that file moves have the proper effect on children.""" btree = self.make_tree_1()[0] - btree.note_rename("grandparent/parent/file", - "grandparent/alt_parent/file") + btree.note_rename("grandparent/parent/file", "grandparent/alt_parent/file") self.assertEqual(btree.id2path(b"c"), "grandparent/alt_parent/file") self.assertEqual(btree.path2id("grandparent/alt_parent/file"), b"c") self.assertIsNone(btree.path2id("grandparent/parent/file")) @@ -256,8 +255,7 @@ def unified_diff(self, old, new): def make_tree_2(self): btree = self.make_tree_1()[0] - btree.note_rename("grandparent/parent/file", - "grandparent/alt_parent/file") + btree.note_rename("grandparent/parent/file", "grandparent/alt_parent/file") self.assertRaises(errors.NoSuchId, btree.id2path, b"e") self.assertFalse(btree.is_versioned("grandparent/parent/file")) btree.note_id(b"e", "grandparent/parent/file") @@ -268,8 +266,8 @@ def test_adds(self): btree = self.make_tree_2() add_patch = self.unified_diff([], [b"Extra cheese\n"]) btree.note_patch("grandparent/parent/file", add_patch) - btree.note_id(b'f', 'grandparent/parent/symlink', kind='symlink') - btree.note_target('grandparent/parent/symlink', 'venus') + btree.note_id(b"f", "grandparent/parent/symlink", kind="symlink") + btree.note_target("grandparent/parent/symlink", "venus") self.adds_test(btree) def adds_test(self, btree): @@ -278,15 +276,16 @@ def adds_test(self, btree): with btree.get_file("grandparent/parent/file") as f: self.assertEqual(f.read(), b"Extra cheese\n") self.assertEqual( - btree.get_symlink_target('grandparent/parent/symlink'), 'venus') + btree.get_symlink_target("grandparent/parent/symlink"), "venus" + ) def make_tree_3(self): btree, mtree = self.make_tree_1() mtree.add_file(b"e", "grandparent/parent/topping", b"Anchovies\n") - btree.note_rename("grandparent/parent/file", - "grandparent/alt_parent/file") - btree.note_rename("grandparent/parent/topping", - "grandparent/alt_parent/stopping") + btree.note_rename("grandparent/parent/file", "grandparent/alt_parent/file") + btree.note_rename( + "grandparent/parent/topping", "grandparent/alt_parent/stopping" + ) return btree def get_file_test(self, btree): @@ -318,55 +317,63 @@ def sorted_ids(self, tree): def test_iteration(self): """Ensure that iteration through ids works properly.""" btree = self.make_tree_1()[0] - self.assertEqual(self.sorted_ids(btree), - [inventory.ROOT_ID, b'a', b'b', b'c', b'd']) + self.assertEqual( + self.sorted_ids(btree), [inventory.ROOT_ID, b"a", b"b", b"c", b"d"] + ) btree.note_deletion("grandparent/parent/file") btree.note_id(b"e", "grandparent/alt_parent/fool", kind="directory") - btree.note_last_changed("grandparent/alt_parent/fool", - b"revisionidiguess") - self.assertEqual(self.sorted_ids(btree), - [inventory.ROOT_ID, b'a', b'b', b'd', b'e']) + btree.note_last_changed("grandparent/alt_parent/fool", b"revisionidiguess") + self.assertEqual( + self.sorted_ids(btree), [inventory.ROOT_ID, b"a", b"b", b"d", b"e"] + ) class BundleTester1(tests.TestCaseWithTransport): - def test_mismatched_bundle(self): format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit3() - serializer = BundleSerializerV08('0.8') - b = self.make_branch('.', format=format) - self.assertRaises(errors.IncompatibleBundleFormat, serializer.write, - b.repository, [], {}, BytesIO()) + serializer = BundleSerializerV08("0.8") + b = self.make_branch(".", format=format) + self.assertRaises( + errors.IncompatibleBundleFormat, + serializer.write, + b.repository, + [], + {}, + BytesIO(), + ) def test_matched_bundle(self): """Don't raise IncompatibleBundleFormat for knit2 and bundle0.9.""" format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit3() - serializer = BundleSerializerV09('0.9') - b = self.make_branch('.', format=format) + serializer = BundleSerializerV09("0.9") + b = self.make_branch(".", format=format) serializer.write(b.repository, [], {}, BytesIO()) def test_mismatched_model(self): """Try copying a bundle from knit2 to knit1.""" format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit3() - source = self.make_branch_and_tree('source', format=format) - source.commit('one', rev_id=b'one-id') - source.commit('two', rev_id=b'two-id') + source = self.make_branch_and_tree("source", format=format) + source.commit("one", rev_id=b"one-id") + source.commit("two", rev_id=b"two-id") text = BytesIO() - write_bundle(source.branch.repository, b'two-id', b'null:', text, - format='0.9') + write_bundle(source.branch.repository, b"two-id", b"null:", text, format="0.9") text.seek(0) format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit1() - target = self.make_branch('target', format=format) - self.assertRaises(errors.IncompatibleRevision, install_bundle, - target.repository, read_bundle(text)) + target = self.make_branch("target", format=format) + self.assertRaises( + errors.IncompatibleRevision, + install_bundle, + target.repository, + read_bundle(text), + ) class BundleTester: - def bzrdir_format(self): format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit1() @@ -375,8 +382,7 @@ def bzrdir_format(self): def make_branch_and_tree(self, path, format=None): if format is None: format = self.bzrdir_format() - return tests.TestCaseWithTransport.make_branch_and_tree( - self, path, format) + return tests.TestCaseWithTransport.make_branch_and_tree(self, path, format) def make_branch(self, path, format=None, name=None): if format is None: @@ -385,16 +391,18 @@ def make_branch(self, path, format=None, name=None): def create_bundle_text(self, base_rev_id, rev_id): bundle_txt = BytesIO() - rev_ids = write_bundle(self.b1.repository, rev_id, base_rev_id, - bundle_txt, format=self.format) + rev_ids = write_bundle( + self.b1.repository, rev_id, base_rev_id, bundle_txt, format=self.format + ) bundle_txt.seek(0) - self.assertEqual(bundle_txt.readline(), - b'# Bazaar revision bundle v%s\n' % self.format.encode('ascii')) - self.assertEqual(bundle_txt.readline(), b'#\n') + self.assertEqual( + bundle_txt.readline(), + b"# Bazaar revision bundle v%s\n" % self.format.encode("ascii"), + ) + self.assertEqual(bundle_txt.readline(), b"#\n") self.b1.repository.get_revision(rev_id) - self.assertEqual(bundle_txt.readline().decode('utf-8'), - '# message:\n') + self.assertEqual(bundle_txt.readline().decode("utf-8"), "# message:\n") bundle_txt.seek(0) return bundle_txt, rev_ids @@ -416,15 +424,20 @@ def get_valid_bundle(self, base_rev_id, rev_id, checkout_dir=None): # only will match if everything is okay, but lets be explicit about # it branch_rev = repository.get_revision(bundle_rev.revision_id) - for a in ('inventory_sha1', 'revision_id', 'parent_ids', - 'timestamp', 'timezone', 'message', 'committer', - 'parent_ids', 'properties'): - self.assertEqual(getattr(branch_rev, a), - getattr(bundle_rev, a)) - self.assertEqual(len(branch_rev.parent_ids), - len(bundle_rev.parent_ids)) - self.assertEqual(rev_ids, - [r.revision_id for r in bundle.real_revisions]) + for a in ( + "inventory_sha1", + "revision_id", + "parent_ids", + "timestamp", + "timezone", + "message", + "committer", + "parent_ids", + "properties", + ): + self.assertEqual(getattr(branch_rev, a), getattr(bundle_rev, a)) + self.assertEqual(len(branch_rev.parent_ids), len(bundle_rev.parent_ids)) + self.assertEqual(rev_ids, [r.revision_id for r in bundle.real_revisions]) self.valid_apply_bundle(base_rev_id, bundle, checkout_dir=checkout_dir) return bundle @@ -436,24 +449,23 @@ def get_invalid_bundle(self, base_rev_id, rev_id): :return: The in-memory bundle """ bundle_txt, rev_ids = self.create_bundle_text(base_rev_id, rev_id) - new_text = bundle_txt.getvalue().replace(b'executable:no', - b'executable:yes') + new_text = bundle_txt.getvalue().replace(b"executable:no", b"executable:yes") bundle_txt = BytesIO(new_text) bundle = read_bundle(bundle_txt) self.valid_apply_bundle(base_rev_id, bundle) return bundle def test_non_bundle(self): - self.assertRaises(errors.NotABundle, - read_bundle, BytesIO(b'#!/bin/sh\n')) + self.assertRaises(errors.NotABundle, read_bundle, BytesIO(b"#!/bin/sh\n")) def test_malformed(self): - self.assertRaises(errors.BadBundle, read_bundle, - BytesIO(b'# Bazaar revision bundle v')) + self.assertRaises( + errors.BadBundle, read_bundle, BytesIO(b"# Bazaar revision bundle v") + ) def test_crlf_bundle(self): try: - read_bundle(BytesIO(b'# Bazaar revision bundle v0.8\r\n')) + read_bundle(BytesIO(b"# Bazaar revision bundle v0.8\r\n")) except errors.BadBundle: # It is currently permitted for bundles with crlf line endings to # make read_bundle raise a BadBundle, but this should be fixed. @@ -463,14 +475,15 @@ def test_crlf_bundle(self): def get_checkout(self, rev_id, checkout_dir=None): """Get a new tree, with the specified revision in it.""" if checkout_dir is None: - checkout_dir = tempfile.mkdtemp(prefix='test-branch-', dir='.') + checkout_dir = tempfile.mkdtemp(prefix="test-branch-", dir=".") else: if not os.path.exists(checkout_dir): os.mkdir(checkout_dir) tree = self.make_branch_and_tree(checkout_dir) s = BytesIO() - ancestors = write_bundle(self.b1.repository, rev_id, b'null:', s, - format=self.format) + ancestors = write_bundle( + self.b1.repository, rev_id, b"null:", s, format=self.format + ) s.seek(0) self.assertIsInstance(s.getvalue(), bytes) install_bundle(tree.branch.repository, read_bundle(s)) @@ -480,8 +493,9 @@ def get_checkout(self, rev_id, checkout_dir=None): with old.lock_read(), new.lock_read(): # Check that there aren't any inventory level changes delta = new.changes_from(old) - self.assertFalse(delta.has_changed(), - f'Revision {ancestor} not copied correctly.') + self.assertFalse( + delta.has_changed(), f"Revision {ancestor} not copied correctly." + ) # Now check that the file contents are all correct for path in old.all_versioned_paths(): @@ -489,14 +503,14 @@ def get_checkout(self, rev_id, checkout_dir=None): old_file = old.get_file(path) except _mod_transport.NoSuchFile: continue - self.assertEqual( - old_file.read(), new.get_file(path).read()) + self.assertEqual(old_file.read(), new.get_file(path).read()) if not _mod_revision.is_null(rev_id): tree.branch.generate_revision_history(rev_id) tree.update() delta = tree.changes_from(self.b1.repository.revision_tree(rev_id)) - self.assertFalse(delta.has_changed(), - f'Working tree has modifications: {delta}') + self.assertFalse( + delta.has_changed(), f"Working tree has modifications: {delta}" + ) return tree def valid_apply_bundle(self, base_rev_id, info, checkout_dir=None): @@ -516,21 +530,22 @@ def _valid_apply_bundle(self, base_rev_id, info, to_tree): original_parents = to_tree.get_parent_ids() self.assertTrue(repository.has_revision(base_rev_id)) for rev in info.real_revisions: - self.assertTrue(not repository.has_revision(rev.revision_id), - 'Revision {%s} present before applying bundle' - % rev.revision_id) + self.assertTrue( + not repository.has_revision(rev.revision_id), + "Revision {%s} present before applying bundle" % rev.revision_id, + ) merge_bundle(info, to_tree, True, merge.Merge3Merger, False, False) for rev in info.real_revisions: - self.assertTrue(repository.has_revision(rev.revision_id), - 'Missing revision {%s} after applying bundle' - % rev.revision_id) + self.assertTrue( + repository.has_revision(rev.revision_id), + "Missing revision {%s} after applying bundle" % rev.revision_id, + ) self.assertTrue(to_tree.branch.repository.has_revision(info.target)) # Do we also want to verify that all the texts have been added? - self.assertEqual(original_parents + [info.target], - to_tree.get_parent_ids()) + self.assertEqual(original_parents + [info.target], to_tree.get_parent_ids()) rev = info.real_revisions[-1] base_tree = self.b1.repository.revision_tree(rev.revision_id) @@ -549,264 +564,283 @@ def _valid_apply_bundle(self, base_rev_id, info, to_tree): # Check that the meta information is the same to_path = InterTree.get(base_tree, to_tree).find_target_path(path) self.assertEqual( - base_tree.get_file_size(path), - to_tree.get_file_size(to_path)) + base_tree.get_file_size(path), to_tree.get_file_size(to_path) + ) self.assertEqual( - base_tree.get_file_sha1(path), - to_tree.get_file_sha1(to_path)) + base_tree.get_file_sha1(path), to_tree.get_file_sha1(to_path) + ) # Check that the contents are the same # This is pretty expensive # self.assertEqual(base_tree.get_file(fileid).read(), # to_tree.get_file(fileid).read()) def test_bundle(self): - self.tree1 = self.make_branch_and_tree('b1') + self.tree1 = self.make_branch_and_tree("b1") self.b1 = self.tree1.branch - self.build_tree_contents([('b1/one', b'one\n')]) - self.tree1.add('one', ids=b'one-id') - self.tree1.set_root_id(b'root-id') - self.tree1.commit('add one', rev_id=b'a@cset-0-1') + self.build_tree_contents([("b1/one", b"one\n")]) + self.tree1.add("one", ids=b"one-id") + self.tree1.set_root_id(b"root-id") + self.tree1.commit("add one", rev_id=b"a@cset-0-1") - self.get_valid_bundle(b'null:', b'a@cset-0-1') + self.get_valid_bundle(b"null:", b"a@cset-0-1") # Make sure we can handle files with spaces, tabs, other # bogus characters - self.build_tree([ - 'b1/with space.txt', 'b1/dir/', 'b1/dir/filein subdir.c', 'b1/dir/WithCaps.txt', 'b1/dir/ pre space', 'b1/sub/', 'b1/sub/sub/', 'b1/sub/sub/nonempty.txt' - ]) - self.build_tree_contents([('b1/sub/sub/emptyfile.txt', b''), - ('b1/dir/nolastnewline.txt', b'bloop')]) + self.build_tree( + [ + "b1/with space.txt", + "b1/dir/", + "b1/dir/filein subdir.c", + "b1/dir/WithCaps.txt", + "b1/dir/ pre space", + "b1/sub/", + "b1/sub/sub/", + "b1/sub/sub/nonempty.txt", + ] + ) + self.build_tree_contents( + [("b1/sub/sub/emptyfile.txt", b""), ("b1/dir/nolastnewline.txt", b"bloop")] + ) tt = self.tree1.transform() - tt.new_file('executable', tt.root, [b'#!/bin/sh\n'], b'exe-1', True) + tt.new_file("executable", tt.root, [b"#!/bin/sh\n"], b"exe-1", True) tt.apply() # have to fix length of file-id so that we can predictably rewrite # a (length-prefixed) record containing it later. - self.tree1.add('with space.txt', ids=b'withspace-id') - self.tree1.add([ - 'dir', 'dir/filein subdir.c', 'dir/WithCaps.txt', 'dir/ pre space', 'dir/nolastnewline.txt', 'sub', 'sub/sub', 'sub/sub/nonempty.txt', 'sub/sub/emptyfile.txt' - ]) - self.tree1.commit('add whitespace', rev_id=b'a@cset-0-2') - - self.get_valid_bundle(b'a@cset-0-1', b'a@cset-0-2') + self.tree1.add("with space.txt", ids=b"withspace-id") + self.tree1.add( + [ + "dir", + "dir/filein subdir.c", + "dir/WithCaps.txt", + "dir/ pre space", + "dir/nolastnewline.txt", + "sub", + "sub/sub", + "sub/sub/nonempty.txt", + "sub/sub/emptyfile.txt", + ] + ) + self.tree1.commit("add whitespace", rev_id=b"a@cset-0-2") + + self.get_valid_bundle(b"a@cset-0-1", b"a@cset-0-2") # Check a rollup bundle - self.get_valid_bundle(b'null:', b'a@cset-0-2') + self.get_valid_bundle(b"null:", b"a@cset-0-2") # Now delete entries - self.tree1.remove( - ['sub/sub/nonempty.txt', 'sub/sub/emptyfile.txt', 'sub/sub' - ]) + self.tree1.remove(["sub/sub/nonempty.txt", "sub/sub/emptyfile.txt", "sub/sub"]) tt = self.tree1.transform() - trans_id = tt.trans_id_tree_path('executable') + trans_id = tt.trans_id_tree_path("executable") tt.set_executability(False, trans_id) tt.apply() - self.tree1.commit('removed', rev_id=b'a@cset-0-3') - - self.get_valid_bundle(b'a@cset-0-2', b'a@cset-0-3') - self.assertRaises((errors.TestamentMismatch, - errors.VersionedFileInvalidChecksum, - errors.BadBundle), self.get_invalid_bundle, - b'a@cset-0-2', b'a@cset-0-3') + self.tree1.commit("removed", rev_id=b"a@cset-0-3") + + self.get_valid_bundle(b"a@cset-0-2", b"a@cset-0-3") + self.assertRaises( + ( + errors.TestamentMismatch, + errors.VersionedFileInvalidChecksum, + errors.BadBundle, + ), + self.get_invalid_bundle, + b"a@cset-0-2", + b"a@cset-0-3", + ) # Check a rollup bundle - self.get_valid_bundle(b'null:', b'a@cset-0-3') + self.get_valid_bundle(b"null:", b"a@cset-0-3") # Now move the directory - self.tree1.rename_one('dir', 'sub/dir') - self.tree1.commit('rename dir', rev_id=b'a@cset-0-4') + self.tree1.rename_one("dir", "sub/dir") + self.tree1.commit("rename dir", rev_id=b"a@cset-0-4") - self.get_valid_bundle(b'a@cset-0-3', b'a@cset-0-4') + self.get_valid_bundle(b"a@cset-0-3", b"a@cset-0-4") # Check a rollup bundle - self.get_valid_bundle(b'null:', b'a@cset-0-4') + self.get_valid_bundle(b"null:", b"a@cset-0-4") # Modified files - with open('b1/sub/dir/WithCaps.txt', 'ab') as f: - f.write(b'\nAdding some text\n') - with open('b1/sub/dir/ pre space', 'ab') as f: - f.write( - b'\r\nAdding some\r\nDOS format lines\r\n') - with open('b1/sub/dir/nolastnewline.txt', 'ab') as f: - f.write(b'\n') - self.tree1.rename_one('sub/dir/ pre space', - 'sub/ start space') - self.tree1.commit('Modified files', rev_id=b'a@cset-0-5') - self.get_valid_bundle(b'a@cset-0-4', b'a@cset-0-5') - - self.tree1.rename_one('sub/dir/WithCaps.txt', 'temp') - self.tree1.rename_one('with space.txt', 'WithCaps.txt') - self.tree1.rename_one('temp', 'with space.txt') - self.tree1.commit('swap filenames', rev_id=b'a@cset-0-6', - verbose=False) - self.get_valid_bundle(b'a@cset-0-5', b'a@cset-0-6') - other = self.get_checkout(b'a@cset-0-5') - tree1_inv = get_inventory_text(self.tree1.branch.repository, - b'a@cset-0-5') - tree2_inv = get_inventory_text(other.branch.repository, - b'a@cset-0-5') + with open("b1/sub/dir/WithCaps.txt", "ab") as f: + f.write(b"\nAdding some text\n") + with open("b1/sub/dir/ pre space", "ab") as f: + f.write(b"\r\nAdding some\r\nDOS format lines\r\n") + with open("b1/sub/dir/nolastnewline.txt", "ab") as f: + f.write(b"\n") + self.tree1.rename_one("sub/dir/ pre space", "sub/ start space") + self.tree1.commit("Modified files", rev_id=b"a@cset-0-5") + self.get_valid_bundle(b"a@cset-0-4", b"a@cset-0-5") + + self.tree1.rename_one("sub/dir/WithCaps.txt", "temp") + self.tree1.rename_one("with space.txt", "WithCaps.txt") + self.tree1.rename_one("temp", "with space.txt") + self.tree1.commit("swap filenames", rev_id=b"a@cset-0-6", verbose=False) + self.get_valid_bundle(b"a@cset-0-5", b"a@cset-0-6") + other = self.get_checkout(b"a@cset-0-5") + tree1_inv = get_inventory_text(self.tree1.branch.repository, b"a@cset-0-5") + tree2_inv = get_inventory_text(other.branch.repository, b"a@cset-0-5") self.assertEqualDiff(tree1_inv, tree2_inv) - other.rename_one('sub/dir/nolastnewline.txt', 'sub/nolastnewline.txt') - other.commit('rename file', rev_id=b'a@cset-0-6b') + other.rename_one("sub/dir/nolastnewline.txt", "sub/nolastnewline.txt") + other.commit("rename file", rev_id=b"a@cset-0-6b") self.tree1.merge_from_branch(other.branch) - self.tree1.commit('Merge', rev_id=b'a@cset-0-7', - verbose=False) - self.get_valid_bundle(b'a@cset-0-6', b'a@cset-0-7') + self.tree1.commit("Merge", rev_id=b"a@cset-0-7", verbose=False) + self.get_valid_bundle(b"a@cset-0-6", b"a@cset-0-7") def _test_symlink_bundle(self, link_name, link_target, new_link_target): - link_id = b'link-1' + link_id = b"link-1" self.requireFeature(features.SymlinkFeature(self.test_dir)) - self.tree1 = self.make_branch_and_tree('b1') + self.tree1 = self.make_branch_and_tree("b1") self.b1 = self.tree1.branch tt = self.tree1.transform() tt.new_symlink(link_name, tt.root, link_target, link_id) tt.apply() - self.tree1.commit('add symlink', rev_id=b'l@cset-0-1') - bundle = self.get_valid_bundle(b'null:', b'l@cset-0-1') - if getattr(bundle, 'revision_tree', None) is not None: + self.tree1.commit("add symlink", rev_id=b"l@cset-0-1") + bundle = self.get_valid_bundle(b"null:", b"l@cset-0-1") + if getattr(bundle, "revision_tree", None) is not None: # Not all bundle formats supports revision_tree - bund_tree = bundle.revision_tree(self.b1.repository, b'l@cset-0-1') - self.assertEqual( - link_target, bund_tree.get_symlink_target(link_name)) + bund_tree = bundle.revision_tree(self.b1.repository, b"l@cset-0-1") + self.assertEqual(link_target, bund_tree.get_symlink_target(link_name)) tt = self.tree1.transform() trans_id = tt.trans_id_tree_path(link_name) - tt.adjust_path('link2', tt.root, trans_id) + tt.adjust_path("link2", tt.root, trans_id) tt.delete_contents(trans_id) tt.create_symlink(new_link_target, trans_id) tt.apply() - self.tree1.commit('rename and change symlink', rev_id=b'l@cset-0-2') - bundle = self.get_valid_bundle(b'l@cset-0-1', b'l@cset-0-2') - if getattr(bundle, 'revision_tree', None) is not None: + self.tree1.commit("rename and change symlink", rev_id=b"l@cset-0-2") + bundle = self.get_valid_bundle(b"l@cset-0-1", b"l@cset-0-2") + if getattr(bundle, "revision_tree", None) is not None: # Not all bundle formats supports revision_tree - bund_tree = bundle.revision_tree(self.b1.repository, b'l@cset-0-2') - self.assertEqual(new_link_target, - bund_tree.get_symlink_target('link2')) + bund_tree = bundle.revision_tree(self.b1.repository, b"l@cset-0-2") + self.assertEqual(new_link_target, bund_tree.get_symlink_target("link2")) tt = self.tree1.transform() - trans_id = tt.trans_id_tree_path('link2') + trans_id = tt.trans_id_tree_path("link2") tt.delete_contents(trans_id) - tt.create_symlink('jupiter', trans_id) + tt.create_symlink("jupiter", trans_id) tt.apply() - self.tree1.commit('just change symlink target', rev_id=b'l@cset-0-3') - bundle = self.get_valid_bundle(b'l@cset-0-2', b'l@cset-0-3') + self.tree1.commit("just change symlink target", rev_id=b"l@cset-0-3") + bundle = self.get_valid_bundle(b"l@cset-0-2", b"l@cset-0-3") tt = self.tree1.transform() - trans_id = tt.trans_id_tree_path('link2') + trans_id = tt.trans_id_tree_path("link2") tt.delete_contents(trans_id) tt.apply() - self.tree1.commit('Delete symlink', rev_id=b'l@cset-0-4') - bundle = self.get_valid_bundle(b'l@cset-0-3', b'l@cset-0-4') + self.tree1.commit("Delete symlink", rev_id=b"l@cset-0-4") + bundle = self.get_valid_bundle(b"l@cset-0-3", b"l@cset-0-4") def test_symlink_bundle(self): - self._test_symlink_bundle('link', 'bar/foo', 'mars') + self._test_symlink_bundle("link", "bar/foo", "mars") def test_unicode_symlink_bundle(self): self.requireFeature(features.UnicodeFilenameFeature) - self._test_symlink_bundle('\N{Euro Sign}link', - 'bar/\N{Euro Sign}foo', - 'mars\N{Euro Sign}') + self._test_symlink_bundle( + "\N{Euro Sign}link", "bar/\N{Euro Sign}foo", "mars\N{Euro Sign}" + ) def test_binary_bundle(self): - self.tree1 = self.make_branch_and_tree('b1') + self.tree1 = self.make_branch_and_tree("b1") self.b1 = self.tree1.branch tt = self.tree1.transform() # Add - tt.new_file('file', tt.root, [ - b'\x00\n\x00\r\x01\n\x02\r\xff'], b'binary-1') - tt.new_file('file2', tt.root, [b'\x01\n\x02\r\x03\n\x04\r\xff'], - b'binary-2') + tt.new_file("file", tt.root, [b"\x00\n\x00\r\x01\n\x02\r\xff"], b"binary-1") + tt.new_file("file2", tt.root, [b"\x01\n\x02\r\x03\n\x04\r\xff"], b"binary-2") tt.apply() - self.tree1.commit('add binary', rev_id=b'b@cset-0-1') - self.get_valid_bundle(b'null:', b'b@cset-0-1') + self.tree1.commit("add binary", rev_id=b"b@cset-0-1") + self.get_valid_bundle(b"null:", b"b@cset-0-1") # Delete tt = self.tree1.transform() - trans_id = tt.trans_id_tree_path('file') + trans_id = tt.trans_id_tree_path("file") tt.delete_contents(trans_id) tt.apply() - self.tree1.commit('delete binary', rev_id=b'b@cset-0-2') - self.get_valid_bundle(b'b@cset-0-1', b'b@cset-0-2') + self.tree1.commit("delete binary", rev_id=b"b@cset-0-2") + self.get_valid_bundle(b"b@cset-0-1", b"b@cset-0-2") # Rename & modify tt = self.tree1.transform() - trans_id = tt.trans_id_tree_path('file2') - tt.adjust_path('file3', tt.root, trans_id) + trans_id = tt.trans_id_tree_path("file2") + tt.adjust_path("file3", tt.root, trans_id) tt.delete_contents(trans_id) - tt.create_file([b'file\rcontents\x00\n\x00'], trans_id) + tt.create_file([b"file\rcontents\x00\n\x00"], trans_id) tt.apply() - self.tree1.commit('rename and modify binary', rev_id=b'b@cset-0-3') - self.get_valid_bundle(b'b@cset-0-2', b'b@cset-0-3') + self.tree1.commit("rename and modify binary", rev_id=b"b@cset-0-3") + self.get_valid_bundle(b"b@cset-0-2", b"b@cset-0-3") # Modify tt = self.tree1.transform() - trans_id = tt.trans_id_tree_path('file3') + trans_id = tt.trans_id_tree_path("file3") tt.delete_contents(trans_id) - tt.create_file([b'\x00file\rcontents'], trans_id) + tt.create_file([b"\x00file\rcontents"], trans_id) tt.apply() - self.tree1.commit('just modify binary', rev_id=b'b@cset-0-4') - self.get_valid_bundle(b'b@cset-0-3', b'b@cset-0-4') + self.tree1.commit("just modify binary", rev_id=b"b@cset-0-4") + self.get_valid_bundle(b"b@cset-0-3", b"b@cset-0-4") # Rollup - self.get_valid_bundle(b'null:', b'b@cset-0-4') + self.get_valid_bundle(b"null:", b"b@cset-0-4") def test_last_modified(self): - self.tree1 = self.make_branch_and_tree('b1') + self.tree1 = self.make_branch_and_tree("b1") self.b1 = self.tree1.branch tt = self.tree1.transform() - tt.new_file('file', tt.root, [b'file'], b'file') + tt.new_file("file", tt.root, [b"file"], b"file") tt.apply() - self.tree1.commit('create file', rev_id=b'a@lmod-0-1') + self.tree1.commit("create file", rev_id=b"a@lmod-0-1") tt = self.tree1.transform() - trans_id = tt.trans_id_tree_path('file') + trans_id = tt.trans_id_tree_path("file") tt.delete_contents(trans_id) - tt.create_file([b'file2'], trans_id) + tt.create_file([b"file2"], trans_id) tt.apply() - self.tree1.commit('modify text', rev_id=b'a@lmod-0-2a') + self.tree1.commit("modify text", rev_id=b"a@lmod-0-2a") - other = self.get_checkout(b'a@lmod-0-1') + other = self.get_checkout(b"a@lmod-0-1") tt = other.transform() - trans_id = tt.trans_id_tree_path('file2') + trans_id = tt.trans_id_tree_path("file2") tt.delete_contents(trans_id) - tt.create_file([b'file2'], trans_id) + tt.create_file([b"file2"], trans_id) tt.apply() - other.commit('modify text in another tree', rev_id=b'a@lmod-0-2b') + other.commit("modify text in another tree", rev_id=b"a@lmod-0-2b") self.tree1.merge_from_branch(other.branch) - self.tree1.commit('Merge', rev_id=b'a@lmod-0-3', - verbose=False) - self.tree1.commit('Merge', rev_id=b'a@lmod-0-4') - self.get_valid_bundle(b'a@lmod-0-2a', b'a@lmod-0-4') + self.tree1.commit("Merge", rev_id=b"a@lmod-0-3", verbose=False) + self.tree1.commit("Merge", rev_id=b"a@lmod-0-4") + self.get_valid_bundle(b"a@lmod-0-2a", b"a@lmod-0-4") def test_hide_history(self): - self.tree1 = self.make_branch_and_tree('b1') + self.tree1 = self.make_branch_and_tree("b1") self.b1 = self.tree1.branch - with open('b1/one', 'wb') as f: - f.write(b'one\n') - self.tree1.add('one') - self.tree1.commit('add file', rev_id=b'a@cset-0-1') - with open('b1/one', 'wb') as f: - f.write(b'two\n') - self.tree1.commit('modify', rev_id=b'a@cset-0-2') - with open('b1/one', 'wb') as f: - f.write(b'three\n') - self.tree1.commit('modify', rev_id=b'a@cset-0-3') + with open("b1/one", "wb") as f: + f.write(b"one\n") + self.tree1.add("one") + self.tree1.commit("add file", rev_id=b"a@cset-0-1") + with open("b1/one", "wb") as f: + f.write(b"two\n") + self.tree1.commit("modify", rev_id=b"a@cset-0-2") + with open("b1/one", "wb") as f: + f.write(b"three\n") + self.tree1.commit("modify", rev_id=b"a@cset-0-3") bundle_file = BytesIO() - write_bundle(self.tree1.branch.repository, b'a@cset-0-3', - b'a@cset-0-1', bundle_file, format=self.format) - self.assertNotContainsRe(bundle_file.getvalue(), b'\btwo\b') - self.assertContainsRe(self.get_raw(bundle_file), b'one') - self.assertContainsRe(self.get_raw(bundle_file), b'three') + write_bundle( + self.tree1.branch.repository, + b"a@cset-0-3", + b"a@cset-0-1", + bundle_file, + format=self.format, + ) + self.assertNotContainsRe(bundle_file.getvalue(), b"\btwo\b") + self.assertContainsRe(self.get_raw(bundle_file), b"one") + self.assertContainsRe(self.get_raw(bundle_file), b"three") def test_bundle_same_basis(self): """Ensure using the basis as the target doesn't cause an error.""" - self.tree1 = self.make_branch_and_tree('b1') - self.tree1.commit('add file', rev_id=b'a@cset-0-1') + self.tree1 = self.make_branch_and_tree("b1") + self.tree1.commit("add file", rev_id=b"a@cset-0-1") bundle_file = BytesIO() - write_bundle(self.tree1.branch.repository, b'a@cset-0-1', - b'a@cset-0-1', bundle_file) + write_bundle( + self.tree1.branch.repository, b"a@cset-0-1", b"a@cset-0-1", bundle_file + ) @staticmethod def get_raw(bundle_file): @@ -815,353 +849,369 @@ def get_raw(bundle_file): def test_unicode_bundle(self): self.requireFeature(features.UnicodeFilenameFeature) # Handle international characters - os.mkdir('b1') - f = open('b1/with Dod\N{Euro Sign}', 'wb') + os.mkdir("b1") + f = open("b1/with Dod\N{Euro Sign}", "wb") - self.tree1 = self.make_branch_and_tree('b1') + self.tree1 = self.make_branch_and_tree("b1") self.b1 = self.tree1.branch - f.write(('A file\n' - 'With international man of mystery\n' - 'William Dod\xe9\n').encode()) + f.write( + ( + "A file\n" "With international man of mystery\n" "William Dod\xe9\n" + ).encode() + ) f.close() - self.tree1.add(['with Dod\N{Euro Sign}'], ids=[b'withdod-id']) - self.tree1.commit('i18n commit from William Dod\xe9', - rev_id=b'i18n-1', committer='William Dod\xe9') + self.tree1.add(["with Dod\N{Euro Sign}"], ids=[b"withdod-id"]) + self.tree1.commit( + "i18n commit from William Dod\xe9", + rev_id=b"i18n-1", + committer="William Dod\xe9", + ) # Add - self.get_valid_bundle(b'null:', b'i18n-1') + self.get_valid_bundle(b"null:", b"i18n-1") # Modified - f = open('b1/with Dod\N{Euro Sign}', 'wb') - f.write('Modified \xb5\n'.encode()) + f = open("b1/with Dod\N{Euro Sign}", "wb") + f.write("Modified \xb5\n".encode()) f.close() - self.tree1.commit('modified', rev_id=b'i18n-2') + self.tree1.commit("modified", rev_id=b"i18n-2") - self.get_valid_bundle(b'i18n-1', b'i18n-2') + self.get_valid_bundle(b"i18n-1", b"i18n-2") # Renamed - self.tree1.rename_one('with Dod\N{Euro Sign}', 'B\N{Euro Sign}gfors') - self.tree1.commit('renamed, the new i18n man', rev_id=b'i18n-3', - committer='Erik B\xe5gfors') + self.tree1.rename_one("with Dod\N{Euro Sign}", "B\N{Euro Sign}gfors") + self.tree1.commit( + "renamed, the new i18n man", rev_id=b"i18n-3", committer="Erik B\xe5gfors" + ) - self.get_valid_bundle(b'i18n-2', b'i18n-3') + self.get_valid_bundle(b"i18n-2", b"i18n-3") # Removed - self.tree1.remove(['B\N{Euro Sign}gfors']) - self.tree1.commit('removed', rev_id=b'i18n-4') + self.tree1.remove(["B\N{Euro Sign}gfors"]) + self.tree1.commit("removed", rev_id=b"i18n-4") - self.get_valid_bundle(b'i18n-3', b'i18n-4') + self.get_valid_bundle(b"i18n-3", b"i18n-4") # Rollup - self.get_valid_bundle(b'null:', b'i18n-4') + self.get_valid_bundle(b"null:", b"i18n-4") def test_whitespace_bundle(self): - if sys.platform in ('win32', 'cygwin'): - raise tests.TestSkipped('Windows doesn\'t support filenames' - ' with tabs or trailing spaces') - self.tree1 = self.make_branch_and_tree('b1') + if sys.platform in ("win32", "cygwin"): + raise tests.TestSkipped( + "Windows doesn't support filenames" " with tabs or trailing spaces" + ) + self.tree1 = self.make_branch_and_tree("b1") self.b1 = self.tree1.branch - self.build_tree(['b1/trailing space ']) - self.tree1.add(['trailing space ']) + self.build_tree(["b1/trailing space "]) + self.tree1.add(["trailing space "]) # TODO: jam 20060701 Check for handling files with '\t' characters # once we actually support them # Added - self.tree1.commit('funky whitespace', rev_id=b'white-1') + self.tree1.commit("funky whitespace", rev_id=b"white-1") - self.get_valid_bundle(b'null:', b'white-1') + self.get_valid_bundle(b"null:", b"white-1") # Modified - with open('b1/trailing space ', 'ab') as f: - f.write(b'add some text\n') - self.tree1.commit('add text', rev_id=b'white-2') + with open("b1/trailing space ", "ab") as f: + f.write(b"add some text\n") + self.tree1.commit("add text", rev_id=b"white-2") - self.get_valid_bundle(b'white-1', b'white-2') + self.get_valid_bundle(b"white-1", b"white-2") # Renamed - self.tree1.rename_one('trailing space ', ' start and end space ') - self.tree1.commit('rename', rev_id=b'white-3') + self.tree1.rename_one("trailing space ", " start and end space ") + self.tree1.commit("rename", rev_id=b"white-3") - self.get_valid_bundle(b'white-2', b'white-3') + self.get_valid_bundle(b"white-2", b"white-3") # Removed - self.tree1.remove([' start and end space ']) - self.tree1.commit('removed', rev_id=b'white-4') + self.tree1.remove([" start and end space "]) + self.tree1.commit("removed", rev_id=b"white-4") - self.get_valid_bundle(b'white-3', b'white-4') + self.get_valid_bundle(b"white-3", b"white-4") # Now test a complet roll-up - self.get_valid_bundle(b'null:', b'white-4') + self.get_valid_bundle(b"null:", b"white-4") def test_alt_timezone_bundle(self): - self.tree1 = self.make_branch_and_memory_tree('b1') + self.tree1 = self.make_branch_and_memory_tree("b1") self.b1 = self.tree1.branch builder = treebuilder.TreeBuilder() self.tree1.lock_write() builder.start_tree(self.tree1) - builder.build(['newfile']) + builder.build(["newfile"]) builder.finish_tree() # Asia/Colombo offset = 5 hours 30 minutes - self.tree1.commit('non-hour offset timezone', rev_id=b'tz-1', - timezone=19800, timestamp=1152544886.0) + self.tree1.commit( + "non-hour offset timezone", + rev_id=b"tz-1", + timezone=19800, + timestamp=1152544886.0, + ) - bundle = self.get_valid_bundle(b'null:', b'tz-1') + bundle = self.get_valid_bundle(b"null:", b"tz-1") rev = bundle.revisions[0] - self.assertEqual('Mon 2006-07-10 20:51:26.000000000 +0530', rev.date) + self.assertEqual("Mon 2006-07-10 20:51:26.000000000 +0530", rev.date) self.assertEqual(19800, rev.timezone) self.assertEqual(1152544886.0, rev.timestamp) self.tree1.unlock() def test_bundle_root_id(self): - self.tree1 = self.make_branch_and_tree('b1') + self.tree1 = self.make_branch_and_tree("b1") self.b1 = self.tree1.branch - self.tree1.commit('message', rev_id=b'revid1') - bundle = self.get_valid_bundle(b'null:', b'revid1') - tree = self.get_bundle_tree(bundle, b'revid1') - root_revision = tree.get_file_revision('') - self.assertEqual(b'revid1', root_revision) + self.tree1.commit("message", rev_id=b"revid1") + bundle = self.get_valid_bundle(b"null:", b"revid1") + tree = self.get_bundle_tree(bundle, b"revid1") + root_revision = tree.get_file_revision("") + self.assertEqual(b"revid1", root_revision) def test_install_revisions(self): - self.tree1 = self.make_branch_and_tree('b1') + self.tree1 = self.make_branch_and_tree("b1") self.b1 = self.tree1.branch - self.tree1.commit('message', rev_id=b'rev2a') - bundle = self.get_valid_bundle(b'null:', b'rev2a') - branch2 = self.make_branch('b2') - self.assertFalse(branch2.repository.has_revision(b'rev2a')) + self.tree1.commit("message", rev_id=b"rev2a") + bundle = self.get_valid_bundle(b"null:", b"rev2a") + branch2 = self.make_branch("b2") + self.assertFalse(branch2.repository.has_revision(b"rev2a")) target_revision = bundle.install_revisions(branch2.repository) - self.assertTrue(branch2.repository.has_revision(b'rev2a')) - self.assertEqual(b'rev2a', target_revision) + self.assertTrue(branch2.repository.has_revision(b"rev2a")) + self.assertEqual(b"rev2a", target_revision) def test_bundle_empty_property(self): """Test serializing revision properties with an empty value.""" - tree = self.make_branch_and_memory_tree('tree') + tree = self.make_branch_and_memory_tree("tree") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'TREE_ROOT']) - tree.commit('One', revprops={'one': 'two', - 'empty': ''}, rev_id=b'rev1') + tree.add([""], ids=[b"TREE_ROOT"]) + tree.commit("One", revprops={"one": "two", "empty": ""}, rev_id=b"rev1") self.b1 = tree.branch - bundle_sio, revision_ids = self.create_bundle_text(b'null:', b'rev1') + bundle_sio, revision_ids = self.create_bundle_text(b"null:", b"rev1") bundle = read_bundle(bundle_sio) revision_info = bundle.revisions[0] - self.assertEqual(b'rev1', revision_info.revision_id) + self.assertEqual(b"rev1", revision_info.revision_id) rev = revision_info.as_revision() - self.assertEqual({'branch-nick': 'tree', 'empty': '', 'one': 'two'}, - rev.properties) + self.assertEqual( + {"branch-nick": "tree", "empty": "", "one": "two"}, rev.properties + ) def test_bundle_sorted_properties(self): """For stability the writer should write properties in sorted order.""" - tree = self.make_branch_and_memory_tree('tree') + tree = self.make_branch_and_memory_tree("tree") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'TREE_ROOT']) - tree.commit('One', rev_id=b'rev1', - revprops={'a': '4', 'b': '3', 'c': '2', 'd': '1'}) + tree.add([""], ids=[b"TREE_ROOT"]) + tree.commit( + "One", rev_id=b"rev1", revprops={"a": "4", "b": "3", "c": "2", "d": "1"} + ) self.b1 = tree.branch - bundle_sio, revision_ids = self.create_bundle_text(b'null:', b'rev1') + bundle_sio, revision_ids = self.create_bundle_text(b"null:", b"rev1") bundle = read_bundle(bundle_sio) revision_info = bundle.revisions[0] - self.assertEqual(b'rev1', revision_info.revision_id) + self.assertEqual(b"rev1", revision_info.revision_id) rev = revision_info.as_revision() - self.assertEqual({'branch-nick': 'tree', 'a': '4', 'b': '3', 'c': '2', - 'd': '1'}, rev.properties) + self.assertEqual( + {"branch-nick": "tree", "a": "4", "b": "3", "c": "2", "d": "1"}, + rev.properties, + ) def test_bundle_unicode_properties(self): """We should be able to round trip a non-ascii property.""" - tree = self.make_branch_and_memory_tree('tree') + tree = self.make_branch_and_memory_tree("tree") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'TREE_ROOT']) + tree.add([""], ids=[b"TREE_ROOT"]) # Revisions themselves do not require anything about revision property # keys, other than that they are a basestring, and do not contain # whitespace. # However, Testaments assert than they are str(), and thus should not # be Unicode. - tree.commit('One', rev_id=b'rev1', - revprops={'omega': '\u03a9', 'alpha': '\u03b1'}) + tree.commit( + "One", rev_id=b"rev1", revprops={"omega": "\u03a9", "alpha": "\u03b1"} + ) self.b1 = tree.branch - bundle_sio, revision_ids = self.create_bundle_text(b'null:', b'rev1') + bundle_sio, revision_ids = self.create_bundle_text(b"null:", b"rev1") bundle = read_bundle(bundle_sio) revision_info = bundle.revisions[0] - self.assertEqual(b'rev1', revision_info.revision_id) + self.assertEqual(b"rev1", revision_info.revision_id) rev = revision_info.as_revision() - self.assertEqual({'branch-nick': 'tree', 'omega': '\u03a9', - 'alpha': '\u03b1'}, rev.properties) + self.assertEqual( + {"branch-nick": "tree", "omega": "\u03a9", "alpha": "\u03b1"}, + rev.properties, + ) def test_bundle_with_ghosts(self): - tree = self.make_branch_and_tree('tree') + tree = self.make_branch_and_tree("tree") self.b1 = tree.branch - self.build_tree_contents([('tree/file', b'content1')]) - tree.add(['file']) - tree.commit('rev1') - self.build_tree_contents([('tree/file', b'content2')]) - tree.add_parent_tree_id(b'ghost') - tree.commit('rev2', rev_id=b'rev2') - self.get_valid_bundle(b'null:', b'rev2') + self.build_tree_contents([("tree/file", b"content1")]) + tree.add(["file"]) + tree.commit("rev1") + self.build_tree_contents([("tree/file", b"content2")]) + tree.add_parent_tree_id(b"ghost") + tree.commit("rev2", rev_id=b"rev2") + self.get_valid_bundle(b"null:", b"rev2") def make_simple_tree(self, format=None): - tree = self.make_branch_and_tree('b1', format=format) + tree = self.make_branch_and_tree("b1", format=format) self.b1 = tree.branch - self.build_tree(['b1/file']) - tree.add('file') + self.build_tree(["b1/file"]) + tree.add("file") return tree def test_across_serializers(self): - tree = self.make_simple_tree('knit') - tree.commit('hello', rev_id=b'rev1') - tree.commit('hello', rev_id=b'rev2') - bundle = read_bundle(self.create_bundle_text(b'null:', b'rev2')[0]) - repo = self.make_repository('repo', format='dirstate-with-subtree') + tree = self.make_simple_tree("knit") + tree.commit("hello", rev_id=b"rev1") + tree.commit("hello", rev_id=b"rev2") + bundle = read_bundle(self.create_bundle_text(b"null:", b"rev2")[0]) + repo = self.make_repository("repo", format="dirstate-with-subtree") bundle.install_revisions(repo) - inv_text = b''.join(repo._get_inventory_xml(b'rev2')) + inv_text = b"".join(repo._get_inventory_xml(b"rev2")) self.assertNotContainsRe(inv_text, b'format="5"') self.assertContainsRe(inv_text, b'format="7"') def make_repo_with_installed_revisions(self): - tree = self.make_simple_tree('knit') - tree.commit('hello', rev_id=b'rev1') - tree.commit('hello', rev_id=b'rev2') - bundle = read_bundle(self.create_bundle_text(b'null:', b'rev2')[0]) - repo = self.make_repository('repo', format='dirstate-with-subtree') + tree = self.make_simple_tree("knit") + tree.commit("hello", rev_id=b"rev1") + tree.commit("hello", rev_id=b"rev2") + bundle = read_bundle(self.create_bundle_text(b"null:", b"rev2")[0]) + repo = self.make_repository("repo", format="dirstate-with-subtree") bundle.install_revisions(repo) return repo def test_across_models(self): repo = self.make_repo_with_installed_revisions() - inv = repo.get_inventory(b'rev2') - self.assertEqual(b'rev2', inv.root.revision) + inv = repo.get_inventory(b"rev2") + self.assertEqual(b"rev2", inv.root.revision) root_id = inv.root.file_id repo.lock_read() self.addCleanup(repo.unlock) - self.assertEqual({(root_id, b'rev1'): (), - (root_id, b'rev2'): ((root_id, b'rev1'),)}, - repo.texts.get_parent_map([(root_id, b'rev1'), (root_id, b'rev2')])) + self.assertEqual( + {(root_id, b"rev1"): (), (root_id, b"rev2"): ((root_id, b"rev1"),)}, + repo.texts.get_parent_map([(root_id, b"rev1"), (root_id, b"rev2")]), + ) def test_inv_hash_across_serializers(self): repo = self.make_repo_with_installed_revisions() - recorded_inv_sha1 = repo.get_revision(b'rev2').inventory_sha1 - xml = b''.join(repo._get_inventory_xml(b'rev2')) + recorded_inv_sha1 = repo.get_revision(b"rev2").inventory_sha1 + xml = b"".join(repo._get_inventory_xml(b"rev2")) self.assertEqual(osutils.sha_string(xml), recorded_inv_sha1) def test_across_models_incompatible(self): - tree = self.make_simple_tree('dirstate-with-subtree') - tree.commit('hello', rev_id=b'rev1') - tree.commit('hello', rev_id=b'rev2') + tree = self.make_simple_tree("dirstate-with-subtree") + tree.commit("hello", rev_id=b"rev1") + tree.commit("hello", rev_id=b"rev2") try: - bundle = read_bundle(self.create_bundle_text(b'null:', b'rev1')[0]) + bundle = read_bundle(self.create_bundle_text(b"null:", b"rev1")[0]) except errors.IncompatibleBundleFormat as e: raise tests.TestSkipped("Format 0.8 doesn't work with knit3") from e - repo = self.make_repository('repo', format='knit') + repo = self.make_repository("repo", format="knit") bundle.install_revisions(repo) - bundle = read_bundle(self.create_bundle_text(b'null:', b'rev2')[0]) - self.assertRaises(errors.IncompatibleRevision, - bundle.install_revisions, repo) + bundle = read_bundle(self.create_bundle_text(b"null:", b"rev2")[0]) + self.assertRaises(errors.IncompatibleRevision, bundle.install_revisions, repo) def test_get_merge_request(self): tree = self.make_simple_tree() - tree.commit('hello', rev_id=b'rev1') - tree.commit('hello', rev_id=b'rev2') - bundle = read_bundle(self.create_bundle_text(b'null:', b'rev1')[0]) + tree.commit("hello", rev_id=b"rev1") + tree.commit("hello", rev_id=b"rev2") + bundle = read_bundle(self.create_bundle_text(b"null:", b"rev1")[0]) result = bundle.get_merge_request(tree.branch.repository) - self.assertEqual((None, b'rev1', 'inapplicable'), result) + self.assertEqual((None, b"rev1", "inapplicable"), result) def test_with_subtree(self): - tree = self.make_branch_and_tree('tree', - format='dirstate-with-subtree') + tree = self.make_branch_and_tree("tree", format="dirstate-with-subtree") self.b1 = tree.branch - self.make_branch_and_tree('tree/subtree', - format='dirstate-with-subtree') - tree.add('subtree') - tree.commit('hello', rev_id=b'rev1') + self.make_branch_and_tree("tree/subtree", format="dirstate-with-subtree") + tree.add("subtree") + tree.commit("hello", rev_id=b"rev1") try: - bundle = read_bundle(self.create_bundle_text(b'null:', b'rev1')[0]) + bundle = read_bundle(self.create_bundle_text(b"null:", b"rev1")[0]) except errors.IncompatibleBundleFormat as e: raise tests.TestSkipped("Format 0.8 doesn't work with knit3") from e if isinstance(bundle, v09.BundleInfo09): raise tests.TestSkipped("Format 0.9 doesn't work with subtrees") - repo = self.make_repository('repo', format='knit') - self.assertRaises(errors.IncompatibleRevision, - bundle.install_revisions, repo) - repo2 = self.make_repository('repo2', format='dirstate-with-subtree') + repo = self.make_repository("repo", format="knit") + self.assertRaises(errors.IncompatibleRevision, bundle.install_revisions, repo) + repo2 = self.make_repository("repo2", format="dirstate-with-subtree") bundle.install_revisions(repo2) def test_revision_id_with_slash(self): - self.tree1 = self.make_branch_and_tree('tree') + self.tree1 = self.make_branch_and_tree("tree") self.b1 = self.tree1.branch try: - self.tree1.commit('Revision/id/with/slashes', rev_id=b'rev/id') + self.tree1.commit("Revision/id/with/slashes", rev_id=b"rev/id") except ValueError as e: raise tests.TestSkipped( - "Repository doesn't support revision ids with slashes") from e - self.get_valid_bundle(b'null:', b'rev/id') + "Repository doesn't support revision ids with slashes" + ) from e + self.get_valid_bundle(b"null:", b"rev/id") def test_skip_file(self): """Make sure we don't accidentally write to the wrong versionedfile.""" - self.tree1 = self.make_branch_and_tree('tree') + self.tree1 = self.make_branch_and_tree("tree") self.b1 = self.tree1.branch # rev1 is not present in bundle, done by fetch - self.build_tree_contents([('tree/file2', b'contents1')]) - self.tree1.add('file2', ids=b'file2-id') - self.tree1.commit('rev1', rev_id=b'reva') - self.build_tree_contents([('tree/file3', b'contents2')]) + self.build_tree_contents([("tree/file2", b"contents1")]) + self.tree1.add("file2", ids=b"file2-id") + self.tree1.commit("rev1", rev_id=b"reva") + self.build_tree_contents([("tree/file3", b"contents2")]) # rev2 is present in bundle, and done by fetch # having file1 in the bunle causes file1's versionedfile to be opened. - self.tree1.add('file3', ids=b'file3-id') - rev2 = self.tree1.commit('rev2') + self.tree1.add("file3", ids=b"file3-id") + rev2 = self.tree1.commit("rev2") # Updating file2 should not cause an attempt to add to file1's vf - target = self.tree1.controldir.sprout('target').open_workingtree() - self.build_tree_contents([('tree/file2', b'contents3')]) - self.tree1.commit('rev3', rev_id=b'rev3') - bundle = self.get_valid_bundle(b'reva', b'rev3') - if getattr(bundle, 'get_bundle_reader', None) is None: - raise tests.TestSkipped('Bundle format cannot provide reader') + target = self.tree1.controldir.sprout("target").open_workingtree() + self.build_tree_contents([("tree/file2", b"contents3")]) + self.tree1.commit("rev3", rev_id=b"rev3") + bundle = self.get_valid_bundle(b"reva", b"rev3") + if getattr(bundle, "get_bundle_reader", None) is None: + raise tests.TestSkipped("Bundle format cannot provide reader") file_ids = { - (f, r) for b, m, k, r, f in bundle.get_bundle_reader().iter_records() - if f is not None} - self.assertEqual( - {(b'file2-id', b'rev3'), (b'file3-id', rev2)}, file_ids) + (f, r) + for b, m, k, r, f in bundle.get_bundle_reader().iter_records() + if f is not None + } + self.assertEqual({(b"file2-id", b"rev3"), (b"file3-id", rev2)}, file_ids) bundle.install_revisions(target.branch.repository) class V08BundleTester(BundleTester, tests.TestCaseWithTransport): - - format = '0.8' + format = "0.8" def test_bundle_empty_property(self): """Test serializing revision properties with an empty value.""" - tree = self.make_branch_and_memory_tree('tree') + tree = self.make_branch_and_memory_tree("tree") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'TREE_ROOT']) - tree.commit('One', revprops={'one': 'two', - 'empty': ''}, rev_id=b'rev1') + tree.add([""], ids=[b"TREE_ROOT"]) + tree.commit("One", revprops={"one": "two", "empty": ""}, rev_id=b"rev1") self.b1 = tree.branch - bundle_sio, revision_ids = self.create_bundle_text(b'null:', b'rev1') - self.assertContainsRe(bundle_sio.getvalue(), - b'# properties:\n' - b'# branch-nick: tree\n' - b'# empty: \n' - b'# one: two\n' - ) + bundle_sio, revision_ids = self.create_bundle_text(b"null:", b"rev1") + self.assertContainsRe( + bundle_sio.getvalue(), + b"# properties:\n" + b"# branch-nick: tree\n" + b"# empty: \n" + b"# one: two\n", + ) bundle = read_bundle(bundle_sio) revision_info = bundle.revisions[0] - self.assertEqual(b'rev1', revision_info.revision_id) + self.assertEqual(b"rev1", revision_info.revision_id) rev = revision_info.as_revision() - self.assertEqual({'branch-nick': 'tree', 'empty': '', 'one': 'two'}, - rev.properties) + self.assertEqual( + {"branch-nick": "tree", "empty": "", "one": "two"}, rev.properties + ) def get_bundle_tree(self, bundle, revision_id): - repository = self.make_repository('repo') - return bundle.revision_tree(repository, b'revid1') + repository = self.make_repository("repo") + return bundle.revision_tree(repository, b"revid1") def test_bundle_empty_property_alt(self): r"""Test serializing revision properties with an empty value. @@ -1171,91 +1221,99 @@ def test_bundle_empty_property_alt(self): empty value as ':\n'. This tests make sure that all newer bzr versions can handle th second form. """ - tree = self.make_branch_and_memory_tree('tree') + tree = self.make_branch_and_memory_tree("tree") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'TREE_ROOT']) - tree.commit('One', revprops={'one': 'two', - 'empty': ''}, rev_id=b'rev1') + tree.add([""], ids=[b"TREE_ROOT"]) + tree.commit("One", revprops={"one": "two", "empty": ""}, rev_id=b"rev1") self.b1 = tree.branch - bundle_sio, revision_ids = self.create_bundle_text(b'null:', b'rev1') + bundle_sio, revision_ids = self.create_bundle_text(b"null:", b"rev1") txt = bundle_sio.getvalue() - loc = txt.find(b'# empty: ') + len(b'# empty:') + loc = txt.find(b"# empty: ") + len(b"# empty:") # Create a new bundle, which strips the trailing space after empty - bundle_sio = BytesIO(txt[:loc] + txt[loc + 1:]) - - self.assertContainsRe(bundle_sio.getvalue(), - b'# properties:\n' - b'# branch-nick: tree\n' - b'# empty:\n' - b'# one: two\n' - ) + bundle_sio = BytesIO(txt[:loc] + txt[loc + 1 :]) + + self.assertContainsRe( + bundle_sio.getvalue(), + b"# properties:\n" + b"# branch-nick: tree\n" + b"# empty:\n" + b"# one: two\n", + ) bundle = read_bundle(bundle_sio) revision_info = bundle.revisions[0] - self.assertEqual(b'rev1', revision_info.revision_id) + self.assertEqual(b"rev1", revision_info.revision_id) rev = revision_info.as_revision() - self.assertEqual({'branch-nick': 'tree', 'empty': '', 'one': 'two'}, - rev.properties) + self.assertEqual( + {"branch-nick": "tree", "empty": "", "one": "two"}, rev.properties + ) def test_bundle_sorted_properties(self): """For stability the writer should write properties in sorted order.""" - tree = self.make_branch_and_memory_tree('tree') + tree = self.make_branch_and_memory_tree("tree") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'TREE_ROOT']) - tree.commit('One', rev_id=b'rev1', - revprops={'a': '4', 'b': '3', 'c': '2', 'd': '1'}) + tree.add([""], ids=[b"TREE_ROOT"]) + tree.commit( + "One", rev_id=b"rev1", revprops={"a": "4", "b": "3", "c": "2", "d": "1"} + ) self.b1 = tree.branch - bundle_sio, revision_ids = self.create_bundle_text(b'null:', b'rev1') - self.assertContainsRe(bundle_sio.getvalue(), - b'# properties:\n' - b'# a: 4\n' - b'# b: 3\n' - b'# branch-nick: tree\n' - b'# c: 2\n' - b'# d: 1\n' - ) + bundle_sio, revision_ids = self.create_bundle_text(b"null:", b"rev1") + self.assertContainsRe( + bundle_sio.getvalue(), + b"# properties:\n" + b"# a: 4\n" + b"# b: 3\n" + b"# branch-nick: tree\n" + b"# c: 2\n" + b"# d: 1\n", + ) bundle = read_bundle(bundle_sio) revision_info = bundle.revisions[0] - self.assertEqual(b'rev1', revision_info.revision_id) + self.assertEqual(b"rev1", revision_info.revision_id) rev = revision_info.as_revision() - self.assertEqual({'branch-nick': 'tree', 'a': '4', 'b': '3', 'c': '2', - 'd': '1'}, rev.properties) + self.assertEqual( + {"branch-nick": "tree", "a": "4", "b": "3", "c": "2", "d": "1"}, + rev.properties, + ) def test_bundle_unicode_properties(self): """We should be able to round trip a non-ascii property.""" - tree = self.make_branch_and_memory_tree('tree') + tree = self.make_branch_and_memory_tree("tree") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'TREE_ROOT']) + tree.add([""], ids=[b"TREE_ROOT"]) # Revisions themselves do not require anything about revision property # keys, other than that they are a basestring, and do not contain # whitespace. # However, Testaments assert than they are str(), and thus should not # be Unicode. - tree.commit('One', rev_id=b'rev1', - revprops={'omega': '\u03a9', 'alpha': '\u03b1'}) + tree.commit( + "One", rev_id=b"rev1", revprops={"omega": "\u03a9", "alpha": "\u03b1"} + ) self.b1 = tree.branch - bundle_sio, revision_ids = self.create_bundle_text(b'null:', b'rev1') - self.assertContainsRe(bundle_sio.getvalue(), - b'# properties:\n' - b'# alpha: \xce\xb1\n' - b'# branch-nick: tree\n' - b'# omega: \xce\xa9\n' - ) + bundle_sio, revision_ids = self.create_bundle_text(b"null:", b"rev1") + self.assertContainsRe( + bundle_sio.getvalue(), + b"# properties:\n" + b"# alpha: \xce\xb1\n" + b"# branch-nick: tree\n" + b"# omega: \xce\xa9\n", + ) bundle = read_bundle(bundle_sio) revision_info = bundle.revisions[0] - self.assertEqual(b'rev1', revision_info.revision_id) + self.assertEqual(b"rev1", revision_info.revision_id) rev = revision_info.as_revision() - self.assertEqual({'branch-nick': 'tree', 'omega': '\u03a9', - 'alpha': '\u03b1'}, rev.properties) + self.assertEqual( + {"branch-nick": "tree", "omega": "\u03a9", "alpha": "\u03b1"}, + rev.properties, + ) class V09BundleKnit2Tester(V08BundleTester): - - format = '0.9' + format = "0.9" def bzrdir_format(self): format = bzrdir.BzrDirMetaFormat1() @@ -1264,8 +1322,7 @@ def bzrdir_format(self): class V09BundleKnit1Tester(V08BundleTester): - - format = '0.9' + format = "0.9" def bzrdir_format(self): format = bzrdir.BzrDirMetaFormat1() @@ -1274,8 +1331,7 @@ def bzrdir_format(self): class V4BundleTester(BundleTester, tests.TestCaseWithTransport): - - format = '4' + format = "4" def get_valid_bundle(self, base_rev_id, rev_id, checkout_dir=None): """Create a bundle from base_rev_id -> rev_id in built-in branch. @@ -1295,17 +1351,21 @@ def get_valid_bundle(self, base_rev_id, rev_id, checkout_dir=None): # only will match if everything is okay, but lets be explicit about # it branch_rev = repository.get_revision(bundle_rev.revision_id) - for a in ('inventory_sha1', 'revision_id', 'parent_ids', - 'timestamp', 'timezone', 'message', 'committer', - 'parent_ids', 'properties'): - self.assertEqual(getattr(branch_rev, a), - getattr(bundle_rev, a)) - self.assertEqual(len(branch_rev.parent_ids), - len(bundle_rev.parent_ids)) - self.assertEqual(set(rev_ids), - {r.revision_id for r in bundle.real_revisions}) - self.valid_apply_bundle(base_rev_id, bundle, - checkout_dir=checkout_dir) + for a in ( + "inventory_sha1", + "revision_id", + "parent_ids", + "timestamp", + "timezone", + "message", + "committer", + "parent_ids", + "properties", + ): + self.assertEqual(getattr(branch_rev, a), getattr(bundle_rev, a)) + self.assertEqual(len(branch_rev.parent_ids), len(bundle_rev.parent_ids)) + self.assertEqual(set(rev_ids), {r.revision_id for r in bundle.real_revisions}) + self.valid_apply_bundle(base_rev_id, bundle, checkout_dir=checkout_dir) return bundle @@ -1316,14 +1376,16 @@ def get_invalid_bundle(self, base_rev_id, rev_id): :return: The in-memory bundle """ from ..bundle import serializer + bundle_txt, rev_ids = self.create_bundle_text(base_rev_id, rev_id) - new_text = self.get_raw(BytesIO(b''.join(bundle_txt))) - new_text = new_text.replace(b' rev_id in built-in branch. @@ -1443,17 +1508,19 @@ def get_invalid_bundle(self, base_rev_id, rev_id): :return: The in-memory bundle """ from ..bundle import serializer + bundle_txt, rev_ids = self.create_bundle_text(base_rev_id, rev_id) - new_text = self.get_raw(BytesIO(b''.join(bundle_txt))) + new_text = self.get_raw(BytesIO(b"".join(bundle_txt))) # We are going to be replacing some text to set the executable bit on a # file. Make sure the text replacement actually works correctly. - self.assertContainsRe(new_text, b'(?m)B244\n\ni 1\n\n' - b'\n' - b'c 0 1 1 2\n' - b'c 1 3 3 2\n', bytes) + b"\n" + b"c 0 1 1 2\n" + b"c 1 3 3 2\n", + bytes, + ) def test_single_inv_no_parents_as_xml(self): self.make_merged_branch() - sio = self.make_bundle_just_inventories(b'null:', b'a@cset-0-1', - [b'a@cset-0-1']) + sio = self.make_bundle_just_inventories( + b"null:", b"a@cset-0-1", [b"a@cset-0-1"] + ) reader = v4.BundleReader(sio, stream_input=False) records = list(reader.iter_records()) self.assertEqual(1, len(records)) - (bytes, metadata, repo_kind, revision_id, - file_id) = records[0] + (bytes, metadata, repo_kind, revision_id, file_id) = records[0] self.assertIs(None, file_id) - self.assertEqual(b'a@cset-0-1', revision_id) - self.assertEqual('inventory', repo_kind) - self.assertEqual({b'parents': [], - b'sha1': b'a13f42b142d544aac9b085c42595d304150e31a2', - b'storage_kind': b'mpdiff', - }, metadata) + self.assertEqual(b"a@cset-0-1", revision_id) + self.assertEqual("inventory", repo_kind) + self.assertEqual( + { + b"parents": [], + b"sha1": b"a13f42b142d544aac9b085c42595d304150e31a2", + b"storage_kind": b"mpdiff", + }, + metadata, + ) # We should have an mpdiff that takes some lines from both parents. self.assertEqualDiff( - b'i 4\n' + b"i 4\n" b'\n' b'\n' @@ -1543,112 +1637,138 @@ def test_single_inv_no_parents_as_xml(self): b' revision="a@cset-0-1"' b' text_sha1="09c2f8647e14e49e922b955c194102070597c2d1"' b' text_size="17" />\n' - b'\n' - b'\n', bytes) + b"\n" + b"\n", + bytes, + ) def test_multiple_inventories_as_xml(self): self.make_merged_branch() - sio = self.make_bundle_just_inventories(b'a@cset-0-1', b'a@cset-0-3', - [b'a@cset-0-2a', b'a@cset-0-2b', b'a@cset-0-3']) + sio = self.make_bundle_just_inventories( + b"a@cset-0-1", + b"a@cset-0-3", + [b"a@cset-0-2a", b"a@cset-0-2b", b"a@cset-0-3"], + ) reader = v4.BundleReader(sio, stream_input=False) records = list(reader.iter_records()) self.assertEqual(3, len(records)) revision_ids = [rev_id for b, m, k, rev_id, f in records] - self.assertEqual([b'a@cset-0-2a', b'a@cset-0-2b', b'a@cset-0-3'], - revision_ids) + self.assertEqual([b"a@cset-0-2a", b"a@cset-0-2b", b"a@cset-0-3"], revision_ids) metadata_2a = records[0][1] - self.assertEqual({b'parents': [b'a@cset-0-1'], - b'sha1': b'1e105886d62d510763e22885eec733b66f5f09bf', - b'storage_kind': b'mpdiff', - }, metadata_2a) + self.assertEqual( + { + b"parents": [b"a@cset-0-1"], + b"sha1": b"1e105886d62d510763e22885eec733b66f5f09bf", + b"storage_kind": b"mpdiff", + }, + metadata_2a, + ) metadata_2b = records[1][1] - self.assertEqual({b'parents': [b'a@cset-0-1'], - b'sha1': b'f03f12574bdb5ed2204c28636c98a8547544ccd8', - b'storage_kind': b'mpdiff', - }, metadata_2b) + self.assertEqual( + { + b"parents": [b"a@cset-0-1"], + b"sha1": b"f03f12574bdb5ed2204c28636c98a8547544ccd8", + b"storage_kind": b"mpdiff", + }, + metadata_2b, + ) metadata_3 = records[2][1] - self.assertEqual({b'parents': [b'a@cset-0-2a', b'a@cset-0-2b'], - b'sha1': b'09c53b0c4de0895e11a2aacc34fef60a6e70865c', - b'storage_kind': b'mpdiff', - }, metadata_3) + self.assertEqual( + { + b"parents": [b"a@cset-0-2a", b"a@cset-0-2b"], + b"sha1": b"09c53b0c4de0895e11a2aacc34fef60a6e70865c", + b"storage_kind": b"mpdiff", + }, + metadata_3, + ) bytes_2a = records[0][0] self.assertEqualDiff( - b'i 1\n' + b"i 1\n" b'\n' - b'\n' - b'c 0 1 1 1\n' - b'i 1\n' + b"\n" + b"c 0 1 1 1\n" + b"i 1\n" b'\n' - b'\n' - b'c 0 3 3 1\n', bytes_2a) + b"\n" + b"c 0 3 3 1\n", + bytes_2a, + ) bytes_2b = records[1][0] self.assertEqualDiff( - b'i 1\n' + b"i 1\n" b'\n' - b'\n' - b'c 0 1 1 2\n' - b'i 1\n' + b"\n" + b"c 0 1 1 2\n" + b"i 1\n" b'\n' - b'\n' - b'c 0 3 4 1\n', bytes_2b) + b"\n" + b"c 0 3 4 1\n", + bytes_2b, + ) bytes_3 = records[2][0] self.assertEqualDiff( - b'i 1\n' + b"i 1\n" b'\n' - b'\n' - b'c 0 1 1 2\n' - b'c 1 3 3 2\n', bytes_3) + b"\n" + b"c 0 1 1 2\n" + b"c 1 3 3 2\n", + bytes_3, + ) def test_creating_bundle_preserves_chk_pages(self): self.make_merged_branch() - target = self.b1.controldir.sprout('target', - revision_id=b'a@cset-0-2a').open_branch() - bundle_txt, rev_ids = self.create_bundle_text(b'a@cset-0-2a', - b'a@cset-0-3') - self.assertEqual({b'a@cset-0-2b', b'a@cset-0-3'}, set(rev_ids)) + target = self.b1.controldir.sprout( + "target", revision_id=b"a@cset-0-2a" + ).open_branch() + bundle_txt, rev_ids = self.create_bundle_text(b"a@cset-0-2a", b"a@cset-0-3") + self.assertEqual({b"a@cset-0-2b", b"a@cset-0-3"}, set(rev_ids)) bundle = read_bundle(bundle_txt) target.lock_write() self.addCleanup(target.unlock) install_bundle(target.repository, bundle) - inv1 = next(self.b1.repository.inventories.get_record_stream([ - (b'a@cset-0-3',)], 'unordered', - True)).get_bytes_as('fulltext') - inv2 = next(target.repository.inventories.get_record_stream([ - (b'a@cset-0-3',)], 'unordered', - True)).get_bytes_as('fulltext') + inv1 = next( + self.b1.repository.inventories.get_record_stream( + [(b"a@cset-0-3",)], "unordered", True + ) + ).get_bytes_as("fulltext") + inv2 = next( + target.repository.inventories.get_record_stream( + [(b"a@cset-0-3",)], "unordered", True + ) + ).get_bytes_as("fulltext") self.assertEqualDiff(inv1, inv2) class MungedBundleTester: - def build_test_bundle(self): - wt = self.make_branch_and_tree('b1') + wt = self.make_branch_and_tree("b1") - self.build_tree(['b1/one']) - wt.add('one') - wt.commit('add one', rev_id=b'a@cset-0-1') - self.build_tree(['b1/two']) - wt.add('two') - wt.commit('add two', rev_id=b'a@cset-0-2', - revprops={'branch-nick': 'test'}) + self.build_tree(["b1/one"]) + wt.add("one") + wt.commit("add one", rev_id=b"a@cset-0-1") + self.build_tree(["b1/two"]) + wt.add("two") + wt.commit("add two", rev_id=b"a@cset-0-2", revprops={"branch-nick": "test"}) bundle_txt = BytesIO() - rev_ids = write_bundle(wt.branch.repository, b'a@cset-0-2', - b'a@cset-0-1', bundle_txt, self.format) - self.assertEqual({b'a@cset-0-2'}, set(rev_ids)) + rev_ids = write_bundle( + wt.branch.repository, b"a@cset-0-2", b"a@cset-0-1", bundle_txt, self.format + ) + self.assertEqual({b"a@cset-0-2"}, set(rev_ids)) bundle_txt.seek(0, 0) return bundle_txt def check_valid(self, bundle): """Check that after whatever munging, the final object is valid.""" - self.assertEqual([b'a@cset-0-2'], - [r.revision_id for r in bundle.real_revisions]) + self.assertEqual( + [b"a@cset-0-2"], [r.revision_id for r in bundle.real_revisions] + ) def test_extra_whitespace(self): bundle_txt = self.build_test_bundle() @@ -1657,7 +1777,7 @@ def test_extra_whitespace(self): # Adding one extra newline used to give us # TypeError: float() argument must be a string or a number bundle_txt.seek(0, 2) - bundle_txt.write(b'\n') + bundle_txt.write(b"\n") bundle_txt.seek(0) bundle = read_bundle(bundle_txt) @@ -1670,7 +1790,7 @@ def test_extra_whitespace_2(self): # Adding two extra newlines used to give us # MalformedPatches: The first line of all patches should be ... bundle_txt.seek(0, 2) - bundle_txt.write(b'\n\n') + bundle_txt.write(b"\n\n") bundle_txt.seek(0) bundle = read_bundle(bundle_txt) @@ -1678,8 +1798,7 @@ def test_extra_whitespace_2(self): class MungedBundleTesterV09(tests.TestCaseWithTransport, MungedBundleTester): - - format = '0.9' + format = "0.9" def test_missing_trailing_whitespace(self): bundle_txt = self.build_test_bundle() @@ -1690,7 +1809,7 @@ def test_missing_trailing_whitespace(self): # test is concerned with the exact case where the serializer # creates a blank line at the end, and fails if that # line is stripped - self.assertEqual(b'\n\n', raw[-2:]) + self.assertEqual(b"\n\n", raw[-2:]) bundle_txt = BytesIO(raw[:-1]) bundle = read_bundle(bundle_txt) @@ -1699,8 +1818,7 @@ def test_missing_trailing_whitespace(self): def test_opening_text(self): bundle_txt = self.build_test_bundle() - bundle_txt = BytesIO( - b"Some random\nemail comments\n" + bundle_txt.getvalue()) + bundle_txt = BytesIO(b"Some random\nemail comments\n" + bundle_txt.getvalue()) bundle = read_bundle(bundle_txt) self.check_valid(bundle) @@ -1708,84 +1826,117 @@ def test_opening_text(self): def test_trailing_text(self): bundle_txt = self.build_test_bundle() - bundle_txt = BytesIO( - bundle_txt.getvalue() + b"Some trailing\nrandom\ntext\n") + bundle_txt = BytesIO(bundle_txt.getvalue() + b"Some trailing\nrandom\ntext\n") bundle = read_bundle(bundle_txt) self.check_valid(bundle) class MungedBundleTesterV4(tests.TestCaseWithTransport, MungedBundleTester): - - format = '4' + format = "4" class TestBundleWriterReader(tests.TestCase): - def test_roundtrip_record(self): fileobj = BytesIO() writer = v4.BundleWriter(fileobj) writer.begin() - writer.add_info_record({b'foo': b'bar'}) - writer._add_record(b"Record body", {b'parents': [b'1', b'3'], - b'storage_kind': b'fulltext'}, 'file', b'revid', b'fileid') + writer.add_info_record({b"foo": b"bar"}) + writer._add_record( + b"Record body", + {b"parents": [b"1", b"3"], b"storage_kind": b"fulltext"}, + "file", + b"revid", + b"fileid", + ) writer.end() fileobj.seek(0) reader = v4.BundleReader(fileobj, stream_input=True) record_iter = reader.iter_records() record = next(record_iter) - self.assertEqual((None, {b'foo': b'bar', b'storage_kind': b'header'}, - 'info', None, None), record) + self.assertEqual( + (None, {b"foo": b"bar", b"storage_kind": b"header"}, "info", None, None), + record, + ) record = next(record_iter) - self.assertEqual((b"Record body", {b'storage_kind': b'fulltext', - b'parents': [b'1', b'3']}, 'file', b'revid', b'fileid'), - record) + self.assertEqual( + ( + b"Record body", + {b"storage_kind": b"fulltext", b"parents": [b"1", b"3"]}, + "file", + b"revid", + b"fileid", + ), + record, + ) def test_roundtrip_record_memory_hungry(self): fileobj = BytesIO() writer = v4.BundleWriter(fileobj) writer.begin() - writer.add_info_record({b'foo': b'bar'}) - writer._add_record(b"Record body", {b'parents': [b'1', b'3'], - b'storage_kind': b'fulltext'}, 'file', b'revid', b'fileid') + writer.add_info_record({b"foo": b"bar"}) + writer._add_record( + b"Record body", + {b"parents": [b"1", b"3"], b"storage_kind": b"fulltext"}, + "file", + b"revid", + b"fileid", + ) writer.end() fileobj.seek(0) reader = v4.BundleReader(fileobj, stream_input=False) record_iter = reader.iter_records() record = next(record_iter) - self.assertEqual((None, {b'foo': b'bar', b'storage_kind': b'header'}, - 'info', None, None), record) + self.assertEqual( + (None, {b"foo": b"bar", b"storage_kind": b"header"}, "info", None, None), + record, + ) record = next(record_iter) - self.assertEqual((b"Record body", {b'storage_kind': b'fulltext', - b'parents': [b'1', b'3']}, 'file', b'revid', b'fileid'), - record) + self.assertEqual( + ( + b"Record body", + {b"storage_kind": b"fulltext", b"parents": [b"1", b"3"]}, + "file", + b"revid", + b"fileid", + ), + record, + ) def test_encode_name(self): - self.assertEqual(b'revision/rev1', - v4.BundleWriter.encode_name('revision', b'rev1')) - self.assertEqual(b'file/rev//1/file-id-1', - v4.BundleWriter.encode_name('file', b'rev/1', b'file-id-1')) - self.assertEqual(b'info', - v4.BundleWriter.encode_name('info', None, None)) + self.assertEqual( + b"revision/rev1", v4.BundleWriter.encode_name("revision", b"rev1") + ) + self.assertEqual( + b"file/rev//1/file-id-1", + v4.BundleWriter.encode_name("file", b"rev/1", b"file-id-1"), + ) + self.assertEqual(b"info", v4.BundleWriter.encode_name("info", None, None)) def test_decode_name(self): - self.assertEqual(('revision', b'rev1', None), - v4.BundleReader.decode_name(b'revision/rev1')) - self.assertEqual(('file', b'rev/1', b'file-id-1'), - v4.BundleReader.decode_name(b'file/rev//1/file-id-1')) - self.assertEqual(('info', None, None), - v4.BundleReader.decode_name(b'info')) + self.assertEqual( + ("revision", b"rev1", None), v4.BundleReader.decode_name(b"revision/rev1") + ) + self.assertEqual( + ("file", b"rev/1", b"file-id-1"), + v4.BundleReader.decode_name(b"file/rev//1/file-id-1"), + ) + self.assertEqual(("info", None, None), v4.BundleReader.decode_name(b"info")) def test_too_many_names(self): fileobj = BytesIO() writer = v4.BundleWriter(fileobj) writer.begin() - writer.add_info_record({b'foo': b'bar'}) - writer._container.add_bytes_record([b'blah'], len(b'blah'), [(b'two', ), (b'names', )]) + writer.add_info_record({b"foo": b"bar"}) + writer._container.add_bytes_record( + [b"blah"], len(b"blah"), [(b"two",), (b"names",)] + ) writer.end() fileobj.seek(0) record_iter = v4.BundleReader(fileobj).iter_records() record = next(record_iter) - self.assertEqual((None, {b'foo': b'bar', b'storage_kind': b'header'}, - 'info', None, None), record) + self.assertEqual( + (None, {b"foo": b"bar", b"storage_kind": b"header"}, "info", None, None), + record, + ) self.assertRaises(errors.BadBundle, next, record_iter) diff --git a/breezy/bzr/tests/test_bzrdir.py b/breezy/bzr/tests/test_bzrdir.py index 7b92e58a5a..e251219dc6 100644 --- a/breezy/bzr/tests/test_bzrdir.py +++ b/breezy/bzr/tests/test_bzrdir.py @@ -57,21 +57,20 @@ class TestDefaultFormat(TestCase): - def test_get_set_default_format(self): old_format = bzrdir.BzrDirFormat.get_default_format() # default is BzrDirMetaFormat1 self.assertIsInstance(old_format, bzrdir.BzrDirMetaFormat1) - current_default = controldir.format_registry.aliases()['bzr'] - controldir.format_registry.register('sample', SampleBzrDirFormat, help='Sample') - self.addCleanup(controldir.format_registry.remove, 'sample') - controldir.format_registry.register_alias('bzr', 'sample') + current_default = controldir.format_registry.aliases()["bzr"] + controldir.format_registry.register("sample", SampleBzrDirFormat, help="Sample") + self.addCleanup(controldir.format_registry.remove, "sample") + controldir.format_registry.register_alias("bzr", "sample") # creating a bzr dir should now create an instrumented dir. try: - result = bzrdir.BzrDir.create('memory:///') + result = bzrdir.BzrDir.create("memory:///") self.assertIsInstance(result, SampleBzrDir) finally: - controldir.format_registry.register_alias('bzr', current_default) + controldir.format_registry.register_alias("bzr", current_default) self.assertEqual(old_format, bzrdir.BzrDirFormat.get_default_format()) @@ -80,107 +79,141 @@ class DeprecatedBzrDirFormat(bzrdir.BzrDirFormat): class TestFormatRegistry(TestCase): - def make_format_registry(self): my_format_registry = controldir.ControlDirFormatRegistry() - my_format_registry.register('deprecated', DeprecatedBzrDirFormat, - 'Some format. Slower and unawesome and deprecated.', - deprecated=True) - my_format_registry.register_lazy('lazy', __name__, - 'DeprecatedBzrDirFormat', 'Format registered lazily', - deprecated=True) - bzr.register_metadir(my_format_registry, 'knit', - 'breezy.bzr.knitrepo.RepositoryFormatKnit1', - 'Format using knits', - ) - my_format_registry.set_default('knit') - bzr.register_metadir(my_format_registry, - 'branch6', - 'breezy.bzr.knitrepo.RepositoryFormatKnit3', - 'Experimental successor to knit. Use at your own risk.', - branch_format='breezy.bzr.branch.BzrBranchFormat6', - experimental=True) - bzr.register_metadir(my_format_registry, - 'hidden format', - 'breezy.bzr.knitrepo.RepositoryFormatKnit3', - 'Experimental successor to knit. Use at your own risk.', - branch_format='breezy.bzr.branch.BzrBranchFormat6', hidden=True) - my_format_registry.register('hiddendeprecated', DeprecatedBzrDirFormat, - 'Old format. Slower and does not support things. ', hidden=True) - my_format_registry.register_lazy('hiddenlazy', __name__, - 'DeprecatedBzrDirFormat', 'Format registered lazily', - deprecated=True, hidden=True) + my_format_registry.register( + "deprecated", + DeprecatedBzrDirFormat, + "Some format. Slower and unawesome and deprecated.", + deprecated=True, + ) + my_format_registry.register_lazy( + "lazy", + __name__, + "DeprecatedBzrDirFormat", + "Format registered lazily", + deprecated=True, + ) + bzr.register_metadir( + my_format_registry, + "knit", + "breezy.bzr.knitrepo.RepositoryFormatKnit1", + "Format using knits", + ) + my_format_registry.set_default("knit") + bzr.register_metadir( + my_format_registry, + "branch6", + "breezy.bzr.knitrepo.RepositoryFormatKnit3", + "Experimental successor to knit. Use at your own risk.", + branch_format="breezy.bzr.branch.BzrBranchFormat6", + experimental=True, + ) + bzr.register_metadir( + my_format_registry, + "hidden format", + "breezy.bzr.knitrepo.RepositoryFormatKnit3", + "Experimental successor to knit. Use at your own risk.", + branch_format="breezy.bzr.branch.BzrBranchFormat6", + hidden=True, + ) + my_format_registry.register( + "hiddendeprecated", + DeprecatedBzrDirFormat, + "Old format. Slower and does not support things. ", + hidden=True, + ) + my_format_registry.register_lazy( + "hiddenlazy", + __name__, + "DeprecatedBzrDirFormat", + "Format registered lazily", + deprecated=True, + hidden=True, + ) return my_format_registry def test_format_registry(self): my_format_registry = self.make_format_registry() - my_bzrdir = my_format_registry.make_controldir('lazy') + my_bzrdir = my_format_registry.make_controldir("lazy") self.assertIsInstance(my_bzrdir, DeprecatedBzrDirFormat) - my_bzrdir = my_format_registry.make_controldir('deprecated') + my_bzrdir = my_format_registry.make_controldir("deprecated") self.assertIsInstance(my_bzrdir, DeprecatedBzrDirFormat) - my_bzrdir = my_format_registry.make_controldir('default') - self.assertIsInstance(my_bzrdir.repository_format, - knitrepo.RepositoryFormatKnit1) - my_bzrdir = my_format_registry.make_controldir('knit') - self.assertIsInstance(my_bzrdir.repository_format, - knitrepo.RepositoryFormatKnit1) - my_bzrdir = my_format_registry.make_controldir('branch6') - self.assertIsInstance(my_bzrdir.get_branch_format(), - breezy.bzr.branch.BzrBranchFormat6) + my_bzrdir = my_format_registry.make_controldir("default") + self.assertIsInstance( + my_bzrdir.repository_format, knitrepo.RepositoryFormatKnit1 + ) + my_bzrdir = my_format_registry.make_controldir("knit") + self.assertIsInstance( + my_bzrdir.repository_format, knitrepo.RepositoryFormatKnit1 + ) + my_bzrdir = my_format_registry.make_controldir("branch6") + self.assertIsInstance( + my_bzrdir.get_branch_format(), breezy.bzr.branch.BzrBranchFormat6 + ) def test_get_help(self): my_format_registry = self.make_format_registry() - self.assertEqual('Format registered lazily', - my_format_registry.get_help('lazy')) - self.assertEqual('Format using knits', - my_format_registry.get_help('knit')) - self.assertEqual('Format using knits', - my_format_registry.get_help('default')) - self.assertEqual('Some format. Slower and unawesome and deprecated.', - my_format_registry.get_help('deprecated')) + self.assertEqual( + "Format registered lazily", my_format_registry.get_help("lazy") + ) + self.assertEqual("Format using knits", my_format_registry.get_help("knit")) + self.assertEqual("Format using knits", my_format_registry.get_help("default")) + self.assertEqual( + "Some format. Slower and unawesome and deprecated.", + my_format_registry.get_help("deprecated"), + ) def test_help_topic(self): topics = help_topics.HelpTopicRegistry() registry = self.make_format_registry() - topics.register('current-formats', registry.help_topic, - 'Current formats') - topics.register('other-formats', registry.help_topic, - 'Other formats') - new = topics.get_detail('current-formats') - rest = topics.get_detail('other-formats') - experimental, deprecated = rest.split('Deprecated formats') - self.assertContainsRe(new, 'formats-help') - self.assertContainsRe(new, - ':knit:\n \\(native\\) \\(default\\) Format using knits\n') - self.assertContainsRe(experimental, - ':branch6:\n \\(native\\) Experimental successor to knit') - self.assertContainsRe(deprecated, - ':lazy:\n \\(native\\) Format registered lazily\n') - self.assertNotContainsRe(new, 'hidden') + topics.register("current-formats", registry.help_topic, "Current formats") + topics.register("other-formats", registry.help_topic, "Other formats") + new = topics.get_detail("current-formats") + rest = topics.get_detail("other-formats") + experimental, deprecated = rest.split("Deprecated formats") + self.assertContainsRe(new, "formats-help") + self.assertContainsRe( + new, ":knit:\n \\(native\\) \\(default\\) Format using knits\n" + ) + self.assertContainsRe( + experimental, ":branch6:\n \\(native\\) Experimental successor to knit" + ) + self.assertContainsRe( + deprecated, ":lazy:\n \\(native\\) Format registered lazily\n" + ) + self.assertNotContainsRe(new, "hidden") def test_set_default_repository(self): - default_factory = controldir.format_registry.get('default') - old_default = [k for k, v in controldir.format_registry.iteritems() - if v == default_factory and k != 'default'][0] - controldir.format_registry.set_default_repository( - 'dirstate-with-subtree') + default_factory = controldir.format_registry.get("default") + old_default = [ + k + for k, v in controldir.format_registry.iteritems() + if v == default_factory and k != "default" + ][0] + controldir.format_registry.set_default_repository("dirstate-with-subtree") try: - self.assertIs(controldir.format_registry.get('dirstate-with-subtree'), - controldir.format_registry.get('default')) + self.assertIs( + controldir.format_registry.get("dirstate-with-subtree"), + controldir.format_registry.get("default"), + ) self.assertIs( repository.format_registry.get_default().__class__, - knitrepo.RepositoryFormatKnit3) + knitrepo.RepositoryFormatKnit3, + ) finally: controldir.format_registry.set_default_repository(old_default) def test_aliases(self): a_registry = controldir.ControlDirFormatRegistry() - a_registry.register('deprecated', DeprecatedBzrDirFormat, - 'Old format. Slower and does not support stuff', - deprecated=True) - a_registry.register_alias('deprecatedalias', 'deprecated') - self.assertEqual({'deprecatedalias': 'deprecated'}, - a_registry.aliases()) + a_registry.register( + "deprecated", + DeprecatedBzrDirFormat, + "Old format. Slower and does not support stuff", + deprecated=True, + ) + a_registry.register_alias("deprecatedalias", "deprecated") + self.assertEqual({"deprecatedalias": "deprecated"}, a_registry.aliases()) class SampleBranch(breezy.branch.Branch): @@ -232,8 +265,8 @@ def get_format_string(self): def initialize_on_transport(self, t): """Create a bzr dir.""" - t.mkdir('.bzr') - t.put_bytes('.bzr/branch-format', self.get_format_string()) + t.mkdir(".bzr") + t.put_bytes(".bzr/branch-format", self.get_format_string()) return SampleBzrDir(t, self) def is_supported(self): @@ -248,14 +281,12 @@ def from_string(cls, format_string): class BzrDirFormatTest1(bzrdir.BzrDirMetaFormat1): - @staticmethod def get_format_string(): return b"Test format 1" class BzrDirFormatTest2(bzrdir.BzrDirMetaFormat1): - @staticmethod def get_format_string(): return b"Test format 2" @@ -267,14 +298,18 @@ class TestBzrDirFormat(TestCaseWithTransport): def test_find_format(self): # is the right format object found for a branch? # create a branch with a few known format objects. - bzr.BzrProber.formats.register(BzrDirFormatTest1.get_format_string(), - BzrDirFormatTest1()) - self.addCleanup(bzr.BzrProber.formats.remove, - BzrDirFormatTest1.get_format_string()) - bzr.BzrProber.formats.register(BzrDirFormatTest2.get_format_string(), - BzrDirFormatTest2()) - self.addCleanup(bzr.BzrProber.formats.remove, - BzrDirFormatTest2.get_format_string()) + bzr.BzrProber.formats.register( + BzrDirFormatTest1.get_format_string(), BzrDirFormatTest1() + ) + self.addCleanup( + bzr.BzrProber.formats.remove, BzrDirFormatTest1.get_format_string() + ) + bzr.BzrProber.formats.register( + BzrDirFormatTest2.get_format_string(), BzrDirFormatTest2() + ) + self.addCleanup( + bzr.BzrProber.formats.remove, BzrDirFormatTest2.get_format_string() + ) t = self.get_transport() self.build_tree(["foo/", "bar/"], transport=t) @@ -283,39 +318,49 @@ def check_format(format, url): t = _mod_transport.get_transport_from_path(url) found_format = bzrdir.BzrDirFormat.find_format(t) self.assertIsInstance(found_format, format.__class__) + check_format(BzrDirFormatTest1(), "foo") check_format(BzrDirFormatTest2(), "bar") def test_find_format_nothing_there(self): - self.assertRaises(NotBranchError, - bzrdir.BzrDirFormat.find_format, - _mod_transport.get_transport_from_path('.')) + self.assertRaises( + NotBranchError, + bzrdir.BzrDirFormat.find_format, + _mod_transport.get_transport_from_path("."), + ) def test_find_format_unknown_format(self): t = self.get_transport() - t.mkdir('.bzr') - t.put_bytes('.bzr/branch-format', b'') - self.assertRaises(UnknownFormatError, - bzrdir.BzrDirFormat.find_format, - _mod_transport.get_transport_from_path('.')) + t.mkdir(".bzr") + t.put_bytes(".bzr/branch-format", b"") + self.assertRaises( + UnknownFormatError, + bzrdir.BzrDirFormat.find_format, + _mod_transport.get_transport_from_path("."), + ) def test_find_format_line_endings(self): t = self.get_transport() - t.mkdir('.bzr') - t.put_bytes('.bzr/branch-format', b'Corrupt line endings\r\n') - self.assertRaises(bzr.LineEndingError, - bzrdir.BzrDirFormat.find_format, - _mod_transport.get_transport_from_path('.')) + t.mkdir(".bzr") + t.put_bytes(".bzr/branch-format", b"Corrupt line endings\r\n") + self.assertRaises( + bzr.LineEndingError, + bzrdir.BzrDirFormat.find_format, + _mod_transport.get_transport_from_path("."), + ) def test_find_format_html(self): t = self.get_transport() - t.mkdir('.bzr') + t.mkdir(".bzr") t.put_bytes( - '.bzr/branch-format', - b'') + ".bzr/branch-format", + b'', + ) self.assertRaises( - NotBranchError, bzrdir.BzrDirFormat.find_format, - _mod_transport.get_transport_from_path('.')) + NotBranchError, + bzrdir.BzrDirFormat.find_format, + _mod_transport.get_transport_from_path("."), + ) def test_register_unregister_format(self): format = SampleBzrDirFormat() @@ -327,80 +372,82 @@ def test_register_unregister_format(self): # which bzrdir.Open will refuse (not supported) self.assertRaises(UnsupportedFormatError, bzrdir.BzrDir.open, url) # which bzrdir.open_containing will refuse (not supported) - self.assertRaises(UnsupportedFormatError, - bzrdir.BzrDir.open_containing, url) + self.assertRaises(UnsupportedFormatError, bzrdir.BzrDir.open_containing, url) # but open_downlevel will work t = _mod_transport.get_transport_from_url(url) self.assertEqual(format.open(t), bzrdir.BzrDir.open_unsupported(url)) # unregister the format bzr.BzrProber.formats.remove(format.get_format_string()) # now open_downlevel should fail too. - self.assertRaises(UnknownFormatError, - bzrdir.BzrDir.open_unsupported, url) + self.assertRaises(UnknownFormatError, bzrdir.BzrDir.open_unsupported, url) def test_create_branch_and_repo_uses_default(self): format = SampleBzrDirFormat() - branch = bzrdir.BzrDir.create_branch_and_repo(self.get_url(), - format=format) + branch = bzrdir.BzrDir.create_branch_and_repo(self.get_url(), format=format) self.assertIsInstance(branch, SampleBranch) def test_create_branch_and_repo_under_shared(self): # creating a branch and repo in a shared repo uses the # shared repository - format = controldir.format_registry.make_controldir('knit') - self.make_repository('.', shared=True, format=format) + format = controldir.format_registry.make_controldir("knit") + self.make_repository(".", shared=True, format=format) branch = bzrdir.BzrDir.create_branch_and_repo( - self.get_url('child'), format=format) - self.assertRaises(errors.NoRepositoryPresent, - branch.controldir.open_repository) + self.get_url("child"), format=format + ) + self.assertRaises(errors.NoRepositoryPresent, branch.controldir.open_repository) def test_create_branch_and_repo_under_shared_force_new(self): # creating a branch and repo in a shared repo can be forced to # make a new repo - format = controldir.format_registry.make_controldir('knit') - self.make_repository('.', shared=True, format=format) - branch = bzrdir.BzrDir.create_branch_and_repo(self.get_url('child'), - force_new_repo=True, - format=format) + format = controldir.format_registry.make_controldir("knit") + self.make_repository(".", shared=True, format=format) + branch = bzrdir.BzrDir.create_branch_and_repo( + self.get_url("child"), force_new_repo=True, format=format + ) branch.controldir.open_repository() def test_create_standalone_working_tree(self): format = SampleBzrDirFormat() # note this is deliberately readonly, as this failure should # occur before any writes. - self.assertRaises(errors.NotLocalUrl, - bzrdir.BzrDir.create_standalone_workingtree, - self.get_readonly_url(), format=format) - tree = bzrdir.BzrDir.create_standalone_workingtree('.', - format=format) - self.assertEqual('A tree', tree) + self.assertRaises( + errors.NotLocalUrl, + bzrdir.BzrDir.create_standalone_workingtree, + self.get_readonly_url(), + format=format, + ) + tree = bzrdir.BzrDir.create_standalone_workingtree(".", format=format) + self.assertEqual("A tree", tree) def test_create_standalone_working_tree_under_shared_repo(self): # create standalone working tree always makes a repo. - format = controldir.format_registry.make_controldir('knit') - self.make_repository('.', shared=True, format=format) + format = controldir.format_registry.make_controldir("knit") + self.make_repository(".", shared=True, format=format) # note this is deliberately readonly, as this failure should # occur before any writes. - self.assertRaises(errors.NotLocalUrl, - bzrdir.BzrDir.create_standalone_workingtree, - self.get_readonly_url('child'), format=format) - tree = bzrdir.BzrDir.create_standalone_workingtree('child', - format=format) + self.assertRaises( + errors.NotLocalUrl, + bzrdir.BzrDir.create_standalone_workingtree, + self.get_readonly_url("child"), + format=format, + ) + tree = bzrdir.BzrDir.create_standalone_workingtree("child", format=format) tree.controldir.open_repository() def test_create_branch_convenience(self): # outside a repo the default convenience output is a repo+branch_tree - format = controldir.format_registry.make_controldir('knit') - branch = bzrdir.BzrDir.create_branch_convenience('.', format=format) + format = controldir.format_registry.make_controldir("knit") + branch = bzrdir.BzrDir.create_branch_convenience(".", format=format) branch.controldir.open_workingtree() branch.controldir.open_repository() def test_create_branch_convenience_possible_transports(self): """Check that the optional 'possible_transports' is recognized.""" - format = controldir.format_registry.make_controldir('knit') + format = controldir.format_registry.make_controldir("knit") t = self.get_transport() branch = bzrdir.BzrDir.create_branch_convenience( - '.', format=format, possible_transports=[t]) + ".", format=format, possible_transports=[t] + ) branch.controldir.open_workingtree() branch.controldir.open_repository() @@ -408,139 +455,136 @@ def test_create_branch_convenience_root(self): """Creating a branch at the root of a fs should work.""" self.vfs_transport_factory = memory.MemoryServer # outside a repo the default convenience output is a repo+branch_tree - format = controldir.format_registry.make_controldir('knit') - branch = bzrdir.BzrDir.create_branch_convenience(self.get_url(), - format=format) - self.assertRaises(errors.NoWorkingTree, - branch.controldir.open_workingtree) + format = controldir.format_registry.make_controldir("knit") + branch = bzrdir.BzrDir.create_branch_convenience(self.get_url(), format=format) + self.assertRaises(errors.NoWorkingTree, branch.controldir.open_workingtree) branch.controldir.open_repository() def test_create_branch_convenience_under_shared_repo(self): # inside a repo the default convenience output is a branch+ follow the # repo tree policy - format = controldir.format_registry.make_controldir('knit') - self.make_repository('.', shared=True, format=format) - branch = bzrdir.BzrDir.create_branch_convenience('child', - format=format) + format = controldir.format_registry.make_controldir("knit") + self.make_repository(".", shared=True, format=format) + branch = bzrdir.BzrDir.create_branch_convenience("child", format=format) branch.controldir.open_workingtree() - self.assertRaises(errors.NoRepositoryPresent, - branch.controldir.open_repository) + self.assertRaises(errors.NoRepositoryPresent, branch.controldir.open_repository) def test_create_branch_convenience_under_shared_repo_force_no_tree(self): # inside a repo the default convenience output is a branch+ follow the # repo tree policy but we can override that - format = controldir.format_registry.make_controldir('knit') - self.make_repository('.', shared=True, format=format) - branch = bzrdir.BzrDir.create_branch_convenience('child', - force_new_tree=False, format=format) - self.assertRaises(errors.NoWorkingTree, - branch.controldir.open_workingtree) - self.assertRaises(errors.NoRepositoryPresent, - branch.controldir.open_repository) + format = controldir.format_registry.make_controldir("knit") + self.make_repository(".", shared=True, format=format) + branch = bzrdir.BzrDir.create_branch_convenience( + "child", force_new_tree=False, format=format + ) + self.assertRaises(errors.NoWorkingTree, branch.controldir.open_workingtree) + self.assertRaises(errors.NoRepositoryPresent, branch.controldir.open_repository) def test_create_branch_convenience_under_shared_repo_no_tree_policy(self): # inside a repo the default convenience output is a branch+ follow the # repo tree policy - format = controldir.format_registry.make_controldir('knit') - repo = self.make_repository('.', shared=True, format=format) + format = controldir.format_registry.make_controldir("knit") + repo = self.make_repository(".", shared=True, format=format) repo.set_make_working_trees(False) - branch = bzrdir.BzrDir.create_branch_convenience('child', - format=format) - self.assertRaises(errors.NoWorkingTree, - branch.controldir.open_workingtree) - self.assertRaises(errors.NoRepositoryPresent, - branch.controldir.open_repository) - - def test_create_branch_convenience_under_shared_repo_no_tree_policy_force_tree(self): + branch = bzrdir.BzrDir.create_branch_convenience("child", format=format) + self.assertRaises(errors.NoWorkingTree, branch.controldir.open_workingtree) + self.assertRaises(errors.NoRepositoryPresent, branch.controldir.open_repository) + + def test_create_branch_convenience_under_shared_repo_no_tree_policy_force_tree( + self + ): # inside a repo the default convenience output is a branch+ follow the # repo tree policy but we can override that - format = controldir.format_registry.make_controldir('knit') - repo = self.make_repository('.', shared=True, format=format) + format = controldir.format_registry.make_controldir("knit") + repo = self.make_repository(".", shared=True, format=format) repo.set_make_working_trees(False) - branch = bzrdir.BzrDir.create_branch_convenience('child', - force_new_tree=True, format=format) + branch = bzrdir.BzrDir.create_branch_convenience( + "child", force_new_tree=True, format=format + ) branch.controldir.open_workingtree() - self.assertRaises(errors.NoRepositoryPresent, - branch.controldir.open_repository) + self.assertRaises(errors.NoRepositoryPresent, branch.controldir.open_repository) def test_create_branch_convenience_under_shared_repo_force_new_repo(self): # inside a repo the default convenience output is overridable to give # repo+branch+tree - format = controldir.format_registry.make_controldir('knit') - self.make_repository('.', shared=True, format=format) - branch = bzrdir.BzrDir.create_branch_convenience('child', - force_new_repo=True, format=format) + format = controldir.format_registry.make_controldir("knit") + self.make_repository(".", shared=True, format=format) + branch = bzrdir.BzrDir.create_branch_convenience( + "child", force_new_repo=True, format=format + ) branch.controldir.open_repository() branch.controldir.open_workingtree() class TestRepositoryAcquisitionPolicy(TestCaseWithTransport): - def test_acquire_repository_standalone(self): """The default acquisition policy should create a standalone branch.""" - my_bzrdir = self.make_controldir('.') + my_bzrdir = self.make_controldir(".") repo_policy = my_bzrdir.determine_repository_policy() repo, is_new = repo_policy.acquire_repository() - self.assertEqual(repo.controldir.root_transport.base, - my_bzrdir.root_transport.base) + self.assertEqual( + repo.controldir.root_transport.base, my_bzrdir.root_transport.base + ) self.assertFalse(repo.is_shared()) def test_determine_stacking_policy(self): - parent_bzrdir = self.make_controldir('.') - child_bzrdir = self.make_controldir('child') - parent_bzrdir.get_config().set_default_stack_on('http://example.org') + parent_bzrdir = self.make_controldir(".") + child_bzrdir = self.make_controldir("child") + parent_bzrdir.get_config().set_default_stack_on("http://example.org") repo_policy = child_bzrdir.determine_repository_policy() - self.assertEqual('http://example.org', repo_policy._stack_on) + self.assertEqual("http://example.org", repo_policy._stack_on) def test_determine_stacking_policy_relative(self): - parent_bzrdir = self.make_controldir('.') - child_bzrdir = self.make_controldir('child') - parent_bzrdir.get_config().set_default_stack_on('child2') + parent_bzrdir = self.make_controldir(".") + child_bzrdir = self.make_controldir("child") + parent_bzrdir.get_config().set_default_stack_on("child2") repo_policy = child_bzrdir.determine_repository_policy() - self.assertEqual('child2', repo_policy._stack_on) - self.assertEqual(parent_bzrdir.root_transport.base, - repo_policy._stack_on_pwd) + self.assertEqual("child2", repo_policy._stack_on) + self.assertEqual(parent_bzrdir.root_transport.base, repo_policy._stack_on_pwd) - def prepare_default_stacking(self, child_format='1.6'): - parent_bzrdir = self.make_controldir('.') - child_branch = self.make_branch('child', format=child_format) + def prepare_default_stacking(self, child_format="1.6"): + parent_bzrdir = self.make_controldir(".") + child_branch = self.make_branch("child", format=child_format) parent_bzrdir.get_config().set_default_stack_on(child_branch.base) - new_child_transport = parent_bzrdir.transport.clone('child2') + new_child_transport = parent_bzrdir.transport.clone("child2") return child_branch, new_child_transport def test_clone_on_transport_obeys_stacking_policy(self): child_branch, new_child_transport = self.prepare_default_stacking() - new_child = child_branch.controldir.clone_on_transport( - new_child_transport) - self.assertEqual(child_branch.base, - new_child.open_branch().get_stacked_on_url()) + new_child = child_branch.controldir.clone_on_transport(new_child_transport) + self.assertEqual( + child_branch.base, new_child.open_branch().get_stacked_on_url() + ) def test_default_stacking_with_stackable_branch_unstackable_repo(self): # Make stackable source branch with an unstackable repo format. - source_bzrdir = self.make_controldir('source') + source_bzrdir = self.make_controldir("source") knitpack_repo.RepositoryFormatKnitPack1().initialize(source_bzrdir) - breezy.bzr.branch.BzrBranchFormat7().initialize( - source_bzrdir) + breezy.bzr.branch.BzrBranchFormat7().initialize(source_bzrdir) # Make a directory with a default stacking policy - parent_bzrdir = self.make_controldir('parent') - stacked_on = self.make_branch('parent/stacked-on', format='pack-0.92') + parent_bzrdir = self.make_controldir("parent") + stacked_on = self.make_branch("parent/stacked-on", format="pack-0.92") parent_bzrdir.get_config().set_default_stack_on(stacked_on.base) # Clone source into directory - source_bzrdir.clone(self.get_url('parent/target')) + source_bzrdir.clone(self.get_url("parent/target")) def test_format_initialize_on_transport_ex_stacked_on(self): # trunk is a stackable format. Note that its in the same server area # which is what launchpad does, but not sufficient to exercise the # general case. - self.make_branch('trunk', format='1.9') - t = self.get_transport('stacked') - old_fmt = controldir.format_registry.make_controldir('pack-0.92') + self.make_branch("trunk", format="1.9") + t = self.get_transport("stacked") + old_fmt = controldir.format_registry.make_controldir("pack-0.92") repo_name = old_fmt.repository_format.network_name() # Should end up with a 1.9 format (stackable) - repo, control, require_stacking, repo_policy = \ - old_fmt.initialize_on_transport_ex(t, - repo_format_name=repo_name, stacked_on='../trunk', - stack_on_pwd=t.base) + ( + repo, + control, + require_stacking, + repo_policy, + ) = old_fmt.initialize_on_transport_ex( + t, repo_format_name=repo_name, stacked_on="../trunk", stack_on_pwd=t.base + ) if repo is not None: # Repositories are open write-locked self.assertTrue(repo.is_write_locked()) @@ -550,93 +594,108 @@ def test_format_initialize_on_transport_ex_stacked_on(self): self.assertIsInstance(control, bzrdir.BzrDir) opened = bzrdir.BzrDir.open(t.base) if not isinstance(old_fmt, remote.RemoteBzrDirFormat): - self.assertEqual(control._format.network_name(), - old_fmt.network_name()) - self.assertEqual(control._format.network_name(), - opened._format.network_name()) + self.assertEqual(control._format.network_name(), old_fmt.network_name()) + self.assertEqual( + control._format.network_name(), opened._format.network_name() + ) self.assertEqual(control.__class__, opened.__class__) self.assertLength(1, repo._fallback_repositories) def test_sprout_obeys_stacking_policy(self): child_branch, new_child_transport = self.prepare_default_stacking() new_child = child_branch.controldir.sprout(new_child_transport.base) - self.assertEqual(child_branch.base, - new_child.open_branch().get_stacked_on_url()) + self.assertEqual( + child_branch.base, new_child.open_branch().get_stacked_on_url() + ) def test_clone_ignores_policy_for_unsupported_formats(self): child_branch, new_child_transport = self.prepare_default_stacking( - child_format='pack-0.92') - new_child = child_branch.controldir.clone_on_transport( - new_child_transport) - self.assertRaises(branch.UnstackableBranchFormat, - new_child.open_branch().get_stacked_on_url) + child_format="pack-0.92" + ) + new_child = child_branch.controldir.clone_on_transport(new_child_transport) + self.assertRaises( + branch.UnstackableBranchFormat, new_child.open_branch().get_stacked_on_url + ) def test_sprout_ignores_policy_for_unsupported_formats(self): child_branch, new_child_transport = self.prepare_default_stacking( - child_format='pack-0.92') + child_format="pack-0.92" + ) new_child = child_branch.controldir.sprout(new_child_transport.base) - self.assertRaises(branch.UnstackableBranchFormat, - new_child.open_branch().get_stacked_on_url) + self.assertRaises( + branch.UnstackableBranchFormat, new_child.open_branch().get_stacked_on_url + ) def test_sprout_upgrades_format_if_stacked_specified(self): child_branch, new_child_transport = self.prepare_default_stacking( - child_format='pack-0.92') - new_child = child_branch.controldir.sprout(new_child_transport.base, - stacked=True) - self.assertEqual(child_branch.controldir.root_transport.base, - new_child.open_branch().get_stacked_on_url()) + child_format="pack-0.92" + ) + new_child = child_branch.controldir.sprout( + new_child_transport.base, stacked=True + ) + self.assertEqual( + child_branch.controldir.root_transport.base, + new_child.open_branch().get_stacked_on_url(), + ) repo = new_child.open_repository() self.assertTrue(repo._format.supports_external_lookups) self.assertFalse(repo.supports_rich_root()) def test_clone_on_transport_upgrades_format_if_stacked_on_specified(self): child_branch, new_child_transport = self.prepare_default_stacking( - child_format='pack-0.92') - new_child = child_branch.controldir.clone_on_transport(new_child_transport, - stacked_on=child_branch.controldir.root_transport.base) - self.assertEqual(child_branch.controldir.root_transport.base, - new_child.open_branch().get_stacked_on_url()) + child_format="pack-0.92" + ) + new_child = child_branch.controldir.clone_on_transport( + new_child_transport, stacked_on=child_branch.controldir.root_transport.base + ) + self.assertEqual( + child_branch.controldir.root_transport.base, + new_child.open_branch().get_stacked_on_url(), + ) repo = new_child.open_repository() self.assertTrue(repo._format.supports_external_lookups) self.assertFalse(repo.supports_rich_root()) def test_sprout_upgrades_to_rich_root_format_if_needed(self): child_branch, new_child_transport = self.prepare_default_stacking( - child_format='rich-root-pack') - new_child = child_branch.controldir.sprout(new_child_transport.base, - stacked=True) + child_format="rich-root-pack" + ) + new_child = child_branch.controldir.sprout( + new_child_transport.base, stacked=True + ) repo = new_child.open_repository() self.assertTrue(repo._format.supports_external_lookups) self.assertTrue(repo.supports_rich_root()) def test_add_fallback_repo_handles_absolute_urls(self): - stack_on = self.make_branch('stack_on', format='1.6') - repo = self.make_repository('repo', format='1.6') + stack_on = self.make_branch("stack_on", format="1.6") + repo = self.make_repository("repo", format="1.6") policy = bzrdir.UseExistingRepository(repo, stack_on.base) policy._add_fallback(repo) def test_add_fallback_repo_handles_relative_urls(self): - stack_on = self.make_branch('stack_on', format='1.6') - repo = self.make_repository('repo', format='1.6') - policy = bzrdir.UseExistingRepository(repo, '.', stack_on.base) + stack_on = self.make_branch("stack_on", format="1.6") + repo = self.make_repository("repo", format="1.6") + policy = bzrdir.UseExistingRepository(repo, ".", stack_on.base) policy._add_fallback(repo) def test_configure_relative_branch_stacking_url(self): - stack_on = self.make_branch('stack_on', format='1.6') - stacked = self.make_branch('stack_on/stacked', format='1.6') - policy = bzrdir.UseExistingRepository(stacked.repository, - '.', stack_on.base) + stack_on = self.make_branch("stack_on", format="1.6") + stacked = self.make_branch("stack_on/stacked", format="1.6") + policy = bzrdir.UseExistingRepository(stacked.repository, ".", stack_on.base) policy.configure_branch(stacked) - self.assertEqual('..', stacked.get_stacked_on_url()) + self.assertEqual("..", stacked.get_stacked_on_url()) def test_relative_branch_stacking_to_absolute(self): - self.make_branch('stack_on', format='1.6') - stacked = self.make_branch('stack_on/stacked', format='1.6') - policy = bzrdir.UseExistingRepository(stacked.repository, - '.', self.get_readonly_url('stack_on')) + self.make_branch("stack_on", format="1.6") + stacked = self.make_branch("stack_on/stacked", format="1.6") + policy = bzrdir.UseExistingRepository( + stacked.repository, ".", self.get_readonly_url("stack_on") + ) policy.configure_branch(stacked) - self.assertEqual(self.get_readonly_url('stack_on'), - stacked.get_stacked_on_url()) + self.assertEqual( + self.get_readonly_url("stack_on"), stacked.get_stacked_on_url() + ) class ChrootedTests(TestCaseWithTransport): @@ -656,178 +715,200 @@ def local_branch_path(self, branch): return os.path.realpath(urlutils.local_path_from_url(branch.base)) def test_open_containing(self): - self.assertRaises(NotBranchError, bzrdir.BzrDir.open_containing, - self.get_readonly_url('')) - self.assertRaises(NotBranchError, bzrdir.BzrDir.open_containing, - self.get_readonly_url('g/p/q')) + self.assertRaises( + NotBranchError, bzrdir.BzrDir.open_containing, self.get_readonly_url("") + ) + self.assertRaises( + NotBranchError, + bzrdir.BzrDir.open_containing, + self.get_readonly_url("g/p/q"), + ) bzrdir.BzrDir.create(self.get_url()) - branch, relpath = bzrdir.BzrDir.open_containing( - self.get_readonly_url('')) - self.assertEqual('', relpath) - branch, relpath = bzrdir.BzrDir.open_containing( - self.get_readonly_url('g/p/q')) - self.assertEqual('g/p/q', relpath) + branch, relpath = bzrdir.BzrDir.open_containing(self.get_readonly_url("")) + self.assertEqual("", relpath) + branch, relpath = bzrdir.BzrDir.open_containing(self.get_readonly_url("g/p/q")) + self.assertEqual("g/p/q", relpath) def test_open_containing_tree_branch_or_repository_empty(self): - self.assertRaises(errors.NotBranchError, - bzrdir.BzrDir.open_containing_tree_branch_or_repository, - self.get_readonly_url('')) + self.assertRaises( + errors.NotBranchError, + bzrdir.BzrDir.open_containing_tree_branch_or_repository, + self.get_readonly_url(""), + ) def test_open_containing_tree_branch_or_repository_all(self): - self.make_branch_and_tree('topdir') - tree, branch, repo, relpath = \ - bzrdir.BzrDir.open_containing_tree_branch_or_repository( - 'topdir/foo') - self.assertEqual(os.path.realpath('topdir'), - os.path.realpath(tree.basedir)) - self.assertEqual(os.path.realpath('topdir'), - self.local_branch_path(branch)) + self.make_branch_and_tree("topdir") + ( + tree, + branch, + repo, + relpath, + ) = bzrdir.BzrDir.open_containing_tree_branch_or_repository("topdir/foo") + self.assertEqual(os.path.realpath("topdir"), os.path.realpath(tree.basedir)) + self.assertEqual(os.path.realpath("topdir"), self.local_branch_path(branch)) self.assertEqual( - osutils.realpath(os.path.join('topdir', '.bzr', 'repository')), - repo.controldir.transport.local_abspath('repository')) - self.assertEqual(relpath, 'foo') + osutils.realpath(os.path.join("topdir", ".bzr", "repository")), + repo.controldir.transport.local_abspath("repository"), + ) + self.assertEqual(relpath, "foo") def test_open_containing_tree_branch_or_repository_no_tree(self): - self.make_branch('branch') - tree, branch, repo, relpath = \ - bzrdir.BzrDir.open_containing_tree_branch_or_repository( - 'branch/foo') + self.make_branch("branch") + ( + tree, + branch, + repo, + relpath, + ) = bzrdir.BzrDir.open_containing_tree_branch_or_repository("branch/foo") self.assertEqual(tree, None) - self.assertEqual(os.path.realpath('branch'), - self.local_branch_path(branch)) + self.assertEqual(os.path.realpath("branch"), self.local_branch_path(branch)) self.assertEqual( - osutils.realpath(os.path.join('branch', '.bzr', 'repository')), - repo.controldir.transport.local_abspath('repository')) - self.assertEqual(relpath, 'foo') + osutils.realpath(os.path.join("branch", ".bzr", "repository")), + repo.controldir.transport.local_abspath("repository"), + ) + self.assertEqual(relpath, "foo") def test_open_containing_tree_branch_or_repository_repo(self): - self.make_repository('repo') - tree, branch, repo, relpath = \ - bzrdir.BzrDir.open_containing_tree_branch_or_repository( - 'repo') + self.make_repository("repo") + ( + tree, + branch, + repo, + relpath, + ) = bzrdir.BzrDir.open_containing_tree_branch_or_repository("repo") self.assertEqual(tree, None) self.assertEqual(branch, None) self.assertEqual( - osutils.realpath(os.path.join('repo', '.bzr', 'repository')), - repo.controldir.transport.local_abspath('repository')) - self.assertEqual(relpath, '') + osutils.realpath(os.path.join("repo", ".bzr", "repository")), + repo.controldir.transport.local_abspath("repository"), + ) + self.assertEqual(relpath, "") def test_open_containing_tree_branch_or_repository_shared_repo(self): - self.make_repository('shared', shared=True) - bzrdir.BzrDir.create_branch_convenience('shared/branch', - force_new_tree=False) - tree, branch, repo, relpath = \ - bzrdir.BzrDir.open_containing_tree_branch_or_repository( - 'shared/branch') + self.make_repository("shared", shared=True) + bzrdir.BzrDir.create_branch_convenience("shared/branch", force_new_tree=False) + ( + tree, + branch, + repo, + relpath, + ) = bzrdir.BzrDir.open_containing_tree_branch_or_repository("shared/branch") self.assertEqual(tree, None) - self.assertEqual(os.path.realpath('shared/branch'), - self.local_branch_path(branch)) self.assertEqual( - osutils.realpath(os.path.join('shared', '.bzr', 'repository')), - repo.controldir.transport.local_abspath('repository')) - self.assertEqual(relpath, '') + os.path.realpath("shared/branch"), self.local_branch_path(branch) + ) + self.assertEqual( + osutils.realpath(os.path.join("shared", ".bzr", "repository")), + repo.controldir.transport.local_abspath("repository"), + ) + self.assertEqual(relpath, "") def test_open_containing_tree_branch_or_repository_branch_subdir(self): - self.make_branch_and_tree('foo') - self.build_tree(['foo/bar/']) - tree, branch, repo, relpath = \ - bzrdir.BzrDir.open_containing_tree_branch_or_repository( - 'foo/bar') - self.assertEqual(os.path.realpath('foo'), - os.path.realpath(tree.basedir)) - self.assertEqual(os.path.realpath('foo'), - self.local_branch_path(branch)) + self.make_branch_and_tree("foo") + self.build_tree(["foo/bar/"]) + ( + tree, + branch, + repo, + relpath, + ) = bzrdir.BzrDir.open_containing_tree_branch_or_repository("foo/bar") + self.assertEqual(os.path.realpath("foo"), os.path.realpath(tree.basedir)) + self.assertEqual(os.path.realpath("foo"), self.local_branch_path(branch)) self.assertEqual( - osutils.realpath(os.path.join('foo', '.bzr', 'repository')), - repo.controldir.transport.local_abspath('repository')) - self.assertEqual(relpath, 'bar') + osutils.realpath(os.path.join("foo", ".bzr", "repository")), + repo.controldir.transport.local_abspath("repository"), + ) + self.assertEqual(relpath, "bar") def test_open_containing_tree_branch_or_repository_repo_subdir(self): - self.make_repository('bar') - self.build_tree(['bar/baz/']) - tree, branch, repo, relpath = \ - bzrdir.BzrDir.open_containing_tree_branch_or_repository( - 'bar/baz') + self.make_repository("bar") + self.build_tree(["bar/baz/"]) + ( + tree, + branch, + repo, + relpath, + ) = bzrdir.BzrDir.open_containing_tree_branch_or_repository("bar/baz") self.assertEqual(tree, None) self.assertEqual(branch, None) self.assertEqual( - osutils.realpath(os.path.join('bar', '.bzr', 'repository')), - repo.controldir.transport.local_abspath('repository')) - self.assertEqual(relpath, 'baz') + osutils.realpath(os.path.join("bar", ".bzr", "repository")), + repo.controldir.transport.local_abspath("repository"), + ) + self.assertEqual(relpath, "baz") def test_open_containing_from_transport(self): - self.assertRaises(NotBranchError, - bzrdir.BzrDir.open_containing_from_transport, - _mod_transport.get_transport_from_url(self.get_readonly_url(''))) - self.assertRaises(NotBranchError, - bzrdir.BzrDir.open_containing_from_transport, - _mod_transport.get_transport_from_url( - self.get_readonly_url('g/p/q'))) + self.assertRaises( + NotBranchError, + bzrdir.BzrDir.open_containing_from_transport, + _mod_transport.get_transport_from_url(self.get_readonly_url("")), + ) + self.assertRaises( + NotBranchError, + bzrdir.BzrDir.open_containing_from_transport, + _mod_transport.get_transport_from_url(self.get_readonly_url("g/p/q")), + ) bzrdir.BzrDir.create(self.get_url()) branch, relpath = bzrdir.BzrDir.open_containing_from_transport( - _mod_transport.get_transport_from_url( - self.get_readonly_url(''))) - self.assertEqual('', relpath) + _mod_transport.get_transport_from_url(self.get_readonly_url("")) + ) + self.assertEqual("", relpath) branch, relpath = bzrdir.BzrDir.open_containing_from_transport( - _mod_transport.get_transport_from_url( - self.get_readonly_url('g/p/q'))) - self.assertEqual('g/p/q', relpath) + _mod_transport.get_transport_from_url(self.get_readonly_url("g/p/q")) + ) + self.assertEqual("g/p/q", relpath) def test_open_containing_tree_or_branch(self): - self.make_branch_and_tree('topdir') + self.make_branch_and_tree("topdir") tree, branch, relpath = bzrdir.BzrDir.open_containing_tree_or_branch( - 'topdir/foo') - self.assertEqual(os.path.realpath('topdir'), - os.path.realpath(tree.basedir)) - self.assertEqual(os.path.realpath('topdir'), - self.local_branch_path(branch)) + "topdir/foo" + ) + self.assertEqual(os.path.realpath("topdir"), os.path.realpath(tree.basedir)) + self.assertEqual(os.path.realpath("topdir"), self.local_branch_path(branch)) self.assertIs(tree.controldir, branch.controldir) - self.assertEqual('foo', relpath) + self.assertEqual("foo", relpath) # opening from non-local should not return the tree tree, branch, relpath = bzrdir.BzrDir.open_containing_tree_or_branch( - self.get_readonly_url('topdir/foo')) + self.get_readonly_url("topdir/foo") + ) self.assertEqual(None, tree) - self.assertEqual('foo', relpath) + self.assertEqual("foo", relpath) # without a tree: - self.make_branch('topdir/foo') + self.make_branch("topdir/foo") tree, branch, relpath = bzrdir.BzrDir.open_containing_tree_or_branch( - 'topdir/foo') + "topdir/foo" + ) self.assertIsNone(tree) - self.assertEqual(os.path.realpath('topdir/foo'), - self.local_branch_path(branch)) - self.assertEqual('', relpath) + self.assertEqual(os.path.realpath("topdir/foo"), self.local_branch_path(branch)) + self.assertEqual("", relpath) def test_open_tree_or_branch(self): - self.make_branch_and_tree('topdir') - tree, branch = bzrdir.BzrDir.open_tree_or_branch('topdir') - self.assertEqual(os.path.realpath('topdir'), - os.path.realpath(tree.basedir)) - self.assertEqual(os.path.realpath('topdir'), - self.local_branch_path(branch)) + self.make_branch_and_tree("topdir") + tree, branch = bzrdir.BzrDir.open_tree_or_branch("topdir") + self.assertEqual(os.path.realpath("topdir"), os.path.realpath(tree.basedir)) + self.assertEqual(os.path.realpath("topdir"), self.local_branch_path(branch)) self.assertIs(tree.controldir, branch.controldir) # opening from non-local should not return the tree tree, branch = bzrdir.BzrDir.open_tree_or_branch( - self.get_readonly_url('topdir')) + self.get_readonly_url("topdir") + ) self.assertEqual(None, tree) # without a tree: - self.make_branch('topdir/foo') - tree, branch = bzrdir.BzrDir.open_tree_or_branch('topdir/foo') + self.make_branch("topdir/foo") + tree, branch = bzrdir.BzrDir.open_tree_or_branch("topdir/foo") self.assertIsNone(tree) - self.assertEqual(os.path.realpath('topdir/foo'), - self.local_branch_path(branch)) + self.assertEqual(os.path.realpath("topdir/foo"), self.local_branch_path(branch)) def test_open_tree_or_branch_named(self): - tree = self.make_branch_and_tree('topdir') + tree = self.make_branch_and_tree("topdir") self.assertRaises( - NotBranchError, - bzrdir.BzrDir.open_tree_or_branch, 'topdir', name='missing') - tree.branch.controldir.create_branch('named') - tree, branch = bzrdir.BzrDir.open_tree_or_branch('topdir', name='named') - self.assertEqual(os.path.realpath('topdir'), - os.path.realpath(tree.basedir)) - self.assertEqual(os.path.realpath('topdir'), - self.local_branch_path(branch)) - self.assertEqual(branch.name, 'named') + NotBranchError, bzrdir.BzrDir.open_tree_or_branch, "topdir", name="missing" + ) + tree.branch.controldir.create_branch("named") + tree, branch = bzrdir.BzrDir.open_tree_or_branch("topdir", name="named") + self.assertEqual(os.path.realpath("topdir"), os.path.realpath(tree.basedir)) + self.assertEqual(os.path.realpath("topdir"), self.local_branch_path(branch)) + self.assertEqual(branch.name, "named") self.assertIs(tree.controldir, branch.controldir) def test_open_from_transport(self): @@ -846,51 +927,51 @@ def test_open_from_transport_no_bzrdir(self): def test_open_from_transport_bzrdir_in_parent(self): bzrdir.BzrDir.create(self.get_url()) t = self.get_transport() - t.mkdir('subdir') - t = t.clone('subdir') + t.mkdir("subdir") + t = t.clone("subdir") self.assertRaises(NotBranchError, bzrdir.BzrDir.open_from_transport, t) def test_sprout_recursive(self): - tree = self.make_branch_and_tree('tree1') - sub_tree = self.make_branch_and_tree('tree1/subtree') - sub_tree.set_root_id(b'subtree-root') + tree = self.make_branch_and_tree("tree1") + sub_tree = self.make_branch_and_tree("tree1/subtree") + sub_tree.set_root_id(b"subtree-root") tree.add_reference(sub_tree) - tree.set_reference_info('subtree', sub_tree.branch.user_url) - self.build_tree(['tree1/subtree/file']) - sub_tree.add('file') - tree.commit('Initial commit') - tree2 = tree.controldir.sprout('tree2').open_workingtree() + tree.set_reference_info("subtree", sub_tree.branch.user_url) + self.build_tree(["tree1/subtree/file"]) + sub_tree.add("file") + tree.commit("Initial commit") + tree2 = tree.controldir.sprout("tree2").open_workingtree() tree2.lock_read() self.addCleanup(tree2.unlock) - self.assertPathExists('tree2/subtree/file') - self.assertEqual('tree-reference', tree2.kind('subtree')) + self.assertPathExists("tree2/subtree/file") + self.assertEqual("tree-reference", tree2.kind("subtree")) def test_cloning_metadir(self): """Ensure that cloning metadir is suitable.""" - bzrdir = self.make_controldir('bzrdir') + bzrdir = self.make_controldir("bzrdir") bzrdir.cloning_metadir() - branch = self.make_branch('branch', format='knit') + branch = self.make_branch("branch", format="knit") format = branch.controldir.cloning_metadir() - self.assertIsInstance(format.workingtree_format, - workingtree_4.WorkingTreeFormat6) + self.assertIsInstance( + format.workingtree_format, workingtree_4.WorkingTreeFormat6 + ) def test_sprout_recursive_treeless(self): - tree = self.make_branch_and_tree('tree1', - format='development-subtree') - sub_tree = self.make_branch_and_tree('tree1/subtree', - format='development-subtree') + tree = self.make_branch_and_tree("tree1", format="development-subtree") + sub_tree = self.make_branch_and_tree( + "tree1/subtree", format="development-subtree" + ) tree.add_reference(sub_tree) - tree.set_reference_info('subtree', sub_tree.branch.user_url) - self.build_tree(['tree1/subtree/file']) - sub_tree.add('file') - tree.commit('Initial commit') + tree.set_reference_info("subtree", sub_tree.branch.user_url) + self.build_tree(["tree1/subtree/file"]) + sub_tree.add("file") + tree.commit("Initial commit") # The following line force the orhaning to reveal bug #634470 - tree.branch.get_config_stack().set('transform.orphan_policy', 'move') + tree.branch.get_config_stack().set("transform.orphan_policy", "move") tree.controldir.destroy_workingtree() # FIXME: subtree/.bzr is left here which allows the test to pass (or # fail :-( ) -- vila 20100909 - repo = self.make_repository('repo', shared=True, - format='development-subtree') + repo = self.make_repository("repo", shared=True, format="development-subtree") repo.set_make_working_trees(False) # FIXME: we just deleted the workingtree and now we want to use it ???? # At a minimum, we should use tree.branch below (but this fails too @@ -900,33 +981,35 @@ def test_sprout_recursive_treeless(self): # by bzrdir.BzrDirMeta1.destroy_workingtree when it ignores the # [DeletingParent('Not deleting', u'subtree', None)] conflict). See bug # #634470. -- vila 20100909 - tree.controldir.sprout('repo/tree2') - self.assertPathExists('repo/tree2/subtree') - self.assertPathDoesNotExist('repo/tree2/subtree/file') + tree.controldir.sprout("repo/tree2") + self.assertPathExists("repo/tree2/subtree") + self.assertPathDoesNotExist("repo/tree2/subtree/file") def make_foo_bar_baz(self): - foo = bzrdir.BzrDir.create_branch_convenience('foo').controldir - bar = self.make_branch('foo/bar').controldir - baz = self.make_branch('baz').controldir + foo = bzrdir.BzrDir.create_branch_convenience("foo").controldir + bar = self.make_branch("foo/bar").controldir + baz = self.make_branch("baz").controldir return foo, bar, baz def test_find_controldirs(self): foo, bar, baz = self.make_foo_bar_baz() t = self.get_transport() - self.assertEqualBzrdirs( - [baz, foo, bar], bzrdir.BzrDir.find_controldirs(t)) + self.assertEqualBzrdirs([baz, foo, bar], bzrdir.BzrDir.find_controldirs(t)) def make_fake_permission_denied_transport(self, transport, paths): """Create a transport that raises PermissionDenied for some paths.""" + def filter(path): if path in paths: raise errors.PermissionDenied(path) return path + path_filter_server = pathfilter.PathFilteringServer(transport, filter) path_filter_server.start_server() self.addCleanup(path_filter_server.stop_server) path_filter_transport = pathfilter.PathFilteringTransport( - path_filter_server, '.') + path_filter_server, "." + ) return (path_filter_server, path_filter_transport) def assertBranchUrlsEndWith(self, expect_url_suffix, actual_bzrdirs): @@ -937,26 +1020,29 @@ def assertBranchUrlsEndWith(self, expect_url_suffix, actual_bzrdirs): def test_find_controldirs_permission_denied(self): foo, bar, baz = self.make_foo_bar_baz() t = self.get_transport() - path_filter_server, path_filter_transport = \ - self.make_fake_permission_denied_transport(t, ['foo']) + ( + path_filter_server, + path_filter_transport, + ) = self.make_fake_permission_denied_transport(t, ["foo"]) # local transport - self.assertBranchUrlsEndWith('/baz/', - bzrdir.BzrDir.find_controldirs(path_filter_transport)) + self.assertBranchUrlsEndWith( + "/baz/", bzrdir.BzrDir.find_controldirs(path_filter_transport) + ) # smart server - smart_transport = self.make_smart_server('.', - backing_server=path_filter_server) - self.assertBranchUrlsEndWith('/baz/', - bzrdir.BzrDir.find_controldirs(smart_transport)) + smart_transport = self.make_smart_server(".", backing_server=path_filter_server) + self.assertBranchUrlsEndWith( + "/baz/", bzrdir.BzrDir.find_controldirs(smart_transport) + ) def test_find_controldirs_list_current(self): def list_current(transport): - return [s for s in transport.list_dir('') if s != 'baz'] + return [s for s in transport.list_dir("") if s != "baz"] foo, bar, baz = self.make_foo_bar_baz() t = self.get_transport() self.assertEqualBzrdirs( - [foo, bar], - bzrdir.BzrDir.find_controldirs(t, list_current=list_current)) + [foo, bar], bzrdir.BzrDir.find_controldirs(t, list_current=list_current) + ) def test_find_controldirs_evaluate(self): def evaluate(bzrdir): @@ -969,8 +1055,10 @@ def evaluate(bzrdir): foo, bar, baz = self.make_foo_bar_baz() t = self.get_transport() - self.assertEqual([baz.root_transport.base, foo.root_transport.base], - list(bzrdir.BzrDir.find_controldirs(t, evaluate=evaluate))) + self.assertEqual( + [baz.root_transport.base, foo.root_transport.base], + list(bzrdir.BzrDir.find_controldirs(t, evaluate=evaluate)), + ) def assertEqualBzrdirs(self, first, second): first = list(first) @@ -980,9 +1068,9 @@ def assertEqualBzrdirs(self, first, second): self.assertEqual(x.root_transport.base, y.root_transport.base) def test_find_branches(self): - self.make_repository('', shared=True) + self.make_repository("", shared=True) foo, bar, baz = self.make_foo_bar_baz() - self.make_controldir('foo/qux') + self.make_controldir("foo/qux") t = self.get_transport() branches = bzrdir.BzrDir.find_branches(t) self.assertEqual(baz.root_transport.base, branches[0].base) @@ -990,24 +1078,22 @@ def test_find_branches(self): self.assertEqual(bar.root_transport.base, branches[2].base) # ensure this works without a top-level repo - branches = bzrdir.BzrDir.find_branches(t.clone('foo')) + branches = bzrdir.BzrDir.find_branches(t.clone("foo")) self.assertEqual(foo.root_transport.base, branches[0].base) self.assertEqual(bar.root_transport.base, branches[1].base) class TestMissingRepoBranchesSkipped(TestCaseWithMemoryTransport): - def test_find_controldirs_missing_repo(self): t = self.get_transport() - arepo = self.make_repository('arepo', shared=True) - abranch_url = arepo.user_url + '/abranch' + arepo = self.make_repository("arepo", shared=True) + abranch_url = arepo.user_url + "/abranch" bzrdir.BzrDir.create(abranch_url).create_branch() - t.delete_tree('arepo/.bzr') - self.assertRaises(errors.NoRepositoryPresent, - branch.Branch.open, abranch_url) - self.make_branch('baz') + t.delete_tree("arepo/.bzr") + self.assertRaises(errors.NoRepositoryPresent, branch.Branch.open, abranch_url) + self.make_branch("baz") for actual_bzrdir in bzrdir.BzrDir.find_branches(t): - self.assertEndsWith(actual_bzrdir.user_url, '/baz/') + self.assertEndsWith(actual_bzrdir.user_url, "/baz/") class TestMeta1DirFormat(TestCaseWithTransport): @@ -1016,27 +1102,27 @@ class TestMeta1DirFormat(TestCaseWithTransport): def test_right_base_dirs(self): dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url()) t = dir.transport - branch_base = t.clone('branch').base + branch_base = t.clone("branch").base self.assertEqual(branch_base, dir.get_branch_transport(None).base) - self.assertEqual(branch_base, - dir.get_branch_transport(BzrBranchFormat5()).base) - repository_base = t.clone('repository').base - self.assertEqual( - repository_base, dir.get_repository_transport(None).base) + self.assertEqual(branch_base, dir.get_branch_transport(BzrBranchFormat5()).base) + repository_base = t.clone("repository").base + self.assertEqual(repository_base, dir.get_repository_transport(None).base) repository_format = repository.format_registry.get_default() - self.assertEqual(repository_base, - dir.get_repository_transport(repository_format).base) - checkout_base = t.clone('checkout').base self.assertEqual( - checkout_base, dir.get_workingtree_transport(None).base) - self.assertEqual(checkout_base, - dir.get_workingtree_transport(workingtree_3.WorkingTreeFormat3()).base) + repository_base, dir.get_repository_transport(repository_format).base + ) + checkout_base = t.clone("checkout").base + self.assertEqual(checkout_base, dir.get_workingtree_transport(None).base) + self.assertEqual( + checkout_base, + dir.get_workingtree_transport(workingtree_3.WorkingTreeFormat3()).base, + ) def test_meta1dir_uses_lockdir(self): """Meta1 format uses a LockDir to guard the whole directory, not a file.""" dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url()) t = dir.transport - self.assertIsDirectory('branch-lock', t) + self.assertIsDirectory("branch-lock", t) def test_comparison(self): """Equality and inequality behave properly. @@ -1044,42 +1130,38 @@ def test_comparison(self): Metadirs should compare equal iff they have the same repo, branch and tree formats. """ - mydir = controldir.format_registry.make_controldir('knit') + mydir = controldir.format_registry.make_controldir("knit") self.assertEqual(mydir, mydir) self.assertEqual(mydir, mydir) - otherdir = controldir.format_registry.make_controldir('knit') + otherdir = controldir.format_registry.make_controldir("knit") self.assertEqual(otherdir, mydir) self.assertEqual(otherdir, mydir) - otherdir2 = controldir.format_registry.make_controldir( - 'development-subtree') + otherdir2 = controldir.format_registry.make_controldir("development-subtree") self.assertNotEqual(otherdir2, mydir) self.assertNotEqual(otherdir2, mydir) def test_with_features(self): - tree = self.make_branch_and_tree('tree', format='2a') + tree = self.make_branch_and_tree("tree", format="2a") tree.controldir.update_feature_flags({b"bar": b"required"}) - self.assertRaises(bzrdir.MissingFeature, bzrdir.BzrDir.open, 'tree') - bzrdir.BzrDirMetaFormat1.register_feature(b'bar') - self.addCleanup(bzrdir.BzrDirMetaFormat1.unregister_feature, b'bar') - dir = bzrdir.BzrDir.open('tree') + self.assertRaises(bzrdir.MissingFeature, bzrdir.BzrDir.open, "tree") + bzrdir.BzrDirMetaFormat1.register_feature(b"bar") + self.addCleanup(bzrdir.BzrDirMetaFormat1.unregister_feature, b"bar") + dir = bzrdir.BzrDir.open("tree") self.assertEqual(b"required", dir._format.features.get(b"bar")) - tree.controldir.update_feature_flags({ - b"bar": None, - b"nonexistant": None}) - dir = bzrdir.BzrDir.open('tree') + tree.controldir.update_feature_flags({b"bar": None, b"nonexistant": None}) + dir = bzrdir.BzrDir.open("tree") self.assertEqual({}, dir._format.features) def test_needs_conversion_different_working_tree(self): # meta1dirs need an conversion if any element is not the default. - new_format = controldir.format_registry.make_controldir('dirstate') - tree = self.make_branch_and_tree('tree', format='knit') - self.assertTrue(tree.controldir.needs_format_conversion( - new_format)) + new_format = controldir.format_registry.make_controldir("dirstate") + tree = self.make_branch_and_tree("tree", format="knit") + self.assertTrue(tree.controldir.needs_format_conversion(new_format)) def test_initialize_on_format_uses_smart_transport(self): self.setup_smart_server_with_call_log() - new_format = controldir.format_registry.make_controldir('dirstate') - transport = self.get_transport('target') + new_format = controldir.format_registry.make_controldir("dirstate") + transport = self.get_transport("target") transport.ensure_base() self.reset_smart_call_log() instance = new_format.initialize_on_transport(transport) @@ -1102,44 +1184,45 @@ def setUp(self): def test_create_branch_convenience(self): # outside a repo the default convenience output is a repo+branch_tree - format = controldir.format_registry.make_controldir('knit') + format = controldir.format_registry.make_controldir("knit") branch = bzrdir.BzrDir.create_branch_convenience( - self.get_url('foo'), format=format) - self.assertRaises(errors.NoWorkingTree, - branch.controldir.open_workingtree) + self.get_url("foo"), format=format + ) + self.assertRaises(errors.NoWorkingTree, branch.controldir.open_workingtree) branch.controldir.open_repository() def test_create_branch_convenience_force_tree_not_local_fails(self): # outside a repo the default convenience output is a repo+branch_tree - format = controldir.format_registry.make_controldir('knit') - self.assertRaises(errors.NotLocalUrl, - bzrdir.BzrDir.create_branch_convenience, - self.get_url('foo'), - force_new_tree=True, - format=format) + format = controldir.format_registry.make_controldir("knit") + self.assertRaises( + errors.NotLocalUrl, + bzrdir.BzrDir.create_branch_convenience, + self.get_url("foo"), + force_new_tree=True, + format=format, + ) t = self.get_transport() - self.assertFalse(t.has('foo')) + self.assertFalse(t.has("foo")) def test_clone(self): # clone into a nonlocal path works - format = controldir.format_registry.make_controldir('knit') - branch = bzrdir.BzrDir.create_branch_convenience('local', - format=format) + format = controldir.format_registry.make_controldir("knit") + branch = bzrdir.BzrDir.create_branch_convenience("local", format=format) branch.controldir.open_workingtree() - result = branch.controldir.clone(self.get_url('remote')) - self.assertRaises(errors.NoWorkingTree, - result.open_workingtree) + result = branch.controldir.clone(self.get_url("remote")) + self.assertRaises(errors.NoWorkingTree, result.open_workingtree) result.open_branch() result.open_repository() def test_checkout_metadir(self): # checkout_metadir has reasonable working tree format even when no # working tree is present - self.make_branch('branch-knit2', format='dirstate-with-subtree') - my_bzrdir = bzrdir.BzrDir.open(self.get_url('branch-knit2')) + self.make_branch("branch-knit2", format="dirstate-with-subtree") + my_bzrdir = bzrdir.BzrDir.open(self.get_url("branch-knit2")) checkout_format = my_bzrdir.checkout_metadir() - self.assertIsInstance(checkout_format.workingtree_format, - workingtree_4.WorkingTreeFormat4) + self.assertIsInstance( + checkout_format.workingtree_format, workingtree_4.WorkingTreeFormat4 + ) class TestHTTPRedirectionsBase: @@ -1174,86 +1257,86 @@ def test_loop(self): # Both servers redirect to each other creating a loop self.new_server.redirect_to(self.old_server.host, self.old_server.port) # Starting from either server should loop - old_url = self._qualified_url(self.old_server.host, - self.old_server.port) + old_url = self._qualified_url(self.old_server.host, self.old_server.port) oldt = self._transport(old_url) - self.assertRaises(errors.NotBranchError, - bzrdir.BzrDir.open_from_transport, oldt) - new_url = self._qualified_url(self.new_server.host, - self.new_server.port) + self.assertRaises( + errors.NotBranchError, bzrdir.BzrDir.open_from_transport, oldt + ) + new_url = self._qualified_url(self.new_server.host, self.new_server.port) newt = self._transport(new_url) - self.assertRaises(errors.NotBranchError, - bzrdir.BzrDir.open_from_transport, newt) + self.assertRaises( + errors.NotBranchError, bzrdir.BzrDir.open_from_transport, newt + ) def test_qualifier_preserved(self): - self.make_branch_and_tree('branch') - old_url = self._qualified_url(self.old_server.host, - self.old_server.port) - start = self._transport(old_url).clone('branch') + self.make_branch_and_tree("branch") + old_url = self._qualified_url(self.old_server.host, self.old_server.port) + start = self._transport(old_url).clone("branch") bdir = bzrdir.BzrDir.open_from_transport(start) # Redirection should preserve the qualifier, hence the transport class # itself. self.assertIsInstance(bdir.root_transport, type(start)) -class TestHTTPRedirections(TestHTTPRedirectionsBase, - http_utils.TestCaseWithTwoWebservers): +class TestHTTPRedirections( + TestHTTPRedirectionsBase, http_utils.TestCaseWithTwoWebservers +): """Tests redirections for urllib implementation.""" _transport = HttpTransport def _qualified_url(self, host, port): - result = f'http://{host}:{port}' + result = f"http://{host}:{port}" self.permit_url(result) return result -class TestHTTPRedirections_nosmart(TestHTTPRedirectionsBase, - http_utils.TestCaseWithTwoWebservers): +class TestHTTPRedirections_nosmart( + TestHTTPRedirectionsBase, http_utils.TestCaseWithTwoWebservers +): """Tests redirections for the nosmart decorator.""" _transport = NoSmartTransportDecorator def _qualified_url(self, host, port): - result = f'nosmart+http://{host}:{port}' + result = f"nosmart+http://{host}:{port}" self.permit_url(result) return result -class TestHTTPRedirections_readonly(TestHTTPRedirectionsBase, - http_utils.TestCaseWithTwoWebservers): +class TestHTTPRedirections_readonly( + TestHTTPRedirectionsBase, http_utils.TestCaseWithTwoWebservers +): """Tests redirections for readonly decoratror.""" _transport = ReadonlyTransportDecorator def _qualified_url(self, host, port): - result = f'readonly+http://{host}:{port}' + result = f"readonly+http://{host}:{port}" self.permit_url(result) return result class TestDotBzrHidden(TestCaseWithTransport): - - ls = ['ls'] - if sys.platform == 'win32': - ls = [os.environ['COMSPEC'], '/C', 'dir', '/B'] + ls = ["ls"] + if sys.platform == "win32": + ls = [os.environ["COMSPEC"], "/C", "dir", "/B"] def get_ls(self): - f = subprocess.Popen(self.ls, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + f = subprocess.Popen(self.ls, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = f.communicate() - self.assertEqual(0, f.returncode, f'Calling {self.ls} failed: {err}') + self.assertEqual(0, f.returncode, f"Calling {self.ls} failed: {err}") return out.splitlines() def test_dot_bzr_hidden(self): - bzrdir.BzrDir.create('.') - self.build_tree(['a']) - self.assertEqual([b'a'], self.get_ls()) + bzrdir.BzrDir.create(".") + self.build_tree(["a"]) + self.assertEqual([b"a"], self.get_ls()) def test_dot_bzr_hidden_with_url(self): - bzrdir.BzrDir.create(urlutils.local_path_to_url('.')) - self.build_tree(['a']) - self.assertEqual([b'a'], self.get_ls()) + bzrdir.BzrDir.create(urlutils.local_path_to_url(".")) + self.build_tree(["a"]) + self.assertEqual([b"a"], self.get_ls()) class _TestBzrDirFormat(bzrdir.BzrDirMetaFormat1): @@ -1302,11 +1385,11 @@ def __init__(self, transport, *args, **kwargs): self._parent = None def sprout(self, *args, **kwargs): - self.calls.append('sprout') + self.calls.append("sprout") return _TestBranch(self._transport) def copy_content_into(self, destination, revision_id=None): - self.calls.append('copy_content_into') + self.calls.append("copy_content_into") def last_revision(self): return _mod_revision.NULL_REVISION @@ -1315,7 +1398,7 @@ def get_parent(self): return self._parent def _get_config(self): - return config.TransportConfig(self._transport, 'branch.conf') + return config.TransportConfig(self._transport, "branch.conf") def _get_config_store(self): return config.BranchStore(self) @@ -1331,7 +1414,6 @@ def unlock(self): class TestBzrDirSprout(TestCaseWithMemoryTransport): - def test_sprout_uses_branch_sprout(self): """BzrDir.sprout calls Branch.sprout. @@ -1345,7 +1427,7 @@ def test_sprout_uses_branch_sprout(self): overridden to satisfy that. """ # Make an instrumented bzrdir. - t = self.get_transport('source') + t = self.get_transport("source") t.ensure_base() source_bzrdir = _TestBzrDirFormat().initialize_on_transport(t) # The instrumented bzrdir has a test_branch attribute that logs calls @@ -1354,25 +1436,24 @@ def test_sprout_uses_branch_sprout(self): self.assertEqual([], source_bzrdir.test_branch.calls) # Sprout the bzrdir - target_url = self.get_url('target') - source_bzrdir.sprout(target_url, recurse='no') + target_url = self.get_url("target") + source_bzrdir.sprout(target_url, recurse="no") # The bzrdir called the branch's sprout method. - self.assertSubset(['sprout'], source_bzrdir.test_branch.calls) + self.assertSubset(["sprout"], source_bzrdir.test_branch.calls) def test_sprout_parent(self): - grandparent_tree = self.make_branch('grandparent') - parent = grandparent_tree.controldir.sprout('parent').open_branch() - branch_tree = parent.controldir.sprout('branch').open_branch() - self.assertContainsRe(branch_tree.get_parent(), '/parent/$') + grandparent_tree = self.make_branch("grandparent") + parent = grandparent_tree.controldir.sprout("parent").open_branch() + branch_tree = parent.controldir.sprout("branch").open_branch() + self.assertContainsRe(branch_tree.get_parent(), "/parent/$") class TestBzrDirHooks(TestCaseWithMemoryTransport): - def test_pre_open_called(self): calls = [] - bzrdir.BzrDir.hooks.install_named_hook('pre_open', calls.append, None) - transport = self.get_transport('foo') + bzrdir.BzrDir.hooks.install_named_hook("pre_open", calls.append, None) + transport = self.get_transport("foo") url = transport.base self.assertRaises(errors.NotBranchError, bzrdir.BzrDir.open, url) self.assertEqual([transport.base], [t.base for t in calls]) @@ -1384,32 +1465,34 @@ def fail_once(transport): count[0] += 1 if count[0] == 1: raise errors.BzrError("fail") - bzrdir.BzrDir.hooks.install_named_hook('pre_open', fail_once, None) - transport = self.get_transport('foo') + + bzrdir.BzrDir.hooks.install_named_hook("pre_open", fail_once, None) + transport = self.get_transport("foo") url = transport.base err = self.assertRaises(errors.BzrError, bzrdir.BzrDir.open, url) - self.assertEqual('fail', err._preformatted_string) + self.assertEqual("fail", err._preformatted_string) def test_post_repo_init(self): from ...controldir import RepoInitHookParams + calls = [] - bzrdir.BzrDir.hooks.install_named_hook('post_repo_init', - calls.append, None) - self.make_repository('foo') + bzrdir.BzrDir.hooks.install_named_hook("post_repo_init", calls.append, None) + self.make_repository("foo") self.assertLength(1, calls) params = calls[0] self.assertIsInstance(params, RepoInitHookParams) - self.assertTrue(hasattr(params, 'controldir')) - self.assertTrue(hasattr(params, 'repository')) + self.assertTrue(hasattr(params, "controldir")) + self.assertTrue(hasattr(params, "repository")) def test_post_repo_init_hook_repr(self): param_reprs = [] - bzrdir.BzrDir.hooks.install_named_hook('post_repo_init', - lambda params: param_reprs.append(repr(params)), None) - self.make_repository('foo') + bzrdir.BzrDir.hooks.install_named_hook( + "post_repo_init", lambda params: param_reprs.append(repr(params)), None + ) + self.make_repository("foo") self.assertLength(1, param_reprs) param_repr = param_reprs[0] - self.assertStartsWith(param_repr, 'e' - b'l8:timezonei3600ee' - b'l10:propertiesd11:branch-nick6:+trunkee' - b'l9:timestamp14:1242300770.844e' - b'l11:revision-id50:pqm@pqm.ubuntu.com-20090514113250-jntkkpminfn3e0tze' - b'l10:parent-ids' - b'l' - b'50:pqm@pqm.ubuntu.com-20090514104039-kggemn7lrretzpvc' - b'48:jelmer@samba.org-20090510012654-jp9ufxquekaokbeo' - b'ee' - b'l14:inventory-sha140:4a2c7fb50e077699242cf6eb16a61779c7b680a7e' - b'l7:message35:(Jelmer) Move dpush to InterBranch.e' - b'e') +_working_revision_bencode1 = ( + b"l" + b"l6:formati10ee" + b"l9:committer54:Canonical.com Patch Queue Manager e" + b"l8:timezonei3600ee" + b"l10:propertiesd11:branch-nick6:+trunkee" + b"l9:timestamp14:1242300770.844e" + b"l11:revision-id50:pqm@pqm.ubuntu.com-20090514113250-jntkkpminfn3e0tze" + b"l10:parent-ids" + b"l" + b"50:pqm@pqm.ubuntu.com-20090514104039-kggemn7lrretzpvc" + b"48:jelmer@samba.org-20090510012654-jp9ufxquekaokbeo" + b"ee" + b"l14:inventory-sha140:4a2c7fb50e077699242cf6eb16a61779c7b680a7e" + b"l7:message35:(Jelmer) Move dpush to InterBranch.e" + b"e" +) -_working_revision_bencode1_no_timezone = (b'l' - b'l6:formati10ee' - b'l9:committer54:Canonical.com Patch Queue Manager e' - b'l9:timestamp14:1242300770.844e' - b'l10:propertiesd11:branch-nick6:+trunkee' - b'l11:revision-id50:pqm@pqm.ubuntu.com-20090514113250-jntkkpminfn3e0tze' - b'l10:parent-ids' - b'l' - b'50:pqm@pqm.ubuntu.com-20090514104039-kggemn7lrretzpvc' - b'48:jelmer@samba.org-20090510012654-jp9ufxquekaokbeo' - b'ee' - b'l14:inventory-sha140:4a2c7fb50e077699242cf6eb16a61779c7b680a7e' - b'l7:message35:(Jelmer) Move dpush to InterBranch.e' - b'e') +_working_revision_bencode1_no_timezone = ( + b"l" + b"l6:formati10ee" + b"l9:committer54:Canonical.com Patch Queue Manager e" + b"l9:timestamp14:1242300770.844e" + b"l10:propertiesd11:branch-nick6:+trunkee" + b"l11:revision-id50:pqm@pqm.ubuntu.com-20090514113250-jntkkpminfn3e0tze" + b"l10:parent-ids" + b"l" + b"50:pqm@pqm.ubuntu.com-20090514104039-kggemn7lrretzpvc" + b"48:jelmer@samba.org-20090510012654-jp9ufxquekaokbeo" + b"ee" + b"l14:inventory-sha140:4a2c7fb50e077699242cf6eb16a61779c7b680a7e" + b"l7:message35:(Jelmer) Move dpush to InterBranch.e" + b"e" +) class TestBEncodeSerializer1(TestCase): @@ -56,45 +60,68 @@ class TestBEncodeSerializer1(TestCase): def test_unpack_revision(self): """Test unpacking a revision.""" rev = revision_bencode_serializer.read_revision_from_string( - _working_revision_bencode1) - self.assertEqual(rev.committer, - "Canonical.com Patch Queue Manager ") - self.assertEqual(rev.inventory_sha1, - b"4a2c7fb50e077699242cf6eb16a61779c7b680a7") - self.assertEqual([b"pqm@pqm.ubuntu.com-20090514104039-kggemn7lrretzpvc", - b"jelmer@samba.org-20090510012654-jp9ufxquekaokbeo"], - rev.parent_ids) + _working_revision_bencode1 + ) + self.assertEqual( + rev.committer, "Canonical.com Patch Queue Manager " + ) + self.assertEqual( + rev.inventory_sha1, b"4a2c7fb50e077699242cf6eb16a61779c7b680a7" + ) + self.assertEqual( + [ + b"pqm@pqm.ubuntu.com-20090514104039-kggemn7lrretzpvc", + b"jelmer@samba.org-20090510012654-jp9ufxquekaokbeo", + ], + rev.parent_ids, + ) self.assertEqual("(Jelmer) Move dpush to InterBranch.", rev.message) - self.assertEqual(b"pqm@pqm.ubuntu.com-20090514113250-jntkkpminfn3e0tz", - rev.revision_id) + self.assertEqual( + b"pqm@pqm.ubuntu.com-20090514113250-jntkkpminfn3e0tz", rev.revision_id + ) self.assertEqual({"branch-nick": "+trunk"}, rev.properties) self.assertEqual(3600, rev.timezone) def test_written_form_matches(self): rev = revision_bencode_serializer.read_revision_from_string( - _working_revision_bencode1) + _working_revision_bencode1 + ) as_str = revision_bencode_serializer.write_revision_to_string(rev) self.assertEqualDiff(_working_revision_bencode1, as_str) def test_unpack_revision_no_timezone(self): rev = revision_bencode_serializer.read_revision_from_string( - _working_revision_bencode1_no_timezone) + _working_revision_bencode1_no_timezone + ) self.assertEqual(None, rev.timezone) def assertRoundTrips(self, serializer, orig_rev): lines = serializer.write_revision_to_lines(orig_rev) - new_rev = serializer.read_revision_from_string(b''.join(lines)) + new_rev = serializer.read_revision_from_string(b"".join(lines)) self.assertEqual(orig_rev, new_rev) def test_roundtrips_non_ascii(self): rev = Revision( - b"revid1", message="\n\xe5me", - committer='Erik B\xe5gfors', timestamp=1242385452, + b"revid1", + message="\n\xe5me", + committer="Erik B\xe5gfors", + timestamp=1242385452, inventory_sha1=b"4a2c7fb50e077699242cf6eb16a61779c7b680a7", - parent_ids=[], properties={}, - timezone=3600) + parent_ids=[], + properties={}, + timezone=3600, + ) self.assertRoundTrips(revision_bencode_serializer, rev) def test_roundtrips_xml_invalid_chars(self): - rev = Revision(b"revid1", properties={}, parent_ids=[], message="\t\ue000", committer='Erik B\xe5gfors', timestamp=1242385452, timezone=3600, inventory_sha1=b"4a2c7fb50e077699242cf6eb16a61779c7b680a7") + rev = Revision( + b"revid1", + properties={}, + parent_ids=[], + message="\t\ue000", + committer="Erik B\xe5gfors", + timestamp=1242385452, + timezone=3600, + inventory_sha1=b"4a2c7fb50e077699242cf6eb16a61779c7b680a7", + ) self.assertRoundTrips(revision_bencode_serializer, rev) diff --git a/breezy/bzr/tests/test_conflicts.py b/breezy/bzr/tests/test_conflicts.py index d4d8619b38..c1a607a775 100644 --- a/breezy/bzr/tests/test_conflicts.py +++ b/breezy/bzr/tests/test_conflicts.py @@ -27,15 +27,13 @@ class TestPerConflict(tests.TestCase): - scenarios = scenarios.multiply_scenarios(vary_by_conflicts()) def test_stringification(self): text = str(self.conflict) self.assertContainsString(text, self.conflict.path) self.assertContainsString(text.lower(), "conflict") - self.assertContainsString(repr(self.conflict), - self.conflict.__class__.__name__) + self.assertContainsString(repr(self.conflict), self.conflict.__class__.__name__) def test_stanza_roundtrip(self): p = self.conflict @@ -47,79 +45,81 @@ def test_stanza_roundtrip(self): if o.file_id is not None: self.assertIsInstance(o.file_id, bytes) - conflict_path = getattr(o, 'conflict_path', None) + conflict_path = getattr(o, "conflict_path", None) if conflict_path is not None: self.assertIsInstance(conflict_path, str) - conflict_file_id = getattr(o, 'conflict_file_id', None) + conflict_file_id = getattr(o, "conflict_file_id", None) if conflict_file_id is not None: self.assertIsInstance(conflict_file_id, bytes) def test_stanzification(self): stanza = self.conflict.as_stanza() - if 'file_id' in stanza: + if "file_id" in stanza: # In Stanza form, the file_id has to be unicode. - self.assertStartsWith(stanza.get('file_id'), '\xeed') - self.assertStartsWith(stanza.get('path'), 'p\xe5th') - if 'conflict_path' in stanza: - self.assertStartsWith(stanza.get('conflict_path'), 'p\xe5th') - if 'conflict_file_id' in stanza: - self.assertStartsWith(stanza.get('conflict_file_id'), '\xeed') + self.assertStartsWith(stanza.get("file_id"), "\xeed") + self.assertStartsWith(stanza.get("path"), "p\xe5th") + if "conflict_path" in stanza: + self.assertStartsWith(stanza.get("conflict_path"), "p\xe5th") + if "conflict_file_id" in stanza: + self.assertStartsWith(stanza.get("conflict_file_id"), "\xeed") class TestConflicts(tests.TestCaseWithTransport): - def test_resolve_conflict_dir(self): - tree = self.make_branch_and_tree('.') - self.build_tree_contents([('hello', b'hello world4'), - ('hello.THIS', b'hello world2'), - ('hello.BASE', b'hello world1'), - ]) - os.mkdir('hello.OTHER') - tree.add('hello', ids=b'q') - l = bzr_conflicts.ConflictList([bzr_conflicts.TextConflict('hello')]) + tree = self.make_branch_and_tree(".") + self.build_tree_contents( + [ + ("hello", b"hello world4"), + ("hello.THIS", b"hello world2"), + ("hello.BASE", b"hello world1"), + ] + ) + os.mkdir("hello.OTHER") + tree.add("hello", ids=b"q") + l = bzr_conflicts.ConflictList([bzr_conflicts.TextConflict("hello")]) l.remove_files(tree) def test_select_conflicts(self): - tree = self.make_branch_and_tree('.') + tree = self.make_branch_and_tree(".") clist = bzr_conflicts.ConflictList def check_select(not_selected, selected, paths, **kwargs): self.assertEqual( (not_selected, selected), - tree_conflicts.select_conflicts(tree, paths, **kwargs)) + tree_conflicts.select_conflicts(tree, paths, **kwargs), + ) - foo = bzr_conflicts.ContentsConflict('foo') - bar = bzr_conflicts.ContentsConflict('bar') + foo = bzr_conflicts.ContentsConflict("foo") + bar = bzr_conflicts.ContentsConflict("bar") tree_conflicts = clist([foo, bar]) - check_select(clist([bar]), clist([foo]), ['foo']) - check_select(clist(), tree_conflicts, - [''], ignore_misses=True, recurse=True) + check_select(clist([bar]), clist([foo]), ["foo"]) + check_select(clist(), tree_conflicts, [""], ignore_misses=True, recurse=True) - foobaz = bzr_conflicts.ContentsConflict('foo/baz') + foobaz = bzr_conflicts.ContentsConflict("foo/baz") tree_conflicts = clist([foobaz, bar]) - check_select(clist([bar]), clist([foobaz]), - ['foo'], ignore_misses=True, recurse=True) + check_select( + clist([bar]), clist([foobaz]), ["foo"], ignore_misses=True, recurse=True + ) - qux = bzr_conflicts.PathConflict('qux', 'foo/baz') + qux = bzr_conflicts.PathConflict("qux", "foo/baz") tree_conflicts = clist([qux]) - check_select(clist(), tree_conflicts, - ['foo'], ignore_misses=True, recurse=True) - check_select(tree_conflicts, clist(), ['foo'], ignore_misses=True) + check_select(clist(), tree_conflicts, ["foo"], ignore_misses=True, recurse=True) + check_select(tree_conflicts, clist(), ["foo"], ignore_misses=True) def test_resolve_conflicts_recursive(self): - tree = self.make_branch_and_tree('.') - self.build_tree(['dir/', 'dir/hello']) - tree.add(['dir', 'dir/hello']) + tree = self.make_branch_and_tree(".") + self.build_tree(["dir/", "dir/hello"]) + tree.add(["dir", "dir/hello"]) - dirhello = [bzr_conflicts.TextConflict('dir/hello')] + dirhello = [bzr_conflicts.TextConflict("dir/hello")] tree.set_conflicts(dirhello) - resolve(tree, ['dir'], recursive=False, ignore_misses=True) + resolve(tree, ["dir"], recursive=False, ignore_misses=True) self.assertEqual(dirhello, tree.conflicts()) - resolve(tree, ['dir'], recursive=True, ignore_misses=True) + resolve(tree, ["dir"], recursive=True, ignore_misses=True) self.assertEqual(bzr_conflicts.ConflictList([]), tree.conflicts()) diff --git a/breezy/bzr/tests/test_dirstate.py b/breezy/bzr/tests/test_dirstate.py index 025f26d855..0da82b41af 100644 --- a/breezy/bzr/tests/test_dirstate.py +++ b/breezy/bzr/tests/test_dirstate.py @@ -43,13 +43,15 @@ class TestErrors(tests.TestCase): - def test_dirstate_corrupt(self): - error = dirstate.DirstateCorrupt('.bzr/checkout/dirstate', - 'trailing garbage: "x"') - self.assertEqualDiff("The dirstate file (.bzr/checkout/dirstate)" - " appears to be corrupt: trailing garbage: \"x\"", - str(error)) + error = dirstate.DirstateCorrupt( + ".bzr/checkout/dirstate", 'trailing garbage: "x"' + ) + self.assertEqualDiff( + "The dirstate file (.bzr/checkout/dirstate)" + ' appears to be corrupt: trailing garbage: "x"', + str(error), + ) load_tests = load_tests_apply_scenarios @@ -66,23 +68,25 @@ class TestCaseWithDirState(tests.TestCaseWithTransport): def setUp(self): super().setUp() - self.overrideAttr(osutils, - '_selected_dir_reader', self._dir_reader_class()) + self.overrideAttr(osutils, "_selected_dir_reader", self._dir_reader_class()) def create_empty_dirstate(self): """Return a locked but empty dirstate.""" - state = dirstate.DirState.initialize('dirstate') + state = dirstate.DirState.initialize("dirstate") return state def create_dirstate_with_root(self): """Return a write-locked state with a single root entry.""" - packed_stat = b'AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk' - root_entry_direntry = (b'', b'', b'a-root-value'), [ - (b'd', b'', 0, False, packed_stat), - ] + packed_stat = b"AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk" + root_entry_direntry = ( + (b"", b"", b"a-root-value"), + [ + (b"d", b"", 0, False, packed_stat), + ], + ) dirblocks = [] - dirblocks.append((b'', [root_entry_direntry])) - dirblocks.append((b'', [])) + dirblocks.append((b"", [root_entry_direntry])) + dirblocks.append((b"", [])) state = self.create_empty_dirstate() try: state._set_data([], dirblocks) @@ -94,10 +98,13 @@ def create_dirstate_with_root(self): def create_dirstate_with_root_and_subdir(self): """Return a locked DirState with a root and a subdir.""" - packed_stat = b'AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk' - subdir_entry = (b'', b'subdir', b'subdir-id'), [ - (b'd', b'', 0, False, packed_stat), - ] + packed_stat = b"AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk" + subdir_entry = ( + (b"", b"subdir", b"subdir-id"), + [ + (b"d", b"", 0, False, packed_stat), + ], + ) state = self.create_dirstate_with_root() try: dirblocks = list(state._dirblocks) @@ -125,41 +132,68 @@ def create_complex_dirstate(self): :return: The dirstate, still write-locked. """ - packed_stat = b'AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk' - null_sha = b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' - root_entry = (b'', b'', b'a-root-value'), [ - (b'd', b'', 0, False, packed_stat), - ] - a_entry = (b'', b'a', b'a-dir'), [ - (b'd', b'', 0, False, packed_stat), - ] - b_entry = (b'', b'b', b'b-dir'), [ - (b'd', b'', 0, False, packed_stat), - ] - c_entry = (b'', b'c', b'c-file'), [ - (b'f', null_sha, 10, False, packed_stat), - ] - d_entry = (b'', b'd', b'd-file'), [ - (b'f', null_sha, 20, False, packed_stat), - ] - e_entry = (b'a', b'e', b'e-dir'), [ - (b'd', b'', 0, False, packed_stat), - ] - f_entry = (b'a', b'f', b'f-file'), [ - (b'f', null_sha, 30, False, packed_stat), - ] - g_entry = (b'b', b'g', b'g-file'), [ - (b'f', null_sha, 30, False, packed_stat), - ] - h_entry = (b'b', b'h\xc3\xa5', b'h-\xc3\xa5-file'), [ - (b'f', null_sha, 40, False, packed_stat), - ] + packed_stat = b"AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk" + null_sha = b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + root_entry = ( + (b"", b"", b"a-root-value"), + [ + (b"d", b"", 0, False, packed_stat), + ], + ) + a_entry = ( + (b"", b"a", b"a-dir"), + [ + (b"d", b"", 0, False, packed_stat), + ], + ) + b_entry = ( + (b"", b"b", b"b-dir"), + [ + (b"d", b"", 0, False, packed_stat), + ], + ) + c_entry = ( + (b"", b"c", b"c-file"), + [ + (b"f", null_sha, 10, False, packed_stat), + ], + ) + d_entry = ( + (b"", b"d", b"d-file"), + [ + (b"f", null_sha, 20, False, packed_stat), + ], + ) + e_entry = ( + (b"a", b"e", b"e-dir"), + [ + (b"d", b"", 0, False, packed_stat), + ], + ) + f_entry = ( + (b"a", b"f", b"f-file"), + [ + (b"f", null_sha, 30, False, packed_stat), + ], + ) + g_entry = ( + (b"b", b"g", b"g-file"), + [ + (b"f", null_sha, 30, False, packed_stat), + ], + ) + h_entry = ( + (b"b", b"h\xc3\xa5", b"h-\xc3\xa5-file"), + [ + (b"f", null_sha, 40, False, packed_stat), + ], + ) dirblocks = [] - dirblocks.append((b'', [root_entry])) - dirblocks.append((b'', [a_entry, b_entry, c_entry, d_entry])) - dirblocks.append((b'a', [e_entry, f_entry])) - dirblocks.append((b'b', [g_entry, h_entry])) - state = dirstate.DirState.initialize('dirstate') + dirblocks.append((b"", [root_entry])) + dirblocks.append((b"", [a_entry, b_entry, c_entry, d_entry])) + dirblocks.append((b"a", [e_entry, f_entry])) + dirblocks.append((b"b", [g_entry, h_entry])) + state = dirstate.DirState.initialize("dirstate") state._validate() try: state._set_data([], dirblocks) @@ -193,7 +227,7 @@ def check_state_with_reopen(self, expected_result, state): finally: state.unlock() del state - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() try: self.assertEqual(expected_result[1], list(state._iter_entries())) @@ -211,83 +245,105 @@ def create_basic_dirstate(self): b-c f """ - tree = self.make_branch_and_tree('tree') - paths = ['a', 'b/', 'b/c', 'b/d/', 'b/d/e', 'b-c', 'f'] - file_ids = [b'a-id', b'b-id', b'c-id', - b'd-id', b'e-id', b'b-c-id', b'f-id'] - self.build_tree(['tree/' + p for p in paths]) - tree.set_root_id(b'TREE_ROOT') - tree.add([p.rstrip('/') for p in paths], ids=file_ids) - tree.commit('initial', rev_id=b'rev-1') - revision_id = b'rev-1' + tree = self.make_branch_and_tree("tree") + paths = ["a", "b/", "b/c", "b/d/", "b/d/e", "b-c", "f"] + file_ids = [b"a-id", b"b-id", b"c-id", b"d-id", b"e-id", b"b-c-id", b"f-id"] + self.build_tree(["tree/" + p for p in paths]) + tree.set_root_id(b"TREE_ROOT") + tree.add([p.rstrip("/") for p in paths], ids=file_ids) + tree.commit("initial", rev_id=b"rev-1") + revision_id = b"rev-1" # a_packed_stat = dirstate.pack_stat(os.stat('tree/a')) - t = self.get_transport('tree') - a_text = t.get_bytes('a') + t = self.get_transport("tree") + a_text = t.get_bytes("a") a_sha = osutils.sha_string(a_text) a_len = len(a_text) # b_packed_stat = dirstate.pack_stat(os.stat('tree/b')) # c_packed_stat = dirstate.pack_stat(os.stat('tree/b/c')) - c_text = t.get_bytes('b/c') + c_text = t.get_bytes("b/c") c_sha = osutils.sha_string(c_text) c_len = len(c_text) # d_packed_stat = dirstate.pack_stat(os.stat('tree/b/d')) # e_packed_stat = dirstate.pack_stat(os.stat('tree/b/d/e')) - e_text = t.get_bytes('b/d/e') + e_text = t.get_bytes("b/d/e") e_sha = osutils.sha_string(e_text) e_len = len(e_text) - b_c_text = t.get_bytes('b-c') + b_c_text = t.get_bytes("b-c") b_c_sha = osutils.sha_string(b_c_text) b_c_len = len(b_c_text) # f_packed_stat = dirstate.pack_stat(os.stat('tree/f')) - f_text = t.get_bytes('f') + f_text = t.get_bytes("f") f_sha = osutils.sha_string(f_text) f_len = len(f_text) null_stat = dirstate.DirState.NULLSTAT expected = { - b'': ((b'', b'', b'TREE_ROOT'), [ - (b'd', b'', 0, False, null_stat), - (b'd', b'', 0, False, revision_id), - ]), - b'a': ((b'', b'a', b'a-id'), [ - (b'f', b'', 0, False, null_stat), - (b'f', a_sha, a_len, False, revision_id), - ]), - b'b': ((b'', b'b', b'b-id'), [ - (b'd', b'', 0, False, null_stat), - (b'd', b'', 0, False, revision_id), - ]), - b'b/c': ((b'b', b'c', b'c-id'), [ - (b'f', b'', 0, False, null_stat), - (b'f', c_sha, c_len, False, revision_id), - ]), - b'b/d': ((b'b', b'd', b'd-id'), [ - (b'd', b'', 0, False, null_stat), - (b'd', b'', 0, False, revision_id), - ]), - b'b/d/e': ((b'b/d', b'e', b'e-id'), [ - (b'f', b'', 0, False, null_stat), - (b'f', e_sha, e_len, False, revision_id), - ]), - b'b-c': ((b'', b'b-c', b'b-c-id'), [ - (b'f', b'', 0, False, null_stat), - (b'f', b_c_sha, b_c_len, False, revision_id), - ]), - b'f': ((b'', b'f', b'f-id'), [ - (b'f', b'', 0, False, null_stat), - (b'f', f_sha, f_len, False, revision_id), - ]), + b"": ( + (b"", b"", b"TREE_ROOT"), + [ + (b"d", b"", 0, False, null_stat), + (b"d", b"", 0, False, revision_id), + ], + ), + b"a": ( + (b"", b"a", b"a-id"), + [ + (b"f", b"", 0, False, null_stat), + (b"f", a_sha, a_len, False, revision_id), + ], + ), + b"b": ( + (b"", b"b", b"b-id"), + [ + (b"d", b"", 0, False, null_stat), + (b"d", b"", 0, False, revision_id), + ], + ), + b"b/c": ( + (b"b", b"c", b"c-id"), + [ + (b"f", b"", 0, False, null_stat), + (b"f", c_sha, c_len, False, revision_id), + ], + ), + b"b/d": ( + (b"b", b"d", b"d-id"), + [ + (b"d", b"", 0, False, null_stat), + (b"d", b"", 0, False, revision_id), + ], + ), + b"b/d/e": ( + (b"b/d", b"e", b"e-id"), + [ + (b"f", b"", 0, False, null_stat), + (b"f", e_sha, e_len, False, revision_id), + ], + ), + b"b-c": ( + (b"", b"b-c", b"b-c-id"), + [ + (b"f", b"", 0, False, null_stat), + (b"f", b_c_sha, b_c_len, False, revision_id), + ], + ), + b"f": ( + (b"", b"f", b"f-id"), + [ + (b"f", b"", 0, False, null_stat), + (b"f", f_sha, f_len, False, revision_id), + ], + ), } - state = dirstate.DirState.from_tree(tree, 'dirstate') + state = dirstate.DirState.from_tree(tree, "dirstate") try: state.save() finally: state.unlock() # Use a different object, to make sure nothing is pre-cached in memory. - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() self.addCleanup(state.unlock) - self.assertEqual(dirstate.DirState.NOT_IN_MEMORY, - state._dirblock_state) + self.assertEqual(dirstate.DirState.NOT_IN_MEMORY, state._dirblock_state) # This is code is only really tested if we actually have to make more # than one read, so set the page size to something smaller. # We want it to contain about 2.2 records, so that we have a couple @@ -304,27 +360,41 @@ def create_duplicated_dirstate(self): tree, state, expected = self.create_basic_dirstate() # Now we will just remove and add every file so we get an extra entry # per entry. Unversion in reverse order so we handle subdirs - tree.unversion(['f', 'b-c', 'b/d/e', 'b/d', 'b/c', 'b', 'a']) - tree.add(['a', 'b', 'b/c', 'b/d', 'b/d/e', 'b-c', 'f'], - ids=[b'a-id2', b'b-id2', b'c-id2', b'd-id2', b'e-id2', b'b-c-id2', b'f-id2']) + tree.unversion(["f", "b-c", "b/d/e", "b/d", "b/c", "b", "a"]) + tree.add( + ["a", "b", "b/c", "b/d", "b/d/e", "b-c", "f"], + ids=[ + b"a-id2", + b"b-id2", + b"c-id2", + b"d-id2", + b"e-id2", + b"b-c-id2", + b"f-id2", + ], + ) # Update the expected dictionary. - for path in [b'a', b'b', b'b/c', b'b/d', b'b/d/e', b'b-c', b'f']: + for path in [b"a", b"b", b"b/c", b"b/d", b"b/d/e", b"b-c", b"f"]: orig = expected[path] - path2 = path + b'2' + path2 = path + b"2" # This record was deleted in the current tree - expected[path] = (orig[0], [dirstate.DirState.NULL_PARENT_DETAILS, - orig[1][1]]) - new_key = (orig[0][0], orig[0][1], orig[0][2] + b'2') + expected[path] = ( + orig[0], + [dirstate.DirState.NULL_PARENT_DETAILS, orig[1][1]], + ) + new_key = (orig[0][0], orig[0][1], orig[0][2] + b"2") # And didn't exist in the basis tree - expected[path2] = (new_key, [orig[1][0], - dirstate.DirState.NULL_PARENT_DETAILS]) + expected[path2] = ( + new_key, + [orig[1][0], dirstate.DirState.NULL_PARENT_DETAILS], + ) # We will replace the 'dirstate' file underneath 'state', but that is # okay as lock as we unlock 'state' first. state.unlock() try: - new_state = dirstate.DirState.from_tree(tree, 'dirstate') + new_state = dirstate.DirState.from_tree(tree, "dirstate") try: new_state.save() finally: @@ -342,30 +412,33 @@ def create_renamed_dirstate(self): """ tree, state, expected = self.create_basic_dirstate() # Rename a file - tree.rename_one('a', 'b/g') + tree.rename_one("a", "b/g") # And a directory - tree.rename_one('b/d', 'h') - - old_a = expected[b'a'] - expected[b'a'] = ( - old_a[0], [(b'r', b'b/g', 0, False, b''), old_a[1][1]]) - expected[b'b/g'] = ((b'b', b'g', b'a-id'), [old_a[1][0], - (b'r', b'a', 0, False, b'')]) - old_d = expected[b'b/d'] - expected[b'b/d'] = (old_d[0], - [(b'r', b'h', 0, False, b''), old_d[1][1]]) - expected[b'h'] = ((b'', b'h', b'd-id'), [old_d[1][0], - (b'r', b'b/d', 0, False, b'')]) - - old_e = expected[b'b/d/e'] - expected[b'b/d/e'] = (old_e[0], [(b'r', b'h/e', 0, False, b''), - old_e[1][1]]) - expected[b'h/e'] = ((b'h', b'e', b'e-id'), [old_e[1][0], - (b'r', b'b/d/e', 0, False, b'')]) + tree.rename_one("b/d", "h") + + old_a = expected[b"a"] + expected[b"a"] = (old_a[0], [(b"r", b"b/g", 0, False, b""), old_a[1][1]]) + expected[b"b/g"] = ( + (b"b", b"g", b"a-id"), + [old_a[1][0], (b"r", b"a", 0, False, b"")], + ) + old_d = expected[b"b/d"] + expected[b"b/d"] = (old_d[0], [(b"r", b"h", 0, False, b""), old_d[1][1]]) + expected[b"h"] = ( + (b"", b"h", b"d-id"), + [old_d[1][0], (b"r", b"b/d", 0, False, b"")], + ) + + old_e = expected[b"b/d/e"] + expected[b"b/d/e"] = (old_e[0], [(b"r", b"h/e", 0, False, b""), old_e[1][1]]) + expected[b"h/e"] = ( + (b"h", b"e", b"e-id"), + [old_e[1][0], (b"r", b"b/d/e", 0, False, b"")], + ) state.unlock() try: - new_state = dirstate.DirState.from_tree(tree, 'dirstate') + new_state = dirstate.DirState.from_tree(tree, "dirstate") try: new_state.save() finally: @@ -376,30 +449,55 @@ def create_renamed_dirstate(self): class TestTreeToDirState(TestCaseWithDirState): - def test_empty_to_dirstate(self): """We should be able to create a dirstate for an empty tree.""" # There are no files on disk and no parents - tree = self.make_branch_and_tree('tree') - expected_result = ([], [ - ((b'', b'', tree.path2id('')), # common details - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - ])]) - state = dirstate.DirState.from_tree(tree, 'dirstate') + tree = self.make_branch_and_tree("tree") + expected_result = ( + [], + [ + ( + (b"", b"", tree.path2id("")), # common details + [ + ( + b"d", + b"", + 0, + False, + dirstate.DirState.NULLSTAT, + ), # current tree + ], + ) + ], + ) + state = dirstate.DirState.from_tree(tree, "dirstate") state._validate() self.check_state_with_reopen(expected_result, state) def test_1_parents_empty_to_dirstate(self): # create a parent by doing a commit - tree = self.make_branch_and_tree('tree') - rev_id = tree.commit('first post') + tree = self.make_branch_and_tree("tree") + rev_id = tree.commit("first post") dirstate.pack_stat(os.stat(tree.basedir)) - expected_result = ([rev_id], [ - ((b'', b'', tree.path2id('')), # common details - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - (b'd', b'', 0, False, rev_id), # first parent details - ])]) - state = dirstate.DirState.from_tree(tree, 'dirstate') + expected_result = ( + [rev_id], + [ + ( + (b"", b"", tree.path2id("")), # common details + [ + ( + b"d", + b"", + 0, + False, + dirstate.DirState.NULLSTAT, + ), # current tree + (b"d", b"", 0, False, rev_id), # first parent details + ], + ) + ], + ) + state = dirstate.DirState.from_tree(tree, "dirstate") self.check_state_with_reopen(expected_result, state) state.lock_read() try: @@ -409,18 +507,31 @@ def test_1_parents_empty_to_dirstate(self): def test_2_parents_empty_to_dirstate(self): # create a parent by doing a commit - tree = self.make_branch_and_tree('tree') - rev_id = tree.commit('first post') - tree2 = tree.controldir.sprout('tree2').open_workingtree() - rev_id2 = tree2.commit('second post', allow_pointless=True) + tree = self.make_branch_and_tree("tree") + rev_id = tree.commit("first post") + tree2 = tree.controldir.sprout("tree2").open_workingtree() + rev_id2 = tree2.commit("second post", allow_pointless=True) tree.merge_from_branch(tree2.branch) - expected_result = ([rev_id, rev_id2], [ - ((b'', b'', tree.path2id('')), # common details - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - (b'd', b'', 0, False, rev_id), # first parent details - (b'd', b'', 0, False, rev_id), # second parent details - ])]) - state = dirstate.DirState.from_tree(tree, 'dirstate') + expected_result = ( + [rev_id, rev_id2], + [ + ( + (b"", b"", tree.path2id("")), # common details + [ + ( + b"d", + b"", + 0, + False, + dirstate.DirState.NULLSTAT, + ), # current tree + (b"d", b"", 0, False, rev_id), # first parent details + (b"d", b"", 0, False, rev_id), # second parent details + ], + ) + ], + ) + state = dirstate.DirState.from_tree(tree, "dirstate") self.check_state_with_reopen(expected_result, state) state.lock_read() try: @@ -431,85 +542,158 @@ def test_2_parents_empty_to_dirstate(self): def test_empty_unknowns_are_ignored_to_dirstate(self): """We should be able to create a dirstate for an empty tree.""" # There are no files on disk and no parents - tree = self.make_branch_and_tree('tree') - self.build_tree(['tree/unknown']) - expected_result = ([], [ - ((b'', b'', tree.path2id('')), # common details - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - ])]) - state = dirstate.DirState.from_tree(tree, 'dirstate') + tree = self.make_branch_and_tree("tree") + self.build_tree(["tree/unknown"]) + expected_result = ( + [], + [ + ( + (b"", b"", tree.path2id("")), # common details + [ + ( + b"d", + b"", + 0, + False, + dirstate.DirState.NULLSTAT, + ), # current tree + ], + ) + ], + ) + state = dirstate.DirState.from_tree(tree, "dirstate") self.check_state_with_reopen(expected_result, state) def get_tree_with_a_file(self): - tree = self.make_branch_and_tree('tree') - self.build_tree(['tree/a file']) - tree.add('a file', ids=b'a-file-id') + tree = self.make_branch_and_tree("tree") + self.build_tree(["tree/a file"]) + tree.add("a file", ids=b"a-file-id") return tree def test_non_empty_no_parents_to_dirstate(self): """We should be able to create a dirstate for an empty tree.""" # There are files on disk and no parents tree = self.get_tree_with_a_file() - expected_result = ([], [ - ((b'', b'', tree.path2id('')), # common details - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - ]), - ((b'', b'a file', b'a-file-id'), # common - [(b'f', b'', 0, False, dirstate.DirState.NULLSTAT), # current - ]), - ]) - state = dirstate.DirState.from_tree(tree, 'dirstate') + expected_result = ( + [], + [ + ( + (b"", b"", tree.path2id("")), # common details + [ + ( + b"d", + b"", + 0, + False, + dirstate.DirState.NULLSTAT, + ), # current tree + ], + ), + ( + (b"", b"a file", b"a-file-id"), # common + [ + (b"f", b"", 0, False, dirstate.DirState.NULLSTAT), # current + ], + ), + ], + ) + state = dirstate.DirState.from_tree(tree, "dirstate") self.check_state_with_reopen(expected_result, state) def test_1_parents_not_empty_to_dirstate(self): # create a parent by doing a commit tree = self.get_tree_with_a_file() - rev_id = tree.commit('first post') + rev_id = tree.commit("first post") # change the current content to be different this will alter stat, sha # and length: - self.build_tree_contents([('tree/a file', b'new content\n')]) - expected_result = ([rev_id], [ - ((b'', b'', tree.path2id('')), # common details - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - (b'd', b'', 0, False, rev_id), # first parent details - ]), - ((b'', b'a file', b'a-file-id'), # common - [(b'f', b'', 0, False, dirstate.DirState.NULLSTAT), # current - (b'f', b'c3ed76e4bfd45ff1763ca206055bca8e9fc28aa8', 24, False, - rev_id), # first parent - ]), - ]) - state = dirstate.DirState.from_tree(tree, 'dirstate') + self.build_tree_contents([("tree/a file", b"new content\n")]) + expected_result = ( + [rev_id], + [ + ( + (b"", b"", tree.path2id("")), # common details + [ + ( + b"d", + b"", + 0, + False, + dirstate.DirState.NULLSTAT, + ), # current tree + (b"d", b"", 0, False, rev_id), # first parent details + ], + ), + ( + (b"", b"a file", b"a-file-id"), # common + [ + (b"f", b"", 0, False, dirstate.DirState.NULLSTAT), # current + ( + b"f", + b"c3ed76e4bfd45ff1763ca206055bca8e9fc28aa8", + 24, + False, + rev_id, + ), # first parent + ], + ), + ], + ) + state = dirstate.DirState.from_tree(tree, "dirstate") self.check_state_with_reopen(expected_result, state) def test_2_parents_not_empty_to_dirstate(self): # create a parent by doing a commit tree = self.get_tree_with_a_file() - rev_id = tree.commit('first post') - tree2 = tree.controldir.sprout('tree2').open_workingtree() + rev_id = tree.commit("first post") + tree2 = tree.controldir.sprout("tree2").open_workingtree() # change the current content to be different this will alter stat, sha # and length: - self.build_tree_contents([('tree2/a file', b'merge content\n')]) - rev_id2 = tree2.commit('second post') + self.build_tree_contents([("tree2/a file", b"merge content\n")]) + rev_id2 = tree2.commit("second post") tree.merge_from_branch(tree2.branch) # change the current content to be different this will alter stat, sha # and length again, giving us three distinct values: - self.build_tree_contents([('tree/a file', b'new content\n')]) - expected_result = ([rev_id, rev_id2], [ - ((b'', b'', tree.path2id('')), # common details - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - (b'd', b'', 0, False, rev_id), # first parent details - (b'd', b'', 0, False, rev_id), # second parent details - ]), - ((b'', b'a file', b'a-file-id'), # common - [(b'f', b'', 0, False, dirstate.DirState.NULLSTAT), # current - (b'f', b'c3ed76e4bfd45ff1763ca206055bca8e9fc28aa8', 24, False, - rev_id), # first parent - (b'f', b'314d796174c9412647c3ce07dfb5d36a94e72958', 14, False, - rev_id2), # second parent - ]), - ]) - state = dirstate.DirState.from_tree(tree, 'dirstate') + self.build_tree_contents([("tree/a file", b"new content\n")]) + expected_result = ( + [rev_id, rev_id2], + [ + ( + (b"", b"", tree.path2id("")), # common details + [ + ( + b"d", + b"", + 0, + False, + dirstate.DirState.NULLSTAT, + ), # current tree + (b"d", b"", 0, False, rev_id), # first parent details + (b"d", b"", 0, False, rev_id), # second parent details + ], + ), + ( + (b"", b"a file", b"a-file-id"), # common + [ + (b"f", b"", 0, False, dirstate.DirState.NULLSTAT), # current + ( + b"f", + b"c3ed76e4bfd45ff1763ca206055bca8e9fc28aa8", + 24, + False, + rev_id, + ), # first parent + ( + b"f", + b"314d796174c9412647c3ce07dfb5d36a94e72958", + 14, + False, + rev_id2, + ), # second parent + ], + ), + ], + ) + state = dirstate.DirState.from_tree(tree, "dirstate") self.check_state_with_reopen(expected_result, state) def test_colliding_fileids(self): @@ -519,15 +703,20 @@ def test_colliding_fileids(self): # create some trees to test from parents = [] for i in range(7): - tree = self.make_branch_and_tree('tree%d' % i) - self.build_tree(['tree%d/name' % i, ]) - tree.add(['name'], ids=[b'file-id%d' % i]) - revision_id = b'revid-%d' % i - tree.commit('message', rev_id=revision_id) - parents.append((revision_id, - tree.branch.repository.revision_tree(revision_id))) + tree = self.make_branch_and_tree("tree%d" % i) + self.build_tree( + [ + "tree%d/name" % i, + ] + ) + tree.add(["name"], ids=[b"file-id%d" % i]) + revision_id = b"revid-%d" % i + tree.commit("message", rev_id=revision_id) + parents.append( + (revision_id, tree.branch.repository.revision_tree(revision_id)) + ) # now fold these trees into a dirstate - state = dirstate.DirState.initialize('dirstate') + state = dirstate.DirState.initialize("dirstate") try: state.set_parent_trees(parents, []) state._validate() @@ -536,45 +725,50 @@ def test_colliding_fileids(self): class TestDirStateOnFile(TestCaseWithDirState): - def create_updated_dirstate(self): - self.build_tree(['a-file']) - tree = self.make_branch_and_tree('.') - tree.add(['a-file'], ids=[b'a-id']) - tree.commit('add a-file') + self.build_tree(["a-file"]) + tree = self.make_branch_and_tree(".") + tree.add(["a-file"], ids=[b"a-id"]) + tree.commit("add a-file") # Save and unlock the state, re-open it in readonly mode - state = dirstate.DirState.from_tree(tree, 'dirstate') + state = dirstate.DirState.from_tree(tree, "dirstate") state.save() state.unlock() - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() return state def test_construct_with_path(self): - tree = self.make_branch_and_tree('tree') - state = dirstate.DirState.from_tree(tree, 'dirstate.from_tree') + tree = self.make_branch_and_tree("tree") + state = dirstate.DirState.from_tree(tree, "dirstate.from_tree") # we want to be able to get the lines of the dirstate that we will # write to disk. lines = state.get_lines() state.unlock() - self.build_tree_contents([('dirstate', b''.join(lines))]) + self.build_tree_contents([("dirstate", b"".join(lines))]) # get a state object # no parents, default tree content - expected_result = ([], [ - ((b'', b'', tree.path2id('')), # common details - # current tree details, but new from_tree skips statting, it - # uses set_state_from_inventory, and thus depends on the - # inventory state. - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT), - ]) - ]) - state = dirstate.DirState.on_file('dirstate') + expected_result = ( + [], + [ + ( + (b"", b"", tree.path2id("")), # common details + # current tree details, but new from_tree skips statting, it + # uses set_state_from_inventory, and thus depends on the + # inventory state. + [ + (b"d", b"", 0, False, dirstate.DirState.NULLSTAT), + ], + ) + ], + ) + state = dirstate.DirState.on_file("dirstate") state.lock_write() # check_state_with_reopen will save() and unlock it self.check_state_with_reopen(expected_result, state) def test_can_save_clean_on_file(self): - tree = self.make_branch_and_tree('tree') - state = dirstate.DirState.from_tree(tree, 'dirstate') + tree = self.make_branch_and_tree("tree") + state = dirstate.DirState.from_tree(tree, "dirstate") try: # doing a save should work here as there have been no changes. state.save() @@ -586,7 +780,7 @@ def test_can_save_clean_on_file(self): def test_can_save_in_read_lock(self): state = self.create_updated_dirstate() try: - entry = state._get_entry(0, path_utf8=b'a-file') + entry = state._get_entry(0, path_utf8=b"a-file") # The current size should be 0 (default) self.assertEqual(0, entry[1][0][2]) # We should have a real entry. @@ -594,16 +788,16 @@ def test_can_save_in_read_lock(self): # Set the cutoff-time into the future, so things look cacheable state._sha_cutoff_time() state._cutoff_time += 10.0 - st = os.lstat('a-file') - sha1sum = dirstate.update_entry(state, entry, 'a-file', st) + st = os.lstat("a-file") + sha1sum = dirstate.update_entry(state, entry, "a-file", st) # We updated the current sha1sum because the file is cacheable - self.assertEqual(b'ecc5374e9ed82ad3ea3b4d452ea995a5fd3e70e3', - sha1sum) + self.assertEqual(b"ecc5374e9ed82ad3ea3b4d452ea995a5fd3e70e3", sha1sum) # The dirblock has been updated self.assertEqual(st.st_size, entry[1][0][2]) - self.assertEqual(dirstate.DirState.IN_MEMORY_HASH_MODIFIED, - state._dirblock_state) + self.assertEqual( + dirstate.DirState.IN_MEMORY_HASH_MODIFIED, state._dirblock_state + ) del entry # Now, since we are the only one holding a lock, we should be able @@ -613,10 +807,10 @@ def test_can_save_in_read_lock(self): state.unlock() # Re-open the file, and ensure that the state has been updated. - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() try: - entry = state._get_entry(0, path_utf8=b'a-file') + entry = state._get_entry(0, path_utf8=b"a-file") self.assertEqual(st.st_size, entry[1][0][2]) finally: state.unlock() @@ -625,24 +819,24 @@ def test_save_fails_quietly_if_locked(self): """If dirstate is locked, save will fail without complaining.""" state = self.create_updated_dirstate() try: - entry = state._get_entry(0, path_utf8=b'a-file') + entry = state._get_entry(0, path_utf8=b"a-file") # No cached sha1 yet. - self.assertEqual(b'', entry[1][0][1]) + self.assertEqual(b"", entry[1][0][1]) # Set the cutoff-time into the future, so things look cacheable state._sha_cutoff_time() state._cutoff_time += 10.0 - st = os.lstat('a-file') - sha1sum = dirstate.update_entry(state, entry, 'a-file', st) - self.assertEqual(b'ecc5374e9ed82ad3ea3b4d452ea995a5fd3e70e3', - sha1sum) - self.assertEqual(dirstate.DirState.IN_MEMORY_HASH_MODIFIED, - state._dirblock_state) + st = os.lstat("a-file") + sha1sum = dirstate.update_entry(state, entry, "a-file", st) + self.assertEqual(b"ecc5374e9ed82ad3ea3b4d452ea995a5fd3e70e3", sha1sum) + self.assertEqual( + dirstate.DirState.IN_MEMORY_HASH_MODIFIED, state._dirblock_state + ) # Now, before we try to save, grab another dirstate, and take out a # read lock. # TODO: jam 20070315 Ideally this would be locked by another # process. To make sure the file is really OS locked. - state2 = dirstate.DirState.on_file('dirstate') + state2 = dirstate.DirState.on_file("dirstate") state2.lock_read() try: # This won't actually write anything, because it couldn't grab @@ -659,46 +853,60 @@ def test_save_fails_quietly_if_locked(self): state.unlock() # The file on disk should not be modified. - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() try: - entry = state._get_entry(0, path_utf8=b'a-file') - self.assertEqual(b'', entry[1][0][1]) + entry = state._get_entry(0, path_utf8=b"a-file") + self.assertEqual(b"", entry[1][0][1]) finally: state.unlock() def test_save_refuses_if_changes_aborted(self): - self.build_tree(['a-file', 'a-dir/']) - state = dirstate.DirState.initialize('dirstate') + self.build_tree(["a-file", "a-dir/"]) + state = dirstate.DirState.initialize("dirstate") try: # No stat and no sha1 sum. - state.add('a-file', b'a-file-id', 'file', None, b'') + state.add("a-file", b"a-file-id", "file", None, b"") state.save() finally: state.unlock() # The dirstate should include TREE_ROOT and 'a-file' and nothing else expected_blocks = [ - (b'', [((b'', b'', b'TREE_ROOT'), - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT)])]), - (b'', [((b'', b'a-file', b'a-file-id'), - [(b'f', b'', 0, False, dirstate.DirState.NULLSTAT)])]), + ( + b"", + [ + ( + (b"", b"", b"TREE_ROOT"), + [(b"d", b"", 0, False, dirstate.DirState.NULLSTAT)], + ) + ], + ), + ( + b"", + [ + ( + (b"", b"a-file", b"a-file-id"), + [(b"f", b"", 0, False, dirstate.DirState.NULLSTAT)], + ) + ], + ), ] - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_write() try: state._read_dirblocks_if_needed() self.assertEqual(expected_blocks, state._dirblocks) # Now modify the state, but mark it as inconsistent - state.add('a-dir', b'a-dir-id', 'directory', None, b'') + state.add("a-dir", b"a-dir-id", "directory", None, b"") state._changes_aborted = True state.save() finally: state.unlock() - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() try: state._read_dirblocks_if_needed() @@ -708,14 +916,25 @@ def test_save_refuses_if_changes_aborted(self): class TestDirStateInitialize(TestCaseWithDirState): - def test_initialize(self): - expected_result = ([], [ - ((b'', b'', b'TREE_ROOT'), # common details - [(b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - ]) - ]) - state = dirstate.DirState.initialize('dirstate') + expected_result = ( + [], + [ + ( + (b"", b"", b"TREE_ROOT"), # common details + [ + ( + b"d", + b"", + 0, + False, + dirstate.DirState.NULLSTAT, + ), # current tree + ], + ) + ], + ) + state = dirstate.DirState.initialize("dirstate") try: self.assertIsInstance(state, dirstate.DirState) lines = state.get_lines() @@ -724,52 +943,63 @@ def test_initialize(self): # On win32 you can't read from a locked file, even within the same # process. So we have to unlock and release before we check the file # contents. - self.assertFileEqual(b''.join(lines), 'dirstate') + self.assertFileEqual(b"".join(lines), "dirstate") state.lock_read() # check_state_with_reopen will unlock self.check_state_with_reopen(expected_result, state) class TestDirStateManipulations(TestCaseWithDirState): - def make_minimal_tree(self): - tree1 = self.make_branch_and_memory_tree('tree1') + tree1 = self.make_branch_and_memory_tree("tree1") tree1.lock_write() self.addCleanup(tree1.unlock) - tree1.add('') - revid1 = tree1.commit('foo') + tree1.add("") + revid1 = tree1.commit("foo") return tree1, revid1 def test_update_minimal_updates_id_index(self): state = self.create_dirstate_with_root_and_subdir() self.addCleanup(state.unlock) id_index = state._get_id_index() - self.assertEqual([b'a-root-value', b'subdir-id'], sorted(id_index.file_ids())) - state.add('file-name', b'file-id', 'file', None, '') - self.assertEqual([b'a-root-value', b'file-id', b'subdir-id'], - sorted(id_index.file_ids())) - state.update_minimal((b'', b'new-name', b'file-id'), b'f', - path_utf8=b'new-name') - self.assertEqual([b'a-root-value', b'file-id', b'subdir-id'], - sorted(id_index.file_ids())) - self.assertEqual([(b'', b'new-name', b'file-id')], - sorted(id_index.get(b'file-id'))) + self.assertEqual([b"a-root-value", b"subdir-id"], sorted(id_index.file_ids())) + state.add("file-name", b"file-id", "file", None, "") + self.assertEqual( + [b"a-root-value", b"file-id", b"subdir-id"], sorted(id_index.file_ids()) + ) + state.update_minimal( + (b"", b"new-name", b"file-id"), b"f", path_utf8=b"new-name" + ) + self.assertEqual( + [b"a-root-value", b"file-id", b"subdir-id"], sorted(id_index.file_ids()) + ) + self.assertEqual( + [(b"", b"new-name", b"file-id")], sorted(id_index.get(b"file-id")) + ) state._validate() def test_set_state_from_inventory_no_content_no_parents(self): # setting the current inventory is a slow but important api to support. tree1, revid1 = self.make_minimal_tree() inv = tree1.root_inventory - root_id = inv.path2id('') - expected_result = [], [ - ((b'', b'', root_id), [ - (b'd', b'', 0, False, dirstate.DirState.NULLSTAT)])] - state = dirstate.DirState.initialize('dirstate') + root_id = inv.path2id("") + expected_result = ( + [], + [ + ( + (b"", b"", root_id), + [(b"d", b"", 0, False, dirstate.DirState.NULLSTAT)], + ) + ], + ) + state = dirstate.DirState.initialize("dirstate") try: state.set_state_from_inventory(inv) - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._header_state) - self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, - state._dirblock_state) + self.assertEqual( + dirstate.DirState.IN_MEMORY_UNMODIFIED, state._header_state + ) + self.assertEqual( + dirstate.DirState.IN_MEMORY_MODIFIED, state._dirblock_state + ) except: state.unlock() raise @@ -780,17 +1010,23 @@ def test_set_state_from_inventory_no_content_no_parents(self): def test_set_state_from_scratch_no_parents(self): tree1, revid1 = self.make_minimal_tree() inv = tree1.root_inventory - root_id = inv.path2id('') - expected_result = [], [ - ((b'', b'', root_id), [ - (b'd', b'', 0, False, dirstate.DirState.NULLSTAT)])] - state = dirstate.DirState.initialize('dirstate') + root_id = inv.path2id("") + expected_result = ( + [], + [ + ( + (b"", b"", root_id), + [(b"d", b"", 0, False, dirstate.DirState.NULLSTAT)], + ) + ], + ) + state = dirstate.DirState.initialize("dirstate") try: state.set_state_from_scratch(inv, [], []) - self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, - state._header_state) - self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, state._header_state) + self.assertEqual( + dirstate.DirState.IN_MEMORY_MODIFIED, state._dirblock_state + ) except: state.unlock() raise @@ -801,19 +1037,18 @@ def test_set_state_from_scratch_no_parents(self): def test_set_state_from_scratch_identical_parent(self): tree1, revid1 = self.make_minimal_tree() inv = tree1.root_inventory - root_id = inv.path2id('') + root_id = inv.path2id("") rev_tree1 = tree1.branch.repository.revision_tree(revid1) - d_entry = (b'd', b'', 0, False, dirstate.DirState.NULLSTAT) - parent_entry = (b'd', b'', 0, False, revid1) - expected_result = [revid1], [ - ((b'', b'', root_id), [d_entry, parent_entry])] - state = dirstate.DirState.initialize('dirstate') + d_entry = (b"d", b"", 0, False, dirstate.DirState.NULLSTAT) + parent_entry = (b"d", b"", 0, False, revid1) + expected_result = [revid1], [((b"", b"", root_id), [d_entry, parent_entry])] + state = dirstate.DirState.initialize("dirstate") try: state.set_state_from_scratch(inv, [(revid1, rev_tree1)], []) - self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, - state._header_state) - self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_MODIFIED, state._header_state) + self.assertEqual( + dirstate.DirState.IN_MEMORY_MODIFIED, state._dirblock_state + ) except: state.unlock() raise @@ -826,42 +1061,62 @@ def test_set_state_from_inventory_preserves_hashcache(self): # set_state_from_inventory should preserve the stat and hash value for # workingtree files that are not changed by the inventory. - tree = self.make_branch_and_tree('.') + tree = self.make_branch_and_tree(".") # depends on the default format using dirstate... with tree.lock_write(): # make a dirstate with some valid hashcache data # file on disk, but that's not needed for this test - foo_contents = b'contents of foo' - self.build_tree_contents([('foo', foo_contents)]) - tree.add('foo', ids=b'foo-id') + foo_contents = b"contents of foo" + self.build_tree_contents([("foo", foo_contents)]) + tree.add("foo", ids=b"foo-id") - foo_stat = os.stat('foo') + foo_stat = os.stat("foo") foo_packed = dirstate.pack_stat(foo_stat) foo_sha = osutils.sha_string(foo_contents) foo_size = len(foo_contents) # should not be cached yet, because the file's too fresh self.assertEqual( - ((b'', b'foo', b'foo-id',), - [(b'f', b'', 0, False, dirstate.DirState.NULLSTAT)]), - tree._dirstate._get_entry(0, b'foo-id')) + ( + ( + b"", + b"foo", + b"foo-id", + ), + [(b"f", b"", 0, False, dirstate.DirState.NULLSTAT)], + ), + tree._dirstate._get_entry(0, b"foo-id"), + ) # poke in some hashcache information - it wouldn't normally be # stored because it's too fresh tree._dirstate.update_minimal( - (b'', b'foo', b'foo-id'), - b'f', False, foo_sha, foo_packed, foo_size, b'foo') + (b"", b"foo", b"foo-id"), + b"f", + False, + foo_sha, + foo_packed, + foo_size, + b"foo", + ) # now should be cached self.assertEqual( - ((b'', b'foo', b'foo-id',), - [(b'f', foo_sha, foo_size, False, foo_packed)]), - tree._dirstate._get_entry(0, b'foo-id')) + ( + ( + b"", + b"foo", + b"foo-id", + ), + [(b"f", foo_sha, foo_size, False, foo_packed)], + ), + tree._dirstate._get_entry(0, b"foo-id"), + ) # extract the inventory, and add something to it inv = tree._get_root_inventory() # should see the file we poked in... - self.assertTrue(inv.has_id(b'foo-id')) - self.assertTrue(inv.has_filename('foo')) - inv.add_path('bar', 'file', b'bar-id') + self.assertTrue(inv.has_id(b"foo-id")) + self.assertTrue(inv.has_filename("foo")) + inv.add_path("bar", "file", b"bar-id") tree._dirstate._validate() # this used to cause it to lose its hashcache tree._dirstate.set_state_from_inventory(inv) @@ -871,46 +1126,65 @@ def test_set_state_from_inventory_preserves_hashcache(self): # now check that the state still has the original hashcache value state = tree._dirstate state._validate() - foo_tuple = state._get_entry(0, path_utf8=b'foo') + foo_tuple = state._get_entry(0, path_utf8=b"foo") self.assertEqual( - ((b'', b'foo', b'foo-id',), - [(b'f', foo_sha, len(foo_contents), False, - dirstate.pack_stat(foo_stat))]), - foo_tuple) + ( + ( + b"", + b"foo", + b"foo-id", + ), + [ + ( + b"f", + foo_sha, + len(foo_contents), + False, + dirstate.pack_stat(foo_stat), + ) + ], + ), + foo_tuple, + ) def test_set_state_from_inventory_mixed_paths(self): - tree1 = self.make_branch_and_tree('tree1') - self.build_tree(['tree1/a/', 'tree1/a/b/', 'tree1/a-b/', - 'tree1/a/b/foo', 'tree1/a-b/bar']) + tree1 = self.make_branch_and_tree("tree1") + self.build_tree( + ["tree1/a/", "tree1/a/b/", "tree1/a-b/", "tree1/a/b/foo", "tree1/a-b/bar"] + ) tree1.lock_write() try: - tree1.add(['a', 'a/b', 'a-b', 'a/b/foo', 'a-b/bar'], - ids=[b'a-id', b'b-id', b'a-b-id', b'foo-id', b'bar-id']) - tree1.commit('rev1', rev_id=b'rev1') - root_id = tree1.path2id('') + tree1.add( + ["a", "a/b", "a-b", "a/b/foo", "a-b/bar"], + ids=[b"a-id", b"b-id", b"a-b-id", b"foo-id", b"bar-id"], + ) + tree1.commit("rev1", rev_id=b"rev1") + root_id = tree1.path2id("") inv = tree1.root_inventory finally: tree1.unlock() - expected_result1 = [(b'', b'', root_id, b'd'), - (b'', b'a', b'a-id', b'd'), - (b'', b'a-b', b'a-b-id', b'd'), - (b'a', b'b', b'b-id', b'd'), - (b'a/b', b'foo', b'foo-id', b'f'), - (b'a-b', b'bar', b'bar-id', b'f'), - ] - expected_result2 = [(b'', b'', root_id, b'd'), - (b'', b'a', b'a-id', b'd'), - (b'', b'a-b', b'a-b-id', b'd'), - (b'a-b', b'bar', b'bar-id', b'f'), - ] - state = dirstate.DirState.initialize('dirstate') + expected_result1 = [ + (b"", b"", root_id, b"d"), + (b"", b"a", b"a-id", b"d"), + (b"", b"a-b", b"a-b-id", b"d"), + (b"a", b"b", b"b-id", b"d"), + (b"a/b", b"foo", b"foo-id", b"f"), + (b"a-b", b"bar", b"bar-id", b"f"), + ] + expected_result2 = [ + (b"", b"", root_id, b"d"), + (b"", b"a", b"a-id", b"d"), + (b"", b"a-b", b"a-b-id", b"d"), + (b"a-b", b"bar", b"bar-id", b"f"), + ] + state = dirstate.DirState.initialize("dirstate") try: state.set_state_from_inventory(inv) values = [] for entry in state._iter_entries(): values.append(entry[0] + entry[1][0][:1]) self.assertEqual(expected_result1, values) - inv.delete(b'b-id') + inv.delete(b"b-id") state.set_state_from_inventory(inv) values = [] for entry in state._iter_entries(): @@ -921,33 +1195,35 @@ def test_set_state_from_inventory_mixed_paths(self): def test_set_path_id_no_parents(self): """The id of a path can be changed trivally with no parents.""" - state = dirstate.DirState.initialize('dirstate') + state = dirstate.DirState.initialize("dirstate") try: # check precondition to be sure the state does change appropriately. - root_entry = ((b'', b'', b'TREE_ROOT'), [ - (b'd', b'', 0, False, b'x' * 32)]) + root_entry = ((b"", b"", b"TREE_ROOT"), [(b"d", b"", 0, False, b"x" * 32)]) self.assertEqual([root_entry], list(state._iter_entries())) - self.assertEqual(root_entry, state._get_entry(0, path_utf8=b'')) - self.assertEqual(root_entry, - state._get_entry(0, fileid_utf8=b'TREE_ROOT')) - self.assertEqual((None, None), - state._get_entry(0, fileid_utf8=b'second-root-id')) - state.set_path_id(b'', b'second-root-id') - new_root_entry = ((b'', b'', b'second-root-id'), - [(b'd', b'', 0, False, b'x' * 32)]) + self.assertEqual(root_entry, state._get_entry(0, path_utf8=b"")) + self.assertEqual(root_entry, state._get_entry(0, fileid_utf8=b"TREE_ROOT")) + self.assertEqual( + (None, None), state._get_entry(0, fileid_utf8=b"second-root-id") + ) + state.set_path_id(b"", b"second-root-id") + new_root_entry = ( + (b"", b"", b"second-root-id"), + [(b"d", b"", 0, False, b"x" * 32)], + ) expected_rows = [new_root_entry] self.assertEqual(expected_rows, list(state._iter_entries())) + self.assertEqual(new_root_entry, state._get_entry(0, path_utf8=b"")) + self.assertEqual( + new_root_entry, state._get_entry(0, fileid_utf8=b"second-root-id") + ) self.assertEqual( - new_root_entry, state._get_entry(0, path_utf8=b'')) - self.assertEqual(new_root_entry, - state._get_entry(0, fileid_utf8=b'second-root-id')) - self.assertEqual((None, None), - state._get_entry(0, fileid_utf8=b'TREE_ROOT')) + (None, None), state._get_entry(0, fileid_utf8=b"TREE_ROOT") + ) # should work across save too state.save() finally: state.unlock() - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() try: state._validate() @@ -957,53 +1233,61 @@ def test_set_path_id_no_parents(self): def test_set_path_id_with_parents(self): """Set the root file id in a dirstate with parents.""" - mt = self.make_branch_and_tree('mt') + mt = self.make_branch_and_tree("mt") # in case the default tree format uses a different root id - mt.set_root_id(b'TREE_ROOT') - mt.commit('foo', rev_id=b'parent-revid') - rt = mt.branch.repository.revision_tree(b'parent-revid') - state = dirstate.DirState.initialize('dirstate') + mt.set_root_id(b"TREE_ROOT") + mt.commit("foo", rev_id=b"parent-revid") + rt = mt.branch.repository.revision_tree(b"parent-revid") + state = dirstate.DirState.initialize("dirstate") state._validate() try: - state.set_parent_trees([(b'parent-revid', rt)], ghosts=[]) - root_entry = ((b'', b'', b'TREE_ROOT'), - [(b'd', b'', 0, False, b'x' * 32), - (b'd', b'', 0, False, b'parent-revid')]) - self.assertEqual(root_entry, state._get_entry(0, path_utf8=b'')) - self.assertEqual(root_entry, - state._get_entry(0, fileid_utf8=b'TREE_ROOT')) - self.assertEqual((None, None), - state._get_entry(0, fileid_utf8=b'Asecond-root-id')) - state.set_path_id(b'', b'Asecond-root-id') + state.set_parent_trees([(b"parent-revid", rt)], ghosts=[]) + root_entry = ( + (b"", b"", b"TREE_ROOT"), + [ + (b"d", b"", 0, False, b"x" * 32), + (b"d", b"", 0, False, b"parent-revid"), + ], + ) + self.assertEqual(root_entry, state._get_entry(0, path_utf8=b"")) + self.assertEqual(root_entry, state._get_entry(0, fileid_utf8=b"TREE_ROOT")) + self.assertEqual( + (None, None), state._get_entry(0, fileid_utf8=b"Asecond-root-id") + ) + state.set_path_id(b"", b"Asecond-root-id") state._validate() # now see that it is what we expected - old_root_entry = ((b'', b'', b'TREE_ROOT'), - [(b'a', b'', 0, False, b''), - (b'd', b'', 0, False, b'parent-revid')]) - new_root_entry = ((b'', b'', b'Asecond-root-id'), - [(b'd', b'', 0, False, b''), - (b'a', b'', 0, False, b'')]) + old_root_entry = ( + (b"", b"", b"TREE_ROOT"), + [(b"a", b"", 0, False, b""), (b"d", b"", 0, False, b"parent-revid")], + ) + new_root_entry = ( + (b"", b"", b"Asecond-root-id"), + [(b"d", b"", 0, False, b""), (b"a", b"", 0, False, b"")], + ) expected_rows = [new_root_entry, old_root_entry] state._validate() self.assertEqual(expected_rows, list(state._iter_entries())) + self.assertEqual(new_root_entry, state._get_entry(0, path_utf8=b"")) + self.assertEqual(old_root_entry, state._get_entry(1, path_utf8=b"")) self.assertEqual( - new_root_entry, state._get_entry(0, path_utf8=b'')) + (None, None), state._get_entry(0, fileid_utf8=b"TREE_ROOT") + ) self.assertEqual( - old_root_entry, state._get_entry(1, path_utf8=b'')) - self.assertEqual((None, None), - state._get_entry(0, fileid_utf8=b'TREE_ROOT')) - self.assertEqual(old_root_entry, - state._get_entry(1, fileid_utf8=b'TREE_ROOT')) - self.assertEqual(new_root_entry, - state._get_entry(0, fileid_utf8=b'Asecond-root-id')) - self.assertEqual((None, None), - state._get_entry(1, fileid_utf8=b'Asecond-root-id')) + old_root_entry, state._get_entry(1, fileid_utf8=b"TREE_ROOT") + ) + self.assertEqual( + new_root_entry, state._get_entry(0, fileid_utf8=b"Asecond-root-id") + ) + self.assertEqual( + (None, None), state._get_entry(1, fileid_utf8=b"Asecond-root-id") + ) # should work across save too state.save() finally: state.unlock() # now flush & check we get the same - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() try: state._validate() @@ -1014,36 +1298,39 @@ def test_set_path_id_with_parents(self): state.lock_write() try: state._validate() - state.set_path_id(b'', b'tree-root-2') + state.set_path_id(b"", b"tree-root-2") state._validate() finally: state.unlock() def test_set_parent_trees_no_content(self): # set_parent_trees is a slow but important api to support. - tree1 = self.make_branch_and_memory_tree('tree1') + tree1 = self.make_branch_and_memory_tree("tree1") tree1.lock_write() try: - tree1.add('') - revid1 = tree1.commit('foo') + tree1.add("") + revid1 = tree1.commit("foo") finally: tree1.unlock() - branch2 = tree1.branch.controldir.clone('tree2').open_branch() + branch2 = tree1.branch.controldir.clone("tree2").open_branch() tree2 = memorytree.MemoryTree.create_on_branch(branch2) tree2.lock_write() try: - revid2 = tree2.commit('foo') - root_id = tree2.path2id('') + revid2 = tree2.commit("foo") + root_id = tree2.path2id("") finally: tree2.unlock() - state = dirstate.DirState.initialize('dirstate') + state = dirstate.DirState.initialize("dirstate") try: - state.set_path_id(b'', root_id) + state.set_path_id(b"", root_id) state.set_parent_trees( - ((revid1, tree1.branch.repository.revision_tree(revid1)), - (revid2, tree2.branch.repository.revision_tree(revid2)), - (b'ghost-rev', None)), - [b'ghost-rev']) + ( + (revid1, tree1.branch.repository.revision_tree(revid1)), + (revid2, tree2.branch.repository.revision_tree(revid2)), + (b"ghost-rev", None), + ), + [b"ghost-rev"], + ) # check we can reopen and use the dirstate after setting parent # trees. state._validate() @@ -1051,36 +1338,50 @@ def test_set_parent_trees_no_content(self): state._validate() finally: state.unlock() - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_write() try: - self.assertEqual([revid1, revid2, b'ghost-rev'], - state.get_parent_ids()) + self.assertEqual([revid1, revid2, b"ghost-rev"], state.get_parent_ids()) # iterating the entire state ensures that the state is parsable. list(state._iter_entries()) # be sure that it sets not appends - change it state.set_parent_trees( - ((revid1, tree1.branch.repository.revision_tree(revid1)), - (b'ghost-rev', None)), - [b'ghost-rev']) + ( + (revid1, tree1.branch.repository.revision_tree(revid1)), + (b"ghost-rev", None), + ), + [b"ghost-rev"], + ) # and now put it back. state.set_parent_trees( - ((revid1, tree1.branch.repository.revision_tree(revid1)), - (revid2, tree2.branch.repository.revision_tree(revid2)), - (b'ghost-rev', tree2.branch.repository.revision_tree( - _mod_revision.NULL_REVISION))), - [b'ghost-rev']) - self.assertEqual([revid1, revid2, b'ghost-rev'], - state.get_parent_ids()) + ( + (revid1, tree1.branch.repository.revision_tree(revid1)), + (revid2, tree2.branch.repository.revision_tree(revid2)), + ( + b"ghost-rev", + tree2.branch.repository.revision_tree( + _mod_revision.NULL_REVISION + ), + ), + ), + [b"ghost-rev"], + ) + self.assertEqual([revid1, revid2, b"ghost-rev"], state.get_parent_ids()) # the ghost should be recorded as such by set_parent_trees. - self.assertEqual([b'ghost-rev'], state.get_ghosts()) + self.assertEqual([b"ghost-rev"], state.get_ghosts()) self.assertEqual( - [((b'', b'', root_id), [ - (b'd', b'', 0, False, dirstate.DirState.NULLSTAT), - (b'd', b'', 0, False, revid1), - (b'd', b'', 0, False, revid1) - ])], - list(state._iter_entries())) + [ + ( + (b"", b"", root_id), + [ + (b"d", b"", 0, False, dirstate.DirState.NULLSTAT), + (b"d", b"", 0, False, revid1), + (b"d", b"", 0, False, revid1), + ], + ) + ], + list(state._iter_entries()), + ) finally: state.unlock() @@ -1089,46 +1390,68 @@ def test_set_parent_trees_file_missing_from_tree(self): # they should get listed just once by id, even if they are in two # separate trees. # set_parent_trees is a slow but important api to support. - tree1 = self.make_branch_and_memory_tree('tree1') + tree1 = self.make_branch_and_memory_tree("tree1") tree1.lock_write() try: - tree1.add('') - tree1.add(['a file'], ['file'], [b'file-id']) - tree1.put_file_bytes_non_atomic('a file', b'file-content') - revid1 = tree1.commit('foo') + tree1.add("") + tree1.add(["a file"], ["file"], [b"file-id"]) + tree1.put_file_bytes_non_atomic("a file", b"file-content") + revid1 = tree1.commit("foo") finally: tree1.unlock() - branch2 = tree1.branch.controldir.clone('tree2').open_branch() + branch2 = tree1.branch.controldir.clone("tree2").open_branch() tree2 = memorytree.MemoryTree.create_on_branch(branch2) tree2.lock_write() try: - tree2.put_file_bytes_non_atomic('a file', b'new file-content') - revid2 = tree2.commit('foo') - root_id = tree2.path2id('') + tree2.put_file_bytes_non_atomic("a file", b"new file-content") + revid2 = tree2.commit("foo") + root_id = tree2.path2id("") finally: tree2.unlock() # check the layout in memory - expected_result = [revid1, revid2], [ - ((b'', b'', root_id), [ - (b'd', b'', 0, False, dirstate.DirState.NULLSTAT), - (b'd', b'', 0, False, revid1), - (b'd', b'', 0, False, revid1) - ]), - ((b'', b'a file', b'file-id'), [ - (b'a', b'', 0, False, b''), - (b'f', b'2439573625385400f2a669657a7db6ae7515d371', 12, False, - revid1), - (b'f', b'542e57dc1cda4af37cb8e55ec07ce60364bb3c7d', 16, False, - revid2) - ]) - ] - state = dirstate.DirState.initialize('dirstate') + expected_result = ( + [revid1, revid2], + [ + ( + (b"", b"", root_id), + [ + (b"d", b"", 0, False, dirstate.DirState.NULLSTAT), + (b"d", b"", 0, False, revid1), + (b"d", b"", 0, False, revid1), + ], + ), + ( + (b"", b"a file", b"file-id"), + [ + (b"a", b"", 0, False, b""), + ( + b"f", + b"2439573625385400f2a669657a7db6ae7515d371", + 12, + False, + revid1, + ), + ( + b"f", + b"542e57dc1cda4af37cb8e55ec07ce60364bb3c7d", + 16, + False, + revid2, + ), + ], + ), + ], + ) + state = dirstate.DirState.initialize("dirstate") try: - state.set_path_id(b'', root_id) + state.set_path_id(b"", root_id) state.set_parent_trees( - ((revid1, tree1.branch.repository.revision_tree(revid1)), - (revid2, tree2.branch.repository.revision_tree(revid2)), - ), []) + ( + (revid1, tree1.branch.repository.revision_tree(revid1)), + (revid2, tree2.branch.repository.revision_tree(revid2)), + ), + [], + ) except: state.unlock() raise @@ -1142,27 +1465,39 @@ def test_set_parent_trees_file_missing_from_tree(self): def test_add_path_to_root_no_parents_all_data(self): # The most trivial addition of a path is when there are no parents and # its in the root and all data about the file is supplied - self.build_tree(['a file']) - stat = os.lstat('a file') + self.build_tree(["a file"]) + stat = os.lstat("a file") # the 1*20 is the sha1 pretend value. - state = dirstate.DirState.initialize('dirstate') + state = dirstate.DirState.initialize("dirstate") expected_entries = [ - ((b'', b'', b'TREE_ROOT'), [ - (b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - ]), - ((b'', b'a file', b'a-file-id'), [ - (b'f', b'1' * 20, 19, False, dirstate.pack_stat(stat)), # current tree - ]), - ] + ( + (b"", b"", b"TREE_ROOT"), + [ + (b"d", b"", 0, False, dirstate.DirState.NULLSTAT), # current tree + ], + ), + ( + (b"", b"a file", b"a-file-id"), + [ + ( + b"f", + b"1" * 20, + 19, + False, + dirstate.pack_stat(stat), + ), # current tree + ], + ), + ] try: - state.add('a file', b'a-file-id', 'file', stat, b'1' * 20) + state.add("a file", b"a-file-id", "file", stat, b"1" * 20) # having added it, it should be in the output of iter_entries. self.assertEqual(expected_entries, list(state._iter_entries())) # saving and reloading should not affect this. state.save() finally: state.unlock() - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() self.addCleanup(state.unlock) self.assertEqual(expected_entries, list(state._iter_entries())) @@ -1174,35 +1509,48 @@ def test_add_path_to_unversioned_directory(self): once dirstate is stable and if it is merged with WorkingTree3, consider removing this copy of the test. """ - self.build_tree(['unversioned/', 'unversioned/a file']) - state = dirstate.DirState.initialize('dirstate') + self.build_tree(["unversioned/", "unversioned/a file"]) + state = dirstate.DirState.initialize("dirstate") self.addCleanup(state.unlock) - self.assertRaises(errors.NotVersionedError, state.add, - 'unversioned/a file', b'a-file-id', 'file', None, None) + self.assertRaises( + errors.NotVersionedError, + state.add, + "unversioned/a file", + b"a-file-id", + "file", + None, + None, + ) def test_add_directory_to_root_no_parents_all_data(self): # The most trivial addition of a dir is when there are no parents and # its in the root and all data about the file is supplied - self.build_tree(['a dir/']) - stat = os.lstat('a dir') + self.build_tree(["a dir/"]) + stat = os.lstat("a dir") expected_entries = [ - ((b'', b'', b'TREE_ROOT'), [ - (b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - ]), - ((b'', b'a dir', b'a-dir-id'), [ - (b'd', b'', 0, False, dirstate.pack_stat(stat)), # current tree - ]), - ] - state = dirstate.DirState.initialize('dirstate') + ( + (b"", b"", b"TREE_ROOT"), + [ + (b"d", b"", 0, False, dirstate.DirState.NULLSTAT), # current tree + ], + ), + ( + (b"", b"a dir", b"a-dir-id"), + [ + (b"d", b"", 0, False, dirstate.pack_stat(stat)), # current tree + ], + ), + ] + state = dirstate.DirState.initialize("dirstate") try: - state.add('a dir', b'a-dir-id', 'directory', stat, None) + state.add("a dir", b"a-dir-id", "directory", stat, None) # having added it, it should be in the output of iter_entries. self.assertEqual(expected_entries, list(state._iter_entries())) # saving and reloading should not affect this. state.save() finally: state.unlock() - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() self.addCleanup(state.unlock) state._validate() @@ -1216,83 +1564,104 @@ def _test_add_symlink_to_root_no_parents_all_data(self, link_name, target): os.symlink(target, link_name) stat = os.lstat(link_name) expected_entries = [ - ((b'', b'', b'TREE_ROOT'), [ - (b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - ]), - ((b'', link_name.encode('UTF-8'), b'a-link-id'), [ - (b'l', target.encode('UTF-8'), stat[6], - False, dirstate.pack_stat(stat)), # current tree - ]), - ] - state = dirstate.DirState.initialize('dirstate') + ( + (b"", b"", b"TREE_ROOT"), + [ + (b"d", b"", 0, False, dirstate.DirState.NULLSTAT), # current tree + ], + ), + ( + (b"", link_name.encode("UTF-8"), b"a-link-id"), + [ + ( + b"l", + target.encode("UTF-8"), + stat[6], + False, + dirstate.pack_stat(stat), + ), # current tree + ], + ), + ] + state = dirstate.DirState.initialize("dirstate") try: - state.add(link_name, b'a-link-id', 'symlink', stat, - target.encode('UTF-8')) + state.add(link_name, b"a-link-id", "symlink", stat, target.encode("UTF-8")) # having added it, it should be in the output of iter_entries. self.assertEqual(expected_entries, list(state._iter_entries())) # saving and reloading should not affect this. state.save() finally: state.unlock() - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() self.addCleanup(state.unlock) self.assertEqual(expected_entries, list(state._iter_entries())) def test_add_symlink_to_root_no_parents_all_data(self): - self._test_add_symlink_to_root_no_parents_all_data( - 'a link', 'target') + self._test_add_symlink_to_root_no_parents_all_data("a link", "target") def test_add_symlink_unicode_to_root_no_parents_all_data(self): self.requireFeature(features.UnicodeFilenameFeature) self._test_add_symlink_to_root_no_parents_all_data( - '\N{Euro Sign}link', 'targ\N{Euro Sign}et') + "\N{Euro Sign}link", "targ\N{Euro Sign}et" + ) def test_add_directory_and_child_no_parents_all_data(self): # after adding a directory, we should be able to add children to it. - self.build_tree(['a dir/', 'a dir/a file']) - dirstat = os.lstat('a dir') - filestat = os.lstat('a dir/a file') + self.build_tree(["a dir/", "a dir/a file"]) + dirstat = os.lstat("a dir") + filestat = os.lstat("a dir/a file") expected_entries = [ - ((b'', b'', b'TREE_ROOT'), [ - (b'd', b'', 0, False, dirstate.DirState.NULLSTAT), # current tree - ]), - ((b'', b'a dir', b'a-dir-id'), [ - (b'd', b'', 0, False, dirstate.pack_stat(dirstat)), # current tree - ]), - ((b'a dir', b'a file', b'a-file-id'), [ - (b'f', b'1' * 20, 25, False, - dirstate.pack_stat(filestat)), # current tree details - ]), - ] - state = dirstate.DirState.initialize('dirstate') + ( + (b"", b"", b"TREE_ROOT"), + [ + (b"d", b"", 0, False, dirstate.DirState.NULLSTAT), # current tree + ], + ), + ( + (b"", b"a dir", b"a-dir-id"), + [ + (b"d", b"", 0, False, dirstate.pack_stat(dirstat)), # current tree + ], + ), + ( + (b"a dir", b"a file", b"a-file-id"), + [ + ( + b"f", + b"1" * 20, + 25, + False, + dirstate.pack_stat(filestat), + ), # current tree details + ], + ), + ] + state = dirstate.DirState.initialize("dirstate") try: - state.add('a dir', b'a-dir-id', 'directory', dirstat, None) - state.add('a dir/a file', b'a-file-id', - 'file', filestat, b'1' * 20) + state.add("a dir", b"a-dir-id", "directory", dirstat, None) + state.add("a dir/a file", b"a-file-id", "file", filestat, b"1" * 20) # added it, it should be in the output of iter_entries. self.assertEqual(expected_entries, list(state._iter_entries())) # saving and reloading should not affect this. state.save() finally: state.unlock() - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() self.addCleanup(state.unlock) self.assertEqual(expected_entries, list(state._iter_entries())) def test_add_tree_reference(self): # make a dirstate and add a tree reference - state = dirstate.DirState.initialize('dirstate') + state = dirstate.DirState.initialize("dirstate") expected_entry = ( - (b'', b'subdir', b'subdir-id'), - [(b't', b'subtree-123123', 0, False, - b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')], - ) + (b"", b"subdir", b"subdir-id"), + [(b"t", b"subtree-123123", 0, False, b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")], + ) try: - state.add('subdir', b'subdir-id', 'tree-reference', - None, b'subtree-123123') - entry = state._get_entry(0, b'subdir-id', b'subdir') + state.add("subdir", b"subdir-id", "tree-reference", None, b"subtree-123123") + entry = state._get_entry(0, b"subdir-id", b"subdir") self.assertEqual(entry, expected_entry) state._validate() state.save() @@ -1302,20 +1671,22 @@ def test_add_tree_reference(self): state.lock_read() self.addCleanup(state.unlock) state._validate() - entry2 = state._get_entry(0, b'subdir-id', b'subdir') + entry2 = state._get_entry(0, b"subdir-id", b"subdir") self.assertEqual(entry, entry2) self.assertEqual(entry, expected_entry) # and lookup by id should work too - entry2 = state._get_entry(0, fileid_utf8=b'subdir-id') + entry2 = state._get_entry(0, fileid_utf8=b"subdir-id") self.assertEqual(entry, expected_entry) def test_add_forbidden_names(self): - state = dirstate.DirState.initialize('dirstate') + state = dirstate.DirState.initialize("dirstate") self.addCleanup(state.unlock) - self.assertRaises(errors.BzrError, - state.add, '.', b'ass-id', 'directory', None, None) - self.assertRaises(errors.BzrError, - state.add, '..', b'ass-id', 'directory', None, None) + self.assertRaises( + errors.BzrError, state.add, ".", b"ass-id", "directory", None, None + ) + self.assertRaises( + errors.BzrError, state.add, "..", b"ass-id", "directory", None, None + ) def test_set_state_with_rename_b_a_bug_395556(self): # bug 395556 uncovered a bug where the dirstate ends up with a false @@ -1323,22 +1694,23 @@ def test_set_state_with_rename_b_a_bug_395556(self): # absent or relocated records. This then leads to further corruption # when a commit occurs, as the incorrect relocation gathers an # incorrect absent in tree 1, and future changes go to pot. - tree1 = self.make_branch_and_tree('tree1') - self.build_tree(['tree1/b']) + tree1 = self.make_branch_and_tree("tree1") + self.build_tree(["tree1/b"]) with tree1.lock_write(): - tree1.add(['b'], ids=[b'b-id']) - root_id = tree1.path2id('') + tree1.add(["b"], ids=[b"b-id"]) + root_id = tree1.path2id("") inv = tree1.root_inventory - state = dirstate.DirState.initialize('dirstate') + state = dirstate.DirState.initialize("dirstate") try: # Set the initial state with 'b' state.set_state_from_inventory(inv) - inv.rename(b'b-id', root_id, 'a') + inv.rename(b"b-id", root_id, "a") # Set the new state with 'a', which currently corrupts. state.set_state_from_inventory(inv) - expected_result1 = [(b'', b'', root_id, b'd'), - (b'', b'a', b'b-id', b'f'), - ] + expected_result1 = [ + (b"", b"", root_id, b"d"), + (b"", b"a", b"b-id", b"f"), + ] values = [] for entry in state._iter_entries(): values.append(entry[0] + entry[1][0][:1]) @@ -1348,7 +1720,6 @@ def test_set_state_with_rename_b_a_bug_395556(self): class TestDirStateHashUpdates(TestCaseWithDirState): - def do_update_entry(self, state, path): entry = state._get_entry(0, path_utf8=path) stat = os.lstat(path) @@ -1368,56 +1739,61 @@ def _read_state_content(self, state): return state._state_file.read() def test_worth_saving_limit_avoids_writing(self): - tree = self.make_branch_and_tree('.') - self.build_tree(['c', 'd']) + tree = self.make_branch_and_tree(".") + self.build_tree(["c", "d"]) tree.lock_write() - tree.add(['c', 'd'], ids=[b'c-id', b'd-id']) - tree.commit('add c and d') - state = InstrumentedDirState.on_file(tree.current_dirstate()._filename, - worth_saving_limit=2) + tree.add(["c", "d"], ids=[b"c-id", b"d-id"]) + tree.commit("add c and d") + state = InstrumentedDirState.on_file( + tree.current_dirstate()._filename, worth_saving_limit=2 + ) tree.unlock() state.lock_write() self.addCleanup(state.unlock) state._read_dirblocks_if_needed() state.adjust_time(+20) # Allow things to be cached - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) content = self._read_state_content(state) - self.do_update_entry(state, b'c') + self.do_update_entry(state, b"c") self.assertEqual(1, len(state._known_hash_changes)) - self.assertEqual(dirstate.DirState.IN_MEMORY_HASH_MODIFIED, - state._dirblock_state) + self.assertEqual( + dirstate.DirState.IN_MEMORY_HASH_MODIFIED, state._dirblock_state + ) state.save() # It should not have set the state to IN_MEMORY_UNMODIFIED because the # hash values haven't been written out. - self.assertEqual(dirstate.DirState.IN_MEMORY_HASH_MODIFIED, - state._dirblock_state) + self.assertEqual( + dirstate.DirState.IN_MEMORY_HASH_MODIFIED, state._dirblock_state + ) self.assertEqual(content, self._read_state_content(state)) - self.assertEqual(dirstate.DirState.IN_MEMORY_HASH_MODIFIED, - state._dirblock_state) - self.do_update_entry(state, b'd') + self.assertEqual( + dirstate.DirState.IN_MEMORY_HASH_MODIFIED, state._dirblock_state + ) + self.do_update_entry(state, b"d") self.assertEqual(2, len(state._known_hash_changes)) state.save() - self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, - state._dirblock_state) + self.assertEqual(dirstate.DirState.IN_MEMORY_UNMODIFIED, state._dirblock_state) self.assertEqual(0, len(state._known_hash_changes)) class TestGetLines(TestCaseWithDirState): - def test_get_line_with_2_rows(self): state = self.create_dirstate_with_root_and_subdir() try: - self.assertEqual([b'#bazaar dirstate flat format 3\n', - b'crc32: 41262208\n', - b'num_entries: 2\n', - b'0\x00\n\x00' - b'0\x00\n\x00' - b'\x00\x00a-root-value\x00' - b'd\x00\x000\x00n\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk\x00\n\x00' - b'\x00subdir\x00subdir-id\x00' - b'd\x00\x000\x00n\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk\x00\n\x00' - ], state.get_lines()) + self.assertEqual( + [ + b"#bazaar dirstate flat format 3\n", + b"crc32: 41262208\n", + b"num_entries: 2\n", + b"0\x00\n\x00" + b"0\x00\n\x00" + b"\x00\x00a-root-value\x00" + b"d\x00\x000\x00n\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk\x00\n\x00" + b"\x00subdir\x00subdir-id\x00" + b"d\x00\x000\x00n\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk\x00\n\x00", + ], + state.get_lines(), + ) finally: state.unlock() @@ -1425,87 +1801,119 @@ def test_entry_to_line(self): state = self.create_dirstate_with_root() try: self.assertEqual( - b'\x00\x00a-root-value\x00d\x00\x000\x00n' - b'\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk', - state._entry_to_line(state._dirblocks[0][1][0])) + b"\x00\x00a-root-value\x00d\x00\x000\x00n" + b"\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk", + state._entry_to_line(state._dirblocks[0][1][0]), + ) finally: state.unlock() def test_entry_to_line_with_parent(self): - packed_stat = b'AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk' - root_entry = (b'', b'', b'a-root-value'), [ - (b'd', b'', 0, False, packed_stat), # current tree details - # first: a pointer to the current location - (b'a', b'dirname/basename', 0, False, b''), - ] - state = dirstate.DirState.initialize('dirstate') + packed_stat = b"AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk" + root_entry = ( + (b"", b"", b"a-root-value"), + [ + (b"d", b"", 0, False, packed_stat), # current tree details + # first: a pointer to the current location + (b"a", b"dirname/basename", 0, False, b""), + ], + ) + state = dirstate.DirState.initialize("dirstate") try: self.assertEqual( - b'\x00\x00a-root-value\x00' - b'd\x00\x000\x00n\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk\x00' - b'a\x00dirname/basename\x000\x00n\x00', - state._entry_to_line(root_entry)) + b"\x00\x00a-root-value\x00" + b"d\x00\x000\x00n\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk\x00" + b"a\x00dirname/basename\x000\x00n\x00", + state._entry_to_line(root_entry), + ) finally: state.unlock() def test_entry_to_line_with_two_parents_at_different_paths(self): # / in the tree, at / in one parent and /dirname/basename in the other. - packed_stat = b'AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk' - root_entry = (b'', b'', b'a-root-value'), [ - (b'd', b'', 0, False, packed_stat), # current tree details - (b'd', b'', 0, False, b'rev_id'), # first parent details - # second: a pointer to the current location - (b'a', b'dirname/basename', 0, False, b''), - ] - state = dirstate.DirState.initialize('dirstate') + packed_stat = b"AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk" + root_entry = ( + (b"", b"", b"a-root-value"), + [ + (b"d", b"", 0, False, packed_stat), # current tree details + (b"d", b"", 0, False, b"rev_id"), # first parent details + # second: a pointer to the current location + (b"a", b"dirname/basename", 0, False, b""), + ], + ) + state = dirstate.DirState.initialize("dirstate") try: self.assertEqual( - b'\x00\x00a-root-value\x00' - b'd\x00\x000\x00n\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk\x00' - b'd\x00\x000\x00n\x00rev_id\x00' - b'a\x00dirname/basename\x000\x00n\x00', - state._entry_to_line(root_entry)) + b"\x00\x00a-root-value\x00" + b"d\x00\x000\x00n\x00AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk\x00" + b"d\x00\x000\x00n\x00rev_id\x00" + b"a\x00dirname/basename\x000\x00n\x00", + state._entry_to_line(root_entry), + ) finally: state.unlock() def test_iter_entries(self): # we should be able to iterate the dirstate entries from end to end # this is for get_lines to be easy to read. - packed_stat = b'AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk' + packed_stat = b"AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk" dirblocks = [] - root_entries = [((b'', b'', b'a-root-value'), [ - (b'd', b'', 0, False, packed_stat), # current tree details - ])] - dirblocks.append(('', root_entries)) + root_entries = [ + ( + (b"", b"", b"a-root-value"), + [ + (b"d", b"", 0, False, packed_stat), # current tree details + ], + ) + ] + dirblocks.append(("", root_entries)) # add two files in the root - subdir_entry = (b'', b'subdir', b'subdir-id'), [ - (b'd', b'', 0, False, packed_stat), # current tree details - ] - afile_entry = (b'', b'afile', b'afile-id'), [ - (b'f', b'sha1value', 34, False, packed_stat), # current tree details - ] - dirblocks.append(('', [subdir_entry, afile_entry])) + subdir_entry = ( + (b"", b"subdir", b"subdir-id"), + [ + (b"d", b"", 0, False, packed_stat), # current tree details + ], + ) + afile_entry = ( + (b"", b"afile", b"afile-id"), + [ + (b"f", b"sha1value", 34, False, packed_stat), # current tree details + ], + ) + dirblocks.append(("", [subdir_entry, afile_entry])) # and one in subdir - file_entry2 = (b'subdir', b'2file', b'2file-id'), [ - (b'f', b'sha1value', 23, False, packed_stat), # current tree details - ] - dirblocks.append(('subdir', [file_entry2])) - state = dirstate.DirState.initialize('dirstate') + file_entry2 = ( + (b"subdir", b"2file", b"2file-id"), + [ + (b"f", b"sha1value", 23, False, packed_stat), # current tree details + ], + ) + dirblocks.append(("subdir", [file_entry2])) + state = dirstate.DirState.initialize("dirstate") try: state._set_data([], dirblocks) - expected_entries = [root_entries[0], subdir_entry, afile_entry, - file_entry2] + expected_entries = [root_entries[0], subdir_entry, afile_entry, file_entry2] self.assertEqual(expected_entries, list(state._iter_entries())) finally: state.unlock() class TestGetBlockRowIndex(TestCaseWithDirState): - - def assertBlockRowIndexEqual(self, block_index, row_index, dir_present, - file_present, state, dirname, basename, tree_index): - self.assertEqual((block_index, row_index, dir_present, file_present), - state._get_block_entry_index(dirname, basename, tree_index)) + def assertBlockRowIndexEqual( + self, + block_index, + row_index, + dir_present, + file_present, + state, + dirname, + basename, + tree_index, + ): + self.assertEqual( + (block_index, row_index, dir_present, file_present), + state._get_block_entry_index(dirname, basename, tree_index), + ) if dir_present: block = state._dirblocks[block_index] self.assertEqual(dirname, block[0]) @@ -1517,56 +1925,45 @@ def assertBlockRowIndexEqual(self, block_index, row_index, dir_present, def test_simple_structure(self): state = self.create_dirstate_with_root_and_subdir() self.addCleanup(state.unlock) - self.assertBlockRowIndexEqual( - 1, 0, True, True, state, b'', b'subdir', 0) - self.assertBlockRowIndexEqual( - 1, 0, True, False, state, b'', b'bdir', 0) - self.assertBlockRowIndexEqual( - 1, 1, True, False, state, b'', b'zdir', 0) - self.assertBlockRowIndexEqual( - 2, 0, False, False, state, b'a', b'foo', 0) - self.assertBlockRowIndexEqual(2, 0, False, False, state, - b'subdir', b'foo', 0) + self.assertBlockRowIndexEqual(1, 0, True, True, state, b"", b"subdir", 0) + self.assertBlockRowIndexEqual(1, 0, True, False, state, b"", b"bdir", 0) + self.assertBlockRowIndexEqual(1, 1, True, False, state, b"", b"zdir", 0) + self.assertBlockRowIndexEqual(2, 0, False, False, state, b"a", b"foo", 0) + self.assertBlockRowIndexEqual(2, 0, False, False, state, b"subdir", b"foo", 0) def test_complex_structure_exists(self): state = self.create_complex_dirstate() self.addCleanup(state.unlock) # Make sure we can find everything that exists - self.assertBlockRowIndexEqual(0, 0, True, True, state, b'', b'', 0) - self.assertBlockRowIndexEqual(1, 0, True, True, state, b'', b'a', 0) - self.assertBlockRowIndexEqual(1, 1, True, True, state, b'', b'b', 0) - self.assertBlockRowIndexEqual(1, 2, True, True, state, b'', b'c', 0) - self.assertBlockRowIndexEqual(1, 3, True, True, state, b'', b'd', 0) - self.assertBlockRowIndexEqual(2, 0, True, True, state, b'a', b'e', 0) - self.assertBlockRowIndexEqual(2, 1, True, True, state, b'a', b'f', 0) - self.assertBlockRowIndexEqual(3, 0, True, True, state, b'b', b'g', 0) - self.assertBlockRowIndexEqual(3, 1, True, True, state, - b'b', b'h\xc3\xa5', 0) + self.assertBlockRowIndexEqual(0, 0, True, True, state, b"", b"", 0) + self.assertBlockRowIndexEqual(1, 0, True, True, state, b"", b"a", 0) + self.assertBlockRowIndexEqual(1, 1, True, True, state, b"", b"b", 0) + self.assertBlockRowIndexEqual(1, 2, True, True, state, b"", b"c", 0) + self.assertBlockRowIndexEqual(1, 3, True, True, state, b"", b"d", 0) + self.assertBlockRowIndexEqual(2, 0, True, True, state, b"a", b"e", 0) + self.assertBlockRowIndexEqual(2, 1, True, True, state, b"a", b"f", 0) + self.assertBlockRowIndexEqual(3, 0, True, True, state, b"b", b"g", 0) + self.assertBlockRowIndexEqual(3, 1, True, True, state, b"b", b"h\xc3\xa5", 0) def test_complex_structure_missing(self): state = self.create_complex_dirstate() self.addCleanup(state.unlock) # Make sure things would be inserted in the right locations # '_' comes before 'a' - self.assertBlockRowIndexEqual(0, 0, True, True, state, b'', b'', 0) - self.assertBlockRowIndexEqual(1, 0, True, False, state, b'', b'_', 0) - self.assertBlockRowIndexEqual(1, 1, True, False, state, b'', b'aa', 0) - self.assertBlockRowIndexEqual(1, 4, True, False, state, - b'', b'h\xc3\xa5', 0) - self.assertBlockRowIndexEqual(2, 0, False, False, state, b'_', b'a', 0) - self.assertBlockRowIndexEqual( - 3, 0, False, False, state, b'aa', b'a', 0) - self.assertBlockRowIndexEqual( - 4, 0, False, False, state, b'bb', b'a', 0) + self.assertBlockRowIndexEqual(0, 0, True, True, state, b"", b"", 0) + self.assertBlockRowIndexEqual(1, 0, True, False, state, b"", b"_", 0) + self.assertBlockRowIndexEqual(1, 1, True, False, state, b"", b"aa", 0) + self.assertBlockRowIndexEqual(1, 4, True, False, state, b"", b"h\xc3\xa5", 0) + self.assertBlockRowIndexEqual(2, 0, False, False, state, b"_", b"a", 0) + self.assertBlockRowIndexEqual(3, 0, False, False, state, b"aa", b"a", 0) + self.assertBlockRowIndexEqual(4, 0, False, False, state, b"bb", b"a", 0) # This would be inserted between a/ and b/ - self.assertBlockRowIndexEqual( - 3, 0, False, False, state, b'a/e', b'a', 0) + self.assertBlockRowIndexEqual(3, 0, False, False, state, b"a/e", b"a", 0) # Put at the end - self.assertBlockRowIndexEqual(4, 0, False, False, state, b'e', b'a', 0) + self.assertBlockRowIndexEqual(4, 0, False, False, state, b"e", b"a", 0) class TestGetEntry(TestCaseWithDirState): - def assertEntryEqual(self, dirname, basename, file_id, state, path, index): """Check that the right entry is returned for a request to getEntry.""" entry = state._get_entry(index, path_utf8=path) @@ -1579,34 +1976,34 @@ def assertEntryEqual(self, dirname, basename, file_id, state, path, index): def test_simple_structure(self): state = self.create_dirstate_with_root_and_subdir() self.addCleanup(state.unlock) - self.assertEntryEqual(b'', b'', b'a-root-value', state, b'', 0) - self.assertEntryEqual( - b'', b'subdir', b'subdir-id', state, b'subdir', 0) - self.assertEntryEqual(None, None, None, state, b'missing', 0) - self.assertEntryEqual(None, None, None, state, b'missing/foo', 0) - self.assertEntryEqual(None, None, None, state, b'subdir/foo', 0) + self.assertEntryEqual(b"", b"", b"a-root-value", state, b"", 0) + self.assertEntryEqual(b"", b"subdir", b"subdir-id", state, b"subdir", 0) + self.assertEntryEqual(None, None, None, state, b"missing", 0) + self.assertEntryEqual(None, None, None, state, b"missing/foo", 0) + self.assertEntryEqual(None, None, None, state, b"subdir/foo", 0) def test_complex_structure_exists(self): state = self.create_complex_dirstate() self.addCleanup(state.unlock) - self.assertEntryEqual(b'', b'', b'a-root-value', state, b'', 0) - self.assertEntryEqual(b'', b'a', b'a-dir', state, b'a', 0) - self.assertEntryEqual(b'', b'b', b'b-dir', state, b'b', 0) - self.assertEntryEqual(b'', b'c', b'c-file', state, b'c', 0) - self.assertEntryEqual(b'', b'd', b'd-file', state, b'd', 0) - self.assertEntryEqual(b'a', b'e', b'e-dir', state, b'a/e', 0) - self.assertEntryEqual(b'a', b'f', b'f-file', state, b'a/f', 0) - self.assertEntryEqual(b'b', b'g', b'g-file', state, b'b/g', 0) - self.assertEntryEqual(b'b', b'h\xc3\xa5', b'h-\xc3\xa5-file', state, - b'b/h\xc3\xa5', 0) + self.assertEntryEqual(b"", b"", b"a-root-value", state, b"", 0) + self.assertEntryEqual(b"", b"a", b"a-dir", state, b"a", 0) + self.assertEntryEqual(b"", b"b", b"b-dir", state, b"b", 0) + self.assertEntryEqual(b"", b"c", b"c-file", state, b"c", 0) + self.assertEntryEqual(b"", b"d", b"d-file", state, b"d", 0) + self.assertEntryEqual(b"a", b"e", b"e-dir", state, b"a/e", 0) + self.assertEntryEqual(b"a", b"f", b"f-file", state, b"a/f", 0) + self.assertEntryEqual(b"b", b"g", b"g-file", state, b"b/g", 0) + self.assertEntryEqual( + b"b", b"h\xc3\xa5", b"h-\xc3\xa5-file", state, b"b/h\xc3\xa5", 0 + ) def test_complex_structure_missing(self): state = self.create_complex_dirstate() self.addCleanup(state.unlock) - self.assertEntryEqual(None, None, None, state, b'_', 0) - self.assertEntryEqual(None, None, None, state, b'_\xc3\xa5', 0) - self.assertEntryEqual(None, None, None, state, b'a/b', 0) - self.assertEntryEqual(None, None, None, state, b'c/d', 0) + self.assertEntryEqual(None, None, None, state, b"_", 0) + self.assertEntryEqual(None, None, None, state, b"_\xc3\xa5", 0) + self.assertEntryEqual(None, None, None, state, b"a/b", 0) + self.assertEntryEqual(None, None, None, state, b"c/d", 0) def test_get_entry_uninitialized(self): """Calling get_entry will load data if it needs to.""" @@ -1616,20 +2013,17 @@ def test_get_entry_uninitialized(self): finally: state.unlock() del state - state = dirstate.DirState.on_file('dirstate') + state = dirstate.DirState.on_file("dirstate") state.lock_read() try: - self.assertEqual(dirstate.DirState.NOT_IN_MEMORY, - state._header_state) - self.assertEqual(dirstate.DirState.NOT_IN_MEMORY, - state._dirblock_state) - self.assertEntryEqual(b'', b'', b'a-root-value', state, b'', 0) + self.assertEqual(dirstate.DirState.NOT_IN_MEMORY, state._header_state) + self.assertEqual(dirstate.DirState.NOT_IN_MEMORY, state._dirblock_state) + self.assertEntryEqual(b"", b"", b"a-root-value", state, b"", 0) finally: state.unlock() class TestIterChildEntries(TestCaseWithDirState): - def create_dirstate_with_two_trees(self): r"""This dirstate contains multiple files and directories. @@ -1653,67 +2047,102 @@ def create_dirstate_with_two_trees(self): :return: The dirstate, still write-locked. """ - packed_stat = b'AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk' - null_sha = b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' + packed_stat = b"AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk" + null_sha = b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" NULL_PARENT_DETAILS = dirstate.DirState.NULL_PARENT_DETAILS - root_entry = (b'', b'', b'a-root-value'), [ - (b'd', b'', 0, False, packed_stat), - (b'd', b'', 0, False, b'parent-revid'), - ] - a_entry = (b'', b'a', b'a-dir'), [ - (b'd', b'', 0, False, packed_stat), - (b'd', b'', 0, False, b'parent-revid'), - ] - b_entry = (b'', b'b', b'b-dir'), [ - (b'd', b'', 0, False, packed_stat), - (b'd', b'', 0, False, b'parent-revid'), - ] - c_entry = (b'', b'c', b'c-file'), [ - (b'f', null_sha, 10, False, packed_stat), - (b'r', b'b/j', 0, False, b''), - ] - d_entry = (b'', b'd', b'd-file'), [ - (b'f', null_sha, 20, False, packed_stat), - (b'f', b'd', 20, False, b'parent-revid'), - ] - e_entry = (b'a', b'e', b'e-dir'), [ - (b'd', b'', 0, False, packed_stat), - (b'd', b'', 0, False, b'parent-revid'), - ] - f_entry = (b'a', b'f', b'f-file'), [ - (b'f', null_sha, 30, False, packed_stat), - (b'f', b'f', 20, False, b'parent-revid'), - ] - g_entry = (b'b', b'g', b'g-file'), [ - (b'f', null_sha, 30, False, packed_stat), - NULL_PARENT_DETAILS, - ] - h_entry1 = (b'b', b'h\xc3\xa5', b'h-\xc3\xa5-file1'), [ - (b'f', null_sha, 40, False, packed_stat), - NULL_PARENT_DETAILS, - ] - h_entry2 = (b'b', b'h\xc3\xa5', b'h-\xc3\xa5-file2'), [ - NULL_PARENT_DETAILS, - (b'f', b'h', 20, False, b'parent-revid'), - ] - i_entry = (b'b', b'i', b'i-file'), [ - NULL_PARENT_DETAILS, - (b'f', b'h', 20, False, b'parent-revid'), - ] - j_entry = (b'b', b'j', b'c-file'), [ - (b'r', b'c', 0, False, b''), - (b'f', b'j', 20, False, b'parent-revid'), - ] + root_entry = ( + (b"", b"", b"a-root-value"), + [ + (b"d", b"", 0, False, packed_stat), + (b"d", b"", 0, False, b"parent-revid"), + ], + ) + a_entry = ( + (b"", b"a", b"a-dir"), + [ + (b"d", b"", 0, False, packed_stat), + (b"d", b"", 0, False, b"parent-revid"), + ], + ) + b_entry = ( + (b"", b"b", b"b-dir"), + [ + (b"d", b"", 0, False, packed_stat), + (b"d", b"", 0, False, b"parent-revid"), + ], + ) + c_entry = ( + (b"", b"c", b"c-file"), + [ + (b"f", null_sha, 10, False, packed_stat), + (b"r", b"b/j", 0, False, b""), + ], + ) + d_entry = ( + (b"", b"d", b"d-file"), + [ + (b"f", null_sha, 20, False, packed_stat), + (b"f", b"d", 20, False, b"parent-revid"), + ], + ) + e_entry = ( + (b"a", b"e", b"e-dir"), + [ + (b"d", b"", 0, False, packed_stat), + (b"d", b"", 0, False, b"parent-revid"), + ], + ) + f_entry = ( + (b"a", b"f", b"f-file"), + [ + (b"f", null_sha, 30, False, packed_stat), + (b"f", b"f", 20, False, b"parent-revid"), + ], + ) + g_entry = ( + (b"b", b"g", b"g-file"), + [ + (b"f", null_sha, 30, False, packed_stat), + NULL_PARENT_DETAILS, + ], + ) + h_entry1 = ( + (b"b", b"h\xc3\xa5", b"h-\xc3\xa5-file1"), + [ + (b"f", null_sha, 40, False, packed_stat), + NULL_PARENT_DETAILS, + ], + ) + h_entry2 = ( + (b"b", b"h\xc3\xa5", b"h-\xc3\xa5-file2"), + [ + NULL_PARENT_DETAILS, + (b"f", b"h", 20, False, b"parent-revid"), + ], + ) + i_entry = ( + (b"b", b"i", b"i-file"), + [ + NULL_PARENT_DETAILS, + (b"f", b"h", 20, False, b"parent-revid"), + ], + ) + j_entry = ( + (b"b", b"j", b"c-file"), + [ + (b"r", b"c", 0, False, b""), + (b"f", b"j", 20, False, b"parent-revid"), + ], + ) dirblocks = [] - dirblocks.append((b'', [root_entry])) - dirblocks.append((b'', [a_entry, b_entry, c_entry, d_entry])) - dirblocks.append((b'a', [e_entry, f_entry])) - dirblocks.append( - (b'b', [g_entry, h_entry1, h_entry2, i_entry, j_entry])) - state = dirstate.DirState.initialize('dirstate') + dirblocks.append((b"", [root_entry])) + dirblocks.append((b"", [a_entry, b_entry, c_entry, d_entry])) + dirblocks.append((b"a", [e_entry, f_entry])) + dirblocks.append((b"b", [g_entry, h_entry1, h_entry2, i_entry, j_entry])) + state = dirstate.DirState.initialize("dirstate") state._validate() try: - state._set_data([b'parent'], dirblocks) + state._set_data([b"parent"], dirblocks) except: state.unlock() raise @@ -1726,8 +2155,7 @@ def test_iter_children_b(self): expected_result.append(dirblocks[3][1][2]) # h2 expected_result.append(dirblocks[3][1][3]) # i expected_result.append(dirblocks[3][1][4]) # j - self.assertEqual(expected_result, - list(state._iter_child_entries(1, b'b'))) + self.assertEqual(expected_result, list(state._iter_child_entries(1, b"b"))) def test_iter_child_root(self): state, dirblocks = self.create_dirstate_with_two_trees() @@ -1741,8 +2169,7 @@ def test_iter_child_root(self): expected_result.append(dirblocks[3][1][2]) # h2 expected_result.append(dirblocks[3][1][3]) # i expected_result.append(dirblocks[3][1][4]) # j - self.assertEqual(expected_result, - list(state._iter_child_entries(1, b''))) + self.assertEqual(expected_result, list(state._iter_child_entries(1, b""))) class TestDirstateSortOrder(tests.TestCaseWithTransport): @@ -1755,54 +2182,71 @@ def test_add_sorting(self): at a single depth. 'a/a' should come before 'a-a', even though it doesn't lexicographically. """ - dirs = ['a', 'a/a', 'a/a/a', 'a/a/a/a', - 'a-a', 'a/a-a', 'a/a/a-a', 'a/a/a/a-a', - ] - null_sha = b'' - state = dirstate.DirState.initialize('dirstate') + dirs = [ + "a", + "a/a", + "a/a/a", + "a/a/a/a", + "a-a", + "a/a-a", + "a/a/a-a", + "a/a/a/a-a", + ] + null_sha = b"" + state = dirstate.DirState.initialize("dirstate") self.addCleanup(state.unlock) - fake_stat = os.stat('dirstate') + fake_stat = os.stat("dirstate") for d in dirs: - d_id = d.encode('utf-8').replace(b'/', b'_') + b'-id' - file_path = d + '/f' - file_id = file_path.encode('utf-8').replace(b'/', b'_') + b'-id' - state.add(d, d_id, 'directory', fake_stat, null_sha) - state.add(file_path, file_id, 'file', fake_stat, null_sha) - - expected = [b'', b'', b'a', - b'a/a', b'a/a/a', b'a/a/a/a', - b'a/a/a/a-a', b'a/a/a-a', b'a/a-a', b'a-a', - ] - - def split(p): return p.split(b'/') + d_id = d.encode("utf-8").replace(b"/", b"_") + b"-id" + file_path = d + "/f" + file_id = file_path.encode("utf-8").replace(b"/", b"_") + b"-id" + state.add(d, d_id, "directory", fake_stat, null_sha) + state.add(file_path, file_id, "file", fake_stat, null_sha) + + expected = [ + b"", + b"", + b"a", + b"a/a", + b"a/a/a", + b"a/a/a/a", + b"a/a/a/a-a", + b"a/a/a-a", + b"a/a-a", + b"a-a", + ] + + def split(p): + return p.split(b"/") + self.assertEqual(sorted(expected, key=split), expected) dirblock_names = [d[0] for d in state._dirblocks] self.assertEqual(expected, dirblock_names) def test_set_parent_trees_correct_order(self): """After calling set_parent_trees() we should maintain the order.""" - dirs = ['a', 'a-a', 'a/a'] - null_sha = b'' - state = dirstate.DirState.initialize('dirstate') + dirs = ["a", "a-a", "a/a"] + null_sha = b"" + state = dirstate.DirState.initialize("dirstate") self.addCleanup(state.unlock) - fake_stat = os.stat('dirstate') + fake_stat = os.stat("dirstate") for d in dirs: - d_id = d.encode('utf-8').replace(b'/', b'_') + b'-id' - file_path = d + '/f' - file_id = file_path.encode('utf-8').replace(b'/', b'_') + b'-id' - state.add(d, d_id, 'directory', fake_stat, null_sha) - state.add(file_path, file_id, 'file', fake_stat, null_sha) + d_id = d.encode("utf-8").replace(b"/", b"_") + b"-id" + file_path = d + "/f" + file_id = file_path.encode("utf-8").replace(b"/", b"_") + b"-id" + state.add(d, d_id, "directory", fake_stat, null_sha) + state.add(file_path, file_id, "file", fake_stat, null_sha) - expected = [b'', b'', b'a', b'a/a', b'a-a'] + expected = [b"", b"", b"a", b"a/a", b"a-a"] dirblock_names = [d[0] for d in state._dirblocks] self.assertEqual(expected, dirblock_names) # *really* cheesy way to just get an empty tree - repo = self.make_repository('repo') + repo = self.make_repository("repo") empty_tree = repo.revision_tree(_mod_revision.NULL_REVISION) - state.set_parent_trees([('null:', empty_tree)], []) + state.set_parent_trees([("null:", empty_tree)], []) dirblock_names = [d[0] for d in state._dirblocks] self.assertEqual(expected, dirblock_names) @@ -1811,11 +2255,15 @@ def test_set_parent_trees_correct_order(self): class InstrumentedDirState(dirstate.DirState): """An DirState with instrumented sha1 functionality.""" - def __init__(self, path, sha1_provider, worth_saving_limit=0, - use_filesystem_for_exec=True): + def __init__( + self, path, sha1_provider, worth_saving_limit=0, use_filesystem_for_exec=True + ): super().__init__( - path, sha1_provider, worth_saving_limit=worth_saving_limit, - use_filesystem_for_exec=use_filesystem_for_exec) + path, + sha1_provider, + worth_saving_limit=worth_saving_limit, + use_filesystem_for_exec=use_filesystem_for_exec, + ) self._time_offset = 0 self._log = [] # member is dynamically set in DirState.__init__ to turn on trace @@ -1827,21 +2275,20 @@ def _sha_cutoff_time(self): self._cutoff_time = timestamp + self._time_offset def _sha1_file_and_log(self, abspath): - self._log.append(('sha1', abspath)) + self._log.append(("sha1", abspath)) return self._sha1_provider.sha1(abspath) def _read_link(self, abspath, old_link): - self._log.append(('read_link', abspath, old_link)) + self._log.append(("read_link", abspath, old_link)) return super()._read_link(abspath, old_link) def _lstat(self, abspath, entry): - self._log.append(('lstat', abspath)) + self._log.append(("lstat", abspath)) return super()._lstat(abspath, entry) def _is_executable(self, mode, old_executable): - self._log.append(('is_exec', mode, old_executable)) - return super()._is_executable(mode, - old_executable) + self._log.append(("is_exec", mode, old_executable)) + return super()._is_executable(mode, old_executable) def adjust_time(self, secs): """Move the clock forward or back. @@ -1867,12 +2314,12 @@ def __init__(self, size, mtime, ctime, dev, ino, mode): @staticmethod def from_stat(st): - return _FakeStat(st.st_size, st.st_mtime, st.st_ctime, st.st_dev, - st.st_ino, st.st_mode) + return _FakeStat( + st.st_size, st.st_mtime, st.st_ctime, st.st_dev, st.st_ino, st.st_mode + ) class TestPackStat(tests.TestCaseWithTransport): - def assertPackStat(self, expected, stat_value): """Check the packed and serialized form of a stat value.""" self.assertEqual(expected, dirstate.pack_stat(stat_value)) @@ -1880,25 +2327,25 @@ def assertPackStat(self, expected, stat_value): def test_pack_stat_int(self): st = _FakeStat(6859, 1172758614, 1172758617, 777, 6499538, 0o100644) # Make sure that all parameters have an impact on the packed stat. - self.assertPackStat(b'AAAay0Xm4FZF5uBZAAADCQBjLNIAAIGk', st) + self.assertPackStat(b"AAAay0Xm4FZF5uBZAAADCQBjLNIAAIGk", st) st.st_size = 7000 # ay0 => bWE - self.assertPackStat(b'AAAbWEXm4FZF5uBZAAADCQBjLNIAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FZF5uBZAAADCQBjLNIAAIGk", st) st.st_mtime = 1172758620 # 4FZ => 4Fx - self.assertPackStat(b'AAAbWEXm4FxF5uBZAAADCQBjLNIAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FxF5uBZAAADCQBjLNIAAIGk", st) st.st_ctime = 1172758630 # uBZ => uBm - self.assertPackStat(b'AAAbWEXm4FxF5uBmAAADCQBjLNIAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FxF5uBmAAADCQBjLNIAAIGk", st) st.st_dev = 888 # DCQ => DeA - self.assertPackStat(b'AAAbWEXm4FxF5uBmAAADeABjLNIAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FxF5uBmAAADeABjLNIAAIGk", st) st.st_ino = 6499540 # LNI => LNQ - self.assertPackStat(b'AAAbWEXm4FxF5uBmAAADeABjLNQAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FxF5uBmAAADeABjLNQAAIGk", st) st.st_mode = 0o100744 # IGk => IHk - self.assertPackStat(b'AAAbWEXm4FxF5uBmAAADeABjLNQAAIHk', st) + self.assertPackStat(b"AAAbWEXm4FxF5uBmAAADeABjLNQAAIHk", st) def test_pack_stat_float(self): """On some platforms mtime and ctime are floats. @@ -1906,21 +2353,20 @@ def test_pack_stat_float(self): Make sure we don't get warnings or errors, and that we ignore changes < 1s """ - st = _FakeStat(7000, 1172758614.0, 1172758617.0, - 777, 6499538, 0o100644) + st = _FakeStat(7000, 1172758614.0, 1172758617.0, 777, 6499538, 0o100644) # These should all be the same as the integer counterparts - self.assertPackStat(b'AAAbWEXm4FZF5uBZAAADCQBjLNIAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FZF5uBZAAADCQBjLNIAAIGk", st) st.st_mtime = 1172758620.0 # FZF5 => FxF5 - self.assertPackStat(b'AAAbWEXm4FxF5uBZAAADCQBjLNIAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FxF5uBZAAADCQBjLNIAAIGk", st) st.st_ctime = 1172758630.0 # uBZ => uBm - self.assertPackStat(b'AAAbWEXm4FxF5uBmAAADCQBjLNIAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FxF5uBmAAADCQBjLNIAAIGk", st) # fractional seconds are discarded, so no change from above st.st_mtime = 1172758620.453 - self.assertPackStat(b'AAAbWEXm4FxF5uBmAAADCQBjLNIAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FxF5uBmAAADCQBjLNIAAIGk", st) st.st_ctime = 1172758630.228 - self.assertPackStat(b'AAAbWEXm4FxF5uBmAAADCQBjLNIAAIGk', st) + self.assertPackStat(b"AAAbWEXm4FxF5uBmAAADCQBjLNIAAIGk", st) class TestBisect(TestCaseWithDirState): @@ -2004,168 +2450,183 @@ def test_bisect_each(self): tree, state, expected = self.create_basic_dirstate() # Bisect should return the rows for the specified files. - self.assertBisect(expected, [[b'']], state, [b'']) - self.assertBisect(expected, [[b'a']], state, [b'a']) - self.assertBisect(expected, [[b'b']], state, [b'b']) - self.assertBisect(expected, [[b'b/c']], state, [b'b/c']) - self.assertBisect(expected, [[b'b/d']], state, [b'b/d']) - self.assertBisect(expected, [[b'b/d/e']], state, [b'b/d/e']) - self.assertBisect(expected, [[b'b-c']], state, [b'b-c']) - self.assertBisect(expected, [[b'f']], state, [b'f']) + self.assertBisect(expected, [[b""]], state, [b""]) + self.assertBisect(expected, [[b"a"]], state, [b"a"]) + self.assertBisect(expected, [[b"b"]], state, [b"b"]) + self.assertBisect(expected, [[b"b/c"]], state, [b"b/c"]) + self.assertBisect(expected, [[b"b/d"]], state, [b"b/d"]) + self.assertBisect(expected, [[b"b/d/e"]], state, [b"b/d/e"]) + self.assertBisect(expected, [[b"b-c"]], state, [b"b-c"]) + self.assertBisect(expected, [[b"f"]], state, [b"f"]) def test_bisect_multi(self): """Bisect can be used to find multiple records at the same time.""" tree, state, expected = self.create_basic_dirstate() # Bisect should be capable of finding multiple entries at the same time - self.assertBisect(expected, [[b'a'], [b'b'], [b'f']], - state, [b'a', b'b', b'f']) - self.assertBisect(expected, [[b'f'], [b'b/d'], [b'b/d/e']], - state, [b'f', b'b/d', b'b/d/e']) - self.assertBisect(expected, [[b'b'], [b'b-c'], [b'b/c']], - state, [b'b', b'b-c', b'b/c']) + self.assertBisect(expected, [[b"a"], [b"b"], [b"f"]], state, [b"a", b"b", b"f"]) + self.assertBisect( + expected, [[b"f"], [b"b/d"], [b"b/d/e"]], state, [b"f", b"b/d", b"b/d/e"] + ) + self.assertBisect( + expected, [[b"b"], [b"b-c"], [b"b/c"]], state, [b"b", b"b-c", b"b/c"] + ) def test_bisect_one_page(self): """Test bisect when there is only 1 page to read.""" tree, state, expected = self.create_basic_dirstate() state._bisect_page_size = 5000 - self.assertBisect(expected, [[b'']], state, [b'']) - self.assertBisect(expected, [[b'a']], state, [b'a']) - self.assertBisect(expected, [[b'b']], state, [b'b']) - self.assertBisect(expected, [[b'b/c']], state, [b'b/c']) - self.assertBisect(expected, [[b'b/d']], state, [b'b/d']) - self.assertBisect(expected, [[b'b/d/e']], state, [b'b/d/e']) - self.assertBisect(expected, [[b'b-c']], state, [b'b-c']) - self.assertBisect(expected, [[b'f']], state, [b'f']) - self.assertBisect(expected, [[b'a'], [b'b'], [b'f']], - state, [b'a', b'b', b'f']) - self.assertBisect(expected, [[b'b/d'], [b'b/d/e'], [b'f']], - state, [b'b/d', b'b/d/e', b'f']) - self.assertBisect(expected, [[b'b'], [b'b/c'], [b'b-c']], - state, [b'b', b'b/c', b'b-c']) + self.assertBisect(expected, [[b""]], state, [b""]) + self.assertBisect(expected, [[b"a"]], state, [b"a"]) + self.assertBisect(expected, [[b"b"]], state, [b"b"]) + self.assertBisect(expected, [[b"b/c"]], state, [b"b/c"]) + self.assertBisect(expected, [[b"b/d"]], state, [b"b/d"]) + self.assertBisect(expected, [[b"b/d/e"]], state, [b"b/d/e"]) + self.assertBisect(expected, [[b"b-c"]], state, [b"b-c"]) + self.assertBisect(expected, [[b"f"]], state, [b"f"]) + self.assertBisect(expected, [[b"a"], [b"b"], [b"f"]], state, [b"a", b"b", b"f"]) + self.assertBisect( + expected, [[b"b/d"], [b"b/d/e"], [b"f"]], state, [b"b/d", b"b/d/e", b"f"] + ) + self.assertBisect( + expected, [[b"b"], [b"b/c"], [b"b-c"]], state, [b"b", b"b/c", b"b-c"] + ) def test_bisect_duplicate_paths(self): """When bisecting for a path, handle multiple entries.""" tree, state, expected = self.create_duplicated_dirstate() # Now make sure that both records are properly returned. - self.assertBisect(expected, [[b'']], state, [b'']) - self.assertBisect(expected, [[b'a', b'a2']], state, [b'a']) - self.assertBisect(expected, [[b'b', b'b2']], state, [b'b']) - self.assertBisect(expected, [[b'b/c', b'b/c2']], state, [b'b/c']) - self.assertBisect(expected, [[b'b/d', b'b/d2']], state, [b'b/d']) - self.assertBisect(expected, [[b'b/d/e', b'b/d/e2']], - state, [b'b/d/e']) - self.assertBisect(expected, [[b'b-c', b'b-c2']], state, [b'b-c']) - self.assertBisect(expected, [[b'f', b'f2']], state, [b'f']) + self.assertBisect(expected, [[b""]], state, [b""]) + self.assertBisect(expected, [[b"a", b"a2"]], state, [b"a"]) + self.assertBisect(expected, [[b"b", b"b2"]], state, [b"b"]) + self.assertBisect(expected, [[b"b/c", b"b/c2"]], state, [b"b/c"]) + self.assertBisect(expected, [[b"b/d", b"b/d2"]], state, [b"b/d"]) + self.assertBisect(expected, [[b"b/d/e", b"b/d/e2"]], state, [b"b/d/e"]) + self.assertBisect(expected, [[b"b-c", b"b-c2"]], state, [b"b-c"]) + self.assertBisect(expected, [[b"f", b"f2"]], state, [b"f"]) def test_bisect_page_size_too_small(self): """If the page size is too small, we will auto increase it.""" tree, state, expected = self.create_basic_dirstate() state._bisect_page_size = 50 - self.assertBisect(expected, [None], state, [b'b/e']) - self.assertBisect(expected, [[b'a']], state, [b'a']) - self.assertBisect(expected, [[b'b']], state, [b'b']) - self.assertBisect(expected, [[b'b/c']], state, [b'b/c']) - self.assertBisect(expected, [[b'b/d']], state, [b'b/d']) - self.assertBisect(expected, [[b'b/d/e']], state, [b'b/d/e']) - self.assertBisect(expected, [[b'b-c']], state, [b'b-c']) - self.assertBisect(expected, [[b'f']], state, [b'f']) + self.assertBisect(expected, [None], state, [b"b/e"]) + self.assertBisect(expected, [[b"a"]], state, [b"a"]) + self.assertBisect(expected, [[b"b"]], state, [b"b"]) + self.assertBisect(expected, [[b"b/c"]], state, [b"b/c"]) + self.assertBisect(expected, [[b"b/d"]], state, [b"b/d"]) + self.assertBisect(expected, [[b"b/d/e"]], state, [b"b/d/e"]) + self.assertBisect(expected, [[b"b-c"]], state, [b"b-c"]) + self.assertBisect(expected, [[b"f"]], state, [b"f"]) def test_bisect_missing(self): """Test that bisect return None if it cannot find a path.""" tree, state, expected = self.create_basic_dirstate() - self.assertBisect(expected, [None], state, [b'foo']) - self.assertBisect(expected, [None], state, [b'b/foo']) - self.assertBisect(expected, [None], state, [b'bar/foo']) - self.assertBisect(expected, [None], state, [b'b-c/foo']) + self.assertBisect(expected, [None], state, [b"foo"]) + self.assertBisect(expected, [None], state, [b"b/foo"]) + self.assertBisect(expected, [None], state, [b"bar/foo"]) + self.assertBisect(expected, [None], state, [b"b-c/foo"]) - self.assertBisect(expected, [[b'a'], None, [b'b/d']], - state, [b'a', b'foo', b'b/d']) + self.assertBisect( + expected, [[b"a"], None, [b"b/d"]], state, [b"a", b"foo", b"b/d"] + ) def test_bisect_rename(self): """Check that we find a renamed row.""" tree, state, expected = self.create_renamed_dirstate() # Search for the pre and post renamed entries - self.assertBisect(expected, [[b'a']], state, [b'a']) - self.assertBisect(expected, [[b'b/g']], state, [b'b/g']) - self.assertBisect(expected, [[b'b/d']], state, [b'b/d']) - self.assertBisect(expected, [[b'h']], state, [b'h']) + self.assertBisect(expected, [[b"a"]], state, [b"a"]) + self.assertBisect(expected, [[b"b/g"]], state, [b"b/g"]) + self.assertBisect(expected, [[b"b/d"]], state, [b"b/d"]) + self.assertBisect(expected, [[b"h"]], state, [b"h"]) # What about b/d/e? shouldn't that also get 2 directory entries? - self.assertBisect(expected, [[b'b/d/e']], state, [b'b/d/e']) - self.assertBisect(expected, [[b'h/e']], state, [b'h/e']) + self.assertBisect(expected, [[b"b/d/e"]], state, [b"b/d/e"]) + self.assertBisect(expected, [[b"h/e"]], state, [b"h/e"]) def test_bisect_dirblocks(self): tree, state, expected = self.create_duplicated_dirstate() - self.assertBisectDirBlocks(expected, - [[b'', b'a', b'a2', b'b', b'b2', - b'b-c', b'b-c2', b'f', b'f2']], - state, [b'']) - self.assertBisectDirBlocks(expected, - [[b'b/c', b'b/c2', b'b/d', b'b/d2']], state, [b'b']) - self.assertBisectDirBlocks(expected, - [[b'b/d/e', b'b/d/e2']], state, [b'b/d']) - self.assertBisectDirBlocks(expected, - [[b'', b'a', b'a2', b'b', b'b2', b'b-c', b'b-c2', b'f', b'f2'], - [b'b/c', b'b/c2', b'b/d', b'b/d2'], - [b'b/d/e', b'b/d/e2'], - ], state, [b'', b'b', b'b/d']) + self.assertBisectDirBlocks( + expected, + [[b"", b"a", b"a2", b"b", b"b2", b"b-c", b"b-c2", b"f", b"f2"]], + state, + [b""], + ) + self.assertBisectDirBlocks( + expected, [[b"b/c", b"b/c2", b"b/d", b"b/d2"]], state, [b"b"] + ) + self.assertBisectDirBlocks(expected, [[b"b/d/e", b"b/d/e2"]], state, [b"b/d"]) + self.assertBisectDirBlocks( + expected, + [ + [b"", b"a", b"a2", b"b", b"b2", b"b-c", b"b-c2", b"f", b"f2"], + [b"b/c", b"b/c2", b"b/d", b"b/d2"], + [b"b/d/e", b"b/d/e2"], + ], + state, + [b"", b"b", b"b/d"], + ) def test_bisect_dirblocks_missing(self): tree, state, expected = self.create_basic_dirstate() - self.assertBisectDirBlocks(expected, [[b'b/d/e'], None], - state, [b'b/d', b'b/e']) + self.assertBisectDirBlocks( + expected, [[b"b/d/e"], None], state, [b"b/d", b"b/e"] + ) # Files don't show up in this search - self.assertBisectDirBlocks(expected, [None], state, [b'a']) - self.assertBisectDirBlocks(expected, [None], state, [b'b/c']) - self.assertBisectDirBlocks(expected, [None], state, [b'c']) - self.assertBisectDirBlocks(expected, [None], state, [b'b/d/e']) - self.assertBisectDirBlocks(expected, [None], state, [b'f']) + self.assertBisectDirBlocks(expected, [None], state, [b"a"]) + self.assertBisectDirBlocks(expected, [None], state, [b"b/c"]) + self.assertBisectDirBlocks(expected, [None], state, [b"c"]) + self.assertBisectDirBlocks(expected, [None], state, [b"b/d/e"]) + self.assertBisectDirBlocks(expected, [None], state, [b"f"]) def test_bisect_recursive_each(self): tree, state, expected = self.create_basic_dirstate() - self.assertBisectRecursive(expected, [b'a'], state, [b'a']) - self.assertBisectRecursive(expected, [b'b/c'], state, [b'b/c']) - self.assertBisectRecursive(expected, [b'b/d/e'], state, [b'b/d/e']) - self.assertBisectRecursive(expected, [b'b-c'], state, [b'b-c']) - self.assertBisectRecursive(expected, [b'b/d', b'b/d/e'], - state, [b'b/d']) - self.assertBisectRecursive(expected, [b'b', b'b/c', b'b/d', b'b/d/e'], - state, [b'b']) - self.assertBisectRecursive(expected, [b'', b'a', b'b', b'b-c', b'f', b'b/c', - b'b/d', b'b/d/e'], - state, [b'']) + self.assertBisectRecursive(expected, [b"a"], state, [b"a"]) + self.assertBisectRecursive(expected, [b"b/c"], state, [b"b/c"]) + self.assertBisectRecursive(expected, [b"b/d/e"], state, [b"b/d/e"]) + self.assertBisectRecursive(expected, [b"b-c"], state, [b"b-c"]) + self.assertBisectRecursive(expected, [b"b/d", b"b/d/e"], state, [b"b/d"]) + self.assertBisectRecursive( + expected, [b"b", b"b/c", b"b/d", b"b/d/e"], state, [b"b"] + ) + self.assertBisectRecursive( + expected, + [b"", b"a", b"b", b"b-c", b"f", b"b/c", b"b/d", b"b/d/e"], + state, + [b""], + ) def test_bisect_recursive_multiple(self): tree, state, expected = self.create_basic_dirstate() + self.assertBisectRecursive(expected, [b"a", b"b/c"], state, [b"a", b"b/c"]) self.assertBisectRecursive( - expected, [b'a', b'b/c'], state, [b'a', b'b/c']) - self.assertBisectRecursive(expected, [b'b/d', b'b/d/e'], - state, [b'b/d', b'b/d/e']) + expected, [b"b/d", b"b/d/e"], state, [b"b/d", b"b/d/e"] + ) def test_bisect_recursive_missing(self): tree, state, expected = self.create_basic_dirstate() - self.assertBisectRecursive(expected, [], state, [b'd']) - self.assertBisectRecursive(expected, [], state, [b'b/e']) - self.assertBisectRecursive(expected, [], state, [b'g']) - self.assertBisectRecursive(expected, [b'a'], state, [b'a', b'g']) + self.assertBisectRecursive(expected, [], state, [b"d"]) + self.assertBisectRecursive(expected, [], state, [b"b/e"]) + self.assertBisectRecursive(expected, [], state, [b"g"]) + self.assertBisectRecursive(expected, [b"a"], state, [b"a", b"g"]) def test_bisect_recursive_renamed(self): tree, state, expected = self.create_renamed_dirstate() # Looking for either renamed item should find the other - self.assertBisectRecursive(expected, [b'a', b'b/g'], state, [b'a']) - self.assertBisectRecursive(expected, [b'a', b'b/g'], state, [b'b/g']) + self.assertBisectRecursive(expected, [b"a", b"b/g"], state, [b"a"]) + self.assertBisectRecursive(expected, [b"a", b"b/g"], state, [b"b/g"]) # Looking in the containing directory should find the rename target, # and anything in a subdir of the renamed target. - self.assertBisectRecursive(expected, [b'a', b'b', b'b/c', b'b/d', - b'b/d/e', b'b/g', b'h', b'h/e'], - state, [b'b']) + self.assertBisectRecursive( + expected, + [b"a", b"b", b"b/c", b"b/d", b"b/d/e", b"b/g", b"h", b"h/e"], + state, + [b"b"], + ) class TestDirstateValidation(TestCaseWithDirState): - def test_validate_correct_dirstate(self): state = self.create_complex_dirstate() state._validate() @@ -2184,12 +2645,13 @@ def test_dirblock_not_sorted(self): # we're appending to the dirblock, but this name comes before some of # the existing names; that's wrong last_dirblock[1].append( - ((b'h', b'aaaa', b'a-id'), - [(b'a', b'', 0, False, b''), - (b'a', b'', 0, False, b'')])) - e = self.assertRaises(AssertionError, - state._validate) - self.assertContainsRe(str(e), 'not sorted') + ( + (b"h", b"aaaa", b"a-id"), + [(b"a", b"", 0, False, b""), (b"a", b"", 0, False, b"")], + ) + ) + e = self.assertRaises(AssertionError, state._validate) + self.assertContainsRe(str(e), "not sorted") def test_dirblock_name_mismatch(self): tree, state, expected = self.create_renamed_dirstate() @@ -2197,13 +2659,13 @@ def test_dirblock_name_mismatch(self): last_dirblock = state._dirblocks[-1] # add an entry with the wrong directory name last_dirblock[1].append( - ((b'', b'z', b'a-id'), - [(b'a', b'', 0, False, b''), - (b'a', b'', 0, False, b'')])) - e = self.assertRaises(AssertionError, - state._validate) - self.assertContainsRe(str(e), - "doesn't match directory name") + ( + (b"", b"z", b"a-id"), + [(b"a", b"", 0, False, b""), (b"a", b"", 0, False, b"")], + ) + ) + e = self.assertRaises(AssertionError, state._validate) + self.assertContainsRe(str(e), "doesn't match directory name") def test_dirblock_missing_rename(self): tree, state, expected = self.create_renamed_dirstate() @@ -2212,28 +2674,30 @@ def test_dirblock_missing_rename(self): # make another entry for a-id, without a correct 'r' pointer to # the real occurrence in the working tree last_dirblock[1].append( - ((b'h', b'z', b'a-id'), - [(b'a', b'', 0, False, b''), - (b'a', b'', 0, False, b'')])) - e = self.assertRaises(AssertionError, - state._validate) - self.assertContainsRe(str(e), - 'file a-id is absent in row') + ( + (b"h", b"z", b"a-id"), + [(b"a", b"", 0, False, b""), (b"a", b"", 0, False, b"")], + ) + ) + e = self.assertRaises(AssertionError, state._validate) + self.assertContainsRe(str(e), "file a-id is absent in row") class TestDirstateTreeReference(TestCaseWithDirState): - def test_reference_revision_is_none(self): - tree = self.make_branch_and_tree('tree', format='development-subtree') - subtree = self.make_branch_and_tree('tree/subtree', - format='development-subtree') - subtree.set_root_id(b'subtree') + tree = self.make_branch_and_tree("tree", format="development-subtree") + subtree = self.make_branch_and_tree( + "tree/subtree", format="development-subtree" + ) + subtree.set_root_id(b"subtree") tree.add_reference(subtree) - tree.add('subtree') - state = dirstate.DirState.from_tree(tree, 'dirstate') - key = (b'', b'subtree', b'subtree') - expected = (b'', [(key, - [(b't', b'', 0, False, b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')])]) + tree.add("subtree") + state = dirstate.DirState.from_tree(tree, "dirstate") + key = (b"", b"subtree", b"subtree") + expected = ( + b"", + [(key, [(b"t", b"", 0, False, b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")])], + ) try: self.assertEqual(expected, state._find_block(key)) @@ -2242,7 +2706,6 @@ def test_reference_revision_is_none(self): class TestDiscardMergeParents(TestCaseWithDirState): - def test_discard_no_parents(self): # This should be a no-op state = self.create_empty_dirstate() @@ -2252,18 +2715,21 @@ def test_discard_no_parents(self): def test_discard_one_parent(self): # No-op - packed_stat = b'AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk' - root_entry_direntry = (b'', b'', b'a-root-value'), [ - (b'd', b'', 0, False, packed_stat), - (b'd', b'', 0, False, packed_stat), - ] + packed_stat = b"AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk" + root_entry_direntry = ( + (b"", b"", b"a-root-value"), + [ + (b"d", b"", 0, False, packed_stat), + (b"d", b"", 0, False, packed_stat), + ], + ) dirblocks = [] - dirblocks.append((b'', [root_entry_direntry])) - dirblocks.append((b'', [])) + dirblocks.append((b"", [root_entry_direntry])) + dirblocks.append((b"", [])) state = self.create_empty_dirstate() self.addCleanup(state.unlock) - state._set_data([b'parent-id'], dirblocks[:]) + state._set_data([b"parent-id"], dirblocks[:]) state._validate() state._discard_merge_parents() @@ -2272,100 +2738,128 @@ def test_discard_one_parent(self): def test_discard_simple(self): # No-op - packed_stat = b'AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk' - root_entry_direntry = (b'', b'', b'a-root-value'), [ - (b'd', b'', 0, False, packed_stat), - (b'd', b'', 0, False, packed_stat), - (b'd', b'', 0, False, packed_stat), - ] - expected_root_entry_direntry = (b'', b'', b'a-root-value'), [ - (b'd', b'', 0, False, packed_stat), - (b'd', b'', 0, False, packed_stat), - ] + packed_stat = b"AAAAREUHaIpFB2iKAAADAQAtkqUAAIGk" + root_entry_direntry = ( + (b"", b"", b"a-root-value"), + [ + (b"d", b"", 0, False, packed_stat), + (b"d", b"", 0, False, packed_stat), + (b"d", b"", 0, False, packed_stat), + ], + ) + expected_root_entry_direntry = ( + (b"", b"", b"a-root-value"), + [ + (b"d", b"", 0, False, packed_stat), + (b"d", b"", 0, False, packed_stat), + ], + ) dirblocks = [] - dirblocks.append((b'', [root_entry_direntry])) - dirblocks.append((b'', [])) + dirblocks.append((b"", [root_entry_direntry])) + dirblocks.append((b"", [])) state = self.create_empty_dirstate() self.addCleanup(state.unlock) - state._set_data([b'parent-id', b'merged-id'], dirblocks[:]) + state._set_data([b"parent-id", b"merged-id"], dirblocks[:]) state._validate() # This should strip of the extra column state._discard_merge_parents() state._validate() - expected_dirblocks = [(b'', [expected_root_entry_direntry]), (b'', [])] + expected_dirblocks = [(b"", [expected_root_entry_direntry]), (b"", [])] self.assertEqual(expected_dirblocks, state._dirblocks) def test_discard_absent(self): """If entries are only in a merge, discard should remove the entries.""" null_stat = dirstate.DirState.NULLSTAT - present_dir = (b'd', b'', 0, False, null_stat) - present_file = (b'f', b'', 0, False, null_stat) + present_dir = (b"d", b"", 0, False, null_stat) + present_file = (b"f", b"", 0, False, null_stat) absent = dirstate.DirState.NULL_PARENT_DETAILS - root_key = (b'', b'', b'a-root-value') - file_in_root_key = (b'', b'file-in-root', b'a-file-id') - file_in_merged_key = (b'', b'file-in-merged', b'b-file-id') - dirblocks = [(b'', [(root_key, [present_dir, present_dir, present_dir])]), - (b'', [(file_in_merged_key, - [absent, absent, present_file]), - (file_in_root_key, - [present_file, present_file, present_file]), - ]), - ] + root_key = (b"", b"", b"a-root-value") + file_in_root_key = (b"", b"file-in-root", b"a-file-id") + file_in_merged_key = (b"", b"file-in-merged", b"b-file-id") + dirblocks = [ + (b"", [(root_key, [present_dir, present_dir, present_dir])]), + ( + b"", + [ + (file_in_merged_key, [absent, absent, present_file]), + (file_in_root_key, [present_file, present_file, present_file]), + ], + ), + ] state = self.create_empty_dirstate() self.addCleanup(state.unlock) - state._set_data([b'parent-id', b'merged-id'], dirblocks[:]) + state._set_data([b"parent-id", b"merged-id"], dirblocks[:]) state._validate() - exp_dirblocks = [(b'', [(root_key, [present_dir, present_dir])]), - (b'', [(file_in_root_key, - [present_file, present_file]), - ]), - ] + exp_dirblocks = [ + (b"", [(root_key, [present_dir, present_dir])]), + ( + b"", + [ + (file_in_root_key, [present_file, present_file]), + ], + ), + ] state._discard_merge_parents() state._validate() self.assertEqual(exp_dirblocks, state._dirblocks) def test_discard_renamed(self): null_stat = dirstate.DirState.NULLSTAT - present_dir = (b'd', b'', 0, False, null_stat) - present_file = (b'f', b'', 0, False, null_stat) + present_dir = (b"d", b"", 0, False, null_stat) + present_file = (b"f", b"", 0, False, null_stat) absent = dirstate.DirState.NULL_PARENT_DETAILS - root_key = (b'', b'', b'a-root-value') - file_in_root_key = (b'', b'file-in-root', b'a-file-id') + root_key = (b"", b"", b"a-root-value") + file_in_root_key = (b"", b"file-in-root", b"a-file-id") # Renamed relative to parent - file_rename_s_key = (b'', b'file-s', b'b-file-id') - file_rename_t_key = (b'', b'file-t', b'b-file-id') + file_rename_s_key = (b"", b"file-s", b"b-file-id") + file_rename_t_key = (b"", b"file-t", b"b-file-id") # And one that is renamed between the parents, but absent in this - key_in_1 = (b'', b'file-in-1', b'c-file-id') - key_in_2 = (b'', b'file-in-2', b'c-file-id') + key_in_1 = (b"", b"file-in-1", b"c-file-id") + key_in_2 = (b"", b"file-in-2", b"c-file-id") dirblocks = [ - (b'', [(root_key, [present_dir, present_dir, present_dir])]), - (b'', [(key_in_1, - [absent, present_file, (b'r', b'file-in-2', b'c-file-id')]), - (key_in_2, - [absent, (b'r', b'file-in-1', b'c-file-id'), present_file]), - (file_in_root_key, - [present_file, present_file, present_file]), - (file_rename_s_key, - [(b'r', b'file-t', b'b-file-id'), absent, present_file]), - (file_rename_t_key, - [present_file, absent, (b'r', b'file-s', b'b-file-id')]), - ]), + (b"", [(root_key, [present_dir, present_dir, present_dir])]), + ( + b"", + [ + ( + key_in_1, + [absent, present_file, (b"r", b"file-in-2", b"c-file-id")], + ), + ( + key_in_2, + [absent, (b"r", b"file-in-1", b"c-file-id"), present_file], + ), + (file_in_root_key, [present_file, present_file, present_file]), + ( + file_rename_s_key, + [(b"r", b"file-t", b"b-file-id"), absent, present_file], + ), + ( + file_rename_t_key, + [present_file, absent, (b"r", b"file-s", b"b-file-id")], + ), + ], + ), ] exp_dirblocks = [ - (b'', [(root_key, [present_dir, present_dir])]), - (b'', [(key_in_1, [absent, present_file]), - (file_in_root_key, [present_file, present_file]), - (file_rename_t_key, [present_file, absent]), - ]), + (b"", [(root_key, [present_dir, present_dir])]), + ( + b"", + [ + (key_in_1, [absent, present_file]), + (file_in_root_key, [present_file, present_file]), + (file_rename_t_key, [present_file, absent]), + ], + ), ] state = self.create_empty_dirstate() self.addCleanup(state.unlock) - state._set_data([b'parent-id', b'merged-id'], dirblocks[:]) + state._set_data([b"parent-id", b"merged-id"], dirblocks[:]) state._validate() state._discard_merge_parents() @@ -2374,31 +2868,35 @@ def test_discard_renamed(self): def test_discard_all_subdir(self): null_stat = dirstate.DirState.NULLSTAT - present_dir = (b'd', b'', 0, False, null_stat) - present_file = (b'f', b'', 0, False, null_stat) + present_dir = (b"d", b"", 0, False, null_stat) + present_file = (b"f", b"", 0, False, null_stat) absent = dirstate.DirState.NULL_PARENT_DETAILS - root_key = (b'', b'', b'a-root-value') - subdir_key = (b'', b'sub', b'dir-id') - child1_key = (b'sub', b'child1', b'child1-id') - child2_key = (b'sub', b'child2', b'child2-id') - child3_key = (b'sub', b'child3', b'child3-id') + root_key = (b"", b"", b"a-root-value") + subdir_key = (b"", b"sub", b"dir-id") + child1_key = (b"sub", b"child1", b"child1-id") + child2_key = (b"sub", b"child2", b"child2-id") + child3_key = (b"sub", b"child3", b"child3-id") dirblocks = [ - (b'', [(root_key, [present_dir, present_dir, present_dir])]), - (b'', [(subdir_key, [present_dir, present_dir, present_dir])]), - (b'sub', [(child1_key, [absent, absent, present_file]), - (child2_key, [absent, absent, present_file]), - (child3_key, [absent, absent, present_file]), - ]), + (b"", [(root_key, [present_dir, present_dir, present_dir])]), + (b"", [(subdir_key, [present_dir, present_dir, present_dir])]), + ( + b"sub", + [ + (child1_key, [absent, absent, present_file]), + (child2_key, [absent, absent, present_file]), + (child3_key, [absent, absent, present_file]), + ], + ), ] exp_dirblocks = [ - (b'', [(root_key, [present_dir, present_dir])]), - (b'', [(subdir_key, [present_dir, present_dir])]), - (b'sub', []), + (b"", [(root_key, [present_dir, present_dir])]), + (b"", [(subdir_key, [present_dir, present_dir])]), + (b"sub", []), ] state = self.create_empty_dirstate() self.addCleanup(state.unlock) - state._set_data([b'parent-id', b'merged-id'], dirblocks[:]) + state._set_data([b"parent-id", b"merged-id"], dirblocks[:]) state._validate() state._discard_merge_parents() @@ -2407,7 +2905,6 @@ def test_discard_all_subdir(self): class Test_InvEntryToDetails(tests.TestCase): - def assertDetails(self, expected, inv_entry): details = dirstate._inv_entry_to_details(inv_entry) self.assertEqual(expected, details) @@ -2419,36 +2916,36 @@ def assertDetails(self, expected, inv_entry): self.assertIsInstance(tree_data, bytes) def test_unicode_symlink(self): - inv_entry = inventory.InventoryLink(b'link-file-id', - 'nam\N{Euro Sign}e', - b'link-parent-id') - inv_entry.revision = b'link-revision-id' - target = 'link-targ\N{Euro Sign}t' + inv_entry = inventory.InventoryLink( + b"link-file-id", "nam\N{Euro Sign}e", b"link-parent-id" + ) + inv_entry.revision = b"link-revision-id" + target = "link-targ\N{Euro Sign}t" inv_entry.symlink_target = target - self.assertDetails((b'l', target.encode('UTF-8'), 0, False, - b'link-revision-id'), inv_entry) + self.assertDetails( + (b"l", target.encode("UTF-8"), 0, False, b"link-revision-id"), inv_entry + ) class TestSHA1Provider(tests.TestCaseInTempDir): - def test_sha1provider_is_an_interface(self): p = dirstate.SHA1Provider() self.assertRaises(NotImplementedError, p.sha1, "foo") self.assertRaises(NotImplementedError, p.stat_and_sha1, "foo") def test_defaultsha1provider_sha1(self): - text = b'test\r\nwith\nall\rpossible line endings\r\n' - self.build_tree_contents([('foo', text)]) + text = b"test\r\nwith\nall\rpossible line endings\r\n" + self.build_tree_contents([("foo", text)]) expected_sha = osutils.sha_string(text) p = dirstate.DefaultSHA1Provider() - self.assertEqual(expected_sha, p.sha1('foo')) + self.assertEqual(expected_sha, p.sha1("foo")) def test_defaultsha1provider_stat_and_sha1(self): - text = b'test\r\nwith\nall\rpossible line endings\r\n' - self.build_tree_contents([('foo', text)]) + text = b"test\r\nwith\nall\rpossible line endings\r\n" + self.build_tree_contents([("foo", text)]) expected_sha = osutils.sha_string(text) p = dirstate.DefaultSHA1Provider() - statvalue, sha1 = p.stat_and_sha1('foo') + statvalue, sha1 = p.stat_and_sha1("foo") self.assertEqual(len(text), statvalue.st_size) self.assertEqual(expected_sha, sha1) @@ -2457,7 +2954,7 @@ class _Repo: """A minimal api to get InventoryRevisionTree to work.""" def __init__(self): - default_format = controldir.format_registry.make_controldir('default') + default_format = controldir.format_registry.make_controldir("default") self._format = default_format.repository_format def lock_read(self): @@ -2468,9 +2965,8 @@ def unlock(self): class TestUpdateBasisByDelta(tests.TestCase): - def path_to_ie(self, path, file_id, rev_id, dir_ids): - if path.endswith('/'): + if path.endswith("/"): is_dir = True path = path[:-1] else: @@ -2479,36 +2975,36 @@ def path_to_ie(self, path, file_id, rev_id, dir_ids): try: dir_id = dir_ids[dirname] except KeyError: - dir_id = osutils.basename(dirname).encode('utf-8') + b'-id' + dir_id = osutils.basename(dirname).encode("utf-8") + b"-id" if is_dir: ie = inventory.InventoryDirectory(file_id, basename, dir_id) dir_ids[path] = file_id else: ie = inventory.InventoryFile(file_id, basename, dir_id) ie.text_size = 0 - ie.text_sha1 = b'' + ie.text_sha1 = b"" ie.revision = rev_id return ie def create_tree_from_shape(self, rev_id, shape): - dir_ids = {'': b'root-id'} - inv = inventory.Inventory(b'root-id', rev_id) + dir_ids = {"": b"root-id"} + inv = inventory.Inventory(b"root-id", rev_id) for info in shape: if len(info) == 2: path, file_id = info ie_rev_id = rev_id else: path, file_id, ie_rev_id = info - if path == '': + if path == "": # Replace the root entry inv.rename_id(inv.root.file_id, file_id) - dir_ids[''] = file_id + dir_ids[""] = file_id else: inv.add(self.path_to_ie(path, file_id, ie_rev_id, dir_ids)) return inventorytree.InventoryRevisionTree(_Repo(), inv, rev_id) def create_empty_dirstate(self): - fd, path = tempfile.mkstemp(prefix='bzr-dirstate') + fd, path = tempfile.mkstemp(prefix="bzr-dirstate") self.addCleanup(os.remove, path) os.close(fd) state = dirstate.DirState.initialize(path) @@ -2517,10 +3013,10 @@ def create_empty_dirstate(self): def create_inv_delta(self, delta, rev_id): """Translate a 'delta shape' into an actual InventoryDelta.""" - dir_ids = {'': b'root-id'} + dir_ids = {"": b"root-id"} inv_delta = [] for old_path, new_path, file_id in delta: - if old_path is not None and old_path.endswith('/'): + if old_path is not None and old_path.endswith("/"): # Don't have to actually do anything for this, because only # new_path creates InventoryEntries old_path = old_path[:-1] @@ -2539,26 +3035,28 @@ def assertUpdate(self, active, basis, target): and assert that the DirState is still valid, and that its stored content matches the target_shape. """ - active_tree = self.create_tree_from_shape(b'active', active) - basis_tree = self.create_tree_from_shape(b'basis', basis) - target_tree = self.create_tree_from_shape(b'target', target) + active_tree = self.create_tree_from_shape(b"active", active) + basis_tree = self.create_tree_from_shape(b"basis", basis) + target_tree = self.create_tree_from_shape(b"target", target) state = self.create_empty_dirstate() - state.set_state_from_scratch(active_tree.root_inventory, - [(b'basis', basis_tree)], []) - delta = target_tree.root_inventory._make_delta( - basis_tree.root_inventory) - state.update_basis_by_delta(delta, b'target') + state.set_state_from_scratch( + active_tree.root_inventory, [(b"basis", basis_tree)], [] + ) + delta = target_tree.root_inventory._make_delta(basis_tree.root_inventory) + state.update_basis_by_delta(delta, b"target") state._validate() dirstate_tree = workingtree_4.DirStateRevisionTree( - state, b'target', _Repo(), None) + state, b"target", _Repo(), None + ) # The target now that delta has been applied should match the # RevisionTree self.assertEqual([], list(dirstate_tree.iter_changes(target_tree))) # And the dirblock state should be identical to the state if we created # it from scratch. state2 = self.create_empty_dirstate() - state2.set_state_from_scratch(active_tree.root_inventory, - [(b'target', target_tree)], []) + state2.set_state_from_scratch( + active_tree.root_inventory, [(b"target", target_tree)], [] + ) self.assertEqual(state2._dirblocks, state._dirblocks) return state @@ -2574,14 +3072,16 @@ def assertBadDelta(self, active, basis, delta): Renaming a dir is: ('old/', 'new/', b'dir-id') etc. """ - active_tree = self.create_tree_from_shape(b'active', active) - basis_tree = self.create_tree_from_shape(b'basis', basis) - inv_delta = self.create_inv_delta(delta, b'target') + active_tree = self.create_tree_from_shape(b"active", active) + basis_tree = self.create_tree_from_shape(b"basis", basis) + inv_delta = self.create_inv_delta(delta, b"target") state = self.create_empty_dirstate() - state.set_state_from_scratch(active_tree.root_inventory, - [(b'basis', basis_tree)], []) - self.assertRaises(errors.InconsistentDelta, - state.update_basis_by_delta, inv_delta, b'target') + state.set_state_from_scratch( + active_tree.root_inventory, [(b"basis", basis_tree)], [] + ) + self.assertRaises( + errors.InconsistentDelta, state.update_basis_by_delta, inv_delta, b"target" + ) # try: ## state.update_basis_by_delta(inv_delta, b'target') # except errors.InconsistentDelta, e: @@ -2593,316 +3093,278 @@ def assertBadDelta(self, active, basis, delta): def test_remove_file_matching_active_state(self): self.assertUpdate( active=[], - basis=[('file', b'file-id')], + basis=[("file", b"file-id")], target=[], - ) + ) def test_remove_file_present_in_active_state(self): self.assertUpdate( - active=[('file', b'file-id')], - basis=[('file', b'file-id')], + active=[("file", b"file-id")], + basis=[("file", b"file-id")], target=[], - ) + ) def test_remove_file_present_elsewhere_in_active_state(self): self.assertUpdate( - active=[('other-file', b'file-id')], - basis=[('file', b'file-id')], + active=[("other-file", b"file-id")], + basis=[("file", b"file-id")], target=[], - ) + ) def test_remove_file_active_state_has_diff_file(self): self.assertUpdate( - active=[('file', b'file-id-2')], - basis=[('file', b'file-id')], + active=[("file", b"file-id-2")], + basis=[("file", b"file-id")], target=[], - ) + ) def test_remove_file_active_state_has_diff_file_and_file_elsewhere(self): self.assertUpdate( - active=[('file', b'file-id-2'), - ('other-file', b'file-id')], - basis=[('file', b'file-id')], + active=[("file", b"file-id-2"), ("other-file", b"file-id")], + basis=[("file", b"file-id")], target=[], - ) + ) def test_add_file_matching_active_state(self): self.assertUpdate( - active=[('file', b'file-id')], + active=[("file", b"file-id")], basis=[], - target=[('file', b'file-id')], - ) + target=[("file", b"file-id")], + ) def test_add_file_in_empty_dir_not_matching_active_state(self): self.assertUpdate( active=[], - basis=[('dir/', b'dir-id')], - target=[('dir/', b'dir-id', b'basis'), ('dir/file', b'file-id')], - ) + basis=[("dir/", b"dir-id")], + target=[("dir/", b"dir-id", b"basis"), ("dir/file", b"file-id")], + ) def test_add_file_missing_in_active_state(self): self.assertUpdate( active=[], basis=[], - target=[('file', b'file-id')], - ) + target=[("file", b"file-id")], + ) def test_add_file_elsewhere_in_active_state(self): self.assertUpdate( - active=[('other-file', b'file-id')], + active=[("other-file", b"file-id")], basis=[], - target=[('file', b'file-id')], - ) + target=[("file", b"file-id")], + ) def test_add_file_active_state_has_diff_file_and_file_elsewhere(self): self.assertUpdate( - active=[('other-file', b'file-id'), - ('file', b'file-id-2')], + active=[("other-file", b"file-id"), ("file", b"file-id-2")], basis=[], - target=[('file', b'file-id')], - ) + target=[("file", b"file-id")], + ) def test_rename_file_matching_active_state(self): self.assertUpdate( - active=[('other-file', b'file-id')], - basis=[('file', b'file-id')], - target=[('other-file', b'file-id')], - ) + active=[("other-file", b"file-id")], + basis=[("file", b"file-id")], + target=[("other-file", b"file-id")], + ) def test_rename_file_missing_in_active_state(self): self.assertUpdate( active=[], - basis=[('file', b'file-id')], - target=[('other-file', b'file-id')], - ) + basis=[("file", b"file-id")], + target=[("other-file", b"file-id")], + ) def test_rename_file_present_elsewhere_in_active_state(self): self.assertUpdate( - active=[('third', b'file-id')], - basis=[('file', b'file-id')], - target=[('other-file', b'file-id')], - ) + active=[("third", b"file-id")], + basis=[("file", b"file-id")], + target=[("other-file", b"file-id")], + ) def test_rename_file_active_state_has_diff_source_file(self): self.assertUpdate( - active=[('file', b'file-id-2')], - basis=[('file', b'file-id')], - target=[('other-file', b'file-id')], - ) + active=[("file", b"file-id-2")], + basis=[("file", b"file-id")], + target=[("other-file", b"file-id")], + ) def test_rename_file_active_state_has_diff_target_file(self): self.assertUpdate( - active=[('other-file', b'file-id-2')], - basis=[('file', b'file-id')], - target=[('other-file', b'file-id')], - ) + active=[("other-file", b"file-id-2")], + basis=[("file", b"file-id")], + target=[("other-file", b"file-id")], + ) def test_rename_file_active_has_swapped_files(self): self.assertUpdate( - active=[('file', b'file-id'), - ('other-file', b'file-id-2')], - basis=[('file', b'file-id'), - ('other-file', b'file-id-2')], - target=[('file', b'file-id-2'), - ('other-file', b'file-id')]) + active=[("file", b"file-id"), ("other-file", b"file-id-2")], + basis=[("file", b"file-id"), ("other-file", b"file-id-2")], + target=[("file", b"file-id-2"), ("other-file", b"file-id")], + ) def test_rename_file_basis_has_swapped_files(self): self.assertUpdate( - active=[('file', b'file-id'), - ('other-file', b'file-id-2')], - basis=[('file', b'file-id-2'), - ('other-file', b'file-id')], - target=[('file', b'file-id'), - ('other-file', b'file-id-2')]) + active=[("file", b"file-id"), ("other-file", b"file-id-2")], + basis=[("file", b"file-id-2"), ("other-file", b"file-id")], + target=[("file", b"file-id"), ("other-file", b"file-id-2")], + ) def test_rename_directory_with_contents(self): self.assertUpdate( # active matches basis - active=[('dir1/', b'dir-id'), - ('dir1/file', b'file-id')], - basis=[('dir1/', b'dir-id'), - ('dir1/file', b'file-id')], - target=[('dir2/', b'dir-id'), - ('dir2/file', b'file-id')]) + active=[("dir1/", b"dir-id"), ("dir1/file", b"file-id")], + basis=[("dir1/", b"dir-id"), ("dir1/file", b"file-id")], + target=[("dir2/", b"dir-id"), ("dir2/file", b"file-id")], + ) self.assertUpdate( # active matches target - active=[('dir2/', b'dir-id'), - ('dir2/file', b'file-id')], - basis=[('dir1/', b'dir-id'), - ('dir1/file', b'file-id')], - target=[('dir2/', b'dir-id'), - ('dir2/file', b'file-id')]) + active=[("dir2/", b"dir-id"), ("dir2/file", b"file-id")], + basis=[("dir1/", b"dir-id"), ("dir1/file", b"file-id")], + target=[("dir2/", b"dir-id"), ("dir2/file", b"file-id")], + ) self.assertUpdate( # active empty active=[], - basis=[('dir1/', b'dir-id'), - ('dir1/file', b'file-id')], - target=[('dir2/', b'dir-id'), - ('dir2/file', b'file-id')]) + basis=[("dir1/", b"dir-id"), ("dir1/file", b"file-id")], + target=[("dir2/", b"dir-id"), ("dir2/file", b"file-id")], + ) self.assertUpdate( # active present at other location - active=[('dir3/', b'dir-id'), - ('dir3/file', b'file-id')], - basis=[('dir1/', b'dir-id'), - ('dir1/file', b'file-id')], - target=[('dir2/', b'dir-id'), - ('dir2/file', b'file-id')]) + active=[("dir3/", b"dir-id"), ("dir3/file", b"file-id")], + basis=[("dir1/", b"dir-id"), ("dir1/file", b"file-id")], + target=[("dir2/", b"dir-id"), ("dir2/file", b"file-id")], + ) self.assertUpdate( # active has different ids - active=[('dir1/', b'dir1-id'), - ('dir1/file', b'file1-id'), - ('dir2/', b'dir2-id'), - ('dir2/file', b'file2-id')], - basis=[('dir1/', b'dir-id'), - ('dir1/file', b'file-id')], - target=[('dir2/', b'dir-id'), - ('dir2/file', b'file-id')]) + active=[ + ("dir1/", b"dir1-id"), + ("dir1/file", b"file1-id"), + ("dir2/", b"dir2-id"), + ("dir2/file", b"file2-id"), + ], + basis=[("dir1/", b"dir-id"), ("dir1/file", b"file-id")], + target=[("dir2/", b"dir-id"), ("dir2/file", b"file-id")], + ) def test_invalid_file_not_present(self): self.assertBadDelta( - active=[('file', b'file-id')], - basis=[('file', b'file-id')], - delta=[('other-file', 'file', b'file-id')]) + active=[("file", b"file-id")], + basis=[("file", b"file-id")], + delta=[("other-file", "file", b"file-id")], + ) def test_invalid_new_id_same_path(self): # The bad entry comes after self.assertBadDelta( - active=[('file', b'file-id')], - basis=[('file', b'file-id')], - delta=[(None, 'file', b'file-id-2')]) + active=[("file", b"file-id")], + basis=[("file", b"file-id")], + delta=[(None, "file", b"file-id-2")], + ) # The bad entry comes first self.assertBadDelta( - active=[('file', b'file-id-2')], - basis=[('file', b'file-id-2')], - delta=[(None, 'file', b'file-id')]) + active=[("file", b"file-id-2")], + basis=[("file", b"file-id-2")], + delta=[(None, "file", b"file-id")], + ) def test_invalid_existing_id(self): self.assertBadDelta( - active=[('file', b'file-id')], - basis=[('file', b'file-id')], - delta=[(None, 'file', b'file-id')]) + active=[("file", b"file-id")], + basis=[("file", b"file-id")], + delta=[(None, "file", b"file-id")], + ) def test_invalid_parent_missing(self): self.assertBadDelta( - active=[], - basis=[], - delta=[(None, 'path/path2', b'file-id')]) + active=[], basis=[], delta=[(None, "path/path2", b"file-id")] + ) # Note: we force the active tree to have the directory, by knowing how # path_to_ie handles entries with missing parents self.assertBadDelta( - active=[('path/', b'path-id')], + active=[("path/", b"path-id")], basis=[], - delta=[(None, 'path/path2', b'file-id')]) + delta=[(None, "path/path2", b"file-id")], + ) self.assertBadDelta( - active=[('path/', b'path-id'), - ('path/path2', b'file-id')], + active=[("path/", b"path-id"), ("path/path2", b"file-id")], basis=[], - delta=[(None, 'path/path2', b'file-id')]) + delta=[(None, "path/path2", b"file-id")], + ) def test_renamed_dir_same_path(self): # We replace the parent directory, with another parent dir. But the C # file doesn't look like it has been moved. self.assertUpdate( # Same as basis - active=[('dir/', b'A-id'), - ('dir/B', b'B-id')], - basis=[('dir/', b'A-id'), - ('dir/B', b'B-id')], - target=[('dir/', b'C-id'), - ('dir/B', b'B-id')]) + active=[("dir/", b"A-id"), ("dir/B", b"B-id")], + basis=[("dir/", b"A-id"), ("dir/B", b"B-id")], + target=[("dir/", b"C-id"), ("dir/B", b"B-id")], + ) self.assertUpdate( # Same as target - active=[('dir/', b'C-id'), - ('dir/B', b'B-id')], - basis=[('dir/', b'A-id'), - ('dir/B', b'B-id')], - target=[('dir/', b'C-id'), - ('dir/B', b'B-id')]) + active=[("dir/", b"C-id"), ("dir/B", b"B-id")], + basis=[("dir/", b"A-id"), ("dir/B", b"B-id")], + target=[("dir/", b"C-id"), ("dir/B", b"B-id")], + ) self.assertUpdate( # empty active active=[], - basis=[('dir/', b'A-id'), - ('dir/B', b'B-id')], - target=[('dir/', b'C-id'), - ('dir/B', b'B-id')]) + basis=[("dir/", b"A-id"), ("dir/B", b"B-id")], + target=[("dir/", b"C-id"), ("dir/B", b"B-id")], + ) self.assertUpdate( # different active - active=[('dir/', b'D-id'), - ('dir/B', b'B-id')], - basis=[('dir/', b'A-id'), - ('dir/B', b'B-id')], - target=[('dir/', b'C-id'), - ('dir/B', b'B-id')]) + active=[("dir/", b"D-id"), ("dir/B", b"B-id")], + basis=[("dir/", b"A-id"), ("dir/B", b"B-id")], + target=[("dir/", b"C-id"), ("dir/B", b"B-id")], + ) def test_parent_child_swap(self): self.assertUpdate( # Same as basis - active=[('A/', b'A-id'), - ('A/B/', b'B-id'), - ('A/B/C', b'C-id')], - basis=[('A/', b'A-id'), - ('A/B/', b'B-id'), - ('A/B/C', b'C-id')], - target=[('A/', b'B-id'), - ('A/B/', b'A-id'), - ('A/B/C', b'C-id')]) + active=[("A/", b"A-id"), ("A/B/", b"B-id"), ("A/B/C", b"C-id")], + basis=[("A/", b"A-id"), ("A/B/", b"B-id"), ("A/B/C", b"C-id")], + target=[("A/", b"B-id"), ("A/B/", b"A-id"), ("A/B/C", b"C-id")], + ) self.assertUpdate( # Same as target - active=[('A/', b'B-id'), - ('A/B/', b'A-id'), - ('A/B/C', b'C-id')], - basis=[('A/', b'A-id'), - ('A/B/', b'B-id'), - ('A/B/C', b'C-id')], - target=[('A/', b'B-id'), - ('A/B/', b'A-id'), - ('A/B/C', b'C-id')]) + active=[("A/", b"B-id"), ("A/B/", b"A-id"), ("A/B/C", b"C-id")], + basis=[("A/", b"A-id"), ("A/B/", b"B-id"), ("A/B/C", b"C-id")], + target=[("A/", b"B-id"), ("A/B/", b"A-id"), ("A/B/C", b"C-id")], + ) self.assertUpdate( # empty active active=[], - basis=[('A/', b'A-id'), - ('A/B/', b'B-id'), - ('A/B/C', b'C-id')], - target=[('A/', b'B-id'), - ('A/B/', b'A-id'), - ('A/B/C', b'C-id')]) + basis=[("A/", b"A-id"), ("A/B/", b"B-id"), ("A/B/C", b"C-id")], + target=[("A/", b"B-id"), ("A/B/", b"A-id"), ("A/B/C", b"C-id")], + ) self.assertUpdate( # different active - active=[('D/', b'A-id'), - ('D/E/', b'B-id'), - ('F', b'C-id')], - basis=[('A/', b'A-id'), - ('A/B/', b'B-id'), - ('A/B/C', b'C-id')], - target=[('A/', b'B-id'), - ('A/B/', b'A-id'), - ('A/B/C', b'C-id')]) + active=[("D/", b"A-id"), ("D/E/", b"B-id"), ("F", b"C-id")], + basis=[("A/", b"A-id"), ("A/B/", b"B-id"), ("A/B/C", b"C-id")], + target=[("A/", b"B-id"), ("A/B/", b"A-id"), ("A/B/C", b"C-id")], + ) def test_change_root_id(self): self.assertUpdate( # same as basis - active=[('', b'root-id'), - ('file', b'file-id')], - basis=[('', b'root-id'), - ('file', b'file-id')], - target=[('', b'target-root-id'), - ('file', b'file-id')]) + active=[("", b"root-id"), ("file", b"file-id")], + basis=[("", b"root-id"), ("file", b"file-id")], + target=[("", b"target-root-id"), ("file", b"file-id")], + ) self.assertUpdate( # same as target - active=[('', b'target-root-id'), - ('file', b'file-id')], - basis=[('', b'root-id'), - ('file', b'file-id')], - target=[('', b'target-root-id'), - ('file', b'root-id')]) + active=[("", b"target-root-id"), ("file", b"file-id")], + basis=[("", b"root-id"), ("file", b"file-id")], + target=[("", b"target-root-id"), ("file", b"root-id")], + ) self.assertUpdate( # all different - active=[('', b'active-root-id'), - ('file', b'file-id')], - basis=[('', b'root-id'), - ('file', b'file-id')], - target=[('', b'target-root-id'), - ('file', b'root-id')]) + active=[("", b"active-root-id"), ("file", b"file-id")], + basis=[("", b"root-id"), ("file", b"file-id")], + target=[("", b"target-root-id"), ("file", b"root-id")], + ) def test_change_file_absent_in_active(self): self.assertUpdate( - active=[], - basis=[('file', b'file-id')], - target=[('file', b'file-id')]) + active=[], basis=[("file", b"file-id")], target=[("file", b"file-id")] + ) def test_invalid_changed_file(self): self.assertBadDelta( # Not present in basis - active=[('file', b'file-id')], + active=[("file", b"file-id")], basis=[], - delta=[('file', 'file', b'file-id')]) + delta=[("file", "file", b"file-id")], + ) self.assertBadDelta( # present at another location in basis - active=[('file', b'file-id')], - basis=[('other-file', b'file-id')], - delta=[('file', 'file', b'file-id')]) + active=[("file", b"file-id")], + basis=[("other-file", b"file-id")], + delta=[("file", "file", b"file-id")], + ) class TestBisectDirblock(tests.TestCase): @@ -2924,13 +3386,15 @@ def assertBisect(self, dirblocks, split_dirblocks, path, *args, **kwargs): """ self.assertIsInstance(dirblocks, list) bisect_split_idx = dirstate.bisect_dirblock(dirblocks, path, *args, **kwargs) - split_dirblock = (path.split(b'/'), []) - bisect_left_idx = bisect.bisect_left(split_dirblocks, split_dirblock, - *args) - self.assertEqual(bisect_left_idx, bisect_split_idx, - 'bisect_split disagreed. {} != {}' - ' for key {!r}'.format(bisect_left_idx, bisect_split_idx, path) - ) + split_dirblock = (path.split(b"/"), []) + bisect_left_idx = bisect.bisect_left(split_dirblocks, split_dirblock, *args) + self.assertEqual( + bisect_left_idx, + bisect_split_idx, + "bisect_split disagreed. {} != {}" " for key {!r}".format( + bisect_left_idx, bisect_split_idx, path + ), + ) def paths_to_dirblocks(self, paths): """Convert a list of paths into dirblock form. @@ -2938,50 +3402,90 @@ def paths_to_dirblocks(self, paths): Also, ensure that the paths are in proper sorted order. """ dirblocks = [(path, []) for path in paths] - split_dirblocks = [(path.split(b'/'), []) for path in paths] + split_dirblocks = [(path.split(b"/"), []) for path in paths] self.assertEqual(sorted(split_dirblocks), split_dirblocks) return dirblocks, split_dirblocks def test_simple(self): """In the simple case it works just like bisect_left.""" - paths = [b'', b'a', b'b', b'c', b'd'] + paths = [b"", b"a", b"b", b"c", b"d"] dirblocks, split_dirblocks = self.paths_to_dirblocks(paths) for path in paths: self.assertBisect(dirblocks, split_dirblocks, path) - self.assertBisect(dirblocks, split_dirblocks, b'_') - self.assertBisect(dirblocks, split_dirblocks, b'aa') - self.assertBisect(dirblocks, split_dirblocks, b'bb') - self.assertBisect(dirblocks, split_dirblocks, b'cc') - self.assertBisect(dirblocks, split_dirblocks, b'dd') - self.assertBisect(dirblocks, split_dirblocks, b'a/a') - self.assertBisect(dirblocks, split_dirblocks, b'b/b') - self.assertBisect(dirblocks, split_dirblocks, b'c/c') - self.assertBisect(dirblocks, split_dirblocks, b'd/d') + self.assertBisect(dirblocks, split_dirblocks, b"_") + self.assertBisect(dirblocks, split_dirblocks, b"aa") + self.assertBisect(dirblocks, split_dirblocks, b"bb") + self.assertBisect(dirblocks, split_dirblocks, b"cc") + self.assertBisect(dirblocks, split_dirblocks, b"dd") + self.assertBisect(dirblocks, split_dirblocks, b"a/a") + self.assertBisect(dirblocks, split_dirblocks, b"b/b") + self.assertBisect(dirblocks, split_dirblocks, b"c/c") + self.assertBisect(dirblocks, split_dirblocks, b"d/d") def test_involved(self): """This is where bisect_left diverges slightly.""" - paths = [b'', b'a', - b'a/a', b'a/a/a', b'a/a/z', b'a/a-a', b'a/a-z', - b'a/z', b'a/z/a', b'a/z/z', b'a/z-a', b'a/z-z', - b'a-a', b'a-z', - b'z', b'z/a/a', b'z/a/z', b'z/a-a', b'z/a-z', - b'z/z', b'z/z/a', b'z/z/z', b'z/z-a', b'z/z-z', - b'z-a', b'z-z', - ] + paths = [ + b"", + b"a", + b"a/a", + b"a/a/a", + b"a/a/z", + b"a/a-a", + b"a/a-z", + b"a/z", + b"a/z/a", + b"a/z/z", + b"a/z-a", + b"a/z-z", + b"a-a", + b"a-z", + b"z", + b"z/a/a", + b"z/a/z", + b"z/a-a", + b"z/a-z", + b"z/z", + b"z/z/a", + b"z/z/z", + b"z/z-a", + b"z/z-z", + b"z-a", + b"z-z", + ] dirblocks, split_dirblocks = self.paths_to_dirblocks(paths) for path in paths: self.assertBisect(dirblocks, split_dirblocks, path) def test_involved_cached(self): """This is where bisect_left diverges slightly.""" - paths = [b'', b'a', - b'a/a', b'a/a/a', b'a/a/z', b'a/a-a', b'a/a-z', - b'a/z', b'a/z/a', b'a/z/z', b'a/z-a', b'a/z-z', - b'a-a', b'a-z', - b'z', b'z/a/a', b'z/a/z', b'z/a-a', b'z/a-z', - b'z/z', b'z/z/a', b'z/z/z', b'z/z-a', b'z/z-z', - b'z-a', b'z-z', - ] + paths = [ + b"", + b"a", + b"a/a", + b"a/a/a", + b"a/a/z", + b"a/a-a", + b"a/a-z", + b"a/z", + b"a/z/a", + b"a/z/z", + b"a/z-a", + b"a/z-z", + b"a-a", + b"a-z", + b"z", + b"z/a/a", + b"z/a/z", + b"z/a-a", + b"z/a-z", + b"z/z", + b"z/z/a", + b"z/z/z", + b"z/z-a", + b"z/z-z", + b"z-a", + b"z-z", + ] cache = {} dirblocks, split_dirblocks = self.paths_to_dirblocks(paths) for path in paths: @@ -2993,10 +3497,17 @@ def _unpack_stat(packed_stat): This is meant as a debugging tool, should not be used in real code. """ - (st_size, st_mtime, st_ctime, st_dev, st_ino, - st_mode) = struct.unpack('>6L', binascii.a2b_base64(packed_stat)) - return {'st_size': st_size, 'st_mtime': st_mtime, 'st_ctime': st_ctime, - 'st_dev': st_dev, 'st_ino': st_ino, 'st_mode': st_mode} + (st_size, st_mtime, st_ctime, st_dev, st_ino, st_mode) = struct.unpack( + ">6L", binascii.a2b_base64(packed_stat) + ) + return { + "st_size": st_size, + "st_mtime": st_mtime, + "st_ctime": st_ctime, + "st_dev": st_dev, + "st_ino": st_ino, + "st_mode": st_mode, + } class TestPackStatRobust(tests.TestCase): @@ -3008,5 +3519,3 @@ def pack(self, statlike_tuple): @staticmethod def unpack_field(packed_string, stat_field): return _unpack_stat(packed_string)[stat_field] - - diff --git a/breezy/bzr/tests/test_generate_ids.py b/breezy/bzr/tests/test_generate_ids.py index 9bb991a726..60f29a6736 100644 --- a/breezy/bzr/tests/test_generate_ids.py +++ b/breezy/bzr/tests/test_generate_ids.py @@ -29,28 +29,28 @@ def assertGenFileId(self, regex, filename): The file id should be ascii, and should be an 8-bit string """ file_id = generate_ids.gen_file_id(filename) - self.assertContainsRe(file_id, b'^' + regex + b'$') + self.assertContainsRe(file_id, b"^" + regex + b"$") # It should be a utf8 file_id, not a unicode one self.assertIsInstance(file_id, bytes) # gen_file_id should always return ascii file ids. - file_id.decode('ascii') + file_id.decode("ascii") def test_gen_file_id(self): gen_file_id = generate_ids.gen_file_id # We try to use the filename if possible - self.assertStartsWith(gen_file_id('bar'), b'bar-') + self.assertStartsWith(gen_file_id("bar"), b"bar-") # but we squash capitalization, and remove non word characters - self.assertStartsWith(gen_file_id('Mwoo oof\t m'), b'mwoooofm-') + self.assertStartsWith(gen_file_id("Mwoo oof\t m"), b"mwoooofm-") # We also remove leading '.' characters to prevent hidden file-ids - self.assertStartsWith(gen_file_id('..gam.py'), b'gam.py-') - self.assertStartsWith(gen_file_id('..Mwoo oof\t m'), b'mwoooofm-') + self.assertStartsWith(gen_file_id("..gam.py"), b"gam.py-") + self.assertStartsWith(gen_file_id("..Mwoo oof\t m"), b"mwoooofm-") # we remove unicode characters, and still don't end up with a # hidden file id - self.assertStartsWith(gen_file_id('\xe5\xb5.txt'), b'txt-') + self.assertStartsWith(gen_file_id("\xe5\xb5.txt"), b"txt-") # Our current method of generating unique ids adds 33 characters # plus an serial number (log10(N) characters) @@ -58,34 +58,33 @@ def test_gen_file_id(self): # be <= 20 characters, so the maximum length should now be approx < 60 # Test both case squashing and length restriction - fid = gen_file_id('A' * 50 + '.txt') - self.assertStartsWith(fid, b'a' * 20 + b'-') + fid = gen_file_id("A" * 50 + ".txt") + self.assertStartsWith(fid, b"a" * 20 + b"-") self.assertLess(len(fid), 60) # restricting length happens after the other actions, so # we preserve as much as possible - fid = gen_file_id('\xe5\xb5..aBcd\tefGhijKLMnop\tqrstuvwxyz') - self.assertStartsWith(fid, b'abcdefghijklmnopqrst-') + fid = gen_file_id("\xe5\xb5..aBcd\tefGhijKLMnop\tqrstuvwxyz") + self.assertStartsWith(fid, b"abcdefghijklmnopqrst-") self.assertLess(len(fid), 60) def test_file_ids_are_ascii(self): - tail = br'-\d{14}-[a-z0-9]{16}-\d+' - self.assertGenFileId(b'foo' + tail, 'foo') - self.assertGenFileId(b'foo' + tail, 'foo') - self.assertGenFileId(b'bar' + tail, 'bar') - self.assertGenFileId(b'br' + tail, 'b\xe5r') + tail = rb"-\d{14}-[a-z0-9]{16}-\d+" + self.assertGenFileId(b"foo" + tail, "foo") + self.assertGenFileId(b"foo" + tail, "foo") + self.assertGenFileId(b"bar" + tail, "bar") + self.assertGenFileId(b"br" + tail, "b\xe5r") def test__next_id_suffix_increments(self): - ids = [ - generate_ids._next_id_suffix(suffix="foo-") for i in range(10)] - ns = [int(id.split(b'-')[-1]) for id in ids] + ids = [generate_ids._next_id_suffix(suffix="foo-") for i in range(10)] + ns = [int(id.split(b"-")[-1]) for id in ids] for i in range(1, len(ns)): self.assertEqual(ns[i] - 1, ns[i - 1]) def test_gen_root_id(self): # Mostly just make sure gen_root_id() exists root_id = generate_ids.gen_root_id() - self.assertStartsWith(root_id, b'tree_root-') + self.assertStartsWith(root_id, b"tree_root-") class TestGenRevisionId(tests.TestCase): @@ -94,47 +93,49 @@ class TestGenRevisionId(tests.TestCase): def assertGenRevisionId(self, regex, username, timestamp=None): """gen_revision_id should create a revision id matching the regex.""" revision_id = generate_ids.gen_revision_id(username, timestamp) - self.assertContainsRe(revision_id, b'^' + regex + b'$') + self.assertContainsRe(revision_id, b"^" + regex + b"$") # It should be a utf8 revision_id, not a unicode one self.assertIsInstance(revision_id, bytes) # gen_revision_id should always return ascii revision ids. - revision_id.decode('ascii') + revision_id.decode("ascii") def test_timestamp(self): """Passing a timestamp should cause it to be used.""" + self.assertGenRevisionId(rb"user@host-\d{14}-[a-z0-9]{16}", "user@host") self.assertGenRevisionId( - br'user@host-\d{14}-[a-z0-9]{16}', 'user@host') - self.assertGenRevisionId(b'user@host-20061102205056-[a-z0-9]{16}', - 'user@host', 1162500656.688) - self.assertGenRevisionId(br'user@host-20061102205024-[a-z0-9]{16}', - 'user@host', 1162500624.000) + b"user@host-20061102205056-[a-z0-9]{16}", "user@host", 1162500656.688 + ) + self.assertGenRevisionId( + rb"user@host-20061102205024-[a-z0-9]{16}", "user@host", 1162500624.000 + ) def test_gen_revision_id_email(self): """gen_revision_id uses email address if present.""" - regex = br'user\+joe_bar@foo-bar\.com-\d{14}-[a-z0-9]{16}' - self.assertGenRevisionId(regex, 'user+joe_bar@foo-bar.com') - self.assertGenRevisionId(regex, '') - self.assertGenRevisionId(regex, 'Joe Bar ') - self.assertGenRevisionId(regex, 'Joe Bar ') - self.assertGenRevisionId( - regex, 'Joe B\xe5r ') + regex = rb"user\+joe_bar@foo-bar\.com-\d{14}-[a-z0-9]{16}" + self.assertGenRevisionId(regex, "user+joe_bar@foo-bar.com") + self.assertGenRevisionId(regex, "") + self.assertGenRevisionId(regex, "Joe Bar ") + self.assertGenRevisionId(regex, "Joe Bar ") + self.assertGenRevisionId(regex, "Joe B\xe5r ") def test_gen_revision_id_user(self): """If there is no email, fall back to the whole username.""" - tail = br'-\d{14}-[a-z0-9]{16}' - self.assertGenRevisionId(b'joe_bar' + tail, 'Joe Bar') - self.assertGenRevisionId(b'joebar' + tail, 'joebar') - self.assertGenRevisionId(b'joe_br' + tail, 'Joe B\xe5r') - self.assertGenRevisionId(br'joe_br_user\+joe_bar_foo-bar.com' + tail, - 'Joe B\xe5r ') + tail = rb"-\d{14}-[a-z0-9]{16}" + self.assertGenRevisionId(b"joe_bar" + tail, "Joe Bar") + self.assertGenRevisionId(b"joebar" + tail, "joebar") + self.assertGenRevisionId(b"joe_br" + tail, "Joe B\xe5r") + self.assertGenRevisionId( + rb"joe_br_user\+joe_bar_foo-bar.com" + tail, + "Joe B\xe5r ", + ) def test_revision_ids_are_ascii(self): """gen_revision_id should always return an ascii revision id.""" - tail = br'-\d{14}-[a-z0-9]{16}' - self.assertGenRevisionId(b'joe_bar' + tail, 'Joe Bar') - self.assertGenRevisionId(b'joe_bar' + tail, 'Joe Bar') - self.assertGenRevisionId(b'joe@foo' + tail, 'Joe Bar ') + tail = rb"-\d{14}-[a-z0-9]{16}" + self.assertGenRevisionId(b"joe_bar" + tail, "Joe Bar") + self.assertGenRevisionId(b"joe_bar" + tail, "Joe Bar") + self.assertGenRevisionId(b"joe@foo" + tail, "Joe Bar ") # We cheat a little with this one, because email-addresses shouldn't # contain non-ascii characters, but generate_ids should strip them # anyway. - self.assertGenRevisionId(b'joe@f' + tail, 'Joe Bar ') + self.assertGenRevisionId(b"joe@f" + tail, "Joe Bar ") diff --git a/breezy/bzr/tests/test_groupcompress.py b/breezy/bzr/tests/test_groupcompress.py index 5402d85097..57c7af1d6c 100644 --- a/breezy/bzr/tests/test_groupcompress.py +++ b/breezy/bzr/tests/test_groupcompress.py @@ -28,11 +28,10 @@ def group_compress_implementation_scenarios(): scenarios = [ - ('python', {'compressor': groupcompress.PythonGroupCompressor}), - ] + ("python", {"compressor": groupcompress.PythonGroupCompressor}), + ] if compiled_groupcompress_feature.available(): - scenarios.append(('C', - {'compressor': groupcompress.PyrexGroupCompressor})) + scenarios.append(("C", {"compressor": groupcompress.PyrexGroupCompressor})) return scenarios @@ -40,9 +39,8 @@ def group_compress_implementation_scenarios(): class TestGroupCompressor(tests.TestCase): - def _chunks_to_repr_lines(self, chunks): - return '\n'.join(map(repr, b''.join(chunks).split(b'\n'))) + return "\n".join(map(repr, b"".join(chunks).split(b"\n"))) def assertEqualDiffEncoded(self, expected, actual): """Compare the actual content to the expected content. @@ -53,8 +51,9 @@ def assertEqualDiffEncoded(self, expected, actual): We will transform the chunks back into lines, and then run 'repr()' over them to handle non-ascii characters. """ - self.assertEqualDiff(self._chunks_to_repr_lines(expected), - self._chunks_to_repr_lines(actual)) + self.assertEqualDiff( + self._chunks_to_repr_lines(expected), self._chunks_to_repr_lines(actual) + ) class TestAllGroupCompressors(TestGroupCompressor): @@ -70,12 +69,13 @@ def test_empty_delta(self): def test_one_nosha_delta(self): # diff against NUKK compressor = self.compressor() - text = b'strange\ncommon\n' + text = b"strange\ncommon\n" sha1, start_point, end_point, _ = compressor.compress( - ('label',), [text], len(text), None) - self.assertEqual(sha_string(b'strange\ncommon\n'), sha1) - expected_lines = b'f\x0fstrange\ncommon\n' - self.assertEqual(expected_lines, b''.join(compressor.chunks)) + ("label",), [text], len(text), None + ) + self.assertEqual(sha_string(b"strange\ncommon\n"), sha1) + expected_lines = b"f\x0fstrange\ncommon\n" + self.assertEqual(expected_lines, b"".join(compressor.chunks)) self.assertEqual(0, start_point) self.assertEqual(len(expected_lines), end_point) @@ -83,106 +83,109 @@ def test_empty_content(self): compressor = self.compressor() # Adding empty bytes should return the 'null' record sha1, start_point, end_point, kind = compressor.compress( - ('empty',), [], 0, None) + ("empty",), [], 0, None + ) self.assertEqual(0, start_point) self.assertEqual(0, end_point) - self.assertEqual('fulltext', kind) + self.assertEqual("fulltext", kind) self.assertEqual(groupcompress._null_sha1, sha1) self.assertEqual(0, compressor.endpoint) self.assertEqual([], compressor.chunks) # Even after adding some content - text = b'some\nbytes\n' - compressor.compress(('content',), [text], len(text), None) + text = b"some\nbytes\n" + compressor.compress(("content",), [text], len(text), None) self.assertGreater(compressor.endpoint, 0) sha1, start_point, end_point, kind = compressor.compress( - ('empty2',), [], 0, None) + ("empty2",), [], 0, None + ) self.assertEqual(0, start_point) self.assertEqual(0, end_point) - self.assertEqual('fulltext', kind) + self.assertEqual("fulltext", kind) self.assertEqual(groupcompress._null_sha1, sha1) def test_extract_from_compressor(self): # Knit fetching will try to reconstruct texts locally which results in # reading something that is in the compressor stream already. compressor = self.compressor() - text = b'strange\ncommon long line\nthat needs a 16 byte match\n' - sha1_1, _, _, _ = compressor.compress( - ('label',), [text], len(text), None) + text = b"strange\ncommon long line\nthat needs a 16 byte match\n" + sha1_1, _, _, _ = compressor.compress(("label",), [text], len(text), None) list(compressor.chunks) - text = b'common long line\nthat needs a 16 byte match\ndifferent\n' + text = b"common long line\nthat needs a 16 byte match\ndifferent\n" sha1_2, _, end_point, _ = compressor.compress( - ('newlabel',), [text], len(text), None) + ("newlabel",), [text], len(text), None + ) # get the first out - self.assertEqual(([b'strange\ncommon long line\n' - b'that needs a 16 byte match\n'], sha1_1), - compressor.extract(('label',))) + self.assertEqual( + ([b"strange\ncommon long line\n" b"that needs a 16 byte match\n"], sha1_1), + compressor.extract(("label",)), + ) # and the second - self.assertEqual(([b'common long line\nthat needs a 16 byte match\n' - b'different\n'], sha1_2), - compressor.extract(('newlabel',))) + self.assertEqual( + ( + [b"common long line\nthat needs a 16 byte match\n" b"different\n"], + sha1_2, + ), + compressor.extract(("newlabel",)), + ) def test_pop_last(self): compressor = self.compressor() - text = b'some text\nfor the first entry\n' - _, _, _, _ = compressor.compress( - ('key1',), [text], len(text), None) + text = b"some text\nfor the first entry\n" + _, _, _, _ = compressor.compress(("key1",), [text], len(text), None) expected_lines = list(compressor.chunks) - text = b'some text\nfor the second entry\n' - _, _, _, _ = compressor.compress( - ('key2',), [text], len(text), None) + text = b"some text\nfor the second entry\n" + _, _, _, _ = compressor.compress(("key2",), [text], len(text), None) compressor.pop_last() self.assertEqual(expected_lines, compressor.chunks) class TestPyrexGroupCompressor(TestGroupCompressor): - _test_needs_features = [compiled_groupcompress_feature] compressor = groupcompress.PyrexGroupCompressor def test_stats(self): compressor = self.compressor() - chunks = [b'strange\n', - b'common very very long line\n', - b'plus more text\n'] - compressor.compress( - ('label',), chunks, sum(map(len, chunks)), None) + chunks = [b"strange\n", b"common very very long line\n", b"plus more text\n"] + compressor.compress(("label",), chunks, sum(map(len, chunks)), None) chunks = [ - b'common very very long line\n', - b'plus more text\n', - b'different\n', - b'moredifferent\n'] - compressor.compress( - ('newlabel',), - chunks, sum(map(len, chunks)), None) + b"common very very long line\n", + b"plus more text\n", + b"different\n", + b"moredifferent\n", + ] + compressor.compress(("newlabel",), chunks, sum(map(len, chunks)), None) chunks = [ - b'new\n', - b'common very very long line\n', - b'plus more text\n', - b'different\n', - b'moredifferent\n'] - compressor.compress( - ('label3',), chunks, sum(map(len, chunks)), None) + b"new\n", + b"common very very long line\n", + b"plus more text\n", + b"different\n", + b"moredifferent\n", + ] + compressor.compress(("label3",), chunks, sum(map(len, chunks)), None) self.assertAlmostEqual(1.9, compressor.ratio(), 1) def test_two_nosha_delta(self): compressor = self.compressor() - text = b'strange\ncommon long line\nthat needs a 16 byte match\n' - sha1_1, _, _, _ = compressor.compress(('label',), [text], len(text), None) + text = b"strange\ncommon long line\nthat needs a 16 byte match\n" + sha1_1, _, _, _ = compressor.compress(("label",), [text], len(text), None) expected_lines = list(compressor.chunks) - text = b'common long line\nthat needs a 16 byte match\ndifferent\n' + text = b"common long line\nthat needs a 16 byte match\ndifferent\n" sha1_2, start_point, end_point, _ = compressor.compress( - ('newlabel',), [text], len(text), None) + ("newlabel",), [text], len(text), None + ) self.assertEqual(sha_string(text), sha1_2) - expected_lines.extend([ - # 'delta', delta length - b'd\x0f', - # source and target length - b'\x36', - # copy the line common - b'\x91\x0a\x2c', # copy, offset 0x0a, len 0x2c - # add the line different, and the trailing newline - b'\x0adifferent\n', # insert 10 bytes - ]) + expected_lines.extend( + [ + # 'delta', delta length + b"d\x0f", + # source and target length + b"\x36", + # copy the line common + b"\x91\x0a\x2c", # copy, offset 0x0a, len 0x2c + # add the line different, and the trailing newline + b"\x0adifferent\n", # insert 10 bytes + ] + ) self.assertEqualDiffEncoded(expected_lines, compressor.chunks) self.assertEqual(sum(map(len, expected_lines)), end_point) @@ -190,83 +193,83 @@ def test_three_nosha_delta(self): # The first interesting test: make a change that should use lines from # both parents. compressor = self.compressor() - text = b'strange\ncommon very very long line\nwith some extra text\n' - sha1_1, _, _, _ = compressor.compress( - ('label',), [text], len(text), None) - text = b'different\nmoredifferent\nand then some more\n' - sha1_2, _, _, _ = compressor.compress( - ('newlabel',), [text], len(text), None) + text = b"strange\ncommon very very long line\nwith some extra text\n" + sha1_1, _, _, _ = compressor.compress(("label",), [text], len(text), None) + text = b"different\nmoredifferent\nand then some more\n" + sha1_2, _, _, _ = compressor.compress(("newlabel",), [text], len(text), None) expected_lines = list(compressor.chunks) - text = (b'new\ncommon very very long line\nwith some extra text\n' - b'different\nmoredifferent\nand then some more\n') + text = ( + b"new\ncommon very very long line\nwith some extra text\n" + b"different\nmoredifferent\nand then some more\n" + ) sha1_3, start_point, end_point, _ = compressor.compress( - ('label3',), [text], len(text), None) + ("label3",), [text], len(text), None + ) self.assertEqual(sha_string(text), sha1_3) - expected_lines.extend([ - # 'delta', delta length - b'd\x0b', - # source and target length - b'\x5f' - # insert new - b'\x03new', - # Copy of first parent 'common' range - b'\x91\x09\x31' # copy, offset 0x09, 0x31 bytes - # Copy of second parent 'different' range - b'\x91\x3c\x2b' # copy, offset 0x3c, 0x2b bytes - ]) + expected_lines.extend( + [ + # 'delta', delta length + b"d\x0b", + # source and target length + b"\x5f" + # insert new + b"\x03new", + # Copy of first parent 'common' range + b"\x91\x09\x31" # copy, offset 0x09, 0x31 bytes + # Copy of second parent 'different' range + b"\x91\x3c\x2b", # copy, offset 0x3c, 0x2b bytes + ] + ) self.assertEqualDiffEncoded(expected_lines, compressor.chunks) self.assertEqual(sum(map(len, expected_lines)), end_point) class TestPythonGroupCompressor(TestGroupCompressor): - compressor = groupcompress.PythonGroupCompressor def test_stats(self): compressor = self.compressor() - chunks = [b'strange\n', - b'common very very long line\n', - b'plus more text\n'] - compressor.compress( - ('label',), chunks, sum(map(len, chunks)), None) + chunks = [b"strange\n", b"common very very long line\n", b"plus more text\n"] + compressor.compress(("label",), chunks, sum(map(len, chunks)), None) chunks = [ - b'common very very long line\n', - b'plus more text\n', - b'different\n', - b'moredifferent\n'] - compressor.compress( - ('newlabel',), chunks, sum(map(len, chunks)), None) + b"common very very long line\n", + b"plus more text\n", + b"different\n", + b"moredifferent\n", + ] + compressor.compress(("newlabel",), chunks, sum(map(len, chunks)), None) chunks = [ - b'new\n', - b'common very very long line\n', - b'plus more text\n', - b'different\n', - b'moredifferent\n'] - compressor.compress( - ('label3',), - chunks, sum(map(len, chunks)), None) + b"new\n", + b"common very very long line\n", + b"plus more text\n", + b"different\n", + b"moredifferent\n", + ] + compressor.compress(("label3",), chunks, sum(map(len, chunks)), None) self.assertAlmostEqual(1.9, compressor.ratio(), 1) def test_two_nosha_delta(self): compressor = self.compressor() - text = b'strange\ncommon long line\nthat needs a 16 byte match\n' - sha1_1, _, _, _ = compressor.compress( - ('label',), [text], len(text), None) + text = b"strange\ncommon long line\nthat needs a 16 byte match\n" + sha1_1, _, _, _ = compressor.compress(("label",), [text], len(text), None) expected_lines = list(compressor.chunks) - text = b'common long line\nthat needs a 16 byte match\ndifferent\n' + text = b"common long line\nthat needs a 16 byte match\ndifferent\n" sha1_2, start_point, end_point, _ = compressor.compress( - ('newlabel',), [text], len(text), None) + ("newlabel",), [text], len(text), None + ) self.assertEqual(sha_string(text), sha1_2) - expected_lines.extend([ - # 'delta', delta length - b'd\x0f', - # target length - b'\x36', - # copy the line common - b'\x91\x0a\x2c', # copy, offset 0x0a, len 0x2c - # add the line different, and the trailing newline - b'\x0adifferent\n', # insert 10 bytes - ]) + expected_lines.extend( + [ + # 'delta', delta length + b"d\x0f", + # target length + b"\x36", + # copy the line common + b"\x91\x0a\x2c", # copy, offset 0x0a, len 0x2c + # add the line different, and the trailing newline + b"\x0adifferent\n", # insert 10 bytes + ] + ) self.assertEqualDiffEncoded(expected_lines, compressor.chunks) self.assertEqual(sum(map(len, expected_lines)), end_point) @@ -274,44 +277,47 @@ def test_three_nosha_delta(self): # The first interesting test: make a change that should use lines from # both parents. compressor = self.compressor() - text = b'strange\ncommon very very long line\nwith some extra text\n' - sha1_1, _, _, _ = compressor.compress( - ('label',), [text], len(text), None) - text = b'different\nmoredifferent\nand then some more\n' - sha1_2, _, _, _ = compressor.compress( - ('newlabel',), [text], len(text), None) + text = b"strange\ncommon very very long line\nwith some extra text\n" + sha1_1, _, _, _ = compressor.compress(("label",), [text], len(text), None) + text = b"different\nmoredifferent\nand then some more\n" + sha1_2, _, _, _ = compressor.compress(("newlabel",), [text], len(text), None) expected_lines = list(compressor.chunks) - text = (b'new\ncommon very very long line\nwith some extra text\n' - b'different\nmoredifferent\nand then some more\n') + text = ( + b"new\ncommon very very long line\nwith some extra text\n" + b"different\nmoredifferent\nand then some more\n" + ) sha1_3, start_point, end_point, _ = compressor.compress( - ('label3',), [text], len(text), None) + ("label3",), [text], len(text), None + ) self.assertEqual(sha_string(text), sha1_3) - expected_lines.extend([ - # 'delta', delta length - b'd\x0c', - # target length - b'\x5f' - # insert new - b'\x04new\n', - # Copy of first parent 'common' range - b'\x91\x0a\x30' # copy, offset 0x0a, 0x30 bytes - # Copy of second parent 'different' range - b'\x91\x3c\x2b' # copy, offset 0x3c, 0x2b bytes - ]) + expected_lines.extend( + [ + # 'delta', delta length + b"d\x0c", + # target length + b"\x5f" + # insert new + b"\x04new\n", + # Copy of first parent 'common' range + b"\x91\x0a\x30" # copy, offset 0x0a, 0x30 bytes + # Copy of second parent 'different' range + b"\x91\x3c\x2b", # copy, offset 0x3c, 0x2b bytes + ] + ) self.assertEqualDiffEncoded(expected_lines, compressor.chunks) self.assertEqual(sum(map(len, expected_lines)), end_point) class TestGroupCompressBlock(tests.TestCase): - def make_block(self, key_to_text): """Create a GroupCompressBlock, filling it with the given texts.""" compressor = groupcompress.GroupCompressor() for key in sorted(key_to_text): - compressor.compress( - key, [key_to_text[key]], len(key_to_text[key]), None) - locs = {key: (start, end) for key, (start, _, end, _) - in compressor.labels_deltas.items()} + compressor.compress(key, [key_to_text[key]], len(key_to_text[key]), None) + locs = { + key: (start, end) + for key, (start, _, end, _) in compressor.labels_deltas.items() + } block = compressor.flush() raw_bytes = block.to_bytes() # Go through from_bytes(to_bytes()) so that we start with a compressed @@ -319,36 +325,38 @@ def make_block(self, key_to_text): return locs, groupcompress.GroupCompressBlock.from_bytes(raw_bytes) def test_from_empty_bytes(self): - self.assertRaises(ValueError, - groupcompress.GroupCompressBlock.from_bytes, b'') + self.assertRaises(ValueError, groupcompress.GroupCompressBlock.from_bytes, b"") def test_from_minimal_bytes(self): - block = groupcompress.GroupCompressBlock.from_bytes( - b'gcb1z\n0\n0\n') + block = groupcompress.GroupCompressBlock.from_bytes(b"gcb1z\n0\n0\n") self.assertIsInstance(block, groupcompress.GroupCompressBlock) self.assertIs(None, block._content) - self.assertEqual(b'', block._z_content) + self.assertEqual(b"", block._z_content) block._ensure_content() - self.assertEqual(b'', block._content) - self.assertEqual(b'', block._z_content) + self.assertEqual(b"", block._content) + self.assertEqual(b"", block._z_content) block._ensure_content() # Ensure content is safe to call 2x def test_from_invalid(self): - self.assertRaises(ValueError, - groupcompress.GroupCompressBlock.from_bytes, - b'this is not a valid header') + self.assertRaises( + ValueError, + groupcompress.GroupCompressBlock.from_bytes, + b"this is not a valid header", + ) def test_from_bytes(self): - content = (b'a tiny bit of content\n') + content = b"a tiny bit of content\n" z_content = zlib.compress(content) z_bytes = ( - b'gcb1z\n' # group compress block v1 plain - b'%d\n' # Length of compressed content - b'%d\n' # Length of uncompressed content - b'%s' # Compressed content - ) % (len(z_content), len(content), z_content) - block = groupcompress.GroupCompressBlock.from_bytes( - z_bytes) + ( + b"gcb1z\n" # group compress block v1 plain + b"%d\n" # Length of compressed content + b"%d\n" # Length of uncompressed content + b"%s" # Compressed content + ) + % (len(z_content), len(content), z_content) + ) + block = groupcompress.GroupCompressBlock.from_bytes(z_bytes) self.assertEqual(z_content, block._z_content) self.assertIs(None, block._content) self.assertEqual(len(z_content), block._z_content_length) @@ -358,51 +366,61 @@ def test_from_bytes(self): self.assertEqual(content, block._content) def test_to_chunks(self): - content_chunks = [b'this is some content\n', - b'this content will be compressed\n'] + content_chunks = [ + b"this is some content\n", + b"this content will be compressed\n", + ] content_len = sum(map(len, content_chunks)) - content = b''.join(content_chunks) + content = b"".join(content_chunks) gcb = groupcompress.GroupCompressBlock() gcb.set_chunked_content(content_chunks, content_len) total_len, block_chunks = gcb.to_chunks() - block_bytes = b''.join(block_chunks) + block_bytes = b"".join(block_chunks) self.assertEqual(gcb._z_content_length, len(gcb._z_content)) self.assertEqual(total_len, len(block_bytes)) self.assertEqual(gcb._content_length, content_len) - expected_header = (b'gcb1z\n' # group compress block v1 zlib - b'%d\n' # Length of compressed content - b'%d\n' # Length of uncompressed content - ) % (gcb._z_content_length, gcb._content_length) + expected_header = ( + ( + b"gcb1z\n" # group compress block v1 zlib + b"%d\n" # Length of compressed content + b"%d\n" # Length of uncompressed content + ) + % (gcb._z_content_length, gcb._content_length) + ) # The first chunk should be the header chunk. It is small, fixed size, # and there is no compelling reason to split it up self.assertEqual(expected_header, block_chunks[0]) self.assertStartsWith(block_bytes, expected_header) - remaining_bytes = block_bytes[len(expected_header):] + remaining_bytes = block_bytes[len(expected_header) :] raw_bytes = zlib.decompress(remaining_bytes) self.assertEqual(content, raw_bytes) def test_to_bytes(self): - content = (b'this is some content\n' - b'this content will be compressed\n') + content = b"this is some content\n" b"this content will be compressed\n" gcb = groupcompress.GroupCompressBlock() gcb.set_content(content) data = gcb.to_bytes() self.assertEqual(gcb._z_content_length, len(gcb._z_content)) self.assertEqual(gcb._content_length, len(content)) - expected_header = (b'gcb1z\n' # group compress block v1 zlib - b'%d\n' # Length of compressed content - b'%d\n' # Length of uncompressed content - ) % (gcb._z_content_length, gcb._content_length) + expected_header = ( + ( + b"gcb1z\n" # group compress block v1 zlib + b"%d\n" # Length of compressed content + b"%d\n" # Length of uncompressed content + ) + % (gcb._z_content_length, gcb._content_length) + ) self.assertStartsWith(data, expected_header) - remaining_bytes = data[len(expected_header):] + remaining_bytes = data[len(expected_header) :] raw_bytes = zlib.decompress(remaining_bytes) self.assertEqual(content, raw_bytes) # we should get the same results if using the chunked version gcb = groupcompress.GroupCompressBlock() - gcb.set_chunked_content([b'this is some content\n' - b'this content will be compressed\n'], - len(content)) + gcb.set_chunked_content( + [b"this is some content\n" b"this content will be compressed\n"], + len(content), + ) old_data = data data = gcb.to_bytes() self.assertEqual(old_data, data) @@ -414,18 +432,18 @@ def test_partial_decomp(self): # compresses a bit too well, we want a combination, so we combine a sha # hash with compressible data. for i in range(2048): - next_content = b'%d\nThis is a bit of duplicate text\n' % (i,) + next_content = b"%d\nThis is a bit of duplicate text\n" % (i,) content_chunks.append(next_content) next_sha1 = osutils.sha_string(next_content) - content_chunks.append(next_sha1 + b'\n') - content = b''.join(content_chunks) + content_chunks.append(next_sha1 + b"\n") + content = b"".join(content_chunks) self.assertEqual(158634, len(content)) z_content = zlib.compress(content) self.assertEqual(57182, len(z_content)) block = groupcompress.GroupCompressBlock() block._z_content_chunks = (z_content,) block._z_content_length = len(z_content) - block._compressor_name = 'zlib' + block._compressor_name = "zlib" block._content_length = 158634 self.assertIs(None, block._content) block._ensure_content(100) @@ -434,7 +452,7 @@ def test_partial_decomp(self): self.assertGreaterEqual(len(block._content), 100) # We have not decompressed the whole content self.assertLess(len(block._content), 158634) - self.assertEqualDiff(content[:len(block._content)], block._content) + self.assertEqualDiff(content[: len(block._content)], block._content) # ensuring content that we already have shouldn't cause any more data # to be extracted cur_len = len(block._content) @@ -445,7 +463,7 @@ def test_partial_decomp(self): block._ensure_content(cur_len) self.assertGreaterEqual(len(block._content), cur_len) self.assertLess(len(block._content), 158634) - self.assertEqualDiff(content[:len(block._content)], block._content) + self.assertEqualDiff(content[: len(block._content)], block._content) # And now lets finish block._ensure_content(158634) self.assertEqualDiff(content, block._content) @@ -459,18 +477,18 @@ def test__ensure_all_content(self): # compresses a bit too well, we want a combination, so we combine a sha # hash with compressible data. for i in range(2048): - next_content = b'%d\nThis is a bit of duplicate text\n' % (i,) + next_content = b"%d\nThis is a bit of duplicate text\n" % (i,) content_chunks.append(next_content) next_sha1 = osutils.sha_string(next_content) - content_chunks.append(next_sha1 + b'\n') - content = b''.join(content_chunks) + content_chunks.append(next_sha1 + b"\n") + content = b"".join(content_chunks) self.assertEqual(158634, len(content)) z_content = zlib.compress(content) self.assertEqual(57182, len(z_content)) block = groupcompress.GroupCompressBlock() block._z_content_chunks = (z_content,) block._z_content_length = len(z_content) - block._compressor_name = 'zlib' + block._compressor_name = "zlib" block._content_length = 158634 self.assertIs(None, block._content) # The first _ensure_content got all of the required data @@ -481,35 +499,52 @@ def test__ensure_all_content(self): self.assertIs(None, block._z_content_decompressor) def test__dump(self): - dup_content = b'some duplicate content\nwhich is sufficiently long\n' - key_to_text = {(b'1',): dup_content + b'1 unique\n', - (b'2',): dup_content + b'2 extra special\n'} + dup_content = b"some duplicate content\nwhich is sufficiently long\n" + key_to_text = { + (b"1",): dup_content + b"1 unique\n", + (b"2",): dup_content + b"2 extra special\n", + } locs, block = self.make_block(key_to_text) - self.assertEqual([(b'f', len(key_to_text[(b'1',)])), - (b'd', 21, len(key_to_text[(b'2',)]), - [(b'c', 2, len(dup_content)), - (b'i', len(b'2 extra special\n'), b'') - ]), - ], block._dump()) - - -class TestCaseWithGroupCompressVersionedFiles( - tests.TestCaseWithMemoryTransport): - - def make_test_vf(self, create_graph, keylength=1, do_cleanup=True, - dir='.', inconsistency_fatal=True): + self.assertEqual( + [ + (b"f", len(key_to_text[(b"1",)])), + ( + b"d", + 21, + len(key_to_text[(b"2",)]), + [ + (b"c", 2, len(dup_content)), + (b"i", len(b"2 extra special\n"), b""), + ], + ), + ], + block._dump(), + ) + + +class TestCaseWithGroupCompressVersionedFiles(tests.TestCaseWithMemoryTransport): + def make_test_vf( + self, + create_graph, + keylength=1, + do_cleanup=True, + dir=".", + inconsistency_fatal=True, + ): t = self.get_transport(dir) t.ensure_base() - vf = groupcompress.make_pack_factory(graph=create_graph, - delta=False, keylength=keylength, - inconsistency_fatal=inconsistency_fatal)(t) + vf = groupcompress.make_pack_factory( + graph=create_graph, + delta=False, + keylength=keylength, + inconsistency_fatal=inconsistency_fatal, + )(t) if do_cleanup: self.addCleanup(groupcompress.cleanup_pack_group, vf) return vf class TestGroupCompressVersionedFiles(TestCaseWithGroupCompressVersionedFiles): - def make_g_index(self, name, ref_lists=0, nodes=None): if nodes is None: nodes = [] @@ -522,62 +557,82 @@ def make_g_index(self, name, ref_lists=0, nodes=None): return btree_index.BTreeGraphIndex(trans, name, size) def make_g_index_missing_parent(self): - graph_index = self.make_g_index('missing_parent', 1, - [((b'parent', ), b'2 78 2 10', ([],)), - ((b'tip', ), b'2 78 2 10', - ([(b'parent', ), (b'missing-parent', )],)), - ]) + graph_index = self.make_g_index( + "missing_parent", + 1, + [ + ((b"parent",), b"2 78 2 10", ([],)), + ((b"tip",), b"2 78 2 10", ([(b"parent",), (b"missing-parent",)],)), + ], + ) return graph_index def test_get_record_stream_as_requested(self): # Consider promoting 'as-requested' to general availability, and # make this a VF interface test - vf = self.make_test_vf(False, dir='source') - vf.add_lines((b'a',), (), [b'lines\n']) - vf.add_lines((b'b',), (), [b'lines\n']) - vf.add_lines((b'c',), (), [b'lines\n']) - vf.add_lines((b'd',), (), [b'lines\n']) + vf = self.make_test_vf(False, dir="source") + vf.add_lines((b"a",), (), [b"lines\n"]) + vf.add_lines((b"b",), (), [b"lines\n"]) + vf.add_lines((b"c",), (), [b"lines\n"]) + vf.add_lines((b"d",), (), [b"lines\n"]) vf.writer.end() - keys = [record.key for record in vf.get_record_stream( - [(b'a',), (b'b',), (b'c',), (b'd',)], - 'as-requested', False)] - self.assertEqual([(b'a',), (b'b',), (b'c',), (b'd',)], keys) - keys = [record.key for record in vf.get_record_stream( - [(b'b',), (b'a',), (b'd',), (b'c',)], - 'as-requested', False)] - self.assertEqual([(b'b',), (b'a',), (b'd',), (b'c',)], keys) + keys = [ + record.key + for record in vf.get_record_stream( + [(b"a",), (b"b",), (b"c",), (b"d",)], "as-requested", False + ) + ] + self.assertEqual([(b"a",), (b"b",), (b"c",), (b"d",)], keys) + keys = [ + record.key + for record in vf.get_record_stream( + [(b"b",), (b"a",), (b"d",), (b"c",)], "as-requested", False + ) + ] + self.assertEqual([(b"b",), (b"a",), (b"d",), (b"c",)], keys) # It should work even after being repacked into another VF - vf2 = self.make_test_vf(False, dir='target') - vf2.insert_record_stream(vf.get_record_stream( - [(b'b',), (b'a',), (b'd',), (b'c',)], 'as-requested', False)) + vf2 = self.make_test_vf(False, dir="target") + vf2.insert_record_stream( + vf.get_record_stream( + [(b"b",), (b"a",), (b"d",), (b"c",)], "as-requested", False + ) + ) vf2.writer.end() - keys = [record.key for record in vf2.get_record_stream( - [(b'a',), (b'b',), (b'c',), (b'd',)], - 'as-requested', False)] - self.assertEqual([(b'a',), (b'b',), (b'c',), (b'd',)], keys) - keys = [record.key for record in vf2.get_record_stream( - [(b'b',), (b'a',), (b'd',), (b'c',)], - 'as-requested', False)] - self.assertEqual([(b'b',), (b'a',), (b'd',), (b'c',)], keys) + keys = [ + record.key + for record in vf2.get_record_stream( + [(b"a",), (b"b",), (b"c",), (b"d",)], "as-requested", False + ) + ] + self.assertEqual([(b"a",), (b"b",), (b"c",), (b"d",)], keys) + keys = [ + record.key + for record in vf2.get_record_stream( + [(b"b",), (b"a",), (b"d",), (b"c",)], "as-requested", False + ) + ] + self.assertEqual([(b"b",), (b"a",), (b"d",), (b"c",)], keys) def test_get_record_stream_max_bytes_to_index_default(self): - vf = self.make_test_vf(True, dir='source') - vf.add_lines((b'a',), (), [b'lines\n']) + vf = self.make_test_vf(True, dir="source") + vf.add_lines((b"a",), (), [b"lines\n"]) vf.writer.end() - record = next(vf.get_record_stream([(b'a',)], 'unordered', True)) - self.assertEqual(vf._DEFAULT_COMPRESSOR_SETTINGS, - record._manager._get_compressor_settings()) + record = next(vf.get_record_stream([(b"a",)], "unordered", True)) + self.assertEqual( + vf._DEFAULT_COMPRESSOR_SETTINGS, record._manager._get_compressor_settings() + ) def test_get_record_stream_accesses_compressor_settings(self): - vf = self.make_test_vf(True, dir='source') - vf.add_lines((b'a',), (), [b'lines\n']) + vf = self.make_test_vf(True, dir="source") + vf.add_lines((b"a",), (), [b"lines\n"]) vf.writer.end() vf._max_bytes_to_index = 1234 - record = next(vf.get_record_stream([(b'a',)], 'unordered', True)) - self.assertEqual({'max_bytes_to_index': 1234}, - record._manager._get_compressor_settings()) + record = next(vf.get_record_stream([(b"a",)], "unordered", True)) + self.assertEqual( + {"max_bytes_to_index": 1234}, record._manager._get_compressor_settings() + ) @staticmethod def grouped_stream(revision_ids, first_parents=()): @@ -585,79 +640,83 @@ def grouped_stream(revision_ids, first_parents=()): for revision_id in revision_ids: key = (revision_id,) record = versionedfile.FulltextContentFactory( - key, parents, None, - b'some content that is\n' - b'identical except for\n' - b'revision_id:%s\n' % (revision_id,)) + key, + parents, + None, + b"some content that is\n" + b"identical except for\n" + b"revision_id:%s\n" % (revision_id,), + ) yield record parents = (key,) def test_insert_record_stream_reuses_blocks(self): - vf = self.make_test_vf(True, dir='source') + vf = self.make_test_vf(True, dir="source") # One group, a-d - vf.insert_record_stream(self.grouped_stream([b'a', b'b', b'c', b'd'])) + vf.insert_record_stream(self.grouped_stream([b"a", b"b", b"c", b"d"])) # Second group, e-h - vf.insert_record_stream(self.grouped_stream( - [b'e', b'f', b'g', b'h'], first_parents=((b'd',),))) + vf.insert_record_stream( + self.grouped_stream([b"e", b"f", b"g", b"h"], first_parents=((b"d",),)) + ) block_bytes = {} stream = vf.get_record_stream( - [(r.encode(),) for r in 'abcdefgh'], 'unordered', False) + [(r.encode(),) for r in "abcdefgh"], "unordered", False + ) num_records = 0 for record in stream: - if record.key in [(b'a',), (b'e',)]: - self.assertEqual('groupcompress-block', record.storage_kind) + if record.key in [(b"a",), (b"e",)]: + self.assertEqual("groupcompress-block", record.storage_kind) else: - self.assertEqual('groupcompress-block-ref', - record.storage_kind) + self.assertEqual("groupcompress-block-ref", record.storage_kind) block_bytes[record.key] = record._manager._block._z_content num_records += 1 self.assertEqual(8, num_records) - for r in 'abcd': + for r in "abcd": key = (r.encode(),) - self.assertIs(block_bytes[key], block_bytes[(b'a',)]) - self.assertNotEqual(block_bytes[key], block_bytes[(b'e',)]) - for r in 'efgh': + self.assertIs(block_bytes[key], block_bytes[(b"a",)]) + self.assertNotEqual(block_bytes[key], block_bytes[(b"e",)]) + for r in "efgh": key = (r.encode(),) - self.assertIs(block_bytes[key], block_bytes[(b'e',)]) - self.assertNotEqual(block_bytes[key], block_bytes[(b'a',)]) + self.assertIs(block_bytes[key], block_bytes[(b"e",)]) + self.assertNotEqual(block_bytes[key], block_bytes[(b"a",)]) # Now copy the blocks into another vf, and ensure that the blocks are # preserved without creating new entries - vf2 = self.make_test_vf(True, dir='target') - keys = [(r.encode(),) for r in 'abcdefgh'] + vf2 = self.make_test_vf(True, dir="target") + keys = [(r.encode(),) for r in "abcdefgh"] # ordering in 'groupcompress' order, should actually swap the groups in # the target vf, but the groups themselves should not be disturbed. def small_size_stream(): - for record in vf.get_record_stream(keys, 'groupcompress', False): - record._manager._full_enough_block_size = \ + for record in vf.get_record_stream(keys, "groupcompress", False): + record._manager._full_enough_block_size = ( record._manager._block._content_length + ) yield record vf2.insert_record_stream(small_size_stream()) - stream = vf2.get_record_stream(keys, 'groupcompress', False) + stream = vf2.get_record_stream(keys, "groupcompress", False) vf2.writer.end() num_records = 0 for record in stream: num_records += 1 - self.assertEqual(block_bytes[record.key], - record._manager._block._z_content) + self.assertEqual(block_bytes[record.key], record._manager._block._z_content) self.assertEqual(8, num_records) def test_insert_record_stream_packs_on_the_fly(self): - vf = self.make_test_vf(True, dir='source') + vf = self.make_test_vf(True, dir="source") # One group, a-d - vf.insert_record_stream(self.grouped_stream([b'a', b'b', b'c', b'd'])) + vf.insert_record_stream(self.grouped_stream([b"a", b"b", b"c", b"d"])) # Second group, e-h - vf.insert_record_stream(self.grouped_stream( - [b'e', b'f', b'g', b'h'], first_parents=((b'd',),))) + vf.insert_record_stream( + self.grouped_stream([b"e", b"f", b"g", b"h"], first_parents=((b"d",),)) + ) # Now copy the blocks into another vf, and see that the # insert_record_stream rebuilt a new block on-the-fly because of # under-utilization - vf2 = self.make_test_vf(True, dir='target') - keys = [(r.encode(),) for r in 'abcdefgh'] - vf2.insert_record_stream(vf.get_record_stream( - keys, 'groupcompress', False)) - stream = vf2.get_record_stream(keys, 'groupcompress', False) + vf2 = self.make_test_vf(True, dir="target") + keys = [(r.encode(),) for r in "abcdefgh"] + vf2.insert_record_stream(vf.get_record_stream(keys, "groupcompress", False)) + stream = vf2.get_record_stream(keys, "groupcompress", False) vf2.writer.end() num_records = 0 # All of the records should be recombined into a single block @@ -671,28 +730,30 @@ def test_insert_record_stream_packs_on_the_fly(self): self.assertEqual(8, num_records) def test__insert_record_stream_no_reuse_block(self): - vf = self.make_test_vf(True, dir='source') + vf = self.make_test_vf(True, dir="source") # One group, a-d - vf.insert_record_stream(self.grouped_stream([b'a', b'b', b'c', b'd'])) + vf.insert_record_stream(self.grouped_stream([b"a", b"b", b"c", b"d"])) # Second group, e-h - vf.insert_record_stream(self.grouped_stream( - [b'e', b'f', b'g', b'h'], first_parents=((b'd',),))) + vf.insert_record_stream( + self.grouped_stream([b"e", b"f", b"g", b"h"], first_parents=((b"d",),)) + ) vf.writer.end() - keys = [(r.encode(),) for r in 'abcdefgh'] - self.assertEqual(8, len(list( - vf.get_record_stream(keys, 'unordered', False)))) + keys = [(r.encode(),) for r in "abcdefgh"] + self.assertEqual(8, len(list(vf.get_record_stream(keys, "unordered", False)))) # Now copy the blocks into another vf, and ensure that the blocks are # preserved without creating new entries - vf2 = self.make_test_vf(True, dir='target') + vf2 = self.make_test_vf(True, dir="target") # ordering in 'groupcompress' order, should actually swap the groups in # the target vf, but the groups themselves should not be disturbed. - list(vf2._insert_record_stream(vf.get_record_stream( - keys, 'groupcompress', False), - reuse_blocks=False)) + list( + vf2._insert_record_stream( + vf.get_record_stream(keys, "groupcompress", False), reuse_blocks=False + ) + ) vf2.writer.end() # After inserting with reuse_blocks=False, we should have everything in # a single new block. - stream = vf2.get_record_stream(keys, 'groupcompress', False) + stream = vf2.get_record_stream(keys, "groupcompress", False) block = None for record in stream: if block is None: @@ -703,44 +764,52 @@ def test__insert_record_stream_no_reuse_block(self): def test_add_missing_noncompression_parent_unvalidated_index(self): unvalidated = self.make_g_index_missing_parent() combined = _mod_index.CombinedGraphIndex([unvalidated]) - index = groupcompress._GCGraphIndex(combined, - is_locked=lambda: True, parents=True, - track_external_parent_refs=True) + index = groupcompress._GCGraphIndex( + combined, + is_locked=lambda: True, + parents=True, + track_external_parent_refs=True, + ) index.scan_unvalidated_index(unvalidated) - self.assertEqual( - frozenset([(b'missing-parent',)]), index.get_missing_parents()) + self.assertEqual(frozenset([(b"missing-parent",)]), index.get_missing_parents()) def test_track_external_parent_refs(self): - g_index = self.make_g_index('empty', 1, []) + g_index = self.make_g_index("empty", 1, []) mod_index = btree_index.BTreeBuilder(1, 1) combined = _mod_index.CombinedGraphIndex([g_index, mod_index]) - index = groupcompress._GCGraphIndex(combined, - is_locked=lambda: True, parents=True, - add_callback=mod_index.add_nodes, - track_external_parent_refs=True) - index.add_records([ - ((b'new-key',), b'2 10 2 10', [((b'parent-1',), (b'parent-2',))])]) + index = groupcompress._GCGraphIndex( + combined, + is_locked=lambda: True, + parents=True, + add_callback=mod_index.add_nodes, + track_external_parent_refs=True, + ) + index.add_records( + [((b"new-key",), b"2 10 2 10", [((b"parent-1",), (b"parent-2",))])] + ) self.assertEqual( - frozenset([(b'parent-1',), (b'parent-2',)]), - index.get_missing_parents()) + frozenset([(b"parent-1",), (b"parent-2",)]), index.get_missing_parents() + ) def make_source_with_b(self, a_parent, path): source = self.make_test_vf(True, dir=path) - source.add_lines((b'a',), (), [b'lines\n']) + source.add_lines((b"a",), (), [b"lines\n"]) if a_parent: - b_parents = ((b'a',),) + b_parents = ((b"a",),) else: b_parents = () - source.add_lines((b'b',), b_parents, [b'lines\n']) + source.add_lines((b"b",), b_parents, [b"lines\n"]) return source def do_inconsistent_inserts(self, inconsistency_fatal): - target = self.make_test_vf(True, dir='target', - inconsistency_fatal=inconsistency_fatal) + target = self.make_test_vf( + True, dir="target", inconsistency_fatal=inconsistency_fatal + ) for x in range(2): - source = self.make_source_with_b(x == 1, f'source{x}') - target.insert_record_stream(source.get_record_stream( - [(b'b',)], 'unordered', False)) + source = self.make_source_with_b(x == 1, f"source{x}") + target.insert_record_stream( + source.get_record_stream([(b"b",)], "unordered", False) + ) def test_inconsistent_redundant_inserts_warn(self): """Should not insert a record that is already present.""" @@ -748,6 +817,7 @@ def test_inconsistent_redundant_inserts_warn(self): def warning(template, args): warnings.append(template % args) + _trace_warning = trace.warning trace.warning = warning try: @@ -758,21 +828,25 @@ def warning(template, args): "\n".join(warnings), r"^inconsistent details in skipped record: \(b?'b',\)" r" \(b?'42 32 0 8', \(\(\),\)\)" - r" \(b?'74 32 0 8', \(\(\(b?'a',\),\),\)\)$") + r" \(b?'74 32 0 8', \(\(\(b?'a',\),\),\)\)$", + ) def test_inconsistent_redundant_inserts_raises(self): - e = self.assertRaises(knit.KnitCorrupt, self.do_inconsistent_inserts, - inconsistency_fatal=True) - self.assertContainsRe(str(e), r"Knit.* corrupt: inconsistent details" - r" in add_records:" - r" \(b?'b',\) \(b?'42 32 0 8', \(\(\),\)\)" - r" \(b?'74 32 0 8', \(\(\(b?'a',\),\),\)\)") + e = self.assertRaises( + knit.KnitCorrupt, self.do_inconsistent_inserts, inconsistency_fatal=True + ) + self.assertContainsRe( + str(e), + r"Knit.* corrupt: inconsistent details" + r" in add_records:" + r" \(b?'b',\) \(b?'42 32 0 8', \(\(\),\)\)" + r" \(b?'74 32 0 8', \(\(\(b?'a',\),\),\)\)", + ) def test_clear_cache(self): - vf = self.make_source_with_b(True, 'source') + vf = self.make_source_with_b(True, "source") vf.writer.end() - for _record in vf.get_record_stream([(b'a',), (b'b',)], 'unordered', - True): + for _record in vf.get_record_stream([(b"a",), (b"b",)], "unordered", True): pass self.assertGreater(len(vf._group_cache), 0) vf.clear_cache() @@ -780,12 +854,12 @@ def test_clear_cache(self): class TestGroupCompressConfig(tests.TestCaseWithTransport): - def make_test_vf(self): - t = self.get_transport('.') + t = self.get_transport(".") t.ensure_base() - factory = groupcompress.make_pack_factory(graph=True, - delta=False, keylength=1, inconsistency_fatal=True) + factory = groupcompress.make_pack_factory( + graph=True, delta=False, keylength=1, inconsistency_fatal=True + ) vf = factory(t) self.addCleanup(groupcompress.cleanup_pack_group, vf) return vf @@ -793,15 +867,15 @@ def make_test_vf(self): def test_max_bytes_to_index_default(self): vf = self.make_test_vf() gc = vf._make_group_compressor() - self.assertEqual(vf._DEFAULT_MAX_BYTES_TO_INDEX, - vf._max_bytes_to_index) + self.assertEqual(vf._DEFAULT_MAX_BYTES_TO_INDEX, vf._max_bytes_to_index) if isinstance(gc, groupcompress.PyrexGroupCompressor): - self.assertEqual(vf._DEFAULT_MAX_BYTES_TO_INDEX, - gc._delta_index._max_bytes_to_index) + self.assertEqual( + vf._DEFAULT_MAX_BYTES_TO_INDEX, gc._delta_index._max_bytes_to_index + ) def test_max_bytes_to_index_in_config(self): c = config.GlobalConfig() - c.set_user_option('bzr.groupcompress.max_bytes_to_index', '10000') + c.set_user_option("bzr.groupcompress.max_bytes_to_index", "10000") vf = self.make_test_vf() gc = vf._make_group_compressor() self.assertEqual(10000, vf._max_bytes_to_index) @@ -810,16 +884,16 @@ def test_max_bytes_to_index_in_config(self): def test_max_bytes_to_index_bad_config(self): c = config.GlobalConfig() - c.set_user_option('bzr.groupcompress.max_bytes_to_index', 'boogah') + c.set_user_option("bzr.groupcompress.max_bytes_to_index", "boogah") vf = self.make_test_vf() # TODO: This is triggering a warning, we might want to trap and make # sure it is readable. gc = vf._make_group_compressor() - self.assertEqual(vf._DEFAULT_MAX_BYTES_TO_INDEX, - vf._max_bytes_to_index) + self.assertEqual(vf._DEFAULT_MAX_BYTES_TO_INDEX, vf._max_bytes_to_index) if isinstance(gc, groupcompress.PyrexGroupCompressor): - self.assertEqual(vf._DEFAULT_MAX_BYTES_TO_INDEX, - gc._delta_index._max_bytes_to_index) + self.assertEqual( + vf._DEFAULT_MAX_BYTES_TO_INDEX, gc._delta_index._max_bytes_to_index + ) class StubGCVF: @@ -842,45 +916,44 @@ def test_add_key_new_read_memo(self): # where index_memo is: (idx, offset, len, factory_start, factory_end) # and (idx, offset, size) is known as the 'read_memo', identifying the # raw bytes needed. - read_memo = ('fake index', 100, 50) - locations = { - ('key',): (read_memo + (None, None), None, None, None)} + read_memo = ("fake index", 100, 50) + locations = {("key",): (read_memo + (None, None), None, None, None)} batcher = groupcompress._BatchingBlockFetcher(StubGCVF(), locations) - total_size = batcher.add_key(('key',)) + total_size = batcher.add_key(("key",)) self.assertEqual(50, total_size) - self.assertEqual([('key',)], batcher.keys) + self.assertEqual([("key",)], batcher.keys) self.assertEqual([read_memo], batcher.memos_to_get) def test_add_key_duplicate_read_memo(self): """read_memos that occur multiple times in a batch will only be fetched once. """ - read_memo = ('fake index', 100, 50) + read_memo = ("fake index", 100, 50) # Two keys, both sharing the same read memo (but different overall # index_memos). locations = { - ('key1',): (read_memo + (0, 1), None, None, None), - ('key2',): (read_memo + (1, 2), None, None, None)} + ("key1",): (read_memo + (0, 1), None, None, None), + ("key2",): (read_memo + (1, 2), None, None, None), + } batcher = groupcompress._BatchingBlockFetcher(StubGCVF(), locations) - total_size = batcher.add_key(('key1',)) - total_size = batcher.add_key(('key2',)) + total_size = batcher.add_key(("key1",)) + total_size = batcher.add_key(("key2",)) self.assertEqual(50, total_size) - self.assertEqual([('key1',), ('key2',)], batcher.keys) + self.assertEqual([("key1",), ("key2",)], batcher.keys) self.assertEqual([read_memo], batcher.memos_to_get) def test_add_key_cached_read_memo(self): """Adding a key with a cached read_memo will not cause that read_memo to be added to the list to fetch. """ - read_memo = ('fake index', 100, 50) + read_memo = ("fake index", 100, 50) gcvf = StubGCVF() - gcvf._group_cache[read_memo] = 'fake block' - locations = { - ('key',): (read_memo + (None, None), None, None, None)} + gcvf._group_cache[read_memo] = "fake block" + locations = {("key",): (read_memo + (None, None), None, None, None)} batcher = groupcompress._BatchingBlockFetcher(gcvf, locations) - total_size = batcher.add_key(('key',)) + total_size = batcher.add_key(("key",)) self.assertEqual(0, total_size) - self.assertEqual([('key',)], batcher.keys) + self.assertEqual([("key",)], batcher.keys) self.assertEqual([], batcher.memos_to_get) def test_yield_factories_empty(self): @@ -890,57 +963,58 @@ def test_yield_factories_empty(self): def test_yield_factories_calls_get_blocks(self): """Uncached memos are retrieved via get_blocks.""" - read_memo1 = ('fake index', 100, 50) - read_memo2 = ('fake index', 150, 40) + read_memo1 = ("fake index", 100, 50) + read_memo2 = ("fake index", 150, 40) gcvf = StubGCVF( canned_get_blocks=[ (read_memo1, groupcompress.GroupCompressBlock()), - (read_memo2, groupcompress.GroupCompressBlock())]) + (read_memo2, groupcompress.GroupCompressBlock()), + ] + ) locations = { - ('key1',): (read_memo1 + (0, 0), None, None, None), - ('key2',): (read_memo2 + (0, 0), None, None, None)} + ("key1",): (read_memo1 + (0, 0), None, None, None), + ("key2",): (read_memo2 + (0, 0), None, None, None), + } batcher = groupcompress._BatchingBlockFetcher(gcvf, locations) - batcher.add_key(('key1',)) - batcher.add_key(('key2',)) + batcher.add_key(("key1",)) + batcher.add_key(("key2",)) factories = list(batcher.yield_factories(full_flush=True)) self.assertLength(2, factories) keys = [f.key for f in factories] kinds = [f.storage_kind for f in factories] - self.assertEqual([('key1',), ('key2',)], keys) - self.assertEqual(['groupcompress-block', 'groupcompress-block'], kinds) + self.assertEqual([("key1",), ("key2",)], keys) + self.assertEqual(["groupcompress-block", "groupcompress-block"], kinds) def test_yield_factories_flushing(self): """yield_factories holds back on yielding results from the final block unless passed full_flush=True. """ fake_block = groupcompress.GroupCompressBlock() - read_memo = ('fake index', 100, 50) + read_memo = ("fake index", 100, 50) gcvf = StubGCVF() gcvf._group_cache[read_memo] = fake_block - locations = { - ('key',): (read_memo + (0, 0), None, None, None)} + locations = {("key",): (read_memo + (0, 0), None, None, None)} batcher = groupcompress._BatchingBlockFetcher(gcvf, locations) - batcher.add_key(('key',)) + batcher.add_key(("key",)) self.assertEqual([], list(batcher.yield_factories())) factories = list(batcher.yield_factories(full_flush=True)) self.assertLength(1, factories) - self.assertEqual(('key',), factories[0].key) - self.assertEqual('groupcompress-block', factories[0].storage_kind) + self.assertEqual(("key",), factories[0].key) + self.assertEqual("groupcompress-block", factories[0].storage_kind) class TestLazyGroupCompress(tests.TestCaseWithTransport): - _texts = { - (b'key1',): b"this is a text\n" + (b"key1",): b"this is a text\n" b"with a reasonable amount of compressible bytes\n" b"which can be shared between various other texts\n", - (b'key2',): b"another text\n" + (b"key2",): b"another text\n" b"with a reasonable amount of compressible bytes\n" b"which can be shared between various other texts\n", - (b'key3',): b"yet another text which won't be extracted\n" + (b"key3",): b"yet another text which won't be extracted\n" b"with a reasonable amount of compressible bytes\n" b"which can be shared between various other texts\n", - (b'key4',): b"this will be extracted\n" + (b"key4",): b"this will be extracted\n" b"but references most of its bytes from\n" b"yet another text which won't be extracted\n" b"with a reasonable amount of compressible bytes\n" @@ -951,10 +1025,11 @@ def make_block(self, key_to_text): """Create a GroupCompressBlock, filling it with the given texts.""" compressor = groupcompress.GroupCompressor() for key in sorted(key_to_text): - compressor.compress( - key, [key_to_text[key]], len(key_to_text[key]), None) - locs = {key: (start, end) for key, (start, _, end, _) - in compressor.labels_deltas.items()} + compressor.compress(key, [key_to_text[key]], len(key_to_text[key]), None) + locs = { + key: (start, end) + for key, (start, _, end, _) in compressor.labels_deltas.items() + } block = compressor.flush() raw_bytes = block.to_bytes() return locs, groupcompress.GroupCompressBlock.from_bytes(raw_bytes) @@ -973,26 +1048,26 @@ def make_block_and_full_manager(self, texts): def test_get_fulltexts(self): locations, block = self.make_block(self._texts) manager = groupcompress._LazyGroupContentManager(block) - self.add_key_to_manager((b'key1',), locations, block, manager) - self.add_key_to_manager((b'key2',), locations, block, manager) + self.add_key_to_manager((b"key1",), locations, block, manager) + self.add_key_to_manager((b"key2",), locations, block, manager) result_order = [] for record in manager.get_record_stream(): result_order.append(record.key) text = self._texts[record.key] - self.assertEqual(text, record.get_bytes_as('fulltext')) - self.assertEqual([(b'key1',), (b'key2',)], result_order) + self.assertEqual(text, record.get_bytes_as("fulltext")) + self.assertEqual([(b"key1",), (b"key2",)], result_order) # If we build the manager in the opposite order, we should get them # back in the opposite order manager = groupcompress._LazyGroupContentManager(block) - self.add_key_to_manager((b'key2',), locations, block, manager) - self.add_key_to_manager((b'key1',), locations, block, manager) + self.add_key_to_manager((b"key2",), locations, block, manager) + self.add_key_to_manager((b"key1",), locations, block, manager) result_order = [] for record in manager.get_record_stream(): result_order.append(record.key) text = self._texts[record.key] - self.assertEqual(text, record.get_bytes_as('fulltext')) - self.assertEqual([(b'key2',), (b'key1',)], result_order) + self.assertEqual(text, record.get_bytes_as("fulltext")) + self.assertEqual([(b"key2",), (b"key1",)], result_order) def test__wire_bytes_no_keys(self): locations, block = self.make_block(self._texts) @@ -1002,59 +1077,60 @@ def test__wire_bytes_no_keys(self): # We should have triggered a strip, since we aren't using any content stripped_block = manager._block.to_bytes() self.assertGreater(block_length, len(stripped_block)) - empty_z_header = zlib.compress(b'') - self.assertEqual(b'groupcompress-block\n' - b'8\n' # len(compress('')) - b'0\n' # len('') - b'%d\n' # compressed block len - b'%s' # zheader - b'%s' # block - % (len(stripped_block), empty_z_header, - stripped_block), - wire_bytes) + empty_z_header = zlib.compress(b"") + self.assertEqual( + b"groupcompress-block\n" + b"8\n" # len(compress('')) + b"0\n" # len('') + b"%d\n" # compressed block len + b"%s" # zheader + b"%s" % (len(stripped_block), empty_z_header, stripped_block), # block + wire_bytes, + ) def test__wire_bytes(self): locations, block = self.make_block(self._texts) manager = groupcompress._LazyGroupContentManager(block) - self.add_key_to_manager((b'key1',), locations, block, manager) - self.add_key_to_manager((b'key4',), locations, block, manager) + self.add_key_to_manager((b"key1",), locations, block, manager) + self.add_key_to_manager((b"key4",), locations, block, manager) block_bytes = block.to_bytes() wire_bytes = manager._wire_bytes() - (storage_kind, z_header_len, header_len, - block_len, rest) = wire_bytes.split(b'\n', 4) + (storage_kind, z_header_len, header_len, block_len, rest) = wire_bytes.split( + b"\n", 4 + ) z_header_len = int(z_header_len) header_len = int(header_len) block_len = int(block_len) - self.assertEqual(b'groupcompress-block', storage_kind) + self.assertEqual(b"groupcompress-block", storage_kind) self.assertEqual(34, z_header_len) self.assertEqual(26, header_len) self.assertEqual(len(block_bytes), block_len) z_header = rest[:z_header_len] header = zlib.decompress(z_header) self.assertEqual(header_len, len(header)) - entry1 = locations[(b'key1',)] - entry4 = locations[(b'key4',)] - self.assertEqualDiff(b'key1\n' - b'\n' # no parents - b'%d\n' # start offset - b'%d\n' # end offset - b'key4\n' - b'\n' - b'%d\n' - b'%d\n' - % (entry1[0], entry1[1], - entry4[0], entry4[1]), - header) + entry1 = locations[(b"key1",)] + entry4 = locations[(b"key4",)] + self.assertEqualDiff( + b"key1\n" + b"\n" # no parents + b"%d\n" # start offset + b"%d\n" # end offset + b"key4\n" + b"\n" + b"%d\n" + b"%d\n" % (entry1[0], entry1[1], entry4[0], entry4[1]), + header, + ) z_block = rest[z_header_len:] self.assertEqual(block_bytes, z_block) def test_from_bytes(self): locations, block = self.make_block(self._texts) manager = groupcompress._LazyGroupContentManager(block) - self.add_key_to_manager((b'key1',), locations, block, manager) - self.add_key_to_manager((b'key4',), locations, block, manager) + self.add_key_to_manager((b"key1",), locations, block, manager) + self.add_key_to_manager((b"key4",), locations, block, manager) wire_bytes = manager._wire_bytes() - self.assertStartsWith(wire_bytes, b'groupcompress-block\n') + self.assertStartsWith(wire_bytes, b"groupcompress-block\n") manager = groupcompress._LazyGroupContentManager.from_bytes(wire_bytes) self.assertIsInstance(manager, groupcompress._LazyGroupContentManager) self.assertEqual(2, len(manager._factories)) @@ -1063,8 +1139,8 @@ def test_from_bytes(self): for record in manager.get_record_stream(): result_order.append(record.key) text = self._texts[record.key] - self.assertEqual(text, record.get_bytes_as('fulltext')) - self.assertEqual([(b'key1',), (b'key4',)], result_order) + self.assertEqual(text, record.get_bytes_as("fulltext")) + self.assertEqual([(b"key1",), (b"key4",)], result_order) def test__check_rebuild_no_changes(self): block, manager = self.make_block_and_full_manager(self._texts) @@ -1075,29 +1151,27 @@ def test__check_rebuild_only_one(self): locations, block = self.make_block(self._texts) manager = groupcompress._LazyGroupContentManager(block) # Request just the first key, which should trigger a 'strip' action - self.add_key_to_manager((b'key1',), locations, block, manager) + self.add_key_to_manager((b"key1",), locations, block, manager) manager._check_rebuild_block() self.assertIsNot(block, manager._block) self.assertGreater(block._content_length, manager._block._content_length) # We should be able to still get the content out of this block, though # it should only have 1 entry for record in manager.get_record_stream(): - self.assertEqual((b'key1',), record.key) - self.assertEqual(self._texts[record.key], - record.get_bytes_as('fulltext')) + self.assertEqual((b"key1",), record.key) + self.assertEqual(self._texts[record.key], record.get_bytes_as("fulltext")) def test__check_rebuild_middle(self): locations, block = self.make_block(self._texts) manager = groupcompress._LazyGroupContentManager(block) # Request a small key in the middle should trigger a 'rebuild' - self.add_key_to_manager((b'key4',), locations, block, manager) + self.add_key_to_manager((b"key4",), locations, block, manager) manager._check_rebuild_block() self.assertIsNot(block, manager._block) self.assertGreater(block._content_length, manager._block._content_length) for record in manager.get_record_stream(): - self.assertEqual((b'key4',), record.key) - self.assertEqual(self._texts[record.key], - record.get_bytes_as('fulltext')) + self.assertEqual((b"key4",), record.key) + self.assertEqual(self._texts[record.key], record.get_bytes_as("fulltext")) def test_manager_default_compressor_settings(self): locations, old_block = self.make_block(self._texts) @@ -1105,40 +1179,46 @@ def test_manager_default_compressor_settings(self): gcvf = groupcompress.GroupCompressVersionedFiles # It doesn't greedily evaluate _max_bytes_to_index self.assertIs(None, manager._compressor_settings) - self.assertEqual(gcvf._DEFAULT_COMPRESSOR_SETTINGS, - manager._get_compressor_settings()) + self.assertEqual( + gcvf._DEFAULT_COMPRESSOR_SETTINGS, manager._get_compressor_settings() + ) def test_manager_custom_compressor_settings(self): locations, old_block = self.make_block(self._texts) called = [] def compressor_settings(): - called.append('called') + called.append("called") return (10,) - manager = groupcompress._LazyGroupContentManager(old_block, - get_compressor_settings=compressor_settings) + + manager = groupcompress._LazyGroupContentManager( + old_block, get_compressor_settings=compressor_settings + ) # It doesn't greedily evaluate compressor_settings self.assertIs(None, manager._compressor_settings) self.assertEqual((10,), manager._get_compressor_settings()) self.assertEqual((10,), manager._get_compressor_settings()) self.assertEqual((10,), manager._compressor_settings) # Only called 1 time - self.assertEqual(['called'], called) + self.assertEqual(["called"], called) def test__rebuild_handles_compressor_settings(self): - if not isinstance(groupcompress.GroupCompressor, - groupcompress.PyrexGroupCompressor): - raise tests.TestNotApplicable('pure-python compressor' - ' does not handle compressor_settings') + if not isinstance( + groupcompress.GroupCompressor, groupcompress.PyrexGroupCompressor + ): + raise tests.TestNotApplicable( + "pure-python compressor" " does not handle compressor_settings" + ) locations, old_block = self.make_block(self._texts) - manager = groupcompress._LazyGroupContentManager(old_block, - get_compressor_settings=lambda: {'max_bytes_to_index': 32}) + manager = groupcompress._LazyGroupContentManager( + old_block, get_compressor_settings=lambda: {"max_bytes_to_index": 32} + ) gc = manager._make_group_compressor() self.assertEqual(32, gc._delta_index._max_bytes_to_index) - self.add_key_to_manager((b'key3',), locations, old_block, manager) - self.add_key_to_manager((b'key4',), locations, old_block, manager) + self.add_key_to_manager((b"key3",), locations, old_block, manager) + self.add_key_to_manager((b"key4",), locations, old_block, manager) action, last_byte, total_bytes = manager._check_rebuild_action() - self.assertEqual('rebuild', action) + self.assertEqual("rebuild", action) manager._rebuild_block() new_block = manager._block self.assertIsNot(old_block, new_block) @@ -1163,14 +1243,14 @@ def test_check_is_well_utilized_all_keys(self): def test_check_is_well_utilized_mixed_keys(self): texts = {} - f1k1 = (b'f1', b'k1') - f1k2 = (b'f1', b'k2') - f2k1 = (b'f2', b'k1') - f2k2 = (b'f2', b'k2') - texts[f1k1] = self._texts[(b'key1',)] - texts[f1k2] = self._texts[(b'key2',)] - texts[f2k1] = self._texts[(b'key3',)] - texts[f2k2] = self._texts[(b'key4',)] + f1k1 = (b"f1", b"k1") + f1k2 = (b"f1", b"k2") + f2k1 = (b"f2", b"k1") + f2k2 = (b"f2", b"k2") + texts[f1k1] = self._texts[(b"key1",)] + texts[f1k2] = self._texts[(b"key2",)] + texts[f2k1] = self._texts[(b"key3",)] + texts[f2k2] = self._texts[(b"key4",)] block, manager = self.make_block_and_full_manager(texts) self.assertFalse(manager.check_is_well_utilized()) manager._full_enough_block_size = block._content_length @@ -1184,67 +1264,67 @@ def test_check_is_well_utilized_partial_use(self): locations, block = self.make_block(self._texts) manager = groupcompress._LazyGroupContentManager(block) manager._full_enough_block_size = block._content_length - self.add_key_to_manager((b'key1',), locations, block, manager) - self.add_key_to_manager((b'key2',), locations, block, manager) + self.add_key_to_manager((b"key1",), locations, block, manager) + self.add_key_to_manager((b"key2",), locations, block, manager) # Just using the content from key1 and 2 is not enough to be considered # 'complete' self.assertFalse(manager.check_is_well_utilized()) # However if we add key3, then we have enough, as we only require 75% # consumption - self.add_key_to_manager((b'key4',), locations, block, manager) + self.add_key_to_manager((b"key4",), locations, block, manager) self.assertTrue(manager.check_is_well_utilized()) class Test_GCBuildDetails(tests.TestCase): - def test_acts_like_tuple(self): # _GCBuildDetails inlines some of the data that used to be spread out # across a bunch of tuples - bd = groupcompress._GCBuildDetails((('parent1',), ('parent2',)), - ('INDEX', 10, 20, 0, 5)) + bd = groupcompress._GCBuildDetails( + (("parent1",), ("parent2",)), ("INDEX", 10, 20, 0, 5) + ) self.assertEqual(4, len(bd)) - self.assertEqual(('INDEX', 10, 20, 0, 5), bd[0]) + self.assertEqual(("INDEX", 10, 20, 0, 5), bd[0]) self.assertEqual(None, bd[1]) # Compression Parent is always None - self.assertEqual((('parent1',), ('parent2',)), bd[2]) - self.assertEqual(('group', None), bd[3]) # Record details + self.assertEqual((("parent1",), ("parent2",)), bd[2]) + self.assertEqual(("group", None), bd[3]) # Record details def test__repr__(self): - bd = groupcompress._GCBuildDetails((('parent1',), ('parent2',)), - ('INDEX', 10, 20, 0, 5)) - self.assertEqual("_GCBuildDetails(('INDEX', 10, 20, 0, 5)," - " (('parent1',), ('parent2',)))", - repr(bd)) + bd = groupcompress._GCBuildDetails( + (("parent1",), ("parent2",)), ("INDEX", 10, 20, 0, 5) + ) + self.assertEqual( + "_GCBuildDetails(('INDEX', 10, 20, 0, 5)," " (('parent1',), ('parent2',)))", + repr(bd), + ) class TestBase128Int(tests.TestCase): - def assertEqualEncode(self, bytes, val): self.assertEqual(bytes, groupcompress.encode_base128_int(val)) def assertEqualDecode(self, val, num_decode, bytes): - self.assertEqual((val, num_decode), - groupcompress.decode_base128_int(bytes)) + self.assertEqual((val, num_decode), groupcompress.decode_base128_int(bytes)) def test_encode(self): - self.assertEqualEncode(b'\x01', 1) - self.assertEqualEncode(b'\x02', 2) - self.assertEqualEncode(b'\x7f', 127) - self.assertEqualEncode(b'\x80\x01', 128) - self.assertEqualEncode(b'\xff\x01', 255) - self.assertEqualEncode(b'\x80\x02', 256) - self.assertEqualEncode(b'\xff\xff\xff\xff\x0f', 0xFFFFFFFF) + self.assertEqualEncode(b"\x01", 1) + self.assertEqualEncode(b"\x02", 2) + self.assertEqualEncode(b"\x7f", 127) + self.assertEqualEncode(b"\x80\x01", 128) + self.assertEqualEncode(b"\xff\x01", 255) + self.assertEqualEncode(b"\x80\x02", 256) + self.assertEqualEncode(b"\xff\xff\xff\xff\x0f", 0xFFFFFFFF) def test_decode(self): - self.assertEqualDecode(1, 1, b'\x01') - self.assertEqualDecode(2, 1, b'\x02') - self.assertEqualDecode(127, 1, b'\x7f') - self.assertEqualDecode(128, 2, b'\x80\x01') - self.assertEqualDecode(255, 2, b'\xff\x01') - self.assertEqualDecode(256, 2, b'\x80\x02') - self.assertEqualDecode(0xFFFFFFFF, 5, b'\xff\xff\xff\xff\x0f') + self.assertEqualDecode(1, 1, b"\x01") + self.assertEqualDecode(2, 1, b"\x02") + self.assertEqualDecode(127, 1, b"\x7f") + self.assertEqualDecode(128, 2, b"\x80\x01") + self.assertEqualDecode(255, 2, b"\xff\x01") + self.assertEqualDecode(256, 2, b"\x80\x02") + self.assertEqualDecode(0xFFFFFFFF, 5, b"\xff\xff\xff\xff\x0f") def test_decode_with_trailing_bytes(self): - self.assertEqualDecode(1, 1, b'\x01abcdef') - self.assertEqualDecode(127, 1, b'\x7f\x01') - self.assertEqualDecode(128, 2, b'\x80\x01abcdef') - self.assertEqualDecode(255, 2, b'\xff\x01\xff') + self.assertEqualDecode(1, 1, b"\x01abcdef") + self.assertEqualDecode(127, 1, b"\x7f\x01") + self.assertEqualDecode(128, 2, b"\x80\x01abcdef") + self.assertEqualDecode(255, 2, b"\xff\x01\xff") diff --git a/breezy/bzr/tests/test_hashcache.py b/breezy/bzr/tests/test_hashcache.py index 133d9cf4e1..2107c161c3 100644 --- a/breezy/bzr/tests/test_hashcache.py +++ b/breezy/bzr/tests/test_hashcache.py @@ -34,69 +34,69 @@ class TestHashCache(TestCaseInTempDir): def make_hashcache(self): # make a dummy bzr directory just to hold the cache - os.mkdir('.bzr') - hc = hashcache.HashCache('.', '.bzr/stat-cache') + os.mkdir(".bzr") + hc = hashcache.HashCache(".", ".bzr/stat-cache") return hc def reopen_hashcache(self): - hc = hashcache.HashCache('.', '.bzr/stat-cache') + hc = hashcache.HashCache(".", ".bzr/stat-cache") hc.read() return hc def test_hashcache_initial_miss(self): """Get correct hash from an empty hashcache.""" hc = self.make_hashcache() - self.build_tree_contents([('foo', b'hello')]) - self.assertEqual(hc.get_sha1('foo'), - b'aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d') + self.build_tree_contents([("foo", b"hello")]) + self.assertEqual( + hc.get_sha1("foo"), b"aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d" + ) self.assertEqual(hc.miss_count, 1) self.assertEqual(hc.hit_count, 0) def test_hashcache_new_file(self): hc = self.make_hashcache() - self.build_tree_contents([('foo', b'goodbye')]) + self.build_tree_contents([("foo", b"goodbye")]) # now read without pausing; it may not be possible to cache it as its # so new - self.assertEqual(hc.get_sha1('foo'), sha1(b'goodbye')) + self.assertEqual(hc.get_sha1("foo"), sha1(b"goodbye")) def test_hashcache_nonexistent_file(self): hc = self.make_hashcache() - self.assertEqual(hc.get_sha1('no-name-yet'), None) + self.assertEqual(hc.get_sha1("no-name-yet"), None) def test_hashcache_replaced_file(self): hc = self.make_hashcache() - self.build_tree_contents([('foo', b'goodbye')]) - self.assertEqual(hc.get_sha1('foo'), sha1(b'goodbye')) - os.remove('foo') - self.assertEqual(hc.get_sha1('foo'), None) - self.build_tree_contents([('foo', b'new content')]) - self.assertEqual(hc.get_sha1('foo'), sha1(b'new content')) + self.build_tree_contents([("foo", b"goodbye")]) + self.assertEqual(hc.get_sha1("foo"), sha1(b"goodbye")) + os.remove("foo") + self.assertEqual(hc.get_sha1("foo"), None) + self.build_tree_contents([("foo", b"new content")]) + self.assertEqual(hc.get_sha1("foo"), sha1(b"new content")) def test_hashcache_not_file(self): hc = self.make_hashcache() - self.build_tree(['subdir/']) - self.assertEqual(hc.get_sha1('subdir'), None) + self.build_tree(["subdir/"]) + self.assertEqual(hc.get_sha1("subdir"), None) def test_hashcache_load(self): hc = self.make_hashcache() - self.build_tree_contents([('foo', b'contents')]) + self.build_tree_contents([("foo", b"contents")]) pause() - self.assertEqual(hc.get_sha1('foo'), sha1(b'contents')) + self.assertEqual(hc.get_sha1("foo"), sha1(b"contents")) hc.write() hc = self.reopen_hashcache() - self.assertEqual(hc.get_sha1('foo'), sha1(b'contents')) + self.assertEqual(hc.get_sha1("foo"), sha1(b"contents")) self.assertEqual(hc.hit_count, 1) def test_hammer_hashcache(self): hc = self.make_hashcache() for i in range(10000): - with open('foo', 'wb') as f: - last_content = b'%08x' % i + with open("foo", "wb") as f: + last_content = b"%08x" % i f.write(last_content) last_sha1 = sha1(last_content) - self.log("iteration %d: %r -> %r", - i, last_content, last_sha1) - got_sha1 = hc.get_sha1('foo') + self.log("iteration %d: %r -> %r", i, last_content, last_sha1) + got_sha1 = hc.get_sha1("foo") self.assertEqual(got_sha1, last_sha1) hc.write() hc = self.reopen_hashcache() @@ -105,9 +105,9 @@ def test_hashcache_raise(self): """Check that hashcache can raise BzrError.""" self.requireFeature(OsFifoFeature) hc = self.make_hashcache() - os.mkfifo('a') + os.mkfifo("a") # It's possible that the system supports fifos but the filesystem # can't. In that case we should skip at this point. But in fact # such combinations don't usually occur for the filesystem where # people test bzr. - self.assertRaises(OSError, hc.get_sha1, 'a') + self.assertRaises(OSError, hc.get_sha1, "a") diff --git a/breezy/bzr/tests/test_index.py b/breezy/bzr/tests/test_index.py index 8b616c5326..3b058c5fa6 100644 --- a/breezy/bzr/tests/test_index.py +++ b/breezy/bzr/tests/test_index.py @@ -21,47 +21,40 @@ class ErrorTests(tests.TestCase): - def test_bad_index_format_signature(self): error = _mod_index.BadIndexFormatSignature("foo", "bar") - self.assertEqual("foo is not an index of type bar.", - str(error)) + self.assertEqual("foo is not an index of type bar.", str(error)) def test_bad_index_data(self): error = _mod_index.BadIndexData("foo") - self.assertEqual("Error in data for index foo.", - str(error)) + self.assertEqual("Error in data for index foo.", str(error)) def test_bad_index_duplicate_key(self): error = _mod_index.BadIndexDuplicateKey("foo", "bar") - self.assertEqual("The key 'foo' is already in index 'bar'.", - str(error)) + self.assertEqual("The key 'foo' is already in index 'bar'.", str(error)) def test_bad_index_key(self): error = _mod_index.BadIndexKey("foo") - self.assertEqual("The key 'foo' is not a valid key.", - str(error)) + self.assertEqual("The key 'foo' is not a valid key.", str(error)) def test_bad_index_options(self): error = _mod_index.BadIndexOptions("foo") - self.assertEqual("Could not parse options for index foo.", - str(error)) + self.assertEqual("Could not parse options for index foo.", str(error)) def test_bad_index_value(self): error = _mod_index.BadIndexValue("foo") - self.assertEqual("The value 'foo' is not a valid value.", - str(error)) + self.assertEqual("The value 'foo' is not a valid value.", str(error)) class TestGraphIndexBuilder(tests.TestCaseWithMemoryTransport): - def test_build_index_empty(self): builder = _mod_index.GraphIndexBuilder() stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=0\nkey_elements=1\nlen=0\n\n", - contents) + contents, + ) def test_build_index_empty_two_element_keys(self): builder = _mod_index.GraphIndexBuilder(key_elements=2) @@ -69,7 +62,8 @@ def test_build_index_empty_two_element_keys(self): contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=0\nkey_elements=2\nlen=0\n\n", - contents) + contents, + ) def test_build_index_one_reference_list_empty(self): builder = _mod_index.GraphIndexBuilder(reference_lists=1) @@ -77,7 +71,8 @@ def test_build_index_one_reference_list_empty(self): contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=1\nkey_elements=1\nlen=0\n\n", - contents) + contents, + ) def test_build_index_two_reference_list_empty(self): builder = _mod_index.GraphIndexBuilder(reference_lists=2) @@ -85,46 +80,55 @@ def test_build_index_two_reference_list_empty(self): contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=2\nkey_elements=1\nlen=0\n\n", - contents) + contents, + ) def test_build_index_one_node_no_refs(self): builder = _mod_index.GraphIndexBuilder() - builder.add_node((b'akey', ), b'data') + builder.add_node((b"akey",), b"data") stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=0\nkey_elements=1\nlen=1\n" - b"akey\x00\x00\x00data\n\n", contents) + b"akey\x00\x00\x00data\n\n", + contents, + ) def test_build_index_one_node_no_refs_accepts_empty_reflist(self): builder = _mod_index.GraphIndexBuilder() - builder.add_node((b'akey', ), b'data', ()) + builder.add_node((b"akey",), b"data", ()) stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=0\nkey_elements=1\nlen=1\n" - b"akey\x00\x00\x00data\n\n", contents) + b"akey\x00\x00\x00data\n\n", + contents, + ) def test_build_index_one_node_2_element_keys(self): # multipart keys are separated by \x00 - because they are fixed length, # not variable this does not cause any issues, and seems clearer to the # author. builder = _mod_index.GraphIndexBuilder(key_elements=2) - builder.add_node((b'akey', b'secondpart'), b'data') + builder.add_node((b"akey", b"secondpart"), b"data") stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=0\nkey_elements=2\nlen=1\n" - b"akey\x00secondpart\x00\x00\x00data\n\n", contents) + b"akey\x00secondpart\x00\x00\x00data\n\n", + contents, + ) def test_add_node_empty_value(self): builder = _mod_index.GraphIndexBuilder() - builder.add_node((b'akey', ), b'') + builder.add_node((b"akey",), b"") stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=0\nkey_elements=1\nlen=1\n" - b"akey\x00\x00\x00\n\n", contents) + b"akey\x00\x00\x00\n\n", + contents, + ) def test_build_index_nodes_sorted(self): # the highest sorted node comes first. @@ -132,9 +136,9 @@ def test_build_index_nodes_sorted(self): # use three to have a good chance of glitching dictionary hash # lookups etc. Insert in randomish order that is not correct # and not the reverse of the correct order. - builder.add_node((b'2002', ), b'data') - builder.add_node((b'2000', ), b'data') - builder.add_node((b'2001', ), b'data') + builder.add_node((b"2002",), b"data") + builder.add_node((b"2000",), b"data") + builder.add_node((b"2001",), b"data") stream = builder.finish() contents = stream.read() self.assertEqual( @@ -142,7 +146,9 @@ def test_build_index_nodes_sorted(self): b"2000\x00\x00\x00data\n" b"2001\x00\x00\x00data\n" b"2002\x00\x00\x00data\n" - b"\n", contents) + b"\n", + contents, + ) def test_build_index_2_element_key_nodes_sorted(self): # multiple element keys are sorted first-key, second-key. @@ -150,15 +156,15 @@ def test_build_index_2_element_key_nodes_sorted(self): # use three values of each key element, to have a good chance of # glitching dictionary hash lookups etc. Insert in randomish order that # is not correct and not the reverse of the correct order. - builder.add_node((b'2002', b'2002'), b'data') - builder.add_node((b'2002', b'2000'), b'data') - builder.add_node((b'2002', b'2001'), b'data') - builder.add_node((b'2000', b'2002'), b'data') - builder.add_node((b'2000', b'2000'), b'data') - builder.add_node((b'2000', b'2001'), b'data') - builder.add_node((b'2001', b'2002'), b'data') - builder.add_node((b'2001', b'2000'), b'data') - builder.add_node((b'2001', b'2001'), b'data') + builder.add_node((b"2002", b"2002"), b"data") + builder.add_node((b"2002", b"2000"), b"data") + builder.add_node((b"2002", b"2001"), b"data") + builder.add_node((b"2000", b"2002"), b"data") + builder.add_node((b"2000", b"2000"), b"data") + builder.add_node((b"2000", b"2001"), b"data") + builder.add_node((b"2001", b"2002"), b"data") + builder.add_node((b"2001", b"2000"), b"data") + builder.add_node((b"2001", b"2001"), b"data") stream = builder.finish() contents = stream.read() self.assertEqual( @@ -172,38 +178,45 @@ def test_build_index_2_element_key_nodes_sorted(self): b"2002\x002000\x00\x00\x00data\n" b"2002\x002001\x00\x00\x00data\n" b"2002\x002002\x00\x00\x00data\n" - b"\n", contents) + b"\n", + contents, + ) def test_build_index_reference_lists_are_included_one(self): builder = _mod_index.GraphIndexBuilder(reference_lists=1) - builder.add_node((b'key', ), b'data', ([], )) + builder.add_node((b"key",), b"data", ([],)) stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=1\nkey_elements=1\nlen=1\n" b"key\x00\x00\x00data\n" - b"\n", contents) + b"\n", + contents, + ) def test_build_index_reference_lists_with_2_element_keys(self): - builder = _mod_index.GraphIndexBuilder( - reference_lists=1, key_elements=2) - builder.add_node((b'key', b'key2'), b'data', ([], )) + builder = _mod_index.GraphIndexBuilder(reference_lists=1, key_elements=2) + builder.add_node((b"key", b"key2"), b"data", ([],)) stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=1\nkey_elements=2\nlen=1\n" b"key\x00key2\x00\x00\x00data\n" - b"\n", contents) + b"\n", + contents, + ) def test_build_index_reference_lists_are_included_two(self): builder = _mod_index.GraphIndexBuilder(reference_lists=2) - builder.add_node((b'key', ), b'data', ([], [])) + builder.add_node((b"key",), b"data", ([], [])) stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=2\nkey_elements=1\nlen=1\n" b"key\x00\x00\t\x00data\n" - b"\n", contents) + b"\n", + contents, + ) def test_clear_cache(self): builder = _mod_index.GraphIndexBuilder(reference_lists=2) @@ -212,22 +225,23 @@ def test_clear_cache(self): def test_node_references_are_byte_offsets(self): builder = _mod_index.GraphIndexBuilder(reference_lists=1) - builder.add_node((b'reference', ), b'data', ([], )) - builder.add_node((b'key', ), b'data', ([(b'reference', )], )) + builder.add_node((b"reference",), b"data", ([],)) + builder.add_node((b"key",), b"data", ([(b"reference",)],)) stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=1\nkey_elements=1\nlen=2\n" b"key\x00\x0072\x00data\n" b"reference\x00\x00\x00data\n" - b"\n", contents) + b"\n", + contents, + ) def test_node_references_are_cr_delimited(self): builder = _mod_index.GraphIndexBuilder(reference_lists=1) - builder.add_node((b'reference', ), b'data', ([], )) - builder.add_node((b'reference2', ), b'data', ([], )) - builder.add_node((b'key', ), b'data', - ([(b'reference', ), (b'reference2', )], )) + builder.add_node((b"reference",), b"data", ([],)) + builder.add_node((b"reference2",), b"data", ([],)) + builder.add_node((b"key",), b"data", ([(b"reference",), (b"reference2",)],)) stream = builder.finish() contents = stream.read() self.assertEqual( @@ -235,25 +249,27 @@ def test_node_references_are_cr_delimited(self): b"key\x00\x00077\r094\x00data\n" b"reference\x00\x00\x00data\n" b"reference2\x00\x00\x00data\n" - b"\n", contents) + b"\n", + contents, + ) def test_multiple_reference_lists_are_tab_delimited(self): builder = _mod_index.GraphIndexBuilder(reference_lists=2) - builder.add_node((b'keference', ), b'data', ([], [])) - builder.add_node((b'rey', ), b'data', - ([(b'keference', )], [(b'keference', )])) + builder.add_node((b"keference",), b"data", ([], [])) + builder.add_node((b"rey",), b"data", ([(b"keference",)], [(b"keference",)])) stream = builder.finish() contents = stream.read() self.assertEqual( b"Bazaar Graph Index 1\nnode_ref_lists=2\nkey_elements=1\nlen=2\n" b"keference\x00\x00\t\x00data\n" b"rey\x00\x0059\t59\x00data\n" - b"\n", contents) + b"\n", + contents, + ) def test_add_node_referencing_missing_key_makes_absent(self): builder = _mod_index.GraphIndexBuilder(reference_lists=1) - builder.add_node((b'rey', ), b'data', - ([(b'beference', ), (b'aeference2', )], )) + builder.add_node((b"rey",), b"data", ([(b"beference",), (b"aeference2",)],)) stream = builder.finish() contents = stream.read() self.assertEqual( @@ -261,13 +277,15 @@ def test_add_node_referencing_missing_key_makes_absent(self): b"aeference2\x00a\x00\x00\n" b"beference\x00a\x00\x00\n" b"rey\x00\x00074\r059\x00data\n" - b"\n", contents) + b"\n", + contents, + ) def test_node_references_three_digits(self): # test the node digit expands as needed. builder = _mod_index.GraphIndexBuilder(reference_lists=1) - references = [((b"%d" % val), ) for val in range(8, -1, -1)] - builder.add_node((b'2-key', ), b'', (references, )) + references = [((b"%d" % val),) for val in range(8, -1, -1)] + builder.add_node((b"2-key",), b"", (references,)) stream = builder.finish() contents = stream.read() self.assertEqualDiff( @@ -282,13 +300,15 @@ def test_node_references_three_digits(self): b"6\x00a\x00\x00\n" b"7\x00a\x00\x00\n" b"8\x00a\x00\x00\n" - b"\n", contents) + b"\n", + contents, + ) def test_absent_has_no_reference_overhead(self): # the offsets after an absent record should be correct when there are # >1 reference lists. builder = _mod_index.GraphIndexBuilder(reference_lists=2) - builder.add_node((b'parent', ), b'', ([(b'aail', ), (b'zther', )], [])) + builder.add_node((b"parent",), b"", ([(b"aail",), (b"zther",)], [])) stream = builder.finish() contents = stream.read() self.assertEqual( @@ -296,105 +316,160 @@ def test_absent_has_no_reference_overhead(self): b"aail\x00a\x00\x00\n" b"parent\x00\x0059\r84\t\x00\n" b"zther\x00a\x00\x00\n" - b"\n", contents) + b"\n", + contents, + ) def test_add_node_bad_key(self): builder = _mod_index.GraphIndexBuilder() - for bad_char in bytearray(b'\t\n\x0b\x0c\r\x00 '): - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, - (b'a%skey' % bytes([bad_char]), ), b'data') - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, - (), b'data') - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, - b'not-a-tuple', b'data') + for bad_char in bytearray(b"\t\n\x0b\x0c\r\x00 "): + self.assertRaises( + _mod_index.BadIndexKey, + builder.add_node, + (b"a%skey" % bytes([bad_char]),), + b"data", + ) + self.assertRaises(_mod_index.BadIndexKey, builder.add_node, (), b"data") + self.assertRaises( + _mod_index.BadIndexKey, builder.add_node, b"not-a-tuple", b"data" + ) # not enough length - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, - (), b'data') + self.assertRaises(_mod_index.BadIndexKey, builder.add_node, (), b"data") # too long - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, - (b'primary', b'secondary'), b'data') + self.assertRaises( + _mod_index.BadIndexKey, + builder.add_node, + (b"primary", b"secondary"), + b"data", + ) # secondary key elements get checked too: builder = _mod_index.GraphIndexBuilder(key_elements=2) - for bad_char in bytearray(b'\t\n\x0b\x0c\r\x00 '): - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, - (b'prefix', b'a%skey' % bytes([bad_char])), b'data') + for bad_char in bytearray(b"\t\n\x0b\x0c\r\x00 "): + self.assertRaises( + _mod_index.BadIndexKey, + builder.add_node, + (b"prefix", b"a%skey" % bytes([bad_char])), + b"data", + ) def test_add_node_bad_data(self): builder = _mod_index.GraphIndexBuilder() - self.assertRaises(_mod_index.BadIndexValue, builder.add_node, (b'akey', ), - b'data\naa') - self.assertRaises(_mod_index.BadIndexValue, builder.add_node, (b'akey', ), - b'data\x00aa') + self.assertRaises( + _mod_index.BadIndexValue, builder.add_node, (b"akey",), b"data\naa" + ) + self.assertRaises( + _mod_index.BadIndexValue, builder.add_node, (b"akey",), b"data\x00aa" + ) def test_add_node_bad_mismatched_ref_lists_length(self): builder = _mod_index.GraphIndexBuilder() - self.assertRaises(_mod_index.BadIndexValue, builder.add_node, (b'akey', ), - b'data aa', ([], )) + self.assertRaises( + _mod_index.BadIndexValue, builder.add_node, (b"akey",), b"data aa", ([],) + ) builder = _mod_index.GraphIndexBuilder(reference_lists=1) - self.assertRaises(_mod_index.BadIndexValue, builder.add_node, (b'akey', ), - b'data aa') - self.assertRaises(_mod_index.BadIndexValue, builder.add_node, (b'akey', ), - b'data aa', (), ) - self.assertRaises(_mod_index.BadIndexValue, builder.add_node, (b'akey', ), - b'data aa', ([], [])) + self.assertRaises( + _mod_index.BadIndexValue, builder.add_node, (b"akey",), b"data aa" + ) + self.assertRaises( + _mod_index.BadIndexValue, + builder.add_node, + (b"akey",), + b"data aa", + (), + ) + self.assertRaises( + _mod_index.BadIndexValue, builder.add_node, (b"akey",), b"data aa", ([], []) + ) builder = _mod_index.GraphIndexBuilder(reference_lists=2) - self.assertRaises(_mod_index.BadIndexValue, builder.add_node, (b'akey', ), - b'data aa') - self.assertRaises(_mod_index.BadIndexValue, builder.add_node, (b'akey', ), - b'data aa', ([], )) - self.assertRaises(_mod_index.BadIndexValue, builder.add_node, (b'akey', ), - b'data aa', ([], [], [])) + self.assertRaises( + _mod_index.BadIndexValue, builder.add_node, (b"akey",), b"data aa" + ) + self.assertRaises( + _mod_index.BadIndexValue, builder.add_node, (b"akey",), b"data aa", ([],) + ) + self.assertRaises( + _mod_index.BadIndexValue, + builder.add_node, + (b"akey",), + b"data aa", + ([], [], []), + ) def test_add_node_bad_key_in_reference_lists(self): # first list, first key - trivial builder = _mod_index.GraphIndexBuilder(reference_lists=1) - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, (b'akey', ), - b'data aa', ([(b'a key', )], )) + self.assertRaises( + _mod_index.BadIndexKey, + builder.add_node, + (b"akey",), + b"data aa", + ([(b"a key",)],), + ) # references keys must be tuples too - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, (b'akey', ), - b'data aa', (['not-a-tuple'], )) + self.assertRaises( + _mod_index.BadIndexKey, + builder.add_node, + (b"akey",), + b"data aa", + (["not-a-tuple"],), + ) # not enough length - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, (b'akey', ), - b'data aa', ([()], )) + self.assertRaises( + _mod_index.BadIndexKey, builder.add_node, (b"akey",), b"data aa", ([()],) + ) # too long - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, (b'akey', ), - b'data aa', ([(b'primary', b'secondary')], )) + self.assertRaises( + _mod_index.BadIndexKey, + builder.add_node, + (b"akey",), + b"data aa", + ([(b"primary", b"secondary")],), + ) # need to check more than the first key in the list - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, (b'akey', ), - b'data aa', ([(b'agoodkey', ), (b'that is a bad key', )], )) + self.assertRaises( + _mod_index.BadIndexKey, + builder.add_node, + (b"akey",), + b"data aa", + ([(b"agoodkey",), (b"that is a bad key",)],), + ) # and if there is more than one list it should be getting checked # too builder = _mod_index.GraphIndexBuilder(reference_lists=2) - self.assertRaises(_mod_index.BadIndexKey, builder.add_node, (b'akey', ), - b'data aa', ([], ['a bad key'])) + self.assertRaises( + _mod_index.BadIndexKey, + builder.add_node, + (b"akey",), + b"data aa", + ([], ["a bad key"]), + ) def test_add_duplicate_key(self): builder = _mod_index.GraphIndexBuilder() - builder.add_node((b'key', ), b'data') - self.assertRaises(_mod_index.BadIndexDuplicateKey, - builder.add_node, (b'key', ), b'data') + builder.add_node((b"key",), b"data") + self.assertRaises( + _mod_index.BadIndexDuplicateKey, builder.add_node, (b"key",), b"data" + ) def test_add_duplicate_key_2_elements(self): builder = _mod_index.GraphIndexBuilder(key_elements=2) - builder.add_node((b'key', b'key'), b'data') - self.assertRaises(_mod_index.BadIndexDuplicateKey, builder.add_node, - (b'key', b'key'), b'data') + builder.add_node((b"key", b"key"), b"data") + self.assertRaises( + _mod_index.BadIndexDuplicateKey, builder.add_node, (b"key", b"key"), b"data" + ) def test_add_key_after_referencing_key(self): builder = _mod_index.GraphIndexBuilder(reference_lists=1) - builder.add_node((b'key', ), b'data', ([(b'reference', )], )) - builder.add_node((b'reference', ), b'data', ([],)) + builder.add_node((b"key",), b"data", ([(b"reference",)],)) + builder.add_node((b"reference",), b"data", ([],)) def test_add_key_after_referencing_key_2_elements(self): - builder = _mod_index.GraphIndexBuilder( - reference_lists=1, key_elements=2) - builder.add_node((b'k', b'ey'), b'data', - ([(b'reference', b'tokey')], )) - builder.add_node((b'reference', b'tokey'), b'data', ([],)) + builder = _mod_index.GraphIndexBuilder(reference_lists=1, key_elements=2) + builder.add_node((b"k", b"ey"), b"data", ([(b"reference", b"tokey")],)) + builder.add_node((b"reference", b"tokey"), b"data", ([],)) def test_set_optimize(self): - builder = _mod_index.GraphIndexBuilder( - reference_lists=1, key_elements=2) + builder = _mod_index.GraphIndexBuilder(reference_lists=1, key_elements=2) builder.set_optimize(for_size=True) self.assertTrue(builder._optimize_for_size) builder.set_optimize(for_size=False) @@ -402,47 +477,42 @@ def test_set_optimize(self): class TestGraphIndex(tests.TestCaseWithMemoryTransport): - def make_key(self, number): - return ((b'%d' % number) + b'X' * 100,) + return ((b"%d" % number) + b"X" * 100,) def make_value(self, number): - return (b'%d' % number) + b'Y' * 100 + return (b"%d" % number) + b"Y" * 100 def make_nodes(self, count=64): # generate a big enough index that we only read some of it on a typical # bisection lookup. nodes = [] for counter in range(count): - nodes.append( - (self.make_key(counter), self.make_value(counter), ())) + nodes.append((self.make_key(counter), self.make_value(counter), ())) return nodes def make_index(self, ref_lists=0, key_elements=1, nodes=None): if nodes is None: nodes = [] - builder = _mod_index.GraphIndexBuilder( - ref_lists, key_elements=key_elements) + builder = _mod_index.GraphIndexBuilder(ref_lists, key_elements=key_elements) for key, value, references in nodes: builder.add_node(key, value, references) stream = builder.finish() - trans = transport.get_transport_from_url('trace+' + self.get_url()) - size = trans.put_file('index', stream) - return _mod_index.GraphIndex(trans, 'index', size) + trans = transport.get_transport_from_url("trace+" + self.get_url()) + size = trans.put_file("index", stream) + return _mod_index.GraphIndex(trans, "index", size) - def make_index_with_offset(self, ref_lists=0, key_elements=1, nodes=None, - offset=0): + def make_index_with_offset(self, ref_lists=0, key_elements=1, nodes=None, offset=0): if nodes is None: nodes = [] - builder = _mod_index.GraphIndexBuilder( - ref_lists, key_elements=key_elements) + builder = _mod_index.GraphIndexBuilder(ref_lists, key_elements=key_elements) for key, value, references in nodes: builder.add_node(key, value, references) content = builder.finish().read() size = len(content) trans = self.get_transport() - trans.put_bytes('index', (b' ' * offset) + content) - return _mod_index.GraphIndex(trans, 'index', size, offset=offset) + trans.put_bytes("index", (b" " * offset) + content) + return _mod_index.GraphIndex(trans, "index", size, offset=offset) def test_clear_cache(self): index = self.make_index() @@ -452,8 +522,8 @@ def test_clear_cache(self): def test_open_bad_index_no_error(self): trans = self.get_transport() - trans.put_bytes('name', b"not an index\n") - _mod_index.GraphIndex(trans, 'name', 13) + trans.put_bytes("name", b"not an index\n") + _mod_index.GraphIndex(trans, "name", 13) def test_with_offset(self): nodes = self.make_nodes(200) @@ -486,10 +556,12 @@ def test_key_count_buffers(self): del index._transport._activity[:] self.assertEqual(2, index.key_count()) # We should have requested reading the header bytes - self.assertEqual([ - ('readv', 'index', [(0, 200)], True, index._size), + self.assertEqual( + [ + ("readv", "index", [(0, 200)], True, index._size), ], - index._transport._activity) + index._transport._activity, + ) # And that should have been enough to trigger reading the whole index # with buffering self.assertIsNot(None, index._nodes) @@ -500,18 +572,18 @@ def test_lookup_key_via_location_buffers(self): del index._transport._activity[:] # do a _lookup_keys_via_location call for the middle of the file, which # is what bisection uses. - result = index._lookup_keys_via_location( - [(index._size // 2, (b'missing', ))]) + result = index._lookup_keys_via_location([(index._size // 2, (b"missing",))]) # this should have asked for a readv request, with adjust_for_latency, # and two regions: the header, and half-way into the file. - self.assertEqual([ - ('readv', 'index', [(30, 30), (0, 200)], True, 60), + self.assertEqual( + [ + ("readv", "index", [(30, 30), (0, 200)], True, 60), ], - index._transport._activity) + index._transport._activity, + ) # and the result should be that the key cannot be present, because this # is a trivial index. - self.assertEqual([((index._size // 2, (b'missing', )), False)], - result) + self.assertEqual([((index._size // 2, (b"missing",)), False)], result) # And this should have caused the file to be fully buffered self.assertIsNot(None, index._nodes) self.assertEqual([], index._parsed_byte_map) @@ -529,43 +601,42 @@ def test_first_lookup_key_via_location(self): # do a _lookup_keys_via_location call for the middle of the file, which # is what bisection uses. start_lookup = index._size // 2 - result = index._lookup_keys_via_location( - [(start_lookup, (b'40missing', ))]) + result = index._lookup_keys_via_location([(start_lookup, (b"40missing",))]) # this should have asked for a readv request, with adjust_for_latency, # and two regions: the header, and half-way into the file. - self.assertEqual([ - ('readv', 'index', - [(start_lookup, 800), (0, 200)], True, index._size), + self.assertEqual( + [ + ("readv", "index", [(start_lookup, 800), (0, 200)], True, index._size), ], - index._transport._activity) + index._transport._activity, + ) # and the result should be that the key cannot be present, because this # is a trivial index. - self.assertEqual([((start_lookup, (b'40missing', )), False)], - result) + self.assertEqual([((start_lookup, (b"40missing",)), False)], result) # And this should not have caused the file to be fully buffered self.assertIs(None, index._nodes) # And the regions of the file that have been parsed should be in the # parsed_byte_map and the parsed_key_map self.assertEqual([(0, 4008), (5046, 8996)], index._parsed_byte_map) - self.assertEqual([((), self.make_key(26)), - (self.make_key(31), self.make_key(48))], - index._parsed_key_map) + self.assertEqual( + [((), self.make_key(26)), (self.make_key(31), self.make_key(48))], + index._parsed_key_map, + ) def test_parsing_non_adjacent_data_trims(self): index = self.make_index(nodes=self.make_nodes(64)) - result = index._lookup_keys_via_location( - [(index._size // 2, (b'40', ))]) + result = index._lookup_keys_via_location([(index._size // 2, (b"40",))]) # and the result should be that the key cannot be present, because key is # in the middle of the observed data from a 4K read - the smallest transport # will do today with this api. - self.assertEqual([((index._size // 2, (b'40', )), False)], - result) + self.assertEqual([((index._size // 2, (b"40",)), False)], result) # and we should have a parse map that includes the header and the # region that was parsed after trimming. self.assertEqual([(0, 4008), (5046, 8996)], index._parsed_byte_map) - self.assertEqual([((), self.make_key(26)), - (self.make_key(31), self.make_key(48))], - index._parsed_key_map) + self.assertEqual( + [((), self.make_key(26)), (self.make_key(31), self.make_key(48))], + index._parsed_key_map, + ) def test_parsing_data_handles_parsed_contained_regions(self): # the following patten creates a parsed region that is wholly within a @@ -582,24 +653,31 @@ def test_parsing_data_handles_parsed_contained_regions(self): # we except both to be found, and the parsed byte map to include the # locations of both keys. index = self.make_index(nodes=self.make_nodes(128)) - result = index._lookup_keys_via_location( - [(index._size // 2, (b'40', ))]) + result = index._lookup_keys_via_location([(index._size // 2, (b"40",))]) # and we should have a parse map that includes the header and the # region that was parsed after trimming. self.assertEqual([(0, 4045), (11759, 15707)], index._parsed_byte_map) - self.assertEqual([((), self.make_key(116)), - (self.make_key(35), self.make_key(51))], - index._parsed_key_map) + self.assertEqual( + [((), self.make_key(116)), (self.make_key(35), self.make_key(51))], + index._parsed_key_map, + ) # now ask for two keys, right before and after the parsed region result = index._lookup_keys_via_location( - [(11450, self.make_key(34)), (15707, self.make_key(52))]) - self.assertEqual([ - ((11450, self.make_key(34)), - (index, self.make_key(34), self.make_value(34))), - ((15707, self.make_key(52)), - (index, self.make_key(52), self.make_value(52))), + [(11450, self.make_key(34)), (15707, self.make_key(52))] + ) + self.assertEqual( + [ + ( + (11450, self.make_key(34)), + (index, self.make_key(34), self.make_value(34)), + ), + ( + (15707, self.make_key(52)), + (index, self.make_key(52), self.make_value(52)), + ), ], - result) + result, + ) self.assertEqual([(0, 4045), (9889, 17993)], index._parsed_byte_map) def test_lookup_missing_key_answers_without_io_when_map_permits(self): @@ -607,23 +685,21 @@ def test_lookup_missing_key_answers_without_io_when_map_permits(self): # bisection lookup. index = self.make_index(nodes=self.make_nodes(64)) # lookup the keys in the middle of the file - result = index._lookup_keys_via_location( - [(index._size // 2, (b'40', ))]) + result = index._lookup_keys_via_location([(index._size // 2, (b"40",))]) # check the parse map, this determines the test validity self.assertEqual([(0, 4008), (5046, 8996)], index._parsed_byte_map) - self.assertEqual([((), self.make_key(26)), - (self.make_key(31), self.make_key(48))], - index._parsed_key_map) + self.assertEqual( + [((), self.make_key(26)), (self.make_key(31), self.make_key(48))], + index._parsed_key_map, + ) # reset the transport log del index._transport._activity[:] # now looking up a key in the portion of the file already parsed should # not create a new transport request, and should return False (cannot # be in the index) - even when the byte location we ask for is outside # the parsed region - result = index._lookup_keys_via_location( - [(4000, (b'40', ))]) - self.assertEqual([((4000, (b'40', )), False)], - result) + result = index._lookup_keys_via_location([(4000, (b"40",))]) + self.assertEqual([((4000, (b"40",)), False)], result) self.assertEqual([], index._transport._activity) def test_lookup_present_key_answers_without_io_when_map_permits(self): @@ -631,13 +707,13 @@ def test_lookup_present_key_answers_without_io_when_map_permits(self): # bisection lookup. index = self.make_index(nodes=self.make_nodes(64)) # lookup the keys in the middle of the file - result = index._lookup_keys_via_location( - [(index._size // 2, (b'40', ))]) + result = index._lookup_keys_via_location([(index._size // 2, (b"40",))]) # check the parse map, this determines the test validity self.assertEqual([(0, 4008), (5046, 8996)], index._parsed_byte_map) - self.assertEqual([((), self.make_key(26)), - (self.make_key(31), self.make_key(48))], - index._parsed_key_map) + self.assertEqual( + [((), self.make_key(26)), (self.make_key(31), self.make_key(48))], + index._parsed_key_map, + ) # reset the transport log del index._transport._activity[:] # now looking up a key in the portion of the file already parsed should @@ -647,9 +723,14 @@ def test_lookup_present_key_answers_without_io_when_map_permits(self): # result = index._lookup_keys_via_location([(4000, self.make_key(40))]) self.assertEqual( - [((4000, self.make_key(40)), - (index, self.make_key(40), self.make_value(40)))], - result) + [ + ( + (4000, self.make_key(40)), + (index, self.make_key(40), self.make_value(40)), + ) + ], + result, + ) self.assertEqual([], index._transport._activity) def test_lookup_key_below_probed_area(self): @@ -658,15 +739,14 @@ def test_lookup_key_below_probed_area(self): index = self.make_index(nodes=self.make_nodes(64)) # ask for the key in the middle, but a key that is located in the # unparsed region before the middle. - result = index._lookup_keys_via_location( - [(index._size // 2, (b'30', ))]) + result = index._lookup_keys_via_location([(index._size // 2, (b"30",))]) # check the parse map, this determines the test validity self.assertEqual([(0, 4008), (5046, 8996)], index._parsed_byte_map) - self.assertEqual([((), self.make_key(26)), - (self.make_key(31), self.make_key(48))], - index._parsed_key_map) - self.assertEqual([((index._size // 2, (b'30', )), -1)], - result) + self.assertEqual( + [((), self.make_key(26)), (self.make_key(31), self.make_key(48))], + index._parsed_key_map, + ) + self.assertEqual([((index._size // 2, (b"30",)), -1)], result) def test_lookup_key_above_probed_area(self): # generate a big enough index that we only read some of it on a typical @@ -674,126 +754,169 @@ def test_lookup_key_above_probed_area(self): index = self.make_index(nodes=self.make_nodes(64)) # ask for the key in the middle, but a key that is located in the # unparsed region after the middle. - result = index._lookup_keys_via_location( - [(index._size // 2, (b'50', ))]) + result = index._lookup_keys_via_location([(index._size // 2, (b"50",))]) # check the parse map, this determines the test validity self.assertEqual([(0, 4008), (5046, 8996)], index._parsed_byte_map) - self.assertEqual([((), self.make_key(26)), - (self.make_key(31), self.make_key(48))], - index._parsed_key_map) - self.assertEqual([((index._size // 2, (b'50', )), +1)], - result) + self.assertEqual( + [((), self.make_key(26)), (self.make_key(31), self.make_key(48))], + index._parsed_key_map, + ) + self.assertEqual([((index._size // 2, (b"50",)), +1)], result) def test_lookup_key_resolves_references(self): # generate a big enough index that we only read some of it on a typical # bisection lookup. nodes = [] for counter in range(99): - nodes.append((self.make_key(counter), self.make_value(counter), - ((self.make_key(counter + 20),),))) + nodes.append( + ( + self.make_key(counter), + self.make_value(counter), + ((self.make_key(counter + 20),),), + ) + ) index = self.make_index(ref_lists=1, nodes=nodes) # lookup a key in the middle that does not exist, so that when we can # check that the referred-to-keys are not accessed automatically. index_size = index._size index_center = index_size // 2 - result = index._lookup_keys_via_location( - [(index_center, (b'40', ))]) + result = index._lookup_keys_via_location([(index_center, (b"40",))]) # check the parse map - only the start and middle should have been # parsed. self.assertEqual([(0, 4027), (10198, 14028)], index._parsed_byte_map) - self.assertEqual([((), self.make_key(17)), - (self.make_key(44), self.make_key(5))], - index._parsed_key_map) + self.assertEqual( + [((), self.make_key(17)), (self.make_key(44), self.make_key(5))], + index._parsed_key_map, + ) # and check the transport activity likewise. self.assertEqual( - [('readv', 'index', [(index_center, 800), (0, 200)], True, - index_size)], - index._transport._activity) + [("readv", "index", [(index_center, 800), (0, 200)], True, index_size)], + index._transport._activity, + ) # reset the transport log for testing the reference lookup del index._transport._activity[:] # now looking up a key in the portion of the file already parsed should # only perform IO to resolve its key references. result = index._lookup_keys_via_location([(11000, self.make_key(45))]) self.assertEqual( - [((11000, self.make_key(45)), - (index, self.make_key(45), self.make_value(45), - ((self.make_key(65),),)))], - result) - self.assertEqual([('readv', 'index', [(15093, 800)], True, index_size)], - index._transport._activity) + [ + ( + (11000, self.make_key(45)), + ( + index, + self.make_key(45), + self.make_value(45), + ((self.make_key(65),),), + ), + ) + ], + result, + ) + self.assertEqual( + [("readv", "index", [(15093, 800)], True, index_size)], + index._transport._activity, + ) def test_lookup_key_can_buffer_all(self): nodes = [] for counter in range(64): - nodes.append((self.make_key(counter), self.make_value(counter), - ((self.make_key(counter + 20),),))) + nodes.append( + ( + self.make_key(counter), + self.make_value(counter), + ((self.make_key(counter + 20),),), + ) + ) index = self.make_index(ref_lists=1, nodes=nodes) # lookup a key in the middle that does not exist, so that when we can # check that the referred-to-keys are not accessed automatically. index_size = index._size index_center = index_size // 2 - result = index._lookup_keys_via_location([(index_center, (b'40', ))]) + result = index._lookup_keys_via_location([(index_center, (b"40",))]) # check the parse map - only the start and middle should have been # parsed. self.assertEqual([(0, 3890), (6444, 10274)], index._parsed_byte_map) - self.assertEqual([((), self.make_key(25)), - (self.make_key(37), self.make_key(52))], - index._parsed_key_map) + self.assertEqual( + [((), self.make_key(25)), (self.make_key(37), self.make_key(52))], + index._parsed_key_map, + ) # and check the transport activity likewise. self.assertEqual( - [('readv', 'index', [(index_center, 800), (0, 200)], True, - index_size)], - index._transport._activity) + [("readv", "index", [(index_center, 800), (0, 200)], True, index_size)], + index._transport._activity, + ) # reset the transport log for testing the reference lookup del index._transport._activity[:] # now looking up a key in the portion of the file already parsed should # only perform IO to resolve its key references. result = index._lookup_keys_via_location([(7000, self.make_key(40))]) self.assertEqual( - [((7000, self.make_key(40)), - (index, self.make_key(40), self.make_value(40), - ((self.make_key(60),),)))], - result) + [ + ( + (7000, self.make_key(40)), + ( + index, + self.make_key(40), + self.make_value(40), + ((self.make_key(60),),), + ), + ) + ], + result, + ) # Resolving the references would have required more data read, and we # are already above the 50% threshold, so it triggered a _buffer_all - self.assertEqual([('get', 'index')], index._transport._activity) + self.assertEqual([("get", "index")], index._transport._activity) def test_iter_all_entries_empty(self): index = self.make_index() self.assertEqual([], list(index.iter_all_entries())) def test_iter_all_entries_simple(self): - index = self.make_index(nodes=[((b'name', ), b'data', ())]) - self.assertEqual([(index, (b'name', ), b'data')], - list(index.iter_all_entries())) + index = self.make_index(nodes=[((b"name",), b"data", ())]) + self.assertEqual([(index, (b"name",), b"data")], list(index.iter_all_entries())) def test_iter_all_entries_simple_2_elements(self): - index = self.make_index(key_elements=2, - nodes=[((b'name', b'surname'), b'data', ())]) - self.assertEqual([(index, (b'name', b'surname'), b'data')], - list(index.iter_all_entries())) + index = self.make_index( + key_elements=2, nodes=[((b"name", b"surname"), b"data", ())] + ) + self.assertEqual( + [(index, (b"name", b"surname"), b"data")], list(index.iter_all_entries()) + ) def test_iter_all_entries_references_resolved(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], )), - ((b'ref', ), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),)), - (index, (b'ref', ), b'refdata', ((), ))}, - set(index.iter_all_entries())) + index = self.make_index( + 1, + nodes=[ + ((b"name",), b"data", ([(b"ref",)],)), + ((b"ref",), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name",), b"data", (((b"ref",),),)), + (index, (b"ref",), b"refdata", ((),)), + }, + set(index.iter_all_entries()), + ) def test_iter_entries_buffers_once(self): index = self.make_index(nodes=self.make_nodes(2)) # reset the transport log del index._transport._activity[:] - self.assertEqual({(index, self.make_key(1), self.make_value(1))}, - set(index.iter_entries([self.make_key(1)]))) + self.assertEqual( + {(index, self.make_key(1), self.make_value(1))}, + set(index.iter_entries([self.make_key(1)])), + ) # We should have requested reading the header bytes # But not needed any more than that because it would have triggered a # buffer all - self.assertEqual([ - ('readv', 'index', [(0, 200)], True, index._size), + self.assertEqual( + [ + ("readv", "index", [(0, 200)], True, index._size), ], - index._transport._activity) + index._transport._activity, + ) # And that should have been enough to trigger reading the whole index # with buffering self.assertIsNot(None, index._nodes) @@ -824,46 +947,80 @@ def test_iter_entries_buffers_by_bytes_read(self): self.assertIsNot(None, index._nodes) def test_iter_entries_references_resolved(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', ), (b'ref', )], )), - ((b'ref', ), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',), (b'ref',)),)), - (index, (b'ref', ), b'refdata', ((), ))}, - set(index.iter_entries([(b'name',), (b'ref',)]))) + index = self.make_index( + 1, + nodes=[ + ((b"name",), b"data", ([(b"ref",), (b"ref",)],)), + ((b"ref",), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name",), b"data", (((b"ref",), (b"ref",)),)), + (index, (b"ref",), b"refdata", ((),)), + }, + set(index.iter_entries([(b"name",), (b"ref",)])), + ) def test_iter_entries_references_2_refs_resolved(self): - index = self.make_index(2, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], [(b'ref', )])), - ((b'ref', ), b'refdata', ([], []))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),), ((b'ref',),))), - (index, (b'ref', ), b'refdata', ((), ()))}, - set(index.iter_entries([(b'name',), (b'ref',)]))) + index = self.make_index( + 2, + nodes=[ + ((b"name",), b"data", ([(b"ref",)], [(b"ref",)])), + ((b"ref",), b"refdata", ([], [])), + ], + ) + self.assertEqual( + { + (index, (b"name",), b"data", (((b"ref",),), ((b"ref",),))), + (index, (b"ref",), b"refdata", ((), ())), + }, + set(index.iter_entries([(b"name",), (b"ref",)])), + ) def test_iteration_absent_skipped(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),))}, - set(index.iter_all_entries())) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),))}, - set(index.iter_entries([(b'name', )]))) - self.assertEqual([], list(index.iter_entries([(b'ref', )]))) + index = self.make_index(1, nodes=[((b"name",), b"data", ([(b"ref",)],))]) + self.assertEqual( + {(index, (b"name",), b"data", (((b"ref",),),))}, + set(index.iter_all_entries()), + ) + self.assertEqual( + {(index, (b"name",), b"data", (((b"ref",),),))}, + set(index.iter_entries([(b"name",)])), + ) + self.assertEqual([], list(index.iter_entries([(b"ref",)]))) def test_iteration_absent_skipped_2_element_keys(self): - index = self.make_index(1, key_elements=2, nodes=[ - ((b'name', b'fin'), b'data', ([(b'ref', b'erence')], ))]) - self.assertEqual([(index, (b'name', b'fin'), b'data', (((b'ref', b'erence'),),))], - list(index.iter_all_entries())) - self.assertEqual([(index, (b'name', b'fin'), b'data', (((b'ref', b'erence'),),))], - list(index.iter_entries([(b'name', b'fin')]))) - self.assertEqual([], list(index.iter_entries([(b'ref', b'erence')]))) + index = self.make_index( + 1, + key_elements=2, + nodes=[((b"name", b"fin"), b"data", ([(b"ref", b"erence")],))], + ) + self.assertEqual( + [(index, (b"name", b"fin"), b"data", (((b"ref", b"erence"),),))], + list(index.iter_all_entries()), + ) + self.assertEqual( + [(index, (b"name", b"fin"), b"data", (((b"ref", b"erence"),),))], + list(index.iter_entries([(b"name", b"fin")])), + ) + self.assertEqual([], list(index.iter_entries([(b"ref", b"erence")]))) def test_iter_all_keys(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], )), - ((b'ref', ), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),)), - (index, (b'ref', ), b'refdata', ((), ))}, - set(index.iter_entries([(b'name', ), (b'ref', )]))) + index = self.make_index( + 1, + nodes=[ + ((b"name",), b"data", ([(b"ref",)],)), + ((b"ref",), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name",), b"data", (((b"ref",),),)), + (index, (b"ref",), b"refdata", ((),)), + }, + set(index.iter_entries([(b"name",), (b"ref",)])), + ) def test_iter_nothing_empty(self): index = self.make_index() @@ -871,99 +1028,144 @@ def test_iter_nothing_empty(self): def test_iter_missing_entry_empty(self): index = self.make_index() - self.assertEqual([], list(index.iter_entries([(b'a', )]))) + self.assertEqual([], list(index.iter_entries([(b"a",)]))) def test_iter_missing_entry_empty_no_size(self): idx = self.make_index() - idx = _mod_index.GraphIndex(idx._transport, 'index', None) - self.assertEqual([], list(idx.iter_entries([(b'a', )]))) + idx = _mod_index.GraphIndex(idx._transport, "index", None) + self.assertEqual([], list(idx.iter_entries([(b"a",)]))) def test_iter_key_prefix_1_element_key_None(self): index = self.make_index() - self.assertRaises(_mod_index.BadIndexKey, list, - index.iter_entries_prefix([(None, )])) + self.assertRaises( + _mod_index.BadIndexKey, list, index.iter_entries_prefix([(None,)]) + ) def test_iter_key_prefix_wrong_length(self): index = self.make_index() - self.assertRaises(_mod_index.BadIndexKey, list, - index.iter_entries_prefix([(b'foo', None)])) + self.assertRaises( + _mod_index.BadIndexKey, list, index.iter_entries_prefix([(b"foo", None)]) + ) index = self.make_index(key_elements=2) - self.assertRaises(_mod_index.BadIndexKey, list, - index.iter_entries_prefix([(b'foo', )])) - self.assertRaises(_mod_index.BadIndexKey, list, - index.iter_entries_prefix([(b'foo', None, None)])) + self.assertRaises( + _mod_index.BadIndexKey, list, index.iter_entries_prefix([(b"foo",)]) + ) + self.assertRaises( + _mod_index.BadIndexKey, + list, + index.iter_entries_prefix([(b"foo", None, None)]), + ) def test_iter_key_prefix_1_key_element_no_refs(self): - index = self.make_index(nodes=[ - ((b'name', ), b'data', ()), - ((b'ref', ), b'refdata', ())]) - self.assertEqual({(index, (b'name', ), b'data'), - (index, (b'ref', ), b'refdata')}, - set(index.iter_entries_prefix([(b'name', ), (b'ref', )]))) + index = self.make_index( + nodes=[((b"name",), b"data", ()), ((b"ref",), b"refdata", ())] + ) + self.assertEqual( + {(index, (b"name",), b"data"), (index, (b"ref",), b"refdata")}, + set(index.iter_entries_prefix([(b"name",), (b"ref",)])), + ) def test_iter_key_prefix_1_key_element_refs(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], )), - ((b'ref', ), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),)), - (index, (b'ref', ), b'refdata', ((), ))}, - set(index.iter_entries_prefix([(b'name', ), (b'ref', )]))) + index = self.make_index( + 1, + nodes=[ + ((b"name",), b"data", ([(b"ref",)],)), + ((b"ref",), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name",), b"data", (((b"ref",),),)), + (index, (b"ref",), b"refdata", ((),)), + }, + set(index.iter_entries_prefix([(b"name",), (b"ref",)])), + ) def test_iter_key_prefix_2_key_element_no_refs(self): - index = self.make_index(key_elements=2, nodes=[ - ((b'name', b'fin1'), b'data', ()), - ((b'name', b'fin2'), b'beta', ()), - ((b'ref', b'erence'), b'refdata', ())]) - self.assertEqual({(index, (b'name', b'fin1'), b'data'), - (index, (b'ref', b'erence'), b'refdata')}, - set(index.iter_entries_prefix([(b'name', b'fin1'), (b'ref', b'erence')]))) - self.assertEqual({(index, (b'name', b'fin1'), b'data'), - (index, (b'name', b'fin2'), b'beta')}, - set(index.iter_entries_prefix([(b'name', None)]))) + index = self.make_index( + key_elements=2, + nodes=[ + ((b"name", b"fin1"), b"data", ()), + ((b"name", b"fin2"), b"beta", ()), + ((b"ref", b"erence"), b"refdata", ()), + ], + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data"), + (index, (b"ref", b"erence"), b"refdata"), + }, + set(index.iter_entries_prefix([(b"name", b"fin1"), (b"ref", b"erence")])), + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data"), + (index, (b"name", b"fin2"), b"beta"), + }, + set(index.iter_entries_prefix([(b"name", None)])), + ) def test_iter_key_prefix_2_key_element_refs(self): - index = self.make_index(1, key_elements=2, nodes=[ - ((b'name', b'fin1'), b'data', ([(b'ref', b'erence')], )), - ((b'name', b'fin2'), b'beta', ([], )), - ((b'ref', b'erence'), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', b'fin1'), b'data', (((b'ref', b'erence'),),)), - (index, (b'ref', b'erence'), b'refdata', ((), ))}, - set(index.iter_entries_prefix([(b'name', b'fin1'), (b'ref', b'erence')]))) - self.assertEqual({(index, (b'name', b'fin1'), b'data', (((b'ref', b'erence'),),)), - (index, (b'name', b'fin2'), b'beta', ((), ))}, - set(index.iter_entries_prefix([(b'name', None)]))) + index = self.make_index( + 1, + key_elements=2, + nodes=[ + ((b"name", b"fin1"), b"data", ([(b"ref", b"erence")],)), + ((b"name", b"fin2"), b"beta", ([],)), + ((b"ref", b"erence"), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data", (((b"ref", b"erence"),),)), + (index, (b"ref", b"erence"), b"refdata", ((),)), + }, + set(index.iter_entries_prefix([(b"name", b"fin1"), (b"ref", b"erence")])), + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data", (((b"ref", b"erence"),),)), + (index, (b"name", b"fin2"), b"beta", ((),)), + }, + set(index.iter_entries_prefix([(b"name", None)])), + ) def test_key_count_empty(self): index = self.make_index() self.assertEqual(0, index.key_count()) def test_key_count_one(self): - index = self.make_index(nodes=[((b'name', ), b'', ())]) + index = self.make_index(nodes=[((b"name",), b"", ())]) self.assertEqual(1, index.key_count()) def test_key_count_two(self): - index = self.make_index(nodes=[ - ((b'name', ), b'', ()), ((b'foo', ), b'', ())]) + index = self.make_index(nodes=[((b"name",), b"", ()), ((b"foo",), b"", ())]) self.assertEqual(2, index.key_count()) def test_read_and_parse_tracks_real_read_value(self): index = self.make_index(nodes=self.make_nodes(10)) del index._transport._activity[:] index._read_and_parse([(0, 200)]) - self.assertEqual([ - ('readv', 'index', [(0, 200)], True, index._size), + self.assertEqual( + [ + ("readv", "index", [(0, 200)], True, index._size), ], - index._transport._activity) + index._transport._activity, + ) # The readv expansion code will expand the initial request to 4096 # bytes, which is more than enough to read the entire index, and we # will track the fact that we read that many bytes. self.assertEqual(index._size, index._bytes_read) def test_read_and_parse_triggers_buffer_all(self): - index = self.make_index(key_elements=2, nodes=[ - ((b'name', b'fin1'), b'data', ()), - ((b'name', b'fin2'), b'beta', ()), - ((b'ref', b'erence'), b'refdata', ())]) + index = self.make_index( + key_elements=2, + nodes=[ + ((b"name", b"fin1"), b"data", ()), + ((b"name", b"fin2"), b"beta", ()), + ((b"ref", b"erence"), b"refdata", ()), + ], + ) self.assertGreater(index._size, 0) self.assertIs(None, index._nodes) index._read_and_parse([(0, index._size)]) @@ -971,33 +1173,33 @@ def test_read_and_parse_triggers_buffer_all(self): def test_validate_bad_index_errors(self): trans = self.get_transport() - trans.put_bytes('name', b"not an index\n") - idx = _mod_index.GraphIndex(trans, 'name', 13) + trans.put_bytes("name", b"not an index\n") + idx = _mod_index.GraphIndex(trans, "name", 13) self.assertRaises(_mod_index.BadIndexFormatSignature, idx.validate) def test_validate_bad_node_refs(self): idx = self.make_index(2) trans = self.get_transport() - content = trans.get_bytes('index') + content = trans.get_bytes("index") # change the options line to end with a rather than a parseable number - new_content = content[:-2] + b'a\n\n' - trans.put_bytes('index', new_content) + new_content = content[:-2] + b"a\n\n" + trans.put_bytes("index", new_content) self.assertRaises(_mod_index.BadIndexOptions, idx.validate) def test_validate_missing_end_line_empty(self): index = self.make_index(2) trans = self.get_transport() - content = trans.get_bytes('index') + content = trans.get_bytes("index") # truncate the last byte - trans.put_bytes('index', content[:-1]) + trans.put_bytes("index", content[:-1]) self.assertRaises(_mod_index.BadIndexData, index.validate) def test_validate_missing_end_line_nonempty(self): - index = self.make_index(2, nodes=[((b'key', ), b'', ([], []))]) + index = self.make_index(2, nodes=[((b"key",), b"", ([], []))]) trans = self.get_transport() - content = trans.get_bytes('index') + content = trans.get_bytes("index") # truncate the last byte - trans.put_bytes('index', content[:-1]) + trans.put_bytes("index", content[:-1]) self.assertRaises(_mod_index.BadIndexData, index.validate) def test_validate_empty(self): @@ -1005,7 +1207,7 @@ def test_validate_empty(self): index.validate() def test_validate_no_refs_content(self): - index = self.make_index(nodes=[((b'key', ), b'value', ())]) + index = self.make_index(nodes=[((b"key",), b"value", ())]) index.validate() # XXX: external_references tests are duplicated in test_btree_index. We @@ -1015,80 +1217,92 @@ def test_external_references_no_refs(self): self.assertRaises(ValueError, index.external_references, 0) def test_external_references_no_results(self): - index = self.make_index(ref_lists=1, nodes=[ - ((b'key',), b'value', ([],))]) + index = self.make_index(ref_lists=1, nodes=[((b"key",), b"value", ([],))]) self.assertEqual(set(), index.external_references(0)) def test_external_references_missing_ref(self): - missing_key = (b'missing',) - index = self.make_index(ref_lists=1, nodes=[ - ((b'key',), b'value', ([missing_key],))]) + missing_key = (b"missing",) + index = self.make_index( + ref_lists=1, nodes=[((b"key",), b"value", ([missing_key],))] + ) self.assertEqual({missing_key}, index.external_references(0)) def test_external_references_multiple_ref_lists(self): - missing_key = (b'missing',) - index = self.make_index(ref_lists=2, nodes=[ - ((b'key',), b'value', ([], [missing_key]))]) + missing_key = (b"missing",) + index = self.make_index( + ref_lists=2, nodes=[((b"key",), b"value", ([], [missing_key]))] + ) self.assertEqual(set(), index.external_references(0)) self.assertEqual({missing_key}, index.external_references(1)) def test_external_references_two_records(self): - index = self.make_index(ref_lists=1, nodes=[ - ((b'key-1',), b'value', ([(b'key-2',)],)), - ((b'key-2',), b'value', ([],)), - ]) + index = self.make_index( + ref_lists=1, + nodes=[ + ((b"key-1",), b"value", ([(b"key-2",)],)), + ((b"key-2",), b"value", ([],)), + ], + ) self.assertEqual(set(), index.external_references(0)) def test__find_ancestors(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - index = self.make_index(ref_lists=1, key_elements=1, nodes=[ - (key1, b'value', ([key2],)), - (key2, b'value', ([],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + index = self.make_index( + ref_lists=1, + key_elements=1, + nodes=[ + (key1, b"value", ([key2],)), + (key2, b"value", ([],)), + ], + ) parent_map = {} missing_keys = set() - search_keys = index._find_ancestors( - [key1], 0, parent_map, missing_keys) + search_keys = index._find_ancestors([key1], 0, parent_map, missing_keys) self.assertEqual({key1: (key2,)}, parent_map) self.assertEqual(set(), missing_keys) self.assertEqual({key2}, search_keys) - search_keys = index._find_ancestors(search_keys, 0, parent_map, - missing_keys) + search_keys = index._find_ancestors(search_keys, 0, parent_map, missing_keys) self.assertEqual({key1: (key2,), key2: ()}, parent_map) self.assertEqual(set(), missing_keys) self.assertEqual(set(), search_keys) def test__find_ancestors_w_missing(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - key3 = (b'key-3',) - index = self.make_index(ref_lists=1, key_elements=1, nodes=[ - (key1, b'value', ([key2],)), - (key2, b'value', ([],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + key3 = (b"key-3",) + index = self.make_index( + ref_lists=1, + key_elements=1, + nodes=[ + (key1, b"value", ([key2],)), + (key2, b"value", ([],)), + ], + ) parent_map = {} missing_keys = set() - search_keys = index._find_ancestors([key2, key3], 0, parent_map, - missing_keys) + search_keys = index._find_ancestors([key2, key3], 0, parent_map, missing_keys) self.assertEqual({key2: ()}, parent_map) self.assertEqual({key3}, missing_keys) self.assertEqual(set(), search_keys) def test__find_ancestors_dont_search_known(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - key3 = (b'key-3',) - index = self.make_index(ref_lists=1, key_elements=1, nodes=[ - (key1, b'value', ([key2],)), - (key2, b'value', ([key3],)), - (key3, b'value', ([],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + key3 = (b"key-3",) + index = self.make_index( + ref_lists=1, + key_elements=1, + nodes=[ + (key1, b"value", ([key2],)), + (key2, b"value", ([key3],)), + (key3, b"value", ([],)), + ], + ) # We already know about key2, so we won't try to search for key3 parent_map = {key2: (key3,)} missing_keys = set() - search_keys = index._find_ancestors([key1], 0, parent_map, - missing_keys) + search_keys = index._find_ancestors([key1], 0, parent_map, missing_keys) self.assertEqual({key1: (key2,), key2: (key3,)}, parent_map) self.assertEqual(set(), missing_keys) self.assertEqual(set(), search_keys) @@ -1097,19 +1311,17 @@ def test_supports_unlimited_cache(self): builder = _mod_index.GraphIndexBuilder(0, key_elements=1) stream = builder.finish() trans = self.get_transport() - size = trans.put_file('index', stream) + size = trans.put_file("index", stream) # It doesn't matter what unlimited_cache does here, just that it can be # passed - _mod_index.GraphIndex(trans, 'index', size, unlimited_cache=True) + _mod_index.GraphIndex(trans, "index", size, unlimited_cache=True) class TestCombinedGraphIndex(tests.TestCaseWithMemoryTransport): - def make_index(self, name, ref_lists=0, key_elements=1, nodes=None): if nodes is None: nodes = [] - builder = _mod_index.GraphIndexBuilder( - ref_lists, key_elements=key_elements) + builder = _mod_index.GraphIndexBuilder(ref_lists, key_elements=key_elements) for key, value, references in nodes: builder.add_node(key, value, references) stream = builder.finish() @@ -1128,12 +1340,10 @@ def make_combined_index_with_missing(self, missing=None): :return: (CombinedGraphIndex, reload_counter) """ if missing is None: - missing = ['1', '2'] - idx1 = self.make_index('1', nodes=[((b'1',), b'', ())]) - idx2 = self.make_index('2', nodes=[((b'2',), b'', ())]) - idx3 = self.make_index('3', nodes=[ - ((b'1',), b'', ()), - ((b'2',), b'', ())]) + missing = ["1", "2"] + idx1 = self.make_index("1", nodes=[((b"1",), b"", ())]) + idx2 = self.make_index("2", nodes=[((b"2",), b"", ())]) + idx3 = self.make_index("3", nodes=[((b"1",), b"", ()), ((b"2",), b"", ())]) # total_reloads, num_changed, num_unchanged reload_counter = [0, 0, 0] @@ -1147,6 +1357,7 @@ def reload(): reload_counter[1] += 1 idx._indices[:] = new_indices return True + idx = _mod_index.CombinedGraphIndex([idx1, idx2], reload_func=reload) trans = self.get_transport() for fname in missing: @@ -1155,21 +1366,19 @@ def reload(): def test_open_missing_index_no_error(self): trans = self.get_transport() - idx1 = _mod_index.GraphIndex(trans, 'missing', 100) + idx1 = _mod_index.GraphIndex(trans, "missing", 100) _mod_index.CombinedGraphIndex([idx1]) def test_add_index(self): idx = _mod_index.CombinedGraphIndex([]) - idx1 = self.make_index('name', 0, nodes=[((b'key', ), b'', ())]) + idx1 = self.make_index("name", 0, nodes=[((b"key",), b"", ())]) idx.insert_index(0, idx1) - self.assertEqual([(idx1, (b'key', ), b'')], - list(idx.iter_all_entries())) + self.assertEqual([(idx1, (b"key",), b"")], list(idx.iter_all_entries())) def test_clear_cache(self): log = [] class ClearCacheProxy: - def __init__(self, index): self._index = index @@ -1181,9 +1390,9 @@ def clear_cache(self): return self._index.clear_cache() idx = _mod_index.CombinedGraphIndex([]) - idx1 = self.make_index('name', 0, nodes=[((b'key', ), b'', ())]) + idx1 = self.make_index("name", 0, nodes=[((b"key",), b"", ())]) idx.insert_index(0, ClearCacheProxy(idx1)) - idx2 = self.make_index('name', 0, nodes=[((b'key', ), b'', ())]) + idx2 = self.make_index("name", 0, nodes=[((b"key",), b"", ())]) idx.insert_index(1, ClearCacheProxy(idx2)) # CombinedGraphIndex should call 'clear_cache()' on all children idx.clear_cache() @@ -1194,129 +1403,150 @@ def test_iter_all_entries_empty(self): self.assertEqual([], list(idx.iter_all_entries())) def test_iter_all_entries_children_empty(self): - idx1 = self.make_index('name') + idx1 = self.make_index("name") idx = _mod_index.CombinedGraphIndex([idx1]) self.assertEqual([], list(idx.iter_all_entries())) def test_iter_all_entries_simple(self): - idx1 = self.make_index('name', nodes=[((b'name', ), b'data', ())]) + idx1 = self.make_index("name", nodes=[((b"name",), b"data", ())]) idx = _mod_index.CombinedGraphIndex([idx1]) - self.assertEqual([(idx1, (b'name', ), b'data')], - list(idx.iter_all_entries())) + self.assertEqual([(idx1, (b"name",), b"data")], list(idx.iter_all_entries())) def test_iter_all_entries_two_indices(self): - idx1 = self.make_index('name1', nodes=[((b'name', ), b'data', ())]) - idx2 = self.make_index('name2', nodes=[((b'2', ), b'', ())]) + idx1 = self.make_index("name1", nodes=[((b"name",), b"data", ())]) + idx2 = self.make_index("name2", nodes=[((b"2",), b"", ())]) idx = _mod_index.CombinedGraphIndex([idx1, idx2]) - self.assertEqual([(idx1, (b'name', ), b'data'), - (idx2, (b'2', ), b'')], - list(idx.iter_all_entries())) + self.assertEqual( + [(idx1, (b"name",), b"data"), (idx2, (b"2",), b"")], + list(idx.iter_all_entries()), + ) def test_iter_entries_two_indices_dup_key(self): - idx1 = self.make_index('name1', nodes=[((b'name', ), b'data', ())]) - idx2 = self.make_index('name2', nodes=[((b'name', ), b'data', ())]) + idx1 = self.make_index("name1", nodes=[((b"name",), b"data", ())]) + idx2 = self.make_index("name2", nodes=[((b"name",), b"data", ())]) idx = _mod_index.CombinedGraphIndex([idx1, idx2]) - self.assertEqual([(idx1, (b'name', ), b'data')], - list(idx.iter_entries([(b'name', )]))) + self.assertEqual( + [(idx1, (b"name",), b"data")], list(idx.iter_entries([(b"name",)])) + ) def test_iter_all_entries_two_indices_dup_key(self): - idx1 = self.make_index('name1', nodes=[((b'name', ), b'data', ())]) - idx2 = self.make_index('name2', nodes=[((b'name', ), b'data', ())]) + idx1 = self.make_index("name1", nodes=[((b"name",), b"data", ())]) + idx2 = self.make_index("name2", nodes=[((b"name",), b"data", ())]) idx = _mod_index.CombinedGraphIndex([idx1, idx2]) - self.assertEqual([(idx1, (b'name', ), b'data')], - list(idx.iter_all_entries())) + self.assertEqual([(idx1, (b"name",), b"data")], list(idx.iter_all_entries())) def test_iter_key_prefix_2_key_element_refs(self): - idx1 = self.make_index('1', 1, key_elements=2, nodes=[ - ((b'name', b'fin1'), b'data', ([(b'ref', b'erence')], ))]) - idx2 = self.make_index('2', 1, key_elements=2, nodes=[ - ((b'name', b'fin2'), b'beta', ([], )), - ((b'ref', b'erence'), b'refdata', ([], ))]) + idx1 = self.make_index( + "1", + 1, + key_elements=2, + nodes=[((b"name", b"fin1"), b"data", ([(b"ref", b"erence")],))], + ) + idx2 = self.make_index( + "2", + 1, + key_elements=2, + nodes=[ + ((b"name", b"fin2"), b"beta", ([],)), + ((b"ref", b"erence"), b"refdata", ([],)), + ], + ) idx = _mod_index.CombinedGraphIndex([idx1, idx2]) - self.assertEqual({(idx1, (b'name', b'fin1'), b'data', - (((b'ref', b'erence'),),)), - (idx2, (b'ref', b'erence'), b'refdata', ((), ))}, - set(idx.iter_entries_prefix([(b'name', b'fin1'), - (b'ref', b'erence')]))) - self.assertEqual({(idx1, (b'name', b'fin1'), b'data', - (((b'ref', b'erence'),),)), - (idx2, (b'name', b'fin2'), b'beta', ((), ))}, - set(idx.iter_entries_prefix([(b'name', None)]))) + self.assertEqual( + { + (idx1, (b"name", b"fin1"), b"data", (((b"ref", b"erence"),),)), + (idx2, (b"ref", b"erence"), b"refdata", ((),)), + }, + set(idx.iter_entries_prefix([(b"name", b"fin1"), (b"ref", b"erence")])), + ) + self.assertEqual( + { + (idx1, (b"name", b"fin1"), b"data", (((b"ref", b"erence"),),)), + (idx2, (b"name", b"fin2"), b"beta", ((),)), + }, + set(idx.iter_entries_prefix([(b"name", None)])), + ) def test_iter_nothing_empty(self): idx = _mod_index.CombinedGraphIndex([]) self.assertEqual([], list(idx.iter_entries([]))) def test_iter_nothing_children_empty(self): - idx1 = self.make_index('name') + idx1 = self.make_index("name") idx = _mod_index.CombinedGraphIndex([idx1]) self.assertEqual([], list(idx.iter_entries([]))) def test_iter_all_keys(self): - idx1 = self.make_index('1', 1, nodes=[((b'name', ), b'data', - ([(b'ref', )], ))]) - idx2 = self.make_index( - '2', 1, nodes=[((b'ref', ), b'refdata', ((), ))]) + idx1 = self.make_index("1", 1, nodes=[((b"name",), b"data", ([(b"ref",)],))]) + idx2 = self.make_index("2", 1, nodes=[((b"ref",), b"refdata", ((),))]) idx = _mod_index.CombinedGraphIndex([idx1, idx2]) - self.assertEqual({(idx1, (b'name', ), b'data', (((b'ref', ), ), )), - (idx2, (b'ref', ), b'refdata', ((), ))}, - set(idx.iter_entries([(b'name', ), (b'ref', )]))) + self.assertEqual( + { + (idx1, (b"name",), b"data", (((b"ref",),),)), + (idx2, (b"ref",), b"refdata", ((),)), + }, + set(idx.iter_entries([(b"name",), (b"ref",)])), + ) def test_iter_all_keys_dup_entry(self): - idx1 = self.make_index('1', 1, nodes=[((b'name', ), b'data', - ([(b'ref', )], )), - ((b'ref', ), b'refdata', ([], ))]) - idx2 = self.make_index( - '2', 1, nodes=[((b'ref', ), b'refdata', ([], ))]) + idx1 = self.make_index( + "1", + 1, + nodes=[ + ((b"name",), b"data", ([(b"ref",)],)), + ((b"ref",), b"refdata", ([],)), + ], + ) + idx2 = self.make_index("2", 1, nodes=[((b"ref",), b"refdata", ([],))]) idx = _mod_index.CombinedGraphIndex([idx1, idx2]) - self.assertEqual({(idx1, (b'name', ), b'data', (((b'ref',),),)), - (idx1, (b'ref', ), b'refdata', ((), ))}, - set(idx.iter_entries([(b'name', ), (b'ref', )]))) + self.assertEqual( + { + (idx1, (b"name",), b"data", (((b"ref",),),)), + (idx1, (b"ref",), b"refdata", ((),)), + }, + set(idx.iter_entries([(b"name",), (b"ref",)])), + ) def test_iter_missing_entry_empty(self): idx = _mod_index.CombinedGraphIndex([]) - self.assertEqual([], list(idx.iter_entries([('a', )]))) + self.assertEqual([], list(idx.iter_entries([("a",)]))) def test_iter_missing_entry_one_index(self): - idx1 = self.make_index('1') + idx1 = self.make_index("1") idx = _mod_index.CombinedGraphIndex([idx1]) - self.assertEqual([], list(idx.iter_entries([(b'a', )]))) + self.assertEqual([], list(idx.iter_entries([(b"a",)]))) def test_iter_missing_entry_two_index(self): - idx1 = self.make_index('1') - idx2 = self.make_index('2') + idx1 = self.make_index("1") + idx2 = self.make_index("2") idx = _mod_index.CombinedGraphIndex([idx1, idx2]) - self.assertEqual([], list(idx.iter_entries([('a', )]))) + self.assertEqual([], list(idx.iter_entries([("a",)]))) def test_iter_entry_present_one_index_only(self): - idx1 = self.make_index('1', nodes=[((b'key', ), b'', ())]) - idx2 = self.make_index('2', nodes=[]) + idx1 = self.make_index("1", nodes=[((b"key",), b"", ())]) + idx2 = self.make_index("2", nodes=[]) idx = _mod_index.CombinedGraphIndex([idx1, idx2]) - self.assertEqual([(idx1, (b'key', ), b'')], - list(idx.iter_entries([(b'key', )]))) + self.assertEqual([(idx1, (b"key",), b"")], list(idx.iter_entries([(b"key",)]))) # and in the other direction idx = _mod_index.CombinedGraphIndex([idx2, idx1]) - self.assertEqual([(idx1, (b'key', ), b'')], - list(idx.iter_entries([(b'key', )]))) + self.assertEqual([(idx1, (b"key",), b"")], list(idx.iter_entries([(b"key",)]))) def test_key_count_empty(self): - idx1 = self.make_index('1', nodes=[]) - idx2 = self.make_index('2', nodes=[]) + idx1 = self.make_index("1", nodes=[]) + idx2 = self.make_index("2", nodes=[]) idx = _mod_index.CombinedGraphIndex([idx1, idx2]) self.assertEqual(0, idx.key_count()) def test_key_count_sums_index_keys(self): - idx1 = self.make_index('1', nodes=[ - ((b'1',), b'', ()), - ((b'2',), b'', ())]) - idx2 = self.make_index('2', nodes=[((b'1',), b'', ())]) + idx1 = self.make_index("1", nodes=[((b"1",), b"", ()), ((b"2",), b"", ())]) + idx2 = self.make_index("2", nodes=[((b"1",), b"", ())]) idx = _mod_index.CombinedGraphIndex([idx1, idx2]) self.assertEqual(3, idx.key_count()) def test_validate_bad_child_index_errors(self): trans = self.get_transport() - trans.put_bytes('name', b"not an index\n") - idx1 = _mod_index.GraphIndex(trans, 'name', 13) + trans.put_bytes("name", b"not an index\n") + idx1 = _mod_index.GraphIndex(trans, "name", 13) idx = _mod_index.CombinedGraphIndex([idx1]) self.assertRaises(_mod_index.BadIndexFormatSignature, idx.validate) @@ -1339,61 +1569,55 @@ def test_key_count_reloads_and_fails(self): # We have deleted all underlying indexes, so we will try to reload, but # still fail. This is mostly to test we don't get stuck in an infinite # loop trying to reload - idx, reload_counter = self.make_combined_index_with_missing( - ['1', '2', '3']) + idx, reload_counter = self.make_combined_index_with_missing(["1", "2", "3"]) self.assertRaises(transport.NoSuchFile, idx.key_count) self.assertEqual([2, 1, 1], reload_counter) def test_iter_entries_reloads(self): index, reload_counter = self.make_combined_index_with_missing() - result = list(index.iter_entries([(b'1',), (b'2',), (b'3',)])) + result = list(index.iter_entries([(b"1",), (b"2",), (b"3",)])) index3 = index._indices[0] - self.assertEqual({(index3, (b'1',), b''), (index3, (b'2',), b'')}, - set(result)) + self.assertEqual({(index3, (b"1",), b""), (index3, (b"2",), b"")}, set(result)) self.assertEqual([1, 1, 0], reload_counter) def test_iter_entries_reloads_midway(self): # The first index still looks present, so we get interrupted mid-way # through - index, reload_counter = self.make_combined_index_with_missing(['2']) + index, reload_counter = self.make_combined_index_with_missing(["2"]) index1, index2 = index._indices - result = list(index.iter_entries([(b'1',), (b'2',), (b'3',)])) + result = list(index.iter_entries([(b"1",), (b"2",), (b"3",)])) index3 = index._indices[0] # We had already yielded b'1', so we just go on to the next, we should # not yield b'1' twice. - self.assertEqual([(index1, (b'1',), b''), (index3, (b'2',), b'')], - result) + self.assertEqual([(index1, (b"1",), b""), (index3, (b"2",), b"")], result) self.assertEqual([1, 1, 0], reload_counter) def test_iter_entries_no_reload(self): index, reload_counter = self.make_combined_index_with_missing() index._reload_func = None # Without a _reload_func we just raise the exception - self.assertListRaises(transport.NoSuchFile, index.iter_entries, [('3',)]) + self.assertListRaises(transport.NoSuchFile, index.iter_entries, [("3",)]) def test_iter_entries_reloads_and_fails(self): - index, reload_counter = self.make_combined_index_with_missing( - ['1', '2', '3']) - self.assertListRaises(transport.NoSuchFile, index.iter_entries, [('3',)]) + index, reload_counter = self.make_combined_index_with_missing(["1", "2", "3"]) + self.assertListRaises(transport.NoSuchFile, index.iter_entries, [("3",)]) self.assertEqual([2, 1, 1], reload_counter) def test_iter_all_entries_reloads(self): index, reload_counter = self.make_combined_index_with_missing() result = list(index.iter_all_entries()) index3 = index._indices[0] - self.assertEqual({(index3, (b'1',), b''), (index3, (b'2',), b'')}, - set(result)) + self.assertEqual({(index3, (b"1",), b""), (index3, (b"2",), b"")}, set(result)) self.assertEqual([1, 1, 0], reload_counter) def test_iter_all_entries_reloads_midway(self): - index, reload_counter = self.make_combined_index_with_missing(['2']) + index, reload_counter = self.make_combined_index_with_missing(["2"]) index1, index2 = index._indices result = list(index.iter_all_entries()) index3 = index._indices[0] # We had already yielded '1', so we just go on to the next, we should # not yield '1' twice. - self.assertEqual([(index1, (b'1',), b''), (index3, (b'2',), b'')], - result) + self.assertEqual([(index1, (b"1",), b""), (index3, (b"2",), b"")], result) self.assertEqual([1, 1, 0], reload_counter) def test_iter_all_entries_no_reload(self): @@ -1402,38 +1626,38 @@ def test_iter_all_entries_no_reload(self): self.assertListRaises(transport.NoSuchFile, index.iter_all_entries) def test_iter_all_entries_reloads_and_fails(self): - index, reload_counter = self.make_combined_index_with_missing( - ['1', '2', '3']) + index, reload_counter = self.make_combined_index_with_missing(["1", "2", "3"]) self.assertListRaises(transport.NoSuchFile, index.iter_all_entries) def test_iter_entries_prefix_reloads(self): index, reload_counter = self.make_combined_index_with_missing() - result = list(index.iter_entries_prefix([(b'1',)])) + result = list(index.iter_entries_prefix([(b"1",)])) index3 = index._indices[0] - self.assertEqual([(index3, (b'1',), b'')], result) + self.assertEqual([(index3, (b"1",), b"")], result) self.assertEqual([1, 1, 0], reload_counter) def test_iter_entries_prefix_reloads_midway(self): - index, reload_counter = self.make_combined_index_with_missing(['2']) + index, reload_counter = self.make_combined_index_with_missing(["2"]) index1, index2 = index._indices - result = list(index.iter_entries_prefix([(b'1',)])) + result = list(index.iter_entries_prefix([(b"1",)])) index._indices[0] # We had already yielded b'1', so we just go on to the next, we should # not yield b'1' twice. - self.assertEqual([(index1, (b'1',), b'')], result) + self.assertEqual([(index1, (b"1",), b"")], result) self.assertEqual([1, 1, 0], reload_counter) def test_iter_entries_prefix_no_reload(self): index, reload_counter = self.make_combined_index_with_missing() index._reload_func = None - self.assertListRaises(transport.NoSuchFile, index.iter_entries_prefix, - [(b'1',)]) + self.assertListRaises( + transport.NoSuchFile, index.iter_entries_prefix, [(b"1",)] + ) def test_iter_entries_prefix_reloads_and_fails(self): - index, reload_counter = self.make_combined_index_with_missing( - ['1', '2', '3']) - self.assertListRaises(transport.NoSuchFile, index.iter_entries_prefix, - [(b'1',)]) + index, reload_counter = self.make_combined_index_with_missing(["1", "2", "3"]) + self.assertListRaises( + transport.NoSuchFile, index.iter_entries_prefix, [(b"1",)] + ) def make_index_with_simple_nodes(self, name, num_nodes=1): """Make an index named after 'name', with keys named after 'name' too. @@ -1441,42 +1665,44 @@ def make_index_with_simple_nodes(self, name, num_nodes=1): Nodes will have a value of '' and no references. """ nodes = [ - ((f'index-{name}-key-{n}'.encode('ascii'),), b'', ()) - for n in range(1, num_nodes + 1)] - return self.make_index(f'index-{name}', 0, nodes=nodes) + ((f"index-{name}-key-{n}".encode("ascii"),), b"", ()) + for n in range(1, num_nodes + 1) + ] + return self.make_index(f"index-{name}", 0, nodes=nodes) def test_reorder_after_iter_entries(self): # Four indices: [key1] in idx1, [key2,key3] in idx2, [] in idx3, # [key4] in idx4. idx = _mod_index.CombinedGraphIndex([]) - idx.insert_index(0, self.make_index_with_simple_nodes('1'), b'1') - idx.insert_index(1, self.make_index_with_simple_nodes('2'), b'2') - idx.insert_index(2, self.make_index_with_simple_nodes('3'), b'3') - idx.insert_index(3, self.make_index_with_simple_nodes('4'), b'4') + idx.insert_index(0, self.make_index_with_simple_nodes("1"), b"1") + idx.insert_index(1, self.make_index_with_simple_nodes("2"), b"2") + idx.insert_index(2, self.make_index_with_simple_nodes("3"), b"3") + idx.insert_index(3, self.make_index_with_simple_nodes("4"), b"4") idx1, idx2, idx3, idx4 = idx._indices # Query a key from idx4 and idx2. - self.assertLength(2, list(idx.iter_entries( - [(b'index-4-key-1',), (b'index-2-key-1',)]))) + self.assertLength( + 2, list(idx.iter_entries([(b"index-4-key-1",), (b"index-2-key-1",)])) + ) # Now idx2 and idx4 should be moved to the front (and idx1 should # still be before idx3). self.assertEqual([idx2, idx4, idx1, idx3], idx._indices) - self.assertEqual([b'2', b'4', b'1', b'3'], idx._index_names) + self.assertEqual([b"2", b"4", b"1", b"3"], idx._index_names) def test_reorder_propagates_to_siblings(self): # Two CombinedGraphIndex objects, with the same number of indicies with # matching names. cgi1 = _mod_index.CombinedGraphIndex([]) cgi2 = _mod_index.CombinedGraphIndex([]) - cgi1.insert_index(0, self.make_index_with_simple_nodes('1-1'), 'one') - cgi1.insert_index(1, self.make_index_with_simple_nodes('1-2'), 'two') - cgi2.insert_index(0, self.make_index_with_simple_nodes('2-1'), 'one') - cgi2.insert_index(1, self.make_index_with_simple_nodes('2-2'), 'two') + cgi1.insert_index(0, self.make_index_with_simple_nodes("1-1"), "one") + cgi1.insert_index(1, self.make_index_with_simple_nodes("1-2"), "two") + cgi2.insert_index(0, self.make_index_with_simple_nodes("2-1"), "one") + cgi2.insert_index(1, self.make_index_with_simple_nodes("2-2"), "two") index2_1, index2_2 = cgi2._indices cgi1.set_sibling_indices([cgi2]) # Trigger a reordering in cgi1. cgi2 will be reordered as well. - list(cgi1.iter_entries([(b'index-1-2-key-1',)])) + list(cgi1.iter_entries([(b"index-1-2-key-1",)])) self.assertEqual([index2_2, index2_1], cgi2._indices) - self.assertEqual(['two', 'one'], cgi2._index_names) + self.assertEqual(["two", "one"], cgi2._index_names) def test_validate_reloads(self): idx, reload_counter = self.make_combined_index_with_missing() @@ -1484,7 +1710,7 @@ def test_validate_reloads(self): self.assertEqual([1, 1, 0], reload_counter) def test_validate_reloads_midway(self): - idx, reload_counter = self.make_combined_index_with_missing(['2']) + idx, reload_counter = self.make_combined_index_with_missing(["2"]) idx.validate() def test_validate_no_reload(self): @@ -1493,23 +1719,30 @@ def test_validate_no_reload(self): self.assertRaises(transport.NoSuchFile, idx.validate) def test_validate_reloads_and_fails(self): - idx, reload_counter = self.make_combined_index_with_missing( - ['1', '2', '3']) + idx, reload_counter = self.make_combined_index_with_missing(["1", "2", "3"]) self.assertRaises(transport.NoSuchFile, idx.validate) def test_find_ancestors_across_indexes(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - key3 = (b'key-3',) - key4 = (b'key-4',) - index1 = self.make_index('12', ref_lists=1, nodes=[ - (key1, b'value', ([],)), - (key2, b'value', ([key1],)), - ]) - index2 = self.make_index('34', ref_lists=1, nodes=[ - (key3, b'value', ([key2],)), - (key4, b'value', ([key3],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + key3 = (b"key-3",) + key4 = (b"key-4",) + index1 = self.make_index( + "12", + ref_lists=1, + nodes=[ + (key1, b"value", ([],)), + (key2, b"value", ([key1],)), + ], + ) + index2 = self.make_index( + "34", + ref_lists=1, + nodes=[ + (key3, b"value", ([key2],)), + (key4, b"value", ([key3],)), + ], + ) c_index = _mod_index.CombinedGraphIndex([index1, index2]) parent_map, missing_keys = c_index.find_ancestry([key1], 0) self.assertEqual({key1: ()}, parent_map) @@ -1522,17 +1755,25 @@ def test_find_ancestors_across_indexes(self): self.assertEqual(set(), missing_keys) def test_find_ancestors_missing_keys(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - key3 = (b'key-3',) - key4 = (b'key-4',) - index1 = self.make_index('12', ref_lists=1, nodes=[ - (key1, b'value', ([],)), - (key2, b'value', ([key1],)), - ]) - index2 = self.make_index('34', ref_lists=1, nodes=[ - (key3, b'value', ([key2],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + key3 = (b"key-3",) + key4 = (b"key-4",) + index1 = self.make_index( + "12", + ref_lists=1, + nodes=[ + (key1, b"value", ([],)), + (key2, b"value", ([key1],)), + ], + ) + index2 = self.make_index( + "34", + ref_lists=1, + nodes=[ + (key3, b"value", ([key2],)), + ], + ) c_index = _mod_index.CombinedGraphIndex([index1, index2]) # Searching for a key which is actually not present at all should # eventually converge @@ -1542,146 +1783,207 @@ def test_find_ancestors_missing_keys(self): def test_find_ancestors_no_indexes(self): c_index = _mod_index.CombinedGraphIndex([]) - key1 = (b'key-1',) + key1 = (b"key-1",) parent_map, missing_keys = c_index.find_ancestry([key1], 0) self.assertEqual({}, parent_map) self.assertEqual({key1}, missing_keys) def test_find_ancestors_ghost_parent(self): - key1 = (b'key-1',) - key2 = (b'key-2',) - key3 = (b'key-3',) - key4 = (b'key-4',) - index1 = self.make_index('12', ref_lists=1, nodes=[ - (key1, b'value', ([],)), - (key2, b'value', ([key1],)), - ]) - index2 = self.make_index('34', ref_lists=1, nodes=[ - (key4, b'value', ([key2, key3],)), - ]) + key1 = (b"key-1",) + key2 = (b"key-2",) + key3 = (b"key-3",) + key4 = (b"key-4",) + index1 = self.make_index( + "12", + ref_lists=1, + nodes=[ + (key1, b"value", ([],)), + (key2, b"value", ([key1],)), + ], + ) + index2 = self.make_index( + "34", + ref_lists=1, + nodes=[ + (key4, b"value", ([key2, key3],)), + ], + ) c_index = _mod_index.CombinedGraphIndex([index1, index2]) # Searching for a key which is actually not present at all should # eventually converge parent_map, missing_keys = c_index.find_ancestry([key4], 0) - self.assertEqual({key4: (key2, key3), key2: (key1,), key1: ()}, - parent_map) + self.assertEqual({key4: (key2, key3), key2: (key1,), key1: ()}, parent_map) self.assertEqual({key3}, missing_keys) def test__find_ancestors_empty_index(self): - idx = self.make_index('test', ref_lists=1, key_elements=1, nodes=[]) + idx = self.make_index("test", ref_lists=1, key_elements=1, nodes=[]) parent_map = {} missing_keys = set() - search_keys = idx._find_ancestors([(b'one',), (b'two',)], 0, parent_map, - missing_keys) + search_keys = idx._find_ancestors( + [(b"one",), (b"two",)], 0, parent_map, missing_keys + ) self.assertEqual(set(), search_keys) self.assertEqual({}, parent_map) - self.assertEqual({(b'one',), (b'two',)}, missing_keys) + self.assertEqual({(b"one",), (b"two",)}, missing_keys) class TestInMemoryGraphIndex(tests.TestCaseWithMemoryTransport): - def make_index(self, ref_lists=0, key_elements=1, nodes=None): if nodes is None: nodes = [] - result = _mod_index.InMemoryGraphIndex( - ref_lists, key_elements=key_elements) + result = _mod_index.InMemoryGraphIndex(ref_lists, key_elements=key_elements) result.add_nodes(nodes) return result def test_add_nodes_no_refs(self): index = self.make_index(0) - index.add_nodes([((b'name', ), b'data')]) - index.add_nodes([((b'name2', ), b''), ((b'name3', ), b'')]) - self.assertEqual({ - (index, (b'name', ), b'data'), - (index, (b'name2', ), b''), - (index, (b'name3', ), b''), - }, set(index.iter_all_entries())) + index.add_nodes([((b"name",), b"data")]) + index.add_nodes([((b"name2",), b""), ((b"name3",), b"")]) + self.assertEqual( + { + (index, (b"name",), b"data"), + (index, (b"name2",), b""), + (index, (b"name3",), b""), + }, + set(index.iter_all_entries()), + ) def test_add_nodes(self): index = self.make_index(1) - index.add_nodes([((b'name', ), b'data', ([],))]) - index.add_nodes([((b'name2', ), b'', ([],)), - ((b'name3', ), b'', ([(b'r', )],))]) - self.assertEqual({ - (index, (b'name', ), b'data', ((),)), - (index, (b'name2', ), b'', ((),)), - (index, (b'name3', ), b'', (((b'r', ), ), )), - }, set(index.iter_all_entries())) + index.add_nodes([((b"name",), b"data", ([],))]) + index.add_nodes([((b"name2",), b"", ([],)), ((b"name3",), b"", ([(b"r",)],))]) + self.assertEqual( + { + (index, (b"name",), b"data", ((),)), + (index, (b"name2",), b"", ((),)), + (index, (b"name3",), b"", (((b"r",),),)), + }, + set(index.iter_all_entries()), + ) def test_iter_all_entries_empty(self): index = self.make_index() self.assertEqual([], list(index.iter_all_entries())) def test_iter_all_entries_simple(self): - index = self.make_index(nodes=[((b'name', ), b'data')]) - self.assertEqual([(index, (b'name', ), b'data')], - list(index.iter_all_entries())) + index = self.make_index(nodes=[((b"name",), b"data")]) + self.assertEqual([(index, (b"name",), b"data")], list(index.iter_all_entries())) def test_iter_all_entries_references(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], )), - ((b'ref', ), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref', ),),)), - (index, (b'ref', ), b'refdata', ((), ))}, - set(index.iter_all_entries())) + index = self.make_index( + 1, + nodes=[ + ((b"name",), b"data", ([(b"ref",)],)), + ((b"ref",), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name",), b"data", (((b"ref",),),)), + (index, (b"ref",), b"refdata", ((),)), + }, + set(index.iter_all_entries()), + ) def test_iteration_absent_skipped(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),))}, - set(index.iter_all_entries())) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),))}, - set(index.iter_entries([(b'name', )]))) - self.assertEqual([], list(index.iter_entries([(b'ref', )]))) + index = self.make_index(1, nodes=[((b"name",), b"data", ([(b"ref",)],))]) + self.assertEqual( + {(index, (b"name",), b"data", (((b"ref",),),))}, + set(index.iter_all_entries()), + ) + self.assertEqual( + {(index, (b"name",), b"data", (((b"ref",),),))}, + set(index.iter_entries([(b"name",)])), + ) + self.assertEqual([], list(index.iter_entries([(b"ref",)]))) def test_iter_all_keys(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], )), - ((b'ref', ), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),)), - (index, (b'ref', ), b'refdata', ((), ))}, - set(index.iter_entries([(b'name', ), (b'ref', )]))) + index = self.make_index( + 1, + nodes=[ + ((b"name",), b"data", ([(b"ref",)],)), + ((b"ref",), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name",), b"data", (((b"ref",),),)), + (index, (b"ref",), b"refdata", ((),)), + }, + set(index.iter_entries([(b"name",), (b"ref",)])), + ) def test_iter_key_prefix_1_key_element_no_refs(self): - index = self.make_index(nodes=[ - ((b'name', ), b'data'), - ((b'ref', ), b'refdata')]) - self.assertEqual({(index, (b'name', ), b'data'), - (index, (b'ref', ), b'refdata')}, - set(index.iter_entries_prefix([(b'name', ), (b'ref', )]))) + index = self.make_index(nodes=[((b"name",), b"data"), ((b"ref",), b"refdata")]) + self.assertEqual( + {(index, (b"name",), b"data"), (index, (b"ref",), b"refdata")}, + set(index.iter_entries_prefix([(b"name",), (b"ref",)])), + ) def test_iter_key_prefix_1_key_element_refs(self): - index = self.make_index(1, nodes=[ - ((b'name', ), b'data', ([(b'ref', )], )), - ((b'ref', ), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', ), b'data', (((b'ref',),),)), - (index, (b'ref', ), b'refdata', ((), ))}, - set(index.iter_entries_prefix([(b'name', ), (b'ref', )]))) + index = self.make_index( + 1, + nodes=[ + ((b"name",), b"data", ([(b"ref",)],)), + ((b"ref",), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name",), b"data", (((b"ref",),),)), + (index, (b"ref",), b"refdata", ((),)), + }, + set(index.iter_entries_prefix([(b"name",), (b"ref",)])), + ) def test_iter_key_prefix_2_key_element_no_refs(self): - index = self.make_index(key_elements=2, nodes=[ - ((b'name', b'fin1'), b'data'), - ((b'name', b'fin2'), b'beta'), - ((b'ref', b'erence'), b'refdata')]) - self.assertEqual({(index, (b'name', b'fin1'), b'data'), - (index, (b'ref', b'erence'), b'refdata')}, - set(index.iter_entries_prefix([(b'name', b'fin1'), (b'ref', b'erence')]))) - self.assertEqual({(index, (b'name', b'fin1'), b'data'), - (index, (b'name', b'fin2'), b'beta')}, - set(index.iter_entries_prefix([(b'name', None)]))) + index = self.make_index( + key_elements=2, + nodes=[ + ((b"name", b"fin1"), b"data"), + ((b"name", b"fin2"), b"beta"), + ((b"ref", b"erence"), b"refdata"), + ], + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data"), + (index, (b"ref", b"erence"), b"refdata"), + }, + set(index.iter_entries_prefix([(b"name", b"fin1"), (b"ref", b"erence")])), + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data"), + (index, (b"name", b"fin2"), b"beta"), + }, + set(index.iter_entries_prefix([(b"name", None)])), + ) def test_iter_key_prefix_2_key_element_refs(self): - index = self.make_index(1, key_elements=2, nodes=[ - ((b'name', b'fin1'), b'data', ([(b'ref', b'erence')], )), - ((b'name', b'fin2'), b'beta', ([], )), - ((b'ref', b'erence'), b'refdata', ([], ))]) - self.assertEqual({(index, (b'name', b'fin1'), b'data', (((b'ref', b'erence'),),)), - (index, (b'ref', b'erence'), b'refdata', ((), ))}, - set(index.iter_entries_prefix([(b'name', b'fin1'), (b'ref', b'erence')]))) - self.assertEqual({(index, (b'name', b'fin1'), b'data', (((b'ref', b'erence'),),)), - (index, (b'name', b'fin2'), b'beta', ((), ))}, - set(index.iter_entries_prefix([(b'name', None)]))) + index = self.make_index( + 1, + key_elements=2, + nodes=[ + ((b"name", b"fin1"), b"data", ([(b"ref", b"erence")],)), + ((b"name", b"fin2"), b"beta", ([],)), + ((b"ref", b"erence"), b"refdata", ([],)), + ], + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data", (((b"ref", b"erence"),),)), + (index, (b"ref", b"erence"), b"refdata", ((),)), + }, + set(index.iter_entries_prefix([(b"name", b"fin1"), (b"ref", b"erence")])), + ) + self.assertEqual( + { + (index, (b"name", b"fin1"), b"data", (((b"ref", b"erence"),),)), + (index, (b"name", b"fin2"), b"beta", ((),)), + }, + set(index.iter_entries_prefix([(b"name", None)])), + ) def test_iter_nothing_empty(self): index = self.make_index() @@ -1689,18 +1991,18 @@ def test_iter_nothing_empty(self): def test_iter_missing_entry_empty(self): index = self.make_index() - self.assertEqual([], list(index.iter_entries([b'a']))) + self.assertEqual([], list(index.iter_entries([b"a"]))) def test_key_count_empty(self): index = self.make_index() self.assertEqual(0, index.key_count()) def test_key_count_one(self): - index = self.make_index(nodes=[((b'name', ), b'')]) + index = self.make_index(nodes=[((b"name",), b"")]) self.assertEqual(1, index.key_count()) def test_key_count_two(self): - index = self.make_index(nodes=[((b'name', ), b''), ((b'foo', ), b'')]) + index = self.make_index(nodes=[((b"name",), b""), ((b"foo",), b"")]) self.assertEqual(2, index.key_count()) def test_validate_empty(self): @@ -1708,107 +2010,165 @@ def test_validate_empty(self): index.validate() def test_validate_no_refs_content(self): - index = self.make_index(nodes=[((b'key', ), b'value')]) + index = self.make_index(nodes=[((b"key",), b"value")]) index.validate() class TestGraphIndexPrefixAdapter(tests.TestCaseWithMemoryTransport): - - def make_index(self, ref_lists=1, key_elements=2, nodes=None, - add_callback=False): + def make_index(self, ref_lists=1, key_elements=2, nodes=None, add_callback=False): if nodes is None: nodes = [] - result = _mod_index.InMemoryGraphIndex( - ref_lists, key_elements=key_elements) + result = _mod_index.InMemoryGraphIndex(ref_lists, key_elements=key_elements) result.add_nodes(nodes) if add_callback: add_nodes_callback = result.add_nodes else: add_nodes_callback = None adapter = _mod_index.GraphIndexPrefixAdapter( - result, (b'prefix', ), key_elements - 1, - add_nodes_callback=add_nodes_callback) + result, + (b"prefix",), + key_elements - 1, + add_nodes_callback=add_nodes_callback, + ) return result, adapter def test_add_node(self): index, adapter = self.make_index(add_callback=True) - adapter.add_node((b'key',), b'value', (((b'ref',),),)) - self.assertEqual({(index, (b'prefix', b'key'), b'value', - (((b'prefix', b'ref'),),))}, - set(index.iter_all_entries())) + adapter.add_node((b"key",), b"value", (((b"ref",),),)) + self.assertEqual( + {(index, (b"prefix", b"key"), b"value", (((b"prefix", b"ref"),),))}, + set(index.iter_all_entries()), + ) def test_add_nodes(self): index, adapter = self.make_index(add_callback=True) - adapter.add_nodes(( - ((b'key',), b'value', (((b'ref',),),)), - ((b'key2',), b'value2', ((),)), - )) - self.assertEqual({ - (index, (b'prefix', b'key2'), b'value2', ((),)), - (index, (b'prefix', b'key'), b'value', (((b'prefix', b'ref'),),)) + adapter.add_nodes( + ( + ((b"key",), b"value", (((b"ref",),),)), + ((b"key2",), b"value2", ((),)), + ) + ) + self.assertEqual( + { + (index, (b"prefix", b"key2"), b"value2", ((),)), + (index, (b"prefix", b"key"), b"value", (((b"prefix", b"ref"),),)), }, - set(index.iter_all_entries())) + set(index.iter_all_entries()), + ) def test_construct(self): idx = _mod_index.InMemoryGraphIndex() - _mod_index.GraphIndexPrefixAdapter(idx, (b'prefix', ), 1) + _mod_index.GraphIndexPrefixAdapter(idx, (b"prefix",), 1) def test_construct_with_callback(self): idx = _mod_index.InMemoryGraphIndex() - _mod_index.GraphIndexPrefixAdapter(idx, (b'prefix', ), 1, - idx.add_nodes) + _mod_index.GraphIndexPrefixAdapter(idx, (b"prefix",), 1, idx.add_nodes) def test_iter_all_entries_cross_prefix_map_errors(self): - index, adapter = self.make_index(nodes=[ - ((b'prefix', b'key1'), b'data1', (((b'prefixaltered', b'key2'),),))]) - self.assertRaises(_mod_index.BadIndexData, list, - adapter.iter_all_entries()) + index, adapter = self.make_index( + nodes=[((b"prefix", b"key1"), b"data1", (((b"prefixaltered", b"key2"),),))] + ) + self.assertRaises(_mod_index.BadIndexData, list, adapter.iter_all_entries()) def test_iter_all_entries(self): - index, adapter = self.make_index(nodes=[ - ((b'notprefix', b'key1'), b'data', ((), )), - ((b'prefix', b'key1'), b'data1', ((), )), - ((b'prefix', b'key2'), b'data2', (((b'prefix', b'key1'),),))]) - self.assertEqual({(index, (b'key1', ), b'data1', ((),)), - (index, (b'key2', ), b'data2', (((b'key1',),),))}, - set(adapter.iter_all_entries())) + index, adapter = self.make_index( + nodes=[ + ((b"notprefix", b"key1"), b"data", ((),)), + ((b"prefix", b"key1"), b"data1", ((),)), + ((b"prefix", b"key2"), b"data2", (((b"prefix", b"key1"),),)), + ] + ) + self.assertEqual( + { + (index, (b"key1",), b"data1", ((),)), + (index, (b"key2",), b"data2", (((b"key1",),),)), + }, + set(adapter.iter_all_entries()), + ) def test_iter_entries(self): - index, adapter = self.make_index(nodes=[ - ((b'notprefix', b'key1'), b'data', ((), )), - ((b'prefix', b'key1'), b'data1', ((), )), - ((b'prefix', b'key2'), b'data2', (((b'prefix', b'key1'),),))]) + index, adapter = self.make_index( + nodes=[ + ((b"notprefix", b"key1"), b"data", ((),)), + ((b"prefix", b"key1"), b"data1", ((),)), + ((b"prefix", b"key2"), b"data2", (((b"prefix", b"key1"),),)), + ] + ) # ask for many - get all - self.assertEqual({(index, (b'key1', ), b'data1', ((),)), - (index, (b'key2', ), b'data2', (((b'key1', ),),))}, - set(adapter.iter_entries([(b'key1', ), (b'key2', )]))) + self.assertEqual( + { + (index, (b"key1",), b"data1", ((),)), + (index, (b"key2",), b"data2", (((b"key1",),),)), + }, + set(adapter.iter_entries([(b"key1",), (b"key2",)])), + ) # ask for one, get one - self.assertEqual({(index, (b'key1', ), b'data1', ((),))}, - set(adapter.iter_entries([(b'key1', )]))) + self.assertEqual( + {(index, (b"key1",), b"data1", ((),))}, + set(adapter.iter_entries([(b"key1",)])), + ) # ask for missing, get none - self.assertEqual(set(), - set(adapter.iter_entries([(b'key3', )]))) + self.assertEqual(set(), set(adapter.iter_entries([(b"key3",)]))) def test_iter_entries_prefix(self): - index, adapter = self.make_index(key_elements=3, nodes=[ - ((b'notprefix', b'foo', b'key1'), b'data', ((), )), - ((b'prefix', b'prefix2', b'key1'), b'data1', ((), )), - ((b'prefix', b'prefix2', b'key2'), b'data2', (((b'prefix', b'prefix2', b'key1'),),))]) + index, adapter = self.make_index( + key_elements=3, + nodes=[ + ((b"notprefix", b"foo", b"key1"), b"data", ((),)), + ((b"prefix", b"prefix2", b"key1"), b"data1", ((),)), + ( + (b"prefix", b"prefix2", b"key2"), + b"data2", + (((b"prefix", b"prefix2", b"key1"),),), + ), + ], + ) # ask for a prefix, get the results for just that prefix, adjusted. - self.assertEqual({(index, (b'prefix2', b'key1', ), b'data1', ((),)), - (index, (b'prefix2', b'key2', ), b'data2', (((b'prefix2', b'key1', ),),))}, - set(adapter.iter_entries_prefix([(b'prefix2', None)]))) + self.assertEqual( + { + ( + index, + ( + b"prefix2", + b"key1", + ), + b"data1", + ((),), + ), + ( + index, + ( + b"prefix2", + b"key2", + ), + b"data2", + ( + ( + ( + b"prefix2", + b"key1", + ), + ), + ), + ), + }, + set(adapter.iter_entries_prefix([(b"prefix2", None)])), + ) def test_key_count_no_matching_keys(self): - index, adapter = self.make_index(nodes=[ - ((b'notprefix', b'key1'), b'data', ((), ))]) + index, adapter = self.make_index( + nodes=[((b"notprefix", b"key1"), b"data", ((),))] + ) self.assertEqual(0, adapter.key_count()) def test_key_count_some_keys(self): - index, adapter = self.make_index(nodes=[ - ((b'notprefix', b'key1'), b'data', ((), )), - ((b'prefix', b'key1'), b'data1', ((), )), - ((b'prefix', b'key2'), b'data2', (((b'prefix', b'key1'),),))]) + index, adapter = self.make_index( + nodes=[ + ((b"notprefix", b"key1"), b"data", ((),)), + ((b"prefix", b"key1"), b"data1", ((),)), + ((b"prefix", b"key2"), b"data2", (((b"prefix", b"key1"),),)), + ] + ) self.assertEqual(2, adapter.key_count()) def test_validate(self): @@ -1816,7 +2176,8 @@ def test_validate(self): calls = [] def validate(): - calls.append('called') + calls.append("called") + index.validate = validate adapter.validate() - self.assertEqual(['called'], calls) + self.assertEqual(["called"], calls) diff --git a/breezy/bzr/tests/test_inv.py b/breezy/bzr/tests/test_inv.py index 17d0d66dcf..d7e5dbf72b 100644 --- a/breezy/bzr/tests/test_inv.py +++ b/breezy/bzr/tests/test_inv.py @@ -41,8 +41,8 @@ def delta_application_scenarios(): scenarios = [ - ('Inventory', {'apply_delta': apply_inventory_Inventory}), - ] + ("Inventory", {"apply_delta": apply_inventory_Inventory}), + ] # Working tree basis delta application # Repository add_inv_by_delta. # Reduce form of the per_repository test logic - that logic needs to be @@ -50,9 +50,15 @@ def delta_application_scenarios(): # just creating trees. for _, format in repository.format_registry.iteritems(): if format.supports_full_versioned_files: - scenarios.append((str(format.__name__), { - 'apply_delta': apply_inventory_Repository_add_inventory_by_delta, - 'format': format})) + scenarios.append( + ( + str(format.__name__), + { + "apply_delta": apply_inventory_Repository_add_inventory_by_delta, + "format": format, + }, + ) + ) for getter in workingtree.format_registry._get_all_lazy(): try: format = getter() @@ -64,20 +70,24 @@ def delta_application_scenarios(): if not repo_fmt.supports_full_versioned_files: continue scenarios.append( - (str(format.__class__.__name__) + ".update_basis_by_delta", { - 'apply_delta': apply_inventory_WT_basis, - 'format': format})) + ( + str(format.__class__.__name__) + ".update_basis_by_delta", + {"apply_delta": apply_inventory_WT_basis, "format": format}, + ) + ) scenarios.append( - (str(format.__class__.__name__) + ".apply_inventory_delta", { - 'apply_delta': apply_inventory_WT, - 'format': format})) + ( + str(format.__class__.__name__) + ".apply_inventory_delta", + {"apply_delta": apply_inventory_WT, "format": format}, + ) + ) return scenarios def create_texts_for_inv(repo, inv): for _path, ie in inv.iter_entries(): - if getattr(ie, 'text_size', None): - lines = [b'a' * ie.text_size] + if getattr(ie, "text_size", None): + lines = [b"a" * ie.text_size] else: lines = [] repo.texts.add_lines((ie.file_id, ie.revision), [], lines) @@ -103,8 +113,7 @@ def apply_inventory_WT(self, basis, delta, invalid_delta=True): :param delta: The inventory delta to apply: :return: An inventory resulting from the application. """ - control = self.make_controldir( - 'tree', format=self.format._matchingcontroldir) + control = self.make_controldir("tree", format=self.format._matchingcontroldir) control.create_repository() control.create_branch() tree = self.format.initialize(control) @@ -125,26 +134,40 @@ def apply_inventory_WT(self, basis, delta, invalid_delta=True): def _create_repo_revisions(repo, basis, delta, invalid_delta): with repository.WriteGroup(repo): - rev = revision.Revision(b'basis', timestamp=0, timezone=None, - message="", committer="foo@example.com", - parent_ids=[], properties={}, inventory_sha1=None) - basis.revision_id = b'basis' + rev = revision.Revision( + b"basis", + timestamp=0, + timezone=None, + message="", + committer="foo@example.com", + parent_ids=[], + properties={}, + inventory_sha1=None, + ) + basis.revision_id = b"basis" create_texts_for_inv(repo, basis) - repo.add_revision(b'basis', rev, basis) + repo.add_revision(b"basis", rev, basis) if invalid_delta: # We don't want to apply the delta to the basis, because we expect # the delta is invalid. result_inv = basis - result_inv.revision_id = b'result' + result_inv.revision_id = b"result" target_entries = None else: - result_inv = basis.create_by_apply_delta(delta, b'result') + result_inv = basis.create_by_apply_delta(delta, b"result") create_texts_for_inv(repo, result_inv) target_entries = list(result_inv.iter_entries_by_dir()) - rev = revision.Revision(b'result', timestamp=0, timezone=None, - message="", committer="foo@example.com", - parent_ids=[], properties={}, inventory_sha1=None) - repo.add_revision(b'result', rev, result_inv) + rev = revision.Revision( + b"result", + timestamp=0, + timezone=None, + message="", + committer="foo@example.com", + parent_ids=[], + properties={}, + inventory_sha1=None, + ) + repo.add_revision(b"result", rev, result_inv) return target_entries @@ -157,17 +180,17 @@ def _get_basis_entries(tree): def _populate_different_tree(tree, basis, delta): """Put all entries into tree, but at a unique location.""" added_ids = set() - tree.add(['unique-dir'], ['directory'], [b'unique-dir-id']) + tree.add(["unique-dir"], ["directory"], [b"unique-dir-id"]) for _path, ie in basis.iter_entries_by_dir(): if ie.file_id in added_ids: continue # We want a unique path for each of these, we use the file-id - tree.add(['unique-dir/' + ie.file_id], [ie.kind], [ie.file_id]) + tree.add(["unique-dir/" + ie.file_id], [ie.kind], [ie.file_id]) added_ids.add(ie.file_id) for _old_path, _new_path, file_id, ie in delta: if file_id in added_ids: continue - tree.add(['unique-dir/' + file_id], [ie.kind], [file_id]) + tree.add(["unique-dir/" + file_id], [ie.kind], [file_id]) def apply_inventory_WT_basis(test, basis, delta, invalid_delta=True): @@ -183,25 +206,25 @@ def apply_inventory_WT_basis(test, basis, delta, invalid_delta=True): :param delta: The inventory delta to apply: :return: An inventory resulting from the application. """ - control = test.make_controldir( - 'tree', format=test.format._matchingcontroldir) + control = test.make_controldir("tree", format=test.format._matchingcontroldir) control.create_repository() control.create_branch() tree = test.format.initialize(control) tree.lock_write() try: - target_entries = _create_repo_revisions(tree.branch.repository, basis, - delta, invalid_delta) + target_entries = _create_repo_revisions( + tree.branch.repository, basis, delta, invalid_delta + ) # Set the basis state as the trees current state tree._write_inventory(basis) # This reads basis from the repo and puts it into the tree's local # cache, if it has one. - tree.set_parent_ids([b'basis']) + tree.set_parent_ids([b"basis"]) finally: tree.unlock() # Fresh lock, reads disk again. with tree.lock_write(): - tree.update_basis_by_delta(b'result', delta) + tree.update_basis_by_delta(b"result", delta) if not invalid_delta: tree._validate() # reload tree - ensure we get what was written. @@ -216,8 +239,9 @@ def apply_inventory_WT_basis(test, basis, delta, invalid_delta=True): return basis_inv -def apply_inventory_Repository_add_inventory_by_delta(self, basis, delta, - invalid_delta=True): +def apply_inventory_Repository_add_inventory_by_delta( + self, basis, delta, invalid_delta=True +): """Apply delta to basis and return the result. This inserts basis as a whole inventory and then uses @@ -228,33 +252,37 @@ def apply_inventory_Repository_add_inventory_by_delta(self, basis, delta, :return: An inventory resulting from the application. """ format = self.format() - control = self.make_controldir('tree', format=format._matchingcontroldir) + control = self.make_controldir("tree", format=format._matchingcontroldir) repo = format.initialize(control) with repo.lock_write(), repository.WriteGroup(repo): rev = revision.Revision( - b'basis', timestamp=0, timezone=None, message="", + b"basis", + timestamp=0, + timezone=None, + message="", committer="foo@example.com", - parent_ids=[], properties={}, inventory_sha1=None) - basis.revision_id = b'basis' + parent_ids=[], + properties={}, + inventory_sha1=None, + ) + basis.revision_id = b"basis" create_texts_for_inv(repo, basis) - repo.add_revision(b'basis', rev, basis) + repo.add_revision(b"basis", rev, basis) with repo.lock_write(), repository.WriteGroup(repo): - repo.add_inventory_by_delta( - b'basis', delta, b'result', [b'basis']) + repo.add_inventory_by_delta(b"basis", delta, b"result", [b"basis"]) # Fresh lock, reads disk again. repo = repo.controldir.open_repository() repo.lock_read() self.addCleanup(repo.unlock) - return repo.get_inventory(b'result') + return repo.get_inventory(b"result") class TestInventoryUpdates(TestCase): - def test_creation_from_root_id(self): # iff a root id is passed to the constructor, a root directory is made - inv = inventory.Inventory(root_id=b'tree-root') + inv = inventory.Inventory(root_id=b"tree-root") self.assertNotEqual(None, inv.root) - self.assertEqual(b'tree-root', inv.root.file_id) + self.assertEqual(b"tree-root", inv.root.file_id) def test_add_path_of_root(self): # if no root id is given at creation time, there is no root directory @@ -262,25 +290,25 @@ def test_add_path_of_root(self): self.assertIs(None, inv.root) # add a root entry by adding its path ie = inv.add_path("", "directory", b"my-root") - ie.revision = b'test-rev' + ie.revision = b"test-rev" self.assertEqual(b"my-root", ie.file_id) self.assertIs(ie, inv.root) def test_add_path(self): - inv = inventory.Inventory(root_id=b'tree_root') - ie = inv.add_path('hello', 'file', b'hello-id') - self.assertEqual(b'hello-id', ie.file_id) - self.assertEqual('file', ie.kind) + inv = inventory.Inventory(root_id=b"tree_root") + ie = inv.add_path("hello", "file", b"hello-id") + self.assertEqual(b"hello-id", ie.file_id) + self.assertEqual("file", ie.kind) def test_copy(self): """Make sure copy() works and creates a deep copy.""" - inv = inventory.Inventory(root_id=b'some-tree-root') - ie = inv.add_path('hello', 'file', b'hello-id') + inv = inventory.Inventory(root_id=b"some-tree-root") + ie = inv.add_path("hello", "file", b"hello-id") inv2 = inv.copy() - inv.rename_id(b'some-tree-root', b'some-new-root') - ie.name = 'file2' - self.assertEqual(b'some-tree-root', inv2.root.file_id) - self.assertEqual('hello', inv2.get_entry(b'hello-id').name) + inv.rename_id(b"some-tree-root", b"some-new-root") + ie.name = "file2" + self.assertEqual(b"some-tree-root", inv2.root.file_id) + self.assertEqual("hello", inv2.get_entry(b"hello-id").name) def test_copy_empty(self): """Make sure an empty inventory can be copied.""" @@ -290,36 +318,44 @@ def test_copy_empty(self): def test_copy_copies_root_revision(self): """Make sure the revision of the root gets copied.""" - inv = inventory.Inventory(root_id=b'someroot') - inv.root.revision = b'therev' + inv = inventory.Inventory(root_id=b"someroot") + inv.root.revision = b"therev" inv2 = inv.copy() - self.assertEqual(b'someroot', inv2.root.file_id) - self.assertEqual(b'therev', inv2.root.revision) + self.assertEqual(b"someroot", inv2.root.file_id) + self.assertEqual(b"therev", inv2.root.revision) def test_create_tree_reference(self): - inv = inventory.Inventory(b'tree-root-123') - inv.add(TreeReference( - b'nested-id', 'nested', parent_id=b'tree-root-123', - revision=b'rev', reference_revision=b'rev2')) + inv = inventory.Inventory(b"tree-root-123") + inv.add( + TreeReference( + b"nested-id", + "nested", + parent_id=b"tree-root-123", + revision=b"rev", + reference_revision=b"rev2", + ) + ) def test_error_encoding(self): - inv = inventory.Inventory(b'tree-root') - inv.add(InventoryFile(b'a-id', '\u1234', b'tree-root')) - e = self.assertRaises(errors.InconsistentDelta, inv.add, - InventoryFile(b'b-id', '\u1234', b'tree-root')) - self.assertContainsRe(str(e), '\\u1234') + inv = inventory.Inventory(b"tree-root") + inv.add(InventoryFile(b"a-id", "\u1234", b"tree-root")) + e = self.assertRaises( + errors.InconsistentDelta, + inv.add, + InventoryFile(b"b-id", "\u1234", b"tree-root"), + ) + self.assertContainsRe(str(e), "\\u1234") def test_add_recursive(self): - parent = InventoryDirectory(b'src-id', 'src', b'tree-root') - child = InventoryFile(b'hello-id', 'hello.c', b'src-id') - inv = inventory.Inventory(b'tree-root') + parent = InventoryDirectory(b"src-id", "src", b"tree-root") + child = InventoryFile(b"hello-id", "hello.c", b"src-id") + inv = inventory.Inventory(b"tree-root") inv.add(parent) inv.add(child) - self.assertEqual('src/hello.c', inv.id2path(b'hello-id')) + self.assertEqual("src/hello.c", inv.id2path(b"hello-id")) class TestDeltaApplication(TestCaseWithTransport): - scenarios = delta_application_scenarios() def get_empty_inventory(self, reference_inv=None): @@ -338,14 +374,14 @@ def get_empty_inventory(self, reference_inv=None): if reference_inv is not None: inv.root.revision = reference_inv.root.revision else: - inv.root.revision = b'basis' + inv.root.revision = b"basis" return inv - def make_file_ie(self, file_id=b'file-id', name='name', parent_id=None): + def make_file_ie(self, file_id=b"file-id", name="name", parent_id=None): ie_file = inventory.InventoryFile(file_id, name, parent_id) - ie_file.revision = b'result' + ie_file.revision = b"result" ie_file.text_size = 0 - ie_file.text_sha1 = b'' + ie_file.text_sha1 = b"" return ie_file def test_empty_delta(self): @@ -357,38 +393,36 @@ def test_empty_delta(self): def test_repeated_file_id(self): inv = self.get_empty_inventory() - file1 = inventory.InventoryFile(b'id', 'path1', inv.root.file_id) - file1.revision = b'result' + file1 = inventory.InventoryFile(b"id", "path1", inv.root.file_id) + file1.revision = b"result" file1.text_size = 0 file1.text_sha1 = b"" file2 = file1.copy() - file2.name = 'path2' + file2.name = "path2" delta = InventoryDelta( - [(None, 'path1', b'id', file1), - (None, 'path2', b'id', file2)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + [(None, "path1", b"id", file1), (None, "path2", b"id", file2)] + ) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_repeated_new_path(self): inv = self.get_empty_inventory() - file1 = inventory.InventoryFile(b'id1', 'path', inv.root.file_id) - file1.revision = b'result' + file1 = inventory.InventoryFile(b"id1", "path", inv.root.file_id) + file1.revision = b"result" file1.text_size = 0 file1.text_sha1 = b"" - file2 = inventory.InventoryFile(b'id2', 'path', inv.root.file_id) - file2.revision = b'result' + file2 = inventory.InventoryFile(b"id2", "path", inv.root.file_id) + file2.revision = b"result" file2.text_size = 0 file2.text_sha1 = b"" delta = InventoryDelta( - [(None, 'path', b'id1', file1), - (None, 'path', b'id2', file2)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + [(None, "path", b"id1", file1), (None, "path", b"id2", file2)] + ) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_repeated_old_path(self): inv = self.get_empty_inventory() - file1 = inventory.InventoryFile(b'id1', 'path', inv.root.file_id) - file1.revision = b'result' + file1 = inventory.InventoryFile(b"id1", "path", inv.root.file_id) + file1.revision = b"result" file1.text_size = 0 file1.text_sha1 = b"" # We can't *create* a source inventory with the same path, but @@ -397,95 +431,88 @@ def test_repeated_old_path(self): # And the path for one of the file ids doesn't match the source # location. Alternatively, we could have a repeated fileid, but that # is separately checked for. - file2 = inventory.InventoryFile(b'id2', 'path2', inv.root.file_id) - file2.revision = b'result' + file2 = inventory.InventoryFile(b"id2", "path2", inv.root.file_id) + file2.revision = b"result" file2.text_size = 0 file2.text_sha1 = b"" inv.add(file1) inv.add(file2) - delta = InventoryDelta([('path', None, b'id1', None), ('path', None, b'id2', None)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta( + [("path", None, b"id1", None), ("path", None, b"id2", None)] + ) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_mismatched_id_entry_id(self): inv = self.get_empty_inventory() - file1 = inventory.InventoryFile(b'id1', 'path', inv.root.file_id) - file1.revision = b'result' + file1 = inventory.InventoryFile(b"id1", "path", inv.root.file_id) + file1.revision = b"result" file1.text_size = 0 file1.text_sha1 = b"" - delta = InventoryDelta([(None, 'path', b'id', file1)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([(None, "path", b"id", file1)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_mismatched_new_path_entry_None(self): inv = self.get_empty_inventory() - delta = InventoryDelta([(None, 'path', b'id', None)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([(None, "path", b"id", None)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_mismatched_new_path_None_entry(self): inv = self.get_empty_inventory() - file1 = inventory.InventoryFile(b'id1', 'path', inv.root.file_id) - file1.revision = b'result' + file1 = inventory.InventoryFile(b"id1", "path", inv.root.file_id) + file1.revision = b"result" file1.text_size = 0 file1.text_sha1 = b"" - delta = InventoryDelta([("path", None, b'id1', file1)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([("path", None, b"id1", file1)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_parent_is_not_directory(self): inv = self.get_empty_inventory() - file1 = inventory.InventoryFile(b'id1', 'path', inv.root.file_id) - file1.revision = b'result' + file1 = inventory.InventoryFile(b"id1", "path", inv.root.file_id) + file1.revision = b"result" file1.text_size = 0 file1.text_sha1 = b"" - file2 = inventory.InventoryFile(b'id2', 'path2', b'id1') - file2.revision = b'result' + file2 = inventory.InventoryFile(b"id2", "path2", b"id1") + file2.revision = b"result" file2.text_size = 0 file2.text_sha1 = b"" inv.add(file1) - delta = InventoryDelta([(None, 'path/path2', b'id2', file2)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([(None, "path/path2", b"id2", file2)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_parent_is_missing(self): inv = self.get_empty_inventory() - file2 = inventory.InventoryFile(b'id2', 'path2', b'missingparent') - file2.revision = b'result' + file2 = inventory.InventoryFile(b"id2", "path2", b"missingparent") + file2.revision = b"result" file2.text_size = 0 file2.text_sha1 = b"" - delta = InventoryDelta([(None, 'path/path2', b'id2', file2)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([(None, "path/path2", b"id2", file2)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_new_parent_path_has_wrong_id(self): inv = self.get_empty_inventory() - parent1 = inventory.InventoryDirectory(b'p-1', 'dir', inv.root.file_id) - parent1.revision = b'result' - parent2 = inventory.InventoryDirectory( - b'p-2', 'dir2', inv.root.file_id) - parent2.revision = b'result' - file1 = inventory.InventoryFile(b'id', 'path', b'p-2') - file1.revision = b'result' + parent1 = inventory.InventoryDirectory(b"p-1", "dir", inv.root.file_id) + parent1.revision = b"result" + parent2 = inventory.InventoryDirectory(b"p-2", "dir2", inv.root.file_id) + parent2.revision = b"result" + file1 = inventory.InventoryFile(b"id", "path", b"p-2") + file1.revision = b"result" file1.text_size = 0 file1.text_sha1 = b"" inv.add(parent1) inv.add(parent2) # This delta claims that file1 is at dir/path, but actually its at # dir2/path if you follow the inventory parent structure. - delta = InventoryDelta([(None, 'dir/path', b'id', file1)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([(None, "dir/path", b"id", file1)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_old_parent_path_is_wrong(self): inv = self.get_empty_inventory() - parent1 = inventory.InventoryDirectory(b'p-1', 'dir', inv.root.file_id) - parent1.revision = b'result' - parent2 = inventory.InventoryDirectory( - b'p-2', 'dir2', inv.root.file_id) - parent2.revision = b'result' - file1 = inventory.InventoryFile(b'id', 'path', b'p-2') - file1.revision = b'result' + parent1 = inventory.InventoryDirectory(b"p-1", "dir", inv.root.file_id) + parent1.revision = b"result" + parent2 = inventory.InventoryDirectory(b"p-2", "dir2", inv.root.file_id) + parent2.revision = b"result" + file1 = inventory.InventoryFile(b"id", "path", b"p-2") + file1.revision = b"result" file1.text_size = 0 file1.text_sha1 = b"" inv.add(parent1) @@ -493,23 +520,21 @@ def test_old_parent_path_is_wrong(self): inv.add(file1) # This delta claims that file1 was at dir/path, but actually it was at # dir2/path if you follow the inventory parent structure. - delta = InventoryDelta([('dir/path', None, b'id', None)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([("dir/path", None, b"id", None)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_old_parent_path_is_for_other_id(self): inv = self.get_empty_inventory() - parent1 = inventory.InventoryDirectory(b'p-1', 'dir', inv.root.file_id) - parent1.revision = b'result' - parent2 = inventory.InventoryDirectory( - b'p-2', 'dir2', inv.root.file_id) - parent2.revision = b'result' - file1 = inventory.InventoryFile(b'id', 'path', b'p-2') - file1.revision = b'result' + parent1 = inventory.InventoryDirectory(b"p-1", "dir", inv.root.file_id) + parent1.revision = b"result" + parent2 = inventory.InventoryDirectory(b"p-2", "dir2", inv.root.file_id) + parent2.revision = b"result" + file1 = inventory.InventoryFile(b"id", "path", b"p-2") + file1.revision = b"result" file1.text_size = 0 file1.text_sha1 = b"" - file2 = inventory.InventoryFile(b'id2', 'path', b'p-1') - file2.revision = b'result' + file2 = inventory.InventoryFile(b"id2", "path", b"p-1") + file2.revision = b"result" file2.text_size = 0 file2.text_sha1 = b"" inv.add(parent1) @@ -519,147 +544,138 @@ def test_old_parent_path_is_for_other_id(self): # This delta claims that file1 was at dir/path, but actually it was at # dir2/path if you follow the inventory parent structure. At dir/path # is another entry we should not delete. - delta = InventoryDelta([('dir/path', None, b'id', None)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([("dir/path", None, b"id", None)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_add_existing_id_new_path(self): inv = self.get_empty_inventory() - parent1 = inventory.InventoryDirectory( - b'p-1', 'dir1', inv.root.file_id) - parent1.revision = b'result' - parent2 = inventory.InventoryDirectory( - b'p-1', 'dir2', inv.root.file_id) - parent2.revision = b'result' + parent1 = inventory.InventoryDirectory(b"p-1", "dir1", inv.root.file_id) + parent1.revision = b"result" + parent2 = inventory.InventoryDirectory(b"p-1", "dir2", inv.root.file_id) + parent2.revision = b"result" inv.add(parent1) - delta = InventoryDelta([(None, 'dir2', b'p-1', parent2)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([(None, "dir2", b"p-1", parent2)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_add_new_id_existing_path(self): inv = self.get_empty_inventory() - parent1 = inventory.InventoryDirectory( - b'p-1', 'dir1', inv.root.file_id) - parent1.revision = b'result' - parent2 = inventory.InventoryDirectory( - b'p-2', 'dir1', inv.root.file_id) - parent2.revision = b'result' + parent1 = inventory.InventoryDirectory(b"p-1", "dir1", inv.root.file_id) + parent1.revision = b"result" + parent2 = inventory.InventoryDirectory(b"p-2", "dir1", inv.root.file_id) + parent2.revision = b"result" inv.add(parent1) - delta = InventoryDelta([(None, 'dir1', b'p-2', parent2)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + delta = InventoryDelta([(None, "dir1", b"p-2", parent2)]) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_remove_dir_leaving_dangling_child(self): inv = self.get_empty_inventory() - dir1 = inventory.InventoryDirectory(b'p-1', 'dir1', inv.root.file_id) - dir1.revision = b'result' - dir2 = inventory.InventoryDirectory(b'p-2', 'child1', b'p-1') - dir2.revision = b'result' - dir3 = inventory.InventoryDirectory(b'p-3', 'child2', b'p-1') - dir3.revision = b'result' + dir1 = inventory.InventoryDirectory(b"p-1", "dir1", inv.root.file_id) + dir1.revision = b"result" + dir2 = inventory.InventoryDirectory(b"p-2", "child1", b"p-1") + dir2.revision = b"result" + dir3 = inventory.InventoryDirectory(b"p-3", "child2", b"p-1") + dir3.revision = b"result" inv.add(dir1) inv.add(dir2) inv.add(dir3) delta = InventoryDelta( - [('dir1', None, b'p-1', None), - ('dir1/child2', None, b'p-3', None)]) - self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, - inv, delta) + [("dir1", None, b"p-1", None), ("dir1/child2", None, b"p-3", None)] + ) + self.assertRaises(errors.InconsistentDelta, self.apply_delta, self, inv, delta) def test_add_file(self): inv = self.get_empty_inventory() - file1 = inventory.InventoryFile(b'file-id', 'path', inv.root.file_id) - file1.revision = b'result' + file1 = inventory.InventoryFile(b"file-id", "path", inv.root.file_id) + file1.revision = b"result" file1.text_size = 0 - file1.text_sha1 = b'' - delta = InventoryDelta([(None, 'path', b'file-id', file1)]) + file1.text_sha1 = b"" + delta = InventoryDelta([(None, "path", b"file-id", file1)]) res_inv = self.apply_delta(self, inv, delta, invalid_delta=False) - self.assertEqual(b'file-id', res_inv.get_entry(b'file-id').file_id) + self.assertEqual(b"file-id", res_inv.get_entry(b"file-id").file_id) def test_remove_file(self): inv = self.get_empty_inventory() - file1 = inventory.InventoryFile(b'file-id', 'path', inv.root.file_id) - file1.revision = b'result' + file1 = inventory.InventoryFile(b"file-id", "path", inv.root.file_id) + file1.revision = b"result" file1.text_size = 0 - file1.text_sha1 = b'' + file1.text_sha1 = b"" inv.add(file1) - delta = InventoryDelta([('path', None, b'file-id', None)]) + delta = InventoryDelta([("path", None, b"file-id", None)]) res_inv = self.apply_delta(self, inv, delta, invalid_delta=False) - self.assertEqual(None, res_inv.path2id('path')) - self.assertRaises(errors.NoSuchId, res_inv.id2path, b'file-id') + self.assertEqual(None, res_inv.path2id("path")) + self.assertRaises(errors.NoSuchId, res_inv.id2path, b"file-id") def test_rename_file(self): inv = self.get_empty_inventory() - file1 = self.make_file_ie(name='path', parent_id=inv.root.file_id) + file1 = self.make_file_ie(name="path", parent_id=inv.root.file_id) inv.add(file1) - file2 = self.make_file_ie(name='path2', parent_id=inv.root.file_id) - delta = InventoryDelta([('path', 'path2', b'file-id', file2)]) + file2 = self.make_file_ie(name="path2", parent_id=inv.root.file_id) + delta = InventoryDelta([("path", "path2", b"file-id", file2)]) res_inv = self.apply_delta(self, inv, delta, invalid_delta=False) - self.assertEqual(None, res_inv.path2id('path')) - self.assertEqual(b'file-id', res_inv.path2id('path2')) + self.assertEqual(None, res_inv.path2id("path")) + self.assertEqual(b"file-id", res_inv.path2id("path2")) def test_replaced_at_new_path(self): inv = self.get_empty_inventory() - file1 = self.make_file_ie(file_id=b'id1', parent_id=inv.root.file_id) + file1 = self.make_file_ie(file_id=b"id1", parent_id=inv.root.file_id) inv.add(file1) - file2 = self.make_file_ie(file_id=b'id2', parent_id=inv.root.file_id) + file2 = self.make_file_ie(file_id=b"id2", parent_id=inv.root.file_id) delta = InventoryDelta( - [('name', None, b'id1', None), - (None, 'name', b'id2', file2)]) + [("name", None, b"id1", None), (None, "name", b"id2", file2)] + ) res_inv = self.apply_delta(self, inv, delta, invalid_delta=False) - self.assertEqual(b'id2', res_inv.path2id('name')) + self.assertEqual(b"id2", res_inv.path2id("name")) def test_rename_dir(self): inv = self.get_empty_inventory() - dir1 = inventory.InventoryDirectory( - b'dir-id', 'dir1', inv.root.file_id) - dir1.revision = b'basis' - file1 = self.make_file_ie(parent_id=b'dir-id') + dir1 = inventory.InventoryDirectory(b"dir-id", "dir1", inv.root.file_id) + dir1.revision = b"basis" + file1 = self.make_file_ie(parent_id=b"dir-id") inv.add(dir1) inv.add(file1) - dir2 = inventory.InventoryDirectory( - b'dir-id', 'dir2', inv.root.file_id) - dir2.revision = b'result' - delta = InventoryDelta([('dir1', 'dir2', b'dir-id', dir2)]) + dir2 = inventory.InventoryDirectory(b"dir-id", "dir2", inv.root.file_id) + dir2.revision = b"result" + delta = InventoryDelta([("dir1", "dir2", b"dir-id", dir2)]) res_inv = self.apply_delta(self, inv, delta, invalid_delta=False) # The file should be accessible under the new path - self.assertEqual(b'file-id', res_inv.path2id('dir2/name')) + self.assertEqual(b"file-id", res_inv.path2id("dir2/name")) def test_renamed_dir_with_renamed_child(self): inv = self.get_empty_inventory() - dir1 = inventory.InventoryDirectory( - b'dir-id', 'dir1', inv.root.file_id) - dir1.revision = b'basis' - file1 = self.make_file_ie(b'file-id-1', 'name1', parent_id=b'dir-id') - file2 = self.make_file_ie(b'file-id-2', 'name2', parent_id=b'dir-id') + dir1 = inventory.InventoryDirectory(b"dir-id", "dir1", inv.root.file_id) + dir1.revision = b"basis" + file1 = self.make_file_ie(b"file-id-1", "name1", parent_id=b"dir-id") + file2 = self.make_file_ie(b"file-id-2", "name2", parent_id=b"dir-id") inv.add(dir1) inv.add(file1) inv.add(file2) - dir2 = inventory.InventoryDirectory( - b'dir-id', 'dir2', inv.root.file_id) - dir2.revision = b'result' - file2b = self.make_file_ie(b'file-id-2', 'name2', inv.root.file_id) + dir2 = inventory.InventoryDirectory(b"dir-id", "dir2", inv.root.file_id) + dir2.revision = b"result" + file2b = self.make_file_ie(b"file-id-2", "name2", inv.root.file_id) delta = InventoryDelta( - [('dir1', 'dir2', b'dir-id', dir2), - ('dir1/name2', 'name2', b'file-id-2', file2b)]) + [ + ("dir1", "dir2", b"dir-id", dir2), + ("dir1/name2", "name2", b"file-id-2", file2b), + ] + ) res_inv = self.apply_delta(self, inv, delta, invalid_delta=False) # The file should be accessible under the new path - self.assertEqual(b'file-id-1', res_inv.path2id('dir2/name1')) - self.assertEqual(None, res_inv.path2id('dir2/name2')) - self.assertEqual(b'file-id-2', res_inv.path2id('name2')) + self.assertEqual(b"file-id-1", res_inv.path2id("dir2/name1")) + self.assertEqual(None, res_inv.path2id("dir2/name2")) + self.assertEqual(b"file-id-2", res_inv.path2id("name2")) def test_is_root(self): """Ensure our root-checking code is accurate.""" - inv = inventory.Inventory(b'TREE_ROOT') - self.assertTrue(inv.is_root(b'TREE_ROOT')) - self.assertFalse(inv.is_root(b'booga')) - inv.rename_id(inv.root.file_id, b'booga') - self.assertFalse(inv.is_root(b'TREE_ROOT')) - self.assertTrue(inv.is_root(b'booga')) + inv = inventory.Inventory(b"TREE_ROOT") + self.assertTrue(inv.is_root(b"TREE_ROOT")) + self.assertFalse(inv.is_root(b"booga")) + inv.rename_id(inv.root.file_id, b"booga") + self.assertFalse(inv.is_root(b"TREE_ROOT")) + self.assertTrue(inv.is_root(b"booga")) # works properly even if no root is set inv.root = None - self.assertFalse(inv.is_root(b'TREE_ROOT')) - self.assertFalse(inv.is_root(b'booga')) + self.assertFalse(inv.is_root(b"TREE_ROOT")) + self.assertFalse(inv.is_root(b"booga")) def test_entries_for_empty_inventory(self): """Test that entries() will not fail for an empty inventory.""" @@ -668,41 +684,41 @@ def test_entries_for_empty_inventory(self): class TestInventoryEntry(TestCase): - def test_file_invalid_entry_name(self): - self.assertRaises(InvalidEntryName, inventory.InventoryFile, - b'123', 'a/hello.c', ROOT_ID) + self.assertRaises( + InvalidEntryName, inventory.InventoryFile, b"123", "a/hello.c", ROOT_ID + ) def test_file_backslash(self): - file = inventory.InventoryFile(b'123', 'h\\ello.c', ROOT_ID) - self.assertEqual(file.name, 'h\\ello.c') + file = inventory.InventoryFile(b"123", "h\\ello.c", ROOT_ID) + self.assertEqual(file.name, "h\\ello.c") def test_file_kind_character(self): - file = inventory.InventoryFile(b'123', 'hello.c', ROOT_ID) - self.assertEqual(file.kind_character(), '') + file = inventory.InventoryFile(b"123", "hello.c", ROOT_ID) + self.assertEqual(file.kind_character(), "") def test_dir_kind_character(self): - dir = inventory.InventoryDirectory(b'123', 'hello.c', ROOT_ID) - self.assertEqual(dir.kind_character(), '/') + dir = inventory.InventoryDirectory(b"123", "hello.c", ROOT_ID) + self.assertEqual(dir.kind_character(), "/") def test_link_kind_character(self): - dir = inventory.InventoryLink(b'123', 'hello.c', ROOT_ID) - self.assertEqual(dir.kind_character(), '@') + dir = inventory.InventoryLink(b"123", "hello.c", ROOT_ID) + self.assertEqual(dir.kind_character(), "@") def test_tree_ref_kind_character(self): - dir = TreeReference(b'123', 'hello.c', ROOT_ID) - self.assertEqual(dir.kind_character(), '+') + dir = TreeReference(b"123", "hello.c", ROOT_ID) + self.assertEqual(dir.kind_character(), "+") def test_dir_detect_changes(self): - left = inventory.InventoryDirectory(b'123', 'hello.c', ROOT_ID) - right = inventory.InventoryDirectory(b'123', 'hello.c', ROOT_ID) + left = inventory.InventoryDirectory(b"123", "hello.c", ROOT_ID) + right = inventory.InventoryDirectory(b"123", "hello.c", ROOT_ID) self.assertEqual((False, False), left.detect_changes(right)) self.assertEqual((False, False), right.detect_changes(left)) def test_file_detect_changes(self): - left = inventory.InventoryFile(b'123', 'hello.c', ROOT_ID) + left = inventory.InventoryFile(b"123", "hello.c", ROOT_ID) left.text_sha1 = b"123" - right = inventory.InventoryFile(b'123', 'hello.c', ROOT_ID) + right = inventory.InventoryFile(b"123", "hello.c", ROOT_ID) right.text_sha1 = b"123" self.assertEqual((False, False), left.detect_changes(right)) self.assertEqual((False, False), right.detect_changes(left)) @@ -714,48 +730,56 @@ def test_file_detect_changes(self): self.assertEqual((True, True), right.detect_changes(left)) def test_symlink_detect_changes(self): - left = inventory.InventoryLink(b'123', 'hello.c', ROOT_ID) - left.symlink_target = 'foo' - right = inventory.InventoryLink(b'123', 'hello.c', ROOT_ID) - right.symlink_target = 'foo' + left = inventory.InventoryLink(b"123", "hello.c", ROOT_ID) + left.symlink_target = "foo" + right = inventory.InventoryLink(b"123", "hello.c", ROOT_ID) + right.symlink_target = "foo" self.assertEqual((False, False), left.detect_changes(right)) self.assertEqual((False, False), right.detect_changes(left)) - left.symlink_target = 'different' + left.symlink_target = "different" self.assertEqual((True, False), left.detect_changes(right)) self.assertEqual((True, False), right.detect_changes(left)) def test_file_has_text(self): - file = inventory.InventoryFile(b'123', 'hello.c', ROOT_ID) + file = inventory.InventoryFile(b"123", "hello.c", ROOT_ID) self.assertTrue(file.has_text()) def test_directory_has_text(self): - dir = inventory.InventoryDirectory(b'123', 'hello.c', ROOT_ID) + dir = inventory.InventoryDirectory(b"123", "hello.c", ROOT_ID) self.assertFalse(dir.has_text()) def test_link_has_text(self): - link = inventory.InventoryLink(b'123', 'hello.c', ROOT_ID) + link = inventory.InventoryLink(b"123", "hello.c", ROOT_ID) self.assertFalse(link.has_text()) def test_make_entry(self): - self.assertIsInstance(inventory.make_entry("file", "name", ROOT_ID), - inventory.InventoryFile) - self.assertIsInstance(inventory.make_entry("symlink", "name", ROOT_ID), - inventory.InventoryLink) - self.assertIsInstance(inventory.make_entry("directory", "name", ROOT_ID), - inventory.InventoryDirectory) + self.assertIsInstance( + inventory.make_entry("file", "name", ROOT_ID), inventory.InventoryFile + ) + self.assertIsInstance( + inventory.make_entry("symlink", "name", ROOT_ID), inventory.InventoryLink + ) + self.assertIsInstance( + inventory.make_entry("directory", "name", ROOT_ID), + inventory.InventoryDirectory, + ) def test_make_entry_non_normalized(self): if osutils.normalizes_filenames(): - entry = inventory.make_entry("file", 'a\u030a', ROOT_ID) - self.assertEqual('\xe5', entry.name) + entry = inventory.make_entry("file", "a\u030a", ROOT_ID) + self.assertEqual("\xe5", entry.name) self.assertIsInstance(entry, inventory.InventoryFile) else: - self.assertRaises(errors.InvalidNormalization, - inventory.make_entry, 'file', 'a\u030a', ROOT_ID) + self.assertRaises( + errors.InvalidNormalization, + inventory.make_entry, + "file", + "a\u030a", + ROOT_ID, + ) class TestDescribeChanges(TestCase): - def test_describe_change(self): # we need to test the following change combinations: # rename @@ -766,46 +790,46 @@ def test_describe_change(self): # renamed/reparented and modified # change kind (perhaps can't be done yet?) # also, merged in combination with all of these? - old_a = InventoryFile(b'a-id', 'a_file', ROOT_ID) - old_a.text_sha1 = b'123132' + old_a = InventoryFile(b"a-id", "a_file", ROOT_ID) + old_a.text_sha1 = b"123132" old_a.text_size = 0 - new_a = InventoryFile(b'a-id', 'a_file', ROOT_ID) - new_a.text_sha1 = b'123132' + new_a = InventoryFile(b"a-id", "a_file", ROOT_ID) + new_a.text_sha1 = b"123132" new_a.text_size = 0 - self.assertChangeDescription('unchanged', old_a, new_a) + self.assertChangeDescription("unchanged", old_a, new_a) new_a.text_size = 10 - new_a.text_sha1 = b'abcabc' - self.assertChangeDescription('modified', old_a, new_a) + new_a.text_sha1 = b"abcabc" + self.assertChangeDescription("modified", old_a, new_a) - self.assertChangeDescription('added', None, new_a) - self.assertChangeDescription('removed', old_a, None) + self.assertChangeDescription("added", None, new_a) + self.assertChangeDescription("removed", old_a, None) # perhaps a bit questionable but seems like the most reasonable thing... - self.assertChangeDescription('unchanged', None, None) + self.assertChangeDescription("unchanged", None, None) # in this case it's both renamed and modified; show a rename and # modification: - new_a.name = 'newfilename' - self.assertChangeDescription('modified and renamed', old_a, new_a) + new_a.name = "newfilename" + self.assertChangeDescription("modified and renamed", old_a, new_a) # reparenting is 'renaming' new_a.name = old_a.name - new_a.parent_id = b'somedir-id' - self.assertChangeDescription('modified and renamed', old_a, new_a) + new_a.parent_id = b"somedir-id" + self.assertChangeDescription("modified and renamed", old_a, new_a) # reset the content values so its not modified new_a.text_size = old_a.text_size new_a.text_sha1 = old_a.text_sha1 new_a.name = old_a.name - new_a.name = 'newfilename' - self.assertChangeDescription('renamed', old_a, new_a) + new_a.name = "newfilename" + self.assertChangeDescription("renamed", old_a, new_a) # reparenting is 'renaming' new_a.name = old_a.name - new_a.parent_id = b'somedir-id' - self.assertChangeDescription('renamed', old_a, new_a) + new_a.parent_id = b"somedir-id" + self.assertChangeDescription("renamed", old_a, new_a) def assertChangeDescription(self, expected_change, old_ie, new_ie): change = InventoryEntry.describe_change(old_ie, new_ie) @@ -813,14 +837,13 @@ def assertChangeDescription(self, expected_change, old_ie, new_ie): class TestCHKInventory(tests.TestCaseWithMemoryTransport): - def get_chk_bytes(self): factory = groupcompress.make_pack_factory(True, True, 1) - trans = self.get_transport('') + trans = self.get_transport("") return factory(trans) def read_bytes(self, chk_bytes, key): - stream = chk_bytes.get_record_stream([key], 'unordered', True) + stream = chk_bytes.get_record_stream([key], "unordered", True) return next(stream).get_bytes_as("fulltext") def test_deserialise_gives_CHKInventory(self): @@ -837,7 +860,7 @@ def test_deserialise_gives_CHKInventory(self): self.assertEqual(inv.root.parent_id, new_inv.root.parent_id) self.assertEqual(inv.root.name, new_inv.root.name) self.assertEqual(b"rootrev", new_inv.root.revision) - self.assertEqual(b'plain', new_inv._search_key_name) + self.assertEqual(b"plain", new_inv._search_key_name) def test_deserialise_wrong_revid(self): inv = Inventory() @@ -846,8 +869,9 @@ def test_deserialise_wrong_revid(self): chk_bytes = self.get_chk_bytes() chk_inv = CHKInventory.from_inventory(chk_bytes, inv) lines = chk_inv.to_lines() - self.assertRaises(ValueError, CHKInventory.deserialise, chk_bytes, - lines, (b"revid2",)) + self.assertRaises( + ValueError, CHKInventory.deserialise, chk_bytes, lines, (b"revid2",) + ) def test_captures_rev_root_byid(self): inv = Inventory() @@ -856,16 +880,18 @@ def test_captures_rev_root_byid(self): chk_bytes = self.get_chk_bytes() chk_inv = CHKInventory.from_inventory(chk_bytes, inv) lines = chk_inv.to_lines() - self.assertEqual([ - b'chkinventory:\n', - b'revision_id: foo\n', - b'root_id: TREE_ROOT\n', - b'parent_id_basename_to_file_id: sha1:eb23f0ad4b07f48e88c76d4c94292be57fb2785f\n', - b'id_to_entry: sha1:debfe920f1f10e7929260f0534ac9a24d7aabbb4\n', - ], lines) - chk_inv = CHKInventory.deserialise( - chk_bytes, lines, (b'foo',)) - self.assertEqual(b'plain', chk_inv._search_key_name) + self.assertEqual( + [ + b"chkinventory:\n", + b"revision_id: foo\n", + b"root_id: TREE_ROOT\n", + b"parent_id_basename_to_file_id: sha1:eb23f0ad4b07f48e88c76d4c94292be57fb2785f\n", + b"id_to_entry: sha1:debfe920f1f10e7929260f0534ac9a24d7aabbb4\n", + ], + lines, + ) + chk_inv = CHKInventory.deserialise(chk_bytes, lines, (b"foo",)) + self.assertEqual(b"plain", chk_inv._search_key_name) def test_captures_parent_id_basename_index(self): inv = Inventory() @@ -874,36 +900,41 @@ def test_captures_parent_id_basename_index(self): chk_bytes = self.get_chk_bytes() chk_inv = CHKInventory.from_inventory(chk_bytes, inv) lines = chk_inv.to_lines() - self.assertEqual([ - b'chkinventory:\n', - b'revision_id: foo\n', - b'root_id: TREE_ROOT\n', - b'parent_id_basename_to_file_id: sha1:eb23f0ad4b07f48e88c76d4c94292be57fb2785f\n', - b'id_to_entry: sha1:debfe920f1f10e7929260f0534ac9a24d7aabbb4\n', - ], lines) - chk_inv = CHKInventory.deserialise( - chk_bytes, lines, (b'foo',)) - self.assertEqual(b'plain', chk_inv._search_key_name) + self.assertEqual( + [ + b"chkinventory:\n", + b"revision_id: foo\n", + b"root_id: TREE_ROOT\n", + b"parent_id_basename_to_file_id: sha1:eb23f0ad4b07f48e88c76d4c94292be57fb2785f\n", + b"id_to_entry: sha1:debfe920f1f10e7929260f0534ac9a24d7aabbb4\n", + ], + lines, + ) + chk_inv = CHKInventory.deserialise(chk_bytes, lines, (b"foo",)) + self.assertEqual(b"plain", chk_inv._search_key_name) def test_captures_search_key_name(self): inv = Inventory() inv.revision_id = b"foo" inv.root.revision = b"bar" chk_bytes = self.get_chk_bytes() - chk_inv = CHKInventory.from_inventory(chk_bytes, inv, - search_key_name=b'hash-16-way') + chk_inv = CHKInventory.from_inventory( + chk_bytes, inv, search_key_name=b"hash-16-way" + ) lines = chk_inv.to_lines() - self.assertEqual([ - b'chkinventory:\n', - b'search_key_name: hash-16-way\n', - b'root_id: TREE_ROOT\n', - b'parent_id_basename_to_file_id: sha1:eb23f0ad4b07f48e88c76d4c94292be57fb2785f\n', - b'revision_id: foo\n', - b'id_to_entry: sha1:debfe920f1f10e7929260f0534ac9a24d7aabbb4\n', - ], lines) - chk_inv = CHKInventory.deserialise( - chk_bytes, lines, (b'foo',)) - self.assertEqual(b'hash-16-way', chk_inv._search_key_name) + self.assertEqual( + [ + b"chkinventory:\n", + b"search_key_name: hash-16-way\n", + b"root_id: TREE_ROOT\n", + b"parent_id_basename_to_file_id: sha1:eb23f0ad4b07f48e88c76d4c94292be57fb2785f\n", + b"revision_id: foo\n", + b"id_to_entry: sha1:debfe920f1f10e7929260f0534ac9a24d7aabbb4\n", + ], + lines, + ) + chk_inv = CHKInventory.deserialise(chk_bytes, lines, (b"foo",)) + self.assertEqual(b"hash-16-way", chk_inv._search_key_name) def test_directory_children_on_demand(self): inv = Inventory() @@ -919,9 +950,9 @@ def test_directory_children_on_demand(self): lines = chk_inv.to_lines() new_inv = CHKInventory.deserialise(chk_bytes, lines, (b"revid",)) root_entry = new_inv.get_entry(inv.root.file_id) - self.assertEqual({'file'}, set(inv.get_children(root_entry.file_id))) + self.assertEqual({"file"}, set(inv.get_children(root_entry.file_id))) file_direct = new_inv.get_entry(b"fileid") - file_found = inv.get_children(root_entry.file_id)['file'] + file_found = inv.get_children(root_entry.file_id)["file"] self.assertEqual(file_direct.kind, file_found.kind) self.assertEqual(file_direct.file_id, file_found.file_id) self.assertEqual(file_direct.parent_id, file_found.parent_id) @@ -1003,7 +1034,7 @@ def test_get_entry(self): self.assertEqual(b"ffff", file_entry.text_sha1) self.assertEqual(1, file_entry.text_size) self.assertEqual(True, file_entry.executable) - self.assertRaises(errors.NoSuchId, new_inv.get_entry, 'missing') + self.assertRaises(errors.NoSuchId, new_inv.get_entry, "missing") def test_has_id_true(self): inv = Inventory() @@ -1016,7 +1047,7 @@ def test_has_id_true(self): inv.get_entry(b"fileid").text_size = 1 chk_bytes = self.get_chk_bytes() chk_inv = CHKInventory.from_inventory(chk_bytes, inv) - self.assertTrue(chk_inv.has_id(b'fileid')) + self.assertTrue(chk_inv.has_id(b"fileid")) self.assertTrue(chk_inv.has_id(inv.root.file_id)) def test_has_id_not(self): @@ -1025,7 +1056,7 @@ def test_has_id_not(self): inv.root.revision = b"rootrev" chk_bytes = self.get_chk_bytes() chk_inv = CHKInventory.from_inventory(chk_bytes, inv) - self.assertFalse(chk_inv.has_id(b'fileid')) + self.assertFalse(chk_inv.has_id(b"fileid")) def test_id2path(self): inv = Inventory() @@ -1044,9 +1075,9 @@ def test_id2path(self): chk_inv = CHKInventory.from_inventory(chk_bytes, inv) lines = chk_inv.to_lines() new_inv = CHKInventory.deserialise(chk_bytes, lines, (b"revid",)) - self.assertEqual('', new_inv.id2path(inv.root.file_id)) - self.assertEqual('dir', new_inv.id2path(b'dirid')) - self.assertEqual('dir/file', new_inv.id2path(b'fileid')) + self.assertEqual("", new_inv.id2path(inv.root.file_id)) + self.assertEqual("dir", new_inv.id2path(b"dirid")) + self.assertEqual("dir/file", new_inv.id2path(b"fileid")) def test_path2id(self): inv = Inventory() @@ -1065,9 +1096,9 @@ def test_path2id(self): chk_inv = CHKInventory.from_inventory(chk_bytes, inv) lines = chk_inv.to_lines() new_inv = CHKInventory.deserialise(chk_bytes, lines, (b"revid",)) - self.assertEqual(inv.root.file_id, new_inv.path2id('')) - self.assertEqual(b'dirid', new_inv.path2id('dir')) - self.assertEqual(b'fileid', new_inv.path2id('dir/file')) + self.assertEqual(inv.root.file_id, new_inv.path2id("")) + self.assertEqual(b"dirid", new_inv.path2id("dir")) + self.assertEqual(b"fileid", new_inv.path2id("dir/file")) def test_create_by_apply_delta_sets_root(self): inv = Inventory() @@ -1079,8 +1110,9 @@ def test_create_by_apply_delta_sets_root(self): inv.revision_id = b"expectedid" inv.root.revision = b"myrootrev" reference_inv = CHKInventory.from_inventory(chk_bytes, inv) - delta = InventoryDelta([("", None, base_inv.root.file_id, None), - (None, "", b"myrootid", inv.root)]) + delta = InventoryDelta( + [("", None, base_inv.root.file_id, None), (None, "", b"myrootid", inv.root)] + ) new_inv = base_inv.create_by_apply_delta(delta, b"expectedid") self.assertEqual(reference_inv.root, new_inv.root) @@ -1105,8 +1137,10 @@ def test_create_by_apply_delta_empty_add_child(self): self.assertEqual(reference_inv.root_id, new_inv.root_id) reference_inv.id_to_entry._ensure_root() new_inv.id_to_entry._ensure_root() - self.assertEqual(reference_inv.id_to_entry._root_node._key, - new_inv.id_to_entry._root_node._key) + self.assertEqual( + reference_inv.id_to_entry._root_node._key, + new_inv.id_to_entry._root_node._key, + ) def test_create_by_apply_delta_empty_add_child_updates_parent_id(self): inv = Inventory() @@ -1131,10 +1165,14 @@ def test_create_by_apply_delta_empty_add_child_updates_parent_id(self): # new_inv should be the same as reference_inv. self.assertEqual(reference_inv.revision_id, new_inv.revision_id) self.assertEqual(reference_inv.root_id, new_inv.root_id) - self.assertEqual(reference_inv.id_to_entry._root_node._key, - new_inv.id_to_entry._root_node._key) - self.assertEqual(reference_inv.parent_id_basename_to_file_id._root_node._key, - new_inv.parent_id_basename_to_file_id._root_node._key) + self.assertEqual( + reference_inv.id_to_entry._root_node._key, + new_inv.id_to_entry._root_node._key, + ) + self.assertEqual( + reference_inv.parent_id_basename_to_file_id._root_node._key, + new_inv.parent_id_basename_to_file_id._root_node._key, + ) def test_iter_changes(self): # Low level bootstrapping smoke test; comprehensive generic tests via @@ -1163,11 +1201,21 @@ def test_iter_changes(self): chk_inv2 = CHKInventory.from_inventory(chk_bytes, inv2) lines = chk_inv2.to_lines() inv_2 = CHKInventory.deserialise(chk_bytes, lines, (b"revid2",)) - self.assertEqual([(b'fileid', ('file', 'file'), True, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('file', - 'file'), ('file', 'file'), - (False, True))], - list(inv_1.iter_changes(inv_2))) + self.assertEqual( + [ + ( + b"fileid", + ("file", "file"), + True, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("file", "file"), + ("file", "file"), + (False, True), + ) + ], + list(inv_1.iter_changes(inv_2)), + ) def test_parent_id_basename_to_file_id_index_enabled(self): inv = Inventory() @@ -1183,129 +1231,152 @@ def test_parent_id_basename_to_file_id_index_enabled(self): tmp_inv = CHKInventory.from_inventory(chk_bytes, inv) lines = tmp_inv.to_lines() chk_inv = CHKInventory.deserialise(chk_bytes, lines, (b"revid",)) - self.assertIsInstance( - chk_inv.parent_id_basename_to_file_id, chk_map.CHKMap) + self.assertIsInstance(chk_inv.parent_id_basename_to_file_id, chk_map.CHKMap) self.assertEqual( - {(b'', b''): b'TREE_ROOT', (b'TREE_ROOT', b'file'): b'fileid'}, - dict(chk_inv.parent_id_basename_to_file_id.iteritems())) + {(b"", b""): b"TREE_ROOT", (b"TREE_ROOT", b"file"): b"fileid"}, + dict(chk_inv.parent_id_basename_to_file_id.iteritems()), + ) def test_file_entry_to_bytes(self): CHKInventory(None) - ie = inventory.InventoryFile(b'file-id', 'filename', b'parent-id') + ie = inventory.InventoryFile(b"file-id", "filename", b"parent-id") ie.executable = True - ie.revision = b'file-rev-id' - ie.text_sha1 = b'abcdefgh' + ie.revision = b"file-rev-id" + ie.text_sha1 = b"abcdefgh" ie.text_size = 100 bytes = _chk_inventory_entry_to_bytes(ie) - self.assertEqual(b'file: file-id\nparent-id\nfilename\n' - b'file-rev-id\nabcdefgh\n100\nY', bytes) + self.assertEqual( + b"file: file-id\nparent-id\nfilename\n" b"file-rev-id\nabcdefgh\n100\nY", + bytes, + ) ie2 = _chk_inventory_bytes_to_entry(bytes) self.assertEqual(ie, ie2) self.assertIsInstance(ie2.name, str) - self.assertEqual((b'filename', b'file-id', b'file-rev-id'), - chk_inventory_bytes_to_utf8name_key(bytes)) + self.assertEqual( + (b"filename", b"file-id", b"file-rev-id"), + chk_inventory_bytes_to_utf8name_key(bytes), + ) def test_file2_entry_to_bytes(self): CHKInventory(None) # \u30a9 == 'omega' - ie = inventory.InventoryFile(b'file-id', '\u03a9name', b'parent-id') + ie = inventory.InventoryFile(b"file-id", "\u03a9name", b"parent-id") ie.executable = False - ie.revision = b'file-rev-id' - ie.text_sha1 = b'123456' + ie.revision = b"file-rev-id" + ie.text_sha1 = b"123456" ie.text_size = 25 bytes = _chk_inventory_entry_to_bytes(ie) - self.assertEqual(b'file: file-id\nparent-id\n\xce\xa9name\n' - b'file-rev-id\n123456\n25\nN', bytes) + self.assertEqual( + b"file: file-id\nparent-id\n\xce\xa9name\n" b"file-rev-id\n123456\n25\nN", + bytes, + ) ie2 = _chk_inventory_bytes_to_entry(bytes) self.assertEqual(ie, ie2) self.assertIsInstance(ie2.name, str) - self.assertEqual((b'\xce\xa9name', b'file-id', b'file-rev-id'), - chk_inventory_bytes_to_utf8name_key(bytes)) + self.assertEqual( + (b"\xce\xa9name", b"file-id", b"file-rev-id"), + chk_inventory_bytes_to_utf8name_key(bytes), + ) def test_dir_entry_to_bytes(self): CHKInventory(None) - ie = inventory.InventoryDirectory(b'dir-id', 'dirname', b'parent-id') - ie.revision = b'dir-rev-id' + ie = inventory.InventoryDirectory(b"dir-id", "dirname", b"parent-id") + ie.revision = b"dir-rev-id" bytes = _chk_inventory_entry_to_bytes(ie) - self.assertEqual(b'dir: dir-id\nparent-id\ndirname\ndir-rev-id', bytes) + self.assertEqual(b"dir: dir-id\nparent-id\ndirname\ndir-rev-id", bytes) ie2 = _chk_inventory_bytes_to_entry(bytes) self.assertEqual(ie, ie2) self.assertIsInstance(ie2.name, str) - self.assertEqual((b'dirname', b'dir-id', b'dir-rev-id'), - chk_inventory_bytes_to_utf8name_key(bytes)) + self.assertEqual( + (b"dirname", b"dir-id", b"dir-rev-id"), + chk_inventory_bytes_to_utf8name_key(bytes), + ) def test_dir2_entry_to_bytes(self): CHKInventory(None) - ie = inventory.InventoryDirectory(b'dir-id', 'dir\u03a9name', b'pid') - ie.revision = b'dir-rev-id' + ie = inventory.InventoryDirectory(b"dir-id", "dir\u03a9name", b"pid") + ie.revision = b"dir-rev-id" bytes = _chk_inventory_entry_to_bytes(ie) - self.assertEqual(b'dir: dir-id\npid\ndir\xce\xa9name\n' - b'dir-rev-id', bytes) + self.assertEqual(b"dir: dir-id\npid\ndir\xce\xa9name\n" b"dir-rev-id", bytes) ie2 = _chk_inventory_bytes_to_entry(bytes) self.assertEqual(ie, ie2) self.assertIsInstance(ie2.name, str) - self.assertEqual(b'pid', ie2.parent_id) - self.assertEqual((b'dir\xce\xa9name', b'dir-id', b'dir-rev-id'), - chk_inventory_bytes_to_utf8name_key(bytes)) + self.assertEqual(b"pid", ie2.parent_id) + self.assertEqual( + (b"dir\xce\xa9name", b"dir-id", b"dir-rev-id"), + chk_inventory_bytes_to_utf8name_key(bytes), + ) def test_symlink_entry_to_bytes(self): CHKInventory(None) - ie = inventory.InventoryLink(b'link-id', 'linkname', b'parent-id') - ie.revision = b'link-rev-id' - ie.symlink_target = 'target/path' + ie = inventory.InventoryLink(b"link-id", "linkname", b"parent-id") + ie.revision = b"link-rev-id" + ie.symlink_target = "target/path" bytes = _chk_inventory_entry_to_bytes(ie) - self.assertEqual(b'symlink: link-id\nparent-id\nlinkname\n' - b'link-rev-id\ntarget/path', bytes) + self.assertEqual( + b"symlink: link-id\nparent-id\nlinkname\n" b"link-rev-id\ntarget/path", + bytes, + ) ie2 = _chk_inventory_bytes_to_entry(bytes) self.assertEqual(ie, ie2) self.assertIsInstance(ie2.name, str) self.assertIsInstance(ie2.symlink_target, str) - self.assertEqual((b'linkname', b'link-id', b'link-rev-id'), - chk_inventory_bytes_to_utf8name_key(bytes)) + self.assertEqual( + (b"linkname", b"link-id", b"link-rev-id"), + chk_inventory_bytes_to_utf8name_key(bytes), + ) def test_symlink2_entry_to_bytes(self): CHKInventory(None) - ie = inventory.InventoryLink( - b'link-id', 'link\u03a9name', b'parent-id') - ie.revision = b'link-rev-id' - ie.symlink_target = 'target/\u03a9path' + ie = inventory.InventoryLink(b"link-id", "link\u03a9name", b"parent-id") + ie.revision = b"link-rev-id" + ie.symlink_target = "target/\u03a9path" bytes = _chk_inventory_entry_to_bytes(ie) - self.assertEqual(b'symlink: link-id\nparent-id\nlink\xce\xa9name\n' - b'link-rev-id\ntarget/\xce\xa9path', bytes) + self.assertEqual( + b"symlink: link-id\nparent-id\nlink\xce\xa9name\n" + b"link-rev-id\ntarget/\xce\xa9path", + bytes, + ) ie2 = _chk_inventory_bytes_to_entry(bytes) self.assertEqual(ie, ie2) self.assertIsInstance(ie2.name, str) self.assertIsInstance(ie2.symlink_target, str) - self.assertEqual((b'link\xce\xa9name', b'link-id', b'link-rev-id'), - chk_inventory_bytes_to_utf8name_key(bytes)) + self.assertEqual( + (b"link\xce\xa9name", b"link-id", b"link-rev-id"), + chk_inventory_bytes_to_utf8name_key(bytes), + ) def test_tree_reference_entry_to_bytes(self): CHKInventory(None) - ie = inventory.TreeReference(b'tree-root-id', 'tree\u03a9name', - b'parent-id') - ie.revision = b'tree-rev-id' - ie.reference_revision = b'ref-rev-id' + ie = inventory.TreeReference(b"tree-root-id", "tree\u03a9name", b"parent-id") + ie.revision = b"tree-rev-id" + ie.reference_revision = b"ref-rev-id" bytes = _chk_inventory_entry_to_bytes(ie) - self.assertEqual(b'tree: tree-root-id\nparent-id\ntree\xce\xa9name\n' - b'tree-rev-id\nref-rev-id', bytes) + self.assertEqual( + b"tree: tree-root-id\nparent-id\ntree\xce\xa9name\n" + b"tree-rev-id\nref-rev-id", + bytes, + ) ie2 = _chk_inventory_bytes_to_entry(bytes) self.assertEqual(ie, ie2) self.assertIsInstance(ie2.name, str) - self.assertEqual((b'tree\xce\xa9name', b'tree-root-id', b'tree-rev-id'), - chk_inventory_bytes_to_utf8name_key(bytes)) + self.assertEqual( + (b"tree\xce\xa9name", b"tree-root-id", b"tree-rev-id"), + chk_inventory_bytes_to_utf8name_key(bytes), + ) def make_basic_utf8_inventory(self): inv = Inventory() inv.revision_id = b"revid" inv.root.revision = b"rootrev" root_id = inv.root.file_id - inv.add(InventoryFile(b"fileid", 'f\xefle', root_id)) + inv.add(InventoryFile(b"fileid", "f\xefle", root_id)) inv.get_entry(b"fileid").revision = b"filerev" inv.get_entry(b"fileid").text_sha1 = b"ffff" inv.get_entry(b"fileid").text_size = 0 - inv.add(InventoryDirectory(b"dirid", 'dir-\N{EURO SIGN}', root_id)) + inv.add(InventoryDirectory(b"dirid", "dir-\N{EURO SIGN}", root_id)) inv.get_entry(b"dirid").revision = b"dirrev" - inv.add(InventoryFile(b"childid", 'ch\xefld', b"dirid")) + inv.add(InventoryFile(b"childid", "ch\xefld", b"dirid")) inv.get_entry(b"childid").revision = b"filerev" inv.get_entry(b"childid").text_sha1 = b"ffff" inv.get_entry(b"childid").text_size = 0 @@ -1321,12 +1392,18 @@ def test__preload_handles_utf8(self): new_inv._preload_cache() self.assertEqual( sorted([new_inv.root_id, b"fileid", b"dirid", b"childid"]), - sorted(new_inv._fileid_to_entry_cache.keys())) + sorted(new_inv._fileid_to_entry_cache.keys()), + ) ie_root = new_inv._fileid_to_entry_cache[new_inv.root_id] - self.assertEqual(['dir-\N{EURO SIGN}', 'f\xefle'], - [ie.name for ie in new_inv.iter_sorted_children(ie_root.file_id)]) - ie_dir = new_inv._fileid_to_entry_cache[b'dirid'] - self.assertEqual(['ch\xefld'], [ie.name for ie in new_inv.iter_sorted_children(ie_dir.file_id)]) + self.assertEqual( + ["dir-\N{EURO SIGN}", "f\xefle"], + [ie.name for ie in new_inv.iter_sorted_children(ie_root.file_id)], + ) + ie_dir = new_inv._fileid_to_entry_cache[b"dirid"] + self.assertEqual( + ["ch\xefld"], + [ie.name for ie in new_inv.iter_sorted_children(ie_dir.file_id)], + ) def test__preload_populates_cache(self): inv = Inventory() @@ -1354,74 +1431,90 @@ def test__preload_populates_cache(self): new_inv._preload_cache() self.assertEqual( sorted([root_id, b"fileid", b"dirid", b"childid"]), - sorted(new_inv._fileid_to_entry_cache.keys())) + sorted(new_inv._fileid_to_entry_cache.keys()), + ) self.assertTrue(new_inv._fully_cached) ie_root = new_inv._fileid_to_entry_cache[root_id] - self.assertEqual(['dir', 'file'], [ie.name for ie in new_inv.iter_sorted_children(ie_root.file_id)]) - ie_dir = new_inv._fileid_to_entry_cache[b'dirid'] - self.assertEqual(['child'], [ie.name for ie in new_inv.iter_sorted_children(ie_dir.file_id)]) + self.assertEqual( + ["dir", "file"], + [ie.name for ie in new_inv.iter_sorted_children(ie_root.file_id)], + ) + ie_dir = new_inv._fileid_to_entry_cache[b"dirid"] + self.assertEqual( + ["child"], [ie.name for ie in new_inv.iter_sorted_children(ie_dir.file_id)] + ) def test__preload_handles_partially_evaluated_inventory(self): new_inv = self.make_basic_utf8_inventory() ie = new_inv.get_entry(new_inv.root_id) - self.assertEqual(['dir-\N{EURO SIGN}', 'f\xefle'], - [c.name for c in new_inv.iter_sorted_children(ie.file_id)]) + self.assertEqual( + ["dir-\N{EURO SIGN}", "f\xefle"], + [c.name for c in new_inv.iter_sorted_children(ie.file_id)], + ) new_inv._preload_cache() # No change - self.assertEqual(['dir-\N{EURO SIGN}', 'f\xefle'], - [c.name for c in new_inv.iter_sorted_children(ie.file_id)]) - self.assertEqual(['ch\xefld'], - [c.name for c in new_inv.iter_sorted_children(b"dirid")]) + self.assertEqual( + ["dir-\N{EURO SIGN}", "f\xefle"], + [c.name for c in new_inv.iter_sorted_children(ie.file_id)], + ) + self.assertEqual( + ["ch\xefld"], [c.name for c in new_inv.iter_sorted_children(b"dirid")] + ) def test_filter_change_in_renamed_subfolder(self): - inv = Inventory(b'tree-root') - inv.root.revision = b'rootrev' - src_ie = inv.add_path('src', 'directory', b'src-id') - src_ie.revision = b'srcrev' - sub_ie = inv.add_path('src/sub/', 'directory', b'sub-id') - sub_ie.revision = b'subrev' - a_ie = inv.add_path('src/sub/a', 'file', b'a-id') - a_ie.revision = b'filerev' - a_ie.text_sha1 = osutils.sha_string(b'content\n') - a_ie.text_size = len(b'content\n') + inv = Inventory(b"tree-root") + inv.root.revision = b"rootrev" + src_ie = inv.add_path("src", "directory", b"src-id") + src_ie.revision = b"srcrev" + sub_ie = inv.add_path("src/sub/", "directory", b"sub-id") + sub_ie.revision = b"subrev" + a_ie = inv.add_path("src/sub/a", "file", b"a-id") + a_ie.revision = b"filerev" + a_ie.text_sha1 = osutils.sha_string(b"content\n") + a_ie.text_size = len(b"content\n") chk_bytes = self.get_chk_bytes() inv = CHKInventory.from_inventory(chk_bytes, inv) - inv = inv.create_by_apply_delta(InventoryDelta([ - ("src/sub/a", "src/sub/a", b"a-id", a_ie), - ("src", "src2", b"src-id", src_ie), - ]), b'new-rev-2') - new_inv = inv.filter([b'a-id', b'src-id']) - self.assertEqual([ - ('', b'tree-root'), - ('src', b'src-id'), - ('src/sub', b'sub-id'), - ('src/sub/a', b'a-id'), - ], [(path, ie.file_id) for path, ie in new_inv.iter_entries()]) + inv = inv.create_by_apply_delta( + InventoryDelta( + [ + ("src/sub/a", "src/sub/a", b"a-id", a_ie), + ("src", "src2", b"src-id", src_ie), + ] + ), + b"new-rev-2", + ) + new_inv = inv.filter([b"a-id", b"src-id"]) + self.assertEqual( + [ + ("", b"tree-root"), + ("src", b"src-id"), + ("src/sub", b"sub-id"), + ("src/sub/a", b"a-id"), + ], + [(path, ie.file_id) for path, ie in new_inv.iter_entries()], + ) class TestCHKInventoryExpand(tests.TestCaseWithMemoryTransport): - def get_chk_bytes(self): factory = groupcompress.make_pack_factory(True, True, 1) - trans = self.get_transport('') + trans = self.get_transport("") return factory(trans) def make_dir(self, inv, name, parent_id, revision): - ie = inv.make_entry('directory', name, parent_id, - name.encode('utf-8') + b'-id') + ie = inv.make_entry("directory", name, parent_id, name.encode("utf-8") + b"-id") ie.revision = revision inv.add(ie) - def make_file(self, inv, name, parent_id, revision, content=b'content\n'): - ie = inv.make_entry('file', name, parent_id, - name.encode('utf-8') + b'-id') + def make_file(self, inv, name, parent_id, revision, content=b"content\n"): + ie = inv.make_entry("file", name, parent_id, name.encode("utf-8") + b"-id") ie.text_sha1 = osutils.sha_string(content) ie.text_size = len(content) ie.revision = revision inv.add(ie) def make_simple_inventory(self): - inv = Inventory(b'TREE_ROOT') + inv = Inventory(b"TREE_ROOT") inv.revision_id = b"revid" inv.root.revision = b"rootrev" # / TREE_ROOT @@ -1433,39 +1526,40 @@ def make_simple_inventory(self): # dir2/ dir2-id # sub2-file1 sub2-file1-id # top top-id - self.make_dir(inv, 'dir1', b'TREE_ROOT', b'dirrev') - self.make_dir(inv, 'dir2', b'TREE_ROOT', b'dirrev') - self.make_dir(inv, 'sub-dir1', b'dir1-id', b'dirrev') - self.make_file(inv, 'top', b'TREE_ROOT', b'filerev') - self.make_file(inv, 'sub-file1', b'dir1-id', b'filerev') - self.make_file(inv, 'sub-file2', b'dir1-id', b'filerev') - self.make_file(inv, 'subsub-file1', b'sub-dir1-id', b'filerev') - self.make_file(inv, 'sub2-file1', b'dir2-id', b'filerev') + self.make_dir(inv, "dir1", b"TREE_ROOT", b"dirrev") + self.make_dir(inv, "dir2", b"TREE_ROOT", b"dirrev") + self.make_dir(inv, "sub-dir1", b"dir1-id", b"dirrev") + self.make_file(inv, "top", b"TREE_ROOT", b"filerev") + self.make_file(inv, "sub-file1", b"dir1-id", b"filerev") + self.make_file(inv, "sub-file2", b"dir1-id", b"filerev") + self.make_file(inv, "subsub-file1", b"sub-dir1-id", b"filerev") + self.make_file(inv, "sub2-file1", b"dir2-id", b"filerev") chk_bytes = self.get_chk_bytes() # use a small maximum_size to force internal paging structures - chk_inv = CHKInventory.from_inventory(chk_bytes, inv, - maximum_size=100, - search_key_name=b'hash-255-way') + chk_inv = CHKInventory.from_inventory( + chk_bytes, inv, maximum_size=100, search_key_name=b"hash-255-way" + ) lines = chk_inv.to_lines() return CHKInventory.deserialise(chk_bytes, lines, (b"revid",)) def assert_Getitems(self, expected_fileids, inv, file_ids): - self.assertEqual(sorted(expected_fileids), - sorted([ie.file_id for ie in inv._getitems(file_ids)])) + self.assertEqual( + sorted(expected_fileids), + sorted([ie.file_id for ie in inv._getitems(file_ids)]), + ) def assertExpand(self, all_ids, inv, file_ids): - (val_all_ids, - val_children) = inv._expand_fileids_to_parents_and_children(file_ids) + (val_all_ids, val_children) = inv._expand_fileids_to_parents_and_children( + file_ids + ) self.assertEqual(set(all_ids), val_all_ids) entries = inv._getitems(val_all_ids) expected_children = {} for entry in entries: s = expected_children.setdefault(entry.parent_id, []) s.append(entry.file_id) - val_children = { - k: sorted(v) for k, v in val_children.items()} - expected_children = { - k: sorted(v) for k, v in expected_children.items()} + val_children = {k: sorted(v) for k, v in val_children.items()} + expected_children = {k: sorted(v) for k, v in expected_children.items()} self.assertEqual(expected_children, val_children) def test_make_simple_inventory(self): @@ -1473,101 +1567,135 @@ def test_make_simple_inventory(self): layout = [] for path, entry in inv.iter_entries_by_dir(): layout.append((path, entry.file_id)) - self.assertEqual([ - ('', b'TREE_ROOT'), - ('dir1', b'dir1-id'), - ('dir2', b'dir2-id'), - ('top', b'top-id'), - ('dir1/sub-dir1', b'sub-dir1-id'), - ('dir1/sub-file1', b'sub-file1-id'), - ('dir1/sub-file2', b'sub-file2-id'), - ('dir1/sub-dir1/subsub-file1', b'subsub-file1-id'), - ('dir2/sub2-file1', b'sub2-file1-id'), - ], layout) + self.assertEqual( + [ + ("", b"TREE_ROOT"), + ("dir1", b"dir1-id"), + ("dir2", b"dir2-id"), + ("top", b"top-id"), + ("dir1/sub-dir1", b"sub-dir1-id"), + ("dir1/sub-file1", b"sub-file1-id"), + ("dir1/sub-file2", b"sub-file2-id"), + ("dir1/sub-dir1/subsub-file1", b"subsub-file1-id"), + ("dir2/sub2-file1", b"sub2-file1-id"), + ], + layout, + ) def test__getitems(self): inv = self.make_simple_inventory() # Reading from disk - self.assert_Getitems([b'dir1-id'], inv, [b'dir1-id']) - self.assertIn(b'dir1-id', inv._fileid_to_entry_cache) - self.assertNotIn(b'sub-file2-id', inv._fileid_to_entry_cache) + self.assert_Getitems([b"dir1-id"], inv, [b"dir1-id"]) + self.assertIn(b"dir1-id", inv._fileid_to_entry_cache) + self.assertNotIn(b"sub-file2-id", inv._fileid_to_entry_cache) # From cache - self.assert_Getitems([b'dir1-id'], inv, [b'dir1-id']) + self.assert_Getitems([b"dir1-id"], inv, [b"dir1-id"]) # Mixed - self.assert_Getitems([b'dir1-id', b'sub-file2-id'], inv, - [b'dir1-id', b'sub-file2-id']) - self.assertIn(b'dir1-id', inv._fileid_to_entry_cache) - self.assertIn(b'sub-file2-id', inv._fileid_to_entry_cache) + self.assert_Getitems( + [b"dir1-id", b"sub-file2-id"], inv, [b"dir1-id", b"sub-file2-id"] + ) + self.assertIn(b"dir1-id", inv._fileid_to_entry_cache) + self.assertIn(b"sub-file2-id", inv._fileid_to_entry_cache) def test_single_file(self): inv = self.make_simple_inventory() - self.assertExpand([b'TREE_ROOT', b'top-id'], inv, [b'top-id']) + self.assertExpand([b"TREE_ROOT", b"top-id"], inv, [b"top-id"]) def test_get_all_parents(self): inv = self.make_simple_inventory() - self.assertExpand([b'TREE_ROOT', b'dir1-id', b'sub-dir1-id', - b'subsub-file1-id', - ], inv, [b'subsub-file1-id']) + self.assertExpand( + [ + b"TREE_ROOT", + b"dir1-id", + b"sub-dir1-id", + b"subsub-file1-id", + ], + inv, + [b"subsub-file1-id"], + ) def test_get_children(self): inv = self.make_simple_inventory() - self.assertExpand([b'TREE_ROOT', b'dir1-id', b'sub-dir1-id', - b'sub-file1-id', b'sub-file2-id', b'subsub-file1-id', - ], inv, [b'dir1-id']) + self.assertExpand( + [ + b"TREE_ROOT", + b"dir1-id", + b"sub-dir1-id", + b"sub-file1-id", + b"sub-file2-id", + b"subsub-file1-id", + ], + inv, + [b"dir1-id"], + ) def test_from_root(self): inv = self.make_simple_inventory() - self.assertExpand([b'TREE_ROOT', b'dir1-id', b'dir2-id', b'sub-dir1-id', - b'sub-file1-id', b'sub-file2-id', b'sub2-file1-id', - b'subsub-file1-id', b'top-id'], inv, [b'TREE_ROOT']) + self.assertExpand( + [ + b"TREE_ROOT", + b"dir1-id", + b"dir2-id", + b"sub-dir1-id", + b"sub-file1-id", + b"sub-file2-id", + b"sub2-file1-id", + b"subsub-file1-id", + b"top-id", + ], + inv, + [b"TREE_ROOT"], + ) def test_top_level_file(self): inv = self.make_simple_inventory() - self.assertExpand([b'TREE_ROOT', b'top-id'], inv, [b'top-id']) + self.assertExpand([b"TREE_ROOT", b"top-id"], inv, [b"top-id"]) def test_subsub_file(self): inv = self.make_simple_inventory() - self.assertExpand([b'TREE_ROOT', b'dir1-id', b'sub-dir1-id', - b'subsub-file1-id'], inv, [b'subsub-file1-id']) + self.assertExpand( + [b"TREE_ROOT", b"dir1-id", b"sub-dir1-id", b"subsub-file1-id"], + inv, + [b"subsub-file1-id"], + ) def test_sub_and_root(self): inv = self.make_simple_inventory() - self.assertExpand([b'TREE_ROOT', b'dir1-id', b'sub-dir1-id', b'top-id', - b'subsub-file1-id'], inv, [b'top-id', b'subsub-file1-id']) + self.assertExpand( + [b"TREE_ROOT", b"dir1-id", b"sub-dir1-id", b"top-id", b"subsub-file1-id"], + inv, + [b"top-id", b"subsub-file1-id"], + ) class TestMutableInventoryFromTree(TestCaseWithTransport): - def test_empty(self): - repository = self.make_repository('.') + repository = self.make_repository(".") tree = repository.revision_tree(revision.NULL_REVISION) inv = mutable_inventory_from_tree(tree) self.assertEqual(revision.NULL_REVISION, inv.revision_id) self.assertEqual(0, len(inv)) def test_some_files(self): - wt = self.make_branch_and_tree('.') - self.build_tree(['a']) - wt.add(['a'], ids=[b'thefileid']) + wt = self.make_branch_and_tree(".") + self.build_tree(["a"]) + wt.add(["a"], ids=[b"thefileid"]) revid = wt.commit("commit") tree = wt.branch.repository.revision_tree(revid) inv = mutable_inventory_from_tree(tree) self.assertEqual(revid, inv.revision_id) self.assertEqual(2, len(inv)) - self.assertEqual("a", inv.get_entry(b'thefileid').name) + self.assertEqual("a", inv.get_entry(b"thefileid").name) # The inventory should be mutable and independent of # the original tree - self.assertFalse(tree.root_inventory.get_entry( - b'thefileid').executable) - inv.get_entry(b'thefileid').executable = True - self.assertFalse(tree.root_inventory.get_entry( - b'thefileid').executable) + self.assertFalse(tree.root_inventory.get_entry(b"thefileid").executable) + inv.get_entry(b"thefileid").executable = True + self.assertFalse(tree.root_inventory.get_entry(b"thefileid").executable) class ErrorTests(TestCase): - def test_duplicate_file_id(self): - error = DuplicateFileId('a_file_id', 'foo') + error = DuplicateFileId("a_file_id", "foo") self.assertEqualDiff( - 'File id {a_file_id} already exists in inventory as foo', - str(error)) + "File id {a_file_id} already exists in inventory as foo", str(error) + ) diff --git a/breezy/bzr/tests/test_inventory_delta.py b/breezy/bzr/tests/test_inventory_delta.py index ef6a73316a..79555a502b 100644 --- a/breezy/bzr/tests/test_inventory_delta.py +++ b/breezy/bzr/tests/test_inventory_delta.py @@ -92,35 +92,37 @@ class TestDeserialization(TestCase): def test_parse_no_bytes(self): deserializer = inventory_delta.InventoryDeltaDeserializer() - err = self.assertRaises( - InventoryDeltaError, deserializer.parse_text_bytes, []) - self.assertContainsRe(str(err), 'inventory delta is empty') + err = self.assertRaises(InventoryDeltaError, deserializer.parse_text_bytes, []) + self.assertContainsRe(str(err), "inventory delta is empty") def test_parse_bad_format(self): deserializer = inventory_delta.InventoryDeltaDeserializer() - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, [b'format: foo\n']) - self.assertContainsRe(str(err), 'unknown format') + err = self.assertRaises( + InventoryDeltaError, deserializer.parse_text_bytes, [b"format: foo\n"] + ) + self.assertContainsRe(str(err), "unknown format") def test_parse_no_parent(self): deserializer = inventory_delta.InventoryDeltaDeserializer() - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, - [b'format: bzr inventory delta v1 (bzr 1.14)\n']) - self.assertContainsRe(str(err), 'missing parent: marker') + err = self.assertRaises( + InventoryDeltaError, + deserializer.parse_text_bytes, + [b"format: bzr inventory delta v1 (bzr 1.14)\n"], + ) + self.assertContainsRe(str(err), "missing parent: marker") def test_parse_no_version(self): deserializer = inventory_delta.InventoryDeltaDeserializer() - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, - [b'format: bzr inventory delta v1 (bzr 1.14)\n', - b'parent: null:\n']) - self.assertContainsRe(str(err), 'missing version: marker') + err = self.assertRaises( + InventoryDeltaError, + deserializer.parse_text_bytes, + [b"format: bzr inventory delta v1 (bzr 1.14)\n", b"parent: null:\n"], + ) + self.assertContainsRe(str(err), "missing version: marker") def test_parse_duplicate_key_errors(self): deserializer = inventory_delta.InventoryDeltaDeserializer() - double_root_lines = \ - b"""format: bzr inventory delta v1 (bzr 1.14) + double_root_lines = b"""format: bzr inventory delta v1 (bzr 1.14) parent: null: version: null: versioned_root: true @@ -128,20 +130,30 @@ def test_parse_duplicate_key_errors(self): None\x00/\x00an-id\x00\x00a@e\xc3\xa5ample.com--2004\x00dir\x00\x00 None\x00/\x00an-id\x00\x00a@e\xc3\xa5ample.com--2004\x00dir\x00\x00 """ - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, osutils.split_lines(double_root_lines)) - self.assertContainsRe(str(err), 'duplicate file id') + err = self.assertRaises( + InventoryDeltaError, + deserializer.parse_text_bytes, + osutils.split_lines(double_root_lines), + ) + self.assertContainsRe(str(err), "duplicate file id") def test_parse_versioned_root_only(self): deserializer = inventory_delta.InventoryDeltaDeserializer() - parse_result = deserializer.parse_text_bytes(osutils.split_lines(root_only_lines)) - expected_entry = inventory.make_entry( - 'directory', '', None, b'an-id') - expected_entry.revision = b'a@e\xc3\xa5ample.com--2004' + parse_result = deserializer.parse_text_bytes( + osutils.split_lines(root_only_lines) + ) + expected_entry = inventory.make_entry("directory", "", None, b"an-id") + expected_entry.revision = b"a@e\xc3\xa5ample.com--2004" self.assertEqual( - (b'null:', b'entry-version', True, True, - InventoryDelta([(None, '', b'an-id', expected_entry)])), - parse_result) + ( + b"null:", + b"entry-version", + True, + True, + InventoryDelta([(None, "", b"an-id", expected_entry)]), + ), + parse_result, + ) def test_parse_special_revid_not_valid_last_mod(self): deserializer = inventory_delta.InventoryDeltaDeserializer() @@ -153,9 +165,11 @@ def test_parse_special_revid_not_valid_last_mod(self): None\x00/\x00TREE_ROOT\x00\x00null:\x00dir\x00\x00 """ err = self.assertRaises( - InventoryDeltaError, deserializer.parse_text_bytes, - osutils.split_lines(root_only_lines)) - self.assertContainsRe(str(err), 'special revisionid found') + InventoryDeltaError, + deserializer.parse_text_bytes, + osutils.split_lines(root_only_lines), + ) + self.assertContainsRe(str(err), "special revisionid found") def test_parse_versioned_root_versioned_disabled(self): deserializer = inventory_delta.InventoryDeltaDeserializer() @@ -167,9 +181,11 @@ def test_parse_versioned_root_versioned_disabled(self): None\x00/\x00TREE_ROOT\x00\x00a@e\xc3\xa5ample.com--2004\x00dir\x00\x00 """ err = self.assertRaises( - InventoryDeltaError, deserializer.parse_text_bytes, - osutils.split_lines(root_only_lines)) - self.assertContainsRe(str(err), 'Versioned root found') + InventoryDeltaError, + deserializer.parse_text_bytes, + osutils.split_lines(root_only_lines), + ) + self.assertContainsRe(str(err), "Versioned root found") def test_parse_unique_root_id_root_versioned_disabled(self): deserializer = inventory_delta.InventoryDeltaDeserializer() @@ -180,34 +196,51 @@ def test_parse_unique_root_id_root_versioned_disabled(self): tree_references: true None\x00/\x00an-id\x00\x00parent-id\x00dir\x00\x00 """ - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, osutils.split_lines(root_only_lines)) - self.assertContainsRe(str(err), 'Versioned root found') + err = self.assertRaises( + InventoryDeltaError, + deserializer.parse_text_bytes, + osutils.split_lines(root_only_lines), + ) + self.assertContainsRe(str(err), "Versioned root found") def test_parse_unversioned_root_versioning_enabled(self): deserializer = inventory_delta.InventoryDeltaDeserializer() parse_result = deserializer.parse_text_bytes( - osutils.split_lines(root_only_unversioned)) - expected_entry = inventory.make_entry( - 'directory', '', None, b'TREE_ROOT') - expected_entry.revision = b'entry-version' + osutils.split_lines(root_only_unversioned) + ) + expected_entry = inventory.make_entry("directory", "", None, b"TREE_ROOT") + expected_entry.revision = b"entry-version" self.assertEqual( - (b'null:', b'entry-version', False, False, - InventoryDelta([(None, '', b'TREE_ROOT', expected_entry)])), - parse_result) + ( + b"null:", + b"entry-version", + False, + False, + InventoryDelta([(None, "", b"TREE_ROOT", expected_entry)]), + ), + parse_result, + ) def test_parse_versioned_root_when_disabled(self): deserializer = inventory_delta.InventoryDeltaDeserializer( - allow_versioned_root=False) - err = self.assertRaises(inventory_delta.IncompatibleInventoryDelta, - deserializer.parse_text_bytes, osutils.split_lines(root_only_lines)) + allow_versioned_root=False + ) + err = self.assertRaises( + inventory_delta.IncompatibleInventoryDelta, + deserializer.parse_text_bytes, + osutils.split_lines(root_only_lines), + ) self.assertEqual("versioned_root not allowed", str(err)) def test_parse_tree_when_disabled(self): deserializer = inventory_delta.InventoryDeltaDeserializer( - allow_tree_references=False) - err = self.assertRaises(inventory_delta.IncompatibleInventoryDelta, - deserializer.parse_text_bytes, osutils.split_lines(reference_lines)) + allow_tree_references=False + ) + err = self.assertRaises( + inventory_delta.IncompatibleInventoryDelta, + deserializer.parse_text_bytes, + osutils.split_lines(reference_lines), + ) self.assertEqual("Tree reference not allowed", str(err)) def test_parse_tree_when_header_disallows(self): @@ -222,9 +255,12 @@ def test_parse_tree_when_header_disallows(self): tree_references: false None\x00/foo\x00id\x00TREE_ROOT\x00changed\x00tree\x00subtree-version """ - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, osutils.split_lines(lines)) - self.assertContainsRe(str(err), 'Tree reference found') + err = self.assertRaises( + InventoryDeltaError, + deserializer.parse_text_bytes, + osutils.split_lines(lines), + ) + self.assertContainsRe(str(err), "Tree reference found") def test_parse_versioned_root_when_header_disallows(self): # A deserializer that allows tree_references to be set or unset. @@ -238,64 +274,75 @@ def test_parse_versioned_root_when_header_disallows(self): tree_references: false None\x00/\x00TREE_ROOT\x00\x00a@e\xc3\xa5ample.com--2004\x00dir """ - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, osutils.split_lines(lines)) - self.assertContainsRe(str(err), 'Versioned root found') + err = self.assertRaises( + InventoryDeltaError, + deserializer.parse_text_bytes, + osutils.split_lines(lines), + ) + self.assertContainsRe(str(err), "Versioned root found") def test_parse_last_line_not_empty(self): """Newpath must start with / if it is not None.""" # Trim the trailing newline from a valid serialization lines = root_only_lines[:-1] deserializer = inventory_delta.InventoryDeltaDeserializer() - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, osutils.split_lines(lines)) - self.assertContainsRe(str(err), 'last line not empty') + err = self.assertRaises( + InventoryDeltaError, + deserializer.parse_text_bytes, + osutils.split_lines(lines), + ) + self.assertContainsRe(str(err), "last line not empty") def test_parse_invalid_newpath(self): """Newpath must start with / if it is not None.""" lines = empty_lines lines += b"None\x00bad\x00TREE_ROOT\x00\x00version\x00dir\n" deserializer = inventory_delta.InventoryDeltaDeserializer() - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, osutils.split_lines(lines)) - self.assertContainsRe(str(err), 'newpath invalid') + err = self.assertRaises( + InventoryDeltaError, + deserializer.parse_text_bytes, + osutils.split_lines(lines), + ) + self.assertContainsRe(str(err), "newpath invalid") def test_parse_invalid_oldpath(self): """Oldpath must start with / if it is not None.""" lines = root_only_lines lines += b"bad\x00/new\x00file-id\x00\x00version\x00dir\n" deserializer = inventory_delta.InventoryDeltaDeserializer() - err = self.assertRaises(InventoryDeltaError, - deserializer.parse_text_bytes, osutils.split_lines(lines)) - self.assertContainsRe(str(err), 'oldpath invalid') + err = self.assertRaises( + InventoryDeltaError, + deserializer.parse_text_bytes, + osutils.split_lines(lines), + ) + self.assertContainsRe(str(err), "oldpath invalid") def test_parse_new_file(self): """A new file is parsed correctly.""" lines = root_only_lines fake_sha = b"deadbeef" * 5 lines += ( - b"None\x00/new\x00file-id\x00an-id\x00version\x00file\x00123\x00" + - b"\x00" + fake_sha + b"\n") + b"None\x00/new\x00file-id\x00an-id\x00version\x00file\x00123\x00" + + b"\x00" + + fake_sha + + b"\n" + ) deserializer = inventory_delta.InventoryDeltaDeserializer() parse_result = deserializer.parse_text_bytes(osutils.split_lines(lines)) - expected_entry = inventory.make_entry( - 'file', 'new', b'an-id', b'file-id') - expected_entry.revision = b'version' + expected_entry = inventory.make_entry("file", "new", b"an-id", b"file-id") + expected_entry.revision = b"version" expected_entry.text_size = 123 expected_entry.text_sha1 = fake_sha delta = parse_result[4] - self.assertEqual( - (None, 'new', b'file-id', expected_entry), delta[-1]) + self.assertEqual((None, "new", b"file-id", expected_entry), delta[-1]) def test_parse_delete(self): lines = root_only_lines - lines += ( - b"/old-file\x00None\x00deleted-id\x00\x00null:\x00deleted\x00\x00\n") + lines += b"/old-file\x00None\x00deleted-id\x00\x00null:\x00deleted\x00\x00\n" deserializer = inventory_delta.InventoryDeltaDeserializer() parse_result = deserializer.parse_text_bytes(osutils.split_lines(lines)) delta = parse_result[4] - self.assertEqual( - ('old-file', None, b'deleted-id', None), delta[-1]) + self.assertEqual(("old-file", None, b"deleted-id", None), delta[-1]) class TestSerialization(TestCase): @@ -306,154 +353,210 @@ def test_empty_delta_to_lines(self): new_inv = Inventory(None) delta = new_inv._make_delta(old_inv) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=True) - self.assertEqual(BytesIO(empty_lines).readlines(), - serializer.delta_to_lines(NULL_REVISION, NULL_REVISION, delta)) + versioned_root=True, tree_references=True + ) + self.assertEqual( + BytesIO(empty_lines).readlines(), + serializer.delta_to_lines(NULL_REVISION, NULL_REVISION, delta), + ) def test_root_only_to_lines(self): old_inv = Inventory(None) new_inv = Inventory(None) - root = new_inv.make_entry('directory', '', None, b'an-id') - root.revision = b'a@e\xc3\xa5ample.com--2004' + root = new_inv.make_entry("directory", "", None, b"an-id") + root.revision = b"a@e\xc3\xa5ample.com--2004" new_inv.add(root) delta = new_inv._make_delta(old_inv) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=True) - self.assertEqual(BytesIO(root_only_lines).readlines(), - serializer.delta_to_lines(NULL_REVISION, b'entry-version', delta)) + versioned_root=True, tree_references=True + ) + self.assertEqual( + BytesIO(root_only_lines).readlines(), + serializer.delta_to_lines(NULL_REVISION, b"entry-version", delta), + ) def test_unversioned_root(self): old_inv = Inventory(None) new_inv = Inventory(None) - root = new_inv.make_entry('directory', '', None, b'TREE_ROOT') + root = new_inv.make_entry("directory", "", None, b"TREE_ROOT") # Implicit roots are considered modified in every revision. - root.revision = b'entry-version' + root.revision = b"entry-version" new_inv.add(root) delta = new_inv._make_delta(old_inv) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=False, tree_references=False) + versioned_root=False, tree_references=False + ) serialized_lines = serializer.delta_to_lines( - NULL_REVISION, b'entry-version', delta) - self.assertEqual(BytesIO(root_only_unversioned).readlines(), - serialized_lines) + NULL_REVISION, b"entry-version", delta + ) + self.assertEqual(BytesIO(root_only_unversioned).readlines(), serialized_lines) deserializer = inventory_delta.InventoryDeltaDeserializer() self.assertEqual( - (NULL_REVISION, b'entry-version', False, False, delta), - deserializer.parse_text_bytes(serialized_lines)) + (NULL_REVISION, b"entry-version", False, False, delta), + deserializer.parse_text_bytes(serialized_lines), + ) def test_unversioned_non_root_errors(self): old_inv = Inventory(None) new_inv = Inventory(None) - root = new_inv.make_entry('directory', '', None, b'TREE_ROOT') - root.revision = b'a@e\xc3\xa5ample.com--2004' + root = new_inv.make_entry("directory", "", None, b"TREE_ROOT") + root.revision = b"a@e\xc3\xa5ample.com--2004" new_inv.add(root) - non_root = new_inv.make_entry('directory', 'foo', root.file_id, b'id') + non_root = new_inv.make_entry("directory", "foo", root.file_id, b"id") new_inv.add(non_root) delta = new_inv._make_delta(old_inv) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=True) - err = self.assertRaises(InventoryDeltaError, - serializer.delta_to_lines, NULL_REVISION, b'entry-version', delta) + versioned_root=True, tree_references=True + ) + err = self.assertRaises( + InventoryDeltaError, + serializer.delta_to_lines, + NULL_REVISION, + b"entry-version", + delta, + ) self.assertContainsRe(str(err), "^no version for fileid id$") def test_richroot_unversioned_root_errors(self): old_inv = Inventory(None) new_inv = Inventory(None) - root = new_inv.make_entry('directory', '', None, b'TREE_ROOT') + root = new_inv.make_entry("directory", "", None, b"TREE_ROOT") new_inv.add(root) delta = new_inv._make_delta(old_inv) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=True) - err = self.assertRaises(InventoryDeltaError, - serializer.delta_to_lines, NULL_REVISION, b'entry-version', delta) - self.assertContainsRe( - str(err), "no version for fileid TREE_ROOT$") + versioned_root=True, tree_references=True + ) + err = self.assertRaises( + InventoryDeltaError, + serializer.delta_to_lines, + NULL_REVISION, + b"entry-version", + delta, + ) + self.assertContainsRe(str(err), "no version for fileid TREE_ROOT$") def test_nonrichroot_versioned_root_errors(self): old_inv = Inventory(None) new_inv = Inventory(None) - root = new_inv.make_entry('directory', '', None, b'TREE_ROOT') - root.revision = b'a@e\xc3\xa5ample.com--2004' + root = new_inv.make_entry("directory", "", None, b"TREE_ROOT") + root.revision = b"a@e\xc3\xa5ample.com--2004" new_inv.add(root) delta = new_inv._make_delta(old_inv) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=False, tree_references=True) - err = self.assertRaises(InventoryDeltaError, - serializer.delta_to_lines, NULL_REVISION, b'entry-version', delta) - self.assertContainsRe( - str(err), "^Version present for / in TREE_ROOT") + versioned_root=False, tree_references=True + ) + err = self.assertRaises( + InventoryDeltaError, + serializer.delta_to_lines, + NULL_REVISION, + b"entry-version", + delta, + ) + self.assertContainsRe(str(err), "^Version present for / in TREE_ROOT") def test_tree_reference_disabled(self): old_inv = Inventory(None) new_inv = Inventory(None) - root = new_inv.make_entry('directory', '', None, b'TREE_ROOT') - root.revision = b'a@e\xc3\xa5ample.com--2004' + root = new_inv.make_entry("directory", "", None, b"TREE_ROOT") + root.revision = b"a@e\xc3\xa5ample.com--2004" new_inv.add(root) - non_root = new_inv.make_entry( - 'tree-reference', 'foo', root.file_id, b'id') - non_root.revision = b'changed' - non_root.reference_revision = b'subtree-version' + non_root = new_inv.make_entry("tree-reference", "foo", root.file_id, b"id") + non_root.revision = b"changed" + non_root.reference_revision = b"subtree-version" new_inv.add(non_root) delta = new_inv._make_delta(old_inv) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=False) + versioned_root=True, tree_references=False + ) # we expect keyerror because there is little value wrapping this. # This test aims to prove that it errors more than how it errors. - err = self.assertRaises(KeyError, - serializer.delta_to_lines, NULL_REVISION, b'entry-version', delta) - self.assertEqual(('tree-reference',), err.args) + err = self.assertRaises( + KeyError, serializer.delta_to_lines, NULL_REVISION, b"entry-version", delta + ) + self.assertEqual(("tree-reference",), err.args) def test_tree_reference_enabled(self): old_inv = Inventory(None) new_inv = Inventory(None) - root = new_inv.make_entry('directory', '', None, b'TREE_ROOT') - root.revision = b'a@e\xc3\xa5ample.com--2004' + root = new_inv.make_entry("directory", "", None, b"TREE_ROOT") + root.revision = b"a@e\xc3\xa5ample.com--2004" new_inv.add(root) - non_root = new_inv.make_entry( - 'tree-reference', 'foo', root.file_id, b'id') - non_root.revision = b'changed' - non_root.reference_revision = b'subtree-version' + non_root = new_inv.make_entry("tree-reference", "foo", root.file_id, b"id") + non_root.revision = b"changed" + non_root.reference_revision = b"subtree-version" new_inv.add(non_root) delta = new_inv._make_delta(old_inv) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=True) - self.assertEqual(BytesIO(reference_lines).readlines(), - serializer.delta_to_lines(NULL_REVISION, b'entry-version', delta)) + versioned_root=True, tree_references=True + ) + self.assertEqual( + BytesIO(reference_lines).readlines(), + serializer.delta_to_lines(NULL_REVISION, b"entry-version", delta), + ) def test_to_inventory_root_id_versioned_not_permitted(self): - root_entry = inventory.make_entry('directory', '', None, b'TREE_ROOT') - root_entry.revision = b'some-version' - delta = InventoryDelta([(None, '', b'TREE_ROOT', root_entry)]) + root_entry = inventory.make_entry("directory", "", None, b"TREE_ROOT") + root_entry.revision = b"some-version" + delta = InventoryDelta([(None, "", b"TREE_ROOT", root_entry)]) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=False, tree_references=True) + versioned_root=False, tree_references=True + ) self.assertRaises( - InventoryDeltaError, serializer.delta_to_lines, b'old-version', - b'new-version', delta) + InventoryDeltaError, + serializer.delta_to_lines, + b"old-version", + b"new-version", + delta, + ) def test_to_inventory_root_id_not_versioned(self): - delta = InventoryDelta([(None, '', b'an-id', inventory.make_entry( - 'directory', '', None, b'an-id'))]) + delta = InventoryDelta( + [ + ( + None, + "", + b"an-id", + inventory.make_entry("directory", "", None, b"an-id"), + ) + ] + ) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=True) + versioned_root=True, tree_references=True + ) self.assertRaises( - InventoryDeltaError, serializer.delta_to_lines, b'old-version', - b'new-version', delta) + InventoryDeltaError, + serializer.delta_to_lines, + b"old-version", + b"new-version", + delta, + ) def test_to_inventory_has_tree_not_meant_to(self): make_entry = inventory.make_entry - tree_ref = make_entry( - 'tree-reference', 'foo', b'changed-in', b'ref-id') - tree_ref.reference_revision = b'ref-revision' - delta = InventoryDelta([ - (None, '', b'an-id', - make_entry('directory', '', b'changed-in', b'an-id')), - (None, 'foo', b'ref-id', tree_ref) - # a file that followed the root move - ]) + tree_ref = make_entry("tree-reference", "foo", b"changed-in", b"ref-id") + tree_ref.reference_revision = b"ref-revision" + delta = InventoryDelta( + [ + ( + None, + "", + b"an-id", + make_entry("directory", "", b"changed-in", b"an-id"), + ), + (None, "foo", b"ref-id", tree_ref), + # a file that followed the root move + ] + ) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=True) - self.assertRaises(InventoryDeltaError, serializer.delta_to_lines, - b'old-version', b'new-version', delta) + versioned_root=True, tree_references=True + ) + self.assertRaises( + InventoryDeltaError, + serializer.delta_to_lines, + b"old-version", + b"new-version", + delta, + ) def test_to_inventory_torture(self): def make_entry(kind, name, parent_id, file_id, **attrs): @@ -461,50 +564,118 @@ def make_entry(kind, name, parent_id, file_id, **attrs): for name, value in attrs.items(): setattr(entry, name, value) return entry + # this delta is crafted to have all the following: # - deletes # - renamed roots # - deep dirs # - files moved after parent dir was renamed # - files with and without exec bit - delta = InventoryDelta([ - # new root: - (None, '', b'new-root-id', - make_entry('directory', '', None, b'new-root-id', - revision=b'changed-in')), - # an old root: - ('', 'old-root', b'TREE_ROOT', - make_entry('directory', 'subdir-now', b'new-root-id', - b'TREE_ROOT', revision=b'moved-root')), - # a file that followed the root move - ('under-old-root', 'old-root/under-old-root', b'moved-id', - make_entry('file', 'under-old-root', b'TREE_ROOT', - b'moved-id', revision=b'old-rev', executable=False, - text_size=30, text_sha1=b'some-sha')), - # a deleted path - ('old-file', None, b'deleted-id', None), - # a tree reference moved to the new root - ('ref', 'ref', b'ref-id', - make_entry('tree-reference', 'ref', b'new-root-id', b'ref-id', - reference_revision=b'tree-reference-id', - revision=b'new-rev')), - # a symlink now in a deep dir - ('dir/link', 'old-root/dir/link', b'link-id', - make_entry('symlink', 'link', b'deep-id', b'link-id', - symlink_target='target', revision=b'new-rev')), - # a deep dir - ('dir', 'old-root/dir', b'deep-id', - make_entry('directory', 'dir', b'TREE_ROOT', b'deep-id', - revision=b'new-rev')), - # a file with an exec bit set - (None, 'configure', b'exec-id', - make_entry('file', 'configure', b'new-root-id', b'exec-id', - executable=True, text_size=30, text_sha1=b'some-sha', - revision=b'old-rev')), - ]) + delta = InventoryDelta( + [ + # new root: + ( + None, + "", + b"new-root-id", + make_entry( + "directory", "", None, b"new-root-id", revision=b"changed-in" + ), + ), + # an old root: + ( + "", + "old-root", + b"TREE_ROOT", + make_entry( + "directory", + "subdir-now", + b"new-root-id", + b"TREE_ROOT", + revision=b"moved-root", + ), + ), + # a file that followed the root move + ( + "under-old-root", + "old-root/under-old-root", + b"moved-id", + make_entry( + "file", + "under-old-root", + b"TREE_ROOT", + b"moved-id", + revision=b"old-rev", + executable=False, + text_size=30, + text_sha1=b"some-sha", + ), + ), + # a deleted path + ("old-file", None, b"deleted-id", None), + # a tree reference moved to the new root + ( + "ref", + "ref", + b"ref-id", + make_entry( + "tree-reference", + "ref", + b"new-root-id", + b"ref-id", + reference_revision=b"tree-reference-id", + revision=b"new-rev", + ), + ), + # a symlink now in a deep dir + ( + "dir/link", + "old-root/dir/link", + b"link-id", + make_entry( + "symlink", + "link", + b"deep-id", + b"link-id", + symlink_target="target", + revision=b"new-rev", + ), + ), + # a deep dir + ( + "dir", + "old-root/dir", + b"deep-id", + make_entry( + "directory", + "dir", + b"TREE_ROOT", + b"deep-id", + revision=b"new-rev", + ), + ), + # a file with an exec bit set + ( + None, + "configure", + b"exec-id", + make_entry( + "file", + "configure", + b"new-root-id", + b"exec-id", + executable=True, + text_size=30, + text_sha1=b"some-sha", + revision=b"old-rev", + ), + ), + ] + ) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=True) - lines = serializer.delta_to_lines(NULL_REVISION, b'something', delta) + versioned_root=True, tree_references=True + ) + lines = serializer.delta_to_lines(NULL_REVISION, b"something", delta) expected = b"""format: bzr inventory delta v1 (bzr 1.14) parent: null: version: something @@ -519,7 +690,7 @@ def make_entry(kind, name, parent_id, file_id, **attrs): None\x00/\x00new-root-id\x00\x00changed-in\x00dir None\x00/configure\x00exec-id\x00new-root-id\x00old-rev\x00file\x0030\x00Y\x00some-sha """ - serialized = b''.join(lines) + serialized = b"".join(lines) self.assertIsInstance(serialized, bytes) self.assertEqual(expected, serialized) @@ -528,79 +699,90 @@ class TestContent(TestCase): """Test serialization of the content part of a line.""" def test_dir(self): - entry = inventory.make_entry('directory', 'a dir', b'parent') - self.assertEqual(b'dir', inventory_delta.serialize_inventory_entry(entry)) + entry = inventory.make_entry("directory", "a dir", b"parent") + self.assertEqual(b"dir", inventory_delta.serialize_inventory_entry(entry)) def test_file_0_short_sha(self): - file_entry = inventory.make_entry('file', 'a file', b'parent', b'file-id') - file_entry.text_sha1 = b'' + file_entry = inventory.make_entry("file", "a file", b"parent", b"file-id") + file_entry.text_sha1 = b"" file_entry.text_size = 0 - self.assertEqual(b'file\x000\x00\x00', - inventory_delta.serialize_inventory_entry(file_entry)) + self.assertEqual( + b"file\x000\x00\x00", inventory_delta.serialize_inventory_entry(file_entry) + ) def test_file_10_foo(self): - file_entry = inventory.make_entry('file', 'a file', b'parent', b'file-id') - file_entry.text_sha1 = b'foo' + file_entry = inventory.make_entry("file", "a file", b"parent", b"file-id") + file_entry.text_sha1 = b"foo" file_entry.text_size = 10 - self.assertEqual(b'file\x0010\x00\x00foo', - inventory_delta.serialize_inventory_entry(file_entry)) + self.assertEqual( + b"file\x0010\x00\x00foo", + inventory_delta.serialize_inventory_entry(file_entry), + ) def test_file_executable(self): - file_entry = inventory.make_entry('file', 'a file', b'parent', b'file-id') + file_entry = inventory.make_entry("file", "a file", b"parent", b"file-id") file_entry.executable = True - file_entry.text_sha1 = b'foo' + file_entry.text_sha1 = b"foo" file_entry.text_size = 10 - self.assertEqual(b'file\x0010\x00Y\x00foo', - inventory_delta.serialize_inventory_entry(file_entry)) + self.assertEqual( + b"file\x0010\x00Y\x00foo", + inventory_delta.serialize_inventory_entry(file_entry), + ) def test_file_without_size(self): - file_entry = inventory.make_entry('file', 'a file', b'parent', b'file-id') - file_entry.text_sha1 = b'foo' - self.assertRaises(InventoryDeltaError, - inventory_delta.serialize_inventory_entry, file_entry) + file_entry = inventory.make_entry("file", "a file", b"parent", b"file-id") + file_entry.text_sha1 = b"foo" + self.assertRaises( + InventoryDeltaError, inventory_delta.serialize_inventory_entry, file_entry + ) def test_file_without_sha1(self): - file_entry = inventory.make_entry('file', 'a file', b'parent', b'file-id') + file_entry = inventory.make_entry("file", "a file", b"parent", b"file-id") file_entry.text_size = 10 - self.assertRaises(InventoryDeltaError, - inventory_delta.serialize_inventory_entry, file_entry) + self.assertRaises( + InventoryDeltaError, inventory_delta.serialize_inventory_entry, file_entry + ) def test_link_empty_target(self): - entry = inventory.make_entry('symlink', 'a link', b'parent') - entry.symlink_target = '' - self.assertEqual(b'link\x00', - inventory_delta.serialize_inventory_entry(entry)) + entry = inventory.make_entry("symlink", "a link", b"parent") + entry.symlink_target = "" + self.assertEqual(b"link\x00", inventory_delta.serialize_inventory_entry(entry)) def test_link_unicode_target(self): - entry = inventory.make_entry('symlink', 'a link', b'parent') - entry.symlink_target = b' \xc3\xa5'.decode('utf8') - self.assertEqual(b'link\x00 \xc3\xa5', - inventory_delta.serialize_inventory_entry(entry)) + entry = inventory.make_entry("symlink", "a link", b"parent") + entry.symlink_target = b" \xc3\xa5".decode("utf8") + self.assertEqual( + b"link\x00 \xc3\xa5", inventory_delta.serialize_inventory_entry(entry) + ) def test_link_space_target(self): - entry = inventory.make_entry('symlink', 'a link', b'parent') - entry.symlink_target = ' ' - self.assertEqual(b'link\x00 ', - inventory_delta.serialize_inventory_entry(entry)) + entry = inventory.make_entry("symlink", "a link", b"parent") + entry.symlink_target = " " + self.assertEqual(b"link\x00 ", inventory_delta.serialize_inventory_entry(entry)) def test_link_no_target(self): - entry = inventory.make_entry('symlink', 'a link', b'parent') - self.assertRaises(InventoryDeltaError, - inventory_delta.serialize_inventory_entry, entry) + entry = inventory.make_entry("symlink", "a link", b"parent") + self.assertRaises( + InventoryDeltaError, inventory_delta.serialize_inventory_entry, entry + ) def test_reference_null(self): - entry = inventory.make_entry('tree-reference', 'a tree', b'parent') + entry = inventory.make_entry("tree-reference", "a tree", b"parent") entry.reference_revision = NULL_REVISION - self.assertEqual(b'tree\x00null:', - inventory_delta.serialize_inventory_entry(entry)) + self.assertEqual( + b"tree\x00null:", inventory_delta.serialize_inventory_entry(entry) + ) def test_reference_revision(self): - entry = inventory.make_entry('tree-reference', 'a tree', b'parent') - entry.reference_revision = b'foo@\xc3\xa5b-lah' - self.assertEqual(b'tree\x00foo@\xc3\xa5b-lah', - inventory_delta.serialize_inventory_entry(entry)) + entry = inventory.make_entry("tree-reference", "a tree", b"parent") + entry.reference_revision = b"foo@\xc3\xa5b-lah" + self.assertEqual( + b"tree\x00foo@\xc3\xa5b-lah", + inventory_delta.serialize_inventory_entry(entry), + ) def test_reference_no_reference(self): - entry = inventory.make_entry('tree-reference', 'a tree', b'parent') - self.assertRaises(InventoryDeltaError, - inventory_delta.serialize_inventory_entry, entry) + entry = inventory.make_entry("tree-reference", "a tree", b"parent") + self.assertRaises( + InventoryDeltaError, inventory_delta.serialize_inventory_entry, entry + ) diff --git a/breezy/bzr/tests/test_knit.py b/breezy/bzr/tests/test_knit.py index b01d8912c2..8866372778 100644 --- a/breezy/bzr/tests/test_knit.py +++ b/breezy/bzr/tests/test_knit.py @@ -57,39 +57,43 @@ ) compiled_knit_feature = features.ModuleAvailableFeature( - 'breezy.bzr._knit_load_data_pyx') + "breezy.bzr._knit_load_data_pyx" +) class ErrorTests(TestCase): - def test_knit_data_stream_incompatible(self): - error = KnitDataStreamIncompatible( - 'stream format', 'target format') - self.assertEqual('Cannot insert knit data stream of format ' - '"stream format" into knit of format ' - '"target format".', str(error)) + error = KnitDataStreamIncompatible("stream format", "target format") + self.assertEqual( + "Cannot insert knit data stream of format " + '"stream format" into knit of format ' + '"target format".', + str(error), + ) def test_knit_data_stream_unknown(self): - error = KnitDataStreamUnknown( - 'stream format') - self.assertEqual('Cannot parse knit data stream of format ' - '"stream format".', str(error)) + error = KnitDataStreamUnknown("stream format") + self.assertEqual( + "Cannot parse knit data stream of format " '"stream format".', str(error) + ) def test_knit_header_error(self): - error = KnitHeaderError('line foo\n', 'path/to/file') - self.assertEqual("Knit header error: 'line foo\\n' unexpected" - " for file \"path/to/file\".", str(error)) + error = KnitHeaderError("line foo\n", "path/to/file") + self.assertEqual( + "Knit header error: 'line foo\\n' unexpected" ' for file "path/to/file".', + str(error), + ) def test_knit_index_unknown_method(self): - error = KnitIndexUnknownMethod('http://host/foo.kndx', - ['bad', 'no-eol']) - self.assertEqual("Knit index http://host/foo.kndx does not have a" - " known method in options: ['bad', 'no-eol']", - str(error)) + error = KnitIndexUnknownMethod("http://host/foo.kndx", ["bad", "no-eol"]) + self.assertEqual( + "Knit index http://host/foo.kndx does not have a" + " known method in options: ['bad', 'no-eol']", + str(error), + ) class KnitContentTestsMixin: - def test_constructor(self): self._make_content([]) @@ -97,13 +101,11 @@ def test_text(self): content = self._make_content([]) self.assertEqual(content.text(), []) - content = self._make_content( - [(b"origin1", b"text1"), (b"origin2", b"text2")]) + content = self._make_content([(b"origin1", b"text1"), (b"origin2", b"text2")]) self.assertEqual(content.text(), [b"text1", b"text2"]) def test_copy(self): - content = self._make_content( - [(b"origin1", b"text1"), (b"origin2", b"text2")]) + content = self._make_content([(b"origin1", b"text1"), (b"origin2", b"text2")]) copy = content.copy() self.assertIsInstance(copy, content.__class__) self.assertEqual(copy.annotate(), content.annotate()) @@ -114,32 +116,32 @@ def assertDerivedBlocksEqual(self, source, target, noeol=False): target_lines = target.splitlines(True) def nl(line): - if noeol and not line.endswith('\n'): - return line + '\n' + if noeol and not line.endswith("\n"): + return line + "\n" else: return line - source_content = self._make_content( - [(None, nl(l)) for l in source_lines]) - target_content = self._make_content( - [(None, nl(l)) for l in target_lines]) + + source_content = self._make_content([(None, nl(l)) for l in source_lines]) + target_content = self._make_content([(None, nl(l)) for l in target_lines]) line_delta = source_content.line_delta(target_content) - delta_blocks = list(KnitContent.get_line_delta_blocks(line_delta, - source_lines, target_lines)) + delta_blocks = list( + KnitContent.get_line_delta_blocks(line_delta, source_lines, target_lines) + ) matcher = PatienceSequenceMatcher(None, source_lines, target_lines) matcher_blocks = list(matcher.get_matching_blocks()) self.assertEqual(matcher_blocks, delta_blocks) def test_get_line_delta_blocks(self): - self.assertDerivedBlocksEqual('a\nb\nc\n', 'q\nc\n') + self.assertDerivedBlocksEqual("a\nb\nc\n", "q\nc\n") self.assertDerivedBlocksEqual(TEXT_1, TEXT_1) self.assertDerivedBlocksEqual(TEXT_1, TEXT_1A) self.assertDerivedBlocksEqual(TEXT_1, TEXT_1B) self.assertDerivedBlocksEqual(TEXT_1B, TEXT_1A) self.assertDerivedBlocksEqual(TEXT_1A, TEXT_1B) - self.assertDerivedBlocksEqual(TEXT_1A, '') - self.assertDerivedBlocksEqual('', TEXT_1A) - self.assertDerivedBlocksEqual('', '') - self.assertDerivedBlocksEqual('a\nb\nc', 'a\nb\nc\nd') + self.assertDerivedBlocksEqual(TEXT_1A, "") + self.assertDerivedBlocksEqual("", TEXT_1A) + self.assertDerivedBlocksEqual("", "") + self.assertDerivedBlocksEqual("a\nb\nc", "a\nb\nc\nd") def test_get_line_delta_blocks_noeol(self): """Handle historical knit deltas safely. @@ -150,10 +152,10 @@ def test_get_line_delta_blocks_noeol(self): New knit deltas appear to always consider the last line to differ in this case. """ - self.assertDerivedBlocksEqual('a\nb\nc', 'a\nb\nc\nd\n', noeol=True) - self.assertDerivedBlocksEqual('a\nb\nc\nd\n', 'a\nb\nc', noeol=True) - self.assertDerivedBlocksEqual('a\nb\nc\n', 'a\nb\nc', noeol=True) - self.assertDerivedBlocksEqual('a\nb\nc', 'a\nb\nc\n', noeol=True) + self.assertDerivedBlocksEqual("a\nb\nc", "a\nb\nc\nd\n", noeol=True) + self.assertDerivedBlocksEqual("a\nb\nc\nd\n", "a\nb\nc", noeol=True) + self.assertDerivedBlocksEqual("a\nb\nc\n", "a\nb\nc", noeol=True) + self.assertDerivedBlocksEqual("a\nb\nc", "a\nb\nc\n", noeol=True) TEXT_1 = """\ @@ -202,25 +204,21 @@ def test_get_line_delta_blocks_noeol(self): class TestPlainKnitContent(TestCase, KnitContentTestsMixin): - def _make_content(self, lines): annotated_content = AnnotatedKnitContent(lines) - return PlainKnitContent(annotated_content.text(), 'bogus') + return PlainKnitContent(annotated_content.text(), "bogus") def test_annotate(self): content = self._make_content([]) self.assertEqual(content.annotate(), []) - content = self._make_content( - [("origin1", "text1"), ("origin2", "text2")]) - self.assertEqual(content.annotate(), - [("bogus", "text1"), ("bogus", "text2")]) + content = self._make_content([("origin1", "text1"), ("origin2", "text2")]) + self.assertEqual(content.annotate(), [("bogus", "text1"), ("bogus", "text2")]) def test_line_delta(self): content1 = self._make_content([("", "a"), ("", "b")]) content2 = self._make_content([("", "a"), ("", "a"), ("", "c")]) - self.assertEqual(content1.line_delta(content2), - [(1, 2, 2, ["a", "c"])]) + self.assertEqual(content1.line_delta(content2), [(1, 2, 2, ["a", "c"])]) def test_line_delta_iter(self): content1 = self._make_content([("", "a"), ("", "b")]) @@ -231,7 +229,6 @@ def test_line_delta_iter(self): class TestAnnotatedKnitContent(TestCase, KnitContentTestsMixin): - def _make_content(self, lines): return AnnotatedKnitContent(lines) @@ -239,16 +236,17 @@ def test_annotate(self): content = self._make_content([]) self.assertEqual(content.annotate(), []) - content = self._make_content( - [(b"origin1", b"text1"), (b"origin2", b"text2")]) - self.assertEqual(content.annotate(), - [(b"origin1", b"text1"), (b"origin2", b"text2")]) + content = self._make_content([(b"origin1", b"text1"), (b"origin2", b"text2")]) + self.assertEqual( + content.annotate(), [(b"origin1", b"text1"), (b"origin2", b"text2")] + ) def test_line_delta(self): content1 = self._make_content([("", "a"), ("", "b")]) content2 = self._make_content([("", "a"), ("", "a"), ("", "c")]) - self.assertEqual(content1.line_delta(content2), - [(1, 2, 2, [("", "a"), ("", "c")])]) + self.assertEqual( + content1.line_delta(content2), [(1, 2, 2, [("", "a"), ("", "c")])] + ) def test_line_delta_iter(self): content1 = self._make_content([("", "a"), ("", "b")]) @@ -259,12 +257,11 @@ def test_line_delta_iter(self): class MockTransport: - def __init__(self, file_lines=None): self.file_lines = file_lines self.calls = [] # We have no base directory for the MockTransport - self.base = '' + self.base = "" def get(self, filename): if self.file_lines is None: @@ -281,6 +278,7 @@ def readv(self, relpath, offsets): def __getattr__(self, name): def queue_call(*args, **kwargs): self.calls.append((name, args, kwargs)) + return queue_call @@ -308,30 +306,31 @@ class KnitRecordAccessTestsMixin: def test_add_raw_records(self): """add_raw_records adds records retrievable later.""" access = self.get_access() - memos = access.add_raw_records([(b'key', 10)], [b'1234567890']) - self.assertEqual([b'1234567890'], list(access.get_raw_records(memos))) + memos = access.add_raw_records([(b"key", 10)], [b"1234567890"]) + self.assertEqual([b"1234567890"], list(access.get_raw_records(memos))) def test_add_raw_record(self): """add_raw_record adds records retrievable later.""" access = self.get_access() - memos = access.add_raw_record(b'key', 10, [b'1234567890']) - self.assertEqual([b'1234567890'], list(access.get_raw_records([memos]))) + memos = access.add_raw_record(b"key", 10, [b"1234567890"]) + self.assertEqual([b"1234567890"], list(access.get_raw_records([memos]))) def test_add_several_raw_records(self): """add_raw_records with many records and read some back.""" access = self.get_access() - memos = access.add_raw_records([(b'key', 10), (b'key2', 2), (b'key3', 5)], - [b'12345678901234567']) - self.assertEqual([b'1234567890', b'12', b'34567'], - list(access.get_raw_records(memos))) - self.assertEqual([b'1234567890'], - list(access.get_raw_records(memos[0:1]))) - self.assertEqual([b'12'], - list(access.get_raw_records(memos[1:2]))) - self.assertEqual([b'34567'], - list(access.get_raw_records(memos[2:3]))) - self.assertEqual([b'1234567890', b'34567'], - list(access.get_raw_records(memos[0:1] + memos[2:3]))) + memos = access.add_raw_records( + [(b"key", 10), (b"key2", 2), (b"key3", 5)], [b"12345678901234567"] + ) + self.assertEqual( + [b"1234567890", b"12", b"34567"], list(access.get_raw_records(memos)) + ) + self.assertEqual([b"1234567890"], list(access.get_raw_records(memos[0:1]))) + self.assertEqual([b"12"], list(access.get_raw_records(memos[1:2]))) + self.assertEqual([b"34567"], list(access.get_raw_records(memos[2:3]))) + self.assertEqual( + [b"1234567890", b"34567"], + list(access.get_raw_records(memos[0:1] + memos[2:3])), + ) class TestKnitKnitAccess(TestCaseWithMemoryTransport, KnitRecordAccessTestsMixin): @@ -354,11 +353,12 @@ class TestPackKnitAccess(TestCaseWithMemoryTransport, KnitRecordAccessTestsMixin def get_access(self): return self._get_access()[0] - def _get_access(self, packname='packfile', index='FOO'): + def _get_access(self, packname="packfile", index="FOO"): transport = self.get_transport() def write_data(bytes): transport.append_bytes(packname, bytes) + writer = pack.ContainerWriter(write_data) writer.begin() access = pack_repo._DirectPackAccess({}) @@ -367,10 +367,10 @@ def write_data(bytes): def make_pack_file(self): """Create a pack file with 2 records.""" - access, writer = self._get_access(packname='packname', index='foo') + access, writer = self._get_access(packname="packname", index="foo") memos = [] - memos.extend(access.add_raw_records([(b'key1', 10)], [b'1234567890'])) - memos.extend(access.add_raw_records([(b'key2', 5)], [b'12345'])) + memos.extend(access.add_raw_records([(b"key1", 10)], [b"1234567890"])) + memos.extend(access.add_raw_records([(b"key2", 5)], [b"12345"])) writer.end() return memos @@ -378,18 +378,30 @@ def test_pack_collection_pack_retries(self): """An explicit pack of a pack collection succeeds even when a concurrent pack happens. """ - builder = self.make_branch_builder('.') + builder = self.make_branch_builder(".") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('file', b'file-id', 'file', b'content\nrev 1\n')), - ], revision_id=b'rev-1') - builder.build_snapshot([b'rev-1'], [ - ('modify', ('file', b'content\nrev 2\n')), - ], revision_id=b'rev-2') - builder.build_snapshot([b'rev-2'], [ - ('modify', ('file', b'content\nrev 3\n')), - ], revision_id=b'rev-3') + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("file", b"file-id", "file", b"content\nrev 1\n")), + ], + revision_id=b"rev-1", + ) + builder.build_snapshot( + [b"rev-1"], + [ + ("modify", ("file", b"content\nrev 2\n")), + ], + revision_id=b"rev-2", + ) + builder.build_snapshot( + [b"rev-2"], + [ + ("modify", ("file", b"content\nrev 3\n")), + ], + revision_id=b"rev-3", + ) self.addCleanup(builder.finish_series) b = builder.get_branch() self.addCleanup(b.lock_write().unlock) @@ -411,18 +423,30 @@ def make_vf_for_retrying(self): :return: (versioned_file, reload_counter) versioned_file a KnitVersionedFiles using the packs for access """ - builder = self.make_branch_builder('.', format="1.9") + builder = self.make_branch_builder(".", format="1.9") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('file', b'file-id', 'file', b'content\nrev 1\n')), - ], revision_id=b'rev-1') - builder.build_snapshot([b'rev-1'], [ - ('modify', ('file', b'content\nrev 2\n')), - ], revision_id=b'rev-2') - builder.build_snapshot([b'rev-2'], [ - ('modify', ('file', b'content\nrev 3\n')), - ], revision_id=b'rev-3') + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("file", b"file-id", "file", b"content\nrev 1\n")), + ], + revision_id=b"rev-1", + ) + builder.build_snapshot( + [b"rev-1"], + [ + ("modify", ("file", b"content\nrev 2\n")), + ], + revision_id=b"rev-2", + ) + builder.build_snapshot( + [b"rev-2"], + [ + ("modify", ("file", b"content\nrev 3\n")), + ], + revision_id=b"rev-3", + ) builder.finish_series() b = builder.get_branch() b.lock_write() @@ -433,7 +457,7 @@ def make_vf_for_retrying(self): collection = repo._pack_collection collection.ensure_loaded() orig_packs = collection.packs - packer = knitpack_repo.KnitPacker(collection, orig_packs, '.testpack') + packer = knitpack_repo.KnitPacker(collection, orig_packs, ".testpack") new_pack = packer.pack() # forget about the new pack collection.reset() @@ -455,6 +479,7 @@ def reload(): vf._access._indices.clear() vf._access._indices[new_index] = access_tuple return True + # Delete one of the pack files so the data will need to be reloaded. We # will delete the file with 'rev-2' in it trans, name = orig_packs[1].access_tuple() @@ -470,70 +495,79 @@ def make_reload_func(self, return_val=True): def reload(): reload_called[0] += 1 return return_val + return reload_called, reload def make_retry_exception(self): # We raise a real exception so that sys.exc_info() is properly # populated try: - raise _TestException('foobar') + raise _TestException("foobar") except _TestException: - retry_exc = pack_repo.RetryWithNewPacks(None, reload_occurred=False, - exc_info=sys.exc_info()) + retry_exc = pack_repo.RetryWithNewPacks( + None, reload_occurred=False, exc_info=sys.exc_info() + ) # GZ 2010-08-10: Cycle with exc_info affects 3 tests return retry_exc def test_read_from_several_packs(self): access, writer = self._get_access() memos = [] - memos.extend(access.add_raw_records([(b'key', 10)], [b'1234567890'])) + memos.extend(access.add_raw_records([(b"key", 10)], [b"1234567890"])) writer.end() - access, writer = self._get_access('pack2', 'FOOBAR') - memos.extend(access.add_raw_records([(b'key', 5)], [b'12345'])) + access, writer = self._get_access("pack2", "FOOBAR") + memos.extend(access.add_raw_records([(b"key", 5)], [b"12345"])) writer.end() - access, writer = self._get_access('pack3', 'BAZ') - memos.extend(access.add_raw_records([(b'key', 5)], [b'alpha'])) + access, writer = self._get_access("pack3", "BAZ") + memos.extend(access.add_raw_records([(b"key", 5)], [b"alpha"])) writer.end() transport = self.get_transport() - access = pack_repo._DirectPackAccess({"FOO": (transport, 'packfile'), - "FOOBAR": (transport, 'pack2'), - "BAZ": (transport, 'pack3')}) - self.assertEqual([b'1234567890', b'12345', b'alpha'], - list(access.get_raw_records(memos))) - self.assertEqual([b'1234567890'], - list(access.get_raw_records(memos[0:1]))) - self.assertEqual([b'12345'], - list(access.get_raw_records(memos[1:2]))) - self.assertEqual([b'alpha'], - list(access.get_raw_records(memos[2:3]))) - self.assertEqual([b'1234567890', b'alpha'], - list(access.get_raw_records(memos[0:1] + memos[2:3]))) + access = pack_repo._DirectPackAccess( + { + "FOO": (transport, "packfile"), + "FOOBAR": (transport, "pack2"), + "BAZ": (transport, "pack3"), + } + ) + self.assertEqual( + [b"1234567890", b"12345", b"alpha"], list(access.get_raw_records(memos)) + ) + self.assertEqual([b"1234567890"], list(access.get_raw_records(memos[0:1]))) + self.assertEqual([b"12345"], list(access.get_raw_records(memos[1:2]))) + self.assertEqual([b"alpha"], list(access.get_raw_records(memos[2:3]))) + self.assertEqual( + [b"1234567890", b"alpha"], + list(access.get_raw_records(memos[0:1] + memos[2:3])), + ) def test_set_writer(self): """The writer should be settable post construction.""" access = pack_repo._DirectPackAccess({}) transport = self.get_transport() - packname = 'packfile' - index = 'foo' + packname = "packfile" + index = "foo" def write_data(bytes): transport.append_bytes(packname, bytes) + writer = pack.ContainerWriter(write_data) writer.begin() access.set_writer(writer, index, (transport, packname)) - memos = access.add_raw_records([(b'key', 10)], [b'1234567890']) + memos = access.add_raw_records([(b"key", 10)], [b"1234567890"]) writer.end() - self.assertEqual([b'1234567890'], list(access.get_raw_records(memos))) + self.assertEqual([b"1234567890"], list(access.get_raw_records(memos))) def test_missing_index_raises_retry(self): memos = self.make_pack_file() transport = self.get_transport() reload_called, reload_func = self.make_reload_func() # Note that the index key has changed from 'foo' to 'bar' - access = pack_repo._DirectPackAccess({'bar': (transport, 'packname')}, - reload_func=reload_func) - e = self.assertListRaises(pack_repo.RetryWithNewPacks, - access.get_raw_records, memos) + access = pack_repo._DirectPackAccess( + {"bar": (transport, "packname")}, reload_func=reload_func + ) + e = self.assertListRaises( + pack_repo.RetryWithNewPacks, access.get_raw_records, memos + ) # Because a key was passed in which does not match our index list, we # assume that the listing was already reloaded self.assertTrue(e.reload_occurred) @@ -545,7 +579,7 @@ def test_missing_index_raises_key_error_with_no_reload(self): memos = self.make_pack_file() transport = self.get_transport() # Note that the index key has changed from 'foo' to 'bar' - access = pack_repo._DirectPackAccess({'bar': (transport, 'packname')}) + access = pack_repo._DirectPackAccess({"bar": (transport, "packname")}) self.assertListRaises(KeyError, access.get_raw_records, memos) def test_missing_file_raises_retry(self): @@ -554,66 +588,58 @@ def test_missing_file_raises_retry(self): reload_called, reload_func = self.make_reload_func() # Note that the 'filename' has been changed to 'different-packname' access = pack_repo._DirectPackAccess( - {'foo': (transport, 'different-packname')}, - reload_func=reload_func) - e = self.assertListRaises(pack_repo.RetryWithNewPacks, - access.get_raw_records, memos) + {"foo": (transport, "different-packname")}, reload_func=reload_func + ) + e = self.assertListRaises( + pack_repo.RetryWithNewPacks, access.get_raw_records, memos + ) # The file has gone missing, so we assume we need to reload self.assertFalse(e.reload_occurred) self.assertIsInstance(e.exc_info, tuple) self.assertIs(e.exc_info[0], _mod_transport.NoSuchFile) self.assertIsInstance(e.exc_info[1], _mod_transport.NoSuchFile) - self.assertEqual('different-packname', e.exc_info[1].path) + self.assertEqual("different-packname", e.exc_info[1].path) def test_missing_file_raises_no_such_file_with_no_reload(self): memos = self.make_pack_file() transport = self.get_transport() # Note that the 'filename' has been changed to 'different-packname' - access = pack_repo._DirectPackAccess( - {'foo': (transport, 'different-packname')}) - self.assertListRaises(_mod_transport.NoSuchFile, - access.get_raw_records, memos) + access = pack_repo._DirectPackAccess({"foo": (transport, "different-packname")}) + self.assertListRaises(_mod_transport.NoSuchFile, access.get_raw_records, memos) def test_failing_readv_raises_retry(self): memos = self.make_pack_file() transport = self.get_transport() - failing_transport = MockReadvFailingTransport( - [transport.get_bytes('packname')]) + failing_transport = MockReadvFailingTransport([transport.get_bytes("packname")]) reload_called, reload_func = self.make_reload_func() access = pack_repo._DirectPackAccess( - {'foo': (failing_transport, 'packname')}, - reload_func=reload_func) + {"foo": (failing_transport, "packname")}, reload_func=reload_func + ) # Asking for a single record will not trigger the Mock failure - self.assertEqual([b'1234567890'], - list(access.get_raw_records(memos[:1]))) - self.assertEqual([b'12345'], - list(access.get_raw_records(memos[1:2]))) + self.assertEqual([b"1234567890"], list(access.get_raw_records(memos[:1]))) + self.assertEqual([b"12345"], list(access.get_raw_records(memos[1:2]))) # A multiple offset readv() will fail mid-way through - e = self.assertListRaises(pack_repo.RetryWithNewPacks, - access.get_raw_records, memos) + e = self.assertListRaises( + pack_repo.RetryWithNewPacks, access.get_raw_records, memos + ) # The file has gone missing, so we assume we need to reload self.assertFalse(e.reload_occurred) self.assertIsInstance(e.exc_info, tuple) self.assertIs(e.exc_info[0], _mod_transport.NoSuchFile) self.assertIsInstance(e.exc_info[1], _mod_transport.NoSuchFile) - self.assertEqual('packname', e.exc_info[1].path) + self.assertEqual("packname", e.exc_info[1].path) def test_failing_readv_raises_no_such_file_with_no_reload(self): memos = self.make_pack_file() transport = self.get_transport() - failing_transport = MockReadvFailingTransport( - [transport.get_bytes('packname')]) + failing_transport = MockReadvFailingTransport([transport.get_bytes("packname")]) reload_called, reload_func = self.make_reload_func() - access = pack_repo._DirectPackAccess( - {'foo': (failing_transport, 'packname')}) + access = pack_repo._DirectPackAccess({"foo": (failing_transport, "packname")}) # Asking for a single record will not trigger the Mock failure - self.assertEqual([b'1234567890'], - list(access.get_raw_records(memos[:1]))) - self.assertEqual([b'12345'], - list(access.get_raw_records(memos[1:2]))) + self.assertEqual([b"1234567890"], list(access.get_raw_records(memos[:1]))) + self.assertEqual([b"12345"], list(access.get_raw_records(memos[1:2]))) # A multiple offset readv() will fail mid-way through - self.assertListRaises(_mod_transport.NoSuchFile, - access.get_raw_records, memos) + self.assertListRaises(_mod_transport.NoSuchFile, access.get_raw_records, memos) def test_reload_or_raise_no_reload(self): access = pack_repo._DirectPackAccess({}, reload_func=None) @@ -649,13 +675,13 @@ def test_annotate_retries(self): vf, reload_counter = self.make_vf_for_retrying() # It is a little bit bogus to annotate the Revision VF, but it works, # as we have ancestry stored there - key = (b'rev-3',) + key = (b"rev-3",) reload_lines = vf.annotate(key) self.assertEqual([1, 1, 0], reload_counter) plain_lines = vf.annotate(key) self.assertEqual([1, 1, 0], reload_counter) # No extra reloading if reload_lines != plain_lines: - self.fail('Annotation was not identical with reloading.') + self.fail("Annotation was not identical with reloading.") # Now delete the packs-in-use, which should trigger another reload, but # this time we just raise an exception because we can't recover for trans, name in vf._access._indices.values(): @@ -665,7 +691,7 @@ def test_annotate_retries(self): def test__get_record_map_retries(self): vf, reload_counter = self.make_vf_for_retrying() - keys = [(b'rev-1',), (b'rev-2',), (b'rev-3',)] + keys = [(b"rev-1",), (b"rev-2",), (b"rev-3",)] records = vf._get_record_map(keys) self.assertEqual(keys, sorted(records.keys())) self.assertEqual([1, 1, 0], reload_counter) @@ -678,26 +704,27 @@ def test__get_record_map_retries(self): def test_get_record_stream_retries(self): vf, reload_counter = self.make_vf_for_retrying() - keys = [(b'rev-1',), (b'rev-2',), (b'rev-3',)] - record_stream = vf.get_record_stream(keys, 'topological', False) + keys = [(b"rev-1",), (b"rev-2",), (b"rev-3",)] + record_stream = vf.get_record_stream(keys, "topological", False) record = next(record_stream) - self.assertEqual((b'rev-1',), record.key) + self.assertEqual((b"rev-1",), record.key) self.assertEqual([0, 0, 0], reload_counter) record = next(record_stream) - self.assertEqual((b'rev-2',), record.key) + self.assertEqual((b"rev-2",), record.key) self.assertEqual([1, 1, 0], reload_counter) record = next(record_stream) - self.assertEqual((b'rev-3',), record.key) + self.assertEqual((b"rev-3",), record.key) self.assertEqual([1, 1, 0], reload_counter) # Now delete all pack files, and see that we raise the right error for trans, name in vf._access._indices.values(): trans.delete(name) - self.assertListRaises(_mod_transport.NoSuchFile, - vf.get_record_stream, keys, 'topological', False) + self.assertListRaises( + _mod_transport.NoSuchFile, vf.get_record_stream, keys, "topological", False + ) def test_iter_lines_added_or_present_in_keys_retries(self): vf, reload_counter = self.make_vf_for_retrying() - keys = [(b'rev-1',), (b'rev-2',), (b'rev-3',)] + keys = [(b"rev-1",), (b"rev-2",), (b"rev-3",)] # Unfortunately, iter_lines_added_or_present_in_keys iterates the # result in random order (determined by the iteration order from a # set()), so we don't have any solid way to trigger whether data is @@ -715,222 +742,225 @@ def test_iter_lines_added_or_present_in_keys_retries(self): # Now delete all pack files, and see that we raise the right error for trans, name in vf._access._indices.values(): trans.delete(name) - self.assertListRaises(_mod_transport.NoSuchFile, - vf.iter_lines_added_or_present_in_keys, keys) + self.assertListRaises( + _mod_transport.NoSuchFile, vf.iter_lines_added_or_present_in_keys, keys + ) self.assertEqual([2, 1, 1], reload_counter) def test_get_record_stream_yields_disk_sorted_order(self): # if we get 'unordered' pick a semi-optimal order for reading. The # order should be grouped by pack file, and then by position in file - repo = self.make_repository('test', format='pack-0.92') + repo = self.make_repository("test", format="pack-0.92") repo.lock_write() self.addCleanup(repo.unlock) repo.start_write_group() vf = repo.texts - vf.add_lines((b'f-id', b'rev-5'), [(b'f-id', b'rev-4')], [b'lines\n']) - vf.add_lines((b'f-id', b'rev-1'), [], [b'lines\n']) - vf.add_lines((b'f-id', b'rev-2'), [(b'f-id', b'rev-1')], [b'lines\n']) + vf.add_lines((b"f-id", b"rev-5"), [(b"f-id", b"rev-4")], [b"lines\n"]) + vf.add_lines((b"f-id", b"rev-1"), [], [b"lines\n"]) + vf.add_lines((b"f-id", b"rev-2"), [(b"f-id", b"rev-1")], [b"lines\n"]) repo.commit_write_group() # We inserted them as rev-5, rev-1, rev-2, we should get them back in # the same order - stream = vf.get_record_stream([(b'f-id', b'rev-1'), (b'f-id', b'rev-5'), - (b'f-id', b'rev-2')], 'unordered', False) + stream = vf.get_record_stream( + [(b"f-id", b"rev-1"), (b"f-id", b"rev-5"), (b"f-id", b"rev-2")], + "unordered", + False, + ) keys = [r.key for r in stream] - self.assertEqual([(b'f-id', b'rev-5'), (b'f-id', b'rev-1'), - (b'f-id', b'rev-2')], keys) + self.assertEqual( + [(b"f-id", b"rev-5"), (b"f-id", b"rev-1"), (b"f-id", b"rev-2")], keys + ) repo.start_write_group() - vf.add_lines((b'f-id', b'rev-4'), [(b'f-id', b'rev-3')], [b'lines\n']) - vf.add_lines((b'f-id', b'rev-3'), [(b'f-id', b'rev-2')], [b'lines\n']) - vf.add_lines((b'f-id', b'rev-6'), [(b'f-id', b'rev-5')], [b'lines\n']) + vf.add_lines((b"f-id", b"rev-4"), [(b"f-id", b"rev-3")], [b"lines\n"]) + vf.add_lines((b"f-id", b"rev-3"), [(b"f-id", b"rev-2")], [b"lines\n"]) + vf.add_lines((b"f-id", b"rev-6"), [(b"f-id", b"rev-5")], [b"lines\n"]) repo.commit_write_group() # Request in random order, to make sure the output order isn't based on # the request - request_keys = {(b'f-id', b'rev-%d' % i) for i in range(1, 7)} - stream = vf.get_record_stream(request_keys, 'unordered', False) + request_keys = {(b"f-id", b"rev-%d" % i) for i in range(1, 7)} + stream = vf.get_record_stream(request_keys, "unordered", False) keys = [r.key for r in stream] # We want to get the keys back in disk order, but it doesn't matter # which pack we read from first. So this can come back in 2 orders - alt1 = [(b'f-id', b'rev-%d' % i) for i in [4, 3, 6, 5, 1, 2]] - alt2 = [(b'f-id', b'rev-%d' % i) for i in [5, 1, 2, 4, 3, 6]] + alt1 = [(b"f-id", b"rev-%d" % i) for i in [4, 3, 6, 5, 1, 2]] + alt2 = [(b"f-id", b"rev-%d" % i) for i in [5, 1, 2, 4, 3, 6]] if keys != alt1 and keys != alt2: - self.fail('Returned key order did not match either expected order.' - f' expected {alt1} or {alt2}, not {keys}') + self.fail( + "Returned key order did not match either expected order." + f" expected {alt1} or {alt2}, not {keys}" + ) class LowLevelKnitDataTests(TestCase): - def create_gz_content(self, text): sio = BytesIO() - with gzip.GzipFile(mode='wb', fileobj=sio) as gz_file: + with gzip.GzipFile(mode="wb", fileobj=sio) as gz_file: gz_file.write(text) return sio.getvalue() def make_multiple_records(self): """Create the content for multiple records.""" - sha1sum = osutils.sha_string(b'foo\nbar\n') + sha1sum = osutils.sha_string(b"foo\nbar\n") total_txt = [] - gz_txt = self.create_gz_content(b'version rev-id-1 2 %s\n' - b'foo\n' - b'bar\n' - b'end rev-id-1\n' - % (sha1sum,)) + gz_txt = self.create_gz_content( + b"version rev-id-1 2 %s\n" b"foo\n" b"bar\n" b"end rev-id-1\n" % (sha1sum,) + ) record_1 = (0, len(gz_txt), sha1sum) total_txt.append(gz_txt) - sha1sum = osutils.sha_string(b'baz\n') - gz_txt = self.create_gz_content(b'version rev-id-2 1 %s\n' - b'baz\n' - b'end rev-id-2\n' - % (sha1sum,)) + sha1sum = osutils.sha_string(b"baz\n") + gz_txt = self.create_gz_content( + b"version rev-id-2 1 %s\n" b"baz\n" b"end rev-id-2\n" % (sha1sum,) + ) record_2 = (record_1[1], len(gz_txt), sha1sum) total_txt.append(gz_txt) return total_txt, record_1, record_2 def test_valid_knit_data(self): - sha1sum = osutils.sha_string(b'foo\nbar\n') - gz_txt = self.create_gz_content(b'version rev-id-1 2 %s\n' - b'foo\n' - b'bar\n' - b'end rev-id-1\n' - % (sha1sum,)) + sha1sum = osutils.sha_string(b"foo\nbar\n") + gz_txt = self.create_gz_content( + b"version rev-id-1 2 %s\n" b"foo\n" b"bar\n" b"end rev-id-1\n" % (sha1sum,) + ) transport = MockTransport([gz_txt]) - access = _KnitKeyAccess(transport, ConstantMapper('filename')) + access = _KnitKeyAccess(transport, ConstantMapper("filename")) knit = KnitVersionedFiles(None, access) - records = [((b'rev-id-1',), ((b'rev-id-1',), 0, len(gz_txt)))] + records = [((b"rev-id-1",), ((b"rev-id-1",), 0, len(gz_txt)))] contents = list(knit._read_records_iter(records)) - self.assertEqual([((b'rev-id-1',), [b'foo\n', b'bar\n'], - b'4e48e2c9a3d2ca8a708cb0cc545700544efb5021')], contents) + self.assertEqual( + [ + ( + (b"rev-id-1",), + [b"foo\n", b"bar\n"], + b"4e48e2c9a3d2ca8a708cb0cc545700544efb5021", + ) + ], + contents, + ) raw_contents = list(knit._read_records_iter_raw(records)) - self.assertEqual([((b'rev-id-1',), gz_txt, sha1sum)], raw_contents) + self.assertEqual([((b"rev-id-1",), gz_txt, sha1sum)], raw_contents) def test_multiple_records_valid(self): total_txt, record_1, record_2 = self.make_multiple_records() - transport = MockTransport([b''.join(total_txt)]) - access = _KnitKeyAccess(transport, ConstantMapper('filename')) + transport = MockTransport([b"".join(total_txt)]) + access = _KnitKeyAccess(transport, ConstantMapper("filename")) knit = KnitVersionedFiles(None, access) - records = [((b'rev-id-1',), ((b'rev-id-1',), record_1[0], record_1[1])), - ((b'rev-id-2',), ((b'rev-id-2',), record_2[0], record_2[1]))] + records = [ + ((b"rev-id-1",), ((b"rev-id-1",), record_1[0], record_1[1])), + ((b"rev-id-2",), ((b"rev-id-2",), record_2[0], record_2[1])), + ] contents = list(knit._read_records_iter(records)) - self.assertEqual([((b'rev-id-1',), [b'foo\n', b'bar\n'], record_1[2]), - ((b'rev-id-2',), [b'baz\n'], record_2[2])], - contents) + self.assertEqual( + [ + ((b"rev-id-1",), [b"foo\n", b"bar\n"], record_1[2]), + ((b"rev-id-2",), [b"baz\n"], record_2[2]), + ], + contents, + ) raw_contents = list(knit._read_records_iter_raw(records)) - self.assertEqual([((b'rev-id-1',), total_txt[0], record_1[2]), - ((b'rev-id-2',), total_txt[1], record_2[2])], - raw_contents) + self.assertEqual( + [ + ((b"rev-id-1",), total_txt[0], record_1[2]), + ((b"rev-id-2",), total_txt[1], record_2[2]), + ], + raw_contents, + ) def test_not_enough_lines(self): - sha1sum = osutils.sha_string(b'foo\n') + sha1sum = osutils.sha_string(b"foo\n") # record says 2 lines data says 1 - gz_txt = self.create_gz_content(b'version rev-id-1 2 %s\n' - b'foo\n' - b'end rev-id-1\n' - % (sha1sum,)) + gz_txt = self.create_gz_content( + b"version rev-id-1 2 %s\n" b"foo\n" b"end rev-id-1\n" % (sha1sum,) + ) transport = MockTransport([gz_txt]) - access = _KnitKeyAccess(transport, ConstantMapper('filename')) + access = _KnitKeyAccess(transport, ConstantMapper("filename")) knit = KnitVersionedFiles(None, access) - records = [((b'rev-id-1',), ((b'rev-id-1',), 0, len(gz_txt)))] - self.assertRaises(KnitCorrupt, list, - knit._read_records_iter(records)) + records = [((b"rev-id-1",), ((b"rev-id-1",), 0, len(gz_txt)))] + self.assertRaises(KnitCorrupt, list, knit._read_records_iter(records)) # read_records_iter_raw won't detect that sort of mismatch/corruption raw_contents = list(knit._read_records_iter_raw(records)) - self.assertEqual([((b'rev-id-1',), gz_txt, sha1sum)], raw_contents) + self.assertEqual([((b"rev-id-1",), gz_txt, sha1sum)], raw_contents) def test_too_many_lines(self): - sha1sum = osutils.sha_string(b'foo\nbar\n') + sha1sum = osutils.sha_string(b"foo\nbar\n") # record says 1 lines data says 2 - gz_txt = self.create_gz_content(b'version rev-id-1 1 %s\n' - b'foo\n' - b'bar\n' - b'end rev-id-1\n' - % (sha1sum,)) + gz_txt = self.create_gz_content( + b"version rev-id-1 1 %s\n" b"foo\n" b"bar\n" b"end rev-id-1\n" % (sha1sum,) + ) transport = MockTransport([gz_txt]) - access = _KnitKeyAccess(transport, ConstantMapper('filename')) + access = _KnitKeyAccess(transport, ConstantMapper("filename")) knit = KnitVersionedFiles(None, access) - records = [((b'rev-id-1',), ((b'rev-id-1',), 0, len(gz_txt)))] - self.assertRaises(KnitCorrupt, list, - knit._read_records_iter(records)) + records = [((b"rev-id-1",), ((b"rev-id-1",), 0, len(gz_txt)))] + self.assertRaises(KnitCorrupt, list, knit._read_records_iter(records)) # read_records_iter_raw won't detect that sort of mismatch/corruption raw_contents = list(knit._read_records_iter_raw(records)) - self.assertEqual([((b'rev-id-1',), gz_txt, sha1sum)], raw_contents) + self.assertEqual([((b"rev-id-1",), gz_txt, sha1sum)], raw_contents) def test_mismatched_version_id(self): - sha1sum = osutils.sha_string(b'foo\nbar\n') - gz_txt = self.create_gz_content(b'version rev-id-1 2 %s\n' - b'foo\n' - b'bar\n' - b'end rev-id-1\n' - % (sha1sum,)) + sha1sum = osutils.sha_string(b"foo\nbar\n") + gz_txt = self.create_gz_content( + b"version rev-id-1 2 %s\n" b"foo\n" b"bar\n" b"end rev-id-1\n" % (sha1sum,) + ) transport = MockTransport([gz_txt]) - access = _KnitKeyAccess(transport, ConstantMapper('filename')) + access = _KnitKeyAccess(transport, ConstantMapper("filename")) knit = KnitVersionedFiles(None, access) # We are asking for rev-id-2, but the data is rev-id-1 - records = [((b'rev-id-2',), ((b'rev-id-2',), 0, len(gz_txt)))] - self.assertRaises(KnitCorrupt, list, - knit._read_records_iter(records)) + records = [((b"rev-id-2",), ((b"rev-id-2",), 0, len(gz_txt)))] + self.assertRaises(KnitCorrupt, list, knit._read_records_iter(records)) # read_records_iter_raw detects mismatches in the header - self.assertRaises(KnitCorrupt, list, - knit._read_records_iter_raw(records)) + self.assertRaises(KnitCorrupt, list, knit._read_records_iter_raw(records)) def test_uncompressed_data(self): - sha1sum = osutils.sha_string(b'foo\nbar\n') - txt = (b'version rev-id-1 2 %s\n' - b'foo\n' - b'bar\n' - b'end rev-id-1\n' - % (sha1sum,)) + sha1sum = osutils.sha_string(b"foo\nbar\n") + txt = b"version rev-id-1 2 %s\n" b"foo\n" b"bar\n" b"end rev-id-1\n" % ( + sha1sum, + ) transport = MockTransport([txt]) - access = _KnitKeyAccess(transport, ConstantMapper('filename')) + access = _KnitKeyAccess(transport, ConstantMapper("filename")) knit = KnitVersionedFiles(None, access) - records = [((b'rev-id-1',), ((b'rev-id-1',), 0, len(txt)))] + records = [((b"rev-id-1",), ((b"rev-id-1",), 0, len(txt)))] # We don't have valid gzip data ==> corrupt - self.assertRaises(KnitCorrupt, list, - knit._read_records_iter(records)) + self.assertRaises(KnitCorrupt, list, knit._read_records_iter(records)) # read_records_iter_raw will notice the bad data - self.assertRaises(KnitCorrupt, list, - knit._read_records_iter_raw(records)) + self.assertRaises(KnitCorrupt, list, knit._read_records_iter_raw(records)) def test_corrupted_data(self): - sha1sum = osutils.sha_string(b'foo\nbar\n') - gz_txt = self.create_gz_content(b'version rev-id-1 2 %s\n' - b'foo\n' - b'bar\n' - b'end rev-id-1\n' - % (sha1sum,)) + sha1sum = osutils.sha_string(b"foo\nbar\n") + gz_txt = self.create_gz_content( + b"version rev-id-1 2 %s\n" b"foo\n" b"bar\n" b"end rev-id-1\n" % (sha1sum,) + ) # Change 2 bytes in the middle to \xff - gz_txt = gz_txt[:10] + b'\xff\xff' + gz_txt[12:] + gz_txt = gz_txt[:10] + b"\xff\xff" + gz_txt[12:] transport = MockTransport([gz_txt]) - access = _KnitKeyAccess(transport, ConstantMapper('filename')) + access = _KnitKeyAccess(transport, ConstantMapper("filename")) knit = KnitVersionedFiles(None, access) - records = [((b'rev-id-1',), ((b'rev-id-1',), 0, len(gz_txt)))] - self.assertRaises(KnitCorrupt, list, - knit._read_records_iter(records)) + records = [((b"rev-id-1",), ((b"rev-id-1",), 0, len(gz_txt)))] + self.assertRaises(KnitCorrupt, list, knit._read_records_iter(records)) # read_records_iter_raw will barf on bad gz data - self.assertRaises(KnitCorrupt, list, - knit._read_records_iter_raw(records)) + self.assertRaises(KnitCorrupt, list, knit._read_records_iter_raw(records)) class LowLevelKnitIndexTests(TestCase): - @property def _load_data(self): from .._knit_load_data_py import _load_data_py + return _load_data_py def get_knit_index(self, transport, name, mode): mapper = ConstantMapper(name) - self.overrideAttr(knit, '_load_data', self._load_data) + self.overrideAttr(knit, "_load_data", self._load_data) def allow_writes(): - return 'w' in mode + return "w" in mode + return _KndxIndex(transport, mapper, lambda: None, allow_writes, lambda: True) def test_create_file(self): @@ -939,126 +969,130 @@ def test_create_file(self): index.keys() call = transport.calls.pop(0) # call[1][1] is a BytesIO - we can't test it by simple equality. - self.assertEqual('put_file_non_atomic', call[0]) - self.assertEqual('filename.kndx', call[1][0]) + self.assertEqual("put_file_non_atomic", call[0]) + self.assertEqual("filename.kndx", call[1][0]) # With no history, _KndxIndex writes a new index: - self.assertEqual(_KndxIndex.HEADER, - call[1][1].getvalue()) - self.assertEqual({'create_parent_dir': True}, call[2]) + self.assertEqual(_KndxIndex.HEADER, call[1][1].getvalue()) + self.assertEqual({"create_parent_dir": True}, call[2]) def test_read_utf8_version_id(self): unicode_revision_id = "version-\N{CYRILLIC CAPITAL LETTER A}" - utf8_revision_id = unicode_revision_id.encode('utf-8') - transport = MockTransport([ - _KndxIndex.HEADER, - b'%s option 0 1 :' % (utf8_revision_id,) - ]) + utf8_revision_id = unicode_revision_id.encode("utf-8") + transport = MockTransport( + [_KndxIndex.HEADER, b"%s option 0 1 :" % (utf8_revision_id,)] + ) index = self.get_knit_index(transport, "filename", "r") # _KndxIndex is a private class, and deals in utf8 revision_ids, not # Unicode revision_ids. - self.assertEqual({(utf8_revision_id,): ()}, - index.get_parent_map(index.keys())) + self.assertEqual({(utf8_revision_id,): ()}, index.get_parent_map(index.keys())) self.assertNotIn((unicode_revision_id,), index.keys()) def test_read_utf8_parents(self): unicode_revision_id = "version-\N{CYRILLIC CAPITAL LETTER A}" - utf8_revision_id = unicode_revision_id.encode('utf-8') - transport = MockTransport([ - _KndxIndex.HEADER, - b"version option 0 1 .%s :" % (utf8_revision_id,) - ]) + utf8_revision_id = unicode_revision_id.encode("utf-8") + transport = MockTransport( + [_KndxIndex.HEADER, b"version option 0 1 .%s :" % (utf8_revision_id,)] + ) index = self.get_knit_index(transport, "filename", "r") - self.assertEqual({(b"version",): ((utf8_revision_id,),)}, - index.get_parent_map(index.keys())) + self.assertEqual( + {(b"version",): ((utf8_revision_id,),)}, index.get_parent_map(index.keys()) + ) def test_read_ignore_corrupted_lines(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"corrupted", - b"corrupted options 0 1 .b .c ", - b"version options 0 1 :" - ]) + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"corrupted", + b"corrupted options 0 1 .b .c ", + b"version options 0 1 :", + ] + ) index = self.get_knit_index(transport, "filename", "r") self.assertEqual(1, len(index.keys())) self.assertEqual({(b"version",)}, index.keys()) def test_read_corrupted_header(self): - transport = MockTransport([b'not a bzr knit index header\n']) + transport = MockTransport([b"not a bzr knit index header\n"]) index = self.get_knit_index(transport, "filename", "r") self.assertRaises(KnitHeaderError, index.keys) def test_read_duplicate_entries(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"parent options 0 1 :", - b"version options1 0 1 0 :", - b"version options2 1 2 .other :", - b"version options3 3 4 0 .other :" - ]) + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"parent options 0 1 :", + b"version options1 0 1 0 :", + b"version options2 1 2 .other :", + b"version options3 3 4 0 .other :", + ] + ) index = self.get_knit_index(transport, "filename", "r") self.assertEqual(2, len(index.keys())) # check that the index used is the first one written. (Specific # to KnitIndex style indices. self.assertEqual(b"1", index._dictionary_compress([(b"version",)])) - self.assertEqual(((b"version",), 3, 4), - index.get_position((b"version",))) + self.assertEqual(((b"version",), 3, 4), index.get_position((b"version",))) self.assertEqual([b"options3"], index.get_options((b"version",))) - self.assertEqual({(b"version",): ((b"parent",), (b"other",))}, - index.get_parent_map([(b"version",)])) + self.assertEqual( + {(b"version",): ((b"parent",), (b"other",))}, + index.get_parent_map([(b"version",)]), + ) def test_read_compressed_parents(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 0 1 :", - b"b option 0 1 0 :", - b"c option 0 1 1 0 :", - ]) + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 0 1 :", + b"b option 0 1 0 :", + b"c option 0 1 1 0 :", + ] + ) index = self.get_knit_index(transport, "filename", "r") - self.assertEqual({(b"b",): ((b"a",),), (b"c",): ((b"b",), (b"a",))}, - index.get_parent_map([(b"b",), (b"c",)])) + self.assertEqual( + {(b"b",): ((b"a",),), (b"c",): ((b"b",), (b"a",))}, + index.get_parent_map([(b"b",), (b"c",)]), + ) def test_write_utf8_version_id(self): unicode_revision_id = "version-\N{CYRILLIC CAPITAL LETTER A}" - utf8_revision_id = unicode_revision_id.encode('utf-8') - transport = MockTransport([ - _KndxIndex.HEADER - ]) + utf8_revision_id = unicode_revision_id.encode("utf-8") + transport = MockTransport([_KndxIndex.HEADER]) index = self.get_knit_index(transport, "filename", "r") - index.add_records([ - ((utf8_revision_id,), [b"option"], ((utf8_revision_id,), 0, 1), [])]) + index.add_records( + [((utf8_revision_id,), [b"option"], ((utf8_revision_id,), 0, 1), [])] + ) call = transport.calls.pop(0) # call[1][1] is a BytesIO - we can't test it by simple equality. - self.assertEqual('put_file_non_atomic', call[0]) - self.assertEqual('filename.kndx', call[1][0]) + self.assertEqual("put_file_non_atomic", call[0]) + self.assertEqual("filename.kndx", call[1][0]) # With no history, _KndxIndex writes a new index: - self.assertEqual(_KndxIndex.HEADER + - b"\n%s option 0 1 :" % (utf8_revision_id,), - call[1][1].getvalue()) - self.assertEqual({'create_parent_dir': True}, call[2]) + self.assertEqual( + _KndxIndex.HEADER + b"\n%s option 0 1 :" % (utf8_revision_id,), + call[1][1].getvalue(), + ) + self.assertEqual({"create_parent_dir": True}, call[2]) def test_write_utf8_parents(self): unicode_revision_id = "version-\N{CYRILLIC CAPITAL LETTER A}" - utf8_revision_id = unicode_revision_id.encode('utf-8') - transport = MockTransport([ - _KndxIndex.HEADER - ]) + utf8_revision_id = unicode_revision_id.encode("utf-8") + transport = MockTransport([_KndxIndex.HEADER]) index = self.get_knit_index(transport, "filename", "r") - index.add_records([ - ((b"version",), [b"option"], ((b"version",), 0, 1), [(utf8_revision_id,)])]) + index.add_records( + [((b"version",), [b"option"], ((b"version",), 0, 1), [(utf8_revision_id,)])] + ) call = transport.calls.pop(0) # call[1][1] is a BytesIO - we can't test it by simple equality. - self.assertEqual('put_file_non_atomic', call[0]) - self.assertEqual('filename.kndx', call[1][0]) + self.assertEqual("put_file_non_atomic", call[0]) + self.assertEqual("filename.kndx", call[1][0]) # With no history, _KndxIndex writes a new index: - self.assertEqual(_KndxIndex.HEADER + - b"\nversion option 0 1 .%s :" % (utf8_revision_id,), - call[1][1].getvalue()) - self.assertEqual({'create_parent_dir': True}, call[2]) + self.assertEqual( + _KndxIndex.HEADER + b"\nversion option 0 1 .%s :" % (utf8_revision_id,), + call[1][1].getvalue(), + ) + self.assertEqual({"create_parent_dir": True}, call[2]) def test_keys(self): - transport = MockTransport([ - _KndxIndex.HEADER - ]) + transport = MockTransport([_KndxIndex.HEADER]) index = self.get_knit_index(transport, "filename", "r") self.assertEqual(set(), index.keys()) @@ -1076,47 +1110,48 @@ def add_a_b(self, index, random_id=None): kwargs = {} if random_id is not None: kwargs["random_id"] = random_id - index.add_records([ - ((b"a",), [b"option"], ((b"a",), 0, 1), [(b"b",)]), - ((b"a",), [b"opt"], ((b"a",), 1, 2), [(b"c",)]), - ((b"b",), [b"option"], ((b"b",), 2, 3), [(b"a",)]) - ], **kwargs) + index.add_records( + [ + ((b"a",), [b"option"], ((b"a",), 0, 1), [(b"b",)]), + ((b"a",), [b"opt"], ((b"a",), 1, 2), [(b"c",)]), + ((b"b",), [b"option"], ((b"b",), 2, 3), [(b"a",)]), + ], + **kwargs, + ) def assertIndexIsAB(self, index): - self.assertEqual({ - (b'a',): ((b'c',),), - (b'b',): ((b'a',),), + self.assertEqual( + { + (b"a",): ((b"c",),), + (b"b",): ((b"a",),), }, - index.get_parent_map(index.keys())) + index.get_parent_map(index.keys()), + ) self.assertEqual(((b"a",), 1, 2), index.get_position((b"a",))) self.assertEqual(((b"b",), 2, 3), index.get_position((b"b",))) self.assertEqual([b"opt"], index.get_options((b"a",))) def test_add_versions(self): - transport = MockTransport([ - _KndxIndex.HEADER - ]) + transport = MockTransport([_KndxIndex.HEADER]) index = self.get_knit_index(transport, "filename", "r") self.add_a_b(index) call = transport.calls.pop(0) # call[1][1] is a BytesIO - we can't test it by simple equality. - self.assertEqual('put_file_non_atomic', call[0]) - self.assertEqual('filename.kndx', call[1][0]) + self.assertEqual("put_file_non_atomic", call[0]) + self.assertEqual("filename.kndx", call[1][0]) # With no history, _KndxIndex writes a new index: self.assertEqual( - _KndxIndex.HEADER + - b"\na option 0 1 .b :" + _KndxIndex.HEADER + b"\na option 0 1 .b :" b"\na opt 1 2 .c :" b"\nb option 2 3 0 :", - call[1][1].getvalue()) - self.assertEqual({'create_parent_dir': True}, call[2]) + call[1][1].getvalue(), + ) + self.assertEqual({"create_parent_dir": True}, call[2]) self.assertIndexIsAB(index) def test_add_versions_random_id_is_accepted(self): - transport = MockTransport([ - _KndxIndex.HEADER - ]) + transport = MockTransport([_KndxIndex.HEADER]) index = self.get_knit_index(transport, "filename", "r") self.add_a_b(index, random_id=True) @@ -1134,62 +1169,61 @@ def test_delay_create_and_add_versions(self): # missing create it), then a second where we write the contents out. self.assertEqual(2, len(transport.calls)) call = transport.calls.pop(0) - self.assertEqual('put_file_non_atomic', call[0]) - self.assertEqual('filename.kndx', call[1][0]) + self.assertEqual("put_file_non_atomic", call[0]) + self.assertEqual("filename.kndx", call[1][0]) # With no history, _KndxIndex writes a new index: self.assertEqual(_KndxIndex.HEADER, call[1][1].getvalue()) - self.assertEqual({'create_parent_dir': True}, call[2]) + self.assertEqual({"create_parent_dir": True}, call[2]) call = transport.calls.pop(0) # call[1][1] is a BytesIO - we can't test it by simple equality. - self.assertEqual('put_file_non_atomic', call[0]) - self.assertEqual('filename.kndx', call[1][0]) + self.assertEqual("put_file_non_atomic", call[0]) + self.assertEqual("filename.kndx", call[1][0]) # With no history, _KndxIndex writes a new index: self.assertEqual( - _KndxIndex.HEADER + - b"\na option 0 1 .b :" + _KndxIndex.HEADER + b"\na option 0 1 .b :" b"\na opt 1 2 .c :" b"\nb option 2 3 0 :", - call[1][1].getvalue()) - self.assertEqual({'create_parent_dir': True}, call[2]) + call[1][1].getvalue(), + ) + self.assertEqual({"create_parent_dir": True}, call[2]) def assertTotalBuildSize(self, size, keys, positions): - self.assertEqual(size, - knit._get_total_build_size(None, keys, positions)) + self.assertEqual(size, knit._get_total_build_size(None, keys, positions)) def test__get_total_build_size(self): positions = { - (b'a',): (('fulltext', False), ((b'a',), 0, 100), None), - (b'b',): (('line-delta', False), ((b'b',), 100, 21), (b'a',)), - (b'c',): (('line-delta', False), ((b'c',), 121, 35), (b'b',)), - (b'd',): (('line-delta', False), ((b'd',), 156, 12), (b'b',)), - } - self.assertTotalBuildSize(100, [(b'a',)], positions) - self.assertTotalBuildSize(121, [(b'b',)], positions) + (b"a",): (("fulltext", False), ((b"a",), 0, 100), None), + (b"b",): (("line-delta", False), ((b"b",), 100, 21), (b"a",)), + (b"c",): (("line-delta", False), ((b"c",), 121, 35), (b"b",)), + (b"d",): (("line-delta", False), ((b"d",), 156, 12), (b"b",)), + } + self.assertTotalBuildSize(100, [(b"a",)], positions) + self.assertTotalBuildSize(121, [(b"b",)], positions) # c needs both a & b - self.assertTotalBuildSize(156, [(b'c',)], positions) + self.assertTotalBuildSize(156, [(b"c",)], positions) # we shouldn't count 'b' twice - self.assertTotalBuildSize(156, [(b'b',), (b'c',)], positions) - self.assertTotalBuildSize(133, [(b'd',)], positions) - self.assertTotalBuildSize(168, [(b'c',), (b'd',)], positions) + self.assertTotalBuildSize(156, [(b"b",), (b"c",)], positions) + self.assertTotalBuildSize(133, [(b"d",)], positions) + self.assertTotalBuildSize(168, [(b"c",), (b"d",)], positions) def test_get_position(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 0 1 :", - b"b option 1 2 :" - ]) + transport = MockTransport( + [_KndxIndex.HEADER, b"a option 0 1 :", b"b option 1 2 :"] + ) index = self.get_knit_index(transport, "filename", "r") self.assertEqual(((b"a",), 0, 1), index.get_position((b"a",))) self.assertEqual(((b"b",), 1, 2), index.get_position((b"b",))) def test_get_method(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a fulltext,unknown 0 1 :", - b"b unknown,line-delta 1 2 :", - b"c bad 3 4 :" - ]) + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a fulltext,unknown 0 1 :", + b"b unknown,line-delta 1 2 :", + b"c bad 3 4 :", + ] + ) index = self.get_knit_index(transport, "filename", "r") self.assertEqual("fulltext", index.get_method(b"a")) @@ -1197,130 +1231,147 @@ def test_get_method(self): self.assertRaises(knit.KnitIndexUnknownMethod, index.get_method, b"c") def test_get_options(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a opt1 0 1 :", - b"b opt2,opt3 1 2 :" - ]) + transport = MockTransport( + [_KndxIndex.HEADER, b"a opt1 0 1 :", b"b opt2,opt3 1 2 :"] + ) index = self.get_knit_index(transport, "filename", "r") self.assertEqual([b"opt1"], index.get_options(b"a")) self.assertEqual([b"opt2", b"opt3"], index.get_options(b"b")) def test_get_parent_map(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 0 1 :", - b"b option 1 2 0 .c :", - b"c option 1 2 1 0 .e :" - ]) + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 0 1 :", + b"b option 1 2 0 .c :", + b"c option 1 2 1 0 .e :", + ] + ) index = self.get_knit_index(transport, "filename", "r") - self.assertEqual({ - (b"a",): (), - (b"b",): ((b"a",), (b"c",)), - (b"c",): ((b"b",), (b"a",), (b"e",)), - }, index.get_parent_map(index.keys())) + self.assertEqual( + { + (b"a",): (), + (b"b",): ((b"a",), (b"c",)), + (b"c",): ((b"b",), (b"a",), (b"e",)), + }, + index.get_parent_map(index.keys()), + ) def test_impossible_parent(self): """Test we get KnitCorrupt if the parent couldn't possibly exist.""" - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 0 1 :", - b"b option 0 1 4 :" # We don't have a 4th record - ]) - index = self.get_knit_index(transport, 'filename', 'r') + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 0 1 :", + b"b option 0 1 4 :", # We don't have a 4th record + ] + ) + index = self.get_knit_index(transport, "filename", "r") self.assertRaises(KnitCorrupt, index.keys) def test_corrupted_parent(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 0 1 :", - b"b option 0 1 :", - b"c option 0 1 1v :", # Can't have a parent of '1v' - ]) - index = self.get_knit_index(transport, 'filename', 'r') + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 0 1 :", + b"b option 0 1 :", + b"c option 0 1 1v :", # Can't have a parent of '1v' + ] + ) + index = self.get_knit_index(transport, "filename", "r") self.assertRaises(KnitCorrupt, index.keys) def test_corrupted_parent_in_list(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 0 1 :", - b"b option 0 1 :", - b"c option 0 1 1 v :", # Can't have a parent of 'v' - ]) - index = self.get_knit_index(transport, 'filename', 'r') + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 0 1 :", + b"b option 0 1 :", + b"c option 0 1 1 v :", # Can't have a parent of 'v' + ] + ) + index = self.get_knit_index(transport, "filename", "r") self.assertRaises(KnitCorrupt, index.keys) def test_invalid_position(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 1v 1 :", - ]) - index = self.get_knit_index(transport, 'filename', 'r') + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 1v 1 :", + ] + ) + index = self.get_knit_index(transport, "filename", "r") self.assertRaises(KnitCorrupt, index.keys) def test_invalid_size(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 1 1v :", - ]) - index = self.get_knit_index(transport, 'filename', 'r') + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 1 1v :", + ] + ) + index = self.get_knit_index(transport, "filename", "r") self.assertRaises(KnitCorrupt, index.keys) def test_scan_unvalidated_index_not_implemented(self): transport = MockTransport() - index = self.get_knit_index(transport, 'filename', 'r') - self.assertRaises( - NotImplementedError, index.scan_unvalidated_index, - 'dummy graph_index') + index = self.get_knit_index(transport, "filename", "r") self.assertRaises( - NotImplementedError, index.get_missing_compression_parents) + NotImplementedError, index.scan_unvalidated_index, "dummy graph_index" + ) + self.assertRaises(NotImplementedError, index.get_missing_compression_parents) def test_short_line(self): - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 0 10 :", - b"b option 10 10 0", # This line isn't terminated, ignored - ]) + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 0 10 :", + b"b option 10 10 0", # This line isn't terminated, ignored + ] + ) index = self.get_knit_index(transport, "filename", "r") - self.assertEqual({(b'a',)}, index.keys()) + self.assertEqual({(b"a",)}, index.keys()) def test_skip_incomplete_record(self): # A line with bogus data should just be skipped - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 0 10 :", - b"b option 10 10 0", # This line isn't terminated, ignored - b"c option 20 10 0 :", # Properly terminated, and starts with '\n' - ]) + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 0 10 :", + b"b option 10 10 0", # This line isn't terminated, ignored + b"c option 20 10 0 :", # Properly terminated, and starts with '\n' + ] + ) index = self.get_knit_index(transport, "filename", "r") - self.assertEqual({(b'a',), (b'c',)}, index.keys()) + self.assertEqual({(b"a",), (b"c",)}, index.keys()) def test_trailing_characters(self): # A line with bogus data should just be skipped - transport = MockTransport([ - _KndxIndex.HEADER, - b"a option 0 10 :", - b"b option 10 10 0 :a", # This line has extra trailing characters - b"c option 20 10 0 :", # Properly terminated, and starts with '\n' - ]) + transport = MockTransport( + [ + _KndxIndex.HEADER, + b"a option 0 10 :", + b"b option 10 10 0 :a", # This line has extra trailing characters + b"c option 20 10 0 :", # Properly terminated, and starts with '\n' + ] + ) index = self.get_knit_index(transport, "filename", "r") - self.assertEqual({(b'a',), (b'c',)}, index.keys()) + self.assertEqual({(b"a",), (b"c",)}, index.keys()) class LowLevelKnitIndexTests_c(LowLevelKnitIndexTests): - _test_needs_features = [compiled_knit_feature] @property def _load_data(self): from .._knit_load_data_pyx import _load_data_c + return _load_data_c class Test_KnitAnnotator(TestCaseWithMemoryTransport): - def make_annotator(self): factory = knit.make_pack_factory(True, True, 1) vf = factory(self.get_transport()) @@ -1328,14 +1379,19 @@ def make_annotator(self): def test__expand_fulltext(self): ann = self.make_annotator() - rev_key = (b'rev-id',) + rev_key = (b"rev-id",) ann._num_compression_children[rev_key] = 1 - res = ann._expand_record(rev_key, ((b'parent-id',),), None, - [b'line1\n', b'line2\n'], ('fulltext', True)) + res = ann._expand_record( + rev_key, + ((b"parent-id",),), + None, + [b"line1\n", b"line2\n"], + ("fulltext", True), + ) # The content object and text lines should be cached appropriately - self.assertEqual([b'line1\n', b'line2'], res) + self.assertEqual([b"line1\n", b"line2"], res) content_obj = ann._content_objects[rev_key] - self.assertEqual([b'line1\n', b'line2\n'], content_obj._lines) + self.assertEqual([b"line1\n", b"line2\n"], content_obj._lines) self.assertEqual(res, content_obj.text()) self.assertEqual(res, ann._text_cache[rev_key]) @@ -1343,12 +1399,11 @@ def test__expand_delta_comp_parent_not_available(self): # Parent isn't available yet, so we return nothing, but queue up this # node for later processing ann = self.make_annotator() - rev_key = (b'rev-id',) - parent_key = (b'parent-id',) - record = [b'0,1,1\n', b'new-line\n'] - details = ('line-delta', False) - res = ann._expand_record(rev_key, (parent_key,), parent_key, - record, details) + rev_key = (b"rev-id",) + parent_key = (b"parent-id",) + record = [b"0,1,1\n", b"new-line\n"] + details = ("line-delta", False) + res = ann._expand_record(rev_key, (parent_key,), parent_key, record, details) self.assertEqual(None, res) self.assertIn(parent_key, ann._pending_deltas) pending = ann._pending_deltas[parent_key] @@ -1357,21 +1412,20 @@ def test__expand_delta_comp_parent_not_available(self): def test__expand_record_tracks_num_children(self): ann = self.make_annotator() - rev_key = (b'rev-id',) - rev2_key = (b'rev2-id',) - parent_key = (b'parent-id',) - record = [b'0,1,1\n', b'new-line\n'] - details = ('line-delta', False) + rev_key = (b"rev-id",) + rev2_key = (b"rev2-id",) + parent_key = (b"parent-id",) + record = [b"0,1,1\n", b"new-line\n"] + details = ("line-delta", False) ann._num_compression_children[parent_key] = 2 - ann._expand_record(parent_key, (), None, [b'line1\n', b'line2\n'], - ('fulltext', False)) - ann._expand_record(rev_key, (parent_key,), parent_key, - record, details) + ann._expand_record( + parent_key, (), None, [b"line1\n", b"line2\n"], ("fulltext", False) + ) + ann._expand_record(rev_key, (parent_key,), parent_key, record, details) self.assertEqual({parent_key: 1}, ann._num_compression_children) # Expanding the second child should remove the content object, and the # num_compression_children entry - ann._expand_record(rev2_key, (parent_key,), parent_key, - record, details) + ann._expand_record(rev2_key, (parent_key,), parent_key, record, details) self.assertNotIn(parent_key, ann._content_objects) self.assertEqual({}, ann._num_compression_children) # We should not cache the content_objects for rev2 and rev, because @@ -1380,68 +1434,73 @@ def test__expand_record_tracks_num_children(self): def test__expand_delta_records_blocks(self): ann = self.make_annotator() - rev_key = (b'rev-id',) - parent_key = (b'parent-id',) - record = [b'0,1,1\n', b'new-line\n'] - details = ('line-delta', True) + rev_key = (b"rev-id",) + parent_key = (b"parent-id",) + record = [b"0,1,1\n", b"new-line\n"] + details = ("line-delta", True) ann._num_compression_children[parent_key] = 2 - ann._expand_record(parent_key, (), None, - [b'line1\n', b'line2\n', b'line3\n'], - ('fulltext', False)) + ann._expand_record( + parent_key, + (), + None, + [b"line1\n", b"line2\n", b"line3\n"], + ("fulltext", False), + ) ann._expand_record(rev_key, (parent_key,), parent_key, record, details) - self.assertEqual({(rev_key, parent_key): [(1, 1, 1), (3, 3, 0)]}, - ann._matching_blocks) - rev2_key = (b'rev2-id',) - record = [b'0,1,1\n', b'new-line\n'] - details = ('line-delta', False) - ann._expand_record(rev2_key, (parent_key,), - parent_key, record, details) - self.assertEqual([(1, 1, 2), (3, 3, 0)], - ann._matching_blocks[(rev2_key, parent_key)]) + self.assertEqual( + {(rev_key, parent_key): [(1, 1, 1), (3, 3, 0)]}, ann._matching_blocks + ) + rev2_key = (b"rev2-id",) + record = [b"0,1,1\n", b"new-line\n"] + details = ("line-delta", False) + ann._expand_record(rev2_key, (parent_key,), parent_key, record, details) + self.assertEqual( + [(1, 1, 2), (3, 3, 0)], ann._matching_blocks[(rev2_key, parent_key)] + ) def test__get_parent_ann_uses_matching_blocks(self): ann = self.make_annotator() - rev_key = (b'rev-id',) - parent_key = (b'parent-id',) + rev_key = (b"rev-id",) + parent_key = (b"parent-id",) parent_ann = [(parent_key,)] * 3 block_key = (rev_key, parent_key) ann._annotations_cache[parent_key] = parent_ann ann._matching_blocks[block_key] = [(0, 1, 1), (3, 3, 0)] # We should not try to access any parent_lines content, because we know # we already have the matching blocks - par_ann, blocks = ann._get_parent_annotations_and_matches(rev_key, - [b'1\n', b'2\n', b'3\n'], parent_key) + par_ann, blocks = ann._get_parent_annotations_and_matches( + rev_key, [b"1\n", b"2\n", b"3\n"], parent_key + ) self.assertEqual(parent_ann, par_ann) self.assertEqual([(0, 1, 1), (3, 3, 0)], blocks) self.assertEqual({}, ann._matching_blocks) def test__process_pending(self): ann = self.make_annotator() - rev_key = (b'rev-id',) - p1_key = (b'p1-id',) - p2_key = (b'p2-id',) - record = [b'0,1,1\n', b'new-line\n'] - details = ('line-delta', False) - p1_record = [b'line1\n', b'line2\n'] + rev_key = (b"rev-id",) + p1_key = (b"p1-id",) + p2_key = (b"p2-id",) + record = [b"0,1,1\n", b"new-line\n"] + details = ("line-delta", False) + p1_record = [b"line1\n", b"line2\n"] ann._num_compression_children[p1_key] = 1 - res = ann._expand_record(rev_key, (p1_key, p2_key), p1_key, - record, details) + res = ann._expand_record(rev_key, (p1_key, p2_key), p1_key, record, details) self.assertEqual(None, res) # self.assertTrue(p1_key in ann._pending_deltas) self.assertEqual({}, ann._pending_annotation) # Now insert p1, and we should be able to expand the delta - res = ann._expand_record(p1_key, (), None, p1_record, - ('fulltext', False)) + res = ann._expand_record(p1_key, (), None, p1_record, ("fulltext", False)) self.assertEqual(p1_record, res) ann._annotations_cache[p1_key] = [(p1_key,)] * 2 res = ann._process_pending(p1_key) self.assertEqual([], res) self.assertNotIn(p1_key, ann._pending_deltas) self.assertIn(p2_key, ann._pending_annotation) - self.assertEqual({p2_key: [(rev_key, (p1_key, p2_key))]}, - ann._pending_annotation) + self.assertEqual( + {p2_key: [(rev_key, (p1_key, p2_key))]}, ann._pending_annotation + ) # Now fill in parent 2, and pending annotation should be satisfied - res = ann._expand_record(p2_key, (), None, [], ('fulltext', False)) + res = ann._expand_record(p2_key, (), None, [], ("fulltext", False)) ann._annotations_cache[p2_key] = [] res = ann._process_pending(p2_key) self.assertEqual([rev_key], res) @@ -1450,42 +1509,50 @@ def test__process_pending(self): def test_record_delta_removes_basis(self): ann = self.make_annotator() - ann._expand_record((b'parent-id',), (), None, - [b'line1\n', b'line2\n'], ('fulltext', False)) - ann._num_compression_children[b'parent-id'] = 2 + ann._expand_record( + (b"parent-id",), (), None, [b"line1\n", b"line2\n"], ("fulltext", False) + ) + ann._num_compression_children[b"parent-id"] = 2 def test_annotate_special_text(self): ann = self.make_annotator() vf = ann._vf - rev1_key = (b'rev-1',) - rev2_key = (b'rev-2',) - rev3_key = (b'rev-3',) - spec_key = (b'special:',) - vf.add_lines(rev1_key, [], [b'initial content\n']) - vf.add_lines(rev2_key, [rev1_key], [b'initial content\n', - b'common content\n', - b'content in 2\n']) - vf.add_lines(rev3_key, [rev1_key], [b'initial content\n', - b'common content\n', - b'content in 3\n']) - spec_text = (b'initial content\n' - b'common content\n' - b'content in 2\n' - b'content in 3\n') + rev1_key = (b"rev-1",) + rev2_key = (b"rev-2",) + rev3_key = (b"rev-3",) + spec_key = (b"special:",) + vf.add_lines(rev1_key, [], [b"initial content\n"]) + vf.add_lines( + rev2_key, + [rev1_key], + [b"initial content\n", b"common content\n", b"content in 2\n"], + ) + vf.add_lines( + rev3_key, + [rev1_key], + [b"initial content\n", b"common content\n", b"content in 3\n"], + ) + spec_text = ( + b"initial content\n" b"common content\n" b"content in 2\n" b"content in 3\n" + ) ann.add_special_text(spec_key, [rev2_key, rev3_key], spec_text) anns, lines = ann.annotate(spec_key) - self.assertEqual([(rev1_key,), - (rev2_key, rev3_key), - (rev2_key,), - (rev3_key,), - ], anns) - self.assertEqualDiff(spec_text, b''.join(lines)) + self.assertEqual( + [ + (rev1_key,), + (rev2_key, rev3_key), + (rev2_key,), + (rev3_key,), + ], + anns, + ) + self.assertEqualDiff(spec_text, b"".join(lines)) class KnitTests(TestCaseWithTransport): """Class containing knit test helper routines.""" - def make_test_knit(self, annotate=False, name='test'): + def make_test_knit(self, annotate=False, name="test"): mapper = ConstantMapper(name) return make_file_factory(annotate, mapper)(self.get_transport()) @@ -1499,57 +1566,68 @@ def test_sha_exception_has_text(self): target = self.make_test_knit(name="target") if not source._max_delta_chain: raise TestNotApplicable( - "cannot get delta-caused sha failures without deltas.") + "cannot get delta-caused sha failures without deltas." + ) # create a basis - basis = (b'basis',) - broken = (b'broken',) - source.add_lines(basis, (), [b'foo\n']) - source.add_lines(broken, (basis,), [b'foo\n', b'bar\n']) + basis = (b"basis",) + broken = (b"broken",) + source.add_lines(basis, (), [b"foo\n"]) + source.add_lines(broken, (basis,), [b"foo\n", b"bar\n"]) # Seed target with a bad basis text - target.add_lines(basis, (), [b'gam\n']) + target.add_lines(basis, (), [b"gam\n"]) target.insert_record_stream( - source.get_record_stream([broken], 'unordered', False)) - err = self.assertRaises(KnitCorrupt, - next(target.get_record_stream([broken], 'unordered', True - )).get_bytes_as, 'chunked') - self.assertEqual([b'gam\n', b'bar\n'], err.content) + source.get_record_stream([broken], "unordered", False) + ) + err = self.assertRaises( + KnitCorrupt, + next(target.get_record_stream([broken], "unordered", True)).get_bytes_as, + "chunked", + ) + self.assertEqual([b"gam\n", b"bar\n"], err.content) # Test for formatting with live data self.assertStartsWith(str(err), "Knit ") class TestKnitIndex(KnitTests): - def test_add_versions_dictionary_compresses(self): """Adding versions to the index should update the lookup dict.""" knit = self.make_test_knit() idx = knit._index - idx.add_records([((b'a-1',), [b'fulltext'], ((b'a-1',), 0, 0), [])]) - self.check_file_contents('test.kndx', - b'# bzr knit index 8\n' - b'\n' - b'a-1 fulltext 0 0 :' - ) - idx.add_records([ - ((b'a-2',), [b'fulltext'], ((b'a-2',), 0, 0), [(b'a-1',)]), - ((b'a-3',), [b'fulltext'], ((b'a-3',), 0, 0), [(b'a-2',)]), - ]) - self.check_file_contents('test.kndx', - b'# bzr knit index 8\n' - b'\n' - b'a-1 fulltext 0 0 :\n' - b'a-2 fulltext 0 0 0 :\n' - b'a-3 fulltext 0 0 1 :' - ) - self.assertEqual({(b'a-3',), (b'a-1',), (b'a-2',)}, idx.keys()) - self.assertEqual({ - (b'a-1',): (((b'a-1',), 0, 0), None, (), ('fulltext', False)), - (b'a-2',): (((b'a-2',), 0, 0), None, ((b'a-1',),), ('fulltext', False)), - (b'a-3',): (((b'a-3',), 0, 0), None, ((b'a-2',),), ('fulltext', False)), - }, idx.get_build_details(idx.keys())) - self.assertEqual({(b'a-1',): (), - (b'a-2',): ((b'a-1',),), - (b'a-3',): ((b'a-2',),), }, - idx.get_parent_map(idx.keys())) + idx.add_records([((b"a-1",), [b"fulltext"], ((b"a-1",), 0, 0), [])]) + self.check_file_contents( + "test.kndx", b"# bzr knit index 8\n" b"\n" b"a-1 fulltext 0 0 :" + ) + idx.add_records( + [ + ((b"a-2",), [b"fulltext"], ((b"a-2",), 0, 0), [(b"a-1",)]), + ((b"a-3",), [b"fulltext"], ((b"a-3",), 0, 0), [(b"a-2",)]), + ] + ) + self.check_file_contents( + "test.kndx", + b"# bzr knit index 8\n" + b"\n" + b"a-1 fulltext 0 0 :\n" + b"a-2 fulltext 0 0 0 :\n" + b"a-3 fulltext 0 0 1 :", + ) + self.assertEqual({(b"a-3",), (b"a-1",), (b"a-2",)}, idx.keys()) + self.assertEqual( + { + (b"a-1",): (((b"a-1",), 0, 0), None, (), ("fulltext", False)), + (b"a-2",): (((b"a-2",), 0, 0), None, ((b"a-1",),), ("fulltext", False)), + (b"a-3",): (((b"a-3",), 0, 0), None, ((b"a-2",),), ("fulltext", False)), + }, + idx.get_build_details(idx.keys()), + ) + self.assertEqual( + { + (b"a-1",): (), + (b"a-2",): ((b"a-1",),), + (b"a-3",): ((b"a-2",),), + }, + idx.get_parent_map(idx.keys()), + ) def test_add_versions_fails_clean(self): """If add_versions fails in the middle, it restores a pristine state. @@ -1565,24 +1643,25 @@ def test_add_versions_fails_clean(self): knit = self.make_test_knit() idx = knit._index - idx.add_records([((b'a-1',), [b'fulltext'], ((b'a-1',), 0, 0), [])]) + idx.add_records([((b"a-1",), [b"fulltext"], ((b"a-1",), 0, 0), [])]) class StopEarly(Exception): pass def generate_failure(): """Add some entries and then raise an exception.""" - yield ((b'a-2',), [b'fulltext'], (None, 0, 0), (b'a-1',)) - yield ((b'a-3',), [b'fulltext'], (None, 0, 0), (b'a-2',)) + yield ((b"a-2",), [b"fulltext"], (None, 0, 0), (b"a-1",)) + yield ((b"a-3",), [b"fulltext"], (None, 0, 0), (b"a-2",)) raise StopEarly() # Assert the pre-condition def assertA1Only(): - self.assertEqual({(b'a-1',)}, set(idx.keys())) + self.assertEqual({(b"a-1",)}, set(idx.keys())) self.assertEqual( - {(b'a-1',): (((b'a-1',), 0, 0), None, (), ('fulltext', False))}, - idx.get_build_details([(b'a-1',)])) - self.assertEqual({(b'a-1',): ()}, idx.get_parent_map(idx.keys())) + {(b"a-1",): (((b"a-1",), 0, 0), None, (), ("fulltext", False))}, + idx.get_build_details([(b"a-1",)]), + ) + self.assertEqual({(b"a-1",): ()}, idx.get_parent_map(idx.keys())) assertA1Only() self.assertRaises(StopEarly, idx.add_records, generate_failure()) @@ -1594,14 +1673,14 @@ def test_knit_index_ignores_empty_files(self): # could leave an empty .kndx file, which bzr would later claim was a # corrupted file since the header was not present. In reality, the file # just wasn't created, so it should be ignored. - t = _mod_transport.get_transport_from_path('.') - t.put_bytes('test.kndx', b'') + t = _mod_transport.get_transport_from_path(".") + t.put_bytes("test.kndx", b"") self.make_test_knit() def test_knit_index_checks_header(self): - t = _mod_transport.get_transport_from_path('.') - t.put_bytes('test.kndx', b'# not really a knit header\n\n') + t = _mod_transport.get_transport_from_path(".") + t.put_bytes("test.kndx", b"# not really a knit header\n\n") k = self.make_test_knit() self.assertRaises(KnitHeaderError, k.keys) @@ -1629,21 +1708,48 @@ def two_graph_index(self, deltas=False, catch_adds=False): # build a complex graph across several indices. if deltas: # delta compression inn the index - index1 = self.make_g_index('1', 2, [ - ((b'tip', ), b'N0 100', ([(b'parent', )], [], )), - ((b'tail', ), b'', ([], []))]) - index2 = self.make_g_index('2', 2, [ - ((b'parent', ), b' 100 78', - ([(b'tail', ), (b'ghost', )], [(b'tail', )])), - ((b'separate', ), b'', ([], []))]) + index1 = self.make_g_index( + "1", + 2, + [ + ( + (b"tip",), + b"N0 100", + ( + [(b"parent",)], + [], + ), + ), + ((b"tail",), b"", ([], [])), + ], + ) + index2 = self.make_g_index( + "2", + 2, + [ + ( + (b"parent",), + b" 100 78", + ([(b"tail",), (b"ghost",)], [(b"tail",)]), + ), + ((b"separate",), b"", ([], [])), + ], + ) else: # just blob location and graph in the index. - index1 = self.make_g_index('1', 1, [ - ((b'tip', ), b'N0 100', ([(b'parent', )], )), - ((b'tail', ), b'', ([], ))]) - index2 = self.make_g_index('2', 1, [ - ((b'parent', ), b' 100 78', ([(b'tail', ), (b'ghost', )], )), - ((b'separate', ), b'', ([], ))]) + index1 = self.make_g_index( + "1", + 1, + [((b"tip",), b"N0 100", ([(b"parent",)],)), ((b"tail",), b"", ([],))], + ) + index2 = self.make_g_index( + "2", + 1, + [ + ((b"parent",), b" 100 78", ([(b"tail",), (b"ghost",)],)), + ((b"separate",), b"", ([],)), + ], + ) combined_index = CombinedGraphIndex([index1, index2]) if catch_adds: self.combined_index = combined_index @@ -1651,125 +1757,172 @@ def two_graph_index(self, deltas=False, catch_adds=False): add_callback = self.catch_add else: add_callback = None - return _KnitGraphIndex(combined_index, lambda: True, deltas=deltas, - add_callback=add_callback) + return _KnitGraphIndex( + combined_index, lambda: True, deltas=deltas, add_callback=add_callback + ) def test_keys(self): index = self.two_graph_index() - self.assertEqual({(b'tail',), (b'tip',), (b'parent',), (b'separate',)}, - set(index.keys())) + self.assertEqual( + {(b"tail",), (b"tip",), (b"parent",), (b"separate",)}, set(index.keys()) + ) def test_get_position(self): index = self.two_graph_index() self.assertEqual( - (index._graph_index._indices[0], 0, 100), index.get_position((b'tip',))) + (index._graph_index._indices[0], 0, 100), index.get_position((b"tip",)) + ) self.assertEqual( - (index._graph_index._indices[1], 100, 78), index.get_position((b'parent',))) + (index._graph_index._indices[1], 100, 78), index.get_position((b"parent",)) + ) def test_get_method_deltas(self): index = self.two_graph_index(deltas=True) - self.assertEqual('fulltext', index.get_method((b'tip',))) - self.assertEqual('line-delta', index.get_method((b'parent',))) + self.assertEqual("fulltext", index.get_method((b"tip",))) + self.assertEqual("line-delta", index.get_method((b"parent",))) def test_get_method_no_deltas(self): # check that the parent-history lookup is ignored with deltas=False. index = self.two_graph_index(deltas=False) - self.assertEqual('fulltext', index.get_method((b'tip',))) - self.assertEqual('fulltext', index.get_method((b'parent',))) + self.assertEqual("fulltext", index.get_method((b"tip",))) + self.assertEqual("fulltext", index.get_method((b"parent",))) def test_get_options_deltas(self): index = self.two_graph_index(deltas=True) - self.assertEqual([b'fulltext', b'no-eol'], - index.get_options((b'tip',))) - self.assertEqual([b'line-delta'], index.get_options((b'parent',))) + self.assertEqual([b"fulltext", b"no-eol"], index.get_options((b"tip",))) + self.assertEqual([b"line-delta"], index.get_options((b"parent",))) def test_get_options_no_deltas(self): # check that the parent-history lookup is ignored with deltas=False. index = self.two_graph_index(deltas=False) - self.assertEqual([b'fulltext', b'no-eol'], - index.get_options((b'tip',))) - self.assertEqual([b'fulltext'], index.get_options((b'parent',))) + self.assertEqual([b"fulltext", b"no-eol"], index.get_options((b"tip",))) + self.assertEqual([b"fulltext"], index.get_options((b"parent",))) def test_get_parent_map(self): index = self.two_graph_index() - self.assertEqual({(b'parent',): ((b'tail',), (b'ghost',))}, - index.get_parent_map([(b'parent',), (b'ghost',)])) + self.assertEqual( + {(b"parent",): ((b"tail",), (b"ghost",))}, + index.get_parent_map([(b"parent",), (b"ghost",)]), + ) def catch_add(self, entries): self.caught_entries.append(entries) def test_add_no_callback_errors(self): index = self.two_graph_index() - self.assertRaises(errors.ReadOnlyError, index.add_records, - [((b'new',), b'fulltext,no-eol', (None, 50, 60), [b'separate'])]) + self.assertRaises( + errors.ReadOnlyError, + index.add_records, + [((b"new",), b"fulltext,no-eol", (None, 50, 60), [b"separate"])], + ) def test_add_version_smoke(self): index = self.two_graph_index(catch_adds=True) - index.add_records([((b'new',), b'fulltext,no-eol', (None, 50, 60), - [(b'separate',)])]) - self.assertEqual([[((b'new', ), b'N50 60', (((b'separate',),),))]], - self.caught_entries) + index.add_records( + [((b"new",), b"fulltext,no-eol", (None, 50, 60), [(b"separate",)])] + ) + self.assertEqual( + [[((b"new",), b"N50 60", (((b"separate",),),))]], self.caught_entries + ) def test_add_version_delta_not_delta_index(self): index = self.two_graph_index(catch_adds=True) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'new',), b'no-eol,line-delta', (None, 0, 100), [(b'parent',)])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"new",), b"no-eol,line-delta", (None, 0, 100), [(b"parent",)])], + ) self.assertEqual([], self.caught_entries) def test_add_version_same_dup(self): index = self.two_graph_index(catch_adds=True) # options can be spelt two different ways index.add_records( - [((b'tip',), b'fulltext,no-eol', (None, 0, 100), [(b'parent',)])]) + [((b"tip",), b"fulltext,no-eol", (None, 0, 100), [(b"parent",)])] + ) index.add_records( - [((b'tip',), b'no-eol,fulltext', (None, 0, 100), [(b'parent',)])]) + [((b"tip",), b"no-eol,fulltext", (None, 0, 100), [(b"parent",)])] + ) # position/length are ignored (because each pack could have fulltext or # delta, and be at a different position. - index.add_records([((b'tip',), b'fulltext,no-eol', (None, 50, 100), - [(b'parent',)])]) - index.add_records([((b'tip',), b'fulltext,no-eol', (None, 0, 1000), - [(b'parent',)])]) + index.add_records( + [((b"tip",), b"fulltext,no-eol", (None, 50, 100), [(b"parent",)])] + ) + index.add_records( + [((b"tip",), b"fulltext,no-eol", (None, 0, 1000), [(b"parent",)])] + ) # but neither should have added data: self.assertEqual([[], [], [], []], self.caught_entries) def test_add_version_different_dup(self): index = self.two_graph_index(deltas=True, catch_adds=True) # change options - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'line-delta', (None, 0, 100), [(b'parent',)])]) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext', (None, 0, 100), [(b'parent',)])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"line-delta", (None, 0, 100), [(b"parent",)])], + ) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"fulltext", (None, 0, 100), [(b"parent",)])], + ) # parents - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext,no-eol', (None, 0, 100), [])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"fulltext,no-eol", (None, 0, 100), [])], + ) self.assertEqual([], self.caught_entries) def test_add_versions_nodeltas(self): index = self.two_graph_index(catch_adds=True) - index.add_records([ - ((b'new',), b'fulltext,no-eol', (None, 50, 60), [(b'separate',)]), - ((b'new2',), b'fulltext', (None, 0, 6), [(b'new',)]), - ]) - self.assertEqual([((b'new', ), b'N50 60', (((b'separate',),),)), - ((b'new2', ), b' 0 6', (((b'new',),),))], - sorted(self.caught_entries[0])) + index.add_records( + [ + ((b"new",), b"fulltext,no-eol", (None, 50, 60), [(b"separate",)]), + ((b"new2",), b"fulltext", (None, 0, 6), [(b"new",)]), + ] + ) + self.assertEqual( + [ + ((b"new",), b"N50 60", (((b"separate",),),)), + ((b"new2",), b" 0 6", (((b"new",),),)), + ], + sorted(self.caught_entries[0]), + ) self.assertEqual(1, len(self.caught_entries)) def test_add_versions_deltas(self): index = self.two_graph_index(deltas=True, catch_adds=True) - index.add_records([ - ((b'new',), b'fulltext,no-eol', (None, 50, 60), [(b'separate',)]), - ((b'new2',), b'line-delta', (None, 0, 6), [(b'new',)]), - ]) - self.assertEqual([((b'new', ), b'N50 60', (((b'separate',),), ())), - ((b'new2', ), b' 0 6', (((b'new',),), ((b'new',),), ))], - sorted(self.caught_entries[0])) + index.add_records( + [ + ((b"new",), b"fulltext,no-eol", (None, 50, 60), [(b"separate",)]), + ((b"new2",), b"line-delta", (None, 0, 6), [(b"new",)]), + ] + ) + self.assertEqual( + [ + ((b"new",), b"N50 60", (((b"separate",),), ())), + ( + (b"new2",), + b" 0 6", + ( + ((b"new",),), + ((b"new",),), + ), + ), + ], + sorted(self.caught_entries[0]), + ) self.assertEqual(1, len(self.caught_entries)) def test_add_versions_delta_not_delta_index(self): index = self.two_graph_index(catch_adds=True) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'new',), b'no-eol,line-delta', (None, 0, 100), [(b'parent',)])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"new",), b"no-eol,line-delta", (None, 0, 100), [(b"parent",)])], + ) self.assertEqual([], self.caught_entries) def test_add_versions_random_id_accepted(self): @@ -1779,53 +1932,88 @@ def test_add_versions_random_id_accepted(self): def test_add_versions_same_dup(self): index = self.two_graph_index(catch_adds=True) # options can be spelt two different ways - index.add_records([((b'tip',), b'fulltext,no-eol', (None, 0, 100), - [(b'parent',)])]) - index.add_records([((b'tip',), b'no-eol,fulltext', (None, 0, 100), - [(b'parent',)])]) + index.add_records( + [((b"tip",), b"fulltext,no-eol", (None, 0, 100), [(b"parent",)])] + ) + index.add_records( + [((b"tip",), b"no-eol,fulltext", (None, 0, 100), [(b"parent",)])] + ) # position/length are ignored (because each pack could have fulltext or # delta, and be at a different position. - index.add_records([((b'tip',), b'fulltext,no-eol', (None, 50, 100), - [(b'parent',)])]) - index.add_records([((b'tip',), b'fulltext,no-eol', (None, 0, 1000), - [(b'parent',)])]) + index.add_records( + [((b"tip",), b"fulltext,no-eol", (None, 50, 100), [(b"parent",)])] + ) + index.add_records( + [((b"tip",), b"fulltext,no-eol", (None, 0, 1000), [(b"parent",)])] + ) # but neither should have added data. self.assertEqual([[], [], [], []], self.caught_entries) def test_add_versions_different_dup(self): index = self.two_graph_index(deltas=True, catch_adds=True) # change options - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'line-delta', (None, 0, 100), [(b'parent',)])]) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext', (None, 0, 100), [(b'parent',)])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"line-delta", (None, 0, 100), [(b"parent",)])], + ) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"fulltext", (None, 0, 100), [(b"parent",)])], + ) # parents - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext,no-eol', (None, 0, 100), [])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"fulltext,no-eol", (None, 0, 100), [])], + ) # change options in the second record - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext,no-eol', (None, 0, 100), [(b'parent',)]), - ((b'tip',), b'line-delta', (None, 0, 100), [(b'parent',)])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [ + ((b"tip",), b"fulltext,no-eol", (None, 0, 100), [(b"parent",)]), + ((b"tip",), b"line-delta", (None, 0, 100), [(b"parent",)]), + ], + ) self.assertEqual([], self.caught_entries) def make_g_index_missing_compression_parent(self): - graph_index = self.make_g_index('missing_comp', 2, - [((b'tip', ), b' 100 78', - ([(b'missing-parent', ), (b'ghost', )], [(b'missing-parent', )]))]) + graph_index = self.make_g_index( + "missing_comp", + 2, + [ + ( + (b"tip",), + b" 100 78", + ([(b"missing-parent",), (b"ghost",)], [(b"missing-parent",)]), + ) + ], + ) return graph_index def make_g_index_missing_parent(self): - graph_index = self.make_g_index('missing_parent', 2, - [((b'parent', ), b' 100 78', ([], [])), - ((b'tip', ), b' 100 78', - ([(b'parent', ), (b'missing-parent', )], [(b'parent', )])), - ]) + graph_index = self.make_g_index( + "missing_parent", + 2, + [ + ((b"parent",), b" 100 78", ([], [])), + ( + (b"tip",), + b" 100 78", + ([(b"parent",), (b"missing-parent",)], [(b"parent",)]), + ), + ], + ) return graph_index def make_g_index_no_external_refs(self): - graph_index = self.make_g_index('no_external_refs', 2, - [((b'rev', ), b' 100 78', - ([(b'parent', ), (b'ghost', )], []))]) + graph_index = self.make_g_index( + "no_external_refs", + 2, + [((b"rev",), b" 100 78", ([(b"parent",), (b"ghost",)], []))], + ) return graph_index def test_add_good_unvalidated_index(self): @@ -1844,30 +2032,42 @@ def test_add_missing_compression_parent_unvalidated_index(self): # examined, otherwise 'ghost' would also be reported as a missing # parent. self.assertEqual( - frozenset([(b'missing-parent',)]), - index.get_missing_compression_parents()) + frozenset([(b"missing-parent",)]), index.get_missing_compression_parents() + ) def test_add_missing_noncompression_parent_unvalidated_index(self): unvalidated = self.make_g_index_missing_parent() combined = CombinedGraphIndex([unvalidated]) - index = _KnitGraphIndex(combined, lambda: True, deltas=True, - track_external_parent_refs=True) + index = _KnitGraphIndex( + combined, lambda: True, deltas=True, track_external_parent_refs=True + ) index.scan_unvalidated_index(unvalidated) - self.assertEqual( - frozenset([(b'missing-parent',)]), index.get_missing_parents()) + self.assertEqual(frozenset([(b"missing-parent",)]), index.get_missing_parents()) def test_track_external_parent_refs(self): - g_index = self.make_g_index('empty', 2, []) + g_index = self.make_g_index("empty", 2, []) combined = CombinedGraphIndex([g_index]) - index = _KnitGraphIndex(combined, lambda: True, deltas=True, - add_callback=self.catch_add, track_external_parent_refs=True) + index = _KnitGraphIndex( + combined, + lambda: True, + deltas=True, + add_callback=self.catch_add, + track_external_parent_refs=True, + ) self.caught_entries = [] - index.add_records([ - ((b'new-key',), b'fulltext,no-eol', (None, 50, 60), - [(b'parent-1',), (b'parent-2',)])]) + index.add_records( + [ + ( + (b"new-key",), + b"fulltext,no-eol", + (None, 50, 60), + [(b"parent-1",), (b"parent-2",)], + ) + ] + ) self.assertEqual( - frozenset([(b'parent-1',), (b'parent-2',)]), - index.get_missing_parents()) + frozenset([(b"parent-1",), (b"parent-2",)]), index.get_missing_parents() + ) def test_add_unvalidated_index_with_present_external_references(self): index = self.two_graph_index(deltas=True) @@ -1880,38 +2080,62 @@ def test_add_unvalidated_index_with_present_external_references(self): self.assertEqual(frozenset(), index.get_missing_compression_parents()) def make_new_missing_parent_g_index(self, name): - missing_parent = name.encode('ascii') + b'-missing-parent' - graph_index = self.make_g_index(name, 2, - [((name.encode('ascii') + b'tip', ), b' 100 78', - ([(missing_parent, ), (b'ghost', )], [(missing_parent, )]))]) + missing_parent = name.encode("ascii") + b"-missing-parent" + graph_index = self.make_g_index( + name, + 2, + [ + ( + (name.encode("ascii") + b"tip",), + b" 100 78", + ([(missing_parent,), (b"ghost",)], [(missing_parent,)]), + ) + ], + ) return graph_index def test_add_mulitiple_unvalidated_indices_with_missing_parents(self): - g_index_1 = self.make_new_missing_parent_g_index('one') - g_index_2 = self.make_new_missing_parent_g_index('two') + g_index_1 = self.make_new_missing_parent_g_index("one") + g_index_2 = self.make_new_missing_parent_g_index("two") combined = CombinedGraphIndex([g_index_1, g_index_2]) index = _KnitGraphIndex(combined, lambda: True, deltas=True) index.scan_unvalidated_index(g_index_1) index.scan_unvalidated_index(g_index_2) self.assertEqual( - frozenset([(b'one-missing-parent',), (b'two-missing-parent',)]), - index.get_missing_compression_parents()) + frozenset([(b"one-missing-parent",), (b"two-missing-parent",)]), + index.get_missing_compression_parents(), + ) def test_add_mulitiple_unvalidated_indices_with_mutual_dependencies(self): - graph_index_a = self.make_g_index('one', 2, - [((b'parent-one', ), b' 100 78', ([(b'non-compression-parent',)], [])), - ((b'child-of-two', ), b' 100 78', - ([(b'parent-two',)], [(b'parent-two',)]))]) - graph_index_b = self.make_g_index('two', 2, - [((b'parent-two', ), b' 100 78', ([(b'non-compression-parent',)], [])), - ((b'child-of-one', ), b' 100 78', - ([(b'parent-one',)], [(b'parent-one',)]))]) + graph_index_a = self.make_g_index( + "one", + 2, + [ + ((b"parent-one",), b" 100 78", ([(b"non-compression-parent",)], [])), + ( + (b"child-of-two",), + b" 100 78", + ([(b"parent-two",)], [(b"parent-two",)]), + ), + ], + ) + graph_index_b = self.make_g_index( + "two", + 2, + [ + ((b"parent-two",), b" 100 78", ([(b"non-compression-parent",)], [])), + ( + (b"child-of-one",), + b" 100 78", + ([(b"parent-one",)], [(b"parent-one",)]), + ), + ], + ) combined = CombinedGraphIndex([graph_index_a, graph_index_b]) index = _KnitGraphIndex(combined, lambda: True, deltas=True) index.scan_unvalidated_index(graph_index_a) index.scan_unvalidated_index(graph_index_b) - self.assertEqual( - frozenset([]), index.get_missing_compression_parents()) + self.assertEqual(frozenset([]), index.get_missing_compression_parents()) class TestNoParentsGraphIndexKnit(KnitTests): @@ -1929,17 +2153,22 @@ def make_g_index(self, name, ref_lists=0, nodes=None): return GraphIndex(trans, name, size) def test_add_good_unvalidated_index(self): - unvalidated = self.make_g_index('unvalidated') + unvalidated = self.make_g_index("unvalidated") combined = CombinedGraphIndex([unvalidated]) index = _KnitGraphIndex(combined, lambda: True, parents=False) index.scan_unvalidated_index(unvalidated) - self.assertEqual(frozenset(), - index.get_missing_compression_parents()) + self.assertEqual(frozenset(), index.get_missing_compression_parents()) def test_parents_deltas_incompatible(self): index = CombinedGraphIndex([]) - self.assertRaises(knit.KnitError, _KnitGraphIndex, lambda: True, - index, deltas=True, parents=False) + self.assertRaises( + knit.KnitError, + _KnitGraphIndex, + lambda: True, + index, + deltas=True, + parents=False, + ) def two_graph_index(self, catch_adds=False): """Build a two-graph index. @@ -1948,12 +2177,10 @@ def two_graph_index(self, catch_adds=False): lists and 'parent' set to a delta-compressed against tail. """ # put several versions in the index. - index1 = self.make_g_index('1', 0, [ - ((b'tip', ), b'N0 100'), - ((b'tail', ), b'')]) - index2 = self.make_g_index('2', 0, [ - ((b'parent', ), b' 100 78'), - ((b'separate', ), b'')]) + index1 = self.make_g_index("1", 0, [((b"tip",), b"N0 100"), ((b"tail",), b"")]) + index2 = self.make_g_index( + "2", 0, [((b"parent",), b" 100 78"), ((b"separate",), b"")] + ) combined_index = CombinedGraphIndex([index1, index2]) if catch_adds: self.combined_index = combined_index @@ -1961,108 +2188,134 @@ def two_graph_index(self, catch_adds=False): add_callback = self.catch_add else: add_callback = None - return _KnitGraphIndex(combined_index, lambda: True, parents=False, - add_callback=add_callback) + return _KnitGraphIndex( + combined_index, lambda: True, parents=False, add_callback=add_callback + ) def test_keys(self): index = self.two_graph_index() - self.assertEqual({(b'tail',), (b'tip',), (b'parent',), (b'separate',)}, - set(index.keys())) + self.assertEqual( + {(b"tail",), (b"tip",), (b"parent",), (b"separate",)}, set(index.keys()) + ) def test_get_position(self): index = self.two_graph_index() - self.assertEqual((index._graph_index._indices[0], 0, 100), - index.get_position((b'tip',))) - self.assertEqual((index._graph_index._indices[1], 100, 78), - index.get_position((b'parent',))) + self.assertEqual( + (index._graph_index._indices[0], 0, 100), index.get_position((b"tip",)) + ) + self.assertEqual( + (index._graph_index._indices[1], 100, 78), index.get_position((b"parent",)) + ) def test_get_method(self): index = self.two_graph_index() - self.assertEqual('fulltext', index.get_method((b'tip',))) - self.assertEqual([b'fulltext'], index.get_options((b'parent',))) + self.assertEqual("fulltext", index.get_method((b"tip",))) + self.assertEqual([b"fulltext"], index.get_options((b"parent",))) def test_get_options(self): index = self.two_graph_index() - self.assertEqual([b'fulltext', b'no-eol'], - index.get_options((b'tip',))) - self.assertEqual([b'fulltext'], index.get_options((b'parent',))) + self.assertEqual([b"fulltext", b"no-eol"], index.get_options((b"tip",))) + self.assertEqual([b"fulltext"], index.get_options((b"parent",))) def test_get_parent_map(self): index = self.two_graph_index() - self.assertEqual({(b'parent',): None}, - index.get_parent_map([(b'parent',), (b'ghost',)])) + self.assertEqual( + {(b"parent",): None}, index.get_parent_map([(b"parent",), (b"ghost",)]) + ) def catch_add(self, entries): self.caught_entries.append(entries) def test_add_no_callback_errors(self): index = self.two_graph_index() - self.assertRaises(errors.ReadOnlyError, index.add_records, - [((b'new',), b'fulltext,no-eol', (None, 50, 60), [(b'separate',)])]) + self.assertRaises( + errors.ReadOnlyError, + index.add_records, + [((b"new",), b"fulltext,no-eol", (None, 50, 60), [(b"separate",)])], + ) def test_add_version_smoke(self): index = self.two_graph_index(catch_adds=True) - index.add_records( - [((b'new',), b'fulltext,no-eol', (None, 50, 60), [])]) - self.assertEqual([[((b'new', ), b'N50 60')]], - self.caught_entries) + index.add_records([((b"new",), b"fulltext,no-eol", (None, 50, 60), [])]) + self.assertEqual([[((b"new",), b"N50 60")]], self.caught_entries) def test_add_version_delta_not_delta_index(self): index = self.two_graph_index(catch_adds=True) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'new',), b'no-eol,line-delta', (None, 0, 100), [])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"new",), b"no-eol,line-delta", (None, 0, 100), [])], + ) self.assertEqual([], self.caught_entries) def test_add_version_same_dup(self): index = self.two_graph_index(catch_adds=True) # options can be spelt two different ways - index.add_records( - [((b'tip',), b'fulltext,no-eol', (None, 0, 100), [])]) - index.add_records( - [((b'tip',), b'no-eol,fulltext', (None, 0, 100), [])]) + index.add_records([((b"tip",), b"fulltext,no-eol", (None, 0, 100), [])]) + index.add_records([((b"tip",), b"no-eol,fulltext", (None, 0, 100), [])]) # position/length are ignored (because each pack could have fulltext or # delta, and be at a different position. - index.add_records( - [((b'tip',), b'fulltext,no-eol', (None, 50, 100), [])]) - index.add_records( - [((b'tip',), b'fulltext,no-eol', (None, 0, 1000), [])]) + index.add_records([((b"tip",), b"fulltext,no-eol", (None, 50, 100), [])]) + index.add_records([((b"tip",), b"fulltext,no-eol", (None, 0, 1000), [])]) # but neither should have added data. self.assertEqual([[], [], [], []], self.caught_entries) def test_add_version_different_dup(self): index = self.two_graph_index(catch_adds=True) # change options - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'no-eol,line-delta', (None, 0, 100), [])]) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'line-delta,no-eol', (None, 0, 100), [])]) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext', (None, 0, 100), [])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"no-eol,line-delta", (None, 0, 100), [])], + ) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"line-delta,no-eol", (None, 0, 100), [])], + ) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"fulltext", (None, 0, 100), [])], + ) # parents - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext,no-eol', (None, 0, 100), [(b'parent',)])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"fulltext,no-eol", (None, 0, 100), [(b"parent",)])], + ) self.assertEqual([], self.caught_entries) def test_add_versions(self): index = self.two_graph_index(catch_adds=True) - index.add_records([ - ((b'new',), b'fulltext,no-eol', (None, 50, 60), []), - ((b'new2',), b'fulltext', (None, 0, 6), []), - ]) - self.assertEqual([((b'new', ), b'N50 60'), ((b'new2', ), b' 0 6')], - sorted(self.caught_entries[0])) + index.add_records( + [ + ((b"new",), b"fulltext,no-eol", (None, 50, 60), []), + ((b"new2",), b"fulltext", (None, 0, 6), []), + ] + ) + self.assertEqual( + [((b"new",), b"N50 60"), ((b"new2",), b" 0 6")], + sorted(self.caught_entries[0]), + ) self.assertEqual(1, len(self.caught_entries)) def test_add_versions_delta_not_delta_index(self): index = self.two_graph_index(catch_adds=True) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'new',), b'no-eol,line-delta', (None, 0, 100), [(b'parent',)])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"new",), b"no-eol,line-delta", (None, 0, 100), [(b"parent",)])], + ) self.assertEqual([], self.caught_entries) def test_add_versions_parents_not_parents_index(self): index = self.two_graph_index(catch_adds=True) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'new',), b'no-eol,fulltext', (None, 0, 100), [(b'parent',)])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"new",), b"no-eol,fulltext", (None, 0, 100), [(b"parent",)])], + ) self.assertEqual([], self.caught_entries) def test_add_versions_random_id_accepted(self): @@ -2072,64 +2325,79 @@ def test_add_versions_random_id_accepted(self): def test_add_versions_same_dup(self): index = self.two_graph_index(catch_adds=True) # options can be spelt two different ways - index.add_records( - [((b'tip',), b'fulltext,no-eol', (None, 0, 100), [])]) - index.add_records( - [((b'tip',), b'no-eol,fulltext', (None, 0, 100), [])]) + index.add_records([((b"tip",), b"fulltext,no-eol", (None, 0, 100), [])]) + index.add_records([((b"tip",), b"no-eol,fulltext", (None, 0, 100), [])]) # position/length are ignored (because each pack could have fulltext or # delta, and be at a different position. - index.add_records( - [((b'tip',), b'fulltext,no-eol', (None, 50, 100), [])]) - index.add_records( - [((b'tip',), b'fulltext,no-eol', (None, 0, 1000), [])]) + index.add_records([((b"tip",), b"fulltext,no-eol", (None, 50, 100), [])]) + index.add_records([((b"tip",), b"fulltext,no-eol", (None, 0, 1000), [])]) # but neither should have added data. self.assertEqual([[], [], [], []], self.caught_entries) def test_add_versions_different_dup(self): index = self.two_graph_index(catch_adds=True) # change options - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'no-eol,line-delta', (None, 0, 100), [])]) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'line-delta,no-eol', (None, 0, 100), [])]) - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext', (None, 0, 100), [])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"no-eol,line-delta", (None, 0, 100), [])], + ) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"line-delta,no-eol", (None, 0, 100), [])], + ) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"fulltext", (None, 0, 100), [])], + ) # parents - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext,no-eol', (None, 0, 100), [(b'parent',)])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [((b"tip",), b"fulltext,no-eol", (None, 0, 100), [(b"parent",)])], + ) # change options in the second record - self.assertRaises(KnitCorrupt, index.add_records, - [((b'tip',), b'fulltext,no-eol', (None, 0, 100), []), - ((b'tip',), b'no-eol,line-delta', (None, 0, 100), [])]) + self.assertRaises( + KnitCorrupt, + index.add_records, + [ + ((b"tip",), b"fulltext,no-eol", (None, 0, 100), []), + ((b"tip",), b"no-eol,line-delta", (None, 0, 100), []), + ], + ) self.assertEqual([], self.caught_entries) class TestKnitVersionedFiles(KnitTests): - - def assertGroupKeysForIo(self, exp_groups, keys, non_local_keys, - positions, _min_buffer_size=None): + def assertGroupKeysForIo( + self, exp_groups, keys, non_local_keys, positions, _min_buffer_size=None + ): kvf = self.make_test_knit() if _min_buffer_size is None: _min_buffer_size = knit._STREAM_MIN_BUFFER_SIZE - self.assertEqual(exp_groups, kvf._group_keys_for_io(keys, - non_local_keys, positions, - _min_buffer_size=_min_buffer_size)) + self.assertEqual( + exp_groups, + kvf._group_keys_for_io( + keys, non_local_keys, positions, _min_buffer_size=_min_buffer_size + ), + ) - def assertSplitByPrefix(self, expected_map, expected_prefix_order, - keys): + def assertSplitByPrefix(self, expected_map, expected_prefix_order, keys): split, prefix_order = KnitVersionedFiles._split_by_prefix(keys) self.assertEqual(expected_map, split) self.assertEqual(expected_prefix_order, prefix_order) def test__group_keys_for_io(self): - ft_detail = ('fulltext', False) - ld_detail = ('line-delta', False) - f_a = (b'f', b'a') - f_b = (b'f', b'b') - f_c = (b'f', b'c') - g_a = (b'g', b'a') - g_b = (b'g', b'b') - g_c = (b'g', b'c') + ft_detail = ("fulltext", False) + ld_detail = ("line-delta", False) + f_a = (b"f", b"a") + f_b = (b"f", b"b") + f_c = (b"f", b"c") + g_a = (b"g", b"a") + g_b = (b"g", b"b") + g_c = (b"g", b"c") positions = { f_a: (ft_detail, (f_a, 0, 100), None), f_b: (ld_detail, (f_b, 100, 21), f_a), @@ -2137,70 +2405,89 @@ def test__group_keys_for_io(self): g_a: (ft_detail, (g_a, 121, 35), None), g_b: (ld_detail, (g_b, 156, 12), g_a), g_c: (ld_detail, (g_c, 195, 13), g_a), - } - self.assertGroupKeysForIo([([f_a], set())], - [f_a], [], positions) - self.assertGroupKeysForIo([([f_a], {f_a})], - [f_a], [f_a], positions) - self.assertGroupKeysForIo([([f_a, f_b], set())], - [f_a, f_b], [], positions) - self.assertGroupKeysForIo([([f_a, f_b], {f_b})], - [f_a, f_b], [f_b], positions) - self.assertGroupKeysForIo([([f_a, f_b, g_a, g_b], set())], - [f_a, g_a, f_b, g_b], [], positions) - self.assertGroupKeysForIo([([f_a, f_b, g_a, g_b], set())], - [f_a, g_a, f_b, g_b], [], positions, - _min_buffer_size=150) - self.assertGroupKeysForIo([([f_a, f_b], set()), ([g_a, g_b], set())], - [f_a, g_a, f_b, g_b], [], positions, - _min_buffer_size=100) - self.assertGroupKeysForIo([([f_c], set()), ([g_b], set())], - [f_c, g_b], [], positions, - _min_buffer_size=125) - self.assertGroupKeysForIo([([g_b, f_c], set())], - [g_b, f_c], [], positions, - _min_buffer_size=125) + } + self.assertGroupKeysForIo([([f_a], set())], [f_a], [], positions) + self.assertGroupKeysForIo([([f_a], {f_a})], [f_a], [f_a], positions) + self.assertGroupKeysForIo([([f_a, f_b], set())], [f_a, f_b], [], positions) + self.assertGroupKeysForIo([([f_a, f_b], {f_b})], [f_a, f_b], [f_b], positions) + self.assertGroupKeysForIo( + [([f_a, f_b, g_a, g_b], set())], [f_a, g_a, f_b, g_b], [], positions + ) + self.assertGroupKeysForIo( + [([f_a, f_b, g_a, g_b], set())], + [f_a, g_a, f_b, g_b], + [], + positions, + _min_buffer_size=150, + ) + self.assertGroupKeysForIo( + [([f_a, f_b], set()), ([g_a, g_b], set())], + [f_a, g_a, f_b, g_b], + [], + positions, + _min_buffer_size=100, + ) + self.assertGroupKeysForIo( + [([f_c], set()), ([g_b], set())], + [f_c, g_b], + [], + positions, + _min_buffer_size=125, + ) + self.assertGroupKeysForIo( + [([g_b, f_c], set())], [g_b, f_c], [], positions, _min_buffer_size=125 + ) def test__split_by_prefix(self): - self.assertSplitByPrefix({b'f': [(b'f', b'a'), (b'f', b'b')], - b'g': [(b'g', b'b'), (b'g', b'a')], - }, [b'f', b'g'], - [(b'f', b'a'), (b'g', b'b'), - (b'g', b'a'), (b'f', b'b')]) - - self.assertSplitByPrefix({b'f': [(b'f', b'a'), (b'f', b'b')], - b'g': [(b'g', b'b'), (b'g', b'a')], - }, [b'f', b'g'], - [(b'f', b'a'), (b'f', b'b'), - (b'g', b'b'), (b'g', b'a')]) - - self.assertSplitByPrefix({b'f': [(b'f', b'a'), (b'f', b'b')], - b'g': [(b'g', b'b'), (b'g', b'a')], - }, [b'f', b'g'], - [(b'f', b'a'), (b'f', b'b'), - (b'g', b'b'), (b'g', b'a')]) - - self.assertSplitByPrefix({b'f': [(b'f', b'a'), (b'f', b'b')], - b'g': [(b'g', b'b'), (b'g', b'a')], - b'': [(b'a',), (b'b',)] - }, [b'f', b'g', b''], - [(b'f', b'a'), (b'g', b'b'), - (b'a',), (b'b',), - (b'g', b'a'), (b'f', b'b')]) + self.assertSplitByPrefix( + { + b"f": [(b"f", b"a"), (b"f", b"b")], + b"g": [(b"g", b"b"), (b"g", b"a")], + }, + [b"f", b"g"], + [(b"f", b"a"), (b"g", b"b"), (b"g", b"a"), (b"f", b"b")], + ) + + self.assertSplitByPrefix( + { + b"f": [(b"f", b"a"), (b"f", b"b")], + b"g": [(b"g", b"b"), (b"g", b"a")], + }, + [b"f", b"g"], + [(b"f", b"a"), (b"f", b"b"), (b"g", b"b"), (b"g", b"a")], + ) + + self.assertSplitByPrefix( + { + b"f": [(b"f", b"a"), (b"f", b"b")], + b"g": [(b"g", b"b"), (b"g", b"a")], + }, + [b"f", b"g"], + [(b"f", b"a"), (b"f", b"b"), (b"g", b"b"), (b"g", b"a")], + ) + + self.assertSplitByPrefix( + { + b"f": [(b"f", b"a"), (b"f", b"b")], + b"g": [(b"g", b"b"), (b"g", b"a")], + b"": [(b"a",), (b"b",)], + }, + [b"f", b"g", b""], + [(b"f", b"a"), (b"g", b"b"), (b"a",), (b"b",), (b"g", b"a"), (b"f", b"b")], + ) class TestStacking(KnitTests): - def get_basis_and_test_knit(self): - basis = self.make_test_knit(name='basis') + basis = self.make_test_knit(name="basis") basis = RecordingVersionedFilesDecorator(basis) - test = self.make_test_knit(name='test') + test = self.make_test_knit(name="test") test.add_fallback_versioned_files(basis) return basis, test def test_add_fallback_versioned_files(self): - basis = self.make_test_knit(name='basis') - test = self.make_test_knit(name='test') + basis = self.make_test_knit(name="basis") + test = self.make_test_knit(name="test") # It must not error; other tests test that the fallback is referred to # when accessing data. test.add_fallback_versioned_files(basis) @@ -2208,49 +2495,52 @@ def test_add_fallback_versioned_files(self): def test_add_lines(self): # lines added to the test are not added to the basis basis, test = self.get_basis_and_test_knit() - key = (b'foo',) - key_basis = (b'bar',) - key_cross_border = (b'quux',) - key_delta = (b'zaphod',) - test.add_lines(key, (), [b'foo\n']) + key = (b"foo",) + key_basis = (b"bar",) + key_cross_border = (b"quux",) + key_delta = (b"zaphod",) + test.add_lines(key, (), [b"foo\n"]) self.assertEqual({}, basis.get_parent_map([key])) # lines added to the test that reference across the stack do a # fulltext. - basis.add_lines(key_basis, (), [b'foo\n']) + basis.add_lines(key_basis, (), [b"foo\n"]) basis.calls = [] - test.add_lines(key_cross_border, (key_basis,), [b'foo\n']) - self.assertEqual('fulltext', test._index.get_method(key_cross_border)) + test.add_lines(key_cross_border, (key_basis,), [b"foo\n"]) + self.assertEqual("fulltext", test._index.get_method(key_cross_border)) # we don't even need to look at the basis to see that this should be # stored as a fulltext self.assertEqual([], basis.calls) # Subsequent adds do delta. basis.calls = [] - test.add_lines(key_delta, (key_cross_border,), [b'foo\n']) - self.assertEqual('line-delta', test._index.get_method(key_delta)) + test.add_lines(key_delta, (key_cross_border,), [b"foo\n"]) + self.assertEqual("line-delta", test._index.get_method(key_delta)) self.assertEqual([], basis.calls) def test_annotate(self): # annotations from the test knit are answered without asking the basis basis, test = self.get_basis_and_test_knit() - key = (b'foo',) - key_basis = (b'bar',) - test.add_lines(key, (), [b'foo\n']) + key = (b"foo",) + key_basis = (b"bar",) + test.add_lines(key, (), [b"foo\n"]) details = test.annotate(key) - self.assertEqual([(key, b'foo\n')], details) + self.assertEqual([(key, b"foo\n")], details) self.assertEqual([], basis.calls) # But texts that are not in the test knit are looked for in the basis # directly. - basis.add_lines(key_basis, (), [b'foo\n', b'bar\n']) + basis.add_lines(key_basis, (), [b"foo\n", b"bar\n"]) basis.calls = [] details = test.annotate(key_basis) - self.assertEqual( - [(key_basis, b'foo\n'), (key_basis, b'bar\n')], details) + self.assertEqual([(key_basis, b"foo\n"), (key_basis, b"bar\n")], details) # Not optimised to date: # self.assertEqual([("annotate", key_basis)], basis.calls) - self.assertEqual([('get_parent_map', {key_basis}), - ('get_parent_map', {key_basis}), - ('get_record_stream', [key_basis], 'topological', True)], - basis.calls) + self.assertEqual( + [ + ("get_parent_map", {key_basis}), + ("get_parent_map", {key_basis}), + ("get_record_stream", [key_basis], "topological", True), + ], + basis.calls, + ) def test_check(self): # At the moment checking a stacked knit does implicitly check the @@ -2261,9 +2551,9 @@ def test_check(self): def test_get_parent_map(self): # parents in the test knit are answered without asking the basis basis, test = self.get_basis_and_test_knit() - key = (b'foo',) - key_basis = (b'bar',) - key_missing = (b'missing',) + key = (b"foo",) + key_basis = (b"bar",) + key_missing = (b"missing",) test.add_lines(key, (), []) parent_map = test.get_parent_map([key]) self.assertEqual({key: ()}, parent_map) @@ -2272,26 +2562,25 @@ def test_get_parent_map(self): basis.add_lines(key_basis, (), []) basis.calls = [] parent_map = test.get_parent_map([key, key_basis, key_missing]) - self.assertEqual({key: (), - key_basis: ()}, parent_map) - self.assertEqual([("get_parent_map", {key_basis, key_missing})], - basis.calls) + self.assertEqual({key: (), key_basis: ()}, parent_map) + self.assertEqual([("get_parent_map", {key_basis, key_missing})], basis.calls) def test_get_record_stream_unordered_fulltexts(self): # records from the test knit are answered without asking the basis: basis, test = self.get_basis_and_test_knit() - key = (b'foo',) - key_basis = (b'bar',) - key_missing = (b'missing',) - test.add_lines(key, (), [b'foo\n']) - records = list(test.get_record_stream([key], 'unordered', True)) + key = (b"foo",) + key_basis = (b"bar",) + key_missing = (b"missing",) + test.add_lines(key, (), [b"foo\n"]) + records = list(test.get_record_stream([key], "unordered", True)) self.assertEqual(1, len(records)) self.assertEqual([], basis.calls) # Missing (from test knit) objects are retrieved from the basis: - basis.add_lines(key_basis, (), [b'foo\n', b'bar\n']) + basis.add_lines(key_basis, (), [b"foo\n", b"bar\n"]) basis.calls = [] - records = list(test.get_record_stream([key_basis, key_missing], - 'unordered', True)) + records = list( + test.get_record_stream([key_basis, key_missing], "unordered", True) + ) self.assertEqual(2, len(records)) calls = list(basis.calls) for record in records: @@ -2299,47 +2588,62 @@ def test_get_record_stream_unordered_fulltexts(self): if record.key == key_missing: self.assertIsInstance(record, AbsentContentFactory) else: - reference = list(basis.get_record_stream([key_basis], - 'unordered', True))[0] + reference = list( + basis.get_record_stream([key_basis], "unordered", True) + )[0] self.assertEqual(reference.key, record.key) self.assertEqual(reference.sha1, record.sha1) self.assertEqual(reference.storage_kind, record.storage_kind) - self.assertEqual(reference.get_bytes_as(reference.storage_kind), - record.get_bytes_as(record.storage_kind)) - self.assertEqual(reference.get_bytes_as('fulltext'), - record.get_bytes_as('fulltext')) + self.assertEqual( + reference.get_bytes_as(reference.storage_kind), + record.get_bytes_as(record.storage_kind), + ) + self.assertEqual( + reference.get_bytes_as("fulltext"), record.get_bytes_as("fulltext") + ) # It's not strictly minimal, but it seems reasonable for now for it to # ask which fallbacks have which parents. - self.assertEqual([ - ("get_parent_map", {key_basis, key_missing}), - ("get_record_stream", [key_basis], 'unordered', True)], - calls) + self.assertEqual( + [ + ("get_parent_map", {key_basis, key_missing}), + ("get_record_stream", [key_basis], "unordered", True), + ], + calls, + ) def test_get_record_stream_ordered_fulltexts(self): # ordering is preserved down into the fallback store. basis, test = self.get_basis_and_test_knit() - key = (b'foo',) - key_basis = (b'bar',) - key_basis_2 = (b'quux',) - key_missing = (b'missing',) - test.add_lines(key, (key_basis,), [b'foo\n']) + key = (b"foo",) + key_basis = (b"bar",) + key_basis_2 = (b"quux",) + key_missing = (b"missing",) + test.add_lines(key, (key_basis,), [b"foo\n"]) # Missing (from test knit) objects are retrieved from the basis: - basis.add_lines(key_basis, (key_basis_2,), [b'foo\n', b'bar\n']) - basis.add_lines(key_basis_2, (), [b'quux\n']) + basis.add_lines(key_basis, (key_basis_2,), [b"foo\n", b"bar\n"]) + basis.add_lines(key_basis_2, (), [b"quux\n"]) basis.calls = [] # ask for in non-topological order - records = list(test.get_record_stream( - [key, key_basis, key_missing, key_basis_2], 'topological', True)) + records = list( + test.get_record_stream( + [key, key_basis, key_missing, key_basis_2], "topological", True + ) + ) self.assertEqual(4, len(records)) results = [] for record in records: - self.assertSubset([record.key], - (key_basis, key_missing, key_basis_2, key)) + self.assertSubset([record.key], (key_basis, key_missing, key_basis_2, key)) if record.key == key_missing: self.assertIsInstance(record, AbsentContentFactory) else: - results.append((record.key, record.sha1, record.storage_kind, - record.get_bytes_as('fulltext'))) + results.append( + ( + record.key, + record.sha1, + record.storage_kind, + record.get_bytes_as("fulltext"), + ) + ) calls = list(basis.calls) order = [record[0] for record in results] self.assertEqual([key_basis_2, key_basis, key], order) @@ -2348,44 +2652,46 @@ def test_get_record_stream_ordered_fulltexts(self): source = test else: source = basis - record = next(source.get_record_stream([result[0]], 'unordered', - True)) + record = next(source.get_record_stream([result[0]], "unordered", True)) self.assertEqual(record.key, result[0]) self.assertEqual(record.sha1, result[1]) # We used to check that the storage kind matched, but actually it # depends on whether it was sourced from the basis, or in a single # group, because asking for full texts returns proxy objects to a # _ContentMapGenerator object; so checking the kind is unneeded. - self.assertEqual(record.get_bytes_as('fulltext'), result[3]) + self.assertEqual(record.get_bytes_as("fulltext"), result[3]) # It's not strictly minimal, but it seems reasonable for now for it to # ask which fallbacks have which parents. self.assertEqual(2, len(calls)) self.assertEqual( - ("get_parent_map", {key_basis, key_basis_2, key_missing}), - calls[0]) + ("get_parent_map", {key_basis, key_basis_2, key_missing}), calls[0] + ) # topological is requested from the fallback, because that is what # was requested at the top level. self.assertIn( - calls[1], [ - ("get_record_stream", [key_basis_2, - key_basis], 'topological', True), - ("get_record_stream", [key_basis, key_basis_2], 'topological', True)]) + calls[1], + [ + ("get_record_stream", [key_basis_2, key_basis], "topological", True), + ("get_record_stream", [key_basis, key_basis_2], "topological", True), + ], + ) def test_get_record_stream_unordered_deltas(self): # records from the test knit are answered without asking the basis: basis, test = self.get_basis_and_test_knit() - key = (b'foo',) - key_basis = (b'bar',) - key_missing = (b'missing',) - test.add_lines(key, (), [b'foo\n']) - records = list(test.get_record_stream([key], 'unordered', False)) + key = (b"foo",) + key_basis = (b"bar",) + key_missing = (b"missing",) + test.add_lines(key, (), [b"foo\n"]) + records = list(test.get_record_stream([key], "unordered", False)) self.assertEqual(1, len(records)) self.assertEqual([], basis.calls) # Missing (from test knit) objects are retrieved from the basis: - basis.add_lines(key_basis, (), [b'foo\n', b'bar\n']) + basis.add_lines(key_basis, (), [b"foo\n", b"bar\n"]) basis.calls = [] - records = list(test.get_record_stream([key_basis, key_missing], - 'unordered', False)) + records = list( + test.get_record_stream([key_basis, key_missing], "unordered", False) + ) self.assertEqual(2, len(records)) calls = list(basis.calls) for record in records: @@ -2393,45 +2699,59 @@ def test_get_record_stream_unordered_deltas(self): if record.key == key_missing: self.assertIsInstance(record, AbsentContentFactory) else: - reference = list(basis.get_record_stream([key_basis], - 'unordered', False))[0] + reference = list( + basis.get_record_stream([key_basis], "unordered", False) + )[0] self.assertEqual(reference.key, record.key) self.assertEqual(reference.sha1, record.sha1) self.assertEqual(reference.storage_kind, record.storage_kind) - self.assertEqual(reference.get_bytes_as(reference.storage_kind), - record.get_bytes_as(record.storage_kind)) + self.assertEqual( + reference.get_bytes_as(reference.storage_kind), + record.get_bytes_as(record.storage_kind), + ) # It's not strictly minimal, but it seems reasonable for now for it to # ask which fallbacks have which parents. - self.assertEqual([ - ("get_parent_map", {key_basis, key_missing}), - ("get_record_stream", [key_basis], 'unordered', False)], - calls) + self.assertEqual( + [ + ("get_parent_map", {key_basis, key_missing}), + ("get_record_stream", [key_basis], "unordered", False), + ], + calls, + ) def test_get_record_stream_ordered_deltas(self): # ordering is preserved down into the fallback store. basis, test = self.get_basis_and_test_knit() - key = (b'foo',) - key_basis = (b'bar',) - key_basis_2 = (b'quux',) - key_missing = (b'missing',) - test.add_lines(key, (key_basis,), [b'foo\n']) + key = (b"foo",) + key_basis = (b"bar",) + key_basis_2 = (b"quux",) + key_missing = (b"missing",) + test.add_lines(key, (key_basis,), [b"foo\n"]) # Missing (from test knit) objects are retrieved from the basis: - basis.add_lines(key_basis, (key_basis_2,), [b'foo\n', b'bar\n']) - basis.add_lines(key_basis_2, (), [b'quux\n']) + basis.add_lines(key_basis, (key_basis_2,), [b"foo\n", b"bar\n"]) + basis.add_lines(key_basis_2, (), [b"quux\n"]) basis.calls = [] # ask for in non-topological order - records = list(test.get_record_stream( - [key, key_basis, key_missing, key_basis_2], 'topological', False)) + records = list( + test.get_record_stream( + [key, key_basis, key_missing, key_basis_2], "topological", False + ) + ) self.assertEqual(4, len(records)) results = [] for record in records: - self.assertSubset([record.key], - (key_basis, key_missing, key_basis_2, key)) + self.assertSubset([record.key], (key_basis, key_missing, key_basis_2, key)) if record.key == key_missing: self.assertIsInstance(record, AbsentContentFactory) else: - results.append((record.key, record.sha1, record.storage_kind, - record.get_bytes_as(record.storage_kind))) + results.append( + ( + record.key, + record.sha1, + record.storage_kind, + record.get_bytes_as(record.storage_kind), + ) + ) calls = list(basis.calls) order = [record[0] for record in results] self.assertEqual([key_basis_2, key_basis, key], order) @@ -2440,80 +2760,85 @@ def test_get_record_stream_ordered_deltas(self): source = test else: source = basis - record = next(source.get_record_stream([result[0]], 'unordered', - False)) + record = next(source.get_record_stream([result[0]], "unordered", False)) self.assertEqual(record.key, result[0]) self.assertEqual(record.sha1, result[1]) self.assertEqual(record.storage_kind, result[2]) - self.assertEqual(record.get_bytes_as( - record.storage_kind), result[3]) + self.assertEqual(record.get_bytes_as(record.storage_kind), result[3]) # It's not strictly minimal, but it seems reasonable for now for it to # ask which fallbacks have which parents. - self.assertEqual([ - ("get_parent_map", {key_basis, key_basis_2, key_missing}), - ("get_record_stream", [key_basis_2, key_basis], 'topological', False)], - calls) + self.assertEqual( + [ + ("get_parent_map", {key_basis, key_basis_2, key_missing}), + ("get_record_stream", [key_basis_2, key_basis], "topological", False), + ], + calls, + ) def test_get_sha1s(self): # sha1's in the test knit are answered without asking the basis basis, test = self.get_basis_and_test_knit() - key = (b'foo',) - key_basis = (b'bar',) - key_missing = (b'missing',) - test.add_lines(key, (), [b'foo\n']) - key_sha1sum = osutils.sha_string(b'foo\n') + key = (b"foo",) + key_basis = (b"bar",) + key_missing = (b"missing",) + test.add_lines(key, (), [b"foo\n"]) + key_sha1sum = osutils.sha_string(b"foo\n") sha1s = test.get_sha1s([key]) self.assertEqual({key: key_sha1sum}, sha1s) self.assertEqual([], basis.calls) # But texts that are not in the test knit are looked for in the basis # directly (rather than via text reconstruction) so that remote servers # etc don't have to answer with full content. - basis.add_lines(key_basis, (), [b'foo\n', b'bar\n']) - basis_sha1sum = osutils.sha_string(b'foo\nbar\n') + basis.add_lines(key_basis, (), [b"foo\n", b"bar\n"]) + basis_sha1sum = osutils.sha_string(b"foo\nbar\n") basis.calls = [] sha1s = test.get_sha1s([key, key_missing, key_basis]) - self.assertEqual({key: key_sha1sum, - key_basis: basis_sha1sum}, sha1s) - self.assertEqual([("get_sha1s", {key_basis, key_missing})], - basis.calls) + self.assertEqual({key: key_sha1sum, key_basis: basis_sha1sum}, sha1s) + self.assertEqual([("get_sha1s", {key_basis, key_missing})], basis.calls) def test_insert_record_stream(self): # records are inserted as normal; insert_record_stream builds on # add_lines, so a smoke test should be all that's needed: - key_basis = (b'bar',) - key_delta = (b'zaphod',) + key_basis = (b"bar",) + key_delta = (b"zaphod",) basis, test = self.get_basis_and_test_knit() - source = self.make_test_knit(name='source') - basis.add_lines(key_basis, (), [b'foo\n']) + source = self.make_test_knit(name="source") + basis.add_lines(key_basis, (), [b"foo\n"]) basis.calls = [] - source.add_lines(key_basis, (), [b'foo\n']) - source.add_lines(key_delta, (key_basis,), [b'bar\n']) - stream = source.get_record_stream([key_delta], 'unordered', False) + source.add_lines(key_basis, (), [b"foo\n"]) + source.add_lines(key_delta, (key_basis,), [b"bar\n"]) + stream = source.get_record_stream([key_delta], "unordered", False) test.insert_record_stream(stream) # XXX: this does somewhat too many calls in making sure of whether it # has to recreate the full text. - self.assertEqual([("get_parent_map", {key_basis}), - ('get_parent_map', {key_basis}), - ('get_record_stream', [key_basis], 'unordered', True)], - basis.calls) - self.assertEqual({key_delta: (key_basis,)}, - test.get_parent_map([key_delta])) - self.assertEqual(b'bar\n', next(test.get_record_stream([key_delta], - 'unordered', True)).get_bytes_as('fulltext')) + self.assertEqual( + [ + ("get_parent_map", {key_basis}), + ("get_parent_map", {key_basis}), + ("get_record_stream", [key_basis], "unordered", True), + ], + basis.calls, + ) + self.assertEqual({key_delta: (key_basis,)}, test.get_parent_map([key_delta])) + self.assertEqual( + b"bar\n", + next(test.get_record_stream([key_delta], "unordered", True)).get_bytes_as( + "fulltext" + ), + ) def test_iter_lines_added_or_present_in_keys(self): # Lines from the basis are returned, and lines for a given key are only # returned once. - key1 = (b'foo1',) - key2 = (b'foo2',) + key1 = (b"foo1",) + key2 = (b"foo2",) # all sources are asked for keys: basis, test = self.get_basis_and_test_knit() basis.add_lines(key1, (), [b"foo"]) basis.calls = [] lines = list(test.iter_lines_added_or_present_in_keys([key1])) self.assertEqual([(b"foo\n", key1)], lines) - self.assertEqual([("iter_lines_added_or_present_in_keys", {key1})], - basis.calls) + self.assertEqual([("iter_lines_added_or_present_in_keys", {key1})], basis.calls) # keys in both are not duplicated: test.add_lines(key2, (), [b"bar\n"]) basis.add_lines(key2, (), [b"bar\n"]) @@ -2523,8 +2848,8 @@ def test_iter_lines_added_or_present_in_keys(self): self.assertEqual([], basis.calls) def test_keys(self): - key1 = (b'foo1',) - key2 = (b'foo2',) + key1 = (b"foo1",) + key2 = (b"foo2",) # all sources are asked for keys: basis, test = self.get_basis_and_test_knit() keys = test.keys() @@ -2548,53 +2873,76 @@ def test_keys(self): def test_add_mpdiffs(self): # records are inserted as normal; add_mpdiff builds on # add_lines, so a smoke test should be all that's needed: - key_basis = (b'bar',) - key_delta = (b'zaphod',) + key_basis = (b"bar",) + key_delta = (b"zaphod",) basis, test = self.get_basis_and_test_knit() - source = self.make_test_knit(name='source') - basis.add_lines(key_basis, (), [b'foo\n']) + source = self.make_test_knit(name="source") + basis.add_lines(key_basis, (), [b"foo\n"]) basis.calls = [] - source.add_lines(key_basis, (), [b'foo\n']) - source.add_lines(key_delta, (key_basis,), [b'bar\n']) + source.add_lines(key_basis, (), [b"foo\n"]) + source.add_lines(key_delta, (key_basis,), [b"bar\n"]) diffs = source.make_mpdiffs([key_delta]) - test.add_mpdiffs([(key_delta, (key_basis,), - source.get_sha1s([key_delta])[key_delta], diffs[0])]) - self.assertEqual([("get_parent_map", {key_basis}), - ('get_record_stream', [key_basis], 'unordered', True), ], - basis.calls) - self.assertEqual({key_delta: (key_basis,)}, - test.get_parent_map([key_delta])) - self.assertEqual(b'bar\n', next(test.get_record_stream([key_delta], - 'unordered', True)).get_bytes_as('fulltext')) + test.add_mpdiffs( + [ + ( + key_delta, + (key_basis,), + source.get_sha1s([key_delta])[key_delta], + diffs[0], + ) + ] + ) + self.assertEqual( + [ + ("get_parent_map", {key_basis}), + ("get_record_stream", [key_basis], "unordered", True), + ], + basis.calls, + ) + self.assertEqual({key_delta: (key_basis,)}, test.get_parent_map([key_delta])) + self.assertEqual( + b"bar\n", + next(test.get_record_stream([key_delta], "unordered", True)).get_bytes_as( + "fulltext" + ), + ) def test_make_mpdiffs(self): # Generating an mpdiff across a stacking boundary should detect parent # texts regions. - key = (b'foo',) - key_left = (b'bar',) - key_right = (b'zaphod',) + key = (b"foo",) + key_left = (b"bar",) + key_right = (b"zaphod",) basis, test = self.get_basis_and_test_knit() - basis.add_lines(key_left, (), [b'bar\n']) - basis.add_lines(key_right, (), [b'zaphod\n']) + basis.add_lines(key_left, (), [b"bar\n"]) + basis.add_lines(key_right, (), [b"zaphod\n"]) basis.calls = [] - test.add_lines(key, (key_left, key_right), - [b'bar\n', b'foo\n', b'zaphod\n']) + test.add_lines(key, (key_left, key_right), [b"bar\n", b"foo\n", b"zaphod\n"]) diffs = test.make_mpdiffs([key]) - self.assertEqual([ - multiparent.MultiParent([multiparent.ParentText(0, 0, 0, 1), - multiparent.NewText([b'foo\n']), - multiparent.ParentText(1, 0, 2, 1)])], - diffs) + self.assertEqual( + [ + multiparent.MultiParent( + [ + multiparent.ParentText(0, 0, 0, 1), + multiparent.NewText([b"foo\n"]), + multiparent.ParentText(1, 0, 2, 1), + ] + ) + ], + diffs, + ) self.assertEqual(3, len(basis.calls)) - self.assertEqual([ - ("get_parent_map", {key_left, key_right}), - ("get_parent_map", {key_left, key_right}), + self.assertEqual( + [ + ("get_parent_map", {key_left, key_right}), + ("get_parent_map", {key_left, key_right}), ], - basis.calls[:-1]) + basis.calls[:-1], + ) last_call = basis.calls[-1] - self.assertEqual('get_record_stream', last_call[0]) + self.assertEqual("get_record_stream", last_call[0]) self.assertEqual({key_left, key_right}, set(last_call[1])) - self.assertEqual('topological', last_call[2]) + self.assertEqual("topological", last_call[2]) self.assertEqual(True, last_call[3]) @@ -2602,71 +2950,76 @@ class TestNetworkBehaviour(KnitTests): """Tests for getting data out of/into knits over the network.""" def test_include_delta_closure_generates_a_knit_delta_closure(self): - vf = self.make_test_knit(name='test') + vf = self.make_test_knit(name="test") # put in three texts, giving ft, delta, delta - vf.add_lines((b'base',), (), [b'base\n', b'content\n']) - vf.add_lines((b'd1',), ((b'base',),), [b'd1\n']) - vf.add_lines((b'd2',), ((b'd1',),), [b'd2\n']) + vf.add_lines((b"base",), (), [b"base\n", b"content\n"]) + vf.add_lines((b"d1",), ((b"base",),), [b"d1\n"]) + vf.add_lines((b"d2",), ((b"d1",),), [b"d2\n"]) # But heuristics could interfere, so check what happened: - self.assertEqual(['knit-ft-gz', 'knit-delta-gz', 'knit-delta-gz'], - [record.storage_kind for record in - vf.get_record_stream([(b'base',), (b'd1',), (b'd2',)], - 'topological', False)]) + self.assertEqual( + ["knit-ft-gz", "knit-delta-gz", "knit-delta-gz"], + [ + record.storage_kind + for record in vf.get_record_stream( + [(b"base",), (b"d1",), (b"d2",)], "topological", False + ) + ], + ) # generate a stream of just the deltas include_delta_closure=True, # serialise to the network, and check that we get a delta closure on the wire. - stream = vf.get_record_stream( - [(b'd1',), (b'd2',)], 'topological', True) + stream = vf.get_record_stream([(b"d1",), (b"d2",)], "topological", True) netb = [record.get_bytes_as(record.storage_kind) for record in stream] # The first bytes should be a memo from _ContentMapGenerator, and the # second bytes should be empty (because its a API proxy not something # for wire serialisation. - self.assertEqual(b'', netb[1]) + self.assertEqual(b"", netb[1]) bytes = netb[0] kind, line_end = network_bytes_to_kind_and_offset(bytes) - self.assertEqual('knit-delta-closure', kind) + self.assertEqual("knit-delta-closure", kind) class TestContentMapGenerator(KnitTests): """Tests for ContentMapGenerator.""" def test_get_record_stream_gives_records(self): - vf = self.make_test_knit(name='test') + vf = self.make_test_knit(name="test") # put in three texts, giving ft, delta, delta - vf.add_lines((b'base',), (), [b'base\n', b'content\n']) - vf.add_lines((b'd1',), ((b'base',),), [b'd1\n']) - vf.add_lines((b'd2',), ((b'd1',),), [b'd2\n']) - keys = [(b'd1',), (b'd2',)] - generator = _VFContentMapGenerator(vf, keys, - global_map=vf.get_parent_map(keys)) + vf.add_lines((b"base",), (), [b"base\n", b"content\n"]) + vf.add_lines((b"d1",), ((b"base",),), [b"d1\n"]) + vf.add_lines((b"d2",), ((b"d1",),), [b"d2\n"]) + keys = [(b"d1",), (b"d2",)] + generator = _VFContentMapGenerator(vf, keys, global_map=vf.get_parent_map(keys)) for record in generator.get_record_stream(): - if record.key == (b'd1',): - self.assertEqual(b'd1\n', record.get_bytes_as('fulltext')) + if record.key == (b"d1",): + self.assertEqual(b"d1\n", record.get_bytes_as("fulltext")) else: - self.assertEqual(b'd2\n', record.get_bytes_as('fulltext')) + self.assertEqual(b"d2\n", record.get_bytes_as("fulltext")) def test_get_record_stream_kinds_are_raw(self): - vf = self.make_test_knit(name='test') + vf = self.make_test_knit(name="test") # put in three texts, giving ft, delta, delta - vf.add_lines((b'base',), (), [b'base\n', b'content\n']) - vf.add_lines((b'd1',), ((b'base',),), [b'd1\n']) - vf.add_lines((b'd2',), ((b'd1',),), [b'd2\n']) - keys = [(b'base',), (b'd1',), (b'd2',)] - generator = _VFContentMapGenerator(vf, keys, - global_map=vf.get_parent_map(keys)) - kinds = {(b'base',): 'knit-delta-closure', - (b'd1',): 'knit-delta-closure-ref', - (b'd2',): 'knit-delta-closure-ref', - } + vf.add_lines((b"base",), (), [b"base\n", b"content\n"]) + vf.add_lines((b"d1",), ((b"base",),), [b"d1\n"]) + vf.add_lines((b"d2",), ((b"d1",),), [b"d2\n"]) + keys = [(b"base",), (b"d1",), (b"d2",)] + generator = _VFContentMapGenerator(vf, keys, global_map=vf.get_parent_map(keys)) + kinds = { + (b"base",): "knit-delta-closure", + (b"d1",): "knit-delta-closure-ref", + (b"d2",): "knit-delta-closure-ref", + } for record in generator.get_record_stream(): self.assertEqual(kinds[record.key], record.storage_kind) class TestErrors(TestCase): - def test_retry_with_new_packs(self): - fake_exc_info = ('{exc type}', '{exc value}', '{exc traceback}') + fake_exc_info = ("{exc type}", "{exc value}", "{exc traceback}") error = pack_repo.RetryWithNewPacks( - '{context}', reload_occurred=False, exc_info=fake_exc_info) + "{context}", reload_occurred=False, exc_info=fake_exc_info + ) self.assertEqual( - 'Pack files have changed, reload and retry. context: ' - '{context} {exc value}', str(error)) + "Pack files have changed, reload and retry. context: " + "{context} {exc value}", + str(error), + ) diff --git a/breezy/bzr/tests/test_lockable_files.py b/breezy/bzr/tests/test_lockable_files.py index 2ba338974a..2fc10580cf 100644 --- a/breezy/bzr/tests/test_lockable_files.py +++ b/breezy/bzr/tests/test_lockable_files.py @@ -34,32 +34,29 @@ # they use an old style of parameterization, but we want to remove this class # so won't modernize them now. - mbp 20080430 class _TestLockableFiles_mixin: - def test_transactions(self): - self.assertIs(self.lockable.get_transaction().__class__, - PassThroughTransaction) + self.assertIs(self.lockable.get_transaction().__class__, PassThroughTransaction) self.lockable.lock_read() try: - self.assertIs(self.lockable.get_transaction().__class__, - ReadOnlyTransaction) + self.assertIs( + self.lockable.get_transaction().__class__, ReadOnlyTransaction + ) finally: self.lockable.unlock() - self.assertIs(self.lockable.get_transaction().__class__, - PassThroughTransaction) + self.assertIs(self.lockable.get_transaction().__class__, PassThroughTransaction) self.lockable.lock_write() - self.assertIs(self.lockable.get_transaction().__class__, - WriteTransaction) + self.assertIs(self.lockable.get_transaction().__class__, WriteTransaction) # check that finish is called: - vf = DummyWeave('a') + vf = DummyWeave("a") self.lockable.get_transaction().register_dirty(vf) self.lockable.unlock() self.assertTrue(vf.finished) def test__escape(self): - self.assertEqual('%25', self.lockable._escape('%')) + self.assertEqual("%25", self.lockable._escape("%")) def test__escape_empty(self): - self.assertEqual('', self.lockable._escape('')) + self.assertEqual("", self.lockable._escape("")) def test_break_lock(self): # some locks are not breakable @@ -92,8 +89,9 @@ def test_lock_write_returns_None_refuses_token(self): # This test does not apply, because this lockable supports # tokens. raise TestNotApplicable(f"{self.lockable!r} uses tokens") - self.assertRaises(errors.TokenLockingNotSupported, - self.lockable.lock_write, token='token') + self.assertRaises( + errors.TokenLockingNotSupported, self.lockable.lock_write, token="token" + ) def test_lock_write_returns_token_when_given_token(self): token = self.lockable.lock_write() @@ -114,11 +112,12 @@ def test_lock_write_raises_on_token_mismatch(self): # This test does not apply, because this lockable refuses # tokens. return - different_token = token + b'xxx' + different_token = token + b"xxx" # Re-using the same lockable instance with a different token will # raise TokenMismatch. - self.assertRaises(errors.TokenMismatch, - self.lockable.lock_write, token=different_token) + self.assertRaises( + errors.TokenMismatch, self.lockable.lock_write, token=different_token + ) # A separate instance for the same lockable will also raise # TokenMismatch. # This detects the case where a caller claims to have a lock (via @@ -127,8 +126,9 @@ def test_lock_write_raises_on_token_mismatch(self): # external resource is probed, whereas the existing lock object # might cache. new_lockable = self.get_lockable() - self.assertRaises(errors.TokenMismatch, - new_lockable.lock_write, token=different_token) + self.assertRaises( + errors.TokenMismatch, new_lockable.lock_write, token=different_token + ) def test_lock_write_with_matching_token(self): # If the token matches, so no exception is raised by lock_write. @@ -174,8 +174,7 @@ def test_lock_write_with_token_fails_when_unlocked(self): # tokens. return - self.assertRaises(errors.TokenMismatch, - self.lockable.lock_write, token=token) + self.assertRaises(errors.TokenMismatch, self.lockable.lock_write, token=token) def test_lock_write_reenter_with_token(self): token = self.lockable.lock_write() @@ -263,14 +262,12 @@ def test_dont_leave_in_place(self): # This method of adapting tests to parameters is different to # the TestProviderAdapters used elsewhere, but seems simpler for this # case. -class TestLockableFiles_TransportLock(TestCaseInTempDir, - _TestLockableFiles_mixin): - +class TestLockableFiles_TransportLock(TestCaseInTempDir, _TestLockableFiles_mixin): def setUp(self): super().setUp() - t = transport.get_transport_from_path('.') - t.mkdir('.bzr') - self.sub_transport = t.clone('.bzr') + t = transport.get_transport_from_path(".") + t.mkdir(".bzr") + self.sub_transport = t.clone(".bzr") self.lockable = self.get_lockable() self.lockable.create_lock() @@ -284,16 +281,15 @@ def stop_server(self): pass def get_lockable(self): - return LockableFiles(self.sub_transport, 'my-lock', TransportLock) + return LockableFiles(self.sub_transport, "my-lock", TransportLock) -class TestLockableFiles_LockDir(TestCaseInTempDir, - _TestLockableFiles_mixin): +class TestLockableFiles_LockDir(TestCaseInTempDir, _TestLockableFiles_mixin): """LockableFile tests run with LockDir underneath.""" def setUp(self): super().setUp() - self.transport = transport.get_transport_from_path('.') + self.transport = transport.get_transport_from_path(".") self.lockable = self.get_lockable() # the lock creation here sets mode - test_permissions on branch # tests that implicitly, but it might be a good idea to factor @@ -302,29 +298,31 @@ def setUp(self): self.lockable.create_lock() def get_lockable(self): - return LockableFiles(self.transport, 'my-lock', lockdir.LockDir) + return LockableFiles(self.transport, "my-lock", lockdir.LockDir) def test_lock_created(self): - self.assertTrue(self.transport.has('my-lock')) + self.assertTrue(self.transport.has("my-lock")) self.lockable.lock_write() - self.assertTrue(self.transport.has('my-lock/held/info')) + self.assertTrue(self.transport.has("my-lock/held/info")) self.lockable.unlock() - self.assertFalse(self.transport.has('my-lock/held/info')) - self.assertTrue(self.transport.has('my-lock')) + self.assertFalse(self.transport.has("my-lock/held/info")) + self.assertTrue(self.transport.has("my-lock")) def test__file_modes(self): - self.transport.mkdir('readonly') - osutils.make_readonly('readonly') - lockable = LockableFiles(self.transport.clone('readonly'), 'test-lock', - lockdir.LockDir) + self.transport.mkdir("readonly") + osutils.make_readonly("readonly") + lockable = LockableFiles( + self.transport.clone("readonly"), "test-lock", lockdir.LockDir + ) # The directory mode should be read-write-execute for the current user self.assertEqual(0o0700, lockable._dir_mode & 0o0700) # Files should be read-write for the current user self.assertEqual(0o0600, lockable._file_mode & 0o0700) -class TestLockableFiles_RemoteLockDir(TestCaseWithSmartMedium, - _TestLockableFiles_mixin): +class TestLockableFiles_RemoteLockDir( + TestCaseWithSmartMedium, _TestLockableFiles_mixin +): """LockableFile tests run with RemoteLockDir on a branch.""" def setUp(self): @@ -334,13 +332,13 @@ def setUp(self): # to end behaviour, so stubbing out the backend and simulating would # defeat the purpose. We test the protocol implementation separately # in test_remote and test_smart as usual. - b = self.make_branch('foo') + b = self.make_branch("foo") self.addCleanup(b.controldir.transport.disconnect) - self.transport = transport.get_transport_from_path('.') + self.transport = transport.get_transport_from_path(".") self.lockable = self.get_lockable() def get_lockable(self): # getting a new lockable involves opening a new instance of the branch - branch = breezy.branch.Branch.open(self.get_url('foo')) + branch = breezy.branch.Branch.open(self.get_url("foo")) self.addCleanup(branch.controldir.transport.disconnect) return branch.control_files diff --git a/breezy/bzr/tests/test_matchers.py b/breezy/bzr/tests/test_matchers.py index 5b6b1614ea..84d8030d3e 100644 --- a/breezy/bzr/tests/test_matchers.py +++ b/breezy/bzr/tests/test_matchers.py @@ -24,7 +24,6 @@ class TestContainsNoVfsCalls(TestCase): - def _make_call(self, method, args): return CapturedCall(CallHookParams(method, args, None, None, None), 0) @@ -43,11 +42,17 @@ def test_ignores_unknown(self): self.assertIs(None, ContainsNoVfsCalls().match(calls)) def test_match(self): - calls = [self._make_call(b"append", [b"file"]), - self._make_call(b"Branch.get_config_file", [])] + calls = [ + self._make_call(b"append", [b"file"]), + self._make_call(b"Branch.get_config_file", []), + ] mismatch = ContainsNoVfsCalls().match(calls) self.assertIsNot(None, mismatch) self.assertEqual([calls[0].call], mismatch.vfs_calls) - self.assertIn(mismatch.describe(), [ - "no VFS calls expected, got: b'append'(b'file')", - "no VFS calls expected, got: append('file')"]) + self.assertIn( + mismatch.describe(), + [ + "no VFS calls expected, got: b'append'(b'file')", + "no VFS calls expected, got: append('file')", + ], + ) diff --git a/breezy/bzr/tests/test_pack.py b/breezy/bzr/tests/test_pack.py index f84bf2a5b2..8c9825d61c 100644 --- a/breezy/bzr/tests/test_pack.py +++ b/breezy/bzr/tests/test_pack.py @@ -31,47 +31,47 @@ def test_construct(self): def test_begin(self): serialiser = pack.ContainerSerialiser() - self.assertEqual(b'Bazaar pack format 1 (introduced in 0.18)\n', - serialiser.begin()) + self.assertEqual( + b"Bazaar pack format 1 (introduced in 0.18)\n", serialiser.begin() + ) def test_end(self): serialiser = pack.ContainerSerialiser() - self.assertEqual(b'E', serialiser.end()) + self.assertEqual(b"E", serialiser.end()) def test_bytes_record_no_name(self): serialiser = pack.ContainerSerialiser() - record = serialiser.bytes_record(b'bytes', []) - self.assertEqual(b'B5\n\nbytes', record) + record = serialiser.bytes_record(b"bytes", []) + self.assertEqual(b"B5\n\nbytes", record) def test_bytes_record_one_name_with_one_part(self): serialiser = pack.ContainerSerialiser() - record = serialiser.bytes_record(b'bytes', [(b'name',)]) - self.assertEqual(b'B5\nname\n\nbytes', record) + record = serialiser.bytes_record(b"bytes", [(b"name",)]) + self.assertEqual(b"B5\nname\n\nbytes", record) def test_bytes_record_one_name_with_two_parts(self): serialiser = pack.ContainerSerialiser() - record = serialiser.bytes_record(b'bytes', [(b'part1', b'part2')]) - self.assertEqual(b'B5\npart1\x00part2\n\nbytes', record) + record = serialiser.bytes_record(b"bytes", [(b"part1", b"part2")]) + self.assertEqual(b"B5\npart1\x00part2\n\nbytes", record) def test_bytes_record_two_names(self): serialiser = pack.ContainerSerialiser() - record = serialiser.bytes_record(b'bytes', [(b'name1',), (b'name2',)]) - self.assertEqual(b'B5\nname1\nname2\n\nbytes', record) + record = serialiser.bytes_record(b"bytes", [(b"name1",), (b"name2",)]) + self.assertEqual(b"B5\nname1\nname2\n\nbytes", record) def test_bytes_record_whitespace_in_name_part(self): serialiser = pack.ContainerSerialiser() self.assertRaises( - pack.InvalidRecordError, - serialiser.bytes_record, b'bytes', [(b'bad name',)]) + pack.InvalidRecordError, serialiser.bytes_record, b"bytes", [(b"bad name",)] + ) def test_bytes_record_header(self): serialiser = pack.ContainerSerialiser() - record = serialiser.bytes_header(32, [(b'name1',), (b'name2',)]) - self.assertEqual(b'B32\nname1\nname2\n\n', record) + record = serialiser.bytes_header(32, [(b"name1",), (b"name2",)]) + self.assertEqual(b"B32\nname1\nname2\n\n", record) class TestContainerWriter(tests.TestCase): - def setUp(self): super().setUp() self.output = BytesIO() @@ -94,7 +94,7 @@ def test_construct(self): def test_begin(self): """The begin() method writes the container format marker line.""" self.writer.begin() - self.assertOutput(b'Bazaar pack format 1 (introduced in 0.18)\n') + self.assertOutput(b"Bazaar pack format 1 (introduced in 0.18)\n") def test_zero_records_written_after_begin(self): """After begin is written, 0 records have been written.""" @@ -105,7 +105,7 @@ def test_end(self): """The end() method writes an End Marker record.""" self.writer.begin() self.writer.end() - self.assertOutput(b'Bazaar pack format 1 (introduced in 0.18)\nE') + self.assertOutput(b"Bazaar pack format 1 (introduced in 0.18)\nE") def test_empty_end_does_not_add_a_record_to_records_written(self): """The end() method does not count towards the records written.""" @@ -116,28 +116,28 @@ def test_empty_end_does_not_add_a_record_to_records_written(self): def test_non_empty_end_does_not_add_a_record_to_records_written(self): """The end() method does not count towards the records written.""" self.writer.begin() - self.writer.add_bytes_record([b'foo'], len(b'foo'), names=[]) + self.writer.add_bytes_record([b"foo"], len(b"foo"), names=[]) self.writer.end() self.assertEqual(1, self.writer.records_written) def test_add_bytes_record_no_name(self): """Add a bytes record with no name.""" self.writer.begin() - offset, length = self.writer.add_bytes_record([b'abc'], len(b'abc'), names=[]) + offset, length = self.writer.add_bytes_record([b"abc"], len(b"abc"), names=[]) self.assertEqual((42, 7), (offset, length)) - self.assertOutput( - b'Bazaar pack format 1 (introduced in 0.18)\nB3\n\nabc') + self.assertOutput(b"Bazaar pack format 1 (introduced in 0.18)\nB3\n\nabc") def test_add_bytes_record_one_name(self): """Add a bytes record with one name.""" self.writer.begin() offset, length = self.writer.add_bytes_record( - [b'abc'], len(b'abc'), names=[(b'name1', )]) + [b"abc"], len(b"abc"), names=[(b"name1",)] + ) self.assertEqual((42, 13), (offset, length)) self.assertOutput( - b'Bazaar pack format 1 (introduced in 0.18)\n' - b'B3\nname1\n\nabc') + b"Bazaar pack format 1 (introduced in 0.18)\n" b"B3\nname1\n\nabc" + ) def test_add_bytes_record_split_writes(self): """Write a large record which does multiple IOs.""" @@ -153,47 +153,52 @@ def record_writes(data): self.writer.begin() offset, length = self.writer.add_bytes_record( - [b'abcabc'], len(b'abcabc'), names=[(b'name1', )]) + [b"abcabc"], len(b"abcabc"), names=[(b"name1",)] + ) self.assertEqual((42, 16), (offset, length)) self.assertOutput( - b'Bazaar pack format 1 (introduced in 0.18)\n' - b'B6\nname1\n\nabcabc') + b"Bazaar pack format 1 (introduced in 0.18)\n" b"B6\nname1\n\nabcabc" + ) - self.assertEqual([ - b'Bazaar pack format 1 (introduced in 0.18)\n', - b'B6\nname1\n\n', - b'abcabc'], - writes) + self.assertEqual( + [ + b"Bazaar pack format 1 (introduced in 0.18)\n", + b"B6\nname1\n\n", + b"abcabc", + ], + writes, + ) def test_add_bytes_record_two_names(self): """Add a bytes record with two names.""" self.writer.begin() offset, length = self.writer.add_bytes_record( - [b'abc'], len(b'abc'), names=[(b'name1', ), (b'name2', )]) + [b"abc"], len(b"abc"), names=[(b"name1",), (b"name2",)] + ) self.assertEqual((42, 19), (offset, length)) self.assertOutput( - b'Bazaar pack format 1 (introduced in 0.18)\n' - b'B3\nname1\nname2\n\nabc') + b"Bazaar pack format 1 (introduced in 0.18)\n" b"B3\nname1\nname2\n\nabc" + ) def test_add_bytes_record_two_element_name(self): """Add a bytes record with a two-element name.""" self.writer.begin() offset, length = self.writer.add_bytes_record( - [b'abc'], len(b'abc'), names=[(b'name1', b'name2')]) + [b"abc"], len(b"abc"), names=[(b"name1", b"name2")] + ) self.assertEqual((42, 19), (offset, length)) self.assertOutput( - b'Bazaar pack format 1 (introduced in 0.18)\n' - b'B3\nname1\x00name2\n\nabc') + b"Bazaar pack format 1 (introduced in 0.18)\n" b"B3\nname1\x00name2\n\nabc" + ) def test_add_second_bytes_record_gets_higher_offset(self): self.writer.begin() - self.writer.add_bytes_record([b'a', b'bc'], len(b'abc'), names=[]) - offset, length = self.writer.add_bytes_record([b'abc'], len(b'abc'), names=[]) + self.writer.add_bytes_record([b"a", b"bc"], len(b"abc"), names=[]) + offset, length = self.writer.add_bytes_record([b"abc"], len(b"abc"), names=[]) self.assertEqual((49, 7), (offset, length)) self.assertOutput( - b'Bazaar pack format 1 (introduced in 0.18)\n' - b'B3\n\nabc' - b'B3\n\nabc') + b"Bazaar pack format 1 (introduced in 0.18)\n" b"B3\n\nabc" b"B3\n\nabc" + ) def test_add_bytes_record_invalid_name(self): """Adding a Bytes record with a name with whitespace in it raises @@ -202,14 +207,18 @@ def test_add_bytes_record_invalid_name(self): self.writer.begin() self.assertRaises( pack.InvalidRecordError, - self.writer.add_bytes_record, [b'abc'], len(b'abc'), names=[(b'bad name', )]) + self.writer.add_bytes_record, + [b"abc"], + len(b"abc"), + names=[(b"bad name",)], + ) def test_add_bytes_records_add_to_records_written(self): """Adding a Bytes record increments the records_written counter.""" self.writer.begin() - self.writer.add_bytes_record([b'foo'], len(b'foo'), names=[]) + self.writer.add_bytes_record([b"foo"], len(b"foo"), names=[]) self.assertEqual(1, self.writer.records_written) - self.writer.add_bytes_record([b'foo'], len(b'foo'), names=[]) + self.writer.add_bytes_record([b"foo"], len(b"foo"), names=[]) self.assertEqual(2, self.writer.records_written) @@ -236,33 +245,27 @@ def test_construct(self): def test_empty_container(self): """Read an empty container.""" - reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\nE") + reader = self.get_reader_for(b"Bazaar pack format 1 (introduced in 0.18)\nE") self.assertEqual([], list(reader.iter_records())) def test_unknown_format(self): """Unrecognised container formats raise UnknownContainerFormatError.""" reader = self.get_reader_for(b"unknown format\n") - self.assertRaises( - pack.UnknownContainerFormatError, reader.iter_records) + self.assertRaises(pack.UnknownContainerFormatError, reader.iter_records) def test_unexpected_end_of_container(self): """Containers that don't end with an End Marker record should cause UnexpectedEndOfContainerError to be raised. """ - reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\n") + reader = self.get_reader_for(b"Bazaar pack format 1 (introduced in 0.18)\n") iterator = reader.iter_records() - self.assertRaises( - pack.UnexpectedEndOfContainerError, next, iterator) + self.assertRaises(pack.UnexpectedEndOfContainerError, next, iterator) def test_unknown_record_type(self): """Unknown record types cause UnknownRecordTypeError to be raised.""" - reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\nX") + reader = self.get_reader_for(b"Bazaar pack format 1 (introduced in 0.18)\nX") iterator = reader.iter_records() - self.assertRaises( - pack.UnknownRecordTypeError, next, iterator) + self.assertRaises(pack.UnknownRecordTypeError, next, iterator) def test_container_with_one_unnamed_record(self): """Read a container with one Bytes record. @@ -272,24 +275,28 @@ def test_container_with_one_unnamed_record(self): ContainerReader's integration with BytesRecordReader is working. """ reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\nB5\n\naaaaaE") - expected_records = [([], b'aaaaa')] + b"Bazaar pack format 1 (introduced in 0.18)\nB5\n\naaaaaE" + ) + expected_records = [([], b"aaaaa")] self.assertEqual( expected_records, - [(names, read_bytes(None)) - for (names, read_bytes) in reader.iter_records()]) + [ + (names, read_bytes(None)) + for (names, read_bytes) in reader.iter_records() + ], + ) def test_validate_empty_container(self): """Validate does not raise an error for a container with no records.""" - reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\nE") + reader = self.get_reader_for(b"Bazaar pack format 1 (introduced in 0.18)\nE") # No exception raised reader.validate() def test_validate_non_empty_valid_container(self): """Validate does not raise an error for a container with a valid record.""" reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\nB3\nname\n\nabcE") + b"Bazaar pack format 1 (introduced in 0.18)\nB3\nname\n\nabcE" + ) # No exception raised reader.validate() @@ -299,21 +306,19 @@ def test_validate_bad_format(self): It may raise either UnexpectedEndOfContainerError or UnknownContainerFormatError, depending on exactly what the string is. """ - inputs = [ - b"", b"x", b"Bazaar pack format 1 (introduced in 0.18)", b"bad\n"] + inputs = [b"", b"x", b"Bazaar pack format 1 (introduced in 0.18)", b"bad\n"] for input in inputs: reader = self.get_reader_for(input) self.assertRaises( - (pack.UnexpectedEndOfContainerError, - pack.UnknownContainerFormatError), - reader.validate) + (pack.UnexpectedEndOfContainerError, pack.UnknownContainerFormatError), + reader.validate, + ) def test_validate_bad_record_marker(self): """Validate raises UnknownRecordTypeError for unrecognised record types. """ - reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\nX") + reader = self.get_reader_for(b"Bazaar pack format 1 (introduced in 0.18)\nX") self.assertRaises(pack.UnknownRecordTypeError, reader.validate) def test_validate_data_after_end_marker(self): @@ -321,19 +326,17 @@ def test_validate_data_after_end_marker(self): after the end of the container. """ reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\nEcrud") - self.assertRaises( - pack.ContainerHasExcessDataError, reader.validate) + b"Bazaar pack format 1 (introduced in 0.18)\nEcrud" + ) + self.assertRaises(pack.ContainerHasExcessDataError, reader.validate) def test_validate_no_end_marker(self): """Validate raises UnexpectedEndOfContainerError if there's no end of container marker, even if the container up to this point has been valid. """ - reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\n") - self.assertRaises( - pack.UnexpectedEndOfContainerError, reader.validate) + reader = self.get_reader_for(b"Bazaar pack format 1 (introduced in 0.18)\n") + self.assertRaises(pack.UnexpectedEndOfContainerError, reader.validate) def test_validate_duplicate_name(self): """Validate raises DuplicateRecordNameError if the same name occurs @@ -343,13 +346,15 @@ def test_validate_duplicate_name(self): b"Bazaar pack format 1 (introduced in 0.18)\n" b"B0\nname\n\n" b"B0\nname\n\n" - b"E") + b"E" + ) self.assertRaises(pack.DuplicateRecordNameError, reader.validate) def test_validate_undecodeable_name(self): """Names that aren't valid UTF-8 cause validate to fail.""" reader = self.get_reader_for( - b"Bazaar pack format 1 (introduced in 0.18)\nB0\n\xcc\n\nE") + b"Bazaar pack format 1 (introduced in 0.18)\nB0\n\xcc\n\nE" + ) self.assertRaises(pack.InvalidRecordError, reader.validate) @@ -374,7 +379,7 @@ def test_record_with_no_name(self): reader = self.get_reader_for(b"5\n\naaaaa") names, get_bytes = reader.read() self.assertEqual([], names) - self.assertEqual(b'aaaaa', get_bytes(None)) + self.assertEqual(b"aaaaa", get_bytes(None)) def test_record_with_one_name(self): """Reading a Bytes record with one name returns a list of just that @@ -382,22 +387,30 @@ def test_record_with_one_name(self): """ reader = self.get_reader_for(b"5\nname1\n\naaaaa") names, get_bytes = reader.read() - self.assertEqual([(b'name1', )], names) - self.assertEqual(b'aaaaa', get_bytes(None)) + self.assertEqual([(b"name1",)], names) + self.assertEqual(b"aaaaa", get_bytes(None)) def test_record_with_two_names(self): """Reading a Bytes record with two names returns a list of both names.""" reader = self.get_reader_for(b"5\nname1\nname2\n\naaaaa") names, get_bytes = reader.read() - self.assertEqual([(b'name1', ), (b'name2', )], names) - self.assertEqual(b'aaaaa', get_bytes(None)) + self.assertEqual([(b"name1",), (b"name2",)], names) + self.assertEqual(b"aaaaa", get_bytes(None)) def test_record_with_two_part_names(self): """Reading a Bytes record with a two_part name reads both.""" reader = self.get_reader_for(b"5\nname1\x00name2\n\naaaaa") names, get_bytes = reader.read() - self.assertEqual([(b'name1', b'name2', )], names) - self.assertEqual(b'aaaaa', get_bytes(None)) + self.assertEqual( + [ + ( + b"name1", + b"name2", + ) + ], + names, + ) + self.assertEqual(b"aaaaa", get_bytes(None)) def test_invalid_length(self): """If the length-prefix is not a number, parsing raises @@ -429,7 +442,8 @@ def test_early_eof(self): pass else: self.fail( - f"UnexpectedEndOfContainerError not raised when parsing {incomplete_record!r}") + f"UnexpectedEndOfContainerError not raised when parsing {incomplete_record!r}" + ) def test_initial_eof(self): """EOF before any bytes read at all.""" @@ -468,20 +482,17 @@ def test_validate_whitespace_in_name(self): def test_validate_interrupted_prelude(self): """EOF during reading a record's prelude causes validate to fail.""" reader = self.get_reader_for(b"") - self.assertRaises( - pack.UnexpectedEndOfContainerError, reader.validate) + self.assertRaises(pack.UnexpectedEndOfContainerError, reader.validate) def test_validate_interrupted_body(self): """EOF during reading a record's body causes validate to fail.""" reader = self.get_reader_for(b"1\n\n") - self.assertRaises( - pack.UnexpectedEndOfContainerError, reader.validate) + self.assertRaises(pack.UnexpectedEndOfContainerError, reader.validate) def test_validate_unparseable_length(self): """An unparseable record length causes validate to fail.""" reader = self.get_reader_for(b"\n\n") - self.assertRaises( - pack.InvalidRecordError, reader.validate) + self.assertRaises(pack.InvalidRecordError, reader.validate) def test_validate_undecodeable_name(self): """Names that aren't valid UTF-8 cause validate to fail.""" @@ -494,7 +505,7 @@ def test_read_max_length(self): """ reader = self.get_reader_for(b"6\n\nabcdef") names, get_bytes = reader.read() - self.assertEqual(b'abc', get_bytes(3)) + self.assertEqual(b"abc", get_bytes(3)) def test_read_no_max_length(self): """If the max_length passed to the callable returned by read is None, @@ -502,7 +513,7 @@ def test_read_no_max_length(self): """ reader = self.get_reader_for(b"6\n\nabcdef") names, get_bytes = reader.read() - self.assertEqual(b'abcdef', get_bytes(None)) + self.assertEqual(b"abcdef", get_bytes(None)) def test_repeated_read_calls(self): """Repeated calls to the callable returned from BytesRecordReader.read @@ -510,31 +521,30 @@ def test_repeated_read_calls(self): """ reader = self.get_reader_for(b"6\n\nabcdefB3\nnext-record\nXXX") names, get_bytes = reader.read() - self.assertEqual(b'abcdef', get_bytes(None)) - self.assertEqual(b'', get_bytes(None)) - self.assertEqual(b'', get_bytes(99)) + self.assertEqual(b"abcdef", get_bytes(None)) + self.assertEqual(b"", get_bytes(None)) + self.assertEqual(b"", get_bytes(99)) class TestMakeReadvReader(tests.TestCaseWithTransport): - def test_read_skipping_records(self): pack_data = BytesIO() writer = pack.ContainerWriter(pack_data.write) writer.begin() memos = [] - memos.append(writer.add_bytes_record([b'abc'], 3, names=[])) - memos.append(writer.add_bytes_record([b'def'], 3, names=[(b'name1', )])) - memos.append(writer.add_bytes_record([b'ghi'], 3, names=[(b'name2', )])) - memos.append(writer.add_bytes_record([b'jkl'], 3, names=[])) + memos.append(writer.add_bytes_record([b"abc"], 3, names=[])) + memos.append(writer.add_bytes_record([b"def"], 3, names=[(b"name1",)])) + memos.append(writer.add_bytes_record([b"ghi"], 3, names=[(b"name2",)])) + memos.append(writer.add_bytes_record([b"jkl"], 3, names=[])) writer.end() transport = self.get_transport() - transport.put_bytes('mypack', pack_data.getvalue()) + transport.put_bytes("mypack", pack_data.getvalue()) requested_records = [memos[0], memos[2]] - reader = pack.make_readv_reader(transport, 'mypack', requested_records) + reader = pack.make_readv_reader(transport, "mypack", requested_records) result = [] for names, reader_func in reader.iter_records(): result.append((names, reader_func(None))) - self.assertEqual([([], b'abc'), ([(b'name2', )], b'ghi')], result) + self.assertEqual([([], b"abc"), ([(b"name2",)], b"ghi")], result) class TestReadvFile(tests.TestCaseWithTransport): @@ -548,16 +558,15 @@ class TestReadvFile(tests.TestCaseWithTransport): def test_read_bytes(self): """Test reading of both single bytes and all bytes in a hunk.""" transport = self.get_transport() - transport.put_bytes('sample', b'0123456789') - f = pack.ReadVFile(transport.readv( - 'sample', [(0, 1), (1, 2), (4, 1), (6, 2)])) + transport.put_bytes("sample", b"0123456789") + f = pack.ReadVFile(transport.readv("sample", [(0, 1), (1, 2), (4, 1), (6, 2)])) results = [] results.append(f.read(1)) results.append(f.read(2)) results.append(f.read(1)) results.append(f.read(1)) results.append(f.read(1)) - self.assertEqual([b'0', b'12', b'4', b'6', b'7'], results) + self.assertEqual([b"0", b"12", b"4", b"6", b"7"], results) def test_readline(self): """Test using readline() as ContainerReader does. @@ -565,24 +574,24 @@ def test_readline(self): This is always within a readv hunk, never across it. """ transport = self.get_transport() - transport.put_bytes('sample', b'0\n2\n4\n') - f = pack.ReadVFile(transport.readv('sample', [(0, 2), (2, 4)])) + transport.put_bytes("sample", b"0\n2\n4\n") + f = pack.ReadVFile(transport.readv("sample", [(0, 2), (2, 4)])) results = [] results.append(f.readline()) results.append(f.readline()) results.append(f.readline()) - self.assertEqual([b'0\n', b'2\n', b'4\n'], results) + self.assertEqual([b"0\n", b"2\n", b"4\n"], results) def test_readline_and_read(self): """Test exercising one byte reads, readline, and then read again.""" transport = self.get_transport() - transport.put_bytes('sample', b'0\n2\n4\n') - f = pack.ReadVFile(transport.readv('sample', [(0, 6)])) + transport.put_bytes("sample", b"0\n2\n4\n") + f = pack.ReadVFile(transport.readv("sample", [(0, 6)])) results = [] results.append(f.read(1)) results.append(f.readline()) results.append(f.read(4)) - self.assertEqual([b'0', b'\n', b'2\n4\n'], results) + self.assertEqual([b"0", b"\n", b"2\n4\n"], results) class PushParserTestCase(tests.TestCase): @@ -630,8 +639,9 @@ def test_multiple_records_at_once(self): parser = self.make_parser_expecting_record_type() parser.accept_bytes(b"B5\nname1\n\nbody1B5\nname2\n\nbody2") self.assertEqual( - [([(b'name1',)], b'body1'), ([(b'name2',)], b'body2')], - parser.read_pending_records()) + [([(b"name1",)], b"body1"), ([(b"name2",)], b"body2")], + parser.read_pending_records(), + ) def test_multiple_empty_records_at_once(self): """If multiple empty records worth of data are fed to the parser in one @@ -643,8 +653,8 @@ def test_multiple_empty_records_at_once(self): parser = self.make_parser_expecting_record_type() parser.accept_bytes(b"B0\nname1\n\nB0\nname2\n\n") self.assertEqual( - [([(b'name1',)], b''), ([(b'name2',)], b'')], - parser.read_pending_records()) + [([(b"name1",)], b""), ([(b"name2",)], b"")], parser.read_pending_records() + ) class TestContainerPushParserBytesParsing(PushParserTestCase): @@ -659,27 +669,25 @@ def test_record_with_no_name(self): """Reading a Bytes record with no name returns an empty list of names. """ - self.assertRecordParsing(([], b'aaaaa'), b"5\n\naaaaa") + self.assertRecordParsing(([], b"aaaaa"), b"5\n\naaaaa") def test_record_with_one_name(self): """Reading a Bytes record with one name returns a list of just that name. """ - self.assertRecordParsing( - ([(b'name1', )], b'aaaaa'), - b"5\nname1\n\naaaaa") + self.assertRecordParsing(([(b"name1",)], b"aaaaa"), b"5\nname1\n\naaaaa") def test_record_with_two_names(self): """Reading a Bytes record with two names returns a list of both names.""" self.assertRecordParsing( - ([(b'name1', ), (b'name2', )], b'aaaaa'), - b"5\nname1\nname2\n\naaaaa") + ([(b"name1",), (b"name2",)], b"aaaaa"), b"5\nname1\nname2\n\naaaaa" + ) def test_record_with_two_part_names(self): """Reading a Bytes record with a two_part name reads both.""" self.assertRecordParsing( - ([(b'name1', b'name2')], b'aaaaa'), - b"5\nname1\x00name2\n\naaaaa") + ([(b"name1", b"name2")], b"aaaaa"), b"5\nname1\x00name2\n\naaaaa" + ) def test_invalid_length(self): """If the length-prefix is not a number, parsing raises @@ -687,7 +695,8 @@ def test_invalid_length(self): """ parser = self.make_parser_expecting_bytes_record() self.assertRaises( - pack.InvalidRecordError, parser.accept_bytes, b"not a number\n") + pack.InvalidRecordError, parser.accept_bytes, b"not a number\n" + ) def test_incomplete_record(self): """If the bytes seen so far don't form a complete record, then there @@ -705,8 +714,7 @@ def test_accept_nothing(self): def assertInvalidRecord(self, data): """Assert that parsing the given bytes raises InvalidRecordError.""" parser = self.make_parser_expecting_bytes_record() - self.assertRaises( - pack.InvalidRecordError, parser.accept_bytes, data) + self.assertRaises(pack.InvalidRecordError, parser.accept_bytes, data) def test_read_invalid_name_whitespace(self): """Names must have no whitespace.""" @@ -723,49 +731,39 @@ def test_repeated_read_pending_records(self): """read_pending_records will not return the same record twice.""" parser = self.make_parser_expecting_bytes_record() parser.accept_bytes(b"6\n\nabcdef") - self.assertEqual([([], b'abcdef')], parser.read_pending_records()) + self.assertEqual([([], b"abcdef")], parser.read_pending_records()) self.assertEqual([], parser.read_pending_records()) class TestErrors(tests.TestCase): - def test_unknown_container_format(self): """Test the formatting of UnknownContainerFormatError.""" - e = pack.UnknownContainerFormatError('bad format string') - self.assertEqual( - "Unrecognised container format: 'bad format string'", - str(e)) + e = pack.UnknownContainerFormatError("bad format string") + self.assertEqual("Unrecognised container format: 'bad format string'", str(e)) def test_unexpected_end_of_container(self): """Test the formatting of UnexpectedEndOfContainerError.""" e = pack.UnexpectedEndOfContainerError() - self.assertEqual( - "Unexpected end of container stream", str(e)) + self.assertEqual("Unexpected end of container stream", str(e)) def test_unknown_record_type(self): """Test the formatting of UnknownRecordTypeError.""" e = pack.UnknownRecordTypeError("X") - self.assertEqual( - "Unknown record type: 'X'", - str(e)) + self.assertEqual("Unknown record type: 'X'", str(e)) def test_invalid_record(self): """Test the formatting of InvalidRecordError.""" e = pack.InvalidRecordError("xxx") - self.assertEqual( - "Invalid record: xxx", - str(e)) + self.assertEqual("Invalid record: xxx", str(e)) def test_container_has_excess_data(self): """Test the formatting of ContainerHasExcessDataError.""" e = pack.ContainerHasExcessDataError("excess bytes") - self.assertEqual( - "Container has data after end marker: 'excess bytes'", - str(e)) + self.assertEqual("Container has data after end marker: 'excess bytes'", str(e)) def test_duplicate_record_name_error(self): """Test the formatting of DuplicateRecordNameError.""" e = pack.DuplicateRecordNameError(b"n\xc3\xa5me") self.assertEqual( - "Container has multiple records with the same name: n\xe5me", - str(e)) + "Container has multiple records with the same name: n\xe5me", str(e) + ) diff --git a/breezy/bzr/tests/test_read_bundle.py b/breezy/bzr/tests/test_read_bundle.py index 5ec8aaca16..1224786728 100644 --- a/breezy/bzr/tests/test_read_bundle.py +++ b/breezy/bzr/tests/test_read_bundle.py @@ -33,20 +33,20 @@ def create_bundle_file(test_case): - test_case.build_tree(['tree/', 'tree/a', 'tree/subdir/']) + test_case.build_tree(["tree/", "tree/a", "tree/subdir/"]) format = breezy.bzr.bzrdir.BzrDirFormat.get_default_format() - bzrdir = format.initialize('tree') + bzrdir = format.initialize("tree") repo = bzrdir.create_repository() branch = repo.controldir.create_branch() wt = branch.controldir.create_workingtree() - wt.add(['a', 'subdir/']) - wt.commit('new project', rev_id=b'commit-1') + wt.add(["a", "subdir/"]) + wt.commit("new project", rev_id=b"commit-1") out = BytesIO() - write_bundle(wt.branch.repository, wt.get_parent_ids()[0], b'null:', out) + write_bundle(wt.branch.repository, wt.get_parent_ids()[0], b"null:", out) out.seek(0) return out, wt @@ -58,20 +58,21 @@ class TestReadMergeableBundleFromURL(TestTransportImplementation): def setUp(self): super().setUp() - self.bundle_name = 'test_bundle' + self.bundle_name = "test_bundle" # read_mergeable_from_url will invoke get_transport which may *not* # respect self._transport (i.e. returns a transport that is different # from the one we want to test, so we must inject a correct transport # into possible_transports first). self.possible_transports = [self.get_transport(self.bundle_name)] - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) self.create_test_bundle() def read_mergeable_from_url(self, url): return breezy.mergeable.read_mergeable_from_url( - url, possible_transports=self.possible_transports) + url, possible_transports=self.possible_transports + ) - def get_url(self, relpath=''): + def get_url(self, relpath=""): return breezy.urlutils.join(self._server.get_url(), relpath) def create_test_bundle(self): @@ -80,30 +81,31 @@ def create_test_bundle(self): self.build_tree_contents([(self.bundle_name, out.getvalue())]) else: self.get_transport().put_file(self.bundle_name, out) - self.log('Put to: %s', self.get_url(self.bundle_name)) + self.log("Put to: %s", self.get_url(self.bundle_name)) return wt def test_read_mergeable_from_url(self): - info = self.read_mergeable_from_url( - str(self.get_url(self.bundle_name))) + info = self.read_mergeable_from_url(str(self.get_url(self.bundle_name))) revision = info.real_revisions[-1] - self.assertEqual(b'commit-1', revision.revision_id) + self.assertEqual(b"commit-1", revision.revision_id) def test_read_fail(self): # Trying to read from a directory, or non-bundle file # should fail with NotABundle - self.assertRaises(errors.NotABundle, - self.read_mergeable_from_url, self.get_url('tree')) - self.assertRaises(errors.NotABundle, - self.read_mergeable_from_url, self.get_url('tree/a')) + self.assertRaises( + errors.NotABundle, self.read_mergeable_from_url, self.get_url("tree") + ) + self.assertRaises( + errors.NotABundle, self.read_mergeable_from_url, self.get_url("tree/a") + ) def test_read_mergeable_respects_possible_transports(self): - if not isinstance(self.get_transport(self.bundle_name), - breezy.transport.ConnectedTransport): + if not isinstance( + self.get_transport(self.bundle_name), breezy.transport.ConnectedTransport + ): # There is no point testing transport reuse for not connected # transports (the test will fail even). - raise tests.TestSkipped( - 'Need a ConnectedTransport to test transport reuse') + raise tests.TestSkipped("Need a ConnectedTransport to test transport reuse") url = str(self.get_url(self.bundle_name)) self.read_mergeable_from_url(url) self.assertEqual(1, len(self.possible_transports)) diff --git a/breezy/bzr/tests/test_remote.py b/breezy/bzr/tests/test_remote.py index eb9f5f6ef6..5d2b918d41 100644 --- a/breezy/bzr/tests/test_remote.py +++ b/breezy/bzr/tests/test_remote.py @@ -73,19 +73,19 @@ class BasicRemoteObjectTests(tests.TestCaseWithTransport): - scenarios = [ - ('HPSS-v2', - {'transport_server': - test_server.SmartTCPServer_for_testing_v2_only}), - ('HPSS-v3', - {'transport_server': test_server.SmartTCPServer_for_testing})] + ( + "HPSS-v2", + {"transport_server": test_server.SmartTCPServer_for_testing_v2_only}, + ), + ("HPSS-v3", {"transport_server": test_server.SmartTCPServer_for_testing}), + ] def setUp(self): super().setUp() self.transport = self.get_transport() # make a branch that can be opened over the smart transport - self.local_wt = BzrDir.create_standalone_workingtree('.') + self.local_wt = BzrDir.create_standalone_workingtree(".") self.addCleanup(self.transport.disconnect) def test_create_remote_bzrdir(self): @@ -101,9 +101,9 @@ def test_open_remote_branch(self): def test_remote_repository(self): b = BzrDir.open_from_transport(self.transport) repo = b.open_repository() - revid = '\xc823123123'.encode() + revid = "\xc823123123".encode() self.assertFalse(repo.has_revision(revid)) - self.local_wt.commit(message='test commit', rev_id=revid) + self.local_wt.commit(message="test commit", rev_id=revid) self.assertTrue(repo.has_revision(revid)) def test_find_correct_format(self): @@ -119,51 +119,49 @@ def test_open_detected_smart_format(self): def test_remote_branch_repr(self): b = BzrDir.open_from_transport(self.transport).open_branch() - self.assertStartsWith(str(b), 'RemoteBranch(') + self.assertStartsWith(str(b), "RemoteBranch(") def test_remote_bzrdir_repr(self): b = BzrDir.open_from_transport(self.transport) - self.assertStartsWith(str(b), 'RemoteBzrDir(') + self.assertStartsWith(str(b), "RemoteBzrDir(") def test_remote_branch_format_supports_stacking(self): t = self.transport - self.make_branch('unstackable', format='pack-0.92') - b = BzrDir.open_from_transport(t.clone('unstackable')).open_branch() + self.make_branch("unstackable", format="pack-0.92") + b = BzrDir.open_from_transport(t.clone("unstackable")).open_branch() self.assertFalse(b._format.supports_stacking()) - self.make_branch('stackable', format='1.9') - b = BzrDir.open_from_transport(t.clone('stackable')).open_branch() + self.make_branch("stackable", format="1.9") + b = BzrDir.open_from_transport(t.clone("stackable")).open_branch() self.assertTrue(b._format.supports_stacking()) def test_remote_repo_format_supports_external_references(self): t = self.transport - bd = self.make_controldir('unstackable', format='pack-0.92') + bd = self.make_controldir("unstackable", format="pack-0.92") r = bd.create_repository() self.assertFalse(r._format.supports_external_lookups) - r = BzrDir.open_from_transport( - t.clone('unstackable')).open_repository() + r = BzrDir.open_from_transport(t.clone("unstackable")).open_repository() self.assertFalse(r._format.supports_external_lookups) - bd = self.make_controldir('stackable', format='1.9') + bd = self.make_controldir("stackable", format="1.9") r = bd.create_repository() self.assertTrue(r._format.supports_external_lookups) - r = BzrDir.open_from_transport(t.clone('stackable')).open_repository() + r = BzrDir.open_from_transport(t.clone("stackable")).open_repository() self.assertTrue(r._format.supports_external_lookups) def test_remote_branch_set_append_revisions_only(self): # Make a format 1.9 branch, which supports append_revisions_only - branch = self.make_branch('branch', format='1.9') + branch = self.make_branch("branch", format="1.9") branch.set_append_revisions_only(True) config = branch.get_config_stack() - self.assertEqual( - True, config.get('append_revisions_only')) + self.assertEqual(True, config.get("append_revisions_only")) branch.set_append_revisions_only(False) config = branch.get_config_stack() - self.assertEqual( - False, config.get('append_revisions_only')) + self.assertEqual(False, config.get("append_revisions_only")) def test_remote_branch_set_append_revisions_only_upgrade_reqd(self): - branch = self.make_branch('branch', format='knit') + branch = self.make_branch("branch", format="knit") self.assertRaises( - errors.UpgradeRequired, branch.set_append_revisions_only, True) + errors.UpgradeRequired, branch.set_append_revisions_only, True + ) class FakeProtocol: @@ -192,7 +190,7 @@ def read_streamed_body(self): class FakeClient(_SmartClient): """Lookalike for _SmartClient allowing testing.""" - def __init__(self, fake_medium_base='fake base'): + def __init__(self, fake_medium_base="fake base"): """Create a FakeClient.""" self.responses = [] self._calls = [] @@ -204,39 +202,42 @@ def __init__(self, fake_medium_base='fake base'): self._expected_calls = None _SmartClient.__init__(self, FakeMedium(self._calls, fake_medium_base)) - def add_expected_call(self, call_name, call_args, response_type, - response_args, response_body=None): + def add_expected_call( + self, call_name, call_args, response_type, response_args, response_body=None + ): if self._expected_calls is None: self._expected_calls = [] self._expected_calls.append((call_name, call_args)) self.responses.append((response_type, response_args, response_body)) def add_success_response(self, *args): - self.responses.append((b'success', args, None)) + self.responses.append((b"success", args, None)) def add_success_response_with_body(self, body, *args): - self.responses.append((b'success', args, body)) + self.responses.append((b"success", args, body)) if self._expected_calls is not None: self._expected_calls.append(None) def add_error_response(self, *args): - self.responses.append((b'error', args)) + self.responses.append((b"error", args)) def add_unknown_method_response(self, verb): - self.responses.append((b'unknown', verb)) + self.responses.append((b"unknown", verb)) def finished_test(self): if self._expected_calls: - raise AssertionError(f"{self!r} finished but was still expecting {self._expected_calls[0]!r}") + raise AssertionError( + f"{self!r} finished but was still expecting {self._expected_calls[0]!r}" + ) def _get_next_response(self): try: response_tuple = self.responses.pop(0) except IndexError as e: raise AssertionError(f"{self!r} didn't expect any more calls") from e - if response_tuple[0] == b'unknown': + if response_tuple[0] == b"unknown": raise errors.UnknownSmartMethod(response_tuple[1]) - elif response_tuple[0] == b'error': + elif response_tuple[0] == b"error": raise errors.ErrorFromSmartServer(response_tuple[1]) return response_tuple @@ -247,36 +248,37 @@ def _check_call(self, method, args): try: next_call = self._expected_calls.pop(0) except IndexError as e: - raise AssertionError(f"{self!r} didn't expect any more calls " - f"but got {method!r}{args!r}") from e + raise AssertionError( + f"{self!r} didn't expect any more calls " f"but got {method!r}{args!r}" + ) from e if next_call is None: return if method != next_call[0] or args != next_call[1]: raise AssertionError( - f"{self!r} expected {next_call[0]!r}{next_call[1]!r} but got {method!r}{args!r}") + f"{self!r} expected {next_call[0]!r}{next_call[1]!r} but got {method!r}{args!r}" + ) def call(self, method, *args): self._check_call(method, args) - self._calls.append(('call', method, args)) + self._calls.append(("call", method, args)) return self._get_next_response()[1] def call_expecting_body(self, method, *args): self._check_call(method, args) - self._calls.append(('call_expecting_body', method, args)) + self._calls.append(("call_expecting_body", method, args)) result = self._get_next_response() self.expecting_body = True return result[1], FakeProtocol(result[2], self) def call_with_body_bytes(self, method, args, body): self._check_call(method, args) - self._calls.append(('call_with_body_bytes', method, args, body)) + self._calls.append(("call_with_body_bytes", method, args, body)) result = self._get_next_response() return result[1], FakeProtocol(result[2], self) def call_with_body_bytes_expecting_body(self, method, args, body): self._check_call(method, args) - self._calls.append(('call_with_body_bytes_expecting_body', method, - args, body)) + self._calls.append(("call_with_body_bytes_expecting_body", method, args, body)) result = self._get_next_response() self.expecting_body = True return result[1], FakeProtocol(result[2], self) @@ -286,8 +288,7 @@ def call_with_body_stream(self, args, stream): # that's what happens a real medium. stream = list(stream) self._check_call(args[0], args[1:]) - self._calls.append( - ('call_with_body_stream', args[0], args[1:], stream)) + self._calls.append(("call_with_body_stream", args[0], args[1:], stream)) result = self._get_next_response() # The second value returned from call_with_body_stream is supposed to # be a response_handler object, but so far no tests depend on that. @@ -296,37 +297,34 @@ def call_with_body_stream(self, args, stream): class FakeMedium(medium.SmartClientMedium): - def __init__(self, client_calls, base): medium.SmartClientMedium.__init__(self, base) self._client_calls = client_calls def disconnect(self): - self._client_calls.append(('disconnect medium',)) + self._client_calls.append(("disconnect medium",)) class TestVfsHas(tests.TestCase): - def test_unicode_path(self): - client = FakeClient('/') - client.add_success_response(b'yes',) - transport = RemoteTransport('bzr://localhost/', _client=client) - filename = '/hell\u00d8' + client = FakeClient("/") + client.add_success_response( + b"yes", + ) + transport = RemoteTransport("bzr://localhost/", _client=client) + filename = "/hell\u00d8" result = transport.has(filename) - self.assertEqual( - [('call', b'has', (filename.encode('utf-8'),))], - client._calls) + self.assertEqual([("call", b"has", (filename.encode("utf-8"),))], client._calls) self.assertTrue(result) class TestRemote(tests.TestCaseWithMemoryTransport): - def get_branch_format(self): - reference_bzrdir_format = controldir.format_registry.get('default')() + reference_bzrdir_format = controldir.format_registry.get("default")() return reference_bzrdir_format.get_branch_format() def get_repo_format(self): - reference_bzrdir_format = controldir.format_registry.get('default')() + reference_bzrdir_format = controldir.format_registry.get("default")() return reference_bzrdir_format.repository_format def assertFinished(self, fake_client): @@ -351,9 +349,8 @@ def test_remote_path_from_transport(self): """SmartClientMedium.remote_path_from_transport calculates a URL for the given transport relative to the root of the client base URL. """ - self.assertRemotePath('xyz/', 'bzr://host/path', 'bzr://host/xyz') - self.assertRemotePath( - 'path/xyz/', 'bzr://host/path', 'bzr://host/path/xyz') + self.assertRemotePath("xyz/", "bzr://host/path", "bzr://host/xyz") + self.assertRemotePath("path/xyz/", "bzr://host/path", "bzr://host/path/xyz") def assertRemotePathHTTP(self, expected, transport_base, relpath): """Assert that the result of @@ -372,11 +369,9 @@ def test_remote_path_from_transport_http(self): transports. They are just relative to the client base, not the root directory of the host. """ - for scheme in ['http:', 'https:', 'bzr+http:', 'bzr+https:']: - self.assertRemotePathHTTP( - '../xyz/', scheme + '//host/path', '../xyz/') - self.assertRemotePathHTTP( - 'xyz/', scheme + '//host/path', 'xyz/') + for scheme in ["http:", "https:", "bzr+http:", "bzr+https:"]: + self.assertRemotePathHTTP("../xyz/", scheme + "//host/path", "../xyz/") + self.assertRemotePathHTTP("xyz/", scheme + "//host/path", "xyz/") class Test_ClientMedium_remote_is_at_least(tests.TestCase): @@ -386,14 +381,14 @@ def test_initially_unlimited(self): """A fresh medium assumes that the remote side supports all versions. """ - client_medium = medium.SmartClientMedium('dummy base') + client_medium = medium.SmartClientMedium("dummy base") self.assertFalse(client_medium._is_remote_before((99, 99))) def test__remember_remote_is_before(self): """Calling _remember_remote_is_before ratchets down the known remote version. """ - client_medium = medium.SmartClientMedium('dummy base') + client_medium = medium.SmartClientMedium("dummy base") # Mark the remote side as being less than 1.6. The remote side may # still be 1.5. client_medium._remember_remote_is_before((1, 6)) @@ -404,58 +399,68 @@ def test__remember_remote_is_before(self): self.assertTrue(client_medium._is_remote_before((1, 5))) # If you call _remember_remote_is_before with a higher value it logs a # warning, and continues to remember the lower value. - self.assertNotContainsRe(self.get_log(), '_remember_remote_is_before') + self.assertNotContainsRe(self.get_log(), "_remember_remote_is_before") client_medium._remember_remote_is_before((1, 9)) - self.assertContainsRe(self.get_log(), '_remember_remote_is_before') + self.assertContainsRe(self.get_log(), "_remember_remote_is_before") self.assertTrue(client_medium._is_remote_before((1, 5))) class TestBzrDirCloningMetaDir(TestRemote): - def test_backwards_compat(self): self.setup_smart_server_with_call_log() - a_dir = self.make_controldir('.') + a_dir = self.make_controldir(".") self.reset_smart_call_log() - verb = b'BzrDir.cloning_metadir' + verb = b"BzrDir.cloning_metadir" self.disable_verb(verb) a_dir.cloning_metadir() - call_count = len([call for call in self.hpss_calls if - call.call.method == verb]) + call_count = len([call for call in self.hpss_calls if call.call.method == verb]) self.assertEqual(1, call_count) def test_branch_reference(self): - transport = self.get_transport('quack') - referenced = self.make_branch('referenced') + transport = self.get_transport("quack") + referenced = self.make_branch("referenced") expected = referenced.controldir.cloning_metadir() client = FakeClient(transport.base) - client.add_expected_call( - b'BzrDir.cloning_metadir', (b'quack/', b'False'), - b'error', (b'BranchReference',)), - client.add_expected_call( - b'BzrDir.open_branchV3', (b'quack/',), - b'success', (b'ref', self.get_url('referenced').encode('utf-8'))), - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + ( + client.add_expected_call( + b"BzrDir.cloning_metadir", + (b"quack/", b"False"), + b"error", + (b"BranchReference",), + ), + ) + ( + client.add_expected_call( + b"BzrDir.open_branchV3", + (b"quack/",), + b"success", + (b"ref", self.get_url("referenced").encode("utf-8")), + ), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) result = a_controldir.cloning_metadir() # We should have got a control dir matching the referenced branch. self.assertEqual(bzrdir.BzrDirMetaFormat1, type(result)) - self.assertEqual(expected._repository_format, - result._repository_format) + self.assertEqual(expected._repository_format, result._repository_format) self.assertEqual(expected._branch_format, result._branch_format) self.assertFinished(client) def test_current_server(self): - transport = self.get_transport('.') - transport = transport.clone('quack') - self.make_controldir('quack') + transport = self.get_transport(".") + transport = transport.clone("quack") + self.make_controldir("quack") client = FakeClient(transport.base) - reference_bzrdir_format = controldir.format_registry.get('default')() + reference_bzrdir_format = controldir.format_registry.get("default")() control_name = reference_bzrdir_format.network_name() - client.add_expected_call( - b'BzrDir.cloning_metadir', (b'quack/', b'False'), - b'success', (control_name, b'', (b'branch', b''))), - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + ( + client.add_expected_call( + b"BzrDir.cloning_metadir", + (b"quack/", b"False"), + b"success", + (control_name, b"", (b"branch", b"")), + ), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) result = a_controldir.cloning_metadir() # We should have got a reference control dir with default branch and # repository formats. @@ -466,33 +471,37 @@ def test_current_server(self): self.assertFinished(client) def test_unknown(self): - transport = self.get_transport('quack') - referenced = self.make_branch('referenced') + transport = self.get_transport("quack") + referenced = self.make_branch("referenced") referenced.controldir.cloning_metadir() client = FakeClient(transport.base) - client.add_expected_call( - b'BzrDir.cloning_metadir', (b'quack/', b'False'), - b'success', (b'unknown', b'unknown', (b'branch', b''))), - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) - self.assertRaises(errors.UnknownFormatError, - a_controldir.cloning_metadir) + ( + client.add_expected_call( + b"BzrDir.cloning_metadir", + (b"quack/", b"False"), + b"success", + (b"unknown", b"unknown", (b"branch", b"")), + ), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) + self.assertRaises(errors.UnknownFormatError, a_controldir.cloning_metadir) class TestBzrDirCheckoutMetaDir(TestRemote): - def test__get_checkout_format(self): transport = MemoryTransport() client = FakeClient(transport.base) - reference_bzrdir_format = controldir.format_registry.get('default')() + reference_bzrdir_format = controldir.format_registry.get("default")() control_name = reference_bzrdir_format.network_name() client.add_expected_call( - b'BzrDir.checkout_metadir', (b'quack/', ), - b'success', (control_name, b'', b'')) - transport.mkdir('quack') - transport = transport.clone('quack') - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + b"BzrDir.checkout_metadir", + (b"quack/",), + b"success", + (control_name, b"", b""), + ) + transport.mkdir("quack") + transport = transport.clone("quack") + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) result = a_controldir.checkout_metadir() # We should have got a reference control dir with default branch and # repository formats. @@ -505,108 +514,122 @@ def test_unknown_format(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'BzrDir.checkout_metadir', (b'quack/',), - b'success', (b'dontknow', b'', b'')) - transport.mkdir('quack') - transport = transport.clone('quack') - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) - self.assertRaises(errors.UnknownFormatError, - a_controldir.checkout_metadir) + b"BzrDir.checkout_metadir", + (b"quack/",), + b"success", + (b"dontknow", b"", b""), + ) + transport.mkdir("quack") + transport = transport.clone("quack") + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) + self.assertRaises(errors.UnknownFormatError, a_controldir.checkout_metadir) self.assertFinished(client) class TestBzrDirGetBranches(TestRemote): - def test_get_branches(self): transport = MemoryTransport() client = FakeClient(transport.base) - reference_bzrdir_format = controldir.format_registry.get('default')() + reference_bzrdir_format = controldir.format_registry.get("default")() branch_name = reference_bzrdir_format.get_branch_format().network_name() client.add_success_response_with_body( - bencode.bencode({ - b"foo": (b"branch", branch_name), - b"": (b"branch", branch_name)}), b"success") + bencode.bencode( + {b"foo": (b"branch", branch_name), b"": (b"branch", branch_name)} + ), + b"success", + ) client.add_success_response( - b'ok', b'', b'no', b'no', b'no', - reference_bzrdir_format.repository_format.network_name()) - client.add_error_response(b'NotStacked') + b"ok", + b"", + b"no", + b"no", + b"no", + reference_bzrdir_format.repository_format.network_name(), + ) + client.add_error_response(b"NotStacked") client.add_success_response( - b'ok', b'', b'no', b'no', b'no', - reference_bzrdir_format.repository_format.network_name()) - client.add_error_response(b'NotStacked') - transport.mkdir('quack') - transport = transport.clone('quack') - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + b"ok", + b"", + b"no", + b"no", + b"no", + reference_bzrdir_format.repository_format.network_name(), + ) + client.add_error_response(b"NotStacked") + transport.mkdir("quack") + transport = transport.clone("quack") + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) result = a_controldir.get_branches() self.assertEqual({"", "foo"}, set(result.keys())) self.assertEqual( - [('call_expecting_body', b'BzrDir.get_branches', (b'quack/',)), - ('call', b'BzrDir.find_repositoryV3', (b'quack/', )), - ('call', b'Branch.get_stacked_on_url', (b'quack/', )), - ('call', b'BzrDir.find_repositoryV3', (b'quack/', )), - ('call', b'Branch.get_stacked_on_url', (b'quack/', ))], - client._calls) + [ + ("call_expecting_body", b"BzrDir.get_branches", (b"quack/",)), + ("call", b"BzrDir.find_repositoryV3", (b"quack/",)), + ("call", b"Branch.get_stacked_on_url", (b"quack/",)), + ("call", b"BzrDir.find_repositoryV3", (b"quack/",)), + ("call", b"Branch.get_stacked_on_url", (b"quack/",)), + ], + client._calls, + ) class TestBzrDirDestroyBranch(TestRemote): - def test_destroy_default(self): - transport = self.get_transport('quack') - self.make_branch('referenced') + transport = self.get_transport("quack") + self.make_branch("referenced") client = FakeClient(transport.base) - client.add_expected_call( - b'BzrDir.destroy_branch', (b'quack/', ), - b'success', (b'ok',)), - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + ( + client.add_expected_call( + b"BzrDir.destroy_branch", (b"quack/",), b"success", (b"ok",) + ), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) a_controldir.destroy_branch() self.assertFinished(client) class TestBzrDirHasWorkingTree(TestRemote): - def test_has_workingtree(self): - transport = self.get_transport('quack') + transport = self.get_transport("quack") client = FakeClient(transport.base) - client.add_expected_call( - b'BzrDir.has_workingtree', (b'quack/',), - b'success', (b'yes',)), - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + ( + client.add_expected_call( + b"BzrDir.has_workingtree", (b"quack/",), b"success", (b"yes",) + ), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) self.assertTrue(a_controldir.has_workingtree()) self.assertFinished(client) def test_no_workingtree(self): - transport = self.get_transport('quack') + transport = self.get_transport("quack") client = FakeClient(transport.base) - client.add_expected_call( - b'BzrDir.has_workingtree', (b'quack/',), - b'success', (b'no',)), - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + ( + client.add_expected_call( + b"BzrDir.has_workingtree", (b"quack/",), b"success", (b"no",) + ), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) self.assertFalse(a_controldir.has_workingtree()) self.assertFinished(client) class TestBzrDirDestroyRepository(TestRemote): - def test_destroy_repository(self): - transport = self.get_transport('quack') + transport = self.get_transport("quack") client = FakeClient(transport.base) - client.add_expected_call( - b'BzrDir.destroy_repository', (b'quack/',), - b'success', (b'ok',)), - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + ( + client.add_expected_call( + b"BzrDir.destroy_repository", (b"quack/",), b"success", (b"ok",) + ), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) a_controldir.destroy_repository() self.assertFinished(client) class TestBzrDirOpen(TestRemote): - - def make_fake_client_and_transport(self, path='quack'): + def make_fake_client_and_transport(self, path="quack"): transport = MemoryTransport() transport.mkdir(path) transport = transport.clone(path) @@ -615,19 +638,25 @@ def make_fake_client_and_transport(self, path='quack'): def test_absent(self): client, transport = self.make_fake_client_and_transport() - client.add_expected_call( - b'BzrDir.open_2.1', (b'quack/',), b'success', (b'no',)) - self.assertRaises(errors.NotBranchError, RemoteBzrDir, transport, - RemoteBzrDirFormat(), _client=client, - _force_probe=True) + client.add_expected_call(b"BzrDir.open_2.1", (b"quack/",), b"success", (b"no",)) + self.assertRaises( + errors.NotBranchError, + RemoteBzrDir, + transport, + RemoteBzrDirFormat(), + _client=client, + _force_probe=True, + ) self.assertFinished(client) def test_present_without_workingtree(self): client, transport = self.make_fake_client_and_transport() client.add_expected_call( - b'BzrDir.open_2.1', (b'quack/',), b'success', (b'yes', b'no')) - bd = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client, _force_probe=True) + b"BzrDir.open_2.1", (b"quack/",), b"success", (b"yes", b"no") + ) + bd = RemoteBzrDir( + transport, RemoteBzrDirFormat(), _client=client, _force_probe=True + ) self.assertIsInstance(bd, RemoteBzrDir) self.assertFalse(bd.has_workingtree()) self.assertRaises(errors.NoWorkingTree, bd.open_workingtree) @@ -636,9 +665,11 @@ def test_present_without_workingtree(self): def test_present_with_workingtree(self): client, transport = self.make_fake_client_and_transport() client.add_expected_call( - b'BzrDir.open_2.1', (b'quack/',), b'success', (b'yes', b'yes')) - bd = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client, _force_probe=True) + b"BzrDir.open_2.1", (b"quack/",), b"success", (b"yes", b"yes") + ) + bd = RemoteBzrDir( + transport, RemoteBzrDirFormat(), _client=client, _force_probe=True + ) self.assertIsInstance(bd, RemoteBzrDir) self.assertTrue(bd.has_workingtree()) self.assertRaises(errors.NotLocalUrl, bd.open_workingtree) @@ -647,12 +678,12 @@ def test_present_with_workingtree(self): def test_backwards_compat(self): client, transport = self.make_fake_client_and_transport() client.add_expected_call( - b'BzrDir.open_2.1', (b'quack/',), b'unknown', - (b'BzrDir.open_2.1',)) - client.add_expected_call( - b'BzrDir.open', (b'quack/',), b'success', (b'yes',)) - bd = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client, _force_probe=True) + b"BzrDir.open_2.1", (b"quack/",), b"unknown", (b"BzrDir.open_2.1",) + ) + client.add_expected_call(b"BzrDir.open", (b"quack/",), b"success", (b"yes",)) + bd = RemoteBzrDir( + transport, RemoteBzrDirFormat(), _client=client, _force_probe=True + ) self.assertIsInstance(bd, RemoteBzrDir) self.assertFinished(client) @@ -669,30 +700,29 @@ def check_call(method, args): client._medium._remember_remote_is_before((1, 6)) client._check_call = orig_check_call client._check_call(method, args) + client._check_call = check_call client.add_expected_call( - b'BzrDir.open_2.1', (b'quack/',), b'unknown', - (b'BzrDir.open_2.1',)) - client.add_expected_call( - b'BzrDir.open', (b'quack/',), b'success', (b'yes',)) - bd = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client, _force_probe=True) + b"BzrDir.open_2.1", (b"quack/",), b"unknown", (b"BzrDir.open_2.1",) + ) + client.add_expected_call(b"BzrDir.open", (b"quack/",), b"success", (b"yes",)) + bd = RemoteBzrDir( + transport, RemoteBzrDirFormat(), _client=client, _force_probe=True + ) self.assertIsInstance(bd, RemoteBzrDir) self.assertFinished(client) class TestBzrDirOpenBranch(TestRemote): - def test_backwards_compat(self): self.setup_smart_server_with_call_log() - self.make_branch('.') - a_dir = BzrDir.open(self.get_url('.')) + self.make_branch(".") + a_dir = BzrDir.open(self.get_url(".")) self.reset_smart_call_log() - verb = b'BzrDir.open_branchV3' + verb = b"BzrDir.open_branchV3" self.disable_verb(verb) a_dir.open_branch() - call_count = len([call for call in self.hpss_calls if - call.call.method == verb]) + call_count = len([call for call in self.hpss_calls if call.call.method == verb]) self.assertEqual(1, call_count) def test_branch_present(self): @@ -700,20 +730,25 @@ def test_branch_present(self): network_name = reference_format.network_name() branch_network_name = self.get_branch_format().network_name() transport = MemoryTransport() - transport.mkdir('quack') - transport = transport.clone('quack') + transport.mkdir("quack") + transport = transport.clone("quack") client = FakeClient(transport.base) client.add_expected_call( - b'BzrDir.open_branchV3', (b'quack/',), - b'success', (b'branch', branch_network_name)) + b"BzrDir.open_branchV3", + (b"quack/",), + b"success", + (b"branch", branch_network_name), + ) client.add_expected_call( - b'BzrDir.find_repositoryV3', (b'quack/',), - b'success', (b'ok', b'', b'no', b'no', b'no', network_name)) + b"BzrDir.find_repositoryV3", + (b"quack/",), + b"success", + (b"ok", b"", b"no", b"no", b"no", network_name), + ) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) - bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) + bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) result = bzrdir.open_branch() self.assertIsInstance(result, RemoteBranch) self.assertEqual(bzrdir, result.controldir) @@ -721,16 +756,15 @@ def test_branch_present(self): def test_branch_missing(self): transport = MemoryTransport() - transport.mkdir('quack') - transport = transport.clone('quack') + transport.mkdir("quack") + transport = transport.clone("quack") client = FakeClient(transport.base) - client.add_error_response(b'nobranch') - bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + client.add_error_response(b"nobranch") + bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) self.assertRaises(errors.NotBranchError, bzrdir.open_branch) self.assertEqual( - [('call', b'BzrDir.open_branchV3', (b'quack/',))], - client._calls) + [("call", b"BzrDir.open_branchV3", (b"quack/",))], client._calls + ) def test__get_tree_branch(self): # _get_tree_branch is a form of open_branch, but it should only ask for @@ -740,11 +774,11 @@ def test__get_tree_branch(self): def open_branch(name=None, possible_transports=None): calls.append("Called") return "a-branch" + transport = MemoryTransport() # no requests on the network - catches other api calls being made. client = FakeClient(transport.base) - bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) # patch the open_branch call to record that it was called. bzrdir.open_branch = open_branch self.assertEqual((None, "a-branch"), bzrdir._get_tree_branch()) @@ -754,50 +788,53 @@ def open_branch(name=None, possible_transports=None): def test_url_quoting_of_path(self): # Relpaths on the wire should not be URL-escaped. So "~" should be # transmitted as "~", not "%7E". - transport = RemoteTCPTransport('bzr://localhost/~hello/') + transport = RemoteTCPTransport("bzr://localhost/~hello/") client = FakeClient(transport.base) reference_format = self.get_repo_format() network_name = reference_format.network_name() branch_network_name = self.get_branch_format().network_name() client.add_expected_call( - b'BzrDir.open_branchV3', (b'~hello/',), - b'success', (b'branch', branch_network_name)) + b"BzrDir.open_branchV3", + (b"~hello/",), + b"success", + (b"branch", branch_network_name), + ) client.add_expected_call( - b'BzrDir.find_repositoryV3', (b'~hello/',), - b'success', (b'ok', b'', b'no', b'no', b'no', network_name)) + b"BzrDir.find_repositoryV3", + (b"~hello/",), + b"success", + (b"ok", b"", b"no", b"no", b"no", network_name), + ) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'~hello/',), - b'error', (b'NotStacked',)) - bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + b"Branch.get_stacked_on_url", (b"~hello/",), b"error", (b"NotStacked",) + ) + bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) bzrdir.open_branch() self.assertFinished(client) - def check_open_repository(self, rich_root, subtrees, - external_lookup=b'no'): + def check_open_repository(self, rich_root, subtrees, external_lookup=b"no"): reference_format = self.get_repo_format() network_name = reference_format.network_name() transport = MemoryTransport() - transport.mkdir('quack') - transport = transport.clone('quack') + transport.mkdir("quack") + transport = transport.clone("quack") if rich_root: - rich_response = b'yes' + rich_response = b"yes" else: - rich_response = b'no' + rich_response = b"no" if subtrees: - subtree_response = b'yes' + subtree_response = b"yes" else: - subtree_response = b'no' + subtree_response = b"no" client = FakeClient(transport.base) client.add_success_response( - b'ok', b'', rich_response, subtree_response, external_lookup, - network_name) - bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + b"ok", b"", rich_response, subtree_response, external_lookup, network_name + ) + bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) result = bzrdir.open_repository() self.assertEqual( - [('call', b'BzrDir.find_repositoryV3', (b'quack/',))], - client._calls) + [("call", b"BzrDir.find_repositoryV3", (b"quack/",))], client._calls + ) self.assertIsInstance(result, RemoteRepository) self.assertEqual(bzrdir, result.controldir) self.assertEqual(rich_root, result._format.rich_root_data) @@ -808,46 +845,50 @@ def test_open_repository_sets_format_attributes(self): self.check_open_repository(False, True) self.check_open_repository(True, False) self.check_open_repository(False, False) - self.check_open_repository(False, False, b'yes') + self.check_open_repository(False, False, b"yes") def test_old_server(self): """RemoteBzrDirFormat should fail to probe if the server version is too old. """ self.assertRaises( - errors.NotBranchError, - RemoteBzrProber.probe_transport, OldServerTransport()) + errors.NotBranchError, RemoteBzrProber.probe_transport, OldServerTransport() + ) class TestBzrDirCreateBranch(TestRemote): - def test_backwards_compat(self): self.setup_smart_server_with_call_log() - repo = self.make_repository('.') + repo = self.make_repository(".") self.reset_smart_call_log() - self.disable_verb(b'BzrDir.create_branch') + self.disable_verb(b"BzrDir.create_branch") repo.controldir.create_branch() create_branch_call_count = len( - [call for call in self.hpss_calls - if call.call.method == b'BzrDir.create_branch']) + [ + call + for call in self.hpss_calls + if call.call.method == b"BzrDir.create_branch" + ] + ) self.assertEqual(1, create_branch_call_count) def test_current_server(self): - transport = self.get_transport('.') - transport = transport.clone('quack') - self.make_repository('quack') + transport = self.get_transport(".") + transport = transport.clone("quack") + self.make_repository("quack") client = FakeClient(transport.base) - reference_bzrdir_format = controldir.format_registry.get('default')() + reference_bzrdir_format = controldir.format_registry.get("default")() reference_format = reference_bzrdir_format.get_branch_format() network_name = reference_format.network_name() reference_repo_fmt = reference_bzrdir_format.repository_format reference_repo_name = reference_repo_fmt.network_name() client.add_expected_call( - b'BzrDir.create_branch', (b'quack/', network_name), - b'success', (b'ok', network_name, b'', b'no', b'no', b'yes', - reference_repo_name)) - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + b"BzrDir.create_branch", + (b"quack/", network_name), + b"success", + (b"ok", network_name, b"", b"no", b"no", b"yes", reference_repo_name), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) branch = a_controldir.create_branch() # We should have got a remote branch self.assertIsInstance(branch, remote.RemoteBranch) @@ -860,22 +901,23 @@ def test_already_open_repo_and_reused_medium(self): regardless of what the smart medium's base URL is. """ self.transport_server = test_server.SmartTCPServer_for_testing - transport = self.get_transport('.') - repo = self.make_repository('quack') + transport = self.get_transport(".") + repo = self.make_repository("quack") # Client's medium rooted a transport root (not at the bzrdir) client = FakeClient(transport.base) - transport = transport.clone('quack') - reference_bzrdir_format = controldir.format_registry.get('default')() + transport = transport.clone("quack") + reference_bzrdir_format = controldir.format_registry.get("default")() reference_format = reference_bzrdir_format.get_branch_format() network_name = reference_format.network_name() reference_repo_fmt = reference_bzrdir_format.repository_format reference_repo_name = reference_repo_fmt.network_name() client.add_expected_call( - b'BzrDir.create_branch', (b'extra/quack/', network_name), - b'success', (b'ok', network_name, b'', b'no', b'no', b'yes', - reference_repo_name)) - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + b"BzrDir.create_branch", + (b"extra/quack/", network_name), + b"success", + (b"ok", network_name, b"", b"no", b"no", b"yes", reference_repo_name), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) branch = a_controldir.create_branch(repository=repo) # We should have got a remote branch self.assertIsInstance(branch, remote.RemoteBranch) @@ -885,32 +927,40 @@ def test_already_open_repo_and_reused_medium(self): class TestBzrDirCreateRepository(TestRemote): - def test_backwards_compat(self): self.setup_smart_server_with_call_log() - bzrdir = self.make_controldir('.') + bzrdir = self.make_controldir(".") self.reset_smart_call_log() - self.disable_verb(b'BzrDir.create_repository') + self.disable_verb(b"BzrDir.create_repository") bzrdir.create_repository() - create_repo_call_count = len([call for call in self.hpss_calls if - call.call.method == b'BzrDir.create_repository']) + create_repo_call_count = len( + [ + call + for call in self.hpss_calls + if call.call.method == b"BzrDir.create_repository" + ] + ) self.assertEqual(1, create_repo_call_count) def test_current_server(self): - transport = self.get_transport('.') - transport = transport.clone('quack') - self.make_controldir('quack') + transport = self.get_transport(".") + transport = transport.clone("quack") + self.make_controldir("quack") client = FakeClient(transport.base) - reference_bzrdir_format = controldir.format_registry.get('default')() + reference_bzrdir_format = controldir.format_registry.get("default")() reference_format = reference_bzrdir_format.repository_format network_name = reference_format.network_name() client.add_expected_call( - b'BzrDir.create_repository', (b'quack/', - b'Bazaar repository format 2a (needs bzr 1.16 or later)\n', - b'False'), - b'success', (b'ok', b'yes', b'yes', b'yes', network_name)) - a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + b"BzrDir.create_repository", + ( + b"quack/", + b"Bazaar repository format 2a (needs bzr 1.16 or later)\n", + b"False", + ), + b"success", + (b"ok", b"yes", b"yes", b"yes", network_name), + ) + a_controldir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) repo = a_controldir.create_repository() # We should have got a remote repository self.assertIsInstance(repo, remote.RemoteRepository) @@ -923,98 +973,101 @@ def test_current_server(self): class TestBzrDirOpenRepository(TestRemote): - def test_backwards_compat_1_2_3(self): # fallback all the way to the first version. reference_format = self.get_repo_format() network_name = reference_format.network_name() - server_url = 'bzr://example.com/' + server_url = "bzr://example.com/" self.permit_url(server_url) client = FakeClient(server_url) - client.add_unknown_method_response(b'BzrDir.find_repositoryV3') - client.add_unknown_method_response(b'BzrDir.find_repositoryV2') - client.add_success_response(b'ok', b'', b'no', b'no') + client.add_unknown_method_response(b"BzrDir.find_repositoryV3") + client.add_unknown_method_response(b"BzrDir.find_repositoryV2") + client.add_success_response(b"ok", b"", b"no", b"no") # A real repository instance will be created to determine the network # name. client.add_success_response_with_body( - b"Bazaar-NG meta directory, format 1\n", b'ok') - client.add_success_response(b'stat', b'0', b'65535') + b"Bazaar-NG meta directory, format 1\n", b"ok" + ) + client.add_success_response(b"stat", b"0", b"65535") client.add_success_response_with_body( - reference_format.get_format_string(), b'ok') + reference_format.get_format_string(), b"ok" + ) # PackRepository wants to do a stat - client.add_success_response(b'stat', b'0', b'65535') - remote_transport = RemoteTransport(server_url + 'quack/', medium=False, - _client=client) - bzrdir = RemoteBzrDir(remote_transport, RemoteBzrDirFormat(), - _client=client) + client.add_success_response(b"stat", b"0", b"65535") + remote_transport = RemoteTransport( + server_url + "quack/", medium=False, _client=client + ) + bzrdir = RemoteBzrDir(remote_transport, RemoteBzrDirFormat(), _client=client) repo = bzrdir.open_repository() self.assertEqual( - [('call', b'BzrDir.find_repositoryV3', (b'quack/',)), - ('call', b'BzrDir.find_repositoryV2', (b'quack/',)), - ('call', b'BzrDir.find_repository', (b'quack/',)), - ('call_expecting_body', b'get', (b'/quack/.bzr/branch-format',)), - ('call', b'stat', (b'/quack/.bzr',)), - ('call_expecting_body', b'get', (b'/quack/.bzr/repository/format',)), - ('call', b'stat', (b'/quack/.bzr/repository',)), - ], - client._calls) + [ + ("call", b"BzrDir.find_repositoryV3", (b"quack/",)), + ("call", b"BzrDir.find_repositoryV2", (b"quack/",)), + ("call", b"BzrDir.find_repository", (b"quack/",)), + ("call_expecting_body", b"get", (b"/quack/.bzr/branch-format",)), + ("call", b"stat", (b"/quack/.bzr",)), + ("call_expecting_body", b"get", (b"/quack/.bzr/repository/format",)), + ("call", b"stat", (b"/quack/.bzr/repository",)), + ], + client._calls, + ) self.assertEqual(network_name, repo._format.network_name()) def test_backwards_compat_2(self): # fallback to find_repositoryV2 reference_format = self.get_repo_format() network_name = reference_format.network_name() - server_url = 'bzr://example.com/' + server_url = "bzr://example.com/" self.permit_url(server_url) client = FakeClient(server_url) - client.add_unknown_method_response(b'BzrDir.find_repositoryV3') - client.add_success_response(b'ok', b'', b'no', b'no', b'no') + client.add_unknown_method_response(b"BzrDir.find_repositoryV3") + client.add_success_response(b"ok", b"", b"no", b"no", b"no") # A real repository instance will be created to determine the network # name. client.add_success_response_with_body( - b"Bazaar-NG meta directory, format 1\n", b'ok') - client.add_success_response(b'stat', b'0', b'65535') + b"Bazaar-NG meta directory, format 1\n", b"ok" + ) + client.add_success_response(b"stat", b"0", b"65535") client.add_success_response_with_body( - reference_format.get_format_string(), b'ok') + reference_format.get_format_string(), b"ok" + ) # PackRepository wants to do a stat - client.add_success_response(b'stat', b'0', b'65535') - remote_transport = RemoteTransport(server_url + 'quack/', medium=False, - _client=client) - bzrdir = RemoteBzrDir(remote_transport, RemoteBzrDirFormat(), - _client=client) + client.add_success_response(b"stat", b"0", b"65535") + remote_transport = RemoteTransport( + server_url + "quack/", medium=False, _client=client + ) + bzrdir = RemoteBzrDir(remote_transport, RemoteBzrDirFormat(), _client=client) repo = bzrdir.open_repository() self.assertEqual( - [('call', b'BzrDir.find_repositoryV3', (b'quack/',)), - ('call', b'BzrDir.find_repositoryV2', (b'quack/',)), - ('call_expecting_body', b'get', (b'/quack/.bzr/branch-format',)), - ('call', b'stat', (b'/quack/.bzr',)), - ('call_expecting_body', b'get', - (b'/quack/.bzr/repository/format',)), - ('call', b'stat', (b'/quack/.bzr/repository',)), - ], - client._calls) + [ + ("call", b"BzrDir.find_repositoryV3", (b"quack/",)), + ("call", b"BzrDir.find_repositoryV2", (b"quack/",)), + ("call_expecting_body", b"get", (b"/quack/.bzr/branch-format",)), + ("call", b"stat", (b"/quack/.bzr",)), + ("call_expecting_body", b"get", (b"/quack/.bzr/repository/format",)), + ("call", b"stat", (b"/quack/.bzr/repository",)), + ], + client._calls, + ) self.assertEqual(network_name, repo._format.network_name()) def test_current_server(self): reference_format = self.get_repo_format() network_name = reference_format.network_name() transport = MemoryTransport() - transport.mkdir('quack') - transport = transport.clone('quack') + transport.mkdir("quack") + transport = transport.clone("quack") client = FakeClient(transport.base) - client.add_success_response( - b'ok', b'', b'no', b'no', b'no', network_name) - bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + client.add_success_response(b"ok", b"", b"no", b"no", b"no", network_name) + bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) repo = bzrdir.open_repository() self.assertEqual( - [('call', b'BzrDir.find_repositoryV3', (b'quack/',))], - client._calls) + [("call", b"BzrDir.find_repositoryV3", (b"quack/",))], client._calls + ) self.assertEqual(network_name, repo._format.network_name()) class TestBzrDirFormatInitializeEx(TestRemote): - def test_success(self): """Simple test for typical successful call.""" fmt = RemoteBzrDirFormat() @@ -1022,18 +1075,50 @@ def test_success(self): transport = self.get_transport() client = FakeClient(transport.base) client.add_expected_call( - b'BzrDirFormat.initialize_ex_1.16', - (default_format_name, b'path', b'False', b'False', b'False', b'', - b'', b'', b'', b'False'), - b'success', - (b'.', b'no', b'no', b'yes', b'repo fmt', b'repo bzrdir fmt', - b'bzrdir fmt', b'False', b'', b'', b'repo lock token')) + b"BzrDirFormat.initialize_ex_1.16", + ( + default_format_name, + b"path", + b"False", + b"False", + b"False", + b"", + b"", + b"", + b"", + b"False", + ), + b"success", + ( + b".", + b"no", + b"no", + b"yes", + b"repo fmt", + b"repo bzrdir fmt", + b"bzrdir fmt", + b"False", + b"", + b"", + b"repo lock token", + ), + ) # XXX: It would be better to call fmt.initialize_on_transport_ex, but # it's currently hard to test that without supplying a real remote # transport connected to a real server. fmt._initialize_on_transport_ex_rpc( - client, b'path', transport, False, False, False, None, None, None, - None, False) + client, + b"path", + transport, + False, + False, + False, + None, + None, + None, + None, + False, + ) self.assertFinished(client) def test_error(self): @@ -1045,30 +1130,55 @@ def test_error(self): transport = self.get_transport() client = FakeClient(transport.base) client.add_expected_call( - b'BzrDirFormat.initialize_ex_1.16', - (default_format_name, b'path', b'False', b'False', b'False', b'', - b'', b'', b'', b'False'), - b'error', - (b'PermissionDenied', b'path', b'extra info')) + b"BzrDirFormat.initialize_ex_1.16", + ( + default_format_name, + b"path", + b"False", + b"False", + b"False", + b"", + b"", + b"", + b"", + b"False", + ), + b"error", + (b"PermissionDenied", b"path", b"extra info"), + ) # XXX: It would be better to call fmt.initialize_on_transport_ex, but # it's currently hard to test that without supplying a real remote # transport connected to a real server. err = self.assertRaises( errors.PermissionDenied, - fmt._initialize_on_transport_ex_rpc, client, b'path', transport, - False, False, False, None, None, None, None, False) - self.assertEqual('path', err.path) - self.assertEqual(': extra info', err.extra) + fmt._initialize_on_transport_ex_rpc, + client, + b"path", + transport, + False, + False, + False, + None, + None, + None, + None, + False, + ) + self.assertEqual("path", err.path) + self.assertEqual(": extra info", err.extra) self.assertFinished(client) def test_error_from_real_server(self): """Integration test for error translation.""" - transport = self.make_smart_server('foo') - transport = transport.clone('no-such-path') + transport = self.make_smart_server("foo") + transport = transport.clone("no-such-path") fmt = RemoteBzrDirFormat() self.assertRaises( - _mod_transport.NoSuchFile, fmt.initialize_on_transport_ex, transport, - create_prefix=False) + _mod_transport.NoSuchFile, + fmt.initialize_on_transport_ex, + transport, + create_prefix=False, + ) class OldSmartClient: @@ -1077,10 +1187,9 @@ class OldSmartClient: """ def get_request(self): - input_file = BytesIO(b'ok\x011\n') + input_file = BytesIO(b"ok\x011\n") output_file = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input_file, output_file) + client_medium = medium.SmartSimplePipesClientMedium(input_file, output_file) return medium.SmartClientStreamMediumRequest(client_medium) def protocol_version(self): @@ -1093,31 +1202,28 @@ class OldServerTransport: """ def __init__(self): - self.base = 'fake:' + self.base = "fake:" def get_smart_client(self): return OldSmartClient() class RemoteBzrDirTestCase(TestRemote): - def make_remote_bzrdir(self, transport, client): """Make a RemotebzrDir using 'client' as the _client.""" - return RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + return RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) class RemoteBranchTestCase(RemoteBzrDirTestCase): - def lock_remote_branch(self, branch): """Trick a RemoteBranch into thinking it is locked.""" - branch._lock_mode = 'w' + branch._lock_mode = "w" branch._lock_count = 2 - branch._lock_token = b'branch token' - branch._repo_lock_token = b'repo token' - branch.repository._lock_mode = 'w' + branch._lock_token = b"branch token" + branch._repo_lock_token = b"repo token" + branch.repository._lock_mode = "w" branch.repository._lock_count = 2 - branch.repository._lock_token = b'repo token' + branch.repository._lock_token = b"repo token" def make_remote_branch(self, transport, client): """Make a RemoteBranch using 'client' as its _SmartClient. @@ -1136,36 +1242,34 @@ def make_remote_branch(self, transport, client): class TestBranchBreakLock(RemoteBranchTestCase): - def test_break_lock(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.break_lock', (b'quack/',), - b'success', (b'ok',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.break_lock", (b"quack/",), b"success", (b"ok",) + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) branch.break_lock() self.assertFinished(client) class TestBranchGetPhysicalLockStatus(RemoteBranchTestCase): - def test_get_physical_lock_status_yes(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.get_physical_lock_status', (b'quack/',), - b'success', (b'yes',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.get_physical_lock_status", (b"quack/",), b"success", (b"yes",) + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) result = branch.get_physical_lock_status() self.assertFinished(client) @@ -1175,13 +1279,13 @@ def test_get_physical_lock_status_no(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.get_physical_lock_status', (b'quack/',), - b'success', (b'no',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.get_physical_lock_status", (b"quack/",), b"success", (b"no",) + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) result = branch.get_physical_lock_status() self.assertFinished(client) @@ -1189,19 +1293,16 @@ def test_get_physical_lock_status_no(self): class TestBranchGetParent(RemoteBranchTestCase): - def test_no_parent(self): # in an empty branch we decode the response properly transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) - client.add_expected_call( - b'Branch.get_parent', (b'quack/',), - b'success', (b'',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) + client.add_expected_call(b"Branch.get_parent", (b"quack/",), b"success", (b"",)) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) result = branch.get_parent() self.assertFinished(client) @@ -1211,51 +1312,50 @@ def test_parent_relative(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'kwaak/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"kwaak/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.get_parent', (b'kwaak/',), - b'success', (b'../foo/',)) - transport.mkdir('kwaak') - transport = transport.clone('kwaak') + b"Branch.get_parent", (b"kwaak/",), b"success", (b"../foo/",) + ) + transport.mkdir("kwaak") + transport = transport.clone("kwaak") branch = self.make_remote_branch(transport, client) result = branch.get_parent() - self.assertEqual(transport.clone('../foo').base, result) + self.assertEqual(transport.clone("../foo").base, result) def test_parent_absolute(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'kwaak/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"kwaak/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.get_parent', (b'kwaak/',), - b'success', (b'http://foo/',)) - transport.mkdir('kwaak') - transport = transport.clone('kwaak') + b"Branch.get_parent", (b"kwaak/",), b"success", (b"http://foo/",) + ) + transport.mkdir("kwaak") + transport = transport.clone("kwaak") branch = self.make_remote_branch(transport, client) result = branch.get_parent() - self.assertEqual('http://foo/', result) + self.assertEqual("http://foo/", result) self.assertFinished(client) class TestBranchSetParentLocation(RemoteBranchTestCase): - def test_no_parent(self): # We call the verb when setting parent to None transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.set_parent_location', (b'quack/', b'b', b'r', b''), - b'success', ()) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.set_parent_location", (b"quack/", b"b", b"r", b""), b"success", () + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) - branch._lock_token = b'b' - branch._repo_lock_token = b'r' + branch._lock_token = b"b" + branch._repo_lock_token = b"r" branch._set_parent_location(None) self.assertFinished(client) @@ -1263,53 +1363,54 @@ def test_parent(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'kwaak/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"kwaak/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.set_parent_location', (b'kwaak/', b'b', b'r', b'foo'), - b'success', ()) - transport.mkdir('kwaak') - transport = transport.clone('kwaak') + b"Branch.set_parent_location", + (b"kwaak/", b"b", b"r", b"foo"), + b"success", + (), + ) + transport.mkdir("kwaak") + transport = transport.clone("kwaak") branch = self.make_remote_branch(transport, client) - branch._lock_token = b'b' - branch._repo_lock_token = b'r' - branch._set_parent_location('foo') + branch._lock_token = b"b" + branch._repo_lock_token = b"r" + branch._set_parent_location("foo") self.assertFinished(client) def test_backwards_compat(self): self.setup_smart_server_with_call_log() - branch = self.make_branch('.') + branch = self.make_branch(".") self.reset_smart_call_log() - verb = b'Branch.set_parent_location' + verb = b"Branch.set_parent_location" self.disable_verb(verb) - branch.set_parent('http://foo/') + branch.set_parent("http://foo/") self.assertLength(14, self.hpss_calls) class TestBranchGetTagsBytes(RemoteBranchTestCase): - def test_backwards_compat(self): self.setup_smart_server_with_call_log() - branch = self.make_branch('.') + branch = self.make_branch(".") self.reset_smart_call_log() - verb = b'Branch.get_tags_bytes' + verb = b"Branch.get_tags_bytes" self.disable_verb(verb) branch.tags.get_tag_dict() - call_count = len([call for call in self.hpss_calls if - call.call.method == verb]) + call_count = len([call for call in self.hpss_calls if call.call.method == verb]) self.assertEqual(1, call_count) def test_trivial(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.get_tags_bytes', (b'quack/',), - b'success', (b'',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.get_tags_bytes", (b"quack/",), b"success", (b"",) + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) result = branch.tags.get_tag_dict() self.assertFinished(client) @@ -1317,37 +1418,40 @@ def test_trivial(self): class TestBranchSetTagsBytes(RemoteBranchTestCase): - def test_trivial(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.set_tags_bytes', (b'quack/', - b'branch token', b'repo token'), - b'success', ('',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.set_tags_bytes", + (b"quack/", b"branch token", b"repo token"), + b"success", + ("",), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) self.lock_remote_branch(branch) - branch._set_tags_bytes(b'tags bytes') + branch._set_tags_bytes(b"tags bytes") self.assertFinished(client) - self.assertEqual(b'tags bytes', client._calls[-1][-1]) + self.assertEqual(b"tags bytes", client._calls[-1][-1]) def test_backwards_compatible(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.set_tags_bytes', (b'quack/', - b'branch token', b'repo token'), - b'unknown', (b'Branch.set_tags_bytes',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.set_tags_bytes", + (b"quack/", b"branch token", b"repo token"), + b"unknown", + (b"Branch.set_tags_bytes",), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) self.lock_remote_branch(branch) @@ -1356,139 +1460,157 @@ def __init__(self): self.calls = [] def _set_tags_bytes(self, bytes): - self.calls.append(('set_tags_bytes', bytes)) + self.calls.append(("set_tags_bytes", bytes)) + real_branch = StubRealBranch() branch._real_branch = real_branch - branch._set_tags_bytes(b'tags bytes') + branch._set_tags_bytes(b"tags bytes") # Call a second time, to exercise the 'remote version already inferred' # code path. - branch._set_tags_bytes(b'tags bytes') + branch._set_tags_bytes(b"tags bytes") self.assertFinished(client) - self.assertEqual( - [('set_tags_bytes', b'tags bytes')] * 2, real_branch.calls) + self.assertEqual([("set_tags_bytes", b"tags bytes")] * 2, real_branch.calls) class TestBranchHeadsToFetch(RemoteBranchTestCase): - def test_uses_last_revision_info_and_tags_by_default(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.last_revision_info', (b'quack/',), - b'success', (b'ok', b'1', b'rev-tip')) + b"Branch.last_revision_info", + (b"quack/",), + b"success", + (b"ok", b"1", b"rev-tip"), + ) client.add_expected_call( - b'Branch.get_config_file', (b'quack/',), - b'success', (b'ok',), b'') - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.get_config_file", (b"quack/",), b"success", (b"ok",), b"" + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) result = branch.heads_to_fetch() self.assertFinished(client) - self.assertEqual(({b'rev-tip'}, set()), result) + self.assertEqual(({b"rev-tip"}, set()), result) def test_uses_last_revision_info_and_tags_when_set(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.last_revision_info', (b'quack/',), - b'success', (b'ok', b'1', b'rev-tip')) + b"Branch.last_revision_info", + (b"quack/",), + b"success", + (b"ok", b"1", b"rev-tip"), + ) client.add_expected_call( - b'Branch.get_config_file', (b'quack/',), - b'success', (b'ok',), b'branch.fetch_tags = True') + b"Branch.get_config_file", + (b"quack/",), + b"success", + (b"ok",), + b"branch.fetch_tags = True", + ) # XXX: this will break if the default format's serialization of tags # changes, or if the RPC for fetching tags changes from get_tags_bytes. client.add_expected_call( - b'Branch.get_tags_bytes', (b'quack/',), - b'success', (b'd5:tag-17:rev-foo5:tag-27:rev-bare',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.get_tags_bytes", + (b"quack/",), + b"success", + (b"d5:tag-17:rev-foo5:tag-27:rev-bare",), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) result = branch.heads_to_fetch() self.assertFinished(client) - self.assertEqual( - ({b'rev-tip'}, {b'rev-foo', b'rev-bar'}), result) + self.assertEqual(({b"rev-tip"}, {b"rev-foo", b"rev-bar"}), result) def test_uses_rpc_for_formats_with_non_default_heads_to_fetch(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.heads_to_fetch', (b'quack/',), - b'success', ([b'tip'], [b'tagged-1', b'tagged-2'])) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.heads_to_fetch", + (b"quack/",), + b"success", + ([b"tip"], [b"tagged-1", b"tagged-2"]), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) branch._format._use_default_local_heads_to_fetch = lambda: False result = branch.heads_to_fetch() self.assertFinished(client) - self.assertEqual(({b'tip'}, {b'tagged-1', b'tagged-2'}), result) + self.assertEqual(({b"tip"}, {b"tagged-1", b"tagged-2"}), result) def make_branch_with_tags(self): self.setup_smart_server_with_call_log() # Make a branch with a single revision. - builder = self.make_branch_builder('foo') + builder = self.make_branch_builder("foo") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', ''))], - revision_id=b'tip') + builder.build_snapshot( + None, [("add", ("", b"root-id", "directory", ""))], revision_id=b"tip" + ) builder.finish_series() branch = builder.get_branch() # Add two tags to that branch - branch.tags.set_tag('tag-1', b'rev-1') - branch.tags.set_tag('tag-2', b'rev-2') + branch.tags.set_tag("tag-1", b"rev-1") + branch.tags.set_tag("tag-2", b"rev-2") return branch def test_backwards_compatible(self): br = self.make_branch_with_tags() - br.get_config_stack().set('branch.fetch_tags', True) + br.get_config_stack().set("branch.fetch_tags", True) self.addCleanup(br.lock_read().unlock) # Disable the heads_to_fetch verb - verb = b'Branch.heads_to_fetch' + verb = b"Branch.heads_to_fetch" self.disable_verb(verb) self.reset_smart_call_log() result = br.heads_to_fetch() - self.assertEqual(({b'tip'}, {b'rev-1', b'rev-2'}), result) + self.assertEqual(({b"tip"}, {b"rev-1", b"rev-2"}), result) self.assertEqual( - [b'Branch.last_revision_info', b'Branch.get_tags_bytes'], - [call.call.method for call in self.hpss_calls]) + [b"Branch.last_revision_info", b"Branch.get_tags_bytes"], + [call.call.method for call in self.hpss_calls], + ) def test_backwards_compatible_no_tags(self): br = self.make_branch_with_tags() - br.get_config_stack().set('branch.fetch_tags', False) + br.get_config_stack().set("branch.fetch_tags", False) self.addCleanup(br.lock_read().unlock) # Disable the heads_to_fetch verb - verb = b'Branch.heads_to_fetch' + verb = b"Branch.heads_to_fetch" self.disable_verb(verb) self.reset_smart_call_log() result = br.heads_to_fetch() - self.assertEqual(({b'tip'}, set()), result) + self.assertEqual(({b"tip"}, set()), result) self.assertEqual( - [b'Branch.last_revision_info'], - [call.call.method for call in self.hpss_calls]) + [b"Branch.last_revision_info"], + [call.call.method for call in self.hpss_calls], + ) class TestBranchLastRevisionInfo(RemoteBranchTestCase): - def test_empty_branch(self): # in an empty branch we decode the response properly transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.last_revision_info', (b'quack/',), - b'success', (b'ok', b'0', b'null:')) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.last_revision_info", + (b"quack/",), + b"success", + (b"ok", b"0", b"null:"), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) result = branch.last_revision_info() self.assertFinished(client) @@ -1496,17 +1618,17 @@ def test_empty_branch(self): def test_non_empty_branch(self): # in a non-empty branch we also decode the response properly - revid = '\xc8'.encode() + revid = "\xc8".encode() transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'kwaak/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"kwaak/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.last_revision_info', (b'kwaak/',), - b'success', (b'ok', b'2', revid)) - transport.mkdir('kwaak') - transport = transport.clone('kwaak') + b"Branch.last_revision_info", (b"kwaak/",), b"success", (b"ok", b"2", revid) + ) + transport.mkdir("kwaak") + transport = transport.clone("kwaak") branch = self.make_remote_branch(transport, client) result = branch.last_revision_info() self.assertEqual((2, revid), result) @@ -1521,90 +1643,129 @@ def test_get_stacked_on_invalid_url(self): # really isn't anything we can do to be 100% sure that the server # doesn't just open in - this test probably needs to be rewritten using # a spawn()ed server. - stacked_branch = self.make_branch('stacked', format='1.9') - self.make_branch('base', format='1.9') - vfs_url = self.get_vfs_only_url('base') + stacked_branch = self.make_branch("stacked", format="1.9") + self.make_branch("base", format="1.9") + vfs_url = self.get_vfs_only_url("base") stacked_branch.set_stacked_on_url(vfs_url) transport = stacked_branch.controldir.root_transport client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'stacked/',), - b'success', (b'ok', vfs_url.encode('utf-8'))) + b"Branch.get_stacked_on_url", + (b"stacked/",), + b"success", + (b"ok", vfs_url.encode("utf-8")), + ) # XXX: Multiple calls are bad, this second call documents what is # today. client.add_expected_call( - b'Branch.get_stacked_on_url', (b'stacked/',), - b'success', (b'ok', vfs_url.encode('utf-8'))) - bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=client) + b"Branch.get_stacked_on_url", + (b"stacked/",), + b"success", + (b"ok", vfs_url.encode("utf-8")), + ) + bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=client) repo_fmt = remote.RemoteRepositoryFormat() repo_fmt._custom_format = stacked_branch.repository._format - branch = RemoteBranch(bzrdir, RemoteRepository(bzrdir, repo_fmt), - _client=client) + branch = RemoteBranch( + bzrdir, RemoteRepository(bzrdir, repo_fmt), _client=client + ) result = branch.get_stacked_on_url() self.assertEqual(vfs_url, result) def test_backwards_compatible(self): # like with bzr1.6 with no Branch.get_stacked_on_url rpc - self.make_branch('base', format='1.6') - stacked_branch = self.make_branch('stacked', format='1.6') - stacked_branch.set_stacked_on_url('../base') + self.make_branch("base", format="1.6") + stacked_branch = self.make_branch("stacked", format="1.6") + stacked_branch.set_stacked_on_url("../base") client = FakeClient(self.get_url()) branch_network_name = self.get_branch_format().network_name() client.add_expected_call( - b'BzrDir.open_branchV3', (b'stacked/',), - b'success', (b'branch', branch_network_name)) + b"BzrDir.open_branchV3", + (b"stacked/",), + b"success", + (b"branch", branch_network_name), + ) client.add_expected_call( - b'BzrDir.find_repositoryV3', (b'stacked/',), - b'success', (b'ok', b'', b'no', b'no', b'yes', - stacked_branch.repository._format.network_name())) + b"BzrDir.find_repositoryV3", + (b"stacked/",), + b"success", + ( + b"ok", + b"", + b"no", + b"no", + b"yes", + stacked_branch.repository._format.network_name(), + ), + ) # called twice, once from constructor and then again by us client.add_expected_call( - b'Branch.get_stacked_on_url', (b'stacked/',), - b'unknown', (b'Branch.get_stacked_on_url',)) + b"Branch.get_stacked_on_url", + (b"stacked/",), + b"unknown", + (b"Branch.get_stacked_on_url",), + ) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'stacked/',), - b'unknown', (b'Branch.get_stacked_on_url',)) + b"Branch.get_stacked_on_url", + (b"stacked/",), + b"unknown", + (b"Branch.get_stacked_on_url",), + ) # this will also do vfs access, but that goes direct to the transport # and isn't seen by the FakeClient. - bzrdir = RemoteBzrDir(self.get_transport('stacked'), - RemoteBzrDirFormat(), _client=client) + bzrdir = RemoteBzrDir( + self.get_transport("stacked"), RemoteBzrDirFormat(), _client=client + ) branch = bzrdir.open_branch() result = branch.get_stacked_on_url() - self.assertEqual('../base', result) + self.assertEqual("../base", result) self.assertFinished(client) # it's in the fallback list both for the RemoteRepository and its vfs # repository self.assertEqual(1, len(branch.repository._fallback_repositories)) - self.assertEqual(1, - len(branch.repository._real_repository._fallback_repositories)) + self.assertEqual( + 1, len(branch.repository._real_repository._fallback_repositories) + ) def test_get_stacked_on_real_branch(self): - self.make_branch('base') - stacked_branch = self.make_branch('stacked') - stacked_branch.set_stacked_on_url('../base') + self.make_branch("base") + stacked_branch = self.make_branch("stacked") + stacked_branch.set_stacked_on_url("../base") reference_format = self.get_repo_format() network_name = reference_format.network_name() client = FakeClient(self.get_url()) branch_network_name = self.get_branch_format().network_name() client.add_expected_call( - b'BzrDir.open_branchV3', (b'stacked/',), - b'success', (b'branch', branch_network_name)) + b"BzrDir.open_branchV3", + (b"stacked/",), + b"success", + (b"branch", branch_network_name), + ) client.add_expected_call( - b'BzrDir.find_repositoryV3', (b'stacked/',), - b'success', (b'ok', b'', b'yes', b'no', b'yes', network_name)) + b"BzrDir.find_repositoryV3", + (b"stacked/",), + b"success", + (b"ok", b"", b"yes", b"no", b"yes", network_name), + ) # called twice, once from constructor and then again by us client.add_expected_call( - b'Branch.get_stacked_on_url', (b'stacked/',), - b'success', (b'ok', b'../base')) + b"Branch.get_stacked_on_url", + (b"stacked/",), + b"success", + (b"ok", b"../base"), + ) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'stacked/',), - b'success', (b'ok', b'../base')) - bzrdir = RemoteBzrDir(self.get_transport('stacked'), - RemoteBzrDirFormat(), _client=client) + b"Branch.get_stacked_on_url", + (b"stacked/",), + b"success", + (b"ok", b"../base"), + ) + bzrdir = RemoteBzrDir( + self.get_transport("stacked"), RemoteBzrDirFormat(), _client=client + ) branch = bzrdir.open_branch() result = branch.get_stacked_on_url() - self.assertEqual('../base', result) + self.assertEqual("../base", result) self.assertFinished(client) # it's in the fallback list both for the RemoteRepository. self.assertEqual(1, len(branch.repository._fallback_repositories)) @@ -1613,32 +1774,46 @@ def test_get_stacked_on_real_branch(self): class TestBranchSetLastRevision(RemoteBranchTestCase): - def test_set_empty(self): # _set_last_revision_info('null:') is translated to calling # Branch.set_last_revision(path, '') on the wire. transport = MemoryTransport() - transport.mkdir('branch') - transport = transport.clone('branch') + transport.mkdir("branch") + transport = transport.clone("branch") client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'branch/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"branch/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.lock_write', (b'branch/', b'', b''), - b'success', (b'ok', b'branch token', b'repo token')) + b"Branch.lock_write", + (b"branch/", b"", b""), + b"success", + (b"ok", b"branch token", b"repo token"), + ) client.add_expected_call( - b'Branch.last_revision_info', - (b'branch/',), - b'success', (b'ok', b'0', b'null:')) + b"Branch.last_revision_info", + (b"branch/",), + b"success", + (b"ok", b"0", b"null:"), + ) client.add_expected_call( - b'Branch.set_last_revision', (b'branch/', - b'branch token', b'repo token', b'null:',), - b'success', (b'ok',)) + b"Branch.set_last_revision", + ( + b"branch/", + b"branch token", + b"repo token", + b"null:", + ), + b"success", + (b"ok",), + ) client.add_expected_call( - b'Branch.unlock', (b'branch/', b'branch token', b'repo token'), - b'success', (b'ok',)) + b"Branch.unlock", + (b"branch/", b"branch token", b"repo token"), + b"success", + (b"ok",), + ) branch = self.make_remote_branch(transport, client) branch.lock_write() result = branch._set_last_revision(NULL_REVISION) @@ -1650,71 +1825,100 @@ def test_set_nonempty(self): # set_last_revision_info(N, rev-idN) is translated to calling # Branch.set_last_revision(path, rev-idN) on the wire. transport = MemoryTransport() - transport.mkdir('branch') - transport = transport.clone('branch') + transport.mkdir("branch") + transport = transport.clone("branch") client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'branch/',), - b'error', (b'NotStacked',)) - client.add_expected_call( - b'Branch.lock_write', (b'branch/', b'', b''), - b'success', (b'ok', b'branch token', b'repo token')) + b"Branch.get_stacked_on_url", (b"branch/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.last_revision_info', - (b'branch/',), - b'success', (b'ok', b'0', b'null:')) - lines = [b'rev-id2'] - encoded_body = bz2.compress(b'\n'.join(lines)) - client.add_success_response_with_body(encoded_body, b'ok') + b"Branch.lock_write", + (b"branch/", b"", b""), + b"success", + (b"ok", b"branch token", b"repo token"), + ) client.add_expected_call( - b'Branch.set_last_revision', (b'branch/', - b'branch token', b'repo token', b'rev-id2',), - b'success', (b'ok',)) + b"Branch.last_revision_info", + (b"branch/",), + b"success", + (b"ok", b"0", b"null:"), + ) + lines = [b"rev-id2"] + encoded_body = bz2.compress(b"\n".join(lines)) + client.add_success_response_with_body(encoded_body, b"ok") + client.add_expected_call( + b"Branch.set_last_revision", + ( + b"branch/", + b"branch token", + b"repo token", + b"rev-id2", + ), + b"success", + (b"ok",), + ) client.add_expected_call( - b'Branch.unlock', (b'branch/', b'branch token', b'repo token'), - b'success', (b'ok',)) + b"Branch.unlock", + (b"branch/", b"branch token", b"repo token"), + b"success", + (b"ok",), + ) branch = self.make_remote_branch(transport, client) # Lock the branch, reset the record of remote calls. branch.lock_write() - result = branch._set_last_revision(b'rev-id2') + result = branch._set_last_revision(b"rev-id2") branch.unlock() self.assertEqual(None, result) self.assertFinished(client) def test_no_such_revision(self): transport = MemoryTransport() - transport.mkdir('branch') - transport = transport.clone('branch') + transport.mkdir("branch") + transport = transport.clone("branch") # A response of 'NoSuchRevision' is translated into an exception. client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'branch/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"branch/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.lock_write', (b'branch/', b'', b''), - b'success', (b'ok', b'branch token', b'repo token')) + b"Branch.lock_write", + (b"branch/", b"", b""), + b"success", + (b"ok", b"branch token", b"repo token"), + ) client.add_expected_call( - b'Branch.last_revision_info', - (b'branch/',), - b'success', (b'ok', b'0', b'null:')) + b"Branch.last_revision_info", + (b"branch/",), + b"success", + (b"ok", b"0", b"null:"), + ) # get_graph calls to construct the revision history, for the set_rh # hook - lines = [b'rev-id'] - encoded_body = bz2.compress(b'\n'.join(lines)) - client.add_success_response_with_body(encoded_body, b'ok') - client.add_expected_call( - b'Branch.set_last_revision', (b'branch/', - b'branch token', b'repo token', b'rev-id',), - b'error', (b'NoSuchRevision', b'rev-id')) + lines = [b"rev-id"] + encoded_body = bz2.compress(b"\n".join(lines)) + client.add_success_response_with_body(encoded_body, b"ok") + client.add_expected_call( + b"Branch.set_last_revision", + ( + b"branch/", + b"branch token", + b"repo token", + b"rev-id", + ), + b"error", + (b"NoSuchRevision", b"rev-id"), + ) client.add_expected_call( - b'Branch.unlock', (b'branch/', b'branch token', b'repo token'), - b'success', (b'ok',)) + b"Branch.unlock", + (b"branch/", b"branch token", b"repo token"), + b"success", + (b"ok",), + ) branch = self.make_remote_branch(transport, client) branch.lock_write() - self.assertRaises( - errors.NoSuchRevision, branch._set_last_revision, b'rev-id') + self.assertRaises(errors.NoSuchRevision, branch._set_last_revision, b"rev-id") branch.unlock() self.assertFinished(client) @@ -1723,38 +1927,53 @@ def test_tip_change_rejected(self): be raised. """ transport = MemoryTransport() - transport.mkdir('branch') - transport = transport.clone('branch') + transport.mkdir("branch") + transport = transport.clone("branch") client = FakeClient(transport.base) - rejection_msg_unicode = 'rejection message\N{INTERROBANG}' - rejection_msg_utf8 = rejection_msg_unicode.encode('utf8') - client.add_expected_call( - b'Branch.get_stacked_on_url', (b'branch/',), - b'error', (b'NotStacked',)) + rejection_msg_unicode = "rejection message\N{INTERROBANG}" + rejection_msg_utf8 = rejection_msg_unicode.encode("utf8") client.add_expected_call( - b'Branch.lock_write', (b'branch/', b'', b''), - b'success', (b'ok', b'branch token', b'repo token')) + b"Branch.get_stacked_on_url", (b"branch/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.last_revision_info', - (b'branch/',), - b'success', (b'ok', b'0', b'null:')) - lines = [b'rev-id'] - encoded_body = bz2.compress(b'\n'.join(lines)) - client.add_success_response_with_body(encoded_body, b'ok') + b"Branch.lock_write", + (b"branch/", b"", b""), + b"success", + (b"ok", b"branch token", b"repo token"), + ) client.add_expected_call( - b'Branch.set_last_revision', (b'branch/', - b'branch token', b'repo token', b'rev-id',), - b'error', (b'TipChangeRejected', rejection_msg_utf8)) + b"Branch.last_revision_info", + (b"branch/",), + b"success", + (b"ok", b"0", b"null:"), + ) + lines = [b"rev-id"] + encoded_body = bz2.compress(b"\n".join(lines)) + client.add_success_response_with_body(encoded_body, b"ok") + client.add_expected_call( + b"Branch.set_last_revision", + ( + b"branch/", + b"branch token", + b"repo token", + b"rev-id", + ), + b"error", + (b"TipChangeRejected", rejection_msg_utf8), + ) client.add_expected_call( - b'Branch.unlock', (b'branch/', b'branch token', b'repo token'), - b'success', (b'ok',)) + b"Branch.unlock", + (b"branch/", b"branch token", b"repo token"), + b"success", + (b"ok",), + ) branch = self.make_remote_branch(transport, client) branch.lock_write() # The 'TipChangeRejected' error response triggered by calling # set_last_revision_info causes a TipChangeRejected exception. err = self.assertRaises( - errors.TipChangeRejected, - branch._set_last_revision, b'rev-id') + errors.TipChangeRejected, branch._set_last_revision, b"rev-id" + ) # The UTF-8 message from the response has been decoded into a unicode # object. self.assertIsInstance(err.msg, str) @@ -1764,52 +1983,62 @@ def test_tip_change_rejected(self): class TestBranchSetLastRevisionInfo(RemoteBranchTestCase): - def test_set_last_revision_info(self): # set_last_revision_info(num, b'rev-id') is translated to calling # Branch.set_last_revision_info(num, 'rev-id') on the wire. transport = MemoryTransport() - transport.mkdir('branch') - transport = transport.clone('branch') + transport.mkdir("branch") + transport = transport.clone("branch") client = FakeClient(transport.base) # get_stacked_on_url - client.add_error_response(b'NotStacked') + client.add_error_response(b"NotStacked") # lock_write - client.add_success_response(b'ok', b'branch token', b'repo token') + client.add_success_response(b"ok", b"branch token", b"repo token") # query the current revision - client.add_success_response(b'ok', b'0', b'null:') + client.add_success_response(b"ok", b"0", b"null:") # set_last_revision - client.add_success_response(b'ok') + client.add_success_response(b"ok") # unlock - client.add_success_response(b'ok') + client.add_success_response(b"ok") branch = self.make_remote_branch(transport, client) # Lock the branch, reset the record of remote calls. branch.lock_write() client._calls = [] - result = branch.set_last_revision_info(1234, b'a-revision-id') + result = branch.set_last_revision_info(1234, b"a-revision-id") self.assertEqual( - [('call', b'Branch.last_revision_info', (b'branch/',)), - ('call', b'Branch.set_last_revision_info', - (b'branch/', b'branch token', b'repo token', - b'1234', b'a-revision-id'))], - client._calls) + [ + ("call", b"Branch.last_revision_info", (b"branch/",)), + ( + "call", + b"Branch.set_last_revision_info", + ( + b"branch/", + b"branch token", + b"repo token", + b"1234", + b"a-revision-id", + ), + ), + ], + client._calls, + ) self.assertEqual(None, result) def test_no_such_revision(self): # A response of 'NoSuchRevision' is translated into an exception. transport = MemoryTransport() - transport.mkdir('branch') - transport = transport.clone('branch') + transport.mkdir("branch") + transport = transport.clone("branch") client = FakeClient(transport.base) # get_stacked_on_url - client.add_error_response(b'NotStacked') + client.add_error_response(b"NotStacked") # lock_write - client.add_success_response(b'ok', b'branch token', b'repo token') + client.add_success_response(b"ok", b"branch token", b"repo token") # set_last_revision - client.add_error_response(b'NoSuchRevision', b'revid') + client.add_error_response(b"NoSuchRevision", b"revid") # unlock - client.add_success_response(b'ok') + client.add_success_response(b"ok") branch = self.make_remote_branch(transport, client) # Lock the branch, reset the record of remote calls. @@ -1817,7 +2046,8 @@ def test_no_such_revision(self): client._calls = [] self.assertRaises( - errors.NoSuchRevision, branch.set_last_revision_info, 123, b'revid') + errors.NoSuchRevision, branch.set_last_revision_info, 123, b"revid" + ) branch.unlock() def test_backwards_compatibility(self): @@ -1833,20 +2063,30 @@ def test_backwards_compatibility(self): # First, set up our RemoteBranch with a FakeClient that raises # UnknownSmartMethod, and a StubRealBranch that logs how it is called. transport = MemoryTransport() - transport.mkdir('branch') - transport = transport.clone('branch') + transport.mkdir("branch") + transport = transport.clone("branch") client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'branch/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"branch/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.last_revision_info', - (b'branch/',), - b'success', (b'ok', b'0', b'null:')) + b"Branch.last_revision_info", + (b"branch/",), + b"success", + (b"ok", b"0", b"null:"), + ) client.add_expected_call( - b'Branch.set_last_revision_info', - (b'branch/', b'branch token', b'repo token', b'1234', b'a-revision-id',), - b'unknown', b'Branch.set_last_revision_info') + b"Branch.set_last_revision_info", + ( + b"branch/", + b"branch token", + b"repo token", + b"1234", + b"a-revision-id", + ), + b"unknown", + b"Branch.set_last_revision_info", + ) branch = self.make_remote_branch(transport, client) @@ -1855,20 +2095,20 @@ def __init__(self): self.calls = [] def set_last_revision_info(self, revno, revision_id): - self.calls.append( - ('set_last_revision_info', revno, revision_id)) + self.calls.append(("set_last_revision_info", revno, revision_id)) def _clear_cached_state(self): pass + real_branch = StubRealBranch() branch._real_branch = real_branch self.lock_remote_branch(branch) # Call set_last_revision_info, and verify it behaved as expected. - branch.set_last_revision_info(1234, b'a-revision-id') + branch.set_last_revision_info(1234, b"a-revision-id") self.assertEqual( - [('set_last_revision_info', 1234, b'a-revision-id')], - real_branch.calls) + [("set_last_revision_info", 1234, b"a-revision-id")], real_branch.calls + ) self.assertFinished(client) def test_unexpected_error(self): @@ -1876,17 +2116,17 @@ def test_unexpected_error(self): # turned into an UnknownErrorFromSmartServer, which is presented as a # non-internal error to the user. transport = MemoryTransport() - transport.mkdir('branch') - transport = transport.clone('branch') + transport.mkdir("branch") + transport = transport.clone("branch") client = FakeClient(transport.base) # get_stacked_on_url - client.add_error_response(b'NotStacked') + client.add_error_response(b"NotStacked") # lock_write - client.add_success_response(b'ok', b'branch token', b'repo token') + client.add_success_response(b"ok", b"branch token", b"repo token") # set_last_revision - client.add_error_response(b'UnexpectedError') + client.add_error_response(b"UnexpectedError") # unlock - client.add_success_response(b'ok') + client.add_success_response(b"ok") branch = self.make_remote_branch(transport, client) # Lock the branch, reset the record of remote calls. @@ -1894,9 +2134,9 @@ def test_unexpected_error(self): client._calls = [] err = self.assertRaises( - UnknownErrorFromSmartServer, - branch.set_last_revision_info, 123, b'revid') - self.assertEqual((b'UnexpectedError',), err.error_tuple) + UnknownErrorFromSmartServer, branch.set_last_revision_info, 123, b"revid" + ) + self.assertEqual((b"UnexpectedError",), err.error_tuple) branch.unlock() def test_tip_change_rejected(self): @@ -1904,17 +2144,17 @@ def test_tip_change_rejected(self): be raised. """ transport = MemoryTransport() - transport.mkdir('branch') - transport = transport.clone('branch') + transport.mkdir("branch") + transport = transport.clone("branch") client = FakeClient(transport.base) # get_stacked_on_url - client.add_error_response(b'NotStacked') + client.add_error_response(b"NotStacked") # lock_write - client.add_success_response(b'ok', b'branch token', b'repo token') + client.add_success_response(b"ok", b"branch token", b"repo token") # set_last_revision - client.add_error_response(b'TipChangeRejected', b'rejection message') + client.add_error_response(b"TipChangeRejected", b"rejection message") # unlock - client.add_success_response(b'ok') + client.add_success_response(b"ok") branch = self.make_remote_branch(transport, client) # Lock the branch, reset the record of remote calls. @@ -1925,28 +2165,33 @@ def test_tip_change_rejected(self): # The 'TipChangeRejected' error response triggered by calling # set_last_revision_info causes a TipChangeRejected exception. err = self.assertRaises( - errors.TipChangeRejected, - branch.set_last_revision_info, 123, b'revid') - self.assertEqual('rejection message', err.msg) + errors.TipChangeRejected, branch.set_last_revision_info, 123, b"revid" + ) + self.assertEqual("rejection message", err.msg) class TestBranchGetSetConfig(RemoteBranchTestCase): - def test_get_branch_conf(self): # in an empty branch we decode the response properly client = FakeClient() client.add_expected_call( - b'Branch.get_stacked_on_url', (b'memory:///',), - b'error', (b'NotStacked',),) - client.add_success_response_with_body(b'# config file body', b'ok') + b"Branch.get_stacked_on_url", + (b"memory:///",), + b"error", + (b"NotStacked",), + ) + client.add_success_response_with_body(b"# config file body", b"ok") transport = MemoryTransport() branch = self.make_remote_branch(transport, client) config = branch.get_config() config.has_explicit_nickname() self.assertEqual( - [('call', b'Branch.get_stacked_on_url', (b'memory:///',)), - ('call_expecting_body', b'Branch.get_config_file', (b'memory:///',))], - client._calls) + [ + ("call", b"Branch.get_stacked_on_url", (b"memory:///",)), + ("call_expecting_body", b"Branch.get_config_file", (b"memory:///",)), + ], + client._calls, + ) def test_get_multi_line_branch_conf(self): # Make sure that multiple-line branch.conf files are supported @@ -1954,324 +2199,444 @@ def test_get_multi_line_branch_conf(self): # https://bugs.launchpad.net/bzr/+bug/354075 client = FakeClient() client.add_expected_call( - b'Branch.get_stacked_on_url', (b'memory:///',), - b'error', (b'NotStacked',),) - client.add_success_response_with_body(b'a = 1\nb = 2\nc = 3\n', b'ok') + b"Branch.get_stacked_on_url", + (b"memory:///",), + b"error", + (b"NotStacked",), + ) + client.add_success_response_with_body(b"a = 1\nb = 2\nc = 3\n", b"ok") transport = MemoryTransport() branch = self.make_remote_branch(transport, client) config = branch.get_config() - self.assertEqual('2', config.get_user_option('b')) + self.assertEqual("2", config.get_user_option("b")) def test_set_option(self): client = FakeClient() client.add_expected_call( - b'Branch.get_stacked_on_url', (b'memory:///',), - b'error', (b'NotStacked',),) + b"Branch.get_stacked_on_url", + (b"memory:///",), + b"error", + (b"NotStacked",), + ) client.add_expected_call( - b'Branch.lock_write', (b'memory:///', b'', b''), - b'success', (b'ok', b'branch token', b'repo token')) + b"Branch.lock_write", + (b"memory:///", b"", b""), + b"success", + (b"ok", b"branch token", b"repo token"), + ) client.add_expected_call( - b'Branch.set_config_option', (b'memory:///', b'branch token', - b'repo token', b'foo', b'bar', b''), - b'success', ()) + b"Branch.set_config_option", + (b"memory:///", b"branch token", b"repo token", b"foo", b"bar", b""), + b"success", + (), + ) client.add_expected_call( - b'Branch.unlock', (b'memory:///', b'branch token', b'repo token'), - b'success', (b'ok',)) + b"Branch.unlock", + (b"memory:///", b"branch token", b"repo token"), + b"success", + (b"ok",), + ) transport = MemoryTransport() branch = self.make_remote_branch(transport, client) branch.lock_write() config = branch._get_config() - config.set_option('foo', 'bar') + config.set_option("foo", "bar") branch.unlock() self.assertFinished(client) def test_set_option_with_dict(self): client = FakeClient() client.add_expected_call( - b'Branch.get_stacked_on_url', (b'memory:///',), - b'error', (b'NotStacked',),) - client.add_expected_call( - b'Branch.lock_write', (b'memory:///', b'', b''), - b'success', (b'ok', b'branch token', b'repo token')) - encoded_dict_value = b'd5:ascii1:a11:unicode \xe2\x8c\x9a3:\xe2\x80\xbde' + b"Branch.get_stacked_on_url", + (b"memory:///",), + b"error", + (b"NotStacked",), + ) client.add_expected_call( - b'Branch.set_config_option_dict', (b'memory:///', b'branch token', - b'repo token', encoded_dict_value, b'foo', b''), - b'success', ()) + b"Branch.lock_write", + (b"memory:///", b"", b""), + b"success", + (b"ok", b"branch token", b"repo token"), + ) + encoded_dict_value = b"d5:ascii1:a11:unicode \xe2\x8c\x9a3:\xe2\x80\xbde" + client.add_expected_call( + b"Branch.set_config_option_dict", + ( + b"memory:///", + b"branch token", + b"repo token", + encoded_dict_value, + b"foo", + b"", + ), + b"success", + (), + ) client.add_expected_call( - b'Branch.unlock', (b'memory:///', b'branch token', b'repo token'), - b'success', (b'ok',)) + b"Branch.unlock", + (b"memory:///", b"branch token", b"repo token"), + b"success", + (b"ok",), + ) transport = MemoryTransport() branch = self.make_remote_branch(transport, client) branch.lock_write() config = branch._get_config() - config.set_option( - {'ascii': 'a', 'unicode \N{WATCH}': '\N{INTERROBANG}'}, - 'foo') + config.set_option({"ascii": "a", "unicode \N{WATCH}": "\N{INTERROBANG}"}, "foo") branch.unlock() self.assertFinished(client) def test_set_option_with_bool(self): client = FakeClient() client.add_expected_call( - b'Branch.get_stacked_on_url', (b'memory:///',), - b'error', (b'NotStacked',),) + b"Branch.get_stacked_on_url", + (b"memory:///",), + b"error", + (b"NotStacked",), + ) client.add_expected_call( - b'Branch.lock_write', (b'memory:///', b'', b''), - b'success', (b'ok', b'branch token', b'repo token')) + b"Branch.lock_write", + (b"memory:///", b"", b""), + b"success", + (b"ok", b"branch token", b"repo token"), + ) client.add_expected_call( - b'Branch.set_config_option', (b'memory:///', b'branch token', - b'repo token', b'True', b'foo', b''), - b'success', ()) + b"Branch.set_config_option", + (b"memory:///", b"branch token", b"repo token", b"True", b"foo", b""), + b"success", + (), + ) client.add_expected_call( - b'Branch.unlock', (b'memory:///', b'branch token', b'repo token'), - b'success', (b'ok',)) + b"Branch.unlock", + (b"memory:///", b"branch token", b"repo token"), + b"success", + (b"ok",), + ) transport = MemoryTransport() branch = self.make_remote_branch(transport, client) branch.lock_write() config = branch._get_config() - config.set_option(True, 'foo') + config.set_option(True, "foo") branch.unlock() self.assertFinished(client) def test_backwards_compat_set_option(self): self.setup_smart_server_with_call_log() - branch = self.make_branch('.') - verb = b'Branch.set_config_option' + branch = self.make_branch(".") + verb = b"Branch.set_config_option" self.disable_verb(verb) branch.lock_write() self.addCleanup(branch.unlock) self.reset_smart_call_log() - branch._get_config().set_option('value', 'name') + branch._get_config().set_option("value", "name") self.assertLength(11, self.hpss_calls) - self.assertEqual('value', branch._get_config().get_option('name')) + self.assertEqual("value", branch._get_config().get_option("name")) def test_backwards_compat_set_option_with_dict(self): self.setup_smart_server_with_call_log() - branch = self.make_branch('.') - verb = b'Branch.set_config_option_dict' + branch = self.make_branch(".") + verb = b"Branch.set_config_option_dict" self.disable_verb(verb) branch.lock_write() self.addCleanup(branch.unlock) self.reset_smart_call_log() config = branch._get_config() - value_dict = {'ascii': 'a', 'unicode \N{WATCH}': '\N{INTERROBANG}'} - config.set_option(value_dict, 'name') + value_dict = {"ascii": "a", "unicode \N{WATCH}": "\N{INTERROBANG}"} + config.set_option(value_dict, "name") self.assertLength(11, self.hpss_calls) - self.assertEqual(value_dict, branch._get_config().get_option('name')) + self.assertEqual(value_dict, branch._get_config().get_option("name")) class TestBranchGetPutConfigStore(RemoteBranchTestCase): - def test_get_branch_conf(self): # in an empty branch we decode the response properly client = FakeClient() client.add_expected_call( - b'Branch.get_stacked_on_url', (b'memory:///',), - b'error', (b'NotStacked',),) - client.add_success_response_with_body(b'# config file body', b'ok') + b"Branch.get_stacked_on_url", + (b"memory:///",), + b"error", + (b"NotStacked",), + ) + client.add_success_response_with_body(b"# config file body", b"ok") transport = MemoryTransport() branch = self.make_remote_branch(transport, client) config = branch.get_config_stack() config.get("email") config.get("log_format") self.assertEqual( - [('call', b'Branch.get_stacked_on_url', (b'memory:///',)), - ('call_expecting_body', b'Branch.get_config_file', (b'memory:///',))], - client._calls) + [ + ("call", b"Branch.get_stacked_on_url", (b"memory:///",)), + ("call_expecting_body", b"Branch.get_config_file", (b"memory:///",)), + ], + client._calls, + ) def test_set_branch_conf(self): client = FakeClient() client.add_expected_call( - b'Branch.get_stacked_on_url', (b'memory:///',), - b'error', (b'NotStacked',),) + b"Branch.get_stacked_on_url", + (b"memory:///",), + b"error", + (b"NotStacked",), + ) client.add_expected_call( - b'Branch.lock_write', (b'memory:///', b'', b''), - b'success', (b'ok', b'branch token', b'repo token')) + b"Branch.lock_write", + (b"memory:///", b"", b""), + b"success", + (b"ok", b"branch token", b"repo token"), + ) client.add_expected_call( - b'Branch.get_config_file', (b'memory:///', ), - b'success', (b'ok', ), b"# line 1\n") + b"Branch.get_config_file", + (b"memory:///",), + b"success", + (b"ok",), + b"# line 1\n", + ) client.add_expected_call( - b'Branch.get_config_file', (b'memory:///', ), - b'success', (b'ok', ), b"# line 1\n") + b"Branch.get_config_file", + (b"memory:///",), + b"success", + (b"ok",), + b"# line 1\n", + ) client.add_expected_call( - b'Branch.put_config_file', (b'memory:///', b'branch token', - b'repo token'), - b'success', (b'ok',)) + b"Branch.put_config_file", + (b"memory:///", b"branch token", b"repo token"), + b"success", + (b"ok",), + ) client.add_expected_call( - b'Branch.unlock', (b'memory:///', b'branch token', b'repo token'), - b'success', (b'ok',)) + b"Branch.unlock", + (b"memory:///", b"branch token", b"repo token"), + b"success", + (b"ok",), + ) transport = MemoryTransport() branch = self.make_remote_branch(transport, client) branch.lock_write() config = branch.get_config_stack() - config.set('email', 'The Dude ') + config.set("email", "The Dude ") branch.unlock() self.assertFinished(client) self.assertEqual( - [('call', b'Branch.get_stacked_on_url', (b'memory:///',)), - ('call', b'Branch.lock_write', (b'memory:///', b'', b'')), - ('call_expecting_body', b'Branch.get_config_file', - (b'memory:///',)), - ('call_expecting_body', b'Branch.get_config_file', - (b'memory:///',)), - ('call_with_body_bytes_expecting_body', b'Branch.put_config_file', - (b'memory:///', b'branch token', b'repo token'), - b'# line 1\nemail = The Dude \n'), - ('call', b'Branch.unlock', - (b'memory:///', b'branch token', b'repo token'))], - client._calls) + [ + ("call", b"Branch.get_stacked_on_url", (b"memory:///",)), + ("call", b"Branch.lock_write", (b"memory:///", b"", b"")), + ("call_expecting_body", b"Branch.get_config_file", (b"memory:///",)), + ("call_expecting_body", b"Branch.get_config_file", (b"memory:///",)), + ( + "call_with_body_bytes_expecting_body", + b"Branch.put_config_file", + (b"memory:///", b"branch token", b"repo token"), + b"# line 1\nemail = The Dude \n", + ), + ( + "call", + b"Branch.unlock", + (b"memory:///", b"branch token", b"repo token"), + ), + ], + client._calls, + ) class TestBranchLockWrite(RemoteBranchTestCase): - def test_lock_write_unlockable(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',),) + b"Branch.get_stacked_on_url", + (b"quack/",), + b"error", + (b"NotStacked",), + ) client.add_expected_call( - b'Branch.lock_write', (b'quack/', b'', b''), - b'error', (b'UnlockableTransport',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.lock_write", + (b"quack/", b"", b""), + b"error", + (b"UnlockableTransport",), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) self.assertRaises(errors.UnlockableTransport, branch.lock_write) self.assertFinished(client) class TestBranchRevisionIdToRevno(RemoteBranchTestCase): - def test_simple(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',),) + b"Branch.get_stacked_on_url", + (b"quack/",), + b"error", + (b"NotStacked",), + ) client.add_expected_call( - b'Branch.revision_id_to_revno', (b'quack/', b'null:'), - b'success', (b'ok', b'0',),) + b"Branch.revision_id_to_revno", + (b"quack/", b"null:"), + b"success", + ( + b"ok", + b"0", + ), + ) client.add_expected_call( - b'Branch.revision_id_to_revno', (b'quack/', b'unknown'), - b'error', (b'NoSuchRevision', b'unknown',),) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.revision_id_to_revno", + (b"quack/", b"unknown"), + b"error", + ( + b"NoSuchRevision", + b"unknown", + ), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) - self.assertEqual(0, branch.revision_id_to_revno(b'null:')) - self.assertRaises(errors.NoSuchRevision, - branch.revision_id_to_revno, b'unknown') + self.assertEqual(0, branch.revision_id_to_revno(b"null:")) + self.assertRaises( + errors.NoSuchRevision, branch.revision_id_to_revno, b"unknown" + ) self.assertFinished(client) def test_dotted(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',),) + b"Branch.get_stacked_on_url", + (b"quack/",), + b"error", + (b"NotStacked",), + ) client.add_expected_call( - b'Branch.revision_id_to_revno', (b'quack/', b'null:'), - b'success', (b'ok', b'0',),) + b"Branch.revision_id_to_revno", + (b"quack/", b"null:"), + b"success", + ( + b"ok", + b"0", + ), + ) client.add_expected_call( - b'Branch.revision_id_to_revno', (b'quack/', b'unknown'), - b'error', (b'NoSuchRevision', b'unknown',),) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.revision_id_to_revno", + (b"quack/", b"unknown"), + b"error", + ( + b"NoSuchRevision", + b"unknown", + ), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) - self.assertEqual((0, ), branch.revision_id_to_dotted_revno(b'null:')) - self.assertRaises(errors.NoSuchRevision, - branch.revision_id_to_dotted_revno, b'unknown') + self.assertEqual((0,), branch.revision_id_to_dotted_revno(b"null:")) + self.assertRaises( + errors.NoSuchRevision, branch.revision_id_to_dotted_revno, b"unknown" + ) self.assertFinished(client) def test_ghost_revid(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',),) + b"Branch.get_stacked_on_url", + (b"quack/",), + b"error", + (b"NotStacked",), + ) # Some older versions of bzr/brz didn't explicitly return # GhostRevisionsHaveNoRevno client.add_expected_call( - b'Branch.revision_id_to_revno', (b'quack/', b'revid'), - b'error', (b'error', b'GhostRevisionsHaveNoRevno', - b'The reivison {revid} was not found because there was ' - b'a ghost at {ghost-revid}')) + b"Branch.revision_id_to_revno", + (b"quack/", b"revid"), + b"error", + ( + b"error", + b"GhostRevisionsHaveNoRevno", + b"The reivison {revid} was not found because there was " + b"a ghost at {ghost-revid}", + ), + ) client.add_expected_call( - b'Branch.revision_id_to_revno', (b'quack/', b'revid'), - b'error', (b'GhostRevisionsHaveNoRevno', b'revid', b'ghost-revid',)) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.revision_id_to_revno", + (b"quack/", b"revid"), + b"error", + ( + b"GhostRevisionsHaveNoRevno", + b"revid", + b"ghost-revid", + ), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) - self.assertRaises(errors.GhostRevisionsHaveNoRevno, - branch.revision_id_to_dotted_revno, b'revid') - self.assertRaises(errors.GhostRevisionsHaveNoRevno, - branch.revision_id_to_dotted_revno, b'revid') + self.assertRaises( + errors.GhostRevisionsHaveNoRevno, + branch.revision_id_to_dotted_revno, + b"revid", + ) + self.assertRaises( + errors.GhostRevisionsHaveNoRevno, + branch.revision_id_to_dotted_revno, + b"revid", + ) self.assertFinished(client) def test_dotted_no_smart_verb(self): self.setup_smart_server_with_call_log() - branch = self.make_branch('.') - self.disable_verb(b'Branch.revision_id_to_revno') + branch = self.make_branch(".") + self.disable_verb(b"Branch.revision_id_to_revno") self.reset_smart_call_log() - self.assertEqual((0, ), - branch.revision_id_to_dotted_revno(b'null:')) + self.assertEqual((0,), branch.revision_id_to_dotted_revno(b"null:")) self.assertLength(8, self.hpss_calls) class TestBzrDirGetSetConfig(RemoteBzrDirTestCase): - def test__get_config(self): client = FakeClient() - client.add_success_response_with_body(b'default_stack_on = /\n', b'ok') + client.add_success_response_with_body(b"default_stack_on = /\n", b"ok") transport = MemoryTransport() bzrdir = self.make_remote_bzrdir(transport, client) config = bzrdir.get_config() - self.assertEqual('/', config.get_default_stack_on()) + self.assertEqual("/", config.get_default_stack_on()) self.assertEqual( - [('call_expecting_body', b'BzrDir.get_config_file', - (b'memory:///',))], - client._calls) + [("call_expecting_body", b"BzrDir.get_config_file", (b"memory:///",))], + client._calls, + ) def test_set_option_uses_vfs(self): self.setup_smart_server_with_call_log() - bzrdir = self.make_controldir('.') + bzrdir = self.make_controldir(".") self.reset_smart_call_log() config = bzrdir.get_config() - config.set_default_stack_on('/') + config.set_default_stack_on("/") self.assertLength(4, self.hpss_calls) def test_backwards_compat_get_option(self): self.setup_smart_server_with_call_log() - bzrdir = self.make_controldir('.') - verb = b'BzrDir.get_config_file' + bzrdir = self.make_controldir(".") + verb = b"BzrDir.get_config_file" self.disable_verb(verb) self.reset_smart_call_log() - self.assertEqual(None, - bzrdir._get_config().get_option('default_stack_on')) + self.assertEqual(None, bzrdir._get_config().get_option("default_stack_on")) self.assertLength(4, self.hpss_calls) class TestTransportIsReadonly(tests.TestCase): - def test_true(self): client = FakeClient() - client.add_success_response(b'yes') - transport = RemoteTransport('bzr://example.com/', medium=False, - _client=client) + client.add_success_response(b"yes") + transport = RemoteTransport("bzr://example.com/", medium=False, _client=client) self.assertEqual(True, transport.is_readonly()) - self.assertEqual( - [('call', b'Transport.is_readonly', ())], - client._calls) + self.assertEqual([("call", b"Transport.is_readonly", ())], client._calls) def test_false(self): client = FakeClient() - client.add_success_response(b'no') - transport = RemoteTransport('bzr://example.com/', medium=False, - _client=client) + client.add_success_response(b"no") + transport = RemoteTransport("bzr://example.com/", medium=False, _client=client) self.assertEqual(False, transport.is_readonly()) - self.assertEqual( - [('call', b'Transport.is_readonly', ())], - client._calls) + self.assertEqual([("call", b"Transport.is_readonly", ())], client._calls) def test_error_from_old_server(self): """Bzr 0.15 and earlier servers don't recognise the is_readonly verb. @@ -2281,43 +2646,35 @@ def test_error_from_old_server(self): underlying filesystem could be readonly anyway). """ client = FakeClient() - client.add_unknown_method_response(b'Transport.is_readonly') - transport = RemoteTransport('bzr://example.com/', medium=False, - _client=client) + client.add_unknown_method_response(b"Transport.is_readonly") + transport = RemoteTransport("bzr://example.com/", medium=False, _client=client) self.assertEqual(False, transport.is_readonly()) - self.assertEqual( - [('call', b'Transport.is_readonly', ())], - client._calls) + self.assertEqual([("call", b"Transport.is_readonly", ())], client._calls) class TestTransportMkdir(tests.TestCase): - def test_permissiondenied(self): client = FakeClient() - client.add_error_response( - b'PermissionDenied', b'remote path', b'extra') - transport = RemoteTransport('bzr://example.com/', medium=False, - _client=client) - exc = self.assertRaises( - errors.PermissionDenied, transport.mkdir, 'client path') - expected_error = errors.PermissionDenied('/client path', 'extra') + client.add_error_response(b"PermissionDenied", b"remote path", b"extra") + transport = RemoteTransport("bzr://example.com/", medium=False, _client=client) + exc = self.assertRaises(errors.PermissionDenied, transport.mkdir, "client path") + expected_error = errors.PermissionDenied("/client path", "extra") self.assertEqual(expected_error, exc) class TestRemoteSSHTransportAuthentication(tests.TestCaseInTempDir): - def test_defaults_to_none(self): - t = RemoteSSHTransport('bzr+ssh://example.com') + t = RemoteSSHTransport("bzr+ssh://example.com") self.assertIs(None, t._get_credentials()[0]) def test_uses_authentication_config(self): conf = config.AuthenticationConfig() conf._get_config().update( - {'bzr+sshtest': {'scheme': 'ssh', 'user': 'bar', 'host': - 'example.com'}}) + {"bzr+sshtest": {"scheme": "ssh", "user": "bar", "host": "example.com"}} + ) conf._save() - t = RemoteSSHTransport('bzr+ssh://example.com') - self.assertEqual('bar', t._get_credentials()[0]) + t = RemoteSSHTransport("bzr+ssh://example.com") + self.assertEqual("bar", t._get_credentials()[0]) class TestRemoteRepository(TestRemote): @@ -2342,28 +2699,26 @@ def setup_fake_client_and_repository(self, transport_path): client = FakeClient(transport.base) transport = transport.clone(transport_path) # we do not want bzrdir to make any remote calls - bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), - _client=False) + bzrdir = RemoteBzrDir(transport, RemoteBzrDirFormat(), _client=False) repo = RemoteRepository(bzrdir, None, _client=client) return repo, client def remoted_description(format): - return 'Remote: ' + format.get_format_description() + return "Remote: " + format.get_format_description() class TestBranchFormat(tests.TestCase): - def test_get_format_description(self): remote_format = RemoteBranchFormat() real_format = branch.format_registry.get_default() remote_format._network_name = real_format.network_name() - self.assertEqual(remoted_description(real_format), - remote_format.get_format_description()) + self.assertEqual( + remoted_description(real_format), remote_format.get_format_description() + ) class TestRepositoryFormat(TestRemoteRepository): - def test_fast_delta(self): true_name = groupcompress_repo.RepositoryFormat2a().network_name() true_format = RemoteRepositoryFormat() @@ -2378,217 +2733,263 @@ def test_get_format_description(self): remote_repo_format = RemoteRepositoryFormat() real_format = repository.format_registry.get_default() remote_repo_format._network_name = real_format.network_name() - self.assertEqual(remoted_description(real_format), - remote_repo_format.get_format_description()) + self.assertEqual( + remoted_description(real_format), + remote_repo_format.get_format_description(), + ) class TestRepositoryAllRevisionIds(TestRemoteRepository): - def test_empty(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body(b'', b'ok') + client.add_success_response_with_body(b"", b"ok") self.assertEqual([], repo.all_revision_ids()) self.assertEqual( - [('call_expecting_body', b'Repository.all_revision_ids', - (b'quack/',))], - client._calls) + [("call_expecting_body", b"Repository.all_revision_ids", (b"quack/",))], + client._calls, + ) def test_with_some_content(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body( - b'rev1\nrev2\nanotherrev\n', b'ok') + client.add_success_response_with_body(b"rev1\nrev2\nanotherrev\n", b"ok") self.assertEqual( - {b"rev1", b"rev2", b"anotherrev"}, - set(repo.all_revision_ids())) + {b"rev1", b"rev2", b"anotherrev"}, set(repo.all_revision_ids()) + ) self.assertEqual( - [('call_expecting_body', b'Repository.all_revision_ids', - (b'quack/',))], - client._calls) + [("call_expecting_body", b"Repository.all_revision_ids", (b"quack/",))], + client._calls, + ) class TestRepositoryGatherStats(TestRemoteRepository): - def test_revid_none(self): # ('ok',), body with revisions and size - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body( - b'revisions: 2\nsize: 18\n', b'ok') + client.add_success_response_with_body(b"revisions: 2\nsize: 18\n", b"ok") result = repo.gather_stats(None) self.assertEqual( - [('call_expecting_body', b'Repository.gather_stats', - (b'quack/', b'', b'no'))], - client._calls) - self.assertEqual({'revisions': 2, 'size': 18}, result) + [ + ( + "call_expecting_body", + b"Repository.gather_stats", + (b"quack/", b"", b"no"), + ) + ], + client._calls, + ) + self.assertEqual({"revisions": 2, "size": 18}, result) def test_revid_no_committers(self): # ('ok',), body without committers - body = (b'firstrev: 123456.300 3600\n' - b'latestrev: 654231.400 0\n' - b'revisions: 2\n' - b'size: 18\n') - transport_path = 'quick' - revid = '\xc8'.encode() + body = ( + b"firstrev: 123456.300 3600\n" + b"latestrev: 654231.400 0\n" + b"revisions: 2\n" + b"size: 18\n" + ) + transport_path = "quick" + revid = "\xc8".encode() repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body(body, b'ok') + client.add_success_response_with_body(body, b"ok") result = repo.gather_stats(revid) self.assertEqual( - [('call_expecting_body', b'Repository.gather_stats', - (b'quick/', revid, b'no'))], - client._calls) - self.assertEqual({'revisions': 2, 'size': 18, - 'firstrev': (123456.300, 3600), - 'latestrev': (654231.400, 0), }, - result) + [ + ( + "call_expecting_body", + b"Repository.gather_stats", + (b"quick/", revid, b"no"), + ) + ], + client._calls, + ) + self.assertEqual( + { + "revisions": 2, + "size": 18, + "firstrev": (123456.300, 3600), + "latestrev": (654231.400, 0), + }, + result, + ) def test_revid_with_committers(self): # ('ok',), body with committers - body = (b'committers: 128\n' - b'firstrev: 123456.300 3600\n' - b'latestrev: 654231.400 0\n' - b'revisions: 2\n' - b'size: 18\n') - transport_path = 'buick' - revid = '\xc8'.encode() + body = ( + b"committers: 128\n" + b"firstrev: 123456.300 3600\n" + b"latestrev: 654231.400 0\n" + b"revisions: 2\n" + b"size: 18\n" + ) + transport_path = "buick" + revid = "\xc8".encode() repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body(body, b'ok') + client.add_success_response_with_body(body, b"ok") result = repo.gather_stats(revid, True) self.assertEqual( - [('call_expecting_body', b'Repository.gather_stats', - (b'buick/', revid, b'yes'))], - client._calls) - self.assertEqual({'revisions': 2, 'size': 18, - 'committers': 128, - 'firstrev': (123456.300, 3600), - 'latestrev': (654231.400, 0), }, - result) + [ + ( + "call_expecting_body", + b"Repository.gather_stats", + (b"buick/", revid, b"yes"), + ) + ], + client._calls, + ) + self.assertEqual( + { + "revisions": 2, + "size": 18, + "committers": 128, + "firstrev": (123456.300, 3600), + "latestrev": (654231.400, 0), + }, + result, + ) class TestRepositoryBreakLock(TestRemoteRepository): - def test_break_lock(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'ok') + client.add_success_response(b"ok") repo.break_lock() self.assertEqual( - [('call', b'Repository.break_lock', (b'quack/',))], - client._calls) + [("call", b"Repository.break_lock", (b"quack/",))], client._calls + ) class TestRepositoryGetSerializerFormat(TestRemoteRepository): - def test_get_serializer_format(self): - transport_path = 'hill' + transport_path = "hill" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'ok', b'7') - self.assertEqual(b'7', repo.get_serializer_format()) + client.add_success_response(b"ok", b"7") + self.assertEqual(b"7", repo.get_serializer_format()) self.assertEqual( - [('call', b'VersionedFileRepository.get_serializer_format', - (b'hill/', ))], - client._calls) + [("call", b"VersionedFileRepository.get_serializer_format", (b"hill/",))], + client._calls, + ) class TestRepositoryReconcile(TestRemoteRepository): - def test_reconcile(self): - transport_path = 'hill' + transport_path = "hill" repo, client = self.setup_fake_client_and_repository(transport_path) - body = (b"garbage_inventories: 2\n" - b"inconsistent_parents: 3\n") + body = b"garbage_inventories: 2\n" b"inconsistent_parents: 3\n" client.add_expected_call( - b'Repository.lock_write', (b'hill/', b''), - b'success', (b'ok', b'a token')) - client.add_success_response_with_body(body, b'ok') + b"Repository.lock_write", (b"hill/", b""), b"success", (b"ok", b"a token") + ) + client.add_success_response_with_body(body, b"ok") reconciler = repo.reconcile() self.assertEqual( - [('call', b'Repository.lock_write', (b'hill/', b'')), - ('call_expecting_body', b'Repository.reconcile', - (b'hill/', b'a token'))], - client._calls) + [ + ("call", b"Repository.lock_write", (b"hill/", b"")), + ( + "call_expecting_body", + b"Repository.reconcile", + (b"hill/", b"a token"), + ), + ], + client._calls, + ) self.assertEqual(2, reconciler.garbage_inventories) self.assertEqual(3, reconciler.inconsistent_parents) class TestRepositoryGetRevisionSignatureText(TestRemoteRepository): - def test_text(self): # ('ok',), body with signature text - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body( - b'THETEXT', b'ok') + client.add_success_response_with_body(b"THETEXT", b"ok") self.assertEqual(b"THETEXT", repo.get_signature_text(b"revid")) self.assertEqual( - [('call_expecting_body', b'Repository.get_revision_signature_text', - (b'quack/', b'revid'))], - client._calls) + [ + ( + "call_expecting_body", + b"Repository.get_revision_signature_text", + (b"quack/", b"revid"), + ) + ], + client._calls, + ) def test_no_signature(self): - transport_path = 'quick' + transport_path = "quick" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_error_response(b'nosuchrevision', b'unknown') - self.assertRaises(errors.NoSuchRevision, repo.get_signature_text, - b"unknown") + client.add_error_response(b"nosuchrevision", b"unknown") + self.assertRaises(errors.NoSuchRevision, repo.get_signature_text, b"unknown") self.assertEqual( - [('call_expecting_body', b'Repository.get_revision_signature_text', - (b'quick/', b'unknown'))], - client._calls) + [ + ( + "call_expecting_body", + b"Repository.get_revision_signature_text", + (b"quick/", b"unknown"), + ) + ], + client._calls, + ) class TestRepositoryGetGraph(TestRemoteRepository): - def test_get_graph(self): # get_graph returns a graph with a custom parents provider. - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) graph = repo.get_graph() self.assertNotEqual(graph._parents_provider, repo) class TestRepositoryAddSignatureText(TestRemoteRepository): - def test_add_signature_text(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'a token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"a token") + ) client.add_expected_call( - b'Repository.start_write_group', (b'quack/', b'a token'), - b'success', (b'ok', (b'token1', ))) + b"Repository.start_write_group", + (b"quack/", b"a token"), + b"success", + (b"ok", (b"token1",)), + ) client.add_expected_call( - b'Repository.add_signature_text', (b'quack/', b'a token', b'rev1', - b'token1'), - b'success', (b'ok', ), None) + b"Repository.add_signature_text", + (b"quack/", b"a token", b"rev1", b"token1"), + b"success", + (b"ok",), + None, + ) repo.lock_write() repo.start_write_group() - self.assertIs( - None, repo.add_signature_text(b"rev1", b"every bloody emperor")) + self.assertIs(None, repo.add_signature_text(b"rev1", b"every bloody emperor")) self.assertEqual( - ('call_with_body_bytes_expecting_body', - b'Repository.add_signature_text', - (b'quack/', b'a token', b'rev1', b'token1'), - b'every bloody emperor'), - client._calls[-1]) + ( + "call_with_body_bytes_expecting_body", + b"Repository.add_signature_text", + (b"quack/", b"a token", b"rev1", b"token1"), + b"every bloody emperor", + ), + client._calls[-1], + ) class TestRepositoryGetParentMap(TestRemoteRepository): - def test_get_parent_map_caching(self): # get_parent_map returns from cache until unlock() # setup a reponse with two revisions - r1 = '\u0e33'.encode() - r2 = '\u0dab'.encode() - lines = [b' '.join([r2, r1]), r1] - encoded_body = bz2.compress(b'\n'.join(lines)) + r1 = "\u0e33".encode() + r2 = "\u0dab".encode() + lines = [b" ".join([r2, r1]), r1] + encoded_body = bz2.compress(b"\n".join(lines)) - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body(encoded_body, b'ok') - client.add_success_response_with_body(encoded_body, b'ok') + client.add_success_response_with_body(encoded_body, b"ok") + client.add_success_response_with_body(encoded_body, b"ok") repo.lock_read() graph = repo.get_graph() parents = graph.get_parent_map([r2]) @@ -2599,11 +3000,16 @@ def test_get_parent_map_caching(self): parents = graph.get_parent_map([r1]) self.assertEqual({r1: (NULL_REVISION,)}, parents) self.assertEqual( - [('call_with_body_bytes_expecting_body', - b'Repository.get_parent_map', (b'quack/', - b'include-missing:', r2), - b'\n\n0')], - client._calls) + [ + ( + "call_with_body_bytes_expecting_body", + b"Repository.get_parent_map", + (b"quack/", b"include-missing:", r2), + b"\n\n0", + ) + ], + client._calls, + ) repo.unlock() # now we call again, and it should use the second response. repo.lock_read() @@ -2611,37 +3017,52 @@ def test_get_parent_map_caching(self): parents = graph.get_parent_map([r1]) self.assertEqual({r1: (NULL_REVISION,)}, parents) self.assertEqual( - [('call_with_body_bytes_expecting_body', - b'Repository.get_parent_map', (b'quack/', - b'include-missing:', r2), - b'\n\n0'), - ('call_with_body_bytes_expecting_body', - b'Repository.get_parent_map', (b'quack/', - b'include-missing:', r1), - b'\n\n0'), - ], - client._calls) + [ + ( + "call_with_body_bytes_expecting_body", + b"Repository.get_parent_map", + (b"quack/", b"include-missing:", r2), + b"\n\n0", + ), + ( + "call_with_body_bytes_expecting_body", + b"Repository.get_parent_map", + (b"quack/", b"include-missing:", r1), + b"\n\n0", + ), + ], + client._calls, + ) repo.unlock() def test_get_parent_map_reconnects_if_unknown_method(self): - transport_path = 'quack' - rev_id = b'revision-id' + transport_path = "quack" + rev_id = b"revision-id" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_unknown_method_response(b'Repository.get_parent_map') - client.add_success_response_with_body(rev_id, b'ok') + client.add_unknown_method_response(b"Repository.get_parent_map") + client.add_success_response_with_body(rev_id, b"ok") self.assertFalse(client._medium._is_remote_before((1, 2))) parents = repo.get_parent_map([rev_id]) self.assertEqual( - [('call_with_body_bytes_expecting_body', - b'Repository.get_parent_map', - (b'quack/', b'include-missing:', rev_id), b'\n\n0'), - ('disconnect medium',), - ('call_expecting_body', b'Repository.get_revision_graph', - (b'quack/', b''))], - client._calls) + [ + ( + "call_with_body_bytes_expecting_body", + b"Repository.get_parent_map", + (b"quack/", b"include-missing:", rev_id), + b"\n\n0", + ), + ("disconnect medium",), + ( + "call_expecting_body", + b"Repository.get_revision_graph", + (b"quack/", b""), + ), + ], + client._calls, + ) # The medium is now marked as being connected to an older server self.assertTrue(client._medium._is_remote_before((1, 2))) - self.assertEqual({rev_id: (b'null:',)}, parents) + self.assertEqual({rev_id: (b"null:",)}, parents) def test_get_parent_map_fallback_parentless_node(self): """get_parent_map falls back to get_revision_graph on old servers. The @@ -2654,62 +3075,68 @@ def test_get_parent_map_fallback_parentless_node(self): This is the test for https://bugs.launchpad.net/bzr/+bug/214894 """ - rev_id = b'revision-id' - transport_path = 'quack' + rev_id = b"revision-id" + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body(rev_id, b'ok') + client.add_success_response_with_body(rev_id, b"ok") client._medium._remember_remote_is_before((1, 2)) parents = repo.get_parent_map([rev_id]) self.assertEqual( - [('call_expecting_body', b'Repository.get_revision_graph', - (b'quack/', b''))], - client._calls) - self.assertEqual({rev_id: (b'null:',)}, parents) + [ + ( + "call_expecting_body", + b"Repository.get_revision_graph", + (b"quack/", b""), + ) + ], + client._calls, + ) + self.assertEqual({rev_id: (b"null:",)}, parents) def test_get_parent_map_unexpected_response(self): - repo, client = self.setup_fake_client_and_repository('path') - client.add_success_response(b'something unexpected!') + repo, client = self.setup_fake_client_and_repository("path") + client.add_success_response(b"something unexpected!") self.assertRaises( errors.UnexpectedSmartServerResponse, - repo.get_parent_map, [b'a-revision-id']) + repo.get_parent_map, + [b"a-revision-id"], + ) def test_get_parent_map_negative_caches_missing_keys(self): self.setup_smart_server_with_call_log() - repo = self.make_repository('foo') + repo = self.make_repository("foo") self.assertIsInstance(repo, RemoteRepository) repo.lock_read() self.addCleanup(repo.unlock) self.reset_smart_call_log() graph = repo.get_graph() - self.assertEqual( - {}, graph.get_parent_map([b'some-missing', b'other-missing'])) + self.assertEqual({}, graph.get_parent_map([b"some-missing", b"other-missing"])) self.assertLength(1, self.hpss_calls) # No call if we repeat this self.reset_smart_call_log() graph = repo.get_graph() - self.assertEqual( - {}, graph.get_parent_map([b'some-missing', b'other-missing'])) + self.assertEqual({}, graph.get_parent_map([b"some-missing", b"other-missing"])) self.assertLength(0, self.hpss_calls) # Asking for more unknown keys makes a request. self.reset_smart_call_log() graph = repo.get_graph() self.assertEqual( - {}, graph.get_parent_map([b'some-missing', b'other-missing', - b'more-missing'])) + {}, + graph.get_parent_map([b"some-missing", b"other-missing", b"more-missing"]), + ) self.assertLength(1, self.hpss_calls) def disableExtraResults(self): - self.overrideAttr(SmartServerRepositoryGetParentMap, - 'no_extra_results', True) + self.overrideAttr(SmartServerRepositoryGetParentMap, "no_extra_results", True) def test_null_cached_missing_and_stop_key(self): self.setup_smart_server_with_call_log() # Make a branch with a single revision. - builder = self.make_branch_builder('foo') + builder = self.make_branch_builder("foo") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', ''))], - revision_id=b'first') + builder.build_snapshot( + None, [("add", ("", b"root-id", "directory", ""))], revision_id=b"first" + ) builder.finish_series() branch = builder.get_branch() repo = branch.repository @@ -2724,13 +3151,14 @@ def test_null_cached_missing_and_stop_key(self): # 'first' it will be a candidate for the stop_keys of subsequent # requests, and because b'null:' was queried but not returned it will # be cached as missing. - self.assertEqual({b'first': (b'null:',)}, - graph.get_parent_map([b'first', b'null:'])) + self.assertEqual( + {b"first": (b"null:",)}, graph.get_parent_map([b"first", b"null:"]) + ) # Now query for another key. This request will pass along a recipe of # start and stop keys describing the already cached results, and this # recipe's revision count must be correct (or else it will trigger an # error from the server). - self.assertEqual({}, graph.get_parent_map([b'another-key'])) + self.assertEqual({}, graph.get_parent_map([b"another-key"])) # This assertion guards against disableExtraResults silently failing to # work, thus invalidating the test. self.assertLength(2, self.hpss_calls) @@ -2739,15 +3167,14 @@ def test_get_parent_map_gets_ghosts_from_result(self): # asking for a revision should negatively cache close ghosts in its # ancestry. self.setup_smart_server_with_call_log() - tree = self.make_branch_and_memory_tree('foo') + tree = self.make_branch_and_memory_tree("foo") with tree.lock_write(): builder = treebuilder.TreeBuilder() builder.start_tree(tree) builder.build([]) builder.finish_tree() - tree.set_parent_ids([b'non-existant'], - allow_leftmost_as_ghost=True) - rev_id = tree.commit('') + tree.set_parent_ids([b"non-existant"], allow_leftmost_as_ghost=True) + rev_id = tree.commit("") tree.lock_read() self.addCleanup(tree.unlock) repo = tree.branch.repository @@ -2756,106 +3183,120 @@ def test_get_parent_map_gets_ghosts_from_result(self): repo.get_parent_map([rev_id]) self.reset_smart_call_log() # Now asking for rev_id's ghost parent should not make calls - self.assertEqual({}, repo.get_parent_map([b'non-existant'])) + self.assertEqual({}, repo.get_parent_map([b"non-existant"])) self.assertLength(0, self.hpss_calls) def test_exposes_get_cached_parent_map(self): """RemoteRepository exposes get_cached_parent_map from _unstacked_provider. """ - r1 = '\u0e33'.encode() - r2 = '\u0dab'.encode() - lines = [b' '.join([r2, r1]), r1] - encoded_body = bz2.compress(b'\n'.join(lines)) + r1 = "\u0e33".encode() + r2 = "\u0dab".encode() + lines = [b" ".join([r2, r1]), r1] + encoded_body = bz2.compress(b"\n".join(lines)) - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body(encoded_body, b'ok') + client.add_success_response_with_body(encoded_body, b"ok") repo.lock_read() # get_cached_parent_map should *not* trigger an RPC self.assertEqual({}, repo.get_cached_parent_map([r1])) self.assertEqual([], client._calls) self.assertEqual({r2: (r1,)}, repo.get_parent_map([r2])) - self.assertEqual({r1: (NULL_REVISION,)}, - repo.get_cached_parent_map([r1])) + self.assertEqual({r1: (NULL_REVISION,)}, repo.get_cached_parent_map([r1])) self.assertEqual( - [('call_with_body_bytes_expecting_body', - b'Repository.get_parent_map', (b'quack/', - b'include-missing:', r2), - b'\n\n0')], - client._calls) + [ + ( + "call_with_body_bytes_expecting_body", + b"Repository.get_parent_map", + (b"quack/", b"include-missing:", r2), + b"\n\n0", + ) + ], + client._calls, + ) repo.unlock() class TestGetParentMapAllowsNew(tests.TestCaseWithTransport): - def test_allows_new_revisions(self): """get_parent_map's results can be updated by commit.""" smart_server = test_server.SmartTCPServer_for_testing() self.start_server(smart_server) - self.make_branch('branch') - branch = Branch.open(smart_server.get_url() + '/branch') - tree = branch.create_checkout('tree', lightweight=True) + self.make_branch("branch") + branch = Branch.open(smart_server.get_url() + "/branch") + tree = branch.create_checkout("tree", lightweight=True) tree.lock_write() self.addCleanup(tree.unlock) graph = tree.branch.repository.get_graph() # This provides an opportunity for the missing rev-id to be cached. - self.assertEqual({}, graph.get_parent_map([b'rev1'])) - tree.commit('message', rev_id=b'rev1') + self.assertEqual({}, graph.get_parent_map([b"rev1"])) + tree.commit("message", rev_id=b"rev1") graph = tree.branch.repository.get_graph() - self.assertEqual({b'rev1': (b'null:',)}, - graph.get_parent_map([b'rev1'])) + self.assertEqual({b"rev1": (b"null:",)}, graph.get_parent_map([b"rev1"])) class TestRepositoryGetRevisions(TestRemoteRepository): - def test_hpss_missing_revision(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body( - b'', b'ok', b'10') - self.assertRaises(errors.NoSuchRevision, repo.get_revisions, - [b'somerev1', b'anotherrev2']) + client.add_success_response_with_body(b"", b"ok", b"10") + self.assertRaises( + errors.NoSuchRevision, repo.get_revisions, [b"somerev1", b"anotherrev2"] + ) self.assertEqual( - [('call_with_body_bytes_expecting_body', - b'Repository.iter_revisions', (b'quack/', ), - b"somerev1\nanotherrev2")], - client._calls) + [ + ( + "call_with_body_bytes_expecting_body", + b"Repository.iter_revisions", + (b"quack/",), + b"somerev1\nanotherrev2", + ) + ], + client._calls, + ) def test_hpss_get_single_revision(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - somerev1 = Revision(b"somerev1", + somerev1 = Revision( + b"somerev1", committer="Joe Committer ", timestamp=1321828927, timezone=-60, inventory_sha1=b"691b39be74c67b1212a75fcb19c433aaed903c2b", parent_ids=[], message="Message", - properties={}) - body = zlib.compress(b''.join(revision_bencode_serializer.write_revision_to_lines( - somerev1))) + properties={}, + ) + body = zlib.compress( + b"".join(revision_bencode_serializer.write_revision_to_lines(somerev1)) + ) # Split up body into two bits to make sure the zlib compression object # gets data fed twice. - client.add_success_response_with_body( - [body[:10], body[10:]], b'ok', b'10') - revs = repo.get_revisions([b'somerev1']) + client.add_success_response_with_body([body[:10], body[10:]], b"ok", b"10") + revs = repo.get_revisions([b"somerev1"]) self.assertEqual(revs, [somerev1]) self.assertEqual( - [('call_with_body_bytes_expecting_body', - b'Repository.iter_revisions', - (b'quack/', ), b"somerev1")], - client._calls) + [ + ( + "call_with_body_bytes_expecting_body", + b"Repository.iter_revisions", + (b"quack/",), + b"somerev1", + ) + ], + client._calls, + ) class TestRepositoryGetRevisionGraph(TestRemoteRepository): - def test_null_revision(self): # a null revision has the predictable result {}, we should have no wire # traffic when calling it with this argument - transport_path = 'empty' + transport_path = "empty" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'notused') + client.add_success_response(b"notused") # actual RemoteRepository.get_revision_graph is gone, but there's an # equivalent private method for testing result = repo._get_revision_graph(NULL_REVISION) @@ -2864,85 +3305,113 @@ def test_null_revision(self): def test_none_revision(self): # with none we want the entire graph - r1 = '\u0e33'.encode() - r2 = '\u0dab'.encode() - lines = [b' '.join([r2, r1]), r1] - encoded_body = b'\n'.join(lines) + r1 = "\u0e33".encode() + r2 = "\u0dab".encode() + lines = [b" ".join([r2, r1]), r1] + encoded_body = b"\n".join(lines) - transport_path = 'sinhala' + transport_path = "sinhala" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body(encoded_body, b'ok') + client.add_success_response_with_body(encoded_body, b"ok") # actual RemoteRepository.get_revision_graph is gone, but there's an # equivalent private method for testing result = repo._get_revision_graph(None) self.assertEqual( - [('call_expecting_body', b'Repository.get_revision_graph', - (b'sinhala/', b''))], - client._calls) - self.assertEqual({r1: (), r2: (r1, )}, result) + [ + ( + "call_expecting_body", + b"Repository.get_revision_graph", + (b"sinhala/", b""), + ) + ], + client._calls, + ) + self.assertEqual({r1: (), r2: (r1,)}, result) def test_specific_revision(self): # with a specific revision we want the graph for that # with none we want the entire graph - r11 = '\u0e33'.encode() - r12 = '\xc9'.encode() - r2 = '\u0dab'.encode() - lines = [b' '.join([r2, r11, r12]), r11, r12] - encoded_body = b'\n'.join(lines) + r11 = "\u0e33".encode() + r12 = "\xc9".encode() + r2 = "\u0dab".encode() + lines = [b" ".join([r2, r11, r12]), r11, r12] + encoded_body = b"\n".join(lines) - transport_path = 'sinhala' + transport_path = "sinhala" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body(encoded_body, b'ok') + client.add_success_response_with_body(encoded_body, b"ok") result = repo._get_revision_graph(r2) self.assertEqual( - [('call_expecting_body', b'Repository.get_revision_graph', - (b'sinhala/', r2))], - client._calls) - self.assertEqual({r11: (), r12: (), r2: (r11, r12), }, result) + [ + ( + "call_expecting_body", + b"Repository.get_revision_graph", + (b"sinhala/", r2), + ) + ], + client._calls, + ) + self.assertEqual( + { + r11: (), + r12: (), + r2: (r11, r12), + }, + result, + ) def test_no_such_revision(self): - revid = b'123' - transport_path = 'sinhala' + revid = b"123" + transport_path = "sinhala" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_error_response(b'nosuchrevision', revid) + client.add_error_response(b"nosuchrevision", revid) # also check that the right revision is reported in the error - self.assertRaises(errors.NoSuchRevision, - repo._get_revision_graph, revid) + self.assertRaises(errors.NoSuchRevision, repo._get_revision_graph, revid) self.assertEqual( - [('call_expecting_body', b'Repository.get_revision_graph', - (b'sinhala/', revid))], - client._calls) + [ + ( + "call_expecting_body", + b"Repository.get_revision_graph", + (b"sinhala/", revid), + ) + ], + client._calls, + ) def test_unexpected_error(self): - revid = b'123' - transport_path = 'sinhala' + revid = b"123" + transport_path = "sinhala" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_error_response(b'AnUnexpectedError') - e = self.assertRaises(UnknownErrorFromSmartServer, - repo._get_revision_graph, revid) - self.assertEqual((b'AnUnexpectedError',), e.error_tuple) + client.add_error_response(b"AnUnexpectedError") + e = self.assertRaises( + UnknownErrorFromSmartServer, repo._get_revision_graph, revid + ) + self.assertEqual((b"AnUnexpectedError",), e.error_tuple) class TestRepositoryGetRevIdForRevno(TestRemoteRepository): - def test_ok(self): - repo, client = self.setup_fake_client_and_repository('quack') + repo, client = self.setup_fake_client_and_repository("quack") client.add_expected_call( - b'Repository.get_rev_id_for_revno', (b'quack/', - 5, (42, b'rev-foo')), - b'success', (b'ok', b'rev-five')) - result = repo.get_rev_id_for_revno(5, (42, b'rev-foo')) - self.assertEqual((True, b'rev-five'), result) + b"Repository.get_rev_id_for_revno", + (b"quack/", 5, (42, b"rev-foo")), + b"success", + (b"ok", b"rev-five"), + ) + result = repo.get_rev_id_for_revno(5, (42, b"rev-foo")) + self.assertEqual((True, b"rev-five"), result) self.assertFinished(client) def test_history_incomplete(self): - repo, client = self.setup_fake_client_and_repository('quack') + repo, client = self.setup_fake_client_and_repository("quack") client.add_expected_call( - b'Repository.get_rev_id_for_revno', (b'quack/', - 5, (42, b'rev-foo')), - b'success', (b'history-incomplete', 10, b'rev-ten')) - result = repo.get_rev_id_for_revno(5, (42, b'rev-foo')) - self.assertEqual((False, (10, b'rev-ten')), result) + b"Repository.get_rev_id_for_revno", + (b"quack/", 5, (42, b"rev-foo")), + b"success", + (b"history-incomplete", 10, b"rev-ten"), + ) + result = repo.get_rev_id_for_revno(5, (42, b"rev-foo")) + self.assertEqual((False, (10, b"rev-ten")), result) self.assertFinished(client) def test_history_incomplete_with_fallback(self): @@ -2951,65 +3420,77 @@ def test_history_incomplete_with_fallback(self): """ # Make a repo with a fallback repo, both using a FakeClient. format = remote.response_tuple_to_repo_format( - (b'yes', b'no', b'yes', self.get_repo_format().network_name())) - repo, client = self.setup_fake_client_and_repository('quack') + (b"yes", b"no", b"yes", self.get_repo_format().network_name()) + ) + repo, client = self.setup_fake_client_and_repository("quack") repo._format = format - fallback_repo, ignored = self.setup_fake_client_and_repository( - 'fallback') + fallback_repo, ignored = self.setup_fake_client_and_repository("fallback") fallback_repo._client = client fallback_repo._format = format repo.add_fallback_repository(fallback_repo) # First the client should ask the primary repo client.add_expected_call( - b'Repository.get_rev_id_for_revno', (b'quack/', - 1, (42, b'rev-foo')), - b'success', (b'history-incomplete', 2, b'rev-two')) + b"Repository.get_rev_id_for_revno", + (b"quack/", 1, (42, b"rev-foo")), + b"success", + (b"history-incomplete", 2, b"rev-two"), + ) # Then it should ask the fallback, using revno/revid from the # history-incomplete response as the known revno/revid. client.add_expected_call( - b'Repository.get_rev_id_for_revno', ( - b'fallback/', 1, (2, b'rev-two')), - b'success', (b'ok', b'rev-one')) - result = repo.get_rev_id_for_revno(1, (42, b'rev-foo')) - self.assertEqual((True, b'rev-one'), result) + b"Repository.get_rev_id_for_revno", + (b"fallback/", 1, (2, b"rev-two")), + b"success", + (b"ok", b"rev-one"), + ) + result = repo.get_rev_id_for_revno(1, (42, b"rev-foo")) + self.assertEqual((True, b"rev-one"), result) self.assertFinished(client) def test_nosuchrevision(self): # 'nosuchrevision' is returned when the known-revid is not found in the # remote repo. The client translates that response to NoSuchRevision. - repo, client = self.setup_fake_client_and_repository('quack') + repo, client = self.setup_fake_client_and_repository("quack") client.add_expected_call( - b'Repository.get_rev_id_for_revno', (b'quack/', - 5, (42, b'rev-foo')), - b'error', (b'nosuchrevision', b'rev-foo')) + b"Repository.get_rev_id_for_revno", + (b"quack/", 5, (42, b"rev-foo")), + b"error", + (b"nosuchrevision", b"rev-foo"), + ) self.assertRaises( - errors.NoSuchRevision, - repo.get_rev_id_for_revno, 5, (42, b'rev-foo')) + errors.NoSuchRevision, repo.get_rev_id_for_revno, 5, (42, b"rev-foo") + ) self.assertFinished(client) def test_outofbounds(self): - repo, client = self.setup_fake_client_and_repository('quack') + repo, client = self.setup_fake_client_and_repository("quack") client.add_expected_call( - b'Repository.get_rev_id_for_revno', (b'quack/', - 43, (42, b'rev-foo')), - b'error', (b'revno-outofbounds', 43, 0, 42)) + b"Repository.get_rev_id_for_revno", + (b"quack/", 43, (42, b"rev-foo")), + b"error", + (b"revno-outofbounds", 43, 0, 42), + ) self.assertRaises( - errors.RevnoOutOfBounds, - repo.get_rev_id_for_revno, 43, (42, b'rev-foo')) + errors.RevnoOutOfBounds, repo.get_rev_id_for_revno, 43, (42, b"rev-foo") + ) self.assertFinished(client) def test_outofbounds_old(self): # Older versions of bzr didn't support RevnoOutOfBounds - repo, client = self.setup_fake_client_and_repository('quack') - client.add_expected_call( - b'Repository.get_rev_id_for_revno', (b'quack/', - 43, (42, b'rev-foo')), - b'error', ( - b'error', b'ValueError', - b'requested revno (43) is later than given known revno (42)')) + repo, client = self.setup_fake_client_and_repository("quack") + client.add_expected_call( + b"Repository.get_rev_id_for_revno", + (b"quack/", 43, (42, b"rev-foo")), + b"error", + ( + b"error", + b"ValueError", + b"requested revno (43) is later than given known revno (42)", + ), + ) self.assertRaises( - errors.RevnoOutOfBounds, - repo.get_rev_id_for_revno, 43, (42, b'rev-foo')) + errors.RevnoOutOfBounds, repo.get_rev_id_for_revno, 43, (42, b"rev-foo") + ) self.assertFinished(client) def test_branch_fallback_locking(self): @@ -3018,299 +3499,318 @@ def test_branch_fallback_locking(self): will be invoked, which will fail if the repo is unlocked. """ self.setup_smart_server_with_call_log() - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - rev1 = tree.commit('First') - tree.commit('Second') + tree.add("") + rev1 = tree.commit("First") + tree.commit("Second") tree.unlock() branch = tree.branch self.assertFalse(branch.is_locked()) self.reset_smart_call_log() - verb = b'Repository.get_rev_id_for_revno' + verb = b"Repository.get_rev_id_for_revno" self.disable_verb(verb) self.assertEqual(rev1, branch.get_rev_id(1)) - self.assertLength(1, [call for call in self.hpss_calls if - call.call.method == verb]) + self.assertLength( + 1, [call for call in self.hpss_calls if call.call.method == verb] + ) class TestRepositoryHasSignatureForRevisionId(TestRemoteRepository): - def test_has_signature_for_revision_id(self): # ('yes', ) for Repository.has_signature_for_revision_id -> 'True'. - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'yes') - result = repo.has_signature_for_revision_id(b'A') + client.add_success_response(b"yes") + result = repo.has_signature_for_revision_id(b"A") self.assertEqual( - [('call', b'Repository.has_signature_for_revision_id', - (b'quack/', b'A'))], - client._calls) + [("call", b"Repository.has_signature_for_revision_id", (b"quack/", b"A"))], + client._calls, + ) self.assertEqual(True, result) def test_is_not_shared(self): # ('no', ) for Repository.has_signature_for_revision_id -> 'False'. - transport_path = 'qwack' + transport_path = "qwack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'no') - result = repo.has_signature_for_revision_id(b'A') + client.add_success_response(b"no") + result = repo.has_signature_for_revision_id(b"A") self.assertEqual( - [('call', b'Repository.has_signature_for_revision_id', - (b'qwack/', b'A'))], - client._calls) + [("call", b"Repository.has_signature_for_revision_id", (b"qwack/", b"A"))], + client._calls, + ) self.assertEqual(False, result) class TestRepositoryPhysicalLockStatus(TestRemoteRepository): - def test_get_physical_lock_status_yes(self): - transport_path = 'qwack' + transport_path = "qwack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'yes') + client.add_success_response(b"yes") result = repo.get_physical_lock_status() self.assertEqual( - [('call', b'Repository.get_physical_lock_status', - (b'qwack/', ))], - client._calls) + [("call", b"Repository.get_physical_lock_status", (b"qwack/",))], + client._calls, + ) self.assertEqual(True, result) def test_get_physical_lock_status_no(self): - transport_path = 'qwack' + transport_path = "qwack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'no') + client.add_success_response(b"no") result = repo.get_physical_lock_status() self.assertEqual( - [('call', b'Repository.get_physical_lock_status', - (b'qwack/', ))], - client._calls) + [("call", b"Repository.get_physical_lock_status", (b"qwack/",))], + client._calls, + ) self.assertEqual(False, result) class TestRepositoryIsShared(TestRemoteRepository): - def test_is_shared(self): # ('yes', ) for Repository.is_shared -> 'True'. - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'yes') + client.add_success_response(b"yes") result = repo.is_shared() self.assertEqual( - [('call', b'Repository.is_shared', (b'quack/',))], - client._calls) + [("call", b"Repository.is_shared", (b"quack/",))], client._calls + ) self.assertEqual(True, result) def test_is_not_shared(self): # ('no', ) for Repository.is_shared -> 'False'. - transport_path = 'qwack' + transport_path = "qwack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'no') + client.add_success_response(b"no") result = repo.is_shared() self.assertEqual( - [('call', b'Repository.is_shared', (b'qwack/',))], - client._calls) + [("call", b"Repository.is_shared", (b"qwack/",))], client._calls + ) self.assertEqual(False, result) class TestRepositoryMakeWorkingTrees(TestRemoteRepository): - def test_make_working_trees(self): # ('yes', ) for Repository.make_working_trees -> 'True'. - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'yes') + client.add_success_response(b"yes") result = repo.make_working_trees() self.assertEqual( - [('call', b'Repository.make_working_trees', (b'quack/',))], - client._calls) + [("call", b"Repository.make_working_trees", (b"quack/",))], client._calls + ) self.assertEqual(True, result) def test_no_working_trees(self): # ('no', ) for Repository.make_working_trees -> 'False'. - transport_path = 'qwack' + transport_path = "qwack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'no') + client.add_success_response(b"no") result = repo.make_working_trees() self.assertEqual( - [('call', b'Repository.make_working_trees', (b'qwack/',))], - client._calls) + [("call", b"Repository.make_working_trees", (b"qwack/",))], client._calls + ) self.assertEqual(False, result) class TestRepositoryLockWrite(TestRemoteRepository): - def test_lock_write(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'ok', b'a token') + client.add_success_response(b"ok", b"a token") token = repo.lock_write().repository_token self.assertEqual( - [('call', b'Repository.lock_write', (b'quack/', b''))], - client._calls) - self.assertEqual(b'a token', token) + [("call", b"Repository.lock_write", (b"quack/", b""))], client._calls + ) + self.assertEqual(b"a token", token) def test_lock_write_already_locked(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_error_response(b'LockContention') + client.add_error_response(b"LockContention") self.assertRaises(errors.LockContention, repo.lock_write) self.assertEqual( - [('call', b'Repository.lock_write', (b'quack/', b''))], - client._calls) + [("call", b"Repository.lock_write", (b"quack/", b""))], client._calls + ) def test_lock_write_unlockable(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_error_response(b'UnlockableTransport') + client.add_error_response(b"UnlockableTransport") self.assertRaises(errors.UnlockableTransport, repo.lock_write) self.assertEqual( - [('call', b'Repository.lock_write', (b'quack/', b''))], - client._calls) + [("call", b"Repository.lock_write", (b"quack/", b""))], client._calls + ) class TestRepositoryWriteGroups(TestRemoteRepository): - def test_start_write_group(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'a token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"a token") + ) client.add_expected_call( - b'Repository.start_write_group', (b'quack/', b'a token'), - b'success', (b'ok', (b'token1', ))) + b"Repository.start_write_group", + (b"quack/", b"a token"), + b"success", + (b"ok", (b"token1",)), + ) repo.lock_write() repo.start_write_group() def test_start_write_group_unsuspendable(self): # Some repositories do not support suspending write # groups. For those, fall back to the "real" repository. - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) def stub_ensure_real(): - client._calls.append(('_ensure_real',)) + client._calls.append(("_ensure_real",)) repo._real_repository = _StubRealPackRepository(client._calls) + repo._ensure_real = stub_ensure_real client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'a token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"a token") + ) client.add_expected_call( - b'Repository.start_write_group', (b'quack/', b'a token'), - b'error', (b'UnsuspendableWriteGroup',)) + b"Repository.start_write_group", + (b"quack/", b"a token"), + b"error", + (b"UnsuspendableWriteGroup",), + ) repo.lock_write() repo.start_write_group() - self.assertEqual(client._calls[-2:], [ - ('_ensure_real',), - ('start_write_group',)]) + self.assertEqual( + client._calls[-2:], [("_ensure_real",), ("start_write_group",)] + ) def test_commit_write_group(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'a token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"a token") + ) client.add_expected_call( - b'Repository.start_write_group', (b'quack/', b'a token'), - b'success', (b'ok', [b'token1'])) + b"Repository.start_write_group", + (b"quack/", b"a token"), + b"success", + (b"ok", [b"token1"]), + ) client.add_expected_call( - b'Repository.commit_write_group', (b'quack/', - b'a token', [b'token1']), - b'success', (b'ok',)) + b"Repository.commit_write_group", + (b"quack/", b"a token", [b"token1"]), + b"success", + (b"ok",), + ) repo.lock_write() repo.start_write_group() repo.commit_write_group() def test_abort_write_group(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'a token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"a token") + ) client.add_expected_call( - b'Repository.start_write_group', (b'quack/', b'a token'), - b'success', (b'ok', [b'token1'])) + b"Repository.start_write_group", + (b"quack/", b"a token"), + b"success", + (b"ok", [b"token1"]), + ) client.add_expected_call( - b'Repository.abort_write_group', (b'quack/', - b'a token', [b'token1']), - b'success', (b'ok',)) + b"Repository.abort_write_group", + (b"quack/", b"a token", [b"token1"]), + b"success", + (b"ok",), + ) repo.lock_write() repo.start_write_group() repo.abort_write_group(False) def test_suspend_write_group(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) self.assertEqual([], repo.suspend_write_group()) def test_resume_write_group(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'a token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"a token") + ) client.add_expected_call( - b'Repository.check_write_group', (b'quack/', - b'a token', [b'token1']), - b'success', (b'ok',)) + b"Repository.check_write_group", + (b"quack/", b"a token", [b"token1"]), + b"success", + (b"ok",), + ) repo.lock_write() - repo.resume_write_group(['token1']) + repo.resume_write_group(["token1"]) class TestRepositorySetMakeWorkingTrees(TestRemoteRepository): - def test_backwards_compat(self): self.setup_smart_server_with_call_log() - repo = self.make_repository('.') + repo = self.make_repository(".") self.reset_smart_call_log() - verb = b'Repository.set_make_working_trees' + verb = b"Repository.set_make_working_trees" self.disable_verb(verb) repo.set_make_working_trees(True) - call_count = len([call for call in self.hpss_calls if - call.call.method == verb]) + call_count = len([call for call in self.hpss_calls if call.call.method == verb]) self.assertEqual(1, call_count) def test_current(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.set_make_working_trees', (b'quack/', b'True'), - b'success', (b'ok',)) + b"Repository.set_make_working_trees", + (b"quack/", b"True"), + b"success", + (b"ok",), + ) client.add_expected_call( - b'Repository.set_make_working_trees', (b'quack/', b'False'), - b'success', (b'ok',)) + b"Repository.set_make_working_trees", + (b"quack/", b"False"), + b"success", + (b"ok",), + ) repo.set_make_working_trees(True) repo.set_make_working_trees(False) class TestRepositoryUnlock(TestRemoteRepository): - def test_unlock(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'ok', b'a token') - client.add_success_response(b'ok') + client.add_success_response(b"ok", b"a token") + client.add_success_response(b"ok") repo.lock_write() repo.unlock() self.assertEqual( - [('call', b'Repository.lock_write', (b'quack/', b'')), - ('call', b'Repository.unlock', (b'quack/', b'a token'))], - client._calls) + [ + ("call", b"Repository.lock_write", (b"quack/", b"")), + ("call", b"Repository.unlock", (b"quack/", b"a token")), + ], + client._calls, + ) def test_unlock_wrong_token(self): # If somehow the token is wrong, unlock will raise TokenMismatch. - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response(b'ok', b'a token') - client.add_error_response(b'TokenMismatch') + client.add_success_response(b"ok", b"a token") + client.add_error_response(b"TokenMismatch") repo.lock_write() self.assertRaises(errors.TokenMismatch, repo.unlock) class TestRepositoryHasRevision(TestRemoteRepository): - def test_none(self): # repo.has_revision(None) should not cause any traffic. - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) # The null revision is always there, so has_revision(None) == True. @@ -3324,27 +3824,36 @@ class TestRepositoryIterFilesBytes(TestRemoteRepository): """Test Repository.iter_file_bytes.""" def test_single(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.iter_files_bytes', (b'quack/', ), - b'success', (b'ok',), iter([b"ok\x000", b"\n", zlib.compress(b"mydata" * 10)])) - for (identifier, byte_stream) in repo.iter_files_bytes([(b"somefile", - b"somerev", b"myid")]): + b"Repository.iter_files_bytes", + (b"quack/",), + b"success", + (b"ok",), + iter([b"ok\x000", b"\n", zlib.compress(b"mydata" * 10)]), + ) + for identifier, byte_stream in repo.iter_files_bytes( + [(b"somefile", b"somerev", b"myid")] + ): self.assertEqual(b"myid", identifier) self.assertEqual(b"".join(byte_stream), b"mydata" * 10) def test_missing(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.iter_files_bytes', - (b'quack/', ), - b'error', (b'RevisionNotPresent', b'somefile', b'somerev'), - iter([b"absent\0somefile\0somerev\n"])) - self.assertRaises(errors.RevisionNotPresent, list, - repo.iter_files_bytes( - [(b"somefile", b"somerev", b"myid")])) + b"Repository.iter_files_bytes", + (b"quack/",), + b"error", + (b"RevisionNotPresent", b"somefile", b"somerev"), + iter([b"absent\0somefile\0somerev\n"]), + ) + self.assertRaises( + errors.RevisionNotPresent, + list, + repo.iter_files_bytes([(b"somefile", b"somerev", b"myid")]), + ) class TestRepositoryInsertStreamBase(TestRemoteRepository): @@ -3375,55 +3884,70 @@ class TestRepositoryInsertStream(TestRepositoryInsertStreamBase): def setUp(self): super().setUp() - self.disable_verb(b'Repository.insert_stream_1.19') + self.disable_verb(b"Repository.insert_stream_1.19") def test_unlocked_repo(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b''), - b'unknown', (b'Repository.insert_stream_1.19',)) + b"Repository.insert_stream_1.19", + (b"quack/", b""), + b"unknown", + (b"Repository.insert_stream_1.19",), + ) client.add_expected_call( - b'Repository.insert_stream', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream", (b"quack/", b""), b"success", (b"ok",) + ) client.add_expected_call( - b'Repository.insert_stream', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream", (b"quack/", b""), b"success", (b"ok",) + ) self.checkInsertEmptyStream(repo, client) def test_locked_repo_with_no_lock_token(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"") + ) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b''), - b'unknown', (b'Repository.insert_stream_1.19',)) + b"Repository.insert_stream_1.19", + (b"quack/", b""), + b"unknown", + (b"Repository.insert_stream_1.19",), + ) client.add_expected_call( - b'Repository.insert_stream', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream", (b"quack/", b""), b"success", (b"ok",) + ) client.add_expected_call( - b'Repository.insert_stream', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream", (b"quack/", b""), b"success", (b"ok",) + ) repo.lock_write() self.checkInsertEmptyStream(repo, client) def test_locked_repo_with_lock_token(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'a token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"a token") + ) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b'', b'a token'), - b'unknown', (b'Repository.insert_stream_1.19',)) + b"Repository.insert_stream_1.19", + (b"quack/", b"", b"a token"), + b"unknown", + (b"Repository.insert_stream_1.19",), + ) client.add_expected_call( - b'Repository.insert_stream_locked', (b'quack/', b'', b'a token'), - b'success', (b'ok',)) + b"Repository.insert_stream_locked", + (b"quack/", b"", b"a token"), + b"success", + (b"ok",), + ) client.add_expected_call( - b'Repository.insert_stream_locked', (b'quack/', b'', b'a token'), - b'success', (b'ok',)) + b"Repository.insert_stream_locked", + (b"quack/", b"", b"a token"), + b"success", + (b"ok",), + ) repo.lock_write() self.checkInsertEmptyStream(repo, client) @@ -3433,17 +3957,20 @@ def test_stream_with_inventory_deltas(self): that verb will accept them. So when one is encountered the RemoteSink immediately stops using that verb and falls back to VFS insert_stream. """ - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b''), - b'unknown', (b'Repository.insert_stream_1.19',)) + b"Repository.insert_stream_1.19", + (b"quack/", b""), + b"unknown", + (b"Repository.insert_stream_1.19",), + ) client.add_expected_call( - b'Repository.insert_stream', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream", (b"quack/", b""), b"success", (b"ok",) + ) client.add_expected_call( - b'Repository.insert_stream', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream", (b"quack/", b""), b"success", (b"ok",) + ) # Create a fake real repository for insert_stream to fall back on, so # that we can directly see the records the RemoteSink passes to the # real sink. @@ -3455,8 +3982,10 @@ def __init__(self): def insert_stream(self, stream, src_format, resume_tokens): for substream_kind, substream in stream: self.records.append( - (substream_kind, [record.key for record in substream])) - return [b'fake tokens'], [b'fake missing keys'] + (substream_kind, [record.key for record in substream]) + ) + return [b"fake tokens"], [b"fake missing keys"] + fake_real_sink = FakeRealSink() class FakeRealRepository: @@ -3468,6 +3997,7 @@ def is_in_write_group(self): def refresh_data(self): return True + repo._real_repository = FakeRealRepository() sink = repo._get_sink() fmt = repository.format_registry.get_default() @@ -3476,13 +4006,14 @@ def refresh_data(self): # Every record from the first inventory delta should have been sent to # the VFS sink. expected_records = [ - ('inventory-deltas', [(b'rev2',), (b'rev3',)]), - ('texts', [(b'some-rev', b'some-file')])] + ("inventory-deltas", [(b"rev2",), (b"rev3",)]), + ("texts", [(b"some-rev", b"some-file")]), + ] self.assertEqual(expected_records, fake_real_sink.records) # The return values from the real sink's insert_stream are propagated # back to the original caller. - self.assertEqual([b'fake tokens'], resume_tokens) - self.assertEqual([b'fake missing keys'], missing_keys) + self.assertEqual([b"fake tokens"], resume_tokens) + self.assertEqual([b"fake missing keys"], missing_keys) self.assertFinished(client) def make_stream_with_inv_deltas(self, fmt): @@ -3496,113 +4027,137 @@ def make_stream_with_inv_deltas(self, fmt): * texts substream: (some-rev, some-file) """ # Define a stream using generators so that it isn't rewindable. - inv = inventory.Inventory(revision_id=b'rev1') - inv.root.revision = b'rev1' + inv = inventory.Inventory(revision_id=b"rev1") + inv.root.revision = b"rev1" def stream_with_inv_delta(): - yield ('inventories', inventories_substream()) - yield ('inventory-deltas', inventory_delta_substream()) - yield ('texts', [ - versionedfile.FulltextContentFactory( - (b'some-rev', b'some-file'), (), None, b'content')]) + yield ("inventories", inventories_substream()) + yield ("inventory-deltas", inventory_delta_substream()) + yield ( + "texts", + [ + versionedfile.FulltextContentFactory( + (b"some-rev", b"some-file"), (), None, b"content" + ) + ], + ) def inventories_substream(): # An empty inventory fulltext. This will be streamed normally. chunks = fmt._inventory_serializer.write_inventory_to_lines(inv) yield versionedfile.ChunkedContentFactory( - (b'rev1',), (), None, chunks, chunks_are_lines=True) + (b"rev1",), (), None, chunks, chunks_are_lines=True + ) def inventory_delta_substream(): # An inventory delta. This can't be streamed via this verb, so it # will trigger a fallback to VFS insert_stream. entry = inv.make_entry( - 'directory', 'newdir', inv.root.file_id, b'newdir-id') - entry.revision = b'ghost' - delta = inventory_delta.InventoryDelta([(None, 'newdir', b'newdir-id', entry)]) + "directory", "newdir", inv.root.file_id, b"newdir-id" + ) + entry.revision = b"ghost" + delta = inventory_delta.InventoryDelta( + [(None, "newdir", b"newdir-id", entry)] + ) serializer = inventory_delta.InventoryDeltaSerializer( - versioned_root=True, tree_references=False) - lines = serializer.delta_to_lines(b'rev1', b'rev2', delta) + versioned_root=True, tree_references=False + ) + lines = serializer.delta_to_lines(b"rev1", b"rev2", delta) yield versionedfile.ChunkedContentFactory( - (b'rev2',), ((b'rev1',)), None, lines) + (b"rev2",), ((b"rev1",)), None, lines + ) # Another delta. - lines = serializer.delta_to_lines(b'rev1', b'rev3', delta) + lines = serializer.delta_to_lines(b"rev1", b"rev3", delta) yield versionedfile.ChunkedContentFactory( - (b'rev3',), ((b'rev1',)), None, lines) + (b"rev3",), ((b"rev1",)), None, lines + ) + return stream_with_inv_delta() class TestRepositoryInsertStream_1_19(TestRepositoryInsertStreamBase): - def test_unlocked_repo(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream_1.19", (b"quack/", b""), b"success", (b"ok",) + ) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream_1.19", (b"quack/", b""), b"success", (b"ok",) + ) self.checkInsertEmptyStream(repo, client) def test_locked_repo_with_no_lock_token(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"") + ) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream_1.19", (b"quack/", b""), b"success", (b"ok",) + ) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b''), - b'success', (b'ok',)) + b"Repository.insert_stream_1.19", (b"quack/", b""), b"success", (b"ok",) + ) repo.lock_write() self.checkInsertEmptyStream(repo, client) def test_locked_repo_with_lock_token(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'a token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"a token") + ) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b'', b'a token'), - b'success', (b'ok',)) + b"Repository.insert_stream_1.19", + (b"quack/", b"", b"a token"), + b"success", + (b"ok",), + ) client.add_expected_call( - b'Repository.insert_stream_1.19', (b'quack/', b'', b'a token'), - b'success', (b'ok',)) + b"Repository.insert_stream_1.19", + (b"quack/", b"", b"a token"), + b"success", + (b"ok",), + ) repo.lock_write() self.checkInsertEmptyStream(repo, client) class TestRepositoryTarball(TestRemoteRepository): - # This is a canned tarball reponse we can validate against tarball_content = base64.b64decode( - 'QlpoOTFBWSZTWdGkj3wAAWF/k8aQACBIB//A9+8cIX/v33AACEAYABAECEACNz' - 'JqsgJJFPTSnk1A3qh6mTQAAAANPUHkagkSTEkaA09QaNAAAGgAAAcwCYCZGAEY' - 'mJhMJghpiaYBUkKammSHqNMZQ0NABkNAeo0AGneAevnlwQoGzEzNVzaYxp/1Uk' - 'xXzA1CQX0BJMZZLcPBrluJir5SQyijWHYZ6ZUtVqqlYDdB2QoCwa9GyWwGYDMA' - 'OQYhkpLt/OKFnnlT8E0PmO8+ZNSo2WWqeCzGB5fBXZ3IvV7uNJVE7DYnWj6qwB' - 'k5DJDIrQ5OQHHIjkS9KqwG3mc3t+F1+iujb89ufyBNIKCgeZBWrl5cXxbMGoMs' - 'c9JuUkg5YsiVcaZJurc6KLi6yKOkgCUOlIlOpOoXyrTJjK8ZgbklReDdwGmFgt' - 'dkVsAIslSVCd4AtACSLbyhLHryfb14PKegrVDba+U8OL6KQtzdM5HLjAc8/p6n' - '0lgaWU8skgO7xupPTkyuwheSckejFLK5T4ZOo0Gda9viaIhpD1Qn7JqqlKAJqC' - 'QplPKp2nqBWAfwBGaOwVrz3y1T+UZZNismXHsb2Jq18T+VaD9k4P8DqE3g70qV' - 'JLurpnDI6VS5oqDDPVbtVjMxMxMg4rzQVipn2Bv1fVNK0iq3Gl0hhnnHKm/egy' - 'nWQ7QH/F3JFOFCQ0aSPfA=' - ) + "QlpoOTFBWSZTWdGkj3wAAWF/k8aQACBIB//A9+8cIX/v33AACEAYABAECEACNz" + "JqsgJJFPTSnk1A3qh6mTQAAAANPUHkagkSTEkaA09QaNAAAGgAAAcwCYCZGAEY" + "mJhMJghpiaYBUkKammSHqNMZQ0NABkNAeo0AGneAevnlwQoGzEzNVzaYxp/1Uk" + "xXzA1CQX0BJMZZLcPBrluJir5SQyijWHYZ6ZUtVqqlYDdB2QoCwa9GyWwGYDMA" + "OQYhkpLt/OKFnnlT8E0PmO8+ZNSo2WWqeCzGB5fBXZ3IvV7uNJVE7DYnWj6qwB" + "k5DJDIrQ5OQHHIjkS9KqwG3mc3t+F1+iujb89ufyBNIKCgeZBWrl5cXxbMGoMs" + "c9JuUkg5YsiVcaZJurc6KLi6yKOkgCUOlIlOpOoXyrTJjK8ZgbklReDdwGmFgt" + "dkVsAIslSVCd4AtACSLbyhLHryfb14PKegrVDba+U8OL6KQtzdM5HLjAc8/p6n" + "0lgaWU8skgO7xupPTkyuwheSckejFLK5T4ZOo0Gda9viaIhpD1Qn7JqqlKAJqC" + "QplPKp2nqBWAfwBGaOwVrz3y1T+UZZNismXHsb2Jq18T+VaD9k4P8DqE3g70qV" + "JLurpnDI6VS5oqDDPVbtVjMxMxMg4rzQVipn2Bv1fVNK0iq3Gl0hhnnHKm/egy" + "nWQ7QH/F3JFOFCQ0aSPfA=" + ) def test_repository_tarball(self): # Test that Repository.tarball generates the right operations - transport_path = 'repo' - expected_calls = [('call_expecting_body', b'Repository.tarball', - (b'repo/', b'bz2',),), - ] + transport_path = "repo" + expected_calls = [ + ( + "call_expecting_body", + b"Repository.tarball", + ( + b"repo/", + b"bz2", + ), + ), + ] repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_success_response_with_body(self.tarball_content, b'ok') + client.add_success_response_with_body(self.tarball_content, b"ok") # Now actually ask for the tarball - tarball_file = repo._get_tarball('bz2') + tarball_file = repo._get_tarball("bz2") try: self.assertEqual(expected_calls, client._calls) self.assertEqual(self.tarball_content, tarball_file.read()) @@ -3615,13 +4170,13 @@ class TestRemoteRepositoryCopyContent(tests.TestCaseWithTransport): def test_copy_content_remote_to_local(self): self.transport_server = test_server.SmartTCPServer_for_testing - src_repo = self.make_repository('repo1') - src_repo = repository.Repository.open(self.get_url('repo1')) + src_repo = self.make_repository("repo1") + src_repo = repository.Repository.open(self.get_url("repo1")) # At the moment the tarball-based copy_content_into can't write back # into a smart server. It would be good if it could upload the # tarball; once that works we'd have to create repositories of # different formats. -- mbp 20070410 - dest_url = self.get_vfs_only_url('repo2') + dest_url = self.get_vfs_only_url("repo2") dest_bzrdir = BzrDir.create(dest_url) dest_repo = dest_bzrdir.create_repository() self.assertNotIsInstance(dest_repo, RemoteRepository) @@ -3630,28 +4185,26 @@ def test_copy_content_remote_to_local(self): class _StubRealPackRepository: - def __init__(self, calls): self.calls = calls self._pack_collection = _StubPackCollection(calls) def start_write_group(self): - self.calls.append(('start_write_group',)) + self.calls.append(("start_write_group",)) def is_in_write_group(self): return False def refresh_data(self): - self.calls.append(('pack collection reload_pack_names',)) + self.calls.append(("pack collection reload_pack_names",)) class _StubPackCollection: - def __init__(self, calls): self.calls = calls def autopack(self): - self.calls.append(('pack collection autopack',)) + self.calls.append(("pack collection autopack",)) class TestRemotePackRepositoryAutoPack(TestRemoteRepository): @@ -3661,10 +4214,11 @@ def test_ok(self): """When the server returns 'ok' and there's no _real_repository, then nothing else happens: the autopack method is done. """ - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'PackRepository.autopack', (b'quack/',), b'success', (b'ok',)) + b"PackRepository.autopack", (b"quack/",), b"success", (b"ok",) + ) repo.autopack() self.assertFinished(client) @@ -3672,44 +4226,51 @@ def test_ok_with_real_repo(self): """When the server returns 'ok' and there is a _real_repository, then the _real_repository's reload_pack_name's method will be called. """ - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'PackRepository.autopack', (b'quack/',), - b'success', (b'ok',)) + b"PackRepository.autopack", (b"quack/",), b"success", (b"ok",) + ) repo._real_repository = _StubRealPackRepository(client._calls) repo.autopack() self.assertEqual( - [('call', b'PackRepository.autopack', (b'quack/',)), - ('pack collection reload_pack_names',)], - client._calls) + [ + ("call", b"PackRepository.autopack", (b"quack/",)), + ("pack collection reload_pack_names",), + ], + client._calls, + ) def test_backwards_compatibility(self): """If the server does not recognise the PackRepository.autopack verb, fallback to the real_repository's implementation. """ - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - client.add_unknown_method_response(b'PackRepository.autopack') + client.add_unknown_method_response(b"PackRepository.autopack") def stub_ensure_real(): - client._calls.append(('_ensure_real',)) + client._calls.append(("_ensure_real",)) repo._real_repository = _StubRealPackRepository(client._calls) + repo._ensure_real = stub_ensure_real repo.autopack() self.assertEqual( - [('call', b'PackRepository.autopack', (b'quack/',)), - ('_ensure_real',), - ('pack collection autopack',)], - client._calls) + [ + ("call", b"PackRepository.autopack", (b"quack/",)), + ("_ensure_real",), + ("pack collection autopack",), + ], + client._calls, + ) def test_oom_error_reporting(self): """An out-of-memory condition on the server is reported clearly.""" - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'PackRepository.autopack', (b'quack/',), - b'error', (b'MemoryError',)) + b"PackRepository.autopack", (b"quack/",), b"error", (b"MemoryError",) + ) err = self.assertRaises(errors.BzrError, repo.autopack) self.assertContainsRe(str(err), "^remote server out of mem") @@ -3731,8 +4292,7 @@ def translateTuple(self, error_tuple, **context): # because _translate_error may need to re-raise it with a bare 'raise' # statement. server_error = errors.ErrorFromSmartServer(error_tuple) - translated_error = self.translateErrorFromSmartServer( - server_error, **context) + translated_error = self.translateErrorFromSmartServer(server_error, **context) return translated_error def translateErrorFromSmartServer(self, error_object, **context): @@ -3743,8 +4303,8 @@ def translateErrorFromSmartServer(self, error_object, **context): raise error_object except errors.ErrorFromSmartServer as server_error: translated_error = self.assertRaises( - errors.BzrError, remote._translate_error, server_error, - **context) + errors.BzrError, remote._translate_error, server_error, **context + ) return translated_error @@ -3761,121 +4321,125 @@ class TestErrorTranslationSuccess(TestErrorTranslationBase): """ def test_NoSuchRevision(self): - branch = self.make_branch('') - revid = b'revid' + branch = self.make_branch("") + revid = b"revid" translated_error = self.translateTuple( - (b'NoSuchRevision', revid), branch=branch) + (b"NoSuchRevision", revid), branch=branch + ) expected_error = errors.NoSuchRevision(branch, revid) self.assertEqual(expected_error, translated_error) def test_nosuchrevision(self): - repository = self.make_repository('') - revid = b'revid' + repository = self.make_repository("") + revid = b"revid" translated_error = self.translateTuple( - (b'nosuchrevision', revid), repository=repository) + (b"nosuchrevision", revid), repository=repository + ) expected_error = errors.NoSuchRevision(repository, revid) self.assertEqual(expected_error, translated_error) def test_nobranch(self): - bzrdir = self.make_controldir('') - translated_error = self.translateTuple((b'nobranch',), bzrdir=bzrdir) + bzrdir = self.make_controldir("") + translated_error = self.translateTuple((b"nobranch",), bzrdir=bzrdir) expected_error = errors.NotBranchError(path=bzrdir.root_transport.base) self.assertEqual(expected_error, translated_error) def test_nobranch_one_arg(self): - bzrdir = self.make_controldir('') + bzrdir = self.make_controldir("") translated_error = self.translateTuple( - (b'nobranch', b'extra detail'), bzrdir=bzrdir) + (b"nobranch", b"extra detail"), bzrdir=bzrdir + ) expected_error = errors.NotBranchError( - path=bzrdir.root_transport.base, - detail='extra detail') + path=bzrdir.root_transport.base, detail="extra detail" + ) self.assertEqual(expected_error, translated_error) def test_norepository(self): - bzrdir = self.make_controldir('') - translated_error = self.translateTuple((b'norepository',), - bzrdir=bzrdir) + bzrdir = self.make_controldir("") + translated_error = self.translateTuple((b"norepository",), bzrdir=bzrdir) expected_error = errors.NoRepositoryPresent(bzrdir) self.assertEqual(expected_error, translated_error) def test_LockContention(self): - translated_error = self.translateTuple((b'LockContention',)) - expected_error = errors.LockContention('(remote lock)') + translated_error = self.translateTuple((b"LockContention",)) + expected_error = errors.LockContention("(remote lock)") self.assertEqual(expected_error, translated_error) def test_UnlockableTransport(self): - bzrdir = self.make_controldir('') - translated_error = self.translateTuple( - (b'UnlockableTransport',), bzrdir=bzrdir) + bzrdir = self.make_controldir("") + translated_error = self.translateTuple((b"UnlockableTransport",), bzrdir=bzrdir) expected_error = errors.UnlockableTransport(bzrdir.root_transport) self.assertEqual(expected_error, translated_error) def test_LockFailed(self): - lock = 'str() of a server lock' - why = 'str() of why' + lock = "str() of a server lock" + why = "str() of why" translated_error = self.translateTuple( - (b'LockFailed', lock.encode('ascii'), why.encode('ascii'))) + (b"LockFailed", lock.encode("ascii"), why.encode("ascii")) + ) expected_error = errors.LockFailed(lock, why) self.assertEqual(expected_error, translated_error) def test_TokenMismatch(self): - token = 'a lock token' - translated_error = self.translateTuple( - (b'TokenMismatch',), token=token) - expected_error = errors.TokenMismatch(token, '(remote token)') + token = "a lock token" + translated_error = self.translateTuple((b"TokenMismatch",), token=token) + expected_error = errors.TokenMismatch(token, "(remote token)") self.assertEqual(expected_error, translated_error) def test_Diverged(self): - branch = self.make_branch('a') - other_branch = self.make_branch('b') + branch = self.make_branch("a") + other_branch = self.make_branch("b") translated_error = self.translateTuple( - (b'Diverged',), branch=branch, other_branch=other_branch) + (b"Diverged",), branch=branch, other_branch=other_branch + ) expected_error = errors.DivergedBranches(branch, other_branch) self.assertEqual(expected_error, translated_error) def test_NotStacked(self): - branch = self.make_branch('') - translated_error = self.translateTuple((b'NotStacked',), branch=branch) + branch = self.make_branch("") + translated_error = self.translateTuple((b"NotStacked",), branch=branch) expected_error = errors.NotStacked(branch) self.assertEqual(expected_error, translated_error) def test_ReadError_no_args(self): - path = 'a path' - translated_error = self.translateTuple((b'ReadError',), path=path) + path = "a path" + translated_error = self.translateTuple((b"ReadError",), path=path) expected_error = errors.ReadError(path) self.assertEqual(expected_error, translated_error) def test_ReadError(self): - path = 'a path' - translated_error = self.translateTuple( - (b'ReadError', path.encode('utf-8'))) + path = "a path" + translated_error = self.translateTuple((b"ReadError", path.encode("utf-8"))) expected_error = errors.ReadError(path) self.assertEqual(expected_error, translated_error) def test_IncompatibleRepositories(self): - translated_error = self.translateTuple((b'IncompatibleRepositories', - b"repo1", b"repo2", b"details here")) - expected_error = errors.IncompatibleRepositories("repo1", "repo2", - "details here") + translated_error = self.translateTuple( + (b"IncompatibleRepositories", b"repo1", b"repo2", b"details here") + ) + expected_error = errors.IncompatibleRepositories( + "repo1", "repo2", "details here" + ) self.assertEqual(expected_error, translated_error) def test_GhostRevisionsHaveNoRevno(self): - translated_error = self.translateTuple((b'GhostRevisionsHaveNoRevno', - b"revid1", b"revid2")) + translated_error = self.translateTuple( + (b"GhostRevisionsHaveNoRevno", b"revid1", b"revid2") + ) expected_error = errors.GhostRevisionsHaveNoRevno(b"revid1", b"revid2") self.assertEqual(expected_error, translated_error) def test_PermissionDenied_no_args(self): - path = 'a path' - translated_error = self.translateTuple((b'PermissionDenied',), - path=path) + path = "a path" + translated_error = self.translateTuple((b"PermissionDenied",), path=path) expected_error = errors.PermissionDenied(path) self.assertEqual(expected_error, translated_error) def test_PermissionDenied_one_arg(self): - path = 'a path' + path = "a path" translated_error = self.translateTuple( - (b'PermissionDenied', path.encode('utf-8'))) + (b"PermissionDenied", path.encode("utf-8")) + ) expected_error = errors.PermissionDenied(path) self.assertEqual(expected_error, translated_error) @@ -3883,18 +4447,20 @@ def test_PermissionDenied_one_arg_and_context(self): """Given a choice between a path from the local context and a path on the wire, _translate_error prefers the path from the local context. """ - local_path = 'local path' - remote_path = 'remote path' + local_path = "local path" + remote_path = "remote path" translated_error = self.translateTuple( - (b'PermissionDenied', remote_path.encode('utf-8')), path=local_path) + (b"PermissionDenied", remote_path.encode("utf-8")), path=local_path + ) expected_error = errors.PermissionDenied(local_path) self.assertEqual(expected_error, translated_error) def test_PermissionDenied_two_args(self): - path = 'a path' - extra = 'a string with extra info' + path = "a path" + extra = "a string with extra info" translated_error = self.translateTuple( - (b'PermissionDenied', path.encode('utf-8'), extra.encode('utf-8'))) + (b"PermissionDenied", path.encode("utf-8"), extra.encode("utf-8")) + ) expected_error = errors.PermissionDenied(path, extra) self.assertEqual(expected_error, translated_error) @@ -3902,31 +4468,31 @@ def test_PermissionDenied_two_args(self): def test_NoSuchFile_context_path(self): local_path = "local path" - translated_error = self.translateTuple((b'ReadError', b"remote path"), - path=local_path) + translated_error = self.translateTuple( + (b"ReadError", b"remote path"), path=local_path + ) expected_error = errors.ReadError(local_path) self.assertEqual(expected_error, translated_error) def test_NoSuchFile_without_context(self): remote_path = "remote path" translated_error = self.translateTuple( - (b'ReadError', remote_path.encode('utf-8'))) + (b"ReadError", remote_path.encode("utf-8")) + ) expected_error = errors.ReadError(remote_path) self.assertEqual(expected_error, translated_error) def test_ReadOnlyError(self): - translated_error = self.translateTuple((b'ReadOnlyError',)) + translated_error = self.translateTuple((b"ReadOnlyError",)) expected_error = errors.TransportNotPossible("readonly transport") self.assertEqual(expected_error, translated_error) def test_MemoryError(self): - translated_error = self.translateTuple((b'MemoryError',)) - self.assertStartsWith(str(translated_error), - "remote server out of memory") + translated_error = self.translateTuple((b"MemoryError",)) + self.assertStartsWith(str(translated_error), "remote server out of memory") def test_generic_IndexError_no_classname(self): - err = errors.ErrorFromSmartServer( - (b'error', b"list index out of range")) + err = errors.ErrorFromSmartServer((b"error", b"list index out of range")) translated_error = self.translateErrorFromSmartServer(err) expected_error = UnknownErrorFromSmartServer(err) self.assertEqual(expected_error, translated_error) @@ -3934,14 +4500,15 @@ def test_generic_IndexError_no_classname(self): # GZ 2011-03-02: TODO test generic non-ascii error string def test_generic_KeyError(self): - err = errors.ErrorFromSmartServer((b'error', b'KeyError', b"1")) + err = errors.ErrorFromSmartServer((b"error", b"KeyError", b"1")) translated_error = self.translateErrorFromSmartServer(err) expected_error = UnknownErrorFromSmartServer(err) self.assertEqual(expected_error, translated_error) def test_RevnoOutOfBounds(self): translated_error = self.translateTuple( - ((b'revno-outofbounds', 5, 0, 3)), path=b'path') + ((b"revno-outofbounds", 5, 0, 3)), path=b"path" + ) expected_error = errors.RevnoOutOfBounds(5, (0, 3)) self.assertEqual(expected_error, translated_error) @@ -3958,7 +4525,7 @@ def test_unrecognised_server_error(self): """If the error code from the server is not recognised, the original ErrorFromSmartServer is propagated unmodified. """ - error_tuple = (b'An unknown error tuple',) + error_tuple = (b"An unknown error tuple",) server_error = errors.ErrorFromSmartServer(error_tuple) translated_error = self.translateErrorFromSmartServer(server_error) expected_error = UnknownErrorFromSmartServer(server_error) @@ -3972,22 +4539,20 @@ def test_context_missing_a_key(self): # To translate a NoSuchRevision error _translate_error needs a 'branch' # in the context dict. So let's give it an empty context dict instead # to exercise its error recovery. - error_tuple = (b'NoSuchRevision', b'revid') + error_tuple = (b"NoSuchRevision", b"revid") server_error = errors.ErrorFromSmartServer(error_tuple) translated_error = self.translateErrorFromSmartServer(server_error) self.assertEqual(server_error, translated_error) # In addition to re-raising ErrorFromSmartServer, some debug info has # been muttered to the log file for developer to look at. - self.assertContainsRe( - self.get_log(), - "Missing key 'branch' in context") + self.assertContainsRe(self.get_log(), "Missing key 'branch' in context") def test_path_missing(self): """Some translations (PermissionDenied, ReadError) can determine the 'path' variable from either the wire or the local context. If neither has it, then an error is raised. """ - error_tuple = (b'ReadError',) + error_tuple = (b"ReadError",) server_error = errors.ErrorFromSmartServer(error_tuple) translated_error = self.translateErrorFromSmartServer(server_error) self.assertEqual(server_error, translated_error) @@ -4007,18 +4572,21 @@ def test_access_stacked_remote(self): # make a branch stacked on another repository containing an empty # revision, then open it over hpss - we should be able to see that # revision. - base_builder = self.make_branch_builder('base', format='1.9') + base_builder = self.make_branch_builder("base", format="1.9") base_builder.start_series() - base_revid = base_builder.build_snapshot(None, - [('add', ('', None, 'directory', None))], - 'message', revision_id=b'rev-id') + base_revid = base_builder.build_snapshot( + None, + [("add", ("", None, "directory", None))], + "message", + revision_id=b"rev-id", + ) base_builder.finish_series() - stacked_branch = self.make_branch('stacked', format='1.9') - stacked_branch.set_stacked_on_url('../base') + stacked_branch = self.make_branch("stacked", format="1.9") + stacked_branch.set_stacked_on_url("../base") # start a server looking at this smart_server = test_server.SmartTCPServer_for_testing() self.start_server(smart_server) - remote_bzrdir = BzrDir.open(smart_server.get_url() + '/stacked') + remote_bzrdir = BzrDir.open(smart_server.get_url() + "/stacked") # can get its branch and repository remote_branch = remote_bzrdir.open_branch() remote_repo = remote_branch.repository @@ -4027,27 +4595,26 @@ def test_access_stacked_remote(self): # it should have an appropriate fallback repository, which should also # be a RemoteRepository self.assertLength(1, remote_repo._fallback_repositories) - self.assertIsInstance(remote_repo._fallback_repositories[0], - RemoteRepository) + self.assertIsInstance( + remote_repo._fallback_repositories[0], RemoteRepository + ) # and it has the revision committed to the underlying repository; # these have varying implementations so we try several of them self.assertTrue(remote_repo.has_revisions([base_revid])) self.assertTrue(remote_repo.has_revision(base_revid)) - self.assertEqual(remote_repo.get_revision(base_revid).message, - 'message') + self.assertEqual(remote_repo.get_revision(base_revid).message, "message") finally: remote_repo.unlock() def prepare_stacked_remote_branch(self): """Get stacked_upon and stacked branches with content in each.""" self.setup_smart_server_with_call_log() - tree1 = self.make_branch_and_tree('tree1', format='1.9') - tree1.commit('rev1', rev_id=b'rev1') - tree2 = tree1.branch.controldir.sprout('tree2', stacked=True - ).open_workingtree() - local_tree = tree2.branch.create_checkout('local') - local_tree.commit('local changes make me feel good.') - branch2 = Branch.open(self.get_url('tree2')) + tree1 = self.make_branch_and_tree("tree1", format="1.9") + tree1.commit("rev1", rev_id=b"rev1") + tree2 = tree1.branch.controldir.sprout("tree2", stacked=True).open_workingtree() + local_tree = tree2.branch.create_checkout("local") + local_tree.commit("local changes make me feel good.") + branch2 = Branch.open(self.get_url("tree2")) branch2.lock_read() self.addCleanup(branch2.unlock) return tree1.branch, branch2 @@ -4056,18 +4623,18 @@ def test_stacked_get_parent_map(self): # the public implementation of get_parent_map obeys stacking _, branch = self.prepare_stacked_remote_branch() repo = branch.repository - self.assertEqual({b'rev1'}, set(repo.get_parent_map([b'rev1']))) + self.assertEqual({b"rev1"}, set(repo.get_parent_map([b"rev1"]))) def test_unstacked_get_parent_map(self): # _unstacked_provider.get_parent_map ignores stacking _, branch = self.prepare_stacked_remote_branch() provider = branch.repository._unstacked_provider - self.assertEqual(set(), set(provider.get_parent_map([b'rev1']))) + self.assertEqual(set(), set(provider.get_parent_map([b"rev1"]))) def fetch_stream_to_rev_order(self, stream): result = [] for kind, substream in stream: - if not kind == 'revisions': + if not kind == "revisions": list(substream) else: for content in substream: @@ -4095,8 +4662,7 @@ def get_ordered_revs(self, format, order, branch_factory=None): tip = stacked.last_revision() stacked.repository._ensure_real() graph = stacked.repository.get_graph() - revs = [r for (r, ps) in graph.iter_ancestry([tip]) - if r != NULL_REVISION] + revs = [r for (r, ps) in graph.iter_ancestry([tip]) if r != NULL_REVISION] revs.reverse() search = vf_search.PendingAncestryResult([tip], stacked.repository) self.reset_smart_call_log() @@ -4111,7 +4677,7 @@ def test_stacked_get_stream_unordered(self): # Repository._get_source.get_stream() from a stacked repository with # unordered yields the full data from both stacked and stacked upon # sources. - rev_ord, expected_revs = self.get_ordered_revs('1.9', 'unordered') + rev_ord, expected_revs = self.get_ordered_revs("1.9", "unordered") self.assertEqual(set(expected_revs), set(rev_ord)) # Getting unordered results should have made a streaming data request # from the server, then one from the backing branch. @@ -4122,16 +4688,17 @@ def test_stacked_on_stacked_get_stream_unordered(self): # is itself stacked yields the full data from all three sources. def make_stacked_stacked(): _, stacked = self.prepare_stacked_remote_branch() - tree = stacked.controldir.sprout('tree3', stacked=True - ).open_workingtree() - local_tree = tree.branch.create_checkout('local-tree3') - local_tree.commit('more local changes are better') - branch = Branch.open(self.get_url('tree3')) + tree = stacked.controldir.sprout("tree3", stacked=True).open_workingtree() + local_tree = tree.branch.create_checkout("local-tree3") + local_tree.commit("more local changes are better") + branch = Branch.open(self.get_url("tree3")) branch.lock_read() self.addCleanup(branch.unlock) return None, branch + rev_ord, expected_revs = self.get_ordered_revs( - '1.9', 'unordered', branch_factory=make_stacked_stacked) + "1.9", "unordered", branch_factory=make_stacked_stacked + ) self.assertEqual(set(expected_revs), set(rev_ord)) # Getting unordered results should have made a streaming data request # from the server, and one from each backing repo @@ -4141,7 +4708,7 @@ def test_stacked_get_stream_topological(self): # Repository._get_source.get_stream() from a stacked repository with # topological sorting yields the full data from both stacked and # stacked upon sources in topological order. - rev_ord, expected_revs = self.get_ordered_revs('knit', 'topological') + rev_ord, expected_revs = self.get_ordered_revs("knit", "topological") self.assertEqual(expected_revs, rev_ord) # Getting topological sort requires VFS calls still - one of which is # pushing up from the bound branch. @@ -4151,8 +4718,8 @@ def test_stacked_get_stream_groupcompress(self): # Repository._get_source.get_stream() from a stacked repository with # groupcompress sorting yields the full data from both stacked and # stacked upon sources in groupcompress order. - raise tests.TestSkipped('No groupcompress ordered format available') - rev_ord, expected_revs = self.get_ordered_revs('dev5', 'groupcompress') + raise tests.TestSkipped("No groupcompress ordered format available") + rev_ord, expected_revs = self.get_ordered_revs("dev5", "groupcompress") self.assertEqual(expected_revs, reversed(rev_ord)) # Getting unordered results should have made a streaming data request # from the backing branch, and one from the stacked on branch. @@ -4165,17 +4732,18 @@ def test_stacked_pull_more_than_stacking_has_bug_360791(self): # Need three branches: a trunk, a stacked branch, and a preexisting # branch pulling content from stacked and trunk. self.setup_smart_server_with_call_log() - trunk = self.make_branch_and_tree('trunk', format="1.9-rich-root") - trunk.commit('start') + trunk = self.make_branch_and_tree("trunk", format="1.9-rich-root") + trunk.commit("start") stacked_branch = trunk.branch.create_clone_on_transport( - self.get_transport('stacked'), stacked_on=trunk.branch.base) - local = self.make_branch('local', format='1.9-rich-root') - local.repository.fetch(stacked_branch.repository, - stacked_branch.last_revision()) + self.get_transport("stacked"), stacked_on=trunk.branch.base + ) + local = self.make_branch("local", format="1.9-rich-root") + local.repository.fetch( + stacked_branch.repository, stacked_branch.last_revision() + ) class TestRemoteBranchEffort(tests.TestCaseWithTransport): - def setUp(self): super().setUp() # Create a smart server that publishes whatever the backing VFS server @@ -4183,43 +4751,48 @@ def setUp(self): self.smart_server = test_server.SmartTCPServer_for_testing() self.start_server(self.smart_server, self.get_server()) # Log all HPSS calls into self.hpss_calls. - _SmartClient.hooks.install_named_hook( - 'call', self.capture_hpss_call, None) + _SmartClient.hooks.install_named_hook("call", self.capture_hpss_call, None) self.hpss_calls = [] def capture_hpss_call(self, params): self.hpss_calls.append(params.method) def test_copy_content_into_avoids_revision_history(self): - local = self.make_branch('local') - builder = self.make_branch_builder('remote') + local = self.make_branch("local") + builder = self.make_branch_builder("remote") builder.build_commit(message="Commit.") - remote_branch_url = self.smart_server.get_url() + 'remote' + remote_branch_url = self.smart_server.get_url() + "remote" remote_branch = bzrdir.BzrDir.open(remote_branch_url).open_branch() local.repository.fetch(remote_branch.repository) self.hpss_calls = [] remote_branch.copy_content_into(local) - self.assertNotIn(b'Branch.revision_history', self.hpss_calls) + self.assertNotIn(b"Branch.revision_history", self.hpss_calls) def test_fetch_everything_needs_just_one_call(self): - local = self.make_branch('local') - builder = self.make_branch_builder('remote') + local = self.make_branch("local") + builder = self.make_branch_builder("remote") builder.build_commit(message="Commit.") - remote_branch_url = self.smart_server.get_url() + 'remote' + remote_branch_url = self.smart_server.get_url() + "remote" remote_branch = bzrdir.BzrDir.open(remote_branch_url).open_branch() self.hpss_calls = [] local.repository.fetch( remote_branch.repository, - fetch_spec=vf_search.EverythingResult(remote_branch.repository)) - self.assertEqual([b'Repository.get_stream_1.19'], self.hpss_calls) + fetch_spec=vf_search.EverythingResult(remote_branch.repository), + ) + self.assertEqual([b"Repository.get_stream_1.19"], self.hpss_calls) def override_verb(self, verb_name, verb): request_handlers = request.request_handlers orig_verb = request_handlers.get(verb_name) orig_info = request_handlers.get_info(verb_name) request_handlers.register(verb_name, verb, override_existing=True) - self.addCleanup(request_handlers.register, verb_name, orig_verb, - override_existing=True, info=orig_info) + self.addCleanup( + request_handlers.register, + verb_name, + orig_verb, + override_existing=True, + info=orig_info, + ) def test_fetch_everything_backwards_compat(self): """Can fetch with EverythingResult even with pre 2.4 servers. @@ -4234,24 +4807,25 @@ class OldGetStreamVerb(SmartServerRepositoryGetStream_1_19): reject 'everything' searches the way 2.3 and earlier do. """ - def recreate_search(self, repository, search_bytes, - discard_excess=False): - verb_log.append(search_bytes.split(b'\n', 1)[0]) - if search_bytes == b'everything': - return (None, - request.FailedSmartServerResponse((b'BadSearch',))) - return super().recreate_search(repository, search_bytes, - discard_excess=discard_excess) - self.override_verb(b'Repository.get_stream_1.19', OldGetStreamVerb) - local = self.make_branch('local') - builder = self.make_branch_builder('remote') + def recreate_search(self, repository, search_bytes, discard_excess=False): + verb_log.append(search_bytes.split(b"\n", 1)[0]) + if search_bytes == b"everything": + return (None, request.FailedSmartServerResponse((b"BadSearch",))) + return super().recreate_search( + repository, search_bytes, discard_excess=discard_excess + ) + + self.override_verb(b"Repository.get_stream_1.19", OldGetStreamVerb) + local = self.make_branch("local") + builder = self.make_branch_builder("remote") builder.build_commit(message="Commit.") - remote_branch_url = self.smart_server.get_url() + 'remote' + remote_branch_url = self.smart_server.get_url() + "remote" remote_branch = bzrdir.BzrDir.open(remote_branch_url).open_branch() self.hpss_calls = [] local.repository.fetch( remote_branch.repository, - fetch_spec=vf_search.EverythingResult(remote_branch.repository)) + fetch_spec=vf_search.EverythingResult(remote_branch.repository), + ) # make sure the overridden verb was used self.assertLength(1, verb_log) # more than one HPSS call is needed, but because it's a VFS callback @@ -4259,8 +4833,7 @@ def recreate_search(self, repository, search_bytes, self.assertGreater(len(self.hpss_calls), 1) -class TestUpdateBoundBranchWithModifiedBoundLocation( - tests.TestCaseWithTransport): +class TestUpdateBoundBranchWithModifiedBoundLocation(tests.TestCaseWithTransport): """Ensure correct handling of bound_location modifications. This is tested against a smart server as http://pad.lv/786980 was about a @@ -4277,8 +4850,8 @@ def make_master_and_checkout(self, master_name, checkout_name): self.master = self.make_branch_and_tree(master_name) self.checkout = self.master.branch.create_checkout(checkout_name) # Modify the master branch so there is something to update - self.master.commit('add stuff') - self.last_revid = self.master.commit('even more stuff') + self.master.commit("add stuff") + self.last_revid = self.master.commit("even more stuff") self.bound_location = self.checkout.branch.get_bound_location() def assertUpdateSucceeds(self, new_location): @@ -4287,44 +4860,47 @@ def assertUpdateSucceeds(self, new_location): self.assertEqual(self.last_revid, self.checkout.last_revision()) def test_without_final_slash(self): - self.make_master_and_checkout('master', 'checkout') + self.make_master_and_checkout("master", "checkout") # For unclear reasons some users have a bound_location without a final # '/', simulate that by forcing such a value - self.assertEndsWith(self.bound_location, '/') - self.assertUpdateSucceeds(self.bound_location.rstrip('/')) + self.assertEndsWith(self.bound_location, "/") + self.assertUpdateSucceeds(self.bound_location.rstrip("/")) def test_plus_sign(self): - self.make_master_and_checkout('+master', 'checkout') - self.assertUpdateSucceeds(self.bound_location.replace('%2B', '+', 1)) + self.make_master_and_checkout("+master", "checkout") + self.assertUpdateSucceeds(self.bound_location.replace("%2B", "+", 1)) def test_tilda(self): # Embed ~ in the middle of the path just to avoid any $HOME # interpretation - self.make_master_and_checkout('mas~ter', 'checkout') - self.assertUpdateSucceeds(self.bound_location.replace('%2E', '~', 1)) + self.make_master_and_checkout("mas~ter", "checkout") + self.assertUpdateSucceeds(self.bound_location.replace("%2E", "~", 1)) class TestWithCustomErrorHandler(RemoteBranchTestCase): - def test_no_context(self): class OutOfCoffee(errors.BzrError): """A dummy exception for testing.""" def __init__(self, urgency): self.urgency = urgency - remote.no_context_error_translators.register(b"OutOfCoffee", - lambda err: OutOfCoffee(err.error_args[0])) + + remote.no_context_error_translators.register( + b"OutOfCoffee", lambda err: OutOfCoffee(err.error_args[0]) + ) transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.last_revision_info', - (b'quack/',), - b'error', (b'OutOfCoffee', b'low')) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.last_revision_info", + (b"quack/",), + b"error", + (b"OutOfCoffee", b"low"), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) self.assertRaises(OutOfCoffee, branch.last_revision_info) self.assertFinished(client) @@ -4334,56 +4910,62 @@ class OutOfTea(errors.BzrError): def __init__(self, branch, urgency): self.branch = branch self.urgency = urgency - remote.error_translators.register(b"OutOfTea", - lambda err, find, path: OutOfTea( - err.error_args[0].decode( - 'utf-8'), - find("branch"))) + + remote.error_translators.register( + b"OutOfTea", + lambda err, find, path: OutOfTea( + err.error_args[0].decode("utf-8"), find("branch") + ), + ) transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.last_revision_info', - (b'quack/',), - b'error', (b'OutOfTea', b'low')) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.last_revision_info", (b"quack/",), b"error", (b"OutOfTea", b"low") + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) self.assertRaises(OutOfTea, branch.last_revision_info) self.assertFinished(client) class TestRepositoryPack(TestRemoteRepository): - def test_pack(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"token") + ) client.add_expected_call( - b'Repository.pack', (b'quack/', b'token', b'False'), - b'success', (b'ok',), ) + b"Repository.pack", + (b"quack/", b"token", b"False"), + b"success", + (b"ok",), + ) client.add_expected_call( - b'Repository.unlock', (b'quack/', b'token'), - b'success', (b'ok', )) + b"Repository.unlock", (b"quack/", b"token"), b"success", (b"ok",) + ) repo.pack() def test_pack_with_hint(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'Repository.lock_write', (b'quack/', b''), - b'success', (b'ok', b'token')) + b"Repository.lock_write", (b"quack/", b""), b"success", (b"ok", b"token") + ) client.add_expected_call( - b'Repository.pack', (b'quack/', b'token', b'False'), - b'success', (b'ok',), ) + b"Repository.pack", + (b"quack/", b"token", b"False"), + b"success", + (b"ok",), + ) client.add_expected_call( - b'Repository.unlock', (b'quack/', b'token', b'False'), - b'success', (b'ok', )) - repo.pack(['hinta', 'hintb']) + b"Repository.unlock", (b"quack/", b"token", b"False"), b"success", (b"ok",) + ) + repo.pack(["hinta", "hintb"]) class TestRepositoryIterInventories(TestRemoteRepository): @@ -4394,38 +4976,56 @@ def _serialize_inv_delta(self, old_name, new_name, delta): return b"".join(serializer.delta_to_lines(old_name, new_name, delta)) def test_single_empty(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - fmt = controldir.format_registry.get('2a')().repository_format + fmt = controldir.format_registry.get("2a")().repository_format repo._format = fmt - stream = [('inventory-deltas', [ - versionedfile.FulltextContentFactory(b'somerevid', None, None, - self._serialize_inv_delta(b'null:', b'somerevid', inventory_delta.InventoryDelta([])))])] - client.add_expected_call( - b'VersionedFileRepository.get_inventories', ( - b'quack/', b'unordered'), - b'success', (b'ok', ), - _stream_to_byte_stream(stream, fmt)) + stream = [ + ( + "inventory-deltas", + [ + versionedfile.FulltextContentFactory( + b"somerevid", + None, + None, + self._serialize_inv_delta( + b"null:", b"somerevid", inventory_delta.InventoryDelta([]) + ), + ) + ], + ) + ] + client.add_expected_call( + b"VersionedFileRepository.get_inventories", + (b"quack/", b"unordered"), + b"success", + (b"ok",), + _stream_to_byte_stream(stream, fmt), + ) ret = list(repo.iter_inventories([b"somerevid"])) self.assertLength(1, ret) inv = ret[0] self.assertEqual(b"somerevid", inv.revision_id) def test_empty(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) ret = list(repo.iter_inventories([])) self.assertEqual(ret, []) def test_missing(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) client.add_expected_call( - b'VersionedFileRepository.get_inventories', ( - b'quack/', b'unordered'), - b'success', (b'ok', ), iter([])) - self.assertRaises(errors.NoSuchRevision, list, repo.iter_inventories( - [b"somerevid"])) + b"VersionedFileRepository.get_inventories", + (b"quack/", b"unordered"), + b"success", + (b"ok",), + iter([]), + ) + self.assertRaises( + errors.NoSuchRevision, list, repo.iter_inventories([b"somerevid"]) + ) class TestRepositoryRevisionTreeArchive(TestRemoteRepository): @@ -4436,35 +5036,50 @@ def _serialize_inv_delta(self, old_name, new_name, delta): return b"".join(serializer.delta_to_lines(old_name, new_name, delta)) def test_simple(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - fmt = controldir.format_registry.get('2a')().repository_format + fmt = controldir.format_registry.get("2a")().repository_format repo._format = fmt - stream = [('inventory-deltas', [ - versionedfile.FulltextContentFactory(b'somerevid', None, None, - self._serialize_inv_delta(b'null:', b'somerevid', inventory_delta.InventoryDelta([])))])] - client.add_expected_call( - b'VersionedFileRepository.get_inventories', ( - b'quack/', b'unordered'), - b'success', (b'ok', ), - _stream_to_byte_stream(stream, fmt)) + stream = [ + ( + "inventory-deltas", + [ + versionedfile.FulltextContentFactory( + b"somerevid", + None, + None, + self._serialize_inv_delta( + b"null:", b"somerevid", inventory_delta.InventoryDelta([]) + ), + ) + ], + ) + ] + client.add_expected_call( + b"VersionedFileRepository.get_inventories", + (b"quack/", b"unordered"), + b"success", + (b"ok",), + _stream_to_byte_stream(stream, fmt), + ) f = BytesIO() - with tarfile.open(mode='w', fileobj=f) as tf: - info = tarfile.TarInfo('somefile') + with tarfile.open(mode="w", fileobj=f) as tf: + info = tarfile.TarInfo("somefile") info.mtime = 432432 - contents = b'some data' + contents = b"some data" info.type = tarfile.REGTYPE info.mode = 0o644 info.size = len(contents) tf.addfile(info, BytesIO(contents)) client.add_expected_call( - b'Repository.revision_archive', (b'quack/', - b'somerevid', b'tar', b'foo.tar', b'', b'', None), - b'success', (b'ok', ), - f.getvalue()) - tree = repo.revision_tree(b'somerevid') - self.assertEqual(f.getvalue(), b''.join( - tree.archive('tar', 'foo.tar'))) + b"Repository.revision_archive", + (b"quack/", b"somerevid", b"tar", b"foo.tar", b"", b"", None), + b"success", + (b"ok",), + f.getvalue(), + ) + tree = repo.revision_tree(b"somerevid") + self.assertEqual(f.getvalue(), b"".join(tree.archive("tar", "foo.tar"))) class TestRepositoryAnnotate(TestRemoteRepository): @@ -4475,58 +5090,73 @@ def _serialize_inv_delta(self, old_name, new_name, delta): return b"".join(serializer.delta_to_lines(old_name, new_name, delta)) def test_simple(self): - transport_path = 'quack' + transport_path = "quack" repo, client = self.setup_fake_client_and_repository(transport_path) - fmt = controldir.format_registry.get('2a')().repository_format + fmt = controldir.format_registry.get("2a")().repository_format repo._format = fmt stream = [ - ('inventory-deltas', [ - versionedfile.FulltextContentFactory( - b'somerevid', None, None, - self._serialize_inv_delta(b'null:', b'somerevid', inventory_delta.InventoryDelta([])))])] - client.add_expected_call( - b'VersionedFileRepository.get_inventories', ( - b'quack/', b'unordered'), - b'success', (b'ok', ), - _stream_to_byte_stream(stream, fmt)) - client.add_expected_call( - b'Repository.annotate_file_revision', - (b'quack/', b'somerevid', b'filename', b'', b'current:'), - b'success', (b'ok', ), - bencode.bencode([[b'baserevid', b'line 1\n'], - [b'somerevid', b'line2\n']])) - tree = repo.revision_tree(b'somerevid') - self.assertEqual([ - (b'baserevid', b'line 1\n'), - (b'somerevid', b'line2\n')], - list(tree.annotate_iter('filename'))) + ( + "inventory-deltas", + [ + versionedfile.FulltextContentFactory( + b"somerevid", + None, + None, + self._serialize_inv_delta( + b"null:", b"somerevid", inventory_delta.InventoryDelta([]) + ), + ) + ], + ) + ] + client.add_expected_call( + b"VersionedFileRepository.get_inventories", + (b"quack/", b"unordered"), + b"success", + (b"ok",), + _stream_to_byte_stream(stream, fmt), + ) + client.add_expected_call( + b"Repository.annotate_file_revision", + (b"quack/", b"somerevid", b"filename", b"", b"current:"), + b"success", + (b"ok",), + bencode.bencode([[b"baserevid", b"line 1\n"], [b"somerevid", b"line2\n"]]), + ) + tree = repo.revision_tree(b"somerevid") + self.assertEqual( + [(b"baserevid", b"line 1\n"), (b"somerevid", b"line2\n")], + list(tree.annotate_iter("filename")), + ) class TestBranchGetAllReferenceInfo(RemoteBranchTestCase): - def test_get_all_reference_info(self): transport = MemoryTransport() client = FakeClient(transport.base) client.add_expected_call( - b'Branch.get_stacked_on_url', (b'quack/',), - b'error', (b'NotStacked',)) + b"Branch.get_stacked_on_url", (b"quack/",), b"error", (b"NotStacked",) + ) client.add_expected_call( - b'Branch.get_all_reference_info', (b'quack/',), - b'success', (b'ok',), bencode.bencode([ - (b'file-id', b'https://www.example.com/', b'')])) - transport.mkdir('quack') - transport = transport.clone('quack') + b"Branch.get_all_reference_info", + (b"quack/",), + b"success", + (b"ok",), + bencode.bencode([(b"file-id", b"https://www.example.com/", b"")]), + ) + transport.mkdir("quack") + transport = transport.clone("quack") branch = self.make_remote_branch(transport, client) result = branch._get_all_reference_info() self.assertFinished(client) - self.assertEqual({b'file-id': ('https://www.example.com/', None)}, result) + self.assertEqual({b"file-id": ("https://www.example.com/", None)}, result) class TestErrors(tests.TestCase): - def test_untranslateable_error_from_smart_server(self): - error_tuple = ('error', 'tuple') + error_tuple = ("error", "tuple") orig_err = errors.ErrorFromSmartServer(error_tuple) err = UnknownErrorFromSmartServer(orig_err) self.assertEqual( - "Server sent an unexpected error: ('error', 'tuple')", str(err)) + "Server sent an unexpected error: ('error', 'tuple')", str(err) + ) diff --git a/breezy/bzr/tests/test_repository.py b/breezy/bzr/tests/test_repository.py index f5bd2e4936..e7f5e83d10 100644 --- a/breezy/bzr/tests/test_repository.py +++ b/breezy/bzr/tests/test_repository.py @@ -59,10 +59,9 @@ class TestDefaultFormat(TestCase): - def test_get_set_default_format(self): - old_default = controldir.format_registry.get('default') - old_default_help = controldir.format_registry.get_help('default') + old_default = controldir.format_registry.get("default") + old_default_help = controldir.format_registry.get_help("default") private_default = old_default().repository_format.__class__ old_format = repository.format_registry.get_default() self.assertIsInstance(old_format, private_default) @@ -71,23 +70,26 @@ def make_sample_bzrdir(): my_bzrdir = bzrdir.BzrDirMetaFormat1() my_bzrdir.repository_format = SampleRepositoryFormat() return my_bzrdir - controldir.format_registry.remove('default') - controldir.format_registry.register('sample', make_sample_bzrdir, '') - controldir.format_registry.set_default('sample') + + controldir.format_registry.remove("default") + controldir.format_registry.register("sample", make_sample_bzrdir, "") + controldir.format_registry.set_default("sample") # creating a repository should now create an instrumented dir. try: # the default branch format is used by the meta dir format # which is not the default bzrdir format at this point - dir = bzrdir.BzrDirMetaFormat1().initialize('memory:///') + dir = bzrdir.BzrDirMetaFormat1().initialize("memory:///") result = dir.create_repository() - self.assertEqual(result, 'A bzr repository dir') + self.assertEqual(result, "A bzr repository dir") finally: - controldir.format_registry.remove('default') - controldir.format_registry.remove('sample') + controldir.format_registry.remove("default") + controldir.format_registry.remove("sample") controldir.format_registry.register( - 'default', old_default, old_default_help) - self.assertIsInstance(repository.format_registry.get_default(), - old_format.__class__) + "default", old_default, old_default_help + ) + self.assertIsInstance( + repository.format_registry.get_default(), old_format.__class__ + ) class SampleRepositoryFormat(bzrrepository.RepositoryFormatMetaDir): @@ -105,8 +107,8 @@ def get_format_string(cls): def initialize(self, a_controldir, shared=False): """Initialize a repository in a BzrDir.""" t = a_controldir.get_repository_transport(self) - t.put_bytes('format', self.get_format_string()) - return 'A bzr repository dir' + t.put_bytes("format", self.get_format_string()) + return "A bzr repository dir" def is_supported(self): return False @@ -134,51 +136,56 @@ def test_find_format(self): def check_format(format, url): dir = format._matchingcontroldir.initialize(url) format.initialize(dir) - found_format = bzrrepository.RepositoryFormatMetaDir.find_format( - dir) + found_format = bzrrepository.RepositoryFormatMetaDir.find_format(dir) self.assertIsInstance(found_format, format.__class__) + check_format(repository.format_registry.get_default(), "bar") def test_find_format_no_repository(self): dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url()) - self.assertRaises(errors.NoRepositoryPresent, - bzrrepository.RepositoryFormatMetaDir.find_format, - dir) + self.assertRaises( + errors.NoRepositoryPresent, + bzrrepository.RepositoryFormatMetaDir.find_format, + dir, + ) def test_from_string(self): self.assertIsInstance( - SampleRepositoryFormat.from_string( - b"Sample .bzr repository format."), - SampleRepositoryFormat) - self.assertRaises(AssertionError, - SampleRepositoryFormat.from_string, - b"Different .bzr repository format.") + SampleRepositoryFormat.from_string(b"Sample .bzr repository format."), + SampleRepositoryFormat, + ) + self.assertRaises( + AssertionError, + SampleRepositoryFormat.from_string, + b"Different .bzr repository format.", + ) def test_find_format_unknown_format(self): dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url()) SampleRepositoryFormat().initialize(dir) - self.assertRaises(UnknownFormatError, - bzrrepository.RepositoryFormatMetaDir.find_format, - dir) + self.assertRaises( + UnknownFormatError, bzrrepository.RepositoryFormatMetaDir.find_format, dir + ) def test_find_format_with_features(self): - tree = self.make_branch_and_tree('.', format='2a') + tree = self.make_branch_and_tree(".", format="2a") tree.branch.repository.update_feature_flags({b"name": b"necessity"}) found_format = bzrrepository.RepositoryFormatMetaDir.find_format( - tree.controldir) - self.assertIsInstance( - found_format, bzrrepository.RepositoryFormatMetaDir) + tree.controldir + ) + self.assertIsInstance(found_format, bzrrepository.RepositoryFormatMetaDir) self.assertEqual(found_format.features.get(b"name"), b"necessity") self.assertRaises( - bzrdir.MissingFeature, found_format.check_support_status, True) + bzrdir.MissingFeature, found_format.check_support_status, True + ) self.addCleanup( - bzrrepository.RepositoryFormatMetaDir.unregister_feature, b"name") + bzrrepository.RepositoryFormatMetaDir.unregister_feature, b"name" + ) bzrrepository.RepositoryFormatMetaDir.register_feature(b"name") found_format.check_support_status(True) class TestRepositoryFormatRegistry(TestCase): - def setUp(self): super().setUp() self.registry = repository.RepositoryFormatRegistry() @@ -186,11 +193,11 @@ def setUp(self): def test_register_unregister_format(self): format = SampleRepositoryFormat() self.registry.register(format) - self.assertEqual(format, self.registry.get( - b"Sample .bzr repository format.")) + self.assertEqual(format, self.registry.get(b"Sample .bzr repository format.")) self.registry.remove(format) - self.assertRaises(KeyError, self.registry.get, - b"Sample .bzr repository format.") + self.assertRaises( + KeyError, self.registry.get, b"Sample .bzr repository format." + ) def test_get_all(self): format = SampleRepositoryFormat() @@ -206,25 +213,25 @@ def test_register_extra(self): def test_register_extra_lazy(self): self.assertEqual([], self.registry._get_all()) - self.registry.register_extra_lazy(__name__, - "SampleExtraRepositoryFormat") + self.registry.register_extra_lazy(__name__, "SampleExtraRepositoryFormat") formats = self.registry._get_all() self.assertEqual(1, len(formats)) self.assertIsInstance(formats[0], SampleExtraRepositoryFormat) class TestFormatKnit1(TestCaseWithTransport): - def test_attribute__fetch_order(self): """Knits need topological data insertion.""" repo = self.make_repository( - '.', format=controldir.format_registry.get('knit')()) - self.assertEqual('topological', repo._format._fetch_order) + ".", format=controldir.format_registry.get("knit")() + ) + self.assertEqual("topological", repo._format._fetch_order) def test_attribute__fetch_uses_deltas(self): """Knits reuse deltas.""" repo = self.make_repository( - '.', format=controldir.format_registry.get('knit')()) + ".", format=controldir.format_registry.get("knit")() + ) self.assertEqual(True, repo._format._fetch_uses_deltas) def test_disk_layout(self): @@ -240,33 +247,32 @@ def test_disk_layout(self): # empty revision-store directory # empty weaves directory t = control.get_repository_transport(None) - with t.get('format') as f: - self.assertEqualDiff(b'Bazaar-NG Knit Repository Format 1', - f.read()) + with t.get("format") as f: + self.assertEqualDiff(b"Bazaar-NG Knit Repository Format 1", f.read()) # XXX: no locks left when unlocked at the moment # self.assertEqualDiff('', t.get('lock').read()) - self.assertTrue(S_ISDIR(t.stat('knits').st_mode)) + self.assertTrue(S_ISDIR(t.stat("knits").st_mode)) self.check_knits(t) # Check per-file knits. control.create_branch() tree = control.create_workingtree() - tree.add(['foo'], ['file'], [b'Nasty-IdC:']) - tree.put_file_bytes_non_atomic('foo', b'') - tree.commit('1st post', rev_id=b'foo') - self.assertHasKnit(t, 'knits/e8/%254easty-%2549d%2543%253a', - b'\nfoo fulltext 0 81 :') - - def assertHasKnit(self, t, knit_name, extra_content=b''): + tree.add(["foo"], ["file"], [b"Nasty-IdC:"]) + tree.put_file_bytes_non_atomic("foo", b"") + tree.commit("1st post", rev_id=b"foo") + self.assertHasKnit( + t, "knits/e8/%254easty-%2549d%2543%253a", b"\nfoo fulltext 0 81 :" + ) + + def assertHasKnit(self, t, knit_name, extra_content=b""): """Assert that knit_name exists on t.""" - with t.get(knit_name + '.kndx') as f: - self.assertEqualDiff(b'# bzr knit index 8\n' + extra_content, - f.read()) + with t.get(knit_name + ".kndx") as f: + self.assertEqualDiff(b"# bzr knit index 8\n" + extra_content, f.read()) def check_knits(self, t): """Check knit content for a repository.""" - self.assertHasKnit(t, 'inventory') - self.assertHasKnit(t, 'revisions') - self.assertHasKnit(t, 'signatures') + self.assertHasKnit(t, "inventory") + self.assertHasKnit(t, "revisions") + self.assertHasKnit(t, "signatures") def test_shared_disk_layout(self): control = bzrdir.BzrDirMetaFormat1().initialize(self.get_url()) @@ -279,20 +285,18 @@ def test_shared_disk_layout(self): # empty weaves directory # a 'shared-storage' marker file. t = control.get_repository_transport(None) - with t.get('format') as f: - self.assertEqualDiff(b'Bazaar-NG Knit Repository Format 1', - f.read()) + with t.get("format") as f: + self.assertEqualDiff(b"Bazaar-NG Knit Repository Format 1", f.read()) # XXX: no locks left when unlocked at the moment # self.assertEqualDiff('', t.get('lock').read()) - with t.get('shared-storage') as f: - self.assertEqualDiff(b'', f.read()) - self.assertTrue(S_ISDIR(t.stat('knits').st_mode)) + with t.get("shared-storage") as f: + self.assertEqualDiff(b"", f.read()) + self.assertTrue(S_ISDIR(t.stat("knits").st_mode)) self.check_knits(t) def test_shared_no_tree_disk_layout(self): control = bzrdir.BzrDirMetaFormat1().initialize(self.get_url()) - repo = knitrepo.RepositoryFormatKnit1().initialize( - control, shared=True) + repo = knitrepo.RepositoryFormatKnit1().initialize(control, shared=True) repo.set_make_working_trees(False) # we want: # format 'Bazaar-NG Knit Repository Format 1' @@ -302,18 +306,17 @@ def test_shared_no_tree_disk_layout(self): # empty weaves directory # a 'shared-storage' marker file. t = control.get_repository_transport(None) - with t.get('format') as f: - self.assertEqualDiff(b'Bazaar-NG Knit Repository Format 1', - f.read()) + with t.get("format") as f: + self.assertEqualDiff(b"Bazaar-NG Knit Repository Format 1", f.read()) # XXX: no locks left when unlocked at the moment # self.assertEqualDiff('', t.get('lock').read()) - with t.get('shared-storage') as f: - self.assertEqualDiff(b'', f.read()) - with t.get('no-working-trees') as f: - self.assertEqualDiff(b'', f.read()) + with t.get("shared-storage") as f: + self.assertEqualDiff(b"", f.read()) + with t.get("no-working-trees") as f: + self.assertEqualDiff(b"", f.read()) repo.set_make_working_trees(True) - self.assertFalse(t.has('no-working-trees')) - self.assertTrue(S_ISDIR(t.stat('knits').st_mode)) + self.assertFalse(t.has("no-working-trees")) + self.assertTrue(S_ISDIR(t.stat("knits").st_mode)) self.check_knits(t) def test_deserialise_sets_root_revision(self): @@ -324,28 +327,33 @@ def test_deserialise_sets_root_revision(self): is valid when the api is not being abused. """ repo = self.make_repository( - '.', format=controldir.format_registry.get('knit')()) + ".", format=controldir.format_registry.get("knit")() + ) inv_xml = b'\n\n' - inv = repo._deserialise_inventory(b'test-rev-id', [inv_xml]) - self.assertEqual(b'test-rev-id', inv.root.revision) + inv = repo._deserialise_inventory(b"test-rev-id", [inv_xml]) + self.assertEqual(b"test-rev-id", inv.root.revision) def test_deserialise_uses_global_revision_id(self): """If it is set, then we re-use the global revision id.""" repo = self.make_repository( - '.', format=controldir.format_registry.get('knit')()) - inv_xml = (b'\n' - b'\n') + ".", format=controldir.format_registry.get("knit")() + ) + inv_xml = ( + b'\n' b"\n" + ) # Arguably, the deserialise_inventory should detect a mismatch, and # raise an error, rather than silently using one revision_id over the # other. - self.assertRaises(AssertionError, repo._deserialise_inventory, - b'test-rev-id', [inv_xml]) - inv = repo._deserialise_inventory(b'other-rev-id', [inv_xml]) - self.assertEqual(b'other-rev-id', inv.root.revision) + self.assertRaises( + AssertionError, repo._deserialise_inventory, b"test-rev-id", [inv_xml] + ) + inv = repo._deserialise_inventory(b"other-rev-id", [inv_xml]) + self.assertEqual(b"other-rev-id", inv.root.revision) def test_supports_external_lookups(self): repo = self.make_repository( - '.', format=controldir.format_registry.get('knit')()) + ".", format=controldir.format_registry.get("knit")() + ) self.assertFalse(repo._format.supports_external_lookups) @@ -379,12 +387,12 @@ class InterDummy(repository.InterRepository): @staticmethod def is_compatible(repo_source, repo_target): """InterDummy is compatible with DummyRepository.""" - return (isinstance(repo_source, DummyRepository) and - isinstance(repo_target, DummyRepository)) + return isinstance(repo_source, DummyRepository) and isinstance( + repo_target, DummyRepository + ) class TestInterRepository(TestCaseWithTransport): - def test_get_default_inter_repository(self): # test that the InterRepository.get(repo_a, repo_b) probes # for a inter_repo class where is_compatible(repo_a, repo_b) returns @@ -409,8 +417,7 @@ def assertGetsDefaultInterRepository(self, repo_a, repo_b): no actual sane default in the presence of incompatible data models. """ inter_repo = repository.InterRepository.get(repo_a, repo_b) - self.assertEqual(vf_repository.InterSameDataRepository, - inter_repo.__class__) + self.assertEqual(vf_repository.InterSameDataRepository, inter_repo.__class__) self.assertEqual(repo_a, inter_repo.source) self.assertEqual(repo_b, inter_repo.target) @@ -424,22 +431,22 @@ def test_register_inter_repository_class(self): dummy_a._format = RepositoryFormat() dummy_b = DummyRepository() dummy_b._format = RepositoryFormat() - repo = self.make_repository('.') + repo = self.make_repository(".") # hack dummies to look like repo somewhat. dummy_a._revision_serializer = repo._revision_serializer dummy_a._inventory_serializer = repo._inventory_serializer - dummy_a._format.supports_tree_reference = ( - repo._format.supports_tree_reference) + dummy_a._format.supports_tree_reference = repo._format.supports_tree_reference dummy_a._format.rich_root_data = repo._format.rich_root_data dummy_a._format.supports_full_versioned_files = ( - repo._format.supports_full_versioned_files) + repo._format.supports_full_versioned_files + ) dummy_b._revision_serializer = repo._revision_serializer dummy_b._inventory_serializer = repo._inventory_serializer - dummy_b._format.supports_tree_reference = ( - repo._format.supports_tree_reference) + dummy_b._format.supports_tree_reference = repo._format.supports_tree_reference dummy_b._format.rich_root_data = repo._format.rich_root_data dummy_b._format.supports_full_versioned_files = ( - repo._format.supports_full_versioned_files) + repo._format.supports_full_versioned_files + ) repository.InterRepository.register_optimiser(InterDummy) try: # we should get the default for something InterDummy returns False @@ -459,33 +466,28 @@ def test_register_inter_repository_class(self): class TestRepositoryFormat1(knitrepo.RepositoryFormatKnit1): - @classmethod def get_format_string(cls): return b"Test Format 1" class TestRepositoryFormat2(knitrepo.RepositoryFormatKnit1): - @classmethod def get_format_string(cls): return b"Test Format 2" class TestRepositoryConverter(TestCaseWithTransport): - def test_convert_empty(self): source_format = TestRepositoryFormat1() target_format = TestRepositoryFormat2() repository.format_registry.register(source_format) - self.addCleanup(repository.format_registry.remove, - source_format) + self.addCleanup(repository.format_registry.remove, source_format) repository.format_registry.register(target_format) - self.addCleanup(repository.format_registry.remove, - target_format) + self.addCleanup(repository.format_registry.remove, target_format) t = self.get_transport() - t.mkdir('repository') - repo_dir = bzrdir.BzrDirMetaFormat1().initialize('repository') + t.mkdir("repository") + repo_dir = bzrdir.BzrDirMetaFormat1().initialize("repository") repo = TestRepositoryFormat1().initialize(repo_dir) converter = repository.CopyConverter(target_format) with breezy.ui.ui_factory.nested_progress_bar() as pb: @@ -495,59 +497,56 @@ def test_convert_empty(self): class TestRepositoryFormatKnit3(TestCaseWithTransport): - def test_attribute__fetch_order(self): """Knits need topological data insertion.""" format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit3() - repo = self.make_repository('.', format=format) - self.assertEqual('topological', repo._format._fetch_order) + repo = self.make_repository(".", format=format) + self.assertEqual("topological", repo._format._fetch_order) def test_attribute__fetch_uses_deltas(self): """Knits reuse deltas.""" format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit3() - repo = self.make_repository('.', format=format) + repo = self.make_repository(".", format=format) self.assertEqual(True, repo._format._fetch_uses_deltas) def test_convert(self): """Ensure the upgrade adds weaves for roots.""" format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit1() - tree = self.make_branch_and_tree('.', format) + tree = self.make_branch_and_tree(".", format) tree.commit("Dull commit", rev_id=b"dull") - revision_tree = tree.branch.repository.revision_tree(b'dull') + revision_tree = tree.branch.repository.revision_tree(b"dull") with revision_tree.lock_read(): - self.assertRaises( - transport.NoSuchFile, revision_tree.get_file_lines, '') + self.assertRaises(transport.NoSuchFile, revision_tree.get_file_lines, "") format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit3() - upgrade.Convert('.', format) - tree = workingtree.WorkingTree.open('.') - revision_tree = tree.branch.repository.revision_tree(b'dull') + upgrade.Convert(".", format) + tree = workingtree.WorkingTree.open(".") + revision_tree = tree.branch.repository.revision_tree(b"dull") with revision_tree.lock_read(): - revision_tree.get_file_lines('') - tree.commit("Another dull commit", rev_id=b'dull2') - revision_tree = tree.branch.repository.revision_tree(b'dull2') + revision_tree.get_file_lines("") + tree.commit("Another dull commit", rev_id=b"dull2") + revision_tree = tree.branch.repository.revision_tree(b"dull2") revision_tree.lock_read() self.addCleanup(revision_tree.unlock) - self.assertEqual(b'dull', revision_tree.get_file_revision('')) + self.assertEqual(b"dull", revision_tree.get_file_revision("")) def test_supports_external_lookups(self): format = bzrdir.BzrDirMetaFormat1() format.repository_format = knitrepo.RepositoryFormatKnit3() - repo = self.make_repository('.', format=format) + repo = self.make_repository(".", format=format) self.assertFalse(repo._format.supports_external_lookups) class Test2a(tests.TestCaseWithMemoryTransport): - def test_chk_bytes_uses_custom_btree_parser(self): - mt = self.make_branch_and_memory_tree('test', format='2a') + mt = self.make_branch_and_memory_tree("test", format="2a") mt.lock_write() self.addCleanup(mt.unlock) - mt.add([''], [b'root-id']) - mt.commit('first') + mt.add([""], [b"root-id"]) + mt.commit("first") index = mt.branch.repository.chk_bytes._index._graph_index._indices[0] self.assertEqual(btree_index._gcchk_factory, index._leaf_factory) # It should also work if we re-open the repo @@ -558,37 +557,51 @@ def test_chk_bytes_uses_custom_btree_parser(self): self.assertEqual(btree_index._gcchk_factory, index._leaf_factory) def test_fetch_combines_groups(self): - builder = self.make_branch_builder('source', format='2a') + builder = self.make_branch_builder("source", format="2a") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', '')), - ('add', ('file', b'file-id', 'file', b'content\n'))], - revision_id=b'1') - builder.build_snapshot([b'1'], [ - ('modify', ('file', b'content-2\n'))], - revision_id=b'2') + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", "")), + ("add", ("file", b"file-id", "file", b"content\n")), + ], + revision_id=b"1", + ) + builder.build_snapshot( + [b"1"], [("modify", ("file", b"content-2\n"))], revision_id=b"2" + ) builder.finish_series() source = builder.get_branch() - target = self.make_repository('target', format='2a') + target = self.make_repository("target", format="2a") target.fetch(source.repository) target.lock_read() self.addCleanup(target.unlock) details = target.texts._index.get_build_details( - [(b'file-id', b'1',), (b'file-id', b'2',)]) - file_1_details = details[(b'file-id', b'1')] - file_2_details = details[(b'file-id', b'2')] + [ + ( + b"file-id", + b"1", + ), + ( + b"file-id", + b"2", + ), + ] + ) + file_1_details = details[(b"file-id", b"1")] + file_2_details = details[(b"file-id", b"2")] # The index, and what to read off disk, should be the same for both # versions of the file. self.assertEqual(file_1_details[0][:3], file_2_details[0][:3]) def test_format_pack_compresses_True(self): - repo = self.make_repository('repo', format='2a') + repo = self.make_repository("repo", format="2a") self.assertTrue(repo._format.pack_compresses) def test_inventories_use_chk_map_with_parent_base_dict(self): - tree = self.make_branch_and_memory_tree('repo', format="2a") + tree = self.make_branch_and_memory_tree("repo", format="2a") tree.lock_write() - tree.add([''], ids=[b'TREE_ROOT']) + tree.add([""], ids=[b"TREE_ROOT"]) revid = tree.commit("foo") tree.unlock() tree.lock_read() @@ -599,31 +612,32 @@ def test_inventories_use_chk_map_with_parent_base_dict(self): inv.id_to_entry._ensure_root() self.assertEqual(65536, inv.id_to_entry._root_node.maximum_size) self.assertEqual( - 65536, inv.parent_id_basename_to_file_id._root_node.maximum_size) + 65536, inv.parent_id_basename_to_file_id._root_node.maximum_size + ) def test_autopack_unchanged_chk_nodes(self): # at 20 unchanged commits, chk pages are packed that are split into # two groups such that the new pack being made doesn't have all its # pages in the source packs (though they are in the repository). # Use a memory backed repository, we don't need to hit disk for this - tree = self.make_branch_and_memory_tree('tree', format='2a') + tree = self.make_branch_and_memory_tree("tree", format="2a") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'TREE_ROOT']) + tree.add([""], ids=[b"TREE_ROOT"]) for pos in range(20): tree.commit(str(pos)) def test_pack_with_hint(self): - tree = self.make_branch_and_memory_tree('tree', format='2a') + tree = self.make_branch_and_memory_tree("tree", format="2a") tree.lock_write() self.addCleanup(tree.unlock) - tree.add([''], ids=[b'TREE_ROOT']) + tree.add([""], ids=[b"TREE_ROOT"]) # 1 commit to leave untouched - tree.commit('1') + tree.commit("1") to_keep = tree.branch.repository._pack_collection.names() # 2 to combine - tree.commit('2') - tree.commit('3') + tree.commit("2") + tree.commit("3") all = tree.branch.repository._pack_collection.names() combine = list(set(all) - set(to_keep)) self.assertLength(3, all) @@ -636,85 +650,91 @@ def test_pack_with_hint(self): self.assertSubset(to_keep, final) def test_stream_source_to_gc(self): - source = self.make_repository('source', format='2a') - target = self.make_repository('target', format='2a') + source = self.make_repository("source", format="2a") + target = self.make_repository("target", format="2a") stream = source._get_source(target._format) self.assertIsInstance(stream, groupcompress_repo.GroupCHKStreamSource) def test_stream_source_to_non_gc(self): - source = self.make_repository('source', format='2a') - target = self.make_repository('target', format='rich-root-pack') + source = self.make_repository("source", format="2a") + target = self.make_repository("target", format="rich-root-pack") stream = source._get_source(target._format) # We don't want the child GroupCHKStreamSource self.assertIs(type(stream), vf_repository.StreamSource) def test_get_stream_for_missing_keys_includes_all_chk_refs(self): - source_builder = self.make_branch_builder('source', - format='2a') + source_builder = self.make_branch_builder("source", format="2a") # We have to build a fairly large tree, so that we are sure the chk # pages will have split into multiple pages. - entries = [('add', ('', b'a-root-id', 'directory', None))] - for i in 'abcdefghijklmnopqrstuvwxyz123456789': - for j in 'abcdefghijklmnopqrstuvwxyz123456789': + entries = [("add", ("", b"a-root-id", "directory", None))] + for i in "abcdefghijklmnopqrstuvwxyz123456789": + for j in "abcdefghijklmnopqrstuvwxyz123456789": fname = i + j - fid = fname.encode('utf-8') + b'-id' - content = b'content for %s\n' % (fname.encode('utf-8'),) - entries.append(('add', (fname, fid, 'file', content))) + fid = fname.encode("utf-8") + b"-id" + content = b"content for %s\n" % (fname.encode("utf-8"),) + entries.append(("add", (fname, fid, "file", content))) source_builder.start_series() - source_builder.build_snapshot(None, entries, revision_id=b'rev-1') + source_builder.build_snapshot(None, entries, revision_id=b"rev-1") # Now change a few of them, so we get a few new pages for the second # revision - source_builder.build_snapshot([b'rev-1'], [ - ('modify', ('aa', b'new content for aa-id\n')), - ('modify', ('cc', b'new content for cc-id\n')), - ('modify', ('zz', b'new content for zz-id\n')), - ], revision_id=b'rev-2') + source_builder.build_snapshot( + [b"rev-1"], + [ + ("modify", ("aa", b"new content for aa-id\n")), + ("modify", ("cc", b"new content for cc-id\n")), + ("modify", ("zz", b"new content for zz-id\n")), + ], + revision_id=b"rev-2", + ) source_builder.finish_series() source_branch = source_builder.get_branch() source_branch.lock_read() self.addCleanup(source_branch.unlock) - target = self.make_repository('target', format='2a') + target = self.make_repository("target", format="2a") source = source_branch.repository._get_source(target._format) self.assertIsInstance(source, groupcompress_repo.GroupCHKStreamSource) # On a regular pass, getting the inventories and chk pages for rev-2 # would only get the newly created chk pages - search = vf_search.SearchResult({b'rev-2'}, {b'rev-1'}, 1, - {b'rev-2'}) + search = vf_search.SearchResult({b"rev-2"}, {b"rev-1"}, 1, {b"rev-2"}) simple_chk_records = set() for vf_name, substream in source.get_stream(search): - if vf_name == 'chk_bytes': + if vf_name == "chk_bytes": for record in substream: simple_chk_records.add(record.key) else: for _ in substream: continue # 3 pages, the root (InternalNode), + 2 pages which actually changed - self.assertEqual({(b'sha1:91481f539e802c76542ea5e4c83ad416bf219f73',), - (b'sha1:4ff91971043668583985aec83f4f0ab10a907d3f',), - (b'sha1:81e7324507c5ca132eedaf2d8414ee4bb2226187',), - (b'sha1:b101b7da280596c71a4540e9a1eeba8045985ee0',)}, - set(simple_chk_records)) + self.assertEqual( + { + (b"sha1:91481f539e802c76542ea5e4c83ad416bf219f73",), + (b"sha1:4ff91971043668583985aec83f4f0ab10a907d3f",), + (b"sha1:81e7324507c5ca132eedaf2d8414ee4bb2226187",), + (b"sha1:b101b7da280596c71a4540e9a1eeba8045985ee0",), + }, + set(simple_chk_records), + ) # Now, when we do a similar call using 'get_stream_for_missing_keys' # we should get a much larger set of pages. - missing = [('inventories', b'rev-2')] + missing = [("inventories", b"rev-2")] full_chk_records = set() for vf_name, substream in source.get_stream_for_missing_keys(missing): - if vf_name == 'inventories': + if vf_name == "inventories": for record in substream: - self.assertEqual((b'rev-2',), record.key) - elif vf_name == 'chk_bytes': + self.assertEqual((b"rev-2",), record.key) + elif vf_name == "chk_bytes": for record in substream: full_chk_records.add(record.key) else: - self.fail(f'Should not be getting a stream of {vf_name}') + self.fail(f"Should not be getting a stream of {vf_name}") # We have 257 records now. This is because we have 1 root page, and 256 # leaf pages in a complete listing. self.assertEqual(257, len(full_chk_records)) self.assertSubset(simple_chk_records, full_chk_records) def test_inconsistency_fatal(self): - repo = self.make_repository('repo', format='2a') + repo = self.make_repository("repo", format="2a") self.assertTrue(repo.revisions._index._inconsistency_fatal) self.assertFalse(repo.texts._index._inconsistency_fatal) self.assertFalse(repo.inventories._index._inconsistency_fatal) @@ -723,69 +743,63 @@ def test_inconsistency_fatal(self): class TestKnitPackStreamSource(tests.TestCaseWithMemoryTransport): - def test_source_to_exact_pack_092(self): - source = self.make_repository('source', format='pack-0.92') - target = self.make_repository('target', format='pack-0.92') + source = self.make_repository("source", format="pack-0.92") + target = self.make_repository("target", format="pack-0.92") stream_source = source._get_source(target._format) - self.assertIsInstance( - stream_source, knitpack_repo.KnitPackStreamSource) + self.assertIsInstance(stream_source, knitpack_repo.KnitPackStreamSource) def test_source_to_exact_pack_rich_root_pack(self): - source = self.make_repository('source', format='rich-root-pack') - target = self.make_repository('target', format='rich-root-pack') + source = self.make_repository("source", format="rich-root-pack") + target = self.make_repository("target", format="rich-root-pack") stream_source = source._get_source(target._format) - self.assertIsInstance( - stream_source, knitpack_repo.KnitPackStreamSource) + self.assertIsInstance(stream_source, knitpack_repo.KnitPackStreamSource) def test_source_to_exact_pack_19(self): - source = self.make_repository('source', format='1.9') - target = self.make_repository('target', format='1.9') + source = self.make_repository("source", format="1.9") + target = self.make_repository("target", format="1.9") stream_source = source._get_source(target._format) - self.assertIsInstance( - stream_source, knitpack_repo.KnitPackStreamSource) + self.assertIsInstance(stream_source, knitpack_repo.KnitPackStreamSource) def test_source_to_exact_pack_19_rich_root(self): - source = self.make_repository('source', format='1.9-rich-root') - target = self.make_repository('target', format='1.9-rich-root') + source = self.make_repository("source", format="1.9-rich-root") + target = self.make_repository("target", format="1.9-rich-root") stream_source = source._get_source(target._format) - self.assertIsInstance( - stream_source, knitpack_repo.KnitPackStreamSource) + self.assertIsInstance(stream_source, knitpack_repo.KnitPackStreamSource) def test_source_to_remote_exact_pack_19(self): - trans = self.make_smart_server('target') + trans = self.make_smart_server("target") trans.ensure_base() - source = self.make_repository('source', format='1.9') - target = self.make_repository('target', format='1.9') + source = self.make_repository("source", format="1.9") + target = self.make_repository("target", format="1.9") target = repository.Repository.open(trans.base) stream_source = source._get_source(target._format) - self.assertIsInstance( - stream_source, knitpack_repo.KnitPackStreamSource) + self.assertIsInstance(stream_source, knitpack_repo.KnitPackStreamSource) def test_stream_source_to_non_exact(self): - source = self.make_repository('source', format='pack-0.92') - target = self.make_repository('target', format='1.9') + source = self.make_repository("source", format="pack-0.92") + target = self.make_repository("target", format="1.9") stream = source._get_source(target._format) self.assertIs(type(stream), vf_repository.StreamSource) def test_stream_source_to_non_exact_rich_root(self): - source = self.make_repository('source', format='1.9') - target = self.make_repository('target', format='1.9-rich-root') + source = self.make_repository("source", format="1.9") + target = self.make_repository("target", format="1.9-rich-root") stream = source._get_source(target._format) self.assertIs(type(stream), vf_repository.StreamSource) def test_source_to_remote_non_exact_pack_19(self): - trans = self.make_smart_server('target') + trans = self.make_smart_server("target") trans.ensure_base() - source = self.make_repository('source', format='1.9') - target = self.make_repository('target', format='1.6') + source = self.make_repository("source", format="1.9") + target = self.make_repository("target", format="1.6") target = repository.Repository.open(trans.base) stream_source = source._get_source(target._format) self.assertIs(type(stream_source), vf_repository.StreamSource) def test_stream_source_to_knit(self): - source = self.make_repository('source', format='pack-0.92') - target = self.make_repository('target', format='dirstate') + source = self.make_repository("source", format="pack-0.92") + target = self.make_repository("target", format="dirstate") stream = source._get_source(target._format) self.assertIs(type(stream), vf_repository.StreamSource) @@ -795,62 +809,62 @@ class TestDevelopment6FindParentIdsOfRevisions(TestCaseWithTransport): def setUp(self): super().setUp() - self.builder = self.make_branch_builder('source') + self.builder = self.make_branch_builder("source") self.builder.start_series() self.builder.build_snapshot( None, - [('add', ('', b'tree-root', 'directory', None))], - revision_id=b'initial') + [("add", ("", b"tree-root", "directory", None))], + revision_id=b"initial", + ) self.repo = self.builder.get_branch().repository self.addCleanup(self.builder.finish_series) def assertParentIds(self, expected_result, rev_set): self.assertEqual( sorted(expected_result), - sorted(self.repo._find_parent_ids_of_revisions(rev_set))) + sorted(self.repo._find_parent_ids_of_revisions(rev_set)), + ) def test_simple(self): - self.builder.build_snapshot(None, [], revision_id=b'revid1') - self.builder.build_snapshot([b'revid1'], [], revision_id=b'revid2') - rev_set = [b'revid2'] - self.assertParentIds([b'revid1'], rev_set) + self.builder.build_snapshot(None, [], revision_id=b"revid1") + self.builder.build_snapshot([b"revid1"], [], revision_id=b"revid2") + rev_set = [b"revid2"] + self.assertParentIds([b"revid1"], rev_set) def test_not_first_parent(self): - self.builder.build_snapshot(None, [], revision_id=b'revid1') - self.builder.build_snapshot([b'revid1'], [], revision_id=b'revid2') - self.builder.build_snapshot([b'revid2'], [], revision_id=b'revid3') - rev_set = [b'revid3', b'revid2'] - self.assertParentIds([b'revid1'], rev_set) + self.builder.build_snapshot(None, [], revision_id=b"revid1") + self.builder.build_snapshot([b"revid1"], [], revision_id=b"revid2") + self.builder.build_snapshot([b"revid2"], [], revision_id=b"revid3") + rev_set = [b"revid3", b"revid2"] + self.assertParentIds([b"revid1"], rev_set) def test_not_null(self): - rev_set = [b'initial'] + rev_set = [b"initial"] self.assertParentIds([], rev_set) def test_not_null_set(self): - self.builder.build_snapshot(None, [], revision_id=b'revid1') + self.builder.build_snapshot(None, [], revision_id=b"revid1") rev_set = [_mod_revision.NULL_REVISION] self.assertParentIds([], rev_set) def test_ghost(self): - self.builder.build_snapshot(None, [], revision_id=b'revid1') - rev_set = [b'ghost', b'revid1'] - self.assertParentIds([b'initial'], rev_set) + self.builder.build_snapshot(None, [], revision_id=b"revid1") + rev_set = [b"ghost", b"revid1"] + self.assertParentIds([b"initial"], rev_set) def test_ghost_parent(self): - self.builder.build_snapshot(None, [], revision_id=b'revid1') - self.builder.build_snapshot( - [b'revid1', b'ghost'], [], revision_id=b'revid2') - rev_set = [b'revid2', b'revid1'] - self.assertParentIds([b'ghost', b'initial'], rev_set) + self.builder.build_snapshot(None, [], revision_id=b"revid1") + self.builder.build_snapshot([b"revid1", b"ghost"], [], revision_id=b"revid2") + rev_set = [b"revid2", b"revid1"] + self.assertParentIds([b"ghost", b"initial"], rev_set) def test_righthand_parent(self): - self.builder.build_snapshot(None, [], revision_id=b'revid1') - self.builder.build_snapshot([b'revid1'], [], revision_id=b'revid2a') - self.builder.build_snapshot([b'revid1'], [], revision_id=b'revid2b') - self.builder.build_snapshot([b'revid2a', b'revid2b'], [], - revision_id=b'revid3') - rev_set = [b'revid3', b'revid2a'] - self.assertParentIds([b'revid1', b'revid2b'], rev_set) + self.builder.build_snapshot(None, [], revision_id=b"revid1") + self.builder.build_snapshot([b"revid1"], [], revision_id=b"revid2a") + self.builder.build_snapshot([b"revid1"], [], revision_id=b"revid2b") + self.builder.build_snapshot([b"revid2a", b"revid2b"], [], revision_id=b"revid3") + rev_set = [b"revid3", b"revid2a"] + self.assertParentIds([b"revid1", b"revid2b"], rev_set) class TestWithBrokenRepo(TestCaseWithTransport): @@ -860,7 +874,7 @@ def make_broken_repository(self): # XXX: This function is borrowed from Aaron's "Reconcile can fix bad # parent references" branch which is due to land in bzr.dev soon. Once # it does, this duplication should be removed. - repo = self.make_repository('broken-repo') + repo = self.make_repository("broken-repo") cleanups = [] try: repo.lock_write() @@ -868,43 +882,49 @@ def make_broken_repository(self): repo.start_write_group() cleanups.append(repo.commit_write_group) # make rev1a: A well-formed revision, containing 'file1' - inv = inventory.Inventory(revision_id=b'rev1a') - inv.root.revision = b'rev1a' - self.add_file(repo, inv, 'file1', b'rev1a', []) - repo.texts.add_lines((inv.root.file_id, b'rev1a'), [], []) - repo.add_inventory(b'rev1a', inv, []) + inv = inventory.Inventory(revision_id=b"rev1a") + inv.root.revision = b"rev1a" + self.add_file(repo, inv, "file1", b"rev1a", []) + repo.texts.add_lines((inv.root.file_id, b"rev1a"), [], []) + repo.add_inventory(b"rev1a", inv, []) revision = _mod_revision.Revision( - b'rev1a', properties={}, - committer='jrandom@example.com', timestamp=0, - inventory_sha1=None, timezone=0, message='foo', parent_ids=[]) - repo.add_revision(b'rev1a', revision, inv) + b"rev1a", + properties={}, + committer="jrandom@example.com", + timestamp=0, + inventory_sha1=None, + timezone=0, + message="foo", + parent_ids=[], + ) + repo.add_revision(b"rev1a", revision, inv) # make rev1b, which has no Revision, but has an Inventory, and # file1 - inv = inventory.Inventory(revision_id=b'rev1b') - inv.root.revision = b'rev1b' - self.add_file(repo, inv, 'file1', b'rev1b', []) - repo.add_inventory(b'rev1b', inv, []) + inv = inventory.Inventory(revision_id=b"rev1b") + inv.root.revision = b"rev1b" + self.add_file(repo, inv, "file1", b"rev1b", []) + repo.add_inventory(b"rev1b", inv, []) # make rev2, with file1 and file2 # file2 is sane # file1 has 'rev1b' as an ancestor, even though this is not # mentioned by 'rev1a', making it an unreferenced ancestor inv = inventory.Inventory() - self.add_file(repo, inv, 'file1', b'rev2', [b'rev1a', b'rev1b']) - self.add_file(repo, inv, 'file2', b'rev2', []) - self.add_revision(repo, b'rev2', inv, [b'rev1a']) + self.add_file(repo, inv, "file1", b"rev2", [b"rev1a", b"rev1b"]) + self.add_file(repo, inv, "file2", b"rev2", []) + self.add_revision(repo, b"rev2", inv, [b"rev1a"]) # make ghost revision rev1c inv = inventory.Inventory() - self.add_file(repo, inv, 'file2', b'rev1c', []) + self.add_file(repo, inv, "file2", b"rev1c", []) # make rev3 with file2 # file2 refers to 'rev1c', which is a ghost in this repository, so # file2 cannot have rev1c as its ancestor. inv = inventory.Inventory() - self.add_file(repo, inv, 'file2', b'rev3', [b'rev1c']) - self.add_revision(repo, b'rev3', inv, [b'rev1c']) + self.add_file(repo, inv, "file2", b"rev3", [b"rev1c"]) + self.add_revision(repo, b"rev3", inv, [b"rev1c"]) return repo finally: for cleanup in reversed(cleanups): @@ -917,14 +937,20 @@ def add_revision(self, repo, revision_id, inv, parent_ids): repo.add_inventory(revision_id, inv, parent_ids) revision = _mod_revision.Revision( revision_id, - committer='jrandom@example.com', timestamp=0, inventory_sha1=None, - timezone=0, message='foo', parent_ids=parent_ids, properties={}) + committer="jrandom@example.com", + timestamp=0, + inventory_sha1=None, + timezone=0, + message="foo", + parent_ids=parent_ids, + properties={}, + ) repo.add_revision(revision_id, revision, inv) def add_file(self, repo, inv, filename, revision, parents): - file_id = filename.encode('utf-8') + b'-id' - content = [b'line\n'] - entry = inventory.InventoryFile(file_id, filename, b'TREE_ROOT') + file_id = filename.encode("utf-8") + b"-id" + content = [b"line\n"] + entry = inventory.InventoryFile(file_id, filename, b"TREE_ROOT") entry.revision = revision entry.text_sha1 = osutils.sha_strings(content) entry.text_size = 0 @@ -938,7 +964,7 @@ def test_insert_from_broken_repo(self): corrupt the target repository. """ broken_repo = self.make_broken_repository() - empty_repo = self.make_repository('empty-repo') + empty_repo = self.make_repository("empty-repo") try: empty_repo.fetch(broken_repo) except (errors.RevisionNotPresent, errors.BzrCheckError): @@ -947,30 +973,32 @@ def test_insert_from_broken_repo(self): return empty_repo.lock_read() self.addCleanup(empty_repo.unlock) - text = next(empty_repo.texts.get_record_stream( - [(b'file2-id', b'rev3')], 'topological', True)) - self.assertEqual(b'line\n', text.get_bytes_as('fulltext')) + text = next( + empty_repo.texts.get_record_stream( + [(b"file2-id", b"rev3")], "topological", True + ) + ) + self.assertEqual(b"line\n", text.get_bytes_as("fulltext")) class TestRepositoryPackCollection(TestCaseWithTransport): - def get_format(self): - return controldir.format_registry.make_controldir('pack-0.92') + return controldir.format_registry.make_controldir("pack-0.92") def get_packs(self): format = self.get_format() - repo = self.make_repository('.', format=format) + repo = self.make_repository(".", format=format) return repo._pack_collection def make_packs_and_alt_repo(self, write_lock=False): """Create a pack repo with 3 packs, and access it via a second repo.""" - tree = self.make_branch_and_tree('.', format=self.get_format()) + tree = self.make_branch_and_tree(".", format=self.get_format()) tree.lock_write() self.addCleanup(tree.unlock) - rev1 = tree.commit('one') - rev2 = tree.commit('two') - rev3 = tree.commit('three') - r = repository.Repository.open('.') + rev1 = tree.commit("one") + rev2 = tree.commit("two") + rev3 = tree.commit("three") + r = repository.Repository.open(".") if write_lock: r.lock_write() else: @@ -982,28 +1010,30 @@ def make_packs_and_alt_repo(self, write_lock=False): def test__clear_obsolete_packs(self): packs = self.get_packs() - obsolete_pack_trans = packs.transport.clone('obsolete_packs') - obsolete_pack_trans.put_bytes('a-pack.pack', b'content\n') - obsolete_pack_trans.put_bytes('a-pack.rix', b'content\n') - obsolete_pack_trans.put_bytes('a-pack.iix', b'content\n') - obsolete_pack_trans.put_bytes('another-pack.pack', b'foo\n') - obsolete_pack_trans.put_bytes('not-a-pack.rix', b'foo\n') + obsolete_pack_trans = packs.transport.clone("obsolete_packs") + obsolete_pack_trans.put_bytes("a-pack.pack", b"content\n") + obsolete_pack_trans.put_bytes("a-pack.rix", b"content\n") + obsolete_pack_trans.put_bytes("a-pack.iix", b"content\n") + obsolete_pack_trans.put_bytes("another-pack.pack", b"foo\n") + obsolete_pack_trans.put_bytes("not-a-pack.rix", b"foo\n") res = packs._clear_obsolete_packs() - self.assertEqual(['a-pack', 'another-pack'], sorted(res)) - self.assertEqual([], obsolete_pack_trans.list_dir('.')) + self.assertEqual(["a-pack", "another-pack"], sorted(res)) + self.assertEqual([], obsolete_pack_trans.list_dir(".")) def test__clear_obsolete_packs_preserve(self): packs = self.get_packs() - obsolete_pack_trans = packs.transport.clone('obsolete_packs') - obsolete_pack_trans.put_bytes('a-pack.pack', b'content\n') - obsolete_pack_trans.put_bytes('a-pack.rix', b'content\n') - obsolete_pack_trans.put_bytes('a-pack.iix', b'content\n') - obsolete_pack_trans.put_bytes('another-pack.pack', b'foo\n') - obsolete_pack_trans.put_bytes('not-a-pack.rix', b'foo\n') - res = packs._clear_obsolete_packs(preserve={'a-pack'}) - self.assertEqual(['a-pack', 'another-pack'], sorted(res)) - self.assertEqual(['a-pack.iix', 'a-pack.pack', 'a-pack.rix'], - sorted(obsolete_pack_trans.list_dir('.'))) + obsolete_pack_trans = packs.transport.clone("obsolete_packs") + obsolete_pack_trans.put_bytes("a-pack.pack", b"content\n") + obsolete_pack_trans.put_bytes("a-pack.rix", b"content\n") + obsolete_pack_trans.put_bytes("a-pack.iix", b"content\n") + obsolete_pack_trans.put_bytes("another-pack.pack", b"foo\n") + obsolete_pack_trans.put_bytes("not-a-pack.rix", b"foo\n") + res = packs._clear_obsolete_packs(preserve={"a-pack"}) + self.assertEqual(["a-pack", "another-pack"], sorted(res)) + self.assertEqual( + ["a-pack.iix", "a-pack.pack", "a-pack.rix"], + sorted(obsolete_pack_trans.list_dir(".")), + ) def test__max_pack_count(self): """The maximum pack count is a function of the number of revisions.""" @@ -1032,8 +1062,7 @@ def test__max_pack_count(self): def test_repr(self): packs = self.get_packs() - self.assertContainsRe(repr(packs), - 'RepositoryPackCollection(.*Repository(.*))') + self.assertContainsRe(repr(packs), "RepositoryPackCollection(.*Repository(.*))") def test__obsolete_packs(self): tree, r, packs, revs = self.make_packs_and_alt_repo(write_lock=True) @@ -1043,23 +1072,30 @@ def test__obsolete_packs(self): packs._remove_pack_from_memory(pack) # Simulate a concurrent update by renaming the .pack file and one of # the indices - packs.transport.rename(f'packs/{names[0]}.pack', - f'obsolete_packs/{names[0]}.pack') - packs.transport.rename(f'indices/{names[0]}.iix', - f'obsolete_packs/{names[0]}.iix') + packs.transport.rename( + f"packs/{names[0]}.pack", f"obsolete_packs/{names[0]}.pack" + ) + packs.transport.rename( + f"indices/{names[0]}.iix", f"obsolete_packs/{names[0]}.iix" + ) # Now trigger the obsoletion, and ensure that all the remaining files # are still renamed packs._obsolete_packs([pack]) - self.assertEqual([n + '.pack' for n in names[1:]], - sorted(packs._pack_transport.list_dir('.'))) + self.assertEqual( + [n + ".pack" for n in names[1:]], + sorted(packs._pack_transport.list_dir(".")), + ) # names[0] should not be present in the index anymore - self.assertEqual(names[1:], - sorted({osutils.splitext(n)[0] for n in - packs._index_transport.list_dir('.')})) + self.assertEqual( + names[1:], + sorted( + {osutils.splitext(n)[0] for n in packs._index_transport.list_dir(".")} + ), + ) def test__obsolete_packs_missing_directory(self): tree, r, packs, revs = self.make_packs_and_alt_repo(write_lock=True) - r.control_transport.rmdir('obsolete_packs') + r.control_transport.rmdir("obsolete_packs") names = packs.names() pack = packs.get_pack_by_name(names[0]) # Schedule this one for removal @@ -1067,12 +1103,17 @@ def test__obsolete_packs_missing_directory(self): # Now trigger the obsoletion, and ensure that all the remaining files # are still renamed packs._obsolete_packs([pack]) - self.assertEqual([n + '.pack' for n in names[1:]], - sorted(packs._pack_transport.list_dir('.'))) + self.assertEqual( + [n + ".pack" for n in names[1:]], + sorted(packs._pack_transport.list_dir(".")), + ) # names[0] should not be present in the index anymore - self.assertEqual(names[1:], - sorted({osutils.splitext(n)[0] for n in - packs._index_transport.list_dir('.')})) + self.assertEqual( + names[1:], + sorted( + {osutils.splitext(n)[0] for n in packs._index_transport.list_dir(".")} + ), + ) def test_pack_distribution_zero(self): packs = self.get_packs() @@ -1080,29 +1121,19 @@ def test_pack_distribution_zero(self): def test_ensure_loaded_unlocked(self): packs = self.get_packs() - self.assertRaises(errors.ObjectNotLocked, - packs.ensure_loaded) + self.assertRaises(errors.ObjectNotLocked, packs.ensure_loaded) def test_pack_distribution_one_to_nine(self): packs = self.get_packs() - self.assertEqual([1], - packs.pack_distribution(1)) - self.assertEqual([1, 1], - packs.pack_distribution(2)) - self.assertEqual([1, 1, 1], - packs.pack_distribution(3)) - self.assertEqual([1, 1, 1, 1], - packs.pack_distribution(4)) - self.assertEqual([1, 1, 1, 1, 1], - packs.pack_distribution(5)) - self.assertEqual([1, 1, 1, 1, 1, 1], - packs.pack_distribution(6)) - self.assertEqual([1, 1, 1, 1, 1, 1, 1], - packs.pack_distribution(7)) - self.assertEqual([1, 1, 1, 1, 1, 1, 1, 1], - packs.pack_distribution(8)) - self.assertEqual([1, 1, 1, 1, 1, 1, 1, 1, 1], - packs.pack_distribution(9)) + self.assertEqual([1], packs.pack_distribution(1)) + self.assertEqual([1, 1], packs.pack_distribution(2)) + self.assertEqual([1, 1, 1], packs.pack_distribution(3)) + self.assertEqual([1, 1, 1, 1], packs.pack_distribution(4)) + self.assertEqual([1, 1, 1, 1, 1], packs.pack_distribution(5)) + self.assertEqual([1, 1, 1, 1, 1, 1], packs.pack_distribution(6)) + self.assertEqual([1, 1, 1, 1, 1, 1, 1], packs.pack_distribution(7)) + self.assertEqual([1, 1, 1, 1, 1, 1, 1, 1], packs.pack_distribution(8)) + self.assertEqual([1, 1, 1, 1, 1, 1, 1, 1, 1], packs.pack_distribution(9)) def test_pack_distribution_stable_at_boundaries(self): """When there are multi-rev packs the counts are stable.""" @@ -1125,7 +1156,8 @@ def test_plan_pack_operations_2009_revisions_skip_all_packs(self): existing_packs = [(2000, "big"), (9, "medium")] # rev count - 2009 -> 2x1000 + 9x1 pack_operations = packs.plan_autopack_combinations( - existing_packs, [1000, 1000, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + existing_packs, [1000, 1000, 1, 1, 1, 1, 1, 1, 1, 1, 1] + ) self.assertEqual([], pack_operations) def test_plan_pack_operations_2010_revisions_skip_all_packs(self): @@ -1133,22 +1165,30 @@ def test_plan_pack_operations_2010_revisions_skip_all_packs(self): existing_packs = [(2000, "big"), (9, "medium"), (1, "single")] # rev count - 2010 -> 2x1000 + 1x10 pack_operations = packs.plan_autopack_combinations( - existing_packs, [1000, 1000, 10]) + existing_packs, [1000, 1000, 10] + ) self.assertEqual([], pack_operations) def test_plan_pack_operations_2010_combines_smallest_two(self): packs = self.get_packs() - existing_packs = [(1999, "big"), (9, "medium"), (1, "single2"), - (1, "single1")] + existing_packs = [(1999, "big"), (9, "medium"), (1, "single2"), (1, "single1")] # rev count - 2010 -> 2x1000 + 1x10 (3) pack_operations = packs.plan_autopack_combinations( - existing_packs, [1000, 1000, 10]) + existing_packs, [1000, 1000, 10] + ) self.assertEqual([[2, ["single2", "single1"]]], pack_operations) def test_plan_pack_operations_creates_a_single_op(self): packs = self.get_packs() - existing_packs = [(50, 'a'), (40, 'b'), (30, 'c'), (10, 'd'), - (10, 'e'), (6, 'f'), (4, 'g')] + existing_packs = [ + (50, "a"), + (40, "b"), + (30, "c"), + (10, "d"), + (10, "e"), + (6, "f"), + (4, "g"), + ] # rev count 150 -> 1x100 and 5x10 # The two size 10 packs do not need to be touched. The 50, 40, 30 would # be combined into a single 120 size pack, and the 6 & 4 would @@ -1156,13 +1196,12 @@ def test_plan_pack_operations_creates_a_single_op(self): # we save a pack file with no increased I/O by putting them into the # same file. distribution = packs.pack_distribution(150) - pack_operations = packs.plan_autopack_combinations(existing_packs, - distribution) - self.assertEqual([[130, ['a', 'b', 'c', 'f', 'g']]], pack_operations) + pack_operations = packs.plan_autopack_combinations(existing_packs, distribution) + self.assertEqual([[130, ["a", "b", "c", "f", "g"]]], pack_operations) def test_all_packs_none(self): format = self.get_format() - tree = self.make_branch_and_tree('.', format=format) + tree = self.make_branch_and_tree(".", format=format) tree.lock_read() self.addCleanup(tree.unlock) packs = tree.branch.repository._pack_collection @@ -1171,34 +1210,35 @@ def test_all_packs_none(self): def test_all_packs_one(self): format = self.get_format() - tree = self.make_branch_and_tree('.', format=format) - tree.commit('start') + tree = self.make_branch_and_tree(".", format=format) + tree.commit("start") tree.lock_read() self.addCleanup(tree.unlock) packs = tree.branch.repository._pack_collection packs.ensure_loaded() - self.assertEqual([ - packs.get_pack_by_name(packs.names()[0])], - packs.all_packs()) + self.assertEqual([packs.get_pack_by_name(packs.names()[0])], packs.all_packs()) def test_all_packs_two(self): format = self.get_format() - tree = self.make_branch_and_tree('.', format=format) - tree.commit('start') - tree.commit('continue') + tree = self.make_branch_and_tree(".", format=format) + tree.commit("start") + tree.commit("continue") tree.lock_read() self.addCleanup(tree.unlock) packs = tree.branch.repository._pack_collection packs.ensure_loaded() - self.assertEqual([ - packs.get_pack_by_name(packs.names()[0]), - packs.get_pack_by_name(packs.names()[1]), - ], packs.all_packs()) + self.assertEqual( + [ + packs.get_pack_by_name(packs.names()[0]), + packs.get_pack_by_name(packs.names()[1]), + ], + packs.all_packs(), + ) def test_get_pack_by_name(self): format = self.get_format() - tree = self.make_branch_and_tree('.', format=format) - tree.commit('start') + tree = self.make_branch_and_tree(".", format=format) + tree.commit("start") tree.lock_read() self.addCleanup(tree.unlock) packs = tree.branch.repository._pack_collection @@ -1208,14 +1248,16 @@ def test_get_pack_by_name(self): pack_1 = packs.get_pack_by_name(name) # the pack should be correctly initialised sizes = packs._names[name] - rev_index = GraphIndex(packs._index_transport, name + '.rix', sizes[0]) - inv_index = GraphIndex(packs._index_transport, name + '.iix', sizes[1]) - txt_index = GraphIndex(packs._index_transport, name + '.tix', sizes[2]) - sig_index = GraphIndex(packs._index_transport, name + '.six', sizes[3]) + rev_index = GraphIndex(packs._index_transport, name + ".rix", sizes[0]) + inv_index = GraphIndex(packs._index_transport, name + ".iix", sizes[1]) + txt_index = GraphIndex(packs._index_transport, name + ".tix", sizes[2]) + sig_index = GraphIndex(packs._index_transport, name + ".six", sizes[3]) self.assertEqual( pack_repo.ExistingPack( - packs._pack_transport, name, rev_index, inv_index, txt_index, - sig_index), pack_1) + packs._pack_transport, name, rev_index, inv_index, txt_index, sig_index + ), + pack_1, + ) # and the same instance should be returned on successive calls. self.assertIs(pack_1, packs.get_pack_by_name(name)) @@ -1223,7 +1265,7 @@ def test_reload_pack_names_new_entry(self): tree, r, packs, revs = self.make_packs_and_alt_repo() names = packs.names() # Add a new pack file into the repository - rev4 = tree.commit('four') + rev4 = tree.commit("four") new_names = tree.branch.repository._pack_collection.names() new_name = set(new_names).difference(names) self.assertEqual(1, len(new_name)) @@ -1259,8 +1301,13 @@ def test_reload_pack_names_preserves_pending(self): to_remove_name = next(iter(orig_names)) r.start_write_group() self.addCleanup(r.abort_write_group) - r.texts.insert_record_stream([versionedfile.FulltextContentFactory( - (b'text', b'rev'), (), None, b'content\n')]) + r.texts.insert_record_stream( + [ + versionedfile.FulltextContentFactory( + (b"text", b"rev"), (), None, b"content\n" + ) + ] + ) new_pack = packs._new_pack self.assertTrue(new_pack.data_inserted()) new_pack.finish() @@ -1274,8 +1321,7 @@ def test_reload_pack_names_preserves_pending(self): self.assertEqual(names, sorted([x[0] for x in all_nodes])) self.assertEqual(set(names) - set(orig_names), new_names) self.assertEqual({new_pack.name}, new_names) - self.assertEqual([to_remove_name], - sorted([x[0] for x in deleted_nodes])) + self.assertEqual([to_remove_name], sorted([x[0] for x in deleted_nodes])) packs.reload_pack_names() reloaded_names = packs.names() self.assertEqual(orig_at_load, packs._packs_at_load) @@ -1285,23 +1331,26 @@ def test_reload_pack_names_preserves_pending(self): self.assertEqual(names, sorted([x[0] for x in all_nodes])) self.assertEqual(set(names) - set(orig_names), new_names) self.assertEqual({new_pack.name}, new_names) - self.assertEqual([to_remove_name], - sorted([x[0] for x in deleted_nodes])) + self.assertEqual([to_remove_name], sorted([x[0] for x in deleted_nodes])) def test_autopack_obsoletes_new_pack(self): tree, r, packs, revs = self.make_packs_and_alt_repo(write_lock=True) packs._max_pack_count = lambda x: 1 packs.pack_distribution = lambda x: [10] r.start_write_group() - r.revisions.insert_record_stream([versionedfile.FulltextContentFactory( - (b'bogus-rev',), (), None, b'bogus-content\n')]) + r.revisions.insert_record_stream( + [ + versionedfile.FulltextContentFactory( + (b"bogus-rev",), (), None, b"bogus-content\n" + ) + ] + ) # This should trigger an autopack, which will combine everything into a # single pack file. r.commit_write_group() names = packs.names() self.assertEqual(1, len(names)) - self.assertEqual([names[0] + '.pack'], - packs._pack_transport.list_dir('.')) + self.assertEqual([names[0] + ".pack"], packs._pack_transport.list_dir(".")) def test_autopack_reloads_and_stops(self): tree, r, packs, revs = self.make_packs_and_alt_repo(write_lock=True) @@ -1313,13 +1362,13 @@ def test_autopack_reloads_and_stops(self): def _munged_execute_pack_ops(*args, **kwargs): tree.branch.repository.pack() return orig_execute(*args, **kwargs) + packs._execute_pack_operations = _munged_execute_pack_ops packs._max_pack_count = lambda x: 1 packs.pack_distribution = lambda x: [10] self.assertFalse(packs.autopack()) self.assertEqual(1, len(packs.names())) - self.assertEqual(tree.branch.repository._pack_collection.names(), - packs.names()) + self.assertEqual(tree.branch.repository._pack_collection.names(), packs.names()) def test__save_pack_names(self): tree, r, packs, revs = self.make_packs_and_alt_repo(write_lock=True) @@ -1327,10 +1376,10 @@ def test__save_pack_names(self): pack = packs.get_pack_by_name(names[0]) packs._remove_pack_from_memory(pack) packs._save_pack_names(obsolete_packs=[pack]) - cur_packs = packs._pack_transport.list_dir('.') - self.assertEqual([n + '.pack' for n in names[1:]], sorted(cur_packs)) + cur_packs = packs._pack_transport.list_dir(".") + self.assertEqual([n + ".pack" for n in names[1:]], sorted(cur_packs)) # obsolete_packs will also have stuff like .rix and .iix present. - obsolete_packs = packs.transport.list_dir('obsolete_packs') + obsolete_packs = packs.transport.list_dir("obsolete_packs") obsolete_names = {osutils.splitext(n)[0] for n in obsolete_packs} self.assertEqual([pack.name], sorted(obsolete_names)) @@ -1342,13 +1391,12 @@ def test__save_pack_names_already_obsoleted(self): # We are going to simulate a concurrent autopack by manually obsoleting # the pack directly. packs._obsolete_packs([pack]) - packs._save_pack_names(clear_obsolete_packs=True, - obsolete_packs=[pack]) - cur_packs = packs._pack_transport.list_dir('.') - self.assertEqual([n + '.pack' for n in names[1:]], sorted(cur_packs)) + packs._save_pack_names(clear_obsolete_packs=True, obsolete_packs=[pack]) + cur_packs = packs._pack_transport.list_dir(".") + self.assertEqual([n + ".pack" for n in names[1:]], sorted(cur_packs)) # Note that while we set clear_obsolete_packs=True, it should not # delete a pack file that we have also scheduled for obsoletion. - obsolete_packs = packs.transport.list_dir('obsolete_packs') + obsolete_packs = packs.transport.list_dir("obsolete_packs") obsolete_names = {osutils.splitext(n)[0] for n in obsolete_packs} self.assertEqual([pack.name], sorted(obsolete_names)) @@ -1357,7 +1405,7 @@ def test_pack_no_obsolete_packs_directory(self): not exist. """ tree, r, packs, revs = self.make_packs_and_alt_repo(write_lock=True) - r.control_transport.rmdir('obsolete_packs') + r.control_transport.rmdir("obsolete_packs") packs._clear_obsolete_packs() @@ -1377,57 +1425,58 @@ def assertCurrentlyNotEqual(self, left, right): self.assertNotEqual(right, left) def test___eq____ne__(self): - left = pack_repo.ExistingPack('', '', '', '', '', '') - right = pack_repo.ExistingPack('', '', '', '', '', '') + left = pack_repo.ExistingPack("", "", "", "", "", "") + right = pack_repo.ExistingPack("", "", "", "", "", "") self.assertCurrentlyEqual(left, right) # change all attributes and ensure equality changes as we do. - left.revision_index = 'a' + left.revision_index = "a" self.assertCurrentlyNotEqual(left, right) - right.revision_index = 'a' + right.revision_index = "a" self.assertCurrentlyEqual(left, right) - left.inventory_index = 'a' + left.inventory_index = "a" self.assertCurrentlyNotEqual(left, right) - right.inventory_index = 'a' + right.inventory_index = "a" self.assertCurrentlyEqual(left, right) - left.text_index = 'a' + left.text_index = "a" self.assertCurrentlyNotEqual(left, right) - right.text_index = 'a' + right.text_index = "a" self.assertCurrentlyEqual(left, right) - left.signature_index = 'a' + left.signature_index = "a" self.assertCurrentlyNotEqual(left, right) - right.signature_index = 'a' + right.signature_index = "a" self.assertCurrentlyEqual(left, right) - left.name = 'a' + left.name = "a" self.assertCurrentlyNotEqual(left, right) - right.name = 'a' + right.name = "a" self.assertCurrentlyEqual(left, right) - left.transport = 'a' + left.transport = "a" self.assertCurrentlyNotEqual(left, right) - right.transport = 'a' + right.transport = "a" self.assertCurrentlyEqual(left, right) def test_file_name(self): - pack = pack_repo.ExistingPack('', 'a_name', '', '', '', '') - self.assertEqual('a_name.pack', pack.file_name()) + pack = pack_repo.ExistingPack("", "a_name", "", "", "", "") + self.assertEqual("a_name.pack", pack.file_name()) class TestNewPack(TestCaseWithTransport): """Tests for pack_repo.NewPack.""" def test_new_instance_attributes(self): - upload_transport = self.get_transport('upload') - pack_transport = self.get_transport('pack') - index_transport = self.get_transport('index') - upload_transport.mkdir('.') + upload_transport = self.get_transport("upload") + pack_transport = self.get_transport("pack") + index_transport = self.get_transport("index") + upload_transport.mkdir(".") collection = pack_repo.RepositoryPackCollection( repo=None, - transport=self.get_transport('.'), + transport=self.get_transport("."), index_transport=index_transport, upload_transport=upload_transport, pack_transport=pack_transport, index_builder_class=BTreeBuilder, index_class=BTreeGraphIndex, - use_chk_index=False) + use_chk_index=False, + ) pack = pack_repo.NewPack(collection) self.addCleanup(pack.abort) # Make sure the write stream gets closed self.assertIsInstance(pack.revision_index, BTreeBuilder) @@ -1446,21 +1495,25 @@ class TestPacker(TestCaseWithTransport): """Tests for the packs repository Packer class.""" def test_pack_optimizes_pack_order(self): - builder = self.make_branch_builder('.', format="1.9") + builder = self.make_branch_builder(".", format="1.9") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('f', b'f-id', 'file', b'content\n'))], - revision_id=b'A') - builder.build_snapshot([b'A'], - [('modify', ('f', b'new-content\n'))], - revision_id=b'B') - builder.build_snapshot([b'B'], - [('modify', ('f', b'third-content\n'))], - revision_id=b'C') - builder.build_snapshot([b'C'], - [('modify', ('f', b'fourth-content\n'))], - revision_id=b'D') + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("f", b"f-id", "file", b"content\n")), + ], + revision_id=b"A", + ) + builder.build_snapshot( + [b"A"], [("modify", ("f", b"new-content\n"))], revision_id=b"B" + ) + builder.build_snapshot( + [b"B"], [("modify", ("f", b"third-content\n"))], revision_id=b"C" + ) + builder.build_snapshot( + [b"C"], [("modify", ("f", b"fourth-content\n"))], revision_id=b"D" + ) b = builder.get_branch() b.lock_read() builder.finish_series() @@ -1469,9 +1522,9 @@ def test_pack_optimizes_pack_order(self): # Because of how they were built, they correspond to # ['D', 'C', 'B', 'A'] packs = b.repository._pack_collection.packs - packer = knitpack_repo.KnitPacker(b.repository._pack_collection, - packs, 'testing', - revision_ids=[b'B', b'C']) + packer = knitpack_repo.KnitPacker( + b.repository._pack_collection, packs, "testing", revision_ids=[b"B", b"C"] + ) # Now, when we are copying the B & C revisions, their pack files should # be moved to the front of the stack # The new ordering moves B & C to the front of the .packs attribute, @@ -1485,12 +1538,13 @@ class TestOptimisingPacker(TestCaseWithTransport): """Tests for the OptimisingPacker class.""" def get_pack_collection(self): - repo = self.make_repository('.') + repo = self.make_repository(".") return repo._pack_collection def test_open_pack_will_optimise(self): - packer = knitpack_repo.OptimisingKnitPacker(self.get_pack_collection(), - [], '.test') + packer = knitpack_repo.OptimisingKnitPacker( + self.get_pack_collection(), [], ".test" + ) new_pack = packer.open_pack() self.addCleanup(new_pack.abort) # ensure cleanup self.assertIsInstance(new_pack, pack_repo.NewPack) @@ -1501,20 +1555,23 @@ def test_open_pack_will_optimise(self): class TestGCCHKPacker(TestCaseWithTransport): - def make_abc_branch(self): - builder = self.make_branch_builder('source') + builder = self.make_branch_builder("source") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('file', b'file-id', 'file', b'content\n')), - ], revision_id=b'A') - builder.build_snapshot([b'A'], [ - ('add', ('dir', b'dir-id', 'directory', None))], - revision_id=b'B') - builder.build_snapshot([b'B'], [ - ('modify', ('file', b'new content\n'))], - revision_id=b'C') + builder.build_snapshot( + None, + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("file", b"file-id", "file", b"content\n")), + ], + revision_id=b"A", + ) + builder.build_snapshot( + [b"A"], [("add", ("dir", b"dir-id", "directory", None))], revision_id=b"B" + ) + builder.build_snapshot( + [b"B"], [("modify", ("file", b"new content\n"))], revision_id=b"C" + ) builder.finish_series() return builder.get_branch() @@ -1530,13 +1587,11 @@ def make_branch_with_disjoint_inventory_and_revision(self): pack_name_with_rev_C_content) """ b_source = self.make_abc_branch() - b_base = b_source.controldir.sprout( - 'base', revision_id=b'A').open_branch() - b_stacked = b_base.controldir.sprout( - 'stacked', stacked=True).open_branch() + b_base = b_source.controldir.sprout("base", revision_id=b"A").open_branch() + b_stacked = b_base.controldir.sprout("stacked", stacked=True).open_branch() b_stacked.lock_write() self.addCleanup(b_stacked.unlock) - b_stacked.fetch(b_source, b'B') + b_stacked.fetch(b_source, b"B") # Now re-open the stacked repo directly (no fallbacks) so that we can # fill in the A rev. repo_not_stacked = b_stacked.controldir.open_repository() @@ -1544,47 +1599,50 @@ def make_branch_with_disjoint_inventory_and_revision(self): self.addCleanup(repo_not_stacked.unlock) # Now we should have a pack file with A's inventory, but not its # Revision - self.assertEqual([(b'A',), (b'B',)], - sorted(repo_not_stacked.inventories.keys())) - self.assertEqual([(b'B',)], - sorted(repo_not_stacked.revisions.keys())) + self.assertEqual( + [(b"A",), (b"B",)], sorted(repo_not_stacked.inventories.keys()) + ) + self.assertEqual([(b"B",)], sorted(repo_not_stacked.revisions.keys())) stacked_pack_names = repo_not_stacked._pack_collection.names() # We have a couple names here, figure out which has A's inventory for name in stacked_pack_names: pack = repo_not_stacked._pack_collection.get_pack_by_name(name) keys = [n[1] for n in pack.inventory_index.iter_all_entries()] - if (b'A',) in keys: + if (b"A",) in keys: inv_a_pack_name = name break else: - self.fail('Could not find pack containing A\'s inventory') - repo_not_stacked.fetch(b_source.repository, b'A') - self.assertEqual([(b'A',), (b'B',)], - sorted(repo_not_stacked.revisions.keys())) + self.fail("Could not find pack containing A's inventory") + repo_not_stacked.fetch(b_source.repository, b"A") + self.assertEqual([(b"A",), (b"B",)], sorted(repo_not_stacked.revisions.keys())) new_pack_names = set(repo_not_stacked._pack_collection.names()) rev_a_pack_names = new_pack_names.difference(stacked_pack_names) self.assertEqual(1, len(rev_a_pack_names)) rev_a_pack_name = list(rev_a_pack_names)[0] # Now fetch 'C', so we have a couple pack files to join - repo_not_stacked.fetch(b_source.repository, b'C') + repo_not_stacked.fetch(b_source.repository, b"C") rev_c_pack_names = set(repo_not_stacked._pack_collection.names()) rev_c_pack_names = rev_c_pack_names.difference(new_pack_names) self.assertEqual(1, len(rev_c_pack_names)) rev_c_pack_name = list(rev_c_pack_names)[0] - return (repo_not_stacked, rev_a_pack_name, inv_a_pack_name, - rev_c_pack_name) + return (repo_not_stacked, rev_a_pack_name, inv_a_pack_name, rev_c_pack_name) def test_pack_with_distant_inventories(self): # See https://bugs.launchpad.net/bzr/+bug/437003 # When repacking, it is possible to have an inventory in a different # pack file than the associated revision. An autopack can then come # along, and miss that inventory, and complain. - (repo, rev_a_pack_name, inv_a_pack_name, rev_c_pack_name - ) = self.make_branch_with_disjoint_inventory_and_revision() + ( + repo, + rev_a_pack_name, + inv_a_pack_name, + rev_c_pack_name, + ) = self.make_branch_with_disjoint_inventory_and_revision() a_pack = repo._pack_collection.get_pack_by_name(rev_a_pack_name) c_pack = repo._pack_collection.get_pack_by_name(rev_c_pack_name) - packer = groupcompress_repo.GCCHKPacker(repo._pack_collection, - [a_pack, c_pack], '.test-pack') + packer = groupcompress_repo.GCCHKPacker( + repo._pack_collection, [a_pack, c_pack], ".test-pack" + ) # This would raise ValueError in bug #437003, but should not raise an # error once fixed. packer.pack() @@ -1592,22 +1650,27 @@ def test_pack_with_distant_inventories(self): def test_pack_with_missing_inventory(self): # Similar to test_pack_with_missing_inventory, but this time, we force # the A inventory to actually be gone from the repository. - (repo, rev_a_pack_name, inv_a_pack_name, rev_c_pack_name - ) = self.make_branch_with_disjoint_inventory_and_revision() + ( + repo, + rev_a_pack_name, + inv_a_pack_name, + rev_c_pack_name, + ) = self.make_branch_with_disjoint_inventory_and_revision() inv_a_pack = repo._pack_collection.get_pack_by_name(inv_a_pack_name) repo._pack_collection._remove_pack_from_memory(inv_a_pack) - packer = groupcompress_repo.GCCHKPacker(repo._pack_collection, - repo._pack_collection.all_packs(), '.test-pack') + packer = groupcompress_repo.GCCHKPacker( + repo._pack_collection, repo._pack_collection.all_packs(), ".test-pack" + ) e = self.assertRaises(ValueError, packer.pack) packer.new_pack.abort() - self.assertContainsRe(str(e), - r"We are missing inventories for revisions: .*'A'") + self.assertContainsRe( + str(e), r"We are missing inventories for revisions: .*'A'" + ) class TestCrossFormatPacks(TestCaseWithTransport): - def log_pack(self, hint=None): - self.calls.append(('pack', hint)) + self.calls.append(("pack", hint)) self.orig_pack(hint=hint) if self.expect_hint: self.assertTrue(hint) @@ -1615,18 +1678,19 @@ def log_pack(self, hint=None): def run_stream(self, src_fmt, target_fmt, expect_pack_called): self.expect_hint = expect_pack_called self.calls = [] - source_tree = self.make_branch_and_tree('src', format=src_fmt) + source_tree = self.make_branch_and_tree("src", format=src_fmt) source_tree.lock_write() self.addCleanup(source_tree.unlock) - tip = source_tree.commit('foo') - target = self.make_repository('target', format=target_fmt) + tip = source_tree.commit("foo") + target = self.make_repository("target", format=target_fmt) target.lock_write() self.addCleanup(target.unlock) source = source_tree.branch.repository._get_source(target._format) self.orig_pack = target.pack self.overrideAttr(target, "pack", self.log_pack) search = target.search_missing_revision_ids( - source_tree.branch.repository, revision_ids=[tip]) + source_tree.branch.repository, revision_ids=[tip] + ) stream = source.get_stream(search) from_format = source_tree.branch.repository._format sink = target._get_sink() @@ -1639,11 +1703,11 @@ def run_stream(self, src_fmt, target_fmt, expect_pack_called): def run_fetch(self, src_fmt, target_fmt, expect_pack_called): self.expect_hint = expect_pack_called self.calls = [] - source_tree = self.make_branch_and_tree('src', format=src_fmt) + source_tree = self.make_branch_and_tree("src", format=src_fmt) source_tree.lock_write() self.addCleanup(source_tree.unlock) - source_tree.commit('foo') - target = self.make_repository('target', format=target_fmt) + source_tree.commit("foo") + target = self.make_repository("target", format=target_fmt) target.lock_write() self.addCleanup(target.unlock) source = source_tree.branch.repository @@ -1658,57 +1722,55 @@ def run_fetch(self, src_fmt, target_fmt, expect_pack_called): def test_sink_format_hint_no(self): # When the target format says packing makes no difference, pack is not # called. - self.run_stream('1.9', 'rich-root-pack', False) + self.run_stream("1.9", "rich-root-pack", False) def test_sink_format_hint_yes(self): # When the target format says packing makes a difference, pack is # called. - self.run_stream('1.9', '2a', True) + self.run_stream("1.9", "2a", True) def test_sink_format_same_no(self): # When the formats are the same, pack is not called. - self.run_stream('2a', '2a', False) + self.run_stream("2a", "2a", False) def test_IDS_format_hint_no(self): # When the target format says packing makes no difference, pack is not # called. - self.run_fetch('1.9', 'rich-root-pack', False) + self.run_fetch("1.9", "rich-root-pack", False) def test_IDS_format_hint_yes(self): # When the target format says packing makes a difference, pack is # called. - self.run_fetch('1.9', '2a', True) + self.run_fetch("1.9", "2a", True) def test_IDS_format_same_no(self): # When the formats are the same, pack is not called. - self.run_fetch('2a', '2a', False) + self.run_fetch("2a", "2a", False) class Test_LazyListJoin(tests.TestCase): - def test__repr__(self): - lazy = repository._LazyListJoin(['a'], ['b']) - self.assertEqual("breezy.repository._LazyListJoin((['a'], ['b']))", - repr(lazy)) + lazy = repository._LazyListJoin(["a"], ["b"]) + self.assertEqual("breezy.repository._LazyListJoin((['a'], ['b']))", repr(lazy)) class TestFeatures(tests.TestCaseWithTransport): - def test_open_with_present_feature(self): self.addCleanup( bzrrepository.RepositoryFormatMetaDir.unregister_feature, - b"makes-cheese-sandwich") - bzrrepository.RepositoryFormatMetaDir.register_feature( - b"makes-cheese-sandwich") - repo = self.make_repository('.') + b"makes-cheese-sandwich", + ) + bzrrepository.RepositoryFormatMetaDir.register_feature(b"makes-cheese-sandwich") + repo = self.make_repository(".") repo.lock_write() repo._format.features[b"makes-cheese-sandwich"] = b"required" repo._format.check_support_status(False) repo.unlock() def test_open_with_missing_required_feature(self): - repo = self.make_repository('.') + repo = self.make_repository(".") repo.lock_write() repo._format.features[b"makes-cheese-sandwich"] = b"required" - self.assertRaises(bzrdir.MissingFeature, - repo._format.check_support_status, False) + self.assertRaises( + bzrdir.MissingFeature, repo._format.check_support_status, False + ) diff --git a/breezy/bzr/tests/test_rio.py b/breezy/bzr/tests/test_rio.py index 19a791e635..50c361fdd1 100644 --- a/breezy/bzr/tests/test_rio.py +++ b/breezy/bzr/tests/test_rio.py @@ -37,61 +37,64 @@ def rio_file(stanzas): class TestRio(TestCase): - def test_stanza(self): """Construct rio stanza in memory.""" - s = _mod_rio.Stanza(number='42', name="fred") - self.assertIn('number', s) - self.assertNotIn('color', s) - self.assertNotIn('42', s) - self.assertEqual(list(s.iter_pairs()), - [('name', 'fred'), ('number', '42')]) - self.assertEqual(s.get('number'), '42') - self.assertEqual(s.get('name'), 'fred') + s = _mod_rio.Stanza(number="42", name="fred") + self.assertIn("number", s) + self.assertNotIn("color", s) + self.assertNotIn("42", s) + self.assertEqual(list(s.iter_pairs()), [("name", "fred"), ("number", "42")]) + self.assertEqual(s.get("number"), "42") + self.assertEqual(s.get("name"), "fred") def test_empty_value(self): """Serialize stanza with empty field.""" - s = _mod_rio.Stanza(empty='') - self.assertEqual(s.to_string(), - b"empty: \n") + s = _mod_rio.Stanza(empty="") + self.assertEqual(s.to_string(), b"empty: \n") def test_to_lines(self): """Write simple rio stanza to string.""" - s = _mod_rio.Stanza(number='42', name='fred') - self.assertEqual(list(s.to_lines()), - [b'name: fred\n', - b'number: 42\n']) + s = _mod_rio.Stanza(number="42", name="fred") + self.assertEqual(list(s.to_lines()), [b"name: fred\n", b"number: 42\n"]) def test_as_dict(self): """Convert rio Stanza to dictionary.""" - s = _mod_rio.Stanza(number='42', name='fred') + s = _mod_rio.Stanza(number="42", name="fred") sd = s.as_dict() - self.assertEqual(sd, {'number': '42', 'name': 'fred'}) + self.assertEqual(sd, {"number": "42", "name": "fred"}) def test_to_file(self): """Write rio to file.""" tmpf = TemporaryFile() - s = _mod_rio.Stanza(a_thing='something with "quotes like \\"this\\""', - number='42', name='fred') + s = _mod_rio.Stanza( + a_thing='something with "quotes like \\"this\\""', number="42", name="fred" + ) s.write(tmpf) tmpf.seek(0) - self.assertEqual(tmpf.read(), b'''\ + self.assertEqual( + tmpf.read(), + b"""\ a_thing: something with "quotes like \\"this\\"" name: fred number: 42 -''') +""", + ) def test_multiline_string(self): tmpf = TemporaryFile() s = _mod_rio.Stanza( - motto="war is peace\nfreedom is slavery\nignorance is strength") + motto="war is peace\nfreedom is slavery\nignorance is strength" + ) s.write(tmpf) tmpf.seek(0) - self.assertEqual(tmpf.read(), b'''\ + self.assertEqual( + tmpf.read(), + b"""\ motto: war is peace \tfreedom is slavery \tignorance is strength -''') +""", + ) tmpf.seek(0) s2 = _mod_rio.read_stanza(tmpf) self.assertEqual(s, s2) @@ -105,71 +108,88 @@ def test_read_stanza(self): committer: Martin Pool """.splitlines(True) s = _mod_rio.read_stanza(lines) - self.assertIn('revision', s) - self.assertEqual(s.get('revision'), 'mbp@sourcefrog.net-123-abc') - self.assertEqual(list(s.iter_pairs()), - [('revision', 'mbp@sourcefrog.net-123-abc'), - ('timestamp', '1130653962'), - ('timezone', '36000'), - ('committer', "Martin Pool ")]) + self.assertIn("revision", s) + self.assertEqual(s.get("revision"), "mbp@sourcefrog.net-123-abc") + self.assertEqual( + list(s.iter_pairs()), + [ + ("revision", "mbp@sourcefrog.net-123-abc"), + ("timestamp", "1130653962"), + ("timezone", "36000"), + ("committer", "Martin Pool "), + ], + ) self.assertEqual(len(s), 4) def test_repeated_field(self): """Repeated field in rio.""" s = _mod_rio.Stanza() - for k, v in [('a', '10'), ('b', '20'), ('a', '100'), ('b', '200'), - ('a', '1000'), ('b', '2000')]: + for k, v in [ + ("a", "10"), + ("b", "20"), + ("a", "100"), + ("b", "200"), + ("a", "1000"), + ("b", "2000"), + ]: s.add(k, v) s2 = _mod_rio.read_stanza(s.to_lines()) self.assertEqual(s, s2) - self.assertEqual(s.get_all('a'), ['10', '100', '1000']) - self.assertEqual(s.get_all('b'), ['20', '200', '2000']) + self.assertEqual(s.get_all("a"), ["10", "100", "1000"]) + self.assertEqual(s.get_all("b"), ["20", "200", "2000"]) def test_backslash(self): - s = _mod_rio.Stanza(q='\\') + s = _mod_rio.Stanza(q="\\") t = s.to_string() - self.assertEqual(t, b'q: \\\n') + self.assertEqual(t, b"q: \\\n") s2 = _mod_rio.read_stanza(s.to_lines()) self.assertEqual(s, s2) def test_blank_line(self): - s = _mod_rio.Stanza(none='', one='\n', two='\n\n') - self.assertEqual(s.to_string(), b"""\ + s = _mod_rio.Stanza(none="", one="\n", two="\n\n") + self.assertEqual( + s.to_string(), + b"""\ none:\x20 one:\x20 \t two:\x20 \t \t -""") +""", + ) s2 = _mod_rio.read_stanza(s.to_lines()) self.assertEqual(s, s2) def test_whitespace_value(self): - s = _mod_rio.Stanza(space=' ', tabs='\t\t\t', combo='\n\t\t\n') - self.assertEqual(s.to_string(), b"""\ + s = _mod_rio.Stanza(space=" ", tabs="\t\t\t", combo="\n\t\t\n") + self.assertEqual( + s.to_string(), + b"""\ combo:\x20 \t\t\t \t space:\x20\x20 tabs: \t\t\t -""") +""", + ) s2 = _mod_rio.read_stanza(s.to_lines()) self.assertEqual(s, s2) self.rio_file_stanzas([s]) def test_quoted(self): """Rio quoted string cases.""" - s = _mod_rio.Stanza(q1='"hello"', - q2=' "for', - q3='\n\n"for"\n', - q4='for\n"\nfor', - q5='\n', - q6='"', - q7='""', - q8='\\', - q9='\\"\\"', - ) + s = _mod_rio.Stanza( + q1='"hello"', + q2=' "for', + q3='\n\n"for"\n', + q4='for\n"\nfor', + q5="\n", + q6='"', + q7='""', + q8="\\", + q9='\\"\\"', + ) s2 = _mod_rio.read_stanza(s.to_lines()) self.assertEqual(s, s2) # apparent bug in read_stanza @@ -184,16 +204,17 @@ def test_read_empty(self): def test_read_nul_byte(self): """File consisting of a nul byte causes an error.""" - self.assertRaises(Exception, _mod_rio.read_stanza, [b'\0']) + self.assertRaises(Exception, _mod_rio.read_stanza, [b"\0"]) def test_read_nul_bytes(self): """File consisting of many nul bytes causes an error.""" - self.assertRaises(Exception, _mod_rio.read_stanza, [b'\0' * 100]) + self.assertRaises(Exception, _mod_rio.read_stanza, [b"\0" * 100]) def test_read_iter(self): """Read several stanzas from file.""" tmpf = TemporaryFile() - tmpf.write(b"""\ + tmpf.write( + b"""\ version_header: 1 name: foo @@ -201,19 +222,25 @@ def test_read_iter(self): name: bar val: 129319 -""") +""" + ) tmpf.seek(0) reader = _mod_rio.read_stanzas(tmpf) stuff = list(reader) - self.assertEqual(stuff, - [_mod_rio.Stanza(version_header='1'), - _mod_rio.Stanza(name="foo", val='123'), - _mod_rio.Stanza(name="bar", val='129319'), ]) + self.assertEqual( + stuff, + [ + _mod_rio.Stanza(version_header="1"), + _mod_rio.Stanza(name="foo", val="123"), + _mod_rio.Stanza(name="bar", val="129319"), + ], + ) def test_read_several(self): """Read several stanzas from file.""" tmpf = TemporaryFile() - tmpf.write(b"""\ + tmpf.write( + b"""\ version_header: 1 name: foo @@ -226,18 +253,18 @@ def test_read_several(self): name: bar val: 129319 -""") +""" + ) tmpf.seek(0) s = _mod_rio.read_stanza(tmpf) - self.assertEqual(s, _mod_rio.Stanza(version_header='1')) + self.assertEqual(s, _mod_rio.Stanza(version_header="1")) s = _mod_rio.read_stanza(tmpf) - self.assertEqual(s, _mod_rio.Stanza(name="foo", val='123')) + self.assertEqual(s, _mod_rio.Stanza(name="foo", val="123")) s = _mod_rio.read_stanza(tmpf) - self.assertEqual(s.get('name'), 'quoted') - self.assertEqual( - s.get('address'), ' "Willowglen"\n 42 Wallaby Way\n Sydney') + self.assertEqual(s.get("name"), "quoted") + self.assertEqual(s.get("address"), ' "Willowglen"\n 42 Wallaby Way\n Sydney') s = _mod_rio.read_stanza(tmpf) - self.assertEqual(s, _mod_rio.Stanza(name="bar", val='129319')) + self.assertEqual(s, _mod_rio.Stanza(name="bar", val="129319")) s = _mod_rio.read_stanza(tmpf) self.assertEqual(s, None) @@ -257,7 +284,8 @@ def rio_file_stanzas(self, stanzas): def test_tricky_quoted(self): tmpf = TemporaryFile() - tmpf.write(b'''\ + tmpf.write( + b'''\ s: "one" s:\x20 @@ -288,26 +316,28 @@ def test_tricky_quoted(self): s: both\\\" -''') +''' + ) tmpf.seek(0) - expected_vals = ['"one"', - '\n"one"\n', - '"', - '""', - '"""', - '\n', - '\\', - '\n\\\n\\\\\n', - 'word\\', - 'quote\"', - 'backslashes\\\\\\', - 'both\\\"', - ] + expected_vals = [ + '"one"', + '\n"one"\n', + '"', + '""', + '"""', + "\n", + "\\", + "\n\\\n\\\\\n", + "word\\", + 'quote"', + "backslashes\\\\\\", + 'both\\"', + ] for expected in expected_vals: stanza = _mod_rio.read_stanza(tmpf) self.rio_file_stanzas([stanza]) self.assertEqual(len(stanza), 1) - self.assertEqual(stanza.get('s'), expected) + self.assertEqual(stanza.get("s"), expected) def test_write_empty_stanza(self): """Write empty stanza.""" @@ -317,7 +347,7 @@ def test_write_empty_stanza(self): def test_rio_raises_type_error(self): """TypeError on adding invalid type to Stanza.""" s = _mod_rio.Stanza() - self.assertRaises(TypeError, s.add, 'foo', {}) + self.assertRaises(TypeError, s.add, "foo", {}) def test_rio_raises_type_error_key(self): """TypeError on adding invalid type to Stanza.""" @@ -325,12 +355,12 @@ def test_rio_raises_type_error_key(self): self.assertRaises(TypeError, s.add, 10, {}) def test_rio_surrogateescape(self): - raw_bytes = b'\xcb' - self.assertRaises(UnicodeDecodeError, raw_bytes.decode, 'utf-8') + raw_bytes = b"\xcb" + self.assertRaises(UnicodeDecodeError, raw_bytes.decode, "utf-8") try: - uni_data = raw_bytes.decode('utf-8', 'surrogateescape') + uni_data = raw_bytes.decode("utf-8", "surrogateescape") except LookupError: - self.skipTest('surrogateescape is not available on Python < 3') + self.skipTest("surrogateescape is not available on Python < 3") try: _mod_rio.Stanza(foo=uni_data) except TypeError: @@ -339,54 +369,49 @@ def test_rio_surrogateescape(self): self.fail() def test_rio_unicode(self): - uni_data = '\N{KATAKANA LETTER O}' + uni_data = "\N{KATAKANA LETTER O}" s = _mod_rio.Stanza(foo=uni_data) - self.assertEqual(s.get('foo'), uni_data) + self.assertEqual(s.get("foo"), uni_data) raw_lines = s.to_lines() - self.assertEqual(raw_lines, - [b'foo: ' + uni_data.encode('utf-8') + b'\n']) + self.assertEqual(raw_lines, [b"foo: " + uni_data.encode("utf-8") + b"\n"]) new_s = _mod_rio.read_stanza(raw_lines) - self.assertEqual(new_s.get('foo'), uni_data) + self.assertEqual(new_s.get("foo"), uni_data) def mail_munge(self, lines, dos_nl=True): new_lines = [] for line in lines: - line = re.sub(b' *\n', b'\n', line) + line = re.sub(b" *\n", b"\n", line) if dos_nl: - line = re.sub(b'([^\r])\n', b'\\1\r\n', line) + line = re.sub(b"([^\r])\n", b"\\1\r\n", line) new_lines.append(line) return new_lines def test_patch_rio(self): - stanza = _mod_rio.Stanza(data='#\n\r\\r ', space=' ' * 255, hash='#' * 255) + stanza = _mod_rio.Stanza(data="#\n\r\\r ", space=" " * 255, hash="#" * 255) lines = to_patch_lines(stanza) for line in lines: - self.assertContainsRe(line, b'^# ') + self.assertContainsRe(line, b"^# ") self.assertGreaterEqual(72, len(line)) for line in to_patch_lines(stanza, max_width=12): self.assertGreaterEqual(12, len(line)) - new_stanza = read_patch_stanza(self.mail_munge(lines, - dos_nl=False)) + new_stanza = read_patch_stanza(self.mail_munge(lines, dos_nl=False)) lines = self.mail_munge(lines) new_stanza = read_patch_stanza(lines) - self.assertEqual('#\n\r\\r ', new_stanza.get('data')) - self.assertEqual(' ' * 255, new_stanza.get('space')) - self.assertEqual('#' * 255, new_stanza.get('hash')) + self.assertEqual("#\n\r\\r ", new_stanza.get("data")) + self.assertEqual(" " * 255, new_stanza.get("space")) + self.assertEqual("#" * 255, new_stanza.get("hash")) def test_patch_rio_linebreaks(self): - stanza = _mod_rio.Stanza(breaktest='linebreak -/' * 30) + stanza = _mod_rio.Stanza(breaktest="linebreak -/" * 30) line1 = to_patch_lines(stanza, 71)[0] - self.assertContainsRe(line1, b'linebreak\\\\\n') - stanza = _mod_rio.Stanza(breaktest='linebreak-/' * 30) - self.assertContainsRe(to_patch_lines(stanza, 70)[0], - b'linebreak-\\\\\n') - stanza = _mod_rio.Stanza(breaktest='linebreak/' * 30) - self.assertContainsRe(to_patch_lines(stanza, 70)[0], - b'linebreak\\\\\n') + self.assertContainsRe(line1, b"linebreak\\\\\n") + stanza = _mod_rio.Stanza(breaktest="linebreak-/" * 30) + self.assertContainsRe(to_patch_lines(stanza, 70)[0], b"linebreak-\\\\\n") + stanza = _mod_rio.Stanza(breaktest="linebreak/" * 30) + self.assertContainsRe(to_patch_lines(stanza, 70)[0], b"linebreak\\\\\n") class TestValidTag(TestCase): - def test_ok(self): self.assertTrue(_mod_rio.valid_tag("foo")) @@ -414,7 +439,6 @@ def test_non_ascii_char(self): class TestReadUTF8Stanza(TestCase): - def assertReadStanza(self, result, line_iter): s = _mod_rio.read_stanza(line_iter) self.assertEqual(result, s) @@ -440,7 +464,8 @@ def test_simple(self): def test_multi_line(self): self.assertReadStanza( - _mod_rio.Stanza(foo="bar\nbla"), [b"foo: bar\n", b"\tbla\n"]) + _mod_rio.Stanza(foo="bar\nbla"), [b"foo: bar\n", b"\tbla\n"] + ) def test_repeated(self): s = _mod_rio.Stanza() @@ -459,9 +484,11 @@ def test_continuation_too_early(self): def test_large(self): value = b"bla" * 9000 - self.assertReadStanza(_mod_rio.Stanza(foo=value.decode()), - [b"foo: %s\n" % value]) + self.assertReadStanza( + _mod_rio.Stanza(foo=value.decode()), [b"foo: %s\n" % value] + ) def test_non_ascii_char(self): - self.assertReadStanza(_mod_rio.Stanza(foo="n\xe5me"), - ["foo: n\xe5me\n".encode()]) + self.assertReadStanza( + _mod_rio.Stanza(foo="n\xe5me"), ["foo: n\xe5me\n".encode()] + ) diff --git a/breezy/bzr/tests/test_serializer.py b/breezy/bzr/tests/test_serializer.py index e204c3c68f..753ec080d4 100644 --- a/breezy/bzr/tests/test_serializer.py +++ b/breezy/bzr/tests/test_serializer.py @@ -26,13 +26,19 @@ class TestSerializer(TestCase): """Test serializer.""" def test_registry(self): - self.assertIs(xml5.revision_serializer_v5, - serializer.revision_format_registry.get('5')) - self.assertIs(xml8.revision_serializer_v8, - serializer.revision_format_registry.get('8')) - self.assertIs(xml6.inventory_serializer_v6, - serializer.inventory_format_registry.get('6')) - self.assertIs(xml7.inventory_serializer_v7, - serializer.inventory_format_registry.get('7')) - self.assertIs(chk_serializer.inventory_chk_serializer_255_bigpage_9, - serializer.inventory_format_registry.get('9')) + self.assertIs( + xml5.revision_serializer_v5, serializer.revision_format_registry.get("5") + ) + self.assertIs( + xml8.revision_serializer_v8, serializer.revision_format_registry.get("8") + ) + self.assertIs( + xml6.inventory_serializer_v6, serializer.inventory_format_registry.get("6") + ) + self.assertIs( + xml7.inventory_serializer_v7, serializer.inventory_format_registry.get("7") + ) + self.assertIs( + chk_serializer.inventory_chk_serializer_255_bigpage_9, + serializer.inventory_format_registry.get("9"), + ) diff --git a/breezy/bzr/tests/test_smart.py b/breezy/bzr/tests/test_smart.py index d38f0ee504..7bbdbfe2b2 100644 --- a/breezy/bzr/tests/test_smart.py +++ b/breezy/bzr/tests/test_smart.py @@ -51,15 +51,22 @@ def load_tests(loader, standard_tests, pattern): """Multiply tests version and protocol consistency.""" # FindRepository tests. scenarios = [ - ("find_repository", { - "_request_class": smart_dir.SmartServerRequestFindRepositoryV1}), - ("find_repositoryV2", { - "_request_class": smart_dir.SmartServerRequestFindRepositoryV2}), - ("find_repositoryV3", { - "_request_class": smart_dir.SmartServerRequestFindRepositoryV3}), - ] + ( + "find_repository", + {"_request_class": smart_dir.SmartServerRequestFindRepositoryV1}, + ), + ( + "find_repositoryV2", + {"_request_class": smart_dir.SmartServerRequestFindRepositoryV2}, + ), + ( + "find_repositoryV3", + {"_request_class": smart_dir.SmartServerRequestFindRepositoryV3}, + ), + ] to_adapt, result = tests.split_suite_by_re( - standard_tests, "TestSmartServerRequestFindRepository") + standard_tests, "TestSmartServerRequestFindRepository" + ) v2_only, v1_and_2 = tests.split_suite_by_re(to_adapt, "_v2") tests.multiply_tests(v1_and_2, scenarios, result) # The first scenario is only applicable to v1 protocols, it is deleted @@ -69,7 +76,6 @@ def load_tests(loader, standard_tests, pattern): class TestCaseWithChrootedTransport(tests.TestCaseWithTransport): - def setUp(self): self.vfs_transport_factory = memory.MemoryServer super().setUp() @@ -87,7 +93,6 @@ def get_transport(self, relpath=None): class TestCaseWithSmartMedium(tests.TestCaseWithMemoryTransport): - def setUp(self): super().setUp() # We're allowed to set the transport class here, so that we don't use @@ -97,7 +102,7 @@ def setUp(self): self.overrideAttr(self, "transport_server", self.make_transport_server) def make_transport_server(self): - return test_server.SmartTCPServer_for_testing('-' + self.id()) + return test_server.SmartTCPServer_for_testing("-" + self.id()) def get_smart_medium(self): """Get a smart medium to use in tests.""" @@ -105,15 +110,19 @@ def get_smart_medium(self): class TestByteStreamToStream(tests.TestCase): - def test_repeated_substreams_same_kind_are_one_stream(self): # Make a stream - an iterable of bytestrings. stream = [ - ('text', [versionedfile.FulltextContentFactory((b'k1',), None, - None, b'foo')]), - ('text', [versionedfile.FulltextContentFactory((b'k2',), None, - None, b'bar')])] - fmt = controldir.format_registry.get('pack-0.92')().repository_format + ( + "text", + [versionedfile.FulltextContentFactory((b"k1",), None, None, b"foo")], + ), + ( + "text", + [versionedfile.FulltextContentFactory((b"k2",), None, None, b"bar")], + ), + ] + fmt = controldir.format_registry.get("pack-0.92")().repository_format bytes = smart_repo._stream_to_byte_stream(stream, fmt) streams = [] # Iterate the resulting iterable; checking that we get only one stream @@ -126,216 +135,229 @@ def test_repeated_substreams_same_kind_are_one_stream(self): class TestSmartServerResponse(tests.TestCase): - def test__eq__(self): - self.assertEqual(smart_req.SmartServerResponse((b'ok', )), - smart_req.SmartServerResponse((b'ok', ))) - self.assertEqual(smart_req.SmartServerResponse((b'ok', ), b'body'), - smart_req.SmartServerResponse((b'ok', ), b'body')) - self.assertNotEqual(smart_req.SmartServerResponse((b'ok', )), - smart_req.SmartServerResponse((b'notok', ))) - self.assertNotEqual(smart_req.SmartServerResponse((b'ok', ), b'body'), - smart_req.SmartServerResponse((b'ok', ))) - self.assertNotEqual(None, - smart_req.SmartServerResponse((b'ok', ))) + self.assertEqual( + smart_req.SmartServerResponse((b"ok",)), + smart_req.SmartServerResponse((b"ok",)), + ) + self.assertEqual( + smart_req.SmartServerResponse((b"ok",), b"body"), + smart_req.SmartServerResponse((b"ok",), b"body"), + ) + self.assertNotEqual( + smart_req.SmartServerResponse((b"ok",)), + smart_req.SmartServerResponse((b"notok",)), + ) + self.assertNotEqual( + smart_req.SmartServerResponse((b"ok",), b"body"), + smart_req.SmartServerResponse((b"ok",)), + ) + self.assertNotEqual(None, smart_req.SmartServerResponse((b"ok",))) def test__str__(self): """SmartServerResponses can be stringified.""" self.assertIn( - str(smart_req.SuccessfulSmartServerResponse((b'args',), b'body')), - ("", - "")) + str(smart_req.SuccessfulSmartServerResponse((b"args",), b"body")), + ( + "", + "", + ), + ) self.assertIn( - str(smart_req.FailedSmartServerResponse((b'args',), b'body')), - ("", - "")) + str(smart_req.FailedSmartServerResponse((b"args",), b"body")), + ( + "", + "", + ), + ) class TestSmartServerRequest(tests.TestCaseWithMemoryTransport): - def test_translate_client_path(self): transport = self.get_transport() - request = smart_req.SmartServerRequest(transport, 'foo/') - self.assertEqual('./', request.translate_client_path(b'foo/')) + request = smart_req.SmartServerRequest(transport, "foo/") + self.assertEqual("./", request.translate_client_path(b"foo/")) self.assertRaises( - urlutils.InvalidURLJoin, request.translate_client_path, b'foo/..') - self.assertRaises( - errors.PathNotChild, request.translate_client_path, b'/') - self.assertRaises( - errors.PathNotChild, request.translate_client_path, b'bar/') - self.assertEqual('./baz', request.translate_client_path(b'foo/baz')) - e_acute = '\N{LATIN SMALL LETTER E WITH ACUTE}' + urlutils.InvalidURLJoin, request.translate_client_path, b"foo/.." + ) + self.assertRaises(errors.PathNotChild, request.translate_client_path, b"/") + self.assertRaises(errors.PathNotChild, request.translate_client_path, b"bar/") + self.assertEqual("./baz", request.translate_client_path(b"foo/baz")) + e_acute = "\N{LATIN SMALL LETTER E WITH ACUTE}" self.assertEqual( - './' + urlutils.escape(e_acute), - request.translate_client_path(b'foo/' + e_acute.encode('utf-8'))) + "./" + urlutils.escape(e_acute), + request.translate_client_path(b"foo/" + e_acute.encode("utf-8")), + ) def test_translate_client_path_vfs(self): """VfsRequests receive escaped paths rather than raw UTF-8.""" transport = self.get_transport() - request = vfs.VfsRequest(transport, 'foo/') - e_acute = '\N{LATIN SMALL LETTER E WITH ACUTE}' - escaped = urlutils.escape('foo/' + e_acute) + request = vfs.VfsRequest(transport, "foo/") + e_acute = "\N{LATIN SMALL LETTER E WITH ACUTE}" + escaped = urlutils.escape("foo/" + e_acute) self.assertEqual( - './' + urlutils.escape(e_acute), - request.translate_client_path(escaped.encode('ascii'))) + "./" + urlutils.escape(e_acute), + request.translate_client_path(escaped.encode("ascii")), + ) def test_transport_from_client_path(self): transport = self.get_transport() - request = smart_req.SmartServerRequest(transport, 'foo/') + request = smart_req.SmartServerRequest(transport, "foo/") self.assertEqual( - transport.base, - request.transport_from_client_path(b'foo/').base) + transport.base, request.transport_from_client_path(b"foo/").base + ) -class TestSmartServerBzrDirRequestCloningMetaDir( - tests.TestCaseWithMemoryTransport): +class TestSmartServerBzrDirRequestCloningMetaDir(tests.TestCaseWithMemoryTransport): """Tests for BzrDir.cloning_metadir.""" def test_cloning_metadir(self): """When there is a bzrdir present, the call succeeds.""" backing = self.get_transport() - dir = self.make_controldir('.') + dir = self.make_controldir(".") local_result = dir.cloning_metadir() request_class = smart_dir.SmartServerBzrDirRequestCloningMetaDir request = request_class(backing) expected = smart_req.SuccessfulSmartServerResponse( - (local_result.network_name(), - local_result.repository_format.network_name(), - (b'branch', local_result.get_branch_format().network_name()))) - self.assertEqual(expected, request.execute(b'', b'False')) + ( + local_result.network_name(), + local_result.repository_format.network_name(), + (b"branch", local_result.get_branch_format().network_name()), + ) + ) + self.assertEqual(expected, request.execute(b"", b"False")) def test_cloning_metadir_reference(self): """The request fails when bzrdir contains a branch reference.""" backing = self.get_transport() - referenced_branch = self.make_branch('referenced') - dir = self.make_controldir('.') + referenced_branch = self.make_branch("referenced") + dir = self.make_controldir(".") dir.cloning_metadir() _mod_bzrbranch.BranchReferenceFormat().initialize( - dir, target_branch=referenced_branch) + dir, target_branch=referenced_branch + ) _mod_bzrbranch.BranchReferenceFormat().get_reference(dir) # The server shouldn't try to follow the branch reference, so it's fine # if the referenced branch isn't reachable. - backing.rename('referenced', 'moved') + backing.rename("referenced", "moved") request_class = smart_dir.SmartServerBzrDirRequestCloningMetaDir request = request_class(backing) - expected = smart_req.FailedSmartServerResponse((b'BranchReference',)) - self.assertEqual(expected, request.execute(b'', b'False')) + expected = smart_req.FailedSmartServerResponse((b"BranchReference",)) + self.assertEqual(expected, request.execute(b"", b"False")) -class TestSmartServerBzrDirRequestCheckoutMetaDir( - tests.TestCaseWithMemoryTransport): +class TestSmartServerBzrDirRequestCheckoutMetaDir(tests.TestCaseWithMemoryTransport): """Tests for BzrDir.checkout_metadir.""" def test_checkout_metadir(self): backing = self.get_transport() - request = smart_dir.SmartServerBzrDirRequestCheckoutMetaDir( - backing) - self.make_branch('.', format='2a') - response = request.execute(b'') + request = smart_dir.SmartServerBzrDirRequestCheckoutMetaDir(backing) + self.make_branch(".", format="2a") + response = request.execute(b"") self.assertEqual( smart_req.SmartServerResponse( - (b'Bazaar-NG meta directory, format 1\n', - b'Bazaar repository format 2a (needs bzr 1.16 or later)\n', - b'Bazaar Branch Format 7 (needs bzr 1.6)\n')), - response) + ( + b"Bazaar-NG meta directory, format 1\n", + b"Bazaar repository format 2a (needs bzr 1.16 or later)\n", + b"Bazaar Branch Format 7 (needs bzr 1.6)\n", + ) + ), + response, + ) -class TestSmartServerBzrDirRequestDestroyBranch( - tests.TestCaseWithMemoryTransport): +class TestSmartServerBzrDirRequestDestroyBranch(tests.TestCaseWithMemoryTransport): """Tests for BzrDir.destroy_branch.""" def test_destroy_branch_default(self): """The default branch can be removed.""" backing = self.get_transport() - self.make_branch('.') + self.make_branch(".") request_class = smart_dir.SmartServerBzrDirRequestDestroyBranch request = request_class(backing) - expected = smart_req.SuccessfulSmartServerResponse((b'ok',)) - self.assertEqual(expected, request.execute(b'', None)) + expected = smart_req.SuccessfulSmartServerResponse((b"ok",)) + self.assertEqual(expected, request.execute(b"", None)) def test_destroy_branch_named(self): """A named branch can be removed.""" backing = self.get_transport() - dir = self.make_repository('.', format="development-colo").controldir + dir = self.make_repository(".", format="development-colo").controldir dir.create_branch(name="branchname") request_class = smart_dir.SmartServerBzrDirRequestDestroyBranch request = request_class(backing) - expected = smart_req.SuccessfulSmartServerResponse((b'ok',)) - self.assertEqual(expected, request.execute(b'', b"branchname")) + expected = smart_req.SuccessfulSmartServerResponse((b"ok",)) + self.assertEqual(expected, request.execute(b"", b"branchname")) def test_destroy_branch_missing(self): """An error is raised if the branch didn't exist.""" backing = self.get_transport() - self.make_controldir('.', format="development-colo") + self.make_controldir(".", format="development-colo") request_class = smart_dir.SmartServerBzrDirRequestDestroyBranch request = request_class(backing) - expected = smart_req.FailedSmartServerResponse((b'nobranch',), None) - self.assertEqual(expected, request.execute(b'', b"branchname")) + expected = smart_req.FailedSmartServerResponse((b"nobranch",), None) + self.assertEqual(expected, request.execute(b"", b"branchname")) -class TestSmartServerBzrDirRequestHasWorkingTree( - tests.TestCaseWithTransport): +class TestSmartServerBzrDirRequestHasWorkingTree(tests.TestCaseWithTransport): """Tests for BzrDir.has_workingtree.""" def test_has_workingtree_yes(self): """A working tree is present.""" backing = self.get_transport() - self.make_branch_and_tree('.') + self.make_branch_and_tree(".") request_class = smart_dir.SmartServerBzrDirRequestHasWorkingTree request = request_class(backing) - expected = smart_req.SuccessfulSmartServerResponse((b'yes',)) - self.assertEqual(expected, request.execute(b'')) + expected = smart_req.SuccessfulSmartServerResponse((b"yes",)) + self.assertEqual(expected, request.execute(b"")) def test_has_workingtree_no(self): """A working tree is missing.""" backing = self.get_transport() - self.make_controldir('.') + self.make_controldir(".") request_class = smart_dir.SmartServerBzrDirRequestHasWorkingTree request = request_class(backing) - expected = smart_req.SuccessfulSmartServerResponse((b'no',)) - self.assertEqual(expected, request.execute(b'')) + expected = smart_req.SuccessfulSmartServerResponse((b"no",)) + self.assertEqual(expected, request.execute(b"")) -class TestSmartServerBzrDirRequestDestroyRepository( - tests.TestCaseWithMemoryTransport): +class TestSmartServerBzrDirRequestDestroyRepository(tests.TestCaseWithMemoryTransport): """Tests for BzrDir.destroy_repository.""" def test_destroy_repository_default(self): """The repository can be removed.""" backing = self.get_transport() - self.make_repository('.') + self.make_repository(".") request_class = smart_dir.SmartServerBzrDirRequestDestroyRepository request = request_class(backing) - expected = smart_req.SuccessfulSmartServerResponse((b'ok',)) - self.assertEqual(expected, request.execute(b'')) + expected = smart_req.SuccessfulSmartServerResponse((b"ok",)) + self.assertEqual(expected, request.execute(b"")) def test_destroy_repository_missing(self): """An error is raised if the repository didn't exist.""" backing = self.get_transport() - self.make_controldir('.') + self.make_controldir(".") request_class = smart_dir.SmartServerBzrDirRequestDestroyRepository request = request_class(backing) - expected = smart_req.FailedSmartServerResponse( - (b'norepository',), None) - self.assertEqual(expected, request.execute(b'')) + expected = smart_req.FailedSmartServerResponse((b"norepository",), None) + self.assertEqual(expected, request.execute(b"")) -class TestSmartServerRequestCreateRepository( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRequestCreateRepository(tests.TestCaseWithMemoryTransport): """Tests for BzrDir.create_repository.""" def test_makes_repository(self): """When there is a bzrdir present, the call succeeds.""" backing = self.get_transport() - self.make_controldir('.') + self.make_controldir(".") request_class = smart_dir.SmartServerRequestCreateRepository request = request_class(backing) - reference_bzrdir_format = controldir.format_registry.get('pack-0.92')() + reference_bzrdir_format = controldir.format_registry.get("pack-0.92")() reference_format = reference_bzrdir_format.repository_format network_name = reference_format.network_name() expected = smart_req.SuccessfulSmartServerResponse( - (b'ok', b'no', b'no', b'no', network_name)) - self.assertEqual(expected, request.execute(b'', network_name, b'True')) + (b"ok", b"no", b"no", b"no", network_name) + ) + self.assertEqual(expected, request.execute(b"", network_name, b"True")) class TestSmartServerRequestFindRepository(tests.TestCaseWithMemoryTransport): @@ -345,9 +367,10 @@ def test_no_repository(self): """If no repository is found, ('norepository', ) is returned.""" backing = self.get_transport() request = self._request_class(backing) - self.make_controldir('.') - self.assertEqual(smart_req.SmartServerResponse((b'norepository', )), - request.execute(b'')) + self.make_controldir(".") + self.assertEqual( + smart_req.SmartServerResponse((b"norepository",)), request.execute(b"") + ) def test_nonshared_repository(self): # nonshared repositorys only allow 'find' to return a handle when the @@ -356,154 +379,146 @@ def test_nonshared_repository(self): backing = self.get_transport() request = self._request_class(backing) result = self._make_repository_and_result() - self.assertEqual(result, request.execute(b'')) - self.make_controldir('subdir') - self.assertEqual(smart_req.SmartServerResponse((b'norepository', )), - request.execute(b'subdir')) + self.assertEqual(result, request.execute(b"")) + self.make_controldir("subdir") + self.assertEqual( + smart_req.SmartServerResponse((b"norepository",)), + request.execute(b"subdir"), + ) def _make_repository_and_result(self, shared=False, format=None): """Convenience function to setup a repository. :result: The SmartServerResponse to expect when opening it. """ - repo = self.make_repository('.', shared=shared, format=format) + repo = self.make_repository(".", shared=shared, format=format) if repo.supports_rich_root(): - rich_root = b'yes' + rich_root = b"yes" else: - rich_root = b'no' + rich_root = b"no" if repo._format.supports_tree_reference: - subtrees = b'yes' + subtrees = b"yes" else: - subtrees = b'no' + subtrees = b"no" if repo._format.supports_external_lookups: - external = b'yes' + external = b"yes" else: - external = b'no' - if (smart_dir.SmartServerRequestFindRepositoryV3 == - self._request_class): + external = b"no" + if smart_dir.SmartServerRequestFindRepositoryV3 == self._request_class: return smart_req.SuccessfulSmartServerResponse( - (b'ok', b'', rich_root, subtrees, external, - repo._format.network_name())) - elif (smart_dir.SmartServerRequestFindRepositoryV2 == - self._request_class): + (b"ok", b"", rich_root, subtrees, external, repo._format.network_name()) + ) + elif smart_dir.SmartServerRequestFindRepositoryV2 == self._request_class: # All tests so far are on formats, and for non-external # repositories. return smart_req.SuccessfulSmartServerResponse( - (b'ok', b'', rich_root, subtrees, external)) + (b"ok", b"", rich_root, subtrees, external) + ) else: return smart_req.SuccessfulSmartServerResponse( - (b'ok', b'', rich_root, subtrees)) + (b"ok", b"", rich_root, subtrees) + ) def test_shared_repository(self): """For a shared repository, we get 'ok', 'relpath-to-repo'.""" backing = self.get_transport() request = self._request_class(backing) result = self._make_repository_and_result(shared=True) - self.assertEqual(result, request.execute(b'')) - self.make_controldir('subdir') + self.assertEqual(result, request.execute(b"")) + self.make_controldir("subdir") result2 = smart_req.SmartServerResponse( - result.args[0:1] + (b'..', ) + result.args[2:]) - self.assertEqual(result2, - request.execute(b'subdir')) - self.make_controldir('subdir/deeper') + result.args[0:1] + (b"..",) + result.args[2:] + ) + self.assertEqual(result2, request.execute(b"subdir")) + self.make_controldir("subdir/deeper") result3 = smart_req.SmartServerResponse( - result.args[0:1] + (b'../..', ) + result.args[2:]) - self.assertEqual(result3, - request.execute(b'subdir/deeper')) + result.args[0:1] + (b"../..",) + result.args[2:] + ) + self.assertEqual(result3, request.execute(b"subdir/deeper")) def test_rich_root_and_subtree_encoding(self): """Test for the format attributes for rich root and subtree support.""" backing = self.get_transport() request = self._request_class(backing) - result = self._make_repository_and_result( - format='development-subtree') + result = self._make_repository_and_result(format="development-subtree") # check the test will be valid - self.assertEqual(b'yes', result.args[2]) - self.assertEqual(b'yes', result.args[3]) - self.assertEqual(result, request.execute(b'')) + self.assertEqual(b"yes", result.args[2]) + self.assertEqual(b"yes", result.args[3]) + self.assertEqual(result, request.execute(b"")) def test_supports_external_lookups_no_v2(self): """Test for the supports_external_lookups attribute.""" backing = self.get_transport() request = self._request_class(backing) - result = self._make_repository_and_result( - format='development-subtree') + result = self._make_repository_and_result(format="development-subtree") # check the test will be valid - self.assertEqual(b'yes', result.args[4]) - self.assertEqual(result, request.execute(b'')) + self.assertEqual(b"yes", result.args[4]) + self.assertEqual(result, request.execute(b"")) -class TestSmartServerBzrDirRequestGetConfigFile( - tests.TestCaseWithMemoryTransport): +class TestSmartServerBzrDirRequestGetConfigFile(tests.TestCaseWithMemoryTransport): """Tests for BzrDir.get_config_file.""" def test_present(self): backing = self.get_transport() - dir = self.make_controldir('.') + dir = self.make_controldir(".") dir.get_config().set_default_stack_on("/") local_result = dir._get_config()._get_config_file().read() request_class = smart_dir.SmartServerBzrDirRequestConfigFile request = request_class(backing) expected = smart_req.SuccessfulSmartServerResponse((), local_result) - self.assertEqual(expected, request.execute(b'')) + self.assertEqual(expected, request.execute(b"")) def test_missing(self): backing = self.get_transport() - self.make_controldir('.') + self.make_controldir(".") request_class = smart_dir.SmartServerBzrDirRequestConfigFile request = request_class(backing) - expected = smart_req.SuccessfulSmartServerResponse((), b'') - self.assertEqual(expected, request.execute(b'')) + expected = smart_req.SuccessfulSmartServerResponse((), b"") + self.assertEqual(expected, request.execute(b"")) -class TestSmartServerBzrDirRequestGetBranches( - tests.TestCaseWithMemoryTransport): +class TestSmartServerBzrDirRequestGetBranches(tests.TestCaseWithMemoryTransport): """Tests for BzrDir.get_branches.""" def test_simple(self): backing = self.get_transport() - branch = self.make_branch('.') + branch = self.make_branch(".") request_class = smart_dir.SmartServerBzrDirRequestGetBranches request = request_class(backing) local_result = bencode.bencode( - {b"": (b"branch", branch._format.network_name())}) - expected = smart_req.SuccessfulSmartServerResponse( - (b"success", ), local_result) - self.assertEqual(expected, request.execute(b'')) + {b"": (b"branch", branch._format.network_name())} + ) + expected = smart_req.SuccessfulSmartServerResponse((b"success",), local_result) + self.assertEqual(expected, request.execute(b"")) def test_ref(self): backing = self.get_transport() - dir = self.make_controldir('foo') - b = self.make_branch('bar') + dir = self.make_controldir("foo") + b = self.make_branch("bar") dir.set_branch_reference(b) request_class = smart_dir.SmartServerBzrDirRequestGetBranches request = request_class(backing) - local_result = bencode.bencode( - {b"": (b"ref", b'../bar/')}) - expected = smart_req.SuccessfulSmartServerResponse( - (b"success", ), local_result) - self.assertEqual(expected, request.execute(b'foo')) + local_result = bencode.bencode({b"": (b"ref", b"../bar/")}) + expected = smart_req.SuccessfulSmartServerResponse((b"success",), local_result) + self.assertEqual(expected, request.execute(b"foo")) def test_empty(self): backing = self.get_transport() - self.make_controldir('.') + self.make_controldir(".") request_class = smart_dir.SmartServerBzrDirRequestGetBranches request = request_class(backing) local_result = bencode.bencode({}) - expected = smart_req.SuccessfulSmartServerResponse( - (b'success',), local_result) - self.assertEqual(expected, request.execute(b'')) + expected = smart_req.SuccessfulSmartServerResponse((b"success",), local_result) + self.assertEqual(expected, request.execute(b"")) -class TestSmartServerRequestInitializeBzrDir( - tests.TestCaseWithMemoryTransport): - +class TestSmartServerRequestInitializeBzrDir(tests.TestCaseWithMemoryTransport): def test_empty_dir(self): """Initializing an empty dir should succeed and do it.""" backing = self.get_transport() request = smart_dir.SmartServerRequestInitializeBzrDir(backing) - self.assertEqual(smart_req.SmartServerResponse((b'ok', )), - request.execute(b'')) + self.assertEqual(smart_req.SmartServerResponse((b"ok",)), request.execute(b"")) made_dir = controldir.ControlDir.open_from_transport(backing) # no branch, tree or repository is expected with the current # default formart. @@ -515,20 +530,17 @@ def test_missing_dir(self): """Initializing a missing directory should fail like the bzrdir api.""" backing = self.get_transport() request = smart_dir.SmartServerRequestInitializeBzrDir(backing) - self.assertRaises(transport.NoSuchFile, - request.execute, b'subdir') + self.assertRaises(transport.NoSuchFile, request.execute, b"subdir") def test_initialized_dir(self): """Initializing an extant bzrdir should fail like the bzrdir api.""" backing = self.get_transport() request = smart_dir.SmartServerRequestInitializeBzrDir(backing) - self.make_controldir('subdir') - self.assertRaises(errors.AlreadyControlDirError, - request.execute, b'subdir') + self.make_controldir("subdir") + self.assertRaises(errors.AlreadyControlDirError, request.execute, b"subdir") -class TestSmartServerRequestBzrDirInitializeEx( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRequestBzrDirInitializeEx(tests.TestCaseWithMemoryTransport): """Basic tests for BzrDir.initialize_ex_1.16 in the smart server. The main unit tests in test_bzrdir exercise the API comprehensively. @@ -537,13 +549,16 @@ class TestSmartServerRequestBzrDirInitializeEx( def test_empty_dir(self): """Initializing an empty dir should succeed and do it.""" backing = self.get_transport() - name = self.make_controldir('reference')._format.network_name() + name = self.make_controldir("reference")._format.network_name() request = smart_dir.SmartServerRequestBzrDirInitializeEx(backing) self.assertEqual( - smart_req.SmartServerResponse((b'', b'', b'', b'', b'', b'', name, - b'False', b'', b'', b'')), - request.execute(name, b'', b'True', b'False', b'False', b'', b'', - b'', b'', b'False')) + smart_req.SmartServerResponse( + (b"", b"", b"", b"", b"", b"", name, b"False", b"", b"", b"") + ), + request.execute( + name, b"", b"True", b"False", b"False", b"", b"", b"", b"", b"False" + ), + ) made_dir = controldir.ControlDir.open_from_transport(backing) # no branch, tree or repository is expected with the current # default format. @@ -554,405 +569,453 @@ def test_empty_dir(self): def test_missing_dir(self): """Initializing a missing directory should fail like the bzrdir api.""" backing = self.get_transport() - name = self.make_controldir('reference')._format.network_name() + name = self.make_controldir("reference")._format.network_name() request = smart_dir.SmartServerRequestBzrDirInitializeEx(backing) - self.assertRaises(transport.NoSuchFile, request.execute, name, - b'subdir/dir', b'False', b'False', b'False', b'', - b'', b'', b'', b'False') + self.assertRaises( + transport.NoSuchFile, + request.execute, + name, + b"subdir/dir", + b"False", + b"False", + b"False", + b"", + b"", + b"", + b"", + b"False", + ) def test_initialized_dir(self): """Initializing an extant directory should fail like the bzrdir api.""" backing = self.get_transport() - name = self.make_controldir('reference')._format.network_name() + name = self.make_controldir("reference")._format.network_name() request = smart_dir.SmartServerRequestBzrDirInitializeEx(backing) - self.make_controldir('subdir') - self.assertRaises(transport.FileExists, request.execute, name, b'subdir', - b'False', b'False', b'False', b'', b'', b'', b'', - b'False') + self.make_controldir("subdir") + self.assertRaises( + transport.FileExists, + request.execute, + name, + b"subdir", + b"False", + b"False", + b"False", + b"", + b"", + b"", + b"", + b"False", + ) class TestSmartServerRequestOpenBzrDir(tests.TestCaseWithMemoryTransport): - def test_no_directory(self): backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBzrDir(backing) - self.assertEqual(smart_req.SmartServerResponse((b'no', )), - request.execute(b'does-not-exist')) + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), request.execute(b"does-not-exist") + ) def test_empty_directory(self): backing = self.get_transport() - backing.mkdir('empty') + backing.mkdir("empty") request = smart_dir.SmartServerRequestOpenBzrDir(backing) - self.assertEqual(smart_req.SmartServerResponse((b'no', )), - request.execute(b'empty')) + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), request.execute(b"empty") + ) def test_outside_root_client_path(self): backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBzrDir( - backing, root_client_path='root') - self.assertEqual(smart_req.SmartServerResponse((b'no', )), - request.execute(b'not-root')) + backing, root_client_path="root" + ) + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), request.execute(b"not-root") + ) class TestSmartServerRequestOpenBzrDir_2_1(tests.TestCaseWithMemoryTransport): - def test_no_directory(self): backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBzrDir_2_1(backing) - self.assertEqual(smart_req.SmartServerResponse((b'no', )), - request.execute(b'does-not-exist')) + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), request.execute(b"does-not-exist") + ) def test_empty_directory(self): backing = self.get_transport() - backing.mkdir('empty') + backing.mkdir("empty") request = smart_dir.SmartServerRequestOpenBzrDir_2_1(backing) - self.assertEqual(smart_req.SmartServerResponse((b'no', )), - request.execute(b'empty')) + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), request.execute(b"empty") + ) def test_present_without_workingtree(self): backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBzrDir_2_1(backing) - self.make_controldir('.') - self.assertEqual(smart_req.SmartServerResponse((b'yes', b'no')), - request.execute(b'')) + self.make_controldir(".") + self.assertEqual( + smart_req.SmartServerResponse((b"yes", b"no")), request.execute(b"") + ) def test_outside_root_client_path(self): backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBzrDir_2_1( - backing, root_client_path='root') - self.assertEqual(smart_req.SmartServerResponse((b'no',)), - request.execute(b'not-root')) + backing, root_client_path="root" + ) + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), request.execute(b"not-root") + ) class TestSmartServerRequestOpenBzrDir_2_1_disk(TestCaseWithChrootedTransport): - def test_present_with_workingtree(self): self.vfs_transport_factory = test_server.LocalURLServer backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBzrDir_2_1(backing) - bd = self.make_controldir('.') + bd = self.make_controldir(".") bd.create_repository() bd.create_branch() bd.create_workingtree() - self.assertEqual(smart_req.SmartServerResponse((b'yes', b'yes')), - request.execute(b'')) + self.assertEqual( + smart_req.SmartServerResponse((b"yes", b"yes")), request.execute(b"") + ) class TestSmartServerRequestOpenBranch(TestCaseWithChrootedTransport): - def test_no_branch(self): """When there is no branch, ('nobranch', ) is returned.""" backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranch(backing) - self.make_controldir('.') - self.assertEqual(smart_req.SmartServerResponse((b'nobranch', )), - request.execute(b'')) + self.make_controldir(".") + self.assertEqual( + smart_req.SmartServerResponse((b"nobranch",)), request.execute(b"") + ) def test_branch(self): """When there is a branch, 'ok' is returned.""" backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranch(backing) - self.make_branch('.') - self.assertEqual(smart_req.SmartServerResponse((b'ok', b'')), - request.execute(b'')) + self.make_branch(".") + self.assertEqual( + smart_req.SmartServerResponse((b"ok", b"")), request.execute(b"") + ) def test_branch_reference(self): """When there is a branch reference, the reference URL is returned.""" self.vfs_transport_factory = test_server.LocalURLServer backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranch(backing) - branch = self.make_branch('branch') - checkout = branch.create_checkout('reference', lightweight=True) - reference_url = _mod_bzrbranch.BranchReferenceFormat().get_reference( - checkout.controldir).encode('utf-8') - self.assertFileEqual(reference_url, 'reference/.bzr/branch/location') - self.assertEqual(smart_req.SmartServerResponse((b'ok', reference_url)), - request.execute(b'reference')) + branch = self.make_branch("branch") + checkout = branch.create_checkout("reference", lightweight=True) + reference_url = ( + _mod_bzrbranch.BranchReferenceFormat() + .get_reference(checkout.controldir) + .encode("utf-8") + ) + self.assertFileEqual(reference_url, "reference/.bzr/branch/location") + self.assertEqual( + smart_req.SmartServerResponse((b"ok", reference_url)), + request.execute(b"reference"), + ) def test_notification_on_branch_from_repository(self): """When there is a repository, the error should return details.""" backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranch(backing) - self.make_repository('.') - self.assertEqual(smart_req.SmartServerResponse((b'nobranch',)), - request.execute(b'')) + self.make_repository(".") + self.assertEqual( + smart_req.SmartServerResponse((b"nobranch",)), request.execute(b"") + ) class TestSmartServerRequestOpenBranchV2(TestCaseWithChrootedTransport): - def test_no_branch(self): """When there is no branch, ('nobranch', ) is returned.""" backing = self.get_transport() - self.make_controldir('.') + self.make_controldir(".") request = smart_dir.SmartServerRequestOpenBranchV2(backing) - self.assertEqual(smart_req.SmartServerResponse((b'nobranch', )), - request.execute(b'')) + self.assertEqual( + smart_req.SmartServerResponse((b"nobranch",)), request.execute(b"") + ) def test_branch(self): """When there is a branch, 'ok' is returned.""" backing = self.get_transport() - expected = self.make_branch('.')._format.network_name() + expected = self.make_branch(".")._format.network_name() request = smart_dir.SmartServerRequestOpenBranchV2(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse( - (b'branch', expected)), - request.execute(b'')) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"branch", expected)), + request.execute(b""), + ) def test_branch_reference(self): """When there is a branch reference, the reference URL is returned.""" self.vfs_transport_factory = test_server.LocalURLServer backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranchV2(backing) - branch = self.make_branch('branch') - checkout = branch.create_checkout('reference', lightweight=True) - reference_url = _mod_bzrbranch.BranchReferenceFormat().get_reference( - checkout.controldir).encode('utf-8') - self.assertFileEqual(reference_url, 'reference/.bzr/branch/location') - self.assertEqual(smart_req.SuccessfulSmartServerResponse( - (b'ref', reference_url)), - request.execute(b'reference')) + branch = self.make_branch("branch") + checkout = branch.create_checkout("reference", lightweight=True) + reference_url = ( + _mod_bzrbranch.BranchReferenceFormat() + .get_reference(checkout.controldir) + .encode("utf-8") + ) + self.assertFileEqual(reference_url, "reference/.bzr/branch/location") + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"ref", reference_url)), + request.execute(b"reference"), + ) def test_stacked_branch(self): """Opening a stacked branch does not open the stacked-on branch.""" - trunk = self.make_branch('trunk') - feature = self.make_branch('feature') + trunk = self.make_branch("trunk") + feature = self.make_branch("feature") feature.set_stacked_on_url(trunk.base) opened_branches = [] _mod_branch.Branch.hooks.install_named_hook( - 'open', opened_branches.append, None) + "open", opened_branches.append, None + ) backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranchV2(backing) request.setup_jail() try: - response = request.execute(b'feature') + response = request.execute(b"feature") finally: request.teardown_jail() expected_format = feature._format.network_name() - self.assertEqual(smart_req.SuccessfulSmartServerResponse( - (b'branch', expected_format)), - response) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"branch", expected_format)), + response, + ) self.assertLength(1, opened_branches) def test_notification_on_branch_from_repository(self): """When there is a repository, the error should return details.""" backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranchV2(backing) - self.make_repository('.') - self.assertEqual(smart_req.SmartServerResponse((b'nobranch',)), - request.execute(b'')) + self.make_repository(".") + self.assertEqual( + smart_req.SmartServerResponse((b"nobranch",)), request.execute(b"") + ) class TestSmartServerRequestOpenBranchV3(TestCaseWithChrootedTransport): - def test_no_branch(self): """When there is no branch, ('nobranch', ) is returned.""" backing = self.get_transport() - self.make_controldir('.') + self.make_controldir(".") request = smart_dir.SmartServerRequestOpenBranchV3(backing) - self.assertEqual(smart_req.SmartServerResponse((b'nobranch',)), - request.execute(b'')) + self.assertEqual( + smart_req.SmartServerResponse((b"nobranch",)), request.execute(b"") + ) def test_branch(self): """When there is a branch, 'ok' is returned.""" backing = self.get_transport() - expected = self.make_branch('.')._format.network_name() + expected = self.make_branch(".")._format.network_name() request = smart_dir.SmartServerRequestOpenBranchV3(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse( - (b'branch', expected)), - request.execute(b'')) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"branch", expected)), + request.execute(b""), + ) def test_branch_reference(self): """When there is a branch reference, the reference URL is returned.""" self.vfs_transport_factory = test_server.LocalURLServer backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranchV3(backing) - branch = self.make_branch('branch') - checkout = branch.create_checkout('reference', lightweight=True) - reference_url = _mod_bzrbranch.BranchReferenceFormat().get_reference( - checkout.controldir).encode('utf-8') - self.assertFileEqual(reference_url, 'reference/.bzr/branch/location') - self.assertEqual(smart_req.SuccessfulSmartServerResponse( - (b'ref', reference_url)), - request.execute(b'reference')) + branch = self.make_branch("branch") + checkout = branch.create_checkout("reference", lightweight=True) + reference_url = ( + _mod_bzrbranch.BranchReferenceFormat() + .get_reference(checkout.controldir) + .encode("utf-8") + ) + self.assertFileEqual(reference_url, "reference/.bzr/branch/location") + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"ref", reference_url)), + request.execute(b"reference"), + ) def test_stacked_branch(self): """Opening a stacked branch does not open the stacked-on branch.""" - trunk = self.make_branch('trunk') - feature = self.make_branch('feature') + trunk = self.make_branch("trunk") + feature = self.make_branch("feature") feature.set_stacked_on_url(trunk.base) opened_branches = [] _mod_branch.Branch.hooks.install_named_hook( - 'open', opened_branches.append, None) + "open", opened_branches.append, None + ) backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranchV3(backing) request.setup_jail() try: - response = request.execute(b'feature') + response = request.execute(b"feature") finally: request.teardown_jail() expected_format = feature._format.network_name() - self.assertEqual(smart_req.SuccessfulSmartServerResponse( - (b'branch', expected_format)), - response) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"branch", expected_format)), + response, + ) self.assertLength(1, opened_branches) def test_notification_on_branch_from_repository(self): """When there is a repository, the error should return details.""" backing = self.get_transport() request = smart_dir.SmartServerRequestOpenBranchV3(backing) - self.make_repository('.') - self.assertEqual(smart_req.SmartServerResponse( - (b'nobranch', b'location is a repository')), - request.execute(b'')) + self.make_repository(".") + self.assertEqual( + smart_req.SmartServerResponse((b"nobranch", b"location is a repository")), + request.execute(b""), + ) class TestSmartServerRequestRevisionHistory(tests.TestCaseWithMemoryTransport): - def test_empty(self): """For an empty branch, the body is empty.""" backing = self.get_transport() request = smart_branch.SmartServerRequestRevisionHistory(backing) - self.make_branch('.') - self.assertEqual(smart_req.SmartServerResponse((b'ok', ), b''), - request.execute(b'')) + self.make_branch(".") + self.assertEqual( + smart_req.SmartServerResponse((b"ok",), b""), request.execute(b"") + ) def test_not_empty(self): """For a non-empty branch, the body is empty.""" backing = self.get_transport() request = smart_branch.SmartServerRequestRevisionHistory(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - r1 = tree.commit('1st commit') - r2 = tree.commit('2nd commit', rev_id='\xc8'.encode()) + tree.add("") + r1 = tree.commit("1st commit") + r2 = tree.commit("2nd commit", rev_id="\xc8".encode()) tree.unlock() self.assertEqual( - smart_req.SmartServerResponse((b'ok', ), (b'\x00'.join([r1, r2]))), - request.execute(b'')) + smart_req.SmartServerResponse((b"ok",), (b"\x00".join([r1, r2]))), + request.execute(b""), + ) class TestSmartServerBranchRequest(tests.TestCaseWithMemoryTransport): - def test_no_branch(self): """When there is a bzrdir and no branch, NotBranchError is raised.""" backing = self.get_transport() request = smart_branch.SmartServerBranchRequest(backing) - self.make_controldir('.') - self.assertRaises(errors.NotBranchError, - request.execute, b'') + self.make_controldir(".") + self.assertRaises(errors.NotBranchError, request.execute, b"") def test_branch_reference(self): """When there is a branch reference, NotBranchError is raised.""" backing = self.get_transport() request = smart_branch.SmartServerBranchRequest(backing) - branch = self.make_branch('branch') - branch.create_checkout('reference', lightweight=True) - self.assertRaises(errors.NotBranchError, - request.execute, b'checkout') + branch = self.make_branch("branch") + branch.create_checkout("reference", lightweight=True) + self.assertRaises(errors.NotBranchError, request.execute, b"checkout") -class TestSmartServerBranchRequestLastRevisionInfo( - tests.TestCaseWithMemoryTransport): - +class TestSmartServerBranchRequestLastRevisionInfo(tests.TestCaseWithMemoryTransport): def test_empty(self): """For an empty branch, the result is ('ok', '0', b'null:').""" backing = self.get_transport() - request = smart_branch.SmartServerBranchRequestLastRevisionInfo( - backing) - self.make_branch('.') + request = smart_branch.SmartServerBranchRequestLastRevisionInfo(backing) + self.make_branch(".") self.assertEqual( - smart_req.SmartServerResponse((b'ok', b'0', b'null:')), - request.execute(b'')) + smart_req.SmartServerResponse((b"ok", b"0", b"null:")), request.execute(b"") + ) def test_ghost(self): """For an empty branch, the result is ('ok', '0', b'null:').""" backing = self.get_transport() - request = smart_branch.SmartServerBranchRequestLastRevisionInfo( - backing) - branch = self.make_branch('.') + request = smart_branch.SmartServerBranchRequestLastRevisionInfo(backing) + branch = self.make_branch(".") def last_revision_info(): - raise errors.GhostRevisionsHaveNoRevno(b'revid1', b'revid2') - self.overrideAttr(branch, 'last_revision_info', last_revision_info) - self.assertRaises(errors.GhostRevisionsHaveNoRevno, - request.do_with_branch, branch) + raise errors.GhostRevisionsHaveNoRevno(b"revid1", b"revid2") + + self.overrideAttr(branch, "last_revision_info", last_revision_info) + self.assertRaises( + errors.GhostRevisionsHaveNoRevno, request.do_with_branch, branch + ) def test_not_empty(self): """For a non-empty branch, the result is ('ok', 'revno', 'revid').""" backing = self.get_transport() - request = smart_branch.SmartServerBranchRequestLastRevisionInfo( - backing) - tree = self.make_branch_and_memory_tree('.') + request = smart_branch.SmartServerBranchRequestLastRevisionInfo(backing) + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - rev_id_utf8 = '\xc8'.encode() - tree.commit('1st commit') - tree.commit('2nd commit', rev_id=rev_id_utf8) + tree.add("") + rev_id_utf8 = "\xc8".encode() + tree.commit("1st commit") + tree.commit("2nd commit", rev_id=rev_id_utf8) tree.unlock() self.assertEqual( - smart_req.SmartServerResponse((b'ok', b'2', rev_id_utf8)), - request.execute(b'')) - + smart_req.SmartServerResponse((b"ok", b"2", rev_id_utf8)), + request.execute(b""), + ) -class TestSmartServerBranchRequestRevisionIdToRevno( - tests.TestCaseWithMemoryTransport): +class TestSmartServerBranchRequestRevisionIdToRevno(tests.TestCaseWithMemoryTransport): def test_null(self): backing = self.get_transport() - request = smart_branch.SmartServerBranchRequestRevisionIdToRevno( - backing) - self.make_branch('.') - self.assertEqual(smart_req.SmartServerResponse((b'ok', b'0')), - request.execute(b'', b'null:')) + request = smart_branch.SmartServerBranchRequestRevisionIdToRevno(backing) + self.make_branch(".") + self.assertEqual( + smart_req.SmartServerResponse((b"ok", b"0")), request.execute(b"", b"null:") + ) def test_ghost_revision(self): backing = self.get_transport() - request = smart_branch.SmartServerBranchRequestRevisionIdToRevno( - backing) - branch = self.make_branch('.') + request = smart_branch.SmartServerBranchRequestRevisionIdToRevno(backing) + branch = self.make_branch(".") + def revision_id_to_dotted_revno(revid): - raise errors.GhostRevisionsHaveNoRevno(revid, b'ghost-revid') - self.overrideAttr(branch, 'revision_id_to_dotted_revno', revision_id_to_dotted_revno) + raise errors.GhostRevisionsHaveNoRevno(revid, b"ghost-revid") + + self.overrideAttr( + branch, "revision_id_to_dotted_revno", revision_id_to_dotted_revno + ) self.assertEqual( smart_req.FailedSmartServerResponse( - (b'GhostRevisionsHaveNoRevno', b'revid', b'ghost-revid')), - request.do_with_branch(branch, b'revid')) + (b"GhostRevisionsHaveNoRevno", b"revid", b"ghost-revid") + ), + request.do_with_branch(branch, b"revid"), + ) def test_simple(self): backing = self.get_transport() - request = smart_branch.SmartServerBranchRequestRevisionIdToRevno( - backing) - tree = self.make_branch_and_memory_tree('.') + request = smart_branch.SmartServerBranchRequestRevisionIdToRevno(backing) + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - r1 = tree.commit('1st commit') + tree.add("") + r1 = tree.commit("1st commit") tree.unlock() self.assertEqual( - smart_req.SmartServerResponse((b'ok', b'1')), - request.execute(b'', r1)) + smart_req.SmartServerResponse((b"ok", b"1")), request.execute(b"", r1) + ) def test_not_found(self): backing = self.get_transport() - request = smart_branch.SmartServerBranchRequestRevisionIdToRevno( - backing) - self.make_branch('.') + request = smart_branch.SmartServerBranchRequestRevisionIdToRevno(backing) + self.make_branch(".") self.assertEqual( - smart_req.FailedSmartServerResponse( - (b'NoSuchRevision', b'idontexist')), - request.execute(b'', b'idontexist')) - + smart_req.FailedSmartServerResponse((b"NoSuchRevision", b"idontexist")), + request.execute(b"", b"idontexist"), + ) -class TestSmartServerBranchRequestGetConfigFile( - tests.TestCaseWithMemoryTransport): +class TestSmartServerBranchRequestGetConfigFile(tests.TestCaseWithMemoryTransport): def test_default(self): """With no file, we get empty content.""" backing = self.get_transport() request = smart_branch.SmartServerBranchGetConfigFile(backing) - self.make_branch('.') + self.make_branch(".") # there should be no file by default - content = b'' - self.assertEqual(smart_req.SmartServerResponse((b'ok', ), content), - request.execute(b'')) + content = b"" + self.assertEqual( + smart_req.SmartServerResponse((b"ok",), content), request.execute(b"") + ) def test_with_content(self): # SmartServerBranchGetConfigFile should return the content from @@ -960,15 +1023,15 @@ def test_with_content(self): # may perform more complex processing. backing = self.get_transport() request = smart_branch.SmartServerBranchGetConfigFile(backing) - branch = self.make_branch('.') - branch._transport.put_bytes('branch.conf', b'foo bar baz') + branch = self.make_branch(".") + branch._transport.put_bytes("branch.conf", b"foo bar baz") self.assertEqual( - smart_req.SmartServerResponse((b'ok', ), b'foo bar baz'), - request.execute(b'')) + smart_req.SmartServerResponse((b"ok",), b"foo bar baz"), + request.execute(b""), + ) class TestLockedBranch(tests.TestCaseWithMemoryTransport): - def get_lock_tokens(self, branch): branch_token = branch.lock_write().token repo_token = branch.repository.lock_write().repository_token @@ -977,85 +1040,84 @@ def get_lock_tokens(self, branch): class TestSmartServerBranchRequestPutConfigFile(TestLockedBranch): - def test_with_content(self): backing = self.get_transport() request = smart_branch.SmartServerBranchPutConfigFile(backing) - branch = self.make_branch('.') + branch = self.make_branch(".") branch_token, repo_token = self.get_lock_tokens(branch) - self.assertIs(None, request.execute(b'', branch_token, repo_token)) + self.assertIs(None, request.execute(b"", branch_token, repo_token)) self.assertEqual( - smart_req.SmartServerResponse((b'ok', )), - request.do_body(b'foo bar baz')) + smart_req.SmartServerResponse((b"ok",)), request.do_body(b"foo bar baz") + ) self.assertEqual( - branch.control_transport.get_bytes('branch.conf'), - b'foo bar baz') + branch.control_transport.get_bytes("branch.conf"), b"foo bar baz" + ) branch.unlock() class TestSmartServerBranchRequestSetConfigOption(TestLockedBranch): - def test_value_name(self): - branch = self.make_branch('.') + branch = self.make_branch(".") request = smart_branch.SmartServerBranchRequestSetConfigOption( - branch.controldir.root_transport) + branch.controldir.root_transport + ) branch_token, repo_token = self.get_lock_tokens(branch) config = branch._get_config() - result = request.execute(b'', branch_token, repo_token, b'bar', b'foo', - b'') + result = request.execute(b"", branch_token, repo_token, b"bar", b"foo", b"") self.assertEqual(smart_req.SuccessfulSmartServerResponse(()), result) - self.assertEqual('bar', config.get_option('foo')) + self.assertEqual("bar", config.get_option("foo")) # Cleanup branch.unlock() def test_value_name_section(self): - branch = self.make_branch('.') + branch = self.make_branch(".") request = smart_branch.SmartServerBranchRequestSetConfigOption( - branch.controldir.root_transport) + branch.controldir.root_transport + ) branch_token, repo_token = self.get_lock_tokens(branch) config = branch._get_config() - result = request.execute(b'', branch_token, repo_token, b'bar', b'foo', - b'gam') + result = request.execute(b"", branch_token, repo_token, b"bar", b"foo", b"gam") self.assertEqual(smart_req.SuccessfulSmartServerResponse(()), result) - self.assertEqual('bar', config.get_option('foo', 'gam')) + self.assertEqual("bar", config.get_option("foo", "gam")) # Cleanup branch.unlock() class TestSmartServerBranchRequestSetConfigOptionDict(TestLockedBranch): - def setUp(self): TestLockedBranch.setUp(self) # A dict with non-ascii keys and values to exercise unicode # roundtripping. - self.encoded_value_dict = ( - b'd5:ascii1:a11:unicode \xe2\x8c\x9a3:\xe2\x80\xbde') - self.value_dict = { - 'ascii': 'a', 'unicode \N{WATCH}': '\N{INTERROBANG}'} + self.encoded_value_dict = b"d5:ascii1:a11:unicode \xe2\x8c\x9a3:\xe2\x80\xbde" + self.value_dict = {"ascii": "a", "unicode \N{WATCH}": "\N{INTERROBANG}"} def test_value_name(self): - branch = self.make_branch('.') + branch = self.make_branch(".") request = smart_branch.SmartServerBranchRequestSetConfigOptionDict( - branch.controldir.root_transport) + branch.controldir.root_transport + ) branch_token, repo_token = self.get_lock_tokens(branch) config = branch._get_config() - result = request.execute(b'', branch_token, repo_token, - self.encoded_value_dict, b'foo', b'') + result = request.execute( + b"", branch_token, repo_token, self.encoded_value_dict, b"foo", b"" + ) self.assertEqual(smart_req.SuccessfulSmartServerResponse(()), result) - self.assertEqual(self.value_dict, config.get_option('foo')) + self.assertEqual(self.value_dict, config.get_option("foo")) # Cleanup branch.unlock() def test_value_name_section(self): - branch = self.make_branch('.') + branch = self.make_branch(".") request = smart_branch.SmartServerBranchRequestSetConfigOptionDict( - branch.controldir.root_transport) + branch.controldir.root_transport + ) branch_token, repo_token = self.get_lock_tokens(branch) config = branch._get_config() - result = request.execute(b'', branch_token, repo_token, - self.encoded_value_dict, b'foo', b'gam') + result = request.execute( + b"", branch_token, repo_token, self.encoded_value_dict, b"foo", b"gam" + ) self.assertEqual(smart_req.SuccessfulSmartServerResponse(()), result) - self.assertEqual(self.value_dict, config.get_option('foo', 'gam')) + self.assertEqual(self.value_dict, config.get_option("foo", "gam")) # Cleanup branch.unlock() @@ -1065,29 +1127,31 @@ class TestSmartServerBranchRequestSetTagsBytes(TestLockedBranch): # methods] so only need to test straight forward cases. def test_set_bytes(self): - base_branch = self.make_branch('base') + base_branch = self.make_branch("base") tag_bytes = base_branch._get_tags_bytes() # get_lock_tokens takes out a lock. branch_token, repo_token = self.get_lock_tokens(base_branch) - request = smart_branch.SmartServerBranchSetTagsBytes( - self.get_transport()) - response = request.execute(b'base', branch_token, repo_token) + request = smart_branch.SmartServerBranchSetTagsBytes(self.get_transport()) + response = request.execute(b"base", branch_token, repo_token) self.assertEqual(None, response) response = request.do_chunk(tag_bytes) self.assertEqual(None, response) response = request.do_end() - self.assertEqual( - smart_req.SuccessfulSmartServerResponse(()), response) + self.assertEqual(smart_req.SuccessfulSmartServerResponse(()), response) base_branch.unlock() def test_lock_failed(self): - base_branch = self.make_branch('base') + base_branch = self.make_branch("base") base_branch.lock_write() tag_bytes = base_branch._get_tags_bytes() - request = smart_branch.SmartServerBranchSetTagsBytes( - self.get_transport()) - self.assertRaises(errors.TokenMismatch, request.execute, - b'base', b'wrong token', b'wrong token') + request = smart_branch.SmartServerBranchSetTagsBytes(self.get_transport()) + self.assertRaises( + errors.TokenMismatch, + request.execute, + b"base", + b"wrong token", + b"wrong token", + ) # The request handler will keep processing the message parts, so even # if the request fails immediately do_chunk and do_end are still # called. @@ -1103,7 +1167,7 @@ def setUp(self): super().setUp() backing_transport = self.get_transport() self.request = self.request_class(backing_transport) - self.tree = self.make_branch_and_memory_tree('.') + self.tree = self.make_branch_and_memory_tree(".") def lock_branch(self): return self.get_lock_tokens(self.tree.branch) @@ -1113,15 +1177,13 @@ def unlock_branch(self): def set_last_revision(self, revision_id, revno): branch_token, repo_token = self.lock_branch() - response = self._set_last_revision( - revision_id, revno, branch_token, repo_token) + response = self._set_last_revision(revision_id, revno, branch_token, repo_token) self.unlock_branch() return response def assertRequestSucceeds(self, revision_id, revno): response = self.set_last_revision(revision_id, revno) - self.assertEqual(smart_req.SuccessfulSmartServerResponse((b'ok',)), - response) + self.assertEqual(smart_req.SuccessfulSmartServerResponse((b"ok",)), response) class TestSetLastRevisionVerbMixin: @@ -1129,22 +1191,22 @@ class TestSetLastRevisionVerbMixin: def test_set_null_to_null(self): """An empty branch can have its last revision set to b'null:'.""" - self.assertRequestSucceeds(b'null:', 0) + self.assertRequestSucceeds(b"null:", 0) def test_NoSuchRevision(self): """If the revision_id is not present, the verb returns NoSuchRevision.""" - revision_id = b'non-existent revision' + revision_id = b"non-existent revision" self.assertEqual( - smart_req.FailedSmartServerResponse( - (b'NoSuchRevision', revision_id)), - self.set_last_revision(revision_id, 1)) + smart_req.FailedSmartServerResponse((b"NoSuchRevision", revision_id)), + self.set_last_revision(revision_id, 1), + ) def make_tree_with_two_commits(self): self.tree.lock_write() - self.tree.add('') - rev_id_utf8 = '\xc8'.encode() - self.tree.commit('1st commit', rev_id=rev_id_utf8) - self.tree.commit('2nd commit', rev_id=b'rev-2') + self.tree.add("") + rev_id_utf8 = "\xc8".encode() + self.tree.commit("1st commit", rev_id=rev_id_utf8) + self.tree.commit("2nd commit", rev_id=b"rev-2") self.tree.unlock() def test_branch_last_revision_info_is_updated(self): @@ -1154,64 +1216,64 @@ def test_branch_last_revision_info_is_updated(self): # Make a branch with an empty revision history, but two revisions in # its repository. self.make_tree_with_two_commits() - rev_id_utf8 = '\xc8'.encode() - self.tree.branch.set_last_revision_info(0, b'null:') - self.assertEqual( - (0, b'null:'), self.tree.branch.last_revision_info()) + rev_id_utf8 = "\xc8".encode() + self.tree.branch.set_last_revision_info(0, b"null:") + self.assertEqual((0, b"null:"), self.tree.branch.last_revision_info()) # We can update the branch to a revision that is present in the # repository. self.assertRequestSucceeds(rev_id_utf8, 1) - self.assertEqual( - (1, rev_id_utf8), self.tree.branch.last_revision_info()) + self.assertEqual((1, rev_id_utf8), self.tree.branch.last_revision_info()) def test_branch_last_revision_info_rewind(self): """A branch's tip can be set to a revision that is an ancestor of the current tip. """ self.make_tree_with_two_commits() - rev_id_utf8 = '\xc8'.encode() - self.assertEqual( - (2, b'rev-2'), self.tree.branch.last_revision_info()) + rev_id_utf8 = "\xc8".encode() + self.assertEqual((2, b"rev-2"), self.tree.branch.last_revision_info()) self.assertRequestSucceeds(rev_id_utf8, 1) - self.assertEqual( - (1, rev_id_utf8), self.tree.branch.last_revision_info()) + self.assertEqual((1, rev_id_utf8), self.tree.branch.last_revision_info()) def test_TipChangeRejected(self): """If a pre_change_branch_tip hook raises TipChangeRejected, the verb returns TipChangeRejected. """ - rejection_message = 'rejection message\N{INTERROBANG}' + rejection_message = "rejection message\N{INTERROBANG}" def hook_that_rejects(params): raise errors.TipChangeRejected(rejection_message) + _mod_branch.Branch.hooks.install_named_hook( - 'pre_change_branch_tip', hook_that_rejects, None) + "pre_change_branch_tip", hook_that_rejects, None + ) self.assertEqual( smart_req.FailedSmartServerResponse( - (b'TipChangeRejected', rejection_message.encode('utf-8'))), - self.set_last_revision(b'null:', 0)) + (b"TipChangeRejected", rejection_message.encode("utf-8")) + ), + self.set_last_revision(b"null:", 0), + ) class TestSmartServerBranchRequestSetLastRevision( - SetLastRevisionTestBase, TestSetLastRevisionVerbMixin): + SetLastRevisionTestBase, TestSetLastRevisionVerbMixin +): """Tests for Branch.set_last_revision verb.""" request_class = smart_branch.SmartServerBranchRequestSetLastRevision def _set_last_revision(self, revision_id, revno, branch_token, repo_token): - return self.request.execute( - b'', branch_token, repo_token, revision_id) + return self.request.execute(b"", branch_token, repo_token, revision_id) class TestSmartServerBranchRequestSetLastRevisionInfo( - SetLastRevisionTestBase, TestSetLastRevisionVerbMixin): + SetLastRevisionTestBase, TestSetLastRevisionVerbMixin +): """Tests for Branch.set_last_revision_info verb.""" request_class = smart_branch.SmartServerBranchRequestSetLastRevisionInfo def _set_last_revision(self, revision_id, revno, branch_token, repo_token): - return self.request.execute( - b'', branch_token, repo_token, revno, revision_id) + return self.request.execute(b"", branch_token, repo_token, revno, revision_id) def test_NoSuchRevision(self): """Branch.set_last_revision_info does not have to return @@ -1221,51 +1283,50 @@ def test_NoSuchRevision(self): class TestSmartServerBranchRequestSetLastRevisionEx( - SetLastRevisionTestBase, TestSetLastRevisionVerbMixin): + SetLastRevisionTestBase, TestSetLastRevisionVerbMixin +): """Tests for Branch.set_last_revision_ex verb.""" request_class = smart_branch.SmartServerBranchRequestSetLastRevisionEx def _set_last_revision(self, revision_id, revno, branch_token, repo_token): - return self.request.execute( - b'', branch_token, repo_token, revision_id, 0, 0) + return self.request.execute(b"", branch_token, repo_token, revision_id, 0, 0) def assertRequestSucceeds(self, revision_id, revno): response = self.set_last_revision(revision_id, revno) self.assertEqual( - smart_req.SuccessfulSmartServerResponse( - (b'ok', revno, revision_id)), - response) + smart_req.SuccessfulSmartServerResponse((b"ok", revno, revision_id)), + response, + ) def test_branch_last_revision_info_rewind(self): """A branch's tip can be set to a revision that is an ancestor of the current tip, but only if allow_overwrite_descendant is passed. """ self.make_tree_with_two_commits() - rev_id_utf8 = '\xc8'.encode() - self.assertEqual( - (2, b'rev-2'), self.tree.branch.last_revision_info()) + rev_id_utf8 = "\xc8".encode() + self.assertEqual((2, b"rev-2"), self.tree.branch.last_revision_info()) # If allow_overwrite_descendant flag is 0, then trying to set the tip # to an older revision ID has no effect. branch_token, repo_token = self.lock_branch() response = self.request.execute( - b'', branch_token, repo_token, rev_id_utf8, 0, 0) - self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'ok', 2, b'rev-2')), - response) + b"", branch_token, repo_token, rev_id_utf8, 0, 0 + ) self.assertEqual( - (2, b'rev-2'), self.tree.branch.last_revision_info()) + smart_req.SuccessfulSmartServerResponse((b"ok", 2, b"rev-2")), response + ) + self.assertEqual((2, b"rev-2"), self.tree.branch.last_revision_info()) # If allow_overwrite_descendant flag is 1, then setting the tip to an # ancestor works. response = self.request.execute( - b'', branch_token, repo_token, rev_id_utf8, 0, 1) + b"", branch_token, repo_token, rev_id_utf8, 0, 1 + ) self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'ok', 1, rev_id_utf8)), - response) + smart_req.SuccessfulSmartServerResponse((b"ok", 1, rev_id_utf8)), response + ) self.unlock_branch() - self.assertEqual( - (1, rev_id_utf8), self.tree.branch.last_revision_info()) + self.assertEqual((1, rev_id_utf8), self.tree.branch.last_revision_info()) def make_branch_with_divergent_history(self): """Make a branch with divergent history in its repo. @@ -1274,16 +1335,16 @@ def make_branch_with_divergent_history(self): 'child-1', which diverges from a common base revision. """ self.tree.lock_write() - self.tree.add('') - self.tree.commit('1st commit') + self.tree.add("") + self.tree.commit("1st commit") revno_1, revid_1 = self.tree.branch.last_revision_info() - self.tree.commit('2nd commit', rev_id=b'child-1') + self.tree.commit("2nd commit", rev_id=b"child-1") # Undo the second commit self.tree.branch.set_last_revision_info(revno_1, revid_1) self.tree.set_parent_ids([revid_1]) # Make a new second commit, child-2. child-2 has diverged from # child-1. - self.tree.commit('2nd commit', rev_id=b'child-2') + self.tree.commit("2nd commit", rev_id=b"child-2") self.tree.unlock() def test_not_allow_diverged(self): @@ -1292,10 +1353,11 @@ def test_not_allow_diverged(self): """ self.make_branch_with_divergent_history() self.assertEqual( - smart_req.FailedSmartServerResponse((b'Diverged',)), - self.set_last_revision(b'child-1', 2)) + smart_req.FailedSmartServerResponse((b"Diverged",)), + self.set_last_revision(b"child-1", 2), + ) # The branch tip was not changed. - self.assertEqual(b'child-2', self.tree.branch.last_revision()) + self.assertEqual(b"child-2", self.tree.branch.last_revision()) def test_allow_diverged(self): """If allow_diverged is passed, then setting a divergent history @@ -1303,67 +1365,63 @@ def test_allow_diverged(self): """ self.make_branch_with_divergent_history() branch_token, repo_token = self.lock_branch() - response = self.request.execute( - b'', branch_token, repo_token, b'child-1', 1, 0) + response = self.request.execute(b"", branch_token, repo_token, b"child-1", 1, 0) self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'ok', 2, b'child-1')), - response) + smart_req.SuccessfulSmartServerResponse((b"ok", 2, b"child-1")), response + ) self.unlock_branch() # The branch tip was changed. - self.assertEqual(b'child-1', self.tree.branch.last_revision()) + self.assertEqual(b"child-1", self.tree.branch.last_revision()) class TestSmartServerBranchBreakLock(tests.TestCaseWithMemoryTransport): - def test_lock_to_break(self): - base_branch = self.make_branch('base') - request = smart_branch.SmartServerBranchBreakLock( - self.get_transport()) + base_branch = self.make_branch("base") + request = smart_branch.SmartServerBranchBreakLock(self.get_transport()) base_branch.lock_write() self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'ok', ), None), - request.execute(b'base')) + smart_req.SuccessfulSmartServerResponse((b"ok",), None), + request.execute(b"base"), + ) def test_nothing_to_break(self): - self.make_branch('base') - request = smart_branch.SmartServerBranchBreakLock( - self.get_transport()) + self.make_branch("base") + request = smart_branch.SmartServerBranchBreakLock(self.get_transport()) self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'ok', ), None), - request.execute(b'base')) + smart_req.SuccessfulSmartServerResponse((b"ok",), None), + request.execute(b"base"), + ) class TestSmartServerBranchRequestGetParent(tests.TestCaseWithMemoryTransport): - def test_get_parent_none(self): - self.make_branch('base') + self.make_branch("base") request = smart_branch.SmartServerBranchGetParent(self.get_transport()) - response = request.execute(b'base') - self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'',)), response) + response = request.execute(b"base") + self.assertEqual(smart_req.SuccessfulSmartServerResponse((b"",)), response) def test_get_parent_something(self): - base_branch = self.make_branch('base') - base_branch.set_parent(self.get_url('foo')) + base_branch = self.make_branch("base") + base_branch.set_parent(self.get_url("foo")) request = smart_branch.SmartServerBranchGetParent(self.get_transport()) - response = request.execute(b'base') + response = request.execute(b"base") self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b"../foo",)), - response) + smart_req.SuccessfulSmartServerResponse((b"../foo",)), response + ) class TestSmartServerBranchRequestSetParent(TestLockedBranch): - def test_set_parent_none(self): - branch = self.make_branch('base', format="1.9") + branch = self.make_branch("base", format="1.9") branch.lock_write() - branch._set_parent_location('foo') + branch._set_parent_location("foo") branch.unlock() request = smart_branch.SmartServerBranchRequestSetParentLocation( - self.get_transport()) + self.get_transport() + ) branch_token, repo_token = self.get_lock_tokens(branch) try: - response = request.execute(b'base', branch_token, repo_token, b'') + response = request.execute(b"base", branch_token, repo_token, b"") finally: branch.unlock() self.assertEqual(smart_req.SuccessfulSmartServerResponse(()), response) @@ -1372,81 +1430,76 @@ def test_set_parent_none(self): self.assertEqual(None, branch.get_parent()) def test_set_parent_something(self): - branch = self.make_branch('base', format="1.9") + branch = self.make_branch("base", format="1.9") request = smart_branch.SmartServerBranchRequestSetParentLocation( - self.get_transport()) + self.get_transport() + ) branch_token, repo_token = self.get_lock_tokens(branch) try: - response = request.execute(b'base', branch_token, repo_token, - b'http://bar/') + response = request.execute( + b"base", branch_token, repo_token, b"http://bar/" + ) finally: branch.unlock() self.assertEqual(smart_req.SuccessfulSmartServerResponse(()), response) refreshed = _mod_branch.Branch.open(branch.base) - self.assertEqual('http://bar/', refreshed.get_parent()) + self.assertEqual("http://bar/", refreshed.get_parent()) -class TestSmartServerBranchRequestGetTagsBytes( - tests.TestCaseWithMemoryTransport): +class TestSmartServerBranchRequestGetTagsBytes(tests.TestCaseWithMemoryTransport): # Only called when the branch format and tags match [yay factory # methods] so only need to test straight forward cases. def test_get_bytes(self): - self.make_branch('base') - request = smart_branch.SmartServerBranchGetTagsBytes( - self.get_transport()) - response = request.execute(b'base') - self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'',)), response) + self.make_branch("base") + request = smart_branch.SmartServerBranchGetTagsBytes(self.get_transport()) + response = request.execute(b"base") + self.assertEqual(smart_req.SuccessfulSmartServerResponse((b"",)), response) -class TestSmartServerBranchRequestGetStackedOnURL( - tests.TestCaseWithMemoryTransport): - +class TestSmartServerBranchRequestGetStackedOnURL(tests.TestCaseWithMemoryTransport): def test_get_stacked_on_url(self): - self.make_branch('base', format='1.6') - stacked_branch = self.make_branch('stacked', format='1.6') + self.make_branch("base", format="1.6") + stacked_branch = self.make_branch("stacked", format="1.6") # typically should be relative - stacked_branch.set_stacked_on_url('../base') + stacked_branch.set_stacked_on_url("../base") request = smart_branch.SmartServerBranchRequestGetStackedOnURL( - self.get_transport()) - response = request.execute(b'stacked') - self.assertEqual( - smart_req.SmartServerResponse((b'ok', b'../base')), - response) + self.get_transport() + ) + response = request.execute(b"stacked") + self.assertEqual(smart_req.SmartServerResponse((b"ok", b"../base")), response) class TestSmartServerBranchRequestLockWrite(TestLockedBranch): - def test_lock_write_on_unlocked_branch(self): backing = self.get_transport() request = smart_branch.SmartServerBranchRequestLockWrite(backing) - branch = self.make_branch('.', format='knit') + branch = self.make_branch(".", format="knit") repository = branch.repository - response = request.execute(b'') + response = request.execute(b"") branch_nonce = branch.control_files._lock.peek().nonce repository_nonce = repository.control_files._lock.peek().nonce - self.assertEqual(smart_req.SmartServerResponse( - (b'ok', branch_nonce, repository_nonce)), - response) + self.assertEqual( + smart_req.SmartServerResponse((b"ok", branch_nonce, repository_nonce)), + response, + ) # The branch (and associated repository) is now locked. Verify that # with a new branch object. new_branch = repository.controldir.open_branch() self.assertRaises(errors.LockContention, new_branch.lock_write) # Cleanup request = smart_branch.SmartServerBranchRequestUnlock(backing) - response = request.execute(b'', branch_nonce, repository_nonce) + response = request.execute(b"", branch_nonce, repository_nonce) def test_lock_write_on_locked_branch(self): backing = self.get_transport() request = smart_branch.SmartServerBranchRequestLockWrite(backing) - branch = self.make_branch('.') + branch = self.make_branch(".") branch_token = branch.lock_write().token branch.leave_lock_in_place() branch.unlock() - response = request.execute(b'') - self.assertEqual( - smart_req.SmartServerResponse((b'LockContention',)), response) + response = request.execute(b"") + self.assertEqual(smart_req.SmartServerResponse((b"LockContention",)), response) # Cleanup branch.lock_write(branch_token) branch.dont_leave_lock_in_place() @@ -1455,16 +1508,15 @@ def test_lock_write_on_locked_branch(self): def test_lock_write_with_tokens_on_locked_branch(self): backing = self.get_transport() request = smart_branch.SmartServerBranchRequestLockWrite(backing) - branch = self.make_branch('.', format='knit') + branch = self.make_branch(".", format="knit") branch_token, repo_token = self.get_lock_tokens(branch) branch.leave_lock_in_place() branch.repository.leave_lock_in_place() branch.unlock() - response = request.execute(b'', - branch_token, repo_token) + response = request.execute(b"", branch_token, repo_token) self.assertEqual( - smart_req.SmartServerResponse((b'ok', branch_token, repo_token)), - response) + smart_req.SmartServerResponse((b"ok", branch_token, repo_token)), response + ) # Cleanup branch.repository.lock_write(repo_token) branch.repository.dont_leave_lock_in_place() @@ -1476,15 +1528,13 @@ def test_lock_write_with_tokens_on_locked_branch(self): def test_lock_write_with_mismatched_tokens_on_locked_branch(self): backing = self.get_transport() request = smart_branch.SmartServerBranchRequestLockWrite(backing) - branch = self.make_branch('.', format='knit') + branch = self.make_branch(".", format="knit") branch_token, repo_token = self.get_lock_tokens(branch) branch.leave_lock_in_place() branch.repository.leave_lock_in_place() branch.unlock() - response = request.execute(b'', - branch_token + b'xxx', repo_token) - self.assertEqual( - smart_req.SmartServerResponse((b'TokenMismatch',)), response) + response = request.execute(b"", branch_token + b"xxx", repo_token) + self.assertEqual(smart_req.SmartServerResponse((b"TokenMismatch",)), response) # Cleanup branch.repository.lock_write(repo_token) branch.repository.dont_leave_lock_in_place() @@ -1496,14 +1546,13 @@ def test_lock_write_with_mismatched_tokens_on_locked_branch(self): def test_lock_write_on_locked_repo(self): backing = self.get_transport() request = smart_branch.SmartServerBranchRequestLockWrite(backing) - branch = self.make_branch('.', format='knit') + branch = self.make_branch(".", format="knit") repo = branch.repository repo_token = repo.lock_write().repository_token repo.leave_lock_in_place() repo.unlock() - response = request.execute(b'') - self.assertEqual( - smart_req.SmartServerResponse((b'LockContention',)), response) + response = request.execute(b"") + self.assertEqual(smart_req.SmartServerResponse((b"LockContention",)), response) # Cleanup repo.lock_write(repo_token) repo.dont_leave_lock_in_place() @@ -1512,46 +1561,40 @@ def test_lock_write_on_locked_repo(self): def test_lock_write_on_readonly_transport(self): backing = self.get_readonly_transport() request = smart_branch.SmartServerBranchRequestLockWrite(backing) - self.make_branch('.') - root = self.get_transport().clone('/') + self.make_branch(".") + root = self.get_transport().clone("/") path = urlutils.relative_url(root.base, self.get_transport().base) - response = request.execute(path.encode('utf-8')) + response = request.execute(path.encode("utf-8")) error_name, lock_str, why_str = response.args self.assertFalse(response.is_successful()) - self.assertEqual(b'LockFailed', error_name) + self.assertEqual(b"LockFailed", error_name) class TestSmartServerBranchRequestGetPhysicalLockStatus(TestLockedBranch): - def test_true(self): backing = self.get_transport() - request = smart_branch.SmartServerBranchRequestGetPhysicalLockStatus( - backing) - branch = self.make_branch('.') + request = smart_branch.SmartServerBranchRequestGetPhysicalLockStatus(backing) + branch = self.make_branch(".") branch_token, repo_token = self.get_lock_tokens(branch) self.assertEqual(True, branch.get_physical_lock_status()) - response = request.execute(b'') - self.assertEqual( - smart_req.SmartServerResponse((b'yes',)), response) + response = request.execute(b"") + self.assertEqual(smart_req.SmartServerResponse((b"yes",)), response) branch.unlock() def test_false(self): backing = self.get_transport() - request = smart_branch.SmartServerBranchRequestGetPhysicalLockStatus( - backing) - branch = self.make_branch('.') + request = smart_branch.SmartServerBranchRequestGetPhysicalLockStatus(backing) + branch = self.make_branch(".") self.assertEqual(False, branch.get_physical_lock_status()) - response = request.execute(b'') - self.assertEqual( - smart_req.SmartServerResponse((b'no',)), response) + response = request.execute(b"") + self.assertEqual(smart_req.SmartServerResponse((b"no",)), response) class TestSmartServerBranchRequestUnlock(TestLockedBranch): - def test_unlock_on_locked_branch_and_repo(self): backing = self.get_transport() request = smart_branch.SmartServerBranchRequestUnlock(backing) - branch = self.make_branch('.', format='knit') + branch = self.make_branch(".", format="knit") # Lock the branch branch_token, repo_token = self.get_lock_tokens(branch) # Unlock the branch (and repo) object, leaving the physical locks @@ -1559,10 +1602,8 @@ def test_unlock_on_locked_branch_and_repo(self): branch.leave_lock_in_place() branch.repository.leave_lock_in_place() branch.unlock() - response = request.execute(b'', - branch_token, repo_token) - self.assertEqual( - smart_req.SmartServerResponse((b'ok',)), response) + response = request.execute(b"", branch_token, repo_token) + self.assertEqual(smart_req.SmartServerResponse((b"ok",)), response) # The branch is now unlocked. Verify that with a new branch # object. new_branch = branch.controldir.open_branch() @@ -1572,25 +1613,22 @@ def test_unlock_on_locked_branch_and_repo(self): def test_unlock_on_unlocked_branch_unlocked_repo(self): backing = self.get_transport() request = smart_branch.SmartServerBranchRequestUnlock(backing) - self.make_branch('.', format='knit') - response = request.execute( - b'', b'branch token', b'repo token') - self.assertEqual( - smart_req.SmartServerResponse((b'TokenMismatch',)), response) + self.make_branch(".", format="knit") + response = request.execute(b"", b"branch token", b"repo token") + self.assertEqual(smart_req.SmartServerResponse((b"TokenMismatch",)), response) def test_unlock_on_unlocked_branch_locked_repo(self): backing = self.get_transport() request = smart_branch.SmartServerBranchRequestUnlock(backing) - branch = self.make_branch('.', format='knit') + branch = self.make_branch(".", format="knit") # Lock the repository. repo_token = branch.repository.lock_write().repository_token branch.repository.leave_lock_in_place() branch.repository.unlock() # Issue branch lock_write request on the unlocked branch (with locked # repo). - response = request.execute(b'', b'branch token', repo_token) - self.assertEqual( - smart_req.SmartServerResponse((b'TokenMismatch',)), response) + response = request.execute(b"", b"branch token", repo_token) + self.assertEqual(smart_req.SmartServerResponse((b"TokenMismatch",)), response) # Cleanup branch.repository.lock_write(repo_token) branch.repository.dont_leave_lock_in_place() @@ -1598,7 +1636,6 @@ def test_unlock_on_unlocked_branch_locked_repo(self): class TestSmartServerRepositoryRequest(tests.TestCaseWithMemoryTransport): - def test_no_repository(self): """Raise NoRepositoryPresent when there is a bzrdir and no repo.""" # we test this using a shared repository above the named path, @@ -1607,256 +1644,270 @@ def test_no_repository(self): # searching. backing = self.get_transport() request = smart_repo.SmartServerRepositoryRequest(backing) - self.make_repository('.', shared=True) - self.make_controldir('subdir') - self.assertRaises(errors.NoRepositoryPresent, - request.execute, b'subdir') - + self.make_repository(".", shared=True) + self.make_controldir("subdir") + self.assertRaises(errors.NoRepositoryPresent, request.execute, b"subdir") -class TestSmartServerRepositoryAddSignatureText( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRepositoryAddSignatureText(tests.TestCaseWithMemoryTransport): def test_add_text(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryAddSignatureText(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") write_token = tree.lock_write() self.addCleanup(tree.unlock) - tree.add('') - tree.commit("Message", rev_id=b'rev1') + tree.add("") + tree.commit("Message", rev_id=b"rev1") tree.branch.repository.start_write_group() write_group_tokens = tree.branch.repository.suspend_write_group() self.assertEqual( - None, request.execute( - b'', write_token, b'rev1', - *[token.encode('utf-8') for token in write_group_tokens])) - response = request.do_body(b'somesignature') + None, + request.execute( + b"", + write_token, + b"rev1", + *[token.encode("utf-8") for token in write_group_tokens], + ), + ) + response = request.do_body(b"somesignature") self.assertTrue(response.is_successful()) - self.assertEqual(response.args[0], b'ok') - write_group_tokens = [token.decode('utf-8') - for token in response.args[1:]] + self.assertEqual(response.args[0], b"ok") + write_group_tokens = [token.decode("utf-8") for token in response.args[1:]] tree.branch.repository.resume_write_group(write_group_tokens) tree.branch.repository.commit_write_group() tree.unlock() - self.assertEqual(b"somesignature", - tree.branch.repository.get_signature_text(b"rev1")) - + self.assertEqual( + b"somesignature", tree.branch.repository.get_signature_text(b"rev1") + ) -class TestSmartServerRepositoryAllRevisionIds( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRepositoryAllRevisionIds(tests.TestCaseWithMemoryTransport): def test_empty(self): """An empty body should be returned for an empty repository.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryAllRevisionIds(backing) - self.make_repository('.') + self.make_repository(".") self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b"ok", ), b""), - request.execute(b'')) + smart_req.SuccessfulSmartServerResponse((b"ok",), b""), request.execute(b"") + ) def test_some_revisions(self): """An empty body should be returned for an empty repository.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryAllRevisionIds(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - tree.commit(rev_id=b'origineel', message="message") - tree.commit(rev_id=b'nog-een-revisie', message="message") + tree.add("") + tree.commit(rev_id=b"origineel", message="message") + tree.commit(rev_id=b"nog-een-revisie", message="message") tree.unlock() self.assertIn( - request.execute(b''), - [smart_req.SuccessfulSmartServerResponse( - (b"ok", ), b"origineel\nnog-een-revisie"), - smart_req.SuccessfulSmartServerResponse( - (b"ok", ), b"nog-een-revisie\norigineel")]) + request.execute(b""), + [ + smart_req.SuccessfulSmartServerResponse( + (b"ok",), b"origineel\nnog-een-revisie" + ), + smart_req.SuccessfulSmartServerResponse( + (b"ok",), b"nog-een-revisie\norigineel" + ), + ], + ) class TestSmartServerRepositoryBreakLock(tests.TestCaseWithMemoryTransport): - def test_lock_to_break(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryBreakLock(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.branch.repository.lock_write() self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'ok', ), None), - request.execute(b'')) + smart_req.SuccessfulSmartServerResponse((b"ok",), None), + request.execute(b""), + ) def test_nothing_to_break(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryBreakLock(backing) - self.make_branch_and_memory_tree('.') + self.make_branch_and_memory_tree(".") self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'ok', ), None), - request.execute(b'')) + smart_req.SuccessfulSmartServerResponse((b"ok",), None), + request.execute(b""), + ) class TestSmartServerRepositoryGetParentMap(tests.TestCaseWithMemoryTransport): - def test_trivial_bzipped(self): # This tests that the wire encoding is actually bzipped backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetParentMap(backing) - self.make_branch_and_memory_tree('.') + self.make_branch_and_memory_tree(".") - self.assertEqual(None, - request.execute(b'', b'missing-id')) + self.assertEqual(None, request.execute(b"", b"missing-id")) # Note that it returns a body that is bzipped. self.assertEqual( - smart_req.SuccessfulSmartServerResponse( - (b'ok', ), bz2.compress(b'')), - request.do_body(b'\n\n0\n')) + smart_req.SuccessfulSmartServerResponse((b"ok",), bz2.compress(b"")), + request.do_body(b"\n\n0\n"), + ) def test_trivial_include_missing(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetParentMap(backing) - self.make_branch_and_memory_tree('.') + self.make_branch_and_memory_tree(".") - self.assertEqual( - None, request.execute(b'', b'missing-id', b'include-missing:')) + self.assertEqual(None, request.execute(b"", b"missing-id", b"include-missing:")) self.assertEqual( smart_req.SuccessfulSmartServerResponse( - (b'ok', ), bz2.compress(b'missing:missing-id')), - request.do_body(b'\n\n0\n')) - + (b"ok",), bz2.compress(b"missing:missing-id") + ), + request.do_body(b"\n\n0\n"), + ) -class TestSmartServerRepositoryGetRevisionGraph( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRepositoryGetRevisionGraph(tests.TestCaseWithMemoryTransport): def test_none_argument(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetRevisionGraph(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - r1 = tree.commit('1st commit') - r2 = tree.commit('2nd commit', rev_id='\xc8'.encode()) + tree.add("") + r1 = tree.commit("1st commit") + r2 = tree.commit("2nd commit", rev_id="\xc8".encode()) tree.unlock() # the lines of revision_id->revision_parent_list has no guaranteed # order coming out of a dict, so sort both our test and response - lines = sorted([b' '.join([r2, r1]), r1]) - response = request.execute(b'', b'') - response.body = b'\n'.join(sorted(response.body.split(b'\n'))) + lines = sorted([b" ".join([r2, r1]), r1]) + response = request.execute(b"", b"") + response.body = b"\n".join(sorted(response.body.split(b"\n"))) self.assertEqual( - smart_req.SmartServerResponse((b'ok', ), b'\n'.join(lines)), - response) + smart_req.SmartServerResponse((b"ok",), b"\n".join(lines)), response + ) def test_specific_revision_argument(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetRevisionGraph(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - rev_id_utf8 = '\xc9'.encode() - tree.commit('1st commit', rev_id=rev_id_utf8) - tree.commit('2nd commit', rev_id='\xc8'.encode()) + tree.add("") + rev_id_utf8 = "\xc9".encode() + tree.commit("1st commit", rev_id=rev_id_utf8) + tree.commit("2nd commit", rev_id="\xc8".encode()) tree.unlock() - self.assertEqual(smart_req.SmartServerResponse((b'ok', ), rev_id_utf8), - request.execute(b'', rev_id_utf8)) + self.assertEqual( + smart_req.SmartServerResponse((b"ok",), rev_id_utf8), + request.execute(b"", rev_id_utf8), + ) def test_no_such_revision(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetRevisionGraph(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - tree.commit('1st commit') + tree.add("") + tree.commit("1st commit") tree.unlock() # Note that it still returns body (of zero bytes). - self.assertEqual(smart_req.SmartServerResponse( - (b'nosuchrevision', b'missingrevision', ), b''), - request.execute(b'', b'missingrevision')) - + self.assertEqual( + smart_req.SmartServerResponse( + ( + b"nosuchrevision", + b"missingrevision", + ), + b"", + ), + request.execute(b"", b"missingrevision"), + ) -class TestSmartServerRepositoryGetRevIdForRevno( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRepositoryGetRevIdForRevno(tests.TestCaseWithMemoryTransport): def test_revno_found(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetRevIdForRevno(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - rev1_id_utf8 = '\xc8'.encode() - rev2_id_utf8 = '\xc9'.encode() - tree.commit('1st commit', rev_id=rev1_id_utf8) - tree.commit('2nd commit', rev_id=rev2_id_utf8) + tree.add("") + rev1_id_utf8 = "\xc8".encode() + rev2_id_utf8 = "\xc9".encode() + tree.commit("1st commit", rev_id=rev1_id_utf8) + tree.commit("2nd commit", rev_id=rev2_id_utf8) tree.unlock() - self.assertEqual(smart_req.SmartServerResponse((b'ok', rev1_id_utf8)), - request.execute(b'', 1, (2, rev2_id_utf8))) + self.assertEqual( + smart_req.SmartServerResponse((b"ok", rev1_id_utf8)), + request.execute(b"", 1, (2, rev2_id_utf8)), + ) def test_known_revid_missing(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetRevIdForRevno(backing) - self.make_repository('.') + self.make_repository(".") self.assertEqual( - smart_req.FailedSmartServerResponse((b'nosuchrevision', b'ghost')), - request.execute(b'', 1, (2, b'ghost'))) + smart_req.FailedSmartServerResponse((b"nosuchrevision", b"ghost")), + request.execute(b"", 1, (2, b"ghost")), + ) def test_history_incomplete(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetRevIdForRevno(backing) - parent = self.make_branch_and_memory_tree('parent', format='1.9') + parent = self.make_branch_and_memory_tree("parent", format="1.9") parent.lock_write() - parent.add([''], ids=[b'TREE_ROOT']) - parent.commit(message='first commit') - r2 = parent.commit(message='second commit') + parent.add([""], ids=[b"TREE_ROOT"]) + parent.commit(message="first commit") + r2 = parent.commit(message="second commit") parent.unlock() - local = self.make_branch_and_memory_tree('local', format='1.9') + local = self.make_branch_and_memory_tree("local", format="1.9") local.branch.pull(parent.branch) local.set_parent_ids([r2]) - r3 = local.commit(message='local commit') + r3 = local.commit(message="local commit") local.branch.create_clone_on_transport( - self.get_transport('stacked'), stacked_on=self.get_url('parent')) + self.get_transport("stacked"), stacked_on=self.get_url("parent") + ) self.assertEqual( - smart_req.SmartServerResponse((b'history-incomplete', 2, r2)), - request.execute(b'stacked', 1, (3, r3))) - + smart_req.SmartServerResponse((b"history-incomplete", 2, r2)), + request.execute(b"stacked", 1, (3, r3)), + ) -class TestSmartServerRepositoryIterRevisions( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRepositoryIterRevisions(tests.TestCaseWithMemoryTransport): def test_basic(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryIterRevisions(backing) - tree = self.make_branch_and_memory_tree('.', format='2a') + tree = self.make_branch_and_memory_tree(".", format="2a") tree.lock_write() - tree.add('') - tree.commit('1st commit', rev_id=b"rev1") - tree.commit('2nd commit', rev_id=b"rev2") + tree.add("") + tree.commit("1st commit", rev_id=b"rev1") + tree.commit("2nd commit", rev_id=b"rev2") tree.unlock() - self.assertIs(None, request.execute(b'')) + self.assertIs(None, request.execute(b"")) response = request.do_body(b"rev1\nrev2") self.assertTrue(response.is_successful()) # Format 2a uses serializer format 10 self.assertEqual(response.args, (b"ok", b"10")) self.addCleanup(tree.branch.lock_read().unlock) - entries = [zlib.compress(record.get_bytes_as("fulltext")) for record in - tree.branch.repository.revisions.get_record_stream( - [(b"rev1", ), (b"rev2", )], "unordered", True)] + entries = [ + zlib.compress(record.get_bytes_as("fulltext")) + for record in tree.branch.repository.revisions.get_record_stream( + [(b"rev1",), (b"rev2",)], "unordered", True + ) + ] contents = b"".join(response.body_stream) self.assertIn( contents, - ( - b"".join([entries[0], entries[1]]), - b"".join([entries[1], entries[0]])) + (b"".join([entries[0], entries[1]]), b"".join([entries[1], entries[0]])), ) def test_missing(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryIterRevisions(backing) - self.make_branch_and_memory_tree('.', format='2a') + self.make_branch_and_memory_tree(".", format="2a") - self.assertIs(None, request.execute(b'')) + self.assertIs(None, request.execute(b"")) response = request.do_body(b"rev1\nrev2") self.assertTrue(response.is_successful()) # Format 2a uses serializer format 10 @@ -1867,324 +1918,344 @@ def test_missing(self): class GetStreamTestBase(tests.TestCaseWithMemoryTransport): - def make_two_commit_repo(self): - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - r1 = tree.commit('1st commit') - r2 = tree.commit('2nd commit', rev_id='\xc8'.encode()) + tree.add("") + r1 = tree.commit("1st commit") + r2 = tree.commit("2nd commit", rev_id="\xc8".encode()) tree.unlock() repo = tree.branch.repository return repo, r1, r2 class TestSmartServerRepositoryGetStream(GetStreamTestBase): - def test_ancestry_of(self): """The search argument may be a 'ancestry-of' some heads'.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetStream(backing) repo, r1, r2 = self.make_two_commit_repo() - fetch_spec = [b'ancestry-of', r2] - lines = b'\n'.join(fetch_spec) - request.execute(b'', repo._format.network_name()) + fetch_spec = [b"ancestry-of", r2] + lines = b"\n".join(fetch_spec) + request.execute(b"", repo._format.network_name()) response = request.do_body(lines) - self.assertEqual((b'ok',), response.args) - stream_bytes = b''.join(response.body_stream) - self.assertStartsWith(stream_bytes, b'Bazaar pack format 1') + self.assertEqual((b"ok",), response.args) + stream_bytes = b"".join(response.body_stream) + self.assertStartsWith(stream_bytes, b"Bazaar pack format 1") def test_search(self): """The search argument may be a 'search' of some explicit keys.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetStream(backing) repo, r1, r2 = self.make_two_commit_repo() - fetch_spec = [b'search', r1 + b' ' + r2, b'null:', b'2'] - lines = b'\n'.join(fetch_spec) - request.execute(b'', repo._format.network_name()) + fetch_spec = [b"search", r1 + b" " + r2, b"null:", b"2"] + lines = b"\n".join(fetch_spec) + request.execute(b"", repo._format.network_name()) response = request.do_body(lines) - self.assertEqual((b'ok',), response.args) - stream_bytes = b''.join(response.body_stream) - self.assertStartsWith(stream_bytes, b'Bazaar pack format 1') + self.assertEqual((b"ok",), response.args) + stream_bytes = b"".join(response.body_stream) + self.assertStartsWith(stream_bytes, b"Bazaar pack format 1") def test_search_everything(self): """A search of 'everything' returns a stream.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetStream_1_19(backing) repo, r1, r2 = self.make_two_commit_repo() - serialised_fetch_spec = b'everything' - request.execute(b'', repo._format.network_name()) + serialised_fetch_spec = b"everything" + request.execute(b"", repo._format.network_name()) response = request.do_body(serialised_fetch_spec) - self.assertEqual((b'ok',), response.args) - stream_bytes = b''.join(response.body_stream) - self.assertStartsWith(stream_bytes, b'Bazaar pack format 1') + self.assertEqual((b"ok",), response.args) + stream_bytes = b"".join(response.body_stream) + self.assertStartsWith(stream_bytes, b"Bazaar pack format 1") class TestSmartServerRequestHasRevision(tests.TestCaseWithMemoryTransport): - def test_missing_revision(self): """For a missing revision, ('no', ) is returned.""" backing = self.get_transport() request = smart_repo.SmartServerRequestHasRevision(backing) - self.make_repository('.') - self.assertEqual(smart_req.SmartServerResponse((b'no', )), - request.execute(b'', b'revid')) + self.make_repository(".") + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), request.execute(b"", b"revid") + ) def test_present_revision(self): """For a present revision, ('yes', ) is returned.""" backing = self.get_transport() request = smart_repo.SmartServerRequestHasRevision(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - rev_id_utf8 = '\xc8abc'.encode() - tree.commit('a commit', rev_id=rev_id_utf8) + tree.add("") + rev_id_utf8 = "\xc8abc".encode() + tree.commit("a commit", rev_id=rev_id_utf8) tree.unlock() self.assertTrue(tree.branch.repository.has_revision(rev_id_utf8)) - self.assertEqual(smart_req.SmartServerResponse((b'yes', )), - request.execute(b'', rev_id_utf8)) + self.assertEqual( + smart_req.SmartServerResponse((b"yes",)), request.execute(b"", rev_id_utf8) + ) class TestSmartServerRepositoryIterFilesBytes(tests.TestCaseWithTransport): - def test_single(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryIterFilesBytes(backing) - t = self.make_branch_and_tree('.') + t = self.make_branch_and_tree(".") self.addCleanup(t.lock_write().unlock) self.build_tree_contents([("file", b"somecontents")]) t.add(["file"], ids=[b"thefileid"]) - t.commit(rev_id=b'somerev', message="add file") - self.assertIs(None, request.execute(b'')) + t.commit(rev_id=b"somerev", message="add file") + self.assertIs(None, request.execute(b"")) response = request.do_body(b"thefileid\0somerev\n") self.assertTrue(response.is_successful()) - self.assertEqual(response.args, (b"ok", )) - self.assertEqual(b"".join(response.body_stream), - b"ok\x000\n" + zlib.compress(b"somecontents")) + self.assertEqual(response.args, (b"ok",)) + self.assertEqual( + b"".join(response.body_stream), + b"ok\x000\n" + zlib.compress(b"somecontents"), + ) def test_missing(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryIterFilesBytes(backing) - t = self.make_branch_and_tree('.') + t = self.make_branch_and_tree(".") self.addCleanup(t.lock_write().unlock) - self.assertIs(None, request.execute(b'')) + self.assertIs(None, request.execute(b"")) response = request.do_body(b"thefileid\0revision\n") self.assertTrue(response.is_successful()) - self.assertEqual(response.args, (b"ok", )) - self.assertEqual(b"".join(response.body_stream), - b"absent\x00thefileid\x00revision\x000\n") + self.assertEqual(response.args, (b"ok",)) + self.assertEqual( + b"".join(response.body_stream), b"absent\x00thefileid\x00revision\x000\n" + ) class TestSmartServerRequestHasSignatureForRevisionId( - tests.TestCaseWithMemoryTransport): - + tests.TestCaseWithMemoryTransport +): def test_missing_revision(self): """For a missing revision, NoSuchRevision is returned.""" backing = self.get_transport() - request = smart_repo.SmartServerRequestHasSignatureForRevisionId( - backing) - self.make_repository('.') + request = smart_repo.SmartServerRequestHasSignatureForRevisionId(backing) + self.make_repository(".") self.assertEqual( - smart_req.FailedSmartServerResponse( - (b'nosuchrevision', b'revid'), None), - request.execute(b'', b'revid')) + smart_req.FailedSmartServerResponse((b"nosuchrevision", b"revid"), None), + request.execute(b"", b"revid"), + ) def test_missing_signature(self): """For a missing signature, ('no', ) is returned.""" backing = self.get_transport() - request = smart_repo.SmartServerRequestHasSignatureForRevisionId( - backing) - tree = self.make_branch_and_memory_tree('.') + request = smart_repo.SmartServerRequestHasSignatureForRevisionId(backing) + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - tree.commit('a commit', rev_id=b'A') + tree.add("") + tree.commit("a commit", rev_id=b"A") tree.unlock() - self.assertTrue(tree.branch.repository.has_revision(b'A')) - self.assertEqual(smart_req.SmartServerResponse((b'no', )), - request.execute(b'', b'A')) + self.assertTrue(tree.branch.repository.has_revision(b"A")) + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), request.execute(b"", b"A") + ) def test_present_signature(self): """For a present signature, ('yes', ) is returned.""" backing = self.get_transport() - request = smart_repo.SmartServerRequestHasSignatureForRevisionId( - backing) + request = smart_repo.SmartServerRequestHasSignatureForRevisionId(backing) strategy = gpg.LoopbackGPGStrategy(None) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') - tree.commit('a commit', rev_id=b'A') + tree.add("") + tree.commit("a commit", rev_id=b"A") tree.branch.repository.start_write_group() - tree.branch.repository.sign_revision(b'A', strategy) + tree.branch.repository.sign_revision(b"A", strategy) tree.branch.repository.commit_write_group() tree.unlock() - self.assertTrue(tree.branch.repository.has_revision(b'A')) - self.assertEqual(smart_req.SmartServerResponse((b'yes', )), - request.execute(b'', b'A')) + self.assertTrue(tree.branch.repository.has_revision(b"A")) + self.assertEqual( + smart_req.SmartServerResponse((b"yes",)), request.execute(b"", b"A") + ) class TestSmartServerRepositoryGatherStats(tests.TestCaseWithMemoryTransport): - def test_empty_revid(self): """With an empty revid, we get only size an number and revisions.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryGatherStats(backing) - repository = self.make_repository('.') + repository = self.make_repository(".") repository.gather_stats() - expected_body = b'revisions: 0\n' + expected_body = b"revisions: 0\n" self.assertEqual( - smart_req.SmartServerResponse((b'ok', ), expected_body), - request.execute(b'', b'', b'no')) + smart_req.SmartServerResponse((b"ok",), expected_body), + request.execute(b"", b"", b"no"), + ) def test_revid_with_committers(self): """For a revid we get more infos.""" backing = self.get_transport() - rev_id_utf8 = '\xc8abc'.encode() + rev_id_utf8 = "\xc8abc".encode() request = smart_repo.SmartServerRepositoryGatherStats(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') + tree.add("") # Let's build a predictable result - tree.commit('a commit', timestamp=123456.2, timezone=3600) - tree.commit('a commit', timestamp=654321.4, timezone=0, - rev_id=rev_id_utf8) + tree.commit("a commit", timestamp=123456.2, timezone=3600) + tree.commit("a commit", timestamp=654321.4, timezone=0, rev_id=rev_id_utf8) tree.unlock() tree.branch.repository.gather_stats() - expected_body = (b'firstrev: 123456.200 3600\n' - b'latestrev: 654321.400 0\n' - b'revisions: 2\n') + expected_body = ( + b"firstrev: 123456.200 3600\n" + b"latestrev: 654321.400 0\n" + b"revisions: 2\n" + ) self.assertEqual( - smart_req.SmartServerResponse((b'ok', ), expected_body), - request.execute(b'', rev_id_utf8, b'no')) + smart_req.SmartServerResponse((b"ok",), expected_body), + request.execute(b"", rev_id_utf8, b"no"), + ) def test_not_empty_repository_with_committers(self): """For a revid and requesting committers we get the whole thing.""" backing = self.get_transport() - rev_id_utf8 = '\xc8abc'.encode() + rev_id_utf8 = "\xc8abc".encode() request = smart_repo.SmartServerRepositoryGatherStats(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") tree.lock_write() - tree.add('') + tree.add("") # Let's build a predictable result - tree.commit('a commit', timestamp=123456.2, timezone=3600, - committer='foo') - tree.commit('a commit', timestamp=654321.4, timezone=0, - committer='bar', rev_id=rev_id_utf8) + tree.commit("a commit", timestamp=123456.2, timezone=3600, committer="foo") + tree.commit( + "a commit", + timestamp=654321.4, + timezone=0, + committer="bar", + rev_id=rev_id_utf8, + ) tree.unlock() tree.branch.repository.gather_stats() - expected_body = (b'committers: 2\n' - b'firstrev: 123456.200 3600\n' - b'latestrev: 654321.400 0\n' - b'revisions: 2\n') + expected_body = ( + b"committers: 2\n" + b"firstrev: 123456.200 3600\n" + b"latestrev: 654321.400 0\n" + b"revisions: 2\n" + ) self.assertEqual( - smart_req.SmartServerResponse((b'ok', ), expected_body), - request.execute(b'', rev_id_utf8, b'yes')) + smart_req.SmartServerResponse((b"ok",), expected_body), + request.execute(b"", rev_id_utf8, b"yes"), + ) def test_unknown_revid(self): """An unknown revision id causes a 'nosuchrevision' error.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryGatherStats(backing) - self.make_repository('.') + self.make_repository(".") self.assertEqual( - smart_req.FailedSmartServerResponse( - (b'nosuchrevision', b'mia'), None), - request.execute(b'', b'mia', b'yes')) + smart_req.FailedSmartServerResponse((b"nosuchrevision", b"mia"), None), + request.execute(b"", b"mia", b"yes"), + ) class TestSmartServerRepositoryIsShared(tests.TestCaseWithMemoryTransport): - def test_is_shared(self): """For a shared repository, ('yes', ) is returned.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryIsShared(backing) - self.make_repository('.', shared=True) - self.assertEqual(smart_req.SmartServerResponse((b'yes', )), - request.execute(b'', )) + self.make_repository(".", shared=True) + self.assertEqual( + smart_req.SmartServerResponse((b"yes",)), + request.execute( + b"", + ), + ) def test_is_not_shared(self): """For a shared repository, ('no', ) is returned.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryIsShared(backing) - self.make_repository('.', shared=False) - self.assertEqual(smart_req.SmartServerResponse((b'no', )), - request.execute(b'', )) + self.make_repository(".", shared=False) + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), + request.execute( + b"", + ), + ) class TestSmartServerRepositoryGetRevisionSignatureText( - tests.TestCaseWithMemoryTransport): - + tests.TestCaseWithMemoryTransport +): def test_get_signature(self): backing = self.get_transport() - request = smart_repo.SmartServerRepositoryGetRevisionSignatureText( - backing) - bb = self.make_branch_builder('.') - bb.build_commit(rev_id=b'A') + request = smart_repo.SmartServerRepositoryGetRevisionSignatureText(backing) + bb = self.make_branch_builder(".") + bb.build_commit(rev_id=b"A") repo = bb.get_branch().repository strategy = gpg.LoopbackGPGStrategy(None) self.addCleanup(repo.lock_write().unlock) repo.start_write_group() - repo.sign_revision(b'A', strategy) + repo.sign_revision(b"A", strategy) repo.commit_write_group() expected_body = ( - b'-----BEGIN PSEUDO-SIGNED CONTENT-----\n' + - Testament.from_revision(repo, b'A').as_short_text() + - b'-----END PSEUDO-SIGNED CONTENT-----\n') + b"-----BEGIN PSEUDO-SIGNED CONTENT-----\n" + + Testament.from_revision(repo, b"A").as_short_text() + + b"-----END PSEUDO-SIGNED CONTENT-----\n" + ) self.assertEqual( - smart_req.SmartServerResponse((b'ok', ), expected_body), - request.execute(b'', b'A')) - + smart_req.SmartServerResponse((b"ok",), expected_body), + request.execute(b"", b"A"), + ) -class TestSmartServerRepositoryMakeWorkingTrees( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRepositoryMakeWorkingTrees(tests.TestCaseWithMemoryTransport): def test_make_working_trees(self): """For a repository with working trees, ('yes', ) is returned.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryMakeWorkingTrees(backing) - r = self.make_repository('.') + r = self.make_repository(".") r.set_make_working_trees(True) - self.assertEqual(smart_req.SmartServerResponse((b'yes', )), - request.execute(b'', )) + self.assertEqual( + smart_req.SmartServerResponse((b"yes",)), + request.execute( + b"", + ), + ) def test_is_not_shared(self): """For a repository with working trees, ('no', ) is returned.""" backing = self.get_transport() request = smart_repo.SmartServerRepositoryMakeWorkingTrees(backing) - r = self.make_repository('.') + r = self.make_repository(".") r.set_make_working_trees(False) - self.assertEqual(smart_req.SmartServerResponse((b'no', )), - request.execute(b'', )) + self.assertEqual( + smart_req.SmartServerResponse((b"no",)), + request.execute( + b"", + ), + ) class TestSmartServerRepositoryLockWrite(tests.TestCaseWithMemoryTransport): - def test_lock_write_on_unlocked_repo(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryLockWrite(backing) - repository = self.make_repository('.', format='knit') - response = request.execute(b'') + repository = self.make_repository(".", format="knit") + response = request.execute(b"") nonce = repository.control_files._lock.peek().nonce - self.assertEqual(smart_req.SmartServerResponse( - (b'ok', nonce)), response) + self.assertEqual(smart_req.SmartServerResponse((b"ok", nonce)), response) # The repository is now locked. Verify that with a new repository # object. new_repo = repository.controldir.open_repository() self.assertRaises(errors.LockContention, new_repo.lock_write) # Cleanup request = smart_repo.SmartServerRepositoryUnlock(backing) - response = request.execute(b'', nonce) + response = request.execute(b"", nonce) def test_lock_write_on_locked_repo(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryLockWrite(backing) - repository = self.make_repository('.', format='knit') + repository = self.make_repository(".", format="knit") repo_token = repository.lock_write().repository_token repository.leave_lock_in_place() repository.unlock() - response = request.execute(b'') - self.assertEqual( - smart_req.SmartServerResponse((b'LockContention',)), response) + response = request.execute(b"") + self.assertEqual(smart_req.SmartServerResponse((b"LockContention",)), response) # Cleanup repository.lock_write(repo_token) repository.dont_leave_lock_in_place() @@ -2193,72 +2264,65 @@ def test_lock_write_on_locked_repo(self): def test_lock_write_on_readonly_transport(self): backing = self.get_readonly_transport() request = smart_repo.SmartServerRepositoryLockWrite(backing) - self.make_repository('.', format='knit') - response = request.execute(b'') + self.make_repository(".", format="knit") + response = request.execute(b"") self.assertFalse(response.is_successful()) - self.assertEqual(b'LockFailed', response.args[0]) + self.assertEqual(b"LockFailed", response.args[0]) class TestInsertStreamBase(tests.TestCaseWithMemoryTransport): - def make_empty_byte_stream(self, repo): byte_stream = smart_repo._stream_to_byte_stream([], repo._format) - return b''.join(byte_stream) + return b"".join(byte_stream) class TestSmartServerRepositoryInsertStream(TestInsertStreamBase): - def test_insert_stream_empty(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryInsertStream(backing) - repository = self.make_repository('.') - response = request.execute(b'', b'') + repository = self.make_repository(".") + response = request.execute(b"", b"") self.assertEqual(None, response) response = request.do_chunk(self.make_empty_byte_stream(repository)) self.assertEqual(None, response) response = request.do_end() - self.assertEqual(smart_req.SmartServerResponse((b'ok', )), response) + self.assertEqual(smart_req.SmartServerResponse((b"ok",)), response) class TestSmartServerRepositoryInsertStreamLocked(TestInsertStreamBase): - def test_insert_stream_empty(self): backing = self.get_transport() - request = smart_repo.SmartServerRepositoryInsertStreamLocked( - backing) - repository = self.make_repository('.', format='knit') + request = smart_repo.SmartServerRepositoryInsertStreamLocked(backing) + repository = self.make_repository(".", format="knit") lock_token = repository.lock_write().repository_token - response = request.execute(b'', b'', lock_token) + response = request.execute(b"", b"", lock_token) self.assertEqual(None, response) response = request.do_chunk(self.make_empty_byte_stream(repository)) self.assertEqual(None, response) response = request.do_end() - self.assertEqual(smart_req.SmartServerResponse((b'ok', )), response) + self.assertEqual(smart_req.SmartServerResponse((b"ok",)), response) repository.unlock() def test_insert_stream_with_wrong_lock_token(self): backing = self.get_transport() - request = smart_repo.SmartServerRepositoryInsertStreamLocked( - backing) - repository = self.make_repository('.', format='knit') + request = smart_repo.SmartServerRepositoryInsertStreamLocked(backing) + repository = self.make_repository(".", format="knit") with repository.lock_write(): self.assertRaises( - errors.TokenMismatch, request.execute, b'', b'', - b'wrong-token') + errors.TokenMismatch, request.execute, b"", b"", b"wrong-token" + ) class TestSmartServerRepositoryUnlock(tests.TestCaseWithMemoryTransport): - def test_unlock_on_locked_repo(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryUnlock(backing) - repository = self.make_repository('.', format='knit') + repository = self.make_repository(".", format="knit") token = repository.lock_write().repository_token repository.leave_lock_in_place() repository.unlock() - response = request.execute(b'', token) - self.assertEqual( - smart_req.SmartServerResponse((b'ok',)), response) + response = request.execute(b"", token) + self.assertEqual(smart_req.SmartServerResponse((b"ok",)), response) # The repository is now unlocked. Verify that with a new repository # object. new_repo = repository.controldir.open_repository() @@ -2268,195 +2332,206 @@ def test_unlock_on_locked_repo(self): def test_unlock_on_unlocked_repo(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryUnlock(backing) - self.make_repository('.', format='knit') - response = request.execute(b'', b'some token') - self.assertEqual( - smart_req.SmartServerResponse((b'TokenMismatch',)), response) - + self.make_repository(".", format="knit") + response = request.execute(b"", b"some token") + self.assertEqual(smart_req.SmartServerResponse((b"TokenMismatch",)), response) -class TestSmartServerRepositoryGetPhysicalLockStatus( - tests.TestCaseWithTransport): +class TestSmartServerRepositoryGetPhysicalLockStatus(tests.TestCaseWithTransport): def test_with_write_lock(self): backing = self.get_transport() - repo = self.make_repository('.') + repo = self.make_repository(".") self.addCleanup(repo.lock_write().unlock) # lock_write() doesn't necessarily actually take a physical # lock out. if repo.get_physical_lock_status(): - expected = b'yes' + expected = b"yes" else: - expected = b'no' + expected = b"no" request_class = smart_repo.SmartServerRepositoryGetPhysicalLockStatus request = request_class(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse((expected,)), - request.execute(b'', )) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((expected,)), + request.execute( + b"", + ), + ) def test_without_write_lock(self): backing = self.get_transport() - repo = self.make_repository('.') + repo = self.make_repository(".") self.assertEqual(False, repo.get_physical_lock_status()) request_class = smart_repo.SmartServerRepositoryGetPhysicalLockStatus request = request_class(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse((b'no',)), - request.execute(b'', )) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"no",)), + request.execute( + b"", + ), + ) class TestSmartServerRepositoryReconcile(tests.TestCaseWithTransport): - def test_reconcile(self): backing = self.get_transport() - repo = self.make_repository('.') + repo = self.make_repository(".") token = repo.lock_write().repository_token self.addCleanup(repo.unlock) request_class = smart_repo.SmartServerRepositoryReconcile request = request_class(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse( - (b'ok', ), - b'garbage_inventories: 0\n' - b'inconsistent_parents: 0\n'), - request.execute(b'', token)) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse( + (b"ok",), b"garbage_inventories: 0\n" b"inconsistent_parents: 0\n" + ), + request.execute(b"", token), + ) class TestSmartServerIsReadonly(tests.TestCaseWithMemoryTransport): - def test_is_readonly_no(self): backing = self.get_transport() request = smart_req.SmartServerIsReadonly(backing) response = request.execute() - self.assertEqual( - smart_req.SmartServerResponse((b'no',)), response) + self.assertEqual(smart_req.SmartServerResponse((b"no",)), response) def test_is_readonly_yes(self): backing = self.get_readonly_transport() request = smart_req.SmartServerIsReadonly(backing) response = request.execute() - self.assertEqual( - smart_req.SmartServerResponse((b'yes',)), response) - + self.assertEqual(smart_req.SmartServerResponse((b"yes",)), response) -class TestSmartServerRepositorySetMakeWorkingTrees( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRepositorySetMakeWorkingTrees(tests.TestCaseWithMemoryTransport): def test_set_false(self): backing = self.get_transport() - repo = self.make_repository('.', shared=True) + repo = self.make_repository(".", shared=True) repo.set_make_working_trees(True) request_class = smart_repo.SmartServerRepositorySetMakeWorkingTrees request = request_class(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse((b'ok',)), - request.execute(b'', b'False')) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"ok",)), + request.execute(b"", b"False"), + ) repo = repo.controldir.open_repository() self.assertFalse(repo.make_working_trees()) def test_set_true(self): backing = self.get_transport() - repo = self.make_repository('.', shared=True) + repo = self.make_repository(".", shared=True) repo.set_make_working_trees(False) request_class = smart_repo.SmartServerRepositorySetMakeWorkingTrees request = request_class(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse((b'ok',)), - request.execute(b'', b'True')) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"ok",)), + request.execute(b"", b"True"), + ) repo = repo.controldir.open_repository() self.assertTrue(repo.make_working_trees()) -class TestSmartServerRepositoryGetSerializerFormat( - tests.TestCaseWithMemoryTransport): - +class TestSmartServerRepositoryGetSerializerFormat(tests.TestCaseWithMemoryTransport): def test_get_serializer_format(self): backing = self.get_transport() - self.make_repository('.', format='2a') + self.make_repository(".", format="2a") request_class = smart_repo.SmartServerRepositoryGetSerializerFormat request = request_class(backing) self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'ok', b'10')), - request.execute(b'')) - + smart_req.SuccessfulSmartServerResponse((b"ok", b"10")), + request.execute(b""), + ) -class TestSmartServerRepositoryWriteGroup( - tests.TestCaseWithMemoryTransport): +class TestSmartServerRepositoryWriteGroup(tests.TestCaseWithMemoryTransport): def test_start_write_group(self): backing = self.get_transport() - repo = self.make_repository('.') + repo = self.make_repository(".") lock_token = repo.lock_write().repository_token self.addCleanup(repo.unlock) request_class = smart_repo.SmartServerRepositoryStartWriteGroup request = request_class(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse((b'ok', [])), - request.execute(b'', lock_token)) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"ok", [])), + request.execute(b"", lock_token), + ) def test_start_write_group_unsuspendable(self): backing = self.get_transport() - repo = self.make_repository('.', format='knit') + repo = self.make_repository(".", format="knit") lock_token = repo.lock_write().repository_token self.addCleanup(repo.unlock) request_class = smart_repo.SmartServerRepositoryStartWriteGroup request = request_class(backing) self.assertEqual( - smart_req.FailedSmartServerResponse((b'UnsuspendableWriteGroup',)), - request.execute(b'', lock_token)) + smart_req.FailedSmartServerResponse((b"UnsuspendableWriteGroup",)), + request.execute(b"", lock_token), + ) def test_commit_write_group(self): backing = self.get_transport() - repo = self.make_repository('.') + repo = self.make_repository(".") lock_token = repo.lock_write().repository_token self.addCleanup(repo.unlock) repo.start_write_group() tokens = repo.suspend_write_group() request_class = smart_repo.SmartServerRepositoryCommitWriteGroup request = request_class(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse((b'ok',)), - request.execute(b'', lock_token, tokens)) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"ok",)), + request.execute(b"", lock_token, tokens), + ) def test_abort_write_group(self): backing = self.get_transport() - repo = self.make_repository('.') + repo = self.make_repository(".") lock_token = repo.lock_write().repository_token repo.start_write_group() tokens = repo.suspend_write_group() self.addCleanup(repo.unlock) request_class = smart_repo.SmartServerRepositoryAbortWriteGroup request = request_class(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse((b'ok',)), - request.execute(b'', lock_token, tokens)) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"ok",)), + request.execute(b"", lock_token, tokens), + ) def test_check_write_group(self): backing = self.get_transport() - repo = self.make_repository('.') + repo = self.make_repository(".") lock_token = repo.lock_write().repository_token repo.start_write_group() tokens = repo.suspend_write_group() self.addCleanup(repo.unlock) request_class = smart_repo.SmartServerRepositoryCheckWriteGroup request = request_class(backing) - self.assertEqual(smart_req.SuccessfulSmartServerResponse((b'ok',)), - request.execute(b'', lock_token, tokens)) + self.assertEqual( + smart_req.SuccessfulSmartServerResponse((b"ok",)), + request.execute(b"", lock_token, tokens), + ) def test_check_write_group_invalid(self): backing = self.get_transport() - repo = self.make_repository('.') + repo = self.make_repository(".") lock_token = repo.lock_write().repository_token self.addCleanup(repo.unlock) request_class = smart_repo.SmartServerRepositoryCheckWriteGroup request = request_class(backing) - self.assertEqual(smart_req.FailedSmartServerResponse( - (b'UnresumableWriteGroup', [b'random'], - b'Malformed write group token')), - request.execute(b'', lock_token, [b"random"])) + self.assertEqual( + smart_req.FailedSmartServerResponse( + (b"UnresumableWriteGroup", [b"random"], b"Malformed write group token") + ), + request.execute(b"", lock_token, [b"random"]), + ) class TestSmartServerPackRepositoryAutopack(tests.TestCaseWithTransport): - - def make_repo_needing_autopacking(self, path='.'): + def make_repo_needing_autopacking(self, path="."): # Make a repo in need of autopacking. - tree = self.make_branch_and_tree('.', format='pack-0.92') + tree = self.make_branch_and_tree(".", format="pack-0.92") repo = tree.branch.repository # monkey-patch the pack collection to disable autopacking repo._pack_collection._max_pack_count = lambda count: count for x in range(10): - tree.commit(f'commit {x}') + tree.commit(f"commit {x}") self.assertEqual(10, len(repo._pack_collection.names())) del repo._pack_collection._max_pack_count return repo @@ -2466,49 +2541,47 @@ def test_autopack_needed(self): repo.lock_write() self.addCleanup(repo.unlock) backing = self.get_transport() - request = smart_packrepo.SmartServerPackRepositoryAutopack( - backing) - response = request.execute(b'') - self.assertEqual(smart_req.SmartServerResponse((b'ok',)), response) + request = smart_packrepo.SmartServerPackRepositoryAutopack(backing) + response = request.execute(b"") + self.assertEqual(smart_req.SmartServerResponse((b"ok",)), response) repo._pack_collection.reload_pack_names() self.assertEqual(1, len(repo._pack_collection.names())) def test_autopack_not_needed(self): - tree = self.make_branch_and_tree('.', format='pack-0.92') + tree = self.make_branch_and_tree(".", format="pack-0.92") repo = tree.branch.repository repo.lock_write() self.addCleanup(repo.unlock) for x in range(9): - tree.commit(f'commit {x}') + tree.commit(f"commit {x}") backing = self.get_transport() - request = smart_packrepo.SmartServerPackRepositoryAutopack( - backing) - response = request.execute(b'') - self.assertEqual(smart_req.SmartServerResponse((b'ok',)), response) + request = smart_packrepo.SmartServerPackRepositoryAutopack(backing) + response = request.execute(b"") + self.assertEqual(smart_req.SmartServerResponse((b"ok",)), response) repo._pack_collection.reload_pack_names() self.assertEqual(9, len(repo._pack_collection.names())) def test_autopack_on_nonpack_format(self): """A request to autopack a non-pack repo is a no-op.""" - self.make_repository('.', format='knit') + self.make_repository(".", format="knit") backing = self.get_transport() - request = smart_packrepo.SmartServerPackRepositoryAutopack( - backing) - response = request.execute(b'') - self.assertEqual(smart_req.SmartServerResponse((b'ok',)), response) + request = smart_packrepo.SmartServerPackRepositoryAutopack(backing) + response = request.execute(b"") + self.assertEqual(smart_req.SmartServerResponse((b"ok",)), response) class TestSmartServerVfsGet(tests.TestCaseWithMemoryTransport): - def test_unicode_path(self): """VFS requests expect unicode paths to be escaped.""" - filename = 'foo\N{INTERROBANG}' + filename = "foo\N{INTERROBANG}" filename_escaped = urlutils.escape(filename) backing = self.get_transport() request = vfs.GetRequest(backing) - backing.put_bytes_non_atomic(filename_escaped, b'contents') - self.assertEqual(smart_req.SmartServerResponse((b'ok', ), b'contents'), - request.execute(filename_escaped.encode('ascii'))) + backing.put_bytes_non_atomic(filename_escaped, b"contents") + self.assertEqual( + smart_req.SmartServerResponse((b"ok",), b"contents"), + request.execute(filename_escaped.encode("ascii")), + ) class TestHandlers(tests.TestCase): @@ -2522,131 +2595,216 @@ def test_all_registrations_exist(self): try: smart_req.request_handlers.get(key) except AttributeError as e: - raise AttributeError(f'failed to get {key}: {e}') from e + raise AttributeError(f"failed to get {key}: {e}") from e def assertHandlerEqual(self, verb, handler): self.assertEqual(smart_req.request_handlers.get(verb), handler) def test_registered_methods(self): """Test that known methods are registered to the correct object.""" - self.assertHandlerEqual(b'Branch.break_lock', - smart_branch.SmartServerBranchBreakLock) - self.assertHandlerEqual(b'Branch.get_config_file', - smart_branch.SmartServerBranchGetConfigFile) - self.assertHandlerEqual(b'Branch.put_config_file', - smart_branch.SmartServerBranchPutConfigFile) - self.assertHandlerEqual(b'Branch.get_parent', - smart_branch.SmartServerBranchGetParent) - self.assertHandlerEqual(b'Branch.get_physical_lock_status', - smart_branch.SmartServerBranchRequestGetPhysicalLockStatus) - self.assertHandlerEqual(b'Branch.get_tags_bytes', - smart_branch.SmartServerBranchGetTagsBytes) - self.assertHandlerEqual(b'Branch.lock_write', - smart_branch.SmartServerBranchRequestLockWrite) - self.assertHandlerEqual(b'Branch.last_revision_info', - smart_branch.SmartServerBranchRequestLastRevisionInfo) - self.assertHandlerEqual(b'Branch.revision_history', - smart_branch.SmartServerRequestRevisionHistory) - self.assertHandlerEqual(b'Branch.revision_id_to_revno', - smart_branch.SmartServerBranchRequestRevisionIdToRevno) - self.assertHandlerEqual(b'Branch.set_config_option', - smart_branch.SmartServerBranchRequestSetConfigOption) - self.assertHandlerEqual(b'Branch.set_last_revision', - smart_branch.SmartServerBranchRequestSetLastRevision) - self.assertHandlerEqual(b'Branch.set_last_revision_info', - smart_branch.SmartServerBranchRequestSetLastRevisionInfo) - self.assertHandlerEqual(b'Branch.set_last_revision_ex', - smart_branch.SmartServerBranchRequestSetLastRevisionEx) - self.assertHandlerEqual(b'Branch.set_parent_location', - smart_branch.SmartServerBranchRequestSetParentLocation) - self.assertHandlerEqual(b'Branch.unlock', - smart_branch.SmartServerBranchRequestUnlock) - self.assertHandlerEqual(b'BzrDir.destroy_branch', - smart_dir.SmartServerBzrDirRequestDestroyBranch) - self.assertHandlerEqual(b'BzrDir.find_repository', - smart_dir.SmartServerRequestFindRepositoryV1) - self.assertHandlerEqual(b'BzrDir.find_repositoryV2', - smart_dir.SmartServerRequestFindRepositoryV2) - self.assertHandlerEqual(b'BzrDirFormat.initialize', - smart_dir.SmartServerRequestInitializeBzrDir) - self.assertHandlerEqual(b'BzrDirFormat.initialize_ex_1.16', - smart_dir.SmartServerRequestBzrDirInitializeEx) - self.assertHandlerEqual(b'BzrDir.checkout_metadir', - smart_dir.SmartServerBzrDirRequestCheckoutMetaDir) - self.assertHandlerEqual(b'BzrDir.cloning_metadir', - smart_dir.SmartServerBzrDirRequestCloningMetaDir) - self.assertHandlerEqual(b'BzrDir.get_branches', - smart_dir.SmartServerBzrDirRequestGetBranches) - self.assertHandlerEqual(b'BzrDir.get_config_file', - smart_dir.SmartServerBzrDirRequestConfigFile) - self.assertHandlerEqual(b'BzrDir.open_branch', - smart_dir.SmartServerRequestOpenBranch) - self.assertHandlerEqual(b'BzrDir.open_branchV2', - smart_dir.SmartServerRequestOpenBranchV2) - self.assertHandlerEqual(b'BzrDir.open_branchV3', - smart_dir.SmartServerRequestOpenBranchV3) - self.assertHandlerEqual(b'PackRepository.autopack', - smart_packrepo.SmartServerPackRepositoryAutopack) - self.assertHandlerEqual(b'Repository.add_signature_text', - smart_repo.SmartServerRepositoryAddSignatureText) - self.assertHandlerEqual(b'Repository.all_revision_ids', - smart_repo.SmartServerRepositoryAllRevisionIds) - self.assertHandlerEqual(b'Repository.break_lock', - smart_repo.SmartServerRepositoryBreakLock) - self.assertHandlerEqual(b'Repository.gather_stats', - smart_repo.SmartServerRepositoryGatherStats) - self.assertHandlerEqual(b'Repository.get_parent_map', - smart_repo.SmartServerRepositoryGetParentMap) - self.assertHandlerEqual(b'Repository.get_physical_lock_status', - smart_repo.SmartServerRepositoryGetPhysicalLockStatus) - self.assertHandlerEqual(b'Repository.get_rev_id_for_revno', - smart_repo.SmartServerRepositoryGetRevIdForRevno) - self.assertHandlerEqual(b'Repository.get_revision_graph', - smart_repo.SmartServerRepositoryGetRevisionGraph) - self.assertHandlerEqual(b'Repository.get_revision_signature_text', - smart_repo.SmartServerRepositoryGetRevisionSignatureText) - self.assertHandlerEqual(b'Repository.get_stream', - smart_repo.SmartServerRepositoryGetStream) - self.assertHandlerEqual(b'Repository.get_stream_1.19', - smart_repo.SmartServerRepositoryGetStream_1_19) - self.assertHandlerEqual(b'Repository.iter_revisions', - smart_repo.SmartServerRepositoryIterRevisions) - self.assertHandlerEqual(b'Repository.has_revision', - smart_repo.SmartServerRequestHasRevision) - self.assertHandlerEqual(b'Repository.insert_stream', - smart_repo.SmartServerRepositoryInsertStream) - self.assertHandlerEqual(b'Repository.insert_stream_locked', - smart_repo.SmartServerRepositoryInsertStreamLocked) - self.assertHandlerEqual(b'Repository.is_shared', - smart_repo.SmartServerRepositoryIsShared) - self.assertHandlerEqual(b'Repository.iter_files_bytes', - smart_repo.SmartServerRepositoryIterFilesBytes) - self.assertHandlerEqual(b'Repository.lock_write', - smart_repo.SmartServerRepositoryLockWrite) - self.assertHandlerEqual(b'Repository.make_working_trees', - smart_repo.SmartServerRepositoryMakeWorkingTrees) - self.assertHandlerEqual(b'Repository.pack', - smart_repo.SmartServerRepositoryPack) - self.assertHandlerEqual(b'Repository.reconcile', - smart_repo.SmartServerRepositoryReconcile) - self.assertHandlerEqual(b'Repository.tarball', - smart_repo.SmartServerRepositoryTarball) - self.assertHandlerEqual(b'Repository.unlock', - smart_repo.SmartServerRepositoryUnlock) - self.assertHandlerEqual(b'Repository.start_write_group', - smart_repo.SmartServerRepositoryStartWriteGroup) - self.assertHandlerEqual(b'Repository.check_write_group', - smart_repo.SmartServerRepositoryCheckWriteGroup) - self.assertHandlerEqual(b'Repository.commit_write_group', - smart_repo.SmartServerRepositoryCommitWriteGroup) - self.assertHandlerEqual(b'Repository.abort_write_group', - smart_repo.SmartServerRepositoryAbortWriteGroup) - self.assertHandlerEqual(b'VersionedFileRepository.get_serializer_format', - smart_repo.SmartServerRepositoryGetSerializerFormat) - self.assertHandlerEqual(b'VersionedFileRepository.get_inventories', - smart_repo.SmartServerRepositoryGetInventories) - self.assertHandlerEqual(b'Transport.is_readonly', - smart_req.SmartServerIsReadonly) + self.assertHandlerEqual( + b"Branch.break_lock", smart_branch.SmartServerBranchBreakLock + ) + self.assertHandlerEqual( + b"Branch.get_config_file", smart_branch.SmartServerBranchGetConfigFile + ) + self.assertHandlerEqual( + b"Branch.put_config_file", smart_branch.SmartServerBranchPutConfigFile + ) + self.assertHandlerEqual( + b"Branch.get_parent", smart_branch.SmartServerBranchGetParent + ) + self.assertHandlerEqual( + b"Branch.get_physical_lock_status", + smart_branch.SmartServerBranchRequestGetPhysicalLockStatus, + ) + self.assertHandlerEqual( + b"Branch.get_tags_bytes", smart_branch.SmartServerBranchGetTagsBytes + ) + self.assertHandlerEqual( + b"Branch.lock_write", smart_branch.SmartServerBranchRequestLockWrite + ) + self.assertHandlerEqual( + b"Branch.last_revision_info", + smart_branch.SmartServerBranchRequestLastRevisionInfo, + ) + self.assertHandlerEqual( + b"Branch.revision_history", smart_branch.SmartServerRequestRevisionHistory + ) + self.assertHandlerEqual( + b"Branch.revision_id_to_revno", + smart_branch.SmartServerBranchRequestRevisionIdToRevno, + ) + self.assertHandlerEqual( + b"Branch.set_config_option", + smart_branch.SmartServerBranchRequestSetConfigOption, + ) + self.assertHandlerEqual( + b"Branch.set_last_revision", + smart_branch.SmartServerBranchRequestSetLastRevision, + ) + self.assertHandlerEqual( + b"Branch.set_last_revision_info", + smart_branch.SmartServerBranchRequestSetLastRevisionInfo, + ) + self.assertHandlerEqual( + b"Branch.set_last_revision_ex", + smart_branch.SmartServerBranchRequestSetLastRevisionEx, + ) + self.assertHandlerEqual( + b"Branch.set_parent_location", + smart_branch.SmartServerBranchRequestSetParentLocation, + ) + self.assertHandlerEqual( + b"Branch.unlock", smart_branch.SmartServerBranchRequestUnlock + ) + self.assertHandlerEqual( + b"BzrDir.destroy_branch", smart_dir.SmartServerBzrDirRequestDestroyBranch + ) + self.assertHandlerEqual( + b"BzrDir.find_repository", smart_dir.SmartServerRequestFindRepositoryV1 + ) + self.assertHandlerEqual( + b"BzrDir.find_repositoryV2", smart_dir.SmartServerRequestFindRepositoryV2 + ) + self.assertHandlerEqual( + b"BzrDirFormat.initialize", smart_dir.SmartServerRequestInitializeBzrDir + ) + self.assertHandlerEqual( + b"BzrDirFormat.initialize_ex_1.16", + smart_dir.SmartServerRequestBzrDirInitializeEx, + ) + self.assertHandlerEqual( + b"BzrDir.checkout_metadir", + smart_dir.SmartServerBzrDirRequestCheckoutMetaDir, + ) + self.assertHandlerEqual( + b"BzrDir.cloning_metadir", smart_dir.SmartServerBzrDirRequestCloningMetaDir + ) + self.assertHandlerEqual( + b"BzrDir.get_branches", smart_dir.SmartServerBzrDirRequestGetBranches + ) + self.assertHandlerEqual( + b"BzrDir.get_config_file", smart_dir.SmartServerBzrDirRequestConfigFile + ) + self.assertHandlerEqual( + b"BzrDir.open_branch", smart_dir.SmartServerRequestOpenBranch + ) + self.assertHandlerEqual( + b"BzrDir.open_branchV2", smart_dir.SmartServerRequestOpenBranchV2 + ) + self.assertHandlerEqual( + b"BzrDir.open_branchV3", smart_dir.SmartServerRequestOpenBranchV3 + ) + self.assertHandlerEqual( + b"PackRepository.autopack", smart_packrepo.SmartServerPackRepositoryAutopack + ) + self.assertHandlerEqual( + b"Repository.add_signature_text", + smart_repo.SmartServerRepositoryAddSignatureText, + ) + self.assertHandlerEqual( + b"Repository.all_revision_ids", + smart_repo.SmartServerRepositoryAllRevisionIds, + ) + self.assertHandlerEqual( + b"Repository.break_lock", smart_repo.SmartServerRepositoryBreakLock + ) + self.assertHandlerEqual( + b"Repository.gather_stats", smart_repo.SmartServerRepositoryGatherStats + ) + self.assertHandlerEqual( + b"Repository.get_parent_map", smart_repo.SmartServerRepositoryGetParentMap + ) + self.assertHandlerEqual( + b"Repository.get_physical_lock_status", + smart_repo.SmartServerRepositoryGetPhysicalLockStatus, + ) + self.assertHandlerEqual( + b"Repository.get_rev_id_for_revno", + smart_repo.SmartServerRepositoryGetRevIdForRevno, + ) + self.assertHandlerEqual( + b"Repository.get_revision_graph", + smart_repo.SmartServerRepositoryGetRevisionGraph, + ) + self.assertHandlerEqual( + b"Repository.get_revision_signature_text", + smart_repo.SmartServerRepositoryGetRevisionSignatureText, + ) + self.assertHandlerEqual( + b"Repository.get_stream", smart_repo.SmartServerRepositoryGetStream + ) + self.assertHandlerEqual( + b"Repository.get_stream_1.19", + smart_repo.SmartServerRepositoryGetStream_1_19, + ) + self.assertHandlerEqual( + b"Repository.iter_revisions", smart_repo.SmartServerRepositoryIterRevisions + ) + self.assertHandlerEqual( + b"Repository.has_revision", smart_repo.SmartServerRequestHasRevision + ) + self.assertHandlerEqual( + b"Repository.insert_stream", smart_repo.SmartServerRepositoryInsertStream + ) + self.assertHandlerEqual( + b"Repository.insert_stream_locked", + smart_repo.SmartServerRepositoryInsertStreamLocked, + ) + self.assertHandlerEqual( + b"Repository.is_shared", smart_repo.SmartServerRepositoryIsShared + ) + self.assertHandlerEqual( + b"Repository.iter_files_bytes", + smart_repo.SmartServerRepositoryIterFilesBytes, + ) + self.assertHandlerEqual( + b"Repository.lock_write", smart_repo.SmartServerRepositoryLockWrite + ) + self.assertHandlerEqual( + b"Repository.make_working_trees", + smart_repo.SmartServerRepositoryMakeWorkingTrees, + ) + self.assertHandlerEqual( + b"Repository.pack", smart_repo.SmartServerRepositoryPack + ) + self.assertHandlerEqual( + b"Repository.reconcile", smart_repo.SmartServerRepositoryReconcile + ) + self.assertHandlerEqual( + b"Repository.tarball", smart_repo.SmartServerRepositoryTarball + ) + self.assertHandlerEqual( + b"Repository.unlock", smart_repo.SmartServerRepositoryUnlock + ) + self.assertHandlerEqual( + b"Repository.start_write_group", + smart_repo.SmartServerRepositoryStartWriteGroup, + ) + self.assertHandlerEqual( + b"Repository.check_write_group", + smart_repo.SmartServerRepositoryCheckWriteGroup, + ) + self.assertHandlerEqual( + b"Repository.commit_write_group", + smart_repo.SmartServerRepositoryCommitWriteGroup, + ) + self.assertHandlerEqual( + b"Repository.abort_write_group", + smart_repo.SmartServerRepositoryAbortWriteGroup, + ) + self.assertHandlerEqual( + b"VersionedFileRepository.get_serializer_format", + smart_repo.SmartServerRepositoryGetSerializerFormat, + ) + self.assertHandlerEqual( + b"VersionedFileRepository.get_inventories", + smart_repo.SmartServerRepositoryGetInventories, + ) + self.assertHandlerEqual( + b"Transport.is_readonly", smart_req.SmartServerIsReadonly + ) class SmartTCPServerHookTests(tests.TestCaseWithMemoryTransport): @@ -2659,64 +2817,69 @@ def setUp(self): def test_run_server_started_hooks(self): """Test the server started hooks get fired properly.""" started_calls = [] - server.SmartTCPServer.hooks.install_named_hook('server_started', - lambda backing_urls, url: started_calls.append( - (backing_urls, url)), - None) + server.SmartTCPServer.hooks.install_named_hook( + "server_started", + lambda backing_urls, url: started_calls.append((backing_urls, url)), + None, + ) started_ex_calls = [] - server.SmartTCPServer.hooks.install_named_hook('server_started_ex', - lambda backing_urls, url: started_ex_calls.append( - (backing_urls, url)), - None) - self.server._sockname = ('example.com', 42) + server.SmartTCPServer.hooks.install_named_hook( + "server_started_ex", + lambda backing_urls, url: started_ex_calls.append((backing_urls, url)), + None, + ) + self.server._sockname = ("example.com", 42) self.server.run_server_started_hooks() - self.assertEqual(started_calls, - [([self.get_transport().base], 'bzr://example.com:42/')]) - self.assertEqual(started_ex_calls, - [([self.get_transport().base], self.server)]) + self.assertEqual( + started_calls, [([self.get_transport().base], "bzr://example.com:42/")] + ) + self.assertEqual(started_ex_calls, [([self.get_transport().base], self.server)]) def test_run_server_started_hooks_ipv6(self): """Test that socknames can contain 4-tuples.""" - self.server._sockname = ('::', 42, 0, 0) + self.server._sockname = ("::", 42, 0, 0) started_calls = [] - server.SmartTCPServer.hooks.install_named_hook('server_started', - lambda backing_urls, url: started_calls.append( - (backing_urls, url)), - None) + server.SmartTCPServer.hooks.install_named_hook( + "server_started", + lambda backing_urls, url: started_calls.append((backing_urls, url)), + None, + ) self.server.run_server_started_hooks() - self.assertEqual(started_calls, - [([self.get_transport().base], 'bzr://:::42/')]) + self.assertEqual(started_calls, [([self.get_transport().base], "bzr://:::42/")]) def test_run_server_stopped_hooks(self): """Test the server stopped hooks.""" - self.server._sockname = ('example.com', 42) + self.server._sockname = ("example.com", 42) stopped_calls = [] - server.SmartTCPServer.hooks.install_named_hook('server_stopped', - lambda backing_urls, url: stopped_calls.append( - (backing_urls, url)), - None) + server.SmartTCPServer.hooks.install_named_hook( + "server_stopped", + lambda backing_urls, url: stopped_calls.append((backing_urls, url)), + None, + ) self.server.run_server_stopped_hooks() - self.assertEqual(stopped_calls, - [([self.get_transport().base], 'bzr://example.com:42/')]) + self.assertEqual( + stopped_calls, [([self.get_transport().base], "bzr://example.com:42/")] + ) class TestSmartServerRepositoryPack(tests.TestCaseWithMemoryTransport): - def test_pack(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryPack(backing) - tree = self.make_branch_and_memory_tree('.') + tree = self.make_branch_and_memory_tree(".") repo_token = tree.branch.repository.lock_write().repository_token - self.assertIs(None, request.execute(b'', repo_token, False)) + self.assertIs(None, request.execute(b"", repo_token, False)) self.assertEqual( - smart_req.SuccessfulSmartServerResponse((b'ok', ), ), - request.do_body(b'')) + smart_req.SuccessfulSmartServerResponse( + (b"ok",), + ), + request.do_body(b""), + ) class TestSmartServerRepositoryGetInventories(tests.TestCaseWithTransport): - def _get_serialized_inventory_delta(self, repository, base_revid, revid): base_inv = repository.revision_tree(base_revid).root_inventory inv = repository.revision_tree(revid).root_inventory @@ -2727,112 +2890,123 @@ def _get_serialized_inventory_delta(self, repository, base_revid, revid): def test_single(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetInventories(backing) - t = self.make_branch_and_tree('.', format='2a') + t = self.make_branch_and_tree(".", format="2a") self.addCleanup(t.lock_write().unlock) self.build_tree_contents([("file", b"somecontents")]) t.add(["file"], ids=[b"thefileid"]) - t.commit(rev_id=b'somerev', message="add file") - self.assertIs(None, request.execute(b'', b'unordered')) + t.commit(rev_id=b"somerev", message="add file") + self.assertIs(None, request.execute(b"", b"unordered")) response = request.do_body(b"somerev\n") self.assertTrue(response.is_successful()) - self.assertEqual(response.args, (b"ok", )) - stream = [('inventory-deltas', [ - versionedfile.FulltextContentFactory(b'somerev', None, None, - self._get_serialized_inventory_delta( - t.branch.repository, b'null:', b'somerev'))])] - fmt = controldir.format_registry.get('2a')().repository_format + self.assertEqual(response.args, (b"ok",)) + stream = [ + ( + "inventory-deltas", + [ + versionedfile.FulltextContentFactory( + b"somerev", + None, + None, + self._get_serialized_inventory_delta( + t.branch.repository, b"null:", b"somerev" + ), + ) + ], + ) + ] + fmt = controldir.format_registry.get("2a")().repository_format self.assertEqual( b"".join(response.body_stream), - b"".join(smart_repo._stream_to_byte_stream(stream, fmt))) + b"".join(smart_repo._stream_to_byte_stream(stream, fmt)), + ) def test_empty(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryGetInventories(backing) - t = self.make_branch_and_tree('.', format='2a') + t = self.make_branch_and_tree(".", format="2a") self.addCleanup(t.lock_write().unlock) self.build_tree_contents([("file", b"somecontents")]) t.add(["file"], ids=[b"thefileid"]) - t.commit(rev_id=b'somerev', message="add file") - self.assertIs(None, request.execute(b'', b'unordered')) + t.commit(rev_id=b"somerev", message="add file") + self.assertIs(None, request.execute(b"", b"unordered")) response = request.do_body(b"") self.assertTrue(response.is_successful()) - self.assertEqual(response.args, (b"ok", )) - self.assertEqual(b"".join(response.body_stream), - b"Bazaar pack format 1 (introduced in 0.18)\nB54\n\nBazaar repository format 2a (needs bzr 1.16 or later)\nE") + self.assertEqual(response.args, (b"ok",)) + self.assertEqual( + b"".join(response.body_stream), + b"Bazaar pack format 1 (introduced in 0.18)\nB54\n\nBazaar repository format 2a (needs bzr 1.16 or later)\nE", + ) class TestSmartServerRepositoryGetStreamForMissingKeys(GetStreamTestBase): - def test_missing(self): """The search argument may be a 'ancestry-of' some heads'.""" backing = self.get_transport() - request = smart_repo.SmartServerRepositoryGetStreamForMissingKeys( - backing) + request = smart_repo.SmartServerRepositoryGetStreamForMissingKeys(backing) repo, r1, r2 = self.make_two_commit_repo() - request.execute(b'', repo._format.network_name()) - lines = b'inventories\t' + r1 + request.execute(b"", repo._format.network_name()) + lines = b"inventories\t" + r1 response = request.do_body(lines) - self.assertEqual((b'ok',), response.args) - stream_bytes = b''.join(response.body_stream) - self.assertStartsWith(stream_bytes, b'Bazaar pack format 1') + self.assertEqual((b"ok",), response.args) + stream_bytes = b"".join(response.body_stream) + self.assertStartsWith(stream_bytes, b"Bazaar pack format 1") def test_unknown_format(self): """The format may not be known by the remote server.""" backing = self.get_transport() - request = smart_repo.SmartServerRepositoryGetStreamForMissingKeys( - backing) + request = smart_repo.SmartServerRepositoryGetStreamForMissingKeys(backing) repo, r1, r2 = self.make_two_commit_repo() - request.execute(b'', b'yada yada yada') + request.execute(b"", b"yada yada yada") smart_req.FailedSmartServerResponse( - (b'UnknownFormat', b'repository', b'yada yada yada')) + (b"UnknownFormat", b"repository", b"yada yada yada") + ) class TestSmartServerRepositoryRevisionArchive(tests.TestCaseWithTransport): def test_get(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryRevisionArchive(backing) - t = self.make_branch_and_tree('.') + t = self.make_branch_and_tree(".") self.addCleanup(t.lock_write().unlock) self.build_tree_contents([("file", b"somecontents")]) t.add(["file"], ids=[b"thefileid"]) - t.commit(rev_id=b'somerev', message="add file") - response = request.execute(b'', b"somerev", b"tar", b"foo.tar", b"foo") + t.commit(rev_id=b"somerev", message="add file") + response = request.execute(b"", b"somerev", b"tar", b"foo.tar", b"foo") self.assertTrue(response.is_successful()) - self.assertEqual(response.args, (b"ok", )) + self.assertEqual(response.args, (b"ok",)) b = BytesIO(b"".join(response.body_stream)) - with tarfile.open(mode='r', fileobj=b) as tf: - self.assertEqual(['foo/file'], tf.getnames()) + with tarfile.open(mode="r", fileobj=b) as tf: + self.assertEqual(["foo/file"], tf.getnames()) class TestSmartServerRepositoryAnnotateFileRevision(tests.TestCaseWithTransport): - def test_get(self): backing = self.get_transport() request = smart_repo.SmartServerRepositoryAnnotateFileRevision(backing) - t = self.make_branch_and_tree('.') + t = self.make_branch_and_tree(".") self.addCleanup(t.lock_write().unlock) self.build_tree_contents([("file", b"somecontents\nmorecontents\n")]) t.add(["file"], ids=[b"thefileid"]) - t.commit(rev_id=b'somerev', message="add file") - response = request.execute(b'', b"somerev", b"file") + t.commit(rev_id=b"somerev", message="add file") + response = request.execute(b"", b"somerev", b"file") self.assertTrue(response.is_successful()) - self.assertEqual(response.args, (b"ok", )) + self.assertEqual(response.args, (b"ok",)) self.assertEqual( - [[b'somerev', b'somecontents\n'], [b'somerev', b'morecontents\n']], - bencode.bdecode(response.body)) + [[b"somerev", b"somecontents\n"], [b"somerev", b"morecontents\n"]], + bencode.bdecode(response.body), + ) class TestSmartServerBranchRequestGetAllReferenceInfo(TestLockedBranch): - def test_get_some(self): backing = self.get_transport() request = smart_branch.SmartServerBranchRequestGetAllReferenceInfo(backing) - branch = self.make_branch('.') - branch.set_reference_info(b'file-id', 'http://www.example.com/', 'some/path') - response = request.execute(b'') + branch = self.make_branch(".") + branch.set_reference_info(b"file-id", "http://www.example.com/", "some/path") + response = request.execute(b"") self.assertTrue(response.is_successful()) - self.assertEqual(response.args, (b"ok", )) + self.assertEqual(response.args, (b"ok",)) self.assertEqual( - [[b'file-id', b'http://www.example.com/', b'some/path']], - bencode.bdecode(response.body)) - + [[b"file-id", b"http://www.example.com/", b"some/path"]], + bencode.bdecode(response.body), + ) diff --git a/breezy/bzr/tests/test_smart_request.py b/breezy/bzr/tests/test_smart_request.py index f61d8d2c17..a92a638ffa 100644 --- a/breezy/bzr/tests/test_smart_request.py +++ b/breezy/bzr/tests/test_smart_request.py @@ -29,14 +29,14 @@ class NoBodyRequest(request.SmartServerRequest): """A request that does not implement do_body.""" def do(self): - return request.SuccessfulSmartServerResponse(('ok',)) + return request.SuccessfulSmartServerResponse(("ok",)) class DoErrorRequest(request.SmartServerRequest): """A request that raises an error from self.do().""" def do(self): - raise transport.NoSuchFile('xyzzy') + raise transport.NoSuchFile("xyzzy") class DoUnexpectedErrorRequest(request.SmartServerRequest): @@ -54,7 +54,7 @@ def do(self): pass def do_chunk(self, bytes): - raise transport.NoSuchFile('xyzzy') + raise transport.NoSuchFile("xyzzy") class EndErrorRequest(request.SmartServerRequest): @@ -69,11 +69,10 @@ def do_chunk(self, bytes): pass def do_end(self): - raise transport.NoSuchFile('xyzzy') + raise transport.NoSuchFile("xyzzy") class CheckJailRequest(request.SmartServerRequest): - def __init__(self, *args): request.SmartServerRequest.__init__(self, *args) self.jail_transports_log = [] @@ -89,51 +88,51 @@ def do_end(self): class TestErrors(TestCase): - def test_disabled_method(self): error = request.DisabledMethod("class name") self.assertEqualDiff( - "The smart server method 'class name' is disabled.", str(error)) + "The smart server method 'class name' is disabled.", str(error) + ) class TestSmartRequest(TestCase): - def test_request_class_without_do_body(self): """If a request has no body data, and the request's implementation does not override do_body, then no exception is raised. """ # Create a SmartServerRequestHandler with a SmartServerRequest subclass # that does not implement do_body. - handler = request.SmartServerRequestHandler( - None, {b'foo': NoBodyRequest}, '/') + handler = request.SmartServerRequestHandler(None, {b"foo": NoBodyRequest}, "/") # Emulate a request with no body (i.e. just args). - handler.args_received((b'foo',)) + handler.args_received((b"foo",)) handler.end_received() # Request done, no exception was raised. def test_only_request_code_is_jailed(self): - transport = 'dummy transport' + transport = "dummy transport" handler = request.SmartServerRequestHandler( - transport, {b'foo': CheckJailRequest}, '/') - handler.args_received((b'foo',)) + transport, {b"foo": CheckJailRequest}, "/" + ) + handler.args_received((b"foo",)) self.assertEqual(None, request.jail_info.transports) - handler.accept_body(b'bytes') + handler.accept_body(b"bytes") self.assertEqual(None, request.jail_info.transports) handler.end_received() self.assertEqual(None, request.jail_info.transports) - self.assertEqual( - [[transport]] * 3, handler._command.jail_transports_log) + self.assertEqual([[transport]] * 3, handler._command.jail_transports_log) def test_all_registered_requests_are_safety_qualified(self): unclassified_requests = [] - allowed_info = ('read', 'idem', 'mutate', 'semivfs', 'semi', 'stream') + allowed_info = ("read", "idem", "mutate", "semivfs", "semi", "stream") for key in request.request_handlers.keys(): info = request.request_handlers.get_info(key) if info is None or info not in allowed_info: unclassified_requests.append(key) if unclassified_requests: - self.fail('These requests were not categorized as safe/unsafe' - f' to retry: {unclassified_requests}') + self.fail( + "These requests were not categorized as safe/unsafe" + f" to retry: {unclassified_requests}" + ) class TestSmartRequestHandlerErrorTranslation(TestCase): @@ -145,40 +144,43 @@ def assertNoResponse(self, handler): self.assertEqual(None, handler.response) def assertResponseIsTranslatedError(self, handler): - expected_translation = (b'NoSuchFile', b'xyzzy') + expected_translation = (b"NoSuchFile", b"xyzzy") self.assertEqual( - request.FailedSmartServerResponse(expected_translation), - handler.response) + request.FailedSmartServerResponse(expected_translation), handler.response + ) def test_error_translation_from_args_received(self): - handler = request.SmartServerRequestHandler( - None, {b'foo': DoErrorRequest}, '/') - handler.args_received((b'foo',)) + handler = request.SmartServerRequestHandler(None, {b"foo": DoErrorRequest}, "/") + handler.args_received((b"foo",)) self.assertResponseIsTranslatedError(handler) def test_error_translation_from_chunk_received(self): handler = request.SmartServerRequestHandler( - None, {b'foo': ChunkErrorRequest}, '/') - handler.args_received((b'foo',)) + None, {b"foo": ChunkErrorRequest}, "/" + ) + handler.args_received((b"foo",)) self.assertNoResponse(handler) - handler.accept_body(b'bytes') + handler.accept_body(b"bytes") self.assertResponseIsTranslatedError(handler) def test_error_translation_from_end_received(self): handler = request.SmartServerRequestHandler( - None, {b'foo': EndErrorRequest}, '/') - handler.args_received((b'foo',)) + None, {b"foo": EndErrorRequest}, "/" + ) + handler.args_received((b"foo",)) self.assertNoResponse(handler) handler.end_received() self.assertResponseIsTranslatedError(handler) def test_unexpected_error_translation(self): handler = request.SmartServerRequestHandler( - None, {b'foo': DoUnexpectedErrorRequest}, '/') - handler.args_received((b'foo',)) + None, {b"foo": DoUnexpectedErrorRequest}, "/" + ) + handler.args_received((b"foo",)) self.assertEqual( - request.FailedSmartServerResponse((b'error', b'KeyError', b"1")), - handler.response) + request.FailedSmartServerResponse((b"error", b"KeyError", b"1")), + handler.response, + ) class TestRequestHanderErrorTranslation(TestCase): @@ -189,47 +191,52 @@ def assertTranslationEqual(self, expected_tuple, error): def test_NoSuchFile(self): self.assertTranslationEqual( - (b'NoSuchFile', b'path'), transport.NoSuchFile('path')) + (b"NoSuchFile", b"path"), transport.NoSuchFile("path") + ) def test_LockContention(self): # For now, LockContentions are always transmitted with no details. # Eventually they should include a relpath or url or something else to # identify which lock is busy. self.assertTranslationEqual( - (b'LockContention',), errors.LockContention('lock', 'msg')) + (b"LockContention",), errors.LockContention("lock", "msg") + ) def test_TokenMismatch(self): self.assertTranslationEqual( - (b'TokenMismatch', b'some-token', b'actual-token'), - errors.TokenMismatch(b'some-token', b'actual-token')) + (b"TokenMismatch", b"some-token", b"actual-token"), + errors.TokenMismatch(b"some-token", b"actual-token"), + ) def test_MemoryError(self): self.assertTranslationEqual((b"MemoryError",), MemoryError()) def test_GhostRevisionsHaveNoRevno(self): self.assertTranslationEqual( - (b"GhostRevisionsHaveNoRevno", b'revid1', b'revid2'), - errors.GhostRevisionsHaveNoRevno(b'revid1', b'revid2')) + (b"GhostRevisionsHaveNoRevno", b"revid1", b"revid2"), + errors.GhostRevisionsHaveNoRevno(b"revid1", b"revid2"), + ) def test_generic_Exception(self): - self.assertTranslationEqual((b'error', b'Exception', b""), - Exception()) + self.assertTranslationEqual((b"error", b"Exception", b""), Exception()) def test_generic_BzrError(self): - self.assertTranslationEqual((b'error', b'BzrError', b"some text"), - errors.BzrError(msg="some text")) + self.assertTranslationEqual( + (b"error", b"BzrError", b"some text"), errors.BzrError(msg="some text") + ) def test_generic_zlib_error(self): from zlib import error + msg = "Error -3 while decompressing data: incorrect data check" - self.assertTranslationEqual((b'error', b'zlib.error', msg.encode('utf-8')), - error(msg)) + self.assertTranslationEqual( + (b"error", b"zlib.error", msg.encode("utf-8")), error(msg) + ) class TestRequestJail(TestCaseWithMemoryTransport): - def test_jail(self): - transport = self.get_transport('blah') + transport = self.get_transport("blah") req = request.SmartServerRequest(transport) self.assertEqual(None, request.jail_info.transports) req.setup_jail() @@ -239,30 +246,33 @@ def test_jail(self): class TestJailHook(TestCaseWithMemoryTransport): - def setUp(self): super().setUp() def clear_jail_info(): request.jail_info.transports = None + self.addCleanup(clear_jail_info) def test_jail_hook(self): request.jail_info.transports = None _pre_open_hook = request._pre_open_hook # Any transport is fine if jail_info.transports is None - t = self.get_transport('foo') + t = self.get_transport("foo") _pre_open_hook(t) # A transport in jail_info.transports is allowed request.jail_info.transports = [t] _pre_open_hook(t) # A child of a transport in jail_info is allowed - _pre_open_hook(t.clone('child')) + _pre_open_hook(t.clone("child")) # A parent is not allowed - self.assertRaises(errors.JailBreak, _pre_open_hook, t.clone('..')) + self.assertRaises(errors.JailBreak, _pre_open_hook, t.clone("..")) # A completely unrelated transport is not allowed - self.assertRaises(errors.JailBreak, _pre_open_hook, - transport.get_transport_from_url('http://host/')) + self.assertRaises( + errors.JailBreak, + _pre_open_hook, + transport.get_transport_from_url("http://host/"), + ) def test_open_bzrdir_in_non_main_thread(self): """Opening a bzrdir in a non-main thread should work ok. @@ -271,14 +281,15 @@ def test_open_bzrdir_in_non_main_thread(self): breezy.bzr.smart.request._pre_open_hook, which uses a threading.local(), works in a newly created thread. """ - bzrdir = self.make_controldir('.') + bzrdir = self.make_controldir(".") transport = bzrdir.root_transport thread_result = [] def t(): BzrDir.open_from_transport(transport) - thread_result.append('ok') + thread_result.append("ok") + thread = threading.Thread(target=t) thread.start() thread.join() - self.assertEqual(['ok'], thread_result) + self.assertEqual(["ok"], thread_result) diff --git a/breezy/bzr/tests/test_smart_signals.py b/breezy/bzr/tests/test_smart_signals.py index 2e51767174..8d9f8e9dc1 100644 --- a/breezy/bzr/tests/test_smart_signals.py +++ b/breezy/bzr/tests/test_smart_signals.py @@ -26,11 +26,10 @@ # Windows doesn't define SIGHUP. And while we could just skip a lot of these # tests, we often don't actually care about interaction with 'signal', so we # can still run the tests for code coverage. -SIGHUP = getattr(signal, 'SIGHUP', 1) +SIGHUP = getattr(signal, "SIGHUP", 1) class TestSignalHandlers(tests.TestCase): - def setUp(self): super().setUp() # This allows us to mutate the signal handler callbacks, but leave it @@ -44,44 +43,46 @@ def setUp(self): def cleanup(): signals._on_sighup = None + self.addCleanup(cleanup) def test_registered_callback_gets_called(self): calls = [] def call_me(): - calls.append('called') - signals.register_on_hangup('myid', call_me) + calls.append("called") + + signals.register_on_hangup("myid", call_me) signals._sighup_handler(SIGHUP, None) - self.assertEqual(['called'], calls) - signals.unregister_on_hangup('myid') + self.assertEqual(["called"], calls) + signals.unregister_on_hangup("myid") def test_unregister_not_present(self): # We don't want unregister to fail, since it is generally run at times # that shouldn't interrupt other flow. - signals.unregister_on_hangup('no-such-id') + signals.unregister_on_hangup("no-such-id") log = self.get_log() - self.assertContainsRe( - log, 'Error occurred during unregister_on_hangup:') - self.assertContainsRe(log, '(?s)Traceback.*KeyError') + self.assertContainsRe(log, "Error occurred during unregister_on_hangup:") + self.assertContainsRe(log, "(?s)Traceback.*KeyError") def test_failing_callback(self): calls = [] def call_me(): - calls.append('called') + calls.append("called") def fail_me(): - raise RuntimeError('something bad happened') - signals.register_on_hangup('myid', call_me) - signals.register_on_hangup('otherid', fail_me) + raise RuntimeError("something bad happened") + + signals.register_on_hangup("myid", call_me) + signals.register_on_hangup("otherid", fail_me) # _sighup_handler should call both, even though it got an exception signals._sighup_handler(SIGHUP, None) - signals.unregister_on_hangup('myid') - signals.unregister_on_hangup('otherid') + signals.unregister_on_hangup("myid") + signals.unregister_on_hangup("otherid") log = self.get_log() - self.assertContainsRe(log, '(?s)Traceback.*RuntimeError') - self.assertEqual(['called'], calls) + self.assertContainsRe(log, "(?s)Traceback.*RuntimeError") + self.assertEqual(["called"], calls) def test_unregister_during_call(self): # _sighup_handler should handle if some callbacks actually remove @@ -89,13 +90,14 @@ def test_unregister_during_call(self): calls = [] def call_me_and_unregister(): - signals.unregister_on_hangup('myid') - calls.append('called_and_unregistered') + signals.unregister_on_hangup("myid") + calls.append("called_and_unregistered") def call_me(): - calls.append('called') - signals.register_on_hangup('myid', call_me_and_unregister) - signals.register_on_hangup('other', call_me) + calls.append("called") + + signals.register_on_hangup("myid", call_me_and_unregister) + signals.register_on_hangup("other", call_me) signals._sighup_handler(SIGHUP, None) def test_keyboard_interrupt_propagated(self): @@ -103,10 +105,10 @@ def test_keyboard_interrupt_propagated(self): # not suppress KeyboardInterrupt def call_me_and_raise(): raise KeyboardInterrupt() - signals.register_on_hangup('myid', call_me_and_raise) - self.assertRaises(KeyboardInterrupt, - signals._sighup_handler, SIGHUP, None) - signals.unregister_on_hangup('myid') + + signals.register_on_hangup("myid", call_me_and_raise) + self.assertRaises(KeyboardInterrupt, signals._sighup_handler, SIGHUP, None) + signals.unregister_on_hangup("myid") def test_weak_references(self): # TODO: This is probably a very-CPython-specific test @@ -114,13 +116,13 @@ def test_weak_references(self): # We overrideAttr during the test suite, so that we don't pollute the # original dict. However, we can test that what we override matches # what we are putting there. - self.assertIsInstance(signals._on_sighup, - weakref.WeakValueDictionary) + self.assertIsInstance(signals._on_sighup, weakref.WeakValueDictionary) calls = [] def call_me(): - calls.append('called') - signals.register_on_hangup('myid', call_me) + calls.append("called") + + signals.register_on_hangup("myid", call_me) del call_me # Non-CPython might want to do a gc.collect() here signals._sighup_handler(SIGHUP, None) @@ -133,19 +135,20 @@ def test_not_installed(self): calls = [] def call_me(): - calls.append('called') - signals.register_on_hangup('myid', calls) + calls.append("called") + + signals.register_on_hangup("myid", calls) signals._sighup_handler(SIGHUP, None) - signals.unregister_on_hangup('myid') + signals.unregister_on_hangup("myid") log = self.get_log() - self.assertEqual('', log) + self.assertEqual("", log) def test_install_sighup_handler(self): # install_sighup_handler should set up a signal handler for SIGHUP, as # well as the signals._on_sighup dict. signals._on_sighup = None orig = signals.install_sighup_handler() - if getattr(signal, 'SIGHUP', None) is not None: + if getattr(signal, "SIGHUP", None) is not None: cur = signal.getsignal(SIGHUP) self.assertEqual(signals._sighup_handler, cur) self.assertIsNot(None, signals._on_sighup) @@ -154,17 +157,16 @@ def test_install_sighup_handler(self): class TestInetServer(tests.TestCase): - def create_file_pipes(self): r, w = os.pipe() - rf = os.fdopen(r, 'rb') - wf = os.fdopen(w, 'wb') + rf = os.fdopen(r, "rb") + wf = os.fdopen(w, "wb") return rf, wf def test_inet_server_responds_to_sighup(self): - t = transport.get_transport('memory:///') - content = b'a' * 1024 * 1024 - t.put_bytes('bigfile', content) + t = transport.get_transport("memory:///") + content = b"a" * 1024 * 1024 + t.put_bytes("bigfile", content) factory = server.BzrServerFactory() # Override stdin/stdout so that we can inject our own handles client_read, server_write = self.create_file_pipes() @@ -179,14 +181,15 @@ def serving(): started.set() factory.smart_server.serve() stopped.set() + server_thread = threading.Thread(target=serving) server_thread.start() started.wait() - client_medium = medium.SmartSimplePipesClientMedium(client_read, - client_write, 'base') + client_medium = medium.SmartSimplePipesClientMedium( + client_read, client_write, "base" + ) client_client = client._SmartClient(client_medium) - resp, response_handler = client_client.call_expecting_body(b'get', - b'bigfile') + resp, response_handler = client_client.call_expecting_body(b"get", b"bigfile") signals._sighup_handler(SIGHUP, None) self.assertTrue(factory.smart_server.finished) # We can still finish reading the file content, but more than that, and diff --git a/breezy/bzr/tests/test_smart_transport.py b/breezy/bzr/tests/test_smart_transport.py index c852dfc8a0..baf54c6c19 100644 --- a/breezy/bzr/tests/test_smart_transport.py +++ b/breezy/bzr/tests/test_smart_transport.py @@ -48,8 +48,8 @@ def create_file_pipes(): r, w = os.pipe() # These must be opened without buffering, or we get undefined results - rf = os.fdopen(r, 'rb', 0) - wf = os.fdopen(w, 'wb', 0) + rf = os.fdopen(r, "rb", 0) + wf = os.fdopen(w, "wb", 0) return rf, wf @@ -59,7 +59,7 @@ def portable_socket_pair(): Unlike socket.socketpair, this should work on Windows. """ listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - listen_sock.bind(('127.0.0.1', 0)) + listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(1) client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) client_sock.connect(listen_sock.getsockname()) @@ -77,8 +77,7 @@ def __init__(self, read_from, write_to): self.calls = [] def connect_ssh(self, username, password, host, port, command): - self.calls.append(('connect_ssh', username, password, host, port, - command)) + self.calls.append(("connect_ssh", username, password, host, port, command)) return BytesIOSSHConnection(self) @@ -89,14 +88,12 @@ class FirstRejectedBytesIOSSHVendor(BytesIOSSHVendor): """ def __init__(self, read_from, write_to, fail_at_write=True): - super().__init__(read_from, - write_to) + super().__init__(read_from, write_to) self.fail_at_write = fail_at_write self._first = True def connect_ssh(self, username, password, host, port, command): - self.calls.append(('connect_ssh', username, password, host, port, - command)) + self.calls.append(("connect_ssh", username, password, host, port, command)) if self._first: self._first = False return ClosedSSHConnection(self) @@ -110,12 +107,12 @@ def __init__(self, vendor): self.vendor = vendor def close(self): - self.vendor.calls.append(('close', )) + self.vendor.calls.append(("close",)) self.vendor.read_from.close() self.vendor.write_to.close() def get_sock_or_pipes(self): - return 'pipes', (self.vendor.read_from, self.vendor.write_to) + return "pipes", (self.vendor.read_from, self.vendor.write_to) class ClosedSSHConnection(ssh.SSHConnection): @@ -125,7 +122,7 @@ def __init__(self, vendor): self.vendor = vendor def close(self): - self.vendor.calls.append(('close', )) + self.vendor.calls.append(("close",)) def get_sock_or_pipes(self): # We create matching pipes, and then close the ssh side @@ -138,7 +135,7 @@ def get_sock_or_pipes(self): ssh_read.close() else: bzr_write = self.vendor.write_to - return 'pipes', (bzr_read, bzr_write) + return "pipes", (bzr_read, bzr_write) class _InvalidHostnameFeature(features.Feature): @@ -153,7 +150,7 @@ class _InvalidHostnameFeature(features.Feature): def _probe(self): try: - socket.gethostbyname('non_existent.invalid') + socket.gethostbyname("non_existent.invalid") except socket.gaierror: # The host name failed to resolve. Good. return True @@ -161,7 +158,7 @@ def _probe(self): return False def feature_name(self): - return 'invalid hostname' + return "invalid hostname" InvalidHostnameFeature = _InvalidHostnameFeature() @@ -178,10 +175,10 @@ class SmartClientMediumTests(tests.TestCase): def make_loopsocket_and_medium(self): """Create a loopback socket for testing, and a medium aimed at it.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(('127.0.0.1', 0)) + sock.bind(("127.0.0.1", 0)) sock.listen(1) port = sock.getsockname()[1] - client_medium = medium.SmartTCPClientMedium('127.0.0.1', port, 'base') + client_medium = medium.SmartTCPClientMedium("127.0.0.1", port, "base") return sock, client_medium def receive_bytes_on_server(self, sock, bytes): @@ -191,10 +188,12 @@ def receive_bytes_on_server(self, sock, bytes): :return: a Thread which is running to do the accept and recv. """ + def _receive_bytes_on_server(): connection, address = sock.accept() bytes.append(osutils.recv_all(connection, 3)) connection.close() + t = threading.Thread(target=_receive_bytes_on_server) t.start() return t @@ -220,8 +219,7 @@ def test_simple_pipes_client_get_concurrent_requests(self): # classes - as the sibling classes share this logic, they do not have # explicit tests for this. output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - None, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, output, "base") request = client_medium.get_request() request.finished_writing() request.finished_reading() @@ -232,66 +230,73 @@ def test_simple_pipes_client_get_concurrent_requests(self): def test_simple_pipes_client__accept_bytes_writes_to_writable(self): # accept_bytes writes to the writeable pipe. output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - None, output, 'base') - client_medium._accept_bytes(b'abc') - self.assertEqual(b'abc', output.getvalue()) + client_medium = medium.SmartSimplePipesClientMedium(None, output, "base") + client_medium._accept_bytes(b"abc") + self.assertEqual(b"abc", output.getvalue()) def test_simple_pipes__accept_bytes_subprocess_closed(self): # It is unfortunate that we have to use Popen for this. However, # os.pipe() does not behave the same as subprocess.Popen(). # On Windows, if you use os.pipe() and close the write side, # read.read() hangs. On Linux, read.read() returns the empty string. - p = subprocess.Popen([sys.executable, '-c', - 'import sys\n' - 'sys.stdout.write(sys.stdin.read(4))\n' - 'sys.stdout.close()\n'], - stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0) - client_medium = medium.SmartSimplePipesClientMedium( - p.stdout, p.stdin, 'base') - client_medium._accept_bytes(b'abc\n') - self.assertEqual(b'abc', client_medium._read_bytes(3)) + p = subprocess.Popen( + [ + sys.executable, + "-c", + "import sys\n" + "sys.stdout.write(sys.stdin.read(4))\n" + "sys.stdout.close()\n", + ], + stdout=subprocess.PIPE, + stdin=subprocess.PIPE, + bufsize=0, + ) + client_medium = medium.SmartSimplePipesClientMedium(p.stdout, p.stdin, "base") + client_medium._accept_bytes(b"abc\n") + self.assertEqual(b"abc", client_medium._read_bytes(3)) p.wait() # While writing to the underlying pipe, # Windows py2.6.6 we get IOError(EINVAL) # Lucid py2.6.5, we get IOError(EPIPE) # In both cases, it should be wrapped to ConnectionReset - self.assertRaises(ConnectionResetError, - client_medium._accept_bytes, b'more') + self.assertRaises(ConnectionResetError, client_medium._accept_bytes, b"more") def test_simple_pipes__accept_bytes_pipe_closed(self): child_read, client_write = create_file_pipes() - client_medium = medium.SmartSimplePipesClientMedium( - None, client_write, 'base') - client_medium._accept_bytes(b'abc\n') - self.assertEqual(b'abc\n', child_read.read(4)) + client_medium = medium.SmartSimplePipesClientMedium(None, client_write, "base") + client_medium._accept_bytes(b"abc\n") + self.assertEqual(b"abc\n", child_read.read(4)) # While writing to the underlying pipe, # Windows py2.6.6 we get IOError(EINVAL) # Lucid py2.6.5, we get IOError(EPIPE) # In both cases, it should be wrapped to ConnectionReset child_read.close() - self.assertRaises(ConnectionResetError, - client_medium._accept_bytes, b'more') + self.assertRaises(ConnectionResetError, client_medium._accept_bytes, b"more") def test_simple_pipes__flush_pipe_closed(self): child_read, client_write = create_file_pipes() - client_medium = medium.SmartSimplePipesClientMedium( - None, client_write, 'base') - client_medium._accept_bytes(b'abc\n') + client_medium = medium.SmartSimplePipesClientMedium(None, client_write, "base") + client_medium._accept_bytes(b"abc\n") child_read.close() # Even though the pipe is closed, flush on the write side seems to be a # no-op, rather than a failure. client_medium._flush() def test_simple_pipes__flush_subprocess_closed(self): - p = subprocess.Popen([sys.executable, '-c', - 'import sys\n' - 'sys.stdout.write(sys.stdin.read(4))\n' - 'sys.stdout.close()\n'], - stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0) - client_medium = medium.SmartSimplePipesClientMedium( - p.stdout, p.stdin, 'base') - client_medium._accept_bytes(b'abc\n') + p = subprocess.Popen( + [ + sys.executable, + "-c", + "import sys\n" + "sys.stdout.write(sys.stdin.read(4))\n" + "sys.stdout.close()\n", + ], + stdout=subprocess.PIPE, + stdin=subprocess.PIPE, + bufsize=0, + ) + client_medium = medium.SmartSimplePipesClientMedium(p.stdout, p.stdin, "base") + client_medium._accept_bytes(b"abc\n") p.wait() # Even though the child process is dead, flush seems to be a no-op. client_medium._flush() @@ -299,38 +304,44 @@ def test_simple_pipes__flush_subprocess_closed(self): def test_simple_pipes__read_bytes_pipe_closed(self): child_read, client_write = create_file_pipes() client_medium = medium.SmartSimplePipesClientMedium( - child_read, client_write, 'base') - client_medium._accept_bytes(b'abc\n') + child_read, client_write, "base" + ) + client_medium._accept_bytes(b"abc\n") client_write.close() - self.assertEqual(b'abc\n', client_medium._read_bytes(4)) - self.assertEqual(b'', client_medium._read_bytes(4)) + self.assertEqual(b"abc\n", client_medium._read_bytes(4)) + self.assertEqual(b"", client_medium._read_bytes(4)) def test_simple_pipes__read_bytes_subprocess_closed(self): - p = subprocess.Popen([sys.executable, '-c', - 'import sys\n' - 'if sys.platform == "win32":\n' - ' import msvcrt, os\n' - ' msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n' - ' msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n' - 'sys.stdout.write(sys.stdin.read(4))\n' - 'sys.stdout.close()\n'], - stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0) - client_medium = medium.SmartSimplePipesClientMedium( - p.stdout, p.stdin, 'base') - client_medium._accept_bytes(b'abc\n') + p = subprocess.Popen( + [ + sys.executable, + "-c", + "import sys\n" + 'if sys.platform == "win32":\n' + " import msvcrt, os\n" + " msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n" + " msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n" + "sys.stdout.write(sys.stdin.read(4))\n" + "sys.stdout.close()\n", + ], + stdout=subprocess.PIPE, + stdin=subprocess.PIPE, + bufsize=0, + ) + client_medium = medium.SmartSimplePipesClientMedium(p.stdout, p.stdin, "base") + client_medium._accept_bytes(b"abc\n") p.wait() - self.assertEqual(b'abc\n', client_medium._read_bytes(4)) - self.assertEqual(b'', client_medium._read_bytes(4)) + self.assertEqual(b"abc\n", client_medium._read_bytes(4)) + self.assertEqual(b"", client_medium._read_bytes(4)) def test_simple_pipes_client_disconnect_does_nothing(self): # calling disconnect does nothing. input = BytesIO() output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") # send some bytes to ensure disconnecting after activity still does not # close. - client_medium._accept_bytes(b'abc') + client_medium._accept_bytes(b"abc") client_medium.disconnect() self.assertFalse(input.closed) self.assertFalse(output.closed) @@ -340,50 +351,50 @@ def test_simple_pipes_client_accept_bytes_after_disconnect(self): # accept_bytes writes to. input = BytesIO() output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') - client_medium._accept_bytes(b'abc') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") + client_medium._accept_bytes(b"abc") client_medium.disconnect() - client_medium._accept_bytes(b'abc') + client_medium._accept_bytes(b"abc") self.assertFalse(input.closed) self.assertFalse(output.closed) - self.assertEqual(b'abcabc', output.getvalue()) + self.assertEqual(b"abcabc", output.getvalue()) def test_simple_pipes_client_ignores_disconnect_when_not_connected(self): # Doing a disconnect on a new (and thus unconnected) SimplePipes medium # does nothing. - client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, None, "base") client_medium.disconnect() def test_simple_pipes_client_can_always_read(self): # SmartSimplePipesClientMedium is never disconnected, so read_bytes # always tries to read from the underlying pipe. - input = BytesIO(b'abcdef') - client_medium = medium.SmartSimplePipesClientMedium( - input, None, 'base') - self.assertEqual(b'abc', client_medium.read_bytes(3)) + input = BytesIO(b"abcdef") + client_medium = medium.SmartSimplePipesClientMedium(input, None, "base") + self.assertEqual(b"abc", client_medium.read_bytes(3)) client_medium.disconnect() - self.assertEqual(b'def', client_medium.read_bytes(3)) + self.assertEqual(b"def", client_medium.read_bytes(3)) def test_simple_pipes_client_supports__flush(self): # invoking _flush on a SimplePipesClient should flush the output # pipe. We test this by creating an output pipe that records # flush calls made to it. from io import BytesIO # get regular BytesIO + input = BytesIO() output = BytesIO() flush_calls = [] - def logging_flush(): flush_calls.append('flush') + def logging_flush(): + flush_calls.append("flush") + output.flush = logging_flush - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") # this call is here to ensure we only flush once, not on every # _accept_bytes call. - client_medium._accept_bytes(b'abc') + client_medium._accept_bytes(b"abc") client_medium._flush() client_medium.disconnect() - self.assertEqual(['flush'], flush_calls) + self.assertEqual(["flush"], flush_calls) def test_construct_smart_ssh_client_medium(self): # the SSH client medium takes: @@ -392,13 +403,12 @@ def test_construct_smart_ssh_client_medium(self): # we test this by creating a empty bound socket and constructing # a medium. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(('127.0.0.1', 0)) + sock.bind(("127.0.0.1", 0)) unopened_port = sock.getsockname()[1] # having vendor be invalid means that if it tries to connect via the # vendor it will blow up. - ssh_params = medium.SSHParams('127.0.0.1', unopened_port, None, None) - medium.SmartSSHClientMedium( - 'base', ssh_params, "not a vendor") + ssh_params = medium.SSHParams("127.0.0.1", unopened_port, None, None) + medium.SmartSSHClientMedium("base", ssh_params, "not a vendor") sock.close() def test_ssh_client_connects_on_first_use(self): @@ -407,14 +417,24 @@ def test_ssh_client_connects_on_first_use(self): output = BytesIO() vendor = BytesIOSSHVendor(BytesIO(), output) ssh_params = medium.SSHParams( - 'a hostname', 'a port', 'a username', 'a password', 'bzr') - client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor) - client_medium._accept_bytes(b'abc') - self.assertEqual(b'abc', output.getvalue()) - self.assertEqual([('connect_ssh', 'a username', 'a password', - 'a hostname', 'a port', - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])], - vendor.calls) + "a hostname", "a port", "a username", "a password", "bzr" + ) + client_medium = medium.SmartSSHClientMedium("base", ssh_params, vendor) + client_medium._accept_bytes(b"abc") + self.assertEqual(b"abc", output.getvalue()) + self.assertEqual( + [ + ( + "connect_ssh", + "a username", + "a password", + "a hostname", + "a port", + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ) + ], + vendor.calls, + ) def test_ssh_client_changes_command_when_bzr_remote_path_passed(self): # The only thing that initiates a connection from the medium is giving @@ -422,15 +442,24 @@ def test_ssh_client_changes_command_when_bzr_remote_path_passed(self): output = BytesIO() vendor = BytesIOSSHVendor(BytesIO(), output) ssh_params = medium.SSHParams( - 'a hostname', 'a port', 'a username', 'a password', - bzr_remote_path='fugly') - client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor) - client_medium._accept_bytes(b'abc') - self.assertEqual(b'abc', output.getvalue()) - self.assertEqual([('connect_ssh', 'a username', 'a password', - 'a hostname', 'a port', - ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])], - vendor.calls) + "a hostname", "a port", "a username", "a password", bzr_remote_path="fugly" + ) + client_medium = medium.SmartSSHClientMedium("base", ssh_params, vendor) + client_medium._accept_bytes(b"abc") + self.assertEqual(b"abc", output.getvalue()) + self.assertEqual( + [ + ( + "connect_ssh", + "a username", + "a password", + "a hostname", + "a port", + ["fugly", "serve", "--inet", "--directory=/", "--allow-writes"], + ) + ], + vendor.calls, + ) def test_ssh_client_disconnect_does_so(self): # calling disconnect should disconnect both the read_from and write_to @@ -439,17 +468,26 @@ def test_ssh_client_disconnect_does_so(self): output = BytesIO() vendor = BytesIOSSHVendor(input, output) client_medium = medium.SmartSSHClientMedium( - 'base', medium.SSHParams('a hostname'), vendor) - client_medium._accept_bytes(b'abc') + "base", medium.SSHParams("a hostname"), vendor + ) + client_medium._accept_bytes(b"abc") client_medium.disconnect() self.assertTrue(input.closed) self.assertTrue(output.closed) - self.assertEqual([ - ('connect_ssh', None, None, 'a hostname', None, - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ('close', ), + self.assertEqual( + [ + ( + "connect_ssh", + None, + None, + "a hostname", + None, + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ("close",), ], - vendor.calls) + vendor.calls, + ) def test_ssh_client_disconnect_allows_reconnection(self): # calling disconnect on the client terminates the connection, but should @@ -460,8 +498,9 @@ def test_ssh_client_disconnect_allows_reconnection(self): output = BytesIO() vendor = BytesIOSSHVendor(input, output) client_medium = medium.SmartSSHClientMedium( - 'base', medium.SSHParams('a hostname'), vendor) - client_medium._accept_bytes(b'abc') + "base", medium.SSHParams("a hostname"), vendor + ) + client_medium._accept_bytes(b"abc") client_medium.disconnect() # the disconnect has closed output, so we need a new output for the # new connection to write to. @@ -469,89 +508,106 @@ def test_ssh_client_disconnect_allows_reconnection(self): output2 = BytesIO() vendor.read_from = input2 vendor.write_to = output2 - client_medium._accept_bytes(b'abc') + client_medium._accept_bytes(b"abc") client_medium.disconnect() self.assertTrue(input.closed) self.assertTrue(output.closed) self.assertTrue(input2.closed) self.assertTrue(output2.closed) - self.assertEqual([ - ('connect_ssh', None, None, 'a hostname', None, - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ('close', ), - ('connect_ssh', None, None, 'a hostname', None, - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ('close', ), + self.assertEqual( + [ + ( + "connect_ssh", + None, + None, + "a hostname", + None, + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ("close",), + ( + "connect_ssh", + None, + None, + "a hostname", + None, + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ("close",), ], - vendor.calls) + vendor.calls, + ) def test_ssh_client_repr(self): client_medium = medium.SmartSSHClientMedium( - 'base', medium.SSHParams("example.com", "4242", "username")) + "base", medium.SSHParams("example.com", "4242", "username") + ) self.assertEqual( "SmartSSHClientMedium(bzr+ssh://username@example.com:4242/)", - repr(client_medium)) + repr(client_medium), + ) def test_ssh_client_repr_no_port(self): client_medium = medium.SmartSSHClientMedium( - 'base', medium.SSHParams("example.com", None, "username")) + "base", medium.SSHParams("example.com", None, "username") + ) self.assertEqual( - "SmartSSHClientMedium(bzr+ssh://username@example.com/)", - repr(client_medium)) + "SmartSSHClientMedium(bzr+ssh://username@example.com/)", repr(client_medium) + ) def test_ssh_client_repr_no_username(self): client_medium = medium.SmartSSHClientMedium( - 'base', medium.SSHParams("example.com", None, None)) + "base", medium.SSHParams("example.com", None, None) + ) self.assertEqual( - "SmartSSHClientMedium(bzr+ssh://example.com/)", - repr(client_medium)) + "SmartSSHClientMedium(bzr+ssh://example.com/)", repr(client_medium) + ) def test_ssh_client_ignores_disconnect_when_not_connected(self): # Doing a disconnect on a new (and thus unconnected) SSH medium # does not fail. It's ok to disconnect an unconnected medium. - client_medium = medium.SmartSSHClientMedium( - 'base', medium.SSHParams(None)) + client_medium = medium.SmartSSHClientMedium("base", medium.SSHParams(None)) client_medium.disconnect() def test_ssh_client_raises_on_read_when_not_connected(self): # Doing a read on a new (and thus unconnected) SSH medium raises # MediumNotConnected. - client_medium = medium.SmartSSHClientMedium( - 'base', medium.SSHParams(None)) - self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, - 0) - self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, - 1) + client_medium = medium.SmartSSHClientMedium("base", medium.SSHParams(None)) + self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0) + self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1) def test_ssh_client_supports__flush(self): # invoking _flush on a SSHClientMedium should flush the output # pipe. We test this by creating an output pipe that records # flush calls made to it. from io import BytesIO # get regular BytesIO + input = BytesIO() output = BytesIO() flush_calls = [] - def logging_flush(): flush_calls.append('flush') + def logging_flush(): + flush_calls.append("flush") + output.flush = logging_flush vendor = BytesIOSSHVendor(input, output) client_medium = medium.SmartSSHClientMedium( - 'base', medium.SSHParams('a hostname'), vendor=vendor) + "base", medium.SSHParams("a hostname"), vendor=vendor + ) # this call is here to ensure we only flush once, not on every # _accept_bytes call. - client_medium._accept_bytes(b'abc') + client_medium._accept_bytes(b"abc") client_medium._flush() client_medium.disconnect() - self.assertEqual(['flush'], flush_calls) + self.assertEqual(["flush"], flush_calls) def test_construct_smart_tcp_client_medium(self): # the TCP client medium takes a host and a port. Constructing it won't # connect to anything. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(('127.0.0.1', 0)) + sock.bind(("127.0.0.1", 0)) unopened_port = sock.getsockname()[1] - medium.SmartTCPClientMedium( - '127.0.0.1', unopened_port, 'base') + medium.SmartTCPClientMedium("127.0.0.1", unopened_port, "base") sock.close() def test_tcp_client_connects_on_first_use(self): @@ -560,10 +616,10 @@ def test_tcp_client_connects_on_first_use(self): sock, medium = self.make_loopsocket_and_medium() bytes = [] t = self.receive_bytes_on_server(sock, bytes) - medium.accept_bytes(b'abc') + medium.accept_bytes(b"abc") t.join() sock.close() - self.assertEqual([b'abc'], bytes) + self.assertEqual([b"abc"], bytes) def test_tcp_client_disconnect_does_so(self): # calling disconnect on the client terminates the connection. @@ -572,11 +628,11 @@ def test_tcp_client_disconnect_does_so(self): sock, medium = self.make_loopsocket_and_medium() bytes = [] t = self.receive_bytes_on_server(sock, bytes) - medium.accept_bytes(b'ab') + medium.accept_bytes(b"ab") medium.disconnect() t.join() sock.close() - self.assertEqual([b'ab'], bytes) + self.assertEqual([b"ab"], bytes) # now disconnect again: this should not do anything, if disconnection # really did disconnect. medium.disconnect() @@ -591,10 +647,8 @@ def test_tcp_client_raises_on_read_when_not_connected(self): # Doing a read on a new (and thus unconnected) TCP medium raises # MediumNotConnected. client_medium = medium.SmartTCPClientMedium(None, None, None) - self.assertRaises(errors.MediumNotConnected, - client_medium.read_bytes, 0) - self.assertRaises(errors.MediumNotConnected, - client_medium.read_bytes, 1) + self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0) + self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1) def test_tcp_client_supports__flush(self): # invoking _flush on a TCPClientMedium should do something useful. @@ -604,13 +658,13 @@ def test_tcp_client_supports__flush(self): t = self.receive_bytes_on_server(sock, bytes) # try with nothing buffered medium._flush() - medium._accept_bytes(b'ab') + medium._accept_bytes(b"ab") # and with something sent. medium._flush() medium.disconnect() t.join() sock.close() - self.assertEqual([b'ab'], bytes) + self.assertEqual([b"ab"], bytes) # now disconnect again : this should not do anything, if disconnection # really did disconnect. medium.disconnect() @@ -618,9 +672,9 @@ def test_tcp_client_supports__flush(self): def test_tcp_client_host_unknown_connection_error(self): self.requireFeature(InvalidHostnameFeature) client_medium = medium.SmartTCPClientMedium( - 'non_existent.invalid', 4155, 'base') - self.assertRaises( - ConnectionError, client_medium._ensure_connection) + "non_existent.invalid", 4155, "base" + ) + self.assertRaises(ConnectionError, client_medium._ensure_connection) class TestSmartClientStreamMediumRequest(tests.TestCase): @@ -636,8 +690,7 @@ def test_accept_bytes_after_finished_writing_errors(self): # WritingCompleted to prevent bad assumptions on stream environments # breaking the needs of message-based environments. output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - None, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, output, "base") request = medium.SmartClientStreamMediumRequest(client_medium) request.finished_writing() self.assertRaises(errors.WritingCompleted, request.accept_bytes, None) @@ -648,21 +701,19 @@ def test_accept_bytes(self): # and checking that the pipes get the data. input = BytesIO() output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = medium.SmartClientStreamMediumRequest(client_medium) - request.accept_bytes(b'123') + request.accept_bytes(b"123") request.finished_writing() request.finished_reading() - self.assertEqual(b'', input.getvalue()) - self.assertEqual(b'123', output.getvalue()) + self.assertEqual(b"", input.getvalue()) + self.assertEqual(b"123", output.getvalue()) def test_construct_sets_stream_request(self): # constructing a SmartClientStreamMediumRequest on a StreamMedium sets # the current request to the new SmartClientStreamMediumRequest output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - None, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, output, "base") request = medium.SmartClientStreamMediumRequest(client_medium) self.assertIs(client_medium._current_request, request) @@ -670,18 +721,19 @@ def test_construct_while_another_request_active_throws(self): # constructing a SmartClientStreamMediumRequest on a StreamMedium with # a non-None _current_request raises TooManyConcurrentRequests. output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - None, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, output, "base") client_medium._current_request = "a" - self.assertRaises(medium.TooManyConcurrentRequests, - medium.SmartClientStreamMediumRequest, client_medium) + self.assertRaises( + medium.TooManyConcurrentRequests, + medium.SmartClientStreamMediumRequest, + client_medium, + ) def test_finished_read_clears_current_request(self): # calling finished_reading clears the current request from the requests # medium output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - None, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, output, "base") request = medium.SmartClientStreamMediumRequest(client_medium) request.finished_writing() request.finished_reading() @@ -690,8 +742,7 @@ def test_finished_read_clears_current_request(self): def test_finished_read_before_finished_write_errors(self): # calling finished_reading before calling finished_writing triggers a # WritingNotComplete error. - client_medium = medium.SmartSimplePipesClientMedium( - None, None, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, None, "base") request = medium.SmartClientStreamMediumRequest(client_medium) self.assertRaises(errors.WritingNotComplete, request.finished_reading) @@ -702,23 +753,22 @@ def test_read_bytes(self): # faulty implementation could poke at the pipe variables them selves, # but we trust that this will be caught as it will break the integration # smoke tests. - input = BytesIO(b'321') + input = BytesIO(b"321") output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = medium.SmartClientStreamMediumRequest(client_medium) request.finished_writing() - self.assertEqual(b'321', request.read_bytes(3)) + self.assertEqual(b"321", request.read_bytes(3)) request.finished_reading() - self.assertEqual(b'', input.read()) - self.assertEqual(b'', output.getvalue()) + self.assertEqual(b"", input.read()) + self.assertEqual(b"", output.getvalue()) def test_read_bytes_before_finished_write_errors(self): # calling read_bytes before calling finished_writing triggers a # WritingNotComplete error because the Smart protocol is designed to be # compatible with strict message based protocols like HTTP where the # request cannot be submitted until the writing has completed. - client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, None, "base") request = medium.SmartClientStreamMediumRequest(client_medium) self.assertRaises(errors.WritingNotComplete, request.read_bytes, None) @@ -727,8 +777,7 @@ def test_read_bytes_after_finished_reading_errors(self): # ReadingCompleted to prevent bad assumptions on stream environments # breaking the needs of message-based environments. output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - None, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, output, "base") request = medium.SmartClientStreamMediumRequest(client_medium) request.finished_writing() request.finished_reading() @@ -742,15 +791,14 @@ def test_reset(self): client_medium._socket = client_sock client_medium._connected = True client_medium.get_request() - self.assertRaises(medium.TooManyConcurrentRequests, - client_medium.get_request) + self.assertRaises(medium.TooManyConcurrentRequests, client_medium.get_request) client_medium.reset() # The stream should be reset, marked as disconnected, though ready for # us to make a new request self.assertFalse(client_medium._connected) self.assertIs(None, client_medium._socket) try: - self.assertEqual('', client_sock.recv(1)) + self.assertEqual("", client_sock.recv(1)) except OSError as e: if e.errno not in (errno.EBADF,): raise @@ -758,9 +806,8 @@ def test_reset(self): class RemoteTransportTests(test_smart.TestCaseWithSmartMedium): - def test_plausible_url(self): - self.assertTrue(self.get_url().startswith('bzr://')) + self.assertTrue(self.get_url().startswith("bzr://")) def test_probe_transport(self): t = self.get_transport() @@ -774,7 +821,6 @@ def test_get_medium_from_transport(self): class ErrorRaisingProtocol: - def __init__(self, exception): self.exception = exception @@ -783,18 +829,17 @@ def next_read_size(self): class SampleRequest: - def __init__(self, expected_bytes): - self.accepted_bytes = b'' + self.accepted_bytes = b"" self._finished_reading = False self.expected_bytes = expected_bytes - self.unused_data = b'' + self.unused_data = b"" def accept_bytes(self, bytes): self.accepted_bytes += bytes if self.accepted_bytes.startswith(self.expected_bytes): self._finished_reading = True - self.unused_data = self.accepted_bytes[len(self.expected_bytes):] + self.unused_data = self.accepted_bytes[len(self.expected_bytes) :] def next_read_size(self): if self._finished_reading: @@ -804,16 +849,15 @@ def next_read_size(self): class TestSmartServerStreamMedium(tests.TestCase): - def setUp(self): super().setUp() - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) - def create_pipe_medium(self, to_server, from_server, transport, - timeout=4.0): + def create_pipe_medium(self, to_server, from_server, transport, timeout=4.0): """Create a new SmartServerPipeStreamMedium.""" - return medium.SmartServerPipeStreamMedium(to_server, from_server, - transport, timeout=timeout) + return medium.SmartServerPipeStreamMedium( + to_server, from_server, transport, timeout=timeout + ) def create_pipe_context(self, to_server_bytes, transport): """Create a SmartServerSocketStreamMedium. @@ -829,8 +873,9 @@ def create_pipe_context(self, to_server_bytes, transport): def create_socket_medium(self, server_sock, transport, timeout=4.0): """Initialize a new medium.SmartServerSocketStreamMedium.""" - return medium.SmartServerSocketStreamMedium(server_sock, transport, - timeout=timeout) + return medium.SmartServerSocketStreamMedium( + server_sock, transport, timeout=timeout + ) def create_socket_context(self, transport, timeout=4.0): """Create a new SmartServerSocketStreamMedium with default context. @@ -840,83 +885,81 @@ def create_socket_context(self, transport, timeout=4.0): It then returns the client_sock and the server. """ server_sock, client_sock = portable_socket_pair() - server = self.create_socket_medium(server_sock, transport, - timeout=timeout) + server = self.create_socket_medium(server_sock, transport, timeout=timeout) return server, client_sock def test_smart_query_version(self): """Feed a canned query version to a server.""" # wire-to-wire, using the whole stack - transport = local.LocalTransport(urlutils.local_path_to_url('/')) - server, from_server = self.create_pipe_context(b'hello\n', transport) - smart_protocol = protocol.SmartServerRequestProtocolOne(transport, - from_server.write) + transport = local.LocalTransport(urlutils.local_path_to_url("/")) + server, from_server = self.create_pipe_context(b"hello\n", transport) + smart_protocol = protocol.SmartServerRequestProtocolOne( + transport, from_server.write + ) server._serve_one_request(smart_protocol) - self.assertEqual(b'ok\0012\n', - from_server.getvalue()) + self.assertEqual(b"ok\0012\n", from_server.getvalue()) def test_response_to_canned_get(self): - transport = memory.MemoryTransport('memory:///') - transport.put_bytes('testfile', b'contents\nof\nfile\n') - server, from_server = self.create_pipe_context(b'get\001./testfile\n', - transport) - smart_protocol = protocol.SmartServerRequestProtocolOne(transport, - from_server.write) + transport = memory.MemoryTransport("memory:///") + transport.put_bytes("testfile", b"contents\nof\nfile\n") + server, from_server = self.create_pipe_context( + b"get\001./testfile\n", transport + ) + smart_protocol = protocol.SmartServerRequestProtocolOne( + transport, from_server.write + ) server._serve_one_request(smart_protocol) - self.assertEqual(b'ok\n' - b'17\n' - b'contents\nof\nfile\n' - b'done\n', - from_server.getvalue()) + self.assertEqual( + b"ok\n" b"17\n" b"contents\nof\nfile\n" b"done\n", from_server.getvalue() + ) def test_response_to_canned_get_of_utf8(self): # wire-to-wire, using the whole stack, with a UTF-8 filename. - transport = memory.MemoryTransport('memory:///') - utf8_filename = 'testfile\N{INTERROBANG}'.encode() + transport = memory.MemoryTransport("memory:///") + utf8_filename = "testfile\N{INTERROBANG}".encode() # VFS requests use filenames, not raw UTF-8. hpss_path = urlutils.quote_from_bytes(utf8_filename) - transport.put_bytes(hpss_path, b'contents\nof\nfile\n') + transport.put_bytes(hpss_path, b"contents\nof\nfile\n") server, from_server = self.create_pipe_context( - b'get\001' + hpss_path.encode('ascii') + b'\n', transport) - smart_protocol = protocol.SmartServerRequestProtocolOne(transport, - from_server.write) + b"get\001" + hpss_path.encode("ascii") + b"\n", transport + ) + smart_protocol = protocol.SmartServerRequestProtocolOne( + transport, from_server.write + ) server._serve_one_request(smart_protocol) - self.assertEqual(b'ok\n' - b'17\n' - b'contents\nof\nfile\n' - b'done\n', - from_server.getvalue()) + self.assertEqual( + b"ok\n" b"17\n" b"contents\nof\nfile\n" b"done\n", from_server.getvalue() + ) def test_pipe_like_stream_with_bulk_data(self): - sample_request_bytes = b'command\n9\nbulk datadone\n' - server, from_server = self.create_pipe_context( - sample_request_bytes, None) + sample_request_bytes = b"command\n9\nbulk datadone\n" + server, from_server = self.create_pipe_context(sample_request_bytes, None) sample_protocol = SampleRequest(expected_bytes=sample_request_bytes) server._serve_one_request(sample_protocol) - self.assertEqual(b'', from_server.getvalue()) + self.assertEqual(b"", from_server.getvalue()) self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes) self.assertFalse(server.finished) def test_socket_stream_with_bulk_data(self): - sample_request_bytes = b'command\n9\nbulk datadone\n' + sample_request_bytes = b"command\n9\nbulk datadone\n" server, client_sock = self.create_socket_context(None) sample_protocol = SampleRequest(expected_bytes=sample_request_bytes) client_sock.sendall(sample_request_bytes) server._serve_one_request(sample_protocol) server._disconnect_client() - self.assertEqual(b'', client_sock.recv(1)) + self.assertEqual(b"", client_sock.recv(1)) self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes) self.assertFalse(server.finished) def test_pipe_like_stream_shutdown_detection(self): - server, _ = self.create_pipe_context(b'', None) - server._serve_one_request(SampleRequest(b'x')) + server, _ = self.create_pipe_context(b"", None) + server._serve_one_request(SampleRequest(b"x")) self.assertTrue(server.finished) def test_socket_stream_shutdown_detection(self): server, client_sock = self.create_socket_context(None) client_sock.close() - server._serve_one_request(SampleRequest(b'x')) + server._serve_one_request(SampleRequest(b"x")) self.assertTrue(server.finished) def test_socket_stream_incomplete_request(self): @@ -928,19 +971,21 @@ def test_socket_stream_incomplete_request(self): implementation of _get_line in the server used to have a bug in that case. """ - incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + b'hel' - rest_of_request_bytes = b'lo\n' - expected_response = ( - protocol.RESPONSE_VERSION_TWO + b'success\nok\x012\n') + incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + b"hel" + rest_of_request_bytes = b"lo\n" + expected_response = protocol.RESPONSE_VERSION_TWO + b"success\nok\x012\n" server, client_sock = self.create_socket_context(None) client_sock.sendall(incomplete_request_bytes) server_protocol = server._build_protocol() client_sock.sendall(rest_of_request_bytes) server._serve_one_request(server_protocol) server._disconnect_client() - self.assertEqual(expected_response, osutils.recv_all(client_sock, 50), - "Not a version 2 response to 'hello' request.") - self.assertEqual(b'', client_sock.recv(1)) + self.assertEqual( + expected_response, + osutils.recv_all(client_sock, 50), + "Not a version 2 response to 'hello' request.", + ) + self.assertEqual(b"", client_sock.recv(1)) def test_pipe_stream_incomplete_request(self): """The medium should still construct the right protocol version even if @@ -951,17 +996,16 @@ def test_pipe_stream_incomplete_request(self): implementation of _get_line in the server used to have a bug in that case. """ - incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + b'hel' - rest_of_request_bytes = b'lo\n' - expected_response = ( - protocol.RESPONSE_VERSION_TWO + b'success\nok\x012\n') + incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + b"hel" + rest_of_request_bytes = b"lo\n" + expected_response = protocol.RESPONSE_VERSION_TWO + b"success\nok\x012\n" # Make a pair of pipes, to and from the server to_server, to_server_w = os.pipe() from_server_r, from_server = os.pipe() - to_server = os.fdopen(to_server, 'rb', 0) - to_server_w = os.fdopen(to_server_w, 'wb', 0) - from_server_r = os.fdopen(from_server_r, 'rb', 0) - from_server = os.fdopen(from_server, 'wb', 0) + to_server = os.fdopen(to_server, "rb", 0) + to_server_w = os.fdopen(to_server_w, "wb", 0) + from_server_r = os.fdopen(from_server_r, "rb", 0) + from_server = os.fdopen(from_server, "wb", 0) server = self.create_pipe_medium(to_server, from_server, None) # Like test_socket_stream_incomplete_request, write an incomplete # request (that does not end in '\n') and build a protocol from it. @@ -972,9 +1016,12 @@ def test_pipe_stream_incomplete_request(self): server._serve_one_request(server_protocol) to_server_w.close() from_server.close() - self.assertEqual(expected_response, from_server_r.read(), - "Not a version 2 response to 'hello' request.") - self.assertEqual(b'', from_server_r.read(1)) + self.assertEqual( + expected_response, + from_server_r.read(), + "Not a version 2 response to 'hello' request.", + ) + self.assertEqual(b"", from_server_r.read(1)) from_server_r.close() to_server.close() @@ -982,19 +1029,18 @@ def test_pipe_like_stream_with_two_requests(self): # If two requests are read in one go, then two calls to # _serve_one_request should still process both of them as if they had # been received separately. - sample_request_bytes = b'command\n' - server, from_server = self.create_pipe_context( - sample_request_bytes * 2, None) + sample_request_bytes = b"command\n" + server, from_server = self.create_pipe_context(sample_request_bytes * 2, None) first_protocol = SampleRequest(expected_bytes=sample_request_bytes) server._serve_one_request(first_protocol) self.assertEqual(0, first_protocol.next_read_size()) - self.assertEqual(b'', from_server.getvalue()) + self.assertEqual(b"", from_server.getvalue()) self.assertFalse(server.finished) # Make a new protocol, call _serve_one_request with it to collect the # second request. second_protocol = SampleRequest(expected_bytes=sample_request_bytes) server._serve_one_request(second_protocol) - self.assertEqual(b'', from_server.getvalue()) + self.assertEqual(b"", from_server.getvalue()) self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes) self.assertFalse(server.finished) @@ -1002,7 +1048,7 @@ def test_socket_stream_with_two_requests(self): # If two requests are read in one go, then two calls to # _serve_one_request should still process both of them as if they had # been received separately. - sample_request_bytes = b'command\n' + sample_request_bytes = b"command\n" server, client_sock = self.create_socket_context(None) first_protocol = SampleRequest(expected_bytes=sample_request_bytes) # Put two whole requests on the wire. @@ -1017,50 +1063,49 @@ def test_socket_stream_with_two_requests(self): self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes) self.assertFalse(server.finished) server._disconnect_client() - self.assertEqual(b'', client_sock.recv(1)) + self.assertEqual(b"", client_sock.recv(1)) def test_pipe_like_stream_error_handling(self): # Use plain python BytesIO so we can monkey-patch the close method to # not discard the contents. from io import BytesIO - to_server = BytesIO(b'') + + to_server = BytesIO(b"") from_server = BytesIO() self.closed = False def close(): self.closed = True + from_server.close = close - server = self.create_pipe_medium( - to_server, from_server, None) - fake_protocol = ErrorRaisingProtocol(Exception('boom')) + server = self.create_pipe_medium(to_server, from_server, None) + fake_protocol = ErrorRaisingProtocol(Exception("boom")) server._serve_one_request(fake_protocol) - self.assertEqual(b'', from_server.getvalue()) + self.assertEqual(b"", from_server.getvalue()) self.assertTrue(self.closed) self.assertTrue(server.finished) def test_socket_stream_error_handling(self): server, client_sock = self.create_socket_context(None) - fake_protocol = ErrorRaisingProtocol(Exception('boom')) + fake_protocol = ErrorRaisingProtocol(Exception("boom")) server._serve_one_request(fake_protocol) # recv should not block, because the other end of the socket has been # closed. - self.assertEqual(b'', client_sock.recv(1)) + self.assertEqual(b"", client_sock.recv(1)) self.assertTrue(server.finished) def test_pipe_like_stream_keyboard_interrupt_handling(self): - server, from_server = self.create_pipe_context(b'', None) - fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom')) - self.assertRaises( - KeyboardInterrupt, server._serve_one_request, fake_protocol) - self.assertEqual(b'', from_server.getvalue()) + server, from_server = self.create_pipe_context(b"", None) + fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt("boom")) + self.assertRaises(KeyboardInterrupt, server._serve_one_request, fake_protocol) + self.assertEqual(b"", from_server.getvalue()) def test_socket_stream_keyboard_interrupt_handling(self): server, client_sock = self.create_socket_context(None) - fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom')) - self.assertRaises( - KeyboardInterrupt, server._serve_one_request, fake_protocol) + fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt("boom")) + self.assertRaises(KeyboardInterrupt, server._serve_one_request, fake_protocol) server._disconnect_client() - self.assertEqual(b'', client_sock.recv(1)) + self.assertEqual(b"", client_sock.recv(1)) def build_protocol_pipe_like(self, bytes): server, _ = self.create_pipe_context(bytes, None) @@ -1076,41 +1121,39 @@ def assertProtocolOne(self, server_protocol): # Use assertIs because assertIsInstance will wrongly pass # SmartServerRequestProtocolTwo (because it subclasses # SmartServerRequestProtocolOne). - self.assertIs( - type(server_protocol), protocol.SmartServerRequestProtocolOne) + self.assertIs(type(server_protocol), protocol.SmartServerRequestProtocolOne) def assertProtocolTwo(self, server_protocol): - self.assertIsInstance( - server_protocol, protocol.SmartServerRequestProtocolTwo) + self.assertIsInstance(server_protocol, protocol.SmartServerRequestProtocolTwo) def test_pipe_like_build_protocol_empty_bytes(self): # Any empty request (i.e. no bytes) is detected as protocol version one. - server_protocol = self.build_protocol_pipe_like(b'') + server_protocol = self.build_protocol_pipe_like(b"") self.assertProtocolOne(server_protocol) def test_socket_like_build_protocol_empty_bytes(self): # Any empty request (i.e. no bytes) is detected as protocol version one. - server_protocol = self.build_protocol_socket(b'') + server_protocol = self.build_protocol_socket(b"") self.assertProtocolOne(server_protocol) def test_pipe_like_build_protocol_non_two(self): # A request that doesn't start with "bzr request 2\n" is version one. - server_protocol = self.build_protocol_pipe_like(b'abc\n') + server_protocol = self.build_protocol_pipe_like(b"abc\n") self.assertProtocolOne(server_protocol) def test_socket_build_protocol_non_two(self): # A request that doesn't start with "bzr request 2\n" is version one. - server_protocol = self.build_protocol_socket(b'abc\n') + server_protocol = self.build_protocol_socket(b"abc\n") self.assertProtocolOne(server_protocol) def test_pipe_like_build_protocol_two(self): # A request that starts with "bzr request 2\n" is version two. - server_protocol = self.build_protocol_pipe_like(b'bzr request 2\n') + server_protocol = self.build_protocol_pipe_like(b"bzr request 2\n") self.assertProtocolTwo(server_protocol) def test_socket_build_protocol_two(self): # A request that starts with "bzr request 2\n" is version two. - server_protocol = self.build_protocol_socket(b'bzr request 2\n') + server_protocol = self.build_protocol_socket(b"bzr request 2\n") self.assertProtocolTwo(server_protocol) def test__build_protocol_returns_if_stopping(self): @@ -1125,26 +1168,26 @@ def test_socket_set_timeout(self): self.assertEqual(1.23, server._client_timeout) def test_pipe_set_timeout(self): - server = self.create_pipe_medium(None, None, None, - timeout=1.23) + server = self.create_pipe_medium(None, None, None, timeout=1.23) self.assertEqual(1.23, server._client_timeout) def test_socket_wait_for_bytes_with_timeout_with_data(self): server, client_sock = self.create_socket_context(None) - client_sock.sendall(b'data\n') + client_sock.sendall(b"data\n") # This should not block or consume any actual content self.assertFalse(server._wait_for_bytes_with_timeout(0.1)) data = server.read_bytes(5) - self.assertEqual(b'data\n', data) + self.assertEqual(b"data\n", data) def test_socket_wait_for_bytes_with_timeout_no_data(self): server, client_sock = self.create_socket_context(None) # This should timeout quickly, reporting that there wasn't any data - self.assertRaises(errors.ConnectionTimeout, - server._wait_for_bytes_with_timeout, 0.01) + self.assertRaises( + errors.ConnectionTimeout, server._wait_for_bytes_with_timeout, 0.01 + ) client_sock.close() data = server.read_bytes(1) - self.assertEqual(b'', data) + self.assertEqual(b"", data) def test_socket_wait_for_bytes_with_timeout_closed(self): server, client_sock = self.create_socket_context(None) @@ -1156,7 +1199,7 @@ def test_socket_wait_for_bytes_with_timeout_closed(self): client_sock.close() self.assertFalse(server._wait_for_bytes_with_timeout(10)) data = server.read_bytes(1) - self.assertEqual(b'', data) + self.assertEqual(b"", data) def test_socket_wait_for_bytes_with_shutdown(self): server, client_sock = self.create_socket_context(None) @@ -1174,21 +1217,20 @@ def test_socket_serve_timeout_closes_socket(self): # This should timeout quickly, and then close the connection so that # client_sock recv doesn't block. server.serve() - self.assertEqual(b'', client_sock.recv(1)) + self.assertEqual(b"", client_sock.recv(1)) def test_pipe_wait_for_bytes_with_timeout_with_data(self): # We intentionally use a real pipe here, so that we can 'select' on it. # You can't select() on a BytesIO (r_server, w_client) = os.pipe() self.addCleanup(os.close, w_client) - with os.fdopen(r_server, 'rb') as rf_server: - server = self.create_pipe_medium( - rf_server, None, None) - os.write(w_client, b'data\n') + with os.fdopen(r_server, "rb") as rf_server: + server = self.create_pipe_medium(rf_server, None, None) + os.write(w_client, b"data\n") # This should not block or consume any actual content server._wait_for_bytes_with_timeout(0.1) data = server.read_bytes(5) - self.assertEqual(b'data\n', data) + self.assertEqual(b"data\n", data) def test_pipe_wait_for_bytes_with_timeout_no_data(self): # We intentionally use a real pipe here, so that we can 'select' on it. @@ -1196,21 +1238,21 @@ def test_pipe_wait_for_bytes_with_timeout_no_data(self): (r_server, w_client) = os.pipe() # We can't add an os.close cleanup here, because we need to control # when the file handle gets closed ourselves. - with os.fdopen(r_server, 'rb') as rf_server: - server = self.create_pipe_medium( - rf_server, None, None) - if sys.platform == 'win32': + with os.fdopen(r_server, "rb") as rf_server: + server = self.create_pipe_medium(rf_server, None, None) + if sys.platform == "win32": # Windows cannot select() on a pipe, so we just always return server._wait_for_bytes_with_timeout(0.01) else: - self.assertRaises(errors.ConnectionTimeout, - server._wait_for_bytes_with_timeout, 0.01) + self.assertRaises( + errors.ConnectionTimeout, server._wait_for_bytes_with_timeout, 0.01 + ) os.close(w_client) data = server.read_bytes(5) - self.assertEqual(b'', data) + self.assertEqual(b"", data) def test_pipe_wait_for_bytes_no_fileno(self): - server, _ = self.create_pipe_context(b'', None) + server, _ = self.create_pipe_context(b"", None) # Our file doesn't support polling, so we should always just return # 'you have data to consume. server._wait_for_bytes_with_timeout(0.01) @@ -1225,31 +1267,27 @@ class TestGetProtocolFactoryForBytes(tests.TestCase): def test_version_three(self): result = medium._get_protocol_factory_for_bytes( - b'bzr message 3 (bzr 1.6)\nextra bytes') + b"bzr message 3 (bzr 1.6)\nextra bytes" + ) protocol_factory, remainder = result - self.assertEqual( - protocol.build_server_protocol_three, protocol_factory) - self.assertEqual(b'extra bytes', remainder) + self.assertEqual(protocol.build_server_protocol_three, protocol_factory) + self.assertEqual(b"extra bytes", remainder) def test_version_two(self): - result = medium._get_protocol_factory_for_bytes( - b'bzr request 2\nextra bytes') + result = medium._get_protocol_factory_for_bytes(b"bzr request 2\nextra bytes") protocol_factory, remainder = result - self.assertEqual( - protocol.SmartServerRequestProtocolTwo, protocol_factory) - self.assertEqual(b'extra bytes', remainder) + self.assertEqual(protocol.SmartServerRequestProtocolTwo, protocol_factory) + self.assertEqual(b"extra bytes", remainder) def test_version_one(self): """Version one requests have no version markers.""" - result = medium._get_protocol_factory_for_bytes(b'anything\n') + result = medium._get_protocol_factory_for_bytes(b"anything\n") protocol_factory, remainder = result - self.assertEqual( - protocol.SmartServerRequestProtocolOne, protocol_factory) - self.assertEqual(b'anything\n', remainder) + self.assertEqual(protocol.SmartServerRequestProtocolOne, protocol_factory) + self.assertEqual(b"anything\n", remainder) class TestSmartTCPServer(tests.TestCase): - def make_server(self): """Create a SmartTCPServer that we can exercise. @@ -1262,13 +1300,12 @@ def make_server(self): :return: (server, server_thread) """ - t = _mod_transport.get_transport_from_url('memory:///') + t = _mod_transport.get_transport_from_url("memory:///") server = _mod_server.SmartTCPServer(t, client_timeout=4.0) server._ACCEPT_TIMEOUT = 0.1 # We don't use 'localhost' because that might be an IPv6 address. - server.start_server('127.0.0.1', 0) - server_thread = threading.Thread(target=server.serve, - args=(self.id(),)) + server.start_server("127.0.0.1", 0) + server_thread = threading.Thread(target=server.serve, args=(self.id(),)) server_thread.start() # Ensure this gets called at some point self.addCleanup(server._stop_gracefully) @@ -1309,8 +1346,8 @@ def connect_to_server_and_hangup(self, server): def say_hello(self, client_sock): """Send the 'hello' smart RPC, and expect the response.""" - client_sock.send(b'hello\n') - self.assertEqual(b'ok\x012\n', client_sock.recv(5)) + client_sock.send(b"hello\n") + self.assertEqual(b"ok\x012\n", client_sock.recv(5)) def shutdown_server_cleanly(self, server, server_thread): server._stop_gracefully() @@ -1321,10 +1358,10 @@ def shutdown_server_cleanly(self, server, server_thread): def test_get_error_unexpected(self): """Error reported by server with no specific representation.""" - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) class FlakyTransport: - base = 'a_url' + base = "a_url" def external_url(self): return self.base @@ -1341,9 +1378,8 @@ def get_backing_transport(self, backing_transport_server): self.addCleanup(smart_server.stop_server) t = remote.RemoteTCPTransport(smart_server.get_url()) self.addCleanup(t.disconnect) - err = self.assertRaises(UnknownErrorFromSmartServer, - t.get, 'something') - self.assertContainsRe(str(err), 'some random exception') + err = self.assertRaises(UnknownErrorFromSmartServer, t.get, "something") + self.assertContainsRe(str(err), "some random exception") def test_propagates_timeout(self): server = _mod_server.SmartTCPServer(None, client_timeout=1.23) @@ -1354,7 +1390,7 @@ def test_propagates_timeout(self): def test_serve_conn_tracks_connections(self): server = _mod_server.SmartTCPServer(None, client_timeout=4.0) server_sock, client_sock = portable_socket_pair() - server.serve_conn(server_sock, f'-{self.id()}') + server.serve_conn(server_sock, f"-{self.id()}") self.assertEqual(1, len(server._active_connections)) # We still want to talk on the connection. Polling should indicate it # is still active. @@ -1426,18 +1462,17 @@ def test_graceful_shutdown_waits_for_clients_to_stop(self): # We need something big enough that it won't fit in a single recv. So # the server thread gets blocked writing content to the client until we # finish reading on the client. - server.backing_transport.put_bytes('bigfile', - b'a' * 1024 * 1024) + server.backing_transport.put_bytes("bigfile", b"a" * 1024 * 1024) client_sock = self.connect_to_server(server) self.say_hello(client_sock) _, server_side_thread = server._active_connections[0] # Start the RPC, but don't finish reading the response client_medium = medium.SmartClientAlreadyConnectedSocketMedium( - 'base', client_sock) + "base", client_sock + ) client_client = client._SmartClient(client_medium) - resp, response_handler = client_client.call_expecting_body(b'get', - b'bigfile') - self.assertEqual((b'ok',), resp) + resp, response_handler = client_client.call_expecting_body(b"get", b"bigfile") + self.assertEqual((b"ok",), resp) # Ask the server to stop gracefully, and wait for it. server._stop_gracefully() self.connect_to_server_and_hangup(server) @@ -1450,10 +1485,16 @@ def test_graceful_shutdown_waits_for_clients_to_stop(self): server_thread.join() self.assertTrue(server._fully_stopped.is_set()) log = self.get_log() - self.assertThat(log, DocTestMatches("""\ + self.assertThat( + log, + DocTestMatches( + """\ INFO Requested to stop gracefully ... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ... -""", flags=doctest.ELLIPSIS | doctest.REPORT_UDIFF)) +""", + flags=doctest.ELLIPSIS | doctest.REPORT_UDIFF, + ), + ) def test_stop_gracefully_tells_handlers_to_stop(self): server, server_thread = self.make_server() @@ -1491,17 +1532,20 @@ def start_server(self, readonly=False, backing_transport=None): self.addCleanup(mem_server.stop_server) self.permit_url(mem_server.get_url()) self.backing_transport = _mod_transport.get_transport_from_url( - mem_server.get_url()) + mem_server.get_url() + ) else: self.backing_transport = backing_transport if readonly: self.real_backing_transport = self.backing_transport self.backing_transport = _mod_transport.get_transport_from_url( - "readonly+" + self.backing_transport.abspath('.')) - self.server = _mod_server.SmartTCPServer(self.backing_transport, - client_timeout=4.0) - self.server.start_server('127.0.0.1', 0) - self.server.start_background_thread('-' + self.id()) + "readonly+" + self.backing_transport.abspath(".") + ) + self.server = _mod_server.SmartTCPServer( + self.backing_transport, client_timeout=4.0 + ) + self.server.start_server("127.0.0.1", 0) + self.server.start_background_thread("-" + self.id()) self.transport = remote.RemoteTCPTransport(self.server.get_url()) self.addCleanup(self.stop_server) self.permit_url(self.server.get_url()) @@ -1512,33 +1556,32 @@ def stop_server(self): This must be re-entrant as some tests will call it explicitly in addition to the normal cleanup. """ - if getattr(self, 'transport', None): + if getattr(self, "transport", None): self.transport.disconnect() del self.transport - if getattr(self, 'server', None): + if getattr(self, "server", None): self.server.stop_background_thread() del self.server class TestServerSocketUsage(SmartTCPTests): - def test_server_start_stop(self): """It should be safe to stop the server with no requests.""" self.start_server() t = remote.RemoteTCPTransport(self.server.get_url()) self.stop_server() - self.assertRaises(ConnectionError, t.has, '.') + self.assertRaises(ConnectionError, t.has, ".") def test_server_closes_listening_sock_on_shutdown_after_request(self): """The server should close its listening socket when it's stopped.""" self.start_server() server_url = self.server.get_url() - self.transport.has('.') + self.transport.has(".") self.stop_server() # if the listening socket has closed, we should get a BADFD error # when connecting, rather than a hang. t = remote.RemoteTCPTransport(server_url) - self.assertRaises(ConnectionError, t.has, '.') + self.assertRaises(ConnectionError, t.has, ".") class WritableEndToEndTests(SmartTCPTests): @@ -1550,74 +1593,71 @@ def setUp(self): def test_start_tcp_server(self): url = self.server.get_url() - self.assertContainsRe(url, r'^bzr://127\.0\.0\.1:[0-9]{2,}/') + self.assertContainsRe(url, r"^bzr://127\.0\.0\.1:[0-9]{2,}/") def test_smart_transport_has(self): """Checking for file existence over smart.""" - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) self.backing_transport.put_bytes("foo", b"contents of foo\n") self.assertTrue(self.transport.has("foo")) self.assertFalse(self.transport.has("non-foo")) def test_smart_transport_get(self): """Read back a file over smart.""" - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) self.backing_transport.put_bytes("foo", b"contents\nof\nfoo\n") fp = self.transport.get("foo") - self.assertEqual(b'contents\nof\nfoo\n', fp.read()) + self.assertEqual(b"contents\nof\nfoo\n", fp.read()) def test_get_error_enoent(self): """Error reported from server getting nonexistent file.""" # The path in a raised NoSuchFile exception should be the precise path # asked for by the client. This gives meaningful and unsurprising errors # for users. - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) err = self.assertRaises( - _mod_transport.NoSuchFile, self.transport.get, 'not%20a%20file') - self.assertSubset([err.path], ['not%20a%20file', './not%20a%20file']) + _mod_transport.NoSuchFile, self.transport.get, "not%20a%20file" + ) + self.assertSubset([err.path], ["not%20a%20file", "./not%20a%20file"]) def test_simple_clone_conn(self): """Test that cloning reuses the same connection.""" # we create a real connection not a loopback one, but it will use the # same server and pipes - conn2 = self.transport.clone('.') - self.assertIs(self.transport.get_smart_medium(), - conn2.get_smart_medium()) + conn2 = self.transport.clone(".") + self.assertIs(self.transport.get_smart_medium(), conn2.get_smart_medium()) def test__remote_path(self): - self.assertEqual(b'/foo/bar', - self.transport._remote_path('foo/bar')) + self.assertEqual(b"/foo/bar", self.transport._remote_path("foo/bar")) def test_clone_changes_base(self): """Cloning transport produces one with a new base location.""" - conn2 = self.transport.clone('subdir') - self.assertEqual(self.transport.base + 'subdir/', - conn2.base) + conn2 = self.transport.clone("subdir") + self.assertEqual(self.transport.base + "subdir/", conn2.base) def test_open_dir(self): """Test changing directory.""" - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) transport = self.transport - self.backing_transport.mkdir('toffee') - self.backing_transport.mkdir('toffee/apple') - self.assertEqual(b'/toffee', transport._remote_path('toffee')) - toffee_trans = transport.clone('toffee') + self.backing_transport.mkdir("toffee") + self.backing_transport.mkdir("toffee/apple") + self.assertEqual(b"/toffee", transport._remote_path("toffee")) + toffee_trans = transport.clone("toffee") # Check that each transport has only the contents of its directory # directly visible. If state was being held in the wrong object, it's # conceivable that cloning a transport would alter the state of the # cloned-from transport. - self.assertTrue(transport.has('toffee')) - self.assertFalse(toffee_trans.has('toffee')) - self.assertFalse(transport.has('apple')) - self.assertTrue(toffee_trans.has('apple')) + self.assertTrue(transport.has("toffee")) + self.assertFalse(toffee_trans.has("toffee")) + self.assertFalse(transport.has("apple")) + self.assertTrue(toffee_trans.has("apple")) def test_open_bzrdir(self): """Open an existing bzrdir over smart transport.""" transport = self.transport t = self.backing_transport bzrdir.BzrDirFormat.get_default_format().initialize_on_transport(t) - controldir.ControlDir.open_containing_from_transport( - transport) + controldir.ControlDir.open_containing_from_transport(transport) class ReadOnlyEndToEndTests(SmartTCPTests): @@ -1625,29 +1665,28 @@ class ReadOnlyEndToEndTests(SmartTCPTests): def test_mkdir_error_readonly(self): """TransportNotPossible should be preserved from the backing transport.""" - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) self.start_server(readonly=True) - self.assertRaises(errors.TransportNotPossible, self.transport.mkdir, - 'foo') + self.assertRaises(errors.TransportNotPossible, self.transport.mkdir, "foo") def test_rename_error_readonly(self): """TransportNotPossible should be preserved from the backing transport.""" - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) self.start_server(readonly=True) - self.assertRaises(errors.TransportNotPossible, self.transport.rename, - 'foo', 'bar') + self.assertRaises( + errors.TransportNotPossible, self.transport.rename, "foo", "bar" + ) def test_open_write_stream_error_readonly(self): """TransportNotPossible should be preserved from the backing transport.""" - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) self.start_server(readonly=True) self.assertRaises( - errors.TransportNotPossible, self.transport.open_write_stream, - 'foo') + errors.TransportNotPossible, self.transport.open_write_stream, "foo" + ) class TestServerHooks(SmartTCPTests): - def capture_server_call(self, backing_urls, public_url): """Record a server_started|stopped hook firing.""" self.hook_calls.append((backing_urls, public_url)) @@ -1655,45 +1694,56 @@ def capture_server_call(self, backing_urls, public_url): def test_server_started_hook_memory(self): """The server_started hook fires when the server is started.""" self.hook_calls = [] - _mod_server.SmartTCPServer.hooks.install_named_hook('server_started', - self.capture_server_call, None) + _mod_server.SmartTCPServer.hooks.install_named_hook( + "server_started", self.capture_server_call, None + ) self.start_server() # at this point, the server will be starting a thread up. # there is no indicator at the moment, so bodge it by doing a request. - self.transport.has('.') + self.transport.has(".") # The default test server uses MemoryTransport and that has no external # url: - self.assertEqual([([self.backing_transport.base], self.transport.base)], - self.hook_calls) + self.assertEqual( + [([self.backing_transport.base], self.transport.base)], self.hook_calls + ) def test_server_started_hook_file(self): """The server_started hook fires when the server is started.""" self.hook_calls = [] - _mod_server.SmartTCPServer.hooks.install_named_hook('server_started', - self.capture_server_call, None) - self.start_server( - backing_transport=_mod_transport.get_transport_from_path(".")) + _mod_server.SmartTCPServer.hooks.install_named_hook( + "server_started", self.capture_server_call, None + ) + self.start_server(backing_transport=_mod_transport.get_transport_from_path(".")) # at this point, the server will be starting a thread up. # there is no indicator at the moment, so bodge it by doing a request. - self.transport.has('.') + self.transport.has(".") # The default test server uses MemoryTransport and that has no external # url: - self.assertEqual([([ - self.backing_transport.base, self.backing_transport.external_url()], - self.transport.base)], - self.hook_calls) + self.assertEqual( + [ + ( + [ + self.backing_transport.base, + self.backing_transport.external_url(), + ], + self.transport.base, + ) + ], + self.hook_calls, + ) def test_server_stopped_hook_simple_memory(self): """The server_stopped hook fires when the server is stopped.""" self.hook_calls = [] - _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped', - self.capture_server_call, None) + _mod_server.SmartTCPServer.hooks.install_named_hook( + "server_stopped", self.capture_server_call, None + ) self.start_server() result = [([self.backing_transport.base], self.transport.base)] # check the stopping message isn't emitted up front. self.assertEqual([], self.hook_calls) # nor after a single message - self.transport.has('.') + self.transport.has(".") self.assertEqual([], self.hook_calls) # clean up the server self.stop_server() @@ -1703,22 +1753,27 @@ def test_server_stopped_hook_simple_memory(self): def test_server_stopped_hook_simple_file(self): """The server_stopped hook fires when the server is stopped.""" self.hook_calls = [] - _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped', - self.capture_server_call, None) - self.start_server( - backing_transport=_mod_transport.get_transport_from_path(".")) - result = [( - [self.backing_transport.base, self.backing_transport.external_url()], self.transport.base)] + _mod_server.SmartTCPServer.hooks.install_named_hook( + "server_stopped", self.capture_server_call, None + ) + self.start_server(backing_transport=_mod_transport.get_transport_from_path(".")) + result = [ + ( + [self.backing_transport.base, self.backing_transport.external_url()], + self.transport.base, + ) + ] # check the stopping message isn't emitted up front. self.assertEqual([], self.hook_calls) # nor after a single message - self.transport.has('.') + self.transport.has(".") self.assertEqual([], self.hook_calls) # clean up the server self.stop_server() # now it should have fired. self.assertEqual(result, self.hook_calls) + # TODO: test that when the server suffers an exception that it calls the # server-stopped hook. @@ -1732,20 +1787,21 @@ class SmartServerCommandTests(tests.TestCaseWithTransport): """ def test_hello(self): - cmd = _mod_request.HelloRequest(None, '/') + cmd = _mod_request.HelloRequest(None, "/") response = cmd.execute() - self.assertEqual((b'ok', b'2'), response.args) + self.assertEqual((b"ok", b"2"), response.args) self.assertEqual(None, response.body) def test_get_bundle(self): from breezy.bzr.bundle import serializer - wt = self.make_branch_and_tree('.') - self.build_tree_contents([('hello', b'hello world')]) - wt.add('hello') - rev_id = wt.commit('add hello') - cmd = _mod_request.GetBundleRequest(self.get_transport(), '/') - response = cmd.execute(b'.', rev_id) + wt = self.make_branch_and_tree(".") + self.build_tree_contents([("hello", b"hello world")]) + wt.add("hello") + rev_id = wt.commit("add hello") + + cmd = _mod_request.GetBundleRequest(self.get_transport(), "/") + response = cmd.execute(b".", rev_id) serializer.read_bundle(BytesIO(response.body)) self.assertEqual((), response.args) @@ -1755,41 +1811,43 @@ class SmartServerRequestHandlerTests(tests.TestCaseWithTransport): def setUp(self): super().setUp() - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) def build_handler(self, transport): """Returns a handler for the commands in protocol version one.""" return _mod_request.SmartServerRequestHandler( - transport, _mod_request.request_handlers, '/') + transport, _mod_request.request_handlers, "/" + ) def test_construct_request_handler(self): """Constructing a request handler should be easy and set defaults.""" - handler = _mod_request.SmartServerRequestHandler(None, commands=None, - root_client_path='/') + handler = _mod_request.SmartServerRequestHandler( + None, commands=None, root_client_path="/" + ) self.assertFalse(handler.finished_reading) def test_hello(self): handler = self.build_handler(None) - handler.args_received((b'hello',)) - self.assertEqual((b'ok', b'2'), handler.response.args) + handler.args_received((b"hello",)) + self.assertEqual((b"ok", b"2"), handler.response.args) self.assertEqual(None, handler.response.body) def test_disable_vfs_handler_classes_via_environment(self): # VFS handler classes will raise an error from "execute" if # BRZ_NO_SMART_VFS is set. - handler = vfs.HasRequest(None, '/') + handler = vfs.HasRequest(None, "/") # set environment variable after construction to make sure it's # examined. - self.overrideEnv('BRZ_NO_SMART_VFS', '') + self.overrideEnv("BRZ_NO_SMART_VFS", "") self.assertRaises(_mod_request.DisabledMethod, handler.execute) def test_readonly_exception_becomes_transport_not_possible(self): """The response for a read-only error is ('ReadOnlyError').""" handler = self.build_handler(self.get_readonly_transport()) # send a mkdir for foo, with no explicit mode - should fail. - handler.args_received((b'mkdir', b'foo', b'')) + handler.args_received((b"mkdir", b"foo", b"")) # and the failure should be an explicit ReadOnlyError - self.assertEqual((b"ReadOnlyError", ), handler.response.args) + self.assertEqual((b"ReadOnlyError",), handler.response.args) # XXX: TODO: test that other TransportNotPossible errors are # presented as TransportNotPossible - not possible to do that # until I figure out how to trigger that relatively cleanly via @@ -1798,104 +1856,96 @@ def test_readonly_exception_becomes_transport_not_possible(self): def test_hello_has_finished_body_on_dispatch(self): """The 'hello' command should set finished_reading.""" handler = self.build_handler(None) - handler.args_received((b'hello',)) + handler.args_received((b"hello",)) self.assertTrue(handler.finished_reading) self.assertNotEqual(None, handler.response) def test_put_bytes_non_atomic(self): """'put_...' should set finished_reading after reading the bytes.""" handler = self.build_handler(self.get_transport()) - handler.args_received((b'put_non_atomic', b'a-file', b'', b'F', b'')) + handler.args_received((b"put_non_atomic", b"a-file", b"", b"F", b"")) self.assertFalse(handler.finished_reading) - handler.accept_body(b'1234') + handler.accept_body(b"1234") self.assertFalse(handler.finished_reading) - handler.accept_body(b'5678') + handler.accept_body(b"5678") handler.end_of_body() self.assertTrue(handler.finished_reading) - self.assertEqual((b'ok', ), handler.response.args) + self.assertEqual((b"ok",), handler.response.args) self.assertEqual(None, handler.response.body) def test_readv_accept_body(self): """'readv' should set finished_reading after reading offsets.""" # noqa: D403 - self.build_tree(['a-file']) + self.build_tree(["a-file"]) handler = self.build_handler(self.get_readonly_transport()) - handler.args_received((b'readv', b'a-file')) + handler.args_received((b"readv", b"a-file")) self.assertFalse(handler.finished_reading) - handler.accept_body(b'2,') + handler.accept_body(b"2,") self.assertFalse(handler.finished_reading) - handler.accept_body(b'3') + handler.accept_body(b"3") handler.end_of_body() self.assertTrue(handler.finished_reading) - self.assertEqual((b'readv', ), handler.response.args) + self.assertEqual((b"readv",), handler.response.args) # co - nte - nt of a-file is the file contents we are extracting from. - self.assertEqual(b'nte', handler.response.body) + self.assertEqual(b"nte", handler.response.body) def test_readv_short_read_response_contents(self): """'readv' when a short read occurs sets the response appropriately.""" # noqa: D403 - self.build_tree(['a-file']) + self.build_tree(["a-file"]) handler = self.build_handler(self.get_readonly_transport()) - handler.args_received((b'readv', b'a-file')) + handler.args_received((b"readv", b"a-file")) # read beyond the end of the file. - handler.accept_body(b'100,1') + handler.accept_body(b"100,1") handler.end_of_body() self.assertTrue(handler.finished_reading) - self.assertEqual((b'ShortReadvError', b'./a-file', b'100', b'1', b'0'), - handler.response.args) + self.assertEqual( + (b"ShortReadvError", b"./a-file", b"100", b"1", b"0"), handler.response.args + ) self.assertEqual(None, handler.response.body) class RemoteTransportRegistration(tests.TestCase): - def test_registration(self): - t = _mod_transport.get_transport_from_url('bzr+ssh://example.com/path') + t = _mod_transport.get_transport_from_url("bzr+ssh://example.com/path") self.assertIsInstance(t, remote.RemoteSSHTransport) - self.assertEqual('example.com', t._parsed_url.host) + self.assertEqual("example.com", t._parsed_url.host) def test_bzr_https(self): # https://bugs.launchpad.net/bzr/+bug/128456 - t = _mod_transport.get_transport_from_url( - 'bzr+https://example.com/path') + t = _mod_transport.get_transport_from_url("bzr+https://example.com/path") self.assertIsInstance(t, remote.RemoteHTTPTransport) - self.assertStartsWith( - t._http_transport.base, - 'https://') + self.assertStartsWith(t._http_transport.base, "https://") class TestRemoteTransport(tests.TestCase): - def test_use_connection_factory(self): # We want to be able to pass a client as a parameter to RemoteTransport. - input = BytesIO(b'ok\n3\nbardone\n') + input = BytesIO(b"ok\n3\nbardone\n") output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') - transport = remote.RemoteTransport( - 'bzr://localhost/', medium=client_medium) + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") + transport = remote.RemoteTransport("bzr://localhost/", medium=client_medium) # Disable version detection. client_medium._protocol_version = 1 # We want to make sure the client is used when the first remote # method is called. No data should have been sent, or read. self.assertEqual(0, input.tell()) - self.assertEqual(b'', output.getvalue()) + self.assertEqual(b"", output.getvalue()) # Now call a method that should result in one request: as the # transport makes its own protocol instances, we check on the wire. # XXX: TODO: give the transport a protocol factory, which can make # an instrumented protocol for us. - self.assertEqual(b'bar', transport.get_bytes('foo')) + self.assertEqual(b"bar", transport.get_bytes("foo")) # only the needed data should have been sent/received. self.assertEqual(13, input.tell()) - self.assertEqual(b'get\x01/foo\n', output.getvalue()) + self.assertEqual(b"get\x01/foo\n", output.getvalue()) def test__translate_error_readonly(self): """Sending a ReadOnlyError to _translate_error raises TransportNotPossible.""" - client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base') - transport = remote.RemoteTransport( - 'bzr://localhost/', medium=client_medium) - err = errors.ErrorFromSmartServer((b"ReadOnlyError", )) - self.assertRaises(errors.TransportNotPossible, - transport._translate_error, err) + client_medium = medium.SmartSimplePipesClientMedium(None, None, "base") + transport = remote.RemoteTransport("bzr://localhost/", medium=client_medium) + err = errors.ErrorFromSmartServer((b"ReadOnlyError",)) + self.assertRaises(errors.TransportNotPossible, transport._translate_error, err) class TestSmartProtocol(tests.TestCase): @@ -1928,8 +1978,7 @@ def make_client_protocol_and_output(self, input_bytes=None): else: input = BytesIO(input_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() if self.client_protocol_class is not None: client_protocol = self.client_protocol_class(request) @@ -1940,9 +1989,9 @@ def make_client_protocol_and_output(self, input_bytes=None): requester = self.request_encoder(request) response_handler = message.ConventionalResponseHandler() response_protocol = self.response_decoder( - response_handler, expect_version_marker=True) - response_handler.setProtoAndMediumRequest( - response_protocol, request) + response_handler, expect_version_marker=True + ) + response_handler.setProtoAndMediumRequest(response_protocol, request) return requester, response_handler, output def make_client_protocol(self, input_bytes=None): @@ -1958,12 +2007,15 @@ def make_server_protocol(self): def setUp(self): super().setUp() self.response_marker = getattr( - self.client_protocol_class, 'response_marker', None) + self.client_protocol_class, "response_marker", None + ) self.request_marker = getattr( - self.client_protocol_class, 'request_marker', None) + self.client_protocol_class, "request_marker", None + ) - def assertOffsetSerialisation(self, expected_offsets, expected_serialised, - requester): + def assertOffsetSerialisation( + self, expected_offsets, expected_serialised, requester + ): """Check that smart (de)serialises offsets as expected. We check both serialisation and deserialisation at the same time @@ -1975,7 +2027,7 @@ def assertOffsetSerialisation(self, expected_offsets, expected_serialised, """ # XXX: '_deserialise_offsets' should be a method of the # SmartServerRequestProtocol in future. - readv_cmd = vfs.ReadvRequest(None, '/') + readv_cmd = vfs.ReadvRequest(None, "/") offsets = readv_cmd._deserialise_offsets(expected_serialised) self.assertEqual(expected_offsets, offsets) serialised = requester._serialise_offsets(offsets) @@ -1985,24 +2037,27 @@ def build_protocol_waiting_for_body(self): smart_protocol, out_stream = self.make_server_protocol() smart_protocol._has_dispatched = True smart_protocol.request = _mod_request.SmartServerRequestHandler( - None, _mod_request.request_handlers, '/') + None, _mod_request.request_handlers, "/" + ) # GZ 2010-08-10: Cycle with closure affects 4 tests class FakeCommand(_mod_request.SmartServerRequest): def do_body(self_cmd, body_bytes): # noqa: N805 self.end_received = True - self.assertEqual(b'abcdefg', body_bytes) - return _mod_request.SuccessfulSmartServerResponse((b'ok', )) + self.assertEqual(b"abcdefg", body_bytes) + return _mod_request.SuccessfulSmartServerResponse((b"ok",)) + smart_protocol.request._command = FakeCommand(None) # Call accept_bytes to make sure that internal state like _body_decoder # is initialised. This test should probably be given a clearer # interface to work with that will not cause this inconsistency. # -- Andrew Bennetts, 2006-09-28 - smart_protocol.accept_bytes(b'') + smart_protocol.accept_bytes(b"") return smart_protocol - def assertServerToClientEncoding(self, expected_bytes, expected_tuple, - input_tuples): + def assertServerToClientEncoding( + self, expected_bytes, expected_tuple, input_tuples + ): """Assert that each input_tuple serialises as expected_bytes, and the bytes deserialise as expected_tuple. """ @@ -2011,25 +2066,27 @@ def assertServerToClientEncoding(self, expected_bytes, expected_tuple, for input_tuple in input_tuples: server_protocol, server_output = self.make_server_protocol() server_protocol._send_response( - _mod_request.SuccessfulSmartServerResponse(input_tuple)) + _mod_request.SuccessfulSmartServerResponse(input_tuple) + ) self.assertEqual(expected_bytes, server_output.getvalue()) # check the decoding of the client smart_protocol from expected_bytes: requester, response_handler = self.make_client_protocol(expected_bytes) - requester.call(b'foo') - self.assertEqual( - expected_tuple, response_handler.read_response_tuple()) + requester.call(b"foo") + self.assertEqual(expected_tuple, response_handler.read_response_tuple()) class CommonSmartProtocolTestMixin: - def test_connection_closed_reporting(self): requester, response_handler = self.make_client_protocol() - requester.call(b'hello') - ex = self.assertRaises(ConnectionResetError, - response_handler.read_response_tuple) - self.assertEqual("Unexpected end of message. Please check connectivity " - "and permissions, and report a bug if problems persist.", - str(ex)) + requester.call(b"hello") + ex = self.assertRaises( + ConnectionResetError, response_handler.read_response_tuple + ) + self.assertEqual( + "Unexpected end of message. Please check connectivity " + "and permissions, and report a bug if problems persist.", + str(ex), + ) def test_server_offset_serialisation(self): r"""The Smart protocol serialises offsets as a comma and \n string. @@ -2039,16 +2096,17 @@ def test_server_offset_serialisation(self): one that should coalesce. """ requester, response_handler = self.make_client_protocol() - self.assertOffsetSerialisation([], b'', requester) - self.assertOffsetSerialisation([(1, 2)], b'1,2', requester) - self.assertOffsetSerialisation([(10, 40), (0, 5)], b'10,40\n0,5', - requester) - self.assertOffsetSerialisation([(1, 2), (3, 4), (100, 200)], - b'1,2\n3,4\n100,200', requester) + self.assertOffsetSerialisation([], b"", requester) + self.assertOffsetSerialisation([(1, 2)], b"1,2", requester) + self.assertOffsetSerialisation([(10, 40), (0, 5)], b"10,40\n0,5", requester) + self.assertOffsetSerialisation( + [(1, 2), (3, 4), (100, 200)], b"1,2\n3,4\n100,200", requester + ) class TestVersionOneFeaturesInProtocolOne( - TestSmartProtocol, CommonSmartProtocolTestMixin): + TestSmartProtocol, CommonSmartProtocolTestMixin +): """Tests for version one smart protocol features as implemeted by version one. """ @@ -2058,100 +2116,100 @@ class TestVersionOneFeaturesInProtocolOne( def test_construct_version_one_server_protocol(self): smart_protocol = protocol.SmartServerRequestProtocolOne(None, None) - self.assertEqual(b'', smart_protocol.unused_data) - self.assertEqual(b'', smart_protocol.in_buffer) + self.assertEqual(b"", smart_protocol.unused_data) + self.assertEqual(b"", smart_protocol.in_buffer) self.assertFalse(smart_protocol._has_dispatched) self.assertEqual(1, smart_protocol.next_read_size()) def test_construct_version_one_client_protocol(self): # we can construct a client protocol from a client medium request output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - None, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, output, "base") request = client_medium.get_request() protocol.SmartClientRequestProtocolOne(request) def test_accept_bytes_of_bad_request_to_protocol(self): out_stream = BytesIO() - smart_protocol = protocol.SmartServerRequestProtocolOne( - None, out_stream.write) - smart_protocol.accept_bytes(b'abc') - self.assertEqual(b'abc', smart_protocol.in_buffer) - smart_protocol.accept_bytes(b'\n') + smart_protocol = protocol.SmartServerRequestProtocolOne(None, out_stream.write) + smart_protocol.accept_bytes(b"abc") + self.assertEqual(b"abc", smart_protocol.in_buffer) + smart_protocol.accept_bytes(b"\n") self.assertEqual( b"error\x01Generic bzr smart protocol error: bad request 'abc'\n", - out_stream.getvalue()) + out_stream.getvalue(), + ) self.assertTrue(smart_protocol._has_dispatched) self.assertEqual(0, smart_protocol.next_read_size()) def test_accept_body_bytes_to_protocol(self): protocol = self.build_protocol_waiting_for_body() self.assertEqual(6, protocol.next_read_size()) - protocol.accept_bytes(b'7\nabc') + protocol.accept_bytes(b"7\nabc") self.assertEqual(9, protocol.next_read_size()) - protocol.accept_bytes(b'defgd') - protocol.accept_bytes(b'one\n') + protocol.accept_bytes(b"defgd") + protocol.accept_bytes(b"one\n") self.assertEqual(0, protocol.next_read_size()) self.assertTrue(self.end_received) def test_accept_request_and_body_all_at_once(self): - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) mem_transport = memory.MemoryTransport() - mem_transport.put_bytes('foo', b'abcdefghij') + mem_transport.put_bytes("foo", b"abcdefghij") out_stream = BytesIO() - smart_protocol = protocol.SmartServerRequestProtocolOne(mem_transport, - out_stream.write) - smart_protocol.accept_bytes(b'readv\x01foo\n3\n3,3done\n') + smart_protocol = protocol.SmartServerRequestProtocolOne( + mem_transport, out_stream.write + ) + smart_protocol.accept_bytes(b"readv\x01foo\n3\n3,3done\n") self.assertEqual(0, smart_protocol.next_read_size()) - self.assertEqual(b'readv\n3\ndefdone\n', out_stream.getvalue()) - self.assertEqual(b'', smart_protocol.unused_data) - self.assertEqual(b'', smart_protocol.in_buffer) + self.assertEqual(b"readv\n3\ndefdone\n", out_stream.getvalue()) + self.assertEqual(b"", smart_protocol.unused_data) + self.assertEqual(b"", smart_protocol.in_buffer) def test_accept_excess_bytes_are_preserved(self): out_stream = BytesIO() - smart_protocol = protocol.SmartServerRequestProtocolOne( - None, out_stream.write) - smart_protocol.accept_bytes(b'hello\nhello\n') + smart_protocol = protocol.SmartServerRequestProtocolOne(None, out_stream.write) + smart_protocol.accept_bytes(b"hello\nhello\n") self.assertEqual(b"ok\x012\n", out_stream.getvalue()) self.assertEqual(b"hello\n", smart_protocol.unused_data) self.assertEqual(b"", smart_protocol.in_buffer) def test_accept_excess_bytes_after_body(self): protocol = self.build_protocol_waiting_for_body() - protocol.accept_bytes(b'7\nabcdefgdone\nX') + protocol.accept_bytes(b"7\nabcdefgdone\nX") self.assertTrue(self.end_received) self.assertEqual(b"X", protocol.unused_data) self.assertEqual(b"", protocol.in_buffer) - protocol.accept_bytes(b'Y') + protocol.accept_bytes(b"Y") self.assertEqual(b"XY", protocol.unused_data) self.assertEqual(b"", protocol.in_buffer) def test_accept_excess_bytes_after_dispatch(self): out_stream = BytesIO() - smart_protocol = protocol.SmartServerRequestProtocolOne( - None, out_stream.write) - smart_protocol.accept_bytes(b'hello\n') + smart_protocol = protocol.SmartServerRequestProtocolOne(None, out_stream.write) + smart_protocol.accept_bytes(b"hello\n") self.assertEqual(b"ok\x012\n", out_stream.getvalue()) - smart_protocol.accept_bytes(b'hel') + smart_protocol.accept_bytes(b"hel") self.assertEqual(b"hel", smart_protocol.unused_data) - smart_protocol.accept_bytes(b'lo\n') + smart_protocol.accept_bytes(b"lo\n") self.assertEqual(b"hello\n", smart_protocol.unused_data) self.assertEqual(b"", smart_protocol.in_buffer) def test__send_response_sets_finished_reading(self): - smart_protocol = protocol.SmartServerRequestProtocolOne( - None, lambda x: None) + smart_protocol = protocol.SmartServerRequestProtocolOne(None, lambda x: None) self.assertEqual(1, smart_protocol.next_read_size()) smart_protocol._send_response( - _mod_request.SuccessfulSmartServerResponse((b'x',))) + _mod_request.SuccessfulSmartServerResponse((b"x",)) + ) self.assertEqual(0, smart_protocol.next_read_size()) def test__send_response_errors_with_base_response(self): """Ensure that only the Successful/Failed subclasses are used.""" - smart_protocol = protocol.SmartServerRequestProtocolOne( - None, lambda x: None) - self.assertRaises(AttributeError, smart_protocol._send_response, - _mod_request.SmartServerResponse((b'x',))) + smart_protocol = protocol.SmartServerRequestProtocolOne(None, lambda x: None) + self.assertRaises( + AttributeError, + smart_protocol._send_response, + _mod_request.SmartServerResponse((b"x",)), + ) def test_query_version(self): """query_version on a SmartClientProtocolOne should return a number. @@ -2163,10 +2221,9 @@ def test_query_version(self): # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the # response of tuple-encoded (ok, 1). Also, separately we should test # the error if the response is a non-understood version. - input = BytesIO(b'ok\x012\n') + input = BytesIO(b"ok\x012\n") output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolOne(request) self.assertEqual(2, smart_protocol.query_version()) @@ -2175,13 +2232,14 @@ def test_client_call_empty_response(self): # protocol.call() can get back an empty tuple as a response. This occurs # when the parsed line is an empty line, and results in a tuple with # one element - an empty string. - self.assertServerToClientEncoding(b'\n', (b'', ), [(), (b'', )]) + self.assertServerToClientEncoding(b"\n", (b"",), [(), (b"",)]) def test_client_call_three_element_response(self): # protocol.call() can get back tuples of other lengths. A three element # tuple should be unpacked as three strings. - self.assertServerToClientEncoding(b'a\x01b\x0134\n', (b'a', b'b', b'34'), - [(b'a', b'b', b'34')]) + self.assertServerToClientEncoding( + b"a\x01b\x0134\n", (b"a", b"b", b"34"), [(b"a", b"b", b"34")] + ) def test_client_call_with_body_bytes_uploads(self): # protocol.call_with_body_bytes should length-prefix the bytes onto the @@ -2189,11 +2247,10 @@ def test_client_call_with_body_bytes_uploads(self): expected_bytes = b"foo\n7\nabcdefgdone\n" input = BytesIO(b"\n") output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolOne(request) - smart_protocol.call_with_body_bytes((b'foo', ), b"abcdefg") + smart_protocol.call_with_body_bytes((b"foo",), b"abcdefg") self.assertEqual(expected_bytes, output.getvalue()) def test_client_call_with_body_readv_array(self): @@ -2202,37 +2259,30 @@ def test_client_call_with_body_readv_array(self): expected_bytes = b"foo\n7\n1,2\n5,6done\n" input = BytesIO(b"\n") output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolOne(request) - smart_protocol.call_with_body_readv_array((b'foo', ), [(1, 2), (5, 6)]) + smart_protocol.call_with_body_readv_array((b"foo",), [(1, 2), (5, 6)]) self.assertEqual(expected_bytes, output.getvalue()) - def _test_client_read_response_tuple_raises_UnknownSmartMethod(self, - server_bytes): + def _test_client_read_response_tuple_raises_UnknownSmartMethod(self, server_bytes): input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolOne(request) - smart_protocol.call(b'foo') - self.assertRaises( - errors.UnknownSmartMethod, smart_protocol.read_response_tuple) + smart_protocol.call(b"foo") + self.assertRaises(errors.UnknownSmartMethod, smart_protocol.read_response_tuple) # The request has been finished. There is no body to read, and # attempts to read one will fail. - self.assertRaises( - errors.ReadingCompleted, smart_protocol.read_body_bytes) + self.assertRaises(errors.ReadingCompleted, smart_protocol.read_body_bytes) def test_client_read_response_tuple_raises_UnknownSmartMethod(self): """read_response_tuple raises UnknownSmartMethod if the response says the server did not recognise the request. """ - server_bytes = ( - b"error\x01Generic bzr smart protocol error: bad request 'foo'\n") - self._test_client_read_response_tuple_raises_UnknownSmartMethod( - server_bytes) + server_bytes = b"error\x01Generic bzr smart protocol error: bad request 'foo'\n" + self._test_client_read_response_tuple_raises_UnknownSmartMethod(server_bytes) def test_client_read_response_tuple_raises_UnknownSmartMethod_0_11(self): """read_response_tuple also raises UnknownSmartMethod if the response @@ -2241,9 +2291,9 @@ def test_client_read_response_tuple_raises_UnknownSmartMethod_0_11(self): (bzr 0.11 sends a slightly different error message to later versions.) """ server_bytes = ( - b"error\x01Generic bzr smart protocol error: bad request u'foo'\n") - self._test_client_read_response_tuple_raises_UnknownSmartMethod( - server_bytes) + b"error\x01Generic bzr smart protocol error: bad request u'foo'\n" + ) + self._test_client_read_response_tuple_raises_UnknownSmartMethod(server_bytes) def test_client_read_body_bytes_all(self): # read_body_bytes should decode the body bytes from the wire into @@ -2252,11 +2302,10 @@ def test_client_read_body_bytes_all(self): server_bytes = b"ok\n7\n1234567done\n" input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolOne(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) self.assertEqual(expected_bytes, smart_protocol.read_body_bytes()) @@ -2270,18 +2319,14 @@ def test_client_read_body_bytes_incremental(self): server_bytes = b"ok\n7\n1234567done\n" input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolOne(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) - self.assertEqual(expected_bytes[0:2], - smart_protocol.read_body_bytes(2)) - self.assertEqual(expected_bytes[2:4], - smart_protocol.read_body_bytes(2)) - self.assertEqual(expected_bytes[4:6], - smart_protocol.read_body_bytes(2)) + self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2)) + self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2)) + self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2)) self.assertEqual(expected_bytes[6:7], smart_protocol.read_body_bytes()) def test_client_cancel_read_body_does_not_eat_body_bytes(self): @@ -2290,33 +2335,30 @@ def test_client_cancel_read_body_does_not_eat_body_bytes(self): server_bytes = b"ok\n7\n1234567done\n" input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolOne(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) smart_protocol.cancel_read_body() self.assertEqual(3, input.tell()) - self.assertRaises( - errors.ReadingCompleted, smart_protocol.read_body_bytes) + self.assertRaises(errors.ReadingCompleted, smart_protocol.read_body_bytes) def test_client_read_body_bytes_interrupted_connection(self): server_bytes = b"ok\n999\nincomplete body" input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = self.client_protocol_class(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) - self.assertRaises( - ConnectionResetError, smart_protocol.read_body_bytes) + self.assertRaises(ConnectionResetError, smart_protocol.read_body_bytes) class TestVersionOneFeaturesInProtocolTwo( - TestSmartProtocol, CommonSmartProtocolTestMixin): + TestSmartProtocol, CommonSmartProtocolTestMixin +): """Tests for version one smart protocol features as implemeted by version two. """ @@ -2326,106 +2368,107 @@ class TestVersionOneFeaturesInProtocolTwo( def test_construct_version_two_server_protocol(self): smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None) - self.assertEqual(b'', smart_protocol.unused_data) - self.assertEqual(b'', smart_protocol.in_buffer) + self.assertEqual(b"", smart_protocol.unused_data) + self.assertEqual(b"", smart_protocol.in_buffer) self.assertFalse(smart_protocol._has_dispatched) self.assertEqual(1, smart_protocol.next_read_size()) def test_construct_version_two_client_protocol(self): # we can construct a client protocol from a client medium request output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - None, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(None, output, "base") request = client_medium.get_request() protocol.SmartClientRequestProtocolTwo(request) def test_accept_bytes_of_bad_request_to_protocol(self): out_stream = BytesIO() smart_protocol = self.server_protocol_class(None, out_stream.write) - smart_protocol.accept_bytes(b'abc') - self.assertEqual(b'abc', smart_protocol.in_buffer) - smart_protocol.accept_bytes(b'\n') + smart_protocol.accept_bytes(b"abc") + self.assertEqual(b"abc", smart_protocol.in_buffer) + smart_protocol.accept_bytes(b"\n") self.assertEqual( self.response_marker + b"failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n", - out_stream.getvalue()) + out_stream.getvalue(), + ) self.assertTrue(smart_protocol._has_dispatched) self.assertEqual(0, smart_protocol.next_read_size()) def test_accept_body_bytes_to_protocol(self): protocol = self.build_protocol_waiting_for_body() self.assertEqual(6, protocol.next_read_size()) - protocol.accept_bytes(b'7\nabc') + protocol.accept_bytes(b"7\nabc") self.assertEqual(9, protocol.next_read_size()) - protocol.accept_bytes(b'defgd') - protocol.accept_bytes(b'one\n') + protocol.accept_bytes(b"defgd") + protocol.accept_bytes(b"one\n") self.assertEqual(0, protocol.next_read_size()) self.assertTrue(self.end_received) def test_accept_request_and_body_all_at_once(self): - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) mem_transport = memory.MemoryTransport() - mem_transport.put_bytes('foo', b'abcdefghij') + mem_transport.put_bytes("foo", b"abcdefghij") out_stream = BytesIO() - smart_protocol = self.server_protocol_class( - mem_transport, out_stream.write) - smart_protocol.accept_bytes(b'readv\x01foo\n3\n3,3done\n') + smart_protocol = self.server_protocol_class(mem_transport, out_stream.write) + smart_protocol.accept_bytes(b"readv\x01foo\n3\n3,3done\n") self.assertEqual(0, smart_protocol.next_read_size()) - self.assertEqual(self.response_marker - + b'success\nreadv\n3\ndefdone\n', - out_stream.getvalue()) - self.assertEqual(b'', smart_protocol.unused_data) - self.assertEqual(b'', smart_protocol.in_buffer) + self.assertEqual( + self.response_marker + b"success\nreadv\n3\ndefdone\n", + out_stream.getvalue(), + ) + self.assertEqual(b"", smart_protocol.unused_data) + self.assertEqual(b"", smart_protocol.in_buffer) def test_accept_excess_bytes_are_preserved(self): out_stream = BytesIO() smart_protocol = self.server_protocol_class(None, out_stream.write) - smart_protocol.accept_bytes(b'hello\nhello\n') - self.assertEqual(self.response_marker + b"success\nok\x012\n", - out_stream.getvalue()) + smart_protocol.accept_bytes(b"hello\nhello\n") + self.assertEqual( + self.response_marker + b"success\nok\x012\n", out_stream.getvalue() + ) self.assertEqual(b"hello\n", smart_protocol.unused_data) self.assertEqual(b"", smart_protocol.in_buffer) def test_accept_excess_bytes_after_body(self): # The excess bytes look like the start of another request. server_protocol = self.build_protocol_waiting_for_body() - server_protocol.accept_bytes( - b'7\nabcdefgdone\n' + self.response_marker) + server_protocol.accept_bytes(b"7\nabcdefgdone\n" + self.response_marker) self.assertTrue(self.end_received) - self.assertEqual(self.response_marker, - server_protocol.unused_data) + self.assertEqual(self.response_marker, server_protocol.unused_data) self.assertEqual(b"", server_protocol.in_buffer) - server_protocol.accept_bytes(b'Y') - self.assertEqual(self.response_marker + b"Y", - server_protocol.unused_data) + server_protocol.accept_bytes(b"Y") + self.assertEqual(self.response_marker + b"Y", server_protocol.unused_data) self.assertEqual(b"", server_protocol.in_buffer) def test_accept_excess_bytes_after_dispatch(self): out_stream = BytesIO() smart_protocol = self.server_protocol_class(None, out_stream.write) - smart_protocol.accept_bytes(b'hello\n') - self.assertEqual(self.response_marker + b"success\nok\x012\n", - out_stream.getvalue()) - smart_protocol.accept_bytes(self.request_marker + b'hel') - self.assertEqual(self.request_marker + b"hel", - smart_protocol.unused_data) - smart_protocol.accept_bytes(b'lo\n') - self.assertEqual(self.request_marker + b"hello\n", - smart_protocol.unused_data) + smart_protocol.accept_bytes(b"hello\n") + self.assertEqual( + self.response_marker + b"success\nok\x012\n", out_stream.getvalue() + ) + smart_protocol.accept_bytes(self.request_marker + b"hel") + self.assertEqual(self.request_marker + b"hel", smart_protocol.unused_data) + smart_protocol.accept_bytes(b"lo\n") + self.assertEqual(self.request_marker + b"hello\n", smart_protocol.unused_data) self.assertEqual(b"", smart_protocol.in_buffer) def test__send_response_sets_finished_reading(self): smart_protocol = self.server_protocol_class(None, lambda x: None) self.assertEqual(1, smart_protocol.next_read_size()) smart_protocol._send_response( - _mod_request.SuccessfulSmartServerResponse((b'x',))) + _mod_request.SuccessfulSmartServerResponse((b"x",)) + ) self.assertEqual(0, smart_protocol.next_read_size()) def test__send_response_errors_with_base_response(self): """Ensure that only the Successful/Failed subclasses are used.""" smart_protocol = self.server_protocol_class(None, lambda x: None) - self.assertRaises(AttributeError, smart_protocol._send_response, - _mod_request.SmartServerResponse((b'x',))) + self.assertRaises( + AttributeError, + smart_protocol._send_response, + _mod_request.SmartServerResponse((b"x",)), + ) def test_query_version(self): """query_version on a SmartClientProtocolTwo should return a number. @@ -2437,10 +2480,9 @@ def test_query_version(self): # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the # response of tuple-encoded (ok, 1). Also, separately we should test # the error if the response is a non-understood version. - input = BytesIO(self.response_marker + b'success\nok\x012\n') + input = BytesIO(self.response_marker + b"success\nok\x012\n") output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = self.client_protocol_class(request) self.assertEqual(2, smart_protocol.query_version()) @@ -2450,15 +2492,17 @@ def test_client_call_empty_response(self): # when the parsed line is an empty line, and results in a tuple with # one element - an empty string. self.assertServerToClientEncoding( - self.response_marker + b'success\n\n', (b'', ), [(), (b'', )]) + self.response_marker + b"success\n\n", (b"",), [(), (b"",)] + ) def test_client_call_three_element_response(self): # protocol.call() can get back tuples of other lengths. A three element # tuple should be unpacked as three strings. self.assertServerToClientEncoding( - self.response_marker + b'success\na\x01b\x0134\n', - (b'a', b'b', b'34'), - [(b'a', b'b', b'34')]) + self.response_marker + b"success\na\x01b\x0134\n", + (b"a", b"b", b"34"), + [(b"a", b"b", b"34")], + ) def test_client_call_with_body_bytes_uploads(self): # protocol.call_with_body_bytes should length-prefix the bytes onto the @@ -2466,11 +2510,10 @@ def test_client_call_with_body_bytes_uploads(self): expected_bytes = self.request_marker + b"foo\n7\nabcdefgdone\n" input = BytesIO(b"\n") output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = self.client_protocol_class(request) - smart_protocol.call_with_body_bytes((b'foo', ), b"abcdefg") + smart_protocol.call_with_body_bytes((b"foo",), b"abcdefg") self.assertEqual(expected_bytes, output.getvalue()) def test_client_call_with_body_readv_array(self): @@ -2479,26 +2522,23 @@ def test_client_call_with_body_readv_array(self): expected_bytes = self.request_marker + b"foo\n7\n1,2\n5,6done\n" input = BytesIO(b"\n") output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = self.client_protocol_class(request) - smart_protocol.call_with_body_readv_array((b'foo', ), [(1, 2), (5, 6)]) + smart_protocol.call_with_body_readv_array((b"foo",), [(1, 2), (5, 6)]) self.assertEqual(expected_bytes, output.getvalue()) def test_client_read_body_bytes_all(self): # read_body_bytes should decode the body bytes from the wire into # a response. expected_bytes = b"1234567" - server_bytes = (self.response_marker - + b"success\nok\n7\n1234567done\n") + server_bytes = self.response_marker + b"success\nok\n7\n1234567done\n" input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = self.client_protocol_class(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) self.assertEqual(expected_bytes, smart_protocol.read_body_bytes()) @@ -2512,18 +2552,14 @@ def test_client_read_body_bytes_incremental(self): server_bytes = self.response_marker + b"success\nok\n7\n1234567done\n" input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = self.client_protocol_class(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) - self.assertEqual(expected_bytes[0:2], - smart_protocol.read_body_bytes(2)) - self.assertEqual(expected_bytes[2:4], - smart_protocol.read_body_bytes(2)) - self.assertEqual(expected_bytes[4:6], - smart_protocol.read_body_bytes(2)) + self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2)) + self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2)) + self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2)) self.assertEqual(expected_bytes[6:7], smart_protocol.read_body_bytes()) def test_client_cancel_read_body_does_not_eat_body_bytes(self): @@ -2532,37 +2568,29 @@ def test_client_cancel_read_body_does_not_eat_body_bytes(self): server_bytes = self.response_marker + b"success\nok\n7\n1234567done\n" input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = self.client_protocol_class(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) smart_protocol.cancel_read_body() - self.assertEqual(len(self.response_marker + b'success\nok\n'), - input.tell()) - self.assertRaises( - errors.ReadingCompleted, smart_protocol.read_body_bytes) + self.assertEqual(len(self.response_marker + b"success\nok\n"), input.tell()) + self.assertRaises(errors.ReadingCompleted, smart_protocol.read_body_bytes) def test_client_read_body_bytes_interrupted_connection(self): - server_bytes = (self.response_marker - + b"success\nok\n999\nincomplete body") + server_bytes = self.response_marker + b"success\nok\n999\nincomplete body" input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = self.client_protocol_class(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) - self.assertRaises( - ConnectionResetError, smart_protocol.read_body_bytes) + self.assertRaises(ConnectionResetError, smart_protocol.read_body_bytes) class TestSmartProtocolTwoSpecificsMixin: - - def assertBodyStreamSerialisation(self, expected_serialisation, - body_stream): + def assertBodyStreamSerialisation(self, expected_serialisation, body_stream): """Assert that body_stream is serialised as expected_serialisation.""" out_stream = BytesIO() protocol._send_stream(body_stream, out_stream.write) @@ -2581,15 +2609,19 @@ def assertBodyStreamRoundTrips(self, body_stream): def test_body_stream_serialisation_empty(self): """A body_stream with no bytes can be serialised.""" - self.assertBodyStreamSerialisation(b'chunked\nEND\n', []) + self.assertBodyStreamSerialisation(b"chunked\nEND\n", []) self.assertBodyStreamRoundTrips([]) def test_body_stream_serialisation(self): - stream = [b'chunk one', b'chunk two', b'chunk three'] + stream = [b"chunk one", b"chunk two", b"chunk three"] self.assertBodyStreamSerialisation( - b'chunked\n' + b'9\nchunk one' + b'9\nchunk two' + b'b\nchunk three' - + b'END\n', - stream) + b"chunked\n" + + b"9\nchunk one" + + b"9\nchunk two" + + b"b\nchunk three" + + b"END\n", + stream, + ) self.assertBodyStreamRoundTrips(stream) def test_body_stream_with_empty_element_serialisation(self): @@ -2597,104 +2629,115 @@ def test_body_stream_with_empty_element_serialisation(self): The empty string can be transmitted like any other string. """ - stream = [b'', b'chunk'] + stream = [b"", b"chunk"] self.assertBodyStreamSerialisation( - b'chunked\n' + b'0\n' + b'5\nchunk' + b'END\n', stream) + b"chunked\n" + b"0\n" + b"5\nchunk" + b"END\n", stream + ) self.assertBodyStreamRoundTrips(stream) def test_body_stream_error_serialistion(self): - stream = [b'first chunk', - _mod_request.FailedSmartServerResponse( - (b'FailureName', b'failure arg'))] + stream = [ + b"first chunk", + _mod_request.FailedSmartServerResponse((b"FailureName", b"failure arg")), + ] expected_bytes = ( - b'chunked\n' + b'b\nfirst chunk' - + b'ERR\n' + b'b\nFailureName' + b'b\nfailure arg' - + b'END\n') + b"chunked\n" + + b"b\nfirst chunk" + + b"ERR\n" + + b"b\nFailureName" + + b"b\nfailure arg" + + b"END\n" + ) self.assertBodyStreamSerialisation(expected_bytes, stream) self.assertBodyStreamRoundTrips(stream) def test__send_response_includes_failure_marker(self): r"""FailedSmartServerResponse have 'failed\n' after the version.""" out_stream = BytesIO() - smart_protocol = protocol.SmartServerRequestProtocolTwo( - None, out_stream.write) - smart_protocol._send_response( - _mod_request.FailedSmartServerResponse((b'x',))) - self.assertEqual(protocol.RESPONSE_VERSION_TWO + b'failed\nx\n', - out_stream.getvalue()) + smart_protocol = protocol.SmartServerRequestProtocolTwo(None, out_stream.write) + smart_protocol._send_response(_mod_request.FailedSmartServerResponse((b"x",))) + self.assertEqual( + protocol.RESPONSE_VERSION_TWO + b"failed\nx\n", out_stream.getvalue() + ) def test__send_response_includes_success_marker(self): r"""SuccessfulSmartServerResponse have 'success\n' after the version.""" out_stream = BytesIO() - smart_protocol = protocol.SmartServerRequestProtocolTwo( - None, out_stream.write) + smart_protocol = protocol.SmartServerRequestProtocolTwo(None, out_stream.write) smart_protocol._send_response( - _mod_request.SuccessfulSmartServerResponse((b'x',))) - self.assertEqual(protocol.RESPONSE_VERSION_TWO + b'success\nx\n', - out_stream.getvalue()) + _mod_request.SuccessfulSmartServerResponse((b"x",)) + ) + self.assertEqual( + protocol.RESPONSE_VERSION_TWO + b"success\nx\n", out_stream.getvalue() + ) def test__send_response_with_body_stream_sets_finished_reading(self): - smart_protocol = protocol.SmartServerRequestProtocolTwo( - None, lambda x: None) + smart_protocol = protocol.SmartServerRequestProtocolTwo(None, lambda x: None) self.assertEqual(1, smart_protocol.next_read_size()) smart_protocol._send_response( - _mod_request.SuccessfulSmartServerResponse((b'x',), body_stream=[])) + _mod_request.SuccessfulSmartServerResponse((b"x",), body_stream=[]) + ) self.assertEqual(0, smart_protocol.next_read_size()) def test_streamed_body_bytes(self): - body_header = b'chunked\n' + body_header = b"chunked\n" two_body_chunks = b"4\n1234" + b"3\n567" body_terminator = b"END\n" - server_bytes = (protocol.RESPONSE_VERSION_TWO - + b"success\nok\n" + body_header + two_body_chunks - + body_terminator) + server_bytes = ( + protocol.RESPONSE_VERSION_TWO + + b"success\nok\n" + + body_header + + two_body_chunks + + body_terminator + ) input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolTwo(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) stream = smart_protocol.read_streamed_body() - self.assertEqual([b'1234', b'567'], list(stream)) + self.assertEqual([b"1234", b"567"], list(stream)) def test_read_streamed_body_error(self): """When a stream is interrupted by an error...""" - body_header = b'chunked\n' - a_body_chunk = b'4\naaaa' - err_signal = b'ERR\n' - err_chunks = b'a\nerror arg1' + b'4\narg2' - finish = b'END\n' + body_header = b"chunked\n" + a_body_chunk = b"4\naaaa" + err_signal = b"ERR\n" + err_chunks = b"a\nerror arg1" + b"4\narg2" + finish = b"END\n" body = body_header + a_body_chunk + err_signal + err_chunks + finish - server_bytes = (protocol.RESPONSE_VERSION_TWO - + b"success\nok\n" + body) + server_bytes = protocol.RESPONSE_VERSION_TWO + b"success\nok\n" + body input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") smart_request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolTwo(smart_request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) expected_chunks = [ - b'aaaa', - _mod_request.FailedSmartServerResponse((b'error arg1', b'arg2'))] + b"aaaa", + _mod_request.FailedSmartServerResponse((b"error arg1", b"arg2")), + ] stream = smart_protocol.read_streamed_body() self.assertEqual(expected_chunks, list(stream)) def test_streamed_body_bytes_interrupted_connection(self): - body_header = b'chunked\n' + body_header = b"chunked\n" incomplete_body_chunk = b"9999\nincomplete chunk" - server_bytes = (protocol.RESPONSE_VERSION_TWO - + b"success\nok\n" + body_header + incomplete_body_chunk) + server_bytes = ( + protocol.RESPONSE_VERSION_TWO + + b"success\nok\n" + + body_header + + incomplete_body_chunk + ) input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolTwo(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(True) stream = smart_protocol.read_streamed_body() self.assertRaises(ConnectionResetError, next, stream) @@ -2703,11 +2746,10 @@ def test_client_read_response_tuple_sets_response_status(self): server_bytes = protocol.RESPONSE_VERSION_TWO + b"success\nok\n" input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolTwo(request) - smart_protocol.call(b'foo') + smart_protocol.call(b"foo") smart_protocol.read_response_tuple(False) self.assertEqual(True, smart_protocol.response_status) @@ -2718,24 +2760,23 @@ def test_client_read_response_tuple_raises_UnknownSmartMethod(self): server_bytes = ( protocol.RESPONSE_VERSION_TWO + b"failed\n" - + b"error\x01Generic bzr smart protocol error: bad request 'foo'\n") + + b"error\x01Generic bzr smart protocol error: bad request 'foo'\n" + ) input = BytesIO(server_bytes) output = BytesIO() - client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'base') + client_medium = medium.SmartSimplePipesClientMedium(input, output, "base") request = client_medium.get_request() smart_protocol = protocol.SmartClientRequestProtocolTwo(request) - smart_protocol.call(b'foo') - self.assertRaises( - errors.UnknownSmartMethod, smart_protocol.read_response_tuple) + smart_protocol.call(b"foo") + self.assertRaises(errors.UnknownSmartMethod, smart_protocol.read_response_tuple) # The request has been finished. There is no body to read, and # attempts to read one will fail. - self.assertRaises( - errors.ReadingCompleted, smart_protocol.read_body_bytes) + self.assertRaises(errors.ReadingCompleted, smart_protocol.read_body_bytes) class TestSmartProtocolTwoSpecifics( - TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin): + TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin +): """Tests for aspects of smart protocol version two that are unique to version two. @@ -2747,7 +2788,8 @@ class TestSmartProtocolTwoSpecifics( class TestVersionOneFeaturesInProtocolThree( - TestSmartProtocol, CommonSmartProtocolTestMixin): + TestSmartProtocol, CommonSmartProtocolTestMixin +): """Tests for version one smart protocol features as implemented by version three. """ @@ -2768,7 +2810,7 @@ def setUp(self): def test_construct_version_three_server_protocol(self): smart_protocol = protocol.ProtocolThreeDecoder(None) - self.assertEqual(b'', smart_protocol.unused_data) + self.assertEqual(b"", smart_protocol.unused_data) self.assertEqual([], smart_protocol._in_buffer_list) self.assertEqual(0, smart_protocol._in_buffer_len) self.assertFalse(smart_protocol._has_dispatched) @@ -2778,7 +2820,6 @@ def test_construct_version_three_server_protocol(self): class LoggingMessageHandler: - def __init__(self): self.event_log = [] @@ -2786,22 +2827,22 @@ def _log(self, *args): self.event_log.append(args) def headers_received(self, headers): - self._log('headers', headers) + self._log("headers", headers) def protocol_error(self, exception): - self._log('protocol_error', exception) + self._log("protocol_error", exception) def byte_part_received(self, byte): - self._log('byte', byte) + self._log("byte", byte) def bytes_part_received(self, bytes): - self._log('bytes', bytes) + self._log("bytes", bytes) def structure_part_received(self, structure): - self._log('structure', structure) + self._log("structure", structure) def end_received(self): - self._log('end') + self._log("end") class TestProtocolThree(TestSmartProtocol): @@ -2816,33 +2857,33 @@ def test_trivial_request(self): message parts. """ BytesIO() - headers = b'\0\0\0\x02de' # length-prefixed, bencoded empty dict - end = b'e' + headers = b"\0\0\0\x02de" # length-prefixed, bencoded empty dict + end = b"e" request_bytes = headers + end smart_protocol = self.server_protocol_class(LoggingMessageHandler()) smart_protocol.accept_bytes(request_bytes) self.assertEqual(0, smart_protocol.next_read_size()) - self.assertEqual(b'', smart_protocol.unused_data) + self.assertEqual(b"", smart_protocol.unused_data) def test_repeated_excess(self): """Repeated calls to accept_bytes after the message end has been parsed accumlates the bytes in the unused_data attribute. """ BytesIO() - headers = b'\0\0\0\x02de' # length-prefixed, bencoded empty dict - end = b'e' + headers = b"\0\0\0\x02de" # length-prefixed, bencoded empty dict + end = b"e" request_bytes = headers + end smart_protocol = self.server_protocol_class(LoggingMessageHandler()) smart_protocol.accept_bytes(request_bytes) - self.assertEqual(b'', smart_protocol.unused_data) - smart_protocol.accept_bytes(b'aaa') - self.assertEqual(b'aaa', smart_protocol.unused_data) - smart_protocol.accept_bytes(b'bbb') - self.assertEqual(b'aaabbb', smart_protocol.unused_data) + self.assertEqual(b"", smart_protocol.unused_data) + smart_protocol.accept_bytes(b"aaa") + self.assertEqual(b"aaa", smart_protocol.unused_data) + smart_protocol.accept_bytes(b"bbb") + self.assertEqual(b"aaabbb", smart_protocol.unused_data) self.assertEqual(0, smart_protocol.next_read_size()) def make_protocol_expecting_message_part(self): - headers = b'\0\0\0\x02de' # length-prefixed, bencoded empty dict + headers = b"\0\0\0\x02de" # length-prefixed, bencoded empty dict message_handler = LoggingMessageHandler() smart_protocol = self.server_protocol_class(message_handler) smart_protocol.accept_bytes(headers) @@ -2853,106 +2894,107 @@ def make_protocol_expecting_message_part(self): def test_decode_one_byte(self): """The protocol can decode a 'one byte' message part.""" smart_protocol, event_log = self.make_protocol_expecting_message_part() - smart_protocol.accept_bytes(b'ox') - self.assertEqual([('byte', b'x')], event_log) + smart_protocol.accept_bytes(b"ox") + self.assertEqual([("byte", b"x")], event_log) def test_decode_bytes(self): """The protocol can decode a 'bytes' message part.""" smart_protocol, event_log = self.make_protocol_expecting_message_part() smart_protocol.accept_bytes( - b'b' # message part kind - b'\0\0\0\x07' # length prefix - b'payload' # payload - ) - self.assertEqual([('bytes', b'payload')], event_log) + b"b" # message part kind + b"\0\0\0\x07" # length prefix + b"payload" # payload + ) + self.assertEqual([("bytes", b"payload")], event_log) def test_decode_structure(self): """The protocol can decode a 'structure' message part.""" smart_protocol, event_log = self.make_protocol_expecting_message_part() smart_protocol.accept_bytes( - b's' # message part kind - b'\0\0\0\x07' # length prefix - b'l3:ARGe' # ['ARG'] - ) - self.assertEqual([('structure', (b'ARG',))], event_log) + b"s" # message part kind + b"\0\0\0\x07" # length prefix + b"l3:ARGe" # ['ARG'] + ) + self.assertEqual([("structure", (b"ARG",))], event_log) def test_decode_multiple_bytes(self): """The protocol can decode a multiple 'bytes' message parts.""" smart_protocol, event_log = self.make_protocol_expecting_message_part() smart_protocol.accept_bytes( - b'b' # message part kind - b'\0\0\0\x05' # length prefix - b'first' # payload - b'b' # message part kind - b'\0\0\0\x06' - b'second' - ) - self.assertEqual( - [('bytes', b'first'), ('bytes', b'second')], event_log) + b"b" # message part kind + b"\0\0\0\x05" # length prefix + b"first" # payload + b"b" # message part kind + b"\0\0\0\x06" + b"second" + ) + self.assertEqual([("bytes", b"first"), ("bytes", b"second")], event_log) class TestConventionalResponseHandlerBodyStream(tests.TestCase): - def make_response_handler(self, response_bytes): from ..smart.message import ConventionalResponseHandler + response_handler = ConventionalResponseHandler() protocol_decoder = protocol.ProtocolThreeDecoder(response_handler) # put decoder in desired state (waiting for message parts) - protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part + protocol_decoder.state_accept = ( + protocol_decoder._state_accept_expecting_message_part + ) output = BytesIO() client_medium = medium.SmartSimplePipesClientMedium( - BytesIO(response_bytes), output, 'base') + BytesIO(response_bytes), output, "base" + ) medium_request = client_medium.get_request() medium_request.finished_writing() - response_handler.setProtoAndMediumRequest( - protocol_decoder, medium_request) + response_handler.setProtoAndMediumRequest(protocol_decoder, medium_request) return response_handler def test_interrupted_by_error(self): response_handler = self.make_response_handler(interrupted_body_stream) stream = response_handler.read_streamed_body() - self.assertEqual(b'aaa', next(stream)) - self.assertEqual(b'bbb', next(stream)) + self.assertEqual(b"aaa", next(stream)) + self.assertEqual(b"bbb", next(stream)) exc = self.assertRaises(errors.ErrorFromSmartServer, next, stream) - self.assertEqual((b'error', b'Exception', b'Boom!'), exc.error_tuple) + self.assertEqual((b"error", b"Exception", b"Boom!"), exc.error_tuple) def test_interrupted_by_connection_lost(self): interrupted_body_stream = ( - b'oS' # successful response - b's\0\0\0\x02le' # empty args - b'b\0\0\xff\xffincomplete chunk') + b"oS" # successful response + b"s\0\0\0\x02le" # empty args + b"b\0\0\xff\xffincomplete chunk" + ) response_handler = self.make_response_handler(interrupted_body_stream) stream = response_handler.read_streamed_body() self.assertRaises(ConnectionResetError, next, stream) def test_read_body_bytes_interrupted_by_connection_lost(self): interrupted_body_stream = ( - b'oS' # successful response - b's\0\0\0\x02le' # empty args - b'b\0\0\xff\xffincomplete chunk') + b"oS" # successful response + b"s\0\0\0\x02le" # empty args + b"b\0\0\xff\xffincomplete chunk" + ) response_handler = self.make_response_handler(interrupted_body_stream) - self.assertRaises( - ConnectionResetError, response_handler.read_body_bytes) + self.assertRaises(ConnectionResetError, response_handler.read_body_bytes) def test_multiple_bytes_parts(self): multiple_bytes_parts = ( - b'oS' # successful response - b's\0\0\0\x02le' # empty args - b'b\0\0\0\x0bSome bytes\n' # some bytes - b'b\0\0\0\x0aMore bytes' # more bytes - b'e' # message end - ) + b"oS" # successful response + b"s\0\0\0\x02le" # empty args + b"b\0\0\0\x0bSome bytes\n" # some bytes + b"b\0\0\0\x0aMore bytes" # more bytes + b"e" # message end + ) response_handler = self.make_response_handler(multiple_bytes_parts) - self.assertEqual( - b'Some bytes\nMore bytes', response_handler.read_body_bytes()) + self.assertEqual(b"Some bytes\nMore bytes", response_handler.read_body_bytes()) response_handler = self.make_response_handler(multiple_bytes_parts) self.assertEqual( - [b'Some bytes\n', b'More bytes'], - list(response_handler.read_streamed_body())) + [b"Some bytes\n", b"More bytes"], + list(response_handler.read_streamed_body()), + ) class FakeResponder: - response_sent = False def send_error(self, exc): @@ -2970,15 +3012,18 @@ def make_request_handler(self, request_bytes): doubles for the request_handler and the responder. """ from ..smart.message import ConventionalRequestHandler + request_handler = InstrumentedRequestHandler() request_handler.response = _mod_request.SuccessfulSmartServerResponse( - (b'arg', b'arg')) + (b"arg", b"arg") + ) responder = FakeResponder() - message_handler = ConventionalRequestHandler( - request_handler, responder) + message_handler = ConventionalRequestHandler(request_handler, responder) protocol_decoder = protocol.ProtocolThreeDecoder(message_handler) # put decoder in desired state (waiting for message parts) - protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part + protocol_decoder.state_accept = ( + protocol_decoder._state_accept_expecting_message_part + ) protocol_decoder.accept_bytes(request_bytes) return request_handler @@ -2987,31 +3032,33 @@ def test_multiple_bytes_parts(self): accept_body method. """ multiple_bytes_parts = ( - b's\0\0\0\x07l3:fooe' # args - b'b\0\0\0\x0bSome bytes\n' # some bytes - b'b\0\0\0\x0aMore bytes' # more bytes - b'e' # message end - ) + b"s\0\0\0\x07l3:fooe" # args + b"b\0\0\0\x0bSome bytes\n" # some bytes + b"b\0\0\0\x0aMore bytes" # more bytes + b"e" # message end + ) request_handler = self.make_request_handler(multiple_bytes_parts) accept_body_calls = [ - call_info[1] for call_info in request_handler.calls - if call_info[0] == 'accept_body'] - self.assertEqual( - [b'Some bytes\n', b'More bytes'], accept_body_calls) + call_info[1] + for call_info in request_handler.calls + if call_info[0] == "accept_body" + ] + self.assertEqual([b"Some bytes\n", b"More bytes"], accept_body_calls) def test_error_flag_after_body(self): body_then_error = ( - b's\0\0\0\x07l3:fooe' # request args - b'b\0\0\0\x0bSome bytes\n' # some bytes - b'b\0\0\0\x0aMore bytes' # more bytes - b'oE' # error flag - b's\0\0\0\x07l3:bare' # error args - b'e' # message end - ) + b"s\0\0\0\x07l3:fooe" # request args + b"b\0\0\0\x0bSome bytes\n" # some bytes + b"b\0\0\0\x0aMore bytes" # more bytes + b"oE" # error flag + b"s\0\0\0\x07l3:bare" # error args + b"e" # message end + ) request_handler = self.make_request_handler(body_then_error) self.assertEqual( - [('post_body_error_received', (b'bar',)), ('end_received',)], - request_handler.calls[-2:]) + [("post_body_error_received", (b"bar",)), ("end_received",)], + request_handler.calls[-2:], + ) class TestMessageHandlerErrors(tests.TestCase): @@ -3034,25 +3081,27 @@ def test_non_conventional_request(self): # fact it has that part twice, to trigger multiple errors. invalid_request = ( protocol.MESSAGE_VERSION_THREE # protocol version marker - + b'\0\0\0\x02de' # empty headers - + b'oX' + # a single byte part: 'X'. ConventionalRequestHandler will + + b"\0\0\0\x02de" # empty headers + + b"oX" # a single byte part: 'X'. ConventionalRequestHandler will + + # error at this part. - b'oX' + # and again. - b'e' # end of message - ) + b"oX" # and again. + + b"e" # end of message + ) to_server = BytesIO(invalid_request) from_server = BytesIO() - transport = memory.MemoryTransport('memory:///') + transport = memory.MemoryTransport("memory:///") server = medium.SmartServerPipeStreamMedium( - to_server, from_server, transport, timeout=4.0) + to_server, from_server, transport, timeout=4.0 + ) proto = server._build_protocol() server._serve_one_request(proto) # All the bytes have been read from the medium... - self.assertEqual(b'', to_server.read()) + self.assertEqual(b"", to_server.read()) # ...and the protocol decoder has consumed all the bytes, and has # finished reading. - self.assertEqual(b'', proto.unused_data) + self.assertEqual(b"", proto.unused_data) self.assertEqual(0, proto.next_read_size()) @@ -3064,28 +3113,27 @@ def __init__(self): self.finished_reading = False def no_body_received(self): - self.calls.append(('no_body_received',)) + self.calls.append(("no_body_received",)) def end_received(self): - self.calls.append(('end_received',)) + self.calls.append(("end_received",)) self.finished_reading = True def args_received(self, args): - self.calls.append(('args_received', args)) + self.calls.append(("args_received", args)) def accept_body(self, bytes): - self.calls.append(('accept_body', bytes)) + self.calls.append(("accept_body", bytes)) def end_of_body(self): - self.calls.append(('end_of_body',)) + self.calls.append(("end_of_body",)) self.finished_reading = True def post_body_error_received(self, error_args): - self.calls.append(('post_body_error_received', error_args)) + self.calls.append(("post_body_error_received", error_args)) class StubRequest: - def finished_reading(self): pass @@ -3110,21 +3158,22 @@ def test_trivial_response_decoding(self): """Smoke test for the simplest possible v3 response: empty headers, status byte, empty args, no body. """ - headers = b'\0\0\0\x02de' # length-prefixed, bencoded empty dict - response_status = b'oS' # success - args = b's\0\0\0\x02le' # length-prefixed, bencoded empty list - end = b'e' # end marker + headers = b"\0\0\0\x02de" # length-prefixed, bencoded empty dict + response_status = b"oS" # success + args = b"s\0\0\0\x02le" # length-prefixed, bencoded empty list + end = b"e" # end marker message_bytes = headers + response_status + args + end decoder, response_handler = self.make_logging_response_decoder() decoder.accept_bytes(message_bytes) # The protocol decoder has finished, and consumed all bytes self.assertEqual(0, decoder.next_read_size()) - self.assertEqual(b'', decoder.unused_data) + self.assertEqual(b"", decoder.unused_data) # The message handler has been invoked with all the parts of the # trivial response: empty headers, status byte, no args, end. self.assertEqual( - [('headers', {}), ('byte', b'S'), ('structure', ()), ('end',)], - response_handler.event_log) + [("headers", {}), ("byte", b"S"), ("structure", ()), ("end",)], + response_handler.event_log, + ) def test_incomplete_message(self): """A decoder will keep signalling that it needs more bytes via @@ -3132,11 +3181,11 @@ def test_incomplete_message(self): which state it is in. """ # Define a simple response that uses all possible message parts. - headers = b'\0\0\0\x02de' # length-prefixed, bencoded empty dict - response_status = b'oS' # success - args = b's\0\0\0\x02le' # length-prefixed, bencoded empty list - body = b'b\0\0\0\x04BODY' # a body: 'BODY' - end = b'e' # end marker + headers = b"\0\0\0\x02de" # length-prefixed, bencoded empty dict + response_status = b"oS" # success + args = b"s\0\0\0\x02le" # length-prefixed, bencoded empty list + body = b"b\0\0\0\x04BODY" # a body: 'BODY' + end = b"e" # end marker simple_response = headers + response_status + args + body + end # Feed the request to the decoder one byte at a time. decoder, response_handler = self.make_logging_response_decoder() @@ -3150,34 +3199,35 @@ def test_read_response_tuple_raises_UnknownSmartMethod(self): """read_response_tuple raises UnknownSmartMethod if the server replied with 'UnknownMethod'. """ - headers = b'\0\0\0\x02de' # length-prefixed, bencoded empty dict - response_status = b'oE' # error flag + headers = b"\0\0\0\x02de" # length-prefixed, bencoded empty dict + response_status = b"oE" # error flag # args: (b'UnknownMethod', 'method-name') - args = b's\0\0\0\x20l13:UnknownMethod11:method-namee' - end = b'e' # end marker + args = b"s\0\0\0\x20l13:UnknownMethod11:method-namee" + end = b"e" # end marker message_bytes = headers + response_status + args + end decoder, response_handler = self.make_conventional_response_decoder() decoder.accept_bytes(message_bytes) error = self.assertRaises( - errors.UnknownSmartMethod, response_handler.read_response_tuple) - self.assertEqual(b'method-name', error.verb) + errors.UnknownSmartMethod, response_handler.read_response_tuple + ) + self.assertEqual(b"method-name", error.verb) def test_read_response_tuple_error(self): """If the response has an error, it is raised as an exception.""" - headers = b'\0\0\0\x02de' # length-prefixed, bencoded empty dict - response_status = b'oE' # error - args = b's\0\0\0\x1al9:first arg10:second arge' # two args - end = b'e' # end marker + headers = b"\0\0\0\x02de" # length-prefixed, bencoded empty dict + response_status = b"oE" # error + args = b"s\0\0\0\x1al9:first arg10:second arge" # two args + end = b"e" # end marker message_bytes = headers + response_status + args + end decoder, response_handler = self.make_conventional_response_decoder() decoder.accept_bytes(message_bytes) error = self.assertRaises( - errors.ErrorFromSmartServer, response_handler.read_response_tuple) - self.assertEqual((b'first arg', b'second arg'), error.error_tuple) + errors.ErrorFromSmartServer, response_handler.read_response_tuple + ) + self.assertEqual((b"first arg", b"second arg"), error.error_tuple) class TestClientEncodingProtocolThree(TestSmartProtocol): - request_encoder = protocol.ProtocolThreeRequester response_decoder = protocol.ProtocolThreeDecoder server_protocol_class = protocol.ProtocolThreeDecoder # type: ignore @@ -3194,14 +3244,15 @@ def test_call_smoke_test(self): correct bytes for that invocation. """ requester, output = self.make_client_encoder_and_output() - requester.set_headers({b'header name': b'header value'}) - requester.call(b'one arg') + requester.set_headers({b"header name": b"header value"}) + requester.call(b"one arg") self.assertEqual( - b'bzr message 3 (bzr 1.6)\n' # protocol version - b'\x00\x00\x00\x1fd11:header name12:header valuee' # headers - b's\x00\x00\x00\x0bl7:one arge' # args - b'e', # end - output.getvalue()) + b"bzr message 3 (bzr 1.6)\n" # protocol version + b"\x00\x00\x00\x1fd11:header name12:header valuee" # headers + b"s\x00\x00\x00\x0bl7:one arge" # args + b"e", # end + output.getvalue(), + ) def test_call_with_body_bytes_smoke_test(self): """A smoke test for ProtocolThreeRequester.call_with_body_bytes. @@ -3210,32 +3261,31 @@ def test_call_with_body_bytes_smoke_test(self): call_with_body_bytes emits the correct bytes for that invocation. """ requester, output = self.make_client_encoder_and_output() - requester.set_headers({b'header name': b'header value'}) - requester.call_with_body_bytes((b'one arg',), b'body bytes') + requester.set_headers({b"header name": b"header value"}) + requester.call_with_body_bytes((b"one arg",), b"body bytes") self.assertEqual( - b'bzr message 3 (bzr 1.6)\n' # protocol version - b'\x00\x00\x00\x1fd11:header name12:header valuee' # headers - b's\x00\x00\x00\x0bl7:one arge' # args - b'b' # there is a prefixed body - b'\x00\x00\x00\nbody bytes' # the prefixed body - b'e', # end - output.getvalue()) + b"bzr message 3 (bzr 1.6)\n" # protocol version + b"\x00\x00\x00\x1fd11:header name12:header valuee" # headers + b"s\x00\x00\x00\x0bl7:one arge" # args + b"b" # there is a prefixed body + b"\x00\x00\x00\nbody bytes" # the prefixed body + b"e", # end + output.getvalue(), + ) def test_call_writes_just_once(self): """A bodyless request is written to the medium all at once.""" medium_request = StubMediumRequest() encoder = protocol.ProtocolThreeRequester(medium_request) - encoder.call(b'arg1', b'arg2', b'arg3') - self.assertEqual( - ['accept_bytes', 'finished_writing'], medium_request.calls) + encoder.call(b"arg1", b"arg2", b"arg3") + self.assertEqual(["accept_bytes", "finished_writing"], medium_request.calls) def test_call_with_body_bytes_writes_just_once(self): """A request with body bytes is written to the medium all at once.""" medium_request = StubMediumRequest() encoder = protocol.ProtocolThreeRequester(medium_request) - encoder.call_with_body_bytes((b'arg', b'arg'), b'body bytes') - self.assertEqual( - ['accept_bytes', 'finished_writing'], medium_request.calls) + encoder.call_with_body_bytes((b"arg", b"arg"), b"body bytes") + self.assertEqual(["accept_bytes", "finished_writing"], medium_request.calls) def test_call_with_body_stream_smoke_test(self): """A smoke test for ProtocolThreeRequester.call_with_body_stream. @@ -3244,31 +3294,33 @@ def test_call_with_body_stream_smoke_test(self): call_with_body_stream emits the correct bytes for that invocation. """ requester, output = self.make_client_encoder_and_output() - requester.set_headers({b'header name': b'header value'}) - stream = [b'chunk 1', b'chunk two'] - requester.call_with_body_stream((b'one arg',), stream) + requester.set_headers({b"header name": b"header value"}) + stream = [b"chunk 1", b"chunk two"] + requester.call_with_body_stream((b"one arg",), stream) self.assertEqual( - b'bzr message 3 (bzr 1.6)\n' # protocol version - b'\x00\x00\x00\x1fd11:header name12:header valuee' # headers - b's\x00\x00\x00\x0bl7:one arge' # args - b'b\x00\x00\x00\x07chunk 1' # a prefixed body chunk - b'b\x00\x00\x00\x09chunk two' # a prefixed body chunk - b'e', # end - output.getvalue()) + b"bzr message 3 (bzr 1.6)\n" # protocol version + b"\x00\x00\x00\x1fd11:header name12:header valuee" # headers + b"s\x00\x00\x00\x0bl7:one arge" # args + b"b\x00\x00\x00\x07chunk 1" # a prefixed body chunk + b"b\x00\x00\x00\x09chunk two" # a prefixed body chunk + b"e", # end + output.getvalue(), + ) def test_call_with_body_stream_empty_stream(self): """call_with_body_stream with an empty stream.""" requester, output = self.make_client_encoder_and_output() requester.set_headers({}) stream = [] - requester.call_with_body_stream((b'one arg',), stream) + requester.call_with_body_stream((b"one arg",), stream) self.assertEqual( - b'bzr message 3 (bzr 1.6)\n' # protocol version - b'\x00\x00\x00\x02de' # headers - b's\x00\x00\x00\x0bl7:one arge' # args + b"bzr message 3 (bzr 1.6)\n" # protocol version + b"\x00\x00\x00\x02de" # headers + b"s\x00\x00\x00\x0bl7:one arge" # args # no body chunks - b'e', # end - output.getvalue()) + b"e", # end + output.getvalue(), + ) def test_call_with_body_stream_error(self): """call_with_body_stream will abort the streamed body with an @@ -3280,21 +3332,27 @@ def test_call_with_body_stream_error(self): requester.set_headers({}) def stream_that_fails(): - yield b'aaa' - yield b'bbb' - raise Exception('Boom!') - self.assertRaises(Exception, requester.call_with_body_stream, - (b'one arg',), stream_that_fails()) + yield b"aaa" + yield b"bbb" + raise Exception("Boom!") + + self.assertRaises( + Exception, + requester.call_with_body_stream, + (b"one arg",), + stream_that_fails(), + ) self.assertEqual( - b'bzr message 3 (bzr 1.6)\n' # protocol version - b'\x00\x00\x00\x02de' # headers - b's\x00\x00\x00\x0bl7:one arge' # args - b'b\x00\x00\x00\x03aaa' # body - b'b\x00\x00\x00\x03bbb' # more body - b'oE' # error flag - b's\x00\x00\x00\x09l5:errore' # error args: ('error',) - b'e', # end - output.getvalue()) + b"bzr message 3 (bzr 1.6)\n" # protocol version + b"\x00\x00\x00\x02de" # headers + b"s\x00\x00\x00\x0bl7:one arge" # args + b"b\x00\x00\x00\x03aaa" # body + b"b\x00\x00\x00\x03bbb" # more body + b"oE" # error flag + b"s\x00\x00\x00\x09l5:errore" # error args: ('error',) + b"e", # end + output.getvalue(), + ) def test_records_start_of_body_stream(self): requester, output = self.make_client_encoder_and_output() @@ -3304,7 +3362,8 @@ def test_records_start_of_body_stream(self): def stream_checker(): self.assertTrue(requester.body_stream_started) in_stream[0] = True - yield b'content' + yield b"content" + flush_called = [] orig_flush = requester.flush @@ -3315,14 +3374,17 @@ def tracked_flush(): else: self.assertFalse(requester.body_stream_started) return orig_flush() + requester.flush = tracked_flush - requester.call_with_body_stream((b'one arg',), stream_checker()) + requester.call_with_body_stream((b"one arg",), stream_checker()) self.assertEqual( - b'bzr message 3 (bzr 1.6)\n' # protocol version - b'\x00\x00\x00\x02de' # headers - b's\x00\x00\x00\x0bl7:one arge' # args - b'b\x00\x00\x00\x07content' # body - b'e', output.getvalue()) + b"bzr message 3 (bzr 1.6)\n" # protocol version + b"\x00\x00\x00\x02de" # headers + b"s\x00\x00\x00\x0bl7:one arge" # args + b"b\x00\x00\x00\x07content" # body + b"e", + output.getvalue(), + ) self.assertEqual([False, True, True], flush_called) @@ -3333,29 +3395,28 @@ class StubMediumRequest: def __init__(self): self.calls = [] - self._medium = 'dummy medium' + self._medium = "dummy medium" def accept_bytes(self, bytes): - self.calls.append('accept_bytes') + self.calls.append("accept_bytes") def finished_writing(self): - self.calls.append('finished_writing') + self.calls.append("finished_writing") interrupted_body_stream = ( - b'oS' # status flag (success) - b's\x00\x00\x00\x08l4:argse' # args struct ('args,') - b'b\x00\x00\x00\x03aaa' # body part ('aaa') - b'b\x00\x00\x00\x03bbb' # body part ('bbb') - b'oE' # status flag (error) + b"oS" # status flag (success) + b"s\x00\x00\x00\x08l4:argse" # args struct ('args,') + b"b\x00\x00\x00\x03aaa" # body part ('aaa') + b"b\x00\x00\x00\x03bbb" # body part ('bbb') + b"oE" # status flag (error) # err struct ('error', 'Exception', 'Boom!') - b's\x00\x00\x00\x1bl5:error9:Exception5:Boom!e' - b'e' # EOM - ) + b"s\x00\x00\x00\x1bl5:error9:Exception5:Boom!e" + b"e" # EOM +) class TestResponseEncodingProtocolThree(tests.TestCase): - def make_response_encoder(self): out_stream = BytesIO() response_encoder = protocol.ProtocolThreeResponder(out_stream.write) @@ -3363,33 +3424,36 @@ def make_response_encoder(self): def test_send_error_unknown_method(self): encoder, out_stream = self.make_response_encoder() - encoder.send_error(errors.UnknownSmartMethod('method name')) + encoder.send_error(errors.UnknownSmartMethod("method name")) # Use assertEndsWith so that we don't compare the header, which varies # by breezy.__version__. self.assertEndsWith( out_stream.getvalue(), # error status - b'oE' + + b"oE" + # tuple: 'UnknownMethod', 'method name' - b's\x00\x00\x00\x20l13:UnknownMethod11:method namee' + b"s\x00\x00\x00\x20l13:UnknownMethod11:method namee" # end of message - b'e') + b"e", + ) def test_send_broken_body_stream(self): encoder, out_stream = self.make_response_encoder() encoder._headers = {} def stream_that_fails(): - yield b'aaa' - yield b'bbb' - raise Exception('Boom!') + yield b"aaa" + yield b"bbb" + raise Exception("Boom!") + response = _mod_request.SuccessfulSmartServerResponse( - (b'args',), body_stream=stream_that_fails()) + (b"args",), body_stream=stream_that_fails() + ) encoder.send_response(response) expected_response = ( - b'bzr message 3 (bzr 1.6)\n' # protocol marker - b'\x00\x00\x00\x02de' # headers dict (empty) - + interrupted_body_stream) + b"bzr message 3 (bzr 1.6)\n" # protocol marker + b"\x00\x00\x00\x02de" + interrupted_body_stream # headers dict (empty) + ) self.assertEqual(expected_response, out_stream.getvalue()) @@ -3408,17 +3472,19 @@ def setUp(self): def assertWriteCount(self, expected_count): # self.writes can be quite large; don't show the whole thing self.assertEqual( - expected_count, len(self.writes), - "Too many writes: %d, expected %d" % (len(self.writes), expected_count)) + expected_count, + len(self.writes), + "Too many writes: %d, expected %d" % (len(self.writes), expected_count), + ) def test_send_error_writes_just_once(self): """An error response is written to the medium all at once.""" - self.responder.send_error(Exception('An exception string.')) + self.responder.send_error(Exception("An exception string.")) self.assertWriteCount(1) def test_send_response_writes_just_once(self): """A normal response with no body is written to the medium all at once.""" - response = _mod_request.SuccessfulSmartServerResponse((b'arg', b'arg')) + response = _mod_request.SuccessfulSmartServerResponse((b"arg", b"arg")) self.responder.send_response(response) self.assertWriteCount(1) @@ -3427,7 +3493,8 @@ def test_send_response_with_body_writes_just_once(self): all at once. """ response = _mod_request.SuccessfulSmartServerResponse( - (b'arg', b'arg'), body=b'body bytes') + (b"arg", b"arg"), body=b"body bytes" + ) self.responder.send_response(response) self.assertWriteCount(1) @@ -3435,7 +3502,8 @@ def test_send_response_with_body_stream_buffers_writes(self): """A normal response with a stream body writes to the medium once.""" # Construct a response with stream with 2 chunks in it. response = _mod_request.SuccessfulSmartServerResponse( - (b'arg', b'arg'), body_stream=[b'chunk1', b'chunk2']) + (b"arg", b"arg"), body_stream=[b"chunk1", b"chunk2"] + ) self.responder.send_response(response) # Per the discussion in bug 590638 we flush once after the header and # then once after each chunk @@ -3462,23 +3530,24 @@ def assertCallDoesNotBreakMedium(self, method, args, body): input = BytesIO(b"\n") output = BytesIO() client_medium = medium.SmartSimplePipesClientMedium( - input, output, 'ignored base') + input, output, "ignored base" + ) smart_client = client._SmartClient(client_medium) - self.assertRaises(TypeError, - smart_client.call_with_body_bytes, method, args, body) + self.assertRaises( + TypeError, smart_client.call_with_body_bytes, method, args, body + ) self.assertEqual(b"", output.getvalue()) self.assertEqual(None, client_medium._current_request) def test_call_with_body_bytes_unicode_method(self): - self.assertCallDoesNotBreakMedium('method', (b'args',), b'body') + self.assertCallDoesNotBreakMedium("method", (b"args",), b"body") def test_call_with_body_bytes_unicode_args(self): - self.assertCallDoesNotBreakMedium(b'method', ('args',), b'body') - self.assertCallDoesNotBreakMedium( - b'method', (b'arg1', 'arg2'), b'body') + self.assertCallDoesNotBreakMedium(b"method", ("args",), b"body") + self.assertCallDoesNotBreakMedium(b"method", (b"arg1", "arg2"), b"body") def test_call_with_body_bytes_unicode_body(self): - self.assertCallDoesNotBreakMedium(b'method', (b'args',), 'body') + self.assertCallDoesNotBreakMedium(b"method", (b"args",), "body") class MockMedium(medium.SmartClientMedium): @@ -3498,12 +3567,11 @@ class MockMedium(medium.SmartClientMedium): """ def __init__(self): - super().__init__('dummy base') + super().__init__("dummy base") self._mock_request = _MockMediumRequest(self) self._expected_events = [] - def expect_request(self, request_bytes, response_bytes, - allow_partial_read=False): + def expect_request(self, request_bytes, response_bytes, allow_partial_read=False): """Expect 'request_bytes' to be sent, and reply with 'response_bytes'. No assumption is made about how many times accept_bytes should be @@ -3529,16 +3597,15 @@ def expect_request(self, request_bytes, response_bytes, is expected to disconnect without needing to read the complete response. Default is False. """ - self._expected_events.append(('send request', request_bytes)) + self._expected_events.append(("send request", request_bytes)) if allow_partial_read: - self._expected_events.append( - ('read response (partial)', response_bytes)) + self._expected_events.append(("read response (partial)", response_bytes)) else: - self._expected_events.append(('read response', response_bytes)) + self._expected_events.append(("read response", response_bytes)) def expect_disconnect(self): """Expect the client to call ``medium.disconnect()``.""" - self._expected_events.append('disconnect') + self._expected_events.append("disconnect") def _assertEvent(self, observed_event): """Raise AssertionError unless observed_event matches the next expected @@ -3551,17 +3618,20 @@ def _assertEvent(self, observed_event): expected_event = self._expected_events.pop(0) except IndexError as e: raise AssertionError( - f'Mock medium observed event {observed_event!r}, but no more events expected') from e - if expected_event[0] == 'read response (partial)': - if observed_event[0] != 'read response': + f"Mock medium observed event {observed_event!r}, but no more events expected" + ) from e + if expected_event[0] == "read response (partial)": + if observed_event[0] != "read response": raise AssertionError( - f'Mock medium observed event {observed_event!r}, but expected event {expected_event!r}') + f"Mock medium observed event {observed_event!r}, but expected event {expected_event!r}" + ) elif observed_event != expected_event: raise AssertionError( - f'Mock medium observed event {observed_event!r}, but expected event {expected_event!r}') + f"Mock medium observed event {observed_event!r}, but expected event {expected_event!r}" + ) if self._expected_events: next_event = self._expected_events[0] - if next_event[0].startswith('read response'): + if next_event[0].startswith("read response"): self._mock_request._response = next_event[1] def get_request(self): @@ -3569,10 +3639,9 @@ def get_request(self): def disconnect(self): if self._mock_request._read_bytes: - self._assertEvent( - ('read response', self._mock_request._read_bytes)) - self._mock_request._read_bytes = b'' - self._assertEvent('disconnect') + self._assertEvent(("read response", self._mock_request._read_bytes)) + self._mock_request._read_bytes = b"" + self._assertEvent("disconnect") class _MockMediumRequest: @@ -3580,20 +3649,20 @@ class _MockMediumRequest: def __init__(self, mock_medium): self._medium = mock_medium - self._written_bytes = b'' - self._read_bytes = b'' + self._written_bytes = b"" + self._read_bytes = b"" self._response = None def accept_bytes(self, bytes): self._written_bytes += bytes def finished_writing(self): - self._medium._assertEvent(('send request', self._written_bytes)) - self._written_bytes = b'' + self._medium._assertEvent(("send request", self._written_bytes)) + self._written_bytes = b"" def finished_reading(self): - self._medium._assertEvent(('read response', self._read_bytes)) - self._read_bytes = b'' + self._medium._assertEvent(("read response", self._read_bytes)) + self._read_bytes = b"" def read_bytes(self, size): resp = self._response @@ -3605,10 +3674,10 @@ def read_bytes(self, size): def read_line(self): resp = self._response try: - line, resp = resp.split(b'\n', 1) - line += b'\n' + line, resp = resp.split(b"\n", 1) + line += b"\n" except ValueError: - line, resp = resp, b'' + line, resp = resp, b"" self._response = resp self._read_bytes += line return line @@ -3625,15 +3694,15 @@ def test_version_three_server(self): """With a protocol 3 server, only one request is needed.""" medium = MockMedium() smart_client = client._SmartClient(medium, headers={}) - message_start = protocol.MESSAGE_VERSION_THREE + b'\x00\x00\x00\x02de' + message_start = protocol.MESSAGE_VERSION_THREE + b"\x00\x00\x00\x02de" medium.expect_request( - message_start - + b's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee', - message_start + b's\0\0\0\x13l14:response valueee') - result = smart_client.call(b'method-name', b'arg 1', b'arg 2') + message_start + b"s\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee", + message_start + b"s\0\0\0\x13l14:response valueee", + ) + result = smart_client.call(b"method-name", b"arg 1", b"arg 2") # The call succeeded without raising any exceptions from the mock # medium, and the smart_client returns the response from the server. - self.assertEqual((b'response value',), result) + self.assertEqual((b"response value",), result) self.assertEqual([], medium._expected_events) # Also, the v3 works then the server should be assumed to support RPCs # introduced in 1.6. @@ -3651,29 +3720,32 @@ def test_version_two_server(self): # First the client should send a v3 request, but the server will reply # with a v2 error. medium.expect_request( - b'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' - + b's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee', - b'bzr response 2\nfailed\n\n') + b"bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de" + + b"s\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee", + b"bzr response 2\nfailed\n\n", + ) # So then the client should disconnect to reset the connection, because # the client needs to assume the server cannot read any further # requests off the original connection. medium.expect_disconnect() # The client should then retry the original request in v2 medium.expect_request( - b'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n', - b'bzr response 2\nsuccess\nresponse value\n') - result = smart_client.call(b'method-name', b'arg 1', b'arg 2') + b"bzr request 2\nmethod-name\x01arg 1\x01arg 2\n", + b"bzr response 2\nsuccess\nresponse value\n", + ) + result = smart_client.call(b"method-name", b"arg 1", b"arg 2") # The smart_client object will return the result of the successful # query. - self.assertEqual((b'response value',), result) + self.assertEqual((b"response value",), result) # Now try another request, and this time the client will just use # protocol 2. (i.e. the autodetection won't be repeated) medium.expect_request( - b'bzr request 2\nanother-method\n', - b'bzr response 2\nsuccess\nanother response\n') - result = smart_client.call(b'another-method') - self.assertEqual((b'another response',), result) + b"bzr request 2\nanother-method\n", + b"bzr response 2\nsuccess\nanother response\n", + ) + result = smart_client.call(b"another-method") + self.assertEqual((b"another response",), result) self.assertEqual([], medium._expected_events) # Also, because v3 is not supported, the client medium should assume @@ -3686,20 +3758,25 @@ def test_unknown_version(self): """ medium = MockMedium() smart_client = client._SmartClient(medium, headers={}) - unknown_protocol_bytes = b'Unknown protocol!' + unknown_protocol_bytes = b"Unknown protocol!" # The client will try v3 and v2 before eventually giving up. medium.expect_request( - b'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' - + b's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee', - unknown_protocol_bytes) + b"bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de" + + b"s\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee", + unknown_protocol_bytes, + ) medium.expect_disconnect() medium.expect_request( - b'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n', - unknown_protocol_bytes) + b"bzr request 2\nmethod-name\x01arg 1\x01arg 2\n", unknown_protocol_bytes + ) medium.expect_disconnect() self.assertRaises( errors.SmartProtocolError, - smart_client.call, b'method-name', b'arg 1', b'arg 2') + smart_client.call, + b"method-name", + b"arg 1", + b"arg 2", + ) self.assertEqual([], medium._expected_events) def test_first_response_is_error(self): @@ -3711,183 +3788,232 @@ def test_first_response_is_error(self): """ medium = MockMedium() smart_client = client._SmartClient(medium, headers={}) - message_start = protocol.MESSAGE_VERSION_THREE + b'\x00\x00\x00\x02de' + message_start = protocol.MESSAGE_VERSION_THREE + b"\x00\x00\x00\x02de" # Issue a request that gets an error reply in a non-default protocol # version. medium.expect_request( - message_start - + b's\x00\x00\x00\x10l11:method-nameee', - b'bzr response 2\nfailed\n\n') + message_start + b"s\x00\x00\x00\x10l11:method-nameee", + b"bzr response 2\nfailed\n\n", + ) medium.expect_disconnect() medium.expect_request( - b'bzr request 2\nmethod-name\n', - b'bzr response 2\nfailed\nFooBarError\n') + b"bzr request 2\nmethod-name\n", b"bzr response 2\nfailed\nFooBarError\n" + ) err = self.assertRaises( - errors.ErrorFromSmartServer, - smart_client.call, b'method-name') - self.assertEqual((b'FooBarError',), err.error_tuple) + errors.ErrorFromSmartServer, smart_client.call, b"method-name" + ) + self.assertEqual((b"FooBarError",), err.error_tuple) # Now the medium should have remembered the protocol version, so # subsequent requests will use the remembered version immediately. medium.expect_request( - b'bzr request 2\nmethod-name\n', - b'bzr response 2\nsuccess\nresponse value\n') - result = smart_client.call(b'method-name') - self.assertEqual((b'response value',), result) + b"bzr request 2\nmethod-name\n", + b"bzr response 2\nsuccess\nresponse value\n", + ) + result = smart_client.call(b"method-name") + self.assertEqual((b"response value",), result) self.assertEqual([], medium._expected_events) class Test_SmartClient(tests.TestCase): - def test_call_default_headers(self): """ProtocolThreeRequester.call by default sends a 'Software version' header. """ - smart_client = client._SmartClient('dummy medium') + smart_client = client._SmartClient("dummy medium") self.assertEqual( - breezy.__version__.encode('utf-8'), - smart_client._headers[b'Software version']) + breezy.__version__.encode("utf-8"), + smart_client._headers[b"Software version"], + ) # XXX: need a test that smart_client._headers is passed to the request # encoder. class Test_SmartClientRequest(tests.TestCase): - - def make_client_with_failing_medium(self, fail_at_write=True, response=b''): + def make_client_with_failing_medium(self, fail_at_write=True, response=b""): response_io = BytesIO(response) output = BytesIO() - vendor = FirstRejectedBytesIOSSHVendor(response_io, output, - fail_at_write=fail_at_write) - ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass') - client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor) + vendor = FirstRejectedBytesIOSSHVendor( + response_io, output, fail_at_write=fail_at_write + ) + ssh_params = medium.SSHParams("a host", "a port", "a user", "a pass") + client_medium = medium.SmartSSHClientMedium("base", ssh_params, vendor) smart_client = client._SmartClient(client_medium, headers={}) return output, vendor, smart_client def make_response(self, args, body=None, body_stream=None): response_io = BytesIO() - response = _mod_request.SuccessfulSmartServerResponse(args, body=body, - body_stream=body_stream) + response = _mod_request.SuccessfulSmartServerResponse( + args, body=body, body_stream=body_stream + ) responder = protocol.ProtocolThreeResponder(response_io.write) responder.send_response(response) return response_io.getvalue() def test__call_doesnt_retry_append(self): - response = self.make_response(('appended', b'8')) + response = self.make_response(("appended", b"8")) output, vendor, smart_client = self.make_client_with_failing_medium( - fail_at_write=False, response=response) - smart_request = client._SmartClientRequest(smart_client, b'append', - (b'foo', b''), body=b'content\n') + fail_at_write=False, response=response + ) + smart_request = client._SmartClientRequest( + smart_client, b"append", (b"foo", b""), body=b"content\n" + ) self.assertRaises(ConnectionResetError, smart_request._call, 3) def test__call_retries_get_bytes(self): - response = self.make_response((b'ok',), b'content\n') + response = self.make_response((b"ok",), b"content\n") output, vendor, smart_client = self.make_client_with_failing_medium( - fail_at_write=False, response=response) - smart_request = client._SmartClientRequest(smart_client, b'get', - (b'foo',)) + fail_at_write=False, response=response + ) + smart_request = client._SmartClientRequest(smart_client, b"get", (b"foo",)) response, response_handler = smart_request._call(3) - self.assertEqual((b'ok',), response) - self.assertEqual(b'content\n', response_handler.read_body_bytes()) + self.assertEqual((b"ok",), response) + self.assertEqual(b"content\n", response_handler.read_body_bytes()) def test__call_noretry_get_bytes(self): - debug.set_debug_flag('noretry') - response = self.make_response((b'ok',), b'content\n') + debug.set_debug_flag("noretry") + response = self.make_response((b"ok",), b"content\n") output, vendor, smart_client = self.make_client_with_failing_medium( - fail_at_write=False, response=response) - smart_request = client._SmartClientRequest(smart_client, b'get', - (b'foo',)) + fail_at_write=False, response=response + ) + smart_request = client._SmartClientRequest(smart_client, b"get", (b"foo",)) self.assertRaises(ConnectionResetError, smart_request._call, 3) def test__send_no_retry_pipes(self): client_read, server_write = create_file_pipes() server_read, client_write = create_file_pipes() - client_medium = medium.SmartSimplePipesClientMedium(client_read, - client_write, base='/') + client_medium = medium.SmartSimplePipesClientMedium( + client_read, client_write, base="/" + ) smart_client = client._SmartClient(client_medium) - smart_request = client._SmartClientRequest(smart_client, - b'hello', ()) + smart_request = client._SmartClientRequest(smart_client, b"hello", ()) # Close the server side server_read.close() encoder, response_handler = smart_request._construct_protocol(3) - self.assertRaises(ConnectionResetError, - smart_request._send_no_retry, encoder) + self.assertRaises(ConnectionResetError, smart_request._send_no_retry, encoder) def test__send_read_response_sockets(self): listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - listen_sock.bind(('127.0.0.1', 0)) + listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(1) host, port = listen_sock.getsockname() - client_medium = medium.SmartTCPClientMedium(host, port, '/') + client_medium = medium.SmartTCPClientMedium(host, port, "/") client_medium._ensure_connection() smart_client = client._SmartClient(client_medium) - smart_request = client._SmartClientRequest(smart_client, b'hello', ()) + smart_request = client._SmartClientRequest(smart_client, b"hello", ()) # Accept the connection, but don't actually talk to the client. server_sock, _ = listen_sock.accept() server_sock.close() # Sockets buffer and don't really notice that the server has closed the # connection until we try to read again. handler = smart_request._send(3) - self.assertRaises(ConnectionResetError, - handler.read_response_tuple, expect_body=False) + self.assertRaises( + ConnectionResetError, handler.read_response_tuple, expect_body=False + ) def test__send_retries_on_write(self): output, vendor, smart_client = self.make_client_with_failing_medium() - smart_request = client._SmartClientRequest(smart_client, b'hello', ()) + smart_request = client._SmartClientRequest(smart_client, b"hello", ()) smart_request._send(3) - self.assertEqual(b'bzr message 3 (bzr 1.6)\n' # protocol - b'\x00\x00\x00\x02de' # empty headers - b's\x00\x00\x00\tl5:helloee', - output.getvalue()) self.assertEqual( - [('connect_ssh', 'a user', 'a pass', 'a host', 'a port', - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ('close',), - ('connect_ssh', 'a user', 'a pass', 'a host', 'a port', - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ], - vendor.calls) + b"bzr message 3 (bzr 1.6)\n" # protocol + b"\x00\x00\x00\x02de" # empty headers + b"s\x00\x00\x00\tl5:helloee", + output.getvalue(), + ) + self.assertEqual( + [ + ( + "connect_ssh", + "a user", + "a pass", + "a host", + "a port", + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ("close",), + ( + "connect_ssh", + "a user", + "a pass", + "a host", + "a port", + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ], + vendor.calls, + ) def test__send_doesnt_retry_read_failure(self): output, vendor, smart_client = self.make_client_with_failing_medium( - fail_at_write=False) - smart_request = client._SmartClientRequest(smart_client, b'hello', ()) + fail_at_write=False + ) + smart_request = client._SmartClientRequest(smart_client, b"hello", ()) handler = smart_request._send(3) - self.assertEqual(b'bzr message 3 (bzr 1.6)\n' # protocol - b'\x00\x00\x00\x02de' # empty headers - b's\x00\x00\x00\tl5:helloee', - output.getvalue()) self.assertEqual( - [('connect_ssh', 'a user', 'a pass', 'a host', 'a port', - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ], - vendor.calls) + b"bzr message 3 (bzr 1.6)\n" # protocol + b"\x00\x00\x00\x02de" # empty headers + b"s\x00\x00\x00\tl5:helloee", + output.getvalue(), + ) + self.assertEqual( + [ + ( + "connect_ssh", + "a user", + "a pass", + "a host", + "a port", + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ], + vendor.calls, + ) self.assertRaises(ConnectionResetError, handler.read_response_tuple) def test__send_request_retries_body_stream_if_not_started(self): output, vendor, smart_client = self.make_client_with_failing_medium() - smart_request = client._SmartClientRequest(smart_client, b'hello', (), - body_stream=[b'a', b'b']) + smart_request = client._SmartClientRequest( + smart_client, b"hello", (), body_stream=[b"a", b"b"] + ) smart_request._send(3) # We connect, get disconnected, and notice before consuming the stream, # so we try again one time and succeed. self.assertEqual( - [('connect_ssh', 'a user', 'a pass', 'a host', 'a port', - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ('close',), - ('connect_ssh', 'a user', 'a pass', 'a host', 'a port', - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ], - vendor.calls) - self.assertEqual(b'bzr message 3 (bzr 1.6)\n' # protocol - b'\x00\x00\x00\x02de' # empty headers - b's\x00\x00\x00\tl5:helloe' - b'b\x00\x00\x00\x01a' - b'b\x00\x00\x00\x01b' - b'e', - output.getvalue()) + [ + ( + "connect_ssh", + "a user", + "a pass", + "a host", + "a port", + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ("close",), + ( + "connect_ssh", + "a user", + "a pass", + "a host", + "a port", + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ], + vendor.calls, + ) + self.assertEqual( + b"bzr message 3 (bzr 1.6)\n" # protocol + b"\x00\x00\x00\x02de" # empty headers + b"s\x00\x00\x00\tl5:helloe" + b"b\x00\x00\x00\x01a" + b"b\x00\x00\x00\x01b" + b"e", + output.getvalue(), + ) def test__send_request_stops_if_body_started(self): # We intentionally use the python BytesIO so that we can subclass it. from io import BytesIO + response = BytesIO() class FailAfterFirstWrite(BytesIO): @@ -3901,45 +4027,63 @@ def write(self, s): if self._first: self._first = False return BytesIO.write(self, s) - raise OSError(errno.EINVAL, 'invalid file handle') + raise OSError(errno.EINVAL, "invalid file handle") + output = FailAfterFirstWrite() - vendor = FirstRejectedBytesIOSSHVendor(response, output, - fail_at_write=False) - ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass') - client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor) + vendor = FirstRejectedBytesIOSSHVendor(response, output, fail_at_write=False) + ssh_params = medium.SSHParams("a host", "a port", "a user", "a pass") + client_medium = medium.SmartSSHClientMedium("base", ssh_params, vendor) smart_client = client._SmartClient(client_medium, headers={}) - smart_request = client._SmartClientRequest(smart_client, b'hello', (), - body_stream=[b'a', b'b']) + smart_request = client._SmartClientRequest( + smart_client, b"hello", (), body_stream=[b"a", b"b"] + ) self.assertRaises(ConnectionResetError, smart_request._send, 3) # We connect, and manage to get to the point that we start consuming # the body stream. The next write fails, so we just stop. self.assertEqual( - [('connect_ssh', 'a user', 'a pass', 'a host', 'a port', - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ('close',), - ], - vendor.calls) - self.assertEqual(b'bzr message 3 (bzr 1.6)\n' # protocol - b'\x00\x00\x00\x02de' # empty headers - b's\x00\x00\x00\tl5:helloe', - output.getvalue()) + [ + ( + "connect_ssh", + "a user", + "a pass", + "a host", + "a port", + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ("close",), + ], + vendor.calls, + ) + self.assertEqual( + b"bzr message 3 (bzr 1.6)\n" # protocol + b"\x00\x00\x00\x02de" # empty headers + b"s\x00\x00\x00\tl5:helloe", + output.getvalue(), + ) def test__send_disabled_retry(self): - debug.set_debug_flag('noretry') + debug.set_debug_flag("noretry") output, vendor, smart_client = self.make_client_with_failing_medium() - smart_request = client._SmartClientRequest(smart_client, b'hello', ()) + smart_request = client._SmartClientRequest(smart_client, b"hello", ()) self.assertRaises(ConnectionResetError, smart_request._send, 3) self.assertEqual( - [('connect_ssh', 'a user', 'a pass', 'a host', 'a port', - ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']), - ('close',), - ], - vendor.calls) + [ + ( + "connect_ssh", + "a user", + "a pass", + "a host", + "a port", + ["bzr", "serve", "--inet", "--directory=/", "--allow-writes"], + ), + ("close",), + ], + vendor.calls, + ) class LengthPrefixedBodyDecoder(tests.TestCase): - # XXX: TODO: make accept_reading_trailer invoke translate_response or # something similar to the ProtocolBase method. @@ -3947,62 +4091,62 @@ def test_construct(self): decoder = protocol.LengthPrefixedBodyDecoder() self.assertFalse(decoder.finished_reading) self.assertEqual(6, decoder.next_read_size()) - self.assertEqual(b'', decoder.read_pending_data()) - self.assertEqual(b'', decoder.unused_data) + self.assertEqual(b"", decoder.read_pending_data()) + self.assertEqual(b"", decoder.unused_data) def test_accept_bytes(self): decoder = protocol.LengthPrefixedBodyDecoder() - decoder.accept_bytes(b'') + decoder.accept_bytes(b"") self.assertFalse(decoder.finished_reading) self.assertEqual(6, decoder.next_read_size()) - self.assertEqual(b'', decoder.read_pending_data()) - self.assertEqual(b'', decoder.unused_data) - decoder.accept_bytes(b'7') + self.assertEqual(b"", decoder.read_pending_data()) + self.assertEqual(b"", decoder.unused_data) + decoder.accept_bytes(b"7") self.assertFalse(decoder.finished_reading) self.assertEqual(6, decoder.next_read_size()) - self.assertEqual(b'', decoder.read_pending_data()) - self.assertEqual(b'', decoder.unused_data) - decoder.accept_bytes(b'\na') + self.assertEqual(b"", decoder.read_pending_data()) + self.assertEqual(b"", decoder.unused_data) + decoder.accept_bytes(b"\na") self.assertFalse(decoder.finished_reading) self.assertEqual(11, decoder.next_read_size()) - self.assertEqual(b'a', decoder.read_pending_data()) - self.assertEqual(b'', decoder.unused_data) - decoder.accept_bytes(b'bcdefgd') + self.assertEqual(b"a", decoder.read_pending_data()) + self.assertEqual(b"", decoder.unused_data) + decoder.accept_bytes(b"bcdefgd") self.assertFalse(decoder.finished_reading) self.assertEqual(4, decoder.next_read_size()) - self.assertEqual(b'bcdefg', decoder.read_pending_data()) - self.assertEqual(b'', decoder.unused_data) - decoder.accept_bytes(b'one') + self.assertEqual(b"bcdefg", decoder.read_pending_data()) + self.assertEqual(b"", decoder.unused_data) + decoder.accept_bytes(b"one") self.assertFalse(decoder.finished_reading) self.assertEqual(1, decoder.next_read_size()) - self.assertEqual(b'', decoder.read_pending_data()) - self.assertEqual(b'', decoder.unused_data) - decoder.accept_bytes(b'\nblarg') + self.assertEqual(b"", decoder.read_pending_data()) + self.assertEqual(b"", decoder.unused_data) + decoder.accept_bytes(b"\nblarg") self.assertTrue(decoder.finished_reading) self.assertEqual(1, decoder.next_read_size()) - self.assertEqual(b'', decoder.read_pending_data()) - self.assertEqual(b'blarg', decoder.unused_data) + self.assertEqual(b"", decoder.read_pending_data()) + self.assertEqual(b"blarg", decoder.unused_data) def test_accept_bytes_all_at_once_with_excess(self): decoder = protocol.LengthPrefixedBodyDecoder() - decoder.accept_bytes(b'1\nadone\nunused') + decoder.accept_bytes(b"1\nadone\nunused") self.assertTrue(decoder.finished_reading) self.assertEqual(1, decoder.next_read_size()) - self.assertEqual(b'a', decoder.read_pending_data()) - self.assertEqual(b'unused', decoder.unused_data) + self.assertEqual(b"a", decoder.read_pending_data()) + self.assertEqual(b"unused", decoder.unused_data) def test_accept_bytes_exact_end_of_body(self): decoder = protocol.LengthPrefixedBodyDecoder() - decoder.accept_bytes(b'1\na') + decoder.accept_bytes(b"1\na") self.assertFalse(decoder.finished_reading) self.assertEqual(5, decoder.next_read_size()) - self.assertEqual(b'a', decoder.read_pending_data()) - self.assertEqual(b'', decoder.unused_data) - decoder.accept_bytes(b'done\n') + self.assertEqual(b"a", decoder.read_pending_data()) + self.assertEqual(b"", decoder.unused_data) + decoder.accept_bytes(b"done\n") self.assertTrue(decoder.finished_reading) self.assertEqual(1, decoder.next_read_size()) - self.assertEqual(b'', decoder.read_pending_data()) - self.assertEqual(b'', decoder.unused_data) + self.assertEqual(b"", decoder.read_pending_data()) + self.assertEqual(b"", decoder.unused_data) class TestChunkedBodyDecoder(tests.TestCase): @@ -4016,99 +4160,107 @@ def test_construct(self): self.assertFalse(decoder.finished_reading) self.assertEqual(8, decoder.next_read_size()) self.assertEqual(None, decoder.read_next_chunk()) - self.assertEqual(b'', decoder.unused_data) + self.assertEqual(b"", decoder.unused_data) def test_empty_content(self): r"""'chunked\nEND\n' is the complete encoding of a zero-length body.""" decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') - decoder.accept_bytes(b'END\n') + decoder.accept_bytes(b"chunked\n") + decoder.accept_bytes(b"END\n") self.assertTrue(decoder.finished_reading) self.assertEqual(None, decoder.read_next_chunk()) - self.assertEqual(b'', decoder.unused_data) + self.assertEqual(b"", decoder.unused_data) def test_one_chunk(self): """A body in a single chunk is decoded correctly.""" decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') - chunk_length = b'f\n' - chunk_content = b'123456789abcdef' - finish = b'END\n' + decoder.accept_bytes(b"chunked\n") + chunk_length = b"f\n" + chunk_content = b"123456789abcdef" + finish = b"END\n" decoder.accept_bytes(chunk_length + chunk_content + finish) self.assertTrue(decoder.finished_reading) self.assertEqual(chunk_content, decoder.read_next_chunk()) - self.assertEqual(b'', decoder.unused_data) + self.assertEqual(b"", decoder.unused_data) def test_incomplete_chunk(self): """When there are less bytes in the chunk than declared by the length, then we haven't finished reading yet. """ decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') - chunk_length = b'8\n' - three_bytes = b'123' + decoder.accept_bytes(b"chunked\n") + chunk_length = b"8\n" + three_bytes = b"123" decoder.accept_bytes(chunk_length + three_bytes) self.assertFalse(decoder.finished_reading) self.assertEqual( - 5 + 4, decoder.next_read_size(), + 5 + 4, + decoder.next_read_size(), "The next_read_size hint should be the number of missing bytes in " "this chunk plus 4 (the length of the end-of-body marker: " - "'END\\n')") + "'END\\n')", + ) self.assertEqual(None, decoder.read_next_chunk()) def test_incomplete_length(self): """A chunk length hasn't been read until a newline byte has been read.""" decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') - decoder.accept_bytes(b'9') + decoder.accept_bytes(b"chunked\n") + decoder.accept_bytes(b"9") self.assertEqual( - 1, decoder.next_read_size(), + 1, + decoder.next_read_size(), "The next_read_size hint should be 1, because we don't know the " - "length yet.") - decoder.accept_bytes(b'\n') + "length yet.", + ) + decoder.accept_bytes(b"\n") self.assertEqual( - 9 + 4, decoder.next_read_size(), + 9 + 4, + decoder.next_read_size(), "The next_read_size hint should be the length of the chunk plus 4 " - "(the length of the end-of-body marker: 'END\\n')") + "(the length of the end-of-body marker: 'END\\n')", + ) self.assertFalse(decoder.finished_reading) self.assertEqual(None, decoder.read_next_chunk()) def test_two_chunks(self): """Content from multiple chunks is concatenated.""" decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') - chunk_one = b'3\naaa' - chunk_two = b'5\nbbbbb' - finish = b'END\n' + decoder.accept_bytes(b"chunked\n") + chunk_one = b"3\naaa" + chunk_two = b"5\nbbbbb" + finish = b"END\n" decoder.accept_bytes(chunk_one + chunk_two + finish) self.assertTrue(decoder.finished_reading) - self.assertEqual(b'aaa', decoder.read_next_chunk()) - self.assertEqual(b'bbbbb', decoder.read_next_chunk()) + self.assertEqual(b"aaa", decoder.read_next_chunk()) + self.assertEqual(b"bbbbb", decoder.read_next_chunk()) self.assertEqual(None, decoder.read_next_chunk()) - self.assertEqual(b'', decoder.unused_data) + self.assertEqual(b"", decoder.unused_data) def test_excess_bytes(self): """Bytes after the chunked body are reported as unused bytes.""" decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') + decoder.accept_bytes(b"chunked\n") chunked_body = b"5\naaaaaEND\n" excess_bytes = b"excess bytes" decoder.accept_bytes(chunked_body + excess_bytes) self.assertTrue(decoder.finished_reading) - self.assertEqual(b'aaaaa', decoder.read_next_chunk()) + self.assertEqual(b"aaaaa", decoder.read_next_chunk()) self.assertEqual(excess_bytes, decoder.unused_data) self.assertEqual( - 1, decoder.next_read_size(), - "next_read_size hint should be 1 when finished_reading.") + 1, + decoder.next_read_size(), + "next_read_size hint should be 1 when finished_reading.", + ) def test_multidigit_length(self): """Lengths in the chunk prefixes can have multiple digits.""" decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') + decoder.accept_bytes(b"chunked\n") length = 0x123 - chunk_prefix = hex(length).encode('ascii') + b'\n' - chunk_bytes = b'z' * length - finish = b'END\n' + chunk_prefix = hex(length).encode("ascii") + b"\n" + chunk_bytes = b"z" * length + finish = b"END\n" decoder.accept_bytes(chunk_prefix + chunk_bytes + finish) self.assertTrue(decoder.finished_reading) self.assertEqual(chunk_bytes, decoder.read_next_chunk()) @@ -4122,41 +4274,40 @@ def test_byte_at_a_time(self): is called. """ decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') - chunk_length = b'f\n' - chunk_content = b'123456789abcdef' - finish = b'END\n' + decoder.accept_bytes(b"chunked\n") + chunk_length = b"f\n" + chunk_content = b"123456789abcdef" + finish = b"END\n" combined = chunk_length + chunk_content + finish for i in range(len(combined)): - decoder.accept_bytes(combined[i:i + 1]) + decoder.accept_bytes(combined[i : i + 1]) self.assertTrue(decoder.finished_reading) self.assertEqual(chunk_content, decoder.read_next_chunk()) - self.assertEqual(b'', decoder.unused_data) + self.assertEqual(b"", decoder.unused_data) def test_read_pending_data_resets(self): """read_pending_data does not return the same bytes twice.""" decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') - chunk_one = b'3\naaa' - chunk_two = b'3\nbbb' + decoder.accept_bytes(b"chunked\n") + chunk_one = b"3\naaa" + chunk_two = b"3\nbbb" decoder.accept_bytes(chunk_one) - self.assertEqual(b'aaa', decoder.read_next_chunk()) + self.assertEqual(b"aaa", decoder.read_next_chunk()) decoder.accept_bytes(chunk_two) - self.assertEqual(b'bbb', decoder.read_next_chunk()) + self.assertEqual(b"bbb", decoder.read_next_chunk()) self.assertEqual(None, decoder.read_next_chunk()) def test_decode_error(self): decoder = protocol.ChunkedBodyDecoder() - decoder.accept_bytes(b'chunked\n') - chunk_one = b'b\nfirst chunk' - error_signal = b'ERR\n' - error_chunks = b'5\npart1' + b'5\npart2' - finish = b'END\n' + decoder.accept_bytes(b"chunked\n") + chunk_one = b"b\nfirst chunk" + error_signal = b"ERR\n" + error_chunks = b"5\npart1" + b"5\npart2" + finish = b"END\n" decoder.accept_bytes(chunk_one + error_signal + error_chunks + finish) self.assertTrue(decoder.finished_reading) - self.assertEqual(b'first chunk', decoder.read_next_chunk()) - expected_failure = _mod_request.FailedSmartServerResponse( - (b'part1', b'part2')) + self.assertEqual(b"first chunk", decoder.read_next_chunk()) + expected_failure = _mod_request.FailedSmartServerResponse((b"part1", b"part2")) self.assertEqual(expected_failure, decoder.read_next_chunk()) def test_bad_header(self): @@ -4165,59 +4316,63 @@ def test_bad_header(self): """ decoder = protocol.ChunkedBodyDecoder() self.assertRaises( - errors.SmartProtocolError, decoder.accept_bytes, b'bad header\n') + errors.SmartProtocolError, decoder.accept_bytes, b"bad header\n" + ) class TestSuccessfulSmartServerResponse(tests.TestCase): - def test_construct_no_body(self): - response = _mod_request.SuccessfulSmartServerResponse((b'foo', b'bar')) - self.assertEqual((b'foo', b'bar'), response.args) + response = _mod_request.SuccessfulSmartServerResponse((b"foo", b"bar")) + self.assertEqual((b"foo", b"bar"), response.args) self.assertEqual(None, response.body) def test_construct_with_body(self): - response = _mod_request.SuccessfulSmartServerResponse((b'foo', b'bar'), - b'bytes') - self.assertEqual((b'foo', b'bar'), response.args) - self.assertEqual(b'bytes', response.body) + response = _mod_request.SuccessfulSmartServerResponse( + (b"foo", b"bar"), b"bytes" + ) + self.assertEqual((b"foo", b"bar"), response.args) + self.assertEqual(b"bytes", response.body) # repr(response) doesn't trigger exceptions. repr(response) def test_construct_with_body_stream(self): - bytes_iterable = [b'abc'] + bytes_iterable = [b"abc"] response = _mod_request.SuccessfulSmartServerResponse( - (b'foo', b'bar'), body_stream=bytes_iterable) - self.assertEqual((b'foo', b'bar'), response.args) + (b"foo", b"bar"), body_stream=bytes_iterable + ) + self.assertEqual((b"foo", b"bar"), response.args) self.assertEqual(bytes_iterable, response.body_stream) def test_construct_rejects_body_and_body_stream(self): """'body' and 'body_stream' are mutually exclusive.""" # noqa: D403 self.assertRaises( errors.BzrError, - _mod_request.SuccessfulSmartServerResponse, (), b'body', [b'stream']) + _mod_request.SuccessfulSmartServerResponse, + (), + b"body", + [b"stream"], + ) def test_is_successful(self): """is_successful should return True for SuccessfulSmartServerResponse.""" - response = _mod_request.SuccessfulSmartServerResponse((b'error',)) + response = _mod_request.SuccessfulSmartServerResponse((b"error",)) self.assertEqual(True, response.is_successful()) class TestFailedSmartServerResponse(tests.TestCase): - def test_construct(self): - response = _mod_request.FailedSmartServerResponse((b'foo', b'bar')) - self.assertEqual((b'foo', b'bar'), response.args) + response = _mod_request.FailedSmartServerResponse((b"foo", b"bar")) + self.assertEqual((b"foo", b"bar"), response.args) self.assertEqual(None, response.body) - response = _mod_request.FailedSmartServerResponse( - (b'foo', b'bar'), b'bytes') - self.assertEqual((b'foo', b'bar'), response.args) - self.assertEqual(b'bytes', response.body) + response = _mod_request.FailedSmartServerResponse((b"foo", b"bar"), b"bytes") + self.assertEqual((b"foo", b"bar"), response.args) + self.assertEqual(b"bytes", response.body) # repr(response) doesn't trigger exceptions. repr(response) def test_is_successful(self): """is_successful should return False for FailedSmartServerResponse.""" - response = _mod_request.FailedSmartServerResponse((b'error',)) + response = _mod_request.FailedSmartServerResponse((b"error",)) self.assertEqual(False, response.is_successful()) @@ -4232,24 +4387,22 @@ def send_http_smart_request(self, bytes): class HTTPTunnellingSmokeTest(tests.TestCase): - def setUp(self): super().setUp() # We use the VFS layer as part of HTTP tunnelling tests. - self.overrideEnv('BRZ_NO_SMART_VFS', None) + self.overrideEnv("BRZ_NO_SMART_VFS", None) def test_smart_http_medium_request_accept_bytes(self): medium = FakeHTTPMedium() request = urllib.SmartClientHTTPMediumRequest(medium) - request.accept_bytes(b'abc') - request.accept_bytes(b'def') + request.accept_bytes(b"abc") + request.accept_bytes(b"def") self.assertEqual(None, medium.written_request) request.finished_writing() - self.assertEqual(b'abcdef', medium.written_request) + self.assertEqual(b"abcdef", medium.written_request) class RemoteHTTPTransportTestCase(tests.TestCase): - def test_remote_path_after_clone_child(self): # If a user enters "bzr+http://host/foo", we want to sent all smart # requests for child URLs of that to the original URL. i.e., we want to @@ -4257,66 +4410,66 @@ def test_remote_path_after_clone_child(self): # "bzr+http://host/foo/.bzr/branch/.bzr/smart". So, a cloned # RemoteHTTPTransport remembers the initial URL, and adjusts the # relpaths it sends in smart requests accordingly. - base_transport = remote.RemoteHTTPTransport('bzr+http://host/path') - new_transport = base_transport.clone('child_dir') - self.assertEqual(base_transport._http_transport, - new_transport._http_transport) - self.assertEqual('child_dir/foo', new_transport._remote_path('foo')) + base_transport = remote.RemoteHTTPTransport("bzr+http://host/path") + new_transport = base_transport.clone("child_dir") + self.assertEqual(base_transport._http_transport, new_transport._http_transport) + self.assertEqual("child_dir/foo", new_transport._remote_path("foo")) self.assertEqual( - b'child_dir/', - new_transport._client.remote_path_from_transport(new_transport)) + b"child_dir/", + new_transport._client.remote_path_from_transport(new_transport), + ) def test_remote_path_unnormal_base(self): # If the transport's base isn't normalised, the _remote_path should # still be calculated correctly. - base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b') - self.assertEqual('c', base_transport._remote_path('c')) + base_transport = remote.RemoteHTTPTransport("bzr+http://host/%7Ea/b") + self.assertEqual("c", base_transport._remote_path("c")) def test_clone_unnormal_base(self): # If the transport's base isn't normalised, cloned transports should # still work correctly. - base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b') - new_transport = base_transport.clone('c') - self.assertEqual(base_transport.base + 'c/', new_transport.base) + base_transport = remote.RemoteHTTPTransport("bzr+http://host/%7Ea/b") + new_transport = base_transport.clone("c") + self.assertEqual(base_transport.base + "c/", new_transport.base) self.assertEqual( - b'c/', - new_transport._client.remote_path_from_transport(new_transport)) + b"c/", new_transport._client.remote_path_from_transport(new_transport) + ) def test__redirect_to(self): - t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo') - r = t._redirected_to('http://www.example.com/foo', - 'http://www.example.com/bar') + t = remote.RemoteHTTPTransport("bzr+http://www.example.com/foo") + r = t._redirected_to("http://www.example.com/foo", "http://www.example.com/bar") self.assertEqual(type(r), type(t)) def test__redirect_sibling_protocol(self): - t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo') - r = t._redirected_to('http://www.example.com/foo', - 'https://www.example.com/bar') + t = remote.RemoteHTTPTransport("bzr+http://www.example.com/foo") + r = t._redirected_to( + "http://www.example.com/foo", "https://www.example.com/bar" + ) self.assertEqual(type(r), type(t)) - self.assertStartsWith(r.base, 'bzr+https') + self.assertStartsWith(r.base, "bzr+https") def test__redirect_to_with_user(self): - t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo') - r = t._redirected_to('http://www.example.com/foo', - 'http://www.example.com/bar') + t = remote.RemoteHTTPTransport("bzr+http://joe@www.example.com/foo") + r = t._redirected_to("http://www.example.com/foo", "http://www.example.com/bar") self.assertEqual(type(r), type(t)) - self.assertEqual('joe', t._parsed_url.user) + self.assertEqual("joe", t._parsed_url.user) self.assertEqual(t._parsed_url.user, r._parsed_url.user) def test_redirected_to_same_host_different_protocol(self): - t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo') - r = t._redirected_to('http://www.example.com/foo', - 'bzr://www.example.com/foo') + t = remote.RemoteHTTPTransport("bzr+http://joe@www.example.com/foo") + r = t._redirected_to("http://www.example.com/foo", "bzr://www.example.com/foo") self.assertNotEqual(type(r), type(t)) class TestErrors(tests.TestCase): def test_too_many_concurrent_requests(self): error = medium.TooManyConcurrentRequests("a medium") - self.assertEqualDiff("The medium 'a medium' has reached its concurrent " - "request limit. Be sure to finish_writing and finish_reading on " - "the currently open request.", - str(error)) + self.assertEqualDiff( + "The medium 'a medium' has reached its concurrent " + "request limit. Be sure to finish_writing and finish_reading on " + "the currently open request.", + str(error), + ) def test_smart_message_handler_error(self): # Make an exc_info tuple. @@ -4327,9 +4480,8 @@ def test_smart_message_handler_error(self): # GZ 2010-11-08: Should not store exc_info in exception instances. try: self.assertStartsWith( - str(err), "The message handler raised an exception:\n") + str(err), "The message handler raised an exception:\n" + ) self.assertEndsWith(str(err), "Exception: example error\n") finally: del err - - diff --git a/breezy/bzr/tests/test_tag.py b/breezy/bzr/tests/test_tag.py index 0452ac79f1..51d120bf09 100644 --- a/breezy/bzr/tests/test_tag.py +++ b/breezy/bzr/tests/test_tag.py @@ -23,7 +23,6 @@ class TestTagSerialization(TestCase): - def test_tag_serialization(self): """Test the precise representation of tag dicts.""" # Don't change this after we commit to this format, as it checks @@ -32,8 +31,8 @@ def test_tag_serialization(self): # This release stores them in bencode as a dictionary from name to # target. store = BasicTags(branch=None) - td = {'stable': b'stable-revid', 'boring': b'boring-revid'} + td = {"stable": b"stable-revid", "boring": b"boring-revid"} packed = store._serialize_tag_dict(td) - expected = br'd6:boring12:boring-revid6:stable12:stable-revide' + expected = rb"d6:boring12:boring-revid6:stable12:stable-revide" self.assertEqualDiff(packed, expected) self.assertEqual(store._deserialize_tag_dict(packed), td) diff --git a/breezy/bzr/tests/test_testament.py b/breezy/bzr/tests/test_testament.py index 432e3a1908..5c562f373b 100644 --- a/breezy/bzr/tests/test_testament.py +++ b/breezy/bzr/tests/test_testament.py @@ -28,36 +28,43 @@ class TestamentSetup(TestCaseWithTransport): - def setUp(self): super().setUp() - self.wt = self.make_branch_and_tree('.') - self.wt.set_root_id(b'TREE_ROT') + self.wt = self.make_branch_and_tree(".") + self.wt.set_root_id(b"TREE_ROT") b = self.b = self.wt.branch b.nick = "test branch" - self.wt.commit(message='initial null commit', - committer='test@user', - timestamp=1129025423, # 'Tue Oct 11 20:10:23 2005' - timezone=0, - rev_id=b'test@user-1') - self.build_tree_contents([('hello', b'contents of hello file'), - ('src/', ), - ('src/foo.c', b'int main()\n{\n}\n')]) - self.wt.add(['hello', 'src', 'src/foo.c'], - ids=[b'hello-id', b'src-id', b'foo.c-id']) + self.wt.commit( + message="initial null commit", + committer="test@user", + timestamp=1129025423, # 'Tue Oct 11 20:10:23 2005' + timezone=0, + rev_id=b"test@user-1", + ) + self.build_tree_contents( + [ + ("hello", b"contents of hello file"), + ("src/",), + ("src/foo.c", b"int main()\n{\n}\n"), + ] + ) + self.wt.add( + ["hello", "src", "src/foo.c"], ids=[b"hello-id", b"src-id", b"foo.c-id"] + ) tt = self.wt.transform() - trans_id = tt.trans_id_tree_path('hello') + trans_id = tt.trans_id_tree_path("hello") tt.set_executability(True, trans_id) tt.apply() - self.wt.commit(message='add files and directories', - timestamp=1129025483, - timezone=36000, - rev_id=b'test@user-2', - committer='test@user') + self.wt.commit( + message="add files and directories", + timestamp=1129025483, + timezone=36000, + rev_id=b"test@user-2", + committer="test@user", + ) class TestamentTests(TestamentSetup): - def testament_class(self): return Testament @@ -69,100 +76,102 @@ def from_revision(self, repository, revision_id): def test_null_testament(self): """Testament for a revision with no contents.""" - t = self.from_revision(self.b.repository, b'test@user-1') + t = self.from_revision(self.b.repository, b"test@user-1") ass = self.assertTrue eq = self.assertEqual ass(isinstance(t, Testament)) - eq(t.revision_id, b'test@user-1') - eq(t.committer, 'test@user') + eq(t.revision_id, b"test@user-1") + eq(t.committer, "test@user") eq(t.timestamp, 1129025423) eq(t.timezone, 0) def test_testment_text_form(self): """Conversion of testament to canonical text form.""" - t = self.from_revision(self.b.repository, b'test@user-1') + t = self.from_revision(self.b.repository, b"test@user-1") text_form = t.as_text() - self.log(f'testament text form:\n{text_form}') - self.assertEqualDiff(text_form, self.expected('rev_1')) + self.log(f"testament text form:\n{text_form}") + self.assertEqualDiff(text_form, self.expected("rev_1")) short_text_form = t.as_short_text() - self.assertEqualDiff(short_text_form, self.expected('rev_1_short')) + self.assertEqualDiff(short_text_form, self.expected("rev_1_short")) def test_testament_with_contents(self): """Testament containing a file and a directory.""" - t = self.from_revision(self.b.repository, b'test@user-2') + t = self.from_revision(self.b.repository, b"test@user-2") text_form = t.as_text() - self.log(f'testament text form:\n{text_form}') - self.assertEqualDiff(text_form, self.expected('rev_2')) + self.log(f"testament text form:\n{text_form}") + self.assertEqualDiff(text_form, self.expected("rev_2")) actual_short = t.as_short_text() - self.assertEqualDiff(actual_short, self.expected('rev_2_short')) + self.assertEqualDiff(actual_short, self.expected("rev_2_short")) def test_testament_symlinks(self): """Testament containing symlink (where possible).""" self.requireFeature(SymlinkFeature(self.test_dir)) - os.symlink('wibble/linktarget', 'link') - self.wt.add(['link'], ids=[b'link-id']) - self.wt.commit(message='add symlink', - timestamp=1129025493, - timezone=36000, - rev_id=b'test@user-3', - committer='test@user') - t = self.from_revision(self.b.repository, b'test@user-3') - self.assertEqualDiff(t.as_text(), self.expected('rev_3')) + os.symlink("wibble/linktarget", "link") + self.wt.add(["link"], ids=[b"link-id"]) + self.wt.commit( + message="add symlink", + timestamp=1129025493, + timezone=36000, + rev_id=b"test@user-3", + committer="test@user", + ) + t = self.from_revision(self.b.repository, b"test@user-3") + self.assertEqualDiff(t.as_text(), self.expected("rev_3")) def test_testament_revprops(self): """Testament to revision with extra properties.""" - props = {'flavor': 'sour cherry\ncream cheese', - 'size': 'medium', - 'empty': '', - } - self.wt.commit(message='revision with properties', - timestamp=1129025493, - timezone=36000, - rev_id=b'test@user-3', - committer='test@user', - revprops=props) - t = self.from_revision(self.b.repository, b'test@user-3') - self.assertEqualDiff(t.as_text(), self.expected('rev_props')) + props = { + "flavor": "sour cherry\ncream cheese", + "size": "medium", + "empty": "", + } + self.wt.commit( + message="revision with properties", + timestamp=1129025493, + timezone=36000, + rev_id=b"test@user-3", + committer="test@user", + revprops=props, + ) + t = self.from_revision(self.b.repository, b"test@user-3") + self.assertEqualDiff(t.as_text(), self.expected("rev_props")) def test_testament_unicode_commit_message(self): self.wt.commit( - message='non-ascii commit \N{COPYRIGHT SIGN} me', + message="non-ascii commit \N{COPYRIGHT SIGN} me", timestamp=1129025493, timezone=36000, - rev_id=b'test@user-3', - committer='Erik B\xe5gfors ', - revprops={'uni': '\xb5'} - ) - t = self.from_revision(self.b.repository, b'test@user-3') + rev_id=b"test@user-3", + committer="Erik B\xe5gfors ", + revprops={"uni": "\xb5"}, + ) + t = self.from_revision(self.b.repository, b"test@user-3") self.assertEqualDiff( - self.expected('sample_unicode').encode('utf-8'), t.as_text()) + self.expected("sample_unicode").encode("utf-8"), t.as_text() + ) def test_from_tree(self): - tree = self.b.repository.revision_tree(b'test@user-2') + tree = self.b.repository.revision_tree(b"test@user-2") testament = self.testament_class().from_revision_tree(tree) text_1 = testament.as_short_text() - text_2 = self.from_revision(self.b.repository, - b'test@user-2').as_short_text() + text_2 = self.from_revision(self.b.repository, b"test@user-2").as_short_text() self.assertEqual(text_1, text_2) def test___init__(self): - revision = self.b.repository.get_revision(b'test@user-2') - tree = self.b.repository.revision_tree(b'test@user-2') + revision = self.b.repository.get_revision(b"test@user-2") + tree = self.b.repository.revision_tree(b"test@user-2") testament_1 = self.testament_class()(revision, tree) text_1 = testament_1.as_short_text() - text_2 = self.from_revision(self.b.repository, - b'test@user-2').as_short_text() + text_2 = self.from_revision(self.b.repository, b"test@user-2").as_short_text() self.assertEqual(text_1, text_2) class TestamentTestsStrict(TestamentTests): - def testament_class(self): return StrictTestament class TestamentTestsStrict2(TestamentTests): - def testament_class(self): return StrictTestament3 @@ -530,28 +539,31 @@ def testament_class(self): texts = { - Testament: {'rev_1': REV_1_TESTAMENT, - 'rev_1_short': REV_1_SHORT, - 'rev_2': REV_2_TESTAMENT, - 'rev_2_short': REV_2_SHORT, - 'rev_3': REV_3_TESTAMENT, - 'rev_props': REV_PROPS_TESTAMENT, - 'sample_unicode': SAMPLE_UNICODE_TESTAMENT, - }, - StrictTestament: {'rev_1': REV_1_STRICT_TESTAMENT, - 'rev_1_short': REV_1_SHORT_STRICT, - 'rev_2': REV_2_STRICT_TESTAMENT, - 'rev_2_short': REV_2_SHORT_STRICT, - 'rev_3': REV_3_TESTAMENT_STRICT, - 'rev_props': REV_PROPS_TESTAMENT_STRICT, - 'sample_unicode': SAMPLE_UNICODE_TESTAMENT_STRICT, - }, - StrictTestament3: {'rev_1': REV_1_STRICT_TESTAMENT3, - 'rev_1_short': REV_1_SHORT_STRICT3, - 'rev_2': REV_2_STRICT_TESTAMENT3, - 'rev_2_short': REV_2_SHORT_STRICT3, - 'rev_3': REV_3_TESTAMENT_STRICT3, - 'rev_props': REV_PROPS_TESTAMENT_STRICT3, - 'sample_unicode': SAMPLE_UNICODE_TESTAMENT_STRICT3, - }, + Testament: { + "rev_1": REV_1_TESTAMENT, + "rev_1_short": REV_1_SHORT, + "rev_2": REV_2_TESTAMENT, + "rev_2_short": REV_2_SHORT, + "rev_3": REV_3_TESTAMENT, + "rev_props": REV_PROPS_TESTAMENT, + "sample_unicode": SAMPLE_UNICODE_TESTAMENT, + }, + StrictTestament: { + "rev_1": REV_1_STRICT_TESTAMENT, + "rev_1_short": REV_1_SHORT_STRICT, + "rev_2": REV_2_STRICT_TESTAMENT, + "rev_2_short": REV_2_SHORT_STRICT, + "rev_3": REV_3_TESTAMENT_STRICT, + "rev_props": REV_PROPS_TESTAMENT_STRICT, + "sample_unicode": SAMPLE_UNICODE_TESTAMENT_STRICT, + }, + StrictTestament3: { + "rev_1": REV_1_STRICT_TESTAMENT3, + "rev_1_short": REV_1_SHORT_STRICT3, + "rev_2": REV_2_STRICT_TESTAMENT3, + "rev_2_short": REV_2_SHORT_STRICT3, + "rev_3": REV_3_TESTAMENT_STRICT3, + "rev_props": REV_PROPS_TESTAMENT_STRICT3, + "sample_unicode": SAMPLE_UNICODE_TESTAMENT_STRICT3, + }, } diff --git a/breezy/bzr/tests/test_transform.py b/breezy/bzr/tests/test_transform.py index 040b4dec1d..c85a7d4883 100644 --- a/breezy/bzr/tests/test_transform.py +++ b/breezy/bzr/tests/test_transform.py @@ -28,85 +28,88 @@ class TestInventoryAltered(TestCaseWithTransport): - def test_inventory_altered_unchanged(self): - tree = self.make_branch_and_tree('tree') - self.build_tree(['tree/foo']) - tree.add('foo', ids=b'foo-id') + tree = self.make_branch_and_tree("tree") + self.build_tree(["tree/foo"]) + tree.add("foo", ids=b"foo-id") with tree.preview_transform() as tt: self.assertEqual([], tt._inventory_altered()) def test_inventory_altered_changed_parent_id(self): - tree = self.make_branch_and_tree('tree') - self.build_tree(['tree/foo']) - tree.add('foo', ids=b'foo-id') + tree = self.make_branch_and_tree("tree") + self.build_tree(["tree/foo"]) + tree.add("foo", ids=b"foo-id") with tree.preview_transform() as tt: tt.unversion_file(tt.root) - tt.version_file(tt.root, file_id=b'new-id') - foo_trans_id = tt.trans_id_tree_path('foo') - foo_tuple = ('foo', foo_trans_id) - root_tuple = ('', tt.root) + tt.version_file(tt.root, file_id=b"new-id") + foo_trans_id = tt.trans_id_tree_path("foo") + foo_tuple = ("foo", foo_trans_id) + root_tuple = ("", tt.root) self.assertEqual([root_tuple, foo_tuple], tt._inventory_altered()) def test_inventory_altered_noop_changed_parent_id(self): - tree = self.make_branch_and_tree('tree') - self.build_tree(['tree/foo']) - tree.add('foo', ids=b'foo-id') + tree = self.make_branch_and_tree("tree") + self.build_tree(["tree/foo"]) + tree.add("foo", ids=b"foo-id") with tree.preview_transform() as tt: tt.unversion_file(tt.root) - tt.version_file(tt.root, file_id=tree.path2id('')) - tt.trans_id_tree_path('foo') + tt.version_file(tt.root, file_id=tree.path2id("")) + tt.trans_id_tree_path("foo") self.assertEqual([], tt._inventory_altered()) class TestBuildTree(TestCaseWithTransport): - def test_build_tree_with_symlinks(self): self.requireFeature(features.SymlinkFeature(self.test_dir)) - os.mkdir('a') - a = ControlDir.create_standalone_workingtree('a') - os.mkdir('a/foo') - with open('a/foo/bar', 'wb') as f: - f.write(b'contents') - os.symlink('a/foo/bar', 'a/foo/baz') - a.add(['foo', 'foo/bar', 'foo/baz']) - a.commit('initial commit') - b = ControlDir.create_standalone_workingtree('b') + os.mkdir("a") + a = ControlDir.create_standalone_workingtree("a") + os.mkdir("a/foo") + with open("a/foo/bar", "wb") as f: + f.write(b"contents") + os.symlink("a/foo/bar", "a/foo/baz") + a.add(["foo", "foo/bar", "foo/baz"]) + a.commit("initial commit") + b = ControlDir.create_standalone_workingtree("b") basis = a.basis_tree() basis.lock_read() self.addCleanup(basis.unlock) build_tree(basis, b) - self.assertTrue(os.path.isdir('b/foo')) - with open('b/foo/bar', 'rb') as f: + self.assertTrue(os.path.isdir("b/foo")) + with open("b/foo/bar", "rb") as f: self.assertEqual(f.read(), b"contents") - self.assertEqual(os.readlink('b/foo/baz'), 'a/foo/bar') + self.assertEqual(os.readlink("b/foo/baz"), "a/foo/bar") def test_build_with_references(self): - tree = self.make_branch_and_tree('source', - format='development-subtree') - subtree = self.make_branch_and_tree('source/subtree', - format='development-subtree') + tree = self.make_branch_and_tree("source", format="development-subtree") + subtree = self.make_branch_and_tree( + "source/subtree", format="development-subtree" + ) tree.add_reference(subtree) - tree.commit('a revision') - tree.branch.create_checkout('target') - self.assertPathExists('target') - self.assertPathExists('target/subtree') + tree.commit("a revision") + tree.branch.create_checkout("target") + self.assertPathExists("target") + self.assertPathExists("target/subtree") def test_file_conflict_handling(self): """Ensure that when building trees, conflict handling is done.""" - source = self.make_branch_and_tree('source') - target = self.make_branch_and_tree('target') - self.build_tree(['source/file', 'target/file']) - source.add('file', ids=b'new-file') - source.commit('added file') + source = self.make_branch_and_tree("source") + target = self.make_branch_and_tree("target") + self.build_tree(["source/file", "target/file"]) + source.add("file", ids=b"new-file") + source.commit("added file") build_tree(source.basis_tree(), target) self.assertEqual( - [DuplicateEntry('Moved existing file to', 'file.moved', - 'file', None, 'new-file')], - target.conflicts()) - target2 = self.make_branch_and_tree('target2') - with open('target2/file', 'wb') as target_file, \ - open('source/file', 'rb') as source_file: + [ + DuplicateEntry( + "Moved existing file to", "file.moved", "file", None, "new-file" + ) + ], + target.conflicts(), + ) + target2 = self.make_branch_and_tree("target2") + with open("target2/file", "wb") as target_file, open( + "source/file", "rb" + ) as source_file: target_file.write(source_file.read()) build_tree(source.basis_tree(), target2) self.assertEqual([], target2.conflicts()) @@ -114,155 +117,173 @@ def test_file_conflict_handling(self): def test_symlink_conflict_handling(self): """Ensure that when building trees, conflict handling is done.""" self.requireFeature(features.SymlinkFeature(self.test_dir)) - source = self.make_branch_and_tree('source') - os.symlink('foo', 'source/symlink') - source.add('symlink', ids=b'new-symlink') - source.commit('added file') - target = self.make_branch_and_tree('target') - os.symlink('bar', 'target/symlink') + source = self.make_branch_and_tree("source") + os.symlink("foo", "source/symlink") + source.add("symlink", ids=b"new-symlink") + source.commit("added file") + target = self.make_branch_and_tree("target") + os.symlink("bar", "target/symlink") build_tree(source.basis_tree(), target) self.assertEqual( - [DuplicateEntry('Moved existing file to', 'symlink.moved', - 'symlink', None, 'new-symlink')], - target.conflicts()) - target = self.make_branch_and_tree('target2') - os.symlink('foo', 'target2/symlink') + [ + DuplicateEntry( + "Moved existing file to", + "symlink.moved", + "symlink", + None, + "new-symlink", + ) + ], + target.conflicts(), + ) + target = self.make_branch_and_tree("target2") + os.symlink("foo", "target2/symlink") build_tree(source.basis_tree(), target) self.assertEqual([], target.conflicts()) def test_directory_conflict_handling(self): """Ensure that when building trees, conflict handling is done.""" - source = self.make_branch_and_tree('source') - target = self.make_branch_and_tree('target') - self.build_tree(['source/dir1/', 'source/dir1/file', 'target/dir1/']) - source.add(['dir1', 'dir1/file'], ids=[b'new-dir1', b'new-file']) - source.commit('added file') + source = self.make_branch_and_tree("source") + target = self.make_branch_and_tree("target") + self.build_tree(["source/dir1/", "source/dir1/file", "target/dir1/"]) + source.add(["dir1", "dir1/file"], ids=[b"new-dir1", b"new-file"]) + source.commit("added file") build_tree(source.basis_tree(), target) self.assertEqual([], target.conflicts()) - self.assertPathExists('target/dir1/file') + self.assertPathExists("target/dir1/file") # Ensure contents are merged - target = self.make_branch_and_tree('target2') - self.build_tree(['target2/dir1/', 'target2/dir1/file2']) + target = self.make_branch_and_tree("target2") + self.build_tree(["target2/dir1/", "target2/dir1/file2"]) build_tree(source.basis_tree(), target) self.assertEqual([], target.conflicts()) - self.assertPathExists('target2/dir1/file2') - self.assertPathExists('target2/dir1/file') + self.assertPathExists("target2/dir1/file2") + self.assertPathExists("target2/dir1/file") # Ensure new contents are suppressed for existing branches - target = self.make_branch_and_tree('target3') - self.make_branch('target3/dir1') - self.build_tree(['target3/dir1/file2']) + target = self.make_branch_and_tree("target3") + self.make_branch("target3/dir1") + self.build_tree(["target3/dir1/file2"]) build_tree(source.basis_tree(), target) - self.assertPathDoesNotExist('target3/dir1/file') - self.assertPathExists('target3/dir1/file2') - self.assertPathExists('target3/dir1.diverted/file') + self.assertPathDoesNotExist("target3/dir1/file") + self.assertPathExists("target3/dir1/file2") + self.assertPathExists("target3/dir1.diverted/file") self.assertEqual( - [DuplicateEntry('Diverted to', 'dir1.diverted', - 'dir1', 'new-dir1', None)], - target.conflicts()) + [DuplicateEntry("Diverted to", "dir1.diverted", "dir1", "new-dir1", None)], + target.conflicts(), + ) - target = self.make_branch_and_tree('target4') - self.build_tree(['target4/dir1/']) - self.make_branch('target4/dir1/file') + target = self.make_branch_and_tree("target4") + self.build_tree(["target4/dir1/"]) + self.make_branch("target4/dir1/file") build_tree(source.basis_tree(), target) - self.assertPathExists('target4/dir1/file') - self.assertEqual('directory', file_kind('target4/dir1/file')) - self.assertPathExists('target4/dir1/file.diverted') + self.assertPathExists("target4/dir1/file") + self.assertEqual("directory", file_kind("target4/dir1/file")) + self.assertPathExists("target4/dir1/file.diverted") self.assertEqual( - [DuplicateEntry('Diverted to', 'dir1/file.diverted', - 'dir1/file', 'new-file', None)], - target.conflicts()) + [ + DuplicateEntry( + "Diverted to", "dir1/file.diverted", "dir1/file", "new-file", None + ) + ], + target.conflicts(), + ) def test_mixed_conflict_handling(self): """Ensure that when building trees, conflict handling is done.""" - source = self.make_branch_and_tree('source') - target = self.make_branch_and_tree('target') - self.build_tree(['source/name', 'target/name/']) - source.add('name', ids=b'new-name') - source.commit('added file') + source = self.make_branch_and_tree("source") + target = self.make_branch_and_tree("target") + self.build_tree(["source/name", "target/name/"]) + source.add("name", ids=b"new-name") + source.commit("added file") build_tree(source.basis_tree(), target) self.assertEqual( - [DuplicateEntry('Moved existing file to', - 'name.moved', 'name', None, 'new-name')], - target.conflicts()) + [ + DuplicateEntry( + "Moved existing file to", "name.moved", "name", None, "new-name" + ) + ], + target.conflicts(), + ) def test_raises_in_populated(self): - source = self.make_branch_and_tree('source') - self.build_tree(['source/name']) - source.add('name') - source.commit('added name') - target = self.make_branch_and_tree('target') - self.build_tree(['target/name']) - target.add('name') - self.assertRaises(errors.WorkingTreeAlreadyPopulated, - build_tree, source.basis_tree(), target) + source = self.make_branch_and_tree("source") + self.build_tree(["source/name"]) + source.add("name") + source.commit("added name") + target = self.make_branch_and_tree("target") + self.build_tree(["target/name"]) + target.add("name") + self.assertRaises( + errors.WorkingTreeAlreadyPopulated, build_tree, source.basis_tree(), target + ) def test_build_tree_rename_count(self): - source = self.make_branch_and_tree('source') - self.build_tree(['source/file1', 'source/dir1/']) - source.add(['file1', 'dir1']) - source.commit('add1') - target1 = self.make_branch_and_tree('target1') + source = self.make_branch_and_tree("source") + self.build_tree(["source/file1", "source/dir1/"]) + source.add(["file1", "dir1"]) + source.commit("add1") + target1 = self.make_branch_and_tree("target1") transform_result = build_tree(source.basis_tree(), target1) self.assertEqual(2, transform_result.rename_count) - self.build_tree(['source/dir1/file2']) - source.add(['dir1/file2']) - source.commit('add3') - target2 = self.make_branch_and_tree('target2') + self.build_tree(["source/dir1/file2"]) + source.add(["dir1/file2"]) + source.commit("add3") + target2 = self.make_branch_and_tree("target2") transform_result = build_tree(source.basis_tree(), target2) # children of non-root directories should not be renamed self.assertEqual(2, transform_result.rename_count) def create_ab_tree(self): """Create a committed test tree with two files.""" - source = self.make_branch_and_tree('source') - self.build_tree_contents([('source/file1', b'A')]) - self.build_tree_contents([('source/file2', b'B')]) - source.add(['file1', 'file2'], ids=[b'file1-id', b'file2-id']) - source.commit('commit files') + source = self.make_branch_and_tree("source") + self.build_tree_contents([("source/file1", b"A")]) + self.build_tree_contents([("source/file2", b"B")]) + source.add(["file1", "file2"], ids=[b"file1-id", b"file2-id"]) + source.commit("commit files") source.lock_write() self.addCleanup(source.unlock) return source def test_build_tree_accelerator_tree(self): source = self.create_ab_tree() - self.build_tree_contents([('source/file2', b'C')]) + self.build_tree_contents([("source/file2", b"C")]) calls = [] real_source_get_file = source.get_file def get_file(path): calls.append(path) return real_source_get_file(path) + source.get_file = get_file - target = self.make_branch_and_tree('target') + target = self.make_branch_and_tree("target") revision_tree = source.basis_tree() revision_tree.lock_read() self.addCleanup(revision_tree.unlock) build_tree(revision_tree, target, source) - self.assertEqual(['file1'], calls) + self.assertEqual(["file1"], calls) target.lock_read() self.addCleanup(target.unlock) self.assertEqual([], list(target.iter_changes(revision_tree))) def test_build_tree_accelerator_tree_observes_sha1(self): source = self.create_ab_tree() - sha1 = osutils.sha_string(b'A') - target = self.make_branch_and_tree('target') + sha1 = osutils.sha_string(b"A") + target = self.make_branch_and_tree("target") target.lock_write() self.addCleanup(target.unlock) state = target.current_dirstate() state._cutoff_time = time.time() + 60 build_tree(source.basis_tree(), target, source) - entry = state._get_entry(0, path_utf8=b'file1') + entry = state._get_entry(0, path_utf8=b"file1") self.assertEqual(sha1, entry[1][0][1]) def test_build_tree_accelerator_tree_missing_file(self): source = self.create_ab_tree() - os.unlink('source/file1') - source.remove(['file2']) - target = self.make_branch_and_tree('target') + os.unlink("source/file1") + source.remove(["file2"]) + target = self.make_branch_and_tree("target") revision_tree = source.basis_tree() revision_tree.lock_read() self.addCleanup(revision_tree.unlock) @@ -273,23 +294,24 @@ def test_build_tree_accelerator_tree_missing_file(self): def test_build_tree_accelerator_wrong_kind(self): self.requireFeature(features.SymlinkFeature(self.test_dir)) - source = self.make_branch_and_tree('source') - self.build_tree_contents([('source/file1', b'')]) - self.build_tree_contents([('source/file2', b'')]) - source.add(['file1', 'file2'], ids=[b'file1-id', b'file2-id']) - source.commit('commit files') - os.unlink('source/file2') - self.build_tree_contents([('source/file2/', b'C')]) - os.unlink('source/file1') - os.symlink('file2', 'source/file1') + source = self.make_branch_and_tree("source") + self.build_tree_contents([("source/file1", b"")]) + self.build_tree_contents([("source/file2", b"")]) + source.add(["file1", "file2"], ids=[b"file1-id", b"file2-id"]) + source.commit("commit files") + os.unlink("source/file2") + self.build_tree_contents([("source/file2/", b"C")]) + os.unlink("source/file1") + os.symlink("file2", "source/file1") calls = [] real_source_get_file = source.get_file def get_file(path): calls.append(path) return real_source_get_file(path) + source.get_file = get_file - target = self.make_branch_and_tree('target') + target = self.make_branch_and_tree("target") revision_tree = source.basis_tree() revision_tree.lock_read() self.addCleanup(revision_tree.unlock) @@ -302,7 +324,7 @@ def get_file(path): def test_build_tree_hardlink(self): self.requireFeature(features.HardlinkFeature(self.test_dir)) source = self.create_ab_tree() - target = self.make_branch_and_tree('target') + target = self.make_branch_and_tree("target") revision_tree = source.basis_tree() revision_tree.lock_read() self.addCleanup(revision_tree.unlock) @@ -310,29 +332,29 @@ def test_build_tree_hardlink(self): target.lock_read() self.addCleanup(target.unlock) self.assertEqual([], list(target.iter_changes(revision_tree))) - source_stat = os.stat('source/file1') - target_stat = os.stat('target/file1') + source_stat = os.stat("source/file1") + target_stat = os.stat("target/file1") self.assertEqual(source_stat, target_stat) # Explicitly disallowing hardlinks should prevent them. - target2 = self.make_branch_and_tree('target2') + target2 = self.make_branch_and_tree("target2") build_tree(revision_tree, target2, source, hardlink=False) target2.lock_read() self.addCleanup(target2.unlock) self.assertEqual([], list(target2.iter_changes(revision_tree))) - source_stat = os.stat('source/file1') - target2_stat = os.stat('target2/file1') + source_stat = os.stat("source/file1") + target2_stat = os.stat("target2/file1") self.assertNotEqual(source_stat, target2_stat) def test_build_tree_accelerator_tree_moved(self): - source = self.make_branch_and_tree('source') - self.build_tree_contents([('source/file1', b'A')]) - source.add(['file1'], ids=[b'file1-id']) - source.commit('commit files') - source.rename_one('file1', 'file2') + source = self.make_branch_and_tree("source") + self.build_tree_contents([("source/file1", b"A")]) + source.add(["file1"], ids=[b"file1-id"]) + source.commit("commit files") + source.rename_one("file1", "file2") source.lock_read() self.addCleanup(source.unlock) - target = self.make_branch_and_tree('target') + target = self.make_branch_and_tree("target") revision_tree = source.basis_tree() revision_tree.lock_read() self.addCleanup(revision_tree.unlock) @@ -345,11 +367,11 @@ def test_build_tree_hardlinks_preserve_execute(self): self.requireFeature(features.HardlinkFeature(self.test_dir)) source = self.create_ab_tree() tt = source.transform() - trans_id = tt.trans_id_tree_path('file1') + trans_id = tt.trans_id_tree_path("file1") tt.set_executability(True, trans_id) tt.apply() - self.assertTrue(source.is_executable('file1')) - target = self.make_branch_and_tree('target') + self.assertTrue(source.is_executable("file1")) + target = self.make_branch_and_tree("target") revision_tree = source.basis_tree() revision_tree.lock_read() self.addCleanup(revision_tree.unlock) @@ -357,7 +379,7 @@ def test_build_tree_hardlinks_preserve_execute(self): target.lock_read() self.addCleanup(target.unlock) self.assertEqual([], list(target.iter_changes(revision_tree))) - self.assertTrue(source.is_executable('file1')) + self.assertTrue(source.is_executable("file1")) def install_rot13_content_filter(self, pattern): # We could use @@ -368,23 +390,26 @@ def install_rot13_content_filter(self, pattern): def restore_registry(): filters._reset_registry(original_registry) + self.addCleanup(restore_registry) def rot13(chunks, context=None): return [ - codecs.encode(chunk.decode('ascii'), 'rot13').encode('ascii') - for chunk in chunks] + codecs.encode(chunk.decode("ascii"), "rot13").encode("ascii") + for chunk in chunks + ] + rot13filter = filters.ContentFilter(rot13, rot13) - filters.filter_stacks_registry.register( - 'rot13', {'yes': [rot13filter]}.get) - os.mkdir(self.test_home_dir + '/.bazaar') - rules_filename = self.test_home_dir + '/.bazaar/rules' - with open(rules_filename, 'wb') as f: - f.write(b'[name %s]\nrot13=yes\n' % (pattern,)) + filters.filter_stacks_registry.register("rot13", {"yes": [rot13filter]}.get) + os.mkdir(self.test_home_dir + "/.bazaar") + rules_filename = self.test_home_dir + "/.bazaar/rules" + with open(rules_filename, "wb") as f: + f.write(b"[name %s]\nrot13=yes\n" % (pattern,)) def uninstall_rules(): os.remove(rules_filename) rules.reset_rules() + self.addCleanup(uninstall_rules) rules.reset_rules() @@ -394,9 +419,9 @@ def test_build_tree_content_filtered_files_are_not_hardlinked(self): if it can). """ self.requireFeature(features.HardlinkFeature(self.test_dir)) - self.install_rot13_content_filter(b'file1') + self.install_rot13_content_filter(b"file1") source = self.create_ab_tree() - target = self.make_branch_and_tree('target') + target = self.make_branch_and_tree("target") revision_tree = source.basis_tree() revision_tree.lock_read() self.addCleanup(revision_tree.unlock) @@ -404,36 +429,39 @@ def test_build_tree_content_filtered_files_are_not_hardlinked(self): target.lock_read() self.addCleanup(target.unlock) self.assertEqual([], list(target.iter_changes(revision_tree))) - source_stat = os.stat('source/file1') - target_stat = os.stat('target/file1') + source_stat = os.stat("source/file1") + target_stat = os.stat("target/file1") self.assertNotEqual(source_stat, target_stat) - source_stat = os.stat('source/file2') - target_stat = os.stat('target/file2') + source_stat = os.stat("source/file2") + target_stat = os.stat("target/file2") self.assertEqualStat(source_stat, target_stat) def test_case_insensitive_build_tree_inventory(self): - if (features.CaseInsensitiveFilesystemFeature.available() - or features.CaseInsCasePresFilenameFeature.available()): - raise UnavailableFeature('Fully case sensitive filesystem') - source = self.make_branch_and_tree('source') - self.build_tree(['source/file', 'source/FILE']) - source.add(['file', 'FILE'], ids=[b'lower-id', b'upper-id']) - source.commit('added files') + if ( + features.CaseInsensitiveFilesystemFeature.available() + or features.CaseInsCasePresFilenameFeature.available() + ): + raise UnavailableFeature("Fully case sensitive filesystem") + source = self.make_branch_and_tree("source") + self.build_tree(["source/file", "source/FILE"]) + source.add(["file", "FILE"], ids=[b"lower-id", b"upper-id"]) + source.commit("added files") # Don't try this at home, kids! # Force the tree to report that it is case insensitive - target = self.make_branch_and_tree('target') + target = self.make_branch_and_tree("target") target.case_sensitive = False build_tree(source.basis_tree(), target, source, delta_from_tree=True) - self.assertEqual('file.moved', target.id2path(b'lower-id')) - self.assertEqual('FILE', target.id2path(b'upper-id')) + self.assertEqual("file.moved", target.id2path(b"lower-id")) + self.assertEqual("FILE", target.id2path(b"upper-id")) def test_build_tree_observes_sha(self): - source = self.make_branch_and_tree('source') - self.build_tree(['source/file1', 'source/dir/', 'source/dir/file2']) - source.add(['file1', 'dir', 'dir/file2'], - ids=[b'file1-id', b'dir-id', b'file2-id']) - source.commit('new files') - target = self.make_branch_and_tree('target') + source = self.make_branch_and_tree("source") + self.build_tree(["source/file1", "source/dir/", "source/dir/file2"]) + source.add( + ["file1", "dir", "dir/file2"], ids=[b"file1-id", b"dir-id", b"file2-id"] + ) + source.commit("new files") + target = self.make_branch_and_tree("target") target.lock_write() self.addCleanup(target.unlock) # We make use of the fact that DirState caches its cutoff time. So we @@ -441,23 +469,23 @@ def test_build_tree_observes_sha(self): state = target.current_dirstate() state._cutoff_time = time.time() + 60 build_tree(source.basis_tree(), target) - entry1_sha = osutils.sha_file_by_name('source/file1') - entry2_sha = osutils.sha_file_by_name('source/dir/file2') + entry1_sha = osutils.sha_file_by_name("source/file1") + entry2_sha = osutils.sha_file_by_name("source/dir/file2") # entry[1] is the state information, entry[1][0] is the state of the # working tree, entry[1][0][1] is the sha value for the current working # tree - entry1 = state._get_entry(0, path_utf8=b'file1') + entry1 = state._get_entry(0, path_utf8=b"file1") self.assertEqual(entry1_sha, entry1[1][0][1]) # The 'size' field must also be set. self.assertEqual(25, entry1[1][0][2]) entry1_state = entry1[1][0] - entry2 = state._get_entry(0, path_utf8=b'dir/file2') + entry2 = state._get_entry(0, path_utf8=b"dir/file2") self.assertEqual(entry2_sha, entry2[1][0][1]) self.assertEqual(29, entry2[1][0][2]) entry2_state = entry2[1][0] # Now, make sure that we don't have to re-read the content. The # packed_stat should match exactly. - self.assertEqual(entry1_sha, target.get_file_sha1('file1')) - self.assertEqual(entry2_sha, target.get_file_sha1('dir/file2')) + self.assertEqual(entry1_sha, target.get_file_sha1("file1")) + self.assertEqual(entry2_sha, target.get_file_sha1("dir/file2")) self.assertEqual(entry1_state, entry1[1][0]) self.assertEqual(entry2_state, entry2[1][0]) diff --git a/breezy/bzr/tests/test_tuned_gzip.py b/breezy/bzr/tests/test_tuned_gzip.py index 36ad356e6b..f25c2834cc 100644 --- a/breezy/bzr/tests/test_tuned_gzip.py +++ b/breezy/bzr/tests/test_tuned_gzip.py @@ -24,26 +24,26 @@ class TestToGzip(tests.TestCase): - def assertToGzip(self, chunks): - raw_bytes = b''.join(chunks) + raw_bytes = b"".join(chunks) gzfromchunks = tuned_gzip.chunks_to_gzip(chunks) - decoded = gzip.GzipFile(fileobj=BytesIO(b''.join(gzfromchunks))).read() + decoded = gzip.GzipFile(fileobj=BytesIO(b"".join(gzfromchunks))).read() lraw, ldecoded = len(raw_bytes), len(decoded) - self.assertEqual(lraw, ldecoded, - 'Expecting data length %d, got %d' % (lraw, ldecoded)) + self.assertEqual( + lraw, ldecoded, "Expecting data length %d, got %d" % (lraw, ldecoded) + ) self.assertEqual(raw_bytes, decoded) def test_single_chunk(self): - self.assertToGzip([b'a modest chunk\nwith some various\nbits\n']) + self.assertToGzip([b"a modest chunk\nwith some various\nbits\n"]) def test_simple_text(self): - self.assertToGzip([b'some\n', b'strings\n', b'to\n', b'process\n']) + self.assertToGzip([b"some\n", b"strings\n", b"to\n", b"process\n"]) def test_large_chunks(self): - self.assertToGzip([b'a large string\n' * 1024]) - self.assertToGzip([b'a large string\n'] * 1024) + self.assertToGzip([b"a large string\n" * 1024]) + self.assertToGzip([b"a large string\n"] * 1024) def test_enormous_chunks(self): - self.assertToGzip([b'a large string\n' * 1024 * 256]) - self.assertToGzip([b'a large string\n'] * 1024 * 256) + self.assertToGzip([b"a large string\n" * 1024 * 256]) + self.assertToGzip([b"a large string\n"] * 1024 * 256) diff --git a/breezy/bzr/tests/test_versionedfile.py b/breezy/bzr/tests/test_versionedfile.py index c7c741caea..4f232c12a6 100644 --- a/breezy/bzr/tests/test_versionedfile.py +++ b/breezy/bzr/tests/test_versionedfile.py @@ -24,119 +24,123 @@ class Test_MPDiffGenerator(tests.TestCaseWithMemoryTransport): # Should this be a per vf test? def make_vf(self): - t = self.get_transport('') + t = self.get_transport("") factory = groupcompress.make_pack_factory(True, True, 1) return factory(t) def make_three_vf(self): vf = self.make_vf() - vf.add_lines((b'one',), (), [b'first\n']) - vf.add_lines((b'two',), [(b'one',)], [b'first\n', b'second\n']) - vf.add_lines((b'three',), [(b'one',), (b'two',)], - [b'first\n', b'second\n', b'third\n']) + vf.add_lines((b"one",), (), [b"first\n"]) + vf.add_lines((b"two",), [(b"one",)], [b"first\n", b"second\n"]) + vf.add_lines( + (b"three",), [(b"one",), (b"two",)], [b"first\n", b"second\n", b"third\n"] + ) return vf def test_finds_parents(self): vf = self.make_three_vf() - gen = versionedfile._MPDiffGenerator(vf, [(b'three',)]) + gen = versionedfile._MPDiffGenerator(vf, [(b"three",)]) needed_keys, refcount = gen._find_needed_keys() - self.assertEqual(sorted([(b'one',), (b'two',), (b'three',)]), - sorted(needed_keys)) - self.assertEqual({(b'one',): 1, (b'two',): 1}, refcount) + self.assertEqual( + sorted([(b"one",), (b"two",), (b"three",)]), sorted(needed_keys) + ) + self.assertEqual({(b"one",): 1, (b"two",): 1}, refcount) def test_ignores_ghost_parents(self): # If a parent is a ghost, it is just ignored vf = self.make_vf() - vf.add_lines((b'two',), [(b'one',)], [b'first\n', b'second\n']) - gen = versionedfile._MPDiffGenerator(vf, [(b'two',)]) + vf.add_lines((b"two",), [(b"one",)], [b"first\n", b"second\n"]) + gen = versionedfile._MPDiffGenerator(vf, [(b"two",)]) needed_keys, refcount = gen._find_needed_keys() - self.assertEqual(sorted([(b'two',)]), sorted(needed_keys)) + self.assertEqual(sorted([(b"two",)]), sorted(needed_keys)) # It is returned, but we don't really care as we won't extract it - self.assertEqual({(b'one',): 1}, refcount) - self.assertEqual([(b'one',)], sorted(gen.ghost_parents)) + self.assertEqual({(b"one",): 1}, refcount) + self.assertEqual([(b"one",)], sorted(gen.ghost_parents)) self.assertEqual([], sorted(gen.present_parents)) def test_raises_on_ghost_keys(self): # If the requested key is a ghost, then we have a problem vf = self.make_vf() - gen = versionedfile._MPDiffGenerator(vf, [(b'one',)]) - self.assertRaises(errors.RevisionNotPresent, - gen._find_needed_keys) + gen = versionedfile._MPDiffGenerator(vf, [(b"one",)]) + self.assertRaises(errors.RevisionNotPresent, gen._find_needed_keys) def test_refcount_multiple_children(self): vf = self.make_three_vf() - gen = versionedfile._MPDiffGenerator(vf, [(b'two',), (b'three',)]) + gen = versionedfile._MPDiffGenerator(vf, [(b"two",), (b"three",)]) needed_keys, refcount = gen._find_needed_keys() - self.assertEqual(sorted([(b'one',), (b'two',), (b'three',)]), - sorted(needed_keys)) - self.assertEqual({(b'one',): 2, (b'two',): 1}, refcount) - self.assertEqual([(b'one',)], sorted(gen.present_parents)) + self.assertEqual( + sorted([(b"one",), (b"two",), (b"three",)]), sorted(needed_keys) + ) + self.assertEqual({(b"one",): 2, (b"two",): 1}, refcount) + self.assertEqual([(b"one",)], sorted(gen.present_parents)) def test_process_contents(self): vf = self.make_three_vf() - gen = versionedfile._MPDiffGenerator(vf, [(b'two',), (b'three',)]) + gen = versionedfile._MPDiffGenerator(vf, [(b"two",), (b"three",)]) gen._find_needed_keys() - self.assertEqual({(b'two',): ((b'one',),), - (b'three',): ((b'one',), (b'two',))}, - gen.parent_map) - self.assertEqual({(b'one',): 2, (b'two',): 1}, gen.refcounts) - self.assertEqual(sorted([(b'one',), (b'two',), (b'three',)]), - sorted(gen.needed_keys)) - stream = vf.get_record_stream(gen.needed_keys, 'topological', True) + self.assertEqual( + {(b"two",): ((b"one",),), (b"three",): ((b"one",), (b"two",))}, + gen.parent_map, + ) + self.assertEqual({(b"one",): 2, (b"two",): 1}, gen.refcounts) + self.assertEqual( + sorted([(b"one",), (b"two",), (b"three",)]), sorted(gen.needed_keys) + ) + stream = vf.get_record_stream(gen.needed_keys, "topological", True) record = next(stream) - self.assertEqual((b'one',), record.key) + self.assertEqual((b"one",), record.key) # one is not needed in the output, but it is needed by children. As # such, it should end up in the various caches - gen._process_one_record(record.key, record.get_bytes_as('chunked')) + gen._process_one_record(record.key, record.get_bytes_as("chunked")) # The chunks should be cached, the refcount untouched - self.assertEqual({(b'one',)}, set(gen.chunks)) - self.assertEqual({(b'one',): 2, (b'two',): 1}, gen.refcounts) + self.assertEqual({(b"one",)}, set(gen.chunks)) + self.assertEqual({(b"one",): 2, (b"two",): 1}, gen.refcounts) self.assertEqual(set(), set(gen.diffs)) # Next we get 'two', which is something we output, but also needed for # three record = next(stream) - self.assertEqual((b'two',), record.key) - gen._process_one_record(record.key, record.get_bytes_as('chunked')) + self.assertEqual((b"two",), record.key) + gen._process_one_record(record.key, record.get_bytes_as("chunked")) # Both are now cached, and the diff for two has been extracted, and # one's refcount has been updated. two has been removed from the # parent_map - self.assertEqual({(b'one',), (b'two',)}, set(gen.chunks)) - self.assertEqual({(b'one',): 1, (b'two',): 1}, gen.refcounts) - self.assertEqual({(b'two',)}, set(gen.diffs)) - self.assertEqual({(b'three',): ((b'one',), (b'two',))}, - gen.parent_map) + self.assertEqual({(b"one",), (b"two",)}, set(gen.chunks)) + self.assertEqual({(b"one",): 1, (b"two",): 1}, gen.refcounts) + self.assertEqual({(b"two",)}, set(gen.diffs)) + self.assertEqual({(b"three",): ((b"one",), (b"two",))}, gen.parent_map) # Finally 'three', which allows us to remove all parents from the # caches record = next(stream) - self.assertEqual((b'three',), record.key) - gen._process_one_record(record.key, record.get_bytes_as('chunked')) + self.assertEqual((b"three",), record.key) + gen._process_one_record(record.key, record.get_bytes_as("chunked")) # Both are now cached, and the diff for two has been extracted, and # one's refcount has been updated self.assertEqual(set(), set(gen.chunks)) self.assertEqual({}, gen.refcounts) - self.assertEqual({(b'two',), (b'three',)}, set(gen.diffs)) + self.assertEqual({(b"two",), (b"three",)}, set(gen.diffs)) def test_compute_diffs(self): vf = self.make_three_vf() # The content is in the order requested, even if it isn't topological - gen = versionedfile._MPDiffGenerator(vf, [(b'two',), (b'three',), - (b'one',)]) + gen = versionedfile._MPDiffGenerator(vf, [(b"two",), (b"three",), (b"one",)]) diffs = gen.compute_diffs() expected_diffs = [ - multiparent.MultiParent([multiparent.ParentText(0, 0, 0, 1), - multiparent.NewText([b'second\n'])]), - multiparent.MultiParent([multiparent.ParentText(1, 0, 0, 2), - multiparent.NewText([b'third\n'])]), - multiparent.MultiParent([multiparent.NewText([b'first\n'])]), - ] + multiparent.MultiParent( + [multiparent.ParentText(0, 0, 0, 1), multiparent.NewText([b"second\n"])] + ), + multiparent.MultiParent( + [multiparent.ParentText(1, 0, 0, 2), multiparent.NewText([b"third\n"])] + ), + multiparent.MultiParent([multiparent.NewText([b"first\n"])]), + ] self.assertEqual(expected_diffs, diffs) class ErrorTests(tests.TestCase): - def test_unavailable_representation(self): - error = versionedfile.UnavailableRepresentation( - ('key',), "mpdiff", "fulltext") - self.assertEqualDiff("The encoding 'mpdiff' is not available for key " - "('key',) which is encoded as 'fulltext'.", - str(error)) + error = versionedfile.UnavailableRepresentation(("key",), "mpdiff", "fulltext") + self.assertEqualDiff( + "The encoding 'mpdiff' is not available for key " + "('key',) which is encoded as 'fulltext'.", + str(error), + ) diff --git a/breezy/bzr/tests/test_vf_search.py b/breezy/bzr/tests/test_vf_search.py index faf7c8bd8f..b36888e6b1 100644 --- a/breezy/bzr/tests/test_vf_search.py +++ b/breezy/bzr/tests/test_vf_search.py @@ -31,11 +31,13 @@ # rev3 / # | / # rev4 -ancestry_1 = {b'rev1': [NULL_REVISION], - b'rev2a': [b'rev1'], - b'rev2b': [b'rev1'], - b'rev3': [b'rev2a'], - b'rev4': [b'rev3', b'rev2b']} +ancestry_1 = { + b"rev1": [NULL_REVISION], + b"rev2a": [b"rev1"], + b"rev2b": [b"rev1"], + b"rev3": [b"rev2a"], + b"rev4": [b"rev3", b"rev2b"], +} # Ancestry 2: # @@ -48,11 +50,13 @@ # rev3a # | # rev4a -ancestry_2 = {b'rev1a': [NULL_REVISION], - b'rev2a': [b'rev1a'], - b'rev1b': [NULL_REVISION], - b'rev3a': [b'rev2a'], - b'rev4a': [b'rev3a']} +ancestry_2 = { + b"rev1a": [NULL_REVISION], + b"rev2a": [b"rev1a"], + b"rev1b": [NULL_REVISION], + b"rev3a": [b"rev2a"], + b"rev4a": [b"rev3a"], +} # Extended history shortcut @@ -67,17 +71,17 @@ # d | # |\| # e f -extended_history_shortcut = {b'a': [NULL_REVISION], - b'b': [b'a'], - b'c': [b'b'], - b'd': [b'c'], - b'e': [b'd'], - b'f': [b'a', b'd'], - } +extended_history_shortcut = { + b"a": [NULL_REVISION], + b"b": [b"a"], + b"c": [b"b"], + b"d": [b"c"], + b"e": [b"d"], + b"f": [b"a", b"d"], +} class TestSearchResultRefine(tests.TestCase): - def make_graph(self, ancestors): return _mod_graph.Graph(_mod_graph.DictParentsProvider(ancestors)) @@ -85,20 +89,25 @@ def test_refine(self): # Used when pulling from a stacked repository, so test some revisions # being satisfied from the stacking branch. self.make_graph( - {b"tip": [b"mid"], b"mid": [b"base"], b"tag": [b"base"], - b"base": [NULL_REVISION], NULL_REVISION: []}) + { + b"tip": [b"mid"], + b"mid": [b"base"], + b"tag": [b"base"], + b"base": [NULL_REVISION], + NULL_REVISION: [], + } + ) result = vf_search.SearchResult( - {b'tip', b'tag'}, - {NULL_REVISION}, 4, {b'tip', b'mid', b'tag', b'base'}) - result = result.refine({b'tip'}, {b'mid'}) + {b"tip", b"tag"}, {NULL_REVISION}, 4, {b"tip", b"mid", b"tag", b"base"} + ) + result = result.refine({b"tip"}, {b"mid"}) recipe = result.get_recipe() # We should be starting from tag (original head) and mid (seen ref) - self.assertEqual({b'mid', b'tag'}, recipe[1]) + self.assertEqual({b"mid", b"tag"}, recipe[1]) # We should be stopping at NULL (original stop) and tip (seen head) - self.assertEqual({NULL_REVISION, b'tip'}, recipe[2]) + self.assertEqual({NULL_REVISION, b"tip"}, recipe[2]) self.assertEqual(3, recipe[3]) - result = result.refine({b'mid', b'tag', b'base'}, - {NULL_REVISION}) + result = result.refine({b"mid", b"tag", b"base"}, {NULL_REVISION}) recipe = result.get_recipe() # We should be starting from nothing (NULL was known as a cut point) self.assertEqual(set(), recipe[1]) @@ -106,83 +115,100 @@ def test_refine(self): # tag (seen head) and mid(seen mid-point head). We could come back and # define this as not including mid, for minimal results, but it is # still 'correct' to include mid, and simpler/easier. - self.assertEqual({NULL_REVISION, b'tip', b'tag', b'mid'}, recipe[2]) + self.assertEqual({NULL_REVISION, b"tip", b"tag", b"mid"}, recipe[2]) self.assertEqual(0, recipe[3]) self.assertTrue(result.is_empty()) class TestSearchResultFromParentMap(TestGraphBase): - - def assertSearchResult(self, start_keys, stop_keys, key_count, parent_map, - missing_keys=()): + def assertSearchResult( + self, start_keys, stop_keys, key_count, parent_map, missing_keys=() + ): (start, stop, count) = vf_search.search_result_from_parent_map( - parent_map, missing_keys) - self.assertEqual((sorted(start_keys), sorted(stop_keys), key_count), - (sorted(start), sorted(stop), count)) + parent_map, missing_keys + ) + self.assertEqual( + (sorted(start_keys), sorted(stop_keys), key_count), + (sorted(start), sorted(stop), count), + ) def test_no_parents(self): self.assertSearchResult([], [], 0, {}) self.assertSearchResult([], [], 0, None) def test_ancestry_1(self): - self.assertSearchResult([b'rev4'], [NULL_REVISION], len(ancestry_1), - ancestry_1) + self.assertSearchResult([b"rev4"], [NULL_REVISION], len(ancestry_1), ancestry_1) def test_ancestry_2(self): - self.assertSearchResult([b'rev1b', b'rev4a'], [NULL_REVISION], - len(ancestry_2), ancestry_2) - self.assertSearchResult([b'rev1b', b'rev4a'], [], - len(ancestry_2) + 1, ancestry_2, - missing_keys=[NULL_REVISION]) + self.assertSearchResult( + [b"rev1b", b"rev4a"], [NULL_REVISION], len(ancestry_2), ancestry_2 + ) + self.assertSearchResult( + [b"rev1b", b"rev4a"], + [], + len(ancestry_2) + 1, + ancestry_2, + missing_keys=[NULL_REVISION], + ) def test_partial_search(self): - parent_map = {k: extended_history_shortcut[k] - for k in [b'e', b'f']} - self.assertSearchResult([b'e', b'f'], [b'd', b'a'], 2, - parent_map) - parent_map.update((k, extended_history_shortcut[k]) - for k in [b'd', b'a']) - self.assertSearchResult([b'e', b'f'], [b'c', NULL_REVISION], 4, - parent_map) - parent_map[b'c'] = extended_history_shortcut[b'c'] - self.assertSearchResult([b'e', b'f'], [b'b'], 6, - parent_map, missing_keys=[NULL_REVISION]) - parent_map[b'b'] = extended_history_shortcut[b'b'] - self.assertSearchResult([b'e', b'f'], [], 7, - parent_map, missing_keys=[NULL_REVISION]) + parent_map = {k: extended_history_shortcut[k] for k in [b"e", b"f"]} + self.assertSearchResult([b"e", b"f"], [b"d", b"a"], 2, parent_map) + parent_map.update((k, extended_history_shortcut[k]) for k in [b"d", b"a"]) + self.assertSearchResult([b"e", b"f"], [b"c", NULL_REVISION], 4, parent_map) + parent_map[b"c"] = extended_history_shortcut[b"c"] + self.assertSearchResult( + [b"e", b"f"], [b"b"], 6, parent_map, missing_keys=[NULL_REVISION] + ) + parent_map[b"b"] = extended_history_shortcut[b"b"] + self.assertSearchResult( + [b"e", b"f"], [], 7, parent_map, missing_keys=[NULL_REVISION] + ) class TestLimitedSearchResultFromParentMap(TestGraphBase): - - def assertSearchResult(self, start_keys, stop_keys, key_count, parent_map, - missing_keys, tip_keys, depth): + def assertSearchResult( + self, + start_keys, + stop_keys, + key_count, + parent_map, + missing_keys, + tip_keys, + depth, + ): (start, stop, count) = vf_search.limited_search_result_from_parent_map( - parent_map, missing_keys, tip_keys, depth) - self.assertEqual((sorted(start_keys), sorted(stop_keys), key_count), - (sorted(start), sorted(stop), count)) + parent_map, missing_keys, tip_keys, depth + ) + self.assertEqual( + (sorted(start_keys), sorted(stop_keys), key_count), + (sorted(start), sorted(stop), count), + ) def test_empty_ancestry(self): - self.assertSearchResult([], [], 0, {}, (), [b'tip-rev-id'], 10) + self.assertSearchResult([], [], 0, {}, (), [b"tip-rev-id"], 10) def test_ancestry_1(self): - self.assertSearchResult([b'rev4'], [b'rev1'], 4, - ancestry_1, (), [b'rev1'], 10) - self.assertSearchResult([b'rev2a', b'rev2b'], [b'rev1'], 2, - ancestry_1, (), [b'rev1'], 1) + self.assertSearchResult([b"rev4"], [b"rev1"], 4, ancestry_1, (), [b"rev1"], 10) + self.assertSearchResult( + [b"rev2a", b"rev2b"], [b"rev1"], 2, ancestry_1, (), [b"rev1"], 1 + ) def test_multiple_heads(self): - self.assertSearchResult([b'e', b'f'], [b'a'], 5, - extended_history_shortcut, (), [b'a'], 10) + self.assertSearchResult( + [b"e", b"f"], [b"a"], 5, extended_history_shortcut, (), [b"a"], 10 + ) # Note that even though we only take 1 step back, we find 'f', which # means the described search will still find d and c. - self.assertSearchResult([b'f'], [b'a'], 4, - extended_history_shortcut, (), [b'a'], 1) - self.assertSearchResult([b'f'], [b'a'], 4, - extended_history_shortcut, (), [b'a'], 2) + self.assertSearchResult( + [b"f"], [b"a"], 4, extended_history_shortcut, (), [b"a"], 1 + ) + self.assertSearchResult( + [b"f"], [b"a"], 4, extended_history_shortcut, (), [b"a"], 2 + ) class TestPendingAncestryResultRefine(tests.TestCase): - def make_graph(self, ancestors): return _mod_graph.Graph(_mod_graph.DictParentsProvider(ancestors)) @@ -190,13 +216,18 @@ def test_refine(self): # Used when pulling from a stacked repository, so test some revisions # being satisfied from the stacking branch. self.make_graph( - {b"tip": [b"mid"], b"mid": [b"base"], b"tag": [b"base"], - b"base": [NULL_REVISION], NULL_REVISION: []}) - result = vf_search.PendingAncestryResult([b'tip', b'tag'], None) - result = result.refine({b'tip'}, {b'mid'}) - self.assertEqual({b'mid', b'tag'}, result.heads) - result = result.refine({b'mid', b'tag', b'base'}, - {NULL_REVISION}) + { + b"tip": [b"mid"], + b"mid": [b"base"], + b"tag": [b"base"], + b"base": [NULL_REVISION], + NULL_REVISION: [], + } + ) + result = vf_search.PendingAncestryResult([b"tip", b"tag"], None) + result = result.refine({b"tip"}, {b"mid"}) + self.assertEqual({b"mid", b"tag"}, result.heads) + result = result.refine({b"mid", b"tag", b"base"}, {NULL_REVISION}) self.assertEqual({NULL_REVISION}, result.heads) self.assertTrue(result.is_empty()) @@ -205,33 +236,32 @@ class TestPendingAncestryResultGetKeys(tests.TestCaseWithMemoryTransport): """Tests for breezy.graph.PendingAncestryResult.""" def test_get_keys(self): - builder = self.make_branch_builder('b') + builder = self.make_branch_builder("b") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', ''))], - revision_id=b'rev-1') - builder.build_snapshot([b'rev-1'], [], revision_id=b'rev-2') + builder.build_snapshot( + None, [("add", ("", b"root-id", "directory", ""))], revision_id=b"rev-1" + ) + builder.build_snapshot([b"rev-1"], [], revision_id=b"rev-2") builder.finish_series() repo = builder.get_branch().repository repo.lock_read() self.addCleanup(repo.unlock) - result = vf_search.PendingAncestryResult([b'rev-2'], repo) - self.assertEqual({b'rev-1', b'rev-2'}, set(result.get_keys())) + result = vf_search.PendingAncestryResult([b"rev-2"], repo) + self.assertEqual({b"rev-1", b"rev-2"}, set(result.get_keys())) def test_get_keys_excludes_ghosts(self): - builder = self.make_branch_builder('b') + builder = self.make_branch_builder("b") builder.start_series() - builder.build_snapshot(None, [ - ('add', ('', b'root-id', 'directory', ''))], - revision_id=b'rev-1') - builder.build_snapshot([b'rev-1', b'ghost'], [], revision_id=b'rev-2') + builder.build_snapshot( + None, [("add", ("", b"root-id", "directory", ""))], revision_id=b"rev-1" + ) + builder.build_snapshot([b"rev-1", b"ghost"], [], revision_id=b"rev-2") builder.finish_series() repo = builder.get_branch().repository repo.lock_read() self.addCleanup(repo.unlock) - result = vf_search.PendingAncestryResult([b'rev-2'], repo) - self.assertEqual(sorted([b'rev-1', b'rev-2']), - sorted(result.get_keys())) + result = vf_search.PendingAncestryResult([b"rev-2"], repo) + self.assertEqual(sorted([b"rev-1", b"rev-2"]), sorted(result.get_keys())) def test_get_keys_excludes_null(self): # Make a 'graph' with an iter_ancestry that returns NULL_REVISION @@ -239,8 +269,9 @@ def test_get_keys_excludes_null(self): # ancestries. class StubGraph: def iter_ancestry(self, keys): - return [(NULL_REVISION, ()), (b'foo', (NULL_REVISION,))] - result = vf_search.PendingAncestryResult([b'rev-3'], None) + return [(NULL_REVISION, ()), (b"foo", (NULL_REVISION,))] + + result = vf_search.PendingAncestryResult([b"rev-3"], None) result_keys = result._get_keys(StubGraph()) # Only the non-null keys from the ancestry appear. - self.assertEqual({b'foo'}, set(result_keys)) + self.assertEqual({b"foo"}, set(result_keys)) diff --git a/breezy/bzr/tests/test_vfs_ratchet.py b/breezy/bzr/tests/test_vfs_ratchet.py index 7c266b9b83..d9b5d7b73c 100644 --- a/breezy/bzr/tests/test_vfs_ratchet.py +++ b/breezy/bzr/tests/test_vfs_ratchet.py @@ -24,18 +24,18 @@ class TestSmartServerCommit(TestCaseWithTransport): - def test_commit_to_lightweight(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') + t = self.make_branch_and_tree("from") for count in range(9): - t.commit(message='commit %d' % count) - out, err = self.run_bzr(['checkout', '--lightweight', self.get_url('from'), - 'target']) + t.commit(message="commit %d" % count) + out, err = self.run_bzr( + ["checkout", "--lightweight", self.get_url("from"), "target"] + ) self.reset_smart_call_log() - self.build_tree(['target/afile']) - self.run_bzr(['add', 'target/afile']) - out, err = self.run_bzr(['commit', '-m', 'do something', 'target']) + self.build_tree(["target/afile"]) + self.run_bzr(["add", "target/afile"]) + out, err = self.run_bzr(["commit", "-m", "do something", "target"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -43,21 +43,23 @@ def test_commit_to_lightweight(self): # upwards without agreement from bzr's network support maintainers. self.assertLength(211, self.hpss_calls) self.assertLength(2, self.hpss_connections) - self.expectFailure("commit still uses VFS calls", - self.assertThat, self.hpss_calls, ContainsNoVfsCalls) + self.expectFailure( + "commit still uses VFS calls", + self.assertThat, + self.hpss_calls, + ContainsNoVfsCalls, + ) class TestSmartServerAnnotate(TestCaseWithTransport): - def test_simple_annotate(self): self.setup_smart_server_with_call_log() - wt = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/hello.txt', b'my helicopter\n')]) - wt.add(['hello.txt']) - wt.commit('commit', committer='test@user') + wt = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/hello.txt", b"my helicopter\n")]) + wt.add(["hello.txt"]) + wt.commit("commit", committer="test@user") self.reset_smart_call_log() - out, err = self.run_bzr(['annotate', "-d", self.get_url('branch'), - "hello.txt"]) + out, err = self.run_bzr(["annotate", "-d", self.get_url("branch"), "hello.txt"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -69,15 +71,15 @@ def test_simple_annotate(self): class TestSmartServerBranching(TestCaseWithTransport): - def test_branch_from_trivial_branch_to_same_server_branch_acceptance(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') + t = self.make_branch_and_tree("from") for count in range(9): - t.commit(message='commit %d' % count) + t.commit(message="commit %d" % count) self.reset_smart_call_log() - out, err = self.run_bzr(['branch', self.get_url('from'), - self.get_url('target')]) + out, err = self.run_bzr( + ["branch", self.get_url("from"), self.get_url("target")] + ) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -87,16 +89,18 @@ def test_branch_from_trivial_branch_to_same_server_branch_acceptance(self): self.assertLength(34, self.hpss_calls) self.expectFailure( "branching to the same branch requires VFS access", - self.assertThat, self.hpss_calls, ContainsNoVfsCalls) + self.assertThat, + self.hpss_calls, + ContainsNoVfsCalls, + ) def test_branch_from_trivial_branch_streaming_acceptance(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') + t = self.make_branch_and_tree("from") for count in range(9): - t.commit(message='commit %d' % count) + t.commit(message="commit %d" % count) self.reset_smart_call_log() - out, err = self.run_bzr(['branch', self.get_url('from'), - 'local-target']) + out, err = self.run_bzr(["branch", self.get_url("from"), "local-target"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -108,18 +112,15 @@ def test_branch_from_trivial_branch_streaming_acceptance(self): def test_branch_from_trivial_stacked_branch_streaming_acceptance(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('trunk') + t = self.make_branch_and_tree("trunk") for count in range(8): - t.commit(message='commit %d' % count) - tree2 = t.branch.controldir.sprout('feature', stacked=True - ).open_workingtree() - local_tree = t.branch.controldir.sprout( - 'local-working').open_workingtree() - local_tree.commit('feature change') + t.commit(message="commit %d" % count) + tree2 = t.branch.controldir.sprout("feature", stacked=True).open_workingtree() + local_tree = t.branch.controldir.sprout("local-working").open_workingtree() + local_tree.commit("feature change") local_tree.branch.push(tree2.branch) self.reset_smart_call_log() - out, err = self.run_bzr(['branch', self.get_url('feature'), - 'local-target']) + out, err = self.run_bzr(["branch", self.get_url("feature"), "local-target"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -131,15 +132,14 @@ def test_branch_from_trivial_stacked_branch_streaming_acceptance(self): def test_branch_from_branch_with_tags(self): self.setup_smart_server_with_call_log() - builder = self.make_branch_builder('source') - source, rev1, rev2 = fixtures.build_branch_with_non_ancestral_rev( - builder) - source.get_config_stack().set('branch.fetch_tags', True) - source.tags.set_tag('tag-a', rev2) - source.tags.set_tag('tag-missing', b'missing-rev') + builder = self.make_branch_builder("source") + source, rev1, rev2 = fixtures.build_branch_with_non_ancestral_rev(builder) + source.get_config_stack().set("branch.fetch_tags", True) + source.tags.set_tag("tag-a", rev2) + source.tags.set_tag("tag-missing", b"missing-rev") # Now source has a tag not in its ancestry. Make a branch from it. self.reset_smart_call_log() - out, err = self.run_bzr(['branch', self.get_url('source'), 'target']) + out, err = self.run_bzr(["branch", self.get_url("source"), "target"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -151,36 +151,42 @@ def test_branch_from_branch_with_tags(self): def test_branch_to_stacked_from_trivial_branch_streaming_acceptance(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') + t = self.make_branch_and_tree("from") for count in range(9): - t.commit(message='commit %d' % count) + t.commit(message="commit %d" % count) self.reset_smart_call_log() - out, err = self.run_bzr(['branch', '--stacked', self.get_url('from'), - 'local-target']) + out, err = self.run_bzr( + ["branch", "--stacked", self.get_url("from"), "local-target"] + ) # XXX: the number of hpss calls for this case isn't deterministic yet, # so we can't easily assert about the number of calls. - #self.assertLength(XXX, self.hpss_calls) + # self.assertLength(XXX, self.hpss_calls) # We can assert that none of the calls were readv requests for rix # files, though (demonstrating that at least get_parent_map calls are # not using VFS RPCs). readvs_of_rix_files = [ - c for c in self.hpss_calls - if c.call.method == 'readv' and c.call.args[-1].endswith('.rix')] + c + for c in self.hpss_calls + if c.call.method == "readv" and c.call.args[-1].endswith(".rix") + ] self.assertLength(1, self.hpss_connections) self.assertLength(0, readvs_of_rix_files) - self.expectFailure("branching to stacked requires VFS access", - self.assertThat, self.hpss_calls, ContainsNoVfsCalls) + self.expectFailure( + "branching to stacked requires VFS access", + self.assertThat, + self.hpss_calls, + ContainsNoVfsCalls, + ) def test_branch_from_branch_with_ghosts(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') + t = self.make_branch_and_tree("from") for count in range(9): - t.commit(message='commit %d' % count) - t.set_parent_ids([t.last_revision(), b'ghost']) - t.commit(message='add commit with parent') + t.commit(message="commit %d" % count) + t.set_parent_ids([t.last_revision(), b"ghost"]) + t.commit(message="add commit with parent") self.reset_smart_call_log() - out, err = self.run_bzr(['branch', self.get_url('from'), - 'local-target']) + out, err = self.run_bzr(["branch", self.get_url("from"), "local-target"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -192,14 +198,12 @@ def test_branch_from_branch_with_ghosts(self): class TestSmartServerBreakLock(TestCaseWithTransport): - def test_simple_branch_break_lock(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') + t = self.make_branch_and_tree("branch") t.branch.lock_write() self.reset_smart_call_log() - out, err = self.run_bzr( - ['break-lock', '--force', self.get_url('branch')]) + out, err = self.run_bzr(["break-lock", "--force", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -211,15 +215,14 @@ def test_simple_branch_break_lock(self): class TestSmartServerCat(TestCaseWithTransport): - def test_simple_branch_cat(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['cat', f"{self.get_url('branch')}/foo"]) + out, err = self.run_bzr(["cat", f"{self.get_url('branch')}/foo"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -231,14 +234,13 @@ def test_simple_branch_cat(self): class TestSmartServerCheckout(TestCaseWithTransport): - def test_heavyweight_checkout(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') + t = self.make_branch_and_tree("from") for count in range(9): - t.commit(message='commit %d' % count) + t.commit(message="commit %d" % count) self.reset_smart_call_log() - out, err = self.run_bzr(['checkout', self.get_url('from'), 'target']) + out, err = self.run_bzr(["checkout", self.get_url("from"), "target"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -250,12 +252,13 @@ def test_heavyweight_checkout(self): def test_lightweight_checkout(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') + t = self.make_branch_and_tree("from") for count in range(9): - t.commit(message='commit %d' % count) + t.commit(message="commit %d" % count) self.reset_smart_call_log() - out, err = self.run_bzr(['checkout', '--lightweight', self.get_url('from'), - 'target']) + out, err = self.run_bzr( + ["checkout", "--lightweight", self.get_url("from"), "target"] + ) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -266,12 +269,11 @@ def test_lightweight_checkout(self): class TestSmartServerConfig(TestCaseWithTransport): - def test_simple_branch_config(self): self.setup_smart_server_with_call_log() - self.make_branch_and_tree('branch') + self.make_branch_and_tree("branch") self.reset_smart_call_log() - out, err = self.run_bzr(['config', '-d', self.get_url('branch')]) + out, err = self.run_bzr(["config", "-d", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -283,15 +285,14 @@ def test_simple_branch_config(self): class TestSmartServerInfo(TestCaseWithTransport): - def test_simple_branch_info(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['info', self.get_url('branch')]) + out, err = self.run_bzr(["info", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -303,12 +304,12 @@ def test_simple_branch_info(self): def test_verbose_branch_info(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['info', '-v', self.get_url('branch')]) + out, err = self.run_bzr(["info", "-v", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -320,16 +321,14 @@ def test_verbose_branch_info(self): class TestSmartServerExport(TestCaseWithTransport): - def test_simple_export(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr( - ['export', "foo.tar.gz", self.get_url('branch')]) + out, err = self.run_bzr(["export", "foo.tar.gz", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -341,15 +340,14 @@ def test_simple_export(self): class TestSmartServerLog(TestCaseWithTransport): - def test_standard_log(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['log', self.get_url('branch')]) + out, err = self.run_bzr(["log", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -361,12 +359,12 @@ def test_standard_log(self): def test_verbose_log(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['log', '-v', self.get_url('branch')]) + out, err = self.run_bzr(["log", "-v", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -378,12 +376,12 @@ def test_verbose_log(self): def test_per_file(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['log', '-v', self.get_url('branch') + "/foo"]) + out, err = self.run_bzr(["log", "-v", self.get_url("branch") + "/foo"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -395,15 +393,14 @@ def test_per_file(self): class TestSmartServerLs(TestCaseWithTransport): - def test_simple_ls(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['ls', self.get_url('branch')]) + out, err = self.run_bzr(["ls", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -415,15 +412,14 @@ def test_simple_ls(self): class TestSmartServerPack(TestCaseWithTransport): - def test_simple_pack(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['pack', self.get_url('branch')]) + out, err = self.run_bzr(["pack", self.get_url("branch")]) # This figure represent the amount of HPSS calls to perform this use # case. It is entirely ok to reduce this number if a test fails due to # rpc_count # being too low. If rpc_count increases, more network @@ -436,13 +432,12 @@ def test_simple_pack(self): class TestSmartServerPush(TestCaseWithTransport): - def test_push_smart_non_stacked_streaming_acceptance(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') - t.commit(allow_pointless=True, message='first commit') + t = self.make_branch_and_tree("from") + t.commit(allow_pointless=True, message="first commit") self.reset_smart_call_log() - self.run_bzr(['push', self.get_url('to-one')], working_dir='from') + self.run_bzr(["push", self.get_url("to-one")], working_dir="from") # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -454,13 +449,15 @@ def test_push_smart_non_stacked_streaming_acceptance(self): def test_push_smart_stacked_streaming_acceptance(self): self.setup_smart_server_with_call_log() - parent = self.make_branch_and_tree('parent', format='1.9') - parent.commit(message='first commit') - local = parent.controldir.sprout('local').open_workingtree() - local.commit(message='local commit') + parent = self.make_branch_and_tree("parent", format="1.9") + parent.commit(message="first commit") + local = parent.controldir.sprout("local").open_workingtree() + local.commit(message="local commit") self.reset_smart_call_log() - self.run_bzr(['push', '--stacked', '--stacked-on', '../parent', - self.get_url('public')], working_dir='local') + self.run_bzr( + ["push", "--stacked", "--stacked-on", "../parent", self.get_url("public")], + working_dir="local", + ) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -469,16 +466,16 @@ def test_push_smart_stacked_streaming_acceptance(self): self.assertLength(15, self.hpss_calls) self.assertLength(1, self.hpss_connections) self.assertThat(self.hpss_calls, ContainsNoVfsCalls) - remote = branch.Branch.open('public') - self.assertEndsWith(remote.get_stacked_on_url(), '/parent') + remote = branch.Branch.open("public") + self.assertEndsWith(remote.get_stacked_on_url(), "/parent") def test_push_smart_tags_streaming_acceptance(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') - rev_id = t.commit(allow_pointless=True, message='first commit') - t.branch.tags.set_tag('new-tag', rev_id) + t = self.make_branch_and_tree("from") + rev_id = t.commit(allow_pointless=True, message="first commit") + t.branch.tags.set_tag("new-tag", rev_id) self.reset_smart_call_log() - self.run_bzr(['push', self.get_url('to-one')], working_dir='from') + self.run_bzr(["push", self.get_url("to-one")], working_dir="from") # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -490,13 +487,12 @@ def test_push_smart_tags_streaming_acceptance(self): def test_push_smart_incremental_acceptance(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') - t.commit(allow_pointless=True, message='first commit') - t.commit(allow_pointless=True, message='second commit') - self.run_bzr( - ['push', self.get_url('to-one'), '-r1'], working_dir='from') + t = self.make_branch_and_tree("from") + t.commit(allow_pointless=True, message="first commit") + t.commit(allow_pointless=True, message="second commit") + self.run_bzr(["push", self.get_url("to-one"), "-r1"], working_dir="from") self.reset_smart_call_log() - self.run_bzr(['push', self.get_url('to-one')], working_dir='from') + self.run_bzr(["push", self.get_url("to-one")], working_dir="from") # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -507,14 +503,12 @@ def test_push_smart_incremental_acceptance(self): self.assertThat(self.hpss_calls, ContainsNoVfsCalls) - class TestSmartServerReconcile(TestCaseWithTransport): - def test_simple_reconcile(self): self.setup_smart_server_with_call_log() - self.make_branch('branch') + self.make_branch("branch") self.reset_smart_call_log() - out, err = self.run_bzr(['reconcile', self.get_url('branch')]) + out, err = self.run_bzr(["reconcile", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -526,15 +520,14 @@ def test_simple_reconcile(self): class TestSmartServerRevno(TestCaseWithTransport): - def test_simple_branch_revno(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['revno', self.get_url('branch')]) + out, err = self.run_bzr(["revno", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -546,14 +539,15 @@ def test_simple_branch_revno(self): def test_simple_branch_revno_lookup(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") revid1 = t.commit("message") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr(['revno', '-rrevid:' + revid1.decode('utf-8'), - self.get_url('branch')]) + out, err = self.run_bzr( + ["revno", "-rrevid:" + revid1.decode("utf-8"), self.get_url("branch")] + ) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -565,12 +559,11 @@ def test_simple_branch_revno_lookup(self): class TestSmartServerRemoveBranch(TestCaseWithTransport): - def test_simple_remove_branch(self): self.setup_smart_server_with_call_log() - self.make_branch('branch') + self.make_branch("branch") self.reset_smart_call_log() - out, err = self.run_bzr(['rmbranch', self.get_url('branch')]) + out, err = self.run_bzr(["rmbranch", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -582,19 +575,19 @@ def test_simple_remove_branch(self): class TestSmartServerSend(TestCaseWithTransport): - def test_send(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") - local = t.controldir.sprout('local-branch').open_workingtree() - self.build_tree_contents([('branch/foo', b'thenewcontents')]) + local = t.controldir.sprout("local-branch").open_workingtree() + self.build_tree_contents([("branch/foo", b"thenewcontents")]) local.commit("anothermessage") self.reset_smart_call_log() out, err = self.run_bzr( - ['send', '-o', 'x.diff', self.get_url('branch')], working_dir='local-branch') + ["send", "-o", "x.diff", self.get_url("branch")], working_dir="local-branch" + ) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -606,12 +599,11 @@ def test_send(self): class TestSmartServerInitRepository(TestCaseWithTransport): - def test_init_repo_smart_acceptance(self): # The amount of hpss calls made on init-shared-repo to a smart server # should be fixed. self.setup_smart_server_with_call_log() - self.run_bzr(['init-shared-repo', self.get_url('repo')]) + self.run_bzr(["init-shared-repo", self.get_url("repo")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -623,7 +615,6 @@ def test_init_repo_smart_acceptance(self): class TestSmartServerSignMyCommits(TestCaseWithTransport): - def monkey_patch_gpg(self): """Monkey patch the gpg signing strategy to be a loopback. @@ -631,17 +622,17 @@ def monkey_patch_gpg(self): the original gpg strategy when done. """ # monkey patch gpg signing mechanism - self.overrideAttr(gpg, 'GPGStrategy', gpg.LoopbackGPGStrategy) + self.overrideAttr(gpg, "GPGStrategy", gpg.LoopbackGPGStrategy) def test_sign_single_commit(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() self.monkey_patch_gpg() - out, err = self.run_bzr(['sign-my-commits', self.get_url('branch')]) + out, err = self.run_bzr(["sign-my-commits", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -653,16 +644,16 @@ def test_sign_single_commit(self): class TestSmartServerSwitch(TestCaseWithTransport): - def test_switch_lightweight(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') + t = self.make_branch_and_tree("from") for count in range(9): - t.commit(message='commit %d' % count) - out, err = self.run_bzr(['checkout', '--lightweight', self.get_url('from'), - 'target']) + t.commit(message="commit %d" % count) + out, err = self.run_bzr( + ["checkout", "--lightweight", self.get_url("from"), "target"] + ) self.reset_smart_call_log() - self.run_bzr(['switch', self.get_url('from')], working_dir='target') + self.run_bzr(["switch", self.get_url("from")], working_dir="target") # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -674,16 +665,14 @@ def test_switch_lightweight(self): class TestSmartServerTags(TestCaseWithTransport): - def test_set_tag(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.reset_smart_call_log() - out, err = self.run_bzr( - ['tag', "-d", self.get_url('branch'), "tagname"]) + out, err = self.run_bzr(["tag", "-d", self.get_url("branch"), "tagname"]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -695,14 +684,14 @@ def test_set_tag(self): def test_show_tags(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") t.branch.tags.set_tag("sometag", b"rev1") t.branch.tags.set_tag("sometag", b"rev2") self.reset_smart_call_log() - out, err = self.run_bzr(['tags', "-d", self.get_url('branch')]) + out, err = self.run_bzr(["tags", "-d", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -714,14 +703,13 @@ def test_show_tags(self): class TestSmartServerUncommit(TestCaseWithTransport): - def test_uncommit(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('from') + t = self.make_branch_and_tree("from") for count in range(2): - t.commit(message='commit %d' % count) + t.commit(message="commit %d" % count) self.reset_smart_call_log() - out, err = self.run_bzr(['uncommit', '--force', self.get_url('from')]) + out, err = self.run_bzr(["uncommit", "--force", self.get_url("from")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have @@ -733,7 +721,6 @@ def test_uncommit(self): class TestSmartServerVerifySignatures(TestCaseWithTransport): - def monkey_patch_gpg(self): """Monkey patch the gpg signing strategy to be a loopback. @@ -741,19 +728,19 @@ def monkey_patch_gpg(self): the original gpg strategy when done. """ # monkey patch gpg signing mechanism - self.overrideAttr(gpg, 'GPGStrategy', gpg.LoopbackGPGStrategy) + self.overrideAttr(gpg, "GPGStrategy", gpg.LoopbackGPGStrategy) def test_verify_signatures(self): self.setup_smart_server_with_call_log() - t = self.make_branch_and_tree('branch') - self.build_tree_contents([('branch/foo', b'thecontents')]) + t = self.make_branch_and_tree("branch") + self.build_tree_contents([("branch/foo", b"thecontents")]) t.add("foo") t.commit("message") self.monkey_patch_gpg() - out, err = self.run_bzr(['sign-my-commits', self.get_url('branch')]) + out, err = self.run_bzr(["sign-my-commits", self.get_url("branch")]) self.reset_smart_call_log() - self.run_bzr('sign-my-commits') - self.run_bzr(['verify-signatures', self.get_url('branch')]) + self.run_bzr("sign-my-commits") + self.run_bzr(["verify-signatures", self.get_url("branch")]) # This figure represent the amount of work to perform this use case. It # is entirely ok to reduce this number if a test fails due to rpc_count # being too low. If rpc_count increases, more network roundtrips have diff --git a/breezy/bzr/tests/test_weave.py b/breezy/bzr/tests/test_weave.py index 4fe0fc19d7..b8f0fd4040 100644 --- a/breezy/bzr/tests/test_weave.py +++ b/breezy/bzr/tests/test_weave.py @@ -32,15 +32,14 @@ # texts for use in testing TEXT_0 = [b"Hello world"] -TEXT_1 = [b"Hello world", - b"A second line"] +TEXT_1 = [b"Hello world", b"A second line"] class TestBase(TestCase): - def check_read_write(self, k): """Check the weave k can be written & re-read.""" from tempfile import TemporaryFile + tf = TemporaryFile() write_weave(k, tf) @@ -49,15 +48,15 @@ def check_read_write(self, k): if k != k2: tf.seek(0) - self.log('serialized weave:') + self.log("serialized weave:") self.log(tf.read()) - self.log('') - self.log('parents: %s' % (k._parents == k2._parents)) - self.log(f' {k._parents!r}') - self.log(f' {k2._parents!r}') - self.log('') - self.fail('read/write check failed') + self.log("") + self.log("parents: %s" % (k._parents == k2._parents)) + self.log(f" {k._parents!r}") + self.log(f" {k2._parents!r}") + self.log("") + self.fail("read/write check failed") class WeaveContains(TestBase): @@ -65,24 +64,21 @@ class WeaveContains(TestBase): def runTest(self): k = Weave(get_scope=lambda: None) - self.assertNotIn(b'foo', k) - k.add_lines(b'foo', [], TEXT_1) - self.assertIn(b'foo', k) + self.assertNotIn(b"foo", k) + k.add_lines(b"foo", [], TEXT_1) + self.assertIn(b"foo", k) class Easy(TestBase): - def runTest(self): Weave() class AnnotateOne(TestBase): - def runTest(self): k = Weave() - k.add_lines(b'text0', [], TEXT_0) - self.assertEqual(k.annotate(b'text0'), - [(b'text0', TEXT_0[0])]) + k.add_lines(b"text0", [], TEXT_0) + self.assertEqual(k.annotate(b"text0"), [(b"text0", TEXT_0[0])]) class InvalidAdd(TestBase): @@ -91,11 +87,9 @@ class InvalidAdd(TestBase): def runTest(self): k = Weave() - self.assertRaises(errors.RevisionNotPresent, - k.add_lines, - b'text0', - [b'69'], - [b'new text!']) + self.assertRaises( + errors.RevisionNotPresent, k.add_lines, b"text0", [b"69"], [b"new text!"] + ) class RepeatedAdd(TestBase): @@ -103,27 +97,30 @@ class RepeatedAdd(TestBase): def test_duplicate_add(self): k = Weave() - idx = k.add_lines(b'text0', [], TEXT_0) - idx2 = k.add_lines(b'text0', [], TEXT_0) + idx = k.add_lines(b"text0", [], TEXT_0) + idx2 = k.add_lines(b"text0", [], TEXT_0) self.assertEqual(idx, idx2) class InvalidRepeatedAdd(TestBase): - def runTest(self): k = Weave() - k.add_lines(b'basis', [], TEXT_0) - k.add_lines(b'text0', [], TEXT_0) - self.assertRaises(errors.RevisionAlreadyPresent, - k.add_lines, - b'text0', - [], - [b'not the same text']) - self.assertRaises(errors.RevisionAlreadyPresent, - k.add_lines, - b'text0', - [b'basis'], # not the right parents - TEXT_0) + k.add_lines(b"basis", [], TEXT_0) + k.add_lines(b"text0", [], TEXT_0) + self.assertRaises( + errors.RevisionAlreadyPresent, + k.add_lines, + b"text0", + [], + [b"not the same text"], + ) + self.assertRaises( + errors.RevisionAlreadyPresent, + k.add_lines, + b"text0", + [b"basis"], # not the right parents + TEXT_0, + ) class InsertLines(TestBase): @@ -136,53 +133,54 @@ class InsertLines(TestBase): def runTest(self): k = Weave() - k.add_lines(b'text0', [], [b'line 1']) - k.add_lines(b'text1', [b'text0'], [b'line 1', b'line 2']) + k.add_lines(b"text0", [], [b"line 1"]) + k.add_lines(b"text1", [b"text0"], [b"line 1", b"line 2"]) - self.assertEqual(k.annotate(b'text0'), - [(b'text0', b'line 1')]) + self.assertEqual(k.annotate(b"text0"), [(b"text0", b"line 1")]) - self.assertEqual(k.get_lines(1), - [b'line 1', - b'line 2']) + self.assertEqual(k.get_lines(1), [b"line 1", b"line 2"]) - self.assertEqual(k.annotate(b'text1'), - [(b'text0', b'line 1'), - (b'text1', b'line 2')]) + self.assertEqual( + k.annotate(b"text1"), [(b"text0", b"line 1"), (b"text1", b"line 2")] + ) - k.add_lines(b'text2', [b'text0'], [b'line 1', b'diverged line']) + k.add_lines(b"text2", [b"text0"], [b"line 1", b"diverged line"]) - self.assertEqual(k.annotate(b'text2'), - [(b'text0', b'line 1'), - (b'text2', b'diverged line')]) + self.assertEqual( + k.annotate(b"text2"), [(b"text0", b"line 1"), (b"text2", b"diverged line")] + ) - text3 = [b'line 1', b'middle line', b'line 2'] - k.add_lines(b'text3', - [b'text0', b'text1'], - text3) + text3 = [b"line 1", b"middle line", b"line 2"] + k.add_lines(b"text3", [b"text0", b"text1"], text3) # self.log("changes to text3: " + pformat(list(k._delta(set([0, 1]), # text3)))) self.log("k._weave=" + pformat(k._weave)) - self.assertEqual(k.annotate(b'text3'), - [(b'text0', b'line 1'), - (b'text3', b'middle line'), - (b'text1', b'line 2')]) + self.assertEqual( + k.annotate(b"text3"), + [(b"text0", b"line 1"), (b"text3", b"middle line"), (b"text1", b"line 2")], + ) # now multiple insertions at different places k.add_lines( - b'text4', [b'text0', b'text1', b'text3'], - [b'line 1', b'aaa', b'middle line', b'bbb', b'line 2', b'ccc']) + b"text4", + [b"text0", b"text1", b"text3"], + [b"line 1", b"aaa", b"middle line", b"bbb", b"line 2", b"ccc"], + ) - self.assertEqual(k.annotate(b'text4'), - [(b'text0', b'line 1'), - (b'text4', b'aaa'), - (b'text3', b'middle line'), - (b'text4', b'bbb'), - (b'text1', b'line 2'), - (b'text4', b'ccc')]) + self.assertEqual( + k.annotate(b"text4"), + [ + (b"text0", b"line 1"), + (b"text4", b"aaa"), + (b"text3", b"middle line"), + (b"text4", b"bbb"), + (b"text1", b"line 2"), + (b"text4", b"ccc"), + ], + ) class DeleteLines(TestBase): @@ -194,27 +192,27 @@ class DeleteLines(TestBase): def runTest(self): k = Weave() - base_text = [b'one', b'two', b'three', b'four'] + base_text = [b"one", b"two", b"three", b"four"] - k.add_lines(b'text0', [], base_text) + k.add_lines(b"text0", [], base_text) - texts = [[b'one', b'two', b'three'], - [b'two', b'three', b'four'], - [b'one', b'four'], - [b'one', b'two', b'three', b'four'], - ] + texts = [ + [b"one", b"two", b"three"], + [b"two", b"three", b"four"], + [b"one", b"four"], + [b"one", b"two", b"three", b"four"], + ] i = 1 for t in texts: - k.add_lines(b'text%d' % i, [b'text0'], t) + k.add_lines(b"text%d" % i, [b"text0"], t) i += 1 - self.log('final weave:') - self.log('k._weave=' + pformat(k._weave)) + self.log("final weave:") + self.log("k._weave=" + pformat(k._weave)) for i in range(len(texts)): - self.assertEqual(k.get_lines(i + 1), - texts[i]) + self.assertEqual(k.get_lines(i + 1), texts[i]) class SuicideDelete(TestBase): @@ -223,22 +221,22 @@ class SuicideDelete(TestBase): def runTest(self): k = Weave() - k._parents = [(), - ] - k._weave = [(b'{', 0), - b'first line', - (b'[', 0), - b'deleted in 0', - (b']', 0), - (b'}', 0), - ] + k._parents = [ + (), + ] + k._weave = [ + (b"{", 0), + b"first line", + (b"[", 0), + b"deleted in 0", + (b"]", 0), + (b"}", 0), + ] # SKIPPED # Weave.get doesn't trap this anymore return - self.assertRaises(WeaveFormatError, - k.get_lines, - 0) + self.assertRaises(WeaveFormatError, k.get_lines, 0) class CannedDelete(TestBase): @@ -247,31 +245,40 @@ class CannedDelete(TestBase): def runTest(self): k = Weave() - k._parents = [(), - frozenset([0]), - ] - k._weave = [(b'{', 0), - b'first line', - (b'[', 1), - b'line to be deleted', - (b']', 1), - b'last line', - (b'}', 0), - ] + k._parents = [ + (), + frozenset([0]), + ] + k._weave = [ + (b"{", 0), + b"first line", + (b"[", 1), + b"line to be deleted", + (b"]", 1), + b"last line", + (b"}", 0), + ] k._sha1s = [ - sha_string(b'first lineline to be deletedlast line'), - sha_string(b'first linelast line')] + sha_string(b"first lineline to be deletedlast line"), + sha_string(b"first linelast line"), + ] - self.assertEqual(k.get_lines(0), - [b'first line', - b'line to be deleted', - b'last line', - ]) + self.assertEqual( + k.get_lines(0), + [ + b"first line", + b"line to be deleted", + b"last line", + ], + ) - self.assertEqual(k.get_lines(1), - [b'first line', - b'last line', - ]) + self.assertEqual( + k.get_lines(1), + [ + b"first line", + b"last line", + ], + ) class CannedReplacement(TestBase): @@ -280,35 +287,44 @@ class CannedReplacement(TestBase): def runTest(self): k = Weave() - k._parents = [frozenset(), - frozenset([0]), - ] - k._weave = [(b'{', 0), - b'first line', - (b'[', 1), - b'line to be deleted', - (b']', 1), - (b'{', 1), - b'replacement line', - (b'}', 1), - b'last line', - (b'}', 0), - ] + k._parents = [ + frozenset(), + frozenset([0]), + ] + k._weave = [ + (b"{", 0), + b"first line", + (b"[", 1), + b"line to be deleted", + (b"]", 1), + (b"{", 1), + b"replacement line", + (b"}", 1), + b"last line", + (b"}", 0), + ] k._sha1s = [ - sha_string(b'first lineline to be deletedlast line'), - sha_string(b'first linereplacement linelast line')] + sha_string(b"first lineline to be deletedlast line"), + sha_string(b"first linereplacement linelast line"), + ] - self.assertEqual(k.get_lines(0), - [b'first line', - b'line to be deleted', - b'last line', - ]) + self.assertEqual( + k.get_lines(0), + [ + b"first line", + b"line to be deleted", + b"last line", + ], + ) - self.assertEqual(k.get_lines(1), - [b'first line', - b'replacement line', - b'last line', - ]) + self.assertEqual( + k.get_lines(1), + [ + b"first line", + b"replacement line", + b"last line", + ], + ) class BadWeave(TestBase): @@ -317,28 +333,29 @@ class BadWeave(TestBase): def runTest(self): k = Weave() - k._parents = [frozenset(), - ] - k._weave = [b'bad line', - (b'{', 0), - b'foo {', - (b'{', 1), - b' added in version 1', - (b'{', 2), - b' added in v2', - (b'}', 2), - b' also from v1', - (b'}', 1), - b'}', - (b'}', 0)] + k._parents = [ + frozenset(), + ] + k._weave = [ + b"bad line", + (b"{", 0), + b"foo {", + (b"{", 1), + b" added in version 1", + (b"{", 2), + b" added in v2", + (b"}", 2), + b" also from v1", + (b"}", 1), + b"}", + (b"}", 0), + ] # SKIPPED # Weave.get doesn't trap this anymore return - self.assertRaises(WeaveFormatError, - k.get, - 0) + self.assertRaises(WeaveFormatError, k.get, 0) class BadInsert(TestBase): @@ -347,31 +364,30 @@ class BadInsert(TestBase): def runTest(self): k = Weave() - k._parents = [frozenset(), - frozenset([0]), - frozenset([0]), - frozenset([0, 1, 2]), - ] - k._weave = [(b'{', 0), - b'foo {', - (b'{', 1), - b' added in version 1', - (b'{', 1), - b' more in 1', - (b'}', 1), - (b'}', 1), - (b'}', 0)] + k._parents = [ + frozenset(), + frozenset([0]), + frozenset([0]), + frozenset([0, 1, 2]), + ] + k._weave = [ + (b"{", 0), + b"foo {", + (b"{", 1), + b" added in version 1", + (b"{", 1), + b" more in 1", + (b"}", 1), + (b"}", 1), + (b"}", 0), + ] # this is not currently enforced by get return - self.assertRaises(WeaveFormatError, - k.get, - 0) + self.assertRaises(WeaveFormatError, k.get, 0) - self.assertRaises(WeaveFormatError, - k.get, - 1) + self.assertRaises(WeaveFormatError, k.get, 1) class InsertNested(TestBase): @@ -380,52 +396,51 @@ class InsertNested(TestBase): def runTest(self): k = Weave() - k._parents = [frozenset(), - frozenset([0]), - frozenset([0]), - frozenset([0, 1, 2]), - ] - k._weave = [(b'{', 0), - b'foo {', - (b'{', 1), - b' added in version 1', - (b'{', 2), - b' added in v2', - (b'}', 2), - b' also from v1', - (b'}', 1), - b'}', - (b'}', 0)] + k._parents = [ + frozenset(), + frozenset([0]), + frozenset([0]), + frozenset([0, 1, 2]), + ] + k._weave = [ + (b"{", 0), + b"foo {", + (b"{", 1), + b" added in version 1", + (b"{", 2), + b" added in v2", + (b"}", 2), + b" also from v1", + (b"}", 1), + b"}", + (b"}", 0), + ] k._sha1s = [ - sha_string(b'foo {}'), - sha_string(b'foo { added in version 1 also from v1}'), - sha_string(b'foo { added in v2}'), - sha_string( - b'foo { added in version 1 added in v2 also from v1}') - ] - - self.assertEqual(k.get_lines(0), - [b'foo {', - b'}']) - - self.assertEqual(k.get_lines(1), - [b'foo {', - b' added in version 1', - b' also from v1', - b'}']) - - self.assertEqual(k.get_lines(2), - [b'foo {', - b' added in v2', - b'}']) - - self.assertEqual(k.get_lines(3), - [b'foo {', - b' added in version 1', - b' added in v2', - b' also from v1', - b'}']) + sha_string(b"foo {}"), + sha_string(b"foo { added in version 1 also from v1}"), + sha_string(b"foo { added in v2}"), + sha_string(b"foo { added in version 1 added in v2 also from v1}"), + ] + + self.assertEqual(k.get_lines(0), [b"foo {", b"}"]) + + self.assertEqual( + k.get_lines(1), [b"foo {", b" added in version 1", b" also from v1", b"}"] + ) + + self.assertEqual(k.get_lines(2), [b"foo {", b" added in v2", b"}"]) + + self.assertEqual( + k.get_lines(3), + [ + b"foo {", + b" added in version 1", + b" added in v2", + b" also from v1", + b"}", + ], + ) class DeleteLines2(TestBase): @@ -438,23 +453,17 @@ class DeleteLines2(TestBase): def runTest(self): k = Weave() - k.add_lines(b'text0', [], [b"line the first", - b"line 2", - b"line 3", - b"fine"]) + k.add_lines(b"text0", [], [b"line the first", b"line 2", b"line 3", b"fine"]) self.assertEqual(len(k.get_lines(0)), 4) - k.add_lines(b'text1', [b'text0'], [b"line the first", - b"fine"]) + k.add_lines(b"text1", [b"text0"], [b"line the first", b"fine"]) - self.assertEqual(k.get_lines(1), - [b"line the first", - b"fine"]) + self.assertEqual(k.get_lines(1), [b"line the first", b"fine"]) - self.assertEqual(k.annotate(b'text1'), - [(b'text0', b"line the first"), - (b'text0', b"fine")]) + self.assertEqual( + k.annotate(b"text1"), [(b"text0", b"line the first"), (b"text0", b"fine")] + ) class IncludeVersions(TestBase): @@ -471,22 +480,20 @@ def runTest(self): k = Weave() k._parents = [frozenset(), frozenset([0])] - k._weave = [(b'{', 0), - b"first line", - (b'}', 0), - (b'{', 1), - b"second line", - (b'}', 1)] + k._weave = [ + (b"{", 0), + b"first line", + (b"}", 0), + (b"{", 1), + b"second line", + (b"}", 1), + ] - k._sha1s = [sha_string(b'first line'), sha_string( - b'first linesecond line')] + k._sha1s = [sha_string(b"first line"), sha_string(b"first linesecond line")] - self.assertEqual(k.get_lines(1), - [b"first line", - b"second line"]) + self.assertEqual(k.get_lines(1), [b"first line", b"second line"]) - self.assertEqual(k.get_lines(0), - [b"first line"]) + self.assertEqual(k.get_lines(0), [b"first line"]) class DivergedIncludes(TestBase): @@ -496,54 +503,51 @@ def runTest(self): # FIXME make the weave, dont poke at it. k = Weave() - k._names = [b'0', b'1', b'2'] - k._name_map = {b'0': 0, b'1': 1, b'2': 2} - k._parents = [frozenset(), - frozenset([0]), - frozenset([0]), - ] - k._weave = [(b'{', 0), - b"first line", - (b'}', 0), - (b'{', 1), - b"second line", - (b'}', 1), - (b'{', 2), - b"alternative second line", - (b'}', 2), - ] + k._names = [b"0", b"1", b"2"] + k._name_map = {b"0": 0, b"1": 1, b"2": 2} + k._parents = [ + frozenset(), + frozenset([0]), + frozenset([0]), + ] + k._weave = [ + (b"{", 0), + b"first line", + (b"}", 0), + (b"{", 1), + b"second line", + (b"}", 1), + (b"{", 2), + b"alternative second line", + (b"}", 2), + ] k._sha1s = [ - sha_string(b'first line'), - sha_string(b'first linesecond line'), - sha_string(b'first linealternative second line')] + sha_string(b"first line"), + sha_string(b"first linesecond line"), + sha_string(b"first linealternative second line"), + ] - self.assertEqual(k.get_lines(0), - [b"first line"]) + self.assertEqual(k.get_lines(0), [b"first line"]) - self.assertEqual(k.get_lines(1), - [b"first line", - b"second line"]) + self.assertEqual(k.get_lines(1), [b"first line", b"second line"]) - self.assertEqual(k.get_lines(b'2'), - [b"first line", - b"alternative second line"]) + self.assertEqual(k.get_lines(b"2"), [b"first line", b"alternative second line"]) - self.assertEqual(set(k.get_ancestry([b'2'])), - {b'0', b'2'}) + self.assertEqual(set(k.get_ancestry([b"2"])), {b"0", b"2"}) class ReplaceLine(TestBase): def runTest(self): k = Weave() - text0 = [b'cheddar', b'stilton', b'gruyere'] - text1 = [b'cheddar', b'blue vein', b'neufchatel', b'chevre'] + text0 = [b"cheddar", b"stilton", b"gruyere"] + text1 = [b"cheddar", b"blue vein", b"neufchatel", b"chevre"] - k.add_lines(b'text0', [], text0) - k.add_lines(b'text1', [b'text0'], text1) + k.add_lines(b"text0", [], text0) + k.add_lines(b"text1", [b"text0"], text1) - self.log('k._weave=' + pformat(k._weave)) + self.log("k._weave=" + pformat(k._weave)) self.assertEqual(k.get_lines(0), text0) self.assertEqual(k.get_lines(1), text1) @@ -556,32 +560,36 @@ def runTest(self): k = Weave() texts = [ - [b'header'], - [b'header', b'', b'line from 1'], - [b'header', b'', b'line from 2', b'more from 2'], - [b'header', b'', b'line from 1', b'fixup line', b'line from 2'], - ] + [b"header"], + [b"header", b"", b"line from 1"], + [b"header", b"", b"line from 2", b"more from 2"], + [b"header", b"", b"line from 1", b"fixup line", b"line from 2"], + ] - k.add_lines(b'text0', [], texts[0]) - k.add_lines(b'text1', [b'text0'], texts[1]) - k.add_lines(b'text2', [b'text0'], texts[2]) - k.add_lines(b'merge', [b'text0', b'text1', b'text2'], texts[3]) + k.add_lines(b"text0", [], texts[0]) + k.add_lines(b"text1", [b"text0"], texts[1]) + k.add_lines(b"text2", [b"text0"], texts[2]) + k.add_lines(b"merge", [b"text0", b"text1", b"text2"], texts[3]) for i, t in enumerate(texts): self.assertEqual(k.get_lines(i), t) - self.assertEqual(k.annotate(b'merge'), - [(b'text0', b'header'), - (b'text1', b''), - (b'text1', b'line from 1'), - (b'merge', b'fixup line'), - (b'text2', b'line from 2'), - ]) + self.assertEqual( + k.annotate(b"merge"), + [ + (b"text0", b"header"), + (b"text1", b""), + (b"text1", b"line from 1"), + (b"merge", b"fixup line"), + (b"text2", b"line from 2"), + ], + ) - self.assertEqual(set(k.get_ancestry([b'merge'])), - {b'text0', b'text1', b'text2', b'merge'}) + self.assertEqual( + set(k.get_ancestry([b"merge"])), {b"text0", b"text1", b"text2", b"merge"} + ) - self.log('k._weave=' + pformat(k._weave)) + self.log("k._weave=" + pformat(k._weave)) self.check_read_write(k) @@ -598,15 +606,13 @@ def runTest(self): return # NOT RUN k = Weave() - k.add_lines([], [b'aaa', b'bbb']) - k.add_lines([0], [b'aaa', b'111', b'bbb']) - k.add_lines([1], [b'aaa', b'222', b'bbb']) + k.add_lines([], [b"aaa", b"bbb"]) + k.add_lines([0], [b"aaa", b"111", b"bbb"]) + k.add_lines([1], [b"aaa", b"222", b"bbb"]) k.merge([1, 2]) - self.assertEqual([[[b'aaa']], - [[b'111'], [b'222']], - [[b'bbb']]]) + self.assertEqual([[[b"aaa"]], [[b"111"], [b"222"]], [[b"bbb"]]]) class NonConflict(TestBase): @@ -619,9 +625,9 @@ def runTest(self): return # NOT RUN k = Weave() - k.add_lines([], [b'aaa', b'bbb']) - k.add_lines([0], [b'111', b'aaa', b'ccc', b'bbb']) - k.add_lines([1], [b'aaa', b'ccc', b'bbb', b'222']) + k.add_lines([], [b"aaa", b"bbb"]) + k.add_lines([0], [b"111", b"aaa", b"ccc", b"bbb"]) + k.add_lines([1], [b"aaa", b"ccc", b"bbb", b"222"]) class Khayyam(TestBase): @@ -633,12 +639,10 @@ def test_multi_line_merge(self): A Jug of Wine, a Loaf of Bread, -- and Thou Beside me singing in the Wilderness -- Oh, Wilderness were Paradise enow!""", - b"""A Book of Verses underneath the Bough, A Jug of Wine, a Loaf of Bread, -- and Thou Beside me singing in the Wilderness -- Oh, Wilderness were Paradise now!""", - b"""A Book of poems underneath the tree, A Jug of Wine, a Loaf of Bread, and Thou @@ -646,21 +650,20 @@ def test_multi_line_merge(self): Oh, Wilderness were Paradise now! -- O. Khayyam""", - b"""A Book of Verses underneath the Bough, A Jug of Wine, a Loaf of Bread, and Thou Beside me singing in the Wilderness -- Oh, Wilderness were Paradise now!""", - ] - texts = [[l.strip() for l in t.split(b'\n')] for t in rawtexts] + ] + texts = [[l.strip() for l in t.split(b"\n")] for t in rawtexts] k = Weave() parents = set() i = 0 for t in texts: - k.add_lines(b'text%d' % i, list(parents), t) - parents.add(b'text%d' % i) + k.add_lines(b"text%d" % i, list(parents), t) + parents.add(b"text%d" % i) i += 1 self.log("k._weave=" + pformat(k._weave)) @@ -672,15 +675,14 @@ def test_multi_line_merge(self): class JoinWeavesTests(TestBase): - def setUp(self): super().setUp() self.weave1 = Weave() - self.lines1 = [b'hello\n'] - self.lines3 = [b'hello\n', b'cruel\n', b'world\n'] - self.weave1.add_lines(b'v1', [], self.lines1) - self.weave1.add_lines(b'v2', [b'v1'], [b'hello\n', b'world\n']) - self.weave1.add_lines(b'v3', [b'v2'], self.lines3) + self.lines1 = [b"hello\n"] + self.lines3 = [b"hello\n", b"cruel\n", b"world\n"] + self.weave1.add_lines(b"v1", [], self.lines1) + self.weave1.add_lines(b"v2", [b"v1"], [b"hello\n", b"world\n"]) + self.weave1.add_lines(b"v3", [b"v2"], self.lines3) def test_written_detection(self): # Test detection of weave file corruption. @@ -690,8 +692,8 @@ def test_written_detection(self): # but it at least helps verify the data you get, is what you want. w = Weave() - w.add_lines(b'v1', [], [b'hello\n']) - w.add_lines(b'v2', [b'v1'], [b'hello\n', b'there\n']) + w.add_lines(b"v1", [], [b"hello\n"]) + w.add_lines(b"v2", [b"v1"], [b"hello\n", b"there\n"]) tmpf = BytesIO() write_weave(w, tmpf) @@ -699,54 +701,56 @@ def test_written_detection(self): # Because we are corrupting, we need to make sure we have the exact # text self.assertEqual( - b'# bzr weave file v5\n' - b'i\n1 f572d396fae9206628714fb2ce00f72e94f2258f\nn v1\n\n' - b'i 0\n1 90f265c6e75f1c8f9ab76dcf85528352c5f215ef\nn v2\n\n' - b'w\n{ 0\n. hello\n}\n{ 1\n. there\n}\nW\n', - tmpf.getvalue()) + b"# bzr weave file v5\n" + b"i\n1 f572d396fae9206628714fb2ce00f72e94f2258f\nn v1\n\n" + b"i 0\n1 90f265c6e75f1c8f9ab76dcf85528352c5f215ef\nn v2\n\n" + b"w\n{ 0\n. hello\n}\n{ 1\n. there\n}\nW\n", + tmpf.getvalue(), + ) # Change a single letter tmpf = BytesIO( - b'# bzr weave file v5\n' - b'i\n1 f572d396fae9206628714fb2ce00f72e94f2258f\nn v1\n\n' - b'i 0\n1 90f265c6e75f1c8f9ab76dcf85528352c5f215ef\nn v2\n\n' - b'w\n{ 0\n. hello\n}\n{ 1\n. There\n}\nW\n') + b"# bzr weave file v5\n" + b"i\n1 f572d396fae9206628714fb2ce00f72e94f2258f\nn v1\n\n" + b"i 0\n1 90f265c6e75f1c8f9ab76dcf85528352c5f215ef\nn v2\n\n" + b"w\n{ 0\n. hello\n}\n{ 1\n. There\n}\nW\n" + ) w = read_weave(tmpf) - self.assertEqual(b'hello\n', w.get_text(b'v1')) - self.assertRaises(WeaveInvalidChecksum, w.get_text, b'v2') - self.assertRaises(WeaveInvalidChecksum, w.get_lines, b'v2') + self.assertEqual(b"hello\n", w.get_text(b"v1")) + self.assertRaises(WeaveInvalidChecksum, w.get_text, b"v2") + self.assertRaises(WeaveInvalidChecksum, w.get_lines, b"v2") self.assertRaises(WeaveInvalidChecksum, w.check) # Change the sha checksum tmpf = BytesIO( - b'# bzr weave file v5\n' - b'i\n1 f572d396fae9206628714fb2ce00f72e94f2258f\nn v1\n\n' - b'i 0\n1 f0f265c6e75f1c8f9ab76dcf85528352c5f215ef\nn v2\n\n' - b'w\n{ 0\n. hello\n}\n{ 1\n. there\n}\nW\n') + b"# bzr weave file v5\n" + b"i\n1 f572d396fae9206628714fb2ce00f72e94f2258f\nn v1\n\n" + b"i 0\n1 f0f265c6e75f1c8f9ab76dcf85528352c5f215ef\nn v2\n\n" + b"w\n{ 0\n. hello\n}\n{ 1\n. there\n}\nW\n" + ) w = read_weave(tmpf) - self.assertEqual(b'hello\n', w.get_text(b'v1')) - self.assertRaises(WeaveInvalidChecksum, w.get_text, b'v2') - self.assertRaises(WeaveInvalidChecksum, w.get_lines, b'v2') + self.assertEqual(b"hello\n", w.get_text(b"v1")) + self.assertRaises(WeaveInvalidChecksum, w.get_text, b"v2") + self.assertRaises(WeaveInvalidChecksum, w.get_lines, b"v2") self.assertRaises(WeaveInvalidChecksum, w.check) class TestWeave(TestCase): - def test_allow_reserved_false(self): - w = Weave('name', allow_reserved=False) + w = Weave("name", allow_reserved=False) # Add lines is checked at the WeaveFile level, not at the Weave level - w.add_lines(b'name:', [], TEXT_1) + w.add_lines(b"name:", [], TEXT_1) # But get_lines is checked at this level - self.assertRaises(errors.ReservedId, w.get_lines, b'name:') + self.assertRaises(errors.ReservedId, w.get_lines, b"name:") def test_allow_reserved_true(self): - w = Weave('name', allow_reserved=True) - w.add_lines(b'name:', [], TEXT_1) - self.assertEqual(TEXT_1, w.get_lines(b'name:')) + w = Weave("name", allow_reserved=True) + w.add_lines(b"name:", [], TEXT_1) + self.assertEqual(TEXT_1, w.get_lines(b"name:")) class InstrumentedWeave(Weave): @@ -765,7 +769,7 @@ class TestNeedsReweave(TestCase): """Internal corner cases for when reweave is needed.""" def test_compatible_parents(self): - w1 = Weave('a') + w1 = Weave("a") my_parents = {1, 2, 3} # subsets are ok self.assertTrue(w1._compatible_parents(my_parents, {3})) @@ -780,7 +784,6 @@ def test_compatible_parents(self): class TestWeaveFile(TestCaseInTempDir): - def test_empty_file(self): - with open('empty.weave', 'wb+') as f: + with open("empty.weave", "wb+") as f: self.assertRaises(WeaveFormatError, read_weave, f) diff --git a/breezy/bzr/tests/test_workingtree.py b/breezy/bzr/tests/test_workingtree.py index 983cc3faf4..4506c450ab 100644 --- a/breezy/bzr/tests/test_workingtree.py +++ b/breezy/bzr/tests/test_workingtree.py @@ -21,10 +21,11 @@ class ErrorTests(TestCase): - def test_inventory_modified(self): error = InventoryModified("a tree to be repred") - self.assertEqualDiff("The current inventory for the tree 'a tree to " - "be repred' has been modified, so a clean inventory cannot be " - "read without data loss.", - str(error)) + self.assertEqualDiff( + "The current inventory for the tree 'a tree to " + "be repred' has been modified, so a clean inventory cannot be " + "read without data loss.", + str(error), + ) diff --git a/breezy/bzr/tests/test_workingtree_4.py b/breezy/bzr/tests/test_workingtree_4.py index 3217ed1501..05579f914a 100644 --- a/breezy/bzr/tests/test_workingtree_4.py +++ b/breezy/bzr/tests/test_workingtree_4.py @@ -40,13 +40,12 @@ def test_disk_layout(self): # format 'Bazaar Working Tree format 4' # stat-cache = ?? t = control.get_workingtree_transport(None) - with t.get('format') as f: - self.assertEqualDiff(b'Bazaar Working Tree Format 4 (bzr 0.15)\n', - f.read()) - self.assertFalse(t.has('inventory.basis')) + with t.get("format") as f: + self.assertEqualDiff(b"Bazaar Working Tree Format 4 (bzr 0.15)\n", f.read()) + self.assertFalse(t.has("inventory.basis")) # no last-revision file means 'None' or 'NULLREVISION' - self.assertFalse(t.has('last-revision')) - state = dirstate.DirState.on_file(t.local_abspath('dirstate')) + self.assertFalse(t.has("last-revision")) + state = dirstate.DirState.on_file(t.local_abspath("dirstate")) state.lock_read() try: self.assertEqual([], state.get_parent_ids()) @@ -74,27 +73,27 @@ def test_uses_lockdir(self): # one in common. t = self.get_transport() tree = self.make_workingtree() - self.assertIsDirectory('.bzr', t) - self.assertIsDirectory('.bzr/checkout', t) - self.assertIsDirectory('.bzr/checkout/lock', t) - our_lock = LockDir(t, '.bzr/checkout/lock') + self.assertIsDirectory(".bzr", t) + self.assertIsDirectory(".bzr/checkout", t) + self.assertIsDirectory(".bzr/checkout/lock", t) + our_lock = LockDir(t, ".bzr/checkout/lock") self.assertEqual(our_lock.peek(), None) tree.lock_write() self.assertTrue(our_lock.peek()) tree.unlock() self.assertEqual(our_lock.peek(), None) - def make_workingtree(self, relpath=''): + def make_workingtree(self, relpath=""): url = self.get_url(relpath) if relpath: - self.build_tree([relpath + '/']) + self.build_tree([relpath + "/"]) dir = bzrdir.BzrDirMetaFormat1().initialize(url) dir.create_repository() dir.create_branch() try: return workingtree_4.WorkingTreeFormat4().initialize(dir) except errors.NotLocalUrl as e: - raise TestSkipped('Not a local URL') from e + raise TestSkipped("Not a local URL") from e def test_dirstate_stores_all_parent_inventories(self): tree = self.make_workingtree() @@ -107,24 +106,28 @@ def test_dirstate_stores_all_parent_inventories(self): # revisions elsewhere and pull them across, doing by hand part of the # work that merge would do. - subtree = self.make_branch_and_tree('subdir') + subtree = self.make_branch_and_tree("subdir") # writelock the tree so its repository doesn't get readlocked by # the revision tree locks. This works around the bug where we dont # permit lock upgrading. subtree.lock_write() self.addCleanup(subtree.unlock) - self.build_tree(['subdir/file-a', ]) - subtree.add(['file-a'], ids=[b'id-a']) - rev1 = subtree.commit('commit in subdir') + self.build_tree( + [ + "subdir/file-a", + ] + ) + subtree.add(["file-a"], ids=[b"id-a"]) + rev1 = subtree.commit("commit in subdir") - subtree2 = subtree.controldir.sprout('subdir2').open_workingtree() - self.build_tree(['subdir2/file-b']) - subtree2.add(['file-b'], ids=[b'id-b']) - rev2 = subtree2.commit('commit in subdir2') + subtree2 = subtree.controldir.sprout("subdir2").open_workingtree() + self.build_tree(["subdir2/file-b"]) + subtree2.add(["file-b"], ids=[b"id-b"]) + rev2 = subtree2.commit("commit in subdir2") subtree.flush() - subtree3 = subtree.controldir.sprout('subdir3').open_workingtree() - rev3 = subtree3.commit('merge from subdir2') + subtree3 = subtree.controldir.sprout("subdir3").open_workingtree() + rev3 = subtree3.commit("merge from subdir2") repo = tree.branch.repository repo.fetch(subtree.branch.repository, rev1) @@ -140,10 +143,13 @@ def test_dirstate_stores_all_parent_inventories(self): # set the parents as if a merge had taken place. # this should cause the tree data to be folded into the # dirstate. - tree.set_parent_trees([ - (rev1, rev1_revtree), - (rev2, rev2_revtree), - (rev3, rev3_revtree), ]) + tree.set_parent_trees( + [ + (rev1, rev1_revtree), + (rev2, rev2_revtree), + (rev3, rev3_revtree), + ] + ) # create tree-sourced revision trees rev1_tree = tree.revision_tree(rev1) @@ -168,8 +174,8 @@ def test_dirstate_doesnt_read_parents_from_repo_when_setting(self): """ tree = self.make_workingtree() - subtree = self.make_branch_and_tree('subdir') - rev1 = subtree.commit('commit in subdir') + subtree = self.make_branch_and_tree("subdir") + rev1 = subtree.commit("commit in subdir") rev1_tree = subtree.basis_tree() rev1_tree.lock_read() self.addCleanup(rev1_tree.unlock) @@ -193,19 +199,19 @@ def test_dirstate_doesnt_read_from_repo_when_returning_cache_tree(self): """ tree = self.make_workingtree() - subtree = self.make_branch_and_tree('subdir') + subtree = self.make_branch_and_tree("subdir") # writelock the tree so its repository doesn't get readlocked by # the revision tree locks. This works around the bug where we dont # permit lock upgrading. subtree.lock_write() self.addCleanup(subtree.unlock) - rev1 = subtree.commit('commit in subdir') + rev1 = subtree.commit("commit in subdir") rev1_tree = subtree.basis_tree() rev1_tree.lock_read() # Trigger reading of inventory rev1_tree.root_inventory # noqa: B018 self.addCleanup(rev1_tree.unlock) - rev2 = subtree.commit('second commit in subdir', allow_pointless=True) + rev2 = subtree.commit("second commit in subdir", allow_pointless=True) rev2_tree = subtree.basis_tree() rev2_tree.lock_read() # Trigger reading of inventory @@ -234,8 +240,11 @@ def test_dirstate_doesnt_read_from_repo_when_returning_cache_tree(self): # returned trees self.assertTreesEqual(rev2_tree, result_rev2_tree) self.assertRaises( - errors.NoSuchRevisionInTree, self.assertTreesEqual, rev1_tree, - result_rev1_tree) + errors.NoSuchRevisionInTree, + self.assertTreesEqual, + rev1_tree, + result_rev1_tree, + ) def test_dirstate_doesnt_cache_non_parent_trees(self): """Getting parent trees from a dirstate tree does not read from the @@ -246,8 +255,8 @@ def test_dirstate_doesnt_cache_non_parent_trees(self): # make a tree that we can try for, which is able to be returned but # must not be - subtree = self.make_branch_and_tree('subdir') - rev1 = subtree.commit('commit in subdir') + subtree = self.make_branch_and_tree("subdir") + rev1 = subtree.commit("commit in subdir") tree.branch.pull(subtree.branch) # check it fails self.assertRaises(errors.NoSuchRevision, tree.revision_tree, rev1) @@ -255,35 +264,45 @@ def test_dirstate_doesnt_cache_non_parent_trees(self): def test_no_dirstate_outside_lock(self): # temporary test until the code is mature enough to test from outside. """Getting a dirstate object fails if there is no lock.""" + def lock_and_call_current_dirstate(tree, lock_method): getattr(tree, lock_method)() tree.current_dirstate() tree.unlock() + tree = self.make_workingtree() self.assertRaises(errors.ObjectNotLocked, tree.current_dirstate) - lock_and_call_current_dirstate(tree, 'lock_read') + lock_and_call_current_dirstate(tree, "lock_read") self.assertRaises(errors.ObjectNotLocked, tree.current_dirstate) - lock_and_call_current_dirstate(tree, 'lock_write') + lock_and_call_current_dirstate(tree, "lock_write") self.assertRaises(errors.ObjectNotLocked, tree.current_dirstate) - lock_and_call_current_dirstate(tree, 'lock_tree_write') + lock_and_call_current_dirstate(tree, "lock_tree_write") self.assertRaises(errors.ObjectNotLocked, tree.current_dirstate) def test_set_parent_trees_uses_update_basis_by_delta(self): - builder = self.make_branch_builder('source') + builder = self.make_branch_builder("source") builder.start_series() self.addCleanup(builder.finish_series) - builder.build_snapshot([], [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('a', b'a-id', 'file', b'content\n'))], - revision_id=b'A') - builder.build_snapshot([b'A'], [ - ('modify', ('a', b'new content\nfor a\n')), - ('add', ('b', b'b-id', 'file', b'b-content\n'))], - revision_id=b'B') - tree = self.make_workingtree('tree') + builder.build_snapshot( + [], + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("a", b"a-id", "file", b"content\n")), + ], + revision_id=b"A", + ) + builder.build_snapshot( + [b"A"], + [ + ("modify", ("a", b"new content\nfor a\n")), + ("add", ("b", b"b-id", "file", b"b-content\n")), + ], + revision_id=b"B", + ) + tree = self.make_workingtree("tree") source_branch = builder.get_branch() - tree.branch.repository.fetch(source_branch.repository, b'B') - tree.pull(source_branch, stop_revision=b'A') + tree.branch.repository.fetch(source_branch.repository, b"B") + tree.pull(source_branch, stop_revision=b"A") tree.lock_write() self.addCleanup(tree.unlock) state = tree.current_dirstate() @@ -293,48 +312,58 @@ def test_set_parent_trees_uses_update_basis_by_delta(self): def log_update_basis_by_delta(delta, new_revid): called.append(new_revid) return orig_update(delta, new_revid) + state.update_basis_by_delta = log_update_basis_by_delta basis = tree.basis_tree() - self.assertEqual(b'a-id', basis.path2id('a')) - self.assertFalse(basis.is_versioned('b')) + self.assertEqual(b"a-id", basis.path2id("a")) + self.assertFalse(basis.is_versioned("b")) def fail_set_parent_trees(trees, ghosts): - raise AssertionError('dirstate.set_parent_trees() was called') + raise AssertionError("dirstate.set_parent_trees() was called") + state.set_parent_trees = fail_set_parent_trees - tree.pull(source_branch, stop_revision=b'B') - self.assertEqual([b'B'], called) + tree.pull(source_branch, stop_revision=b"B") + self.assertEqual([b"B"], called) basis = tree.basis_tree() - self.assertEqual(b'a-id', basis.path2id('a')) - self.assertEqual(b'b-id', basis.path2id('b')) + self.assertEqual(b"a-id", basis.path2id("a")) + self.assertEqual(b"b-id", basis.path2id("b")) def test_set_parent_trees_handles_missing_basis(self): - builder = self.make_branch_builder('source') + builder = self.make_branch_builder("source") builder.start_series() self.addCleanup(builder.finish_series) - builder.build_snapshot([], [ - ('add', ('', b'root-id', 'directory', None)), - ('add', ('a', b'a-id', 'file', b'content\n'))], - revision_id=b'A') - builder.build_snapshot([b'A'], [ - ('modify', ('a', b'new content\nfor a\n')), - ('add', ('b', b'b-id', 'file', b'b-content\n'))], - revision_id=b'B') - builder.build_snapshot([b'A'], [ - ('add', ('c', b'c-id', 'file', b'c-content\n'))], - revision_id=b'C') - b_c = self.make_branch('branch_with_c') - b_c.pull(builder.get_branch(), stop_revision=b'C') - b_b = self.make_branch('branch_with_b') - b_b.pull(builder.get_branch(), stop_revision=b'B') + builder.build_snapshot( + [], + [ + ("add", ("", b"root-id", "directory", None)), + ("add", ("a", b"a-id", "file", b"content\n")), + ], + revision_id=b"A", + ) + builder.build_snapshot( + [b"A"], + [ + ("modify", ("a", b"new content\nfor a\n")), + ("add", ("b", b"b-id", "file", b"b-content\n")), + ], + revision_id=b"B", + ) + builder.build_snapshot( + [b"A"], [("add", ("c", b"c-id", "file", b"c-content\n"))], revision_id=b"C" + ) + b_c = self.make_branch("branch_with_c") + b_c.pull(builder.get_branch(), stop_revision=b"C") + b_b = self.make_branch("branch_with_b") + b_b.pull(builder.get_branch(), stop_revision=b"B") # This is reproducing some of what 'switch' does, just to isolate the # set_parent_trees() step. - wt = b_b.create_checkout('tree', lightweight=True) + wt = b_b.create_checkout("tree", lightweight=True) fmt = wt.controldir.find_branch_format() fmt.set_reference(wt.controldir, None, b_c) # Re-open with the new reference wt = wt.controldir.open_workingtree() - wt.set_parent_trees([(b'C', b_c.repository.revision_tree(b'C'))]) - self.assertFalse(wt.basis_tree().is_versioned('b')) + wt.set_parent_trees([(b"C", b_c.repository.revision_tree(b"C"))]) + self.assertFalse(wt.basis_tree().is_versioned("b")) def test_new_dirstate_on_new_lock(self): # until we have detection for when a dirstate can be reused, we @@ -347,28 +376,27 @@ def lock_and_compare_all_current_dirstate(tree, lock_method): self.assertNotIn(state, known_dirstates) known_dirstates.add(state) tree.unlock() + tree = self.make_workingtree() # lock twice with each type to prevent silly per-lock-type bugs. # each lock and compare looks for a unique state object. - lock_and_compare_all_current_dirstate(tree, 'lock_read') - lock_and_compare_all_current_dirstate(tree, 'lock_read') - lock_and_compare_all_current_dirstate(tree, 'lock_tree_write') - lock_and_compare_all_current_dirstate(tree, 'lock_tree_write') - lock_and_compare_all_current_dirstate(tree, 'lock_write') - lock_and_compare_all_current_dirstate(tree, 'lock_write') + lock_and_compare_all_current_dirstate(tree, "lock_read") + lock_and_compare_all_current_dirstate(tree, "lock_read") + lock_and_compare_all_current_dirstate(tree, "lock_tree_write") + lock_and_compare_all_current_dirstate(tree, "lock_tree_write") + lock_and_compare_all_current_dirstate(tree, "lock_write") + lock_and_compare_all_current_dirstate(tree, "lock_write") def test_constructing_invalid_interdirstate_raises(self): tree = self.make_workingtree() - rev_id = tree.commit('first post') - tree.commit('second post') + rev_id = tree.commit("first post") + tree.commit("second post") rev_tree = tree.branch.repository.revision_tree(rev_id) # Exception is not a great thing to raise, but this test is # very short, and code is used to sanity check other tests, so # a full error object is YAGNI. - self.assertRaises( - Exception, workingtree_4.InterDirStateTree, rev_tree, tree) - self.assertRaises( - Exception, workingtree_4.InterDirStateTree, tree, rev_tree) + self.assertRaises(Exception, workingtree_4.InterDirStateTree, rev_tree, tree) + self.assertRaises(Exception, workingtree_4.InterDirStateTree, tree, rev_tree) def test_revtree_to_revtree_not_interdirstate(self): # we should not get a dirstate optimiser for two repository sourced @@ -376,43 +404,31 @@ def test_revtree_to_revtree_not_interdirstate(self): # of all formats; though that could be written in the future it doesn't # seem well worth it. tree = self.make_workingtree() - rev_id = tree.commit('first post') - rev_id2 = tree.commit('second post') + rev_id = tree.commit("first post") + rev_id2 = tree.commit("second post") rev_tree = tree.branch.repository.revision_tree(rev_id) rev_tree2 = tree.branch.repository.revision_tree(rev_id2) optimiser = InterTree.get(rev_tree, rev_tree2) self.assertIsInstance(optimiser, InterTree) - self.assertNotIsInstance( - optimiser, - workingtree_4.InterDirStateTree - ) + self.assertNotIsInstance(optimiser, workingtree_4.InterDirStateTree) optimiser = InterTree.get(rev_tree2, rev_tree) self.assertIsInstance(optimiser, InterTree) - self.assertNotIsInstance( - optimiser, - workingtree_4.InterDirStateTree - ) + self.assertNotIsInstance(optimiser, workingtree_4.InterDirStateTree) def test_revtree_not_in_dirstate_to_dirstate_not_interdirstate(self): # we should not get a dirstate optimiser when the revision id for of # the source is not in the dirstate of the target. tree = self.make_workingtree() - rev_id = tree.commit('first post') - tree.commit('second post') + rev_id = tree.commit("first post") + tree.commit("second post") rev_tree = tree.branch.repository.revision_tree(rev_id) tree.lock_read() optimiser = InterTree.get(rev_tree, tree) self.assertIsInstance(optimiser, InterTree) - self.assertNotIsInstance( - optimiser, - workingtree_4.InterDirStateTree - ) + self.assertNotIsInstance(optimiser, workingtree_4.InterDirStateTree) optimiser = InterTree.get(tree, rev_tree) self.assertIsInstance(optimiser, InterTree) - self.assertNotIsInstance( - optimiser, - workingtree_4.InterDirStateTree - ) + self.assertNotIsInstance(optimiser, workingtree_4.InterDirStateTree) tree.unlock() def test_empty_basis_to_dirstate_tree(self): @@ -433,7 +449,7 @@ def test_nonempty_basis_to_dirstate_tree(self): # 'changes_from' from a non-null basis dirstate revision tree to a # WorkingTree4. tree = self.make_workingtree() - tree.commit('first post') + tree.commit("first post") tree.lock_read() basis_tree = tree.basis_tree() basis_tree.lock_read() @@ -460,7 +476,7 @@ def test_nonempty_basis_revtree_to_dirstate_tree(self): # 'changes_from' from a non-null repository based rev tree to a # WorkingTree4. tree = self.make_workingtree() - tree.commit('first post') + tree.commit("first post") tree.lock_read() basis_tree = tree.branch.repository.revision_tree(tree.last_revision()) basis_tree.lock_read() @@ -474,9 +490,9 @@ def test_tree_to_basis_in_other_tree(self): # the source revid is in the dirstate object of the target and # the dirstates are different. This is largely covered by testing # with repository revtrees, so is just for extra confidence. - tree = self.make_workingtree('a') - tree.commit('first post') - tree2 = self.make_workingtree('b') + tree = self.make_workingtree("a") + tree.commit("first post") + tree2 = self.make_workingtree("b") tree2.pull(tree.branch) basis_tree = tree.basis_tree() tree2.lock_read() @@ -489,12 +505,12 @@ def test_tree_to_basis_in_other_tree(self): def test_merged_revtree_to_tree(self): # we should get a InterDirStateTree when # the source tree is a merged tree present in the dirstate of target. - tree = self.make_workingtree('a') - tree.commit('first post') - tree.commit('tree 1 commit 2') - tree2 = self.make_workingtree('b') + tree = self.make_workingtree("a") + tree.commit("first post") + tree.commit("tree 1 commit 2") + tree2 = self.make_workingtree("b") tree2.pull(tree.branch) - tree2.commit('tree 2 commit 2') + tree2.commit("tree 2 commit 2") tree.merge_from_branch(tree2.branch) second_parent_tree = tree.revision_tree(tree.get_parent_ids()[1]) second_parent_tree.lock_read() @@ -505,40 +521,38 @@ def test_merged_revtree_to_tree(self): self.assertIsInstance(optimiser, workingtree_4.InterDirStateTree) def test_id2path(self): - tree = self.make_workingtree('tree') - self.build_tree(['tree/a', 'tree/b']) - tree.add(['a'], ids=[b'a-id']) - self.assertEqual('a', tree.id2path(b'a-id')) - self.assertRaises(errors.NoSuchId, tree.id2path, b'a') - tree.commit('a') - tree.add(['b'], ids=[b'b-id']) + tree = self.make_workingtree("tree") + self.build_tree(["tree/a", "tree/b"]) + tree.add(["a"], ids=[b"a-id"]) + self.assertEqual("a", tree.id2path(b"a-id")) + self.assertRaises(errors.NoSuchId, tree.id2path, b"a") + tree.commit("a") + tree.add(["b"], ids=[b"b-id"]) try: - new_path = 'b\u03bcrry' - tree.rename_one('a', new_path) + new_path = "b\u03bcrry" + tree.rename_one("a", new_path) except UnicodeEncodeError: # support running the test on non-unicode platforms - new_path = 'c' - tree.rename_one('a', new_path) - self.assertEqual(new_path, tree.id2path(b'a-id')) - tree.commit('b\xb5rry') + new_path = "c" + tree.rename_one("a", new_path) + self.assertEqual(new_path, tree.id2path(b"a-id")) + tree.commit("b\xb5rry") tree.unversion([new_path]) - self.assertRaises(errors.NoSuchId, tree.id2path, b'a-id') - self.assertEqual('b', tree.id2path(b'b-id')) - self.assertRaises(errors.NoSuchId, tree.id2path, b'c-id') + self.assertRaises(errors.NoSuchId, tree.id2path, b"a-id") + self.assertEqual("b", tree.id2path(b"b-id")) + self.assertRaises(errors.NoSuchId, tree.id2path, b"c-id") def test_unique_root_id_per_tree(self): # each time you initialize a new tree, it gets a different root id - format_name = 'development-subtree' - tree1 = self.make_branch_and_tree('tree1', - format=format_name) - tree2 = self.make_branch_and_tree('tree2', - format=format_name) - self.assertNotEqual(tree1.path2id(''), tree2.path2id('')) + format_name = "development-subtree" + tree1 = self.make_branch_and_tree("tree1", format=format_name) + tree2 = self.make_branch_and_tree("tree2", format=format_name) + self.assertNotEqual(tree1.path2id(""), tree2.path2id("")) # when you branch, it inherits the same root id - tree1.commit('first post') - tree3 = tree1.controldir.sprout('tree3').open_workingtree() - self.assertEqual(tree3.path2id(''), tree1.path2id('')) + tree1.commit("first post") + tree3 = tree1.controldir.sprout("tree3").open_workingtree() + self.assertEqual(tree3.path2id(""), tree1.path2id("")) def test_set_root_id(self): # similar to some code that fails in the dirstate-plus-subtree branch @@ -547,75 +561,94 @@ def test_set_root_id(self): def validate(): with wt.lock_read(): wt.current_dirstate()._validate() - wt = self.make_workingtree('tree') - wt.set_root_id(b'TREE-ROOTID') + + wt = self.make_workingtree("tree") + wt.set_root_id(b"TREE-ROOTID") validate() - wt.commit('somenthing') + wt.commit("somenthing") validate() # now switch and commit again - wt.set_root_id(b'tree-rootid') + wt.set_root_id(b"tree-rootid") validate() - wt.commit('again') + wt.commit("again") validate() def test_default_root_id(self): - tree = self.make_branch_and_tree('tag', format='dirstate-tags') - self.assertEqual(inventory.ROOT_ID, tree.path2id('')) - tree = self.make_branch_and_tree('subtree', - format='development-subtree') - self.assertNotEqual(inventory.ROOT_ID, tree.path2id('')) + tree = self.make_branch_and_tree("tag", format="dirstate-tags") + self.assertEqual(inventory.ROOT_ID, tree.path2id("")) + tree = self.make_branch_and_tree("subtree", format="development-subtree") + self.assertNotEqual(inventory.ROOT_ID, tree.path2id("")) def test_non_subtree_with_nested_trees(self): # prior to dirstate, st/diff/commit ignored nested trees. # dirstate, as opposed to development-subtree, should # behave the same way. - tree = self.make_branch_and_tree('.', format='dirstate') + tree = self.make_branch_and_tree(".", format="dirstate") self.assertFalse(tree.supports_tree_reference()) - self.build_tree(['dir/']) + self.build_tree(["dir/"]) # for testing easily. - tree.set_root_id(b'root') - tree.add(['dir'], ids=[b'dir-id']) - self.make_branch_and_tree('dir') + tree.set_root_id(b"root") + tree.add(["dir"], ids=[b"dir-id"]) + self.make_branch_and_tree("dir") # the most primitive operation: kind - self.assertEqual('directory', tree.kind('dir')) + self.assertEqual("directory", tree.kind("dir")) # a diff against the basis should give us a directory and the root (as # the root is new too). tree.lock_read() - expected = [(b'dir-id', - (None, 'dir'), - True, - (False, True), - (None, b'root'), - (None, 'dir'), - (None, 'directory'), - (None, False), False), - (b'root', (None, ''), True, (False, True), (None, None), - (None, ''), (None, 'directory'), (None, False), False)] + expected = [ + ( + b"dir-id", + (None, "dir"), + True, + (False, True), + (None, b"root"), + (None, "dir"), + (None, "directory"), + (None, False), + False, + ), + ( + b"root", + (None, ""), + True, + (False, True), + (None, None), + (None, ""), + (None, "directory"), + (None, False), + False, + ), + ] self.assertEqual( - expected, - list(tree.iter_changes(tree.basis_tree(), specific_files=['dir']))) + expected, list(tree.iter_changes(tree.basis_tree(), specific_files=["dir"])) + ) tree.unlock() # do a commit, we want to trigger the dirstate fast-path too - tree.commit('first post') + tree.commit("first post") # change the path for the subdir, which will trigger getting all # its data: - os.rename('dir', 'also-dir') + os.rename("dir", "also-dir") # now the diff will use the fast path tree.lock_read() - expected = [(b'dir-id', - ('dir', 'dir'), - True, - (True, True), - (b'root', b'root'), - ('dir', 'dir'), - ('directory', None), - (False, False), False)] + expected = [ + ( + b"dir-id", + ("dir", "dir"), + True, + (True, True), + (b"root", b"root"), + ("dir", "dir"), + ("directory", None), + (False, False), + False, + ) + ] self.assertEqual(expected, list(tree.iter_changes(tree.basis_tree()))) tree.unlock() def test_with_subtree_supports_tree_references(self): # development-subtree should support tree-references. - tree = self.make_branch_and_tree('.', format='development-subtree') + tree = self.make_branch_and_tree(".", format="development-subtree") self.assertTrue(tree.supports_tree_reference()) # having checked this is on, the tree interface, and intertree # interface tests, will proceed to test the subtree support of @@ -623,25 +656,28 @@ def test_with_subtree_supports_tree_references(self): def test_iter_changes_ignores_unversioned_dirs(self): """iter_changes should not descend into unversioned directories.""" - tree = self.make_branch_and_tree('.', format='dirstate') + tree = self.make_branch_and_tree(".", format="dirstate") # We have an unversioned directory at the root, a versioned one with # other versioned files and an unversioned directory, and another # versioned dir with nothing but an unversioned directory. - self.build_tree(['unversioned/', - 'unversioned/a', - 'unversioned/b/', - 'versioned/', - 'versioned/unversioned/', - 'versioned/unversioned/a', - 'versioned/unversioned/b/', - 'versioned2/', - 'versioned2/a', - 'versioned2/unversioned/', - 'versioned2/unversioned/a', - 'versioned2/unversioned/b/', - ]) - tree.add(['versioned', 'versioned2', 'versioned2/a']) - tree.commit('one', rev_id=b'rev-1') + self.build_tree( + [ + "unversioned/", + "unversioned/a", + "unversioned/b/", + "versioned/", + "versioned/unversioned/", + "versioned/unversioned/a", + "versioned/unversioned/b/", + "versioned2/", + "versioned2/a", + "versioned2/unversioned/", + "versioned2/unversioned/a", + "versioned2/unversioned/b/", + ] + ) + tree.add(["versioned", "versioned2", "versioned2/a"]) + tree.commit("one", rev_id=b"rev-1") # Trap osutils._walkdirs_utf8 to spy on what dirs have been accessed. returned = [] @@ -649,81 +685,93 @@ def walkdirs_spy(*args, **kwargs): for val in orig(*args, **kwargs): returned.append(val[0][0]) yield val - orig = self.overrideAttr(osutils, '_walkdirs_utf8', walkdirs_spy) + + orig = self.overrideAttr(osutils, "_walkdirs_utf8", walkdirs_spy) basis = tree.basis_tree() tree.lock_read() self.addCleanup(tree.unlock) basis.lock_read() self.addCleanup(basis.unlock) - changes = [c.path for c in - tree.iter_changes(basis, want_unversioned=True)] - self.assertEqual([(None, 'unversioned'), - (None, 'versioned/unversioned'), - (None, 'versioned2/unversioned'), - ], changes) - self.assertEqual([b'', b'versioned', b'versioned2'], returned) + changes = [c.path for c in tree.iter_changes(basis, want_unversioned=True)] + self.assertEqual( + [ + (None, "unversioned"), + (None, "versioned/unversioned"), + (None, "versioned2/unversioned"), + ], + changes, + ) + self.assertEqual([b"", b"versioned", b"versioned2"], returned) del returned[:] # reset changes = [c[1] for c in tree.iter_changes(basis)] self.assertEqual([], changes) - self.assertEqual([b'', b'versioned', b'versioned2'], returned) + self.assertEqual([b"", b"versioned", b"versioned2"], returned) def test_iter_changes_unversioned_error(self): """Check if a PathsNotVersionedError is correctly raised and the paths list contains all unversioned entries only. """ - tree = self.make_branch_and_tree('tree') - self.build_tree_contents([('tree/bar', b'')]) - tree.add(['bar'], ids=[b'bar-id']) + tree = self.make_branch_and_tree("tree") + self.build_tree_contents([("tree/bar", b"")]) + tree.add(["bar"], ids=[b"bar-id"]) tree.lock_read() self.addCleanup(tree.unlock) def tree_iter_changes(files): - return list(tree.iter_changes( - tree.basis_tree(), specific_files=files, - require_versioned=True)) - e = self.assertRaises(errors.PathsNotVersionedError, - tree_iter_changes, ['bar', 'foo']) - self.assertEqual(e.paths, ['foo']) + return list( + tree.iter_changes( + tree.basis_tree(), specific_files=files, require_versioned=True + ) + ) + + e = self.assertRaises( + errors.PathsNotVersionedError, tree_iter_changes, ["bar", "foo"] + ) + self.assertEqual(e.paths, ["foo"]) def test_iter_changes_unversioned_non_ascii(self): """Unversioned non-ascii paths should be reported as unicode.""" self.requireFeature(features.UnicodeFilenameFeature) - tree = self.make_branch_and_tree('.') - self.build_tree_contents([('f', b'')]) - tree.add(['f'], ids=[b'f-id']) + tree = self.make_branch_and_tree(".") + self.build_tree_contents([("f", b"")]) + tree.add(["f"], ids=[b"f-id"]) def tree_iter_changes(tree, files): - return list(tree.iter_changes( - tree.basis_tree(), specific_files=files, - require_versioned=True)) + return list( + tree.iter_changes( + tree.basis_tree(), specific_files=files, require_versioned=True + ) + ) + tree.lock_read() self.addCleanup(tree.unlock) - e = self.assertRaises(errors.PathsNotVersionedError, - tree_iter_changes, tree, ['\xa7', '\u03c0']) - self.assertEqual(set(e.paths), {'\xa7', '\u03c0'}) + e = self.assertRaises( + errors.PathsNotVersionedError, tree_iter_changes, tree, ["\xa7", "\u03c0"] + ) + self.assertEqual(set(e.paths), {"\xa7", "\u03c0"}) def get_tree_with_cachable_file_foo(self): - tree = self.make_branch_and_tree('.') + tree = self.make_branch_and_tree(".") tree.lock_write() self.addCleanup(tree.unlock) - self.build_tree_contents([('foo', b'a bit of content for foo\n')]) - tree.add(['foo'], ids=[b'foo-id']) + self.build_tree_contents([("foo", b"a bit of content for foo\n")]) + tree.add(["foo"], ids=[b"foo-id"]) tree.current_dirstate()._cutoff_time = time.time() + 60 return tree def test_commit_updates_hash_cache(self): tree = self.get_tree_with_cachable_file_foo() - tree.commit('a commit') + tree.commit("a commit") # tree's dirstate should now have a valid stat entry for foo. - entry = tree._get_entry(path='foo') - expected_sha1 = osutils.sha_file_by_name('foo') + entry = tree._get_entry(path="foo") + expected_sha1 = osutils.sha_file_by_name("foo") self.assertEqual(expected_sha1, entry[1][0][1]) - self.assertEqual(len('a bit of content for foo\n'), entry[1][0][2]) + self.assertEqual(len("a bit of content for foo\n"), entry[1][0][2]) def test_observed_sha1_cachable(self): tree = self.get_tree_with_cachable_file_foo() - expected_sha1 = osutils.sha_file_by_name('foo') + expected_sha1 = osutils.sha_file_by_name("foo") statvalue = os.lstat("foo") tree._observed_sha1("foo", (expected_sha1, statvalue)) entry = tree._get_entry(path="foo") @@ -741,27 +789,27 @@ def test_observed_sha1_cachable(self): self.assertEqual(statvalue.st_size, entry_state[2]) def test_observed_sha1_new_file(self): - tree = self.make_branch_and_tree('.') - self.build_tree(['foo']) - tree.add(['foo'], ids=[b'foo-id']) + tree = self.make_branch_and_tree(".") + self.build_tree(["foo"]) + tree.add(["foo"], ids=[b"foo-id"]) with tree.lock_read(): current_sha1 = tree._get_entry(path="foo")[1][0][1] with tree.lock_write(): tree._observed_sha1( - "foo", (osutils.sha_file_by_name('foo'), os.lstat("foo"))) + "foo", (osutils.sha_file_by_name("foo"), os.lstat("foo")) + ) # Must not have changed - self.assertEqual(current_sha1, - tree._get_entry(path="foo")[1][0][1]) + self.assertEqual(current_sha1, tree._get_entry(path="foo")[1][0][1]) def test_get_file_with_stat_id_only(self): # Explicit test to ensure we get a lstat value from WT4 trees. - tree = self.make_branch_and_tree('.') - self.build_tree(['foo']) - tree.add(['foo']) + tree = self.make_branch_and_tree(".") + self.build_tree(["foo"]) + tree.add(["foo"]) tree.lock_read() self.addCleanup(tree.unlock) - file_obj, statvalue = tree.get_file_with_stat('foo') - expected = os.lstat('foo') + file_obj, statvalue = tree.get_file_with_stat("foo") + expected = os.lstat("foo") self.assertEqualStat(expected, statvalue) self.assertEqual([b"contents of foo\n"], file_obj.readlines()) @@ -781,15 +829,19 @@ def test_invalid_rename(self): # Create a corrupted dirstate with tree.lock_write(): # We need a parent, or we always compare with NULL - tree.commit('init') + tree.commit("init") state = tree.current_dirstate() state._read_dirblocks_if_needed() # Now add in an invalid entry, a rename with a dangling pointer - state._dirblocks[1][1].append(((b'', b'foo', b'foo-id'), - [(b'f', b'', 0, False, b''), - (b'r', b'bar', 0, False, b'')])) - self.assertListRaises(dirstate.DirstateCorrupt, - tree.iter_changes, tree.basis_tree()) + state._dirblocks[1][1].append( + ( + (b"", b"foo", b"foo-id"), + [(b"f", b"", 0, False, b""), (b"r", b"bar", 0, False, b"")], + ) + ) + self.assertListRaises( + dirstate.DirstateCorrupt, tree.iter_changes, tree.basis_tree() + ) def get_simple_dirblocks(self, state): """Extract the simple information from the DirState. @@ -810,42 +862,53 @@ def get_simple_dirblocks(self, state): def test_update_basis_with_invalid_delta(self): """When given an invalid delta, it should abort, and not be saved.""" - self.build_tree(['dir/', 'dir/file']) + self.build_tree(["dir/", "dir/file"]) tree = self.create_wt4() tree.lock_write() self.addCleanup(tree.unlock) - tree.add(['dir', 'dir/file'], ids=[b'dir-id', b'file-id']) - first_revision_id = tree.commit('init') + tree.add(["dir", "dir/file"], ids=[b"dir-id", b"file-id"]) + first_revision_id = tree.commit("init") - root_id = tree.path2id('') + root_id = tree.path2id("") state = tree.current_dirstate() state._read_dirblocks_if_needed() - self.assertEqual([ - (b'', [((b'', b'', root_id), [b'd', b'd'])]), - (b'', [((b'', b'dir', b'dir-id'), [b'd', b'd'])]), - (b'dir', [((b'dir', b'file', b'file-id'), [b'f', b'f'])]), - ], self.get_simple_dirblocks(state)) - - tree.remove(['dir/file']) - self.assertEqual([ - (b'', [((b'', b'', root_id), [b'd', b'd'])]), - (b'', [((b'', b'dir', b'dir-id'), [b'd', b'd'])]), - (b'dir', [((b'dir', b'file', b'file-id'), [b'a', b'f'])]), - ], self.get_simple_dirblocks(state)) + self.assertEqual( + [ + (b"", [((b"", b"", root_id), [b"d", b"d"])]), + (b"", [((b"", b"dir", b"dir-id"), [b"d", b"d"])]), + (b"dir", [((b"dir", b"file", b"file-id"), [b"f", b"f"])]), + ], + self.get_simple_dirblocks(state), + ) + + tree.remove(["dir/file"]) + self.assertEqual( + [ + (b"", [((b"", b"", root_id), [b"d", b"d"])]), + (b"", [((b"", b"dir", b"dir-id"), [b"d", b"d"])]), + (b"dir", [((b"dir", b"file", b"file-id"), [b"a", b"f"])]), + ], + self.get_simple_dirblocks(state), + ) # Make sure the removal is written to disk tree.flush() # self.assertRaises(Exception, tree.update_basis_by_delta, - new_dir = inventory.InventoryDirectory(b'dir-id', 'new-dir', root_id) - new_dir.revision = b'new-revision-id' - new_file = inventory.InventoryFile(b'file-id', 'new-file', root_id) - new_file.revision = b'new-revision-id' + new_dir = inventory.InventoryDirectory(b"dir-id", "new-dir", root_id) + new_dir.revision = b"new-revision-id" + new_file = inventory.InventoryFile(b"file-id", "new-file", root_id) + new_file.revision = b"new-revision-id" self.assertRaises( errors.InconsistentDelta, - tree.update_basis_by_delta, b'new-revision-id', - InventoryDelta([('dir', 'new-dir', b'dir-id', new_dir), - ('dir/file', 'new-dir/new-file', b'file-id', new_file), - ])) + tree.update_basis_by_delta, + b"new-revision-id", + InventoryDelta( + [ + ("dir", "new-dir", b"dir-id", new_dir), + ("dir/file", "new-dir/new-file", b"file-id", new_file), + ] + ), + ) del state # Now when we re-read the file it should not have been modified @@ -854,28 +917,30 @@ def test_update_basis_with_invalid_delta(self): self.assertEqual(first_revision_id, tree.last_revision()) state = tree.current_dirstate() state._read_dirblocks_if_needed() - self.assertEqual([ - (b'', [((b'', b'', root_id), [b'd', b'd'])]), - (b'', [((b'', b'dir', b'dir-id'), [b'd', b'd'])]), - (b'dir', [((b'dir', b'file', b'file-id'), [b'a', b'f'])]), - ], self.get_simple_dirblocks(state)) + self.assertEqual( + [ + (b"", [((b"", b"", root_id), [b"d", b"d"])]), + (b"", [((b"", b"dir", b"dir-id"), [b"d", b"d"])]), + (b"dir", [((b"dir", b"file", b"file-id"), [b"a", b"f"])]), + ], + self.get_simple_dirblocks(state), + ) class TestInventoryCoherency(TestCaseWithTransport): - def test_inventory_is_synced_when_unversioning_a_dir(self): """Unversioning the root of a subtree unversions the entire subtree.""" - tree = self.make_branch_and_tree('.') - self.build_tree(['a/', 'a/b', 'c/']) - tree.add(['a', 'a/b', 'c'], ids=[b'a-id', b'b-id', b'c-id']) + tree = self.make_branch_and_tree(".") + self.build_tree(["a/", "a/b", "c/"]) + tree.add(["a", "a/b", "c"], ids=[b"a-id", b"b-id", b"c-id"]) # within a lock unversion should take effect tree.lock_write() self.addCleanup(tree.unlock) # Force access to the in memory inventory to trigger bug #494221: try # maintaining the in-memory inventory inv = tree.root_inventory - self.assertTrue(inv.has_id(b'a-id')) - self.assertTrue(inv.has_id(b'b-id')) - tree.unversion(['a', 'a/b']) - self.assertFalse(inv.has_id(b'a-id')) - self.assertFalse(inv.has_id(b'b-id')) + self.assertTrue(inv.has_id(b"a-id")) + self.assertTrue(inv.has_id(b"b-id")) + tree.unversion(["a", "a/b"]) + self.assertFalse(inv.has_id(b"a-id")) + self.assertFalse(inv.has_id(b"b-id")) diff --git a/breezy/bzr/tests/test_xml.py b/breezy/bzr/tests/test_xml.py index 0d9092a2bb..2961ea6221 100644 --- a/breezy/bzr/tests/test_xml.py +++ b/breezy/bzr/tests/test_xml.py @@ -175,7 +175,6 @@ """ - _inventory_utf8_v5 = b""" ") + eq(rev.committer, "Martin Pool ") eq(len(rev.parent_ids), 1) eq(rev.timezone, 36000) - eq(rev.parent_ids[0], - b"mbp@sourcefrog.net-20050905063503-43948f59fa127d92") + eq(rev.parent_ids[0], b"mbp@sourcefrog.net-20050905063503-43948f59fa127d92") def test_unpack_revision_5_utc(self): inp = BytesIO(_revision_v5_utc) rev = breezy.bzr.xml5.revision_serializer_v5.read_revision(inp) eq = self.assertEqual - eq(rev.committer, - "Martin Pool ") + eq(rev.committer, "Martin Pool ") eq(len(rev.parent_ids), 1) eq(rev.timezone, 0) - eq(rev.parent_ids[0], - b"mbp@sourcefrog.net-20050905063503-43948f59fa127d92") + eq(rev.parent_ids[0], b"mbp@sourcefrog.net-20050905063503-43948f59fa127d92") def test_unpack_inventory_5(self): """Unpack canned new-style inventory.""" @@ -234,44 +229,48 @@ def test_unpack_inventory_5(self): inv = breezy.bzr.xml5.inventory_serializer_v5.read_inventory(inp) eq = self.assertEqual eq(len(inv), 4) - ie = inv.get_entry(b'bar-20050824000535-6bc48cfad47ed134') - eq(ie.kind, 'file') - eq(ie.revision, b'mbp@foo-00') - eq(ie.name, 'bar') - eq(inv.get_entry(ie.parent_id).kind, 'directory') + ie = inv.get_entry(b"bar-20050824000535-6bc48cfad47ed134") + eq(ie.kind, "file") + eq(ie.revision, b"mbp@foo-00") + eq(ie.name, "bar") + eq(inv.get_entry(ie.parent_id).kind, "directory") def test_unpack_basis_inventory_5(self): """Unpack canned new-style inventory.""" inv = breezy.bzr.xml5.inventory_serializer_v5.read_inventory_from_lines( - breezy.osutils.split_lines(_basis_inv_v5)) + breezy.osutils.split_lines(_basis_inv_v5) + ) eq = self.assertEqual eq(len(inv), 4) - eq(inv.revision_id, - b'mbp@sourcefrog.net-20050905063503-43948f59fa127d92') - ie = inv.get_entry(b'bar-20050824000535-6bc48cfad47ed134') - eq(ie.kind, 'file') - eq(ie.revision, b'mbp@foo-00') - eq(ie.name, 'bar') - eq(inv.get_entry(ie.parent_id).kind, 'directory') + eq(inv.revision_id, b"mbp@sourcefrog.net-20050905063503-43948f59fa127d92") + ie = inv.get_entry(b"bar-20050824000535-6bc48cfad47ed134") + eq(ie.kind, "file") + eq(ie.revision, b"mbp@foo-00") + eq(ie.name, "bar") + eq(inv.get_entry(ie.parent_id).kind, "directory") def test_unpack_inventory_5a(self): inv = breezy.bzr.xml5.inventory_serializer_v5.read_inventory_from_lines( - breezy.osutils.split_lines(_inventory_v5a), revision_id=b'test-rev-id') - self.assertEqual(b'test-rev-id', inv.root.revision) + breezy.osutils.split_lines(_inventory_v5a), revision_id=b"test-rev-id" + ) + self.assertEqual(b"test-rev-id", inv.root.revision) def test_unpack_inventory_5b(self): inv = breezy.bzr.xml5.inventory_serializer_v5.read_inventory_from_lines( - breezy.osutils.split_lines(_inventory_v5b), revision_id=b'test-rev-id') - self.assertEqual(b'a-rev-id', inv.root.revision) + breezy.osutils.split_lines(_inventory_v5b), revision_id=b"test-rev-id" + ) + self.assertEqual(b"a-rev-id", inv.root.revision) def test_repack_inventory_5(self): inv = breezy.bzr.xml5.inventory_serializer_v5.read_inventory_from_lines( - breezy.osutils.split_lines(_committed_inv_v5)) + breezy.osutils.split_lines(_committed_inv_v5) + ) outp = BytesIO() breezy.bzr.xml5.inventory_serializer_v5.write_inventory(inv, outp) self.assertEqualDiff(_expected_inv_v5, outp.getvalue()) inv2 = breezy.bzr.xml5.inventory_serializer_v5.read_inventory_from_lines( - breezy.osutils.split_lines(outp.getvalue())) + breezy.osutils.split_lines(outp.getvalue()) + ) self.assertEqual(inv, inv2) def assertRoundTrips(self, xml_string): @@ -284,7 +283,8 @@ def assertRoundTrips(self, xml_string): outp.seek(0) self.assertEqual(outp.readlines(), lines) inv2 = breezy.bzr.xml5.inventory_serializer_v5.read_inventory( - BytesIO(outp.getvalue())) + BytesIO(outp.getvalue()) + ) self.assertEqual(inv, inv2) def tests_serialize_inventory_v5_with_root(self): @@ -294,9 +294,12 @@ def check_repack_revision(self, txt): """Check that repacking a revision yields the same information.""" inp = BytesIO(txt) rev = breezy.bzr.xml5.revision_serializer_v5.read_revision(inp) - outfile_contents = breezy.bzr.xml5.revision_serializer_v5.write_revision_to_string(rev) + outfile_contents = ( + breezy.bzr.xml5.revision_serializer_v5.write_revision_to_string(rev) + ) rev2 = breezy.bzr.xml5.revision_serializer_v5.read_revision( - BytesIO(outfile_contents)) + BytesIO(outfile_contents) + ) self.assertEqual(rev, rev2) def test_repack_revision_5(self): @@ -310,19 +313,25 @@ def test_pack_revision_5(self): """Pack revision to XML v5.""" # fixed 20051025, revisions should have final newline rev = breezy.bzr.xml5.revision_serializer_v5.read_revision_from_string( - _revision_v5) - outfile_contents = breezy.bzr.xml5.revision_serializer_v5.write_revision_to_string(rev) - self.assertEqual(outfile_contents[-1:], b'\n') + _revision_v5 + ) + outfile_contents = ( + breezy.bzr.xml5.revision_serializer_v5.write_revision_to_string(rev) + ) + self.assertEqual(outfile_contents[-1:], b"\n") self.assertEqualDiff( outfile_contents, - b''.join(breezy.bzr.xml5.revision_serializer_v5.write_revision_to_lines(rev))) + b"".join( + breezy.bzr.xml5.revision_serializer_v5.write_revision_to_lines(rev) + ), + ) self.assertEqualDiff(outfile_contents, _expected_rev_v5) def test_empty_property_value(self): """Create an empty property value check that it serializes correctly.""" s_v5 = breezy.bzr.xml5.revision_serializer_v5 rev = s_v5.read_revision_from_string(_revision_v5) - props = {'empty': '', 'one': 'one'} + props = {"empty": "", "one": "one"} rev = Revision( revision_id=rev.revision_id, timestamp=rev.timestamp, @@ -331,32 +340,35 @@ def test_empty_property_value(self): message=rev.message, parent_ids=rev.parent_ids, inventory_sha1=rev.inventory_sha1, - properties=props) - txt = b''.join(s_v5.write_revision_to_lines(rev)) + properties=props, + ) + txt = b"".join(s_v5.write_revision_to_lines(rev)) new_rev = s_v5.read_revision_from_string(txt) self.assertEqual(props, new_rev.properties) def get_sample_inventory(self): - inv = Inventory(b'tree-root-321', revision_id=b'rev_outer') - inv.add(inventory.InventoryFile(b'file-id', 'file', b'tree-root-321')) - inv.add(inventory.InventoryDirectory(b'dir-id', 'dir', - b'tree-root-321')) - inv.add(inventory.InventoryLink(b'link-id', 'link', b'tree-root-321')) - inv.get_entry(b'tree-root-321').revision = b'rev_outer' - inv.get_entry(b'dir-id').revision = b'rev_outer' - inv.get_entry(b'file-id').revision = b'rev_outer' - inv.get_entry(b'file-id').text_sha1 = b'A' - inv.get_entry(b'file-id').text_size = 1 - inv.get_entry(b'link-id').revision = b'rev_outer' - inv.get_entry(b'link-id').symlink_target = 'a' + inv = Inventory(b"tree-root-321", revision_id=b"rev_outer") + inv.add(inventory.InventoryFile(b"file-id", "file", b"tree-root-321")) + inv.add(inventory.InventoryDirectory(b"dir-id", "dir", b"tree-root-321")) + inv.add(inventory.InventoryLink(b"link-id", "link", b"tree-root-321")) + inv.get_entry(b"tree-root-321").revision = b"rev_outer" + inv.get_entry(b"dir-id").revision = b"rev_outer" + inv.get_entry(b"file-id").revision = b"rev_outer" + inv.get_entry(b"file-id").text_sha1 = b"A" + inv.get_entry(b"file-id").text_size = 1 + inv.get_entry(b"link-id").revision = b"rev_outer" + inv.get_entry(b"link-id").symlink_target = "a" return inv def test_roundtrip_inventory_v7(self): inv = self.get_sample_inventory() - inv.add(inventory.TreeReference(b'nested-id', 'nested', b'tree-root-321', - b'rev_outer', b'rev_inner')) + inv.add( + inventory.TreeReference( + b"nested-id", "nested", b"tree-root-321", b"rev_outer", b"rev_inner" + ) + ) lines = xml7.inventory_serializer_v7.write_inventory_to_lines(inv) - self.assertEqualDiff(_expected_inv_v7, b''.join(lines)) + self.assertEqualDiff(_expected_inv_v7, b"".join(lines)) inv2 = xml7.inventory_serializer_v7.read_inventory_from_lines(lines) self.assertEqual(5, len(inv2)) for _path, ie in inv.iter_entries(): @@ -365,7 +377,7 @@ def test_roundtrip_inventory_v7(self): def test_roundtrip_inventory_v6(self): inv = self.get_sample_inventory() lines = xml6.inventory_serializer_v6.write_inventory_to_lines(inv) - self.assertEqualDiff(_expected_inv_v6, b''.join(lines)) + self.assertEqualDiff(_expected_inv_v6, b"".join(lines)) inv2 = xml6.inventory_serializer_v6.read_inventory_from_lines(lines) self.assertEqual(4, len(inv2)) for _path, ie in inv.iter_entries(): @@ -375,32 +387,39 @@ def test_wrong_format_v7(self): """Can't accidentally open a file with wrong serializer.""" s_v6 = breezy.bzr.xml6.inventory_serializer_v6 s_v7 = xml7.inventory_serializer_v7 - self.assertRaises(serializer.UnexpectedInventoryFormat, - s_v7.read_inventory_from_lines, - breezy.osutils.split_lines(_expected_inv_v5)) - self.assertRaises(serializer.UnexpectedInventoryFormat, - s_v6.read_inventory_from_lines, - breezy.osutils.split_lines(_expected_inv_v7)) + self.assertRaises( + serializer.UnexpectedInventoryFormat, + s_v7.read_inventory_from_lines, + breezy.osutils.split_lines(_expected_inv_v5), + ) + self.assertRaises( + serializer.UnexpectedInventoryFormat, + s_v6.read_inventory_from_lines, + breezy.osutils.split_lines(_expected_inv_v7), + ) def test_tree_reference(self): s_v5 = breezy.bzr.xml5.inventory_serializer_v5 s_v6 = breezy.bzr.xml6.inventory_serializer_v6 s_v7 = xml7.inventory_serializer_v7 - inv = Inventory(b'tree-root-321', revision_id=b'rev-outer') - inv.root.revision = b'root-rev' - inv.add(inventory.TreeReference(b'nested-id', 'nested', b'tree-root-321', - b'rev-outer', b'rev-inner')) - self.assertRaises(serializer.UnsupportedInventoryKind, - s_v5.write_inventory_to_lines, inv) - self.assertRaises(serializer.UnsupportedInventoryKind, - s_v6.write_inventory_to_lines, inv) + inv = Inventory(b"tree-root-321", revision_id=b"rev-outer") + inv.root.revision = b"root-rev" + inv.add( + inventory.TreeReference( + b"nested-id", "nested", b"tree-root-321", b"rev-outer", b"rev-inner" + ) + ) + self.assertRaises( + serializer.UnsupportedInventoryKind, s_v5.write_inventory_to_lines, inv + ) + self.assertRaises( + serializer.UnsupportedInventoryKind, s_v6.write_inventory_to_lines, inv + ) lines = s_v7.write_inventory_to_chunks(inv) inv2 = s_v7.read_inventory_from_lines(lines) - self.assertEqual(b'tree-root-321', - inv2.get_entry(b'nested-id').parent_id) - self.assertEqual(b'rev-outer', inv2.get_entry(b'nested-id').revision) - self.assertEqual( - b'rev-inner', inv2.get_entry(b'nested-id').reference_revision) + self.assertEqual(b"tree-root-321", inv2.get_entry(b"nested-id").parent_id) + self.assertEqual(b"rev-outer", inv2.get_entry(b"nested-id").revision) + self.assertEqual(b"rev-inner", inv2.get_entry(b"nested-id").reference_revision) def test_roundtrip_inventory_v8(self): inv = self.get_sample_inventory() @@ -413,64 +432,69 @@ def test_roundtrip_inventory_v8(self): def test_inventory_text_v8(self): inv = self.get_sample_inventory() lines = xml8.inventory_serializer_v8.write_inventory_to_lines(inv) - self.assertEqualDiff(_expected_inv_v8, b''.join(lines)) + self.assertEqualDiff(_expected_inv_v8, b"".join(lines)) def test_revision_text_v5(self): """Pack revision to XML v7.""" rev = breezy.bzr.xml5.revision_serializer_v5.read_revision_from_string( - _expected_rev_v5) - serialized = breezy.bzr.xml5.revision_serializer_v5.write_revision_to_lines( - rev) - self.assertEqualDiff(b''.join(serialized), _expected_rev_v5) + _expected_rev_v5 + ) + serialized = breezy.bzr.xml5.revision_serializer_v5.write_revision_to_lines(rev) + self.assertEqualDiff(b"".join(serialized), _expected_rev_v5) def test_revision_text_v8(self): """Pack revision to XML v8.""" rev = breezy.bzr.xml8.revision_serializer_v8.read_revision_from_string( - _expected_rev_v8) - serialized = breezy.bzr.xml8.revision_serializer_v8.write_revision_to_lines( - rev) - self.assertEqualDiff(b''.join(serialized), _expected_rev_v8) + _expected_rev_v8 + ) + serialized = breezy.bzr.xml8.revision_serializer_v8.write_revision_to_lines(rev) + self.assertEqualDiff(b"".join(serialized), _expected_rev_v8) def test_revision_text_v8_complex(self): """Pack revision to XML v8.""" rev = breezy.bzr.xml8.revision_serializer_v8.read_revision_from_string( - _expected_rev_v8_complex) - serialized = breezy.bzr.xml8.revision_serializer_v8.write_revision_to_lines( - rev) - self.assertEqualDiff(b''.join(serialized), _expected_rev_v8_complex) + _expected_rev_v8_complex + ) + serialized = breezy.bzr.xml8.revision_serializer_v8.write_revision_to_lines(rev) + self.assertEqualDiff(b"".join(serialized), _expected_rev_v8_complex) def test_revision_ids_are_utf8(self): """Parsed revision_ids should all be utf-8 strings, not unicode.""" sr_v5 = breezy.bzr.xml5.revision_serializer_v5 si_v5 = breezy.bzr.xml5.inventory_serializer_v5 rev = sr_v5.read_revision_from_string(_revision_utf8_v5) - self.assertEqual(b'erik@b\xc3\xa5gfors-02', rev.revision_id) + self.assertEqual(b"erik@b\xc3\xa5gfors-02", rev.revision_id) self.assertIsInstance(rev.revision_id, bytes) - self.assertEqual([b'erik@b\xc3\xa5gfors-01'], rev.parent_ids) + self.assertEqual([b"erik@b\xc3\xa5gfors-01"], rev.parent_ids) for parent_id in rev.parent_ids: self.assertIsInstance(parent_id, bytes) - self.assertEqual('Include \xb5nicode characters\n', rev.message) + self.assertEqual("Include \xb5nicode characters\n", rev.message) self.assertIsInstance(rev.message, str) # ie.revision should either be None or a utf-8 revision id - inv = si_v5.read_inventory_from_lines(breezy.osutils.split_lines(_inventory_utf8_v5)) - rev_id_1 = 'erik@b\xe5gfors-01'.encode() - rev_id_2 = 'erik@b\xe5gfors-02'.encode() - fid_root = 'TRE\xe9_ROOT'.encode() - fid_bar1 = 'b\xe5r-01'.encode() - fid_sub = 's\xb5bdir-01'.encode() - fid_bar2 = 'b\xe5r-02'.encode() - expected = [('', fid_root, None, rev_id_2), - ('b\xe5r', fid_bar1, fid_root, rev_id_1), - ('s\xb5bdir', fid_sub, fid_root, rev_id_1), - ('s\xb5bdir/b\xe5r', fid_bar2, fid_sub, rev_id_2), - ] + inv = si_v5.read_inventory_from_lines( + breezy.osutils.split_lines(_inventory_utf8_v5) + ) + rev_id_1 = "erik@b\xe5gfors-01".encode() + rev_id_2 = "erik@b\xe5gfors-02".encode() + fid_root = "TRE\xe9_ROOT".encode() + fid_bar1 = "b\xe5r-01".encode() + fid_sub = "s\xb5bdir-01".encode() + fid_bar2 = "b\xe5r-02".encode() + expected = [ + ("", fid_root, None, rev_id_2), + ("b\xe5r", fid_bar1, fid_root, rev_id_1), + ("s\xb5bdir", fid_sub, fid_root, rev_id_1), + ("s\xb5bdir/b\xe5r", fid_bar2, fid_sub, rev_id_2), + ] self.assertEqual(rev_id_2, inv.revision_id) self.assertIsInstance(inv.revision_id, bytes) actual = list(inv.iter_entries_by_dir()) - for ((exp_path, exp_file_id, exp_parent_id, exp_rev_id), - (act_path, act_ie)) in zip(expected, actual): + for (exp_path, exp_file_id, exp_parent_id, exp_rev_id), ( + act_path, + act_ie, + ) in zip(expected, actual): self.assertEqual(exp_path, act_path) self.assertIsInstance(act_path, str) self.assertEqual(exp_file_id, act_ie.file_id) @@ -488,7 +512,9 @@ def test_serialization_error(self): s_v5 = breezy.bzr.xml5.inventory_serializer_v5 e = self.assertRaises( serializer.UnexpectedInventoryFormat, - s_v5.read_inventory_from_lines, [b"')) + self.assertEqual( + b"&'"<>", + breezy.bzr.xml_serializer.encode_and_escape("&'\"<>"), + ) def test_utf8_with_xml(self): # u'\xb5\xe5&\u062c' - utf8_str = b'\xc2\xb5\xc3\xa5&\xd8\xac' - self.assertEqual(b'µå&ج', - breezy.bzr.xml_serializer.encode_and_escape(utf8_str)) + utf8_str = b"\xc2\xb5\xc3\xa5&\xd8\xac" + self.assertEqual( + b"µå&ج", + breezy.bzr.xml_serializer.encode_and_escape(utf8_str), + ) def test_unicode(self): - uni_str = '\xb5\xe5&\u062c' - self.assertEqual(b'µå&ج', - breezy.bzr.xml_serializer.encode_and_escape(uni_str)) + uni_str = "\xb5\xe5&\u062c" + self.assertEqual( + b"µå&ج", + breezy.bzr.xml_serializer.encode_and_escape(uni_str), + ) class TestMisc(TestCase): - def test_unescape_xml(self): """We get some kind of error when malformed entities are passed.""" - self.assertRaises(KeyError, breezy.bzr.xml8._unescape_xml, b'foo&bar;') + self.assertRaises(KeyError, breezy.bzr.xml8._unescape_xml, b"foo&bar;") diff --git a/breezy/bzr/textinv.py b/breezy/bzr/textinv.py index 00cb34e856..fe33b4aa02 100644 --- a/breezy/bzr/textinv.py +++ b/breezy/bzr/textinv.py @@ -27,19 +27,23 @@ def escape(s): (Why not just use backslashes? Because then we couldn't parse lines just by splitting on spaces.) """ - return (s.replace('\\', r'\x5c') - .replace(' ', r'\x20') - .replace('\t', r'\x09') - .replace('\n', r'\x0a')) + return ( + s.replace("\\", r"\x5c") + .replace(" ", r"\x20") + .replace("\t", r"\x09") + .replace("\n", r"\x0a") + ) def unescape(s): - if s.find(' ') != -1: + if s.find(" ") != -1: raise AssertionError() - s = (s.replace(r'\x20', ' ') - .replace(r'\x09', '\t') - .replace(r'\x0a', '\n') - .replace(r'\x5c', '\\')) + s = ( + s.replace(r"\x20", " ") + .replace(r"\x09", "\t") + .replace(r"\x0a", "\n") + .replace(r"\x5c", "\\") + ) # TODO: What if there's anything else? @@ -53,15 +57,15 @@ def write_text_inventory(inv, outf): if inv.is_root(ie.file_id): continue - outf.write(ie.file_id + ' ') - outf.write(escape(ie.name) + ' ') - outf.write(ie.kind + ' ') - outf.write(ie.parent_id + ' ') + outf.write(ie.file_id + " ") + outf.write(escape(ie.name) + " ") + outf.write(ie.kind + " ") + outf.write(ie.parent_id + " ") - if ie.kind == 'file': + if ie.kind == "file": outf.write(ie.text_id) - outf.write(' ' + ie.text_sha1) - outf.write(' ' + str(ie.text_size)) + outf.write(" " + ie.text_sha1) + outf.write(" " + str(ie.text_size)) outf.write("\n") outf.write(END_MARK) @@ -74,13 +78,15 @@ def read_text_inventory(tf): inv = Inventory() for l in tf: - fields = l.split(' ') - if fields[0] == '#': + fields = l.split(" ") + if fields[0] == "#": break - {'file_id': fields[0], - 'name': unescape(fields[1]), - 'kind': fields[2], - 'parent_id': fields[3]} + { + "file_id": fields[0], + "name": unescape(fields[1]), + "kind": fields[2], + "parent_id": fields[3], + } # inv.add(ie) if l != END_MARK: diff --git a/breezy/bzr/transform.py b/breezy/bzr/transform.py index 8be49ffc7d..a181d87759 100644 --- a/breezy/bzr/transform.py +++ b/breezy/bzr/transform.py @@ -70,8 +70,7 @@ def _content_match(tree, entry, tree_path, kind, target_path): if entry.kind == "directory": return True if entry.kind == "file": - with open(target_path, 'rb') as f1, \ - tree.get_file(tree_path) as f2: + with open(target_path, "rb") as f1, tree.get_file(tree_path) as f2: if osutils.compare_files(f1, f2): return True elif entry.kind == "symlink": @@ -102,8 +101,8 @@ def __init__(self, tree, pb=None, case_sensitive: bool = True) -> None: # Mapping of new file_id -> trans_id self._r_new_id: Dict[bytes, str] = {} # The trans_id that will be used as the tree root - if tree.is_versioned(''): - self._new_root = self.trans_id_tree_path('') + if tree.is_versioned(""): + self._new_root = self.trans_id_tree_path("") else: self._new_root = None # Whether the target is case sensitive @@ -117,7 +116,7 @@ def finalize(self): """ if self._tree is None: return - for hook in MutableTree.hooks['post_transform']: + for hook in MutableTree.hooks["post_transform"]: hook(self._tree, self) self._tree.unlock() self._tree = None @@ -148,8 +147,9 @@ def adjust_root_path(self, name, parent): # force moving all children of root for child_id in self.iter_tree_children(old_root): if child_id != parent: - self.adjust_path(self.final_name(child_id), - self.final_parent(child_id), child_id) + self.adjust_path( + self.final_name(child_id), self.final_parent(child_id), child_id + ) file_id = self.final_file_id(child_id) if file_id is not None: self.unversion_file(child_id) @@ -158,7 +158,7 @@ def adjust_root_path(self, name, parent): # the physical root needs a new transaction id self._tree_path_ids.pop("") self._tree_id_paths.pop(old_root) - self._new_root = self.trans_id_tree_path('') + self._new_root = self.trans_id_tree_path("") if parent == old_root: parent = self._new_root self.adjust_path(name, parent, old_root) @@ -178,12 +178,11 @@ def fixup_new_roots(self): irrelevant. """ - new_roots = [k for k, v in self._new_parent.items() - if v == ROOT_PARENT] + new_roots = [k for k, v in self._new_parent.items() if v == ROOT_PARENT] if len(new_roots) < 1: return if len(new_roots) != 1: - raise ValueError('A tree cannot have two roots!') + raise ValueError("A tree cannot have two roots!") if self._new_root is None: self._new_root = new_roots[0] return @@ -199,8 +198,10 @@ def fixup_new_roots(self): self.unversion_file(old_new_root) # if, at this stage, root still has an old file_id, zap it so we can # stick a new one in. - if (self.tree_file_id(self._new_root) is not None - and self._new_root not in self._removed_id): + if ( + self.tree_file_id(self._new_root) is not None + and self._new_root not in self._removed_id + ): self.unversion_file(self._new_root) if file_id is not None: self.version_file(self._new_root, file_id=file_id) @@ -234,7 +235,7 @@ def trans_id_file_id(self, file_id): (this will likely lead to an unversioned parent conflict.). """ if file_id is None: - raise ValueError('None is not a valid file id') + raise ValueError("None is not a valid file id") if file_id in self._r_new_id and self._r_new_id[file_id] is not None: return self._r_new_id[file_id] else: @@ -273,8 +274,13 @@ def new_paths(self, filesystem_only=False): needs_rename = self._needs_rename.difference(stale_ids) id_sets = (needs_rename, self._new_executability) else: - id_sets = (self._new_name, self._new_parent, self._new_contents, - self._new_id, self._new_executability) + id_sets = ( + self._new_name, + self._new_parent, + self._new_contents, + self._new_id, + self._new_executability, + ) for id_set in id_sets: new_ids.update(id_set) return sorted(FinalPaths(self).get_paths(new_ids)) @@ -286,7 +292,7 @@ def tree_file_id(self, trans_id): return None # the file is old; the old id is still valid if self._new_root == trans_id: - return self._tree.path2id('') + return self._tree.path2id("") return self._tree.path2id(path) def final_is_versioned(self, trans_id): @@ -347,14 +353,15 @@ def _add_tree_children(self): removed. This is a necessary first step in detecting conflicts. """ parents = list(self.by_parent()) - parents.extend([t for t in self._removed_contents if - self.tree_kind(t) == 'directory']) + parents.extend( + [t for t in self._removed_contents if self.tree_kind(t) == "directory"] + ) for trans_id in self._removed_id: path = self.tree_path(trans_id) if path is not None: - if self._tree.stored_kind(path) == 'directory': + if self._tree.stored_kind(path) == "directory": parents.append(trans_id) - elif self.tree_kind(trans_id) == 'directory': + elif self.tree_kind(trans_id) == "directory": parents.append(trans_id) for parent_id in parents: @@ -387,7 +394,9 @@ def _has_named_child(self, name, parent_id, known_children): # Not known by the tree transform yet, check the filesystem return osutils.lexists(self._tree.abspath(child_path)) else: - raise AssertionError(f'child_id is missing: {name}, {parent_id}, {child_id}') + raise AssertionError( + f"child_id is missing: {name}, {parent_id}, {child_id}" + ) def _available_backup_name(self, name, target_id): """Find an available backup name. @@ -399,9 +408,8 @@ def _available_backup_name(self, name, target_id): """ known_children = self.by_parent().get(target_id, []) return osutils.available_backup_name( - name, - lambda base: self._has_named_child( - base, target_id, known_children)) + name, lambda base: self._has_named_child(base, target_id, known_children) + ) def _parent_loops(self): """No entry should be its own ancestor.""" @@ -415,7 +423,7 @@ def _parent_loops(self): except KeyError: break if parent_id == trans_id: - yield ('parent loop', trans_id) + yield ("parent loop", trans_id) if parent_id in seen: break @@ -428,7 +436,7 @@ def _unversioned_parents(self, by_parent): continue for child_id in children: if self.final_is_versioned(child_id): - yield ('unversioned parent', parent_id) + yield ("unversioned parent", parent_id) break def _improper_versioning(self): @@ -438,14 +446,14 @@ def _improper_versioning(self): """ for trans_id in self._new_id: kind = self.final_kind(trans_id) - if kind == 'symlink' and not self._tree.supports_symlinks(): + if kind == "symlink" and not self._tree.supports_symlinks(): # Ignore symlinks as they are not supported on this platform continue if kind is None: - yield ('versioning no contents', trans_id) + yield ("versioning no contents", trans_id) continue if not self._tree.versionable_kind(kind): - yield ('versioning bad kind', trans_id, kind) + yield ("versioning bad kind", trans_id, kind) def _executability_conflicts(self): """Check for bad executability changes. @@ -457,10 +465,10 @@ def _executability_conflicts(self): """ for trans_id in self._new_executability: if not self.final_is_versioned(trans_id): - yield ('unversioned executability', trans_id) + yield ("unversioned executability", trans_id) else: if self.final_kind(trans_id) != "file": - yield ('non-file executability', trans_id) + yield ("non-file executability", trans_id) def _overwrite_conflicts(self): """Check for overwrites (not permitted on Win32).""" @@ -468,7 +476,7 @@ def _overwrite_conflicts(self): if self.tree_kind(trans_id) is None: continue if trans_id not in self._removed_contents: - yield ('overwrite', trans_id, self.final_name(trans_id)) + yield ("overwrite", trans_id, self.final_name(trans_id)) def _duplicate_entries(self, by_parent): """No directory may have two entries with the same name.""" @@ -491,7 +499,7 @@ def _duplicate_entries(self, by_parent): if kind is None and not self.final_is_versioned(trans_id): continue if name == last_name: - yield ('duplicate', last_trans_id, trans_id, name) + yield ("duplicate", last_trans_id, trans_id, name) last_name = name last_trans_id = trans_id @@ -512,10 +520,10 @@ def _parent_type_conflicts(self, by_parent): kind = self.final_kind(parent_id) if kind is None: # The directory will be deleted - yield ('missing parent', parent_id) + yield ("missing parent", parent_id) elif kind != "directory": # Meh, we need a *directory* to put something in it - yield ('non-directory parent', parent_id) + yield ("non-directory parent", parent_id) def _set_executability(self, path, trans_id): """Set the executability of versioned files.""" @@ -543,8 +551,9 @@ def _new_entry(self, name, parent_id, file_id): self.version_file(trans_id, file_id=file_id) return trans_id - def new_file(self, name, parent_id, contents, file_id=None, - executable=None, sha1=None): + def new_file( + self, name, parent_id, contents, file_id=None, executable=None, sha1=None + ): """Convenience method to create files. name is the name of the file to create. @@ -663,8 +672,9 @@ def _from_file_data(self, from_trans_id, from_versioned, from_path): from_path = self._tree_id_paths.get(from_trans_id) if from_versioned: # get data from working tree if versioned - from_entry = next(self._tree.iter_entries_by_dir( - specific_files=[from_path]))[1] + from_entry = next( + self._tree.iter_entries_by_dir(specific_files=[from_path]) + )[1] from_name = from_entry.name from_parent = from_entry.parent_id else: @@ -680,8 +690,9 @@ def _from_file_data(self, from_trans_id, from_versioned, from_path): tree_parent = self.get_tree_parent(from_trans_id) from_parent = self.tree_file_id(tree_parent) if from_path is not None: - from_kind, from_executable, from_stats = \ - self._tree._comparison_data(from_entry, from_path) + from_kind, from_executable, from_stats = self._tree._comparison_data( + from_entry, from_path + ) else: from_kind = None from_executable = False @@ -742,33 +753,44 @@ def iter_changes(self): else: to_path = final_paths.get_path(to_trans_id) - from_name, from_parent, from_kind, from_executable = \ - self._from_file_data(from_trans_id, from_versioned, from_path) + from_name, from_parent, from_kind, from_executable = self._from_file_data( + from_trans_id, from_versioned, from_path + ) - to_name, to_parent, to_kind, to_executable = \ - self._to_file_data(to_trans_id, from_trans_id, from_executable) + to_name, to_parent, to_kind, to_executable = self._to_file_data( + to_trans_id, from_trans_id, from_executable + ) if from_kind != to_kind: modified = True - elif to_kind in ('file', 'symlink') and ( - to_trans_id != from_trans_id - or to_trans_id in self._new_contents): + elif to_kind in ("file", "symlink") and ( + to_trans_id != from_trans_id or to_trans_id in self._new_contents + ): modified = True - if (not modified and from_versioned == to_versioned - and from_parent == to_parent and from_name == to_name - and from_executable == to_executable): + if ( + not modified + and from_versioned == to_versioned + and from_parent == to_parent + and from_name == to_name + and from_executable == to_executable + ): continue results.append( inventorytree.InventoryTreeChange( - file_id, (from_path, to_path), modified, + file_id, + (from_path, to_path), + modified, (from_versioned, to_versioned), (from_parent, to_parent), (from_name, to_name), (from_kind, to_kind), - (from_executable, to_executable))) + (from_executable, to_executable), + ) + ) def path_key(c): - return (c.path[0] or '', c.path[1] or '') + return (c.path[0] or "", c.path[1] or "") + return iter(sorted(results, key=path_key)) def get_preview_tree(self): @@ -779,9 +801,19 @@ def get_preview_tree(self): """ raise NotImplementedError(self.get_preview) - def commit(self, branch, message, merge_parents=None, strict=False, - timestamp=None, timezone=None, committer=None, authors=None, - revprops=None, revision_id=None): + def commit( + self, + branch, + message, + merge_parents=None, + strict=False, + timestamp=None, + timezone=None, + committer=None, + authors=None, + revprops=None, + revision_id=None, + ): """Commit the result of this TreeTransform to a branch. :param branch: The branch to commit to. @@ -812,27 +844,30 @@ def commit(self, branch, message, merge_parents=None, strict=False, revno, last_rev_id = branch.last_revision_info() if last_rev_id == _mod_revision.NULL_REVISION: if merge_parents is not None: - raise ValueError('Cannot supply merge parents for first' - ' commit.') + raise ValueError("Cannot supply merge parents for first" " commit.") parent_ids = [] else: parent_ids = [last_rev_id] if merge_parents is not None: parent_ids.extend(merge_parents) if self._tree.get_revision_id() != last_rev_id: - raise ValueError('TreeTransform not based on branch basis: %s' % - self._tree.get_revision_id().decode('utf-8')) + raise ValueError( + "TreeTransform not based on branch basis: %s" + % self._tree.get_revision_id().decode("utf-8") + ) from .. import commit + revprops = commit.Commit.update_revprops(revprops, branch, authors) - builder = branch.get_commit_builder(parent_ids, - timestamp=timestamp, - timezone=timezone, - committer=committer, - revprops=revprops, - revision_id=revision_id) + builder = branch.get_commit_builder( + parent_ids, + timestamp=timestamp, + timezone=timezone, + committer=committer, + revprops=revprops, + revision_id=revision_id, + ) preview = self.get_preview_tree() - list(builder.record_iter_changes(preview, last_rev_id, - self.iter_changes())) + list(builder.record_iter_changes(preview, last_rev_id, self.iter_changes())) builder.finish_inventory() revision_id = builder.commit(message) branch.set_last_revision_info(revno + 1, revision_id) @@ -841,7 +876,7 @@ def commit(self, branch, message, merge_parents=None, strict=False, def _text_parent(self, trans_id): path = self.tree_path(trans_id) try: - if path is None or self._tree.kind(path) != 'file': + if path is None or self._tree.kind(path) != "file": return None except _mod_transport.NoSuchFile: return None @@ -867,50 +902,55 @@ def serialize(self, serializer): :param serializer: A Serialiser like pack.ContainerSerializer. """ import fastbencode as bencode - new_name = {k.encode('utf-8'): v.encode('utf-8') - for k, v in self._new_name.items()} - new_parent = {k.encode('utf-8'): v.encode('utf-8') - for k, v in self._new_parent.items()} - new_id = {k.encode('utf-8'): v - for k, v in self._new_id.items()} - new_executability = {k.encode('utf-8'): int(v) - for k, v in self._new_executability.items()} - tree_path_ids = {k.encode('utf-8'): v.encode('utf-8') - for k, v in self._tree_path_ids.items()} - non_present_ids = {k: v.encode('utf-8') - for k, v in self._non_present_ids.items()} - removed_contents = [trans_id.encode('utf-8') - for trans_id in self._removed_contents] - removed_id = [trans_id.encode('utf-8') - for trans_id in self._removed_id] + + new_name = { + k.encode("utf-8"): v.encode("utf-8") for k, v in self._new_name.items() + } + new_parent = { + k.encode("utf-8"): v.encode("utf-8") for k, v in self._new_parent.items() + } + new_id = {k.encode("utf-8"): v for k, v in self._new_id.items()} + new_executability = { + k.encode("utf-8"): int(v) for k, v in self._new_executability.items() + } + tree_path_ids = { + k.encode("utf-8"): v.encode("utf-8") for k, v in self._tree_path_ids.items() + } + non_present_ids = { + k: v.encode("utf-8") for k, v in self._non_present_ids.items() + } + removed_contents = [ + trans_id.encode("utf-8") for trans_id in self._removed_contents + ] + removed_id = [trans_id.encode("utf-8") for trans_id in self._removed_id] attribs = { - b'_id_number': self._id_number, - b'_new_name': new_name, - b'_new_parent': new_parent, - b'_new_executability': new_executability, - b'_new_id': new_id, - b'_tree_path_ids': tree_path_ids, - b'_removed_id': removed_id, - b'_removed_contents': removed_contents, - b'_non_present_ids': non_present_ids, - } - yield serializer.bytes_record(bencode.bencode(attribs), - ((b'attribs',),)) + b"_id_number": self._id_number, + b"_new_name": new_name, + b"_new_parent": new_parent, + b"_new_executability": new_executability, + b"_new_id": new_id, + b"_tree_path_ids": tree_path_ids, + b"_removed_id": removed_id, + b"_removed_contents": removed_contents, + b"_non_present_ids": non_present_ids, + } + yield serializer.bytes_record(bencode.bencode(attribs), ((b"attribs",),)) for trans_id, kind in sorted(self._new_contents.items()): - if kind == 'file': - with open(self._limbo_name(trans_id), 'rb') as cur_file: + if kind == "file": + with open(self._limbo_name(trans_id), "rb") as cur_file: lines = cur_file.readlines() parents = self._get_parents_lines(trans_id) mpdiff = multiparent.MultiParent.from_lines(lines, parents) - content = b''.join(mpdiff.to_patch()) - if kind == 'directory': - content = b'' - if kind == 'symlink': + content = b"".join(mpdiff.to_patch()) + if kind == "directory": + content = b"" + if kind == "symlink": content = self._read_symlink_target(trans_id) if not isinstance(content, bytes): - content = content.encode('utf-8') + content = content.encode("utf-8") yield serializer.bytes_record( - content, ((trans_id.encode('utf-8'), kind.encode('ascii')),)) + content, ((trans_id.encode("utf-8"), kind.encode("ascii")),) + ) def deserialize(self, records): """Deserialize a stored TreeTransform. @@ -919,45 +959,51 @@ def deserialize(self, records): pack.ContainerPushParser. """ import fastbencode as bencode + names, content = next(records) attribs = bencode.bdecode(content) - self._id_number = attribs[b'_id_number'] - self._new_name = {k.decode('utf-8'): v.decode('utf-8') - for k, v in attribs[b'_new_name'].items()} - self._new_parent = {k.decode('utf-8'): v.decode('utf-8') - for k, v in attribs[b'_new_parent'].items()} + self._id_number = attribs[b"_id_number"] + self._new_name = { + k.decode("utf-8"): v.decode("utf-8") + for k, v in attribs[b"_new_name"].items() + } + self._new_parent = { + k.decode("utf-8"): v.decode("utf-8") + for k, v in attribs[b"_new_parent"].items() + } self._new_executability = { - k.decode('utf-8'): bool(v) - for k, v in attribs[b'_new_executability'].items()} - self._new_id = {k.decode('utf-8'): v - for k, v in attribs[b'_new_id'].items()} + k.decode("utf-8"): bool(v) + for k, v in attribs[b"_new_executability"].items() + } + self._new_id = {k.decode("utf-8"): v for k, v in attribs[b"_new_id"].items()} self._r_new_id = {v: k for k, v in self._new_id.items()} self._tree_path_ids = {} self._tree_id_paths = {} - for bytepath, trans_id in attribs[b'_tree_path_ids'].items(): - path = bytepath.decode('utf-8') - trans_id = trans_id.decode('utf-8') + for bytepath, trans_id in attribs[b"_tree_path_ids"].items(): + path = bytepath.decode("utf-8") + trans_id = trans_id.decode("utf-8") self._tree_path_ids[path] = trans_id self._tree_id_paths[trans_id] = path - self._removed_id = {trans_id.decode('utf-8') - for trans_id in attribs[b'_removed_id']} + self._removed_id = { + trans_id.decode("utf-8") for trans_id in attribs[b"_removed_id"] + } self._removed_contents = { - trans_id.decode('utf-8') - for trans_id in attribs[b'_removed_contents']} + trans_id.decode("utf-8") for trans_id in attribs[b"_removed_contents"] + } self._non_present_ids = { - k: v.decode('utf-8') - for k, v in attribs[b'_non_present_ids'].items()} + k: v.decode("utf-8") for k, v in attribs[b"_non_present_ids"].items() + } for ((trans_id, kind),), content in records: - trans_id = trans_id.decode('utf-8') - kind = kind.decode('ascii') - if kind == 'file': + trans_id = trans_id.decode("utf-8") + kind = kind.decode("ascii") + if kind == "file": mpdiff = multiparent.MultiParent.from_patch(content) lines = mpdiff.to_lines(self._get_parents_texts(trans_id)) self.create_file(lines, trans_id) - if kind == 'directory': + if kind == "directory": self.create_directory(trans_id) - if kind == 'symlink': - self.create_symlink(content.decode('utf-8'), trans_id) + if kind == "symlink": + self.create_symlink(content.decode("utf-8"), trans_id) def create_file(self, contents, trans_id, mode_id=None, sha1=None): """Schedule creation of a new file. @@ -1019,41 +1065,50 @@ def cook_conflicts(self, raw_conflicts): content_conflict_file_ids = set() cooked_conflicts = list(iter_cook_conflicts(raw_conflicts, self)) for c in cooked_conflicts: - if c.typestring == 'contents conflict': + if c.typestring == "contents conflict": content_conflict_file_ids.add(c.file_id) # We want to get rid of path conflicts when a corresponding contents # conflict exists. This can occur when one branch deletes a file while # the other renames *and* modifies it. In this case, the content # conflict is enough. cooked_conflicts = [ - c for c in cooked_conflicts - if c.typestring != 'path conflict' or - c.file_id not in content_conflict_file_ids] + c + for c in cooked_conflicts + if c.typestring != "path conflict" + or c.file_id not in content_conflict_file_ids + ] return sorted(cooked_conflicts, key=Conflict.sort_key) def cook_path_conflict( - tt, fp, conflict_type, trans_id, file_id, this_parent, this_name, - other_parent, other_name): + tt, + fp, + conflict_type, + trans_id, + file_id, + this_parent, + this_name, + other_parent, + other_name, +): if this_parent is None or this_name is None: - this_path = '' + this_path = "" else: parent_path = fp.get_path(tt.trans_id_file_id(this_parent)) this_path = osutils.pathjoin(parent_path, this_name) if other_parent is None or other_name is None: - other_path = '' + other_path = "" else: try: parent_path = fp.get_path(tt.trans_id_file_id(other_parent)) except NoFinalPath: # The other entry was in a path that doesn't exist in our tree. # Put it in the root. - parent_path = '' + parent_path = "" other_path = osutils.pathjoin(parent_path, other_name) return Conflict.factory( - conflict_type, path=this_path, - conflict_path=other_path, - file_id=file_id) + conflict_type, path=this_path, conflict_path=other_path, file_id=file_id + ) def cook_content_conflict(tt, fp, conflict_type, trans_ids): @@ -1063,10 +1118,10 @@ def cook_content_conflict(tt, fp, conflict_type, trans_ids): # Ok we found the relevant file-id break path = fp.get_path(trans_id) - for suffix in ('.BASE', '.THIS', '.OTHER'): + for suffix in (".BASE", ".THIS", ".OTHER"): if path.endswith(suffix): # Here is the raw path - path = path[:-len(suffix)] + path = path[: -len(suffix)] break return Conflict.factory(conflict_type, path=path, file_id=file_id) @@ -1078,11 +1133,12 @@ def cook_text_conflict(tt, fp, conflict_type, trans_id): CONFLICT_COOKERS = { - 'path conflict': cook_path_conflict, - 'text conflict': cook_text_conflict, - 'contents conflict': cook_content_conflict, + "path conflict": cook_path_conflict, + "text conflict": cook_text_conflict, + "contents conflict": cook_content_conflict, } + def iter_cook_conflicts(raw_conflicts, tt): fp = FinalPaths(tt) for conflict in raw_conflicts: @@ -1095,16 +1151,20 @@ def iter_cook_conflicts(raw_conflicts, tt): modified_id = tt.final_file_id(conflict[2]) if len(conflict) == 3: yield Conflict.factory( - c_type, action=action, path=modified_path, file_id=modified_id) + c_type, action=action, path=modified_path, file_id=modified_id + ) else: conflicting_path = fp.get_path(conflict[3]) conflicting_id = tt.final_file_id(conflict[3]) yield Conflict.factory( - c_type, action=action, path=modified_path, + c_type, + action=action, + path=modified_path, file_id=modified_id, conflict_path=conflicting_path, - conflict_file_id=conflicting_id) + conflict_file_id=conflicting_id, + ) else: yield cooker(tt, fp, *conflict) @@ -1198,8 +1258,7 @@ def adjust_path(self, name, parent, trans_id): previous_parent = self._new_parent.get(trans_id) previous_name = self._new_name.get(trans_id) super().adjust_path(name, parent, trans_id) - if (trans_id in self._limbo_files - and trans_id not in self._needs_rename): + if trans_id in self._limbo_files and trans_id not in self._needs_rename: self._rename_in_limbo([trans_id]) if previous_parent != parent: self._limbo_children[previous_parent].remove(trans_id) @@ -1228,7 +1287,7 @@ def _rename_in_limbo(self, trans_ids): self._possibly_stale_limbo_files.remove(old_path) for descendant in self._limbo_descendants(trans_id): desc_path = self._limbo_files[descendant] - desc_path = new_path + desc_path[len(old_path):] + desc_path = new_path + desc_path[len(old_path) :] self._limbo_files[descendant] = desc_path def _limbo_descendants(self, trans_id): @@ -1256,8 +1315,8 @@ def create_file(self, contents, trans_id, mode_id=None, sha1=None): We can use it to prevent future sha1 computations. """ name = self._limbo_name(trans_id) - with open(name, 'wb') as f: - unique_add(self._new_contents, trans_id, 'file') + with open(name, "wb") as f: + unique_add(self._new_contents, trans_id, "file") f.writelines(contents) self._set_mtime(name) self._set_mode(trans_id, mode_id, S_ISREG) @@ -1287,7 +1346,7 @@ def create_hardlink(self, path, trans_id): except PermissionError as e: raise errors.HardLinkNotSupported(path) from e try: - unique_add(self._new_contents, trans_id, 'file') + unique_add(self._new_contents, trans_id, "file") except BaseException: # Clean up the file, it never got registered so # TreeTransform.finalize() won't clean it up. @@ -1300,7 +1359,7 @@ def create_directory(self, trans_id): See also new_directory. """ os.mkdir(self._limbo_name(trans_id)) - unique_add(self._new_contents, trans_id, 'directory') + unique_add(self._new_contents, trans_id, "directory") def create_symlink(self, target, trans_id): """Schedule creation of a new symbolic link. @@ -1315,12 +1374,11 @@ def create_symlink(self, target, trans_id): path = FinalPaths(self).get_path(trans_id) except KeyError: path = None - trace.warning( - f'Unable to create symlink "{path}" on this filesystem.') + trace.warning(f'Unable to create symlink "{path}" on this filesystem.') # We add symlink to _new_contents even if they are unsupported # and not created. These entries are subsequently used to avoid # conflicts on platforms that don't support symlink - unique_add(self._new_contents, trans_id, 'symlink') + unique_add(self._new_contents, trans_id, "symlink") def cancel_creation(self, trans_id): """Cancel the creation of new file contents.""" @@ -1338,7 +1396,7 @@ def cancel_creation(self, trans_id): def new_orphan(self, trans_id, parent_id): conf = self._tree.get_config_stack() - handle_orphan = conf.get('transform.orphan_policy') + handle_orphan = conf.get("transform.orphan_policy") handle_orphan(self, trans_id, parent_id) @@ -1416,18 +1474,16 @@ def __init__(self, tree, pb=None) -> None: """ tree.lock_tree_write() try: - limbodir = urlutils.local_path_from_url( - tree._transport.abspath('limbo')) + limbodir = urlutils.local_path_from_url(tree._transport.abspath("limbo")) try: - osutils.ensure_empty_directory_exists( - limbodir) + osutils.ensure_empty_directory_exists(limbodir) except errors.DirectoryNotEmpty as e: raise errors.ExistingLimbo(limbodir) from e deletiondir = urlutils.local_path_from_url( - tree._transport.abspath('pending-deletion')) + tree._transport.abspath("pending-deletion") + ) try: - osutils.ensure_empty_directory_exists( - deletiondir) + osutils.ensure_empty_directory_exists(deletiondir) except errors.DirectoryNotEmpty as e: raise errors.ExistingPendingDeletion(deletiondir) from e except BaseException: @@ -1438,8 +1494,7 @@ def __init__(self, tree, pb=None) -> None: self._realpaths: Dict[str, str] = {} # Cache of relpath results, to speed up canonical_path self._relpaths: Dict[str, str] = {} - DiskTreeTransform.__init__(self, tree, limbodir, pb, - tree.case_sensitive) + DiskTreeTransform.__init__(self, tree, limbodir, pb, tree.case_sensitive) self._deletiondir: str = deletiondir def canonical_path(self, path): @@ -1514,7 +1569,7 @@ def _generate_limbo_path(self, trans_id): # tree), choose a limbo name inside the parent, to reduce further # renames. use_direct_path = False - if self._new_contents.get(parent) == 'directory': + if self._new_contents.get(parent) == "directory": filename = self._new_name.get(trans_id) if filename is not None: if parent not in self._limbo_children: @@ -1525,12 +1580,15 @@ def _generate_limbo_path(self, trans_id): # already taken this pathname, i.e. if the name is unused, or # if it is already associated with this trans_id. elif self._case_sensitive_target: - if (self._limbo_children_names[parent].get(filename) - in (trans_id, None)): + if self._limbo_children_names[parent].get(filename) in ( + trans_id, + None, + ): use_direct_path = True else: - for l_filename, l_trans_id in ( - self._limbo_children_names[parent].items()): + for l_filename, l_trans_id in self._limbo_children_names[ + parent + ].items(): if l_trans_id == trans_id: continue if l_filename.lower() == filename.lower(): @@ -1567,14 +1625,15 @@ def _duplicate_ids(self): except errors.UnsupportedOperation: # it's okay for non-file-id trees to raise UnsupportedOperation. return [] - removed_tree_ids = {self.tree_file_id(trans_id) for trans_id in - self._removed_id} + removed_tree_ids = { + self.tree_file_id(trans_id) for trans_id in self._removed_id + } active_tree_ids = all_ids.difference(removed_tree_ids) for trans_id, file_id in self._new_id.items(): if file_id in active_tree_ids: path = self._tree.id2path(file_id) old_trans_id = self.trans_id_tree_path(path) - conflicts.append(('duplicate id', old_trans_id, trans_id)) + conflicts.append(("duplicate id", old_trans_id, trans_id)) return conflicts def find_raw_conflicts(self): @@ -1596,14 +1655,14 @@ def apply(self, no_conflicts=False, precomputed_delta=None, _mover=None): calculating one. :param _mover: Supply an alternate FileMover, for testing """ - for hook in MutableTree.hooks['pre_transform']: + for hook in MutableTree.hooks["pre_transform"]: hook(self._tree, self) if not no_conflicts: self._check_malformed() self.rename_count = 0 with ui.ui_factory.nested_progress_bar() as child_pb: if precomputed_delta is None: - child_pb.update(gettext('Apply phase'), 0, 2) + child_pb.update(gettext("Apply phase"), 0, 2) inventory_delta = self._generate_inventory_delta() offset = 1 else: @@ -1614,9 +1673,9 @@ def apply(self, no_conflicts=False, precomputed_delta=None, _mover=None): else: mover = _mover try: - child_pb.update(gettext('Apply phase'), 0 + offset, 2 + offset) + child_pb.update(gettext("Apply phase"), 0 + offset, 2 + offset) self._apply_removals(mover) - child_pb.update(gettext('Apply phase'), 1 + offset, 2 + offset) + child_pb.update(gettext("Apply phase"), 1 + offset, 2 + offset) modified_paths = self._apply_insertions(mover) except BaseException: mover.rollback() @@ -1624,7 +1683,7 @@ def apply(self, no_conflicts=False, precomputed_delta=None, _mover=None): else: mover.apply_deletions() if self.final_file_id(self.root) is None: - inventory_delta = [e for e in inventory_delta if e[0] != ''] + inventory_delta = [e for e in inventory_delta if e[0] != ""] self._tree.apply_inventory_delta(inventory_delta) self._apply_observed_sha1s() self._done = True @@ -1644,15 +1703,14 @@ def _apply_removals(self, mover): with ui.ui_factory.nested_progress_bar() as child_pb: for num, (path, trans_id) in enumerate(tree_paths): # do not attempt to move root into a subdirectory of itself. - if path == '': + if path == "": continue - child_pb.update(gettext('removing file'), num, len(tree_paths)) + child_pb.update(gettext("removing file"), num, len(tree_paths)) full_path = self._tree.abspath(path) if trans_id in self._removed_contents: delete_path = os.path.join(self._deletiondir, trans_id) mover.pre_delete(full_path, delete_path) - elif (trans_id in self._new_name or - trans_id in self._new_parent): + elif trans_id in self._new_name or trans_id in self._new_parent: try: mover.rename(full_path, self._limbo_name(trans_id)) except TransformRenameFailed as e: @@ -1676,8 +1734,7 @@ def _apply_insertions(self, mover): with ui.ui_factory.nested_progress_bar() as child_pb: for num, (path, trans_id) in enumerate(new_paths): if (num % 10) == 0: - child_pb.update(gettext('adding file'), - num, len(new_paths)) + child_pb.update(gettext("adding file"), num, len(new_paths)) full_path = self._tree.abspath(path) if trans_id in self._needs_rename: try: @@ -1691,8 +1748,7 @@ def _apply_insertions(self, mover): # TODO: if trans_id in self._observed_sha1s, we should # re-stat the final target, since ctime will be # updated by the change. - if (trans_id in self._new_contents - or self.path_changed(trans_id)): + if trans_id in self._new_contents or self.path_changed(trans_id): if trans_id in self._new_contents: modified_paths.append(full_path) if trans_id in self._new_executability: @@ -1753,10 +1809,15 @@ def _inventory_altered(self): """ changed_ids = set() # Find entries whose file_ids are new (or changed). - new_file_id = {t for t in self._new_id - if self._new_id[t] != self.tree_file_id(t)} - for id_set in [self._new_name, self._new_parent, new_file_id, - self._new_executability]: + new_file_id = { + t for t in self._new_id if self._new_id[t] != self.tree_file_id(t) + } + for id_set in [ + self._new_name, + self._new_parent, + new_file_id, + self._new_executability, + ]: changed_ids.update(id_set) # removing implies a kind change changed_kind = set(self._removed_contents) @@ -1765,8 +1826,9 @@ def _inventory_altered(self): # Ignore entries that are already known to have changed. changed_kind.difference_update(changed_ids) # to keep only the truly changed ones - changed_kind = (t for t in changed_kind - if self.tree_kind(t) != self.final_kind(t)) + changed_kind = ( + t for t in changed_kind if self.tree_kind(t) != self.final_kind(t) + ) # all kind changes will alter the inventory changed_ids.update(changed_kind) # To find entries with changed parent_ids, find parents which existed, @@ -1784,10 +1846,9 @@ def _generate_inventory_delta(self): with ui.ui_factory.nested_progress_bar() as child_pb: for num, trans_id in enumerate(self._removed_id): if (num % 10) == 0: - child_pb.update(gettext('removing file'), - num, total_entries) + child_pb.update(gettext("removing file"), num, total_entries) if trans_id == self._new_root: - file_id = self._tree.path2id('') + file_id = self._tree.path2id("") else: file_id = self.tree_file_id(trans_id) # File-id isn't really being deleted, just moved @@ -1795,12 +1856,14 @@ def _generate_inventory_delta(self): continue path = self._tree_id_paths[trans_id] inventory_delta.append((path, None, file_id, None)) - new_path_file_ids = {t: self.final_file_id(t) for p, t in - new_paths} + new_path_file_ids = {t: self.final_file_id(t) for p, t in new_paths} for num, (path, trans_id) in enumerate(new_paths): if (num % 10) == 0: - child_pb.update(gettext('adding file'), - num + len(self._removed_id), total_entries) + child_pb.update( + gettext("adding file"), + num + len(self._removed_id), + total_entries, + ) file_id = new_path_file_ids[trans_id] if file_id is None: continue @@ -1816,11 +1879,13 @@ def _generate_inventory_delta(self): file_id, self._new_name[trans_id], self.final_file_id(self._new_parent[trans_id]), - None, self._new_reference_revision[trans_id]) + None, + self._new_reference_revision[trans_id], + ) else: - new_entry = inventory.make_entry(kind, - self.final_name(trans_id), - parent_file_id, file_id) + new_entry = inventory.make_entry( + kind, self.final_name(trans_id), parent_file_id, file_id + ) try: old_path = self._tree.id2path(new_entry.file_id) except errors.NoSuchId: @@ -1828,8 +1893,7 @@ def _generate_inventory_delta(self): new_executability = self._new_executability.get(trans_id) if new_executability is not None: new_entry.executable = bool(new_executability) - inventory_delta.append( - (old_path, path, new_entry.file_id, new_entry)) + inventory_delta.append((old_path, path, new_entry.file_id, new_entry)) return inventory_delta @@ -1843,7 +1907,7 @@ class TransformPreview(InventoryTreeTransform): def __init__(self, tree, pb=None, case_sensitive=True): tree.lock_read() - limbodir = tempfile.mkdtemp(prefix='bzr-limbo-') + limbodir = tempfile.mkdtemp(prefix="bzr-limbo-") DiskTreeTransform.__init__(self, tree, limbodir, pb, case_sensitive) def canonical_path(self, path): @@ -1854,7 +1918,7 @@ def tree_kind(self, trans_id): if path is None: return None kind = self._tree.path_content_summary(path)[0] - if kind == 'missing': + if kind == "missing": kind = None return kind @@ -1891,7 +1955,8 @@ def __init__(self, transform): PreviewTree.__init__(self, transform) self._final_paths = FinalPaths(transform) self._iter_changes_cache = { - c.file_id: c for c in self._transform.iter_changes()} + c.file_id: c for c in self._transform.iter_changes() + } def supports_setting_file_ids(self): return True @@ -1907,14 +1972,14 @@ def supports_tree_reference(self): def _content_change(self, file_id): """Return True if the content of this file changed.""" changes = self._iter_changes_cache.get(file_id) - return (changes is not None and changes.changed_content) + return changes is not None and changes.changed_content def _get_file_revision(self, path, file_id, vf, tree_revision): parent_keys = [ (file_id, t.get_file_revision(t.id2path(file_id))) - for t in self._iter_parent_trees()] - vf.add_lines((file_id, tree_revision), parent_keys, - self.get_file_lines(path)) + for t in self._iter_parent_trees() + ] + vf.add_lines((file_id, tree_revision), parent_keys, self.get_file_lines(path)) repo = self._get_repository() base_vf = repo.texts if base_vf not in vf.fallback_versionedfiles: @@ -1927,7 +1992,7 @@ def _stat_limbo_file(self, trans_id): def _comparison_data(self, entry, path): kind, size, executable, link_or_sha1 = self.path_content_summary(path) - if kind == 'missing': + if kind == "missing": kind = None executable = False else: @@ -1942,8 +2007,9 @@ def root_inventory(self): def all_file_ids(self): tree_ids = set(self._transform._tree.all_file_ids()) - tree_ids.difference_update(self._transform.tree_file_id(t) - for t in self._transform._removed_id) + tree_ids.difference_update( + self._transform.tree_file_id(t) for t in self._transform._removed_id + ) tree_ids.update(self._transform._new_id.values()) return tree_ids @@ -1951,12 +2017,12 @@ def all_versioned_paths(self): tree_paths = set(self._transform._tree.all_versioned_paths()) tree_paths.difference_update( - self._transform.trans_id_tree_path(t) - for t in self._transform._removed_id) + self._transform.trans_id_tree_path(t) for t in self._transform._removed_id + ) tree_paths.update( - self._final_paths._determine_path(t) - for t in self._transform._new_id) + self._final_paths._determine_path(t) for t in self._transform._new_id + ) return tree_paths @@ -1967,7 +2033,7 @@ def path2id(self, path): path = osutils.pathjoin(*path) return self._transform.final_file_id(self._path2trans_id(path)) - def id2path(self, file_id, recurse='down'): + def id2path(self, file_id, recurse="down"): trans_id = self._transform.trans_id_file_id(file_id) try: return self._final_paths._determine_path(trans_id) @@ -1975,8 +2041,10 @@ def id2path(self, file_id, recurse='down'): raise errors.NoSuchId(self, file_id) from e def extras(self): - possible_extras = {self._transform.trans_id_tree_path(p) for p - in self._transform._tree.extras()} + possible_extras = { + self._transform.trans_id_tree_path(p) + for p in self._transform._tree.extras() + } possible_extras.update(self._transform._new_contents) possible_extras.update(self._transform._removed_id) for trans_id in possible_extras: @@ -1989,17 +2057,19 @@ def _make_inv_entries(self, ordered_entries, specific_files=None): if file_id is None: continue parent_file_id = self._transform.final_file_id(parent_trans_id) - if (specific_files is not None - and self._final_paths.get_path(trans_id) not in specific_files): + if ( + specific_files is not None + and self._final_paths.get_path(trans_id) not in specific_files + ): continue kind = self._transform.final_kind(trans_id) if kind is None: kind = self._transform._tree.stored_kind( - self._transform._tree.id2path(file_id)) + self._transform._tree.id2path(file_id) + ) new_entry = inventory.make_entry( - kind, - self._transform.final_name(trans_id), - parent_file_id, file_id) + kind, self._transform.final_name(trans_id), parent_file_id, file_id + ) yield new_entry, trans_id def _list_files_by_dir(self): @@ -2019,15 +2089,16 @@ def iter_child_entries(self, path): trans_id = self._path2trans_id(path) if trans_id is None: raise _mod_transport.NoSuchFile(path) - todo = [(child_trans_id, trans_id) for child_trans_id in - self._all_children(trans_id)] + todo = [ + (child_trans_id, trans_id) + for child_trans_id in self._all_children(trans_id) + ] for entry, _trans_id in self._make_inv_entries(todo): yield entry def iter_entries_by_dir(self, specific_files=None, recurse_nested=False): if recurse_nested: - raise NotImplementedError( - 'follow tree references not yet supported') + raise NotImplementedError("follow tree references not yet supported") # This may not be a maximally efficient implementation, but it is # reasonably straightforward. An implementation that grafts the @@ -2035,8 +2106,7 @@ def iter_entries_by_dir(self, specific_files=None, recurse_nested=False): # might be more efficient, but requires tricky inferences about stack # position. ordered_ids = self._list_files_by_dir() - for entry, trans_id in self._make_inv_entries(ordered_ids, - specific_files): + for entry, trans_id in self._make_inv_entries(ordered_ids, specific_files): yield self._final_paths.get_path(trans_id), entry def _iter_entries_for_dir(self, dir_path): @@ -2051,38 +2121,39 @@ def _iter_entries_for_dir(self, dir_path): path_entries.sort() return path_entries - def list_files(self, include_root=False, from_dir=None, recursive=True, - recurse_nested=False): + def list_files( + self, include_root=False, from_dir=None, recursive=True, recurse_nested=False + ): """See WorkingTree.list_files.""" if recurse_nested: - raise NotImplementedError( - 'follow tree references not yet supported') + raise NotImplementedError("follow tree references not yet supported") # XXX This should behave like WorkingTree.list_files, but is really # more like RevisionTree.list_files. - if from_dir == '.': + if from_dir == ".": from_dir = None if recursive: prefix = None if from_dir: - prefix = from_dir + '/' + prefix = from_dir + "/" entries = self.iter_entries_by_dir() for path, entry in entries: - if entry.name == '' and not include_root: + if entry.name == "" and not include_root: continue if prefix: if not path.startswith(prefix): continue - path = path[len(prefix):] - yield path, 'V', entry.kind, entry + path = path[len(prefix) :] + yield path, "V", entry.kind, entry else: if from_dir is None and include_root is True: root_entry = inventory.make_entry( - 'directory', '', None, self.path2id('')) - yield '', 'V', 'directory', root_entry - entries = self._iter_entries_for_dir(from_dir or '') + "directory", "", None, self.path2id("") + ) + yield "", "V", "directory", root_entry + entries = self._iter_entries_for_dir(from_dir or "") for path, entry in entries: - yield path, 'V', entry.kind, entry + yield path, "V", entry.kind, entry def get_file_mtime(self, path): """See Tree.get_file_mtime.""" @@ -2091,7 +2162,8 @@ def get_file_mtime(self, path): raise _mod_transport.NoSuchFile(path) if not self._content_change(file_id): return self._transform._tree.get_file_mtime( - self._transform._tree.id2path(file_id)) + self._transform._tree.id2path(file_id) + ) trans_id = self._path2trans_id(path) return self._stat_limbo_file(trans_id).st_mtime @@ -2102,16 +2174,16 @@ def path_content_summary(self, path): kind = tt._new_contents.get(trans_id) if kind is None: if tree_path is None or trans_id in tt._removed_contents: - return 'missing', None, None, None + return "missing", None, None, None summary = tt._tree.path_content_summary(tree_path) kind, size, executable, link_or_sha1 = summary else: link_or_sha1 = None limbo_name = tt._limbo_name(trans_id) if trans_id in tt._new_reference_revision: - kind = 'tree-reference' + kind = "tree-reference" link_or_sha1 = tt._new_reference_revision - if kind == 'file': + if kind == "file": statval = os.lstat(limbo_name) size = statval.st_size if not tt._limbo_supports_executable(): @@ -2121,36 +2193,47 @@ def path_content_summary(self, path): else: size = None executable = None - if kind == 'symlink': + if kind == "symlink": link_or_sha1 = os.readlink(limbo_name) if not isinstance(link_or_sha1, str): link_or_sha1 = os.fsdecode(link_or_sha1) executable = tt._new_executability.get(trans_id, executable) return kind, size, executable, link_or_sha1 - def iter_changes(self, from_tree, include_unchanged=False, - specific_files=None, pb=None, extra_trees=None, - require_versioned=True, want_unversioned=False): + def iter_changes( + self, + from_tree, + include_unchanged=False, + specific_files=None, + pb=None, + extra_trees=None, + require_versioned=True, + want_unversioned=False, + ): """See InterTree.iter_changes. This has a fast path that is only used when the from_tree matches the transform tree, and no fancy options are supplied. """ - if (from_tree is not self._transform._tree or include_unchanged - or specific_files or want_unversioned): + if ( + from_tree is not self._transform._tree + or include_unchanged + or specific_files + or want_unversioned + ): return tree.InterTree.get(from_tree, self).iter_changes( include_unchanged=include_unchanged, specific_files=specific_files, pb=pb, extra_trees=extra_trees, require_versioned=require_versioned, - want_unversioned=want_unversioned) + want_unversioned=want_unversioned, + ) if want_unversioned: - raise ValueError('want_unversioned is not supported') + raise ValueError("want_unversioned is not supported") return self._transform.iter_changes() - def annotate_iter(self, path, - default_revision=_mod_revision.CURRENT_REVISION): + def annotate_iter(self, path, default_revision=_mod_revision.CURRENT_REVISION): file_id = self.path2id(path) changes = self._iter_changes_cache.get(file_id) if changes is None: @@ -2161,13 +2244,14 @@ def annotate_iter(self, path, else: if changes.kind[1] is None: return None - if changes.kind[0] == 'file' and changes.versioned[0]: + if changes.kind[0] == "file" and changes.versioned[0]: old_path = changes.path[0] else: old_path = None if old_path is not None: old_annotation = self._transform._tree.annotate_iter( - old_path, default_revision=default_revision) + old_path, default_revision=default_revision + ) else: old_annotation = [] if changes is None: @@ -2184,17 +2268,17 @@ def annotate_iter(self, path, # 'default_revision' # It would be nice to be able to use the new Annotator based # approach, as well. - return annotate.reannotate([old_annotation], - self.get_file_lines(path), - default_revision) + return annotate.reannotate( + [old_annotation], self.get_file_lines(path), default_revision + ) - def walkdirs(self, prefix=''): + def walkdirs(self, prefix=""): pending = [self._transform.root] while len(pending) > 0: parent_id = pending.pop() children = [] subdirs = [] - prefix = prefix.rstrip('/') + prefix = prefix.rstrip("/") parent_path = self._final_paths.get_path(parent_id) self._transform.final_file_id(parent_id) for child_id in self._all_children(parent_id): @@ -2205,18 +2289,17 @@ def walkdirs(self, prefix=''): if kind is not None: versioned_kind = kind else: - kind = 'unknown' - versioned_kind = self._transform._tree.stored_kind( - path_from_root) - if versioned_kind == 'directory': + kind = "unknown" + versioned_kind = self._transform._tree.stored_kind(path_from_root) + if versioned_kind == "directory": subdirs.append(child_id) - children.append((path_from_root, basename, kind, None, - versioned_kind)) + children.append((path_from_root, basename, kind, None, versioned_kind)) children.sort() if parent_path.startswith(prefix): yield parent_path, children - pending.extend(sorted(subdirs, key=self._final_paths.get_path, - reverse=True)) + pending.extend( + sorted(subdirs, key=self._final_paths.get_path, reverse=True) + ) def get_symlink_target(self, path): """See Tree.get_symlink_target.""" @@ -2234,11 +2317,10 @@ def get_file(self, path): return self._transform._tree.get_file(path) trans_id = self._path2trans_id(path) name = self._transform._limbo_name(trans_id) - return open(name, 'rb') + return open(name, "rb") -def build_tree(tree, wt, accelerator_tree=None, hardlink=False, - delta_from_tree=False): +def build_tree(tree, wt, accelerator_tree=None, hardlink=False, delta_from_tree=False): """Create working tree for a branch, using a TreeTransform. This function should be used on empty trees, having a tree root at most. @@ -2268,15 +2350,14 @@ def build_tree(tree, wt, accelerator_tree=None, hardlink=False, exit_stack.enter_context(tree.lock_read()) if accelerator_tree is not None: exit_stack.enter_context(accelerator_tree.lock_read()) - return _build_tree(tree, wt, accelerator_tree, hardlink, - delta_from_tree) + return _build_tree(tree, wt, accelerator_tree, hardlink, delta_from_tree) def resolve_checkout(tt, conflicts, divert): new_conflicts = set() for c_type, conflict in ((c[0], c) for c in conflicts): # Anything but a 'duplicate' would indicate programmer error - if c_type != 'duplicate': + if c_type != "duplicate": raise AssertionError(c_type) # Now figure out which is new and which is old if tt.new_contents(conflict[1]): @@ -2290,15 +2371,13 @@ def resolve_checkout(tt, conflicts, divert): # resolved final_parent = tt.final_parent(old_file) if new_file in divert: - new_name = tt.final_name(old_file) + '.diverted' + new_name = tt.final_name(old_file) + ".diverted" tt.adjust_path(new_name, final_parent, new_file) - new_conflicts.add((c_type, 'Diverted to', - new_file, old_file)) + new_conflicts.add((c_type, "Diverted to", new_file, old_file)) else: - new_name = tt.final_name(old_file) + '.moved' + new_name = tt.final_name(old_file) + ".moved" tt.adjust_path(new_name, final_parent, old_file) - new_conflicts.add((c_type, 'Moved existing file to', - old_file, new_file)) + new_conflicts.add((c_type, "Moved existing file to", old_file, new_file)) return new_conflicts @@ -2310,7 +2389,7 @@ def _build_tree(tree, wt, accelerator_tree, hardlink, delta_from_tree): file_trans_id = {} top_pb = ui.ui_factory.nested_progress_bar() pp = ProgressPhase("Build phase", 2, top_pb) - if tree.path2id('') is not None: + if tree.path2id("") is not None: # This is kind of a hack: we should be altering the root # as part of the regular tree shape diff logic. # The conditional test here is to avoid doing an @@ -2318,14 +2397,14 @@ def _build_tree(tree, wt, accelerator_tree, hardlink, delta_from_tree): # is set within the tree, nor setting the root and thus # marking the tree as dirty, because we use two different # idioms here: tree interfaces and inventory interfaces. - if wt.path2id('') != tree.path2id(''): - wt.set_root_id(tree.path2id('')) + if wt.path2id("") != tree.path2id(""): + wt.set_root_id(tree.path2id("")) wt.flush() tt = wt.transform() divert = set() try: pp.next_phase() - file_trans_id[find_previous_path(wt, tree, '')] = tt.trans_id_tree_path('') + file_trans_id[find_previous_path(wt, tree, "")] = tt.trans_id_tree_path("") with ui.ui_factory.nested_progress_bar() as pb: deferred_contents = [] num = 0 @@ -2343,10 +2422,8 @@ def _build_tree(tree, wt, accelerator_tree, hardlink, delta_from_tree): existing_files = set() for _dir, files in wt.walkdirs(): existing_files.update(f[0] for f in files) - for num, (tree_path, entry) in \ - enumerate(tree.iter_entries_by_dir()): - pb.update(gettext("Building tree"), num - - len(deferred_contents), total) + for num, (tree_path, entry) in enumerate(tree.iter_entries_by_dir()): + pb.update(gettext("Building tree"), num - len(deferred_contents), total) if entry.parent_id is None: continue reparent = False @@ -2363,14 +2440,14 @@ def _build_tree(tree, wt, accelerator_tree, hardlink, delta_from_tree): pass else: divert.add(tree_path) - if (tree_path not in divert - and _content_match( - tree, entry, tree_path, kind, target_path)): + if tree_path not in divert and _content_match( + tree, entry, tree_path, kind, target_path + ): tt.delete_contents(tt.trans_id_tree_path(tree_path)) - if kind == 'directory': + if kind == "directory": reparent = True parent_id = file_trans_id[osutils.dirname(tree_path)] - if entry.kind == 'file': + if entry.kind == "file": # We *almost* replicate new_by_entry, so that we can defer # getting the file text, and get them all at once. trans_id = tt.create_path(entry.name, parent_id) @@ -2383,19 +2460,22 @@ def _build_tree(tree, wt, accelerator_tree, hardlink, delta_from_tree): deferred_contents.append((tree_path, trans_data)) else: file_trans_id[tree_path] = new_by_entry( - tree_path, tt, entry, parent_id, tree) + tree_path, tt, entry, parent_id, tree + ) if reparent: new_trans_id = file_trans_id[tree_path] old_parent = tt.trans_id_tree_path(tree_path) _reparent_children(tt, old_parent, new_trans_id) offset = num + 1 - len(deferred_contents) - _create_files(tt, tree, deferred_contents, pb, offset, - accelerator_tree, hardlink) + _create_files( + tt, tree, deferred_contents, pb, offset, accelerator_tree, hardlink + ) pp.next_phase() divert_trans = {file_trans_id[f] for f in divert} def resolver(t, c): return resolve_checkout(t, c, divert_trans) + raw_conflicts = resolve_conflicts(tt, pass_func=resolver) if len(raw_conflicts) > 0: precomputed_delta = None @@ -2406,16 +2486,14 @@ def resolver(t, c): wt.add_conflicts(conflicts) except errors.UnsupportedOperation: pass - result = tt.apply(no_conflicts=True, - precomputed_delta=precomputed_delta) + result = tt.apply(no_conflicts=True, precomputed_delta=precomputed_delta) finally: tt.finalize() top_pb.finished() return result -def _create_files(tt, tree, desired_files, pb, offset, accelerator_tree, - hardlink): +def _create_files(tt, tree, desired_files, pb, offset, accelerator_tree, hardlink): total = len(desired_files) + offset wt = tt._tree if accelerator_tree is None: @@ -2423,39 +2501,47 @@ def _create_files(tt, tree, desired_files, pb, offset, accelerator_tree, else: iter = accelerator_tree.iter_changes(tree, include_unchanged=True) unchanged = [ - change.path for change in iter - if not (change.changed_content or change.executable[0] != change.executable[1])] + change.path + for change in iter + if not ( + change.changed_content or change.executable[0] != change.executable[1] + ) + ] if accelerator_tree.supports_content_filtering(): - unchanged = [(tp, ap) for (tp, ap) in unchanged - if not next(accelerator_tree.iter_search_rules([ap]))] + unchanged = [ + (tp, ap) + for (tp, ap) in unchanged + if not next(accelerator_tree.iter_search_rules([ap])) + ] unchanged = dict(unchanged) new_desired_files = [] count = 0 for _unused_tree_path, (trans_id, tree_path, text_sha1) in desired_files: accelerator_path = unchanged.get(tree_path) if accelerator_path is None: - new_desired_files.append((tree_path, - (trans_id, tree_path, text_sha1))) + new_desired_files.append((tree_path, (trans_id, tree_path, text_sha1))) continue - pb.update(gettext('Adding file contents'), count + offset, total) + pb.update(gettext("Adding file contents"), count + offset, total) if hardlink: - tt.create_hardlink(accelerator_tree.abspath(accelerator_path), - trans_id) + tt.create_hardlink(accelerator_tree.abspath(accelerator_path), trans_id) else: with accelerator_tree.get_file(accelerator_path) as f: chunks = osutils.file_iterator(f) if wt.supports_content_filtering(): filters = wt._content_filter_stack(tree_path) - chunks = filtered_output_bytes(chunks, filters, - ContentFilterContext(tree_path, tree)) + chunks = filtered_output_bytes( + chunks, filters, ContentFilterContext(tree_path, tree) + ) tt.create_file(chunks, trans_id, sha1=text_sha1) count += 1 offset += count for count, ((trans_id, tree_path, text_sha1), contents) in enumerate( - tree.iter_files_bytes(new_desired_files)): + tree.iter_files_bytes(new_desired_files) + ): if wt.supports_content_filtering(): filters = wt._content_filter_stack(tree_path) - contents = filtered_output_bytes(contents, filters, - ContentFilterContext(tree_path, tree)) + contents = filtered_output_bytes( + contents, filters, ContentFilterContext(tree_path, tree) + ) tt.create_file(contents, trans_id, sha1=text_sha1) - pb.update(gettext('Adding file contents'), count + offset, total) + pb.update(gettext("Adding file contents"), count + offset, total) diff --git a/breezy/bzr/tuned_gzip.py b/breezy/bzr/tuned_gzip.py index 3384e0a619..2d4b11ff51 100644 --- a/breezy/bzr/tuned_gzip.py +++ b/breezy/bzr/tuned_gzip.py @@ -38,29 +38,34 @@ def LOWU32(i): return i & 0xFFFFFFFF -def chunks_to_gzip(chunks, factory=zlib.compressobj, - level=zlib.Z_DEFAULT_COMPRESSION, method=zlib.DEFLATED, - width=-zlib.MAX_WBITS, mem=zlib.DEF_MEM_LEVEL, - crc32=zlib.crc32): +def chunks_to_gzip( + chunks, + factory=zlib.compressobj, + level=zlib.Z_DEFAULT_COMPRESSION, + method=zlib.DEFLATED, + width=-zlib.MAX_WBITS, + mem=zlib.DEF_MEM_LEVEL, + crc32=zlib.crc32, +): """Create a gzip file containing chunks and return its content. :param chunks: An iterable of strings. Each string can have arbitrary layout. """ result = [ - b'\037\213' # self.fileobj.write('\037\213') # magic header - b'\010' # self.fileobj.write('\010') # compression method - # fname = self.filename[:-3] - # flags = 0 - # if fname: - # flags = FNAME - b'\x00' # self.fileobj.write(chr(flags)) - b'\0\0\0\0' # write32u(self.fileobj, long(time.time())) - b'\002' # self.fileobj.write('\002') - b'\377' # self.fileobj.write('\377') - # if fname: - b'' # self.fileobj.write(fname + '\000') - ] + b"\037\213" # self.fileobj.write('\037\213') # magic header + b"\010" # self.fileobj.write('\010') # compression method + # fname = self.filename[:-3] + # flags = 0 + # if fname: + # flags = FNAME + b"\x00" # self.fileobj.write(chr(flags)) + b"\0\0\0\0" # write32u(self.fileobj, long(time.time())) + b"\002" # self.fileobj.write('\002') + b"\377" # self.fileobj.write('\377') + # if fname: + b"" # self.fileobj.write(fname + '\000') + ] # using a compressobj avoids a small header and trailer that the compress() # utility function adds. compress = factory(level, method, width, mem, 0) diff --git a/breezy/bzr/versionedfile.py b/breezy/bzr/versionedfile.py index b858dac4c2..53415944f6 100644 --- a/breezy/bzr/versionedfile.py +++ b/breezy/bzr/versionedfile.py @@ -26,13 +26,16 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import fastbencode as bencode from breezy import ( multiparent, ) -""") +""", +) from .. import errors, osutils, revision, urlutils from .. import graph as _mod_graph from .. import transport as _mod_transport @@ -41,25 +44,42 @@ from . import index adapter_registry = Registry[Tuple[str, str], Any, None]() -adapter_registry.register_lazy(('knit-annotated-delta-gz', 'knit-delta-gz'), - 'breezy.bzr.knit', 'DeltaAnnotatedToUnannotated') -adapter_registry.register_lazy(('knit-annotated-ft-gz', 'knit-ft-gz'), - 'breezy.bzr.knit', 'FTAnnotatedToUnannotated') -for target_storage_kind in ('fulltext', 'chunked', 'lines'): - adapter_registry.register_lazy(('knit-delta-gz', target_storage_kind), 'breezy.bzr.knit', - 'DeltaPlainToFullText') - adapter_registry.register_lazy(('knit-ft-gz', target_storage_kind), 'breezy.bzr.knit', - 'FTPlainToFullText') - adapter_registry.register_lazy(('knit-annotated-ft-gz', target_storage_kind), - 'breezy.bzr.knit', 'FTAnnotatedToFullText') - adapter_registry.register_lazy(('knit-annotated-delta-gz', target_storage_kind), - 'breezy.bzr.knit', 'DeltaAnnotatedToFullText') +adapter_registry.register_lazy( + ("knit-annotated-delta-gz", "knit-delta-gz"), + "breezy.bzr.knit", + "DeltaAnnotatedToUnannotated", +) +adapter_registry.register_lazy( + ("knit-annotated-ft-gz", "knit-ft-gz"), + "breezy.bzr.knit", + "FTAnnotatedToUnannotated", +) +for target_storage_kind in ("fulltext", "chunked", "lines"): + adapter_registry.register_lazy( + ("knit-delta-gz", target_storage_kind), + "breezy.bzr.knit", + "DeltaPlainToFullText", + ) + adapter_registry.register_lazy( + ("knit-ft-gz", target_storage_kind), "breezy.bzr.knit", "FTPlainToFullText" + ) + adapter_registry.register_lazy( + ("knit-annotated-ft-gz", target_storage_kind), + "breezy.bzr.knit", + "FTAnnotatedToFullText", + ) + adapter_registry.register_lazy( + ("knit-annotated-delta-gz", target_storage_kind), + "breezy.bzr.knit", + "DeltaAnnotatedToFullText", + ) class UnavailableRepresentation(errors.InternalBzrError): - - _fmt = ("The encoding '%(wanted)s' is not available for key %(key)s which " - "is encoded as '%(native)s'.") + _fmt = ( + "The encoding '%(wanted)s' is not available for key %(key)s which " + "is encoded as '%(native)s'." + ) def __init__(self, key, wanted, native): errors.InternalBzrError.__init__(self) @@ -69,7 +89,6 @@ def __init__(self, key, wanted, native): class ExistingContent(errors.BzrError): - _fmt = "The content being inserted is already present." @@ -121,33 +140,32 @@ def __init__(self, key, parents, sha1, chunks, chunks_are_lines=None) -> None: """Create a ContentFactory.""" self.sha1 = sha1 self.size: int = sum(map(len, chunks)) - self.storage_kind: str = 'chunked' + self.storage_kind: str = "chunked" self.key = key self.parents = parents self._chunks = chunks self._chunks_are_lines = chunks_are_lines def get_bytes_as(self, storage_kind): - if storage_kind == 'chunked': + if storage_kind == "chunked": return self._chunks - elif storage_kind == 'fulltext': - return b''.join(self._chunks) - elif storage_kind == 'lines': + elif storage_kind == "fulltext": + return b"".join(self._chunks) + elif storage_kind == "lines": if self._chunks_are_lines: return self._chunks return list(osutils.chunks_to_lines(self._chunks)) - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) def iter_bytes_as(self, storage_kind): - if storage_kind == 'chunked': + if storage_kind == "chunked": return iter(self._chunks) - elif storage_kind == 'lines': + elif storage_kind == "lines": if self._chunks_are_lines: return iter(self._chunks) return osutils.chunks_to_lines_iter(iter(self._chunks)) - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) + class FulltextContentFactory(ContentFactory): """Static data content factory. @@ -169,7 +187,7 @@ def __init__(self, key, parents, sha1, text): """Create a ContentFactory.""" self.sha1 = sha1 self.size = len(text) - self.storage_kind = 'fulltext' + self.storage_kind = "fulltext" self.key = key self.parents = parents if not isinstance(text, bytes): @@ -179,20 +197,18 @@ def __init__(self, key, parents, sha1, text): def get_bytes_as(self, storage_kind): if storage_kind == self.storage_kind: return self._text - elif storage_kind == 'chunked': + elif storage_kind == "chunked": return [self._text] - elif storage_kind == 'lines': + elif storage_kind == "lines": return osutils.split_lines(self._text) - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) def iter_bytes_as(self, storage_kind): - if storage_kind == 'chunked': + if storage_kind == "chunked": return iter([self._text]) - elif storage_kind == 'lines': + elif storage_kind == "lines": return iter(osutils.split_lines(self._text)) - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) class FileContentFactory(ContentFactory): @@ -202,7 +218,7 @@ def __init__(self, key, parents, fileobj, sha1=None, size=None): self.key = key self.parents = parents self.file = fileobj - self.storage_kind = 'file' + self.storage_kind = "file" self.sha1 = sha1 self.size = size self._needs_reset = False @@ -211,25 +227,23 @@ def get_bytes_as(self, storage_kind): if self._needs_reset: self.file.seek(0) self._needs_reset = True - if storage_kind == 'fulltext': + if storage_kind == "fulltext": return self.file.read() - elif storage_kind == 'chunked': + elif storage_kind == "chunked": return list(osutils.file_iterator(self.file)) - elif storage_kind == 'lines': + elif storage_kind == "lines": return list(self.file.readlines()) - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) def iter_bytes_as(self, storage_kind): if self._needs_reset: self.file.seek(0) self._needs_reset = True - if storage_kind == 'chunked': + if storage_kind == "chunked": return osutils.file_iterator(self.file) - elif storage_kind == 'lines': + elif storage_kind == "lines": return self.file - raise UnavailableRepresentation(self.key, storage_kind, - self.storage_kind) + raise UnavailableRepresentation(self.key, storage_kind, self.storage_kind) class AbsentContentFactory(ContentFactory): @@ -246,19 +260,23 @@ def __init__(self, key): """Create a ContentFactory.""" self.sha1 = None self.size = None - self.storage_kind = 'absent' + self.storage_kind = "absent" self.key = key self.parents = None def get_bytes_as(self, storage_kind): - raise ValueError(f'A request was made for key: {self.key}, but that' - ' content is not available, and the calling' - ' code does not handle if it is missing.') + raise ValueError( + f"A request was made for key: {self.key}, but that" + " content is not available, and the calling" + " code does not handle if it is missing." + ) def iter_bytes_as(self, storage_kind): - raise ValueError(f'A request was made for key: {self.key}, but that' - ' content is not available, and the calling' - ' code does not handle if it is missing.') + raise ValueError( + f"A request was made for key: {self.key}, but that" + " content is not available, and the calling" + " code does not handle if it is missing." + ) class AdapterFactory(ContentFactory): @@ -272,7 +290,7 @@ def __init__(self, key, parents, adapted): def __getattr__(self, attr): """Return a member from the adapted object.""" - if attr in ('key', 'parents'): + if attr in ("key", "parents"): return self.__dict__[attr] else: return getattr(self._adapted, attr) @@ -281,7 +299,7 @@ def __getattr__(self, attr): def filter_absent(record_stream): """Adapt a record stream to remove absent records.""" for record in record_stream: - if record.storage_kind != 'absent': + if record.storage_kind != "absent": yield record @@ -356,12 +374,12 @@ def _compute_diff(self, key, parent_lines, lines): # It was meant to extract the left-parent diff without # having to recompute it for Knit content (pack-0.92, # etc). That seems to have regressed somewhere - left_parent_blocks = self.vf._extract_blocks(key, - parent_lines[0], lines) + left_parent_blocks = self.vf._extract_blocks(key, parent_lines[0], lines) else: left_parent_blocks = None - diff = multiparent.MultiParent.from_lines(lines, - parent_lines, left_parent_blocks) + diff = multiparent.MultiParent.from_lines( + lines, parent_lines, left_parent_blocks + ) self.diffs[key] = diff def _process_one_record(self, key, this_chunks): @@ -404,12 +422,10 @@ def _process_one_record(self, key, this_chunks): def _extract_diffs(self): needed_keys, refcounts = self._find_needed_keys() - for record in self.vf.get_record_stream(needed_keys, - 'topological', True): - if record.storage_kind == 'absent': + for record in self.vf.get_record_stream(needed_keys, "topological", True): + if record.storage_kind == "absent": raise errors.RevisionNotPresent(record.key, self.vf) - self._process_one_record(record.key, - record.get_bytes_as('chunked')) + self._process_one_record(record.key, record.get_bytes_as("chunked")) def compute_diffs(self): self._extract_diffs() @@ -469,9 +485,17 @@ def insert_record_stream(self, stream): """ raise NotImplementedError - def add_lines(self, version_id, parents, lines, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False, - check_content=True): + def add_lines( + self, + version_id, + parents, + lines, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + check_content=True, + ): r"""Add a single text on top of the versioned file. Must raise RevisionAlreadyPresent if the new version is @@ -510,27 +534,69 @@ def add_lines(self, version_id, parents, lines, parent_texts=None, back to future add_lines calls in the parent_texts dictionary. """ self._check_write_ok() - return self._add_lines(version_id, parents, lines, parent_texts, - left_matching_blocks, nostore_sha, random_id, check_content) - - def _add_lines(self, version_id, parents, lines, parent_texts, - left_matching_blocks, nostore_sha, random_id, check_content): + return self._add_lines( + version_id, + parents, + lines, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + check_content, + ) + + def _add_lines( + self, + version_id, + parents, + lines, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + check_content, + ): """Helper to do the class specific add_lines.""" raise NotImplementedError(self.add_lines) - def add_lines_with_ghosts(self, version_id, parents, lines, - parent_texts=None, nostore_sha=None, random_id=False, - check_content=True, left_matching_blocks=None): + def add_lines_with_ghosts( + self, + version_id, + parents, + lines, + parent_texts=None, + nostore_sha=None, + random_id=False, + check_content=True, + left_matching_blocks=None, + ): """Add lines to the versioned file, allowing ghosts to be present. This takes the same parameters as add_lines and returns the same. """ self._check_write_ok() - return self._add_lines_with_ghosts(version_id, parents, lines, - parent_texts, nostore_sha, random_id, check_content, left_matching_blocks) - - def _add_lines_with_ghosts(self, version_id, parents, lines, parent_texts, - nostore_sha, random_id, check_content, left_matching_blocks): + return self._add_lines_with_ghosts( + version_id, + parents, + lines, + parent_texts, + nostore_sha, + random_id, + check_content, + left_matching_blocks, + ) + + def _add_lines_with_ghosts( + self, + version_id, + parents, + lines, + parent_texts, + nostore_sha, + random_id, + check_content, + left_matching_blocks, + ): """Helper to do class specific add_lines_with_ghosts.""" raise NotImplementedError(self.add_lines_with_ghosts) @@ -547,7 +613,7 @@ def _check_lines_not_unicode(self, lines): def _check_lines_are_lines(self, lines): """Check that the lines really are full lines without inline EOL.""" for line in lines: - if b'\n' in line[:-1]: + if b"\n" in line[:-1]: raise errors.BzrBadParameterContainsNewline("lines") def get_format_signature(self): @@ -573,14 +639,14 @@ def make_mpdiffs(self, version_ids): raise errors.RevisionNotPresent(version_id, self) from e # We need to filter out ghosts, because we can't diff against them. knit_versions = set(self.get_parent_map(knit_versions)) - lines = dict(zip(knit_versions, - self._get_lf_split_line_list(knit_versions))) + lines = dict(zip(knit_versions, self._get_lf_split_line_list(knit_versions))) diffs = [] for version_id in version_ids: target = lines[version_id] try: - parents = [lines[p] for p in parent_map[version_id] if p in - knit_versions] + parents = [ + lines[p] for p in parent_map[version_id] if p in knit_versions + ] except KeyError as e: # I don't know how this could ever trigger. # parent_map[version_id] was already triggered in the previous @@ -588,12 +654,14 @@ def make_mpdiffs(self, version_ids): # so we again won't have a KeyError. raise errors.RevisionNotPresent(version_id, self) from e if len(parents) > 0: - left_parent_blocks = self._extract_blocks(version_id, - parents[0], target) + left_parent_blocks = self._extract_blocks( + version_id, parents[0], target + ) else: left_parent_blocks = None - diffs.append(multiparent.MultiParent.from_lines(target, parents, - left_parent_blocks)) + diffs.append( + multiparent.MultiParent.from_lines(target, parents, left_parent_blocks) + ) return diffs def _extract_blocks(self, version_id, source, target): @@ -614,29 +682,41 @@ def add_mpdiffs(self, records): mpvf.add_diff(mpdiff, version, parent_ids) needed_parents = set() for _version, parent_ids, _expected_sha1, _mpdiff in records: - needed_parents.update(p for p in parent_ids - if not mpvf.has_version(p)) + needed_parents.update(p for p in parent_ids if not mpvf.has_version(p)) present_parents = set(self.get_parent_map(needed_parents)) - for parent_id, lines in zip(present_parents, - self._get_lf_split_line_list(present_parents)): + for parent_id, lines in zip( + present_parents, self._get_lf_split_line_list(present_parents) + ): mpvf.add_version(lines, parent_id, []) for (version, parent_ids, _expected_sha1, mpdiff), lines in zip( - records, mpvf.get_line_list(versions)): + records, mpvf.get_line_list(versions) + ): if len(parent_ids) == 1: - left_matching_blocks = list(mpdiff.get_matching_blocks(0, - mpvf.get_diff(parent_ids[0]).num_lines())) + left_matching_blocks = list( + mpdiff.get_matching_blocks( + 0, mpvf.get_diff(parent_ids[0]).num_lines() + ) + ) else: left_matching_blocks = None try: - _, _, version_text = self.add_lines_with_ghosts(version, - parent_ids, lines, vf_parents, - left_matching_blocks=left_matching_blocks) + _, _, version_text = self.add_lines_with_ghosts( + version, + parent_ids, + lines, + vf_parents, + left_matching_blocks=left_matching_blocks, + ) except NotImplementedError: # The vf can't handle ghosts, so add lines normally, which will # (reasonably) fail if there are ghosts in the data. - _, _, version_text = self.add_lines(version, - parent_ids, lines, vf_parents, - left_matching_blocks=left_matching_blocks) + _, _, version_text = self.add_lines( + version, + parent_ids, + lines, + vf_parents, + left_matching_blocks=left_matching_blocks, + ) vf_parents[version] = version_text sha1s = self.get_sha1s(versions) for version, _parent_ids, expected_sha1, _mpdiff in records: @@ -649,7 +729,8 @@ def get_text(self, version_id): Raises RevisionNotPresent if version is not present in file history. """ - return b''.join(self.get_lines(version_id)) + return b"".join(self.get_lines(version_id)) + get_string = get_text def get_texts(self, version_ids): @@ -658,7 +739,7 @@ def get_texts(self, version_ids): Raises RevisionNotPresent if version is not present in file history. """ - return [b''.join(self.get_lines(v)) for v in version_ids] + return [b"".join(self.get_lines(v)) for v in version_ids] def get_lines(self, version_id): """Return version contents as a sequence of lines. @@ -722,8 +803,7 @@ def annotate(self, version_id): """ raise NotImplementedError(self.annotate) - def iter_lines_added_or_present_in_versions(self, version_ids=None, - pb=None): + def iter_lines_added_or_present_in_versions(self, version_ids=None, pb=None): r"""Iterate over the lines in the versioned file from version_ids. This may return lines from other versions. Each item the returned @@ -766,8 +846,9 @@ def plan_merge(self, ver_a, ver_b, base=None): """ raise NotImplementedError(VersionedFile.plan_merge) - def weave_merge(self, plan, a_marker=TextMerge.A_MARKER, - b_marker=TextMerge.B_MARKER): + def weave_merge( + self, plan, a_marker=TextMerge.A_MARKER, b_marker=TextMerge.B_MARKER + ): return PlanWeaveMerge(plan, a_marker, b_marker).merge_lines()[0] @@ -788,22 +869,69 @@ def __init__(self, backing_vf): self._backing_vf = backing_vf self.calls = [] - def add_lines(self, key, parents, lines, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False, - check_content=True): - self.calls.append(("add_lines", key, parents, lines, parent_texts, - left_matching_blocks, nostore_sha, random_id, check_content)) - return self._backing_vf.add_lines(key, parents, lines, parent_texts, - left_matching_blocks, nostore_sha, random_id, check_content) - - def add_content(self, factory, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False, - check_content=True): - self.calls.append(("add_content", factory, parent_texts, - left_matching_blocks, nostore_sha, random_id, check_content)) + def add_lines( + self, + key, + parents, + lines, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + check_content=True, + ): + self.calls.append( + ( + "add_lines", + key, + parents, + lines, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + check_content, + ) + ) + return self._backing_vf.add_lines( + key, + parents, + lines, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + check_content, + ) + + def add_content( + self, + factory, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + check_content=True, + ): + self.calls.append( + ( + "add_content", + factory, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + check_content, + ) + ) return self._backing_vf.add_content( - factory, parent_texts, left_matching_blocks, nostore_sha, - random_id, check_content) + factory, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + check_content, + ) def check(self): self._backing_vf.check() @@ -813,10 +941,12 @@ def get_parent_map(self, keys): return self._backing_vf.get_parent_map(keys) def get_record_stream(self, keys, sort_order, include_delta_closure): - self.calls.append(("get_record_stream", list(keys), sort_order, - include_delta_closure)) - return self._backing_vf.get_record_stream(keys, sort_order, - include_delta_closure) + self.calls.append( + ("get_record_stream", list(keys), sort_order, include_delta_closure) + ) + return self._backing_vf.get_record_stream( + keys, sort_order, include_delta_closure + ) def get_sha1s(self, keys): self.calls.append(("get_sha1s", copy(keys))) @@ -852,19 +982,24 @@ def __init__(self, backing_vf, key_priority): self._key_priority = key_priority def get_record_stream(self, keys, sort_order, include_delta_closure): - self.calls.append(("get_record_stream", list(keys), sort_order, - include_delta_closure)) - if sort_order == 'unordered': + self.calls.append( + ("get_record_stream", list(keys), sort_order, include_delta_closure) + ) + if sort_order == "unordered": + def sort_key(key): return (self._key_priority.get(key, 0), key) + # Use a defined order by asking for the keys one-by-one from the # backing_vf for key in sorted(keys, key=sort_key): - yield from self._backing_vf.get_record_stream([key], - 'unordered', include_delta_closure) + yield from self._backing_vf.get_record_stream( + [key], "unordered", include_delta_closure + ) else: - yield from self._backing_vf.get_record_stream(keys, sort_order, - include_delta_closure) + yield from self._backing_vf.get_record_stream( + keys, sort_order, include_delta_closure + ) class KeyMapper: @@ -925,11 +1060,11 @@ class PrefixMapper(URLEscapeMapper): def _map(self, key): """See KeyMapper.map().""" - return key[0].decode('utf-8') + return key[0].decode("utf-8") def _unmap(self, partition_id): """See KeyMapper.unmap().""" - return (partition_id.encode('utf-8'),) + return (partition_id.encode("utf-8"),) class HashPrefixMapper(URLEscapeMapper): @@ -949,7 +1084,7 @@ def _escape(self, prefix): def _unmap(self, partition_id): """See KeyMapper.unmap().""" - return (self._unescape(osutils.basename(partition_id)).encode('utf-8'),) + return (self._unescape(osutils.basename(partition_id)).encode("utf-8"),) def _unescape(self, basename): """No unescaping needed for HashPrefixMapper.""" @@ -974,9 +1109,8 @@ def _escape(self, prefix): # @ does not get escaped. This is because it is a valid # filesystem character we use all the time, and it looks # a lot better than seeing %40 all the time. - r = [(c in self._safe) and chr(c) or (f'%{c:02x}') - for c in bytearray(prefix)] - return ''.join(r).encode('ascii') + r = [(c in self._safe) and chr(c) or (f"%{c:02x}") for c in bytearray(prefix)] + return "".join(r).encode("ascii") def _unescape(self, basename): """Escaped names are easily unescaped by urlutils.""" @@ -990,9 +1124,12 @@ def make_versioned_files_factory(versioned_file_factory, mapper): ThunkedVersionedFiles on a transport, using mapper to access individual versioned files, and versioned_file_factory to create each individual file. """ + def factory(transport): - return ThunkedVersionedFiles(transport, versioned_file_factory, mapper, - lambda: True) + return ThunkedVersionedFiles( + transport, versioned_file_factory, mapper, lambda: True + ) + return factory @@ -1019,9 +1156,17 @@ class VersionedFiles: one. They may in turn each have further fallbacks. """ - def add_lines(self, key, parents, lines, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False, - check_content=True): + def add_lines( + self, + key, + parents, + lines, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + check_content=True, + ): r"""Add a text to the store. :param key: The key tuple of the text to add. If the last element is @@ -1058,9 +1203,15 @@ def add_lines(self, key, parents, lines, parent_texts=None, """ raise NotImplementedError(self.add_lines) - def add_content(self, factory, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False, - check_content=True): + def add_content( + self, + factory, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + check_content=True, + ): """Add a text to the store from a chunk iterable. :param key: The key tuple of the text to add. If the last element is @@ -1104,25 +1255,31 @@ def add_mpdiffs(self, records): mpvf.add_diff(mpdiff, version, parent_ids) needed_parents = set() for _version, parent_ids, _expected_sha1, _mpdiff in records: - needed_parents.update(p for p in parent_ids - if not mpvf.has_version(p)) + needed_parents.update(p for p in parent_ids if not mpvf.has_version(p)) # It seems likely that adding all the present parents as fulltexts can # easily exhaust memory. - for record in self.get_record_stream(needed_parents, 'unordered', - True): - if record.storage_kind == 'absent': + for record in self.get_record_stream(needed_parents, "unordered", True): + if record.storage_kind == "absent": continue - mpvf.add_version(record.get_bytes_as('lines'), record.key, []) + mpvf.add_version(record.get_bytes_as("lines"), record.key, []) for (key, parent_keys, expected_sha1, mpdiff), lines in zip( - records, mpvf.get_line_list(versions)): + records, mpvf.get_line_list(versions) + ): if len(parent_keys) == 1: - left_matching_blocks = list(mpdiff.get_matching_blocks(0, - mpvf.get_diff(parent_keys[0]).num_lines())) + left_matching_blocks = list( + mpdiff.get_matching_blocks( + 0, mpvf.get_diff(parent_keys[0]).num_lines() + ) + ) else: left_matching_blocks = None - version_sha1, _, version_text = self.add_lines(key, - parent_keys, lines, vf_parents, - left_matching_blocks=left_matching_blocks) + version_sha1, _, version_text = self.add_lines( + key, + parent_keys, + lines, + vf_parents, + left_matching_blocks=left_matching_blocks, + ) if version_sha1 != expected_sha1: raise errors.VersionedFileInvalidChecksum(version) vf_parents[key] = version_text @@ -1167,7 +1324,7 @@ def _check_lines_not_unicode(self, lines): def _check_lines_are_lines(self, lines): """Check that the lines really are full lines without inline EOL.""" for line in lines: - if b'\n' in line[:-1]: + if b"\n" in line[:-1]: raise errors.BzrBadParameterContainsNewline("lines") def get_known_graph_ancestry(self, keys): @@ -1178,8 +1335,7 @@ def get_known_graph_ancestry(self, keys): while pending: this_parent_map = self.get_parent_map(pending) parent_map.update(this_parent_map) - pending = set(itertools.chain.from_iterable( - this_parent_map.values())) + pending = set(itertools.chain.from_iterable(this_parent_map.values())) pending.difference_update(parent_map) kg = _mod_graph.KnownGraph(parent_map) return kg @@ -1274,6 +1430,7 @@ def make_mpdiffs(self, keys): def get_annotator(self): from ..annotate import Annotator + return Annotator(self) missing_keys = index._missing_keys_from_parent_map @@ -1312,21 +1469,38 @@ def __init__(self, transport, file_factory, mapper, is_locked): self._mapper = mapper self._is_locked = is_locked - def add_content(self, factory, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False): + def add_content( + self, + factory, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + ): """See VersionedFiles.add_content().""" - lines = factory.get_bytes_as('lines') + lines = factory.get_bytes_as("lines") return self.add_lines( - factory.key, factory.parents, lines, + factory.key, + factory.parents, + lines, parent_texts=parent_texts, left_matching_blocks=left_matching_blocks, nostore_sha=nostore_sha, random_id=random_id, - check_content=True) - - def add_lines(self, key, parents, lines, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False, - check_content=True): + check_content=True, + ) + + def add_lines( + self, + key, + parents, + lines, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + check_content=True, + ): """See VersionedFiles.add_lines().""" path = self._mapper.map(key) version_id = key[-1] @@ -1334,32 +1508,52 @@ def add_lines(self, key, parents, lines, parent_texts=None, vf = self._get_vf(path) try: try: - return vf.add_lines_with_ghosts(version_id, parents, lines, - parent_texts=parent_texts, - left_matching_blocks=left_matching_blocks, - nostore_sha=nostore_sha, random_id=random_id, - check_content=check_content) + return vf.add_lines_with_ghosts( + version_id, + parents, + lines, + parent_texts=parent_texts, + left_matching_blocks=left_matching_blocks, + nostore_sha=nostore_sha, + random_id=random_id, + check_content=check_content, + ) except NotImplementedError: - return vf.add_lines(version_id, parents, lines, - parent_texts=parent_texts, - left_matching_blocks=left_matching_blocks, - nostore_sha=nostore_sha, random_id=random_id, - check_content=check_content) + return vf.add_lines( + version_id, + parents, + lines, + parent_texts=parent_texts, + left_matching_blocks=left_matching_blocks, + nostore_sha=nostore_sha, + random_id=random_id, + check_content=check_content, + ) except _mod_transport.NoSuchFile: # parent directory may be missing, try again. self._transport.mkdir(osutils.dirname(path)) try: - return vf.add_lines_with_ghosts(version_id, parents, lines, - parent_texts=parent_texts, - left_matching_blocks=left_matching_blocks, - nostore_sha=nostore_sha, random_id=random_id, - check_content=check_content) + return vf.add_lines_with_ghosts( + version_id, + parents, + lines, + parent_texts=parent_texts, + left_matching_blocks=left_matching_blocks, + nostore_sha=nostore_sha, + random_id=random_id, + check_content=check_content, + ) except NotImplementedError: - return vf.add_lines(version_id, parents, lines, - parent_texts=parent_texts, - left_matching_blocks=left_matching_blocks, - nostore_sha=nostore_sha, random_id=random_id, - check_content=check_content) + return vf.add_lines( + version_id, + parents, + lines, + parent_texts=parent_texts, + left_matching_blocks=left_matching_blocks, + nostore_sha=nostore_sha, + random_id=random_id, + check_content=check_content, + ) def annotate(self, key): """Return a list of (version-key, line) tuples for the text of key. @@ -1383,7 +1577,7 @@ def check(self, progress_bar=None, keys=None): for _prefix, vf in self._iter_all_components(): vf.check() if keys is not None: - return self.get_record_stream(keys, 'unordered', True) + return self.get_record_stream(keys, "unordered", True) def get_parent_map(self, keys): """Get a map of the parents of keys. @@ -1400,14 +1594,16 @@ def get_parent_map(self, keys): parent_map = vf.get_parent_map(suffixes) for key, parents in parent_map.items(): result[prefix + (key,)] = tuple( - prefix + (parent,) for parent in parents) + prefix + (parent,) for parent in parents + ) return result def _get_vf(self, path): if not self._is_locked(): raise errors.ObjectNotLocked(self) - return self._file_factory(path, self._transport, create=True, - get_scope=lambda: None) + return self._file_factory( + path, self._transport, create=True, get_scope=lambda: None + ) def _partition_keys(self, keys): """Turn keys into a dict of prefix:suffix_list.""" @@ -1439,11 +1635,11 @@ def get_record_stream(self, keys, ordering, include_delta_closure): keys = sorted(keys) for prefix, suffixes, vf in self._iter_keys_vf(keys): suffixes = [(suffix,) for suffix in suffixes] - for record in vf.get_record_stream(suffixes, ordering, - include_delta_closure): + for record in vf.get_record_stream( + suffixes, ordering, include_delta_closure + ): if record.parents is not None: - record.parents = tuple( - prefix + parent for parent in record.parents) + record.parents = tuple(prefix + parent for parent in record.parents) record.key = prefix + record.key yield record @@ -1524,7 +1720,6 @@ def keys(self): class VersionedFilesWithFallbacks(VersionedFiles): - def without_fallbacks(self): """Return a clone of this object without any fallbacks configured.""" raise NotImplementedError(self.without_fallbacks) @@ -1542,8 +1737,7 @@ def get_known_graph_ancestry(self, keys): for fallback in self._transitive_fallbacks(): if not missing_keys: break - (f_parent_map, f_missing_keys) = fallback._index.find_ancestry( - missing_keys) + (f_parent_map, f_missing_keys) = fallback._index.find_ancestry(missing_keys) parent_map.update(f_parent_map) missing_keys = f_missing_keys kg = _mod_graph.KnownGraph(parent_map) @@ -1581,28 +1775,31 @@ def __init__(self, file_id): def plan_merge(self, ver_a, ver_b, base=None): """See VersionedFile.plan_merge.""" from ..merge import _PlanMerge + if base is None: return _PlanMerge(ver_a, ver_b, self, (self._file_id,)).plan_merge() - old_plan = list(_PlanMerge(ver_a, base, self, - (self._file_id,)).plan_merge()) - new_plan = list(_PlanMerge(ver_a, ver_b, self, - (self._file_id,)).plan_merge()) + old_plan = list(_PlanMerge(ver_a, base, self, (self._file_id,)).plan_merge()) + new_plan = list(_PlanMerge(ver_a, ver_b, self, (self._file_id,)).plan_merge()) return _PlanMerge._subtract_plans(old_plan, new_plan) def plan_lca_merge(self, ver_a, ver_b, base=None): from ..merge import _PlanLCAMerge + graph = _mod_graph.Graph(self) new_plan = _PlanLCAMerge( - ver_a, ver_b, self, (self._file_id,), graph).plan_merge() + ver_a, ver_b, self, (self._file_id,), graph + ).plan_merge() if base is None: return new_plan old_plan = _PlanLCAMerge( - ver_a, base, self, (self._file_id,), graph).plan_merge() + ver_a, base, self, (self._file_id,), graph + ).plan_merge() return _PlanLCAMerge._subtract_plans(list(old_plan), list(new_plan)) def add_content(self, factory): return self.add_lines( - factory.key, factory.parents, factory.get_bytes_as('lines')) + factory.key, factory.parents, factory.get_bytes_as("lines") + ) def add_lines(self, key, parents, lines): """See VersionedFiles.add_lines. @@ -1613,11 +1810,11 @@ def add_lines(self, key, parents, lines): if not isinstance(key, tuple): raise TypeError(key) if not revision.is_reserved_id(key[-1]): - raise ValueError('Only reserved ids may be used') + raise ValueError("Only reserved ids may be used") if parents is None: - raise ValueError('Parents may not be None') + raise ValueError("Parents may not be None") if lines is None: - raise ValueError('Lines may not be None') + raise ValueError("Lines may not be None") self._parents[key] = tuple(parents) self._lines[key] = lines @@ -1629,12 +1826,11 @@ def get_record_stream(self, keys, ordering, include_delta_closure): parents = self._parents[key] pending.remove(key) yield ChunkedContentFactory( - key, parents, None, lines, - chunks_are_lines=True) + key, parents, None, lines, chunks_are_lines=True + ) for versionedfile in self.fallback_versionedfiles: - for record in versionedfile.get_record_stream( - pending, 'unordered', True): - if record.storage_kind == 'absent': + for record in versionedfile.get_record_stream(pending, "unordered", True): + if record.storage_kind == "absent": continue else: pending.remove(record.key) @@ -1657,8 +1853,8 @@ def get_parent_map(self, keys): result[revision.NULL_REVISION] = () self._providers = self._providers[:1] + self.fallback_versionedfiles result.update( - _mod_graph.StackedParentsProvider( - self._providers).get_parent_map(keys)) + _mod_graph.StackedParentsProvider(self._providers).get_parent_map(keys) + ) for key, parents in result.items(): if parents == (): result[key] = (revision.NULL_REVISION,) @@ -1672,8 +1868,7 @@ class PlanWeaveMerge(TextMerge): Most callers will want to use WeaveMerge instead. """ - def __init__(self, plan, a_marker=TextMerge.A_MARKER, - b_marker=TextMerge.B_MARKER): + def __init__(self, plan, a_marker=TextMerge.A_MARKER, b_marker=TextMerge.B_MARKER): TextMerge.__init__(self, a_marker, b_marker) self.plan = list(plan) @@ -1699,41 +1894,40 @@ def outstanding_struct(): # to be possible places to resynchronize. However, assuming agreement # on killed-both lines may be too aggressive. -- mbp 20060324 for state, line in self.plan: - if state == 'unchanged': + if state == "unchanged": # resync and flush queued conflicts changes if any yield from outstanding_struct() lines_a = [] lines_b = [] ch_a = ch_b = False - if state == 'unchanged': + if state == "unchanged": if line: yield ([line],) - elif state == 'killed-a': + elif state == "killed-a": ch_a = True lines_b.append(line) - elif state == 'killed-b': + elif state == "killed-b": ch_b = True lines_a.append(line) - elif state == 'new-a': + elif state == "new-a": ch_a = True lines_a.append(line) - elif state == 'new-b': + elif state == "new-b": ch_b = True lines_b.append(line) - elif state == 'conflicted-a': + elif state == "conflicted-a": ch_b = ch_a = True lines_a.append(line) - elif state == 'conflicted-b': + elif state == "conflicted-b": ch_b = ch_a = True lines_b.append(line) - elif state == 'killed-both': + elif state == "killed-both": # This counts as a change, even though there is no associated # line ch_b = ch_a = True else: - if state not in ('irrelevant', 'ghost-a', 'ghost-b', - 'killed-base'): + if state not in ("irrelevant", "ghost-a", "ghost-b", "killed-base"): raise AssertionError(state) yield from outstanding_struct() @@ -1741,15 +1935,21 @@ def base_from_plan(self): """Construct a BASE file from the plan text.""" base_lines = [] for state, line in self.plan: - if state in ('killed-a', 'killed-b', 'killed-both', 'unchanged'): + if state in ("killed-a", "killed-b", "killed-both", "unchanged"): # If unchanged, then this line is straight from base. If a or b # or both killed the line, then it *used* to be in base. base_lines.append(line) else: - if state not in ('killed-base', 'irrelevant', - 'ghost-a', 'ghost-b', - 'new-a', 'new-b', - 'conflicted-a', 'conflicted-b'): + if state not in ( + "killed-base", + "irrelevant", + "ghost-a", + "ghost-b", + "new-a", + "new-b", + "conflicted-a", + "conflicted-b", + ): # killed-base, irrelevant means it doesn't apply # ghost-a/ghost-b are harder to say for sure, but they # aren't in the 'inc_c' which means they aren't in the @@ -1788,15 +1988,21 @@ def base_from_plan(self): # It seems that having the line 2 times is better than # having it omitted. (Easier to manually delete than notice # it needs to be added.) - raise AssertionError(f'Unknown state: {state}') + raise AssertionError(f"Unknown state: {state}") return base_lines class WeaveMerge(PlanWeaveMerge): """Weave merge that takes a VersionedFile and two versions as its input.""" - def __init__(self, versionedfile, ver_a, ver_b, - a_marker=PlanWeaveMerge.A_MARKER, b_marker=PlanWeaveMerge.B_MARKER): + def __init__( + self, + versionedfile, + ver_a, + ver_b, + a_marker=PlanWeaveMerge.A_MARKER, + b_marker=PlanWeaveMerge.B_MARKER, + ): plan = versionedfile.plan_merge(ver_a, ver_b) PlanWeaveMerge.__init__(self, plan, a_marker, b_marker) @@ -1858,8 +2064,12 @@ def get_record_stream(self, keys, ordering, include_delta_closure): if not isinstance(lines, list): raise AssertionError yield ChunkedContentFactory( - (k,), None, sha1=osutils.sha_strings(lines), - chunks=lines, chunks_are_lines=True) + (k,), + None, + sha1=osutils.sha_strings(lines), + chunks=lines, + chunks_are_lines=True, + ) else: yield AbsentContentFactory((k,)) @@ -1880,9 +2090,17 @@ class NoDupeAddLinesDecorator: def __init__(self, store): self._store = store - def add_lines(self, key, parents, lines, parent_texts=None, - left_matching_blocks=None, nostore_sha=None, random_id=False, - check_content=True): + def add_lines( + self, + key, + parents, + lines, + parent_texts=None, + left_matching_blocks=None, + nostore_sha=None, + random_id=False, + check_content=True, + ): """See VersionedFiles.add_lines. This implementation may return None as the third element of the return @@ -1891,7 +2109,8 @@ def add_lines(self, key, parents, lines, parent_texts=None, if nostore_sha: raise NotImplementedError( "NoDupeAddLinesDecorator.add_lines does not implement the " - "nostore_sha behaviour.") + "nostore_sha behaviour." + ) if key[-1] is None: sha1 = osutils.sha_strings(lines) key = (b"sha1:" + sha1,) @@ -1902,11 +2121,16 @@ def add_lines(self, key, parents, lines, parent_texts=None, if sha1 is None: sha1 = osutils.sha_strings(lines) return sha1, sum(map(len, lines)), None - return self._store.add_lines(key, parents, lines, - parent_texts=parent_texts, - left_matching_blocks=left_matching_blocks, - nostore_sha=nostore_sha, random_id=random_id, - check_content=check_content) + return self._store.add_lines( + key, + parents, + lines, + parent_texts=parent_texts, + left_matching_blocks=left_matching_blocks, + nostore_sha=nostore_sha, + random_id=random_id, + check_content=check_content, + ) def __getattr__(self, name): return getattr(self._store, name) @@ -1918,8 +2142,8 @@ def network_bytes_to_kind_and_offset(network_bytes): :param network_bytes: The bytes of a record. :return: A tuple (storage_kind, offset_of_remaining_bytes) """ - line_end = network_bytes.find(b'\n') - storage_kind = network_bytes[:line_end].decode('ascii') + line_end = network_bytes.find(b"\n") + storage_kind = network_bytes[:line_end].decode("ascii") return storage_kind, line_end + 1 @@ -1934,16 +2158,17 @@ def __init__(self, bytes_iterator): record.get_bytes_as(record.storage_kind) call. """ from . import groupcompress, knit + self._bytes_iterator = bytes_iterator self._kind_factory = { - 'fulltext': fulltext_network_to_record, - 'groupcompress-block': groupcompress.network_block_to_records, - 'knit-ft-gz': knit.knit_network_to_record, - 'knit-delta-gz': knit.knit_network_to_record, - 'knit-annotated-ft-gz': knit.knit_network_to_record, - 'knit-annotated-delta-gz': knit.knit_network_to_record, - 'knit-delta-closure': knit.knit_delta_closure_to_records, - } + "fulltext": fulltext_network_to_record, + "groupcompress-block": groupcompress.network_block_to_records, + "knit-ft-gz": knit.knit_network_to_record, + "knit-delta-gz": knit.knit_network_to_record, + "knit-annotated-ft-gz": knit.knit_network_to_record, + "knit-annotated-delta-gz": knit.knit_network_to_record, + "knit-delta-closure": knit.knit_delta_closure_to_records, + } def read(self): """Read the stream. @@ -1952,34 +2177,36 @@ def read(self): """ for bytes in self._bytes_iterator: storage_kind, line_end = network_bytes_to_kind_and_offset(bytes) - yield from self._kind_factory[storage_kind]( - storage_kind, bytes, line_end) + yield from self._kind_factory[storage_kind](storage_kind, bytes, line_end) def fulltext_network_to_record(kind, bytes, line_end): """Convert a network fulltext record to record.""" - meta_len, = struct.unpack('!L', bytes[line_end:line_end + 4]) - record_meta = bytes[line_end + 4:line_end + 4 + meta_len] + (meta_len,) = struct.unpack("!L", bytes[line_end : line_end + 4]) + record_meta = bytes[line_end + 4 : line_end + 4 + meta_len] key, parents = bencode.bdecode_as_tuple(record_meta) - if parents == b'nil': + if parents == b"nil": parents = None - fulltext = bytes[line_end + 4 + meta_len:] + fulltext = bytes[line_end + 4 + meta_len :] return [FulltextContentFactory(key, parents, None, fulltext)] def _length_prefix(bytes): - return struct.pack('!L', len(bytes)) + return struct.pack("!L", len(bytes)) def record_to_fulltext_bytes(record): if record.parents is None: - parents = b'nil' + parents = b"nil" else: parents = tuple([tuple(p) for p in record.parents]) record_meta = bencode.bencode((record.key, parents)) - record_content = record.get_bytes_as('fulltext') + record_content = record.get_bytes_as("fulltext") return b"fulltext\n%s%s%s" % ( - _length_prefix(record_meta), record_meta, record_content) + _length_prefix(record_meta), + record_meta, + record_content, + ) def sort_groupcompress(parent_map): @@ -1998,7 +2225,7 @@ def sort_groupcompress(parent_map): for item in parent_map.items(): key = item[0] if isinstance(key, bytes) or len(key) == 1: - prefix = b'' + prefix = b"" else: prefix = key[0] try: @@ -2013,7 +2240,6 @@ def sort_groupcompress(parent_map): class _KeyRefs: - def __init__(self, track_new_keys=False): # dict mapping 'key' to 'set of keys referring to that key' self.refs = {} diff --git a/breezy/bzr/vf_repository.py b/breezy/bzr/vf_repository.py index 1d536edc13..930db8da0b 100644 --- a/breezy/bzr/vf_repository.py +++ b/breezy/bzr/vf_repository.py @@ -20,7 +20,9 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import itertools from breezy import ( @@ -47,7 +49,8 @@ from breezy.i18n import gettext from breezy.bzr.testament import Testament -""") +""", +) from .. import debug, errors, osutils from ..decorators import only_raises @@ -78,7 +81,7 @@ class VersionedFileRepositoryFormat(RepositoryFormat): # What order should fetch operations request streams in? # The default is unordered as that is the cheapest for an origin to # provide. - _fetch_order = 'unordered' + _fetch_order = "unordered" # Does this repository format use deltas that can be fetched as-deltas ? # (E.g. knits, where the knit deltas can be transplanted intact. # We default to False, which will ensure that enough data to get @@ -89,12 +92,30 @@ class VersionedFileRepositoryFormat(RepositoryFormat): class VersionedFileCommitBuilder(CommitBuilder): """Commit builder implementation for versioned files based repositories.""" - def __init__(self, repository, parents, config_stack, timestamp=None, - timezone=None, committer=None, revprops=None, - revision_id=None, lossy=False, owns_transaction=True): - super().__init__(repository, - parents, config_stack, timestamp, timezone, committer, revprops, - revision_id, lossy) + def __init__( + self, + repository, + parents, + config_stack, + timestamp=None, + timezone=None, + committer=None, + revprops=None, + revision_id=None, + lossy=False, + owns_transaction=True, + ): + super().__init__( + repository, + parents, + config_stack, + timestamp, + timezone, + committer, + revprops, + revision_id, + lossy, + ) try: basis_id = self.parents[0] except IndexError: @@ -128,34 +149,34 @@ def _ensure_fallback_inventories(self): if not self.repository._fallback_repositories: return if not self.repository._format.supports_chks: - raise errors.BzrError("Cannot commit directly to a stacked branch" - " in pre-2a formats. See " - "https://bugs.launchpad.net/bzr/+bug/375013 for details.") + raise errors.BzrError( + "Cannot commit directly to a stacked branch" + " in pre-2a formats. See " + "https://bugs.launchpad.net/bzr/+bug/375013 for details." + ) # This is a stacked repo, we need to make sure we have the parent # inventories for the parents. parent_keys = [(p,) for p in self.parents] - parent_map = self.repository.inventories._index.get_parent_map( - parent_keys) - missing_parent_keys = {pk for pk in parent_keys - if pk not in parent_map} + parent_map = self.repository.inventories._index.get_parent_map(parent_keys) + missing_parent_keys = {pk for pk in parent_keys if pk not in parent_map} fallback_repos = list(reversed(self.repository._fallback_repositories)) - missing_keys = [('inventories', pk[0]) - for pk in missing_parent_keys] + missing_keys = [("inventories", pk[0]) for pk in missing_parent_keys] while missing_keys and fallback_repos: fallback_repo = fallback_repos.pop() source = fallback_repo._get_source(self.repository._format) sink = self.repository._get_sink() missing_keys = sink.insert_missing_keys(source, missing_keys) if missing_keys: - raise errors.BzrError('Unable to fill in parent inventories for a' - ' stacked branch') + raise errors.BzrError( + "Unable to fill in parent inventories for a" " stacked branch" + ) def commit(self, message): """Make the actual commit. :return: The revision id of the recorded revision. """ - self._validate_unicode_text(message, 'commit message') + self._validate_unicode_text(message, "commit message") rev = _mod_revision.Revision( timestamp=self._timestamp, timezone=self._timezone, @@ -164,24 +185,29 @@ def commit(self, message): inventory_sha1=self.inv_sha1, revision_id=self._new_revision_id, parent_ids=self.parents, - properties=self._revprops) - create_signatures = self._config_stack.get('create_signatures') + properties=self._revprops, + ) + create_signatures = self._config_stack.get("create_signatures") if create_signatures in ( - _mod_config.SIGN_ALWAYS, _mod_config.SIGN_WHEN_POSSIBLE): + _mod_config.SIGN_ALWAYS, + _mod_config.SIGN_WHEN_POSSIBLE, + ): testament = Testament(rev, self.revision_tree()) plaintext = testament.as_short_text() try: self.repository.store_revision_signature( - gpg.GPGStrategy(self._config_stack), plaintext, - self._new_revision_id) + gpg.GPGStrategy(self._config_stack), + plaintext, + self._new_revision_id, + ) except gpg.GpgNotInstalled as e: if create_signatures == _mod_config.SIGN_WHEN_POSSIBLE: - note('skipping commit signature: %s', e) + note("skipping commit signature: %s", e) else: raise except gpg.SigningFailed as e: if create_signatures == _mod_config.SIGN_WHEN_POSSIBLE: - note('commit signature failed: %s', e) + note("commit signature failed: %s", e) else: raise self.repository._add_revision(rev) @@ -205,10 +231,10 @@ def revision_tree(self): memory. """ if self._new_inventory is None: - self._new_inventory = self.repository.get_inventory( - self._new_revision_id) - return inventorytree.InventoryRevisionTree(self.repository, - self._new_inventory, self._new_revision_id) + self._new_inventory = self.repository.get_inventory(self._new_revision_id) + return inventorytree.InventoryRevisionTree( + self.repository, self._new_inventory, self._new_revision_id + ) def finish_inventory(self): """Tell the builder that the inventory is finished. @@ -220,8 +246,11 @@ def finish_inventory(self): # inventory. basis_id = self.basis_delta_revision self.inv_sha1, self._new_inventory = self.repository.add_inventory_by_delta( - basis_id, InventoryDelta(self._basis_delta), self._new_revision_id, - self.parents) + basis_id, + InventoryDelta(self._basis_delta), + self._new_revision_id, + self.parents, + ) return self._new_revision_id def _gen_revision_id(self): @@ -240,10 +269,9 @@ def _require_root_change(self, tree): return if len(self.parents) == 0: raise errors.RootMissing() - entry = entry_factory['directory'](tree.path2id(''), '', - None) + entry = entry_factory["directory"](tree.path2id(""), "", None) entry.revision = self._new_revision_id - self._basis_delta.append(('', '', entry.file_id, entry)) + self._basis_delta.append(("", "", entry.file_id, entry)) def _get_delta(self, ie, basis_inv, path): """Get a delta against the basis inventory for ie.""" @@ -278,8 +306,9 @@ def get_basis_delta(self): """ return InventoryDelta(self._basis_delta) - def record_iter_changes(self, tree, basis_revision_id, iter_changes, - _entry_factory=entry_factory): + def record_iter_changes( + self, tree, basis_revision_id, iter_changes, _entry_factory=entry_factory + ): """Record a new tree via iter_changes. :param tree: The tree to obtain text contents from for changed objects. @@ -319,19 +348,18 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, if not revtrees: basis_revision_id = _mod_revision.NULL_REVISION ghost_basis = True - revtrees.append(self.repository.revision_tree( - _mod_revision.NULL_REVISION)) + revtrees.append( + self.repository.revision_tree(_mod_revision.NULL_REVISION) + ) # The basis inventory from a repository if revtrees: basis_tree = revtrees[0] else: - basis_tree = self.repository.revision_tree( - _mod_revision.NULL_REVISION) + basis_tree = self.repository.revision_tree(_mod_revision.NULL_REVISION) basis_inv = basis_tree.root_inventory if len(self.parents) > 0: if basis_revision_id != self.parents[0] and not ghost_basis: - raise Exception( - "arbitrary basis parents not yet supported with merges") + raise Exception("arbitrary basis parents not yet supported with merges") for revtree in revtrees[1:]: for change in revtree.root_inventory._make_delta(basis_inv): if change[1] is None: @@ -344,21 +372,20 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, # basis revid basis_entry.revision, # new tree revid - change[3].revision] + change[3].revision, + ] parent_entries[change[2]] = { # basis parent basis_entry.revision: basis_entry, # this parent change[3].revision: change[3], - } + } else: merged_ids[change[2]] = [change[3].revision] - parent_entries[change[2]] = { - change[3].revision: change[3]} + parent_entries[change[2]] = {change[3].revision: change[3]} else: merged_ids[change[2]].append(change[3].revision) - parent_entries[change[2] - ][change[3].revision] = change[3] + parent_entries[change[2]][change[3].revision] = change[3] else: merged_ids = {} # Setup the changes from the tree: @@ -370,8 +397,10 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, head_candidate = [basis_inv.get_entry(change.file_id).revision] else: head_candidate = [] - changes[change.file_id] = change, merged_ids.get( - change.file_id, head_candidate) + changes[change.file_id] = ( + change, + merged_ids.get(change.file_id, head_candidate), + ) unchanged_merged = set(merged_ids) - set(changes) # Extend the changes dict with synthetic changes to record merges of # texts. @@ -398,11 +427,13 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, change = InventoryTreeChange( file_id, (basis_inv.id2path(file_id), tree.id2path(file_id)), - False, (True, True), + False, + (True, True), (basis_entry.parent_id, basis_entry.parent_id), (basis_entry.name, basis_entry.name), (basis_entry.kind, basis_entry.kind), - (basis_entry.executable, basis_entry.executable)) + (basis_entry.executable, basis_entry.executable), + ) changes[file_id] = (change, merged_ids[file_id]) # changes contains tuples with the change and a set of inventory # candidates for the file. @@ -422,8 +453,9 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, # - record the change with the content from tree kind = change.kind[1] file_id = change.file_id - entry = _entry_factory[kind](file_id, change.name[1], - change.parent_id[1]) + entry = _entry_factory[kind]( + file_id, change.name[1], change.parent_id[1] + ) head_set = self._heads(change.file_id, set(head_candidates)) heads = [] # Preserve ordering. @@ -450,9 +482,11 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, # we need to check the content against the source of the # merge to determine if it was changed after the merge # or carried over. - if (parent_entry.kind != entry.kind + if ( + parent_entry.kind != entry.kind or parent_entry.parent_id != entry.parent_id - or parent_entry.name != entry.name): + or parent_entry.name != entry.name + ): # Metadata common to all entries has changed # against per-file parent carry_over_possible = False @@ -464,7 +498,7 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, # Cannot be a carry-over situation carry_over_possible = False # Populate the entry in the delta - if kind == 'file': + if kind == "file": # XXX: There is still a small race here: If someone reverts # the content of a file after iter_changes examines and # decides it has changed, we will unconditionally record a @@ -475,8 +509,10 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, entry.executable = True else: entry.executable = False - if (carry_over_possible - and parent_entry.executable == entry.executable): + if ( + carry_over_possible + and parent_entry.executable == entry.executable + ): # Check the file length, content hash after reading # the file. nostore_sha = parent_entry.text_sha1 @@ -485,8 +521,12 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, file_obj, stat_value = tree.get_file_with_stat(change.path[1]) try: entry.text_sha1, entry.text_size = self._add_file_to_weave( - file_id, file_obj, heads, nostore_sha, - size=(stat_value.st_size if stat_value else None)) + file_id, + file_obj, + heads, + nostore_sha, + size=(stat_value.st_size if stat_value else None), + ) yield change.path[1], (entry.text_sha1, stat_value) except versionedfile.ExistingContent: # No content change against a carry_over parent @@ -496,46 +536,50 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, entry.text_sha1 = parent_entry.text_sha1 finally: file_obj.close() - elif kind == 'symlink': + elif kind == "symlink": # Wants a path hint? - entry.symlink_target = tree.get_symlink_target( - change.path[1]) - if (carry_over_possible and - parent_entry.symlink_target == - entry.symlink_target): + entry.symlink_target = tree.get_symlink_target(change.path[1]) + if ( + carry_over_possible + and parent_entry.symlink_target == entry.symlink_target + ): carried_over = True else: self._add_file_to_weave( - change.file_id, BytesIO(), heads, None, size=0) - elif kind == 'directory': + change.file_id, BytesIO(), heads, None, size=0 + ) + elif kind == "directory": if carry_over_possible: carried_over = True else: # Nothing to set on the entry. # XXX: split into the Root and nonRoot versions. - if change.path[1] != '' or self.repository.supports_rich_root(): + if change.path[1] != "" or self.repository.supports_rich_root(): self._add_file_to_weave( - change.file_id, BytesIO(), heads, None, size=0) - elif kind == 'tree-reference': + change.file_id, BytesIO(), heads, None, size=0 + ) + elif kind == "tree-reference": if not self.repository._format.supports_tree_reference: # This isn't quite sane as an error, but we shouldn't # ever see this code path in practice: tree's don't # permit references when the repo doesn't support tree # references. raise errors.UnsupportedOperation( - tree.add_reference, self.repository) - reference_revision = tree.get_reference_revision( - change.path[1]) + tree.add_reference, self.repository + ) + reference_revision = tree.get_reference_revision(change.path[1]) entry.reference_revision = reference_revision - if (carry_over_possible - and parent_entry.reference_revision == - reference_revision): + if ( + carry_over_possible + and parent_entry.reference_revision == reference_revision + ): carried_over = True else: self._add_file_to_weave( - change.file_id, BytesIO(), heads, None, size=0) + change.file_id, BytesIO(), heads, None, size=0 + ) else: - raise AssertionError(f'unknown kind {kind!r}') + raise AssertionError(f"unknown kind {kind!r}") if not carried_over: entry.revision = modified_rev else: @@ -544,12 +588,13 @@ def record_iter_changes(self, tree, basis_revision_id, iter_changes, entry = None new_path = change.path[1] inv_delta.append((change.path[0], new_path, change.file_id, entry)) - if new_path == '': + if new_path == "": seen_root = True # The initial commit adds a root directory, but this in itself is not # a worthwhile commit. - if ((len(inv_delta) > 0 and basis_revision_id != _mod_revision.NULL_REVISION) - or (len(inv_delta) > 1 and basis_revision_id == _mod_revision.NULL_REVISION)): + if ( + len(inv_delta) > 0 and basis_revision_id != _mod_revision.NULL_REVISION + ) or (len(inv_delta) > 1 and basis_revision_id == _mod_revision.NULL_REVISION): # This should perhaps be guarded by a check that the basis we # commit against is the basis for the commit and if not do a delta # against the basis. @@ -563,8 +608,11 @@ def _add_file_to_weave(self, file_id, fileobj, parents, nostore_sha, size): parent_keys = tuple([(file_id, parent) for parent in parents]) return self.repository.texts.add_content( versionedfile.FileContentFactory( - (file_id, self._new_revision_id), parent_keys, fileobj, size=size), - nostore_sha=nostore_sha, random_id=self.random_revid)[0:2] + (file_id, self._new_revision_id), parent_keys, fileobj, size=size + ), + nostore_sha=nostore_sha, + random_id=self.random_revid, + )[0:2] class VersionedFileRepository(Repository): @@ -671,7 +719,8 @@ def add_inventory(self, revision_id, inv, parents): if not (inv.revision_id is None or inv.revision_id == revision_id): raise AssertionError( "Mismatch between inventory revision" - f" id and insertion revid ({inv.revision_id!r}, {revision_id!r})") + f" id and insertion revid ({inv.revision_id!r}, {revision_id!r})" + ) if inv.root is None: raise errors.RootMissing() return self._add_inventory_checked(revision_id, inv, parents) @@ -684,11 +733,19 @@ def _add_inventory_checked(self, revision_id, inv, parents): :seealso: add_inventory, for the contract. """ inv_lines = self._inventory_serializer.write_inventory_to_lines(inv) - return self._inventory_add_lines(revision_id, parents, - inv_lines, check_content=False) - - def add_inventory_by_delta(self, basis_revision_id, delta, new_revision_id, - parents, basis_inv=None, propagate_caches=False): + return self._inventory_add_lines( + revision_id, parents, inv_lines, check_content=False + ) + + def add_inventory_by_delta( + self, + basis_revision_id, + delta, + new_revision_id, + parents, + basis_inv=None, + propagate_caches=False, + ): """Add a new inventory expressed as a delta against another revision. See the inventory developers documentation for the theory behind @@ -728,15 +785,14 @@ def add_inventory_by_delta(self, basis_revision_id, delta, new_revision_id, basis_inv = basis_tree.root_inventory basis_inv.apply_delta(delta) basis_inv.revision_id = new_revision_id - return (self.add_inventory(new_revision_id, basis_inv, parents), - basis_inv) + return (self.add_inventory(new_revision_id, basis_inv, parents), basis_inv) - def _inventory_add_lines(self, revision_id, parents, lines, - check_content=True): + def _inventory_add_lines(self, revision_id, parents, lines, check_content=True): """Store lines in inv_vf and return the sha1 of the inventory.""" parents = [(parent,) for parent in parents] - result = self.inventories.add_lines((revision_id,), parents, lines, - check_content=check_content)[0] + result = self.inventories.add_lines( + (revision_id,), parents, lines, check_content=check_content + )[0] self.inventories._access.flush() return result @@ -754,12 +810,12 @@ def add_revision(self, revision_id, rev, inv=None): # check inventory present if not self.inventories.get_parent_map([(revision_id,)]): if inv is None: - raise errors.WeaveRevisionNotPresent(revision_id, - self.inventories) + raise errors.WeaveRevisionNotPresent(revision_id, self.inventories) else: # yes, this is not suitable for adding with ghosts. - rev.inventory_sha1 = self.add_inventory(revision_id, inv, - rev.parent_ids) + rev.inventory_sha1 = self.add_inventory( + revision_id, inv, rev.parent_ids + ) else: key = (revision_id,) rev.inventory_sha1 = self.inventories.get_sha1s([key])[key] @@ -783,67 +839,73 @@ def _check_inventories(self, checker): def _do_check_inventories(self, checker, bar): """Helper for _check_inventories.""" - keys = {'chk_bytes': set(), 'inventories': set(), 'texts': set()} - kinds = ['chk_bytes', 'texts'] + keys = {"chk_bytes": set(), "inventories": set(), "texts": set()} + kinds = ["chk_bytes", "texts"] len(checker.pending_keys) bar.update(gettext("inventories"), 0, 2) current_keys = checker.pending_keys checker.pending_keys = {} # Accumulate current checks. for key in current_keys: - if key[0] != 'inventories' and key[0] not in kinds: - checker._report_items.append(f'unknown key type {key!r}') + if key[0] != "inventories" and key[0] not in kinds: + checker._report_items.append(f"unknown key type {key!r}") keys[key[0]].add(key[1:]) - if keys['inventories']: + if keys["inventories"]: # NB: output order *should* be roughly sorted - topo or # inverse topo depending on repository - either way decent # to just delta against. However, pre-CHK formats didn't # try to optimise inventory layout on disk. As such the # pre-CHK code path does not use inventory deltas. last_object = None - for record in self.inventories.check(keys=keys['inventories']): - if record.storage_kind == 'absent': - checker._report_items.append( - f'Missing inventory {{{record.key}}}') + for record in self.inventories.check(keys=keys["inventories"]): + if record.storage_kind == "absent": + checker._report_items.append(f"Missing inventory {{{record.key}}}") else: - last_object = self._check_record('inventories', record, - checker, last_object, - current_keys[('inventories',) + record.key]) - del keys['inventories'] + last_object = self._check_record( + "inventories", + record, + checker, + last_object, + current_keys[("inventories",) + record.key], + ) + del keys["inventories"] else: return bar.update(gettext("texts"), 1) - while (checker.pending_keys or keys['chk_bytes'] or - keys['texts']): + while checker.pending_keys or keys["chk_bytes"] or keys["texts"]: # Something to check. current_keys = checker.pending_keys checker.pending_keys = {} # Accumulate current checks. for key in current_keys: if key[0] not in kinds: - checker._report_items.append( - f'unknown key type {key!r}') + checker._report_items.append(f"unknown key type {key!r}") keys[key[0]].add(key[1:]) # Check the outermost kind only - inventories || chk_bytes || texts for kind in kinds: if keys[kind]: last_object = None for record in getattr(self, kind).check(keys=keys[kind]): - if record.storage_kind == 'absent': + if record.storage_kind == "absent": checker._report_items.append( - f'Missing {kind} {{{record.key}}}') + f"Missing {kind} {{{record.key}}}" + ) else: - last_object = self._check_record(kind, record, - checker, last_object, current_keys[(kind,) + record.key]) + last_object = self._check_record( + kind, + record, + checker, + last_object, + current_keys[(kind,) + record.key], + ) keys[kind] = set() break def _check_record(self, kind, record, checker, last_object, item_data): """Check a single text from this repository.""" - if kind == 'inventories': + if kind == "inventories": rev_id = record.key[0] - inv = self._deserialise_inventory( - rev_id, record.get_bytes_as('lines')) + inv = self._deserialise_inventory(rev_id, record.get_bytes_as("lines")) if last_object is not None: delta = inv._make_delta(last_object) for _old_path, _path, _file_id, ie in delta: @@ -855,26 +917,27 @@ def _check_record(self, kind, record, checker, last_object, item_data): ie.check(checker, rev_id, inv) if self._format.fast_deltas: return inv - elif kind == 'chk_bytes': + elif kind == "chk_bytes": # No code written to check chk_bytes for this repo format. checker._report_items.append( - f'unsupported key type chk_bytes for {record.key}') - elif kind == 'texts': + f"unsupported key type chk_bytes for {record.key}" + ) + elif kind == "texts": self._check_text(record, checker, item_data) else: - checker._report_items.append( - f'unknown key type {kind} for {record.key}') + checker._report_items.append(f"unknown key type {kind} for {record.key}") def _check_text(self, record, checker, item_data): """Check a single text.""" # Check it is extractable. # TODO: check length. - chunks = record.get_bytes_as('chunked') + chunks = record.get_bytes_as("chunked") sha1 = osutils.sha_strings(chunks) sum(map(len, chunks)) if item_data and sha1 != item_data[1]: checker._report_items.append( - f'sha1 mismatch: {record.key} has sha1 {sha1} expected {item_data[1]} referenced by {item_data[2]}') + f"sha1 mismatch: {record.key} has sha1 {sha1} expected {item_data[1]} referenced by {item_data[2]}" + ) def _eliminate_revisions_not_present(self, revision_ids): """Check every revision id in revision_ids to see if we have it. @@ -897,8 +960,7 @@ def __init__(self, _format, a_controldir, control_files): # In the future we will have a single api for all stores for # getting file texts, inventories and revisions, then # this construct will accept instances of those things. - super().__init__(_format, a_controldir, - control_files) + super().__init__(_format, a_controldir, control_files) self._transport = control_files._transport self.base = self._transport.base # for tests @@ -911,8 +973,9 @@ def __init__(self, _format, a_controldir, control_files): # rather copying them? self._safe_to_return_from_cache = False - def fetch(self, source, revision_id=None, find_ghosts=False, - fetch_spec=None, lossy=False): + def fetch( + self, source, revision_id=None, find_ghosts=False, fetch_spec=None, lossy=False + ): """Fetch the content required to construct revision_id from source. If revision_id is None and fetch_spec is None, then all content is @@ -934,31 +997,31 @@ def fetch(self, source, revision_id=None, find_ghosts=False, revision_id. """ if fetch_spec is not None and revision_id is not None: - raise AssertionError( - "fetch_spec and revision_id are mutually exclusive.") + raise AssertionError("fetch_spec and revision_id are mutually exclusive.") if self.is_in_write_group(): - raise errors.InternalBzrError( - "May not fetch while in a write group.") + raise errors.InternalBzrError("May not fetch while in a write group.") # fast path same-url fetch operations # TODO: lift out to somewhere common with RemoteRepository # - if (self.has_same_location(source) and - fetch_spec is None and - self._has_same_fallbacks(source)): + if ( + self.has_same_location(source) + and fetch_spec is None + and self._has_same_fallbacks(source) + ): # check that last_revision is in 'from' and then return a # no-operation. - if (revision_id is not None - and not _mod_revision.is_null(revision_id)): + if revision_id is not None and not _mod_revision.is_null(revision_id): self.get_revision(revision_id) return FetchResult(0) inter = InterRepository.get(source, self) - if (fetch_spec is not None - and not getattr(inter, "supports_fetch_spec", False)): - raise errors.UnsupportedOperation( - f"fetch_spec not supported for {inter!r}") - return inter.fetch(revision_id=revision_id, - find_ghosts=find_ghosts, fetch_spec=fetch_spec, - lossy=lossy) + if fetch_spec is not None and not getattr(inter, "supports_fetch_spec", False): + raise errors.UnsupportedOperation(f"fetch_spec not supported for {inter!r}") + return inter.fetch( + revision_id=revision_id, + find_ghosts=find_ghosts, + fetch_spec=fetch_spec, + lossy=lossy, + ) def gather_stats(self, revid=None, committers=None): """See Repository.gather_stats().""" @@ -970,13 +1033,22 @@ def gather_stats(self, revid=None, committers=None): # XXX: do we want to __define len__() ? # Maybe the versionedfiles object should provide a different # method to get the number of keys. - result['revisions'] = len(self.revisions.keys()) + result["revisions"] = len(self.revisions.keys()) # result['size'] = t return result - def get_commit_builder(self, branch, parents, config_stack, timestamp=None, - timezone=None, committer=None, revprops=None, - revision_id=None, lossy=False): + def get_commit_builder( + self, + branch, + parents, + config_stack, + timestamp=None, + timezone=None, + committer=None, + revprops=None, + revision_id=None, + lossy=False, + ): """Obtain a CommitBuilder for this repository. :param branch: Branch to commit to. @@ -991,14 +1063,24 @@ def get_commit_builder(self, branch, parents, config_stack, timestamp=None, represented, when pushing to a foreign VCS """ if self._fallback_repositories and not self._format.supports_chks: - raise errors.BzrError("Cannot commit directly to a stacked branch" - " in pre-2a formats. See " - "https://bugs.launchpad.net/bzr/+bug/375013 for details.") + raise errors.BzrError( + "Cannot commit directly to a stacked branch" + " in pre-2a formats. See " + "https://bugs.launchpad.net/bzr/+bug/375013 for details." + ) in_transaction = self.is_in_write_group() result = self._commit_builder_class( - self, parents, config_stack, - timestamp, timezone, committer, revprops, revision_id, - lossy, owns_transaction=not in_transaction) + self, + parents, + config_stack, + timestamp, + timezone, + committer, + revprops, + revision_id, + lossy, + owns_transaction=not in_transaction, + ) if not in_transaction: self.start_write_group() return result @@ -1017,7 +1099,7 @@ def get_missing_parent_inventories(self, check_for_missing_texts=True): # This is only an issue for stacked repositories return set() if not self.is_in_write_group(): - raise AssertionError('not in a write group') + raise AssertionError("not in a write group") # XXX: We assume that every added revision already has its # corresponding inventory, so we only check for parent inventories that @@ -1026,13 +1108,14 @@ def get_missing_parent_inventories(self, check_for_missing_texts=True): parents.discard(_mod_revision.NULL_REVISION) unstacked_inventories = self.inventories._index present_inventories = unstacked_inventories.get_parent_map( - key[-1:] for key in parents) + key[-1:] for key in parents + ) parents.difference_update(present_inventories) if len(parents) == 0: # No missing parent inventories. return set() if not check_for_missing_texts: - return {('inventories', rev_id) for (rev_id,) in parents} + return {("inventories", rev_id) for (rev_id,) in parents} # Ok, now we have a list of missing inventories. But these only matter # if the inventories that reference them are missing some texts they # appear to introduce. @@ -1045,8 +1128,7 @@ def get_missing_parent_inventories(self, check_for_missing_texts=True): file_ids = self.fileids_altered_by_revision_ids(referrers) missing_texts = set() for file_id, version_ids in file_ids.items(): - missing_texts.update( - (file_id, version_id) for version_id in version_ids) + missing_texts.update((file_id, version_id) for version_id in version_ids) present_texts = self.texts.get_parent_map(missing_texts) missing_texts.difference_update(present_texts) if not missing_texts: @@ -1055,7 +1137,7 @@ def get_missing_parent_inventories(self, check_for_missing_texts=True): return set() # Alternatively the text versions could be returned as the missing # keys, but this is likely to be less data. - missing_keys = {('inventories', rev_id) for (rev_id,) in parents} + missing_keys = {("inventories", rev_id) for (rev_id,) in parents} return missing_keys def has_revisions(self, revision_ids): @@ -1066,7 +1148,8 @@ def has_revisions(self, revision_ids): """ with self.lock_read(): parent_map = self.revisions.get_parent_map( - [(rev_id,) for rev_id in revision_ids]) + [(rev_id,) for rev_id in revision_ids] + ) result = set() if _mod_revision.NULL_REVISION in revision_ids: result.add(_mod_revision.NULL_REVISION) @@ -1097,16 +1180,15 @@ def iter_revisions(self, revision_ids): with self.lock_read(): for rev_id in revision_ids: if not rev_id or not isinstance(rev_id, bytes): - raise errors.InvalidRevisionId( - revision_id=rev_id, branch=self) + raise errors.InvalidRevisionId(revision_id=rev_id, branch=self) keys = [(key,) for key in revision_ids] - stream = self.revisions.get_record_stream(keys, 'unordered', True) + stream = self.revisions.get_record_stream(keys, "unordered", True) for record in stream: revid = record.key[0] - if record.storage_kind == 'absent': + if record.storage_kind == "absent": yield (revid, None) else: - text = record.get_bytes_as('fulltext') + text = record.get_bytes_as("fulltext") rev = self._revision_serializer.read_revision_from_string(text) yield (revid, rev) @@ -1117,13 +1199,13 @@ def add_signature_text(self, revision_id, signature): :param signature: Signature text. """ with self.lock_write(): - self.signatures.add_lines((revision_id,), (), - osutils.split_lines(signature)) + self.signatures.add_lines( + (revision_id,), (), osutils.split_lines(signature) + ) def sign_revision(self, revision_id, gpg_strategy): with self.lock_write(): - testament = Testament.from_revision( - self, revision_id) + testament = Testament.from_revision(self, revision_id) plaintext = testament.as_short_text() self.store_revision_signature(gpg_strategy, plaintext, revision_id) @@ -1145,8 +1227,7 @@ def verify_revision_signature(self, revision_id, gpg_strategy): return gpg.SIGNATURE_NOT_SIGNED, None signature = self.get_signature_text(revision_id) - testament = Testament.from_revision( - self, revision_id) + testament = Testament.from_revision(self, revision_id) (status, key, signed_plaintext) = gpg_strategy.verify(signature) if testament.as_short_text() != signed_plaintext: @@ -1165,7 +1246,8 @@ def find_text_key_references(self): w = self.inventories with ui.ui_factory.nested_progress_bar() as pb: return self._inventory_serializer._find_text_key_references( - w.iter_lines_added_or_present_in_keys(revision_keys, pb=pb)) + w.iter_lines_added_or_present_in_keys(revision_keys, pb=pb) + ) def _inventory_xml_lines_for_keys(self, keys): """Get a line iterator of the sort needed for findind references. @@ -1178,15 +1260,14 @@ def _inventory_xml_lines_for_keys(self, keys): :return: An iterator over (inventory line, revid) for the fulltexts of all of the xml inventories specified by revision_keys. """ - stream = self.inventories.get_record_stream(keys, 'unordered', True) + stream = self.inventories.get_record_stream(keys, "unordered", True) for record in stream: - if record.storage_kind != 'absent': + if record.storage_kind != "absent": revid = record.key[-1] - for line in record.get_bytes_as('lines'): + for line in record.get_bytes_as("lines"): yield line, revid - def _find_file_ids_from_xml_inventory_lines(self, line_iterator, - revision_keys): + def _find_file_ids_from_xml_inventory_lines(self, line_iterator, revision_keys): """Helper routine for fileids_altered_by_revision_ids. This performs the translation of xml lines to revision ids. @@ -1202,8 +1283,11 @@ def _find_file_ids_from_xml_inventory_lines(self, line_iterator, """ seen = set(self._inventory_serializer._find_text_key_references(line_iterator)) parent_keys = self._find_parent_keys_of_revisions(revision_keys) - parent_seen = set(self._inventory_serializer._find_text_key_references( - self._inventory_xml_lines_for_keys(parent_keys))) + parent_seen = set( + self._inventory_serializer._find_text_key_references( + self._inventory_xml_lines_for_keys(parent_keys) + ) + ) new_keys = seen - parent_seen result = {} setdefault = result.setdefault @@ -1219,8 +1303,7 @@ def _find_parent_keys_of_revisions(self, revision_keys): revision_keys """ parent_map = self.revisions.get_parent_map(revision_keys) - parent_keys = set(itertools.chain.from_iterable( - parent_map.values())) + parent_keys = set(itertools.chain.from_iterable(parent_map.values())) parent_keys.difference_update(revision_keys) parent_keys.discard(_mod_revision.NULL_REVISION) return parent_keys @@ -1238,9 +1321,8 @@ def fileids_altered_by_revision_ids(self, revision_ids, _inv_weave=None): selected_keys = {(revid,) for revid in revision_ids} w = _inv_weave or self.inventories return self._find_file_ids_from_xml_inventory_lines( - w.iter_lines_added_or_present_in_keys( - selected_keys, pb=None), - selected_keys) + w.iter_lines_added_or_present_in_keys(selected_keys, pb=None), selected_keys + ) def iter_files_bytes(self, desired_files): """Iterate through file versions. @@ -1264,13 +1346,12 @@ def iter_files_bytes(self, desired_files): text_keys = {} for file_id, revision_id, callable_data in desired_files: text_keys[(file_id, revision_id)] = callable_data - for record in self.texts.get_record_stream(text_keys, 'unordered', True): - if record.storage_kind == 'absent': + for record in self.texts.get_record_stream(text_keys, "unordered", True): + if record.storage_kind == "absent": raise errors.RevisionNotPresent(record.key[1], record.key[0]) - yield text_keys[record.key], record.iter_bytes_as('chunked') + yield text_keys[record.key], record.iter_bytes_as("chunked") - def _generate_text_key_index(self, text_key_references=None, - ancestors=None): + def _generate_text_key_index(self, text_key_references=None, ancestors=None): """Generate a new text key index for the repository. This is an expensive function that will take considerable time to run. @@ -1286,8 +1367,7 @@ def _generate_text_key_index(self, text_key_references=None, if text_key_references is None: text_key_references = self.find_text_key_references() with ui.ui_factory.nested_progress_bar() as pb: - return self._do_generate_text_key_index(ancestors, - text_key_references, pb) + return self._do_generate_text_key_index(ancestors, text_key_references, pb) def _do_generate_text_key_index(self, ancestors, text_key_references, pb): """Helper for _generate_text_key_index to avoid deep nesting.""" @@ -1319,25 +1399,23 @@ def _do_generate_text_key_index(self, ancestors, text_key_references, pb): batch_size = 10 # should be ~150MB on a 55K path tree batch_count = len(revision_order) // batch_size + 1 processed_texts = 0 - pb.update(gettext("Calculating text parents"), - processed_texts, text_count) + pb.update(gettext("Calculating text parents"), processed_texts, text_count) for offset in range(batch_count): - to_query = revision_order[offset * batch_size:(offset + 1) - * batch_size] + to_query = revision_order[offset * batch_size : (offset + 1) * batch_size] if not to_query: break for revision_id in to_query: parent_ids = ancestors[revision_id] for text_key in revision_keys[revision_id]: - pb.update(gettext("Calculating text parents"), - processed_texts) + pb.update(gettext("Calculating text parents"), processed_texts) processed_texts += 1 candidate_parents = [] for parent_id in parent_ids: parent_text_key = (text_key[0], parent_id) try: - check_parent = parent_text_key not in \ - revision_keys[parent_id] + check_parent = ( + parent_text_key not in revision_keys[parent_id] + ) except KeyError: # the parent parent_id is a ghost: check_parent = False @@ -1350,21 +1428,18 @@ def _do_generate_text_key_index(self, ancestors, text_key_references, pb): try: inv = inventory_cache[parent_id] except KeyError: - inv = self.revision_tree( - parent_id).root_inventory + inv = self.revision_tree(parent_id).root_inventory inventory_cache[parent_id] = inv try: parent_entry = inv.get_entry(text_key[0]) except (KeyError, errors.NoSuchId): parent_entry = None if parent_entry is not None: - parent_text_key = ( - text_key[0], parent_entry.revision) + parent_text_key = (text_key[0], parent_entry.revision) else: parent_text_key = None if parent_text_key is not None: - candidate_parents.append( - text_key_cache[parent_text_key]) + candidate_parents.append(text_key_cache[parent_text_key]) parent_heads = text_graph.heads(candidate_parents) new_parents = list(parent_heads) new_parents.sort(key=lambda x: candidate_parents.index(x)) @@ -1417,8 +1492,9 @@ def _find_non_file_keys_to_fetch(self, revision_ids): # XXX: Note ATM no callers actually pay attention to this return # instead they just use the list of revision ids and ignore # missing sigs. Consider removing this work entirely - revisions_with_signatures = set(self.signatures.get_parent_map( - [(r,) for r in revision_ids])) + revisions_with_signatures = set( + self.signatures.get_parent_map([(r,) for r in revision_ids]) + ) revisions_with_signatures = {r for (r,) in revisions_with_signatures} revisions_with_signatures.intersection_update(revision_ids) yield ("signatures", None, revisions_with_signatures) @@ -1444,9 +1520,8 @@ def iter_inventories(self, revision_ids, ordering=None): buffering if necessary). :return: An iterator of inventories. """ - if ((None in revision_ids) or - (_mod_revision.NULL_REVISION in revision_ids)): - raise ValueError('cannot get null revision inventory') + if (None in revision_ids) or (_mod_revision.NULL_REVISION in revision_ids): + raise ValueError("cannot get null revision inventory") for inv, revid in self._iter_inventories(revision_ids, ordering): if inv is None: raise errors.NoSuchRevision(self, revid) @@ -1464,7 +1539,7 @@ def _iter_inventories(self, revision_ids, ordering): def _iter_inventory_xmls(self, revision_ids, ordering): if ordering is None: order_as_requested = True - ordering = 'unordered' + ordering = "unordered" else: order_as_requested = False keys = [(revision_id,) for revision_id in revision_ids] @@ -1476,8 +1551,8 @@ def _iter_inventory_xmls(self, revision_ids, ordering): stream = self.inventories.get_record_stream(keys, ordering, True) text_lines = {} for record in stream: - if record.storage_kind != 'absent': - lines = record.get_bytes_as('lines') + if record.storage_kind != "absent": + lines = record.get_bytes_as("lines") if order_as_requested: text_lines[record.key] = lines else: @@ -1504,11 +1579,15 @@ def _deserialise_inventory(self, revision_id, xml): :param xml: A serialised inventory. """ result = self._inventory_serializer.read_inventory_from_lines( - xml, revision_id, entry_cache=self._inventory_entry_cache, - return_from_cache=self._safe_to_return_from_cache) + xml, + revision_id, + entry_cache=self._inventory_entry_cache, + return_from_cache=self._safe_to_return_from_cache, + ) if result.revision_id != revision_id: - raise AssertionError('revision id mismatch {} != {}'.format( - result.revision_id, revision_id)) + raise AssertionError( + f"revision id mismatch {result.revision_id} != {revision_id}" + ) return result def get_serializer_format(self): @@ -1517,7 +1596,7 @@ def get_serializer_format(self): def _get_inventory_xml(self, revision_id): """Get serialized inventory as a string.""" with self.lock_read(): - texts = self._iter_inventory_xmls([revision_id], 'unordered') + texts = self._iter_inventory_xmls([revision_id], "unordered") lines, revision_id = next(texts) if lines is None: raise errors.NoSuchRevision(self, revision_id) @@ -1531,8 +1610,9 @@ def revision_tree(self, revision_id): # TODO: refactor this to use an existing revision object # so we don't need to read it in twice. if revision_id == _mod_revision.NULL_REVISION: - return inventorytree.InventoryRevisionTree(self, - Inventory(root_id=None), _mod_revision.NULL_REVISION) + return inventorytree.InventoryRevisionTree( + self, Inventory(root_id=None), _mod_revision.NULL_REVISION + ) else: with self.lock_read(): inv = self.get_inventory(revision_id) @@ -1558,14 +1638,16 @@ def get_parent_map(self, revision_ids): if revision_id == _mod_revision.NULL_REVISION: result[revision_id] = () elif revision_id is None: - raise ValueError('get_parent_map(None) is not valid') + raise ValueError("get_parent_map(None) is not valid") else: query_keys.append((revision_id,)) for (revision_id,), parent_keys in ( - self.revisions.get_parent_map(query_keys)).items(): + self.revisions.get_parent_map(query_keys) + ).items(): if parent_keys: - result[revision_id] = tuple([parent_revid - for (parent_revid,) in parent_keys]) + result[revision_id] = tuple( + [parent_revid for (parent_revid,) in parent_keys] + ) else: result[revision_id] = (_mod_revision.NULL_REVISION,) return result @@ -1575,8 +1657,7 @@ def get_known_graph_ancestry(self, revision_ids): st = static_tuple.StaticTuple revision_keys = [st(r_id).intern() for r_id in revision_ids] with self.lock_read(): - known_graph = self.revisions.get_known_graph_ancestry( - revision_keys) + known_graph = self.revisions.get_known_graph_ancestry(revision_keys) return graph.GraphThunkIdsToKeys(known_graph) def get_file_graph(self): @@ -1586,17 +1667,20 @@ def get_file_graph(self): def revision_ids_to_search_result(self, result_set): """Convert a set of revision ids to a graph SearchResult.""" - result_parents = set(itertools.chain.from_iterable( - self.get_graph().get_parent_map(result_set).values())) + result_parents = set( + itertools.chain.from_iterable( + self.get_graph().get_parent_map(result_set).values() + ) + ) included_keys = result_set.intersection(result_parents) start_keys = result_set.difference(included_keys) exclude_keys = result_parents.difference(result_set) - result = vf_search.SearchResult(start_keys, exclude_keys, - len(result_set), result_set) + result = vf_search.SearchResult( + start_keys, exclude_keys, len(result_set), result_set + ) return result - def _get_versioned_file_checker(self, text_key_references=None, - ancestors=None): + def _get_versioned_file_checker(self, text_key_references=None, ancestors=None): """Return an object suitable for checking versioned files. :param text_key_references: if non-None, an already built @@ -1608,27 +1692,28 @@ def _get_versioned_file_checker(self, text_key_references=None, self.get_graph().get_parent_map(self.all_revision_ids()) if already available. """ - return _VersionedFileChecker(self, - text_key_references=text_key_references, ancestors=ancestors) + return _VersionedFileChecker( + self, text_key_references=text_key_references, ancestors=ancestors + ) def has_signature_for_revision_id(self, revision_id): """Query for a revision signature for revision_id in the repository.""" with self.lock_read(): if not self.has_revision(revision_id): raise errors.NoSuchRevision(self, revision_id) - sig_present = (1 == len( - self.signatures.get_parent_map([(revision_id,)]))) + sig_present = 1 == len(self.signatures.get_parent_map([(revision_id,)])) return sig_present def get_signature_text(self, revision_id): """Return the text for a signature.""" with self.lock_read(): - stream = self.signatures.get_record_stream([(revision_id,)], - 'unordered', True) + stream = self.signatures.get_record_stream( + [(revision_id,)], "unordered", True + ) record = next(stream) - if record.storage_kind == 'absent': + if record.storage_kind == "absent": raise errors.NoSuchRevision(self, revision_id) - return record.get_bytes_as('fulltext') + return record.get_bytes_as("fulltext") def _check(self, revision_ids, callback_refs, check_repo): with self.lock_read(): @@ -1654,18 +1739,17 @@ def _find_inconsistent_revision_parents(self, revisions_iterator=None): if revision is None: pass parent_map = vf.get_parent_map([(revid,)]) - parents_according_to_index = tuple(parent[-1] for parent in - parent_map[(revid,)]) + parents_according_to_index = tuple( + parent[-1] for parent in parent_map[(revid,)] + ) parents_according_to_revision = tuple(revision.parent_ids) if parents_according_to_index != parents_according_to_revision: - yield (revid, parents_according_to_index, - parents_according_to_revision) + yield (revid, parents_according_to_index, parents_according_to_revision) def _check_for_inconsistent_revision_parents(self): inconsistencies = list(self._find_inconsistent_revision_parents()) if inconsistencies: - raise errors.BzrCheckError( - "Revision knit has inconsistent parents.") + raise errors.BzrCheckError("Revision knit has inconsistent parents.") def _get_sink(self): """Return a sink for streaming into this repository.""" @@ -1678,22 +1762,22 @@ def _get_source(self, to_format): def reconcile(self, other=None, thorough=False): """Reconcile this repository.""" from .reconcile import VersionedFileRepoReconciler + with self.lock_write(): reconciler = VersionedFileRepoReconciler(self, thorough=thorough) return reconciler.reconcile() -class MetaDirVersionedFileRepository(MetaDirRepository, - VersionedFileRepository): +class MetaDirVersionedFileRepository(MetaDirRepository, VersionedFileRepository): """Repositories in a meta-dir, that work via versioned file objects.""" def __init__(self, _format, a_controldir, control_files): - super().__init__(_format, a_controldir, - control_files) + super().__init__(_format, a_controldir, control_files) -class MetaDirVersionedFileRepositoryFormat(RepositoryFormatMetaDir, - VersionedFileRepositoryFormat): +class MetaDirVersionedFileRepositoryFormat( + RepositoryFormatMetaDir, VersionedFileRepositoryFormat +): """Base class for repository formats using versioned files in metadirs.""" @@ -1717,8 +1801,7 @@ def insert_missing_keys(self, source, missing_keys): :return: keys still missing """ stream = source.get_stream_for_missing_keys(missing_keys) - return self.insert_stream_without_locking(stream, - self.target_repo._format) + return self.insert_stream_without_locking(stream, self.target_repo._format) def insert_stream(self, stream, src_format, resume_tokens): """Insert a stream's content into the target repository. @@ -1737,8 +1820,9 @@ def insert_stream(self, stream, src_format, resume_tokens): is_resume = False try: # locked_insert_stream performs a commit|suspend. - missing_keys = self.insert_stream_without_locking(stream, - src_format, is_resume) + missing_keys = self.insert_stream_without_locking( + stream, src_format, is_resume + ) if missing_keys: # suspend the write group and tell the caller what we is # missing. We know we can suspend or else we would not have @@ -1748,17 +1832,18 @@ def insert_stream(self, stream, src_format, resume_tokens): return write_group_tokens, missing_keys hint = self.target_repo.commit_write_group() dest_format = self.target_repo._format - if ((dest_format._revision_serializer != src_format._revision_serializer or - dest_format._inventory_serializer != src_format._inventory_serializer) - and self.target_repo._format.pack_compresses): + if ( + dest_format._revision_serializer != src_format._revision_serializer + or dest_format._inventory_serializer + != src_format._inventory_serializer + ) and self.target_repo._format.pack_compresses: self.target_repo.pack(hint=hint) return [], set() except: self.target_repo.abort_write_group(suppress_errors=True) raise - def insert_stream_without_locking(self, stream, src_format, - is_resume=False): + def insert_stream_without_locking(self, stream, src_format, is_resume=False): """Insert a stream's content into the target repository. This assumes that you already have a locked repository and an active @@ -1774,11 +1859,13 @@ def insert_stream_without_locking(self, stream, src_format, if not self.target_repo.is_write_locked(): raise errors.ObjectNotLocked(self) if not self.target_repo.is_in_write_group(): - raise errors.BzrError('you must already be in a write group') + raise errors.BzrError("you must already be in a write group") dest_format = self.target_repo._format new_pack = None - if (src_format._revision_serializer == dest_format._revision_serializer - and src_format._inventory_serializer == dest_format._inventory_serializer): + if ( + src_format._revision_serializer == dest_format._revision_serializer + and src_format._inventory_serializer == dest_format._inventory_serializer + ): # If serializers match and the target is a pack repository, set the # write cache size on the new pack. This avoids poor performance # on transports where append is unbuffered (such as @@ -1797,54 +1884,63 @@ def insert_stream_without_locking(self, stream, src_format, else: new_pack.set_write_cache_size(1024 * 1024) for substream_type, substream in stream: - if debug.debug_flag_enabled('stream'): - mutter('inserting substream: %s', substream_type) - if substream_type == 'texts': + if debug.debug_flag_enabled("stream"): + mutter("inserting substream: %s", substream_type) + if substream_type == "texts": self.target_repo.texts.insert_record_stream(substream) - elif substream_type == 'inventories': - if src_format._inventory_serializer == dest_format._inventory_serializer: - self.target_repo.inventories.insert_record_stream( - substream) + elif substream_type == "inventories": + if ( + src_format._inventory_serializer + == dest_format._inventory_serializer + ): + self.target_repo.inventories.insert_record_stream(substream) else: self._extract_and_insert_inventories( - substream, src_format._inventory_serializer) - elif substream_type == 'inventory-deltas': + substream, src_format._inventory_serializer + ) + elif substream_type == "inventory-deltas": self._extract_and_insert_inventory_deltas( - substream, src_format._inventory_serializer) - elif substream_type == 'chk_bytes': + substream, src_format._inventory_serializer + ) + elif substream_type == "chk_bytes": # XXX: This doesn't support conversions, as it assumes the # conversion was done in the fetch code. self.target_repo.chk_bytes.insert_record_stream(substream) - elif substream_type == 'revisions': + elif substream_type == "revisions": # This may fallback to extract-and-insert more often than # required if the serializers are different only in terms of # the inventory, since we also need to update the .inventory_sha1 field - if (src_format._revision_serializer == dest_format._revision_serializer - and src_format._inventory_serializer == dest_format._inventory_serializer): + if ( + src_format._revision_serializer == dest_format._revision_serializer + and src_format._inventory_serializer + == dest_format._inventory_serializer + ): self.target_repo.revisions.insert_record_stream(substream) else: - self._extract_and_insert_revisions(substream, - src_format._revision_serializer) - elif substream_type == 'signatures': + self._extract_and_insert_revisions( + substream, src_format._revision_serializer + ) + elif substream_type == "signatures": self.target_repo.signatures.insert_record_stream(substream) else: - raise AssertionError(f'kaboom! {substream_type}') + raise AssertionError(f"kaboom! {substream_type}") # Done inserting data, and the missing_keys calculations will try to # read back from the inserted data, so flush the writes to the new pack # (if this is pack format). if new_pack is not None: - new_pack._write_data(b'', flush=True) + new_pack._write_data(b"", flush=True) # Find all the new revisions (including ones from resume_tokens) missing_keys = self.target_repo.get_missing_parent_inventories( - check_for_missing_texts=is_resume) + check_for_missing_texts=is_resume + ) try: for prefix, versioned_file in ( - ('texts', self.target_repo.texts), - ('inventories', self.target_repo.inventories), - ('revisions', self.target_repo.revisions), - ('signatures', self.target_repo.signatures), - ('chk_bytes', self.target_repo.chk_bytes), - ): + ("texts", self.target_repo.texts), + ("inventories", self.target_repo.inventories), + ("revisions", self.target_repo.revisions), + ("signatures", self.target_repo.signatures), + ("chk_bytes", self.target_repo.chk_bytes), + ): if versioned_file is None: continue # TODO: key is often going to be a StaticTuple object @@ -1854,8 +1950,10 @@ def insert_stream_without_locking(self, stream, src_format, # pass in either a tuple or a StaticTuple as the second # object, so instead we could have: # StaticTuple(prefix) + key here... - missing_keys.update((prefix,) + key for key in - versioned_file.get_missing_compression_parent_keys()) + missing_keys.update( + (prefix,) + key + for key in versioned_file.get_missing_compression_parent_keys() + ) except NotImplementedError: # cannot even attempt suspending, and missing would have failed # during stream insertion. @@ -1865,11 +1963,10 @@ def insert_stream_without_locking(self, stream, src_format, def _extract_and_insert_inventory_deltas(self, substream, serializer): for record in substream: # Insert the delta directly - inventory_delta_bytes = record.get_bytes_as('lines') + inventory_delta_bytes = record.get_bytes_as("lines") deserialiser = inventory_delta.InventoryDeltaDeserializer() try: - parse_result = deserialiser.parse_text_bytes( - inventory_delta_bytes) + parse_result = deserialiser.parse_text_bytes(inventory_delta_bytes) except inventory_delta.IncompatibleInventoryDelta as err: mutter("Incompatible delta: %s", err.msg) raise errors.IncompatibleRevision(self.target_repo._format) from err @@ -1878,10 +1975,10 @@ def _extract_and_insert_inventory_deltas(self, substream, serializer): revision_id = new_id parents = [key[0] for key in record.parents] self.target_repo.add_inventory_by_delta( - basis_id, inv_delta, revision_id, parents) + basis_id, inv_delta, revision_id, parents + ) - def _extract_and_insert_inventories(self, substream, serializer, - parse_delta=None): + def _extract_and_insert_inventories(self, substream, serializer, parse_delta=None): """Generate a new inventory versionedfile in target, converting data. The inventory is retrieved from the source, (deserializing it), and @@ -1890,7 +1987,7 @@ def _extract_and_insert_inventories(self, substream, serializer, for record in substream: # It's not a delta, so it must be a fulltext in the source # serializer's format. - lines = record.get_bytes_as('lines') + lines = record.get_bytes_as("lines") revision_id = record.key[0] inv = serializer.read_inventory_from_lines(lines, revision_id) parents = [key[0] for key in record.parents] @@ -1901,11 +1998,11 @@ def _extract_and_insert_inventories(self, substream, serializer, def _extract_and_insert_revisions(self, substream, serializer): for record in substream: - bytes = record.get_bytes_as('fulltext') + bytes = record.get_bytes_as("fulltext") revision_id = record.key[0] rev = serializer.read_revision_from_string(bytes) if rev.revision_id != revision_id: - raise AssertionError(f'wtf: {rev} != {revision_id}') + raise AssertionError(f"wtf: {rev} != {revision_id}") self.target_repo.add_revision(revision_id, rev) def finished(self): @@ -1921,6 +2018,7 @@ def __init__(self, from_repository, to_format): self.from_repository = from_repository self.to_format = to_format from .recordcounter import RecordCounter + self._record_counter = RecordCounter() def delta_on_metadata(self): @@ -1930,9 +2028,11 @@ def delta_on_metadata(self): """ src_format = self.from_repository._format dest_format = self.to_format - return (self.to_format._fetch_uses_deltas - and src_format._revision_serializer == dest_format._revision_serializer - and src_format._inventory_serializer == dest_format._inventory_serializer) + return ( + self.to_format._fetch_uses_deltas + and src_format._revision_serializer == dest_format._revision_serializer + and src_format._inventory_serializer == dest_format._inventory_serializer + ) def _fetch_revision_texts(self, revs): # fetch signatures first and then the revision texts @@ -1940,19 +2040,19 @@ def _fetch_revision_texts(self, revs): from_sf = self.from_repository.signatures # A missing signature is just skipped. keys = [(rev_id,) for rev_id in revs] - signatures = versionedfile.filter_absent(from_sf.get_record_stream( - keys, - self.to_format._fetch_order, - not self.to_format._fetch_uses_deltas)) + signatures = versionedfile.filter_absent( + from_sf.get_record_stream( + keys, self.to_format._fetch_order, not self.to_format._fetch_uses_deltas + ) + ) # If a revision has a delta, this is actually expanded inside the # insert_record_stream code now, which is an alternate fix for # bug #261339 from_rf = self.from_repository.revisions revisions = from_rf.get_record_stream( - keys, - self.to_format._fetch_order, - not self.delta_on_metadata()) - return [('signatures', signatures), ('revisions', revisions)] + keys, self.to_format._fetch_order, not self.delta_on_metadata() + ) + return [("signatures", signatures), ("revisions", revisions)] def _generate_root_texts(self, revs): """This will be called by get_stream between fetching weave texts and @@ -1960,12 +2060,13 @@ def _generate_root_texts(self, revs): """ if self._rich_root_upgrade(): return _mod_fetch.Inter1and2Helper( - self.from_repository).generate_root_texts(revs) + self.from_repository + ).generate_root_texts(revs) else: return [] def get_stream(self, search): - phase = 'file' + phase = "file" revs = search.get_keys() graph = self.from_repository.get_graph() revs = tsort.topo_sort(graph.get_parent_map(revs)) @@ -1977,14 +2078,18 @@ def get_stream(self, search): # Make a new progress bar for this phase if knit_kind == "file": # Accumulate file texts - text_keys.extend([(file_id, revision) for revision in - revisions]) + text_keys.extend([(file_id, revision) for revision in revisions]) elif knit_kind == "inventory": # Now copy the file texts. from_texts = self.from_repository.texts - yield ('texts', from_texts.get_record_stream( - text_keys, self.to_format._fetch_order, - not self.to_format._fetch_uses_deltas)) + yield ( + "texts", + from_texts.get_record_stream( + text_keys, + self.to_format._fetch_order, + not self.to_format._fetch_uses_deltas, + ), + ) # Cause an error if a text occurs after we have done the # copy. text_keys = None @@ -2011,29 +2116,31 @@ def get_stream_for_missing_keys(self, missing_keys): # translating (because translation means we don't send # unreconstructable deltas ever). keys = {} - keys['texts'] = set() - keys['revisions'] = set() - keys['inventories'] = set() - keys['chk_bytes'] = set() - keys['signatures'] = set() + keys["texts"] = set() + keys["revisions"] = set() + keys["inventories"] = set() + keys["chk_bytes"] = set() + keys["signatures"] = set() for key in missing_keys: keys[key[0]].add(key[1:]) - if len(keys['revisions']): + if len(keys["revisions"]): # If we allowed copying revisions at this point, we could end up # copying a revision without copying its required texts: a # violation of the requirements for repository integrity. raise AssertionError( - f"cannot copy revisions to fill in missing deltas {keys['revisions']}") + f"cannot copy revisions to fill in missing deltas {keys['revisions']}" + ) for substream_kind, keys in keys.items(): # noqa: B020 vf = getattr(self.from_repository, substream_kind) if vf is None and keys: raise AssertionError( "cannot fill in keys for a versioned file we don't" - f" have: {substream_kind} needs {keys}") + f" have: {substream_kind} needs {keys}" + ) if not keys: # No need to stream something we don't have continue - if substream_kind == 'inventories': + if substream_kind == "inventories": # Some missing keys are genuinely ghosts, filter those out. present = self.from_repository.inventories.get_parent_map(keys) revs = [key[0] for key in present] @@ -2050,44 +2157,53 @@ def get_stream_for_missing_keys(self, missing_keys): # records. The Sink is responsible for doing another check to # ensure that ghosts don't introduce missing data for future # fetches. - stream = versionedfile.filter_absent(vf.get_record_stream(keys, - self.to_format._fetch_order, True)) + stream = versionedfile.filter_absent( + vf.get_record_stream(keys, self.to_format._fetch_order, True) + ) yield substream_kind, stream def inventory_fetch_order(self): if self._rich_root_upgrade(): - return 'topological' + return "topological" else: return self.to_format._fetch_order def _rich_root_upgrade(self): - return (not self.from_repository._format.rich_root_data - and self.to_format.rich_root_data) + return ( + not self.from_repository._format.rich_root_data + and self.to_format.rich_root_data + ) def _get_inventory_stream(self, revision_ids, missing=False): from_format = self.from_repository._format - if (from_format.supports_chks and self.to_format.supports_chks - and from_format.network_name() == self.to_format.network_name()): - raise AssertionError( - "this case should be handled by GroupCHKStreamSource") - elif debug.debug_flag_enabled('forceinvdeltas'): - return self._get_convertable_inventory_stream(revision_ids, - delta_versus_null=missing) + if ( + from_format.supports_chks + and self.to_format.supports_chks + and from_format.network_name() == self.to_format.network_name() + ): + raise AssertionError("this case should be handled by GroupCHKStreamSource") + elif debug.debug_flag_enabled("forceinvdeltas"): + return self._get_convertable_inventory_stream( + revision_ids, delta_versus_null=missing + ) elif from_format.network_name() == self.to_format.network_name(): # Same format. - return self._get_simple_inventory_stream(revision_ids, - missing=missing) - elif (not from_format.supports_chks and not self.to_format.supports_chks and - from_format._revision_serializer == self.to_format._revision_serializer and - from_format._inventory_serializer == self.to_format._inventory_serializer): + return self._get_simple_inventory_stream(revision_ids, missing=missing) + elif ( + not from_format.supports_chks + and not self.to_format.supports_chks + and from_format._revision_serializer == self.to_format._revision_serializer + and from_format._inventory_serializer + == self.to_format._inventory_serializer + ): # Essentially the same format. - return self._get_simple_inventory_stream(revision_ids, - missing=missing) + return self._get_simple_inventory_stream(revision_ids, missing=missing) else: # Any time we switch serializations, we want to use an # inventory-delta based approach. - return self._get_convertable_inventory_stream(revision_ids, - delta_versus_null=missing) + return self._get_convertable_inventory_stream( + revision_ids, delta_versus_null=missing + ) def _get_simple_inventory_stream(self, revision_ids, missing=False): # NB: This currently reopens the inventory weave in source; @@ -2097,12 +2213,16 @@ def _get_simple_inventory_stream(self, revision_ids, missing=False): delta_closure = True else: delta_closure = not self.delta_on_metadata() - yield ('inventories', from_weave.get_record_stream( - [(rev_id,) for rev_id in revision_ids], - self.inventory_fetch_order(), delta_closure)) - - def _get_convertable_inventory_stream(self, revision_ids, - delta_versus_null=False): + yield ( + "inventories", + from_weave.get_record_stream( + [(rev_id,) for rev_id in revision_ids], + self.inventory_fetch_order(), + delta_closure, + ), + ) + + def _get_convertable_inventory_stream(self, revision_ids, delta_versus_null=False): # The two formats are sufficiently different that there is no fast # path, so we need to send just inventorydeltas, which any # sufficiently modern client can insert into any repository. @@ -2110,9 +2230,12 @@ def _get_convertable_inventory_stream(self, revision_ids, # convert on the target, so we need to put bytes-on-the-wire that can # be converted. That means inventory deltas (if the remote is <1.19, # RemoteStreamSink will fallback to VFS to insert the deltas). - yield ('inventory-deltas', - self._stream_invs_as_deltas(revision_ids, - delta_versus_null=delta_versus_null)) + yield ( + "inventory-deltas", + self._stream_invs_as_deltas( + revision_ids, delta_versus_null=delta_versus_null + ), + ) def _stream_invs_as_deltas(self, revision_ids, delta_versus_null=False): """Return a stream of inventory-deltas for the given rev ids. @@ -2128,19 +2251,20 @@ def _stream_invs_as_deltas(self, revision_ids, delta_versus_null=False): parent_map = from_repo.inventories.get_parent_map(revision_keys) # XXX: possibly repos could implement a more efficient iter_inv_deltas # method... - inventories = self.from_repository.iter_inventories( - revision_ids, 'topological') + inventories = self.from_repository.iter_inventories(revision_ids, "topological") format = from_repo._format invs_sent_so_far = {_mod_revision.NULL_REVISION} inventory_cache = lru_cache.LRUCache(50) null_inventory = from_repo.revision_tree( - _mod_revision.NULL_REVISION).root_inventory + _mod_revision.NULL_REVISION + ).root_inventory # XXX: ideally the rich-root/tree-refs flags would be per-revision, not # per-repo (e.g. streaming a non-rich-root revision out of a rich-root # repo back into a non-rich-root repo ought to be allowed) serializer = inventory_delta.InventoryDeltaSerializer( versioned_root=format.rich_root_data, - tree_references=format.supports_tree_reference) + tree_references=format.supports_tree_reference, + ) for inv in inventories: key = (inv.revision_id,) parent_keys = parent_map.get(key, ()) @@ -2162,8 +2286,7 @@ def _stream_invs_as_deltas(self, revision_ids, delta_versus_null=False): if parent_inv is None: parent_inv = from_repo.get_inventory(parent_id) candidate_delta = inv._make_delta(parent_inv) - if (delta is None - or len(delta) > len(candidate_delta)): + if delta is None or len(delta) > len(candidate_delta): delta = candidate_delta basis_id = parent_id if delta is None: @@ -2175,15 +2298,16 @@ def _stream_invs_as_deltas(self, revision_ids, delta_versus_null=False): inventory_cache[inv.revision_id] = inv delta_serialized = serializer.delta_to_lines(basis_id, key[-1], delta) yield versionedfile.ChunkedContentFactory( - key, parent_keys, None, delta_serialized, chunks_are_lines=True) + key, parent_keys, None, delta_serialized, chunks_are_lines=True + ) class _VersionedFileChecker: - def __init__(self, repository, text_key_references=None, ancestors=None): self.repository = repository self.text_index = self.repository._generate_text_key_index( - text_key_references=text_key_references, ancestors=ancestors) + text_key_references=text_key_references, ancestors=ancestors + ) def calculate_file_version_parents(self, text_key): """Calculate the correct parents for a file version according to @@ -2223,14 +2347,13 @@ def _check_file_version_parents(self, texts, progress_bar): self.file_ids = {file_id for file_id, _ in self.text_index} # text keys is now grouped by file_id n_versions = len(self.text_index) - progress_bar.update(gettext('loading text store'), 0, n_versions) + progress_bar.update(gettext("loading text store"), 0, n_versions) parent_map = self.repository.texts.get_parent_map(self.text_index) # On unlistable transports this could well be empty/error... text_keys = self.repository.texts.keys() unused_keys = frozenset(text_keys) - set(self.text_index) for num, key in enumerate(self.text_index): - progress_bar.update( - gettext('checking text graph'), num, n_versions) + progress_bar.update(gettext("checking text graph"), num, n_versions) correct_parents = self.calculate_file_version_parents(key) try: knit_parents = parent_map[key] @@ -2243,13 +2366,11 @@ def _check_file_version_parents(self, texts, progress_bar): class InterVersionedFileRepository(InterRepository): - _walk_to_common_revisions_batch_size = 50 supports_fetch_spec = True - def fetch(self, revision_id=None, find_ghosts=False, - fetch_spec=None, lossy=False): + def fetch(self, revision_id=None, find_ghosts=False, fetch_spec=None, lossy=False): """Fetch the content required to construct revision_id. The content is copied from self.source to self.target. @@ -2262,22 +2383,27 @@ def fetch(self, revision_id=None, find_ghosts=False, raise errors.LossyPushToSameVCS(self.source, self.target) if self.target._format.experimental: ui.ui_factory.show_user_warning( - 'experimental_format_fetch', + "experimental_format_fetch", from_format=self.source._format, - to_format=self.target._format) + to_format=self.target._format, + ) from .fetch import RepoFetcher # See asking for a warning here if self.source._format.network_name() != self.target._format.network_name(): ui.ui_factory.show_user_warning( - 'cross_format_fetch', from_format=self.source._format, - to_format=self.target._format) + "cross_format_fetch", + from_format=self.source._format, + to_format=self.target._format, + ) with self.lock_write(): - RepoFetcher(to_repository=self.target, - from_repository=self.source, - last_revision=revision_id, - fetch_spec=fetch_spec, - find_ghosts=find_ghosts) + RepoFetcher( + to_repository=self.target, + from_repository=self.source, + last_revision=revision_id, + fetch_spec=fetch_spec, + find_ghosts=find_ghosts, + ) return FetchResult() def _walk_to_common_revisions(self, revision_ids, if_present_ids=None): @@ -2325,8 +2451,7 @@ def _walk_to_common_revisions(self, revision_ids, if_present_ids=None): if ghosts_to_check: # One of the caller's revision_ids is a ghost in both the # source and the target. - raise errors.NoSuchRevision( - self.source, ghosts_to_check.pop()) + raise errors.NoSuchRevision(self.source, ghosts_to_check.pop()) missing_revs.update(next_revs - have_revs) # Because we may have walked past the original stop point, make # sure everything is stopped @@ -2335,12 +2460,13 @@ def _walk_to_common_revisions(self, revision_ids, if_present_ids=None): if searcher_exhausted: break (started_keys, excludes, included_keys) = searcher.get_state() - return vf_search.SearchResult(started_keys, excludes, - len(included_keys), included_keys) + return vf_search.SearchResult( + started_keys, excludes, len(included_keys), included_keys + ) - def search_missing_revision_ids(self, - find_ghosts=True, revision_ids=None, if_present_ids=None, - limit=None): + def search_missing_revision_ids( + self, find_ghosts=True, revision_ids=None, if_present_ids=None, limit=None + ): """Return the revision ids that source has that target does not. :param revision_ids: return revision ids included by these @@ -2356,10 +2482,12 @@ def search_missing_revision_ids(self, """ with self.lock_read(): # stop searching at found target revisions. - if not find_ghosts and (revision_ids is not None or if_present_ids is - not None): - result = self._walk_to_common_revisions(revision_ids, - if_present_ids=if_present_ids) + if not find_ghosts and ( + revision_ids is not None or if_present_ids is not None + ): + result = self._walk_to_common_revisions( + revision_ids, if_present_ids=if_present_ids + ) if limit is None: return result result_set = result.get_keys() @@ -2367,7 +2495,8 @@ def search_missing_revision_ids(self, # generic, possibly worst case, slow code path. target_ids = set(self.target.all_revision_ids()) source_ids = self._present_source_revisions_for( - revision_ids, if_present_ids) + revision_ids, if_present_ids + ) result_set = set(source_ids).difference(target_ids) if limit is not None: topo_ordered = self.source.get_graph().iter_topo_order(result_set) @@ -2398,10 +2527,11 @@ def _present_source_revisions_for(self, revision_ids, if_present_ids=None): if missing: raise errors.NoSuchRevision(self.source, missing.pop()) found_ids = all_wanted_ids.intersection(present_revs) - source_ids = [rev_id for (rev_id, parents) in - graph.iter_ancestry(found_ids) - if rev_id != _mod_revision.NULL_REVISION and - parents is not None] + source_ids = [ + rev_id + for (rev_id, parents) in graph.iter_ancestry(found_ids) + if rev_id != _mod_revision.NULL_REVISION and parents is not None + ] else: source_ids = self.source.all_revision_ids() return set(source_ids) @@ -2413,12 +2543,13 @@ def _get_repo_format_to_test(self): @classmethod def is_compatible(cls, source, target): # The default implementation is compatible with everything - return (source._format.supports_full_versioned_files - and target._format.supports_full_versioned_files) + return ( + source._format.supports_full_versioned_files + and target._format.supports_full_versioned_files + ) class InterDifferingSerializer(InterVersionedFileRepository): - @classmethod def _get_repo_format_to_test(cls): return None @@ -2432,24 +2563,26 @@ def is_compatible(source, target): # This is redundant with format.check_conversion_target(), however that # raises an exception, and we just want to say "False" as in we won't # support converting between these formats. - if debug.debug_flag_enabled('IDS_never'): + if debug.debug_flag_enabled("IDS_never"): return False if source.supports_rich_root() and not target.supports_rich_root(): return False - if (source._format.supports_tree_reference and - not target._format.supports_tree_reference): + if ( + source._format.supports_tree_reference + and not target._format.supports_tree_reference + ): return False if target._fallback_repositories and target._format.supports_chks: # IDS doesn't know how to copy CHKs for the parent inventories it # adds to stacked repos. return False - if debug.debug_flag_enabled('IDS_always'): + if debug.debug_flag_enabled("IDS_always"): return True # Only use this code path for local source and target. IDS does far # too much IO (both bandwidth and roundtrips) over a network. - if not source.controldir.transport.base.startswith('file:///'): + if not source.controldir.transport.base.startswith("file:///"): return False - if not target.controldir.transport.base.startswith('file:///'): + if not target.controldir.transport.base.startswith("file:///"): return False return True @@ -2490,7 +2623,7 @@ def _get_delta_for_revision(self, tree, parent_ids, possible_trees): # Rich roots are handled elsewhere... continue kind = new_entry.kind - if kind != 'directory' and kind != 'file': + if kind != "directory" and kind != "file": # No text record associated with this inventory entry. continue # This is a directory or file that has changed somehow. @@ -2508,24 +2641,23 @@ def _fetch_parent_invs_for_stacking(self, parent_map, cache): source may be not have _fallback_repositories even though it is stacked.) """ - parent_revs = set(itertools.chain.from_iterable( - parent_map.values())) + parent_revs = set(itertools.chain.from_iterable(parent_map.values())) present_parents = self.source.get_parent_map(parent_revs) absent_parents = parent_revs.difference(present_parents) parent_invs_keys_for_stacking = self.source.inventories.get_parent_map( - (rev_id,) for rev_id in absent_parents) + (rev_id,) for rev_id in absent_parents + ) parent_inv_ids = [key[-1] for key in parent_invs_keys_for_stacking] for parent_tree in self.source.revision_trees(parent_inv_ids): current_revision_id = parent_tree.get_revision_id() - parents_parents_keys = parent_invs_keys_for_stacking[ - (current_revision_id,)] + parents_parents_keys = parent_invs_keys_for_stacking[(current_revision_id,)] parents_parents = [key[-1] for key in parents_parents_keys] basis_id = _mod_revision.NULL_REVISION basis_tree = self.source.revision_tree(basis_id) - delta = parent_tree.root_inventory._make_delta( - basis_tree.root_inventory) + delta = parent_tree.root_inventory._make_delta(basis_tree.root_inventory) self.target.add_inventory_by_delta( - basis_id, delta, current_revision_id, parents_parents) + basis_id, delta, current_revision_id, parents_parents + ) cache[current_revision_id] = parent_tree def _fetch_batch(self, revision_ids, basis_id, cache): @@ -2559,14 +2691,15 @@ def _fetch_batch(self, revision_ids, basis_id, cache): # There either aren't any parents, or the parents are ghosts, # so just use the last converted tree. possible_trees.append((basis_id, cache[basis_id])) - basis_id, delta = self._get_delta_for_revision(tree, parent_ids, - possible_trees) + basis_id, delta = self._get_delta_for_revision( + tree, parent_ids, possible_trees + ) revision = self.source.get_revision(current_revision_id) - pending_deltas.append((basis_id, delta, - current_revision_id, revision.parent_ids)) + pending_deltas.append( + (basis_id, delta, current_revision_id, revision.parent_ids) + ) if self._converting_to_rich_root: - self._revision_id_to_root_id[current_revision_id] = \ - tree.path2id('') + self._revision_id_to_root_id[current_revision_id] = tree.path2id("") # Determine which texts are in present in this revision but not in # any of the available parents. texts_possibly_new_in_tree = set() @@ -2606,12 +2739,19 @@ def _fetch_batch(self, revision_ids, basis_id, cache): to_texts = self.target.texts if root_keys_to_create: root_stream = _mod_fetch._new_root_data_stream( - root_keys_to_create, self._revision_id_to_root_id, parent_map, - self.source) + root_keys_to_create, + self._revision_id_to_root_id, + parent_map, + self.source, + ) to_texts.insert_record_stream(root_stream) - to_texts.insert_record_stream(from_texts.get_record_stream( - text_keys, self.target._format._fetch_order, - not self.target._format._fetch_uses_deltas)) + to_texts.insert_record_stream( + from_texts.get_record_stream( + text_keys, + self.target._format._fetch_order, + not self.target._format._fetch_uses_deltas, + ) + ) # insert inventory deltas for delta in pending_deltas: self.target.add_inventory_by_delta(*delta) @@ -2640,17 +2780,17 @@ def _fetch_batch(self, revision_ids, basis_id, cache): # There either aren't any parents, or the parents are # ghosts, so just use the last converted tree. possible_trees.append((basis_id, cache[basis_id])) - basis_id, delta = self._get_delta_for_revision(parent_tree, - parents_parents, possible_trees) + basis_id, delta = self._get_delta_for_revision( + parent_tree, parents_parents, possible_trees + ) self.target.add_inventory_by_delta( - basis_id, delta, current_revision_id, parents_parents) + basis_id, delta, current_revision_id, parents_parents + ) # insert signatures and revisions for revision in pending_revisions: try: - signature = self.source.get_signature_text( - revision.revision_id) - self.target.add_signature_text(revision.revision_id, - signature) + signature = self.source.get_signature_text(revision.revision_id) + self.target.add_signature_text(revision.revision_id, signature) except errors.NoSuchRevision: pass self.target.add_revision(revision.revision_id, revision) @@ -2674,9 +2814,8 @@ def _fetch_all_revisions(self, revision_ids, pb): for offset in range(0, len(revision_ids), batch_size): self.target.start_write_group() try: - pb.update(gettext('Transferring revisions'), offset, - len(revision_ids)) - batch = revision_ids[offset:offset + batch_size] + pb.update(gettext("Transferring revisions"), offset, len(revision_ids)) + batch = revision_ids[offset : offset + batch_size] basis_id = self._fetch_batch(batch, basis_id, cache) except: self.source._safe_to_return_from_cache = False @@ -2688,11 +2827,11 @@ def _fetch_all_revisions(self, revision_ids, pb): hints.extend(hint) if hints and self.target._format.pack_compresses: self.target.pack(hint=hints) - pb.update(gettext('Transferring revisions'), len(revision_ids), - len(revision_ids)) + pb.update( + gettext("Transferring revisions"), len(revision_ids), len(revision_ids) + ) - def fetch(self, revision_id=None, find_ghosts=False, - fetch_spec=None, lossy=False): + def fetch(self, revision_id=None, find_ghosts=False, fetch_spec=None, lossy=False): """See InterRepository.fetch().""" if lossy: raise errors.LossyPushToSameVCS(self.source, self.target) @@ -2701,20 +2840,23 @@ def fetch(self, revision_id=None, find_ghosts=False, else: revision_ids = None if self.source._format.experimental: - ui.ui_factory.show_user_warning('experimental_format_fetch', - from_format=self.source._format, - to_format=self.target._format) - if (not self.source.supports_rich_root() and - self.target.supports_rich_root()): + ui.ui_factory.show_user_warning( + "experimental_format_fetch", + from_format=self.source._format, + to_format=self.target._format, + ) + if not self.source.supports_rich_root() and self.target.supports_rich_root(): self._converting_to_rich_root = True self._revision_id_to_root_id = {} else: self._converting_to_rich_root = False # See asking for a warning here if self.source._format.network_name() != self.target._format.network_name(): - ui.ui_factory.show_user_warning('cross_format_fetch', - from_format=self.source._format, - to_format=self.target._format) + ui.ui_factory.show_user_warning( + "cross_format_fetch", + from_format=self.source._format, + to_format=self.target._format, + ) with self.lock_write(): if revision_ids is None: if revision_id: @@ -2722,12 +2864,15 @@ def fetch(self, revision_id=None, find_ghosts=False, else: search_revision_ids = None revision_ids = self.target.search_missing_revision_ids( - self.source, revision_ids=search_revision_ids, - find_ghosts=find_ghosts).get_keys() + self.source, + revision_ids=search_revision_ids, + find_ghosts=find_ghosts, + ).get_keys() if not revision_ids: return FetchResult(0) revision_ids = tsort.topo_sort( - self.source.get_graph().get_parent_map(revision_ids)) + self.source.get_graph().get_parent_map(revision_ids) + ) if not revision_ids: return FetchResult(0) # Walk though all revisions; get inventory deltas, copy referenced @@ -2774,6 +2919,7 @@ def _get_repo_format_to_test(self): non-subtree, so we test this with the richest repository format. """ from breezy.bzr import knitrepo + return knitrepo.RepositoryFormatKnit3() @staticmethod @@ -2781,7 +2927,8 @@ def is_compatible(source, target): return ( InterRepository._same_model(source, target) and source._format.supports_full_versioned_files - and target._format.supports_full_versioned_files) + and target._format.supports_full_versioned_files + ) InterRepository.register_optimiser(InterVersionedFileRepository) @@ -2798,15 +2945,14 @@ def install_revisions(repository, iterable, num_revisions=None, pb=None): with WriteGroup(repository): inventory_cache = lru_cache.LRUCache(10) for n, (revision, revision_tree, signature) in enumerate(iterable): - _install_revision(repository, revision, revision_tree, signature, - inventory_cache) + _install_revision( + repository, revision, revision_tree, signature, inventory_cache + ) if pb is not None: - pb.update(gettext('Transferring revisions'), - n + 1, num_revisions) + pb.update(gettext("Transferring revisions"), n + 1, num_revisions) -def _install_revision(repository, rev, revision_tree, signature, - inventory_cache): +def _install_revision(repository, rev, revision_tree, signature, inventory_cache): """Install all revision data into a repository.""" present_parents = [] parent_trees = {} @@ -2815,8 +2961,7 @@ def _install_revision(repository, rev, revision_tree, signature, present_parents.append(p_id) parent_trees[p_id] = repository.revision_tree(p_id) else: - parent_trees[p_id] = repository.revision_tree( - _mod_revision.NULL_REVISION) + parent_trees[p_id] = repository.revision_tree(_mod_revision.NULL_REVISION) # FIXME: Support nested trees inv = revision_tree.root_inventory @@ -2863,8 +3008,9 @@ def _install_revision(repository, rev, revision_tree, signature, repository.add_inventory(rev.revision_id, inv, present_parents) else: delta = inv._make_delta(basis_inv) - repository.add_inventory_by_delta(rev.parent_ids[0], delta, - rev.revision_id, present_parents) + repository.add_inventory_by_delta( + rev.parent_ids[0], delta, rev.revision_id, present_parents + ) else: repository.add_inventory(rev.revision_id, inv, present_parents) except errors.RevisionAlreadyPresent: diff --git a/breezy/bzr/vf_search.py b/breezy/bzr/vf_search.py index 6c35300e82..2938c3885e 100644 --- a/breezy/bzr/vf_search.py +++ b/breezy/bzr/vf_search.py @@ -58,21 +58,26 @@ def __init__(self, start_keys, exclude_keys, key_count, keys): a SearchResult from a smart server, in which case the keys list is not necessarily immediately available. """ - self._recipe = ('search', start_keys, exclude_keys, key_count) + self._recipe = ("search", start_keys, exclude_keys, key_count) self._keys = frozenset(keys) def __repr__(self): kind, start_keys, exclude_keys, key_count = self._recipe if len(start_keys) > 5: - start_keys_repr = repr(list(start_keys)[:5])[:-1] + ', ...]' + start_keys_repr = repr(list(start_keys)[:5])[:-1] + ", ...]" else: start_keys_repr = repr(start_keys) if len(exclude_keys) > 5: - exclude_keys_repr = repr(list(exclude_keys)[:5])[:-1] + ', ...]' + exclude_keys_repr = repr(list(exclude_keys)[:5])[:-1] + ", ...]" else: exclude_keys_repr = repr(exclude_keys) - return '<%s %s:(%s, %s, %d)>' % (self.__class__.__name__, - kind, start_keys_repr, exclude_keys_repr, key_count) + return "<%s %s:(%s, %s, %d)>" % ( + self.__class__.__name__, + kind, + start_keys_repr, + exclude_keys_repr, + key_count, + ) def get_recipe(self): """Return a recipe that can be used to replay this search. @@ -98,11 +103,13 @@ def get_recipe(self): return self._recipe def get_network_struct(self): - start_keys = b' '.join(self._recipe[1]) - stop_keys = b' '.join(self._recipe[2]) - count = str(self._recipe[3]).encode('ascii') - return (self._recipe[0].encode('ascii'), - b'\n'.join((start_keys, stop_keys, count))) + start_keys = b" ".join(self._recipe[1]) + stop_keys = b" ".join(self._recipe[2]) + count = str(self._recipe[3]).encode("ascii") + return ( + self._recipe[0].encode("ascii"), + b"\n".join((start_keys, stop_keys, count)), + ) def get_keys(self): """Return the keys found in this search. @@ -161,10 +168,10 @@ def __init__(self, heads, repo): def __repr__(self): if len(self.heads) > 5: heads_repr = repr(list(self.heads)[:5])[:-1] - heads_repr += ', <%d more>...]' % (len(self.heads) - 5,) + heads_repr += ", <%d more>...]" % (len(self.heads) - 5,) else: heads_repr = repr(self.heads) - return f'<{self.__class__.__name__} heads:{heads_repr} repo:{self.repo!r}>' + return f"<{self.__class__.__name__} heads:{heads_repr} repo:{self.repo!r}>" def get_recipe(self): """Return a recipe that can be used to replay this search. @@ -177,10 +184,10 @@ def get_recipe(self): To recreate this result, create a PendingAncestryResult with the start_keys_set. """ - return ('proxy-search', self.heads, set(), -1) + return ("proxy-search", self.heads, set(), -1) def get_network_struct(self): - parts = [b'ancestry-of'] + parts = [b"ancestry-of"] parts.extend(self.heads) return parts @@ -194,8 +201,11 @@ def get_keys(self): def _get_keys(self, graph): NULL_REVISION = revision.NULL_REVISION - keys = [key for (key, parents) in graph.iter_ancestry(self.heads) - if key != NULL_REVISION and parents is not None] + keys = [ + key + for (key, parents) in graph.iter_ancestry(self.heads) + if key != NULL_REVISION and parents is not None + ] return keys def is_empty(self): @@ -230,21 +240,23 @@ def __init__(self, repo): self._repo = repo def __repr__(self): - return f'{self.__class__.__name__}({self._repo!r})' + return f"{self.__class__.__name__}({self._repo!r})" def get_recipe(self): raise NotImplementedError(self.get_recipe) def get_network_struct(self): - return (b'everything',) + return (b"everything",) def get_keys(self): - if debug.debug_flag_enabled('evil'): + if debug.debug_flag_enabled("evil"): from . import remote + if isinstance(self._repo, remote.RemoteRepository): # warn developers (not users) not to do this trace.mutter_callsite( - 2, "EverythingResult(RemoteRepository).get_keys() is slow.") + 2, "EverythingResult(RemoteRepository).get_keys() is slow." + ) return self._repo.all_revision_ids() def is_empty(self): @@ -270,14 +282,22 @@ def __init__(self, to_repo, from_repo, find_ghosts=False): def execute(self): return self.to_repo.search_missing_revision_ids( - self.from_repo, find_ghosts=self.find_ghosts) + self.from_repo, find_ghosts=self.find_ghosts + ) class NotInOtherForRevs(AbstractSearch): """Find all revisions missing in one repo for a some specific heads.""" - def __init__(self, to_repo, from_repo, required_ids, if_present_ids=None, - find_ghosts=False, limit=None): + def __init__( + self, + to_repo, + from_repo, + required_ids, + if_present_ids=None, + find_ghosts=False, + limit=None, + ): """Constructor. :param required_ids: revision IDs of heads that must be found, or else @@ -298,25 +318,35 @@ def __init__(self, to_repo, from_repo, required_ids, if_present_ids=None, def __repr__(self): if len(self.required_ids) > 5: - reqd_revs_repr = repr(list(self.required_ids)[:5])[:-1] + ', ...]' + reqd_revs_repr = repr(list(self.required_ids)[:5])[:-1] + ", ...]" else: reqd_revs_repr = repr(self.required_ids) if self.if_present_ids and len(self.if_present_ids) > 5: - ifp_revs_repr = repr(list(self.if_present_ids)[:5])[:-1] + ', ...]' + ifp_revs_repr = repr(list(self.if_present_ids)[:5])[:-1] + ", ...]" else: ifp_revs_repr = repr(self.if_present_ids) - return ("<{} from:{!r} to:{!r} find_ghosts:{!r} req'd:{!r} if-present:{!r}" - "limit:{!r}>").format( - self.__class__.__name__, self.from_repo, self.to_repo, - self.find_ghosts, reqd_revs_repr, ifp_revs_repr, - self.limit) + return ( + "<{} from:{!r} to:{!r} find_ghosts:{!r} req'd:{!r} if-present:{!r}" + "limit:{!r}>" + ).format( + self.__class__.__name__, + self.from_repo, + self.to_repo, + self.find_ghosts, + reqd_revs_repr, + ifp_revs_repr, + self.limit, + ) def execute(self): return self.to_repo.search_missing_revision_ids( - self.from_repo, revision_ids=self.required_ids, - if_present_ids=self.if_present_ids, find_ghosts=self.find_ghosts, - limit=self.limit) + self.from_repo, + revision_ids=self.required_ids, + if_present_ids=self.if_present_ids, + find_ghosts=self.find_ghosts, + limit=self.limit, + ) def search_result_from_parent_map(parent_map, missing_keys): @@ -333,8 +363,10 @@ def search_result_from_parent_map(parent_map, missing_keys): # stop either. stop_keys.difference_update(missing_keys) key_count = len(parent_map) - if (revision.NULL_REVISION in result_parents - and revision.NULL_REVISION in missing_keys): + if ( + revision.NULL_REVISION in result_parents + and revision.NULL_REVISION in missing_keys + ): # If we pruned NULL_REVISION from the stop_keys because it's also # in our cache of "missing" keys we need to increment our key count # by 1, because the reconsitituted SearchResult on the server will @@ -413,8 +445,7 @@ def _find_possible_heads(parent_map, tip_keys, depth): return heads -def limited_search_result_from_parent_map(parent_map, missing_keys, tip_keys, - depth): +def limited_search_result_from_parent_map(parent_map, missing_keys, tip_keys, depth): """Transform a parent_map that is searching 'tip_keys' into an approximate SearchResult. diff --git a/breezy/bzr/weave.py b/breezy/bzr/weave.py index 75562e6ee0..c404192fee 100755 --- a/breezy/bzr/weave.py +++ b/breezy/bzr/weave.py @@ -90,7 +90,6 @@ class WeaveError(errors.BzrError): - _fmt = "Error in processing weave: %(msg)s" def __init__(self, msg=None): @@ -99,18 +98,15 @@ def __init__(self, msg=None): class WeaveRevisionAlreadyPresent(WeaveError): - _fmt = "Revision {%(revision_id)s} already present in %(weave)s" def __init__(self, revision_id, weave): - WeaveError.__init__(self) self.revision_id = revision_id self.weave = weave class WeaveRevisionNotPresent(WeaveError): - _fmt = "Revision {%(revision_id)s} not present in %(weave)s" def __init__(self, revision_id, weave): @@ -120,7 +116,6 @@ def __init__(self, revision_id, weave): class WeaveFormatError(WeaveError): - _fmt = "Weave invariant violated: %(what)s" def __init__(self, what): @@ -129,19 +124,18 @@ def __init__(self, what): class WeaveParentMismatch(WeaveError): - _fmt = "Parents are mismatched between two revisions. %(msg)s" class WeaveInvalidChecksum(WeaveError): - _fmt = "Text did not match its checksum: %(msg)s" class WeaveTextDiffers(WeaveError): - - _fmt = ("Weaves differ on text content. Revision:" - " {%(revision_id)s}, %(weave_a)s, %(weave_b)s") + _fmt = ( + "Weaves differ on text content. Revision:" + " {%(revision_id)s}, %(weave_a)s, %(weave_b)s" + ) def __init__(self, revision_id, weave_a, weave_b): WeaveError.__init__(self) @@ -163,22 +157,22 @@ def __init__(self, version, weave): self.key = (version,) parents = weave.get_parent_map([version])[version] self.parents = tuple((parent,) for parent in parents) - self.storage_kind = 'fulltext' + self.storage_kind = "fulltext" self._weave = weave def get_bytes_as(self, storage_kind): - if storage_kind == 'fulltext': + if storage_kind == "fulltext": return self._weave.get_text(self.key[-1]) - elif storage_kind in ('chunked', 'lines'): + elif storage_kind in ("chunked", "lines"): return self._weave.get_lines(self.key[-1]) else: - raise UnavailableRepresentation(self.key, storage_kind, 'fulltext') + raise UnavailableRepresentation(self.key, storage_kind, "fulltext") def iter_bytes_as(self, storage_kind): - if storage_kind in ('chunked', 'lines'): + if storage_kind in ("chunked", "lines"): return iter(self._weave.get_lines(self.key[-1])) else: - raise UnavailableRepresentation(self.key, storage_kind, 'fulltext') + raise UnavailableRepresentation(self.key, storage_kind, "fulltext") class Weave(VersionedFile): @@ -270,11 +264,25 @@ class Weave(VersionedFile): Set by read_weave. """ - __slots__ = ['_weave', '_parents', '_sha1s', '_names', '_name_map', - '_weave_name', '_matcher', '_allow_reserved'] - - def __init__(self, weave_name=None, access_mode='w', matcher=None, - get_scope=None, allow_reserved=False): + __slots__ = [ + "_weave", + "_parents", + "_sha1s", + "_names", + "_name_map", + "_weave_name", + "_matcher", + "_allow_reserved", + ] + + def __init__( + self, + weave_name=None, + access_mode="w", + matcher=None, + get_scope=None, + allow_reserved=False, + ): """Create a weave. :param get_scope: A callable that returns an opaque object to be used @@ -293,8 +301,10 @@ def __init__(self, weave_name=None, access_mode='w', matcher=None, else: self._matcher = matcher if get_scope is None: + def get_scope(): return None + self._get_scope = get_scope self._scope = get_scope() self._access_mode = access_mode @@ -307,7 +317,7 @@ def _check_write_ok(self): """Is the versioned file marked as 'finished' ? Raise if it is.""" if self._get_scope() != self._scope: raise errors.OutSideTransaction() - if self._access_mode != 'w': + if self._access_mode != "w": raise errors.ReadOnlyObjectDirtiedError(self) def copy(self): @@ -327,9 +337,11 @@ def copy(self): def __eq__(self, other): if not isinstance(other, Weave): return False - return self._parents == other._parents \ - and self._weave == other._weave \ + return ( + self._parents == other._parents + and self._weave == other._weave and self._sha1s == other._sha1s + ) def __ne__(self, other): return not self.__eq__(other) @@ -352,7 +364,7 @@ def versions(self): def has_version(self, version_id): """See VersionedFile.has_version.""" - return (version_id in self._name_map) + return version_id in self._name_map __contains__ = has_version @@ -370,13 +382,14 @@ def get_record_stream(self, versions, ordering, include_delta_closure): valid until the iterator is advanced. """ from .. import tsort + versions = [version[-1] for version in versions] - if ordering == 'topological': + if ordering == "topological": parents = self.get_parent_map(versions) new_versions = tsort.topo_sort(parents) new_versions.extend(set(versions).difference(set(parents))) versions = new_versions - elif ordering == 'groupcompress': + elif ordering == "groupcompress": parents = self.get_parent_map(versions) new_versions = sort_groupcompress(parents) new_versions.extend(set(versions).difference(set(parents))) @@ -396,8 +409,8 @@ def get_parent_map(self, version_ids): else: try: parents = tuple( - map(self._idx_to_name, - self._parents[self._lookup(version_id)])) + map(self._idx_to_name, self._parents[self._lookup(version_id)]) + ) except RevisionNotPresent: continue result[version_id] = parents @@ -416,23 +429,21 @@ def insert_record_stream(self, stream): adapters = {} for record in stream: # Raise an error when a record is missing. - if record.storage_kind == 'absent': + if record.storage_kind == "absent": raise RevisionNotPresent([record.key[0]], self) # adapt to non-tuple interface parents = [parent[0] for parent in record.parents] - if record.storage_kind in ('fulltext', 'chunked', 'lines'): - self.add_lines( - record.key[0], parents, - record.get_bytes_as('lines')) + if record.storage_kind in ("fulltext", "chunked", "lines"): + self.add_lines(record.key[0], parents, record.get_bytes_as("lines")) else: - adapter_key = record.storage_kind, 'lines' + adapter_key = record.storage_kind, "lines" try: adapter = adapters[adapter_key] except KeyError: adapter_factory = adapter_registry.get(adapter_key) adapter = adapter_factory(self) adapters[adapter_key] = adapter - lines = adapter.get_bytes(record, 'lines') + lines = adapter.get_bytes(record, "lines") try: self.add_lines(record.key[0], parents, lines) except RevisionAlreadyPresent: @@ -444,17 +455,25 @@ def _check_repeated_add(self, name, parents, text, sha1): If it is, return the (old) index; otherwise raise an exception. """ idx = self._lookup(name) - if sorted(self._parents[idx]) != sorted(parents) \ - or sha1 != self._sha1s[idx]: + if sorted(self._parents[idx]) != sorted(parents) or sha1 != self._sha1s[idx]: raise RevisionAlreadyPresent(name, self._weave_name) return idx - def _add_lines(self, version_id, parents, lines, parent_texts, - left_matching_blocks, nostore_sha, random_id, - check_content): + def _add_lines( + self, + version_id, + parents, + lines, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + check_content, + ): """See VersionedFile.add_lines.""" - idx = self._add(version_id, lines, list(map(self._lookup, parents)), - nostore_sha=nostore_sha) + idx = self._add( + version_id, lines, list(map(self._lookup, parents)), nostore_sha=nostore_sha + ) return sha_strings(lines), sum(map(len, lines)), idx def _add(self, version_id, lines, parents, sha1=None, nostore_sha=None): @@ -504,9 +523,9 @@ def _add(self, version_id, lines, parents, sha1=None, nostore_sha=None): # even more specially, if we're adding an empty text we # need do nothing at all. if lines: - self._weave.append((b'{', new_version)) + self._weave.append((b"{", new_version)) self._weave.extend(lines) - self._weave.append((b'}', None)) + self._weave.append((b"}", None)) return new_version if len(parents) == 1: @@ -517,7 +536,6 @@ def _add(self, version_id, lines, parents, sha1=None, nostore_sha=None): ancestors = self._inclusions(parents) - # basis a list of (origin, lineno, line) basis_lineno = [] basis_lines = [] @@ -551,15 +569,15 @@ def _add(self, version_id, lines, parents, sha1=None, nostore_sha=None): # i1,i2 are given in offsets within basis_lines; we need to map # them back to offsets within the entire weave print 'raw match', # tag, i1, i2, j1, j2 - if tag == 'equal': + if tag == "equal": continue i1 = basis_lineno[i1] i2 = basis_lineno[i2] # the deletion and insertion are handled separately. # first delete the region. if i1 != i2: - self._weave.insert(i1 + offset, (b'[', new_version)) - self._weave.insert(i2 + offset + 1, (b']', new_version)) + self._weave.insert(i1 + offset, (b"[", new_version)) + self._weave.insert(i2 + offset + 1, (b"]", new_version)) offset += 2 if j1 != j2: @@ -567,9 +585,7 @@ def _add(self, version_id, lines, parents, sha1=None, nostore_sha=None): # i2; we want to insert after this region to make sure # we don't destroy ourselves i = i2 + offset - self._weave[i:i] = ([(b'{', new_version)] + - lines[j1:j2] + - [(b'}', None)]) + self._weave[i:i] = [(b"{", new_version)] + lines[j1:j2] + [(b"}", None)] offset += 2 + (j2 - j1) return new_version @@ -613,21 +629,21 @@ def annotate(self, version_id): The index indicates when the line originated in the weave. """ incls = [self._lookup(version_id)] - return [(self._idx_to_name(origin), text) for origin, lineno, text in - self._extract(incls)] + return [ + (self._idx_to_name(origin), text) + for origin, lineno, text in self._extract(incls) + ] - def iter_lines_added_or_present_in_versions(self, version_ids=None, - pb=None): + def iter_lines_added_or_present_in_versions(self, version_ids=None, pb=None): """See VersionedFile.iter_lines_added_or_present_in_versions().""" if version_ids is None: version_ids = self.versions() version_ids = set(version_ids) - for _lineno, inserted, _deletes, line in self._walk_internal( - version_ids): + for _lineno, inserted, _deletes, line in self._walk_internal(version_ids): if inserted not in version_ids: continue - if not line.endswith(b'\n'): - yield line + b'\n', inserted + if not line.endswith(b"\n"): + yield line + b"\n", inserted else: yield line, inserted @@ -636,31 +652,31 @@ def _walk_internal(self, version_ids=None): istack = [] dset = set() - lineno = 0 # line of weave, 0-based + lineno = 0 # line of weave, 0-based for l in self._weave: if l.__class__ == tuple: c, v = l - if c == b'{': + if c == b"{": istack.append(self._names[v]) - elif c == b'}': + elif c == b"}": istack.pop() - elif c == b'[': + elif c == b"[": dset.add(self._names[v]) - elif c == b']': + elif c == b"]": dset.remove(self._names[v]) else: - raise WeaveFormatError(f'unexpected instruction {v!r}') + raise WeaveFormatError(f"unexpected instruction {v!r}") else: yield lineno, istack[-1], frozenset(dset), l lineno += 1 if istack: - raise WeaveFormatError("unclosed insertion blocks " - "at end of weave: %s" % istack) - if dset: raise WeaveFormatError( - f"unclosed deletion blocks at end of weave: {dset}") + "unclosed insertion blocks " "at end of weave: %s" % istack + ) + if dset: + raise WeaveFormatError(f"unclosed deletion blocks at end of weave: {dset}") def plan_merge(self, ver_a, ver_b): """Return pseudo-annotation indicating how the two versions merge. @@ -674,38 +690,37 @@ def plan_merge(self, ver_a, ver_b): inc_b = self.get_ancestry([ver_b]) inc_c = inc_a & inc_b - for _lineno, insert, deleteset, line in self._walk_internal( - [ver_a, ver_b]): + for _lineno, insert, deleteset, line in self._walk_internal([ver_a, ver_b]): if deleteset & inc_c: # killed in parent; can't be in either a or b # not relevant to our work - yield 'killed-base', line + yield "killed-base", line elif insert in inc_c: # was inserted in base killed_a = bool(deleteset & inc_a) killed_b = bool(deleteset & inc_b) if killed_a and killed_b: - yield 'killed-both', line + yield "killed-both", line elif killed_a: - yield 'killed-a', line + yield "killed-a", line elif killed_b: - yield 'killed-b', line + yield "killed-b", line else: - yield 'unchanged', line + yield "unchanged", line elif insert in inc_a: if deleteset & inc_a: - yield 'ghost-a', line + yield "ghost-a", line else: # new in A; not in B - yield 'new-a', line + yield "new-a", line elif insert in inc_b: if deleteset & inc_b: - yield 'ghost-b', line + yield "ghost-b", line else: - yield 'new-b', line + yield "new-b", line else: # not in either revision - yield 'irrelevant', line + yield "irrelevant", line def _extract(self, versions): """Yield annotation of lines in included set. @@ -726,7 +741,7 @@ def _extract(self, versions): iset = set() dset = set() - lineno = 0 # line of weave, 0-based + lineno = 0 # line of weave, 0-based isactive = None @@ -760,32 +775,31 @@ def _extract(self, versions): if l.__class__ == tuple: c, v = l isactive = None - if c == b'{': + if c == b"{": istack.append(v) iset.add(v) - elif c == b'}': + elif c == b"}": iset.remove(istack.pop()) - elif c == b'[': + elif c == b"[": if v in included: dset.add(v) - elif c == b']': + elif c == b"]": if v in included: dset.remove(v) else: raise AssertionError() else: if isactive is None: - isactive = (not dset) and istack and ( - istack[-1] in included) + isactive = (not dset) and istack and (istack[-1] in included) if isactive: result.append((istack[-1], lineno, l)) lineno += 1 if istack: - raise WeaveFormatError("unclosed insertion blocks " - "at end of weave: %s" % istack) - if dset: raise WeaveFormatError( - f"unclosed deletion blocks at end of weave: {dset}") + "unclosed insertion blocks " "at end of weave: %s" % istack + ) + if dset: + raise WeaveFormatError(f"unclosed deletion blocks at end of weave: {dset}") return result def _maybe_lookup(self, name_or_index): @@ -803,14 +817,15 @@ def _maybe_lookup(self, name_or_index): def get_lines(self, version_id): """See VersionedFile.get_lines().""" int_index = self._maybe_lookup(version_id) - result = [line for (origin, lineno, line) - in self._extract([int_index])] + result = [line for (origin, lineno, line) in self._extract([int_index])] expected_sha1 = self._sha1s[int_index] measured_sha1 = sha_strings(result) if measured_sha1 != expected_sha1: raise WeaveInvalidChecksum( - 'file {}, revision {}, expected: {}, measured {}'.format(self._weave_name, version_id, - expected_sha1, measured_sha1)) + "file {}, revision {}, expected: {}, measured {}".format( + self._weave_name, version_id, expected_sha1, measured_sha1 + ) + ) return result def get_sha1s(self, version_ids): @@ -837,7 +852,8 @@ def check(self, progress_bar=None): if inclusions[-1] >= version: raise WeaveFormatError( "invalid included version %d for index %d" - % (inclusions[-1], version)) + % (inclusions[-1], version) + ) # try extracting all versions; parallel extraction is used nv = self.num_versions() @@ -855,16 +871,15 @@ def check(self, progress_bar=None): new_inc.update(inclusions[self._idx_to_name(p)]) if new_inc != self.get_ancestry(name): - raise AssertionError( - f'failed {new_inc} != {self.get_ancestry(name)}') + raise AssertionError(f"failed {new_inc} != {self.get_ancestry(name)}") inclusions[name] = new_inc nlines = len(self._weave) - update_text = 'checking weave' + update_text = "checking weave" if self._weave_name: short_name = os.path.basename(self._weave_name) - update_text = f'checking {short_name}' + update_text = f"checking {short_name}" update_text = update_text[:25] for lineno, insert, deleteset, line in self._walk_internal(): @@ -875,8 +890,7 @@ def check(self, progress_bar=None): # The active inclusion must be an ancestor, # and no ancestors must have deleted this line, # because we don't support resurrection. - if ((insert in name_inclusions) and - not (deleteset & name_inclusions)): + if (insert in name_inclusions) and not (deleteset & name_inclusions): sha1s[name].update(line) for i in range(nv): @@ -886,7 +900,8 @@ def check(self, progress_bar=None): if hd != expected: raise WeaveInvalidChecksum( f"mismatched sha1 for version {version}: " - f"got {hd}, expected {expected}") + f"got {hd}, expected {expected}" + ) # TODO: check insertions are properly nested, that there are # no lines outside of insertion blocks, that deletions are @@ -899,7 +914,9 @@ def _imported_parents(self, other, other_idx): parent_name = other._names[parent_idx] if parent_name not in self._name_map: # should not be possible - raise WeaveError(f"missing parent {{{parent_name}}} of {{{other._name_map[other_idx]}}} in {self!r}") + raise WeaveError( + f"missing parent {{{parent_name}}} of {{{other._name_map[other_idx]}}} in {self!r}" + ) new_parents.append(self._name_map[parent_name]) return new_parents @@ -926,10 +943,10 @@ def _check_version_consistent(self, other, other_idx, name): n2 = {other._names[i] for i in other_parents} if not self._compatible_parents(n1, n2): raise WeaveParentMismatch( - "inconsistent parents " - f"for version {{{name}}}: {n1} vs {n2}") + "inconsistent parents " f"for version {{{name}}}: {n1} vs {n2}" + ) else: - return True # ok! + return True # ok! else: return False @@ -946,23 +963,29 @@ def _reweave(self, other, pb, msg): def _copy_weave_content(self, otherweave): """Adsorb the content from otherweave.""" for attr in self.__slots__: - if attr != '_weave_name': + if attr != "_weave_name": setattr(self, attr, copy(getattr(otherweave, attr))) class WeaveFile(Weave): """A WeaveFile represents a Weave on disk and writes on change.""" - WEAVE_SUFFIX = '.weave' - - def __init__(self, name, transport, filemode=None, create=False, - access_mode='w', get_scope=None): + WEAVE_SUFFIX = ".weave" + + def __init__( + self, + name, + transport, + filemode=None, + create=False, + access_mode="w", + get_scope=None, + ): """Create a WeaveFile. :param create: If not True, only open an existing knit. """ - super().__init__(name, access_mode, get_scope=get_scope, - allow_reserved=False) + super().__init__(name, access_mode, get_scope=get_scope, allow_reserved=False) self._transport = transport self._filemode = filemode try: @@ -974,14 +997,29 @@ def __init__(self, name, transport, filemode=None, create=False, # new file, save it self._save() - def _add_lines(self, version_id, parents, lines, parent_texts, - left_matching_blocks, nostore_sha, random_id, - check_content): + def _add_lines( + self, + version_id, + parents, + lines, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + check_content, + ): """Add a version and save the weave.""" self.check_not_reserved_id(version_id) result = super()._add_lines( - version_id, parents, lines, parent_texts, left_matching_blocks, - nostore_sha, random_id, check_content) + version_id, + parents, + lines, + parent_texts, + left_matching_blocks, + nostore_sha, + random_id, + check_content, + ) self._save() return result @@ -1033,6 +1071,7 @@ def _reweave(wa, wb, pb=None, msg=None): :param msg: An optional message for the progress """ from .. import tsort + wr = Weave() # first determine combined parents of all versions # map from version name -> all parent names @@ -1042,7 +1081,7 @@ def _reweave(wa, wb, pb=None, msg=None): mutter("order to reweave: %r", order) if pb and not msg: - msg = 'reweave' + msg = "reweave" for idx, name in enumerate(order): if pb: @@ -1052,12 +1091,16 @@ def _reweave(wa, wb, pb=None, msg=None): if name in wb._name_map: lines_b = wb.get_lines(name) if lines != lines_b: - mutter('Weaves differ on content. rev_id {%s}', name) - mutter('weaves: %s, %s', wa._weave_name, wb._weave_name) + mutter("Weaves differ on content. rev_id {%s}", name) + mutter("weaves: %s, %s", wa._weave_name, wb._weave_name) import difflib - lines = list(difflib.unified_diff(lines, lines_b, - wa._weave_name, wb._weave_name)) - mutter('lines:\n%s', ''.join(lines)) + + lines = list( + difflib.unified_diff( + lines, lines_b, wa._weave_name, wb._weave_name + ) + ) + mutter("lines:\n%s", "".join(lines)) raise WeaveTextDiffers(name, wa, wb) else: lines = wb.get_lines(name) diff --git a/breezy/bzr/weavefile.py b/breezy/bzr/weavefile.py index 86640c43cb..6c8a900731 100644 --- a/breezy/bzr/weavefile.py +++ b/breezy/bzr/weavefile.py @@ -38,7 +38,7 @@ # an iterator returning the weave lines... We don't really need to # deserialize it into memory. -FORMAT_1 = b'# bzr weave file v5\n' +FORMAT_1 = b"# bzr weave file v5\n" def write_weave(weave, f, format=None): @@ -56,38 +56,39 @@ def write_weave_v5(weave, f): if included: # mininc = weave.minimal_parents(version) mininc = included - f.write(b'i ') - f.write(b' '.join(b'%d' % i for i in mininc)) - f.write(b'\n') + f.write(b"i ") + f.write(b" ".join(b"%d" % i for i in mininc)) + f.write(b"\n") else: - f.write(b'i\n') - f.write(b'1 ' + weave._sha1s[version] + b'\n') - f.write(b'n ' + weave._names[version] + b'\n') - f.write(b'\n') + f.write(b"i\n") + f.write(b"1 " + weave._sha1s[version] + b"\n") + f.write(b"n " + weave._names[version] + b"\n") + f.write(b"\n") - f.write(b'w\n') + f.write(b"w\n") for l in weave._weave: if isinstance(l, tuple): - if l[0] == b'}': - f.write(b'}\n') + if l[0] == b"}": + f.write(b"}\n") else: - f.write(l[0] + b' %d\n' % l[1]) + f.write(l[0] + b" %d\n" % l[1]) else: # text line if not l: - f.write(b', \n') - elif l.endswith(b'\n'): - f.write(b'. ' + l) + f.write(b", \n") + elif l.endswith(b"\n"): + f.write(b". " + l) else: - f.write(b', ' + l + b'\n') + f.write(b", " + l + b"\n") - f.write(b'W\n') + f.write(b"W\n") def read_weave(f): # FIXME: detect the weave type and dispatch from .weave import Weave - w = Weave(getattr(f, 'name', None)) + + w = Weave(getattr(f, "name", None)) _read_weave_v5(f, w) return w @@ -121,10 +122,10 @@ def _read_weave_v5(f, w): try: l = next(lines) except StopIteration as err: - raise WeaveFormatError('invalid weave file: no header') from err + raise WeaveFormatError("invalid weave file: no header") from err if l != FORMAT_1: - raise WeaveFormatError(f'invalid weave file header: {l!r}') + raise WeaveFormatError(f"invalid weave file header: {l!r}") ver = 0 # read weave header. @@ -132,10 +133,10 @@ def _read_weave_v5(f, w): try: l = next(lines) except StopIteration as err: - raise WeaveFormatError('unexpected end of file') from err - if l[0:1] == b'i': + raise WeaveFormatError("unexpected end of file") from err + if l[0:1] == b"i": if len(l) > 2: - w._parents.append(list(map(int, l[2:].split(b' ')))) + w._parents.append(list(map(int, l[2:].split(b" ")))) else: w._parents.append([]) l = next(lines)[:-1] @@ -146,25 +147,25 @@ def _read_weave_v5(f, w): w._name_map[name] = ver l = next(lines) ver += 1 - elif l == b'w\n': + elif l == b"w\n": break else: - raise WeaveFormatError(f'unexpected line {l!r}') + raise WeaveFormatError(f"unexpected line {l!r}") # read weave body while True: try: l = next(lines) except StopIteration as err: - raise WeaveFormatError('unexpected end of file') from err - if l == b'W\n': + raise WeaveFormatError("unexpected end of file") from err + if l == b"W\n": break - elif b'. ' == l[0:2]: + elif b". " == l[0:2]: w._weave.append(l[2:]) # include newline - elif b', ' == l[0:2]: - w._weave.append(l[2:-1]) # exclude newline - elif l == b'}\n': - w._weave.append((b'}', None)) + elif b", " == l[0:2]: + w._weave.append(l[2:-1]) # exclude newline + elif l == b"}\n": + w._weave.append((b"}", None)) else: - w._weave.append((l[0:1], int(l[2:].decode('ascii')))) + w._weave.append((l[0:1], int(l[2:].decode("ascii")))) return w diff --git a/breezy/bzr/workingtree.py b/breezy/bzr/workingtree.py index d8c268f866..f5fdf1e2b2 100644 --- a/breezy/bzr/workingtree.py +++ b/breezy/bzr/workingtree.py @@ -46,7 +46,9 @@ # is guaranteed to be registered. from . import bzrdir -lazy_import.lazy_import(globals(), """ +lazy_import.lazy_import( + globals(), + """ from breezy import ( cache_utf8, conflicts as _mod_conflicts, @@ -60,7 +62,8 @@ inventory, serializer, ) -""") +""", +) from .. import errors, osutils from .. import revision as _mod_revision @@ -86,13 +89,14 @@ # impossible as there is no clear relationship between the working tree format # and the conflict list file format. CONFLICT_HEADER_1 = b"BZR conflict list format 1" -ERROR_PATH_NOT_FOUND = 3 # WindowsError errno code, equivalent to ENOENT +ERROR_PATH_NOT_FOUND = 3 # WindowsError errno code, equivalent to ENOENT class InventoryModified(errors.InternalBzrError): - - _fmt = ("The current inventory for the tree %(tree)r has been modified," - " so a clean inventory cannot be read without data loss.") + _fmt = ( + "The current inventory for the tree %(tree)r has been modified," + " so a clean inventory cannot be read without data loss." + ) def __init__(self, tree): self.tree = tree @@ -108,21 +112,28 @@ class InventoryWorkingTree(WorkingTree, MutableInventoryTree): not listed in the Inventory and vice versa. """ - def __init__(self, basedir='.', - branch=None, - _inventory=None, - _control_files=None, - _internal=False, - _format=None, - _controldir=None): + def __init__( + self, + basedir=".", + branch=None, + _inventory=None, + _control_files=None, + _internal=False, + _format=None, + _controldir=None, + ): """Construct a InventoryWorkingTree instance. This is not a public API. :param branch: A branch to override probing for the branch. """ super().__init__( - basedir=basedir, branch=branch, - _transport=_control_files._transport, _internal=_internal, - _format=_format, _controldir=_controldir) + basedir=basedir, + branch=branch, + _transport=_control_files._transport, + _internal=_internal, + _format=_format, + _controldir=_controldir, + ) self._control_files = _control_files self._detect_case_handling() @@ -163,15 +174,14 @@ def _detect_case_handling(self): def transform(self, pb=None): from .transform import InventoryTreeTransform + return InventoryTreeTransform(self, pb=pb) def _setup_directory_is_tree_reference(self): if self._branch.repository._format.supports_tree_reference: - self._directory_is_tree_reference = \ - self._directory_may_be_tree_reference + self._directory_is_tree_reference = self._directory_may_be_tree_reference else: - self._directory_is_tree_reference = \ - self._directory_is_never_tree_reference + self._directory_is_tree_reference = self._directory_is_never_tree_reference def _directory_is_never_tree_reference(self, relpath): return False @@ -193,11 +203,12 @@ def _directory_may_be_tree_reference(self, relpath): def _serialize(self, inventory, out_file): from .xml5 import inventory_serializer_v5 - inventory_serializer_v5.write_inventory( - self._inventory, out_file, working=True) + + inventory_serializer_v5.write_inventory(self._inventory, out_file, working=True) def _deserialize(self, in_file): from .xml5 import inventory_serializer_v5 + return inventory_serializer_v5.read_inventory(in_file) def break_lock(self): @@ -283,18 +294,19 @@ def set_inventory(self, new_inventory_list): InventoryFile, InventoryLink, ) + with self.lock_tree_write(): - inv = Inventory(self.path2id('')) + inv = Inventory(self.path2id("")) for path, file_id, parent, kind in new_inventory_list: name = os.path.basename(path) if name == "": continue # fixme, there should be a factory function inv,add_?? - if kind == 'directory': + if kind == "directory": inv.add(InventoryDirectory(file_id, name, parent)) - elif kind == 'file': + elif kind == "file": inv.add(InventoryFile(file_id, name, parent)) - elif kind == 'symlink': + elif kind == "symlink": inv.add(InventoryLink(file_id, name, parent)) else: raise errors.BzrError(f"unknown kind {kind!r}") @@ -303,14 +315,13 @@ def set_inventory(self, new_inventory_list): def _write_basis_inventory(self, xml): """Write the basis inventory XML to the basis-inventory file.""" path = self._basis_inventory_name() - sio = BytesIO(b''.join(xml)) - self._transport.put_file(path, sio, - mode=self.controldir._get_file_mode()) + sio = BytesIO(b"".join(xml)) + self._transport.put_file(path, sio, mode=self.controldir._get_file_mode()) def _reset_data(self): """Reset transient data that cannot be revalidated.""" self._inventory_is_modified = False - with self._transport.get('inventory') as f: + with self._transport.get("inventory") as f: result = self._deserialize(f) self._set_inventory(result, dirty=False) @@ -319,6 +330,7 @@ def store_uncommitted(self): with self.lock_write(): target_tree = self.basis_tree() from ..shelf import ShelfCreator + shelf_creator = ShelfCreator(self, target_tree) try: if not shelf_creator.shelve_all(): @@ -327,8 +339,7 @@ def store_uncommitted(self): shelf_creator.transform() finally: shelf_creator.finalize() - note('Uncommitted changes stored in branch "%s".', - self.branch.nick) + note('Uncommitted changes stored in branch "%s".', self.branch.nick) def restore_uncommitted(self): """Restore uncommitted changes from the branch into the tree.""" @@ -347,6 +358,7 @@ def restore_uncommitted(self): def get_shelf_manager(self): """Return the ShelfManager for this WorkingTree.""" from ..shelf import ShelfManager + return ShelfManager(self, self._transport) def set_root_id(self, file_id): @@ -354,8 +366,7 @@ def set_root_id(self, file_id): with self.lock_tree_write(): # for compatibility if file_id is None: - raise ValueError( - 'WorkingTree.set_root_id with fileid=None') + raise ValueError("WorkingTree.set_root_id with fileid=None") self._set_root_id(file_id) def _set_root_id(self, file_id): @@ -373,8 +384,7 @@ def _set_root_id(self, file_id): # unlinkit from the byid index inv.change_root_id(file_id) - def remove(self, files, verbose=False, to_file=None, keep_files=True, - force=False): + def remove(self, files, verbose=False, to_file=None, keep_files=True, force=False): """Remove nominated files from the working tree metadata. :files: File paths relative to the basedir. @@ -408,7 +418,6 @@ def recurse_directory_to_add_files(directory): files_to_backup.append(relpath) with self.lock_tree_write(): - for filename in files: # Get file name into canonical form. abspath = self.abspath(filename) @@ -429,21 +438,26 @@ def recurse_directory_to_add_files(directory): # Bail out if we are going to delete files we shouldn't if not keep_files and not force: for change in self.iter_changes( - self.basis_tree(), include_unchanged=True, - require_versioned=False, want_unversioned=True, - specific_files=files): + self.basis_tree(), + include_unchanged=True, + require_versioned=False, + want_unversioned=True, + specific_files=files, + ): if change.versioned[0] is False: # The record is unknown or newly added files_to_backup.append(change.path[1]) - elif (change.changed_content and (change.kind[1] is not None) - and osutils.is_inside_any(files, change.path[1])): + elif ( + change.changed_content + and (change.kind[1] is not None) + and osutils.is_inside_any(files, change.path[1]) + ): # Versioned and changed, but not deleted, and still # in one of the dirs to be deleted. files_to_backup.append(change.path[1]) def backup(file_to_backup): - backup_name = self.controldir._available_backup_name( - file_to_backup) + backup_name = self.controldir._available_backup_name(file_to_backup) osutils.rename(abs_path, self.abspath(backup_name)) return f"removed {file_to_backup} (but kept a copy: {backup_name})" @@ -459,14 +473,13 @@ def backup(file_to_backup): # having removed it, it must be either ignored or # unknown if self.is_ignored(f): - new_status = 'I' + new_status = "I" else: - new_status = '?' + new_status = "?" # XXX: Really should be a more abstract reporter # interface kind_ch = osutils.kind_marker(self.kind(f)) - to_file.write( - new_status + ' ' + f + kind_ch + '\n') + to_file.write(new_status + " " + f + kind_ch + "\n") # Unversion file inv_delta.append((f, None, fid, None)) message = f"removed {f}" @@ -474,8 +487,7 @@ def backup(file_to_backup): if not keep_files: abs_path = self.abspath(f) if osutils.lexists(abs_path): - if (osutils.isdir(abs_path) - and len(os.listdir(abs_path)) > 0): + if osutils.isdir(abs_path) and len(os.listdir(abs_path)) > 0: if force: osutils.rmtree(abs_path) message = f"deleted {f}" @@ -512,8 +524,9 @@ def set_parent_trees(self, parents_list, allow_leftmost_as_ghost=False): _mod_revision.check_not_reserved_id(revision_id) with self.lock_tree_write(): - self._check_parents_for_ghosts(parent_ids, - allow_leftmost_as_ghost=allow_leftmost_as_ghost) + self._check_parents_for_ghosts( + parent_ids, allow_leftmost_as_ghost=allow_leftmost_as_ghost + ) parent_ids = self._filter_parent_ids_by_ancestry(parent_ids) @@ -530,8 +543,7 @@ def set_parent_trees(self, parents_list, allow_leftmost_as_ghost=False): self._cache_basis_inventory(leftmost_parent_id) else: inv = leftmost_parent_tree.root_inventory - xml = self._create_basis_xml_from_inventory( - leftmost_parent_id, inv) + xml = self._create_basis_xml_from_inventory(leftmost_parent_id, inv) self._write_basis_inventory(xml) self._set_merges_from_parent_ids(parent_ids) @@ -553,46 +565,47 @@ def _cache_basis_inventory(self, new_revision): # contain a '"'. lines = self.branch.repository._get_inventory_xml(new_revision) firstline = lines[0] - if (b'revision_id="' not in firstline - or b'format="7"' not in firstline): + if b'revision_id="' not in firstline or b'format="7"' not in firstline: inv = self.branch.repository._inventory_serializer.read_inventory_from_lines( - lines, new_revision) + lines, new_revision + ) lines = self._create_basis_xml_from_inventory(new_revision, inv) self._write_basis_inventory(lines) except (errors.NoSuchRevision, errors.RevisionNotPresent): pass def _basis_inventory_name(self): - return 'basis-inventory-cache' + return "basis-inventory-cache" def _create_basis_xml_from_inventory(self, revision_id, inventory): """Create the text that will be saved in basis-inventory.""" inventory.revision_id = revision_id from .xml7 import inventory_serializer_v7 + return inventory_serializer_v7.write_inventory_to_lines(inventory) def set_conflicts(self, conflicts): conflict_list = _mod_bzr_conflicts.ConflictList(conflicts) with self.lock_tree_write(): - self._put_rio('conflicts', conflict_list.to_stanzas(), - CONFLICT_HEADER_1) + self._put_rio("conflicts", conflict_list.to_stanzas(), CONFLICT_HEADER_1) def add_conflicts(self, new_conflicts): with self.lock_tree_write(): conflict_set = set(self.conflicts()) conflict_set.update(set(new_conflicts)) self.set_conflicts( - sorted(conflict_set, key=_mod_bzr_conflicts.Conflict.sort_key)) + sorted(conflict_set, key=_mod_bzr_conflicts.Conflict.sort_key) + ) def conflicts(self): with self.lock_read(): try: - confile = self._transport.get('conflicts') + confile = self._transport.get("conflicts") except _mod_transport.NoSuchFile: return _mod_bzr_conflicts.ConflictList() try: try: - if next(confile) != CONFLICT_HEADER_1 + b'\n': + if next(confile) != CONFLICT_HEADER_1 + b"\n": raise errors.ConflictFormatError() except StopIteration as err: raise errors.ConflictFormatError() from err @@ -606,7 +619,7 @@ def get_ignore_list(self): Cached in the Tree object after the first call. """ - ignoreset = getattr(self, '_ignoreset', None) + ignoreset = getattr(self, "_ignoreset", None) if ignoreset is not None: return ignoreset @@ -639,9 +652,8 @@ def is_ignored(self, filename): be ignored, otherwise None. So this can simply be used as a boolean if desired. """ - if getattr(self, '_ignoreglobster', None) is None: - self._ignoreglobster = globbing.ExceptionGlobster( - self.get_ignore_list()) + if getattr(self, "_ignoreglobster", None) is None: + self._ignoreglobster = globbing.ExceptionGlobster(self.get_ignore_list()) return self._ignoreglobster.match(filename) def read_basis_inventory(self): @@ -662,7 +674,7 @@ def read_working_inventory(self): with self.lock_read(): if self._inventory_is_modified: raise InventoryModified(self) - with self._transport.get('inventory') as f: + with self._transport.get("inventory") as f: result = self._deserialize(f) self._set_inventory(result, dirty=False) return result @@ -704,10 +716,9 @@ def _check(self, references): with self.lock_read(): tree_basis = self.basis_tree() with tree_basis.lock_read(): - repo_basis = references[('trees', self.last_revision())] + repo_basis = references[("trees", self.last_revision())] if len(list(repo_basis.iter_changes(tree_basis))) > 0: - raise errors.BzrCheckError( - "Mismatched basis inventory content.") + raise errors.BzrCheckError("Mismatched basis inventory content.") self._validate() def check_state(self): @@ -717,7 +728,7 @@ def check_state(self): refs = {} for ref in check_refs: kind, value = ref - if kind == 'trees': + if kind == "trees": refs[ref] = self.branch.repository.revision_tree(value) self._check(refs) @@ -731,8 +742,7 @@ def reset_state(self, revision_ids=None): if revision_ids is None: revision_ids = self.get_parent_ids() if not revision_ids: - rt = self.branch.repository.revision_tree( - _mod_revision.NULL_REVISION) + rt = self.branch.repository.revision_tree(_mod_revision.NULL_REVISION) else: rt = self.branch.repository.revision_tree(revision_ids[0]) self._write_inventory(rt.root_inventory) @@ -741,13 +751,14 @@ def reset_state(self, revision_ids=None): def flush(self): """Write the in memory inventory to disk.""" # TODO: Maybe this should only write on dirty ? - if self._control_files._lock_mode != 'w': + if self._control_files._lock_mode != "w": raise errors.NotWriteLocked(self) sio = BytesIO() self._serialize(self._inventory, sio) sio.seek(0) - self._transport.put_file('inventory', sio, - mode=self.controldir._get_file_mode()) + self._transport.put_file( + "inventory", sio, mode=self.controldir._get_file_mode() + ) self._inventory_is_modified = False def get_file_mtime(self, path): @@ -757,26 +768,27 @@ def get_file_mtime(self, path): except FileNotFoundError as err: raise _mod_transport.NoSuchFile(path) from err - def path_content_summary(self, path, _lstat=os.lstat, - _mapper=osutils.file_kind_from_stat_mode): + def path_content_summary( + self, path, _lstat=os.lstat, _mapper=osutils.file_kind_from_stat_mode + ): """See Tree.path_content_summary.""" abspath = self.abspath(path) try: stat_result = _lstat(abspath) except FileNotFoundError: - return ('missing', None, None, None) + return ("missing", None, None, None) kind = _mapper(stat_result.st_mode) - if kind == 'file': + if kind == "file": return self._file_content_summary(path, stat_result) - elif kind == 'directory': + elif kind == "directory": # perhaps it looks like a plain directory, but it's really a # reference. if self._directory_is_tree_reference(path): - kind = 'tree-reference' + kind = "tree-reference" return kind, None, None, None - elif kind == 'symlink': + elif kind == "symlink": target = osutils.readlink(abspath) - return ('symlink', None, None, target) + return ("symlink", None, None, target) else: return (kind, None, None, None) @@ -784,8 +796,7 @@ def _file_content_summary(self, path, stat_result): size = stat_result.st_size executable = self._is_executable_from_path_and_stat(path, stat_result) # try for a stat cache lookup - return ('file', size, executable, self._sha_from_stat( - path, stat_result)) + return ("file", size, executable, self._sha_from_stat(path, stat_result)) def _is_executable_from_path_and_stat_from_basis(self, path, stat_result): try: @@ -809,11 +820,9 @@ def is_executable(self, path): def _is_executable_from_path_and_stat(self, path, stat_result): if not self._supports_executable(): - return self._is_executable_from_path_and_stat_from_basis( - path, stat_result) + return self._is_executable_from_path_and_stat_from_basis(path, stat_result) else: - return self._is_executable_from_path_and_stat_from_stat( - path, stat_result) + return self._is_executable_from_path_and_stat_from_stat(path, stat_result) def _add(self, files, kinds, ids): """See MutableTree._add.""" @@ -837,7 +846,7 @@ def mkdir(self, path, file_id=None): file_id = generate_ids.gen_file_id(os.path.basename(path)) with self.lock_write(): os.mkdir(self.abspath(path)) - self.add([path], ['directory'], ids=[file_id]) + self.add([path], ["directory"], ids=[file_id]) return file_id def revision_tree(self, revision_id): @@ -849,20 +858,21 @@ def revision_tree(self, revision_id): pass else: from .xml7 import inventory_serializer_v7 + try: inv = inventory_serializer_v7.read_inventory_from_lines(xml_lines) # dont use the repository revision_tree api because we want # to supply the inventory. if inv.revision_id == revision_id: return InventoryRevisionTree( - self.branch.repository, inv, revision_id) + self.branch.repository, inv, revision_id + ) except serializer.BadInventoryFormat: pass # raise if there was no inventory, or if we read the wrong inventory. raise errors.NoSuchRevisionInTree(self, revision_id) - def annotate_iter(self, path, - default_revision=_mod_revision.CURRENT_REVISION): + def annotate_iter(self, path, default_revision=_mod_revision.CURRENT_REVISION): """See Tree.annotate_iter. This implementation will use the basis tree implementation if possible. @@ -881,15 +891,13 @@ def annotate_iter(self, path, try: parent_tree = self.revision_tree(parent_id) except errors.NoSuchRevisionInTree: - parent_tree = self.branch.repository.revision_tree( - parent_id) + parent_tree = self.branch.repository.revision_tree(parent_id) with parent_tree.lock_read(): - try: kind = parent_tree.kind(path) except _mod_transport.NoSuchFile: continue - if kind != 'file': + if kind != "file": # Note: this is slightly unnecessary, because symlinks # and directories have a "text" which is the empty # text, and we know that won't mess up annotations. But @@ -898,7 +906,8 @@ def annotate_iter(self, path, parent_path = parent_tree.id2path(file_id) parent_text_key = ( file_id, - parent_tree.get_file_revision(parent_path)) + parent_tree.get_file_revision(parent_path), + ) if parent_text_key not in maybe_file_parent_keys: maybe_file_parent_keys.append(parent_text_key) graph = self.branch.repository.get_file_graph() @@ -913,15 +922,17 @@ def annotate_iter(self, path, text = self.get_file_text(path) this_key = (file_id, default_revision) annotator.add_special_text(this_key, file_parent_keys, text) - annotations = [(key[-1], line) - for key, line in annotator.annotate_flat(this_key)] + annotations = [ + (key[-1], line) for key, line in annotator.annotate_flat(this_key) + ] return annotations def _put_rio(self, filename, stanzas, header): self._must_be_locked() my_file = osutils.IterableFile(_mod_rio.rio_iter(stanzas, header)) - self._transport.put_file(filename, my_file, - mode=self.controldir._get_file_mode()) + self._transport.put_file( + filename, my_file, mode=self.controldir._get_file_mode() + ) def set_merge_modified(self, modified_hashes): def iter_stanzas(): @@ -929,11 +940,12 @@ def iter_stanzas(): file_id = self.path2id(path) if file_id is None: continue - yield _mod_rio.Stanza(file_id=file_id.decode('utf8'), - hash=sha1.decode('ascii')) + yield _mod_rio.Stanza( + file_id=file_id.decode("utf8"), hash=sha1.decode("ascii") + ) + with self.lock_tree_write(): - self._put_rio('merge-hashes', iter_stanzas(), - MERGE_MODIFIED_HEADER_1) + self._put_rio("merge-hashes", iter_stanzas(), MERGE_MODIFIED_HEADER_1) def merge_modified(self): """Return a dictionary of files modified by a merge. @@ -947,13 +959,13 @@ def merge_modified(self): """ with self.lock_read(): try: - hashfile = self._transport.get('merge-hashes') + hashfile = self._transport.get("merge-hashes") except _mod_transport.NoSuchFile: return {} try: merge_hashes = {} try: - if next(hashfile) != MERGE_MODIFIED_HEADER_1 + b'\n': + if next(hashfile) != MERGE_MODIFIED_HEADER_1 + b"\n": raise errors.MergeModifiedFormatError() except StopIteration as err: raise errors.MergeModifiedFormatError() from err @@ -965,7 +977,7 @@ def merge_modified(self): path = self.id2path(file_id) except errors.NoSuchId: continue - text_hash = s.get("hash").encode('ascii') + text_hash = s.get("hash").encode("ascii") if text_hash == self.get_file_sha1(path): merge_hashes[path] = text_hash return merge_hashes @@ -974,24 +986,29 @@ def merge_modified(self): def subsume(self, other_tree): from .inventory import InventoryDirectory + def add_children(inventory, other_inventory, entry): for child_entry in other_inventory.get_children(entry.file_id).values(): inventory._byid[child_entry.file_id] = child_entry - if child_entry.kind == 'directory': + if child_entry.kind == "directory": add_children(inventory, other_inventory, child_entry) + with self.lock_write(): - if other_tree.path2id('') == self.path2id(''): - raise errors.BadSubsumeSource(self, other_tree, - 'Trees have the same root') + if other_tree.path2id("") == self.path2id(""): + raise errors.BadSubsumeSource( + self, other_tree, "Trees have the same root" + ) try: other_tree_path = self.relpath(other_tree.basedir) except errors.PathNotChild as err: raise errors.BadSubsumeSource( - self, other_tree, 'Tree is not contained by the other') from err + self, other_tree, "Tree is not contained by the other" + ) from err new_root_parent = self.path2id(osutils.dirname(other_tree_path)) if new_root_parent is None: raise errors.BadSubsumeSource( - self, other_tree, 'Parent directory is not versioned.') + self, other_tree, "Parent directory is not versioned." + ) # We need to ensure that the result of a fetch will have a # versionedfile for the other_tree root, and only fetching into # RepositoryKnit2 guarantees that. @@ -1001,7 +1018,8 @@ def add_children(inventory, other_inventory, entry): other_root = InventoryDirectory( other_tree.root_inventory.root.file_id, osutils.basename(other_tree_path), - new_root_parent) + new_root_parent, + ) self.root_inventory.add(other_root) add_children(self.root_inventory, other_tree.root_inventory, other_root) self._write_inventory(self.root_inventory) @@ -1018,6 +1036,7 @@ def extract(self, sub_path, format=None): A new branch will be created, relative to the path for this tree. """ from .inventory import InventoryDirectory + def mkdirs(path): segments = osutils.splitpath(path) transport = self.branch.controldir.root_transport @@ -1063,8 +1082,9 @@ def mkdirs(path): wt._write_inventory(child_inv) return wt - def list_files(self, include_root=False, from_dir=None, recursive=True, - recurse_nested=False): + def list_files( + self, include_root=False, from_dir=None, recursive=True, recurse_nested=False + ): """List all files as (path, class, kind, id, entry). Lists, but does not descend into unversioned directories. @@ -1078,21 +1098,20 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, with contextlib.ExitStack() as exit_stack: exit_stack.enter_context(self.lock_read()) if from_dir is None and include_root is True: - yield ('', 'V', 'directory', self.root_inventory.root) + yield ("", "V", "directory", self.root_inventory.root) # Convert these into local objects to save lookup times pathjoin = osutils.pathjoin # transport.base ends in a slash, we want the piece # between the last two slashes - transport_base_dir = self.controldir.transport.base.rsplit( - '/', 2)[1] + transport_base_dir = self.controldir.transport.base.rsplit("/", 2)[1] fk_entries = { - 'directory': TreeDirectory, - 'file': TreeFile, - 'symlink': TreeLink, - 'tree-reference': TreeReference, - } + "directory": TreeDirectory, + "file": TreeFile, + "symlink": TreeLink, + "tree-reference": TreeReference, + } # directory file_id, relative path, absolute path, reverse sorted # children @@ -1111,10 +1130,15 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, # use a deque and popleft to keep them sorted, or if we use a plain # list and just reverse() them. children = deque(children) - stack = [(from_inv, from_dir_id, '', from_dir_abspath, children)] + stack = [(from_inv, from_dir_id, "", from_dir_abspath, children)] while stack: - (inv, from_dir_id, from_dir_relpath, from_dir_abspath, - children) = stack[-1] + ( + inv, + from_dir_id, + from_dir_relpath, + from_dir_abspath, + children, + ) = stack[-1] while children: f = children.popleft() @@ -1130,20 +1154,20 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, # end in a slash and 'f' doesn't begin with one, we can do # a string op, rather than the checks of pathjoin(), all # relative paths will have an extra slash at the beginning - fp = from_dir_relpath + '/' + f + fp = from_dir_relpath + "/" + f # absolute path - fap = from_dir_abspath + '/' + f + fap = from_dir_abspath + "/" + f dir_ie = inv.get_entry(from_dir_id) - if dir_ie.kind == 'directory': + if dir_ie.kind == "directory": f_ie = inv.get_child(dir_ie.file_id, f) else: f_ie = None if f_ie: - c = 'V' + c = "V" elif self.is_ignored(fp[1:]): - c = 'I' + c = "I" else: # we may not have found this file, because of a unicode # issue, or because the directory was actually a @@ -1151,32 +1175,32 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, f_norm, can_access = osutils.normalized_filename(f) if f == f_norm or not can_access: # No change, so treat this file normally - c = '?' + c = "?" else: # this file can be accessed by a normalized path # check again if it is versioned # these lines are repeated here for performance f = f_norm - fp = from_dir_relpath + '/' + f - fap = from_dir_abspath + '/' + f + fp = from_dir_relpath + "/" + f + fap = from_dir_abspath + "/" + f f_ie = inv.get_child(from_dir_id, f) if f_ie: - c = 'V' + c = "V" elif self.is_ignored(fp[1:]): - c = 'I' + c = "I" else: - c = '?' + c = "?" fk = file_kind(fap) - if fk == 'directory' and self._directory_is_tree_reference(f): + if fk == "directory" and self._directory_is_tree_reference(f): if not recurse_nested: - fk = 'tree-reference' + fk = "tree-reference" else: subtree = self.get_nested_tree(f) exit_stack.enter_context(subtree.lock_read()) inv = subtree.root_inventory f_ie = inv.get_entry(f_ie.file_id) - fk = 'directory' + fk = "directory" # make a last minute entry if f_ie: @@ -1188,7 +1212,7 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, yield fp[1:], c, fk, TreeEntry() continue - if fk != 'directory': + if fk != "directory": continue # But do this child first if recursing down @@ -1242,7 +1266,7 @@ def move(self, from_paths, to_dir=None, after=False): # check for deprecated use of signature if to_dir is None: - raise TypeError('You must supply a target directory') + raise TypeError("You must supply a target directory") # check destination directory if isinstance(from_paths, str): raise ValueError() @@ -1250,27 +1274,32 @@ def move(self, from_paths, to_dir=None, after=False): to_abs = self.abspath(to_dir) if not osutils.isdir(to_abs): raise errors.BzrMoveFailedError( - '', to_dir, errors.NotADirectory(to_abs)) + "", to_dir, errors.NotADirectory(to_abs) + ) if not self.has_filename(to_dir): raise errors.BzrMoveFailedError( - '', to_dir, errors.NotInWorkingDirectory(to_dir)) + "", to_dir, errors.NotInWorkingDirectory(to_dir) + ) to_inv, to_dir_id = self._path2inv_file_id(to_dir) if to_dir_id is None: raise errors.BzrMoveFailedError( - '', to_dir, errors.NotVersionedError(path=to_dir)) + "", to_dir, errors.NotVersionedError(path=to_dir) + ) to_dir_ie = to_inv.get_entry(to_dir_id) - if to_dir_ie.kind != 'directory': + if to_dir_ie.kind != "directory": raise errors.BzrMoveFailedError( - '', to_dir, errors.NotADirectory(to_abs)) + "", to_dir, errors.NotADirectory(to_abs) + ) # create rename entries and tuples for from_rel in from_paths: from_tail = osutils.splitpath(from_rel)[-1] from_inv, from_id = self._path2inv_file_id(from_rel) if from_id is None: - raise errors.BzrMoveFailedError(from_rel, to_dir, - errors.NotVersionedError(path=from_rel)) + raise errors.BzrMoveFailedError( + from_rel, to_dir, errors.NotVersionedError(path=from_rel) + ) from_entry = from_inv.get_entry(from_id) from_parent_id = from_entry.parent_id @@ -1280,8 +1309,10 @@ def move(self, from_paths, to_dir=None, after=False): from_id=from_id, from_tail=from_tail, from_parent_id=from_parent_id, - to_rel=to_rel, to_tail=from_tail, - to_parent_id=to_dir_id) + to_rel=to_rel, + to_tail=from_tail, + to_parent_id=to_dir_id, + ) rename_entries.append(rename_entry) rename_tuples.append((from_rel, to_rel)) @@ -1307,7 +1338,7 @@ def iter_child_entries(self, path): inv, ie = self._path2inv_ie(path) if inv is None: raise _mod_transport.NoSuchFile(path) - if ie.kind != 'directory': + if ie.kind != "directory": raise errors.NotADirectory(path) return inv.iter_sorted_children(ie.file_id) @@ -1348,8 +1379,8 @@ def rename_one(self, from_rel, to_rel, after=False): basis_from_inv, from_id = basis_tree._path2inv_file_id(from_rel) if from_id is None: raise errors.BzrRenameFailedError( - from_rel, to_rel, - errors.NotVersionedError(path=from_rel)) + from_rel, to_rel, errors.NotVersionedError(path=from_rel) + ) try: from_entry = from_inv.get_entry(from_id) except errors.NoSuchId: @@ -1367,8 +1398,10 @@ def rename_one(self, from_rel, to_rel, after=False): from_id=from_id, from_tail=from_tail, from_parent_id=from_parent_id, - to_rel=to_rel, to_tail=to_tail, - to_parent_id=to_dir_id) + to_rel=to_rel, + to_tail=to_tail, + to_parent_id=to_dir_id, + ) rename_entries.append(rename_entry) # determine which move mode to use. checks also for movability @@ -1378,24 +1411,40 @@ def rename_one(self, from_rel, to_rel, after=False): # is versioned if to_dir_id is None: raise errors.BzrMoveFailedError( - from_rel, to_rel, errors.NotVersionedError(path=to_dir)) + from_rel, to_rel, errors.NotVersionedError(path=to_dir) + ) # all checks done. now we can continue with our actual work - mutter('rename_one:\n' - ' from_id {%s}\n' - ' from_rel: %r\n' - ' to_rel: %r\n' - ' to_dir %r\n' - ' to_dir_id {%s}\n', - from_id, from_rel, to_rel, to_dir, to_dir_id) + mutter( + "rename_one:\n" + " from_id {%s}\n" + " from_rel: %r\n" + " to_rel: %r\n" + " to_dir %r\n" + " to_dir_id {%s}\n", + from_id, + from_rel, + to_rel, + to_dir, + to_dir_id, + ) self._move(rename_entries) self._write_inventory(to_inv) class _RenameEntry: - def __init__(self, from_rel, from_id, from_tail, from_parent_id, - to_rel, to_tail, to_parent_id, only_change_inv=False, - change_id=False): + def __init__( + self, + from_rel, + from_id, + from_tail, + from_parent_id, + to_rel, + to_tail, + to_parent_id, + only_change_inv=False, + change_id=False, + ): self.from_rel = from_rel self.from_id = from_id self.from_tail = from_tail @@ -1426,7 +1475,8 @@ def _determine_mv_mode(self, rename_entries, after=False): # check the inventory for source and destination if from_id is None: raise errors.BzrMoveFailedError( - from_rel, to_rel, errors.NotVersionedError(path=from_rel)) + from_rel, to_rel, errors.NotVersionedError(path=from_rel) + ) if to_id is not None: allowed = False # allow it with --after but only if dest is newly added @@ -1440,34 +1490,39 @@ def _determine_mv_mode(self, rename_entries, after=False): allowed = True if not allowed: raise errors.BzrMoveFailedError( - from_rel, to_rel, - errors.AlreadyVersionedError(path=to_rel)) + from_rel, to_rel, errors.AlreadyVersionedError(path=to_rel) + ) # try to determine the mode for rename (only change inv or change # inv and file system) if after: if not self.has_filename(to_rel): raise errors.BzrMoveFailedError( - from_rel, to_rel, + from_rel, + to_rel, _mod_transport.NoSuchFile( - path=to_rel, - extra="New file has not been created yet")) + path=to_rel, extra="New file has not been created yet" + ), + ) only_change_inv = True elif not self.has_filename(from_rel) and self.has_filename(to_rel): only_change_inv = True elif self.has_filename(from_rel) and not self.has_filename(to_rel): only_change_inv = False - elif (not self.case_sensitive and - from_rel.lower() == to_rel.lower() and - self.has_filename(from_rel)): + elif ( + not self.case_sensitive + and from_rel.lower() == to_rel.lower() + and self.has_filename(from_rel) + ): only_change_inv = False else: # something is wrong, so lets determine what exactly - if not self.has_filename(from_rel) and \ - not self.has_filename(to_rel): + if not self.has_filename(from_rel) and not self.has_filename(to_rel): raise errors.BzrRenameFailedError( - from_rel, to_rel, - errors.PathsDoNotExist(paths=(from_rel, to_rel))) + from_rel, + to_rel, + errors.PathsDoNotExist(paths=(from_rel, to_rel)), + ) else: raise errors.RenameFailedFilesExist(from_rel, to_rel) rename_entry.only_change_inv = only_change_inv @@ -1493,32 +1548,44 @@ def _rollback_move(self, moved): """Try to rollback a previous move in case of an filesystem error.""" for entry in moved: try: - self._move_entry(WorkingTree._RenameEntry( - entry.to_rel, entry.from_id, - entry.to_tail, entry.to_parent_id, entry.from_rel, - entry.from_tail, entry.from_parent_id, - entry.only_change_inv)) + self._move_entry( + WorkingTree._RenameEntry( + entry.to_rel, + entry.from_id, + entry.to_tail, + entry.to_parent_id, + entry.from_rel, + entry.from_tail, + entry.from_parent_id, + entry.only_change_inv, + ) + ) except errors.BzrMoveFailedError as e: raise errors.BzrMoveFailedError( - '', '', "Rollback failed." + "", + "", + "Rollback failed." " The working tree is in an inconsistent state." " Please consider doing a 'bzr revert'." - " Error message is: %s" % e) from e + " Error message is: %s" % e, + ) from e def _move_entry(self, entry): inv = self.root_inventory from_rel_abs = self.abspath(entry.from_rel) to_rel_abs = self.abspath(entry.to_rel) if from_rel_abs == to_rel_abs: - raise errors.BzrMoveFailedError(entry.from_rel, entry.to_rel, - "Source and target are identical.") + raise errors.BzrMoveFailedError( + entry.from_rel, entry.to_rel, "Source and target are identical." + ) if not entry.only_change_inv: try: osutils.rename(from_rel_abs, to_rel_abs) except OSError as e: raise errors.BzrMoveFailedError( - entry.from_rel, entry.to_rel, e[1]) from e + entry.from_rel, entry.to_rel, e[1] + ) from e if entry.change_id: to_id = inv.path2id(entry.to_rel) inv.remove_recursive_id(to_id) @@ -1570,7 +1637,7 @@ def extras(self): """ # TODO: Work from given directory downwards for path, dir_entry in self.iter_entries_by_dir(): - if dir_entry.kind != 'directory': + if dir_entry.kind != "directory": continue # mutter("search for unknowns in %r", path) dirabs = self.abspath(path) @@ -1587,8 +1654,7 @@ def extras(self): if self.controldir.is_control_filename(subf): continue if subf not in versioned_children: - (subf_norm, - can_access) = osutils.normalized_filename(subf) + (subf_norm, can_access) = osutils.normalized_filename(subf) if subf_norm != subf and can_access: if subf_norm not in versioned_children: fl.append(subf_norm) @@ -1615,7 +1681,7 @@ def walkdirs(self, prefix=""): depending on the tree implementation. """ disk_top = self.abspath(prefix) - if disk_top.endswith('/'): + if disk_top.endswith("/"): disk_top = disk_top[:-1] top_strip_len = len(disk_top) + 1 inventory_iterator = self._walkdirs(prefix) @@ -1634,23 +1700,30 @@ def walkdirs(self, prefix=""): inv_finished = True while not inv_finished or not disk_finished: if current_disk: - ((cur_disk_dir_relpath, cur_disk_dir_path_from_top), - cur_disk_dir_content) = current_disk + ( + (cur_disk_dir_relpath, cur_disk_dir_path_from_top), + cur_disk_dir_content, + ) = current_disk else: - ((cur_disk_dir_relpath, cur_disk_dir_path_from_top), - cur_disk_dir_content) = ((None, None), None) + ( + (cur_disk_dir_relpath, cur_disk_dir_path_from_top), + cur_disk_dir_content, + ) = ((None, None), None) if not disk_finished: # strip out .bzr dirs - if (cur_disk_dir_path_from_top[top_strip_len:] == '' - and len(cur_disk_dir_content) > 0): + if ( + cur_disk_dir_path_from_top[top_strip_len:] == "" + and len(cur_disk_dir_content) > 0 + ): # osutils.walkdirs can be made nicer - # yield the path-from-prefix rather than the pathjoined # value. - bzrdir_loc = bisect_left(cur_disk_dir_content, - ('.bzr', '.bzr')) - if (bzrdir_loc < len(cur_disk_dir_content) and - self.controldir.is_control_filename( - cur_disk_dir_content[bzrdir_loc][0])): + bzrdir_loc = bisect_left(cur_disk_dir_content, (".bzr", ".bzr")) + if bzrdir_loc < len( + cur_disk_dir_content + ) and self.controldir.is_control_filename( + cur_disk_dir_content[bzrdir_loc][0] + ): # we dont yield the contents of, or, .bzr itself. del cur_disk_dir_content[bzrdir_loc] if inv_finished: @@ -1660,14 +1733,16 @@ def walkdirs(self, prefix=""): # everything is missing direction = -1 else: - direction = ((current_inv[0][0] > cur_disk_dir_relpath) - - (current_inv[0][0] < cur_disk_dir_relpath)) + direction = (current_inv[0][0] > cur_disk_dir_relpath) - ( + current_inv[0][0] < cur_disk_dir_relpath + ) if direction > 0: # disk is before inventory - unknown - dirblock = [(relpath, basename, kind, stat, None) for - relpath, basename, kind, stat, top_path in - cur_disk_dir_content] + dirblock = [ + (relpath, basename, kind, stat, None) + for relpath, basename, kind, stat, top_path in cur_disk_dir_content + ] yield cur_disk_dir_relpath, dirblock try: current_disk = next(disk_iterator) @@ -1675,9 +1750,10 @@ def walkdirs(self, prefix=""): disk_finished = True elif direction < 0: # inventory is before disk - missing. - dirblock = [(relpath, basename, 'unknown', None, kind) - for relpath, basename, dkind, stat, fileid, kind in - current_inv[1]] + dirblock = [ + (relpath, basename, "unknown", None, kind) + for relpath, basename, dkind, stat, fileid, kind in current_inv[1] + ] yield current_inv[0][0], dirblock try: current_inv = next(inventory_iterator) @@ -1687,28 +1763,50 @@ def walkdirs(self, prefix=""): # versioned present directory # merge the inventory and disk data together dirblock = [] - for _relpath, subiterator in itertools.groupby(sorted( + for _relpath, subiterator in itertools.groupby( + sorted( current_inv[1] + cur_disk_dir_content, - key=operator.itemgetter(0)), operator.itemgetter(1)): + key=operator.itemgetter(0), + ), + operator.itemgetter(1), + ): path_elements = list(subiterator) if len(path_elements) == 2: inv_row, disk_row = path_elements # versioned, present file - dirblock.append((inv_row[0], - inv_row[1], disk_row[2], - disk_row[3], inv_row[5])) + dirblock.append( + ( + inv_row[0], + inv_row[1], + disk_row[2], + disk_row[3], + inv_row[5], + ) + ) elif len(path_elements[0]) == 5: # unknown disk file dirblock.append( - (path_elements[0][0], path_elements[0][1], - path_elements[0][2], path_elements[0][3], None)) + ( + path_elements[0][0], + path_elements[0][1], + path_elements[0][2], + path_elements[0][3], + None, + ) + ) elif len(path_elements[0]) == 6: # versioned, absent file. dirblock.append( - (path_elements[0][0], path_elements[0][1], - 'unknown', None, path_elements[0][5])) + ( + path_elements[0][0], + path_elements[0][1], + "unknown", + None, + path_elements[0][5], + ) + ) else: - raise NotImplementedError('unreachable code') + raise NotImplementedError("unreachable code") yield current_inv[0][0], dirblock try: current_inv = next(inventory_iterator) @@ -1729,29 +1827,36 @@ def _walkdirs(self, prefix=""): [(file1_path, file1_name, file1_kind, None, file1_id, file1_kind), ... ]) """ - _directory = 'directory' + _directory = "directory" # get the root in the inventory inv, top_id = self._path2inv_file_id(prefix) if top_id is None: pending = [] else: - pending = [(prefix, '', _directory, None, top_id, None)] + pending = [(prefix, "", _directory, None, top_id, None)] while pending: dirblock = [] currentdir = pending.pop() # 0 - relpath, 1- basename, 2- kind, 3- stat, 4-id, 5-kind top_id = currentdir[4] if currentdir[0]: - relroot = currentdir[0] + '/' + relroot = currentdir[0] + "/" else: relroot = "" # FIXME: stash the node in pending entry = inv.get_entry(top_id) - if entry.kind == 'directory': + if entry.kind == "directory": for child in inv.iter_sorted_children(entry.file_id): - dirblock.append((relroot + child.name, child.name, child.kind, None, - child.file_id, child.kind - )) + dirblock.append( + ( + relroot + child.name, + child.name, + child.kind, + None, + child.file_id, + child.kind, + ) + ) yield (currentdir[0], entry.file_id), dirblock # push the user specified dirs from dirblock for dir in reversed(dirblock): @@ -1766,8 +1871,7 @@ def update_feature_flags(self, updated_flags): """ with self.lock_write(): self._format._update_feature_flags(updated_flags) - self.control_transport.put_bytes( - 'format', self._format.as_string()) + self.control_transport.put_bytes("format", self._format.as_string()) def _check_for_tree_references(self, iterator, recurse_nested, specific_files=None): """See if directories have become tree-references.""" @@ -1776,21 +1880,23 @@ def _check_for_tree_references(self, iterator, recurse_nested, specific_files=No if ie.parent_id in blocked_parent_ids: # This entry was pruned because one of its parents became a # TreeReference. If this is a directory, mark it as blocked. - if ie.kind == 'directory': + if ie.kind == "directory": blocked_parent_ids.add(ie.file_id) continue - if (ie.kind == 'directory' and ie.parent_id is not None and - self._directory_is_tree_reference(path)): - + if ( + ie.kind == "directory" + and ie.parent_id is not None + and self._directory_is_tree_reference(path) + ): # This InventoryDirectory needs to be a TreeReference ie = inventory.TreeReference(ie.file_id, ie.name, ie.parent_id) blocked_parent_ids.add(ie.file_id) - if ie.kind == 'tree-reference' and recurse_nested: + if ie.kind == "tree-reference" and recurse_nested: subtree = self.get_nested_tree(path) for subpath, ie in subtree.iter_entries_by_dir( - recurse_nested=recurse_nested, - specific_files=specific_files): + recurse_nested=recurse_nested, specific_files=specific_files + ): if subpath: full_subpath = osutils.pathjoin(path, subpath) else: @@ -1804,10 +1910,11 @@ def iter_entries_by_dir(self, specific_files=None, recurse_nested=False): # The only trick here is that if we supports_tree_reference then we # need to detect if a directory becomes a tree-reference. iterator = super(WorkingTree, self).iter_entries_by_dir( - specific_files=specific_files, recurse_nested=recurse_nested) + specific_files=specific_files, recurse_nested=recurse_nested + ) return self._check_for_tree_references( - iterator, recurse_nested=recurse_nested, - specific_files=specific_files) + iterator, recurse_nested=recurse_nested, specific_files=specific_files + ) def get_canonical_paths(self, paths): """Look up canonical paths for multiple items. @@ -1819,18 +1926,19 @@ def get_canonical_paths(self, paths): """ with self.lock_read(): if not self.case_sensitive: + def normalize(x): return x.lower() - elif sys.platform == 'darwin': + elif sys.platform == "darwin": import unicodedata def normalize(x): - return unicodedata.normalize('NFC', x) + return unicodedata.normalize("NFC", x) else: normalize = None for path in paths: if normalize is None or self.is_versioned(path): - yield path.strip('/') + yield path.strip("/") else: yield get_canonical_path(self, path, normalize) @@ -1848,8 +1956,8 @@ def set_reference_info(self, tree_path, branch_location): def reference_parent(self, path, branch=None, possible_transports=None): return self.branch.reference_parent( - self.path2id(path), - path, possible_transports=possible_transports) + self.path2id(path), path, possible_transports=possible_transports + ) def has_changes(self, _from_tree=None): """Quickly check that the tree contains at least one commitable change. @@ -1888,8 +1996,8 @@ def has_changes(self, _from_tree=None): # working copy as compared to the repository. # Also, exclude root as mention in the above fast path. changes = filter( - lambda c: c[6][0] != 'symlink' and c[4] != (None, None), - changes) + lambda c: c[6][0] != "symlink" and c[4] != (None, None), changes + ) try: next(iter(changes)) except StopIteration: @@ -1898,8 +2006,14 @@ def has_changes(self, _from_tree=None): _marker = object() - def update(self, change_reporter=None, possible_transports=None, - revision=None, old_tip=_marker, show_base=False): + def update( + self, + change_reporter=None, + possible_transports=None, + revision=None, + old_tip=_marker, + show_base=False, + ): """Update a working tree along its branch. This will update the branch if its bound too, which means we have @@ -1933,7 +2047,7 @@ def update(self, change_reporter=None, possible_transports=None, """ if self.branch.get_bound_location() is not None: self.lock_write() - update_branch = (old_tip is self._marker) + update_branch = old_tip is self._marker else: self.lock_tree_write() update_branch = False @@ -1947,8 +2061,9 @@ def update(self, change_reporter=None, possible_transports=None, finally: self.unlock() - def _update_tree(self, old_tip=None, change_reporter=None, revision=None, - show_base=False): + def _update_tree( + self, old_tip=None, change_reporter=None, revision=None, show_base=False + ): """Update a tree to the master branch. :param old_tip: if supplied, the previous tip revision the branch, @@ -1980,10 +2095,14 @@ def _update_tree(self, old_tip=None, change_reporter=None, revision=None, # merge those changes in first base_tree = self.basis_tree() other_tree = self.branch.repository.revision_tree(old_tip) - nb_conflicts = merge.merge_inner(self.branch, other_tree, - base_tree, this_tree=self, - change_reporter=change_reporter, - show_base=show_base) + nb_conflicts = merge.merge_inner( + self.branch, + other_tree, + base_tree, + this_tree=self, + change_reporter=change_reporter, + show_base=show_base, + ) if nb_conflicts: self.add_parent_tree((old_tip, other_tree)) return len(nb_conflicts) @@ -1992,24 +2111,29 @@ def _update_tree(self, old_tip=None, change_reporter=None, revision=None, # the working tree is up to date with the branch # we can merge the specified revision from master to_tree = self.branch.repository.revision_tree(revision) - to_root_id = to_tree.path2id('') + to_root_id = to_tree.path2id("") basis = self.basis_tree() with basis.lock_read(): - if (basis.path2id('') is None or basis.path2id('') != to_root_id): + if basis.path2id("") is None or basis.path2id("") != to_root_id: self.set_root_id(to_root_id) self.flush() # determine the branch point graph = self.branch.repository.get_graph() - base_rev_id = graph.find_unique_lca(self.branch.last_revision(), - last_rev) + base_rev_id = graph.find_unique_lca( + self.branch.last_revision(), last_rev + ) base_tree = self.branch.repository.revision_tree(base_rev_id) - nb_conflicts = merge.merge_inner(self.branch, to_tree, base_tree, - this_tree=self, - change_reporter=change_reporter, - show_base=show_base) + nb_conflicts = merge.merge_inner( + self.branch, + to_tree, + base_tree, + this_tree=self, + change_reporter=change_reporter, + show_base=show_base, + ) self.set_last_revision(revision) # TODO - dedup parents list with things merged by pull ? # reuse the tree we've updated to to set the basis: @@ -2022,24 +2146,40 @@ def _update_tree(self, old_tip=None, change_reporter=None, revision=None, # will not, but also does not need them when setting parents. for parent in merges: parent_trees.append( - (parent, self.branch.repository.revision_tree(parent))) + (parent, self.branch.repository.revision_tree(parent)) + ) if not _mod_revision.is_null(old_tip): parent_trees.append( - (old_tip, self.branch.repository.revision_tree(old_tip))) + (old_tip, self.branch.repository.revision_tree(old_tip)) + ) self.set_parent_trees(parent_trees) last_rev = parent_trees[0][0] return len(nb_conflicts) - def pull(self, source, overwrite=False, stop_revision=None, - change_reporter=None, possible_transports=None, local=False, - show_base=False, tag_selector=None): + def pull( + self, + source, + overwrite=False, + stop_revision=None, + change_reporter=None, + possible_transports=None, + local=False, + show_base=False, + tag_selector=None, + ): from ..merge import merge_inner + with self.lock_write(), source.lock_read(): old_revision_info = self.branch.last_revision_info() basis_tree = self.basis_tree() - count = self.branch.pull(source, overwrite=overwrite, stop_revision=stop_revision, - possible_transports=possible_transports, - local=local, tag_selector=tag_selector) + count = self.branch.pull( + source, + overwrite=overwrite, + stop_revision=stop_revision, + possible_transports=possible_transports, + local=local, + tag_selector=tag_selector, + ) new_revision_info = self.branch.last_revision_info() if new_revision_info != old_revision_info: repository = self.branch.repository @@ -2056,9 +2196,10 @@ def pull(self, source, overwrite=False, stop_revision=None, basis_tree, this_tree=self, change_reporter=change_reporter, - show_base=show_base) - basis_root_id = basis_tree.path2id('') - new_root_id = new_basis_tree.path2id('') + show_base=show_base, + ) + basis_root_id = basis_tree.path2id("") + new_root_id = new_basis_tree.path2id("") if new_root_id is not None and basis_root_id != new_root_id: self.set_root_id(new_root_id) # TODO - dedup parents list with things merged by pull ? @@ -2066,25 +2207,25 @@ def pull(self, source, overwrite=False, stop_revision=None, # tree data. parent_trees = [] if self.branch.last_revision() != _mod_revision.NULL_REVISION: - parent_trees.append( - (self.branch.last_revision(), new_basis_tree)) + parent_trees.append((self.branch.last_revision(), new_basis_tree)) # we have to pull the merge trees out again, because # merge_inner has set the ids. - this corner is not yet # layered well enough to prevent double handling. # XXX TODO: Fix the double handling: telling the tree about # the already known parent data is wasteful. merges = self.get_parent_ids()[1:] - parent_trees.extend([ - (parent, repository.revision_tree(parent)) for - parent in merges]) + parent_trees.extend( + [(parent, repository.revision_tree(parent)) for parent in merges] + ) self.set_parent_trees(parent_trees) return count def copy_content_into(self, tree, revision_id=None): """Copy the current content and user files of this tree into tree.""" from ..merge import transform_tree + with self.lock_read(): - tree.set_root_id(self.path2id('')) + tree.set_root_id(self.path2id("")) if revision_id is None: transform_tree(tree, self) else: @@ -2093,8 +2234,7 @@ def copy_content_into(self, tree, revision_id=None): try: other_tree = self.revision_tree(revision_id) except errors.NoSuchRevision: - other_tree = self.branch.repository.revision_tree( - revision_id) + other_tree = self.branch.repository.revision_tree(revision_id) transform_tree(tree, other_tree) if revision_id == _mod_revision.NULL_REVISION: @@ -2107,7 +2247,7 @@ def copy_content_into(self, tree, revision_id=None): class WorkingTreeFormatMetaDir(bzrdir.BzrFormat, WorkingTreeFormat): """Base class for working trees that live in bzr meta directories.""" - ignore_filename = '.bzrignore' + ignore_filename = ".bzrignore" supports_setting_file_ids = True """If this format allows setting the file id.""" @@ -2129,14 +2269,20 @@ def find_format_string(klass, controldir): def find_format(klass, controldir): """Return the format for the working tree object in controldir.""" format_string = klass.find_format_string(controldir) - return klass._find_format(format_registry, 'working tree', - format_string) + return klass._find_format(format_registry, "working tree", format_string) - def check_support_status(self, allow_unsupported, recommend_upgrade=True, - basedir=None): + def check_support_status( + self, allow_unsupported, recommend_upgrade=True, basedir=None + ): WorkingTreeFormat.check_support_status( - self, allow_unsupported=allow_unsupported, - recommend_upgrade=recommend_upgrade, basedir=basedir) + self, + allow_unsupported=allow_unsupported, + recommend_upgrade=recommend_upgrade, + basedir=basedir, + ) bzrdir.BzrFormat.check_support_status( - self, allow_unsupported=allow_unsupported, - recommend_upgrade=recommend_upgrade, basedir=basedir) + self, + allow_unsupported=allow_unsupported, + recommend_upgrade=recommend_upgrade, + basedir=basedir, + ) diff --git a/breezy/bzr/workingtree_3.py b/breezy/bzr/workingtree_3.py index 56119143c7..cda7e61af1 100644 --- a/breezy/bzr/workingtree_3.py +++ b/breezy/bzr/workingtree_3.py @@ -29,8 +29,7 @@ class PreDirStateWorkingTree(InventoryWorkingTree): - - def __init__(self, basedir='.', *args, **kwargs): + def __init__(self, basedir=".", *args, **kwargs): super().__init__(basedir, *args, **kwargs) # update the whole cache up front and write to disk if anything changed; # in the future we might want to do this more selectively @@ -39,10 +38,13 @@ def __init__(self, basedir='.', *args, **kwargs): # cache file, and have the parser take the most recent entry for a # given path only. wt_trans = self.controldir.get_workingtree_transport(None) - cache_filename = wt_trans.local_abspath('stat-cache') - self._hashcache = hashcache.HashCache(basedir, cache_filename, - self.controldir._get_file_mode(), - self._content_filter_stack_provider()) + cache_filename = wt_trans.local_abspath("stat-cache") + self._hashcache = hashcache.HashCache( + basedir, + cache_filename, + self.controldir._get_file_mode(), + self._content_filter_stack_provider(), + ) hc = self._hashcache hc.read() # is this scan needed ? it makes things kinda slow. @@ -61,8 +63,11 @@ def _write_hashcache_if_dirty(self): # TODO: jam 20061219 Should this be a warning? A single line # warning might be sufficient to let the user know what # is going on. - trace.mutter('Could not write hashcache for %s\nError: %s', - self._hashcache.cache_file_name(), e.filename) + trace.mutter( + "Could not write hashcache for %s\nError: %s", + self._hashcache.cache_file_name(), + e.filename, + ) def get_file_sha1(self, path, stat_value=None): with self.lock_read(): @@ -86,7 +91,7 @@ def _last_revision(self): """See Mutable.last_revision.""" with self.lock_read(): try: - return self._transport.get_bytes('last-revision') + return self._transport.get_bytes("last-revision") except _mod_transport.NoSuchFile: return _mod_revision.NULL_REVISION @@ -94,18 +99,19 @@ def _change_last_revision(self, revision_id): """See WorkingTree._change_last_revision.""" if revision_id is None or revision_id == _mod_revision.NULL_REVISION: try: - self._transport.delete('last-revision') + self._transport.delete("last-revision") except _mod_transport.NoSuchFile: pass return False else: - self._transport.put_bytes('last-revision', revision_id, - mode=self.controldir._get_file_mode()) + self._transport.put_bytes( + "last-revision", revision_id, mode=self.controldir._get_file_mode() + ) return True def _get_check_refs(self): """Return the references needed to perform a check of this tree.""" - return [('trees', self.last_revision())] + return [("trees", self.last_revision())] def unlock(self): if self._control_files._lock_count == 1: @@ -158,10 +164,16 @@ def __get_matchingcontroldir(self): def _open_control_files(self, a_controldir): transport = a_controldir.get_workingtree_transport(None) - return LockableFiles(transport, 'lock', LockDir) - - def initialize(self, a_controldir, revision_id=None, from_branch=None, - accelerator_tree=None, hardlink=False): + return LockableFiles(transport, "lock", LockDir) + + def initialize( + self, + a_controldir, + revision_id=None, + from_branch=None, + accelerator_tree=None, + hardlink=False, + ): """See WorkingTreeFormat.initialize(). :param revision_id: if supplied, create a working tree at a different @@ -179,8 +191,9 @@ def initialize(self, a_controldir, revision_id=None, from_branch=None, control_files = self._open_control_files(a_controldir) control_files.create_lock() control_files.lock_write() - transport.put_bytes('format', self.as_string(), - mode=a_controldir._get_file_mode()) + transport.put_bytes( + "format", self.as_string(), mode=a_controldir._get_file_mode() + ) if from_branch is not None: branch = from_branch else: @@ -193,25 +206,27 @@ def initialize(self, a_controldir, revision_id=None, from_branch=None, # are maintaining compatibility with older clients. # inv = Inventory(root_id=gen_root_id()) inv = self._initial_inventory() - wt = self._tree_class(a_controldir.root_transport.local_abspath('.'), - branch, - inv, - _internal=True, - _format=self, - _controldir=a_controldir, - _control_files=control_files) + wt = self._tree_class( + a_controldir.root_transport.local_abspath("."), + branch, + inv, + _internal=True, + _format=self, + _controldir=a_controldir, + _control_files=control_files, + ) wt.lock_tree_write() try: basis_tree = branch.repository.revision_tree(revision_id) # only set an explicit root id if there is one to set. - if basis_tree.path2id('') is not None: - wt.set_root_id(basis_tree.path2id('')) + if basis_tree.path2id("") is not None: + wt.set_root_id(basis_tree.path2id("")) if revision_id == _mod_revision.NULL_REVISION: wt.set_parent_trees([]) else: wt.set_parent_trees([(revision_id, basis_tree)]) bzr_transform.build_tree(basis_tree, wt) - for hook in MutableTree.hooks['post_build_tree']: + for hook in MutableTree.hooks["post_build_tree"]: hook(wt) finally: # Unlock in this order so that the unlock-triggers-flush in @@ -243,8 +258,10 @@ def _open(self, a_controldir, control_files): :param a_controldir: the dir for the tree. :param control_files: the control files for the tree. """ - return self._tree_class(a_controldir.root_transport.local_abspath('.'), - _internal=True, - _format=self, - _controldir=a_controldir, - _control_files=control_files) + return self._tree_class( + a_controldir.root_transport.local_abspath("."), + _internal=True, + _format=self, + _controldir=a_controldir, + _control_files=control_files, + ) diff --git a/breezy/bzr/workingtree_4.py b/breezy/bzr/workingtree_4.py index e8abe6b69b..9e7fca13c1 100644 --- a/breezy/bzr/workingtree_4.py +++ b/breezy/bzr/workingtree_4.py @@ -28,7 +28,9 @@ from ..lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import contextlib import stat @@ -43,7 +45,8 @@ generate_ids, transform as bzr_transform, ) -""") +""", +) from .. import cache_utf8, debug, errors, osutils, trace from .. import revision as _mod_revision @@ -70,12 +73,9 @@ class DirStateWorkingTree(InventoryWorkingTree): - - def __init__(self, basedir, - branch, - _control_files=None, - _format=None, - _controldir=None): + def __init__( + self, basedir, branch, _control_files=None, _format=None, _controldir=None + ): """Construct a WorkingTree for basedir. If the branch is not supplied, it is opened automatically. @@ -107,31 +107,31 @@ def __init__(self, basedir, # --- allow tests to select the dirstate iter_changes implementation self._iter_changes = dirstate._process_entry self._repo_supports_tree_reference = getattr( - self._branch.repository._format, "supports_tree_reference", - False) + self._branch.repository._format, "supports_tree_reference", False + ) def _add(self, files, kinds, ids): """See MutableTree._add.""" with self.lock_tree_write(): state = self.current_dirstate() for f, file_id, kind in zip(files, ids, kinds): - f = f.strip('/') + f = f.strip("/") if self.path2id(f): # special case tree root handling. - if f == b'' and self.path2id(f) == ROOT_ID: - state.set_path_id(b'', generate_ids.gen_file_id(f)) + if f == b"" and self.path2id(f) == ROOT_ID: + state.set_path_id(b"", generate_ids.gen_file_id(f)) continue if file_id is None: file_id = generate_ids.gen_file_id(f) # deliberately add the file with no cached stat or sha1 # - on the first access it will be gathered, and we can # always change this once tests are all passing. - state.add(f, file_id, kind, None, b'') + state.add(f, file_id, kind, None, b"") self._make_dirty(reset_inventory=True) def _get_check_refs(self): """Return the references needed to perform a check of this tree.""" - return [('trees', self.last_revision())] + return [("trees", self.last_revision())] def _make_dirty(self, reset_inventory): """Make the tree state dirty. @@ -152,20 +152,21 @@ def add_reference(self, sub_tree): try: sub_tree_path = self.relpath(sub_tree.basedir) except errors.PathNotChild as e: - raise BadReferenceTarget(self, sub_tree, - 'Target not inside tree.') from e - sub_tree_id = sub_tree.path2id('') - if sub_tree_id == self.path2id(''): - raise BadReferenceTarget(self, sub_tree, - 'Trees have the same root id.') + raise BadReferenceTarget( + self, sub_tree, "Target not inside tree." + ) from e + sub_tree_id = sub_tree.path2id("") + if sub_tree_id == self.path2id(""): + raise BadReferenceTarget(self, sub_tree, "Trees have the same root id.") try: self.id2path(sub_tree_id) except errors.NoSuchId: pass else: raise BadReferenceTarget( - self, sub_tree, 'Root id already present in tree') - self._add([sub_tree_path], ['tree-reference'], [sub_tree_id]) + self, sub_tree, "Root id already present in tree" + ) + self._add([sub_tree_path], ["tree-reference"], [sub_tree_id]) def break_lock(self): """Break a lock if one is present from another instance. @@ -205,21 +206,23 @@ def break_lock(self): self.branch.break_lock() def _comparison_data(self, entry, path): - kind, executable, stat_value = \ - WorkingTree._comparison_data(self, entry, path) + kind, executable, stat_value = WorkingTree._comparison_data(self, entry, path) # it looks like a plain directory, but it's really a reference -- see # also kind() - if (self._repo_supports_tree_reference and kind == 'directory' - and entry is not None and entry.kind == 'tree-reference'): - kind = 'tree-reference' + if ( + self._repo_supports_tree_reference + and kind == "directory" + and entry is not None + and entry.kind == "tree-reference" + ): + kind = "tree-reference" return kind, executable, stat_value def commit(self, message=None, revprops=None, *args, **kwargs): with self.lock_write(): # mark the tree as dirty post commit - commit # can change the current versioned list by doing deletes. - result = WorkingTree.commit(self, message, revprops, *args, - **kwargs) + result = WorkingTree.commit(self, message, revprops, *args, **kwargs) self._make_dirty(reset_inventory=True) return result @@ -241,11 +244,15 @@ def _current_dirstate(self): """ if self._dirstate is not None: return self._dirstate - local_path = self.controldir.get_workingtree_transport( - None).local_abspath('dirstate') + local_path = self.controldir.get_workingtree_transport(None).local_abspath( + "dirstate" + ) self._dirstate = dirstate.DirState.on_file( - local_path, self._sha1_provider(), self._worth_saving_limit(), - self._supports_executable()) + local_path, + self._sha1_provider(), + self._worth_saving_limit(), + self._supports_executable(), + ) return self._dirstate def _sha1_provider(self): @@ -266,7 +273,7 @@ def _worth_saving_limit(self): :return: an integer. -1 means never save. """ conf = self.get_config_stack() - return conf.get('bzr.workingtree.worth_saving_limit') + return conf.get("bzr.workingtree.worth_saving_limit") def filter_unversioned_files(self, paths): """Filter out paths that are versioned. @@ -281,16 +288,17 @@ def filter_unversioned_files(self, paths): state = self.current_dirstate() # TODO we want a paths_to_dirblocks helper I think for path in paths: - dirname, basename = os.path.split(path.encode('utf8')) + dirname, basename = os.path.split(path.encode("utf8")) _, _, _, path_is_versioned = state._get_block_entry_index( - dirname, basename, 0) + dirname, basename, 0 + ) if not path_is_versioned: result.add(path) return result def flush(self): """Write all cached data to disk.""" - if self._control_files._lock_mode != 'w': + if self._control_files._lock_mode != "w": raise errors.NotWriteLocked(self) self.current_dirstate().save() self._inventory = None @@ -313,9 +321,9 @@ def _generate_inventory(self) -> None: # import pdb;pdb.set_trace() state = self.current_dirstate() state._read_dirblocks_if_needed() - root_key, current_entry = self._get_entry(path='') + root_key, current_entry = self._get_entry(path="") current_id = root_key[2] - if not (current_entry[0][0] == b'd'): # directory + if not (current_entry[0][0] == b"d"): # directory raise AssertionError(current_entry) inv = Inventory(root_id=current_id) # Turn some things into local variables @@ -324,7 +332,7 @@ def _generate_inventory(self) -> None: utf8_decode = cache_utf8._utf8_decode # we could do this straight out of the dirstate; it might be fast # and should be profiled - RBC 20070216 - parent_ies: Dict[bytes, InventoryEntry] = {b'': inv.root} + parent_ies: Dict[bytes, InventoryEntry] = {b"": inv.root} for block in state._dirblocks[1:]: # skip the root dirname = block[0] try: @@ -334,37 +342,39 @@ def _generate_inventory(self) -> None: continue for key, entry in block[1]: minikind, link_or_sha1, size, executable, stat = entry[0] - if minikind in (b'a', b'r'): # absent, relocated + if minikind in (b"a", b"r"): # absent, relocated # a parent tree only entry continue name = key[1] name_unicode = utf8_decode(name)[0] file_id = key[2] kind = minikind_to_kind[minikind] - inv_entry = factory[kind](file_id, name_unicode, - parent_ie.file_id) - if kind == 'file': + inv_entry = factory[kind](file_id, name_unicode, parent_ie.file_id) + if kind == "file": # This is only needed on win32, where this is the only way # we know the executable bit. inv_entry.executable = bool(executable) # not strictly needed: working tree # inv_entry.text_size = size # inv_entry.text_sha1 = sha1 - elif kind == 'directory': + elif kind == "directory": # add this entry to the parent map. - parent_ies[(dirname + b'/' + name).strip(b'/')] = inv_entry - elif kind == 'tree-reference': + parent_ies[(dirname + b"/" + name).strip(b"/")] = inv_entry + elif kind == "tree-reference": inv_entry.reference_revision = link_or_sha1 or None - elif kind != 'symlink': + elif kind != "symlink": raise AssertionError(f"unknown kind {kind!r}") try: inv.add(inv_entry) except DuplicateFileId as err: raise AssertionError( - f'file_id {file_id} already in' - f' inventory as {inv.get_entry(file_id)}') from err + f"file_id {file_id} already in" + f" inventory as {inv.get_entry(file_id)}" + ) from err except errors.InconsistentDelta as err: - raise AssertionError(f'name {name_unicode!r} already in parent') from err + raise AssertionError( + f"name {name_unicode!r} already in parent" + ) from err self._inventory = inv def _get_entry(self, file_id=None, path=None): @@ -379,10 +389,10 @@ def _get_entry(self, file_id=None, path=None): :return: The dirstate row tuple for path/file_id, or (None, None) """ if file_id is None and path is None: - raise errors.BzrError('must supply file_id or path') + raise errors.BzrError("must supply file_id or path") state = self.current_dirstate() if path is not None: - path = path.encode('utf8') + path = path.encode("utf8") return state._get_entry(0, fileid_utf8=file_id, path_utf8=path) def get_file_sha1(self, path, stat_value=None): @@ -391,7 +401,7 @@ def get_file_sha1(self, path, stat_value=None): if entry[0] is None: raise NoSuchFile(self, path) if path is None: - path = pathjoin(entry[0][0], entry[0][1]).decode('utf8') + path = pathjoin(entry[0][0], entry[0][1]).decode("utf8") file_abspath = self.abspath(path) state = self.current_dirstate() @@ -400,9 +410,10 @@ def get_file_sha1(self, path, stat_value=None): stat_value = osutils.lstat(file_abspath) except FileNotFoundError: return None - link_or_sha1 = dirstate.update_entry(state, entry, file_abspath, - stat_value=stat_value) - if entry[1][0][0] == b'f': + link_or_sha1 = dirstate.update_entry( + state, entry, file_abspath, stat_value=stat_value + ) + if entry[1][0][0] == b"f": if link_or_sha1 is None: file_obj, statvalue = self.get_file_with_stat(path) try: @@ -417,17 +428,17 @@ def get_file_sha1(self, path, stat_value=None): def _get_root_inventory(self): """Get the inventory for the tree. This is only valid within a lock.""" - if debug.debug_flag_enabled('evil'): + if debug.debug_flag_enabled("evil"): trace.mutter_callsite( - 2, "accessing .inventory forces a size of tree translation.") + 2, "accessing .inventory forces a size of tree translation." + ) if self._inventory is not None: return self._inventory self._must_be_locked() self._generate_inventory() return self._inventory - root_inventory = property(_get_root_inventory, - doc="Root inventory of this tree") + root_inventory = property(_get_root_inventory, doc="Root inventory of this tree") def get_parent_ids(self): """See Tree.get_parent_ids. @@ -453,25 +464,27 @@ def get_nested_tree(self, path): except errors.NotBranchError as err: raise MissingNestedTree(path) from err - def id2path(self, file_id, recurse='down'): + def id2path(self, file_id, recurse="down"): """Convert a file-id to a path.""" with self.lock_read(): self.current_dirstate() entry = self._get_entry(file_id=file_id) if entry == (None, None): - if recurse == 'down': - if debug.debug_flag_enabled('evil'): - trace.mutter_callsite( - 2, "Tree.id2path scans all nested trees.") + if recurse == "down": + if debug.debug_flag_enabled("evil"): + trace.mutter_callsite(2, "Tree.id2path scans all nested trees.") for nested_path in self.iter_references(): nested_tree = self.get_nested_tree(nested_path) try: return osutils.pathjoin( - nested_path, nested_tree.id2path(file_id)) + nested_path, nested_tree.id2path(file_id) + ) except errors.NoSuchId: pass raise errors.NoSuchId(tree=self, file_id=file_id) - return osutils.pathjoin(entry[0][0], entry[0][1]).decode('utf-8', 'surrogateescape') + return osutils.pathjoin(entry[0][0], entry[0][1]).decode( + "utf-8", "surrogateescape" + ) def _is_executable_from_path_and_stat_from_basis(self, path, stat_result): entry = self._get_entry(path=path) @@ -499,15 +512,16 @@ def all_file_ids(self): self._must_be_locked() result = set() for key, tree_details in self.current_dirstate()._iter_entries(): - if tree_details[0][0] in (b'a', b'r'): # relocated + if tree_details[0][0] in (b"a", b"r"): # relocated continue result.add(key[2]) return result def all_versioned_paths(self): self._must_be_locked() - return {path for path, entry in - self.root_inventory.iter_entries(recursive=True)} + return { + path for path, entry in self.root_inventory.iter_entries(recursive=True) + } def __iter__(self): """Iterate through file_ids for this tree. @@ -518,11 +532,12 @@ def __iter__(self): with self.lock_read(): result = [] for key, tree_details in self.current_dirstate()._iter_entries(): - if tree_details[0][0] in (b'a', b'r'): # absent, relocated + if tree_details[0][0] in (b"a", b"r"): # absent, relocated # not relevant to the working tree continue - path = pathjoin(self.basedir, key[0].decode( - 'utf8'), key[1].decode('utf8')) + path = pathjoin( + self.basedir, key[0].decode("utf8"), key[1].decode("utf8") + ) if osutils.lexists(path): result.append(key[2]) return iter(result) @@ -534,15 +549,15 @@ def iter_references(self): return with self.lock_read(): for key, tree_details in self.current_dirstate()._iter_entries(): - if tree_details[0][0] in (b'a', b'r'): # absent, relocated + if tree_details[0][0] in (b"a", b"r"): # absent, relocated # not relevant to the working tree continue if not key[1]: # the root is not a reference. continue - relpath = pathjoin(key[0].decode('utf8'), key[1].decode('utf8')) + relpath = pathjoin(key[0].decode("utf8"), key[1].decode("utf8")) try: - if self.kind(relpath) == 'tree-reference': + if self.kind(relpath) == "tree-reference": yield relpath except NoSuchFile: # path is missing on disk. @@ -557,12 +572,12 @@ def _observed_sha1(self, path, sha_and_stat): def kind(self, relpath): abspath = self.abspath(relpath) kind = file_kind(abspath) - if (self._repo_supports_tree_reference and kind == 'directory'): + if self._repo_supports_tree_reference and kind == "directory": with self.lock_read(): entry = self._get_entry(path=relpath) if entry[1] is not None: - if entry[1][0][0] == b't': - kind = 'tree-reference' + if entry[1][0][0] == b"t": + kind = "tree-reference" return kind def _last_revision(self): @@ -589,8 +604,8 @@ def lock_read(self): # set our support for tree references from the repository in # use. self._repo_supports_tree_reference = getattr( - self.branch.repository._format, "supports_tree_reference", - False) + self.branch.repository._format, "supports_tree_reference", False + ) except BaseException: self._control_files.unlock() raise @@ -610,8 +625,8 @@ def _lock_self_write(self): # set our support for tree references from the repository in # use. self._repo_supports_tree_reference = getattr( - self.branch.repository._format, "supports_tree_reference", - False) + self.branch.repository._format, "supports_tree_reference", False + ) except BaseException: self._control_files.unlock() raise @@ -645,29 +660,36 @@ def move(self, from_paths, to_dir, after=False): state = self.current_dirstate() if isinstance(from_paths, (str, bytes)): raise ValueError() - to_dir_utf8 = to_dir.encode('utf8') + to_dir_utf8 = to_dir.encode("utf8") to_entry_dirname, to_basename = os.path.split(to_dir_utf8) # check destination directory # get the details for it - (to_entry_block_index, to_entry_entry_index, dir_present, - entry_present) = state._get_block_entry_index( - to_entry_dirname, to_basename, 0) + ( + to_entry_block_index, + to_entry_entry_index, + dir_present, + entry_present, + ) = state._get_block_entry_index(to_entry_dirname, to_basename, 0) if not entry_present: raise errors.BzrMoveFailedError( - '', to_dir, errors.NotVersionedError(to_dir)) + "", to_dir, errors.NotVersionedError(to_dir) + ) to_entry = state._dirblocks[to_entry_block_index][1][to_entry_entry_index] # get a handle on the block itself. to_block_index = state._ensure_block( - to_entry_block_index, to_entry_entry_index, to_dir_utf8) + to_entry_block_index, to_entry_entry_index, to_dir_utf8 + ) to_block = state._dirblocks[to_block_index] to_abs = self.abspath(to_dir) if not isdir(to_abs): - raise errors.BzrMoveFailedError('', to_dir, - errors.NotADirectory(to_abs)) + raise errors.BzrMoveFailedError( + "", to_dir, errors.NotADirectory(to_abs) + ) - if to_entry[1][0][0] != b'd': - raise errors.BzrMoveFailedError('', to_dir, - errors.NotADirectory(to_abs)) + if to_entry[1][0][0] != b"d": + raise errors.BzrMoveFailedError( + "", to_dir, errors.NotADirectory(to_abs) + ) if self._inventory is not None: update_inventory = True @@ -680,9 +702,18 @@ def move(self, from_paths, to_dir, after=False): # missing those added here, but there's also no test coverage for this. rollbacks = contextlib.ExitStack() - def move_one(old_entry, from_path_utf8, minikind, executable, - fingerprint, packed_stat, size, - to_block, to_key, to_path_utf8): + def move_one( + old_entry, + from_path_utf8, + minikind, + executable, + fingerprint, + packed_stat, + size, + to_block, + to_key, + to_path_utf8, + ): state._make_absent(old_entry) from_key = old_entry[0] rollbacks.callback( @@ -693,29 +724,31 @@ def move_one(old_entry, from_path_utf8, minikind, executable, fingerprint=fingerprint, packed_stat=packed_stat, size=size, - path_utf8=from_path_utf8) - state.update_minimal(to_key, - minikind, - executable=executable, - fingerprint=fingerprint, - packed_stat=packed_stat, - size=size, - path_utf8=to_path_utf8) - added_entry_index, _ = state._find_entry_index( - to_key, to_block[1]) + path_utf8=from_path_utf8, + ) + state.update_minimal( + to_key, + minikind, + executable=executable, + fingerprint=fingerprint, + packed_stat=packed_stat, + size=size, + path_utf8=to_path_utf8, + ) + added_entry_index, _ = state._find_entry_index(to_key, to_block[1]) new_entry = to_block[1][added_entry_index] rollbacks.callback(state._make_absent, new_entry) for from_rel in from_paths: # from_rel is 'pathinroot/foo/bar' - from_rel_utf8 = from_rel.encode('utf8') + from_rel_utf8 = from_rel.encode("utf8") from_dirname, from_tail = osutils.split(from_rel) from_dirname, from_tail_utf8 = osutils.split(from_rel_utf8) from_entry = self._get_entry(path=from_rel) if from_entry == (None, None): raise errors.BzrMoveFailedError( - from_rel, to_dir, - errors.NotVersionedError(path=from_rel)) + from_rel, to_dir, errors.NotVersionedError(path=from_rel) + ) from_id = from_entry[0][2] to_rel = pathjoin(to_dir, from_tail) @@ -723,11 +756,13 @@ def move_one(old_entry, from_path_utf8, minikind, executable, item_to_entry = self._get_entry(path=to_rel) if item_to_entry != (None, None): raise errors.BzrMoveFailedError( - from_rel, to_rel, "Target is already versioned.") + from_rel, to_rel, "Target is already versioned." + ) if from_rel == to_rel: raise errors.BzrMoveFailedError( - from_rel, to_rel, "Source and target are identical.") + from_rel, to_rel, "Source and target are identical." + ) from_missing = not self.has_filename(from_rel) to_missing = not self.has_filename(to_rel) @@ -738,15 +773,19 @@ def move_one(old_entry, from_path_utf8, minikind, executable, if to_missing: if not move_file: raise errors.BzrMoveFailedError( - from_rel, to_rel, + from_rel, + to_rel, NoSuchFile( - path=to_rel, - extra="New file has not been created yet")) + path=to_rel, extra="New file has not been created yet" + ), + ) elif from_missing: # neither path exists raise errors.BzrRenameFailedError( - from_rel, to_rel, - errors.PathsDoNotExist(paths=(from_rel, to_rel))) + from_rel, + to_rel, + errors.PathsDoNotExist(paths=(from_rel, to_rel)), + ) else: if from_missing: # implicitly just update our path mapping move_file = False @@ -761,8 +800,7 @@ def move_one(old_entry, from_path_utf8, minikind, executable, osutils.rename(from_rel_abs, to_rel_abs) except OSError as e: raise errors.BzrMoveFailedError(from_rel, to_rel, e[1]) from e - rollbacks.callback( - osutils.rename, to_rel_abs, from_rel_abs) + rollbacks.callback(osutils.rename, to_rel_abs, from_rel_abs) try: # perform the rename in the inventory next if needed: its easy # to rollback @@ -772,38 +810,46 @@ def move_one(old_entry, from_path_utf8, minikind, executable, current_parent = from_entry.parent_id inv.rename(from_id, to_dir_id, from_tail) rollbacks.callback( - inv.rename, from_id, current_parent, from_tail) + inv.rename, from_id, current_parent, from_tail + ) # finally do the rename in the dirstate, which is a little # tricky to rollback, but least likely to need it. - old_block_index, old_entry_index, dir_present, file_present = \ - state._get_block_entry_index( - from_dirname, from_tail_utf8, 0) + ( + old_block_index, + old_entry_index, + dir_present, + file_present, + ) = state._get_block_entry_index(from_dirname, from_tail_utf8, 0) old_block = state._dirblocks[old_block_index][1] old_entry = old_block[old_entry_index] from_key, old_entry_details = old_entry cur_details = old_entry_details[0] # remove the old row - to_key = ((to_block[0],) + from_key[1:3]) + to_key = (to_block[0],) + from_key[1:3] minikind = cur_details[0] - move_one(old_entry, from_path_utf8=from_rel_utf8, - minikind=minikind, - executable=cur_details[3], - fingerprint=cur_details[1], - packed_stat=cur_details[4], - size=cur_details[2], - to_block=to_block, - to_key=to_key, - to_path_utf8=to_rel_utf8) - - if minikind == b'd': + move_one( + old_entry, + from_path_utf8=from_rel_utf8, + minikind=minikind, + executable=cur_details[3], + fingerprint=cur_details[1], + packed_stat=cur_details[4], + size=cur_details[2], + to_block=to_block, + to_key=to_key, + to_path_utf8=to_rel_utf8, + ) + + if minikind == b"d": + def update_dirblock(from_dir, to_key, to_dir_utf8): """Recursively update all entries in this dirblock.""" - if from_dir == b'': - raise AssertionError( - "renaming root not supported") - from_key = (from_dir, '') - from_block_idx, present = \ - state._find_block_index_from_key(from_key) + if from_dir == b"": + raise AssertionError("renaming root not supported") + from_key = (from_dir, "") + from_block_idx, present = state._find_block_index_from_key( + from_key + ) if not present: # This is the old record, if it isn't present, # then there is theoretically nothing to @@ -811,11 +857,15 @@ def update_dirblock(from_dir, to_key, to_dir_utf8): # lazy loading, but we don't do that yet) return from_block = state._dirblocks[from_block_idx] - to_block_index, to_entry_index, _, _ = \ - state._get_block_entry_index( - to_key[0], to_key[1], 0) + ( + to_block_index, + to_entry_index, + _, + _, + ) = state._get_block_entry_index(to_key[0], to_key[1], 0) to_block_index = state._ensure_block( - to_block_index, to_entry_index, to_dir_utf8) + to_block_index, to_entry_index, to_dir_utf8 + ) to_block = state._dirblocks[to_block_index] # Grab a copy since move_one may update the list. @@ -823,33 +873,39 @@ def update_dirblock(from_dir, to_key, to_dir_utf8): if not (entry[0][0] == from_dir): raise AssertionError() cur_details = entry[1][0] - to_key = ( - to_dir_utf8, entry[0][1], entry[0][2]) + to_key = (to_dir_utf8, entry[0][1], entry[0][2]) from_path_utf8 = osutils.pathjoin( - entry[0][0], entry[0][1]) + entry[0][0], entry[0][1] + ) to_path_utf8 = osutils.pathjoin( - to_dir_utf8, entry[0][1]) + to_dir_utf8, entry[0][1] + ) minikind = cur_details[0] - if minikind in (b'a', b'r'): + if minikind in (b"a", b"r"): # Deleted children of a renamed directory # Do not need to be updated. Children that # have been renamed out of this directory # should also not be updated continue - move_one(entry, from_path_utf8=from_path_utf8, - minikind=minikind, - executable=cur_details[3], - fingerprint=cur_details[1], - packed_stat=cur_details[4], - size=cur_details[2], - to_block=to_block, - to_key=to_key, - to_path_utf8=to_path_utf8) - if minikind == b'd': + move_one( + entry, + from_path_utf8=from_path_utf8, + minikind=minikind, + executable=cur_details[3], + fingerprint=cur_details[1], + packed_stat=cur_details[4], + size=cur_details[2], + to_block=to_block, + to_key=to_key, + to_path_utf8=to_path_utf8, + ) + if minikind == b"d": # We need to move all the children of this # entry - update_dirblock(from_path_utf8, to_key, - to_path_utf8) + update_dirblock( + from_path_utf8, to_key, to_path_utf8 + ) + update_dirblock(from_rel_utf8, to_key, to_rel_utf8) except BaseException: rollbacks.close() @@ -875,7 +931,7 @@ def path2id(self, path): if path == []: path = [""] path = osutils.pathjoin(*path) - path = path.strip('/') + path = path.strip("/") entry = self._get_entry(path=path) if entry == (None, None): nested_tree, subpath = self.get_containing_nested_tree(path) @@ -896,34 +952,35 @@ def paths2ids(self, paths, trees=None, require_versioned=True): return None parents = self.get_parent_ids() for tree in trees: - if not (isinstance(tree, DirStateRevisionTree) and - tree._revision_id in parents): - return super().paths2ids( - paths, trees, require_versioned) - search_indexes = [ - 0] + [1 + parents.index(tree._revision_id) for tree in trees] + if not ( + isinstance(tree, DirStateRevisionTree) and tree._revision_id in parents + ): + return super().paths2ids(paths, trees, require_versioned) + search_indexes = [0] + [1 + parents.index(tree._revision_id) for tree in trees] paths_utf8 = set() for path in paths: - paths_utf8.add(path.encode('utf8')) + paths_utf8.add(path.encode("utf8")) # -- get the state object and prepare it. state = self.current_dirstate() - if False and (state._dirblock_state == dirstate.DirState.NOT_IN_MEMORY - and b'' not in paths): + if False and ( + state._dirblock_state == dirstate.DirState.NOT_IN_MEMORY + and b"" not in paths + ): paths2ids = self._paths2ids_using_bisect else: paths2ids = self._paths2ids_in_memory - return paths2ids(paths_utf8, search_indexes, - require_versioned=require_versioned) + return paths2ids( + paths_utf8, search_indexes, require_versioned=require_versioned + ) - def _paths2ids_in_memory(self, paths, search_indexes, - require_versioned=True): + def _paths2ids_in_memory(self, paths, search_indexes, require_versioned=True): state = self.current_dirstate() state._read_dirblocks_if_needed() def _entries_for_path(path): """Return a list with all the entries that match path for all ids.""" dirname, basename = os.path.split(path) - key = (dirname, basename, b'') + key = (dirname, basename, b"") block_index, present = state._find_block_index_from_key(key) if not present: # the block which should contain path is absent. @@ -933,11 +990,11 @@ def _entries_for_path(path): entry_index, _ = state._find_entry_index(key, block) # we may need to look at multiple entries at this path: walk while # the paths match. - while (entry_index < len(block) and - block[entry_index][0][0:2] == key[0:2]): + while entry_index < len(block) and block[entry_index][0][0:2] == key[0:2]: result.append(block[entry_index]) entry_index += 1 return result + if require_versioned: # -- check all supplied paths are versioned in a search tree. -- all_versioned = True @@ -952,7 +1009,7 @@ def _entries_for_path(path): for entry in path_entries: # for each tree. for index in search_indexes: - if entry[1][index][0] != b'a': # absent + if entry[1][index][0] != b"a": # absent found_versioned = True # all good: found a versioned cell break @@ -962,12 +1019,11 @@ def _entries_for_path(path): all_versioned = False break if not all_versioned: - raise errors.PathsNotVersionedError( - [p.decode('utf-8') for p in paths]) + raise errors.PathsNotVersionedError([p.decode("utf-8") for p in paths]) # -- remove redundancy in supplied paths to prevent over-scanning -- search_paths = { - p.encode('utf-8') - for p in osutils.minimum_path_selection(paths)} + p.encode("utf-8") for p in osutils.minimum_path_selection(paths) + } # sketch: # for all search_indexs in each path at or under each element of # search_paths, if the detail is relocated: add the id, and add the @@ -984,12 +1040,12 @@ def _process_entry(entry): nothing. Otherwise add the id to found_ids. """ for index in search_indexes: - if entry[1][index][0] == b'r': # relocated - if not osutils.is_inside_any(searched_paths, - entry[1][index][1]): + if entry[1][index][0] == b"r": # relocated + if not osutils.is_inside_any(searched_paths, entry[1][index][1]): search_paths.add(entry[1][index][1]) - elif entry[1][index][0] != b'a': # absent + elif entry[1][index][0] != b"a": # absent found_ids.add(entry[0][2]) + while search_paths: current_root = search_paths.pop() searched_paths.add(current_root) @@ -1001,17 +1057,17 @@ def _process_entry(entry): continue for entry in root_entries: _process_entry(entry) - initial_key = (current_root, b'', b'') + initial_key = (current_root, b"", b"") block_index, _ = state._find_block_index_from_key(initial_key) - while (block_index < len(state._dirblocks) and - osutils.is_inside(current_root, state._dirblocks[block_index][0])): + while block_index < len(state._dirblocks) and osutils.is_inside( + current_root, state._dirblocks[block_index][0] + ): for entry in state._dirblocks[block_index][1]: _process_entry(entry) block_index += 1 return found_ids - def _paths2ids_using_bisect(self, paths, search_indexes, - require_versioned=True): + def _paths2ids_using_bisect(self, paths, search_indexes, require_versioned=True): state = self.current_dirstate() found_ids = set() @@ -1023,11 +1079,12 @@ def _paths2ids_using_bisect(self, paths, search_indexes, for dir_name in split_paths: if dir_name not in found_dir_names: raise errors.PathsNotVersionedError( - [p.decode('utf-8') for p in paths]) + [p.decode("utf-8") for p in paths] + ) for dir_name_id, trees_info in found.items(): for index in search_indexes: - if trees_info[index][0] not in (b'r', b'a'): + if trees_info[index][0] not in (b"r", b"a"): found_ids.add(dir_name_id[2]) return found_ids @@ -1051,8 +1108,11 @@ def revision_tree(self, revision_id): if revision_id in dirstate.get_ghosts(): raise errors.NoSuchRevisionInTree(self, revision_id) return DirStateRevisionTree( - dirstate, revision_id, self.branch.repository, - get_transport_from_path(self.basedir)) + dirstate, + revision_id, + self.branch.repository, + get_transport_from_path(self.basedir), + ) def set_last_revision(self, new_revision): """Change the last revision in the working tree.""" @@ -1062,12 +1122,13 @@ def set_last_revision(self, new_revision): if len(parents) >= 2: raise AssertionError( "setting the last parent to none with a pending merge " - "is unsupported.") + "is unsupported." + ) self.set_parent_ids([]) else: self.set_parent_ids( - [new_revision] + parents[1:], - allow_leftmost_as_ghost=True) + [new_revision] + parents[1:], allow_leftmost_as_ghost=True + ) def set_parent_ids(self, revision_ids, allow_leftmost_as_ghost=False): """Set the parent ids to revision_ids. @@ -1094,7 +1155,8 @@ def set_parent_ids(self, revision_ids, allow_leftmost_as_ghost=False): revtree = None trees.append((revision_id, revtree)) self.set_parent_trees( - trees, allow_leftmost_as_ghost=allow_leftmost_as_ghost) + trees, allow_leftmost_as_ghost=allow_leftmost_as_ghost + ) def set_parent_trees(self, parents_list, allow_leftmost_as_ghost=False): """Set the parents of the working tree. @@ -1129,17 +1191,24 @@ def set_parent_trees(self, parents_list, allow_leftmost_as_ghost=False): if tree is not None: real_trees.append((rev_id, tree)) else: - real_trees.append((rev_id, - self.branch.repository.revision_tree( - _mod_revision.NULL_REVISION))) + real_trees.append( + ( + rev_id, + self.branch.repository.revision_tree( + _mod_revision.NULL_REVISION + ), + ) + ) ghosts.append(rev_id) accepted_revisions.add(rev_id) updated = False - if (len(real_trees) == 1 + if ( + len(real_trees) == 1 and not ghosts and self.branch.repository._format.fast_deltas and isinstance(real_trees[0][1], InventoryRevisionTree) - and self.get_parent_ids()): + and self.get_parent_ids() + ): rev_id, rev_tree = real_trees[0] basis_id = self.get_parent_ids()[0] # There are times when basis_tree won't be in @@ -1152,7 +1221,8 @@ def set_parent_trees(self, parents_list, allow_leftmost_as_ghost=False): pass else: delta = rev_tree.root_inventory._make_delta( - basis_tree.root_inventory) + basis_tree.root_inventory + ) dirstate.update_basis_by_delta(delta, rev_id) updated = True if not updated: @@ -1162,7 +1232,7 @@ def set_parent_trees(self, parents_list, allow_leftmost_as_ghost=False): def _set_root_id(self, file_id): """See WorkingTree.set_root_id.""" state = self.current_dirstate() - state.set_path_id(b'', file_id) + state.set_path_id(b"", file_id) if state._dirblock_state == dirstate.DirState.IN_MEMORY_MODIFIED: self._make_dirty(reset_inventory=True) @@ -1187,7 +1257,7 @@ def unlock(self): # eventually we should do signature checking during read locks for # dirstate updates. - if self._control_files._lock_mode == 'w': + if self._control_files._lock_mode == "w": if self._dirty: self.flush() if self._dirstate is not None: @@ -1233,12 +1303,15 @@ def unversion(self, paths): # walk the state marking unversioned things as absent. # if there are any un-unversioned ids at the end, raise for key, details in state._dirblocks[0][1]: - if (details[0][0] not in (b'a', b'r') and # absent or relocated - key[2] in ids_to_unversion): + if ( + details[0][0] not in (b"a", b"r") # absent or relocated + and key[2] in ids_to_unversion + ): # I haven't written the code to unversion / yet - it should # be supported. raise errors.BzrError( - 'Unversioning the / is not currently supported') + "Unversioning the / is not currently supported" + ) block_index = 0 while block_index < len(state._dirblocks): # process one directory at a time. @@ -1246,9 +1319,9 @@ def unversion(self, paths): # first check: is the path one to remove - it or its children delete_block = False for path in paths_to_unversion: - if (block[0].startswith(path) and - (len(block[0]) == len(path) or - block[0][len(path)] == '/')): + if block[0].startswith(path) and ( + len(block[0]) == len(path) or block[0][len(path)] == "/" + ): # this entire block should be deleted - its the block for a # path to unversion; or the child of one delete_block = True @@ -1261,7 +1334,7 @@ def unversion(self, paths): entry_index = 0 while entry_index < len(block[1]): entry = block[1][entry_index] - if entry[1][0][0] in (b'a', b'r'): + if entry[1][0][0] in (b"a", b"r"): # don't remove absent or renamed entries entry_index += 1 else: @@ -1277,15 +1350,17 @@ def unversion(self, paths): entry_index = 0 while entry_index < len(block[1]): entry = block[1][entry_index] - if (entry[1][0][0] in (b'a', b'r') or # absent, relocated + if ( + entry[1][0][0] in (b"a", b"r") # absent, relocated + or # ^ some parent row. - entry[0][2] not in ids_to_unversion): + entry[0][2] not in ids_to_unversion + ): # ^ not an id to unversion entry_index += 1 continue - if entry[1][0][0] == b'd': - paths_to_unversion.add( - pathjoin(entry[0][0], entry[0][1])) + if entry[1][0][0] == b"d": + paths_to_unversion.add(pathjoin(entry[0][0], entry[0][1])) if not state._make_absent(entry): entry_index += 1 # we have unversioned this id @@ -1304,8 +1379,7 @@ def rename_one(self, from_rel, to_rel, after=False): """See WorkingTree.rename_one.""" with self.lock_tree_write(): self.flush() - super().rename_one( - from_rel, to_rel, after) + super().rename_one(from_rel, to_rel, after) def apply_inventory_delta(self, changes): """See MutableTree.apply_inventory_delta.""" @@ -1327,8 +1401,10 @@ def _validate(self): def _write_inventory(self, inv): """Write inventory as the current inventory.""" if self._dirty: - raise AssertionError("attempting to write an inventory when the " - "dirstate is dirty will lose pending changes") + raise AssertionError( + "attempting to write an inventory when the " + "dirstate is dirty will lose pending changes" + ) with self.lock_tree_write(): had_inventory = self._inventory is not None # Setting self._inventory = None forces the dirstate to regenerate the @@ -1355,11 +1431,16 @@ def reset_state(self, revision_ids=None): revision_ids = self.get_parent_ids() if not revision_ids: base_tree = self.branch.repository.revision_tree( - _mod_revision.NULL_REVISION) + _mod_revision.NULL_REVISION + ) trees = [] else: - trees = list(zip(revision_ids, - self.branch.repository.revision_trees(revision_ids))) + trees = list( + zip( + revision_ids, + self.branch.repository.revision_trees(revision_ids), + ) + ) base_tree = trees[0][1] state = self.current_dirstate() # We don't support ghosts yet @@ -1367,21 +1448,22 @@ def reset_state(self, revision_ids=None): class ContentFilterAwareSHA1Provider(dirstate.SHA1Provider): - def __init__(self, tree): self.tree = tree def sha1(self, abspath): """See dirstate.SHA1Provider.sha1().""" filters = self.tree._content_filter_stack( - self.tree.relpath(osutils.safe_unicode(abspath))) + self.tree.relpath(osutils.safe_unicode(abspath)) + ) return _mod_filters.internal_size_sha_file_byname(abspath, filters)[1] def stat_and_sha1(self, abspath): """See dirstate.SHA1Provider.stat_and_sha1().""" filters = self.tree._content_filter_stack( - self.tree.relpath(osutils.safe_unicode(abspath))) - with open(abspath, 'rb', 65000) as file_obj: + self.tree.relpath(osutils.safe_unicode(abspath)) + ) + with open(abspath, "rb", 65000) as file_obj: statvalue = os.fstat(file_obj.fileno()) if filters: file_obj, size = _mod_filters.filtered_input_file(file_obj, filters) @@ -1413,7 +1495,7 @@ def _file_content_summary(self, path, stat_result): # can't trust it for content-filtered trees. We just return None. dirstate_sha1 = self._dirstate.sha1_from_stat(path, stat_result) executable = self._is_executable_from_path_and_stat(path, stat_result) - return ('file', None, executable, dirstate_sha1) + return ("file", None, executable, dirstate_sha1) class WorkingTree4(DirStateWorkingTree): @@ -1454,21 +1536,25 @@ def _make_views(self): class DirStateWorkingTreeFormat(WorkingTreeFormatMetaDir): - missing_parent_conflicts = True supports_versioned_directories = True _lock_class = LockDir - _lock_file_name = 'lock' + _lock_file_name = "lock" def _open_control_files(self, a_controldir): transport = a_controldir.get_workingtree_transport(None) - return LockableFiles(transport, self._lock_file_name, - self._lock_class) - - def initialize(self, a_controldir, revision_id=None, from_branch=None, - accelerator_tree=None, hardlink=False): + return LockableFiles(transport, self._lock_file_name, self._lock_class) + + def initialize( + self, + a_controldir, + revision_id=None, + from_branch=None, + accelerator_tree=None, + hardlink=False, + ): """See WorkingTreeFormat.initialize(). :param revision_id: allows creating a working tree at a different @@ -1483,29 +1569,32 @@ def initialize(self, a_controldir, revision_id=None, from_branch=None, These trees get an initial random root id, if their repository supports rich root data, TREE_ROOT otherwise. """ - a_controldir.transport.local_abspath('.') + a_controldir.transport.local_abspath(".") transport = a_controldir.get_workingtree_transport(self) control_files = self._open_control_files(a_controldir) control_files.create_lock() control_files.lock_write() - transport.put_bytes('format', self.as_string(), - mode=a_controldir._get_file_mode()) + transport.put_bytes( + "format", self.as_string(), mode=a_controldir._get_file_mode() + ) if from_branch is not None: branch = from_branch else: branch = a_controldir.open_branch() if revision_id is None: revision_id = branch.last_revision() - local_path = transport.local_abspath('dirstate') + local_path = transport.local_abspath("dirstate") # write out new dirstate (must exist when we create the tree) state = dirstate.DirState.initialize(local_path) state.unlock() del state - wt = self._tree_class(a_controldir.root_transport.local_abspath('.'), - branch, - _format=self, - _controldir=a_controldir, - _control_files=control_files) + wt = self._tree_class( + a_controldir.root_transport.local_abspath("."), + branch, + _format=self, + _controldir=a_controldir, + _control_files=control_files, + ) wt._new_tree() wt.lock_tree_write() try: @@ -1536,7 +1625,7 @@ def initialize(self, a_controldir, revision_id=None, from_branch=None, wt.flush() # if the basis has a root id we have to use that; otherwise we # use a new random one - basis_root_id = basis.path2id('') + basis_root_id = basis.path2id("") if basis_root_id is not None: wt._set_root_id(basis_root_id) wt.flush() @@ -1551,10 +1640,13 @@ def initialize(self, a_controldir, revision_id=None, from_branch=None, # because wt4.apply_inventory_delta does not mutate the input # inventory entries. bzr_transform.build_tree( - basis, wt, accelerator_tree, + basis, + wt, + accelerator_tree, hardlink=hardlink, - delta_from_tree=delta_from_tree) - for hook in MutableTree.hooks['post_build_tree']: + delta_from_tree=delta_from_tree, + ) + for hook in MutableTree.hooks["post_build_tree"]: hook(wt) finally: control_files.unlock() @@ -1579,7 +1671,7 @@ def open(self, a_controldir, _found=False): if not _found: # we are being called directly and must probe. raise NotImplementedError - a_controldir.transport.local_abspath('.') + a_controldir.transport.local_abspath(".") wt = self._open(a_controldir, self._open_control_files(a_controldir)) return wt @@ -1589,11 +1681,13 @@ def _open(self, a_controldir, control_files): :param a_controldir: the dir for the tree. :param control_files: the control files for the tree. """ - return self._tree_class(a_controldir.root_transport.local_abspath('.'), - branch=a_controldir.open_branch(), - _format=self, - _controldir=a_controldir, - _control_files=control_files) + return self._tree_class( + a_controldir.root_transport.local_abspath("."), + branch=a_controldir.open_branch(), + _format=self, + _controldir=a_controldir, + _control_files=control_files, + ) def __get_matchingcontroldir(self): return self._get_matchingcontroldir() @@ -1601,8 +1695,7 @@ def __get_matchingcontroldir(self): def _get_matchingcontroldir(self): """Overrideable method to get a bzrdir for testing.""" # please test against something that will let us do tree references - return controldir.format_registry.make_controldir( - 'development-subtree') + return controldir.format_registry.make_controldir("development-subtree") _matchingcontroldir = property(__get_matchingcontroldir) @@ -1671,8 +1764,7 @@ def get_format_description(self): def _init_custom_control_files(self, wt): """Subclasses with custom control files should override this method.""" - wt._transport.put_bytes('views', b'', - mode=wt.controldir._get_file_mode()) + wt._transport.put_bytes("views", b"", mode=wt.controldir._get_file_mode()) def supports_content_filtering(self): return True @@ -1684,7 +1776,7 @@ def _get_matchingcontroldir(self): """Overrideable method to get a bzrdir for testing.""" # We use 'development-subtree' instead of '2a', because we have a # few tests that want to test tree references - return controldir.format_registry.make_controldir('development-subtree') + return controldir.format_registry.make_controldir("development-subtree") class DirStateRevisionTree(InventoryTree): @@ -1703,14 +1795,13 @@ def __init__(self, dirstate, revision_id, repository, nested_tree_transport): self._dirstate_locked = False self._nested_tree_transport = nested_tree_transport self._repo_supports_tree_reference = getattr( - repository._format, "supports_tree_reference", - False) + repository._format, "supports_tree_reference", False + ) def __repr__(self): return f"<{self.__class__.__name__} of {self._revision_id} in {self._dirstate}>" - def annotate_iter(self, path, - default_revision=_mod_revision.CURRENT_REVISION): + def annotate_iter(self, path, default_revision=_mod_revision.CURRENT_REVISION): """See Tree.annotate_iter.""" file_id = self.path2id(path) text_key = (file_id, self.get_file_revision(path)) @@ -1723,7 +1814,7 @@ def iter_child_entries(self, path): if inv is None: raise NoSuchFile(path) ie = inv.get_entry(inv_file_id) - if ie.kind != 'directory': + if ie.kind != "directory": raise errors.NotADirectory(path) return inv.iter_sorted_children(inv_file_id) @@ -1751,25 +1842,26 @@ def filter_unversioned_files(self, paths): pred = self.has_filename return {p for p in paths if not pred(p)} - def id2path(self, file_id, recurse='down'): + def id2path(self, file_id, recurse="down"): """Convert a file-id to a path.""" with self.lock_read(): entry = self._get_entry(file_id=file_id) if entry == (None, None): - if recurse == 'down': - if debug.debug_flag_enabled('evil'): - trace.mutter_callsite( - 2, "Tree.id2path scans all nested trees.") + if recurse == "down": + if debug.debug_flag_enabled("evil"): + trace.mutter_callsite(2, "Tree.id2path scans all nested trees.") for nested_path in self.iter_references(): nested_tree = self.get_nested_tree(nested_path) try: - return osutils.pathjoin(nested_path, nested_tree.id2path(file_id)) + return osutils.pathjoin( + nested_path, nested_tree.id2path(file_id) + ) except errors.NoSuchId: pass raise errors.NoSuchId(tree=self, file_id=file_id) path_utf8 = osutils.pathjoin(entry[0][0], entry[0][1]) - return path_utf8.decode('utf8') + return path_utf8.decode("utf8") def get_nested_tree(self, path): with self.lock_read(): @@ -1779,16 +1871,18 @@ def get_nested_tree(self, path): def _get_nested_tree(self, path, file_id, reference_revision): try: branch = _mod_branch.Branch.open_from_transport( - self._nested_tree_transport.clone(path)) + self._nested_tree_transport.clone(path) + ) except errors.NotBranchError as e: raise MissingNestedTree(path) from e try: revtree = branch.repository.revision_tree(reference_revision) except errors.NoSuchRevision as e: raise MissingNestedTree(path) from e - if file_id is not None and revtree.path2id('') != file_id: - raise AssertionError('mismatching file id: {!r} != {!r}'.format( - revtree.path2id(''), file_id)) + if file_id is not None and revtree.path2id("") != file_id: + raise AssertionError( + "mismatching file id: {!r} != {!r}".format(revtree.path2id(""), file_id) + ) return revtree def iter_references(self): @@ -1815,15 +1909,18 @@ def _get_entry(self, file_id=None, path=None): :return: The dirstate row tuple for path/file_id, or (None, None) """ if file_id is None and path is None: - raise errors.BzrError('must supply file_id or path') + raise errors.BzrError("must supply file_id or path") if path is not None: - path = path.encode('utf8') + path = path.encode("utf8") try: parent_index = self._get_parent_index() except ValueError as err: - raise errors.NoSuchRevisionInTree(self._dirstate, self._revision_id) from err - return self._dirstate._get_entry(parent_index, fileid_utf8=file_id, - path_utf8=path) + raise errors.NoSuchRevisionInTree( + self._dirstate, self._revision_id + ) from err + return self._dirstate._get_entry( + parent_index, fileid_utf8=file_id, path_utf8=path + ) def _generate_inventory(self): """Create and set self.inventory from the dirstate object. @@ -1835,21 +1932,22 @@ def _generate_inventory(self): """ if not self._locked: raise AssertionError( - 'cannot generate inventory of an unlocked ' - 'dirstate revision tree') + "cannot generate inventory of an unlocked " "dirstate revision tree" + ) # separate call for profiling - makes it clear where the costs are. self._dirstate._read_dirblocks_if_needed() if self._revision_id not in self._dirstate.get_parent_ids(): raise AssertionError( - 'parent {} has disappeared from {}'.format( - self._revision_id, self._dirstate.get_parent_ids())) + "parent {} has disappeared from {}".format( + self._revision_id, self._dirstate.get_parent_ids() + ) + ) parent_index = self._dirstate.get_parent_ids().index(self._revision_id) + 1 # This is identical now to the WorkingTree _generate_inventory except # for the tree index use. - root_key, current_entry = self._dirstate._get_entry( - parent_index, path_utf8=b'') + root_key, current_entry = self._dirstate._get_entry(parent_index, path_utf8=b"") current_id = root_key[2] - if current_entry[parent_index][0] != b'd': + if current_entry[parent_index][0] != b"d": raise AssertionError() inv = Inventory(root_id=current_id, revision_id=self._revision_id) if current_entry[parent_index][4] == b"": @@ -1862,7 +1960,7 @@ def _generate_inventory(self): utf8_decode = cache_utf8._utf8_decode # we could do this straight out of the dirstate; it might be fast # and should be profiled - RBC 20070216 - parent_ies = {b'': inv.root} + parent_ies = {b"": inv.root} for block in self._dirstate._dirblocks[1:]: # skip root dirname = block[0] try: @@ -1871,39 +1969,41 @@ def _generate_inventory(self): # all the paths in this block are not versioned in this tree continue for key, entry in block[1]: - (minikind, fingerprint, size, executable, - revid) = entry[parent_index] - if minikind in (b'a', b'r'): # absent, relocated + (minikind, fingerprint, size, executable, revid) = entry[parent_index] + if minikind in (b"a", b"r"): # absent, relocated # not this tree continue name = key[1] name_unicode = utf8_decode(name)[0] file_id = key[2] kind = minikind_to_kind[minikind] - inv_entry = factory[kind](file_id, name_unicode, - parent_ie.file_id) + inv_entry = factory[kind](file_id, name_unicode, parent_ie.file_id) inv_entry.revision = revid - if kind == 'file': + if kind == "file": inv_entry.executable = bool(executable) inv_entry.text_size = size inv_entry.text_sha1 = fingerprint - elif kind == 'directory': - parent_ies[(dirname + b'/' + name).strip(b'/')] = inv_entry - elif kind == 'symlink': + elif kind == "directory": + parent_ies[(dirname + b"/" + name).strip(b"/")] = inv_entry + elif kind == "symlink": inv_entry.symlink_target = utf8_decode(fingerprint)[0] - elif kind == 'tree-reference': + elif kind == "tree-reference": inv_entry.reference_revision = fingerprint or None else: raise AssertionError( - f"cannot convert entry {entry!r} into an InventoryEntry") + f"cannot convert entry {entry!r} into an InventoryEntry" + ) try: inv.add(inv_entry) except DuplicateFileId as err: raise AssertionError( - f'file_id {file_id} already in' - f' inventory as {inv.get_entry(file_id)}') from err + f"file_id {file_id} already in" + f" inventory as {inv.get_entry(file_id)}" + ) from err except errors.InconsistentDelta as err: - raise AssertionError(f'name {name_unicode!r} already in parent') from err + raise AssertionError( + f"name {name_unicode!r} already in parent" + ) from err self._inventory = inv def get_file_mtime(self, path): @@ -1913,7 +2013,7 @@ def get_file_mtime(self, path): """ # Make sure the file exists entry = self._get_entry(path=path) - if entry == (None, None): # do we raise? + if entry == (None, None): # do we raise? nested_tree, subpath = self.get_containing_nested_tree(path) if nested_tree is not None: return nested_tree.get_file_mtime(subpath) @@ -1930,7 +2030,7 @@ def get_file_sha1(self, path, stat_value=None): entry = self._get_entry(path=path) parent_index = self._get_parent_index() parent_details = entry[1][parent_index] - if parent_details[0] == b'f': + if parent_details[0] == b"f": return parent_details[1] return None @@ -1951,14 +2051,14 @@ def get_file_text(self, path): content = None for _, content_iter in self.iter_files_bytes([(path, None)]): if content is not None: - raise AssertionError('iter_files_bytes returned' - ' too many entries') + raise AssertionError("iter_files_bytes returned" " too many entries") # For each entry returned by iter_files_bytes, we must consume the # content_iter before we step the files iterator. - content = b''.join(content_iter) + content = b"".join(content_iter) if content is None: - raise AssertionError('iter_files_bytes did not return' - ' the requested data') + raise AssertionError( + "iter_files_bytes did not return" " the requested data" + ) return content def get_reference_revision(self, path): @@ -1976,8 +2076,9 @@ def iter_files_bytes(self, desired_files): entry = self._get_entry(path=path) if entry == (None, None): raise NoSuchFile(path) - repo_desired_files.append((entry[0][2], entry[1][parent_index][4], - identifier)) + repo_desired_files.append( + (entry[0][2], entry[1][parent_index][4], identifier) + ) return self._repository.iter_files_bytes(repo_desired_files) def get_symlink_target(self, path): @@ -1985,11 +2086,11 @@ def get_symlink_target(self, path): if entry is None: raise NoSuchFile(tree=self, path=path) parent_index = self._get_parent_index() - if entry[1][parent_index][0] != b'l': + if entry[1][parent_index][0] != b"l": return None else: target = entry[1][parent_index][1] - target = target.decode('utf8') + target = target.decode("utf8") return target def get_revision_id(self): @@ -2003,8 +2104,7 @@ def _get_root_inventory(self): self._generate_inventory() return self._inventory - root_inventory = property(_get_root_inventory, - doc="Inventory of this Tree") + root_inventory = property(_get_root_inventory, doc="Inventory of this Tree") def get_parent_ids(self): """The parents of a tree in the dirstate are not cached.""" @@ -2028,12 +2128,12 @@ def path_content_summary(self, path): """See Tree.path_content_summary.""" inv, inv_file_id = self._path2inv_file_id(path) if inv_file_id is None: - return ('missing', None, None, None) + return ("missing", None, None, None) entry = inv.get_entry(inv_file_id) kind = entry.kind - if kind == 'file': + if kind == "file": return (kind, entry.text_size, entry.executable, entry.text_sha1) - elif kind == 'symlink': + elif kind == "symlink": return (kind, None, None, entry.symlink_target) else: return (kind, None, None, None) @@ -2050,8 +2150,9 @@ def is_executable(self, path): def is_locked(self): return self._locked - def list_files(self, include_root=False, from_dir=None, recursive=True, - recurse_nested=False): + def list_files( + self, include_root=False, from_dir=None, recursive=True, recurse_nested=False + ): # The only files returned by this are those from the version if from_dir is None: from_dir_id = None @@ -2061,25 +2162,30 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, if from_dir_id is None: # Directory not versioned return iter([]) + def iter_entries(inv): entries = inv.iter_entries(from_dir=from_dir_id, recursive=recursive) if inv.root is not None and not include_root and from_dir is None: # skip the root for compatibility with the current apis. next(entries) for path, entry in entries: - if entry.kind == 'tree-reference' and recurse_nested: + if entry.kind == "tree-reference" and recurse_nested: subtree = self._get_nested_tree( - path, entry.file_id, entry.reference_revision) + path, entry.file_id, entry.reference_revision + ) for subpath, status, kind, entry in subtree.list_files( - include_root=True, recursive=recursive, - recurse_nested=recurse_nested): + include_root=True, + recursive=recursive, + recurse_nested=recurse_nested, + ): if subpath: full_subpath = osutils.pathjoin(path, subpath) else: full_subpath = path yield full_subpath, status, kind, entry else: - yield path, 'V', entry.kind, entry + yield path, "V", entry.kind, entry + return iter_entries(inv) def lock_read(self): @@ -2137,7 +2243,7 @@ def walkdirs(self, prefix=""): # This should be cleaned up to use the much faster Dirstate code # So for now, we just build up the parent inventory, and extract # it the same way RevisionTree does. - _directory = 'directory' + _directory = "directory" inv = self._get_root_inventory() top_id = inv.path2id(prefix) if top_id is None: @@ -2149,7 +2255,7 @@ def walkdirs(self, prefix=""): relpath, file_id = pending.pop() # 0 - relpath, 1- file-id if relpath: - relroot = relpath + '/' + relroot = relpath + "/" else: relroot = "" # FIXME: stash the node in pending @@ -2182,7 +2288,7 @@ def __init__(self, source, target): @staticmethod def make_source_parent_tree(source, target): """Change the source tree into a parent of the target.""" - revid = source.commit('record tree') + revid = source.commit("record tree") target.branch.fetch(source.branch, revid) target.set_parent_ids([revid]) return target.basis_tree(), target @@ -2194,11 +2300,12 @@ def make_source_parent_tree_python_dirstate(klass, test_case, source, target): return result @classmethod - def make_source_parent_tree_compiled_dirstate(klass, test_case, source, - target): + def make_source_parent_tree_compiled_dirstate(klass, test_case, source, target): from .tests.test__dirstate_helpers import compiled_dirstate_helpers_feature + test_case.requireFeature(compiled_dirstate_helpers_feature) from ._dirstate_helpers_pyx import ProcessEntryC + result = klass.make_source_parent_tree(source, target) result[1]._iter_changes = ProcessEntryC return result @@ -2212,9 +2319,15 @@ def _test_mutable_trees_to_test_trees(klass, test_case, source, target): # specific flavours. raise NotImplementedError - def iter_changes(self, include_unchanged=False, - specific_files=None, pb=None, extra_trees=None, - require_versioned=True, want_unversioned=False): + def iter_changes( + self, + include_unchanged=False, + specific_files=None, + pb=None, + extra_trees=None, + require_versioned=True, + want_unversioned=False, + ): """Return the changes from source to target. :return: An iterator that yields tuples. See InterTree.iter_changes @@ -2237,17 +2350,25 @@ def iter_changes(self, include_unchanged=False, if extra_trees is None: extra_trees = [] # TODO: handle extra trees in the dirstate. - if (extra_trees or specific_files == []): + if extra_trees or specific_files == []: # we can't fast-path these cases (yet) return super().iter_changes( - include_unchanged, specific_files, pb, extra_trees, - require_versioned, want_unversioned=want_unversioned) + include_unchanged, + specific_files, + pb, + extra_trees, + require_versioned, + want_unversioned=want_unversioned, + ) parent_ids = self.target.get_parent_ids() - if not (self.source._revision_id in parent_ids - or self.source._revision_id == _mod_revision.NULL_REVISION): + if not ( + self.source._revision_id in parent_ids + or self.source._revision_id == _mod_revision.NULL_REVISION + ): raise AssertionError( f"revision {{{self.source._revision_id}}} is not stored in {{{self.target}}}, but {self.iter_changes} " - "can only be used for trees stored in the dirstate") + "can only be used for trees stored in the dirstate" + ) target_index = 0 if self.source._revision_id == _mod_revision.NULL_REVISION: source_index = None @@ -2256,12 +2377,14 @@ def iter_changes(self, include_unchanged=False, if self.source._revision_id not in parent_ids: raise AssertionError( "Failure: source._revision_id: {} not in target.parent_ids({})".format( - self.source._revision_id, parent_ids)) + self.source._revision_id, parent_ids + ) + ) source_index = 1 + parent_ids.index(self.source._revision_id) indices = (source_index, target_index) if specific_files is None: - specific_files = {''} + specific_files = {""} # -- get the state object and prepare it. state = self.target.current_dirstate() @@ -2270,7 +2393,7 @@ def iter_changes(self, include_unchanged=False, # -- check all supplied paths are versioned in a search tree. -- not_versioned = [] for path in specific_files: - path_entries = state._entries_for_path(path.encode('utf-8')) + path_entries = state._entries_for_path(path.encode("utf-8")) if not path_entries: # this specified path is not present at all: error not_versioned.append(path) @@ -2280,7 +2403,7 @@ def iter_changes(self, include_unchanged=False, for entry in path_entries: # for each tree. for index in indices: - if entry[1][index][0] != b'a': # absent + if entry[1][index][0] != b"a": # absent found_versioned = True # all good: found a versioned cell break @@ -2297,12 +2420,18 @@ def iter_changes(self, include_unchanged=False, for path in osutils.minimum_path_selection(specific_files): # Note, if there are many specific files, using cache_utf8 # would be good here. - search_specific_files_utf8.add(path.encode('utf8')) + search_specific_files_utf8.add(path.encode("utf8")) iter_changes = self.target._iter_changes( - include_unchanged, self.target._supports_executable(), - search_specific_files_utf8, state, source_index, target_index, - want_unversioned, self.target) + include_unchanged, + self.target._supports_executable(), + search_specific_files_utf8, + state, + source_index, + target_index, + want_unversioned, + self.target, + ) return iter_changes.iter_changes() @staticmethod @@ -2311,12 +2440,13 @@ def is_compatible(source, target): if not isinstance(target, DirStateWorkingTree): return False # the source must be a revtree or dirstate rev tree. - if not isinstance(source, - (revisiontree.RevisionTree, DirStateRevisionTree)): + if not isinstance(source, (revisiontree.RevisionTree, DirStateRevisionTree)): return False # the source revid must be in the target dirstate - if not (source._revision_id == _mod_revision.NULL_REVISION or - source._revision_id in target.get_parent_ids()): + if not ( + source._revision_id == _mod_revision.NULL_REVISION + or source._revision_id in target.get_parent_ids() + ): # TODO: what about ghosts? it may well need to # check for them explicitly. return False @@ -2347,8 +2477,9 @@ def convert(self, tree): def create_dirstate_data(self, tree): """Create the dirstate based data for tree.""" - local_path = tree.controldir.get_workingtree_transport( - None).local_abspath('dirstate') + local_path = tree.controldir.get_workingtree_transport(None).local_abspath( + "dirstate" + ) state = dirstate.DirState.from_tree(tree, local_path) state.save() state.unlock() @@ -2356,8 +2487,13 @@ def create_dirstate_data(self, tree): def remove_xml_files(self, tree): """Remove the oldformat 3 data.""" transport = tree.controldir.get_workingtree_transport(None) - for path in ['basis-inventory-cache', 'inventory', 'last-revision', - 'pending-merges', 'stat-cache']: + for path in [ + "basis-inventory-cache", + "inventory", + "last-revision", + "pending-merges", + "stat-cache", + ]: try: transport.delete(path) except NoSuchFile: @@ -2366,9 +2502,11 @@ def remove_xml_files(self, tree): def update_format(self, tree): """Change the format marker.""" - tree._transport.put_bytes('format', - self.target_format.as_string(), - mode=tree.controldir._get_file_mode()) + tree._transport.put_bytes( + "format", + self.target_format.as_string(), + mode=tree.controldir._get_file_mode(), + ) class Converter4to5: @@ -2389,9 +2527,11 @@ def convert(self, tree): def update_format(self, tree): """Change the format marker.""" - tree._transport.put_bytes('format', - self.target_format.as_string(), - mode=tree.controldir._get_file_mode()) + tree._transport.put_bytes( + "format", + self.target_format.as_string(), + mode=tree.controldir._get_file_mode(), + ) class Converter4or5to6: @@ -2413,11 +2553,12 @@ def convert(self, tree): def init_custom_control_files(self, tree): """Initialize custom control files.""" - tree._transport.put_bytes('views', b'', - mode=tree.controldir._get_file_mode()) + tree._transport.put_bytes("views", b"", mode=tree.controldir._get_file_mode()) def update_format(self, tree): """Change the format marker.""" - tree._transport.put_bytes('format', - self.target_format.as_string(), - mode=tree.controldir._get_file_mode()) + tree._transport.put_bytes( + "format", + self.target_format.as_string(), + mode=tree.controldir._get_file_mode(), + ) diff --git a/breezy/bzr/xml5.py b/breezy/bzr/xml5.py index 2516312aab..508fb21122 100644 --- a/breezy/bzr/xml5.py +++ b/breezy/bzr/xml5.py @@ -26,22 +26,24 @@ class InventorySerializer_v5(xml6.InventorySerializer_v6): Packs objects into XML and vice versa. """ - format_num = b'5' + + format_num = b"5" root_id = inventory.ROOT_ID - def _unpack_inventory(self, elt, revision_id, entry_cache=None, - return_from_cache=False): + def _unpack_inventory( + self, elt, revision_id, entry_cache=None, return_from_cache=False + ): """Construct from XML Element.""" - root_id = elt.get('file_id') or inventory.ROOT_ID + root_id = elt.get("file_id") or inventory.ROOT_ID root_id = get_utf8_or_ascii(root_id) - format = elt.get('format') + format = elt.get("format") if format is not None: - if format != '5': + if format != "5": raise errors.BzrError(f"invalid format version {format!r} on inventory") - data_revision_id = elt.get('revision_id') + data_revision_id = elt.get("revision_id") if data_revision_id is not None: - revision_id = data_revision_id.encode('utf-8') + revision_id = data_revision_id.encode("utf-8") inv = inventory.Inventory(root_id, revision_id=revision_id) # Optimizations tested # baseline w/entry cache 2.85s @@ -52,23 +54,32 @@ def _unpack_inventory(self, elt, revision_id, entry_cache=None, byid = inv._byid children = inv._children for e in elt: - ie = unpack_inventory_entry(e, entry_cache=entry_cache, - return_from_cache=return_from_cache, root_id=root_id) + ie = unpack_inventory_entry( + e, + entry_cache=entry_cache, + return_from_cache=return_from_cache, + root_id=root_id, + ) try: parent = byid[ie.parent_id] except KeyError as err: - raise errors.BzrError(f"parent_id {{{ie.parent_id}}} not in inventory") from err + raise errors.BzrError( + f"parent_id {{{ie.parent_id}}} not in inventory" + ) from err if ie.file_id in byid: raise inventory.DuplicateFileId(ie.file_id, byid[ie.file_id]) siblings = children[parent.file_id] if ie.name in siblings: raise errors.BzrError( "{} is already versioned".format( - osutils.pathjoin( - inv.id2path(ie.parent_id), ie.name).encode('utf-8'))) + osutils.pathjoin(inv.id2path(ie.parent_id), ie.name).encode( + "utf-8" + ) + ) + ) siblings[ie.name] = ie byid[ie.file_id] = ie - if ie.kind == 'directory': + if ie.kind == "directory": children[ie.file_id] = {} if revision_id is not None: inv.root.revision = revision_id @@ -87,11 +98,15 @@ def _check_revisions(self, inv): def _append_inventory_root(self, append, inv): """Append the inventory root to output.""" if inv.root.file_id not in (None, inventory.ROOT_ID): - fileid = b''.join([b' file_id="', encode_and_escape(inv.root.file_id), b'"']) + fileid = b"".join( + [b' file_id="', encode_and_escape(inv.root.file_id), b'"'] + ) else: fileid = b"" if inv.revision_id is not None: - revid = b''.join([b' revision_id="', encode_and_escape(inv.revision_id), b'"']) + revid = b"".join( + [b' revision_id="', encode_and_escape(inv.revision_id), b'"'] + ) else: revid = b"" append(b'\n' % (fileid, revid)) diff --git a/breezy/bzr/xml6.py b/breezy/bzr/xml6.py index c3ae990ded..d799898a40 100644 --- a/breezy/bzr/xml6.py +++ b/breezy/bzr/xml6.py @@ -25,7 +25,7 @@ class InventorySerializer_v6(xml8.InventorySerializer_v8): converted from format 5 or 7 without updating the sha1. """ - format_num = b'6' + format_num = b"6" inventory_serializer_v6 = InventorySerializer_v6() diff --git a/breezy/bzr/xml7.py b/breezy/bzr/xml7.py index bfcc96bcef..14636ab61f 100644 --- a/breezy/bzr/xml7.py +++ b/breezy/bzr/xml7.py @@ -22,8 +22,8 @@ class InventorySerializer_v7(xml6.InventorySerializer_v6): # this format is used by BzrBranch6 - supported_kinds = {'file', 'directory', 'symlink', 'tree-reference'} - format_num = b'7' + supported_kinds = {"file", "directory", "symlink", "tree-reference"} + format_num = b"7" inventory_serializer_v7 = InventorySerializer_v7() diff --git a/breezy/bzr/xml8.py b/breezy/bzr/xml8.py index 91a92b6de2..61a13580e8 100644 --- a/breezy/bzr/xml8.py +++ b/breezy/bzr/xml8.py @@ -28,11 +28,11 @@ ) _xml_unescape_map = { - b'apos': b"'", - b'quot': b'"', - b'amp': b'&', - b'lt': b'<', - b'gt': b'>' + b"apos": b"'", + b"quot": b'"', + b"amp": b"&", + b"lt": b"<", + b"gt": b">", } @@ -41,12 +41,12 @@ def _unescaper(match, _map=_xml_unescape_map): try: return _map[code] except KeyError: - if not code.startswith(b'#'): + if not code.startswith(b"#"): raise - return chr(int(code[1:])).encode('utf8') + return chr(int(code[1:])).encode("utf8") -_unescape_re = lazy_regex.lazy_compile(b'\\&([^;]*);') +_unescape_re = lazy_regex.lazy_compile(b"\\&([^;]*);") def _unescape_xml(data): @@ -67,15 +67,14 @@ class InventorySerializer_v8(XMLInventorySerializer): # This format supports the altered-by hack that reads file ids directly out # of the versionedfile, without doing XML parsing. - supported_kinds = {'file', 'directory', 'symlink'} - format_num = b'8' + supported_kinds = {"file", "directory", "symlink"} + format_num = b"8" # The search regex used by xml based repositories to determine what things # where changed in a single commit. _file_ids_altered_regex = lazy_regex.lazy_compile( - b'file_id="(?P[^"]+)"' - b'.* revision="(?P[^"]+)"' - ) + b'file_id="(?P[^"]+)"' b'.* revision="(?P[^"]+)"' + ) def _check_revisions(self, inv): """Extension point for subclasses to check during serialisation. @@ -112,8 +111,11 @@ def _check_cache_size(self, inv_size, entry_cache): recommended_min_cache_size = inv_size * 1.5 if entry_cache.cache_size() < recommended_min_cache_size: recommended_cache_size = inv_size * 2 - trace.mutter('Resizing the inventory entry cache from %d to %d', - entry_cache.cache_size(), recommended_cache_size) + trace.mutter( + "Resizing the inventory entry cache from %d to %d", + entry_cache.cache_size(), + recommended_cache_size, + ) entry_cache.resize(recommended_cache_size) def write_inventory_to_lines(self, inv): @@ -136,8 +138,9 @@ def write_inventory(self, inv, f, working=False): output = [] append = output.append self._append_inventory_root(append, inv) - serialize_inventory_flat(inv, append, - self.root_id, self.supported_kinds, working) + serialize_inventory_flat( + inv, append, self.root_id, self.supported_kinds, working + ) if f is not None: f.writelines(output) # Just to keep the cache from growing without bounds @@ -148,27 +151,32 @@ def write_inventory(self, inv, f, working=False): def _append_inventory_root(self, append, inv): """Append the inventory root to output.""" if inv.revision_id is not None: - revid1 = b''.join( - [b' revision_id="', encode_and_escape(inv.revision_id), b'"']) + revid1 = b"".join( + [b' revision_id="', encode_and_escape(inv.revision_id), b'"'] + ) else: revid1 = b"" - append(b'\n' % ( - self.format_num, revid1)) - append(b'\n' % ( - encode_and_escape(inv.root.file_id), - encode_and_escape(inv.root.name), - encode_and_escape(inv.root.revision))) + append(b'\n' % (self.format_num, revid1)) + append( + b'\n' + % ( + encode_and_escape(inv.root.file_id), + encode_and_escape(inv.root.name), + encode_and_escape(inv.root.revision), + ) + ) def _unpack_entry(self, elt, entry_cache=None, return_from_cache=False): # This is here because it's overridden by xml7 - return unpack_inventory_entry(elt, entry_cache, - return_from_cache) + return unpack_inventory_entry(elt, entry_cache, return_from_cache) - def _unpack_inventory(self, elt, revision_id=None, entry_cache=None, - return_from_cache=False): + def _unpack_inventory( + self, elt, revision_id=None, entry_cache=None, return_from_cache=False + ): """Construct from XML Element.""" - inv = unpack_inventory_flat(elt, self.format_num, self._unpack_entry, - entry_cache, return_from_cache) + inv = unpack_inventory_flat( + elt, self.format_num, self._unpack_entry, entry_cache, return_from_cache + ) self._check_cache_size(len(inv), entry_cache) return inv @@ -188,7 +196,8 @@ def _find_text_key_references(self, line_iterator): raise AssertionError( "_find_text_key_references only " "supported for branches which store inventory as unnested xml" - ", not on %r" % self) + ", not on %r" % self + ) result = {} # this code needs to read every new line in every inventory for the @@ -216,7 +225,7 @@ def _find_text_key_references(self, line_iterator): continue # One call to match.group() returning multiple items is quite a # bit faster than 2 calls to match.group() each returning 1 - file_id, revision_id = match.group('file_id', 'revision_id') + file_id, revision_id = match.group("file_id", "revision_id") # Inlining the cache lookups helps a lot when you make 170,000 # lines and 350k ids, versus 8.4 unique ids. diff --git a/breezy/bzr/xml_serializer.py b/breezy/bzr/xml_serializer.py index f916e44402..da519e3f25 100644 --- a/breezy/bzr/xml_serializer.py +++ b/breezy/bzr/xml_serializer.py @@ -44,7 +44,7 @@ def _unpack_revision(self, element): raise NotImplementedError(self._unpack_revision) def write_revision_to_string(self, rev): - return b''.join(self.write_revision_to_lines(rev)) + return b"".join(self.write_revision_to_lines(rev)) def read_revision(self, f): return self._unpack_revision(self._read_element(f)) @@ -59,8 +59,9 @@ def _read_element(self, f): class XMLInventorySerializer(serializer.InventorySerializer): """Abstract XML object serialize/deserialize.""" - def read_inventory_from_lines(self, lines, revision_id=None, - entry_cache=None, return_from_cache=False): + def read_inventory_from_lines( + self, lines, revision_id=None, entry_cache=None, return_from_cache=False + ): """Read xml_string into an inventory object. :param chunks: The xml to read. @@ -80,20 +81,28 @@ def read_inventory_from_lines(self, lines, revision_id=None, make some operations significantly faster. """ try: - return self._unpack_inventory(fromstringlist(lines), revision_id, - entry_cache=entry_cache, - return_from_cache=return_from_cache) + return self._unpack_inventory( + fromstringlist(lines), + revision_id, + entry_cache=entry_cache, + return_from_cache=return_from_cache, + ) except ParseError as e: raise serializer.UnexpectedInventoryFormat(str(e)) from e - def _unpack_inventory(self, element, revision_id: Optional[bytes] = None, entry_cache=None, return_from_cache=False): + def _unpack_inventory( + self, + element, + revision_id: Optional[bytes] = None, + entry_cache=None, + return_from_cache=False, + ): raise NotImplementedError(self._unpack_inventory) def read_inventory(self, f, revision_id=None): try: try: - return self._unpack_inventory(self._read_element(f), - revision_id=None) + return self._unpack_inventory(self._read_element(f), revision_id=None) finally: f.close() except ParseError as e: @@ -119,7 +128,7 @@ def get_utf8_or_ascii(a_str): # not meant as a generic function for all cases. Because it is possible for # an 8-bit string to not be ascii or valid utf8. if a_str.__class__ is str: - return a_str.encode('utf-8') + return a_str.encode("utf-8") else: return a_str @@ -127,10 +136,12 @@ def get_utf8_or_ascii(a_str): from .._bzr_rs import encode_and_escape, escape_invalid_chars # noqa: F401 -def unpack_inventory_entry(elt, entry_cache=None, return_from_cache=False, root_id=None): +def unpack_inventory_entry( + elt, entry_cache=None, return_from_cache=False, root_id=None +): elt_get = elt.get - file_id = elt_get('file_id') - revision = elt_get('revision') + file_id = elt_get("file_id") + revision = elt_get("revision") # Check and see if we have already unpacked this exact entry # Some timings for "repo.revision_trees(last_100_revs)" # bzr mysql @@ -172,52 +183,47 @@ def unpack_inventory_entry(elt, entry_cache=None, return_from_cache=False, root_ else: # Only copying directory entries drops us 2.85s => 2.35s if return_from_cache: - if cached_ie.kind == 'directory': + if cached_ie.kind == "directory": return cached_ie.copy() return cached_ie return cached_ie.copy() kind = elt.tag if not inventory.InventoryEntry.versionable_kind(kind): - raise AssertionError(f'unsupported entry kind {kind}') + raise AssertionError(f"unsupported entry kind {kind}") file_id = get_utf8_or_ascii(file_id) if revision is not None: revision = get_utf8_or_ascii(revision) - parent_id = elt_get('parent_id') + parent_id = elt_get("parent_id") if parent_id is not None: parent_id = get_utf8_or_ascii(parent_id) else: parent_id = root_id - if kind == 'directory': - ie = inventory.InventoryDirectory(file_id, - elt_get('name'), - parent_id) - elif kind == 'file': - ie = inventory.InventoryFile(file_id, - elt_get('name'), - parent_id) - text_sha1 = elt_get('text_sha1') + if kind == "directory": + ie = inventory.InventoryDirectory(file_id, elt_get("name"), parent_id) + elif kind == "file": + ie = inventory.InventoryFile(file_id, elt_get("name"), parent_id) + text_sha1 = elt_get("text_sha1") if text_sha1 is not None: - ie.text_sha1 = text_sha1.encode('ascii') - if elt_get('executable') == 'yes': + ie.text_sha1 = text_sha1.encode("ascii") + if elt_get("executable") == "yes": ie.executable = True - v = elt_get('text_size') + v = elt_get("text_size") ie.text_size = v and int(v) - elif kind == 'symlink': - ie = inventory.InventoryLink(file_id, - elt_get('name'), - parent_id) - ie.symlink_target = elt_get('symlink_target') - elif kind == 'tree-reference': - file_id = get_utf8_or_ascii(elt.attrib['file_id']) - name = elt.attrib['name'] - parent_id = get_utf8_or_ascii(elt.attrib['parent_id']) - revision = get_utf8_or_ascii(elt.get('revision')) - reference_revision = get_utf8_or_ascii(elt.get('reference_revision')) - ie = inventory.TreeReference(file_id, name, parent_id, revision, - reference_revision) + elif kind == "symlink": + ie = inventory.InventoryLink(file_id, elt_get("name"), parent_id) + ie.symlink_target = elt_get("symlink_target") + elif kind == "tree-reference": + file_id = get_utf8_or_ascii(elt.attrib["file_id"]) + name = elt.attrib["name"] + parent_id = get_utf8_or_ascii(elt.attrib["parent_id"]) + revision = get_utf8_or_ascii(elt.get("revision")) + reference_revision = get_utf8_or_ascii(elt.get("reference_revision")) + ie = inventory.TreeReference( + file_id, name, parent_id, revision, reference_revision + ) else: raise serializer.UnsupportedInventoryKind(kind) ie.revision = revision @@ -231,8 +237,9 @@ def unpack_inventory_entry(elt, entry_cache=None, return_from_cache=False, root_ return ie -def unpack_inventory_flat(elt, format_num, unpack_entry, - entry_cache=None, return_from_cache=False): +def unpack_inventory_flat( + elt, format_num, unpack_entry, entry_cache=None, return_from_cache=False +): """Unpack a flat XML inventory. :param elt: XML element for the inventory @@ -242,15 +249,14 @@ def unpack_inventory_flat(elt, format_num, unpack_entry, :raise UnexpectedInventoryFormat: When unexpected elements or data is encountered """ - if elt.tag != 'inventory': - raise serializer.UnexpectedInventoryFormat(f'Root tag is {elt.tag!r}') - format = elt.get('format') - if ((format is None and format_num is not None) or - format.encode() != format_num): - raise serializer.UnexpectedInventoryFormat(f'Invalid format version {format!r}') - revision_id = elt.get('revision_id') + if elt.tag != "inventory": + raise serializer.UnexpectedInventoryFormat(f"Root tag is {elt.tag!r}") + format = elt.get("format") + if (format is None and format_num is not None) or format.encode() != format_num: + raise serializer.UnexpectedInventoryFormat(f"Invalid format version {format!r}") + revision_id = elt.get("revision_id") if revision_id is not None: - revision_id = revision_id.encode('utf-8') + revision_id = revision_id.encode("utf-8") inv = inventory.Inventory(root_id=None, revision_id=revision_id) for e in elt: ie = unpack_entry(e, entry_cache, return_from_cache) @@ -271,69 +277,107 @@ def serialize_inventory_flat(inv, append, root_id, supported_kinds, working): root_path, root_ie = next(entries) for _path, ie in entries: if ie.parent_id != root_id: - parent_str = b''.join( - [b' parent_id="', encode_and_escape(ie.parent_id), b'"']) + parent_str = b"".join( + [b' parent_id="', encode_and_escape(ie.parent_id), b'"'] + ) else: - parent_str = b'' - if ie.kind == 'file': + parent_str = b"" + if ie.kind == "file": if ie.executable: executable = b' executable="yes"' else: - executable = b'' + executable = b"" if not working: - append(b'\n' % ( - executable, encode_and_escape(ie.file_id), - encode_and_escape(ie.name), parent_str, - encode_and_escape(ie.revision), ie.text_sha1, - ie.text_size)) + append( + b'\n' + % ( + executable, + encode_and_escape(ie.file_id), + encode_and_escape(ie.name), + parent_str, + encode_and_escape(ie.revision), + ie.text_sha1, + ie.text_size, + ) + ) else: - append(b'\n' % ( - executable, encode_and_escape(ie.file_id), - encode_and_escape(ie.name), parent_str)) - elif ie.kind == 'directory': + append( + b'\n' + % ( + executable, + encode_and_escape(ie.file_id), + encode_and_escape(ie.name), + parent_str, + ) + ) + elif ie.kind == "directory": if not working: - append(b'\n' % ( - encode_and_escape(ie.file_id), - encode_and_escape(ie.name), - parent_str, - encode_and_escape(ie.revision))) + append( + b'\n" + % ( + encode_and_escape(ie.file_id), + encode_and_escape(ie.name), + parent_str, + encode_and_escape(ie.revision), + ) + ) else: - append(b'\n' % ( - encode_and_escape(ie.file_id), - encode_and_escape(ie.name), - parent_str)) - elif ie.kind == 'symlink': + append( + b'\n' + % ( + encode_and_escape(ie.file_id), + encode_and_escape(ie.name), + parent_str, + ) + ) + elif ie.kind == "symlink": if not working: - append(b'\n' % ( - encode_and_escape(ie.file_id), - encode_and_escape(ie.name), - parent_str, - encode_and_escape(ie.revision), - encode_and_escape(ie.symlink_target))) + append( + b'\n' + % ( + encode_and_escape(ie.file_id), + encode_and_escape(ie.name), + parent_str, + encode_and_escape(ie.revision), + encode_and_escape(ie.symlink_target), + ) + ) else: - append(b'\n' % ( - encode_and_escape(ie.file_id), - encode_and_escape(ie.name), - parent_str)) - elif ie.kind == 'tree-reference': + append( + b'\n' + % ( + encode_and_escape(ie.file_id), + encode_and_escape(ie.name), + parent_str, + ) + ) + elif ie.kind == "tree-reference": if ie.kind not in supported_kinds: raise serializer.UnsupportedInventoryKind(ie.kind) if not working: - append(b'\n' % ( - encode_and_escape(ie.file_id), - encode_and_escape(ie.name), - parent_str, - encode_and_escape(ie.revision), - encode_and_escape(ie.reference_revision))) + append( + b'\n' + % ( + encode_and_escape(ie.file_id), + encode_and_escape(ie.name), + parent_str, + encode_and_escape(ie.revision), + encode_and_escape(ie.reference_revision), + ) + ) else: - append(b'\n' % ( - encode_and_escape(ie.file_id), - encode_and_escape(ie.name), - parent_str)) + append( + b'\n' + % ( + encode_and_escape(ie.file_id), + encode_and_escape(ie.name), + parent_str, + ) + ) else: raise serializer.UnsupportedInventoryKind(ie.kind) - append(b'\n') + append(b"\n") diff --git a/breezy/cache_utf8.py b/breezy/cache_utf8.py index a5d11f8f01..b203307950 100644 --- a/breezy/cache_utf8.py +++ b/breezy/cache_utf8.py @@ -43,10 +43,12 @@ def _utf8_decode_with_None(bytestring, _utf8_decode=_utf8_decode): _utf8_to_unicode_map: Dict[bytes, str] = {} -def encode(unicode_str, - _uni_to_utf8=_unicode_to_utf8_map, - _utf8_to_uni=_utf8_to_unicode_map, - _utf8_encode=_utf8_encode): +def encode( + unicode_str, + _uni_to_utf8=_unicode_to_utf8_map, + _utf8_to_uni=_utf8_to_unicode_map, + _utf8_encode=_utf8_encode, +): """Take this unicode revision id, and get a unicode version.""" # If the key is in the cache try/KeyError is 50% faster than # val = dict.get(key), if val is None: @@ -64,10 +66,12 @@ def encode(unicode_str, return utf8_str -def decode(utf8_str, - _uni_to_utf8=_unicode_to_utf8_map, - _utf8_to_uni=_utf8_to_unicode_map, - _utf8_decode=_utf8_decode): +def decode( + utf8_str, + _uni_to_utf8=_unicode_to_utf8_map, + _utf8_to_uni=_utf8_to_unicode_map, + _utf8_decode=_utf8_decode, +): """Take a utf8 revision id, and decode it, but cache the result.""" try: return _utf8_to_uni[utf8_str] @@ -100,15 +104,15 @@ def get_cached_utf8(utf8_str): return encode(decode(utf8_str)) -def get_cached_ascii(ascii_str, - _uni_to_utf8=_unicode_to_utf8_map, - _utf8_to_uni=_utf8_to_unicode_map): +def get_cached_ascii( + ascii_str, _uni_to_utf8=_unicode_to_utf8_map, _utf8_to_uni=_utf8_to_unicode_map +): """This is a string which is identical in utf-8 and unicode.""" # We don't need to do any encoding, but we want _utf8_to_uni to return a # real Unicode string. Unicode and plain strings of this type will have the # same hash, so we can just use it as the key in _uni_to_utf8, but we need # the return value to be different in _utf8_to_uni - uni_str = ascii_str.decode('ascii') + uni_str = ascii_str.decode("ascii") ascii_str = _uni_to_utf8.setdefault(uni_str, ascii_str) _utf8_to_uni.setdefault(ascii_str, uni_str) return ascii_str diff --git a/breezy/cethread.py b/breezy/cethread.py index 5dadc749fa..f3233b5f34 100644 --- a/breezy/cethread.py +++ b/breezy/cethread.py @@ -34,7 +34,7 @@ def __init__(self, *args, **kwargs): # blocked. The main example is a calling thread that want to wait for # the called thread to be in a given state before continuing. try: - sync_event = kwargs.pop('sync_event') + sync_event = kwargs.pop("sync_event") except KeyError: # If the caller didn't pass a specific event, create our own sync_event = threading.Event() @@ -98,7 +98,12 @@ def switch_and_set(self, new): finally: self.lock.release() - def set_ignored_exceptions(self, ignored: Union[Callable[[Exception], bool], None, List[Type[Exception]], Type[Exception]]): + def set_ignored_exceptions( + self, + ignored: Union[ + Callable[[Exception], bool], None, List[Type[Exception]], Type[Exception] + ], + ): """Declare which exceptions will be ignored. :param ignored: Can be either: @@ -141,8 +146,9 @@ def join(self, timeout=None): if self.exception is not None: exc_class, exc_value, exc_tb = self.exception self.exception = None # The exception should be raised only once - if (self.ignored_exceptions is None - or not self.ignored_exceptions(exc_value)): + if self.ignored_exceptions is None or not self.ignored_exceptions( + exc_value + ): # Raise non ignored exceptions raise exc_value diff --git a/breezy/check.py b/breezy/check.py index 5665d5d2ed..280716abb1 100644 --- a/breezy/check.py +++ b/breezy/check.py @@ -93,8 +93,12 @@ def check_dwim(path, verbose, do_branch=False, do_repo=False, do_tree=False): an exception raised at the end of the process. """ try: - base_tree, branch, repo, relpath = \ - ControlDir.open_containing_tree_branch_or_repository(path) + ( + base_tree, + branch, + repo, + relpath, + ) = ControlDir.open_containing_tree_branch_or_repository(path) except errors.NotBranchError: base_tree = branch = repo = None @@ -133,10 +137,8 @@ def check_dwim(path, verbose, do_branch=False, do_repo=False, do_tree=False): note(gettext("No working tree found at specified location.")) if do_repo or do_branch or do_tree: if do_repo: - note(gettext("Checking repository at '%s'.") - % (repo.user_url,)) - result = repo.check(None, callback_refs=needed_refs, - check_repo=do_repo) + note(gettext("Checking repository at '%s'.") % (repo.user_url,)) + result = repo.check(None, callback_refs=needed_refs, check_repo=do_repo) result.report_results(verbose) else: if do_tree: diff --git a/breezy/chunk_writer.py b/breezy/chunk_writer.py index 8ddabd4381..0d75006431 100644 --- a/breezy/chunk_writer.py +++ b/breezy/chunk_writer.py @@ -94,7 +94,9 @@ class ChunkWriter: _repack_opts_for_speed = (0, 8) _repack_opts_for_size = (20, 0) - def __init__(self, chunk_size: int, reserved: int = 0, optimize_for_size: bool = False) -> None: + def __init__( + self, chunk_size: int, reserved: int = 0, optimize_for_size: bool = False + ) -> None: """Create a ChunkWriter to write chunk_size chunks. :param chunk_size: The total byte count to emit at the end of the @@ -139,9 +141,10 @@ def finish(self) -> Tuple[List[bytes], Optional[bytes], int]: self.bytes_out_len += len(out) if self.bytes_out_len > self.chunk_size: - raise AssertionError('Somehow we ended up with too much' - ' compressed data, %d > %d' - % (self.bytes_out_len, self.chunk_size)) + raise AssertionError( + "Somehow we ended up with too much" + " compressed data, %d > %d" % (self.bytes_out_len, self.chunk_size) + ) nulls_needed = self.chunk_size - self.bytes_out_len if nulls_needed: self.bytes_list.append(b"\x00" * nulls_needed) @@ -160,7 +163,9 @@ def set_optimize(self, for_size: bool = True) -> None: opts = ChunkWriter._repack_opts_for_speed self._max_repack, self._max_zsync = opts - def _recompress_all_bytes_in(self, extra_bytes: Optional[bytes] = None) -> Tuple[List[bytes], int, "zlib._Compress"]: + def _recompress_all_bytes_in( + self, extra_bytes: Optional[bytes] = None + ) -> Tuple[List[bytes], int, "zlib._Compress"]: """Recompress the current bytes_in, and optionally more. :param extra_bytes: Optional, if supplied we will add it with @@ -210,7 +215,7 @@ def write(self, bytes: bytes, reserved: bool = False) -> bool: # room to spare, assuming no compression. next_unflushed = self.unflushed_in_bytes + len(bytes) remaining_capacity = capacity - self.bytes_out_len - 10 - if (next_unflushed < remaining_capacity): + if next_unflushed < remaining_capacity: # looks like it will fit out = comp.compress(bytes) if out: @@ -252,15 +257,13 @@ def write(self, bytes: bytes, reserved: bool = False) -> bool: # We are over budget, try to squeeze this in without any # Z_SYNC_FLUSH calls self.num_repack += 1 - (bytes_out, this_len, - compressor) = self._recompress_all_bytes_in(bytes) + (bytes_out, this_len, compressor) = self._recompress_all_bytes_in(bytes) if self.num_repack >= self._max_repack: # When we get *to* _max_repack, bump over so that the # earlier > _max_repack will be triggered. self.num_repack += 1 if this_len + 10 > capacity: - (bytes_out, this_len, - compressor) = self._recompress_all_bytes_in() + (bytes_out, this_len, compressor) = self._recompress_all_bytes_in() self.compressor = compressor # Force us to not allow more data self.num_repack = self._max_repack + 1 diff --git a/breezy/clean_tree.py b/breezy/clean_tree.py index a6203b653c..6f0fdc3f0a 100644 --- a/breezy/clean_tree.py +++ b/breezy/clean_tree.py @@ -26,8 +26,13 @@ def is_detritus(subp): """Return True if the supplied path is detritus, False otherwise.""" - return subp.endswith('.THIS') or subp.endswith('.BASE') or\ - subp.endswith('.OTHER') or subp.endswith('~') or subp.endswith('.tmp') + return ( + subp.endswith(".THIS") + or subp.endswith(".BASE") + or subp.endswith(".OTHER") + or subp.endswith("~") + or subp.endswith(".tmp") + ) def iter_deletables(tree, unknown=False, ignored=False, detritus=False): @@ -44,23 +49,30 @@ def iter_deletables(tree, unknown=False, ignored=False, detritus=False): yield tree.abspath(subp), subp -def clean_tree(directory, unknown=False, ignored=False, detritus=False, - dry_run=False, no_prompt=False): +def clean_tree( + directory, + unknown=False, + ignored=False, + detritus=False, + dry_run=False, + no_prompt=False, +): """Remove files in the specified classes from the tree.""" tree = WorkingTree.open_containing(directory)[0] with tree.lock_read(): - deletables = list(iter_deletables(tree, unknown=unknown, - ignored=ignored, detritus=detritus)) + deletables = list( + iter_deletables(tree, unknown=unknown, ignored=ignored, detritus=detritus) + ) deletables = _filter_out_nested_controldirs(deletables) if len(deletables) == 0: - note(gettext('Nothing to delete.')) + note(gettext("Nothing to delete.")) return 0 if not no_prompt: for _path, subp in deletables: ui.ui_factory.note(subp) - prompt = gettext('Are you sure you wish to delete these') + prompt = gettext("Are you sure you wish to delete these") if not ui.ui_factory.get_boolean(prompt): - ui.ui_factory.note(gettext('Canceled')) + ui.ui_factory.note(gettext("Canceled")) return 0 delete_items(deletables, dry_run=dry_run) @@ -88,13 +100,15 @@ def _filter_out_nested_controldirs(deletables): def delete_items(deletables, dry_run=False): """Delete files in the deletables iterable.""" + def onerror(function, path, excinfo): """Show warning for errors seen by rmtree.""" # Handle only permission error while removing files. # Other errors are re-raised. if function is not os.remove or not isinstance(excinfo[1], PermissionError): raise - ui.ui_factory.show_warning(gettext('unable to remove %s') % path) + ui.ui_factory.show_warning(gettext("unable to remove %s") % path) + has_deleted = False for path, subp in deletables: if not has_deleted: @@ -106,12 +120,12 @@ def onerror(function, path, excinfo): else: try: os.unlink(path) - note(' ' + subp) + note(" " + subp) except PermissionError as e: - ui.ui_factory.show_warning(gettext( - 'unable to remove "{0}": {1}.').format( - path, e.strerror)) + ui.ui_factory.show_warning( + gettext('unable to remove "{0}": {1}.').format(path, e.strerror) + ) else: - note(' ' + subp) + note(" " + subp) if not has_deleted: note(gettext("No files deleted.")) diff --git a/breezy/cmd_test_script.py b/breezy/cmd_test_script.py index 3a2a43b6a0..e0d21630ff 100644 --- a/breezy/cmd_test_script.py +++ b/breezy/cmd_test_script.py @@ -29,11 +29,10 @@ class cmd_test_script(commands.Command): """Run a shell-like test from a file.""" hidden = True - takes_args = ['infile'] + takes_args = ["infile"] takes_options = [ - option.Option('null-output', - help='Null command outputs match any output.'), - ] + option.Option("null-output", help="Null command outputs match any output."), + ] @commands.display_command def run(self, infile, null_output=False): @@ -46,15 +45,13 @@ def run(self, infile, null_output=False): script = f.read() class Test(TestCaseWithTransportAndScript): - script = None # Set before running def test_it(self): - self.run_script(script, - null_output_matches_anything=null_output) + self.run_script(script, null_output_matches_anything=null_output) runner = tests.TextTestRunner(stream=self.outf) - test = Test('test_it') + test = Test("test_it") test.path = os.path.realpath(infile) res = runner.run(test) return len(res.errors) + len(res.failures) diff --git a/breezy/cmd_version_info.py b/breezy/cmd_version_info.py index dfb6bee548..20f5a78fb4 100644 --- a/breezy/cmd_version_info.py +++ b/breezy/cmd_version_info.py @@ -18,13 +18,16 @@ from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( branch, workingtree, ) from breezy.i18n import gettext -""") +""", +) from . import errors from .commands import Command @@ -38,13 +41,16 @@ def _parse_version_info_format(format): cannot be found, generates a useful error exception. """ from . import version_info_formats + try: return version_info_formats.get_builder(format) except KeyError as err: formats = version_info_formats.get_builder_formats() raise errors.CommandError( - gettext('No known version info format {0}.' - ' Supported types are: {1}').format(format, formats)) from err + gettext( + "No known version info format {0}." " Supported types are: {1}" + ).format(format, formats) + ) from err class cmd_version_info(Command): @@ -71,40 +77,51 @@ class cmd_version_info(Command): otherwise 1 """ - takes_options = [RegistryOption('format', - 'Select the output format.', - value_switches=True, - lazy_registry=('breezy.version_info_formats', - 'format_registry')), - Option('all', help='Include all possible information.'), - Option('check-clean', help='Check if tree is clean.'), - Option('include-history', - help='Include the revision-history.'), - Option('include-file-revisions', - help='Include the last revision for each file.'), - Option('template', type=str, - help='Template for the output.'), - 'revision', - ] - takes_args = ['location?'] - - encoding_type = 'replace' - - def run(self, location=None, format=None, - all=False, check_clean=False, include_history=False, - include_file_revisions=False, template=None, - revision=None): - + takes_options = [ + RegistryOption( + "format", + "Select the output format.", + value_switches=True, + lazy_registry=("breezy.version_info_formats", "format_registry"), + ), + Option("all", help="Include all possible information."), + Option("check-clean", help="Check if tree is clean."), + Option("include-history", help="Include the revision-history."), + Option( + "include-file-revisions", help="Include the last revision for each file." + ), + Option("template", type=str, help="Template for the output."), + "revision", + ] + takes_args = ["location?"] + + encoding_type = "replace" + + def run( + self, + location=None, + format=None, + all=False, + check_clean=False, + include_history=False, + include_file_revisions=False, + template=None, + revision=None, + ): if revision and len(revision) > 1: raise errors.CommandError( - gettext('brz version-info --revision takes exactly' - ' one revision specifier')) + gettext( + "brz version-info --revision takes exactly" + " one revision specifier" + ) + ) if location is None: - location = '.' + location = "." if format is None: from . import version_info_formats + format = version_info_formats.format_registry.get() try: @@ -122,7 +139,7 @@ def run(self, location=None, format=None, if template: include_history = True include_file_revisions = True - if '{clean}' in template: + if "{clean}" in template: check_clean = True if revision is not None: @@ -130,9 +147,13 @@ def run(self, location=None, format=None, else: revision_id = None - builder = format(b, working_tree=wt, - check_for_clean=check_clean, - include_revision_history=include_history, - include_file_revisions=include_file_revisions, - template=template, revision_id=revision_id) + builder = format( + b, + working_tree=wt, + check_for_clean=check_clean, + include_revision_history=include_history, + include_file_revisions=include_file_revisions, + template=template, + revision_id=revision_id, + ) builder.generate(self.outf) diff --git a/breezy/cmdline.py b/breezy/cmdline.py index 4481f5d906..642c499742 100644 --- a/breezy/cmdline.py +++ b/breezy/cmdline.py @@ -23,7 +23,7 @@ import re from typing import List, Optional, Tuple -_whitespace_match = re.compile('\\s', re.UNICODE).match +_whitespace_match = re.compile("\\s", re.UNICODE).match class _PushbackSequence: @@ -56,7 +56,7 @@ def process(self, next_char, context): elif next_char in context.allowed_quote_chars: context.quoted = True return _Quotes(next_char, self) - elif next_char == '\\': + elif next_char == "\\": return _Backslash(self) else: context.token.append(next_char) @@ -69,10 +69,10 @@ def __init__(self, quote_char, exit_state): self.exit_state = exit_state def process(self, next_char, context): - if next_char == '\\': + if next_char == "\\": return _Backslash(self) elif next_char == self.quote_char: - context.token.append('') + context.token.append("") return self.exit_state else: context.token.append(next_char) @@ -86,12 +86,12 @@ def __init__(self, exit_state): self.count = 1 def process(self, next_char, context): - if next_char == '\\': + if next_char == "\\": self.count += 1 return self elif next_char in context.allowed_quote_chars: # 2N backslashes followed by a quote are N backslashes - context.token.append('\\' * (self.count // 2)) + context.token.append("\\" * (self.count // 2)) # 2N+1 backslashes follwed by a quote are N backslashes followed by # the quote which should not be processed as the start or end of # the quoted arg @@ -106,7 +106,7 @@ def process(self, next_char, context): else: # N backslashes not followed by a quote are just N backslashes if self.count > 0: - context.token.append('\\' * self.count) + context.token.append("\\" * self.count) self.count = 0 # let exit_state handle next_char context.seq.pushback(next_char) @@ -114,7 +114,7 @@ def process(self, next_char, context): def finish(self, context): if self.count > 0: - context.token.append('\\' * self.count) + context.token.append("\\" * self.count) class _Word: @@ -123,7 +123,7 @@ def process(self, next_char, context): return None elif next_char in context.allowed_quote_chars: return _Quotes(next_char, self) - elif next_char == '\\': + elif next_char == "\\": return _Backslash(self) else: context.token.append(next_char) @@ -156,10 +156,10 @@ def _get_token(self) -> Tuple[bool, Optional[str]]: state = state.process(next_char, self) if state is None: break - if state is not None and hasattr(state, 'finish'): + if state is not None and hasattr(state, "finish"): state.finish(self) - result: Optional[str] = ''.join(self.token) - if not self.quoted and result == '': + result: Optional[str] = "".join(self.token) + if not self.quoted and result == "": result = None return self.quoted, result diff --git a/breezy/colordiff.py b/breezy/colordiff.py index 74ebfd66a8..0d8dfd2db8 100644 --- a/breezy/colordiff.py +++ b/breezy/colordiff.py @@ -30,10 +30,10 @@ hunk_from_header, ) -GLOBAL_COLORDIFFRC = '/etc/colordiffrc' +GLOBAL_COLORDIFFRC = "/etc/colordiffrc" -class LineParser: +class LineParser: def parse_line(self, line): if line.startswith(b"@"): return hunk_from_header(line) @@ -52,7 +52,7 @@ def read_colordiffrc(path): with open(path) as f: for line in f.readlines(): try: - key, val = line.split('=') + key, val = line.split("=") except ValueError: continue @@ -69,24 +69,23 @@ def read_colordiffrc(path): class DiffWriter: - def __init__(self, target, check_style=False): self.target = target self.lp = LineParser() self.chunks = [] self.colors = { - 'metaline': 'darkyellow', - 'plain': 'darkwhite', - 'newtext': 'darkblue', - 'oldtext': 'darkred', - 'diffstuff': 'darkgreen', - 'trailingspace': 'yellow', - 'leadingtabs': 'magenta', - 'longline': 'cyan', + "metaline": "darkyellow", + "plain": "darkwhite", + "newtext": "darkblue", + "oldtext": "darkred", + "diffstuff": "darkgreen", + "trailingspace": "yellow", + "leadingtabs": "magenta", + "longline": "cyan", } if GLOBAL_COLORDIFFRC is not None: self._read_colordiffrc(GLOBAL_COLORDIFFRC) - self._read_colordiffrc(expanduser('~/.colordiffrc')) + self._read_colordiffrc(expanduser("~/.colordiffrc")) self.added_leading_tabs = 0 self.added_trailing_whitespace = 0 self.spurious_whitespace = 0 @@ -106,17 +105,22 @@ def colorstring(self, type, item, bad_ws_match): color = self.colors[type] if color is not None: if self.check_style and bad_ws_match: - #highlight were needed - item.contents = ''.join( + # highlight were needed + item.contents = "".join( terminal.colorstring(txt, color, bcol) for txt, bcol in ( - (bad_ws_match.group(1).expandtabs(), - self.colors['leadingtabs']), - (bad_ws_match.group(2)[0:self.max_line_len], None), - (bad_ws_match.group(2)[self.max_line_len:], - self.colors['longline']), - (bad_ws_match.group(3), self.colors['trailingspace']) - )) + bad_ws_match.group(4) + ( + bad_ws_match.group(1).expandtabs(), + self.colors["leadingtabs"], + ), + (bad_ws_match.group(2)[0 : self.max_line_len], None), + ( + bad_ws_match.group(2)[self.max_line_len :], + self.colors["longline"], + ), + (bad_ws_match.group(3), self.colors["trailingspace"]), + ) + ) + bad_ws_match.group(4) if not isinstance(item, bytes): item = item.as_bytes() string = terminal.colorstring(item, color) @@ -125,9 +129,9 @@ def colorstring(self, type, item, bad_ws_match): self.target.write(string) def write(self, text): - newstuff = text.split(b'\n') + newstuff = text.split(b"\n") for newchunk in newstuff[:-1]: - self._writeline(b''.join(self.chunks + [newchunk, b'\n'])) + self._writeline(b"".join(self.chunks + [newchunk, b"\n"])) self.chunks = [] self.chunks = [newstuff[-1]] @@ -139,11 +143,10 @@ def _writeline(self, line): item = self.lp.parse_line(line) bad_ws_match = None if isinstance(item, Hunk): - line_class = 'diffstuff' + line_class = "diffstuff" self._analyse_old_new() elif isinstance(item, HunkLine): - bad_ws_match = re.match(br'^([\t]*)(.*?)([\t ]*)(\r?\n)$', - item.contents) + bad_ws_match = re.match(rb"^([\t]*)(.*?)([\t ]*)(\r?\n)$", item.contents) has_leading_tabs = bool(bad_ws_match.group(1)) has_trailing_whitespace = bool(bad_ws_match.group(3)) if isinstance(item, InsertLine): @@ -151,21 +154,22 @@ def _writeline(self, line): self.added_leading_tabs += 1 if has_trailing_whitespace: self.added_trailing_whitespace += 1 - if (len(bad_ws_match.group(2)) > self.max_line_len and - not item.contents.startswith(b'++ ')): + if len( + bad_ws_match.group(2) + ) > self.max_line_len and not item.contents.startswith(b"++ "): self.long_lines += 1 - line_class = 'newtext' + line_class = "newtext" self._new_lines.append(item) elif isinstance(item, RemoveLine): - line_class = 'oldtext' + line_class = "oldtext" self._old_lines.append(item) else: - line_class = 'plain' - elif isinstance(item, bytes) and item.startswith(b'==='): - line_class = 'metaline' + line_class = "plain" + elif isinstance(item, bytes) and item.startswith(b"==="): + line_class = "metaline" self._analyse_old_new() else: - line_class = 'plain' + line_class = "plain" self._analyse_old_new() self.colorstring(line_class, item, bad_ws_match) @@ -193,5 +197,5 @@ def _analyse_old_new(self): raise AssertionError if no_ws_matched > ws_matched: self.spurious_whitespace += no_ws_matched - ws_matched - self.target.write('^ Spurious whitespace change above.\n') + self.target.write("^ Spurious whitespace change above.\n") self._old_lines, self._new_lines = ([], []) diff --git a/breezy/commands.py b/breezy/commands.py index 61198cc316..479adfe242 100644 --- a/breezy/commands.py +++ b/breezy/commands.py @@ -31,14 +31,17 @@ from . import i18n, option, trace from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import breezy from breezy import ( cmdline, ui, ) -""") +""", +) from . import debug, errors, registry from .hooks import Hooks @@ -46,21 +49,21 @@ class CommandAvailableInPlugin(Exception): - internal_error = False def __init__(self, cmd_name, plugin_metadata, provider): - self.plugin_metadata = plugin_metadata self.cmd_name = cmd_name self.provider = provider def __str__(self): - - _fmt = ('"{}" is not a standard brz command. \n' - 'However, the following official plugin provides this command: {}\n' - 'You can install it by going to: {}'.format(self.cmd_name, self.plugin_metadata['name'], - self.plugin_metadata['url'])) + _fmt = ( + '"{}" is not a standard brz command. \n' + "However, the following official plugin provides this command: {}\n" + "You can install it by going to: {}".format( + self.cmd_name, self.plugin_metadata["name"], self.plugin_metadata["url"] + ) + ) return _fmt @@ -126,13 +129,16 @@ def register(self, cmd, decorate=False): pass info = CommandInfo.from_command(cmd) try: - registry.Registry.register(self, k_unsquished, cmd, - override_existing=decorate, info=info) + registry.Registry.register( + self, k_unsquished, cmd, override_existing=decorate, info=info + ) except KeyError: - trace.warning(f'Two plugins defined the same command: {k!r}') - trace.warning(f'Not loading the one in {sys.modules[cmd.__module__]!r}') - trace.warning('Previously this command was registered from %r' % - sys.modules[previous.__module__]) + trace.warning(f"Two plugins defined the same command: {k!r}") + trace.warning(f"Not loading the one in {sys.modules[cmd.__module__]!r}") + trace.warning( + "Previously this command was registered from %r" + % sys.modules[previous.__module__] + ) for a in cmd.aliases: self._alias_dict[a] = k_unsquished return previous @@ -146,8 +152,9 @@ def register_lazy(self, command_name, aliases, module_name): module_name: The module that the command lives in. """ key = self._get_name(command_name) - registry.Registry.register_lazy(self, key, module_name, command_name, - info=CommandInfo(aliases)) + registry.Registry.register_lazy( + self, key, module_name, command_name, info=CommandInfo(aliases) + ) for a in aliases: self._alias_dict[a] = key @@ -167,11 +174,11 @@ def register_command(cmd, decorate=False): def _squish_command_name(cmd): - return 'cmd_' + cmd.replace('-', '_') + return "cmd_" + cmd.replace("-", "_") def _unsquish_command_name(cmd): - return cmd[4:].replace('_', '-') + return cmd[4:].replace("_", "-") def _register_builtin_commands(): @@ -179,6 +186,7 @@ def _register_builtin_commands(): # only load once return import breezy.builtins + for cmd_class in _scan_module_for_commands(breezy.builtins): builtin_command_registry.register(cmd_class) breezy.builtins._register_lazy_builtins() @@ -206,11 +214,12 @@ def _list_bzr_commands(names): def all_command_names(): """Return a set of all command names.""" names = set() - for hook in Command.hooks['list_commands']: + for hook in Command.hooks["list_commands"]: names = hook(names) if names is None: raise AssertionError( - f'hook {Command.hooks.get_hook_name(hook)} returned None') + f"hook {Command.hooks.get_hook_name(hook)} returned None" + ) return names @@ -231,8 +240,8 @@ def plugin_command_names(): # Overrides for common mispellings that heuristics get wrong _GUESS_OVERRIDES = { - 'ic': {'ci': 0}, # heuristic finds nick - } + "ic": {"ci": 0}, # heuristic finds nick +} def guess_command(cmd_name): @@ -251,18 +260,19 @@ def guess_command(cmd_name): # candidate: modified levenshtein distance against cmd_name. costs = {} import patiencediff + for name in sorted(names): matcher = patiencediff.PatienceSequenceMatcher(None, cmd_name, name) distance = 0.0 opcodes = matcher.get_opcodes() for opcode, l1, l2, r1, r2 in opcodes: - if opcode == 'delete': + if opcode == "delete": distance += l2 - l1 - elif opcode == 'replace': + elif opcode == "replace": distance += max(l2 - l1, r2 - l1) - elif opcode == 'insert': + elif opcode == "insert": distance += r2 - r1 - elif opcode == 'equal': + elif opcode == "equal": # Score equal ranges lower, making similar commands of equal # length closer than arbitrary same length commands. distance -= 0.1 * (l2 - l1) @@ -277,8 +287,7 @@ def guess_command(cmd_name): return candidate -def get_cmd_object( - cmd_name: str, plugins_override: bool = True) -> "Command": +def get_cmd_object(cmd_name: str, plugins_override: bool = True) -> "Command": """Return the command object for a command. plugins_override @@ -292,14 +301,16 @@ def get_cmd_object( if candidate is not None: raise errors.CommandError( i18n.gettext('unknown command "%s". Perhaps you meant "%s"') - % (cmd_name, candidate)) from err - raise errors.CommandError(i18n.gettext('unknown command "%s"') - % cmd_name) from err + % (cmd_name, candidate) + ) from err + raise errors.CommandError( + i18n.gettext('unknown command "%s"') % cmd_name + ) from err def _get_cmd_object( - cmd_name: str, plugins_override: bool = True, - check_missing: bool = True) -> "Command": + cmd_name: str, plugins_override: bool = True, check_missing: bool = True +) -> "Command": """Get a command object. Args: @@ -320,14 +331,14 @@ def _get_cmd_object( # In the future, we may actually support Unicode command names. cmd: Optional[Command] = None # Get a command - for hook in Command.hooks['get_command']: + for hook in Command.hooks["get_command"]: cmd = hook(cmd, cmd_name) if cmd is not None and not plugins_override and not cmd.plugin_name(): # We've found a non-plugin command, don't permit it to be # overridden. break if cmd is None and check_missing: - for hook in Command.hooks['get_missing_command']: + for hook in Command.hooks["get_missing_command"]: cmd = hook(cmd_name) if cmd is not None: break @@ -335,9 +346,9 @@ def _get_cmd_object( # No command found. raise KeyError # Allow plugins to extend commands - for hook in Command.hooks['extend_command']: + for hook in Command.hooks["extend_command"]: hook(cmd) - if getattr(cmd, 'invoked_as', None) is None: + if getattr(cmd, "invoked_as", None) is None: cmd.invoked_as = cmd_name return cmd @@ -393,6 +404,7 @@ def _get_external_command(cmd_or_none, cmd_name): if cmd_or_none is not None: return cmd_or_none from .externalcommand import ExternalCommand + cmd_obj = ExternalCommand.find_command(cmd_name) if cmd_obj: return cmd_obj @@ -476,10 +488,11 @@ class Command: class Foo(Command): __doc__ = "My help goes here" """ + aliases: List[str] = [] takes_args: List[str] = [] takes_options: List[Union[str, option.Option]] = [] - encoding_type: str = 'strict' + encoding_type: str = "strict" invoked_as: Optional[str] = None l10n: bool = True _see_also: List[str] @@ -524,21 +537,26 @@ def _usage(self): Only describes arguments, not options. """ - s = 'brz ' + self.name() + ' ' + s = "brz " + self.name() + " " for aname in self.takes_args: aname = aname.upper() - if aname[-1] in ['$', '+']: - aname = aname[:-1] + '...' - elif aname[-1] == '?': - aname = '[' + aname[:-1] + ']' - elif aname[-1] == '*': - aname = '[' + aname[:-1] + '...]' - s += aname + ' ' - s = s[:-1] # remove last space + if aname[-1] in ["$", "+"]: + aname = aname[:-1] + "..." + elif aname[-1] == "?": + aname = "[" + aname[:-1] + "]" + elif aname[-1] == "*": + aname = "[" + aname[:-1] + "...]" + s += aname + " " + s = s[:-1] # remove last space return s - def get_help_text(self, additional_see_also=None, plain=True, - see_also_as_links=False, verbose=True): + def get_help_text( + self, + additional_see_also=None, + plain=True, + see_also_as_links=False, + verbose=True, + ): """Return a text string with help for this command. Args: @@ -570,27 +588,26 @@ def get_help_text(self, additional_see_also=None, plain=True, purpose, sections, order = self._get_help_parts(doc) # If a custom usage section was provided, use it - if 'Usage' in sections: - usage = sections.pop('Usage') + if "Usage" in sections: + usage = sections.pop("Usage") else: usage = self._usage() # The header is the purpose and usage result = "" - result += i18n.gettext(':Purpose: %s\n') % (purpose,) - if usage.find('\n') >= 0: - result += i18n.gettext(':Usage:\n%s\n') % (usage,) + result += i18n.gettext(":Purpose: %s\n") % (purpose,) + if usage.find("\n") >= 0: + result += i18n.gettext(":Usage:\n%s\n") % (usage,) else: - result += i18n.gettext(':Usage: %s\n') % (usage,) - result += '\n' + result += i18n.gettext(":Usage: %s\n") % (usage,) + result += "\n" # Add the options # # XXX: optparse implicitly rewraps the help, and not always perfectly, # so we get . -- mbp # 20090319 - parser = option.get_optparser( - [v for k, v in sorted(self.options().items())]) + parser = option.get_optparser([v for k, v in sorted(self.options().items())]) options = parser.format_option_help() # FIXME: According to the spec, ReST option lists actually don't # support options like --1.14 so that causes syntax errors (in Sphinx @@ -598,37 +615,39 @@ def get_help_text(self, additional_see_also=None, plain=True, # break, we trap on that and then format that block of 'format' options # as a literal block. We use the most recent format still listed so we # don't have to do that too often -- vila 20110514 - if not plain and options.find(' --1.14 ') != -1: - options = options.replace(' format:\n', ' format::\n\n', 1) - if options.startswith('Options:'): - result += i18n.gettext(':Options:%s') % (options[len('options:'):],) + if not plain and options.find(" --1.14 ") != -1: + options = options.replace(" format:\n", " format::\n\n", 1) + if options.startswith("Options:"): + result += i18n.gettext(":Options:%s") % (options[len("options:") :],) else: result += options - result += '\n' + result += "\n" if verbose: # Add the description, indenting it 2 spaces # to match the indentation of the options if None in sections: text = sections.pop(None) - text = '\n '.join(text.splitlines()) - result += i18n.gettext(':Description:\n %s\n\n') % (text,) + text = "\n ".join(text.splitlines()) + result += i18n.gettext(":Description:\n %s\n\n") % (text,) # Add the custom sections (e.g. Examples). Note that there's no need # to indent these as they must be indented already in the source. if sections: for label in order: if label in sections: - result += f':{label}:\n{sections[label]}\n' - result += '\n' + result += f":{label}:\n{sections[label]}\n" + result += "\n" else: - result += (i18n.gettext("See brz help %s for more details and examples.\n\n") - % self.name()) + result += ( + i18n.gettext("See brz help %s for more details and examples.\n\n") + % self.name() + ) # Add the aliases, source (plug-in) and see also links, if any if self.aliases: - result += i18n.gettext(':Aliases: ') - result += ', '.join(self.aliases) + '\n' + result += i18n.gettext(":Aliases: ") + result += ", ".join(self.aliases) + "\n" plugin_name = self.plugin_name() if plugin_name is not None: result += i18n.gettext(':From: plugin "%s"\n') % plugin_name @@ -637,21 +656,23 @@ def get_help_text(self, additional_see_also=None, plain=True, if not plain and see_also_as_links: see_also_links = [] for item in see_also: - if item == 'topics': + if item == "topics": # topics doesn't have an independent section # so don't create a real link see_also_links.append(item) else: # Use a Sphinx link for this entry link_text = i18n.gettext(":doc:`{0} <{1}-help>`").format( - item, item) + item, item + ) see_also_links.append(link_text) see_also = see_also_links - result += i18n.gettext(':See also: %s') % ', '.join(see_also) + '\n' + result += i18n.gettext(":See also: %s") % ", ".join(see_also) + "\n" # If this will be rendered as plain text, convert it if plain: import breezy.help_topics + result = breezy.help_topics.help_as_plain_text(result) return result @@ -667,10 +688,11 @@ def _get_help_parts(text): All text found outside a named section is assigned to the default section which is given the key of None. """ + def save_section(sections, order, label, section): if len(section) > 0: if label in sections: - sections[label] += '\n' + section + sections[label] += "\n" + section else: order.append(label) sections[label] = section @@ -679,18 +701,17 @@ def save_section(sections, order, label, section): summary = lines.pop(0) sections = {} order = [] - label, section = None, '' + label, section = None, "" for line in lines: - if line.startswith(':') and line.endswith(':') and len(line) > 2: + if line.startswith(":") and line.endswith(":") and len(line) > 2: save_section(sections, order, label, section) - label, section = line[1:-1], '' - elif (label is not None and len(line) > 1 and - not line[0].isspace()): + label, section = line[1:-1], "" + elif label is not None and len(line) > 1 and not line[0].isspace(): save_section(sections, order, label, section) label, section = None, line else: if len(section) > 0: - section += '\n' + line + section += "\n" + line else: section = line save_section(sections, order, label, section) @@ -712,7 +733,7 @@ def get_see_also(self, additional_terms=None): Returns: A list of help topics. """ - see_also = set(getattr(self, '_see_also', [])) + see_also = set(getattr(self, "_see_also", [])) if additional_terms: see_also.update(additional_terms) return sorted(see_also) @@ -734,8 +755,7 @@ def options(self): def _setup_outf(self): """Return a file linked to stdout, which has proper encoding.""" - self.outf = ui.ui_factory.make_output_stream( - encoding_type=self.encoding_type) + self.outf = ui.ui_factory.make_output_stream(encoding_type=self.encoding_type) def run_argv_aliases(self, argv, alias_argv=None): """Parse the command line and run with extra aliases in alias_argv.""" @@ -743,26 +763,26 @@ def run_argv_aliases(self, argv, alias_argv=None): self._setup_outf() # Process the standard options - if 'help' in opts: # e.g. brz add --help + if "help" in opts: # e.g. brz add --help self.outf.write(self.get_help_text()) return 0 - if 'usage' in opts: # e.g. brz add --usage + if "usage" in opts: # e.g. brz add --usage self.outf.write(self.get_help_text(verbose=False)) return 0 trace.set_verbosity_level(option._verbosity_level) - if 'verbose' in self.supported_std_options: - opts['verbose'] = trace.is_verbose() - elif 'verbose' in opts: - del opts['verbose'] - if 'quiet' in self.supported_std_options: - opts['quiet'] = trace.is_quiet() - elif 'quiet' in opts: - del opts['quiet'] + if "verbose" in self.supported_std_options: + opts["verbose"] = trace.is_verbose() + elif "verbose" in opts: + del opts["verbose"] + if "quiet" in self.supported_std_options: + opts["quiet"] = trace.is_quiet() + elif "quiet" in opts: + del opts["quiet"] # mix arguments and options into one dictionary cmdargs = _match_argform(self.name(), self.takes_args, args) cmdopts = {} for k, v in opts.items(): - cmdopts[k.replace('-', '_')] = v + cmdopts[k.replace("-", "_")] = v all_cmd_args = cmdargs.copy() all_cmd_args.update(cmdopts) @@ -774,7 +794,8 @@ def run_argv_aliases(self, argv, alias_argv=None): # inherit state. Before we reset it, log any activity, so that it # gets properly tracked. ui.ui_factory.log_transport_activity( - display=(debug.debug_flag_enabled('bytes'))) + display=(debug.debug_flag_enabled("bytes")) + ) trace.set_verbosity_level(0) def _setup_run(self): @@ -789,14 +810,15 @@ def _setup_run(self): class_run = self.run def run(*args, **kwargs): - for hook in Command.hooks['pre_command']: + for hook in Command.hooks["pre_command"]: hook(self) try: with contextlib.ExitStack() as self._exit_stack: return class_run(*args, **kwargs) finally: - for hook in Command.hooks['post_command']: + for hook in Command.hooks["post_command"]: hook(self) + self.run = run def run(self): # type: ignore @@ -820,11 +842,12 @@ def run(self): # type: ignore def run(self, files=None): pass """ - raise NotImplementedError(f'no implementation of command {self.name()!r}') + raise NotImplementedError(f"no implementation of command {self.name()!r}") def help(self): """Return help message for this class.""" from inspect import getdoc + if self.__doc__ is Command.__doc__: return None return getdoc(self) @@ -863,40 +886,48 @@ def __init__(self): """ Hooks.__init__(self, "breezy.commands", "Command.hooks") self.add_hook( - 'extend_command', + "extend_command", "Called after creating a command object to allow modifications " "such as adding or removing options, docs etc. Called with the " - "new breezy.commands.Command object.", (1, 13)) + "new breezy.commands.Command object.", + (1, 13), + ) self.add_hook( - 'get_command', + "get_command", "Called when creating a single command. Called with " "(cmd_or_none, command_name). get_command should either return " "the cmd_or_none parameter, or a replacement Command object that " "should be used for the command. Note that the Command.hooks " "hooks are core infrastructure. Many users will prefer to use " "breezy.commands.register_command or plugin_cmds.register_lazy.", - (1, 17)) + (1, 17), + ) self.add_hook( - 'get_missing_command', + "get_missing_command", "Called when creating a single command if no command could be " "found. Called with (command_name). get_missing_command should " "either return None, or a Command object to be used for the " - "command.", (1, 17)) + "command.", + (1, 17), + ) self.add_hook( - 'list_commands', + "list_commands", "Called when enumerating commands. Called with a set of " "cmd_name strings for all the commands found so far. This set " " is safe to mutate - e.g. to remove a command. " "list_commands should return the updated set of command names.", - (1, 17)) + (1, 17), + ) self.add_hook( - 'pre_command', - "Called prior to executing a command. Called with the command " - "object.", (2, 6)) + "pre_command", + "Called prior to executing a command. Called with the command " "object.", + (2, 6), + ) self.add_hook( - 'post_command', - "Called after executing a command. Called with the command " - "object.", (2, 6)) + "post_command", + "Called after executing a command. Called with the command " "object.", + (2, 6), + ) Command.hooks = CommandHooks() # type: ignore @@ -911,8 +942,7 @@ def parse_args(command, argv, alias_argv=None): they take, and which commands will accept them. """ # TODO: make it a method of the Command? - parser = option.get_optparser( - [v for k, v in sorted(command.options().items())]) + parser = option.get_optparser([v for k, v in sorted(command.options().items())]) if alias_argv is not None: args = alias_argv + argv else: @@ -924,10 +954,14 @@ def parse_args(command, argv, alias_argv=None): options, args = parser.parse_args(args) except UnicodeEncodeError as err: raise errors.CommandError( - i18n.gettext('Only ASCII permitted in option names')) from err + i18n.gettext("Only ASCII permitted in option names") + ) from err - opts = {k: v for k, v in options.__dict__.items() if - v is not option.OptionParser.DEFAULT_VALUE} + opts = { + k: v + for k, v in options.__dict__.items() + if v is not option.OptionParser.DEFAULT_VALUE + } return args, opts @@ -937,29 +971,33 @@ def _match_argform(cmd, takes_args, args): # step through args and takes_args, allowing appropriate 0-many matches for ap in takes_args: argname = ap[:-1] - if ap[-1] == '?': + if ap[-1] == "?": if args: argdict[argname] = args.pop(0) - elif ap[-1] == '*': # all remaining arguments + elif ap[-1] == "*": # all remaining arguments if args: - argdict[argname + '_list'] = args[:] + argdict[argname + "_list"] = args[:] args = [] else: - argdict[argname + '_list'] = None - elif ap[-1] == '+': + argdict[argname + "_list"] = None + elif ap[-1] == "+": if not args: - raise errors.CommandError(i18n.gettext( - "command {0!r} needs one or more {1}").format( - cmd, argname.upper())) + raise errors.CommandError( + i18n.gettext("command {0!r} needs one or more {1}").format( + cmd, argname.upper() + ) + ) else: - argdict[argname + '_list'] = args[:] + argdict[argname + "_list"] = args[:] args = [] - elif ap[-1] == '$': # all but one + elif ap[-1] == "$": # all but one if len(args) < 2: raise errors.CommandError( i18n.gettext("command {0!r} needs one or more {1}").format( - cmd, argname.upper())) - argdict[argname + '_list'] = args[:-1] + cmd, argname.upper() + ) + ) + argdict[argname + "_list"] = args[:-1] args[:-1] = [] else: # just a plain arg @@ -967,23 +1005,26 @@ def _match_argform(cmd, takes_args, args): if not args: raise errors.CommandError( i18n.gettext("command {0!r} requires argument {1}").format( - cmd, argname.upper())) + cmd, argname.upper() + ) + ) else: argdict[argname] = args.pop(0) if args: - raise errors.CommandError(i18n.gettext( - "extra argument to command {0}: {1}").format( - cmd, args[0])) + raise errors.CommandError( + i18n.gettext("extra argument to command {0}: {1}").format(cmd, args[0]) + ) return argdict def apply_coveraged(the_callable, *args, **kwargs): import coverage + cov = coverage.Coverage() config_file = cov.config.config_file - os.environ['COVERAGE_PROCESS_START'] = config_file + os.environ["COVERAGE_PROCESS_START"] = config_file cov.start() try: return exception_to_return_code(the_callable, *args, **kwargs) @@ -997,17 +1038,20 @@ def apply_profiled(the_callable, *args, **kwargs): import hotshot import hotshot.stats + pffileno, pfname = tempfile.mkstemp() try: prof = hotshot.Profile(pfname) try: - ret = prof.runcall(exception_to_return_code, the_callable, *args, - **kwargs) or 0 + ret = ( + prof.runcall(exception_to_return_code, the_callable, *args, **kwargs) + or 0 + ) finally: prof.close() stats = hotshot.stats.load(pfname) stats.strip_dirs() - stats.sort_stats('cum') # 'time' + stats.sort_stats("cum") # 'time' # XXX: Might like to write to stderr or the trace file instead but # print_stats seems hardcoded to stdout stats.print_stats(20) @@ -1031,17 +1075,18 @@ def exception_to_return_code(the_callable, *args, **kwargs): # specially here, but hopefully they're handled ok by the logger now exc_info = sys.exc_info() exitcode = trace.report_exception(exc_info, sys.stderr) - if os.environ.get('BRZ_PDB'): - print('**** entering debugger') + if os.environ.get("BRZ_PDB"): + print("**** entering debugger") import pdb + pdb.post_mortem(exc_info[2]) return exitcode def apply_lsprofiled(filename, the_callable, *args, **kwargs): from .lsprof import profile - ret, stats = profile(exception_to_return_code, the_callable, - *args, **kwargs) + + ret, stats = profile(exception_to_return_code, the_callable, *args, **kwargs) stats.sort() if filename is None: stats.pprint() @@ -1063,6 +1108,7 @@ def get_alias(cmd, config=None): """ if config is None: import breezy.config + config = breezy.config.GlobalConfig() alias = config.get_alias(cmd) if alias: @@ -1119,8 +1165,11 @@ def run_bzr(argv, load_plugins=load_plugins, disable_plugins=disable_plugins): argv = _specified_or_unicode_argv(argv) trace.mutter("brz arguments: %r", argv) - opt_lsprof = opt_profile = opt_no_plugins = opt_builtin = \ - opt_coverage = opt_no_l10n = opt_no_aliases = False + opt_lsprof = ( + opt_profile + ) = ( + opt_no_plugins + ) = opt_builtin = opt_coverage = opt_no_l10n = opt_no_aliases = False opt_lsprof_file = None # --no-plugins is handled specially at a very early stage. We need @@ -1132,32 +1181,32 @@ def run_bzr(argv, load_plugins=load_plugins, disable_plugins=disable_plugins): override_config = [] while i < len(argv): a = argv[i] - if a == '--profile': + if a == "--profile": opt_profile = True - elif a == '--lsprof': + elif a == "--lsprof": opt_lsprof = True - elif a == '--lsprof-file': + elif a == "--lsprof-file": opt_lsprof = True opt_lsprof_file = argv[i + 1] i += 1 - elif a == '--no-plugins': + elif a == "--no-plugins": opt_no_plugins = True - elif a == '--no-aliases': + elif a == "--no-aliases": opt_no_aliases = True - elif a == '--no-l10n': + elif a == "--no-l10n": opt_no_l10n = True - elif a == '--builtin': + elif a == "--builtin": opt_builtin = True - elif a == '--concurrency': - os.environ['BRZ_CONCURRENCY'] = argv[i + 1] + elif a == "--concurrency": + os.environ["BRZ_CONCURRENCY"] = argv[i + 1] i += 1 - elif a == '--coverage': + elif a == "--coverage": opt_coverage = True - elif a == '--profile-imports': + elif a == "--profile-imports": pass # already handled in startup script Bug #588277 - elif a.startswith('-D'): + elif a.startswith("-D"): debug.set_debug_flag(a[2:]) - elif a.startswith('-O'): + elif a.startswith("-O"): override_config.append(a[2:]) else: argv_copy.append(a) @@ -1170,19 +1219,20 @@ def run_bzr(argv, load_plugins=load_plugins, disable_plugins=disable_plugins): if not opt_no_plugins: from breezy import config + c = config.GlobalConfig() - warn_load_problems = not c.suppress_warning('plugin_load_failure') + warn_load_problems = not c.suppress_warning("plugin_load_failure") load_plugins(warn_load_problems=warn_load_problems) else: disable_plugins() argv = argv_copy - if (not argv): - get_cmd_object('help').run_argv_aliases([]) + if not argv: + get_cmd_object("help").run_argv_aliases([]) return 0 - if argv[0] == '--version': - get_cmd_object('version').run_argv_aliases([]) + if argv[0] == "--version": + get_cmd_object("version").run_argv_aliases([]) return 0 alias_argv = None @@ -1206,13 +1256,11 @@ def run_bzr(argv, load_plugins=load_plugins, disable_plugins=disable_plugins): option._verbosity_level = 0 if opt_lsprof: if opt_coverage: - trace.warning( - '--coverage ignored, because --lsprof is in use.') + trace.warning("--coverage ignored, because --lsprof is in use.") ret = apply_lsprofiled(opt_lsprof_file, run, *run_argv) elif opt_profile: if opt_coverage: - trace.warning( - '--coverage ignored, because --profile is in use.') + trace.warning("--coverage ignored, because --profile is in use.") ret = apply_profiled(run, *run_argv) elif opt_coverage: ret = apply_coveraged(run, *run_argv) @@ -1223,8 +1271,8 @@ def run_bzr(argv, load_plugins=load_plugins, disable_plugins=disable_plugins): # reset, in case we may do other commands later within the same # process. Commands that want to execute sub-commands must propagate # --verbose in their own way. - if debug.debug_flag_enabled('memory'): - trace.debug_memory('Process status after command:', short=False) + if debug.debug_flag_enabled("memory"): + trace.debug_memory("Process status after command:", short=False) option._verbosity_level = saved_verbosity_level # Reset the overrides cmdline_overrides._reset() @@ -1232,6 +1280,7 @@ def run_bzr(argv, load_plugins=load_plugins, disable_plugins=disable_plugins): def display_command(func): """Decorator that suppresses pipe/interrupt errors.""" + def ignore_pipe(*args, **kwargs): try: result = func(*args, **kwargs) @@ -1239,15 +1288,17 @@ def ignore_pipe(*args, **kwargs): return result except OSError as e: import errno - if getattr(e, 'errno', None) is None: + + if getattr(e, "errno", None) is None: raise if e.errno != errno.EPIPE: # Win32 raises IOError with errno=0 on a broken pipe - if sys.platform != 'win32' or (e.errno not in (0, errno.EINVAL)): + if sys.platform != "win32" or (e.errno not in (0, errno.EINVAL)): raise pass except KeyboardInterrupt: pass + return ignore_pipe @@ -1255,17 +1306,19 @@ def install_bzr_command_hooks(): """Install the hooks to supply bzr's own commands.""" if _list_bzr_commands in Command.hooks["list_commands"]: return - Command.hooks.install_named_hook("list_commands", _list_bzr_commands, - "bzr commands") - Command.hooks.install_named_hook("get_command", _get_bzr_command, - "bzr commands") - Command.hooks.install_named_hook("get_command", _get_plugin_command, - "bzr plugin commands") - Command.hooks.install_named_hook("get_command", _get_external_command, - "bzr external command lookup") - Command.hooks.install_named_hook("get_missing_command", - _try_plugin_provider, - "bzr plugin-provider-db check") + Command.hooks.install_named_hook( + "list_commands", _list_bzr_commands, "bzr commands" + ) + Command.hooks.install_named_hook("get_command", _get_bzr_command, "bzr commands") + Command.hooks.install_named_hook( + "get_command", _get_plugin_command, "bzr plugin commands" + ) + Command.hooks.install_named_hook( + "get_command", _get_external_command, "bzr external command lookup" + ) + Command.hooks.install_named_hook( + "get_missing_command", _try_plugin_provider, "bzr plugin-provider-db check" + ) def _specified_or_unicode_argv(argv): @@ -1278,7 +1331,7 @@ def _specified_or_unicode_argv(argv): # ensure all arguments are unicode strings for a in argv: if not isinstance(a, str): - raise ValueError(f'not native str or unicode: {a!r}') + raise ValueError(f"not native str or unicode: {a!r}") new_argv.append(a) except (ValueError, UnicodeDecodeError) as err: raise errors.BzrError("argv should be list of unicode strings.") from err @@ -1330,8 +1383,7 @@ def run_bzr_catch_user_errors(argv): try: return run_bzr(argv) except Exception as e: - if (isinstance(e, (OSError, IOError)) - or not getattr(e, 'internal_error', True)): + if isinstance(e, (OSError, IOError)) or not getattr(e, "internal_error", True): trace.report_exception(sys.exc_info(), sys.stderr) return 3 else: @@ -1342,7 +1394,7 @@ class HelpCommandIndex: """A index for bzr help that returns commands.""" def __init__(self): - self.prefix = 'commands/' + self.prefix = "commands/" def get_topics(self, topic): """Search for topic amongst commands. @@ -1355,7 +1407,7 @@ def get_topics(self, topic): Command entry. """ if topic and topic.startswith(self.prefix): - topic = topic[len(self.prefix):] + topic = topic[len(self.prefix) :] try: cmd = _get_cmd_object(topic, check_missing=False) except KeyError: diff --git a/breezy/commit.py b/breezy/commit.py index 21f04c0f2c..a50aa8775c 100644 --- a/breezy/commit.py +++ b/breezy/commit.py @@ -62,17 +62,14 @@ class PointlessCommit(BzrError): - _fmt = "No changes to commit" class CannotCommitSelectedFileMerge(BzrError): - - _fmt = 'Selected-file commit of merges is not supported yet:'\ - ' files %(files_str)s' + _fmt = "Selected-file commit of merges is not supported yet:" " files %(files_str)s" def __init__(self, files): - files_str = ', '.join(files) + files_str = ", ".join(files) BzrError.__init__(self, files=files, files_str=files_str) @@ -84,11 +81,13 @@ def filter_excluded(iter_changes, exclude): :return: iter_changes function """ for change in iter_changes: - new_excluded = (change.path[1] is not None and - is_inside_any(exclude, change.path[1])) + new_excluded = change.path[1] is not None and is_inside_any( + exclude, change.path[1] + ) - old_excluded = (change.path[0] is not None and - is_inside_any(exclude, change.path[0])) + old_excluded = change.path[0] is not None and is_inside_any( + exclude, change.path[0] + ) if old_excluded and new_excluded: continue @@ -126,7 +125,6 @@ def is_verbose(self): class ReportCommitToLog(NullCommitReporter): - def _note(self, format, *args): """Output a message. @@ -135,35 +133,35 @@ def _note(self, format, *args): note(format, *args) def snapshot_change(self, change, path): - if path == '' and change in (gettext('added'), gettext('modified')): + if path == "" and change in (gettext("added"), gettext("modified")): return self._note("%s %s", change, path) def started(self, revno, rev_id, location): self._note( - gettext('Committing to: %s'), - unescape_for_display(location, 'utf-8')) + gettext("Committing to: %s"), unescape_for_display(location, "utf-8") + ) def completed(self, revno, rev_id): if revno is not None: - self._note(gettext('Committed revision %d.'), revno) + self._note(gettext("Committed revision %d."), revno) # self._note goes to the console too; so while we want to log the # rev_id, we can't trivially only log it. (See bug 526425). Long # term we should rearrange the reporting structure, but for now # we just mutter seperately. We mutter the revid and revno together # so that concurrent bzr invocations won't lead to confusion. - mutter('Committed revid %s as revno %d.', rev_id, revno) + mutter("Committed revid %s as revno %d.", rev_id, revno) else: - self._note(gettext('Committed revid %s.'), rev_id) + self._note(gettext("Committed revid %s."), rev_id) def deleted(self, path): - self._note(gettext('deleted %s'), path) + self._note(gettext("deleted %s"), path) def missing(self, path): - self._note(gettext('missing %s'), path) + self._note(gettext("missing %s"), path) def renamed(self, change, old_path, new_path): - self._note('%s %s => %s', change, old_path, new_path) + self._note("%s %s => %s", change, old_path, new_path) def is_verbose(self): return True @@ -182,9 +180,7 @@ class Commit: working inventory. """ - def __init__(self, - reporter=None, - config_stack=None): + def __init__(self, reporter=None, config_stack=None): """Create a Commit object. :param reporter: the default reporter to use or None to decide later @@ -193,49 +189,55 @@ def __init__(self, self.config_stack = config_stack @staticmethod - def update_revprops(revprops, branch, authors=None, - local=False, possible_master_transports=None): + def update_revprops( + revprops, branch, authors=None, local=False, possible_master_transports=None + ): if revprops is None: revprops = {} if possible_master_transports is None: possible_master_transports = [] - if ('branch-nick' not in revprops and - branch.repository._format.supports_storing_branch_nick): - revprops['branch-nick'] = branch._get_nick( - local, - possible_master_transports) + if ( + "branch-nick" not in revprops + and branch.repository._format.supports_storing_branch_nick + ): + revprops["branch-nick"] = branch._get_nick( + local, possible_master_transports + ) if authors is not None: - if 'author' in revprops or 'authors' in revprops: + if "author" in revprops or "authors" in revprops: # XXX: maybe we should just accept one of them? - raise AssertionError('author property given twice') + raise AssertionError("author property given twice") if authors: for individual in authors: - if '\n' in individual: - raise AssertionError('\\n is not a valid character ' - 'in an author identity') - revprops['authors'] = '\n'.join(authors) + if "\n" in individual: + raise AssertionError( + "\\n is not a valid character " "in an author identity" + ) + revprops["authors"] = "\n".join(authors) return revprops - def commit(self, - message=None, - timestamp=None, - timezone=None, - committer=None, - specific_files=None, - rev_id=None, - allow_pointless=True, - strict=False, - verbose=False, - revprops=None, - working_tree=None, - local=False, - reporter=None, - config=None, - message_callback=None, - recursive='down', - exclude=None, - possible_master_transports=None, - lossy=False): + def commit( + self, + message=None, + timestamp=None, + timezone=None, + committer=None, + specific_files=None, + rev_id=None, + allow_pointless=True, + strict=False, + verbose=False, + revprops=None, + working_tree=None, + local=False, + reporter=None, + config=None, + message_callback=None, + recursive="down", + exclude=None, + possible_master_transports=None, + lossy=False, + ): """Commit working copy as a new revision. :param message: the commit message (it or message_callback is required) @@ -276,14 +278,14 @@ def commit(self, self.revprops = revprops or {} # XXX: Can be set on __init__ or passed in - this is a bit ugly. self.config_stack = config or self.config_stack - mutter('preparing to commit') + mutter("preparing to commit") if working_tree is None: raise BzrError("working_tree must be passed into commit().") else: self.work_tree = working_tree self.branch = self.work_tree.branch - if getattr(self.work_tree, 'requires_rich_root', lambda: False)(): + if getattr(self.work_tree, "requires_rich_root", lambda: False)(): if not self.branch.repository.supports_rich_root(): raise errors.RootNotRich() if message_callback is None: @@ -294,14 +296,15 @@ def commit(self, def message_callback(x): return message else: - raise BzrError("The message or message_callback keyword" - " parameter is required for commit().") + raise BzrError( + "The message or message_callback keyword" + " parameter is required for commit()." + ) self.bound_branch = None self.any_entries_deleted = False if exclude is not None: - self.exclude = sorted( - minimum_path_selection(exclude)) + self.exclude = sorted(minimum_path_selection(exclude)) else: self.exclude = [] self.local = local @@ -311,8 +314,7 @@ def message_callback(x): # self.specific_files is None to indicate no filter, or any iterable to # indicate a filter - [] means no files at all, as per iter_changes. if specific_files is not None: - self.specific_files = sorted( - minimum_path_selection(specific_files)) + self.specific_files = sorted(minimum_path_selection(specific_files)) else: self.specific_files = None @@ -378,14 +380,22 @@ def message_callback(x): self._set_progress_stage("Collecting changes", counter=True) self._lossy = lossy self.builder = self.branch.get_commit_builder( - self.parents, self.config_stack, timestamp, timezone, committer, - self.revprops, rev_id, lossy=lossy) + self.parents, + self.config_stack, + timestamp, + timezone, + committer, + self.revprops, + rev_id, + lossy=lossy, + ) if self.builder.updates_branch and self.bound_branch: self.builder.abort() raise AssertionError( "bound branches not supported for commit builders " - "that update the branch") + "that update the branch" + ) try: # find the location being committed to @@ -428,7 +438,8 @@ def message_callback(x): self.work_tree.unversion(self.deleted_paths) self._set_progress_stage("Updating the working tree") self.work_tree.update_basis_by_delta( - self.rev_id, self.builder.get_basis_delta()) + self.rev_id, self.builder.get_basis_delta() + ) self.reporter.completed(new_revno, self.rev_id) self._process_post_hooks(old_revno, new_revno) return self.rev_id @@ -453,8 +464,12 @@ def _update_branches(self, old_revno, old_revid, new_revno): self._set_progress_stage("Uploading data to master branch") # 'commit' to the master first so a timeout here causes the # local branch to be out of date - (new_revno, self.rev_id) = self.master_branch.import_last_revision_info_and_tags( - self.branch, new_revno, self.rev_id, lossy=self._lossy) + ( + new_revno, + self.rev_id, + ) = self.master_branch.import_last_revision_info_and_tags( + self.branch, new_revno, self.rev_id, lossy=self._lossy + ) if self._lossy: self.branch.fetch(self.master_branch, self.rev_id) @@ -476,11 +491,15 @@ def _update_branches(self, old_revno, old_revid, new_revno): if self.bound_branch: self._set_progress_stage("Merging tags to master branch") tag_updates, tag_conflicts = self.branch.tags.merge_to( - self.master_branch.tags) + self.master_branch.tags + ) if tag_conflicts: - warning_lines = [' ' + name for name, _, _ in tag_conflicts] - note(gettext("Conflicting tags in bound branch:\n{}").format( - "\n".join(warning_lines))) + warning_lines = [" " + name for name, _, _ in tag_conflicts] + note( + gettext("Conflicting tags in bound branch:\n{}").format( + "\n".join(warning_lines) + ) + ) def _select_reporter(self): """Select the CommitReporter to use.""" @@ -510,7 +529,8 @@ def _check_bound_branch(self, stack, possible_master_transports=None): if not self.local: self.master_branch = self.branch.get_master_branch( - possible_master_transports) + possible_master_transports + ) if not self.master_branch: # make this branch the reference branch for out of date checks. @@ -521,7 +541,8 @@ def _check_bound_branch(self, stack, possible_master_transports=None): master_bound_location = self.master_branch.get_bound_location() if master_bound_location: raise errors.CommitToDoubleBoundBranch( - self.branch, self.master_branch, master_bound_location) + self.branch, self.master_branch, master_bound_location + ) # TODO: jam 20051230 We could automatically push local # commits to the remote branch if they would fit. @@ -532,8 +553,7 @@ def _check_bound_branch(self, stack, possible_master_transports=None): master_revid = self.master_branch.last_revision() local_revid = self.branch.last_revision() if local_revid != master_revid: - raise errors.BoundBranchOutOfDate(self.branch, - self.master_branch) + raise errors.BoundBranchOutOfDate(self.branch, self.master_branch) # Now things are ready to change the master branch # so grab the lock @@ -554,8 +574,9 @@ def _check_out_of_date_tree(self): # - in a checkout scenario the tree may have no # parents but the branch may do. first_tree_parent = breezy.revision.NULL_REVISION - if (self.master_branch._format.stores_revno() or - self.config_stack.get('calculate_revnos')): + if self.master_branch._format.stores_revno() or self.config_stack.get( + "calculate_revnos" + ): try: old_revno, master_last = self.master_branch.last_revision_info() except errors.UnsupportedOperation: @@ -567,8 +588,9 @@ def _check_out_of_date_tree(self): if master_last != first_tree_parent: if master_last != breezy.revision.NULL_REVISION: raise errors.OutOfDateTree(self.work_tree) - if (old_revno is not None and - self.branch.repository.has_revision(first_tree_parent)): + if old_revno is not None and self.branch.repository.has_revision( + first_tree_parent + ): new_revno = old_revno + 1 else: # ghost parents never appear in revision history. @@ -586,15 +608,15 @@ def _process_post_hooks(self, old_revno, new_revno): self._set_progress_stage("Running post_commit hooks") # old style commit hooks - should be deprecated ? (obsoleted in # 0.15^H^H^H^H 2.5.0) - post_commit = self.config_stack.get('post_commit') + post_commit = self.config_stack.get("post_commit") if post_commit is not None: - hooks = post_commit.split(' ') + hooks = post_commit.split(" ") # this would be nicer with twisted.python.reflect.namedAny for hook in hooks: - eval(hook + '(branch, rev_id)', # noqa: S307 - {'branch': self.branch, - 'breezy': breezy, - 'rev_id': self.rev_id}) + eval( # noqa: S307 + hook + "(branch, rev_id)", + {"branch": self.branch, "breezy": breezy, "rev_id": self.rev_id}, + ) # process new style post commit hooks self._process_hooks("post_commit", old_revno, new_revno) @@ -619,8 +641,7 @@ def _process_hooks(self, hook_name, old_revno, new_revno): if hook_name == "pre_commit": future_tree = self.builder.revision_tree() - tree_delta = future_tree.changes_from(self.basis_tree, - include_root=True) + tree_delta = future_tree.changes_from(self.basis_tree, include_root=True) for hook in Branch.hooks[hook_name]: # show the running hook in the progress bar. As hooks may @@ -628,17 +649,32 @@ def _process_hooks(self, hook_name, old_revno, new_revno): # the user) this is still showing progress, not showing overall # actions - its up to each plugin to show a UI if it want's to # (such as 'Emailing diff to foo@example.com'). - self.pb_stage_name = f"Running {hook_name} hooks [{Branch.hooks.get_hook_name(hook)}]" + self.pb_stage_name = ( + f"Running {hook_name} hooks [{Branch.hooks.get_hook_name(hook)}]" + ) self._emit_progress() - if debug.debug_flag_enabled('hooks'): + if debug.debug_flag_enabled("hooks"): mutter("Invoking commit hook: %r", hook) if hook_name == "post_commit": - hook(hook_local, hook_master, old_revno, old_revid, new_revno, - self.rev_id) + hook( + hook_local, + hook_master, + old_revno, + old_revid, + new_revno, + self.rev_id, + ) elif hook_name == "pre_commit": - hook(hook_local, hook_master, - old_revno, old_revid, new_revno, self.rev_id, - tree_delta, future_tree) + hook( + hook_local, + hook_master, + old_revno, + old_revid, + new_revno, + self.rev_id, + tree_delta, + future_tree, + ) def _update_builder_with_changes(self): """Update the commit builder with the data about what has changed.""" @@ -647,12 +683,14 @@ def _update_builder_with_changes(self): self._check_strict() iter_changes = self.work_tree.iter_changes( - self.basis_tree, specific_files=specific_files) + self.basis_tree, specific_files=specific_files + ) if self.exclude: iter_changes = filter_excluded(iter_changes, self.exclude) iter_changes = self._filter_iter_changes(iter_changes) for path, fs_hash in self.builder.record_iter_changes( - self.work_tree, self.basis_revid, iter_changes): + self.work_tree, self.basis_revid, iter_changes + ): self.work_tree._observed_sha1(path, fs_hash) def _filter_iter_changes(self, iter_changes): @@ -679,17 +717,22 @@ def _filter_iter_changes(self, iter_changes): # 'missing' path if report_changes: reporter.missing(new_path) - if change.kind[0] == 'symlink' and not self.work_tree.supports_symlinks(): - trace.warning(f'Ignoring "{change.path[0]}" as symlinks are not ' - 'supported on this filesystem.') + if ( + change.kind[0] == "symlink" + and not self.work_tree.supports_symlinks() + ): + trace.warning( + f'Ignoring "{change.path[0]}" as symlinks are not ' + "supported on this filesystem." + ) continue deleted_paths.append(change.path[1]) # Reset the new path (None) and new versioned flag (False) change = change.discard_new() new_path = change.path[1] versioned = False - elif kind == 'tree-reference': - if self.recursive == 'down': + elif kind == "tree-reference": + if self.recursive == "down": self._commit_nested_tree(change.path[1]) if change.versioned[0] or change.versioned[1]: yield change @@ -697,17 +740,17 @@ def _filter_iter_changes(self, iter_changes): if new_path is None: reporter.deleted(old_path) elif old_path is None: - reporter.snapshot_change(gettext('added'), new_path) + reporter.snapshot_change(gettext("added"), new_path) elif old_path != new_path: - reporter.renamed(gettext('renamed'), - old_path, new_path) + reporter.renamed(gettext("renamed"), old_path, new_path) else: - if (new_path - or self.work_tree.branch.repository._format.rich_root_data): + if ( + new_path + or self.work_tree.branch.repository._format.rich_root_data + ): # Don't report on changes to '' in non rich root # repositories. - reporter.snapshot_change( - gettext('modified'), new_path) + reporter.snapshot_change(gettext("modified"), new_path) self._next_progress_entry() # Unversion files that were found to be deleted self.deleted_paths = deleted_paths @@ -731,19 +774,24 @@ def _commit_nested_tree(self, path): # finally implement the explicit-caches approach design # a while back - RBC 20070306. if sub_tree.branch.repository.has_same_location( - self.work_tree.branch.repository): - sub_tree.branch.repository = \ - self.work_tree.branch.repository + self.work_tree.branch.repository + ): + sub_tree.branch.repository = self.work_tree.branch.repository try: - return sub_tree.commit(message=None, revprops=self.revprops, - recursive=self.recursive, - message_callback=self.message_callback, - timestamp=self.timestamp, - timezone=self.timezone, - committer=self.committer, - allow_pointless=self.allow_pointless, - strict=self.strict, verbose=self.verbose, - local=self.local, reporter=self.reporter) + return sub_tree.commit( + message=None, + revprops=self.revprops, + recursive=self.recursive, + message_callback=self.message_callback, + timestamp=self.timestamp, + timezone=self.timezone, + committer=self.committer, + allow_pointless=self.allow_pointless, + strict=self.strict, + verbose=self.verbose, + local=self.local, + reporter=self.reporter, + ) except PointlessCommit: return self.work_tree.get_reference_revision(path) @@ -764,8 +812,9 @@ def _next_progress_entry(self): def _emit_progress(self): if self.pb_entries_count is not None: - text = gettext("{0} [{1}] - Stage").format(self.pb_stage_name, - self.pb_entries_count) + text = gettext("{0} [{1}] - Stage").format( + self.pb_stage_name, self.pb_entries_count + ) else: - text = gettext("%s - Stage") % (self.pb_stage_name, ) + text = gettext("%s - Stage") % (self.pb_stage_name,) self.pb.update(text, self.pb_stage_count, self.pb_stage_total) diff --git a/breezy/commit_signature_commands.py b/breezy/commit_signature_commands.py index a24e74a07b..7a5d0d5ef9 100644 --- a/breezy/commit_signature_commands.py +++ b/breezy/commit_signature_commands.py @@ -37,15 +37,17 @@ class cmd_sign_my_commits(Command): # repository takes_options = [ - Option('dry-run', - help='Don\'t actually sign anything, just print' - ' the revisions that would be signed.'), - ] - takes_args = ['location?', 'committer?'] + Option( + "dry-run", + help="Don't actually sign anything, just print" + " the revisions that would be signed.", + ), + ] + takes_args = ["location?", "committer?"] def run(self, location=None, committer=None, dry_run=False): if location is None: - bzrdir = controldir.ControlDir.open_containing('.')[0] + bzrdir = controldir.ControlDir.open_containing(".")[0] else: # Passed in locations should be exact bzrdir = controldir.ControlDir.open(location) @@ -54,15 +56,14 @@ def run(self, location=None, committer=None, dry_run=False): branch_config = branch.get_config_stack() if committer is None: - committer = branch_config.get('email') + committer = branch_config.get("email") gpg_strategy = gpg.GPGStrategy(branch_config) count = 0 with repo.lock_write(): graph = repo.get_graph() with _mod_repository.WriteGroup(repo): - for rev_id, parents in graph.iter_ancestry( - [branch.last_revision()]): + for rev_id, parents in graph.iter_ancestry([branch.last_revision()]): if _mod_revision.is_null(rev_id): continue if parents is None: @@ -80,8 +81,8 @@ def run(self, location=None, committer=None, dry_run=False): if not dry_run: repo.sign_revision(rev_id, gpg_strategy) self.outf.write( - ngettext('Signed %d revision.\n', 'Signed %d revisions.\n', - count) % count) + ngettext("Signed %d revision.\n", "Signed %d revisions.\n", count) % count + ) class cmd_verify_signatures(Command): @@ -91,18 +92,19 @@ class cmd_verify_signatures(Command): """ takes_options = [ - Option('acceptable-keys', - help='Comma separated list of GPG key patterns which are' - ' acceptable for verification.', - short_name='k', - type=str,), - 'revision', - 'verbose', - ] - takes_args = ['location?'] - - def run(self, acceptable_keys=None, revision=None, verbose=None, - location='.'): + Option( + "acceptable-keys", + help="Comma separated list of GPG key patterns which are" + " acceptable for verification.", + short_name="k", + type=str, + ), + "revision", + "verbose", + ] + takes_args = ["location?"] + + def run(self, acceptable_keys=None, revision=None, verbose=None, location="."): bzrdir = controldir.ControlDir.open_containing(location)[0] branch = bzrdir.open_branch() repo = branch.repository @@ -131,16 +133,17 @@ def write_verbose(string): to_revno = branch.revno() if from_revno is None or to_revno is None: raise errors.CommandError( - gettext('Cannot verify a range of non-revision-history' - ' revisions')) + gettext( + "Cannot verify a range of non-revision-history" " revisions" + ) + ) for revno in range(from_revno, to_revno + 1): revisions.append(branch.get_rev_id(revno)) else: # all revisions by default including merges graph = repo.get_graph() revisions = [] - for rev_id, parents in graph.iter_ancestry( - [branch.last_revision()]): + for rev_id, parents in graph.iter_ancestry([branch.last_revision()]): if _mod_revision.is_null(rev_id): continue if parents is None: @@ -148,7 +151,8 @@ def write_verbose(string): continue revisions.append(rev_id) count, result, all_verifiable = gpg.bulk_verify_signatures( - repo, revisions, gpg_strategy) + repo, revisions, gpg_strategy + ) if all_verifiable: write(gettext("All commits signed with verifiable keys")) if verbose: diff --git a/breezy/config.py b/breezy/config.py index 7da96ff795..6f9fc0e58f 100644 --- a/breezy/config.py +++ b/breezy/config.py @@ -87,7 +87,9 @@ from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import re from breezy import ( @@ -100,7 +102,8 @@ win32utils, ) from breezy.i18n import gettext -""") +""", +) from . import ( bedding, commands, @@ -132,15 +135,15 @@ _policy_name = { POLICY_NONE: None, - POLICY_NORECURSE: 'norecurse', - POLICY_APPENDPATH: 'appendpath', - } + POLICY_NORECURSE: "norecurse", + POLICY_APPENDPATH: "appendpath", +} _policy_value = { None: POLICY_NONE, - 'none': POLICY_NONE, - 'norecurse': POLICY_NORECURSE, - 'appendpath': POLICY_APPENDPATH, - } + "none": POLICY_NONE, + "norecurse": POLICY_NORECURSE, + "appendpath": POLICY_APPENDPATH, +} STORE_LOCATION = POLICY_NONE @@ -151,16 +154,14 @@ class OptionExpansionLoop(errors.BzrError): - _fmt = 'Loop involving %(refs)r while expanding "%(string)s".' def __init__(self, string, refs): self.string = string - self.refs = '->'.join(refs) + self.refs = "->".join(refs) class ExpandingUnknownOption(errors.BzrError): - _fmt = 'Option "%(name)s" is not defined while expanding "%(string)s".' def __init__(self, name, string): @@ -169,7 +170,6 @@ def __init__(self, name, string): class IllegalOptionName(errors.BzrError): - _fmt = 'Option "%(name)s" is not allowed.' def __init__(self, name): @@ -177,7 +177,6 @@ def __init__(self, name): class ConfigContentError(errors.BzrError): - _fmt = "Config file %(filename)s is not UTF-8 encoded\n" def __init__(self, filename): @@ -185,25 +184,21 @@ def __init__(self, filename): class ParseConfigError(errors.BzrError): - _fmt = "Error(s) parsing config file %(filename)s:\n%(errors)s" def __init__(self, errors, filename): self.filename = filename - self.errors = '\n'.join(e.msg for e in errors) + self.errors = "\n".join(e.msg for e in errors) class ConfigOptionValueError(errors.BzrError): - - _fmt = ('Bad value "%(value)s" for option "%(name)s".\n' - 'See ``brz help %(name)s``') + _fmt = 'Bad value "%(value)s" for option "%(name)s".\n' "See ``brz help %(name)s``" def __init__(self, name, value): errors.BzrError.__init__(self, name=name, value=value) class NoEmailInUsername(errors.BzrError): - _fmt = "%(username)r does not seem to contain a reasonable email address" def __init__(self, username): @@ -211,24 +206,21 @@ def __init__(self, username): class NoSuchConfig(errors.BzrError): - - _fmt = ('The "%(config_id)s" configuration does not exist.') + _fmt = 'The "%(config_id)s" configuration does not exist.' def __init__(self, config_id): errors.BzrError.__init__(self, config_id=config_id) class NoSuchConfigOption(errors.BzrError): - - _fmt = ('The "%(option_name)s" configuration option does not exist.') + _fmt = 'The "%(option_name)s" configuration option does not exist.' def __init__(self, option_name): errors.BzrError.__init__(self, option_name=option_name) class NoSuchAlias(errors.BzrError): - - _fmt = ('The alias "%(alias_name)s" does not exist.') + _fmt = 'The alias "%(alias_name)s" does not exist.' def __init__(self, alias_name): errors.BzrError.__init__(self, alias_name=alias_name) @@ -236,24 +228,24 @@ def __init__(self, alias_name): def signature_policy_from_unicode(signature_string): """Convert a string to a signing policy.""" - if signature_string.lower() == 'check-available': + if signature_string.lower() == "check-available": return CHECK_IF_POSSIBLE - if signature_string.lower() == 'ignore': + if signature_string.lower() == "ignore": return CHECK_NEVER - if signature_string.lower() == 'require': + if signature_string.lower() == "require": return CHECK_ALWAYS raise ValueError(f"Invalid signatures policy '{signature_string}'") def signing_policy_from_unicode(signature_string): """Convert a string to a signing policy.""" - if signature_string.lower() == 'when-required': + if signature_string.lower() == "when-required": return SIGN_WHEN_REQUIRED - if signature_string.lower() == 'never': + if signature_string.lower() == "never": return SIGN_NEVER - if signature_string.lower() == 'always': + if signature_string.lower() == "always": return SIGN_ALWAYS - if signature_string.lower() == 'when-possible': + if signature_string.lower() == "when-possible": return SIGN_WHEN_POSSIBLE raise ValueError(f"Invalid signing policy '{signature_string}'") @@ -268,14 +260,12 @@ def _has_triplequote_bug(): class ConfigObj(configobj.ConfigObj): - def __init__(self, infile=None, **kwargs): # We define our own interpolation mechanism calling it option expansion - super().__init__(infile=infile, - interpolation=False, - **kwargs) + super().__init__(infile=infile, interpolation=False, **kwargs) if _has_triplequote_bug(): + def _get_triple_quote(self, value): quot = super()._get_triple_quote(value) if quot == configobj.tdquot: @@ -307,16 +297,16 @@ def config_id(self): def get_change_editor(self, old_tree, new_tree): from breezy import diff + cmd = self._get_change_editor() if cmd is None: return None - cmd = cmd.replace('@old_path', '{old_path}') - cmd = cmd.replace('@new_path', '{new_path}') + cmd = cmd.replace("@old_path", "{old_path}") + cmd = cmd.replace("@new_path", "{new_path}") cmd = cmdline.split(cmd) - if '{old_path}' not in cmd: - cmd.extend(['{old_path}', '{new_path}']) - return diff.DiffFromTool.from_string(cmd, old_tree, new_tree, - sys.stdout) + if "{old_path}" not in cmd: + cmd.extend(["{old_path}", "{new_path}"]) + return diff.DiffFromTool.from_string(cmd, old_tree, new_tree, sys.stdout) def _get_signature_checking(self): """Template method to override signature checking policy.""" @@ -388,7 +378,7 @@ def _expand_options_in_string(self, string, env=None, _ref_stack=None): # We want to match the most embedded reference first (i.e. for # '{{foo}}' we will get '{foo}', # for '{bar{baz}}' we will get '{baz}' - self.option_ref_re = re.compile('({[^{}]+})') + self.option_ref_re = re.compile("({[^{}]+})") result = string # We need to iterate until no more refs appear ({{foo}} will need two # iterations for example). @@ -431,7 +421,7 @@ def _expand_options_in_string(self, string, env=None, _ref_stack=None): # expanded value is a list. return self._expand_options_in_list(chunks, env, _ref_stack) else: - result = ''.join(chunks) + result = "".join(chunks) return result def _expand_option(self, name, env, _ref_stack): @@ -471,11 +461,13 @@ def get_user_option(self, option_name, expand=True): if isinstance(value, list): value = self._expand_options_in_list(value) elif isinstance(value, dict): - trace.warning(f'Cannot expand "{option_name}":' - ' Dicts do not support option expansion') + trace.warning( + f'Cannot expand "{option_name}":' + " Dicts do not support option expansion" + ) else: value = self._expand_options_in_string(value) - for hook in OldConfigHooks['get']: + for hook in OldConfigHooks["get"]: hook(self, option_name, value) return value @@ -497,8 +489,7 @@ def get_user_option_as_bool(self, option_name, expand=None, default=None): val = ui.bool_from_string(s) if val is None: # The value can't be interpreted as a boolean - trace.warning('Value "%s" is not a boolean for "%s"', - s, option_name) + trace.warning('Value "%s" is not a boolean for "%s"', s, option_name) return val def get_user_option_as_list(self, option_name, expand=None): @@ -550,7 +541,7 @@ def username(self): $EMAIL is examined. If no username can be found, NoWhoami exception is raised. """ - v = os.environ.get('BRZ_EMAIL') or os.environ.get('BZR_EMAIL') + v = os.environ.get("BRZ_EMAIL") or os.environ.get("BZR_EMAIL") if v: return v v = self._get_user_id() @@ -572,11 +563,11 @@ def _get_nickname(self): def get_bzr_remote_path(self): try: - return os.environ['BZR_REMOTE_PATH'] + return os.environ["BZR_REMOTE_PATH"] except KeyError: path = self.get_user_option("bzr_remote_path") if path is None: - path = 'bzr' + path = "bzr" return path def suppress_warning(self, warning): @@ -588,7 +579,7 @@ def suppress_warning(self, warning): Returns: True if the warning should be suppressed, False otherwise. """ - warnings = self.get_user_option_as_list('suppress_warnings') + warnings = self.get_user_option_as_list("suppress_warnings") if warnings is None or warning not in warnings: return False else: @@ -596,11 +587,11 @@ def suppress_warning(self, warning): def get_merge_tools(self): tools = {} - for (oname, _value, _section, _conf_id, _parser) in self._get_options(): - if oname.startswith('bzr.mergetool.'): - tool_name = oname[len('bzr.mergetool.'):] + for oname, _value, _section, _conf_id, _parser in self._get_options(): + if oname.startswith("bzr.mergetool."): + tool_name = oname[len("bzr.mergetool.") :] tools[tool_name] = self.get_user_option(oname, False) - trace.mutter(f'loaded merge tools: {tools!r}') + trace.mutter(f"loaded merge tools: {tools!r}") return tools def find_merge_tool(self, name): @@ -610,9 +601,9 @@ def find_merge_tool(self, name): # be found in the known_merge_tools if it's not found in the config. # This should be done through the proposed config defaults mechanism # when it becomes available in the future. - command_line = (self.get_user_option(f'bzr.mergetool.{name}', - expand=False) or - known_merge_tools.get(name, None)) + command_line = self.get_user_option( + f"bzr.mergetool.{name}", expand=False + ) or known_merge_tools.get(name, None) return command_line @@ -625,28 +616,36 @@ def __init__(self): These are all empty initially, because by default nothing should get notified. """ - super().__init__('breezy.config', 'ConfigHooks') - self.add_hook('load', - 'Invoked when a config store is loaded.' - ' The signature is (store).', - (2, 4)) - self.add_hook('save', - 'Invoked when a config store is saved.' - ' The signature is (store).', - (2, 4)) + super().__init__("breezy.config", "ConfigHooks") + self.add_hook( + "load", + "Invoked when a config store is loaded." " The signature is (store).", + (2, 4), + ) + self.add_hook( + "save", + "Invoked when a config store is saved." " The signature is (store).", + (2, 4), + ) # The hooks for config options - self.add_hook('get', - 'Invoked when a config option is read.' - ' The signature is (stack, name, value).', - (2, 4)) - self.add_hook('set', - 'Invoked when a config option is set.' - ' The signature is (stack, name, value).', - (2, 4)) - self.add_hook('remove', - 'Invoked when a config option is removed.' - ' The signature is (stack, name).', - (2, 4)) + self.add_hook( + "get", + "Invoked when a config option is read." + " The signature is (stack, name, value).", + (2, 4), + ) + self.add_hook( + "set", + "Invoked when a config option is set." + " The signature is (stack, name, value).", + (2, 4), + ) + self.add_hook( + "remove", + "Invoked when a config option is removed." + " The signature is (stack, name).", + (2, 4), + ) ConfigHooks = _ConfigHooks() @@ -661,29 +660,36 @@ def __init__(self): These are all empty initially, because by default nothing should get notified. """ - super().__init__( - 'breezy.config', 'OldConfigHooks') - self.add_hook('load', - 'Invoked when a config store is loaded.' - ' The signature is (config).', - (2, 4)) - self.add_hook('save', - 'Invoked when a config store is saved.' - ' The signature is (config).', - (2, 4)) + super().__init__("breezy.config", "OldConfigHooks") + self.add_hook( + "load", + "Invoked when a config store is loaded." " The signature is (config).", + (2, 4), + ) + self.add_hook( + "save", + "Invoked when a config store is saved." " The signature is (config).", + (2, 4), + ) # The hooks for config options - self.add_hook('get', - 'Invoked when a config option is read.' - ' The signature is (config, name, value).', - (2, 4)) - self.add_hook('set', - 'Invoked when a config option is set.' - ' The signature is (config, name, value).', - (2, 4)) - self.add_hook('remove', - 'Invoked when a config option is removed.' - ' The signature is (config, name).', - (2, 4)) + self.add_hook( + "get", + "Invoked when a config option is read." + " The signature is (config, name, value).", + (2, 4), + ) + self.add_hook( + "set", + "Invoked when a config option is set." + " The signature is (config, name, value).", + (2, 4), + ) + self.add_hook( + "remove", + "Invoked when a config option is removed." + " The signature is (config, name).", + (2, 4), + ) OldConfigHooks = _OldConfigHooks() @@ -720,7 +726,7 @@ def from_string(cls, str_or_unicode, file_name=None, save=False): def _create_from_string(self, str_or_unicode, save): if isinstance(str_or_unicode, str): - str_or_unicode = str_or_unicode.encode('utf-8') + str_or_unicode = str_or_unicode.encode("utf-8") self._content = BytesIO(str_or_unicode) # Some tests use in-memory configs, some other always need the config # file to exist on disk. @@ -733,28 +739,28 @@ def _get_parser(self): if self._content is not None: co_input = self._content elif self.file_name is None: - raise AssertionError('We have no content to create the config') + raise AssertionError("We have no content to create the config") else: co_input = self.file_name try: - self._parser = ConfigObj(co_input, encoding='utf-8') + self._parser = ConfigObj(co_input, encoding="utf-8") except configobj.ConfigObjError as e: raise ParseConfigError(e.errors, e.config.filename) from e except UnicodeDecodeError as e: raise ConfigContentError(self.file_name) from e # Make sure self.reload() will use the right file name self._parser.filename = self.file_name - for hook in OldConfigHooks['load']: + for hook in OldConfigHooks["load"]: hook(self) return self._parser def reload(self): """Reload the config file from disk.""" if self.file_name is None: - raise AssertionError('We need a file name to reload the config') + raise AssertionError("We need a file name to reload the config") if self._parser is not None: self._parser.reload() - for hook in ConfigHooks['load']: + for hook in ConfigHooks["load"]: hook(self) def _get_matching_sections(self): @@ -765,7 +771,7 @@ def _get_matching_sections(self): """ section = self._get_section() if section is not None: - return [(section, '')] + return [(section, "")] else: return [] @@ -809,7 +815,7 @@ def _get_options(self, sections=None): if sections is None: parser = self._get_parser() sections = [] - for (section_name, _) in self._get_matching_sections(): + for section_name, _ in self._get_matching_sections(): try: section = parser[section_name] except KeyError: @@ -819,37 +825,36 @@ def _get_options(self, sections=None): continue sections.append((section_name, section)) config_id = self.config_id() - for (section_name, section) in sections: - for (name, value) in section.iteritems(): - yield (name, parser._quote(value), section_name, - config_id, parser) + for section_name, section in sections: + for name, value in section.iteritems(): + yield (name, parser._quote(value), section_name, config_id, parser) def _get_option_policy(self, section, option_name): """Return the policy for the given (section, option_name) pair.""" return POLICY_NONE def _get_change_editor(self): - return self.get_user_option('change_editor', expand=False) + return self.get_user_option("change_editor", expand=False) def _get_signature_checking(self): """See Config._get_signature_checking.""" - policy = self._get_user_option('check_signatures') + policy = self._get_user_option("check_signatures") if policy: return signature_policy_from_unicode(policy) def _get_signing_policy(self): """See Config._get_signing_policy.""" - policy = self._get_user_option('create_signatures') + policy = self._get_user_option("create_signatures") if policy: return signing_policy_from_unicode(policy) def _get_user_id(self): """Get the user id from the 'email' key in the current section.""" - return self._get_user_option('email') + return self._get_user_option("email") def _get_user_option(self, option_name): """See Config._get_user_option.""" - for (section, extra_path) in self._get_matching_sections(): + for section, extra_path in self._get_matching_sections(): try: value = self._get_parser().get_value(section, option_name) except KeyError: @@ -868,35 +873,34 @@ def _get_user_option(self, option_name): value = urlutils.join(value, extra_path) return value else: - raise AssertionError(f'Unexpected config policy {policy!r}') + raise AssertionError(f"Unexpected config policy {policy!r}") else: return None def _log_format(self): """See Config.log_format.""" - return self._get_user_option('log_format') + return self._get_user_option("log_format") def _validate_signatures_in_log(self): """See Config.validate_signatures_in_log.""" - return self._get_user_option('validate_signatures_in_log') + return self._get_user_option("validate_signatures_in_log") def _acceptable_keys(self): """See Config.acceptable_keys.""" - return self._get_user_option('acceptable_keys') + return self._get_user_option("acceptable_keys") def _post_commit(self): """See Config.post_commit.""" - return self._get_user_option('post_commit') + return self._get_user_option("post_commit") def _get_alias(self, value): try: - return self._get_parser().get_value("ALIASES", - value) + return self._get_parser().get_value("ALIASES", value) except KeyError: pass def _get_nickname(self): - return self.get_user_option('nickname') + return self.get_user_option("nickname") def remove_user_option(self, option_name, section_name=None): """Remove a user option and save the configuration file. @@ -917,19 +921,20 @@ def remove_user_option(self, option_name, section_name=None): except KeyError as e: raise NoSuchConfigOption(option_name) from e self._write_config_file() - for hook in OldConfigHooks['remove']: + for hook in OldConfigHooks["remove"]: hook(self, option_name) def _write_config_file(self): if self.file_name is None: - raise AssertionError('We cannot save, self.file_name is None') + raise AssertionError("We cannot save, self.file_name is None") from . import atomicfile + conf_dir = os.path.dirname(self.file_name) bedding.ensure_config_dir_exists(conf_dir) with atomicfile.AtomicFile(self.file_name) as atomic_file: self._get_parser().write(atomic_file) osutils.copy_ownership_from_path(self.file_name) - for hook in OldConfigHooks['save']: + for hook in OldConfigHooks["save"]: hook(self) @@ -958,7 +963,7 @@ class LockableConfig(IniBasedConfig): update made by another writer. """ - lock_name = 'lock' + lock_name = "lock" def __init__(self, file_name): super().__init__(file_name=file_name) @@ -997,8 +1002,7 @@ def break_lock(self): def remove_user_option(self, option_name, section_name=None): with self.lock_write(): - super().remove_user_option( - option_name, section_name) + super().remove_user_option(option_name, section_name) def _write_config_file(self): if self._lock is None or not self._lock.is_held: @@ -1015,7 +1019,7 @@ def __init__(self): super().__init__(file_name=bedding.config_path()) def config_id(self): - return 'breezy' + return "breezy" @classmethod def from_string(cls, str_or_unicode, save=False): @@ -1033,25 +1037,25 @@ def from_string(cls, str_or_unicode, save=False): def set_user_option(self, option, value): """Save option and its value in the configuration.""" with self.lock_write(): - self._set_option(option, value, 'DEFAULT') + self._set_option(option, value, "DEFAULT") def get_aliases(self): """Return the aliases section.""" - if 'ALIASES' in self._get_parser(): - return self._get_parser()['ALIASES'] + if "ALIASES" in self._get_parser(): + return self._get_parser()["ALIASES"] else: return {} def set_alias(self, alias_name, alias_command): """Save the alias in the configuration.""" with self.lock_write(): - self._set_option(alias_name, alias_command, 'ALIASES') + self._set_option(alias_name, alias_command, "ALIASES") def unset_alias(self, alias_name): """Unset an existing alias.""" with self.lock_write(): self.reload() - aliases = self._get_parser().get('ALIASES') + aliases = self._get_parser().get("ALIASES") if not aliases or alias_name not in aliases: raise NoSuchAlias(alias_name) del aliases[alias_name] @@ -1061,7 +1065,7 @@ def _set_option(self, option, value, section): self.reload() self._get_parser().setdefault(section, {})[option] = value self._write_config_file() - for hook in OldConfigHooks['set']: + for hook in OldConfigHooks["set"]: hook(self, option, value) def _get_sections(self, name=None): @@ -1069,23 +1073,22 @@ def _get_sections(self, name=None): parser = self._get_parser() # We don't give access to options defined outside of any section, we # used the DEFAULT section by... default. - if name in (None, 'DEFAULT'): + if name in (None, "DEFAULT"): # This could happen for an empty file where the DEFAULT section # doesn't exist yet. So we force DEFAULT when yielding - name = 'DEFAULT' - if 'DEFAULT' not in parser: - parser['DEFAULT'] = {} + name = "DEFAULT" + if "DEFAULT" not in parser: + parser["DEFAULT"] = {} yield (name, parser[name], self.config_id()) def remove_user_option(self, option_name, section_name=None): if section_name is None: # We need to force the default section. - section_name = 'DEFAULT' + section_name = "DEFAULT" with self.lock_write(): # We need to avoid the LockableConfig implementation or we'll lock # twice - super(LockableConfig, self).remove_user_option( - option_name, section_name) + super(LockableConfig, self).remove_user_option(option_name, section_name) def _iter_for_location_by_parts(sections, location): @@ -1105,7 +1108,8 @@ def _iter_for_location_by_parts(sections, location): section names themselves can be in either form. """ import fnmatch - location_parts = location.rstrip('/').split('/') + + location_parts = location.rstrip("/").split("/") for section in sections: # location is a local path if possible, so we need to convert 'file://' @@ -1117,11 +1121,11 @@ def _iter_for_location_by_parts(sections, location): # FIXME: This still raises an issue if a user defines both file:///path # *and* /path. Should we raise an error in this case -- vila 20110505 - if section.startswith('file://'): + if section.startswith("file://"): section_path = urlutils.local_path_from_url(section) else: section_path = section - section_parts = section_path.rstrip('/').split('/') + section_parts = section_path.rstrip("/").split("/") matched = True if len(section_parts) > len(location_parts): @@ -1137,7 +1141,7 @@ def _iter_for_location_by_parts(sections, location): if not matched: continue # build the path difference between the section and the location - extra_path = '/'.join(location_parts[len(section_parts):]) + extra_path = "/".join(location_parts[len(section_parts) :]) yield section, extra_path, len(section_parts) @@ -1145,17 +1149,16 @@ class LocationConfig(LockableConfig): """A configuration object that gives the policy for a location.""" def __init__(self, location): - super().__init__( - file_name=bedding.locations_config_path()) + super().__init__(file_name=bedding.locations_config_path()) # local file locations are looked up by local path, rather than # by file url. This is because the config file is a user # file, and we would rather not expose the user to file urls. - if location.startswith('file://'): + if location.startswith("file://"): location = urlutils.local_path_from_url(location) self.location = location def config_id(self): - return 'locations' + return "locations" @classmethod def from_string(cls, str_or_unicode, location, save=False): @@ -1177,12 +1180,13 @@ def _get_matching_sections(self): matches = sorted( _iter_for_location_by_parts(self._get_parser(), self.location), key=lambda match: (match[2], match[0]), - reverse=True) - for (section, extra_path, _length) in matches: + reverse=True, + ) + for section, extra_path, _length in matches: yield section, extra_path # should we stop looking for parent configs here? try: - if self._get_parser()[section].as_bool('ignore_parents'): + if self._get_parser()[section].as_bool("ignore_parents"): break except KeyError: pass @@ -1199,13 +1203,13 @@ def _get_option_policy(self, section, option_name): """Return the policy for the given (section, option_name) pair.""" # check for the old 'recurse=False' flag try: - recurse = self._get_parser()[section].as_bool('recurse') + recurse = self._get_parser()[section].as_bool("recurse") except KeyError: recurse = True if not recurse: return POLICY_NORECURSE - policy_key = option_name + ':policy' + policy_key = option_name + ":policy" try: policy_name = self._get_parser()[section][policy_key] except KeyError: @@ -1215,7 +1219,7 @@ def _get_option_policy(self, section, option_name): def _set_option_policy(self, section, option_name, option_policy): """Set the policy for the given option name in the given section.""" - policy_key = option_name + ':policy' + policy_key = option_name + ":policy" policy_name = _policy_name[option_policy] if policy_name is not None: self._get_parser()[section][policy_key] = policy_name @@ -1225,25 +1229,27 @@ def _set_option_policy(self, section, option_name, option_policy): def set_user_option(self, option, value, store=STORE_LOCATION): """Save option and its value in the configuration.""" - if store not in [STORE_LOCATION, - STORE_LOCATION_NORECURSE, - STORE_LOCATION_APPENDPATH]: - raise ValueError(f'bad storage policy {store!r} for {option!r}') + if store not in [ + STORE_LOCATION, + STORE_LOCATION_NORECURSE, + STORE_LOCATION_APPENDPATH, + ]: + raise ValueError(f"bad storage policy {store!r} for {option!r}") with self.lock_write(): self.reload() location = self.location - if location.endswith('/'): + if location.endswith("/"): location = location[:-1] parser = self._get_parser() if location not in parser and location + "/" not in parser: parser[location] = {} - elif location + '/' in parser: - location = location + '/' + elif location + "/" in parser: + location = location + "/" parser[location][option] = value # the allowed values of store match the config policies self._set_option_policy(location, option, store) self._write_config_file() - for hook in OldConfigHooks['set']: + for hook in OldConfigHooks["set"]: hook(self, option, value) @@ -1256,12 +1262,14 @@ def __init__(self, branch): self._branch_data_config = None self._global_config = None self.branch = branch - self.option_sources = (self._get_location_config, - self._get_branch_data_config, - self._get_global_config) + self.option_sources = ( + self._get_location_config, + self._get_branch_data_config, + self._get_global_config, + ) def config_id(self): - return 'branch' + return "branch" def _get_branch_data_config(self): if self._branch_data_config is None: @@ -1272,7 +1280,7 @@ def _get_branch_data_config(self): def _get_location_config(self): if self._location_config is None: if self.branch.base is None: - self.branch.base = 'memory://' + self.branch.base = "memory://" self._location_config = LocationConfig(self.branch.base) return self._location_config @@ -1314,18 +1322,18 @@ def _get_user_id(self): e.g. "John Hacker " This is looked up in the email controlfile for the branch. """ - return self._get_best_value('_get_user_id') + return self._get_best_value("_get_user_id") def _get_change_editor(self): - return self._get_best_value('_get_change_editor') + return self._get_best_value("_get_change_editor") def _get_signature_checking(self): """See Config._get_signature_checking.""" - return self._get_best_value('_get_signature_checking') + return self._get_best_value("_get_signature_checking") def _get_signing_policy(self): """See Config._get_signing_policy.""" - return self._get_best_value('_get_signing_policy') + return self._get_best_value("_get_signing_policy") def _get_user_option(self, option_name): """See Config._get_user_option.""" @@ -1346,19 +1354,23 @@ def _get_options(self, sections=None): # Then the branch options branch_config = self._get_branch_data_config() if sections is None: - sections = [('DEFAULT', branch_config._get_parser())] + sections = [("DEFAULT", branch_config._get_parser())] # FIXME: We shouldn't have to duplicate the code in IniBasedConfig but # Config itself has no notion of sections :( -- vila 20101001 config_id = self.config_id() - for (section_name, section) in sections: - for (name, value) in section.iteritems(): - yield (name, value, section_name, - config_id, branch_config._get_parser()) + for section_name, section in sections: + for name, value in section.iteritems(): + yield ( + name, + value, + section_name, + config_id, + branch_config._get_parser(), + ) # Then the global options yield from self._get_global_config()._get_options() - def set_user_option(self, name, value, store=STORE_BRANCH, - warn_masked=False): + def set_user_option(self, name, value, store=STORE_BRANCH, warn_masked=False): if store == STORE_BRANCH: self._get_branch_data_config().set_option(value, name) elif store == STORE_GLOBAL: @@ -1370,22 +1382,28 @@ def set_user_option(self, name, value, store=STORE_BRANCH, if store in (STORE_GLOBAL, STORE_BRANCH): mask_value = self._get_location_config().get_user_option(name) if mask_value is not None: - trace.warning('Value "%s" is masked by "%s" from' - ' locations.conf', value, mask_value) + trace.warning( + 'Value "%s" is masked by "%s" from' " locations.conf", + value, + mask_value, + ) else: if store == STORE_GLOBAL: branch_config = self._get_branch_data_config() mask_value = branch_config.get_user_option(name) if mask_value is not None: - trace.warning('Value "%s" is masked by "%s" from' - ' branch.conf', value, mask_value) + trace.warning( + 'Value "%s" is masked by "%s" from' " branch.conf", + value, + mask_value, + ) def remove_user_option(self, option_name, section_name=None): self._get_branch_data_config().remove_option(option_name, section_name) def _post_commit(self): """See Config.post_commit.""" - return self._get_safe_value('_post_commit') + return self._get_safe_value("_post_commit") def _get_nickname(self): value = self._get_explicit_nickname() @@ -1393,36 +1411,36 @@ def _get_nickname(self): return value if self.branch.name: return self.branch.name - return urlutils.unescape(self.branch.base.split('/')[-2]) + return urlutils.unescape(self.branch.base.split("/")[-2]) def has_explicit_nickname(self): """Return true if a nickname has been explicitly assigned.""" return self._get_explicit_nickname() is not None def _get_explicit_nickname(self): - return self._get_best_value('_get_nickname') + return self._get_best_value("_get_nickname") def _log_format(self): """See Config.log_format.""" - return self._get_best_value('_log_format') + return self._get_best_value("_log_format") def _validate_signatures_in_log(self): """See Config.validate_signatures_in_log.""" - return self._get_best_value('_validate_signatures_in_log') + return self._get_best_value("_validate_signatures_in_log") def _acceptable_keys(self): """See Config.acceptable_keys.""" - return self._get_best_value('_acceptable_keys') + return self._get_best_value("_acceptable_keys") -_username_re = lazy_regex.lazy_compile(r'(.*?)\s*?') +_username_re = lazy_regex.lazy_compile(r"(.*?)\s*?") def parse_username(username): """Parse e-mail username and return a (name, address) tuple.""" match = _username_re.match(username) if match is None: - return (username, '') + return (username, "") return (match.group(1), match.group(2)) @@ -1505,7 +1523,7 @@ def _get_config(self): # Note: the encoding below declares that the file itself is utf-8 # encoded, but the values in the ConfigObj are always Unicode. - self._config = ConfigObj(self._input, encoding='utf-8') + self._config = ConfigObj(self._input, encoding="utf-8") except configobj.ConfigObjError as e: raise ParseConfigError(e.errors, e.config.filename) from e except UnicodeError as e: @@ -1515,23 +1533,34 @@ def _get_config(self): def _check_permissions(self): """Check permission of auth file are user read/write able only.""" import stat + try: st = os.stat(self._filename) except FileNotFoundError: return except OSError as e: - trace.mutter('Unable to stat %r: %r', self._filename, e) + trace.mutter("Unable to stat %r: %r", self._filename, e) return mode = stat.S_IMODE(st.st_mode) - if ((stat.S_IXOTH | stat.S_IWOTH | stat.S_IROTH | stat.S_IXGRP - | stat.S_IWGRP | stat.S_IRGRP) & mode): + if ( + stat.S_IXOTH + | stat.S_IWOTH + | stat.S_IROTH + | stat.S_IXGRP + | stat.S_IWGRP + | stat.S_IRGRP + ) & mode: # Only warn once - if (self._filename not in _authentication_config_permission_errors and - not GlobalConfig().suppress_warning( - 'insecure_permissions')): - trace.warning("The file '%s' has insecure " - "file permissions. Saved passwords may be accessible " - "by other users.", self._filename) + if ( + self._filename not in _authentication_config_permission_errors + and not GlobalConfig().suppress_warning("insecure_permissions") + ): + trace.warning( + "The file '%s' has insecure " + "file permissions. Saved passwords may be accessible " + "by other users.", + self._filename, + ) _authentication_config_permission_errors.add(self._filename) def _save(self): @@ -1540,7 +1569,7 @@ def _save(self): bedding.ensure_config_dir_exists(conf_dir) fd = os.open(self._filename, os.O_RDWR | os.O_CREAT, 0o600) try: - f = os.fdopen(fd, 'wb') + f = os.fdopen(fd, "wb") self._get_config().write(f) finally: f.close() @@ -1555,8 +1584,9 @@ def _set_option(self, section_name, option_name, value): section[option_name] = value self._save() - def get_credentials(self, scheme, host, port=None, user=None, path=None, - realm=None): + def get_credentials( + self, scheme, host, port=None, user=None, path=None, realm=None + ): """Returns the matching credentials from authentication.conf file. Args: @@ -1589,36 +1619,37 @@ def get_credentials(self, scheme, host, port=None, user=None, path=None, raise ValueError(f"{auth_def_name} defined outside a section") a_scheme, a_host, a_user, a_path = map( - auth_def.get, ['scheme', 'host', 'user', 'path']) + auth_def.get, ["scheme", "host", "user", "path"] + ) try: - a_port = auth_def.as_int('port') + a_port = auth_def.as_int("port") except KeyError: a_port = None except ValueError as e: raise ValueError(f"'port' not numeric in {auth_def_name}") from e try: - a_verify_certificates = auth_def.as_bool('verify_certificates') + a_verify_certificates = auth_def.as_bool("verify_certificates") except KeyError: a_verify_certificates = True except ValueError as e: raise ValueError( - f"'verify_certificates' not boolean in {auth_def_name}") from e + f"'verify_certificates' not boolean in {auth_def_name}" + ) from e # Attempt matching if a_scheme is not None and scheme != a_scheme: continue if a_host is not None: - if not (host == a_host or - (a_host.startswith('.') and host.endswith(a_host))): + if not ( + host == a_host or (a_host.startswith(".") and host.endswith(a_host)) + ): continue if a_port is not None and port != a_port: continue - if (a_path is not None and path is not None and - not path.startswith(a_path)): + if a_path is not None and path is not None and not path.startswith(a_path): continue - if (a_user is not None and user is not None and - a_user != user): + if a_user is not None and user is not None and a_user != user: # Never contradict the caller about the user to be used continue if a_user is None: @@ -1626,19 +1657,20 @@ def get_credentials(self, scheme, host, port=None, user=None, path=None, continue # Prepare a credentials dictionary with additional keys # for the credential providers - credentials = {"name": auth_def_name, - "user": a_user, - "scheme": a_scheme, - "host": host, - "port": port, - "path": path, - "realm": realm, - "password": auth_def.get('password', None), - "verify_certificates": a_verify_certificates} + credentials = { + "name": auth_def_name, + "user": a_user, + "scheme": a_scheme, + "host": host, + "port": port, + "path": path, + "realm": realm, + "password": auth_def.get("password", None), + "verify_certificates": a_verify_certificates, + } # Decode the password in the credentials (or get one) - self.decode_password(credentials, - auth_def.get('password_encoding', None)) - if debug.debug_flag_enabled('auth'): + self.decode_password(credentials, auth_def.get("password_encoding", None)) + if debug.debug_flag_enabled("auth"): trace.mutter("Using authentication section: %r", auth_def_name) break @@ -1646,13 +1678,23 @@ def get_credentials(self, scheme, host, port=None, user=None, path=None, # No credentials were found in authentication.conf, try the fallback # credentials stores. credentials = credential_store_registry.get_fallback_credentials( - scheme, host, port, user, path, realm) + scheme, host, port, user, path, realm + ) return credentials - def set_credentials(self, name, host, user, scheme=None, password=None, - port=None, path=None, verify_certificates=None, - realm=None): + def set_credentials( + self, + name, + host, + user, + scheme=None, + password=None, + port=None, + path=None, + verify_certificates=None, + realm=None, + ): """Set authentication credentials for a host. Any existing credentials with matching scheme, host, port and path @@ -1669,22 +1711,22 @@ def set_credentials(self, name, host, user, scheme=None, password=None, verify_certificates: On https, verify server certificates if True. realm: The http authentication realm (optional). """ - values = {'host': host, 'user': user} + values = {"host": host, "user": user} if password is not None: - values['password'] = password + values["password"] = password if scheme is not None: - values['scheme'] = scheme + values["scheme"] = scheme if port is not None: - values['port'] = '%d' % port + values["port"] = "%d" % port if path is not None: - values['path'] = path + values["path"] = path if verify_certificates is not None: - values['verify_certificates'] = str(verify_certificates) + values["verify_certificates"] = str(verify_certificates) if realm is not None: - values['realm'] = realm + values["realm"] = realm config = self._get_config() for section, existing_values in config.iteritems(): - for key in ('scheme', 'host', 'port', 'path', 'realm'): + for key in ("scheme", "host", "port", "path", "realm"): if existing_values.get(key) != values.get(key): break else: @@ -1692,8 +1734,17 @@ def set_credentials(self, name, host, user, scheme=None, password=None, config.update({name: values}) self._save() - def get_user(self, scheme, host, port=None, realm=None, path=None, - prompt=None, ask=False, default=None): + def get_user( + self, + scheme, + host, + port=None, + realm=None, + path=None, + prompt=None, + ask=False, + default=None, + ): """Get a user from authentication file. Args: @@ -1709,20 +1760,21 @@ def get_user(self, scheme, host, port=None, realm=None, path=None, Returns: The found user. """ - credentials = self.get_credentials(scheme, host, port, user=None, - path=path, realm=realm) + credentials = self.get_credentials( + scheme, host, port, user=None, path=path, realm=realm + ) if credentials is not None: - user = credentials['user'] + user = credentials["user"] else: user = None if user is None: if ask: if prompt is None: # Create a default prompt suitable for most cases - prompt = f'{scheme.upper()}' + ' %(host)s username' + prompt = f"{scheme.upper()}" + " %(host)s username" # Special handling for optional fields in the prompt if port is not None: - prompt_host = '%s:%d' % (host, port) + prompt_host = "%s:%d" % (host, port) else: prompt_host = host user = ui.ui_factory.get_username(prompt, host=prompt_host) @@ -1730,8 +1782,9 @@ def get_user(self, scheme, host, port=None, realm=None, path=None, user = default return user - def get_password(self, scheme, host, user, port=None, - realm=None, path=None, prompt=None): + def get_password( + self, scheme, host, user, port=None, realm=None, path=None, prompt=None + ): """Get a password from authentication file or prompt the user for one. Args: @@ -1745,14 +1798,14 @@ def get_password(self, scheme, host, user, port=None, Returns: The found password or the one entered by the user. """ - credentials = self.get_credentials(scheme, host, port, user, path, - realm) + credentials = self.get_credentials(scheme, host, port, user, path, realm) if credentials is not None: - password = credentials['password'] - if password is not None and scheme == 'ssh': - trace.warning('password ignored in section [%s],' - ' use an ssh agent instead' - % credentials['name']) + password = credentials["password"] + if password is not None and scheme == "ssh": + trace.warning( + "password ignored in section [%s]," + " use an ssh agent instead" % credentials["name"] + ) password = None else: password = None @@ -1760,22 +1813,21 @@ def get_password(self, scheme, host, user, port=None, if password is None: if prompt is None: # Create a default prompt suitable for most cases - prompt = (f'{scheme.upper()}' + ' %(user)s@%(host)s password') + prompt = f"{scheme.upper()}" + " %(user)s@%(host)s password" # Special handling for optional fields in the prompt if port is not None: - prompt_host = '%s:%d' % (host, port) + prompt_host = "%s:%d" % (host, port) else: prompt_host = host - password = ui.ui_factory.get_password(prompt, - host=prompt_host, user=user) + password = ui.ui_factory.get_password(prompt, host=prompt_host, user=user) return password def decode_password(self, credentials, encoding): try: cs = credential_store_registry.get_credential_store(encoding) except KeyError as e: - raise ValueError(f'{encoding!r} is not a known password_encoding') from e - credentials['password'] = cs.decode_password(credentials) + raise ValueError(f"{encoding!r} is not a known password_encoding") from e + credentials["password"] = cs.decode_password(credentials) return credentials @@ -1804,8 +1856,9 @@ def is_fallback(self, name): """Check if the named credentials store should be used as fallback.""" return self.get_info(name) - def get_fallback_credentials(self, scheme, host, port=None, user=None, - path=None, realm=None): + def get_fallback_credentials( + self, scheme, host, port=None, user=None, path=None, realm=None + ): """Request credentials from all fallback credentials stores. The first credentials store that can provide credentials wins. @@ -1815,15 +1868,13 @@ def get_fallback_credentials(self, scheme, host, port=None, user=None, if not self.is_fallback(name): continue cs = self.get_credential_store(name) - credentials = cs.get_credentials(scheme, host, port, user, - path, realm) + credentials = cs.get_credentials(scheme, host, port, user, path, realm) if credentials is not None: # We found some credentials break return credentials - def register(self, key, obj, help=None, override_existing=False, - fallback=False): + def register(self, key, obj, help=None, override_existing=False, fallback=False): """Register a new object to a name. Args: @@ -1839,12 +1890,19 @@ def register(self, key, obj, help=None, override_existing=False, fallback: Whether this credential store should be used as fallback. """ - return super().register(key, obj, help, info=fallback, - override_existing=override_existing) - - def register_lazy(self, key, module_name, member_name, - help=None, override_existing=False, - fallback=False): + return super().register( + key, obj, help, info=fallback, override_existing=override_existing + ) + + def register_lazy( + self, + key, + module_name, + member_name, + help=None, + override_existing=False, + fallback=False, + ): """Register a new credential store to be loaded on request. Args: @@ -1860,8 +1918,13 @@ def register_lazy(self, key, module_name, member_name, used as fallback. """ return super().register_lazy( - key, module_name, member_name, help, - info=fallback, override_existing=override_existing) + key, + module_name, + member_name, + help, + info=fallback, + override_existing=override_existing, + ) credential_store_registry = CredentialStoreRegistry() @@ -1874,8 +1937,9 @@ def decode_password(self, credentials): """Returns a clear text password for the provided credentials.""" raise NotImplementedError(self.decode_password) - def get_credentials(self, scheme, host, port=None, user=None, path=None, - realm=None): + def get_credentials( + self, scheme, host, port=None, user=None, path=None, realm=None + ): """Return the matching credentials from this credential store. This method is only called on fallback credential stores. @@ -1888,12 +1952,13 @@ class PlainTextCredentialStore(CredentialStore): def decode_password(self, credentials): """See CredentialStore.decode_password.""" - return credentials['password'] + return credentials["password"] -credential_store_registry.register('plain', PlainTextCredentialStore, - help=PlainTextCredentialStore.__doc__) -credential_store_registry.default_key = 'plain' +credential_store_registry.register( + "plain", PlainTextCredentialStore, help=PlainTextCredentialStore.__doc__ +) +credential_store_registry.default_key = "plain" class Base64CredentialStore(CredentialStore): @@ -1904,15 +1969,16 @@ def decode_password(self, credentials): # GZ 2012-07-28: Will raise binascii.Error if password is not base64, # should probably propogate as something more useful. import base64 - return base64.standard_b64decode(credentials['password']) + return base64.standard_b64decode(credentials["password"]) -credential_store_registry.register('base64', Base64CredentialStore, - help=Base64CredentialStore.__doc__) +credential_store_registry.register( + "base64", Base64CredentialStore, help=Base64CredentialStore.__doc__ +) -class BzrDirConfig: +class BzrDirConfig: def __init__(self, bzrdir): self._bzrdir = bzrdir self._config = bzrdir._get_config() @@ -1928,9 +1994,9 @@ def set_default_stack_on(self, value): if self._config is None: raise errors.BzrError(f"Cannot set configuration in {self._bzrdir}") if value is None: - self._config.set_option('', 'default_stack_on') + self._config.set_option("", "default_stack_on") else: - self._config.set_option(value, 'default_stack_on') + self._config.set_option(value, "default_stack_on") def get_default_stack_on(self): """Return the default stacking location. @@ -1942,8 +2008,8 @@ def get_default_stack_on(self): """ if self._config is None: return None - value = self._config.get_option('default_stack_on') - if value == '': + value = self._config.get_option("default_stack_on") + if value == "": value = None return value @@ -1978,7 +2044,7 @@ def get_option(self, name, section=None, default=None): except KeyError: return default value = section_obj.get(name, default) - for hook in OldConfigHooks['get']: + for hook in OldConfigHooks["get"]: hook(self, name, value) return value @@ -1995,7 +2061,7 @@ def set_option(self, value, name, section=None): configobj[name] = value else: configobj.setdefault(section, {})[name] = value - for hook in OldConfigHooks['set']: + for hook in OldConfigHooks["set"]: hook(self, name, value) self._set_configobj(configobj) @@ -2005,25 +2071,25 @@ def remove_option(self, option_name, section_name=None): del configobj[option_name] else: del configobj[section_name][option_name] - for hook in OldConfigHooks['remove']: + for hook in OldConfigHooks["remove"]: hook(self, option_name) self._set_configobj(configobj) def _get_config_file(self): try: f = BytesIO(self._transport.get_bytes(self._filename)) - for hook in OldConfigHooks['load']: + for hook in OldConfigHooks["load"]: hook(self) return f except transport.NoSuchFile: return BytesIO() except errors.PermissionDenied: trace.warning( - "Permission denied while trying to open " - "configuration file %s.", + "Permission denied while trying to open " "configuration file %s.", urlutils.unescape_for_display( - urlutils.join(self._transport.base, self._filename), - "utf-8")) + urlutils.join(self._transport.base, self._filename), "utf-8" + ), + ) return BytesIO() def _external_url(self): @@ -2033,7 +2099,7 @@ def _get_configobj(self): f = self._get_config_file() try: try: - conf = ConfigObj(f, encoding='utf-8') + conf = ConfigObj(f, encoding="utf-8") except configobj.ConfigObjError as e: raise ParseConfigError(e.errors, self._external_url()) from e except UnicodeDecodeError as e: @@ -2047,7 +2113,7 @@ def _set_configobj(self, configobj): configobj.write(out_file) out_file.seek(0) self._transport.put_file(self._filename, out_file) - for hook in OldConfigHooks['save']: + for hook in OldConfigHooks["save"]: hook(self) @@ -2061,9 +2127,17 @@ class Option: encoutered, in which config files it can be stored. """ - def __init__(self, name, override_from_env=None, - default=None, default_from_env=None, - help=None, from_unicode=None, invalid=None, unquote=True): + def __init__( + self, + name, + override_from_env=None, + default=None, + default_from_env=None, + help=None, + from_unicode=None, + invalid=None, + unquote=True, + ): """Build an option definition. Args: @@ -2113,22 +2187,21 @@ def __init__(self, name, override_from_env=None, elif isinstance(default, list): # Only the empty list is supported if default: - raise AssertionError( - 'Only empty lists are supported as default values') - self.default = ',' + raise AssertionError("Only empty lists are supported as default values") + self.default = "," elif isinstance(default, (bytes, str, bool, int, float)): # Rely on python to convert strings, booleans and integers - self.default = f'{default}' + self.default = f"{default}" elif callable(default): self.default = default else: # other python objects are not expected - raise AssertionError(f'{default!r} is not supported as a default value') + raise AssertionError(f"{default!r} is not supported as a default value") self.default_from_env = default_from_env self._help = help self.from_unicode = from_unicode self.unquote = unquote - if invalid and invalid not in ('warning', 'error'): + if invalid and invalid not in ("warning", "error"): raise AssertionError(f"{invalid} not supported for 'invalid'") self.invalid = invalid @@ -2149,10 +2222,11 @@ def convert_from_unicode(self, store, unicode_value): converted = None if converted is None and self.invalid is not None: # The conversion failed - if self.invalid == 'warning': - trace.warning('Value "%s" is not valid for "%s"', - unicode_value, self.name) - elif self.invalid == 'error': + if self.invalid == "warning": + trace.warning( + 'Value "%s" is not valid for "%s"', unicode_value, self.name + ) + elif self.invalid == "error": raise ConfigOptionValueError(self.name, unicode_value) return converted @@ -2182,7 +2256,8 @@ def get_default(self): value = self.default() if not isinstance(value, str): raise AssertionError( - f"Callable default value for '{self.name}' should be unicode") + f"Callable default value for '{self.name}' should be unicode" + ) else: value = self.default return value @@ -2193,6 +2268,7 @@ def get_help_topic(self): def get_help_text(self, additional_see_also=None, plain=True): result = self.help from breezy import help_topics + result += help_topics._format_see_also(additional_see_also) if plain: result = help_topics.help_as_plain_text(result) @@ -2201,6 +2277,7 @@ def get_help_text(self, additional_see_also=None, plain=True): # Predefined converters to get proper values from store + def bool_from_store(unicode_str): return ui.bool_from_string(unicode_str) @@ -2222,7 +2299,7 @@ def int_SI_from_store(unicode_str): Returns: Integer, expanded to its base-10 value if a proper SI unit is found, None otherwise. """ - regexp = "^(\\d+)(([" + ''.join(_unit_suffixes) + "])b?)?$" + regexp = "^(\\d+)(([" + "".join(_unit_suffixes) + "])b?)?$" p = re.compile(regexp, re.IGNORECASE) m = p.match(unicode_str) val = None @@ -2233,8 +2310,7 @@ def int_SI_from_store(unicode_str): try: coeff = _unit_suffixes[unit.upper()] except KeyError as e: - raise ValueError( - gettext('{0} is not an SI unit.').format(unit)) from e + raise ValueError(gettext("{0} is not an SI unit.").format(unit)) from e val *= coeff return val @@ -2246,22 +2322,28 @@ def float_from_store(unicode_str): # Use an empty dict to initialize an empty configobj avoiding all parsing and # encoding checks _list_converter_config = configobj.ConfigObj( - {}, encoding='utf-8', list_values=True, interpolation=False) + {}, encoding="utf-8", list_values=True, interpolation=False +) class ListOption(Option): - - def __init__(self, name, default=None, default_from_env=None, - help=None, invalid=None): + def __init__( + self, name, default=None, default_from_env=None, help=None, invalid=None + ): """A list Option definition. This overrides the base class so the conversion from a unicode string can take quoting into account. """ super().__init__( - name, default=default, default_from_env=default_from_env, - from_unicode=self.from_unicode, help=help, - invalid=invalid, unquote=False) + name, + default=default, + default_from_env=default_from_env, + from_unicode=self.from_unicode, + help=help, + invalid=invalid, + unquote=False, + ) def from_unicode(self, unicode_str): if not isinstance(unicode_str, str): @@ -2271,7 +2353,7 @@ def from_unicode(self, unicode_str): # properly quoted. _list_converter_config.reset() _list_converter_config._parse([f"list={unicode_str}"]) - maybe_list = _list_converter_config['list'] + maybe_list = _list_converter_config["list"] if isinstance(maybe_list, str): if maybe_list: # A single value, most probably the user forgot (or didn't care @@ -2289,18 +2371,21 @@ def from_unicode(self, unicode_str): class RegistryOption(Option): """Option for a choice from a registry.""" - def __init__(self, name, registry, default_from_env=None, - help=None, invalid=None): + def __init__(self, name, registry, default_from_env=None, help=None, invalid=None): """A registry based Option definition. This overrides the base class so the conversion from a unicode string can take quoting into account. """ super().__init__( - name, default=lambda: registry.default_key, + name, + default=lambda: registry.default_key, default_from_env=default_from_env, - from_unicode=self.from_unicode, help=help, - invalid=invalid, unquote=False) + from_unicode=self.from_unicode, + help=help, + invalid=invalid, + unquote=False, + ) self.registry = registry def from_unicode(self, unicode_str): @@ -2311,7 +2396,8 @@ def from_unicode(self, unicode_str): except KeyError as e: raise ValueError( f"Invalid value {unicode_str} for {self.name}." - "See help for a list of possible values.") from e + "See help for a list of possible values." + ) from e @property def help(self): @@ -2321,7 +2407,7 @@ def help(self): return "".join(ret) -_option_ref_re = lazy_regex.lazy_compile('({[^\\d\\W](?:\\.\\w|-\\w|\\w)*})') +_option_ref_re = lazy_regex.lazy_compile("({[^\\d\\W](?:\\.\\w|-\\w|\\w)*})") """Describes an expandable option reference. We want to match the most embedded reference first. @@ -2352,7 +2438,7 @@ def _check_option_name(self, option_name): Args: option_name: The name to validate. """ - if _option_ref_re.match('{%s}' % option_name) is None: + if _option_ref_re.match("{%s}" % option_name) is None: raise IllegalOptionName(option_name) def register(self, option): @@ -2362,8 +2448,7 @@ def register(self, option): option: The option to register. Its name is used as the key. """ self._check_option_name(option.name) - super().register(option.name, option, - help=option.help) + super().register(option.name, option, help=option.help) def register_lazy(self, key, module_name, member_name): """Register a new option to be loaded on request. @@ -2378,8 +2463,7 @@ def register_lazy(self, key, module_name, member_name): None, get() will return the module itself. """ self._check_option_name(key) - super().register_lazy(key, - module_name, member_name) + super().register_lazy(key, module_name, member_name) def get_help(self, key=None): """Get the help text associated with the given key.""" @@ -2396,63 +2480,91 @@ def get_help(self, key=None): # Registered options in lexicographical order option_registry.register( - Option('append_revisions_only', - default=None, from_unicode=bool_from_store, invalid='warning', - help='''\ + Option( + "append_revisions_only", + default=None, + from_unicode=bool_from_store, + invalid="warning", + help="""\ Whether to only append revisions to the mainline. If this is set to true, then it is not possible to change the existing mainline of the branch. -''')) +""", + ) +) option_registry.register( - ListOption('acceptable_keys', - default=None, - help="""\ + ListOption( + "acceptable_keys", + default=None, + help="""\ List of GPG key patterns which are acceptable for verification. -""")) +""", + ) +) option_registry.register( - Option('add.maximum_file_size', - default='20MB', from_unicode=int_SI_from_store, - help="""\ + Option( + "add.maximum_file_size", + default="20MB", + from_unicode=int_SI_from_store, + help="""\ Size above which files should be added manually. Files below this size are added automatically when using ``bzr add`` without arguments. A negative value means disable the size check. -""")) +""", + ) +) option_registry.register( - Option('bound', - default=None, from_unicode=bool_from_store, - help="""\ + Option( + "bound", + default=None, + from_unicode=bool_from_store, + help="""\ Is the branch bound to ``bound_location``. If set to "True", the branch should act as a checkout, and push each commit to the bound_location. This option is normally set by ``bind``/``unbind``. See also: bound_location. -""")) +""", + ) +) option_registry.register( - Option('bound_location', - default=None, - help="""\ + Option( + "bound_location", + default=None, + help="""\ The location that commits should go to when acting as a checkout. This option is normally set by ``bind``. See also: bound. -""")) +""", + ) +) option_registry.register( - Option('branch.fetch_tags', default=False, from_unicode=bool_from_store, - help="""\ + Option( + "branch.fetch_tags", + default=False, + from_unicode=bool_from_store, + help="""\ Whether revisions associated with tags should be fetched. -""")) +""", + ) +) option_registry.register_lazy( - 'transform.orphan_policy', 'breezy.transform', 'opt_transform_orphan') + "transform.orphan_policy", "breezy.transform", "opt_transform_orphan" +) option_registry.register( - Option('bzr.workingtree.worth_saving_limit', default=10, - from_unicode=int_from_store, invalid='warning', - help='''\ + Option( + "bzr.workingtree.worth_saving_limit", + default=10, + from_unicode=int_from_store, + invalid="warning", + help="""\ How many changes before saving the dirstate. -1 means that we will never rewrite the dirstate file for only @@ -2460,30 +2572,42 @@ def get_help(self, key=None): the dirstate file if a file is added/removed/renamed/etc. This flag only affects the behavior of updating the dirstate file after we notice that a file has been touched. -''')) +""", + ) +) option_registry.register( - Option('bugtracker', default=None, - help='''\ + Option( + "bugtracker", + default=None, + help="""\ Default bug tracker to use. This bug tracker will be used for example when marking bugs as fixed using ``bzr commit --fixes``, if no explicit bug tracker was specified. -''')) +""", + ) +) option_registry.register( - Option('calculate_revnos', default=True, - from_unicode=bool_from_store, - help='''\ + Option( + "calculate_revnos", + default=True, + from_unicode=bool_from_store, + help="""\ Calculate revision numbers if they are not known. Always show revision numbers, even for branch formats that don't store them natively (such as Git). Calculating the revision number requires traversing the left hand ancestry of the branch and can be slow on very large branches. -''')) +""", + ) +) option_registry.register( - Option('check_signatures', default=CHECK_IF_POSSIBLE, - from_unicode=signature_policy_from_unicode, - help='''\ + Option( + "check_signatures", + default=CHECK_IF_POSSIBLE, + from_unicode=signature_policy_from_unicode, + help="""\ GPG checking policy. Possible values: require, ignore, check-available (default) @@ -2491,60 +2615,86 @@ def get_help(self, key=None): this option will control whether bzr will require good gpg signatures, ignore them, or check them if they are present. -''')) +""", + ) +) option_registry.register( - Option('child_submit_format', - help='''The preferred format of submissions to this branch.''')) + Option( + "child_submit_format", + help="""The preferred format of submissions to this branch.""", + ) +) option_registry.register( - Option('child_submit_to', - help='''Where submissions to this branch are mailed to.''')) + Option( + "child_submit_to", help="""Where submissions to this branch are mailed to.""" + ) +) option_registry.register( - Option('create_signatures', default=SIGN_WHEN_REQUIRED, - from_unicode=signing_policy_from_unicode, - help='''\ + Option( + "create_signatures", + default=SIGN_WHEN_REQUIRED, + from_unicode=signing_policy_from_unicode, + help="""\ GPG Signing policy. Possible values: always, never, when-required (default), when-possible This option controls whether bzr will always create gpg signatures or not on commits. -''')) +""", + ) +) option_registry.register( - Option('dirstate.fdatasync', default=True, - from_unicode=bool_from_store, - help='''\ + Option( + "dirstate.fdatasync", + default=True, + from_unicode=bool_from_store, + help="""\ Flush dirstate changes onto physical disk? If true (default), working tree metadata changes are flushed through the OS buffers to physical disk. This is somewhat slower, but means data should not be lost if the machine crashes. See also repository.fdatasync. -''')) +""", + ) +) option_registry.register( - ListOption('debug_flags', default=[], - help='Debug flags to activate.')) + ListOption("debug_flags", default=[], help="Debug flags to activate.") +) option_registry.register( - Option('default_format', default='2a', - help='Format used when creating branches.')) + Option("default_format", default="2a", help="Format used when creating branches.") +) option_registry.register( - Option('editor', - help='The command called to launch an editor to enter a message.')) + Option("editor", help="The command called to launch an editor to enter a message.") +) option_registry.register( - Option('email', override_from_env=['BRZ_EMAIL', 'BZR_EMAIL'], - default=bedding.default_email, help='The users identity')) + Option( + "email", + override_from_env=["BRZ_EMAIL", "BZR_EMAIL"], + default=bedding.default_email, + help="The users identity", + ) +) option_registry.register( - Option('gpg_signing_key', - default=None, - help="""\ + Option( + "gpg_signing_key", + default=None, + help="""\ GPG key to use for signing. This defaults to the first key associated with the users email. -""")) +""", + ) +) option_registry.register( - Option('language', - help='Language to translate messages into.')) + Option("language", help="Language to translate messages into.") +) option_registry.register( - Option('locks.steal_dead', default=True, from_unicode=bool_from_store, - help='''\ + Option( + "locks.steal_dead", + default=True, + from_unicode=bool_from_store, + help="""\ Steal locks that appears to be dead. If set to True, bzr will check if a lock is supposed to be held by an @@ -2553,138 +2703,191 @@ def get_help(self, key=None): will automatically break the stale lock, and create a new lock for this process. Otherwise, bzr will prompt as normal to break the lock. -''')) +""", + ) +) option_registry.register( - Option('log_format', default='long', - help='''\ + Option( + "log_format", + default="long", + help="""\ Log format to use when displaying revisions. Standard log formats are ``long``, ``short`` and ``line``. Additional formats may be provided by plugins. -''')) -option_registry.register_lazy('mail_client', 'breezy.mail_client', - 'opt_mail_client') +""", + ) +) +option_registry.register_lazy("mail_client", "breezy.mail_client", "opt_mail_client") option_registry.register( - Option('output_encoding', - help='Unicode encoding for output' - ' (terminal encoding if not specified).')) + Option( + "output_encoding", + help="Unicode encoding for output" " (terminal encoding if not specified).", + ) +) option_registry.register( - Option('parent_location', - default=None, - help="""\ + Option( + "parent_location", + default=None, + help="""\ The location of the default branch for pull or merge. This option is normally set when creating a branch, the first ``pull`` or by ``pull --remember``. -""")) +""", + ) +) option_registry.register( - Option('post_commit', default=None, - help='''\ + Option( + "post_commit", + default=None, + help="""\ Post commit functions. An ordered list of python functions to call, separated by spaces. Each function takes branch, rev_id as parameters. -''')) -option_registry.register_lazy('progress_bar', 'breezy.ui.text', - 'opt_progress_bar') +""", + ) +) +option_registry.register_lazy("progress_bar", "breezy.ui.text", "opt_progress_bar") option_registry.register( - Option('public_branch', - default=None, - help="""\ + Option( + "public_branch", + default=None, + help="""\ A publically-accessible version of this branch. This implies that the branch setting this option is not publically-accessible. Used and set by ``bzr send``. -""")) +""", + ) +) option_registry.register( - Option('push_location', - default=None, - help="""\ + Option( + "push_location", + default=None, + help="""\ The location of the default branch for push. This option is normally set by the first ``push`` or ``push --remember``. -""")) +""", + ) +) option_registry.register( - Option('push_strict', default=None, - from_unicode=bool_from_store, - help='''\ + Option( + "push_strict", + default=None, + from_unicode=bool_from_store, + help="""\ The default value for ``push --strict``. If present, defines the ``--strict`` option default value for checking uncommitted changes before sending a merge directive. -''')) +""", + ) +) option_registry.register( - Option('repository.fdatasync', default=True, - from_unicode=bool_from_store, - help='''\ + Option( + "repository.fdatasync", + default=True, + from_unicode=bool_from_store, + help="""\ Flush repository changes onto physical disk? If true (default), repository changes are flushed through the OS buffers to physical disk. This is somewhat slower, but means data should not be lost if the machine crashes. See also dirstate.fdatasync. -''')) -option_registry.register_lazy('smtp_server', - 'breezy.smtp_connection', 'smtp_server') -option_registry.register_lazy('smtp_password', - 'breezy.smtp_connection', 'smtp_password') -option_registry.register_lazy('smtp_username', - 'breezy.smtp_connection', 'smtp_username') +""", + ) +) +option_registry.register_lazy("smtp_server", "breezy.smtp_connection", "smtp_server") +option_registry.register_lazy( + "smtp_password", "breezy.smtp_connection", "smtp_password" +) +option_registry.register_lazy( + "smtp_username", "breezy.smtp_connection", "smtp_username" +) option_registry.register( - Option('selftest.timeout', - default='1200', - from_unicode=int_from_store, - help='Abort selftest if one test takes longer than this many seconds', - )) + Option( + "selftest.timeout", + default="1200", + from_unicode=int_from_store, + help="Abort selftest if one test takes longer than this many seconds", + ) +) option_registry.register( - Option('send_strict', default=None, - from_unicode=bool_from_store, - help='''\ + Option( + "send_strict", + default=None, + from_unicode=bool_from_store, + help="""\ The default value for ``send --strict``. If present, defines the ``--strict`` option default value for checking uncommitted changes before sending a bundle. -''')) +""", + ) +) option_registry.register( - Option('serve.client_timeout', - default=300.0, from_unicode=float_from_store, - help="If we wait for a new request from a client for more than" - " X seconds, consider the client idle, and hangup.")) + Option( + "serve.client_timeout", + default=300.0, + from_unicode=float_from_store, + help="If we wait for a new request from a client for more than" + " X seconds, consider the client idle, and hangup.", + ) +) option_registry.register( - Option('ssh', - default=None, override_from_env=['BRZ_SSH'], - help='SSH vendor to use.')) + Option( + "ssh", default=None, override_from_env=["BRZ_SSH"], help="SSH vendor to use." + ) +) option_registry.register( - Option('stacked_on_location', - default=None, - help="""The location where this branch is stacked on.""")) + Option( + "stacked_on_location", + default=None, + help="""The location where this branch is stacked on.""", + ) +) option_registry.register( - Option('submit_branch', - default=None, - help="""\ + Option( + "submit_branch", + default=None, + help="""\ The branch you intend to submit your current work to. This is automatically set by ``bzr send`` and ``bzr merge``, and is also used by the ``submit:`` revision spec. -""")) +""", + ) +) option_registry.register( - Option('submit_to', - help='''Where submissions from this branch are mailed to.''')) + Option("submit_to", help="""Where submissions from this branch are mailed to.""") +) option_registry.register( - ListOption('suppress_warnings', - default=[], - help="List of warning classes to suppress.")) + ListOption( + "suppress_warnings", default=[], help="List of warning classes to suppress." + ) +) option_registry.register( - Option('validate_signatures_in_log', default=False, - from_unicode=bool_from_store, invalid='warning', - help='''Whether to validate signatures in brz log.''')) -option_registry.register_lazy('ssl.ca_certs', - 'breezy.transport.http', 'opt_ssl_ca_certs') + Option( + "validate_signatures_in_log", + default=False, + from_unicode=bool_from_store, + invalid="warning", + help="""Whether to validate signatures in brz log.""", + ) +) +option_registry.register_lazy( + "ssl.ca_certs", "breezy.transport.http", "opt_ssl_ca_certs" +) -option_registry.register_lazy('ssl.cert_reqs', - 'breezy.transport.http', 'opt_ssl_cert_reqs') +option_registry.register_lazy( + "ssl.cert_reqs", "breezy.transport.http", "opt_ssl_cert_reqs" +) class Section: @@ -2764,18 +2967,21 @@ def apply_changes(self, dirty, store): # to be used as a sharing mechanism). if expected != reloaded: if actual is _DeletedOption: - actual = '' + actual = "" if reloaded is _NewlyCreatedOption: - reloaded = '' + reloaded = "" if expected is _NewlyCreatedOption: - expected = '' + expected = "" # Someone changed the value since we get it from the persistent # storage. - trace.warning(gettext( - "Option {} in section {} of {} was changed" - " from {} to {}. The {} value will be saved.").format( - k, self.id, store.external_url(), expected, - reloaded, actual)) + trace.warning( + gettext( + "Option {} in section {} of {} was changed" + " from {} to {}. The {} value will be saved." + ).format( + k, self.id, store.external_url(), expected, reloaded, actual + ) + ) # No need to keep track of these changes self.reset_changes() @@ -2904,7 +3110,7 @@ def __init__(self, opts=None): if opts is None: opts = {} self.options = {} - self.id = 'cmdline' + self.id = "cmdline" def _reset(self): # The dict should be cleared but not replaced so it can be shared. @@ -2915,17 +3121,18 @@ def _from_cmdline(self, overrides): self._reset() for over in overrides: try: - name, value = over.split('=', 1) + name, value = over.split("=", 1) except ValueError as e: raise errors.CommandError( gettext("Invalid '%s', should be of the form 'name=value'") - % (over,)) from e + % (over,) + ) from e self.options[name] = value def external_url(self): # Not an url but it makes debugging easier and is never needed # otherwise - return 'cmdline' + return "cmdline" def get_sections(self): yield self, self.readonly_section_class(None, self.options) @@ -2976,7 +3183,7 @@ def load(self): return content = self._load_content() self._load_from_string(content) - for hook in ConfigHooks['load']: + for hook in ConfigHooks["load"]: hook(self) def _load_from_string(self, bytes): @@ -2986,12 +3193,11 @@ def _load_from_string(self, bytes): bytes: A string representing the file content. """ if self.is_loaded(): - raise AssertionError(f'Already loaded: {self._config_obj!r}') + raise AssertionError(f"Already loaded: {self._config_obj!r}") co_input = BytesIO(bytes) try: # The config files are always stored utf8-encoded - self._config_obj = ConfigObj(co_input, encoding='utf-8', - list_values=False) + self._config_obj = ConfigObj(co_input, encoding="utf-8", list_values=False) except configobj.ConfigObjError as e: self._config_obj = None raise ParseConfigError(e.errors, self.external_url()) from e @@ -3017,7 +3223,7 @@ def save(self): out = BytesIO() self._config_obj.write(out) self._save_content(out.getvalue()) - for hook in ConfigHooks['save']: + for hook in ConfigHooks["save"]: hook(self) def get_sections(self) -> Iterable[Tuple[Store, Section]]: @@ -3035,9 +3241,7 @@ def get_sections(self) -> Iterable[Tuple[Store, Section]]: if cobj.scalars: yield self, self.readonly_section_class(None, cobj) for section_name in cobj.sections: - yield (self, - self.readonly_section_class(section_name, - cobj[section_name])) + yield (self, self.readonly_section_class(section_name, cobj[section_name])) def get_mutable_section(self, section_id=None): # We need a loaded store @@ -3045,7 +3249,7 @@ def get_mutable_section(self, section_id=None): self.load() except transport.NoSuchFile: # The file doesn't exist, let's pretend it was empty - self._load_from_string(b'') + self._load_from_string(b"") if section_id in self.dirty_sections: # We already created a mutable section for this id return self.dirty_sections[section_id] @@ -3078,7 +3282,7 @@ def external_url(self): # it's better to provide something than raising a NotImplementedError. # All daughter classes are supposed to provide an implementation # anyway. - return 'In-Process Store, no URL' + return "In-Process Store, no URL" class TransportIniFileStore(IniFileStore): @@ -3104,8 +3308,10 @@ def _load_content(self): try: return self.transport.get_bytes(self.file_name) except errors.PermissionDenied: - trace.warning("Permission denied while trying to load " - "configuration store %s.", self.external_url()) + trace.warning( + "Permission denied while trying to load " "configuration store %s.", + self.external_url(), + ) raise def _save_content(self, content): @@ -3118,7 +3324,8 @@ def external_url(self): # expose a path here but rather a config ID and its associated # object . return urlutils.join( - self.transport.external_url(), urlutils.escape(self.file_name)) + self.transport.external_url(), urlutils.escape(self.file_name) + ) # Note that LockableConfigObjStore inherits from ConfigObjStore because we need @@ -3138,7 +3345,7 @@ def __init__(self, transport, file_name, lock_dir_name=None): file_name: The config file basename in the transport directory. """ if lock_dir_name is None: - lock_dir_name = 'lock' + lock_dir_name = "lock" self.lock_dir_name = lock_dir_name super().__init__(transport, file_name) self._lock = lockdir.LockDir(self.transport, self.lock_dir_name) @@ -3174,6 +3381,7 @@ def save_without_locking(self): # 'user_defaults' as opposed to 'user_overrides', 'system_defaults' # (/etc/bzr/bazaar.conf) and 'system_overrides' ? -- vila 2011-04-05 + # FIXME: Moreover, we shouldn't need classes for these stores either, factory # functions or a registry will make it easier and clearer for tests, focusing # on the relevant parts of the API that needs testing -- vila 20110503 (based @@ -3187,9 +3395,10 @@ class GlobalStore(LockableIniFileStore): def __init__(self, possible_transports=None): path, kind = bedding._config_dir() t = transport.get_transport_from_path( - path, possible_transports=possible_transports) - super().__init__(t, kind + '.conf') - self.id = 'breezy' + path, possible_transports=possible_transports + ) + super().__init__(t, kind + ".conf") + self.id = "breezy" class LocationStore(LockableIniFileStore): @@ -3200,9 +3409,10 @@ class LocationStore(LockableIniFileStore): def __init__(self, possible_transports=None): t = transport.get_transport_from_path( - bedding.config_dir(), possible_transports=possible_transports) - super().__init__(t, 'locations.conf') - self.id = 'locations' + bedding.config_dir(), possible_transports=possible_transports + ) + super().__init__(t, "locations.conf") + self.id = "locations" class BranchStore(TransportIniFileStore): @@ -3212,19 +3422,15 @@ class BranchStore(TransportIniFileStore): """ def __init__(self, branch): - super().__init__(branch.control_transport, - 'branch.conf') + super().__init__(branch.control_transport, "branch.conf") self.branch = branch - self.id = 'branch' + self.id = "branch" class ControlStore(LockableIniFileStore): - def __init__(self, bzrdir): - super().__init__(bzrdir.transport, - 'control.conf', - lock_dir_name='branch_lock') - self.id = 'control' + super().__init__(bzrdir.transport, "control.conf", lock_dir_name="branch_lock") + self.id = "control" class SectionMatcher: @@ -3259,7 +3465,6 @@ def match(self, section): class NameMatcher(SectionMatcher): - def __init__(self, store, section_id): super().__init__(store) self.section_id = section_id @@ -3269,20 +3474,21 @@ def match(self, section): class LocationSection(Section): - def __init__(self, section, extra_path, branch_name=None): super().__init__(section.id, section.options) self.extra_path = extra_path if branch_name is None: - branch_name = '' - self.locals = {'relpath': extra_path, - 'basename': urlutils.basename(extra_path), - 'branchname': branch_name} + branch_name = "" + self.locals = { + "relpath": extra_path, + "basename": urlutils.basename(extra_path), + "branchname": branch_name, + } def get(self, name, default=None, expand=True): value = super().get(name, default) if value is not None and expand: - policy_name = self.get(name + ':policy', None) + policy_name = self.get(name + ":policy", None) policy = _policy_value.get(policy_name, POLICY_NONE) if policy == POLICY_APPENDPATH: value = urlutils.join(value, self.extra_path) @@ -3298,7 +3504,7 @@ def get(self, name, default=None, expand=True): chunks.append(self.locals[ref]) else: chunks.append(chunk) - value = ''.join(chunks) + value = "".join(chunks) return value @@ -3315,7 +3521,7 @@ class StartingPathMatcher(SectionMatcher): def __init__(self, store, location): super().__init__(store) - if location.startswith('file://'): + if location.startswith("file://"): location = urlutils.local_path_from_url(location) self.location = location @@ -3329,7 +3535,8 @@ def get_sections(self): the most specific ones can be found first. """ import fnmatch - location_parts = self.location.rstrip('/').split('/') + + location_parts = self.location.rstrip("/").split("/") store = self.store # Later sections are more specific, they should be returned first for _, section in reversed(list(store.get_sections())): @@ -3338,26 +3545,26 @@ def get_sections(self): yield store, LocationSection(section, self.location) continue section_path = section.id - if section_path.startswith('file://'): + if section_path.startswith("file://"): # the location is already a local path or URL, convert the # section id to the same format section_path = urlutils.local_path_from_url(section_path) - if (self.location.startswith(section_path) or - fnmatch.fnmatch(self.location, section_path)): - section_parts = section_path.rstrip('/').split('/') - extra_path = '/'.join(location_parts[len(section_parts):]) + if self.location.startswith(section_path) or fnmatch.fnmatch( + self.location, section_path + ): + section_parts = section_path.rstrip("/").split("/") + extra_path = "/".join(location_parts[len(section_parts) :]) yield store, LocationSection(section, extra_path) class LocationMatcher(SectionMatcher): - def __init__(self, store, location): super().__init__(store) url, params = urlutils.split_segment_parameters(location) - if location.startswith('file://'): + if location.startswith("file://"): location = urlutils.local_path_from_url(location) self.location = location - branch_name = params.get('branch') + branch_name = params.get("branch") if branch_name is None: self.branch_name = urlutils.basename(self.location) else: @@ -3379,12 +3586,14 @@ def _get_matching_sections(self): # Unfortunately _iter_for_location_by_parts deals with section names so # we have to resync. filtered_sections = _iter_for_location_by_parts( - [s.id for s in all_sections], self.location) + [s.id for s in all_sections], self.location + ) iter_all_sections = iter(all_sections) matching_sections = [] if no_name_section is not None: matching_sections.append( - (0, LocationSection(no_name_section, self.location))) + (0, LocationSection(no_name_section, self.location)) + ) for section_id, extra_path, length in filtered_sections: # a section id is unique for a given store so it's safe to take the # first matching section while iterating. Also, all filtered @@ -3393,8 +3602,7 @@ def _get_matching_sections(self): while True: section = next(iter_all_sections) if section_id == section.id: - section = LocationSection(section, extra_path, - self.branch_name) + section = LocationSection(section, extra_path, self.branch_name) matching_sections.append((length, section)) break return matching_sections @@ -3402,13 +3610,15 @@ def _get_matching_sections(self): def get_sections(self): # Override the default implementation as we want to change the order # We want the longest (aka more specific) locations first - sections = sorted(self._get_matching_sections(), - key=lambda match: (match[0], match[1].id), - reverse=True) + sections = sorted( + self._get_matching_sections(), + key=lambda match: (match[0], match[1].id), + reverse=True, + ) # Sections mentioning 'ignore_parents' restrict the selection for _, section in sections: # FIXME: We really want to use as_bool below -- vila 2011-04-07 - ignore = section.get('ignore_parents', None) + ignore = section.get("ignore_parents", None) if ignore is not None: ignore = ui.bool_from_string(ignore) if ignore: @@ -3491,8 +3701,10 @@ def expand_and_convert(val): if isinstance(val, str): val = self._expand_options_in_string(val) else: - trace.warning(f'Cannot expand "{name}":' - f' {type(val)} does not support option expansion') + trace.warning( + f'Cannot expand "{name}":' + f" {type(val)} does not support option expansion" + ) if opt is None: val = found_store.unquote(val) elif convert: @@ -3515,7 +3727,7 @@ def expand_and_convert(val): # If the option is registered, it may provide a default value value = opt.get_default() value = expand_and_convert(value) - for hook in ConfigHooks['get']: + for hook in ConfigHooks["get"]: hook(self, name, value) return value @@ -3571,7 +3783,7 @@ def _expand_options_in_string(self, string, env=None, _refs=None): raise ExpandingUnknownOption(name, string) chunks.append(value) _refs.pop() - result = ''.join(chunks) + result = "".join(chunks) return result def _expand_option(self, name, env, _refs): @@ -3600,14 +3812,14 @@ def set(self, name, value): """Set a new value for the option.""" store, section = self._get_mutable_section() section.set(name, store.quote(value)) - for hook in ConfigHooks['set']: + for hook in ConfigHooks["set"]: hook(self, name, value) def remove(self, name): """Remove an existing option.""" _, section = self._get_mutable_section() section.remove(name) - for hook in ConfigHooks['remove']: + for hook in ConfigHooks["remove"]: hook(self, name) def __repr__(self): @@ -3643,10 +3855,12 @@ def get_shared_store(self, store, state=None): def save_config_changes(): for _k, store in stores.items(): store.save_changes() + if not _shared_stores_at_exit_installed: # FIXME: Ugly hack waiting for library_state to always be # available. -- vila 20120731 import atexit + atexit.register(save_config_changes) _shared_stores_at_exit_installed = True else: @@ -3679,8 +3893,7 @@ def __init__(self, content=None): store = IniFileStore() if content is not None: store._load_from_string(content) - super().__init__( - [store.get_sections], store) + super().__init__([store.get_sections], store) class _CompatibleStack(Stack): @@ -3730,9 +3943,10 @@ class GlobalStack(Stack): def __init__(self): gstore = self.get_shared_store(GlobalStore()) super().__init__( - [self._get_overrides, - NameMatcher(gstore, 'DEFAULT').get_sections], - gstore, mutable_section_id='DEFAULT') + [self._get_overrides, NameMatcher(gstore, "DEFAULT").get_sections], + gstore, + mutable_section_id="DEFAULT", + ) class LocationStack(Stack): @@ -3759,14 +3973,18 @@ def __init__(self, location): location: A URL prefix to """ lstore = self.get_shared_store(LocationStore()) - if location.startswith('file://'): + if location.startswith("file://"): location = urlutils.local_path_from_url(location) gstore = self.get_shared_store(GlobalStore()) super().__init__( - [self._get_overrides, - LocationMatcher(lstore, location).get_sections, - NameMatcher(gstore, 'DEFAULT').get_sections], - lstore, mutable_section_id=location) + [ + self._get_overrides, + LocationMatcher(lstore, location).get_sections, + NameMatcher(gstore, "DEFAULT").get_sections, + ], + lstore, + mutable_section_id=location, + ) class BranchStack(Stack): @@ -3793,11 +4011,14 @@ def __init__(self, branch): bstore = branch._get_config_store() gstore = self.get_shared_store(GlobalStore()) super().__init__( - [self._get_overrides, - LocationMatcher(lstore, branch.base).get_sections, - NameMatcher(bstore, None).get_sections, - NameMatcher(gstore, 'DEFAULT').get_sections], - bstore) + [ + self._get_overrides, + LocationMatcher(lstore, branch.base).get_sections, + NameMatcher(bstore, None).get_sections, + NameMatcher(gstore, "DEFAULT").get_sections, + ], + bstore, + ) self.branch = branch def lock_write(self, token=None): @@ -3828,9 +4049,7 @@ class RemoteControlStack(Stack): def __init__(self, bzrdir): cstore = bzrdir._get_config_store() - super().__init__( - [NameMatcher(cstore, None).get_sections], - cstore) + super().__init__([NameMatcher(cstore, None).get_sections], cstore) self.controldir = bzrdir @@ -3843,9 +4062,7 @@ class BranchOnlyStack(Stack): def __init__(self, branch): bstore = branch._get_config_store() - super().__init__( - [NameMatcher(bstore, None).get_sections], - bstore) + super().__init__([NameMatcher(bstore, None).get_sections], bstore) self.branch = branch def lock_write(self, token=None): @@ -3887,45 +4104,48 @@ class cmd_config(commands.Command): Removing a value is achieved by using --remove NAME. """ - takes_args = ['name?'] + takes_args = ["name?"] takes_options = [ - 'directory', + "directory", # FIXME: This should be a registry option so that plugins can register # their own config files (or not) and will also address # http://pad.lv/788991 -- vila 20101115 - CommandOption('scope', help='Reduce the scope to the specified' - ' configuration file.', - type=str), - CommandOption('all', - help='Display all the defined values for the matching options.', - ), - CommandOption('remove', help='Remove the option from' - ' the configuration file.'), - ] - - _see_also = ['configuration'] + CommandOption( + "scope", + help="Reduce the scope to the specified" " configuration file.", + type=str, + ), + CommandOption( + "all", + help="Display all the defined values for the matching options.", + ), + CommandOption( + "remove", help="Remove the option from" " the configuration file." + ), + ] + + _see_also = ["configuration"] @commands.display_command - def run(self, name=None, all=False, directory=None, scope=None, - remove=False): + def run(self, name=None, all=False, directory=None, scope=None, remove=False): from .directory_service import directories + if directory is None: - directory = '.' + directory = "." directory = directories.dereference(directory) directory = urlutils.normalize_url(directory) if remove and all: - raise errors.BzrError( - '--all and --remove are mutually exclusive.') + raise errors.BzrError("--all and --remove are mutually exclusive.") elif remove: # Delete the option in the given scope self._remove_config_option(name, directory, scope) elif name is None: # Defaults to all options - self._show_matching_options('.*', directory, scope) + self._show_matching_options(".*", directory, scope) else: try: - name, value = name.split('=', 1) + name, value = name.split("=", 1) except ValueError: # Display the option(s) value(s) if all: @@ -3934,8 +4154,7 @@ def run(self, name=None, all=False, directory=None, scope=None, self._show_value(name, directory, scope) else: if all: - raise errors.BzrError( - 'Only one option can be set.') + raise errors.BzrError("Only one option can be set.") # Set the option value self._set_config_option(name, value, directory, scope) @@ -3952,23 +4171,23 @@ def _get_stack(self, directory, scope=None, write_access=False): # reduced to the plugin-specific store), related to # http://pad.lv/788991 -- vila 2011-11-15 if scope is not None: - if scope == 'breezy': + if scope == "breezy": return GlobalStack() - elif scope == 'locations': + elif scope == "locations": return LocationStack(directory) - elif scope == 'branch': - (_, br, _) = ( - controldir.ControlDir.open_containing_tree_or_branch( - directory)) + elif scope == "branch": + (_, br, _) = controldir.ControlDir.open_containing_tree_or_branch( + directory + ) if write_access: self.add_cleanup(br.lock_write().unlock) return br.get_config_stack() raise NoSuchConfig(scope) else: try: - (_, br, _) = ( - controldir.ControlDir.open_containing_tree_or_branch( - directory)) + (_, br, _) = controldir.ControlDir.open_containing_tree_or_branch( + directory + ) if write_access: self.add_cleanup(br.lock_write().unlock) return br.get_config_stack() @@ -3976,7 +4195,7 @@ def _get_stack(self, directory, scope=None, write_access=False): return LocationStack(directory) def _quote_multiline(self, value): - if '\n' in value: + if "\n" in value: value = '"""' + value + '"""' return value @@ -3986,7 +4205,7 @@ def _show_value(self, name, directory, scope): if value is not None: # Quote the value appropriately value = self._quote_multiline(value) - self.outf.write(f'{value}\n') + self.outf.write(f"{value}\n") else: raise NoSuchConfigOption(name) @@ -4004,18 +4223,18 @@ def _show_matching_options(self, name, directory, scope): if name.search(oname): if cur_store_id != store.id: # Explain where the options are defined - self.outf.write(f'{store.id}:\n') + self.outf.write(f"{store.id}:\n") cur_store_id = store.id cur_section = None - if (section.id is not None and cur_section != section.id): + if section.id is not None and cur_section != section.id: # Display the section id as it appears in the store # (None doesn't appear by definition) - self.outf.write(f' [{section.id}]\n') + self.outf.write(f" [{section.id}]\n") cur_section = section.id value = section.get(oname, expand=False) # Quote the value appropriately value = self._quote_multiline(value) - self.outf.write(f' {oname} = {value}\n') + self.outf.write(f" {oname} = {value}\n") def _set_config_option(self, name, value, directory, scope): conf = self._get_stack(directory, scope, write_access=True) @@ -4025,8 +4244,7 @@ def _set_config_option(self, name, value, directory, scope): def _remove_config_option(self, name, directory, scope): if name is None: - raise errors.CommandError( - '--remove expects an option to remove.') + raise errors.CommandError("--remove expects an option to remove.") conf = self._get_stack(directory, scope, write_access=True) try: conf.remove(name) diff --git a/breezy/conflicts.py b/breezy/conflicts.py index 140a591357..a8623d1cb8 100644 --- a/breezy/conflicts.py +++ b/breezy/conflicts.py @@ -21,13 +21,16 @@ from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( workingtree, ) from breezy.i18n import gettext, ngettext -""") +""", +) from . import commands, errors, option, osutils, registry, trace @@ -46,46 +49,50 @@ class cmd_conflicts(commands.Command): Use brz resolve when you have fixed a problem. """ takes_options = [ - 'directory', - option.Option('text', - help='List paths of files with text conflicts.'), - ] - _see_also = ['resolve', 'conflict-types'] + "directory", + option.Option("text", help="List paths of files with text conflicts."), + ] + _see_also = ["resolve", "conflict-types"] - def run(self, text=False, directory='.'): + def run(self, text=False, directory="."): wt = workingtree.WorkingTree.open_containing(directory)[0] for conflict in wt.conflicts(): if text: - if conflict.typestring != 'text conflict': + if conflict.typestring != "text conflict": continue - self.outf.write(conflict.path + '\n') + self.outf.write(conflict.path + "\n") else: - self.outf.write(str(conflict) + '\n') + self.outf.write(str(conflict) + "\n") resolve_action_registry = registry.Registry[str, str, None]() resolve_action_registry.register( - 'auto', 'auto', 'Detect whether conflict has been resolved by user.') -resolve_action_registry.register( - 'done', 'done', 'Marks the conflict as resolved.') + "auto", "auto", "Detect whether conflict has been resolved by user." +) +resolve_action_registry.register("done", "done", "Marks the conflict as resolved.") resolve_action_registry.register( - 'take-this', 'take_this', - 'Resolve the conflict preserving the version in the working tree.') + "take-this", + "take_this", + "Resolve the conflict preserving the version in the working tree.", +) resolve_action_registry.register( - 'take-other', 'take_other', - 'Resolve the conflict taking the merged version into account.') -resolve_action_registry.default_key = 'done' + "take-other", + "take_other", + "Resolve the conflict taking the merged version into account.", +) +resolve_action_registry.default_key = "done" class ResolveActionOption(option.RegistryOption): - def __init__(self): super().__init__( - 'action', 'How to resolve the conflict.', + "action", + "How to resolve the conflict.", value_switches=True, - registry=resolve_action_registry) + registry=resolve_action_registry, + ) class cmd_resolve(commands.Command): @@ -100,56 +107,65 @@ class cmd_resolve(commands.Command): text conflicts as fixed, "brz resolve FILE" to mark a specific conflict as resolved, or "brz resolve --all" to mark all conflicts as resolved. """ - aliases = ['resolved'] - takes_args = ['file*'] + aliases = ["resolved"] + takes_args = ["file*"] takes_options = [ - 'directory', - option.Option('all', help='Resolve all conflicts in this tree.'), + "directory", + option.Option("all", help="Resolve all conflicts in this tree."), ResolveActionOption(), - ] - _see_also = ['conflicts'] + ] + _see_also = ["conflicts"] def run(self, file_list=None, all=False, action=None, directory=None): if all: if file_list: - raise errors.CommandError(gettext("If --all is specified," - " no FILE may be provided")) + raise errors.CommandError( + gettext("If --all is specified," " no FILE may be provided") + ) if directory is None: - directory = '.' + directory = "." tree = workingtree.WorkingTree.open_containing(directory)[0] if action is None: - action = 'done' + action = "done" else: tree, file_list = workingtree.WorkingTree.open_containing_paths( - file_list, directory) + file_list, directory + ) if action is None: if file_list is None: - action = 'auto' + action = "auto" else: - action = 'done' + action = "done" before, after = resolve(tree, file_list, action=action) # GZ 2012-07-27: Should unify UI below now that auto is less magical. - if action == 'auto' and file_list is None: + if action == "auto" and file_list is None: if after > 0: trace.note( - ngettext('%d conflict auto-resolved.', - '%d conflicts auto-resolved.', before - after), - before - after) - trace.note(gettext('Remaining conflicts:')) + ngettext( + "%d conflict auto-resolved.", + "%d conflicts auto-resolved.", + before - after, + ), + before - after, + ) + trace.note(gettext("Remaining conflicts:")) for conflict in tree.conflicts(): trace.note(str(conflict)) return 1 else: - trace.note(gettext('All conflicts resolved.')) + trace.note(gettext("All conflicts resolved.")) return 0 else: - trace.note(ngettext('{0} conflict resolved, {1} remaining', - '{0} conflicts resolved, {1} remaining', - before - after).format(before - after, after)) + trace.note( + ngettext( + "{0} conflict resolved, {1} remaining", + "{0} conflicts resolved, {1} remaining", + before - after, + ).format(before - after, after) + ) -def resolve(tree, paths=None, ignore_misses=False, recursive=False, - action='done'): +def resolve(tree, paths=None, ignore_misses=False, recursive=False, action="done"): """Resolve some or all of the conflicts in a working tree. :param paths: If None, resolve all conflicts. Otherwise, select only @@ -172,7 +188,8 @@ def resolve(tree, paths=None, ignore_misses=False, recursive=False, to_process = tree_conflicts else: new_conflicts, to_process = tree_conflicts.select_conflicts( - tree, paths, ignore_misses, recursive) + tree, paths, ignore_misses, recursive + ) for conflict in to_process: try: conflict.do(action, tree) @@ -263,8 +280,7 @@ def remove_files(self, tree): continue conflict.cleanup(tree) - def select_conflicts(self, tree, paths, ignore_misses=False, - recurse=False): + def select_conflicts(self, tree, paths, ignore_misses=False, recurse=False): """Select the conflicts associated with paths in a tree. :return: a pair of ConflictLists: (not_selected, selected) diff --git a/breezy/controldir.py b/breezy/controldir.py index e80f34ae3c..ec5797aef6 100644 --- a/breezy/controldir.py +++ b/breezy/controldir.py @@ -28,7 +28,9 @@ from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import textwrap from breezy import ( @@ -38,7 +40,8 @@ ) from breezy.i18n import gettext -""") +""", +) from . import errors, hooks, registry, trace from . import revision as _mod_revision @@ -51,7 +54,6 @@ class MustHaveWorkingTree(errors.BzrError): - _fmt = "Branching '%(url)s'(%(format)s) must create a working tree." def __init__(self, format, url): @@ -59,7 +61,6 @@ def __init__(self, format, url): class BranchReferenceLoop(errors.BzrError): - _fmt = "Can not create branch reference that points at branch itself." def __init__(self, branch): @@ -67,8 +68,7 @@ def __init__(self, branch): class NoColocatedBranchSupport(errors.BzrError): - - _fmt = ("%(controldir)r does not support co-located branches.") + _fmt = "%(controldir)r does not support co-located branches." def __init__(self, controldir): self.controldir = controldir @@ -206,9 +206,12 @@ def destroy_repository(self) -> None: """Destroy the repository in this ControlDir.""" raise NotImplementedError(self.destroy_repository) - def create_branch(self, name: Optional[str] = None, - repository: Optional["Repository"] = None, - append_revisions_only: Optional[bool] = None) -> "Branch": + def create_branch( + self, + name: Optional[str] = None, + repository: Optional["Repository"] = None, + append_revisions_only: Optional[bool] = None, + ) -> "Branch": """Create a branch in this ControlDir. Args: @@ -234,8 +237,9 @@ def destroy_branch(self, name: Optional[str] = None) -> None: """ raise NotImplementedError(self.destroy_branch) - def create_workingtree(self, revision_id=None, from_branch=None, - accelerator_tree=None, hardlink=False) -> "WorkingTree": + def create_workingtree( + self, revision_id=None, from_branch=None, accelerator_tree=None, hardlink=False + ) -> "WorkingTree": """Create a working tree at this ControlDir. Args: @@ -306,8 +310,13 @@ def set_branch_reference(self, target_branch, name=None): """ raise NotImplementedError(self.set_branch_reference) - def open_branch(self, name=None, unsupported=False, - ignore_fallbacks=False, possible_transports=None) -> "Branch": + def open_branch( + self, + name=None, + unsupported=False, + ignore_fallbacks=False, + possible_transports=None, + ) -> "Branch": """Open the branch object at this ControlDir if one is present. Args: @@ -340,8 +349,9 @@ def find_repository(self) -> "Repository": """ raise NotImplementedError(self.find_repository) - def open_workingtree(self, unsupported=False, - recommend_upgrade=True, from_branch=None) -> "WorkingTree": + def open_workingtree( + self, unsupported=False, recommend_upgrade=True, from_branch=None + ) -> "WorkingTree": """Open the workingtree object at this ControlDir if one is present. Args: @@ -417,11 +427,20 @@ def checkout_metadir(self): """ return self.cloning_metadir() - def sprout(self, url, revision_id=None, force_new_repo=False, - recurse='down', possible_transports=None, - accelerator_tree=None, hardlink=False, stacked=False, - source_branch=None, create_tree_if_local=True, - lossy=False): + def sprout( + self, + url, + revision_id=None, + force_new_repo=False, + recurse="down", + possible_transports=None, + accelerator_tree=None, + hardlink=False, + stacked=False, + source_branch=None, + create_tree_if_local=True, + lossy=False, + ): """Create a copy of this controldir prepared for use as a new line of development. @@ -448,9 +467,17 @@ def sprout(self, url, revision_id=None, force_new_repo=False, """ raise NotImplementedError(self.sprout) - def push_branch(self, source, revision_id=None, overwrite=False, - remember=False, create_prefix=False, lossy=False, - tag_selector=None, name=None): + def push_branch( + self, + source, + revision_id=None, + overwrite=False, + remember=False, + create_prefix=False, + lossy=False, + tag_selector=None, + name=None, + ): """Push the source branch into this ControlDir.""" from .push import PushResult @@ -477,8 +504,12 @@ def push_branch(self, source, revision_id=None, overwrite=False, revision_id = source.last_revision() repository_to.fetch(source.repository, revision_id=revision_id) br_to = source.sprout( - self, revision_id=revision_id, lossy=lossy, - tag_selector=tag_selector, name=name) + self, + revision_id=revision_id, + lossy=lossy, + tag_selector=tag_selector, + name=name, + ) if source.get_push_location() is None or remember: # FIXME: Should be done only if we succeed ? -- vila 2012-01-18 source.set_push_location(br_to.base) @@ -498,32 +529,46 @@ def push_branch(self, source, revision_id=None, overwrite=False, tree_to = self.open_workingtree() except errors.NotLocalUrl: push_result.branch_push_result = source.push( - br_to, overwrite=overwrite, stop_revision=revision_id, lossy=lossy, - tag_selector=tag_selector) + br_to, + overwrite=overwrite, + stop_revision=revision_id, + lossy=lossy, + tag_selector=tag_selector, + ) push_result.workingtree_updated = False except errors.NoWorkingTree: push_result.branch_push_result = source.push( - br_to, overwrite=overwrite, stop_revision=revision_id, - lossy=lossy, tag_selector=tag_selector) + br_to, + overwrite=overwrite, + stop_revision=revision_id, + lossy=lossy, + tag_selector=tag_selector, + ) push_result.workingtree_updated = None # Not applicable else: if br_to.name == tree_to.branch.name: with tree_to.lock_write(): push_result.branch_push_result = source.push( - tree_to.branch, overwrite=overwrite, + tree_to.branch, + overwrite=overwrite, stop_revision=revision_id, - lossy=lossy, tag_selector=tag_selector) + lossy=lossy, + tag_selector=tag_selector, + ) tree_to.update() push_result.workingtree_updated = True else: push_result.branch_push_result = source.push( - br_to, overwrite=overwrite, stop_revision=revision_id, - lossy=lossy, tag_selector=tag_selector) + br_to, + overwrite=overwrite, + stop_revision=revision_id, + lossy=lossy, + tag_selector=tag_selector, + ) push_result.workingtree_updated = None # Not applicable push_result.old_revno = push_result.branch_push_result.old_revno push_result.old_revid = push_result.branch_push_result.old_revid - push_result.target_branch = \ - push_result.branch_push_result.target_branch + push_result.target_branch = push_result.branch_push_result.target_branch return push_result def _get_tree_branch(self, name=None): @@ -557,8 +602,14 @@ def check_conversion_target(self, target_format): """Check that a controldir as a whole can be converted to a new format.""" raise NotImplementedError(self.check_conversion_target) - def clone(self, url, revision_id=None, force_new_repo=False, - preserve_stacking=False, tag_selector=None): + def clone( + self, + url, + revision_id=None, + force_new_repo=False, + preserve_stacking=False, + tag_selector=None, + ): """Clone this controldir and its contents to url verbatim. Args: @@ -572,16 +623,26 @@ def clone(self, url, revision_id=None, force_new_repo=False, preserve_stacking: When cloning a stacked branch, stack the new branch on top of the other branch's stacked-on branch. """ - return self.clone_on_transport(_mod_transport.get_transport(url), - revision_id=revision_id, - force_new_repo=force_new_repo, - preserve_stacking=preserve_stacking, - tag_selector=tag_selector) - - def clone_on_transport(self, transport, revision_id=None, - force_new_repo=False, preserve_stacking=False, stacked_on=None, - create_prefix=False, use_existing_dir=True, no_tree=False, - tag_selector=None): + return self.clone_on_transport( + _mod_transport.get_transport(url), + revision_id=revision_id, + force_new_repo=force_new_repo, + preserve_stacking=preserve_stacking, + tag_selector=tag_selector, + ) + + def clone_on_transport( + self, + transport, + revision_id=None, + force_new_repo=False, + preserve_stacking=False, + stacked_on=None, + create_prefix=False, + use_existing_dir=True, + no_tree=False, + tag_selector=None, + ): """Clone this controldir and its contents to transport verbatim. Args: @@ -621,9 +682,12 @@ def find_controldirs(klass, transport, evaluate=None, list_current=None): a generator of found bzrdirs, or whatever evaluate returns. """ if list_current is None: + def list_current(transport): - return transport.list_dir('') + return transport.list_dir("") + if evaluate is None: + def evaluate(controldir): return True, controldir @@ -633,8 +697,11 @@ def evaluate(controldir): recurse = True try: controldir = klass.open_from_transport(current_transport) - except (errors.NotBranchError, errors.PermissionDenied, - errors.UnknownFormatError): + except ( + errors.NotBranchError, + errors.PermissionDenied, + errors.UnknownFormatError, + ): pass else: recurse, value = evaluate(controldir) @@ -658,6 +725,7 @@ def find_branches(klass, transport): To list all the branches that use a particular Repository, see Repository.find_branches """ + def evaluate(controldir): try: repository = controldir.open_repository() @@ -666,9 +734,9 @@ def evaluate(controldir): else: return False, ([], repository) return True, (controldir.list_branches(), None) + ret = [] - for branches, repo in klass.find_controldirs( - transport, evaluate=evaluate): + for branches, repo in klass.find_controldirs(transport, evaluate=evaluate): if repo is not None: ret.extend(repo.find_branches()) if branches is not None: @@ -677,7 +745,8 @@ def evaluate(controldir): @classmethod def create_branch_and_repo( - klass, base, force_new_repo=False, format=None) -> "Branch": + klass, base, force_new_repo=False, format=None + ) -> "Branch": """Create a new ControlDir, Branch and Repository at the url 'base'. This will use the current default ControlDirFormat unless one is @@ -699,9 +768,14 @@ def create_branch_and_repo( return cast("Branch", controldir.create_branch()) @classmethod - def create_branch_convenience(klass, base, force_new_repo=False, - force_new_tree=None, format=None, - possible_transports=None): + def create_branch_convenience( + klass, + base, + force_new_repo=False, + force_new_tree=None, + format=None, + possible_transports=None, + ): """Create a new ControlDir, Branch and Repository at the url 'base'. This is a convenience function - it will use an existing repository @@ -738,8 +812,7 @@ def create_branch_convenience(klass, base, force_new_repo=False, controldir = klass.create(base, format, possible_transports) repo = controldir._find_or_create_repository(force_new_repo) result = controldir.create_branch() - if force_new_tree or (repo.make_working_trees() and - force_new_tree is None): + if force_new_tree or (repo.make_working_trees() and force_new_tree is None): try: controldir.create_workingtree() except errors.NotLocalUrl: @@ -764,11 +837,12 @@ def create_standalone_workingtree(klass, base, format=None) -> "WorkingTree": """ t = _mod_transport.get_transport(base) from breezy.transport import local + if not isinstance(t, local.LocalTransport): raise errors.NotLocalUrl(base) - controldir = klass.create_branch_and_repo(base, - force_new_repo=True, - format=format).controldir + controldir = klass.create_branch_and_repo( + base, force_new_repo=True, format=format + ).controldir return controldir.create_workingtree() @classmethod @@ -777,47 +851,51 @@ def open_unsupported(klass, base): return klass.open(base, _unsupported=True) @classmethod - def open(klass, base, possible_transports=None, probers=None, - _unsupported=False) -> "ControlDir": + def open( + klass, base, possible_transports=None, probers=None, _unsupported=False + ) -> "ControlDir": """Open an existing controldir, rooted at 'base' (url). Args: _unsupported: a private parameter to the ControlDir class. """ t = _mod_transport.get_transport(base, possible_transports) - return klass.open_from_transport(t, probers=probers, - _unsupported=_unsupported) + return klass.open_from_transport(t, probers=probers, _unsupported=_unsupported) @classmethod - def open_from_transport(klass, transport: _mod_transport.Transport, - _unsupported=False, probers=None) -> "ControlDir": + def open_from_transport( + klass, transport: _mod_transport.Transport, _unsupported=False, probers=None + ) -> "ControlDir": """Open a controldir within a particular directory. Args: transport: Transport containing the controldir. _unsupported: private. """ - for hook in klass.hooks['pre_open']: + for hook in klass.hooks["pre_open"]: hook(transport) # Keep initial base since 'transport' may be modified while following # the redirections. base = transport.base def find_format(transport): - return transport, ControlDirFormat.find_format(transport, - probers=probers) + return transport, ControlDirFormat.find_format(transport, probers=probers) def redirected(transport, e, redirection_notice): redirected_transport = transport._redirected_to(e.source, e.target) if redirected_transport is None: raise errors.NotBranchError(base) - trace.note(gettext('{0} is{1} redirected to {2}').format( - transport.base, e.permanently, redirected_transport.base)) + trace.note( + gettext("{0} is{1} redirected to {2}").format( + transport.base, e.permanently, redirected_transport.base + ) + ) return redirected_transport try: transport, format = _mod_transport.do_catching_redirections( - find_format, transport, redirected) + find_format, transport, redirected + ) except errors.TooManyRedirections as e: raise errors.NotBranchError(base) from e @@ -862,7 +940,7 @@ def open_containing_from_transport(klass, a_transport, probers=None): except errors.PermissionDenied: pass try: - new_t = a_transport.clone('..') + new_t = a_transport.clone("..") except urlutils.InvalidURLJoin as e: # reached the root, whatever that may be raise errors.NotBranchError(path=url) from e @@ -884,8 +962,7 @@ def open_tree_or_branch(klass, location, name=None): return controldir._get_tree_branch(name=name) @classmethod - def open_containing_tree_or_branch(klass, location, - possible_transports=None): + def open_containing_tree_or_branch(klass, location, possible_transports=None): """Return the branch and working tree contained by a location. Returns (tree, branch, relpath). @@ -894,8 +971,9 @@ def open_containing_tree_or_branch(klass, location, raised relpath is the portion of the path that is contained by the branch. """ - controldir, relpath = klass.open_containing(location, - possible_transports=possible_transports) + controldir, relpath = klass.open_containing( + location, possible_transports=possible_transports + ) tree, branch = controldir._get_tree_branch() return tree, branch, relpath @@ -920,7 +998,7 @@ def open_containing_tree_branch_or_repository(klass, location): try: repo = controldir.find_repository() return None, None, repo, relpath - except (errors.NoRepositoryPresent) as e: + except errors.NoRepositoryPresent as e: raise errors.NotBranchError(location) from e return tree, branch, branch.repository, relpath @@ -935,8 +1013,10 @@ def create(klass, base, format=None, possible_transports=None): can be reused to share a remote connection. """ if klass is not ControlDir: - raise AssertionError("ControlDir.create always creates the" - "default format, not one of %r" % klass) + raise AssertionError( + "ControlDir.create always creates the" + "default format, not one of %r" % klass + ) t = _mod_transport.get_transport(base, possible_transports) t.ensure_base() if format is None: @@ -950,14 +1030,19 @@ class ControlDirHooks(hooks.Hooks): def __init__(self): """Create the default hooks.""" hooks.Hooks.__init__(self, "breezy.controldir", "ControlDir.hooks") - self.add_hook('pre_open', - "Invoked before attempting to open a ControlDir with the transport " - "that the open will use.", (1, 14)) - self.add_hook('post_repo_init', - "Invoked after a repository has been initialized. " - "post_repo_init is called with a " - "breezy.controldir.RepoInitHookParams.", - (2, 2)) + self.add_hook( + "pre_open", + "Invoked before attempting to open a ControlDir with the transport " + "that the open will use.", + (1, 14), + ) + self.add_hook( + "post_repo_init", + "Invoked after a repository has been initialized. " + "post_repo_init is called with a " + "breezy.controldir.RepoInitHookParams.", + (2, 2), + ) # install the default hooks @@ -982,8 +1067,9 @@ def is_supported(self): """ return True - def check_support_status(self, allow_unsupported, recommend_upgrade=True, - basedir=None): + def check_support_status( + self, allow_unsupported, recommend_upgrade=True, basedir=None + ): """Give an error or warning on old formats. Args: @@ -999,15 +1085,16 @@ def check_support_status(self, allow_unsupported, recommend_upgrade=True, # see open_downlevel to open legacy branches. raise errors.UnsupportedFormatError(format=self) if recommend_upgrade and self.upgrade_recommended: - ui.ui_factory.recommend_upgrade( - self.get_format_description(), basedir) + ui.ui_factory.recommend_upgrade(self.get_format_description(), basedir) @classmethod def get_format_string(cls): raise NotImplementedError(cls.get_format_string) -class ControlComponentFormatRegistry(registry.FormatRegistry[ControlComponentFormat, None]): +class ControlComponentFormatRegistry( + registry.FormatRegistry[ControlComponentFormat, None] +): """A registry for control components (branch, workingtree, repository).""" def __init__(self, other_registry=None): @@ -1016,13 +1103,11 @@ def __init__(self, other_registry=None): def register(self, format): """Register a new format.""" - super().register( - format.get_format_string(), format) + super().register(format.get_format_string(), format) def remove(self, format): """Remove a registered format.""" - super().remove( - format.get_format_string()) + super().remove(format.get_format_string()) def register_extra(self, format): """Register a format that can not be used in a metadir. @@ -1038,8 +1123,7 @@ def remove_extra(self, format): def register_extra_lazy(self, module_name, member_name): """Register a format lazily.""" - self._extra_formats.append( - registry._LazyObjectGetter(module_name, member_name)) + self._extra_formats.append(registry._LazyObjectGetter(module_name, member_name)) def _get_extra(self): """Return getters for extra formats, not usable in meta directories.""" @@ -1166,8 +1250,9 @@ def is_initializable(self): """Whether new control directories of this format can be initialized.""" return self.is_supported() - def check_support_status(self, allow_unsupported, recommend_upgrade=True, - basedir=None): + def check_support_status( + self, allow_unsupported, recommend_upgrade=True, basedir=None + ): """Give an error or warning on old formats. Args: @@ -1183,12 +1268,10 @@ def check_support_status(self, allow_unsupported, recommend_upgrade=True, # see open_downlevel to open legacy branches. raise errors.UnsupportedFormatError(format=self) if recommend_upgrade and self.upgrade_recommended: - ui.ui_factory.recommend_upgrade( - self.get_format_description(), basedir) + ui.ui_factory.recommend_upgrade(self.get_format_description(), basedir) def same_model(self, target_format): - return (self.repository_format.rich_root_data == - target_format.rich_root_data) + return self.repository_format.rich_root_data == target_format.rich_root_data @classmethod def register_prober(klass, prober: Type["Prober"]): @@ -1217,14 +1300,16 @@ def known_formats(klass): return result @classmethod - def find_format(klass, transport: _mod_transport.Transport, - probers: Optional[List[Type["Prober"]]] = None - ) -> "ControlDirFormat": + def find_format( + klass, + transport: _mod_transport.Transport, + probers: Optional[List[Type["Prober"]]] = None, + ) -> "ControlDirFormat": """Return the format present at transport.""" if probers is None: probers = sorted( - klass.all_probers(), - key=lambda prober: prober.priority(transport)) + klass.all_probers(), key=lambda prober: prober.priority(transport) + ) for prober_kls in probers: prober = prober_kls() try: @@ -1245,16 +1330,26 @@ def initialize(self, url: str, possible_transports=None): instead of this method. """ return self.initialize_on_transport( - _mod_transport.get_transport(url, possible_transports)) + _mod_transport.get_transport(url, possible_transports) + ) def initialize_on_transport(self, transport: _mod_transport.Transport): """Initialize a new controldir in the base directory of a Transport.""" raise NotImplementedError(self.initialize_on_transport) - def initialize_on_transport_ex(self, transport: _mod_transport.Transport, use_existing_dir: bool = False, - create_prefix: bool = False, force_new_repo: bool = False, stacked_on=None, - stack_on_pwd=None, repo_format_name=None, make_working_trees=None, - shared_repo=False, vfs_only=False): + def initialize_on_transport_ex( + self, + transport: _mod_transport.Transport, + use_existing_dir: bool = False, + create_prefix: bool = False, + force_new_repo: bool = False, + stacked_on=None, + stack_on_pwd=None, + repo_format_name=None, + make_working_trees=None, + shared_repo=False, + vfs_only=False, + ): """Create this format on transport. The directory to initialize will be created. @@ -1387,7 +1482,6 @@ def priority(klass, transport: _mod_transport.Transport) -> int: class ControlDirFormatInfo: - def __init__(self, native, deprecated, hidden, experimental): self.deprecated = deprecated self.native = native @@ -1407,8 +1501,16 @@ def __init__(self): self._registration_order = [] super().__init__() - def register(self, key, factory, help, native=True, deprecated=False, - hidden=False, experimental=False): + def register( + self, + key, + factory, + help, + native=True, + deprecated=False, + hidden=False, + experimental=False, + ): """Register a ControlDirFormat factory. The factory must be a callable that takes one parameter: the key. @@ -1417,8 +1519,13 @@ def register(self, key, factory, help, native=True, deprecated=False, This function mainly exists to prevent the info object from being supplied directly. """ - registry.Registry.register(self, key, factory, help, - ControlDirFormatInfo(native, deprecated, hidden, experimental)) + registry.Registry.register( + self, + key, + factory, + help, + ControlDirFormatInfo(native, deprecated, hidden, experimental), + ) self._registration_order.append(key) def register_alias(self, key, target, hidden=False): @@ -1430,15 +1537,37 @@ def register_alias(self, key, target, hidden=False): hidden: Whether the alias is hidden """ info = self.get_info(target) - registry.Registry.register_alias(self, key, target, - ControlDirFormatInfo( - native=info.native, deprecated=info.deprecated, - hidden=hidden, experimental=info.experimental)) - - def register_lazy(self, key, module_name, member_name, help, native=True, - deprecated=False, hidden=False, experimental=False): - registry.Registry.register_lazy(self, key, module_name, member_name, - help, ControlDirFormatInfo(native, deprecated, hidden, experimental)) + registry.Registry.register_alias( + self, + key, + target, + ControlDirFormatInfo( + native=info.native, + deprecated=info.deprecated, + hidden=hidden, + experimental=info.experimental, + ), + ) + + def register_lazy( + self, + key, + module_name, + member_name, + help, + native=True, + deprecated=False, + hidden=False, + experimental=False, + ): + registry.Registry.register_lazy( + self, + key, + module_name, + member_name, + help, + ControlDirFormatInfo(native, deprecated, hidden, experimental), + ) self._registration_order.append(key) def set_default(self, key): @@ -1446,7 +1575,7 @@ def set_default(self, key): This method must be called once and only once. """ - self.register_alias('default', key) + self.register_alias("default", key) def set_default_repository(self, key): """Set the FormatRegistry default and Repository default. @@ -1454,10 +1583,10 @@ def set_default_repository(self, key): This is a transitional method while Repository.set_default_format is deprecated. """ - if 'default' in self: - self.remove('default') + if "default" in self: + self.remove("default") self.set_default(key) - self.get('default')() + self.get("default")() def make_controldir(self, key): return self.get(key)() @@ -1465,10 +1594,10 @@ def make_controldir(self, key): def help_topic(self, topic): output = "" default_realkey = None - default_help = self.get_help('default') + default_help = self.get_help("default") help_pairs = [] for key in self._registration_order: - if key == 'default': + if key == "default": continue help = self.get_help(key) if help == default_help: @@ -1478,14 +1607,21 @@ def help_topic(self, topic): def wrapped(key, help, info): if info.native: - help = '(native) ' + help - return ':{}:\n{}\n\n'.format(key, - textwrap.fill(help, initial_indent=' ', - subsequent_indent=' ', - break_long_words=False)) + help = "(native) " + help + return ":{}:\n{}\n\n".format( + key, + textwrap.fill( + help, + initial_indent=" ", + subsequent_indent=" ", + break_long_words=False, + ), + ) + if default_realkey is not None: - output += wrapped(default_realkey, f'(default) {default_help}', - self.get_info('default')) + output += wrapped( + default_realkey, f"(default) {default_help}", self.get_info("default") + ) deprecated_pairs = [] experimental_pairs = [] for key, help in help_pairs: @@ -1506,20 +1642,17 @@ def wrapped(key, help, info): info = self.get_info(key) other_output += wrapped(key, help, info) else: - other_output += \ - "No experimental formats are available.\n\n" + other_output += "No experimental formats are available.\n\n" if len(deprecated_pairs) > 0: other_output += "\nDeprecated formats are shown below.\n\n" for key, help in deprecated_pairs: info = self.get_info(key) other_output += wrapped(key, help, info) else: - other_output += \ - "\nNo deprecated formats are available.\n\n" - other_output += \ - "\nSee :doc:`formats-help` for more about storage formats." + other_output += "\nNo deprecated formats are available.\n\n" + other_output += "\nSee :doc:`formats-help` for more about storage formats." - if topic == 'other-formats': + if topic == "other-formats": return other_output else: return output @@ -1604,15 +1737,17 @@ def configure_branch(self, branch): stack_on = self._stack_on else: try: - stack_on = urlutils.rebase_url(self._stack_on, - self._stack_on_pwd, - branch.user_url) + stack_on = urlutils.rebase_url( + self._stack_on, self._stack_on_pwd, branch.user_url + ) except urlutils.InvalidRebaseURLs: stack_on = self._get_full_stack_on() try: branch.set_stacked_on_url(stack_on) - except (_mod_branch.UnstackableBranchFormat, - errors.UnstackableRepositoryFormat): + except ( + _mod_branch.UnstackableBranchFormat, + errors.UnstackableRepositoryFormat, + ): if self._require_stacking: raise @@ -1636,7 +1771,8 @@ def _add_fallback(self, repository, possible_transports=None): return try: stacked_dir = ControlDir.open( - stack_on, possible_transports=possible_transports) + stack_on, possible_transports=possible_transports + ) except errors.JailBreak: # We keep the stacking details, but we are in the server code so # actually stacking is not needed. @@ -1653,8 +1789,9 @@ def _add_fallback(self, repository, possible_transports=None): else: self._require_stacking = True - def acquire_repository(self, make_working_trees=None, shared=False, - possible_transports=None): + def acquire_repository( + self, make_working_trees=None, shared=False, possible_transports=None + ): """Acquire a repository for this controlrdir. Implementations may create a new repository or use a pre-exising @@ -1667,8 +1804,7 @@ def acquire_repository(self, make_working_trees=None, shared=False, Returns: A repository, is_new_flag (True if the repository was created). """ - raise NotImplementedError( - RepositoryAcquisitionPolicy.acquire_repository) + raise NotImplementedError(RepositoryAcquisitionPolicy.acquire_repository) # Please register new formats after old formats so that formats diff --git a/breezy/counted_lock.py b/breezy/counted_lock.py index 2beb380260..e6cc79dc61 100644 --- a/breezy/counted_lock.py +++ b/breezy/counted_lock.py @@ -69,7 +69,7 @@ def lock_read(self): else: self._real_lock.lock_read() self._lock_count = 1 - self._lock_mode = 'r' + self._lock_mode = "r" def lock_write(self, token=None): """Acquire the lock in write mode. @@ -84,10 +84,10 @@ def lock_write(self, token=None): """ if self._lock_count == 0: self._token = self._real_lock.lock_write(token=token) - self._lock_mode = 'w' + self._lock_mode = "w" self._lock_count += 1 return self._token - elif self._lock_mode != 'w': + elif self._lock_mode != "w": raise errors.ReadOnlyError(self) else: self._real_lock.validate_token(token) diff --git a/breezy/crash.py b/breezy/crash.py index 2d51d8731c..fe72d2f4cb 100644 --- a/breezy/crash.py +++ b/breezy/crash.py @@ -55,8 +55,9 @@ def report_bug(exc_info, stderr): - if (debug.debug_flag_enabled('no_apport')) or \ - os.environ.get('APPORT_DISABLE', None): + if (debug.debug_flag_enabled("no_apport")) or os.environ.get( + "APPORT_DISABLE", None + ): return report_bug_legacy(exc_info, stderr) try: if report_bug_to_apport(exc_info, stderr): @@ -75,25 +76,35 @@ def report_bug(exc_info, stderr): def report_bug_legacy(exc_info, err_file): """Report a bug by just printing a message to the user.""" trace.print_exception(exc_info, err_file) - err_file.write('\n') + err_file.write("\n") import textwrap def print_wrapped(l): - err_file.write(textwrap.fill( - l, width=78, subsequent_indent=' ') + '\n') - print_wrapped('brz {} on python {} ({})\n'.format(breezy.__version__, - breezy._format_version_tuple(sys.version_info), - platform.platform(aliased=1))) - print_wrapped(f'arguments: {sys.argv!r}\n') - print_wrapped(textwrap.fill( - 'plugins: ' + plugin.format_concise_plugin_list(), - width=78, - subsequent_indent=' ', - ) + '\n') + err_file.write(textwrap.fill(l, width=78, subsequent_indent=" ") + "\n") + + print_wrapped( + "brz {} on python {} ({})\n".format( + breezy.__version__, + breezy._format_version_tuple(sys.version_info), + platform.platform(aliased=1), + ) + ) + print_wrapped(f"arguments: {sys.argv!r}\n") + print_wrapped( + textwrap.fill( + "plugins: " + plugin.format_concise_plugin_list(), + width=78, + subsequent_indent=" ", + ) + + "\n" + ) print_wrapped( - 'encoding: {!r}, fsenc: {!r}, lang: {!r}\n'.format( - osutils.get_user_encoding(), sys.getfilesystemencoding(), - os.environ.get('LANG'))) + "encoding: {!r}, fsenc: {!r}, lang: {!r}\n".format( + osutils.get_user_encoding(), + sys.getfilesystemencoding(), + os.environ.get("LANG"), + ) + ) # We used to show all the plugins here, but it's too verbose. err_file.write( "\n" @@ -101,7 +112,7 @@ def print_wrapped(l): " bug in Breezy. You can help us fix it by filing a bug report at\n" " https://bugs.launchpad.net/brz/+filebug\n" " including this traceback and a description of the problem.\n" - ) + ) def report_bug_to_apport(exc_info, stderr): @@ -120,16 +131,16 @@ def report_bug_to_apport(exc_info, stderr): crash_filename = _write_apport_report_to_file(exc_info) if crash_filename is None: - stderr.write("\n" - "apport is set to ignore crashes in this version of brz.\n" - ) + stderr.write("\n" "apport is set to ignore crashes in this version of brz.\n") else: trace.print_exception(exc_info, stderr) - stderr.write("\n" - "You can report this problem to Breezy's developers by running\n" - " apport-bug %s\n" - "if a bug-reporting window does not automatically appear.\n" - % (crash_filename)) + stderr.write( + "\n" + "You can report this problem to Breezy's developers by running\n" + " apport-bug %s\n" + "if a bug-reporting window does not automatically appear.\n" + % (crash_filename) + ) # XXX: on Windows, Mac, and other platforms where we might have the # apport libraries but not have an apport always running, we could # synchronously file now @@ -149,42 +160,42 @@ def _write_apport_report_to_file(exc_info): pr.add_proc_info() # It also adds ProcMaps which for us is rarely useful and mostly noise, so # let's remove it. - del pr['ProcMaps'] + del pr["ProcMaps"] pr.add_user_info() # Package and SourcePackage are needed so that apport will report about even # non-packaged versions of brz; also this reports on their packaged # dependencies which is useful. - pr['SourcePackage'] = 'brz' - pr['Package'] = 'brz' - - pr['CommandLine'] = pprint.pformat(sys.argv) - pr['BrzVersion'] = breezy.__version__ - pr['PythonVersion'] = breezy._format_version_tuple(sys.version_info) - pr['Platform'] = platform.platform(aliased=1) - pr['UserEncoding'] = osutils.get_user_encoding() - pr['FileSystemEncoding'] = sys.getfilesystemencoding() - pr['Locale'] = os.environ.get('LANG', 'C') - pr['BrzPlugins'] = _format_plugin_list() - pr['PythonLoadedModules'] = _format_module_list() - pr['BrzDebugFlags'] = pprint.pformat(debug.debug_flags) + pr["SourcePackage"] = "brz" + pr["Package"] = "brz" + + pr["CommandLine"] = pprint.pformat(sys.argv) + pr["BrzVersion"] = breezy.__version__ + pr["PythonVersion"] = breezy._format_version_tuple(sys.version_info) + pr["Platform"] = platform.platform(aliased=1) + pr["UserEncoding"] = osutils.get_user_encoding() + pr["FileSystemEncoding"] = sys.getfilesystemencoding() + pr["Locale"] = os.environ.get("LANG", "C") + pr["BrzPlugins"] = _format_plugin_list() + pr["PythonLoadedModules"] = _format_module_list() + pr["BrzDebugFlags"] = pprint.pformat(debug.debug_flags) # actually we'd rather file directly against the upstream product, but # apport does seem to count on there being one in there; we might need to # redirect it elsewhere anyhow - pr['SourcePackage'] = 'brz' - pr['Package'] = 'brz' + pr["SourcePackage"] = "brz" + pr["Package"] = "brz" # tell apport to file directly against the brz package using # # # XXX: unfortunately apport may crash later if the crashdb definition # file isn't present - pr['CrashDb'] = 'brz' + pr["CrashDb"] = "brz" tb_file = StringIO() traceback.print_exception(exc_type, exc_object, exc_tb, file=tb_file) - pr['Traceback'] = tb_file.getvalue() + pr["Traceback"] = tb_file.getvalue() _attach_log_tail(pr) @@ -220,11 +231,11 @@ def _attach_log_tail(pr): try: brz_log = open(trace.get_brz_log_filename()) except OSError as e: - pr['BrzLogTail'] = repr(e) + pr["BrzLogTail"] = repr(e) return try: lines = brz_log.readlines() - pr['BrzLogTail'] = ''.join(lines[-40:]) + pr["BrzLogTail"] = "".join(lines[-40:]) finally: brz_log.close() @@ -236,26 +247,22 @@ def _open_crash_file(): # Windows or if it's manually configured it might need to be created, # and then it should be private os.makedirs(crash_dir, mode=0o600) - date_string = time.strftime('%Y-%m-%dT%H:%M', time.gmtime()) + date_string = time.strftime("%Y-%m-%dT%H:%M", time.gmtime()) # XXX: getuid doesn't work on win32, but the crash directory is per-user - if sys.platform == 'win32': - user_part = '' + if sys.platform == "win32": + user_part = "" else: - user_part = '.%d' % os.getuid() - filename = osutils.pathjoin( - crash_dir, - f'brz{user_part}.{date_string}.crash') + user_part = ".%d" % os.getuid() + filename = osutils.pathjoin(crash_dir, f"brz{user_part}.{date_string}.crash") # be careful here that people can't play tmp-type symlink mischief in the # world-writable directory return filename, os.fdopen( - os.open(filename, - os.O_WRONLY | os.O_CREAT | os.O_EXCL, - 0o600), - 'wb') + os.open(filename, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600), "wb" + ) def _format_plugin_list(): - return ''.join(plugin.describe_plugins(show_paths=True)) + return "".join(plugin.describe_plugins(show_paths=True)) def _format_module_list(): diff --git a/breezy/debug.py b/breezy/debug.py index 74c5597046..9b1452e617 100644 --- a/breezy/debug.py +++ b/breezy/debug.py @@ -47,7 +47,7 @@ def set_debug_flags_from_config(): from breezy import config c = config.GlobalStack() - for f in c.get('debug_flags'): + for f in c.get("debug_flags"): set_debug_flag(f) @@ -68,5 +68,7 @@ def set_trace(): """ import pdb import sys - pdb.Pdb(stdin=sys.__stdin__, stdout=sys.__stdout__ - ).set_trace(sys._getframe().f_back) + + pdb.Pdb(stdin=sys.__stdin__, stdout=sys.__stdout__).set_trace( + sys._getframe().f_back + ) diff --git a/breezy/decorators.py b/breezy/decorators.py index a50219f398..34b05a4b3c 100644 --- a/breezy/decorators.py +++ b/breezy/decorators.py @@ -27,6 +27,7 @@ def only_raises(*errors): def unlock(self): # etc """ + def decorator(unbound): def wrapped(*args, **kwargs): try: @@ -34,11 +35,13 @@ def wrapped(*args, **kwargs): except errors: raise except BaseException: - trace.mutter('Error suppressed by only_raises:') + trace.mutter("Error suppressed by only_raises:") trace.log_exception_quietly() + wrapped.__doc__ = unbound.__doc__ wrapped.__name__ = unbound.__name__ return wrapped + return decorator @@ -87,12 +90,11 @@ def cachedproperty(attrname_or_fn): return _CachedPropertyForAttr(attrname) else: fn = attrname_or_fn - attrname = f'_{fn.__name__}_cached_value' + attrname = f"_{fn.__name__}_cached_value" return _CachedProperty(attrname, fn) class _CachedPropertyForAttr: - def __init__(self, attrname): self.attrname = attrname @@ -101,7 +103,6 @@ def __call__(self, fn): class _CachedProperty: - def __init__(self, attrname, fn): self.fn = fn self.attrname = attrname diff --git a/breezy/delta.py b/breezy/delta.py index 751693fdce..7f948ce160 100644 --- a/breezy/delta.py +++ b/breezy/delta.py @@ -63,73 +63,101 @@ def __init__(self): def __eq__(self, other): if not isinstance(other, TreeDelta): return False - return self.added == other.added \ - and self.removed == other.removed \ - and self.renamed == other.renamed \ - and self.copied == other.copied \ - and self.modified == other.modified \ - and self.unchanged == other.unchanged \ - and self.kind_changed == other.kind_changed \ + return ( + self.added == other.added + and self.removed == other.removed + and self.renamed == other.renamed + and self.copied == other.copied + and self.modified == other.modified + and self.unchanged == other.unchanged + and self.kind_changed == other.kind_changed and self.unversioned == other.unversioned + ) def __ne__(self, other): return not (self == other) def __repr__(self): - return "TreeDelta(added={!r}, removed={!r}, renamed={!r}," \ - " copied={!r}, kind_changed={!r}, modified={!r}, unchanged={!r}," \ + return ( + "TreeDelta(added={!r}, removed={!r}, renamed={!r}," + " copied={!r}, kind_changed={!r}, modified={!r}, unchanged={!r}," " unversioned={!r})".format( - self.added, self.removed, self.renamed, self.copied, - self.kind_changed, self.modified, self.unchanged, - self.unversioned) + self.added, + self.removed, + self.renamed, + self.copied, + self.kind_changed, + self.modified, + self.unchanged, + self.unversioned, + ) + ) def has_changed(self): - return bool(self.modified - or self.added - or self.removed - or self.renamed - or self.copied - or self.kind_changed) - - def get_changes_as_text(self, show_ids=False, show_unchanged=False, - short_status=False): + return bool( + self.modified + or self.added + or self.removed + or self.renamed + or self.copied + or self.kind_changed + ) + + def get_changes_as_text( + self, show_ids=False, show_unchanged=False, short_status=False + ): output = StringIO() report_delta(output, self, short_status, show_ids, show_unchanged) return output.getvalue() -def _compare_trees(old_tree, new_tree, want_unchanged, specific_files, - include_root, extra_trees=None, - require_versioned=False, want_unversioned=False): +def _compare_trees( + old_tree, + new_tree, + want_unchanged, + specific_files, + include_root, + extra_trees=None, + require_versioned=False, + want_unversioned=False, +): """Worker function that implements Tree.changes_from.""" delta = TreeDelta() # mutter('start compare_trees') for change in new_tree.iter_changes( - old_tree, want_unchanged, specific_files, extra_trees=extra_trees, - require_versioned=require_versioned, - want_unversioned=want_unversioned): + old_tree, + want_unchanged, + specific_files, + extra_trees=extra_trees, + require_versioned=require_versioned, + want_unversioned=want_unversioned, + ): if change.versioned == (False, False): delta.unversioned.append(change) continue if not include_root and (None, None) == change.parent_id: continue fully_present = tuple( - (change.versioned[x] and change.kind[x] is not None) - for x in range(2)) + (change.versioned[x] and change.kind[x] is not None) for x in range(2) + ) if fully_present[0] != fully_present[1]: if fully_present[1] is True: delta.added.append(change) else: - if change.kind[0] == 'symlink' and not new_tree.supports_symlinks(): + if change.kind[0] == "symlink" and not new_tree.supports_symlinks(): trace.warning( f'Ignoring "{change.path[0]}" as symlinks ' - 'are not supported on this filesystem.') + "are not supported on this filesystem." + ) else: delta.removed.append(change) elif fully_present[0] is False: delta.missing.append(change) - elif change.name[0] != change.name[1] or change.parent_id[0] != change.parent_id[1]: + elif ( + change.name[0] != change.name[1] + or change.parent_id[0] != change.parent_id[1] + ): # If the name changes, or the parent_id changes, we have a rename or copy # (if we move a parent, that doesn't count as a rename for the # file) @@ -168,9 +196,15 @@ def change_key(change): class _ChangeReporter: """Report changes between two trees.""" - def __init__(self, output=None, suppress_root_add=True, - output_file=None, unversioned_filter=None, view_info=None, - classify=True): + def __init__( + self, + output=None, + suppress_root_add=True, + output_file=None, + unversioned_filter=None, + view_info=None, + classify=True, + ): """Constructor. :param output: a function with the signature of trace.note, i.e. @@ -189,43 +223,48 @@ def __init__(self, output=None, suppress_root_add=True, """ if output_file is not None: if output is not None: - raise BzrError('Cannot specify both output and output_file') + raise BzrError("Cannot specify both output and output_file") def output(fmt, *args): - output_file.write((fmt % args) + '\n') + output_file.write((fmt % args) + "\n") + self.output = output if self.output is None: from . import trace + self.output = trace.note self.suppress_root_add = suppress_root_add - self.modified_map = {'kind changed': 'K', - 'unchanged': ' ', - 'created': 'N', - 'modified': 'M', - 'deleted': 'D', - 'missing': '!', - } - self.versioned_map = {'added': '+', # versioned target - 'unchanged': ' ', # versioned in both - 'removed': '-', # versioned in source - 'unversioned': '?', # versioned in neither - } + self.modified_map = { + "kind changed": "K", + "unchanged": " ", + "created": "N", + "modified": "M", + "deleted": "D", + "missing": "!", + } + self.versioned_map = { + "added": "+", # versioned target + "unchanged": " ", # versioned in both + "removed": "-", # versioned in source + "unversioned": "?", # versioned in neither + } self.unversioned_filter = unversioned_filter if classify: self.kind_marker = osutils.kind_marker else: - self.kind_marker = lambda kind: '' + self.kind_marker = lambda kind: "" if view_info is None: self.view_name = None self.view_files = [] else: self.view_name = view_info[0] self.view_files = view_info[1] - self.output("Operating on whole tree but only reporting on " - f"'{self.view_name}' view.") + self.output( + "Operating on whole tree but only reporting on " + f"'{self.view_name}' view." + ) - def report(self, paths, versioned, renamed, copied, modified, exe_change, - kind): + def report(self, paths, versioned, renamed, copied, modified, exe_change, kind): """Report one change to a file. :param path: The old and new paths as generated by Tree.iter_changes. @@ -241,24 +280,24 @@ def report(self, paths, versioned, renamed, copied, modified, exe_change, """ if trace.is_quiet(): return - if paths[1] == '' and versioned == 'added' and self.suppress_root_add: + if paths[1] == "" and versioned == "added" and self.suppress_root_add: return - if self.view_files and not osutils.is_inside_any(self.view_files, - paths[1]): + if self.view_files and not osutils.is_inside_any(self.view_files, paths[1]): return - if versioned == 'unversioned': + if versioned == "unversioned": # skip ignored unversioned files if needed. if self.unversioned_filter is not None: if self.unversioned_filter(paths[1]): return # dont show a content change in the output. - modified = 'unchanged' + modified = "unchanged" # we show both paths in the following situations: # the file versioning is unchanged AND # ( the path is different OR # the kind is different) - if (versioned == 'unchanged' and - (renamed or copied or modified == 'kind changed')): + if versioned == "unchanged" and ( + renamed or copied or modified == "kind changed" + ): if renamed or copied: # on a rename or copy, we show old and new old_path, path = paths @@ -271,7 +310,7 @@ def report(self, paths, versioned, renamed, copied, modified, exe_change, if kind[0] is not None: old_path += self.kind_marker(kind[0]) old_path += " => " - elif versioned == 'removed': + elif versioned == "removed": # not present in target old_path = "" path = paths[0] @@ -285,17 +324,18 @@ def report(self, paths, versioned, renamed, copied, modified, exe_change, else: rename = self.versioned_map[versioned] # we show the old kind on the new path when the content is deleted. - if modified == 'deleted': + if modified == "deleted": path += self.kind_marker(kind[0]) # otherwise we always show the current kind when there is one elif kind[1] is not None: path += self.kind_marker(kind[1]) if exe_change: - exe = '*' + exe = "*" else: - exe = ' ' - self.output("%s%s%s %s%s", rename, self.modified_map[modified], exe, - old_path, path) + exe = " " + self.output( + "%s%s%s %s%s", rename, self.modified_map[modified], exe, old_path, path + ) def report_changes(change_iterator, reporter): @@ -309,11 +349,11 @@ def report_changes(change_iterator, reporter): :param reporter: The _ChangeReporter that will report the changes. """ versioned_change_map = { - (True, True): 'unchanged', - (True, False): 'removed', - (False, True): 'added', - (False, False): 'unversioned', - } + (True, True): "unchanged", + (True, False): "removed", + (False, True): "added", + (False, False): "unversioned", + } def path_key(change): if change.path[0] is not None: @@ -321,6 +361,7 @@ def path_key(change): else: path = change.path[1] return osutils.splitpath(path) + for change in sorted(change_iterator, key=path_key): exe_change = False # files are "renamed" if they are moved or if name changes, as long @@ -349,14 +390,29 @@ def path_key(change): else: modified = "unchanged" if change.kind[1] == "file": - exe_change = (change.executable[0] != change.executable[1]) + exe_change = change.executable[0] != change.executable[1] versioned_change = versioned_change_map[change.versioned] - reporter.report(change.path, versioned_change, renamed, copied, modified, - exe_change, change.kind) - - -def report_delta(to_file, delta, short_status=False, show_ids=False, - show_unchanged=False, indent='', predicate=None, classify=True): + reporter.report( + change.path, + versioned_change, + renamed, + copied, + modified, + exe_change, + change.kind, + ) + + +def report_delta( + to_file, + delta, + short_status=False, + show_ids=False, + show_unchanged=False, + indent="", + predicate=None, + classify=True, +): """Output this delta in status-like form to to_file. :param to_file: A file-like object where the output is displayed. @@ -381,48 +437,56 @@ def report_delta(to_file, delta, short_status=False, show_ids=False, def decorate_path(path, kind, meta_modified=None): if not classify: return path - if kind == 'directory': - path += '/' - elif kind == 'symlink': - path += '@' + if kind == "directory": + path += "/" + elif kind == "symlink": + path += "@" if meta_modified: - path += '*' + path += "*" return path def show_more_renamed(item): dec_new_path = decorate_path(item.path[1], item.kind[1], item.meta_modified()) - to_file.write(f' => {dec_new_path}') + to_file.write(f" => {dec_new_path}") if item.changed_content or item.meta_modified(): - extra_modified.append(InventoryTreeChange( - item.file_id, (item.path[1], item.path[1]), - item.changed_content, - item.versioned, - (item.parent_id[1], item.parent_id[1]), - (item.name[1], item.name[1]), - (item.kind[1], item.kind[1]), - item.executable)) + extra_modified.append( + InventoryTreeChange( + item.file_id, + (item.path[1], item.path[1]), + item.changed_content, + item.versioned, + (item.parent_id[1], item.parent_id[1]), + (item.name[1], item.name[1]), + (item.kind[1], item.kind[1]), + item.executable, + ) + ) def show_more_kind_changed(item): - to_file.write(f' ({item.kind[0]} => {item.kind[1]})') + to_file.write(f" ({item.kind[0]} => {item.kind[1]})") - def show_path(path, kind, meta_modified, - default_format, with_file_id_format): + def show_path(path, kind, meta_modified, default_format, with_file_id_format): dec_path = decorate_path(path, kind, meta_modified) if show_ids: to_file.write(with_file_id_format % dec_path) else: to_file.write(default_format % dec_path) - def show_list(files, long_status_name, short_status_letter, - default_format='%s', with_file_id_format='%-30s', - show_more=None): + def show_list( + files, + long_status_name, + short_status_letter, + default_format="%s", + with_file_id_format="%-30s", + show_more=None, + ): if files: header_shown = False if short_status: prefix = short_status_letter else: - prefix = '' - prefix = indent + prefix + ' ' + prefix = "" + prefix = indent + prefix + " " for item in files: if item.path[0] is None: @@ -434,30 +498,49 @@ def show_list(files, long_status_name, short_status_letter, if predicate is not None and not predicate(path): continue if not header_shown and not short_status: - to_file.write(indent + long_status_name + ':\n') + to_file.write(indent + long_status_name + ":\n") header_shown = True to_file.write(prefix) - show_path(path, kind, item.meta_modified(), - default_format, with_file_id_format) + show_path( + path, + kind, + item.meta_modified(), + default_format, + with_file_id_format, + ) if show_more is not None: show_more(item) - if show_ids and getattr(item, 'file_id', None): + if show_ids and getattr(item, "file_id", None): to_file.write(f" {item.file_id.decode('utf-8')}") - to_file.write('\n') + to_file.write("\n") - show_list(delta.removed, 'removed', 'D') - show_list(delta.added, 'added', 'A') - show_list(delta.missing, 'missing', '!') + show_list(delta.removed, "removed", "D") + show_list(delta.added, "added", "A") + show_list(delta.missing, "missing", "!") extra_modified = [] - show_list(delta.renamed, 'renamed', 'R', with_file_id_format='%s', - show_more=show_more_renamed) - show_list(delta.copied, 'copied', 'C', with_file_id_format='%s', - show_more=show_more_renamed) - show_list(delta.kind_changed, 'kind changed', 'K', - with_file_id_format='%s', - show_more=show_more_kind_changed) - show_list(delta.modified + extra_modified, 'modified', 'M') + show_list( + delta.renamed, + "renamed", + "R", + with_file_id_format="%s", + show_more=show_more_renamed, + ) + show_list( + delta.copied, + "copied", + "C", + with_file_id_format="%s", + show_more=show_more_renamed, + ) + show_list( + delta.kind_changed, + "kind changed", + "K", + with_file_id_format="%s", + show_more=show_more_kind_changed, + ) + show_list(delta.modified + extra_modified, "modified", "M") if show_unchanged: - show_list(delta.unchanged, 'unchanged', 'S') + show_list(delta.unchanged, "unchanged", "S") - show_list(delta.unversioned, 'unknown', ' ') + show_list(delta.unversioned, "unknown", " ") diff --git a/breezy/diff.py b/breezy/diff.py index 50d9525d66..4fb0ad6825 100644 --- a/breezy/diff.py +++ b/breezy/diff.py @@ -23,7 +23,9 @@ from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ import subprocess from breezy import ( @@ -35,7 +37,8 @@ from breezy.workingtree import WorkingTree from breezy.i18n import gettext -""") +""", +) from . import errors, osutils from . import transport as _mod_transport @@ -60,9 +63,17 @@ def __init__(self, matching_blocks): self.opcodes = None -def internal_diff(old_label, oldlines, new_label, newlines, to_file, - allow_binary=False, sequence_matcher=None, - path_encoding='utf8', context_lines=DEFAULT_CONTEXT_AMOUNT): +def internal_diff( + old_label, + oldlines, + new_label, + newlines, + to_file, + allow_binary=False, + sequence_matcher=None, + path_encoding="utf8", + context_lines=DEFAULT_CONTEXT_AMOUNT, +): # FIXME: difflib is wrong if there is no trailing newline. # The syntax used by patch seems to be "\ No newline at # end of file" following the last diff line from that @@ -79,12 +90,16 @@ def internal_diff(old_label, oldlines, new_label, newlines, to_file, if sequence_matcher is None: import patiencediff + sequence_matcher = patiencediff.PatienceSequenceMatcher ud = unified_diff_bytes( - oldlines, newlines, - fromfile=old_label.encode(path_encoding, 'replace'), - tofile=new_label.encode(path_encoding, 'replace'), - n=context_lines, sequencematcher=sequence_matcher) + oldlines, + newlines, + fromfile=old_label.encode(path_encoding, "replace"), + tofile=new_label.encode(path_encoding, "replace"), + n=context_lines, + sequencematcher=sequence_matcher, + ) ud = list(ud) if len(ud) == 0: # Identical contents, nothing to do @@ -92,19 +107,28 @@ def internal_diff(old_label, oldlines, new_label, newlines, to_file, # work-around for difflib being too smart for its own good # if /dev/null is "1,0", patch won't recognize it as /dev/null if not oldlines: - ud[2] = ud[2].replace(b'-1,0', b'-0,0') + ud[2] = ud[2].replace(b"-1,0", b"-0,0") elif not newlines: - ud[2] = ud[2].replace(b'+1,0', b'+0,0') + ud[2] = ud[2].replace(b"+1,0", b"+0,0") for line in ud: to_file.write(line) - if not line.endswith(b'\n'): + if not line.endswith(b"\n"): to_file.write(b"\n\\ No newline at end of file\n") - to_file.write(b'\n') - - -def unified_diff_bytes(a, b, fromfile=b'', tofile=b'', fromfiledate=b'', - tofiledate=b'', n=3, lineterm=b'\n', sequencematcher=None): + to_file.write(b"\n") + + +def unified_diff_bytes( + a, + b, + fromfile=b"", + tofile=b"", + fromfiledate=b"", + tofiledate=b"", + n=3, + lineterm=b"\n", + sequencematcher=None, +): r"""Compare two sequences of lines; generate the delta as a unified diff. Unified diffs are a compact way of showing line changes and a few @@ -145,29 +169,29 @@ def unified_diff_bytes(a, b, fromfile=b'', tofile=b'', fromfiledate=b'', sequencematcher = difflib.SequenceMatcher if fromfiledate: - fromfiledate = b'\t' + bytes(fromfiledate) + fromfiledate = b"\t" + bytes(fromfiledate) if tofiledate: - tofiledate = b'\t' + bytes(tofiledate) + tofiledate = b"\t" + bytes(tofiledate) started = False for group in sequencematcher(None, a, b).get_grouped_opcodes(n): if not started: - yield b'--- %s%s%s' % (fromfile, fromfiledate, lineterm) - yield b'+++ %s%s%s' % (tofile, tofiledate, lineterm) + yield b"--- %s%s%s" % (fromfile, fromfiledate, lineterm) + yield b"+++ %s%s%s" % (tofile, tofiledate, lineterm) started = True i1, i2, j1, j2 = group[0][1], group[-1][2], group[0][3], group[-1][4] yield b"@@ -%d,%d +%d,%d @@%s" % (i1 + 1, i2 - i1, j1 + 1, j2 - j1, lineterm) for tag, i1, i2, j1, j2 in group: - if tag == 'equal': + if tag == "equal": for line in a[i1:i2]: - yield b' ' + line + yield b" " + line continue - if tag == 'replace' or tag == 'delete': + if tag == "replace" or tag == "delete": for line in a[i1:i2]: - yield b'-' + line - if tag == 'replace' or tag == 'insert': + yield b"-" + line + if tag == "replace" or tag == "insert": for line in b[j1:j2]: - yield b'+' + line + yield b"+" + line def _spawn_external_diff(diffcmd, capture_errors=True): @@ -182,23 +206,25 @@ def _spawn_external_diff(diffcmd, capture_errors=True): if capture_errors: # construct minimal environment env = {} - path = os.environ.get('PATH') + path = os.environ.get("PATH") if path is not None: - env['PATH'] = path - env['LANGUAGE'] = 'C' # on win32 only LANGUAGE has effect - env['LANG'] = 'C' - env['LC_ALL'] = 'C' + env["PATH"] = path + env["LANGUAGE"] = "C" # on win32 only LANGUAGE has effect + env["LANG"] = "C" + env["LC_ALL"] = "C" stderr = subprocess.PIPE else: env = None stderr = None try: - pipe = subprocess.Popen(diffcmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=stderr, - env=env) + pipe = subprocess.Popen( + diffcmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=stderr, + env=env, + ) except FileNotFoundError as e: raise errors.NoDiff(str(e)) from e @@ -206,15 +232,27 @@ def _spawn_external_diff(diffcmd, capture_errors=True): # diff style options as of GNU diff v3.2 -style_option_list = ['-c', '-C', '--context', - '-e', '--ed', - '-f', '--forward-ed', - '-q', '--brief', - '--normal', - '-n', '--rcs', - '-u', '-U', '--unified', - '-y', '--side-by-side', - '-D', '--ifdef'] +style_option_list = [ + "-c", + "-C", + "--context", + "-e", + "--ed", + "-f", + "--forward-ed", + "-q", + "--brief", + "--normal", + "-n", + "--rcs", + "-u", + "-U", + "--unified", + "-y", + "--side-by-side", + "-D", + "--ifdef", +] def default_style_unified(diff_opts): @@ -237,22 +275,21 @@ def default_style_unified(diff_opts): continue break else: - diff_opts.append('-u') + diff_opts.append("-u") return diff_opts -def external_diff(old_label, oldlines, new_label, newlines, to_file, - diff_opts): +def external_diff(old_label, oldlines, new_label, newlines, to_file, diff_opts): """Display a diff by calling out to the external diff program.""" import tempfile # make sure our own output is properly ordered before the diff to_file.flush() - oldtmp_fd, old_abspath = tempfile.mkstemp(prefix='brz-diff-old-') - newtmp_fd, new_abspath = tempfile.mkstemp(prefix='brz-diff-new-') - oldtmpf = os.fdopen(oldtmp_fd, 'wb') - newtmpf = os.fdopen(newtmp_fd, 'wb') + oldtmp_fd, old_abspath = tempfile.mkstemp(prefix="brz-diff-old-") + newtmp_fd, new_abspath = tempfile.mkstemp(prefix="brz-diff-new-") + oldtmpf = os.fdopen(oldtmp_fd, "wb") + newtmpf = os.fdopen(newtmp_fd, "wb") try: # TODO: perhaps a special case for comparing to or from the empty @@ -270,18 +307,21 @@ def external_diff(old_label, oldlines, new_label, newlines, to_file, if not diff_opts: diff_opts = [] - if sys.platform == 'win32': + if sys.platform == "win32": # Popen doesn't do the proper encoding for external commands # Since we are dealing with an ANSI api, use mbcs encoding - old_label = old_label.encode('mbcs') - new_label = new_label.encode('mbcs') - diffcmd = ['diff', - '--label', old_label, - old_abspath, - '--label', new_label, - new_abspath, - '--binary', - ] + old_label = old_label.encode("mbcs") + new_label = new_label.encode("mbcs") + diffcmd = [ + "diff", + "--label", + old_label, + old_abspath, + "--label", + new_label, + new_abspath, + "--binary", + ] diff_opts = default_style_unified(diff_opts) @@ -293,7 +333,7 @@ def external_diff(old_label, oldlines, new_label, newlines, to_file, rc = pipe.returncode # internal_diff() adds a trailing newline, add one here for consistency - out += b'\n' + out += b"\n" if rc == 2: # 'diff' gives retcode == 2 for all sorts of errors # one of those is 'Binary files differ'. @@ -306,19 +346,21 @@ def external_diff(old_label, oldlines, new_label, newlines, to_file, out, err = pipe.communicate() # Write out the new i18n diff response - to_file.write(out + b'\n') + to_file.write(out + b"\n") if pipe.returncode != 2: raise errors.BzrError( - 'external diff failed with exit code 2' - ' when run with LANG=C and LC_ALL=C,' - f' but not when run natively: {diffcmd!r}') + "external diff failed with exit code 2" + " when run with LANG=C and LC_ALL=C," + f" but not when run natively: {diffcmd!r}" + ) - first_line = lang_c_out.split(b'\n', 1)[0] + first_line = lang_c_out.split(b"\n", 1)[0] # Starting with diffutils 2.8.4 the word "binary" was dropped. - m = re.match(b'^(binary )?files.*differ$', first_line, re.I) + m = re.match(b"^(binary )?files.*differ$", first_line, re.I) if m is None: - raise errors.BzrError('external diff failed with exit code 2;' - f' command: {diffcmd!r}') + raise errors.BzrError( + "external diff failed with exit code 2;" f" command: {diffcmd!r}" + ) else: # Binary files differ, just return return @@ -329,14 +371,16 @@ def external_diff(old_label, oldlines, new_label, newlines, to_file, if rc not in (0, 1): # returns 1 if files differ; that's OK if rc < 0: - msg = 'signal %d' % (-rc) + msg = "signal %d" % (-rc) else: - msg = 'exit code %d' % rc + msg = "exit code %d" % rc - raise errors.BzrError(f'external diff failed with {msg}; command: {diffcmd!r}') + raise errors.BzrError( + f"external diff failed with {msg}; command: {diffcmd!r}" + ) finally: - oldtmpf.close() # and delete + oldtmpf.close() # and delete newtmpf.close() def cleanup(path): @@ -348,14 +392,15 @@ def cleanup(path): except FileNotFoundError: pass except OSError as e: - warning('Failed to delete temporary file: %s %s', path, e) + warning("Failed to delete temporary file: %s %s", path, e) cleanup(old_abspath) cleanup(new_abspath) def get_trees_and_branches_to_diff_locked( - path_list, revision_specs, old_url, new_url, exit_stack, apply_view=True): + path_list, revision_specs, old_url, new_url, exit_stack, apply_view=True +): """Get the trees and specific files to diff given a list of paths. This method works out the trees to be diff'ed and the files of @@ -403,7 +448,7 @@ def get_trees_and_branches_to_diff_locked( consider_relpath = True if path_list is None or len(path_list) == 0: # If no path is given, the current working tree is used - default_location = '.' + default_location = "." consider_relpath = False elif old_url is not None and new_url is not None: other_paths = path_list @@ -422,10 +467,13 @@ def lock_tree_or_branch(wt, br): specific_files = [] if old_url is None: old_url = default_location - working_tree, branch, relpath = \ - controldir.ControlDir.open_containing_tree_or_branch(old_url) + ( + working_tree, + branch, + relpath, + ) = controldir.ControlDir.open_containing_tree_or_branch(old_url) lock_tree_or_branch(working_tree, branch) - if consider_relpath and relpath != '': + if consider_relpath and relpath != "": if working_tree is not None and apply_view: views.check_path_in_view(working_tree, relpath) specific_files.append(relpath) @@ -436,27 +484,30 @@ def lock_tree_or_branch(wt, br): if new_url is None: new_url = default_location if new_url != old_url: - working_tree, branch, relpath = \ - controldir.ControlDir.open_containing_tree_or_branch(new_url) + ( + working_tree, + branch, + relpath, + ) = controldir.ControlDir.open_containing_tree_or_branch(new_url) lock_tree_or_branch(working_tree, branch) - if consider_relpath and relpath != '': + if consider_relpath and relpath != "": if working_tree is not None and apply_view: views.check_path_in_view(working_tree, relpath) specific_files.append(relpath) - new_tree = _get_tree_to_diff(new_revision_spec, working_tree, branch, - basis_is_default=working_tree is None) + new_tree = _get_tree_to_diff( + new_revision_spec, working_tree, branch, basis_is_default=working_tree is None + ) new_branch = branch # Get the specific files (all files is None, no files is []) if make_paths_wt_relative and working_tree is not None: other_paths = working_tree.safe_relpath_files( - other_paths, - apply_view=apply_view) + other_paths, apply_view=apply_view + ) specific_files.extend(other_paths) if len(specific_files) == 0: specific_files = None - if (working_tree is not None and working_tree.supports_views() and - apply_view): + if working_tree is not None and working_tree.supports_views() and apply_view: view_files = working_tree.views.lookup_view() if view_files: specific_files = view_files @@ -467,8 +518,7 @@ def lock_tree_or_branch(wt, br): extra_trees = None if working_tree is not None and working_tree not in (old_tree, new_tree): extra_trees = (working_tree,) - return (old_tree, new_tree, old_branch, new_branch, - specific_files, extra_trees) + return (old_tree, new_tree, old_branch, new_branch, specific_files, extra_trees) def _get_tree_to_diff(spec, tree=None, branch=None, basis_is_default=True): @@ -485,14 +535,20 @@ def _get_tree_to_diff(spec, tree=None, branch=None, basis_is_default=True): return spec.as_tree(branch) -def show_diff_trees(old_tree, new_tree, to_file, specific_files=None, - external_diff_options=None, - old_label: str = 'a/', new_label: str = 'b/', - extra_trees=None, - path_encoding: str = 'utf8', - using: Optional[str] = None, - format_cls=None, - context=DEFAULT_CONTEXT_AMOUNT): +def show_diff_trees( + old_tree, + new_tree, + to_file, + specific_files=None, + external_diff_options=None, + old_label: str = "a/", + new_label: str = "b/", + extra_trees=None, + path_encoding: str = "utf8", + using: Optional[str] = None, + format_cls=None, + context=DEFAULT_CONTEXT_AMOUNT, +): """Show in text form the changes from one tree to another. :param to_file: The output stream. @@ -515,11 +571,17 @@ def show_diff_trees(old_tree, new_tree, to_file, specific_files=None, for tree in extra_trees: exit_stack.enter_context(tree.lock_read()) exit_stack.enter_context(new_tree.lock_read()) - differ = format_cls.from_trees_options(old_tree, new_tree, to_file, - path_encoding, - external_diff_options, - old_label, new_label, using, - context_lines=context) + differ = format_cls.from_trees_options( + old_tree, + new_tree, + to_file, + path_encoding, + external_diff_options, + old_label, + new_label, + using, + context_lines=context, + ) return differ.show_diff(specific_files, extra_trees) @@ -535,7 +597,13 @@ def _patch_header_date(tree, path): def get_executable_change(old_is_x, new_is_x): descr = {True: b"+x", False: b"-x", None: b"??"} if old_is_x != new_is_x: - return [b"%s to %s" % (descr[old_is_x], descr[new_is_x],)] + return [ + b"%s to %s" + % ( + descr[old_is_x], + descr[new_is_x], + ) + ] else: return [] @@ -544,13 +612,13 @@ class DiffPath: """Base type for command object that compare files.""" # The type or contents of the file were unsuitable for diffing - CANNOT_DIFF = 'CANNOT_DIFF' + CANNOT_DIFF = "CANNOT_DIFF" # The file has changed in a semantic way - CHANGED = 'CHANGED' + CHANGED = "CHANGED" # The file content may have changed, but there is no semantic change - UNCHANGED = 'UNCHANGED' + UNCHANGED = "UNCHANGED" - def __init__(self, old_tree, new_tree, to_file, path_encoding='utf-8'): + def __init__(self, old_tree, new_tree, to_file, path_encoding="utf-8"): """Constructor. :param old_tree: The tree to show as the old tree in the comparison @@ -568,8 +636,12 @@ def finish(self): @classmethod def from_diff_tree(klass, diff_tree): - return klass(diff_tree.old_tree, diff_tree.new_tree, - diff_tree.to_file, diff_tree.path_encoding) + return klass( + diff_tree.old_tree, + diff_tree.new_tree, + diff_tree.to_file, + diff_tree.path_encoding, + ) @staticmethod def _diff_many(differs, old_path, new_path, old_kind, new_kind): @@ -608,42 +680,37 @@ def diff(self, old_path, new_path, old_kind, new_kind): """ if None in (old_kind, new_kind): return DiffPath.CANNOT_DIFF - result = DiffPath._diff_many( - self.differs, old_path, new_path, old_kind, None) + result = DiffPath._diff_many(self.differs, old_path, new_path, old_kind, None) if result is DiffPath.CANNOT_DIFF: return result - return DiffPath._diff_many( - self.differs, old_path, new_path, None, new_kind) + return DiffPath._diff_many(self.differs, old_path, new_path, None, new_kind) class DiffTreeReference(DiffPath): - def diff(self, old_path, new_path, old_kind, new_kind): """Perform comparison between two tree references. (dummy).""" - if 'tree-reference' not in (old_kind, new_kind): + if "tree-reference" not in (old_kind, new_kind): return self.CANNOT_DIFF - if old_kind not in ('tree-reference', None): + if old_kind not in ("tree-reference", None): return self.CANNOT_DIFF - if new_kind not in ('tree-reference', None): + if new_kind not in ("tree-reference", None): return self.CANNOT_DIFF return self.CHANGED class DiffDirectory(DiffPath): - def diff(self, old_path, new_path, old_kind, new_kind): """Perform comparison between two directories. (dummy).""" - if 'directory' not in (old_kind, new_kind): + if "directory" not in (old_kind, new_kind): return self.CANNOT_DIFF - if old_kind not in ('directory', None): + if old_kind not in ("directory", None): return self.CANNOT_DIFF - if new_kind not in ('directory', None): + if new_kind not in ("directory", None): return self.CANNOT_DIFF return self.CHANGED class DiffSymlink(DiffPath): - def diff(self, old_path, new_path, old_kind, new_kind): """Perform comparison between two symlinks. @@ -652,15 +719,15 @@ def diff(self, old_path, new_path, old_kind, new_kind): :param old_kind: Old file-kind of the file :param new_kind: New file-kind of the file """ - if 'symlink' not in (old_kind, new_kind): + if "symlink" not in (old_kind, new_kind): return self.CANNOT_DIFF - if old_kind == 'symlink': + if old_kind == "symlink": old_target = self.old_tree.get_symlink_target(old_path) elif old_kind is None: old_target = None else: return self.CANNOT_DIFF - if new_kind == 'symlink': + if new_kind == "symlink": new_target = self.new_tree.get_symlink_target(new_path) elif new_kind is None: new_target = None @@ -670,27 +737,42 @@ def diff(self, old_path, new_path, old_kind, new_kind): def diff_symlink(self, old_target, new_target): if old_target is None: - self.to_file.write(b'=== target is \'%s\'\n' % - new_target.encode(self.path_encoding, 'replace')) + self.to_file.write( + b"=== target is '%s'\n" + % new_target.encode(self.path_encoding, "replace") + ) elif new_target is None: - self.to_file.write(b'=== target was \'%s\'\n' % - old_target.encode(self.path_encoding, 'replace')) + self.to_file.write( + b"=== target was '%s'\n" + % old_target.encode(self.path_encoding, "replace") + ) else: - self.to_file.write(b'=== target changed \'%s\' => \'%s\'\n' % - (old_target.encode(self.path_encoding, 'replace'), - new_target.encode(self.path_encoding, 'replace'))) + self.to_file.write( + b"=== target changed '%s' => '%s'\n" + % ( + old_target.encode(self.path_encoding, "replace"), + new_target.encode(self.path_encoding, "replace"), + ) + ) return self.CHANGED class DiffText(DiffPath): - # GNU Patch uses the epoch date to detect files that are being added # or removed in a diff. - EPOCH_DATE = '1970-01-01 00:00:00 +0000' - - def __init__(self, old_tree, new_tree, to_file, path_encoding='utf-8', - old_label='', new_label='', text_differ=internal_diff, - context_lines=DEFAULT_CONTEXT_AMOUNT): + EPOCH_DATE = "1970-01-01 00:00:00 +0000" + + def __init__( + self, + old_tree, + new_tree, + to_file, + path_encoding="utf-8", + old_label="", + new_label="", + text_differ=internal_diff, + context_lines=DEFAULT_CONTEXT_AMOUNT, + ): DiffPath.__init__(self, old_tree, new_tree, to_file, path_encoding) self.text_differ = text_differ self.old_label = old_label @@ -706,22 +788,22 @@ def diff(self, old_path, new_path, old_kind, new_kind): :param old_kind: Old file-kind of the file :param new_kind: New file-kind of the file """ - if 'file' not in (old_kind, new_kind): + if "file" not in (old_kind, new_kind): return self.CANNOT_DIFF - if old_kind == 'file': + if old_kind == "file": old_date = _patch_header_date(self.old_tree, old_path) elif old_kind is None: old_date = self.EPOCH_DATE else: return self.CANNOT_DIFF - if new_kind == 'file': + if new_kind == "file": new_date = _patch_header_date(self.new_tree, new_path) elif new_kind is None: new_date = self.EPOCH_DATE else: return self.CANNOT_DIFF - from_label = f'{self.old_label}{old_path or new_path}\t{old_date}' - to_label = f'{self.new_label}{new_path or old_path}\t{new_date}' + from_label = f"{self.old_label}{old_path or new_path}\t{old_date}" + to_label = f"{self.new_label}{new_path or old_path}\t{new_date}" return self.diff_text(old_path, new_path, from_label, to_label) def diff_text(self, from_path, to_path, from_label, to_label): @@ -733,6 +815,7 @@ def diff_text(self, from_path, to_path, from_label, to_label): to a different file from from_path. If None, the file is not present in the to tree. """ + def _get_text(tree, path): if path is None: return [] @@ -740,37 +823,58 @@ def _get_text(tree, path): return tree.get_file_lines(path) except _mod_transport.NoSuchFile: return [] + try: from_text = _get_text(self.old_tree, from_path) to_text = _get_text(self.new_tree, to_path) - self.text_differ(from_label, from_text, to_label, to_text, - self.to_file, path_encoding=self.path_encoding, - context_lines=self.context_lines) + self.text_differ( + from_label, + from_text, + to_label, + to_text, + self.to_file, + path_encoding=self.path_encoding, + context_lines=self.context_lines, + ) except errors.BinaryFile: self.to_file.write( - ("Binary files {}{} and {}{} differ\n".format(self.old_label, from_path or to_path, - self.new_label, to_path or from_path) - ).encode(self.path_encoding, 'replace')) + ( + "Binary files {}{} and {}{} differ\n".format( + self.old_label, + from_path or to_path, + self.new_label, + to_path or from_path, + ) + ).encode(self.path_encoding, "replace") + ) return self.CHANGED class DiffFromTool(DiffPath): - - def __init__(self, command_template: Union[str, List[str]], - old_tree: Tree, new_tree: Tree, to_file, - path_encoding='utf-8'): + def __init__( + self, + command_template: Union[str, List[str]], + old_tree: Tree, + new_tree: Tree, + to_file, + path_encoding="utf-8", + ): import tempfile + DiffPath.__init__(self, old_tree, new_tree, to_file, path_encoding) self.command_template = command_template - self._root = tempfile.mkdtemp(prefix='brz-diff-') + self._root = tempfile.mkdtemp(prefix="brz-diff-") @classmethod - def from_string(klass, - command_template: Union[str, List[str]], - old_tree: Tree, new_tree: Tree, to_file, - path_encoding: str = 'utf-8'): - return klass(command_template, old_tree, new_tree, to_file, - path_encoding) + def from_string( + klass, + command_template: Union[str, List[str]], + old_tree: Tree, + new_tree: Tree, + to_file, + path_encoding: str = "utf-8", + ): + return klass(command_template, old_tree, new_tree, to_file, path_encoding) @classmethod def make_from_diff_tree(klass, command_string, external_diff_options=None): @@ -778,21 +882,25 @@ def from_diff_tree(diff_tree): full_command_string = [command_string] if external_diff_options is not None: full_command_string.extend(external_diff_options.split()) - return klass.from_string(full_command_string, diff_tree.old_tree, - diff_tree.new_tree, diff_tree.to_file) + return klass.from_string( + full_command_string, + diff_tree.old_tree, + diff_tree.new_tree, + diff_tree.to_file, + ) + return from_diff_tree def _get_command(self, old_path, new_path): - my_map = {'old_path': old_path, 'new_path': new_path} - command = [t.format(**my_map) for t in - self.command_template] + my_map = {"old_path": old_path, "new_path": new_path} + command = [t.format(**my_map) for t in self.command_template] if command == self.command_template: command += [old_path, new_path] - if sys.platform == 'win32': # Popen doesn't accept unicode on win32 + if sys.platform == "win32": # Popen doesn't accept unicode on win32 command_encoded = [] for c in command: if isinstance(c, str): - command_encoded.append(c.encode('mbcs')) + command_encoded.append(c.encode("mbcs")) else: command_encoded.append(c) return command_encoded @@ -802,8 +910,7 @@ def _get_command(self, old_path, new_path): def _execute(self, old_path, new_path): command = self._get_command(old_path, new_path) try: - proc = subprocess.Popen(command, stdout=subprocess.PIPE, - cwd=self._root) + proc = subprocess.Popen(command, stdout=subprocess.PIPE, cwd=self._root) except FileNotFoundError as e: raise errors.ExecutableMissing(command[0]) from e self.to_file.write(proc.stdout.read()) @@ -811,11 +918,12 @@ def _execute(self, old_path, new_path): return proc.wait() def _try_symlink_root(self, tree, prefix): - if (getattr(tree, 'abspath', None) is None or - not osutils.supports_symlinks(self._root)): + if getattr(tree, "abspath", None) is None or not osutils.supports_symlinks( + self._root + ): return False try: - os.symlink(tree.abspath(''), osutils.pathjoin(self._root, prefix)) + os.symlink(tree.abspath(""), osutils.pathjoin(self._root, prefix)) except FileExistsError: pass return True @@ -823,12 +931,12 @@ def _try_symlink_root(self, tree, prefix): @staticmethod def _fenc(): """Returns safe encoding for passing file path to diff tool.""" - if sys.platform == 'win32': - return 'mbcs' + if sys.platform == "win32": + return "mbcs" else: # Don't fallback to 'utf-8' because subprocess may not be able to # handle utf-8 correctly when locale is not utf-8. - return sys.getfilesystemencoding() or 'ascii' + return sys.getfilesystemencoding() or "ascii" def _is_safepath(self, path): """Return true if `path` may be able to pass to subprocess.""" @@ -845,12 +953,11 @@ def _safe_filename(self, prefix, relpath): fenc = self._fenc() # encoded_str.replace('?', '_') may break multibyte char. # So we should encode, decode, then replace(u'?', u'_') - relpath_tmp = relpath.encode(fenc, 'replace').decode(fenc, 'replace') - relpath_tmp = relpath_tmp.replace('?', '_') + relpath_tmp = relpath.encode(fenc, "replace").decode(fenc, "replace") + relpath_tmp = relpath_tmp.replace("?", "_") return osutils.pathjoin(self._root, prefix, relpath_tmp) - def _write_file(self, relpath, tree, prefix, force_temp=False, - allow_write=False): + def _write_file(self, relpath, tree, prefix, force_temp=False, allow_write=False): if not force_temp and isinstance(tree, WorkingTree): full_path = tree.abspath(relpath) if self._is_safepath(full_path): @@ -864,8 +971,7 @@ def _write_file(self, relpath, tree, prefix, force_temp=False, os.makedirs(parent_dir) except FileExistsError: pass - with tree.get_file(relpath) as source, \ - open(full_path, 'wb') as target: + with tree.get_file(relpath) as source, open(full_path, "wb") as target: osutils.pumpfile(source, target) try: mtime = tree.get_file_mtime(relpath) @@ -877,13 +983,13 @@ def _write_file(self, relpath, tree, prefix, force_temp=False, osutils.make_readonly(full_path) return full_path - def _prepare_files(self, old_path, new_path, force_temp=False, - allow_write_new=False): - old_disk_path = self._write_file( - old_path, self.old_tree, 'old', force_temp) + def _prepare_files( + self, old_path, new_path, force_temp=False, allow_write_new=False + ): + old_disk_path = self._write_file(old_path, self.old_tree, "old", force_temp) new_disk_path = self._write_file( - new_path, self.new_tree, 'new', force_temp, - allow_write=allow_write_new) + new_path, self.new_tree, "new", force_temp, allow_write=allow_write_new + ) return old_disk_path, new_disk_path def finish(self): @@ -892,14 +998,15 @@ def finish(self): except FileNotFoundError: pass except OSError as e: - mutter(f"The temporary directory \"{self._root}\" was not " - f"cleanly removed: {e}.") + mutter( + f'The temporary directory "{self._root}" was not ' + f"cleanly removed: {e}." + ) def diff(self, old_path, new_path, old_kind, new_kind): - if (old_kind, new_kind) != ('file', 'file'): + if (old_kind, new_kind) != ("file", "file"): return DiffPath.CANNOT_DIFF - (old_disk_path, new_disk_path) = self._prepare_files( - old_path, new_path) + (old_disk_path, new_disk_path) = self._prepare_files(old_path, new_path) self._execute(old_disk_path, new_disk_path) def edit_file(self, old_path, new_path): @@ -911,10 +1018,11 @@ def edit_file(self, old_path, new_path): :return: The new contents of the file. """ old_abs_path, new_abs_path = self._prepare_files( - old_path, new_path, allow_write_new=True, force_temp=True) + old_path, new_path, allow_write_new=True, force_temp=True + ) command = self._get_command(old_abs_path, new_abs_path) subprocess.call(command, cwd=self._root) - with open(new_abs_path, 'rb') as new_file: + with open(new_abs_path, "rb") as new_file: return new_file.read() @@ -932,12 +1040,21 @@ class DiffTree: # list of factories that can provide instances of DiffPath objects # may be extended by plugins. - diff_factories = [DiffSymlink.from_diff_tree, - DiffDirectory.from_diff_tree, - DiffTreeReference.from_diff_tree] - - def __init__(self, old_tree, new_tree, to_file, path_encoding='utf-8', - diff_text=None, extra_factories=None): + diff_factories = [ + DiffSymlink.from_diff_tree, + DiffDirectory.from_diff_tree, + DiffTreeReference.from_diff_tree, + ] + + def __init__( + self, + old_tree, + new_tree, + to_file, + path_encoding="utf-8", + diff_text=None, + extra_factories=None, + ): """Constructor. :param old_tree: Tree to show as old in the comparison @@ -950,8 +1067,9 @@ def __init__(self, old_tree, new_tree, to_file, path_encoding='utf-8', DiffPaths """ if diff_text is None: - diff_text = DiffText(old_tree, new_tree, to_file, path_encoding, - '', '', internal_diff) + diff_text = DiffText( + old_tree, new_tree, to_file, path_encoding, "", "", internal_diff + ) self.old_tree = old_tree self.new_tree = new_tree self.to_file = to_file @@ -963,9 +1081,18 @@ def __init__(self, old_tree, new_tree, to_file, path_encoding='utf-8', self.differs.extend([diff_text, DiffKindChange.from_diff_tree(self)]) @classmethod - def from_trees_options(klass, old_tree, new_tree, to_file, - path_encoding, external_diff_options, old_label, - new_label, using, context_lines): + def from_trees_options( + klass, + old_tree, + new_tree, + to_file, + path_encoding, + external_diff_options, + old_label, + new_label, + using, + context_lines, + ): """Factory for producing a DiffTree. Designed to accept options used by show_diff_trees. @@ -981,24 +1108,42 @@ def from_trees_options(klass, old_tree, new_tree, to_file, :param using: Commandline to use to invoke an external diff tool """ if using is not None: - extra_factories = [DiffFromTool.make_from_diff_tree( - using, external_diff_options)] + extra_factories = [ + DiffFromTool.make_from_diff_tree(using, external_diff_options) + ] else: extra_factories = [] if external_diff_options: opts = external_diff_options.split() - def diff_file(olab, olines, nlab, nlines, to_file, path_encoding=None, context_lines=None): + def diff_file( + olab, + olines, + nlab, + nlines, + to_file, + path_encoding=None, + context_lines=None, + ): """:param path_encoding: not used but required to match the signature of internal_diff. """ external_diff(olab, olines, nlab, nlines, to_file, opts) else: diff_file = internal_diff - diff_text = DiffText(old_tree, new_tree, to_file, path_encoding, - old_label, new_label, diff_file, context_lines=context_lines) - return klass(old_tree, new_tree, to_file, path_encoding, diff_text, - extra_factories) + diff_text = DiffText( + old_tree, + new_tree, + to_file, + path_encoding, + old_label, + new_label, + diff_file, + context_lines=context_lines, + ) + return klass( + old_tree, new_tree, to_file, path_encoding, diff_text, extra_factories + ) def show_diff(self, specific_files, extra_trees=None): """Write tree diff to self.to_file. @@ -1015,10 +1160,12 @@ def show_diff(self, specific_files, extra_trees=None): def _show_diff(self, specific_files, extra_trees): # TODO: Generation of pseudo-diffs for added/deleted files could # be usefully made into a much faster special case. - iterator = self.new_tree.iter_changes(self.old_tree, - specific_files=specific_files, - extra_trees=extra_trees, - require_versioned=True) + iterator = self.new_tree.iter_changes( + self.old_tree, + specific_files=specific_files, + extra_trees=extra_trees, + require_versioned=True, + ) has_changes = 0 def changes_key(change): @@ -1031,49 +1178,68 @@ def changes_key(change): def get_encoded_path(path): if path is not None: return path.encode(self.path_encoding, "replace") + for change in sorted(iterator, key=changes_key): # The root does not get diffed, and items with no known kind (that # is, missing) in both trees are skipped as well. - if (not change.path[0] and not change.path[1]) or change.kind == (None, None): + if (not change.path[0] and not change.path[1]) or change.kind == ( + None, + None, + ): continue - if change.kind[0] == 'symlink' and not self.new_tree.supports_symlinks(): + if change.kind[0] == "symlink" and not self.new_tree.supports_symlinks(): warning( f'Ignoring "{change.path[0]}" as symlinks are not ' - 'supported on this filesystem.') + "supported on this filesystem." + ) continue oldpath, newpath = change.path oldpath_encoded = get_encoded_path(oldpath) newpath_encoded = get_encoded_path(newpath) - old_present = (change.kind[0] is not None and change.versioned[0]) - new_present = (change.kind[1] is not None and change.versioned[1]) + old_present = change.kind[0] is not None and change.versioned[0] + new_present = change.kind[1] is not None and change.versioned[1] executable = change.executable kind = change.kind renamed = change.renamed properties_changed = [] properties_changed.extend( - get_executable_change(executable[0], executable[1])) + get_executable_change(executable[0], executable[1]) + ) if properties_changed: prop_str = b" (properties changed: %s)" % ( - b", ".join(properties_changed),) + b", ".join(properties_changed), + ) else: prop_str = b"" if (old_present, new_present) == (True, False): - self.to_file.write(b"=== removed %s '%s'\n" % - (kind[0].encode('ascii'), oldpath_encoded)) + self.to_file.write( + b"=== removed %s '%s'\n" + % (kind[0].encode("ascii"), oldpath_encoded) + ) elif (old_present, new_present) == (False, True): - self.to_file.write(b"=== added %s '%s'\n" % - (kind[1].encode('ascii'), newpath_encoded)) + self.to_file.write( + b"=== added %s '%s'\n" % (kind[1].encode("ascii"), newpath_encoded) + ) elif renamed: - self.to_file.write(b"=== renamed %s '%s' => '%s'%s\n" % - (kind[0].encode('ascii'), oldpath_encoded, newpath_encoded, prop_str)) + self.to_file.write( + b"=== renamed %s '%s' => '%s'%s\n" + % ( + kind[0].encode("ascii"), + oldpath_encoded, + newpath_encoded, + prop_str, + ) + ) else: # if it was produced by iter_changes, it must be # modified *somehow*, either content or execute bit. - self.to_file.write(b"=== modified %s '%s'%s\n" % (kind[0].encode('ascii'), - newpath_encoded, prop_str)) + self.to_file.write( + b"=== modified %s '%s'%s\n" + % (kind[0].encode("ascii"), newpath_encoded, prop_str) + ) if change.changed_content: self._diff(oldpath, newpath, kind[0], kind[1]) has_changes = 1 @@ -1099,7 +1265,8 @@ def diff(self, old_path, new_path): def _diff(self, old_path, new_path, old_kind, new_kind): result = DiffPath._diff_many( - self.differs, old_path, new_path, old_kind, new_kind) + self.differs, old_path, new_path, old_kind, new_kind + ) if result is DiffPath.CANNOT_DIFF: error_path = new_path if error_path is None: @@ -1108,4 +1275,4 @@ def _diff(self, old_path, new_path, old_kind, new_kind): format_registry = Registry[str, Type[DiffTree], None]() -format_registry.register('default', DiffTree) +format_registry.register("default", DiffTree) diff --git a/breezy/directory_service.py b/breezy/directory_service.py index d628ce64b4..8ff7c9d3e9 100644 --- a/breezy/directory_service.py +++ b/breezy/directory_service.py @@ -26,12 +26,15 @@ from . import errors, registry from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( controldir as _mod_controldir, urlutils, ) -""") +""", +) class DirectoryLookupFailure(errors.BzrError): @@ -39,7 +42,6 @@ class DirectoryLookupFailure(errors.BzrError): class InvalidLocationAlias(DirectoryLookupFailure): - _fmt = '"%(alias_name)s" is not a valid location alias.' def __init__(self, alias_name): @@ -47,8 +49,7 @@ def __init__(self, alias_name): class UnsetLocationAlias(DirectoryLookupFailure): - - _fmt = 'No %(alias_name)s location assigned.' + _fmt = "No %(alias_name)s location assigned." def __init__(self, alias_name): DirectoryLookupFailure.__init__(self, alias_name=alias_name[1:]) @@ -114,23 +115,37 @@ class AliasDirectory(Directory): supported. On error, a subclass of DirectoryLookupFailure will be raised. """ - branch_aliases = registry.Registry[str, Callable[[_mod_branch.Branch], Optional[str]], None]() - branch_aliases.register('parent', lambda b: b.get_parent(), - help="The parent of this branch.") - branch_aliases.register('submit', lambda b: b.get_submit_branch(), - help="The submit branch for this branch.") - branch_aliases.register('public', lambda b: b.get_public_branch(), - help="The public location of this branch.") - branch_aliases.register('bound', lambda b: b.get_bound_location(), - help="The branch this branch is bound to, for bound branches.") - branch_aliases.register('push', lambda b: b.get_push_location(), - help="The saved location used for `brz push` with no arguments.") - branch_aliases.register('this', lambda b: b.base, - help="This branch.") + branch_aliases = registry.Registry[ + str, Callable[[_mod_branch.Branch], Optional[str]], None + ]() + branch_aliases.register( + "parent", lambda b: b.get_parent(), help="The parent of this branch." + ) + branch_aliases.register( + "submit", + lambda b: b.get_submit_branch(), + help="The submit branch for this branch.", + ) + branch_aliases.register( + "public", + lambda b: b.get_public_branch(), + help="The public location of this branch.", + ) + branch_aliases.register( + "bound", + lambda b: b.get_bound_location(), + help="The branch this branch is bound to, for bound branches.", + ) + branch_aliases.register( + "push", + lambda b: b.get_push_location(), + help="The saved location used for `brz push` with no arguments.", + ) + branch_aliases.register("this", lambda b: b.base, help="This branch.") def look_up(self, name, url, purpose=None): - branch = _mod_branch.Branch.open_containing('.')[0] - parts = url.split('/', 1) + branch = _mod_branch.Branch.open_containing(".")[0] + parts = url.split("/", 1) if len(parts) == 2: name, extra = parts else: @@ -170,8 +185,7 @@ def help_text(cls, topic): """ % "".join(alias_lines) -directories.register(':', AliasDirectory, - 'Easy access to remembered branch locations') +directories.register(":", AliasDirectory, "Easy access to remembered branch locations") class ColocatedDirectory(Directory): @@ -182,10 +196,10 @@ class ColocatedDirectory(Directory): """ def look_up(self, name, url, purpose=None): - dir = _mod_controldir.ControlDir.open_containing('.')[0] + dir = _mod_controldir.ControlDir.open_containing(".")[0] return urlutils.join_segment_parameters( - dir.user_url, {"branch": urlutils.escape(name)}) + dir.user_url, {"branch": urlutils.escape(name)} + ) -directories.register('co:', ColocatedDirectory, - 'Easy access to colocated branches') +directories.register("co:", ColocatedDirectory, "Easy access to colocated branches") diff --git a/breezy/dirty_tracker.py b/breezy/dirty_tracker.py index 92c0c4585e..ff1ec1630a 100644 --- a/breezy/dirty_tracker.py +++ b/breezy/dirty_tracker.py @@ -41,13 +41,11 @@ ) - class TooManyOpenFiles(Exception): """Too many open files.""" class _Process(ProcessEvent): # type: ignore - paths: Set[str] created: Set[str] diff --git a/breezy/doc/__init__.py b/breezy/doc/__init__.py index 43c5663ebd..66685eef36 100644 --- a/breezy/doc/__init__.py +++ b/breezy/doc/__init__.py @@ -27,8 +27,8 @@ def load_tests(loader, basic_tests, pattern): suite.addTests(basic_tests) testmod_names = [ - 'breezy.doc.api', - ] + "breezy.doc.api", + ] # add the tests for the sub modules suite.addTests(loader.loadTestsFromModuleNames(testmod_names)) diff --git a/breezy/doc/api/__init__.py b/breezy/doc/api/__init__.py index 82b01bb147..e29fbbe2dc 100644 --- a/breezy/doc/api/__init__.py +++ b/breezy/doc/api/__init__.py @@ -31,7 +31,7 @@ def make_new_test_id(test): - new_id = f'{__name__}.DocFileTest({test.id()})' + new_id = f"{__name__}.DocFileTest({test.id()})" return lambda: new_id @@ -42,11 +42,13 @@ def load_tests(loader, basic_tests, pattern): candidates = os.listdir(dir_) else: candidates = [] - scripts = [candidate for candidate in candidates - if candidate.endswith('.txt')] + scripts = [candidate for candidate in candidates if candidate.endswith(".txt")] # since this module doesn't define tests, we ignore basic_tests - suite = doctest.DocFileSuite(*scripts, setUp=tests.isolated_doctest_setUp, - tearDown=tests.isolated_doctest_tearDown) + suite = doctest.DocFileSuite( + *scripts, + setUp=tests.isolated_doctest_setUp, + tearDown=tests.isolated_doctest_tearDown, + ) # DocFileCase reduces the test id to the base name of the tested file, we # want the module to appears there. for t in tests.iter_suite_tests(suite): diff --git a/breezy/doc_generate/__init__.py b/breezy/doc_generate/__init__.py index e7b3db4969..f2ba20116c 100644 --- a/breezy/doc_generate/__init__.py +++ b/breezy/doc_generate/__init__.py @@ -23,7 +23,7 @@ def get_module(target): mod_name = f"breezy.doc_generate.autodoc_{target}" mod = __import__(mod_name) - components = mod_name.split('.') + components = mod_name.split(".") for comp in components[1:]: mod = getattr(mod, comp) return mod @@ -35,7 +35,6 @@ def get_autodoc_datetime(): :return: A `datetime` object """ try: - return datetime.datetime.utcfromtimestamp( - int(os.environ['SOURCE_DATE_EPOCH'])) + return datetime.datetime.utcfromtimestamp(int(os.environ["SOURCE_DATE_EPOCH"])) except (KeyError, ValueError): return datetime.datetime.utcnow() diff --git a/breezy/doc_generate/autodoc_bash_completion.py b/breezy/doc_generate/autodoc_bash_completion.py index 29eb075d36..24573dffac 100644 --- a/breezy/doc_generate/autodoc_bash_completion.py +++ b/breezy/doc_generate/autodoc_bash_completion.py @@ -28,12 +28,12 @@ def get_filename(options): def infogen(options, outfile): d = get_autodoc_datetime() - params = \ - {"brzcmd": options.brz_name, - "datestamp": d.strftime("%Y-%m-%d"), - "timestamp": d.strftime("%Y-%m-%d %H:%M:%S +0000"), - "version": breezy.__version__, - } + params = { + "brzcmd": options.brz_name, + "datestamp": d.strftime("%Y-%m-%d"), + "timestamp": d.strftime("%Y-%m-%d %H:%M:%S +0000"), + "version": breezy.__version__, + } outfile.write(preamble % params) diff --git a/breezy/doc_generate/autodoc_man.py b/breezy/doc_generate/autodoc_man.py index daf8f62384..4e0d673416 100644 --- a/breezy/doc_generate/autodoc_man.py +++ b/breezy/doc_generate/autodoc_man.py @@ -44,12 +44,12 @@ def get_filename(options): def infogen(options, outfile): """Assembles a man page.""" d = get_autodoc_datetime() - params = \ - {"brzcmd": options.brz_name, - "datestamp": d.strftime("%Y-%m-%d"), - "timestamp": d.strftime("%Y-%m-%d %H:%M:%S +0000"), - "version": breezy.__version__, - } + params = { + "brzcmd": options.brz_name, + "datestamp": d.strftime("%Y-%m-%d"), + "timestamp": d.strftime("%Y-%m-%d %H:%M:%S +0000"), + "version": breezy.__version__, + } outfile.write(man_preamble % params) outfile.write(man_escape(man_head % params)) outfile.write(man_escape(getcommand_list(params))) @@ -72,8 +72,10 @@ def command_name_list(): command_names = breezy.commands.builtin_command_names() for cmdname in breezy.commands.plugin_command_names(): cmd_object = breezy.commands.get_cmd_object(cmdname) - if (PLUGINS_TO_DOCUMENT is None or - cmd_object.plugin_name() in PLUGINS_TO_DOCUMENT): + if ( + PLUGINS_TO_DOCUMENT is None + or cmd_object.plugin_name() in PLUGINS_TO_DOCUMENT + ): command_names.append(cmdname) command_names.sort() return command_names @@ -89,7 +91,7 @@ def getcommand_list(params): continue cmd_help = cmd_object.help() if cmd_help: - firstline = cmd_help.split('\n', 1)[0] + firstline = cmd_help.split("\n", 1)[0] usage = cmd_object._usage() tmp = f'.TP\n.B "{usage}"\n{firstline}\n' output = output + tmp @@ -116,7 +118,7 @@ def getcommand_help(params): def format_command(params, cmd): """Provides long help for each public command.""" - subsection_header = f'.SS "{cmd._usage()}\"\n' + subsection_header = f'.SS "{cmd._usage()}"\n' doc = f"{cmd.__doc__}\n" doc = breezy.help_topics.help_as_plain_text(cmd.help()) @@ -133,50 +135,55 @@ def format_command(params, cmd): for name, short_name, argname, help in option.iter_switches(): if option.is_hidden(name): continue - l = ' --' + name + l = " --" + name if argname is not None: - l += ' ' + argname + l += " " + argname if short_name: - l += ', -' + short_name - l += (30 - len(l)) * ' ' + (help or '') - wrapped = textwrap.fill(l, initial_indent='', - subsequent_indent=30 * ' ', - break_long_words=False, - ) - option_str += wrapped + '\n' + l += ", -" + short_name + l += (30 - len(l)) * " " + (help or "") + wrapped = textwrap.fill( + l, + initial_indent="", + subsequent_indent=30 * " ", + break_long_words=False, + ) + option_str += wrapped + "\n" aliases_str = "" if cmd.aliases: if len(cmd.aliases) > 1: - aliases_str += '\nAliases: ' + aliases_str += "\nAliases: " else: - aliases_str += '\nAlias: ' - aliases_str += ', '.join(cmd.aliases) - aliases_str += '\n' + aliases_str += "\nAlias: " + aliases_str += ", ".join(cmd.aliases) + aliases_str += "\n" see_also_str = "" see_also = cmd.get_see_also() if see_also: - see_also_str += '\nSee also: ' - see_also_str += ', '.join(see_also) - see_also_str += '\n' + see_also_str += "\nSee also: " + see_also_str += ", ".join(see_also) + see_also_str += "\n" - return subsection_header + option_str + aliases_str + see_also_str + "\n" + doc + "\n" + return ( + subsection_header + option_str + aliases_str + see_also_str + "\n" + doc + "\n" + ) def format_alias(params, alias, cmd_name): - help = f'.SS "brz {alias}\"\n' + help = f'.SS "brz {alias}"\n' help += f'Alias for "{cmd_name}", see "brz {cmd_name}".\n' return help def environment_variables(): - yield ".SH \"ENVIRONMENT\"\n" + yield '.SH "ENVIRONMENT"\n' from breezy.help_topics import known_env_variables + for k, desc in known_env_variables(): yield ".TP\n" - yield f".I \"{k}\"\n" + yield f'.I "{k}"\n' yield man_escape(desc) + "\n" diff --git a/breezy/doc_generate/autodoc_rstx.py b/breezy/doc_generate/autodoc_rstx.py index f9c8fd080d..d966849af1 100644 --- a/breezy/doc_generate/autodoc_rstx.py +++ b/breezy/doc_generate/autodoc_rstx.py @@ -38,13 +38,13 @@ def infogen(options, outfile): """Create manual in RSTX format.""" t = time.time() tt = time.gmtime(t) - params = \ - {"brzcmd": options.brz_name, - "datestamp": time.strftime("%Y-%m-%d", tt), - "timestamp": time.strftime("%Y-%m-%d %H:%M:%S +0000", tt), - "version": breezy.__version__, - } - nominated_filename = getattr(options, 'filename', None) + params = { + "brzcmd": options.brz_name, + "datestamp": time.strftime("%Y-%m-%d", tt), + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S +0000", tt), + "version": breezy.__version__, + } + nominated_filename = getattr(options, "filename", None) if nominated_filename is None: topic_dir = None else: @@ -58,18 +58,20 @@ def infogen(options, outfile): def _get_body(params, topic_dir): """Build the manual content.""" from breezy.help_topics import SECT_CONCEPT, SECT_LIST + registry = breezy.help_topics.topic_registry result = [] - result.append(_get_section(registry, SECT_CONCEPT, "Concepts", - output_dir=topic_dir)) - result.append(_get_section(registry, SECT_LIST, "Lists", - output_dir=topic_dir)) + result.append( + _get_section(registry, SECT_CONCEPT, "Concepts", output_dir=topic_dir) + ) + result.append(_get_section(registry, SECT_LIST, "Lists", output_dir=topic_dir)) result.append(_get_commands_section(registry, output_dir=topic_dir)) return "\n".join(result) -def _get_section(registry, section, title, hdg_level1="#", hdg_level2="=", - output_dir=None): +def _get_section( + registry, section, title, hdg_level1="#", hdg_level2="=", output_dir=None +): """Build the manual part from topics matching that section. If output_dir is not None, topics are dumped into text files there @@ -98,8 +100,9 @@ def _get_section(registry, section, title, hdg_level1="#", hdg_level2="=", return "\n" + "\n".join(lines) + "\n" -def _get_commands_section(registry, title="Commands", hdg_level1="#", - hdg_level2="=", output_dir=None): +def _get_commands_section( + registry, title="Commands", hdg_level1="#", hdg_level2="=", output_dir=None +): """Build the commands reference section of the manual.""" file_per_topic = output_dir is not None lines = [title, hdg_level1 * len(title), ""] @@ -129,7 +132,7 @@ def _dump_text(output_dir, topic, text): topic_id = f"{topic}-help" filename = breezy.osutils.pathjoin(output_dir, topic_id + ".txt") with open(filename, "wb") as f: - f.write(text.encode('utf-8')) + f.write(text.encode("utf-8")) return topic_id diff --git a/breezy/doc_generate/conf.py b/breezy/doc_generate/conf.py index f841ae0f5e..27b50d7285 100644 --- a/breezy/doc_generate/conf.py +++ b/breezy/doc_generate/conf.py @@ -21,28 +21,28 @@ # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ - 'sphinx.ext.ifconfig', - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx_epytext', + "sphinx.ext.ifconfig", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx_epytext", # 'sphinxcontrib.napoleon', # TODO: for Google docstrings - ] +] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.txt' +source_suffix = ".txt" # The encoding of source files. -#source_encoding = 'utf-8' +# source_encoding = 'utf-8' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'Breezy' -copyright = '2009-2011 Canonical Ltd, 2017-2018 Breezy Developers' +project = "Breezy" +copyright = "2009-2011 Canonical Ltd, 2017-2018 Breezy Developers" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -51,53 +51,53 @@ # The short X.Y version. import breezy -version = '.'.join(str(p) for p in breezy.version_info[:2]) +version = ".".join(str(p) for p in breezy.version_info[:2]) # The full version, including alpha/beta/rc tags. release = breezy.version_string # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of documents that shouldn't be included in the build. -#unused_docs = [] +# unused_docs = [] # List of directories, relative to source directory, that shouldn't be searched # for source files. -exclude_trees = ['_build'] +exclude_trees = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. Major themes that come with # Sphinx are currently 'default' and 'sphinxdoc'. -html_theme = 'agogo' +html_theme = "agogo" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -107,7 +107,6 @@ # So we stick with the default left placement to cater for users stuck # on those browsers. # 'rightsidebar': True, - # Non-document areas: header (relbar), footer, sidebar, etc. # Some useful colours here: # * blue: darkblue, mediumblue, darkslateblue, cornflowerblue, royalblue, @@ -117,26 +116,25 @@ #'sidebarlinkcolor': "midnightblue", #'relbarbgcolor': "darkblue", #'footerbgcolor': "lightslategray", - # Text, heading and code colouring #'codebgcolor': "lightyellow", #'codetextcolor': "firebrick", #'linkcolor': "mediumblue", - } +} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -147,22 +145,22 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. html_use_modindex = False @@ -171,7 +169,7 @@ html_use_index = False # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. html_show_sourcelink = True @@ -179,22 +177,22 @@ # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = '' +# html_file_suffix = '' # Output file base name for HTML help builder. -htmlhelp_basename = 'brz-docs' +htmlhelp_basename = "brz-docs" # -- Options for LaTeX output -------------------------------------------------- # The paper size ('letter' or 'a4'). -#latex_paper_size = 'letter' +# latex_paper_size = 'letter' # The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' +# latex_font_size = '10pt' # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). @@ -206,26 +204,26 @@ # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # Additional stuff for the LaTeX preamble. -#latex_preamble = '' +# latex_preamble = '' # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_use_modindex = True +# latex_use_modindex = True # -- Bazaar-specific configuration --------------------------------------------- # Authors of the documents -brz_team = 'Breezy Developers' +brz_team = "Breezy Developers" intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'testtools': ('https://testtools.readthedocs.io/en/latest', None), - 'configobj': ('https://configobj.readthedocs.io/en/latest', None), - 'dulwich': ('https://dulwich.readthedocs.io/en/latest', None), + "python": ("https://docs.python.org/3", None), + "testtools": ("https://testtools.readthedocs.io/en/latest", None), + "configobj": ("https://configobj.readthedocs.io/en/latest", None), + "dulwich": ("https://dulwich.readthedocs.io/en/latest", None), } diff --git a/breezy/email_message.py b/breezy/email_message.py index c99b677249..33c60ce0da 100644 --- a/breezy/email_message.py +++ b/breezy/email_message.py @@ -65,12 +65,12 @@ def __init__(self, from_address, to_address, subject, body=None): for addr in to_address: to_addresses.append(self.address_to_encoded_header(addr)) - self._headers['To'] = ', '.join(to_addresses) - self._headers['From'] = self.address_to_encoded_header(from_address) - self._headers['Subject'] = Header(safe_unicode(subject)) - self._headers['User-Agent'] = f'Bazaar ({_breezy_version})' + self._headers["To"] = ", ".join(to_addresses) + self._headers["From"] = self.address_to_encoded_header(from_address) + self._headers["Subject"] = Header(safe_unicode(subject)) + self._headers["User-Agent"] = f"Bazaar ({_breezy_version})" - def add_inline_attachment(self, body, filename=None, mime_subtype='plain'): + def add_inline_attachment(self, body, filename=None, mime_subtype="plain"): """Add an inline attachment to the message. :param body: A text to attach. Can be an unicode string or a byte @@ -87,7 +87,7 @@ def add_inline_attachment(self, body, filename=None, mime_subtype='plain'): # add_inline_attachment() has been called, so the message will be a # MIMEMultipart; add the provided body, if any, as the first attachment if self._body is not None: - self._parts.append((self._body, None, 'plain')) + self._parts.append((self._body, None, "plain")) self._body = None self._parts.append((body, filename, mime_subtype)) @@ -114,11 +114,11 @@ def as_string(self, boundary=None): payload = MIMEText(body, mime_subtype, encoding) if filename is not None: - content_type = payload['Content-Type'] + content_type = payload["Content-Type"] content_type += f'; name="{filename}"' - payload.replace_header('Content-Type', content_type) + payload.replace_header("Content-Type", content_type) - payload['Content-Disposition'] = 'inline' + payload["Content-Disposition"] = "inline" msgobj.attach(payload) # sort headers here to ease testing @@ -145,8 +145,16 @@ def __setitem__(self, header, value): return self._headers.__setitem__(header, value) @staticmethod - def send(config, from_address, to_address, subject, body, attachment=None, - attachment_filename=None, attachment_mime_subtype='plain'): + def send( + config, + from_address, + to_address, + subject, + body, + attachment=None, + attachment_filename=None, + attachment_mime_subtype="plain", + ): """Create an email message and send it with SMTPConnection. :param config: config object to pass to SMTPConnection constructor. @@ -156,8 +164,9 @@ def send(config, from_address, to_address, subject, body, attachment=None, """ msg = EmailMessage(from_address, to_address, subject, body) if attachment is not None: - msg.add_inline_attachment(attachment, attachment_filename, - attachment_mime_subtype) + msg.add_inline_attachment( + attachment, attachment_filename, attachment_mime_subtype + ) SMTPConnection(config).send_email(msg) @staticmethod @@ -175,8 +184,7 @@ def address_to_encoded_header(address): if not user: return email else: - return formataddr((str(Header(safe_unicode(user))), - email)) + return formataddr((str(Header(safe_unicode(user))), email)) @staticmethod def string_with_encoding(string_): @@ -193,16 +201,16 @@ def string_with_encoding(string_): # and decode() whether the body is actually ascii-only. if isinstance(string_, str): try: - return (string_.encode('ascii'), 'ascii') + return (string_.encode("ascii"), "ascii") except UnicodeEncodeError: - return (string_.encode('utf-8'), 'utf-8') + return (string_.encode("utf-8"), "utf-8") else: try: - string_.decode('ascii') - return (string_, 'ascii') + string_.decode("ascii") + return (string_, "ascii") except UnicodeDecodeError: try: - string_.decode('utf-8') - return (string_, 'utf-8') + string_.decode("utf-8") + return (string_, "utf-8") except UnicodeDecodeError: - return (string_, '8-bit') + return (string_, "8-bit") diff --git a/breezy/errors.py b/breezy/errors.py index 27d2a6e4d2..ab6ac1b92b 100644 --- a/breezy/errors.py +++ b/breezy/errors.py @@ -86,7 +86,7 @@ def __init__(self, msg=None, **kwds): setattr(self, key, value) def _format(self): - s = getattr(self, '_preformatted_string', None) + s = getattr(self, "_preformatted_string", None) if s is not None: # contains a preformatted message return s @@ -101,21 +101,21 @@ def _format(self): return s except Exception as e: err = e - return 'Unprintable exception {}: dict={!r}, fmt={!r}, error={!r}'.format(self.__class__.__name__, - self.__dict__, - getattr(self, '_fmt', None), - err) + return "Unprintable exception {}: dict={!r}, fmt={!r}, error={!r}".format( + self.__class__.__name__, self.__dict__, getattr(self, "_fmt", None), err + ) __str__ = _format def __repr__(self): - return f'{self.__class__.__name__}({str(self)})' + return f"{self.__class__.__name__}({str(self)})" def _get_format_string(self): """Return format string for this exception or None.""" - fmt = getattr(self, '_fmt', None) + fmt = getattr(self, "_fmt", None) if fmt is not None: from .i18n import gettext + return gettext(fmt) # _fmt strings should be ascii def __eq__(self, other): @@ -146,7 +146,6 @@ def __init__(self, branch): class BzrCheckError(InternalBzrError): - _fmt = "Internal check failed: %(msg)s" def __init__(self, msg): @@ -155,9 +154,10 @@ def __init__(self, msg): class IncompatibleVersion(BzrError): - - _fmt = 'API %(api)s is not compatible; one of versions %(wanted)r '\ - 'is required, but current version is %(current)r.' + _fmt = ( + "API %(api)s is not compatible; one of versions %(wanted)r " + "is required, but current version is %(current)r." + ) def __init__(self, api, wanted, current): self.api = api @@ -166,16 +166,13 @@ def __init__(self, api, wanted, current): class InProcessTransport(BzrError): - - _fmt = "The transport '%(transport)s' is only accessible within this " \ - "process." + _fmt = "The transport '%(transport)s' is only accessible within this " "process." def __init__(self, transport): self.transport = transport class InvalidRevisionNumber(BzrError): - _fmt = "Invalid revision number %(revno)s" def __init__(self, revno): @@ -184,7 +181,6 @@ def __init__(self, revno): class InvalidRevisionId(BzrError): - _fmt = "Invalid revision-id {%(revision_id)s} in %(branch)s" def __init__(self, revision_id, branch): @@ -195,7 +191,6 @@ def __init__(self, revision_id, branch): class ReservedId(BzrError): - _fmt = "Reserved revision-id {%(revision_id)s}" def __init__(self, revision_id): @@ -203,23 +198,23 @@ def __init__(self, revision_id): class RootMissing(InternalBzrError): - - _fmt = ("The root entry of a tree must be the first entry supplied to " - "the commit builder.") + _fmt = ( + "The root entry of a tree must be the first entry supplied to " + "the commit builder." + ) class NoPublicBranch(BzrError): - _fmt = 'There is no public branch set for "%(branch_url)s".' def __init__(self, branch): from . import urlutils - public_location = urlutils.unescape_for_display(branch.base, 'ascii') + + public_location = urlutils.unescape_for_display(branch.base, "ascii") BzrError.__init__(self, branch_url=public_location) class NoSuchId(BzrError): - _fmt = 'The file id "%(file_id)s" is not present in the tree %(tree)s.' def __init__(self, tree, file_id): @@ -229,12 +224,10 @@ def __init__(self, tree, file_id): class NotStacked(BranchError): - _fmt = "The branch '%(branch)s' is not stacked." class NoWorkingTree(BzrError): - _fmt = 'No WorkingTree exists for "%(base)s".' def __init__(self, base): @@ -243,7 +236,6 @@ def __init__(self, base): class NotLocalUrl(BzrError): - _fmt = "%(url)s is not a local path." def __init__(self, url): @@ -251,7 +243,6 @@ def __init__(self, url): class WorkingTreeAlreadyPopulated(InternalBzrError): - _fmt = 'Working tree already populated in "%(base)s"' def __init__(self, base): @@ -259,10 +250,11 @@ def __init__(self, base): class NoWhoami(BzrError): - - _fmt = ('Unable to determine your name.\n' - "Please, set your name with the 'whoami' command.\n" - 'E.g. brz whoami "Your Name "') + _fmt = ( + "Unable to determine your name.\n" + "Please, set your name with the 'whoami' command.\n" + 'E.g. brz whoami "Your Name "' + ) class CommandError(BzrError): @@ -282,7 +274,6 @@ class CommandError(BzrError): class NotWriteLocked(BzrError): - _fmt = """%(not_locked)r is not write locked but needs to be.""" def __init__(self, not_locked): @@ -290,7 +281,6 @@ def __init__(self, not_locked): class StrictCommitFailed(BzrError): - _fmt = "Commit refused because there are unknown files in the tree" @@ -301,77 +291,72 @@ class StrictCommitFailed(BzrError): # differentiates between 'transport has failed' and 'operation on a transport # has failed.' class PathError(BzrError): - _fmt = "Generic path error: %(path)r%(extra)s)" def __init__(self, path, extra=None): BzrError.__init__(self) self.path = path if extra: - self.extra = ': ' + str(extra) + self.extra = ": " + str(extra) else: - self.extra = '' + self.extra = "" class RenameFailedFilesExist(BzrError): """Used when renaming and both source and dest exist.""" - _fmt = ("Could not rename %(source)s => %(dest)s because both files exist." - " (Use --after to tell brz about a rename that has already" - " happened)%(extra)s") + _fmt = ( + "Could not rename %(source)s => %(dest)s because both files exist." + " (Use --after to tell brz about a rename that has already" + " happened)%(extra)s" + ) def __init__(self, source, dest, extra=None): BzrError.__init__(self) self.source = str(source) self.dest = str(dest) if extra: - self.extra = ' ' + str(extra) + self.extra = " " + str(extra) else: - self.extra = '' + self.extra = "" class NotADirectory(PathError): - _fmt = '"%(path)s" is not a directory %(extra)s' class NotInWorkingDirectory(PathError): - _fmt = '"%(path)s" is not in the working directory %(extra)s' class DirectoryNotEmpty(PathError): - _fmt = 'Directory not empty: "%(path)s"%(extra)s' class HardLinkNotSupported(PathError): - _fmt = 'Hard-linking "%(path)s" is not supported' class ReadingCompleted(InternalBzrError): - - _fmt = ("The MediumRequest '%(request)s' has already had finish_reading " - "called upon it - the request has been completed and no more " - "data may be read.") + _fmt = ( + "The MediumRequest '%(request)s' has already had finish_reading " + "called upon it - the request has been completed and no more " + "data may be read." + ) def __init__(self, request): self.request = request class ResourceBusy(PathError): - _fmt = 'Device or resource busy: "%(path)s"%(extra)s' class PermissionDenied(PathError): - _fmt = 'Permission denied: "%(path)s"%(extra)s' class UnstackableLocationError(BzrError): - _fmt = "The branch '%(branch_url)s' cannot be stacked on '%(target_url)s'." def __init__(self, branch_url, target_url): @@ -381,9 +366,10 @@ def __init__(self, branch_url, target_url): class UnstackableRepositoryFormat(BzrError): - - _fmt = ("The repository '%(url)s'(%(format)s) is not a stackable format. " - "You will need to upgrade the repository to permit branch stacking.") + _fmt = ( + "The repository '%(url)s'(%(format)s) is not a stackable format. " + "You will need to upgrade the repository to permit branch stacking." + ) def __init__(self, format, url): BzrError.__init__(self) @@ -392,14 +378,14 @@ def __init__(self, format, url): class ReadError(PathError): - _fmt = """Error reading from %(path)r%(extra)r.""" class ShortReadvError(PathError): - - _fmt = ('readv() read %(actual)s bytes rather than %(length)s bytes' - ' at %(offset)s for "%(path)s"%(extra)s') + _fmt = ( + "readv() read %(actual)s bytes rather than %(length)s bytes" + ' at %(offset)s for "%(path)s"%(extra)s' + ) internal_error = True @@ -411,7 +397,6 @@ def __init__(self, path, offset, length, actual, extra=None): class PathNotChild(PathError): - _fmt = 'Path "%(path)s" is not a child of path "%(base)s"%(extra)s' internal_error = False @@ -421,13 +406,12 @@ def __init__(self, path, base, extra=None): self.path = path self.base = base if extra: - self.extra = ': ' + str(extra) + self.extra = ": " + str(extra) else: - self.extra = '' + self.extra = "" class InvalidNormalization(PathError): - _fmt = 'Path "%(path)s" is not unicode normalized' @@ -435,20 +419,20 @@ class InvalidNormalization(PathError): # the exception object is a bit undesirable. # TODO: Probably this behavior of should be a common superclass class NotBranchError(PathError): - _fmt = 'Not a branch: "%(path)s"%(detail)s.' def __init__(self, path, detail=None, controldir=None): from . import urlutils - path = urlutils.unescape_for_display(path, 'ascii') + + path = urlutils.unescape_for_display(path, "ascii") if detail is not None: - detail = ': ' + detail + detail = ": " + detail self.detail = detail self.controldir = controldir PathError.__init__(self, path=path) def __repr__(self): - return f'<{self.__class__.__name__} {self.__dict__!r}>' + return f"<{self.__class__.__name__} {self.__dict__!r}>" def _get_format_string(self): # GZ 2017-06-08: Not the best place to lazy fill detail in. @@ -461,7 +445,7 @@ def _get_detail(self): try: self.controldir.open_repository() except NoRepositoryPresent: - return '' + return "" except Exception as e: # Just ignore unexpected errors. Raising arbitrary errors # during str(err) can provoke strange bugs. Concretely @@ -470,33 +454,30 @@ def _get_detail(self): # trying to str() that error. All this error really cares # about that there's no working repository there, and if # open_repository() fails, there probably isn't. - return ': ' + e.__class__.__name__ + return ": " + e.__class__.__name__ else: - return ': location is a repository' - return '' + return ": location is a repository" + return "" class NoSubmitBranch(PathError): - _fmt = 'No submit branch available for branch "%(path)s"' def __init__(self, branch): from . import urlutils - self.path = urlutils.unescape_for_display(branch.base, 'ascii') + self.path = urlutils.unescape_for_display(branch.base, "ascii") -class AlreadyControlDirError(PathError): +class AlreadyControlDirError(PathError): _fmt = 'A control directory already exists: "%(path)s".' class AlreadyBranchError(PathError): - _fmt = 'Already a branch: "%(path)s".' class InvalidBranchName(PathError): - _fmt = "Invalid branch name: %(name)s" def __init__(self, name): @@ -505,20 +486,16 @@ def __init__(self, name): class ParentBranchExists(AlreadyBranchError): - _fmt = 'Parent branch already exists: "%(path)s".' class BranchExistsWithoutWorkingTree(PathError): - _fmt = 'Directory contains a branch, but no working tree \ (use brz checkout if you wish to build a working tree): "%(path)s"' class InaccessibleParent(PathError): - - _fmt = ('Parent not accessible given base "%(base)s" and' - ' relative path "%(path)s"') + _fmt = 'Parent not accessible given base "%(base)s" and' ' relative path "%(path)s"' def __init__(self, path, base): PathError.__init__(self, path) @@ -526,38 +503,32 @@ def __init__(self, path, base): class NoRepositoryPresent(BzrError): - _fmt = 'No repository present: "%(path)s"' def __init__(self, controldir): BzrError.__init__(self) - self.path = controldir.transport.clone('..').base + self.path = controldir.transport.clone("..").base class UnsupportedFormatError(BzrError): - _fmt = "Unsupported branch format: %(format)s\nPlease run 'brz upgrade'" - class UnsupportedVcs(UnsupportedFormatError): - vcs: str _fmt = "Unsupported version control system: %(vcs)s" class UnknownFormatError(BzrError): - _fmt = "Unknown %(kind)s format: %(format)r" - def __init__(self, format, kind='branch'): + def __init__(self, format, kind="branch"): self.kind = kind self.format = format class IncompatibleFormat(BzrError): - _fmt = "Format %(format)s is not compatible with .bzr version %(controldir)s." def __init__(self, format, controldir_format): @@ -567,7 +538,6 @@ def __init__(self, format, controldir_format): class ParseFormatError(BzrError): - _fmt = "Parse error on line %(lineno)d of %(format)s format: %(line)s" def __init__(self, format, lineno, line, text): @@ -586,10 +556,7 @@ class IncompatibleRepositories(BzrError): repository the client hasn't opened. """ - _fmt = "%(target)s\n" \ - "is not compatible with\n" \ - "%(source)s\n" \ - "%(details)s" + _fmt = "%(target)s\n" "is not compatible with\n" "%(source)s\n" "%(details)s" def __init__(self, source, target, details=None): if details is None: @@ -598,7 +565,6 @@ def __init__(self, source, target, details=None): class IncompatibleRevision(BzrError): - _fmt = "Revision is not compatible with %(repo_format)s" def __init__(self, repo_format): @@ -622,7 +588,7 @@ def __init__(self, path, context_info=None): BzrError.__init__(self) self.path = path if context_info is None: - self.context_info = '' + self.context_info = "" else: self.context_info = context_info + ". " @@ -643,7 +609,7 @@ def __init__(self, path, context_info=None): BzrError.__init__(self) self.path = path if context_info is None: - self.context_info = '' + self.context_info = "" else: self.context_info = context_info + ". " @@ -655,13 +621,13 @@ class PathsNotVersionedError(BzrError): def __init__(self, paths): from .osutils import quotefn + BzrError.__init__(self) self.paths = paths - self.paths_as_string = ' '.join([quotefn(p) for p in paths]) + self.paths_as_string = " ".join([quotefn(p) for p in paths]) class PathsDoNotExist(BzrError): - _fmt = "Path(s) do not exist: %(paths_as_string)s%(extra)s" # used when reporting that paths are neither versioned nor in the working @@ -670,17 +636,17 @@ class PathsDoNotExist(BzrError): def __init__(self, paths, extra=None): # circular import from .osutils import quotefn + BzrError.__init__(self) self.paths = paths - self.paths_as_string = ' '.join([quotefn(p) for p in paths]) + self.paths_as_string = " ".join([quotefn(p) for p in paths]) if extra: - self.extra = ': ' + str(extra) + self.extra = ": " + str(extra) else: - self.extra = '' + self.extra = "" class BadFileKindError(BzrError): - _fmt = 'Cannot operate on "%(filename)s" of unsupported kind "%(kind)s"' def __init__(self, filename, kind): @@ -688,12 +654,10 @@ def __init__(self, filename, kind): class ForbiddenControlFileError(BzrError): - _fmt = 'Cannot operate on "%(filename)s" because it is a control file' class LockError(InternalBzrError): - _fmt = "Lock error: %(msg)s" # All exceptions from the lock/unlock functions should be from @@ -706,7 +670,6 @@ def __init__(self, msg): class LockActive(LockError): - _fmt = "The lock for '%(lock_description)s' is in use and cannot be broken." internal_error = False @@ -716,7 +679,6 @@ def __init__(self, lock_description): class CommitNotPossible(LockError): - _fmt = "A commit was attempted but we do not have a write lock open." def __init__(self): @@ -724,7 +686,6 @@ def __init__(self): class AlreadyCommitted(LockError): - _fmt = "A rollback was requested, but is not able to be accomplished." def __init__(self): @@ -732,7 +693,6 @@ def __init__(self): class ReadOnlyError(LockError): - _fmt = "A write attempt was made in a read only transaction on %(obj)s" # TODO: There should also be an error indicating that you need a write @@ -743,25 +703,24 @@ def __init__(self, obj): class LockFailed(LockError): - internal_error = False _fmt = "Cannot lock %(lock)s: %(why)s" def __init__(self, lock, why): - LockError.__init__(self, '') + LockError.__init__(self, "") self.lock = lock self.why = why class OutSideTransaction(BzrError): - - _fmt = ("A transaction related operation was attempted after" - " the transaction finished.") + _fmt = ( + "A transaction related operation was attempted after" + " the transaction finished." + ) class ObjectNotLocked(LockError): - _fmt = "%(obj)r is not locked" # this can indicate that any particular object is not locked; see also @@ -772,7 +731,6 @@ def __init__(self, obj): class ReadOnlyObjectDirtiedError(ReadOnlyError): - _fmt = "Cannot change object %(obj)r in read only transaction" def __init__(self, obj): @@ -780,7 +738,6 @@ def __init__(self, obj): class UnlockableTransport(LockError): - internal_error = False _fmt = "Cannot lock: transport is read only: %(transport)s" @@ -790,20 +747,17 @@ def __init__(self, transport): class LockContention(LockError): - _fmt = 'Could not acquire lock "%(lock)s": %(msg)s' internal_error = False - def __init__(self, lock, msg=''): + def __init__(self, lock, msg=""): self.lock = lock self.msg = msg class LockBroken(LockError): - - _fmt = ("Lock was broken while still open: %(lock)s" - " - check storage consistency!") + _fmt = "Lock was broken while still open: %(lock)s" " - check storage consistency!" internal_error = False @@ -812,9 +766,10 @@ def __init__(self, lock): class LockBreakMismatch(LockError): - - _fmt = ("Lock was released and re-acquired before being broken:" - " %(lock)s: held by %(holder)r, wanted to break %(target)r") + _fmt = ( + "Lock was released and re-acquired before being broken:" + " %(lock)s: held by %(holder)r, wanted to break %(target)r" + ) internal_error = False @@ -825,9 +780,10 @@ def __init__(self, lock, holder, target): class LockCorrupt(LockError): - - _fmt = ("Lock is apparently held, but corrupted: %(corruption_info)s\n" - "Use 'brz break-lock' to clear it") + _fmt = ( + "Lock is apparently held, but corrupted: %(corruption_info)s\n" + "Use 'brz break-lock' to clear it" + ) internal_error = False @@ -837,7 +793,6 @@ def __init__(self, corruption_info, file_data=None): class LockNotHeld(LockError): - _fmt = "Lock not held: %(lock)s" internal_error = False @@ -847,7 +802,6 @@ def __init__(self, lock): class TokenLockingNotSupported(LockError): - _fmt = "The object %(obj)s does not support token specifying a token when locking." def __init__(self, obj): @@ -855,7 +809,6 @@ def __init__(self, obj): class TokenMismatch(LockBroken): - _fmt = "The lock token %(given_token)r does not match lock token %(lock_token)r." internal_error = True @@ -866,12 +819,10 @@ def __init__(self, given_token, lock_token): class UpgradeReadonly(BzrError): - _fmt = "Upgrade URL cannot work with readonly URLs." class UpToDateFormat(BzrError): - _fmt = "The branch format %(format)s is already at the most recent format." def __init__(self, format): @@ -880,7 +831,6 @@ def __init__(self, format): class NoSuchRevision(InternalBzrError): - revision: bytes _fmt = "%(branch)s has no revision %(revision)s" @@ -891,12 +841,10 @@ def __init__(self, branch, revision): class RangeInChangeOption(BzrError): - _fmt = "Option --change does not accept revision ranges" class NoSuchRevisionSpec(BzrError): - _fmt = "No namespace registered for string: %(spec)r" def __init__(self, spec): @@ -915,22 +863,25 @@ def __init__(self, tree, revision_id): class AppendRevisionsOnlyViolation(BzrError): - - _fmt = ('Operation denied because it would change the main history,' - ' which is not permitted by the append_revisions_only setting on' - ' branch "%(location)s".') + _fmt = ( + "Operation denied because it would change the main history," + " which is not permitted by the append_revisions_only setting on" + ' branch "%(location)s".' + ) def __init__(self, location): import breezy.urlutils as urlutils - location = urlutils.unescape_for_display(location, 'ascii') + + location = urlutils.unescape_for_display(location, "ascii") BzrError.__init__(self, location=location) class DivergedBranches(BzrError): - - _fmt = ("These branches have diverged." - " Use the missing command to see how.\n" - "Use the merge command to reconcile them.") + _fmt = ( + "These branches have diverged." + " Use the missing command to see how.\n" + "Use the merge command to reconcile them." + ) def __init__(self, branch1, branch2): self.branch1 = branch1 @@ -938,7 +889,6 @@ def __init__(self, branch1, branch2): class NotLefthandHistory(InternalBzrError): - _fmt = "Supplied history does not follow left-hand parents" def __init__(self, history): @@ -946,19 +896,16 @@ def __init__(self, history): class UnrelatedBranches(BzrError): - - _fmt = ("Branches have no common ancestor, and" - " no merge base revision was specified.") + _fmt = ( + "Branches have no common ancestor, and" " no merge base revision was specified." + ) class CannotReverseCherrypick(BzrError): - - _fmt = ('Selected merge cannot perform reverse cherrypicks. Try merge3' - ' or diff3.') + _fmt = "Selected merge cannot perform reverse cherrypicks. Try merge3" " or diff3." class NoCommonAncestor(BzrError): - _fmt = "Revisions have no common ancestor: %(revision_a)s %(revision_b)s" def __init__(self, revision_a, revision_b): @@ -967,56 +914,54 @@ def __init__(self, revision_a, revision_b): class NoCommonRoot(BzrError): - - _fmt = ("Revisions are not derived from the same root: " - "%(revision_a)s %(revision_b)s.") + _fmt = ( + "Revisions are not derived from the same root: " + "%(revision_a)s %(revision_b)s." + ) def __init__(self, revision_a, revision_b): BzrError.__init__(self, revision_a=revision_a, revision_b=revision_b) class NotAncestor(BzrError): - _fmt = "Revision %(rev_id)s is not an ancestor of %(not_ancestor_id)s" def __init__(self, rev_id, not_ancestor_id): - BzrError.__init__(self, rev_id=rev_id, - not_ancestor_id=not_ancestor_id) + BzrError.__init__(self, rev_id=rev_id, not_ancestor_id=not_ancestor_id) class NoCommits(BranchError): - _fmt = "Branch %(branch)s has no commits." class UnlistableStore(BzrError): - def __init__(self, store): BzrError.__init__(self, f"Store {store} is not listable") class UnlistableBranch(BzrError): - def __init__(self, br): BzrError.__init__(self, f"Stores for branch {br} are not listable") class BoundBranchOutOfDate(BzrError): - - _fmt = ("Bound branch %(branch)s is out of date with master branch" - " %(master)s.%(extra_help)s") + _fmt = ( + "Bound branch %(branch)s is out of date with master branch" + " %(master)s.%(extra_help)s" + ) def __init__(self, branch, master): BzrError.__init__(self) self.branch = branch self.master = master - self.extra_help = '' + self.extra_help = "" class CommitToDoubleBoundBranch(BzrError): - - _fmt = ("Cannot commit to branch %(branch)s." - " It is bound to %(master)s, which is bound to %(remote)s.") + _fmt = ( + "Cannot commit to branch %(branch)s." + " It is bound to %(master)s, which is bound to %(remote)s." + ) def __init__(self, branch, master, remote): BzrError.__init__(self) @@ -1026,7 +971,6 @@ def __init__(self, branch, master, remote): class OverwriteBoundBranch(BzrError): - _fmt = "Cannot pull --overwrite to a branch which is bound %(branch)s" def __init__(self, branch): @@ -1035,9 +979,10 @@ def __init__(self, branch): class BoundBranchConnectionFailure(BzrError): - - _fmt = ("Unable to connect to target of bound branch %(branch)s" - " => %(target)s: %(error)s") + _fmt = ( + "Unable to connect to target of bound branch %(branch)s" + " => %(target)s: %(error)s" + ) def __init__(self, branch, target, error): BzrError.__init__(self) @@ -1047,12 +992,10 @@ def __init__(self, branch, target, error): class VersionedFileError(BzrError): - _fmt = "Versioned file error" class RevisionNotPresent(VersionedFileError): - _fmt = 'Revision {%(revision_id)s} not present in "%(file_id)s".' def __init__(self, revision_id, file_id): @@ -1062,7 +1005,6 @@ def __init__(self, revision_id, file_id): class RevisionAlreadyPresent(VersionedFileError): - _fmt = 'Revision {%(revision_id)s} already present in "%(file_id)s".' def __init__(self, revision_id, file_id): @@ -1072,12 +1014,10 @@ def __init__(self, revision_id, file_id): class VersionedFileInvalidChecksum(VersionedFileError): - _fmt = "Text did not match its checksum: %(msg)s" class NoSuchExportFormat(BzrError): - _fmt = "Export format %(format)r not supported" def __init__(self, format): @@ -1086,23 +1026,21 @@ def __init__(self, format): class TransportError(BzrError): - _fmt = "Transport error: %(msg)s %(orig_error)s" def __init__(self, msg=None, orig_error=None): if msg is None and orig_error is not None: msg = str(orig_error) if orig_error is None: - orig_error = '' + orig_error = "" if msg is None: - msg = '' + msg = "" self.msg = msg self.orig_error = orig_error BzrError.__init__(self) class SmartProtocolError(TransportError): - _fmt = "Generic bzr smart protocol error: %(details)s" def __init__(self, details): @@ -1110,7 +1048,6 @@ def __init__(self, details): class UnexpectedProtocolVersionMarker(TransportError): - _fmt = "Received bad protocol version marker: %(marker)r" def __init__(self, marker): @@ -1118,7 +1055,6 @@ def __init__(self, marker): class UnknownSmartMethod(InternalBzrError): - _fmt = "The server does not recognise the '%(verb)s' request." def __init__(self, verb): @@ -1127,35 +1063,31 @@ def __init__(self, verb): # A set of semi-meaningful errors which can be thrown class TransportNotPossible(TransportError): - _fmt = "Transport operation not possible: %(msg)s %(orig_error)s" class SocketConnectionError(ConnectionError): - def __init__(self, host, port=None, msg=None, orig_error=None): if msg is None: - msg = 'Failed to connect to' + msg = "Failed to connect to" if orig_error is None: - orig_error = '' + orig_error = "" else: - orig_error = '; ' + str(orig_error) + orig_error = "; " + str(orig_error) self.host = host if port is None: - port = '' + port = "" else: - port = f':{port}' + port = f":{port}" self.port = port ConnectionError.__init__(self, f"{msg} {host}{port}{orig_error}") class ConnectionTimeout(ConnectionError): - _fmt = "Connection Timeout: %(msg)s%(orig_error)s" class InvalidRange(TransportError): - _fmt = "Invalid range access in %(path)s at %(offset)s: %(msg)s" def __init__(self, path, offset, msg=None): @@ -1165,38 +1097,34 @@ def __init__(self, path, offset, msg=None): class InvalidHttpResponse(TransportError): - _fmt = "Invalid http response for %(path)s: %(msg)s%(orig_error)s" def __init__(self, path, msg, orig_error=None, headers=None): self.path = path if orig_error is None: - orig_error = '' + orig_error = "" else: # This is reached for obscure and unusual errors so we want to # preserve as much info as possible to ease debug. - orig_error = f': {orig_error!r}' + orig_error = f": {orig_error!r}" self.headers = headers TransportError.__init__(self, msg, orig_error=orig_error) class UnexpectedHttpStatus(InvalidHttpResponse): - _fmt = "Unexpected HTTP status %(code)d for %(path)s: %(extra)s" def __init__(self, path, code, extra=None, headers=None): self.path = path self.code = code - self.extra = extra or '' - full_msg = 'status code %d unexpected' % code + self.extra = extra or "" + full_msg = "status code %d unexpected" % code if extra is not None: - full_msg += ': ' + extra - InvalidHttpResponse.__init__( - self, path, full_msg, headers=headers) + full_msg += ": " + extra + InvalidHttpResponse.__init__(self, path, full_msg, headers=headers) class BadHttpRequest(UnexpectedHttpStatus): - _fmt = "Bad http request for %(path)s: %(reason)s" def __init__(self, path, reason): @@ -1206,7 +1134,6 @@ def __init__(self, path, reason): class InvalidHttpRange(InvalidHttpResponse): - _fmt = "Invalid http range %(range)r for %(path)s: %(msg)s" def __init__(self, path, range, msg): @@ -1228,7 +1155,6 @@ def __init__(self, path, msg): class InvalidHttpContentType(InvalidHttpResponse): - _fmt = 'Invalid http Content-type "%(ctype)s" for %(path)s: %(msg)s' def __init__(self, path, ctype, msg): @@ -1237,31 +1163,27 @@ def __init__(self, path, ctype, msg): class RedirectRequested(TransportError): - - _fmt = '%(source)s is%(permanently)s redirected to %(target)s' + _fmt = "%(source)s is%(permanently)s redirected to %(target)s" def __init__(self, source, target, is_permanent=False): self.source = source self.target = target if is_permanent: - self.permanently = ' permanently' + self.permanently = " permanently" else: - self.permanently = '' + self.permanently = "" TransportError.__init__(self) class TooManyRedirections(TransportError): - _fmt = "Too many redirections" class ConflictsInTree(BzrError): - _fmt = "Working tree has conflicts." class DependencyNotPresent(BzrError): - _fmt = 'Unable to import library "%(library)s": %(error)s' def __init__(self, library, error): @@ -1269,17 +1191,17 @@ def __init__(self, library, error): class WorkingTreeNotRevision(BzrError): - - _fmt = ("The working tree for %(basedir)s has changed since" - " the last commit, but weave merge requires that it be" - " unchanged") + _fmt = ( + "The working tree for %(basedir)s has changed since" + " the last commit, but weave merge requires that it be" + " unchanged" + ) def __init__(self, tree): BzrError.__init__(self, basedir=tree.basedir) class GraphCycleError(BzrError): - _fmt = "Cycle in graph %(graph)r" def __init__(self, graph): @@ -1288,26 +1210,27 @@ def __init__(self, graph): class WritingCompleted(InternalBzrError): - - _fmt = ("The MediumRequest '%(request)s' has already had finish_writing " - "called upon it - accept bytes may not be called anymore.") + _fmt = ( + "The MediumRequest '%(request)s' has already had finish_writing " + "called upon it - accept bytes may not be called anymore." + ) def __init__(self, request): self.request = request class WritingNotComplete(InternalBzrError): - - _fmt = ("The MediumRequest '%(request)s' has not has finish_writing " - "called upon it - until the write phase is complete no " - "data may be read.") + _fmt = ( + "The MediumRequest '%(request)s' has not has finish_writing " + "called upon it - until the write phase is complete no " + "data may be read." + ) def __init__(self, request): self.request = request class NotConflicted(BzrError): - _fmt = "File %(filename)s is not conflicted." def __init__(self, filename): @@ -1316,7 +1239,6 @@ def __init__(self, filename): class MediumNotConnected(InternalBzrError): - _fmt = """The medium '%(medium)s' is not connected.""" def __init__(self, medium): @@ -1324,12 +1246,10 @@ def __init__(self, medium): class MustUseDecorated(Exception): - _fmt = "A decorating function has requested its original command be used." class NoBundleFound(BzrError): - _fmt = 'No bundle was found in "%(filename)s".' def __init__(self, filename): @@ -1338,7 +1258,6 @@ def __init__(self, filename): class BundleNotSupported(BzrError): - _fmt = "Unable to handle bundle version %(version)s: %(msg)s" def __init__(self, version, msg): @@ -1348,9 +1267,7 @@ def __init__(self, version, msg): class MissingText(BzrError): - - _fmt = ("Branch %(base)s is missing revision" - " %(text_revision)s of %(file_id)s") + _fmt = "Branch %(base)s is missing revision" " %(text_revision)s of %(file_id)s" def __init__(self, branch, text_revision, file_id): BzrError.__init__(self) @@ -1361,12 +1278,10 @@ def __init__(self, branch, text_revision, file_id): class DuplicateKey(BzrError): - _fmt = "Key %(key)s is already present in map" class DuplicateHelpPrefix(BzrError): - _fmt = "The prefix %(prefix)s is in the help search path twice." def __init__(self, prefix): @@ -1374,7 +1289,6 @@ def __init__(self, prefix): class BzrBadParameter(InternalBzrError): - _fmt = "Bad parameter: %(param)r" # This exception should never be thrown, but it is a base class for all @@ -1386,34 +1300,34 @@ def __init__(self, param): class BzrBadParameterNotUnicode(BzrBadParameter): - _fmt = "Parameter %(param)s is neither unicode nor utf8." class BzrMoveFailedError(BzrError): + _fmt = ( + "Could not move %(from_path)s%(operator)s %(to_path)s" "%(_has_extra)s%(extra)s" + ) - _fmt = ("Could not move %(from_path)s%(operator)s %(to_path)s" - "%(_has_extra)s%(extra)s") - - def __init__(self, from_path='', to_path='', extra=None): + def __init__(self, from_path="", to_path="", extra=None): from .osutils import splitpath + BzrError.__init__(self) if extra: - self.extra, self._has_extra = extra, ': ' + self.extra, self._has_extra = extra, ": " else: - self.extra = self._has_extra = '' + self.extra = self._has_extra = "" has_from = len(from_path) > 0 has_to = len(to_path) > 0 if has_from: self.from_path = splitpath(from_path)[-1] else: - self.from_path = '' + self.from_path = "" if has_to: self.to_path = splitpath(to_path)[-1] else: - self.to_path = '' + self.to_path = "" self.operator = "" if has_from and has_to: @@ -1427,45 +1341,39 @@ def __init__(self, from_path='', to_path='', extra=None): class BzrRenameFailedError(BzrMoveFailedError): - - _fmt = ("Could not rename %(from_path)s%(operator)s %(to_path)s" - "%(_has_extra)s%(extra)s") + _fmt = ( + "Could not rename %(from_path)s%(operator)s %(to_path)s" + "%(_has_extra)s%(extra)s" + ) def __init__(self, from_path, to_path, extra=None): BzrMoveFailedError.__init__(self, from_path, to_path, extra) class BzrBadParameterNotString(BzrBadParameter): - _fmt = "Parameter %(param)s is not a string or unicode string." class BzrBadParameterMissing(BzrBadParameter): - _fmt = "Parameter %(param)s is required but not present." class BzrBadParameterUnicode(BzrBadParameter): - - _fmt = ("Parameter %(param)s is unicode but" - " only byte-strings are permitted.") + _fmt = "Parameter %(param)s is unicode but" " only byte-strings are permitted." class BzrBadParameterContainsNewline(BzrBadParameter): - _fmt = "Parameter %(param)s contains a newline." class ParamikoNotPresent(DependencyNotPresent): - _fmt = "Unable to import paramiko (required for sftp support): %(error)s" def __init__(self, error): - DependencyNotPresent.__init__(self, 'paramiko', error) + DependencyNotPresent.__init__(self, "paramiko", error) class UninitializableFormat(BzrError): - _fmt = "Format %(format)s cannot be initialised by this version of brz." def __init__(self, format): @@ -1474,19 +1382,19 @@ def __init__(self, format): class BadConversionTarget(BzrError): - - _fmt = "Cannot convert from format %(from_format)s to format %(format)s." \ + _fmt = ( + "Cannot convert from format %(from_format)s to format %(format)s." " %(problem)s" + ) def __init__(self, problem, format, from_format=None): BzrError.__init__(self) self.problem = problem self.format = format - self.from_format = from_format or '(unspecified)' + self.from_format = from_format or "(unspecified)" class NoDiffFound(BzrError): - _fmt = 'Could not find an appropriate Differ for file "%(path)s"' def __init__(self, path): @@ -1494,7 +1402,6 @@ def __init__(self, path): class ExecutableMissing(BzrError): - _fmt = "%(exe_name)s could not be found on this machine" def __init__(self, exe_name): @@ -1502,7 +1409,6 @@ def __init__(self, exe_name): class NoDiff(BzrError): - _fmt = "Diff is not installed on this machine: %(msg)s" def __init__(self, msg): @@ -1510,12 +1416,10 @@ def __init__(self, msg): class NoDiff3(BzrError): - _fmt = "Diff3 is not installed on this machine." class ExistingLimbo(BzrError): - _fmt = """This tree contains left-over files from a failed operation. Please examine %(limbo_dir)s to see if it contains any files you wish to keep, and delete it when you are done.""" @@ -1526,7 +1430,6 @@ def __init__(self, limbo_dir): class ExistingPendingDeletion(BzrError): - _fmt = """This tree contains left-over files from a failed operation. Please examine %(pending_deletion)s to see if it contains any files you wish to keep, and delete it when you are done.""" @@ -1536,56 +1439,52 @@ def __init__(self, pending_deletion): class ImmortalPendingDeletion(BzrError): - - _fmt = ("Unable to delete transform temporary directory " - "%(pending_deletion)s. Please examine %(pending_deletion)s to see if it " - "contains any files you wish to keep, and delete it when you are done.") + _fmt = ( + "Unable to delete transform temporary directory " + "%(pending_deletion)s. Please examine %(pending_deletion)s to see if it " + "contains any files you wish to keep, and delete it when you are done." + ) def __init__(self, pending_deletion): BzrError.__init__(self, pending_deletion=pending_deletion) class OutOfDateTree(BzrError): - _fmt = "Working tree is out of date, please run 'brz update'.%(more)s" def __init__(self, tree, more=None): if more is None: - more = '' + more = "" else: - more = ' ' + more + more = " " + more BzrError.__init__(self) self.tree = tree self.more = more class PublicBranchOutOfDate(BzrError): - - _fmt = 'Public branch "%(public_location)s" lacks revision '\ - '"%(revstring)s".' + _fmt = 'Public branch "%(public_location)s" lacks revision ' '"%(revstring)s".' def __init__(self, public_location, revstring): import breezy.urlutils as urlutils - public_location = urlutils.unescape_for_display(public_location, - 'ascii') - BzrError.__init__(self, public_location=public_location, - revstring=revstring) + public_location = urlutils.unescape_for_display(public_location, "ascii") + BzrError.__init__(self, public_location=public_location, revstring=revstring) -class MergeModifiedFormatError(BzrError): +class MergeModifiedFormatError(BzrError): _fmt = "Error in merge modified format" class ConflictFormatError(BzrError): - _fmt = "Format error in conflict listings" class CorruptRepository(BzrError): - - _fmt = ("An error has been detected in the repository %(repo_path)s.\n" - "Please run brz reconcile on this repository.") + _fmt = ( + "An error has been detected in the repository %(repo_path)s.\n" + "Please run brz reconcile on this repository." + ) def __init__(self, repo): BzrError.__init__(self) @@ -1595,8 +1494,10 @@ def __init__(self, repo): class InconsistentDelta(BzrError): """Used when we get a delta that is not valid.""" - _fmt = ("An inconsistent delta was supplied involving %(path)r," - " %(file_id)r\nreason: %(reason)s") + _fmt = ( + "An inconsistent delta was supplied involving %(path)r," + " %(file_id)r\nreason: %(reason)s" + ) def __init__(self, path, file_id, reason): BzrError.__init__(self) @@ -1608,8 +1509,7 @@ def __init__(self, path, file_id, reason): class InconsistentDeltaDelta(InconsistentDelta): """Used when we get a delta that is not valid.""" - _fmt = ("An inconsistent delta was supplied: %(delta)r" - "\nreason: %(reason)s") + _fmt = "An inconsistent delta was supplied: %(delta)r" "\nreason: %(reason)s" def __init__(self, delta, reason): BzrError.__init__(self) @@ -1618,7 +1518,6 @@ def __init__(self, delta, reason): class UpgradeRequired(BzrError): - _fmt = "To use this feature you must upgrade your branch at %(path)s." def __init__(self, path): @@ -1627,25 +1526,22 @@ def __init__(self, path): class RepositoryUpgradeRequired(UpgradeRequired): - _fmt = "To use this feature you must upgrade your repository at %(path)s." class RichRootUpgradeRequired(UpgradeRequired): - - _fmt = ("To use this feature you must upgrade your branch at %(path)s to" - " a format which supports rich roots.") + _fmt = ( + "To use this feature you must upgrade your branch at %(path)s to" + " a format which supports rich roots." + ) class LocalRequiresBoundBranch(BzrError): - _fmt = "Cannot perform local-only commits on unbound branches." class UnsupportedOperation(BzrError): - - _fmt = ("The method %(mname)s is not supported on" - " objects of type %(tname)s.") + _fmt = "The method %(mname)s is not supported on" " objects of type %(tname)s." def __init__(self, method, method_self): self.method = method @@ -1654,8 +1550,7 @@ def __init__(self, method, method_self): class FetchLimitUnsupported(UnsupportedOperation): - - fmt = ("InterBranch %(interbranch)r does not support fetching limits.") + fmt = "InterBranch %(interbranch)r does not support fetching limits." def __init__(self, interbranch): BzrError.__init__(self, interbranch=interbranch) @@ -1675,7 +1570,6 @@ def __init__(self, format): class GhostTagsNotSupported(BzrError): - _fmt = "Ghost tags not supported by format %(format)r." def __init__(self, format): @@ -1683,12 +1577,10 @@ def __init__(self, format): class BinaryFile(BzrError): - _fmt = "File is binary but should be text." class IllegalPath(BzrError): - _fmt = "The path %(path)s is not permitted on this platform" def __init__(self, path): @@ -1697,7 +1589,6 @@ def __init__(self, path): class TestamentMismatch(BzrError): - _fmt = """Testament did not match expected value. For revision_id {%(revision_id)s}, expected {%(expected)s}, measured {%(measured)s}""" @@ -1709,7 +1600,6 @@ def __init__(self, revision_id, expected, measured): class NotABundle(BzrError): - _fmt = "Not a bzr revision-bundle: %(text)r" def __init__(self, text): @@ -1718,7 +1608,6 @@ def __init__(self, text): class BadBundle(BzrError): - _fmt = "Bad bzr revision-bundle: %(text)r" def __init__(self, text): @@ -1727,22 +1616,18 @@ def __init__(self, text): class MalformedHeader(BadBundle): - _fmt = "Malformed bzr revision-bundle header: %(text)r" class MalformedPatches(BadBundle): - _fmt = "Malformed patches in bzr revision-bundle: %(text)r" class MalformedFooter(BadBundle): - _fmt = "Malformed footer in bzr revision-bundle: %(text)r" class UnsupportedEOLMarker(BadBundle): - _fmt = "End of line marker was not \\n in bzr revision-bundle" def __init__(self): @@ -1752,7 +1637,6 @@ def __init__(self): class IncompatibleBundleFormat(BzrError): - _fmt = "Bundle format %(bundle_format)s is incompatible with %(other)s" def __init__(self, bundle_format, other): @@ -1762,12 +1646,10 @@ def __init__(self, bundle_format, other): class RootNotRich(BzrError): - _fmt = """This operation requires rich root data storage""" class NoSmartMedium(InternalBzrError): - _fmt = "The transport '%(transport)s' cannot tunnel the smart protocol." def __init__(self, transport): @@ -1775,7 +1657,6 @@ def __init__(self, transport): class UnknownSSH(BzrError): - _fmt = "Unrecognised value for BRZ_SSH environment variable: %(vendor)s" def __init__(self, vendor): @@ -1784,16 +1665,19 @@ def __init__(self, vendor): class SSHVendorNotFound(BzrError): - - _fmt = ("Don't know how to handle SSH connections." - " Please set BRZ_SSH environment variable.") + _fmt = ( + "Don't know how to handle SSH connections." + " Please set BRZ_SSH environment variable." + ) class GhostRevisionsHaveNoRevno(BzrError): """When searching for revnos, if we encounter a ghost, we are stuck.""" - _fmt = ("Could not determine revno for {%(revision_id)s} because" - " its ancestry shows a ghost at {%(ghost_revision_id)s}") + _fmt = ( + "Could not determine revno for {%(revision_id)s} because" + " its ancestry shows a ghost at {%(ghost_revision_id)s}" + ) def __init__(self, revision_id, ghost_revision_id): self.revision_id = revision_id @@ -1801,7 +1685,6 @@ def __init__(self, revision_id, ghost_revision_id): class GhostRevisionUnusableHere(BzrError): - _fmt = "Ghost revision {%(revision_id)s} cannot be used here." def __init__(self, revision_id): @@ -1819,8 +1702,9 @@ def __init__(self, firstline): class NoMergeSource(BzrError): """Raise if no merge source was specified for a merge directive.""" - _fmt = "A merge directive must provide either a bundle or a public"\ - " branch location." + _fmt = ( + "A merge directive must provide either a bundle or a public" " branch location." + ) class PatchVerificationFailed(BzrError): @@ -1842,10 +1726,12 @@ def __init__(self, patch_type): class TargetNotBranch(BzrError): """A merge directive's target branch is required, but isn't a branch.""" - _fmt = ("Your branch does not have all of the revisions required in " - "order to merge this merge directive and the target " - "location specified in the merge directive is not a branch: " - "%(location)s.") + _fmt = ( + "Your branch does not have all of the revisions required in " + "order to merge this merge directive and the target " + "location specified in the merge directive is not a branch: " + "%(location)s." + ) def __init__(self, location): BzrError.__init__(self) @@ -1853,7 +1739,6 @@ def __init__(self, location): class BadSubsumeSource(BzrError): - _fmt = "Can't subsume %(other_tree)s into %(tree)s. %(reason)s" def __init__(self, tree, other_tree, reason): @@ -1863,7 +1748,6 @@ def __init__(self, tree, other_tree, reason): class SubsumeTargetNeedsUpgrade(BzrError): - _fmt = """Subsume target %(other_tree)s needs to be upgraded.""" def __init__(self, other_tree): @@ -1871,7 +1755,6 @@ def __init__(self, other_tree): class NoSuchTag(BzrError): - _fmt = "No such tag: %(tag_name)s" def __init__(self, tag_name): @@ -1879,9 +1762,10 @@ def __init__(self, tag_name): class TagsNotSupported(BzrError): - - _fmt = ("Tags not supported by %(branch)s;" - " you may be able to use 'brz upgrade %(branch_url)s'.") + _fmt = ( + "Tags not supported by %(branch)s;" + " you may be able to use 'brz upgrade %(branch_url)s'." + ) def __init__(self, branch): self.branch = branch @@ -1889,7 +1773,6 @@ def __init__(self, branch): class TagAlreadyExists(BzrError): - _fmt = "Tag %(tag_name)s already exists." def __init__(self, tag_name): @@ -1897,7 +1780,6 @@ def __init__(self, tag_name): class UnexpectedSmartServerResponse(BzrError): - _fmt = "Could not understand response from smart server: %(response_tuple)r" def __init__(self, response_tuple): @@ -1924,7 +1806,6 @@ def __init__(self, error_tuple): class RepositoryDataStreamError(BzrError): - _fmt = "Corrupt or incompatible data stream: %(reason)s" def __init__(self, reason): @@ -1932,65 +1813,66 @@ def __init__(self, reason): class UncommittedChanges(BzrError): - - _fmt = ('Working tree "%(display_url)s" has uncommitted changes' - ' (See brz status).%(more)s') + _fmt = ( + 'Working tree "%(display_url)s" has uncommitted changes' + " (See brz status).%(more)s" + ) def __init__(self, tree, more=None): if more is None: - more = '' + more = "" else: - more = ' ' + more + more = " " + more import breezy.urlutils as urlutils + user_url = getattr(tree, "user_url", None) if user_url is None: display_url = str(tree) else: - display_url = urlutils.unescape_for_display(user_url, 'ascii') + display_url = urlutils.unescape_for_display(user_url, "ascii") BzrError.__init__(self, tree=tree, display_url=display_url, more=more) class StoringUncommittedNotSupported(BzrError): - - _fmt = ('Branch "%(display_url)s" does not support storing uncommitted' - ' changes.') + _fmt = 'Branch "%(display_url)s" does not support storing uncommitted' " changes." def __init__(self, branch): import breezy.urlutils as urlutils + user_url = getattr(branch, "user_url", None) if user_url is None: display_url = str(branch) else: - display_url = urlutils.unescape_for_display(user_url, 'ascii') + display_url = urlutils.unescape_for_display(user_url, "ascii") BzrError.__init__(self, branch=branch, display_url=display_url) class ShelvedChanges(UncommittedChanges): - - _fmt = ('Working tree "%(display_url)s" has shelved changes' - ' (See brz shelve --list).%(more)s') + _fmt = ( + 'Working tree "%(display_url)s" has shelved changes' + " (See brz shelve --list).%(more)s" + ) class UnableEncodePath(BzrError): - - _fmt = ('Unable to encode %(kind)s path %(path)r in ' - 'user encoding %(user_encoding)s') + _fmt = ( + "Unable to encode %(kind)s path %(path)r in " "user encoding %(user_encoding)s" + ) def __init__(self, path, kind): from .osutils import get_user_encoding + self.path = path self.kind = kind self.user_encoding = get_user_encoding() class CannotBindAddress(BzrError): - _fmt = 'Cannot bind address "%(host)s:%(port)i": %(orig_error)s.' def __init__(self, host, port, orig_error): # nb: in python2.4 socket.error doesn't have a useful repr - BzrError.__init__(self, host=host, port=port, - orig_error=repr(orig_error.args)) + BzrError.__init__(self, host=host, port=port, orig_error=repr(orig_error.args)) class TipChangeRejected(BzrError): @@ -2005,7 +1887,6 @@ def __init__(self, msg): class JailBreak(BzrError): - _fmt = "An attempt to access a url outside the server jail was made: '%(url)s'." def __init__(self, url): @@ -2013,14 +1894,14 @@ def __init__(self, url): class UserAbort(BzrError): - - _fmt = 'The user aborted the operation.' + _fmt = "The user aborted the operation." class UnresumableWriteGroup(BzrError): - - _fmt = ("Repository %(repository)s cannot resume write group " - "%(write_groups)r: %(reason)s") + _fmt = ( + "Repository %(repository)s cannot resume write group " + "%(write_groups)r: %(reason)s" + ) internal_error = True @@ -2031,8 +1912,7 @@ def __init__(self, repository, write_groups, reason): class UnsuspendableWriteGroup(BzrError): - - _fmt = ("Repository %(repository)s cannot suspend a write group.") + _fmt = "Repository %(repository)s cannot suspend a write group." internal_error = True @@ -2041,9 +1921,10 @@ def __init__(self, repository): class LossyPushToSameVCS(BzrError): - - _fmt = ("Lossy push not possible between %(source_branch)r and " - "%(target_branch)r that are in the same VCS.") + _fmt = ( + "Lossy push not possible between %(source_branch)r and " + "%(target_branch)r that are in the same VCS." + ) internal_error = True @@ -2053,9 +1934,10 @@ def __init__(self, source_branch, target_branch): class NoRoundtrippingSupport(BzrError): - - _fmt = ("Roundtripping is not supported between %(source_branch)r and " - "%(target_branch)r.") + _fmt = ( + "Roundtripping is not supported between %(source_branch)r and " + "%(target_branch)r." + ) internal_error = True @@ -2065,18 +1947,20 @@ def __init__(self, source_branch, target_branch): class RecursiveBind(BzrError): - - _fmt = ('Branch "%(branch_url)s" appears to be bound to itself. ' - 'Please use `brz unbind` to fix.') + _fmt = ( + 'Branch "%(branch_url)s" appears to be bound to itself. ' + "Please use `brz unbind` to fix." + ) def __init__(self, branch_url): self.branch_url = branch_url class UnsupportedKindChange(BzrError): - - _fmt = ("Kind change from %(from_kind)s to %(to_kind)s for " - "%(path)s not supported by format %(format)r") + _fmt = ( + "Kind change from %(from_kind)s to %(to_kind)s for " + "%(path)s not supported by format %(format)r" + ) def __init__(self, path, from_kind, to_kind, format): self.path = path @@ -2086,16 +1970,19 @@ def __init__(self, path, from_kind, to_kind, format): class ChangesAlreadyStored(CommandError): - - _fmt = ('Cannot store uncommitted changes because this branch already' - ' stores uncommitted changes.') + _fmt = ( + "Cannot store uncommitted changes because this branch already" + " stores uncommitted changes." + ) class RevnoOutOfBounds(InternalBzrError): - - _fmt = ("The requested revision number %(revno)d is outside of the " - "expected boundaries (%(minimum)d <= %(maximum)d).") + _fmt = ( + "The requested revision number %(revno)d is outside of the " + "expected boundaries (%(minimum)d <= %(maximum)d)." + ) def __init__(self, revno, bounds): InternalBzrError.__init__( - self, revno=revno, minimum=bounds[0], maximum=bounds[1]) + self, revno=revno, minimum=bounds[0], maximum=bounds[1] + ) diff --git a/breezy/export.py b/breezy/export.py index 673350a80c..8d74123ab8 100644 --- a/breezy/export.py +++ b/breezy/export.py @@ -23,9 +23,16 @@ from . import archive, errors, osutils, trace -def export(tree, dest, format=None, root=None, subdir=None, - per_file_timestamps=False, fileobj=None, - recurse_nested=False): +def export( + tree, + dest, + format=None, + root=None, + subdir=None, + per_file_timestamps=False, + fileobj=None, + recurse_nested=False, +): """Export the given Tree to the specific destination. Args: @@ -58,10 +65,11 @@ def export(tree, dest, format=None, root=None, subdir=None, if not per_file_timestamps: force_mtime = time.time() - if getattr(tree, '_repository', None): + if getattr(tree, "_repository", None): try: force_mtime = tree._repository.get_revision( - tree.get_revision_id()).timestamp + tree.get_revision_id() + ).timestamp except errors.NoSuchRevision: pass except errors.UnsupportedOperation: @@ -69,35 +77,40 @@ def export(tree, dest, format=None, root=None, subdir=None, else: force_mtime = None - trace.mutter('export version %r', tree) + trace.mutter("export version %r", tree) - if format == 'dir': + if format == "dir": # TODO(jelmer): If the tree is remote (e.g. HPSS, Git Remote), # then we should stream a tar file and unpack that on the fly. with tree.lock_read(): - for _unused in dir_exporter_generator(tree, dest, root, subdir, - force_mtime, - recurse_nested=recurse_nested): + for _unused in dir_exporter_generator( + tree, dest, root, subdir, force_mtime, recurse_nested=recurse_nested + ): pass return with tree.lock_read(): - chunks = tree.archive(format, dest, root=root, - subdir=subdir, force_mtime=force_mtime, - recurse_nested=recurse_nested) - if dest == '-': + chunks = tree.archive( + format, + dest, + root=root, + subdir=subdir, + force_mtime=force_mtime, + recurse_nested=recurse_nested, + ) + if dest == "-": for chunk in chunks: - getattr(sys.stdout, 'buffer', sys.stdout).write(chunk) + getattr(sys.stdout, "buffer", sys.stdout).write(chunk) elif fileobj is not None: for chunk in chunks: fileobj.write(chunk) else: - with open(dest, 'wb') as f: + with open(dest, "wb") as f: for chunk in chunks: f.write(chunk) -def guess_format(filename, default='dir'): +def guess_format(filename, default="dir"): """Guess the export format based on a file name. :param filename: Filename to guess from @@ -113,13 +126,13 @@ def guess_format(filename, default='dir'): def get_root_name(dest): """Get just the root name for an export.""" global _exporter_extensions - if dest == '-': + if dest == "-": # Exporting to -/foo doesn't make sense so use relative paths. - return '' + return "" dest = os.path.basename(dest) for ext in archive.format_registry.extensions: if dest.endswith(ext): - return dest[:-len(ext)] + return dest[: -len(ext)] return dest @@ -132,24 +145,24 @@ def _export_iter_entries(tree, subdir, skip_special=True, recurse_nested=False): :return: iterator over tuples with final path, tree path and inventory entry for each entry to export """ - if subdir == '': + if subdir == "": subdir = None if subdir is not None: - subdir = subdir.rstrip('/') + subdir = subdir.rstrip("/") entries = tree.iter_entries_by_dir(recurse_nested=recurse_nested) for path, entry in entries: - if path == '': + if path == "": continue if skip_special and tree.is_special_path(path): continue if path == subdir: - if entry.kind == 'directory': + if entry.kind == "directory": continue final_path = entry.name elif subdir is not None: - if path.startswith(subdir + '/'): - final_path = path[len(subdir) + 1:] + if path.startswith(subdir + "/"): + final_path = path[len(subdir) + 1 :] else: continue else: @@ -160,9 +173,9 @@ def _export_iter_entries(tree, subdir, skip_special=True, recurse_nested=False): yield final_path, path, entry -def dir_exporter_generator(tree, dest, root, subdir=None, - force_mtime=None, fileobj=None, - recurse_nested=False): +def dir_exporter_generator( + tree, dest, root, subdir=None, force_mtime=None, fileobj=None, recurse_nested=False +): """Return a generator that exports this tree to a new directory. `dest` should either not exist or should be empty. If it does not exist it @@ -176,8 +189,7 @@ def dir_exporter_generator(tree, dest, root, subdir=None, except FileExistsError as e: # check if directory empty if os.listdir(dest) != []: - raise errors.BzrError( - "Can't export tree to non-empty directory.") from e + raise errors.BzrError("Can't export tree to non-empty directory.") from e # Iterate everything, building up the files we will want to export, and # creating the directories and symlinks that we need. # This tracks (None, (destination_path, executable)) @@ -185,8 +197,7 @@ def dir_exporter_generator(tree, dest, root, subdir=None, # Note in the case of revision trees, this does trigger a double inventory # lookup, hopefully it isn't too expensive. to_fetch = [] - for dp, tp, ie in _export_iter_entries( - tree, subdir, recurse_nested=recurse_nested): + for dp, tp, ie in _export_iter_entries(tree, subdir, recurse_nested=recurse_nested): fullpath = osutils.pathjoin(dest, dp) if ie.kind == "file": to_fetch.append((tp, (dp, tp, None))) @@ -198,21 +209,24 @@ def dir_exporter_generator(tree, dest, root, subdir=None, os.symlink(symlink_target, fullpath) except OSError as e: raise errors.BzrError( - f"Failed to create symlink {fullpath!r} -> {symlink_target!r}, error: {e}") from e + f"Failed to create symlink {fullpath!r} -> {symlink_target!r}, error: {e}" + ) from e else: - raise errors.BzrError(f"don't know how to export {{{tp}}} of kind {ie.kind!r}") + raise errors.BzrError( + f"don't know how to export {{{tp}}} of kind {ie.kind!r}" + ) yield # The data returned here can be in any order, but we've already created all # the directories - flags = os.O_CREAT | os.O_TRUNC | os.O_WRONLY | getattr(os, 'O_BINARY', 0) + flags = os.O_CREAT | os.O_TRUNC | os.O_WRONLY | getattr(os, "O_BINARY", 0) for (relpath, treepath, _unused_none), chunks in tree.iter_files_bytes(to_fetch): fullpath = osutils.pathjoin(dest, relpath) # We set the mode and let the umask sort out the file info mode = 0o666 if tree.is_executable(treepath): mode = 0o777 - with os.fdopen(os.open(fullpath, flags, mode), 'wb') as out: + with os.fdopen(os.open(fullpath, flags, mode), "wb") as out: out.writelines(chunks) if force_mtime is not None: mtime = force_mtime diff --git a/breezy/export_pot.py b/breezy/export_pot.py index 5712740ec4..346754300e 100644 --- a/breezy/export_pot.py +++ b/breezy/export_pot.py @@ -40,33 +40,35 @@ def _escape(s): - s = (s.replace('\\', '\\\\') - .replace('\n', '\\n') - .replace('\r', '\\r') - .replace('\t', '\\t') - .replace('"', '\\"') - ) + s = ( + s.replace("\\", "\\\\") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + .replace('"', '\\"') + ) return s def _normalize(s): # This converts the various Python string types into a format that # is appropriate for .po files, namely much closer to C style. - lines = s.split('\n') + lines = s.split("\n") if len(lines) == 1: s = '"' + _escape(s) + '"' else: if not lines[-1]: del lines[-1] - lines[-1] = lines[-1] + '\n' + lines[-1] = lines[-1] + "\n" lineterm = '\\n"\n"' s = '""\n"' + lineterm.join(map(_escape, lines)) + '"' return s -def _parse_source(source_text, filename=''): +def _parse_source(source_text, filename=""): """Get object to lineno mappings from given source_text.""" import ast + cls_to_lineno = {} str_to_lineno = {} for node in ast.walk(ast.parse(source_text, filename)): @@ -79,7 +81,9 @@ def _parse_source(source_text, filename=''): # string terminates on. It's more useful to have the line the # string begins on. Unfortunately, counting back newlines is # only an approximation as the AST is ignorant of escaping. - str_to_lineno[node.s] = node.lineno - (0 if sys.version_info >= (3, 8) else node.s.count('\n')) + str_to_lineno[node.s] = node.lineno - ( + 0 if sys.version_info >= (3, 8) else node.s.count("\n") + ) return cls_to_lineno, str_to_lineno @@ -98,8 +102,12 @@ def from_module(cls, module): sourcepath = inspect.getsourcefile(module) # TODO: fix this to do the right thing rather than rely on cwd relpath = os.path.relpath(sourcepath) - return cls(relpath, - _source_info=_parse_source("".join(inspect.findsource(module)[0]), module.__file__)) + return cls( + relpath, + _source_info=_parse_source( + "".join(inspect.findsource(module)[0]), module.__file__ + ), + ) def from_class(self, cls): """Get new context with same details but lineno of class in source.""" @@ -108,8 +116,9 @@ def from_class(self, cls): except (AttributeError, KeyError): mutter("Definition of %r not found in %r", cls, self.path) return self - return self.__class__(self.path, lineno, - (self._cls_to_lineno, self._str_to_lineno)) + return self.__class__( + self.path, lineno, (self._cls_to_lineno, self._str_to_lineno) + ) def from_string(self, string): """Get new context with same details but lineno of string in source.""" @@ -118,8 +127,9 @@ def from_string(self, string): except (AttributeError, KeyError): mutter("String %r not found in %r", string[:20], self.path) return self - return self.__class__(self.path, lineno, - (self._cls_to_lineno, self._str_to_lineno)) + return self.__class__( + self.path, lineno, (self._cls_to_lineno, self._str_to_lineno) + ) class _PotExporter: @@ -139,12 +149,11 @@ def poentry(self, path, lineno, s, comment=None): return self._msgids.add(s) if comment is None: - comment = '' + comment = "" else: comment = f"# {comment}\n" mutter("Exporting msg %r at line %d in %r", s[:20], lineno, path) - line = ( - f"#: {path}:{lineno}\n{comment}msgid {_normalize(s)}\nmsgstr \"\"\n\n") + line = f'#: {path}:{lineno}\n{comment}msgid {_normalize(s)}\nmsgstr ""\n\n' self.outf.write(line) def poentry_in_context(self, context, string, comment=None): @@ -153,12 +162,12 @@ def poentry_in_context(self, context, string, comment=None): def poentry_per_paragraph(self, path, lineno, msgid, include=None): # TODO: How to split long help? - paragraphs = msgid.split('\n\n') + paragraphs = msgid.split("\n\n") if include is not None: paragraphs = filter(include, paragraphs) for p in paragraphs: self.poentry(path, lineno, p) - lineno += p.count('\n') + 2 + lineno += p.count("\n") + 2 def get_context(self, obj): module = inspect.getmodule(obj) @@ -173,20 +182,18 @@ def get_context(self, obj): def _write_option(exporter, context, opt, note): - if getattr(opt, 'hidden', False): + if getattr(opt, "hidden", False): return optname = opt.name - if getattr(opt, 'title', None): - exporter.poentry_in_context(context, opt.title, - f"title of {optname!r} {note}") + if getattr(opt, "title", None): + exporter.poentry_in_context(context, opt.title, f"title of {optname!r} {note}") for name, _, _, helptxt in opt.iter_switches(): if name != optname: if opt.is_hidden(name): continue name = "=".join([optname, name]) if helptxt: - exporter.poentry_in_context(context, helptxt, - f"help of {name!r} {note}") + exporter.poentry_in_context(context, helptxt, f"help of {name!r} {note}") def _standard_options(exporter): @@ -214,11 +221,10 @@ def _write_command_help(exporter, cmd): def exclude_usage(p): # ':Usage:' has special meaning in help topics. # This is usage example of command and should not be translated. - if p.splitlines()[0] != ':Usage:': + if p.splitlines()[0] != ":Usage:": return True - exporter.poentry_per_paragraph(dcontext.path, dcontext.lineno, doc, - exclude_usage) + exporter.poentry_per_paragraph(dcontext.path, dcontext.lineno, doc, exclude_usage) _command_options(exporter, context, cmd) @@ -241,10 +247,10 @@ def _command_helps(exporter, plugin_name=None): plugins = _mod_plugin.plugins() if plugin_name is not None and plugin_name not in plugins: - raise errors.BzrError(gettext('Plugin {} is not loaded').format(plugin_name)) + raise errors.BzrError(gettext("Plugin {} is not loaded").format(plugin_name)) core_plugins = { - name for name in plugins - if plugins[name].path().startswith(breezy.__path__[0])} + name for name in plugins if plugins[name].path().startswith(breezy.__path__[0]) + } # plugins for cmd_name in _mod_commands.plugin_command_names(): command = _mod_commands.get_cmd_object(cmd_name, False) @@ -258,8 +264,11 @@ def _command_helps(exporter, plugin_name=None): # skip non-core plugins # TODO: Support extracting from third party plugins. continue - note(gettext("Exporting messages from plugin command: {0} in {1}").format( - cmd_name, command.plugin_name())) + note( + gettext("Exporting messages from plugin command: {0} in {1}").format( + cmd_name, command.plugin_name() + ) + ) _write_command_help(exporter, command) @@ -289,16 +298,15 @@ def _help_topics(exporter): doc = topic_registry.get(key) if isinstance(doc, str): exporter.poentry_per_paragraph( - 'dummy/help_topics/' + key + '/detail.txt', - 1, doc) + "dummy/help_topics/" + key + "/detail.txt", 1, doc + ) elif callable(doc): # help topics from files exporter.poentry_per_paragraph( - 'en/help_topics/' + key + '.txt', - 1, doc(key)) + "en/help_topics/" + key + ".txt", 1, doc(key) + ) summary = topic_registry.get_summary(key) if summary is not None: - exporter.poentry('dummy/help_topics/' + key + '/summary.txt', - 1, summary) + exporter.poentry("dummy/help_topics/" + key + "/summary.txt", 1, summary) def export_pot(outf, plugin=None, include_duplicates=False): diff --git a/breezy/externalcommand.py b/breezy/externalcommand.py index 4f2089db72..4f18ad5a22 100644 --- a/breezy/externalcommand.py +++ b/breezy/externalcommand.py @@ -30,7 +30,8 @@ class ExternalCommand(Command): @classmethod def find_command(cls, cmd): import os.path - bzrpath = os.environ.get('BZRPATH', '') + + bzrpath = os.environ.get("BZRPATH", "") for dir in bzrpath.split(os.pathsep): # Empty directories are not real paths @@ -51,12 +52,12 @@ def name(self): return os.path.basename(self.path) def run(self, *args, **kwargs): - raise NotImplementedError(f'should not be called on {self!r}') + raise NotImplementedError(f"should not be called on {self!r}") def run_argv_aliases(self, argv, alias_argv=None): return os.spawnv(os.P_WAIT, self.path, [self.path] + argv) # noqa: S606 def help(self): - m = f'external command from {self.path}\n\n' - pipe = os.popen(f'{self.path} --help') # noqa: S605 + m = f"external command from {self.path}\n\n" + pipe = os.popen(f"{self.path} --help") # noqa: S605 return m + pipe.read() diff --git a/breezy/fetch_ghosts.py b/breezy/fetch_ghosts.py index 52ba2deaaf..c59b820fe0 100644 --- a/breezy/fetch_ghosts.py +++ b/breezy/fetch_ghosts.py @@ -22,15 +22,13 @@ class GhostFetcher: - @classmethod def from_cmdline(cls, other): - this_branch = Branch.open_containing('.')[0] + this_branch = Branch.open_containing(".")[0] if other is None: other = this_branch.get_parent() if other is None: - raise CommandError('No branch specified and no location' - ' saved.') + raise CommandError("No branch specified and no location" " saved.") else: note("Using saved location %s.", other) other_branch = Branch.open_containing(other)[0] diff --git a/breezy/fifo_cache.py b/breezy/fifo_cache.py index daf06a951b..a08886cfbe 100644 --- a/breezy/fifo_cache.py +++ b/breezy/fifo_cache.py @@ -29,9 +29,10 @@ def __init__(self, max_cache: int = 100, after_cleanup_count=None) -> None: if after_cleanup_count is None: self._after_cleanup_count = self._max_cache * 8 // 10 else: - self._after_cleanup_count = min(after_cleanup_count, - self._max_cache) - self._cleanup: Dict[Any, Callable[[], None]] = {} # map to cleanup functions when items are removed + self._after_cleanup_count = min(after_cleanup_count, self._max_cache) + self._cleanup: Dict[ + Any, Callable[[], None] + ] = {} # map to cleanup functions when items are removed self._queue: Deque[Any] = deque() # Track when things are accessed def __setitem__(self, key, value): @@ -79,8 +80,10 @@ def cleanup(self): while len(self) > self._after_cleanup_count: self._remove_oldest() if len(self._queue) != len(self): - raise AssertionError('The length of the queue should always equal' - f' the length of the dict. {len(self._queue)} != {len(self)}') + raise AssertionError( + "The length of the queue should always equal" + f" the length of the dict. {len(self._queue)} != {len(self)}" + ) def clear(self): """Clear out all of the cache.""" @@ -153,7 +156,7 @@ def update(self, *args, **kwargs): for key, val in args[0]: self.add(key, val) elif len(args) > 1: - raise TypeError(f'update expected at most 1 argument, got {len(args)}') + raise TypeError(f"update expected at most 1 argument, got {len(args)}") if kwargs: for key in kwargs: self.add(key, kwargs[key]) @@ -166,8 +169,9 @@ class FIFOSizeCache(FIFOCache): it restricts the cache to be cleaned based on the size of the data. """ - def __init__(self, max_size=1024 * 1024, after_cleanup_size=None, - compute_size=None): + def __init__( + self, max_size=1024 * 1024, after_cleanup_size=None, compute_size=None + ): """Create a new FIFOSizeCache. :param max_size: The max number of bytes to store before we start diff --git a/breezy/filter_tree.py b/breezy/filter_tree.py index 7148e8dc57..f447465bc1 100644 --- a/breezy/filter_tree.py +++ b/breezy/filter_tree.py @@ -43,7 +43,7 @@ def get_file_text(self, path): filters = self.filter_stack_callback(path) context = ContentFilterContext(path, self) contents = filtered_output_bytes(chunks, filters, context) - content = b''.join(contents) + content = b"".join(contents) return content def get_file(self, path): @@ -62,7 +62,8 @@ def iter_entries_by_dir(self, specific_files=None, recurse_nested=False): # updated to a narrower interface that only provides things guaranteed # cheaply available across all trees. -- mbp 20110705 return self.backing_tree.iter_entries_by_dir( - specific_files=specific_files, recurse_nested=recurse_nested) + specific_files=specific_files, recurse_nested=recurse_nested + ) def lock_read(self): return self.backing_tree.lock_read() diff --git a/breezy/filters/__init__.py b/breezy/filters/__init__.py index d025115454..c30c26badb 100644 --- a/breezy/filters/__init__.py +++ b/breezy/filters/__init__.py @@ -48,7 +48,6 @@ class ContentFilter: - def __init__(self, reader, writer): """Create a filter that converts content while reading and writing. @@ -97,8 +96,7 @@ def revision_id(self): """Id of revision that last changed this file.""" if self._revision_id is None: if self._tree is not None: - self._revision_id = self._tree.get_file_revision( - self._relpath) + self._revision_id = self._tree.get_file_revision(self._relpath) return self._revision_id def revision(self): @@ -106,7 +104,7 @@ def revision(self): if self._revision is None: rev_id = self.revision_id() if rev_id is not None: - repo = getattr(self._tree, '_repository', None) + repo = getattr(self._tree, "_repository", None) if repo is None: repo = self._tree.branch.repository self._revision = repo.get_revision(rev_id) @@ -126,7 +124,7 @@ def filtered_input_file(f, filters) -> tuple[BytesIO, int]: for filter in filters: if filter.reader is not None: chunks = filter.reader(chunks) - text = b''.join(chunks) + text = b"".join(chunks) return BytesIO(text), len(text) @@ -155,14 +153,13 @@ def internal_size_sha_file_byname(name, filters): name: path to file filters: the stack of filters to apply """ - with open(name, 'rb', 65000) as f: + with open(name, "rb", 65000) as f: if filters: f, size = filtered_input_file(f, filters) return osutils.size_sha_file(f) class FilteredStat: - def __init__(self, base, st_size=None): self.st_mode = base.st_mode self.st_size = st_size or base.st_size @@ -171,7 +168,9 @@ def __init__(self, base, st_size=None): # The registry of filter stacks indexed by name. -filter_stacks_registry = registry.Registry[str, Callable[[str], List[ContentFilter]], None]() +filter_stacks_registry = registry.Registry[ + str, Callable[[str], List[ContentFilter]], None +]() # Cache of preferences -> stack @@ -241,4 +240,4 @@ def _reset_registry(value=None): return original -filter_stacks_registry.register_lazy('eol', 'breezy.filters.eol', 'eol_lookup') +filter_stacks_registry.register_lazy("eol", "breezy.filters.eol", "eol_lookup") diff --git a/breezy/filters/eol.py b/breezy/filters/eol.py index 2647ac94aa..39bef2d5ae 100644 --- a/breezy/filters/eol.py +++ b/breezy/filters/eol.py @@ -26,43 +26,40 @@ from ..filters import ContentFilter # Real Unix newline - \n without \r before it -_UNIX_NL_RE = re.compile(br'(? Tuple[ - TagUpdates, Set[TagConflict]]: - if self.source.branch.repository.has_same_location(self.target.branch.repository): + def merge( + self, + overwrite: bool = False, + ignore_master: bool = False, + selector: Optional[TagSelector] = None, + ) -> Tuple[TagUpdates, Set[TagConflict]]: + if self.source.branch.repository.has_same_location( + self.target.branch.repository + ): return {}, set() updates = {} conflicts = [] @@ -123,38 +127,44 @@ def merge(self, overwrite: bool = False, ignore_master: bool = False, def get_changed_refs(old_refs): ret = dict(old_refs) - for ref_name, tag_name, peeled, unpeeled in ( - source_tag_refs.iteritems()): + for ref_name, tag_name, peeled, unpeeled in source_tag_refs.iteritems(): if selector and not selector(tag_name): continue if old_refs.get(ref_name) == unpeeled: pass elif overwrite or ref_name not in old_refs: ret[ref_name] = unpeeled - updates[tag_name] = self.target.branch.repository.lookup_foreign_revision_id( - peeled) + updates[ + tag_name + ] = self.target.branch.repository.lookup_foreign_revision_id(peeled) ref_to_tag_map[ref_name] = tag_name self.target.branch._tag_refs = None else: conflicts.append( - (tag_name, - self.source.branch.repository.lookup_foreign_revision_id(peeled), - self.target.branch.repository.lookup_foreign_revision_id( - old_refs[ref_name]))) + ( + tag_name, + self.source.branch.repository.lookup_foreign_revision_id( + peeled + ), + self.target.branch.repository.lookup_foreign_revision_id( + old_refs[ref_name] + ), + ) + ) return ret + result = self.target.branch.repository.controldir.send_pack( - get_changed_refs, lambda have, want: []) + get_changed_refs, lambda have, want: [] + ) if result is not None and not isinstance(result, dict): for ref, error in result.ref_status.items(): if error: - warning('unable to update ref %s: %s', - ref, error) + warning("unable to update ref %s: %s", ref, error) del updates[ref_to_tag_map[ref]] return updates, set(conflicts) class InterTagsFromGitToLocalGit(InterTags): - @classmethod def is_compatible(klass, source, target): if not isinstance(source, GitTags): @@ -166,7 +176,9 @@ def is_compatible(klass, source, target): return True def merge(self, overwrite=False, ignore_master=False, selector=None): - if self.source.branch.repository.has_same_location(self.target.branch.repository): + if self.source.branch.repository.has_same_location( + self.target.branch.repository + ): return {}, [] conflicts = [] @@ -182,38 +194,34 @@ def merge(self, overwrite=False, ignore_master=False, selector=None): pass elif overwrite or ref_name not in target_repo._git.refs: try: - updates[tag_name] = ( - target_repo.lookup_foreign_revision_id(peeled)) + updates[tag_name] = target_repo.lookup_foreign_revision_id(peeled) except KeyError: - trace.warning('%s does not point to a valid object', - tag_name) + trace.warning("%s does not point to a valid object", tag_name) continue except NotCommitError: - trace.warning('%s points to a non-commit object', - tag_name) + trace.warning("%s points to a non-commit object", tag_name) continue target_repo._git.refs[ref_name] = unpeeled or peeled self.target.branch._tag_refs = None else: try: - source_revid = self.source.branch.repository.lookup_foreign_revision_id( - peeled) + source_revid = ( + self.source.branch.repository.lookup_foreign_revision_id(peeled) + ) target_revid = target_repo.lookup_foreign_revision_id( - target_repo._git.refs[ref_name]) + target_repo._git.refs[ref_name] + ) except KeyError: - trace.warning('%s does not point to a valid object', - ref_name) + trace.warning("%s does not point to a valid object", ref_name) continue except NotCommitError: - trace.warning('%s points to a non-commit object', - tag_name) + trace.warning("%s points to a non-commit object", tag_name) continue conflicts.append((tag_name, source_revid, target_revid)) return updates, set(conflicts) class InterTagsFromGitToNonGit(InterTags): - @classmethod def is_compatible(klass, source: Tags, target: Tags): if not isinstance(source, GitTags): @@ -233,19 +241,28 @@ def merge(self, overwrite=False, ignore_master=False, selector=None): if master is not None: es.enter_context(master.lock_write()) updates, conflicts = self._merge_to( - self.target, source_tag_refs, overwrite=overwrite, - selector=selector) + self.target, source_tag_refs, overwrite=overwrite, selector=selector + ) if master is not None: extra_updates, extra_conflicts = self._merge_to( - master.tags, overwrite=overwrite, + master.tags, + overwrite=overwrite, source_tag_refs=source_tag_refs, - ignore_master=ignore_master, selector=selector) + ignore_master=ignore_master, + selector=selector, + ) updates.update(extra_updates) conflicts.update(extra_conflicts) return updates, conflicts - def _merge_to(self, to_tags, source_tag_refs, overwrite=False, - selector=None, ignore_master=False): + def _merge_to( + self, + to_tags, + source_tag_refs, + overwrite=False, + selector=None, + ignore_master=False, + ): unpeeled_map = defaultdict(set) conflicts = [] updates = {} @@ -288,8 +305,7 @@ def __init__(self, branch): def get_tag_dict(self): ret = {} - for (_ref_name, tag_name, peeled, _unpeeled) in ( - self.branch.get_tag_refs()): + for _ref_name, tag_name, peeled, _unpeeled in self.branch.get_tag_refs(): try: bzr_revid = self.branch.lookup_foreign_revision_id(peeled) except NotCommitError: @@ -346,7 +362,6 @@ def delete_tag(self, name): class GitBranchFormat(branch.BranchFormat): - def network_name(self): return b"git" @@ -364,6 +379,7 @@ def tags_are_versioned(self): def get_foreign_tests_branch_factory(self): from .tests.test_branch import ForeignTestsBranchFactory + return ForeignTestsBranchFactory() def make_tags(self, branch): @@ -373,12 +389,14 @@ def make_tags(self, branch): pass if getattr(branch.repository, "_git", None) is None: from .remote import RemoteGitTagDict + return RemoteGitTagDict(branch) else: return LocalGitTagDict(branch) - def initialize(self, a_controldir, name=None, repository=None, - append_revisions_only=None): + def initialize( + self, a_controldir, name=None, repository=None, append_revisions_only=None + ): raise NotImplementedError(self.initialize) def get_reference(self, controldir, name=None): @@ -395,23 +413,27 @@ def stores_revno(self): class LocalGitBranchFormat(GitBranchFormat): - def get_format_description(self): - return 'Local Git Branch' + return "Local Git Branch" @property def _matchingcontroldir(self): from .dir import LocalGitControlDirFormat + return LocalGitControlDirFormat() - def initialize(self, a_controldir, name=None, repository=None, - append_revisions_only=None): + def initialize( + self, a_controldir, name=None, repository=None, append_revisions_only=None + ): from .dir import LocalGitDir + if not isinstance(a_controldir, LocalGitDir): raise errors.IncompatibleFormat(self, a_controldir._format) return a_controldir.create_branch( - repository=repository, name=name, - append_revisions_only=append_revisions_only) + repository=repository, + name=name, + append_revisions_only=append_revisions_only, + ) class GitBranch(ForeignBranch): @@ -436,8 +458,8 @@ def __init__(self, controldir, repository, ref: bytes, format): raise TypeError(f"ref is invalid: {ref!r}") self.ref = ref self._head = None - self._user_transport = controldir.user_transport.clone('.') - self._control_transport = controldir.control_transport.clone('.') + self._user_transport = controldir.user_transport.clone(".") + self._control_transport = controldir.control_transport.clone(".") self._tag_refs = None params: Dict[str, str] = {} try: @@ -445,10 +467,10 @@ def __init__(self, controldir, repository, ref: bytes, format): except ValueError: self.name = None if self.ref is not None: - params = {"ref": urlutils.escape(self.ref, safe='')} + params = {"ref": urlutils.escape(self.ref, safe="")} else: if self.name: - params = {"branch": urlutils.escape(self.name, safe='')} + params = {"branch": urlutils.escape(self.name, safe="")} for k, v in params.items(): self._user_transport.set_segment_parameter(k, v) self._control_transport.set_segment_parameter(k, v) @@ -475,10 +497,12 @@ def get_child_submit_format(self): def get_config(self): from .config import GitBranchConfig + return GitBranchConfig(self) def get_config_stack(self): from .config import GitBranchStack + return GitBranchStack(self) def _get_nick(self, local=False, possible_master_transports=None): @@ -486,23 +510,22 @@ def _get_nick(self, local=False, possible_master_transports=None): :return: Branch nick """ - if getattr(self.repository, '_git', None): + if getattr(self.repository, "_git", None): cs = self.repository._git.get_config_stack() try: - return cs.get( - (b"branch", self.name.encode('utf-8')), - b"nick").decode("utf-8") + return cs.get((b"branch", self.name.encode("utf-8")), b"nick").decode( + "utf-8" + ) except KeyError: pass return self.name or "HEAD" def _set_nick(self, nick): cf = self.repository._git.get_config() - cf.set((b"branch", self.name.encode('utf-8')), - b"nick", nick.encode("utf-8")) + cf.set((b"branch", self.name.encode("utf-8")), b"nick", nick.encode("utf-8")) f = BytesIO() cf.write_to_file(f) - self.repository._git._put_named_file('config', f.getvalue()) + self.repository._git._put_named_file("config", f.getvalue()) nick = property(_get_nick, _set_nick) @@ -512,8 +535,7 @@ def __repr__(self): def set_last_revision(self, revid): raise NotImplementedError(self.set_last_revision) - def generate_revision_history(self, revid, last_rev=None, - other_branch=None): + def generate_revision_history(self, revid, last_rev=None, other_branch=None): if last_rev is not None: graph = self.repository.get_graph() if not graph.is_ancestor(last_rev, revid): @@ -526,12 +548,12 @@ def lock_write(self, token=None): if token is not None: raise errors.TokenLockingNotSupported(self) if self._lock_mode: - if self._lock_mode == 'r': + if self._lock_mode == "r": raise errors.ReadOnlyError(self) self._lock_count += 1 else: self._lock_ref() - self._lock_mode = 'w' + self._lock_mode = "w" self._lock_count = 1 self.repository.lock_write() return lock.LogicalLockResult(self.unlock) @@ -552,21 +574,21 @@ def _get_push_origin(self, cs): The exact behaviour is documented in the git-config(1) manpage. """ try: - return cs.get((b'branch', self.name.encode('utf-8')), b'pushRemote') + return cs.get((b"branch", self.name.encode("utf-8")), b"pushRemote") except KeyError: try: - return cs.get((b'branch', ), b'remote') + return cs.get((b"branch",), b"remote") except KeyError: try: - return cs.get((b'branch', self.name.encode('utf-8')), b'remote') + return cs.get((b"branch", self.name.encode("utf-8")), b"remote") except KeyError: - return b'origin' + return b"origin" def _get_origin(self, cs): try: - return cs.get((b'branch', self.name.encode('utf-8')), b'remote') + return cs.get((b"branch", self.name.encode("utf-8")), b"remote") except KeyError: - return b'origin' + return b"origin" def _get_related_push_branch(self, cs): remote = self._get_push_origin(cs) @@ -575,7 +597,7 @@ def _get_related_push_branch(self, cs): except KeyError: return None - return git_url_to_bzr_url(location.decode('utf-8'), ref=self.ref) + return git_url_to_bzr_url(location.decode("utf-8"), ref=self.ref) def _get_related_merge_branch(self, cs): remote = self._get_origin(cs) @@ -587,9 +609,9 @@ def _get_related_merge_branch(self, cs): try: ref = cs.get((b"branch", remote), b"merge") except KeyError: - ref = b'HEAD' + ref = b"HEAD" - return git_url_to_bzr_url(location.decode('utf-8'), ref=ref) + return git_url_to_bzr_url(location.decode("utf-8"), ref=ref) def _get_parent_location(self): """See Branch.get_parent().""" @@ -603,16 +625,21 @@ def set_parent(self, location): target_url, branch, ref = bzr_url_to_git_url(location) location = urlutils.relative_url(this_url, target_url) cs.set((b"remote", remote), b"url", location) - cs.set((b"remote", remote), b'fetch', - b'+refs/heads/*:refs/remotes/%s/*' % remote) + cs.set( + (b"remote", remote), b"fetch", b"+refs/heads/*:refs/remotes/%s/*" % remote + ) if self.name: if branch: - cs.set((b"branch", self.name.encode()), b"merge", branch_name_to_ref(branch)) + cs.set( + (b"branch", self.name.encode()), + b"merge", + branch_name_to_ref(branch), + ) elif ref: cs.set((b"branch", self.name.encode()), b"merge", ref) else: # TODO(jelmer): Maybe unset rather than setting to HEAD? - cs.set((b"branch", self.name.encode()), b"merge", b'HEAD') + cs.set((b"branch", self.name.encode()), b"merge", b"HEAD") self.repository._write_git_config(cs) def break_lock(self): @@ -620,11 +647,11 @@ def break_lock(self): def lock_read(self): if self._lock_mode: - if self._lock_mode not in ('r', 'w'): + if self._lock_mode not in ("r", "w"): raise ValueError(self._lock_mode) self._lock_count += 1 else: - self._lock_mode = 'r' + self._lock_mode = "r" self._lock_count = 1 self.repository.lock_read() return lock.LogicalLockResult(self.unlock) @@ -633,7 +660,7 @@ def peek_lock_mode(self): return self._lock_mode def is_locked(self): - return (self._lock_mode is not None) + return self._lock_mode is not None def _lock_ref(self): pass @@ -648,7 +675,7 @@ def unlock(self): try: self._lock_count -= 1 if self._lock_count == 0: - if self._lock_mode == 'w': + if self._lock_mode == "w": self._unlock_ref() self._lock_mode = None self._clear_cached_state() @@ -665,22 +692,24 @@ def last_revision(self): return revision.NULL_REVISION return self.lookup_foreign_revision_id(self.head) - def _basic_push(self, target, overwrite=False, stop_revision=None, - tag_selector=None): + def _basic_push( + self, target, overwrite=False, stop_revision=None, tag_selector=None + ): return branch.InterBranch.get(self, target)._basic_push( - overwrite, stop_revision, tag_selector=tag_selector) + overwrite, stop_revision, tag_selector=tag_selector + ) def lookup_foreign_revision_id(self, foreign_revid): try: - return self.repository.lookup_foreign_revision_id(foreign_revid, - self.mapping) + return self.repository.lookup_foreign_revision_id( + foreign_revid, self.mapping + ) except KeyError: # Let's try.. return self.mapping.revision_id_foreign_to_bzr(foreign_revid) def lookup_bzr_revision_id(self, revid): - return self.repository.lookup_bzr_revision_id( - revid, mapping=self.mapping) + return self.repository.lookup_bzr_revision_id(revid, mapping=self.mapping) def get_unshelver(self, tree): raise errors.StoringUncommittedNotSupported(self) @@ -703,8 +732,7 @@ def get_tag_refs(self): self._tag_refs = list(self._iter_tag_refs()) return self._tag_refs - def import_last_revision_info_and_tags(self, source, revno, revid, - lossy=False): + def import_last_revision_info_and_tags(self, source, revno, revid, lossy=False): """Set the last revision info, importing from another repo if necessary. This is used by the bound branch code to upload a revision to @@ -720,7 +748,8 @@ def import_last_revision_info_and_tags(self, source, revno, revid, (should only be different from the arguments when lossy=True) """ push_result = source.push( - self, stop_revision=revid, lossy=lossy, _stop_revno=revno) + self, stop_revision=revid, lossy=lossy, _stop_revno=revno + ) return (push_result.new_revno, push_result.new_revid) def reconcile(self, thorough=True): @@ -735,11 +764,16 @@ class LocalGitBranch(GitBranch): """A local Git branch.""" def __init__(self, controldir, repository, ref): - super().__init__(controldir, repository, ref, - LocalGitBranchFormat()) - - def create_checkout(self, to_location, revision_id=None, lightweight=False, - accelerator_tree=None, hardlink=False): + super().__init__(controldir, repository, ref, LocalGitBranchFormat()) + + def create_checkout( + self, + to_location, + revision_id=None, + lightweight=False, + accelerator_tree=None, + hardlink=False, + ): t = transport.get_transport(to_location) t.ensure_base() format = self._get_checkout_format(lightweight=lightweight) @@ -754,7 +788,8 @@ def create_checkout(self, to_location, revision_id=None, lightweight=False, checkout_branch.pull(self, stop_revision=revision_id) from_branch = None return checkout.create_workingtree( - revision_id, from_branch=from_branch, hardlink=hardlink) + revision_id, from_branch=from_branch, hardlink=hardlink + ) def _lock_ref(self): self._ref_lock = self.repository._git.refs.lock_ref(self.ref) @@ -771,8 +806,9 @@ def _gen_revision_history(self): last_revid = self.last_revision() graph = self.repository.get_graph() try: - ret = list(graph.iter_lefthand_ancestry( - last_revid, (revision.NULL_REVISION, ))) + ret = list( + graph.iter_lefthand_ancestry(last_revid, (revision.NULL_REVISION,)) + ) except errors.RevisionNotPresent as e: raise errors.GhostRevisionsHaveNoRevno(last_revid, e.revision_id) from e ret.reverse() @@ -789,7 +825,8 @@ def _read_last_revision_info(self): graph = self.repository.get_graph() try: revno = graph.find_distance_to_null( - last_revid, [(revision.NULL_REVISION, 0)]) + last_revid, [(revision.NULL_REVISION, 0)] + ) except errors.GhostRevisionsHaveNoRevno: revno = None return revno, last_revid @@ -804,8 +841,7 @@ def set_last_revision(self, revid): if revid == NULL_REVISION: newhead = None else: - (newhead, self.mapping) = self.repository.lookup_bzr_revision_id( - revid) + (newhead, self.mapping) = self.repository.lookup_bzr_revision_id(revid) if self.mapping is None: raise AssertionError self._set_head(newhead) @@ -824,7 +860,7 @@ def _set_head(self, value): def get_push_location(self): """See Branch.get_push_location.""" - push_loc = self.get_config_stack().get('push_location') + push_loc = self.get_config_stack().get("push_location") if push_loc is not None: return push_loc cs = self.repository._git.get_config_stack() @@ -832,8 +868,9 @@ def get_push_location(self): def set_push_location(self, location): """See Branch.set_push_location.""" - self.get_config().set_user_option('push_location', location, - store=config.STORE_LOCATION) + self.get_config().set_user_option( + "push_location", location, store=config.STORE_LOCATION + ) def supports_tags(self): return True @@ -862,8 +899,8 @@ def _iter_tag_refs(self): def create_memorytree(self): from .memorytree import GitMemoryTree - return GitMemoryTree(self, self.repository._git.object_store, - self.head) + + return GitMemoryTree(self, self.repository._git.object_store, self.head) def _quick_lookup_revno(local_branch, remote_branch, revid): @@ -878,8 +915,7 @@ def _quick_lookup_revno(local_branch, remote_branch, revid): except errors.NoSuchRevision: graph = local_branch.repository.get_graph() try: - return graph.find_distance_to_null( - revid, [(revision.NULL_REVISION, 0)]) + return graph.find_distance_to_null(revid, [(revision.NULL_REVISION, 0)]) except errors.GhostRevisionsHaveNoRevno: if not _calculate_revnos(remote_branch): return None @@ -889,7 +925,6 @@ def _quick_lookup_revno(local_branch, remote_branch, revid): class GitBranchPullResult(branch.PullResult): - def __init__(self): super().__init__() self.new_git_head = None @@ -899,17 +934,18 @@ def __init__(self): def report(self, to_file): if not is_quiet(): if self.old_revid == self.new_revid: - to_file.write('No revisions to pull.\n') + to_file.write("No revisions to pull.\n") elif self.new_git_head is not None: - to_file.write('Now on revision %d (git sha: %s).\n' % - (self.new_revno, self.new_git_head)) + to_file.write( + "Now on revision %d (git sha: %s).\n" + % (self.new_revno, self.new_git_head) + ) else: - to_file.write('Now on revision %d.\n' % (self.new_revno,)) + to_file.write("Now on revision %d.\n" % (self.new_revno,)) self._show_tag_conficts(to_file) def _lookup_revno(self, revid): - return _quick_lookup_revno(self.target_branch, self.source_branch, - revid) + return _quick_lookup_revno(self.target_branch, self.source_branch, revid) def _get_old_revno(self): if self._old_revno is not None: @@ -933,10 +969,8 @@ def _set_new_revno(self, revno): class GitBranchPushResult(branch.BranchPushResult): - def _lookup_revno(self, revid): - return _quick_lookup_revno(self.source_branch, self.target_branch, - revid) + return _quick_lookup_revno(self.source_branch, self.target_branch, revid) @property def old_revno(self): @@ -962,14 +996,15 @@ def _get_branch_formats_to_test(): except AttributeError: default_format = branch.BranchFormat._default_format from .remote import RemoteGitBranchFormat + return [ (RemoteGitBranchFormat(), default_format), - (LocalGitBranchFormat(), default_format)] + (LocalGitBranchFormat(), default_format), + ] @classmethod def _get_interrepo(self, source, target): - return _mod_repository.InterRepository.get( - source.repository, target.repository) + return _mod_repository.InterRepository.get(source.repository, target.repository) @classmethod def is_compatible(cls, source, target): @@ -978,22 +1013,24 @@ def is_compatible(cls, source, target): if isinstance(target, GitBranch): # InterLocalGitRemoteGitBranch or InterToGitBranch should be used return False - if (getattr(cls._get_interrepo(source, target), "fetch_objects", None) - is None): + if getattr(cls._get_interrepo(source, target), "fetch_objects", None) is None: # fetch_objects is necessary for this to work return False return True def fetch(self, stop_revision=None, fetch_tags=None, limit=None, lossy=False): self.fetch_objects( - stop_revision, fetch_tags=fetch_tags, limit=limit, lossy=lossy) + stop_revision, fetch_tags=fetch_tags, limit=limit, lossy=lossy + ) return _mod_repository.FetchResult() - def fetch_objects(self, stop_revision, fetch_tags, limit=None, lossy=False, tag_selector=None): + def fetch_objects( + self, stop_revision, fetch_tags, limit=None, lossy=False, tag_selector=None + ): interrepo = self._get_interrepo(self.source, self.target) if fetch_tags is None: c = self.source.get_config_stack() - fetch_tags = c.get('branch.fetch_tags') + fetch_tags = c.get("branch.fetch_tags") def determine_wants(heads): if stop_revision is None: @@ -1002,26 +1039,26 @@ def determine_wants(heads): except KeyError: self._last_revid = revision.NULL_REVISION else: - self._last_revid = self.source.lookup_foreign_revision_id( - head) + self._last_revid = self.source.lookup_foreign_revision_id(head) else: self._last_revid = stop_revision real = interrepo.get_determine_wants_revids( - [self._last_revid], include_tags=fetch_tags, tag_selector=tag_selector) + [self._last_revid], include_tags=fetch_tags, tag_selector=tag_selector + ) return real(heads) + pack_hint, head, refs = interrepo.fetch_objects( - determine_wants, self.source.mapping, limit=limit, - lossy=lossy) - if (pack_hint is not None and - self.target.repository._format.pack_compresses): + determine_wants, self.source.mapping, limit=limit, lossy=lossy + ) + if pack_hint is not None and self.target.repository._format.pack_compresses: self.target.repository.pack(hint=pack_hint) return head, refs def _update_revisions(self, stop_revision=None, overwrite=False, tag_selector=None): - head, refs = self.fetch_objects(stop_revision, fetch_tags=None, tag_selector=tag_selector) - _update_tip( - self.source, self.target, - self._last_revid, overwrite) + head, refs = self.fetch_objects( + stop_revision, fetch_tags=None, tag_selector=tag_selector + ) + _update_tip(self.source, self.target, self._last_revid, overwrite) return head, refs def update_references(self, revid=None): @@ -1029,17 +1066,25 @@ def update_references(self, revid=None): revid = self.target.last_revision() tree = self.target.repository.revision_tree(revid) try: - with tree.get_file('.gitmodules') as f: - for path, url, _section in parse_submodules( - GitConfigFile.from_file(f)): + with tree.get_file(".gitmodules") as f: + for path, url, _section in parse_submodules(GitConfigFile.from_file(f)): self.target.set_reference_info( - tree.path2id(decode_git_path(path)), url.decode('utf-8'), - decode_git_path(path)) + tree.path2id(decode_git_path(path)), + url.decode("utf-8"), + decode_git_path(path), + ) except transport.NoSuchFile: pass - def _basic_pull(self, stop_revision, overwrite, run_hooks, - _override_hook_target, _hook_master, tag_selector=None): + def _basic_pull( + self, + stop_revision, + overwrite, + run_hooks, + _override_hook_target, + _hook_master, + tag_selector=None, + ): if overwrite is True: overwrite = {"history", "tags"} elif not overwrite: @@ -1053,19 +1098,20 @@ def _basic_pull(self, stop_revision, overwrite, run_hooks, with self.target.lock_write(), self.source.lock_read(): # We assume that during 'pull' the target repository is closer than # the source one. - (result.old_revno, result.old_revid) = \ - self.target.last_revision_info() + (result.old_revno, result.old_revid) = self.target.last_revision_info() result.new_git_head, remote_refs = self._update_revisions( - stop_revision, overwrite=("history" in overwrite), - tag_selector=tag_selector) + stop_revision, + overwrite=("history" in overwrite), + tag_selector=tag_selector, + ) tags_ret = self.source.tags.merge_to( - self.target.tags, ("tags" in overwrite), ignore_master=True) + self.target.tags, ("tags" in overwrite), ignore_master=True + ) if isinstance(tags_ret, tuple): result.tag_updates, result.tag_conflicts = tags_ret else: result.tag_conflicts = tags_ret - (result.new_revno, result.new_revid) = \ - self.target.last_revision_info() + (result.new_revno, result.new_revid) = self.target.last_revision_info() self.update_references(revid=result.new_revid) if _hook_master: result.master_branch = _hook_master @@ -1074,13 +1120,21 @@ def _basic_pull(self, stop_revision, overwrite, run_hooks, result.master_branch = result.target_branch result.local_branch = None if run_hooks: - for hook in branch.Branch.hooks['post_pull']: + for hook in branch.Branch.hooks["post_pull"]: hook(result) return result - def pull(self, overwrite=False, stop_revision=None, - possible_transports=None, _hook_master=None, run_hooks=True, - _override_hook_target=None, local=False, tag_selector=None): + def pull( + self, + overwrite=False, + stop_revision=None, + possible_transports=None, + _hook_master=None, + run_hooks=True, + _override_hook_target=None, + local=False, + tag_selector=None, + ): """See Branch.pull. :param _hook_master: Private parameter - set the branch to @@ -1104,7 +1158,7 @@ def pull(self, overwrite=False, stop_revision=None, normalized = urlutils.normalize_url(bound_location) try: relpath = self.source.user_transport.relpath(normalized) - source_is_master = (relpath == '') + source_is_master = relpath == "" except (errors.PathNotChild, urlutils.InvalidURL): source_is_master = False if not local and bound_location and not source_is_master: @@ -1113,14 +1167,21 @@ def pull(self, overwrite=False, stop_revision=None, es.enter_context(master_branch.lock_write()) # pull from source into master. master_branch.pull( - self.source, overwrite=overwrite, - stop_revision=stop_revision, run_hooks=False) + self.source, + overwrite=overwrite, + stop_revision=stop_revision, + run_hooks=False, + ) else: master_branch = None - return self._basic_pull(stop_revision, overwrite, run_hooks, - _override_hook_target, - _hook_master=master_branch, - tag_selector=tag_selector) + return self._basic_pull( + stop_revision, + overwrite, + run_hooks, + _override_hook_target, + _hook_master=master_branch, + tag_selector=tag_selector, + ) def _basic_push(self, overwrite, stop_revision, tag_selector=None): if overwrite is True: @@ -1132,11 +1193,14 @@ def _basic_push(self, overwrite, stop_revision, tag_selector=None): result.target_branch = self.target result.old_revno, result.old_revid = self.target.last_revision_info() result.new_git_head, remote_refs = self._update_revisions( - stop_revision, overwrite=("history" in overwrite), - tag_selector=tag_selector) + stop_revision, overwrite=("history" in overwrite), tag_selector=tag_selector + ) tags_ret = self.source.tags.merge_to( - self.target.tags, "tags" in overwrite, ignore_master=True, - selector=tag_selector) + self.target.tags, + "tags" in overwrite, + ignore_master=True, + selector=tag_selector, + ) (result.tag_updates, result.tag_conflicts) = tags_ret result.new_revno, result.new_revid = self.target.last_revision_info() self.update_references(revid=result.new_revid) @@ -1156,17 +1220,20 @@ class InterLocalGitRemoteGitBranch(InterGitBranch): @staticmethod def _get_branch_formats_to_test(): from .remote import RemoteGitBranchFormat - return [ - (LocalGitBranchFormat(), RemoteGitBranchFormat())] + + return [(LocalGitBranchFormat(), RemoteGitBranchFormat())] @classmethod def is_compatible(self, source, target): from .remote import RemoteGitBranch - return (isinstance(source, LocalGitBranch) and - isinstance(target, RemoteGitBranch)) + + return isinstance(source, LocalGitBranch) and isinstance( + target, RemoteGitBranch + ) def _basic_push(self, overwrite, stop_revision, tag_selector=None): from .remote import parse_git_error + result = GitBranchPushResult() result.source_branch = self.source result.target_branch = self.target @@ -1178,36 +1245,36 @@ def get_changed_refs(old_refs): if old_ref is None: result.old_revid = revision.NULL_REVISION else: - result.old_revid = self.target.lookup_foreign_revision_id( - old_ref) - new_ref = self.source.repository.lookup_bzr_revision_id( - stop_revision)[0] + result.old_revid = self.target.lookup_foreign_revision_id(old_ref) + new_ref = self.source.repository.lookup_bzr_revision_id(stop_revision)[0] if not overwrite: if remote_divergence( - old_ref, new_ref, - self.source.repository._git.object_store): + old_ref, new_ref, self.source.repository._git.object_store + ): raise errors.DivergedBranches(self.source, self.target) refs = {self.target.ref: new_ref} result.new_revid = stop_revision - for name, sha in ( - self.source.repository._git.refs.as_dict(b"refs/tags").items()): - if tag_selector and not tag_selector(name.decode('utf-8')): + for name, sha in self.source.repository._git.refs.as_dict( + b"refs/tags" + ).items(): + if tag_selector and not tag_selector(name.decode("utf-8")): continue if sha not in self.source.repository._git: - trace.mutter('Ignoring missing SHA: %s', sha) + trace.mutter("Ignoring missing SHA: %s", sha) continue - refs[tag_name_to_ref(name.decode('utf-8'))] = sha + refs[tag_name_to_ref(name.decode("utf-8"))] = sha return refs + dw_result = self.target.repository.send_pack( - get_changed_refs, - self.source.repository._git.generate_pack_data) + get_changed_refs, self.source.repository._git.generate_pack_data + ) if dw_result is not None and not isinstance(dw_result, dict): error = dw_result.ref_status.get(self.target.ref) if error: raise parse_git_error(self.target.user_url, error) for ref, error in dw_result.ref_status.items(): if error: - trace.warning('unable to open ref %s: %s', ref, error) + trace.warning("unable to open ref %s: %s", ref, error) return result @@ -1217,28 +1284,32 @@ class InterGitLocalGitBranch(InterGitBranch): @staticmethod def _get_branch_formats_to_test(): from .remote import RemoteGitBranchFormat + return [ (RemoteGitBranchFormat(), LocalGitBranchFormat()), - (LocalGitBranchFormat(), LocalGitBranchFormat())] + (LocalGitBranchFormat(), LocalGitBranchFormat()), + ] @classmethod def is_compatible(self, source, target): - return (isinstance(source, GitBranch) and - isinstance(target, LocalGitBranch)) + return isinstance(source, GitBranch) and isinstance(target, LocalGitBranch) def fetch(self, stop_revision=None, fetch_tags=None, limit=None, lossy=False): if lossy: raise errors.LossyPushToSameVCS( - source_branch=self.source, target_branch=self.target) + source_branch=self.source, target_branch=self.target + ) interrepo = _mod_repository.InterRepository.get( - self.source.repository, self.target.repository) + self.source.repository, self.target.repository + ) if stop_revision is None: stop_revision = self.source.last_revision() if fetch_tags is None: c = self.source.get_config_stack() - fetch_tags = c.get('branch.fetch_tags') + fetch_tags = c.get("branch.fetch_tags") determine_wants = interrepo.get_determine_wants_revids( - [stop_revision], include_tags=fetch_tags) + [stop_revision], include_tags=fetch_tags + ) interrepo.fetch_objects(determine_wants, limit=limit) return _mod_repository.FetchResult() @@ -1252,14 +1323,10 @@ def _basic_push(self, overwrite=False, stop_revision=None, tag_selector=None): result.target_branch = self.target result.old_revid = self.target.last_revision() refs, stop_revision = self.update_refs(stop_revision) - _update_tip( - self.source, self.target, - stop_revision, - "history" in overwrite) + _update_tip(self.source, self.target, stop_revision, "history" in overwrite) tags_ret = self.source.tags.merge_to( - self.target.tags, - overwrite=("tags" in overwrite), - selector=tag_selector) + self.target.tags, overwrite=("tags" in overwrite), selector=tag_selector + ) if isinstance(tags_ret, tuple): (result.tag_updates, result.tag_conflicts) = tags_ret else: @@ -1269,12 +1336,15 @@ def _basic_push(self, overwrite=False, stop_revision=None, tag_selector=None): def update_refs(self, stop_revision=None): interrepo = _mod_repository.InterRepository.get( - self.source.repository, self.target.repository) + self.source.repository, self.target.repository + ) c = self.source.get_config_stack() - fetch_tags = c.get('branch.fetch_tags') + fetch_tags = c.get("branch.fetch_tags") if stop_revision is None: - result = interrepo.fetch(branches=[self.source.ref], include_tags=fetch_tags) + result = interrepo.fetch( + branches=[self.source.ref], include_tags=fetch_tags + ) try: head = result.refs[self.source.ref] except KeyError: @@ -1282,13 +1352,18 @@ def update_refs(self, stop_revision=None): else: stop_revision = self.target.lookup_foreign_revision_id(head) else: - result = interrepo.fetch( - revision_id=stop_revision, include_tags=fetch_tags) + result = interrepo.fetch(revision_id=stop_revision, include_tags=fetch_tags) return result.refs, stop_revision - def pull(self, stop_revision=None, overwrite=False, - possible_transports=None, run_hooks=True, local=False, - tag_selector=None): + def pull( + self, + stop_revision=None, + overwrite=False, + possible_transports=None, + run_hooks=True, + local=False, + tag_selector=None, + ): # This type of branch can't be bound. if local: raise errors.LocalRequiresBoundBranch() @@ -1303,13 +1378,10 @@ def pull(self, stop_revision=None, overwrite=False, with self.target.lock_write(), self.source.lock_read(): result.old_revid = self.target.last_revision() refs, stop_revision = self.update_refs(stop_revision) - _update_tip( - self.source, self.target, - stop_revision, - "history" in overwrite) + _update_tip(self.source, self.target, stop_revision, "history" in overwrite) tags_ret = self.source.tags.merge_to( - self.target.tags, overwrite=("tags" in overwrite), - selector=tag_selector) + self.target.tags, overwrite=("tags" in overwrite), selector=tag_selector + ) if isinstance(tags_ret, tuple): (result.tag_updates, result.tag_conflicts) = tags_ret else: @@ -1318,14 +1390,13 @@ def pull(self, stop_revision=None, overwrite=False, result.local_branch = None result.master_branch = result.target_branch if run_hooks: - for hook in branch.Branch.hooks['post_pull']: + for hook in branch.Branch.hooks["post_pull"]: hook(result) return result def _update_pure_git_refs(result, new_refs, overwrite, tag_selector, old_refs): - mutter("updating refs. old refs: %r, new refs: %r", - old_refs, new_refs) + mutter("updating refs. old refs: %r, new refs: %r", old_refs, new_refs) result.tag_updates = {} result.tag_conflicts = [] ret = {} @@ -1344,6 +1415,7 @@ def ref_equals(refs, name, git_sha, revid): # has the bzr revid, then this will cause us to show a tag as updated # that hasn't actually been updated. return False + # FIXME: Check for diverged branches for ref, (git_sha, revid) in new_refs.items(): if ref_equals(ret, ref, git_sha, revid): @@ -1372,21 +1444,20 @@ def ref_equals(refs, name, git_sha, revid): except ValueError: pass else: - result.tag_conflicts.append( - (name, revid, ret[name][1])) + result.tag_conflicts.append((name, revid, ret[name][1])) else: ret[ref] = (git_sha, revid) return ret - class InterToGitBranch(branch.GenericInterBranch): """InterBranch implementation that pulls into a Git branch.""" def __init__(self, source, target): super().__init__(source, target) - self.interrepo = _mod_repository.InterRepository.get(source.repository, - target.repository) + self.interrepo = _mod_repository.InterRepository.get( + source.repository, target.repository + ) @staticmethod def _get_branch_formats_to_test(): @@ -1395,17 +1466,17 @@ def _get_branch_formats_to_test(): except AttributeError: default_format = branch.BranchFormat._default_format from .remote import RemoteGitBranchFormat + return [ (default_format, LocalGitBranchFormat()), - (default_format, RemoteGitBranchFormat())] + (default_format, RemoteGitBranchFormat()), + ] @classmethod def is_compatible(self, source, target): - return (not isinstance(source, GitBranch) and - isinstance(target, GitBranch)) + return not isinstance(source, GitBranch) and isinstance(target, GitBranch) - def _get_new_refs(self, stop_revision=None, fetch_tags=None, - stop_revno=None): + def _get_new_refs(self, stop_revision=None, fetch_tags=None, stop_revno=None): if not self.source.is_locked(): raise errors.ObjectNotLocked(self.source) if stop_revision is None: @@ -1421,21 +1492,19 @@ def _get_new_refs(self, stop_revision=None, fetch_tags=None, refs = {main_ref: (None, stop_revision)} if fetch_tags is None: c = self.source.get_config_stack() - fetch_tags = c.get('branch.fetch_tags') + fetch_tags = c.get("branch.fetch_tags") for name, revid in self.source.tags.get_tag_dict().items(): if self.source.repository.has_revision(revid): ref = tag_name_to_ref(name) if not check_ref_format(ref): - warning("skipping tag with invalid characters %s (%s)", - name, ref) + warning("skipping tag with invalid characters %s (%s)", name, ref) continue if fetch_tags: # FIXME: Skip tags that are not in the ancestry refs[ref] = (None, revid) return refs, main_ref, (stop_revno, stop_revision) - def fetch(self, stop_revision=None, fetch_tags=None, lossy=False, - limit=None): + def fetch(self, stop_revision=None, fetch_tags=None, lossy=False, limit=None): if stop_revision is None: stop_revision = self.source.last_revision() ret = [] @@ -1443,15 +1512,19 @@ def fetch(self, stop_revision=None, fetch_tags=None, lossy=False, for _k, v in self.source.tags.get_tag_dict().items(): ret.append((None, v)) ret.append((None, stop_revision)) - if getattr(self.interrepo, 'fetch_revs', None): + if getattr(self.interrepo, "fetch_revs", None): try: revidmap = self.interrepo.fetch_revs(ret, lossy=lossy, limit=limit) except NoPushSupport as err: raise errors.NoRoundtrippingSupport(self.source, self.target) from err - return _mod_repository.FetchResult(revidmap={ - old_revid: new_revid - for (old_revid, (new_sha, new_revid)) in revidmap.items()}) + return _mod_repository.FetchResult( + revidmap={ + old_revid: new_revid + for (old_revid, (new_sha, new_revid)) in revidmap.items() + } + ) else: + def determine_wants(refs): wants = [] for git_sha, revid in ret: @@ -1460,42 +1533,58 @@ def determine_wants(refs): wants.append(git_sha) return wants - self.interrepo.fetch_objects( - determine_wants, lossy=lossy, limit=limit) + self.interrepo.fetch_objects(determine_wants, lossy=lossy, limit=limit) return _mod_repository.FetchResult() - def pull(self, overwrite=False, stop_revision=None, local=False, - possible_transports=None, run_hooks=True, _stop_revno=None, - tag_selector=None): + def pull( + self, + overwrite=False, + stop_revision=None, + local=False, + possible_transports=None, + run_hooks=True, + _stop_revno=None, + tag_selector=None, + ): result = GitBranchPullResult() result.source_branch = self.source result.target_branch = self.target with self.source.lock_read(), self.target.lock_write(): new_refs, main_ref, stop_revinfo = self._get_new_refs( - stop_revision, stop_revno=_stop_revno) + stop_revision, stop_revno=_stop_revno + ) - update_refs = partial(_update_pure_git_refs, result, new_refs, overwrite, tag_selector) + update_refs = partial( + _update_pure_git_refs, result, new_refs, overwrite, tag_selector + ) try: - result.revidmap, old_refs, new_refs = ( - self.interrepo.fetch_refs(update_refs, lossy=False)) + result.revidmap, old_refs, new_refs = self.interrepo.fetch_refs( + update_refs, lossy=False + ) except NoPushSupport as err: raise errors.NoRoundtrippingSupport(self.source, self.target) from err (old_sha1, result.old_revid) = old_refs.get( - main_ref, (ZERO_SHA, NULL_REVISION)) + main_ref, (ZERO_SHA, NULL_REVISION) + ) if result.old_revid is None: - result.old_revid = self.target.lookup_foreign_revision_id( - old_sha1) + result.old_revid = self.target.lookup_foreign_revision_id(old_sha1) result.new_revid = new_refs[main_ref][1] result.local_branch = None result.master_branch = self.target if run_hooks: - for hook in branch.Branch.hooks['post_pull']: + for hook in branch.Branch.hooks["post_pull"]: hook(result) return result - def push(self, overwrite=False, stop_revision=None, lossy=False, - _override_hook_source_branch=None, _stop_revno=None, - tag_selector=None): + def push( + self, + overwrite=False, + stop_revision=None, + lossy=False, + _override_hook_source_branch=None, + _stop_revno=None, + tag_selector=None, + ): result = GitBranchPushResult() result.source_branch = self.source result.target_branch = self.target @@ -1503,24 +1592,26 @@ def push(self, overwrite=False, stop_revision=None, lossy=False, result.master_branch = result.target_branch with self.source.lock_read(), self.target.lock_write(): new_refs, main_ref, stop_revinfo = self._get_new_refs( - stop_revision, stop_revno=_stop_revno) + stop_revision, stop_revno=_stop_revno + ) - update_refs = partial(_update_pure_git_refs, result, new_refs, overwrite, tag_selector) + update_refs = partial( + _update_pure_git_refs, result, new_refs, overwrite, tag_selector + ) try: - result.revidmap, old_refs, new_refs = ( - self.interrepo.fetch_refs( - update_refs, lossy=lossy, overwrite=overwrite)) + result.revidmap, old_refs, new_refs = self.interrepo.fetch_refs( + update_refs, lossy=lossy, overwrite=overwrite + ) except NoPushSupport as err: raise errors.NoRoundtrippingSupport(self.source, self.target) from err (old_sha1, result.old_revid) = old_refs.get( - main_ref, (ZERO_SHA, NULL_REVISION)) + main_ref, (ZERO_SHA, NULL_REVISION) + ) if lossy or result.old_revid is None: - result.old_revid = self.target.lookup_foreign_revision_id( - old_sha1) + result.old_revid = self.target.lookup_foreign_revision_id(old_sha1) result.new_revid = new_refs[main_ref][1] - (result.new_original_revno, - result.new_original_revid) = stop_revinfo - for hook in branch.Branch.hooks['post_push']: + (result.new_original_revno, result.new_original_revid) = stop_revinfo + for hook in branch.Branch.hooks["post_push"]: hook(result) return result diff --git a/breezy/git/cache.py b/breezy/git/cache.py index 7c07830c5e..205eb52b7d 100644 --- a/breezy/git/cache.py +++ b/breezy/git/cache.py @@ -141,7 +141,7 @@ def open(self, transport): def initialize(self, transport): """Create a new instance of this cache format at transport.""" - transport.put_bytes('format', self.get_format_string()) + transport.put_bytes("format", self.get_format_string()) @classmethod def from_transport(self, transport): @@ -151,10 +151,10 @@ def from_transport(self, transport): :return: A BzrGitCache instance """ try: - format_name = transport.get_bytes('format') + format_name = transport.get_bytes("format") format = formats.get(format_name) except NoSuchFile: - format = formats.get('default') + format = formats.get("default") format.initialize(transport) return format.open(transport) @@ -170,22 +170,21 @@ def from_repository(cls, repository): :return: A `BzrGitCache` """ from ..transport.local import LocalTransport + repo_transport = getattr(repository, "_transport", None) - if (repo_transport is not None - and isinstance(repo_transport, LocalTransport)): + if repo_transport is not None and isinstance(repo_transport, LocalTransport): # Even if we don't write to this repo, we should be able # to update its cache. try: - repo_transport = remove_readonly_transport_decorator( - repo_transport) + repo_transport = remove_readonly_transport_decorator(repo_transport) except bzr_errors.ReadOnlyError: transport = None else: try: - repo_transport.mkdir('git') + repo_transport.mkdir("git") except FileExists: pass - transport = repo_transport.clone('git') + transport = repo_transport.clone("git") else: transport = None if transport is None: @@ -242,7 +241,7 @@ def add_object(self, obj, bzr_key_data, path): if isinstance(obj, tuple): (type_name, hexsha) = obj else: - type_name = obj.type_name.decode('ascii') + type_name = obj.type_name.decode("ascii") hexsha = obj.id if not isinstance(hexsha, bytes): raise TypeError(hexsha) @@ -257,7 +256,8 @@ def add_object(self, obj, bzr_key_data, path): if bzr_key_data is not None: key = type_data = bzr_key_data self.cache.idmap._by_fileid.setdefault(type_data[1], {})[ - type_data[0]] = hexsha + type_data[0] + ] = hexsha else: raise AssertionError entry = (type_name, type_data) @@ -293,7 +293,7 @@ def lookup_commit(self, revid): def revids(self): for _key, entries in self._by_sha.items(): - for (type, type_data) in entries.values(): + for type, type_data in entries.values(): if type == "commit": yield type_data[0] @@ -302,7 +302,6 @@ def sha1s(self): class SqliteCacheUpdater(CacheUpdater): - def __init__(self, cache, rev): self.cache = cache self.db = self.cache.idmap.db @@ -315,7 +314,7 @@ def add_object(self, obj, bzr_key_data, path): if isinstance(obj, tuple): (type_name, hexsha) = obj else: - type_name = obj.type_name.decode('ascii') + type_name = obj.type_name.decode("ascii") hexsha = obj.id if not isinstance(hexsha, bytes): raise TypeError(hexsha) @@ -337,16 +336,16 @@ def finish(self): if self._commit is None: raise AssertionError("No commit object added") self.db.executemany( - "replace into trees (sha1, fileid, revid) values (?, ?, ?)", - self._trees) + "replace into trees (sha1, fileid, revid) values (?, ?, ?)", self._trees + ) self.db.executemany( - "replace into blobs (sha1, fileid, revid) values (?, ?, ?)", - self._blobs) + "replace into blobs (sha1, fileid, revid) values (?, ?, ?)", self._blobs + ) self.db.execute( "replace into commits (sha1, revid, tree_sha, testament3_sha1) " "values (?, ?, ?, ?)", - (self._commit.id, self.revid, self._commit.tree, - self._testament3_sha1)) + (self._commit.id, self.revid, self._commit.tree, self._testament3_sha1), + ) return self._commit @@ -355,9 +354,8 @@ def SqliteBzrGitCache(p): class SqliteGitCacheFormat(BzrGitCacheFormat): - def get_format_string(self): - return b'bzr-git sha map version 1 using sqlite\n' + return b"bzr-git sha map version 1 using sqlite\n" def open(self, transport): try: @@ -372,6 +370,7 @@ class SqliteGitShaMap(GitShaMap): def __init__(self, path=None): import sqlite3 + self.path = path if path is None: self.db = sqlite3.connect(":memory:") @@ -380,7 +379,8 @@ def __init__(self, path=None): mapdbs()[path] = sqlite3.connect(path) self.db = mapdbs()[path] self.db.text_factory = str - self.db.executescript(""" + self.db.executescript( + """ create table if not exists commits( sha1 text not null check(length(sha1) == 40), revid text not null, @@ -404,10 +404,10 @@ def __init__(self, path=None): create unique index if not exists trees_sha1 on trees(sha1); create unique index if not exists trees_fileid_revid on trees( fileid, revid); -""") +""" + ) try: - self.db.executescript( - "ALTER TABLE commits ADD testament3_sha1 TEXT;") + self.db.executescript("ALTER TABLE commits ADD testament3_sha1 TEXT;") except sqlite3.OperationalError: pass # Column already exists. @@ -415,8 +415,7 @@ def __repr__(self): return f"{self.__class__.__name__}({self.path!r})" def lookup_commit(self, revid): - cursor = self.db.execute("select sha1 from commits where revid = ?", - (revid,)) + cursor = self.db.execute("select sha1 from commits where revid = ?", (revid,)) row = cursor.fetchone() if row is not None: return row[0] @@ -427,16 +426,16 @@ def commit_write_group(self): def lookup_blob_id(self, fileid, revision): row = self.db.execute( - "select sha1 from blobs where fileid = ? and revid = ?", - (fileid, revision)).fetchone() + "select sha1 from blobs where fileid = ? and revid = ?", (fileid, revision) + ).fetchone() if row is not None: return row[0] raise KeyError(fileid) def lookup_tree_id(self, fileid, revision): row = self.db.execute( - "select sha1 from trees where fileid = ? and revid = ?", - (fileid, revision)).fetchone() + "select sha1 from trees where fileid = ? and revid = ?", (fileid, revision) + ).fetchone() if row is not None: return row[0] raise KeyError(fileid) @@ -452,8 +451,9 @@ def lookup_git_sha(self, sha): """ found = False cursor = self.db.execute( - "select revid, tree_sha, testament3_sha1 from commits where " - "sha1 = ?", (sha,)) + "select revid, tree_sha, testament3_sha1 from commits where " "sha1 = ?", + (sha,), + ) for row in cursor.fetchall(): found = True if row[2] is not None: @@ -462,12 +462,14 @@ def lookup_git_sha(self, sha): verifiers = {} yield ("commit", (row[0], row[1], verifiers)) cursor = self.db.execute( - "select fileid, revid from blobs where sha1 = ?", (sha,)) + "select fileid, revid from blobs where sha1 = ?", (sha,) + ) for row in cursor.fetchall(): found = True yield ("blob", row) cursor = self.db.execute( - "select fileid, revid from trees where sha1 = ?", (sha,)) + "select fileid, revid from trees where sha1 = ?", (sha,) + ) for row in cursor.fetchall(): found = True yield ("tree", row) @@ -482,7 +484,7 @@ def sha1s(self): """List the SHA1s.""" for table in ("blobs", "commits", "trees"): for (sha,) in self.db.execute(f"select sha1 from {table}"): # noqa: S608 - yield sha.encode('ascii') + yield sha.encode("ascii") class TdbCacheUpdater(CacheUpdater): @@ -501,7 +503,7 @@ def add_object(self, obj, bzr_key_data, path): (type_name, hexsha) = obj sha = hex_to_sha(hexsha) else: - type_name = obj.type_name.decode('ascii') + type_name = obj.type_name.decode("ascii") sha = obj.sha().digest() if type_name == "commit": self.db[b"commit\0" + self.revid] = b"\0".join((sha, obj.tree)) @@ -516,8 +518,7 @@ def add_object(self, obj, bzr_key_data, path): elif type_name == "blob": if bzr_key_data is None: return - self.db[b"\0".join( - (b"blob", bzr_key_data[0], bzr_key_data[1]))] = sha + self.db[b"\0".join((b"blob", bzr_key_data[0], bzr_key_data[1]))] = sha type_data = bzr_key_data elif type_name == "tree": if bzr_key_data is None: @@ -525,14 +526,14 @@ def add_object(self, obj, bzr_key_data, path): type_data = bzr_key_data else: raise AssertionError - entry = b"\0".join((type_name.encode('ascii'), ) + type_data) + b"\n" + entry = b"\0".join((type_name.encode("ascii"),) + type_data) + b"\n" key = b"git\0" + sha try: oldval = self.db[key] except KeyError: self.db[key] = entry else: - if not oldval.endswith(b'\n'): + if not oldval.endswith(b"\n"): self.db[key] = b"".join([oldval, b"\n", entry]) else: self.db[key] = b"".join([oldval, entry]) @@ -551,7 +552,7 @@ class TdbGitCacheFormat(BzrGitCacheFormat): """Cache format for tdb-based caches.""" def get_format_string(self): - return b'bzr-git sha map version 3 using tdb\n' + return b"bzr-git sha map version 3 using tdb\n" def open(self, transport): try: @@ -563,7 +564,8 @@ def open(self, transport): except ImportError as err: raise ImportError( "Unable to open existing bzr-git cache because 'tdb' is not " - "installed.") from err + "installed." + ) from err class TdbGitShaMap(GitShaMap): @@ -582,23 +584,27 @@ class TdbGitShaMap(GitShaMap): def __init__(self, path=None): import tdb + self.path = path if path is None: self.db = {} else: if path not in mapdbs(): - mapdbs()[path] = tdb.Tdb(path, self.TDB_HASH_SIZE, tdb.DEFAULT, - os.O_RDWR | os.O_CREAT) + mapdbs()[path] = tdb.Tdb( + path, self.TDB_HASH_SIZE, tdb.DEFAULT, os.O_RDWR | os.O_CREAT + ) self.db = mapdbs()[path] try: if int(self.db[b"version"]) not in (2, 3): trace.warning( "SHA Map is incompatible (%s -> %d), rebuilding database.", - self.db[b"version"], self.TDB_MAP_VERSION) + self.db[b"version"], + self.TDB_MAP_VERSION, + ) self.db.clear() except KeyError: pass - self.db[b"version"] = b'%d' % self.TDB_MAP_VERSION + self.db[b"version"] = b"%d" % self.TDB_MAP_VERSION def start_write_group(self): """Start writing changes.""" @@ -638,13 +644,12 @@ def lookup_git_sha(self, sha): value = self.db[b"git\0" + sha] for data in value.splitlines(): data = data.split(b"\0") - type_name = data[0].decode('ascii') + type_name = data[0].decode("ascii") if type_name == "commit": if len(data) == 3: yield (type_name, (data[1], data[2], {})) else: - yield (type_name, (data[1], data[2], - {"testament3-sha1": data[3]})) + yield (type_name, (data[1], data[2], {"testament3-sha1": data[3]})) elif type_name in ("tree", "blob"): yield (type_name, tuple(data[1:])) else: @@ -674,25 +679,27 @@ def sha1s(self): class VersionedFilesContentCache(ContentCache): - def __init__(self, vf): self._vf = vf def add(self, obj): self._vf.insert_record_stream( - [versionedfile.ChunkedContentFactory( - (obj.id,), [], None, obj.as_legacy_object_chunks())]) + [ + versionedfile.ChunkedContentFactory( + (obj.id,), [], None, obj.as_legacy_object_chunks() + ) + ] + ) def __getitem__(self, sha): - stream = self._vf.get_record_stream([(sha,)], 'unordered', True) + stream = self._vf.get_record_stream([(sha,)], "unordered", True) entry = next(stream) - if entry.storage_kind == 'absent': + if entry.storage_kind == "absent": raise KeyError(sha) - return ShaFile._parse_legacy_object(entry.get_bytes_as('fulltext')) + return ShaFile._parse_legacy_object(entry.get_bytes_as("fulltext")) class IndexCacheUpdater(CacheUpdater): - def __init__(self, cache, rev): self.cache = cache self.revid = rev.revision_id @@ -704,20 +711,23 @@ def add_object(self, obj, bzr_key_data, path): if isinstance(obj, tuple): (type_name, hexsha) = obj else: - type_name = obj.type_name.decode('ascii') + type_name = obj.type_name.decode("ascii") hexsha = obj.id if type_name == "commit": self._commit = obj if not isinstance(bzr_key_data, dict): raise TypeError(bzr_key_data) - self.cache.idmap._add_git_sha(hexsha, b"commit", - (self.revid, obj.tree, bzr_key_data)) - self.cache.idmap._add_node((b"commit", self.revid, b"X"), - b" ".join((hexsha, obj.tree))) + self.cache.idmap._add_git_sha( + hexsha, b"commit", (self.revid, obj.tree, bzr_key_data) + ) + self.cache.idmap._add_node( + (b"commit", self.revid, b"X"), b" ".join((hexsha, obj.tree)) + ) elif type_name == "blob": self.cache.idmap._add_git_sha(hexsha, b"blob", bzr_key_data) - self.cache.idmap._add_node((b"blob", bzr_key_data[0], - bzr_key_data[1]), hexsha) + self.cache.idmap._add_node( + (b"blob", bzr_key_data[0], bzr_key_data[1]), hexsha + ) elif type_name == "tree": self.cache.idmap._add_git_sha(hexsha, b"tree", bzr_key_data) else: @@ -728,23 +738,22 @@ def finish(self): class IndexBzrGitCache(BzrGitCache): - def __init__(self, transport=None): - shamap = IndexGitShaMap(transport.clone('index')) + shamap = IndexGitShaMap(transport.clone("index")) super().__init__(shamap, IndexCacheUpdater) class IndexGitCacheFormat(BzrGitCacheFormat): - def get_format_string(self): - return b'bzr-git sha map with git object cache version 1\n' + return b"bzr-git sha map with git object cache version 1\n" def initialize(self, transport): super().initialize(transport) - transport.mkdir('index') - transport.mkdir('objects') + transport.mkdir("index") + transport.mkdir("objects") from .transportgit import TransportObjectStore - TransportObjectStore.init(transport.clone('objects')) + + TransportObjectStore.init(transport.clone("objects")) def open(self, transport): return IndexBzrGitCache(transport) @@ -775,7 +784,8 @@ def __init__(self, transport=None): if not name.endswith(".rix"): continue x = _mod_btree_index.BTreeGraphIndex( - self._transport, name, self._transport.stat(name).st_size) + self._transport, name, self._transport.stat(name).st_size + ) self._index.insert_index(0, x) @classmethod @@ -783,10 +793,10 @@ def from_repository(cls, repository): transport = getattr(repository, "_transport", None) if transport is not None: try: - transport.mkdir('git') + transport.mkdir("git") except FileExists: pass - return cls(transport.clone('git')) + return cls(transport.clone("git")) return cls(get_transport_from_path(get_cache_dir())) def __repr__(self): @@ -797,29 +807,29 @@ def __repr__(self): def repack(self): if self._builder is not None: - raise bzr_errors.BzrError('builder already open') + raise bzr_errors.BzrError("builder already open") self.start_write_group() self._builder.add_nodes( - (key, value) - for (_, key, value) in self._index.iter_all_entries()) + (key, value) for (_, key, value) in self._index.iter_all_entries() + ) to_remove = [] - for name in self._transport.list_dir('.'): - if name.endswith('.rix'): + for name in self._transport.list_dir("."): + if name.endswith(".rix"): to_remove.append(name) self.commit_write_group() del self._index.indices[1:] for name in to_remove: - self._transport.rename(name, name + '.old') + self._transport.rename(name, name + ".old") def start_write_group(self): if self._builder is not None: - raise bzr_errors.BzrError('builder already open') + raise bzr_errors.BzrError("builder already open") self._builder = _mod_btree_index.BTreeBuilder(0, key_elements=3) self._name = hashlib.sha1() # noqa: S324 def commit_write_group(self): if self._builder is None: - raise bzr_errors.BzrError('builder not open') + raise bzr_errors.BzrError("builder not open") stream = self._builder.finish() name = self._name.hexdigest() + ".rix" size = self._transport.put_file(name, stream) @@ -830,7 +840,7 @@ def commit_write_group(self): def abort_write_group(self): if self._builder is None: - raise bzr_errors.BzrError('builder not open') + raise bzr_errors.BzrError("builder not open") self._builder = None self._name = None @@ -901,7 +911,7 @@ def lookup_git_sha(self, sha): verifiers = {} yield ("commit", (data[1], data[2], verifiers)) else: - yield (data[0].decode('ascii'), tuple(data[1:])) + yield (data[0].decode("ascii"), tuple(data[1:])) def revids(self): """List the revision ids known.""" @@ -912,7 +922,8 @@ def missing_revisions(self, revids): """Return set of all the revisions that are not present.""" missing_revids = set(revids) for _, key, _value in self._index.iter_entries( - (b"commit", revid, b"X") for revid in revids): + (b"commit", revid, b"X") for revid in revids + ): missing_revids.remove(key[1]) return missing_revids @@ -923,14 +934,11 @@ def sha1s(self): formats = registry.Registry[str, BzrGitCacheFormat, None]() -formats.register(TdbGitCacheFormat().get_format_string(), - TdbGitCacheFormat()) -formats.register(SqliteGitCacheFormat().get_format_string(), - SqliteGitCacheFormat()) -formats.register(IndexGitCacheFormat().get_format_string(), - IndexGitCacheFormat()) +formats.register(TdbGitCacheFormat().get_format_string(), TdbGitCacheFormat()) +formats.register(SqliteGitCacheFormat().get_format_string(), SqliteGitCacheFormat()) +formats.register(IndexGitCacheFormat().get_format_string(), IndexGitCacheFormat()) # In the future, this will become the default: -formats.register('default', IndexGitCacheFormat()) +formats.register("default", IndexGitCacheFormat()) def migrate_ancient_formats(repo_transport): diff --git a/breezy/git/commands.py b/breezy/git/commands.py index ecbdcbfac0..cf44a55060 100644 --- a/breezy/git/commands.py +++ b/breezy/git/commands.py @@ -34,20 +34,21 @@ class cmd_git_import(Command): takes_args = ["src_location", "dest_location?"] takes_options = [ - Option('colocated', help='Create colocated branches.'), - RegistryOption('dest-format', - help='Specify a format for this branch. ' - 'See "help formats" for a full list.', - lazy_registry=('breezy.controldir', 'format_registry'), - converter=lambda name: controldir.format_registry.make_controldir( - name), - value_switches=True, - title="Branch format", - ), - ] + Option("colocated", help="Create colocated branches."), + RegistryOption( + "dest-format", + help="Specify a format for this branch. " + 'See "help formats" for a full list.', + lazy_registry=("breezy.controldir", "format_registry"), + converter=lambda name: controldir.format_registry.make_controldir(name), + value_switches=True, + title="Branch format", + ), + ] def _get_colocated_branch(self, target_controldir, name): from ..errors import NotBranchError + try: return target_controldir.open_branch(name=name) except NotBranchError: @@ -56,12 +57,14 @@ def _get_colocated_branch(self, target_controldir, name): def _get_nested_branch(self, dest_transport, dest_format, name): from ..controldir import ControlDir from ..errors import NotBranchError + head_transport = dest_transport.clone(name) try: head_controldir = ControlDir.open_from_transport(head_transport) except NotBranchError: head_controldir = dest_format.initialize_on_transport_ex( - head_transport, create_prefix=True)[1] + head_transport, create_prefix=True + )[1] try: return head_controldir.open_branch() except NotBranchError: @@ -81,7 +84,7 @@ def run(self, src_location, dest_location=None, colocated=False, dest_format=Non from .repository import GitRepository if dest_format is None: - dest_format = controldir.format_registry.make_controldir('default') + dest_format = controldir.format_registry.make_controldir("default") if dest_location is None: dest_location = os.path.basename(src_location.rstrip("/\\")) @@ -90,21 +93,20 @@ def run(self, src_location, dest_location=None, colocated=False, dest_format=Non source_repo = Repository.open(src_location) if not isinstance(source_repo, GitRepository): - raise CommandError( - gettext("%r is not a git repository") % src_location) + raise CommandError(gettext("%r is not a git repository") % src_location) try: target_controldir = ControlDir.open_from_transport(dest_transport) except NotBranchError: target_controldir = dest_format.initialize_on_transport_ex( - dest_transport, shared_repo=True)[1] + dest_transport, shared_repo=True + )[1] try: target_repo = target_controldir.find_repository() except NoRepositoryPresent: target_repo = target_controldir.create_repository(shared=True) if not target_repo.supports_rich_root(): - raise CommandError( - gettext("Target repository doesn't support rich roots")) + raise CommandError(gettext("Target repository doesn't support rich roots")) interrepo = InterRepository.get(source_repo, target_repo) mapping = source_repo.get_mapping() @@ -117,29 +119,35 @@ def run(self, src_location, dest_location=None, colocated=False, dest_format=Non # Not a branch, ignore continue pb.update(gettext("creating branches"), i, len(result.refs)) - if (getattr(target_controldir._format, "colocated_branches", - False) and colocated): + if ( + getattr(target_controldir._format, "colocated_branches", False) + and colocated + ): if name == "HEAD": branch_name = None head_branch = self._get_colocated_branch( - target_controldir, branch_name) + target_controldir, branch_name + ) else: head_branch = self._get_nested_branch( - dest_transport, dest_format, branch_name) + dest_transport, dest_format, branch_name + ) revid = mapping.revision_id_foreign_to_bzr(sha) - source_branch = LocalGitBranch( - source_repo.controldir, source_repo, sha) + source_branch = LocalGitBranch(source_repo.controldir, source_repo, sha) if head_branch.last_revision() != revid: head_branch.generate_revision_history(revid) source_branch.tags.merge_to(head_branch.tags) if not head_branch.get_parent(): url = urlutils.join_segment_parameters( - source_branch.base, - {"branch": urlutils.escape(branch_name)}) + source_branch.base, {"branch": urlutils.escape(branch_name)} + ) head_branch.set_parent(url) - trace.note(gettext( - "Use 'bzr checkout' to create a working tree in " - "the newly created branches.")) + trace.note( + gettext( + "Use 'bzr checkout' to create a working tree in " + "the newly created branches." + ) + ) class cmd_git_object(Command): @@ -153,11 +161,11 @@ class cmd_git_object(Command): aliases = ["git-objects", "git-cat"] takes_args = ["sha1?"] - takes_options = [Option('directory', - short_name='d', - help='Location of repository.', type=str), - Option('pretty', help='Pretty-print objects.')] - encoding_type = 'exact' + takes_options = [ + Option("directory", short_name="d", help="Location of repository.", type=str), + Option("pretty", help="Pretty-print objects."), + ] + encoding_type = "exact" @display_command def run(self, sha1=None, directory=".", pretty=False): @@ -165,16 +173,16 @@ def run(self, sha1=None, directory=".", pretty=False): from ..errors import CommandError from ..i18n import gettext from .object_store import get_object_store + controldir, _ = ControlDir.open_containing(directory) repo = controldir.find_repository() object_store = get_object_store(repo) with object_store.lock_read(): if sha1 is not None: try: - obj = object_store[sha1.encode('ascii')] + obj = object_store[sha1.encode("ascii")] except KeyError as err: - raise CommandError( - gettext("Object not found: %s") % sha1) from err + raise CommandError(gettext("Object not found: %s") % sha1) from err if pretty: text = obj.as_pretty_string() else: @@ -197,6 +205,7 @@ def run(self, location="."): from ..controldir import ControlDir from .object_store import get_object_store from .refs import get_refs_container + controldir, _ = ControlDir.open_containing(location) repo = controldir.find_repository() object_store = get_object_store(repo) @@ -213,10 +222,9 @@ class cmd_git_apply(Command): """ takes_options = [ - Option('signoff', short_name='s', help='Add a Signed-off-by line.'), - Option('force', - help='Apply patches even if tree has uncommitted changes.') - ] + Option("signoff", short_name="s", help="Add a Signed-off-by line."), + Option("force", help="Apply patches even if tree has uncommitted changes."), + ] takes_args = ["patches*"] def _apply_patch(self, wt, f, signoff): @@ -229,22 +237,25 @@ def _apply_patch(self, wt, f, signoff): from dulwich.patch import git_am_patch_split from ..workingtree import patch_tree + (c, diff, version) = git_am_patch_split(f) # FIXME: Cope with git-specific bits in patch # FIXME: Add new files to working tree from io import BytesIO + b = BytesIO() patch_tree(wt, [diff], strip=1, out=b) - self.outf.write(b.getvalue().decode('utf-8', 'replace')) - message = c.message.decode('utf-8') + self.outf.write(b.getvalue().decode("utf-8", "replace")) + message = c.message.decode("utf-8") if signoff: signed_off_by = wt.branch.get_config().username() message += f"Signed-off-by: {signed_off_by}\n" - wt.commit(authors=[c.author.decode('utf-8')], message=message) + wt.commit(authors=[c.author.decode("utf-8")], message=message) def run(self, patches_list=None, signoff=False, force=False): from ..errors import UncommittedChanges from ..workingtree import WorkingTree + if patches_list is None: patches_list = [] @@ -260,12 +271,12 @@ def run(self, patches_list=None, signoff=False, force=False): class cmd_git_push_pristine_tar_deltas(Command): """Push pristine tar deltas to a git repository.""" - takes_options = [Option('directory', - short_name='d', - help='Location of repository.', type=str)] - takes_args = ['target', 'package'] + takes_options = [ + Option("directory", short_name="d", help="Location of repository.", type=str) + ] + takes_args = ["target", "package"] - def run(self, target, package, directory='.'): + def run(self, target, package, directory="."): from ..branch import Branch from ..errors import CommandError, NoSuchRevision from ..repository import Repository @@ -276,9 +287,10 @@ def run(self, target, package, directory='.'): revision_pristine_tar_data, store_git_pristine_tar_data, ) + source = Branch.open_containing(directory)[0] target_bzr = Repository.open(target) - target = getattr(target_bzr, '_git', None) + target = getattr(target_bzr, "_git", None) if target is None: raise CommandError("Target not a git repository") git_store = get_object_store(source.repository) @@ -294,17 +306,19 @@ def run(self, target, package, directory='.'): except KeyError: continue gitid = git_store._lookup_revision_sha1(revid) - if (not (name.startswith('upstream/') or - name.startswith('upstream-'))): + if not (name.startswith("upstream/") or name.startswith("upstream-")): warning( - "Unexpected pristine tar revision tagged %s. " - "Ignoring.", name) + "Unexpected pristine tar revision tagged %s. " "Ignoring.", name + ) continue - upstream_version = name[len("upstream/"):] - filename = f'{package}_{upstream_version}.orig.tar.{kind}' + upstream_version = name[len("upstream/") :] + filename = f"{package}_{upstream_version}.orig.tar.{kind}" if gitid not in target: warning( "base git id %s for %s missing in target repository", - gitid, filename) - store_git_pristine_tar_data(target, encode_git_path(filename), - delta, gitid) + gitid, + filename, + ) + store_git_pristine_tar_data( + target, encode_git_path(filename), delta, gitid + ) diff --git a/breezy/git/commit.py b/breezy/git/commit.py index 6d1c72c1a5..229579f8d7 100644 --- a/breezy/git/commit.py +++ b/breezy/git/commit.py @@ -62,14 +62,22 @@ def record_iter_changes(self, workingtree, basis_revid, iter_changes): else: file_id = None if change.path[1]: - parent_id_new = self._mapping.generate_file_id(osutils.dirname(change.path[1])) + parent_id_new = self._mapping.generate_file_id( + osutils.dirname(change.path[1]) + ) else: parent_id_new = None if change.kind[1] in ("directory",): self._inv_delta.append( - (change.path[0], change.path[1], file_id, - entry_factory[change.kind[1]]( - file_id, change.name[1], parent_id_new))) + ( + change.path[0], + change.path[1], + file_id, + entry_factory[change.kind[1]]( + file_id, change.name[1], parent_id_new + ), + ) + ) if change.kind[0] in ("file", "symlink"): self._blobs[encode_git_path(change.path[0])] = None self._any_changes = True @@ -151,24 +159,24 @@ def finish_inventory(self): self._blobs = dict(self._blobs.items()) def _iterblobs(self): - return ((path, sha, mode) for (path, (mode, sha)) - in self._blobs.items()) + return ((path, sha, mode) for (path, (mode, sha)) in self._blobs.items()) def commit(self, message): - self._validate_unicode_text(message, 'commit message') + self._validate_unicode_text(message, "commit message") c = Commit() - c.parents = [self.repository.lookup_bzr_revision_id( - revid)[0] for revid in self.parents] + c.parents = [ + self.repository.lookup_bzr_revision_id(revid)[0] for revid in self.parents + ] c.tree = commit_tree(self.store, self._iterblobs()) - encoding = self._revprops.pop('git-explicit-encoding', 'utf-8') - c.encoding = encoding.encode('ascii') + encoding = self._revprops.pop("git-explicit-encoding", "utf-8") + c.encoding = encoding.encode("ascii") c.committer = fix_person_identifier(self._committer.encode(encoding)) pseudoheaders = [] try: - author = self._revprops.pop('author') + author = self._revprops.pop("author") except KeyError: try: - authors = self._revprops.pop('authors').splitlines() + authors = self._revprops.pop("authors").splitlines() except KeyError: author = self._committer else: @@ -178,10 +186,11 @@ def commit(self, message): author = authors[0] for coauthor in authors[1:]: pseudoheaders.append( - b'Co-authored-by: %s' - % fix_person_identifier(coauthor.encode(encoding))) + b"Co-authored-by: %s" + % fix_person_identifier(coauthor.encode(encoding)) + ) c.author = fix_person_identifier(author.encode(encoding)) - bugstext = self._revprops.pop('bugs', None) + bugstext = self._revprops.pop("bugs", None) if bugstext is not None: for url, status in bugtracker.decode_bug_urls(bugstext.splitlines()): if status == bugtracker.FIXED: @@ -201,20 +210,22 @@ def commit(self, message): if not c.message.endswith(b"\n"): c.message += b"\n" c.message += b"\n" + b"".join([line + b"\n" for line in pseudoheaders]) - create_signatures = self._config_stack.get('create_signatures') - if (create_signatures in ( - _mod_config.SIGN_ALWAYS, _mod_config.SIGN_WHEN_POSSIBLE)): + create_signatures = self._config_stack.get("create_signatures") + if create_signatures in ( + _mod_config.SIGN_ALWAYS, + _mod_config.SIGN_WHEN_POSSIBLE, + ): strategy = gpg.GPGStrategy(self._config_stack) try: c.gpgsig = strategy.sign(c.as_raw_string(), gpg.MODE_DETACH) except gpg.GpgNotInstalled as e: if create_signatures == _mod_config.SIGN_WHEN_POSSIBLE: - trace.note('skipping commit signature: %s', e) + trace.note("skipping commit signature: %s", e) else: raise except gpg.SigningFailed as e: if create_signatures == _mod_config.SIGN_WHEN_POSSIBLE: - trace.note('commit signature failed: %s', e) + trace.note("commit signature failed: %s", e) else: raise self.store.add_object(c) diff --git a/breezy/git/config.py b/breezy/git/config.py index 2407682d9a..17a755c4fb 100644 --- a/breezy/git/config.py +++ b/breezy/git/config.py @@ -30,16 +30,17 @@ def __init__(self, branch): def __repr__(self): return f"<{self.__class__.__name__} of {self.branch!r}>" - def set_user_option(self, name, value, store=config.STORE_BRANCH, - warn_masked=False): + def set_user_option( + self, name, value, store=config.STORE_BRANCH, warn_masked=False + ): """Force local to True.""" config.BranchConfig.set_user_option( - self, name, value, store=config.STORE_LOCATION, - warn_masked=warn_masked) + self, name, value, store=config.STORE_LOCATION, warn_masked=warn_masked + ) def _get_user_id(self): # TODO: Read from ~/.gitconfig - return self._get_best_value('_get_user_id') + return self._get_best_value("_get_user_id") class GitConfigSectionDefault(config.Section): @@ -51,19 +52,19 @@ def __init__(self, id, config): self._config = config def get(self, name, default=None, expand=True): - if name == 'email': + if name == "email": try: - email = self._config.get((b'user', ), b'email') + email = self._config.get((b"user",), b"email") except KeyError: return None try: - name = self._config.get((b'user', ), b'name') + name = self._config.get((b"user",), b"name") except KeyError: return email.decode() - return f'{name.decode()} <{email.decode()}>' - if name == 'gpg_signing_key': + return f"{name.decode()} <{email.decode()}>" + if name == "gpg_signing_key": try: - key = self._config.get((b'user', ), b'signingkey') + key = self._config.get((b"user",), b"signingkey") except KeyError: return None return key.decode() @@ -71,17 +72,17 @@ def get(self, name, default=None, expand=True): def iter_option_names(self): try: - self._config.get((b'user', ), b'email') + self._config.get((b"user",), b"email") except KeyError: pass else: - yield 'email' + yield "email" try: - self._config.get((b'user', ), b'signingkey') + self._config.get((b"user",), b"signingkey") except KeyError: pass else: - yield 'gpg_signing_key' + yield "gpg_signing_key" class GitConfigStore(config.Store): @@ -93,8 +94,8 @@ def __init__(self, id, config): def get_sections(self): return [ - (self, GitConfigSectionDefault('default', self._config)), - ] + (self, GitConfigSectionDefault("default", self._config)), + ] class GitBranchStack(config._CompatibleStack): @@ -107,9 +108,9 @@ def __init__(self, branch): section_getters.append(loc_matcher.get_sections) # FIXME: This should also be looking in .git/config for # local git branches. - git = getattr(branch.repository, '_git', None) + git = getattr(branch.repository, "_git", None) if git: - cstore = GitConfigStore('branch', git.get_config()) + cstore = GitConfigStore("branch", git.get_config()) section_getters.append(cstore.get_sections) gstore = config.GlobalStore() section_getters.append(gstore.get_sections) @@ -117,5 +118,7 @@ def __init__(self, branch): section_getters, # All modifications go to the corresponding section in # locations.conf - lstore, branch.base) + lstore, + branch.base, + ) self.branch = branch diff --git a/breezy/git/dir.py b/breezy/git/dir.py index 39c254b607..946e8aefdb 100644 --- a/breezy/git/dir.py +++ b/breezy/git/dir.py @@ -44,7 +44,6 @@ class GitDirConfig: - def get_default_stack_on(self): return None @@ -53,7 +52,6 @@ def set_default_stack_on(self, value): class GitControlDirFormat(ControlDirFormat): - colocated_branches = True fixed_components = True @@ -70,8 +68,9 @@ def network_name(self): class UseExistingRepository(RepositoryAcquisitionPolicy): """A policy of reusing an existing repository.""" - def __init__(self, repository, stack_on=None, stack_on_pwd=None, - require_stacking=False): + def __init__( + self, repository, stack_on=None, stack_on_pwd=None, require_stacking=False + ): """Constructor. :param repository: The repository to use. @@ -79,12 +78,12 @@ def __init__(self, repository, stack_on=None, stack_on_pwd=None, :param stack_on_pwd: If stack_on is relative, the location it is relative to. """ - super().__init__( - stack_on, stack_on_pwd, require_stacking) + super().__init__(stack_on, stack_on_pwd, require_stacking) self._repository = repository - def acquire_repository(self, make_working_trees=None, shared=False, - possible_transports=None): + def acquire_repository( + self, make_working_trees=None, shared=False, possible_transports=None + ): """Implementation of RepositoryAcquisitionPolicy.acquire_repository. Returns an existing repository to use. @@ -122,9 +121,11 @@ def _get_selected_ref(self, branch, ref=None): return ref if branch is not None: from .refs import branch_name_to_ref + return branch_name_to_ref(branch) segment_parameters = getattr( - self.user_transport, "get_segment_parameters", lambda: {})() + self.user_transport, "get_segment_parameters", lambda: {} + )() ref = segment_parameters.get("ref") if ref is not None: return urlutils.unquote_to_bytes(ref) @@ -132,6 +133,7 @@ def _get_selected_ref(self, branch, ref=None): branch = self._get_selected_branch() if branch is not None: from .refs import branch_name_to_ref + return branch_name_to_ref(branch) return b"HEAD" @@ -141,13 +143,23 @@ def get_config(self): def _available_backup_name(self, base): return osutils.available_backup_name(base, self.root_transport.has) - def sprout(self, url, revision_id=None, force_new_repo=False, - recurse='down', possible_transports=None, - accelerator_tree=None, hardlink=False, stacked=False, - source_branch=None, create_tree_if_local=True): + def sprout( + self, + url, + revision_id=None, + force_new_repo=False, + recurse="down", + possible_transports=None, + accelerator_tree=None, + hardlink=False, + stacked=False, + source_branch=None, + create_tree_if_local=True, + ): from ..repository import InterRepository from ..transport import get_transport from ..transport.local import LocalTransport + target_transport = get_transport(url, possible_transports) target_transport.ensure_base() cloning_format = self.cloning_metadir() @@ -164,29 +176,35 @@ def sprout(self, url, revision_id=None, force_new_repo=False, except brz_errors.NoRepositoryPresent: result_repo = result.create_repository() if stacked: - raise _mod_branch.UnstackableBranchFormat( - self._format, self.user_url) + raise _mod_branch.UnstackableBranchFormat(self._format, self.user_url) interrepo = InterRepository.get(source_repository, result_repo) if revision_id is not None: determine_wants = interrepo.get_determine_wants_revids( - [revision_id], include_tags=True) + [revision_id], include_tags=True + ) else: determine_wants = interrepo.determine_wants_all - interrepo.fetch_objects(determine_wants=determine_wants, - mapping=source_branch.mapping) + interrepo.fetch_objects( + determine_wants=determine_wants, mapping=source_branch.mapping + ) result_branch = source_branch.sprout( - result, revision_id=revision_id, repository=result_repo) - if (create_tree_if_local and - result.open_branch(name="").name == result_branch.name and - isinstance(target_transport, LocalTransport) and - (result_repo is None or result_repo.make_working_trees())): + result, revision_id=revision_id, repository=result_repo + ) + if ( + create_tree_if_local + and result.open_branch(name="").name == result_branch.name + and isinstance(target_transport, LocalTransport) + and (result_repo is None or result_repo.make_working_trees()) + ): wt = result.create_workingtree( accelerator_tree=accelerator_tree, - hardlink=hardlink, from_branch=result_branch) + hardlink=hardlink, + from_branch=result_branch, + ) else: wt = None - if recurse == 'down': + if recurse == "down": with contextlib.ExitStack() as stack: basis = None if wt is not None: @@ -208,55 +226,79 @@ def sprout(self, url, revision_id=None, force_new_repo=False, continue remote_url = urlutils.join(self.user_url, sublocation) try: - subbranch = _mod_branch.Branch.open(remote_url, possible_transports=possible_transports) + subbranch = _mod_branch.Branch.open( + remote_url, possible_transports=possible_transports + ) except brz_errors.NotBranchError as e: trace.warning( - 'Unable to clone submodule %s from %s: %s', - path, remote_url, e) + "Unable to clone submodule %s from %s: %s", + path, + remote_url, + e, + ) continue subbranch.controldir.sprout( - target, basis.get_reference_revision(path), - force_new_repo=force_new_repo, recurse=recurse, - stacked=stacked) - if getattr(result_repo, '_git', None): + target, + basis.get_reference_revision(path), + force_new_repo=force_new_repo, + recurse=recurse, + stacked=stacked, + ) + if getattr(result_repo, "_git", None): # Don't leak resources: # TODO(jelmer): This shouldn't be git-specific, and possibly # just use read locks. result_repo._git.object_store.close() return result - def clone_on_transport(self, transport, revision_id=None, - force_new_repo=False, preserve_stacking=False, - stacked_on=None, create_prefix=False, - use_existing_dir=True, no_tree=False, - tag_selector=None): + def clone_on_transport( + self, + transport, + revision_id=None, + force_new_repo=False, + preserve_stacking=False, + stacked_on=None, + create_prefix=False, + use_existing_dir=True, + no_tree=False, + tag_selector=None, + ): """See ControlDir.clone_on_transport.""" from ..repository import InterRepository from ..transport.local import LocalTransport from .mapping import default_mapping from .refs import is_peeled + if no_tree: format = BareLocalGitControlDirFormat() else: format = LocalGitControlDirFormat() if stacked_on is not None: raise _mod_branch.UnstackableBranchFormat(format, self.user_url) - (target_repo, target_controldir, stacking, - repo_policy) = format.initialize_on_transport_ex( - transport, use_existing_dir=use_existing_dir, + ( + target_repo, + target_controldir, + stacking, + repo_policy, + ) = format.initialize_on_transport_ex( + transport, + use_existing_dir=use_existing_dir, create_prefix=create_prefix, - force_new_repo=force_new_repo) + force_new_repo=force_new_repo, + ) target_repo = target_controldir.find_repository() target_git_repo = target_repo._git source_repo = self.find_repository() interrepo = InterRepository.get(source_repo, target_repo) if revision_id is not None: determine_wants = interrepo.get_determine_wants_revids( - [revision_id], include_tags=True, tag_selector=tag_selector) + [revision_id], include_tags=True, tag_selector=tag_selector + ) else: determine_wants = interrepo.determine_wants_all - (pack_hint, _, refs) = interrepo.fetch_objects(determine_wants, - mapping=default_mapping) + (pack_hint, _, refs) = interrepo.fetch_objects( + determine_wants, mapping=default_mapping + ) for name, val in refs.items(): if is_peeled(name): continue @@ -299,8 +341,13 @@ def get_refs_container(self): """Retrieve the refs container.""" raise NotImplementedError(self.get_refs_container) - def determine_repository_policy(self, force_new_repo=False, stack_on=None, - stack_on_pwd=None, require_stacking=False): + def determine_repository_policy( + self, + force_new_repo=False, + stack_on=None, + stack_on_pwd=None, + require_stacking=False, + ): """Return an object representing a policy to use. This controls whether a new repository is created, and the format of @@ -318,6 +365,7 @@ def determine_repository_policy(self, force_new_repo=False, stack_on=None, def branch_names(self): from .refs import ref_to_branch_name + ret = [] for ref in self.get_refs_container().keys(): try: @@ -332,6 +380,7 @@ def branch_names(self): def get_branches(self): from .refs import ref_to_branch_name + ret = {} for ref in self.get_refs_container().keys(): try: @@ -347,9 +396,17 @@ def get_branches(self): def list_branches(self): return list(self.get_branches().values()) - def push_branch(self, source, revision_id=None, overwrite=False, - remember=False, create_prefix=False, lossy=False, - name=None, tag_selector=None): + def push_branch( + self, + source, + revision_id=None, + overwrite=False, + remember=False, + create_prefix=False, + lossy=False, + name=None, + tag_selector=None, + ): """Push the source branch into this ControlDir.""" push_result = GitPushResult() push_result.workingtree_updated = None @@ -357,12 +414,17 @@ def push_branch(self, source, revision_id=None, overwrite=False, push_result.source_branch = source push_result.stacked_on = None from .branch import GitBranch + if isinstance(source, GitBranch) and lossy: raise brz_errors.LossyPushToSameVCS(source.controldir, self) target = self.open_branch(name, nascent_ok=True) push_result.branch_push_result = source.push( - target, overwrite=overwrite, stop_revision=revision_id, - lossy=lossy, tag_selector=tag_selector) + target, + overwrite=overwrite, + stop_revision=revision_id, + lossy=lossy, + tag_selector=tag_selector, + ) push_result.new_revid = push_result.branch_push_result.new_revid push_result.old_revid = push_result.branch_push_result.old_revid try: @@ -373,7 +435,8 @@ def push_branch(self, source, revision_id=None, overwrite=False, if self.open_branch(name="").name == target.name: wt._update_git_tree( old_revision=push_result.old_revid, - new_revision=push_result.new_revid) + new_revision=push_result.new_revid, + ) push_result.workingtree_updated = True else: push_result.workingtree_updated = False @@ -395,15 +458,18 @@ def _known_formats(self): @property def repository_format(self): from .repository import GitRepositoryFormat + return GitRepositoryFormat() @property def workingtree_format(self): from .workingtree import GitWorkingTreeFormat + return GitWorkingTreeFormat() def get_branch_format(self): from .branch import LocalGitBranchFormat + return LocalGitBranchFormat() def open(self, transport, _found=None): @@ -412,18 +478,20 @@ def open(self, transport, _found=None): def _open(transport): try: - return TransportRepo(transport, self.bare, - refs_text=getattr(self, "_refs_text", None)) + return TransportRepo( + transport, self.bare, refs_text=getattr(self, "_refs_text", None) + ) except ValueError as e: - if e.args == ('Expected file to start with \'gitdir: \'', ): + if e.args == ("Expected file to start with 'gitdir: '",): raise brz_errors.NotBranchError(path=transport.base) from e raise def redirected(transport, e, redirection_notice): trace.note(redirection_notice) return transport._redirected_to(e.source, e.target) + gitrepo = do_catching_redirections(_open, transport, redirected) - if not _found and not gitrepo._controltransport.has('objects'): + if not _found and not gitrepo._controltransport.has("objects"): raise brz_errors.NotBranchError(path=transport.base) return LocalGitDir(transport, gitrepo, self) @@ -432,28 +500,36 @@ def get_format_description(self): def initialize_on_transport(self, transport): from .transportgit import TransportRepo + git_repo = TransportRepo.init(transport, bare=self.bare) return LocalGitDir(transport, git_repo, self) - def initialize_on_transport_ex(self, transport, use_existing_dir=False, - create_prefix=False, force_new_repo=False, - stacked_on=None, - stack_on_pwd=None, repo_format_name=None, - make_working_trees=None, - shared_repo=False, vfs_only=False): + def initialize_on_transport_ex( + self, + transport, + use_existing_dir=False, + create_prefix=False, + force_new_repo=False, + stacked_on=None, + stack_on_pwd=None, + repo_format_name=None, + make_working_trees=None, + shared_repo=False, + vfs_only=False, + ): if shared_repo: raise brz_errors.SharedRepositoriesUnsupported(self) def make_directory(transport): - transport.mkdir('.') + transport.mkdir(".") return transport def redirected(transport, e, redirection_notice): trace.note(redirection_notice) return transport._redirected_to(e.source, e.target) + try: - transport = do_catching_redirections( - make_directory, transport, redirected) + transport = do_catching_redirections(make_directory, transport, redirected) except FileExists: if not use_existing_dir: raise @@ -469,8 +545,7 @@ def redirected(transport, e, redirection_notice): else: result_repo = None repository_policy = None - return (result_repo, controldir, False, - repository_policy) + return (result_repo, controldir, False, repository_policy) def is_supported(self): return True @@ -483,13 +558,14 @@ def supports_transport(self, transport): return external_url.startswith("file:") def is_control_filename(self, filename): - return (filename == '.git' - or filename.startswith('.git/') - or filename.startswith('.git\\')) + return ( + filename == ".git" + or filename.startswith(".git/") + or filename.startswith(".git\\") + ) class BareLocalGitControlDirFormat(LocalGitControlDirFormat): - bare = True supports_workingtrees = False @@ -505,6 +581,7 @@ class LocalGitDir(GitDir): def _get_gitrepository_class(self): from .repository import LocalGitRepository + return LocalGitRepository def __repr__(self): @@ -528,7 +605,7 @@ def __init__(self, transport, gitrepo, format): if gitrepo.bare: self.transport = transport else: - self.transport = transport.clone('.git') + self.transport = transport.clone(".git") self._mode_check_done = None def _get_symref(self, ref): @@ -546,31 +623,33 @@ def set_branch_reference(self, target_branch, name=None): self._git.refs.set_symbolic_ref(ref, target_branch.ref) else: try: - target_path = ( - target_branch.controldir.control_transport.local_abspath( - '.')) + target_path = target_branch.controldir.control_transport.local_abspath( + "." + ) except brz_errors.NotLocalUrl as err: raise brz_errors.IncompatibleFormat( - target_branch._format, self._format) from err + target_branch._format, self._format + ) from err # TODO(jelmer): Do some consistency checking across branches.. - self.control_transport.put_bytes( - 'commondir', encode_git_path(target_path)) + self.control_transport.put_bytes("commondir", encode_git_path(target_path)) # TODO(jelmer): Urgh, avoid mucking about with internals. self._git._commontransport = ( - target_branch.repository._git._commontransport.clone()) + target_branch.repository._git._commontransport.clone() + ) self._git.object_store = TransportObjectStore( - self._git._commontransport.clone(OBJECTDIR)) + self._git._commontransport.clone(OBJECTDIR) + ) self._git.refs.transport = self._git._commontransport - target_ref_chain, unused_sha = ( - target_branch.controldir._git.refs.follow(target_branch.ref)) + target_ref_chain, unused_sha = target_branch.controldir._git.refs.follow( + target_branch.ref + ) for target_ref in target_ref_chain: - if target_ref == b'HEAD': + if target_ref == b"HEAD": continue break else: # Can't create a reference to something that is not a in a repository. - raise brz_errors.IncompatibleFormat( - self.set_branch_reference, self) + raise brz_errors.IncompatibleFormat(self.set_branch_reference, self) self._git.refs.set_symbolic_ref(ref, target_ref) def get_branch_reference(self, name=None): @@ -581,28 +660,33 @@ def get_branch_reference(self, name=None): raise BranchReferenceLoop(self) from err if target_ref is not None: from .refs import ref_to_branch_name + try: branch_name = ref_to_branch_name(target_ref) except ValueError: - params = {'ref': urlutils.quote( - target_ref.decode('utf-8'), '')} + params = {"ref": urlutils.quote(target_ref.decode("utf-8"), "")} else: - if branch_name != '': - params = {'branch': urlutils.quote(branch_name, '')} + if branch_name != "": + params = {"branch": urlutils.quote(branch_name, "")} else: params = {} try: - commondir = self.control_transport.get_bytes('commondir') + commondir = self.control_transport.get_bytes("commondir") except NoSuchFile: - base_url = self.user_url.rstrip('/') + base_url = self.user_url.rstrip("/") else: - base_url = urlutils.local_path_to_url( # noqa: B005 - decode_git_path(commondir)).rstrip('/.git/') + '/' + base_url = ( + urlutils.local_path_to_url( # noqa: B005 + decode_git_path(commondir) + ).rstrip("/.git/") + + "/" + ) return urlutils.join_segment_parameters(base_url, params) return None def find_branch_format(self, name=None): from .branch import LocalGitBranchFormat + return LocalGitBranchFormat() def get_branch_transport(self, branch_format, name=None): @@ -626,20 +710,27 @@ def get_workingtree_transport(self, format): return self.transport raise brz_errors.IncompatibleFormat(format, self._format) - def open_branch(self, name=None, unsupported=False, ignore_fallbacks=None, - ref=None, possible_transports=None, nascent_ok=False): + def open_branch( + self, + name=None, + unsupported=False, + ignore_fallbacks=None, + ref=None, + possible_transports=None, + nascent_ok=False, + ): """'create' a branch for this dir.""" # noqa: D403 repo = self.find_repository() from .branch import LocalGitBranch + ref = self._get_selected_ref(name, ref) if not nascent_ok and ref not in self._git.refs: - raise brz_errors.NotBranchError( - self.root_transport.base, controldir=self) + raise brz_errors.NotBranchError(self.root_transport.base, controldir=self) try: ref_chain, unused_sha = self._git.refs.follow(ref) except SymrefLoop as err: raise BranchReferenceLoop(self) from err - if ref_chain[-1] == b'HEAD': + if ref_chain[-1] == b"HEAD": controldir = self else: controldir = self._find_commondir() @@ -647,15 +738,15 @@ def open_branch(self, name=None, unsupported=False, ignore_fallbacks=None, def destroy_branch(self, name=None): refname = self._get_selected_ref(name) - if refname == b'HEAD': + if refname == b"HEAD": # HEAD can't be removed - raise brz_errors.UnsupportedOperation( - self.destroy_branch, self) + raise brz_errors.UnsupportedOperation(self.destroy_branch, self) try: del self._git.refs[refname] except KeyError as err: raise brz_errors.NotBranchError( - self.root_transport.base, controldir=self) from err + self.root_transport.base, controldir=self + ) from err def destroy_repository(self): raise brz_errors.UnsupportedOperation(self.destroy_repository, self) @@ -664,15 +755,14 @@ def destroy_workingtree(self): raise brz_errors.UnsupportedOperation(self.destroy_workingtree, self) def destroy_workingtree_metadata(self): - raise brz_errors.UnsupportedOperation( - self.destroy_workingtree_metadata, self) + raise brz_errors.UnsupportedOperation(self.destroy_workingtree_metadata, self) def needs_format_conversion(self, format=None): return not isinstance(self._format, format.__class__) def open_repository(self): """'open' a repository for this dir.""" # noqa: D403 - if self.control_transport.has('commondir'): + if self.control_transport.has("commondir"): raise brz_errors.NoRepositoryPresent(self) return self._gitrepository_class(self) @@ -683,30 +773,32 @@ def open_workingtree(self, recommend_upgrade=True, unsupported=False): if not self._git.bare: repo = self.find_repository() from .workingtree import GitWorkingTree - branch = self.open_branch(ref=b'HEAD', nascent_ok=True) + + branch = self.open_branch(ref=b"HEAD", nascent_ok=True) return GitWorkingTree(self, repo, branch) - loc = urlutils.unescape_for_display(self.root_transport.base, 'ascii') + loc = urlutils.unescape_for_display(self.root_transport.base, "ascii") raise brz_errors.NoWorkingTree(loc) def create_repository(self, shared=False): from .repository import GitRepositoryFormat + if shared: - raise brz_errors.IncompatibleFormat( - GitRepositoryFormat(), self._format) + raise brz_errors.IncompatibleFormat(GitRepositoryFormat(), self._format) return self.find_repository() - def create_branch(self, name=None, repository=None, - append_revisions_only=None, ref=None): + def create_branch( + self, name=None, repository=None, append_revisions_only=None, ref=None + ): refname = self._get_selected_ref(name, ref) - if refname != b'HEAD' and refname in self._git.refs: + if refname != b"HEAD" and refname in self._git.refs: raise brz_errors.AlreadyBranchError(self.user_url) repo = self.open_repository() if refname in self._git.refs: - ref_chain, unused_sha = self._git.refs.follow( - self._get_selected_ref(None)) - if ref_chain[0] == b'HEAD': + ref_chain, unused_sha = self._git.refs.follow(self._get_selected_ref(None)) + if ref_chain[0] == b"HEAD": refname = ref_chain[1] from .branch import LocalGitBranch + branch = LocalGitBranch(self, repo, refname) if append_revisions_only: branch.set_append_revisions_only(append_revisions_only) @@ -715,24 +807,27 @@ def create_branch(self, name=None, repository=None, def backup_bzrdir(self): if not self._git.bare: self.root_transport.copy_tree(".git", ".git.backup") - return (self.root_transport.abspath(".git"), - self.root_transport.abspath(".git.backup")) + return ( + self.root_transport.abspath(".git"), + self.root_transport.abspath(".git.backup"), + ) else: basename = urlutils.basename(self.root_transport.base) - parent = self.root_transport.clone('..') + parent = self.root_transport.clone("..") parent.copy_tree(basename, basename + ".backup") - def create_workingtree(self, revision_id=None, from_branch=None, - accelerator_tree=None, hardlink=False): + def create_workingtree( + self, revision_id=None, from_branch=None, accelerator_tree=None, hardlink=False + ): if self._git.bare: - raise brz_errors.UnsupportedOperation( - self.create_workingtree, self) + raise brz_errors.UnsupportedOperation(self.create_workingtree, self) if from_branch is None: from_branch = self.open_branch(nascent_ok=True) if revision_id is None: revision_id = from_branch.last_revision() repo = self.find_repository() from .workingtree import GitWorkingTree + wt = GitWorkingTree(self, repo, from_branch) wt.set_last_revision(revision_id) wt._build_checkout_with_index() @@ -753,7 +848,7 @@ def _find_creation_modes(self): return self._mode_check_done = True try: - st = self.transport.stat('.') + st = self.transport.stat(".") except brz_errors.TransportNotPossible: self._dir_mode = None self._file_mode = None @@ -762,7 +857,7 @@ def _find_creation_modes(self): # directories and files are read-write for this user. This is # mostly a workaround for filesystems which lie about being able to # write to a directory (cygwin & win32) - if (st.st_mode & 0o7777 == 0o0000): + if st.st_mode & 0o7777 == 0o0000: # FTP allows stat but does not return dir/file modes self._dir_mode = None self._file_mode = None @@ -791,10 +886,9 @@ def get_peeled(self, ref): def _find_commondir(self): try: - commondir = self.control_transport.get_bytes('commondir') + commondir = self.control_transport.get_bytes("commondir") except NoSuchFile: return self else: - commondir = os.fsdecode(commondir.rstrip(b'/.git/')) - return ControlDir.open_from_transport( - get_transport_from_path(commondir)) + commondir = os.fsdecode(commondir.rstrip(b"/.git/")) + return ControlDir.open_from_transport(get_transport_from_path(commondir)) diff --git a/breezy/git/directory.py b/breezy/git/directory.py index efe064199a..41b201b189 100644 --- a/breezy/git/directory.py +++ b/breezy/git/directory.py @@ -19,11 +19,10 @@ from .. import transport -transport.register_urlparse_netloc_protocol('github') +transport.register_urlparse_netloc_protocol("github") class GitHubDirectory: - def look_up(self, name, url, purpose=None): """See DirectoryService.look_up.""" return "git+ssh://git@github.com/" + name diff --git a/breezy/git/errors.py b/breezy/git/errors.py index b97678d1a9..cf41d81448 100644 --- a/breezy/git/errors.py +++ b/breezy/git/errors.py @@ -26,8 +26,10 @@ class BzrGitError(brz_errors.BzrError): class NoPushSupport(brz_errors.BzrError): - _fmt = ("Push is not yet supported from %(source)r to %(target)r " - "using %(mapping)r for %(revision_id)r. Try dpush instead.") + _fmt = ( + "Push is not yet supported from %(source)r to %(target)r " + "using %(mapping)r for %(revision_id)r. Try dpush instead." + ) def __init__(self, source, target, mapping, revision_id=None): self.source = source diff --git a/breezy/git/fetch.py b/breezy/git/fetch.py index df15f213ad..d3fbb291d5 100644 --- a/breezy/git/fetch.py +++ b/breezy/git/fetch.py @@ -48,10 +48,21 @@ from .object_store import LRUTreeCache, _tree_to_objects -def import_git_blob(texts, mapping, path, name, hexshas, - base_bzr_tree, parent_id, revision_id, - parent_bzr_trees, lookup_object, modes, store_updater, - lookup_file_id): +def import_git_blob( + texts, + mapping, + path, + name, + hexshas, + base_bzr_tree, + parent_id, + revision_id, + parent_bzr_trees, + lookup_object, + modes, + store_updater, + lookup_file_id, +): """Import a git blob object into a bzr repository. :param texts: VersionedFiles to add to @@ -101,16 +112,23 @@ def import_git_blob(texts, mapping, path, name, hexshas, for ptree in parent_bzr_trees: intertree = InterTree.get(ptree, base_bzr_tree) try: - ppath = intertree.find_source_paths(decoded_path, recurse='none') + ppath = intertree.find_source_paths(decoded_path, recurse="none") except NoSuchFile: continue if ppath is None: continue pkind = ptree.kind(ppath) - if (pkind == ie.kind and - ((pkind == "symlink" and ptree.get_symlink_target(ppath) == ie.symlink_target) or - (pkind == "file" and ptree.get_file_sha1(ppath) == ie.text_sha1 and - ptree.is_executable(ppath) == ie.executable))): + if pkind == ie.kind and ( + ( + pkind == "symlink" + and ptree.get_symlink_target(ppath) == ie.symlink_target + ) + or ( + pkind == "file" + and ptree.get_file_sha1(ppath) == ie.text_sha1 + and ptree.is_executable(ppath) == ie.executable + ) + ): # found a revision in one of the parents to use ie.revision = ptree.get_file_revision(ppath) break @@ -122,20 +140,33 @@ def import_git_blob(texts, mapping, path, name, hexshas, ie.revision = revision_id if ie.revision is None: raise ValueError("no file revision set") - if ie.kind == 'symlink': + if ie.kind == "symlink": chunks = [] else: chunks = blob.chunked - texts.insert_record_stream([ - ChunkedContentFactory((file_id, ie.revision), - tuple(parent_keys), getattr(ie, 'text_sha1', None), chunks)]) + texts.insert_record_stream( + [ + ChunkedContentFactory( + (file_id, ie.revision), + tuple(parent_keys), + getattr(ie, "text_sha1", None), + chunks, + ) + ] + ) invdelta = [] if base_hexsha is not None: old_path = decoded_path # Renames are not supported yet if stat.S_ISDIR(base_mode): - invdelta.extend(remove_disappeared_children( - base_bzr_tree, old_path, lookup_object(base_hexsha), [], - lookup_object)) + invdelta.extend( + remove_disappeared_children( + base_bzr_tree, + old_path, + lookup_object(base_hexsha), + [], + lookup_object, + ) + ) else: old_path = None invdelta.append((old_path, decoded_path, file_id, ie)) @@ -145,15 +176,28 @@ def import_git_blob(texts, mapping, path, name, hexshas, class SubmodulesRequireSubtrees(BzrError): - _fmt = ("The repository you are fetching from contains submodules, " - "which require a Bazaar format that supports tree references.") + _fmt = ( + "The repository you are fetching from contains submodules, " + "which require a Bazaar format that supports tree references." + ) internal = False -def import_git_submodule(texts, mapping, path, name, hexshas, - base_bzr_tree, parent_id, revision_id, - parent_bzr_trees, lookup_object, - modes, store_updater, lookup_file_id): +def import_git_submodule( + texts, + mapping, + path, + name, + hexshas, + base_bzr_tree, + parent_id, + revision_id, + parent_bzr_trees, + lookup_object, + modes, + store_updater, + lookup_file_id, +): """Import a git submodule.""" (base_hexsha, hexsha) = hexshas (base_mode, mode) = modes @@ -167,20 +211,28 @@ def import_git_submodule(texts, mapping, path, name, hexshas, if base_hexsha is not None: old_path = path # Renames are not supported yet if stat.S_ISDIR(base_mode): - invdelta.extend(remove_disappeared_children( - base_bzr_tree, old_path, lookup_object(base_hexsha), [], - lookup_object)) + invdelta.extend( + remove_disappeared_children( + base_bzr_tree, + old_path, + lookup_object(base_hexsha), + [], + lookup_object, + ) + ) else: old_path = None ie.reference_revision = mapping.revision_id_foreign_to_bzr(hexsha) - texts.insert_record_stream([ - ChunkedContentFactory((file_id, ie.revision), (), None, [])]) + texts.insert_record_stream( + [ChunkedContentFactory((file_id, ie.revision), (), None, [])] + ) invdelta.append((old_path, path, file_id, ie)) return invdelta, {} -def remove_disappeared_children(base_bzr_tree, path, base_tree, - existing_children, lookup_object): +def remove_disappeared_children( + base_bzr_tree, path, base_tree, existing_children, lookup_object +): """Generate an inventory delta for removed children. :param base_bzr_tree: Base bzr tree against which to generate the @@ -203,16 +255,30 @@ def remove_disappeared_children(base_bzr_tree, path, base_tree, raise TypeError(file_id) ret.append((c_path, None, file_id, None)) if stat.S_ISDIR(mode): - ret.extend(remove_disappeared_children( - base_bzr_tree, c_path, lookup_object(hexsha), [], - lookup_object)) + ret.extend( + remove_disappeared_children( + base_bzr_tree, c_path, lookup_object(hexsha), [], lookup_object + ) + ) return ret -def import_git_tree(texts, mapping, path, name, hexshas, - base_bzr_tree, parent_id, revision_id, parent_bzr_trees, - lookup_object, modes, store_updater, - lookup_file_id, allow_submodules=False): +def import_git_tree( + texts, + mapping, + path, + name, + hexshas, + base_bzr_tree, + parent_id, + revision_id, + parent_bzr_trees, + lookup_object, + modes, + store_updater, + lookup_file_id, + allow_submodules=False, +): """Import a git tree object into a bzr repository. :param texts: VersionedFiles object to add to @@ -246,8 +312,9 @@ def import_git_tree(texts, mapping, path, name, hexshas, if base_tree is None or type(base_tree) is not Tree: ie.revision = revision_id invdelta.append((old_path, new_path, ie.file_id, ie)) - texts.insert_record_stream([ - ChunkedContentFactory((ie.file_id, ie.revision), (), None, [])]) + texts.insert_record_stream( + [ChunkedContentFactory((ie.file_id, ie.revision), (), None, [])] + ) # Remember for next time existing_children = set() child_modes = {} @@ -265,101 +332,153 @@ def import_git_tree(texts, mapping, path, name, hexshas, child_base_mode = 0 if stat.S_ISDIR(child_mode): subinvdelta, grandchildmodes = import_git_tree( - texts, mapping, child_path, name, - (child_base_hexsha, child_hexsha), base_bzr_tree, file_id, - revision_id, parent_bzr_trees, lookup_object, - (child_base_mode, child_mode), store_updater, lookup_file_id, - allow_submodules=allow_submodules) + texts, + mapping, + child_path, + name, + (child_base_hexsha, child_hexsha), + base_bzr_tree, + file_id, + revision_id, + parent_bzr_trees, + lookup_object, + (child_base_mode, child_mode), + store_updater, + lookup_file_id, + allow_submodules=allow_submodules, + ) elif S_ISGITLINK(child_mode): # submodule if not allow_submodules: raise SubmodulesRequireSubtrees() subinvdelta, grandchildmodes = import_git_submodule( - texts, mapping, child_path, name, + texts, + mapping, + child_path, + name, (child_base_hexsha, child_hexsha), - base_bzr_tree, file_id, revision_id, parent_bzr_trees, - lookup_object, (child_base_mode, child_mode), store_updater, - lookup_file_id) + base_bzr_tree, + file_id, + revision_id, + parent_bzr_trees, + lookup_object, + (child_base_mode, child_mode), + store_updater, + lookup_file_id, + ) else: if not mapping.is_special_file(name): subinvdelta = import_git_blob( - texts, mapping, child_path, name, - (child_base_hexsha, child_hexsha), base_bzr_tree, file_id, - revision_id, parent_bzr_trees, lookup_object, - (child_base_mode, child_mode), store_updater, - lookup_file_id) + texts, + mapping, + child_path, + name, + (child_base_hexsha, child_hexsha), + base_bzr_tree, + file_id, + revision_id, + parent_bzr_trees, + lookup_object, + (child_base_mode, child_mode), + store_updater, + lookup_file_id, + ) else: subinvdelta = [] grandchildmodes = {} child_modes.update(grandchildmodes) invdelta.extend(subinvdelta) - if child_mode not in (stat.S_IFDIR, DEFAULT_FILE_MODE, - stat.S_IFLNK, DEFAULT_FILE_MODE | 0o111, - S_IFGITLINK): + if child_mode not in ( + stat.S_IFDIR, + DEFAULT_FILE_MODE, + stat.S_IFLNK, + DEFAULT_FILE_MODE | 0o111, + S_IFGITLINK, + ): child_modes[child_path] = child_mode # Remove any children that have disappeared if base_tree is not None and type(base_tree) is Tree: - invdelta.extend(remove_disappeared_children( - base_bzr_tree, old_path, base_tree, existing_children, - lookup_object)) + invdelta.extend( + remove_disappeared_children( + base_bzr_tree, old_path, base_tree, existing_children, lookup_object + ) + ) store_updater.add_object(tree, (file_id, revision_id), path) return invdelta, child_modes -def verify_commit_reconstruction(target_git_object_retriever, lookup_object, - o, rev, ret_tree, parent_trees, mapping, - unusual_modes, verifiers): +def verify_commit_reconstruction( + target_git_object_retriever, + lookup_object, + o, + rev, + ret_tree, + parent_trees, + mapping, + unusual_modes, + verifiers, +): new_unusual_modes = mapping.export_unusual_file_modes(rev) if new_unusual_modes != unusual_modes: - raise AssertionError("unusual modes don't match: {!r} != {!r}".format( - unusual_modes, new_unusual_modes)) + raise AssertionError( + f"unusual modes don't match: {unusual_modes!r} != {new_unusual_modes!r}" + ) # Verify that we can reconstruct the commit properly - rec_o = target_git_object_retriever._reconstruct_commit(rev, o.tree, True, - verifiers) + rec_o = target_git_object_retriever._reconstruct_commit( + rev, o.tree, True, verifiers + ) if rec_o != o: raise AssertionError(f"Reconstructed commit differs: {rec_o!r} != {o!r}") diff = [] new_objs = {} for path, obj, _ie in _tree_to_objects( - ret_tree, parent_trees, target_git_object_retriever._cache.idmap, - unusual_modes, mapping.BZR_DUMMY_FILE): + ret_tree, + parent_trees, + target_git_object_retriever._cache.idmap, + unusual_modes, + mapping.BZR_DUMMY_FILE, + ): old_obj_id = tree_lookup_path(lookup_object, o.tree, path)[1] new_objs[path] = obj if obj.id != old_obj_id: diff.append((path, lookup_object(old_obj_id), obj)) - for (path, old_obj, new_obj) in diff: - while (old_obj.type_name == "tree" - and new_obj.type_name == "tree" - and sorted(old_obj) == sorted(new_obj)): + for path, old_obj, new_obj in diff: + while ( + old_obj.type_name == "tree" + and new_obj.type_name == "tree" + and sorted(old_obj) == sorted(new_obj) + ): for name in old_obj: if old_obj[name][0] != new_obj[name][0]: raise AssertionError( - f"Modes for {path} differ: {old_obj[name][0]:o} != {new_obj[name][0]:o}") + f"Modes for {path} differ: {old_obj[name][0]:o} != {new_obj[name][0]:o}" + ) if old_obj[name][1] != new_obj[name][1]: # Found a differing child, delve deeper path = posixpath.join(path, name) old_obj = lookup_object(old_obj[name][1]) new_obj = new_objs[path] break - raise AssertionError( - f"objects differ for {path}: {old_obj!r} != {new_obj!r}") + raise AssertionError(f"objects differ for {path}: {old_obj!r} != {new_obj!r}") def ensure_inventories_in_repo(repo, trees): real_inv_vf = repo.inventories.without_fallbacks() for t in trees: revid = t.get_revision_id() - if not real_inv_vf.get_parent_map([(revid, )]): + if not real_inv_vf.get_parent_map([(revid,)]): repo.add_inventory(revid, t.root_inventory, t.get_parent_ids()) -def import_git_commit(repo, mapping, head, lookup_object, - target_git_object_retriever, trees_cache, strict): +def import_git_commit( + repo, mapping, head, lookup_object, target_git_object_retriever, trees_cache, strict +): o = lookup_object(head) # Note that this uses mapping.revision_id_foreign_to_bzr. If the parents # were bzr roundtripped revisions they would be specified in the # roundtrip data. rev, roundtrip_revid, verifiers = mapping.import_commit( - o, mapping.revision_id_foreign_to_bzr, strict) + o, mapping.revision_id_foreign_to_bzr, strict + ) if roundtrip_revid is not None: original_revid = rev.revision_id rev.revision_id = roundtrip_revid @@ -378,11 +497,21 @@ def import_git_commit(repo, mapping, head, lookup_object, base_mode = stat.S_IFDIR store_updater = target_git_object_retriever._get_updater(rev) inv_delta, unusual_modes = import_git_tree( - repo.texts, mapping, b"", b"", (base_tree, o.tree), base_bzr_tree, - None, rev.revision_id, parent_trees, lookup_object, - (base_mode, stat.S_IFDIR), store_updater, + repo.texts, + mapping, + b"", + b"", + (base_tree, o.tree), + base_bzr_tree, + None, + rev.revision_id, + parent_trees, + lookup_object, + (base_mode, stat.S_IFDIR), + store_updater, mapping.generate_file_id, - allow_submodules=repo._format.supports_tree_reference) + allow_submodules=repo._format.supports_tree_reference, + ) if unusual_modes != {}: for path, mode in unusual_modes.iteritems(): warn_unusual_mode(rev.foreign_revid, path, mode) @@ -396,21 +525,24 @@ def import_git_commit(repo, mapping, head, lookup_object, base_bzr_inventory = base_bzr_tree.root_inventory inv_delta = InventoryDelta(inv_delta) rev.inventory_sha1, inv = repo.add_inventory_by_delta( - basis_id, inv_delta, rev.revision_id, rev.parent_ids, - base_bzr_inventory) + basis_id, inv_delta, rev.revision_id, rev.parent_ids, base_bzr_inventory + ) ret_tree = InventoryRevisionTree(repo, inv, rev.revision_id) # Check verifiers if verifiers and roundtrip_revid is not None: testament = StrictTestament3(rev, ret_tree) calculated_verifiers = {"testament3-sha1": testament.as_sha1()} if calculated_verifiers != verifiers: - trace.mutter("Testament SHA1 %r for %r did not match %r.", - calculated_verifiers["testament3-sha1"], - rev.revision_id, verifiers["testament3-sha1"]) + trace.mutter( + "Testament SHA1 %r for %r did not match %r.", + calculated_verifiers["testament3-sha1"], + rev.revision_id, + verifiers["testament3-sha1"], + ) rev.revision_id = original_revid rev.inventory_sha1, inv = repo.add_inventory_by_delta( - basis_id, inv_delta, rev.revision_id, rev.parent_ids, - base_bzr_tree) + basis_id, inv_delta, rev.revision_id, rev.parent_ids, base_bzr_tree + ) ret_tree = InventoryRevisionTree(repo, inv, rev.revision_id) else: calculated_verifiers = {} @@ -418,15 +550,23 @@ def import_git_commit(repo, mapping, head, lookup_object, store_updater.finish() trees_cache.add(ret_tree) repo.add_revision(rev.revision_id, rev) - if debug.debug_flag_enabled('verify'): + if debug.debug_flag_enabled("verify"): verify_commit_reconstruction( - target_git_object_retriever, lookup_object, o, rev, ret_tree, - parent_trees, mapping, unusual_modes, verifiers) + target_git_object_retriever, + lookup_object, + o, + rev, + ret_tree, + parent_trees, + mapping, + unusual_modes, + verifiers, + ) -def import_git_objects(repo, mapping, object_iter, - target_git_object_retriever, heads, pb=None, - limit=None): +def import_git_objects( + repo, mapping, object_iter, target_git_object_retriever, heads, pb=None, limit=None +): """Import a set of git objects into a bzr repository. :param repo: Target Bazaar repository @@ -434,11 +574,13 @@ def import_git_objects(repo, mapping, object_iter, :param object_iter: Iterator over Git objects. :return: Tuple with pack hints and last imported revision id """ + def lookup_object(sha): try: return object_iter[sha] except KeyError: return target_git_object_retriever[sha] + graph = [] checked = set() heads = list(set(heads)) @@ -458,10 +600,11 @@ def lookup_object(sha): continue if isinstance(o, Commit): rev, roundtrip_revid, verifiers = mapping.import_commit( - o, mapping.revision_id_foreign_to_bzr, strict=True) - if (repo.has_revision(rev.revision_id) - or (roundtrip_revid and - repo.has_revision(roundtrip_revid))): + o, mapping.revision_id_foreign_to_bzr, strict=True + ) + if repo.has_revision(rev.revision_id) or ( + roundtrip_revid and repo.has_revision(roundtrip_revid) + ): continue graph.append((o.id, o.parents)) heads.extend([p for p in o.parents if p not in checked]) @@ -485,14 +628,18 @@ def lookup_object(sha): try: repo.start_write_group() try: - for i, head in enumerate( - revision_ids[offset:offset + batch_size]): + for i, head in enumerate(revision_ids[offset : offset + batch_size]): if pb is not None: - pb.update("fetching revisions", offset + i, - len(revision_ids)) - import_git_commit(repo, mapping, head, lookup_object, - target_git_object_retriever, trees_cache, - strict=True) + pb.update("fetching revisions", offset + i, len(revision_ids)) + import_git_commit( + repo, + mapping, + head, + lookup_object, + target_git_object_retriever, + trees_cache, + strict=True, + ) last_imported = head except BaseException: repo.abort_write_group() @@ -510,7 +657,6 @@ def lookup_object(sha): class DetermineWantsRecorder: - def __init__(self, actual): self.actual = actual self.wants = [] diff --git a/breezy/git/filegraph.py b/breezy/git/filegraph.py index dd71b4e007..11eabff5a5 100644 --- a/breezy/git/filegraph.py +++ b/breezy/git/filegraph.py @@ -29,7 +29,6 @@ class GitFileLastChangeScanner: - def __init__(self, repository): self.repository = repository self.store = self.repository._git.object_store @@ -42,7 +41,8 @@ def find_last_change_revision(self, path, commit_id): commit = store[commit_id] try: target_mode, target_sha = tree_lookup_path( - store.__getitem__, commit.tree, path) + store.__getitem__, commit.tree, path + ) except SubmoduleEncountered as e: revid = self.repository.lookup_foreign_revision_id(commit_id) revtree = self.repository.revision_tree(revid) @@ -51,7 +51,7 @@ def find_last_change_revision(self, path, commit_id): path = posixpath.relpath(path, e.path) else: break - if path == b'': + if path == b"": target_mode = stat.S_IFDIR if target_mode is None: raise AssertionError(f"sha {target_sha!r} for {path!r} in {commit_id!r}") @@ -60,18 +60,20 @@ def find_last_change_revision(self, path, commit_id): for parent_id in commit.parents: try: parent_commit = store[parent_id] - mode, sha = tree_lookup_path(store.__getitem__, - parent_commit.tree, path) + mode, sha = tree_lookup_path( + store.__getitem__, parent_commit.tree, path + ) except (KeyError, NotTreeError): continue else: parent_commits.append(parent_commit) - if path == b'': + if path == b"": mode = stat.S_IFDIR # Candidate found iff, mode or text changed, # or is a directory that didn't previously exist. if mode != target_mode or ( - not stat.S_ISDIR(target_mode) and sha != target_sha): + not stat.S_ISDIR(target_mode) and sha != target_sha + ): return (store, path, commit.id) if parent_commits == []: break @@ -80,15 +82,14 @@ def find_last_change_revision(self, path, commit_id): class GitFileParentProvider: - def __init__(self, change_scanner): self.change_scanner = change_scanner self.store = self.change_scanner.repository._git.object_store def _get_parents(self, file_id, text_revision): - commit_id, mapping = ( - self.change_scanner.repository.lookup_bzr_revision_id( - text_revision)) + commit_id, mapping = self.change_scanner.repository.lookup_bzr_revision_id( + text_revision + ) try: path = encode_git_path(mapping.parse_file_id(file_id)) except ValueError as err: @@ -96,17 +97,21 @@ def _get_parents(self, file_id, text_revision): text_parents = [] for commit_parent in self.store[commit_id].parents: try: - (store, path, text_parent) = ( - self.change_scanner.find_last_change_revision( - path, commit_parent)) + ( + store, + path, + text_parent, + ) = self.change_scanner.find_last_change_revision(path, commit_parent) except KeyError: continue if text_parent not in text_parents: text_parents.append(text_parent) - return tuple([ - (file_id, - self.change_scanner.repository.lookup_foreign_revision_id(p)) - for p in text_parents]) + return tuple( + [ + (file_id, self.change_scanner.repository.lookup_foreign_revision_id(p)) + for p in text_parents + ] + ) def get_parent_map(self, keys): ret = {} diff --git a/breezy/git/git_remote_helper.py b/breezy/git/git_remote_helper.py index 886ca12fed..1d374209ee 100644 --- a/breezy/git/git_remote_helper.py +++ b/breezy/git/git_remote_helper.py @@ -55,10 +55,9 @@ def fetch(outf, wants, shortname, remote_dir, local_dir): local_repo = local_dir.find_repository() inter = InterRepository.get(remote_repo, local_repo) revs = [] - for (sha1, _ref) in wants: + for sha1, _ref in wants: revs.append((sha1, None)) - if (isinstance(remote_repo, GitRepository) and - isinstance(local_repo, GitRepository)): + if isinstance(remote_repo, GitRepository) and isinstance(local_repo, GitRepository): lossy = False else: lossy = True @@ -67,7 +66,7 @@ def fetch(outf, wants, shortname, remote_dir, local_dir): def push(outf, wants, shortname, remote_dir, local_dir): - for (src_ref, dest_ref) in wants: + for src_ref, dest_ref in wants: local_branch = local_dir.open_branch(ref=src_ref) dest_branch_name = ref_to_branch_name(dest_ref) if dest_branch_name == "master": @@ -125,15 +124,23 @@ def cmd_push(self, outf, argv): def cmd_import(self, outf, argv): if "fastimport" in CAPABILITIES: raise Exception("install fastimport for 'import' command support") - ref = argv[1].encode('utf-8') + ref = argv[1].encode("utf-8") dest_branch_name = ref_to_branch_name(ref) if dest_branch_name == "master": dest_branch_name = None remote_branch = self.remote_dir.open_branch(name=dest_branch_name) exporter = fastexporter.BzrFastExporter( - remote_branch, outf=outf, ref=ref, checkpoint=None, - import_marks_file=None, export_marks_file=None, revision=None, - verbose=None, plain_format=True, rewrite_tags=False) + remote_branch, + outf=outf, + ref=ref, + checkpoint=None, + import_marks_file=None, + export_marks_file=None, + revision=None, + verbose=None, + plain_format=True, + rewrite_tags=False, + ) exporter.run() commands = { @@ -143,7 +150,7 @@ def cmd_import(self, outf, argv): "fetch": cmd_fetch, "push": cmd_push, "import": cmd_import, - } + } def process(self, inf, outf): while True: @@ -156,11 +163,9 @@ def process_line(self, l, outf): argv = l.strip().split() if argv == []: if self.batchcmd == "fetch": - fetch(outf, self.wants, self.shortname, - self.remote_dir, self.local_dir) + fetch(outf, self.wants, self.shortname, self.remote_dir, self.local_dir) elif self.batchcmd == "push": - push(outf, self.wants, self.shortname, - self.remote_dir, self.local_dir) + push(outf, self.wants, self.shortname, self.remote_dir, self.local_dir) elif self.batchcmd is None: return else: diff --git a/breezy/git/hg.py b/breezy/git/hg.py index f9e973c608..9496a6d127 100644 --- a/breezy/git/hg.py +++ b/breezy/git/hg.py @@ -28,8 +28,8 @@ def format_hg_metadata(renames, branch, extra): :param extra: Dictionary with extra data :return: Tail for commit message """ - extra_message = '' - if branch != 'default': + extra_message = "" + if branch != "default": extra_message += "branch : " + branch + "\n" if renames: @@ -37,12 +37,10 @@ def format_hg_metadata(renames, branch, extra): extra_message += "rename : " + oldfile + " => " + newfile + "\n" for key, value in extra.iteritems(): - if key in ('author', 'committer', 'encoding', 'message', 'branch', - 'hg-git'): + if key in ("author", "committer", "encoding", "message", "branch", "hg-git"): continue else: - extra_message += "extra : " + key + \ - " : " + urllib.parse.quote(value) + "\n" + extra_message += "extra : " + key + " : " + urllib.parse.quote(value) + "\n" if extra_message: return "\n--HG--\n" + extra_message @@ -65,15 +63,15 @@ def extract_hg_metadata(message): message, meta = split lines = meta.split("\n") for line in lines: - if line == '': + if line == "": continue command, data = line.split(" : ", 1) - if command == 'rename': + if command == "rename": before, after = data.split(" => ", 1) renames[after] = before - elif command == 'branch': + elif command == "branch": branch = data - elif command == 'extra': + elif command == "extra": before, after = data.split(" : ", 1) extra[before] = urllib.parse.unquote(after) else: diff --git a/breezy/git/interrepo.py b/breezy/git/interrepo.py index aba8480ba7..b88b056698 100644 --- a/breezy/git/interrepo.py +++ b/breezy/git/interrepo.py @@ -78,8 +78,11 @@ def copy_content(self, revision_id=None): self.fetch(revision_id=revision_id, find_ghosts=False) def fetch_refs( - self, update_refs: Callable[[Dict[bytes, ObjectID]], Dict[bytes, ObjectID]], - lossy: bool, overwrite: bool = False) -> Tuple[RevidMap, Dict[bytes, ObjectID]]: + self, + update_refs: Callable[[Dict[bytes, ObjectID]], Dict[bytes, ObjectID]], + lossy: bool, + overwrite: bool = False, + ) -> Tuple[RevidMap, Dict[bytes, ObjectID]]: """Fetch possibly roundtripped revisions into the target repository and update refs. @@ -91,9 +94,9 @@ def fetch_refs( """ raise NotImplementedError(self.fetch_refs) - def search_missing_revision_ids(self, - find_ghosts=True, revision_ids=None, - if_present_ids=None, limit=None): + def search_missing_revision_ids( + self, find_ghosts=True, revision_ids=None, if_present_ids=None, limit=None + ): if limit is not None: raise FetchLimitUnsupported(self) git_shas = [] @@ -115,21 +118,28 @@ def search_missing_revision_ids(self, self.source_store, include=git_shas, exclude=[ - sha for sha in self.target.controldir.get_refs_container().as_dict().values() - if sha != ZERO_SHA]) + sha + for sha in self.target.controldir.get_refs_container() + .as_dict() + .values() + if sha != ZERO_SHA + ], + ) missing_revids = set() for entry in walker: - for (kind, type_data) in self.source_store.lookup_git_sha( - entry.commit.id): + for kind, type_data in self.source_store.lookup_git_sha( + entry.commit.id + ): if kind == "commit": missing_revids.add(type_data[0]) return self.source.revision_ids_to_search_result(missing_revids) def _warn_slow(self): - if not config.GlobalConfig().suppress_warning('slow_intervcs_push'): + if not config.GlobalConfig().suppress_warning("slow_intervcs_push"): trace.warning( - 'Pushing from a Bazaar to a Git repository. ' - 'For better performance, push into a Bazaar repository.') + "Pushing from a Bazaar to a Git repository. " + "For better performance, push into a Bazaar repository." + ) class InterToLocalGitRepository(InterToGitRepository): @@ -144,7 +154,7 @@ def __init__(self, source, target): def _commit_needs_fetching(self, sha_id): try: - return (sha_id not in self.target_store) + return sha_id not in self.target_store except NoSuchRevision: # Ghost, can't push return False @@ -170,13 +180,17 @@ def missing_revisions(self, stop_revisions): """ revid_sha_map = {} stop_revids = [] - for (sha1, revid) in stop_revisions: + for sha1, revid in stop_revisions: if sha1 is not None and revid is not None: revid_sha_map[revid] = sha1 stop_revids.append(revid) elif sha1 is not None: if self._commit_needs_fetching(sha1): - for (_kind, (revid, _tree_sha, _verifiers)) in self.source_store.lookup_git_sha(sha1): + for _kind, ( + revid, + _tree_sha, + _verifiers, + ) in self.source_store.lookup_git_sha(sha1): revid_sha_map[revid] = sha1 stop_revids.append(revid) else: @@ -190,8 +204,9 @@ def missing_revisions(self, stop_revisions): new_stop_revids = [] for revid in stop_revids: sha1 = revid_sha_map.get(revid) - if (revid not in missing and - self._revision_needs_fetching(sha1, revid)): + if revid not in missing and self._revision_needs_fetching( + sha1, revid + ): missing.add(revid) new_stop_revids.append(revid) stop_revids = set() @@ -217,10 +232,8 @@ def _get_target_either_refs(self) -> EitherRefDict: revid = None if v and not v.startswith(SYMREF): try: - for (kind, type_data) in self.source_store.lookup_git_sha( - v): - if kind == "commit" and self.source.has_revision( - type_data[0]): + for kind, type_data in self.source_store.lookup_git_sha(v): + if kind == "commit" and self.source.has_revision(type_data[0]): revid = type_data[0] break except KeyError: @@ -235,10 +248,13 @@ def fetch_refs(self, update_refs, lossy, overwrite: bool = False): old_refs = self._get_target_either_refs() new_refs = update_refs(old_refs) revidmap = self.fetch_revs( - [(git_sha, bzr_revid) - for (git_sha, bzr_revid) in new_refs.values() - if git_sha is None or not git_sha.startswith(SYMREF)], - lossy=lossy) + [ + (git_sha, bzr_revid) + for (git_sha, bzr_revid) in new_refs.values() + if git_sha is None or not git_sha.startswith(SYMREF) + ], + lossy=lossy, + ) for name, (gitid, revid) in new_refs.items(): if gitid is None: try: @@ -246,8 +262,7 @@ def fetch_refs(self, update_refs, lossy, overwrite: bool = False): except KeyError: gitid = self.source_store._lookup_revision_sha1(revid) if gitid.startswith(SYMREF): - self.target_refs.set_symbolic_ref( - name, gitid[len(SYMREF):]) + self.target_refs.set_symbolic_ref(name, gitid[len(SYMREF) :]) else: try: old_git_id = old_refs[name][0] @@ -255,27 +270,35 @@ def fetch_refs(self, update_refs, lossy, overwrite: bool = False): self.target_refs.add_if_new(name, gitid) else: self.target_refs.set_if_equals(name, old_git_id, gitid) - result_refs[name] = (gitid, revid if not lossy else self.mapping.revision_id_foreign_to_bzr(gitid)) + result_refs[name] = ( + gitid, + revid + if not lossy + else self.mapping.revision_id_foreign_to_bzr(gitid), + ) return revidmap, old_refs, result_refs def fetch_revs(self, revs, lossy: bool, limit: Optional[int] = None) -> RevidMap: if not lossy and not self.mapping.roundtripping: for _git_sha, bzr_revid in revs: - if (bzr_revid is not None and - needs_roundtripping(self.source, bzr_revid)): - raise NoPushSupport(self.source, self.target, self.mapping, - bzr_revid) + if bzr_revid is not None and needs_roundtripping( + self.source, bzr_revid + ): + raise NoPushSupport( + self.source, self.target, self.mapping, bzr_revid + ) with self.source_store.lock_read(): todo = list(self.missing_revisions(revs))[:limit] revidmap = {} with ui.ui_factory.nested_progress_bar() as pb: object_generator = MissingObjectsIterator( - self.source_store, self.source, pb) - for (old_revid, git_sha) in object_generator.import_revisions( - todo, lossy=lossy): + self.source_store, self.source, pb + ) + for old_revid, git_sha in object_generator.import_revisions( + todo, lossy=lossy + ): if lossy: - new_revid = self.mapping.revision_id_foreign_to_bzr( - git_sha) + new_revid = self.mapping.revision_id_foreign_to_bzr(git_sha) else: new_revid = old_revid try: @@ -286,8 +309,9 @@ def fetch_revs(self, revs, lossy: bool, limit: Optional[int] = None) -> RevidMap self.target_store.add_objects(object_generator) return revidmap - def fetch(self, revision_id=None, find_ghosts: bool = False, - lossy=False, fetch_spec=None) -> FetchResult: + def fetch( + self, revision_id=None, find_ghosts: bool = False, lossy=False, fetch_spec=None + ) -> FetchResult: if revision_id is not None: stop_revisions = [(None, revision_id)] elif fetch_spec is not None: @@ -295,11 +319,9 @@ def fetch(self, revision_id=None, find_ghosts: bool = False, if recipe[0] in ("search", "proxy-search"): stop_revisions = [(None, revid) for revid in recipe[1]] else: - raise AssertionError( - f"Unsupported search result type {recipe[0]}") + raise AssertionError(f"Unsupported search result type {recipe[0]}") else: - stop_revisions = [(None, revid) - for revid in self.source.all_revision_ids()] + stop_revisions = [(None, revid) for revid in self.source.all_revision_ids()] self._warn_slow() try: revidmap = self.fetch_revs(stop_revisions, lossy=lossy) @@ -310,12 +332,12 @@ def fetch(self, revision_id=None, find_ghosts: bool = False, @staticmethod def is_compatible(source, target): """Be compatible with GitRepository.""" - return (not isinstance(source, GitRepository) and - isinstance(target, LocalGitRepository)) + return not isinstance(source, GitRepository) and isinstance( + target, LocalGitRepository + ) class InterToRemoteGitRepository(InterToGitRepository): - target: RemoteGitRepository def fetch_refs(self, update_refs, lossy, overwrite: bool = False): @@ -328,28 +350,26 @@ def fetch_refs(self, update_refs, lossy, overwrite: bool = False): def git_update_refs(old_refs): ret = {} - self.old_refs = { - k: (v, None) for (k, v) in old_refs.items()} + self.old_refs = {k: (v, None) for (k, v) in old_refs.items()} new_refs = update_refs(self.old_refs) for name, (gitid, revid) in new_refs.items(): if gitid is None: git_sha = self.source_store._lookup_revision_sha1(revid) - gitid = unpeel_map.re_unpeel_tag( - git_sha, old_refs.get(name)) + gitid = unpeel_map.re_unpeel_tag(git_sha, old_refs.get(name)) if not overwrite: - if remote_divergence( - old_refs.get(name), gitid, self.source_store): + if remote_divergence(old_refs.get(name), gitid, self.source_store): raise DivergedBranches(self.source, self.target) ret[name] = gitid return ret + self._warn_slow() with self.source_store.lock_read(): result = self.target.send_pack( - git_update_refs, self.source_store.generate_lossy_pack_data) + git_update_refs, self.source_store.generate_lossy_pack_data + ) for ref, error in result.ref_status.items(): if error: - raise RemoteGitError( - f'unable to update ref {ref!r}: {error}') + raise RemoteGitError(f"unable to update ref {ref!r}: {error}") new_refs = result.refs # FIXME: revidmap? return revidmap, self.old_refs, new_refs @@ -357,12 +377,12 @@ def git_update_refs(old_refs): @staticmethod def is_compatible(source, target): """Be compatible with GitRepository.""" - return (not isinstance(source, GitRepository) and - isinstance(target, RemoteGitRepository)) + return not isinstance(source, GitRepository) and isinstance( + target, RemoteGitRepository + ) class GitSearchResult(AbstractSearchResult): - def __init__(self, start, exclude, keys): self._start = start self._exclude = exclude @@ -372,11 +392,10 @@ def get_keys(self): return self._keys def get_recipe(self): - return ('search', self._start, self._exclude, len(self._keys)) + return ("search", self._start, self._exclude, len(self._keys)) class InterFromGitRepository(InterRepository): - _matching_repo_format = GitRepositoryFormat() def _target_has_shas(self, shas): @@ -389,7 +408,7 @@ def determine_wants(refs): unpeel_lookup = {} for k, v in refs.items(): if k.endswith(PEELED_TAG_SUFFIX): - unpeel_lookup[v] = refs[k[:-len(PEELED_TAG_SUFFIX)]] + unpeel_lookup[v] = refs[k[: -len(PEELED_TAG_SUFFIX)]] potential = {unpeel_lookup.get(w, w) for w in wants} if include_tags: for k, sha in refs.items(): @@ -405,6 +424,7 @@ def determine_wants(refs): continue potential.add(sha) return list(potential - self._target_has_shas(potential)) + return determine_wants def determine_wants_all(self, refs): @@ -418,9 +438,9 @@ def copy_content(self, revision_id=None): """See InterRepository.copy_content.""" self.fetch(revision_id, find_ghosts=False) - def search_missing_revision_ids(self, - find_ghosts=True, revision_ids=None, - if_present_ids=None, limit=None): + def search_missing_revision_ids( + self, find_ghosts=True, revision_ids=None, if_present_ids=None, limit=None + ): if limit is not None: raise FetchLimitUnsupported(self) if revision_ids is None and if_present_ids is None: @@ -435,8 +455,11 @@ def search_missing_revision_ids(self, if if_present_ids is not None: todo.update(if_present_ids) result_set = todo.difference(self.target.all_revision_ids()) - result_parents = set(itertools.chain.from_iterable( - self.source.get_graph().get_parent_map(result_set).values())) + result_parents = set( + itertools.chain.from_iterable( + self.source.get_graph().get_parent_map(result_set).values() + ) + ) included_keys = result_set.intersection(result_parents) start_keys = result_set.difference(included_keys) exclude_keys = result_parents.difference(result_set) @@ -470,10 +493,11 @@ def determine_wants_all(self, refs): return list(potential - self._target_has_shas(potential)) def _warn_slow(self): - if not config.GlobalConfig().suppress_warning('slow_intervcs_push'): + if not config.GlobalConfig().suppress_warning("slow_intervcs_push"): trace.warning( - 'Fetching from Git to Bazaar repository. ' - 'For better performance, fetch into a Git repository.') + "Fetching from Git to Bazaar repository. " + "For better performance, fetch into a Git repository." + ) def fetch_objects(self, determine_wants, mapping, limit=None, lossy=False): """Fetch objects from a remote server. @@ -494,10 +518,18 @@ def get_determine_wants_revids(self, revids, include_tags=False, tag_selector=No git_sha, mapping = self.source.lookup_bzr_revision_id(revid) wants.add(git_sha) return self.get_determine_wants_heads( - wants, include_tags=include_tags, tag_selector=tag_selector) - - def fetch(self, revision_id=None, find_ghosts=False, - mapping=None, fetch_spec=None, include_tags=False, lossy=False): + wants, include_tags=include_tags, tag_selector=tag_selector + ) + + def fetch( + self, + revision_id=None, + find_ghosts=False, + mapping=None, + fetch_spec=None, + include_tags=False, + lossy=False, + ): if mapping is None: mapping = self.source.get_mapping() if revision_id is not None: @@ -513,12 +545,14 @@ def fetch(self, revision_id=None, find_ghosts=False, if interesting_heads is not None: determine_wants = self.get_determine_wants_revids( - interesting_heads, include_tags=include_tags) + interesting_heads, include_tags=include_tags + ) else: determine_wants = self.determine_wants_all (pack_hint, _, remote_refs) = self.fetch_objects( - determine_wants, mapping, lossy=lossy) + determine_wants, mapping, lossy=lossy + ) if pack_hint is not None and self.target._format.pack_compresses: self.target.pack(hint=pack_hint) result = FetchResult() @@ -548,17 +582,24 @@ def fetch_objects(self, determine_wants, mapping, limit=None, lossy=False): heads = self.get_target_heads() graph_walker = ObjectStoreGraphWalker( [store._lookup_revision_sha1(head) for head in heads], - lambda sha: store[sha].parents) + lambda sha: store[sha].parents, + ) wants_recorder = DetermineWantsRecorder(determine_wants) with ui.ui_factory.nested_progress_bar() as pb: objects_iter = self.source.fetch_objects( - wants_recorder, graph_walker, store.get_raw) - trace.mutter("Importing %d new revisions", - len(wants_recorder.wants)) + wants_recorder, graph_walker, store.get_raw + ) + trace.mutter("Importing %d new revisions", len(wants_recorder.wants)) (pack_hint, last_rev) = import_git_objects( - self.target, mapping, objects_iter, store, - wants_recorder.wants, pb, limit) + self.target, + mapping, + objects_iter, + store, + wants_recorder.wants, + pb, + limit, + ) return (pack_hint, last_rev, wants_recorder.remote_refs) @staticmethod @@ -590,8 +631,14 @@ def fetch_objects(self, determine_wants, mapping, limit=None, lossy=False): target_git_object_retriever.lock_write() try: (pack_hint, last_rev) = import_git_objects( - self.target, mapping, self.source._git.object_store, - target_git_object_retriever, wants, pb, limit) + self.target, + mapping, + self.source._git.object_store, + target_git_object_retriever, + wants, + pb, + limit, + ) return (pack_hint, last_rev, remote_refs) finally: target_git_object_retriever.unlock() @@ -622,15 +669,16 @@ def _get_target_either_refs(self): ret[name] = (sha1, self.target.lookup_foreign_revision_id(sha1)) return ret - def fetch_refs(self, update_refs, lossy: bool = False, overwrite: bool = False) -> Tuple[RevidMap, EitherRefDict, EitherRefDict]: + def fetch_refs( + self, update_refs, lossy: bool = False, overwrite: bool = False + ) -> Tuple[RevidMap, EitherRefDict, EitherRefDict]: if lossy: raise LossyPushToSameVCS(self.source, self.target) old_refs = self._get_target_either_refs() ref_changes = {} def determine_wants(heads): - old_refs = {k: (v, None) - for (k, v) in heads.items()} + old_refs = {k: (v, None) for (k, v) in heads.items()} new_refs = update_refs(old_refs) ret = [] for name, (sha1, bzr_revid) in list(new_refs.items()): @@ -640,6 +688,7 @@ def determine_wants(heads): ret.append(sha1) ref_changes.update(new_refs) return ret + self.fetch_objects(determine_wants) for k, (git_sha, _bzr_revid) in ref_changes.items(): self.target._git.refs[k] = git_sha # type: ignore @@ -650,12 +699,18 @@ def fetch_objects(self, determine_wants, limit=None, mapping=None, lossy=False): raise NotImplementedError(self.fetch_objects) def _target_has_shas(self, shas): - return { - sha for sha in shas if sha in self.target._git.object_store} - - def fetch(self, revision_id=None, find_ghosts=False, - fetch_spec=None, branches=None, limit=None, - include_tags=False, lossy=False): + return {sha for sha in shas if sha in self.target._git.object_store} + + def fetch( + self, + revision_id=None, + find_ghosts=False, + fetch_spec=None, + branches=None, + limit=None, + include_tags=False, + lossy=False, + ): if lossy: raise LossyPushToSameVCS(self.source, self.target) if revision_id is not None: @@ -665,17 +720,18 @@ def fetch(self, revision_id=None, find_ghosts=False, if recipe[0] in ("search", "proxy-search"): heads = recipe[1] else: - raise AssertionError( - f"Unsupported search result type {recipe[0]}") + raise AssertionError(f"Unsupported search result type {recipe[0]}") args = heads if branches is not None: determine_wants = self.get_determine_wants_branches( - branches, include_tags=include_tags) + branches, include_tags=include_tags + ) elif fetch_spec is None and revision_id is None: determine_wants = self.determine_wants_all else: determine_wants = self.get_determine_wants_revids( - args, include_tags=include_tags) + args, include_tags=include_tags + ) wants_recorder = DetermineWantsRecorder(determine_wants) self.fetch_objects(wants_recorder, limit=limit) result = FetchResult() @@ -689,7 +745,9 @@ def get_determine_wants_revids(self, revids, include_tags=False, tag_selector=No continue git_sha, mapping = self.source.lookup_bzr_revision_id(revid) wants.add(git_sha) - return self.get_determine_wants_heads(wants, include_tags=include_tags, tag_selector=tag_selector) + return self.get_determine_wants_heads( + wants, include_tags=include_tags, tag_selector=tag_selector + ) def get_determine_wants_branches(self, branches, include_tags=False): def determine_wants(refs): @@ -704,54 +762,61 @@ def determine_wants(refs): if name in branches or (include_tags and is_tag(name)): ret.append(value) return ret + return determine_wants def determine_wants_all(self, refs): potential = { - v for k, v in refs.items() - if not v == ZERO_SHA and not k.endswith(PEELED_TAG_SUFFIX)} + v + for k, v in refs.items() + if not v == ZERO_SHA and not k.endswith(PEELED_TAG_SUFFIX) + } return list(potential - self._target_has_shas(potential)) class InterLocalGitLocalGitRepository(InterGitGitRepository): - source: LocalGitRepository target: LocalGitRepository - def fetch_objects(self, determine_wants, limit=None, mapping=None, lossy: bool = False): + def fetch_objects( + self, determine_wants, limit=None, mapping=None, lossy: bool = False + ): if limit is not None: raise FetchLimitUnsupported(self) if lossy: raise LossyPushToSameVCS(self.source, self.target) from .remote import DefaultProgressReporter + with ui.ui_factory.nested_progress_bar() as pb: progress = DefaultProgressReporter(pb).progress refs = self.source._git.fetch( - self.target._git, determine_wants, - progress=progress) + self.target._git, determine_wants, progress=progress + ) return (None, None, refs) @staticmethod def is_compatible(source, target): """Be compatible with GitRepository.""" - return (isinstance(source, LocalGitRepository) and - isinstance(target, LocalGitRepository)) + return isinstance(source, LocalGitRepository) and isinstance( + target, LocalGitRepository + ) class InterRemoteGitLocalGitRepository(InterGitGitRepository): - def fetch_objects(self, determine_wants, limit=None, mapping=None): from tempfile import SpooledTemporaryFile + if limit is not None: raise FetchLimitUnsupported(self) graphwalker = self.target._git.get_graph_walker() - if (CAPABILITY_THIN_PACK in - self.source.controldir._client._fetch_capabilities): + if CAPABILITY_THIN_PACK in self.source.controldir._client._fetch_capabilities: # TODO(jelmer): Avoid reading entire file into memory and # only processing it after the whole file has been fetched. f = SpooledTemporaryFile( - max_size=PACK_SPOOL_FILE_MAX_SIZE, prefix='incoming-', - dir=getattr(self.target._git.object_store, 'path', None)) + max_size=PACK_SPOOL_FILE_MAX_SIZE, + prefix="incoming-", + dir=getattr(self.target._git.object_store, "path", None), + ) def commit(): if f.tell(): @@ -764,7 +829,8 @@ def abort(): f, commit, abort = self.target._git.object_store.add_pack() try: refs = self.source.controldir.fetch_pack( - determine_wants, graphwalker, f.write) + determine_wants, graphwalker, f.write + ) commit() return (None, None, refs) except BaseException: @@ -774,13 +840,12 @@ def abort(): @staticmethod def is_compatible(source, target): """Be compatible with GitRepository.""" - return (isinstance(source, RemoteGitRepository) and - isinstance(target, LocalGitRepository)) - + return isinstance(source, RemoteGitRepository) and isinstance( + target, LocalGitRepository + ) class InterLocalGitRemoteGitRepository(InterToGitRepository): - def fetch_refs(self, update_refs, lossy=False, overwrite=False): """Import the gist of the ancestry of a particular revision.""" if lossy: @@ -788,24 +853,24 @@ def fetch_refs(self, update_refs, lossy=False, overwrite=False): def git_update_refs(old_refs): ret = {} - self.old_refs = { - k: (v, None) for (k, v) in old_refs.items()} + self.old_refs = {k: (v, None) for (k, v) in old_refs.items()} new_refs = update_refs(self.old_refs) for name, (gitid, revid) in new_refs.items(): if gitid is None: gitid = self.source_store._lookup_revision_sha1(revid) if not overwrite: - if remote_divergence( - old_refs.get(name), gitid, self.source_store): + if remote_divergence(old_refs.get(name), gitid, self.source_store): raise DivergedBranches(self.source, self.target) ret[name] = gitid return ret + new_refs = self.target.send_pack( - git_update_refs, - self.source._git.generate_pack_data) + git_update_refs, self.source._git.generate_pack_data + ) return None, self.old_refs, new_refs @staticmethod def is_compatible(source, target): - return (isinstance(source, LocalGitRepository) and - isinstance(target, RemoteGitRepository)) + return isinstance(source, LocalGitRepository) and isinstance( + target, RemoteGitRepository + ) diff --git a/breezy/git/mapping.py b/breezy/git/mapping.py index 31ecec7736..dee3454a29 100644 --- a/breezy/git/mapping.py +++ b/breezy/git/mapping.py @@ -44,7 +44,7 @@ HG_EXTRA_TOPIC = b"topic" HG_EXTRA_REWRITE_NOISE = b"_rewrite_noise" -FILE_ID_PREFIX = b'git:' +FILE_ID_PREFIX = b"git:" # Always the same. ROOT_ID = b"TREE_ROOT" @@ -77,9 +77,9 @@ def __init__(self, encoding): def escape_file_id(file_id): - file_id = file_id.replace(b'_', b'__') - file_id = file_id.replace(b' ', b'_s') - file_id = file_id.replace(b'\x0c', b'_c') + file_id = file_id.replace(b"_", b"__") + file_id = file_id.replace(b" ", b"_s") + file_id = file_id.replace(b"\x0c", b"_c") return file_id @@ -87,14 +87,14 @@ def unescape_file_id(file_id): ret = bytearray() i = 0 while i < len(file_id): - if file_id[i:i + 1] != b'_': + if file_id[i : i + 1] != b"_": ret.append(file_id[i]) else: - if file_id[i + 1:i + 2] == b'_': + if file_id[i + 1 : i + 2] == b"_": ret.append(b"_"[0]) - elif file_id[i + 1:i + 2] == b's': + elif file_id[i + 1 : i + 2] == b"s": ret.append(b" "[0]) - elif file_id[i + 1:i + 2] == b'c': + elif file_id[i + 1 : i + 2] == b"c": ret.append(b"\x0c"[0]) else: raise ValueError(f"unknown escape character {file_id[i + 1:i + 2]}") @@ -121,44 +121,53 @@ def fix_person_identifier(text): def decode_git_path(path): """Take a git path and decode it.""" - return path.decode('utf-8', 'surrogateescape') + return path.decode("utf-8", "surrogateescape") def encode_git_path(path): """Take a regular path and encode it for git.""" - return path.encode('utf-8', 'surrogateescape') + return path.encode("utf-8", "surrogateescape") def warn_escaped(commit, num_escaped): - trace.warning("Escaped %d XML-invalid characters in %s. Will be unable " - "to regenerate the SHA map.", num_escaped, commit) + trace.warning( + "Escaped %d XML-invalid characters in %s. Will be unable " + "to regenerate the SHA map.", + num_escaped, + commit, + ) def warn_unusual_mode(commit, path, mode): - trace.mutter("Unusual file mode %o for %s in %s. Storing as revision " - "property. ", mode, path, commit) + trace.mutter( + "Unusual file mode %o for %s in %s. Storing as revision " "property. ", + mode, + path, + commit, + ) class BzrGitMapping(foreign.VcsMapping): """Class that maps between Git and Bazaar semantics.""" + experimental = False BZR_DUMMY_FILE: Optional[str] = None def is_special_file(self, filename): - return (filename in (self.BZR_DUMMY_FILE, )) + return filename in (self.BZR_DUMMY_FILE,) def __init__(self): super().__init__(foreign_vcs_git) def __eq__(self, other): - return (type(self) == type(other) - and self.revid_prefix == other.revid_prefix) + return type(self) == type(other) and self.revid_prefix == other.revid_prefix @classmethod def revision_id_foreign_to_bzr(cls, git_rev_id): """Convert a git revision id handle to a Bazaar revision id.""" from dulwich.protocol import ZERO_SHA + if git_rev_id == ZERO_SHA: return NULL_REVISION return b"%s:%s" % (cls.revid_prefix, git_rev_id) @@ -168,7 +177,7 @@ def revision_id_bzr_to_foreign(cls, bzr_rev_id): """Convert a Bazaar revision id to a git revision id handle.""" if not bzr_rev_id.startswith(b"%s:" % cls.revid_prefix): raise errors.InvalidRevisionId(bzr_rev_id, cls) - return bzr_rev_id[len(cls.revid_prefix) + 1:], cls() + return bzr_rev_id[len(cls.revid_prefix) + 1 :], cls() def generate_file_id(self, path): # Git paths are just bytestrings @@ -184,17 +193,19 @@ def parse_file_id(self, file_id): return "" if not file_id.startswith(FILE_ID_PREFIX): raise ValueError - return decode_git_path(unescape_file_id(file_id[len(FILE_ID_PREFIX):])) + return decode_git_path(unescape_file_id(file_id[len(FILE_ID_PREFIX) :])) def import_unusual_file_modes(self, rev, unusual_file_modes): if unusual_file_modes: - ret = [(path, unusual_file_modes[path]) - for path in sorted(unusual_file_modes.keys())] - rev.properties['file-modes'] = bencode.bencode(ret) + ret = [ + (path, unusual_file_modes[path]) + for path in sorted(unusual_file_modes.keys()) + ] + rev.properties["file-modes"] = bencode.bencode(ret) def export_unusual_file_modes(self, rev): try: - file_modes = rev.properties['file-modes'] + file_modes = rev.properties["file-modes"] except KeyError: return {} else: @@ -211,16 +222,16 @@ def _generate_git_svn_metadata(self, rev, encoding): def _generate_hg_message_tail(self, rev): extra = {} renames = [] - branch = 'default' + branch = "default" for name in rev.properties: - if name == 'hg:extra:branch': - branch = rev.properties['hg:extra:branch'] - elif name.startswith('hg:extra'): - extra[name[len('hg:extra:'):]] = base64.b64decode( - rev.properties[name]) - elif name == 'hg:renames': - renames = bencode.bdecode(base64.b64decode( - rev.properties['hg:renames'])) + if name == "hg:extra:branch": + branch = rev.properties["hg:extra:branch"] + elif name.startswith("hg:extra"): + extra[name[len("hg:extra:") :]] = base64.b64decode(rev.properties[name]) + elif name == "hg:renames": + renames = bencode.bdecode( + base64.b64decode(rev.properties["hg:renames"]) + ) # TODO: Export other properties as 'bzr:' extras? ret = format_hg_metadata(renames, branch, extra) if not isinstance(ret, bytes): @@ -229,11 +240,12 @@ def _generate_hg_message_tail(self, rev): def _extract_git_svn_metadata(self, properties, message): lines = message.split("\n") - if not (lines[-1] == "" and len(lines) >= 2 and - lines[-2].startswith("git-svn-id:")): + if not ( + lines[-1] == "" and len(lines) >= 2 and lines[-2].startswith("git-svn-id:") + ): return message git_svn_id = lines[-2].split(": ", 1)[1] - properties['git-svn-id'] = git_svn_id + properties["git-svn-id"] = git_svn_id (url, rev, uuid) = parse_git_svn_id(git_svn_id) # FIXME: Convert this to converted-from property somehow.. return "\n".join(lines[:-2]) @@ -241,12 +253,13 @@ def _extract_git_svn_metadata(self, properties, message): def _extract_hg_metadata(self, properties, message): (message, renames, branch, extra) = extract_hg_metadata(message) if branch is not None: - properties['hg:extra:branch'] = branch + properties["hg:extra:branch"] = branch for name, value in extra.items(): - properties['hg:extra:' + name] = base64.b64encode(value) + properties["hg:extra:" + name] = base64.b64encode(value) if renames: - properties['hg:renames'] = base64.b64encode(bencode.bencode( - [(new, old) for (old, new) in renames.items()])) + properties["hg:renames"] = base64.b64encode( + bencode.bencode([(new, old) for (old, new) in renames.items()]) + ) return message def _extract_bzr_metadata(self, properties, message): @@ -266,8 +279,7 @@ def _encode_commit_message(self, rev, message, encoding): else: return message.encode(encoding) - def export_commit(self, rev, tree_sha, parent_lookup, lossy, - verifiers): + def export_commit(self, rev, tree_sha, parent_lookup, lossy, verifiers): """Turn a Bazaar revision in to a Git commit. :param tree_sha: Tree sha for the commit @@ -278,6 +290,7 @@ def export_commit(self, rev, tree_sha, parent_lookup, lossy, :return dulwich.objects.Commit represent the revision: """ from dulwich.objects import Commit, Tag + commit = Commit() commit.tree = tree_sha if not lossy: @@ -299,47 +312,42 @@ def export_commit(self, rev, tree_sha, parent_lookup, lossy, parents.append(git_p) commit.parents = parents try: - encoding = rev.properties['git-explicit-encoding'] + encoding = rev.properties["git-explicit-encoding"] except KeyError: - encoding = rev.properties.get('git-implicit-encoding', 'utf-8') + encoding = rev.properties.get("git-implicit-encoding", "utf-8") try: - commit.encoding = rev.properties['git-explicit-encoding'].encode( - 'ascii') + commit.encoding = rev.properties["git-explicit-encoding"].encode("ascii") except KeyError: pass - commit.committer = fix_person_identifier(rev.committer.encode( - encoding)) + commit.committer = fix_person_identifier(rev.committer.encode(encoding)) first_author = rev.get_apparent_authors()[0] - if ',' in first_author and first_author.count('>') > 1: - first_author = first_author.split(',')[0] - commit.author = fix_person_identifier( - first_author.encode(encoding)) + if "," in first_author and first_author.count(">") > 1: + first_author = first_author.split(",")[0] + commit.author = fix_person_identifier(first_author.encode(encoding)) # TODO(jelmer): Don't use this hack. - long = getattr(__builtins__, 'long', int) + long = getattr(__builtins__, "long", int) commit.commit_time = long(rev.timestamp) - if 'author-timestamp' in rev.properties: - commit.author_time = long(rev.properties['author-timestamp']) + if "author-timestamp" in rev.properties: + commit.author_time = long(rev.properties["author-timestamp"]) else: commit.author_time = commit.commit_time - commit._commit_timezone_neg_utc = ( - "commit-timezone-neg-utc" in rev.properties) + commit._commit_timezone_neg_utc = "commit-timezone-neg-utc" in rev.properties commit.commit_timezone = rev.timezone - commit._author_timezone_neg_utc = ( - "author-timezone-neg-utc" in rev.properties) - if 'author-timezone' in rev.properties: - commit.author_timezone = int(rev.properties['author-timezone']) + commit._author_timezone_neg_utc = "author-timezone-neg-utc" in rev.properties + if "author-timezone" in rev.properties: + commit.author_timezone = int(rev.properties["author-timezone"]) else: commit.author_timezone = commit.commit_timezone - if 'git-gpg-signature' in rev.properties: - commit.gpgsig = rev.properties['git-gpg-signature'].encode( - 'utf-8', 'surrogateescape') - if 'git-missing-message' in rev.properties: - if commit.message != '': - raise AssertionError('git-missing-message set but message is not empty') + if "git-gpg-signature" in rev.properties: + commit.gpgsig = rev.properties["git-gpg-signature"].encode( + "utf-8", "surrogateescape" + ) + if "git-missing-message" in rev.properties: + if commit.message != "": + raise AssertionError("git-missing-message set but message is not empty") commit.message = None else: - commit.message = self._encode_commit_message(rev, rev.message, - encoding) + commit.message = self._encode_commit_message(rev, rev.message, encoding) if not isinstance(commit.message, bytes): raise TypeError(commit.message) if metadata is not None: @@ -348,50 +356,61 @@ def export_commit(self, rev, tree_sha, parent_lookup, lossy, except errors.InvalidRevisionId: metadata.revision_id = rev.revision_id mapping_properties = { - 'author', 'author-timezone', 'author-timezone-neg-utc', - 'commit-timezone-neg-utc', 'git-implicit-encoding', - 'git-gpg-signature', 'git-explicit-encoding', - 'author-timestamp', 'file-modes'} + "author", + "author-timezone", + "author-timezone-neg-utc", + "commit-timezone-neg-utc", + "git-implicit-encoding", + "git-gpg-signature", + "git-explicit-encoding", + "author-timestamp", + "file-modes", + } for k, v in rev.properties.items(): if k not in mapping_properties: metadata.properties[k] = v if not lossy and metadata: if self.roundtripping: - commit.message = inject_bzr_metadata(commit.message, metadata, - encoding) + commit.message = inject_bzr_metadata(commit.message, metadata, encoding) else: - raise NoPushSupport( - None, None, self, revision_id=rev.revision_id) + raise NoPushSupport(None, None, self, revision_id=rev.revision_id) if not isinstance(commit.message, bytes): raise TypeError(commit.message) i = 0 - propname = 'git-mergetag-0' + propname = "git-mergetag-0" while propname in rev.properties: commit.mergetag.append( - Tag.from_string(rev.properties[propname].encode('utf-8', 'surrogateescape'))) + Tag.from_string( + rev.properties[propname].encode("utf-8", "surrogateescape") + ) + ) i += 1 - propname = 'git-mergetag-%d' % i + propname = "git-mergetag-%d" % i try: extra = commit._extra except AttributeError: extra = commit.extra - if 'git-extra' in rev.properties: - for l in rev.properties['git-extra'].splitlines(): - (k, v) = l.split(' ', 1) + if "git-extra" in rev.properties: + for l in rev.properties["git-extra"].splitlines(): + (k, v) = l.split(" ", 1) extra.append( - (k.encode('utf-8', 'surrogateescape'), - v.encode('utf-8', 'surrogateescape'))) + ( + k.encode("utf-8", "surrogateescape"), + v.encode("utf-8", "surrogateescape"), + ) + ) return commit def get_revision_id(self, commit): if commit.encoding: - encoding = commit.encoding.decode('ascii') + encoding = commit.encoding.decode("ascii") else: - encoding = 'utf-8' + encoding = "utf-8" if commit.message is not None: try: message, metadata = self._decode_commit_message( - None, commit.message, encoding) + None, commit.message, encoding + ) except UnicodeDecodeError: pass else: @@ -420,42 +439,44 @@ def decode_using_encoding(properties, commit, encoding): raise UnknownCommitEncoding(encoding) from err try: if commit.committer != commit.author: - properties['author'] = commit.author.decode(encoding) + properties["author"] = commit.author.decode(encoding) except LookupError as err: raise UnknownCommitEncoding(encoding) from err message, git_metadata = self._decode_commit_message( - properties, commit.message, encoding) + properties, commit.message, encoding + ) if commit.encoding is not None: - properties['git-explicit-encoding'] = commit.encoding.decode( - 'ascii') - if commit.encoding is not None and commit.encoding != b'false': - decode_using_encoding(properties, commit, commit.encoding.decode('ascii')) + properties["git-explicit-encoding"] = commit.encoding.decode("ascii") + if commit.encoding is not None and commit.encoding != b"false": + decode_using_encoding(properties, commit, commit.encoding.decode("ascii")) else: - for encoding in ('utf-8', 'latin1'): + for encoding in ("utf-8", "latin1"): try: decode_using_encoding(properties, commit, encoding) except UnicodeDecodeError: pass else: - if encoding != 'utf-8': - properties['git-implicit-encoding'] = encoding + if encoding != "utf-8": + properties["git-implicit-encoding"] = encoding break if commit.commit_time != commit.author_time: - properties['author-timestamp'] = str(commit.author_time) + properties["author-timestamp"] = str(commit.author_time) if commit.commit_timezone != commit.author_timezone: - properties['author-timezone'] = "%d" % commit.author_timezone + properties["author-timezone"] = "%d" % commit.author_timezone if commit._author_timezone_neg_utc: - properties['author-timezone-neg-utc'] = "" + properties["author-timezone-neg-utc"] = "" if commit._commit_timezone_neg_utc: - properties['commit-timezone-neg-utc'] = "" + properties["commit-timezone-neg-utc"] = "" if commit.gpgsig: - properties['git-gpg-signature'] = commit.gpgsig.decode( - 'utf-8', 'surrogateescape') + properties["git-gpg-signature"] = commit.gpgsig.decode( + "utf-8", "surrogateescape" + ) if commit.mergetag: for i, tag in enumerate(commit.mergetag): - properties['git-mergetag-%d' % i] = tag.as_raw_string().decode( - 'utf-8', 'surrogateescape') + properties["git-mergetag-%d" % i] = tag.as_raw_string().decode( + "utf-8", "surrogateescape" + ) timestamp = commit.commit_time timezone = commit.commit_timezone parent_ids = None @@ -486,58 +507,75 @@ def decode_using_encoding(properties, commit, encoding): for k, v in extra: if k == HG_RENAME_SOURCE: extra_lines.append( - k.decode('utf-8', 'surrogateescape') + ' ' + - v.decode('utf-8', 'surrogateescape') + '\n') + k.decode("utf-8", "surrogateescape") + + " " + + v.decode("utf-8", "surrogateescape") + + "\n" + ) elif k == HG_EXTRA: - hgk, hgv = v.split(b':', 1) - if hgk not in (HG_EXTRA_AMEND_SOURCE, HG_EXTRA_REBASE_SOURCE, - HG_EXTRA_ABSORB_SOURCE, HG_EXTRA_INTERMEDIATE_SOURCE, - HG_EXTRA_SOURCE, HG_EXTRA_TOPIC, - HG_EXTRA_REWRITE_NOISE) and strict: + hgk, hgv = v.split(b":", 1) + if ( + hgk + not in ( + HG_EXTRA_AMEND_SOURCE, + HG_EXTRA_REBASE_SOURCE, + HG_EXTRA_ABSORB_SOURCE, + HG_EXTRA_INTERMEDIATE_SOURCE, + HG_EXTRA_SOURCE, + HG_EXTRA_TOPIC, + HG_EXTRA_REWRITE_NOISE, + ) + and strict + ): raise UnknownMercurialCommitExtra(commit, [hgk]) extra_lines.append( - k.decode('utf-8', 'surrogateescape') + ' ' + - v.decode('utf-8', 'surrogateescape') + '\n') + k.decode("utf-8", "surrogateescape") + + " " + + v.decode("utf-8", "surrogateescape") + + "\n" + ) else: unknown_extra_fields.append(k) if unknown_extra_fields and strict: raise UnknownCommitExtra( - commit, - [f.decode('ascii', 'replace') for f in unknown_extra_fields]) + commit, [f.decode("ascii", "replace") for f in unknown_extra_fields] + ) if extra_lines: - properties['git-extra'] = ''.join(extra_lines) + properties["git-extra"] = "".join(extra_lines) if message is None: - properties['git-missing-message'] = 'true' - message = '' + properties["git-missing-message"] = "true" + message = "" rev = ForeignRevision( - foreign_revid=commit.id, mapping=self, + foreign_revid=commit.id, + mapping=self, revision_id=self.revision_id_foreign_to_bzr(commit.id), properties=properties, parent_ids=parent_ids, timestamp=timestamp, timezone=timezone, committer=committer, - message=message,) + message=message, + ) rev.git_metadata = git_metadata return rev, roundtrip_revid, verifiers class BzrGitMappingv1(BzrGitMapping): - revid_prefix = b'git-v1' + revid_prefix = b"git-v1" experimental = False def __str__(self): - return self.revid_prefix.decode('utf-8') + return self.revid_prefix.decode("utf-8") class BzrGitMappingExperimental(BzrGitMappingv1): - revid_prefix = b'git-experimental' + revid_prefix = b"git-experimental" experimental = True roundtripping = False - BZR_DUMMY_FILE = '.bzrdummy' + BZR_DUMMY_FILE = ".bzrdummy" def _decode_commit_message(self, properties, message, encoding): message = self._extract_hg_metadata(properties, message) @@ -556,8 +594,9 @@ def _encode_commit_message(self, rev, message, encoding): def import_commit(self, commit, lookup_parent_revid, strict=True): rev, roundtrip_revid, verifiers = super().import_commit( - commit, lookup_parent_revid, strict) - rev.properties['converted_revision'] = f"git {commit.id}\n" + commit, lookup_parent_revid, strict + ) + rev.properties["converted_revision"] = f"git {commit.id}\n" return rev, roundtrip_revid, verifiers @@ -567,6 +606,7 @@ class GitMappingRegistry(VcsMappingRegistry): def revision_id_bzr_to_foreign(self, bzr_revid): if bzr_revid == NULL_REVISION: from dulwich.protocol import ZERO_SHA + return ZERO_SHA, None if not bzr_revid.startswith(b"git-"): raise errors.InvalidRevisionId(bzr_revid, None) @@ -578,17 +618,17 @@ def revision_id_bzr_to_foreign(self, bzr_revid): mapping_registry = GitMappingRegistry() -mapping_registry.register_lazy(b'git-v1', __name__, - "BzrGitMappingv1") -mapping_registry.register_lazy(b'git-experimental', - __name__, "BzrGitMappingExperimental") +mapping_registry.register_lazy(b"git-v1", __name__, "BzrGitMappingv1") +mapping_registry.register_lazy( + b"git-experimental", __name__, "BzrGitMappingExperimental" +) # Uncomment the next line to enable the experimental bzr-git mappings. # This will make sure all bzr metadata is pushed into git, allowing for # full roundtripping later. # NOTE: THIS IS EXPERIMENTAL. IT MAY EAT YOUR DATA OR CORRUPT # YOUR BZR OR GIT REPOSITORIES. USE WITH CARE. # mapping_registry.set_default('git-experimental') -mapping_registry.set_default(b'git-v1') +mapping_registry.set_default(b"git-v1") class ForeignGit(ForeignVcs): @@ -597,11 +637,13 @@ class ForeignGit(ForeignVcs): @property def branch_format(self): from .branch import LocalGitBranchFormat + return LocalGitBranchFormat() @property def repository_format(self): from .repository import GitRepositoryFormat + return GitRepositoryFormat() def __init__(self): @@ -614,7 +656,7 @@ def serialize_foreign_revid(self, foreign_revid): @classmethod def show_foreign_revid(cls, foreign_revid): - return {"git commit": foreign_revid.decode('utf-8')} + return {"git commit": foreign_revid.decode("utf-8")} foreign_vcs_git = ForeignGit() @@ -623,6 +665,7 @@ def show_foreign_revid(cls, foreign_revid): def symlink_to_blob(symlink_target): from dulwich.objects import Blob + blob = Blob() if isinstance(symlink_target, str): symlink_target = encode_git_path(symlink_target) @@ -641,38 +684,43 @@ def mode_kind(mode): return None entry_kind = (mode & 0o700000) / 0o100000 if entry_kind == 0: - return 'directory' + return "directory" elif entry_kind == 1: file_kind = (mode & 0o70000) / 0o10000 if file_kind == 0: - return 'file' + return "file" elif file_kind == 2: - return 'symlink' + return "symlink" elif file_kind == 6: - return 'tree-reference' + return "tree-reference" else: raise AssertionError( - "Unknown file kind %d, perms=%o." % (file_kind, mode,)) + "Unknown file kind %d, perms=%o." + % ( + file_kind, + mode, + ) + ) else: - raise AssertionError( - f"Unknown kind, perms={mode!r}.") + raise AssertionError(f"Unknown kind, perms={mode!r}.") def object_mode(kind, executable): - if kind == 'directory': + if kind == "directory": return stat.S_IFDIR - elif kind == 'symlink': + elif kind == "symlink": mode = stat.S_IFLNK if executable: mode |= 0o111 return mode - elif kind == 'file': + elif kind == "file": mode = stat.S_IFREG | 0o644 if executable: mode |= 0o111 return mode - elif kind == 'tree-reference': + elif kind == "tree-reference": from dulwich.objects import S_IFGITLINK + return S_IFGITLINK else: raise AssertionError @@ -680,13 +728,12 @@ def object_mode(kind, executable): def entry_mode(entry): """Determine the git file mode for an inventory entry.""" - return object_mode(entry.kind, getattr(entry, 'executable', False)) + return object_mode(entry.kind, getattr(entry, "executable", False)) def extract_unusual_modes(rev): try: - foreign_revid, mapping = mapping_registry.parse_revision_id( - rev.revision_id) + foreign_revid, mapping = mapping_registry.parse_revision_id(rev.revision_id) except errors.InvalidRevisionId: return {} else: diff --git a/breezy/git/memorytree.py b/breezy/git/memorytree.py index e48acac335..143fcf89eb 100644 --- a/breezy/git/memorytree.py +++ b/breezy/git/memorytree.py @@ -89,9 +89,8 @@ def _populate_from_branch(self): self._file_transport.mkdir(subpath) trees.append((subpath, self.store[sha])) elif stat.S_ISREG(mode): - self._file_transport.put_bytes( - subpath, self.store[sha].data) - self._index_add_entry(subpath, 'file') + self._file_transport.put_bytes(subpath, self.store[sha].data) + self._index_add_entry(subpath, "file") else: raise NotImplementedError(self._populate_from_branch) @@ -161,7 +160,8 @@ def unlock(self): def _lstat(self, path): mem_stat = self._file_transport.stat(path) stat_val = os.stat_result( - (mem_stat.st_mode, 0, 0, 0, 0, 0, mem_stat.st_size, 0, 0, 0)) + (mem_stat.st_mode, 0, 0, 0, 0, 0, mem_stat.st_size, 0, 0, 0) + ) return stat_val def _live_entry(self, path): @@ -171,11 +171,12 @@ def _live_entry(self, path): return None elif stat.S_ISLNK(stat_val.st_mode): blob = Blob.from_string( - encode_git_path(self._file_transport.readlink(path))) + encode_git_path(self._file_transport.readlink(path)) + ) elif stat.S_ISREG(stat_val.st_mode): blob = Blob.from_string(self._file_transport.get_bytes(path)) else: - raise AssertionError('unknown type %d' % stat_val.st_mode) + raise AssertionError("unknown type %d" % stat_val.st_mode) return index_entry_from_stat(stat_val, blob.id, mode=stat_val.st_mode) def get_file_with_stat(self, path): @@ -204,8 +205,7 @@ def last_revision(self): with self.lock_read(): if self.branch.head is None: return _mod_revision.NULL_REVISION - return self.branch.repository.lookup_foreign_revision_id( - self.branch.head) + return self.branch.repository.lookup_foreign_revision_id(self.branch.head) def basis_tree(self): """See Tree.basis_tree().""" @@ -231,11 +231,12 @@ def set_parent_ids(self, parent_ids, allow_leftmost_as_ghost=False): else: self._parent_ids = parent_ids self.branch.head = self.branch.repository.lookup_bzr_revision_id( - parent_ids[0])[0] + parent_ids[0] + )[0] def mkdir(self, path, file_id=None): """See MutableTree.mkdir().""" - self.add(path, 'directory') + self.add(path, "directory") self._file_transport.mkdir(path) def _rename_one(self, from_rel, to_rel): diff --git a/breezy/git/object_store.py b/breezy/git/object_store.py index a93e16db63..ac155a7299 100644 --- a/breezy/git/object_store.py +++ b/breezy/git/object_store.py @@ -41,7 +41,7 @@ ) from .unpeel_map import UnpeelMap -BANNED_FILENAMES = ['.git'] +BANNED_FILENAMES = [".git"] def get_object_store(repo, mapping=None): @@ -58,15 +58,17 @@ def get_object_store(repo, mapping=None): class LRUTreeCache: - def __init__(self, repository): def approx_tree_size(tree): # Very rough estimate, 250 per inventory entry return len(tree.root_inventory) * 250 + self.repository = repository self._cache = lru_cache.LRUSizeCache( - max_size=MAX_TREE_CACHE_SIZE, after_cleanup_size=None, - compute_size=approx_tree_size) + max_size=MAX_TREE_CACHE_SIZE, + after_cleanup_size=None, + compute_size=approx_tree_size, + ) def revision_tree(self, revid): try: @@ -88,7 +90,9 @@ def iter_revision_trees(self, revids): if tree.get_revision_id() != revid: raise AssertionError( "revision id did not match: {} != {}".format( - tree.get_revision_id(), revid)) + tree.get_revision_id(), revid + ) + ) trees[revid] = tree for tree in self.repository.revision_trees(todo): trees[tree.get_revision_id()] = tree @@ -135,18 +139,20 @@ def _check_expected_sha(expected_sha, object): if expected_sha is None: return if len(expected_sha) == 40: - if expected_sha != object.sha().hexdigest().encode('ascii'): + if expected_sha != object.sha().hexdigest().encode("ascii"): raise AssertionError(f"Invalid sha for {object!r}: {expected_sha}") elif len(expected_sha) == 20: if expected_sha != object.sha().digest(): - raise AssertionError("Invalid sha for {!r}: {}".format( - object, sha_to_hex(expected_sha))) + raise AssertionError( + f"Invalid sha for {object!r}: {sha_to_hex(expected_sha)}" + ) else: raise AssertionError(f"Unknown length {len(expected_sha)} for {expected_sha!r}") -def directory_to_tree(path, children, lookup_ie_sha1, unusual_modes, - empty_file_name, allow_empty=False): +def directory_to_tree( + path, children, lookup_ie_sha1, unusual_modes, empty_file_name, allow_empty=False +): """Create a Git Tree object from a Bazaar directory. :param path: directory path @@ -177,8 +183,9 @@ def directory_to_tree(path, children, lookup_ie_sha1, unusual_modes, return tree -def _tree_to_objects(tree, parent_trees, idmap, unusual_modes, - dummy_file_name=None, add_cache_entry=None): +def _tree_to_objects( + tree, parent_trees, idmap, unusual_modes, dummy_file_name=None, add_cache_entry=None +): """Iterate over the objects that were introduced in a revision. :param idmap: id map @@ -205,15 +212,11 @@ def find_unchanged_parent_ie(path, kind, other, parent_trees): if ppath is not None: pkind = ptree.kind(ppath) if kind == "file": - if (pkind == "file" and - ptree.get_file_sha1(ppath) == other): - return ( - ptree.path2id(ppath), ptree.get_file_revision(ppath)) + if pkind == "file" and ptree.get_file_sha1(ppath) == other: + return (ptree.path2id(ppath), ptree.get_file_revision(ppath)) if kind == "symlink": - if (pkind == "symlink" and - ptree.get_symlink_target(ppath) == other): - return ( - ptree.path2id(ppath), ptree.get_file_revision(ppath)) + if pkind == "symlink" and ptree.get_symlink_target(ppath) == other: + return (ptree.path2id(ppath), ptree.get_file_revision(ppath)) raise KeyError # Find all the changed blobs @@ -225,15 +228,15 @@ def find_unchanged_parent_ie(path, kind, other, parent_trees): blob_id = None try: (pfile_id, prevision) = find_unchanged_parent_ie( - change.path[1], change.kind[1], sha1, other_parent_trees) + change.path[1], change.kind[1], sha1, other_parent_trees + ) except KeyError: pass else: # It existed in one of the parents, with the same contents. # So no need to yield any new git objects. try: - blob_id = idmap.lookup_blob_id( - pfile_id, prevision) + blob_id = idmap.lookup_blob_id(pfile_id, prevision) except KeyError: if not change.changed_content: # no-change merge ? @@ -248,24 +251,33 @@ def find_unchanged_parent_ie(path, kind, other, parent_trees): if add_cache_entry is not None: add_cache_entry( ("blob", blob_id), - (change.file_id, tree.get_file_revision(change.path[1])), change.path[1]) + (change.file_id, tree.get_file_revision(change.path[1])), + change.path[1], + ) elif change.kind[1] == "symlink": target = tree.get_symlink_target(change.path[1]) blob = symlink_to_blob(target) shamap[change.path[1]] = blob.id if add_cache_entry is not None: add_cache_entry( - blob, (change.file_id, tree.get_file_revision(change.path[1])), change.path[1]) + blob, + (change.file_id, tree.get_file_revision(change.path[1])), + change.path[1], + ) try: find_unchanged_parent_ie( - change.path[1], change.kind[1], target, other_parent_trees) + change.path[1], change.kind[1], target, other_parent_trees + ) except KeyError: if change.changed_content: - yield (change.path[1], blob, - (change.file_id, tree.get_file_revision(change.path[1]))) + yield ( + change.path[1], + blob, + (change.file_id, tree.get_file_revision(change.path[1])), + ) elif change.kind[1] is None: shamap[change.path[1]] = None - elif change.kind[1] != 'directory': + elif change.kind[1] != "directory": raise AssertionError(change.kind[1]) for p in change.path: if p is None: @@ -274,7 +286,8 @@ def find_unchanged_parent_ie(path, kind, other, parent_trees): # Fetch contents of the blobs that were changed for (path, file_id), chunks in tree.iter_files_bytes( - [(path, (path, file_id)) for (path, file_id) in new_blobs]): + [(path, (path, file_id)) for (path, file_id) in new_blobs] + ): obj = Blob() obj.chunked = list(chunks) if add_cache_entry is not None: @@ -292,7 +305,7 @@ def find_unchanged_parent_ie(path, kind, other, parent_trees): dirty_dirs.add(parent) if dirty_dirs: - dirty_dirs.add('') + dirty_dirs.add("") def ie_to_hexsha(path, ie): try: @@ -324,8 +337,13 @@ def ie_to_hexsha(path, ie): # Not all cache backends store the tree information, # calculate again from scratch ret = directory_to_tree( - path, tree.iter_child_entries(path), ie_to_hexsha, unusual_modes, - dummy_file_name, ie.parent_id is None) + path, + tree.iter_child_entries(path), + ie_to_hexsha, + unusual_modes, + dummy_file_name, + ie.parent_id is None, + ) if ret is None: return ret return ret.id @@ -336,12 +354,17 @@ def ie_to_hexsha(path, ie): if not tree.has_filename(path): continue - if tree.kind(path) != 'directory': + if tree.kind(path) != "directory": continue obj = directory_to_tree( - path, tree.iter_child_entries(path), ie_to_hexsha, unusual_modes, - dummy_file_name, path == '') + path, + tree.iter_child_entries(path), + ie_to_hexsha, + unusual_modes, + dummy_file_name, + path == "", + ) if obj is not None: file_id = tree.path2id(path) @@ -378,8 +401,7 @@ def _update_sha_map(self, stop_revision=None): raise errors.LockNotHeld(self) if self._map_updated: return - if (stop_revision is not None and - not self._missing_revisions([stop_revision])): + if stop_revision is not None and not self._missing_revisions([stop_revision]): return graph = self.repository.get_graph() if stop_revision is None: @@ -405,9 +427,8 @@ def _update_sha_map(self, stop_revision=None): self.start_write_group() try: with ui.ui_factory.nested_progress_bar() as pb: - for i, revid in enumerate(graph.iter_topo_order( - missing_revids)): - trace.mutter('processing %r', revid) + for i, revid in enumerate(graph.iter_topo_order(missing_revids)): + trace.mutter("processing %r", revid) pb.update("updating git map", i, len(missing_revids)) self._update_sha_map_revision(revid) if stop_revision is None: @@ -431,13 +452,16 @@ def _reconstruct_commit(self, rev, tree_sha, lossy, verifiers): :param verifiers: Verifiers for the commits :return: Commit object """ + def parent_lookup(revid): try: return self._lookup_revision_sha1(revid) except errors.NoSuchRevision: return None - return self.mapping.export_commit(rev, tree_sha, parent_lookup, - lossy, verifiers) + + return self.mapping.export_commit( + rev, tree_sha, parent_lookup, lossy, verifiers + ) def _revision_to_objects(self, rev, tree, lossy, add_cache_entry=None): """Convert a revision to a set of git objects. @@ -449,11 +473,17 @@ def _revision_to_objects(self, rev, tree, lossy, add_cache_entry=None): unusual_modes = extract_unusual_modes(rev) present_parents = self.repository.has_revisions(rev.parent_ids) parent_trees = self.tree_cache.revision_trees( - [p for p in rev.parent_ids if p in present_parents]) + [p for p in rev.parent_ids if p in present_parents] + ) root_tree = None for path, obj, bzr_key_data in _tree_to_objects( - tree, parent_trees, self._cache.idmap, unusual_modes, - self.mapping.BZR_DUMMY_FILE, add_cache_entry): + tree, + parent_trees, + self._cache.idmap, + unusual_modes, + self.mapping.BZR_DUMMY_FILE, + add_cache_entry, + ): if path == "": root_tree = obj root_key_data = bzr_key_data @@ -467,7 +497,7 @@ def _revision_to_objects(self, rev, tree, lossy, add_cache_entry=None): else: base_sha1 = self._lookup_revision_sha1(rev.parent_ids[0]) root_tree = self[self[base_sha1].tree] - root_key_data = (tree.path2id(''), tree.get_revision_id()) + root_key_data = (tree.path2id(""), tree.get_revision_id()) if add_cache_entry is not None: add_cache_entry(root_tree, root_key_data, "") yield "", root_tree @@ -476,11 +506,11 @@ def _revision_to_objects(self, rev, tree, lossy, add_cache_entry=None): verifiers = {"testament3-sha1": testament3.as_sha1()} else: verifiers = {} - commit_obj = self._reconstruct_commit(rev, root_tree.id, - lossy=lossy, verifiers=verifiers) + commit_obj = self._reconstruct_commit( + rev, root_tree.id, lossy=lossy, verifiers=verifiers + ) try: - foreign_revid, mapping = mapping_registry.parse_revision_id( - rev.revision_id) + foreign_revid, mapping = mapping_registry.parse_revision_id(rev.revision_id) except errors.InvalidRevisionId: pass else: @@ -499,14 +529,24 @@ def _update_sha_map_revision(self, revid): updater = self._get_updater(rev) # FIXME JRV 2011-12-15: Shouldn't we try both values for lossy ? for _path, obj in self._revision_to_objects( - rev, tree, lossy=(not self.mapping.roundtripping), - add_cache_entry=updater.add_object): + rev, + tree, + lossy=(not self.mapping.roundtripping), + add_cache_entry=updater.add_object, + ): if isinstance(obj, Commit): commit_obj = obj commit_obj = updater.finish() return commit_obj.id - def iter_unpacked_subset(self, shas, *, include_comp=False, allow_missing: bool = False, convert_ofs_delta: bool = True) -> Iterator[ShaFile]: + def iter_unpacked_subset( + self, + shas, + *, + include_comp=False, + allow_missing: bool = False, + convert_ofs_delta: bool = True, + ) -> Iterator[ShaFile]: # We don't store unpacked objects, so... if not allow_missing and shas: raise KeyError(shas.pop()) @@ -518,8 +558,7 @@ def _reconstruct_blobs(self, keys): :param fileid: File id of the text :param revision: Revision of the text """ - stream = self.repository.iter_files_bytes( - (key[0], key[1], key) for key in keys) + stream = self.repository.iter_files_bytes((key[0], key[1], key) for key in keys) for (file_id, revision, expected_sha), chunks in stream: blob = Blob() blob.chunked = list(chunks) @@ -527,49 +566,57 @@ def _reconstruct_blobs(self, keys): # Perhaps it's a symlink ? tree = self.tree_cache.revision_tree(revision) path = tree.id2path(file_id) - if tree.kind(path) == 'symlink': + if tree.kind(path) == "symlink": blob = symlink_to_blob(tree.get_symlink_target(path)) _check_expected_sha(expected_sha, blob) yield blob - def _reconstruct_tree(self, fileid, revid, bzr_tree, unusual_modes, - expected_sha=None): + def _reconstruct_tree( + self, fileid, revid, bzr_tree, unusual_modes, expected_sha=None + ): """Return a Git Tree object from a file id and a revision stored in bzr. :param fileid: fileid in the tree. :param revision: Revision of the tree. """ + def get_ie_sha1(path, entry): if entry.kind == "directory": try: - return self._cache.idmap.lookup_tree_id(entry.file_id, - revid) + return self._cache.idmap.lookup_tree_id(entry.file_id, revid) except (NotImplementedError, KeyError): obj = self._reconstruct_tree( - entry.file_id, revid, bzr_tree, unusual_modes) + entry.file_id, revid, bzr_tree, unusual_modes + ) if obj is None: return None else: return obj.id elif entry.kind in ("file", "symlink"): try: - return self._cache.idmap.lookup_blob_id(entry.file_id, - entry.revision) + return self._cache.idmap.lookup_blob_id( + entry.file_id, entry.revision + ) except KeyError: # no-change merge? - return next(self._reconstruct_blobs( - [(entry.file_id, entry.revision, None)])).id - elif entry.kind == 'tree-reference': + return next( + self._reconstruct_blobs([(entry.file_id, entry.revision, None)]) + ).id + elif entry.kind == "tree-reference": # FIXME: Make sure the file id is the root id return self._lookup_revision_sha1(entry.reference_revision) else: raise AssertionError(f"unknown entry kind '{entry.kind}'") + path = bzr_tree.id2path(fileid) tree = directory_to_tree( path, bzr_tree.iter_child_entries(path), - get_ie_sha1, unusual_modes, self.mapping.BZR_DUMMY_FILE, - bzr_tree.path2id('') == fileid) + get_ie_sha1, + unusual_modes, + self.mapping.BZR_DUMMY_FILE, + bzr_tree.path2id("") == fileid, + ) if tree is not None: _check_expected_sha(expected_sha, tree) return tree @@ -608,7 +655,7 @@ def get_raw(self, sha): def __contains__(self, sha): # See if sha is in map try: - for (type, type_data) in self.lookup_git_sha(sha): + for type, type_data in self.lookup_git_sha(sha): if type == "commit": if self.repository.has_revision(type_data[0]): return True @@ -626,19 +673,19 @@ def __contains__(self, sha): return False def lock_read(self): - self._locked = 'r' + self._locked = "r" self._map_updated = False self.repository.lock_read() return LogicalLockResult(self.unlock) def lock_write(self): - self._locked = 'r' + self._locked = "r" self._map_updated = False self.repository.lock_write() return LogicalLockResult(self.unlock) def is_locked(self): - return (self._locked is not None) + return self._locked is not None def unlock(self): self._locked = None @@ -668,7 +715,7 @@ def lookup_git_sha(self, sha): def __getitem__(self, sha): with self.repository.lock_read(): - for (kind, type_data) in self.lookup_git_sha(sha): + for kind, type_data in self.lookup_git_sha(sha): # convert object to git object if kind == "commit": (revid, tree_sha, verifiers) = type_data @@ -677,15 +724,24 @@ def __getitem__(self, sha): except errors.NoSuchRevision as err: if revid == NULL_REVISION: raise AssertionError( - "should not try to look up NULL_REVISION") from err - trace.mutter('entry for %s %s in shamap: %r, but not ' - 'found in repository', kind, sha, type_data) + "should not try to look up NULL_REVISION" + ) from err + trace.mutter( + "entry for %s %s in shamap: %r, but not " + "found in repository", + kind, + sha, + type_data, + ) raise KeyError(sha) from err # FIXME: the type data should say whether conversion was # lossless commit = self._reconstruct_commit( - rev, tree_sha, lossy=(not self.mapping.roundtripping), - verifiers=verifiers) + rev, + tree_sha, + lossy=(not self.mapping.roundtripping), + verifiers=verifiers, + ) _check_expected_sha(sha, commit) return commit elif kind == "blob": @@ -699,13 +755,18 @@ def __getitem__(self, sha): rev = self.repository.get_revision(revid) except errors.NoSuchRevision as err: trace.mutter( - 'entry for %s %s in shamap: %r, but not found in ' - 'repository', kind, sha, type_data) + "entry for %s %s in shamap: %r, but not found in " + "repository", + kind, + sha, + type_data, + ) raise KeyError(sha) from err unusual_modes = extract_unusual_modes(rev) try: return self._reconstruct_tree( - fileid, revid, tree, unusual_modes, expected_sha=sha) + fileid, revid, tree, unusual_modes, expected_sha=sha + ) except errors.NoSuchRevision as err: raise KeyError(sha) from err else: @@ -713,17 +774,33 @@ def __getitem__(self, sha): else: raise KeyError(sha) - def generate_lossy_pack_data(self, have, want, shallow=None, - progress=None, - get_tagged=None, ofs_delta=False): - object_ids = list(self.find_missing_objects(have, want, progress=progress, - shallow=shallow, get_tagged=get_tagged, - lossy=True)) - return pack_objects_to_data([ - (self[oid], path) for (oid, (type_num, path)) in object_ids]) - - def find_missing_objects(self, haves, wants, shallow=None, progress=None, - get_tagged=None, lossy: bool = False, ofs_delta=False) -> Iterator[Tuple[ObjectID, Tuple[int, str]]]: + def generate_lossy_pack_data( + self, have, want, shallow=None, progress=None, get_tagged=None, ofs_delta=False + ): + object_ids = list( + self.find_missing_objects( + have, + want, + progress=progress, + shallow=shallow, + get_tagged=get_tagged, + lossy=True, + ) + ) + return pack_objects_to_data( + [(self[oid], path) for (oid, (type_num, path)) in object_ids] + ) + + def find_missing_objects( + self, + haves, + wants, + shallow=None, + progress=None, + get_tagged=None, + lossy: bool = False, + ofs_delta=False, + ) -> Iterator[Tuple[ObjectID, Tuple[int, str]]]: """Iterate over the contents of a pack file. :param haves: List of SHA1s of objects that should not be sent @@ -734,7 +811,7 @@ def find_missing_objects(self, haves, wants, shallow=None, progress=None, for commit_sha in haves: commit_sha = self.unpeel_map.peel_tag(commit_sha, commit_sha) try: - for (type, type_data) in ret[commit_sha]: + for type, type_data in ret[commit_sha]: if type != "commit": raise AssertionError(f"Type was {type}, not commit") processed.add(type_data[0]) @@ -745,7 +822,7 @@ def find_missing_objects(self, haves, wants, shallow=None, progress=None, if commit_sha in haves: continue try: - for (type, type_data) in ret[commit_sha]: + for type, type_data in ret[commit_sha]: if type != "commit": raise AssertionError(f"Type was {type}, not commit") pending.add(type_data[0]) @@ -754,7 +831,7 @@ def find_missing_objects(self, haves, wants, shallow=None, progress=None, shallows = set() for commit_sha in shallow or set(): try: - for (type, type_data) in ret[commit_sha]: + for type, type_data in ret[commit_sha]: if type != "commit": raise AssertionError(f"Type was {type}, not commit") shallows.add(type_data[0]) @@ -773,8 +850,7 @@ def find_missing_objects(self, haves, wants, shallow=None, progress=None, except errors.NoSuchRevision: continue tree = self.tree_cache.revision_tree(revid) - for path, obj in self._revision_to_objects( - rev, tree, lossy=lossy): + for path, obj in self._revision_to_objects(rev, tree, lossy=lossy): if obj.id not in seen: yield (obj.id, (obj.type_num, path)) seen.add(obj.id) @@ -782,11 +858,13 @@ def find_missing_objects(self, haves, wants, shallow=None, progress=None, def add_thin_pack(self): import os import tempfile + fd, path = tempfile.mkstemp(suffix=".pack") - f = os.fdopen(fd, 'wb') + f = os.fdopen(fd, "wb") def commit(): from .fetch import import_git_objects + os.fsync(fd) f.close() if os.path.getsize(path) == 0: @@ -798,14 +876,18 @@ def commit(): with self.repository.lock_write(): self.repository.start_write_group() try: - import_git_objects(self.repository, self.mapping, - p.iterobjects(get_raw=self.get_raw), - self.object_store) + import_git_objects( + self.repository, + self.mapping, + p.iterobjects(get_raw=self.get_raw), + self.object_store, + ) except BaseException: self.repository.abort_write_group() raise else: self.repository.commit_write_group() + return f, commit # The pack isn't kept around anyway, so no point diff --git a/breezy/git/pristine_tar.py b/breezy/git/pristine_tar.py index 8fa025471d..03987ead10 100644 --- a/breezy/git/pristine_tar.py +++ b/breezy/git/pristine_tar.py @@ -30,15 +30,15 @@ def revision_pristine_tar_data(rev): """Export the pristine tar data from a revision.""" - if 'deb-pristine-delta' in rev.properties: - uuencoded = rev.properties['deb-pristine-delta'] - kind = 'gz' - elif 'deb-pristine-delta-bz2' in rev.properties: - uuencoded = rev.properties['deb-pristine-delta-bz2'] - kind = 'bz2' - elif 'deb-pristine-delta-xz' in rev.properties: - uuencoded = rev.properties['deb-pristine-delta-xz'] - kind = 'xz' + if "deb-pristine-delta" in rev.properties: + uuencoded = rev.properties["deb-pristine-delta"] + kind = "gz" + elif "deb-pristine-delta-bz2" in rev.properties: + uuencoded = rev.properties["deb-pristine-delta-bz2"] + kind = "bz2" + elif "deb-pristine-delta-xz" in rev.properties: + uuencoded = rev.properties["deb-pristine-delta-xz"] + kind = "xz" else: raise KeyError(rev.revision_id) @@ -65,12 +65,10 @@ def read_git_pristine_tar_data(repo, filename): tree = get_pristine_tar_tree(repo) delta = tree[filename + b".delta"][1] gitid = tree[filename + b".id"][1] - return (repo.object_store[delta].data, - repo.object_store[gitid].data) + return (repo.object_store[delta].data, repo.object_store[gitid].data) -def store_git_pristine_tar_data(repo, filename, delta, gitid, - message=None, **kwargs): +def store_git_pristine_tar_data(repo, filename, delta, gitid, message=None, **kwargs): """Add pristine tar data to a Git repository. :param repo: Git repository to add data to @@ -82,9 +80,7 @@ def store_git_pristine_tar_data(repo, filename, delta, gitid, delta_name = filename + b".delta" id_ob = Blob.from_string(gitid) id_name = filename + b".id" - objects = [ - (delta_ob, delta_name), - (id_ob, id_name)] + objects = [(delta_ob, delta_name), (id_ob, id_name)] tree = get_pristine_tar_tree(repo) tree.add(delta_name, stat.S_IFREG | 0o644, delta_ob.id) tree.add(id_name, stat.S_IFREG | 0o644, id_ob.id) @@ -95,6 +91,7 @@ def store_git_pristine_tar_data(repo, filename, delta, gitid, objects.append((tree, "")) repo.object_store.add_objects(objects) if message is None: - message = b'pristine-tar data for %s' % filename - return repo.do_commit(ref=b'refs/heads/pristine-tar', tree=tree.id, - message=message, **kwargs) + message = b"pristine-tar data for %s" % filename + return repo.do_commit( + ref=b"refs/heads/pristine-tar", tree=tree.id, message=message, **kwargs + ) diff --git a/breezy/git/push.py b/breezy/git/push.py index b1221f92e0..e65fc2d069 100644 --- a/breezy/git/push.py +++ b/breezy/git/push.py @@ -21,12 +21,11 @@ class GitPushResult(PushResult): - def _lookup_revno(self, revid): from .branch import _quick_lookup_revno + try: - return _quick_lookup_revno(self.source_branch, self.target_branch, - revid) + return _quick_lookup_revno(self.source_branch, self.target_branch, revid) except GitSmartRemoteNotSupported: return None @@ -70,8 +69,7 @@ def import_revision(self, revid, lossy): tree = self._object_store.tree_cache.revision_tree(revid) rev = self.source.get_revision(revid) commit = None - for path, obj in self._object_store._revision_to_objects( - rev, tree, lossy): + for path, obj in self._object_store._revision_to_objects(rev, tree, lossy): if obj.type_name == b"commit": commit = obj self._pending.append((obj, path)) @@ -87,7 +85,6 @@ def __iter__(self): class ObjectStoreParentsProvider: - def __init__(self, store): self._store = store @@ -113,5 +110,6 @@ def remote_divergence(old_sha, new_sha, store): if not isinstance(new_sha, bytes): raise TypeError(new_sha) from ..graph import Graph + graph = Graph(ObjectStoreParentsProvider(store)) return not graph.is_ancestor(old_sha, new_sha) diff --git a/breezy/git/refs.py b/breezy/git/refs.py index 359234a31d..f5cea79607 100644 --- a/breezy/git/refs.py +++ b/breezy/git/refs.py @@ -71,18 +71,17 @@ def ref_to_branch_name(ref): if ref is None: return ref if ref.startswith(LOCAL_BRANCH_PREFIX): - return ref[len(LOCAL_BRANCH_PREFIX):].decode('utf-8') + return ref[len(LOCAL_BRANCH_PREFIX) :].decode("utf-8") raise ValueError(f"unable to map ref {ref} back to branch name") def ref_to_tag_name(ref): if ref.startswith(LOCAL_TAG_PREFIX): - return ref[len(LOCAL_TAG_PREFIX):].decode("utf-8") + return ref[len(LOCAL_TAG_PREFIX) :].decode("utf-8") raise ValueError(f"unable to map ref {ref} back to tag name") class BazaarRefsContainer(RefsContainer): - def __init__(self, dir, object_store): self.dir = dir self.object_store = object_store @@ -95,8 +94,8 @@ def set_symbolic_ref(self, name, other): pass # FIXME: Switch default branch else: raise NotImplementedError( - "Symbolic references not supported for anything other than " - "HEAD") + "Symbolic references not supported for anything other than " "HEAD" + ) def _get_revid_by_tag_name(self, tag_name): for branch in self.dir.list_branches(): diff --git a/breezy/git/remote.py b/breezy/git/remote.py index 3bcbc93d07..b6d167e62b 100644 --- a/breezy/git/remote.py +++ b/breezy/git/remote.py @@ -78,16 +78,14 @@ from .repository import GitRepository, GitRepositoryFormat # urlparse only supports a limited number of schemes by default -register_urlparse_netloc_protocol('git') -register_urlparse_netloc_protocol('git+ssh') +register_urlparse_netloc_protocol("git") +register_urlparse_netloc_protocol("git+ssh") class GitPushResult(PushResult): - def _lookup_revno(self, revid): try: - return _quick_lookup_revno(self.source_branch, self.target_branch, - revid) + return _quick_lookup_revno(self.source_branch, self.target_branch, revid) except GitSmartRemoteNotSupported: return None @@ -116,23 +114,22 @@ def split_git_url(url): path = urlparse.unquote(parsed_url.path) if path.startswith("/~"): path = path[1:] - return ((parsed_url.hostname or '', parsed_url.port, parsed_url.username, path)) + return (parsed_url.hostname or "", parsed_url.port, parsed_url.username, path) class RemoteGitError(BzrError): - _fmt = "Remote server error: %(msg)s" class ProtectedBranchHookDeclined(BzrError): - _fmt = "Protected branch hook declined" class HeadUpdateFailed(BzrError): - - _fmt = ("Unable to update remote HEAD branch. To update the master " - "branch, specify the URL %(base_url)s,branch=master.") + _fmt = ( + "Unable to update remote HEAD branch. To update the master " + "branch, specify the URL %(base_url)s,branch=master." + ) def __init__(self, base_url): super().__init__() @@ -146,46 +143,49 @@ def parse_git_error(url, message): :param message: Message sent by the remote git server """ message = str(message).strip() - if (message.startswith("Could not find Repository ") - or message == 'Repository not found.' - or (message.startswith('Repository ') and - message.endswith(' not found.'))): + if ( + message.startswith("Could not find Repository ") + or message == "Repository not found." + or (message.startswith("Repository ") and message.endswith(" not found.")) + ): return NotBranchError(url, message) if message == "HEAD failed to update": base_url = urlutils.strip_segment_parameters(url) return HeadUpdateFailed(base_url) - if message.startswith('access denied or repository not exported:'): - extra, path = message.split(':', 1) + if message.startswith("access denied or repository not exported:"): + extra, path = message.split(":", 1) return PermissionDenied(path.strip(), extra) - if message.endswith('You are not allowed to push code to this project.'): + if message.endswith("You are not allowed to push code to this project."): return PermissionDenied(url, message) - if message.endswith(' does not appear to be a git repository'): + if message.endswith(" does not appear to be a git repository"): return NotBranchError(url, message) - if message == 'A repository for this project does not exist yet.': + if message == "A repository for this project does not exist yet.": return NotBranchError(url, message) - if message == 'pre-receive hook declined': + if message == "pre-receive hook declined": return PermissionDenied(url, message) - if re.match('(.+) is not a valid repository name', - message.splitlines()[0]): + if re.match("(.+) is not a valid repository name", message.splitlines()[0]): return NotBranchError(url, message) if message == ( - 'GitLab: You are not allowed to push code to protected branches ' - 'on this project.'): + "GitLab: You are not allowed to push code to protected branches " + "on this project." + ): return PermissionDenied(url, message) - m = re.match(r'Permission to ([^ ]+) denied to ([^ ]+)\.', message) + m = re.match(r"Permission to ([^ ]+) denied to ([^ ]+)\.", message) if m: - return PermissionDenied(m.group(1), f'denied to {m.group(2)}') - if message == 'Host key verification failed.': - return TransportError('Host key verification failed') - if message == '[Errno 104] Connection reset by peer': + return PermissionDenied(m.group(1), f"denied to {m.group(2)}") + if message == "Host key verification failed.": + return TransportError("Host key verification failed") + if message == "[Errno 104] Connection reset by peer": return ConnectionResetError(message) - if message == 'The remote server unexpectedly closed the connection.': + if message == "The remote server unexpectedly closed the connection.": return TransportError(message) - m = re.match(r'unexpected http resp ([0-9]+) for (.*)', message) + m = re.match(r"unexpected http resp ([0-9]+) for (.*)", message) if m: # TODO(jelmer): Have dulwich raise an exception and look at that instead? - return UnexpectedHttpStatus(path=m.group(2), code=int(m.group(1)), extra=message) - if message == 'protected branch hook declined': + return UnexpectedHttpStatus( + path=m.group(2), code=int(m.group(1)), extra=message + ) + if message == "protected branch hook declined": return ProtectedBranchHookDeclined(msg=message) # Don't know, just return it to the user as-is return RemoteGitError(message) @@ -197,32 +197,32 @@ def parse_git_hangup(url, e): :param url: URL of the remote repository :param e: A HangupException """ - stderr_lines = getattr(e, 'stderr_lines', None) + stderr_lines = getattr(e, "stderr_lines", None) if not stderr_lines: - return ConnectionResetError('Connection closed early') - if all(line.startswith(b'remote: ') for line in stderr_lines): - stderr_lines = [ - line[len(b'remote: '):] for line in stderr_lines] + return ConnectionResetError("Connection closed early") + if all(line.startswith(b"remote: ") for line in stderr_lines): + stderr_lines = [line[len(b"remote: ") :] for line in stderr_lines] interesting_lines = [ - line for line in stderr_lines - if line and line.replace(b'=', b'')] + line for line in stderr_lines if line and line.replace(b"=", b"") + ] if len(interesting_lines) == 1: interesting_line = interesting_lines[0] - return parse_git_error( - url, interesting_line.decode('utf-8', 'surrogateescape')) - return RemoteGitError( - b'\n'.join(stderr_lines).decode('utf-8', 'surrogateescape')) + return parse_git_error(url, interesting_line.decode("utf-8", "surrogateescape")) + return RemoteGitError(b"\n".join(stderr_lines).decode("utf-8", "surrogateescape")) class GitSmartTransport(Transport): - def __init__(self, url, _client=None): Transport.__init__(self, url) - (self._host, self._port, self._username, self._path) = \ - split_git_url(url) - if debug.debug_flag_enabled('transport'): - trace.mutter('host: %r, user: %r, port: %r, path: %r', - self._host, self._username, self._port, self._path) + (self._host, self._port, self._username, self._path) = split_git_url(url) + if debug.debug_flag_enabled("transport"): + trace.mutter( + "host: %r, user: %r, port: %r, path: %r", + self._host, + self._username, + self._port, + self._path, + ) self._client = _client self._stripped_path = self._path.rsplit(",", 1)[0] @@ -255,23 +255,22 @@ def clone(self, offset=None): class TCPGitSmartTransport(GitSmartTransport): - - _scheme = 'git' + _scheme = "git" def _get_client(self): if self._client is not None: ret = self._client self._client = None return ret - if self._host == '': + if self._host == "": # return dulwich.client.LocalGitClient() return dulwich.client.SubprocessGitClient() return dulwich.client.TCPGitClient( - self._host, self._port, report_activity=self._report_activity) + self._host, self._port, report_activity=self._report_activity + ) class SSHSocketWrapper: - def __init__(self, sock): self.sock = sock @@ -286,17 +285,17 @@ def can_read(self): class DulwichSSHVendor(dulwich.client.SSHVendor): - def __init__(self): from ..transport import ssh + self.bzr_ssh_vendor = ssh._get_ssh_vendor() def run_command(self, host, command, username=None, port=None): connection = self.bzr_ssh_vendor.connect_ssh( - username=username, password=None, port=port, host=host, - command=command) + username=username, password=None, port=port, host=host, command=command + ) (kind, io_object) = connection.get_sock_or_pipes() - if kind == 'socket': + if kind == "socket": return SSHSocketWrapper(io_object) else: raise AssertionError(f"Unknown io object kind {kind!r}'") @@ -306,8 +305,7 @@ def run_command(self, host, command, username=None, port=None): class SSHGitSmartTransport(GitSmartTransport): - - _scheme = 'git+ssh' + _scheme = "git+ssh" def _get_path(self): path = self._stripped_path @@ -322,34 +320,36 @@ def _get_client(self): return ret location_config = config.LocationConfig(self.base) client = dulwich.client.SSHGitClient( - self._host, self._port, self._username, - report_activity=self._report_activity) + self._host, + self._port, + self._username, + report_activity=self._report_activity, + ) # Set up alternate pack program paths - upload_pack = location_config.get_user_option('git_upload_pack') + upload_pack = location_config.get_user_option("git_upload_pack") if upload_pack: client.alternative_paths["upload-pack"] = upload_pack - receive_pack = location_config.get_user_option('git_receive_pack') + receive_pack = location_config.get_user_option("git_receive_pack") if receive_pack: client.alternative_paths["receive-pack"] = receive_pack return client class RemoteGitBranchFormat(GitBranchFormat): - def get_format_description(self): - return 'Remote Git Branch' + return "Remote Git Branch" @property def _matchingcontroldir(self): return RemoteGitControlDirFormat() - def initialize(self, a_controldir, name=None, repository=None, - append_revisions_only=None): + def initialize( + self, a_controldir, name=None, repository=None, append_revisions_only=None + ): raise UninitializableFormat(self) class DefaultProgressReporter: - _GIT_PROGRESS_PARTIAL_RE = re.compile(r"(.*?): +(\d+)% \((\d+)/(\d+)\)") _GIT_PROGRESS_TOTAL_RE = re.compile(r"(.*?): (\d+)") @@ -359,11 +359,11 @@ def __init__(self, pb): def progress(self, text): text = text.rstrip(b"\r\n") - text = text.decode('utf-8', 'surrogateescape') - if text.lower().startswith('error: '): - error = text[len('error: '):] + text = text.decode("utf-8", "surrogateescape") + if text.lower().startswith("error: "): + error = text[len("error: ") :] self.errors.append(error) - trace.show_error('git: %s', error) + trace.show_error("git: %s", error) else: trace.mutter("git: %s", text) g = self._GIT_PROGRESS_PARTIAL_RE.match(text) @@ -379,11 +379,10 @@ def progress(self, text): trace.note("%s", text) -_LOCK_REF_ERROR_MATCHER = re.compile('cannot lock ref \'(.*)\': (.*)') +_LOCK_REF_ERROR_MATCHER = re.compile("cannot lock ref '(.*)': (.*)") class RemoteGitDir(GitDir): - def __init__(self, transport, format, client, client_path): self._format = format self.root_transport = transport @@ -398,27 +397,42 @@ def __init__(self, transport, format, client, client_path): def _gitrepository_class(self): return RemoteGitRepository - def archive(self, format, committish, write_data, progress=None, - write_error=None, subdirs=None, prefix=None, recurse_nested=False): + def archive( + self, + format, + committish, + write_data, + progress=None, + write_error=None, + subdirs=None, + prefix=None, + recurse_nested=False, + ): if recurse_nested: - raise NotImplementedError('recurse_nested is not yet supported') + raise NotImplementedError("recurse_nested is not yet supported") if progress is None: pb = ui.ui_factory.nested_progress_bar() progress = DefaultProgressReporter(pb).progress else: pb = None + def progress_wrapper(message): - if message.startswith(b"fatal: Unknown archive format \'"): - format = message.strip()[len(b"fatal: Unknown archive format '"):-1] - raise errors.NoSuchExportFormat(format.decode('ascii')) + if message.startswith(b"fatal: Unknown archive format '"): + format = message.strip()[len(b"fatal: Unknown archive format '") : -1] + raise errors.NoSuchExportFormat(format.decode("ascii")) return progress(message) + try: self._client.archive( - self._client_path, committish, write_data, progress_wrapper, + self._client_path, + committish, + write_data, + progress_wrapper, write_error, - format=(format.encode('ascii') if format else None), + format=(format.encode("ascii") if format else None), subdirs=subdirs, - prefix=(encode_git_path(prefix) if prefix else None)) + prefix=(encode_git_path(prefix) if prefix else None), + ) except HangupException as e: raise parse_git_hangup(self.transport.external_url(), e) from e except GitProtocolError as e: @@ -427,8 +441,7 @@ def progress_wrapper(message): if pb is not None: pb.finished() - def fetch_pack(self, determine_wants, graph_walker, pack_data, - progress=None): + def fetch_pack(self, determine_wants, graph_walker, pack_data, progress=None): if progress is None: pb = ui.ui_factory.nested_progress_bar() progress = DefaultProgressReporter(pb).progress @@ -436,12 +449,11 @@ def fetch_pack(self, determine_wants, graph_walker, pack_data, pb = None try: result = self._client.fetch_pack( - self._client_path, determine_wants, graph_walker, pack_data, - progress) + self._client_path, determine_wants, graph_walker, pack_data, progress + ) if result.refs is None: result.refs = {} - self._refs = remote_refs_dict_to_container( - result.refs, result.symrefs) + self._refs = remote_refs_dict_to_container(result.refs, result.symrefs) return result except HangupException as e: raise parse_git_hangup(self.transport.external_url(), e) from e @@ -464,10 +476,14 @@ def get_changed_refs_wrapper(remote_refs): if self._refs is not None: update_refs_container(self._refs, remote_refs) return get_changed_refs(remote_refs) + try: result = self._client.send_pack( - self._client_path, get_changed_refs_wrapper, - generate_pack_data, progress) + self._client_path, + get_changed_refs_wrapper, + generate_pack_data, + progress, + ) for ref, msg in list(result.ref_status.items()): if msg: result.ref_status[ref] = RemoteGitError(msg=msg) @@ -475,7 +491,9 @@ def get_changed_refs_wrapper(remote_refs): for error in progress_reporter.errors: m = _LOCK_REF_ERROR_MATCHER.match(error) if m: - result.ref_status[m.group(1)] = LockContention(m.group(1), m.group(2)) + result.ref_status[m.group(1)] = LockContention( + m.group(1), m.group(2) + ) return result except HangupException as e: raise parse_git_hangup(self.transport.external_url(), e) from e @@ -485,14 +503,14 @@ def get_changed_refs_wrapper(remote_refs): if pb is not None: pb.finished() - def create_branch(self, name=None, repository=None, - append_revisions_only=None, ref=None): + def create_branch( + self, name=None, repository=None, append_revisions_only=None, ref=None + ): refname = self._get_selected_ref(name, ref) - if refname != b'HEAD' and refname in self.get_refs_container(): + if refname != b"HEAD" and refname in self.get_refs_container(): raise AlreadyBranchError(self.user_url) - ref_chain, sha = self.get_refs_container().follow( - self._get_selected_ref(name)) - if ref_chain and ref_chain[0] == b'HEAD' and len(ref_chain) > 1: + ref_chain, sha = self.get_refs_container().follow(self._get_selected_ref(name)) + if ref_chain and ref_chain[0] == b"HEAD" and len(ref_chain) > 1: refname = ref_chain[1] repo = self.open_repository() return RemoteGitBranch(self, repo, refname, sha) @@ -509,6 +527,7 @@ def get_changed_refs(old_refs): def generate_pack_data(have, want, ofs_delta=False, progress=None): return pack_objects_to_data([]) + result = self.send_pack(get_changed_refs, generate_pack_data) error = result.ref_status.get(refname) if error: @@ -543,20 +562,27 @@ def get_branch_reference(self, name=None): return None target_ref = ref_chain[1] from .refs import ref_to_branch_name + try: branch_name = ref_to_branch_name(target_ref) except ValueError: - params = {'ref': urlutils.quote(target_ref.decode('utf-8'), '')} + params = {"ref": urlutils.quote(target_ref.decode("utf-8"), "")} else: - if branch_name != '': - params = {'branch': urlutils.quote(branch_name, '')} + if branch_name != "": + params = {"branch": urlutils.quote(branch_name, "")} else: params = {} - return urlutils.join_segment_parameters(self.user_url.rstrip('/'), params) + return urlutils.join_segment_parameters(self.user_url.rstrip("/"), params) - def open_branch(self, name=None, unsupported=False, - ignore_fallbacks=False, ref=None, possible_transports=None, - nascent_ok=False): + def open_branch( + self, + name=None, + unsupported=False, + ignore_fallbacks=False, + ref=None, + possible_transports=None, + nascent_ok=False, + ): repo = self.open_repository() ref = self._get_selected_ref(name, ref) try: @@ -566,8 +592,7 @@ def open_branch(self, name=None, unsupported=False, except NotGitRepository as err: raise NotBranchError(self.root_transport.base, controldir=self) from err if not nascent_ok and sha is None: - raise NotBranchError( - self.root_transport.base, controldir=self) + raise NotBranchError(self.root_transport.base, controldir=self) return RemoteGitBranch(self, repo, ref_chain[-1], sha) def open_workingtree(self, recommend_upgrade=False): @@ -582,16 +607,23 @@ def get_peeled(self, name): def get_refs_container(self): if self._refs is not None: return self._refs - result = self.fetch_pack(lambda x: None, None, - lambda x: None, - lambda x: trace.mutter(f"git: {x}")) - self._refs = remote_refs_dict_to_container( - result.refs, result.symrefs) + result = self.fetch_pack( + lambda x: None, None, lambda x: None, lambda x: trace.mutter(f"git: {x}") + ) + self._refs = remote_refs_dict_to_container(result.refs, result.symrefs) return self._refs - def push_branch(self, source, revision_id=None, overwrite=False, - remember=False, create_prefix=False, lossy=False, - name=None, tag_selector=None): + def push_branch( + self, + source, + revision_id=None, + overwrite=False, + remember=False, + create_prefix=False, + lossy=False, + name=None, + tag_selector=None, + ): """Push the source branch into this ControlDir.""" if revision_id is None: # No revision supplied by the user, default to the branch @@ -622,7 +654,8 @@ def push_branch(self, source, revision_id=None, overwrite=False, if isinstance(source, GitBranch) and lossy: raise errors.LossyPushToSameVCS(source.controldir, self) source_store = get_object_store(source.repository) - fetch_tags = source.get_config_stack().get('branch.fetch_tags') + fetch_tags = source.get_config_stack().get("branch.fetch_tags") + def get_changed_refs(remote_refs): if self._refs is not None: update_refs_container(self._refs, remote_refs) @@ -636,12 +669,14 @@ def get_changed_refs(remote_refs): new_sha = repo.lookup_bzr_revision_id(revision_id)[0] except errors.NoSuchRevision as err: raise errors.NoRoundtrippingSupport( - source, self.open_branch(name=name, nascent_ok=True)) from err + source, self.open_branch(name=name, nascent_ok=True) + ) from err old_sha = remote_refs.get(actual_refname) if not overwrite: if remote_divergence(old_sha, new_sha, source_store): raise DivergedBranches( - source, self.open_branch(name, nascent_ok=True)) + source, self.open_branch(name, nascent_ok=True) + ) ret[actual_refname] = new_sha if fetch_tags: for tagname, revid in source.tags.get_tag_dict().items(): @@ -663,25 +698,36 @@ def get_changed_refs(remote_refs): continue ret[tag_name_to_ref(tagname)] = new_sha return ret + with source_store.lock_read(): - def generate_pack_data(have, want, progress=None, - ofs_delta=True): - git_repo = getattr(source.repository, '_git', None) + + def generate_pack_data(have, want, progress=None, ofs_delta=True): + git_repo = getattr(source.repository, "_git", None) if git_repo: shallow = git_repo.get_shallow() else: shallow = None if lossy: return source_store.generate_lossy_pack_data( - have, want, shallow=shallow, - progress=progress, ofs_delta=ofs_delta) + have, + want, + shallow=shallow, + progress=progress, + ofs_delta=ofs_delta, + ) elif shallow: return source_store.generate_pack_data( - have, want, shallow=shallow, - progress=progress, ofs_delta=ofs_delta) + have, + want, + shallow=shallow, + progress=progress, + ofs_delta=ofs_delta, + ) else: return source_store.generate_pack_data( - have, want, progress=progress, ofs_delta=ofs_delta) + have, want, progress=progress, ofs_delta=ofs_delta + ) + dw_result = self.send_pack(get_changed_refs, generate_pack_data) new_refs = dw_result.refs error = dw_result.ref_status.get(actual_refname) @@ -689,10 +735,10 @@ def generate_pack_data(have, want, progress=None, raise error for ref, error in dw_result.ref_status.items(): if error: - trace.warning('unable to open ref %s: %s', - ref, error) + trace.warning("unable to open ref %s: %s", ref, error) push_result.new_revid = repo.lookup_foreign_revision_id( - new_refs[actual_refname]) + new_refs[actual_refname] + ) if old_sha is not None: push_result.old_revid = repo.lookup_foreign_revision_id(old_sha) else: @@ -703,15 +749,14 @@ def generate_pack_data(have, want, progress=None, if old_sha is not None: push_result.branch_push_result = GitBranchPushResult() push_result.branch_push_result.source_branch = source - push_result.branch_push_result.target_branch = ( - push_result.target_branch) + push_result.branch_push_result.target_branch = push_result.target_branch push_result.branch_push_result.local_branch = None - push_result.branch_push_result.master_branch = ( - push_result.target_branch) + push_result.branch_push_result.master_branch = push_result.target_branch push_result.branch_push_result.old_revid = push_result.old_revid push_result.branch_push_result.new_revid = push_result.new_revid push_result.branch_push_result.new_original_revid = ( - push_result.new_original_revid) + push_result.new_original_revid + ) if source.get_push_location() is None or remember: source.set_push_location(push_result.target_branch.base) return push_result @@ -722,23 +767,22 @@ def _find_commondir(self): class EmptyObjectStoreIterator(dict): - def iterobjects(self): return [] class TemporaryPackIterator(Pack): - def __init__(self, path, resolve_ext_ref): - super().__init__( - path, resolve_ext_ref=resolve_ext_ref) + super().__init__(path, resolve_ext_ref=resolve_ext_ref) self._idx_load = lambda: self._idx_load_or_generate(self._idx_path) def _idx_load_or_generate(self, path): if not os.path.exists(path): with ui.ui_factory.nested_progress_bar() as pb: + def report_progress(cur, total): pb.update("generating index", cur, total) + self.data.create_index(path, progress=report_progress) return load_pack_index(path) @@ -752,7 +796,6 @@ def __del__(self): class BzrGitHttpClient(dulwich.client.HttpGitClient): - def __init__(self, transport, *args, **kwargs): self.transport = transport url = urlutils.URL.from_string(transport.external_url()) @@ -774,8 +817,7 @@ def archive( ): raise GitSmartRemoteNotSupported(self.archive, self) - def _http_request(self, url, headers=None, data=None, - allow_compression=False): + def _http_request(self, url, headers=None, data=None, allow_compression=False): """Perform HTTP request. :param url: Request URL. @@ -787,25 +829,27 @@ def _http_request(self, url, headers=None, data=None, method for the response data. """ if is_github_url(url): - headers['User-agent'] = user_agent_for_github() + headers["User-agent"] = user_agent_for_github() headers["Pragma"] = "no-cache" response = self.transport.request( - ('GET' if data is None else 'POST'), + ("GET" if data is None else "POST"), url, body=data, - headers=headers, retries=8) + headers=headers, + retries=8, + ) if response.status == 404: raise NotGitRepository() elif response.status != 200: - raise GitProtocolError("unexpected http resp %d for %s" % - (response.status, url)) + raise GitProtocolError( + "unexpected http resp %d for %s" % (response.status, url) + ) read = response.read class WrapResponse: - def __init__(self, response): self._response = response self.status = response.status @@ -855,7 +899,7 @@ def open(self, transport, _found=None): client = transport._get_client() elif split_url.scheme in ("http", "https"): client = BzrGitHttpClient(transport) - elif split_url.scheme in ('file', ): + elif split_url.scheme in ("file",): client = dulwich.client.LocalGitClient() else: raise NotBranchError(transport.base) @@ -874,15 +918,24 @@ def supports_transport(self, transport): external_url = transport.external_url() except InProcessTransport as err: raise NotBranchError(path=transport.base) from err - return (external_url.startswith("http:") - or external_url.startswith("https:") - or external_url.startswith("git+") - or external_url.startswith("git:")) + return ( + external_url.startswith("http:") + or external_url.startswith("https:") + or external_url.startswith("git+") + or external_url.startswith("git:") + ) class GitRemoteRevisionTree(RevisionTree): - - def archive(self, format, name, root=None, subdir=None, force_mtime=None, recurse_nested=False): + def archive( + self, + format, + name, + root=None, + subdir=None, + force_mtime=None, + recurse_nested=False, + ): """Create an archive of this tree. :param format: Format name (e.g. 'tar') @@ -893,16 +946,19 @@ def archive(self, format, name, root=None, subdir=None, force_mtime=None, recurs """ if recurse_nested: # TODO(jelmer): Parse .gitmodules from archive afterwards? - raise NotImplementedError('recurse_nested is not yet supported') - commit = self._repository.lookup_bzr_revision_id( - self.get_revision_id())[0] + raise NotImplementedError("recurse_nested is not yet supported") + commit = self._repository.lookup_bzr_revision_id(self.get_revision_id())[0] from tempfile import SpooledTemporaryFile - f = SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE, prefix='incoming-') + + f = SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE, prefix="incoming-") # git-upload-archive(1) generaly only supports refs. So let's see if we # can find one. reverse_refs = { - v: k for (k, v) in - self._repository.controldir.get_refs_container().as_dict().items()} + v: k + for (k, v) in self._repository.controldir.get_refs_container() + .as_dict() + .items() + } try: committish = reverse_refs[commit] except KeyError: @@ -910,9 +966,12 @@ def archive(self, format, name, root=None, subdir=None, force_mtime=None, recurs # Let's hope for the best. committish = commit self._repository.archive( - format, committish, f.write, + format, + committish, + f.write, subdirs=([subdir] if subdir else None), - prefix=(root + '/') if root else '') + prefix=(root + "/") if root else "", + ) f.seek(0) return osutils.file_iterator(f) @@ -930,7 +989,6 @@ def list_files(self, include_root=False, from_dir=None, recursive=True): class RemoteGitRepository(GitRepository): - supports_random_access = False @property @@ -943,26 +1001,29 @@ def get_parent_map(self, revids): def archive(self, *args, **kwargs): return self.controldir.archive(*args, **kwargs) - def fetch_pack(self, determine_wants, graph_walker, pack_data, - progress=None): + def fetch_pack(self, determine_wants, graph_walker, pack_data, progress=None): return self.controldir.fetch_pack( - determine_wants, graph_walker, pack_data, progress) + determine_wants, graph_walker, pack_data, progress + ) def send_pack(self, get_changed_refs, generate_pack_data): return self.controldir.send_pack(get_changed_refs, generate_pack_data) - def fetch_objects(self, determine_wants, graph_walker, resolve_ext_ref, - progress=None): + def fetch_objects( + self, determine_wants, graph_walker, resolve_ext_ref, progress=None + ): import tempfile + fd, path = tempfile.mkstemp(suffix=".pack") try: - self.fetch_pack(determine_wants, graph_walker, - lambda x: os.write(fd, x), progress) + self.fetch_pack( + determine_wants, graph_walker, lambda x: os.write(fd, x), progress + ) finally: os.close(fd) if os.path.getsize(path) == 0: return EmptyObjectStoreIterator() - return TemporaryPackIterator(path[:-len(".pack")], resolve_ext_ref) + return TemporaryPackIterator(path[: -len(".pack")], resolve_ext_ref) def lookup_bzr_revision_id(self, bzr_revid, mapping=None): # This won't work for any round-tripped bzr revisions, but it's a @@ -990,7 +1051,6 @@ def has_revisions(self, revids): class RemoteGitTagDict(GitTags): - def set_tag(self, name, revid): sha = self.branch.lookup_bzr_revision_id(revid)[0] self._set_ref(name, sha) @@ -1010,19 +1070,17 @@ def get_changed_refs(old_refs): def generate_pack_data(have, want, ofs_delta=False, progress=None): return pack_objects_to_data([]) - result = self.repository.send_pack( - get_changed_refs, generate_pack_data) + + result = self.repository.send_pack(get_changed_refs, generate_pack_data) error = result.ref_status.get(ref) if error: raise error class RemoteGitBranch(GitBranch): - def __init__(self, controldir, repository, ref, sha): self._sha = sha - super().__init__(controldir, repository, ref, - RemoteGitBranchFormat()) + super().__init__(controldir, repository, ref, RemoteGitBranchFormat()) def last_revision_info(self): raise GitSmartRemoteNotSupported(self.last_revision_info, self) @@ -1083,15 +1141,16 @@ def _iter_tag_refs(self): def set_last_revision_info(self, revno, revid): self.generate_revision_history(revid) - def generate_revision_history(self, revision_id, last_rev=None, - other_branch=None): + def generate_revision_history(self, revision_id, last_rev=None, other_branch=None): sha = self.lookup_bzr_revision_id(revision_id)[0] + def get_changed_refs(old_refs): return {self.ref: sha} + def generate_pack_data(have, want, ofs_delta=False, progress=None): return pack_objects_to_data([]) - result = self.repository.send_pack( - get_changed_refs, generate_pack_data) + + result = self.repository.send_pack(get_changed_refs, generate_pack_data) error = result.ref_status.get(self.ref) if error: raise error diff --git a/breezy/git/repository.py b/breezy/git/repository.py index 123d2863fb..c07ce0eabf 100644 --- a/breezy/git/repository.py +++ b/breezy/git/repository.py @@ -34,7 +34,6 @@ class GitCheck(check.Check): - def __init__(self, repository, check_repo=True): self.repository = repository self.check_repo = check_repo @@ -45,13 +44,12 @@ def __init__(self, repository, check_repo=True): def check(self, callback_refs=None, check_repo=True): if callback_refs is None: callback_refs = {} - with self.repository.lock_read(), \ - ui.ui_factory.nested_progress_bar() as self.progress: + with self.repository.lock_read(), ui.ui_factory.nested_progress_bar() as self.progress: shas = set(self.repository._git.object_store) self.object_count = len(shas) # TODO(jelmer): Check more things for i, sha in enumerate(shas): - self.progress.update('checking objects', i, self.object_count) + self.progress.update("checking objects", i, self.object_count) o = self.repository._git.object_store[sha] try: o.check() @@ -59,28 +57,32 @@ def check(self, callback_refs=None, check_repo=True): self.problems.append((sha, e)) def _report_repo_results(self, verbose): - trace.note('checked repository {} format {}'.format( - self.repository.user_url, - self.repository._format)) - trace.note('%6d objects', self.object_count) + trace.note( + "checked repository {} format {}".format( + self.repository.user_url, self.repository._format + ) + ) + trace.note("%6d objects", self.object_count) for sha, problem in self.problems: - trace.note('%s: %s', sha, problem) + trace.note("%s: %s", sha, problem) def report_results(self, verbose): if self.check_repo: self._report_repo_results(verbose) -for optimiser in ['InterRemoteGitNonGitRepository', - 'InterLocalGitNonGitRepository', - 'InterLocalGitLocalGitRepository', - 'InterLocalGitRemoteGitRepository', - 'InterRemoteGitLocalGitRepository', - 'InterToLocalGitRepository', - 'InterToRemoteGitRepository', - ]: +for optimiser in [ + "InterRemoteGitNonGitRepository", + "InterLocalGitNonGitRepository", + "InterLocalGitLocalGitRepository", + "InterLocalGitRemoteGitRepository", + "InterRemoteGitLocalGitRepository", + "InterToLocalGitRepository", + "InterToRemoteGitRepository", +]: repository.InterRepository.register_lazy_optimiser( - 'breezy.git.interrepo', optimiser) + "breezy.git.interrepo", optimiser + ) class GitRepository(ForeignRepository): @@ -91,15 +93,15 @@ class GitRepository(ForeignRepository): def __init__(self, gitdir): self._transport = gitdir.root_transport - super().__init__(GitRepositoryFormat(), - gitdir, control_files=None) + super().__init__(GitRepositoryFormat(), gitdir, control_files=None) self.base = gitdir.root_transport.base self._lock_mode = None self._lock_count = 0 def add_fallback_repository(self, basis_url): - raise errors.UnstackableRepositoryFormat(self._format, - self.control_transport.base) + raise errors.UnstackableRepositoryFormat( + self._format, self.control_transport.base + ) def is_shared(self): return False @@ -110,11 +112,11 @@ def get_physical_lock_status(self): def lock_write(self): """See Branch.lock_write().""" if self._lock_mode: - if self._lock_mode != 'w': + if self._lock_mode != "w": raise errors.ReadOnlyError(self) self._lock_count += 1 else: - self._lock_mode = 'w' + self._lock_mode = "w" self._lock_count = 1 self._transaction = transactions.WriteTransaction() return repository.RepositoryWriteLockResult(self.unlock, None) @@ -130,11 +132,11 @@ def leave_lock_in_place(self): def lock_read(self): if self._lock_mode: - if self._lock_mode not in ('r', 'w'): + if self._lock_mode not in ("r", "w"): raise AssertionError self._lock_count += 1 else: - self._lock_mode = 'r' + self._lock_mode = "r" self._lock_count = 1 self._transaction = transactions.ReadOnlyTransaction() return lock.LogicalLockResult(self.unlock) @@ -143,27 +145,28 @@ def lock_read(self): def unlock(self): if self._lock_count == 0: raise errors.LockNotHeld(self) - if self._lock_count == 1 and self._lock_mode == 'w': + if self._lock_count == 1 and self._lock_mode == "w": if self._write_group is not None: self.abort_write_group() self._lock_count -= 1 self._lock_mode = None raise errors.BzrError( - 'Must end write groups before releasing write locks.') + "Must end write groups before releasing write locks." + ) self._lock_count -= 1 if self._lock_count == 0: self._lock_mode = None transaction = self._transaction self._transaction = None transaction.finish() - if hasattr(self, '_git'): + if hasattr(self, "_git"): self._git.close() def is_write_locked(self): - return (self._lock_mode == 'w') + return self._lock_mode == "w" def is_locked(self): - return (self._lock_mode is not None) + return self._lock_mode is not None def get_transaction(self): """See Repository.get_transaction().""" @@ -175,6 +178,7 @@ def get_transaction(self): def reconcile(self, other=None, thorough=False): """Reconcile this repository.""" from ..reconcile import ReconcileResult + ret = ReconcileResult() ret.aborted = False return ret @@ -207,9 +211,18 @@ def __init__(self, gitdir): self._file_change_scanner = GitFileLastChangeScanner(self) self._transaction = None - def get_commit_builder(self, branch, parents, config, timestamp=None, - timezone=None, committer=None, revprops=None, - revision_id=None, lossy=False): + def get_commit_builder( + self, + branch, + parents, + config, + timestamp=None, + timezone=None, + committer=None, + revprops=None, + revision_id=None, + lossy=False, + ): """Obtain a CommitBuilder for this repository. :param branch: Branch to commit to. @@ -224,20 +237,28 @@ def get_commit_builder(self, branch, parents, config, timestamp=None, represented, when pushing to a foreign VCS """ from .commit import GitCommitBuilder + builder = GitCommitBuilder( - self, parents, config, timestamp, timezone, committer, revprops, - revision_id, lossy) + self, + parents, + config, + timestamp, + timezone, + committer, + revprops, + revision_id, + lossy, + ) self.start_write_group() return builder def _write_git_config(self, cs): f = BytesIO() cs.write_to_file(f) - self._git._put_named_file('config', f.getvalue()) + self._git._put_named_file("config", f.getvalue()) def get_file_graph(self): - return _mod_graph.Graph(GitFileParentProvider( - self._file_change_scanner)) + return _mod_graph.Graph(GitFileParentProvider(self._file_change_scanner)) def iter_files_bytes(self, desired_files): """Iterate through file versions. @@ -259,9 +280,8 @@ def iter_files_bytes(self, desired_files): triples """ per_revision = {} - for (file_id, revision_id, identifier) in desired_files: - per_revision.setdefault(revision_id, []).append( - (file_id, identifier)) + for file_id, revision_id, identifier in desired_files: + per_revision.setdefault(revision_id, []).append((file_id, identifier)) for revid, files in per_revision.items(): try: (commit_id, mapping) = self.lookup_bzr_revision_id(revid) @@ -279,8 +299,10 @@ def iter_files_bytes(self, desired_files): raise errors.RevisionNotPresent((fileid, revid), self) from err try: mode, item_id = tree_lookup_path( - self._git.object_store.__getitem__, root_tree, - encode_git_path(path)) + self._git.object_store.__getitem__, + root_tree, + encode_git_path(path), + ) obj = self._git.object_store[item_id] except KeyError as err: raise errors.RevisionNotPresent((fileid, revid), self) from err @@ -294,14 +316,13 @@ def iter_files_bytes(self, desired_files): def gather_stats(self, revid=None, committers=None): """See Repository.gather_stats().""" - result = super().gather_stats( - revid, committers) + result = super().gather_stats(revid, committers) revs = [] for sha in self._git.object_store: o = self._git.object_store[sha] if o.type_name == b"commit": revs.append(o.id) - result['revisions'] = len(revs) + result["revisions"] = len(revs) return result def _iter_revision_ids(self): @@ -345,8 +366,7 @@ def _get_parent_map_no_fallbacks(self, revids): def get_parent_map(self, revids, no_alternates=False): parent_map = {} for revision_id in revids: - parents = self._get_parents( - revision_id, no_alternates=no_alternates) + parents = self._get_parents(revision_id, no_alternates=no_alternates) if revision_id == _mod_revision.NULL_REVISION: parent_map[revision_id] = () continue @@ -436,6 +456,7 @@ def verify_revision_signature(self, revision_id, gpg_strategy): :return: gpg.SIGNATURE_VALID or a failed SIGNATURE_ value """ from breezy import gpg + with self.lock_read(): git_commit_id, mapping = self.lookup_bzr_revision_id(revision_id) try: @@ -450,7 +471,8 @@ def verify_revision_signature(self, revision_id, gpg_strategy): without_sig.gpgsig = None (result, key, plain_text) = gpg_strategy.verify( - without_sig.as_raw_string(), commit.gpgsig) + without_sig.as_raw_string(), commit.gpgsig + ) return (result, key) def lookup_bzr_revision_id(self, bzr_revid, mapping=None): @@ -462,8 +484,7 @@ def lookup_bzr_revision_id(self, bzr_revid, mapping=None): details """ try: - (git_sha, mapping) = mapping_registry.revision_id_bzr_to_foreign( - bzr_revid) + (git_sha, mapping) = mapping_registry.revision_id_bzr_to_foreign(bzr_revid) except errors.InvalidRevisionId as err: raise errors.NoSuchRevision(self, bzr_revid) from err else: @@ -478,7 +499,8 @@ def get_revision(self, revision_id): except KeyError as err: raise errors.NoSuchRevision(self, revision_id) from err revision, roundtrip_revid, verifiers = mapping.import_commit( - commit, self.lookup_foreign_revision_id, strict=False) + commit, self.lookup_foreign_revision_id, strict=False + ) if revision is None: raise AssertionError # FIXME: check verifiers ? @@ -494,7 +516,7 @@ def has_revision(self, revision_id): git_commit_id, mapping = self.lookup_bzr_revision_id(revision_id) except errors.NoSuchRevision: return False - return (git_commit_id in self._git) + return git_commit_id in self._git def has_revisions(self, revision_ids): """See Repository.has_revisions.""" @@ -517,14 +539,14 @@ def revision_trees(self, revids): def revision_tree(self, revision_id): """See Repository.revision_tree.""" if revision_id is None: - raise ValueError(f'invalid revision id {revision_id}') + raise ValueError(f"invalid revision id {revision_id}") return GitRevisionTree(self, revision_id) def set_make_working_trees(self, trees): raise errors.UnsupportedOperation(self.set_make_working_trees, self) def make_working_trees(self): - return not self._git.get_config().get_boolean(("core", ), "bare") + return not self._git.get_config().get_boolean(("core",), "bare") class GitRepositoryFormat(repository.RepositoryFormat): @@ -554,6 +576,7 @@ class GitRepositoryFormat(repository.RepositoryFormat): @property def _matchingcontroldir(self): from .dir import LocalGitControlDirFormat + return LocalGitControlDirFormat() def get_format_description(self): @@ -561,6 +584,7 @@ def get_format_description(self): def initialize(self, controldir, shared=False, _internal=False): from .dir import GitDir + if not isinstance(controldir, GitDir): raise errors.UninitializableFormat(self) return controldir.open_repository() @@ -570,6 +594,7 @@ def check_conversion_target(self, target_repo_format): def get_foreign_tests_repository_factory(self): from .tests.test_repository import ForeignTestsRepositoryFactory + return ForeignTestsRepositoryFactory() def network_name(self): @@ -579,11 +604,21 @@ def network_name(self): def get_extra_interrepo_test_combinations(): from ..bzr.groupcompress_repo import RepositoryFormat2a from . import interrepo + return [ - (interrepo.InterLocalGitNonGitRepository, - GitRepositoryFormat(), RepositoryFormat2a()), - (interrepo.InterLocalGitLocalGitRepository, - GitRepositoryFormat(), GitRepositoryFormat()), - (interrepo.InterToLocalGitRepository, - RepositoryFormat2a(), GitRepositoryFormat()), - ] + ( + interrepo.InterLocalGitNonGitRepository, + GitRepositoryFormat(), + RepositoryFormat2a(), + ), + ( + interrepo.InterLocalGitLocalGitRepository, + GitRepositoryFormat(), + GitRepositoryFormat(), + ), + ( + interrepo.InterToLocalGitRepository, + RepositoryFormat2a(), + GitRepositoryFormat(), + ), + ] diff --git a/breezy/git/revspec.py b/breezy/git/revspec.py index b7128f4083..218a633c56 100644 --- a/breezy/git/revspec.py +++ b/breezy/git/revspec.py @@ -49,15 +49,18 @@ class RevisionSpec_git(RevisionSpec): imported into Bazaar repositories. """ - prefix = 'git:' + prefix = "git:" wants_revision_history = False def _lookup_git_sha1(self, branch, sha1): from .errors import GitSmartRemoteNotSupported from .mapping import default_mapping - bzr_revid = getattr(branch.repository, "lookup_foreign_revision_id", - default_mapping.revision_id_foreign_to_bzr)(sha1) + bzr_revid = getattr( + branch.repository, + "lookup_foreign_revision_id", + default_mapping.revision_id_foreign_to_bzr, + )(sha1) try: if branch.repository.has_revision(bzr_revid): return bzr_revid @@ -75,8 +78,13 @@ def __nonzero__(self): def _find_short_git_sha1(self, branch, sha1): from .mapping import ForeignGit, mapping_registry - parse_revid = getattr(branch.repository, "lookup_bzr_revision_id", - mapping_registry.parse_revision_id) + + parse_revid = getattr( + branch.repository, + "lookup_bzr_revision_id", + mapping_registry.parse_revision_id, + ) + def matches_revid(revid): if revid == NULL_REVISION: return False @@ -87,6 +95,7 @@ def matches_revid(revid): if not isinstance(mapping.vcs, ForeignGit): return False return foreign_revid.startswith(sha1) + with branch.repository.lock_read(): graph = branch.repository.get_graph() last_revid = branch.last_revision() @@ -98,12 +107,12 @@ def matches_revid(revid): raise InvalidRevisionSpec(self.user_spec, branch) def _as_revision_id(self, context_branch): - loc = self.spec.find(':') - git_sha1 = self.spec[loc + 1:].encode("utf-8") - if (len(git_sha1) > 40 or len(git_sha1) < 4 or - not valid_git_sha1(git_sha1)): + loc = self.spec.find(":") + git_sha1 = self.spec[loc + 1 :].encode("utf-8") + if len(git_sha1) > 40 or len(git_sha1) < 4 or not valid_git_sha1(git_sha1): raise InvalidRevisionSpec(self.user_spec, context_branch) from . import lazy_check_versions + lazy_check_versions() if len(git_sha1) == 40: return self._lookup_git_sha1(context_branch, git_sha1) diff --git a/breezy/git/roundtrip.py b/breezy/git/roundtrip.py index 00c964d36b..a5ee870048 100644 --- a/breezy/git/roundtrip.py +++ b/breezy/git/roundtrip.py @@ -67,8 +67,7 @@ def __init__(self) -> None: self.verifiers: Dict[str, Any] = {} def __nonzero__(self) -> bool: - return bool(self.revision_id or self.properties or - self.explicit_parent_ids) + return bool(self.revision_id or self.properties or self.explicit_parent_ids) class TreeSupplement: @@ -92,7 +91,7 @@ def parse_roundtripping_metadata(text): elif key == b"testament3-sha1": ret.verifiers[b"testament3-sha1"] = value.strip() elif key.startswith(b"property-"): - name = key[len(b"property-"):] + name = key[len(b"property-") :] if name not in ret.properties: ret.properties[name] = value[1:].rstrip(b"\n") else: @@ -112,14 +111,12 @@ def generate_roundtripping_metadata(metadata, encoding): if metadata.revision_id: lines.append(b"revision-id: %s\n" % metadata.revision_id) if metadata.explicit_parent_ids: - lines.append(b"parent-ids: %s\n" % - b" ".join(metadata.explicit_parent_ids)) + lines.append(b"parent-ids: %s\n" % b" ".join(metadata.explicit_parent_ids)) for key in sorted(metadata.properties.keys()): for l in metadata.properties[key].split(b"\n"): lines.append(b"property-%s: %s\n" % (key, l)) if b"testament3-sha1" in metadata.verifiers: - lines.append(b"testament3-sha1: %s\n" % - metadata.verifiers[b"testament3-sha1"]) + lines.append(b"testament3-sha1: %s\n" % metadata.verifiers[b"testament3-sha1"]) return b"".join(lines) diff --git a/breezy/git/send.py b/breezy/git/send.py index 49fe395924..b07598ad90 100644 --- a/breezy/git/send.py +++ b/breezy/git/send.py @@ -35,8 +35,7 @@ from .mapping import object_mode from .object_store import get_object_store -version_tail = "Breezy %s, dulwich %d.%d.%d" % ( - (brz_version, ) + dulwich_version[:3]) +version_tail = "Breezy %s, dulwich %d.%d.%d" % ((brz_version,) + dulwich_version[:3]) class GitDiffTree(_mod_diff.DiffTree): @@ -44,9 +43,13 @@ class GitDiffTree(_mod_diff.DiffTree): def _show_diff(self, specific_files, extra_trees): from dulwich.patch import write_blob_diff + iterator = self.new_tree.iter_changes( - self.old_tree, specific_files=specific_files, - extra_trees=extra_trees, require_versioned=True) + self.old_tree, + specific_files=specific_files, + extra_trees=extra_trees, + require_versioned=True, + ) has_changes = 0 def get_encoded_path(path): @@ -64,29 +67,45 @@ def get_blob(present, tree, path): return Blob.from_string(f.read()) else: return None + trees = (self.old_tree, self.new_tree) for change in iterator: # The root does not get diffed, and items with no known kind (that # is, missing) in both trees are skipped as well. if change.parent_id == (None, None) or change.kind == (None, None): continue - path_encoded = (get_encoded_path(change.path[0]), - get_encoded_path(change.path[1])) - present = ((change.kind[0] not in (None, 'directory')), - (change.kind[1] not in (None, 'directory'))) + path_encoded = ( + get_encoded_path(change.path[0]), + get_encoded_path(change.path[1]), + ) + present = ( + (change.kind[0] not in (None, "directory")), + (change.kind[1] not in (None, "directory")), + ) if not present[0] and not present[1]: continue - contents = (get_blob(present[0], trees[0], change.path[0]), - get_blob(present[1], trees[1], change.path[1])) - renamed = (change.parent_id[0], change.name[0]) != (change.parent_id[1], change.name[1]) - mode = (get_file_mode(trees[0], path_encoded[0], - change.kind[0], change.executable[0]), - get_file_mode(trees[1], path_encoded[1], - change.kind[1], change.executable[1])) - write_blob_diff(self.to_file, - (path_encoded[0], mode[0], contents[0]), - (path_encoded[1], mode[1], contents[1])) - has_changes |= (change.changed_content or renamed) + contents = ( + get_blob(present[0], trees[0], change.path[0]), + get_blob(present[1], trees[1], change.path[1]), + ) + renamed = (change.parent_id[0], change.name[0]) != ( + change.parent_id[1], + change.name[1], + ) + mode = ( + get_file_mode( + trees[0], path_encoded[0], change.kind[0], change.executable[0] + ), + get_file_mode( + trees[1], path_encoded[1], change.kind[1], change.executable[1] + ), + ) + write_blob_diff( + self.to_file, + (path_encoded[0], mode[0], contents[0]), + (path_encoded[1], mode[1], contents[1]), + ) + has_changes |= change.changed_content or renamed return has_changes @@ -95,16 +114,31 @@ def generate_patch_filename(num, summary): class GitMergeDirective(BaseMergeDirective): - multiple_output_files = True - def __init__(self, revision_id, testament_sha1, time, timezone, - target_branch, source_branch=None, message=None, - patches=None, local_target_branch=None): + def __init__( + self, + revision_id, + testament_sha1, + time, + timezone, + target_branch, + source_branch=None, + message=None, + patches=None, + local_target_branch=None, + ): super().__init__( - revision_id=revision_id, testament_sha1=testament_sha1, time=time, - timezone=timezone, target_branch=target_branch, patch=None, - source_branch=source_branch, message=message, bundle=None) + revision_id=revision_id, + testament_sha1=testament_sha1, + time=time, + timezone=timezone, + target_branch=target_branch, + patch=None, + source_branch=source_branch, + message=message, + bundle=None, + ) self.patches = patches def to_lines(self): @@ -114,13 +148,20 @@ def to_files(self): return ((summary, patch.splitlines(True)) for (summary, patch) in self.patches) @classmethod - def _generate_commit(cls, repository, revision_id, num, total, - context=_mod_diff.DEFAULT_CONTEXT_AMOUNT): + def _generate_commit( + cls, + repository, + revision_id, + num, + total, + context=_mod_diff.DEFAULT_CONTEXT_AMOUNT, + ): s = BytesIO() store = get_object_store(repository) with store.lock_read(): commit = store[repository.lookup_bzr_revision_id(revision_id)[0]] from dulwich.patch import get_summary, write_commit_patch + try: lhs_parent = repository.get_revision(revision_id).parent_ids[0] except IndexError: @@ -129,18 +170,33 @@ def _generate_commit(cls, repository, revision_id, num, total, tree_2 = repository.revision_tree(revision_id) contents = BytesIO() differ = GitDiffTree.from_trees_options( - tree_1, tree_2, contents, 'utf8', None, 'a/', 'b/', None, - context_lines=context) + tree_1, + tree_2, + contents, + "utf8", + None, + "a/", + "b/", + None, + context_lines=context, + ) differ.show_diff(None, None) - write_commit_patch(s, commit, contents.getvalue(), (num, total), - version_tail) + write_commit_patch(s, commit, contents.getvalue(), (num, total), version_tail) summary = generate_patch_filename(num, get_summary(commit)) return summary, s.getvalue() @classmethod - def from_objects(cls, repository, revision_id, time, timezone, - target_branch, local_target_branch=None, - public_branch=None, message=None): + def from_objects( + cls, + repository, + revision_id, + time, + timezone, + target_branch, + local_target_branch=None, + public_branch=None, + message=None, + ): patches = [] submit_branch = _mod_branch.Branch.open(target_branch) with submit_branch.lock_read(): @@ -150,23 +206,41 @@ def from_objects(cls, repository, revision_id, time, timezone, todo = graph.find_difference(submit_revision_id, revision_id)[1] total = len(todo) for i, revid in enumerate(graph.iter_topo_order(todo)): - patches.append(cls._generate_commit(repository, revid, i + 1, - total)) - return cls(revision_id, None, time, timezone, - target_branch=target_branch, source_branch=public_branch, - message=message, patches=patches) - - -def send_git(branch, revision_id, submit_branch, public_branch, no_patch, - no_bundle, message, base_revision_id, local_target_branch=None): + patches.append(cls._generate_commit(repository, revid, i + 1, total)) + return cls( + revision_id, + None, + time, + timezone, + target_branch=target_branch, + source_branch=public_branch, + message=message, + patches=patches, + ) + + +def send_git( + branch, + revision_id, + submit_branch, + public_branch, + no_patch, + no_bundle, + message, + base_revision_id, + local_target_branch=None, +): if no_patch: - raise errors.CommandError( - "no patch not supported for git-am style patches") + raise errors.CommandError("no patch not supported for git-am style patches") if no_bundle: - raise errors.CommandError( - "no bundle not supported for git-am style patches") + raise errors.CommandError("no bundle not supported for git-am style patches") return GitMergeDirective.from_objects( - repository=branch.repository, revision_id=revision_id, time=time.time(), - timezone=osutils.local_time_offset(), target_branch=submit_branch, - public_branch=public_branch, message=message, - local_target_branch=local_target_branch) + repository=branch.repository, + revision_id=revision_id, + time=time.time(), + timezone=osutils.local_time_offset(), + target_branch=submit_branch, + public_branch=public_branch, + message=message, + local_target_branch=local_target_branch, + ) diff --git a/breezy/git/server.py b/breezy/git/server.py index 39b58f2d89..823d44cd1a 100644 --- a/breezy/git/server.py +++ b/breezy/git/server.py @@ -45,12 +45,11 @@ def __init__(self, transport): def open_repository(self, path): # FIXME: More secure path sanitization transport = self.transport.clone(decode_git_path(path).lstrip("/")) - trace.mutter('client opens %r: %r', path, transport) + trace.mutter("client opens %r: %r", path, transport) return BzrBackendRepo(transport, self.mapping) class BzrBackendRepo(BackendRepo): - def __init__(self, transport, mapping): self.mapping = mapping self.repo_dir = ControlDir.open_from_transport(transport) @@ -68,38 +67,44 @@ def get_peeled(self, name): return cached return peel_sha(self.object_store, self.refs[name])[1].id - def find_missing_objects(self, determine_wants, graph_walker, progress, - get_tagged=None): + def find_missing_objects( + self, determine_wants, graph_walker, progress, get_tagged=None + ): """Yield git objects to send to client.""" with self.object_store.lock_read(): wants = determine_wants(self.get_refs()) have = self.object_store.find_common_revisions(graph_walker) if wants is None: return - shallows = getattr(graph_walker, 'shallow', frozenset()) + shallows = getattr(graph_walker, "shallow", frozenset()) if isinstance(self.object_store, BazaarObjectStore): return self.object_store.find_missing_objects( - have, wants, shallow=shallows, - progress=progress, get_tagged=get_tagged, lossy=True) + have, + wants, + shallow=shallows, + progress=progress, + get_tagged=get_tagged, + lossy=True, + ) else: return MissingObjectFinder( - self.object_store, - have, wants, shallow=shallows, progress=progress) + self.object_store, have, wants, shallow=shallows, progress=progress + ) class BzrTCPGitServer(TCPGitServer): - def handle_error(self, request, client_address): trace.log_exception_quietly() - trace.warning('Exception happened during processing of request ' - 'from %s', client_address) + trace.warning( + "Exception happened during processing of request " "from %s", client_address + ) def serve_git(transport, host=None, port=None, inet=False, timeout=None): backend = BzrBackend(transport) if host is None: - host = 'localhost' + host = "localhost" if port: server = BzrTCPGitServer(backend, host, port) else: @@ -109,8 +114,9 @@ def serve_git(transport, host=None, port=None, inet=False, timeout=None): def git_http_hook(branch, method, path): from dulwich.web import DEFAULT_HANDLERS, HTTPGitApplication, HTTPGitRequest + handler = None - for (smethod, spath) in HTTPGitApplication.services: + for smethod, spath in HTTPGitApplication.services: if smethod != method: continue mat = spath.search(path) @@ -122,9 +128,11 @@ def git_http_hook(branch, method, path): backend = BzrBackend(branch.user_transport) def git_call(environ, start_response): - req = HTTPGitRequest(environ, start_response, dumb=False, - handlers=DEFAULT_HANDLERS) + req = HTTPGitRequest( + environ, start_response, dumb=False, handlers=DEFAULT_HANDLERS + ) return handler(req, backend, mat) + return git_call @@ -140,9 +148,11 @@ def serve_command(handler_cls, backend, inf=sys.stdin, outf=sys.stdout): :param outf: File-like object to write to, defaults to standard output. :return: Exit code for use with sys.exit. 0 on success, 1 on failure. """ + def send_fn(data): outf.write(data) outf.flush() + proto = Protocol(inf.read, send_fn) handler = handler_cls(backend, ["/"], proto) # FIXME: Catch exceptions and write a single-line summary to outf. @@ -152,15 +162,13 @@ def send_fn(data): def serve_git_receive_pack(transport, host=None, port=None, inet=False): if not inet: - raise errors.CommandError( - "git-receive-pack only works in inetd mode") + raise errors.CommandError("git-receive-pack only works in inetd mode") backend = BzrBackend(transport) sys.exit(serve_command(ReceivePackHandler, backend=backend)) def serve_git_upload_pack(transport, host=None, port=None, inet=False): if not inet: - raise errors.CommandError( - "git-receive-pack only works in inetd mode") + raise errors.CommandError("git-receive-pack only works in inetd mode") backend = BzrBackend(transport) sys.exit(serve_command(UploadPackHandler, backend=backend)) diff --git a/breezy/git/tests/__init__.py b/breezy/git/tests/__init__.py index b6ce4ef644..f88bebbdcd 100644 --- a/breezy/git/tests/__init__.py +++ b/breezy/git/tests/__init__.py @@ -32,7 +32,6 @@ class _DulwichFeature(Feature): - def _probe(self): try: import_dulwich() @@ -41,15 +40,14 @@ def _probe(self): return True def feature_name(self): - return 'dulwich' + return "dulwich" DulwichFeature = _DulwichFeature() -FastimportFeature = ModuleAvailableFeature('fastimport') +FastimportFeature = ModuleAvailableFeature("fastimport") class GitBranchBuilder: - def __init__(self, stream=None): if not FastimportFeature.available(): raise tests.UnavailableFeature(FastimportFeature) @@ -60,7 +58,7 @@ def __init__(self, stream=None): else: self.stream = stream self._counter = 0 - self._branch = b'refs/heads/master' + self._branch = b"refs/heads/master" def set_branch(self, branch): """Set the branch we are committing.""" @@ -75,53 +73,64 @@ def _writelines(self, lines): def _create_blob(self, content): self._counter += 1 from fastimport.commands import BlobCommand - blob = BlobCommand(b'%d' % self._counter, content) + + blob = BlobCommand(b"%d" % self._counter, content) self._write(bytes(blob) + b"\n") return self._counter def set_symlink(self, path, content): """Create or update symlink at a given path.""" mark = self._create_blob(self._encode_path(content)) - mode = b'120000' - self.commit_info.append(b'M %s :%d %s\n' - % (mode, mark, self._encode_path(path))) + mode = b"120000" + self.commit_info.append( + b"M %s :%d %s\n" % (mode, mark, self._encode_path(path)) + ) def set_submodule(self, path, commit_sha): """Create or update submodule at a given path.""" - mode = b'160000' + mode = b"160000" self.commit_info.append( - b'M %s %s %s\n' % (mode, commit_sha, self._encode_path(path))) + b"M %s %s %s\n" % (mode, commit_sha, self._encode_path(path)) + ) def set_file(self, path, content, executable): """Create or update content at a given path.""" mark = self._create_blob(content) if executable: - mode = b'100755' + mode = b"100755" else: - mode = b'100644' - self.commit_info.append(b'M %s :%d %s\n' - % (mode, mark, self._encode_path(path))) + mode = b"100644" + self.commit_info.append( + b"M %s :%d %s\n" % (mode, mark, self._encode_path(path)) + ) def delete_entry(self, path): """This will delete files or symlinks at the given location.""" - self.commit_info.append(b'D %s\n' % (self._encode_path(path),)) + self.commit_info.append(b"D %s\n" % (self._encode_path(path),)) @staticmethod def _encode_path(path): if isinstance(path, bytes): return path - if '\n' in path or path[0] == '"': - path = path.replace('\\', '\\\\') - path = path.replace('\n', '\\n') + if "\n" in path or path[0] == '"': + path = path.replace("\\", "\\\\") + path = path.replace("\n", "\\n") path = path.replace('"', '\\"') path = '"' + path + '"' - return path.encode('utf-8') + return path.encode("utf-8") # TODO: Author # TODO: Author timestamp+timezone - def commit(self, committer, message, timestamp=None, - timezone=b'+0000', author=None, - merge=None, base=None): + def commit( + self, + committer, + message, + timestamp=None, + timezone=b"+0000", + author=None, + merge=None, + base=None, + ): """Commit the new content. :param committer: The name and address for the committer @@ -137,25 +146,24 @@ def commit(self, committer, message, timestamp=None, commit. """ self._counter += 1 - mark = b'%d' % (self._counter,) + mark = b"%d" % (self._counter,) if timestamp is None: timestamp = int(time.time()) - self._write(b'commit %s\n' % (self._branch,)) - self._write(b'mark :%s\n' % (mark,)) - self._write(b'committer %s %ld %s\n' - % (committer, timestamp, timezone)) + self._write(b"commit %s\n" % (self._branch,)) + self._write(b"mark :%s\n" % (mark,)) + self._write(b"committer %s %ld %s\n" % (committer, timestamp, timezone)) if not isinstance(message, bytes): - message = message.encode('UTF-8') - self._write(b'data %d\n' % (len(message),)) + message = message.encode("UTF-8") + self._write(b"data %d\n" % (len(message),)) self._write(message) - self._write(b'\n') + self._write(b"\n") if base is not None: - self._write(b'from :%s\n' % (base,)) + self._write(b"from :%s\n" % (base,)) if merge is not None: for m in merge: - self._write(b'merge :%s\n' % (m,)) + self._write(b"merge :%s\n" % (m,)) self._writelines(self.commit_info) - self._write(b'\n') + self._write(b"\n") self.commit_info = [] return mark @@ -167,24 +175,25 @@ def reset(self, ref=None, mark=None): """ if ref is None: ref = self._branch - self._write(b'reset %s\n' % (ref,)) + self._write(b"reset %s\n" % (ref,)) if mark is not None: - self._write(b'from :%s\n' % mark) - self._write(b'\n') + self._write(b"from :%s\n" % mark) + self._write(b"\n") def finish(self): """We are finished building, close the stream, get the id mapping.""" self.stream.seek(0) if self.orig_stream is None: from dulwich.repo import Repo + r = Repo(".") from dulwich.fastexport import GitImportProcessor + importer = GitImportProcessor(r) return importer.import_stream(self.stream) class MissingFeature(tests.TestCase): - def test_dulwich(self): self.requireFeature(DulwichFeature) @@ -194,40 +203,43 @@ def load_tests(loader, basic_tests, pattern): # add the tests for this module suite.addTests(basic_tests) - prefix = __name__ + '.' + prefix = __name__ + "." if not DulwichFeature.available(): suite.addTests(loader.loadTestsFromTestCase(MissingFeature)) return suite testmod_names = [ - 'test_blackbox', - 'test_builder', - 'test_branch', - 'test_cache', - 'test_dir', - 'test_fetch', - 'test_git_remote_helper', - 'test_mapping', - 'test_memorytree', - 'test_object_store', - 'test_pristine_tar', - 'test_push', - 'test_remote', - 'test_repository', - 'test_refs', - 'test_revspec', - 'test_roundtrip', - 'test_server', - 'test_transform', - 'test_transportgit', - 'test_tree', - 'test_unpeel_map', - 'test_urls', - 'test_workingtree', - ] + "test_blackbox", + "test_builder", + "test_branch", + "test_cache", + "test_dir", + "test_fetch", + "test_git_remote_helper", + "test_mapping", + "test_memorytree", + "test_object_store", + "test_pristine_tar", + "test_push", + "test_remote", + "test_repository", + "test_refs", + "test_revspec", + "test_roundtrip", + "test_server", + "test_transform", + "test_transportgit", + "test_tree", + "test_unpeel_map", + "test_urls", + "test_workingtree", + ] # add the tests for the sub modules - suite.addTests(loader.loadTestsFromModuleNames( - [prefix + module_name for module_name in testmod_names])) + suite.addTests( + loader.loadTestsFromModuleNames( + [prefix + module_name for module_name in testmod_names] + ) + ) return suite diff --git a/breezy/git/tests/test_blackbox.py b/breezy/git/tests/test_blackbox.py index 81bb616854..b78ff537d1 100644 --- a/breezy/git/tests/test_blackbox.py +++ b/breezy/git/tests/test_blackbox.py @@ -30,105 +30,110 @@ class TestGitBlackBox(ExternalBase): - def simple_commit(self): # Create a git repository with a revision. repo = GitRepo.init(self.test_dir) builder = tests.GitBranchBuilder() - builder.set_file('a', b'text for a\n', False) - r1 = builder.commit(b'Joe Foo ', '') + builder.set_file("a", b"text for a\n", False) + r1 = builder.commit(b"Joe Foo ", "") return repo, builder.finish()[r1] def test_add(self): GitRepo.init(self.test_dir) dir = ControlDir.open(self.test_dir) dir.create_branch() - self.build_tree(['a', 'b']) - output, error = self.run_bzr(['add', 'a']) - self.assertEqual('adding a\n', output) - self.assertEqual('', error) - output, error = self.run_bzr( - ['add', '--file-ids-from=../othertree', 'b']) - self.assertEqual('adding b\n', output) + self.build_tree(["a", "b"]) + output, error = self.run_bzr(["add", "a"]) + self.assertEqual("adding a\n", output) + self.assertEqual("", error) + output, error = self.run_bzr(["add", "--file-ids-from=../othertree", "b"]) + self.assertEqual("adding b\n", output) self.assertEqual( - 'Ignoring --file-ids-from, since the tree does not support ' - 'setting file ids.\n', error) + "Ignoring --file-ids-from, since the tree does not support " + "setting file ids.\n", + error, + ) def test_nick(self): GitRepo.init(self.test_dir) dir = ControlDir.open(self.test_dir) dir.create_branch() - output, error = self.run_bzr(['nick']) + output, error = self.run_bzr(["nick"]) self.assertEqual("master\n", output) def test_branches(self): self.simple_commit() - output, error = self.run_bzr(['branches']) + output, error = self.run_bzr(["branches"]) self.assertEqual("* master\n", output) def test_info(self): self.simple_commit() - output, error = self.run_bzr(['info']) - self.assertEqual(error, '') + output, error = self.run_bzr(["info"]) + self.assertEqual(error, "") self.assertEqual( output, - 'Standalone tree (format: git)\n' - 'Location:\n' - ' light checkout root: .\n' - ' checkout of co-located branch: master\n') + "Standalone tree (format: git)\n" + "Location:\n" + " light checkout root: .\n" + " checkout of co-located branch: master\n", + ) def test_ignore(self): self.simple_commit() - output, error = self.run_bzr(['ignore', 'foo']) - self.assertEqual(error, '') - self.assertEqual(output, '') + output, error = self.run_bzr(["ignore", "foo"]) + self.assertEqual(error, "") + self.assertEqual(output, "") self.assertFileEqual("foo\n", ".gitignore") def test_cat_revision(self): self.simple_commit() - output, error = self.run_bzr(['cat-revision', '-r-1'], retcode=3) + output, error = self.run_bzr(["cat-revision", "-r-1"], retcode=3) self.assertContainsRe( error, - 'brz: ERROR: Repository .* does not support access to raw ' - 'revision texts') - self.assertEqual(output, '') + "brz: ERROR: Repository .* does not support access to raw " + "revision texts", + ) + self.assertEqual(output, "") def test_branch(self): os.mkdir("gitbranch") GitRepo.init(os.path.join(self.test_dir, "gitbranch")) - os.chdir('gitbranch') + os.chdir("gitbranch") builder = tests.GitBranchBuilder() - builder.set_file(b'a', b'text for a\n', False) - builder.commit(b'Joe Foo ', b'') + builder.set_file(b"a", b"text for a\n", False) + builder.commit(b"Joe Foo ", b"") builder.finish() - os.chdir('..') + os.chdir("..") - output, error = self.run_bzr(['branch', 'gitbranch', 'bzrbranch']) + output, error = self.run_bzr(["branch", "gitbranch", "bzrbranch"]) errlines = error.splitlines(False) self.assertTrue( - 'Branched 1 revision(s).' in errlines or - 'Branched 1 revision.' in errlines, errlines) + "Branched 1 revision(s)." in errlines or "Branched 1 revision." in errlines, + errlines, + ) def test_checkout(self): os.mkdir("gitbranch") GitRepo.init(os.path.join(self.test_dir, "gitbranch")) - os.chdir('gitbranch') + os.chdir("gitbranch") builder = tests.GitBranchBuilder() - builder.set_file(b'a', b'text for a\n', False) - builder.commit(b'Joe Foo ', b'') + builder.set_file(b"a", b"text for a\n", False) + builder.commit(b"Joe Foo ", b"") builder.finish() - os.chdir('..') + os.chdir("..") - output, error = self.run_bzr(['checkout', 'gitbranch', 'bzrbranch']) - self.assertEqual(error, - 'Fetching from Git to Bazaar repository. ' - 'For better performance, fetch into a Git repository.\n') - self.assertEqual(output, '') + output, error = self.run_bzr(["checkout", "gitbranch", "bzrbranch"]) + self.assertEqual( + error, + "Fetching from Git to Bazaar repository. " + "For better performance, fetch into a Git repository.\n", + ) + self.assertEqual(output, "") def test_branch_ls(self): self.simple_commit() - output, error = self.run_bzr(['ls', '-r-1']) - self.assertEqual(error, '') + output, error = self.run_bzr(["ls", "-r-1"]) + self.assertEqual(error, "") self.assertEqual(output, "a\n") def test_init(self): @@ -137,8 +142,8 @@ def test_init(self): def test_info_verbose(self): self.simple_commit() - output, error = self.run_bzr(['info', '-v']) - self.assertEqual(error, '') + output, error = self.run_bzr(["info", "-v"]) + self.assertEqual(error, "") self.assertIn("Standalone tree (format: git)", output) self.assertIn("control: Local Git Repository", output) self.assertIn("branch: Local Git Branch", output) @@ -149,87 +154,92 @@ def test_push_roundtripping(self): self.with_roundtripping() os.mkdir("bla") GitRepo.init(os.path.join(self.test_dir, "bla")) - self.run_bzr(['init', 'foo']) - self.run_bzr(['commit', '--unchanged', '-m', 'bla', 'foo']) + self.run_bzr(["init", "foo"]) + self.run_bzr(["commit", "--unchanged", "-m", "bla", "foo"]) # when roundtripping is supported - output, error = self.run_bzr(['push', '-d', 'foo', 'bla']) + output, error = self.run_bzr(["push", "-d", "foo", "bla"]) self.assertEqual(b"", output) self.assertTrue(error.endswith(b"Created new branch.\n")) def test_push_without_calculate_revnos(self): - self.run_bzr(['init', '--git', 'bla']) - self.run_bzr(['init', '--git', 'foo']) - self.run_bzr(['commit', '--unchanged', '-m', 'bla', 'foo']) + self.run_bzr(["init", "--git", "bla"]) + self.run_bzr(["init", "--git", "foo"]) + self.run_bzr(["commit", "--unchanged", "-m", "bla", "foo"]) output, error = self.run_bzr( - ['push', '-Ocalculate_revnos=no', '-d', 'foo', 'bla']) + ["push", "-Ocalculate_revnos=no", "-d", "foo", "bla"] + ) self.assertEqual("", output) - self.assertContainsRe( - error, - 'Pushed up to revision id git(.*).\n') + self.assertContainsRe(error, "Pushed up to revision id git(.*).\n") def test_merge(self): - self.run_bzr(['init', '--git', 'orig']) - self.build_tree_contents([('orig/a', 'orig contents\n')]) - self.run_bzr(['add', 'orig/a']) - self.run_bzr(['commit', '-m', 'add orig', 'orig']) - self.run_bzr(['clone', 'orig', 'other']) - self.build_tree_contents([('other/a', 'new contents\n')]) - self.run_bzr(['commit', '-m', 'modify', 'other']) - self.build_tree_contents([('orig/b', 'more\n')]) - self.run_bzr(['add', 'orig/b']) - self.build_tree_contents([('orig/a', 'new contents\n')]) - self.run_bzr(['commit', '-m', 'more', 'orig']) - self.run_bzr(['merge', '-d', 'orig', 'other']) + self.run_bzr(["init", "--git", "orig"]) + self.build_tree_contents([("orig/a", "orig contents\n")]) + self.run_bzr(["add", "orig/a"]) + self.run_bzr(["commit", "-m", "add orig", "orig"]) + self.run_bzr(["clone", "orig", "other"]) + self.build_tree_contents([("other/a", "new contents\n")]) + self.run_bzr(["commit", "-m", "modify", "other"]) + self.build_tree_contents([("orig/b", "more\n")]) + self.run_bzr(["add", "orig/b"]) + self.build_tree_contents([("orig/a", "new contents\n")]) + self.run_bzr(["commit", "-m", "more", "orig"]) + self.run_bzr(["merge", "-d", "orig", "other"]) def test_push_lossy_non_mainline(self): - self.run_bzr(['init', '--git', 'bla']) - self.run_bzr(['init', 'foo']) - self.run_bzr(['commit', '--unchanged', '-m', 'bla', 'foo']) - self.run_bzr(['branch', 'foo', 'foo1']) - self.run_bzr(['commit', '--unchanged', '-m', 'bla', 'foo1']) - self.run_bzr(['commit', '--unchanged', '-m', 'bla', 'foo']) - self.run_bzr(['merge', '-d', 'foo', 'foo1']) - self.run_bzr(['commit', '--unchanged', '-m', 'merge', 'foo']) - output, error = self.run_bzr(['push', '--lossy', '-r1.1.1', '-d', 'foo', 'bla']) + self.run_bzr(["init", "--git", "bla"]) + self.run_bzr(["init", "foo"]) + self.run_bzr(["commit", "--unchanged", "-m", "bla", "foo"]) + self.run_bzr(["branch", "foo", "foo1"]) + self.run_bzr(["commit", "--unchanged", "-m", "bla", "foo1"]) + self.run_bzr(["commit", "--unchanged", "-m", "bla", "foo"]) + self.run_bzr(["merge", "-d", "foo", "foo1"]) + self.run_bzr(["commit", "--unchanged", "-m", "merge", "foo"]) + output, error = self.run_bzr(["push", "--lossy", "-r1.1.1", "-d", "foo", "bla"]) self.assertEqual("", output) self.assertEqual( - 'Pushing from a Bazaar to a Git repository. For better ' - 'performance, push into a Bazaar repository.\n' - 'All changes applied successfully.\n' - 'Pushed up to revision 2.\n', error) + "Pushing from a Bazaar to a Git repository. For better " + "performance, push into a Bazaar repository.\n" + "All changes applied successfully.\n" + "Pushed up to revision 2.\n", + error, + ) def test_push_lossy_non_mainline_incremental(self): - self.run_bzr(['init', '--git', 'bla']) - self.run_bzr(['init', 'foo']) - self.run_bzr(['commit', '--unchanged', '-m', 'bla', 'foo']) - self.run_bzr(['commit', '--unchanged', '-m', 'bla', 'foo']) - output, error = self.run_bzr(['push', '--lossy', '-d', 'foo', 'bla']) + self.run_bzr(["init", "--git", "bla"]) + self.run_bzr(["init", "foo"]) + self.run_bzr(["commit", "--unchanged", "-m", "bla", "foo"]) + self.run_bzr(["commit", "--unchanged", "-m", "bla", "foo"]) + output, error = self.run_bzr(["push", "--lossy", "-d", "foo", "bla"]) self.assertEqual("", output) self.assertEqual( - 'Pushing from a Bazaar to a Git repository. For better ' - 'performance, push into a Bazaar repository.\n' - 'All changes applied successfully.\n' - 'Pushed up to revision 2.\n', error) - self.run_bzr(['commit', '--unchanged', '-m', 'bla', 'foo']) - output, error = self.run_bzr(['push', '--lossy', '-d', 'foo', 'bla']) + "Pushing from a Bazaar to a Git repository. For better " + "performance, push into a Bazaar repository.\n" + "All changes applied successfully.\n" + "Pushed up to revision 2.\n", + error, + ) + self.run_bzr(["commit", "--unchanged", "-m", "bla", "foo"]) + output, error = self.run_bzr(["push", "--lossy", "-d", "foo", "bla"]) self.assertEqual("", output) self.assertEqual( - 'Pushing from a Bazaar to a Git repository. For better ' - 'performance, push into a Bazaar repository.\n' - 'All changes applied successfully.\n' - 'Pushed up to revision 3.\n', error) + "Pushing from a Bazaar to a Git repository. For better " + "performance, push into a Bazaar repository.\n" + "All changes applied successfully.\n" + "Pushed up to revision 3.\n", + error, + ) def test_log(self): # Smoke test for "bzr log" in a git repository. self.simple_commit() # Check that bzr log does not fail and includes the revision. - output, error = self.run_bzr(['log']) - self.assertEqual(error, '') + output, error = self.run_bzr(["log"]) + self.assertEqual(error, "") self.assertIn( - '', + "", output, - f"Commit message was not found in output:\n{output}" + f"Commit message was not found in output:\n{output}", ) def test_log_verbose(self): @@ -237,55 +247,57 @@ def test_log_verbose(self): self.simple_commit() # Check that bzr log does not fail and includes the revision. - output, error = self.run_bzr(['log', '-v']) - self.assertContainsRe(output, 'revno: 1') + output, error = self.run_bzr(["log", "-v"]) + self.assertContainsRe(output, "revno: 1") def test_log_without_revno(self): # Smoke test for "bzr log -v" in a git repository. self.simple_commit() # Check that bzr log does not fail and includes the revision. - output, error = self.run_bzr(['log', '-Ocalculate_revnos=no']) - self.assertNotContainsRe(output, 'revno: 1') + output, error = self.run_bzr(["log", "-Ocalculate_revnos=no"]) + self.assertNotContainsRe(output, "revno: 1") def test_commit_without_revno(self): GitRepo.init(self.test_dir) output, error = self.run_bzr( - ['commit', '-Ocalculate_revnos=yes', '--unchanged', '-m', 'one']) - self.assertContainsRe(error, 'Committed revision 1.') + ["commit", "-Ocalculate_revnos=yes", "--unchanged", "-m", "one"] + ) + self.assertContainsRe(error, "Committed revision 1.") output, error = self.run_bzr( - ['commit', '-Ocalculate_revnos=no', '--unchanged', '-m', 'two']) - self.assertNotContainsRe(error, 'Committed revision 2.') - self.assertContainsRe(error, 'Committed revid .*.') + ["commit", "-Ocalculate_revnos=no", "--unchanged", "-m", "two"] + ) + self.assertNotContainsRe(error, "Committed revision 2.") + self.assertContainsRe(error, "Committed revid .*.") def test_log_file(self): # Smoke test for "bzr log" in a git repository. GitRepo.init(self.test_dir) builder = tests.GitBranchBuilder() - builder.set_file('a', b'text for a\n', False) - r1 = builder.commit(b'Joe Foo ', 'First') - builder.set_file('a', b'text 3a for a\n', False) - r2a = builder.commit(b'Joe Foo ', 'Second a', base=r1) - builder.set_file('a', b'text 3b for a\n', False) - r2b = builder.commit(b'Joe Foo ', 'Second b', base=r1) - builder.set_file('a', b'text 4 for a\n', False) - builder.commit(b'Joe Foo ', 'Third', merge=[r2a], base=r2b) + builder.set_file("a", b"text for a\n", False) + r1 = builder.commit(b"Joe Foo ", "First") + builder.set_file("a", b"text 3a for a\n", False) + r2a = builder.commit(b"Joe Foo ", "Second a", base=r1) + builder.set_file("a", b"text 3b for a\n", False) + r2b = builder.commit(b"Joe Foo ", "Second b", base=r1) + builder.set_file("a", b"text 4 for a\n", False) + builder.commit(b"Joe Foo ", "Third", merge=[r2a], base=r2b) builder.finish() # Check that bzr log does not fail and includes the revision. - output, error = self.run_bzr(['log', '-n2', 'a']) - self.assertEqual(error, '') - self.assertIn('Second a', output) - self.assertIn('Second b', output) - self.assertIn('First', output) - self.assertIn('Third', output) + output, error = self.run_bzr(["log", "-n2", "a"]) + self.assertEqual(error, "") + self.assertIn("Second a", output) + self.assertIn("Second b", output) + self.assertIn("First", output) + self.assertIn("Third", output) def test_tags(self): git_repo, commit_sha1 = self.simple_commit() git_repo.refs[b"refs/tags/foo"] = commit_sha1 - output, error = self.run_bzr(['tags']) - self.assertEqual(error, '') + output, error = self.run_bzr(["tags"]) + self.assertEqual(error, "") self.assertEqual(output, "foo 1\n") def test_tag(self): @@ -295,74 +307,96 @@ def test_tag(self): # bzr <= 2.2 emits this message in the output stream # bzr => 2.3 emits this message in the error stream - self.assertEqual(error + output, 'Created tag bar.\n') + self.assertEqual(error + output, "Created tag bar.\n") def test_init_repo(self): output, error = self.run_bzr(["init", "--format=git", "bla.git"]) - self.assertEqual(error, '') - self.assertEqual(output, 'Created a standalone tree (format: git)\n') + self.assertEqual(error, "") + self.assertEqual(output, "Created a standalone tree (format: git)\n") def test_diff_format(self): - tree = self.make_branch_and_tree('.') - self.build_tree(['a']) - tree.add(['a']) - output, error = self.run_bzr(['diff', '--color=never', '--format=git'], retcode=1) - self.assertEqual(error, '') + tree = self.make_branch_and_tree(".") + self.build_tree(["a"]) + tree.add(["a"]) + output, error = self.run_bzr( + ["diff", "--color=never", "--format=git"], retcode=1 + ) + self.assertEqual(error, "") # Some older versions of Dulwich (< 0.19.12) formatted diffs slightly # differently. from dulwich import __version__ as dulwich_version + if dulwich_version < (0, 19, 12): - self.assertEqual(output, - 'diff --git /dev/null b/a\n' - 'old mode 0\n' - 'new mode 100644\n' - 'index 0000000..c197bd8 100644\n' - '--- /dev/null\n' - '+++ b/a\n' - '@@ -0,0 +1 @@\n' - '+contents of a\n') + self.assertEqual( + output, + "diff --git /dev/null b/a\n" + "old mode 0\n" + "new mode 100644\n" + "index 0000000..c197bd8 100644\n" + "--- /dev/null\n" + "+++ b/a\n" + "@@ -0,0 +1 @@\n" + "+contents of a\n", + ) else: - self.assertEqual(output, - 'diff --git a/a b/a\n' - 'old file mode 0\n' - 'new file mode 100644\n' - 'index 0000000..c197bd8 100644\n' - '--- /dev/null\n' - '+++ b/a\n' - '@@ -0,0 +1 @@\n' - '+contents of a\n') + self.assertEqual( + output, + "diff --git a/a b/a\n" + "old file mode 0\n" + "new file mode 100644\n" + "index 0000000..c197bd8 100644\n" + "--- /dev/null\n" + "+++ b/a\n" + "@@ -0,0 +1 @@\n" + "+contents of a\n", + ) def test_git_import_uncolocated(self): r = GitRepo.init("a", mkdir=True) self.build_tree(["a/file"]) r.stage("file") - r.do_commit(ref=b"refs/heads/abranch", - committer=b"Joe ", message=b"Dummy") - r.do_commit(ref=b"refs/heads/bbranch", - committer=b"Joe ", message=b"Dummy") + r.do_commit( + ref=b"refs/heads/abranch", + committer=b"Joe ", + message=b"Dummy", + ) + r.do_commit( + ref=b"refs/heads/bbranch", + committer=b"Joe ", + message=b"Dummy", + ) self.run_bzr(["git-import", "a", "b"]) - self.assertEqual( - {".bzr", "abranch", "bbranch"}, set(os.listdir("b"))) + self.assertEqual({".bzr", "abranch", "bbranch"}, set(os.listdir("b"))) def test_git_import(self): r = GitRepo.init("a", mkdir=True) self.build_tree(["a/file"]) r.stage("file") - r.do_commit(ref=b"refs/heads/abranch", - committer=b"Joe ", message=b"Dummy") - r.do_commit(ref=b"refs/heads/bbranch", - committer=b"Joe ", message=b"Dummy") + r.do_commit( + ref=b"refs/heads/abranch", + committer=b"Joe ", + message=b"Dummy", + ) + r.do_commit( + ref=b"refs/heads/bbranch", + committer=b"Joe ", + message=b"Dummy", + ) self.run_bzr(["git-import", "--colocated", "a", "b"]) self.assertEqual({".bzr"}, set(os.listdir("b"))) - self.assertEqual({"abranch", "bbranch"}, - set(ControlDir.open("b").branch_names())) + self.assertEqual( + {"abranch", "bbranch"}, set(ControlDir.open("b").branch_names()) + ) def test_git_import_incremental(self): r = GitRepo.init("a", mkdir=True) self.build_tree(["a/file"]) r.stage("file") - r.do_commit(ref=b"refs/heads/abranch", - committer=b"Joe ", message=b"Dummy") + r.do_commit( + ref=b"refs/heads/abranch", + committer=b"Joe ", + message=b"Dummy", + ) self.run_bzr(["git-import", "--colocated", "a", "b"]) self.run_bzr(["git-import", "--colocated", "a", "b"]) self.assertEqual({".bzr"}, set(os.listdir("b"))) @@ -373,49 +407,64 @@ def test_git_import_tags(self): r = GitRepo.init("a", mkdir=True) self.build_tree(["a/file"]) r.stage("file") - cid = r.do_commit(ref=b"refs/heads/abranch", - committer=b"Joe ", message=b"Dummy") + cid = r.do_commit( + ref=b"refs/heads/abranch", + committer=b"Joe ", + message=b"Dummy", + ) r[b"refs/tags/atag"] = cid self.run_bzr(["git-import", "--colocated", "a", "b"]) self.assertEqual({".bzr"}, set(os.listdir("b"))) b = ControlDir.open("b") self.assertEqual(["abranch"], b.branch_names()) - self.assertEqual(["atag"], - list(b.open_branch("abranch").tags.get_tag_dict().keys())) + self.assertEqual( + ["atag"], list(b.open_branch("abranch").tags.get_tag_dict().keys()) + ) def test_git_import_colo(self): r = GitRepo.init("a", mkdir=True) self.build_tree(["a/file"]) r.stage("file") - r.do_commit(ref=b"refs/heads/abranch", - committer=b"Joe ", message=b"Dummy") - r.do_commit(ref=b"refs/heads/bbranch", - committer=b"Joe ", message=b"Dummy") + r.do_commit( + ref=b"refs/heads/abranch", + committer=b"Joe ", + message=b"Dummy", + ) + r.do_commit( + ref=b"refs/heads/bbranch", + committer=b"Joe ", + message=b"Dummy", + ) self.make_controldir("b", format="development-colo") self.run_bzr(["git-import", "--colocated", "a", "b"]) self.assertEqual( {b.name for b in ControlDir.open("b").list_branches()}, - {"abranch", "bbranch"}) + {"abranch", "bbranch"}, + ) def test_git_refs_from_git(self): r = GitRepo.init("a", mkdir=True) self.build_tree(["a/file"]) r.stage("file") - cid = r.do_commit(ref=b"refs/heads/abranch", - committer=b"Joe ", message=b"Dummy") + cid = r.do_commit( + ref=b"refs/heads/abranch", + committer=b"Joe ", + message=b"Dummy", + ) r[b"refs/tags/atag"] = cid (stdout, stderr) = self.run_bzr(["git-refs", "a"]) self.assertEqual(stderr, "") - self.assertEqual(stdout, - 'refs/heads/abranch -> ' + cid.decode('ascii') + '\n' - 'refs/tags/atag -> ' + cid.decode('ascii') + '\n') + self.assertEqual( + stdout, + "refs/heads/abranch -> " + cid.decode("ascii") + "\n" + "refs/tags/atag -> " + cid.decode("ascii") + "\n", + ) def test_git_refs_from_bzr(self): - tree = self.make_branch_and_tree('a') + tree = self.make_branch_and_tree("a") self.build_tree(["a/file"]) tree.add(["file"]) - revid = tree.commit( - committer=b"Joe ", message=b"Dummy") + revid = tree.commit(committer=b"Joe ", message=b"Dummy") tree.branch.tags.set_tag("atag", revid) (stdout, stderr) = self.run_bzr(["git-refs", "a"]) self.assertEqual(stderr, "") @@ -429,153 +478,184 @@ def test_check(self): r.do_commit(b"message", committer=b"Somebody ") out, err = self.run_bzr(["check", "gitr"]) self.maxDiff = None - self.assertEqual(out, '') - self.assertTrue(err.endswith, '3 objects\n') + self.assertEqual(out, "") + self.assertTrue(err.endswith, "3 objects\n") def test_local_whoami(self): GitRepo.init("gitr", mkdir=True) - self.build_tree_contents([('gitr/.git/config', """\ + self.build_tree_contents( + [ + ( + "gitr/.git/config", + """\ [user] email = some@example.com name = Test User -""")]) +""", + ) + ] + ) out, err = self.run_bzr(["whoami", "-d", "gitr"]) self.assertEqual(out, "Test User \n") self.assertEqual(err, "") - self.build_tree_contents([('gitr/.git/config', """\ + self.build_tree_contents( + [ + ( + "gitr/.git/config", + """\ [user] email = some@example.com -""")]) +""", + ) + ] + ) out, err = self.run_bzr(["whoami", "-d", "gitr"]) self.assertEqual(out, "some@example.com\n") self.assertEqual(err, "") def test_local_signing_key(self): GitRepo.init("gitr", mkdir=True) - self.build_tree_contents([('gitr/.git/config', """\ + self.build_tree_contents( + [ + ( + "gitr/.git/config", + """\ [user] email = some@example.com name = Test User signingkey = D729A457 -""")]) +""", + ) + ] + ) out, err = self.run_bzr(["config", "-d", "gitr", "gpg_signing_key"]) self.assertEqual(out, "D729A457\n") self.assertEqual(err, "") class ShallowTests(ExternalBase): - def setUp(self): super().setUp() # Smoke test for "bzr log" in a git repository with shallow depth. - self.repo = GitRepo.init('gitr', mkdir=True) + self.repo = GitRepo.init("gitr", mkdir=True) self.build_tree_contents([("gitr/foo", b"hello from git")]) self.repo.stage("foo") self.repo.do_commit( - b"message", committer=b"Somebody ", + b"message", + committer=b"Somebody ", author=b"Somebody ", - commit_timestamp=1526330165, commit_timezone=0, - author_timestamp=1526330165, author_timezone=0, - merge_heads=[b'aa' * 20]) + commit_timestamp=1526330165, + commit_timezone=0, + author_timestamp=1526330165, + author_timezone=0, + merge_heads=[b"aa" * 20], + ) def test_log_shallow(self): # Check that bzr log does not fail and includes the revision. - output, error = self.run_bzr(['log', 'gitr'], retcode=3) + output, error = self.run_bzr(["log", "gitr"], retcode=3) + self.assertEqual(error, "brz: ERROR: Further revision history missing.\n") self.assertEqual( - error, 'brz: ERROR: Further revision history missing.\n') - self.assertEqual(output, - '------------------------------------------------------------\n' - 'revision-id: git-v1:' + self.repo.head().decode('ascii') + '\n' - 'git commit: ' + self.repo.head().decode('ascii') + '\n' - 'committer: Somebody \n' - 'timestamp: Mon 2018-05-14 20:36:05 +0000\n' - 'message:\n' - ' message\n') + output, + "------------------------------------------------------------\n" + "revision-id: git-v1:" + self.repo.head().decode("ascii") + "\n" + "git commit: " + self.repo.head().decode("ascii") + "\n" + "committer: Somebody \n" + "timestamp: Mon 2018-05-14 20:36:05 +0000\n" + "message:\n" + " message\n", + ) def test_version_info_rio(self): - output, error = self.run_bzr(['version-info', '--rio', 'gitr']) - self.assertEqual(error, '') - self.assertNotIn('revno:', output) + output, error = self.run_bzr(["version-info", "--rio", "gitr"]) + self.assertEqual(error, "") + self.assertNotIn("revno:", output) def test_version_info_python(self): - output, error = self.run_bzr(['version-info', '--python', 'gitr']) - self.assertEqual(error, '') - self.assertNotIn('revno:', output) + output, error = self.run_bzr(["version-info", "--python", "gitr"]) + self.assertEqual(error, "") + self.assertNotIn("revno:", output) def test_version_info_custom_with_revno(self): output, error = self.run_bzr( - ['version-info', '--custom', - '--template=VERSION_INFO r{revno})\n', 'gitr'], retcode=3) - self.assertEqual( - error, 'brz: ERROR: Variable {revno} is not available.\n') - self.assertEqual(output, 'VERSION_INFO r') + ["version-info", "--custom", "--template=VERSION_INFO r{revno})\n", "gitr"], + retcode=3, + ) + self.assertEqual(error, "brz: ERROR: Variable {revno} is not available.\n") + self.assertEqual(output, "VERSION_INFO r") def test_version_info_custom_without_revno(self): output, error = self.run_bzr( - ['version-info', '--custom', '--template=VERSION_INFO \n', - 'gitr']) - self.assertEqual(error, '') - self.assertEqual(output, 'VERSION_INFO \n') + ["version-info", "--custom", "--template=VERSION_INFO \n", "gitr"] + ) + self.assertEqual(error, "") + self.assertEqual(output, "VERSION_INFO \n") class SwitchTests(ExternalBase): - def test_switch_branch(self): # Create a git repository with a revision. repo = GitRepo.init(self.test_dir) builder = tests.GitBranchBuilder() - builder.set_branch(b'refs/heads/oldbranch') - builder.set_file('a', b'text for a\n', False) - builder.commit(b'Joe Foo ', '') - builder.set_branch(b'refs/heads/newbranch') + builder.set_branch(b"refs/heads/oldbranch") + builder.set_file("a", b"text for a\n", False) + builder.commit(b"Joe Foo ", "") + builder.set_branch(b"refs/heads/newbranch") builder.reset() - builder.set_file('a', b'text for new a\n', False) - builder.commit(b'Joe Foo ', '') + builder.set_file("a", b"text for new a\n", False) + builder.commit(b"Joe Foo ", "") builder.finish() - repo.refs.set_symbolic_ref(b'HEAD', b'refs/heads/newbranch') + repo.refs.set_symbolic_ref(b"HEAD", b"refs/heads/newbranch") repo.reset_index() - output, error = self.run_bzr('switch oldbranch') - self.assertEqual(output, '') - self.assertTrue(error.startswith('Updated to revision 1.\n'), error) + output, error = self.run_bzr("switch oldbranch") + self.assertEqual(output, "") + self.assertTrue(error.startswith("Updated to revision 1.\n"), error) - self.assertFileEqual("text for a\n", 'a') - tree = WorkingTree.open('.') + self.assertFileEqual("text for a\n", "a") + tree = WorkingTree.open(".") with tree.lock_read(): basis_tree = tree.basis_tree() with basis_tree.lock_read(): self.assertEqual([], list(tree.iter_changes(basis_tree))) def test_branch_with_nested_trees(self): - orig = self.make_branch_and_tree('source', format='git') - subtree = self.make_branch_and_tree('source/subtree', format='git') - self.build_tree(['source/subtree/a']) - self.build_tree_contents([('source/.gitmodules', f"""[submodule "subtree"] + orig = self.make_branch_and_tree("source", format="git") + subtree = self.make_branch_and_tree("source/subtree", format="git") + self.build_tree(["source/subtree/a"]) + self.build_tree_contents( + [ + ( + "source/.gitmodules", + f"""[submodule "subtree"] path = subtree url = {subtree.user_url} -""")]) - subtree.add(['a']) - subtree.commit('add subtree contents') +""", + ) + ] + ) + subtree.add(["a"]) + subtree.commit("add subtree contents") orig.add_reference(subtree) - orig.add(['.gitmodules']) - orig.commit('add subtree') + orig.add([".gitmodules"]) + orig.commit("add subtree") - self.run_bzr('branch source target') + self.run_bzr("branch source target") - target = WorkingTree.open('target') - target_subtree = WorkingTree.open('target/subtree') + target = WorkingTree.open("target") + target_subtree = WorkingTree.open("target/subtree") self.assertTreesEqual(orig, target) self.assertTreesEqual(subtree, target_subtree) class SwitchScriptTests(TestCaseWithTransportAndScript): - def test_switch_preserves(self): # See https://bugs.launchpad.net/brz/+bug/1820606 - self.run_script(""" + self.run_script( + """ $ brz init --git r Created a standalone tree (format: git) $ cd r @@ -592,89 +672,86 @@ def test_switch_preserves(self): 2>Switched to branch other $ cat file.txt entered on master branch -""") +""" + ) class GrepTests(ExternalBase): - def test_simple_grep(self): - tree = self.make_branch_and_tree('.', format='git') - self.build_tree_contents([('a', 'text for a\n')]) - tree.add(['a']) - output, error = self.run_bzr('grep --color=never text') - self.assertEqual(output, 'a:text for a\n') - self.assertEqual(error, '') + tree = self.make_branch_and_tree(".", format="git") + self.build_tree_contents([("a", "text for a\n")]) + tree.add(["a"]) + output, error = self.run_bzr("grep --color=never text") + self.assertEqual(output, "a:text for a\n") + self.assertEqual(error, "") class ReconcileTests(ExternalBase): - def test_simple_reconcile(self): - tree = self.make_branch_and_tree('.', format='git') - self.build_tree_contents([('a', 'text for a\n')]) - tree.add(['a']) - output, error = self.run_bzr('reconcile') + tree = self.make_branch_and_tree(".", format="git") + self.build_tree_contents([("a", "text for a\n")]) + tree.add(["a"]) + output, error = self.run_bzr("reconcile") self.assertContainsRe( output, - 'Reconciling branch file://.*\n' - 'Reconciling repository file://.*\n' - 'Reconciliation complete.\n') - self.assertEqual(error, '') + "Reconciling branch file://.*\n" + "Reconciling repository file://.*\n" + "Reconciliation complete.\n", + ) + self.assertEqual(error, "") class StatusTests(ExternalBase): - def test_empty_dir(self): - tree = self.make_branch_and_tree('.', format='git') - self.build_tree(['a/', 'a/foo']) - self.build_tree_contents([('.gitignore', 'foo\n')]) - tree.add(['.gitignore']) - tree.commit('add ignore') - output, error = self.run_bzr('st') - self.assertEqual(output, '') - self.assertEqual(error, '') + tree = self.make_branch_and_tree(".", format="git") + self.build_tree(["a/", "a/foo"]) + self.build_tree_contents([(".gitignore", "foo\n")]) + tree.add([".gitignore"]) + tree.commit("add ignore") + output, error = self.run_bzr("st") + self.assertEqual(output, "") + self.assertEqual(error, "") class StatsTests(ExternalBase): - def test_simple_stats(self): - self.requireFeature(PluginLoadedFeature('stats')) - tree = self.make_branch_and_tree('.', format='git') - self.build_tree_contents([('a', 'text for a\n')]) - tree.add(['a']) - tree.commit('a commit', committer='Somebody ') - output, error = self.run_bzr('stats') - self.assertEqual(output, ' 1 Somebody \n') + self.requireFeature(PluginLoadedFeature("stats")) + tree = self.make_branch_and_tree(".", format="git") + self.build_tree_contents([("a", "text for a\n")]) + tree.add(["a"]) + tree.commit("a commit", committer="Somebody ") + output, error = self.run_bzr("stats") + self.assertEqual(output, " 1 Somebody \n") class GitObjectsTests(ExternalBase): - def run_simple(self, format): - tree = self.make_branch_and_tree('.', format=format) - self.build_tree(['a/', 'a/foo']) - tree.add(['a']) - tree.commit('add a') - output, error = self.run_bzr('git-objects') + tree = self.make_branch_and_tree(".", format=format) + self.build_tree(["a/", "a/foo"]) + tree.add(["a"]) + tree.commit("add a") + output, error = self.run_bzr("git-objects") shas = list(output.splitlines()) self.assertEqual([40, 40], [len(s) for s in shas]) - self.assertEqual(error, '') + self.assertEqual(error, "") - output, error = self.run_bzr(f'git-object {shas[0]}') - self.assertEqual('', error) + output, error = self.run_bzr(f"git-object {shas[0]}") + self.assertEqual("", error) def test_in_native(self): - self.run_simple(format='git') + self.run_simple(format="git") def test_in_bzr(self): - self.run_simple(format='2a') + self.run_simple(format="2a") class GitApplyTests(ExternalBase): - def test_apply(self): - self.make_branch_and_tree('.') + self.make_branch_and_tree(".") - with open('foo.patch', 'w') as f: - f.write("""\ + with open("foo.patch", "w") as f: + f.write( + """\ From bdefb25fab801e6af0a70e965f60cb48f2b759fa Mon Sep 17 00:00:00 2001 From: Dmitry Bogatov Date: Fri, 8 Feb 2019 23:28:30 +0000 @@ -694,9 +771,7 @@ def test_apply(self): +Update standards version, no changes needed. +Certainty: certain +Fixed-Lintian-Tags: out-of-date-standards-version -""") - output, error = self.run_bzr('git-apply foo.patch') - self.assertContainsRe( - error, - 'Committing to: .*\n' - 'Committed revision 1.\n') +""" + ) + output, error = self.run_bzr("git-apply foo.patch") + self.assertContainsRe(error, "Committing to: .*\n" "Committed revision 1.\n") diff --git a/breezy/git/tests/test_branch.py b/breezy/git/tests/test_branch.py index c751b1044e..3924a3ef90 100644 --- a/breezy/git/tests/test_branch.py +++ b/breezy/git/tests/test_branch.py @@ -34,62 +34,64 @@ class TestGitBranch(tests.TestCaseInTempDir): - def test_open_by_ref(self): - GitRepo.init('.') + GitRepo.init(".") url = "{},ref={}".format( urlutils.local_path_to_url(self.test_dir), - urlutils.quote("refs/remotes/origin/unstable", safe='') - ) + urlutils.quote("refs/remotes/origin/unstable", safe=""), + ) d = ControlDir.open(url) b = d.create_branch() self.assertEqual(b.ref, b"refs/remotes/origin/unstable") def test_open_existing(self): - GitRepo.init('.') - d = ControlDir.open('.') + GitRepo.init(".") + d = ControlDir.open(".") thebranch = d.create_branch() self.assertIsInstance(thebranch, branch.GitBranch) def test_repr(self): - GitRepo.init('.') - d = ControlDir.open('.') + GitRepo.init(".") + d = ControlDir.open(".") thebranch = d.create_branch() self.assertEqual( f"", - repr(thebranch)) + repr(thebranch), + ) def test_last_revision_is_null(self): - GitRepo.init('.') - thedir = ControlDir.open('.') + GitRepo.init(".") + thedir = ControlDir.open(".") thebranch = thedir.create_branch() self.assertEqual(revision.NULL_REVISION, thebranch.last_revision()) - self.assertEqual((0, revision.NULL_REVISION), - thebranch.last_revision_info()) + self.assertEqual((0, revision.NULL_REVISION), thebranch.last_revision_info()) def simple_commit_a(self): - r = GitRepo.init('.') - self.build_tree(['a']) + r = GitRepo.init(".") + self.build_tree(["a"]) r.stage(["a"]) return r.do_commit(b"a", committer=b"Somebody ") def test_last_revision_is_valid(self): head = self.simple_commit_a() - thebranch = Branch.open('.') - self.assertEqual(default_mapping.revision_id_foreign_to_bzr(head), - thebranch.last_revision()) + thebranch = Branch.open(".") + self.assertEqual( + default_mapping.revision_id_foreign_to_bzr(head), thebranch.last_revision() + ) def test_last_revision_info(self): self.simple_commit_a() - self.build_tree(['b']) + self.build_tree(["b"]) r = GitRepo(".") self.addCleanup(r.close) r.stage("b") revb = r.do_commit(b"b", committer=b"Somebody ") - thebranch = Branch.open('.') - self.assertEqual((2, default_mapping.revision_id_foreign_to_bzr( - revb)), thebranch.last_revision_info()) + thebranch = Branch.open(".") + self.assertEqual( + (2, default_mapping.revision_id_foreign_to_bzr(revb)), + thebranch.last_revision_info(), + ) def test_tag_annotated(self): reva = self.simple_commit_a() @@ -103,23 +105,26 @@ def test_tag_annotated(self): r = GitRepo(".") self.addCleanup(r.close) r.object_store.add_object(o) - r[b'refs/tags/foo'] = o.id - thebranch = Branch.open('.') - self.assertEqual({"foo": default_mapping.revision_id_foreign_to_bzr(reva)}, - thebranch.tags.get_tag_dict()) + r[b"refs/tags/foo"] = o.id + thebranch = Branch.open(".") + self.assertEqual( + {"foo": default_mapping.revision_id_foreign_to_bzr(reva)}, + thebranch.tags.get_tag_dict(), + ) def test_tag(self): reva = self.simple_commit_a() r = GitRepo(".") self.addCleanup(r.close) r.refs[b"refs/tags/foo"] = reva - thebranch = Branch.open('.') - self.assertEqual({"foo": default_mapping.revision_id_foreign_to_bzr(reva)}, - thebranch.tags.get_tag_dict()) + thebranch = Branch.open(".") + self.assertEqual( + {"foo": default_mapping.revision_id_foreign_to_bzr(reva)}, + thebranch.tags.get_tag_dict(), + ) class TestWithGitBranch(tests.TestCaseWithTransport): - def setUp(self): tests.TestCaseWithTransport.setUp(self) dulwich.repo.Repo.create(self.test_dir) @@ -130,22 +135,19 @@ def test_get_parent(self): self.assertIs(None, self.git_branch.get_parent()) def test_get_stacked_on_url(self): - self.assertRaises(UnstackableBranchFormat, - self.git_branch.get_stacked_on_url) + self.assertRaises(UnstackableBranchFormat, self.git_branch.get_stacked_on_url) def test_get_physical_lock_status(self): self.assertFalse(self.git_branch.get_physical_lock_status()) class TestLocalGitBranchFormat(tests.TestCase): - def setUp(self): super().setUp() self.format = branch.LocalGitBranchFormat() def test_get_format_description(self): - self.assertEqual("Local Git Branch", - self.format.get_format_description()) + self.assertEqual("Local Git Branch", self.format.get_format_description()) def test_get_network_name(self): self.assertEqual(b"git", self.format.network_name()) @@ -155,11 +157,10 @@ def test_supports_tags(self): class BranchTests(tests.TestCaseInTempDir): - def make_onerev_branch(self): os.mkdir("d") os.chdir("d") - GitRepo.init('.') + GitRepo.init(".") bb = tests.GitBranchBuilder() bb.set_file("foobar", b"foo\nbar\n", False) mark = bb.commit(b"Somebody ", b"mymsg") @@ -170,7 +171,7 @@ def make_onerev_branch(self): def make_tworev_branch(self): os.mkdir("d") os.chdir("d") - GitRepo.init('.') + GitRepo.init(".") bb = tests.GitBranchBuilder() bb.set_file("foobar", b"foo\nbar\n", False) mark1 = bb.commit(b"Somebody ", b"mymsg") @@ -210,26 +211,32 @@ def test_sprouted_ghost_tags(self): r.refs[b"refs/tags/lala"] = b"aa" * 20 oldrepo = Repository.open(path) oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha) - warnings, newbranch = self.callCatchWarnings( - self.clone_git_branch, path, "f") + warnings, newbranch = self.callCatchWarnings(self.clone_git_branch, path, "f") self.assertEqual({}, newbranch.tags.get_tag_dict()) # Dulwich raises a UserWarning for tags with invalid target - self.assertIn(('ref refs/tags/lala points at non-present sha ' + ("aa" * 20), ), [w.args for w in warnings]) + self.assertIn( + ("ref refs/tags/lala points at non-present sha " + ("aa" * 20),), + [w.args for w in warnings], + ) def test_interbranch_pull_submodule(self): path = "d" os.mkdir(path) os.chdir(path) - GitRepo.init('.') + GitRepo.init(".") bb = tests.GitBranchBuilder() bb.set_file("foobar", b"foo\nbar\n", False) mark1 = bb.commit(b"Somebody ", b"mymsg") - bb.set_submodule("core", b'102ee7206ebc4227bec8ac02450972e6738f4a33') - bb.set_file('.gitmodules', b"""\ + bb.set_submodule("core", b"102ee7206ebc4227bec8ac02450972e6738f4a33") + bb.set_file( + ".gitmodules", + b"""\ [submodule "core"] path = core url = https://github.com/phhusson/QuasselC.git -""", False) +""", + False, + ) mark2 = bb.commit(b"Somebody ", b"mymsg") marks = bb.finish() os.chdir("..") @@ -237,19 +244,20 @@ def test_interbranch_pull_submodule(self): gitsha2 = marks[mark2] oldrepo = Repository.open(path) revid2 = oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha2) - newbranch = self.make_branch('g') + newbranch = self.make_branch("g") inter_branch = InterBranch.get(Branch.open(path), newbranch) inter_branch.pull() self.assertEqual(revid2, newbranch.last_revision()) self.assertEqual( - ('https://github.com/phhusson/QuasselC.git', 'core'), - newbranch.get_reference_info(newbranch.basis_tree().path2id('core'))) + ("https://github.com/phhusson/QuasselC.git", "core"), + newbranch.get_reference_info(newbranch.basis_tree().path2id("core")), + ) def test_interbranch_pull(self): path, (gitsha1, gitsha2) = self.make_tworev_branch() oldrepo = Repository.open(path) revid2 = oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha2) - newbranch = self.make_branch('g') + newbranch = self.make_branch("g") inter_branch = InterBranch.get(Branch.open(path), newbranch) inter_branch.pull() self.assertEqual(revid2, newbranch.last_revision()) @@ -258,7 +266,7 @@ def test_interbranch_pull_noop(self): path, (gitsha1, gitsha2) = self.make_tworev_branch() oldrepo = Repository.open(path) revid2 = oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha2) - newbranch = self.make_branch('g') + newbranch = self.make_branch("g") inter_branch = InterBranch.get(Branch.open(path), newbranch) inter_branch.pull() # This is basically "assertNotRaises" @@ -269,7 +277,7 @@ def test_interbranch_pull_stop_revision(self): path, (gitsha1, gitsha2) = self.make_tworev_branch() oldrepo = Repository.open(path) revid1 = oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha1) - newbranch = self.make_branch('g') + newbranch = self.make_branch("g") inter_branch = InterBranch.get(Branch.open(path), newbranch) inter_branch.pull(stop_revision=revid1) self.assertEqual(revid1, newbranch.last_revision()) @@ -282,7 +290,7 @@ def test_interbranch_pull_with_tags(self): oldrepo = Repository.open(path) revid1 = oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha1) revid2 = oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha2) - newbranch = self.make_branch('g') + newbranch = self.make_branch("g") source_branch = Branch.open(path) source_branch.get_config().set_user_option("branch.fetch_tags", True) inter_branch = InterBranch.get(source_branch, newbranch) @@ -292,23 +300,21 @@ def test_interbranch_pull_with_tags(self): def test_bzr_branch_bound_to_git(self): path, (gitsha1, gitsha2) = self.make_tworev_branch() - wt = Branch.open(path).create_checkout('co') - self.build_tree_contents([('co/foobar', b'blah')]) + wt = Branch.open(path).create_checkout("co") + self.build_tree_contents([("co/foobar", b"blah")]) self.assertRaises( - errors.NoRoundtrippingSupport, wt.commit, - 'commit from bound branch.') - revid = wt.commit('commit from bound branch.', lossy=True) + errors.NoRoundtrippingSupport, wt.commit, "commit from bound branch." + ) + revid = wt.commit("commit from bound branch.", lossy=True) self.assertEqual(revid, wt.branch.last_revision()) - self.assertEqual( - revid, - wt.branch.get_master_branch().last_revision()) + self.assertEqual(revid, wt.branch.get_master_branch().last_revision()) def test_interbranch_pull_older(self): path, (gitsha1, gitsha2) = self.make_tworev_branch() oldrepo = Repository.open(path) revid1 = oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha1) revid2 = oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha2) - newbranch = self.make_branch('g') + newbranch = self.make_branch("g") inter_branch = InterBranch.get(Branch.open(path), newbranch) inter_branch.pull(stop_revision=revid2) inter_branch.pull(stop_revision=revid1) @@ -316,7 +322,6 @@ def test_interbranch_pull_older(self): class ForeignTestsBranchFactory: - def make_empty_branch(self, transport): d = LocalGitControlDirFormat().initialize_on_transport(transport) return d.create_branch() diff --git a/breezy/git/tests/test_builder.py b/breezy/git/tests/test_builder.py index 6123b6b255..c9e53d435c 100644 --- a/breezy/git/tests/test_builder.py +++ b/breezy/git/tests/test_builder.py @@ -24,235 +24,244 @@ class TestGitBranchBuilder(tests.TestCase): - def test__create_blob(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - self.assertEqual(1, builder._create_blob(b'foo\nbar\n')) - self.assertEqualDiff(b'blob\nmark :1\ndata 8\nfoo\nbar\n\n', - stream.getvalue()) + self.assertEqual(1, builder._create_blob(b"foo\nbar\n")) + self.assertEqualDiff(b"blob\nmark :1\ndata 8\nfoo\nbar\n\n", stream.getvalue()) def test_set_file(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.set_file('foobar', b'foo\nbar\n', False) - self.assertEqualDiff(b'blob\nmark :1\ndata 8\nfoo\nbar\n\n', - stream.getvalue()) - self.assertEqual([b'M 100644 :1 foobar\n'], builder.commit_info) + builder.set_file("foobar", b"foo\nbar\n", False) + self.assertEqualDiff(b"blob\nmark :1\ndata 8\nfoo\nbar\n\n", stream.getvalue()) + self.assertEqual([b"M 100644 :1 foobar\n"], builder.commit_info) def test_set_file_unicode(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.set_file('f\xb5/bar', b'contents\nbar\n', False) - self.assertEqualDiff(b'blob\nmark :1\ndata 13\ncontents\nbar\n\n', - stream.getvalue()) - self.assertEqual([b'M 100644 :1 f\xc2\xb5/bar\n'], builder.commit_info) + builder.set_file("f\xb5/bar", b"contents\nbar\n", False) + self.assertEqualDiff( + b"blob\nmark :1\ndata 13\ncontents\nbar\n\n", stream.getvalue() + ) + self.assertEqual([b"M 100644 :1 f\xc2\xb5/bar\n"], builder.commit_info) def test_set_file_newline(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.set_file('foo\nbar', b'contents\nbar\n', False) - self.assertEqualDiff(b'blob\nmark :1\ndata 13\ncontents\nbar\n\n', - stream.getvalue()) + builder.set_file("foo\nbar", b"contents\nbar\n", False) + self.assertEqualDiff( + b"blob\nmark :1\ndata 13\ncontents\nbar\n\n", stream.getvalue() + ) self.assertEqual([b'M 100644 :1 "foo\\nbar"\n'], builder.commit_info) def test_set_file_executable(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.set_file('f\xb5/bar', b'contents\nbar\n', True) - self.assertEqualDiff(b'blob\nmark :1\ndata 13\ncontents\nbar\n\n', - stream.getvalue()) - self.assertEqual([b'M 100755 :1 f\xc2\xb5/bar\n'], builder.commit_info) + builder.set_file("f\xb5/bar", b"contents\nbar\n", True) + self.assertEqualDiff( + b"blob\nmark :1\ndata 13\ncontents\nbar\n\n", stream.getvalue() + ) + self.assertEqual([b"M 100755 :1 f\xc2\xb5/bar\n"], builder.commit_info) def test_set_symlink(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.set_symlink('f\xb5/bar', b'link/contents') - self.assertEqualDiff(b'blob\nmark :1\ndata 13\nlink/contents\n', - stream.getvalue()) - self.assertEqual([b'M 120000 :1 f\xc2\xb5/bar\n'], builder.commit_info) + builder.set_symlink("f\xb5/bar", b"link/contents") + self.assertEqualDiff( + b"blob\nmark :1\ndata 13\nlink/contents\n", stream.getvalue() + ) + self.assertEqual([b"M 120000 :1 f\xc2\xb5/bar\n"], builder.commit_info) def test_set_symlink_newline(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.set_symlink('foo\nbar', 'link/contents') - self.assertEqualDiff(b'blob\nmark :1\ndata 13\nlink/contents\n', - stream.getvalue()) + builder.set_symlink("foo\nbar", "link/contents") + self.assertEqualDiff( + b"blob\nmark :1\ndata 13\nlink/contents\n", stream.getvalue() + ) self.assertEqual([b'M 120000 :1 "foo\\nbar"\n'], builder.commit_info) def test_delete_entry(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.delete_entry('path/to/f\xb5') - self.assertEqual([b'D path/to/f\xc2\xb5\n'], builder.commit_info) + builder.delete_entry("path/to/f\xb5") + self.assertEqual([b"D path/to/f\xc2\xb5\n"], builder.commit_info) def test_delete_entry_newline(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.delete_entry('path/to/foo\nbar') + builder.delete_entry("path/to/foo\nbar") self.assertEqual([b'D "path/to/foo\\nbar"\n'], builder.commit_info) def test_encode_path(self): encode = tests.GitBranchBuilder._encode_path # Unicode is encoded to utf-8 - self.assertEqual(encode('f\xb5'), b'f\xc2\xb5') + self.assertEqual(encode("f\xb5"), b"f\xc2\xb5") # The name must be quoted if it starts by a double quote or contains a # newline. self.assertEqual(encode('"foo'), b'"\\"foo"') - self.assertEqual(encode('fo\no'), b'"fo\\no"') + self.assertEqual(encode("fo\no"), b'"fo\\no"') # When the name is quoted, all backslash and quote chars must be # escaped. - self.assertEqual(encode('fo\\o\nbar'), b'"fo\\\\o\\nbar"') + self.assertEqual(encode("fo\\o\nbar"), b'"fo\\\\o\\nbar"') self.assertEqual(encode('fo"o"\nbar'), b'"fo\\"o\\"\\nbar"') # Other control chars, such as \r, need not be escaped. - self.assertEqual(encode('foo\r\nbar'), b'"foo\r\\nbar"') + self.assertEqual(encode("foo\r\nbar"), b'"foo\r\\nbar"') def test_add_and_commit(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.set_file('f\xb5/bar', b'contents\nbar\n', False) - self.assertEqual(b'2', builder.commit(b'Joe Foo ', - 'committing f\xb5/bar', - timestamp=1194586400, - timezone=b'+0100')) - self.assertEqualDiff(b'blob\nmark :1\ndata 13\ncontents\nbar\n\n' - b'commit refs/heads/master\n' - b'mark :2\n' - b'committer Joe Foo 1194586400 +0100\n' - b'data 18\n' - b'committing f\xc2\xb5/bar' - b'\n' - b'M 100644 :1 f\xc2\xb5/bar\n' - b'\n', - stream.getvalue()) + builder.set_file("f\xb5/bar", b"contents\nbar\n", False) + self.assertEqual( + b"2", + builder.commit( + b"Joe Foo ", + "committing f\xb5/bar", + timestamp=1194586400, + timezone=b"+0100", + ), + ) + self.assertEqualDiff( + b"blob\nmark :1\ndata 13\ncontents\nbar\n\n" + b"commit refs/heads/master\n" + b"mark :2\n" + b"committer Joe Foo 1194586400 +0100\n" + b"data 18\n" + b"committing f\xc2\xb5/bar" + b"\n" + b"M 100644 :1 f\xc2\xb5/bar\n" + b"\n", + stream.getvalue(), + ) def test_commit_base(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.set_file('foo', b'contents\nfoo\n', False) - r1 = builder.commit(b'Joe Foo ', 'first', - timestamp=1194586400) - builder.commit(b'Joe Foo ', 'second', - timestamp=1194586405) - builder.commit(b'Joe Foo ', 'third', - timestamp=1194586410, - base=r1) + builder.set_file("foo", b"contents\nfoo\n", False) + r1 = builder.commit(b"Joe Foo ", "first", timestamp=1194586400) + builder.commit(b"Joe Foo ", "second", timestamp=1194586405) + builder.commit(b"Joe Foo ", "third", timestamp=1194586410, base=r1) - self.assertEqualDiff(b'blob\nmark :1\ndata 13\ncontents\nfoo\n\n' - b'commit refs/heads/master\n' - b'mark :2\n' - b'committer Joe Foo 1194586400 +0000\n' - b'data 5\n' - b'first' - b'\n' - b'M 100644 :1 foo\n' - b'\n' - b'commit refs/heads/master\n' - b'mark :3\n' - b'committer Joe Foo 1194586405 +0000\n' - b'data 6\n' - b'second' - b'\n' - b'\n' - b'commit refs/heads/master\n' - b'mark :4\n' - b'committer Joe Foo 1194586410 +0000\n' - b'data 5\n' - b'third' - b'\n' - b'from :2\n' - b'\n', stream.getvalue()) + self.assertEqualDiff( + b"blob\nmark :1\ndata 13\ncontents\nfoo\n\n" + b"commit refs/heads/master\n" + b"mark :2\n" + b"committer Joe Foo 1194586400 +0000\n" + b"data 5\n" + b"first" + b"\n" + b"M 100644 :1 foo\n" + b"\n" + b"commit refs/heads/master\n" + b"mark :3\n" + b"committer Joe Foo 1194586405 +0000\n" + b"data 6\n" + b"second" + b"\n" + b"\n" + b"commit refs/heads/master\n" + b"mark :4\n" + b"committer Joe Foo 1194586410 +0000\n" + b"data 5\n" + b"third" + b"\n" + b"from :2\n" + b"\n", + stream.getvalue(), + ) def test_commit_merge(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.set_file('foo', b'contents\nfoo\n', False) - r1 = builder.commit(b'Joe Foo ', 'first', - timestamp=1194586400) - r2 = builder.commit(b'Joe Foo ', 'second', - timestamp=1194586405) - builder.commit(b'Joe Foo ', 'third', - timestamp=1194586410, - base=r1) - builder.commit(b'Joe Foo ', 'Merge', - timestamp=1194586415, - merge=[r2]) + builder.set_file("foo", b"contents\nfoo\n", False) + r1 = builder.commit(b"Joe Foo ", "first", timestamp=1194586400) + r2 = builder.commit(b"Joe Foo ", "second", timestamp=1194586405) + builder.commit(b"Joe Foo ", "third", timestamp=1194586410, base=r1) + builder.commit( + b"Joe Foo ", "Merge", timestamp=1194586415, merge=[r2] + ) - self.assertEqualDiff(b'blob\nmark :1\ndata 13\ncontents\nfoo\n\n' - b'commit refs/heads/master\n' - b'mark :2\n' - b'committer Joe Foo 1194586400 +0000\n' - b'data 5\n' - b'first' - b'\n' - b'M 100644 :1 foo\n' - b'\n' - b'commit refs/heads/master\n' - b'mark :3\n' - b'committer Joe Foo 1194586405 +0000\n' - b'data 6\n' - b'second' - b'\n' - b'\n' - b'commit refs/heads/master\n' - b'mark :4\n' - b'committer Joe Foo 1194586410 +0000\n' - b'data 5\n' - b'third' - b'\n' - b'from :2\n' - b'\n' - b'commit refs/heads/master\n' - b'mark :5\n' - b'committer Joe Foo 1194586415 +0000\n' - b'data 5\n' - b'Merge' - b'\n' - b'merge :3\n' - b'\n', stream.getvalue()) + self.assertEqualDiff( + b"blob\nmark :1\ndata 13\ncontents\nfoo\n\n" + b"commit refs/heads/master\n" + b"mark :2\n" + b"committer Joe Foo 1194586400 +0000\n" + b"data 5\n" + b"first" + b"\n" + b"M 100644 :1 foo\n" + b"\n" + b"commit refs/heads/master\n" + b"mark :3\n" + b"committer Joe Foo 1194586405 +0000\n" + b"data 6\n" + b"second" + b"\n" + b"\n" + b"commit refs/heads/master\n" + b"mark :4\n" + b"committer Joe Foo 1194586410 +0000\n" + b"data 5\n" + b"third" + b"\n" + b"from :2\n" + b"\n" + b"commit refs/heads/master\n" + b"mark :5\n" + b"committer Joe Foo 1194586415 +0000\n" + b"data 5\n" + b"Merge" + b"\n" + b"merge :3\n" + b"\n", + stream.getvalue(), + ) def test_auto_timestamp(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.commit(b'Joe Foo ', 'message') - self.assertContainsRe(stream.getvalue(), - br'committer Joe Foo \d+ \+0000') + builder.commit(b"Joe Foo ", "message") + self.assertContainsRe( + stream.getvalue(), rb"committer Joe Foo \d+ \+0000" + ) def test_reset(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) builder.reset() - self.assertEqualDiff(b'reset refs/heads/master\n\n', stream.getvalue()) + self.assertEqualDiff(b"reset refs/heads/master\n\n", stream.getvalue()) def test_reset_named_ref(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.reset(b'refs/heads/branch') - self.assertEqualDiff(b'reset refs/heads/branch\n\n', stream.getvalue()) + builder.reset(b"refs/heads/branch") + self.assertEqualDiff(b"reset refs/heads/branch\n\n", stream.getvalue()) def test_reset_revision(self): stream = BytesIO() builder = tests.GitBranchBuilder(stream) - builder.reset(mark=b'123') + builder.reset(mark=b"123") self.assertEqualDiff( - b'reset refs/heads/master\n' - b'from :123\n' - b'\n', stream.getvalue()) + b"reset refs/heads/master\n" b"from :123\n" b"\n", stream.getvalue() + ) class TestGitBranchBuilderReal(tests.TestCaseInTempDir): - def test_create_real_branch(self): GitRepo.init(".") builder = tests.GitBranchBuilder() - builder.set_file('foo', b'contents\nfoo\n', False) - builder.commit(b'Joe Foo ', 'first', - timestamp=1194586400) + builder.set_file("foo", b"contents\nfoo\n", False) + builder.commit(b"Joe Foo ", "first", timestamp=1194586400) mapping = builder.finish() - self.assertEqual({b'1': b'44411e8e9202177dd19b6599d7a7991059fa3cb4', - b'2': b'b0b62e674f67306fddcf72fa888c3b56df100d64', - }, mapping) + self.assertEqual( + { + b"1": b"44411e8e9202177dd19b6599d7a7991059fa3cb4", + b"2": b"b0b62e674f67306fddcf72fa888c3b56df100d64", + }, + mapping, + ) diff --git a/breezy/git/tests/test_cache.py b/breezy/git/tests/test_cache.py index b4d0ca3c30..2d67c663c0 100644 --- a/breezy/git/tests/test_cache.py +++ b/breezy/git/tests/test_cache.py @@ -34,7 +34,6 @@ class TestGitShaMap: - def _get_test_commit(self): c = Commit() c.committer = b"Jelmer " @@ -49,63 +48,118 @@ def _get_test_commit(self): def test_commit(self): self.map.start_write_group() - updater = self.cache.get_updater(Revision(b"myrevid", parent_ids=[], message='', committer='', timezone=0, timestamp=0, properties={}, inventory_sha1=None)) + updater = self.cache.get_updater( + Revision( + b"myrevid", + parent_ids=[], + message="", + committer="", + timezone=0, + timestamp=0, + properties={}, + inventory_sha1=None, + ) + ) c = self._get_test_commit() - updater.add_object(c, { - "testament3-sha1": b"cc9462f7f8263ef5adf8eff2fb936bb36b504cba"}, - None) + updater.add_object( + c, {"testament3-sha1": b"cc9462f7f8263ef5adf8eff2fb936bb36b504cba"}, None + ) updater.finish() self.map.commit_write_group() self.assertEqual( - [("commit", (b"myrevid", - b"cc9462f7f8263ef5adfbeff2fb936bb36b504cba", - {"testament3-sha1": b"cc9462f7f8263ef5adf8eff2fb936bb36b504cba"}, - ))], - list(self.map.lookup_git_sha(c.id))) + [ + ( + "commit", + ( + b"myrevid", + b"cc9462f7f8263ef5adfbeff2fb936bb36b504cba", + { + "testament3-sha1": b"cc9462f7f8263ef5adf8eff2fb936bb36b504cba" + }, + ), + ) + ], + list(self.map.lookup_git_sha(c.id)), + ) self.assertEqual(c.id, self.map.lookup_commit(b"myrevid")) def test_lookup_notfound(self): - self.assertRaises(KeyError, list, - self.map.lookup_git_sha(b"5686645d49063c73d35436192dfc9a160c672301")) + self.assertRaises( + KeyError, + list, + self.map.lookup_git_sha(b"5686645d49063c73d35436192dfc9a160c672301"), + ) def test_blob(self): self.map.start_write_group() - updater = self.cache.get_updater(Revision(b"myrevid", parent_ids=[], message='', committer='', timezone=0, timestamp=0, properties={}, inventory_sha1=None)) - updater.add_object(self._get_test_commit(), { - "testament3-sha1": b"Test"}, None) + updater = self.cache.get_updater( + Revision( + b"myrevid", + parent_ids=[], + message="", + committer="", + timezone=0, + timestamp=0, + properties={}, + inventory_sha1=None, + ) + ) + updater.add_object(self._get_test_commit(), {"testament3-sha1": b"Test"}, None) b = Blob() b.data = b"TEH BLOB" updater.add_object(b, (b"myfileid", b"myrevid"), None) updater.finish() self.map.commit_write_group() self.assertEqual( - [("blob", (b"myfileid", b"myrevid"))], - list(self.map.lookup_git_sha(b.id))) - self.assertEqual(b.id, - self.map.lookup_blob_id(b"myfileid", b"myrevid")) + [("blob", (b"myfileid", b"myrevid"))], list(self.map.lookup_git_sha(b.id)) + ) + self.assertEqual(b.id, self.map.lookup_blob_id(b"myfileid", b"myrevid")) def test_tree(self): self.map.start_write_group() - updater = self.cache.get_updater(Revision(b"somerevid", parent_ids=[], message='', committer='', timezone=0, timestamp=0, properties={}, inventory_sha1=None)) - updater.add_object(self._get_test_commit(), { - "testament3-sha1": b"mytestamentsha"}, None) + updater = self.cache.get_updater( + Revision( + b"somerevid", + parent_ids=[], + message="", + committer="", + timezone=0, + timestamp=0, + properties={}, + inventory_sha1=None, + ) + ) + updater.add_object( + self._get_test_commit(), {"testament3-sha1": b"mytestamentsha"}, None + ) t = Tree() t.add(b"somename", stat.S_IFREG, Blob().id) updater.add_object(t, (b"fileid", b"myrevid"), b"") updater.finish() self.map.commit_write_group() - self.assertEqual([("tree", (b"fileid", b"myrevid"))], - list(self.map.lookup_git_sha(t.id))) + self.assertEqual( + [("tree", (b"fileid", b"myrevid"))], list(self.map.lookup_git_sha(t.id)) + ) # It's possible for a backend to not implement lookup_tree try: - self.assertEqual(t.id, - self.map.lookup_tree_id(b"fileid", b"myrevid")) + self.assertEqual(t.id, self.map.lookup_tree_id(b"fileid", b"myrevid")) except NotImplementedError: pass def test_revids(self): self.map.start_write_group() - updater = self.cache.get_updater(Revision(b"myrevid", parent_ids=[], message='', committer='', timezone=0, timestamp=0, properties={}, inventory_sha1=None)) + updater = self.cache.get_updater( + Revision( + b"myrevid", + parent_ids=[], + message="", + committer="", + timezone=0, + timestamp=0, + properties={}, + inventory_sha1=None, + ) + ) c = self._get_test_commit() updater.add_object(c, {"testament3-sha1": b"mtestament"}, None) updater.finish() @@ -114,17 +168,29 @@ def test_revids(self): def test_missing_revisions(self): self.map.start_write_group() - updater = self.cache.get_updater(Revision(b"myrevid", parent_ids=[], message='', committer='', timezone=0, timestamp=0, properties={}, inventory_sha1=None)) + updater = self.cache.get_updater( + Revision( + b"myrevid", + parent_ids=[], + message="", + committer="", + timezone=0, + timestamp=0, + properties={}, + inventory_sha1=None, + ) + ) c = self._get_test_commit() updater.add_object(c, {"testament3-sha1": b"testament"}, None) updater.finish() self.map.commit_write_group() - self.assertEqual({b"lala", b"bla"}, - set(self.map.missing_revisions([b"myrevid", b"lala", b"bla"]))) + self.assertEqual( + {b"lala", b"bla"}, + set(self.map.missing_revisions([b"myrevid", b"lala", b"bla"])), + ) class DictGitShaMapTests(TestCase, TestGitShaMap): - def setUp(self): TestCase.setUp(self) self.cache = DictBzrGitCache() @@ -132,26 +198,23 @@ def setUp(self): class SqliteGitShaMapTests(TestCaseInTempDir, TestGitShaMap): - def setUp(self): TestCaseInTempDir.setUp(self) - self.cache = SqliteBzrGitCache(os.path.join(self.test_dir, 'foo.db')) + self.cache = SqliteBzrGitCache(os.path.join(self.test_dir, "foo.db")) self.map = self.cache.idmap class TdbGitShaMapTests(TestCaseInTempDir, TestGitShaMap): - def setUp(self): TestCaseInTempDir.setUp(self) try: - self.cache = TdbBzrGitCache(os.path.join(self.test_dir, 'foo.tdb')) + self.cache = TdbBzrGitCache(os.path.join(self.test_dir, "foo.tdb")) except ModuleNotFoundError as err: raise UnavailableFeature("Missing tdb") from err self.map = self.cache.idmap class IndexGitShaMapTests(TestCaseInTempDir, TestGitShaMap): - def setUp(self): TestCaseInTempDir.setUp(self) transport = get_transport(self.test_dir) diff --git a/breezy/git/tests/test_dir.py b/breezy/git/tests/test_dir.py index 0989c9f00f..32ff3a8a36 100644 --- a/breezy/git/tests/test_dir.py +++ b/breezy/git/tests/test_dir.py @@ -27,85 +27,84 @@ class TestGitDir(tests.TestCaseInTempDir): - def test_get_head_branch_reference(self): GitRepo.init(".") - gd = controldir.ControlDir.open('.') + gd = controldir.ControlDir.open(".") self.assertEqual( f"{urlutils.local_path_to_url(os.path.abspath('.'))},branch=master", - gd.get_branch_reference()) + gd.get_branch_reference(), + ) def test_get_reference_loop(self): r = GitRepo.init(".") - r.refs.set_symbolic_ref(b'refs/heads/loop', b'refs/heads/loop') + r.refs.set_symbolic_ref(b"refs/heads/loop", b"refs/heads/loop") - gd = controldir.ControlDir.open('.') + gd = controldir.ControlDir.open(".") self.assertRaises( - controldir.BranchReferenceLoop, - gd.get_branch_reference, name='loop') + controldir.BranchReferenceLoop, gd.get_branch_reference, name="loop" + ) def test_open_reference_loop(self): r = GitRepo.init(".") - r.refs.set_symbolic_ref(b'refs/heads/loop', b'refs/heads/loop') + r.refs.set_symbolic_ref(b"refs/heads/loop", b"refs/heads/loop") - gd = controldir.ControlDir.open('.') - self.assertRaises( - controldir.BranchReferenceLoop, - gd.open_branch, name='loop') + gd = controldir.ControlDir.open(".") + self.assertRaises(controldir.BranchReferenceLoop, gd.open_branch, name="loop") def test_open_existing(self): GitRepo.init(".") - gd = controldir.ControlDir.open('.') + gd = controldir.ControlDir.open(".") self.assertIsInstance(gd, dir.LocalGitDir) def test_open_ref_parent(self): r = GitRepo.init(".") - r.do_commit(message=b"message", ref=b'refs/heads/foo/bar') - gd = controldir.ControlDir.open('.') - self.assertRaises(errors.NotBranchError, gd.open_branch, 'foo') + r.do_commit(message=b"message", ref=b"refs/heads/foo/bar") + gd = controldir.ControlDir.open(".") + self.assertRaises(errors.NotBranchError, gd.open_branch, "foo") def test_open_workingtree(self): r = GitRepo.init(".") r.do_commit(message=b"message") - gd = controldir.ControlDir.open('.') + gd = controldir.ControlDir.open(".") wt = gd.open_workingtree() self.assertIsInstance(wt, workingtree.GitWorkingTree) def test_open_workingtree_bare(self): GitRepo.init_bare(".") - gd = controldir.ControlDir.open('.') + gd = controldir.ControlDir.open(".") self.assertRaises(errors.NoWorkingTree, gd.open_workingtree) def test_git_file(self): gitrepo = GitRepo.init("blah", mkdir=True) - self.build_tree_contents( - [('foo/', ), ('foo/.git', b'gitdir: ../blah/.git\n')]) + self.build_tree_contents([("foo/",), ("foo/.git", b"gitdir: ../blah/.git\n")]) - gd = controldir.ControlDir.open('foo') - self.assertEqual(gd.control_url.rstrip('/'), - urlutils.local_path_to_url(os.path.abspath(gitrepo.controldir()))) + gd = controldir.ControlDir.open("foo") + self.assertEqual( + gd.control_url.rstrip("/"), + urlutils.local_path_to_url(os.path.abspath(gitrepo.controldir())), + ) def test_shared_repository(self): - t = get_transport('.') + t = get_transport(".") self.assertRaises( errors.SharedRepositoriesUnsupported, - dir.LocalGitControlDirFormat().initialize_on_transport_ex, t, - shared_repo=True) + dir.LocalGitControlDirFormat().initialize_on_transport_ex, + t, + shared_repo=True, + ) class TestGitDirFormat(tests.TestCase): - def setUp(self): super().setUp() self.format = dir.LocalGitControlDirFormat() def test_get_format_description(self): - self.assertEqual("Local Git Repository", - self.format.get_format_description()) + self.assertEqual("Local Git Repository", self.format.get_format_description()) def test_eq(self): format2 = dir.LocalGitControlDirFormat() diff --git a/breezy/git/tests/test_fetch.py b/breezy/git/tests/test_fetch.py index d1f6b28627..5ba75c0fee 100644 --- a/breezy/git/tests/test_fetch.py +++ b/breezy/git/tests/test_fetch.py @@ -36,7 +36,6 @@ class RepositoryFetchTests: - def make_git_repo(self, path): os.mkdir(path) return GitRepo.init(os.path.abspath(path)) @@ -95,8 +94,7 @@ def test_incremental(self): self.assertEqual([revid1], newrepo.all_revision_ids()) revid2 = oldrepo.get_mapping().revision_id_foreign_to_bzr(gitsha2) newrepo.fetch(oldrepo, revision_id=revid2) - self.assertEqual({revid1, revid2}, - set(newrepo.all_revision_ids())) + self.assertEqual({revid1, revid2}, set(newrepo.all_revision_ids())) def test_dir_becomes_symlink(self): self.make_git_repo("d") @@ -238,11 +236,11 @@ def test_into_stacked_on(self): tree.branch.repository.check() self.addCleanup(tree.lock_read().unlock) self.assertEqual( - {(revid2,)}, - tree.branch.repository.revisions.without_fallbacks().keys()) + {(revid2,)}, tree.branch.repository.revisions.without_fallbacks().keys() + ) self.assertEqual( - {revid1, revid2}, - set(tree.branch.repository.all_revision_ids())) + {revid1, revid2}, set(tree.branch.repository.all_revision_ids()) + ) def test_non_ascii_characters(self): self.make_git_repo("d") @@ -283,13 +281,11 @@ def test_tagged_tree(self): class LocalRepositoryFetchTests(RepositoryFetchTests, TestCaseWithTransport): - def open_git_repo(self, path): return Repository.open(path) class DummyStoreUpdater: - def add_object(self, obj, ie, path): pass @@ -298,48 +294,77 @@ def finish(self): class ImportObjects(TestCaseWithTransport): - def setUp(self): super().setUp() self._mapping = BzrGitMappingv1() factory = knit.make_file_factory(True, versionedfile.PrefixMapper()) - self._texts = factory(self.get_transport('texts')) + self._texts = factory(self.get_transport("texts")) def test_import_blob_missing_in_one_parent(self): - builder = self.make_branch_builder('br') + builder = self.make_branch_builder("br") builder.start_series() - rev_root = builder.build_snapshot(None, [ - ('add', ('', b'rootid', 'directory', ''))]) - rev1 = builder.build_snapshot([rev_root], [ - ('add', ('bla', self._mapping.generate_file_id('bla'), 'file', b'content'))]) + rev_root = builder.build_snapshot( + None, [("add", ("", b"rootid", "directory", ""))] + ) + rev1 = builder.build_snapshot( + [rev_root], + [ + ( + "add", + ("bla", self._mapping.generate_file_id("bla"), "file", b"content"), + ) + ], + ) rev2 = builder.build_snapshot([rev_root], []) builder.finish_series() branch = builder.get_branch() blob = Blob.from_string(b"bar") objs = {"blobname": blob} - import_git_blob(self._texts, self._mapping, b"bla", b"bla", - (None, "blobname"), - branch.repository.revision_tree( - rev1), b'rootid', b"somerevid", - [branch.repository.revision_tree(r) for r in [ - rev1, rev2]], - objs.__getitem__, - (None, DEFAULT_FILE_MODE), DummyStoreUpdater(), - self._mapping.generate_file_id) - self.assertEqual({(b'git:bla', b'somerevid')}, self._texts.keys()) + import_git_blob( + self._texts, + self._mapping, + b"bla", + b"bla", + (None, "blobname"), + branch.repository.revision_tree(rev1), + b"rootid", + b"somerevid", + [branch.repository.revision_tree(r) for r in [rev1, rev2]], + objs.__getitem__, + (None, DEFAULT_FILE_MODE), + DummyStoreUpdater(), + self._mapping.generate_file_id, + ) + self.assertEqual({(b"git:bla", b"somerevid")}, self._texts.keys()) def test_import_blob_simple(self): blob = Blob.from_string(b"bar") objs = {"blobname": blob} - ret = import_git_blob(self._texts, self._mapping, b"bla", b"bla", - (None, "blobname"), - None, b"parentid", b"somerevid", [], objs.__getitem__, - (None, DEFAULT_FILE_MODE), DummyStoreUpdater(), - self._mapping.generate_file_id) - self.assertEqual({(b'git:bla', b'somerevid')}, self._texts.keys()) - self.assertEqual(next(self._texts.get_record_stream([(b'git:bla', b'somerevid')], - "unordered", True)).get_bytes_as("fulltext"), b"bar") + ret = import_git_blob( + self._texts, + self._mapping, + b"bla", + b"bla", + (None, "blobname"), + None, + b"parentid", + b"somerevid", + [], + objs.__getitem__, + (None, DEFAULT_FILE_MODE), + DummyStoreUpdater(), + self._mapping.generate_file_id, + ) + self.assertEqual({(b"git:bla", b"somerevid")}, self._texts.keys()) + self.assertEqual( + next( + self._texts.get_record_stream( + [(b"git:bla", b"somerevid")], "unordered", True + ) + ).get_bytes_as("fulltext"), + b"bar", + ) self.assertEqual(1, len(ret)) self.assertEqual(None, ret[0][0]) self.assertEqual("bla", ret[0][1]) @@ -351,15 +376,23 @@ def test_import_blob_simple(self): def test_import_tree_empty_root(self): tree = Tree() - ret, child_modes = import_git_tree(self._texts, self._mapping, b"", b"", - (None, tree.id), None, - None, b"somerevid", [], { - tree.id: tree}.__getitem__, - (None, stat.S_IFDIR), DummyStoreUpdater(), - self._mapping.generate_file_id) + ret, child_modes = import_git_tree( + self._texts, + self._mapping, + b"", + b"", + (None, tree.id), + None, + None, + b"somerevid", + [], + {tree.id: tree}.__getitem__, + (None, stat.S_IFDIR), + DummyStoreUpdater(), + self._mapping.generate_file_id, + ) self.assertEqual(child_modes, {}) - self.assertEqual( - {(b"TREE_ROOT", b'somerevid')}, self._texts.keys()) + self.assertEqual({(b"TREE_ROOT", b"somerevid")}, self._texts.keys()) self.assertEqual(1, len(ret)) self.assertEqual(None, ret[0][0]) self.assertEqual("", ret[0][1]) @@ -371,13 +404,23 @@ def test_import_tree_empty_root(self): def test_import_tree_empty(self): tree = Tree() - ret, child_modes = import_git_tree(self._texts, self._mapping, b"bla", b"bla", - (None, tree.id), None, None, b"somerevid", [], - {tree.id: tree}.__getitem__, - (None, stat.S_IFDIR), DummyStoreUpdater(), - self._mapping.generate_file_id) + ret, child_modes = import_git_tree( + self._texts, + self._mapping, + b"bla", + b"bla", + (None, tree.id), + None, + None, + b"somerevid", + [], + {tree.id: tree}.__getitem__, + (None, stat.S_IFDIR), + DummyStoreUpdater(), + self._mapping.generate_file_id, + ) self.assertEqual(child_modes, {}) - self.assertEqual({(b"git:bla", b'somerevid')}, self._texts.keys()) + self.assertEqual({(b"git:bla", b"somerevid")}, self._texts.keys()) self.assertEqual(1, len(ret)) self.assertEqual(None, ret[0][0]) self.assertEqual("bla", ret[0][1]) @@ -391,11 +434,21 @@ def test_import_tree_with_file(self): tree = Tree() tree.add(b"foo", stat.S_IFREG | 0o644, blob.id) objects = {blob.id: blob, tree.id: tree} - ret, child_modes = import_git_tree(self._texts, self._mapping, b"bla", b"bla", - (None, tree.id), None, None, b"somerevid", [], - objects.__getitem__, (None, stat.S_IFDIR), DummyStoreUpdater( - ), - self._mapping.generate_file_id) + ret, child_modes = import_git_tree( + self._texts, + self._mapping, + b"bla", + b"bla", + (None, tree.id), + None, + None, + b"somerevid", + [], + objects.__getitem__, + (None, stat.S_IFDIR), + DummyStoreUpdater(), + self._mapping.generate_file_id, + ) self.assertEqual(child_modes, {}) self.assertEqual(2, len(ret)) self.assertEqual(None, ret[0][0]) @@ -416,12 +469,21 @@ def test_import_tree_with_unusual_mode_file(self): tree = Tree() tree.add(b"foo", stat.S_IFREG | 0o664, blob.id) objects = {blob.id: blob, tree.id: tree} - ret, child_modes = import_git_tree(self._texts, self._mapping, - b"bla", b"bla", (None, tree.id), None, None, b"somerevid", [ - ], - objects.__getitem__, (None, stat.S_IFDIR), DummyStoreUpdater( - ), - self._mapping.generate_file_id) + ret, child_modes = import_git_tree( + self._texts, + self._mapping, + b"bla", + b"bla", + (None, tree.id), + None, + None, + b"somerevid", + [], + objects.__getitem__, + (None, stat.S_IFDIR), + DummyStoreUpdater(), + self._mapping.generate_file_id, + ) self.assertEqual(child_modes, {b"bla/foo": stat.S_IFREG | 0o664}) def test_import_tree_with_file_exe(self): @@ -429,11 +491,21 @@ def test_import_tree_with_file_exe(self): tree = Tree() tree.add(b"foo", 0o100755, blob.id) objects = {blob.id: blob, tree.id: tree} - ret, child_modes = import_git_tree(self._texts, self._mapping, b"", b"", - (None, tree.id), None, None, b"somerevid", [], - objects.__getitem__, (None, stat.S_IFDIR), DummyStoreUpdater( - ), - self._mapping.generate_file_id) + ret, child_modes = import_git_tree( + self._texts, + self._mapping, + b"", + b"", + (None, tree.id), + None, + None, + b"somerevid", + [], + objects.__getitem__, + (None, stat.S_IFDIR), + DummyStoreUpdater(), + self._mapping.generate_file_id, + ) self.assertEqual(child_modes, {}) self.assertEqual(2, len(ret)) self.assertEqual(None, ret[0][0]) @@ -455,17 +527,26 @@ def test_directory_converted_to_submodule(self): tree = Tree() tree.add(b"bar", 0o160000, blob.id) objects = {tree.id: tree} - ret, child_modes = import_git_submodule(self._texts, self._mapping, b"foo", b"foo", - (tree.id, othertree.id), base_inv, base_inv.root.file_id, b"somerevid", [ - ], - objects.__getitem__, (stat.S_IFDIR | 0o755, S_IFGITLINK), DummyStoreUpdater( - ), - self._mapping.generate_file_id) + ret, child_modes = import_git_submodule( + self._texts, + self._mapping, + b"foo", + b"foo", + (tree.id, othertree.id), + base_inv, + base_inv.root.file_id, + b"somerevid", + [], + objects.__getitem__, + (stat.S_IFDIR | 0o755, S_IFGITLINK), + DummyStoreUpdater(), + self._mapping.generate_file_id, + ) self.assertEqual(child_modes, {}) self.assertEqual(2, len(ret)) + self.assertEqual(ret[0], ("foo/bar", None, base_inv.path2id("foo/bar"), None)) self.assertEqual( - ret[0], ("foo/bar", None, base_inv.path2id("foo/bar"), None)) - self.assertEqual(ret[1][:3], ("foo", "foo", - self._mapping.generate_file_id("foo"))) + ret[1][:3], ("foo", "foo", self._mapping.generate_file_id("foo")) + ) ie = ret[1][3] self.assertEqual(ie.kind, "tree-reference") diff --git a/breezy/git/tests/test_git_remote_helper.py b/breezy/git/tests/test_git_remote_helper.py index b4cc7cfe7d..1575216e49 100644 --- a/breezy/git/tests/test_git_remote_helper.py +++ b/breezy/git/tests/test_git_remote_helper.py @@ -39,31 +39,29 @@ def map_to_git_sha1(dir, bzr_revid): git_remote_bzr_path = os.path.abspath( - os.path.join(os.path.dirname(__file__), '..', 'git-remote-bzr')) + os.path.join(os.path.dirname(__file__), "..", "git-remote-bzr") +) git_remote_bzr_feature = PathFeature(git_remote_bzr_path) class OpenLocalDirTests(TestCaseWithTransport): - def test_from_env_dir(self): - self.make_branch_and_tree('bla', format='git') - self.overrideEnv('GIT_DIR', os.path.join(self.test_dir, 'bla', '.git')) + self.make_branch_and_tree("bla", format="git") + self.overrideEnv("GIT_DIR", os.path.join(self.test_dir, "bla", ".git")) open_local_dir() def test_from_dir(self): - self.make_branch_and_tree('.', format='git') + self.make_branch_and_tree(".", format="git") open_local_dir() class FetchTests(TestCaseWithTransport): - def setUp(self): super().setUp() - self.local_dir = self.make_branch_and_tree( - 'local', format='git').controldir - self.remote_tree = self.make_branch_and_tree('remote') + self.local_dir = self.make_branch_and_tree("local", format="git").controldir + self.remote_tree = self.make_branch_and_tree("remote") self.remote_dir = self.remote_tree.controldir - self.shortname = 'bzr' + self.shortname = "bzr" def fetch(self, wants): outf = BytesIO() @@ -75,52 +73,53 @@ def test_no_wants(self): self.assertEqual(b"\n", r) def test_simple(self): - self.build_tree(['remote/foo']) + self.build_tree(["remote/foo"]) self.remote_tree.add("foo") revid = self.remote_tree.commit("msg") git_sha1 = map_to_git_sha1(self.remote_dir, revid) - out = self.fetch([(git_sha1, 'HEAD')]) + out = self.fetch([(git_sha1, "HEAD")]) self.assertEqual(out, b"\n") - r = Repo('local') + r = Repo("local") self.assertIn(git_sha1, r.object_store) self.assertEqual({}, r.get_refs()) class ExecuteRemoteHelperTests(TestCaseWithTransport): - def test_run(self): self.requireFeature(git_remote_bzr_feature) - local_dir = self.make_branch_and_tree('local', format='git').controldir - local_path = local_dir.control_transport.local_abspath('.') - remote_tree = self.make_branch_and_tree('remote') + local_dir = self.make_branch_and_tree("local", format="git").controldir + local_path = local_dir.control_transport.local_abspath(".") + remote_tree = self.make_branch_and_tree("remote") remote_dir = remote_tree.controldir env = dict(os.environ) - env['GIT_DIR'] = local_path - env['PYTHONPATH'] = ':'.join(sys.path) + env["GIT_DIR"] = local_path + env["PYTHONPATH"] = ":".join(sys.path) p = subprocess.Popen( [sys.executable, git_remote_bzr_path, local_path, remote_dir.user_url], - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, env=env) - (out, err) = p.communicate(b'capabilities\n') + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + (out, err) = p.communicate(b"capabilities\n") lines = out.splitlines() - self.assertIn(b'push', lines, f"no 'push' in {lines!r}, error: {err!r}") + self.assertIn(b"push", lines, f"no 'push' in {lines!r}, error: {err!r}") self.assertEqual( b"git-remote-bzr is experimental and has not been optimized " b"for performance. Use 'brz fast-export' and 'git fast-import' " - b"for large repositories.\n", err) + b"for large repositories.\n", + err, + ) class RemoteHelperTests(TestCaseWithTransport): - def setUp(self): super().setUp() - self.local_dir = self.make_branch_and_tree( - 'local', format='git').controldir - self.remote_tree = self.make_branch_and_tree('remote') + self.local_dir = self.make_branch_and_tree("local", format="git").controldir + self.remote_tree = self.make_branch_and_tree("remote") self.remote_dir = self.remote_tree.controldir - self.shortname = 'bzr' - self.helper = RemoteHelper( - self.local_dir, self.shortname, self.remote_dir) + self.shortname = "bzr" + self.helper = RemoteHelper(self.local_dir, self.shortname, self.remote_dir) def test_capabilities(self): f = BytesIO() @@ -137,26 +136,29 @@ def test_option(self): def test_list_basic(self): f = BytesIO() self.helper.cmd_list(f, []) - self.assertEqual( - b'\n', - f.getvalue()) + self.assertEqual(b"\n", f.getvalue()) def test_import(self): self.requireFeature(FastimportFeature) self.build_tree_contents([("remote/afile", b"somecontent")]) self.remote_tree.add(["afile"]) - self.remote_tree.commit(b"A commit message", timestamp=1330445983, - timezone=0, committer=b'Somebody ') + self.remote_tree.commit( + b"A commit message", + timestamp=1330445983, + timezone=0, + committer=b"Somebody ", + ) f = BytesIO() self.helper.cmd_import(f, ["import", "refs/heads/master"]) self.assertEqual( - b'reset refs/heads/master\n' - b'commit refs/heads/master\n' - b'mark :1\n' - b'committer Somebody 1330445983 +0000\n' - b'data 16\n' - b'A commit message\n' - b'M 644 inline afile\n' - b'data 11\n' - b'somecontent\n', - f.getvalue()) + b"reset refs/heads/master\n" + b"commit refs/heads/master\n" + b"mark :1\n" + b"committer Somebody 1330445983 +0000\n" + b"data 16\n" + b"A commit message\n" + b"M 644 inline afile\n" + b"data 11\n" + b"somecontent\n", + f.getvalue(), + ) diff --git a/breezy/git/tests/test_mapping.py b/breezy/git/tests/test_mapping.py index b557b4e545..168bd3624b 100644 --- a/breezy/git/tests/test_mapping.py +++ b/breezy/git/tests/test_mapping.py @@ -33,19 +33,21 @@ class TestRevidConversionV1(tests.TestCase): - def test_simple_git_to_bzr_revision_id(self): - self.assertEqual(b"git-v1:" - b"c6a4d8f1fa4ac650748e647c4b1b368f589a7356", - BzrGitMappingv1().revision_id_foreign_to_bzr( - b"c6a4d8f1fa4ac650748e647c4b1b368f589a7356")) + self.assertEqual( + b"git-v1:" b"c6a4d8f1fa4ac650748e647c4b1b368f589a7356", + BzrGitMappingv1().revision_id_foreign_to_bzr( + b"c6a4d8f1fa4ac650748e647c4b1b368f589a7356" + ), + ) def test_simple_bzr_to_git_revision_id(self): - self.assertEqual((b"c6a4d8f1fa4ac650748e647c4b1b368f589a7356", - BzrGitMappingv1()), - BzrGitMappingv1().revision_id_bzr_to_foreign( - b"git-v1:" - b"c6a4d8f1fa4ac650748e647c4b1b368f589a7356")) + self.assertEqual( + (b"c6a4d8f1fa4ac650748e647c4b1b368f589a7356", BzrGitMappingv1()), + BzrGitMappingv1().revision_id_bzr_to_foreign( + b"git-v1:" b"c6a4d8f1fa4ac650748e647c4b1b368f589a7356" + ), + ) def test_is_control_file(self): mapping = BzrGitMappingv1() @@ -61,7 +63,6 @@ def test_generate_file_id(self): class FileidTests(tests.TestCase): - def test_escape_space(self): self.assertEqual(b"bla_s", escape_file_id(b"bla ")) @@ -85,7 +86,6 @@ def test_unescape_underscore_space(self): class TestImportCommit(tests.TestCase): - def test_commit(self): c = Commit() c.tree = b"cc9462f7f8263ef5adfbeff2fb936bb36b504cba" @@ -98,16 +98,17 @@ def test_commit(self): c.author = b"Author" mapping = BzrGitMappingv1() rev, roundtrip_revid, verifiers = mapping.import_commit( - c, mapping.revision_id_foreign_to_bzr) + c, mapping.revision_id_foreign_to_bzr + ) self.assertEqual(None, roundtrip_revid) self.assertEqual({}, verifiers) self.assertEqual("Some message", rev.message) self.assertEqual("Committer", rev.committer) - self.assertEqual("Author", rev.properties['author']) + self.assertEqual("Author", rev.properties["author"]) self.assertEqual(300, rev.timezone) self.assertEqual([], rev.parent_ids) - self.assertEqual("5", rev.properties['author-timestamp']) - self.assertEqual("180", rev.properties['author-timezone']) + self.assertEqual("5", rev.properties["author-timestamp"]) + self.assertEqual("180", rev.properties["author-timezone"]) self.assertEqual(b"git-v1:" + c.id, rev.revision_id) def test_missing_message(self): @@ -123,11 +124,12 @@ def test_missing_message(self): try: c.id # noqa: B018 except TypeError: # old version of dulwich - self.skipTest('dulwich too old') + self.skipTest("dulwich too old") mapping = BzrGitMappingv1() rev, roundtrip_revid, verifiers = mapping.import_commit( - c, mapping.revision_id_foreign_to_bzr) - self.assertEqual(rev.message, '') + c, mapping.revision_id_foreign_to_bzr + ) + self.assertEqual(rev.message, "") def test_unknown_encoding(self): c = Commit() @@ -142,8 +144,11 @@ def test_unknown_encoding(self): c.encoding = b"Unknown" mapping = BzrGitMappingv1() e = self.assertRaises( - UnknownCommitEncoding, mapping.import_commit, - c, mapping.revision_id_foreign_to_bzr) + UnknownCommitEncoding, + mapping.import_commit, + c, + mapping.revision_id_foreign_to_bzr, + ) self.assertEqual(e.encoding, "Unknown") def test_explicit_encoding(self): @@ -159,10 +164,11 @@ def test_explicit_encoding(self): c.encoding = b"iso8859-1" mapping = BzrGitMappingv1() rev, roundtrip_revid, verifiers = mapping.import_commit( - c, mapping.revision_id_foreign_to_bzr) + c, mapping.revision_id_foreign_to_bzr + ) self.assertEqual(None, roundtrip_revid) self.assertEqual({}, verifiers) - self.assertEqual("Authér", rev.properties['author']) + self.assertEqual("Authér", rev.properties["author"]) self.assertEqual("iso8859-1", rev.properties["git-explicit-encoding"]) self.assertNotIn("git-implicit-encoding", rev.properties) @@ -179,10 +185,11 @@ def test_explicit_encoding_false(self): c.encoding = b"false" mapping = BzrGitMappingv1() rev, roundtrip_revid, verifiers = mapping.import_commit( - c, mapping.revision_id_foreign_to_bzr) + c, mapping.revision_id_foreign_to_bzr + ) self.assertEqual(None, roundtrip_revid) self.assertEqual({}, verifiers) - self.assertEqual("Authér", rev.properties['author']) + self.assertEqual("Authér", rev.properties["author"]) self.assertEqual("false", rev.properties["git-explicit-encoding"]) self.assertNotIn("git-implicit-encoding", rev.properties) @@ -198,10 +205,11 @@ def test_implicit_encoding_fallback(self): c.author = "Authér".encode("latin1") mapping = BzrGitMappingv1() rev, roundtrip_revid, verifiers = mapping.import_commit( - c, mapping.revision_id_foreign_to_bzr) + c, mapping.revision_id_foreign_to_bzr + ) self.assertEqual(None, roundtrip_revid) self.assertEqual({}, verifiers) - self.assertEqual("Authér", rev.properties['author']) + self.assertEqual("Authér", rev.properties["author"]) self.assertEqual("latin1", rev.properties["git-implicit-encoding"]) self.assertNotIn("git-explicit-encoding", rev.properties) @@ -217,10 +225,11 @@ def test_implicit_encoding_utf8(self): c.author = "Authér".encode() mapping = BzrGitMappingv1() rev, roundtrip_revid, verifiers = mapping.import_commit( - c, mapping.revision_id_foreign_to_bzr) + c, mapping.revision_id_foreign_to_bzr + ) self.assertEqual(None, roundtrip_revid) self.assertEqual({}, verifiers) - self.assertEqual("Authér", rev.properties['author']) + self.assertEqual("Authér", rev.properties["author"]) self.assertNotIn("git-explicit-encoding", rev.properties) self.assertNotIn("git-implicit-encoding", rev.properties) @@ -236,8 +245,12 @@ def test_unknown_extra(self): c.author = b"Author" c._extra.append((b"iamextra", b"foo")) mapping = BzrGitMappingv1() - self.assertRaises(UnknownCommitExtra, mapping.import_commit, c, - mapping.revision_id_foreign_to_bzr) + self.assertRaises( + UnknownCommitExtra, + mapping.import_commit, + c, + mapping.revision_id_foreign_to_bzr, + ) mapping.import_commit(c, mapping.revision_id_foreign_to_bzr, strict=False) def test_mergetag(self): @@ -250,18 +263,23 @@ def test_mergetag(self): c.commit_timezone = 60 * 5 c.author_timezone = 60 * 3 c.author = b"Author" - tag = make_object(Tag, - tagger=b'Jelmer Vernooij ', - name=b'0.1', message=None, - object=( - Blob, b'd80c186a03f423a81b39df39dc87fd269736ca86'), - tag_time=423423423, tag_timezone=0) + tag = make_object( + Tag, + tagger=b"Jelmer Vernooij ", + name=b"0.1", + message=None, + object=(Blob, b"d80c186a03f423a81b39df39dc87fd269736ca86"), + tag_time=423423423, + tag_timezone=0, + ) c.mergetag = [tag] mapping = BzrGitMappingv1() rev, roundtrip_revid, verifiers = mapping.import_commit( - c, mapping.revision_id_foreign_to_bzr) + c, mapping.revision_id_foreign_to_bzr + ) self.assertEqual( - rev.properties['git-mergetag-0'].encode('utf-8'), tag.as_raw_string()) + rev.properties["git-mergetag-0"].encode("utf-8"), tag.as_raw_string() + ) def test_unknown_hg_fields(self): c = Commit() @@ -277,12 +295,14 @@ def test_unknown_hg_fields(self): mapping = BzrGitMappingv1() self.assertRaises( UnknownMercurialCommitExtra, - mapping.import_commit, c, mapping.revision_id_foreign_to_bzr) - mapping.import_commit( - c, mapping.revision_id_foreign_to_bzr, strict=False) + mapping.import_commit, + c, + mapping.revision_id_foreign_to_bzr, + ) + mapping.import_commit(c, mapping.revision_id_foreign_to_bzr, strict=False) self.assertEqual( - mapping.revision_id_foreign_to_bzr(c.id), - mapping.get_revision_id(c)) + mapping.revision_id_foreign_to_bzr(c.id), mapping.get_revision_id(c) + ) def test_invalid_utf8(self): c = Commit() @@ -296,12 +316,11 @@ def test_invalid_utf8(self): c.author = b"Author" mapping = BzrGitMappingv1() self.assertEqual( - mapping.revision_id_foreign_to_bzr(c.id), - mapping.get_revision_id(c)) + mapping.revision_id_foreign_to_bzr(c.id), mapping.get_revision_id(c) + ) class RoundtripRevisionsFromBazaar(tests.TestCase): - def setUp(self): super().setUp() self.mapping = BzrGitMappingv1() @@ -309,12 +328,15 @@ def setUp(self): self._lookup_parent = self._parent_map.__getitem__ def assertRoundtripRevision(self, orig_rev): - commit = self.mapping.export_commit(orig_rev, b"mysha", - self._lookup_parent, True, b"testamentsha") + commit = self.mapping.export_commit( + orig_rev, b"mysha", self._lookup_parent, True, b"testamentsha" + ) rev, roundtrip_revid, verifiers = self.mapping.import_commit( - commit, self.mapping.revision_id_foreign_to_bzr, strict=True) - self.assertEqual(rev.revision_id, - self.mapping.revision_id_foreign_to_bzr(commit.id)) + commit, self.mapping.revision_id_foreign_to_bzr, strict=True + ) + self.assertEqual( + rev.revision_id, self.mapping.revision_id_foreign_to_bzr(commit.id) + ) if self.mapping.roundtripping: self.assertEqual({"testament3-sha1": b"testamentsha"}, verifiers) self.assertEqual(orig_rev.revision_id, roundtrip_revid) @@ -329,78 +351,95 @@ def assertRoundtripRevision(self, orig_rev): return commit def test_simple_commit(self): - r = Revision(self.mapping.revision_id_foreign_to_bzr( - b"edf99e6c56495c620f20d5dacff9859ff7119261"), + r = Revision( + self.mapping.revision_id_foreign_to_bzr( + b"edf99e6c56495c620f20d5dacff9859ff7119261" + ), message="MyCommitMessage", parent_ids=[], committer="Jelmer Vernooij ", timestamp=453543543, timezone=0, - properties={}, inventory_sha1=None) + properties={}, + inventory_sha1=None, + ) self.assertRoundtripRevision(r) def test_revision_id(self): - r = Revision(b"myrevid", + r = Revision( + b"myrevid", message="MyCommitMessage", parent_ids=[], committer="Jelmer Vernooij ", timestamp=453543543, timezone=0, - properties={}, inventory_sha1=None) + properties={}, + inventory_sha1=None, + ) self.assertRoundtripRevision(r) def test_ghost_parent(self): - r = Revision(b"myrevid", + r = Revision( + b"myrevid", message="MyCommitMessage", parent_ids=[b"iamaghost"], committer="Jelmer Vernooij ", timestamp=453543543, timezone=0, - properties={}, inventory_sha1=None) + properties={}, + inventory_sha1=None, + ) self.assertRoundtripRevision(r) def test_custom_property(self): - r = Revision(b"myrevid", + r = Revision( + b"myrevid", message="MyCommitMessage", parent_ids=[], properties={"fool": "bar"}, committer="Jelmer Vernooij ", timestamp=453543543, timezone=0, - inventory_sha1=None) + inventory_sha1=None, + ) self.assertRoundtripRevision(r) def test_multiple_authors(self): - r = Revision(b"myrevid", + r = Revision( + b"myrevid", message="MyCommitMessage", parent_ids=[], properties={ - "authors": - "Jelmer Vernooij \n" - "Alex "}, + "authors": "Jelmer Vernooij \n" + "Alex " + }, committer="Jelmer Vernooij ", timestamp=453543543, - timezone=0, inventory_sha1=None) + timezone=0, + inventory_sha1=None, + ) c = self.assertRoundtripRevision(r) - self.assertEqual(c.author, b'Jelmer Vernooij ') + self.assertEqual(c.author, b"Jelmer Vernooij ") def test_multiple_authors_comma(self): - r = Revision(b"myrevid", + r = Revision( + b"myrevid", message="MyCommitMessage", parent_ids=[], properties={ - "authors": - "Jelmer Vernooij , " - "Alex "}, + "authors": "Jelmer Vernooij , " + "Alex " + }, committer="Jelmer Vernooij ", timestamp=453543543, - timezone=0, inventory_sha1=None) + timezone=0, + inventory_sha1=None, + ) c = self.assertRoundtripRevision(r) - self.assertEqual(c.author, b'Jelmer Vernooij ') + self.assertEqual(c.author, b"Jelmer Vernooij ") class RoundtripRevisionsFromGit(tests.TestCase): - def setUp(self): super().setUp() self.mapping = BzrGitMappingv1() @@ -413,9 +452,9 @@ def assertRoundtripBlob(self, blob): def assertRoundtripCommit(self, commit1): rev, roundtrip_revid, verifiers = self.mapping.import_commit( - commit1, self.mapping.revision_id_foreign_to_bzr, strict=True) - commit2 = self.mapping.export_commit(rev, "12341212121212", None, - True, None) + commit1, self.mapping.revision_id_foreign_to_bzr, strict=True + ) + commit2 = self.mapping.export_commit(rev, "12341212121212", None, True, None) self.assertEqual(commit1.committer, commit2.committer) self.assertEqual(commit1.commit_time, commit2.commit_time) self.assertEqual(commit1.commit_timezone, commit2.commit_timezone) @@ -467,7 +506,7 @@ def test_commit_encoding(self): c.tree = b"cc9462f7f8263ef5adfbeff2fb936bb36b504cba" c.message = b"Some message" c.committer = b"Committer " - c.encoding = b'iso8859-1' + c.encoding = b"iso8859-1" c.commit_time = 4 c.commit_timezone = -60 * 3 c.author_time = 5 @@ -498,33 +537,38 @@ def test_commit_mergetag(self): c.author_time = 5 c.author_timezone = 60 * 2 c.author = b"Author " - tag = make_object(Tag, - tagger=b'Jelmer Vernooij ', - name=b'0.1', message=None, - object=( - Blob, b'd80c186a03f423a81b39df39dc87fd269736ca86'), - tag_time=423423423, tag_timezone=0) + tag = make_object( + Tag, + tagger=b"Jelmer Vernooij ", + name=b"0.1", + message=None, + object=(Blob, b"d80c186a03f423a81b39df39dc87fd269736ca86"), + tag_time=423423423, + tag_timezone=0, + ) c.mergetag = [tag] self.assertRoundtripCommit(c) class FixPersonIdentifierTests(tests.TestCase): - def test_valid(self): - self.assertEqual(b"foo ", - fix_person_identifier(b"foo ")) - self.assertEqual(b"bar@blah.nl ", - fix_person_identifier(b"bar@blah.nl")) + self.assertEqual( + b"foo ", fix_person_identifier(b"foo ") + ) + self.assertEqual( + b"bar@blah.nl ", fix_person_identifier(b"bar@blah.nl") + ) def test_fix(self): self.assertEqual( b"person ", - fix_person_identifier(b"somebody >")) + fix_person_identifier(b"somebody >"), + ) self.assertEqual( - b"person ", - fix_person_identifier(b"person")) + b"person ", fix_person_identifier(b"person") + ) self.assertEqual( - b'Rohan Garg ', - fix_person_identifier(b'Rohan Garg bar@blah.nl<") + b"Rohan Garg ", + fix_person_identifier(b"Rohan Garg bar@blah.nl<") diff --git a/breezy/git/tests/test_memorytree.py b/breezy/git/tests/test_memorytree.py index b9e3b731b8..ad3afc6d5f 100644 --- a/breezy/git/tests/test_memorytree.py +++ b/breezy/git/tests/test_memorytree.py @@ -23,16 +23,15 @@ class TestMemoryTree(TestCaseWithTransport): - - def make_branch(self, path, format='git'): + def make_branch(self, path, format="git"): return super().make_branch(path, format=format) - def make_branch_and_tree(self, path, format='git'): + def make_branch_and_tree(self, path, format="git"): return super().make_branch_and_tree(path, format=format) def test_create_on_branch(self): """Creating a mutable tree on a trivial branch works.""" - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() self.assertEqual(branch.controldir, tree.controldir) self.assertEqual(branch, tree.branch) @@ -40,26 +39,25 @@ def test_create_on_branch(self): def test_create_on_branch_with_content(self): """Creating a mutable tree on a non-trivial branch works.""" - wt = self.make_branch_and_tree('sometree') - self.build_tree(['sometree/foo']) - wt.add(['foo']) - rev_id = wt.commit('first post') + wt = self.make_branch_and_tree("sometree") + self.build_tree(["sometree/foo"]) + wt.add(["foo"]) + rev_id = wt.commit("first post") tree = wt.branch.create_memorytree() with tree.lock_read(): self.assertEqual([rev_id], tree.get_parent_ids()) - self.assertEqual(b'contents of sometree/foo\n', - tree.get_file('foo').read()) + self.assertEqual(b"contents of sometree/foo\n", tree.get_file("foo").read()) def test_lock_tree_write(self): """Check we can lock_tree_write and unlock MemoryTrees.""" - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() tree.lock_tree_write() tree.unlock() def test_lock_tree_write_after_read_fails(self): """Check that we error when trying to upgrade a read lock to write.""" - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() tree.lock_read() self.assertRaises(errors.ReadOnlyError, tree.lock_tree_write) @@ -67,57 +65,57 @@ def test_lock_tree_write_after_read_fails(self): def test_lock_write(self): """Check we can lock_write and unlock MemoryTrees.""" - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() tree.lock_write() tree.unlock() def test_lock_write_after_read_fails(self): """Check that we error when trying to upgrade a read lock to write.""" - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() tree.lock_read() self.assertRaises(errors.ReadOnlyError, tree.lock_write) tree.unlock() def test_add_with_kind(self): - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() tree.lock_write() - tree.add(['', 'afile', 'adir'], ['directory', 'file', 'directory']) - self.assertTrue(tree.is_versioned('afile')) - self.assertFalse(tree.is_versioned('adir')) - self.assertFalse(tree.has_filename('afile')) - self.assertFalse(tree.has_filename('adir')) + tree.add(["", "afile", "adir"], ["directory", "file", "directory"]) + self.assertTrue(tree.is_versioned("afile")) + self.assertFalse(tree.is_versioned("adir")) + self.assertFalse(tree.has_filename("afile")) + self.assertFalse(tree.has_filename("adir")) tree.unlock() def test_put_new_file(self): - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() with tree.lock_write(): - tree.add(['', 'foo'], kinds=['directory', 'file']) - tree.put_file_bytes_non_atomic('foo', b'barshoom') - self.assertEqual(b'barshoom', tree.get_file('foo').read()) + tree.add(["", "foo"], kinds=["directory", "file"]) + tree.put_file_bytes_non_atomic("foo", b"barshoom") + self.assertEqual(b"barshoom", tree.get_file("foo").read()) def test_put_existing_file(self): - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() with tree.lock_write(): - tree.add(['', 'foo'], kinds=['directory', 'file']) - tree.put_file_bytes_non_atomic('foo', b'first-content') - tree.put_file_bytes_non_atomic('foo', b'barshoom') - self.assertEqual(b'barshoom', tree.get_file('foo').read()) + tree.add(["", "foo"], kinds=["directory", "file"]) + tree.put_file_bytes_non_atomic("foo", b"first-content") + tree.put_file_bytes_non_atomic("foo", b"barshoom") + self.assertEqual(b"barshoom", tree.get_file("foo").read()) def test_add_in_subdir(self): - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() with tree.lock_write(): - tree.add([''], ['directory']) - tree.mkdir('adir') - tree.put_file_bytes_non_atomic('adir/afile', b'barshoom') - tree.add(['adir/afile'], ['file']) - self.assertTrue(tree.is_versioned('adir/afile')) - self.assertTrue(tree.is_versioned('adir')) + tree.add([""], ["directory"]) + tree.mkdir("adir") + tree.put_file_bytes_non_atomic("adir/afile", b"barshoom") + tree.add(["adir/afile"], ["file"]) + self.assertTrue(tree.is_versioned("adir/afile")) + self.assertTrue(tree.is_versioned("adir")) def test_commit_trivial(self): """Smoke test for commit on a MemoryTree. @@ -125,72 +123,69 @@ def test_commit_trivial(self): Becamse of commits design and layering, if this works, all commit logic should work quite reliably. """ - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() with tree.lock_write(): - tree.add(['', 'foo'], kinds=['directory', 'file']) - tree.put_file_bytes_non_atomic('foo', b'barshoom') - revision_id = tree.commit('message baby') + tree.add(["", "foo"], kinds=["directory", "file"]) + tree.put_file_bytes_non_atomic("foo", b"barshoom") + revision_id = tree.commit("message baby") # the parents list for the tree should have changed. self.assertEqual([revision_id], tree.get_parent_ids()) # and we should have a revision that is accessible outside the tree lock revtree = tree.branch.repository.revision_tree(revision_id) with revtree.lock_read(): - self.assertEqual(b'barshoom', revtree.get_file('foo').read()) + self.assertEqual(b"barshoom", revtree.get_file("foo").read()) def test_unversion(self): """Some test for unversion of a memory tree.""" - branch = self.make_branch('branch') + branch = self.make_branch("branch") tree = branch.create_memorytree() with tree.lock_write(): - tree.add(['', 'foo'], kinds=['directory', 'file']) - tree.unversion(['foo']) - self.assertFalse(tree.is_versioned('foo')) - self.assertFalse(tree.has_filename('foo')) + tree.add(["", "foo"], kinds=["directory", "file"]) + tree.unversion(["foo"]) + self.assertFalse(tree.is_versioned("foo")) + self.assertFalse(tree.has_filename("foo")) def test_last_revision(self): """There should be a last revision method we can call.""" - tree = self.make_branch_and_memory_tree('branch') + tree = self.make_branch_and_memory_tree("branch") with tree.lock_write(): - tree.add('') - rev_id = tree.commit('first post') + tree.add("") + rev_id = tree.commit("first post") self.assertEqual(rev_id, tree.last_revision()) def test_rename_file(self): - tree = self.make_branch_and_memory_tree('branch') + tree = self.make_branch_and_memory_tree("branch") tree.lock_write() self.addCleanup(tree.unlock) - tree.add(['', 'foo'], ['directory', 'file'], ids=[b'root-id', b'foo-id']) - tree.put_file_bytes_non_atomic('foo', b'content\n') - tree.commit('one', rev_id=b'rev-one') - tree.rename_one('foo', 'bar') - self.assertEqual('bar', tree.id2path(b'foo-id')) - self.assertEqual(b'content\n', tree._file_transport.get_bytes('bar')) - self.assertRaises(NoSuchFile, - tree._file_transport.get_bytes, 'foo') - tree.commit('two', rev_id=b'rev-two') - self.assertEqual(b'content\n', tree._file_transport.get_bytes('bar')) - self.assertRaises(NoSuchFile, - tree._file_transport.get_bytes, 'foo') - - rev_tree2 = tree.branch.repository.revision_tree(b'rev-two') - self.assertEqual('bar', rev_tree2.id2path(b'foo-id')) - self.assertEqual(b'content\n', rev_tree2.get_file_text('bar')) + tree.add(["", "foo"], ["directory", "file"], ids=[b"root-id", b"foo-id"]) + tree.put_file_bytes_non_atomic("foo", b"content\n") + tree.commit("one", rev_id=b"rev-one") + tree.rename_one("foo", "bar") + self.assertEqual("bar", tree.id2path(b"foo-id")) + self.assertEqual(b"content\n", tree._file_transport.get_bytes("bar")) + self.assertRaises(NoSuchFile, tree._file_transport.get_bytes, "foo") + tree.commit("two", rev_id=b"rev-two") + self.assertEqual(b"content\n", tree._file_transport.get_bytes("bar")) + self.assertRaises(NoSuchFile, tree._file_transport.get_bytes, "foo") + + rev_tree2 = tree.branch.repository.revision_tree(b"rev-two") + self.assertEqual("bar", rev_tree2.id2path(b"foo-id")) + self.assertEqual(b"content\n", rev_tree2.get_file_text("bar")) def test_rename_file_to_subdir(self): - tree = self.make_branch_and_memory_tree('branch') + tree = self.make_branch_and_memory_tree("branch") tree.lock_write() self.addCleanup(tree.unlock) - tree.add('') - tree.mkdir('subdir', b'subdir-id') - tree.add('foo', 'file', b'foo-id') - tree.put_file_bytes_non_atomic('foo', b'content\n') - tree.commit('one', rev_id=b'rev-one') - - tree.rename_one('foo', 'subdir/bar') - self.assertEqual('subdir/bar', tree.id2path(b'foo-id')) - self.assertEqual(b'content\n', - tree._file_transport.get_bytes('subdir/bar')) - tree.commit('two', rev_id=b'rev-two') - rev_tree2 = tree.branch.repository.revision_tree(b'rev-two') - self.assertEqual('subdir/bar', rev_tree2.id2path(b'foo-id')) + tree.add("") + tree.mkdir("subdir", b"subdir-id") + tree.add("foo", "file", b"foo-id") + tree.put_file_bytes_non_atomic("foo", b"content\n") + tree.commit("one", rev_id=b"rev-one") + + tree.rename_one("foo", "subdir/bar") + self.assertEqual("subdir/bar", tree.id2path(b"foo-id")) + self.assertEqual(b"content\n", tree._file_transport.get_bytes("subdir/bar")) + tree.commit("two", rev_id=b"rev-two") + rev_tree2 = tree.branch.repository.revision_tree(b"rev-two") + self.assertEqual("subdir/bar", rev_tree2.id2path(b"foo-id")) diff --git a/breezy/git/tests/test_object_store.py b/breezy/git/tests/test_object_store.py index d0bc3811ef..b4150dd7c7 100644 --- a/breezy/git/tests/test_object_store.py +++ b/breezy/git/tests/test_object_store.py @@ -40,7 +40,6 @@ class ExpectedShaTests(TestCase): - def setUp(self): super().setUp() self.obj = Blob() @@ -50,47 +49,42 @@ def test_none(self): _check_expected_sha(None, self.obj) def test_hex(self): - _check_expected_sha( - self.obj.sha().hexdigest().encode('ascii'), self.obj) - self.assertRaises(AssertionError, _check_expected_sha, - b"0" * 40, self.obj) + _check_expected_sha(self.obj.sha().hexdigest().encode("ascii"), self.obj) + self.assertRaises(AssertionError, _check_expected_sha, b"0" * 40, self.obj) def test_binary(self): _check_expected_sha(self.obj.sha().digest(), self.obj) - self.assertRaises(AssertionError, _check_expected_sha, - b"x" * 20, self.obj) + self.assertRaises(AssertionError, _check_expected_sha, b"x" * 20, self.obj) class FindMissingBzrRevidsTests(TestCase): - def _find_missing(self, ancestry, want, have): return _find_missing_bzr_revids( - Graph(DictParentsProvider(ancestry)), - set(want), set(have)) + Graph(DictParentsProvider(ancestry)), set(want), set(have) + ) def test_simple(self): self.assertEqual(set(), self._find_missing({}, [], [])) def test_up_to_date(self): - self.assertEqual(set(), - self._find_missing({"a": ["b"]}, ["a"], ["a"])) + self.assertEqual(set(), self._find_missing({"a": ["b"]}, ["a"], ["a"])) def test_one_missing(self): - self.assertEqual({"a"}, - self._find_missing({"a": ["b"]}, ["a"], ["b"])) + self.assertEqual({"a"}, self._find_missing({"a": ["b"]}, ["a"], ["b"])) def test_two_missing(self): - self.assertEqual({"a", "b"}, - self._find_missing({"a": ["b"], "b": ["c"]}, ["a"], ["c"])) + self.assertEqual( + {"a", "b"}, self._find_missing({"a": ["b"], "b": ["c"]}, ["a"], ["c"]) + ) def test_two_missing_history(self): - self.assertEqual({"a", "b"}, - self._find_missing({"a": ["b"], "b": ["c"], "c": ["d"]}, - ["a"], ["c"])) + self.assertEqual( + {"a", "b"}, + self._find_missing({"a": ["b"], "b": ["c"], "c": ["d"]}, ["a"], ["c"]), + ) class LRUTreeCacheTests(TestCaseWithTransport): - def setUp(self): super().setUp() self.branch = self.make_branch(".") @@ -99,32 +93,32 @@ def setUp(self): self.cache = LRUTreeCache(self.branch.repository) def test_get_not_present(self): - self.assertRaises(NoSuchRevision, self.cache.revision_tree, - "unknown") + self.assertRaises(NoSuchRevision, self.cache.revision_tree, "unknown") def test_revision_trees(self): - self.assertRaises(NoSuchRevision, self.cache.revision_trees, - ["unknown", "la"]) + self.assertRaises(NoSuchRevision, self.cache.revision_trees, ["unknown", "la"]) def test_iter_revision_trees(self): - self.assertRaises(NoSuchRevision, self.cache.iter_revision_trees, - ["unknown", "la"]) + self.assertRaises( + NoSuchRevision, self.cache.iter_revision_trees, ["unknown", "la"] + ) def test_get(self): bb = BranchBuilder(branch=self.branch) bb.start_series() - revid = bb.build_snapshot(None, - [('add', ('', None, 'directory', None)), - ('add', ('foo', b'foo-id', - 'file', b'a\nb\nc\nd\ne\n')), - ]) + revid = bb.build_snapshot( + None, + [ + ("add", ("", None, "directory", None)), + ("add", ("foo", b"foo-id", "file", b"a\nb\nc\nd\ne\n")), + ], + ) bb.finish_series() tree = self.cache.revision_tree(revid) self.assertEqual(revid, tree.get_revision_id()) class BazaarObjectStoreTests(TestCaseWithTransport): - def setUp(self): super().setUp() self.branch = self.make_branch(".") @@ -134,16 +128,19 @@ def test_get_blob(self): self.branch.lock_write() self.addCleanup(self.branch.unlock) b = Blob() - b.data = b'a\nb\nc\nd\ne\n' + b.data = b"a\nb\nc\nd\ne\n" self.store.lock_read() self.addCleanup(self.store.unlock) self.assertRaises(KeyError, self.store.__getitem__, b.id) bb = BranchBuilder(branch=self.branch) bb.start_series() - bb.build_snapshot(None, - [('add', ('', None, 'directory', None)), - ('add', ('foo', b'foo-id', 'file', b'a\nb\nc\nd\ne\n')), - ]) + bb.build_snapshot( + None, + [ + ("add", ("", None, "directory", None)), + ("add", ("foo", b"foo-id", "file", b"a\nb\nc\nd\ne\n")), + ], + ) bb.finish_series() # read locks cache self.assertRaises(KeyError, self.store.__getitem__, b.id) @@ -154,19 +151,17 @@ def test_get_blob(self): def test_directory_converted_to_symlink(self): self.requireFeature(SymlinkFeature(self.test_dir)) b = Blob() - b.data = b'trgt' + b.data = b"trgt" self.store.lock_read() self.addCleanup(self.store.unlock) self.assertRaises(KeyError, self.store.__getitem__, b.id) tree = self.branch.controldir.create_workingtree() - self.build_tree_contents([ - ('foo/', ), - ('foo/bar', b'a\nb\nc\nd\ne\n')]) - tree.add(['foo', 'foo/bar']) - tree.commit('commit 1') - shutil.rmtree('foo') - os.symlink('trgt', 'foo') - tree.commit('commit 2') + self.build_tree_contents([("foo/",), ("foo/bar", b"a\nb\nc\nd\ne\n")]) + tree.add(["foo", "foo/bar"]) + tree.commit("commit 1") + shutil.rmtree("foo") + os.symlink("trgt", "foo") + tree.commit("commit 2") # read locks cache self.assertRaises(KeyError, self.store.__getitem__, b.id) self.store.unlock() @@ -177,16 +172,19 @@ def test_get_raw(self): self.branch.lock_write() self.addCleanup(self.branch.unlock) b = Blob() - b.data = b'a\nb\nc\nd\ne\n' + b.data = b"a\nb\nc\nd\ne\n" self.store.lock_read() self.addCleanup(self.store.unlock) self.assertRaises(KeyError, self.store.get_raw, b.id) bb = BranchBuilder(branch=self.branch) bb.start_series() - bb.build_snapshot(None, - [('add', ('', None, 'directory', None)), - ('add', ('foo', b'foo-id', 'file', b'a\nb\nc\nd\ne\n')), - ]) + bb.build_snapshot( + None, + [ + ("add", ("", None, "directory", None)), + ("add", ("foo", b"foo-id", "file", b"a\nb\nc\nd\ne\n")), + ], + ) bb.finish_series() # read locks cache self.assertRaises(KeyError, self.store.get_raw, b.id) @@ -198,16 +196,19 @@ def test_contains(self): self.branch.lock_write() self.addCleanup(self.branch.unlock) b = Blob() - b.data = b'a\nb\nc\nd\ne\n' + b.data = b"a\nb\nc\nd\ne\n" self.store.lock_read() self.addCleanup(self.store.unlock) self.assertNotIn(b.id, self.store) bb = BranchBuilder(branch=self.branch) bb.start_series() - bb.build_snapshot(None, - [('add', ('', None, 'directory', None)), - ('add', ('foo', b'foo-id', 'file', b'a\nb\nc\nd\ne\n')), - ]) + bb.build_snapshot( + None, + [ + ("add", ("", None, "directory", None)), + ("add", ("foo", b"foo-id", "file", b"a\nb\nc\nd\ne\n")), + ], + ) bb.finish_series() # read locks cache self.assertNotIn(b.id, self.store) @@ -217,58 +218,56 @@ def test_contains(self): class TreeToObjectsTests(TestCaseWithTransport): - def setUp(self): super().setUp() self.idmap = DictGitShaMap() def test_no_changes(self): - tree = self.make_branch_and_tree('.') + tree = self.make_branch_and_tree(".") self.addCleanup(tree.lock_read().unlock) entries = list(_tree_to_objects(tree, [tree], self.idmap, {})) self.assertEqual([], entries) def test_with_gitdir(self): - tree = self.make_branch_and_tree('.') - self.build_tree(['.git', 'foo']) - tree.add(['.git', 'foo']) - revid = tree.commit('commit') + tree = self.make_branch_and_tree(".") + self.build_tree([".git", "foo"]) + tree.add([".git", "foo"]) + revid = tree.commit("commit") revtree = tree.branch.repository.revision_tree(revid) self.addCleanup(revtree.lock_read().unlock) entries = list(_tree_to_objects(revtree, [], self.idmap, {})) - self.assertEqual(['foo', ''], [p[0] for p in entries]) + self.assertEqual(["foo", ""], [p[0] for p in entries]) def test_merge(self): - basis_tree = self.make_branch_and_tree('base') - self.build_tree(['base/foo/']) - basis_tree.add(['foo']) - basis_rev = basis_tree.commit('foo') + basis_tree = self.make_branch_and_tree("base") + self.build_tree(["base/foo/"]) + basis_tree.add(["foo"]) + basis_rev = basis_tree.commit("foo") basis_revtree = basis_tree.branch.repository.revision_tree(basis_rev) - tree_a = self.make_branch_and_tree('a') + tree_a = self.make_branch_and_tree("a") tree_a.pull(basis_tree.branch) - self.build_tree(['a/foo/file1']) - self.build_tree(['a/foo/subdir-a/']) - os.symlink('.', 'a/foo/subdir-a/symlink') - tree_a.add(['foo/subdir-a', 'foo/subdir-a/symlink']) + self.build_tree(["a/foo/file1"]) + self.build_tree(["a/foo/subdir-a/"]) + os.symlink(".", "a/foo/subdir-a/symlink") + tree_a.add(["foo/subdir-a", "foo/subdir-a/symlink"]) - tree_a.add(['foo/file1']) - rev_a = tree_a.commit('commit a') + tree_a.add(["foo/file1"]) + rev_a = tree_a.commit("commit a") revtree_a = tree_a.branch.repository.revision_tree(rev_a) with revtree_a.lock_read(): - entries = list(_tree_to_objects(revtree_a, [basis_revtree], - self.idmap, {})) + entries = list(_tree_to_objects(revtree_a, [basis_revtree], self.idmap, {})) objects = {path: obj for (path, obj, key) in entries} - subdir_a = objects['foo/subdir-a'] + subdir_a = objects["foo/subdir-a"] - tree_b = self.make_branch_and_tree('b') + tree_b = self.make_branch_and_tree("b") tree_b.pull(basis_tree.branch) - self.build_tree(['b/foo/subdir/']) - os.symlink('.', 'b/foo/subdir/symlink') - tree_b.add(['foo/subdir', 'foo/subdir/symlink']) - rev_b = tree_b.commit('commit b') + self.build_tree(["b/foo/subdir/"]) + os.symlink(".", "b/foo/subdir/symlink") + tree_b.add(["foo/subdir", "foo/subdir/symlink"]) + rev_b = tree_b.commit("commit b") revtree_b = tree_b.branch.repository.revision_tree(rev_b) self.addCleanup(revtree_b.lock_read().unlock) @@ -277,62 +276,70 @@ def test_merge(self): with tree_a.lock_write(): tree_a.merge_from_branch(tree_b.branch) - tree_a.commit('merge') + tree_a.commit("merge") revtree_merge = tree_a.branch.basis_tree() self.addCleanup(revtree_merge.lock_read().unlock) - entries = list(_tree_to_objects( - revtree_merge, - [tree_a.branch.repository.revision_tree(r) - for r in revtree_merge.get_parent_ids()], - self.idmap, {})) + entries = list( + _tree_to_objects( + revtree_merge, + [ + tree_a.branch.repository.revision_tree(r) + for r in revtree_merge.get_parent_ids() + ], + self.idmap, + {}, + ) + ) objects = {path: obj for (path, obj, key) in entries} - self.assertEqual({'', 'foo', 'foo/subdir'}, set(objects)) - self.assertEqual( - (stat.S_IFDIR, subdir_a.id), objects['foo'][b'subdir-a']) + self.assertEqual({"", "foo", "foo/subdir"}, set(objects)) + self.assertEqual((stat.S_IFDIR, subdir_a.id), objects["foo"][b"subdir-a"]) class DirectoryToTreeTests(TestCase): - def test_empty(self): - t = directory_to_tree('', [], None, {}, None, allow_empty=False) + t = directory_to_tree("", [], None, {}, None, allow_empty=False) self.assertEqual(None, t) def test_empty_dir(self): - child_ie = InventoryDirectory(b'bar', 'bar', b'bar') - t = directory_to_tree('', [child_ie], lambda p, x: None, {}, None, - allow_empty=False) + child_ie = InventoryDirectory(b"bar", "bar", b"bar") + t = directory_to_tree( + "", [child_ie], lambda p, x: None, {}, None, allow_empty=False + ) self.assertEqual(None, t) def test_empty_dir_dummy_files(self): - child_ie = InventoryDirectory(b'bar', 'bar', b'bar') - t = directory_to_tree('', [child_ie], lambda p, x: None, {}, ".mydummy", - allow_empty=False) + child_ie = InventoryDirectory(b"bar", "bar", b"bar") + t = directory_to_tree( + "", [child_ie], lambda p, x: None, {}, ".mydummy", allow_empty=False + ) self.assertIn(".mydummy", t) def test_empty_root(self): - child_ie = InventoryDirectory(b'bar', 'bar', b'bar') - t = directory_to_tree('', [child_ie], lambda p, x: None, {}, None, - allow_empty=True) + child_ie = InventoryDirectory(b"bar", "bar", b"bar") + t = directory_to_tree( + "", [child_ie], lambda p, x: None, {}, None, allow_empty=True + ) self.assertEqual(Tree(), t) def test_with_file(self): - child_ie = InventoryFile(b'bar', 'bar', b'bar') + child_ie = InventoryFile(b"bar", "bar", b"bar") b = Blob.from_string(b"bla") - t1 = directory_to_tree('', [child_ie], lambda p, x: b.id, {}, None, - allow_empty=False) + t1 = directory_to_tree( + "", [child_ie], lambda p, x: b.id, {}, None, allow_empty=False + ) t2 = Tree() t2.add(b"bar", 0o100644, b.id) self.assertEqual(t1, t2) def test_with_gitdir(self): - child_ie = InventoryFile(b'bar', 'bar', b'bar') - git_file_ie = InventoryFile(b'gitid', '.git', b'bar') + child_ie = InventoryFile(b"bar", "bar", b"bar") + git_file_ie = InventoryFile(b"gitid", ".git", b"bar") b = Blob.from_string(b"bla") - t1 = directory_to_tree('', [child_ie, git_file_ie], - lambda p, x: b.id, {}, None, - allow_empty=False) + t1 = directory_to_tree( + "", [child_ie, git_file_ie], lambda p, x: b.id, {}, None, allow_empty=False + ) t2 = Tree() t2.add(b"bar", 0o100644, b.id) self.assertEqual(t1, t2) diff --git a/breezy/git/tests/test_pristine_tar.py b/breezy/git/tests/test_pristine_tar.py index 260592eae2..4a2bd89449 100644 --- a/breezy/git/tests/test_pristine_tar.py +++ b/breezy/git/tests/test_pristine_tar.py @@ -33,23 +33,39 @@ class RevisionPristineTarDataTests(TestCase): - def test_pristine_tar_delta_unknown(self): - rev = Revision(b"myrevid", properties={}, parent_ids=[], timestamp=0, timezone=0, committer="", inventory_sha1=None, message="") - self.assertRaises(KeyError, - revision_pristine_tar_data, rev) + rev = Revision( + b"myrevid", + properties={}, + parent_ids=[], + timestamp=0, + timezone=0, + committer="", + inventory_sha1=None, + message="", + ) + self.assertRaises(KeyError, revision_pristine_tar_data, rev) def test_pristine_tar_delta_gz(self): - rev = Revision(b"myrevid", properties={"deb-pristine-delta": standard_b64encode(b"bla").decode('ascii')}, parent_ids=[], timestamp=0, timezone=0, committer="", inventory_sha1=None, message="") + rev = Revision( + b"myrevid", + properties={ + "deb-pristine-delta": standard_b64encode(b"bla").decode("ascii") + }, + parent_ids=[], + timestamp=0, + timezone=0, + committer="", + inventory_sha1=None, + message="", + ) self.assertEqual((b"bla", "gz"), revision_pristine_tar_data(rev)) class ReadPristineTarData(TestCase): - def test_read_pristine_tar_data_no_branch(self): r = GitMemoryRepo() - self.assertRaises(KeyError, read_git_pristine_tar_data, - r, b"foo") + self.assertRaises(KeyError, read_git_pristine_tar_data, r, b"foo") def test_read_pristine_tar_data_no_file(self): r = GitMemoryRepo() @@ -58,10 +74,8 @@ def test_read_pristine_tar_data_no_file(self): r.object_store.add_object(b) t.add(b"README", stat.S_IFREG | 0o644, b.id) r.object_store.add_object(t) - r.do_commit(b"Add README", tree=t.id, - ref=b'refs/heads/pristine-tar') - self.assertRaises(KeyError, read_git_pristine_tar_data, - r, b"foo") + r.do_commit(b"Add README", tree=t.id, ref=b"refs/heads/pristine-tar") + self.assertRaises(KeyError, read_git_pristine_tar_data, r, b"foo") def test_read_pristine_tar_data(self): r = GitMemoryRepo() @@ -73,25 +87,25 @@ def test_read_pristine_tar_data(self): t.add(b"foo.delta", stat.S_IFREG | 0o644, delta.id) t.add(b"foo.id", stat.S_IFREG | 0o644, idfile.id) r.object_store.add_object(t) - r.do_commit(b"pristine tar delta for foo", tree=t.id, - ref=b'refs/heads/pristine-tar') + r.do_commit( + b"pristine tar delta for foo", tree=t.id, ref=b"refs/heads/pristine-tar" + ) self.assertEqual( - (b"some yummy data", b"someid"), - read_git_pristine_tar_data(r, b'foo')) + (b"some yummy data", b"someid"), read_git_pristine_tar_data(r, b"foo") + ) class StoreGitPristineTarData(TestCase): - def test_store_new(self): r = GitMemoryRepo() cid = store_git_pristine_tar_data(r, b"foo", b"mydelta", b"myid") tree = get_pristine_tar_tree(r) self.assertEqual( (stat.S_IFREG | 0o644, b"7b02de8ac4162e64f402c43487d8a40a505482e1"), - tree[b"README"]) + tree[b"README"], + ) self.assertEqual(r[cid].tree, tree.id) self.assertEqual(r[tree[b"foo.delta"][1]].data, b"mydelta") self.assertEqual(r[tree[b"foo.id"][1]].data, b"myid") - self.assertEqual((b"mydelta", b"myid"), - read_git_pristine_tar_data(r, b"foo")) + self.assertEqual((b"mydelta", b"myid"), read_git_pristine_tar_data(r, b"foo")) diff --git a/breezy/git/tests/test_push.py b/breezy/git/tests/test_push.py index 418736dd83..3ce464f305 100644 --- a/breezy/git/tests/test_push.py +++ b/breezy/git/tests/test_push.py @@ -24,11 +24,11 @@ class InterToGitRepositoryTests(TestCaseWithTransport): - def setUp(self): super().setUp() - self.git_repo = self.make_repository("git", - format=format_registry.make_controldir("git")) + self.git_repo = self.make_repository( + "git", format=format_registry.make_controldir("git") + ) self.bzr_repo = self.make_repository("bzr", shared=True) def _get_interrepo(self, mapping=None): @@ -48,17 +48,15 @@ def test_pointless_fetch_refs_old_mapping(self): def test_pointless_fetch_refs(self): interrepo = self._get_interrepo(mapping=BzrGitMappingExperimental()) - revidmap, old_refs, new_refs = interrepo.fetch_refs( - lambda x: {}, lossy=False) - self.assertEqual(old_refs, {b'HEAD': ( - b'ref: refs/heads/master', None)}) + revidmap, old_refs, new_refs = interrepo.fetch_refs(lambda x: {}, lossy=False) + self.assertEqual(old_refs, {b"HEAD": (b"ref: refs/heads/master", None)}) self.assertEqual(new_refs, {}) def test_pointless_lossy_fetch_refs(self): - revidmap, old_refs, new_refs = self._get_interrepo( - ).fetch_refs(lambda x: {}, lossy=True) - self.assertEqual(old_refs, {b'HEAD': ( - b'ref: refs/heads/master', None)}) + revidmap, old_refs, new_refs = self._get_interrepo().fetch_refs( + lambda x: {}, lossy=True + ) + self.assertEqual(old_refs, {b"HEAD": (b"ref: refs/heads/master", None)}) self.assertEqual(new_refs, {}) self.assertEqual(revidmap, {}) @@ -72,8 +70,7 @@ def test_missing_revisions_unknown_stop_rev(self): interrepo = self._get_interrepo() interrepo.source_store.lock_read() self.addCleanup(interrepo.source_store.unlock) - self.assertEqual([], - list(interrepo.missing_revisions([(None, b"unknown")]))) + self.assertEqual([], list(interrepo.missing_revisions([(None, b"unknown")]))) def test_odd_rename(self): # Add initial revision to bzr branch. @@ -92,6 +89,7 @@ def test_odd_rename(self): # Push bzr branch to git branch. def decide(x): return {b"refs/heads/master": (None, last_revid)} + interrepo = self._get_interrepo() revidmap, old_refs, new_refs = interrepo.fetch_refs(decide, lossy=True) gitid = revidmap[last_revid][0] diff --git a/breezy/git/tests/test_refs.py b/breezy/git/tests/test_refs.py index c06d8875d1..f3db219ef6 100644 --- a/breezy/git/tests/test_refs.py +++ b/breezy/git/tests/test_refs.py @@ -23,7 +23,6 @@ class BranchNameRefConversionTests(tests.TestCase): - def test_head(self): self.assertEqual("", ref_to_branch_name(b"HEAD")) self.assertEqual(b"HEAD", branch_name_to_ref("")) @@ -37,43 +36,44 @@ def test_branch(self): class BazaarRefsContainerTests(tests.TestCaseWithTransport): - def test_empty(self): - tree = self.make_branch_and_tree('.') + tree = self.make_branch_and_tree(".") store = BazaarObjectStore(tree.branch.repository) refs = BazaarRefsContainer(tree.controldir, store) self.assertEqual(refs.as_dict(), {}) def test_some_commit(self): - tree = self.make_branch_and_tree('.') - revid = tree.commit('somechange') + tree = self.make_branch_and_tree(".") + revid = tree.commit("somechange") store = BazaarObjectStore(tree.branch.repository) refs = BazaarRefsContainer(tree.controldir, store) - self.assertEqual( - refs.as_dict(), - {b'HEAD': store._lookup_revision_sha1(revid)}) + self.assertEqual(refs.as_dict(), {b"HEAD": store._lookup_revision_sha1(revid)}) def test_some_tag(self): - tree = self.make_branch_and_tree('.') - revid = tree.commit('somechange') - tree.branch.tags.set_tag('sometag', revid) + tree = self.make_branch_and_tree(".") + revid = tree.commit("somechange") + tree.branch.tags.set_tag("sometag", revid) store = BazaarObjectStore(tree.branch.repository) refs = BazaarRefsContainer(tree.controldir, store) self.assertEqual( refs.as_dict(), - {b'HEAD': store._lookup_revision_sha1(revid), - b'refs/tags/sometag': store._lookup_revision_sha1(revid), - }) + { + b"HEAD": store._lookup_revision_sha1(revid), + b"refs/tags/sometag": store._lookup_revision_sha1(revid), + }, + ) def test_some_branch(self): - tree = self.make_branch_and_tree('.') - revid = tree.commit('somechange') - otherbranch = tree.controldir.create_branch(name='otherbranch') + tree = self.make_branch_and_tree(".") + revid = tree.commit("somechange") + otherbranch = tree.controldir.create_branch(name="otherbranch") otherbranch.generate_revision_history(revid) store = BazaarObjectStore(tree.branch.repository) refs = BazaarRefsContainer(tree.controldir, store) self.assertEqual( refs.as_dict(), - {b'HEAD': store._lookup_revision_sha1(revid), - b'refs/heads/otherbranch': store._lookup_revision_sha1(revid), - }) + { + b"HEAD": store._lookup_revision_sha1(revid), + b"refs/heads/otherbranch": store._lookup_revision_sha1(revid), + }, + ) diff --git a/breezy/git/tests/test_remote.py b/breezy/git/tests/test_remote.py index 898f0d01ae..8cdffa80c8 100644 --- a/breezy/git/tests/test_remote.py +++ b/breezy/git/tests/test_remote.py @@ -55,51 +55,42 @@ class SplitUrlTests(TestCase): - def test_simple(self): - self.assertEqual(("foo", None, None, "/bar"), - split_git_url("git://foo/bar")) + self.assertEqual(("foo", None, None, "/bar"), split_git_url("git://foo/bar")) def test_port(self): - self.assertEqual(("foo", 343, None, "/bar"), - split_git_url("git://foo:343/bar")) + self.assertEqual(("foo", 343, None, "/bar"), split_git_url("git://foo:343/bar")) def test_username(self): - self.assertEqual(("foo", None, "la", "/bar"), - split_git_url("git://la@foo/bar")) + self.assertEqual(("foo", None, "la", "/bar"), split_git_url("git://la@foo/bar")) def test_username_password(self): self.assertEqual( - ("foo", None, "la", "/bar"), - split_git_url("git://la:passwd@foo/bar")) + ("foo", None, "la", "/bar"), split_git_url("git://la:passwd@foo/bar") + ) def test_nopath(self): - self.assertEqual(("foo", None, None, "/"), - split_git_url("git://foo/")) + self.assertEqual(("foo", None, None, "/"), split_git_url("git://foo/")) def test_slashpath(self): - self.assertEqual(("foo", None, None, "//bar"), - split_git_url("git://foo//bar")) + self.assertEqual(("foo", None, None, "//bar"), split_git_url("git://foo//bar")) def test_homedir(self): - self.assertEqual(("foo", None, None, "~bar"), - split_git_url("git://foo/~bar")) + self.assertEqual(("foo", None, None, "~bar"), split_git_url("git://foo/~bar")) def test_file(self): - self.assertEqual( - ("", None, None, "/bar"), - split_git_url("file:///bar")) + self.assertEqual(("", None, None, "/bar"), split_git_url("file:///bar")) class ParseGitErrorTests(TestCase): - def test_unknown(self): e = parse_git_error("url", "foo") self.assertIsInstance(e, RemoteGitError) def test_connection_closed(self): e = parse_git_error( - "url", "The remote server unexpectedly closed the connection.") + "url", "The remote server unexpectedly closed the connection." + ) self.assertIsInstance(e, TransportError) def test_notbrancherror(self): @@ -116,7 +107,9 @@ def test_notbrancherror_github(self): def test_notbrancherror_normal(self): e = parse_git_error( - "url", "fatal: '/srv/git/lintian-brush' does not appear to be a git repository") + "url", + "fatal: '/srv/git/lintian-brush' does not appear to be a git repository", + ) self.assertIsInstance(e, NotBranchError) def test_head_update(self): @@ -125,151 +118,181 @@ def test_head_update(self): def test_permission_dnied(self): e = parse_git_error( - "url", - "access denied or repository not exported: /debian/altermime.git") + "url", "access denied or repository not exported: /debian/altermime.git" + ) self.assertIsInstance(e, PermissionDenied) def test_permission_denied_gitlab(self): e = parse_git_error( - "url", - 'GitLab: You are not allowed to push code to this project.\n') + "url", "GitLab: You are not allowed to push code to this project.\n" + ) self.assertIsInstance(e, PermissionDenied) def test_permission_denied_github(self): e = parse_git_error( - "url", - 'Permission to porridge/gaduhistory.git denied to jelmer.') + "url", "Permission to porridge/gaduhistory.git denied to jelmer." + ) self.assertIsInstance(e, PermissionDenied) - self.assertEqual(e.path, 'porridge/gaduhistory.git') - self.assertEqual(e.extra, ': denied to jelmer') + self.assertEqual(e.path, "porridge/gaduhistory.git") + self.assertEqual(e.extra, ": denied to jelmer") def test_pre_receive_hook_declined(self): - e = parse_git_error( - "url", - 'pre-receive hook declined') + e = parse_git_error("url", "pre-receive hook declined") self.assertIsInstance(e, PermissionDenied) self.assertEqual(e.path, "url") - self.assertEqual(e.extra, ': pre-receive hook declined') + self.assertEqual(e.extra, ": pre-receive hook declined") def test_invalid_repo_name(self): e = parse_git_error( "url", """Gregwar/fatcat/tree/debian is not a valid repository name Email support@github.com for help -""") +""", + ) self.assertIsInstance(e, NotBranchError) def test_invalid_git_error(self): self.assertEqual( PermissionDenied( - 'url', - 'GitLab: You are not allowed to push code to protected ' - 'branches on this project.'), + "url", + "GitLab: You are not allowed to push code to protected " + "branches on this project.", + ), parse_git_error( - 'url', + "url", RemoteGitError( - 'GitLab: You are not allowed to push code to ' - 'protected branches on this project.'))) + "GitLab: You are not allowed to push code to " + "protected branches on this project." + ), + ), + ) def test_protected_branch(self): self.assertEqual( - ProtectedBranchHookDeclined( - msg='protected branch hook declined'), - parse_git_error( - 'url', - RemoteGitError( - 'protected branch hook declined'))) + ProtectedBranchHookDeclined(msg="protected branch hook declined"), + parse_git_error("url", RemoteGitError("protected branch hook declined")), + ) def test_host_key_verification(self): self.assertEqual( - TransportError('Host key verification failed'), - parse_git_error( - 'url', - RemoteGitError( - 'Host key verification failed.'))) + TransportError("Host key verification failed"), + parse_git_error("url", RemoteGitError("Host key verification failed.")), + ) def test_connection_reset_by_peer(self): got = parse_git_error( - 'url', - RemoteGitError( - '[Errno 104] Connection reset by peer')) + "url", RemoteGitError("[Errno 104] Connection reset by peer") + ) self.assertIsInstance(got, ConnectionResetError) - self.assertEqual('[Errno 104] Connection reset by peer', got.args[0]) + self.assertEqual("[Errno 104] Connection reset by peer", got.args[0]) def test_http_unexpected(self): self.assertEqual( UnexpectedHttpStatus( - 'https://example.com/bigint.git/git-upload-pack', - 403, extra=('unexpected http resp 403 for ' - 'https://example.com/bigint.git/git-upload-pack')), + "https://example.com/bigint.git/git-upload-pack", + 403, + extra=( + "unexpected http resp 403 for " + "https://example.com/bigint.git/git-upload-pack" + ), + ), parse_git_error( - 'url', + "url", RemoteGitError( - 'unexpected http resp 403 for ' - 'https://example.com/bigint.git/git-upload-pack'))) + "unexpected http resp 403 for " + "https://example.com/bigint.git/git-upload-pack" + ), + ), + ) class ParseHangupTests(TestCase): - def setUp(self): super().setUp() try: - HangupException([b'foo']) + HangupException([b"foo"]) except TypeError: - self.skipTest('dulwich version too old') + self.skipTest("dulwich version too old") def test_not_set(self): self.assertIsInstance( - parse_git_hangup('http://', HangupException()), ConnectionResetError) + parse_git_hangup("http://", HangupException()), ConnectionResetError + ) def test_single_line(self): self.assertEqual( - RemoteGitError('foo bar'), - parse_git_hangup('http://', HangupException([b'foo bar']))) + RemoteGitError("foo bar"), + parse_git_hangup("http://", HangupException([b"foo bar"])), + ) def test_multi_lines(self): self.assertEqual( - RemoteGitError('foo bar\nbla bla'), - parse_git_hangup( - 'http://', HangupException([b'foo bar', b'bla bla']))) + RemoteGitError("foo bar\nbla bla"), + parse_git_hangup("http://", HangupException([b"foo bar", b"bla bla"])), + ) def test_filter_boring(self): self.assertEqual( - RemoteGitError('foo bar'), parse_git_hangup('http://', HangupException( - [b'=======', b'foo bar', b'======']))) + RemoteGitError("foo bar"), + parse_git_hangup( + "http://", HangupException([b"=======", b"foo bar", b"======"]) + ), + ) self.assertEqual( - RemoteGitError('foo bar'), parse_git_hangup('http://', HangupException( - [b'remote: =======', b'remote: foo bar', b'remote: ======']))) + RemoteGitError("foo bar"), + parse_git_hangup( + "http://", + HangupException( + [b"remote: =======", b"remote: foo bar", b"remote: ======"] + ), + ), + ) def test_permission_denied(self): self.assertEqual( - PermissionDenied('http://', 'You are not allowed to push code to this project.'), + PermissionDenied( + "http://", "You are not allowed to push code to this project." + ), parse_git_hangup( - 'http://', + "http://", HangupException( - [b'=======', - b'You are not allowed to push code to this project.', b'', b'======']))) + [ + b"=======", + b"You are not allowed to push code to this project.", + b"", + b"======", + ] + ), + ), + ) def test_notbrancherror_yet(self): self.assertEqual( - NotBranchError('http://', 'A repository for this project does not exist yet.'), + NotBranchError( + "http://", "A repository for this project does not exist yet." + ), parse_git_hangup( - 'http://', + "http://", HangupException( - [b'=======', - b'', - b'A repository for this project does not exist yet.', b'', b'======']))) + [ + b"=======", + b"", + b"A repository for this project does not exist yet.", + b"", + b"======", + ] + ), + ), + ) class TestRemoteGitBranchFormat(TestCase): - def setUp(self): super().setUp() self.format = RemoteGitBranchFormat() def test_get_format_description(self): - self.assertEqual("Remote Git Branch", - self.format.get_format_description()) + self.assertEqual("Remote Git Branch", self.format.get_format_description()) def test_get_network_name(self): self.assertEqual(b"git", self.format.network_name()) @@ -279,252 +302,270 @@ def test_supports_tags(self): class TestRemoteGitBranch(TestCaseWithTransport): - - _test_needs_features = [ExecutableFeature('git')] + _test_needs_features = [ExecutableFeature("git")] def setUp(self): TestCaseWithTransport.setUp(self) - self.remote_real = GitRepo.init('remote', mkdir=True) - self.remote_url = f'git://{os.path.abspath(self.remote_real.path)}/' + self.remote_real = GitRepo.init("remote", mkdir=True) + self.remote_url = f"git://{os.path.abspath(self.remote_real.path)}/" self.permit_url(self.remote_url) def test_set_last_revision_info(self): c1 = self.remote_real.do_commit( - message=b'message 1', - committer=b'committer ', - author=b'author ', - ref=b'refs/heads/newbranch') + message=b"message 1", + committer=b"committer ", + author=b"author ", + ref=b"refs/heads/newbranch", + ) c2 = self.remote_real.do_commit( - message=b'message 2', - committer=b'committer ', - author=b'author ', - ref=b'refs/heads/newbranch') + message=b"message 2", + committer=b"committer ", + author=b"author ", + ref=b"refs/heads/newbranch", + ) remote = ControlDir.open(self.remote_url) - newbranch = remote.open_branch('newbranch') - self.assertEqual(newbranch.lookup_foreign_revision_id(c2), - newbranch.last_revision()) - newbranch.set_last_revision_info( - 1, newbranch.lookup_foreign_revision_id(c1)) - self.assertEqual(c1, self.remote_real.refs[b'refs/heads/newbranch']) - self.assertEqual(newbranch.last_revision(), - newbranch.lookup_foreign_revision_id(c1)) + newbranch = remote.open_branch("newbranch") + self.assertEqual( + newbranch.lookup_foreign_revision_id(c2), newbranch.last_revision() + ) + newbranch.set_last_revision_info(1, newbranch.lookup_foreign_revision_id(c1)) + self.assertEqual(c1, self.remote_real.refs[b"refs/heads/newbranch"]) + self.assertEqual( + newbranch.last_revision(), newbranch.lookup_foreign_revision_id(c1) + ) class FetchFromRemoteTestBase: - - _test_needs_features = [ExecutableFeature('git')] + _test_needs_features = [ExecutableFeature("git")] _to_format: str def setUp(self): TestCaseWithTransport.setUp(self) - self.remote_real = GitRepo.init('remote', mkdir=True) - self.remote_url = f'git://{os.path.abspath(self.remote_real.path)}/' + self.remote_real = GitRepo.init("remote", mkdir=True) + self.remote_url = f"git://{os.path.abspath(self.remote_real.path)}/" self.permit_url(self.remote_url) def test_sprout_simple(self): self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) remote = ControlDir.open(self.remote_url) - self.make_controldir('local', format=self._to_format) - local = remote.sprout('local') + self.make_controldir("local", format=self._to_format) + local = remote.sprout("local") self.assertEqual( - default_mapping.revision_id_foreign_to_bzr( - self.remote_real.head()), - local.open_branch().last_revision()) + default_mapping.revision_id_foreign_to_bzr(self.remote_real.head()), + local.open_branch().last_revision(), + ) def test_sprout_submodule_invalid(self): - self.sub_real = GitRepo.init('sub', mkdir=True) + self.sub_real = GitRepo.init("sub", mkdir=True) self.sub_real.do_commit( - message=b'message in sub', - committer=b'committer ', - author=b'author ') - - self.sub_real.clone('remote/nested') - self.remote_real.stage('nested') - self.permit_url(urljoin(self.remote_url, '../sub')) - self.assertIn(b'nested', self.remote_real.open_index()) + message=b"message in sub", + committer=b"committer ", + author=b"author ", + ) + + self.sub_real.clone("remote/nested") + self.remote_real.stage("nested") + self.permit_url(urljoin(self.remote_url, "../sub")) + self.assertIn(b"nested", self.remote_real.open_index()) self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) remote = ControlDir.open(self.remote_url) - self.make_controldir('local', format=self._to_format) - local = remote.sprout('local') + self.make_controldir("local", format=self._to_format) + local = remote.sprout("local") self.assertEqual( - default_mapping.revision_id_foreign_to_bzr( - self.remote_real.head()), - local.open_branch().last_revision()) + default_mapping.revision_id_foreign_to_bzr(self.remote_real.head()), + local.open_branch().last_revision(), + ) self.assertRaises( - MissingNestedTree, - local.open_workingtree().get_nested_tree, 'nested') + MissingNestedTree, local.open_workingtree().get_nested_tree, "nested" + ) def test_sprout_submodule_relative(self): - self.sub_real = GitRepo.init('sub', mkdir=True) + self.sub_real = GitRepo.init("sub", mkdir=True) self.sub_real.do_commit( - message=b'message in sub', - committer=b'committer ', - author=b'author ') - - with open('remote/.gitmodules', 'w') as f: - f.write(""" + message=b"message in sub", + committer=b"committer ", + author=b"author ", + ) + + with open("remote/.gitmodules", "w") as f: + f.write( + """ [submodule "lala"] \tpath = nested \turl = ../sub/.git -""") - self.remote_real.stage('.gitmodules') - self.sub_real.clone('remote/nested') - self.remote_real.stage('nested') - self.permit_url(urljoin(self.remote_url, '../sub')) - self.assertIn(b'nested', self.remote_real.open_index()) +""" + ) + self.remote_real.stage(".gitmodules") + self.sub_real.clone("remote/nested") + self.remote_real.stage("nested") + self.permit_url(urljoin(self.remote_url, "../sub")) + self.assertIn(b"nested", self.remote_real.open_index()) self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) remote = ControlDir.open(self.remote_url) - self.make_controldir('local', format=self._to_format) - local = remote.sprout('local') + self.make_controldir("local", format=self._to_format) + local = remote.sprout("local") self.assertEqual( - default_mapping.revision_id_foreign_to_bzr( - self.remote_real.head()), - local.open_branch().last_revision()) + default_mapping.revision_id_foreign_to_bzr(self.remote_real.head()), + local.open_branch().last_revision(), + ) self.assertEqual( - default_mapping.revision_id_foreign_to_bzr( - self.sub_real.head()), - local.open_workingtree().get_nested_tree('nested').last_revision()) + default_mapping.revision_id_foreign_to_bzr(self.sub_real.head()), + local.open_workingtree().get_nested_tree("nested").last_revision(), + ) def test_sprout_with_tags(self): c1 = self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) c2 = self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ', - ref=b'refs/tags/another') - self.remote_real.refs[b'refs/tags/blah'] = self.remote_real.head() + message=b"another commit", + committer=b"committer ", + author=b"author ", + ref=b"refs/tags/another", + ) + self.remote_real.refs[b"refs/tags/blah"] = self.remote_real.head() remote = ControlDir.open(self.remote_url) - self.make_controldir('local', format=self._to_format) - local = remote.sprout('local') + self.make_controldir("local", format=self._to_format) + local = remote.sprout("local") local_branch = local.open_branch() self.assertEqual( - default_mapping.revision_id_foreign_to_bzr(c1), - local_branch.last_revision()) + default_mapping.revision_id_foreign_to_bzr(c1), local_branch.last_revision() + ) self.assertEqual( - {'blah': local_branch.last_revision(), - 'another': default_mapping.revision_id_foreign_to_bzr(c2)}, - local_branch.tags.get_tag_dict()) + { + "blah": local_branch.last_revision(), + "another": default_mapping.revision_id_foreign_to_bzr(c2), + }, + local_branch.tags.get_tag_dict(), + ) def test_sprout_with_annotated_tag(self): c1 = self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) c2 = self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ', - ref=b'refs/heads/another') + message=b"another commit", + committer=b"committer ", + author=b"author ", + ref=b"refs/heads/another", + ) porcelain.tag_create( self.remote_real, tag=b"blah", - author=b'author ', + author=b"author ", objectish=c2, tag_time=int(time.time()), tag_timezone=0, annotated=True, - message=b"Annotated tag") + message=b"Annotated tag", + ) remote = ControlDir.open(self.remote_url) - self.make_controldir('local', format=self._to_format) + self.make_controldir("local", format=self._to_format) local = remote.sprout( - 'local', revision_id=default_mapping.revision_id_foreign_to_bzr(c1)) + "local", revision_id=default_mapping.revision_id_foreign_to_bzr(c1) + ) local_branch = local.open_branch() self.assertEqual( - default_mapping.revision_id_foreign_to_bzr(c1), - local_branch.last_revision()) + default_mapping.revision_id_foreign_to_bzr(c1), local_branch.last_revision() + ) self.assertEqual( - {'blah': default_mapping.revision_id_foreign_to_bzr(c2)}, - local_branch.tags.get_tag_dict()) + {"blah": default_mapping.revision_id_foreign_to_bzr(c2)}, + local_branch.tags.get_tag_dict(), + ) def test_sprout_with_annotated_tag_unreferenced(self): c1 = self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ') + message=b"another commit", + committer=b"committer ", + author=b"author ", + ) porcelain.tag_create( self.remote_real, tag=b"blah", - author=b'author ', + author=b"author ", objectish=c1, tag_time=int(time.time()), tag_timezone=0, annotated=True, - message=b"Annotated tag") + message=b"Annotated tag", + ) remote = ControlDir.open(self.remote_url) - self.make_controldir('local', format=self._to_format) + self.make_controldir("local", format=self._to_format) local = remote.sprout( - 'local', - revision_id=default_mapping.revision_id_foreign_to_bzr(c1)) + "local", revision_id=default_mapping.revision_id_foreign_to_bzr(c1) + ) local_branch = local.open_branch() self.assertEqual( - default_mapping.revision_id_foreign_to_bzr(c1), - local_branch.last_revision()) + default_mapping.revision_id_foreign_to_bzr(c1), local_branch.last_revision() + ) self.assertEqual( - {'blah': default_mapping.revision_id_foreign_to_bzr(c1)}, - local_branch.tags.get_tag_dict()) + {"blah": default_mapping.revision_id_foreign_to_bzr(c1)}, + local_branch.tags.get_tag_dict(), + ) class FetchFromRemoteToBzrTests(FetchFromRemoteTestBase, TestCaseWithTransport): - - _to_format = '2a' + _to_format = "2a" class FetchFromRemoteToGitTests(FetchFromRemoteTestBase, TestCaseWithTransport): - - _to_format = 'git' + _to_format = "git" class PushToRemoteBase: - - _test_needs_features = [ExecutableFeature('git')] + _test_needs_features = [ExecutableFeature("git")] _from_format: str def setUp(self): TestCaseWithTransport.setUp(self) - self.remote_real = GitRepo.init('remote', mkdir=True) - self.remote_url = f'git://{os.path.abspath(self.remote_real.path)}/' + self.remote_real = GitRepo.init("remote", mkdir=True) + self.remote_url = f"git://{os.path.abspath(self.remote_real.path)}/" self.permit_url(self.remote_url) def test_push_branch_new(self): remote = ControlDir.open(self.remote_url) - wt = self.make_branch_and_tree('local', format=self._from_format) - self.build_tree(['local/blah']) - wt.add(['blah']) - wt.commit('blah') + wt = self.make_branch_and_tree("local", format=self._from_format) + self.build_tree(["local/blah"]) + wt.add(["blah"]) + wt.commit("blah") - if self._from_format == 'git': - result = remote.push_branch(wt.branch, name='newbranch') + if self._from_format == "git": + result = remote.push_branch(wt.branch, name="newbranch") else: - result = remote.push_branch( - wt.branch, lossy=True, name='newbranch') + result = remote.push_branch(wt.branch, lossy=True, name="newbranch") self.assertEqual(0, result.old_revno) - if self._from_format == 'git': + if self._from_format == "git": self.assertEqual(1, result.new_revno) else: self.assertIs(None, result.new_revno) @@ -532,33 +573,36 @@ def test_push_branch_new(self): result.report(BytesIO()) self.assertEqual( - {b'refs/heads/newbranch': self.remote_real.refs[b'refs/heads/newbranch'], - }, - self.remote_real.get_refs()) + { + b"refs/heads/newbranch": self.remote_real.refs[b"refs/heads/newbranch"], + }, + self.remote_real.get_refs(), + ) def test_push_branch_symref(self): cfg = self.remote_real.get_config() - cfg.set((b'core', ), b'bare', True) + cfg.set((b"core",), b"bare", True) cfg.write_to_path() - self.remote_real.refs.set_symbolic_ref(b'HEAD', b'refs/heads/master') + self.remote_real.refs.set_symbolic_ref(b"HEAD", b"refs/heads/master") self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ', - ref=b'refs/heads/master') + message=b"message", + committer=b"committer ", + author=b"author ", + ref=b"refs/heads/master", + ) remote = ControlDir.open(self.remote_url) - wt = self.make_branch_and_tree('local', format=self._from_format) - self.build_tree(['local/blah']) - wt.add(['blah']) - wt.commit('blah') + wt = self.make_branch_and_tree("local", format=self._from_format) + self.build_tree(["local/blah"]) + wt.add(["blah"]) + wt.commit("blah") - if self._from_format == 'git': + if self._from_format == "git": result = remote.push_branch(wt.branch, overwrite=True) else: result = remote.push_branch(wt.branch, lossy=True, overwrite=True) self.assertEqual(None, result.old_revno) - if self._from_format == 'git': + if self._from_format == "git": self.assertEqual(1, result.new_revno) else: self.assertIs(None, result.new_revno) @@ -567,38 +611,44 @@ def test_push_branch_symref(self): self.assertEqual( { - b'HEAD': self.remote_real.refs[b'refs/heads/master'], - b'refs/heads/master': self.remote_real.refs[b'refs/heads/master'], + b"HEAD": self.remote_real.refs[b"refs/heads/master"], + b"refs/heads/master": self.remote_real.refs[b"refs/heads/master"], }, - self.remote_real.get_refs()) + self.remote_real.get_refs(), + ) def test_push_branch_new_with_tags(self): remote = ControlDir.open(self.remote_url) - builder = self.make_branch_builder('local', format=self._from_format) + builder = self.make_branch_builder("local", format=self._from_format) builder.start_series() - rev_1 = builder.build_snapshot(None, [ - ('add', ('', None, 'directory', '')), - ('add', ('filename', None, 'file', b'content'))]) + rev_1 = builder.build_snapshot( + None, + [ + ("add", ("", None, "directory", "")), + ("add", ("filename", None, "file", b"content")), + ], + ) rev_2 = builder.build_snapshot( - [rev_1], [('modify', ('filename', b'new-content\n'))]) + [rev_1], [("modify", ("filename", b"new-content\n"))] + ) builder.build_snapshot( - [rev_1], [('modify', ('filename', b'new-new-content\n'))]) + [rev_1], [("modify", ("filename", b"new-new-content\n"))] + ) builder.finish_series() branch = builder.get_branch() try: - branch.tags.set_tag('atag', rev_2) + branch.tags.set_tag("atag", rev_2) except TagsNotSupported as err: - raise TestNotApplicable('source format does not support tags') from err + raise TestNotApplicable("source format does not support tags") from err - branch.get_config_stack().set('branch.fetch_tags', True) - if self._from_format == 'git': - result = remote.push_branch(branch, name='newbranch') + branch.get_config_stack().set("branch.fetch_tags", True) + if self._from_format == "git": + result = remote.push_branch(branch, name="newbranch") else: - result = remote.push_branch( - branch, lossy=True, name='newbranch') + result = remote.push_branch(branch, lossy=True, name="newbranch") self.assertEqual(0, result.old_revno) - if self._from_format == 'git': + if self._from_format == "git": self.assertEqual(2, result.new_revno) else: self.assertIs(None, result.new_revno) @@ -606,30 +656,31 @@ def test_push_branch_new_with_tags(self): result.report(BytesIO()) self.assertEqual( - {b'refs/heads/newbranch', b'refs/tags/atag'}, - set(self.remote_real.get_refs().keys())) + {b"refs/heads/newbranch", b"refs/tags/atag"}, + set(self.remote_real.get_refs().keys()), + ) def test_push(self): self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) remote = ControlDir.open(self.remote_url) - self.make_controldir('local', format=self._from_format) - local = remote.sprout('local') - self.build_tree(['local/blah']) + self.make_controldir("local", format=self._from_format) + local = remote.sprout("local") + self.build_tree(["local/blah"]) wt = local.open_workingtree() - wt.add(['blah']) - revid = wt.commit('blah') - wt.branch.tags.set_tag('sometag', revid) - wt.branch.get_config_stack().set('branch.fetch_tags', True) + wt.add(["blah"]) + revid = wt.commit("blah") + wt.branch.tags.set_tag("sometag", revid) + wt.branch.get_config_stack().set("branch.fetch_tags", True) - if self._from_format == 'git': - result = wt.branch.push(remote.create_branch('newbranch')) + if self._from_format == "git": + result = wt.branch.push(remote.create_branch("newbranch")) else: - result = wt.branch.push( - remote.create_branch('newbranch'), lossy=True) + result = wt.branch.push(remote.create_branch("newbranch"), lossy=True) self.assertEqual(0, result.old_revno) self.assertEqual(2, result.new_revno) @@ -637,275 +688,300 @@ def test_push(self): result.report(BytesIO()) self.assertEqual( - {b'refs/heads/master': self.remote_real.head(), - b'HEAD': self.remote_real.head(), - b'refs/heads/newbranch': self.remote_real.refs[b'refs/heads/newbranch'], - b'refs/tags/sometag': self.remote_real.refs[b'refs/heads/newbranch'], - }, - self.remote_real.get_refs()) + { + b"refs/heads/master": self.remote_real.head(), + b"HEAD": self.remote_real.head(), + b"refs/heads/newbranch": self.remote_real.refs[b"refs/heads/newbranch"], + b"refs/tags/sometag": self.remote_real.refs[b"refs/heads/newbranch"], + }, + self.remote_real.get_refs(), + ) def test_push_diverged(self): c1 = self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ', - ref=b'refs/heads/newbranch') + message=b"message", + committer=b"committer ", + author=b"author ", + ref=b"refs/heads/newbranch", + ) remote = ControlDir.open(self.remote_url) - wt = self.make_branch_and_tree('local', format=self._from_format) - self.build_tree(['local/blah']) - wt.add(['blah']) - wt.commit('blah') + wt = self.make_branch_and_tree("local", format=self._from_format) + self.build_tree(["local/blah"]) + wt.add(["blah"]) + wt.commit("blah") - newbranch = remote.open_branch('newbranch') - if self._from_format == 'git': + newbranch = remote.open_branch("newbranch") + if self._from_format == "git": self.assertRaises(DivergedBranches, wt.branch.push, newbranch) else: - self.assertRaises(DivergedBranches, wt.branch.push, - newbranch, lossy=True) + self.assertRaises(DivergedBranches, wt.branch.push, newbranch, lossy=True) - self.assertEqual( - {b'refs/heads/newbranch': c1}, - self.remote_real.get_refs()) + self.assertEqual({b"refs/heads/newbranch": c1}, self.remote_real.get_refs()) - if self._from_format == 'git': + if self._from_format == "git": wt.branch.push(newbranch, overwrite=True) else: wt.branch.push(newbranch, lossy=True, overwrite=True) - self.assertNotEqual(c1, self.remote_real.refs[b'refs/heads/newbranch']) + self.assertNotEqual(c1, self.remote_real.refs[b"refs/heads/newbranch"]) class PushToRemoteFromBzrTests(PushToRemoteBase, TestCaseWithTransport): - - _from_format = '2a' + _from_format = "2a" class PushToRemoteFromGitTests(PushToRemoteBase, TestCaseWithTransport): - - _from_format = 'git' + _from_format = "git" class RemoteControlDirTests(TestCaseWithTransport): - - _test_needs_features = [ExecutableFeature('git')] + _test_needs_features = [ExecutableFeature("git")] def setUp(self): TestCaseWithTransport.setUp(self) - self.remote_real = GitRepo.init('remote', mkdir=True) - self.remote_url = f'git://{os.path.abspath(self.remote_real.path)}/' + self.remote_real = GitRepo.init("remote", mkdir=True) + self.remote_url = f"git://{os.path.abspath(self.remote_real.path)}/" self.permit_url(self.remote_url) def test_remove_branch(self): self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ', - ref=b'refs/heads/blah') + message=b"another commit", + committer=b"committer ", + author=b"author ", + ref=b"refs/heads/blah", + ) remote = ControlDir.open(self.remote_url) - remote.destroy_branch(name='blah') + remote.destroy_branch(name="blah") self.assertEqual( self.remote_real.get_refs(), - {b'refs/heads/master': self.remote_real.head(), - b'HEAD': self.remote_real.head(), - }) + { + b"refs/heads/master": self.remote_real.head(), + b"HEAD": self.remote_real.head(), + }, + ) def test_list_branches(self): self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ', - ref=b'refs/heads/blah') + message=b"another commit", + committer=b"committer ", + author=b"author ", + ref=b"refs/heads/blah", + ) remote = ControlDir.open(self.remote_url) - self.assertEqual( - {'master', 'blah'}, - {b.name for b in remote.list_branches()}) + self.assertEqual({"master", "blah"}, {b.name for b in remote.list_branches()}) def test_get_branches(self): self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ', - ref=b'refs/heads/blah') + message=b"another commit", + committer=b"committer ", + author=b"author ", + ref=b"refs/heads/blah", + ) remote = ControlDir.open(self.remote_url) self.assertEqual( - {'': 'master', 'blah': 'blah', 'master': 'master'}, - {n: b.name for (n, b) in remote.get_branches().items()}) - self.assertEqual( - {'', 'blah', 'master'}, set(remote.branch_names())) + {"": "master", "blah": "blah", "master": "master"}, + {n: b.name for (n, b) in remote.get_branches().items()}, + ) + self.assertEqual({"", "blah", "master"}, set(remote.branch_names())) def test_remove_tag(self): self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ', - ref=b'refs/tags/blah') + message=b"another commit", + committer=b"committer ", + author=b"author ", + ref=b"refs/tags/blah", + ) remote = ControlDir.open(self.remote_url) remote_branch = remote.open_branch() - remote_branch.tags.delete_tag('blah') - self.assertRaises(NoSuchTag, remote_branch.tags.delete_tag, 'blah') + remote_branch.tags.delete_tag("blah") + self.assertRaises(NoSuchTag, remote_branch.tags.delete_tag, "blah") self.assertEqual( self.remote_real.get_refs(), - {b'refs/heads/master': self.remote_real.head(), - b'HEAD': self.remote_real.head(), - }) + { + b"refs/heads/master": self.remote_real.head(), + b"HEAD": self.remote_real.head(), + }, + ) def test_set_tag(self): c1 = self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ') + message=b"another commit", + committer=b"committer ", + author=b"author ", + ) remote = ControlDir.open(self.remote_url) remote.open_branch().tags.set_tag( - 'blah', default_mapping.revision_id_foreign_to_bzr(c1)) + "blah", default_mapping.revision_id_foreign_to_bzr(c1) + ) self.assertEqual( self.remote_real.get_refs(), - {b'refs/heads/master': self.remote_real.head(), - b'refs/tags/blah': c1, - b'HEAD': self.remote_real.head(), - }) + { + b"refs/heads/master": self.remote_real.head(), + b"refs/tags/blah": c1, + b"HEAD": self.remote_real.head(), + }, + ) def test_annotated_tag(self): self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) c2 = self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ') + message=b"another commit", + committer=b"committer ", + author=b"author ", + ) porcelain.tag_create( self.remote_real, tag=b"blah", - author=b'author ', + author=b"author ", objectish=c2, tag_time=int(time.time()), tag_timezone=0, annotated=True, - message=b"Annotated tag") + message=b"Annotated tag", + ) remote = ControlDir.open(self.remote_url) remote_branch = remote.open_branch() - self.assertEqual({ - 'blah': default_mapping.revision_id_foreign_to_bzr(c2)}, - remote_branch.tags.get_tag_dict()) + self.assertEqual( + {"blah": default_mapping.revision_id_foreign_to_bzr(c2)}, + remote_branch.tags.get_tag_dict(), + ) def test_get_branch_reference(self): self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) self.remote_real.do_commit( - message=b'another commit', - committer=b'committer ', - author=b'author ') + message=b"another commit", + committer=b"committer ", + author=b"author ", + ) remote = ControlDir.open(self.remote_url) self.assertEqual( - remote.user_url.rstrip('/') + ',branch=master', remote.get_branch_reference('')) - self.assertEqual(None, remote.get_branch_reference('master')) + remote.user_url.rstrip("/") + ",branch=master", + remote.get_branch_reference(""), + ) + self.assertEqual(None, remote.get_branch_reference("master")) def test_get_branch_nick(self): self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) remote = ControlDir.open(self.remote_url) - self.assertEqual('master', remote.open_branch().nick) + self.assertEqual("master", remote.open_branch().nick) class GitUrlAndPathFromTransportTests(TestCase): - def test_file(self): - split_url = _git_url_and_path_from_transport('file:///home/blah') - self.assertEqual(split_url.scheme, 'file') - self.assertEqual(split_url.path, '/home/blah') + split_url = _git_url_and_path_from_transport("file:///home/blah") + self.assertEqual(split_url.scheme, "file") + self.assertEqual(split_url.path, "/home/blah") def test_file_segment_params(self): - split_url = _git_url_and_path_from_transport('file:///home/blah,branch=master') - self.assertEqual(split_url.scheme, 'file') - self.assertEqual(split_url.path, '/home/blah') + split_url = _git_url_and_path_from_transport("file:///home/blah,branch=master") + self.assertEqual(split_url.scheme, "file") + self.assertEqual(split_url.path, "/home/blah") def test_git_smart(self): split_url = _git_url_and_path_from_transport( - 'git://github.com/dulwich/dulwich,branch=master') - self.assertEqual(split_url.scheme, 'git') - self.assertEqual(split_url.path, '/dulwich/dulwich') + "git://github.com/dulwich/dulwich,branch=master" + ) + self.assertEqual(split_url.scheme, "git") + self.assertEqual(split_url.path, "/dulwich/dulwich") def test_https(self): split_url = _git_url_and_path_from_transport( - 'https://github.com/dulwich/dulwich') - self.assertEqual(split_url.scheme, 'https') - self.assertEqual(split_url.path, '/dulwich/dulwich') + "https://github.com/dulwich/dulwich" + ) + self.assertEqual(split_url.scheme, "https") + self.assertEqual(split_url.path, "/dulwich/dulwich") def test_https_segment_params(self): split_url = _git_url_and_path_from_transport( - 'https://github.com/dulwich/dulwich,branch=master') - self.assertEqual(split_url.scheme, 'https') - self.assertEqual(split_url.path, '/dulwich/dulwich') + "https://github.com/dulwich/dulwich,branch=master" + ) + self.assertEqual(split_url.scheme, "https") + self.assertEqual(split_url.path, "/dulwich/dulwich") class RemoteRevisionTreeTests(TestCaseWithTransport): - - _test_needs_features = [ExecutableFeature('git')] + _test_needs_features = [ExecutableFeature("git")] def setUp(self): TestCaseWithTransport.setUp(self) - self.remote_real = GitRepo.init('remote', mkdir=True) - self.remote_url = f'git://{os.path.abspath(self.remote_real.path)}/' + self.remote_real = GitRepo.init("remote", mkdir=True) + self.remote_url = f"git://{os.path.abspath(self.remote_real.path)}/" self.permit_url(self.remote_url) self.remote_real.do_commit( - message=b'message', - committer=b'committer ', - author=b'author ') + message=b"message", + committer=b"committer ", + author=b"author ", + ) def test_open(self): br = Branch.open(self.remote_url) t = br.basis_tree() self.assertIsInstance(t, GitRemoteRevisionTree) - self.assertRaises(GitSmartRemoteNotSupported, t.is_versioned, 'la') - self.assertRaises(GitSmartRemoteNotSupported, t.has_filename, 'la') - self.assertRaises(GitSmartRemoteNotSupported, t.get_file_text, 'la') - self.assertRaises(GitSmartRemoteNotSupported, t.list_files, 'la') + self.assertRaises(GitSmartRemoteNotSupported, t.is_versioned, "la") + self.assertRaises(GitSmartRemoteNotSupported, t.has_filename, "la") + self.assertRaises(GitSmartRemoteNotSupported, t.get_file_text, "la") + self.assertRaises(GitSmartRemoteNotSupported, t.list_files, "la") def test_archive(self): br = Branch.open(self.remote_url) t = br.basis_tree() - chunks = list(t.archive('tgz', 'foo.tar.gz')) - with gzip.GzipFile(fileobj=BytesIO(b''.join(chunks))) as g: - self.assertEqual('', g.name) + chunks = list(t.archive("tgz", "foo.tar.gz")) + with gzip.GzipFile(fileobj=BytesIO(b"".join(chunks))) as g: + self.assertEqual("", g.name) def test_archive_unsupported(self): # archive is not supported over HTTP, so simulate that br = Branch.open(self.remote_url) t = br.basis_tree() + def raise_unsupp(*args, **kwargs): raise GitSmartRemoteNotSupported(raise_unsupp, None) - self.overrideAttr(t._repository.controldir._client, 'archive', raise_unsupp) - self.assertRaises(GitSmartRemoteNotSupported, t.archive, 'tgz', 'foo.tar.gz') + + self.overrideAttr(t._repository.controldir._client, "archive", raise_unsupp) + self.assertRaises(GitSmartRemoteNotSupported, t.archive, "tgz", "foo.tar.gz") diff --git a/breezy/git/tests/test_repository.py b/breezy/git/tests/test_repository.py index 5ffaa48540..8d85ac734a 100644 --- a/breezy/git/tests/test_repository.py +++ b/breezy/git/tests/test_repository.py @@ -35,38 +35,38 @@ class TestGitRepositoryFeatures(tests.TestCaseInTempDir): def _do_commit(self): builder = tests.GitBranchBuilder() - builder.set_file(b'a', b'text for a\n', False) - commit_handle = builder.commit(b'Joe Foo ', b'message') + builder.set_file(b"a", b"text for a\n", False) + commit_handle = builder.commit(b"Joe Foo ", b"message") mapping = builder.finish() return mapping[commit_handle] def test_open_existing(self): GitRepo.init(self.test_dir) - repo = Repository.open('.') + repo = Repository.open(".") self.assertIsInstance(repo, repository.GitRepository) def test_has_git_repo(self): GitRepo.init(self.test_dir) - repo = Repository.open('.') + repo = Repository.open(".") self.assertIsInstance(repo._git, dulwich.repo.BaseRepo) def test_has_revision(self): GitRepo.init(self.test_dir) commit_id = self._do_commit() - repo = Repository.open('.') - self.assertFalse(repo.has_revision(b'foobar')) + repo = Repository.open(".") + self.assertFalse(repo.has_revision(b"foobar")) revid = default_mapping.revision_id_foreign_to_bzr(commit_id) self.assertTrue(repo.has_revision(revid)) def test_has_revisions(self): GitRepo.init(self.test_dir) commit_id = self._do_commit() - repo = Repository.open('.') - self.assertEqual(set(), repo.has_revisions([b'foobar'])) + repo = Repository.open(".") + self.assertEqual(set(), repo.has_revisions([b"foobar"])) revid = default_mapping.revision_id_foreign_to_bzr(commit_id) - self.assertEqual({revid}, repo.has_revisions([b'foobar', revid])) + self.assertEqual({revid}, repo.has_revisions([b"foobar", revid])) def test_get_revision(self): # GitRepository.get_revision gives a Revision object. @@ -77,37 +77,38 @@ def test_get_revision(self): # Get the corresponding Revision object. revid = default_mapping.revision_id_foreign_to_bzr(commit_id) - repo = Repository.open('.') + repo = Repository.open(".") rev = repo.get_revision(revid) self.assertIsInstance(rev, revision.Revision) def test_get_revision_unknown(self): GitRepo.init(self.test_dir) - repo = Repository.open('.') + repo = Repository.open(".") self.assertRaises(errors.NoSuchRevision, repo.get_revision, b"bla") def simple_commit(self): # Create a git repository with some interesting files in a revision. GitRepo.init(self.test_dir) builder = tests.GitBranchBuilder() - builder.set_file(b'data', b'text\n', False) - builder.set_file(b'executable', b'content', True) - builder.set_symlink(b'link', b'broken') - builder.set_file(b'subdir/subfile', b'subdir text\n', False) - commit_handle = builder.commit(b'Joe Foo ', b'message', - timestamp=1205433193) + builder.set_file(b"data", b"text\n", False) + builder.set_file(b"executable", b"content", True) + builder.set_symlink(b"link", b"broken") + builder.set_file(b"subdir/subfile", b"subdir text\n", False) + commit_handle = builder.commit( + b"Joe Foo ", b"message", timestamp=1205433193 + ) mapping = builder.finish() return mapping[commit_handle] def test_pack(self): self.simple_commit() - repo = Repository.open('.') + repo = Repository.open(".") repo.pack() def test_unlock_closes(self): self.simple_commit() - repo = Repository.open('.') + repo = Repository.open(".") repo.pack() with repo.lock_read(): repo.all_revision_ids() @@ -117,18 +118,17 @@ def test_unlock_closes(self): def test_revision_tree(self): commit_id = self.simple_commit() revid = default_mapping.revision_id_foreign_to_bzr(commit_id) - repo = Repository.open('.') + repo = Repository.open(".") tree = repo.revision_tree(revid) self.assertEqual(tree.get_revision_id(), revid) self.assertEqual(b"text\n", tree.get_file_text("data")) class TestGitRepository(tests.TestCaseWithTransport): - def _do_commit(self): builder = tests.GitBranchBuilder() - builder.set_file(b'a', b'text for a\n', False) - commit_handle = builder.commit(b'Joe Foo ', b'message') + builder.set_file(b"a", b"text for a\n", False) + commit_handle = builder.commit(b"Joe Foo ", b"message") mapping = builder.finish() return mapping[commit_handle] @@ -143,11 +143,15 @@ def test_supports_rich_root(self): def test_get_signature_text(self): self.assertRaises( - errors.NoSuchRevision, self.git_repo.get_signature_text, revision.NULL_REVISION) + errors.NoSuchRevision, + self.git_repo.get_signature_text, + revision.NULL_REVISION, + ) def test_has_signature_for_revision_id(self): - self.assertEqual(False, self.git_repo.has_signature_for_revision_id( - revision.NULL_REVISION)) + self.assertEqual( + False, self.git_repo.has_signature_for_revision_id(revision.NULL_REVISION) + ) def test_all_revision_ids_none(self): self.assertEqual([], self.git_repo.all_revision_ids()) @@ -156,17 +160,21 @@ def test_get_known_graph_ancestry(self): cid = self._do_commit() revid = default_mapping.revision_id_foreign_to_bzr(cid) g = self.git_repo.get_known_graph_ancestry([revid]) - self.assertEqual(frozenset([revid]), - g.heads([revid])) - self.assertEqual([(revid, 0, (1,), True)], - [(n.key, n.merge_depth, n.revno, n.end_of_merge) - for n in g.merge_sort(revid)]) + self.assertEqual(frozenset([revid]), g.heads([revid])) + self.assertEqual( + [(revid, 0, (1,), True)], + [ + (n.key, n.merge_depth, n.revno, n.end_of_merge) + for n in g.merge_sort(revid) + ], + ) def test_all_revision_ids(self): commit_id = self._do_commit() self.assertEqual( [default_mapping.revision_id_foreign_to_bzr(commit_id)], - self.git_repo.all_revision_ids()) + self.git_repo.all_revision_ids(), + ) def assertIsNullInventory(self, inv): self.assertEqual(inv.root, None) @@ -180,108 +188,119 @@ def test_revision_tree_none(self): self.assertEqual(tree.get_revision_id(), revision.NULL_REVISION) def test_get_parent_map_null(self): - self.assertEqual({revision.NULL_REVISION: ()}, - self.git_repo.get_parent_map([revision.NULL_REVISION])) + self.assertEqual( + {revision.NULL_REVISION: ()}, + self.git_repo.get_parent_map([revision.NULL_REVISION]), + ) class SigningGitRepository(tests.TestCaseWithTransport): - def test_signed_commit(self): import breezy.gpg + oldstrategy = breezy.gpg.GPGStrategy - wt = self.make_branch_and_tree('.', format='git') + wt = self.make_branch_and_tree(".", format="git") branch = wt.branch revid = wt.commit("base", allow_pointless=True) - self.assertFalse( - branch.repository.has_signature_for_revision_id(revid)) + self.assertFalse(branch.repository.has_signature_for_revision_id(revid)) try: breezy.gpg.GPGStrategy = breezy.gpg.LoopbackGPGStrategy - conf = config.MemoryStack(b''' + conf = config.MemoryStack( + b""" create_signatures=always -''') - revid2 = wt.commit(config=conf, message="base", - allow_pointless=True) +""" + ) + revid2 = wt.commit(config=conf, message="base", allow_pointless=True) def sign(text): return breezy.gpg.LoopbackGPGStrategy(None).sign(text) - self.assertIsInstance( - branch.repository.get_signature_text(revid2), bytes) + + self.assertIsInstance(branch.repository.get_signature_text(revid2), bytes) finally: breezy.gpg.GPGStrategy = oldstrategy class RevpropsRepository(tests.TestCaseWithTransport): - def test_author(self): - wt = self.make_branch_and_tree('.', format='git') + wt = self.make_branch_and_tree(".", format="git") revid = wt.commit( - "base", allow_pointless=True, - revprops={'author': 'Joe Example '}) + "base", + allow_pointless=True, + revprops={"author": "Joe Example "}, + ) wt.branch.repository.get_revision(revid) - r = dulwich.repo.Repo('.') - self.assertEqual(b'Joe Example ', r[r.head()].author) + r = dulwich.repo.Repo(".") + self.assertEqual(b"Joe Example ", r[r.head()].author) def test_authors_single_author(self): - wt = self.make_branch_and_tree('.', format='git') + wt = self.make_branch_and_tree(".", format="git") revid = wt.commit( - "base", allow_pointless=True, - revprops={'authors': 'Joe Example '}) + "base", + allow_pointless=True, + revprops={"authors": "Joe Example "}, + ) wt.branch.repository.get_revision(revid) - r = dulwich.repo.Repo('.') - self.assertEqual(b'Joe Example ', r[r.head()].author) + r = dulwich.repo.Repo(".") + self.assertEqual(b"Joe Example ", r[r.head()].author) def test_multiple_authors(self): - wt = self.make_branch_and_tree('.', format='git') + wt = self.make_branch_and_tree(".", format="git") self.assertRaises( - Exception, wt.commit, "base", allow_pointless=True, - revprops={'authors': 'Joe Example \n' - 'Jane Doe '}) + Exception, + wt.commit, + "base", + allow_pointless=True, + revprops={ + "authors": "Joe Example \n" + "Jane Doe " + }, + ) def test_bugs(self): - wt = self.make_branch_and_tree('.', format='git') + wt = self.make_branch_and_tree(".", format="git") revid = wt.commit( - "base", allow_pointless=True, - revprops={ - 'bugs': 'https://github.com/jelmer/dulwich/issues/123 fixed\n' - }) + "base", + allow_pointless=True, + revprops={"bugs": "https://github.com/jelmer/dulwich/issues/123 fixed\n"}, + ) wt.branch.repository.get_revision(revid) - r = dulwich.repo.Repo('.') + r = dulwich.repo.Repo(".") self.assertEqual( - b'base\n\n' - b'Fixes: https://github.com/jelmer/dulwich/issues/123\n', - r[r.head()].message) + b"base\n\n" b"Fixes: https://github.com/jelmer/dulwich/issues/123\n", + r[r.head()].message, + ) def test_authors(self): - wt = self.make_branch_and_tree('.', format='git') + wt = self.make_branch_and_tree(".", format="git") revid = wt.commit( - "base", allow_pointless=True, + "base", + allow_pointless=True, revprops={ - 'authors': ( - 'Jelmer Vernooij \n' - 'Martin Packman \n'), - }) + "authors": ( + "Jelmer Vernooij \n" + "Martin Packman \n" + ), + }, + ) wt.branch.repository.get_revision(revid) - r = dulwich.repo.Repo('.') + r = dulwich.repo.Repo(".") + self.assertEqual(r[r.head()].author, b"Jelmer Vernooij ") self.assertEqual( - r[r.head()].author, b'Jelmer Vernooij ') - self.assertEqual( - b'base\n\nCo-authored-by: Martin Packman \n', - r[r.head()].message) + b"base\n\nCo-authored-by: Martin Packman \n", + r[r.head()].message, + ) class GitRepositoryFormat(tests.TestCase): - def setUp(self): super().setUp() self.format = repository.GitRepositoryFormat() def test_get_format_description(self): - self.assertEqual("Git Repository", - self.format.get_format_description()) + self.assertEqual("Git Repository", self.format.get_format_description()) class RevisionGistImportTests(tests.TestCaseWithTransport): - def setUp(self): tests.TestCaseWithTransport.setUp(self) self.git_path = os.path.join(self.test_dir, "git") @@ -291,14 +310,11 @@ def setUp(self): self.bzr_tree = self.make_branch_and_tree("bzr") def get_inter(self): - return InterRepository.get(self.bzr_tree.branch.repository, - self.git_repo) + return InterRepository.get(self.bzr_tree.branch.repository, self.git_repo) def object_iter(self): - store = BazaarObjectStore( - self.bzr_tree.branch.repository, default_mapping) - store_iterator = MissingObjectsIterator( - store, self.bzr_tree.branch.repository) + store = BazaarObjectStore(self.bzr_tree.branch.repository, default_mapping) + store_iterator = MissingObjectsIterator(store, self.bzr_tree.branch.repository) return store, store_iterator def import_rev(self, revid, parent_lookup=None): @@ -313,15 +329,24 @@ def import_rev(self, revid, parent_lookup=None): store._cache.idmap.commit_write_group() def test_pointless(self): - revid = self.bzr_tree.commit("pointless", timestamp=1205433193, - timezone=0, committer="Jelmer Vernooij ") - self.assertEqual(b"2caa8094a5b794961cd9bf582e3e2bb090db0b14", - self.import_rev(revid)) - self.assertEqual(b"2caa8094a5b794961cd9bf582e3e2bb090db0b14", - self.import_rev(revid)) + revid = self.bzr_tree.commit( + "pointless", + timestamp=1205433193, + timezone=0, + committer="Jelmer Vernooij ", + ) + self.assertEqual( + b"2caa8094a5b794961cd9bf582e3e2bb090db0b14", self.import_rev(revid) + ) + self.assertEqual( + b"2caa8094a5b794961cd9bf582e3e2bb090db0b14", self.import_rev(revid) + ) class ForeignTestsRepositoryFactory: - def make_repository(self, transport): - return dir.LocalGitControlDirFormat().initialize_on_transport(transport).open_repository() + return ( + dir.LocalGitControlDirFormat() + .initialize_on_transport(transport) + .open_repository() + ) diff --git a/breezy/git/tests/test_revspec.py b/breezy/git/tests/test_revspec.py index d2abd1055f..3839ca2de6 100644 --- a/breezy/git/tests/test_revspec.py +++ b/breezy/git/tests/test_revspec.py @@ -21,7 +21,6 @@ class Sha1ValidTests(TestCase): - def test_invalid(self): self.assertFalse(valid_git_sha1(b"git-v1:abcde")) diff --git a/breezy/git/tests/test_roundtrip.py b/breezy/git/tests/test_roundtrip.py index a793821c18..894c345401 100644 --- a/breezy/git/tests/test_roundtrip.py +++ b/breezy/git/tests/test_roundtrip.py @@ -28,7 +28,6 @@ class RoundtripTests(TestCase): - def test_revid(self): md = parse_roundtripping_metadata(b"revision-id: foo\n") self.assertEqual(b"foo", md.revision_id) @@ -43,52 +42,56 @@ def test_properties(self): class FormatTests(TestCase): - def test_revid(self): metadata = CommitSupplement() metadata.revision_id = b"bla" - self.assertEqual(b"revision-id: bla\n", - generate_roundtripping_metadata(metadata, "utf-8")) + self.assertEqual( + b"revision-id: bla\n", generate_roundtripping_metadata(metadata, "utf-8") + ) def test_parent_ids(self): metadata = CommitSupplement() metadata.explicit_parent_ids = (b"foo", b"bar") - self.assertEqual(b"parent-ids: foo bar\n", - generate_roundtripping_metadata(metadata, "utf-8")) + self.assertEqual( + b"parent-ids: foo bar\n", generate_roundtripping_metadata(metadata, "utf-8") + ) def test_properties(self): metadata = CommitSupplement() metadata.properties = {b"foo": b"bar"} - self.assertEqual(b"property-foo: bar\n", - generate_roundtripping_metadata(metadata, "utf-8")) + self.assertEqual( + b"property-foo: bar\n", generate_roundtripping_metadata(metadata, "utf-8") + ) def test_empty(self): metadata = CommitSupplement() - self.assertEqual(b"", - generate_roundtripping_metadata(metadata, "utf-8")) + self.assertEqual(b"", generate_roundtripping_metadata(metadata, "utf-8")) class ExtractMetadataTests(TestCase): - def test_roundtrip(self): - (msg, metadata) = extract_bzr_metadata(b"""Foo + (msg, metadata) = extract_bzr_metadata( + b"""Foo --BZR-- revision-id: foo -""") +""" + ) self.assertEqual(b"Foo", msg) self.assertEqual(b"foo", metadata.revision_id) class GenerateMetadataTests(TestCase): - def test_roundtrip(self): metadata = CommitSupplement() metadata.revision_id = b"myrevid" msg = inject_bzr_metadata(b"Foo", metadata, "utf-8") - self.assertEqual(b"""Foo + self.assertEqual( + b"""Foo --BZR-- revision-id: myrevid -""", msg) +""", + msg, + ) def test_no_metadata(self): metadata = CommitSupplement() diff --git a/breezy/git/tests/test_server.py b/breezy/git/tests/test_server.py index c6b38cb874..dbcbad9071 100644 --- a/breezy/git/tests/test_server.py +++ b/breezy/git/tests/test_server.py @@ -27,17 +27,15 @@ class TestPresent(TestCase): - def test_present(self): # Just test that the server is registered. - transport_server_registry.get('git') + transport_server_registry.get("git") class GitServerTestCase(TestCaseWithTransport): - def start_server(self, t): backend = BzrBackend(t) - server = BzrTCPGitServer(backend, 'localhost', port=0) + server = BzrTCPGitServer(backend, "localhost", port=0) self.addCleanup(server.shutdown) threading.Thread(target=server.serve).start() self._server = server @@ -46,48 +44,43 @@ def start_server(self, t): class TestPlainFetch(GitServerTestCase): - def test_fetch_from_native_git(self): - wt = self.make_branch_and_tree('t', format='git') - self.build_tree(['t/foo']) - wt.add('foo') + wt = self.make_branch_and_tree("t", format="git") + self.build_tree(["t/foo"]) + wt.add("foo") revid = wt.commit(message="some data") wt.branch.tags.set_tag("atag", revid) - t = self.get_transport('t') + t = self.get_transport("t") port = self.start_server(t) - c = TCPGitClient('localhost', port=port) - gitrepo = Repo.init('gitrepo', mkdir=True) - result = c.fetch('/', gitrepo) + c = TCPGitClient("localhost", port=port) + gitrepo = Repo.init("gitrepo", mkdir=True) + result = c.fetch("/", gitrepo) self.assertEqual( - set(result.refs.keys()), - {b"refs/tags/atag", b'refs/heads/master', b"HEAD"}) + set(result.refs.keys()), {b"refs/tags/atag", b"refs/heads/master", b"HEAD"} + ) def test_fetch_nothing(self): - wt = self.make_branch_and_tree('t') - self.build_tree(['t/foo']) - wt.add('foo') + wt = self.make_branch_and_tree("t") + self.build_tree(["t/foo"]) + wt.add("foo") revid = wt.commit(message="some data") wt.branch.tags.set_tag("atag", revid) - t = self.get_transport('t') + t = self.get_transport("t") port = self.start_server(t) - c = TCPGitClient('localhost', port=port) - gitrepo = Repo.init('gitrepo', mkdir=True) - result = c.fetch('/', gitrepo, determine_wants=lambda x: []) - self.assertEqual( - set(result.refs.keys()), - {b"refs/tags/atag", b"HEAD"}) + c = TCPGitClient("localhost", port=port) + gitrepo = Repo.init("gitrepo", mkdir=True) + result = c.fetch("/", gitrepo, determine_wants=lambda x: []) + self.assertEqual(set(result.refs.keys()), {b"refs/tags/atag", b"HEAD"}) def test_fetch_from_non_git(self): - wt = self.make_branch_and_tree('t', format='bzr') - self.build_tree(['t/foo']) - wt.add('foo') + wt = self.make_branch_and_tree("t", format="bzr") + self.build_tree(["t/foo"]) + wt.add("foo") revid = wt.commit(message="some data") wt.branch.tags.set_tag("atag", revid) - t = self.get_transport('t') + t = self.get_transport("t") port = self.start_server(t) - c = TCPGitClient('localhost', port=port) - gitrepo = Repo.init('gitrepo', mkdir=True) - result = c.fetch('/', gitrepo) - self.assertEqual( - set(result.refs.keys()), - {b"refs/tags/atag", b"HEAD"}) + c = TCPGitClient("localhost", port=port) + gitrepo = Repo.init("gitrepo", mkdir=True) + result = c.fetch("/", gitrepo) + self.assertEqual(set(result.refs.keys()), {b"refs/tags/atag", b"HEAD"}) diff --git a/breezy/git/tests/test_transform.py b/breezy/git/tests/test_transform.py index 9d78c319ad..ea062c6feb 100644 --- a/breezy/git/tests/test_transform.py +++ b/breezy/git/tests/test_transform.py @@ -23,30 +23,28 @@ class GitTransformTests(TestCaseWithTransport): - def test_directory_exists(self): - tree = self.make_branch_and_tree('.', format='git') + tree = self.make_branch_and_tree(".", format="git") tt = tree.transform() - dir1 = tt.new_directory('dir', ROOT_PARENT) - tt.new_file('name1', dir1, [b'content1']) - dir2 = tt.new_directory('dir', ROOT_PARENT) - tt.new_file('name2', dir2, [b'content2']) - raw_conflicts = resolve_conflicts( - tt, None, lambda t, c: conflict_pass(t, c)) + dir1 = tt.new_directory("dir", ROOT_PARENT) + tt.new_file("name1", dir1, [b"content1"]) + dir2 = tt.new_directory("dir", ROOT_PARENT) + tt.new_file("name2", dir2, [b"content2"]) + raw_conflicts = resolve_conflicts(tt, None, lambda t, c: conflict_pass(t, c)) conflicts = tt.cook_conflicts(raw_conflicts) self.assertEqual([], list(conflicts)) tt.apply() - self.assertEqual({'name1', 'name2'}, set(os.listdir('dir'))) + self.assertEqual({"name1", "name2"}, set(os.listdir("dir"))) def test_revert_does_not_remove(self): - tree = self.make_branch_and_tree('.', format='git') + tree = self.make_branch_and_tree(".", format="git") tt = tree.transform() - dir1 = tt.new_directory('dir', ROOT_PARENT) - tid = tt.new_file('name1', dir1, [b'content1']) + dir1 = tt.new_directory("dir", ROOT_PARENT) + tid = tt.new_file("name1", dir1, [b"content1"]) tt.version_file(tid) tt.apply() - tree.commit('start') - with open('dir/name1', 'wb') as f: - f.write(b'new content2') + tree.commit("start") + with open("dir/name1", "wb") as f: + f.write(b"new content2") revert(tree, tree.basis_tree()) self.assertEqual([], list(tree.iter_changes(tree.basis_tree()))) diff --git a/breezy/git/tests/test_transportgit.py b/breezy/git/tests/test_transportgit.py index ddfbb49d27..31e8f2f63a 100644 --- a/breezy/git/tests/test_transportgit.py +++ b/breezy/git/tests/test_transportgit.py @@ -25,7 +25,6 @@ class TransportObjectStoreTests(PackBasedObjectStoreTests, TestCaseWithTransport): - def setUp(self): TestCaseWithTransport.setUp(self) self.store = TransportObjectStore.init(self.get_transport()) @@ -40,12 +39,13 @@ def test_prefers_pack_listdir(self): self.store.pack_loose_objects() self.assertEqual(1, len(self.store.packs), self.store.packs) packname = list(self.store.packs)[0].name() - self.assertEqual({f"pack-{packname.decode('ascii')}"}, - set(self.store._pack_names())) - self.store.transport.put_bytes_non_atomic('info/packs', - b'P foo-pack.pack\n') - self.assertEqual({f"pack-{packname.decode('ascii')}"}, - set(self.store._pack_names())) + self.assertEqual( + {f"pack-{packname.decode('ascii')}"}, set(self.store._pack_names()) + ) + self.store.transport.put_bytes_non_atomic("info/packs", b"P foo-pack.pack\n") + self.assertEqual( + {f"pack-{packname.decode('ascii')}"}, set(self.store._pack_names()) + ) def test_remembers_packs(self): self.store.add_object(make_object(Blob, data=b"data")) @@ -65,8 +65,8 @@ def test_remembers_packs(self): # FIXME: Unfortunately RefsContainerTests requires on a specific set of refs existing. -class TransportRefContainerTests(TestCaseWithTransport): +class TransportRefContainerTests(TestCaseWithTransport): def setUp(self): TestCaseWithTransport.setUp(self) self._refs = TransportRefsContainer(self.get_transport()) @@ -75,9 +75,12 @@ def test_packed_refs_missing(self): self.assertEqual({}, self._refs.get_packed_refs()) def test_packed_refs(self): - self.get_transport().put_bytes_non_atomic('packed-refs', - b'# pack-refs with: peeled fully-peeled sorted \n' - b'2001b954f1ec392f84f7cec2f2f96a76ed6ba4ee refs/heads/master') + self.get_transport().put_bytes_non_atomic( + "packed-refs", + b"# pack-refs with: peeled fully-peeled sorted \n" + b"2001b954f1ec392f84f7cec2f2f96a76ed6ba4ee refs/heads/master", + ) self.assertEqual( - {b'refs/heads/master': b'2001b954f1ec392f84f7cec2f2f96a76ed6ba4ee'}, - self._refs.get_packed_refs()) + {b"refs/heads/master": b"2001b954f1ec392f84f7cec2f2f96a76ed6ba4ee"}, + self._refs.get_packed_refs(), + ) diff --git a/breezy/git/tests/test_tree.py b/breezy/git/tests/test_tree.py index 0086a19e9a..163b7bdfb4 100644 --- a/breezy/git/tests/test_tree.py +++ b/breezy/git/tests/test_tree.py @@ -34,355 +34,670 @@ class ChangesFromGitChangesTests(TestCase): - def setUp(self): super().setUp() self.maxDiff = None self.mapping = default_mapping def transform( - self, changes, specific_files=None, include_unchanged=False, - source_extras=None, target_extras=None): - return list(changes_from_git_changes( - changes, self.mapping, specific_files=specific_files, - include_unchanged=include_unchanged, source_extras=source_extras, - target_extras=target_extras)) + self, + changes, + specific_files=None, + include_unchanged=False, + source_extras=None, + target_extras=None, + ): + return list( + changes_from_git_changes( + changes, + self.mapping, + specific_files=specific_files, + include_unchanged=include_unchanged, + source_extras=source_extras, + target_extras=target_extras, + ) + ) def test_empty(self): self.assertEqual([], self.transform([])) def test_modified(self): - a = Blob.from_string(b'a') - b = Blob.from_string(b'b') - self.assertEqual([ - TreeChange( - b'git:a', ('a', 'a'), True, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'a'), ('file', 'file'), - (False, False), False) - ], self.transform([ - ('modify', - (b'a', stat.S_IFREG | 0o644, a), (b'a', stat.S_IFREG | 0o644, b))])) + a = Blob.from_string(b"a") + b = Blob.from_string(b"b") + self.assertEqual( + [ + TreeChange( + b"git:a", + ("a", "a"), + True, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "a"), + ("file", "file"), + (False, False), + False, + ) + ], + self.transform( + [ + ( + "modify", + (b"a", stat.S_IFREG | 0o644, a), + (b"a", stat.S_IFREG | 0o644, b), + ) + ] + ), + ) def test_kind_changed(self): - a = Blob.from_string(b'a') - b = Blob.from_string(b'target') - self.assertEqual([ - TreeChange( - b'git:a', ('a', 'a'), True, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'a'), ('file', 'symlink'), - (False, False), False) - ], self.transform([ - ('modify', - (b'a', stat.S_IFREG | 0o644, a), (b'a', stat.S_IFLNK, b))])) + a = Blob.from_string(b"a") + b = Blob.from_string(b"target") + self.assertEqual( + [ + TreeChange( + b"git:a", + ("a", "a"), + True, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "a"), + ("file", "symlink"), + (False, False), + False, + ) + ], + self.transform( + [("modify", (b"a", stat.S_IFREG | 0o644, a), (b"a", stat.S_IFLNK, b))] + ), + ) def test_rename_no_changes(self): - a = Blob.from_string(b'a') - self.assertEqual([ - TreeChange( - b'git:old', ('old', 'a'), False, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('old', 'a'), - ('file', 'file'), (False, False), False) - ], self.transform([ - ('rename', - (b'old', stat.S_IFREG | 0o644, a), (b'a', stat.S_IFREG | 0o644, a))])) + a = Blob.from_string(b"a") + self.assertEqual( + [ + TreeChange( + b"git:old", + ("old", "a"), + False, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("old", "a"), + ("file", "file"), + (False, False), + False, + ) + ], + self.transform( + [ + ( + "rename", + (b"old", stat.S_IFREG | 0o644, a), + (b"a", stat.S_IFREG | 0o644, a), + ) + ] + ), + ) def test_rename_and_modify(self): - a = Blob.from_string(b'a') - b = Blob.from_string(b'b') - self.assertEqual([ - TreeChange( - b'git:a', ('a', 'b'), True, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'b'), - ('file', 'file'), (False, False), False) - ], self.transform([ - ('rename', - (b'a', stat.S_IFREG | 0o644, a), (b'b', stat.S_IFREG | 0o644, b))])) + a = Blob.from_string(b"a") + b = Blob.from_string(b"b") + self.assertEqual( + [ + TreeChange( + b"git:a", + ("a", "b"), + True, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "b"), + ("file", "file"), + (False, False), + False, + ) + ], + self.transform( + [ + ( + "rename", + (b"a", stat.S_IFREG | 0o644, a), + (b"b", stat.S_IFREG | 0o644, b), + ) + ] + ), + ) def test_copy_no_changes(self): - a = Blob.from_string(b'a') - self.assertEqual([ - TreeChange( - b'git:a', ('old', 'a'), False, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('old', 'a'), - ('file', 'file'), (False, False), True) - ], self.transform([ - ('copy', - (b'old', stat.S_IFREG | 0o644, a), (b'a', stat.S_IFREG | 0o644, a))])) + a = Blob.from_string(b"a") + self.assertEqual( + [ + TreeChange( + b"git:a", + ("old", "a"), + False, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("old", "a"), + ("file", "file"), + (False, False), + True, + ) + ], + self.transform( + [ + ( + "copy", + (b"old", stat.S_IFREG | 0o644, a), + (b"a", stat.S_IFREG | 0o644, a), + ) + ] + ), + ) def test_copy_and_modify(self): - a = Blob.from_string(b'a') - b = Blob.from_string(b'b') - self.assertEqual([ - TreeChange( - b'git:b', ('a', 'b'), True, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'b'), - ('file', 'file'), (False, False), True) - ], self.transform([ - ('copy', - (b'a', stat.S_IFREG | 0o644, a), (b'b', stat.S_IFREG | 0o644, b))])) + a = Blob.from_string(b"a") + b = Blob.from_string(b"b") + self.assertEqual( + [ + TreeChange( + b"git:b", + ("a", "b"), + True, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "b"), + ("file", "file"), + (False, False), + True, + ) + ], + self.transform( + [ + ( + "copy", + (b"a", stat.S_IFREG | 0o644, a), + (b"b", stat.S_IFREG | 0o644, b), + ) + ] + ), + ) def test_add(self): - b = Blob.from_string(b'b') - self.assertEqual([ - TreeChange( - b'git:a', (None, 'a'), True, (False, True), (None, b'TREE_ROOT'), - (None, 'a'), (None, 'file'), (None, False), False) - ], self.transform([ - ('add', - (None, None, None), (b'a', stat.S_IFREG | 0o644, b))])) + b = Blob.from_string(b"b") + self.assertEqual( + [ + TreeChange( + b"git:a", + (None, "a"), + True, + (False, True), + (None, b"TREE_ROOT"), + (None, "a"), + (None, "file"), + (None, False), + False, + ) + ], + self.transform( + [("add", (None, None, None), (b"a", stat.S_IFREG | 0o644, b))] + ), + ) def test_delete(self): - b = Blob.from_string(b'b') - self.assertEqual([ - TreeChange( - b'git:a', ('a', None), True, (True, False), (b'TREE_ROOT', None), - ('a', None), ('file', None), (False, None), False) - ], self.transform([ - ('remove', - (b'a', stat.S_IFREG | 0o644, b), (None, None, None))])) + b = Blob.from_string(b"b") + self.assertEqual( + [ + TreeChange( + b"git:a", + ("a", None), + True, + (True, False), + (b"TREE_ROOT", None), + ("a", None), + ("file", None), + (False, None), + False, + ) + ], + self.transform( + [("remove", (b"a", stat.S_IFREG | 0o644, b), (None, None, None))] + ), + ) def test_unchanged(self): - b = Blob.from_string(b'b') - self.assertEqual([ - TreeChange( - b'git:a', ('a', 'a'), False, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'a'), ('file', 'file'), - (False, False), False) - ], self.transform([ - ('unchanged', - (b'a', stat.S_IFREG | 0o644, b), (b'a', stat.S_IFREG | 0o644, b))], - include_unchanged=True)) - self.assertEqual([], self.transform([ - ('unchanged', - (b'a', stat.S_IFREG | 0o644, b), (b'a', stat.S_IFREG | 0o644, b))], - include_unchanged=False)) + b = Blob.from_string(b"b") + self.assertEqual( + [ + TreeChange( + b"git:a", + ("a", "a"), + False, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "a"), + ("file", "file"), + (False, False), + False, + ) + ], + self.transform( + [ + ( + "unchanged", + (b"a", stat.S_IFREG | 0o644, b), + (b"a", stat.S_IFREG | 0o644, b), + ) + ], + include_unchanged=True, + ), + ) + self.assertEqual( + [], + self.transform( + [ + ( + "unchanged", + (b"a", stat.S_IFREG | 0o644, b), + (b"a", stat.S_IFREG | 0o644, b), + ) + ], + include_unchanged=False, + ), + ) def test_unversioned(self): - b = Blob.from_string(b'b') - self.assertEqual([ - TreeChange( - None, (None, 'a'), True, (False, False), - (None, b'TREE_ROOT'), (None, 'a'), (None, 'file'), - (None, False), False) - ], self.transform([ - ('add', - (None, None, None), (b'a', stat.S_IFREG | 0o644, b))], - target_extras={b'a'})) - self.assertEqual([ - TreeChange( - None, ('a', 'a'), False, (False, False), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'a'), ('file', 'file'), - (False, False), False) - ], self.transform([ - ('add', - (b'a', stat.S_IFREG | 0o644, b), (b'a', stat.S_IFREG | 0o644, b))], - source_extras={b'a'}, - target_extras={b'a'})) + b = Blob.from_string(b"b") + self.assertEqual( + [ + TreeChange( + None, + (None, "a"), + True, + (False, False), + (None, b"TREE_ROOT"), + (None, "a"), + (None, "file"), + (None, False), + False, + ) + ], + self.transform( + [("add", (None, None, None), (b"a", stat.S_IFREG | 0o644, b))], + target_extras={b"a"}, + ), + ) + self.assertEqual( + [ + TreeChange( + None, + ("a", "a"), + False, + (False, False), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "a"), + ("file", "file"), + (False, False), + False, + ) + ], + self.transform( + [ + ( + "add", + (b"a", stat.S_IFREG | 0o644, b), + (b"a", stat.S_IFREG | 0o644, b), + ) + ], + source_extras={b"a"}, + target_extras={b"a"}, + ), + ) class DeltaFromGitChangesTests(TestCase): - def setUp(self): super().setUp() self.maxDiff = None self.mapping = default_mapping def transform( - self, changes, specific_files=None, - require_versioned=False, include_root=False, - source_extras=None, target_extras=None): + self, + changes, + specific_files=None, + require_versioned=False, + include_root=False, + source_extras=None, + target_extras=None, + ): return tree_delta_from_git_changes( - changes, (self.mapping, self.mapping), + changes, + (self.mapping, self.mapping), specific_files=specific_files, - require_versioned=require_versioned, include_root=include_root, - source_extras=source_extras, target_extras=target_extras) + require_versioned=require_versioned, + include_root=include_root, + source_extras=source_extras, + target_extras=target_extras, + ) def test_empty(self): self.assertEqual(TreeDelta(), self.transform([])) def test_modified(self): - a = Blob.from_string(b'a') - b = Blob.from_string(b'b') - delta = self.transform([ - ('modify', - (b'a', stat.S_IFREG | 0o644, a), (b'a', stat.S_IFREG | 0o644, b))]) + a = Blob.from_string(b"a") + b = Blob.from_string(b"b") + delta = self.transform( + [ + ( + "modify", + (b"a", stat.S_IFREG | 0o644, a), + (b"a", stat.S_IFREG | 0o644, b), + ) + ] + ) expected_delta = TreeDelta() - expected_delta.modified.append(TreeChange( - b'git:a', ('a', 'a'), True, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'a'), ('file', 'file'), - (False, False), False)) + expected_delta.modified.append( + TreeChange( + b"git:a", + ("a", "a"), + True, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "a"), + ("file", "file"), + (False, False), + False, + ) + ) self.assertEqual(expected_delta, delta) def test_rename_no_changes(self): - a = Blob.from_string(b'a') - delta = self.transform([ - ('rename', - (b'old', stat.S_IFREG | 0o644, a), (b'a', stat.S_IFREG | 0o644, a))]) + a = Blob.from_string(b"a") + delta = self.transform( + [ + ( + "rename", + (b"old", stat.S_IFREG | 0o644, a), + (b"a", stat.S_IFREG | 0o644, a), + ) + ] + ) expected_delta = TreeDelta() expected_delta.renamed.append( TreeChange( - b'git:old', ('old', 'a'), False, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('old', 'a'), - ('file', 'file'), (False, False), False)) + b"git:old", + ("old", "a"), + False, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("old", "a"), + ("file", "file"), + (False, False), + False, + ) + ) self.assertEqual(expected_delta, delta) def test_rename_and_modify(self): - a = Blob.from_string(b'a') - b = Blob.from_string(b'b') - delta = self.transform([ - ('rename', - (b'a', stat.S_IFREG | 0o644, a), (b'b', stat.S_IFREG | 0o644, b))]) + a = Blob.from_string(b"a") + b = Blob.from_string(b"b") + delta = self.transform( + [ + ( + "rename", + (b"a", stat.S_IFREG | 0o644, a), + (b"b", stat.S_IFREG | 0o644, b), + ) + ] + ) expected_delta = TreeDelta() expected_delta.renamed.append( TreeChange( - b'git:a', ('a', 'b'), True, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'b'), - ('file', 'file'), (False, False), False)) + b"git:a", + ("a", "b"), + True, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "b"), + ("file", "file"), + (False, False), + False, + ) + ) self.assertEqual(delta, expected_delta) def test_copy_no_changes(self): - a = Blob.from_string(b'a') - delta = self.transform([ - ('copy', - (b'old', stat.S_IFREG | 0o644, a), (b'a', stat.S_IFREG | 0o644, a))]) + a = Blob.from_string(b"a") + delta = self.transform( + [ + ( + "copy", + (b"old", stat.S_IFREG | 0o644, a), + (b"a", stat.S_IFREG | 0o644, a), + ) + ] + ) expected_delta = TreeDelta() - expected_delta.copied.append(TreeChange( - b'git:a', ('old', 'a'), False, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('old', 'a'), - ('file', 'file'), (False, False), True)) + expected_delta.copied.append( + TreeChange( + b"git:a", + ("old", "a"), + False, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("old", "a"), + ("file", "file"), + (False, False), + True, + ) + ) self.assertEqual(expected_delta, delta) def test_copy_and_modify(self): - a = Blob.from_string(b'a') - b = Blob.from_string(b'b') - delta = self.transform([ - ('copy', - (b'a', stat.S_IFREG | 0o644, a), - (b'b', stat.S_IFREG | 0o644, b))]) + a = Blob.from_string(b"a") + b = Blob.from_string(b"b") + delta = self.transform( + [("copy", (b"a", stat.S_IFREG | 0o644, a), (b"b", stat.S_IFREG | 0o644, b))] + ) expected_delta = TreeDelta() - expected_delta.copied.append(TreeChange( - b'git:b', ('a', 'b'), True, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'b'), - ('file', 'file'), (False, False), True)) + expected_delta.copied.append( + TreeChange( + b"git:b", + ("a", "b"), + True, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "b"), + ("file", "file"), + (False, False), + True, + ) + ) self.assertEqual(expected_delta, delta) def test_add(self): - b = Blob.from_string(b'b') - delta = self.transform([ - ('add', - (None, None, None), (b'a', stat.S_IFREG | 0o644, b))]) + b = Blob.from_string(b"b") + delta = self.transform( + [("add", (None, None, None), (b"a", stat.S_IFREG | 0o644, b))] + ) expected_delta = TreeDelta() - expected_delta.added.append(TreeChange( - b'git:a', (None, 'a'), True, (False, True), (None, b'TREE_ROOT'), - (None, 'a'), (None, 'file'), (None, False), False)) + expected_delta.added.append( + TreeChange( + b"git:a", + (None, "a"), + True, + (False, True), + (None, b"TREE_ROOT"), + (None, "a"), + (None, "file"), + (None, False), + False, + ) + ) self.assertEqual(delta, expected_delta) def test_delete(self): - b = Blob.from_string(b'b') - delta = self.transform([ - ('remove', - (b'a', stat.S_IFREG | 0o644, b), (None, None, None))]) + b = Blob.from_string(b"b") + delta = self.transform( + [("remove", (b"a", stat.S_IFREG | 0o644, b), (None, None, None))] + ) expected_delta = TreeDelta() - expected_delta.removed.append(TreeChange( - b'git:a', ('a', None), True, (True, False), (b'TREE_ROOT', None), - ('a', None), ('file', None), (False, None), False)) + expected_delta.removed.append( + TreeChange( + b"git:a", + ("a", None), + True, + (True, False), + (b"TREE_ROOT", None), + ("a", None), + ("file", None), + (False, None), + False, + ) + ) self.assertEqual(delta, expected_delta) def test_unchanged(self): - b = Blob.from_string(b'b') - self.transform([ - ('unchanged', - (b'a', stat.S_IFREG | 0o644, b), (b'a', stat.S_IFREG | 0o644, b))]) + b = Blob.from_string(b"b") + self.transform( + [ + ( + "unchanged", + (b"a", stat.S_IFREG | 0o644, b), + (b"a", stat.S_IFREG | 0o644, b), + ) + ] + ) expected_delta = TreeDelta() - expected_delta.unchanged.append(TreeChange( - b'git:a', ('a', 'a'), False, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'a'), ('file', 'file'), - (False, False), False)) + expected_delta.unchanged.append( + TreeChange( + b"git:a", + ("a", "a"), + False, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "a"), + ("file", "file"), + (False, False), + False, + ) + ) def test_unversioned(self): - b = Blob.from_string(b'b') - delta = self.transform([ - ('add', - (None, None, None), (b'a', stat.S_IFREG | 0o644, b))], - target_extras={b'a'}) + b = Blob.from_string(b"b") + delta = self.transform( + [("add", (None, None, None), (b"a", stat.S_IFREG | 0o644, b))], + target_extras={b"a"}, + ) expected_delta = TreeDelta() expected_delta.unversioned.append( TreeChange( - None, (None, 'a'), True, (False, False), - (None, b'TREE_ROOT'), (None, 'a'), (None, 'file'), - (None, False), False)) + None, + (None, "a"), + True, + (False, False), + (None, b"TREE_ROOT"), + (None, "a"), + (None, "file"), + (None, False), + False, + ) + ) self.assertEqual(delta, expected_delta) - delta = self.transform([ - ('add', - (b'a', stat.S_IFREG | 0o644, b), (b'a', stat.S_IFREG | 0o644, b))], - source_extras={b'a'}, - target_extras={b'a'}) + delta = self.transform( + [("add", (b"a", stat.S_IFREG | 0o644, b), (b"a", stat.S_IFREG | 0o644, b))], + source_extras={b"a"}, + target_extras={b"a"}, + ) expected_delta = TreeDelta() - expected_delta.unversioned.append(TreeChange( - None, ('a', 'a'), False, (False, False), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'a'), ('file', 'file'), - (False, False), False)) + expected_delta.unversioned.append( + TreeChange( + None, + ("a", "a"), + False, + (False, False), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "a"), + ("file", "file"), + (False, False), + False, + ) + ) self.assertEqual(delta, expected_delta) def test_kind_change(self): - a = Blob.from_string(b'a') - b = Blob.from_string(b'target') - delta = self.transform([ - ('modify', - (b'a', stat.S_IFREG | 0o644, a), (b'a', stat.S_IFLNK, b))]) + a = Blob.from_string(b"a") + b = Blob.from_string(b"target") + delta = self.transform( + [("modify", (b"a", stat.S_IFREG | 0o644, a), (b"a", stat.S_IFLNK, b))] + ) expected_delta = TreeDelta() - expected_delta.kind_changed.append(TreeChange( - b'git:a', ('a', 'a'), True, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'a'), ('file', 'symlink'), - (False, False), False)) + expected_delta.kind_changed.append( + TreeChange( + b"git:a", + ("a", "a"), + True, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "a"), + ("file", "symlink"), + (False, False), + False, + ) + ) self.assertEqual(expected_delta, delta) class FindRelatedPathsAcrossTrees(TestCaseWithTransport): - def test_none(self): - self.make_branch_and_tree('t1', format='git') - wt = WorkingTree.open('t1') + self.make_branch_and_tree("t1", format="git") + wt = WorkingTree.open("t1") self.assertIs(None, wt.find_related_paths_across_trees(None)) def test_empty(self): - self.make_branch_and_tree('t1', format='git') - wt = WorkingTree.open('t1') + self.make_branch_and_tree("t1", format="git") + wt = WorkingTree.open("t1") self.assertEqual([], list(wt.find_related_paths_across_trees([]))) def test_directory(self): - self.make_branch_and_tree('t1', format='git') - wt = WorkingTree.open('t1') - self.build_tree(['t1/dir/', 't1/dir/file']) - wt.add(['dir', 'dir/file']) - self.assertEqual(['dir/file'], list(wt.find_related_paths_across_trees(['dir/file']))) - self.assertEqual(['dir'], list(wt.find_related_paths_across_trees(['dir']))) + self.make_branch_and_tree("t1", format="git") + wt = WorkingTree.open("t1") + self.build_tree(["t1/dir/", "t1/dir/file"]) + wt.add(["dir", "dir/file"]) + self.assertEqual( + ["dir/file"], list(wt.find_related_paths_across_trees(["dir/file"])) + ) + self.assertEqual(["dir"], list(wt.find_related_paths_across_trees(["dir"]))) def test_empty_directory(self): - self.make_branch_and_tree('t1', format='git') - wt = WorkingTree.open('t1') - self.build_tree(['t1/dir/']) - wt.add(['dir']) - self.assertEqual(['dir'], list(wt.find_related_paths_across_trees(['dir']))) - self.assertRaises(PathsNotVersionedError, wt.find_related_paths_across_trees, ['dir/file']) + self.make_branch_and_tree("t1", format="git") + wt = WorkingTree.open("t1") + self.build_tree(["t1/dir/"]) + wt.add(["dir"]) + self.assertEqual(["dir"], list(wt.find_related_paths_across_trees(["dir"]))) + self.assertRaises( + PathsNotVersionedError, wt.find_related_paths_across_trees, ["dir/file"] + ) def test_missing(self): - self.make_branch_and_tree('t1', format='git') - wt = WorkingTree.open('t1') - self.assertRaises(PathsNotVersionedError, wt.find_related_paths_across_trees, ['file']) + self.make_branch_and_tree("t1", format="git") + wt = WorkingTree.open("t1") + self.assertRaises( + PathsNotVersionedError, wt.find_related_paths_across_trees, ["file"] + ) def test_not_versioned(self): - self.make_branch_and_tree('t1', format='git') - self.make_branch_and_tree('t2', format='git') - wt1 = WorkingTree.open('t1') - wt2 = WorkingTree.open('t2') - self.build_tree(['t1/file']) - self.build_tree(['t2/file']) - self.assertRaises(PathsNotVersionedError, wt1.find_related_paths_across_trees, ['file'], [wt2]) + self.make_branch_and_tree("t1", format="git") + self.make_branch_and_tree("t2", format="git") + wt1 = WorkingTree.open("t1") + wt2 = WorkingTree.open("t2") + self.build_tree(["t1/file"]) + self.build_tree(["t2/file"]) + self.assertRaises( + PathsNotVersionedError, wt1.find_related_paths_across_trees, ["file"], [wt2] + ) def test_single(self): - self.make_branch_and_tree('t1', format='git') - wt = WorkingTree.open('t1') - self.build_tree(['t1/file']) - wt.add('file') - self.assertEqual(['file'], list(wt.find_related_paths_across_trees(['file']))) + self.make_branch_and_tree("t1", format="git") + wt = WorkingTree.open("t1") + self.build_tree(["t1/file"]) + wt.add("file") + self.assertEqual(["file"], list(wt.find_related_paths_across_trees(["file"]))) diff --git a/breezy/git/tests/test_unpeel_map.py b/breezy/git/tests/test_unpeel_map.py index 2af3f5cc8c..7e3fbd2595 100644 --- a/breezy/git/tests/test_unpeel_map.py +++ b/breezy/git/tests/test_unpeel_map.py @@ -23,7 +23,6 @@ class TestUnpeelMap(TestCaseWithTransport): - def test_new(self): m = UnpeelMap() self.assertIs(None, m.peel_tag("ab" * 20)) @@ -31,16 +30,21 @@ def test_new(self): def test_load(self): f = BytesIO( b"unpeel map version 1\n" - b"0123456789012345678901234567890123456789: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\n") + b"0123456789012345678901234567890123456789: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\n" + ) m = UnpeelMap() m.load(f) - self.assertEqual(b"0123456789012345678901234567890123456789", - m.peel_tag(b"aa" * 20)) + self.assertEqual( + b"0123456789012345678901234567890123456789", m.peel_tag(b"aa" * 20) + ) def test_update(self): m = UnpeelMap() - m.update({ - b"0123456789012345678901234567890123456789": {b"aa" * 20}, - }) - self.assertEqual(b"0123456789012345678901234567890123456789", - m.peel_tag(b"aa" * 20)) + m.update( + { + b"0123456789012345678901234567890123456789": {b"aa" * 20}, + } + ) + self.assertEqual( + b"0123456789012345678901234567890123456789", m.peel_tag(b"aa" * 20) + ) diff --git a/breezy/git/tests/test_urls.py b/breezy/git/tests/test_urls.py index 679ad4aeae..a29b88146a 100644 --- a/breezy/git/tests/test_urls.py +++ b/breezy/git/tests/test_urls.py @@ -21,49 +21,49 @@ class TestConvertURL(TestCase): - def test_simple(self): + self.assertEqual(git_url_to_bzr_url("foo:bar/path"), "git+ssh://foo/bar/path") self.assertEqual( - git_url_to_bzr_url('foo:bar/path'), - 'git+ssh://foo/bar/path') - self.assertEqual( - git_url_to_bzr_url( - 'user@foo:bar/path'), - ('git+ssh://user@foo/bar/path')) + git_url_to_bzr_url("user@foo:bar/path"), ("git+ssh://user@foo/bar/path") + ) def test_regular(self): self.assertEqual( - git_url_to_bzr_url( - 'git+ssh://user@foo/bar/path'), - ('git+ssh://user@foo/bar/path')) + git_url_to_bzr_url("git+ssh://user@foo/bar/path"), + ("git+ssh://user@foo/bar/path"), + ) def test_just_ssh(self): self.assertEqual( - git_url_to_bzr_url( - 'ssh://user@foo/bar/path'), - ('git+ssh://user@foo/bar/path')) + git_url_to_bzr_url("ssh://user@foo/bar/path"), + ("git+ssh://user@foo/bar/path"), + ) def test_path(self): - self.assertEqual(git_url_to_bzr_url('/bar/path'), ('/bar/path')) + self.assertEqual(git_url_to_bzr_url("/bar/path"), ("/bar/path")) def test_with_ref(self): self.assertEqual( - git_url_to_bzr_url('foo:bar/path', ref=b'HEAD'), - 'git+ssh://foo/bar/path') + git_url_to_bzr_url("foo:bar/path", ref=b"HEAD"), "git+ssh://foo/bar/path" + ) self.assertEqual( - git_url_to_bzr_url('foo:bar/path', ref=b'refs/heads/blah'), - 'git+ssh://foo/bar/path,branch=blah') + git_url_to_bzr_url("foo:bar/path", ref=b"refs/heads/blah"), + "git+ssh://foo/bar/path,branch=blah", + ) self.assertEqual( - git_url_to_bzr_url('foo:bar/path', ref=b'refs/tags/blah'), - 'git+ssh://foo/bar/path,ref=refs%2Ftags%2Fblah') + git_url_to_bzr_url("foo:bar/path", ref=b"refs/tags/blah"), + "git+ssh://foo/bar/path,ref=refs%2Ftags%2Fblah", + ) def test_with_branch(self): self.assertEqual( - git_url_to_bzr_url('foo:bar/path', branch=''), - 'git+ssh://foo/bar/path') + git_url_to_bzr_url("foo:bar/path", branch=""), "git+ssh://foo/bar/path" + ) self.assertEqual( - git_url_to_bzr_url('foo:bar/path', branch='foo/blah'), - 'git+ssh://foo/bar/path,branch=foo%2Fblah') + git_url_to_bzr_url("foo:bar/path", branch="foo/blah"), + "git+ssh://foo/bar/path,branch=foo%2Fblah", + ) self.assertEqual( - git_url_to_bzr_url('foo:bar/path', branch='blah'), - 'git+ssh://foo/bar/path,branch=blah') + git_url_to_bzr_url("foo:bar/path", branch="blah"), + "git+ssh://foo/bar/path,branch=blah", + ) diff --git a/breezy/git/tests/test_workingtree.py b/breezy/git/tests/test_workingtree.py index 2e9259d207..87c6b32c61 100644 --- a/breezy/git/tests/test_workingtree.py +++ b/breezy/git/tests/test_workingtree.py @@ -35,132 +35,151 @@ from ..tree import tree_delta_from_git_changes -def changes_between_git_tree_and_working_copy(source_store, from_tree_sha, target, - want_unchanged=False, - want_unversioned=False, - rename_detector=None, - include_trees=True): +def changes_between_git_tree_and_working_copy( + source_store, + from_tree_sha, + target, + want_unchanged=False, + want_unversioned=False, + rename_detector=None, + include_trees=True, +): """Determine the changes between a git tree and a working tree with index.""" to_tree_sha, extras = target.git_snapshot(want_unversioned=want_unversioned) store = OverlayObjectStore([source_store, target.store]) return tree_changes( - store, from_tree_sha, to_tree_sha, include_trees=include_trees, + store, + from_tree_sha, + to_tree_sha, + include_trees=include_trees, rename_detector=rename_detector, - want_unchanged=want_unchanged, change_type_same=True), extras + want_unchanged=want_unchanged, + change_type_same=True, + ), extras class GitWorkingTreeTests(TestCaseWithTransport): - def setUp(self): super().setUp() - self.tree = self.make_branch_and_tree('.', format="git") + self.tree = self.make_branch_and_tree(".", format="git") def test_conflict_list(self): - self.assertIsInstance( - self.tree.conflicts(), - _mod_conflicts.ConflictList) + self.assertIsInstance(self.tree.conflicts(), _mod_conflicts.ConflictList) def test_add_conflict(self): - self.build_tree(['conflicted']) - self.tree.add(['conflicted']) + self.build_tree(["conflicted"]) + self.tree.add(["conflicted"]) with self.tree.lock_tree_write(): - self.tree.index[b'conflicted'] = ConflictedIndexEntry(this=self.tree.index[b'conflicted']) + self.tree.index[b"conflicted"] = ConflictedIndexEntry( + this=self.tree.index[b"conflicted"] + ) self.tree._index_dirty = True conflicts = self.tree.conflicts() self.assertEqual(1, len(conflicts)) def test_revert_empty(self): - self.build_tree(['a']) - self.tree.add(['a']) - self.assertTrue(self.tree.is_versioned('a')) - self.tree.revert(['a']) - self.assertFalse(self.tree.is_versioned('a')) + self.build_tree(["a"]) + self.tree.add(["a"]) + self.assertTrue(self.tree.is_versioned("a")) + self.tree.revert(["a"]) + self.assertFalse(self.tree.is_versioned("a")) def test_is_ignored_directory(self): - self.assertFalse(self.tree.is_ignored('a')) - self.build_tree(['a/']) - self.assertFalse(self.tree.is_ignored('a')) - self.build_tree_contents([('.gitignore', 'a\n')]) + self.assertFalse(self.tree.is_ignored("a")) + self.build_tree(["a/"]) + self.assertFalse(self.tree.is_ignored("a")) + self.build_tree_contents([(".gitignore", "a\n")]) self.tree._ignoremanager = None - self.assertTrue(self.tree.is_ignored('a')) - self.build_tree_contents([('.gitignore', 'a/\n')]) + self.assertTrue(self.tree.is_ignored("a")) + self.build_tree_contents([(".gitignore", "a/\n")]) self.tree._ignoremanager = None - self.assertTrue(self.tree.is_ignored('a')) + self.assertTrue(self.tree.is_ignored("a")) def test_add_submodule_dir(self): - subtree = self.make_branch_and_tree('asub', format='git') - subtree.commit('Empty commit') - self.tree.add(['asub']) + subtree = self.make_branch_and_tree("asub", format="git") + subtree.commit("Empty commit") + self.tree.add(["asub"]) with self.tree.lock_read(): - entry = self.tree.index[b'asub'] + entry = self.tree.index[b"asub"] self.assertEqual(entry.mode, S_IFGITLINK) self.assertEqual([], list(subtree.unknowns())) def test_add_submodule_file(self): - os.mkdir('.git/modules') - self.make_branch('.git/modules/asub', format='git-bare') - os.mkdir('asub') - with open('asub/.git', 'w') as f: - f.write('gitdir: ../.git/modules/asub\n') - subtree = _mod_workingtree.WorkingTree.open('asub') - subtree.commit('Empty commit') - self.tree.add(['asub']) + os.mkdir(".git/modules") + self.make_branch(".git/modules/asub", format="git-bare") + os.mkdir("asub") + with open("asub/.git", "w") as f: + f.write("gitdir: ../.git/modules/asub\n") + subtree = _mod_workingtree.WorkingTree.open("asub") + subtree.commit("Empty commit") + self.tree.add(["asub"]) with self.tree.lock_read(): - entry = self.tree.index[b'asub'] + entry = self.tree.index[b"asub"] self.assertEqual(entry.mode, S_IFGITLINK) self.assertEqual([], list(subtree.unknowns())) class GitWorkingTreeFileTests(TestCaseWithTransport): - def setUp(self): super().setUp() - self.tree = self.make_branch_and_tree('actual', format="git") + self.tree = self.make_branch_and_tree("actual", format="git") self.build_tree_contents( - [('linked/',), ('linked/.git', 'gitdir: ../actual/.git')]) - self.wt = _mod_workingtree.WorkingTree.open('linked') + [("linked/",), ("linked/.git", "gitdir: ../actual/.git")] + ) + self.wt = _mod_workingtree.WorkingTree.open("linked") def test_add(self): - self.build_tree(['linked/somefile']) + self.build_tree(["linked/somefile"]) self.wt.add(["somefile"]) self.wt.commit("Add somefile") class TreeDeltaFromGitChangesTests(TestCase): - def test_empty(self): delta = TreeDelta() changes = [] self.assertEqual( delta, - tree_delta_from_git_changes(changes, (default_mapping, default_mapping))) + tree_delta_from_git_changes(changes, (default_mapping, default_mapping)), + ) def test_missing(self): delta = TreeDelta() delta.removed.append( TreeChange( - b'git:a', ('a', 'a'), False, (True, True), - (b'TREE_ROOT', b'TREE_ROOT'), ('a', 'a'), ('file', None), - (True, False))) + b"git:a", + ("a", "a"), + False, + (True, True), + (b"TREE_ROOT", b"TREE_ROOT"), + ("a", "a"), + ("file", None), + (True, False), + ) + ) changes = [ - ('remove', - (b'a', stat.S_IFREG | 0o755, b'a' * 40), - (b'a', 0, b'a' * 40))] + ("remove", (b"a", stat.S_IFREG | 0o755, b"a" * 40), (b"a", 0, b"a" * 40)) + ] self.assertEqual( delta, - tree_delta_from_git_changes(changes, (default_mapping, default_mapping))) + tree_delta_from_git_changes(changes, (default_mapping, default_mapping)), + ) class ChangesBetweenGitTreeAndWorkingCopyTests(TestCaseWithTransport): - def setUp(self): super().setUp() - self.wt = self.make_branch_and_tree('.', format='git') + self.wt = self.make_branch_and_tree(".", format="git") self.store = self.wt.branch.repository._git.object_store - def expectDelta(self, expected_changes, - expected_extras=None, want_unversioned=False, - tree_id=None, rename_detector=None): + def expectDelta( + self, + expected_changes, + expected_extras=None, + want_unversioned=False, + tree_id=None, + rename_detector=None, + ): if tree_id is None: try: tree_id = self.store[self.wt.branch.repository._git.head()].tree @@ -168,32 +187,38 @@ def expectDelta(self, expected_changes, tree_id = None with self.wt.lock_read(): changes, extras = changes_between_git_tree_and_working_copy( - self.store, tree_id, self.wt, want_unversioned=want_unversioned, - rename_detector=rename_detector) + self.store, + tree_id, + self.wt, + want_unversioned=want_unversioned, + rename_detector=rename_detector, + ) self.assertEqual(expected_changes, list(changes)) if expected_extras is None: expected_extras = set() self.assertEqual(set(expected_extras), set(extras)) def test_empty(self): - self.expectDelta( - [('add', (None, None, None), (b'', stat.S_IFDIR, Tree().id))]) + self.expectDelta([("add", (None, None, None), (b"", stat.S_IFDIR, Tree().id))]) def test_added_file(self): - self.build_tree(['a']) - self.wt.add(['a']) - a = Blob.from_string(b'contents of a\n') + self.build_tree(["a"]) + self.wt.add(["a"]) + a = Blob.from_string(b"contents of a\n") t = Tree() t.add(b"a", stat.S_IFREG | 0o644, a.id) self.expectDelta( - [('add', (None, None, None), (b'', stat.S_IFDIR, t.id)), - ('add', (None, None, None), (b'a', stat.S_IFREG | 0o644, a.id))]) + [ + ("add", (None, None, None), (b"", stat.S_IFDIR, t.id)), + ("add", (None, None, None), (b"a", stat.S_IFREG | 0o644, a.id)), + ] + ) def test_renamed_file(self): - self.build_tree(['a']) - self.wt.add(['a']) - self.wt.rename_one('a', 'b') - a = Blob.from_string(b'contents of a\n') + self.build_tree(["a"]) + self.wt.add(["a"]) + self.wt.rename_one("a", "b") + a = Blob.from_string(b"contents of a\n") self.store.add_object(a) oldt = Tree() oldt.add(b"a", stat.S_IFREG | 0o644, a.id) @@ -202,22 +227,36 @@ def test_renamed_file(self): newt.add(b"b", stat.S_IFREG | 0o644, a.id) self.store.add_object(newt) self.expectDelta( - [('modify', (b'', stat.S_IFDIR, oldt.id), (b'', stat.S_IFDIR, newt.id)), - ('delete', (b'a', stat.S_IFREG | 0o644, a.id), (None, None, None)), - ('add', (None, None, None), (b'b', stat.S_IFREG | 0o644, a.id)), - ], - tree_id=oldt.id) + [ + ("modify", (b"", stat.S_IFDIR, oldt.id), (b"", stat.S_IFDIR, newt.id)), + ("delete", (b"a", stat.S_IFREG | 0o644, a.id), (None, None, None)), + ("add", (None, None, None), (b"b", stat.S_IFREG | 0o644, a.id)), + ], + tree_id=oldt.id, + ) if dulwich_version >= (0, 19, 15): self.expectDelta( - [('modify', (b'', stat.S_IFDIR, oldt.id), (b'', stat.S_IFDIR, newt.id)), - ('rename', (b'a', stat.S_IFREG | 0o644, a.id), (b'b', stat.S_IFREG | 0o644, a.id))], - tree_id=oldt.id, rename_detector=RenameDetector(self.store)) + [ + ( + "modify", + (b"", stat.S_IFDIR, oldt.id), + (b"", stat.S_IFDIR, newt.id), + ), + ( + "rename", + (b"a", stat.S_IFREG | 0o644, a.id), + (b"b", stat.S_IFREG | 0o644, a.id), + ), + ], + tree_id=oldt.id, + rename_detector=RenameDetector(self.store), + ) def test_copied_file(self): - self.build_tree(['a']) - self.wt.add(['a']) - self.wt.copy_one('a', 'b') - a = Blob.from_string(b'contents of a\n') + self.build_tree(["a"]) + self.wt.add(["a"]) + self.wt.copy_one("a", "b") + a = Blob.from_string(b"contents of a\n") self.store.add_object(a) oldt = Tree() oldt.add(b"a", stat.S_IFREG | 0o644, a.id) @@ -227,116 +266,147 @@ def test_copied_file(self): newt.add(b"b", stat.S_IFREG | 0o644, a.id) self.store.add_object(newt) self.expectDelta( - [('modify', (b'', stat.S_IFDIR, oldt.id), (b'', stat.S_IFDIR, newt.id)), - ('add', (None, None, None), (b'b', stat.S_IFREG | 0o644, a.id)), - ], - tree_id=oldt.id) + [ + ("modify", (b"", stat.S_IFDIR, oldt.id), (b"", stat.S_IFDIR, newt.id)), + ("add", (None, None, None), (b"b", stat.S_IFREG | 0o644, a.id)), + ], + tree_id=oldt.id, + ) if dulwich_version >= (0, 19, 15): self.expectDelta( - [('modify', (b'', stat.S_IFDIR, oldt.id), (b'', stat.S_IFDIR, newt.id)), - ('copy', (b'a', stat.S_IFREG | 0o644, a.id), (b'b', stat.S_IFREG | 0o644, a.id))], - tree_id=oldt.id, rename_detector=RenameDetector(self.store, find_copies_harder=True)) + [ + ( + "modify", + (b"", stat.S_IFDIR, oldt.id), + (b"", stat.S_IFDIR, newt.id), + ), + ( + "copy", + (b"a", stat.S_IFREG | 0o644, a.id), + (b"b", stat.S_IFREG | 0o644, a.id), + ), + ], + tree_id=oldt.id, + rename_detector=RenameDetector(self.store, find_copies_harder=True), + ) self.expectDelta( - [('modify', (b'', stat.S_IFDIR, oldt.id), (b'', stat.S_IFDIR, newt.id)), - ('add', (None, None, None), (b'b', stat.S_IFREG | 0o644, a.id)), - ], - tree_id=oldt.id, rename_detector=RenameDetector(self.store, find_copies_harder=False)) + [ + ( + "modify", + (b"", stat.S_IFDIR, oldt.id), + (b"", stat.S_IFDIR, newt.id), + ), + ("add", (None, None, None), (b"b", stat.S_IFREG | 0o644, a.id)), + ], + tree_id=oldt.id, + rename_detector=RenameDetector(self.store, find_copies_harder=False), + ) def test_added_unknown_file(self): - self.build_tree(['a']) + self.build_tree(["a"]) t = Tree() - self.expectDelta( - [('add', (None, None, None), (b'', stat.S_IFDIR, t.id))]) - a = Blob.from_string(b'contents of a\n') + self.expectDelta([("add", (None, None, None), (b"", stat.S_IFDIR, t.id))]) + a = Blob.from_string(b"contents of a\n") t = Tree() t.add(b"a", stat.S_IFREG | 0o644, a.id) self.expectDelta( - [('add', (None, None, None), (b'', stat.S_IFDIR, t.id)), - ('add', (None, None, None), (b'a', stat.S_IFREG | 0o644, a.id))], - [b'a'], - want_unversioned=True) + [ + ("add", (None, None, None), (b"", stat.S_IFDIR, t.id)), + ("add", (None, None, None), (b"a", stat.S_IFREG | 0o644, a.id)), + ], + [b"a"], + want_unversioned=True, + ) def test_missing_added_file(self): - self.build_tree(['a']) - self.wt.add(['a']) - os.unlink('a') - Blob.from_string(b'contents of a\n') + self.build_tree(["a"]) + self.wt.add(["a"]) + os.unlink("a") + Blob.from_string(b"contents of a\n") t = Tree() t.add(b"a", 0, ZERO_SHA) self.expectDelta( - [('add', (None, None, None), (b'', stat.S_IFDIR, t.id)), - ('add', (None, None, None), (b'a', 0, ZERO_SHA))], - []) + [ + ("add", (None, None, None), (b"", stat.S_IFDIR, t.id)), + ("add", (None, None, None), (b"a", 0, ZERO_SHA)), + ], + [], + ) def test_missing_versioned_file(self): - self.build_tree(['a']) - self.wt.add(['a']) - self.wt.commit('') - os.unlink('a') - a = Blob.from_string(b'contents of a\n') + self.build_tree(["a"]) + self.wt.add(["a"]) + self.wt.commit("") + os.unlink("a") + a = Blob.from_string(b"contents of a\n") oldt = Tree() oldt.add(b"a", stat.S_IFREG | 0o644, a.id) newt = Tree() newt.add(b"a", 0, ZERO_SHA) self.expectDelta( - [('modify', - (b'', stat.S_IFDIR, oldt.id), - (b'', stat.S_IFDIR, newt.id)), - ('modify', - (b'a', stat.S_IFREG | 0o644, a.id), - (b'a', 0, ZERO_SHA))]) + [ + ("modify", (b"", stat.S_IFDIR, oldt.id), (b"", stat.S_IFDIR, newt.id)), + ("modify", (b"a", stat.S_IFREG | 0o644, a.id), (b"a", 0, ZERO_SHA)), + ] + ) def test_versioned_replace_by_dir(self): - self.build_tree(['a']) - self.wt.add(['a']) - self.wt.commit('') - os.unlink('a') - os.mkdir('a') - olda = Blob.from_string(b'contents of a\n') + self.build_tree(["a"]) + self.wt.add(["a"]) + self.wt.commit("") + os.unlink("a") + os.mkdir("a") + olda = Blob.from_string(b"contents of a\n") oldt = Tree() oldt.add(b"a", stat.S_IFREG | 0o644, olda.id) newt = Tree() newa = Tree() newt.add(b"a", stat.S_IFDIR, newa.id) - self.expectDelta([ - ('modify', - (b'', stat.S_IFDIR, oldt.id), - (b'', stat.S_IFDIR, newt.id)), - ('modify', - (b'a', stat.S_IFREG | 0o644, olda.id), - (b'a', stat.S_IFDIR, newa.id)) - ], want_unversioned=False) - self.expectDelta([ - ('modify', - (b'', stat.S_IFDIR, oldt.id), - (b'', stat.S_IFDIR, newt.id)), - ('modify', - (b'a', stat.S_IFREG | 0o644, olda.id), - (b'a', stat.S_IFDIR, newa.id)), - ], want_unversioned=True) + self.expectDelta( + [ + ("modify", (b"", stat.S_IFDIR, oldt.id), (b"", stat.S_IFDIR, newt.id)), + ( + "modify", + (b"a", stat.S_IFREG | 0o644, olda.id), + (b"a", stat.S_IFDIR, newa.id), + ), + ], + want_unversioned=False, + ) + self.expectDelta( + [ + ("modify", (b"", stat.S_IFDIR, oldt.id), (b"", stat.S_IFDIR, newt.id)), + ( + "modify", + (b"a", stat.S_IFREG | 0o644, olda.id), + (b"a", stat.S_IFDIR, newa.id), + ), + ], + want_unversioned=True, + ) def test_extra(self): - self.build_tree(['a']) - newa = Blob.from_string(b'contents of a\n') + self.build_tree(["a"]) + newa = Blob.from_string(b"contents of a\n") newt = Tree() newt.add(b"a", stat.S_IFREG | 0o644, newa.id) - self.expectDelta([ - ('add', - (None, None, None), - (b'', stat.S_IFDIR, newt.id)), - ('add', - (None, None, None), - (b'a', stat.S_IFREG | 0o644, newa.id)), - ], [b'a'], want_unversioned=True) + self.expectDelta( + [ + ("add", (None, None, None), (b"", stat.S_IFDIR, newt.id)), + ("add", (None, None, None), (b"a", stat.S_IFREG | 0o644, newa.id)), + ], + [b"a"], + want_unversioned=True, + ) def test_submodule(self): - self.subtree = self.make_branch_and_tree('a', format="git") - a = Blob.from_string(b'irrelevant\n') - self.build_tree_contents([('a/.git/HEAD', a.id)]) + self.subtree = self.make_branch_and_tree("a", format="git") + a = Blob.from_string(b"irrelevant\n") + self.build_tree_contents([("a/.git/HEAD", a.id)]) with self.wt.lock_tree_write(): - (index, index_path) = self.wt._lookup_index(b'a') - index[b'a'] = IndexEntry(0, 0, 0, 0, S_IFGITLINK, 0, 0, 0, a.id) + (index, index_path) = self.wt._lookup_index(b"a") + index[b"a"] = IndexEntry(0, 0, 0, 0, S_IFGITLINK, 0, 0, 0, a.id) self.wt._index_dirty = True t = Tree() t.add(b"a", S_IFGITLINK, a.id) @@ -344,12 +414,12 @@ def test_submodule(self): self.expectDelta([], tree_id=t.id) def test_submodule_not_checked_out(self): - a = Blob.from_string(b'irrelevant\n') + a = Blob.from_string(b"irrelevant\n") with self.wt.lock_tree_write(): - (index, index_path) = self.wt._lookup_index(b'a') - index[b'a'] = IndexEntry(0, 0, 0, 0, S_IFGITLINK, 0, 0, 0, a.id) + (index, index_path) = self.wt._lookup_index(b"a") + index[b"a"] = IndexEntry(0, 0, 0, 0, S_IFGITLINK, 0, 0, 0, a.id) self.wt._index_dirty = True - os.mkdir(self.wt.abspath('a')) + os.mkdir(self.wt.abspath("a")) t = Tree() t.add(b"a", S_IFGITLINK, a.id) self.store.add_object(t) diff --git a/breezy/git/transform.py b/breezy/git/transform.py index abb62b8512..80ac8304d4 100644 --- a/breezy/git/transform.py +++ b/breezy/git/transform.py @@ -73,7 +73,7 @@ def __init__(self, tree, pb=None, case_sensitive=True): # Set of versioned trans ids self._versioned = set() # The trans_id that will be used as the tree root - self.root = self.trans_id_tree_path('') + self.root = self.trans_id_tree_path("") # Whether the target is case sensitive self._case_sensitive_target = case_sensitive self._symlink_target = {} @@ -90,7 +90,7 @@ def finalize(self): """ if self._tree is None: return - for hook in MutableTree.hooks['post_transform']: + for hook in MutableTree.hooks["post_transform"]: hook(self._tree, self) self._tree.unlock() self._tree = None @@ -117,12 +117,11 @@ def fixup_new_roots(self): irrelevant. """ - new_roots = [k for k, v in self._new_parent.items() - if v == ROOT_PARENT] + new_roots = [k for k, v in self._new_parent.items() if v == ROOT_PARENT] if len(new_roots) < 1: return if len(new_roots) != 1: - raise ValueError('A tree cannot have two roots!') + raise ValueError("A tree cannot have two roots!") old_new_root = new_roots[0] # unversion the new root's directory. if old_new_root in self._versioned: @@ -159,7 +158,7 @@ def trans_id_file_id(self, file_id): (this will likely lead to an unversioned parent conflict.). """ if file_id is None: - raise ValueError('None is not a valid file id') + raise ValueError("None is not a valid file id") path = self.mapping.parse_file_id(file_id) return self.trans_id_tree_path(path) @@ -188,8 +187,13 @@ def new_paths(self, filesystem_only=False): needs_rename = self._needs_rename.difference(stale_ids) id_sets = (needs_rename, self._new_executability) else: - id_sets = (self._new_name, self._new_parent, self._new_contents, - self._versioned, self._new_executability) + id_sets = ( + self._new_name, + self._new_parent, + self._new_contents, + self._versioned, + self._new_executability, + ) for id_set in id_sets: new_ids.update(id_set) return sorted(FinalPaths(self).get_paths(new_ids)) @@ -233,17 +237,18 @@ def _add_tree_children(self): removed. This is a necessary first step in detecting conflicts. """ parents = list(self.by_parent()) - parents.extend([t for t in self._removed_contents if - self.tree_kind(t) == 'directory']) + parents.extend( + [t for t in self._removed_contents if self.tree_kind(t) == "directory"] + ) for trans_id in self._removed_id: path = self.tree_path(trans_id) if path is not None: try: - if self._tree.stored_kind(path) == 'directory': + if self._tree.stored_kind(path) == "directory": parents.append(trans_id) except _mod_transport.NoSuchFile: pass - elif self.tree_kind(trans_id) == 'directory': + elif self.tree_kind(trans_id) == "directory": parents.append(trans_id) for parent_id in parents: @@ -276,7 +281,9 @@ def _has_named_child(self, name, parent_id, known_children): # Not known by the tree transform yet, check the filesystem return osutils.lexists(self._tree.abspath(child_path)) else: - raise AssertionError(f'child_id is missing: {name}, {parent_id}, {child_id}') + raise AssertionError( + f"child_id is missing: {name}, {parent_id}, {child_id}" + ) def _available_backup_name(self, name, target_id): """Find an available backup name. @@ -288,9 +295,8 @@ def _available_backup_name(self, name, target_id): """ known_children = self.by_parent().get(target_id, []) return osutils.available_backup_name( - name, - lambda base: self._has_named_child( - base, target_id, known_children)) + name, lambda base: self._has_named_child(base, target_id, known_children) + ) def _parent_loops(self): """No entry should be its own ancestor.""" @@ -304,7 +310,7 @@ def _parent_loops(self): except KeyError: break if parent_id == trans_id: - yield ('parent loop', trans_id) + yield ("parent loop", trans_id) if parent_id in seen: break @@ -315,14 +321,14 @@ def _improper_versioning(self): """ for trans_id in self._versioned: kind = self.final_kind(trans_id) - if kind == 'symlink' and not self._tree.supports_symlinks(): + if kind == "symlink" and not self._tree.supports_symlinks(): # Ignore symlinks as they are not supported on this platform continue if kind is None: - yield ('versioning no contents', trans_id) + yield ("versioning no contents", trans_id) continue if not self._tree.versionable_kind(kind): - yield ('versioning bad kind', trans_id, kind) + yield ("versioning bad kind", trans_id, kind) def _executability_conflicts(self): """Check for bad executability changes. @@ -334,10 +340,10 @@ def _executability_conflicts(self): """ for trans_id in self._new_executability: if not self.final_is_versioned(trans_id): - yield ('unversioned executability', trans_id) + yield ("unversioned executability", trans_id) else: if self.final_kind(trans_id) != "file": - yield ('non-file executability', trans_id) + yield ("non-file executability", trans_id) def _overwrite_conflicts(self): """Check for overwrites (not permitted on Win32).""" @@ -345,7 +351,7 @@ def _overwrite_conflicts(self): if self.tree_kind(trans_id) is None: continue if trans_id not in self._removed_contents: - yield ('overwrite', trans_id, self.final_name(trans_id)) + yield ("overwrite", trans_id, self.final_name(trans_id)) def _duplicate_entries(self, by_parent): """No directory may have two entries with the same name.""" @@ -368,7 +374,7 @@ def _duplicate_entries(self, by_parent): if kind is None and not self.final_is_versioned(trans_id): continue if name == last_name: - yield ('duplicate', last_trans_id, trans_id, name) + yield ("duplicate", last_trans_id, trans_id, name) last_name = name last_trans_id = trans_id @@ -389,10 +395,10 @@ def _parent_type_conflicts(self, by_parent): kind = self.final_kind(parent_id) if kind is None: # The directory will be deleted - yield ('missing parent', parent_id) + yield ("missing parent", parent_id) elif kind != "directory": # Meh, we need a *directory* to put something in it - yield ('non-directory parent', parent_id) + yield ("non-directory parent", parent_id) def _set_executability(self, path, trans_id): """Set the executability of versioned files.""" @@ -420,8 +426,9 @@ def _new_entry(self, name, parent_id, file_id): self.version_file(trans_id, file_id=file_id) return trans_id - def new_file(self, name, parent_id, contents, file_id=None, - executable=None, sha1=None): + def new_file( + self, name, parent_id, contents, file_id=None, executable=None, sha1=None + ): """Convenience method to create files. name is the name of the file to create. @@ -546,8 +553,9 @@ def iter_changes(self, want_unversioned=False): if from_versioned: # get data from working tree if versioned - from_entry = next(self._tree.iter_entries_by_dir( - specific_files=[from_path]))[1] + from_entry = next( + self._tree.iter_entries_by_dir(specific_files=[from_path]) + )[1] from_name = from_entry.name else: from_entry = None @@ -559,8 +567,9 @@ def iter_changes(self, want_unversioned=False): # splitting stuff from_name = os.path.basename(from_path) if from_path is not None: - from_kind, from_executable, from_stats = \ - self._tree._comparison_data(from_entry, from_path) + from_kind, from_executable, from_stats = self._tree._comparison_data( + from_entry, from_path + ) else: from_kind = None from_executable = False @@ -574,26 +583,32 @@ def iter_changes(self, want_unversioned=False): if from_versioned and from_kind != to_kind: modified = True - elif to_kind in ('file', 'symlink') and ( - trans_id in self._new_contents): + elif to_kind in ("file", "symlink") and (trans_id in self._new_contents): modified = True - if (not modified and from_versioned == to_versioned + if ( + not modified + and from_versioned == to_versioned and from_path == to_path and from_name == to_name - and from_executable == to_executable): + and from_executable == to_executable + ): continue if (from_path, to_path) == (None, None): continue results.append( TreeChange( - (from_path, to_path), modified, + (from_path, to_path), + modified, (from_versioned, to_versioned), (from_name, to_name), (from_kind, to_kind), - (from_executable, to_executable))) + (from_executable, to_executable), + ) + ) def path_key(c): - return (c.path[0] or '', c.path[1] or '') + return (c.path[0] or "", c.path[1] or "") + return iter(sorted(results, key=path_key)) def get_preview_tree(self): @@ -604,9 +619,19 @@ def get_preview_tree(self): """ return GitPreviewTree(self) - def commit(self, branch, message, merge_parents=None, strict=False, - timestamp=None, timezone=None, committer=None, authors=None, - revprops=None, revision_id=None): + def commit( + self, + branch, + message, + merge_parents=None, + strict=False, + timestamp=None, + timezone=None, + committer=None, + authors=None, + revprops=None, + revision_id=None, + ): """Commit the result of this TreeTransform to a branch. :param branch: The branch to commit to. @@ -637,27 +662,30 @@ def commit(self, branch, message, merge_parents=None, strict=False, revno, last_rev_id = branch.last_revision_info() if last_rev_id == _mod_revision.NULL_REVISION: if merge_parents is not None: - raise ValueError('Cannot supply merge parents for first' - ' commit.') + raise ValueError("Cannot supply merge parents for first" " commit.") parent_ids = [] else: parent_ids = [last_rev_id] if merge_parents is not None: parent_ids.extend(merge_parents) if self._tree.get_revision_id() != last_rev_id: - raise ValueError('TreeTransform not based on branch basis: %s' % - self._tree.get_revision_id().decode('utf-8')) + raise ValueError( + "TreeTransform not based on branch basis: %s" + % self._tree.get_revision_id().decode("utf-8") + ) from .. import commit + revprops = commit.Commit.update_revprops(revprops, branch, authors) - builder = branch.get_commit_builder(parent_ids, - timestamp=timestamp, - timezone=timezone, - committer=committer, - revprops=revprops, - revision_id=revision_id) + builder = branch.get_commit_builder( + parent_ids, + timestamp=timestamp, + timezone=timezone, + committer=committer, + revprops=revprops, + revision_id=revision_id, + ) preview = self.get_preview_tree() - list(builder.record_iter_changes(preview, last_rev_id, - self.iter_changes())) + list(builder.record_iter_changes(preview, last_rev_id, self.iter_changes())) builder.finish_inventory() revision_id = builder.commit(message) branch.set_last_revision_info(revno + 1, revision_id) @@ -666,7 +694,7 @@ def commit(self, branch, message, merge_parents=None, strict=False, def _text_parent(self, trans_id): path = self.tree_path(trans_id) try: - if path is None or self._tree.kind(path) != 'file': + if path is None or self._tree.kind(path) != "file": return None except _mod_transport.NoSuchFile: return None @@ -748,25 +776,26 @@ def cook_conflicts(self, raw_conflicts): return fp = FinalPaths(self) from .workingtree import ContentsConflict, TextConflict + for c in raw_conflicts: - if c[0] == 'text conflict': + if c[0] == "text conflict": yield TextConflict(fp.get_path(c[1])) - elif c[0] == 'contents conflict': + elif c[0] == "contents conflict": yield ContentsConflict(fp.get_path(c[1][0])) - elif c[0] == 'duplicate': + elif c[0] == "duplicate": yield TextConflict(fp.get_path(c[2])) - elif c[0] == 'missing parent': + elif c[0] == "missing parent": pass - elif c[0] == 'non-directory parent': + elif c[0] == "non-directory parent": yield TextConflict(fp.get_path(c[2])) - elif c[0] == 'deleting parent': + elif c[0] == "deleting parent": # TODO(jelmer): This should not make it to here yield TextConflict(fp.get_path(c[2])) - elif c[0] == 'parent loop': + elif c[0] == "parent loop": # TODO(jelmer): This should not make it to here yield TextConflict(fp.get_path(c[2])) else: - raise AssertionError(f'unknown conflict {c[0]}') + raise AssertionError(f"unknown conflict {c[0]}") class DiskTreeTransform(TreeTransformBase): @@ -856,8 +885,7 @@ def adjust_path(self, name, parent, trans_id): previous_parent = self._new_parent.get(trans_id) previous_name = self._new_name.get(trans_id) super().adjust_path(name, parent, trans_id) - if (trans_id in self._limbo_files - and trans_id not in self._needs_rename): + if trans_id in self._limbo_files and trans_id not in self._needs_rename: self._rename_in_limbo([trans_id]) if previous_parent != parent: self._limbo_children[previous_parent].remove(trans_id) @@ -886,7 +914,7 @@ def _rename_in_limbo(self, trans_ids): self._possibly_stale_limbo_files.remove(old_path) for descendant in self._limbo_descendants(trans_id): desc_path = self._limbo_files[descendant] - desc_path = new_path + desc_path[len(old_path):] + desc_path = new_path + desc_path[len(old_path) :] self._limbo_files[descendant] = desc_path def _limbo_descendants(self, trans_id): @@ -914,8 +942,8 @@ def create_file(self, contents, trans_id, mode_id=None, sha1=None): We can use it to prevent future sha1 computations. """ name = self._limbo_name(trans_id) - with open(name, 'wb') as f: - unique_add(self._new_contents, trans_id, 'file') + with open(name, "wb") as f: + unique_add(self._new_contents, trans_id, "file") f.writelines(contents) self._set_mtime(name) self._set_mode(trans_id, mode_id, S_ISREG) @@ -945,7 +973,7 @@ def create_hardlink(self, path, trans_id): except PermissionError as err: raise errors.HardLinkNotSupported(path) from err try: - unique_add(self._new_contents, trans_id, 'file') + unique_add(self._new_contents, trans_id, "file") except BaseException: # Clean up the file, it never got registered so # TreeTransform.finalize() won't clean it up. @@ -958,7 +986,7 @@ def create_directory(self, trans_id): See also new_directory. """ os.mkdir(self._limbo_name(trans_id)) - unique_add(self._new_contents, trans_id, 'directory') + unique_add(self._new_contents, trans_id, "directory") def create_symlink(self, target, trans_id): """Schedule creation of a new symbolic link. @@ -973,13 +1001,12 @@ def create_symlink(self, target, trans_id): path = FinalPaths(self).get_path(trans_id) except KeyError: path = None - trace.warning( - f'Unable to create symlink "{path}" on this filesystem.') + trace.warning(f'Unable to create symlink "{path}" on this filesystem.') self._symlink_target[trans_id] = target # We add symlink to _new_contents even if they are unsupported # and not created. These entries are subsequently used to avoid # conflicts on platforms that don't support symlink - unique_add(self._new_contents, trans_id, 'symlink') + unique_add(self._new_contents, trans_id, "symlink") def create_tree_reference(self, reference_revision, trans_id): """Schedule creation of a new symbolic link. @@ -989,7 +1016,7 @@ def create_tree_reference(self, reference_revision, trans_id): """ os.mkdir(self._limbo_name(trans_id)) unique_add(self._new_reference_revision, trans_id, reference_revision) - unique_add(self._new_contents, trans_id, 'tree-reference') + unique_add(self._new_contents, trans_id, "tree-reference") def cancel_creation(self, trans_id): """Cancel the creation of new file contents.""" @@ -1007,7 +1034,7 @@ def cancel_creation(self, trans_id): def new_orphan(self, trans_id, parent_id): conf = self._tree.get_config_stack() - handle_orphan = conf.get('transform.orphan_policy') + handle_orphan = conf.get("transform.orphan_policy") handle_orphan(self, trans_id, parent_id) def final_entry(self, trans_id): @@ -1021,20 +1048,26 @@ def final_entry(self, trans_id): name = self.final_name(trans_id) file_id = self._tree.mapping.generate_file_id(tree_path) parent_id = self._tree.mapping.generate_file_id(os.path.dirname(tree_path)) - if kind == 'directory': + if kind == "directory": return GitTreeDirectory( - file_id, self.final_name(trans_id), parent_id=parent_id), is_versioned + file_id, self.final_name(trans_id), parent_id=parent_id + ), is_versioned executable = mode_is_executable(st.st_mode) object_mode(kind, executable) blob = blob_from_path_and_stat(encode_git_path(path), st) - if kind == 'symlink': + if kind == "symlink": return GitTreeSymlink( - file_id, name, parent_id, - decode_git_path(blob.data)), is_versioned - elif kind == 'file': + file_id, name, parent_id, decode_git_path(blob.data) + ), is_versioned + elif kind == "file": return GitTreeFile( - file_id, name, executable=executable, parent_id=parent_id, - git_sha1=blob.id, text_size=len(blob.data)), is_versioned + file_id, + name, + executable=executable, + parent_id=parent_id, + git_sha1=blob.id, + text_size=len(blob.data), + ), is_versioned else: raise AssertionError(kind) elif trans_id in self._removed_contents: @@ -1044,21 +1077,23 @@ def final_entry(self, trans_id): if orig_path is None: return None, None file_id = self._tree.mapping.generate_file_id(tree_path) - if tree_path == '': + if tree_path == "": parent_id = None else: - parent_id = self._tree.mapping.generate_file_id(os.path.dirname(tree_path)) + parent_id = self._tree.mapping.generate_file_id( + os.path.dirname(tree_path) + ) try: - ie = next(self._tree.iter_entries_by_dir( - specific_files=[orig_path]))[1] + ie = next(self._tree.iter_entries_by_dir(specific_files=[orig_path]))[1] ie.file_id = file_id ie.parent_id = parent_id return ie, is_versioned except StopIteration: try: - if self.tree_kind(trans_id) == 'directory': + if self.tree_kind(trans_id) == "directory": return GitTreeDirectory( - file_id, self.final_name(trans_id), parent_id=parent_id), is_versioned + file_id, self.final_name(trans_id), parent_id=parent_id + ), is_versioned except NotADirectoryError: pass return None, None @@ -1068,7 +1103,7 @@ def final_git_entry(self, trans_id): path = self._limbo_name(trans_id) st = os.lstat(path) kind = mode_kind(st.st_mode) - if kind == 'directory': + if kind == "directory": return None, None executable = mode_is_executable(st.st_mode) mode = object_mode(kind, executable) @@ -1080,11 +1115,11 @@ def final_git_entry(self, trans_id): kind = self._tree.kind(orig_path) executable = self._tree.is_executable(orig_path) mode = object_mode(kind, executable) - if kind == 'symlink': + if kind == "symlink": contents = self._tree.get_symlink_target(orig_path) - elif kind == 'file': + elif kind == "file": contents = self._tree.get_file_text(orig_path) - elif kind == 'directory': + elif kind == "directory": return None, None else: raise AssertionError(kind) @@ -1164,18 +1199,16 @@ def __init__(self, tree, pb=None): """ tree.lock_tree_write() try: - limbodir = urlutils.local_path_from_url( - tree._transport.abspath('limbo')) + limbodir = urlutils.local_path_from_url(tree._transport.abspath("limbo")) try: - osutils.ensure_empty_directory_exists( - limbodir) + osutils.ensure_empty_directory_exists(limbodir) except errors.DirectoryNotEmpty as err: raise errors.ExistingLimbo(limbodir) from err deletiondir = urlutils.local_path_from_url( - tree._transport.abspath('pending-deletion')) + tree._transport.abspath("pending-deletion") + ) try: - osutils.ensure_empty_directory_exists( - deletiondir) + osutils.ensure_empty_directory_exists(deletiondir) except errors.DirectoryNotEmpty as err: raise errors.ExistingPendingDeletion(deletiondir) from err except BaseException: @@ -1186,8 +1219,7 @@ def __init__(self, tree, pb=None): self._realpaths = {} # Cache of relpath results, to speed up canonical_path self._relpaths = {} - DiskTreeTransform.__init__(self, tree, limbodir, pb, - tree.case_sensitive) + DiskTreeTransform.__init__(self, tree, limbodir, pb, tree.case_sensitive) self._deletiondir = deletiondir def canonical_path(self, path): @@ -1203,7 +1235,7 @@ def canonical_path(self, path): abs = osutils.pathjoin(dirname, basename) if dirname in self._relpaths: relpath = osutils.pathjoin(self._relpaths[dirname], basename) - relpath = relpath.rstrip('/\\') + relpath = relpath.rstrip("/\\") else: relpath = self._tree.relpath(abs) self._relpaths[abs] = relpath @@ -1277,7 +1309,7 @@ def _generate_limbo_path(self, trans_id): # tree), choose a limbo name inside the parent, to reduce further # renames. use_direct_path = False - if self._new_contents.get(parent) == 'directory': + if self._new_contents.get(parent) == "directory": filename = self._new_name.get(trans_id) if filename is not None: if parent not in self._limbo_children: @@ -1288,12 +1320,15 @@ def _generate_limbo_path(self, trans_id): # already taken this pathname, i.e. if the name is unused, or # if it is already associated with this trans_id. elif self._case_sensitive_target: - if (self._limbo_children_names[parent].get(filename) - in (trans_id, None)): + if self._limbo_children_names[parent].get(filename) in ( + trans_id, + None, + ): use_direct_path = True else: - for l_filename, l_trans_id in ( - self._limbo_children_names[parent].items()): + for l_filename, l_trans_id in self._limbo_children_names[ + parent + ].items(): if l_trans_id == trans_id: continue if l_filename.lower() == filename.lower(): @@ -1325,13 +1360,13 @@ def apply(self, no_conflicts=False, _mover=None): conflicts, so no check is made. :param _mover: Supply an alternate FileMover, for testing """ - for hook in MutableTree.hooks['pre_transform']: + for hook in MutableTree.hooks["pre_transform"]: hook(self._tree, self) if not no_conflicts: self._check_malformed() self.rename_count = 0 with ui.ui_factory.nested_progress_bar() as child_pb: - child_pb.update(gettext('Apply phase'), 0, 2) + child_pb.update(gettext("Apply phase"), 0, 2) index_changes = self._generate_index_changes() offset = 1 if _mover is None: @@ -1339,9 +1374,9 @@ def apply(self, no_conflicts=False, _mover=None): else: mover = _mover try: - child_pb.update(gettext('Apply phase'), 0 + offset, 2 + offset) + child_pb.update(gettext("Apply phase"), 0 + offset, 2 + offset) self._apply_removals(mover) - child_pb.update(gettext('Apply phase'), 1 + offset, 2 + offset) + child_pb.update(gettext("Apply phase"), 1 + offset, 2 + offset) modified_paths = self._apply_insertions(mover) except BaseException: mover.rollback() @@ -1366,15 +1401,14 @@ def _apply_removals(self, mover): with ui.ui_factory.nested_progress_bar() as child_pb: for num, (path, trans_id) in enumerate(tree_paths): # do not attempt to move root into a subdirectory of itself. - if path == '': + if path == "": continue - child_pb.update(gettext('removing file'), num, len(tree_paths)) + child_pb.update(gettext("removing file"), num, len(tree_paths)) full_path = self._tree.abspath(path) if trans_id in self._removed_contents: delete_path = os.path.join(self._deletiondir, trans_id) mover.pre_delete(full_path, delete_path) - elif (trans_id in self._new_name or - trans_id in self._new_parent): + elif trans_id in self._new_name or trans_id in self._new_parent: try: mover.rename(full_path, self._limbo_name(trans_id)) except TransformRenameFailed as e: @@ -1398,8 +1432,7 @@ def _apply_insertions(self, mover): with ui.ui_factory.nested_progress_bar() as child_pb: for num, (path, trans_id) in enumerate(new_paths): if (num % 10) == 0: - child_pb.update(gettext('adding file'), - num, len(new_paths)) + child_pb.update(gettext("adding file"), num, len(new_paths)) full_path = self._tree.abspath(path) if trans_id in self._needs_rename: try: @@ -1413,8 +1446,7 @@ def _apply_insertions(self, mover): # TODO: if trans_id in self._observed_sha1s, we should # re-stat the final target, since ctime will be # updated by the change. - if (trans_id in self._new_contents - or self.path_changed(trans_id)): + if trans_id in self._new_contents or self.path_changed(trans_id): if trans_id in self._new_contents: modified_paths.append(full_path) if trans_id in self._new_executability: @@ -1424,21 +1456,30 @@ def _apply_insertions(self, mover): st = osutils.lstat(full_path) self._observed_sha1s[trans_id] = (o_sha1, st) if trans_id in self._new_reference_revision: - for (submodule_path, _submodule_url, _submodule_name) in self._tree._submodule_config(): + for ( + submodule_path, + _submodule_url, + _submodule_name, + ) in self._tree._submodule_config(): if decode_git_path(submodule_path) == path: break else: - trace.warning( - 'unable to find submodule for path %s', path) + trace.warning("unable to find submodule for path %s", path) continue submodule_transport = self._tree.controldir.control_transport.clone( - os.path.join('modules', submodule_name.decode('utf-8'))) + os.path.join("modules", submodule_name.decode("utf-8")) + ) submodule_transport.create_prefix() from .dir import BareLocalGitControlDirFormat - BareLocalGitControlDirFormat().initialize_on_transport(submodule_transport) - with open(os.path.join(full_path, '.git'), 'w') as f: - submodule_abspath = submodule_transport.local_abspath('.') - f.write(f'gitdir: {os.path.relpath(submodule_abspath, full_path)}\n') + + BareLocalGitControlDirFormat().initialize_on_transport( + submodule_transport + ) + with open(os.path.join(full_path, ".git"), "w") as f: + submodule_abspath = submodule_transport.local_abspath(".") + f.write( + f"gitdir: {os.path.relpath(submodule_abspath, full_path)}\n" + ) for _path, trans_id in new_paths: # new_paths includes stuff like workingtree conflicts. Only the # stuff in new_contents actually comes from limbo. @@ -1453,8 +1494,12 @@ def _generate_index_changes(self): removed_id.update(self._removed_contents) changes = {} changed_ids = set() - for id_set in [self._new_name, self._new_parent, - self._new_executability, self._new_contents]: + for id_set in [ + self._new_name, + self._new_parent, + self._new_executability, + self._new_contents, + ]: changed_ids.update(id_set) for id_set in [self._new_name, self._new_parent]: removed_id.update(id_set) @@ -1463,11 +1508,12 @@ def _generate_index_changes(self): # Ignore entries that are already known to have changed. changed_kind.difference_update(changed_ids) # to keep only the truly changed ones - changed_kind = (t for t in changed_kind - if self.tree_kind(t) != self.final_kind(t)) + changed_kind = ( + t for t in changed_kind if self.tree_kind(t) != self.final_kind(t) + ) changed_ids.update(changed_kind) for t in changed_kind: - if self.final_kind(t) == 'directory': + if self.final_kind(t) == "directory": removed_id.add(t) changed_ids.remove(t) new_paths = sorted(FinalPaths(self).get_paths(changed_ids)) @@ -1475,8 +1521,7 @@ def _generate_index_changes(self): with ui.ui_factory.nested_progress_bar() as child_pb: for num, trans_id in enumerate(removed_id): if (num % 10) == 0: - child_pb.update(gettext('removing file'), - num, total_entries) + child_pb.update(gettext("removing file"), num, total_entries) try: path = self._tree_id_paths[trans_id] except KeyError: @@ -1484,8 +1529,9 @@ def _generate_index_changes(self): changes[path] = (None, None, None, None) for num, (path, trans_id) in enumerate(new_paths): if (num % 10) == 0: - child_pb.update(gettext('adding file'), - num + len(removed_id), total_entries) + child_pb.update( + gettext("adding file"), num + len(removed_id), total_entries + ) kind = self.final_kind(trans_id) if kind is None: @@ -1497,7 +1543,11 @@ def _generate_index_changes(self): reference_revision = self._new_reference_revision.get(trans_id) symlink_target = self._symlink_target.get(trans_id) changes[path] = ( - kind, executability, reference_revision, symlink_target) + kind, + executability, + reference_revision, + symlink_target, + ) return [(p, k, e, rr, st) for (p, (k, e, rr, st)) in changes.items()] @@ -1511,7 +1561,7 @@ class GitTransformPreview(GitTreeTransform): def __init__(self, tree, pb=None, case_sensitive=True): tree.lock_read() - limbodir = tempfile.mkdtemp(prefix='git-limbo-') + limbodir = tempfile.mkdtemp(prefix="git-limbo-") DiskTreeTransform.__init__(self, tree, limbodir, pb, case_sensitive) def canonical_path(self, path): @@ -1522,7 +1572,7 @@ def tree_kind(self, trans_id): if path is None: return None kind = self._tree.path_content_summary(path)[0] - if kind == 'missing': + if kind == "missing": kind = None return kind @@ -1572,13 +1622,13 @@ def supports_symlinks(self): def _supports_executable(self): return self._transform._limbo_supports_executable() - def walkdirs(self, prefix=''): + def walkdirs(self, prefix=""): pending = [self._transform.root] while len(pending) > 0: parent_id = pending.pop() children = [] subdirs = [] - prefix = prefix.rstrip('/') + prefix = prefix.rstrip("/") parent_path = self._final_paths.get_path(parent_id) for child_id in self._all_children(parent_id): path_from_root = self._final_paths.get_path(child_id) @@ -1587,22 +1637,28 @@ def walkdirs(self, prefix=''): if kind is not None: versioned_kind = kind else: - kind = 'unknown' - versioned_kind = self._transform._tree.stored_kind( - path_from_root) - if versioned_kind == 'directory': + kind = "unknown" + versioned_kind = self._transform._tree.stored_kind(path_from_root) + if versioned_kind == "directory": subdirs.append(child_id) - children.append((path_from_root, basename, kind, None, - versioned_kind)) + children.append((path_from_root, basename, kind, None, versioned_kind)) children.sort() if parent_path.startswith(prefix): yield parent_path, children - pending.extend(sorted(subdirs, key=self._final_paths.get_path, - reverse=True)) - - def iter_changes(self, from_tree, include_unchanged=False, - specific_files=None, pb=None, extra_trees=None, - require_versioned=True, want_unversioned=False): + pending.extend( + sorted(subdirs, key=self._final_paths.get_path, reverse=True) + ) + + def iter_changes( + self, + from_tree, + include_unchanged=False, + specific_files=None, + pb=None, + extra_trees=None, + require_versioned=True, + want_unversioned=False, + ): """See InterTree.iter_changes. This has a fast path that is only used when the from_tree matches @@ -1614,7 +1670,8 @@ def iter_changes(self, from_tree, include_unchanged=False, pb=pb, extra_trees=extra_trees, require_versioned=require_versioned, - want_unversioned=want_unversioned) + want_unversioned=want_unversioned, + ) def get_file(self, path): """See Tree.get_file.""" @@ -1623,7 +1680,7 @@ def get_file(self, path): raise _mod_transport.NoSuchFile(path) if trans_id in self._transform._new_contents: name = self._transform._limbo_name(trans_id) - return open(name, 'rb') + return open(name, "rb") if trans_id in self._transform._removed_contents: raise _mod_transport.NoSuchFile(path) orig_path = self._transform.tree_path(trans_id) @@ -1647,7 +1704,8 @@ def annotate_iter(self, path, default_revision=_mod_revision.CURRENT_REVISION): orig_path = self._transform.tree_path(trans_id) if orig_path is not None: old_annotation = self._transform._tree.annotate_iter( - orig_path, default_revision=default_revision) + orig_path, default_revision=default_revision + ) else: old_annotation = [] try: @@ -1683,8 +1741,10 @@ def get_file_lines(self, path): return osutils.split_lines(self.get_file_text(path)) def extras(self): - possible_extras = {self._transform.trans_id_tree_path(p) for p - in self._transform._tree.extras()} + possible_extras = { + self._transform.trans_id_tree_path(p) + for p in self._transform._tree.extras() + } possible_extras.update(self._transform._new_contents) possible_extras.update(self._transform._removed_id) for trans_id in possible_extras: @@ -1698,15 +1758,15 @@ def path_content_summary(self, path): kind = tt._new_contents.get(trans_id) if kind is None: if tree_path is None or trans_id in tt._removed_contents: - return 'missing', None, None, None + return "missing", None, None, None summary = tt._tree.path_content_summary(tree_path) kind, size, executable, link_or_sha1 = summary else: link_or_sha1 = None limbo_name = tt._limbo_name(trans_id) if trans_id in tt._new_reference_revision: - kind = 'tree-reference' - if kind == 'file': + kind = "tree-reference" + if kind == "file": statval = os.lstat(limbo_name) size = statval.st_size if not tt._limbo_supports_executable(): @@ -1716,7 +1776,7 @@ def path_content_summary(self, path): else: size = None executable = None - if kind == 'symlink': + if kind == "symlink": link_or_sha1 = os.readlink(limbo_name) if not isinstance(link_or_sha1, str): link_or_sha1 = os.fsdecode(link_or_sha1) @@ -1730,7 +1790,8 @@ def get_file_mtime(self, path): raise _mod_transport.NoSuchFile(path) if trans_id not in self._transform._new_contents: return self._transform._tree.get_file_mtime( - self._transform.tree_path(trans_id)) + self._transform.tree_path(trans_id) + ) name = self._transform._limbo_name(trans_id) statval = os.lstat(name) return statval.st_mtime @@ -1749,8 +1810,7 @@ def is_versioned(self, path): def iter_entries_by_dir(self, specific_files=None, recurse_nested=False): if recurse_nested: - raise NotImplementedError( - 'follow tree references not yet supported') + raise NotImplementedError("follow tree references not yet supported") # This may not be a maximally efficient implementation, but it is # reasonably straightforward. An implementation that grafts the @@ -1761,7 +1821,7 @@ def iter_entries_by_dir(self, specific_files=None, recurse_nested=False): entry, is_versioned = self._transform.final_entry(trans_id) if entry is None: continue - if not is_versioned and entry.kind != 'directory': + if not is_versioned and entry.kind != "directory": continue if specific_files is not None and path not in specific_files: continue diff --git a/breezy/git/transportgit.py b/breezy/git/transportgit.py index f1d60884d0..44f8332c6a 100644 --- a/breezy/git/transportgit.py +++ b/breezy/git/transportgit.py @@ -73,7 +73,6 @@ class _RemoteGitFile: - def __init__(self, transport, filename, mode, bufsize, mask): self.transport = transport self.filename = filename @@ -81,6 +80,7 @@ def __init__(self, transport, filename, mode, bufsize, mask): self.bufsize = bufsize self.mask = mask import tempfile + self._file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024) self._closed = False for method in _GitFile.PROXY_METHODS: @@ -171,8 +171,7 @@ def allkeys(self): else: keys.add(b"HEAD") try: - iter_files = list(self.transport.clone( - "refs").iter_files_recursive()) + iter_files = list(self.transport.clone("refs").iter_files_recursive()) for filename in iter_files: unquoted_filename = urlutils.unquote_to_bytes(filename) refname = osutils.pathjoin(b"refs", unquoted_filename) @@ -203,8 +202,7 @@ def get_packed_refs(self): return {} try: first_line = next(iter(f)).rstrip() - if (first_line.startswith(b"# pack-refs") and b" peeled" in - first_line): + if first_line.startswith(b"# pack-refs") and b" peeled" in first_line: for sha, name, peeled in read_packed_refs_with_peeled(f): self._packed_refs[name] = sha if peeled: @@ -246,7 +244,7 @@ def read_loose_ref(self, name): exist. :raises IOError: if any other error occurs """ - if name == b'HEAD': + if name == b"HEAD": transport = self.worktree_transport else: transport = self.transport @@ -292,13 +290,12 @@ def set_symbolic_ref(self, name, other): """ self._check_refname(name) self._check_refname(other) - if name != b'HEAD': + if name != b"HEAD": transport = self.transport self._ensure_dir_exists(urlutils.quote_from_bytes(name)) else: transport = self.worktree_transport - transport.put_bytes(urlutils.quote_from_bytes( - name), SYMREF + other + b'\n') + transport.put_bytes(urlutils.quote_from_bytes(name), SYMREF + other + b"\n") def set_if_equals(self, name, old_ref, new_ref): """Set a refname to new_ref only if it currently equals old_ref. @@ -318,13 +315,12 @@ def set_if_equals(self, name, old_ref, new_ref): realname = realnames[-1] except (KeyError, IndexError, SymrefLoop): realname = name - if realname == b'HEAD': + if realname == b"HEAD": transport = self.worktree_transport else: transport = self.transport self._ensure_dir_exists(urlutils.quote_from_bytes(realname)) - transport.put_bytes(urlutils.quote_from_bytes( - realname), new_ref + b"\n") + transport.put_bytes(urlutils.quote_from_bytes(realname), new_ref + b"\n") return True def add_if_new(self, name, ref): @@ -345,7 +341,7 @@ def add_if_new(self, name, ref): except (KeyError, IndexError): realname = name self._check_refname(realname) - if realname == b'HEAD': + if realname == b"HEAD": transport = self.worktree_transport else: transport = self.transport @@ -366,7 +362,7 @@ def remove_if_equals(self, name, old_ref): """ self._check_refname(name) # may only be packed - if name == b'HEAD': + if name == b"HEAD": transport = self.worktree_transport else: transport = self.transport @@ -402,20 +398,20 @@ def lock_ref(self, name): self._ensure_dir_exists(urlutils.quote_from_bytes(name)) lockname = urlutils.quote_from_bytes(name + b".lock") try: - transport.local_abspath( - urlutils.quote_from_bytes(name)) + transport.local_abspath(urlutils.quote_from_bytes(name)) except NotLocalUrl as err: # This is racy, but what can we do? if transport.has(lockname): raise LockContention(name) from err - transport.put_bytes(lockname, b'Locked by brz-git') + transport.put_bytes(lockname, b"Locked by brz-git") return LogicalLockResult(lambda: transport.delete(lockname)) else: try: - gf = TransportGitFile(transport, urlutils.quote_from_bytes(name), 'wb') + gf = TransportGitFile(transport, urlutils.quote_from_bytes(name), "wb") except FileLocked as e: raise LockContention(name, e) from e else: + def unlock(): try: transport.delete(lockname) @@ -424,6 +420,7 @@ def unlock(): # GitFile.abort doesn't care if the lock has already # disappeared gf.abort() + return LogicalLockResult(unlock) @@ -440,11 +437,10 @@ def read_gitfile(f): cs = f.read() if not cs.startswith(b"gitdir: "): raise ValueError("Expected file to start with 'gitdir: '") - return cs[len(b"gitdir: "):].rstrip(b"\n") + return cs[len(b"gitdir: ") :].rstrip(b"\n") class TransportRepo(BaseRepo): - def __init__(self, transport, bare, refs_text=None) -> None: self.transport = transport self.bare = bare @@ -455,46 +451,51 @@ def __init__(self, transport, bare, refs_text=None) -> None: if self.bare: self._controltransport = self.transport else: - self._controltransport = self.transport.clone('.git') + self._controltransport = self.transport.clone(".git") else: self._controltransport = self.transport.clone( - urlutils.quote_from_bytes(path)) + urlutils.quote_from_bytes(path) + ) commondir = self.get_named_file(COMMONDIR) if commondir is not None: with commondir: commondir = os.path.join( self.controldir(), - commondir.read().rstrip(b"\r\n").decode( - sys.getfilesystemencoding())) - self._commontransport = \ - _mod_transport.get_transport_from_path(commondir) + commondir.read() + .rstrip(b"\r\n") + .decode(sys.getfilesystemencoding()), + ) + self._commontransport = _mod_transport.get_transport_from_path( + commondir + ) else: self._commontransport = self._controltransport config = self.get_config() object_store = TransportObjectStore.from_config( - self._commontransport.clone(OBJECTDIR), - config) + self._commontransport.clone(OBJECTDIR), config + ) refs_container: RefsContainer if refs_text is not None: refs_container = InfoRefsContainer(BytesIO(refs_text)) try: - head = TransportRefsContainer( - self._commontransport).read_loose_ref(b"HEAD") + head = TransportRefsContainer(self._commontransport).read_loose_ref( + b"HEAD" + ) except KeyError: pass else: refs_container._refs[b"HEAD"] = head else: refs_container = TransportRefsContainer( - self._commontransport, self._controltransport) - super().__init__(object_store, - refs_container) + self._commontransport, self._controltransport + ) + super().__init__(object_store, refs_container) def controldir(self): - return self._controltransport.local_abspath('.') + return self._controltransport.local_abspath(".") def commondir(self): - return self._commontransport.local_abspath('.') + return self._commontransport.local_abspath(".") def close(self): """Close any files opened by this repository.""" @@ -502,11 +503,11 @@ def close(self): @property def path(self): - return self.transport.local_abspath('.') + return self.transport.local_abspath(".") def _determine_file_mode(self): # Be consistent with bzr - if sys.platform == 'win32': + if sys.platform == "win32": return False return True @@ -528,7 +529,7 @@ def get_named_file(self, path): :return: An open file object, or None if the file does not exist. """ try: - return self._controltransport.get(path.lstrip('/')) + return self._controltransport.get(path.lstrip("/")) except NoSuchFile: return None @@ -542,6 +543,7 @@ def index_path(self): def open_index(self): """Open the index for this repository.""" from dulwich.index import Index + if not self.has_index(): raise NoIndexPresent() return Index(self.index_path()) @@ -554,14 +556,16 @@ def has_index(self): def get_config(self): from dulwich.config import ConfigFile + try: - with self._controltransport.get('config') as f: + with self._controltransport.get("config") as f: return ConfigFile.from_file(f) except NoSuchFile: return ConfigFile() def get_config_stack(self): from dulwich.config import StackedConfig + backends = [] p = self.get_config() if p is not None: @@ -604,8 +608,9 @@ def init(cls, transport, bare=False): class TransportObjectStore(PackBasedObjectStore): """Git-style object store that exists on disk.""" - def __init__(self, transport, - loose_compression_level=-1, pack_compression_level=-1): + def __init__( + self, transport, loose_compression_level=-1, pack_compression_level=-1 + ): """Open an object store. :param transport: Transport to open data from @@ -620,18 +625,21 @@ def __init__(self, transport, @classmethod def from_config(cls, path, config): try: - default_compression_level = int(config.get( - (b'core', ), b'compression').decode()) + default_compression_level = int( + config.get((b"core",), b"compression").decode() + ) except KeyError: default_compression_level = -1 try: - loose_compression_level = int(config.get( - (b'core', ), b'looseCompression').decode()) + loose_compression_level = int( + config.get((b"core",), b"looseCompression").decode() + ) except KeyError: loose_compression_level = default_compression_level try: - pack_compression_level = int(config.get( - (b'core', ), 'packCompression').decode()) + pack_compression_level = int( + config.get((b"core",), "packCompression").decode() + ) except KeyError: pack_compression_level = default_compression_level return cls(path, loose_compression_level, pack_compression_level) @@ -680,12 +688,9 @@ def _update_pack_cache(self): size = self.pack_transport.stat(pack_name).st_size except TransportNotPossible: size = None - pd = PackData( - pack_name, self.pack_transport.get(pack_name), - size=size) + pd = PackData(pack_name, self.pack_transport.get(pack_name), size=size) idxname = basename + ".idx" - idx = load_pack_index_file( - idxname, self.pack_transport.get(idxname)) + idx = load_pack_index_file(idxname, self.pack_transport.get(idxname)) pack = Pack.from_objects(pd, idx) pack._basename = basename self._pack_cache[basename] = pack @@ -708,15 +713,17 @@ def _pack_names(self): pack_files.append(os.path.splitext(name)[0]) except TransportNotPossible: try: - f = self.transport.get('info/packs') + f = self.transport.get("info/packs") except NoSuchFile: - warning('No info/packs on remote host;' - 'run \'git update-server-info\' on remote.') + warning( + "No info/packs on remote host;" + "run 'git update-server-info' on remote." + ) else: with f: pack_files = [ - os.path.splitext(name)[0] - for name in read_packs_file(f)] + os.path.splitext(name)[0] for name in read_packs_file(f) + ] except NoSuchFile: pass return pack_files @@ -730,7 +737,7 @@ def _remove_pack(self, pack): pass def _iter_loose_objects(self): - for base in self.transport.list_dir('.'): + for base in self.transport.list_dir("."): if len(base) != 2: continue for rest in self.transport.list_dir(base): @@ -768,7 +775,8 @@ def add_object(self, obj): # the compression_level parameter. if self.loose_compression_level not in (-1, None): raw_string = obj.as_legacy_object( - compression_level=self.loose_compression_level) + compression_level=self.loose_compression_level + ) else: raw_string = obj.as_legacy_object() self.transport.put_bytes(path, raw_string) @@ -776,7 +784,7 @@ def add_object(self, obj): @classmethod def init(cls, transport): try: - transport.mkdir('info') + transport.mkdir("info") except FileExists: pass try: @@ -799,12 +807,18 @@ def _complete_pack(self, f, path, num_objects, indexer, progress=None): entries = [] for i, entry in enumerate(indexer): if progress is not None: - progress(("generating index: %d/%d\r" % (i, num_objects)).encode('ascii')) + progress( + ("generating index: %d/%d\r" % (i, num_objects)).encode("ascii") + ) entries.append(entry) pack_sha, extra_entries = extend_pack( - f, indexer.ext_refs(), get_raw=self.get_raw, compression_level=self.pack_compression_level, - progress=progress) + f, + indexer.ext_refs(), + get_raw=self.get_raw, + compression_level=self.pack_compression_level, + progress=progress, + ) f.flush() try: fileno = f.fileno() @@ -817,7 +831,9 @@ def _complete_pack(self, f, path, num_objects, indexer, progress=None): # Move the pack in. entries.sort() - pack_base_name = "pack-" + iter_sha1(entry[0] for entry in entries).decode('ascii') + pack_base_name = "pack-" + iter_sha1(entry[0] for entry in entries).decode( + "ascii" + ) for pack in self.packs: if osutils.basename(pack._basename) == pack_base_name: @@ -835,19 +851,29 @@ def _complete_pack(self, f, path, num_objects, indexer, progress=None): if path: f.close() self.pack_transport.ensure_base() - os.rename(path, self.transport.local_abspath(osutils.pathjoin(PACKDIR, target_pack_name))) + os.rename( + path, + self.transport.local_abspath( + osutils.pathjoin(PACKDIR, target_pack_name) + ), + ) else: f.seek(0) self.pack_transport.put_file(target_pack_name, f, mode=PACK_MODE) # Write the index. - with TransportGitFile(self.pack_transport, target_pack_index, "wb", mask=PACK_MODE) as index_file: + with TransportGitFile( + self.pack_transport, target_pack_index, "wb", mask=PACK_MODE + ) as index_file: write_pack_index(index_file, entries, pack_sha) # Add the pack to the store and return it. final_pack = Pack.from_objects( PackData(target_pack_name, self.pack_transport.get(target_pack_name)), - load_pack_index_file(target_pack_index, self.pack_transport.get(target_pack_index))) + load_pack_index_file( + target_pack_index, self.pack_transport.get(target_pack_index) + ), + ) final_pack._basename = pack_base_name final_pack.check_length_and_checksum() @@ -872,7 +898,7 @@ def add_thin_pack(self, read_all, read_some, progress=None): import tempfile try: - dir = self.transport.local_abspath('.') + dir = self.transport.local_abspath(".") except NotLocalUrl: f = tempfile.SpooledTemporaryFile(prefix="tmp_pack_") path = None @@ -903,7 +929,7 @@ def add_pack(self): import tempfile try: - dir = self.transport.local_abspath('.') + dir = self.transport.local_abspath(".") except NotLocalUrl: f = tempfile.SpooledTemporaryFile(prefix="tmp_pack_") path = None @@ -915,7 +941,9 @@ def commit(): if f.tell() > 0: f.seek(0) with PackData(path, f) as pd: - indexer = PackIndexer.for_pack_data(pd, resolve_ext_ref=self.get_raw) + indexer = PackIndexer.for_pack_data( + pd, resolve_ext_ref=self.get_raw + ) return self._complete_pack(f, path, len(pd), indexer) else: abort() diff --git a/breezy/git/tree.py b/breezy/git/tree.py index ff887a21ab..d981806af3 100644 --- a/breezy/git/tree.py +++ b/breezy/git/tree.py @@ -59,8 +59,7 @@ class GitTreeDirectory(_mod_tree.TreeDirectory): - - __slots__ = ['file_id', 'name', 'parent_id', 'git_sha1'] + __slots__ = ["file_id", "name", "parent_id", "git_sha1"] def __init__(self, file_id, name, parent_id, git_sha1=None): self.file_id = file_id @@ -70,35 +69,35 @@ def __init__(self, file_id, name, parent_id, git_sha1=None): @property def kind(self): - return 'directory' + return "directory" @property def executable(self): return False def copy(self): - return self.__class__( - self.file_id, self.name, self.parent_id) + return self.__class__(self.file_id, self.name, self.parent_id) def __repr__(self): return "{}(file_id={!r}, name={!r}, parent_id={!r})".format( - self.__class__.__name__, self.file_id, self.name, - self.parent_id) + self.__class__.__name__, self.file_id, self.name, self.parent_id + ) def __eq__(self, other): - return (self.kind == other.kind and - self.file_id == other.file_id and - self.name == other.name and - self.parent_id == other.parent_id) + return ( + self.kind == other.kind + and self.file_id == other.file_id + and self.name == other.name + and self.parent_id == other.parent_id + ) class GitTreeFile(_mod_tree.TreeFile): + __slots__ = ["file_id", "name", "parent_id", "text_size", "executable", "git_sha1"] - __slots__ = ['file_id', 'name', 'parent_id', 'text_size', - 'executable', 'git_sha1'] - - def __init__(self, file_id, name, parent_id, text_size=None, - git_sha1=None, executable=None): + def __init__( + self, file_id, name, parent_id, text_size=None, git_sha1=None, executable=None + ): self.file_id = file_id self.name = name self.parent_id = parent_id @@ -108,26 +107,35 @@ def __init__(self, file_id, name, parent_id, text_size=None, @property def kind(self): - return 'file' + return "file" def __eq__(self, other): - return (self.kind == other.kind and - self.file_id == other.file_id and - self.name == other.name and - self.parent_id == other.parent_id and - self.git_sha1 == other.git_sha1 and - self.text_size == other.text_size and - self.executable == other.executable) + return ( + self.kind == other.kind + and self.file_id == other.file_id + and self.name == other.name + and self.parent_id == other.parent_id + and self.git_sha1 == other.git_sha1 + and self.text_size == other.text_size + and self.executable == other.executable + ) def __repr__(self): - return ("{}(file_id={!r}, name={!r}, parent_id={!r}, text_size={!r}, " - "git_sha1={!r}, executable={!r})").format( - type(self).__name__, self.file_id, self.name, self.parent_id, - self.text_size, self.git_sha1, self.executable) + return ( + "{}(file_id={!r}, name={!r}, parent_id={!r}, text_size={!r}, " + "git_sha1={!r}, executable={!r})" + ).format( + type(self).__name__, + self.file_id, + self.name, + self.parent_id, + self.text_size, + self.git_sha1, + self.executable, + ) def copy(self): - ret = self.__class__( - self.file_id, self.name, self.parent_id) + ret = self.__class__(self.file_id, self.name, self.parent_id) ret.git_sha1 = self.git_sha1 ret.text_size = self.text_size ret.executable = self.executable @@ -135,11 +143,9 @@ def copy(self): class GitTreeSymlink(_mod_tree.TreeLink): + __slots__ = ["file_id", "name", "parent_id", "symlink_target", "git_sha1"] - __slots__ = ['file_id', 'name', 'parent_id', 'symlink_target', 'git_sha1'] - - def __init__(self, file_id, name, parent_id, - symlink_target=None, git_sha1=None): + def __init__(self, file_id, name, parent_id, symlink_target=None, git_sha1=None): self.file_id = file_id self.name = name self.parent_id = parent_id @@ -148,7 +154,7 @@ def __init__(self, file_id, name, parent_id, @property def kind(self): - return 'symlink' + return "symlink" @property def executable(self): @@ -159,28 +165,37 @@ def text_size(self): return None def __repr__(self): - return "{}(file_id={!r}, name={!r}, parent_id={!r}, symlink_target={!r})".format( - type(self).__name__, self.file_id, self.name, self.parent_id, - self.symlink_target) + return ( + "{}(file_id={!r}, name={!r}, parent_id={!r}, symlink_target={!r})".format( + type(self).__name__, + self.file_id, + self.name, + self.parent_id, + self.symlink_target, + ) + ) def __eq__(self, other): - return (self.kind == other.kind and - self.file_id == other.file_id and - self.name == other.name and - self.parent_id == other.parent_id and - self.symlink_target == other.symlink_target) + return ( + self.kind == other.kind + and self.file_id == other.file_id + and self.name == other.name + and self.parent_id == other.parent_id + and self.symlink_target == other.symlink_target + ) def copy(self): return self.__class__( - self.file_id, self.name, self.parent_id, - self.symlink_target) + self.file_id, self.name, self.parent_id, self.symlink_target + ) class GitTreeSubmodule(_mod_tree.TreeReference): + __slots__ = ["file_id", "name", "parent_id", "reference_revision", "git_sha1"] - __slots__ = ['file_id', 'name', 'parent_id', 'reference_revision', 'git_sha1'] - - def __init__(self, file_id, name, parent_id, reference_revision=None, git_sha1=None): + def __init__( + self, file_id, name, parent_id, reference_revision=None, git_sha1=None + ): self.file_id = file_id self.name = name self.parent_id = parent_id @@ -193,33 +208,40 @@ def executable(self): @property def kind(self): - return 'tree-reference' + return "tree-reference" def __repr__(self): - return ("{}(file_id={!r}, name={!r}, parent_id={!r}, " - "reference_revision={!r})").format( - type(self).__name__, self.file_id, self.name, self.parent_id, - self.reference_revision) + return ( + "{}(file_id={!r}, name={!r}, parent_id={!r}, " "reference_revision={!r})" + ).format( + type(self).__name__, + self.file_id, + self.name, + self.parent_id, + self.reference_revision, + ) def __eq__(self, other): - return (self.kind == other.kind and - self.file_id == other.file_id and - self.name == other.name and - self.parent_id == other.parent_id and - self.reference_revision == other.reference_revision) + return ( + self.kind == other.kind + and self.file_id == other.file_id + and self.name == other.name + and self.parent_id == other.parent_id + and self.reference_revision == other.reference_revision + ) def copy(self): return self.__class__( - self.file_id, self.name, self.parent_id, - self.reference_revision) + self.file_id, self.name, self.parent_id, self.reference_revision + ) entry_factory = { - 'directory': GitTreeDirectory, - 'file': GitTreeFile, - 'symlink': GitTreeSymlink, - 'tree-reference': GitTreeSubmodule, - } + "directory": GitTreeDirectory, + "file": GitTreeFile, + "symlink": GitTreeSymlink, + "tree-reference": GitTreeSubmodule, +} def ensure_normalized_path(path): @@ -239,14 +261,13 @@ def ensure_normalized_path(path): class GitTree(_mod_tree.Tree): - supports_file_ids = False store: BaseObjectStore @classmethod def is_special_path(cls, path): - return path.startswith('.git') + return path.startswith(".git") def supports_symlinks(self): return True @@ -267,10 +288,12 @@ def git_snapshot(self, want_unversioned=False): def preview_transform(self, pb=None): from .transform import GitTransformPreview + return GitTransformPreview(self, pb=pb) - def find_related_paths_across_trees(self, paths, trees=None, - require_versioned=True): + def find_related_paths_across_trees( + self, paths, trees=None, require_versioned=True + ): if trees is None: trees = [] if paths is None: @@ -282,7 +305,7 @@ def include(t, p): # Include directories, since they may exist but just be # empty try: - if t.kind(p) == 'directory': + if t.kind(p) == "directory": return True except _mod_transport.NoSuchFile: return False @@ -304,7 +327,7 @@ def include(t, p): def _submodule_config(self): if self._submodules is None: try: - with self.get_file('.gitmodules') as f: + with self.get_file(".gitmodules") as f: config = GitConfigFile.from_file(f) self._submodules = list(parse_submodules(config)) except _mod_transport.NoSuchFile: @@ -312,17 +335,16 @@ def _submodule_config(self): return self._submodules def _submodule_info(self): - return {path: (url, section) - for path, url, section in self._submodule_config()} + return {path: (url, section) for path, url, section in self._submodule_config()} def reference_parent(self, path): from ..branch import Branch + (url, section) = self._submodule_info()[encode_git_path(path)] - return Branch.open(url.decode('utf-8')) + return Branch.open(url.decode("utf-8")) class RemoteNestedTree(MissingNestedTree): - _fmt = "Unable to access remote nested tree at %(path)s" @@ -336,8 +358,7 @@ def __init__(self, repository, revision_id): self.store = repository._git.object_store if not isinstance(revision_id, bytes): raise TypeError(revision_id) - self.commit_id, self.mapping = repository.lookup_bzr_revision_id( - revision_id) + self.commit_id, self.mapping = repository.lookup_bzr_revision_id(revision_id) if revision_id == NULL_REVISION: self.tree = None self.mapping = default_mapping @@ -360,30 +381,38 @@ def _get_submodule_repository(self, relpath): nested_repo_transport = None else: nested_repo_transport = self._repository.controldir.control_transport.clone( - posixpath.join('modules', decode_git_path(section))) - if not nested_repo_transport.has('.'): + posixpath.join("modules", decode_git_path(section)) + ) + if not nested_repo_transport.has("."): nested_url = urlutils.join( - self._repository.controldir.user_url, decode_git_path(url)) + self._repository.controldir.user_url, decode_git_path(url) + ) nested_repo_transport = get_transport(nested_url) if nested_repo_transport is None: nested_repo_transport = self._repository.controldir.user_transport.clone( - decode_git_path(relpath)) + decode_git_path(relpath) + ) else: nested_repo_transport = self._repository.controldir.control_transport.clone( - posixpath.join('modules', decode_git_path(section))) - if not nested_repo_transport.has('.'): - nested_repo_transport = self._repository.controldir.user_transport.clone( - posixpath.join(decode_git_path(section), '.git')) + posixpath.join("modules", decode_git_path(section)) + ) + if not nested_repo_transport.has("."): + nested_repo_transport = ( + self._repository.controldir.user_transport.clone( + posixpath.join(decode_git_path(section), ".git") + ) + ) try: nested_controldir = _mod_controldir.ControlDir.open_from_transport( - nested_repo_transport) + nested_repo_transport + ) except errors.NotBranchError as e: raise MissingNestedTree(decode_git_path(relpath)) from e return nested_controldir.find_repository() def _get_submodule_store(self, relpath): repo = self._get_submodule_repository(relpath) - if not hasattr(repo, '_git'): + if not hasattr(repo, "_git"): raise RemoteNestedTree(relpath) return repo._git.object_store @@ -401,7 +430,8 @@ def get_file_revision(self, path): if self.commit_id == ZERO_SHA: return NULL_REVISION (store, unused_path, commit_id) = change_scanner.find_last_change_revision( - encode_git_path(path), self.commit_id) + encode_git_path(path), self.commit_id + ) return self.mapping.revision_id_foreign_to_bzr(commit_id) def get_file_mtime(self, path): @@ -410,7 +440,8 @@ def get_file_mtime(self, path): return NULL_REVISION try: (store, unused_path, commit_id) = change_scanner.find_last_change_revision( - encode_git_path(path), self.commit_id) + encode_git_path(path), self.commit_id + ) except KeyError as err: raise _mod_transport.NoSuchFile(path) from err commit = store[commit_id] @@ -427,8 +458,8 @@ def path2id(self, path): return self.mapping.generate_file_id(osutils.safe_unicode(path)) def all_versioned_paths(self): - ret = {''} - todo = [(self.store, b'', self.tree)] + ret = {""} + todo = [(self.store, b"", self.tree)] while todo: (store, path, tree_id) = todo.pop() if tree_id is None: @@ -446,7 +477,7 @@ def _lookup_path(self, path): raise _mod_transport.NoSuchFile(path) encoded_path = encode_git_path(path) - parts = encoded_path.split(b'/') + parts = encoded_path.split(b"/") hexsha = self.tree store = self.store mode = None @@ -461,7 +492,7 @@ def _lookup_path(self, path): except KeyError as err: raise _mod_transport.NoSuchFile(path) from err if S_ISGITLINK(mode) and i != len(parts) - 1: - store = self._get_submodule_store(b'/'.join(parts[:i + 1])) + store = self._get_submodule_store(b"/".join(parts[: i + 1])) hexsha = store[hexsha].tree return (store, mode, hexsha) @@ -487,11 +518,12 @@ def has_filename(self, path): else: return True - def list_files(self, include_root=False, from_dir=None, recursive=True, - recurse_nested=False): + def list_files( + self, include_root=False, from_dir=None, recursive=True, recurse_nested=False + ): if self.tree is None: return - if from_dir is None or from_dir == '.': + if from_dir is None or from_dir == ".": from_dir = "" (store, mode, hexsha) = self._lookup_path(from_dir) if mode is None: # Root @@ -499,18 +531,23 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, else: parent_path = posixpath.dirname(from_dir) parent_id = self.mapping.generate_file_id(parent_path) - if mode_kind(mode) == 'directory': + if mode_kind(mode) == "directory": root_ie = self._get_dir_ie(encode_git_path(from_dir), parent_id) else: root_ie = self._get_file_ie( - store, encode_git_path(from_dir), - posixpath.basename(from_dir), mode, hexsha) + store, + encode_git_path(from_dir), + posixpath.basename(from_dir), + mode, + hexsha, + ) if include_root: yield (from_dir, "V", root_ie.kind, root_ie) todo = [] - if root_ie.kind == 'directory': - todo.append((store, encode_git_path(from_dir), - b"", hexsha, root_ie.file_id)) + if root_ie.kind == "directory": + todo.append( + (store, encode_git_path(from_dir), b"", hexsha, root_ie.file_id) + ) while todo: (store, path, relpath, hexsha, parent_id) = todo.pop() tree = store[hexsha] @@ -527,14 +564,17 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, ie = self._get_dir_ie(child_path, parent_id) if recursive: todo.append( - (store, child_path, child_relpath, hexsha, - ie.file_id)) + (store, child_path, child_relpath, hexsha, ie.file_id) + ) else: ie = self._get_file_ie( - store, child_path, name, mode, hexsha, parent_id) + store, child_path, name, mode, hexsha, parent_id + ) yield (decode_git_path(child_relpath), "V", ie.kind, ie) - def _get_file_ie(self, store, path: str, name: str, mode: int, hexsha: bytes, parent_id): + def _get_file_ie( + self, store, path: str, name: str, mode: int, hexsha: bytes, parent_id + ): if not isinstance(path, bytes): raise TypeError(path) if not isinstance(name, bytes): @@ -544,11 +584,10 @@ def _get_file_ie(self, store, path: str, name: str, mode: int, hexsha: bytes, pa name = decode_git_path(name) file_id = self.mapping.generate_file_id(path) ie = entry_factory[kind](file_id, name, parent_id, git_sha1=hexsha) - if kind == 'symlink': + if kind == "symlink": ie.symlink_target = decode_git_path(store[hexsha].data) - elif kind == 'tree-reference': - ie.reference_revision = self.mapping.revision_id_foreign_to_bzr( - hexsha) + elif kind == "tree-reference": + ie.reference_revision = self.mapping.revision_id_foreign_to_bzr(hexsha) else: ie.git_sha1 = hexsha ie.text_size = None @@ -576,20 +615,17 @@ def iter_child_entries(self, path: str): if stat.S_ISDIR(mode): yield self._get_dir_ie(child_path, file_id) else: - yield self._get_file_ie(store, child_path, name, mode, hexsha, - file_id) + yield self._get_file_ie(store, child_path, name, mode, hexsha, file_id) - def iter_entries_by_dir(self, specific_files=None, - recurse_nested=False): + def iter_entries_by_dir(self, specific_files=None, recurse_nested=False): if self.tree is None: return if specific_files is not None: if specific_files in ([""], []): specific_files = None else: - specific_files = {encode_git_path(p) - for p in specific_files} - todo = deque([(self.store, b"", self.tree, self.path2id(''))]) + specific_files = {encode_git_path(p) for p in specific_files} + todo = deque([(self.store, b"", self.tree, self.path2id(""))]) if specific_files is None or "" in specific_files: yield "", self._get_dir_ie(b"", None) while todo: @@ -612,26 +648,36 @@ def iter_entries_by_dir(self, specific_files=None, else: substore = store if stat.S_ISDIR(mode): - if (specific_files is None or - any(p for p in specific_files if p.startswith( - child_path))): + if specific_files is None or any( + p for p in specific_files if p.startswith(child_path) + ): extradirs.append( - (substore, child_path, hexsha, - self.path2id(child_path_decoded))) + ( + substore, + child_path, + hexsha, + self.path2id(child_path_decoded), + ) + ) if specific_files is None or child_path in specific_files: if stat.S_ISDIR(mode): - yield (child_path_decoded, - self._get_dir_ie(child_path, parent_id)) + yield ( + child_path_decoded, + self._get_dir_ie(child_path, parent_id), + ) else: - yield (child_path_decoded, - self._get_file_ie(substore, child_path, name, mode, - hexsha, parent_id)) + yield ( + child_path_decoded, + self._get_file_ie( + substore, child_path, name, mode, hexsha, parent_id + ), + ) todo.extendleft(reversed(extradirs)) def iter_references(self): if self.supports_tree_reference(): for path, entry in self.iter_entries_by_dir(): - if entry.kind == 'tree-reference': + if entry.kind == "tree-reference": yield path def get_revision_id(self): @@ -695,27 +741,24 @@ def path_content_summary(self, path): try: (store, mode, hexsha) = self._lookup_path(path) except _mod_transport.NoSuchFile: - return ('missing', None, None, None) + return ("missing", None, None, None) kind = mode_kind(mode) - if kind == 'file': + if kind == "file": executable = mode_is_executable(mode) contents = store[hexsha].data - return (kind, len(contents), executable, - osutils.sha_string(contents)) - elif kind == 'symlink': + return (kind, len(contents), executable, osutils.sha_string(contents)) + elif kind == "symlink": return (kind, None, None, decode_git_path(store[hexsha].data)) - elif kind == 'tree-reference': + elif kind == "tree-reference": nested_repo = self._get_submodule_repository(encode_git_path(path)) - return (kind, None, None, - nested_repo.lookup_foreign_revision_id(hexsha)) + return (kind, None, None, nested_repo.lookup_foreign_revision_id(hexsha)) else: return (kind, None, None, None) def _iter_tree_contents(self, include_trees=False): if self.tree is None: return iter([]) - return iter_tree_contents( - self.store, self.tree, include_trees=include_trees) + return iter_tree_contents(self.store, self.tree, include_trees=include_trees) def annotate_iter(self, path, default_revision=CURRENT_REVISION): """Return an iterator of revision_id, line tuples. @@ -731,11 +774,14 @@ def annotate_iter(self, path, default_revision=CURRENT_REVISION): # Now we have the parents of this content from ..annotate import Annotator from .annotate import AnnotateProvider - annotator = Annotator(AnnotateProvider( - self._repository._file_change_scanner)) + + annotator = Annotator( + AnnotateProvider(self._repository._file_change_scanner) + ) this_key = (path, self.get_file_revision(path)) - annotations = [(key[-1], line) - for key, line in annotator.annotate_flat(this_key)] + annotations = [ + (key[-1], line) for key, line in annotator.annotate_flat(this_key) + ] return annotations def _get_rules_searcher(self, default_searcher): @@ -743,8 +789,7 @@ def _get_rules_searcher(self, default_searcher): def walkdirs(self, prefix=""): (store, mode, hexsha) = self._lookup_path(prefix) - todo = deque( - [(store, encode_git_path(prefix), hexsha)]) + todo = deque([(store, encode_git_path(prefix), hexsha)]) while todo: store, path, tree_sha = todo.popleft() path_decoded = decode_git_path(path) @@ -757,16 +802,26 @@ def walkdirs(self, prefix=""): if stat.S_ISDIR(mode): todo.append((store, child_path, hexsha)) children.append( - (decode_git_path(child_path), decode_git_path(name), - mode_kind(mode), None, - mode_kind(mode))) + ( + decode_git_path(child_path), + decode_git_path(name), + mode_kind(mode), + None, + mode_kind(mode), + ) + ) yield path_decoded, children -def tree_delta_from_git_changes(changes, mappings, - specific_files=None, - require_versioned=False, include_root=False, - source_extras=None, target_extras=None): +def tree_delta_from_git_changes( + changes, + mappings, + specific_files=None, + require_versioned=False, + include_root=False, + source_extras=None, + target_extras=None, +): """Create a TreeDelta from two git trees. source and target are iterators over tuples with: @@ -779,12 +834,12 @@ def tree_delta_from_git_changes(changes, mappings, source_extras = set() ret = delta.TreeDelta() added = [] - for (change_type, old, new) in changes: + for change_type, old, new in changes: (oldpath, oldmode, oldsha) = old (newpath, newmode, newsha) = new - if newpath == b'' and not include_root: + if newpath == b"" and not include_root: continue - copied = (change_type == 'copy') + copied = change_type == "copy" if oldpath is not None: oldpath_decoded = decode_git_path(oldpath) else: @@ -793,13 +848,17 @@ def tree_delta_from_git_changes(changes, mappings, newpath_decoded = decode_git_path(newpath) else: newpath_decoded = None - if not (specific_files is None or - (oldpath is not None and - osutils.is_inside_or_parent_of_any( - specific_files, oldpath_decoded)) or - (newpath is not None and - osutils.is_inside_or_parent_of_any( - specific_files, newpath_decoded))): + if not ( + specific_files is None + or ( + oldpath is not None + and osutils.is_inside_or_parent_of_any(specific_files, oldpath_decoded) + ) + or ( + newpath is not None + and osutils.is_inside_or_parent_of_any(specific_files, newpath_decoded) + ) + ): continue if oldpath is None: @@ -809,16 +868,16 @@ def tree_delta_from_git_changes(changes, mappings, oldparent = None oldversioned = False else: - oldversioned = (oldpath not in source_extras) + oldversioned = oldpath not in source_extras if oldmode: oldexe = mode_is_executable(oldmode) oldkind = mode_kind(oldmode) else: oldexe = False oldkind = None - if oldpath == b'': + if oldpath == b"": oldparent = None - oldname = '' + oldname = "" else: (oldparentpath, oldname) = osutils.split(oldpath_decoded) oldparent = old_mapping.generate_file_id(oldparentpath) @@ -829,16 +888,16 @@ def tree_delta_from_git_changes(changes, mappings, newparent = None newversioned = False else: - newversioned = (newpath not in target_extras) + newversioned = newpath not in target_extras if newmode: newexe = mode_is_executable(newmode) newkind = mode_kind(newmode) else: newexe = False newkind = None - if newpath_decoded == '': + if newpath_decoded == "": newparent = None - newname = '' + newname = "" else: newparentpath, newname = osutils.split(newpath_decoded) newparent = new_mapping.generate_file_id(newparentpath) @@ -855,25 +914,30 @@ def tree_delta_from_git_changes(changes, mappings, if oldpath is None and newpath is None: continue change = InventoryTreeChange( - fileid, (oldpath_decoded, newpath_decoded), (oldsha != newsha), + fileid, + (oldpath_decoded, newpath_decoded), + (oldsha != newsha), (oldversioned, newversioned), - (oldparent, newparent), (oldname, newname), - (oldkind, newkind), (oldexe, newexe), - copied=copied) - if newpath is not None and not newversioned and newkind != 'directory': + (oldparent, newparent), + (oldname, newname), + (oldkind, newkind), + (oldexe, newexe), + copied=copied, + ) + if newpath is not None and not newversioned and newkind != "directory": change.file_id = None ret.unversioned.append(change) - elif change_type == 'add': + elif change_type == "add": added.append((newpath, newkind, newsha)) elif newpath is None or newmode == 0: ret.removed.append(change) - elif change_type == 'delete': + elif change_type == "delete": ret.removed.append(change) - elif change_type == 'copy': + elif change_type == "copy": if stat.S_ISDIR(oldmode) and stat.S_ISDIR(newmode): continue ret.copied.append(change) - elif change_type == 'rename': + elif change_type == "rename": if stat.S_ISDIR(oldmode) and stat.S_ISDIR(newmode): continue ret.renamed.append(change) @@ -886,32 +950,43 @@ def tree_delta_from_git_changes(changes, mappings, else: ret.unchanged.append(change) - implicit_dirs = {''} + implicit_dirs = {""} for path, kind, _sha in added: - if kind == 'directory' or path in target_extras: + if kind == "directory" or path in target_extras: continue implicit_dirs.update(osutils.parent_directories(path)) for path, kind, _sha in added: path_decoded = decode_git_path(path) - if kind == 'directory' and path_decoded not in implicit_dirs: + if kind == "directory" and path_decoded not in implicit_dirs: continue parent_path, basename = osutils.split(path_decoded) parent_id = new_mapping.generate_file_id(parent_path) file_id = new_mapping.generate_file_id(path_decoded) ret.added.append( InventoryTreeChange( - file_id, (None, path_decoded), True, + file_id, + (None, path_decoded), + True, (False, True), (None, parent_id), - (None, basename), (None, kind), (None, False))) + (None, basename), + (None, kind), + (None, False), + ) + ) return ret -def changes_from_git_changes(changes, mapping, specific_files=None, - include_unchanged=False, source_extras=None, - target_extras=None): +def changes_from_git_changes( + changes, + mapping, + specific_files=None, + include_unchanged=False, + source_extras=None, + target_extras=None, +): """Create a iter_changes-like generator from a git stream. source and target are iterators over tuples with: @@ -921,8 +996,8 @@ def changes_from_git_changes(changes, mapping, specific_files=None, target_extras = set() if source_extras is None: source_extras = set() - for (change_type, old, new) in changes: - if change_type == 'unchanged' and not include_unchanged: + for change_type, old, new in changes: + if change_type == "unchanged" and not include_unchanged: continue (oldpath, oldmode, oldsha) = old (newpath, newmode, newsha) = new @@ -934,13 +1009,17 @@ def changes_from_git_changes(changes, mapping, specific_files=None, newpath_decoded = decode_git_path(newpath) else: newpath_decoded = None - if not (specific_files is None or - (oldpath_decoded is not None and - osutils.is_inside_or_parent_of_any( - specific_files, oldpath_decoded)) or - (newpath_decoded is not None and - osutils.is_inside_or_parent_of_any( - specific_files, newpath_decoded))): + if not ( + specific_files is None + or ( + oldpath_decoded is not None + and osutils.is_inside_or_parent_of_any(specific_files, oldpath_decoded) + ) + or ( + newpath_decoded is not None + and osutils.is_inside_or_parent_of_any(specific_files, newpath_decoded) + ) + ): continue if oldpath is not None and mapping.is_special_file(oldpath): continue @@ -953,16 +1032,16 @@ def changes_from_git_changes(changes, mapping, specific_files=None, oldparent = None oldversioned = False else: - oldversioned = (oldpath not in source_extras) + oldversioned = oldpath not in source_extras if oldmode: oldexe = mode_is_executable(oldmode) oldkind = mode_kind(oldmode) else: oldexe = False oldkind = None - if oldpath_decoded == '': + if oldpath_decoded == "": oldparent = None - oldname = '' + oldname = "" else: (oldparentpath, oldname) = osutils.split(oldpath_decoded) oldparent = mapping.generate_file_id(oldparentpath) @@ -973,40 +1052,47 @@ def changes_from_git_changes(changes, mapping, specific_files=None, newparent = None newversioned = False else: - newversioned = (newpath not in target_extras) + newversioned = newpath not in target_extras if newmode: newexe = mode_is_executable(newmode) newkind = mode_kind(newmode) else: newexe = False newkind = None - if newpath_decoded == '': + if newpath_decoded == "": newparent = None - newname = '' + newname = "" else: newparentpath, newname = osutils.split(newpath_decoded) newparent = mapping.generate_file_id(newparentpath) - if (not include_unchanged and - oldkind == 'directory' and newkind == 'directory' and - oldpath_decoded == newpath_decoded): + if ( + not include_unchanged + and oldkind == "directory" + and newkind == "directory" + and oldpath_decoded == newpath_decoded + ): continue - if oldversioned and change_type != 'copy': + if oldversioned and change_type != "copy": fileid = mapping.generate_file_id(oldpath_decoded) elif newversioned: fileid = mapping.generate_file_id(newpath_decoded) else: fileid = None - if oldkind == 'directory' and newkind == 'directory': + if oldkind == "directory" and newkind == "directory": modified = False else: modified = (oldsha != newsha) or (oldmode != newmode) yield InventoryTreeChange( - fileid, (oldpath_decoded, newpath_decoded), + fileid, + (oldpath_decoded, newpath_decoded), modified, (oldversioned, newversioned), - (oldparent, newparent), (oldname, newname), - (oldkind, newkind), (oldexe, newexe), - copied=(change_type == 'copy')) + (oldparent, newparent), + (oldname, newname), + (oldkind, newkind), + (oldexe, newexe), + copied=(change_type == "copy"), + ) class InterGitTrees(_mod_tree.InterTree): @@ -1022,33 +1108,48 @@ def __init__(self, source: GitTree, target: GitTree) -> None: if self.source.store == self.target.store: self.store = self.source.store else: - self.store = OverlayObjectStore( - [self.source.store, self.target.store]) + self.store = OverlayObjectStore([self.source.store, self.target.store]) self.rename_detector = RenameDetector(self.store) @classmethod def is_compatible(cls, source, target): return isinstance(source, GitTree) and isinstance(target, GitTree) - def compare(self, want_unchanged=False, specific_files=None, - extra_trees=None, require_versioned=False, include_root=False, - want_unversioned=False): + def compare( + self, + want_unchanged=False, + specific_files=None, + extra_trees=None, + require_versioned=False, + include_root=False, + want_unversioned=False, + ): with self.lock_read(): changes, source_extras, target_extras = self._iter_git_changes( want_unchanged=want_unchanged, require_versioned=require_versioned, specific_files=specific_files, extra_trees=extra_trees, - want_unversioned=want_unversioned) + want_unversioned=want_unversioned, + ) return tree_delta_from_git_changes( - changes, (self.source.mapping, self.target.mapping), + changes, + (self.source.mapping, self.target.mapping), specific_files=specific_files, include_root=include_root, - source_extras=source_extras, target_extras=target_extras) - - def iter_changes(self, include_unchanged=False, specific_files=None, - pb=None, extra_trees=None, require_versioned=True, - want_unversioned=False): + source_extras=source_extras, + target_extras=target_extras, + ) + + def iter_changes( + self, + include_unchanged=False, + specific_files=None, + pb=None, + extra_trees=None, + require_versioned=True, + want_unversioned=False, + ): if extra_trees is None: extra_trees = [] with self.lock_read(): @@ -1057,51 +1158,65 @@ def iter_changes(self, include_unchanged=False, specific_files=None, require_versioned=require_versioned, specific_files=specific_files, extra_trees=extra_trees, - want_unversioned=want_unversioned) + want_unversioned=want_unversioned, + ) return changes_from_git_changes( - changes, self.target.mapping, + changes, + self.target.mapping, specific_files=specific_files, include_unchanged=include_unchanged, source_extras=source_extras, - target_extras=target_extras) - - def _iter_git_changes(self, want_unchanged=False, specific_files=None, - require_versioned=False, extra_trees=None, - want_unversioned=False, include_trees=True): + target_extras=target_extras, + ) + + def _iter_git_changes( + self, + want_unchanged=False, + specific_files=None, + require_versioned=False, + extra_trees=None, + want_unversioned=False, + include_trees=True, + ): trees = [self.source] if extra_trees is not None: trees.extend(extra_trees) if specific_files is not None: specific_files = self.target.find_related_paths_across_trees( - specific_files, trees, - require_versioned=require_versioned) + specific_files, trees, require_versioned=require_versioned + ) # TODO(jelmer): Restrict to specific_files, for performance reasons. with self.lock_read(): from_tree_sha, from_extras = self.source.git_snapshot( - want_unversioned=want_unversioned) + want_unversioned=want_unversioned + ) to_tree_sha, to_extras = self.target.git_snapshot( - want_unversioned=want_unversioned) + want_unversioned=want_unversioned + ) changes = tree_changes( - self.store, from_tree_sha, to_tree_sha, + self.store, + from_tree_sha, + to_tree_sha, include_trees=include_trees, rename_detector=self.rename_detector, - want_unchanged=want_unchanged, change_type_same=True) + want_unchanged=want_unchanged, + change_type_same=True, + ) return changes, from_extras, to_extras - def find_target_path(self, path, recurse='none'): + def find_target_path(self, path, recurse="none"): ret = self.find_target_paths([path], recurse=recurse) return ret[path] - def find_source_path(self, path, recurse='none'): + def find_source_path(self, path, recurse="none"): ret = self.find_source_paths([path], recurse=recurse) return ret[path] - def find_target_paths(self, paths, recurse='none'): + def find_target_paths(self, paths, recurse="none"): paths = set(paths) ret = {} - changes = self._iter_git_changes( - specific_files=paths, include_trees=False)[0] - for (_change_type, old, new) in changes: + changes = self._iter_git_changes(specific_files=paths, include_trees=False)[0] + for _change_type, old, new in changes: if old[0] is None: continue oldpath = decode_git_path(old[0]) @@ -1118,12 +1233,11 @@ def find_target_paths(self, paths, recurse='none'): raise _mod_transport.NoSuchFile(path) return ret - def find_source_paths(self, paths, recurse='none'): + def find_source_paths(self, paths, recurse="none"): paths = set(paths) ret = {} - changes = self._iter_git_changes( - specific_files=paths, include_trees=False)[0] - for (_change_type, old, new) in changes: + changes = self._iter_git_changes(specific_files=paths, include_trees=False)[0] + for _change_type, old, new in changes: if new[0] is None: continue newpath = decode_git_path(new[0]) @@ -1145,7 +1259,6 @@ def find_source_paths(self, paths, recurse='none'): class MutableGitIndexTree(mutabletree.MutableTree, GitTree): - store: BaseObjectStore def __init__(self): @@ -1160,9 +1273,9 @@ def git_snapshot(self, want_unversioned=False): def is_versioned(self, path): with self.lock_read(): - path = encode_git_path(path.rstrip('/')) + path = encode_git_path(path.rstrip("/")) (index, subpath) = self._lookup_index(path) - return (subpath in index or self._has_dir(path)) + return subpath in index or self._has_dir(path) def _has_dir(self, path): if not isinstance(path, bytes): @@ -1191,10 +1304,9 @@ def _ensure_versioned_dir(self, dirname): def path2id(self, path): with self.lock_read(): - path = path.rstrip('/') - if self.is_versioned(path.rstrip('/')): - return self.mapping.generate_file_id( - osutils.safe_unicode(path)) + path = path.rstrip("/") + if self.is_versioned(path.rstrip("/")): + return self.mapping.generate_file_id(osutils.safe_unicode(path)) return None def add(self, files, kinds=None): @@ -1220,7 +1332,7 @@ def add(self, files, kinds=None): if kinds is not None: kinds = [kinds] - files = [path.strip('/') for path in files] + files = [path.strip("/") for path in files] if kinds is None: kinds = [None] * len(files) @@ -1236,7 +1348,7 @@ def add(self, files, kinds=None): # caring about the instantaneous file kind within a uncommmitted tree # self._gather_kinds(files, kinds) - for (path, kind) in zip(files, kinds): + for path, kind in zip(files, kinds): path, can_access = osutils.normalized_filename(path) if not can_access: raise errors.InvalidNormalization(path) @@ -1260,9 +1372,9 @@ def _lookup_index(self, encoded_path): index = self.index remaining_path = encoded_path while True: - parts = remaining_path.split(b'/') + parts = remaining_path.split(b"/") for i in range(1, len(parts)): - basepath = b'/'.join(parts[:i]) + basepath = b"/".join(parts[:i]) try: value = index[basepath] except KeyError: @@ -1270,7 +1382,7 @@ def _lookup_index(self, encoded_path): else: if S_ISGITLINK(value.mode): index = self._get_submodule_index(basepath) - remaining_path = b'/'.join(parts[i:]) + remaining_path = b"/".join(parts[i:]) break else: return index, remaining_path @@ -1284,11 +1396,9 @@ def _index_del_entry(self, index, path): self._index_dirty = True def _apply_index_changes(self, changes): - for (path, kind, _executability, reference_revision, - symlink_target) in changes: - if kind is None or kind == 'directory': - (index, subpath) = self._lookup_index( - encode_git_path(path)) + for path, kind, _executability, reference_revision, symlink_target in changes: + if kind is None or kind == "directory": + (index, subpath) = self._lookup_index(encode_git_path(path)) try: self._index_del_entry(index, subpath) except KeyError: @@ -1297,14 +1407,16 @@ def _apply_index_changes(self, changes): self._versioned_dirs = None else: self._index_add_entry( - path, kind, + path, + kind, reference_revision=reference_revision, - symlink_target=symlink_target) + symlink_target=symlink_target, + ) self.flush() def _index_add_entry( - self, path, kind, reference_revision=None, - symlink_target=None): + self, path, kind, reference_revision=None, symlink_target=None + ): if kind == "directory": # Git indexes don't contain directories return @@ -1317,7 +1429,8 @@ def _index_add_entry( # index file = BytesIO() stat_val = os.stat_result( - (stat.S_IFREG | 0o644, 0, 0, 0, 0, 0, 0, 0, 0, 0)) + (stat.S_IFREG | 0o644, 0, 0, 0, 0, 0, 0, 0, 0, 0) + ) with file: blob.set_raw_string(file.read()) # Add object to the repository if it didn't exist yet @@ -1331,8 +1444,7 @@ def _index_add_entry( except OSError: # TODO: Rather than come up with something here, use the # old index - stat_val = os.stat_result( - (stat.S_IFLNK, 0, 0, 0, 0, 0, 0, 0, 0, 0)) + stat_val = os.stat_result((stat.S_IFLNK, 0, 0, 0, 0, 0, 0, 0, 0, 0)) if symlink_target is None: symlink_target = self.get_symlink_target(path) blob.set_raw_string(encode_git_path(symlink_target)) @@ -1342,8 +1454,7 @@ def _index_add_entry( hexsha = blob.id elif kind == "tree-reference": if reference_revision is not None: - hexsha = self.branch.lookup_bzr_revision_id( - reference_revision)[0] + hexsha = self.branch.lookup_bzr_revision_id(reference_revision)[0] else: hexsha = self._read_submodule_head(path) if hexsha is None: @@ -1351,17 +1462,16 @@ def _index_add_entry( try: stat_val = self._lstat(path) except OSError: - stat_val = os.stat_result( - (S_IFGITLINK, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - stat_val = os.stat_result((S_IFGITLINK, ) + stat_val[1:]) + stat_val = os.stat_result((S_IFGITLINK, 0, 0, 0, 0, 0, 0, 0, 0, 0)) + stat_val = os.stat_result((S_IFGITLINK,) + stat_val[1:]) else: raise AssertionError(f"unknown kind '{kind}'") # Add an entry to the index or update the existing entry ensure_normalized_path(path) encoded_path = encode_git_path(path) - if b'\r' in encoded_path or b'\n' in encoded_path: + if b"\r" in encoded_path or b"\n" in encoded_path: # TODO(jelmer): Why do we need to do this? - trace.mutter('ignoring path with invalid newline in it: %r', path) + trace.mutter("ignoring path with invalid newline in it: %r", path) return (index, index_path) = self._lookup_index(encoded_path) index[index_path] = index_entry_from_stat(stat_val, hexsha) @@ -1373,8 +1483,7 @@ def iter_git_objects(self): for p, entry in self._recurse_index_entries(): yield p, entry.sha, entry.mode - def _recurse_index_entries(self, index=None, basepath=b"", - recurse_nested=False): + def _recurse_index_entries(self, index=None, basepath=b"", recurse_nested=False): # Iterate over all index entries with self.lock_read(): if index is None: @@ -1389,13 +1498,12 @@ def _recurse_index_entries(self, index=None, basepath=b"", if S_ISGITLINK(mode) and recurse_nested: subindex = self._get_submodule_index(path) yield from self._recurse_index_entries( - index=subindex, basepath=path, - recurse_nested=recurse_nested) + index=subindex, basepath=path, recurse_nested=recurse_nested + ) else: yield (posixpath.join(basepath, path), value) - def iter_entries_by_dir(self, specific_files=None, - recurse_nested=False): + def iter_entries_by_dir(self, specific_files=None, recurse_nested=False): with self.lock_read(): if specific_files is not None: specific_files = set(specific_files) @@ -1407,7 +1515,8 @@ def iter_entries_by_dir(self, specific_files=None, ret[("", "")] = root_ie dir_ids = {"": root_ie.file_id} for path, value in self._recurse_index_entries( - recurse_nested=recurse_nested): + recurse_nested=recurse_nested + ): if self.mapping.is_special_file(path): continue path = decode_git_path(path) @@ -1419,8 +1528,9 @@ def iter_entries_by_dir(self, specific_files=None, except _mod_transport.NoSuchFile: continue if specific_files is None: - for (dir_path, dir_ie) in self._add_missing_parent_ids( - parent, dir_ids): + for dir_path, dir_ie in self._add_missing_parent_ids( + parent, dir_ids + ): ret[(posixpath.dirname(dir_path), dir_path)] = dir_ie file_ie.parent_id = self.path2id(parent) ret[(posixpath.dirname(path), path)] = file_ie @@ -1430,21 +1540,27 @@ def iter_entries_by_dir(self, specific_files=None, key = (posixpath.dirname(path), path) if key not in ret and self.is_versioned(path): ret[key] = self._get_dir_ie(path, self.path2id(key[0])) - for ((_, path), ie) in sorted(ret.items()): + for (_, path), ie in sorted(ret.items()): yield path, ie def iter_references(self): if self.supports_tree_reference(): # TODO(jelmer): Implement a more efficient version of this for path, entry in self.iter_entries_by_dir(): - if entry.kind == 'tree-reference': + if entry.kind == "tree-reference": yield path def _get_dir_ie(self, path: str, parent_id) -> GitTreeDirectory: file_id = self.path2id(path) return GitTreeDirectory(file_id, posixpath.basename(path).strip("/"), parent_id) - def _get_file_ie(self, name: str, path: str, value: Union[IndexEntry, ConflictedIndexEntry], parent_id) -> Union[GitTreeSymlink, GitTreeDirectory, GitTreeFile, GitTreeSubmodule]: + def _get_file_ie( + self, + name: str, + path: str, + value: Union[IndexEntry, ConflictedIndexEntry], + parent_id, + ) -> Union[GitTreeSymlink, GitTreeDirectory, GitTreeFile, GitTreeSubmodule]: if not isinstance(name, str): raise TypeError(name) if not isinstance(path, str): @@ -1466,11 +1582,11 @@ def _get_file_ie(self, name: str, path: str, value: Union[IndexEntry, Conflicted raise TypeError(file_id) kind = mode_kind(mode) ie = entry_factory[kind](file_id, name, parent_id, git_sha1=sha) - if kind == 'symlink': + if kind == "symlink": ie.symlink_target = self.get_symlink_target(path) - elif kind == 'tree-reference': + elif kind == "tree-reference": ie.reference_revision = self.get_reference_revision(path) - elif kind == 'directory': + elif kind == "directory": pass else: ie.git_sha1 = sha @@ -1478,7 +1594,9 @@ def _get_file_ie(self, name: str, path: str, value: Union[IndexEntry, Conflicted ie.executable = bool(stat.S_ISREG(mode) and stat.S_IEXEC & mode) return ie - def _add_missing_parent_ids(self, path: str, dir_ids) -> List[Tuple[str, GitTreeDirectory]]: + def _add_missing_parent_ids( + self, path: str, dir_ids + ) -> List[Tuple[str, GitTreeDirectory]]: if path in dir_ids: return [] parent = posixpath.dirname(path).strip("/") @@ -1527,14 +1645,13 @@ def flush(self): def update_basis_by_delta(self, revid, delta): # TODO(jelmer): This shouldn't be called, it's inventory specific. - for (old_path, new_path, _file_id, ie) in delta: + for old_path, new_path, _file_id, ie in delta: if old_path is not None: - (index, old_subpath) = self._lookup_index( - encode_git_path(old_path)) + (index, old_subpath) = self._lookup_index(encode_git_path(old_path)) if old_subpath in index: self._index_del_entry(index, old_subpath) self._versioned_dirs = None - if new_path is not None and ie.kind != 'directory': + if new_path is not None and ie.kind != "directory": self._index_add_entry(new_path, ie.kind) self.flush() self._set_merges_from_parent_ids([]) @@ -1544,8 +1661,9 @@ def move(self, from_paths, to_dir=None, after=None): with self.lock_tree_write(): to_abs = self.abspath(to_dir) if not os.path.isdir(to_abs): - raise errors.BzrMoveFailedError('', to_dir, - errors.NotADirectory(to_abs)) + raise errors.BzrMoveFailedError( + "", to_dir, errors.NotADirectory(to_abs) + ) for from_rel in from_paths: from_tail = os.path.split(from_rel)[-1] @@ -1565,16 +1683,19 @@ def rename_one(self, from_rel, to_rel, after=None): if not after: # Perhaps it's already moved? after = ( - not self.has_filename(from_rel) and - self.has_filename(to_rel) and - not self.is_versioned(to_rel)) + not self.has_filename(from_rel) + and self.has_filename(to_rel) + and not self.is_versioned(to_rel) + ) if after: if not self.has_filename(to_rel): raise errors.BzrMoveFailedError( - from_rel, to_rel, _mod_transport.NoSuchFile(to_rel)) + from_rel, to_rel, _mod_transport.NoSuchFile(to_rel) + ) if self.basis_tree().is_versioned(to_rel): raise errors.BzrMoveFailedError( - from_rel, to_rel, errors.AlreadyVersionedError(to_rel)) + from_rel, to_rel, errors.AlreadyVersionedError(to_rel) + ) kind = self.kind(to_rel) else: @@ -1585,36 +1706,39 @@ def rename_one(self, from_rel, to_rel, after=None): else: exc_type = errors.BzrMoveFailedError if self.is_versioned(to_rel): - raise exc_type(from_rel, to_rel, - errors.AlreadyVersionedError(to_rel)) + raise exc_type( + from_rel, to_rel, errors.AlreadyVersionedError(to_rel) + ) if not self.has_filename(from_rel): raise errors.BzrMoveFailedError( - from_rel, to_rel, _mod_transport.NoSuchFile(from_rel)) + from_rel, to_rel, _mod_transport.NoSuchFile(from_rel) + ) kind = self.kind(from_rel) - if not self.is_versioned(from_rel) and kind != 'directory': - raise exc_type(from_rel, to_rel, - errors.NotVersionedError(from_rel)) + if not self.is_versioned(from_rel) and kind != "directory": + raise exc_type(from_rel, to_rel, errors.NotVersionedError(from_rel)) if self.has_filename(to_rel): raise errors.RenameFailedFilesExist( - from_rel, to_rel, _mod_transport.FileExists(to_rel)) + from_rel, to_rel, _mod_transport.FileExists(to_rel) + ) kind = self.kind(from_rel) - if not after and kind != 'directory': + if not after and kind != "directory": (index, from_subpath) = self._lookup_index(from_path) if from_subpath not in index: # It's not a file raise errors.BzrMoveFailedError( - from_rel, to_rel, - errors.NotVersionedError(path=from_rel)) + from_rel, to_rel, errors.NotVersionedError(path=from_rel) + ) if not after: try: self._rename_one(from_rel, to_rel) except FileNotFoundError as err: raise errors.BzrMoveFailedError( - from_rel, to_rel, _mod_transport.NoSuchFile(to_rel)) from err - if kind != 'directory': + from_rel, to_rel, _mod_transport.NoSuchFile(to_rel) + ) from err + if kind != "directory": (index, from_index_path) = self._lookup_index(from_path) try: self._index_del_entry(index, from_path) @@ -1622,18 +1746,24 @@ def rename_one(self, from_rel, to_rel, after=None): pass self._index_add_entry(to_rel, kind) else: - todo = [(p, i) for (p, i) in self._recurse_index_entries() - if p.startswith(from_path + b'/')] + todo = [ + (p, i) + for (p, i) in self._recurse_index_entries() + if p.startswith(from_path + b"/") + ] for child_path, child_value in todo: (child_to_index, child_to_index_path) = self._lookup_index( - posixpath.join(to_path, posixpath.relpath(child_path, from_path))) + posixpath.join( + to_path, posixpath.relpath(child_path, from_path) + ) + ) child_to_index[child_to_index_path] = child_value # TODO(jelmer): Mark individual index as dirty self._index_dirty = True (child_from_index, child_from_index_path) = self._lookup_index( - child_path) - self._index_del_entry( - child_from_index, child_from_index_path) + child_path + ) + self._index_del_entry(child_from_index, child_from_index_path) self._versioned_dirs = None self.flush() @@ -1644,29 +1774,28 @@ def path_content_summary(self, path): stat_result = self._lstat(path) except FileNotFoundError: # no file. - return ('missing', None, None, None) + return ("missing", None, None, None) kind = mode_kind(stat_result.st_mode) - if kind == 'file': + if kind == "file": size = stat_result.st_size executable = self._is_executable_from_path_and_stat(path, stat_result) # try for a stat cache lookup - return ('file', size, executable, self._sha_from_stat( - path, stat_result)) - elif kind == 'directory': + return ("file", size, executable, self._sha_from_stat(path, stat_result)) + elif kind == "directory": # perhaps it looks like a plain directory, but it's really a # reference. if self._directory_is_tree_reference(path): - kind = 'tree-reference' + kind = "tree-reference" return kind, None, None, None - elif kind == 'symlink': + elif kind == "symlink": target = osutils.readlink(self.abspath(path)) - return ('symlink', None, None, target) + return ("symlink", None, None, target) else: return (kind, None, None, None) def stored_kind(self, relpath): - if relpath == '': - return 'directory' + if relpath == "": + return "directory" (index, index_path) = self._lookup_index(encode_git_path(relpath)) if index is None: return None @@ -1674,19 +1803,18 @@ def stored_kind(self, relpath): mode = index[index_path].mode except KeyError: for p in index: - if osutils.is_inside( - decode_git_path(index_path), decode_git_path(p)): - return 'directory' + if osutils.is_inside(decode_git_path(index_path), decode_git_path(p)): + return "directory" return None else: return mode_kind(mode) def kind(self, relpath): kind = file_kind(self.abspath(relpath)) - if kind == 'directory': + if kind == "directory": if self._directory_is_tree_reference(relpath): - return 'tree-reference' - return 'directory' + return "tree-reference" + return "directory" else: return kind @@ -1695,6 +1823,7 @@ def _live_entry(self, relpath): def transform(self, pb=None): from .transform import GitTreeTransform + return GitTreeTransform(self, pb=pb) def has_changes(self, _from_tree=None): @@ -1716,7 +1845,7 @@ def has_changes(self, _from_tree=None): # Fast path for has_changes. try: change = next(changes) - if change.path[1] == '': + if change.path[1] == "": next(changes) return True except StopIteration: @@ -1733,8 +1862,8 @@ def has_changes(self, _from_tree=None): # working copy as compared to the repository. # Also, exclude root as mention in the above fast path. changes = filter( - lambda c: c[6][0] != 'symlink' and c[4] != (None, None), - changes) + lambda c: c[6][0] != "symlink" and c[4] != (None, None), changes + ) try: next(iter(changes)) except StopIteration: @@ -1742,7 +1871,9 @@ def has_changes(self, _from_tree=None): return True -def snapshot_workingtree(target: MutableGitIndexTree, want_unversioned: bool = False) -> Tuple[ObjectID, Set[bytes]]: +def snapshot_workingtree( + target: MutableGitIndexTree, want_unversioned: bool = False +) -> Tuple[ObjectID, Set[bytes]]: """Snapshot a working tree into a tree object.""" extras = set() blobs = {} @@ -1751,7 +1882,7 @@ def snapshot_workingtree(target: MutableGitIndexTree, want_unversioned: bool = F dirified = [] trust_executable = target._supports_executable() # type: ignore for path, index_entry in target._recurse_index_entries(): - index_entry = getattr(index_entry, 'this', index_entry) + index_entry = getattr(index_entry, "this", index_entry) try: live_entry = target._live_entry(path) except FileNotFoundError: @@ -1798,12 +1929,12 @@ def snapshot_workingtree(target: MutableGitIndexTree, want_unversioned: bool = F if stat.S_ISDIR(st.st_mode): obj = Tree() elif stat.S_ISREG(st.st_mode) or stat.S_ISLNK(st.st_mode): - obj = blob_from_path_and_stat( - os.fsencode(target.abspath(extra)), st) # type: ignore + obj = blob_from_path_and_stat(os.fsencode(target.abspath(extra)), st) # type: ignore else: continue target.store.add_object(obj) blobs[np] = (obj.id, cleanup_mode(st.st_mode)) extras.add(np) return commit_tree( - target.store, dirified + [(p, s, m) for (p, (s, m)) in blobs.items()]), extras + target.store, dirified + [(p, s, m) for (p, (s, m)) in blobs.items()] + ), extras diff --git a/breezy/git/unpeel_map.py b/breezy/git/unpeel_map.py index 5e19a44db1..99b9407c35 100644 --- a/breezy/git/unpeel_map.py +++ b/breezy/git/unpeel_map.py @@ -44,8 +44,7 @@ def update(self, m): def load(self, f): firstline = f.readline() if firstline != b"unpeel map version 1\n": - raise AssertionError( - f"invalid format for unpeel map: {firstline!r}") + raise AssertionError(f"invalid format for unpeel map: {firstline!r}") for l in f.readlines(): (k, v) = l.split(b":", 1) k = k.strip() diff --git a/breezy/git/urls.py b/breezy/git/urls.py index 3fe0fa48ff..68563dbe85 100644 --- a/breezy/git/urls.py +++ b/breezy/git/urls.py @@ -23,35 +23,35 @@ from .._git_rs import bzr_url_to_git_url # noqa: F401 from .refs import ref_to_branch_name -KNOWN_GIT_SCHEMES = ['git+ssh', 'git', 'http', 'https', 'ftp', 'ssh'] +KNOWN_GIT_SCHEMES = ["git+ssh", "git", "http", "https", "ftp", "ssh"] SCHEME_REPLACEMENT = { - 'ssh': 'git+ssh', - } + "ssh": "git+ssh", +} def git_url_to_bzr_url(location, branch=None, ref=None): if branch is not None and ref is not None: - raise ValueError('only specify one of branch or ref') + raise ValueError("only specify one of branch or ref") url = urlutils.URL.from_string(location) - if (url.scheme not in KNOWN_GIT_SCHEMES and - not url.scheme.startswith('chroot-')): + if url.scheme not in KNOWN_GIT_SCHEMES and not url.scheme.startswith("chroot-"): try: (username, host, path) = parse_rsync_url(location) except ValueError: return location else: url = urlutils.URL( - scheme='git+ssh', + scheme="git+ssh", quoted_user=(urlutils.quote(username) if username else None), quoted_password=None, quoted_host=urlutils.quote(host), port=None, - quoted_path=urlutils.quote(path, safe="/~")) + quoted_path=urlutils.quote(path, safe="/~"), + ) location = str(url) elif url.scheme in SCHEME_REPLACEMENT: url.scheme = SCHEME_REPLACEMENT[url.scheme] location = str(url) - if ref == b'HEAD': + if ref == b"HEAD": ref = branch = None if ref: try: @@ -63,8 +63,8 @@ def git_url_to_bzr_url(location, branch=None, ref=None): if ref or branch: params = {} if ref: - params['ref'] = urlutils.quote_from_bytes(ref, safe='') + params["ref"] = urlutils.quote_from_bytes(ref, safe="") if branch: - params['branch'] = urlutils.escape(branch, safe='') + params["branch"] = urlutils.escape(branch, safe="") location = urlutils.join_segment_parameters(location, params) return location diff --git a/breezy/git/workingtree.py b/breezy/git/workingtree.py index 64d571e6d2..ebf0b43495 100644 --- a/breezy/git/workingtree.py +++ b/breezy/git/workingtree.py @@ -56,7 +56,7 @@ from .mapping import decode_git_path, encode_git_path, mode_kind from .tree import MutableGitIndexTree -CONFLICT_SUFFIXES = ['.BASE', '.OTHER', '.THIS'] +CONFLICT_SUFFIXES = [".BASE", ".OTHER", ".THIS"] # TODO: There should be a base revid attribute to better inform the user about @@ -66,15 +66,15 @@ class TextConflict(_mod_conflicts.Conflict): has_files = True - typestring = 'text conflict' + typestring = "text conflict" - _conflict_re = re.compile(b'^(<{7}|={7}|>{7})') + _conflict_re = re.compile(b"^(<{7}|={7}|>{7})") def __init__(self, path): super().__init__(path) def associated_filenames(self): - return [self.path + suffix for suffix in ('.BASE', '.OTHER', '.THIS')] + return [self.path + suffix for suffix in (".BASE", ".OTHER", ".THIS")] def _resolve(self, tt, winner_suffix): """Resolve the conflict by copying one of .THIS or .OTHER into file. @@ -90,14 +90,12 @@ def _resolve(self, tt, winner_suffix): # item will exist after the conflict has been resolved anyway. item_tid = tt.trans_id_tree_path(self.path) item_parent_tid = tt.get_tree_parent(item_tid) - winner_path = self.path + '.' + winner_suffix + winner_path = self.path + "." + winner_suffix winner_tid = tt.trans_id_tree_path(winner_path) winner_parent_tid = tt.get_tree_parent(winner_tid) # Switch the paths to preserve the content - tt.adjust_path(osutils.basename(self.path), - winner_parent_tid, winner_tid) - tt.adjust_path(osutils.basename(winner_path), - item_parent_tid, item_tid) + tt.adjust_path(osutils.basename(self.path), winner_parent_tid, winner_tid) + tt.adjust_path(osutils.basename(winner_path), item_parent_tid, item_tid) tt.unversion_file(item_tid) tt.version_file(winner_tid) tt.apply() @@ -109,7 +107,7 @@ def action_auto(self, tree): kind = tree.kind(self.path) except _mod_transport.NoSuchFile: return - if kind != 'file': + if kind != "file": raise NotImplementedError("Conflict is not a file") conflict_markers_in_line = self._conflict_re.search with tree.get_file(self.path) as f: @@ -122,10 +120,10 @@ def _resolve_with_cleanups(self, tree, *args, **kwargs): self._resolve(tt, *args, **kwargs) def action_take_this(self, tree): - self._resolve_with_cleanups(tree, 'THIS') + self._resolve_with_cleanups(tree, "THIS") def action_take_other(self, tree): - self._resolve_with_cleanups(tree, 'OTHER') + self._resolve_with_cleanups(tree, "OTHER") def do(self, action, tree): """Apply the specified action to the conflict. @@ -134,9 +132,9 @@ def do(self, action, tree): :param tree: The tree passed as a parameter to the method. """ - meth = getattr(self, f'action_{action}', None) + meth = getattr(self, f"action_{action}", None) if meth is None: - raise NotImplementedError(self.__class__.__name__ + '.' + action) + raise NotImplementedError(self.__class__.__name__ + "." + action) meth(tree) def action_done(self, tree): @@ -162,15 +160,15 @@ def to_index_entry(self, tree): """Convert the conflict to a Git index entry.""" encoded_path = encode_git_path(tree.abspath(self.path)) try: - base = index_entry_from_path(encoded_path + b'.BASE') + base = index_entry_from_path(encoded_path + b".BASE") except FileNotFoundError: base = None try: - other = index_entry_from_path(encoded_path + b'.OTHER') + other = index_entry_from_path(encoded_path + b".OTHER") except FileNotFoundError: other = None try: - this = index_entry_from_path(encoded_path + b'.THIS') + this = index_entry_from_path(encoded_path + b".THIS") except FileNotFoundError: this = None return ConflictedIndexEntry(this=this, other=other, ancestor=base) @@ -181,15 +179,15 @@ class ContentsConflict(_mod_conflicts.Conflict): has_files = True - typestring = 'contents conflict' + typestring = "contents conflict" - format = 'Contents conflict in %(path)s' + format = "Contents conflict in %(path)s" def __init__(self, path, conflict_path=None): - for suffix in ('.BASE', '.THIS', '.OTHER'): + for suffix in (".BASE", ".THIS", ".OTHER"): if path.endswith(suffix): # Here is the raw path - path = path[:-len(suffix)] + path = path[: -len(suffix)] break _mod_conflicts.Conflict.__init__(self, path) self.conflict_path = conflict_path @@ -198,7 +196,7 @@ def _revision_tree(self, tree, revid): return tree.branch.repository.revision_tree(revid) def associated_filenames(self): - return [self.path + suffix for suffix in ('.BASE', '.OTHER', '.THIS')] + return [self.path + suffix for suffix in (".BASE", ".OTHER", ".THIS")] def _resolve(self, tt, suffix_to_remove): """Resolve the conflict. @@ -213,7 +211,8 @@ def _resolve(self, tt, suffix_to_remove): # Delete 'item.THIS' or 'item.OTHER' depending on # suffix_to_remove tt.delete_contents( - tt.trans_id_tree_path(self.path + '.' + suffix_to_remove)) + tt.trans_id_tree_path(self.path + "." + suffix_to_remove) + ) except _mod_transport.NoSuchFile: # There are valid cases where 'item.suffix_to_remove' either # never existed or was already deleted (including the case @@ -242,10 +241,10 @@ def _resolve_with_cleanups(self, tree, *args, **kwargs): self._resolve(tt, *args, **kwargs) def action_take_this(self, tree): - self._resolve_with_cleanups(tree, 'OTHER') + self._resolve_with_cleanups(tree, "OTHER") def action_take_other(self, tree): - self._resolve_with_cleanups(tree, 'THIS') + self._resolve_with_cleanups(tree, "THIS") @classmethod def from_index_entry(cls, entry): @@ -259,15 +258,15 @@ def to_index_entry(self, tree): """Convert the conflict to a Git index entry.""" encoded_path = encode_git_path(tree.abspath(self.path)) try: - base = index_entry_from_path(encoded_path + b'.BASE') + base = index_entry_from_path(encoded_path + b".BASE") except FileNotFoundError: base = None try: - other = index_entry_from_path(encoded_path + b'.OTHER') + other = index_entry_from_path(encoded_path + b".OTHER") except FileNotFoundError: other = None try: - this = index_entry_from_path(encoded_path + b'.THIS') + this = index_entry_from_path(encoded_path + b".THIS") except FileNotFoundError: this = None return ConflictedIndexEntry(this=this, other=other, ancestor=base) @@ -278,7 +277,7 @@ class GitWorkingTree(MutableGitIndexTree, workingtree.WorkingTree): def __init__(self, controldir, repo, branch): MutableGitIndexTree.__init__(self) - basedir = controldir.root_transport.local_abspath('.') + basedir = controldir.root_transport.local_abspath(".") self.basedir = osutils.realpath(basedir) self.controldir = controldir self.repository = repo @@ -301,7 +300,7 @@ def supports_rename_tracking(self): return False def _read_index(self): - self.index = Index(self.control_transport.local_abspath('index')) + self.index = Index(self.control_transport.local_abspath("index")) self._index_dirty = False def _get_submodule_index(self, relpath): @@ -312,17 +311,20 @@ def _get_submodule_index(self, relpath): except KeyError: submodule_transport = self.user_transport.clone(decode_git_path(relpath)) try: - submodule_dir = self._format._matchingcontroldir.open(submodule_transport) + submodule_dir = self._format._matchingcontroldir.open( + submodule_transport + ) except errors.NotBranchError as e: raise tree.MissingNestedTree(relpath) from e else: submodule_transport = self.control_transport.clone( - posixpath.join('modules', decode_git_path(info[1]))) + posixpath.join("modules", decode_git_path(info[1])) + ) try: submodule_dir = BareLocalGitControlDirFormat().open(submodule_transport) except errors.NotBranchError as e: raise tree.MissingNestedTree(relpath) from e - return Index(submodule_dir.control_transport.local_abspath('index')) + return Index(submodule_dir.control_transport.local_abspath("index")) def lock_read(self): """Lock the repository for read operations. @@ -330,7 +332,7 @@ def lock_read(self): :return: A breezy.lock.LogicalLockResult. """ if not self._lock_mode: - self._lock_mode = 'r' + self._lock_mode = "r" self._lock_count = 1 self._read_index() else: @@ -340,15 +342,16 @@ def lock_read(self): def _lock_write_tree(self): if not self._lock_mode: - self._lock_mode = 'w' + self._lock_mode = "w" self._lock_count = 1 try: self._index_file = GitFile( - self.control_transport.local_abspath('index'), 'wb') + self.control_transport.local_abspath("index"), "wb" + ) except FileLocked as err: - raise errors.LockContention('index') from err + raise errors.LockContention("index") from err self._read_index() - elif self._lock_mode == 'r': + elif self._lock_mode == "r": raise errors.ReadOnlyError(self) else: self._lock_count += 1 @@ -379,7 +382,7 @@ def get_physical_lock_status(self): def break_lock(self): try: - self.control_transport.delete('index.lock') + self.control_transport.delete("index.lock") except _mod_transport.NoSuchFile: pass self.branch.break_lock() @@ -429,17 +432,18 @@ def set_parent_trees(self, parents_list, allow_leftmost_as_ghost=False): def _set_merges_from_parent_ids(self, rhs_parent_ids): try: - merges = [self.branch.lookup_bzr_revision_id( - revid)[0] for revid in rhs_parent_ids] + merges = [ + self.branch.lookup_bzr_revision_id(revid)[0] for revid in rhs_parent_ids + ] except errors.NoSuchRevision as e: raise errors.GhostRevisionUnusableHere(e.revision) from e if merges: self.control_transport.put_bytes( - 'MERGE_HEAD', b'\n'.join(merges), - mode=self.controldir._get_file_mode()) + "MERGE_HEAD", b"\n".join(merges), mode=self.controldir._get_file_mode() + ) else: try: - self.control_transport.delete('MERGE_HEAD') + self.control_transport.delete("MERGE_HEAD") except _mod_transport.NoSuchFile: pass @@ -457,7 +461,8 @@ def set_parent_ids(self, revision_ids, allow_leftmost_as_ghost=False): """ with self.lock_tree_write(): self._check_parents_for_ghosts( - revision_ids, allow_leftmost_as_ghost=allow_leftmost_as_ghost) + revision_ids, allow_leftmost_as_ghost=allow_leftmost_as_ghost + ) for revision_id in revision_ids: _mod_revision.check_not_reserved_id(revision_id) @@ -482,22 +487,20 @@ def get_parent_ids(self): else: parents = [last_rev] try: - merges_bytes = self.control_transport.get_bytes('MERGE_HEAD') + merges_bytes = self.control_transport.get_bytes("MERGE_HEAD") except _mod_transport.NoSuchFile: pass else: for l in osutils.split_lines(merges_bytes): - revision_id = l.rstrip(b'\n') - parents.append( - self.branch.lookup_foreign_revision_id(revision_id)) + revision_id = l.rstrip(b"\n") + parents.append(self.branch.lookup_foreign_revision_id(revision_id)) return parents def check_state(self): """Check that the working state is/isn't valid.""" pass - def remove(self, files, verbose=False, to_file=None, keep_files=True, - force=False): + def remove(self, files, verbose=False, to_file=None, keep_files=True, force=False): """Remove nominated files from the working tree metadata. :param files: File paths relative to the basedir. @@ -513,8 +516,7 @@ def remove(self, files, verbose=False, to_file=None, keep_files=True, def backup(file_to_backup): abs_path = self.abspath(file_to_backup) - backup_name = self.controldir._available_backup_name( - file_to_backup) + backup_name = self.controldir._available_backup_name(file_to_backup) osutils.rename(abs_path, self.abspath(backup_name)) return f"removed {file_to_backup} (but kept a copy: {backup_name})" @@ -559,24 +561,32 @@ def recurse_directory_to_add_files(directory): # Bail out if we are going to delete files we shouldn't if not keep_files and not force: for change in self.iter_changes( - self.basis_tree(), include_unchanged=True, - require_versioned=False, want_unversioned=True, - specific_files=files): + self.basis_tree(), + include_unchanged=True, + require_versioned=False, + want_unversioned=True, + specific_files=files, + ): if change.versioned[0] is False: # The record is unknown or newly added files_to_backup.append(change.path[1]) files_to_backup.extend( - osutils.parent_directories(change.path[1])) - elif (change.changed_content and (change.kind[1] is not None) - and osutils.is_inside_any(files, change.path[1])): + osutils.parent_directories(change.path[1]) + ) + elif ( + change.changed_content + and (change.kind[1] is not None) + and osutils.is_inside_any(files, change.path[1]) + ): # Versioned and changed, but not deleted, and still # in one of the dirs to be deleted. files_to_backup.append(change.path[1]) files_to_backup.extend( - osutils.parent_directories(change.path[1])) + osutils.parent_directories(change.path[1]) + ) for f in files: - if f == '': + if f == "": continue try: @@ -588,11 +598,11 @@ def recurse_directory_to_add_files(directory): if verbose: # having removed it, it must be either ignored or unknown if self.is_ignored(f): - new_status = 'I' + new_status = "I" else: - new_status = '?' + new_status = "?" kind_ch = osutils.kind_marker(kind) - to_file.write(new_status + ' ' + f + kind_ch + '\n') + to_file.write(new_status + " " + f + kind_ch + "\n") if kind is None: message = f"{f} does not exist" else: @@ -600,7 +610,7 @@ def recurse_directory_to_add_files(directory): if f in files_to_backup and not force: message = backup(f) else: - if kind == 'directory': + if kind == "directory": osutils.rmtree(abs_path) else: osutils.delete_any(abs_path) @@ -616,7 +626,7 @@ def recurse_directory_to_add_files(directory): def smart_add(self, file_list, recurse=True, action=None, save=True): if not file_list: - file_list = ['.'] + file_list = ["."] # expand any symlinks in the directory part, while leaving the # filename alone @@ -633,7 +643,7 @@ def smart_add(self, file_list, recurse=True, action=None, save=True): user_dirs = [] def call_action(filepath, kind): - if filepath == '': + if filepath == "": return if action is not None: parent_path = posixpath.dirname(filepath) @@ -644,8 +654,7 @@ def call_action(filepath, kind): raise workingtree.SettingFileIdUnsupported() with self.lock_tree_write(): - for filepath in osutils.canonical_relpaths( - self.basedir, file_list): + for filepath in osutils.canonical_relpaths(self.basedir, file_list): filepath, can_access = osutils.normalized_filename(filepath) if not can_access: raise errors.InvalidNormalization(filepath) @@ -653,8 +662,7 @@ def call_action(filepath, kind): abspath = self.abspath(filepath) kind = file_kind(abspath) if kind in ("file", "symlink"): - (index, subpath) = self._lookup_index( - encode_git_path(filepath)) + (index, subpath) = self._lookup_index(encode_git_path(filepath)) if subpath in index: # Already present continue @@ -663,8 +671,7 @@ def call_action(filepath, kind): self._index_add_entry(filepath, kind) added.append(filepath) elif kind == "directory": - (index, subpath) = self._lookup_index( - encode_git_path(filepath)) + (index, subpath) = self._lookup_index(encode_git_path(filepath)) if subpath not in index: call_action(filepath, kind) if recurse: @@ -673,10 +680,9 @@ def call_action(filepath, kind): raise errors.BadFileKindError(filename=abspath, kind=kind) for user_dir in user_dirs: abs_user_dir = self.abspath(user_dir) - if user_dir != '': + if user_dir != "": try: - transport = _mod_transport.get_transport_from_path( - abs_user_dir) + transport = _mod_transport.get_transport_from_path(abs_user_dir) _mod_controldir.ControlDirFormat.find_format(transport) subtree = True except errors.NotBranchError: @@ -686,13 +692,14 @@ def call_action(filepath, kind): else: subtree = False if subtree: - trace.warning('skipping nested tree %r', abs_user_dir) + trace.warning("skipping nested tree %r", abs_user_dir) continue for name in os.listdir(abs_user_dir): subp = os.path.join(user_dir, name) - if (self.is_control_filename(subp) or - self.mapping.is_special_file(subp)): + if self.is_control_filename(subp) or self.mapping.is_special_file( + subp + ): continue ignore_glob = self.is_ignored(subp) if ignore_glob is not None: @@ -703,8 +710,7 @@ def call_action(filepath, kind): if kind == "directory": user_dirs.append(subp) else: - (index, subpath) = self._lookup_index( - encode_git_path(subp)) + (index, subpath) = self._lookup_index(encode_git_path(subp)) if subpath in index: # Already present continue @@ -719,15 +725,16 @@ def call_action(filepath, kind): def has_filename(self, filename): return osutils.lexists(self.abspath(filename)) - def _iter_files_recursive(self, from_dir=None, include_dirs=False, - recurse_nested=False): + def _iter_files_recursive( + self, from_dir=None, include_dirs=False, recurse_nested=False + ): if from_dir is None: from_dir = "" if not isinstance(from_dir, str): raise TypeError(from_dir) encoded_from_dir = os.fsencode(self.abspath(from_dir)) - for (dirpath, dirnames, filenames) in os.walk(encoded_from_dir): - dir_relpath = dirpath[len(self.basedir):].strip(b"/") + for dirpath, dirnames, filenames in os.walk(encoded_from_dir): + dir_relpath = dirpath[len(self.basedir) :].strip(b"/") if self.controldir.is_control_filename(os.fsdecode(dir_relpath)): continue for name in list(dirnames): @@ -735,7 +742,9 @@ def _iter_files_recursive(self, from_dir=None, include_dirs=False, dirnames.remove(name) continue relpath = os.path.join(dir_relpath, name) - if not recurse_nested and self._directory_is_tree_reference(os.fsdecode(relpath)): + if not recurse_nested and self._directory_is_tree_reference( + os.fsdecode(relpath) + ): dirnames.remove(name) if include_dirs: yield os.fsdecode(relpath) @@ -756,7 +765,8 @@ def extras(self): """Yield all unversioned files in this WorkingTree.""" with self.lock_read(): index_paths = { - decode_git_path(p) for p, sha, mode in self.iter_git_objects()} + decode_git_path(p) for p, sha, mode in self.iter_git_objects() + } all_paths = set(self._iter_files_recursive(include_dirs=False)) return iter(all_paths - index_paths) @@ -770,17 +780,17 @@ def _gather_kinds(self, files, kinds): kind = file_kind(fullpath) except FileNotFoundError as err: raise _mod_transport.NoSuchFile(fullpath) from err - if f != '' and self._directory_is_tree_reference(f): - kind = 'tree-reference' + if f != "" and self._directory_is_tree_reference(f): + kind = "tree-reference" kinds[pos] = kind def flush(self): - if self._lock_mode != 'w': + if self._lock_mode != "w": raise errors.NotWriteLocked(self) # TODO(jelmer): This shouldn't be writing in-place, but index.lock is # already in use and GitFile doesn't allow overriding the lock file # name :( - f = open(self.control_transport.local_abspath('index'), 'wb') + f = open(self.control_transport.local_abspath("index"), "wb") # Note that _flush will close the file self._flush(f) @@ -808,22 +818,22 @@ def is_ignored(self, filename): be ignored, otherwise None. So this can simply be used as a boolean if desired. """ - if getattr(self, '_global_ignoreglobster', None) is None: + if getattr(self, "_global_ignoreglobster", None) is None: from breezy import ignores + ignore_globs = set() ignore_globs.update(ignores.get_runtime_ignores()) ignore_globs.update(ignores.get_user_ignores()) - self._global_ignoreglobster = globbing.ExceptionGlobster( - ignore_globs) + self._global_ignoreglobster = globbing.ExceptionGlobster(ignore_globs) match = self._global_ignoreglobster.match(filename) if match is not None: return match try: - if self.kind(filename) == 'directory': - filename += '/' + if self.kind(filename) == "directory": + filename += "/" except _mod_transport.NoSuchFile: pass - filename = filename.lstrip('/') + filename = filename.lstrip("/") ignore_manager = self._get_ignore_manager() ps = list(ignore_manager.find_matching(filename)) if not ps: @@ -833,7 +843,7 @@ def is_ignored(self, filename): return bytes(ps[-1]) def _get_ignore_manager(self): - ignoremanager = getattr(self, '_ignoremanager', None) + ignoremanager = getattr(self, "_ignoremanager", None) if ignoremanager is not None: return ignoremanager @@ -898,7 +908,7 @@ def stored_kind(self, path): if self._has_dir(encoded_path): return "directory" raise _mod_transport.NoSuchFile(path) from err - entry = getattr(entry, 'this', entry) + entry = getattr(entry, "this", entry) return mode_kind(entry.mode) def _lstat(self, path): @@ -924,18 +934,20 @@ def _is_executable_from_path_and_stat(self, path, stat_result): if self._supports_executable(): return self._is_executable_from_path_and_stat_from_stat(path, stat_result) else: - return self._is_executable_from_path_and_stat_from_basis( - path, stat_result) + return self._is_executable_from_path_and_stat_from_basis(path, stat_result) - def list_files(self, include_root=False, from_dir=None, recursive=True, - recurse_nested=False): - if from_dir is None or from_dir == '.': + def list_files( + self, include_root=False, from_dir=None, recursive=True, recurse_nested=False + ): + if from_dir is None or from_dir == ".": from_dir = "" dir_ids = {} - fk_entries = {'directory': tree.TreeDirectory, - 'file': tree.TreeFile, - 'symlink': tree.TreeLink, - 'tree-reference': tree.TreeReference} + fk_entries = { + "directory": tree.TreeDirectory, + "file": tree.TreeFile, + "symlink": tree.TreeLink, + "tree-reference": tree.TreeReference, + } with self.lock_read(): root_ie = self._get_dir_ie("", None) if include_root and not from_dir: @@ -944,17 +956,19 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, if recursive: path_iterator = sorted( self._iter_files_recursive( - from_dir, include_dirs=True, - recurse_nested=recurse_nested)) + from_dir, include_dirs=True, recurse_nested=recurse_nested + ) + ) else: encoded_from_dir = os.fsencode(self.abspath(from_dir)) path_iterator = sorted( - [os.path.join(from_dir, os.fsdecode(name)) - for name in os.listdir(encoded_from_dir) - if not self.controldir.is_control_filename( - os.fsdecode(name)) and - not self.mapping.is_special_file( - os.fsdecode(name))]) + [ + os.path.join(from_dir, os.fsdecode(name)) + for name in os.listdir(encoded_from_dir) + if not self.controldir.is_control_filename(os.fsdecode(name)) + and not self.mapping.is_special_file(os.fsdecode(name)) + ] + ) for path in path_iterator: encoded_path = encode_git_path(path) (index, index_path) = self._lookup_index(encoded_path) @@ -964,15 +978,13 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, value = None kind = self.kind(path) parent, name = posixpath.split(path) - for _dir_path, _dir_ie in self._add_missing_parent_ids( - parent, dir_ids): + for _dir_path, _dir_ie in self._add_missing_parent_ids(parent, dir_ids): pass - if kind == 'tree-reference' and recurse_nested: + if kind == "tree-reference" and recurse_nested: ie = self._get_dir_ie(path, self.path2id(path)) - yield (posixpath.relpath(path, from_dir), 'V', 'directory', - ie) + yield (posixpath.relpath(path, from_dir), "V", "directory", ie) continue - if kind == 'directory': + if kind == "directory": if path != from_dir: if self._has_dir(encoded_path): ie = self._get_dir_ie(path, self.path2id(path)) @@ -983,8 +995,7 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, else: status = "?" ie = fk_entries[kind]() - yield (posixpath.relpath(path, from_dir), status, kind, - ie) + yield (posixpath.relpath(path, from_dir), status, kind, ie) continue if value is not None: ie = self._get_file_ie(name, path, value, dir_ids[parent]) @@ -995,8 +1006,12 @@ def list_files(self, include_root=False, from_dir=None, recursive=True, except KeyError: # unsupported kind continue - yield (posixpath.relpath(path, from_dir), - ("I" if self.is_ignored(path) else "?"), kind, ie) + yield ( + posixpath.relpath(path, from_dir), + ("I" if self.is_ignored(path) else "?"), + kind, + ie, + ) def all_versioned_paths(self): with self.lock_read(): @@ -1026,16 +1041,16 @@ def iter_child_entries(self, path): continue found_any = True subpath = posixpath.relpath(decoded_item_path, path) - if '/' in subpath: - dirname = subpath.split('/', 1)[0] - file_ie = self._get_dir_ie( - posixpath.join(path, dirname), parent_id) + if "/" in subpath: + dirname = subpath.split("/", 1)[0] + file_ie = self._get_dir_ie(posixpath.join(path, dirname), parent_id) else: (unused_parent, name) = posixpath.split(decoded_item_path) file_ie = self._get_file_ie( - name, decoded_item_path, value, parent_id) + name, decoded_item_path, value, parent_id + ) yield file_ie - if not found_any and path != '': + if not found_any and path != "": raise _mod_transport.NoSuchFile(path) def conflicts(self): @@ -1051,7 +1066,7 @@ def set_conflicts(self, conflicts): for conflict in conflicts: if not isinstance(conflict, (TextConflict, ContentsConflict)): raise errors.UnsupportedOperation(self.set_conflicts, self) - if conflict.typestring in ('text conflict', 'contents conflict'): + if conflict.typestring in ("text conflict", "contents conflict"): by_path[encode_git_path(conflict.path)] = conflict else: raise errors.UnsupportedOperation(self.set_conflicts, self) @@ -1081,15 +1096,21 @@ def add_conflicts(self, new_conflicts): if not isinstance(conflict, (TextConflict, ContentsConflict)): raise errors.UnsupportedOperation(self.set_conflicts, self) - if conflict.typestring in ('text conflict', 'contents conflict'): + if conflict.typestring in ("text conflict", "contents conflict"): self._index_dirty = True try: entry = conflict.to_index_entry(self) - if entry.this is None and entry.ancestor is None and entry.other is None: + if ( + entry.this is None + and entry.ancestor is None + and entry.other is None + ): raise AssertionError self.index[encode_git_path(conflict.path)] = entry except KeyError as err: - raise errors.UnsupportedOperation(self.add_conflicts, self) from err + raise errors.UnsupportedOperation( + self.add_conflicts, self + ) from err else: raise errors.UnsupportedOperation(self.add_conflicts, self) @@ -1109,8 +1130,9 @@ def walkdirs(self, prefix=""): """ import operator from bisect import bisect_left + disk_top = self.abspath(prefix) - if disk_top.endswith('/'): + if disk_top.endswith("/"): disk_top = disk_top[:-1] top_strip_len = len(disk_top) + 1 inventory_iterator = self._walkdirs(prefix) @@ -1129,23 +1151,30 @@ def walkdirs(self, prefix=""): inv_finished = True while not inv_finished or not disk_finished: if current_disk: - ((cur_disk_dir_relpath, cur_disk_dir_path_from_top), - cur_disk_dir_content) = current_disk + ( + (cur_disk_dir_relpath, cur_disk_dir_path_from_top), + cur_disk_dir_content, + ) = current_disk else: - ((cur_disk_dir_relpath, cur_disk_dir_path_from_top), - cur_disk_dir_content) = ((None, None), None) + ( + (cur_disk_dir_relpath, cur_disk_dir_path_from_top), + cur_disk_dir_content, + ) = ((None, None), None) if not disk_finished: # strip out .bzr dirs - if (cur_disk_dir_path_from_top[top_strip_len:] == '' - and len(cur_disk_dir_content) > 0): + if ( + cur_disk_dir_path_from_top[top_strip_len:] == "" + and len(cur_disk_dir_content) > 0 + ): # osutils.walkdirs can be made nicer - # yield the path-from-prefix rather than the pathjoined # value. - bzrdir_loc = bisect_left(cur_disk_dir_content, - ('.git', '.git')) - if (bzrdir_loc < len(cur_disk_dir_content) and - self.controldir.is_control_filename( - cur_disk_dir_content[bzrdir_loc][0])): + bzrdir_loc = bisect_left(cur_disk_dir_content, (".git", ".git")) + if bzrdir_loc < len( + cur_disk_dir_content + ) and self.controldir.is_control_filename( + cur_disk_dir_content[bzrdir_loc][0] + ): # we dont yield the contents of, or, .bzr itself. del cur_disk_dir_content[bzrdir_loc] if inv_finished: @@ -1155,13 +1184,15 @@ def walkdirs(self, prefix=""): # everything is missing direction = -1 else: - direction = ((current_inv[0][0] > cur_disk_dir_relpath) - - (current_inv[0][0] < cur_disk_dir_relpath)) + direction = (current_inv[0][0] > cur_disk_dir_relpath) - ( + current_inv[0][0] < cur_disk_dir_relpath + ) if direction > 0: # disk is before inventory - unknown - dirblock = [(relpath, basename, kind, stat, None) for - relpath, basename, kind, stat, top_path in - cur_disk_dir_content] + dirblock = [ + (relpath, basename, kind, stat, None) + for relpath, basename, kind, stat, top_path in cur_disk_dir_content + ] yield cur_disk_dir_relpath, dirblock try: current_disk = next(disk_iterator) @@ -1169,9 +1200,10 @@ def walkdirs(self, prefix=""): disk_finished = True elif direction < 0: # inventory is before disk - missing. - dirblock = [(relpath, basename, 'unknown', None, kind) - for relpath, basename, dkind, stat, fileid, kind in - current_inv[1]] + dirblock = [ + (relpath, basename, "unknown", None, kind) + for relpath, basename, dkind, stat, fileid, kind in current_inv[1] + ] yield current_inv[0][0], dirblock try: current_inv = next(inventory_iterator) @@ -1181,30 +1213,50 @@ def walkdirs(self, prefix=""): # versioned present directory # merge the inventory and disk data together dirblock = [] - for _relpath, subiterator in itertools.groupby(sorted( + for _relpath, subiterator in itertools.groupby( + sorted( current_inv[1] + cur_disk_dir_content, - key=operator.itemgetter(0)), operator.itemgetter(1)): + key=operator.itemgetter(0), + ), + operator.itemgetter(1), + ): path_elements = list(subiterator) if len(path_elements) == 2: inv_row, disk_row = path_elements # versioned, present file - dirblock.append((inv_row[0], - inv_row[1], disk_row[2], - disk_row[3], inv_row[5])) + dirblock.append( + ( + inv_row[0], + inv_row[1], + disk_row[2], + disk_row[3], + inv_row[5], + ) + ) elif len(path_elements[0]) == 5: # unknown disk file dirblock.append( - (path_elements[0][0], path_elements[0][1], - path_elements[0][2], path_elements[0][3], - None)) + ( + path_elements[0][0], + path_elements[0][1], + path_elements[0][2], + path_elements[0][3], + None, + ) + ) elif len(path_elements[0]) == 6: # versioned, absent file. dirblock.append( - (path_elements[0][0], path_elements[0][1], - 'unknown', None, - path_elements[0][5])) + ( + path_elements[0][0], + path_elements[0][1], + "unknown", + None, + path_elements[0][5], + ) + ) else: - raise NotImplementedError('unreachable code') + raise NotImplementedError("unreachable code") yield current_inv[0][0], dirblock try: current_inv = next(inventory_iterator) @@ -1221,22 +1273,28 @@ def _walkdirs(self, prefix=""): prefix = encode_git_path(prefix) per_dir = defaultdict(set) if prefix == b"": - per_dir[('', self.path2id(''))] = set() + per_dir[("", self.path2id(""))] = set() def add_entry(path, kind): - if path == b'' or not path.startswith(prefix): + if path == b"" or not path.startswith(prefix): return (dirname, child_name) = posixpath.split(path) - add_entry(dirname, 'directory') + add_entry(dirname, "directory") dirname = decode_git_path(dirname) dir_file_id = self.path2id(dirname) if not isinstance(value, (tuple, IndexEntry)): raise ValueError(value) per_dir[(dirname, dir_file_id)].add( - (decode_git_path(path), decode_git_path(child_name), - kind, None, - self.path2id(decode_git_path(path)), - kind)) + ( + decode_git_path(path), + decode_git_path(child_name), + kind, + None, + self.path2id(decode_git_path(path)), + kind, + ) + ) + with self.lock_read(): for path, value in self.index.iteritems(): if self.mapping.is_special_file(path): @@ -1252,8 +1310,7 @@ def get_shelf_manager(self): def store_uncommitted(self): raise errors.StoringUncommittedNotSupported(self) - def annotate_iter(self, path, - default_revision=_mod_revision.CURRENT_REVISION): + def annotate_iter(self, path, default_revision=_mod_revision.CURRENT_REVISION): """See Tree.annotate_iter. This implementation will use the basis tree implementation if possible. @@ -1269,8 +1326,7 @@ def annotate_iter(self, path, try: parent_tree = self.revision_tree(parent_id) except errors.NoSuchRevisionInTree: - parent_tree = self.branch.repository.revision_tree( - parent_id) + parent_tree = self.branch.repository.revision_tree(parent_id) with parent_tree.lock_read(): # TODO(jelmer): Use rename/copy tracker to find path name # in parent @@ -1279,7 +1335,7 @@ def annotate_iter(self, path, kind = parent_tree.kind(parent_path) except _mod_transport.NoSuchFile: continue - if kind != 'file': + if kind != "file": # Note: this is slightly unnecessary, because symlinks # and directories have a "text" which is the empty # text, and we know that won't mess up annotations. But @@ -1287,17 +1343,21 @@ def annotate_iter(self, path, continue parent_text_key = ( parent_path, - parent_tree.get_file_revision(parent_path)) + parent_tree.get_file_revision(parent_path), + ) if parent_text_key not in maybe_file_parent_keys: maybe_file_parent_keys.append(parent_text_key) # Now we have the parents of this content from ..annotate import Annotator from .annotate import AnnotateProvider + annotate_provider = AnnotateProvider( - self.branch.repository._file_change_scanner) + self.branch.repository._file_change_scanner + ) annotator = Annotator(annotate_provider) from ..graph import Graph + graph = Graph(annotate_provider) heads = graph.heads(maybe_file_parent_keys) file_parent_keys = [] @@ -1308,8 +1368,9 @@ def annotate_iter(self, path, text = self.get_file_text(path) this_key = (path, default_revision) annotator.add_special_text(this_key, file_parent_keys, text) - annotations = [(key[-1], line) - for key, line in annotator.annotate_flat(this_key)] + annotations = [ + (key[-1], line) for key, line in annotator.annotate_flat(this_key) + ] return annotations def _rename_one(self, from_rel, to_rel): @@ -1317,13 +1378,12 @@ def _rename_one(self, from_rel, to_rel): def _build_checkout_with_index(self): build_index_from_tree( - self.user_transport.local_abspath('.'), + self.user_transport.local_abspath("."), self.control_transport.local_abspath("index"), self.store, - None - if self.branch.head is None - else self.store[self.branch.head].tree, - honor_filemode=self._supports_executable()) + None if self.branch.head is None else self.store[self.branch.head].tree, + honor_filemode=self._supports_executable(), + ) def reset_state(self, revision_ids=None): """Reset the state of the working tree. @@ -1338,7 +1398,8 @@ def reset_state(self, revision_ids=None): self._index_dirty = True if self.branch.head is not None: for entry in iter_tree_contents( - self.store, self.store[self.branch.head].tree): + self.store, self.store[self.branch.head].tree + ): if not validate_path(entry.path): continue @@ -1347,24 +1408,36 @@ def reset_state(self, revision_ids=None): else: # Let's at least try to use the working tree file: try: - st = self._lstat(self.abspath( - decode_git_path(entry.path))) + st = self._lstat(self.abspath(decode_git_path(entry.path))) except OSError: # But if it doesn't exist, we'll make something up. obj = self.store[entry.sha] - st = os.stat_result((entry.mode, 0, 0, 0, - 0, 0, len( - obj.as_raw_string()), 0, - 0, 0)) + st = os.stat_result( + ( + entry.mode, + 0, + 0, + 0, + 0, + 0, + len(obj.as_raw_string()), + 0, + 0, + 0, + ) + ) (index, subpath) = self._lookup_index(entry.path) - index[subpath] = index_entry_from_stat(st, entry.sha, mode=entry.mode) + index[subpath] = index_entry_from_stat( + st, entry.sha, mode=entry.mode + ) def _update_git_tree( - self, old_revision, new_revision, change_reporter=None, - show_base=False): + self, old_revision, new_revision, change_reporter=None, show_base=False + ): basis_tree = self.revision_tree(old_revision) if new_revision != old_revision: from ..merge import merge_inner + with basis_tree.lock_read(): new_basis_tree = self.branch.basis_tree() merge_inner( @@ -1373,22 +1446,36 @@ def _update_git_tree( basis_tree, this_tree=self, change_reporter=change_reporter, - show_base=show_base) - - def pull(self, source, overwrite=False, stop_revision=None, - change_reporter=None, possible_transports=None, local=False, - show_base=False, tag_selector=None): + show_base=show_base, + ) + + def pull( + self, + source, + overwrite=False, + stop_revision=None, + change_reporter=None, + possible_transports=None, + local=False, + show_base=False, + tag_selector=None, + ): with self.lock_write(), source.lock_read(): old_revision = self.branch.last_revision() - count = self.branch.pull(source, overwrite=overwrite, - stop_revision=stop_revision, - possible_transports=possible_transports, - local=local, tag_selector=tag_selector) + count = self.branch.pull( + source, + overwrite=overwrite, + stop_revision=stop_revision, + possible_transports=possible_transports, + local=local, + tag_selector=tag_selector, + ) self._update_git_tree( old_revision=old_revision, new_revision=self.branch.last_revision(), change_reporter=change_reporter, - show_base=show_base) + show_base=show_base, + ) return count def add_reference(self, sub_tree): @@ -1401,12 +1488,13 @@ def add_reference(self, sub_tree): sub_tree_path = self.relpath(sub_tree.basedir) except errors.PathNotChild as err: raise BadReferenceTarget( - self, sub_tree, 'Target not inside tree.') from err + self, sub_tree, "Target not inside tree." + ) from err path, can_access = osutils.normalized_filename(sub_tree_path) if not can_access: raise errors.InvalidNormalization(path) - self._index_add_entry(sub_tree_path, 'tree-reference') + self._index_add_entry(sub_tree_path, "tree-reference") def _read_submodule_head(self, path): return read_submodule_head(self.abspath(path)) @@ -1414,8 +1502,7 @@ def _read_submodule_head(self, path): def get_reference_revision(self, path): hexsha = self._read_submodule_head(path) if hexsha is None: - (index, subpath) = self._lookup_index( - encode_git_path(path)) + (index, subpath) = self._lookup_index(encode_git_path(path)) if subpath is None: raise _mod_transport.NoSuchFile(path) hexsha = index[subpath].sha @@ -1437,6 +1524,7 @@ def extract(self, sub_path, format=None): A new branch will be created, relative to the path for this tree. """ + def mkdirs(path): segments = osutils.splitpath(path) transport = self.branch.controldir.root_transport @@ -1485,6 +1573,7 @@ def _get_check_refs(self): def copy_content_into(self, tree, revision_id=None): """Copy the current content and user files of this tree into tree.""" from .. import merge + with self.lock_read(): if revision_id is None: merge.transform_tree(tree, self) @@ -1494,8 +1583,7 @@ def copy_content_into(self, tree, revision_id=None): try: other_tree = self.revision_tree(revision_id) except errors.NoSuchRevision: - other_tree = self.branch.repository.revision_tree( - revision_id) + other_tree = self.branch.repository.revision_tree(revision_id) merge.transform_tree(tree, other_tree) if revision_id == _mod_revision.NULL_REVISION: @@ -1509,7 +1597,9 @@ def reference_parent(self, path, possible_transports=None): if remote_url is None: trace.warning("Unable to find submodule info for %s", path) return None - return _mod_branch.Branch.open(remote_url, possible_transports=possible_transports) + return _mod_branch.Branch.open( + remote_url, possible_transports=possible_transports + ) def get_reference_info(self, path): submodule_info = self._submodule_info() @@ -1519,12 +1609,12 @@ def get_reference_info(self, path): return decode_git_path(info[0]) def set_reference_info(self, tree_path, branch_location): - path = self.abspath('.gitmodules') + path = self.abspath(".gitmodules") try: config = GitConfigFile.from_path(path) except FileNotFoundError: config = GitConfigFile() - section = (b'submodule', encode_git_path(tree_path)) + section = (b"submodule", encode_git_path(tree_path)) if branch_location is None: try: del config[section] @@ -1532,21 +1622,23 @@ def set_reference_info(self, tree_path, branch_location): pass else: branch_location = urlutils.join( - urlutils.strip_segment_parameters(self.branch.user_url), - branch_location) - config.set( - section, - b'path', encode_git_path(tree_path)) - config.set( - section, - b'url', branch_location.encode('utf-8')) + urlutils.strip_segment_parameters(self.branch.user_url), branch_location + ) + config.set(section, b"path", encode_git_path(tree_path)) + config.set(section, b"url", branch_location.encode("utf-8")) config.write_to_path(path) - self.add('.gitmodules') + self.add(".gitmodules") _marker = object() - def update(self, change_reporter=None, possible_transports=None, - revision=None, old_tip=_marker, show_base=False): + def update( + self, + change_reporter=None, + possible_transports=None, + revision=None, + old_tip=_marker, + show_base=False, + ): """Update a working tree along its branch. This will update the branch if its bound too, which means we have @@ -1580,7 +1672,7 @@ def update(self, change_reporter=None, possible_transports=None, """ if self.branch.get_bound_location() is not None: self.lock_write() - update_branch = (old_tip is self._marker) + update_branch = old_tip is self._marker else: self.lock_tree_write() update_branch = False @@ -1594,8 +1686,9 @@ def update(self, change_reporter=None, possible_transports=None, finally: self.unlock() - def _update_tree(self, old_tip=None, change_reporter=None, revision=None, - show_base=False): + def _update_tree( + self, old_tip=None, change_reporter=None, revision=None, show_base=False + ): """Update a tree to the master branch. :param old_tip: if supplied, the previous tip revision the branch, @@ -1613,6 +1706,7 @@ def _update_tree(self, old_tip=None, change_reporter=None, revision=None, # with self.lock_tree_write(): from ..merge import merge_inner + nb_conflicts = [] try: last_rev = self.get_parent_ids()[0] @@ -1628,10 +1722,14 @@ def _update_tree(self, old_tip=None, change_reporter=None, revision=None, # merge those changes in first base_tree = self.basis_tree() other_tree = self.branch.repository.revision_tree(old_tip) - nb_conflicts = merge_inner(self.branch, other_tree, - base_tree, this_tree=self, - change_reporter=change_reporter, - show_base=show_base) + nb_conflicts = merge_inner( + self.branch, + other_tree, + base_tree, + this_tree=self, + change_reporter=change_reporter, + show_base=show_base, + ) if nb_conflicts: self.add_parent_tree((old_tip, other_tree)) return len(nb_conflicts) @@ -1641,14 +1739,19 @@ def _update_tree(self, old_tip=None, change_reporter=None, revision=None, # determine the branch point graph = self.branch.repository.get_graph() - base_rev_id = graph.find_unique_lca(self.branch.last_revision(), - last_rev) + base_rev_id = graph.find_unique_lca( + self.branch.last_revision(), last_rev + ) base_tree = self.branch.repository.revision_tree(base_rev_id) - nb_conflicts = merge_inner(self.branch, to_tree, base_tree, - this_tree=self, - change_reporter=change_reporter, - show_base=show_base) + nb_conflicts = merge_inner( + self.branch, + to_tree, + base_tree, + this_tree=self, + change_reporter=change_reporter, + show_base=show_base, + ) self.set_last_revision(revision) # TODO - dedup parents list with things merged by pull ? # reuse the tree we've updated to to set the basis: @@ -1661,17 +1764,18 @@ def _update_tree(self, old_tip=None, change_reporter=None, revision=None, # will not, but also does not need them when setting parents. for parent in merges: parent_trees.append( - (parent, self.branch.repository.revision_tree(parent))) + (parent, self.branch.repository.revision_tree(parent)) + ) if not _mod_revision.is_null(old_tip): parent_trees.append( - (old_tip, self.branch.repository.revision_tree(old_tip))) + (old_tip, self.branch.repository.revision_tree(old_tip)) + ) self.set_parent_trees(parent_trees) last_rev = parent_trees[0][0] return len(nb_conflicts) class GitWorkingTreeFormat(workingtree.WorkingTreeFormat): - _tree_class = GitWorkingTree supports_versioned_directories = False @@ -1693,21 +1797,27 @@ class GitWorkingTreeFormat(workingtree.WorkingTreeFormat): @property def _matchingcontroldir(self): from .dir import LocalGitControlDirFormat + return LocalGitControlDirFormat() def get_format_description(self): return "Git Working Tree" - def initialize(self, a_controldir, revision_id=None, from_branch=None, - accelerator_tree=None, hardlink=False): + def initialize( + self, + a_controldir, + revision_id=None, + from_branch=None, + accelerator_tree=None, + hardlink=False, + ): """See WorkingTreeFormat.initialize().""" if not isinstance(a_controldir, LocalGitDir): raise errors.IncompatibleFormat(self, a_controldir) branch = a_controldir.open_branch(nascent_ok=True) if revision_id is not None: branch.set_last_revision(revision_id) - wt = GitWorkingTree( - a_controldir, a_controldir.open_repository(), branch) - for hook in MutableTree.hooks['post_build_tree']: + wt = GitWorkingTree(a_controldir, a_controldir.open_repository(), branch) + for hook in MutableTree.hooks["post_build_tree"]: hook(wt) return wt diff --git a/breezy/globbing.py b/breezy/globbing.py index 0d26be2c58..e4cc9f7b7c 100644 --- a/breezy/globbing.py +++ b/breezy/globbing.py @@ -28,25 +28,28 @@ Replacer = _globbing_rs.Replacer _sub_named = Replacer() -_sub_named.add(r'\[:digit:\]', r'\d') -_sub_named.add(r'\[:space:\]', r'\s') -_sub_named.add(r'\[:alnum:\]', r'\w') -_sub_named.add(r'\[:ascii:\]', r'\0-\x7f') -_sub_named.add(r'\[:blank:\]', r' \t') -_sub_named.add(r'\[:cntrl:\]', r'\0-\x1f\x7f-\x9f') +_sub_named.add(r"\[:digit:\]", r"\d") +_sub_named.add(r"\[:space:\]", r"\s") +_sub_named.add(r"\[:alnum:\]", r"\w") +_sub_named.add(r"\[:ascii:\]", r"\0-\x7f") +_sub_named.add(r"\[:blank:\]", r" \t") +_sub_named.add(r"\[:cntrl:\]", r"\0-\x1f\x7f-\x9f") def _sub_group(m): - if m[1] in ('!', '^'): - return '[^' + _sub_named(m[2:-1]) + ']' - return '[' + _sub_named(m[1:-1]) + ']' + if m[1] in ("!", "^"): + return "[^" + _sub_named(m[2:-1]) + "]" + return "[" + _sub_named(m[1:-1]) + "]" def _invalid_regex(repl): def _(m): - warning(f"'{m}' not allowed within a regular expression. " - f"Replacing with '{repl}'") + warning( + f"'{m}' not allowed within a regular expression. " + f"Replacing with '{repl}'" + ) return repl + return _ @@ -57,39 +60,39 @@ def _trailing_backslashes_regex(m): one on the end that would escape the brackets we wrap the RE in. """ if (len(m) % 2) != 0: - warning("Regular expressions cannot end with an odd number of '\\'. " - "Dropping the final '\\'.") + warning( + "Regular expressions cannot end with an odd number of '\\'. " + "Dropping the final '\\'." + ) return m[:-1] return m _sub_re = Replacer() -_sub_re.add('^RE:', '') -_sub_re.add('\\((?!\\?)', '(?:') -_sub_re.add('\\(\\?P<.*>', _invalid_regex('(?:')) -_sub_re.add('\\(\\?P=[^)]*\\)', _invalid_regex('')) -_sub_re.add(r'\\+$', _trailing_backslashes_regex) +_sub_re.add("^RE:", "") +_sub_re.add("\\((?!\\?)", "(?:") +_sub_re.add("\\(\\?P<.*>", _invalid_regex("(?:")) +_sub_re.add("\\(\\?P=[^)]*\\)", _invalid_regex("")) +_sub_re.add(r"\\+$", _trailing_backslashes_regex) _sub_fullpath = Replacer() -_sub_fullpath.add(r'^RE:.*', _sub_re) # RE: is a regex -_sub_fullpath.add(r'\[\^?\]?(?:[^\]\[]|\[:[^\]]+:\])+\]', - _sub_group) # char group -_sub_fullpath.add(r'(?:(?<=/)|^)(?:\.?/)+', '') # canonicalize path -_sub_fullpath.add(r'\\.', r'\&') # keep anything backslashed -_sub_fullpath.add(r'[(){}|^$+.]', r'\\&') # escape specials -_sub_fullpath.add(r'(?:(?<=/)|^)\*\*+/', r'(?:.*/)?') # **/ after ^ or / -_sub_fullpath.add(r'\*+', r'[^/]*') # * elsewhere -_sub_fullpath.add(r'\?', r'[^/]') # ? everywhere +_sub_fullpath.add(r"^RE:.*", _sub_re) # RE: is a regex +_sub_fullpath.add(r"\[\^?\]?(?:[^\]\[]|\[:[^\]]+:\])+\]", _sub_group) # char group +_sub_fullpath.add(r"(?:(?<=/)|^)(?:\.?/)+", "") # canonicalize path +_sub_fullpath.add(r"\\.", r"\&") # keep anything backslashed +_sub_fullpath.add(r"[(){}|^$+.]", r"\\&") # escape specials +_sub_fullpath.add(r"(?:(?<=/)|^)\*\*+/", r"(?:.*/)?") # **/ after ^ or / +_sub_fullpath.add(r"\*+", r"[^/]*") # * elsewhere +_sub_fullpath.add(r"\?", r"[^/]") # ? everywhere _sub_basename = Replacer() -_sub_basename.add(r'\[\^?\]?(?:[^\]\[]|\[:[^\]]+:\])+\]', - _sub_group) # char group -_sub_basename.add(r'\\.', r'\&') # keep anything backslashed -_sub_basename.add(r'[(){}|^$+.]', r'\\&') # escape specials -_sub_basename.add(r'\*+', r'.*') # * everywhere -_sub_basename.add(r'\?', r'.') # ? everywhere +_sub_basename.add(r"\[\^?\]?(?:[^\]\[]|\[:[^\]]+:\])+\]", _sub_group) # char group +_sub_basename.add(r"\\.", r"\&") # keep anything backslashed +_sub_basename.add(r"[(){}|^$+.]", r"\\&") # escape specials +_sub_basename.add(r"\*+", r".*") # * everywhere +_sub_basename.add(r"\?", r".") # ? everywhere def _sub_extension(pattern): @@ -121,6 +124,7 @@ class Globster: so are matched first, then the basename patterns, then the fullpath patterns. """ + # We want to _add_patterns in a specific order (as per type_list below) # starting with the shortest and going to the longest. # As some Python version don't support ordered dicts the list below is @@ -130,16 +134,10 @@ class Globster: pattern_info = { "extension": { "translator": _sub_extension, - "prefix": r'(?:.*/)?(?!.*/)(?:.*\.)' - }, - "basename": { - "translator": _sub_basename, - "prefix": r'(?:.*/)?(?!.*/)' - }, - "fullpath": { - "translator": _sub_fullpath, - "prefix": r'' + "prefix": r"(?:.*/)?(?!.*/)(?:.*\.)", }, + "basename": {"translator": _sub_basename, "prefix": r"(?:.*/)?(?!.*/)"}, + "fullpath": {"translator": _sub_fullpath, "prefix": r""}, } def __init__(self, patterns): @@ -154,19 +152,17 @@ def __init__(self, patterns): pattern_lists[Globster.identify(pat)].append(pat) pi = Globster.pattern_info for t in Globster.pattern_types: - self._add_patterns(pattern_lists[t], pi[t]["translator"], - pi[t]["prefix"]) + self._add_patterns(pattern_lists[t], pi[t]["translator"], pi[t]["prefix"]) - def _add_patterns(self, patterns, translator, prefix=''): + def _add_patterns(self, patterns, translator, prefix=""): while patterns: - grouped_rules = [ - f'({translator(pat)})' for pat in patterns[:99]] + grouped_rules = [f"({translator(pat)})" for pat in patterns[:99]] joined_rule = f"{prefix}(?:{'|'.join(grouped_rules)})$" # Explicitly use lazy_compile here, because we count on its # nicer error reporting. - self._regex_patterns.append(( - lazy_regex.lazy_compile(joined_rule, re.UNICODE), - patterns[:99])) + self._regex_patterns.append( + (lazy_regex.lazy_compile(joined_rule, re.UNICODE), patterns[:99]) + ) patterns = patterns[99:] def match(self, filename): @@ -183,15 +179,13 @@ def match(self, filename): # We can't show the default e.msg to the user as thats for # the combined pattern we sent to regex. Instead we indicate to # the user that an ignore file needs fixing. - mutter('Invalid pattern found in regex: %s.', e.msg) - e.msg = ( - "File ~/.config/breezy/ignore or " - ".bzrignore contains error(s).") - bad_patterns = '' + mutter("Invalid pattern found in regex: %s.", e.msg) + e.msg = "File ~/.config/breezy/ignore or " ".bzrignore contains error(s)." + bad_patterns = "" for _, patterns in self._regex_patterns: for p in patterns: if not Globster.is_pattern_valid(p): - bad_patterns += f'\n {p}' + bad_patterns += f"\n {p}" e.msg += bad_patterns raise e return None @@ -204,9 +198,9 @@ def identify(pattern): Identify if a pattern is fullpath, basename or extension and returns the appropriate type. """ - if pattern.startswith('RE:') or '/' in pattern: + if pattern.startswith("RE:") or "/" in pattern: return "fullpath" - elif pattern.startswith('*.'): + elif pattern.startswith("*."): return "extension" else: return "basename" @@ -220,9 +214,8 @@ def is_pattern_valid(pattern): see: globbing.normalize_pattern """ result = True - translator = Globster.pattern_info[Globster.identify( - pattern)]["translator"] - tpattern = f'({translator(pattern)})' + translator = Globster.pattern_info[Globster.identify(pattern)]["translator"] + tpattern = f"({translator(pattern)})" try: re_obj = lazy_regex.lazy_compile(tpattern, re.UNICODE) re_obj.search("") # force compile @@ -245,9 +238,9 @@ class ExceptionGlobster: def __init__(self, patterns): ignores = [[], [], []] for p in patterns: - if p.startswith('!!'): + if p.startswith("!!"): ignores[2].append(p[2:]) - elif p.startswith('!'): + elif p.startswith("!"): ignores[1].append(p[1:]) else: ignores[0].append(p) @@ -280,8 +273,11 @@ def __init__(self, patterns): for pat in patterns: pat = normalize_pattern(pat) t = Globster.identify(pat) - self._add_patterns([pat], Globster.pattern_info[t]["translator"], - Globster.pattern_info[t]["prefix"]) + self._add_patterns( + [pat], + Globster.pattern_info[t]["translator"], + Globster.pattern_info[t]["prefix"], + ) normalize_pattern = _globbing_rs.normalize_pattern diff --git a/breezy/gpg.py b/breezy/gpg.py index b769f5ed71..de34e5caa6 100644 --- a/breezy/gpg.py +++ b/breezy/gpg.py @@ -36,16 +36,16 @@ class GpgNotInstalled(errors.DependencyNotPresent): - - _fmt = ('python-gpg is not installed, it is needed to create or ' - 'verify signatures. %(error)s') + _fmt = ( + "python-gpg is not installed, it is needed to create or " + "verify signatures. %(error)s" + ) def __init__(self, error): - errors.DependencyNotPresent.__init__(self, 'gpg', error) + errors.DependencyNotPresent.__init__(self, "gpg", error) class SigningFailed(errors.BzrError): - _fmt = 'Failed to GPG sign data: "%(error)s"' def __init__(self, error): @@ -53,15 +53,13 @@ def __init__(self, error): class SignatureVerificationFailed(errors.BzrError): - _fmt = 'Failed to verify GPG signature data with error "%(error)s"' def __init__(self, error): errors.BzrError.__init__(self, error=error) -def bulk_verify_signatures(repository, revids, strategy, - process_events_callback=None): +def bulk_verify_signatures(repository, revids, strategy, process_events_callback=None): """Do verifications on a set of revisions. :param repository: repository object @@ -74,18 +72,20 @@ def bulk_verify_signatures(repository, revids, strategy, result list for each revision, boolean True if all results are verified successfully """ - count = {SIGNATURE_VALID: 0, - SIGNATURE_KEY_MISSING: 0, - SIGNATURE_NOT_VALID: 0, - SIGNATURE_NOT_SIGNED: 0, - SIGNATURE_EXPIRED: 0} + count = { + SIGNATURE_VALID: 0, + SIGNATURE_KEY_MISSING: 0, + SIGNATURE_NOT_VALID: 0, + SIGNATURE_NOT_SIGNED: 0, + SIGNATURE_EXPIRED: 0, + } result = [] all_verifiable = True total = len(revids) with ui.ui_factory.nested_progress_bar() as pb: for i, (rev_id, verification_result, uid) in enumerate( - repository.verify_revision_signatures( - revids, strategy)): + repository.verify_revision_signatures(revids, strategy) + ): pb.update("verifying signatures", i, total) result.append([rev_id, verification_result, uid]) count[verification_result] += 1 @@ -107,11 +107,13 @@ def __init__(self, ignored): """Real strategies take a configuration.""" def sign(self, content, mode): - raise SigningFailed('Signing is disabled.') + raise SigningFailed("Signing is disabled.") def verify(self, signed_data, signature=None): - raise SignatureVerificationFailed('Signature verification is \ -disabled.') + raise SignatureVerificationFailed( + "Signature verification is \ +disabled." + ) def set_acceptable_keys(self, command_line_input): pass @@ -130,14 +132,17 @@ def __init__(self, ignored): """Real strategies take a configuration.""" def sign(self, content, mode): - return (b"-----BEGIN PSEUDO-SIGNED CONTENT-----\n" + content - + b"-----END PSEUDO-SIGNED CONTENT-----\n") + return ( + b"-----BEGIN PSEUDO-SIGNED CONTENT-----\n" + + content + + b"-----END PSEUDO-SIGNED CONTENT-----\n" + ) def verify(self, signed_data, signature=None): plain_text = signed_data.replace( - b"-----BEGIN PSEUDO-SIGNED CONTENT-----\n", b"") - plain_text = plain_text.replace( - b"-----END PSEUDO-SIGNED CONTENT-----\n", b"") + b"-----BEGIN PSEUDO-SIGNED CONTENT-----\n", b"" + ) + plain_text = plain_text.replace(b"-----END PSEUDO-SIGNED CONTENT-----\n", b"") return SIGNATURE_VALID, None, plain_text def set_acceptable_keys(self, command_line_input): @@ -152,16 +157,15 @@ def set_acceptable_keys(self, command_line_input): def _set_gpg_tty(): - tty = os.environ.get('TTY') + tty = os.environ.get("TTY") if tty is not None: - os.environ['GPG_TTY'] = tty - trace.mutter('setting GPG_TTY=%s', tty) + os.environ["GPG_TTY"] = tty + trace.mutter("setting GPG_TTY=%s", tty) else: # This is not quite worthy of a warning, because some people # don't need GPG_TTY to be set. But it is worthy of a big mark # in brz.log, so that people can debug it if it happens to them - trace.mutter('** Env var TTY empty, cannot set GPG_TTY.' - ' Is TTY exported?') + trace.mutter("** Env var TTY empty, cannot set GPG_TTY." " Is TTY exported?") class GPGStrategy: @@ -173,6 +177,7 @@ def __init__(self, config_stack): self._config_stack = config_stack try: import gpg + self.context = gpg.Context() self.context.armor = True self.context.signers = self._get_signing_keys() @@ -181,8 +186,9 @@ def __init__(self, config_stack): def _get_signing_keys(self): import gpg - keyname = self._config_stack.get('gpg_signing_key') - if keyname == 'default': + + keyname = self._config_stack.get("gpg_signing_key") + if keyname == "default": # Leave things to gpg return [] @@ -195,9 +201,8 @@ def _get_signing_keys(self): if keyname is None: # not setting gpg_signing_key at all means we should # use the user email address - keyname = config.extract_email_address( - self._config_stack.get('email')) - if keyname == 'default': + keyname = config.extract_email_address(self._config_stack.get("email")) + if keyname == "default": return [] possible_keys = self.context.keylist(keyname, secret=True) try: @@ -213,6 +218,7 @@ def verify_signatures_available(): """ try: import gpg # noqa: F401 + return True except ModuleNotFoundError: return False @@ -222,19 +228,22 @@ def sign(self, content, mode): import gpg except ModuleNotFoundError as err: raise GpgNotInstalled( - 'Set create_signatures=no to disable creating signatures.') from err + "Set create_signatures=no to disable creating signatures." + ) from err if isinstance(content, str): - raise errors.BzrBadParameterUnicode('content') + raise errors.BzrBadParameterUnicode("content") plain_text = gpg.Data(content) try: output, result = self.context.sign( - plain_text, mode={ + plain_text, + mode={ MODE_DETACH: gpg.constants.sig.mode.DETACH, MODE_CLEAR: gpg.constants.sig.mode.CLEAR, MODE_NORMAL: gpg.constants.sig.mode.NORMAL, - }[mode]) + }[mode], + ) except gpg.errors.GPGMEError as error: raise SigningFailed(str(error)) from error except gpg.errors.InvalidSigners as error: @@ -254,7 +263,8 @@ def verify(self, signed_data, signature=None): import gpg except ModuleNotFoundError as err: raise GpgNotInstalled( - 'Set check_signatures=ignore to disable verifying signatures.') from err + "Set check_signatures=ignore to disable verifying signatures." + ) from err signed_data = gpg.Data(signed_data) if signature: @@ -264,8 +274,11 @@ def verify(self, signed_data, signature=None): except gpg.errors.BadSignatures as error: fingerprint = error.result.signatures[0].fpr if error.result.signatures[0].summary & gpg.constants.SIGSUM_KEY_EXPIRED: - expires = self.context.get_key( - error.result.signatures[0].fpr).subkeys[0].expires + expires = ( + self.context.get_key(error.result.signatures[0].fpr) + .subkeys[0] + .expires + ) if expires > error.result.signatures[0].timestamp: # The expired key was not expired at time of signing. # test_verify_expired_but_valid() @@ -277,8 +290,7 @@ def verify(self, signed_data, signature=None): # GPG does not know this key. # test_verify_unknown_key() - if (error.result.signatures[0].summary & - gpg.constants.SIGSUM_KEY_MISSING): + if error.result.signatures[0].summary & gpg.constants.SIGSUM_KEY_MISSING: return SIGNATURE_KEY_MISSING, fingerprint[-8:], None return SIGNATURE_NOT_VALID, None, None @@ -303,10 +315,10 @@ def verify(self, signed_data, signature=None): key = self.context.get_key(fingerprint) name = key.uids[0].name if isinstance(name, bytes): - name = name.decode('utf-8') + name = name.decode("utf-8") email = key.uids[0].email if isinstance(email, bytes): - email = email.decode('utf-8') + email = email.decode("utf-8") return (SIGNATURE_VALID, name + " <" + email + ">", plain_output) # Sigsum_red indicates a problem, unfortunatly I have not been able # to write any tests which actually set this. @@ -314,8 +326,7 @@ def verify(self, signed_data, signature=None): return SIGNATURE_NOT_VALID, None, plain_output # Summary isn't set if sig is valid but key is untrusted but if user # has explicity set the key as acceptable we can validate it. - if (result.signatures[0].summary == 0 and - self.acceptable_keys is not None): + if result.signatures[0].summary == 0 and self.acceptable_keys is not None: if fingerprint in self.acceptable_keys: # test_verify_untrusted_but_accepted() return SIGNATURE_VALID, None, plain_output @@ -324,8 +335,7 @@ def verify(self, signed_data, signature=None): return SIGNATURE_NOT_VALID, None, plain_output # Other error types such as revoked keys should (I think) be caught by # SIGSUM_RED so anything else means something is buggy. - raise SignatureVerificationFailed( - "Unknown GnuPG key verification result") + raise SignatureVerificationFailed("Unknown GnuPG key verification result") def set_acceptable_keys(self, command_line_input): """Set the acceptable keys for verifying with this GPGStrategy. @@ -335,11 +345,11 @@ def set_acceptable_keys(self, command_line_input): :return: nothing """ patterns = None - acceptable_keys_config = self._config_stack.get('acceptable_keys') + acceptable_keys_config = self._config_stack.get("acceptable_keys") if acceptable_keys_config is not None: patterns = acceptable_keys_config if command_line_input is not None: # command line overrides config - patterns = command_line_input.split(',') + patterns = command_line_input.split(",") if patterns: self.acceptable_keys = [] @@ -351,47 +361,46 @@ def set_acceptable_keys(self, command_line_input): self.acceptable_keys.append(key.subkeys[0].fpr) trace.mutter("Added acceptable key: " + key.subkeys[0].fpr) if not found_key: - trace.note(gettext( - "No GnuPG key results for pattern: {0}" - ).format(pattern)) + trace.note( + gettext("No GnuPG key results for pattern: {0}").format(pattern) + ) def valid_commits_message(count): """Returns message for number of commits.""" - return gettext("{0} commits with valid signatures").format( - count[SIGNATURE_VALID]) + return gettext("{0} commits with valid signatures").format(count[SIGNATURE_VALID]) def unknown_key_message(count): """Returns message for number of commits.""" - return ngettext("{0} commit with unknown key", - "{0} commits with unknown keys", - count[SIGNATURE_KEY_MISSING]).format( - count[SIGNATURE_KEY_MISSING]) + return ngettext( + "{0} commit with unknown key", + "{0} commits with unknown keys", + count[SIGNATURE_KEY_MISSING], + ).format(count[SIGNATURE_KEY_MISSING]) def commit_not_valid_message(count): """Returns message for number of commits.""" - return ngettext("{0} commit not valid", - "{0} commits not valid", - count[SIGNATURE_NOT_VALID]).format( - count[SIGNATURE_NOT_VALID]) + return ngettext( + "{0} commit not valid", "{0} commits not valid", count[SIGNATURE_NOT_VALID] + ).format(count[SIGNATURE_NOT_VALID]) def commit_not_signed_message(count): """Returns message for number of commits.""" - return ngettext("{0} commit not signed", - "{0} commits not signed", - count[SIGNATURE_NOT_SIGNED]).format( - count[SIGNATURE_NOT_SIGNED]) + return ngettext( + "{0} commit not signed", "{0} commits not signed", count[SIGNATURE_NOT_SIGNED] + ).format(count[SIGNATURE_NOT_SIGNED]) def expired_commit_message(count): """Returns message for number of commits.""" - return ngettext("{0} commit with key now expired", - "{0} commits with key now expired", - count[SIGNATURE_EXPIRED]).format( - count[SIGNATURE_EXPIRED]) + return ngettext( + "{0} commit with key now expired", + "{0} commits with key now expired", + count[SIGNATURE_EXPIRED], + ).format(count[SIGNATURE_EXPIRED]) def verbose_expired_key_message(result, repo) -> List[str]: @@ -401,17 +410,19 @@ def verbose_expired_key_message(result, repo) -> List[str]: for rev_id, validity, fingerprint in result: if validity == SIGNATURE_EXPIRED: revision = repo.get_revision(rev_id) - authors = ', '.join(revision.get_apparent_authors()) + authors = ", ".join(revision.get_apparent_authors()) signers.setdefault(fingerprint, 0) signers[fingerprint] += 1 fingerprint_to_authors[fingerprint] = authors ret: List[str] = [] for fingerprint, number in signers.items(): ret.append( - ngettext("{0} commit by author {1} with key {2} now expired", - "{0} commits by author {1} with key {2} now expired", - number).format( - number, fingerprint_to_authors[fingerprint], fingerprint)) + ngettext( + "{0} commit by author {1} with key {2} now expired", + "{0} commits by author {1} with key {2} now expired", + number, + ).format(number, fingerprint_to_authors[fingerprint], fingerprint) + ) return ret @@ -424,9 +435,11 @@ def verbose_valid_message(result) -> List[str]: signers[uid] += 1 ret: List[str] = [] for uid, number in signers.items(): - ret.append(ngettext("{0} signed {1} commit", - "{0} signed {1} commits", - number).format(uid, number)) + ret.append( + ngettext("{0} signed {1} commit", "{0} signed {1} commits", number).format( + uid, number + ) + ) return ret @@ -436,14 +449,16 @@ def verbose_not_valid_message(result, repo) -> List[str]: for rev_id, validity, _empty in result: if validity == SIGNATURE_NOT_VALID: revision = repo.get_revision(rev_id) - authors = ', '.join(revision.get_apparent_authors()) + authors = ", ".join(revision.get_apparent_authors()) signers.setdefault(authors, 0) signers[authors] += 1 ret: List[str] = [] for authors, number in signers.items(): - ret.append(ngettext("{0} commit by author {1}", - "{0} commits by author {1}", - number).format(number, authors)) + ret.append( + ngettext( + "{0} commit by author {1}", "{0} commits by author {1}", number + ).format(number, authors) + ) return ret @@ -453,14 +468,16 @@ def verbose_not_signed_message(result, repo) -> List[str]: for rev_id, validity, _empty in result: if validity == SIGNATURE_NOT_SIGNED: revision = repo.get_revision(rev_id) - authors = ', '.join(revision.get_apparent_authors()) + authors = ", ".join(revision.get_apparent_authors()) signers.setdefault(authors, 0) signers[authors] += 1 ret: List[str] = [] for authors, number in signers.items(): - ret.append(ngettext("{0} commit by author {1}", - "{0} commits by author {1}", - number).format(number, authors)) + ret.append( + ngettext( + "{0} commit by author {1}", "{0} commits by author {1}", number + ).format(number, authors) + ) return ret @@ -473,7 +490,11 @@ def verbose_missing_key_message(result) -> List[str]: signers[fingerprint] += 1 ret: List[str] = [] for fingerprint, number in list(signers.items()): - ret.append(ngettext("Unknown key {0} signed {1} commit", - "Unknown key {0} signed {1} commits", - number).format(fingerprint, number)) + ret.append( + ngettext( + "Unknown key {0} signed {1} commit", + "Unknown key {0} signed {1} commits", + number, + ).format(fingerprint, number) + ) return ret diff --git a/breezy/graph.py b/breezy/graph.py index d6312de0ee..11c50d274a 100644 --- a/breezy/graph.py +++ b/breezy/graph.py @@ -49,7 +49,7 @@ def __init__(self, ancestry): self.ancestry = ancestry def __repr__(self): - return f'DictParentsProvider({self.ancestry!r})' + return f"DictParentsProvider({self.ancestry!r})" # Note: DictParentsProvider does not implement get_cached_parent_map # Arguably, the data is clearly cached in memory. However, this class @@ -94,8 +94,7 @@ def get_parent_map(self, keys): # (either local indexes, or remote RPCs), so CPU overhead should be # minimal. for parents_provider in self._parent_providers: - get_cached = getattr(parents_provider, 'get_cached_parent_map', - None) + get_cached = getattr(parents_provider, "get_cached_parent_map", None) if get_cached is None: continue new_found = get_cached(remaining) @@ -151,7 +150,7 @@ def __repr__(self): def enable_cache(self, cache_misses=True): """Enable cache.""" if self._cache is not None: - raise AssertionError('Cache enabled when already enabled.') + raise AssertionError("Cache enabled when already enabled.") self._cache = {} self._cache_misses = cache_misses self.missing_keys = set() @@ -243,14 +242,14 @@ def __init__(self, parents_provider): conforming to the behavior of StackedParentsProvider.get_parent_map. """ - if getattr(parents_provider, 'get_parents', None) is not None: + if getattr(parents_provider, "get_parents", None) is not None: self.get_parents = parents_provider.get_parents - if getattr(parents_provider, 'get_parent_map', None) is not None: + if getattr(parents_provider, "get_parent_map", None) is not None: self.get_parent_map = parents_provider.get_parent_map self._parents_provider = parents_provider def __repr__(self): - return f'Graph({self._parents_provider!r})' + return f"Graph({self._parents_provider!r})" def find_lca(self, *revisions): """Determine the lowest common ancestors of the provided revisions. @@ -295,7 +294,8 @@ def find_lca(self, *revisions): def find_difference(self, left_revision, right_revision): """Determine the graph difference between two revisions.""" border, common, searchers = self._find_border_ancestors( - [left_revision, right_revision]) + [left_revision, right_revision] + ) self._search_for_extra_common(common, searchers) left = searchers[0].seen right = searchers[1].seen @@ -303,8 +303,9 @@ def find_difference(self, left_revision, right_revision): def find_descendants(self, old_key, new_key): """Find descendants of old_key that are ancestors of new_key.""" - child_map = self.get_child_map(self._find_descendant_ancestors( - old_key, new_key)) + child_map = self.get_child_map( + self._find_descendant_ancestors(old_key, new_key) + ) graph = Graph(DictParentsProvider(child_map)) searcher = graph._make_breadth_first_searcher([old_key]) list(searcher) @@ -365,8 +366,7 @@ def find_distance_to_null(self, target_revision_id, known_revision_ids): parent_map = self.get_parent_map(to_search) parents = parent_map.get(cur_tip, None) if not parents: # An empty list or None is a ghost - raise errors.GhostRevisionsHaveNoRevno(target_revision_id, - cur_tip) + raise errors.GhostRevisionsHaveNoRevno(target_revision_id, cur_tip) cur_tip = parents[0] next_known_tips = [] for revision_id in searching_known_tips: @@ -401,7 +401,8 @@ def find_lefthand_distances(self, keys): for key in keys: try: known_revnos.append( - (key, self.find_distance_to_null(key, known_revnos))) + (key, self.find_distance_to_null(key, known_revnos)) + ) except errors.GhostRevisionsHaveNoRevno: ghosts.append(key) for key in ghosts: @@ -445,22 +446,27 @@ def find_unique_ancestors(self, unique_revision, common_revisions): # 8) Search is done when all common searchers have completed. unique_searcher, common_searcher = self._find_initial_unique_nodes( - [unique_revision], common_revisions) + [unique_revision], common_revisions + ) unique_nodes = unique_searcher.seen.difference(common_searcher.seen) if not unique_nodes: return unique_nodes - (all_unique_searcher, - unique_tip_searchers) = self._make_unique_searchers( - unique_nodes, unique_searcher, common_searcher) + (all_unique_searcher, unique_tip_searchers) = self._make_unique_searchers( + unique_nodes, unique_searcher, common_searcher + ) - self._refine_unique_nodes(unique_searcher, all_unique_searcher, - unique_tip_searchers, common_searcher) + self._refine_unique_nodes( + unique_searcher, all_unique_searcher, unique_tip_searchers, common_searcher + ) true_unique_nodes = unique_nodes.difference(common_searcher.seen) - if debug.debug_flag_enabled('graph'): - trace.mutter('Found %d truly unique nodes out of %d', - len(true_unique_nodes), len(unique_nodes)) + if debug.debug_flag_enabled("graph"): + trace.mutter( + "Found %d truly unique nodes out of %d", + len(true_unique_nodes), + len(unique_nodes), + ) return true_unique_nodes def _find_initial_unique_nodes(self, unique_revisions, common_revisions): @@ -485,25 +491,24 @@ def _find_initial_unique_nodes(self, unique_revisions, common_revisions): # Check if either searcher encounters new nodes seen by the other # side. unique_are_common_nodes = next_unique_nodes.intersection( - common_searcher.seen) + common_searcher.seen + ) unique_are_common_nodes.update( - next_common_nodes.intersection(unique_searcher.seen)) + next_common_nodes.intersection(unique_searcher.seen) + ) if unique_are_common_nodes: - ancestors = unique_searcher.find_seen_ancestors( - unique_are_common_nodes) + ancestors = unique_searcher.find_seen_ancestors(unique_are_common_nodes) # TODO: This is a bit overboard, we only really care about # the ancestors of the tips because the rest we # already know. This is *correct* but causes us to # search too much ancestry. - ancestors.update( - common_searcher.find_seen_ancestors(ancestors)) + ancestors.update(common_searcher.find_seen_ancestors(ancestors)) unique_searcher.stop_searching_any(ancestors) common_searcher.start_searching(ancestors) return unique_searcher, common_searcher - def _make_unique_searchers(self, unique_nodes, unique_searcher, - common_searcher): + def _make_unique_searchers(self, unique_nodes, unique_searcher, common_searcher): """Create a searcher for all the unique search tips (step 4). As a side effect, the common_searcher will stop searching any nodes @@ -512,18 +517,19 @@ def _make_unique_searchers(self, unique_nodes, unique_searcher, :return: (all_unique_searcher, unique_tip_searchers) """ unique_tips = self._remove_simple_descendants( - unique_nodes, self.get_parent_map(unique_nodes)) + unique_nodes, self.get_parent_map(unique_nodes) + ) if len(unique_tips) == 1: unique_tip_searchers = [] - ancestor_all_unique = unique_searcher.find_seen_ancestors( - unique_tips) + ancestor_all_unique = unique_searcher.find_seen_ancestors(unique_tips) else: unique_tip_searchers = [] for tip in unique_tips: revs_to_search = unique_searcher.find_seen_ancestors([tip]) revs_to_search.update( - common_searcher.find_seen_ancestors(revs_to_search)) + common_searcher.find_seen_ancestors(revs_to_search) + ) searcher = self._make_breadth_first_searcher(revs_to_search) # We don't care about the starting nodes. searcher._label = tip @@ -536,10 +542,10 @@ def _make_unique_searchers(self, unique_nodes, unique_searcher, ancestor_all_unique = set(searcher.seen) else: ancestor_all_unique = ancestor_all_unique.intersection( - searcher.seen) + searcher.seen + ) # Collapse all the common nodes into a single searcher - all_unique_searcher = self._make_breadth_first_searcher( - ancestor_all_unique) + all_unique_searcher = self._make_breadth_first_searcher(ancestor_all_unique) if ancestor_all_unique: # We've seen these nodes in all the searchers, so we'll just go to # the next @@ -548,24 +554,32 @@ def _make_unique_searchers(self, unique_nodes, unique_searcher, # Stop any search tips that are already known as ancestors of the # unique nodes stopped_common = common_searcher.stop_searching_any( - common_searcher.find_seen_ancestors(ancestor_all_unique)) + common_searcher.find_seen_ancestors(ancestor_all_unique) + ) total_stopped = 0 for searcher in unique_tip_searchers: - total_stopped += len(searcher.stop_searching_any( - searcher.find_seen_ancestors(ancestor_all_unique))) - if debug.debug_flag_enabled('graph'): - trace.mutter('For %d unique nodes, created %d + 1 unique searchers' - ' (%d stopped search tips, %d common ancestors' - ' (%d stopped common)', - len(unique_nodes), len(unique_tip_searchers), - total_stopped, len(ancestor_all_unique), - len(stopped_common)) + total_stopped += len( + searcher.stop_searching_any( + searcher.find_seen_ancestors(ancestor_all_unique) + ) + ) + if debug.debug_flag_enabled("graph"): + trace.mutter( + "For %d unique nodes, created %d + 1 unique searchers" + " (%d stopped search tips, %d common ancestors" + " (%d stopped common)", + len(unique_nodes), + len(unique_tip_searchers), + total_stopped, + len(ancestor_all_unique), + len(stopped_common), + ) return all_unique_searcher, unique_tip_searchers - def _step_unique_and_common_searchers(self, common_searcher, - unique_tip_searchers, - unique_searcher): + def _step_unique_and_common_searchers( + self, common_searcher, unique_tip_searchers, unique_searcher + ): """Step all the searchers.""" newly_seen_common = set(common_searcher.step()) newly_seen_unique = set() @@ -581,9 +595,13 @@ def _step_unique_and_common_searchers(self, common_searcher, newly_seen_unique.update(next) return newly_seen_common, newly_seen_unique - def _find_nodes_common_to_all_unique(self, unique_tip_searchers, - all_unique_searcher, - newly_seen_unique, step_all_unique): + def _find_nodes_common_to_all_unique( + self, + unique_tip_searchers, + all_unique_searcher, + newly_seen_unique, + step_all_unique, + ): """Find nodes that are common to all unique_tip_searchers. If it is time, step the all_unique_searcher, and add its nodes to the @@ -592,8 +610,7 @@ def _find_nodes_common_to_all_unique(self, unique_tip_searchers, common_to_all_unique_nodes = newly_seen_unique.copy() for searcher in unique_tip_searchers: common_to_all_unique_nodes.intersection_update(searcher.seen) - common_to_all_unique_nodes.intersection_update( - all_unique_searcher.seen) + common_to_all_unique_nodes.intersection_update(all_unique_searcher.seen) # Step all-unique less frequently than the other searchers. # In the common case, we don't need to spider out far here, so # avoid doing extra work. @@ -601,16 +618,21 @@ def _find_nodes_common_to_all_unique(self, unique_tip_searchers, tstart = osutils.perf_counter() nodes = all_unique_searcher.step() common_to_all_unique_nodes.update(nodes) - if debug.debug_flag_enabled('graph'): + if debug.debug_flag_enabled("graph"): tdelta = osutils.perf_counter() - tstart - trace.mutter('all_unique_searcher step() took %.3fs' - 'for %d nodes (%d total), iteration: %s', - tdelta, len(nodes), len(all_unique_searcher.seen), - all_unique_searcher._iterations) + trace.mutter( + "all_unique_searcher step() took %.3fs" + "for %d nodes (%d total), iteration: %s", + tdelta, + len(nodes), + len(all_unique_searcher.seen), + all_unique_searcher._iterations, + ) return common_to_all_unique_nodes - def _collapse_unique_searchers(self, unique_tip_searchers, - common_to_all_unique_nodes): + def _collapse_unique_searchers( + self, unique_tip_searchers, common_to_all_unique_nodes + ): """Combine searchers that are searching the same tips. When two searchers are searching the same tips, we can stop one of the @@ -626,12 +648,14 @@ def _collapse_unique_searchers(self, unique_tip_searchers, stopped = searcher.stop_searching_any(common_to_all_unique_nodes) will_search_set = frozenset(searcher._next_query) if not will_search_set: - if debug.debug_flag_enabled('graph'): - trace.mutter('Unique searcher %s was stopped.' - ' (%s iterations) %d nodes stopped', - searcher._label, - searcher._iterations, - len(stopped)) + if debug.debug_flag_enabled("graph"): + trace.mutter( + "Unique searcher %s was stopped." + " (%s iterations) %d nodes stopped", + searcher._label, + searcher._iterations, + len(stopped), + ) elif will_search_set not in unique_search_tips: # This searcher is searching a unique set of nodes, let it unique_search_tips[will_search_set] = [searcher] @@ -652,18 +676,25 @@ def _collapse_unique_searchers(self, unique_tip_searchers, next_searcher = searchers[0] for searcher in searchers[1:]: next_searcher.seen.intersection_update(searcher.seen) - if debug.debug_flag_enabled('graph'): - trace.mutter('Combining %d searchers into a single' - ' searcher searching %d nodes with' - ' %d ancestry', - len(searchers), - len(next_searcher._next_query), - len(next_searcher.seen)) + if debug.debug_flag_enabled("graph"): + trace.mutter( + "Combining %d searchers into a single" + " searcher searching %d nodes with" + " %d ancestry", + len(searchers), + len(next_searcher._next_query), + len(next_searcher.seen), + ) next_unique_searchers.append(next_searcher) return next_unique_searchers - def _refine_unique_nodes(self, unique_searcher, all_unique_searcher, - unique_tip_searchers, common_searcher): + def _refine_unique_nodes( + self, + unique_searcher, + all_unique_searcher, + unique_tip_searchers, + common_searcher, + ): """Steps 5-8 of find_unique_ancestors. This function returns when common_searcher has stopped searching for @@ -674,25 +705,33 @@ def _refine_unique_nodes(self, unique_searcher, all_unique_searcher, step_all_unique_counter = 0 # While we still have common nodes to search while common_searcher._next_query: - (newly_seen_common, - newly_seen_unique) = self._step_unique_and_common_searchers( - common_searcher, unique_tip_searchers, unique_searcher) + ( + newly_seen_common, + newly_seen_unique, + ) = self._step_unique_and_common_searchers( + common_searcher, unique_tip_searchers, unique_searcher + ) # These nodes are common ancestors of all unique nodes common_to_all_unique_nodes = self._find_nodes_common_to_all_unique( - unique_tip_searchers, all_unique_searcher, newly_seen_unique, - step_all_unique_counter == 0) - step_all_unique_counter = ((step_all_unique_counter + 1) - % STEP_UNIQUE_SEARCHER_EVERY) + unique_tip_searchers, + all_unique_searcher, + newly_seen_unique, + step_all_unique_counter == 0, + ) + step_all_unique_counter = ( + step_all_unique_counter + 1 + ) % STEP_UNIQUE_SEARCHER_EVERY if newly_seen_common: # If a 'common' node is an ancestor of all unique searchers, we # can stop searching it. common_searcher.stop_searching_any( - all_unique_searcher.seen.intersection(newly_seen_common)) + all_unique_searcher.seen.intersection(newly_seen_common) + ) if common_to_all_unique_nodes: common_to_all_unique_nodes.update( - common_searcher.find_seen_ancestors( - common_to_all_unique_nodes)) + common_searcher.find_seen_ancestors(common_to_all_unique_nodes) + ) # The all_unique searcher can start searching the common nodes # but everyone else can stop. # This is the sort of thing where we would like to not have it @@ -704,17 +743,19 @@ def _refine_unique_nodes(self, unique_searcher, all_unique_searcher, common_searcher.stop_searching_any(common_to_all_unique_nodes) next_unique_searchers = self._collapse_unique_searchers( - unique_tip_searchers, common_to_all_unique_nodes) + unique_tip_searchers, common_to_all_unique_nodes + ) if len(unique_tip_searchers) != len(next_unique_searchers): - if debug.debug_flag_enabled('graph'): - trace.mutter('Collapsed %d unique searchers => %d' - ' at %s iterations', - len(unique_tip_searchers), - len(next_unique_searchers), - all_unique_searcher._iterations) + if debug.debug_flag_enabled("graph"): + trace.mutter( + "Collapsed %d unique searchers => %d" " at %s iterations", + len(unique_tip_searchers), + len(next_unique_searchers), + all_unique_searcher._iterations, + ) unique_tip_searchers = next_unique_searchers - def get_parent_map(self, revisions): # type: ignore + def get_parent_map(self, revisions): # type: ignore """Get a map of key:parent_list for revisions. This implementation delegates to get_parents, for old parent_providers @@ -746,8 +787,7 @@ def _find_border_ancestors(self, revisions): if None in revisions: raise errors.InvalidRevisionId(None, self) common_ancestors = set() - searchers = [self._make_breadth_first_searcher([r]) - for r in revisions] + searchers = [self._make_breadth_first_searcher([r]) for r in revisions] border_ancestors = set() while True: @@ -793,11 +833,13 @@ def _find_border_ancestors(self, revisions): nodes = unique_search_sets.pop() uncommon_nodes = nodes.difference(common_ancestors) if uncommon_nodes: - raise AssertionError("Somehow we ended up converging" - " without actually marking them as" - " in common." - f"\nStart_nodes: {revisions}" - f"\nuncommon_nodes: {uncommon_nodes}") + raise AssertionError( + "Somehow we ended up converging" + " without actually marking them as" + " in common." + f"\nStart_nodes: {revisions}" + f"\nuncommon_nodes: {uncommon_nodes}" + ) break return border_ancestors, common_ancestors, searchers @@ -824,8 +866,7 @@ def heads(self, keys): return {_mod_revision.NULL_REVISION} if len(candidate_heads) < 2: return candidate_heads - searchers = {c: self._make_breadth_first_searcher([c]) - for c in candidate_heads} + searchers = {c: self._make_breadth_first_searcher([c]) for c in candidate_heads} active_searchers = dict(searchers) # skip over the actual candidate for each searcher for searcher in active_searchers.values(): @@ -884,8 +925,7 @@ def heads(self, keys): # so we can stop searching it, and any seen ancestors new_common.add(ancestor) for searcher in searchers.values(): - seen_ancestors =\ - searcher.find_seen_ancestors([ancestor]) + seen_ancestors = searcher.find_seen_ancestors([ancestor]) searcher.stop_searching_any(seen_ancestors) common_walker.start_searching(new_common) return candidate_heads @@ -957,8 +997,7 @@ def find_lefthand_merger(self, merged_key, tip_key): return last_candidate last_candidate = candidate - def find_unique_lca(self, left_revision, right_revision, - count_steps=False): + def find_unique_lca(self, left_revision, right_revision, count_steps=False): """Find a unique LCA. Find lowest common ancestors. If there is no unique common @@ -1024,6 +1063,7 @@ def get_parents(key): return self._parents_provider.get_parent_map([key])[key] except KeyError as err: raise errors.RevisionNotPresent(next_key, self) from err + while True: if next_key in stop_keys: return @@ -1042,6 +1082,7 @@ def iter_topo_order(self, revisions): visible in the supplied list of revisions. """ from breezy import tsort + pm = self.get_parent_map(revisions) sorter = tsort.TopoSorter(pm) return sorter.iter_topo_order() @@ -1054,7 +1095,8 @@ def is_ancestor(self, candidate_ancestor, candidate_descendant): relationship between N revisions. """ return {candidate_descendant} == self.heads( - [candidate_ancestor, candidate_descendant]) + [candidate_ancestor, candidate_descendant] + ) def is_between(self, revid, lower_bound_revid, upper_bound_revid): """Determine whether a revision is between two others. @@ -1062,10 +1104,9 @@ def is_between(self, revid, lower_bound_revid, upper_bound_revid): returns true if and only if: lower_bound_revid <= revid <= upper_bound_revid """ - return ((upper_bound_revid is None or - self.is_ancestor(revid, upper_bound_revid)) and - (lower_bound_revid is None or - self.is_ancestor(lower_bound_revid, revid))) + return ( + upper_bound_revid is None or self.is_ancestor(revid, upper_bound_revid) + ) and (lower_bound_revid is None or self.is_ancestor(lower_bound_revid, revid)) def _search_for_extra_common(self, common, searchers): """Make sure that unique nodes are genuinely unique. @@ -1098,8 +1139,7 @@ def _search_for_extra_common(self, common, searchers): # TODO: We need a way to remove unique_searchers when they overlap with # other unique searchers. if len(searchers) != 2: - raise NotImplementedError( - "Algorithm not yet implemented for > 2 searchers") + raise NotImplementedError("Algorithm not yet implemented for > 2 searchers") common_searchers = searchers left_searcher = searchers[0] right_searcher = searchers[1] @@ -1107,8 +1147,7 @@ def _search_for_extra_common(self, common, searchers): if not unique: # No unique nodes, nothing to do return total_unique = len(unique) - unique = self._remove_simple_descendants(unique, - self.get_parent_map(unique)) + unique = self._remove_simple_descendants(unique, self.get_parent_map(unique)) simple_unique = len(unique) unique_searchers = [] @@ -1134,11 +1173,13 @@ def _search_for_extra_common(self, common, searchers): if ancestor_all_unique is None: ancestor_all_unique = set(searcher.seen) else: - ancestor_all_unique = ancestor_all_unique.intersection( - searcher.seen) + ancestor_all_unique = ancestor_all_unique.intersection(searcher.seen) - trace.mutter('Started %d unique searchers for %d unique revisions', - simple_unique, total_unique) + trace.mutter( + "Started %d unique searchers for %d unique revisions", + simple_unique, + total_unique, + ) while True: # If we have no more nodes we have nothing to do newly_seen_common = set() @@ -1161,7 +1202,8 @@ def _search_for_extra_common(self, common, searchers): # Make sure all searchers are on the same page for searcher in common_searchers: newly_seen_common.update( - searcher.find_seen_ancestors(newly_seen_common)) + searcher.find_seen_ancestors(newly_seen_common) + ) # We start searching the whole ancestry. It is a bit wasteful, # though. We really just want to mark all of these nodes as # 'seen' and then start just the tips. However, it requires a @@ -1173,7 +1215,8 @@ def _search_for_extra_common(self, common, searchers): # If a 'common' node is an ancestor of all unique searchers, we # can stop searching it. stop_searching_common = ancestor_all_unique.intersection( - newly_seen_common) + newly_seen_common + ) if stop_searching_common: for searcher in common_searchers: searcher.stop_searching_any(stop_searching_common) @@ -1181,12 +1224,14 @@ def _search_for_extra_common(self, common, searchers): # We found some ancestors that are common for searcher in unique_searchers: new_common_unique.update( - searcher.find_seen_ancestors(new_common_unique)) + searcher.find_seen_ancestors(new_common_unique) + ) # Since these are common, we can grab another set of ancestors # that we have seen for searcher in common_searchers: new_common_unique.update( - searcher.find_seen_ancestors(new_common_unique)) + searcher.find_seen_ancestors(new_common_unique) + ) # We can tell all of the unique searchers to start at these # nodes, and tell all of the common searchers to *stop* @@ -1330,7 +1375,7 @@ def __init__(self, revisions, parents_provider): self._started_keys = set(self._next_query) self._stopped_keys = set() self._parents_provider = parents_provider - self._returning = 'next_with_ghosts' + self._returning = "next_with_ghosts" self._current_present = set() self._current_ghosts = set() self._current_parents = {} @@ -1340,16 +1385,19 @@ def __repr__(self): prefix = "searching" else: prefix = "starting" - search = f'{prefix}={list(self._next_query)!r}' - return ('_BreadthFirstSearcher(iterations=%d, %s,' - ' seen=%r)' % (self._iterations, search, list(self.seen))) + search = f"{prefix}={list(self._next_query)!r}" + return "_BreadthFirstSearcher(iterations=%d, %s," " seen=%r)" % ( + self._iterations, + search, + list(self.seen), + ) def get_state(self): """Get the current state of this searcher. :return: Tuple with started keys, excludes and included keys """ - if self._returning == 'next': + if self._returning == "next": # We have to know the current nodes children to be able to list the # exclude keys for them. However, while we could have a second # look-ahead result buffer and shuffle things around, this method @@ -1385,9 +1433,9 @@ def __next__(self): :return: A set of revision_ids. """ - if self._returning != 'next': + if self._returning != "next": # switch to returning the query, not the results. - self._returning = 'next' + self._returning = "next" self._iterations += 1 else: self._advance() @@ -1410,9 +1458,9 @@ def next_with_ghosts(self): :return: A tuple with (present ancestors, ghost ancestors) sets. """ - if self._returning != 'next_with_ghosts': + if self._returning != "next_with_ghosts": # switch to returning the results, not the current query. - self._returning = 'next_with_ghosts' + self._returning = "next_with_ghosts" self._advance() if len(self._next_query) == 0: raise StopIteration() @@ -1481,7 +1529,7 @@ def find_seen_ancestors(self, revisions): pending = set(revisions).intersection(all_seen) seen_ancestors = set(pending) - if self._returning == 'next': + if self._returning == "next": # self.seen contains what nodes have been returned, not what nodes # have been queried. We don't want to probe for nodes that haven't # been searched yet. @@ -1497,8 +1545,7 @@ def find_seen_ancestors(self, revisions): # a ghost for parent_ids in parent_map.values(): all_parents.extend(parent_ids) - next_pending = all_seen.intersection( - all_parents).difference(seen_ancestors) + next_pending = all_seen.intersection(all_parents).difference(seen_ancestors) seen_ancestors.update(next_pending) next_pending.difference_update(not_searched_yet) pending = next_pending @@ -1521,13 +1568,14 @@ def stop_searching_any(self, revisions): # if not revisions: # return set() revisions = frozenset(revisions) - if self._returning == 'next': + if self._returning == "next": stopped = self._next_query.intersection(revisions) self._next_query = self._next_query.difference(revisions) else: stopped_present = self._current_present.intersection(revisions) stopped = stopped_present.union( - self._current_ghosts.intersection(revisions)) + self._current_ghosts.intersection(revisions) + ) self._current_present.difference_update(stopped) self._current_ghosts.difference_update(stopped) # stopping 'x' should stop returning parents of 'x', but @@ -1566,7 +1614,7 @@ def start_searching(self, revisions): revisions = frozenset(revisions) self._started_keys.update(revisions) new_revisions = revisions.difference(self.seen) - if self._returning == 'next': + if self._returning == "next": self._next_query.update(new_revisions) self.seen.update(new_revisions) else: diff --git a/breezy/grep.py b/breezy/grep.py index 1bb1dd1906..58886d1dcd 100644 --- a/breezy/grep.py +++ b/breezy/grep.py @@ -19,11 +19,14 @@ from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy.terminal import color_string, FG -""") +""", +) from . import controldir, errors, osutils, trace from . import revision as _mod_revision from .revisionspec import RevisionSpec, RevisionSpec_revid, RevisionSpec_revno @@ -42,6 +45,7 @@ class GrepOptions: some other params (like outf) to processing functions. This makes it easier to add more options as grep evolves. """ + verbose = False ignore_case = False no_recursive = False @@ -86,9 +90,10 @@ def _linear_view_revisions(branch, start_rev_id, end_rev_id): repo = branch.repository graph = repo.get_graph() for revision_id in graph.iter_lefthand_ancestry( - end_rev_id, (_mod_revision.NULL_REVISION, )): + end_rev_id, (_mod_revision.NULL_REVISION,) + ): revno = branch.revision_id_to_dotted_revno(revision_id) - revno_str = '.'.join(str(n) for n in revno) + revno_str = ".".join(str(n) for n in revno) if revision_id == start_rev_id: yield revision_id, revno_str, 0 break @@ -98,8 +103,7 @@ def _linear_view_revisions(branch, start_rev_id, end_rev_id): # NOTE: _graph_view_revisions is copied from # breezy.log._graph_view_revisions. # This should probably be a common public API -def _graph_view_revisions(branch, start_rev_id, end_rev_id, - rebase_initial_depths=True): +def _graph_view_revisions(branch, start_rev_id, end_rev_id, rebase_initial_depths=True): """Calculate revisions to view including merges, newest to oldest. :param branch: the branch @@ -111,20 +115,20 @@ def _graph_view_revisions(branch, start_rev_id, end_rev_id, """ # requires that start is older than end view_revisions = branch.iter_merge_sorted_revisions( - start_revision_id=end_rev_id, stop_revision_id=start_rev_id, - stop_rule="with-merges") + start_revision_id=end_rev_id, + stop_revision_id=start_rev_id, + stop_rule="with-merges", + ) if not rebase_initial_depths: - for (rev_id, merge_depth, revno, _end_of_merge - ) in view_revisions: - yield rev_id, '.'.join(map(str, revno)), merge_depth + for rev_id, merge_depth, revno, _end_of_merge in view_revisions: + yield rev_id, ".".join(map(str, revno)), merge_depth else: # We're following a development line starting at a merged revision. # We need to adjust depths down by the initial depth until we find # a depth less than it. Then we use that depth as the adjustment. # If and when we reach the mainline, depth adjustment ends. depth_adjustment = None - for (rev_id, merge_depth, revno, _end_of_merge - ) in view_revisions: + for rev_id, merge_depth, revno, _end_of_merge in view_revisions: if depth_adjustment is None: depth_adjustment = merge_depth if depth_adjustment: @@ -135,7 +139,7 @@ def _graph_view_revisions(branch, start_rev_id, end_rev_id, # though. depth_adjustment = merge_depth merge_depth -= depth_adjustment - yield rev_id, '.'.join(map(str, revno)), merge_depth + yield rev_id, ".".join(map(str, revno)), merge_depth def compile_pattern(pattern, flags=0): @@ -165,8 +169,7 @@ def __init__(self, opts): self.get_writer = self._get_writer_fixed_highlighted else: flags = opts.patternc.flags - self._sub = re.compile( - opts.pattern.join(("((?:", ")+)")), flags).sub + self._sub = re.compile(opts.pattern.join(("((?:", ")+)")), flags).sub self._highlight = color_string("\\1", FG.BOLD_RED) self.get_writer = self._get_writer_regexp_highlighted else: @@ -182,6 +185,7 @@ def _line_writer(line): def _line_writer_color(line): write(FG.BOLD_MAGENTA + line + FG.NONE + eol_marker) + if self.opts.show_color: return _line_writer_color else: @@ -198,6 +202,7 @@ def _line_writer(line): def _line_writer_color(line): write(FG.BOLD_BLUE + line + FG.NONE + eol_marker) + if self.opts.show_color: return _line_writer_color else: @@ -211,6 +216,7 @@ def _get_writer_plain(self): def _line_writer(line): write(line + eol_marker) + return _line_writer def _get_writer_regexp_highlighted(self): @@ -221,6 +227,7 @@ def _get_writer_regexp_highlighted(self): def _line_writer_regexp_highlighted(line): """Write formatted line with matched pattern highlighted.""" return _line_writer(line=sub(highlight, line)) + return _line_writer_regexp_highlighted def _get_writer_fixed_highlighted(self): @@ -231,23 +238,26 @@ def _get_writer_fixed_highlighted(self): def _line_writer_fixed_highlighted(line): """Write formatted line with string searched for highlighted.""" return _line_writer(line=line.replace(old, new)) + return _line_writer_fixed_highlighted def grep_diff(opts): - wt, branch, relpath = \ - controldir.ControlDir.open_containing_tree_or_branch('.') + wt, branch, relpath = controldir.ControlDir.open_containing_tree_or_branch(".") from breezy import diff + with branch.lock_read(): if opts.revision: start_rev = opts.revision[0] else: # if no revision is sepcified for diff grep we grep all changesets. - opts.revision = [RevisionSpec.from_string('revno:1'), - RevisionSpec.from_string('last:1')] + opts.revision = [ + RevisionSpec.from_string("revno:1"), + RevisionSpec.from_string("last:1"), + ] start_rev = opts.revision[0] start_revid = start_rev.as_revision_id(branch) - if start_revid == b'null:': + if start_revid == b"null:": return srevno_tuple = branch.revision_id_to_dotted_revno(start_revid) if len(opts.revision) == 2: @@ -257,8 +267,9 @@ def grep_diff(opts): end_revno, end_revid = branch.last_revision_info() erevno_tuple = branch.revision_id_to_dotted_revno(end_revid) - grep_mainline = (_rev_on_mainline(srevno_tuple) - and _rev_on_mainline(erevno_tuple)) + grep_mainline = _rev_on_mainline(srevno_tuple) and _rev_on_mainline( + erevno_tuple + ) # ensure that we go in reverse order if srevno_tuple > erevno_tuple: @@ -270,22 +281,21 @@ def grep_diff(opts): # with _linear_view_revisions. If all revs are to be grepped we # use the slower _graph_view_revisions if opts.levels == 1 and grep_mainline: - given_revs = _linear_view_revisions( - branch, start_revid, end_revid) + given_revs = _linear_view_revisions(branch, start_revid, end_revid) else: - given_revs = _graph_view_revisions( - branch, start_revid, end_revid) + given_revs = _graph_view_revisions(branch, start_revid, end_revid) else: # We do an optimization below. For grepping a specific revison # We don't need to call _graph_view_revisions which is slow. # We create the start_rev_tuple for only that specific revision. # _graph_view_revisions is used only for revision range. - start_revno = '.'.join(map(str, srevno_tuple)) + start_revno = ".".join(map(str, srevno_tuple)) start_rev_tuple = (start_revid, start_revno, 0) given_revs = [start_rev_tuple] repo = branch.repository diff_pattern = re.compile( - b"^[+\\-].*(" + opts.pattern.encode(_user_encoding) + b")") + b"^[+\\-].*(" + opts.pattern.encode(_user_encoding) + b")" + ) file_pattern = re.compile(b"=== (modified|added|removed) file '.*'") outputter = _GrepDiffOutputter(opts) writeline = outputter.get_writer() @@ -297,8 +307,7 @@ def grep_diff(opts): # with level=1 show only top level continue - rev_spec = RevisionSpec_revid.from_string( - "revid:" + revid.decode('utf-8')) + rev_spec = RevisionSpec_revid.from_string("revid:" + revid.decode("utf-8")) new_rev = repo.get_revision(revid) new_tree = rev_spec.as_tree(branch) if len(new_rev.parent_ids) == 0: @@ -307,8 +316,7 @@ def grep_diff(opts): ancestor_id = new_rev.parent_ids[0] old_tree = repo.revision_tree(ancestor_id) s = BytesIO() - diff.show_diff_trees(old_tree, new_tree, s, - old_label='', new_label='') + diff.show_diff_trees(old_tree, new_tree, s, old_label="", new_label="") display_revno = True display_file = False file_header = None @@ -323,15 +331,15 @@ def grep_diff(opts): display_revno = False if display_file: writefileheader( - f" {file_header.decode(file_encoding, 'replace')}") + f" {file_header.decode(file_encoding, 'replace')}" + ) display_file = False - line = line.decode(file_encoding, 'replace') + line = line.decode(file_encoding, "replace") writeline(f" {line}") def versioned_grep(opts): - wt, branch, relpath = \ - controldir.ControlDir.open_containing_tree_or_branch('.') + wt, branch, relpath = controldir.ControlDir.open_containing_tree_or_branch(".") with branch.lock_read(): start_rev = opts.revision[0] start_revid = start_rev.as_revision_id(branch) @@ -347,8 +355,9 @@ def versioned_grep(opts): end_revno, end_revid = branch.last_revision_info() erevno_tuple = branch.revision_id_to_dotted_revno(end_revid) - grep_mainline = (_rev_on_mainline(srevno_tuple) - and _rev_on_mainline(erevno_tuple)) + grep_mainline = _rev_on_mainline(srevno_tuple) and _rev_on_mainline( + erevno_tuple + ) # ensure that we go in reverse order if srevno_tuple > erevno_tuple: @@ -360,17 +369,15 @@ def versioned_grep(opts): # with _linear_view_revisions. If all revs are to be grepped we # use the slower _graph_view_revisions if opts.levels == 1 and grep_mainline: - given_revs = _linear_view_revisions( - branch, start_revid, end_revid) + given_revs = _linear_view_revisions(branch, start_revid, end_revid) else: - given_revs = _graph_view_revisions( - branch, start_revid, end_revid) + given_revs = _graph_view_revisions(branch, start_revid, end_revid) else: # We do an optimization below. For grepping a specific revison # We don't need to call _graph_view_revisions which is slow. # We create the start_rev_tuple for only that specific revision. # _graph_view_revisions is used only for revision range. - start_revno = '.'.join(map(str, srevno_tuple)) + start_revno = ".".join(map(str, srevno_tuple)) start_rev_tuple = (start_revid, start_revno, 0) given_revs = [start_rev_tuple] @@ -382,8 +389,7 @@ def versioned_grep(opts): # with level=1 show only top level continue - rev = RevisionSpec_revid.from_string( - "revid:" + revid.decode('utf-8')) + rev = RevisionSpec_revid.from_string("revid:" + revid.decode("utf-8")) tree = rev.as_tree(branch) for path in opts.path_list: tree_path = osutils.pathjoin(relpath, path) @@ -395,18 +401,18 @@ def versioned_grep(opts): path_prefix = path dir_grep(tree, path, relpath, opts, revno, path_prefix) else: - versioned_file_grep( - tree, tree_path, '.', path, opts, revno) + versioned_file_grep(tree, tree_path, ".", path, opts, revno) def workingtree_grep(opts): revno = opts.print_revno = None # for working tree set revno to None - tree, branch, relpath = \ - controldir.ControlDir.open_containing_tree_or_branch('.') + tree, branch, relpath = controldir.ControlDir.open_containing_tree_or_branch(".") if not tree: - msg = ('Cannot search working tree. Working tree not found.\n' - 'To search for specific revision in history use the -r option.') + msg = ( + "Cannot search working tree. Working tree not found.\n" + "To search for specific revision in history use the -r option." + ) raise errors.CommandError(msg) # GZ 2010-06-02: Shouldn't be smuggling this on opts, but easy for now @@ -418,7 +424,7 @@ def workingtree_grep(opts): path_prefix = path dir_grep(tree, path, relpath, opts, revno, path_prefix) else: - with open(path, 'rb') as f: + with open(path, "rb") as f: _file_grep(f.read(), path, opts, revno) @@ -433,7 +439,7 @@ def _skip_file(include, exclude, path): def dir_grep(tree, path, relpath, opts, revno, path_prefix): # setup relpath to open files relative to cwd if relpath: - osutils.pathjoin('..', relpath) + osutils.pathjoin("..", relpath) from_dir = osutils.pathjoin(relpath, path) if opts.from_root: @@ -447,13 +453,13 @@ def dir_grep(tree, path, relpath, opts, revno, path_prefix): # for a good reason, otherwise cache might want purging. outputter = opts.outputter for fp, fc, fkind, _entry in tree.list_files( - include_root=False, from_dir=from_dir, recursive=opts.recursive): - + include_root=False, from_dir=from_dir, recursive=opts.recursive + ): if _skip_file(opts.include, opts.exclude, fp): continue - if fc == 'V' and fkind == 'file': - tree_path = osutils.pathjoin(from_dir if from_dir else '', fp) + if fc == "V" and fkind == "file": + tree_path = osutils.pathjoin(from_dir if from_dir else "", fp) if revno is not None: # If old result is valid, print results immediately. # Otherwise, add file info to to_grep so that the @@ -469,23 +475,29 @@ def dir_grep(tree, path, relpath, opts, revno, path_prefix): else: # we are grepping working tree. if from_dir is None: - from_dir = '.' + from_dir = "." path_for_file = osutils.pathjoin(tree.basedir, from_dir, fp) if opts.files_with_matches or opts.files_without_match: # Optimize for wtree list-only as we don't need to read the # entire file - with open(path_for_file, 'rb', buffering=4096) as file: + with open(path_for_file, "rb", buffering=4096) as file: _file_grep_list_only_wtree(file, fp, opts, path_prefix) else: - with open(path_for_file, 'rb') as f: + with open(path_for_file, "rb") as f: _file_grep(f.read(), fp, opts, revno, path_prefix) if revno is not None: # grep versioned files for (path, tree_path), chunks in tree.iter_files_bytes(to_grep): path = _make_display_path(relpath, path) - _file_grep(b''.join(chunks), path, opts, revno, path_prefix, - tree.get_file_revision(tree_path)) + _file_grep( + b"".join(chunks), + path, + opts, + revno, + path_prefix, + tree.get_file_revision(tree_path), + ) def _make_display_path(relpath, path): @@ -498,8 +510,8 @@ def _make_display_path(relpath, path): # update path so to display it w.r.t cwd # handle windows slash separator path = osutils.normpath(osutils.pathjoin(relpath, path)) - path = path.replace('\\', '/') - path = path.replace(relpath + '/', '', 1) + path = path.replace("\\", "/") + path = path.replace(relpath + "/", "", 1) return path @@ -512,6 +524,7 @@ def versioned_file_grep(tree, tree_path, relpath, path, opts, revno, path_prefix def _path_in_glob_list(path, glob_list) -> bool: from fnmatch import fnmatch + for glob in glob_list: if fnmatch(path, glob): return True @@ -520,7 +533,7 @@ def _path_in_glob_list(path, glob_list) -> bool: def _file_grep_list_only_wtree(file, path, opts, path_prefix=None): # test and skip binary files - if b'\x00' in file.read(1024): + if b"\x00" in file.read(1024): if opts.verbose: trace.warning("Binary file '%s' skipped.", path) return @@ -529,7 +542,7 @@ def _file_grep_list_only_wtree(file, path, opts, path_prefix=None): found = False if opts.fixed_string: - pattern = opts.pattern.encode(_user_encoding, 'replace') + pattern = opts.pattern.encode(_user_encoding, "replace") for line in file: if pattern in line: found = True @@ -540,9 +553,8 @@ def _file_grep_list_only_wtree(file, path, opts, path_prefix=None): found = True break - if (opts.files_with_matches and found) or \ - (opts.files_without_match and not found): - if path_prefix and path_prefix != '.': + if (opts.files_with_matches and found) or (opts.files_without_match and not found): + if path_prefix and path_prefix != ".": # user has passed a dir arg, show that as result prefix path = osutils.pathjoin(path_prefix, path) opts.outputter.get_writer(path, None, None)() @@ -577,14 +589,13 @@ def __init__(self, opts, use_cache=False): self.get_writer = self._get_writer_fixed_highlighted else: flags = opts.patternc.flags - self._sub = re.compile( - opts.pattern.join(("((?:", ")+)")), flags).sub + self._sub = re.compile(opts.pattern.join(("((?:", ")+)")), flags).sub self._highlight = color_string("\\1", FG.BOLD_RED) self.get_writer = self._get_writer_regexp_highlighted path_start = FG.MAGENTA path_end = FG.NONE - sep = color_string(':', FG.BOLD_CYAN) - rev_sep = color_string('~', FG.BOLD_YELLOW) + sep = color_string(":", FG.BOLD_CYAN) + rev_sep = color_string("~", FG.BOLD_YELLOW) else: self.get_writer = self._get_writer_plain path_start = path_end = "" @@ -621,11 +632,13 @@ def _line_cache_and_writer(**kwargs): end = per_line % kwargs add_to_cache(end) write(start + end) + return _line_cache_and_writer def _line_writer(**kwargs): """Write formatted line from arguments given by underlying opts.""" write(start + per_line % kwargs) + return _line_writer def write_cached_lines(self, cache_id, revno): @@ -644,6 +657,7 @@ def _get_writer_regexp_highlighted(self, path, revno, cache_id): def _line_writer_regexp_highlighted(line, **kwargs): """Write formatted line with matched pattern highlighted.""" return _line_writer(line=sub(highlight, line), **kwargs) + return _line_writer_regexp_highlighted def _get_writer_fixed_highlighted(self, path, revno, cache_id): @@ -654,17 +668,18 @@ def _get_writer_fixed_highlighted(self, path, revno, cache_id): def _line_writer_fixed_highlighted(line, **kwargs): """Write formatted line with string searched for highlighted.""" return _line_writer(line=line.replace(old, new), **kwargs) + return _line_writer_fixed_highlighted def _file_grep(file_text, path, opts, revno, path_prefix=None, cache_id=None): # test and skip binary files - if b'\x00' in file_text[:1024]: + if b"\x00" in file_text[:1024]: if opts.verbose: trace.warning("Binary file '%s' skipped.", path) return - if path_prefix and path_prefix != '.': + if path_prefix and path_prefix != ".": # user has passed a dir arg, show that as result prefix path = osutils.pathjoin(path_prefix, path) @@ -672,7 +687,7 @@ def _file_grep(file_text, path, opts, revno, path_prefix=None, cache_id=None): # the user encoding, but we have to guess something and it # is a reasonable default without a better mechanism. file_encoding = _user_encoding - pattern = opts.pattern.encode(_user_encoding, 'replace') + pattern = opts.pattern.encode(_user_encoding, "replace") writeline = opts.outputter.get_writer(path, revno, cache_id) @@ -690,8 +705,9 @@ def _file_grep(file_text, path, opts, revno, path_prefix=None, cache_id=None): break else: found = False - if (opts.files_with_matches and found) or \ - (opts.files_without_match and not found): + if (opts.files_with_matches and found) or ( + opts.files_without_match and not found + ): writeline() elif opts.fixed_string: # Fast path for no match, search through the entire file at once rather @@ -706,12 +722,12 @@ def _file_grep(file_text, path, opts, revno, path_prefix=None, cache_id=None): if opts.line_number: for index, line in enumerate(file_text.splitlines()): if pattern in line: - line = line.decode(file_encoding, 'replace') + line = line.decode(file_encoding, "replace") writeline(lineno=index + start, line=line) else: for line in file_text.splitlines(): if pattern in line: - line = line.decode(file_encoding, 'replace') + line = line.decode(file_encoding, "replace") writeline(line=line) else: # Fast path on no match, the re module avoids bad behaviour in most @@ -734,10 +750,10 @@ def _file_grep(file_text, path, opts, revno, path_prefix=None, cache_id=None): if opts.line_number: for index, line in enumerate(file_text.splitlines()): if search(line): - line = line.decode(file_encoding, 'replace') + line = line.decode(file_encoding, "replace") writeline(lineno=index + start, line=line) else: for line in file_text.splitlines(): if search(line): - line = line.decode(file_encoding, 'replace') + line = line.decode(file_encoding, "replace") writeline(line=line) diff --git a/breezy/help.py b/breezy/help.py index 316c7f59c0..af9a3b3c77 100644 --- a/breezy/help.py +++ b/breezy/help.py @@ -26,9 +26,10 @@ class NoHelpTopic(errors.BzrError): - - _fmt = ("No help could be found for '%(topic)s'. " - "Please use 'brz help topics' to obtain a list of topics.") + _fmt = ( + "No help could be found for '%(topic)s'. " + "Please use 'brz help topics' to obtain a list of topics." + ) def __init__(self, topic): self.topic = topic @@ -46,7 +47,7 @@ def help(topic=None, outfile=None): topics = indices.search(topic) shadowed_terms = [] for index, topic_obj in topics[1:]: - shadowed_terms.append(f'{index.prefix}{topic_obj.get_help_topic()}') + shadowed_terms.append(f"{index.prefix}{topic_obj.get_help_topic()}") source = topics[0][1] outfile.write(source.get_help_text(shadowed_terms)) except NoHelpTopic: @@ -61,13 +62,13 @@ def help_commands(outfile=None): """List all commands.""" if outfile is None: outfile = ui.ui_factory.make_output_stream() - outfile.write(_help_commands_to_text('commands')) + outfile.write(_help_commands_to_text("commands")) def _help_commands_to_text(topic): """Generate the help text for the list of commands.""" out = [] - if topic == 'hidden-commands': + if topic == "hidden-commands": hidden = True else: hidden = False @@ -75,7 +76,7 @@ def _help_commands_to_text(topic): commands = ((n, _mod_commands.get_cmd_object(n)) for n in names) shown_commands = [(n, o) for n, o in commands if o.hidden == hidden] max_name = max(len(n) for n, o in shown_commands) - indent = ' ' * (max_name + 1) + indent = " " * (max_name + 1) width = osutils.terminal_width() if width is None: width = osutils.default_terminal_width @@ -85,33 +86,36 @@ def _help_commands_to_text(topic): for cmd_name, cmd_object in sorted(shown_commands): plugin_name = cmd_object.plugin_name() if plugin_name is None: - plugin_name = '' + plugin_name = "" else: - plugin_name = f' [{plugin_name}]' + plugin_name = f" [{plugin_name}]" cmd_help = cmd_object.help() if cmd_help: - firstline = cmd_help.split('\n', 1)[0] + firstline = cmd_help.split("\n", 1)[0] else: - firstline = '' - helpstring = '%-*s %s%s' % (max_name, cmd_name, firstline, plugin_name) + firstline = "" + helpstring = "%-*s %s%s" % (max_name, cmd_name, firstline, plugin_name) lines = utextwrap.wrap( - helpstring, subsequent_indent=indent, - width=width, - break_long_words=False) + helpstring, subsequent_indent=indent, width=width, break_long_words=False + ) for line in lines: - out.append(line + '\n') - return ''.join(out) + out.append(line + "\n") + return "".join(out) -help_topics.topic_registry.register("commands", - _help_commands_to_text, - "Basic help for all commands", - help_topics.SECT_HIDDEN) -help_topics.topic_registry.register("hidden-commands", - _help_commands_to_text, - "All hidden commands", - help_topics.SECT_HIDDEN) +help_topics.topic_registry.register( + "commands", + _help_commands_to_text, + "Basic help for all commands", + help_topics.SECT_HIDDEN, +) +help_topics.topic_registry.register( + "hidden-commands", + _help_commands_to_text, + "All hidden commands", + help_topics.SECT_HIDDEN, +) class HelpIndices: @@ -135,7 +139,7 @@ def __init__(self): _mod_commands.HelpCommandIndex(), plugin.PluginsHelpIndex(), help_topics.ConfigOptionHelpIndex(), - ] + ] def _check_prefix_uniqueness(self): """Ensure that the index collection is able to differentiate safely.""" @@ -156,8 +160,7 @@ def search(self, topic): self._check_prefix_uniqueness() result = [] for index in self.search_path: - result.extend([(index, _topic) - for _topic in index.get_topics(topic)]) + result.extend([(index, _topic) for _topic in index.get_topics(topic)]) if not result: raise NoHelpTopic(topic) else: diff --git a/breezy/help_topics/__init__.py b/breezy/help_topics/__init__.py index e8d259b78d..ca0d7c8c86 100644 --- a/breezy/help_topics/__init__.py +++ b/breezy/help_topics/__init__.py @@ -33,7 +33,7 @@ rendering on the screen naturally. """ -__all__ = ['help_as_plain_text', '_format_see_also'] +__all__ = ["help_as_plain_text", "_format_see_also"] from breezy import config from breezy._cmd_rs import format_see_also as _format_see_also @@ -56,6 +56,7 @@ # ---------------------------------------------------- + def _help_on_topics(dummy): """Write out the help for topics to outfile.""" topics = topic_registry.keys() @@ -65,7 +66,7 @@ def _help_on_topics(dummy): for topic in topics: summary = topic_registry.get_summary(topic) out.append("%-*s %s\n" % (lmax, topic, summary)) - return ''.join(out) + return "".join(out) def _help_on_revisionspec(name): @@ -104,14 +105,14 @@ def _help_on_revisionspec(name): 3649, but not 3647. The keywords used as revision selection methods are the following: -""") +""" + ) details = [] details.append("\nIn addition, plugins can provide other keywords.") - details.append( - "\nA detailed description of each keyword is given below.\n") + details.append("\nA detailed description of each keyword is given below.\n") # The help text is indented 4 spaces - this re cleans that up below - re.compile(r'^ ', re.MULTILINE) + re.compile(r"^ ", re.MULTILINE) for _prefix, i in breezy.revisionspec.revspec_registry.iteritems(): doc = i.help_txt if doc == breezy.revisionspec.RevisionSpec.help_txt: @@ -121,8 +122,8 @@ def _help_on_revisionspec(name): # Extract out the top line summary from the body and # clean-up the unwanted whitespace summary, doc = doc.split("\n", 1) - #doc = indent_re.sub('', doc) - while (doc[-2:] == '\n\n' or doc[-1:] == ' '): + # doc = indent_re.sub('', doc) + while doc[-2:] == "\n\n" or doc[-1:] == " ": doc = doc[:-1] # Note: The leading : here are HACKs to get reStructuredText @@ -130,7 +131,7 @@ def _help_on_revisionspec(name): out.append(f":{i.prefix}\n\t{summary}") details.append(f":{i.prefix}\n{doc}") - return '\n'.join(out + details) + return "\n".join(out + details) def _help_on_transport(name): @@ -139,14 +140,13 @@ def _help_on_transport(name): from breezy.transport import transport_list_registry def add_string(proto, help, maxl, prefix_width=20): - help_lines = textwrap.wrap(help, maxl - prefix_width, - break_long_words=False) - line_with_indent = '\n' + ' ' * prefix_width + help_lines = textwrap.wrap(help, maxl - prefix_width, break_long_words=False) + line_with_indent = "\n" + " " * prefix_width help_text = line_with_indent.join(help_lines) return "%-20s%s\n" % (proto, help_text) def key_func(a): - return a[:a.rfind("://")] + return a[: a.rfind("://")] protl = [] decl = [] @@ -161,13 +161,10 @@ def key_func(a): else: decl.append(add_string(proto, shorthelp, 79)) - out = "URL Identifiers\n\n" + \ - "Supported URL prefixes::\n\n " + \ - ' '.join(protl) + out = "URL Identifiers\n\n" + "Supported URL prefixes::\n\n " + " ".join(protl) if len(decl): - out += "\nSupported modifiers::\n\n " + \ - ' '.join(decl) + out += "\nSupported modifiers::\n\n " + " ".join(decl) out += """\ \nBreezy supports all of the standard parts within the URL:: @@ -196,35 +193,46 @@ def key_func(a): # Register help topics -topic_registry.register("revisionspec", _help_on_revisionspec, - "Explain how to use --revision") -topic_registry.register('topics', _help_on_topics, "Topics list", SECT_HIDDEN) +topic_registry.register( + "revisionspec", _help_on_revisionspec, "Explain how to use --revision" +) +topic_registry.register("topics", _help_on_topics, "Topics list", SECT_HIDDEN) def get_current_formats_topic(topic): from breezy import controldir - return "Current Storage Formats\n\n" + \ - controldir.format_registry.help_topic(topic) + + return "Current Storage Formats\n\n" + controldir.format_registry.help_topic(topic) def get_other_formats_topic(topic): from breezy import controldir - return "Other Storage Formats\n\n" + \ - controldir.format_registry.help_topic(topic) - -topic_registry.register('current-formats', get_current_formats_topic, - 'Current storage formats') -topic_registry.register('other-formats', get_other_formats_topic, - 'Experimental and deprecated storage formats') -topic_registry.register('urlspec', _help_on_transport, - "Supported transport protocols") - -topic_registry.register_lazy('hooks', 'breezy.hooks', 'hooks_help_text', - 'Points at which custom processing can be added') -topic_registry.register_lazy('location-alias', 'breezy.directory_service', - 'AliasDirectory.help_text', - 'Aliases for remembered locations') + return "Other Storage Formats\n\n" + controldir.format_registry.help_topic(topic) + + +topic_registry.register( + "current-formats", get_current_formats_topic, "Current storage formats" +) +topic_registry.register( + "other-formats", + get_other_formats_topic, + "Experimental and deprecated storage formats", +) +topic_registry.register("urlspec", _help_on_transport, "Supported transport protocols") + +topic_registry.register_lazy( + "hooks", + "breezy.hooks", + "hooks_help_text", + "Points at which custom processing can be added", +) +topic_registry.register_lazy( + "location-alias", + "breezy.directory_service", + "AliasDirectory.help_text", + "Aliases for remembered locations", +) # Register concept topics. @@ -232,11 +240,12 @@ def get_other_formats_topic(topic): # future or implement them via loading content from files. In the meantime, # please keep them concise. + class HelpTopicIndex: """A index for brz help that returns topics.""" def __init__(self): - self.prefix = '' + self.prefix = "" def get_topics(self, topic): """Search for topic in the HelpTopicRegistry. @@ -246,7 +255,7 @@ def get_topics(self, topic): RegisteredTopic entry. """ if topic is None: - topic = 'basic' + topic = "basic" topic = topic_registry.get(topic) if topic: return [topic] @@ -258,7 +267,7 @@ class ConfigOptionHelpIndex: """A help index that returns help topics for config options.""" def __init__(self): - self.prefix = 'configuration/' + self.prefix = "configuration/" def get_topics(self, topic): """Search for topic in the registered config options. @@ -270,7 +279,7 @@ def get_topics(self, topic): if topic is None: return [] elif topic.startswith(self.prefix): - topic = topic[len(self.prefix):] + topic = topic[len(self.prefix) :] if topic in config.option_registry: return [config.option_registry.get(topic)] else: diff --git a/breezy/hooks.py b/breezy/hooks.py index 8bb71a3739..97c73f4421 100644 --- a/breezy/hooks.py +++ b/breezy/hooks.py @@ -23,17 +23,19 @@ from . import errors, registry from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( _format_version_tuple, pyutils, ) from breezy.i18n import gettext -""") +""", +) class UnknownHook(errors.BzrError): - _fmt = "The %(type)s hook '%(hook)s' is unknown in this version of breezy." def __init__(self, hook_type, hook_name): @@ -48,10 +50,14 @@ class KnownHooksRegistry(registry.Registry[str, "Hooks", None]): # module where the specific hooks are defined # callable to get the empty specific Hooks for that attribute - def register_lazy_hook(self, hook_module_name, hook_member_name, - hook_factory_member_name): - self.register_lazy((hook_module_name, hook_member_name), - hook_module_name, hook_factory_member_name) + def register_lazy_hook( + self, hook_module_name, hook_member_name, hook_factory_member_name + ): + self.register_lazy( + (hook_module_name, hook_member_name), + hook_module_name, + hook_factory_member_name, + ) def iter_parent_objects(self): """Yield (hook_key, (parent_object, attr)) tuples for every registered @@ -77,27 +83,29 @@ def key_to_parent_and_attribute(self, key): _builtin_known_hooks = ( - ('breezy.branch', 'Branch.hooks', 'BranchHooks'), - ('breezy.controldir', 'ControlDir.hooks', 'ControlDirHooks'), - ('breezy.commands', 'Command.hooks', 'CommandHooks'), - ('breezy.config', 'ConfigHooks', '_ConfigHooks'), - ('breezy.info', 'hooks', 'InfoHooks'), - ('breezy.lock', 'Lock.hooks', 'LockHooks'), - ('breezy.merge', 'Merger.hooks', 'MergeHooks'), - ('breezy.msgeditor', 'hooks', 'MessageEditorHooks'), - ('breezy.mutabletree', 'MutableTree.hooks', 'MutableTreeHooks'), - ('breezy.bzr.smart.client', '_SmartClient.hooks', 'SmartClientHooks'), - ('breezy.bzr.smart.server', 'SmartTCPServer.hooks', 'SmartServerHooks'), - ('breezy.status', 'hooks', 'StatusHooks'), - ('breezy.transport', 'Transport.hooks', 'TransportHooks'), - ('breezy.version_info_formats.format_rio', 'RioVersionInfoBuilder.hooks', - 'RioVersionInfoBuilderHooks'), - ('breezy.merge_directive', 'BaseMergeDirective.hooks', - 'MergeDirectiveHooks'), - ) + ("breezy.branch", "Branch.hooks", "BranchHooks"), + ("breezy.controldir", "ControlDir.hooks", "ControlDirHooks"), + ("breezy.commands", "Command.hooks", "CommandHooks"), + ("breezy.config", "ConfigHooks", "_ConfigHooks"), + ("breezy.info", "hooks", "InfoHooks"), + ("breezy.lock", "Lock.hooks", "LockHooks"), + ("breezy.merge", "Merger.hooks", "MergeHooks"), + ("breezy.msgeditor", "hooks", "MessageEditorHooks"), + ("breezy.mutabletree", "MutableTree.hooks", "MutableTreeHooks"), + ("breezy.bzr.smart.client", "_SmartClient.hooks", "SmartClientHooks"), + ("breezy.bzr.smart.server", "SmartTCPServer.hooks", "SmartServerHooks"), + ("breezy.status", "hooks", "StatusHooks"), + ("breezy.transport", "Transport.hooks", "TransportHooks"), + ( + "breezy.version_info_formats.format_rio", + "RioVersionInfoBuilder.hooks", + "RioVersionInfoBuilderHooks", + ), + ("breezy.merge_directive", "BaseMergeDirective.hooks", "MergeDirectiveHooks"), +) known_hooks = KnownHooksRegistry() -for (_hook_module, _hook_attribute, _hook_class) in _builtin_known_hooks: +for _hook_module, _hook_attribute, _hook_class in _builtin_known_hooks: known_hooks.register_lazy_hook(_hook_module, _hook_attribute, _hook_class) del _builtin_known_hooks, _hook_module, _hook_attribute, _hook_class @@ -146,11 +154,17 @@ def add_hook(self, name, doc, introduced, deprecated=None): raise errors.DuplicateKey(name) if self._module: callbacks = _lazy_hooks.setdefault( - (self._module, self._member_name, name), []) + (self._module, self._member_name, name), [] + ) else: callbacks = None - hookpoint = HookPoint(name=name, doc=doc, introduced=introduced, - deprecated=deprecated, callbacks=callbacks) + hookpoint = HookPoint( + name=name, + doc=doc, + introduced=introduced, + deprecated=deprecated, + callbacks=callbacks, + ) self[name] = hookpoint def docs(self): @@ -179,15 +193,16 @@ def get_hook_name(self, a_callable): """ name = self._callable_names.get(a_callable, None) if name is None and a_callable is not None: - name = self._lazy_callable_names.get((a_callable.__module__, - a_callable.__name__), - None) + name = self._lazy_callable_names.get( + (a_callable.__module__, a_callable.__name__), None + ) if name is None: - return 'No hook name' + return "No hook name" return name - def install_named_hook_lazy(self, hook_name, callable_module, - callable_member, name): + def install_named_hook_lazy( + self, hook_name, callable_module, callable_member, name + ): """Install a_callable in to the hook hook_name lazily, and label it. :param hook_name: A hook name. See the __init__ method for the complete @@ -205,8 +220,9 @@ def install_named_hook_lazy(self, hook_name, callable_module, try: hook_lazy = hook.hook_lazy except AttributeError as err: - raise errors.UnsupportedOperation(self.install_named_hook_lazy, - self) from err + raise errors.UnsupportedOperation( + self.install_named_hook_lazy, self + ) from err else: hook_lazy(callable_module, callable_member, name) if name is not None: @@ -257,8 +273,7 @@ def name_hook(self, a_callable, name): self._callable_names[a_callable] = name def name_hook_lazy(self, callable_module, callable_member, callable_name): - self._lazy_callable_names[(callable_module, callable_member)] = \ - callable_name + self._lazy_callable_names[(callable_module, callable_member)] = callable_name class HookPoint: @@ -300,26 +315,26 @@ def docs(self): :return: A string terminated in \n. """ import textwrap + strings = [] strings.append(self.name) - strings.append('~' * len(self.name)) - strings.append('') + strings.append("~" * len(self.name)) + strings.append("") if self.introduced: introduced_string = _format_version_tuple(self.introduced) else: - introduced_string = 'unknown' - strings.append(gettext('Introduced in: %s') % introduced_string) + introduced_string = "unknown" + strings.append(gettext("Introduced in: %s") % introduced_string) if self.deprecated: deprecated_string = _format_version_tuple(self.deprecated) - strings.append(gettext('Deprecated in: %s') % deprecated_string) - strings.append('') - strings.extend(textwrap.wrap(self.__doc__, - break_long_words=False)) - strings.append('') - return '\n'.join(strings) + strings.append(gettext("Deprecated in: %s") % deprecated_string) + strings.append("") + strings.extend(textwrap.wrap(self.__doc__, break_long_words=False)) + strings.append("") + return "\n".join(strings) def __eq__(self, other): - return (isinstance(other, type(self)) and other.__dict__ == self.__dict__) + return isinstance(other, type(self)) and other.__dict__ == self.__dict__ def hook_lazy(self, callback_module, callback_member, callback_label): """Lazily register a callback to be called when this HookPoint fires. @@ -330,8 +345,7 @@ def hook_lazy(self, callback_module, callback_member, callback_label): :param callback_label: A label to show in the UI while this callback is processing. """ - obj_getter = registry._LazyObjectGetter(callback_module, - callback_member) + obj_getter = registry._LazyObjectGetter(callback_module, callback_member) self._callbacks.append((obj_getter, callback_label)) def hook(self, callback, callback_label): @@ -371,7 +385,7 @@ def __repr__(self): strings.append(self.name) strings.append("), callbacks=[") callbacks = self._callbacks - for (callback, callback_name) in callbacks: + for callback, callback_name in callbacks: strings.append(repr(callback.get_obj())) strings.append("(") strings.append(callback_name) @@ -379,11 +393,10 @@ def __repr__(self): if len(callbacks) == 1: strings[-1] = ")" strings.append("]>") - return ''.join(strings) + return "".join(strings) -_help_prefix = \ - """ +_help_prefix = """ Hooks ===== @@ -421,7 +434,7 @@ def hooks_help_text(topic): for hook_key in sorted(known_hooks.keys()): hooks = known_hooks_key_to_object(hook_key) segments.append(hooks.docs()) - return '\n'.join(segments) + return "\n".join(segments) # Lazily registered hooks. Maps (module, name, hook_name) tuples @@ -429,8 +442,9 @@ def hooks_help_text(topic): _lazy_hooks: Dict[Tuple[str, str, str], List[Tuple[registry._ObjectGetter, str]]] = {} -def install_lazy_named_hook(hookpoints_module, hookpoints_name, hook_name, - a_callable, name): +def install_lazy_named_hook( + hookpoints_module, hookpoints_name, hook_name, a_callable, name +): """Install a callable in to a hook lazily, and label it name. :param hookpoints_module: Module name of the hook points. diff --git a/breezy/i18n.py b/breezy/i18n.py index 7115bd7ef4..34ec97c5c7 100644 --- a/breezy/i18n.py +++ b/breezy/i18n.py @@ -69,6 +69,8 @@ def N_(msg): _installed = False + + def install(lang=None): """Enables gettext translations in brz.""" global _installed @@ -86,19 +88,20 @@ def _get_locale_dir(): :param base: plugins can specify their own local directory """ base = os.path.dirname(__file__) - dirpath = os.path.realpath(os.path.join(base, 'locale')) + dirpath = os.path.realpath(os.path.join(base, "locale")) if os.path.exists(dirpath): return dirpath return os.path.join(sys.prefix, "share", "locale") def _check_win32_locale(): - for i in ('LANGUAGE', 'LC_ALL', 'LC_MESSAGES', 'LANG'): + for i in ("LANGUAGE", "LC_ALL", "LC_MESSAGES", "LANG"): if os.environ.get(i): break else: lang = None import locale + try: import ctypes except ImportError: @@ -113,22 +116,23 @@ def _check_win32_locale(): else: lcid = [lcid_user] lang = [locale.windows_locale.get(i) for i in lcid] - lang = ':'.join([i for i in lang if i]) + lang = ":".join([i for i in lang if i]) # set lang code for gettext if lang: - os.environ['LANGUAGE'] = lang + os.environ["LANGUAGE"] = lang def _get_current_locale(): - if not os.environ.get('LANGUAGE'): + if not os.environ.get("LANGUAGE"): from . import config - lang = config.GlobalStack().get('language') + + lang = config.GlobalStack().get("language") if lang: - os.environ['LANGUAGE'] = lang + os.environ["LANGUAGE"] = lang return lang - if sys.platform == 'win32': + if sys.platform == "win32": _check_win32_locale() - for i in ('LANGUAGE', 'LC_ALL', 'LC_MESSAGES', 'LANG'): + for i in ("LANGUAGE", "LC_ALL", "LC_MESSAGES", "LANG"): lang = os.environ.get(i) if lang: return lang @@ -136,7 +140,6 @@ def _get_current_locale(): class Domain: - def __init__(self, domain): self.domain = domain diff --git a/breezy/identitymap.py b/breezy/identitymap.py index 4266829ed2..43d9bacba1 100644 --- a/breezy/identitymap.py +++ b/breezy/identitymap.py @@ -33,7 +33,7 @@ class IdentityMap: def add_weave(self, id, weave): """Add weave to the map with a given id.""" if self._weave_key(id) in self._map: - raise errors.BzrError(f'weave {id} already in the identity map') + raise errors.BzrError(f"weave {id} already in the identity map") self._map[self._weave_key(id)] = weave self._reverse_map[weave] = self._weave_key(id) @@ -49,7 +49,7 @@ def __init__(self) -> None: def remove_object(self, an_object: object): """Remove object from map.""" if isinstance(an_object, list): - raise KeyError(f'{an_object!r} not in identity map') + raise KeyError(f"{an_object!r} not in identity map") else: self._map.pop(self._reverse_map[an_object]) self._reverse_map.pop(an_object) diff --git a/breezy/ignores.py b/breezy/ignores.py index 5c32c33ad7..0098f9b68d 100644 --- a/breezy/ignores.py +++ b/breezy/ignores.py @@ -26,16 +26,16 @@ # this ignore list, if it does not exist # please keep these sorted (in C locale order) to aid merging USER_DEFAULTS = [ - '*.a', - '*.o', - '*.py[co]', - '*.so', - '*.sw[nop]', - '*~', - '.#*', - '[#]*#', - '__pycache__', - 'bzr-orphans', + "*.a", + "*.o", + "*.py[co]", + "*.so", + "*.sw[nop]", + "*~", + ".#*", + "[#]*#", + "__pycache__", + "bzr-orphans", ] @@ -47,29 +47,31 @@ def parse_ignore_file(f: BinaryIO) -> Set[str]: errors. """ from .globbing import normalize_pattern + ignored = set() ignore_file = f.read() try: # Try and parse whole ignore file at once. - unicode_lines = ignore_file.decode('utf8').split('\n') + unicode_lines = ignore_file.decode("utf8").split("\n") except UnicodeDecodeError: # Otherwise go though line by line and pick out the 'good' # decodable lines - lines = ignore_file.split(b'\n') + lines = ignore_file.split(b"\n") unicode_lines = [] for line_number, line in enumerate(lines): try: - unicode_lines.append(line.decode('utf-8')) + unicode_lines.append(line.decode("utf-8")) except UnicodeDecodeError: # report error about line (idx+1) trace.warning( - '.bzrignore: On Line #%d, malformed utf8 character. ' - 'Ignoring line.' % (line_number + 1)) + ".bzrignore: On Line #%d, malformed utf8 character. " + "Ignoring line." % (line_number + 1) + ) # Append each line to ignore list if it's not a comment line for uline in unicode_lines: - uline = uline.rstrip('\r\n') - if not uline or uline.startswith('#'): + uline = uline.rstrip("\r\n") + if not uline or uline.startswith("#"): continue ignored.add(normalize_pattern(uline)) return ignored @@ -80,7 +82,7 @@ def get_user_ignores(): path = bedding.user_ignore_config_path() patterns = set(USER_DEFAULTS) try: - f = open(path, 'rb') + f = open(path, "rb") except FileNotFoundError: # Create the ignore file, and just return the default # We want to ignore if we can't write to the file @@ -110,9 +112,9 @@ def _set_user_ignores(patterns: Iterable[str]) -> None: bedding.ensure_config_dir_exists() # Create an empty file - with open(ignore_path, 'wb') as f: + with open(ignore_path, "wb") as f: for pattern in patterns: - f.write(pattern.encode('utf8') + b'\n') + f.write(pattern.encode("utf8") + b"\n") def add_unique_user_ignores(new_ignores: Set[str]): @@ -122,6 +124,7 @@ def add_unique_user_ignores(new_ignores: Set[str]): :return: The list of ignores that were added """ from .globbing import normalize_pattern + ignored = get_user_ignores() to_add: list[str] = [] for ignore in new_ignores: @@ -133,9 +136,9 @@ def add_unique_user_ignores(new_ignores: Set[str]): if not to_add: return [] - with open(bedding.user_ignore_config_path(), 'ab') as f: + with open(bedding.user_ignore_config_path(), "ab") as f: for pattern in to_add: - f.write(pattern.encode('utf8') + b'\n') + f.write(pattern.encode("utf8") + b"\n") return to_add @@ -176,12 +179,12 @@ def tree_ignores_add_patterns(tree, name_pattern_list): # read in the existing ignores set ifn = tree.abspath(tree._format.ignore_filename) if tree.has_filename(ifn): - with open(ifn, 'rb') as f: + with open(ifn, "rb") as f: file_contents = f.read() - if file_contents.find(b'\r\n') != -1: - newline = b'\r\n' + if file_contents.find(b"\r\n") != -1: + newline = b"\r\n" else: - newline = b'\n' + newline = b"\n" else: file_contents = b"" newline = os.linesep.encode() @@ -192,14 +195,14 @@ def tree_ignores_add_patterns(tree, name_pattern_list): from .atomicfile import AtomicFile # write out the updated ignores set - with AtomicFile(ifn, 'wb') as f: + with AtomicFile(ifn, "wb") as f: # write the original contents, preserving original line endings f.write(file_contents) - if len(file_contents) > 0 and not file_contents.endswith(b'\n'): + if len(file_contents) > 0 and not file_contents.endswith(b"\n"): f.write(newline) for pattern in name_pattern_list: if pattern not in ignores: - f.write(pattern.encode('utf-8')) + f.write(pattern.encode("utf-8")) f.write(newline) if not tree.is_versioned(tree._format.ignore_filename): diff --git a/breezy/info.py b/breezy/info.py index 887e394a9a..5acceaf34c 100644 --- a/breezy/info.py +++ b/breezy/info.py @@ -14,7 +14,7 @@ # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA -__all__ = ['show_bzrdir_info'] +__all__ = ["show_bzrdir_info"] import sys import time @@ -28,17 +28,16 @@ from .missing import find_unmerged -def plural(n, base='', pl=None): +def plural(n, base="", pl=None): if n == 1: return base elif pl is not None: return pl else: - return 's' + return "s" class LocationList: - def __init__(self, base_path): self.locs = [] self.base_path = base_path @@ -61,10 +60,10 @@ def add_path(self, label, path): except errors.PathNotChild: pass else: - if path == '': - path = '.' - if path != '/': - path = path.rstrip('/') + if path == "": + path = "." + if path != "/": + path = path.rstrip("/") self.locs.append((label, path)) def get_lines(self): @@ -72,8 +71,7 @@ def get_lines(self): return [" %*s: %s\n" % (max_len, l, u) for l, u in self.locs] -def gather_location_info(repository=None, branch=None, working=None, - control=None): +def gather_location_info(repository=None, branch=None, working=None, control=None): locs = {} if branch is not None: branch_path = branch.user_url @@ -85,66 +83,74 @@ def gather_location_info(repository=None, branch=None, working=None, master_path = None try: if control is not None and control.get_branch_reference(): - locs['checkout of branch'] = control.get_branch_reference() + locs["checkout of branch"] = control.get_branch_reference() except NotBranchError: pass if working: working_path = working.user_url if working_path != branch_path: - locs['light checkout root'] = working_path + locs["light checkout root"] = working_path if master_path != branch_path: if repository.is_shared(): - locs['repository checkout root'] = branch_path + locs["repository checkout root"] = branch_path else: - locs['checkout root'] = branch_path + locs["checkout root"] = branch_path if working_path != master_path: - (master_path_base, params) = urlutils.split_segment_parameters( - master_path) + (master_path_base, params) = urlutils.split_segment_parameters(master_path) if working_path == master_path_base: - locs['checkout of co-located branch'] = params['branch'] - elif 'branch' in params: - locs['checkout of branch'] = "{}, branch {}".format( - master_path_base, params['branch']) + locs["checkout of co-located branch"] = params["branch"] + elif "branch" in params: + locs["checkout of branch"] = "{}, branch {}".format( + master_path_base, params["branch"] + ) else: - locs['checkout of branch'] = master_path + locs["checkout of branch"] = master_path elif repository.is_shared(): - locs['repository branch'] = branch_path + locs["repository branch"] = branch_path elif branch_path is not None: # standalone - locs['branch root'] = branch_path + locs["branch root"] = branch_path else: working_path = None if repository is not None and repository.is_shared(): # lightweight checkout of branch in shared repository if branch_path is not None: - locs['repository branch'] = branch_path + locs["repository branch"] = branch_path elif branch_path is not None: # standalone - locs['branch root'] = branch_path + locs["branch root"] = branch_path elif repository is not None: - locs['repository'] = repository.user_url + locs["repository"] = repository.user_url elif control is not None: - locs['control directory'] = control.user_url + locs["control directory"] = control.user_url else: # Really, at least a control directory should be # passed in for this method to be useful. pass if master_path != branch_path: - locs['bound to branch'] = master_path + locs["bound to branch"] = master_path if repository is not None and repository.is_shared(): # lightweight checkout of branch in shared repository - locs['shared repository'] = repository.user_url - order = ['control directory', 'light checkout root', - 'repository checkout root', 'checkout root', - 'checkout of branch', 'checkout of co-located branch', - 'shared repository', 'repository', 'repository branch', - 'branch root', 'bound to branch'] + locs["shared repository"] = repository.user_url + order = [ + "control directory", + "light checkout root", + "repository checkout root", + "checkout root", + "checkout of branch", + "checkout of co-located branch", + "shared repository", + "repository", + "repository branch", + "branch root", + "bound to branch", + ] return [(n, locs[n]) for n in order if n in locs] def _show_location_info(locs, outfile): """Show known locations for working, branch and repository.""" - outfile.write('Location:\n') + outfile.write("Location:\n") path_list = LocationList(osutils.getcwd()) for name, loc in locs: path_list.add_url(name, loc) @@ -153,14 +159,17 @@ def _show_location_info(locs, outfile): def _gather_related_branches(branch): locs = LocationList(osutils.getcwd()) - locs.add_url('public branch', branch.get_public_branch()) - locs.add_url('push branch', branch.get_push_location()) - locs.add_url('parent branch', branch.get_parent()) - locs.add_url('submit branch', branch.get_submit_branch()) + locs.add_url("public branch", branch.get_public_branch()) + locs.add_url("push branch", branch.get_push_location()) + locs.add_url("parent branch", branch.get_parent()) + locs.add_url("submit branch", branch.get_submit_branch()) try: - locs.add_url('stacked on', branch.get_stacked_on_url()) - except (_mod_branch.UnstackableBranchFormat, errors.UnstackableRepositoryFormat, - errors.NotStacked): + locs.add_url("stacked on", branch.get_stacked_on_url()) + except ( + _mod_branch.UnstackableBranchFormat, + errors.UnstackableRepositoryFormat, + errors.NotStacked, + ): pass return locs @@ -169,61 +178,65 @@ def _show_related_info(branch, outfile): """Show parent and push location of branch.""" locs = _gather_related_branches(branch) if len(locs.locs) > 0: - outfile.write('\n') - outfile.write('Related branches:\n') + outfile.write("\n") + outfile.write("Related branches:\n") outfile.writelines(locs.get_lines()) def _show_control_dir_info(control, outfile): """Show control dir information.""" if control._format.colocated_branches: - outfile.write('\n') - outfile.write('Control directory:\n') - outfile.write(f' {len(control.list_branches())} branches\n') + outfile.write("\n") + outfile.write("Control directory:\n") + outfile.write(f" {len(control.list_branches())} branches\n") -def _show_format_info(control=None, repository=None, branch=None, - working=None, outfile=None): +def _show_format_info( + control=None, repository=None, branch=None, working=None, outfile=None +): """Show known formats for control, working, branch and repository.""" - outfile.write('\n') - outfile.write('Format:\n') + outfile.write("\n") + outfile.write("Format:\n") if control: - outfile.write(f' control: {control._format.get_format_description()}\n') + outfile.write(f" control: {control._format.get_format_description()}\n") if working: - outfile.write(f' working tree: {working._format.get_format_description()}\n') + outfile.write(f" working tree: {working._format.get_format_description()}\n") if branch: - outfile.write(f' branch: {branch._format.get_format_description()}\n') + outfile.write(f" branch: {branch._format.get_format_description()}\n") if repository: - outfile.write(' repository: %s\n' % - repository._format.get_format_description()) + outfile.write( + " repository: %s\n" % repository._format.get_format_description() + ) -def _show_locking_info(repository=None, branch=None, working=None, - outfile=None): +def _show_locking_info(repository=None, branch=None, working=None, outfile=None): """Show locking status of working, branch and repository.""" - if (repository and repository.get_physical_lock_status() or - (branch and branch.get_physical_lock_status()) or - (working and working.get_physical_lock_status())): - outfile.write('\n') - outfile.write('Lock status:\n') + if ( + repository + and repository.get_physical_lock_status() + or (branch and branch.get_physical_lock_status()) + or (working and working.get_physical_lock_status()) + ): + outfile.write("\n") + outfile.write("Lock status:\n") if working: if working.get_physical_lock_status(): - status = 'locked' + status = "locked" else: - status = 'unlocked' - outfile.write(f' working tree: {status}\n') + status = "unlocked" + outfile.write(f" working tree: {status}\n") if branch: if branch.get_physical_lock_status(): - status = 'locked' + status = "locked" else: - status = 'unlocked' - outfile.write(f' branch: {status}\n') + status = "unlocked" + outfile.write(f" branch: {status}\n") if repository: if repository.get_physical_lock_status(): - status = 'locked' + status = "locked" else: - status = 'unlocked' - outfile.write(f' repository: {status}\n') + status = "unlocked" + outfile.write(f" repository: {status}\n") def _show_missing_revisions_branch(branch, outfile): @@ -233,10 +246,11 @@ def _show_missing_revisions_branch(branch, outfile): if master: local_extra, remote_extra = find_unmerged(branch, master) if remote_extra: - outfile.write('\n') - outfile.write(('Branch is out of date: missing %d ' - 'revision%s.\n') % (len(remote_extra), - plural(len(remote_extra)))) + outfile.write("\n") + outfile.write( + ("Branch is out of date: missing %d " "revision%s.\n") + % (len(remote_extra), plural(len(remote_extra))) + ) def _show_missing_revisions_working(working, outfile): @@ -254,9 +268,11 @@ def _show_missing_revisions_working(working, outfile): if branch_revno and tree_last_id != branch_last_revision: tree_last_revno = branch.revision_id_to_revno(tree_last_id) missing_count = branch_revno - tree_last_revno - outfile.write('\n') - outfile.write(('Working tree is out of date: missing %d ' - 'revision%s.\n') % (missing_count, plural(missing_count))) + outfile.write("\n") + outfile.write( + ("Working tree is out of date: missing %d " "revision%s.\n") + % (missing_count, plural(missing_count)) + ) def _show_working_stats(working, outfile): @@ -264,14 +280,14 @@ def _show_working_stats(working, outfile): basis = working.basis_tree() delta = working.changes_from(basis, want_unchanged=True) - outfile.write('\n') - outfile.write('In the working tree:\n') - outfile.write(' %8s unchanged\n' % len(delta.unchanged)) - outfile.write(f' {len(delta.modified):8} modified\n') - outfile.write(f' {len(delta.added):8} added\n') - outfile.write(f' {len(delta.removed):8} removed\n') - outfile.write(f' {len(delta.renamed):8} renamed\n') - outfile.write(f' {len(delta.copied):8} copied\n') + outfile.write("\n") + outfile.write("In the working tree:\n") + outfile.write(" %8s unchanged\n" % len(delta.unchanged)) + outfile.write(f" {len(delta.modified):8} modified\n") + outfile.write(f" {len(delta.added):8} added\n") + outfile.write(f" {len(delta.removed):8} removed\n") + outfile.write(f" {len(delta.renamed):8} renamed\n") + outfile.write(f" {len(delta.copied):8} copied\n") ignore_cnt = unknown_cnt = 0 for path in working.extras(): @@ -279,15 +295,17 @@ def _show_working_stats(working, outfile): ignore_cnt += 1 else: unknown_cnt += 1 - outfile.write(' %8d unknown\n' % unknown_cnt) - outfile.write(' %8d ignored\n' % ignore_cnt) + outfile.write(" %8d unknown\n" % unknown_cnt) + outfile.write(" %8d ignored\n" % ignore_cnt) dir_cnt = 0 for path, entry in working.iter_entries_by_dir(): - if entry.kind == 'directory' and path != '': + if entry.kind == "directory" and path != "": dir_cnt += 1 - outfile.write(' %8d versioned %s\n' % (dir_cnt, - plural(dir_cnt, 'subdirectory', 'subdirectories'))) + outfile.write( + " %8d versioned %s\n" + % (dir_cnt, plural(dir_cnt, "subdirectory", "subdirectories")) + ) def _show_branch_stats(branch, verbose, outfile): @@ -296,47 +314,49 @@ def _show_branch_stats(branch, verbose, outfile): revno, head = branch.last_revision_info() except errors.UnsupportedOperation: return {} - outfile.write('\n') - outfile.write('Branch history:\n') - outfile.write(' %8d revision%s\n' % (revno, plural(revno))) + outfile.write("\n") + outfile.write("Branch history:\n") + outfile.write(" %8d revision%s\n" % (revno, plural(revno))) stats = branch.repository.gather_stats(head, committers=verbose) if verbose: - committers = stats['committers'] - outfile.write(' %8d committer%s\n' % (committers, - plural(committers))) + committers = stats["committers"] + outfile.write(" %8d committer%s\n" % (committers, plural(committers))) if revno: - timestamp, timezone = stats['firstrev'] + timestamp, timezone = stats["firstrev"] age = int((time.time() - timestamp) / 3600 / 24) - outfile.write(' %8d day%s old\n' % (age, plural(age))) - outfile.write(' first revision: %s\n' % - osutils.format_date(timestamp, timezone)) - timestamp, timezone = stats['latestrev'] - outfile.write(' latest revision: %s\n' % - osutils.format_date(timestamp, timezone)) + outfile.write(" %8d day%s old\n" % (age, plural(age))) + outfile.write( + " first revision: %s\n" % osutils.format_date(timestamp, timezone) + ) + timestamp, timezone = stats["latestrev"] + outfile.write( + " latest revision: %s\n" % osutils.format_date(timestamp, timezone) + ) return stats def _show_repository_info(repository, outfile): """Show settings of a repository.""" if repository.make_working_trees(): - outfile.write('\n') - outfile.write('Create working tree for new branches inside ' - 'the repository.\n') + outfile.write("\n") + outfile.write( + "Create working tree for new branches inside " "the repository.\n" + ) def _show_repository_stats(repository, stats, outfile): """Show statistics about a repository.""" f = StringIO() - if 'revisions' in stats: - revisions = stats['revisions'] - f.write(' %8d revision%s\n' % (revisions, plural(revisions))) - if 'size' in stats: - f.write(' %8d KiB\n' % (stats['size'] / 1024)) - for hook in hooks['repository']: + if "revisions" in stats: + revisions = stats["revisions"] + f.write(" %8d revision%s\n" % (revisions, plural(revisions))) + if "size" in stats: + f.write(" %8d KiB\n" % (stats["size"] / 1024)) + for hook in hooks["repository"]: hook(repository, stats, f) if f.getvalue() != "": - outfile.write('\n') - outfile.write('Repository:\n') + outfile.write("\n") + outfile.write("Repository:\n") outfile.write(f.getvalue()) @@ -345,8 +365,7 @@ def show_bzrdir_info(a_controldir, verbose=False, outfile=None): if outfile is None: outfile = sys.stdout try: - tree = a_controldir.open_workingtree( - recommend_upgrade=False) + tree = a_controldir.open_workingtree(recommend_upgrade=False) except (NoWorkingTree, NotLocalUrl, NotBranchError): tree = None try: @@ -371,15 +390,15 @@ def show_bzrdir_info(a_controldir, verbose=False, outfile=None): if lockable is not None: lockable.lock_read() try: - show_component_info(a_controldir, repository, branch, tree, verbose, - outfile) + show_component_info(a_controldir, repository, branch, tree, verbose, outfile) finally: if lockable is not None: lockable.unlock() -def show_component_info(control, repository, branch=None, working=None, - verbose=1, outfile=None): +def show_component_info( + control, repository, branch=None, working=None, verbose=1, outfile=None +): """Write info about all bzrdir components to stdout.""" if outfile is None: outfile = sys.stdout @@ -391,9 +410,11 @@ def show_component_info(control, repository, branch=None, working=None, format = describe_format(control, repository, branch, working) outfile.write(f"{layout} (format: {format})\n") _show_location_info( - gather_location_info(control=control, repository=repository, - branch=branch, working=working), - outfile) + gather_location_info( + control=control, repository=repository, branch=branch, working=working + ), + outfile, + ) if branch is not None: _show_related_info(branch, outfile) if verbose == 0: @@ -435,19 +456,19 @@ def describe_layout(repository=None, branch=None, tree=None, control=None): if branch_reference is not None: return "Dangling branch reference" if repository is None: - return 'Empty control directory' + return "Empty control directory" if branch is None and tree is None: if repository.is_shared(): - phrase = 'Shared repository' + phrase = "Shared repository" else: - phrase = 'Unshared repository' + phrase = "Unshared repository" extra = [] if repository.make_working_trees(): - extra.append('trees') + extra.append("trees") if len(control.branch_names()) > 0: - extra.append('colocated branches') + extra.append("colocated branches") if extra: - phrase += ' with ' + " and ".join(extra) + phrase += " with " + " and ".join(extra) return phrase else: if repository.is_shared(): @@ -461,13 +482,15 @@ def describe_layout(repository=None, branch=None, tree=None, control=None): if branch is None and tree is not None: phrase = "branchless tree" else: - if (tree is not None and tree.controldir.control_url != - branch.controldir.control_url): - independence = '' + if ( + tree is not None + and tree.controldir.control_url != branch.controldir.control_url + ): + independence = "" phrase = "Lightweight checkout" elif branch.get_bound_location() is not None: - if independence == 'Standalone ': - independence = '' + if independence == "Standalone ": + independence = "" if tree is None: phrase = "Bound branch" else: @@ -486,8 +509,7 @@ def describe_format(control, repository, branch, tree): If no matching candidate is found, "unnamed" is returned. """ candidates = [] - if (branch is not None and tree is not None and - branch.user_url != tree.user_url): + if branch is not None and tree is not None and branch.user_url != tree.user_url: branch = None repository = None non_aliases = set(controldir.format_registry.keys()) @@ -495,29 +517,27 @@ def describe_format(control, repository, branch, tree): for key in non_aliases: format = controldir.format_registry.make_controldir(key) if isinstance(format, bzrdir.BzrDirMetaFormat1): - if (tree and format.workingtree_format != - tree._format): + if tree and format.workingtree_format != tree._format: continue - if (branch and format.get_branch_format() != - branch._format): + if branch and format.get_branch_format() != branch._format: continue - if (repository and format.repository_format != - repository._format): + if repository and format.repository_format != repository._format: continue if format.__class__ is not control._format.__class__: continue candidates.append(key) if len(candidates) == 0: - return 'unnamed' + return "unnamed" candidates.sort() - new_candidates = [c for c in candidates if not - controldir.format_registry.get_info(c).hidden] + new_candidates = [ + c for c in candidates if not controldir.format_registry.get_info(c).hidden + ] if len(new_candidates) > 0: # If there are any non-hidden formats that match, only return those to # avoid listing hidden formats except when only a hidden format will # do. candidates = new_candidates - return ' or '.join(candidates) + return " or ".join(candidates) class InfoHooks(_mod_hooks.Hooks): @@ -526,10 +546,12 @@ class InfoHooks(_mod_hooks.Hooks): def __init__(self): super().__init__("breezy.info", "hooks") self.add_hook( - 'repository', + "repository", "Invoked when displaying the statistics for a repository. " "repository is called with a statistics dictionary as returned " - "by the repository and a file-like object to write to.", (1, 15)) + "by the repository and a file-like object to write to.", + (1, 15), + ) hooks = InfoHooks() diff --git a/breezy/inter.py b/breezy/inter.py index b84a5abeff..dcdeb2174c 100644 --- a/breezy/inter.py +++ b/breezy/inter.py @@ -24,16 +24,17 @@ class NoCompatibleInter(BzrError): - - _fmt = ('No compatible object available for operations from %(source)r ' - 'to %(target)r.') + _fmt = ( + "No compatible object available for operations from %(source)r " + "to %(target)r." + ) def __init__(self, source, target): self.source = source self.target = target -T = TypeVar('T') +T = TypeVar("T") class InterObject(Generic[T]): diff --git a/breezy/lazy_import.py b/breezy/lazy_import.py index 242b323977..638333a976 100644 --- a/breezy/lazy_import.py +++ b/breezy/lazy_import.py @@ -47,9 +47,9 @@ class ImportNameCollision(InternalBzrError): - - _fmt = ("Tried to import an object to the same name as" - " an existing object. %(name)s") + _fmt = ( + "Tried to import an object to the same name as" " an existing object. %(name)s" + ) def __init__(self, name): BzrError.__init__(self) @@ -57,22 +57,19 @@ def __init__(self, name): class IllegalUseOfScopeReplacer(InternalBzrError): - - _fmt = ("ScopeReplacer object %(name)r was used incorrectly:" - " %(msg)s%(extra)s") + _fmt = "ScopeReplacer object %(name)r was used incorrectly:" " %(msg)s%(extra)s" def __init__(self, name, msg, extra=None): BzrError.__init__(self) self.name = name self.msg = msg if extra: - self.extra = ': ' + str(extra) + self.extra = ": " + str(extra) else: - self.extra = '' + self.extra = "" class InvalidImportLine(InternalBzrError): - _fmt = "Not a valid import statement: %(msg)\n%(text)s" def __init__(self, text, msg): @@ -88,7 +85,7 @@ class ScopeReplacer: needed. """ - __slots__ = ('_scope', '_factory', '_name', '_real_obj') + __slots__ = ("_scope", "_factory", "_name", "_real_obj") # If you to do x = y, setting this to False will disallow access to # members from the second variable (i.e. x). This should normally @@ -105,53 +102,57 @@ def __init__(self, scope, factory, name): It will be passed (self, scope, name) :param name: The variable name in the given scope. """ - object.__setattr__(self, '_scope', scope) - object.__setattr__(self, '_factory', factory) - object.__setattr__(self, '_name', name) - object.__setattr__(self, '_real_obj', None) + object.__setattr__(self, "_scope", scope) + object.__setattr__(self, "_factory", factory) + object.__setattr__(self, "_name", name) + object.__setattr__(self, "_real_obj", None) scope[name] = self def _resolve(self): """Return the real object for which this is a placeholder.""" - name = object.__getattribute__(self, '_name') - real_obj = object.__getattribute__(self, '_real_obj') + name = object.__getattribute__(self, "_name") + real_obj = object.__getattribute__(self, "_real_obj") if real_obj is None: # No obj generated previously, so generate from factory and scope. - factory = object.__getattribute__(self, '_factory') - scope = object.__getattribute__(self, '_scope') + factory = object.__getattribute__(self, "_factory") + scope = object.__getattribute__(self, "_scope") obj = factory(self, scope, name) if obj is self: raise IllegalUseOfScopeReplacer( - name, msg="Object tried" - " to replace itself, check it's not using its own scope.") + name, + msg="Object tried" + " to replace itself, check it's not using its own scope.", + ) # Check if another thread has jumped in while obj was generated. - real_obj = object.__getattribute__(self, '_real_obj') + real_obj = object.__getattribute__(self, "_real_obj") if real_obj is None: # Still no prexisting obj, so go ahead and assign to scope and # return. There is still a small window here where races will # not be detected, but safest to avoid additional locking. - object.__setattr__(self, '_real_obj', obj) + object.__setattr__(self, "_real_obj", obj) scope[name] = obj return obj # Raise if proxying is disabled as obj has already been generated. if not ScopeReplacer._should_proxy: raise IllegalUseOfScopeReplacer( - name, msg="Object already replaced, did you assign it" - " to another variable?") + name, + msg="Object already replaced, did you assign it" + " to another variable?", + ) return real_obj def __getattribute__(self, attr): - obj = object.__getattribute__(self, '_resolve')() + obj = object.__getattribute__(self, "_resolve")() return getattr(obj, attr) def __setattr__(self, attr, value): - obj = object.__getattribute__(self, '_resolve')() + obj = object.__getattribute__(self, "_resolve")() return setattr(obj, attr, value) def __call__(self, *args, **kwargs): - obj = object.__getattribute__(self, '_resolve')() + obj = object.__getattribute__(self, "_resolve")() return obj(*args, **kwargs) @@ -186,7 +187,7 @@ class ImportReplacer(ScopeReplacer): # We can't just use 'isinstance(obj, ImportReplacer)', because that # accesses .__class__, which goes through __getattribute__, and triggers # the replacement. - __slots__ = ('_import_replacer_children', '_member', '_module_path') + __slots__ = ("_import_replacer_children", "_member", "_module_path") def __init__(self, scope, name, module_path, member=None, children=None): """Upon request import 'module_path' as the name 'module_name'. @@ -222,23 +223,22 @@ def __init__(self, scope, name, module_path, member=None, children=None): if children is None: children = {} if (member is not None) and children: - raise ValueError('Cannot supply both a member and children') + raise ValueError("Cannot supply both a member and children") - object.__setattr__(self, '_import_replacer_children', children) - object.__setattr__(self, '_member', member) - object.__setattr__(self, '_module_path', module_path) + object.__setattr__(self, "_import_replacer_children", children) + object.__setattr__(self, "_member", member) + object.__setattr__(self, "_module_path", module_path) # Indirecting through __class__ so that children can # override _import (especially our instrumented version) - cls = object.__getattribute__(self, '__class__') - ScopeReplacer.__init__(self, scope=scope, name=name, - factory=cls._import) + cls = object.__getattribute__(self, "__class__") + ScopeReplacer.__init__(self, scope=scope, name=name, factory=cls._import) def _import(self, scope, name): - children = object.__getattribute__(self, '_import_replacer_children') - member = object.__getattribute__(self, '_member') - module_path = object.__getattribute__(self, '_module_path') - name = '.'.join(module_path) + children = object.__getattribute__(self, "_import_replacer_children") + member = object.__getattribute__(self, "_member") + module_path = object.__getattribute__(self, "_module_path") + name = ".".join(module_path) if member is not None: module = _builtin_import(name, scope, scope, [member], level=0) return getattr(module, member) @@ -248,14 +248,17 @@ def _import(self, scope, name): module = getattr(module, path) # Prepare the children to be imported - for child_name, (child_path, child_member, grandchildren) in \ - children.items(): + for child_name, (child_path, child_member, grandchildren) in children.items(): # Using self.__class__, so that children get children classes # instantiated. (This helps with instrumented tests) - cls = object.__getattribute__(self, '__class__') - cls(module.__dict__, name=child_name, - module_path=child_path, member=child_member, - children=grandchildren) + cls = object.__getattribute__(self, "__class__") + cls( + module.__dict__, + name=child_name, + module_path=child_path, + member=child_member, + children=grandchildren, + ) return module @@ -268,7 +271,7 @@ class ImportProcessor: # For now, it should be supporting a superset of python import # syntax which is all we really care about. - __slots__ = ['imports', '_lazy_import_class'] + __slots__ = ["imports", "_lazy_import_class"] def __init__(self, lazy_import_class=None) -> None: self.imports: Dict[str, Any] = {} @@ -289,19 +292,19 @@ def lazy_import(self, scope, text): def _convert_imports(self, scope): # Now convert the map into a set of imports for name, info in self.imports.items(): - self._lazy_import_class(scope, name=name, module_path=info[0], - member=info[1], children=info[2]) + self._lazy_import_class( + scope, name=name, module_path=info[0], member=info[1], children=info[2] + ) def _build_map(self, text): """Take a string describing imports, and build up the internal map.""" for line in self._canonicalize_import_text(text): - if line.startswith('import '): + if line.startswith("import "): self._convert_import_str(line) - elif line.startswith('from '): + elif line.startswith("from "): self._convert_from_str(line) else: - raise InvalidImportLine( - line, "doesn't start with 'import ' or 'from '") + raise InvalidImportLine(line, "doesn't start with 'import ' or 'from '") def _convert_import_str(self, import_str): """This converts a import string into an import map. @@ -310,21 +313,21 @@ def _convert_import_str(self, import_str): :param import_str: The import string to process """ - if not import_str.startswith('import '): - raise ValueError(f'bad import string {import_str!r}') - import_str = import_str[len('import '):] + if not import_str.startswith("import "): + raise ValueError(f"bad import string {import_str!r}") + import_str = import_str[len("import ") :] - for path in import_str.split(','): + for path in import_str.split(","): path = path.strip() if not path: continue - as_hunks = path.split(' as ') + as_hunks = path.split(" as ") if len(as_hunks) == 2: # We have 'as' so this is a different style of import # 'import foo.bar.baz as bing' creates a local variable # named 'bing' which points to 'foo.bar.baz' name = as_hunks[1].strip() - module_path = as_hunks[0].strip().split('.') + module_path = as_hunks[0].strip().split(".") if name in self.imports: raise ImportNameCollision(name) if not module_path[0]: @@ -333,7 +336,7 @@ def _convert_import_str(self, import_str): self.imports[name] = (module_path, None, {}) else: # Now we need to handle - module_path = path.split('.') + module_path = path.split(".") name = module_path[0] if not name: raise ImportError(path) @@ -360,22 +363,22 @@ def _convert_from_str(self, from_str): :param from_str: The import string to process """ - if not from_str.startswith('from '): - raise ValueError(f'bad from/import {from_str!r}') - from_str = from_str[len('from '):] + if not from_str.startswith("from "): + raise ValueError(f"bad from/import {from_str!r}") + from_str = from_str[len("from ") :] - from_module, import_list = from_str.split(' import ') + from_module, import_list = from_str.split(" import ") - from_module_path = from_module.split('.') + from_module_path = from_module.split(".") if not from_module_path[0]: raise ImportError(from_module) - for path in import_list.split(','): + for path in import_list.split(","): path = path.strip() if not path: continue - as_hunks = path.split(' as ') + as_hunks = path.split(" as ") if len(as_hunks) == 2: # We have 'as' so this is a different style of import # 'import foo.bar.baz as bing' creates a local variable @@ -397,27 +400,27 @@ def _canonicalize_import_text(self, text): out = [] cur = None - for line in text.split('\n'): + for line in text.split("\n"): line = line.strip() - loc = line.find('#') + loc = line.find("#") if loc != -1: line = line[:loc].strip() if not line: continue if cur is not None: - if line.endswith(')'): - out.append(cur + ' ' + line[:-1]) + if line.endswith(")"): + out.append(cur + " " + line[:-1]) cur = None else: - cur += ' ' + line + cur += " " + line else: - if '(' in line and ')' not in line: - cur = line.replace('(', '') + if "(" in line and ")" not in line: + cur = line.replace("(", "") else: - out.append(line.replace('(', '').replace(')', '')) + out.append(line.replace("(", "").replace(")", "")) if cur is not None: - raise InvalidImportLine(cur, 'Unmatched parenthesis') + raise InvalidImportLine(cur, "Unmatched parenthesis") return out diff --git a/breezy/lazy_regex.py b/breezy/lazy_regex.py index a292d8da0a..dd5fd185ed 100644 --- a/breezy/lazy_regex.py +++ b/breezy/lazy_regex.py @@ -30,8 +30,7 @@ class InvalidPattern(errors.BzrError): - - _fmt = ('Invalid pattern(s) found. %(msg)s') + _fmt = "Invalid pattern(s) found. %(msg)s" def __init__(self, msg): self.msg = msg @@ -43,14 +42,25 @@ class LazyRegex: # These are the parameters on a real _sre.SRE_Pattern object, which we # will map to local members so that we don't have the proxy overhead. _regex_attributes_to_copy = [ - '__copy__', '__deepcopy__', 'findall', 'finditer', 'match', - 'scanner', 'search', 'split', 'sub', 'subn' - ] + "__copy__", + "__deepcopy__", + "findall", + "finditer", + "match", + "scanner", + "search", + "split", + "sub", + "subn", + ] # We use slots to keep the overhead low. But we need a slot entry for # all of the attributes we will copy - __slots__ = ['_real_regex', '_regex_args', '_regex_kwargs', - ] + _regex_attributes_to_copy + __slots__ = [ + "_real_regex", + "_regex_args", + "_regex_kwargs", + ] + _regex_attributes_to_copy def __init__(self, args, kwargs): """Create a new proxy object, passing in the args to pass to re.compile. @@ -64,8 +74,9 @@ def __init__(self, args, kwargs): def _compile_and_collapse(self): """Actually compile the requested regex.""" - self._real_regex = self._real_re_compile(*self._regex_args, - **self._regex_kwargs) + self._real_regex = self._real_re_compile( + *self._regex_args, **self._regex_kwargs + ) for attr in self._regex_attributes_to_copy: setattr(self, attr, getattr(self._real_regex, attr)) @@ -83,13 +94,13 @@ def __getstate__(self): return { "args": self._regex_args, "kwargs": self._regex_kwargs, - } + } def __setstate__(self, dict): """Restore from a pickled state.""" self._real_regex = None - self._regex_args = dict['args'] - self._regex_kwargs = dict['kwargs'] + self._regex_args = dict["args"] + self._regex_kwargs = dict["kwargs"] def __getattr__(self, attr): """Return a member from the proxied regex object. diff --git a/breezy/library_state.py b/breezy/library_state.py index 32ca99b246..54fe1a5fcb 100644 --- a/breezy/library_state.py +++ b/breezy/library_state.py @@ -17,8 +17,8 @@ """The core state needed to make use of bzr is managed here.""" __all__ = [ - 'BzrLibraryState', - ] + "BzrLibraryState", +] import contextlib @@ -27,7 +27,9 @@ from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( config, osutils, @@ -35,7 +37,8 @@ trace, ui, ) -""") +""", +) class BzrLibraryState: @@ -92,9 +95,10 @@ def _start(self): # TestRunBzrSubprocess may fail. self.exit_stack = contextlib.ExitStack() - if breezy.version_info[3] == 'final': + if breezy.version_info[3] == "final": self.exit_stack.callback( - symbol_versioning.suppress_deprecation_warnings(override=True)) + symbol_versioning.suppress_deprecation_warnings(override=True) + ) self._trace.__enter__() diff --git a/breezy/location.py b/breezy/location.py index 8fd7b35070..db092fa1e5 100644 --- a/breezy/location.py +++ b/breezy/location.py @@ -27,13 +27,17 @@ class LocationHooks(Hooks): def __init__(self): Hooks.__init__(self, "breezy.location", "hooks") self.add_hook( - 'rewrite_url', + "rewrite_url", "Possibly rewrite a URL. Called with a URL to rewrite and the " - "purpose of the URL.", (3, 0)) + "purpose of the URL.", + (3, 0), + ) self.add_hook( - 'rewrite_location', + "rewrite_location", "Possibly rewrite a location. Called with a location string to " - "rewrite and the purpose of the URL.", (3, 2)) + "rewrite and the purpose of the URL.", + (3, 2), + ) hooks = LocationHooks() @@ -58,25 +62,27 @@ def location_to_url(location, purpose=None): if not isinstance(location, str): raise AssertionError("location not a byte or unicode string") - for hook in hooks['rewrite_location']: + for hook in hooks["rewrite_location"]: location = hook(location, purpose=purpose) - if location.startswith(':pserver:') or location.startswith(':extssh:'): + if location.startswith(":pserver:") or location.startswith(":extssh:"): return cvs_to_url(location) from .directory_service import directories + location = directories.dereference(location, purpose) # Catch any URLs which are passing Unicode rather than ASCII try: - location = location.encode('ascii') + location = location.encode("ascii") except UnicodeError as err: if urlutils.is_url(location): raise urlutils.InvalidURL( - path=location, extra='URLs must be properly escaped') from err + path=location, extra="URLs must be properly escaped" + ) from err location = urlutils.local_path_to_url(location) else: - location = location.decode('ascii') + location = location.decode("ascii") if location.startswith("file:") and not location.startswith("file://"): return urlutils.join(urlutils.local_path_to_url("."), location[5:]) @@ -91,7 +97,7 @@ def location_to_url(location, purpose=None): if not urlutils.is_url(location): return urlutils.local_path_to_url(location) - for hook in hooks['rewrite_url']: + for hook in hooks["rewrite_url"]: location = hook(location, purpose=purpose) return location diff --git a/breezy/lock.py b/breezy/lock.py index 24279d6c5c..1c98ee9822 100644 --- a/breezy/lock.py +++ b/breezy/lock.py @@ -44,33 +44,36 @@ have_fcntl = True + def ReadLock(path): - return _transport_rs.ReadLock(path, 'strict_locks' in debug.debug_flags) + return _transport_rs.ReadLock(path, "strict_locks" in debug.debug_flags) -def WriteLock(path): - return _transport_rs.WriteLock(path, 'strict_locks' in debug.debug_flags) +def WriteLock(path): + return _transport_rs.WriteLock(path, "strict_locks" in debug.debug_flags) LockToken = bytes class LockHooks(Hooks): - def __init__(self): Hooks.__init__(self, "breezy.lock", "Lock.hooks") self.add_hook( - 'lock_acquired', - "Called with a breezy.lock.LockResult when a physical lock is " - "acquired.", (1, 8)) + "lock_acquired", + "Called with a breezy.lock.LockResult when a physical lock is " "acquired.", + (1, 8), + ) self.add_hook( - 'lock_released', - "Called with a breezy.lock.LockResult when a physical lock is " - "released.", (1, 8)) + "lock_released", + "Called with a breezy.lock.LockResult when a physical lock is " "released.", + (1, 8), + ) self.add_hook( - 'lock_broken', - "Called with a breezy.lock.LockResult when a physical lock is " - "broken.", (1, 15)) + "lock_broken", + "Called with a breezy.lock.LockResult when a physical lock is " "broken.", + (1, 15), + ) class Lock: @@ -81,24 +84,34 @@ class Lock: hooks = LockHooks() - def __init__(self, transport: Transport, path: str, file_modebits: int, - dir_modebits: int) -> None: ... + def __init__( + self, transport: Transport, path: str, file_modebits: int, dir_modebits: int + ) -> None: + ... - def create(self, mode: int): ... + def create(self, mode: int): + ... - def break_lock(self) -> None: ... + def break_lock(self) -> None: + ... - def leave_in_place(self) -> None: ... + def leave_in_place(self) -> None: + ... - def dont_leave_in_place(self) -> None: ... + def dont_leave_in_place(self) -> None: + ... - def validate_token(self, token: Optional[LockToken]) -> None: ... + def validate_token(self, token: Optional[LockToken]) -> None: + ... - def lock_write(self, token: Optional[LockToken]) -> Optional[LockToken]: ... + def lock_write(self, token: Optional[LockToken]) -> Optional[LockToken]: + ... - def lock_read(self) -> None: ... + def lock_read(self) -> None: + ... - def unlock(self) -> None: ... + def unlock(self) -> None: + ... def peek(self) -> LockToken: raise NotImplementedError(self.peek) @@ -116,7 +129,7 @@ def __eq__(self, other): return self.lock_url == other.lock_url and self.details == other.details def __repr__(self): - return f'{self.__class__.__name__}({self.lock_url}, {self.details})' + return f"{self.__class__.__name__}({self.lock_url}, {self.details})" class LogicalLockResult: @@ -155,9 +168,8 @@ def cant_unlock_not_held(locked_object): # block, so it's useful to have the option not to generate a new error # here. You can use -Werror to make it fatal. It should possibly also # raise LockNotHeld. - if debug.debug_flag_enabled('unlock'): - warnings.warn(f"{locked_object!r} is already unlocked", - stacklevel=3) + if debug.debug_flag_enabled("unlock"): + warnings.warn(f"{locked_object!r} is already unlocked", stacklevel=3) else: raise errors.LockNotHeld(locked_object) @@ -174,12 +186,12 @@ class _RelockDebugMixin: _prev_lock = None def _note_lock(self, lock_type): - if debug.debug_flag_enabled('relock') and self._prev_lock == lock_type: - if lock_type == 'r': - type_name = 'read' + if debug.debug_flag_enabled("relock") and self._prev_lock == lock_type: + if lock_type == "r": + type_name = "read" else: - type_name = 'write' - trace.note(gettext('{0!r} was {1} locked again'), self, type_name) + type_name = "write" + trace.note(gettext("{0!r} was {1} locked again"), self, type_name) self._prev_lock = lock_type @@ -192,4 +204,4 @@ def write_locked(lockable): lockable.unlock() -_lock_classes = [('default', WriteLock, ReadLock)] +_lock_classes = [("default", WriteLock, ReadLock)] diff --git a/breezy/lockdir.py b/breezy/lockdir.py index b4ff2e5bfa..b3761787eb 100644 --- a/breezy/lockdir.py +++ b/breezy/lockdir.py @@ -151,10 +151,16 @@ class LockDir(lock.Lock): """Write-lock guarding access to data.""" - __INFO_NAME = '/info' - - def __init__(self, transport, path, file_modebits=0o644, - dir_modebits=0o755, extra_holder_info=None): + __INFO_NAME = "/info" + + def __init__( + self, + transport, + path, + file_modebits=0o644, + dir_modebits=0o755, + extra_holder_info=None, + ): """Create a new LockDir object. The LockDir is initially unlocked - this just creates the object. @@ -173,7 +179,7 @@ def __init__(self, transport, path, file_modebits=0o644, self._lock_held = False self._locked_via_token = False self._fake_read_lock = False - self._held_dir = path + '/held' + self._held_dir = path + "/held" self._held_info_path = self._held_dir + self.__INFO_NAME self._file_modebits = file_modebits self._dir_modebits = dir_modebits @@ -182,7 +188,7 @@ def __init__(self, transport, path, file_modebits=0o644, self._warned_about_lock_holder = None def __repr__(self): - return f'{self.__class__.__name__}({self.transport.base}{self.path})' + return f"{self.__class__.__name__}({self.transport.base}{self.path})" is_held = property(lambda self: self._lock_held) @@ -224,8 +230,13 @@ def _attempt_lock(self): try: self.transport.rename(tmpname, self._held_dir) break - except (errors.TransportError, PathError, DirectoryNotEmpty, - FileExists, ResourceBusy) as e: + except ( + errors.TransportError, + PathError, + DirectoryNotEmpty, + FileExists, + ResourceBusy, + ) as e: self._trace("... contention, %s", e) other_holder = self.peek() self._trace(f"other holder is {other_holder!r}") @@ -252,15 +263,14 @@ def _attempt_lock(self): info = self.peek() self._trace("after locking, info=%r", info) if info is None: - raise LockFailed(self, "lock was renamed into place, but " - "now is missing!") + raise LockFailed( + self, "lock was renamed into place, but " "now is missing!" + ) if info.nonce != self.nonce: - self._trace("rename succeeded, " - "but lock is still held by someone else") + self._trace("rename succeeded, " "but lock is still held by someone else") raise LockContention(self) self._lock_held = True - self._trace("... lock succeeded after %dms", - (time.time() - start_time) * 1000) + self._trace("... lock succeeded after %dms", (time.time() - start_time) * 1000) return self.nonce def _handle_lock_contention(self, other_holder): @@ -276,13 +286,14 @@ def _handle_lock_contention(self, other_holder): it might be None if the lock can be seen to be held but the info can't be read. """ - if (other_holder is not None): + if other_holder is not None: if other_holder.is_lock_holder_known_dead(): - if self.get_config().get('locks.steal_dead'): + if self.get_config().get("locks.steal_dead"): ui.ui_factory.show_user_warning( - 'locks_steal_dead', + "locks_steal_dead", lock_url=urlutils.join(self.transport.base, self.path), - other_holder_info=str(other_holder)) + other_holder_info=str(other_holder), + ) self.force_break(other_holder) self._trace("stole lock from dead holder") return @@ -302,7 +313,7 @@ def _remove_pending_dir(self, tmpname): note(gettext("error removing pending lock: %s"), e) def _create_pending_dir(self): - tmpname = f'{self.path}/{rand_chars(10)}.tmp' + tmpname = f"{self.path}/{rand_chars(10)}.tmp" try: self.transport.mkdir(tmpname) except NoSuchFile: @@ -319,8 +330,7 @@ def _create_pending_dir(self): # directory so we don't have to worry about files existing there. # We'll rename the whole directory into place to get atomic # properties - self.transport.put_bytes_non_atomic(tmpname + self.__INFO_NAME, - info.to_bytes()) + self.transport.put_bytes_non_atomic(tmpname + self.__INFO_NAME, info.to_bytes()) return tmpname @only_raises(LockNotHeld, LockBroken) @@ -340,7 +350,7 @@ def unlock(self): # whole tree start_time = time.time() self._trace("unlocking") - tmpname = f'{self.path}/releasing.{rand_chars(20)}.tmp' + tmpname = f"{self.path}/releasing.{rand_chars(20)}.tmp" # gotta own it to unlock self.confirm() self.transport.rename(self._held_dir, tmpname) @@ -353,14 +363,15 @@ def unlock(self): # another locker within the 'held' directory. do a slower # deletion where we list the directory and remove everything # within it. - self._trace("doing recursive deletion of non-empty directory " - "%s", tmpname) + self._trace( + "doing recursive deletion of non-empty directory " "%s", tmpname + ) self.transport.delete_tree(tmpname) - self._trace("... unlock succeeded after %dms", - (time.time() - start_time) * 1000) - result = lock.LockResult(self.transport.abspath(self.path), - old_nonce) - for hook in self.hooks['lock_released']: + self._trace( + "... unlock succeeded after %dms", (time.time() - start_time) * 1000 + ) + result = lock.LockResult(self.transport.abspath(self.path), old_nonce) + for hook in self.hooks["lock_released"]: hook(result) def break_lock(self): @@ -384,11 +395,11 @@ def break_lock(self): if holder_info is not None: if ui.ui_factory.confirm_action( "Break %(lock_info)s", - 'breezy.lockdir.break', - {'lock_info': str(holder_info)}): + "breezy.lockdir.break", + {"lock_info": str(holder_info)}, + ): result = self.force_break(holder_info) - ui.ui_factory.show_message( - f"Broke lock {result.lock_url}") + ui.ui_factory.show_message(f"Broke lock {result.lock_url}") def force_break(self, dead_holder_info): """Release a lock held by another process. @@ -418,7 +429,7 @@ def force_break(self, dead_holder_info): return if current_info != dead_holder_info: raise LockBreakMismatch(self, current_info, dead_holder_info) - tmpname = f'{self.path}/broken.{rand_chars(20)}.tmp' + tmpname = f"{self.path}/broken.{rand_chars(20)}.tmp" self.transport.rename(self._held_dir, tmpname) # check that we actually broke the right lock, not someone else; # there's a small race window between checking it and doing the @@ -429,9 +440,8 @@ def force_break(self, dead_holder_info): raise LockBreakMismatch(self, broken_info, dead_holder_info) self.transport.delete(broken_info_path) self.transport.rmdir(tmpname) - result = lock.LockResult(self.transport.abspath(self.path), - current_info.nonce) - for hook in self.hooks['lock_broken']: + result = lock.LockResult(self.transport.abspath(self.path), current_info.nonce) + for hook in self.hooks["lock_broken"]: hook(result) return result @@ -448,7 +458,7 @@ def force_break_corrupt(self, corrupt_info_content): # XXX: this copes with unparseable info files, but what about missing # info files? Or missing lock dirs? self._check_not_locked() - tmpname = f'{self.path}/broken.{rand_chars(20)}.tmp' + tmpname = f"{self.path}/broken.{rand_chars(20)}.tmp" self.transport.rename(self._held_dir, tmpname) # check that we actually broke the right lock, not someone else; # there's a small race window between checking it and doing the @@ -460,7 +470,7 @@ def force_break_corrupt(self, corrupt_info_content): self.transport.delete(broken_info_path) self.transport.rmdir(tmpname) result = lock.LockResult(self.transport.abspath(self.path)) - for hook in self.hooks['lock_broken']: + for hook in self.hooks["lock_broken"]: hook(result) def _check_not_locked(self): @@ -493,8 +503,7 @@ def _read_info_file(self, path): peek() reads the info file of the lock holder, if any. """ - return LockHeldInfo.from_info_file_bytes( - self.transport.get_bytes(path)) + return LockHeldInfo.from_info_file_bytes(self.transport.get_bytes(path)) def peek(self): """Check if the lock is held by anyone. @@ -525,9 +534,8 @@ def attempt_lock(self): if self._fake_read_lock: raise LockContention(self) result = self._attempt_lock() - hook_result = lock.LockResult(self.transport.abspath(self.path), - self.nonce) - for hook in self.hooks['lock_acquired']: + hook_result = lock.LockResult(self.transport.abspath(self.path), self.nonce) + for hook in self.hooks["lock_acquired"]: hook(hook_result) return result @@ -536,10 +544,10 @@ def lock_url_for_display(self): # As local lock urls are correct we display them. # We avoid displaying remote lock urls. lock_url = self.transport.abspath(self.path) - if lock_url.startswith('file://'): - lock_url = lock_url.split('.bzr/')[0] + if lock_url.startswith("file://"): + lock_url = lock_url.split(".bzr/")[0] else: - lock_url = '' + lock_url = "" return lock_url def wait_lock(self, timeout=None, poll=None, max_attempts=None): @@ -588,20 +596,22 @@ def wait_lock(self, timeout=None, poll=None, max_attempts=None): new_info = self.peek() if new_info is not None and new_info != last_info: if last_info is None: - start = gettext('Unable to obtain') + start = gettext("Unable to obtain") else: - start = gettext('Lock owner changed for') + start = gettext("Lock owner changed for") last_info = new_info - msg = gettext('{0} lock {1} {2}.').format(start, lock_url, - new_info) + msg = gettext("{0} lock {1} {2}.").format(start, lock_url, new_info) if deadline_str is None: - deadline_str = time.strftime('%H:%M:%S', - time.localtime(deadline)) + deadline_str = time.strftime("%H:%M:%S", time.localtime(deadline)) if timeout > 0: - msg += '\n' + gettext( - 'Will continue to try until %s, unless ' - 'you press Ctrl-C.') % deadline_str - msg += '\n' + gettext('See "brz help break-lock" for more.') + msg += ( + "\n" + + gettext( + "Will continue to try until %s, unless " "you press Ctrl-C." + ) + % deadline_str + ) + msg += "\n" + gettext('See "brz help break-lock" for more.') self._report_function(msg) if (max_attempts is not None) and (attempt_count >= max_attempts): self._trace("exceeded %d attempts") @@ -614,7 +624,7 @@ def wait_lock(self, timeout=None, poll=None, max_attempts=None): # this block is applicable only for local # lock contention self._trace("timeout after waiting %ss", timeout) - raise LockContention('(local)', lock_url) + raise LockContention("(local)", lock_url) def leave_in_place(self): self._locked_via_token = True @@ -677,7 +687,7 @@ def validate_token(self, token): self._trace("revalidated by token %r", token) def _trace(self, format, *args): - if not debug.debug_flag_enabled('lock'): + if not debug.debug_flag_enabled("lock"): return mutter(str(self) + ": " + (format % args)) diff --git a/breezy/log.py b/breezy/log.py index 20068887cc..0c226d4de7 100644 --- a/breezy/log.py +++ b/breezy/log.py @@ -57,7 +57,9 @@ from .lazy_import import lazy_import -lazy_import(globals(), """ +lazy_import( + globals(), + """ from breezy import ( config, @@ -67,7 +69,8 @@ lazy_regex, ) from breezy.i18n import gettext, ngettext -""") +""", +) from . import errors, registry, revisionspec, trace from . import revision as _mod_revision @@ -116,7 +119,7 @@ def find_touching_revisions(repository, last_revision, last_tree, last_path): this_verifier = this_tree.get_file_verifier(this_path) else: this_verifier = this_tree.get_file_verifier(this_path) - if (this_verifier != last_verifier): + if this_verifier != last_verifier: yield revno, revision_id, "modified " + this_path last_verifier = this_verifier @@ -127,15 +130,17 @@ def find_touching_revisions(repository, last_revision, last_tree, last_path): revno -= 1 -def show_log(branch, - lf, - verbose=False, - direction='reverse', - start_revision=None, - end_revision=None, - limit=None, - show_diff=False, - match=None): +def show_log( + branch, + lf, + verbose=False, + direction="reverse", + start_revision=None, + end_revision=None, + limit=None, + show_diff=False, + match=None, +): """Write out human-readable log of commits to this branch. This function is being retained for backwards compatibility but @@ -163,11 +168,11 @@ def show_log(branch, properties. """ if verbose: - delta_type = 'full' + delta_type = "full" else: delta_type = None if show_diff: - diff_type = 'full' + diff_type = "full" else: diff_type = None @@ -189,30 +194,43 @@ def show_log(branch, # Build the request and execute it rqst = make_log_request_dict( direction=direction, - start_revision=start_revision, end_revision=end_revision, - limit=limit, delta_type=delta_type, diff_type=diff_type) + start_revision=start_revision, + end_revision=end_revision, + limit=limit, + delta_type=delta_type, + diff_type=diff_type, + ) Logger(branch, rqst).show(lf) # Note: This needs to be kept in sync with the defaults in # make_log_request_dict() below _DEFAULT_REQUEST_PARAMS = { - 'direction': 'reverse', - 'levels': None, - 'generate_tags': True, - 'exclude_common_ancestry': False, - '_match_using_deltas': True, - } - - -def make_log_request_dict(direction='reverse', specific_files=None, - start_revision=None, end_revision=None, limit=None, - message_search=None, levels=None, generate_tags=True, - delta_type=None, - diff_type=None, _match_using_deltas=True, - exclude_common_ancestry=False, match=None, - signature=False, omit_merges=False, - ): + "direction": "reverse", + "levels": None, + "generate_tags": True, + "exclude_common_ancestry": False, + "_match_using_deltas": True, +} + + +def make_log_request_dict( + direction="reverse", + specific_files=None, + start_revision=None, + end_revision=None, + limit=None, + message_search=None, + levels=None, + generate_tags=True, + delta_type=None, + diff_type=None, + _match_using_deltas=True, + exclude_common_ancestry=False, + match=None, + signature=False, + omit_merges=False, +): """Convenience function for making a logging request dictionary. Using this function may make code slightly safer by ensuring @@ -274,28 +292,28 @@ def make_log_request_dict(direction='reverse', specific_files=None, # Take care of old style message_search parameter if message_search: if match: - if 'message' in match: - match['message'].append(message_search) + if "message" in match: + match["message"].append(message_search) else: - match['message'] = [message_search] + match["message"] = [message_search] else: - match = {'message': [message_search]} + match = {"message": [message_search]} return { - 'direction': direction, - 'specific_files': specific_files, - 'start_revision': start_revision, - 'end_revision': end_revision, - 'limit': limit, - 'levels': levels, - 'generate_tags': generate_tags, - 'delta_type': delta_type, - 'diff_type': diff_type, - 'exclude_common_ancestry': exclude_common_ancestry, - 'signature': signature, - 'match': match, - 'omit_merges': omit_merges, + "direction": direction, + "specific_files": specific_files, + "start_revision": start_revision, + "end_revision": end_revision, + "limit": limit, + "levels": levels, + "generate_tags": generate_tags, + "delta_type": delta_type, + "diff_type": diff_type, + "exclude_common_ancestry": exclude_common_ancestry, + "signature": signature, + "match": match, + "omit_merges": omit_merges, # Add 'private' attributes for features that may be deprecated - '_match_using_deltas': _match_using_deltas, + "_match_using_deltas": _match_using_deltas, } @@ -361,10 +379,10 @@ def show(self, lf): warn(f"not a LogFormatter instance: {lf!r}", stacklevel=1) with self.branch.lock_read(): - if getattr(lf, 'begin_log', None): + if getattr(lf, "begin_log", None): lf.begin_log() self._show_body(lf) - if getattr(lf, 'end_log', None): + if getattr(lf, "end_log", None): lf.end_log() def _show_body(self, lf): @@ -375,18 +393,18 @@ def _show_body(self, lf): # Tweak the LogRequest based on what the LogFormatter can handle. # (There's no point generating stuff if the formatter can't display it.) rqst = self.rqst - if rqst['levels'] is None or lf.get_levels() > rqst['levels']: + if rqst["levels"] is None or lf.get_levels() > rqst["levels"]: # user didn't specify levels, use whatever the LF can handle: - rqst['levels'] = lf.get_levels() + rqst["levels"] = lf.get_levels() - if not getattr(lf, 'supports_tags', False): - rqst['generate_tags'] = False - if not getattr(lf, 'supports_delta', False): - rqst['delta_type'] = None - if not getattr(lf, 'supports_diff', False): - rqst['diff_type'] = None - if not getattr(lf, 'supports_signatures', False): - rqst['signature'] = False + if not getattr(lf, "supports_tags", False): + rqst["generate_tags"] = False + if not getattr(lf, "supports_delta", False): + rqst["delta_type"] = None + if not getattr(lf, "supports_diff", False): + rqst["diff_type"] = None + if not getattr(lf, "supports_signatures", False): + rqst["signature"] = False # Find and print the interesting revisions generator = self._generator_factory(self.branch, rqst) @@ -395,7 +413,8 @@ def _show_body(self, lf): lf.log_revision(lr) except errors.GhostRevisionUnusableHere as e: raise errors.CommandError( - gettext('Further revision history missing.')) from e + gettext("Further revision history missing.") + ) from e lf.show_advice() def _generator_factory(self, branch, rqst): @@ -407,43 +426,71 @@ def _generator_factory(self, branch, rqst): def _log_revision_iterator_using_per_file_graph( - branch, delta_type, match, levels, path, start_rev_id, end_rev_id, - direction, exclude_common_ancestry): + branch, + delta_type, + match, + levels, + path, + start_rev_id, + end_rev_id, + direction, + exclude_common_ancestry, +): # Get the base revisions, filtering by the revision range. # Note that we always generate the merge revisions because # filter_revisions_touching_path() requires them ... view_revisions = _calc_view_revisions( - branch, start_rev_id, end_rev_id, - direction, generate_merge_revisions=True, - exclude_common_ancestry=exclude_common_ancestry) + branch, + start_rev_id, + end_rev_id, + direction, + generate_merge_revisions=True, + exclude_common_ancestry=exclude_common_ancestry, + ) if not isinstance(view_revisions, list): view_revisions = list(view_revisions) view_revisions = _filter_revisions_touching_path( - branch, path, view_revisions, - include_merges=levels != 1) - return make_log_rev_iterator( - branch, view_revisions, delta_type, match) + branch, path, view_revisions, include_merges=levels != 1 + ) + return make_log_rev_iterator(branch, view_revisions, delta_type, match) def _log_revision_iterator_using_delta_matching( - branch, delta_type, match, levels, specific_files, start_rev_id, end_rev_id, - direction, exclude_common_ancestry, limit): + branch, + delta_type, + match, + levels, + specific_files, + start_rev_id, + end_rev_id, + direction, + exclude_common_ancestry, + limit, +): # Get the base revisions, filtering by the revision range generate_merge_revisions = levels != 1 delayed_graph_generation = not specific_files and ( - limit or start_rev_id or end_rev_id) + limit or start_rev_id or end_rev_id + ) view_revisions = _calc_view_revisions( - branch, start_rev_id, end_rev_id, + branch, + start_rev_id, + end_rev_id, direction, generate_merge_revisions=generate_merge_revisions, delayed_graph_generation=delayed_graph_generation, - exclude_common_ancestry=exclude_common_ancestry) + exclude_common_ancestry=exclude_common_ancestry, + ) # Apply the other filters - return make_log_rev_iterator(branch, view_revisions, - delta_type, match, - files=specific_files, - direction=direction) + return make_log_rev_iterator( + branch, + view_revisions, + delta_type, + match, + files=specific_files, + direction=direction, + ) def _format_diff(branch, rev, diff_type, files=None): @@ -461,14 +508,21 @@ def _format_diff(branch, rev, diff_type, files=None): ancestor_id = rev.parent_ids[0] tree_1 = repo.revision_tree(ancestor_id) tree_2 = repo.revision_tree(rev.revision_id) - if diff_type == 'partial' and files is not None: + if diff_type == "partial" and files is not None: specific_files = files else: specific_files = None s = BytesIO() path_encoding = get_diff_header_encoding() - diff.show_diff_trees(tree_1, tree_2, s, specific_files, old_label='', - new_label='', path_encoding=path_encoding) + diff.show_diff_trees( + tree_1, + tree_2, + s, + specific_files, + old_label="", + new_label="", + path_encoding=path_encoding, + ) return s.getvalue() @@ -480,12 +534,24 @@ class _DefaultLogGenerator(LogGenerator): """The default generator of log revisions.""" def __init__( - self, branch, levels=None, limit=None, diff_type=None, - delta_type=None, show_signature=None, omit_merges=None, - generate_tags=None, specific_files=None, match=None, - start_revision=None, end_revision=None, direction=None, - exclude_common_ancestry=None, _match_using_deltas=None, - signature=None): + self, + branch, + levels=None, + limit=None, + diff_type=None, + delta_type=None, + show_signature=None, + omit_merges=None, + generate_tags=None, + specific_files=None, + match=None, + start_revision=None, + end_revision=None, + direction=None, + exclude_common_ancestry=None, + _match_using_deltas=None, + signature=None, + ): self.branch = branch self.levels = levels self.limit = limit @@ -515,8 +581,11 @@ def iter_log_revisions(self): for revs in revision_iterator: for (rev_id, revno, merge_depth), rev, delta in revs: # 0 levels means show everything; merge_depth counts from 0 - if (self.levels != 0 and merge_depth is not None and - merge_depth >= self.levels): + if ( + self.levels != 0 + and merge_depth is not None + and merge_depth >= self.levels + ): continue if self.omit_merges and len(rev.parent_ids) > 1: continue @@ -526,15 +595,21 @@ def iter_log_revisions(self): diff = None else: diff = _format_diff( - self.branch, rev, self.diff_type, - self.specific_files) + self.branch, rev, self.diff_type, self.specific_files + ) if self.show_signature: signature = format_signature_validity(rev_id, self.branch) else: signature = None yield LogRevision( - rev, revno, merge_depth, delta, - self.rev_tag_dict.get(rev_id), diff, signature) + rev, + revno, + merge_depth, + delta, + self.rev_tag_dict.get(rev_id), + diff, + signature, + ) if self.limit: log_count += 1 if log_count >= self.limit: @@ -547,7 +622,8 @@ def _create_log_revision_iterator(self): delta). """ start_rev_id, end_rev_id = _get_revision_limits( - self.branch, self.start_revision, self.end_revision) + self.branch, self.start_revision, self.end_revision + ) if self._match_using_deltas: return _log_revision_iterator_using_delta_matching( self.branch, @@ -555,10 +631,12 @@ def _create_log_revision_iterator(self): match=self.match, levels=self.levels, specific_files=self.specific_files, - start_rev_id=start_rev_id, end_rev_id=end_rev_id, + start_rev_id=start_rev_id, + end_rev_id=end_rev_id, direction=self.direction, exclude_common_ancestry=self.exclude_common_ancestry, - limit=self.limit) + limit=self.limit, + ) else: # We're using the per-file-graph algorithm. This scales really # well but only makes sense if there is a single file and it's @@ -567,67 +645,84 @@ def _create_log_revision_iterator(self): if file_count != 1: raise errors.BzrError( "illegal LogRequest: must match-using-deltas " - "when logging %d files" % file_count) + "when logging %d files" % file_count + ) return _log_revision_iterator_using_per_file_graph( self.branch, delta_type=self.delta_type, match=self.match, levels=self.levels, path=self.specific_files[0], - start_rev_id=start_rev_id, end_rev_id=end_rev_id, + start_rev_id=start_rev_id, + end_rev_id=end_rev_id, direction=self.direction, - exclude_common_ancestry=self.exclude_common_ancestry - ) - - -def _calc_view_revisions(branch, start_rev_id, end_rev_id, direction, - generate_merge_revisions, - delayed_graph_generation=False, - exclude_common_ancestry=False, - ): + exclude_common_ancestry=self.exclude_common_ancestry, + ) + + +def _calc_view_revisions( + branch, + start_rev_id, + end_rev_id, + direction, + generate_merge_revisions, + delayed_graph_generation=False, + exclude_common_ancestry=False, +): """Calculate the revisions to view. :return: An iterator of (revision_id, dotted_revno, merge_depth) tuples OR a list of the same tuples. """ - if (exclude_common_ancestry and start_rev_id == end_rev_id): - raise errors.CommandError(gettext( - '--exclude-common-ancestry requires two different revisions')) - if direction not in ('reverse', 'forward'): - raise ValueError(gettext('invalid direction %r') % direction) + if exclude_common_ancestry and start_rev_id == end_rev_id: + raise errors.CommandError( + gettext("--exclude-common-ancestry requires two different revisions") + ) + if direction not in ("reverse", "forward"): + raise ValueError(gettext("invalid direction %r") % direction) br_rev_id = branch.last_revision() if br_rev_id == _mod_revision.NULL_REVISION: return [] - if (end_rev_id and start_rev_id == end_rev_id - and (not generate_merge_revisions - or not _has_merges(branch, end_rev_id))): + if ( + end_rev_id + and start_rev_id == end_rev_id + and (not generate_merge_revisions or not _has_merges(branch, end_rev_id)) + ): # If a single revision is requested, check we can handle it - return _generate_one_revision(branch, end_rev_id, br_rev_id, - branch.revno()) + return _generate_one_revision(branch, end_rev_id, br_rev_id, branch.revno()) if not generate_merge_revisions: try: # If we only want to see linear revisions, we can iterate ... iter_revs = _linear_view_revisions( - branch, start_rev_id, end_rev_id, - exclude_common_ancestry=exclude_common_ancestry) + branch, + start_rev_id, + end_rev_id, + exclude_common_ancestry=exclude_common_ancestry, + ) # If a start limit was given and it's not obviously an # ancestor of the end limit, check it before outputting anything - if (direction == 'forward' - or (start_rev_id and not _is_obvious_ancestor( - branch, start_rev_id, end_rev_id))): + if direction == "forward" or ( + start_rev_id + and not _is_obvious_ancestor(branch, start_rev_id, end_rev_id) + ): iter_revs = list(iter_revs) - if direction == 'forward': + if direction == "forward": iter_revs = reversed(iter_revs) return iter_revs except _StartNotLinearAncestor: # Switch to the slower implementation that may be able to find a # non-obvious ancestor out of the left-hand history. pass - iter_revs = _generate_all_revisions(branch, start_rev_id, end_rev_id, - direction, delayed_graph_generation, - exclude_common_ancestry) - if direction == 'forward': + iter_revs = _generate_all_revisions( + branch, + start_rev_id, + end_rev_id, + direction, + delayed_graph_generation, + exclude_common_ancestry, + ) + if direction == "forward": iter_revs = _rebase_merge_depth(reverse_by_depth(list(iter_revs))) return iter_revs @@ -641,9 +736,14 @@ def _generate_one_revision(branch, rev_id, br_rev_id, br_revno): return [(rev_id, revno_str, 0)] -def _generate_all_revisions(branch, start_rev_id, end_rev_id, direction, - delayed_graph_generation, - exclude_common_ancestry=False): +def _generate_all_revisions( + branch, + start_rev_id, + end_rev_id, + direction, + delayed_graph_generation, + exclude_common_ancestry=False, +): # On large trees, generating the merge graph can take 30-60 seconds # so we delay doing it until a merge is detected, incrementally # returning initial (non-merge) revisions while we can. @@ -655,7 +755,8 @@ def _generate_all_revisions(branch, start_rev_id, end_rev_id, direction, if delayed_graph_generation: try: for rev_id, revno, depth in _linear_view_revisions( - branch, start_rev_id, end_rev_id, exclude_common_ancestry): + branch, start_rev_id, end_rev_id, exclude_common_ancestry + ): if _has_merges(branch, rev_id): # The end_rev_id can be nested down somewhere. We need an # explicit ancestry check. There is an ambiguity here as we @@ -667,8 +768,9 @@ def _generate_all_revisions(branch, start_rev_id, end_rev_id, direction, # revisions have _mod_revision.NULL_REVISION as an ancestor # -- vila 20100319 graph = branch.repository.get_graph() - if (start_rev_id is not None - and not graph.is_ancestor(start_rev_id, end_rev_id)): + if start_rev_id is not None and not graph.is_ancestor( + start_rev_id, end_rev_id + ): raise _StartNotLinearAncestor() # Since we collected the revisions so far, we need to # adjust end_rev_id. @@ -682,8 +784,9 @@ def _generate_all_revisions(branch, start_rev_id, end_rev_id, direction, except _StartNotLinearAncestor as e: # A merge was never detected so the lower revision limit can't # be nested down somewhere - raise errors.CommandError(gettext('Start revision not found in' - ' history of end revision.')) from e + raise errors.CommandError( + gettext("Start revision not found in" " history of end revision.") + ) from e # We exit the loop above because we encounter a revision with merges, from # this revision, we need to switch to _graph_view_revisions. @@ -693,11 +796,16 @@ def _generate_all_revisions(branch, start_rev_id, end_rev_id, direction, # shown naturally, i.e. just like it is for linear logging. We can easily # make forward the exact opposite display, but showing the merge revisions # indented at the end seems slightly nicer in that case. - view_revisions = itertools.chain(iter(initial_revisions), - _graph_view_revisions(branch, start_rev_id, end_rev_id, - rebase_initial_depths=( - direction == 'reverse'), - exclude_common_ancestry=exclude_common_ancestry)) + view_revisions = itertools.chain( + iter(initial_revisions), + _graph_view_revisions( + branch, + start_rev_id, + end_rev_id, + rebase_initial_depths=(direction == "reverse"), + exclude_common_ancestry=exclude_common_ancestry, + ), + ) return view_revisions @@ -719,7 +827,7 @@ def _compute_revno_str(branch, rev_id): # The revision must be outside of this branch return None else: - return '.'.join(str(n) for n in revno) + return ".".join(str(n) for n in revno) def _is_obvious_ancestor(branch, start_rev_id, end_rev_id): @@ -734,8 +842,11 @@ def _is_obvious_ancestor(branch, start_rev_id, end_rev_id): if len(start_dotted) == 1 and len(end_dotted) == 1: # both on mainline return start_dotted[0] <= end_dotted[0] - elif (len(start_dotted) == 3 and len(end_dotted) == 3 and - start_dotted[0:1] == end_dotted[0:1]): + elif ( + len(start_dotted) == 3 + and len(end_dotted) == 3 + and start_dotted[0:1] == end_dotted[0:1] + ): # both on same development line return start_dotted[2] <= end_dotted[2] else: @@ -746,8 +857,9 @@ def _is_obvious_ancestor(branch, start_rev_id, end_rev_id): return True -def _linear_view_revisions(branch, start_rev_id, end_rev_id, - exclude_common_ancestry=False): +def _linear_view_revisions( + branch, start_rev_id, end_rev_id, exclude_common_ancestry=False +): """Calculate a sequence of revisions to view, newest to oldest. :param start_rev_id: the lower revision-id @@ -762,8 +874,9 @@ def _linear_view_revisions(branch, start_rev_id, end_rev_id, repo = branch.repository graph = repo.get_graph() if start_rev_id is None and end_rev_id is None: - if branch._format.stores_revno() or \ - config.GlobalStack().get('calculate_revnos'): + if branch._format.stores_revno() or config.GlobalStack().get( + "calculate_revnos" + ): try: br_revno, br_rev_id = branch.last_revision_info() except errors.GhostRevisionsHaveNoRevno: @@ -775,8 +888,9 @@ def _linear_view_revisions(branch, start_rev_id, end_rev_id, br_rev_id = branch.last_revision() cur_revno = None - graph_iter = graph.iter_lefthand_ancestry(br_rev_id, - (_mod_revision.NULL_REVISION,)) + graph_iter = graph.iter_lefthand_ancestry( + br_rev_id, (_mod_revision.NULL_REVISION,) + ) while True: try: revision_id = next(graph_iter) @@ -795,8 +909,9 @@ def _linear_view_revisions(branch, start_rev_id, end_rev_id, if end_rev_id is None: end_rev_id = br_rev_id found_start = start_rev_id is None - graph_iter = graph.iter_lefthand_ancestry(end_rev_id, - (_mod_revision.NULL_REVISION,)) + graph_iter = graph.iter_lefthand_ancestry( + end_rev_id, (_mod_revision.NULL_REVISION,) + ) while True: try: revision_id = next(graph_iter) @@ -819,9 +934,13 @@ def _linear_view_revisions(branch, start_rev_id, end_rev_id, raise _StartNotLinearAncestor() -def _graph_view_revisions(branch, start_rev_id, end_rev_id, - rebase_initial_depths=True, - exclude_common_ancestry=False): +def _graph_view_revisions( + branch, + start_rev_id, + end_rev_id, + rebase_initial_depths=True, + exclude_common_ancestry=False, +): """Calculate revisions to view including merges, newest to oldest. :param branch: the branch @@ -832,24 +951,22 @@ def _graph_view_revisions(branch, start_rev_id, end_rev_id, :return: An iterator of (revision_id, dotted_revno, merge_depth) tuples. """ if exclude_common_ancestry: - stop_rule = 'with-merges-without-common-ancestry' + stop_rule = "with-merges-without-common-ancestry" else: - stop_rule = 'with-merges' + stop_rule = "with-merges" view_revisions = branch.iter_merge_sorted_revisions( - start_revision_id=end_rev_id, stop_revision_id=start_rev_id, - stop_rule=stop_rule) + start_revision_id=end_rev_id, stop_revision_id=start_rev_id, stop_rule=stop_rule + ) if not rebase_initial_depths: - for (rev_id, merge_depth, revno, _end_of_merge - ) in view_revisions: - yield rev_id, '.'.join(map(str, revno)), merge_depth + for rev_id, merge_depth, revno, _end_of_merge in view_revisions: + yield rev_id, ".".join(map(str, revno)), merge_depth else: # We're following a development line starting at a merged revision. # We need to adjust depths down by the initial depth until we find # a depth less than it. Then we use that depth as the adjustment. # If and when we reach the mainline, depth adjustment ends. depth_adjustment = None - for (rev_id, merge_depth, revno, _end_of_merge - ) in view_revisions: + for rev_id, merge_depth, revno, _end_of_merge in view_revisions: if depth_adjustment is None: depth_adjustment = merge_depth if depth_adjustment: @@ -860,7 +977,7 @@ def _graph_view_revisions(branch, start_rev_id, end_rev_id, # though. depth_adjustment = merge_depth merge_depth -= depth_adjustment - yield rev_id, '.'.join(map(str, revno)), merge_depth + yield rev_id, ".".join(map(str, revno)), merge_depth def _rebase_merge_depth(view_revisions): @@ -869,13 +986,13 @@ def _rebase_merge_depth(view_revisions): if view_revisions and view_revisions[0][2] and view_revisions[-1][2]: min_depth = min([d for r, n, d in view_revisions]) if min_depth != 0: - view_revisions = [(r, n, d - min_depth) - for r, n, d in view_revisions] + view_revisions = [(r, n, d - min_depth) for r, n, d in view_revisions] return view_revisions -def make_log_rev_iterator(branch, view_revisions, generate_delta, search, - files=None, direction='reverse'): +def make_log_rev_iterator( + branch, view_revisions, generate_delta, search, files=None, direction="reverse" +): """Create a revision iterator for log. :param branch: The branch being logged. @@ -897,20 +1014,21 @@ def make_log_rev_iterator(branch, view_revisions, generate_delta, search, nones = [None] * len(view_revisions) log_rev_iterator = iter([list(zip(view_revisions, nones, nones))]) else: + def _convert(): for view in view_revisions: yield (view, None, None) + log_rev_iterator = iter([_convert()]) for adapter in log_adapters: # It would be nicer if log adapters were first class objects # with custom parameters. This will do for now. IGC 20090127 if adapter == _make_delta_filter: log_rev_iterator = adapter( - branch, generate_delta, search, log_rev_iterator, files, - direction) + branch, generate_delta, search, log_rev_iterator, files, direction + ) else: - log_rev_iterator = adapter( - branch, generate_delta, search, log_rev_iterator) + log_rev_iterator = adapter(branch, generate_delta, search, log_rev_iterator) return log_rev_iterator @@ -931,8 +1049,10 @@ def _make_search_filter(branch, generate_delta, match, log_rev_iterator): if not match: return log_rev_iterator # Use lazy_compile so mapping to InvalidPattern error occurs. - searchRE = [(k, [lazy_regex.lazy_compile(x, re.IGNORECASE) for x in v]) - for k, v in match.items()] + searchRE = [ + (k, [lazy_regex.lazy_compile(x, re.IGNORECASE) for x in v]) + for k, v in match.items() + ] return _filter_re(searchRE, log_rev_iterator) @@ -945,13 +1065,12 @@ def _filter_re(search_re, log_rev_iterator): def _match_filter(search_re, rev): strings = { - 'message': (rev.message,), - 'committer': (rev.committer,), - 'author': (rev.get_apparent_authors()), - 'bugs': list(_mod_revision.iter_bugs(rev)) - } - strings[''] = [item for inner_list in strings.values() - for item in inner_list] + "message": (rev.message,), + "committer": (rev.committer,), + "author": (rev.get_apparent_authors()), + "bugs": list(_mod_revision.iter_bugs(rev)), + } + strings[""] = [item for inner_list in strings.values() for item in inner_list] for k, v in search_re: if k in strings and not _match_any_filter(strings[k], v): return False @@ -962,8 +1081,9 @@ def _match_any_filter(strings, res): return any(r.search(s) for r in res for s in strings) -def _make_delta_filter(branch, generate_delta, search, log_rev_iterator, - files=None, direction='reverse'): +def _make_delta_filter( + branch, generate_delta, search, log_rev_iterator, files=None, direction="reverse" +): """Add revision deltas to a log iterator if needed. :param branch: The branch being logged. @@ -980,12 +1100,12 @@ def _make_delta_filter(branch, generate_delta, search, log_rev_iterator, """ if not generate_delta and not files: return log_rev_iterator - return _generate_deltas(branch.repository, log_rev_iterator, - generate_delta, files, direction) + return _generate_deltas( + branch.repository, log_rev_iterator, generate_delta, files, direction + ) -def _generate_deltas(repository, log_rev_iterator, delta_type, files, - direction): +def _generate_deltas(repository, log_rev_iterator, delta_type, files, direction): """Create deltas for each batch of revisions in log_rev_iterator. If we're only generating deltas for the sake of filtering against @@ -996,10 +1116,10 @@ def _generate_deltas(repository, log_rev_iterator, delta_type, files, check_files = files is not None and len(files) > 0 if check_files: file_set = set(files) - if direction == 'reverse': - stop_on = 'add' + if direction == "reverse": + stop_on = "add" else: - stop_on = 'remove' + stop_on = "remove" else: file_set = None for revs in log_rev_iterator: @@ -1009,13 +1129,12 @@ def _generate_deltas(repository, log_rev_iterator, delta_type, files, return revisions = [rev[1] for rev in revs] new_revs = [] - if delta_type == 'full' and not check_files: + if delta_type == "full" and not check_files: deltas = repository.get_revision_deltas(revisions) for rev, delta in zip(revs, deltas): new_revs.append((rev[0], rev[1], delta)) else: - deltas = repository.get_revision_deltas( - revisions, specific_files=file_set) + deltas = repository.get_revision_deltas(revisions, specific_files=file_set) for rev, delta in zip(revs, deltas): if check_files: if delta is None or not delta.has_changed(): @@ -1024,7 +1143,7 @@ def _generate_deltas(repository, log_rev_iterator, delta_type, files, _update_files(delta, file_set, stop_on) if delta_type is None: delta = None - elif delta_type == 'full': + elif delta_type == "full": # If the file matches all the time, rebuilding # a full delta like this in addition to a partial # one could be slow. However, it's likely that @@ -1045,7 +1164,7 @@ def _update_files(delta, files, stop_on): :param stop_on: either 'add' or 'remove' - take files out of the files set once their add or remove entry is detected respectively """ - if stop_on == 'add': + if stop_on == "add": for item in delta.added: if item.path[1] in files: files.remove(item.path[1]) @@ -1053,12 +1172,12 @@ def _update_files(delta, files, stop_on): if item.path[1] in files: files.remove(item.path[1]) files.add(item.path[0]) - if item.kind[1] == 'directory': + if item.kind[1] == "directory": for path in list(files): if is_inside(item.path[1], path): files.remove(path) - files.add(item.path[0] + path[len(item.path[1]):]) - elif stop_on == 'delete': + files.add(item.path[0] + path[len(item.path[1]) :]) + elif stop_on == "delete": for item in delta.removed: if item.path[0] in files: files.remove(item.path[0]) @@ -1066,11 +1185,11 @@ def _update_files(delta, files, stop_on): if item.path[0] in files: files.remove(item.path[0]) files.add(item.path[1]) - if item.kind[0] == 'directory': + if item.kind[0] == "directory": for path in list(files): if is_inside(item.path[0], path): files.remove(path) - files.add(item.path[1] + path[len(item.path[0]):]) + files.add(item.path[1] + path[len(item.path[0]) :]) def _make_revision_objects(branch, generate_delta, search, log_rev_iterator): @@ -1144,13 +1263,15 @@ def _get_revision_limits(branch, start_revision, end_revision): end_revno = end_revision.revno if branch.last_revision() != _mod_revision.NULL_REVISION: - if (start_rev_id == _mod_revision.NULL_REVISION - or end_rev_id == _mod_revision.NULL_REVISION): - raise errors.CommandError( - gettext('Logging revision 0 is invalid.')) + if ( + start_rev_id == _mod_revision.NULL_REVISION + or end_rev_id == _mod_revision.NULL_REVISION + ): + raise errors.CommandError(gettext("Logging revision 0 is invalid.")) if end_revno is not None and start_revno > end_revno: raise errors.CommandError( - gettext("Start revision must be older than the end revision.")) + gettext("Start revision must be older than the end revision.") + ) return (start_rev_id, end_rev_id) @@ -1203,12 +1324,14 @@ def _get_mainline_revs(branch, start_revision, end_revision): branch.check_real_revno(end_revision) end_revno = end_revision - if ((start_rev_id == _mod_revision.NULL_REVISION) - or (end_rev_id == _mod_revision.NULL_REVISION)): - raise errors.CommandError(gettext('Logging revision 0 is invalid.')) + if (start_rev_id == _mod_revision.NULL_REVISION) or ( + end_rev_id == _mod_revision.NULL_REVISION + ): + raise errors.CommandError(gettext("Logging revision 0 is invalid.")) if start_revno > end_revno: - raise errors.CommandError(gettext("Start revision must be older " - "than the end revision.")) + raise errors.CommandError( + gettext("Start revision must be older " "than the end revision.") + ) if end_revno < start_revno: return None, None, None, None @@ -1217,7 +1340,8 @@ def _get_mainline_revs(branch, start_revision, end_revision): mainline_revs = [] graph = branch.repository.get_graph() for revision_id in graph.iter_lefthand_ancestry( - branch_last_revision, (_mod_revision.NULL_REVISION,)): + branch_last_revision, (_mod_revision.NULL_REVISION,) + ): if cur_revno < start_revno: # We have gone far enough, but we always add 1 more revision rev_nos[revision_id] = cur_revno @@ -1237,8 +1361,7 @@ def _get_mainline_revs(branch, start_revision, end_revision): return mainline_revs, rev_nos, start_rev_id, end_rev_id -def _filter_revisions_touching_path(branch, path, view_revisions, - include_merges=True): +def _filter_revisions_touching_path(branch, path, view_revisions, include_merges=True): r"""Return the list of revision ids which touch a given path. The function filters view_revisions and returns a subset. @@ -1294,10 +1417,9 @@ def _filter_revisions_touching_path(branch, path, view_revisions, modified_text_revisions = set() chunk_size = 1000 for start in range(0, len(text_keys), chunk_size): - next_keys = text_keys[start:start + chunk_size] + next_keys = text_keys[start : start + chunk_size] # Only keep the revision_id portion of the key - modified_text_revisions.update( - [k[1] for k in get_parent_map(next_keys)]) + modified_text_revisions.update([k[1] for k in get_parent_map(next_keys)]) del text_keys, next_keys result = [] @@ -1309,7 +1431,7 @@ def _filter_revisions_touching_path(branch, path, view_revisions, if depth == len(current_merge_stack): current_merge_stack.append(info) else: - del current_merge_stack[depth + 1:] + del current_merge_stack[depth + 1 :] current_merge_stack[-1] = info if rev_id in modified_text_revisions: @@ -1362,8 +1484,16 @@ class LogRevision: logging options and the log formatter capabilities. """ - def __init__(self, rev=None, revno=None, merge_depth=0, delta=None, - tags=None, diff=None, signature=None): + def __init__( + self, + rev=None, + revno=None, + merge_depth=0, + delta=None, + tags=None, + diff=None, + signature=None, + ): self.rev = rev if revno is None: self.revno = None @@ -1417,11 +1547,20 @@ def my_show_properties(properties_dict): # code that returns a dict {'name':'value'} of the properties # to be shown """ + preferred_levels = 0 - def __init__(self, to_file, show_ids=False, show_timezone='original', - delta_format=None, levels=None, show_advice=False, - to_exact_file=None, author_list_handler=None): + def __init__( + self, + to_file, + show_ids=False, + show_timezone="original", + delta_format=None, + levels=None, + show_advice=False, + to_exact_file=None, + author_list_handler=None, + ): """Create a LogFormatter. :param to_file: the file to output to @@ -1448,7 +1587,7 @@ def __init__(self, to_file, show_ids=False, show_timezone='original', # XXX: somewhat hacky; this assumes it's a codec writer; it's # better for code that expects to get diffs to pass in the exact # file stream - self.to_exact_file = getattr(to_file, 'stream', to_file) + self.to_exact_file = getattr(to_file, "stream", to_file) self.show_ids = show_ids self.show_timezone = show_timezone if delta_format is None: @@ -1462,7 +1601,7 @@ def __init__(self, to_file, show_ids=False, show_timezone='original', def get_levels(self): """Get the number of levels to display or 0 for all.""" - if getattr(self, 'supports_merge_revisions', False): + if getattr(self, "supports_merge_revisions", False): if self.levels is None or self.levels == -1: self.levels = self.preferred_levels else: @@ -1474,7 +1613,7 @@ def log_revision(self, revision): :param revision: The LogRevision to be logged. """ - raise NotImplementedError('not implemented in abstract base') + raise NotImplementedError("not implemented in abstract base") def show_advice(self): """Output user advice, if any, when the log is completed.""" @@ -1482,12 +1621,11 @@ def show_advice(self): advice_sep = self.get_advice_separator() if advice_sep: self.to_file.write(advice_sep) - self.to_file.write( - "Use --include-merged or -n0 to see merged revisions.\n") + self.to_file.write("Use --include-merged or -n0 to see merged revisions.\n") def get_advice_separator(self): """Get the text separating the log from the closing advice.""" - return '' + return "" def short_committer(self, rev): name, address = config.parse_username(rev.committer) @@ -1496,7 +1634,7 @@ def short_committer(self, rev): return address def short_author(self, rev): - return self.authors(rev, 'first', short=True, sep=', ') + return self.authors(rev, "first", short=True, sep=", ") def authors(self, rev, who, short=False, sep=None): """Generate list of authors, taking --authors option into account. @@ -1536,9 +1674,9 @@ def merge_marker(self, revision): """Get the merge marker to include in the output or '' if none.""" if len(revision.rev.parent_ids) > 1: self._merge_count += 1 - return ' [merge]' + return " [merge]" else: - return '' + return "" def show_properties(self, revision, indent): """Displays the custom properties returned by each registered handler. @@ -1572,7 +1710,8 @@ def _foreign_info_properties(self, rev): # Revision comes directly from a foreign repository if isinstance(rev, foreign.ForeignRevision): return self._format_properties( - rev.mapping.vcs.show_foreign_revid(rev.foreign_revid)) + rev.mapping.vcs.show_foreign_revid(rev.foreign_revid) + ) # Imported foreign revision revision ids always contain : if b":" not in rev.revision_id: @@ -1580,32 +1719,31 @@ def _foreign_info_properties(self, rev): # Revision was once imported from a foreign repository try: - foreign_revid, mapping = \ - foreign.foreign_vcs_registry.parse_revision_id(rev.revision_id) + foreign_revid, mapping = foreign.foreign_vcs_registry.parse_revision_id( + rev.revision_id + ) except errors.InvalidRevisionId: return [] - return self._format_properties( - mapping.vcs.show_foreign_revid(foreign_revid)) + return self._format_properties(mapping.vcs.show_foreign_revid(foreign_revid)) def _format_properties(self, properties): lines = [] for key, value in properties.items(): - lines.append(key + ': ' + value) + lines.append(key + ": " + value) return lines def show_diff(self, to_file, diff, indent): encoding = get_terminal_encoding() - for l in diff.rstrip().split(b'\n'): - to_file.write(indent + l.decode(encoding, 'ignore') + '\n') + for l in diff.rstrip().split(b"\n"): + to_file.write(indent + l.decode(encoding, "ignore") + "\n") # Separator between revisions in long format -_LONG_SEP = '-' * 60 +_LONG_SEP = "-" * 60 class LongLogFormatter(LogFormatter): - supports_merge_revisions = True preferred_levels = 1 supports_delta = True @@ -1615,28 +1753,34 @@ class LongLogFormatter(LogFormatter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self.show_timezone == 'original': + if self.show_timezone == "original": self.date_string = self._date_string_original_timezone else: self.date_string = self._date_string_with_timezone def _date_string_with_timezone(self, rev): try: - return format_date(rev.timestamp, rev.timezone or 0, - timezone=self.show_timezone) + return format_date( + rev.timestamp, rev.timezone or 0, timezone=self.show_timezone + ) except UnsupportedTimezoneFormat as e: - raise errors.CommandError(gettext('Unsupported timezone format "{}", options are "utc", "original", "local".').format(self.show_timezone)) from e + raise errors.CommandError( + gettext( + 'Unsupported timezone format "{}", options are "utc", "original", "local".' + ).format(self.show_timezone) + ) from e def _date_string_original_timezone(self, rev): - return format_date_with_offset_in_original_timezone(rev.timestamp, - rev.timezone or 0) + return format_date_with_offset_in_original_timezone( + rev.timestamp, rev.timezone or 0 + ) def log_revision(self, revision): """Log a revision, either merged or not.""" - indent = ' ' * revision.merge_depth + indent = " " * revision.merge_depth lines = [_LONG_SEP] if revision.revno is not None: - lines.append(f'revno: {revision.revno}{self.merge_marker(revision)}') + lines.append(f"revno: {revision.revno}{self.merge_marker(revision)}") if revision.tags: lines.append(f"tags: {', '.join(sorted(revision.tags))}") if self.show_ids or revision.revno is None: @@ -1647,41 +1791,51 @@ def log_revision(self, revision): lines.extend(self.custom_properties(revision.rev)) committer = revision.rev.committer - authors = self.authors(revision.rev, 'all') + authors = self.authors(revision.rev, "all") if authors != [committer]: lines.append(f"author: {', '.join(authors)}") - lines.append(f'committer: {committer}') + lines.append(f"committer: {committer}") - branch_nick = revision.rev.properties.get('branch-nick', None) + branch_nick = revision.rev.properties.get("branch-nick", None) if branch_nick is not None: - lines.append(f'branch nick: {branch_nick}') + lines.append(f"branch nick: {branch_nick}") try: - lines.append(f'timestamp: {self.date_string(revision.rev)}') + lines.append(f"timestamp: {self.date_string(revision.rev)}") except UnsupportedTimezoneFormat as e: - raise errors.CommandError(gettext('Unsupported timezone format "{}", options are "utc", "original", "local".').format(self.show_timezone)) from e + raise errors.CommandError( + gettext( + 'Unsupported timezone format "{}", options are "utc", "original", "local".' + ).format(self.show_timezone) + ) from e if revision.signature is not None: - lines.append('signature: ' + revision.signature) + lines.append("signature: " + revision.signature) - lines.append('message:') + lines.append("message:") if not revision.rev.message: - lines.append(' (no message)') + lines.append(" (no message)") else: - message = revision.rev.message.rstrip('\r\n') - for l in message.split('\n'): - lines.append(f' {l}') + message = revision.rev.message.rstrip("\r\n") + for l in message.split("\n"): + lines.append(f" {l}") # Dump the output, appending the delta and diff if requested to_file = self.to_file - to_file.write("{}{}\n".format(indent, ('\n' + indent).join(lines))) + to_file.write("{}{}\n".format(indent, ("\n" + indent).join(lines))) if revision.delta is not None: # Use the standard status output to display changes from .delta import report_delta - report_delta(to_file, revision.delta, short_status=False, - show_ids=self.show_ids, indent=indent) + + report_delta( + to_file, + revision.delta, + short_status=False, + show_ids=self.show_ids, + indent=indent, + ) if revision.diff is not None: - to_file.write(indent + 'diff:\n') + to_file.write(indent + "diff:\n") to_file.flush() # Note: we explicitly don't indent the diff (relative to the # revision information) so that the output can be fed to patch -p0 @@ -1690,11 +1844,10 @@ def log_revision(self, revision): def get_advice_separator(self): """Get the text separating the log from the closing advice.""" - return '-' * 60 + '\n' + return "-" * 60 + "\n" class ShortLogFormatter(LogFormatter): - supports_merge_revisions = True preferred_levels = 1 supports_delta = True @@ -1713,53 +1866,71 @@ def log_revision(self, revision): # as we might be starting from a dotted revno in the first column # and we want subsequent mainline revisions to line up. depth = revision.merge_depth - indent = ' ' * depth + indent = " " * depth revno_width = self.revno_width_by_depth.get(depth) if revno_width is None: - if revision.revno is None or revision.revno.find('.') == -1: + if revision.revno is None or revision.revno.find(".") == -1: # mainline revno, e.g. 12345 revno_width = 5 else: # dotted revno, e.g. 12345.10.55 revno_width = 11 self.revno_width_by_depth[depth] = revno_width - offset = ' ' * (revno_width + 1) + offset = " " * (revno_width + 1) to_file = self.to_file - tags = '' + tags = "" if revision.tags: - tags = ' {%s}' % (', '.join(sorted(revision.tags))) - to_file.write(indent + "%*s %s\t%s%s%s\n" % (revno_width, - revision.revno or "", self.short_author( - revision.rev), - format_date(revision.rev.timestamp, - revision.rev.timezone or 0, - self.show_timezone, date_fmt="%Y-%m-%d", - show_offset=False), - tags, self.merge_marker(revision))) + tags = " {%s}" % (", ".join(sorted(revision.tags))) + to_file.write( + indent + + "%*s %s\t%s%s%s\n" + % ( + revno_width, + revision.revno or "", + self.short_author(revision.rev), + format_date( + revision.rev.timestamp, + revision.rev.timezone or 0, + self.show_timezone, + date_fmt="%Y-%m-%d", + show_offset=False, + ), + tags, + self.merge_marker(revision), + ) + ) self.show_properties(revision.rev, indent + offset) if self.show_ids or revision.revno is None: - to_file.write(indent + offset + 'revision-id:{}\n'.format(revision.rev.revision_id.decode('utf-8'))) + to_file.write( + indent + + offset + + "revision-id:{}\n".format(revision.rev.revision_id.decode("utf-8")) + ) if not revision.rev.message: - to_file.write(indent + offset + '(no message)\n') + to_file.write(indent + offset + "(no message)\n") else: - message = revision.rev.message.rstrip('\r\n') - for l in message.split('\n'): - to_file.write(indent + offset + f'{l}\n') + message = revision.rev.message.rstrip("\r\n") + for l in message.split("\n"): + to_file.write(indent + offset + f"{l}\n") if revision.delta is not None: # Use the standard status output to display changes from .delta import report_delta - report_delta(to_file, revision.delta, - short_status=self.delta_format == 1, - show_ids=self.show_ids, indent=indent + offset) + + report_delta( + to_file, + revision.delta, + short_status=self.delta_format == 1, + show_ids=self.show_ids, + indent=indent + offset, + ) if revision.diff is not None: - self.show_diff(self.to_exact_file, revision.diff, ' ') - to_file.write('\n') + self.show_diff(self.to_exact_file, revision.diff, " ") + to_file.write("\n") class LineLogFormatter(LogFormatter): - supports_merge_revisions = True preferred_levels = 1 supports_tags = True @@ -1775,26 +1946,33 @@ def __init__(self, *args, **kwargs): def truncate(self, str, max_len): if max_len is None or len(str) <= max_len: return str - return str[:max_len - 3] + '...' + return str[: max_len - 3] + "..." def date_string(self, rev): - return format_date(rev.timestamp, rev.timezone or 0, - self.show_timezone, date_fmt="%Y-%m-%d", - show_offset=False) + return format_date( + rev.timestamp, + rev.timezone or 0, + self.show_timezone, + date_fmt="%Y-%m-%d", + show_offset=False, + ) def message(self, rev): if not rev.message: - return '(no message)' + return "(no message)" else: return rev.message def log_revision(self, revision): - indent = ' ' * revision.merge_depth - self.to_file.write(self.log_string(revision.revno, revision.rev, - self._max_chars, revision.tags, indent)) - self.to_file.write('\n') - - def log_string(self, revno, rev, max_chars, tags=None, prefix=''): + indent = " " * revision.merge_depth + self.to_file.write( + self.log_string( + revision.revno, revision.rev, self._max_chars, revision.tags, indent + ) + ) + self.to_file.write("\n") + + def log_string(self, revno, rev, max_chars, tags=None, prefix=""): """Format log info into one string. Truncate tail of string. :param revno: revision number or None. @@ -1810,22 +1988,20 @@ def log_string(self, revno, rev, max_chars, tags=None, prefix=''): # show revno only when is not None out.append(f"{revno}:") if max_chars is not None: - out.append(self.truncate( - self.short_author(rev), (max_chars + 3) // 4)) + out.append(self.truncate(self.short_author(rev), (max_chars + 3) // 4)) else: out.append(self.short_author(rev)) out.append(self.date_string(rev)) if len(rev.parent_ids) > 1: - out.append('[merge]') + out.append("[merge]") if tags: - tag_str = '{%s}' % (', '.join(sorted(tags))) + tag_str = "{%s}" % (", ".join(sorted(tags))) out.append(tag_str) out.append(rev.get_summary()) - return self.truncate(prefix + " ".join(out).rstrip('\n'), max_chars) + return self.truncate(prefix + " ".join(out).rstrip("\n"), max_chars) class GnuChangelogLogFormatter(LogFormatter): - supports_merge_revisions = True supports_delta = True @@ -1833,34 +2009,38 @@ def log_revision(self, revision): """Log a revision, either merged or not.""" to_file = self.to_file - date_str = format_date(revision.rev.timestamp, - revision.rev.timezone or 0, - self.show_timezone, - date_fmt='%Y-%m-%d', - show_offset=False) - committer_str = self.authors(revision.rev, 'first', sep=', ') - committer_str = committer_str.replace(' <', ' <') - to_file.write(f'{date_str} {committer_str}\n\n') + date_str = format_date( + revision.rev.timestamp, + revision.rev.timezone or 0, + self.show_timezone, + date_fmt="%Y-%m-%d", + show_offset=False, + ) + committer_str = self.authors(revision.rev, "first", sep=", ") + committer_str = committer_str.replace(" <", " <") + to_file.write(f"{date_str} {committer_str}\n\n") if revision.delta is not None and revision.delta.has_changed(): - for c in revision.delta.added + revision.delta.removed + revision.delta.modified: + for c in ( + revision.delta.added + revision.delta.removed + revision.delta.modified + ): if c.path[0] is None: path = c.path[1] else: path = c.path[0] - to_file.write(f'\t* {path}:\n') + to_file.write(f"\t* {path}:\n") for c in revision.delta.renamed + revision.delta.copied: # For renamed files, show both the old and the new path - to_file.write(f'\t* {c.path[0]}:\n\t* {c.path[1]}:\n') - to_file.write('\n') + to_file.write(f"\t* {c.path[0]}:\n\t* {c.path[1]}:\n") + to_file.write("\n") if not revision.rev.message: - to_file.write('\tNo commit message\n') + to_file.write("\tNo commit message\n") else: - message = revision.rev.message.rstrip('\r\n') - for l in message.split('\n'): - to_file.write(f'\t{l.lstrip()}\n') - to_file.write('\n') + message = revision.rev.message.rstrip("\r\n") + for l in message.split("\n"): + to_file.write(f"\t{l.lstrip()}\n") + to_file.write("\n") def line_log(rev, max_chars): @@ -1881,20 +2061,22 @@ def make_formatter(self, name, *args, **kwargs): def get_default(self, branch): c = branch.get_config_stack() - return self.get(c.get('log_format')) + return self.get(c.get("log_format")) log_formatter_registry = LogFormatterRegistry() -log_formatter_registry.register('short', ShortLogFormatter, - 'Moderately short log format.') -log_formatter_registry.register('long', LongLogFormatter, - 'Detailed log format.') -log_formatter_registry.register('line', LineLogFormatter, - 'Log format with one line per revision.') -log_formatter_registry.register('gnu-changelog', GnuChangelogLogFormatter, - 'Format used by GNU ChangeLog files.') +log_formatter_registry.register( + "short", ShortLogFormatter, "Moderately short log format." +) +log_formatter_registry.register("long", LongLogFormatter, "Detailed log format.") +log_formatter_registry.register( + "line", LineLogFormatter, "Log format with one line per revision." +) +log_formatter_registry.register( + "gnu-changelog", GnuChangelogLogFormatter, "Format used by GNU ChangeLog files." +) def register_formatter(name, formatter): @@ -1910,8 +2092,7 @@ def log_formatter(name, *args, **kwargs): try: return log_formatter_registry.make_formatter(name, *args, **kwargs) except KeyError as e: - raise errors.CommandError( - gettext("unknown log formatter: %r") % name) from e + raise errors.CommandError(gettext("unknown log formatter: %r") % name) from e def author_list_all(rev): @@ -1930,20 +2111,18 @@ def author_list_committer(rev): return [rev.committer] -author_list_registry = registry.Registry[str, Callable[[_mod_revision.Revision], List[str]], None]() +author_list_registry = registry.Registry[ + str, Callable[[_mod_revision.Revision], List[str]], None +]() -author_list_registry.register('all', author_list_all, - 'All authors') +author_list_registry.register("all", author_list_all, "All authors") -author_list_registry.register('first', author_list_first, - 'The first author') +author_list_registry.register("first", author_list_first, "The first author") -author_list_registry.register('committer', author_list_committer, - 'The committer') +author_list_registry.register("committer", author_list_committer, "The committer") -def show_changed_revisions(branch, old_rh, new_rh, to_file=None, - log_format='long'): +def show_changed_revisions(branch, old_rh, new_rh, to_file=None, log_format="long"): """Show the change in revision history comparing the old revision history to the new one. :param branch: The branch where the revisions exist @@ -1952,46 +2131,46 @@ def show_changed_revisions(branch, old_rh, new_rh, to_file=None, :param to_file: A file to write the results to. If None, stdout will be used """ if to_file is None: - to_file = codecs.getwriter(get_terminal_encoding())(sys.stdout, - errors='replace') - lf = log_formatter(log_format, - show_ids=False, - to_file=to_file, - show_timezone='original') + to_file = codecs.getwriter(get_terminal_encoding())( + sys.stdout, errors="replace" + ) + lf = log_formatter( + log_format, show_ids=False, to_file=to_file, show_timezone="original" + ) # This is the first index which is different between # old and new base_idx = None for i in range(max(len(new_rh), len(old_rh))): - if (len(new_rh) <= i - or len(old_rh) <= i - or new_rh[i] != old_rh[i]): + if len(new_rh) <= i or len(old_rh) <= i or new_rh[i] != old_rh[i]: base_idx = i break if base_idx is None: - to_file.write('Nothing seems to have changed\n') + to_file.write("Nothing seems to have changed\n") return # TODO: It might be nice to do something like show_log # and show the merged entries. But since this is the # removed revisions, it shouldn't be as important if base_idx < len(old_rh): - to_file.write('*' * 60) - to_file.write('\nRemoved Revisions:\n') + to_file.write("*" * 60) + to_file.write("\nRemoved Revisions:\n") for i in range(base_idx, len(old_rh)): rev = branch.repository.get_revision(old_rh[i]) lr = LogRevision(rev, i + 1, 0, None) lf.log_revision(lr) - to_file.write('*' * 60) - to_file.write('\n\n') + to_file.write("*" * 60) + to_file.write("\n\n") if base_idx < len(new_rh): - to_file.write('Added Revisions:\n') - show_log(branch, - lf, - verbose=False, - direction='forward', - start_revision=base_idx + 1, - end_revision=len(new_rh)) + to_file.write("Added Revisions:\n") + show_log( + branch, + lf, + verbose=False, + direction="forward", + start_revision=base_idx + 1, + end_revision=len(new_rh), + ) def get_history_change(old_revision_id, new_revision_id, repository): @@ -2039,8 +2218,8 @@ def get_history_change(old_revision_id, new_revision_id, repository): new_history.reverse() old_history.reverse() if stop_revision is not None: - new_history = new_history[new_history.index(stop_revision) + 1:] - old_history = old_history[old_history.index(stop_revision) + 1:] + new_history = new_history[new_history.index(stop_revision) + 1 :] + old_history = old_history[old_history.index(stop_revision) + 1 :] return old_history, new_history @@ -2053,26 +2232,27 @@ def show_branch_change(branch, output, old_revno, old_revision_id): :param old_revision_id: The revision_id of the old tip. """ new_revno, new_revision_id = branch.last_revision_info() - old_history, new_history = get_history_change(old_revision_id, - new_revision_id, - branch.repository) + old_history, new_history = get_history_change( + old_revision_id, new_revision_id, branch.repository + ) if old_history == [] and new_history == []: - output.write('Nothing seems to have changed\n') + output.write("Nothing seems to have changed\n") return log_format = log_formatter_registry.get_default(branch) - lf = log_format(show_ids=False, to_file=output, show_timezone='original') + lf = log_format(show_ids=False, to_file=output, show_timezone="original") if old_history != []: - output.write('*' * 60) - output.write('\nRemoved Revisions:\n') + output.write("*" * 60) + output.write("\nRemoved Revisions:\n") show_flat_log(branch.repository, old_history, old_revno, lf) - output.write('*' * 60) - output.write('\n\n') + output.write("*" * 60) + output.write("\n\n") if new_history != []: - output.write('Added Revisions:\n') + output.write("Added Revisions:\n") start_revno = new_revno - len(new_history) + 1 - show_log(branch, lf, verbose=False, direction='forward', - start_revision=start_revno) + show_log( + branch, lf, verbose=False, direction="forward", start_revision=start_revno + ) def show_flat_log(repository, history, last_revno, lf): @@ -2107,8 +2287,8 @@ def _get_info_for_log_files(revisionspec_list, file_list, exit_stack): branch will be read-locked. """ from .builtins import _get_revision_range - tree, b, path = controldir.ControlDir.open_containing_tree_or_branch( - file_list[0]) + + tree, b, path = controldir.ControlDir.open_containing_tree_or_branch(file_list[0]) exit_stack.enter_context(b.lock_read()) # XXX: It's damn messy converting a list of paths to relative paths when # those paths might be deleted ones, they might be on a case-insensitive @@ -2123,9 +2303,8 @@ def _get_info_for_log_files(revisionspec_list, file_list, exit_stack): else: relpaths = [path] + file_list[1:] info_list = [] - start_rev_info, end_rev_info = _get_revision_range(revisionspec_list, b, - "log") - if relpaths in ([], ['']): + start_rev_info, end_rev_info = _get_revision_range(revisionspec_list, b, "log") + if relpaths in ([], [""]): return b, [], start_rev_info, end_rev_info if start_rev_info is None and end_rev_info is None: if tree is None: @@ -2187,7 +2366,9 @@ def _get_kind_for_file(tree, path): return None -properties_handler_registry = registry.Registry[str, Callable[[Dict[str, str]], Dict[str, str]], None]() +properties_handler_registry = registry.Registry[ + str, Callable[[Dict[str, str]], Dict[str, str]], None +]() # Use the properties handlers to print out bug information if available @@ -2196,23 +2377,23 @@ def _bugs_properties_handler(revision): fixed_bug_urls = [] related_bug_urls = [] for bug_url, status in _mod_revision.iter_bugs(revision): - if status == 'fixed': + if status == "fixed": fixed_bug_urls.append(bug_url) - elif status == 'related': + elif status == "related": related_bug_urls.append(bug_url) ret = {} if fixed_bug_urls: - text = ngettext('fixes bug', 'fixes bugs', len(fixed_bug_urls)) - ret[text] = ' '.join(fixed_bug_urls) + text = ngettext("fixes bug", "fixes bugs", len(fixed_bug_urls)) + ret[text] = " ".join(fixed_bug_urls) if related_bug_urls: - text = ngettext('related bug', 'related bugs', - len(related_bug_urls)) - ret[text] = ' '.join(related_bug_urls) + text = ngettext("related bug", "related bugs", len(related_bug_urls)) + ret[text] = " ".join(related_bug_urls) return ret -properties_handler_registry.register('bugs_properties_handler', - _bugs_properties_handler) +properties_handler_registry.register( + "bugs_properties_handler", _bugs_properties_handler +) # adapters which revision ids to log are filtered. When log is called, the @@ -2229,5 +2410,5 @@ def _bugs_properties_handler(revision): # filter on log messages _make_search_filter, # generate deltas for things we will show - _make_delta_filter - ] + _make_delta_filter, +] diff --git a/breezy/lru_cache.py b/breezy/lru_cache.py index d16a16f701..e3e9a6a327 100644 --- a/breezy/lru_cache.py +++ b/breezy/lru_cache.py @@ -24,7 +24,7 @@ class _LRUNode: """This maintains the linked-list which is the lru internals.""" - __slots__ = ('prev', 'next_key', 'key', 'value') + __slots__ = ("prev", "next_key", "key", "value") def __init__(self, key, value): self.prev = None @@ -37,8 +37,9 @@ def __repr__(self): prev_key = None else: prev_key = self.prev.key - return '{}({!r} n:{!r} p:{!r})'.format(self.__class__.__name__, self.key, - self.next_key, prev_key) + return "{}({!r} n:{!r} p:{!r})".format( + self.__class__.__name__, self.key, self.next_key, prev_key + ) class LRUCache: @@ -92,7 +93,7 @@ def __len__(self): def __setitem__(self, key, value): """Add a new value to the cache.""" if key is _null_key: - raise ValueError('cannot use _null_key as a key') + raise ValueError("cannot use _null_key as a key") if key in self._cache: node = self._cache[key] node.value = value @@ -201,16 +202,14 @@ def clear(self): def resize(self, max_cache, after_cleanup_count=None): """Change the number of entries that will be cached.""" - self._update_max_cache(max_cache, - after_cleanup_count=after_cleanup_count) + self._update_max_cache(max_cache, after_cleanup_count=after_cleanup_count) def _update_max_cache(self, max_cache, after_cleanup_count=None): self._max_cache = max_cache if after_cleanup_count is None: self._after_cleanup_count = self._max_cache * 8 // 10 else: - self._after_cleanup_count = min(after_cleanup_count, - self._max_cache) + self._after_cleanup_count = min(after_cleanup_count, self._max_cache) self.cleanup() @@ -224,8 +223,9 @@ class LRUSizeCache(LRUCache): defaults to len() if not supplied. """ - def __init__(self, max_size=1024 * 1024, after_cleanup_size=None, - compute_size=None): + def __init__( + self, max_size=1024 * 1024, after_cleanup_size=None, compute_size=None + ): """Create a new LRUSizeCache. :param max_size: The max number of bytes to store before we start @@ -249,16 +249,21 @@ def __init__(self, max_size=1024 * 1024, after_cleanup_size=None, def __setitem__(self, key, value): """Add a new value to the cache.""" if key is _null_key: - raise ValueError('cannot use _null_key as a key') + raise ValueError("cannot use _null_key as a key") node = self._cache.get(key, None) value_len = self._compute_size(value) if value_len >= self._after_cleanup_size: # The new value is 'too big to fit', as it would fill up/overflow # the cache all by itself - trace.mutter('Adding the key %r to an LRUSizeCache failed.' - ' value %d is too big to fit in a the cache' - ' with size %d %d', key, value_len, - self._after_cleanup_size, self._max_size) + trace.mutter( + "Adding the key %r to an LRUSizeCache failed." + " value %d is too big to fit in a the cache" + " with size %d %d", + key, + value_len, + self._after_cleanup_size, + self._max_size, + ) if node is not None: # We won't be replacing the old node, so just remove it self._remove_node(node) diff --git a/breezy/lsprof.py b/breezy/lsprof.py index f3fc8a351a..2017e09f37 100644 --- a/breezy/lsprof.py +++ b/breezy/lsprof.py @@ -16,7 +16,7 @@ from . import errors -__all__ = ['profile', 'Stats'] +__all__ = ["profile", "Stats"] def profile(f, *args, **kwds): @@ -73,8 +73,7 @@ def start(self): """ self._g_threadmap = {} self.p = Profiler() - permitted = self.__class__.profiler_lock.acquire( - self.__class__.profiler_block) + permitted = self.__class__.profiler_lock.acquire(self.__class__.profiler_block) if not permitted: raise errors.InternalBzrError(msg="Already profiling something") try: @@ -133,7 +132,7 @@ def sort(self, crit="inlinetime", reverse=True): :param crit: the data attribute used as the sort key. """ - if crit not in profiler_entry.__dict__ or crit == 'code': + if crit not in profiler_entry.__dict__ or crit == "code": raise ValueError(f"Can't sort by {crit}") key_func = operator.attrgetter(crit) @@ -158,16 +157,39 @@ def pprint(self, top=None, file=None): d = d[:top] cols = "% 12s %12s %11.4f %11.4f %s\n" hcols = "% 12s %12s %12s %12s %s\n" - file.write(hcols % ("CallCount", "Recursive", "Total(ms)", - "Inline(ms)", "module:lineno(function)")) + file.write( + hcols + % ( + "CallCount", + "Recursive", + "Total(ms)", + "Inline(ms)", + "module:lineno(function)", + ) + ) for e in d: - file.write(cols % (e.callcount, e.reccallcount, e.totaltime, - e.inlinetime, label(e.code))) + file.write( + cols + % ( + e.callcount, + e.reccallcount, + e.totaltime, + e.inlinetime, + label(e.code), + ) + ) if e.calls: for se in e.calls: - file.write(cols % (f"+{se.callcount}", se.reccallcount, - se.totaltime, se.inlinetime, - f"+{label(se.code)}")) + file.write( + cols + % ( + f"+{se.callcount}", + se.reccallcount, + se.totaltime, + se.inlinetime, + f"+{label(se.code)}", + ) + ) def freeze(self): """Replace all references to code objects with string @@ -203,20 +225,20 @@ def save(self, filename, format=None): """ if format is None: basename = os.path.basename(filename) - if basename.startswith('callgrind.out'): + if basename.startswith("callgrind.out"): format = "callgrind" else: ext = os.path.splitext(filename)[1] if len(ext) > 1: format = ext[1:] - with open(filename, 'wb') as outfile: + with open(filename, "wb") as outfile: if format == "callgrind": # The callgrind format states it is 'ASCII based': # # But includes filenames so lets ignore and use UTF-8. - self.calltree(codecs.getwriter('utf-8')(outfile)) + self.calltree(codecs.getwriter("utf-8")(outfile)) elif format == "txt": - self.pprint(file=codecs.getwriter('utf-8')(outfile)) + self.pprint(file=codecs.getwriter("utf-8")(outfile)) else: self.freeze() pickle.dump(self, outfile, 2) @@ -239,7 +261,7 @@ def __init__(self, data): def output(self, out_file): self.out_file = out_file - out_file.write('events: Ticks\n') + out_file.write("events: Ticks\n") self._print_summary() for entry in self.data: self._entry(entry) @@ -249,21 +271,21 @@ def _print_summary(self): for entry in self.data: totaltime = int(entry.totaltime * 1000) max_cost = max(max_cost, totaltime) - self.out_file.write('summary: %d\n' % (max_cost,)) + self.out_file.write("summary: %d\n" % (max_cost,)) def _entry(self, entry): out_file = self.out_file code = entry.code inlinetime = int(entry.inlinetime * 1000) if isinstance(code, str): - out_file.write('fi=~\n') + out_file.write("fi=~\n") else: - out_file.write(f'fi={code.co_filename}\n') - out_file.write(f'fn={label(code, True)}\n') + out_file.write(f"fi={code.co_filename}\n") + out_file.write(f"fn={label(code, True)}\n") if isinstance(code, str): - out_file.write(f'0 {inlinetime}\n') + out_file.write(f"0 {inlinetime}\n") else: - out_file.write('%d %d\n' % (code.co_firstlineno, inlinetime)) + out_file.write("%d %d\n" % (code.co_firstlineno, inlinetime)) # recursive calls are counted in entry.calls if entry.calls: calls = entry.calls @@ -275,22 +297,21 @@ def _entry(self, entry): lineno = code.co_firstlineno for subentry in calls: self._subentry(lineno, subentry) - out_file.write('\n') + out_file.write("\n") def _subentry(self, lineno, subentry): out_file = self.out_file code = subentry.code totaltime = int(subentry.totaltime * 1000) if isinstance(code, str): - out_file.write('cfi=~\n') - out_file.write(f'cfn={label(code, True)}\n') - out_file.write('calls=%d 0\n' % (subentry.callcount,)) + out_file.write("cfi=~\n") + out_file.write(f"cfn={label(code, True)}\n") + out_file.write("calls=%d 0\n" % (subentry.callcount,)) else: - out_file.write(f'cfi={code.co_filename}\n') - out_file.write(f'cfn={label(code, True)}\n') - out_file.write('calls=%d %d\n' % ( - subentry.callcount, code.co_firstlineno)) - out_file.write('%d %d\n' % (lineno, totaltime)) + out_file.write(f"cfi={code.co_filename}\n") + out_file.write(f"cfn={label(code, True)}\n") + out_file.write("calls=%d %d\n" % (subentry.callcount, code.co_firstlineno)) + out_file.write("%d %d\n" % (lineno, totaltime)) _fn2mod: Dict[str, object] = {} @@ -305,7 +326,7 @@ def label(code, calltree=False): for k, v in sys.modules.items(): if v is None: continue - if getattr(v, '__file__', None) is None: + if getattr(v, "__file__", None) is None: continue if not isinstance(v.__file__, str): continue @@ -313,11 +334,11 @@ def label(code, calltree=False): mname = _fn2mod[code.co_filename] = k break else: - mname = _fn2mod[code.co_filename] = f'<{code.co_filename}>' + mname = _fn2mod[code.co_filename] = f"<{code.co_filename}>" if calltree: - return '%s %s:%d' % (code.co_name, mname, code.co_firstlineno) + return "%s %s:%d" % (code.co_name, mname, code.co_firstlineno) else: - return '%s:%d(%s)' % (mname, code.co_firstlineno, code.co_name) + return "%s:%d(%s)" % (mname, code.co_firstlineno, code.co_name) def main(): @@ -326,10 +347,11 @@ def main(): sys.stderr.write("usage: lsprof.py