Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fix-floordiv-arange
Browse files Browse the repository at this point in the history
  • Loading branch information
parthchadha committed Sep 16, 2024
2 parents 6d7e764 + 251102e commit d5ac9eb
Show file tree
Hide file tree
Showing 17 changed files with 289 additions and 126 deletions.
2 changes: 1 addition & 1 deletion tripy/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This directory includes all the source files for the public API documentation.
You can build the documentation locally in the development container by running:
```bash
python3 docs/generate_rsts.py
sphinx-build build/doc_sources build/docs -c docs/ -j auto -W
sphinx-build build/doc_sources build/docs -c docs/ -j 4 -W
```
To view the documentation, you can open `build/docs/index.html` in a browser.

Expand Down
36 changes: 19 additions & 17 deletions tripy/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,19 @@ def process_docstring(app, what, name, obj, options, lines):
if unqual_name in TYPE_VERIFICATION:
add_text_index = -1
for index, block in enumerate(blocks):

def insert_block(text):
nonlocal index

blocks.insert(index, text)
index += 1

if re.search(r".. code-block::", block):
type_dict = TYPE_VERIFICATION[unqual_name].dtypes
blocks.insert(index, "Type Constraints:")
index += 1
insert_block("TYPE CONSTRAINTS:")
# Add the dtype constraint name and the dtypes that correlate.
for type_name, dt in type_dict.items():
blocks.insert(
index,
insert_block(
f" - **{type_name}**: :class:`"
+ "`, :class:`".join(
sorted(
Expand All @@ -215,20 +220,17 @@ def process_docstring(app, what, name, obj, options, lines):
)
+ "`",
)
index += 1
blocks.insert(index, "\n")
if TYPE_VERIFICATION[unqual_name].dtype_exceptions != []:
insert_block("\n")

if TYPE_VERIFICATION[unqual_name].dtype_exceptions:
# Add the dtype exceptions.
index += 1
blocks.insert(index, "**Unsupported Type Combinations**:")
dtype_exception_text = []
insert_block("UNSUPPORTED TYPE COMBINATIONS:")
for exception_dict in TYPE_VERIFICATION[unqual_name].dtype_exceptions:
dtype_exception_text.append(
", ".join([f"{key}: :class:`{val}`" for key, val in exception_dict.items()])
insert_block(
" - "
+ ", ".join([f"**{key}**\ =\ :class:`{val}`" for key, val in exception_dict.items()]),
)
dtype_exception_text = "; ".join(dtype_exception_text) + "\n"
index += 1
blocks.insert(index, dtype_exception_text)
insert_block("\n")
break

if re.search(r":param \w+: ", block):
Expand All @@ -237,14 +239,14 @@ def process_docstring(app, what, name, obj, options, lines):
if TYPE_VERIFICATION[unqual_name].dtype_constraints.get(param_name, None):
add_text_index = re.search(r":param \w+: ", block).span()[1]
blocks[index] = (
f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}"
f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}"
)

if TYPE_VERIFICATION[unqual_name].return_dtype is not None and re.search(r":returns:", block):
add_text_index = re.search(r":returns:", block).span()[1] + 1
# Add dtype constraint to start of returns description.
blocks[index] = (
f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}"
f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}"
)

seen_classes.add(name)
Expand Down
24 changes: 24 additions & 0 deletions tripy/tests/frontend/ops/test_cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tests import helper
import tripy as tp


class TestCumsum:
def test_invalid_dim_fails(self):
a = tp.ones((2, 2))
with helper.raises(tp.TripyException, "Dimension argument is out of bounds."):
tp.cumsum(a, dim=4)
40 changes: 40 additions & 0 deletions tripy/tests/integration/test_cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import tripy as tp


