Skip to content

Commit

Permalink
Spec Verification (#25)
Browse files Browse the repository at this point in the history
1) Create a standard for doc-strings dtypes
2) Automatically verify doc-strings' dtype
    - negative test any dtypes that are not supported
3) Integrate verification into test pipeline (L1 for now)
4) Add readme file to explain how to use verifier/decorator

Side task:

Add support for several dtypes within cast.

---------

Signed-off-by: Mgluhovskoi <[email protected]>
Co-authored-by: pranavm-nvidia <[email protected]>
Co-authored-by: Parth Chadha <[email protected]>
  • Loading branch information
3 people authored Aug 28, 2024
1 parent 35473f3 commit 5c412aa
Show file tree
Hide file tree
Showing 43 changed files with 907 additions and 344 deletions.
3 changes: 3 additions & 0 deletions tripy/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ To view the documentation, you can open `build/docs/index.html` in a browser.
The `export.public_api()` decorator allows you to specify metadata for documentation
generation, such as where in the documentation hierarchy the API should be documented.

The `constraints.dtype_info()` decorator verifies the data types a function claims to support and generates
corresponding documentation. For more information, see [this guide](../tests/spec_verification/README.md).

The `generate_rsts.py` script uses this information to automatically generate a directory
structure and populate it with `.rst` files.

Expand Down
63 changes: 38 additions & 25 deletions tripy/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tests import helper

import tripy as tp
from tripy.dtype_info import TYPE_VERIFICATION
from tripy.constraints import TYPE_VERIFICATION, FUNC_W_DOC_VERIF

PARAM_PAT = re.compile(":param .*?:")

Expand Down Expand Up @@ -161,12 +161,10 @@ def process_docstring(app, what, name, obj, options, lines):
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
pname = "*" + pname

if pname == "self":
# Don't want a type annotation for the self parameter.
if pname != "self" or obj.__qualname__ in FUNC_W_DOC_VERIF:
assert (
param.annotation == signature.empty
), f"Avoid using type annotations for the `self` parameter since this will corrupt the rendered documentation!"
else:
pname in documented_args
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args} {doc}"
assert (
pname in documented_args
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args}"
Expand All @@ -179,6 +177,10 @@ def process_docstring(app, what, name, obj, options, lines):
assert not inspect.ismodule(
param.annotation
), f"Type annotation cannot be a module, but got: '{param.annotation}' for parameter: '{pname}' in: '{obj}'. Please specify a type!"
else:
assert (
param.annotation == signature.empty
), f"Avoid using type annotations for the `self` parameter since this will corrupt the rendered documentation! Note: Documented parameters were: {documented_args} {doc}"

assert signature.return_annotation != signature.empty, (
f"Missing return type annotation for: '{obj}'. "
Expand All @@ -190,39 +192,50 @@ def process_docstring(app, what, name, obj, options, lines):
":returns:" in doc
), f"For: {obj}, return value is not documented. Please ensure you've included a `Returns:` section"

# new docstring logic:
# first figure out if we should it is the new docstring
if name.split(".")[-1] in TYPE_VERIFICATION.keys():
cleaned_name = name.split(".")[-1]
# New docstring logic:
# First figure out if object is using the @constraints.dtype_info decorator.
unqual_name = name.split(".")[-1]
if unqual_name in TYPE_VERIFICATION.keys():
add_text_index = -1
for index, block in enumerate(blocks):
if re.search(r".. code-block::", block):
type_dict = TYPE_VERIFICATION[cleaned_name][1]["types"]
type_dict = TYPE_VERIFICATION[unqual_name].dtypes
blocks.insert(index, "Type Constraints:")
index += 1
# Add the dtype constraint name and the dtypes that correlate.
for type_name, dt in type_dict.items():
blocks.insert(index, f" - {type_name}: " + ", ".join(dt))
blocks.insert(
index,
f" - **{type_name}**: :class:`" + "`, :class:`".join(set(dt)) + "`",
)
index += 1
blocks.insert(index, "\n")
if TYPE_VERIFICATION[unqual_name].dtype_exceptions != []:
# Add the dtype exceptions.
index += 1
blocks.insert(index, "**Unsupported Type Combinations**:")
dtype_exception_text = []
for exception_dict in TYPE_VERIFICATION[unqual_name].dtype_exceptions:
dtype_exception_text.append(
", ".join([f"{key}: :class:`{val}`" for key, val in exception_dict.items()])
)
dtype_exception_text = "; ".join(dtype_exception_text) + "\n"
index += 1
blocks.insert(index, dtype_exception_text)
break
if re.search(r":param \w+: ", block):
add_text_index = re.search(r":param \w+: ", block).span()[1]
param_name = re.match(r":param (\w+): ", block).group(1)
blocks[index] = (
block[0:add_text_index]
+ "[dtype="
+ TYPE_VERIFICATION[cleaned_name][2][param_name]
+ "] "
+ block[add_text_index:]
)
# Add dtype constraint to start of each parameter description.
if TYPE_VERIFICATION[unqual_name].dtype_constraints.get(param_name, None):
add_text_index = re.search(r":param \w+: ", block).span()[1]
blocks[index] = (
f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}"
)
if re.search(r":returns:", block):
add_text_index = re.search(r":returns:", block).span()[1] + 1
# Add dtype constraint to start of returns description.
blocks[index] = (
block[0:add_text_index]
+ "[dtype="
+ list(TYPE_VERIFICATION[cleaned_name][1]["returns"].values())[0]["dtype"]
+ "] "
+ block[add_text_index:]
f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}"
)

