Skip to content

Commit

Permalink
Implemented testing for copy (#19683)
Browse files Browse the repository at this point in the history
Co-authored-by: @AnnaTz
  • Loading branch information
KevinUli authored Oct 19, 2023
1 parent c9f38f1 commit 98d5315
Show file tree
Hide file tree
Showing 21 changed files with 129 additions and 0 deletions.
54 changes: 54 additions & 0 deletions ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,23 @@ def target_fn(instance, *args, **kwargs):
"the array in out argument does not contain same value as the"
" returned"
)
if test_flags.with_copy:
array_fn = ivy_backend.is_array
if "copy" in list(inspect.signature(target_fn).parameters.keys()):
kwargs["copy"] = True
first_array = ivy_backend.func_wrapper._get_first_array(
*args, array_fn=array_fn, **kwargs
)
ret_, ret_np_flat_ = get_ret_and_flattened_np_array(
fw,
target_fn,
*args,
test_trace=test_flags.test_trace,
precision_mode=test_flags.precision_mode,
**kwargs,
)
assert not np.may_share_memory(first_array, ret_)

ret_device = None
if isinstance(ret_from_target, ivy_backend.Array): # TODO use str for now
ret_device = ivy_backend.dev(ret_from_target)
Expand Down Expand Up @@ -451,6 +468,10 @@ def test_function(
"""
_switch_backend_context(test_flags.test_trace or test_flags.transpile)
ground_truth_backend = test_flags.ground_truth_backend

if test_flags.with_copy is True:
test_flags.with_out = False

if mod_backend[backend_to_test]:
# multiprocessing
proc, input_queue, output_queue = mod_backend[backend_to_test]
Expand Down Expand Up @@ -743,6 +764,10 @@ def test_frontend_function(
not test_flags.with_out or not test_flags.inplace
), "only one of with_out or with_inplace can be set as True"

if test_flags.with_copy is True:
test_flags.with_out = False
test_flags.inplace = False

# split the arguments into their positional and keyword components
args_np, kwargs_np = kwargs_to_args_n_kwargs(
num_positional_args=test_flags.num_positional_args, kwargs=all_as_kwargs_np
Expand Down Expand Up @@ -874,6 +899,35 @@ def test_frontend_function(
):
assert ret.ivy_array.data is out.ivy_array.data
assert ret is out
elif test_flags.with_copy:
assert _is_frontend_array(ret)

if "copy" in list(inspect.signature(frontend_fn).parameters.keys()):
copy_kwargs["copy"] = True
first_array = ivy_backend.func_wrapper._get_first_array(
*copy_args,
array_fn=(
_is_frontend_array
if test_flags.generate_frontend_arrays
else ivy_backend.is_array
),
**copy_kwargs,
)
ret_ = get_frontend_ret(
backend_to_test,
frontend_fn,
*copy_args,
test_trace=test_flags.test_trace,
frontend_array_function=(
create_frontend_array if test_flags.test_trace else None
),
precision_mode=test_flags.precision_mode,
**copy_kwargs,
)
if _is_frontend_array(first_array):
first_array = first_array.ivy_array
ret_ = ret_.ivy_array
assert not np.may_share_memory(first_array, ret_)
elif test_flags.inplace:
assert not isinstance(ret, tuple)

Expand Down
13 changes: 13 additions & 0 deletions ivy_tests/test_ivy/helpers/test_parameter_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def _as_varaible_strategy(draw):
BuiltInplaceStrategy = DynamicFlag(st.just(False))
BuiltGradientStrategy = DynamicFlag(_gradient_strategy())
BuiltWithOutStrategy = DynamicFlag(st.booleans())
BuiltWithCopyStrategy = DynamicFlag(st.just(False))
BuiltCompileStrategy = DynamicFlag(st.just(False))
BuiltTraceStrategy = DynamicFlag(st.just(False))
BuiltFrontendArrayStrategy = DynamicFlag(st.booleans())
BuiltTranspileStrategy = DynamicFlag(st.just(False))
Expand All @@ -55,6 +57,7 @@ def _as_varaible_strategy(draw):
"instance_method": "BuiltInstanceStrategy",
"test_gradients": "BuiltGradientStrategy",
"with_out": "BuiltWithOutStrategy",
"with_copy": "BuiltWithCopyStrategy",
"inplace": "BuiltInplace",
"test_trace": "BuiltTraceStrategy",
"transpile": "BuiltTranspileStrategy",
Expand Down Expand Up @@ -86,6 +89,7 @@ def __init__(
ground_truth_backend,
num_positional_args,
with_out,
with_copy,
instance_method,
as_variable,
native_arrays,
Expand All @@ -98,6 +102,7 @@ def __init__(
self.ground_truth_backend = ground_truth_backend
self.num_positional_args = num_positional_args
self.with_out = with_out
self.with_copy = with_copy
self.instance_method = instance_method
self.native_arrays = native_arrays
self.container = container
Expand Down Expand Up @@ -126,6 +131,7 @@ def __str__(self):
f"ground_truth_backend={self.ground_truth_backend}"
f"num_positional_args={self.num_positional_args}. "
f"with_out={self.with_out}. "
f"with_copy={self.with_copy}. "
f"instance_method={self.instance_method}. "
f"native_arrays={self.native_arrays}. "
f"container={self.container}. "
Expand All @@ -148,6 +154,7 @@ def function_flags(
num_positional_args,
instance_method,
with_out,
with_copy,
test_gradients,
test_trace,
transpile,
Expand All @@ -162,6 +169,7 @@ def function_flags(
ground_truth_backend=ground_truth_backend,
num_positional_args=num_positional_args,
with_out=with_out,
with_copy=with_copy,
instance_method=instance_method,
test_gradients=test_gradients,
test_trace=test_trace,
Expand All @@ -179,6 +187,7 @@ def __init__(
self,
num_positional_args,
with_out,
with_copy,
inplace,
as_variable,
native_arrays,
Expand All @@ -189,6 +198,7 @@ def __init__(
):
self.num_positional_args = num_positional_args
self.with_out = with_out
self.with_copy = with_copy
self.inplace = inplace
self.native_arrays = native_arrays
self.as_variable = as_variable
Expand All @@ -213,6 +223,7 @@ def __str__(self):
return (
f"num_positional_args={self.num_positional_args}. "
f"with_out={self.with_out}. "
f"with_copy={self.with_copy}. "
f"inplace={self.inplace}. "
f"native_arrays={self.native_arrays}. "
f"as_variable={self.as_variable}. "
Expand All @@ -232,6 +243,7 @@ def frontend_function_flags(
*,
num_positional_args,
with_out,
with_copy,
inplace,
as_variable,
native_arrays,
Expand All @@ -245,6 +257,7 @@ def frontend_function_flags(
FrontendFunctionTestFlags,
num_positional_args=num_positional_args,
with_out=with_out,
with_copy=with_copy,
inplace=inplace,
as_variable=as_variable,
native_arrays=native_arrays,
Expand Down
13 changes: 13 additions & 0 deletions ivy_tests/test_ivy/helpers/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BuiltGradientStrategy,
BuiltContainerStrategy,
BuiltWithOutStrategy,
BuiltWithCopyStrategy,
BuiltInplaceStrategy,
BuiltTraceStrategy,
BuiltFrontendArrayStrategy,
Expand Down Expand Up @@ -335,6 +336,7 @@ def handle_test(
number_positional_args=None,
test_instance_method=BuiltInstanceStrategy,
test_with_out=BuiltWithOutStrategy,
test_with_copy=BuiltWithCopyStrategy,
test_gradients=BuiltGradientStrategy,
test_trace=BuiltTraceStrategy,
transpile=BuiltTranspileStrategy,
Expand Down Expand Up @@ -368,6 +370,10 @@ def handle_test(
A search strategy that generates a boolean to test the function with an `out`
parameter
test_with_copy
A search strategy that generates a boolean to test the function with an `copy`
parameter
test_gradients
A search strategy that generates a boolean to test the function with arrays as
gradients
Expand Down Expand Up @@ -408,6 +414,7 @@ def handle_test(
num_positional_args=number_positional_args,
instance_method=_get_runtime_flag_value(test_instance_method),
with_out=_get_runtime_flag_value(test_with_out),
with_copy=_get_runtime_flag_value(test_with_copy),
test_gradients=_get_runtime_flag_value(test_gradients),
test_trace=_get_runtime_flag_value(test_trace),
transpile=_get_runtime_flag_value(transpile),
Expand Down Expand Up @@ -472,6 +479,7 @@ def handle_frontend_test(
aliases: List[str] = None,
number_positional_args=None,
test_with_out=BuiltWithOutStrategy,
test_with_copy=BuiltWithCopyStrategy,
test_inplace=BuiltInplaceStrategy,
as_variable_flags=BuiltAsVariableStrategy,
native_array_flags=BuiltNativeArrayStrategy,
Expand Down Expand Up @@ -505,6 +513,10 @@ def handle_frontend_test(
A search strategy that generates a boolean to test the function with an `out`
parameter
test_with_copy
A search strategy that generates a boolean to test the function with an `copy`
parameter
precision_mode
A search strategy that generates a boolean to switch between two different
precision modes supported by numpy and (torch, jax) and test the function
Expand Down Expand Up @@ -539,6 +551,7 @@ def handle_frontend_test(
test_flags = pf.frontend_function_flags(
num_positional_args=number_positional_args,
with_out=_get_runtime_flag_value(test_with_out),
with_copy=_get_runtime_flag_value(test_with_copy),
inplace=_get_runtime_flag_value(test_inplace),
as_variable=_get_runtime_flag_value(as_variable_flags),
native_arrays=_get_runtime_flag_value(native_array_flags),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_jax_arange(
copy=st.booleans(),
ndmin=helpers.ints(min_value=0, max_value=9),
test_with_out=st.just(True),
test_with_copy=st.just(True),
)
def test_jax_array(
*,
Expand Down Expand Up @@ -276,6 +277,7 @@ def test_jax_compress(
max_dim_size=5,
),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_jax_copy(
dtype_and_a,
Expand Down Expand Up @@ -825,6 +827,7 @@ def test_jax_logspace(
sparse=st.booleans(),
indexing=st.sampled_from(["xy", "ij"]),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_jax_meshgrid(
dtype_and_arrays,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2271,6 +2271,7 @@ def test_jax_multiply(
posinf=st.floats(min_value=5e100, max_value=5e100),
neginf=st.floats(min_value=-5e100, max_value=-5e100),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_jax_nan_to_num(
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ndmin=st.integers(min_value=0, max_value=5),
copy=st.booleans(),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_mindspore_array(
dtype_and_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
max_dim_size=5,
),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_numpy_array(
dtype_and_a,
Expand Down Expand Up @@ -85,6 +86,7 @@ def test_numpy_asarray(
max_dim_size=5,
),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_numpy_copy(
dtype_and_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def test_numpy_logspace(
sparse=st.booleans(),
indexing=st.sampled_from(["xy", "ij"]),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_numpy_meshgrid(
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def test_numpy_lcm(
nan=st.floats(min_value=0, max_value=10),
copy=st.booleans(),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_numpy_nan_to_num(
dtype_and_x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_paddle_assign(
@handle_frontend_test(
fn_tree="paddle.clone",
dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")),
test_with_copy=st.just(True),
)
def test_paddle_clone(
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
),
test_with_copy=st.just(True),
)
def test_sklearn_as_float_array(
dtype_and_x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ def test_tensorflow_gather_nd(
available_dtypes=helpers.get_dtypes("numeric"),
),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_tensorflow_identity(
dtype_and_x,
Expand Down Expand Up @@ -1078,6 +1079,7 @@ def test_tensorflow_identity(
available_dtypes=helpers.get_dtypes("valid"), max_num_dims=5
),
test_with_out=st.just(False),
test_with_copy=st.just(True),
)
def test_tensorflow_identity_n(
dtype_and_x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def test_torch_as_tensor(
available_dtypes=helpers.get_dtypes("numeric")
),
dtype=helpers.get_dtypes("numeric", full=False),
test_with_copy=st.just(True),
)
def test_torch_asarray(
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def test_torch_cartesian_prod(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
),
test_with_copy=st.just(True),
)
def test_torch_clone(
*,
Expand Down Expand Up @@ -1024,6 +1025,7 @@ def test_torch_flatten(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"),
force_tuple=True,
),
test_with_copy=st.just(True),
)
def test_torch_flip(
*,
Expand Down Expand Up @@ -1055,6 +1057,7 @@ def test_torch_flip(
available_dtypes=helpers.get_dtypes("float"),
shape=helpers.get_shape(min_num_dims=2),
),
test_with_copy=st.just(True),
)
def test_torch_fliplr(
*,
Expand Down Expand Up @@ -1084,6 +1087,7 @@ def test_torch_fliplr(
available_dtypes=helpers.get_dtypes("float"),
shape=helpers.get_shape(min_num_dims=1),
),
test_with_copy=st.just(True),
)
def test_torch_flipud(
*,
Expand Down
Loading

0 comments on commit 98d5315

Please sign in to comment.