axum/extract/
multipart.rs

1//! Extractor that parses `multipart/form-data` requests commonly used with file uploads.
2//!
3//! See [`Multipart`] for more details.
4
5use super::{FromRequest, Request};
6use crate::body::Bytes;
7use axum_core::{
8    __composite_rejection as composite_rejection, __define_rejection as define_rejection,
9    response::{IntoResponse, Response},
10    RequestExt,
11};
12use futures_util::stream::Stream;
13use http::{
14    header::{HeaderMap, CONTENT_TYPE},
15    StatusCode,
16};
17use std::{
18    error::Error,
19    fmt,
20    pin::Pin,
21    task::{Context, Poll},
22};
23
24/// Extractor that parses `multipart/form-data` requests (commonly used with file uploads).
25///
26/// ⚠️ Since extracting multipart form data from the request requires consuming the body, the
27/// `Multipart` extractor must be *last* if there are multiple extractors in a handler.
28/// See ["the order of extractors"][order-of-extractors]
29///
30/// [order-of-extractors]: crate::extract#the-order-of-extractors
31///
32/// # Example
33///
34/// ```rust,no_run
35/// use axum::{
36///     extract::Multipart,
37///     routing::post,
38///     Router,
39/// };
40/// use futures_util::stream::StreamExt;
41///
42/// async fn upload(mut multipart: Multipart) {
43///     while let Some(mut field) = multipart.next_field().await.unwrap() {
44///         let name = field.name().unwrap().to_string();
45///         let data = field.bytes().await.unwrap();
46///
47///         println!("Length of `{}` is {} bytes", name, data.len());
48///     }
49/// }
50///
51/// let app = Router::new().route("/upload", post(upload));
52/// # let _: Router = app;
53/// ```
54///
55/// # Large Files
56///
57/// For security reasons, by default, `Multipart` limits the request body size to 2MB.
58/// See [`DefaultBodyLimit`][default-body-limit] for how to configure this limit.
59///
60/// [default-body-limit]: crate::extract::DefaultBodyLimit
61#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
62#[derive(Debug)]
63pub struct Multipart {
64    inner: multer::Multipart<'static>,
65}
66
67impl<S> FromRequest<S> for Multipart
68where
69    S: Send + Sync,
70{
71    type Rejection = MultipartRejection;
72
73    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
74        let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
75        let stream = req.with_limited_body().into_body();
76        let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
77        Ok(Self { inner: multipart })
78    }
79}
80
81impl Multipart {
82    /// Yields the next [`Field`] if available.
83    pub async fn next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError> {
84        let field = self
85            .inner
86            .next_field()
87            .await
88            .map_err(MultipartError::from_multer)?;
89
90        if let Some(field) = field {
91            Ok(Some(Field {
92                inner: field,
93                _multipart: self,
94            }))
95        } else {
96            Ok(None)
97        }
98    }
99}
100
101/// A single field in a multipart stream.
102#[derive(Debug)]
103pub struct Field<'a> {
104    inner: multer::Field<'static>,
105    // multer requires there to only be one live `multer::Field` at any point. This enforces that
106    // statically, which multer does not do, it returns an error instead.
107    _multipart: &'a mut Multipart,
108}
109
110impl Stream for Field<'_> {
111    type Item = Result<Bytes, MultipartError>;
112
113    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
114        Pin::new(&mut self.inner)
115            .poll_next(cx)
116            .map_err(MultipartError::from_multer)
117    }
118}
119
120impl Field<'_> {
121    /// The field name found in the
122    /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
123    /// header.
124    pub fn name(&self) -> Option<&str> {
125        self.inner.name()
126    }
127
128    /// The file name found in the
129    /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
130    /// header.
131    pub fn file_name(&self) -> Option<&str> {
132        self.inner.file_name()
133    }
134
135    /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field.
136    pub fn content_type(&self) -> Option<&str> {
137        self.inner.content_type().map(|m| m.as_ref())
138    }
139
140    /// Get a map of headers as [`HeaderMap`].
141    pub fn headers(&self) -> &HeaderMap {
142        self.inner.headers()
143    }
144
145    /// Get the full data of the field as [`Bytes`].
146    pub async fn bytes(self) -> Result<Bytes, MultipartError> {
147        self.inner
148            .bytes()
149            .await
150            .map_err(MultipartError::from_multer)
151    }
152
153    /// Get the full field data as text.
154    pub async fn text(self) -> Result<String, MultipartError> {
155        self.inner.text().await.map_err(MultipartError::from_multer)
156    }
157
158    /// Stream a chunk of the field data.
159    ///
160    /// When the field data has been exhausted, this will return [`None`].
161    ///
162    /// Note this does the same thing as `Field`'s [`Stream`] implementation.
163    ///
164    /// # Example
165    ///
166    /// ```
167    /// use axum::{
168    ///    extract::Multipart,
169    ///    routing::post,
170    ///    response::IntoResponse,
171    ///    http::StatusCode,
172    ///    Router,
173    /// };
174    ///
175    /// async fn upload(mut multipart: Multipart) -> Result<(), (StatusCode, String)> {
176    ///     while let Some(mut field) = multipart
177    ///         .next_field()
178    ///         .await
179    ///         .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
180    ///     {
181    ///         while let Some(chunk) = field
182    ///             .chunk()
183    ///             .await
184    ///             .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
185    ///         {
186    ///             println!("received {} bytes", chunk.len());
187    ///         }
188    ///     }
189    ///
190    ///     Ok(())
191    /// }
192    ///
193    /// let app = Router::new().route("/upload", post(upload));
194    /// # let _: Router = app;
195    /// ```
196    pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
197        self.inner
198            .chunk()
199            .await
200            .map_err(MultipartError::from_multer)
201    }
202}
203
204/// Errors associated with parsing `multipart/form-data` requests.
205#[derive(Debug)]
206pub struct MultipartError {
207    source: multer::Error,
208}
209
210impl MultipartError {
211    fn from_multer(multer: multer::Error) -> Self {
212        Self { source: multer }
213    }
214
215    /// Get the response body text used for this rejection.
216    pub fn body_text(&self) -> String {
217        self.source.to_string()
218    }
219
220    /// Get the status code used for this rejection.
221    pub fn status(&self) -> http::StatusCode {
222        status_code_from_multer_error(&self.source)
223    }
224}
225
226fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
227    match err {
228        multer::Error::UnknownField { .. }
229        | multer::Error::IncompleteFieldData { .. }
230        | multer::Error::IncompleteHeaders
231        | multer::Error::ReadHeaderFailed(..)
232        | multer::Error::DecodeHeaderName { .. }
233        | multer::Error::DecodeContentType(..)
234        | multer::Error::NoBoundary
235        | multer::Error::DecodeHeaderValue { .. }
236        | multer::Error::NoMultipart
237        | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
238        multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
239            StatusCode::PAYLOAD_TOO_LARGE
240        }
241        multer::Error::StreamReadFailed(err) => {
242            if let Some(err) = err.downcast_ref::<multer::Error>() {
243                return status_code_from_multer_error(err);
244            }
245
246            if err
247                .downcast_ref::<crate::Error>()
248                .and_then(|err| err.source())
249                .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
250                .is_some()
251            {
252                return StatusCode::PAYLOAD_TOO_LARGE;
253            }
254
255            StatusCode::INTERNAL_SERVER_ERROR
256        }
257        _ => StatusCode::INTERNAL_SERVER_ERROR,
258    }
259}
260
261impl fmt::Display for MultipartError {
262    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263        write!(f, "Error parsing `multipart/form-data` request")
264    }
265}
266
267impl std::error::Error for MultipartError {
268    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
269        Some(&self.source)
270    }
271}
272
273impl IntoResponse for MultipartError {
274    fn into_response(self) -> Response {
275        let body = self.body_text();
276        axum_core::__log_rejection!(
277            rejection_type = Self,
278            body_text = body,
279            status = self.status(),
280        );
281        (self.status(), body).into_response()
282    }
283}
284
285fn parse_boundary(headers: &HeaderMap) -> Option<String> {
286    let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
287    multer::parse_boundary(content_type).ok()
288}
289
290composite_rejection! {
291    /// Rejection used for [`Multipart`].
292    ///
293    /// Contains one variant for each way the [`Multipart`] extractor can fail.
294    pub enum MultipartRejection {
295        InvalidBoundary,
296    }
297}
298
299define_rejection! {
300    #[status = BAD_REQUEST]
301    #[body = "Invalid `boundary` for `multipart/form-data` request"]
302    /// Rejection type used if the `boundary` in a `multipart/form-data` is
303    /// missing or invalid.
304    pub struct InvalidBoundary;
305}
306
307#[cfg(test)]
308mod tests {
309    use axum_core::extract::DefaultBodyLimit;
310
311    use super::*;
312    use crate::{routing::post, test_helpers::*, Router};
313
314    #[crate::test]
315    async fn content_type_with_encoding() {
316        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
317        const FILE_NAME: &str = "index.html";
318        const CONTENT_TYPE: &str = "text/html; charset=utf-8";
319
320        async fn handle(mut multipart: Multipart) -> impl IntoResponse {
321            let field = multipart.next_field().await.unwrap().unwrap();
322
323            assert_eq!(field.file_name().unwrap(), FILE_NAME);
324            assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
325            assert_eq!(field.headers()["foo"], "bar");
326            assert_eq!(field.bytes().await.unwrap(), BYTES);
327
328            assert!(multipart.next_field().await.unwrap().is_none());
329        }
330
331        let app = Router::new().route("/", post(handle));
332
333        let client = TestClient::new(app);
334
335        let form = reqwest::multipart::Form::new().part(
336            "file",
337            reqwest::multipart::Part::bytes(BYTES)
338                .file_name(FILE_NAME)
339                .mime_str(CONTENT_TYPE)
340                .unwrap()
341                .headers(reqwest::header::HeaderMap::from_iter([(
342                    reqwest::header::HeaderName::from_static("foo"),
343                    reqwest::header::HeaderValue::from_static("bar"),
344                )])),
345        );
346
347        client.post("/").multipart(form).await;
348    }
349
350    // No need for this to be a #[test], we just want to make sure it compiles
351    fn _multipart_from_request_limited() {
352        async fn handler(_: Multipart) {}
353        let _app: Router = Router::new()
354            .route("/", post(handler))
355            .layer(tower_http::limit::RequestBodyLimitLayer::new(1024));
356    }
357
358    #[crate::test]
359    async fn body_too_large() {
360        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
361
362        async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
363            while let Some(field) = multipart.next_field().await? {
364                field.bytes().await?;
365            }
366            Ok(())
367        }
368
369        let app = Router::new()
370            .route("/", post(handle))
371            .layer(DefaultBodyLimit::max(BYTES.len() - 1));
372
373        let client = TestClient::new(app);
374
375        let form =
376            reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
377
378        let res = client.post("/").multipart(form).await;
379        assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
380    }
381}