Skip to content

Commit

Permalink
refactor: BI-5948 move body signature middlewares to dl_api_commons (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
MCPN authored Nov 29, 2024
1 parent 3c418f8 commit 62ce5bf
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from aiohttp import web
from aiohttp.typedefs import Handler

from dl_api_commons.aio.typing import AIOHTTPMiddleware
from dl_api_commons.crypto import get_hmac_hex_digest


def body_signature_validation_middleware(hmac_key: bytes, header: str) -> AIOHTTPMiddleware:
@web.middleware
async def actual_middleware(request: web.Request, handler: Handler) -> web.StreamResponse:
if not hmac_key: # do not consider an empty hmac key as valid.
raise Exception("body_signature_validation_middleware: no hmac_key.")

if request.method in ("HEAD", "OPTIONS", "GET"):
return await handler(request)

body_bytes = await request.read()
expected_signature = get_hmac_hex_digest(body_bytes, secret_key=hmac_key)
signature_str_from_header = request.headers.get(header)

if expected_signature != signature_str_from_header:
raise web.HTTPForbidden(reason="Invalid signature")

return await handler(request)

return actual_middleware
6 changes: 6 additions & 0 deletions lib/dl_api_commons/dl_api_commons/crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import hashlib
import hmac


def get_hmac_hex_digest(target: bytes, secret_key: bytes) -> str:
return hmac.new(key=secret_key, msg=target, digestmod=hashlib.sha256).hexdigest()
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import attr
import flask
from werkzeug.exceptions import Forbidden

from dl_api_commons.crypto import get_hmac_hex_digest


@attr.s
class BodySignatureValidator:
hmac_key: bytes = attr.ib()
header: str = attr.ib()

def validate_request_body(self) -> None:
if flask.request.method in ("HEAD", "OPTIONS", "GET"): # no body to validate.
return

# For import-test reasons, can't verify this when getting it;
# but allowing requests when the key is empty is too dangerous.
if not self.hmac_key:
raise Exception("validate_request_body: no hmac_key")

body_bytes = flask.request.get_data()
expected_signature = get_hmac_hex_digest(body_bytes, secret_key=self.hmac_key)
signature_str_from_header = flask.request.headers.get(self.header)

if expected_signature != signature_str_from_header:
raise Forbidden("Invalid signature")

def set_up(self, app: flask.Flask) -> None:
app.before_request(self.validate_request_body)
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aiohttp.client_exceptions import ServerTimeoutError
import attr

from dl_api_commons.crypto import get_hmac_hex_digest
from dl_api_commons.headers import (
HEADER_LOGGING_CONTEXT,
INTERNAL_HEADER_PROFILING_STACK,
Expand Down Expand Up @@ -64,7 +65,6 @@
ResponseTypes,
)
from dl_core.connection_executors.qe_serializer import dba_actions as dba_actions
from dl_core.connection_executors.remote_query_executor.crypto import get_hmac_hex_digest
from dl_core.connection_models.conn_options import ConnectOptions
from dl_core.enums import RQEEventType
from dl_dashsql.typed_query.primitives import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

import aiodns
from aiohttp import web
from aiohttp.typedefs import Handler

from dl_api_commons.aio.middlewares.body_signature import body_signature_validation_middleware
from dl_api_commons.aio.middlewares.request_bootstrap import RequestBootstrap
from dl_api_commons.aio.middlewares.request_id import RequestId
from dl_api_commons.aio.server_header import ServerHeader
from dl_api_commons.aio.typing import AIOHTTPMiddleware
from dl_app_tools.profiling_base import GenericProfiler
from dl_configs.env_var_definitions import (
jaeger_service_name_env_aware,
Expand Down Expand Up @@ -52,7 +51,6 @@
DEFAULT_CHUNK_SIZE,
SUPPORTED_ADAPTER_CLS,
)
from dl_core.connection_executors.remote_query_executor.crypto import get_hmac_hex_digest
from dl_core.connection_executors.remote_query_executor.error_handler_rqe import RQEErrorHandler
from dl_core.connection_executors.remote_query_executor.settings import RQESettings
from dl_core.enums import RQEEventType
Expand Down Expand Up @@ -282,27 +280,6 @@ async def post(self) -> Union[web.Response, web.StreamResponse]:
raise NotImplementedError(f"Action {action} is not implemented in QE")


def body_signature_validation_middleware(hmac_key: bytes) -> AIOHTTPMiddleware:
@web.middleware
async def actual_middleware(request: web.Request, handler: Handler) -> web.StreamResponse:
if not hmac_key: # do not consider an empty hmac key as valid.
raise Exception("body_signature_validation_middleware: no hmac_key.")

if request.method in ("HEAD", "OPTIONS", "GET"):
return await handler(request)

body_bytes = await request.read()
expected_signature = get_hmac_hex_digest(body_bytes, secret_key=hmac_key)
signature_str_from_header = request.headers.get(HEADER_BODY_SIGNATURE)

if expected_signature != signature_str_from_header:
raise web.HTTPForbidden(reason="Invalid signature")

return await handler(request)

return actual_middleware


def create_async_qe_app(hmac_key: bytes, forbid_private_addr: bool = False) -> web.Application:
req_id_service = RequestId(
header_name=HEADER_REQUEST_ID,
Expand All @@ -318,7 +295,7 @@ def create_async_qe_app(hmac_key: bytes, forbid_private_addr: bool = False) -> w
error_handler=error_handler,
).middleware,
# TODO FIX: Add profiling middleware.
body_signature_validation_middleware(hmac_key=hmac_key),
body_signature_validation_middleware(hmac_key=hmac_key, header=HEADER_BODY_SIGNATURE),
]
)
app.on_response_prepare.append(req_id_service.on_response_prepare)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@
Union,
)

