1use std::collections::HashMap;
2use std::error::Error;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use crate::ferron_common::{
7 ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
8 ServerModuleHandlers, SocketData,
9};
10use crate::ferron_common::{HyperResponse, WithRuntime};
11use ahash::RandomState;
12use async_trait::async_trait;
13use cache_control::{Cachability, CacheControl};
14use futures_util::{StreamExt, TryStreamExt};
15use hashlink::LinkedHashMap;
16use http_body_util::{BodyExt, Full, StreamBody};
17use hyper::body::{Bytes, Frame};
18use hyper::header::{HeaderName, HeaderValue};
19use hyper::{header, HeaderMap, Method, Response, StatusCode};
20use hyper_tungstenite::HyperWebsocket;
21use tokio::runtime::Handle;
22use tokio::sync::RwLock;
23
24const CACHE_HEADER_NAME: &str = "X-Ferron-Cache";
25const DEFAULT_MAX_AGE: u64 = 300;
26
27pub fn server_module_init(
28 config: &ServerConfig,
29) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
30 let maximum_cache_entries = config["global"]["maximumCacheEntries"]
31 .as_i64()
32 .map(|v| v as usize);
33
34 Ok(Box::new(CacheModule::new(
35 Arc::new(RwLock::new(LinkedHashMap::with_hasher(RandomState::new()))),
36 Arc::new(RwLock::new(HashMap::with_hasher(RandomState::new()))),
37 maximum_cache_entries,
38 )))
39}
40
41#[allow(clippy::type_complexity)]
42struct CacheModule {
43 cache: Arc<
44 RwLock<
45 LinkedHashMap<
46 String,
47 (
48 StatusCode,
49 HeaderMap,
50 Vec<u8>,
51 Instant,
52 Option<CacheControl>,
53 ),
54 RandomState,
55 >,
56 >,
57 >,
58 vary_cache: Arc<RwLock<HashMap<String, Vec<String>, RandomState>>>,
59 maximum_cache_entries: Option<usize>,
60}
61
62impl CacheModule {
63 #[allow(clippy::type_complexity)]
64 fn new(
65 cache: Arc<
66 RwLock<
67 LinkedHashMap<
68 String,
69 (
70 StatusCode,
71 HeaderMap,
72 Vec<u8>,
73 Instant,
74 Option<CacheControl>,
75 ),
76 RandomState,
77 >,
78 >,
79 >,
80 vary_cache: Arc<RwLock<HashMap<String, Vec<String>, RandomState>>>,
81 maximum_cache_entries: Option<usize>,
82 ) -> Self {
83 Self {
84 cache,
85 vary_cache,
86 maximum_cache_entries,
87 }
88 }
89}
90
91impl ServerModule for CacheModule {
92 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
93 Box::new(CacheModuleHandlers {
94 cache: self.cache.clone(),
95 vary_cache: self.vary_cache.clone(),
96 maximum_cache_entries: self.maximum_cache_entries,
97 cache_vary_headers_configured: Vec::new(),
98 cache_ignore_headers_configured: Vec::new(),
99 maximum_cached_response_size: None,
100 cache_key: None,
101 request_headers: HeaderMap::new(),
102 has_authorization: false,
103 cached: false,
104 no_store: false,
105 handle,
106 })
107 }
108}
109
110#[allow(clippy::type_complexity)]
111struct CacheModuleHandlers {
112 handle: Handle,
113 cache: Arc<
114 RwLock<
115 LinkedHashMap<
116 String,
117 (
118 StatusCode,
119 HeaderMap,
120 Vec<u8>,
121 Instant,
122 Option<CacheControl>,
123 ),
124 RandomState,
125 >,
126 >,
127 >,
128 vary_cache: Arc<RwLock<HashMap<String, Vec<String>, RandomState>>>,
129 maximum_cache_entries: Option<usize>,
130 cache_vary_headers_configured: Vec<String>,
131 cache_ignore_headers_configured: Vec<String>,
132 maximum_cached_response_size: Option<u64>,
133 cache_key: Option<String>,
134 request_headers: HeaderMap<HeaderValue>,
135 has_authorization: bool,
136 cached: bool,
137 no_store: bool,
138}
139
140#[async_trait]
141impl ServerModuleHandlers for CacheModuleHandlers {
142 async fn request_handler(
143 &mut self,
144 request: RequestData,
145 config: &ServerConfig,
146 socket_data: &SocketData,
147 _error_logger: &ErrorLogger,
148 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
149 WithRuntime::new(self.handle.clone(), async move {
150 self.cache_vary_headers_configured = match config["cacheVaryHeaders"].as_vec() {
151 Some(vector) => {
152 let mut new_vector = Vec::new();
153 for yaml_value in vector.iter() {
154 if let Some(str_value) = yaml_value.as_str() {
155 new_vector.push(str_value.to_string());
156 }
157 }
158 new_vector
159 }
160 None => Vec::new(),
161 };
162 self.cache_ignore_headers_configured = match config["cacheIgnoreHeaders"].as_vec() {
163 Some(vector) => {
164 let mut new_vector = Vec::new();
165 for yaml_value in vector.iter() {
166 if let Some(str_value) = yaml_value.as_str() {
167 new_vector.push(str_value.to_string());
168 }
169 }
170 new_vector
171 }
172 None => Vec::new(),
173 };
174 self.maximum_cached_response_size = config["maximumCachedResponseSize"]
175 .as_i64()
176 .map(|f| f as u64);
177
178 let hyper_request = request.get_hyper_request();
179 let cache_key = format!(
180 "{} {}{}{}{}",
181 hyper_request.method().as_str(),
182 match socket_data.encrypted {
183 false => "http://",
184 true => "https://",
185 },
186 match hyper_request.headers().get(header::HOST) {
187 Some(host) => String::from_utf8_lossy(host.as_bytes()).into_owned(),
188 None => "".to_string(),
189 },
190 hyper_request.uri().path(),
191 match hyper_request.uri().query() {
192 Some(query) => format!("?{query}"),
193 None => "".to_string(),
194 }
195 );
196
197 let request_cache_control = match hyper_request.headers().get(header::CACHE_CONTROL) {
198 Some(value) => CacheControl::from_value(&String::from_utf8_lossy(value.as_bytes())),
199 None => None,
200 };
201
202 let mut no_store = false;
203 let mut no_cache = false;
204
205 if let Some(request_cache_control) = request_cache_control {
206 no_store = request_cache_control.no_store;
207 if let Some(cachability) = request_cache_control.cachability {
208 if cachability == Cachability::NoCache {
209 no_cache = true;
210 }
211 }
212 }
213
214 match hyper_request.method() {
215 &Method::GET | &Method::HEAD => (),
216 _ => {
217 no_store = true;
218 }
219 };
220
221 if no_store {
222 self.no_store = true;
223 return Ok(ResponseData::builder(request).build());
224 }
225
226 if !no_cache {
227 let rwlock_read = self.vary_cache.read().await;
228 let processed_vary = rwlock_read.get(&cache_key);
229 if let Some(processed_vary) = processed_vary {
230 let cache_key_with_vary = format!(
231 "{}\n{}",
232 &cache_key,
233 processed_vary
234 .iter()
235 .map(|header_name| {
236 match hyper_request.headers().get(header_name) {
237 Some(header_value) => format!(
238 "{}: {}",
239 header_name,
240 String::from_utf8_lossy(header_value.as_bytes()).into_owned()
241 ),
242 None => "".to_string(),
243 }
244 })
245 .collect::<Vec<String>>()
246 .join("\n")
247 );
248
249 drop(rwlock_read);
250
251 let rwlock_read = self.cache.read().await;
252 let cached_entry_option = rwlock_read.get(&cache_key_with_vary);
253
254 if let Some((status_code, headers, body, timestamp, response_cache_control)) =
255 cached_entry_option
256 {
257 let max_age = match response_cache_control {
258 Some(response_cache_control) => match response_cache_control.s_max_age {
259 Some(s_max_age) => Some(s_max_age),
260 None => response_cache_control.max_age,
261 },
262 None => None,
263 };
264
265 let mut cached = true;
266
267 if timestamp.elapsed() > max_age.unwrap_or(Duration::from_secs(DEFAULT_MAX_AGE)) {
268 cached = false;
269 }
270
271 if cached {
272 self.cached = true;
273 let mut hyper_response_builder = Response::builder().status(status_code);
274 for (header_name, header_value) in headers.iter() {
275 hyper_response_builder = hyper_response_builder.header(header_name, header_value);
276 }
277 let hyper_response = hyper_response_builder.body(
278 Full::new(Bytes::from(body.clone()))
279 .map_err(|e| match e {})
280 .boxed(),
281 )?;
282 return Ok(
283 ResponseData::builder(request)
284 .response(hyper_response)
285 .build(),
286 );
287 } else {
288 drop(rwlock_read);
289 }
290 }
291 } else {
292 drop(rwlock_read);
293 }
294 }
295
296 self.request_headers = hyper_request.headers().clone();
297 self.cache_key = Some(cache_key);
298 self.has_authorization = hyper_request.headers().contains_key(header::AUTHORIZATION);
299
300 Ok(ResponseData::builder(request).build())
301 })
302 .await
303 }
304
305 async fn proxy_request_handler(
306 &mut self,
307 request: RequestData,
308 _config: &ServerConfig,
309 _socket_data: &SocketData,
310 _error_logger: &ErrorLogger,
311 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
312 Ok(ResponseData::builder(request).build())
313 }
314
315 async fn response_modifying_handler(
316 &mut self,
317 mut response: HyperResponse,
318 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
319 WithRuntime::new(self.handle.clone(), async move {
320 if self.no_store {
321 response.headers_mut().insert(
322 HeaderName::from_static(CACHE_HEADER_NAME),
323 HeaderValue::from_static("BYPASS"),
324 );
325 Ok(response)
326 } else if self.cached {
327 response.headers_mut().insert(
328 HeaderName::from_static(CACHE_HEADER_NAME),
329 HeaderValue::from_static("HIT"),
330 );
331 Ok(response)
332 } else if let Some(cache_key) = &self.cache_key {
333 let (mut response_parts, mut response_body) = response.into_parts();
334 let response_cache_control = match response_parts.headers.get(header::CACHE_CONTROL) {
335 Some(value) => CacheControl::from_value(&String::from_utf8_lossy(value.as_bytes())),
336 None => None,
337 };
338
339 let should_cache_response = match &response_cache_control {
340 Some(response_cache_control) => {
341 let is_private = response_cache_control.cachability == Some(Cachability::Private);
342 let is_public = response_cache_control.cachability == Some(Cachability::Public);
343
344 !response_cache_control.no_store
345 && !is_private
346 && (is_public
347 || (!self.has_authorization
348 && (response_cache_control.max_age.is_some()
349 || response_cache_control.s_max_age.is_some())))
350 }
351 None => false,
352 };
353
354 if should_cache_response {
355 let mut response_body_buffer = Vec::new();
356 let mut maximum_cached_response_size_exceeded = false;
357
358 while let Some(frame) = response_body.frame().await {
359 let frame_unwrapped = frame?;
360 if frame_unwrapped.is_data() {
361 if let Some(bytes) = frame_unwrapped.data_ref() {
362 response_body_buffer.extend_from_slice(bytes);
363 if let Some(maximum_cached_response_size) = self.maximum_cached_response_size {
364 if response_body_buffer.len() as u64 > maximum_cached_response_size {
365 maximum_cached_response_size_exceeded = true;
366 break;
367 }
368 }
369 }
370 }
371 }
372
373 if maximum_cached_response_size_exceeded {
374 let cached_stream =
375 futures_util::stream::once(async move { Ok(Bytes::from(response_body_buffer)) });
376 let response_stream = response_body.into_data_stream();
377 let chained_stream = cached_stream.chain(response_stream);
378 let stream_body = StreamBody::new(chained_stream.map_ok(Frame::data));
379 let response_body = BodyExt::boxed(stream_body);
380 response_parts.headers.insert(
381 HeaderName::from_static(CACHE_HEADER_NAME),
382 HeaderValue::from_static("MISS"),
383 );
384 let response = Response::from_parts(response_parts, response_body);
385 Ok(response)
386 } else {
387 let mut response_vary = match response_parts.headers.get(header::VARY) {
388 Some(value) => String::from_utf8_lossy(value.as_bytes())
389 .split(",")
390 .map(|s| s.trim().to_owned())
391 .collect(),
392 None => Vec::new(),
393 };
394
395 let mut processed_vary_orig = self.cache_vary_headers_configured.clone();
396 processed_vary_orig.append(&mut response_vary);
397
398 let mut processed_vary = processed_vary_orig
399 .iter()
400 .map(|s| s.to_owned())
401 .collect::<Vec<String>>();
402
403 processed_vary.sort_unstable();
404 processed_vary.dedup();
405
406 if !processed_vary.contains(&"*".to_string()) {
407 let cache_key_with_vary = format!(
408 "{}\n{}",
409 &cache_key,
410 processed_vary
411 .iter()
412 .map(|header_name| {
413 match self.request_headers.get(header_name) {
414 Some(header_value) => format!(
415 "{}: {}",
416 header_name,
417 String::from_utf8_lossy(header_value.as_bytes()).into_owned()
418 ),
419 None => "".to_string(),
420 }
421 })
422 .collect::<Vec<String>>()
423 .join("\n")
424 );
425
426 let mut rwlock_write = self.vary_cache.write().await;
427 rwlock_write.insert(cache_key.clone(), processed_vary);
428 drop(rwlock_write);
429
430 let mut written_headers = response_parts.headers.clone();
431 for header in self.cache_ignore_headers_configured.iter() {
432 while written_headers.remove(header).is_some() {}
433 }
434
435 let mut rwlock_write = self.cache.write().await;
436 rwlock_write.retain(|_, (_, _, _, timestamp, response_cache_control)| {
437 let max_age = match response_cache_control {
438 Some(response_cache_control) => match response_cache_control.s_max_age {
439 Some(s_max_age) => Some(s_max_age),
440 None => response_cache_control.max_age,
441 },
442 None => None,
443 };
444
445 timestamp.elapsed() <= max_age.unwrap_or(Duration::from_secs(DEFAULT_MAX_AGE))
446 });
447
448 if let Some(maximum_cache_entries) = self.maximum_cache_entries {
449 while !rwlock_write.is_empty() && rwlock_write.len() >= maximum_cache_entries {
451 rwlock_write.pop_front();
452 }
453 }
454
455 rwlock_write.insert(
457 cache_key_with_vary,
458 (
459 response_parts.status,
460 written_headers,
461 response_body_buffer.clone(),
462 Instant::now(),
463 response_cache_control,
464 ),
465 );
466 drop(rwlock_write);
467 }
468
469 let cached_stream =
470 futures_util::stream::once(async move { Ok(Bytes::from(response_body_buffer)) });
471 let stream_body = StreamBody::new(cached_stream.map_ok(Frame::data));
472 let response_body = BodyExt::boxed(stream_body);
473 response_parts.headers.insert(
474 HeaderName::from_static(CACHE_HEADER_NAME),
475 HeaderValue::from_static("MISS"),
476 );
477 let response = Response::from_parts(response_parts, response_body);
478 Ok(response)
479 }
480 } else {
481 response_parts.headers.insert(
482 HeaderName::from_static(CACHE_HEADER_NAME),
483 HeaderValue::from_static("MISS"),
484 );
485 let response = Response::from_parts(response_parts, response_body);
486 Ok(response)
487 }
488 } else {
489 Ok(response)
490 }
491 })
492 .await
493 }
494
495 async fn proxy_response_modifying_handler(
496 &mut self,
497 response: HyperResponse,
498 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
499 Ok(response)
500 }
501
502 async fn connect_proxy_request_handler(
503 &mut self,
504 _upgraded_request: HyperUpgraded,
505 _connect_address: &str,
506 _config: &ServerConfig,
507 _socket_data: &SocketData,
508 _error_logger: &ErrorLogger,
509 ) -> Result<(), Box<dyn Error + Send + Sync>> {
510 Ok(())
511 }
512
513 fn does_connect_proxy_requests(&mut self) -> bool {
514 false
515 }
516
517 async fn websocket_request_handler(
518 &mut self,
519 _websocket: HyperWebsocket,
520 _uri: &hyper::Uri,
521 _headers: &hyper::HeaderMap,
522 _config: &ServerConfig,
523 _socket_data: &SocketData,
524 _error_logger: &ErrorLogger,
525 ) -> Result<(), Box<dyn Error + Send + Sync>> {
526 Ok(())
527 }
528
529 fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
530 false
531 }
532}