tower_http/
request_id.rs

1//! Set and propagate request ids.
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, header::HeaderName};
7//! use tower::{Service, ServiceExt, ServiceBuilder};
8//! use tower_http::request_id::{
9//!     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
10//! };
11//! use http_body_util::Full;
12//! use bytes::Bytes;
13//! use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
14//!
15//! # #[tokio::main]
16//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
18//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
19//! # });
20//! #
21//! // A `MakeRequestId` that increments an atomic counter
22//! #[derive(Clone, Default)]
23//! struct MyMakeRequestId {
24//!     counter: Arc<AtomicU64>,
25//! }
26//!
27//! impl MakeRequestId for MyMakeRequestId {
28//!     fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
29//!         let request_id = self.counter
30//!             .fetch_add(1, Ordering::SeqCst)
31//!             .to_string()
32//!             .parse()
33//!             .unwrap();
34//!
35//!         Some(RequestId::new(request_id))
36//!     }
37//! }
38//!
39//! let x_request_id = HeaderName::from_static("x-request-id");
40//!
41//! let mut svc = ServiceBuilder::new()
42//!     // set `x-request-id` header on all requests
43//!     .layer(SetRequestIdLayer::new(
44//!         x_request_id.clone(),
45//!         MyMakeRequestId::default(),
46//!     ))
47//!     // propagate `x-request-id` headers from request to response
48//!     .layer(PropagateRequestIdLayer::new(x_request_id))
49//!     .service(handler);
50//!
51//! let request = Request::new(Full::default());
52//! let response = svc.ready().await?.call(request).await?;
53//!
54//! assert_eq!(response.headers()["x-request-id"], "0");
55//! #
56//! # Ok(())
57//! # }
58//! ```
59//!
60//! Additional convenience methods are available on [`ServiceBuilderExt`]:
61//!
62//! ```
63//! use tower_http::ServiceBuilderExt;
64//! # use http::{Request, Response, header::HeaderName};
65//! # use tower::{Service, ServiceExt, ServiceBuilder};
66//! # use tower_http::request_id::{
67//! #     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
68//! # };
69//! # use bytes::Bytes;
70//! # use http_body_util::Full;
71//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
72//! # #[tokio::main]
73//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
74//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
75//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
76//! # });
77//! # #[derive(Clone, Default)]
78//! # struct MyMakeRequestId {
79//! #     counter: Arc<AtomicU64>,
80//! # }
81//! # impl MakeRequestId for MyMakeRequestId {
82//! #     fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
83//! #         let request_id = self.counter
84//! #             .fetch_add(1, Ordering::SeqCst)
85//! #             .to_string()
86//! #             .parse()
87//! #             .unwrap();
88//! #         Some(RequestId::new(request_id))
89//! #     }
90//! # }
91//!
92//! let mut svc = ServiceBuilder::new()
93//!     .set_x_request_id(MyMakeRequestId::default())
94//!     .propagate_x_request_id()
95//!     .service(handler);
96//!
97//! let request = Request::new(Full::default());
98//! let response = svc.ready().await?.call(request).await?;
99//!
100//! assert_eq!(response.headers()["x-request-id"], "0");
101//! #
102//! # Ok(())
103//! # }
104//! ```
105//!
106//! See [`SetRequestId`] and [`PropagateRequestId`] for more details.
107//!
108//! # Using `Trace`
109//!
110//! To have request ids show up correctly in logs produced by [`Trace`] you must apply the layers
111//! in this order:
112//!
113//! ```
114//! use tower_http::{
115//!     ServiceBuilderExt,
116//!     trace::{TraceLayer, DefaultMakeSpan, DefaultOnResponse},
117//! };
118//! # use http::{Request, Response, header::HeaderName};
119//! # use tower::{Service, ServiceExt, ServiceBuilder};
120//! # use tower_http::request_id::{
121//! #     SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
122//! # };
123//! # use http_body_util::Full;
124//! # use bytes::Bytes;
125//! # use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
126//! # #[tokio::main]
127//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
128//! # let handler = tower::service_fn(|request: Request<Full<Bytes>>| async move {
129//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
130//! # });
131//! # #[derive(Clone, Default)]
132//! # struct MyMakeRequestId {
133//! #     counter: Arc<AtomicU64>,
134//! # }
135//! # impl MakeRequestId for MyMakeRequestId {
136//! #     fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
137//! #         let request_id = self.counter
138//! #             .fetch_add(1, Ordering::SeqCst)
139//! #             .to_string()
140//! #             .parse()
141//! #             .unwrap();
142//! #         Some(RequestId::new(request_id))
143//! #     }
144//! # }
145//!
146//! let svc = ServiceBuilder::new()
147//!     // make sure to set request ids before the request reaches `TraceLayer`
148//!     .set_x_request_id(MyMakeRequestId::default())
149//!     // log requests and responses
150//!     .layer(
151//!         TraceLayer::new_for_http()
152//!             .make_span_with(DefaultMakeSpan::new().include_headers(true))
153//!             .on_response(DefaultOnResponse::new().include_headers(true))
154//!     )
155//!     // propagate the header to the response before the response reaches `TraceLayer`
156//!     .propagate_x_request_id()
157//!     .service(handler);
158//! #
159//! # Ok(())
160//! # }
161//! ```
162//!
163//! # Doesn't override existing headers
164//!
165//! [`SetRequestId`] and [`PropagateRequestId`] wont override request ids if its already present on
166//! requests or responses. Among other things, this allows other middleware to conditionally set
167//! request ids and use the middleware in this module as a fallback.
168//!
169//! [`ServiceBuilderExt`]: crate::ServiceBuilderExt
170//! [`Uuid`]: https://crates.io/crates/uuid
171//! [`Trace`]: crate::trace::Trace
172
173use http::{
174    header::{HeaderName, HeaderValue},
175    Request, Response,
176};
177use pin_project_lite::pin_project;
178use std::task::{ready, Context, Poll};
179use std::{future::Future, pin::Pin};
180use tower_layer::Layer;
181use tower_service::Service;
182use uuid::Uuid;
183
184pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
185
186/// Trait for producing [`RequestId`]s.
187///
188/// Used by [`SetRequestId`].
189pub trait MakeRequestId {
190    /// Try and produce a [`RequestId`] from the request.
191    fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId>;
192}
193
194/// An identifier for a request.
195#[derive(Debug, Clone)]
196pub struct RequestId(HeaderValue);
197
198impl RequestId {
199    /// Create a new `RequestId` from a [`HeaderValue`].
200    pub fn new(header_value: HeaderValue) -> Self {
201        Self(header_value)
202    }
203
204    /// Gets a reference to the underlying [`HeaderValue`].
205    pub fn header_value(&self) -> &HeaderValue {
206        &self.0
207    }
208
209    /// Consumes `self`, returning the underlying [`HeaderValue`].
210    pub fn into_header_value(self) -> HeaderValue {
211        self.0
212    }
213}
214
215impl From<HeaderValue> for RequestId {
216    fn from(value: HeaderValue) -> Self {
217        Self::new(value)
218    }
219}
220
221/// Set request id headers and extensions on requests.
222///
223/// This layer applies the [`SetRequestId`] middleware.
224///
225/// See the [module docs](self) and [`SetRequestId`] for more details.
226#[derive(Debug, Clone)]
227pub struct SetRequestIdLayer<M> {
228    header_name: HeaderName,
229    make_request_id: M,
230}
231
232impl<M> SetRequestIdLayer<M> {
233    /// Create a new `SetRequestIdLayer`.
234    pub fn new(header_name: HeaderName, make_request_id: M) -> Self
235    where
236        M: MakeRequestId,
237    {
238        SetRequestIdLayer {
239            header_name,
240            make_request_id,
241        }
242    }
243
244    /// Create a new `SetRequestIdLayer` that uses `x-request-id` as the header name.
245    pub fn x_request_id(make_request_id: M) -> Self
246    where
247        M: MakeRequestId,
248    {
249        SetRequestIdLayer::new(X_REQUEST_ID, make_request_id)
250    }
251}
252
253impl<S, M> Layer<S> for SetRequestIdLayer<M>
254where
255    M: Clone + MakeRequestId,
256{
257    type Service = SetRequestId<S, M>;
258
259    fn layer(&self, inner: S) -> Self::Service {
260        SetRequestId::new(
261            inner,
262            self.header_name.clone(),
263            self.make_request_id.clone(),
264        )
265    }
266}
267
268/// Set request id headers and extensions on requests.
269///
270/// See the [module docs](self) for an example.
271///
272/// If [`MakeRequestId::make_request_id`] returns `Some(_)` and the request doesn't already have a
273/// header with the same name, then the header will be inserted.
274///
275/// Additionally [`RequestId`] will be inserted into [`Request::extensions`] so other
276/// services can access it.
277#[derive(Debug, Clone)]
278pub struct SetRequestId<S, M> {
279    inner: S,
280    header_name: HeaderName,
281    make_request_id: M,
282}
283
284impl<S, M> SetRequestId<S, M> {
285    /// Create a new `SetRequestId`.
286    pub fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
287    where
288        M: MakeRequestId,
289    {
290        Self {
291            inner,
292            header_name,
293            make_request_id,
294        }
295    }
296
297    /// Create a new `SetRequestId` that uses `x-request-id` as the header name.
298    pub fn x_request_id(inner: S, make_request_id: M) -> Self
299    where
300        M: MakeRequestId,
301    {
302        Self::new(inner, X_REQUEST_ID, make_request_id)
303    }
304
305    define_inner_service_accessors!();
306
307    /// Returns a new [`Layer`] that wraps services with a `SetRequestId` middleware.
308    pub fn layer(header_name: HeaderName, make_request_id: M) -> SetRequestIdLayer<M>
309    where
310        M: MakeRequestId,
311    {
312        SetRequestIdLayer::new(header_name, make_request_id)
313    }
314}
315
316impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for SetRequestId<S, M>
317where
318    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
319    M: MakeRequestId,
320{
321    type Response = S::Response;
322    type Error = S::Error;
323    type Future = S::Future;
324
325    #[inline]
326    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
327        self.inner.poll_ready(cx)
328    }
329
330    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
331        if let Some(request_id) = req.headers().get(&self.header_name) {
332            if req.extensions().get::<RequestId>().is_none() {
333                let request_id = request_id.clone();
334                req.extensions_mut().insert(RequestId::new(request_id));
335            }
336        } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
337            req.extensions_mut().insert(request_id.clone());
338            req.headers_mut()
339                .insert(self.header_name.clone(), request_id.0);
340        }
341
342        self.inner.call(req)
343    }
344}
345
346/// Propagate request ids from requests to responses.
347///
348/// This layer applies the [`PropagateRequestId`] middleware.
349///
350/// See the [module docs](self) and [`PropagateRequestId`] for more details.
351#[derive(Debug, Clone)]
352pub struct PropagateRequestIdLayer {
353    header_name: HeaderName,
354}
355
356impl PropagateRequestIdLayer {
357    /// Create a new `PropagateRequestIdLayer`.
358    pub fn new(header_name: HeaderName) -> Self {
359        PropagateRequestIdLayer { header_name }
360    }
361
362    /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name.
363    pub fn x_request_id() -> Self {
364        Self::new(X_REQUEST_ID)
365    }
366}
367
368impl<S> Layer<S> for PropagateRequestIdLayer {
369    type Service = PropagateRequestId<S>;
370
371    fn layer(&self, inner: S) -> Self::Service {
372        PropagateRequestId::new(inner, self.header_name.clone())
373    }
374}
375
376/// Propagate request ids from requests to responses.
377///
378/// See the [module docs](self) for an example.
379///
380/// If the request contains a matching header that header will be applied to responses. If a
381/// [`RequestId`] extension is also present it will be propagated as well.
382#[derive(Debug, Clone)]
383pub struct PropagateRequestId<S> {
384    inner: S,
385    header_name: HeaderName,
386}
387
388impl<S> PropagateRequestId<S> {
389    /// Create a new `PropagateRequestId`.
390    pub fn new(inner: S, header_name: HeaderName) -> Self {
391        Self { inner, header_name }
392    }
393
394    /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name.
395    pub fn x_request_id(inner: S) -> Self {
396        Self::new(inner, X_REQUEST_ID)
397    }
398
399    define_inner_service_accessors!();
400
401    /// Returns a new [`Layer`] that wraps services with a `PropagateRequestId` middleware.
402    pub fn layer(header_name: HeaderName) -> PropagateRequestIdLayer {
403        PropagateRequestIdLayer::new(header_name)
404    }
405}
406
407impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for PropagateRequestId<S>
408where
409    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
410{
411    type Response = S::Response;
412    type Error = S::Error;
413    type Future = PropagateRequestIdResponseFuture<S::Future>;
414
415    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
416        self.inner.poll_ready(cx)
417    }
418
419    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
420        let request_id = req
421            .headers()
422            .get(&self.header_name)
423            .cloned()
424            .map(RequestId::new);
425
426        PropagateRequestIdResponseFuture {
427            inner: self.inner.call(req),
428            header_name: self.header_name.clone(),
429            request_id,
430        }
431    }
432}
433
434pin_project! {
435    /// Response future for [`PropagateRequestId`].
436    pub struct PropagateRequestIdResponseFuture<F> {
437        #[pin]
438        inner: F,
439        header_name: HeaderName,
440        request_id: Option<RequestId>,
441    }
442}
443
444impl<F, B, E> Future for PropagateRequestIdResponseFuture<F>
445where
446    F: Future<Output = Result<Response<B>, E>>,
447{
448    type Output = Result<Response<B>, E>;
449
450    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
451        let this = self.project();
452        let mut response = ready!(this.inner.poll(cx))?;
453
454        if let Some(current_id) = response.headers().get(&*this.header_name) {
455            if response.extensions().get::<RequestId>().is_none() {
456                let current_id = current_id.clone();
457                response.extensions_mut().insert(RequestId::new(current_id));
458            }
459        } else if let Some(request_id) = this.request_id.take() {
460            response
461                .headers_mut()
462                .insert(this.header_name.clone(), request_id.0.clone());
463            response.extensions_mut().insert(request_id);
464        }
465
466        Poll::Ready(Ok(response))
467    }
468}
469
470/// A [`MakeRequestId`] that generates `UUID`s.
471#[derive(Clone, Copy, Default)]
472pub struct MakeRequestUuid;
473
474impl MakeRequestId for MakeRequestUuid {
475    fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
476        let request_id = Uuid::new_v4().to_string().parse().unwrap();
477        Some(RequestId::new(request_id))
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use crate::test_helpers::Body;
484    use crate::ServiceBuilderExt as _;
485    use http::Response;
486    use std::{
487        convert::Infallible,
488        sync::{
489            atomic::{AtomicU64, Ordering},
490            Arc,
491        },
492    };
493    use tower::{ServiceBuilder, ServiceExt};
494
495    #[allow(unused_imports)]
496    use super::*;
497
498    #[tokio::test]
499    async fn basic() {
500        let svc = ServiceBuilder::new()
501            .set_x_request_id(Counter::default())
502            .propagate_x_request_id()
503            .service_fn(handler);
504
505        // header on response
506        let req = Request::builder().body(Body::empty()).unwrap();
507        let res = svc.clone().oneshot(req).await.unwrap();
508        assert_eq!(res.headers()["x-request-id"], "0");
509
510        let req = Request::builder().body(Body::empty()).unwrap();
511        let res = svc.clone().oneshot(req).await.unwrap();
512        assert_eq!(res.headers()["x-request-id"], "1");
513
514        // doesn't override if header is already there
515        let req = Request::builder()
516            .header("x-request-id", "foo")
517            .body(Body::empty())
518            .unwrap();
519        let res = svc.clone().oneshot(req).await.unwrap();
520        assert_eq!(res.headers()["x-request-id"], "foo");
521
522        // extension propagated
523        let req = Request::builder().body(Body::empty()).unwrap();
524        let res = svc.clone().oneshot(req).await.unwrap();
525        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
526    }
527
528    #[tokio::test]
529    async fn other_middleware_setting_request_id() {
530        let svc = ServiceBuilder::new()
531            .override_request_header(
532                HeaderName::from_static("x-request-id"),
533                HeaderValue::from_str("foo").unwrap(),
534            )
535            .set_x_request_id(Counter::default())
536            .map_request(|request: Request<_>| {
537                // `set_x_request_id` should set the extension if its missing
538                assert_eq!(request.extensions().get::<RequestId>().unwrap().0, "foo");
539                request
540            })
541            .propagate_x_request_id()
542            .service_fn(handler);
543
544        let req = Request::builder()
545            .header(
546                "x-request-id",
547                "this-will-be-overriden-by-override_request_header-middleware",
548            )
549            .body(Body::empty())
550            .unwrap();
551        let res = svc.clone().oneshot(req).await.unwrap();
552        assert_eq!(res.headers()["x-request-id"], "foo");
553        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
554    }
555
556    #[tokio::test]
557    async fn other_middleware_setting_request_id_on_response() {
558        let svc = ServiceBuilder::new()
559            .set_x_request_id(Counter::default())
560            .propagate_x_request_id()
561            .override_response_header(
562                HeaderName::from_static("x-request-id"),
563                HeaderValue::from_str("foo").unwrap(),
564            )
565            .service_fn(handler);
566
567        let req = Request::builder()
568            .header("x-request-id", "foo")
569            .body(Body::empty())
570            .unwrap();
571        let res = svc.clone().oneshot(req).await.unwrap();
572        assert_eq!(res.headers()["x-request-id"], "foo");
573        assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
574    }
575
576    #[derive(Clone, Default)]
577    struct Counter(Arc<AtomicU64>);
578
579    impl MakeRequestId for Counter {
580        fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
581            let id =
582                HeaderValue::from_str(&self.0.fetch_add(1, Ordering::SeqCst).to_string()).unwrap();
583            Some(RequestId::new(id))
584        }
585    }
586
587    async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
588        Ok(Response::new(Body::empty()))
589    }
590
591    #[tokio::test]
592    async fn uuid() {
593        let svc = ServiceBuilder::new()
594            .set_x_request_id(MakeRequestUuid)
595            .propagate_x_request_id()
596            .service_fn(handler);
597
598        // header on response
599        let req = Request::builder().body(Body::empty()).unwrap();
600        let mut res = svc.clone().oneshot(req).await.unwrap();
601        let id = res.headers_mut().remove("x-request-id").unwrap();
602        id.to_str().unwrap().parse::<Uuid>().unwrap();
603    }
604}