Skip to content

Commit

Permalink
pythongh-109590 - Fix shutil.which to more closely match CMD on win32
Browse files Browse the repository at this point in the history
An extensionless file will on be attempted if '.' is in PATHEXT

Fix up the tests to make better use of the bytes testing of shutil.which
  • Loading branch information
csm10495 committed Sep 24, 2023
1 parent 3d88145 commit 41cb299
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 13 deletions.
12 changes: 9 additions & 3 deletions Lib/shutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,11 +1551,17 @@ def which(cmd, mode=os.F_OK | os.X_OK, path=None):
pathext_source = os.getenv("PATHEXT") or _WIN_DEFAULT_PATHEXT
pathext = [ext for ext in pathext_source.split(os.pathsep) if ext]

dot = '.'
if use_bytes:
pathext = [os.fsencode(ext) for ext in pathext]

# Always try checking the originally given cmd, if it doesn't match, try pathext
files = [cmd] + [cmd + ext for ext in pathext]
dot = b'.'

# Attempt to match CMD behavior:
# Only try the given cmd if it has an extension (therefore has a dot)
# or a dot is a pathext in PATHEXT.
# Otherwise use PATHEXT to formulate paths to check.
files = (([cmd] if (dot in cmd or dot in pathext) else []) +
[cmd + ext for ext in pathext])
else:
# On other platforms you don't have things like PATHEXT to tell you
# what file suffixes are executable, so just pass on cmd as-is.
Expand Down
69 changes: 59 additions & 10 deletions Lib/test/test_shutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,6 +2067,14 @@ def setUp(self):
self.curdir = os.curdir
self.ext = ".EXE"

def to_text_type(self, s):
'''
In this class we're testing with str, so convert s to a str
'''
if isinstance(s, bytes):
return s.decode()
return s

def test_basic(self):
# Given an EXE in a directory, it should be returned.
rv = shutil.which(self.file, path=self.dir)
Expand Down Expand Up @@ -2254,9 +2262,9 @@ def test_empty_path_no_PATH(self):

@unittest.skipUnless(sys.platform == "win32", 'test specific to Windows')
def test_pathext(self):
ext = ".xyz"
ext = self.to_text_type(".xyz")
temp_filexyz = tempfile.NamedTemporaryFile(dir=self.temp_dir,
prefix="Tmp2", suffix=ext)
prefix=self.to_text_type("Tmp2"), suffix=ext)
os.chmod(temp_filexyz.name, stat.S_IXUSR)
self.addCleanup(temp_filexyz.close)

Expand All @@ -2265,38 +2273,39 @@ def test_pathext(self):
program = os.path.splitext(program)[0]

with os_helper.EnvironmentVarGuard() as env:
env['PATHEXT'] = ext
env['PATHEXT'] = ext if isinstance(ext, str) else ext.decode()
rv = shutil.which(program, path=self.temp_dir)
self.assertEqual(rv, temp_filexyz.name)

# Issue 40592: See https://bugs.python.org/issue40592
@unittest.skipUnless(sys.platform == "win32", 'test specific to Windows')
def test_pathext_with_empty_str(self):
ext = ".xyz"
ext = self.to_text_type(".xyz")
temp_filexyz = tempfile.NamedTemporaryFile(dir=self.temp_dir,
prefix="Tmp2", suffix=ext)
prefix=self.to_text_type("Tmp2"), suffix=ext)
self.addCleanup(temp_filexyz.close)

# strip path and extension
program = os.path.basename(temp_filexyz.name)
program = os.path.splitext(program)[0]

with os_helper.EnvironmentVarGuard() as env:
env['PATHEXT'] = f"{ext};" # note the ;
env['PATHEXT'] = f"{ext if isinstance(ext, str) else ext.decode()};" # note the ;
rv = shutil.which(program, path=self.temp_dir)
self.assertEqual(rv, temp_filexyz.name)

# See GH-75586
@unittest.skipUnless(sys.platform == "win32", 'test specific to Windows')
def test_pathext_applied_on_files_in_path(self):
with os_helper.EnvironmentVarGuard() as env:
env["PATH"] = self.temp_dir
env["PATH"] = self.temp_dir if isinstance(self.temp_dir, str) else self.temp_dir.decode()
env["PATHEXT"] = ".test"

test_path = pathlib.Path(self.temp_dir) / "test_program.test"
test_path.touch(mode=0o755)
test_path = os.path.join(self.temp_dir, self.to_text_type("test_program.test"))
open(test_path, 'w').close()
os.chmod(test_path, 0o755)

self.assertEqual(shutil.which("test_program"), str(test_path))
self.assertEqual(shutil.which(self.to_text_type("test_program")), test_path)

# See GH-75586
@unittest.skipUnless(sys.platform == "win32", 'test specific to Windows')
Expand All @@ -2312,16 +2321,56 @@ def test_win_path_needs_curdir(self):
self.assertFalse(shutil._win_path_needs_curdir('dontcare', os.X_OK))
need_curdir_mock.assert_called_once_with('dontcare')

# See GH-109590
@unittest.skipUnless(sys.platform == "win32", 'test specific to Windows')
def test_extensionless_file_resolution_no_dot_in_pathext(self):
with os_helper.EnvironmentVarGuard() as env:
env['PATHEXT'] = ".test;"
env['PATH'] = self.temp_dir if isinstance(self.temp_dir, str) else self.temp_dir.decode()

extensionless_file_in_path = os.path.join(self.temp_dir, self.to_text_type("file"))
open(extensionless_file_in_path, 'w').close()

extensioned_file_in_path = os.path.join(self.temp_dir, self.to_text_type("file.test"))
open(extensioned_file_in_path, 'w').close()


self.assertEqual(shutil.which(self.to_text_type('file'), os.F_OK), extensioned_file_in_path)

# See GH-109590
@unittest.skipUnless(sys.platform == "win32", 'test specific to Windows')
def test_extensionless_file_resolution_dot_in_pathext(self):
with os_helper.EnvironmentVarGuard() as env:
env['PATHEXT'] = ".test;.;"
env['PATH'] = self.temp_dir if isinstance(self.temp_dir, str) else self.temp_dir.decode()

extensionless_file_in_path = os.path.join(self.temp_dir, self.to_text_type("file"))
open(extensionless_file_in_path, 'w').close()

extensioned_file_in_path = os.path.join(self.temp_dir, self.to_text_type("file.test"))
open(extensioned_file_in_path, 'w').close()

self.assertEqual(shutil.which(self.to_text_type('file')), extensionless_file_in_path)


class TestWhichBytes(TestWhich):
def setUp(self):
TestWhich.setUp(self)
self.dir = os.fsencode(self.dir)
self.file = os.fsencode(self.file)
self.temp_file.name = os.fsencode(self.temp_file.name)
self.temp_dir = os.fsencode(self.temp_dir)
self.curdir = os.fsencode(self.curdir)
self.ext = os.fsencode(self.ext)

def to_text_type(self, s):
'''
In this class we're testing with bytes, so convert s to a bytes
'''
if isinstance(s, str):
return s.encode()
return s


class TestMove(BaseTest, unittest.TestCase):

Expand Down

0 comments on commit 41cb299

Please sign in to comment.