diff --git a/lib/python/pyflyby/_importdb.py b/lib/python/pyflyby/_importdb.py index 4df4186e..818121f2 100644 --- a/lib/python/pyflyby/_importdb.py +++ b/lib/python/pyflyby/_importdb.py @@ -284,7 +284,7 @@ def get_default(cls, target_filename: Union[Filename, str], /): target_path = Path(target_filename).resolve() parents: List[Path] - if not target_path.is_dir(): + if target_path.is_dir(): parents = [target_path] else: parents = [] diff --git a/tests/test_importdb.py b/tests/test_importdb.py index 80bdb8a4..438f6ac4 100644 --- a/tests/test_importdb.py +++ b/tests/test_importdb.py @@ -6,6 +6,7 @@ import os +import sys from shutil import rmtree from tempfile import NamedTemporaryFile, mkdtemp from textwrap import dedent @@ -15,8 +16,31 @@ from pyflyby._importstmt import Import from pyflyby._util import EnvVarCtx +from contextlib import contextmanager +if sys.version_info > (3, 11): + from contextlib import chdir + +else: + + @contextmanager + def chdir(path): + old = os.getcwd() + try: + os.chdir(path) + yield + finally: + os.chdir(old) + + +def test_importDB_root(): + """ + See #362 + """ + with chdir("/"): + ImportDB.get_default(None) + def test_ImportDB_from_code_1(): db = ImportDB('from aa.bb import cc as dd, ee') expected_known = ImportSet(['from aa.bb import cc as dd, ee'])