Skip to content

Commit

Permalink
more api compat fixes (#64)
Browse files Browse the repository at this point in the history
Co-authored-by: Arham Khan <[email protected]>
  • Loading branch information
123epsilon and Arham Khan authored Jul 14, 2023
1 parent fa1c3a2 commit b1becd0
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 9 deletions.
44 changes: 42 additions & 2 deletions cpp_ext/TorchTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,51 @@ void PyAnyTorchTensorValue::bindDerived(ClassTy &c) {
[](const PyAnyTorchTensorValue &self, const py::args &dims,
DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
return permute(self, PyAnyTorchListOfTorchIntValue(dims), loc.get(),
ip.get());
PyAnyTorchListOfTorchIntValue dims_ =
py::isinstance<py::list>(dims[0])
? dims[0].cast<py::list>()
: PyAnyTorchListOfTorchIntValue(dims);
return permute(self, dims_, loc.get(), ip.get());
},
py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)
c.def(
"add",
[](const PyAnyTorchTensorValue &self, const PyAnyTorchTensorValue &other,
const PyAnyTorchScalarValue &alpha, DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
return add(self, other, alpha, loc.get(), ip.get());
},
"other"_a, "alpha"_a = 1, py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

// aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)
c.def(
"norm",
[](const PyAnyTorchTensorValue &self,
const PyAnyTorchOptionalScalarValue &p,
const PyAnyTorchListOfTorchIntValue &dim,
const PyTorch_BoolValue &keepdim, DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
return norm(self, p, dim, keepdim, loc.get(), ip.get());
},
"p"_a = py::none(), "dim"_a, "keepdim"_a = false, py::kw_only(),
"loc"_a = py::none(), "ip"_a = py::none());

// aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int, bool) -> (Tensor)
c.def(
"norm",
[](const PyAnyTorchTensorValue &self,
const PyAnyTorchOptionalScalarValue &p, const PyTorch_IntValue &dim,
const PyTorch_BoolValue &keepdim, DefaultingPyLocation &loc,
const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue {
auto dims = PyAnyTorchListOfTorchIntValue(py::make_tuple(dim));
return norm(self, p, dims, keepdim, loc.get(), ip.get());
},
"p"_a = py::none(), "dim"_a, "keepdim"_a = false, py::kw_only(),
"loc"_a = py::none(), "ip"_a = py::none());

#include "TorchTensor.pybinds.cpp"
}

Expand Down
7 changes: 4 additions & 3 deletions cpp_ext/TorchTensor.pybinds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,6 @@ c.def("_nested_tensor_strides", [](PyAnyTorchTensorValue& self, py::args args, p
// _nnz(self) -> _int
c.def("_nnz", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _nnz with signature _nnz(self) -> _int"); });

// _sparse_mask_projection(self, mask: Tensor) -> Tensor
c.def("_sparse_mask_projection", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _sparse_mask_projection with signature _sparse_mask_projection(self, mask: Tensor) -> Tensor"); });

// _to_dense(self, dtype: Optional[_dtype]=None, masked_grad: Optional[_bool]=None) -> Tensor
c.def("_to_dense", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _to_dense with signature _to_dense(self, dtype: Optional[_dtype]=None, masked_grad: Optional[_bool]=None) -> Tensor"); });

Expand Down Expand Up @@ -1752,6 +1749,10 @@ c.def("ormqr", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs)
// outer(self, vec2: Tensor) -> Tensor
c.def("outer", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: outer with signature outer(self, vec2: Tensor) -> Tensor"); });

// @overload permute(self, dims: _size) -> Tensor
// aten::permute : (Tensor, int[]) -> (Tensor)
c.def("permute", [](const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &dims, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return permute(self, dims, loc.get(), ip.get()); }, "dims"_a, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// pin_memory(self, device: Optional[Union[_device, str, None]]=None) -> Tensor
c.def("pin_memory", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: pin_memory with signature pin_memory(self, device: Optional[Union[_device, str, None]]=None) -> Tensor"); });

Expand Down
6 changes: 3 additions & 3 deletions pi/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import functools
import inspect
import warnings
from enum import Enum
from enum import IntEnum
from typing import List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -188,7 +188,7 @@ def standard_normal(*args, **kwargs):
LongTensor = functools.partial(_np_wrapper, factory=np.array, dtype=dtype.int64)


class layout(Enum):
class layout(IntEnum):
strided = 1
sparse_coo = 2
sparse_csr = 3
Expand All @@ -198,7 +198,7 @@ class layout(Enum):
_mkldnn = 7


class memory_format(Enum):
class memory_format(IntEnum):
contiguous_format = 0
preserve_format = 1
channels_last = 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def get_clean_name(name):
"chunk(self, chunks: _int, dim: _int=0) -> List[Tensor]",
"__getitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple]) -> Tensor",
"double(self) -> Tensor",
"@overload permute(self, dims: _size) -> Tensor",
}

TORCH_OPS_IMPL_CPP = "TorchOps.impls.cpp"
Expand Down
5 changes: 5 additions & 0 deletions tests/torch_mlir/xfail.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
}

PI_XFAIL_SET = {

# In these, torch-mlir spuriously initializes tensors as double precision and truncates to floating point, we simply initialize as single-precision causing an IR diff
"ElementwiseGeFloatScalarModule_basic",
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartStepFloatModule_basic",
Expand All @@ -33,4 +35,7 @@
"ThresholdBackward3dFloatModule_basic",
"TypePromotionAlphaWiderModule_basic",
"TypePromotionSameCategoryZeroRankWider_basic",

# An IR difference due to an additional pass in torch-mlir, but functionally the same
"NormalizeModule_basic"
}

0 comments on commit b1becd0

Please sign in to comment.