Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spec Verification #25

Merged
merged 82 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
2c12803
Push branch to Github
Mgluhovskoi Aug 5, 2024
c0f308c
Amend
Mgluhovskoi Aug 5, 2024
300ac05
Remove int64 from some decorators
Mgluhovskoi Aug 5, 2024
a83a201
Change input_values to dataclass, remove type checking if None, futur…
Mgluhovskoi Aug 6, 2024
9e4ca01
Update tripy/tripy/dtype_info.py
Mgluhovskoi Aug 6, 2024
328a34f
Push branch to Github
Mgluhovskoi Aug 5, 2024
ea21c85
Remove int64 from some decorators
Mgluhovskoi Aug 5, 2024
28b578c
Change input_values to dataclass, remove type checking if None, futur…
Mgluhovskoi Aug 6, 2024
6e19d4d
Update tripy/tripy/frontend/trace/ops/cast.py
Mgluhovskoi Aug 6, 2024
30abbe0
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 6, 2024
7ab69ad
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 6, 2024
3786b09
Fix rebase issues
Mgluhovskoi Aug 6, 2024
f845384
Move default_constraints to test_dtype_constraints
Mgluhovskoi Aug 6, 2024
6a9c8d1
Update guides
Mgluhovskoi Aug 6, 2024
cf21bde
Refactor and fix documentation compilation
Mgluhovskoi Aug 6, 2024
b731849
Remove unused imports
Mgluhovskoi Aug 6, 2024
68c13a8
Update tripy/docs/README.md
Mgluhovskoi Aug 7, 2024
956bd24
Add logic for self parameter and add documentation for tensor initial…
Mgluhovskoi Aug 7, 2024
b000b07
Merge branch 'dev-mgluhovskoi-spec-verification' of github.com:NVIDIA…
Mgluhovskoi Aug 7, 2024
463fec3
Update documentation generation and remove function names for arange …
Mgluhovskoi Aug 7, 2024
38cd48e
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 7, 2024
a9071fa
Remove any unnecessary mentions of dtype in docstrings
Mgluhovskoi Aug 7, 2024
7d9bb10
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 7, 2024
218a057
Fix arange verification
Mgluhovskoi Aug 7, 2024
6a4fb76
Convert TYPE_VERIFICATION to hold namedtuples
Mgluhovskoi Aug 7, 2024
1889463
Remove support for double assigned functions (__r<op>__)
Mgluhovskoi Aug 7, 2024
49d0f87
Revert "Remove support for double assigned functions (__r<op>__)"
Mgluhovskoi Aug 7, 2024
8bb4a3d
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 8, 2024
5de9c64
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 8, 2024
5ae4562
Add flat_ir testcase for gather op dim >= 0
Mgluhovskoi Aug 8, 2024
e9b48b4
Add integration testcases for gather
Mgluhovskoi Aug 8, 2024
d76c4f7
Revert "Add integration testcases for gather"
Mgluhovskoi Aug 8, 2024
5ed3738
Revert "Add flat_ir testcase for gather op dim >= 0"
Mgluhovskoi Aug 8, 2024
a285d83
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 8, 2024
50794a4
Add some comments and add some error checking
Mgluhovskoi Aug 8, 2024
78c61ff
Simplify conf.py logic
Mgluhovskoi Aug 9, 2024
7a6da5b
Update tripy/docs/conf.py
Mgluhovskoi Aug 9, 2024
ef6d8a0
Remove unused conditional from conf.py
Mgluhovskoi Aug 9, 2024
0be3ed6
Update op_doc_guide.md
Mgluhovskoi Aug 9, 2024
e8a3850
Move some verification documentation to function docstrings
Mgluhovskoi Aug 9, 2024
c05e3d9
Rename dtype_info.py to constraints.py
Mgluhovskoi Aug 9, 2024
714af43
Remove all constraint types except for init
Mgluhovskoi Aug 9, 2024
483ec40
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 9, 2024
4f3586a
Simplify conditionals
Mgluhovskoi Aug 9, 2024
4e4cdcc
Update tripy/tests/spec_verification/test_dtype_constraints.py
Mgluhovskoi Aug 9, 2024
5d4b77a
Use ExitStack
Mgluhovskoi Aug 9, 2024
f0c84df
Update comment
Mgluhovskoi Aug 9, 2024
c3d25de
Move most of TYPE_VERIFICATION logic to test_dtype_constraints.py
Mgluhovskoi Aug 9, 2024
f7a9811
Add support for exceptions to dtype rules
Mgluhovskoi Aug 9, 2024
59cdac6
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 9, 2024
00c8037
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 12, 2024
330591f
Update type exceptions style
Mgluhovskoi Aug 12, 2024
eabc276
Update variable name
Mgluhovskoi Aug 12, 2024
315aac7
Improve conf.py dtype exception text building
Mgluhovskoi Aug 12, 2024
ecb3d4a
Merge
Mgluhovskoi Aug 13, 2024
137e055
Update tripy/docs/conf.py
Mgluhovskoi Aug 13, 2024
b3270dc
Change guide file name
Mgluhovskoi Aug 13, 2024
e03632a
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 13, 2024
5491974
Merge branch 'dev-mgluhovskoi-spec-verification' of github.com:NVIDIA…
Mgluhovskoi Aug 13, 2024
6305ce7
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 14, 2024
4bc25c6
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 16, 2024
467e476
Initialize objects all within object_builder.py
Mgluhovskoi Aug 16, 2024
2e53173
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 16, 2024
66bd8bb
Rework some logic and experiment with L0 time
Mgluhovskoi Aug 16, 2024
3daa71a
Moving negative tests to L1
Mgluhovskoi Aug 19, 2024
1405fbb
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 19, 2024
13e51c5
Amend
Mgluhovskoi Aug 19, 2024
5fce676
Remove xfail since it was increasing test timing by almost 2x
Mgluhovskoi Aug 19, 2024
e7b21de
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 26, 2024
643d140
Add missing imports
Mgluhovskoi Aug 26, 2024
af02128
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 26, 2024
d7bb914
Fix bad rebase
Mgluhovskoi Aug 26, 2024
1074a57
Amendment
Mgluhovskoi Aug 26, 2024
b7362e3
Amendment
Mgluhovskoi Aug 27, 2024
54bd34c
Test L0 timings
Mgluhovskoi Aug 27, 2024
f4f54ed
Test L0
Mgluhovskoi Aug 27, 2024
506bb73
Move all tests to L1
Mgluhovskoi Aug 27, 2024
96bbe66
Update test_dtype_constraints.py comment
Mgluhovskoi Aug 27, 2024
3516e7e
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 27, 2024
7f87b86
Merge branch 'dev-mgluhovskoi-spec-verification' of github.com:NVIDIA…
Mgluhovskoi Aug 27, 2024
e6202b8
Update dtypes after rebase
Mgluhovskoi Aug 27, 2024
f72567b
Merge remote-tracking branch 'origin/main' into dev-mgluhovskoi-spec-…
Mgluhovskoi Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tripy/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ To view the documentation, you can open `build/docs/index.html` in a browser.
The `export.public_api()` decorator allows you to specify metadata for documentation
generation, such as where in the documentation hierarchy the API should be documented.

