From d2a342eff6935e8ddc3394412525a4ab73d55310 Mon Sep 17 00:00:00 2001 From: Amit Kumar Date: Tue, 12 Sep 2023 20:23:20 +0530 Subject: [PATCH] insert lines between packages --- lib/python/pyflyby/_imports2s.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/lib/python/pyflyby/_imports2s.py b/lib/python/pyflyby/_imports2s.py index 4d20cb1d..c8961d39 100644 --- a/lib/python/pyflyby/_imports2s.py +++ b/lib/python/pyflyby/_imports2s.py @@ -1,6 +1,8 @@ # pyflyby/_imports2s.py. # Copyright (C) 2011-2018 Karl Chen. # License: MIT http://opensource.org/licenses/MIT +from collections import defaultdict + import isort from pyflyby._autoimp import scan_for_import_issues from pyflyby._file import FileText, Filename @@ -538,7 +540,33 @@ def sort_imports(codeblock): :param codeblock: :return: codeblock """ - return isort.code(codeblock) + sorted_imports = isort.code(str(codeblock), force_sort_within_sections=True) + # Step 1: Split the input string into a list of lines + lines = sorted_imports.strip().split('\n') + + # Step 2: Identify groups of imports and keep track of their line numbers + pkg_lines = defaultdict(list) + line_pkg_dict = {} + for i, line in enumerate(lines): + match = re.match(r'(from (\w+)|import (\w+))', line) + current_pkg = match.groups()[1:3] + current_pkg = current_pkg[0] if current_pkg[0] is not None else current_pkg[1] + + pkg_lines[current_pkg].append(i) + line_pkg_dict[i] = current_pkg + + # Step 3: Create the output list of lines with blank lines around groups with more than one import + output_lines = [] + for i, line in enumerate(lines): + if i > 0 and line_pkg_dict[i] != line_pkg_dict[i-1] and len(pkg_lines[line_pkg_dict[i]]) > 1: + output_lines.append('') + output_lines.append(line) + if i < len(lines) - 1 and line_pkg_dict[i] != line_pkg_dict[i+1] and len(pkg_lines[line_pkg_dict[i]]) > 1: + output_lines.append('') + + # Step 4: Join the lines to create the output string + sorted_output_str = '\n'.join(output_lines) + return PythonBlock(sorted_output_str) def transform_imports(codeblock, transformations, params=None):