Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CLI flags in testing pipeline #22788

Merged
merged 5 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions ivy_tests/test_ivy/helpers/test_parameter_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
from . import globals as test_globals
from .pipeline_helper import BackendHandler

from dataclasses import dataclass
from hypothesis.strategies import SearchStrategy


@dataclass
class DynamicFlag:
strategy: SearchStrategy


@st.composite
def _gradient_strategy(draw):
Expand All @@ -27,17 +35,17 @@ def _as_varaible_strategy(draw):
return draw(st.lists(st.booleans(), min_size=1, max_size=1))


BuiltNativeArrayStrategy = st.lists(st.booleans(), min_size=1, max_size=1)
BuiltAsVariableStrategy = _as_varaible_strategy()
BuiltContainerStrategy = st.lists(st.booleans(), min_size=1, max_size=1)
BuiltInstanceStrategy = st.booleans()
BuiltInplaceStrategy = st.just(False)
BuiltGradientStrategy = _gradient_strategy()
BuiltWithOutStrategy = st.booleans()
BuiltCompileStrategy = st.just(False)
BuiltFrontendArrayStrategy = st.booleans()
BuiltTranspileStrategy = st.just(False)
BuiltPrecisionModeStrategy = st.booleans()
BuiltNativeArrayStrategy = DynamicFlag(st.lists(st.booleans(), min_size=1, max_size=1))
BuiltAsVariableStrategy = DynamicFlag(_as_varaible_strategy())
BuiltContainerStrategy = DynamicFlag(st.lists(st.booleans(), min_size=1, max_size=1))
BuiltInstanceStrategy = DynamicFlag(st.booleans())
BuiltInplaceStrategy = DynamicFlag(st.just(False))
BuiltGradientStrategy = DynamicFlag(_gradient_strategy())
BuiltWithOutStrategy = DynamicFlag(st.booleans())
BuiltCompileStrategy = DynamicFlag(st.booleans())
BuiltFrontendArrayStrategy = DynamicFlag(st.booleans())
BuiltTranspileStrategy = DynamicFlag(st.just(False))
BuiltPrecisionModeStrategy = DynamicFlag(st.booleans())


