1use std::{
55    future::Future,
56    pin::Pin,
57    task::{Context, Poll},
58    time::{Duration, Instant},
59};
60
61use hyper::rt::{Executor, Sleep, Timer};
62use pin_project_lite::pin_project;
63
64#[cfg(feature = "tracing")]
65use tracing::instrument::Instrument;
66
67pub use self::{with_hyper_io::WithHyperIo, with_tokio_io::WithTokioIo};
68
69mod with_hyper_io;
70mod with_tokio_io;
71
72#[non_exhaustive]
74#[derive(Default, Debug, Clone)]
75pub struct TokioExecutor {}
76
77pin_project! {
78    #[derive(Debug)]
82    pub struct TokioIo<T> {
83        #[pin]
84        inner: T,
85    }
86}
87
88#[non_exhaustive]
90#[derive(Default, Clone, Debug)]
91pub struct TokioTimer;
92
93pin_project! {
96    #[derive(Debug)]
97    struct TokioSleep {
98        #[pin]
99        inner: tokio::time::Sleep,
100    }
101}
102
103impl<Fut> Executor<Fut> for TokioExecutor
106where
107    Fut: Future + Send + 'static,
108    Fut::Output: Send + 'static,
109{
110    fn execute(&self, fut: Fut) {
111        #[cfg(feature = "tracing")]
112        tokio::spawn(fut.in_current_span());
113
114        #[cfg(not(feature = "tracing"))]
115        tokio::spawn(fut);
116    }
117}
118
119impl TokioExecutor {
120    pub fn new() -> Self {
122        Self {}
123    }
124}
125
126impl<T> TokioIo<T> {
129    pub fn new(inner: T) -> Self {
131        Self { inner }
132    }
133
134    pub fn inner(&self) -> &T {
136        &self.inner
137    }
138
139    pub fn inner_mut(&mut self) -> &mut T {
141        &mut self.inner
142    }
143
144    pub fn into_inner(self) -> T {
146        self.inner
147    }
148}
149
150impl<T> hyper::rt::Read for TokioIo<T>
151where
152    T: tokio::io::AsyncRead,
153{
154    fn poll_read(
155        self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        mut buf: hyper::rt::ReadBufCursor<'_>,
158    ) -> Poll<Result<(), std::io::Error>> {
159        let n = unsafe {
160            let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
161            match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
162                Poll::Ready(Ok(())) => tbuf.filled().len(),
163                other => return other,
164            }
165        };
166
167        unsafe {
168            buf.advance(n);
169        }
170        Poll::Ready(Ok(()))
171    }
172}
173
174impl<T> hyper::rt::Write for TokioIo<T>
175where
176    T: tokio::io::AsyncWrite,
177{
178    fn poll_write(
179        self: Pin<&mut Self>,
180        cx: &mut Context<'_>,
181        buf: &[u8],
182    ) -> Poll<Result<usize, std::io::Error>> {
183        tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
184    }
185
186    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
187        tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
188    }
189
190    fn poll_shutdown(
191        self: Pin<&mut Self>,
192        cx: &mut Context<'_>,
193    ) -> Poll<Result<(), std::io::Error>> {
194        tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
195    }
196
197    fn is_write_vectored(&self) -> bool {
198        tokio::io::AsyncWrite::is_write_vectored(&self.inner)
199    }
200
201    fn poll_write_vectored(
202        self: Pin<&mut Self>,
203        cx: &mut Context<'_>,
204        bufs: &[std::io::IoSlice<'_>],
205    ) -> Poll<Result<usize, std::io::Error>> {
206        tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
207    }
208}
209
210impl<T> tokio::io::AsyncRead for TokioIo<T>
211where
212    T: hyper::rt::Read,
213{
214    fn poll_read(
215        self: Pin<&mut Self>,
216        cx: &mut Context<'_>,
217        tbuf: &mut tokio::io::ReadBuf<'_>,
218    ) -> Poll<Result<(), std::io::Error>> {
219        let filled = tbuf.filled().len();
221        let sub_filled = unsafe {
222            let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
223
224            match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
225                Poll::Ready(Ok(())) => buf.filled().len(),
226                other => return other,
227            }
228        };
229
230        let n_filled = filled + sub_filled;
231        let n_init = sub_filled;
233        unsafe {
234            tbuf.assume_init(n_init);
235            tbuf.set_filled(n_filled);
236        }
237
238        Poll::Ready(Ok(()))
239    }
240}
241
242impl<T> tokio::io::AsyncWrite for TokioIo<T>
243where
244    T: hyper::rt::Write,
245{
246    fn poll_write(
247        self: Pin<&mut Self>,
248        cx: &mut Context<'_>,
249        buf: &[u8],
250    ) -> Poll<Result<usize, std::io::Error>> {
251        hyper::rt::Write::poll_write(self.project().inner, cx, buf)
252    }
253
254    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
255        hyper::rt::Write::poll_flush(self.project().inner, cx)
256    }
257
258    fn poll_shutdown(
259        self: Pin<&mut Self>,
260        cx: &mut Context<'_>,
261    ) -> Poll<Result<(), std::io::Error>> {
262        hyper::rt::Write::poll_shutdown(self.project().inner, cx)
263    }
264
265    fn is_write_vectored(&self) -> bool {
266        hyper::rt::Write::is_write_vectored(&self.inner)
267    }
268
269    fn poll_write_vectored(
270        self: Pin<&mut Self>,
271        cx: &mut Context<'_>,
272        bufs: &[std::io::IoSlice<'_>],
273    ) -> Poll<Result<usize, std::io::Error>> {
274        hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
275    }
276}
277
278impl Timer for TokioTimer {
281    fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
282        Box::pin(TokioSleep {
283            inner: tokio::time::sleep(duration),
284        })
285    }
286
287    fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
288        Box::pin(TokioSleep {
289            inner: tokio::time::sleep_until(deadline.into()),
290        })
291    }
292
293    fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
294        if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
295            sleep.reset(new_deadline)
296        }
297    }
298}
299
300impl TokioTimer {
301    pub fn new() -> Self {
303        Self {}
304    }
305}
306
307impl Future for TokioSleep {
308    type Output = ();
309
310    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
311        self.project().inner.poll(cx)
312    }
313}
314
315impl Sleep for TokioSleep {}
316
317impl TokioSleep {
318    fn reset(self: Pin<&mut Self>, deadline: Instant) {
319        self.project().inner.as_mut().reset(deadline.into());
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use crate::rt::TokioExecutor;
326    use hyper::rt::Executor;
327    use tokio::sync::oneshot;
328
329    #[cfg(not(miri))]
330    #[tokio::test]
331    async fn simple_execute() -> Result<(), Box<dyn std::error::Error>> {
332        let (tx, rx) = oneshot::channel();
333        let executor = TokioExecutor::new();
334        executor.execute(async move {
335            tx.send(()).unwrap();
336        });
337        rx.await.map_err(Into::into)
338    }
339}