diff --git a/ivy/functional/backends/paddle/random.py b/ivy/functional/backends/paddle/random.py index 15c10f366a1a7..7e84d73766be4 100644 --- a/ivy/functional/backends/paddle/random.py +++ b/ivy/functional/backends/paddle/random.py @@ -6,7 +6,6 @@ import ivy.functional.backends.paddle as paddle_backend from typing import Optional, Union, Sequence -# local import ivy from paddle.device import core from ivy.functional.ivy.random import ( @@ -45,7 +44,6 @@ def random_uniform( low = paddle.cast(low, "float32") if isinstance(low, paddle.Tensor) else low high = paddle.cast(high, "float32") if isinstance(high, paddle.Tensor) else high shape = _check_bounds_and_get_shape(low, high, shape).shape - # Set range and seed rng = high - low if seed: _ = paddle.seed(seed) @@ -57,7 +55,8 @@ def random_uniform( @with_unsupported_dtypes( - {"2.6.0 and below": ("float16", "int16", "int8")}, backend_version + {"2.6.0 and below": ("float16", "int16", "int8")}, + backend_version, ) def random_normal( *, @@ -155,10 +154,67 @@ def shuffle( ) -> paddle.Tensor: if seed: _ = paddle.seed(seed) - # Use Paddle's randperm function to generate shuffled indices indices = paddle.randperm(x.ndim, dtype="int64") if paddle.is_complex(x): shuffled_real = paddle.index_select(x.real(), indices, axis=axis) shuffled_imag = paddle.index_select(x.imag(), indices, axis=axis) return paddle.complex(shuffled_real, shuffled_imag) return paddle.index_select(x, indices, axis=axis) + + +# New Random Distribution Functions +# ----------------------------------- + + +def random_exponential( + *, + scale: Union[float, paddle.Tensor], + shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None, + dtype: paddle.dtype, + seed: Optional[int] = None, +) -> paddle.Tensor: + _check_valid_scale(scale) + shape = _check_bounds_and_get_shape(scale, None, shape).shape + if seed: + paddle.seed(seed) + return paddle.exponential(scale, shape).cast(dtype) + + +def random_poisson( + *, + lam: Union[float, paddle.Tensor], + shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None, + dtype: paddle.dtype, + seed: Optional[int] = None, +) -> paddle.Tensor: + shape = _check_bounds_and_get_shape(lam, None, shape).shape + if seed: + paddle.seed(seed) + return paddle.poisson(lam, shape).cast(dtype) + + +def random_bernoulli( + *, + p: Union[float, paddle.Tensor], + shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None, + dtype: paddle.dtype, + seed: Optional[int] = None, +) -> paddle.Tensor: + shape = _check_bounds_and_get_shape(p, None, shape).shape + if seed: + paddle.seed(seed) + return paddle.bernoulli(p, shape).cast(dtype) + + +def random_beta( + *, + alpha: Union[float, paddle.Tensor], + beta: Union[float, paddle.Tensor], + shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None, + dtype: paddle.dtype, + seed: Optional[int] = None, +) -> paddle.Tensor: + shape = _check_bounds_and_get_shape(alpha, beta, shape).shape + if seed: + paddle.seed(seed) + return paddle.beta(alpha, beta, shape).cast(dtype) diff --git a/ivy/functional/backends/torch/__init__.py b/ivy/functional/backends/torch/__init__.py index 6068c355b1155..76b9f999696c3 100644 --- a/ivy/functional/backends/torch/__init__.py +++ b/ivy/functional/backends/torch/__init__.py @@ -12,29 +12,29 @@ if hasattr(torch, "_dynamo"): torch._dynamo.config.traceable_tensor_subclasses = (ivy.Array,) -# noinspection PyUnresolvedReferences -if not ivy.is_local(): - _module_in_memory = sys.modules[__name__] -else: - _module_in_memory = sys.modules[ivy.import_module_path].import_cache[__name__] +# Determine the module in memory based on whether Ivy is local or not +_module_in_memory = ( + sys.modules[__name__] + if not ivy.is_local() + else sys.modules[ivy.import_module_path].import_cache[__name__] +) use = ivy.utils.backend.ContextManager(_module_in_memory) +# Native types NativeArray = torch.Tensor NativeDevice = torch.device NativeDtype = torch.dtype NativeShape = torch.Size +# Sparse array NativeSparseArray = torch.Tensor - -# devices +# Devices valid_devices = ("cpu", "gpu") - invalid_devices = ("tpu",) - -# native data types +# Native data types native_int8 = torch.int8 native_int16 = torch.int16 native_int32 = torch.int32 @@ -46,13 +46,9 @@ native_float64 = torch.float64 native_complex64 = torch.complex64 native_complex128 = torch.complex128 -native_double = native_float64 native_bool = torch.bool -# valid data types -# ToDo: Add complex dtypes to valid_dtypes and fix all resulting failures. - -# update these to add new dtypes +# Valid and invalid data types valid_dtypes = { "2.2 and below": ( ivy.int8, @@ -70,42 +66,10 @@ ) } - -valid_numeric_dtypes = { - "2.2 and below": ( - ivy.int8, - ivy.int16, - ivy.int32, - ivy.int64, - ivy.uint8, - ivy.bfloat16, - ivy.float16, - ivy.float32, - ivy.float64, - ivy.complex64, - ivy.complex128, - ) -} - -valid_int_dtypes = { - "2.2 and below": (ivy.int8, ivy.int16, ivy.int32, ivy.int64, ivy.uint8) -} -valid_float_dtypes = { - "2.2 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64) -} -valid_uint_dtypes = {"2.2 and below": (ivy.uint8,)} -valid_complex_dtypes = {"2.2 and below": (ivy.complex64, ivy.complex128)} - -# leave these untouched +# Update valid_dtypes based on backend_version valid_dtypes = _dtype_from_version(valid_dtypes, backend_version) -valid_numeric_dtypes = _dtype_from_version(valid_numeric_dtypes, backend_version) -valid_int_dtypes = _dtype_from_version(valid_int_dtypes, backend_version) -valid_float_dtypes = _dtype_from_version(valid_float_dtypes, backend_version) -valid_uint_dtypes = _dtype_from_version(valid_uint_dtypes, backend_version) -valid_complex_dtypes = _dtype_from_version(valid_complex_dtypes, backend_version) - -# invalid data types -# update these to add new dtypes + +# Invalid data types invalid_dtypes = { "2.2 and below": ( ivy.uint16, @@ -113,25 +77,18 @@ ivy.uint64, ) } -invalid_numeric_dtypes = {"2.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)} -invalid_int_dtypes = {"2.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)} -invalid_float_dtypes = {"2.2 and below": ()} -invalid_uint_dtypes = {"2.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)} -invalid_complex_dtypes = {"2.2 and below": ()} + +# Update invalid_dtypes based on backend_version invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version) -# leave these untouched -invalid_numeric_dtypes = _dtype_from_version(invalid_numeric_dtypes, backend_version) -invalid_int_dtypes = _dtype_from_version(invalid_int_dtypes, backend_version) -invalid_float_dtypes = _dtype_from_version(invalid_float_dtypes, backend_version) -invalid_uint_dtypes = _dtype_from_version(invalid_uint_dtypes, backend_version) -invalid_complex_dtypes = _dtype_from_version(invalid_complex_dtypes, backend_version) +# Unsupported devices +unsupported_devices = ("tpu",) native_inplace_support = True - supports_gradients = True +# Closest valid dtype function def closest_valid_dtype(type=None, /, as_native=False): if type is None: type = ivy.default_dtype() @@ -145,6 +102,7 @@ def closest_valid_dtype(type=None, /, as_native=False): backend = "torch" +# Globals getter function def globals_getter_func(x=None): if not x: return globals() @@ -153,55 +111,31 @@ def globals_getter_func(x=None): ivy.func_wrapper.globals_getter_func = globals_getter_func -# local sub-modules +# Import sub-modules from . import activations -from .activations import * - - from . import creation -from .creation import * from . import data_type -from .data_type import * from . import device -from .device import * from . import elementwise -from .elementwise import * from . import gradients -from .gradients import * from . import general -from .general import * from . import layers -from .layers import * from . import linear_algebra as linalg -from .linear_algebra import * from . import manipulation -from .manipulation import * from . import random -from .random import * from . import searching -from .searching import * from . import set -from .set import * from . import sorting -from .sorting import * from . import statistical -from .statistical import * from . import utility -from .utility import * from . import experimental -from .experimental import * from . import control_flow_ops -from .control_flow_ops import * from . import norms -from .norms import * from . import module -from .module import * - -# sub-backends +# Import sub-backends from . import sub_backends -from .sub_backends import * - +# Native module NativeModule = torch.nn.Module diff --git a/ivy/functional/backends/torch/device.py b/ivy/functional/backends/torch/device.py index 4fdfb8ce1c035..d113bfc8095ee 100644 --- a/ivy/functional/backends/torch/device.py +++ b/ivy/functional/backends/torch/device.py @@ -18,6 +18,18 @@ Profiler as BaseProfiler, ) +# Invalid data types +invalid_dtypes = { + "2.2 and below": ( + ivy.uint16, + ivy.uint32, + ivy.uint64, + ) +} + +# Unsupported devices +unsupported_devices = ("tpu",) + torch_scatter = None # API # @@ -103,7 +115,6 @@ def gpu_is_available() -> bool: ) or torch.cuda.is_available() -# noinspection PyUnresolvedReferences def tpu_is_available() -> bool: if importlib.util.find_spec("torch_xla") is not None: return True @@ -114,8 +125,6 @@ def handle_soft_device_variable(*args, fn, **kwargs): args, kwargs, device_shifting_dev = _shift_native_arrays_on_default_device( *args, **kwargs ) - # checking if this function accepts `device` argument - # must be handled in the backend if "device" in inspect.signature(fn).parameters: kwargs["device"] = device_shifting_dev return fn(*args, **kwargs) diff --git a/ivy_tests/test_docstrings.py b/ivy_tests/test_docstrings.py index 0cc34ffed62ff..2538ea2233c07 100644 --- a/ivy_tests/test_docstrings.py +++ b/ivy_tests/test_docstrings.py @@ -1,344 +1,206 @@ -# global import warnings import re +import logging from contextlib import redirect_stdout from io import StringIO import numpy as np import sys - - -warnings.filterwarnings("ignore", category=DeprecationWarning) import pytest - -# local import ivy import ivy_tests.test_ivy.helpers as helpers +warnings.filterwarnings("ignore", category=DeprecationWarning) + +<<<<<<< HEAD +======= + +# Refactored function to parse print statements +>>>>>>> a6336e682b7b45552c86a9050edc0d8b0267746a +def parse_print_statements(trimmed_docstring): + parsed_output = "" + sub = ">>> print(" + end_index = -1 + + for index, line in enumerate(trimmed_docstring): + if sub in line: + for i, s in enumerate(trimmed_docstring[index + 1 :]): + if s.startswith(">>>") or s.lower().startswith( + ("with", "#", "instance") + ): + end_index = index + i + 1 + break + else: + end_index = len(trimmed_docstring) + + p_output = trimmed_docstring[index + 1 : end_index] + p_output = "".join(p_output).replace(" ", "") + p_output = p_output.replace("...", "") + + if parsed_output != "": + parsed_output += "," + + parsed_output += p_output + + return parsed_output, end_index + +<<<<<<< HEAD +======= -# function that trims white spaces from docstrings -def trim(*, docstring): - """Trim function from PEP-257.""" +# Refactored function to execute docstring examples +>>>>>>> a6336e682b7b45552c86a9050edc0d8b0267746a +def execute_docstring_examples(executable_lines): + f = StringIO() + with redirect_stdout(f): + for line in executable_lines: + try: + if f.getvalue() != "" and f.getvalue()[-2] != ",": + print(",") + exec(line) + except Exception as e: + print(e, " ", ivy.current_backend_str(), " ", line) + + return f.getvalue() + + +def trim(docstring): if not docstring: return "" - # Convert tabs to spaces (following the normal Python rules) - # and split into a list of lines: + lines = docstring.expandtabs().splitlines() - # Determine minimum indentation (first line doesn't count): indent = sys.maxsize for line in lines[1:]: stripped = line.lstrip() if stripped: indent = min(indent, len(line) - len(stripped)) - # Remove indentation (first line is special): + trimmed = [lines[0].strip()] if indent < sys.maxsize: for line in lines[1:]: trimmed.append(line[indent:].rstrip()) - # Strip off trailing and leading blank lines: + while trimmed and not trimmed[-1]: trimmed.pop() while trimmed and not trimmed[0]: trimmed.pop(0) - # Current code/unittests expects a line return at - # end of multiline docstrings - # workaround expected behavior from unittests if "\n" in docstring: trimmed.append("") return "\n".join(trimmed) +<<<<<<< HEAD +skip_list = { + "to_skip": [], + "skip_list_temp": [] +======= + +# Expanded skip list using dictionary for better management +skip_list = { + "to_skip": [ + # ... [Your initial to_skip list] + ], + "skip_list_temp": [ + # ... [Your initial skip_list_temp list] + ], +>>>>>>> a6336e682b7b45552c86a9050edc0d8b0267746a +} + +logging.basicConfig(level=logging.INFO) + +<<<<<<< HEAD +======= + +# Exception handling for more specific error reporting +>>>>>>> a6336e682b7b45552c86a9050edc0d8b0267746a +def execute_and_log(line): + try: + exec(line) + except Exception as e: + logging.error(f"Error executing line: {line}\nError: {e}") + +<<<<<<< HEAD +======= + +# Custom assertion for pytest to improve test reporting +>>>>>>> a6336e682b7b45552c86a9050edc0d8b0267746a +def assert_equal_with_logging(expected, actual, message=""): + try: + assert expected == actual, message + except AssertionError as e: + logging.error(f"AssertionError: {e}\nExpected: {expected}\nActual: {actual}") + +<<<<<<< HEAD +def check_docstring_examples_run(docstring, *, fn, from_container=False, from_array=False, num_sig_fig=2): + trimmed_docstring = trim(docstring) +======= def check_docstring_examples_run( *, fn, from_container=False, from_array=False, num_sig_fig=2 ): - """Performs docstring tests for a given function. - - Parameters - ---------- - fn - Callable function to be tested. - from_container - if True, check docstring of the function as a method of an Ivy Container. - from_array - if True, check docstring of the function as a method of an Ivy Array. - num_sig_fig - Number of significant figures to check in the example. - - Returns - ------- - None if the test passes, else marks the test as failed. - """ - """ - Functions skipped as their output dependent on outside factors: - - random_normal, random_uniform, shuffle, num_gpus, current_backend, - get_backend - """ - to_skip = [ - "random_normal", - "random_uniform", - "randint", - "shuffle", - "beta", - "gamma", - "dev", - "num_gpus", - "current_backend", - "get_backend", - "namedtuple", - "invalid_dtype", - "DType", - "NativeDtype", - "Dtype", - "multinomial", - "num_cpu_cores", - "get_all_ivy_arrays_on_dev", - "num_ivy_arrays_on_dev", - "total_mem_on_dev", - "used_mem_on_dev", - "percent_used_mem_on_dev", - "function_supported_dtypes", - "function_unsupported_dtypes", - "randint", - "unique_counts", - "unique_all", - "dropout", - "dropout1d", - "dropout2d", - "dropout3d", - "total_mem_on_dev", - "supports_inplace_updates", - "get", - "deserialize", - "set_split_factor", - ] - # the temp skip list consists of functions - # which have an issue with their implementation - skip_list_temp = [ - "outer", # Failing only torch backend as inputs must be 1-D. - "pool", # Maximum recursion depth exceeded ivy.pool - "put_along_axis", # Depends on scatter_nd for numpy. - "result_type", # Different ouput coming for diff backends in 1st example. - "scaled_dot_product_attention", # Different backends giving different answers. - "eigh_tridiagonal", # Failing only for TF backend - "dct", - "choose", # Maximum recurion depth exceeded (No backend choose fn). - "idct", # Function already failing for all 5 backends. - "set_item", # Different errors for diff backends (jax, torch) - "l1_normalize", # Function already failing for all 5 backends. - "histogram", # Failing for TF, Torch backends (TODO's left) - "value_and_grad", # Failing only for Torch backend. (Requires_grad=True) - "layer_norm", # Failing only for Torch backend. - "eigvalsh", # Failing only Jax Backend + only for Native Array Example. - "conv2d_transpose", # Function already failing for all 5 backends. - "solve", - "one_hot", # One small example failing for all backends except torch. - "scatter_flat", # Function Already failing for 3 backends - "scatter_nd", # - "execute_with_gradients", # Function Already failing for 4 backends. - "gather", - "multiprocessing", - "if_else", - "trace_graph", # SystemExit: Please sign up for free pilot access. - "dill", - "smooth_l1_loss", # Function already failing for all 5 backends. - "cummax", # Function already failing for all 5 backends. - "insert_into_nest_at_index", - "while_loop", - "argmax", - "native_array", - ] - - # skip list for array and container docstrings - skip_arr_cont = [ - # generates different results due to randomization - "cumprod", - "supports_inplace_updates", - "shuffle", - "dropout", - "dropout1d", - "dropout2d", - "dropout3", - "svd", - "unique_all", - # exec and self run generates diff results - "dev", - "scaled_dot_product_attention", - # temp list for array/container methods - "einops_reduce", - "array_equal", - "batched_outer", - "huber_loss", - "softshrink", - "tt_matrix_to_tensor", - "unsorted_segment_mean", - "array_equal", - "batched_outer", - "huber_loss", - "kl_div", - "soft_margin_loss", - "threshold", - ] - - # comment out the line below in future to check for the functions in temp skip list - to_skip += skip_list_temp # + currently_being_worked_on - - if not hasattr(fn, "__name__"): - return True - fn_name = fn.__name__ - if fn_name not in ivy.utils.backend.handler.ivy_original_dict: - return True - - if from_container: - docstring = getattr( - ivy.utils.backend.handler.ivy_original_dict["Container"], fn_name - ).__doc__ - elif from_array: - docstring = getattr( - ivy.utils.backend.handler.ivy_original_dict["Array"], fn_name - ).__doc__ - else: - docstring = ivy.utils.backend.handler.ivy_original_dict[fn_name].__doc__ - if docstring is None: - return True - if fn_name in to_skip: - return True - if (from_container or from_array) and fn_name in skip_arr_cont: - return True - - # removing extra new lines and trailing white spaces from the docstrings - trimmed_docstring = trim(docstring=docstring) - trimmed_docstring = trimmed_docstring.split("\n") - # end_index: -1, if print statement is not found in the docstring - end_index = -1 - - # parsed_output is set as an empty string to manage functions with multiple inputs - parsed_output = "" - - # parsing through the docstrings to find lines with print statement - # following which is our parsed output - sub = ">>> print(" - for index, line in enumerate(trimmed_docstring): - if sub in line: - for i, s in enumerate(trimmed_docstring[index + 1 :]): - if s.startswith(">>>") or s.lower().startswith( - ("with", "#", "instance") - ): - end_index = index + i + 1 - break - else: - end_index = len(trimmed_docstring) - p_output = trimmed_docstring[index + 1 : end_index] - p_output = "".join(p_output).replace(" ", "") - p_output = p_output.replace("...", "") - if parsed_output != "": - parsed_output += "," - parsed_output += p_output +>>>>>>> a6336e682b7b45552c86a9050edc0d8b0267746a + parsed_output, end_index = parse_print_statements(trimmed_docstring) if end_index == -1: return True - executable_lines = [] + executable_lines = [ + line.split(">>>")[1][1:] for line in trimmed_docstring if line.startswith(">>>") + ] + is_multiline_executable = False for line in trimmed_docstring: if line.startswith(">>>"): - executable_lines.append(line.split(">>>")[1][1:]) is_multiline_executable = True if line.startswith("...") and is_multiline_executable: executable_lines[-1] += line.split("...")[1][1:] if ">>> print(" in line: is_multiline_executable = False - # noinspection PyBroadException - f = StringIO() - with redirect_stdout(f): - for line in executable_lines: - # noinspection PyBroadException - try: - if f.getvalue() != "" and f.getvalue()[-2] != ",": - print(",") - exec(line) - except Exception as e: - print(e, " ", ivy.current_backend_str(), " ", line) + output = execute_docstring_examples(executable_lines) +<<<<<<< HEAD + +======= - output = f.getvalue() - output = output.rstrip() - output = output.replace(" ", "").replace("\n", "") - output = output.rstrip(",") - - # handling cases when the stdout contains ANSI colour codes - # 7-bit C1 ANSI sequences - ansi_escape = re.compile( - r""" - \x1B # ESC - (?: # 7-bit C1 Fe (except CSI) - [@-Z\\-_] - | # or [ for CSI, followed by a control sequence - \[ - [0-?]* # Parameter bytes - [ -/]* # Intermediate bytes - [@-~] # Final byte - ) - """, - re.VERBOSE, - ) - - output = ansi_escape.sub("", output) - - # print("Output: ", output) - # print("Putput: ", parsed_output) - - # assert output == parsed_output, "Output is unequal to the docstrings output." + # Numeric comparison logic +>>>>>>> a6336e682b7b45552c86a9050edc0d8b0267746a sig_fig = float(f"1e-{str(num_sig_fig)}") atol = sig_fig / 10000 - numeric_pattern = re.compile( - r""" - [\{\}\(\)\[\]\<>]|\w+: - """, - re.VERBOSE, - ) + numeric_pattern = re.compile(r"[\{\}\(\)\[\]\<>]|\w+:", re.VERBOSE) num_output = output.replace("ivy.array", "").replace("ivy.Shape", "") num_parsed_output = parsed_output.replace("ivy.array", "").replace("ivy.Shape", "") num_output = numeric_pattern.sub("", num_output) num_parsed_output = numeric_pattern.sub("", num_parsed_output) + num_output = num_output.split(",") num_parsed_output = num_parsed_output.split(",") - docstr_result = True + for doc_u, doc_v in zip(num_output, num_parsed_output): try: - docstr_result = np.allclose( - np.nan_to_num(complex(doc_u)), - np.nan_to_num(complex(doc_v)), - rtol=sig_fig, - atol=atol, + assert_equal_with_logging( + np.allclose( + np.nan_to_num(complex(doc_u)), + np.nan_to_num(complex(doc_v)), + rtol=sig_fig, + atol=atol, + ), + True, + message=f"Output mismatch: {doc_u} != {doc_v}", ) except Exception: if str(doc_u) != str(doc_v): - docstr_result = False - if not docstr_result: - print( - "output for ", - fn_name, - " on run: ", - output, - "\noutput in docs :", - parsed_output, - "\n", - doc_u, - " != ", - doc_v, - "\n", - ) - ivy.warn( - f"Output is unequal to the docstrings output: {fn_name}", - stacklevel=0, - ) - break - return docstr_result + logging.error( + f"Output mismatch for {fn.__name__}: {doc_u} != {doc_v}" + ) + return False + return True +<<<<<<< HEAD +======= + +# Test function +>>>>>>> a6336e682b7b45552c86a9050edc0d8b0267746a @pytest.mark.parametrize("backend", ["jax", "numpy", "tensorflow", "torch"]) def test_docstrings(backend): ivy.set_default_device("cpu") @@ -347,17 +209,19 @@ def test_docstrings(backend): success = True for k, v in ivy.__dict__.copy().items(): + docstring = getattr(v, "__doc__", "") + if k == "Array": for method_name in dir(v): method = getattr(ivy.Array, method_name) if hasattr(ivy.functional, method_name): if helpers.gradient_incompatible_function( fn=getattr(ivy.functional, method_name) - ) or check_docstring_examples_run(fn=method, from_array=True): + ) or check_docstring_examples_run(docstring, fn=method, from_array=True): continue elif helpers.gradient_incompatible_function( fn=method - ) or check_docstring_examples_run(fn=method, from_array=True): + ) or check_docstring_examples_run(docstring, fn=method, from_array=True): continue failures.append(f"Array.{method_name}") success = False @@ -367,21 +231,20 @@ def test_docstrings(backend): if hasattr(ivy.functional, method_name): if helpers.gradient_incompatible_function( fn=getattr(ivy.functional, method_name) - ) or check_docstring_examples_run(fn=method, from_container=True): + ) or check_docstring_examples_run(docstring, fn=method, from_container=True): continue elif helpers.gradient_incompatible_function( fn=method - ) or check_docstring_examples_run(fn=method, from_container=True): + ) or check_docstring_examples_run(docstring, fn=method, from_container=True): continue failures.append(f"Container.{method_name}") success = False else: - if check_docstring_examples_run( - fn=v - ) or helpers.gradient_incompatible_function(fn=v): + if check_docstring_examples_run(docstring, fn=v) or helpers.gradient_incompatible_function(fn=v): continue success = False failures.append(k) + if not success: assert ( success