ferron/util/
cgi_response.rs1use memchr::memmem::Finder;
2use std::io::Error;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
6
7const RESPONSE_BUFFER_CAPACITY: usize = 16384;
9
10pub struct CgiResponse<R>
12where
13 R: AsyncRead + Unpin,
14{
15 stream: R,
16 response_buf: Vec<u8>,
17 response_head_length: Option<usize>,
18}
19
20impl<R> CgiResponse<R>
21where
22 R: AsyncRead + Unpin,
23{
24 pub fn new(stream: R) -> Self {
26 Self {
27 stream,
28 response_buf: Vec::with_capacity(RESPONSE_BUFFER_CAPACITY),
29 response_head_length: None,
30 }
31 }
32
33 pub async fn get_head(&mut self) -> Result<&[u8], Error> {
35 let mut temp_buf = [0u8; RESPONSE_BUFFER_CAPACITY];
36 let rnrn = Finder::new(b"\r\n\r\n");
37 let nrnr = Finder::new(b"\n\r\n\r");
38 let nn = Finder::new(b"\n\n");
39 let rr = Finder::new(b"\r\r");
40 let to_parse_length;
41
42 loop {
43 let read_bytes = self.stream.read(&mut temp_buf).await?;
45
46 if read_bytes == 0 {
48 self.response_head_length = Some(0);
49 return Ok(&[0u8; 0]);
50 }
51
52 if self.response_buf.len() + read_bytes > RESPONSE_BUFFER_CAPACITY {
54 self.response_head_length = Some(0);
55 return Ok(&[0u8; 0]);
56 }
57
58 let begin_rnrn_or_nrnr_search = self.response_buf.len().saturating_sub(3);
60 let begin_rr_or_nn_search = self.response_buf.len().saturating_sub(1);
61 self.response_buf.extend_from_slice(&temp_buf[..read_bytes]);
62
63 if let Some(rnrn_index) = rnrn.find(&self.response_buf[begin_rnrn_or_nrnr_search..]) {
65 to_parse_length = begin_rnrn_or_nrnr_search + rnrn_index + 4;
66 break;
67 } else if let Some(nrnr_index) = nrnr.find(&self.response_buf[begin_rnrn_or_nrnr_search..]) {
68 to_parse_length = begin_rnrn_or_nrnr_search + nrnr_index + 4;
69 break;
70 } else if let Some(nn_index) = nn.find(&self.response_buf[begin_rr_or_nn_search..]) {
71 to_parse_length = begin_rr_or_nn_search + nn_index + 2;
72 break;
73 } else if let Some(rr_index) = rr.find(&self.response_buf[begin_rr_or_nn_search..]) {
74 to_parse_length = begin_rr_or_nn_search + rr_index + 2;
75 break;
76 }
77 }
78
79 self.response_head_length = Some(to_parse_length);
81
82 Ok(&self.response_buf[..to_parse_length])
84 }
85}
86
87impl<R> AsyncRead for CgiResponse<R>
89where
90 R: AsyncRead + Unpin,
91{
92 fn poll_read(
93 mut self: Pin<&mut Self>,
94 cx: &mut Context<'_>,
95 buf: &mut ReadBuf<'_>,
96 ) -> Poll<std::io::Result<()>> {
97 if let Some(response_head_length) = self.response_head_length {
99 if self.response_buf.len() > response_head_length {
100 let remaining_data = &self.response_buf[response_head_length..];
101 let to_read = remaining_data.len().min(buf.remaining());
102 buf.put_slice(&remaining_data[..to_read]);
103 self.response_head_length = Some(response_head_length + to_read);
104 return Poll::Ready(Ok(()));
105 }
106 }
107
108 let stream = Pin::new(&mut self.stream);
110 match stream.poll_read(cx, buf) {
111 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
112 other => other,
113 }
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use tokio::io::AsyncReadExt;
121 use tokio_test::io::Builder;
122
123 #[tokio::test]
124 async fn test_get_head() {
125 let data = b"Content-Type: text/plain\r\n\r\n";
126 let mut stream = Builder::new().read(data).build();
127 let mut response = CgiResponse::new(&mut stream);
128
129 let head = response.get_head().await.unwrap();
130 assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");
131 }
132
133 #[tokio::test]
134 async fn test_get_head_nn() {
135 let data = b"Content-Type: text/plain\n\n";
136 let mut stream = Builder::new().read(data).build();
137 let mut response = CgiResponse::new(&mut stream);
138
139 let head = response.get_head().await.unwrap();
140 assert_eq!(head, b"Content-Type: text/plain\n\n");
141 }
142
143 #[tokio::test]
144 async fn test_get_head_large_headers() {
145 let data = b"Content-Type: text/plain\r\n";
146 let large_header = vec![b'A'; RESPONSE_BUFFER_CAPACITY + 10]
147 .into_iter()
148 .collect::<Vec<u8>>();
149 let mut stream = Builder::new().read(data).read(&large_header).build();
150 let mut response = CgiResponse::new(&mut stream);
151
152 let result = response.get_head().await;
153 assert_eq!(result.unwrap().len(), 0);
154
155 let mut remaining_data = vec![0u8; RESPONSE_BUFFER_CAPACITY + 10];
157 let _ = response.stream.read(&mut remaining_data).await;
158 }
159
160 #[tokio::test]
161 async fn test_get_head_premature_eof() {
162 let data = b"Content-Type: text/plain\r\n";
163 let mut stream = Builder::new().read(data).build();
164 let mut response = CgiResponse::new(&mut stream);
165
166 let result = response.get_head().await;
167 assert_eq!(result.unwrap().len(), 0);
168 }
169
170 #[tokio::test]
171 async fn test_poll_read() {
172 let data = b"Content-Type: text/plain\r\n\r\nHello, world!";
173 let mut stream = Builder::new().read(data).build();
174 let mut response = CgiResponse::new(&mut stream);
175
176 let head = response.get_head().await.unwrap();
177 assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");
178
179 let mut buf = vec![0u8; 13];
180 let n = response.read(&mut buf).await.unwrap();
181 assert_eq!(n, 13);
182 assert_eq!(&buf[..n], b"Hello, world!");
183 }
184}