1use std::convert::TryFrom;
2use std::fmt;
3use std::hash::{Hash, Hasher};
4use std::str::FromStr;
5
6use bytes::Bytes;
7
8use super::{ErrorKind, InvalidUri};
9use crate::byte_str::ByteStr;
10
11#[derive(Clone)]
13pub struct Scheme {
14    pub(super) inner: Scheme2,
15}
16
17#[derive(Clone, Debug)]
18pub(super) enum Scheme2<T = Box<ByteStr>> {
19    None,
20    Standard(Protocol),
21    Other(T),
22}
23
24#[derive(Copy, Clone, Debug)]
25pub(super) enum Protocol {
26    Http,
27    Https,
28}
29
30impl Scheme {
31    pub const HTTP: Scheme = Scheme {
33        inner: Scheme2::Standard(Protocol::Http),
34    };
35
36    pub const HTTPS: Scheme = Scheme {
38        inner: Scheme2::Standard(Protocol::Https),
39    };
40
41    pub(super) fn empty() -> Self {
42        Scheme {
43            inner: Scheme2::None,
44        }
45    }
46
47    #[inline]
57    pub fn as_str(&self) -> &str {
58        use self::Protocol::*;
59        use self::Scheme2::*;
60
61        match self.inner {
62            Standard(Http) => "http",
63            Standard(Https) => "https",
64            Other(ref v) => &v[..],
65            None => unreachable!(),
66        }
67    }
68}
69
70impl<'a> TryFrom<&'a [u8]> for Scheme {
71    type Error = InvalidUri;
72    #[inline]
73    fn try_from(s: &'a [u8]) -> Result<Self, Self::Error> {
74        use self::Scheme2::*;
75
76        match Scheme2::parse_exact(s)? {
77            None => Err(ErrorKind::InvalidScheme.into()),
78            Standard(p) => Ok(Standard(p).into()),
79            Other(_) => {
80                let bytes = Bytes::copy_from_slice(s);
81
82                let string = unsafe { ByteStr::from_utf8_unchecked(bytes) };
85
86                Ok(Other(Box::new(string)).into())
87            }
88        }
89    }
90}
91
92impl<'a> TryFrom<&'a str> for Scheme {
93    type Error = InvalidUri;
94    #[inline]
95    fn try_from(s: &'a str) -> Result<Self, Self::Error> {
96        TryFrom::try_from(s.as_bytes())
97    }
98}
99
100impl FromStr for Scheme {
101    type Err = InvalidUri;
102
103    fn from_str(s: &str) -> Result<Self, Self::Err> {
104        TryFrom::try_from(s)
105    }
106}
107
108impl fmt::Debug for Scheme {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        fmt::Debug::fmt(self.as_str(), f)
111    }
112}
113
114impl fmt::Display for Scheme {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.write_str(self.as_str())
117    }
118}
119
120impl AsRef<str> for Scheme {
121    #[inline]
122    fn as_ref(&self) -> &str {
123        self.as_str()
124    }
125}
126
127impl PartialEq for Scheme {
128    fn eq(&self, other: &Scheme) -> bool {
129        use self::Protocol::*;
130        use self::Scheme2::*;
131
132        match (&self.inner, &other.inner) {
133            (&Standard(Http), &Standard(Http)) => true,
134            (&Standard(Https), &Standard(Https)) => true,
135            (Other(a), Other(b)) => a.eq_ignore_ascii_case(b),
136            (&None, _) | (_, &None) => unreachable!(),
137            _ => false,
138        }
139    }
140}
141
142impl Eq for Scheme {}
143
144impl PartialEq<str> for Scheme {
154    fn eq(&self, other: &str) -> bool {
155        self.as_str().eq_ignore_ascii_case(other)
156    }
157}
158
159impl PartialEq<Scheme> for str {
161    fn eq(&self, other: &Scheme) -> bool {
162        other == self
163    }
164}
165
166impl Hash for Scheme {
168    fn hash<H>(&self, state: &mut H)
169    where
170        H: Hasher,
171    {
172        match self.inner {
173            Scheme2::None => (),
174            Scheme2::Standard(Protocol::Http) => state.write_u8(1),
175            Scheme2::Standard(Protocol::Https) => state.write_u8(2),
176            Scheme2::Other(ref other) => {
177                other.len().hash(state);
178                for &b in other.as_bytes() {
179                    state.write_u8(b.to_ascii_lowercase());
180                }
181            }
182        }
183    }
184}
185
186impl<T> Scheme2<T> {
187    pub(super) fn is_none(&self) -> bool {
188        matches!(*self, Scheme2::None)
189    }
190}
191
192const MAX_SCHEME_LEN: usize = 64;
195
196#[rustfmt::skip]
205const SCHEME_CHARS: [u8; 256] = [
206    0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,  b'+',     0,  b'-',  b'.',     0,  b'0',  b'1', b'2',  b'3',  b'4',  b'5',  b'6',  b'7',  b'8',  b'9',  b':',     0, 0,     0,     0,     0,     0,  b'A',  b'B',  b'C',  b'D',  b'E', b'F',  b'G',  b'H',  b'I',  b'J',  b'K',  b'L',  b'M',  b'N',  b'O', b'P',  b'Q',  b'R',  b'S',  b'T',  b'U',  b'V',  b'W',  b'X',  b'Y', b'Z',     0,     0,     0,     0,     0,     0,  b'a',  b'b',  b'c', b'd',  b'e',  b'f',  b'g',  b'h',  b'i',  b'j',  b'k',  b'l',  b'm', b'n',  b'o',  b'p',  b'q',  b'r',  b's',  b't',  b'u',  b'v',  b'w', b'x',  b'y',  b'z',     0,     0,     0,  b'~',     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0,     0,     0,     0,     0, 0,     0,     0,     0,     0,     0                              ];
234
235impl Scheme2<usize> {
236    fn parse_exact(s: &[u8]) -> Result<Scheme2<()>, InvalidUri> {
238        match s {
239            b"http" => Ok(Protocol::Http.into()),
240            b"https" => Ok(Protocol::Https.into()),
241            _ => {
242                if s.len() > MAX_SCHEME_LEN {
243                    return Err(ErrorKind::SchemeTooLong.into());
244                }
245
246                for &b in s {
249                    match SCHEME_CHARS[b as usize] {
250                        b':' => {
251                            return Err(ErrorKind::InvalidScheme.into());
253                        }
254                        0 => {
255                            return Err(ErrorKind::InvalidScheme.into());
256                        }
257                        _ => {}
258                    }
259                }
260
261                Ok(Scheme2::Other(()))
262            }
263        }
264    }
265
266    pub(super) fn parse(s: &[u8]) -> Result<Scheme2<usize>, InvalidUri> {
267        if s.len() >= 7 {
268            if s[..7].eq_ignore_ascii_case(b"http://") {
270                return Ok(Protocol::Http.into());
272            }
273        }
274
275        if s.len() >= 8 {
276            if s[..8].eq_ignore_ascii_case(b"https://") {
278                return Ok(Protocol::Https.into());
279            }
280        }
281
282        if s.len() > 3 {
283            for i in 0..s.len() {
284                let b = s[i];
285
286                match SCHEME_CHARS[b as usize] {
287                    b':' => {
288                        if s.len() < i + 3 {
290                            break;
291                        }
292
293                        if &s[i + 1..i + 3] != b"//" {
295                            break;
296                        }
297
298                        if i > MAX_SCHEME_LEN {
299                            return Err(ErrorKind::SchemeTooLong.into());
300                        }
301
302                        return Ok(Scheme2::Other(i));
304                    }
305                    0 => break,
307                    _ => {}
308                }
309            }
310        }
311
312        Ok(Scheme2::None)
313    }
314}
315
316impl Protocol {
317    pub(super) fn len(&self) -> usize {
318        match *self {
319            Protocol::Http => 4,
320            Protocol::Https => 5,
321        }
322    }
323}
324
325impl<T> From<Protocol> for Scheme2<T> {
326    fn from(src: Protocol) -> Self {
327        Scheme2::Standard(src)
328    }
329}
330
331#[doc(hidden)]
332impl From<Scheme2> for Scheme {
333    fn from(src: Scheme2) -> Self {
334        Scheme { inner: src }
335    }
336}
337
338#[cfg(test)]
339mod test {
340    use super::*;
341
342    #[test]
343    fn scheme_eq_to_str() {
344        assert_eq!(&scheme("http"), "http");
345        assert_eq!(&scheme("https"), "https");
346        assert_eq!(&scheme("ftp"), "ftp");
347        assert_eq!(&scheme("my+funky+scheme"), "my+funky+scheme");
348    }
349
350    #[test]
351    fn invalid_scheme_is_error() {
352        Scheme::try_from("my_funky_scheme").expect_err("Unexpectedly valid Scheme");
353
354        Scheme::try_from([0xC0].as_ref()).expect_err("Unexpectedly valid Scheme");
356    }
357
358    fn scheme(s: &str) -> Scheme {
359        s.parse().expect(&format!("Invalid scheme: {}", s))
360    }
361}