From b1becd0dfa04e0b39f7ca18846a38de92cc48baf Mon Sep 17 00:00:00 2001 From: Arham Khan Date: Fri, 14 Jul 2023 14:24:40 -0500 Subject: [PATCH] more api compat fixes (#64) Co-authored-by: Arham Khan --- cpp_ext/TorchTensor.cpp | 44 ++++++++++++++++++- cpp_ext/TorchTensor.pybinds.cpp | 7 +-- pi/mlir/utils.py | 6 +-- ...ate_torch_mlir_bindings_from_torch_json.py | 1 - tests/torch_mlir/xfail.py | 5 +++ 5 files changed, 54 insertions(+), 9 deletions(-) diff --git a/cpp_ext/TorchTensor.cpp b/cpp_ext/TorchTensor.cpp index 341602b..4767ee4 100644 --- a/cpp_ext/TorchTensor.cpp +++ b/cpp_ext/TorchTensor.cpp @@ -410,11 +410,51 @@ void PyAnyTorchTensorValue::bindDerived(ClassTy &c) { [](const PyAnyTorchTensorValue &self, const py::args &dims, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { - return permute(self, PyAnyTorchListOfTorchIntValue(dims), loc.get(), - ip.get()); + PyAnyTorchListOfTorchIntValue dims_ = + py::isinstance(dims[0]) + ? dims[0].cast() + : PyAnyTorchListOfTorchIntValue(dims); + return permute(self, dims_, loc.get(), ip.get()); }, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none()); + // aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor) + c.def( + "add", + [](const PyAnyTorchTensorValue &self, const PyAnyTorchTensorValue &other, + const PyAnyTorchScalarValue &alpha, DefaultingPyLocation &loc, + const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { + return add(self, other, alpha, loc.get(), ip.get()); + }, + "other"_a, "alpha"_a = 1, py::kw_only(), "loc"_a = py::none(), + "ip"_a = py::none()); + + // aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor) + c.def( + "norm", + [](const PyAnyTorchTensorValue &self, + const PyAnyTorchOptionalScalarValue &p, + const PyAnyTorchListOfTorchIntValue &dim, + const PyTorch_BoolValue &keepdim, DefaultingPyLocation &loc, + const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { + return norm(self, p, dim, keepdim, loc.get(), ip.get()); + }, + "p"_a = py::none(), "dim"_a, "keepdim"_a = false, py::kw_only(), + "loc"_a = py::none(), "ip"_a = py::none()); + + // aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int, bool) -> (Tensor) + c.def( + "norm", + [](const PyAnyTorchTensorValue &self, + const PyAnyTorchOptionalScalarValue &p, const PyTorch_IntValue &dim, + const PyTorch_BoolValue &keepdim, DefaultingPyLocation &loc, + const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { + auto dims = PyAnyTorchListOfTorchIntValue(py::make_tuple(dim)); + return norm(self, p, dims, keepdim, loc.get(), ip.get()); + }, + "p"_a = py::none(), "dim"_a, "keepdim"_a = false, py::kw_only(), + "loc"_a = py::none(), "ip"_a = py::none()); + #include "TorchTensor.pybinds.cpp" } diff --git a/cpp_ext/TorchTensor.pybinds.cpp b/cpp_ext/TorchTensor.pybinds.cpp index caaf217..3cdb040 100644 --- a/cpp_ext/TorchTensor.pybinds.cpp +++ b/cpp_ext/TorchTensor.pybinds.cpp @@ -283,9 +283,6 @@ c.def("_nested_tensor_strides", [](PyAnyTorchTensorValue& self, py::args args, p // _nnz(self) -> _int c.def("_nnz", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _nnz with signature _nnz(self) -> _int"); }); -// _sparse_mask_projection(self, mask: Tensor) -> Tensor -c.def("_sparse_mask_projection", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _sparse_mask_projection with signature _sparse_mask_projection(self, mask: Tensor) -> Tensor"); }); - // _to_dense(self, dtype: Optional[_dtype]=None, masked_grad: Optional[_bool]=None) -> Tensor c.def("_to_dense", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _to_dense with signature _to_dense(self, dtype: Optional[_dtype]=None, masked_grad: Optional[_bool]=None) -> Tensor"); }); @@ -1752,6 +1749,10 @@ c.def("ormqr", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) // outer(self, vec2: Tensor) -> Tensor c.def("outer", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: outer with signature outer(self, vec2: Tensor) -> Tensor"); }); +// @overload permute(self, dims: _size) -> Tensor +// aten::permute : (Tensor, int[]) -> (Tensor) +c.def("permute", [](const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &dims, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return permute(self, dims, loc.get(), ip.get()); }, "dims"_a, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none()); + // pin_memory(self, device: Optional[Union[_device, str, None]]=None) -> Tensor c.def("pin_memory", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: pin_memory with signature pin_memory(self, device: Optional[Union[_device, str, None]]=None) -> Tensor"); }); diff --git a/pi/mlir/utils.py b/pi/mlir/utils.py index c81060c..f8aaf1e 100644 --- a/pi/mlir/utils.py +++ b/pi/mlir/utils.py @@ -6,7 +6,7 @@ import functools import inspect import warnings -from enum import Enum +from enum import IntEnum from typing import List, Optional, Tuple, Union import numpy as np @@ -188,7 +188,7 @@ def standard_normal(*args, **kwargs): LongTensor = functools.partial(_np_wrapper, factory=np.array, dtype=dtype.int64) -class layout(Enum): +class layout(IntEnum): strided = 1 sparse_coo = 2 sparse_csr = 3 @@ -198,7 +198,7 @@ class layout(Enum): _mkldnn = 7 -class memory_format(Enum): +class memory_format(IntEnum): contiguous_format = 0 preserve_format = 1 channels_last = 2 diff --git a/scripts/generate_stuff/generate_torch_mlir_bindings_from_torch_json.py b/scripts/generate_stuff/generate_torch_mlir_bindings_from_torch_json.py index 015dc5d..e3988f2 100644 --- a/scripts/generate_stuff/generate_torch_mlir_bindings_from_torch_json.py +++ b/scripts/generate_stuff/generate_torch_mlir_bindings_from_torch_json.py @@ -63,7 +63,6 @@ def get_clean_name(name): "chunk(self, chunks: _int, dim: _int=0) -> List[Tensor]", "__getitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple]) -> Tensor", "double(self) -> Tensor", - "@overload permute(self, dims: _size) -> Tensor", } TORCH_OPS_IMPL_CPP = "TorchOps.impls.cpp" diff --git a/tests/torch_mlir/xfail.py b/tests/torch_mlir/xfail.py index b5687cb..7fa99d5 100644 --- a/tests/torch_mlir/xfail.py +++ b/tests/torch_mlir/xfail.py @@ -17,6 +17,8 @@ } PI_XFAIL_SET = { + + # In these, torch-mlir spuriously initializes tensors as double precision and truncates to floating point, we simply initialize as single-precision causing an IR diff "ElementwiseGeFloatScalarModule_basic", "ArangeStartNegativeStepFloatModule_basic", "ArangeStartStepFloatModule_basic", @@ -33,4 +35,7 @@ "ThresholdBackward3dFloatModule_basic", "TypePromotionAlphaWiderModule_basic", "TypePromotionSameCategoryZeroRankWider_basic", + + # An IR difference due to an additional pass in torch-mlir, but functionally the same + "NormalizeModule_basic" }