From 74e03af0899bc4b0972440d4d494d42da8c4daa7 Mon Sep 17 00:00:00 2001 From: Andrew Coover Date: Mon, 16 Sep 2019 16:30:32 -0700 Subject: [PATCH] cache fixture results per sequence, clearing the cache at the end --- fuzz_lightyear/datastore.py | 34 +++++++++++++------ fuzz_lightyear/runner.py | 2 ++ .../supplements/factory_integration_test.py | 30 ++++++++++++++++ .../supplements/factory_supplements_test.py | 4 --- 4 files changed, 55 insertions(+), 15 deletions(-) diff --git a/fuzz_lightyear/datastore.py b/fuzz_lightyear/datastore.py index cb3e7fe..39070fb 100644 --- a/fuzz_lightyear/datastore.py +++ b/fuzz_lightyear/datastore.py @@ -46,6 +46,12 @@ def get_non_vulnerable_operations() -> Dict[str, Optional[str]]: return {} +def clear_cache(): + """ Clear the cached values for fixture functions """ + for value in get_user_defined_mapping().values(): + value._fuzz_cache = None + + def inject_user_defined_variables(func: Callable) -> Callable: """ This decorator allows the use of user defined variables in functions. @@ -62,6 +68,9 @@ def inject_user_defined_variables(func: Callable) -> Callable: @wraps(func) def wrapped(*args, **kwargs) -> Any: + if getattr(func, '_fuzz_cache', None) is not None: + return func._fuzz_cache # type: ignore + expected_args = _get_injectable_variables(func) type_annotations = inspect.getfullargspec(func).annotations @@ -72,19 +81,22 @@ def wrapped(*args, **kwargs) -> Any: # two values for the same argument. continue - if arg_name not in kwargs and arg_name in mapping: - value = mapping[arg_name]() - if ( - arg_name in type_annotations and - not isinstance(type_annotations[arg_name], type(List)) - ): - # If type annotations are used, use that to cast - # values for input. - value = type_annotations[arg_name](value) + if arg_name not in mapping: + raise TypeError + + value = mapping[arg_name]() + if ( + arg_name in type_annotations and + not isinstance(type_annotations[arg_name], type(List)) + ): + # If type annotations are used, use that to cast + # values for input. + value = type_annotations[arg_name](value) - kwargs[arg_name] = value + kwargs[arg_name] = value - return func(*args, **kwargs) + func._fuzz_cache = func(*args, **kwargs) # type: ignore + return func._fuzz_cache # type: ignore return wrapped diff --git a/fuzz_lightyear/runner.py b/fuzz_lightyear/runner.py index da93b79..6abb4fb 100644 --- a/fuzz_lightyear/runner.py +++ b/fuzz_lightyear/runner.py @@ -1,5 +1,6 @@ from typing import List +from .datastore import clear_cache from .request import FuzzingRequest from .response import ResponseSequence @@ -15,4 +16,5 @@ def run_sequence( # Then, check for vulnerabilities. responses.analyze_requests(sequence) + clear_cache() return responses diff --git a/tests/integration/supplements/factory_integration_test.py b/tests/integration/supplements/factory_integration_test.py index 8361f33..64a07f5 100644 --- a/tests/integration/supplements/factory_integration_test.py +++ b/tests/integration/supplements/factory_integration_test.py @@ -16,3 +16,33 @@ def factory(): assert request.fuzzed_input['string'] == '1' assert request.fuzzed_input['integer'] == 1 + + +def test_session_fixtures(mock_client): + count = 0 + + def nested_function(): + nonlocal count + count += 1 + return count + + def child_a(nested): + return nested + + def child_b(nested): + return nested + + def function(a, b): + assert a == b + return 'does_not_matter' + + fuzz_lightyear.register_factory('nested')(nested_function) + fuzz_lightyear.register_factory('a')(child_a) + fuzz_lightyear.register_factory('b')(child_b) + fuzz_lightyear.register_factory('string')(function) + + request = FuzzingRequest( + operation_id='get_expect_primitives', + tag='types', + ) + request.send() diff --git a/tests/unit/supplements/factory_supplements_test.py b/tests/unit/supplements/factory_supplements_test.py index 4c0d3d7..08764fc 100644 --- a/tests/unit/supplements/factory_supplements_test.py +++ b/tests/unit/supplements/factory_supplements_test.py @@ -54,10 +54,6 @@ def setup(self): def test_uses_default(self): assert get_user_defined_mapping()['caller']() == 2 - def test_uses_provided_value_over_default(self): - assert get_user_defined_mapping()['caller'](dependency=2) == 3 - assert get_user_defined_mapping()['caller'](3) == 4 - def test_throws_error_when_no_default(self): def foobar(no_default): pass