Skip to content

Commit

Permalink
Merge pull request #289 from krassowski/fix-and-test-black-tidy-imports
Browse files Browse the repository at this point in the history
Fix tidy-imports failure when black is used and `pyproject.toml` is missing
  • Loading branch information
dshivashankar1994 authored Jan 30, 2024
2 parents c25497f + 13db297 commit 58ff86f
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
6 changes: 3 additions & 3 deletions lib/python/pyflyby/_importstmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def read_black_config():
"""
value = find_pyproject_toml('.')
pyproject_path = find_pyproject_toml('.')

raw_config = parse_pyproject_toml(value)
raw_config = parse_pyproject_toml(pyproject_path) if pyproject_path else {}

config = {}
for key in [
Expand All @@ -46,7 +46,7 @@ def read_black_config():
config["target_version"] = set(target_version)
else:
raise ValueError(
f"Invalid config for black = {target_version!r} in {value}"
f"Invalid config for black = {target_version!r} in {pyproject_path}"
)
return config

Expand Down
69 changes: 67 additions & 2 deletions tests/test_importstmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# http://creativecommons.org/publicdomain/zero/1.0/


from pytest import raises
from unittest.mock import patch

from pyflyby._flags import CompilerFlags
from pyflyby._importstmt import Import, ImportSplit, ImportStatement

from pyflyby._format import FormatParams
from pyflyby._importstmt import (Import, ImportSplit, ImportStatement,
read_black_config)

def test_Import_from_parts_1():
imp = Import.from_parts(".foo.bar", "bar")
Expand Down Expand Up @@ -118,6 +121,13 @@ def test_Import_replace_2():
assert result == Import('from xx import yy as bb')


@patch("pyflyby._importstmt.read_black_config", lambda: {"line_length": 20})
def test_Import_black_line_length():
stmt = Import("from a123456789 import b123456789")
result = stmt.pretty_print(params=FormatParams(use_black=True))
assert result == "from a123456789 import (\n b123456789,\n)\n"


def test_ImportStatement_1():
stmt = ImportStatement("import foo . bar")
assert stmt.fromname == None
Expand Down Expand Up @@ -206,3 +216,58 @@ def test_ImportStatement_eqne_2():
assert not (stmt1a != stmt1b)
assert (stmt1a != stmt2 )
assert not (stmt1a == stmt2 )


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: None)
def test_ImportStatement_pretty_print_black_no_config():
# running should not error out when no pyproject.toml file is found
stmt = ImportStatement("from a import b")
result = stmt.pretty_print(params=FormatParams(use_black=True))
assert isinstance(result, str)


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: None)
def test_read_black_config_no_config():
# reading black config should work when no pyproject.toml file is found
config = read_black_config()
assert config == {}


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml")
@patch(
"pyflyby._importstmt.parse_pyproject_toml",
lambda path: {
"line_length": 80,
"skip_magic_trailing_comma": True,
"skip_string_normalization": False,
"skip_source_first_line": True
}
)
def test_read_black_config_extracts_config_subset():
config = read_black_config()
# should copy the desired black options
assert config["line_length"] == 80
assert config["skip_magic_trailing_comma"] == True
assert config["skip_string_normalization"] == False
# should not copy anything else
assert "skip_source_first_line" not in config


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("pyflyby._importstmt.parse_pyproject_toml", lambda path: {"target_version": ["py310", "py311"]})
def test_read_black_config_target_version_list():
config = read_black_config()
assert config["target_version"] == {"py310", "py311"}


@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("pyflyby._importstmt.parse_pyproject_toml", lambda path: {"target_version": "py311"})
def test_read_black_config_target_version_str():
config = read_black_config()
assert config["target_version"] == "py311"

@patch("pyflyby._importstmt.find_pyproject_toml", lambda root: "pyproject.toml")
@patch("pyflyby._importstmt.parse_pyproject_toml", lambda path: {"target_version": object()})
def test_read_black_config_target_version_other():
with raises(ValueError, match="Invalid config for black"):
read_black_config()

0 comments on commit 58ff86f

Please sign in to comment.