ferron/util/
cgi_response.rs

1use memchr::memmem::Finder;
2use std::io::Error;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
6
7// Constant defining the capacity of the response buffer
8const RESPONSE_BUFFER_CAPACITY: usize = 16384;
9
10// Struct representing a response, which wraps an async read stream
11pub struct CgiResponse<R>
12where
13  R: AsyncRead + Unpin,
14{
15  stream: R,
16  response_buf: Vec<u8>,
17  response_head_length: Option<usize>,
18}
19
20impl<R> CgiResponse<R>
21where
22  R: AsyncRead + Unpin,
23{
24  // Constructor to create a new CgiResponse instance
25  pub fn new(stream: R) -> Self {
26    Self {
27      stream,
28      response_buf: Vec::with_capacity(RESPONSE_BUFFER_CAPACITY),
29      response_head_length: None,
30    }
31  }
32
33  // Asynchronous method to get the response headers
34  pub async fn get_head(&mut self) -> Result<&[u8], Error> {
35    let mut temp_buf = [0u8; RESPONSE_BUFFER_CAPACITY];
36    let rnrn = Finder::new(b"\r\n\r\n");
37    let nrnr = Finder::new(b"\n\r\n\r");
38    let nn = Finder::new(b"\n\n");
39    let rr = Finder::new(b"\r\r");
40    let to_parse_length;
41
42    loop {
43      // Read data from the stream into the temporary buffer
44      let read_bytes = self.stream.read(&mut temp_buf).await?;
45
46      // If no bytes are read, return an empty response head
47      if read_bytes == 0 {
48        self.response_head_length = Some(0);
49        return Ok(&[0u8; 0]);
50      }
51
52      // If the response buffer exceeds the capacity, return an empty response head
53      if self.response_buf.len() + read_bytes > RESPONSE_BUFFER_CAPACITY {
54        self.response_head_length = Some(0);
55        return Ok(&[0u8; 0]);
56      }
57
58      // Determine the starting point for searching the "\r\n\r\n" sequence
59      let begin_rnrn_or_nrnr_search = self.response_buf.len().saturating_sub(3);
60      let begin_rr_or_nn_search = self.response_buf.len().saturating_sub(1);
61      self.response_buf.extend_from_slice(&temp_buf[..read_bytes]);
62
63      // Search for the "\r\n\r\n" sequence in the response buffer
64      if let Some(rnrn_index) = rnrn.find(&self.response_buf[begin_rnrn_or_nrnr_search..]) {
65        to_parse_length = begin_rnrn_or_nrnr_search + rnrn_index + 4;
66        break;
67      } else if let Some(nrnr_index) = nrnr.find(&self.response_buf[begin_rnrn_or_nrnr_search..]) {
68        to_parse_length = begin_rnrn_or_nrnr_search + nrnr_index + 4;
69        break;
70      } else if let Some(nn_index) = nn.find(&self.response_buf[begin_rr_or_nn_search..]) {
71        to_parse_length = begin_rr_or_nn_search + nn_index + 2;
72        break;
73      } else if let Some(rr_index) = rr.find(&self.response_buf[begin_rr_or_nn_search..]) {
74        to_parse_length = begin_rr_or_nn_search + rr_index + 2;
75        break;
76      }
77    }
78
79    // Set the length of the response header
80    self.response_head_length = Some(to_parse_length);
81
82    // Return the response header as a byte slice
83    Ok(&self.response_buf[..to_parse_length])
84  }
85}
86
87// Implementation of AsyncRead for the CgiResponse struct
88impl<R> AsyncRead for CgiResponse<R>
89where
90  R: AsyncRead + Unpin,
91{
92  fn poll_read(
93    mut self: Pin<&mut Self>,
94    cx: &mut Context<'_>,
95    buf: &mut ReadBuf<'_>,
96  ) -> Poll<std::io::Result<()>> {
97    // If the response header length is known and the buffer contains more data than the header length
98    if let Some(response_head_length) = self.response_head_length {
99      if self.response_buf.len() > response_head_length {
100        let remaining_data = &self.response_buf[response_head_length..];
101        let to_read = remaining_data.len().min(buf.remaining());
102        buf.put_slice(&remaining_data[..to_read]);
103        self.response_head_length = Some(response_head_length + to_read);
104        return Poll::Ready(Ok(()));
105      }
106    }
107
108    // Create a temporary buffer to hold the data to be consumed
109    let stream = Pin::new(&mut self.stream);
110    match stream.poll_read(cx, buf) {
111      Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
112      other => other,
113    }
114  }
115}
116
117#[cfg(test)]
118mod tests {
119  use super::*;
120  use tokio::io::AsyncReadExt;
121  use tokio_test::io::Builder;
122
123  #[tokio::test]
124  async fn test_get_head() {
125    let data = b"Content-Type: text/plain\r\n\r\n";
126    let mut stream = Builder::new().read(data).build();
127    let mut response = CgiResponse::new(&mut stream);
128
129    let head = response.get_head().await.unwrap();
130    assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");
131  }
132
133  #[tokio::test]
134  async fn test_get_head_nn() {
135    let data = b"Content-Type: text/plain\n\n";
136    let mut stream = Builder::new().read(data).build();
137    let mut response = CgiResponse::new(&mut stream);
138
139    let head = response.get_head().await.unwrap();
140    assert_eq!(head, b"Content-Type: text/plain\n\n");
141  }
142
143  #[tokio::test]
144  async fn test_get_head_large_headers() {
145    let data = b"Content-Type: text/plain\r\n";
146    let large_header = vec![b'A'; RESPONSE_BUFFER_CAPACITY + 10]
147      .into_iter()
148      .collect::<Vec<u8>>();
149    let mut stream = Builder::new().read(data).read(&large_header).build();
150    let mut response = CgiResponse::new(&mut stream);
151
152    let result = response.get_head().await;
153    assert_eq!(result.unwrap().len(), 0);
154
155    // Consume the remaining data to avoid panicking
156    let mut remaining_data = vec![0u8; RESPONSE_BUFFER_CAPACITY + 10];
157    let _ = response.stream.read(&mut remaining_data).await;
158  }
159
160  #[tokio::test]
161  async fn test_get_head_premature_eof() {
162    let data = b"Content-Type: text/plain\r\n";
163    let mut stream = Builder::new().read(data).build();
164    let mut response = CgiResponse::new(&mut stream);
165
166    let result = response.get_head().await;
167    assert_eq!(result.unwrap().len(), 0);
168  }
169
170  #[tokio::test]
171  async fn test_poll_read() {
172    let data = b"Content-Type: text/plain\r\n\r\nHello, world!";
173    let mut stream = Builder::new().read(data).build();
174    let mut response = CgiResponse::new(&mut stream);
175
176    let head = response.get_head().await.unwrap();
177    assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");
178
179    let mut buf = vec![0u8; 13];
180    let n = response.read(&mut buf).await.unwrap();
181    assert_eq!(n, 13);
182    assert_eq!(&buf[..n], b"Hello, world!");
183  }
184}