Skip to content

Commit

Permalink
refactor(middleware): Refactor internals of CSPMiddleware so that it'…
Browse files Browse the repository at this point in the history
…s easier to extend existing logic without copy/pasting it into subclass
  • Loading branch information
crbunney committed Jul 11, 2024
1 parent ed0b7a4 commit c5237c1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 42 deletions.
41 changes: 11 additions & 30 deletions csp/contrib/rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from django.conf import settings

from csp.middleware import CSPMiddleware
from csp.utils import build_policy
from csp.middleware import CSPMiddleware, PolicyParts

if TYPE_CHECKING:
from django.http import HttpRequest, HttpResponseBase
Expand All @@ -16,38 +15,20 @@ class RateLimitedCSPMiddleware(CSPMiddleware):
"""A CSP middleware that rate-limits the number of violation reports sent
to report-uri by excluding it from some requests."""

def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config", None)
update = getattr(response, "_csp_update", None)
replace = getattr(response, "_csp_replace", {})
nonce = getattr(request, "_csp_nonce", None)

policy = getattr(settings, "CONTENT_SECURITY_POLICY", None)

if policy is None:
return ""

report_percentage = policy.get("REPORT_PERCENTAGE", 100)
include_report_uri = random.randint(0, 100) < report_percentage
if not include_report_uri:
replace["report-uri"] = None

return build_policy(config=config, update=update, replace=replace, nonce=nonce)

def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config_ro", None)
update = getattr(response, "_csp_update_ro", None)
replace = getattr(response, "_csp_replace_ro", {})
nonce = getattr(request, "_csp_nonce", None)

policy = getattr(settings, "CONTENT_SECURITY_POLICY_REPORT_ONLY", None)
def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
policy_parts = super().get_policy_parts(request, response, report_only)

csp_setting_name = "CONTENT_SECURITY_POLICY_REPORT_ONLY" if report_only else "CONTENT_SECURITY_POLICY"
policy = getattr(settings, csp_setting_name, None)
if policy is None:
return ""
return policy_parts

report_percentage = policy.get("REPORT_PERCENTAGE", 100)
include_report_uri = random.randint(0, 100) < report_percentage
if not include_report_uri:
replace["report-uri"] = None
if policy_parts.replace is None:
policy_parts.replace = {"report-uri": None}
else:
policy_parts.replace["report-uri"] = None

return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True)
return policy_parts
48 changes: 36 additions & 12 deletions csp/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import base64
import http.client as http_client
import os
import warnings
from dataclasses import asdict, dataclass
from functools import partial
from typing import TYPE_CHECKING

Expand All @@ -11,12 +13,21 @@
from django.utils.functional import SimpleLazyObject

from csp.constants import HEADER, HEADER_REPORT_ONLY
from csp.utils import build_policy
from csp.utils import _DIRECTIVES, build_policy

if TYPE_CHECKING:
from django.http import HttpRequest, HttpResponseBase


@dataclass
class PolicyParts:
# A dataclass is used rather than a namedtuple so that the attributes are mutable
config: _DIRECTIVES = None
update: _DIRECTIVES = None
replace: _DIRECTIVES = None
nonce: str | None = None


class CSPMiddleware(MiddlewareMixin):
"""
Implements the Content-Security-Policy response header, which
Expand All @@ -25,6 +36,7 @@ class CSPMiddleware(MiddlewareMixin):
See http://www.w3.org/TR/CSP/
Can be customised by subclassing and extending the get_policy_parts method.
"""

def _make_nonce(self, request: HttpRequest) -> str:
Expand All @@ -49,7 +61,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
if response.status_code in exempted_debug_codes and settings.DEBUG:
return response

csp = self.build_policy(request, response)
policy_parts = self.get_policy_parts(request=request, response=response)
csp = build_policy(**asdict(policy_parts))
if csp:
# Only set header if not already set and not an excluded prefix and not exempted.
is_not_exempt = getattr(response, "_csp_exempt", False) is False
Expand All @@ -60,7 +73,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
if no_header and is_not_exempt and is_not_excluded:
response[HEADER] = csp

csp_ro = self.build_policy_ro(request, response)
policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True)
csp_ro = build_policy(**asdict(policy_parts_ro), report_only=True)
if csp_ro:
# Only set header if not already set and not an excluded prefix and not exempted.
is_not_exempt = getattr(response, "_csp_exempt_ro", False) is False
Expand All @@ -74,15 +88,25 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
return response

def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config", None)
update = getattr(response, "_csp_update", None)
replace = getattr(response, "_csp_replace", None)
nonce = getattr(request, "_csp_nonce", None)
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning)
policy_parts = self.get_policy_parts(request=request, response=response, report_only=False)
return build_policy(**asdict(policy_parts))

def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config_ro", None)
update = getattr(response, "_csp_update_ro", None)
replace = getattr(response, "_csp_replace_ro", None)
warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning)
policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True)
return build_policy(**asdict(policy_parts_ro), report_only=True)

def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
if report_only:
config = getattr(response, "_csp_config_ro", None)
update = getattr(response, "_csp_update_ro", None)
replace = getattr(response, "_csp_replace_ro", None)
else:
config = getattr(response, "_csp_config", None)
update = getattr(response, "_csp_update", None)
replace = getattr(response, "_csp_replace", None)

nonce = getattr(request, "_csp_nonce", None)
return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True)

return PolicyParts(config, update, replace, nonce)

0 comments on commit c5237c1

Please sign in to comment.