Skip to content

Commit

Permalink
Merge pull request #14 from Yelp/cache-fixtures-per-request-sequence
Browse files Browse the repository at this point in the history
cache fixture results per sequence, clearing the cache at the end
  • Loading branch information
acoover authored Sep 20, 2019
2 parents d85edc8 + 74e03af commit 8640b6d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
34 changes: 23 additions & 11 deletions fuzz_lightyear/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def get_non_vulnerable_operations() -> Dict[str, Optional[str]]:
return {}


def clear_cache():
""" Clear the cached values for fixture functions """
for value in get_user_defined_mapping().values():
value._fuzz_cache = None


def inject_user_defined_variables(func: Callable) -> Callable:
"""
This decorator allows the use of user defined variables in functions.
Expand All @@ -62,6 +68,9 @@ def inject_user_defined_variables(func: Callable) -> Callable:

@wraps(func)
def wrapped(*args, **kwargs) -> Any:
if getattr(func, '_fuzz_cache', None) is not None:
return func._fuzz_cache # type: ignore

expected_args = _get_injectable_variables(func)
type_annotations = inspect.getfullargspec(func).annotations

Expand All @@ -72,19 +81,22 @@ def wrapped(*args, **kwargs) -> Any:
# two values for the same argument.
continue

if arg_name not in kwargs and arg_name in mapping:
value = mapping[arg_name]()
if (
arg_name in type_annotations and
not isinstance(type_annotations[arg_name], type(List))
):
# If type annotations are used, use that to cast
# values for input.
value = type_annotations[arg_name](value)
if arg_name not in mapping:
raise TypeError

value = mapping[arg_name]()
if (
arg_name in type_annotations and
not isinstance(type_annotations[arg_name], type(List))
):
# If type annotations are used, use that to cast
# values for input.
value = type_annotations[arg_name](value)

kwargs[arg_name] = value
kwargs[arg_name] = value

return func(*args, **kwargs)
func._fuzz_cache = func(*args, **kwargs) # type: ignore
return func._fuzz_cache # type: ignore

return wrapped

Expand Down
2 changes: 2 additions & 0 deletions fuzz_lightyear/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

from .datastore import clear_cache
from .request import FuzzingRequest
from .response import ResponseSequence

Expand All @@ -15,4 +16,5 @@ def run_sequence(

# Then, check for vulnerabilities.
responses.analyze_requests(sequence)
clear_cache()
return responses
30 changes: 30 additions & 0 deletions tests/integration/supplements/factory_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,33 @@ def factory():

assert request.fuzzed_input['string'] == '1'
assert request.fuzzed_input['integer'] == 1


def test_session_fixtures(mock_client):
count = 0

def nested_function():
nonlocal count
count += 1
return count

def child_a(nested):
return nested

def child_b(nested):
return nested

def function(a, b):
assert a == b
return 'does_not_matter'

fuzz_lightyear.register_factory('nested')(nested_function)
fuzz_lightyear.register_factory('a')(child_a)
fuzz_lightyear.register_factory('b')(child_b)
fuzz_lightyear.register_factory('string')(function)

request = FuzzingRequest(
operation_id='get_expect_primitives',
tag='types',
)
request.send()
4 changes: 0 additions & 4 deletions tests/unit/supplements/factory_supplements_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ def setup(self):
def test_uses_default(self):
assert get_user_defined_mapping()['caller']() == 2

def test_uses_provided_value_over_default(self):
assert get_user_defined_mapping()['caller'](dependency=2) == 3
assert get_user_defined_mapping()['caller'](3) == 4

def test_throws_error_when_no_default(self):
def foobar(no_default):
pass
Expand Down

0 comments on commit 8640b6d

Please sign in to comment.