axum/extract/
multipart.rs1use super::{FromRequest, Request};
6use crate::body::Bytes;
7use axum_core::{
8 __composite_rejection as composite_rejection, __define_rejection as define_rejection,
9 response::{IntoResponse, Response},
10 RequestExt,
11};
12use futures_util::stream::Stream;
13use http::{
14 header::{HeaderMap, CONTENT_TYPE},
15 StatusCode,
16};
17use std::{
18 error::Error,
19 fmt,
20 pin::Pin,
21 task::{Context, Poll},
22};
23
24#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
62#[derive(Debug)]
63pub struct Multipart {
64 inner: multer::Multipart<'static>,
65}
66
67impl<S> FromRequest<S> for Multipart
68where
69 S: Send + Sync,
70{
71 type Rejection = MultipartRejection;
72
73 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
74 let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
75 let stream = req.with_limited_body().into_body();
76 let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
77 Ok(Self { inner: multipart })
78 }
79}
80
81impl Multipart {
82 pub async fn next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError> {
84 let field = self
85 .inner
86 .next_field()
87 .await
88 .map_err(MultipartError::from_multer)?;
89
90 if let Some(field) = field {
91 Ok(Some(Field {
92 inner: field,
93 _multipart: self,
94 }))
95 } else {
96 Ok(None)
97 }
98 }
99}
100
101#[derive(Debug)]
103pub struct Field<'a> {
104 inner: multer::Field<'static>,
105 _multipart: &'a mut Multipart,
108}
109
110impl Stream for Field<'_> {
111 type Item = Result<Bytes, MultipartError>;
112
113 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
114 Pin::new(&mut self.inner)
115 .poll_next(cx)
116 .map_err(MultipartError::from_multer)
117 }
118}
119
120impl Field<'_> {
121 pub fn name(&self) -> Option<&str> {
125 self.inner.name()
126 }
127
128 pub fn file_name(&self) -> Option<&str> {
132 self.inner.file_name()
133 }
134
135 pub fn content_type(&self) -> Option<&str> {
137 self.inner.content_type().map(|m| m.as_ref())
138 }
139
140 pub fn headers(&self) -> &HeaderMap {
142 self.inner.headers()
143 }
144
145 pub async fn bytes(self) -> Result<Bytes, MultipartError> {
147 self.inner
148 .bytes()
149 .await
150 .map_err(MultipartError::from_multer)
151 }
152
153 pub async fn text(self) -> Result<String, MultipartError> {
155 self.inner.text().await.map_err(MultipartError::from_multer)
156 }
157
158 pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
197 self.inner
198 .chunk()
199 .await
200 .map_err(MultipartError::from_multer)
201 }
202}
203
204#[derive(Debug)]
206pub struct MultipartError {
207 source: multer::Error,
208}
209
210impl MultipartError {
211 fn from_multer(multer: multer::Error) -> Self {
212 Self { source: multer }
213 }
214
215 pub fn body_text(&self) -> String {
217 self.source.to_string()
218 }
219
220 pub fn status(&self) -> http::StatusCode {
222 status_code_from_multer_error(&self.source)
223 }
224}
225
226fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
227 match err {
228 multer::Error::UnknownField { .. }
229 | multer::Error::IncompleteFieldData { .. }
230 | multer::Error::IncompleteHeaders
231 | multer::Error::ReadHeaderFailed(..)
232 | multer::Error::DecodeHeaderName { .. }
233 | multer::Error::DecodeContentType(..)
234 | multer::Error::NoBoundary
235 | multer::Error::DecodeHeaderValue { .. }
236 | multer::Error::NoMultipart
237 | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
238 multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
239 StatusCode::PAYLOAD_TOO_LARGE
240 }
241 multer::Error::StreamReadFailed(err) => {
242 if let Some(err) = err.downcast_ref::<multer::Error>() {
243 return status_code_from_multer_error(err);
244 }
245
246 if err
247 .downcast_ref::<crate::Error>()
248 .and_then(|err| err.source())
249 .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
250 .is_some()
251 {
252 return StatusCode::PAYLOAD_TOO_LARGE;
253 }
254
255 StatusCode::INTERNAL_SERVER_ERROR
256 }
257 _ => StatusCode::INTERNAL_SERVER_ERROR,
258 }
259}
260
261impl fmt::Display for MultipartError {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 write!(f, "Error parsing `multipart/form-data` request")
264 }
265}
266
267impl std::error::Error for MultipartError {
268 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
269 Some(&self.source)
270 }
271}
272
273impl IntoResponse for MultipartError {
274 fn into_response(self) -> Response {
275 let body = self.body_text();
276 axum_core::__log_rejection!(
277 rejection_type = Self,
278 body_text = body,
279 status = self.status(),
280 );
281 (self.status(), body).into_response()
282 }
283}
284
285fn parse_boundary(headers: &HeaderMap) -> Option<String> {
286 let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
287 multer::parse_boundary(content_type).ok()
288}
289
290composite_rejection! {
291 pub enum MultipartRejection {
295 InvalidBoundary,
296 }
297}
298
299define_rejection! {
300 #[status = BAD_REQUEST]
301 #[body = "Invalid `boundary` for `multipart/form-data` request"]
302 pub struct InvalidBoundary;
305}
306
307#[cfg(test)]
308mod tests {
309 use axum_core::extract::DefaultBodyLimit;
310
311 use super::*;
312 use crate::{routing::post, test_helpers::*, Router};
313
314 #[crate::test]
315 async fn content_type_with_encoding() {
316 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
317 const FILE_NAME: &str = "index.html";
318 const CONTENT_TYPE: &str = "text/html; charset=utf-8";
319
320 async fn handle(mut multipart: Multipart) -> impl IntoResponse {
321 let field = multipart.next_field().await.unwrap().unwrap();
322
323 assert_eq!(field.file_name().unwrap(), FILE_NAME);
324 assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
325 assert_eq!(field.headers()["foo"], "bar");
326 assert_eq!(field.bytes().await.unwrap(), BYTES);
327
328 assert!(multipart.next_field().await.unwrap().is_none());
329 }
330
331 let app = Router::new().route("/", post(handle));
332
333 let client = TestClient::new(app);
334
335 let form = reqwest::multipart::Form::new().part(
336 "file",
337 reqwest::multipart::Part::bytes(BYTES)
338 .file_name(FILE_NAME)
339 .mime_str(CONTENT_TYPE)
340 .unwrap()
341 .headers(reqwest::header::HeaderMap::from_iter([(
342 reqwest::header::HeaderName::from_static("foo"),
343 reqwest::header::HeaderValue::from_static("bar"),
344 )])),
345 );
346
347 client.post("/").multipart(form).await;
348 }
349
350 fn _multipart_from_request_limited() {
352 async fn handler(_: Multipart) {}
353 let _app: Router = Router::new()
354 .route("/", post(handler))
355 .layer(tower_http::limit::RequestBodyLimitLayer::new(1024));
356 }
357
358 #[crate::test]
359 async fn body_too_large() {
360 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
361
362 async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
363 while let Some(field) = multipart.next_field().await? {
364 field.bytes().await?;
365 }
366 Ok(())
367 }
368
369 let app = Router::new()
370 .route("/", post(handle))
371 .layer(DefaultBodyLimit::max(BYTES.len() - 1));
372
373 let client = TestClient::new(app);
374
375 let form =
376 reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
377
378 let res = client.post("/").multipart(form).await;
379 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
380 }
381}