ferron/optional_modules/
cache.rs

1use std::collections::HashMap;
2use std::error::Error;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use crate::ferron_common::{
7  ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
8  ServerModuleHandlers, SocketData,
9};
10use crate::ferron_common::{HyperResponse, WithRuntime};
11use ahash::RandomState;
12use async_trait::async_trait;
13use cache_control::{Cachability, CacheControl};
14use futures_util::{StreamExt, TryStreamExt};
15use hashlink::LinkedHashMap;
16use http_body_util::{BodyExt, Full, StreamBody};
17use hyper::body::{Bytes, Frame};
18use hyper::header::{HeaderName, HeaderValue};
19use hyper::{header, HeaderMap, Method, Response, StatusCode};
20use hyper_tungstenite::HyperWebsocket;
21use tokio::runtime::Handle;
22use tokio::sync::RwLock;
23
24const CACHE_HEADER_NAME: &str = "X-Ferron-Cache";
25const DEFAULT_MAX_AGE: u64 = 300;
26
27pub fn server_module_init(
28  config: &ServerConfig,
29) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
30  let maximum_cache_entries = config["global"]["maximumCacheEntries"]
31    .as_i64()
32    .map(|v| v as usize);
33
34  Ok(Box::new(CacheModule::new(
35    Arc::new(RwLock::new(LinkedHashMap::with_hasher(RandomState::new()))),
36    Arc::new(RwLock::new(HashMap::with_hasher(RandomState::new()))),
37    maximum_cache_entries,
38  )))
39}
40
41#[allow(clippy::type_complexity)]
42struct CacheModule {
43  cache: Arc<
44    RwLock<
45      LinkedHashMap<
46        String,
47        (
48          StatusCode,
49          HeaderMap,
50          Vec<u8>,
51          Instant,
52          Option<CacheControl>,
53        ),
54        RandomState,
55      >,
56    >,
57  >,
58  vary_cache: Arc<RwLock<HashMap<String, Vec<String>, RandomState>>>,
59  maximum_cache_entries: Option<usize>,
60}
61
62impl CacheModule {
63  #[allow(clippy::type_complexity)]
64  fn new(
65    cache: Arc<
66      RwLock<
67        LinkedHashMap<
68          String,
69          (
70            StatusCode,
71            HeaderMap,
72            Vec<u8>,
73            Instant,
74            Option<CacheControl>,
75          ),
76          RandomState,
77        >,
78      >,
79    >,
80    vary_cache: Arc<RwLock<HashMap<String, Vec<String>, RandomState>>>,
81    maximum_cache_entries: Option<usize>,
82  ) -> Self {
83    Self {
84      cache,
85      vary_cache,
86      maximum_cache_entries,
87    }
88  }
89}
90
91impl ServerModule for CacheModule {
92  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
93    Box::new(CacheModuleHandlers {
94      cache: self.cache.clone(),
95      vary_cache: self.vary_cache.clone(),
96      maximum_cache_entries: self.maximum_cache_entries,
97      cache_vary_headers_configured: Vec::new(),
98      cache_ignore_headers_configured: Vec::new(),
99      maximum_cached_response_size: None,
100      cache_key: None,
101      request_headers: HeaderMap::new(),
102      has_authorization: false,
103      cached: false,
104      no_store: false,
105      handle,
106    })
107  }
108}
109
110#[allow(clippy::type_complexity)]
111struct CacheModuleHandlers {
112  handle: Handle,
113  cache: Arc<
114    RwLock<
115      LinkedHashMap<
116        String,
117        (
118          StatusCode,
119          HeaderMap,
120          Vec<u8>,
121          Instant,
122          Option<CacheControl>,
123        ),
124        RandomState,
125      >,
126    >,
127  >,
128  vary_cache: Arc<RwLock<HashMap<String, Vec<String>, RandomState>>>,
129  maximum_cache_entries: Option<usize>,
130  cache_vary_headers_configured: Vec<String>,
131  cache_ignore_headers_configured: Vec<String>,
132  maximum_cached_response_size: Option<u64>,
133  cache_key: Option<String>,
134  request_headers: HeaderMap<HeaderValue>,
135  has_authorization: bool,
136  cached: bool,
137  no_store: bool,
138}
139
140#[async_trait]
141impl ServerModuleHandlers for CacheModuleHandlers {
142  async fn request_handler(
143    &mut self,
144    request: RequestData,
145    config: &ServerConfig,
146    socket_data: &SocketData,
147    _error_logger: &ErrorLogger,
148  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
149    WithRuntime::new(self.handle.clone(), async move {
150      self.cache_vary_headers_configured = match config["cacheVaryHeaders"].as_vec() {
151        Some(vector) => {
152          let mut new_vector = Vec::new();
153          for yaml_value in vector.iter() {
154            if let Some(str_value) = yaml_value.as_str() {
155              new_vector.push(str_value.to_string());
156            }
157          }
158          new_vector
159        }
160        None => Vec::new(),
161      };
162      self.cache_ignore_headers_configured = match config["cacheIgnoreHeaders"].as_vec() {
163        Some(vector) => {
164          let mut new_vector = Vec::new();
165          for yaml_value in vector.iter() {
166            if let Some(str_value) = yaml_value.as_str() {
167              new_vector.push(str_value.to_string());
168            }
169          }
170          new_vector
171        }
172        None => Vec::new(),
173      };
174      self.maximum_cached_response_size = config["maximumCachedResponseSize"]
175        .as_i64()
176        .map(|f| f as u64);
177
178      let hyper_request = request.get_hyper_request();
179      let cache_key = format!(
180        "{} {}{}{}{}",
181        hyper_request.method().as_str(),
182        match socket_data.encrypted {
183          false => "http://",
184          true => "https://",
185        },
186        match hyper_request.headers().get(header::HOST) {
187          Some(host) => String::from_utf8_lossy(host.as_bytes()).into_owned(),
188          None => "".to_string(),
189        },
190        hyper_request.uri().path(),
191        match hyper_request.uri().query() {
192          Some(query) => format!("?{query}"),
193          None => "".to_string(),
194        }
195      );
196
197      let request_cache_control = match hyper_request.headers().get(header::CACHE_CONTROL) {
198        Some(value) => CacheControl::from_value(&String::from_utf8_lossy(value.as_bytes())),
199        None => None,
200      };
201
202      let mut no_store = false;
203      let mut no_cache = false;
204
205      if let Some(request_cache_control) = request_cache_control {
206        no_store = request_cache_control.no_store;
207        if let Some(cachability) = request_cache_control.cachability {
208          if cachability == Cachability::NoCache {
209            no_cache = true;
210          }
211        }
212      }
213
214      match hyper_request.method() {
215        &Method::GET | &Method::HEAD => (),
216        _ => {
217          no_store = true;
218        }
219      };
220
221      if no_store {
222        self.no_store = true;
223        return Ok(ResponseData::builder(request).build());
224      }
225
226      if !no_cache {
227        let rwlock_read = self.vary_cache.read().await;
228        let processed_vary = rwlock_read.get(&cache_key);
229        if let Some(processed_vary) = processed_vary {
230          let cache_key_with_vary = format!(
231            "{}\n{}",
232            &cache_key,
233            processed_vary
234              .iter()
235              .map(|header_name| {
236                match hyper_request.headers().get(header_name) {
237                  Some(header_value) => format!(
238                    "{}: {}",
239                    header_name,
240                    String::from_utf8_lossy(header_value.as_bytes()).into_owned()
241                  ),
242                  None => "".to_string(),
243                }
244              })
245              .collect::<Vec<String>>()
246              .join("\n")
247          );
248
249          drop(rwlock_read);
250
251          let rwlock_read = self.cache.read().await;
252          let cached_entry_option = rwlock_read.get(&cache_key_with_vary);
253
254          if let Some((status_code, headers, body, timestamp, response_cache_control)) =
255            cached_entry_option
256          {
257            let max_age = match response_cache_control {
258              Some(response_cache_control) => match response_cache_control.s_max_age {
259                Some(s_max_age) => Some(s_max_age),
260                None => response_cache_control.max_age,
261              },
262              None => None,
263            };
264
265            let mut cached = true;
266
267            if timestamp.elapsed() > max_age.unwrap_or(Duration::from_secs(DEFAULT_MAX_AGE)) {
268              cached = false;
269            }
270
271            if cached {
272              self.cached = true;
273              let mut hyper_response_builder = Response::builder().status(status_code);
274              for (header_name, header_value) in headers.iter() {
275                hyper_response_builder = hyper_response_builder.header(header_name, header_value);
276              }
277              let hyper_response = hyper_response_builder.body(
278                Full::new(Bytes::from(body.clone()))
279                  .map_err(|e| match e {})
280                  .boxed(),
281              )?;
282              return Ok(
283                ResponseData::builder(request)
284                  .response(hyper_response)
285                  .build(),
286              );
287            } else {
288              drop(rwlock_read);
289            }
290          }
291        } else {
292          drop(rwlock_read);
293        }
294      }
295
296      self.request_headers = hyper_request.headers().clone();
297      self.cache_key = Some(cache_key);
298      self.has_authorization = hyper_request.headers().contains_key(header::AUTHORIZATION);
299
300      Ok(ResponseData::builder(request).build())
301    })
302    .await
303  }
304
305  async fn proxy_request_handler(
306    &mut self,
307    request: RequestData,
308    _config: &ServerConfig,
309    _socket_data: &SocketData,
310    _error_logger: &ErrorLogger,
311  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
312    Ok(ResponseData::builder(request).build())
313  }
314
315  async fn response_modifying_handler(
316    &mut self,
317    mut response: HyperResponse,
318  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
319    WithRuntime::new(self.handle.clone(), async move {
320      if self.no_store {
321        response.headers_mut().insert(
322          HeaderName::from_static(CACHE_HEADER_NAME),
323          HeaderValue::from_static("BYPASS"),
324        );
325        Ok(response)
326      } else if self.cached {
327        response.headers_mut().insert(
328          HeaderName::from_static(CACHE_HEADER_NAME),
329          HeaderValue::from_static("HIT"),
330        );
331        Ok(response)
332      } else if let Some(cache_key) = &self.cache_key {
333        let (mut response_parts, mut response_body) = response.into_parts();
334        let response_cache_control = match response_parts.headers.get(header::CACHE_CONTROL) {
335          Some(value) => CacheControl::from_value(&String::from_utf8_lossy(value.as_bytes())),
336          None => None,
337        };
338
339        let should_cache_response = match &response_cache_control {
340          Some(response_cache_control) => {
341            let is_private = response_cache_control.cachability == Some(Cachability::Private);
342            let is_public = response_cache_control.cachability == Some(Cachability::Public);
343
344            !response_cache_control.no_store
345              && !is_private
346              && (is_public
347                || (!self.has_authorization
348                  && (response_cache_control.max_age.is_some()
349                    || response_cache_control.s_max_age.is_some())))
350          }
351          None => false,
352        };
353
354        if should_cache_response {
355          let mut response_body_buffer = Vec::new();
356          let mut maximum_cached_response_size_exceeded = false;
357
358          while let Some(frame) = response_body.frame().await {
359            let frame_unwrapped = frame?;
360            if frame_unwrapped.is_data() {
361              if let Some(bytes) = frame_unwrapped.data_ref() {
362                response_body_buffer.extend_from_slice(bytes);
363                if let Some(maximum_cached_response_size) = self.maximum_cached_response_size {
364                  if response_body_buffer.len() as u64 > maximum_cached_response_size {
365                    maximum_cached_response_size_exceeded = true;
366                    break;
367                  }
368                }
369              }
370            }
371          }
372
373          if maximum_cached_response_size_exceeded {
374            let cached_stream =
375              futures_util::stream::once(async move { Ok(Bytes::from(response_body_buffer)) });
376            let response_stream = response_body.into_data_stream();
377            let chained_stream = cached_stream.chain(response_stream);
378            let stream_body = StreamBody::new(chained_stream.map_ok(Frame::data));
379            let response_body = BodyExt::boxed(stream_body);
380            response_parts.headers.insert(
381              HeaderName::from_static(CACHE_HEADER_NAME),
382              HeaderValue::from_static("MISS"),
383            );
384            let response = Response::from_parts(response_parts, response_body);
385            Ok(response)
386          } else {
387            let mut response_vary = match response_parts.headers.get(header::VARY) {
388              Some(value) => String::from_utf8_lossy(value.as_bytes())
389                .split(",")
390                .map(|s| s.trim().to_owned())
391                .collect(),
392              None => Vec::new(),
393            };
394
395            let mut processed_vary_orig = self.cache_vary_headers_configured.clone();
396            processed_vary_orig.append(&mut response_vary);
397
398            let mut processed_vary = processed_vary_orig
399              .iter()
400              .map(|s| s.to_owned())
401              .collect::<Vec<String>>();
402
403            processed_vary.sort_unstable();
404            processed_vary.dedup();
405
406            if !processed_vary.contains(&"*".to_string()) {
407              let cache_key_with_vary = format!(
408                "{}\n{}",
409                &cache_key,
410                processed_vary
411                  .iter()
412                  .map(|header_name| {
413                    match self.request_headers.get(header_name) {
414                      Some(header_value) => format!(
415                        "{}: {}",
416                        header_name,
417                        String::from_utf8_lossy(header_value.as_bytes()).into_owned()
418                      ),
419                      None => "".to_string(),
420                    }
421                  })
422                  .collect::<Vec<String>>()
423                  .join("\n")
424              );
425
426              let mut rwlock_write = self.vary_cache.write().await;
427              rwlock_write.insert(cache_key.clone(), processed_vary);
428              drop(rwlock_write);
429
430              let mut written_headers = response_parts.headers.clone();
431              for header in self.cache_ignore_headers_configured.iter() {
432                while written_headers.remove(header).is_some() {}
433              }
434
435              let mut rwlock_write = self.cache.write().await;
436              rwlock_write.retain(|_, (_, _, _, timestamp, response_cache_control)| {
437                let max_age = match response_cache_control {
438                  Some(response_cache_control) => match response_cache_control.s_max_age {
439                    Some(s_max_age) => Some(s_max_age),
440                    None => response_cache_control.max_age,
441                  },
442                  None => None,
443                };
444
445                timestamp.elapsed() <= max_age.unwrap_or(Duration::from_secs(DEFAULT_MAX_AGE))
446              });
447
448              if let Some(maximum_cache_entries) = self.maximum_cache_entries {
449                // Remove a value at the front of the list
450                while !rwlock_write.is_empty() && rwlock_write.len() >= maximum_cache_entries {
451                  rwlock_write.pop_front();
452                }
453              }
454
455              // This inserts a value at the back of the list
456              rwlock_write.insert(
457                cache_key_with_vary,
458                (
459                  response_parts.status,
460                  written_headers,
461                  response_body_buffer.clone(),
462                  Instant::now(),
463                  response_cache_control,
464                ),
465              );
466              drop(rwlock_write);
467            }
468
469            let cached_stream =
470              futures_util::stream::once(async move { Ok(Bytes::from(response_body_buffer)) });
471            let stream_body = StreamBody::new(cached_stream.map_ok(Frame::data));
472            let response_body = BodyExt::boxed(stream_body);
473            response_parts.headers.insert(
474              HeaderName::from_static(CACHE_HEADER_NAME),
475              HeaderValue::from_static("MISS"),
476            );
477            let response = Response::from_parts(response_parts, response_body);
478            Ok(response)
479          }
480        } else {
481          response_parts.headers.insert(
482            HeaderName::from_static(CACHE_HEADER_NAME),
483            HeaderValue::from_static("MISS"),
484          );
485          let response = Response::from_parts(response_parts, response_body);
486          Ok(response)
487        }
488      } else {
489        Ok(response)
490      }
491    })
492    .await
493  }
494
495  async fn proxy_response_modifying_handler(
496    &mut self,
497    response: HyperResponse,
498  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
499    Ok(response)
500  }
501
502  async fn connect_proxy_request_handler(
503    &mut self,
504    _upgraded_request: HyperUpgraded,
505    _connect_address: &str,
506    _config: &ServerConfig,
507    _socket_data: &SocketData,
508    _error_logger: &ErrorLogger,
509  ) -> Result<(), Box<dyn Error + Send + Sync>> {
510    Ok(())
511  }
512
513  fn does_connect_proxy_requests(&mut self) -> bool {
514    false
515  }
516
517  async fn websocket_request_handler(
518    &mut self,
519    _websocket: HyperWebsocket,
520    _uri: &hyper::Uri,
521    _headers: &hyper::HeaderMap,
522    _config: &ServerConfig,
523    _socket_data: &SocketData,
524    _error_logger: &ErrorLogger,
525  ) -> Result<(), Box<dyn Error + Send + Sync>> {
526    Ok(())
527  }
528
529  fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
530    false
531  }
532}