diff --git a/tripy/tests/spec_verification/test_dtype_constraints.py b/tripy/tests/spec_verification/test_dtype_constraints.py index 0f5c41429..1bef0b863 100755 --- a/tripy/tests/spec_verification/test_dtype_constraints.py +++ b/tripy/tests/spec_verification/test_dtype_constraints.py @@ -17,14 +17,14 @@ import inspect -from typing import List +from typing import List, Union, Optional, get_origin, get_args, ForwardRef, get_type_hints from tripy.common.datatype import DATA_TYPES import itertools import pytest from tests.spec_verification.object_builders import create_obj -from tripy.constraints import TYPE_VERIFICATION, RETURN_VALUE +from tripy.constraints import TYPE_VERIFICATION, RETURN_VALUE, FUNC_W_DOC_VERIF import tripy as tp -from contextlib import ExitStack +import tests.helper def _method_handler(func_name, kwargs, func_obj, api_call_locals): @@ -181,3 +181,77 @@ def test_neg_dtype_constraints(test_data): api_call_locals, namespace = _run_dtype_constraints_subtest(test_data) if isinstance(api_call_locals[RETURN_VALUE], tp.Tensor): assert api_call_locals[RETURN_VALUE].dtype == namespace[return_dtype] + + +def get_all_possible_verif_ops(): + qualnames = set() + tripy_interfaces = tests.helper.get_all_tripy_interfaces() + for obj in tripy_interfaces: + if not obj.__doc__: + continue + + blocks = [ + (block.code()) + for block in tests.helper.consolidate_code_blocks(obj.__doc__) + if isinstance(block, tests.helper.DocstringCodeBlock) + ] + if blocks is None: + continue + + if isinstance(obj, property): + continue + + func_sig = inspect.signature(func_obj) + param_dict = func_sig.parameters + contains_tensor_input = False + for type_hint in param_dict.values(): + type_hint = type_hint.annotation + while get_origin(type_hint) in [Union, Optional, list] and not contains_tensor_input: + type_hint = get_args(type_hint)[0] + # ForwardRef refers to any case where type hint is a string. + if isinstance(type_hint, ForwardRef): + type_hint = type_hint.__forward_arg__ + if type_hint == "tripy.Tensor": + print(type_hint) + contains_tensor_input = True + + if not contains_tensor_input: + continue + + qualnames.add(obj.__qualname__) + + return qualnames + + +print(get_all_possible_verif_ops()) +print(len(get_all_possible_verif_ops())) + +operations = get_all_possible_verif_ops() +# add any function that you do not want to be verified: +func_exceptions = [ + "plugin", + "dequantize", + "default", + "dtype", + "function", + "type", + "tolist", + "md5", + "integer", + "volume", + "save", + "floating", + "load", + "device", +] + + +# Check if there are any operations that are not included (Currently does not test any ____ functions) +@pytest.mark.parametrize("func_qualname", operations, ids=lambda val: f"is_{val}_verified") +def test_all_ops_verified(func_qualname): + if func_qualname not in func_exceptions: + assert ( + func_qualname in FUNC_W_DOC_VERIF + ), f"function {func_qualname}'s data types have not been verified. Please add data type verification by following the guide within tripy/tests/spec_verification or exclude it from this test." + else: + pytest.skip("Data type constraints are not required for this API")