Skip to content

Commit

Permalink
changes requested in review
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Liddell committed Feb 1, 2024
1 parent 957f604 commit 08c0ae4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
4 changes: 4 additions & 0 deletions python/torch_mlir/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
# Make a temp dir for all the temp files we'll be generating as a side
# effect of infering shapes. For now, the only file is a new .onnx holding
# the revised model with shapes.
#
# TODO: If the program temp_dir is None, we should be using an ephemeral
# temp directory instead of a hard-coded path in order to avoid data races
# by default.
input_dir = os.path.dirname(os.path.abspath(args.input_file))
temp_dir = (
Path(input_dir if args.temp_dir is None else args.temp_dir)
Expand Down
19 changes: 9 additions & 10 deletions test/python/onnx_importer/command_line_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import onnx

from torch_mlir.tools.import_onnx import __main__

# For ONNX models

import numpy
Expand Down Expand Up @@ -102,10 +104,9 @@ def run_model_intern(self, onnx_model: onnx.ModelProto, model_name: str):
model_file = run_path / f"{model_name}-i.onnx"
mlir_file = run_path / f"{model_name}-i.torch.mlir"
onnx.save(onnx_model, model_file)
p = subprocess.run([
sys.executable, "-m", "torch_mlir.tools.import_onnx", model_file,
"-o", mlir_file])
self.assertEqual(p.returncode, 0)
args = __main__.parse_arguments([
str(model_file), "-o", str(mlir_file)])
__main__.main(args)

def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str):
run_path = self.get_run_path(model_name)
Expand All @@ -122,12 +123,10 @@ def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str):
onnx.save(onnx_model, model_file)
temp_dir = run_path / "temp"
temp_dir.mkdir(exist_ok=True)
p = subprocess.run([
sys.executable, "-m", "torch_mlir.tools.import_onnx",
model_file, "-o", mlir_file, "--keep-temps", "--temp-dir",
temp_dir, "--data-dir", run_path
])
self.assertEqual(p.returncode, 0)
args = __main__.parse_arguments([
str(model_file), "-o", str(mlir_file), "--keep-temps", "--temp-dir",
str(temp_dir), "--data-dir", str(run_path)])
__main__.main(args)

def test_all(self):
for model_func in ALL_MODELS:
Expand Down

0 comments on commit 08c0ae4

Please sign in to comment.