class TestCumsum:
@pytest.mark.parametrize(
"data,dim,expected",
[
([0, 1, 2, 3], 0, [0, 1, 3, 6]),
# Negative dim:
([[2, 3], [4, 5]], -2, [[2, 3], [6, 8]]),
# Non-innermost dim:
([[2, 3], [4, 5]], 0, [[2, 3], [6, 8]]),
# >2D (can potentially find transposition bugs)
([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 0, [[[1, 2], [3, 4]], [[6, 8], [10, 12]]]),
],
)
def test_cumsum(self, data, dim, expected):
inp = tp.Tensor(data, dtype=tp.float32)

out = tp.cumsum(inp, dim=dim)

expected = tp.Tensor(expected, dtype=tp.float32)
assert tp.allclose(out, expected)
assert out.shape == expected.shape
69 changes: 38 additions & 31 deletions tripy/tests/spec_verification/object_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,43 +81,19 @@ def default_builder(init, dtype, namespace):
All other types do not have defaults and must be passed to the verifier using default_constraints_all.
"""
default_constraints_all = {
"__rtruediv__": {"self": 1},
"__rsub__": {"self": 1},
"__getitem__": {"index": 2},
"__matmul__": {"self": tp.ones((2, 3))},
"__radd__": {"self": 1},
"__rpow__": {"self": 1},
"__rmul__": {"self": 1},
"softmax": {"dim": 1},
"concatenate": {"dim": 0},
"expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))},
"full": {"shape": tp.Tensor([3]), "value": 1},
"full_like": {"value": 1},
"flip": {"dim": 1},
"gather": {"dim": 0, "index": tp.Tensor([1])},
"iota": {"shape": tp.Tensor([4])},
"__matmul__": {"self": tp.ones((2, 3))},
"transpose": {"dim0": 0, "dim1": 1},
"permute": {"perm": [1, 0]},
"quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
"dequantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
"sum": {"dim": 0},
"__rpow__": {"self": 1},
"__rsub__": {"self": 1},
"__rtruediv__": {"self": 1},
"all": {"dim": 0},
"any": {"dim": 0},
"max": {"dim": 0},
"prod": {"dim": 0},
"mean": {"dim": 0},
"var": {"dim": 0},
"arange": {"start": 0, "stop": 5},
"argmax": {"dim": 0},
"argmin": {"dim": 0},
"reshape": {"shape": tp.Tensor([6])},
"squeeze": {"input": tp.ones((3, 1)), "dims": (1)},
"__getitem__": {"index": 2},
"split": {"indices_or_sections": 2},
"unsqueeze": {"dim": 1},
"masked_fill": {"value": 1},
"ones": {"shape": tp.Tensor([3, 2])},
"zeros": {"shape": tp.Tensor([3, 2])},
"arange": {"start": 0, "stop": 5},
"repeat": {"repeats": 2, "dim": 0},
"concatenate": {"dim": 0},
"convolution": {
"input": tp.ones((1, 3, 5, 5)),
"weight": tp.ones((1, 3, 3, 3)),
Expand All @@ -127,6 +103,31 @@ def default_builder(init, dtype, namespace):
"lhs_dilation": [1, 1],
"rhs_dilation": [1, 1],
},
"cumsum": {"dim": 0},
"dequantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
"expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))},
"flip": {"dim": 1},
"full_like": {"value": 1},
"full": {"shape": tp.Tensor([3]), "value": 1},
"gather": {"dim": 0, "index": tp.Tensor([1])},
"iota": {"shape": tp.Tensor([4])},
"masked_fill": {"value": 1},
"max": {"dim": 0},
"mean": {"dim": 0},
"ones": {"shape": tp.Tensor([3, 2])},
"permute": {"perm": [1, 0]},
"prod": {"dim": 0},
"quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
"repeat": {"repeats": 2, "dim": 0},
"reshape": {"shape": tp.Tensor([6])},
"softmax": {"dim": 1},
"split": {"indices_or_sections": 2},
"squeeze": {"input": tp.ones((3, 1)), "dims": (1)},
"sum": {"dim": 0},
"transpose": {"dim0": 0, "dim1": 1},
"unsqueeze": {"dim": 1},
"var": {"dim": 0},
"zeros": {"shape": tp.Tensor([3, 2])},
}


Expand All @@ -137,24 +138,30 @@ def create_obj(func_obj, func_name, param_name, param_dtype, namespace):
param_dict = func_sig.parameters
param_type_annot = param_dict[param_name]
init = None

# Check if there is a value in default_constraints_all for func_name and param_name and use it.
default_constraints = default_constraints_all.get(func_name, None)
if default_constraints != None:
other_constraint = default_constraints.get(param_name, None)
if other_constraint is not None:
init = other_constraint

# If parameter had a default then use it otherwise skip.
if init is None and param_type_annot.default is not param_type_annot.empty:
# Checking if not equal to None since default can be 0 or similar.
if param_type_annot.default != None:
init = param_type_annot.default

param_type = param_type_annot.annotation
while get_origin(param_type) in [Union, Optional]:
param_type = get_args(param_type)[0]
# ForwardRef refers to any case where type hint is a string.
if isinstance(param_type, ForwardRef):
param_type = param_type.__forward_arg__

create_obj_func = find_func.get(param_type, default_builder)
if create_obj_func:
namespace[param_name] = create_obj_func(init, param_dtype, namespace)
return namespace[param_name]

assert False, f"Could not create parameter: {param_name}"
1 change: 1 addition & 0 deletions tripy/tripy/flat_ir/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tripy.flat_ir.ops.dot import DotOp
from tripy.flat_ir.ops.exponential import ExpOp
from tripy.flat_ir.ops.flip import FlipOp
from tripy.flat_ir.ops.floor import FloorOp
from tripy.flat_ir.ops.gather import DynamicGatherOp
from tripy.flat_ir.ops.get_dimension_size import GetDimensionSizeOp
from tripy.flat_ir.ops.iota import DynamicIotaOp
Expand Down
28 changes: 28 additions & 0 deletions tripy/tripy/flat_ir/ops/floor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass

from mlir_tensorrt.compiler.dialects import stablehlo

from tripy.flat_ir.ops.base import BaseFlatIROp


@dataclass(repr=False)
class FloorOp(BaseFlatIROp):
def to_mlir(self, operands):
return [stablehlo.floor(*operands)]
4 changes: 0 additions & 4 deletions tripy/tripy/frontend/ops/allclose.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ def allclose(a: "tripy.Tensor", b: "tripy.Tensor", rtol: float = 1e-05, atol: fl
"""
from tripy.frontend.trace.ops.unary_elementwise import abs
from tripy.frontend.trace.ops.reduce import all
from tripy.common.datatype import int64

if a.dtype == int64:
raise_error("Known issue with i64. Allclose currently does not work with int64 inputs. Issue #116")

compare = abs(a - b) <= (atol + rtol * abs(b))
return bool(all(compare))
Loading

0 comments on commit d5ac9eb

Please sign in to comment.