1use std::any::TypeId;
43use std::error::Error as StdError;
44use std::fmt;
45use std::future::Future;
46use std::io;
47use std::pin::Pin;
48use std::sync::{Arc, Mutex};
49use std::task::{Context, Poll};
50
51use crate::rt::{Read, ReadBufCursor, Write};
52use bytes::Bytes;
53use tokio::sync::oneshot;
54
55use crate::common::io::Rewind;
56
57pub struct Upgraded {
66    io: Rewind<Box<dyn Io + Send>>,
67}
68
69#[derive(Clone)]
73pub struct OnUpgrade {
74    rx: Option<Arc<Mutex<oneshot::Receiver<crate::Result<Upgraded>>>>>,
75}
76
77#[derive(Debug)]
82#[non_exhaustive]
83pub struct Parts<T> {
84    pub io: T,
86    pub read_buf: Bytes,
95}
96
97pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
106    msg.on_upgrade()
107}
108
109#[cfg(all(
110    any(feature = "client", feature = "server"),
111    any(feature = "http1", feature = "http2"),
112))]
113pub(super) struct Pending {
114    tx: oneshot::Sender<crate::Result<Upgraded>>,
115}
116
117#[cfg(all(
118    any(feature = "client", feature = "server"),
119    any(feature = "http1", feature = "http2"),
120))]
121pub(super) fn pending() -> (Pending, OnUpgrade) {
122    let (tx, rx) = oneshot::channel();
123    (
124        Pending { tx },
125        OnUpgrade {
126            rx: Some(Arc::new(Mutex::new(rx))),
127        },
128    )
129}
130
131impl Upgraded {
134    #[cfg(all(
135        any(feature = "client", feature = "server"),
136        any(feature = "http1", feature = "http2")
137    ))]
138    pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
139    where
140        T: Read + Write + Unpin + Send + 'static,
141    {
142        Upgraded {
143            io: Rewind::new_buffered(Box::new(io), read_buf),
144        }
145    }
146
147    pub fn downcast<T: Read + Write + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
152        let (io, buf) = self.io.into_inner();
153        match io.__hyper_downcast() {
154            Ok(t) => Ok(Parts {
155                io: *t,
156                read_buf: buf,
157            }),
158            Err(io) => Err(Upgraded {
159                io: Rewind::new_buffered(io, buf),
160            }),
161        }
162    }
163}
164
165impl Read for Upgraded {
166    fn poll_read(
167        mut self: Pin<&mut Self>,
168        cx: &mut Context<'_>,
169        buf: ReadBufCursor<'_>,
170    ) -> Poll<io::Result<()>> {
171        Pin::new(&mut self.io).poll_read(cx, buf)
172    }
173}
174
175impl Write for Upgraded {
176    fn poll_write(
177        mut self: Pin<&mut Self>,
178        cx: &mut Context<'_>,
179        buf: &[u8],
180    ) -> Poll<io::Result<usize>> {
181        Pin::new(&mut self.io).poll_write(cx, buf)
182    }
183
184    fn poll_write_vectored(
185        mut self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        bufs: &[io::IoSlice<'_>],
188    ) -> Poll<io::Result<usize>> {
189        Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
190    }
191
192    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193        Pin::new(&mut self.io).poll_flush(cx)
194    }
195
196    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
197        Pin::new(&mut self.io).poll_shutdown(cx)
198    }
199
200    fn is_write_vectored(&self) -> bool {
201        self.io.is_write_vectored()
202    }
203}
204
205impl fmt::Debug for Upgraded {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        f.debug_struct("Upgraded").finish()
208    }
209}
210
211impl OnUpgrade {
214    pub(super) fn none() -> Self {
215        OnUpgrade { rx: None }
216    }
217
218    #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))]
219    pub(super) fn is_none(&self) -> bool {
220        self.rx.is_none()
221    }
222}
223
224impl Future for OnUpgrade {
225    type Output = Result<Upgraded, crate::Error>;
226
227    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
228        match self.rx {
229            Some(ref rx) => Pin::new(&mut *rx.lock().unwrap())
230                .poll(cx)
231                .map(|res| match res {
232                    Ok(Ok(upgraded)) => Ok(upgraded),
233                    Ok(Err(err)) => Err(err),
234                    Err(_oneshot_canceled) => {
235                        Err(crate::Error::new_canceled().with(UpgradeExpected))
236                    }
237                }),
238            None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
239        }
240    }
241}
242
243impl fmt::Debug for OnUpgrade {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        f.debug_struct("OnUpgrade").finish()
246    }
247}
248
249#[cfg(all(
252    any(feature = "client", feature = "server"),
253    any(feature = "http1", feature = "http2")
254))]
255impl Pending {
256    pub(super) fn fulfill(self, upgraded: Upgraded) {
257        trace!("pending upgrade fulfill");
258        let _ = self.tx.send(Ok(upgraded));
259    }
260
261    #[cfg(feature = "http1")]
262    pub(super) fn manual(self) {
265        #[cfg(any(feature = "http1", feature = "http2"))]
266        trace!("pending upgrade handled manually");
267        let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
268    }
269}
270
271#[derive(Debug)]
278struct UpgradeExpected;
279
280impl fmt::Display for UpgradeExpected {
281    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282        f.write_str("upgrade expected but not completed")
283    }
284}
285
286impl StdError for UpgradeExpected {}
287
288pub(super) trait Io: Read + Write + Unpin + 'static {
291    fn __hyper_type_id(&self) -> TypeId {
292        TypeId::of::<Self>()
293    }
294}
295
296impl<T: Read + Write + Unpin + 'static> Io for T {}
297
298impl dyn Io + Send {
299    fn __hyper_is<T: Io>(&self) -> bool {
300        let t = TypeId::of::<T>();
301        self.__hyper_type_id() == t
302    }
303
304    fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
305        if self.__hyper_is::<T>() {
306            unsafe {
308                let raw: *mut dyn Io = Box::into_raw(self);
309                Ok(Box::from_raw(raw as *mut T))
310            }
311        } else {
312            Err(self)
313        }
314    }
315}
316
317mod sealed {
318    use super::OnUpgrade;
319
320    pub trait CanUpgrade {
321        fn on_upgrade(self) -> OnUpgrade;
322    }
323
324    impl<B> CanUpgrade for http::Request<B> {
325        fn on_upgrade(mut self) -> OnUpgrade {
326            self.extensions_mut()
327                .remove::<OnUpgrade>()
328                .unwrap_or_else(OnUpgrade::none)
329        }
330    }
331
332    impl<B> CanUpgrade for &'_ mut http::Request<B> {
333        fn on_upgrade(self) -> OnUpgrade {
334            self.extensions_mut()
335                .remove::<OnUpgrade>()
336                .unwrap_or_else(OnUpgrade::none)
337        }
338    }
339
340    impl<B> CanUpgrade for http::Response<B> {
341        fn on_upgrade(mut self) -> OnUpgrade {
342            self.extensions_mut()
343                .remove::<OnUpgrade>()
344                .unwrap_or_else(OnUpgrade::none)
345        }
346    }
347
348    impl<B> CanUpgrade for &'_ mut http::Response<B> {
349        fn on_upgrade(self) -> OnUpgrade {
350            self.extensions_mut()
351                .remove::<OnUpgrade>()
352                .unwrap_or_else(OnUpgrade::none)
353        }
354    }
355}
356
357#[cfg(all(
358    any(feature = "client", feature = "server"),
359    any(feature = "http1", feature = "http2"),
360))]
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn upgraded_downcast() {
367        let upgraded = Upgraded::new(Mock, Bytes::new());
368
369        let upgraded = upgraded
370            .downcast::<crate::common::io::Compat<std::io::Cursor<Vec<u8>>>>()
371            .unwrap_err();
372
373        upgraded.downcast::<Mock>().unwrap();
374    }
375
376    struct Mock;
378
379    impl Read for Mock {
380        fn poll_read(
381            self: Pin<&mut Self>,
382            _cx: &mut Context<'_>,
383            _buf: ReadBufCursor<'_>,
384        ) -> Poll<io::Result<()>> {
385            unreachable!("Mock::poll_read")
386        }
387    }
388
389    impl Write for Mock {
390        fn poll_write(
391            self: Pin<&mut Self>,
392            _: &mut Context<'_>,
393            buf: &[u8],
394        ) -> Poll<io::Result<usize>> {
395            Poll::Ready(Ok(buf.len()))
397        }
398
399        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
400            unreachable!("Mock::poll_flush")
401        }
402
403        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
404            unreachable!("Mock::poll_shutdown")
405        }
406    }
407}