1use super::{rejection::*, FromRequestParts};
2use http::{request::Parts, Uri};
3use serde::de::DeserializeOwned;
4
5#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
50#[derive(Debug, Clone, Copy, Default)]
51pub struct Query<T>(pub T);
52
53impl<T, S> FromRequestParts<S> for Query<T>
54where
55    T: DeserializeOwned,
56    S: Send + Sync,
57{
58    type Rejection = QueryRejection;
59
60    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61        Self::try_from_uri(&parts.uri)
62    }
63}
64
65impl<T> Query<T>
66where
67    T: DeserializeOwned,
68{
69    pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
89        let query = value.query().unwrap_or_default();
90        let deserializer =
91            serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
92        let params = serde_path_to_error::deserialize(deserializer)
93            .map_err(FailedToDeserializeQueryString::from_err)?;
94        Ok(Query(params))
95    }
96}
97
98axum_core::__impl_deref!(Query);
99
100#[cfg(test)]
101mod tests {
102    use crate::{routing::get, test_helpers::TestClient, Router};
103
104    use super::*;
105    use axum_core::{body::Body, extract::FromRequest};
106    use http::{Request, StatusCode};
107    use serde::Deserialize;
108    use std::fmt::Debug;
109
110    async fn check<T>(uri: impl AsRef<str>, value: T)
111    where
112        T: DeserializeOwned + PartialEq + Debug,
113    {
114        let req = Request::builder()
115            .uri(uri.as_ref())
116            .body(Body::empty())
117            .unwrap();
118        assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value);
119    }
120
121    #[crate::test]
122    async fn test_query() {
123        #[derive(Debug, PartialEq, Deserialize)]
124        struct Pagination {
125            size: Option<u64>,
126            page: Option<u64>,
127        }
128
129        check(
130            "http://example.com/test",
131            Pagination {
132                size: None,
133                page: None,
134            },
135        )
136        .await;
137
138        check(
139            "http://example.com/test?size=10",
140            Pagination {
141                size: Some(10),
142                page: None,
143            },
144        )
145        .await;
146
147        check(
148            "http://example.com/test?size=10&page=20",
149            Pagination {
150                size: Some(10),
151                page: Some(20),
152            },
153        )
154        .await;
155    }
156
157    #[crate::test]
158    async fn correct_rejection_status_code() {
159        #[derive(Deserialize)]
160        #[allow(dead_code)]
161        struct Params {
162            n: i32,
163        }
164
165        async fn handler(_: Query<Params>) {}
166
167        let app = Router::new().route("/", get(handler));
168        let client = TestClient::new(app);
169
170        let res = client.get("/?n=hi").await;
171        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
172        assert_eq!(
173            res.text().await,
174            "Failed to deserialize query string: n: invalid digit found in string"
175        );
176    }
177
178    #[test]
179    fn test_try_from_uri() {
180        #[derive(Deserialize)]
181        struct TestQueryParams {
182            foo: String,
183            bar: u32,
184        }
185        let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
186        let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap();
187        assert_eq!(result.foo, String::from("hello"));
188        assert_eq!(result.bar, 42);
189    }
190
191    #[test]
192    fn test_try_from_uri_with_invalid_query() {
193        #[derive(Deserialize)]
194        struct TestQueryParams {
195            _foo: String,
196            _bar: u32,
197        }
198        let uri: Uri = "http://example.com/path?foo=hello&bar=invalid"
199            .parse()
200            .unwrap();
201        let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri);
202
203        assert!(result.is_err());
204    }
205}