From 7e00f52a0dfa942cd62db063a10ababea08ede94 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Wed, 27 Sep 2023 21:35:02 +0530 Subject: [PATCH] refactor: Refactored `test_assertions.py` file. (#23584) --- .../test_ivy/test_misc/test_assertions.py | 485 +++++++++--------- 1 file changed, 242 insertions(+), 243 deletions(-) diff --git a/ivy_tests/test_ivy/test_misc/test_assertions.py b/ivy_tests/test_ivy/test_misc/test_assertions.py index e3c9bf96dbac8..b6e4c948c1dea 100644 --- a/ivy_tests/test_ivy/test_misc/test_assertions.py +++ b/ivy_tests/test_ivy/test_misc/test_assertions.py @@ -40,15 +40,15 @@ def test_check_all(results): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_all(results) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_all(results) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -81,23 +81,23 @@ def test_check_all(results): def test_check_all_or_any_fn(args, fn, type, limit): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_all_or_any_fn(*args, fn=fn, type=type, limit=limit) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_all_or_any_fn(*args, fn=fn, type=type, limit=limit) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() if type in ["all", "any"]: - if "e" in local_vars.keys(): + if "e" in local_vars: assert "args must exist according to" in lines.strip() else: assert not lines.strip() @@ -116,15 +116,15 @@ def test_check_all_or_any_fn(args, fn, type, limit): def test_check_any(results): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_any(results) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_any(results) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -171,26 +171,26 @@ def test_check_dev_correct_formatting(device): def test_check_dimensions(x): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_dimensions(x) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_dimensions(x) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert "greater than one dimension" in lines.strip() - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -209,15 +209,15 @@ def test_check_dimensions(x): def test_check_elem_in_list(elem, list, inverse): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_elem_in_list(elem, list, inverse) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_elem_in_list(elem, list, inverse) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -232,7 +232,6 @@ def test_check_elem_in_list(elem, list, inverse): if inverse: if elem not in list: assert not lines.strip() - if elem in list: assert "must not be one" in lines.strip() @@ -252,15 +251,15 @@ def test_check_elem_in_list(elem, list, inverse): def test_check_equal(x1, x2, inverse): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_equal(x1, x2, inverse) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_equal(x1, x2, inverse) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -290,15 +289,15 @@ def test_check_equal(x1, x2, inverse): def test_check_exists(x, inverse): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_exists(x, inverse) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_exists(x, inverse) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -333,15 +332,15 @@ def test_check_exists(x, inverse): def test_check_false(expression): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_false(expression) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_false(expression) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -373,29 +372,29 @@ def test_check_false(expression): def test_check_gather_input_valid(params, indices, axis, batch_dims): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_gather_input_valid(params, indices, axis, batch_dims) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_gather_input_valid(params, indices, axis, batch_dims) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert ( "must be less than or equal" in lines.strip() or "batch dimensions must match in" in lines.strip() ) - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -419,23 +418,23 @@ def test_check_gather_input_valid(params, indices, axis, batch_dims): def test_check_gather_nd_input_valid(params, indices, batch_dims): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_gather_nd_input_valid(params, indices, batch_dims) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_gather_nd_input_valid(params, indices, batch_dims) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert ( "less than rank(`params`)" in lines.strip() or "less than rank(`indices`)" in lines.strip() @@ -443,7 +442,7 @@ def test_check_gather_nd_input_valid(params, indices, batch_dims): or "index innermost dimension length must be <=" in lines.strip() ) - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -462,15 +461,15 @@ def test_check_gather_nd_input_valid(params, indices, batch_dims): def test_check_greater(x1, x2, allow_equal): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_greater(x1, x2, allow_equal) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_greater(x1, x2, allow_equal) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -502,26 +501,26 @@ def test_check_greater(x1, x2, allow_equal): def test_check_inplace_sizes_valid(var, data): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_inplace_sizes_valid(var, data) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_inplace_sizes_valid(var, data) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert "Could not output values of shape" in lines.strip() - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -535,15 +534,15 @@ def test_check_inplace_sizes_valid(var, data): def test_check_isinstance(x, allowed_types): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_isinstance(x, allowed_types) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_isinstance(x, allowed_types) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -576,26 +575,26 @@ def test_check_isinstance(x, allowed_types): def test_check_jax_x64_flag(dtype): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - _check_jax_x64_flag(dtype) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + _check_jax_x64_flag(dtype) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert "output not supported while jax_enable_x64" in lines.strip() - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -616,26 +615,26 @@ def test_check_jax_x64_flag(dtype): def test_check_kernel_padding_size(kernel_size, padding_size): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_kernel_padding_size(kernel_size, padding_size) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_kernel_padding_size(kernel_size, padding_size) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert "less than or equal to half" in lines.strip() - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -654,15 +653,15 @@ def test_check_kernel_padding_size(kernel_size, padding_size): def test_check_less(x1, x2, allow_equal): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_less(x1, x2, allow_equal) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_less(x1, x2, allow_equal) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -692,26 +691,26 @@ def test_check_less(x1, x2, allow_equal): def test_check_same_dtype(x1, x2): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_same_dtype(x1, x2) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_same_dtype(x1, x2) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert "same dtype" in lines.strip() - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -731,26 +730,26 @@ def test_check_same_dtype(x1, x2): def test_check_shape(x1, x2): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_shape(x1, x2) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_shape(x1, x2) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert "same shape" in lines.strip() - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -771,26 +770,26 @@ def test_check_shape(x1, x2): def test_check_shapes_broadcastable(var, data): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_shapes_broadcastable(var, data) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_shapes_broadcastable(var, data) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert "Could not broadcast shape" in lines.strip() - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError): @@ -809,15 +808,15 @@ def test_check_shapes_broadcastable(var, data): def test_check_true(expression): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_true(expression) - except Exception as e: - print(e) - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_true(expression) + except Exception as e: + print(e) + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() @@ -856,23 +855,23 @@ def test_check_true(expression): def test_check_unsorted_segment_min_valid_params(data, segment_ids, num_segments): filename = "except_out.txt" orig_stdout = sys.stdout - f = open(filename, "w") - sys.stdout = f - lines = "" - try: - check_unsorted_segment_min_valid_params(data, segment_ids, num_segments) - local_vars = {**locals()} - except Exception as e: - local_vars = {**locals()} - print(e) - - sys.stdout = orig_stdout - f.close() + + with open(filename, "w") as f: + sys.stdout = f + lines = "" + try: + check_unsorted_segment_min_valid_params(data, segment_ids, num_segments) + local_vars = {**locals()} + except Exception as e: + local_vars = {**locals()} + print(e) + + sys.stdout = orig_stdout with open(filename) as f: lines += f.read() - if "e" in local_vars.keys(): + if "e" in local_vars: assert ( "num_segments must be of integer type" in lines.strip() or "segment_ids must have an integer dtype" in lines.strip() @@ -881,7 +880,7 @@ def test_check_unsorted_segment_min_valid_params(data, segment_ids, num_segments or "num_segments must be positive" in lines.strip() ) - if "e" not in local_vars.keys(): + if "e" not in local_vars: assert not lines.strip() with contextlib.suppress(FileNotFoundError):