1use crate::response::{IntoResponse, Response};
2use axum_core::extract::FromRequestParts;
3use futures_util::future::BoxFuture;
4use http::Request;
5use std::{
6    any::type_name,
7    convert::Infallible,
8    fmt,
9    future::Future,
10    marker::PhantomData,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use tower_layer::Layer;
15use tower_service::Service;
16
17pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
100    map_response_with_state((), f)
101}
102
103pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
142    MapResponseLayer {
143        f,
144        state,
145        _extractor: PhantomData,
146    }
147}
148
149#[must_use]
153pub struct MapResponseLayer<F, S, T> {
154    f: F,
155    state: S,
156    _extractor: PhantomData<fn() -> T>,
157}
158
159impl<F, S, T> Clone for MapResponseLayer<F, S, T>
160where
161    F: Clone,
162    S: Clone,
163{
164    fn clone(&self) -> Self {
165        Self {
166            f: self.f.clone(),
167            state: self.state.clone(),
168            _extractor: self._extractor,
169        }
170    }
171}
172
173impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T>
174where
175    F: Clone,
176    S: Clone,
177{
178    type Service = MapResponse<F, S, I, T>;
179
180    fn layer(&self, inner: I) -> Self::Service {
181        MapResponse {
182            f: self.f.clone(),
183            state: self.state.clone(),
184            inner,
185            _extractor: PhantomData,
186        }
187    }
188}
189
190impl<F, S, T> fmt::Debug for MapResponseLayer<F, S, T>
191where
192    S: fmt::Debug,
193{
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        f.debug_struct("MapResponseLayer")
196            .field("f", &format_args!("{}", type_name::<F>()))
198            .field("state", &self.state)
199            .finish()
200    }
201}
202
203pub struct MapResponse<F, S, I, T> {
207    f: F,
208    inner: I,
209    state: S,
210    _extractor: PhantomData<fn() -> T>,
211}
212
213impl<F, S, I, T> Clone for MapResponse<F, S, I, T>
214where
215    F: Clone,
216    I: Clone,
217    S: Clone,
218{
219    fn clone(&self) -> Self {
220        Self {
221            f: self.f.clone(),
222            inner: self.inner.clone(),
223            state: self.state.clone(),
224            _extractor: self._extractor,
225        }
226    }
227}
228
229macro_rules! impl_service {
230    (
231        $($ty:ident),*
232    ) => {
233        #[allow(non_snake_case, unused_mut)]
234        impl<F, Fut, S, I, B, ResBody, $($ty,)*> Service<Request<B>> for MapResponse<F, S, I, ($($ty,)*)>
235        where
236            F: FnMut($($ty,)* Response<ResBody>) -> Fut + Clone + Send + 'static,
237            $( $ty: FromRequestParts<S> + Send, )*
238            Fut: Future + Send + 'static,
239            Fut::Output: IntoResponse + Send + 'static,
240            I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible>
241                + Clone
242                + Send
243                + 'static,
244            I::Future: Send + 'static,
245            B: Send + 'static,
246            ResBody: Send + 'static,
247            S: Clone + Send + Sync + 'static,
248        {
249            type Response = Response;
250            type Error = Infallible;
251            type Future = ResponseFuture;
252
253            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254                self.inner.poll_ready(cx)
255            }
256
257
258            fn call(&mut self, req: Request<B>) -> Self::Future {
259                let not_ready_inner = self.inner.clone();
260                let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
261
262                let mut f = self.f.clone();
263                let _state = self.state.clone();
264                let (mut parts, body) = req.into_parts();
265
266                let future = Box::pin(async move {
267                    $(
268                        let $ty = match $ty::from_request_parts(&mut parts, &_state).await {
269                            Ok(value) => value,
270                            Err(rejection) => return rejection.into_response(),
271                        };
272                    )*
273
274                    let req = Request::from_parts(parts, body);
275
276                    match ready_inner.call(req).await {
277                        Ok(res) => {
278                            f($($ty,)* res).await.into_response()
279                        }
280                        Err(err) => match err {}
281                    }
282                });
283
284                ResponseFuture {
285                    inner: future
286                }
287            }
288        }
289    };
290}
291
292impl_service!();
293impl_service!(T1);
294impl_service!(T1, T2);
295impl_service!(T1, T2, T3);
296impl_service!(T1, T2, T3, T4);
297impl_service!(T1, T2, T3, T4, T5);
298impl_service!(T1, T2, T3, T4, T5, T6);
299impl_service!(T1, T2, T3, T4, T5, T6, T7);
300impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
301impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
302impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
303impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
304impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
305impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
306impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
307impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
308impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
309
310impl<F, S, I, T> fmt::Debug for MapResponse<F, S, I, T>
311where
312    S: fmt::Debug,
313    I: fmt::Debug,
314{
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        f.debug_struct("MapResponse")
317            .field("f", &format_args!("{}", type_name::<F>()))
318            .field("inner", &self.inner)
319            .field("state", &self.state)
320            .finish()
321    }
322}
323
324pub struct ResponseFuture {
326    inner: BoxFuture<'static, Response>,
327}
328
329impl Future for ResponseFuture {
330    type Output = Result<Response, Infallible>;
331
332    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
333        self.inner.as_mut().poll(cx).map(Ok)
334    }
335}
336
337impl fmt::Debug for ResponseFuture {
338    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339        f.debug_struct("ResponseFuture").finish()
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    #[allow(unused_imports)]
346    use super::*;
347    use crate::{test_helpers::TestClient, Router};
348
349    #[crate::test]
350    async fn works() {
351        async fn add_header<B>(mut res: Response<B>) -> Response<B> {
352            res.headers_mut().insert("x-foo", "foo".parse().unwrap());
353            res
354        }
355
356        let app = Router::new().layer(map_response(add_header));
357        let client = TestClient::new(app);
358
359        let res = client.get("/").await;
360
361        assert_eq!(res.headers()["x-foo"], "foo");
362    }
363}