tokio_rustls_acme2/
incoming.rs

1use crate::acceptor::{AcmeAccept, AcmeAcceptor};
2use crate::{crypto_provider, AcmeState};
3use core::fmt;
4use futures_util::stream::{FusedStream, FuturesUnordered};
5use futures_util::Stream;
6use std::fmt::Debug;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite};
11use tokio_rustls::rustls::crypto::CryptoProvider;
12use tokio_rustls::rustls::ServerConfig;
13use tokio_rustls::server::TlsStream;
14use tokio_rustls::Accept;
15
16pub struct Incoming<
17    TCP: AsyncRead + AsyncWrite + Unpin,
18    ETCP,
19    ITCP: Stream<Item = Result<TCP, ETCP>> + Unpin,
20    EC: Debug + 'static,
21    EA: Debug + 'static,
22> {
23    state: AcmeState<EC, EA>,
24    acceptor: AcmeAcceptor,
25    rustls_config: Arc<ServerConfig>,
26    tcp_incoming: Option<ITCP>,
27    acme_accepting: FuturesUnordered<AcmeAccept<TCP>>,
28    tls_accepting: FuturesUnordered<Accept<TCP>>,
29}
30
31impl<TCP: AsyncRead + AsyncWrite + Unpin, ETCP, ITCP: Stream<Item = Result<TCP, ETCP>> + Unpin, EC: Debug + 'static, EA: Debug + 'static> fmt::Debug
32    for Incoming<TCP, ETCP, ITCP, EC, EA>
33{
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        f.debug_struct("Incoming")
36            .field("state", &self.state)
37            .field("acceptor", &self.acceptor)
38            .field("in_progress", &(self.acme_accepting.len() + self.tls_accepting.len()))
39            .field("terminated", &self.is_terminated())
40            .finish_non_exhaustive()
41    }
42}
43
44impl<TCP: AsyncRead + AsyncWrite + Unpin, ETCP, ITCP: Stream<Item = Result<TCP, ETCP>> + Unpin, EC: Debug + 'static, EA: Debug + 'static> Unpin
45    for Incoming<TCP, ETCP, ITCP, EC, EA>
46{
47}
48
49impl<TCP: AsyncRead + AsyncWrite + Unpin, ETCP, ITCP: Stream<Item = Result<TCP, ETCP>> + Unpin, EC: Debug + 'static, EA: Debug + 'static>
50    Incoming<TCP, ETCP, ITCP, EC, EA>
51{
52    #[cfg(any(feature = "ring", feature = "aws-lc-rs"))]
53    pub fn new(tcp_incoming: ITCP, state: AcmeState<EC, EA>, acceptor: AcmeAcceptor, alpn_protocols: Vec<Vec<u8>>) -> Self {
54        Self::new_with_provider(tcp_incoming, state, acceptor, alpn_protocols, crypto_provider().into())
55    }
56
57    /// Same as [Incoming::new], with a specific [CryptoProvider].
58    pub fn new_with_provider(
59        tcp_incoming: ITCP,
60        state: AcmeState<EC, EA>,
61        acceptor: AcmeAcceptor,
62        alpn_protocols: Vec<Vec<u8>>,
63        provider: Arc<CryptoProvider>,
64    ) -> Self {
65        let mut config = ServerConfig::builder_with_provider(provider)
66            .with_safe_default_protocol_versions()
67            .unwrap()
68            .with_no_client_auth()
69            .with_cert_resolver(state.resolver());
70        config.alpn_protocols = alpn_protocols;
71        Self {
72            state,
73            acceptor,
74            rustls_config: Arc::new(config),
75            tcp_incoming: Some(tcp_incoming),
76            acme_accepting: FuturesUnordered::new(),
77            tls_accepting: FuturesUnordered::new(),
78        }
79    }
80}
81
82impl<TCP: AsyncRead + AsyncWrite + Unpin, ETCP, ITCP: Stream<Item = Result<TCP, ETCP>> + Unpin, EC: Debug + 'static, EA: Debug + 'static> Stream
83    for Incoming<TCP, ETCP, ITCP, EC, EA>
84{
85    type Item = Result<TlsStream<TCP>, ETCP>;
86
87    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
88        loop {
89            match Pin::new(&mut self.state).poll_next(cx) {
90                Poll::Ready(Some(event)) => {
91                    match event {
92                        Ok(ok) => log::info!("event: {ok:?}"),
93                        Err(err) => log::error!("event: {err:?}"),
94                    }
95                    continue;
96                }
97                Poll::Ready(None) => unreachable!(),
98                Poll::Pending => {}
99            }
100            match Pin::new(&mut self.acme_accepting).poll_next(cx) {
101                Poll::Ready(Some(Ok(Some(tls)))) => self.tls_accepting.push(tls.into_stream(self.rustls_config.clone())),
102                Poll::Ready(Some(Ok(None))) => {
103                    log::info!("received TLS-ALPN-01 validation request");
104                    continue;
105                }
106                Poll::Ready(Some(Err(err))) => {
107                    log::error!("tls accept failed, {err:?}");
108                    continue;
109                }
110                Poll::Ready(None) | Poll::Pending => {}
111            }
112            match Pin::new(&mut self.tls_accepting).poll_next(cx) {
113                Poll::Ready(Some(Ok(tls))) => return Poll::Ready(Some(Ok(tls))),
114                Poll::Ready(Some(Err(err))) => {
115                    log::error!("tls accept failed, {err:?}");
116                    continue;
117                }
118                Poll::Ready(None) | Poll::Pending => {}
119            }
120            let tcp_incoming = match &mut self.tcp_incoming {
121                Some(tcp_incoming) => tcp_incoming,
122                None => match self.is_terminated() {
123                    true => return Poll::Ready(None),
124                    false => return Poll::Pending,
125                },
126            };
127            match Pin::new(tcp_incoming).poll_next(cx) {
128                Poll::Ready(Some(Ok(tcp))) => self.acme_accepting.push(self.acceptor.accept(tcp)),
129                Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
130                Poll::Ready(None) => drop(self.tcp_incoming.as_mut()),
131                Poll::Pending => return Poll::Pending,
132            }
133        }
134    }
135}
136
137impl<TCP: AsyncRead + AsyncWrite + Unpin, ETCP, ITCP: Stream<Item = Result<TCP, ETCP>> + Unpin, EC: Debug + 'static, EA: Debug + 'static> FusedStream
138    for Incoming<TCP, ETCP, ITCP, EC, EA>
139{
140    fn is_terminated(&self) -> bool {
141        self.tcp_incoming.is_none() && self.acme_accepting.is_terminated() && self.tls_accepting.is_terminated()
142    }
143}