Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add lock to CachingCryptoMaterialsManager #700

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 43 additions & 39 deletions src/aws_encryption_sdk/materials_managers/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Caching crypto material manager."""
import logging
import uuid
from threading import RLock

import attr
import six
Expand Down Expand Up @@ -109,6 +110,8 @@ def __attrs_post_init__(self):
if self.partition_name is None:
self.partition_name = to_bytes(str(uuid.uuid4()))

self._cache_lock = RLock()

def _cache_entry_has_encrypted_too_many_bytes(self, entry):
"""Determines if a cache entry has exceeded the max allowed bytes encrypted.

Expand Down Expand Up @@ -188,32 +191,33 @@ def get_encryption_materials(self, request):
)
cache_key = build_encryption_materials_cache_key(partition=self.partition_name, request=inner_request)

# Attempt to retrieve from cache
try:
cache_entry = self.cache.get_encryption_materials(
cache_key=cache_key, plaintext_length=request.plaintext_length
)
except CacheKeyError:
pass
else:
if self._cache_entry_has_exceeded_limits(cache_entry):
self.cache.remove(cache_entry)
with self._cache_lock:
# Attempt to retrieve from cache
try:
cache_entry = self.cache.get_encryption_materials(
cache_key=cache_key, plaintext_length=request.plaintext_length
)
except CacheKeyError:
pass
else:
return cache_entry.value

# Nothing found in cache: try the material manager
new_result = self.backing_materials_manager.get_encryption_materials(inner_request)

if not new_result.algorithm.safe_to_cache() or request.plaintext_length >= self.max_bytes_encrypted:
return new_result

# Add results into cache
self.cache.put_encryption_materials(
cache_key=cache_key,
encryption_materials=new_result,
plaintext_length=request.plaintext_length,
entry_hints=CryptoMaterialsCacheEntryHints(lifetime=self.max_age),
)
if self._cache_entry_has_exceeded_limits(cache_entry):
self.cache.remove(cache_entry)
else:
return cache_entry.value

# Nothing found in cache: try the material manager
new_result = self.backing_materials_manager.get_encryption_materials(inner_request)

if not new_result.algorithm.safe_to_cache() or request.plaintext_length >= self.max_bytes_encrypted:
return new_result

# Add results into cache
self.cache.put_encryption_materials(
cache_key=cache_key,
encryption_materials=new_result,
plaintext_length=request.plaintext_length,
entry_hints=CryptoMaterialsCacheEntryHints(lifetime=self.max_age),
)
return new_result

def decrypt_materials(self, request):
Expand All @@ -225,21 +229,21 @@ def decrypt_materials(self, request):
:rtype: aws_encryption_sdk.materials_managers.DecryptionMaterials
"""
cache_key = build_decryption_materials_cache_key(partition=self.partition_name, request=request)

# Attempt to retrieve from cache
try:
cache_entry = self.cache.get_decryption_materials(cache_key)
except CacheKeyError:
pass
else:
if self._cache_entry_is_too_old(cache_entry):
self.cache.remove(cache_entry)
with self._cache_lock:
# Attempt to retrieve from cache
try:
cache_entry = self.cache.get_decryption_materials(cache_key)
except CacheKeyError:
pass
else:
return cache_entry.value
if self._cache_entry_is_too_old(cache_entry):
self.cache.remove(cache_entry)
else:
return cache_entry.value

# Nothing found in cache: try the material manager
new_result = self.backing_materials_manager.decrypt_materials(request)
# Nothing found in cache: try the material manager
new_result = self.backing_materials_manager.decrypt_materials(request)

# Add results into cache
self.cache.put_decryption_materials(cache_key=cache_key, decryption_materials=new_result)
# Add results into cache
self.cache.put_decryption_materials(cache_key=cache_key, decryption_materials=new_result)
return new_result
34 changes: 34 additions & 0 deletions test/unit/test_material_managers_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
"""Unit test suite for CachingCryptoMaterialsManager"""

import concurrent.futures

import pytest
from mock import MagicMock, sentinel
from pytest_mock import mocker # noqa pylint: disable=unused-import
Expand Down Expand Up @@ -371,6 +373,26 @@ def test_get_encryption_materials_cache_miss_algorithm_not_safe_to_cache(
assert test is ccmm.backing_materials_manager.get_encryption_materials.return_value


def test_get_encryption_materials_cache_thread_safe(
patch_encryption_materials_request,
patch_should_cache_encryption_request,
patch_cache_entry_has_exceeded_limits,
patch_build_encryption_materials_cache_key,
):
patch_cache_entry_has_exceeded_limits.return_value = False
mock_request = fake_encryption_request()
mock_request.plaintext_length = 10
ccmm = build_ccmm()
ccmm.cache.get_encryption_materials.side_effect = [CacheKeyError, MagicMock(), MagicMock()]
ccmm.backing_materials_manager.get_encryption_materials.return_value.algorithm.safe_to_cache.return_value = True

arguments = [mock_request, mock_request, mock_request]
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
results = [item for item in executor.map(ccmm.get_encryption_materials, arguments)]

assert ccmm.backing_materials_manager.get_encryption_materials.call_count == 1


@pytest.fixture
def patch_build_decryption_materials_cache_key(mocker):
mocker.patch.object(aws_encryption_sdk.materials_managers.caching, "build_decryption_materials_cache_key")
Expand Down Expand Up @@ -428,3 +450,15 @@ def test_decrypt_materials_cache_miss(patch_build_decryption_materials_cache_key
assert not patch_cache_entry_is_too_old.called
assert not ccmm.cache.remove.called
assert test is ccmm.backing_materials_manager.decrypt_materials.return_value


def test_decrypt_materials_cache_thread_safe(patch_build_decryption_materials_cache_key, patch_cache_entry_is_too_old):
patch_cache_entry_is_too_old.return_value = False
ccmm = build_ccmm()
ccmm.cache.get_decryption_materials.side_effect = [CacheKeyError, MagicMock(), MagicMock()]

arguments = [sentinel.request, sentinel.request, sentinel.request]
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
results = [item for item in executor.map(ccmm.decrypt_materials, arguments)]

assert ccmm.backing_materials_manager.decrypt_materials.call_count == 1