diff --git a/bin/tidy-imports b/bin/tidy-imports index 7639e4b5..7737a80d 100755 --- a/bin/tidy-imports +++ b/bin/tidy-imports @@ -27,10 +27,10 @@ import sys import os from pyflyby._cmdline import hfmt, parse_args, process_actions -from pyflyby._imports2s import (canonicalize_imports, - fix_unused_and_missing_imports, - replace_star_imports, - transform_imports) +from pyflyby._imports2s import (canonicalize_imports, + fix_unused_and_missing_imports, + replace_star_imports, + transform_imports, sort_imports) from pyflyby._log import logger import toml @@ -145,20 +145,23 @@ def main(): parser.set_defaults(**default_config) options, args = parse_args( _add_opts_and_defaults, import_format_params=True, modify_action_params=True, defaults=default_config) + def modify(x): - if options.canonicalize: - x = canonicalize_imports(x, params=options.params) if options.transformations: x = transform_imports(x, options.transformations, params=options.params) if options.replace_star_imports: x = replace_star_imports(x, params=options.params) - return fix_unused_and_missing_imports( + x = fix_unused_and_missing_imports( x, params=options.params, add_missing=options.add_missing, remove_unused=options.remove_unused, add_mandatory=options.add_mandatory, ) + sorted_imports = sort_imports(x) + if options.canonicalize: + sorted_imports = canonicalize_imports(sorted_imports, params=options.params) + return sorted_imports process_actions(args, options.actions, modify) diff --git a/lib/python/pyflyby/_imports2s.py b/lib/python/pyflyby/_imports2s.py index 7a554f08..5a33879f 100644 --- a/lib/python/pyflyby/_imports2s.py +++ b/lib/python/pyflyby/_imports2s.py @@ -1,9 +1,9 @@ # pyflyby/_imports2s.py. # Copyright (C) 2011-2018 Karl Chen. # License: MIT http://opensource.org/licenses/MIT +from collections import defaultdict - - +import isort from pyflyby._autoimp import scan_for_import_issues from pyflyby._file import FileText, Filename from pyflyby._flags import CompilerFlags @@ -534,6 +534,48 @@ def replace_star_imports(codeblock, params=None): return transformer.output(params=params) +def sort_imports(codeblock): + """ + Sort imports for better grouping. + :param codeblock: + :return: codeblock + """ + sorted_imports = isort.code( + str(codeblock), + # To sort all the import in lexicographic order + force_sort_within_sections=True, + # This is done below + lines_between_sections=0, + lines_after_imports=1 + ) + # Step 1: Split the input string into a list of lines + lines = sorted_imports.split('\n') + + # Step 2: Identify groups of imports and keep track of their line numbers + pkg_lines = defaultdict(list) + line_pkg_dict = {} + for i, line in enumerate(lines): + match = re.match(r'(from (\w+)|import (\w+))', line) + if match: + current_pkg = match.groups()[1:3] + current_pkg = current_pkg[0] if current_pkg[0] is not None else current_pkg[1] + pkg_lines[current_pkg].append(i) + line_pkg_dict[i] = current_pkg + + # Step 3: Create the output list of lines with blank lines around groups with more than one import + output_lines = [] + for i, line in enumerate(lines): + if i > 0 and line_pkg_dict.get(i) != line_pkg_dict.get(i-1) and len(pkg_lines[line_pkg_dict.get(i)]) > 1: + output_lines.append('') + output_lines.append(line) + if i < len(lines) - 1 and line_pkg_dict.get(i) != line_pkg_dict.get(i+1) and len(pkg_lines[line_pkg_dict.get(i)]) > 1: + output_lines.append('') + + # Step 4: Join the lines to create the output string + sorted_output_str = '\n'.join(output_lines) + return PythonBlock(sorted_output_str) + + def transform_imports(codeblock, transformations, params=None): """ Transform imports as specified by ``transformations``. diff --git a/setup.py b/setup.py index 01db9c90..14218edd 100755 --- a/setup.py +++ b/setup.py @@ -221,7 +221,7 @@ def make_distribution(self): "License :: OSI Approved :: MIT License", "Programming Language :: Python", ], - install_requires=["six", "toml", "pathlib ; python_version<'3'"], + install_requires=["six", "toml", "isort", "pathlib ; python_version<'3'"], python_requires=">3.0, !=3.0.*, !=3.1.*, !=3.2.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*, <4", tests_require=['pexpect>=3.3', 'pytest', 'epydoc', 'rlipython', 'requests'], cmdclass = { diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index e7c549b2..23c0af88 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -121,6 +121,7 @@ def foo(): foo() + os + sys import a import c + a, c ''').lstrip() assert result == expected_result @@ -138,6 +139,7 @@ def test_tidy_imports_no_add_no_remove_1(): import a import b import c + a, c, os, sys ''').strip() assert result == expected @@ -409,6 +411,7 @@ def test_tidy_imports_query_no_change_1(): input = dedent(''' from __future__ import absolute_import, division import x1 + x1 ''') with tempfile.NamedTemporaryFile(suffix=".py", mode='w+') as f: @@ -446,6 +449,7 @@ def test_tidy_imports_query_y_1(): expected = dedent(""" from __future__ import absolute_import, division import x1 + x1 """) assert output == expected @@ -681,3 +685,61 @@ def test_tidy_imports_symlinks_bad_argument(): assert b"error: --symlinks must be one of" in proc_output assert output == input assert symlink_output == input + + +def test_tidy_imports_sorting(): + with tempfile.NamedTemporaryFile(suffix=".py", mode='w+') as f: + f.write(dedent(""" + import numpy + + from pkg1.mod1 import foo + from pkg1.mod2 import bar + from pkg2 import baz + import yy + + from pkg1.mod1 import foo2 + from pkg1.mod3 import quux + from pkg2 import baar + import sympy + import zz + + + zz.foo() + bar() + quux() + foo2() + yy.f() + bar() + foo() + numpy.arange() + baz + baar + sympy + """).lstrip()) + f.flush() + result = pipe([BIN_DIR+"/tidy-imports", f.name]) + expected = dedent(""" + import numpy + + from pkg1.mod1 import foo, foo2 + from pkg1.mod2 import bar + from pkg1.mod3 import quux + + from pkg2 import baar, baz + import sympy + import yy + import zz + + zz.foo() + bar() + quux() + foo2() + yy.f() + bar() + foo() + numpy.arange() + baz + baar + sympy + """).strip().format(f=f) + assert result == expected