From 2b1f9cdbfeace556553294fa5ad73b0921c598ef Mon Sep 17 00:00:00 2001 From: Amit Kumar Date: Tue, 12 Sep 2023 20:32:14 +0530 Subject: [PATCH] fix is not matching pattern --- bin/tidy-imports | 7 ++++--- lib/python/pyflyby/_imports2s.py | 14 +++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/bin/tidy-imports b/bin/tidy-imports index 6bed81ae..09ce2bef 100755 --- a/bin/tidy-imports +++ b/bin/tidy-imports @@ -146,8 +146,6 @@ def main(): options, args = parse_args( _add_opts_and_defaults, import_format_params=True, modify_action_params=True, defaults=default_config) def modify(x): - if options.canonicalize: - x = canonicalize_imports(x, params=options.params) if options.transformations: x = transform_imports(x, options.transformations, params=options.params) @@ -159,7 +157,10 @@ def main(): remove_unused=options.remove_unused, add_mandatory=options.add_mandatory, ) - return sort_imports(x) + sx = sort_imports(x) + if options.canonicalize: + sx = canonicalize_imports(x, params=options.params) + return sx process_actions(args, options.actions, modify) diff --git a/lib/python/pyflyby/_imports2s.py b/lib/python/pyflyby/_imports2s.py index c8961d39..8d42f08d 100644 --- a/lib/python/pyflyby/_imports2s.py +++ b/lib/python/pyflyby/_imports2s.py @@ -549,19 +549,19 @@ def sort_imports(codeblock): line_pkg_dict = {} for i, line in enumerate(lines): match = re.match(r'(from (\w+)|import (\w+))', line) - current_pkg = match.groups()[1:3] - current_pkg = current_pkg[0] if current_pkg[0] is not None else current_pkg[1] - - pkg_lines[current_pkg].append(i) - line_pkg_dict[i] = current_pkg + if match: + current_pkg = match.groups()[1:3] + current_pkg = current_pkg[0] if current_pkg[0] is not None else current_pkg[1] + pkg_lines[current_pkg].append(i) + line_pkg_dict[i] = current_pkg # Step 3: Create the output list of lines with blank lines around groups with more than one import output_lines = [] for i, line in enumerate(lines): - if i > 0 and line_pkg_dict[i] != line_pkg_dict[i-1] and len(pkg_lines[line_pkg_dict[i]]) > 1: + if i > 0 and line_pkg_dict.get(i) != line_pkg_dict.get(i-1) and len(pkg_lines[line_pkg_dict.get(i)]) > 1: output_lines.append('') output_lines.append(line) - if i < len(lines) - 1 and line_pkg_dict[i] != line_pkg_dict[i+1] and len(pkg_lines[line_pkg_dict[i]]) > 1: + if i < len(lines) - 1 and line_pkg_dict.get(i) != line_pkg_dict.get(i+1) and len(pkg_lines[line_pkg_dict.get(i)]) > 1: output_lines.append('') # Step 4: Join the lines to create the output string