seen_classes.add(name)
Expand Down
1 change: 0 additions & 1 deletion tripy/tests/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import tripy.common.datatype

from tests import helper
from tripy.common.datatype import DATA_TYPES
from tripy.common.exception import TripyException
from tripy.common.utils import (
convert_list_to_array,
Expand Down
1 change: 0 additions & 1 deletion tripy/tests/frontend/trace/ops/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#

import pytest
import re
import tripy as tp
from tests import helper
Expand Down
2 changes: 0 additions & 2 deletions tripy/tests/integration/test_conv_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
from collections.abc import Sequence
from dataclasses import dataclass

import cupy as cp
import pytest
import torch

import tripy as tp
from tests import helper

DTYPES = [
(torch.float16, tp.float16),
Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# limitations under the License.
#

import cupy as cp
import numpy as np

import pytest
import re
import torch
Expand Down
11 changes: 11 additions & 0 deletions tripy/tests/spec_verification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Introduction

Spec verification is designed to ensure that the datatypes documented in the operator documentation are accurate.

# How to Verify an Operation

To run the verification program on an operation, add the decorator `@constraints.dtype_info` to the operation. The inputs to the decorator will help the verifier determine the constraints on the inputs and the datatypes to verify.

To learn more about how to use `@constraints.dtype_info` check out `tripy/constraints.py`, `tests/spec_verification/test_dtype_constraints.py`, and `tests/spec_verification/object_builders.py`.

After the decorator is set up, it will automatically run verification test cases alongside the regular test cases. If you only want to run the verifier, execute `pytest -s -v` within the tests/spec_verification folder.
134 changes: 111 additions & 23 deletions tripy/tests/spec_verification/object_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,128 @@
#

import tripy as tp
import math

from typing import Union, Optional, get_origin, get_args, ForwardRef, List
from tripy.common import datatype
import inspect

def tensor_builder(func_obj, input_values, namespace):
shape = input_values.get("shape", None)
if not shape:
shape = (3, 2)
return tp.ones(dtype=namespace[input_values["dtype"]], shape=shape)

def tensor_builder(init, dtype, namespace):
if init is None:
return tp.ones(dtype=namespace[dtype], shape=(3, 2))
elif not isinstance(init, tp.Tensor):
assert dtype == None
return init
return tp.cast(init, dtype=namespace[dtype])

def shape_tensor_builder(func_obj, input_values, namespace):
follow_tensor = input_values.get("follow_tensor", None)
return (math.prod((namespace[follow_tensor]).shape.tolist()),)


def dtype_builder(func_obj, input_values, namespace):
dtype = input_values.get("dtype", None)
def dtype_builder(init, dtype, namespace):
return namespace[dtype]


def int_builder(func_obj, input_values, namespace):
return input_values.get("value", None)
def tensor_list_builder(init, dtype, namespace):
if init is None:
return [tp.ones(shape=(3, 2), dtype=namespace[dtype]) for _ in range(2)]
else:
return [tp.cast(tens, dtype=namespace[dtype]) for tens in init]


def device_builder(init, dtype, namespace):
if init is None:
return tp.device("gpu")
return init


def default_builder(init, dtype, namespace):
return init


find_func = {
"Tensor": tensor_builder,
"shape_tensor": shape_tensor_builder,
"dtype": dtype_builder,
"int": int_builder,
"tripy.Tensor": tensor_builder,
"tripy.Shape": tensor_builder,
"tripy.dtype": dtype_builder,
datatype.dtype: dtype_builder,
List[Union["tripy.Tensor"]]: tensor_list_builder,
"tripy.device": device_builder,
}

"""
default_constraints_all: This dictionary helps set specific constraints and values for parameters. These constraints correspond to the type hint of each parameter.
Some type have default values, so you might not need to pass other_constraints for every operation.
If there is no default, you must specify an initialization value, or the testcase may fail.
The dictionary's keys must be the name of the function that they are constraining and the value must be what the parameter should be initialized to.
Here is the list of parameter types that have defaults or work differently from other types:
- tensor - default: tp.ones(shape=(3,2)). If init is passed then value must be in the form of a list. Example: "scale": tp.Tensor([1,1,1]) or "scale": tp.ones((3,3))
- dtype - default: no default. Dtype parameters will be set using dtype_constraints input so using default_constraints_all will not change anything.
- list/sequence of tensors - default: [tp.ones((3,2)),tp.ones((3,2))]. Example: "dim": [tp.ones((2,4)),tp.ones((1,2))].
This will create a list/sequence of tensors of size count and each tensor will follow the init and shape value similar to tensor parameters.
- device - default: tp.device("gpu"). Example: {"device": tp.device("cpu")}.
All other types do not have defaults and must be passed to the verifier using default_constraints_all.
"""
default_constraints_all = {
"__rtruediv__": {"self": 1},
"__rsub__": {"self": 1},
"__radd__": {"self": 1},
"__rpow__": {"self": 1},
"__rmul__": {"self": 1},
"softmax": {"dim": 1},
"concatenate": {"dim": 0},
"expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))},
"full": {"shape": tp.Tensor([3]), "value": 1},
"full_like": {"value": 1},
"flip": {"dim": 1},
"gather": {"dim": 0, "index": tp.Tensor([1])},
"iota": {"shape": tp.Tensor([3])},
"__matmul__": {"self": tp.ones((2, 3))},
"transpose": {"dim0": 0, "dim1": 1},
"permute": {"perm": [1, 0]},
"quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
"sum": {"dim": 0},
"all": {"dim": 0},
"any": {"dim": 0},
"max": {"dim": 0},
"prod": {"dim": 0},
"mean": {"dim": 0},
"var": {"dim": 0},
"argmax": {"dim": 0},
"argmin": {"dim": 0},
"reshape": {"shape": tp.Tensor([6])},
"squeeze": {"input": tp.ones((3, 1)), "dims": (1)},
"__getitem__": {"index": 2},
"split": {"indices_or_sections": 2},
"unsqueeze": {"dim": 1},
"masked_fill": {"value": 1},
"ones": {"shape": tp.Tensor([3, 2])},
"zeros": {"shape": tp.Tensor([3, 2])},
"arange": {"start": 0, "stop": 5},
}


def create_obj(func_obj, param_name, input_desc, namespace):
param_type = list(input_desc.keys())[0]
create_obj_func = find_func[param_type]
namespace[param_name] = create_obj_func(func_obj, input_desc[param_type], namespace)
return namespace[param_name]
def create_obj(func_obj, func_name, param_name, param_dtype, namespace):
# If type is an optional or union get the first type.
# Get names and type hints for each param.
func_sig = inspect.signature(func_obj)
param_dict = func_sig.parameters
param_type_annot = param_dict[param_name]
init = None
# Check if there is a value in default_constraints_all for func_name and param_name and use it.
default_constraints = default_constraints_all.get(func_name, None)
if default_constraints != None:
other_constraint = default_constraints.get(param_name, None)
if other_constraint is not None:
init = other_constraint
# If parameter had a default then use it otherwise skip.
if init is None and param_type_annot.default is not param_type_annot.empty:
# Checking if not equal to None since default can be 0 or similar.
if param_type_annot.default != None:
init = param_type_annot.default
param_type = param_type_annot.annotation
while get_origin(param_type) in [Union, Optional]:
param_type = get_args(param_type)[0]
# ForwardRef refers to any case where type hint is a string.
if isinstance(param_type, ForwardRef):
param_type = param_type.__forward_arg__
create_obj_func = find_func.get(param_type, default_builder)
if create_obj_func:
namespace[param_name] = create_obj_func(init, param_dtype, namespace)
return namespace[param_name]
Loading

0 comments on commit 5c412aa

Please sign in to comment.