The `constraints.dtype_info()` decorator verifies the data types a function claims to support and generates
corresponding documentation. For more information, see [this guide](../tests/spec_verification/README.md).

The `generate_rsts.py` script uses this information to automatically generate a directory
structure and populate it with `.rst` files.

Expand Down
63 changes: 38 additions & 25 deletions tripy/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import tripy as tp
from tests import helper
from tripy.dtype_info import TYPE_VERIFICATION
from tripy.constraints import TYPE_VERIFICATION, FUNC_W_DOC_VERIF


PARAM_PAT = re.compile(":param .*?:")
Expand Down Expand Up @@ -149,12 +149,10 @@ def process_docstring(app, what, name, obj, options, lines):
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
pname = "*" + pname

if pname == "self":
# Don't want a type annotation for the self parameter.
if pname != "self" or obj.__qualname__ in FUNC_W_DOC_VERIF:
assert (
param.annotation == signature.empty
), f"Avoid using type annotations for the `self` parameter since this will corrupt the rendered documentation!"
else:
pname in documented_args
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args} {doc}"
assert (
pname in documented_args
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args}"
Expand All @@ -167,6 +165,10 @@ def process_docstring(app, what, name, obj, options, lines):
assert not inspect.ismodule(
param.annotation
), f"Type annotation cannot be a module, but got: '{param.annotation}' for parameter: '{pname}' in: '{obj}'. Please specify a type!"
else:
assert (
param.annotation == signature.empty
), f"Avoid using type annotations for the `self` parameter since this will corrupt the rendered documentation! Note: Documented parameters were: {documented_args} {doc}"

