From 006258fb7784a1ca5186e8882853ba7edcc6efdc Mon Sep 17 00:00:00 2001 From: Ingvar Stepanyan Date: Thu, 23 Feb 2023 17:27:11 +0000 Subject: [PATCH] Fix handling of missing Accept header Workaround for https://github.com/tokio-rs/axum/issues/1781. --- src/axum.rs | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/src/axum.rs b/src/axum.rs index f283581..444e4e2 100644 --- a/src/axum.rs +++ b/src/axum.rs @@ -3,29 +3,43 @@ use crate::params::OpaqueParams; use crate::response::OpaqueResponse; use crate::transaction::server_handler; use crate::Devices; -use axum::extract::Path; +use async_trait::async_trait; +use axum::extract::{FromRequest, Path}; +use axum::headers::{Header, HeaderName, HeaderValue}; use axum::http::Method; use axum::routing::{on, MethodFilter}; -use axum::{Form, Router, TypedHeader}; +use axum::{Form, Router}; use futures::StreamExt; use mediatype::MediaTypeList; use std::sync::Arc; // A hack until TypedHeader supports Accept natively. +#[derive(Default)] struct AcceptsImageBytes { accepts: bool, } -impl axum::headers::Header for AcceptsImageBytes { - fn name() -> &'static axum::headers::HeaderName { - static ACCEPT: axum::headers::HeaderName = axum::headers::HeaderName::from_static("accept"); +#[async_trait] +impl FromRequest for AcceptsImageBytes { + type Rejection = std::convert::Infallible; + + async fn from_request( + req: &mut axum::extract::RequestParts, + ) -> Result { + Ok(Self::decode(&mut req.headers().get_all("accept").into_iter()).unwrap_or_default()) + } +} + +impl Header for AcceptsImageBytes { + fn name() -> &'static HeaderName { + static ACCEPT: HeaderName = HeaderName::from_static("accept"); &ACCEPT } fn decode<'value, I>(values: &mut I) -> Result where Self: Sized, - I: Iterator, + I: Iterator, { let mut accepts = false; for value in values { @@ -47,14 +61,12 @@ impl axum::headers::Header for AcceptsImageBytes { Ok(Self { accepts }) } - fn encode>(&self, values: &mut E) { - values.extend(std::iter::once(axum::http::HeaderValue::from_static( - if self.accepts { - "application/imagebytes" - } else { - "*/*" - }, - ))); + fn encode>(&self, values: &mut E) { + values.extend(std::iter::once(HeaderValue::from_static(if self.accepts { + "application/imagebytes" + } else { + "*/*" + }))); } } @@ -99,7 +111,7 @@ impl Devices { usize, String, )>, - TypedHeader(accepts_image_bytes): TypedHeader, + accepts_image_bytes: AcceptsImageBytes, Form(params): Form| { let is_mut = method == Method::PUT;