diff --git a/fuzz_lightyear/datastore.py b/fuzz_lightyear/datastore.py index ce84e84..1f68908 100644 --- a/fuzz_lightyear/datastore.py +++ b/fuzz_lightyear/datastore.py @@ -19,6 +19,12 @@ def get_user_defined_mapping() -> Dict: return {} +def clear_cache(): + """ Clear the cached values for fixture functions """ + for value in get_user_defined_mapping().values(): + value._fuzz_cache = None + + def inject_user_defined_variables(func: Callable) -> Callable: """ This decorator allows the use of user defined variables in functions. @@ -35,6 +41,9 @@ def inject_user_defined_variables(func: Callable) -> Callable: @wraps(func) def wrapped(*args, **kwargs) -> Any: + if getattr(func, '_fuzz_cache', None) is not None: + return func._fuzz_cache # type: ignore + expected_args = _get_injectable_variables(func) type_annotations = inspect.getfullargspec(func).annotations @@ -45,19 +54,22 @@ def wrapped(*args, **kwargs) -> Any: # two values for the same argument. continue - if arg_name not in kwargs and arg_name in mapping: - value = mapping[arg_name]() - if ( - arg_name in type_annotations and - not isinstance(type_annotations[arg_name], type(List)) - ): - # If type annotations are used, use that to cast - # values for input. - value = type_annotations[arg_name](value) + if arg_name not in mapping: + raise TypeError + + value = mapping[arg_name]() + if ( + arg_name in type_annotations and + not isinstance(type_annotations[arg_name], type(List)) + ): + # If type annotations are used, use that to cast + # values for input. + value = type_annotations[arg_name](value) - kwargs[arg_name] = value + kwargs[arg_name] = value - return func(*args, **kwargs) + func._fuzz_cache = func(*args, **kwargs) # type: ignore + return func._fuzz_cache # type: ignore return wrapped diff --git a/fuzz_lightyear/runner.py b/fuzz_lightyear/runner.py index da93b79..62712e8 100644 --- a/fuzz_lightyear/runner.py +++ b/fuzz_lightyear/runner.py @@ -1,5 +1,6 @@ from typing import List +from .datastore import clear_cache from .request import FuzzingRequest from .response import ResponseSequence @@ -12,6 +13,7 @@ def run_sequence( for request in sequence: response = request.send() responses.add_response(response) + clear_cache() # Then, check for vulnerabilities. responses.analyze_requests(sequence) diff --git a/tests/integration/supplements/factory_integration_test.py b/tests/integration/supplements/factory_integration_test.py index 8361f33..64a07f5 100644 --- a/tests/integration/supplements/factory_integration_test.py +++ b/tests/integration/supplements/factory_integration_test.py @@ -16,3 +16,33 @@ def factory(): assert request.fuzzed_input['string'] == '1' assert request.fuzzed_input['integer'] == 1 + + +def test_session_fixtures(mock_client): + count = 0 + + def nested_function(): + nonlocal count + count += 1 + return count + + def child_a(nested): + return nested + + def child_b(nested): + return nested + + def function(a, b): + assert a == b + return 'does_not_matter' + + fuzz_lightyear.register_factory('nested')(nested_function) + fuzz_lightyear.register_factory('a')(child_a) + fuzz_lightyear.register_factory('b')(child_b) + fuzz_lightyear.register_factory('string')(function) + + request = FuzzingRequest( + operation_id='get_expect_primitives', + tag='types', + ) + request.send() diff --git a/tests/unit/supplements/factory_supplements_test.py b/tests/unit/supplements/factory_supplements_test.py index 4c0d3d7..08764fc 100644 --- a/tests/unit/supplements/factory_supplements_test.py +++ b/tests/unit/supplements/factory_supplements_test.py @@ -54,10 +54,6 @@ def setup(self): def test_uses_default(self): assert get_user_defined_mapping()['caller']() == 2 - def test_uses_provided_value_over_default(self): - assert get_user_defined_mapping()['caller'](dependency=2) == 3 - assert get_user_defined_mapping()['caller'](3) == 4 - def test_throws_error_when_no_default(self): def foobar(no_default): pass