assert signature.return_annotation != signature.empty, (
f"Missing return type annotation for: '{obj}'. "
Expand All @@ -178,39 +180,50 @@ def process_docstring(app, what, name, obj, options, lines):
":returns:" in doc
), f"For: {obj}, return value is not documented. Please ensure you've included a `Returns:` section"

# new docstring logic:
# first figure out if we should it is the new docstring
if name.split(".")[-1] in TYPE_VERIFICATION.keys():
cleaned_name = name.split(".")[-1]
# New docstring logic:
# First figure out if object is using the @constraints.dtype_info decorator.
unqual_name = name.split(".")[-1]
if unqual_name in TYPE_VERIFICATION.keys():
add_text_index = -1
for index, block in enumerate(blocks):
if re.search(r".. code-block::", block):
type_dict = TYPE_VERIFICATION[cleaned_name][1]["types"]
type_dict = TYPE_VERIFICATION[unqual_name].dtypes
blocks.insert(index, "Type Constraints:")
index += 1
# Add the dtype constraint name and the dtypes that correlate.
for type_name, dt in type_dict.items():
blocks.insert(index, f" - {type_name}: " + ", ".join(dt))
blocks.insert(
index,
f" - **{type_name}**: :class:`" + "`, :class:`".join(set(dt)) + "`",
)
index += 1
blocks.insert(index, "\n")
if TYPE_VERIFICATION[unqual_name].dtype_exceptions != []:
# Add the dtype exceptions.
index += 1
blocks.insert(index, "**Unsupported Type Combinations**:")
dtype_exception_text = []
for exception_dict in TYPE_VERIFICATION[unqual_name].dtype_exceptions:
dtype_exception_text.append(
", ".join([f"{key}: :class:`{val}`" for key, val in exception_dict.items()])
)
dtype_exception_text = "; ".join(dtype_exception_text) + "\n"
index += 1
blocks.insert(index, dtype_exception_text)
break
if re.search(r":param \w+: ", block):
add_text_index = re.search(r":param \w+: ", block).span()[1]
param_name = re.match(r":param (\w+): ", block).group(1)
blocks[index] = (
block[0:add_text_index]
+ "[dtype="
+ TYPE_VERIFICATION[cleaned_name][2][param_name]
+ "] "
+ block[add_text_index:]
)
# Add dtype constraint to start of each parameter description.
if TYPE_VERIFICATION[unqual_name].dtype_constraints.get(param_name, None):
add_text_index = re.search(r":param \w+: ", block).span()[1]
blocks[index] = (
f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}"
)
if re.search(r":returns:", block):
add_text_index = re.search(r":returns:", block).span()[1] + 1
# Add dtype constraint to start of returns description.
blocks[index] = (
block[0:add_text_index]
+ "[dtype="
+ list(TYPE_VERIFICATION[cleaned_name][1]["returns"].values())[0]["dtype"]
+ "] "
+ block[add_text_index:]
f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}"
)

