Skip to content

Commit

Permalink
Use sys.executable to resolve python executable.
Browse files Browse the repository at this point in the history
This ensures the python we pass to subprocess.run is the same one
as the thing that's running.
  • Loading branch information
tfogal committed Dec 14, 2024
1 parent 8074484 commit 0474620
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import os
import subprocess
import sys
import torch
import torch.fx
import torch.nn as nn
Expand Down Expand Up @@ -819,9 +820,13 @@ def func(x):
s2 = f"{tmp_path}/graph1_thunder_0.py"
assert os.path.exists(s1)
assert os.path.exists(s2)
cmd = "pytest" if use_pytest_benchmark else "python"
result1 = subprocess.run([cmd, s1], shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
result2 = subprocess.run([cmd, s2], shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
cmd = [sys.executable]
if use_pytest_benchmark:
cmd = cmd + ["-m", "pytest"]
cmd1 = cmd + [s1]
cmd2 = cmd + [s2]
result1 = subprocess.run(cmd1, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
result2 = subprocess.run(cmd2, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

assert result1.returncode == 0, f"Reproducer {s1} failed: {result1}"
assert result2.returncode == 0, f"Reproducer {s2} failed: {result2}"
Expand Down Expand Up @@ -852,8 +857,11 @@ def forward(self, x):

s1 = f"{tmp_path}/graph0_thunder_0.py"
assert os.path.exists(s1)
cmd = "pytest" if use_pytest_benchmark else "python"
result1 = subprocess.run([cmd, s1], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
cmd = [sys.executable]
if use_pytest_benchmark:
cmd = cmd + ["-m", "pytest"]
cmd1 = cmd + [s1]
result1 = subprocess.run(cmd1, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
assert result1.returncode == 0, f"Reproducer {s1} failed: {result1}"


Expand Down Expand Up @@ -915,7 +923,9 @@ def check(file_name, cmd):
s1 = f"{tmp_path}/graph0_thunder_0.py"
s2 = f"{tmp_path}/graph0_thunder_2.py"
s3 = f"{tmp_path}/graph0_thunder_4.py"
cmd = "pytest" if use_pytest_benchmark else "python"
cmd = [sys.executable]
if use_pytest_benchmark:
cmd = cmd + ["-m", "pytest"]
for fname in [s1, s2, s3]:
check(fname, cmd)

Expand Down

0 comments on commit 0474620

Please sign in to comment.