import attr
from flask import current_app
import flask.views
from werkzeug.exceptions import (
Forbidden,
HTTPException,
)
from werkzeug.exceptions import HTTPException

from dl_api_commons.flask.middlewares.aio_event_loop_middleware import AIOEventLoopMiddleware
from dl_api_commons.flask.middlewares.body_signature import BodySignatureValidator
from dl_api_commons.flask.middlewares.context_var_middleware import ContextVarMiddleware
from dl_api_commons.flask.middlewares.logging_context import RequestLoggingContextControllerMiddleWare
from dl_api_commons.flask.middlewares.request_id import RequestIDService
Expand All @@ -49,7 +46,6 @@
DEFAULT_CHUNK_SIZE,
SUPPORTED_ADAPTER_CLS,
)
from dl_core.connection_executors.remote_query_executor.crypto import get_hmac_hex_digest
from dl_core.connection_executors.remote_query_executor.settings import RQESettings
from dl_core.enums import RQEEventType
from dl_core.exc import SourceTimeout
Expand Down Expand Up @@ -262,30 +258,6 @@ def dispatch_request(self) -> flask.Response:
raise NotImplementedError(f"Action {action} is not implemented in QE")


@attr.s
class BodySignatureValidator:
hmac_key: bytes = attr.ib()

def validate_request_body(self) -> None:
if flask.request.method in ("HEAD", "OPTIONS", "GET"): # no body to validate.
return

# For import-test reasons, can't verify this when getting it;
# but allowing requests when the key is empty is too dangerous.
if not self.hmac_key:
raise Exception("validate_request_body: no hmac_key")

body_bytes = flask.request.get_data()
expected_signature = get_hmac_hex_digest(body_bytes, secret_key=self.hmac_key)
signature_str_from_header = flask.request.headers.get(HEADER_BODY_SIGNATURE)

if expected_signature != signature_str_from_header:
raise Forbidden("Invalid signature")

def set_up(self, app: flask.Flask) -> None:
app.before_request(self.validate_request_body)


def ping_view() -> flask.Response:
return flask.jsonify(dict(result="PONG"))

Expand Down Expand Up @@ -343,6 +315,7 @@ def create_sync_app() -> flask.Flask:
profiling_middleware.set_up(app, accept_outer_stages=True)
BodySignatureValidator(
hmac_key=hmac_key.encode(),
header=HEADER_BODY_SIGNATURE,
).set_up(app)

app.config["forbid_private_addr"] = settings.FORBID_PRIVATE_ADDRESSES
Expand Down

This file was deleted.

0 comments on commit 62ce5bf

Please sign in to comment.