1use axum_core::extract::{FromRequest, FromRequestParts, Request};
2use futures_util::future::BoxFuture;
3use std::{
4    any::type_name,
5    convert::Infallible,
6    fmt,
7    future::Future,
8    marker::PhantomData,
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tower::util::BoxCloneSyncService;
13use tower_layer::Layer;
14use tower_service::Service;
15
16use crate::{
17    response::{IntoResponse, Response},
18    util::MapIntoResponse,
19};
20
21pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
115    from_fn_with_state((), f)
116}
117
118pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
165    FromFnLayer {
166        f,
167        state,
168        _extractor: PhantomData,
169    }
170}
171
172#[must_use]
178pub struct FromFnLayer<F, S, T> {
179    f: F,
180    state: S,
181    _extractor: PhantomData<fn() -> T>,
182}
183
184impl<F, S, T> Clone for FromFnLayer<F, S, T>
185where
186    F: Clone,
187    S: Clone,
188{
189    fn clone(&self) -> Self {
190        Self {
191            f: self.f.clone(),
192            state: self.state.clone(),
193            _extractor: self._extractor,
194        }
195    }
196}
197
198impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
199where
200    F: Clone,
201    S: Clone,
202{
203    type Service = FromFn<F, S, I, T>;
204
205    fn layer(&self, inner: I) -> Self::Service {
206        FromFn {
207            f: self.f.clone(),
208            state: self.state.clone(),
209            inner,
210            _extractor: PhantomData,
211        }
212    }
213}
214
215impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
216where
217    S: fmt::Debug,
218{
219    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
220        f.debug_struct("FromFnLayer")
221            .field("f", &format_args!("{}", type_name::<F>()))
223            .field("state", &self.state)
224            .finish()
225    }
226}
227
228pub struct FromFn<F, S, I, T> {
232    f: F,
233    inner: I,
234    state: S,
235    _extractor: PhantomData<fn() -> T>,
236}
237
238impl<F, S, I, T> Clone for FromFn<F, S, I, T>
239where
240    F: Clone,
241    I: Clone,
242    S: Clone,
243{
244    fn clone(&self) -> Self {
245        Self {
246            f: self.f.clone(),
247            inner: self.inner.clone(),
248            state: self.state.clone(),
249            _extractor: self._extractor,
250        }
251    }
252}
253
254macro_rules! impl_service {
255    (
256        [$($ty:ident),*], $last:ident
257    ) => {
258        #[allow(non_snake_case, unused_mut)]
259        impl<F, Fut, Out, S, I, $($ty,)* $last> Service<Request> for FromFn<F, S, I, ($($ty,)* $last,)>
260        where
261            F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static,
262            $( $ty: FromRequestParts<S> + Send, )*
263            $last: FromRequest<S> + Send,
264            Fut: Future<Output = Out> + Send + 'static,
265            Out: IntoResponse + 'static,
266            I: Service<Request, Error = Infallible>
267                + Clone
268                + Send
269                + Sync
270                + 'static,
271            I::Response: IntoResponse,
272            I::Future: Send + 'static,
273            S: Clone + Send + Sync + 'static,
274        {
275            type Response = Response;
276            type Error = Infallible;
277            type Future = ResponseFuture;
278
279            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
280                self.inner.poll_ready(cx)
281            }
282
283            fn call(&mut self, req: Request) -> Self::Future {
284                let not_ready_inner = self.inner.clone();
285                let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
286
287                let mut f = self.f.clone();
288                let state = self.state.clone();
289                let (mut parts, body) = req.into_parts();
290
291                let future = Box::pin(async move {
292                    $(
293                        let $ty = match $ty::from_request_parts(&mut parts, &state).await {
294                            Ok(value) => value,
295                            Err(rejection) => return rejection.into_response(),
296                        };
297                    )*
298
299                    let req = Request::from_parts(parts, body);
300
301                    let $last = match $last::from_request(req, &state).await {
302                        Ok(value) => value,
303                        Err(rejection) => return rejection.into_response(),
304                    };
305
306                    let inner = BoxCloneSyncService::new(MapIntoResponse::new(ready_inner));
307                    let next = Next { inner };
308
309                    f($($ty,)* $last, next).await.into_response()
310                });
311
312                ResponseFuture {
313                    inner: future
314                }
315            }
316        }
317    };
318}
319
320all_the_tuples!(impl_service);
321
322impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
323where
324    S: fmt::Debug,
325    I: fmt::Debug,
326{
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        f.debug_struct("FromFnLayer")
329            .field("f", &format_args!("{}", type_name::<F>()))
330            .field("inner", &self.inner)
331            .field("state", &self.state)
332            .finish()
333    }
334}
335
336#[derive(Debug, Clone)]
338pub struct Next {
339    inner: BoxCloneSyncService<Request, Response, Infallible>,
340}
341
342impl Next {
343    pub async fn run(mut self, req: Request) -> Response {
345        match self.inner.call(req).await {
346            Ok(res) => res,
347            Err(err) => match err {},
348        }
349    }
350}
351
352impl Service<Request> for Next {
353    type Response = Response;
354    type Error = Infallible;
355    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
356
357    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
358        self.inner.poll_ready(cx)
359    }
360
361    fn call(&mut self, req: Request) -> Self::Future {
362        self.inner.call(req)
363    }
364}
365
366pub struct ResponseFuture {
368    inner: BoxFuture<'static, Response>,
369}
370
371impl Future for ResponseFuture {
372    type Output = Result<Response, Infallible>;
373
374    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
375        self.inner.as_mut().poll(cx).map(Ok)
376    }
377}
378
379impl fmt::Debug for ResponseFuture {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        f.debug_struct("ResponseFuture").finish()
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use crate::{body::Body, routing::get, Router};
389    use http::{HeaderMap, StatusCode};
390    use http_body_util::BodyExt;
391    use tower::ServiceExt;
392
393    #[crate::test]
394    async fn basic() {
395        async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse {
396            req.headers_mut()
397                .insert("x-axum-test", "ok".parse().unwrap());
398
399            next.run(req).await
400        }
401
402        async fn handle(headers: HeaderMap) -> String {
403            headers["x-axum-test"].to_str().unwrap().to_owned()
404        }
405
406        let app = Router::new()
407            .route("/", get(handle))
408            .layer(from_fn(insert_header));
409
410        let res = app
411            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
412            .await
413            .unwrap();
414        assert_eq!(res.status(), StatusCode::OK);
415        let body = res.collect().await.unwrap().to_bytes();
416        assert_eq!(&body[..], b"ok");
417    }
418}