1use crate::extract::Request;
2use crate::extract::{rejection::*, FromRequest, RawForm};
3use axum_core::response::{IntoResponse, Response};
4use axum_core::RequestExt;
5use http::header::CONTENT_TYPE;
6use http::StatusCode;
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9
10#[cfg_attr(docsrs, doc(cfg(feature = "form")))]
70#[derive(Debug, Clone, Copy, Default)]
71#[must_use]
72pub struct Form<T>(pub T);
73
74impl<T, S> FromRequest<S> for Form<T>
75where
76    T: DeserializeOwned,
77    S: Send + Sync,
78{
79    type Rejection = FormRejection;
80
81    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
82        let is_get_or_head =
83            req.method() == http::Method::GET || req.method() == http::Method::HEAD;
84
85        match req.extract().await {
86            Ok(RawForm(bytes)) => {
87                let deserializer =
88                    serde_urlencoded::Deserializer::new(form_urlencoded::parse(&bytes));
89                let value = serde_path_to_error::deserialize(deserializer).map_err(
90                    |err| -> FormRejection {
91                        if is_get_or_head {
92                            FailedToDeserializeForm::from_err(err).into()
93                        } else {
94                            FailedToDeserializeFormBody::from_err(err).into()
95                        }
96                    },
97                )?;
98                Ok(Form(value))
99            }
100            Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
101            Err(RawFormRejection::InvalidFormContentType(r)) => {
102                Err(FormRejection::InvalidFormContentType(r))
103            }
104        }
105    }
106}
107
108impl<T> IntoResponse for Form<T>
109where
110    T: Serialize,
111{
112    fn into_response(self) -> Response {
113        fn make_response(ser_result: Result<String, serde_urlencoded::ser::Error>) -> Response {
115            match ser_result {
116                Ok(body) => (
117                    [(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())],
118                    body,
119                )
120                    .into_response(),
121                Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
122            }
123        }
124
125        make_response(serde_urlencoded::to_string(&self.0))
126    }
127}
128axum_core::__impl_deref!(Form);
129
130#[cfg(test)]
131mod tests {
132    use crate::{
133        routing::{on, MethodFilter},
134        test_helpers::TestClient,
135        Router,
136    };
137
138    use super::*;
139    use axum_core::body::Body;
140    use http::{Method, Request};
141    use mime::APPLICATION_WWW_FORM_URLENCODED;
142    use serde::{Deserialize, Serialize};
143    use std::fmt::Debug;
144
145    #[derive(Debug, PartialEq, Serialize, Deserialize)]
146    struct Pagination {
147        size: Option<u64>,
148        page: Option<u64>,
149    }
150
151    async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
152        let req = Request::builder()
153            .uri(uri.as_ref())
154            .body(Body::empty())
155            .unwrap();
156        assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
157    }
158
159    async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
160        let req = Request::builder()
161            .uri("http://example.com/test")
162            .method(Method::POST)
163            .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
164            .body(Body::from(serde_urlencoded::to_string(&value).unwrap()))
165            .unwrap();
166        assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
167    }
168
169    #[crate::test]
170    async fn test_form_query() {
171        check_query(
172            "http://example.com/test",
173            Pagination {
174                size: None,
175                page: None,
176            },
177        )
178        .await;
179
180        check_query(
181            "http://example.com/test?size=10",
182            Pagination {
183                size: Some(10),
184                page: None,
185            },
186        )
187        .await;
188
189        check_query(
190            "http://example.com/test?size=10&page=20",
191            Pagination {
192                size: Some(10),
193                page: Some(20),
194            },
195        )
196        .await;
197    }
198
199    #[crate::test]
200    async fn test_form_body() {
201        check_body(Pagination {
202            size: None,
203            page: None,
204        })
205        .await;
206
207        check_body(Pagination {
208            size: Some(10),
209            page: None,
210        })
211        .await;
212
213        check_body(Pagination {
214            size: Some(10),
215            page: Some(20),
216        })
217        .await;
218    }
219
220    #[crate::test]
221    async fn test_incorrect_content_type() {
222        let req = Request::builder()
223            .uri("http://example.com/test")
224            .method(Method::POST)
225            .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
226            .body(Body::from(
227                serde_urlencoded::to_string(&Pagination {
228                    size: Some(10),
229                    page: None,
230                })
231                .unwrap(),
232            ))
233            .unwrap();
234        assert!(matches!(
235            Form::<Pagination>::from_request(req, &())
236                .await
237                .unwrap_err(),
238            FormRejection::InvalidFormContentType(InvalidFormContentType)
239        ));
240    }
241
242    #[tokio::test]
243    async fn deserialize_error_status_codes() {
244        #[allow(dead_code)]
245        #[derive(Deserialize)]
246        struct Payload {
247            a: i32,
248        }
249
250        let app = Router::new().route(
251            "/",
252            on(
253                MethodFilter::GET.or(MethodFilter::POST),
254                |_: Form<Payload>| async {},
255            ),
256        );
257
258        let client = TestClient::new(app);
259
260        let res = client.get("/?a=false").await;
261        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
262        assert_eq!(
263            res.text().await,
264            "Failed to deserialize form: a: invalid digit found in string"
265        );
266
267        let res = client
268            .post("/")
269            .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
270            .body("a=false")
271            .await;
272        assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
273        assert_eq!(
274            res.text().await,
275            "Failed to deserialize form body: a: invalid digit found in string"
276        );
277    }
278}