seen_classes.add(name)
Expand Down
1 change: 0 additions & 1 deletion tripy/tests/common/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
import pytest
import torch
from textwrap import dedent

import mlir_tensorrt.runtime.api as runtime
import tripy as tp
Expand Down
1 change: 0 additions & 1 deletion tripy/tests/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import tripy.common.datatype

from tests import helper
from tripy.common.datatype import DATA_TYPES
from tripy.common.exception import TripyException
from tripy.common.utils import (
convert_frontend_dtype_to_tripy_dtype,
Expand Down
1 change: 0 additions & 1 deletion tripy/tests/frontend/trace/ops/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#

import pytest
import re
import tripy as tp
from tests import helper
Expand Down
2 changes: 0 additions & 2 deletions tripy/tests/integration/test_conv_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
from collections.abc import Sequence
from dataclasses import dataclass

import cupy as cp
import pytest
import torch

import tripy as tp
from tests import helper

DTYPES = [
(torch.float16, tp.float16),
Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# limitations under the License.
#

import cupy as cp
import numpy as np

import pytest
import re
import torch
Expand Down
11 changes: 11 additions & 0 deletions tripy/tests/spec_verification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Introduction

Spec verification is designed to ensure that the datatypes documented in the operator documentation are accurate.

# How to Verify an Operation

To run the verification program on an operation, add the decorator `@constraints.dtype_info` to the operation. The inputs to the decorator will help the verifier determine the constraints on the inputs and the datatypes to verify.

To learn more about how to use `@constraints.dtype_info` check out `tripy/constraints.py`, `tests/spec_verification/test_dtype_constraints.py`, and `tests/spec_verification/object_builders.py`.

After the decorator is set up, it will automatically run verification test cases alongside the regular test cases. If you only want to run the verifier, execute `pytest -s -v` within the tests/spec_verification folder.
134 changes: 111 additions & 23 deletions tripy/tests/spec_verification/object_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,128 @@
#

import tripy as tp
import math

from typing import Union, Optional, get_origin, get_args, ForwardRef, List
from tripy.common import datatype
import inspect

def tensor_builder(func_obj, input_values, namespace):
shape = input_values.get("shape", None)
if not shape:
shape = (3, 2)
return tp.ones(dtype=namespace[input_values["dtype"]], shape=shape)

def tensor_builder(init, dtype, namespace):
if init is None:
return tp.ones(dtype=namespace[dtype], shape=(3, 2))
elif not isinstance(init, tp.Tensor):
assert dtype == None
return init
return tp.cast(init, dtype=namespace[dtype])

def shape_tensor_builder(func_obj, input_values, namespace):
follow_tensor = input_values.get("follow_tensor", None)
return (math.prod((namespace[follow_tensor]).shape.data().data()),)


def dtype_builder(func_obj, input_values, namespace):
dtype = input_values.get("dtype", None)
def dtype_builder(init, dtype, namespace):
return namespace[dtype]


def int_builder(func_obj, input_values, namespace):
return input_values.get("value", None)
def tensor_list_builder(init, dtype, namespace):
if init is None:
return [tp.ones(shape=(3, 2), dtype=namespace[dtype]) for _ in range(2)]
else:
return [tp.cast(tens, dtype=namespace[dtype]) for tens in init]


def device_builder(init, dtype, namespace):
if init is None:
return tp.device("gpu")
return init


def default_builder(init, dtype, namespace):
return init


find_func = {
"Tensor": tensor_builder,
"shape_tensor": shape_tensor_builder,
"dtype": dtype_builder,
"int": int_builder,
"tripy.Tensor": tensor_builder,
"tripy.Shape": tensor_builder,
"tripy.dtype": dtype_builder,
datatype.dtype: dtype_builder,
List[Union["tripy.Tensor"]]: tensor_list_builder,
"tripy.device": device_builder,
}

"""
default_constraints_all: This dictionary helps set specific constraints and values for parameters. These constraints correspond to the type hint of each parameter.
Some type have default values, so you might not need to pass other_constraints for every operation.
If there is no default, you must specify an initialization value, or the testcase may fail.
The dictionary's keys must be the name of the function that they are constraining and the value must be what the parameter should be initialized to.
Here is the list of parameter types that have defaults or work differently from other types:
- tensor - default: tp.ones(shape=(3,2)). If init is passed then value must be in the form of a list. Example: "scale": tp.Tensor([1,1,1]) or "scale": tp.ones((3,3))
- dtype - default: no default. Dtype parameters will be set using dtype_constraints input so using default_constraints_all will not change anything.
- list/sequence of tensors - default: [tp.ones((3,2)),tp.ones((3,2))]. Example: "dim": [tp.ones((2,4)),tp.ones((1,2))].
This will create a list/sequence of tensors of size count and each tensor will follow the init and shape value similar to tensor parameters.
- device - default: tp.device("gpu"). Example: {"device": tp.device("cpu")}.
All other types do not have defaults and must be passed to the verifier using default_constraints_all.
"""
default_constraints_all = {
"__rtruediv__": {"self": 1},
"__rsub__": {"self": 1},
"__radd__": {"self": 1},
"__rpow__": {"self": 1},
"__rmul__": {"self": 1},
"softmax": {"dim": 1},
"concatenate": {"dim": 0},
"expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))},
"full": {"shape": tp.Tensor([3]), "value": 1},
"full_like": {"value": 1},
"flip": {"dim": 1},
"gather": {"dim": 0, "index": tp.Tensor([1])},
"iota": {"shape": tp.Tensor([3])},
"__matmul__": {"self": tp.ones((2, 3))},
"transpose": {"dim0": 0, "dim1": 1},
"permute": {"perm": [1, 0]},
"quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
"sum": {"dim": 0},
"all": {"dim": 0},
"any": {"dim": 0},
"max": {"dim": 0},
"prod": {"dim": 0},
"mean": {"dim": 0},
"var": {"dim": 0},
"argmax": {"dim": 0},
"argmin": {"dim": 0},
"reshape": {"shape": tp.Tensor([6])},
"squeeze": {"input": tp.ones((3, 1)), "dims": (1)},
"__getitem__": {"index": 2},
"split": {"indices_or_sections": 2},
"unsqueeze": {"dim": 1},
"masked_fill": {"value": 1},
"ones": {"shape": tp.Tensor([3, 2])},
"zeros": {"shape": tp.Tensor([3, 2])},
"arange": {"start": 0, "stop": 5},
}


def create_obj(func_obj, param_name, input_desc, namespace):
param_type = list(input_desc.keys())[0]
create_obj_func = find_func[param_type]
namespace[param_name] = create_obj_func(func_obj, input_desc[param_type], namespace)
return namespace[param_name]
def create_obj(func_obj, func_name, param_name, param_dtype, namespace):
# If type is an optional or union get the first type.
# Get names and type hints for each param.
func_sig = inspect.signature(func_obj)
param_dict = func_sig.parameters
param_type_annot = param_dict[param_name]
init = None
# Check if there is a value in default_constraints_all for func_name and param_name and use it.
default_constraints = default_constraints_all.get(func_name, None)
if default_constraints != None:
other_constraint = default_constraints.get(param_name, None)
if other_constraint is not None:
init = other_constraint
# If parameter had a default then use it otherwise skip.
if init is None and param_type_annot.default is not param_type_annot.empty:
# Checking if not equal to None since default can be 0 or similar.
if param_type_annot.default != None:
init = param_type_annot.default
param_type = param_type_annot.annotation
while get_origin(param_type) in [Union, Optional]:
param_type = get_args(param_type)[0]
# ForwardRef refers to any case where type hint is a string.
if isinstance(param_type, ForwardRef):
param_type = param_type.__forward_arg__
create_obj_func = find_func.get(param_type, default_builder)
if create_obj_func:
namespace[param_name] = create_obj_func(init, param_dtype, namespace)
return namespace[param_name]
Loading