1use super::{InsertHeaderMode, MakeHeaderValue};
98use http::{header::HeaderName, Request, Response};
99use pin_project_lite::pin_project;
100use std::{
101    fmt,
102    future::Future,
103    pin::Pin,
104    task::{ready, Context, Poll},
105};
106use tower_layer::Layer;
107use tower_service::Service;
108
109pub struct SetResponseHeaderLayer<M> {
113    header_name: HeaderName,
114    make: M,
115    mode: InsertHeaderMode,
116}
117
118impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        f.debug_struct("SetResponseHeaderLayer")
121            .field("header_name", &self.header_name)
122            .field("mode", &self.mode)
123            .field("make", &std::any::type_name::<M>())
124            .finish()
125    }
126}
127
128impl<M> SetResponseHeaderLayer<M> {
129    pub fn overriding(header_name: HeaderName, make: M) -> Self {
134        Self::new(header_name, make, InsertHeaderMode::Override)
135    }
136
137    pub fn appending(header_name: HeaderName, make: M) -> Self {
142        Self::new(header_name, make, InsertHeaderMode::Append)
143    }
144
145    pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
149        Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
150    }
151
152    fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
153        Self {
154            make,
155            header_name,
156            mode,
157        }
158    }
159}
160
161impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
162where
163    M: Clone,
164{
165    type Service = SetResponseHeader<S, M>;
166
167    fn layer(&self, inner: S) -> Self::Service {
168        SetResponseHeader {
169            inner,
170            header_name: self.header_name.clone(),
171            make: self.make.clone(),
172            mode: self.mode,
173        }
174    }
175}
176
177impl<M> Clone for SetResponseHeaderLayer<M>
178where
179    M: Clone,
180{
181    fn clone(&self) -> Self {
182        Self {
183            make: self.make.clone(),
184            header_name: self.header_name.clone(),
185            mode: self.mode,
186        }
187    }
188}
189
190#[derive(Clone)]
192pub struct SetResponseHeader<S, M> {
193    inner: S,
194    header_name: HeaderName,
195    make: M,
196    mode: InsertHeaderMode,
197}
198
199impl<S, M> SetResponseHeader<S, M> {
200    pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
205        Self::new(inner, header_name, make, InsertHeaderMode::Override)
206    }
207
208    pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
213        Self::new(inner, header_name, make, InsertHeaderMode::Append)
214    }
215
216    pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
220        Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
221    }
222
223    fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
224        Self {
225            inner,
226            header_name,
227            make,
228            mode,
229        }
230    }
231
232    define_inner_service_accessors!();
233}
234
235impl<S, M> fmt::Debug for SetResponseHeader<S, M>
236where
237    S: fmt::Debug,
238{
239    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240        f.debug_struct("SetResponseHeader")
241            .field("inner", &self.inner)
242            .field("header_name", &self.header_name)
243            .field("mode", &self.mode)
244            .field("make", &std::any::type_name::<M>())
245            .finish()
246    }
247}
248
249impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetResponseHeader<S, M>
250where
251    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
252    M: MakeHeaderValue<Response<ResBody>> + Clone,
253{
254    type Response = S::Response;
255    type Error = S::Error;
256    type Future = ResponseFuture<S::Future, M>;
257
258    #[inline]
259    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
260        self.inner.poll_ready(cx)
261    }
262
263    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
264        ResponseFuture {
265            future: self.inner.call(req),
266            header_name: self.header_name.clone(),
267            make: self.make.clone(),
268            mode: self.mode,
269        }
270    }
271}
272
273pin_project! {
274    #[derive(Debug)]
276    pub struct ResponseFuture<F, M> {
277        #[pin]
278        future: F,
279        header_name: HeaderName,
280        make: M,
281        mode: InsertHeaderMode,
282    }
283}
284
285impl<F, ResBody, E, M> Future for ResponseFuture<F, M>
286where
287    F: Future<Output = Result<Response<ResBody>, E>>,
288    M: MakeHeaderValue<Response<ResBody>>,
289{
290    type Output = F::Output;
291
292    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
293        let this = self.project();
294        let mut res = ready!(this.future.poll(cx)?);
295
296        this.mode.apply(this.header_name, &mut res, &mut *this.make);
297
298        Poll::Ready(Ok(res))
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::test_helpers::Body;
306    use http::{header, HeaderValue};
307    use std::convert::Infallible;
308    use tower::{service_fn, ServiceExt};
309
310    #[tokio::test]
311    async fn test_override_mode() {
312        let svc = SetResponseHeader::overriding(
313            service_fn(|_req: Request<Body>| async {
314                let res = Response::builder()
315                    .header(header::CONTENT_TYPE, "good-content")
316                    .body(Body::empty())
317                    .unwrap();
318                Ok::<_, Infallible>(res)
319            }),
320            header::CONTENT_TYPE,
321            HeaderValue::from_static("text/html"),
322        );
323
324        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
325
326        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
327        assert_eq!(values.next().unwrap(), "text/html");
328        assert_eq!(values.next(), None);
329    }
330
331    #[tokio::test]
332    async fn test_append_mode() {
333        let svc = SetResponseHeader::appending(
334            service_fn(|_req: Request<Body>| async {
335                let res = Response::builder()
336                    .header(header::CONTENT_TYPE, "good-content")
337                    .body(Body::empty())
338                    .unwrap();
339                Ok::<_, Infallible>(res)
340            }),
341            header::CONTENT_TYPE,
342            HeaderValue::from_static("text/html"),
343        );
344
345        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
346
347        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
348        assert_eq!(values.next().unwrap(), "good-content");
349        assert_eq!(values.next().unwrap(), "text/html");
350        assert_eq!(values.next(), None);
351    }
352
353    #[tokio::test]
354    async fn test_skip_if_present_mode() {
355        let svc = SetResponseHeader::if_not_present(
356            service_fn(|_req: Request<Body>| async {
357                let res = Response::builder()
358                    .header(header::CONTENT_TYPE, "good-content")
359                    .body(Body::empty())
360                    .unwrap();
361                Ok::<_, Infallible>(res)
362            }),
363            header::CONTENT_TYPE,
364            HeaderValue::from_static("text/html"),
365        );
366
367        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
368
369        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
370        assert_eq!(values.next().unwrap(), "good-content");
371        assert_eq!(values.next(), None);
372    }
373
374    #[tokio::test]
375    async fn test_skip_if_present_mode_when_not_present() {
376        let svc = SetResponseHeader::if_not_present(
377            service_fn(|_req: Request<Body>| async {
378                let res = Response::builder().body(Body::empty()).unwrap();
379                Ok::<_, Infallible>(res)
380            }),
381            header::CONTENT_TYPE,
382            HeaderValue::from_static("text/html"),
383        );
384
385        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
386
387        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
388        assert_eq!(values.next().unwrap(), "text/html");
389        assert_eq!(values.next(), None);
390    }
391}