flags_mapping = {
Expand All @@ -61,7 +69,7 @@ def build_flag(key: str, value: bool):
assert (
flags_mapping[key] in globals().keys()
), f"{flags_mapping[key]} is not a valid flag variable."
globals()[flags_mapping[key]] = value
globals()[flags_mapping[key]].strategy = value


# Strategy Helpers #
Expand Down
73 changes: 39 additions & 34 deletions ivy_tests/test_ivy/helpers/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from . import test_globals as t_globals
from .pipeline_helper import BackendHandler
from ivy_tests.test_ivy.helpers.test_parameter_flags import (
DynamicFlag,
BuiltInstanceStrategy,
BuiltAsVariableStrategy,
BuiltNativeArrayStrategy,
Expand Down Expand Up @@ -51,6 +52,10 @@
)


def _get_runtime_flag_value(flag):
return flag.strategy if isinstance(flag, DynamicFlag) else flag


@st.composite
def num_positional_args_method(draw, *, method):
"""
Expand Down Expand Up @@ -395,14 +400,14 @@ def handle_test(
possible_arguments["test_flags"] = pf.function_flags(
ground_truth_backend=st.just(ground_truth_backend),
num_positional_args=number_positional_args,
instance_method=test_instance_method,
with_out=test_with_out,
test_gradients=test_gradients,
test_compile=test_compile,
as_variable=as_variable_flags,
native_arrays=native_array_flags,
container_flags=container_flags,
precision_mode=precision_mode,
instance_method=_get_runtime_flag_value(test_instance_method),
with_out=_get_runtime_flag_value(test_with_out),
test_gradients=_get_runtime_flag_value(test_gradients),
test_compile=_get_runtime_flag_value(test_compile),
as_variable=_get_runtime_flag_value(as_variable_flags),
native_arrays=_get_runtime_flag_value(native_array_flags),
container_flags=_get_runtime_flag_value(container_flags),
precision_mode=_get_runtime_flag_value(precision_mode),
)

def test_wrapper(test_fn):
Expand Down Expand Up @@ -526,14 +531,14 @@ def handle_frontend_test(
# Generate the test flags strategy
test_flags = pf.frontend_function_flags(
num_positional_args=number_positional_args,
with_out=test_with_out,
inplace=test_inplace,
as_variable=as_variable_flags,
native_arrays=native_array_flags,
test_compile=test_compile,
generate_frontend_arrays=generate_frontend_arrays,
transpile=transpile,
precision_mode=precision_mode,
with_out=_get_runtime_flag_value(test_with_out),
inplace=_get_runtime_flag_value(test_inplace),
as_variable=_get_runtime_flag_value(as_variable_flags),
native_arrays=_get_runtime_flag_value(native_array_flags),
test_compile=_get_runtime_flag_value(test_compile),
generate_frontend_arrays=_get_runtime_flag_value(generate_frontend_arrays),
transpile=_get_runtime_flag_value(transpile),
precision_mode=_get_runtime_flag_value(precision_mode),
)

def test_wrapper(test_fn):
Expand Down Expand Up @@ -635,9 +640,9 @@ def handle_method(
is_hypothesis_test = len(_given_kwargs) != 0
possible_arguments = {
"ground_truth_backend": st.just(ground_truth_backend),
"test_gradients": test_gradients,
"test_compile": test_compile,
"precision_mode": precision_mode,
"test_gradients": _get_runtime_flag_value(test_gradients),
"test_compile": _get_runtime_flag_value(test_compile),
"precision_mode": _get_runtime_flag_value(precision_mode),
}

if is_hypothesis_test and is_method_tree_provided:
Expand All @@ -650,9 +655,9 @@ def handle_method(

possible_arguments["init_flags"] = pf.init_method_flags(
num_positional_args=init_num_positional_args,
as_variable=init_as_variable_flags,
native_arrays=init_native_arrays,
precision_mode=precision_mode,
as_variable=_get_runtime_flag_value(init_as_variable_flags),
native_arrays=_get_runtime_flag_value(init_native_arrays),
precision_mode=_get_runtime_flag_value(precision_mode),
)

if method_num_positional_args is None:
Expand All @@ -662,10 +667,10 @@ def handle_method(

possible_arguments["method_flags"] = pf.method_flags(
num_positional_args=method_num_positional_args,
as_variable=method_as_variable_flags,
native_arrays=method_native_arrays,
container_flags=method_container_flags,
precision_mode=precision_mode,
as_variable=_get_runtime_flag_value(method_as_variable_flags),
native_arrays=_get_runtime_flag_value(method_native_arrays),
container_flags=_get_runtime_flag_value(method_container_flags),
precision_mode=_get_runtime_flag_value(precision_mode),
)

def test_wrapper(test_fn):
Expand Down Expand Up @@ -783,18 +788,18 @@ def test_wrapper(test_fn):
param_names = inspect.signature(test_fn).parameters.keys()
init_flags = pf.frontend_method_flags(
num_positional_args=init_num_positional_args,
as_variable=init_as_variable_flags,
native_arrays=init_native_arrays,
test_compile=test_compile,
precision_mode=precision_mode,
as_variable=_get_runtime_flag_value(init_as_variable_flags),
native_arrays=_get_runtime_flag_value(init_native_arrays),
test_compile=_get_runtime_flag_value(test_compile),
precision_mode=_get_runtime_flag_value(precision_mode),
)

method_flags = pf.frontend_method_flags(
num_positional_args=method_num_positional_args,
as_variable=method_as_variable_flags,
native_arrays=method_native_arrays,
test_compile=test_compile,
precision_mode=precision_mode,
as_variable=_get_runtime_flag_value(method_as_variable_flags),
native_arrays=_get_runtime_flag_value(method_native_arrays),
test_compile=_get_runtime_flag_value(test_compile),
precision_mode=_get_runtime_flag_value(precision_mode),
)
ivy_init_modules = str(ivy_init_module)
framework_init_modules = str(framework_init_module)
Expand Down
Loading