tokio_tungstenite/
lib.rs

1//! Async WebSocket usage.
2//!
3//! This library is an implementation of WebSocket handshakes and streams. It
4//! is based on the crate which implements all required WebSocket protocol
5//! logic. So this crate basically just brings tokio support / tokio integration
6//! to it.
7//!
8//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
9//! so the socket is just a stream of messages coming in and going out.
10
11#![deny(missing_docs, unused_must_use, unused_mut, unused_imports, unused_import_braces)]
12
13pub use tungstenite;
14
15mod compat;
16#[cfg(feature = "connect")]
17mod connect;
18mod handshake;
19#[cfg(feature = "stream")]
20mod stream;
21#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
22mod tls;
23
24use std::io::{Read, Write};
25
26use compat::{cvt, AllowStd, ContextWaker};
27use futures_util::{
28    sink::{Sink, SinkExt},
29    stream::{FusedStream, Stream},
30};
31use log::*;
32use std::{
33    pin::Pin,
34    task::{Context, Poll},
35};
36use tokio::io::{AsyncRead, AsyncWrite};
37
38#[cfg(feature = "handshake")]
39use tungstenite::{
40    client::IntoClientRequest,
41    handshake::{
42        client::{ClientHandshake, Response},
43        server::{Callback, NoCallback},
44        HandshakeError,
45    },
46};
47use tungstenite::{
48    error::Error as WsError,
49    protocol::{Message, Role, WebSocket, WebSocketConfig},
50};
51
52#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
53pub use tls::Connector;
54#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
55pub use tls::{client_async_tls, client_async_tls_with_config};
56
57#[cfg(feature = "connect")]
58pub use connect::{connect_async, connect_async_with_config};
59
60#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "connect"))]
61pub use connect::connect_async_tls_with_config;
62
63#[cfg(feature = "stream")]
64pub use stream::MaybeTlsStream;
65
66use tungstenite::protocol::CloseFrame;
67
68/// Creates a WebSocket handshake from a request and a stream.
69/// For convenience, the user may call this with a url string, a URL,
70/// or a `Request`. Calling with `Request` allows the user to add
71/// a WebSocket protocol or other custom headers.
72///
73/// Internally, this custom creates a handshake representation and returns
74/// a future representing the resolution of the WebSocket handshake. The
75/// returned future will resolve to either `WebSocketStream<S>` or `Error`
76/// depending on whether the handshake is successful.
77///
78/// This is typically used for clients who have already established, for
79/// example, a TCP connection to the remote server.
80#[cfg(feature = "handshake")]
81pub async fn client_async<'a, R, S>(
82    request: R,
83    stream: S,
84) -> Result<(WebSocketStream<S>, Response), WsError>
85where
86    R: IntoClientRequest + Unpin,
87    S: AsyncRead + AsyncWrite + Unpin,
88{
89    client_async_with_config(request, stream, None).await
90}
91
92/// The same as `client_async()` but the one can specify a websocket configuration.
93/// Please refer to `client_async()` for more details.
94#[cfg(feature = "handshake")]
95pub async fn client_async_with_config<'a, R, S>(
96    request: R,
97    stream: S,
98    config: Option<WebSocketConfig>,
99) -> Result<(WebSocketStream<S>, Response), WsError>
100where
101    R: IntoClientRequest + Unpin,
102    S: AsyncRead + AsyncWrite + Unpin,
103{
104    let f = handshake::client_handshake(stream, move |allow_std| {
105        let request = request.into_client_request()?;
106        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
107        cli_handshake.handshake()
108    });
109    f.await.map_err(|e| match e {
110        HandshakeError::Failure(e) => e,
111        e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
112    })
113}
114
115/// Accepts a new WebSocket connection with the provided stream.
116///
117/// This function will internally call `server::accept` to create a
118/// handshake representation and returns a future representing the
119/// resolution of the WebSocket handshake. The returned future will resolve
120/// to either `WebSocketStream<S>` or `Error` depending if it's successful
121/// or not.
122///
123/// This is typically used after a socket has been accepted from a
124/// `TcpListener`. That socket is then passed to this function to perform
125/// the server half of the accepting a client's websocket connection.
126#[cfg(feature = "handshake")]
127pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
128where
129    S: AsyncRead + AsyncWrite + Unpin,
130{
131    accept_hdr_async(stream, NoCallback).await
132}
133
134/// The same as `accept_async()` but the one can specify a websocket configuration.
135/// Please refer to `accept_async()` for more details.
136#[cfg(feature = "handshake")]
137pub async fn accept_async_with_config<S>(
138    stream: S,
139    config: Option<WebSocketConfig>,
140) -> Result<WebSocketStream<S>, WsError>
141where
142    S: AsyncRead + AsyncWrite + Unpin,
143{
144    accept_hdr_async_with_config(stream, NoCallback, config).await
145}
146
147/// Accepts a new WebSocket connection with the provided stream.
148///
149/// This function does the same as `accept_async()` but accepts an extra callback
150/// for header processing. The callback receives headers of the incoming
151/// requests and is able to add extra headers to the reply.
152#[cfg(feature = "handshake")]
153pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
154where
155    S: AsyncRead + AsyncWrite + Unpin,
156    C: Callback + Unpin,
157{
158    accept_hdr_async_with_config(stream, callback, None).await
159}
160
161/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
162/// Please refer to `accept_hdr_async()` for more details.
163#[cfg(feature = "handshake")]
164pub async fn accept_hdr_async_with_config<S, C>(
165    stream: S,
166    callback: C,
167    config: Option<WebSocketConfig>,
168) -> Result<WebSocketStream<S>, WsError>
169where
170    S: AsyncRead + AsyncWrite + Unpin,
171    C: Callback + Unpin,
172{
173    let f = handshake::server_handshake(stream, move |allow_std| {
174        tungstenite::accept_hdr_with_config(allow_std, callback, config)
175    });
176    f.await.map_err(|e| match e {
177        HandshakeError::Failure(e) => e,
178        e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
179    })
180}
181
182/// A wrapper around an underlying raw stream which implements the WebSocket
183/// protocol.
184///
185/// A `WebSocketStream<S>` represents a handshake that has been completed
186/// successfully and both the server and the client are ready for receiving
187/// and sending data. Message from a `WebSocketStream<S>` are accessible
188/// through the respective `Stream` and `Sink`. Check more information about
189/// them in `futures-rs` crate documentation or have a look on the examples
190/// and unit tests for this crate.
191#[derive(Debug)]
192pub struct WebSocketStream<S> {
193    inner: WebSocket<AllowStd<S>>,
194    closing: bool,
195    ended: bool,
196    /// Tungstenite is probably ready to receive more data.
197    ///
198    /// `false` once start_send hits `WouldBlock` errors.
199    /// `true` initially and after `flush`ing.
200    ready: bool,
201}
202
203impl<S> WebSocketStream<S> {
204    /// Convert a raw socket into a WebSocketStream without performing a
205    /// handshake.
206    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
207    where
208        S: AsyncRead + AsyncWrite + Unpin,
209    {
210        handshake::without_handshake(stream, move |allow_std| {
211            WebSocket::from_raw_socket(allow_std, role, config)
212        })
213        .await
214    }
215
216    /// Convert a raw socket into a WebSocketStream without performing a
217    /// handshake.
218    pub async fn from_partially_read(
219        stream: S,
220        part: Vec<u8>,
221        role: Role,
222        config: Option<WebSocketConfig>,
223    ) -> Self
224    where
225        S: AsyncRead + AsyncWrite + Unpin,
226    {
227        handshake::without_handshake(stream, move |allow_std| {
228            WebSocket::from_partially_read(allow_std, part, role, config)
229        })
230        .await
231    }
232
233    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
234        Self { inner: ws, closing: false, ended: false, ready: true }
235    }
236
237    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
238    where
239        S: Unpin,
240        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
241        AllowStd<S>: Read + Write,
242    {
243        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
244        if let Some((kind, ctx)) = ctx {
245            self.inner.get_mut().set_waker(kind, ctx.waker());
246        }
247        f(&mut self.inner)
248    }
249
250    /// Consumes the `WebSocketStream` and returns the underlying stream.
251    pub fn into_inner(self) -> S {
252        self.inner.into_inner().into_inner()
253    }
254
255    /// Returns a shared reference to the inner stream.
256    pub fn get_ref(&self) -> &S
257    where
258        S: AsyncRead + AsyncWrite + Unpin,
259    {
260        self.inner.get_ref().get_ref()
261    }
262
263    /// Returns a mutable reference to the inner stream.
264    pub fn get_mut(&mut self) -> &mut S
265    where
266        S: AsyncRead + AsyncWrite + Unpin,
267    {
268        self.inner.get_mut().get_mut()
269    }
270
271    /// Returns a reference to the configuration of the tungstenite stream.
272    pub fn get_config(&self) -> &WebSocketConfig {
273        self.inner.get_config()
274    }
275
276    /// Close the underlying web socket
277    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
278    where
279        S: AsyncRead + AsyncWrite + Unpin,
280    {
281        self.send(Message::Close(msg)).await
282    }
283}
284
285impl<T> Stream for WebSocketStream<T>
286where
287    T: AsyncRead + AsyncWrite + Unpin,
288{
289    type Item = Result<Message, WsError>;
290
291    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
292        trace!("{}:{} Stream.poll_next", file!(), line!());
293
294        // The connection has been closed or a critical error has occurred.
295        // We have already returned the error to the user, the `Stream` is unusable,
296        // so we assume that the stream has been "fused".
297        if self.ended {
298            return Poll::Ready(None);
299        }
300
301        match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
302            trace!("{}:{} Stream.with_context poll_next -> read()", file!(), line!());
303            cvt(s.read())
304        })) {
305            Ok(v) => Poll::Ready(Some(Ok(v))),
306            Err(e) => {
307                self.ended = true;
308                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
309                    Poll::Ready(None)
310                } else {
311                    Poll::Ready(Some(Err(e)))
312                }
313            }
314        }
315    }
316}
317
318impl<T> FusedStream for WebSocketStream<T>
319where
320    T: AsyncRead + AsyncWrite + Unpin,
321{
322    fn is_terminated(&self) -> bool {
323        self.ended
324    }
325}
326
327impl<T> Sink<Message> for WebSocketStream<T>
328where
329    T: AsyncRead + AsyncWrite + Unpin,
330{
331    type Error = WsError;
332
333    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
334        if self.ready {
335            Poll::Ready(Ok(()))
336        } else {
337            // Currently blocked so try to flush the blockage away
338            (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
339                self.ready = true;
340                r
341            })
342        }
343    }
344
345    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
346        match (*self).with_context(None, |s| s.write(item)) {
347            Ok(()) => {
348                self.ready = true;
349                Ok(())
350            }
351            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
352                // the message was accepted and queued so not an error
353                // but `poll_ready` will now start trying to flush the block
354                self.ready = false;
355                Ok(())
356            }
357            Err(e) => {
358                self.ready = true;
359                debug!("websocket start_send error: {}", e);
360                Err(e)
361            }
362        }
363    }
364
365    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
366        (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
367            self.ready = true;
368            match r {
369                // WebSocket connection has just been closed. Flushing completed, not an error.
370                Err(WsError::ConnectionClosed) => Ok(()),
371                other => other,
372            }
373        })
374    }
375
376    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
377        self.ready = true;
378        let res = if self.closing {
379            // After queueing it, we call `flush` to drive the close handshake to completion.
380            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
381        } else {
382            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
383        };
384
385        match res {
386            Ok(()) => Poll::Ready(Ok(())),
387            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
388            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
389                trace!("WouldBlock");
390                self.closing = true;
391                Poll::Pending
392            }
393            Err(err) => {
394                debug!("websocket close error: {}", err);
395                Poll::Ready(Err(err))
396            }
397        }
398    }
399}
400
401/// Get a domain from an URL.
402#[cfg(any(feature = "connect", feature = "native-tls", feature = "__rustls-tls"))]
403#[inline]
404fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
405    match request.uri().host() {
406        // rustls expects IPv6 addresses without the surrounding [] brackets
407        #[cfg(feature = "__rustls-tls")]
408        Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
409        Some(d) => Ok(d.to_string()),
410        None => Err(WsError::Url(tungstenite::error::UrlError::NoHostName)),
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    #[cfg(feature = "connect")]
417    use crate::stream::MaybeTlsStream;
418    use crate::{compat::AllowStd, WebSocketStream};
419    use std::io::{Read, Write};
420    #[cfg(feature = "connect")]
421    use tokio::io::{AsyncReadExt, AsyncWriteExt};
422
423    fn is_read<T: Read>() {}
424    fn is_write<T: Write>() {}
425    #[cfg(feature = "connect")]
426    fn is_async_read<T: AsyncReadExt>() {}
427    #[cfg(feature = "connect")]
428    fn is_async_write<T: AsyncWriteExt>() {}
429    fn is_unpin<T: Unpin>() {}
430
431    #[test]
432    fn web_socket_stream_has_traits() {
433        is_read::<AllowStd<tokio::net::TcpStream>>();
434        is_write::<AllowStd<tokio::net::TcpStream>>();
435
436        #[cfg(feature = "connect")]
437        is_async_read::<MaybeTlsStream<tokio::net::TcpStream>>();
438        #[cfg(feature = "connect")]
439        is_async_write::<MaybeTlsStream<tokio::net::TcpStream>>();
440
441        is_unpin::<WebSocketStream<tokio::net::TcpStream>>();
442        #[cfg(feature = "connect")]
443        is_unpin::<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>();
444    }
445}