diff --git a/engine/apps/email/tests/test_inbound_email.py b/engine/apps/email/tests/test_inbound_email.py index 252b52920..808fbdfac 100644 --- a/engine/apps/email/tests/test_inbound_email.py +++ b/engine/apps/email/tests/test_inbound_email.py @@ -14,6 +14,7 @@ from cryptography.x509 import CertificateBuilder, NameOID from django.conf import settings from django.urls import reverse +from requests import RequestException from rest_framework import status from rest_framework.test import APIClient @@ -604,6 +605,77 @@ def test_amazon_ses_validated_fail_wrong_signature( mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) +@patch("requests.get", side_effect=RequestException) +@pytest.mark.django_db +def test_amazon_ses_validated_fail_cant_download_certificate( + _, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = AMAZON_SNS_TOPIC_ARN + + organization = make_organization() + make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + + client = APIClient() + with pytest.raises(RequestException): + client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@pytest.mark.django_db +def test_amazon_ses_validated_caches_certificate( + mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = AMAZON_SNS_TOPIC_ARN + + organization = make_organization() + make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + + client = APIClient() + for _ in range(2): + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + assert response.status_code == status.HTTP_200_OK + + mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) + + @patch.object(create_alert, "delay") @pytest.mark.django_db def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert_receive_channel): diff --git a/engine/apps/email/validate_amazon_sns_message.py b/engine/apps/email/validate_amazon_sns_message.py index f3d2aec48..e08256525 100644 --- a/engine/apps/email/validate_amazon_sns_message.py +++ b/engine/apps/email/validate_amazon_sns_message.py @@ -9,6 +9,7 @@ from cryptography.hazmat.primitives.hashes import SHA1, SHA256 from cryptography.x509 import NameOID, load_pem_x509_certificate from django.conf import settings +from django.core.cache import cache logger = logging.getLogger(__name__) @@ -67,13 +68,7 @@ def validate_amazon_sns_message(message: dict) -> bool: return False # Fetch the certificate - try: - response = requests.get(signing_cert_url, timeout=5) - response.raise_for_status() - certificate_bytes = response.content - except requests.RequestException as e: - logger.warning("Failed to fetch the certificate from %s: %s", signing_cert_url, e) - return False + certificate_bytes = fetch_certificate(signing_cert_url) # Verify the certificate issuer certificate = load_pem_x509_certificate(certificate_bytes) @@ -97,3 +92,17 @@ def validate_amazon_sns_message(message: dict) -> bool: return False return True + + +def fetch_certificate(certificate_url: str) -> bytes: + cache_key = f"aws_sns_cert_{certificate_url}" + cached_certificate = cache.get(cache_key) + if cached_certificate: + return cached_certificate + + response = requests.get(certificate_url, timeout=5) + response.raise_for_status() + certificate = response.content + + cache.set(cache_key, certificate, timeout=60 * 60) # Cache for 1 hour + return certificate