tower_http/request_id.rs
1//! Set and propagate request ids.
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, header::HeaderName};
7//! use tower::{Service, ServiceExt, ServiceBuilder};
8//! use tower_http::request_id::{
9//! SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
10//! };
11//! use http_body_util::Full;
12//! use bytes::Bytes;
13//! use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
14//!
15//! # #[tokio::main]
16//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
18//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
19//! # });
20//! #
21//! // A `MakeRequestId` that increments an atomic counter
22//! #[derive(Clone, Default)]
23//! struct MyMakeRequestId {
24//! counter: Arc<AtomicU64>,
25//! }
26//!
27//! impl MakeRequestId for MyMakeRequestId {
28//! fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
29//! let request_id = self.counter
30//! .fetch_add(1, Ordering::SeqCst)
31//! .to_string()
32//! .parse()
33//! .unwrap();
34//!
35//! Some(RequestId::new(request_id))
36//! }
37//! }
38//!
39//! let x_request_id = HeaderName::from_static("x-request-id");
40//!
41//! let mut svc = ServiceBuilder::new()
42//! // set `x-request-id` header on all requests
43//! .layer(SetRequestIdLayer::new(
44//! x_request_id.clone(),
45//! MyMakeRequestId::default(),
46//! ))
47//! // propagate `x-request-id` headers from request to response
48//! .layer(PropagateRequestIdLayer::new(x_request_id))
49//! .service(handler);
50//!
51//! let request = Request::new(Full::default());
52//! let response = svc.ready().await?.call(request).await?;
53//!
54//! assert_eq!(response.headers()["x-request-id"], "0");
55//! #
56//! # Ok(())
57//! # }
58//! ```
59//!
60//! Additional convenience methods are available on [`ServiceBuilderExt`]:
61//!
62//! ```
63//! use tower_http::ServiceBuilderExt;
64//! # use http::{Request, Response, header::HeaderName};
65//! # use tower::{Service, ServiceExt, ServiceBuilder};
66//! # use tower_http::request_id::{
67//! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
68//! # };
69//! # use bytes::Bytes;
70//! # use http_body_util::Full;
71//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
72//! # #[tokio::main]
73//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
74//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
75//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
76//! # });
77//! # #[derive(Clone, Default)]
78//! # struct MyMakeRequestId {
79//! # counter: Arc<AtomicU64>,
80//! # }
81//! # impl MakeRequestId for MyMakeRequestId {
82//! # fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
83//! # let request_id = self.counter
84//! # .fetch_add(1, Ordering::SeqCst)
85//! # .to_string()
86//! # .parse()
87//! # .unwrap();
88//! # Some(RequestId::new(request_id))
89//! # }
90//! # }
91//!
92//! let mut svc = ServiceBuilder::new()
93//! .set_x_request_id(MyMakeRequestId::default())
94//! .propagate_x_request_id()
95//! .service(handler);
96//!
97//! let request = Request::new(Full::default());
98//! let response = svc.ready().await?.call(request).await?;
99//!
100//! assert_eq!(response.headers()["x-request-id"], "0");
101//! #
102//! # Ok(())
103//! # }
104//! ```
105//!
106//! See [`SetRequestId`] and [`PropagateRequestId`] for more details.
107//!
108//! # Using `Trace`
109//!
110//! To have request ids show up correctly in logs produced by [`Trace`] you must apply the layers
111//! in this order:
112//!
113//! ```
114//! use tower_http::{
115//! ServiceBuilderExt,
116//! trace::{TraceLayer, DefaultMakeSpan, DefaultOnResponse},
117//! };
118//! # use http::{Request, Response, header::HeaderName};
119//! # use tower::{Service, ServiceExt, ServiceBuilder};
120//! # use tower_http::request_id::{
121//! # SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
122//! # };
123//! # use http_body_util::Full;
124//! # use bytes::Bytes;
125//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
126//! # #[tokio::main]
127//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
128//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
129//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
130//! # });
131//! # #[derive(Clone, Default)]
132//! # struct MyMakeRequestId {
133//! # counter: Arc<AtomicU64>,
134//! # }
135//! # impl MakeRequestId for MyMakeRequestId {
136//! # fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
137//! # let request_id = self.counter
138//! # .fetch_add(1, Ordering::SeqCst)
139//! # .to_string()
140//! # .parse()
141//! # .unwrap();
142//! # Some(RequestId::new(request_id))
143//! # }
144//! # }
145//!
146//! let svc = ServiceBuilder::new()
147//! // make sure to set request ids before the request reaches `TraceLayer`
148//! .set_x_request_id(MyMakeRequestId::default())
149//! // log requests and responses
150//! .layer(
151//! TraceLayer::new_for_http()
152//! .make_span_with(DefaultMakeSpan::new().include_headers(true))
153//! .on_response(DefaultOnResponse::new().include_headers(true))
154//! )
155//! // propagate the header to the response before the response reaches `TraceLayer`
156//! .propagate_x_request_id()
157//! .service(handler);
158//! #
159//! # Ok(())
160//! # }
161//! ```
162//!
163//! # Doesn't override existing headers
164//!
165//! [`SetRequestId`] and [`PropagateRequestId`] wont override request ids if its already present on
166//! requests or responses. Among other things, this allows other middleware to conditionally set
167//! request ids and use the middleware in this module as a fallback.
168//!
169//! [`ServiceBuilderExt`]: crate::ServiceBuilderExt
170//! [`Uuid`]: https://crates.io/crates/uuid
171//! [`Trace`]: crate::trace::Trace
172
173use http::{
174 header::{HeaderName, HeaderValue},
175 Request, Response,
176};
177use pin_project_lite::pin_project;
178use std::task::{ready, Context, Poll};
179use std::{future::Future, pin::Pin};
180use tower_layer::Layer;
181use tower_service::Service;
182use uuid::Uuid;
183
184pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
185
186/// Trait for producing [`RequestId`]s.
187///
188/// Used by [`SetRequestId`].
189pub trait MakeRequestId {
190 /// Try and produce a [`RequestId`] from the request.
191 fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId>;
192}
193
194/// An identifier for a request.
195#[derive(Debug, Clone)]
196pub struct RequestId(HeaderValue);
197
198impl RequestId {
199 /// Create a new `RequestId` from a [`HeaderValue`].
200 pub fn new(header_value: HeaderValue) -> Self {
201 Self(header_value)
202 }
203
204 /// Gets a reference to the underlying [`HeaderValue`].
205 pub fn header_value(&self) -> &HeaderValue {
206 &self.0
207 }
208
209 /// Consumes `self`, returning the underlying [`HeaderValue`].
210 pub fn into_header_value(self) -> HeaderValue {
211 self.0
212 }
213}
214
215impl From<HeaderValue> for RequestId {
216 fn from(value: HeaderValue) -> Self {
217 Self::new(value)
218 }
219}
220
221/// Set request id headers and extensions on requests.
222///
223/// This layer applies the [`SetRequestId`] middleware.
224///
225/// See the [module docs](self) and [`SetRequestId`] for more details.
226#[derive(Debug, Clone)]
227pub struct SetRequestIdLayer<M> {
228 header_name: HeaderName,
229 make_request_id: M,
230}
231
232impl<M> SetRequestIdLayer<M> {
233 /// Create a new `SetRequestIdLayer`.
234 pub fn new(header_name: HeaderName, make_request_id: M) -> Self
235 where
236 M: MakeRequestId,
237 {
238 SetRequestIdLayer {
239 header_name,
240 make_request_id,
241 }
242 }
243
244 /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name.
245 pub fn x_request_id(make_request_id: M) -> Self
246 where
247 M: MakeRequestId,
248 {
249 SetRequestIdLayer::new(X_REQUEST_ID, make_request_id)
250 }
251}
252
253impl<S, M> Layer<S> for SetRequestIdLayer<M>
254where
255 M: Clone + MakeRequestId,
256{
257 type Service = SetRequestId<S, M>;
258
259 fn layer(&self, inner: S) -> Self::Service {
260 SetRequestId::new(
261 inner,
262 self.header_name.clone(),
263 self.make_request_id.clone(),
264 )
265 }
266}
267
268/// Set request id headers and extensions on requests.
269///
270/// See the [module docs](self) for an example.
271///
272/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a
273/// header with the same name, then the header will be inserted.
274///
275/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other
276/// services can access it.
277#[derive(Debug, Clone)]
278pub struct SetRequestId<S, M> {
279 inner: S,
280 header_name: HeaderName,
281 make_request_id: M,
282}
283
284impl<S, M> SetRequestId<S, M> {
285 /// Create a new `SetRequestId`.
286 pub fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
287 where
288 M: MakeRequestId,
289 {
290 Self {
291 inner,
292 header_name,
293 make_request_id,
294 }
295 }
296
297 /// Create a new `SetRequestId` that uses `x-request-id` as the header name.
298 pub fn x_request_id(inner: S, make_request_id: M) -> Self
299 where
300 M: MakeRequestId,
301 {
302 Self::new(inner, X_REQUEST_ID, make_request_id)
303 }
304
305 define_inner_service_accessors!();
306
307 /// Returns a new [`Layer`] that wraps services with a `SetRequestId` middleware.
308 pub fn layer(header_name: HeaderName, make_request_id: M) -> SetRequestIdLayer<M>
309 where
310 M: MakeRequestId,
311 {
312 SetRequestIdLayer::new(header_name, make_request_id)
313 }
314}
315
316impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestId<S, M>
317where
318 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
319 M: MakeRequestId,
320{
321 type Response = S::Response;
322 type Error = S::Error;
323 type Future = S::Future;
324
325 #[inline]
326 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
327 self.inner.poll_ready(cx)
328 }
329
330 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
331 if let Some(request_id) = req.headers().get(&self.header_name) {
332 if req.extensions().get::<RequestId>().is_none() {
333 let request_id = request_id.clone();
334 req.extensions_mut().insert(RequestId::new(request_id));
335 }
336 } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
337 req.extensions_mut().insert(request_id.clone());
338 req.headers_mut()
339 .insert(self.header_name.clone(), request_id.0);
340 }
341
342 self.inner.call(req)
343 }
344}
345
346/// Propagate request ids from requests to responses.
347///
348/// This layer applies the [`PropagateRequestId`] middleware.
349///
350/// See the [module docs](self) and [`PropagateRequestId`] for more details.
351#[derive(Debug, Clone)]
352pub struct PropagateRequestIdLayer {
353 header_name: HeaderName,
354}
355
356impl PropagateRequestIdLayer {
357 /// Create a new `PropagateRequestIdLayer`.
358 pub fn new(header_name: HeaderName) -> Self {
359 PropagateRequestIdLayer { header_name }
360 }
361
362 /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name.
363 pub fn x_request_id() -> Self {
364 Self::new(X_REQUEST_ID)
365 }
366}
367
368impl<S> Layer<S> for PropagateRequestIdLayer {
369 type Service = PropagateRequestId<S>;
370
371 fn layer(&self, inner: S) -> Self::Service {
372 PropagateRequestId::new(inner, self.header_name.clone())
373 }
374}
375
376/// Propagate request ids from requests to responses.
377///
378/// See the [module docs](self) for an example.
379///
380/// If the request contains a matching header that header will be applied to responses. If a
381/// [`RequestId`] extension is also present it will be propagated as well.
382#[derive(Debug, Clone)]
383pub struct PropagateRequestId<S> {
384 inner: S,
385 header_name: HeaderName,
386}
387
388impl<S> PropagateRequestId<S> {
389 /// Create a new `PropagateRequestId`.
390 pub fn new(inner: S, header_name: HeaderName) -> Self {
391 Self { inner, header_name }
392 }
393
394 /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name.
395 pub fn x_request_id(inner: S) -> Self {
396 Self::new(inner, X_REQUEST_ID)
397 }
398
399 define_inner_service_accessors!();
400
401 /// Returns a new [`Layer`] that wraps services with a `PropagateRequestId` middleware.
402 pub fn layer(header_name: HeaderName) -> PropagateRequestIdLayer {
403 PropagateRequestIdLayer::new(header_name)
404 }
405}
406
407impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for PropagateRequestId<S>
408where
409 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
410{
411 type Response = S::Response;
412 type Error = S::Error;
413 type Future = PropagateRequestIdResponseFuture<S::Future>;
414
415 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
416 self.inner.poll_ready(cx)
417 }
418
419 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
420 let request_id = req
421 .headers()
422 .get(&self.header_name)
423 .cloned()
424 .map(RequestId::new);
425
426 PropagateRequestIdResponseFuture {
427 inner: self.inner.call(req),
428 header_name: self.header_name.clone(),
429 request_id,
430 }
431 }
432}
433
434pin_project! {
435 /// Response future for [`PropagateRequestId`].
436 pub struct PropagateRequestIdResponseFuture<F> {
437 #[pin]
438 inner: F,
439 header_name: HeaderName,
440 request_id: Option<RequestId>,
441 }
442}
443
444impl<F, B, E> Future for PropagateRequestIdResponseFuture<F>
445where
446 F: Future<Output = Result<Response<B>, E>>,
447{
448 type Output = Result<Response<B>, E>;
449
450 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
451 let this = self.project();
452 let mut response = ready!(this.inner.poll(cx))?;
453
454 if let Some(current_id) = response.headers().get(&*this.header_name) {
455 if response.extensions().get::<RequestId>().is_none() {
456 let current_id = current_id.clone();
457 response.extensions_mut().insert(RequestId::new(current_id));
458 }
459 } else if let Some(request_id) = this.request_id.take() {
460 response
461 .headers_mut()
462 .insert(this.header_name.clone(), request_id.0.clone());
463 response.extensions_mut().insert(request_id);
464 }
465
466 Poll::Ready(Ok(response))
467 }
468}
469
470/// A [`MakeRequestId`] that generates `UUID`s.
471#[derive(Clone, Copy, Default)]
472pub struct MakeRequestUuid;
473
474impl MakeRequestId for MakeRequestUuid {
475 fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
476 let request_id = Uuid::new_v4().to_string().parse().unwrap();
477 Some(RequestId::new(request_id))
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use crate::test_helpers::Body;
484 use crate::ServiceBuilderExt as _;
485 use http::Response;
486 use std::{
487 convert::Infallible,
488 sync::{
489 atomic::{AtomicU64, Ordering},
490 Arc,
491 },
492 };
493 use tower::{ServiceBuilder, ServiceExt};
494
495 #[allow(unused_imports)]
496 use super::*;
497
498 #[tokio::test]
499 async fn basic() {
500 let svc = ServiceBuilder::new()
501 .set_x_request_id(Counter::default())
502 .propagate_x_request_id()
503 .service_fn(handler);
504
505 // header on response
506 let req = Request::builder().body(Body::empty()).unwrap();
507 let res = svc.clone().oneshot(req).await.unwrap();
508 assert_eq!(res.headers()["x-request-id"], "0");
509
510 let req = Request::builder().body(Body::empty()).unwrap();
511 let res = svc.clone().oneshot(req).await.unwrap();
512 assert_eq!(res.headers()["x-request-id"], "1");
513
514 // doesn't override if header is already there
515 let req = Request::builder()
516 .header("x-request-id", "foo")
517 .body(Body::empty())
518 .unwrap();
519 let res = svc.clone().oneshot(req).await.unwrap();
520 assert_eq!(res.headers()["x-request-id"], "foo");
521
522 // extension propagated
523 let req = Request::builder().body(Body::empty()).unwrap();
524 let res = svc.clone().oneshot(req).await.unwrap();
525 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
526 }
527
528 #[tokio::test]
529 async fn other_middleware_setting_request_id() {
530 let svc = ServiceBuilder::new()
531 .override_request_header(
532 HeaderName::from_static("x-request-id"),
533 HeaderValue::from_str("foo").unwrap(),
534 )
535 .set_x_request_id(Counter::default())
536 .map_request(|request: Request<_>| {
537 // `set_x_request_id` should set the extension if its missing
538 assert_eq!(request.extensions().get::<RequestId>().unwrap().0, "foo");
539 request
540 })
541 .propagate_x_request_id()
542 .service_fn(handler);
543
544 let req = Request::builder()
545 .header(
546 "x-request-id",
547 "this-will-be-overriden-by-override_request_header-middleware",
548 )
549 .body(Body::empty())
550 .unwrap();
551 let res = svc.clone().oneshot(req).await.unwrap();
552 assert_eq!(res.headers()["x-request-id"], "foo");
553 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
554 }
555
556 #[tokio::test]
557 async fn other_middleware_setting_request_id_on_response() {
558 let svc = ServiceBuilder::new()
559 .set_x_request_id(Counter::default())
560 .propagate_x_request_id()
561 .override_response_header(
562 HeaderName::from_static("x-request-id"),
563 HeaderValue::from_str("foo").unwrap(),
564 )
565 .service_fn(handler);
566
567 let req = Request::builder()
568 .header("x-request-id", "foo")
569 .body(Body::empty())
570 .unwrap();
571 let res = svc.clone().oneshot(req).await.unwrap();
572 assert_eq!(res.headers()["x-request-id"], "foo");
573 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
574 }
575
576 #[derive(Clone, Default)]
577 struct Counter(Arc<AtomicU64>);
578
579 impl MakeRequestId for Counter {
580 fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
581 let id =
582 HeaderValue::from_str(&self.0.fetch_add(1, Ordering::SeqCst).to_string()).unwrap();
583 Some(RequestId::new(id))
584 }
585 }
586
587 async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
588 Ok(Response::new(Body::empty()))
589 }
590
591 #[tokio::test]
592 async fn uuid() {
593 let svc = ServiceBuilder::new()
594 .set_x_request_id(MakeRequestUuid)
595 .propagate_x_request_id()
596 .service_fn(handler);
597
598 // header on response
599 let req = Request::builder().body(Body::empty()).unwrap();
600 let mut res = svc.clone().oneshot(req).await.unwrap();
601 let id = res.headers_mut().remove("x-request-id").unwrap();
602 id.to_str().unwrap().parse::<Uuid>().unwrap();
603 }
604}