ferron/optional_modules/
fauth.rs

1// The "fauth" module is derived from the "rproxy" module, and inspired by Traefik's ForwardAuth middleware.
2
3use std::collections::HashMap;
4use std::error::Error;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use crate::ferron_common::{
9  ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
10  ServerModuleHandlers, SocketData,
11};
12use crate::ferron_common::{HyperResponse, WithRuntime};
13use async_trait::async_trait;
14use http_body_util::combinators::BoxBody;
15use http_body_util::{BodyExt, Empty};
16use hyper::body::Bytes;
17use hyper::client::conn::http1::SendRequest;
18use hyper::header::{self, HeaderName};
19use hyper::{Method, Request, StatusCode, Uri, Version};
20use hyper_tungstenite::HyperWebsocket;
21use hyper_util::rt::TokioIo;
22use rustls::pki_types::ServerName;
23use rustls::RootCertStore;
24use rustls_native_certs::load_native_certs;
25use tokio::io::{AsyncRead, AsyncWrite};
26use tokio::net::TcpStream;
27use tokio::runtime::Handle;
28use tokio::sync::RwLock;
29use tokio_rustls::TlsConnector;
30
31const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32;
32
33pub fn server_module_init(
34  _config: &ServerConfig,
35) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
36  let mut roots: RootCertStore = RootCertStore::empty();
37  let certs_result = load_native_certs();
38  if !certs_result.errors.is_empty() {
39    Err(anyhow::anyhow!(
40      "Couldn't load the native certificate store: {}",
41      certs_result.errors[0]
42    ))?
43  }
44  let certs = certs_result.certs;
45
46  for cert in certs {
47    match roots.add(cert) {
48      Ok(_) => (),
49      Err(err) => Err(anyhow::anyhow!(
50        "Couldn't add a certificate to the certificate store: {}",
51        err
52      ))?,
53    }
54  }
55
56  let mut connections_vec = Vec::new();
57  for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST {
58    connections_vec.push(RwLock::new(HashMap::new()));
59  }
60  Ok(Box::new(ForwardedAuthenticationModule::new(
61    Arc::new(roots),
62    Arc::new(connections_vec),
63  )))
64}
65
66#[allow(clippy::type_complexity)]
67struct ForwardedAuthenticationModule {
68  roots: Arc<RootCertStore>,
69  connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, hyper::Error>>>>>>,
70}
71
72impl ForwardedAuthenticationModule {
73  #[allow(clippy::type_complexity)]
74  fn new(
75    roots: Arc<RootCertStore>,
76    connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, hyper::Error>>>>>>,
77  ) -> Self {
78    Self { roots, connections }
79  }
80}
81
82impl ServerModule for ForwardedAuthenticationModule {
83  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
84    Box::new(ForwardedAuthenticationModuleHandlers {
85      roots: self.roots.clone(),
86      connections: self.connections.clone(),
87      handle,
88    })
89  }
90}
91
92#[allow(clippy::type_complexity)]
93struct ForwardedAuthenticationModuleHandlers {
94  handle: Handle,
95  roots: Arc<RootCertStore>,
96  connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, hyper::Error>>>>>>,
97}
98
99#[async_trait]
100impl ServerModuleHandlers for ForwardedAuthenticationModuleHandlers {
101  async fn request_handler(
102    &mut self,
103    request: RequestData,
104    config: &ServerConfig,
105    socket_data: &SocketData,
106    error_logger: &ErrorLogger,
107  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
108    WithRuntime::new(self.handle.clone(), async move {
109      let mut auth_to = None;
110
111      if let Some(auth_to_str) = config["authTo"].as_str() {
112        auth_to = Some(auth_to_str.to_string());
113      }
114
115      let forwarded_auth_copy_headers = match config["forwardedAuthCopyHeaders"].as_vec() {
116        Some(vector) => {
117          let mut new_vector = Vec::new();
118          for yaml_value in vector.iter() {
119            if let Some(str_value) = yaml_value.as_str() {
120              new_vector.push(str_value.to_string());
121            }
122          }
123          new_vector
124        }
125        None => Vec::new(),
126      };
127
128      if let Some(auth_to) = auth_to {
129        let (hyper_request, auth_user, original_url, error_status_code) = request.into_parts();
130        let (hyper_request_parts, request_body) = hyper_request.into_parts();
131
132        let auth_request_url = auth_to.parse::<hyper::Uri>()?;
133        let scheme_str = auth_request_url.scheme_str();
134        let mut encrypted = false;
135
136        match scheme_str {
137          Some("http") => {
138            encrypted = false;
139          }
140          Some("https") => {
141            encrypted = true;
142          }
143          _ => Err(anyhow::anyhow!(
144            "Only HTTP and HTTPS reverse proxy URLs are supported."
145          ))?,
146        };
147
148        let host = match auth_request_url.host() {
149          Some(host) => host,
150          None => Err(anyhow::anyhow!(
151            "The reverse proxy URL doesn't include the host"
152          ))?,
153        };
154
155        let port = auth_request_url.port_u16().unwrap_or(match scheme_str {
156          Some("http") => 80,
157          Some("https") => 443,
158          _ => 80,
159        });
160
161        let addr = format!("{host}:{port}");
162        let authority = auth_request_url.authority().cloned();
163
164        let hyper_request_path = hyper_request_parts.uri.path();
165
166        let path_and_query = format!(
167          "{}{}",
168          hyper_request_path,
169          match hyper_request_parts.uri.query() {
170            Some(query) => format!("?{query}"),
171            None => "".to_string(),
172          }
173        );
174
175        let mut auth_hyper_request_parts = hyper_request_parts.clone();
176
177        auth_hyper_request_parts.uri = Uri::from_str(&format!(
178          "{}{}",
179          auth_request_url.path(),
180          match auth_request_url.query() {
181            Some(query) => format!("?{query}"),
182            None => "".to_string(),
183          }
184        ))?;
185
186        let original_host = hyper_request_parts.headers.get(header::HOST).cloned();
187
188        // Host header for host identification
189        match authority {
190          Some(authority) => {
191            auth_hyper_request_parts
192              .headers
193              .insert(header::HOST, authority.to_string().parse()?);
194          }
195          None => {
196            auth_hyper_request_parts.headers.remove(header::HOST);
197          }
198        }
199
200        // Connection header to enable HTTP/1.1 keep-alive
201        auth_hyper_request_parts
202          .headers
203          .insert(header::CONNECTION, "keep-alive".parse()?);
204
205        // X-Forwarded-* headers to send the client's data to a forwarded authentication server
206        auth_hyper_request_parts.headers.insert(
207          HeaderName::from_static("x-forwarded-for"),
208          socket_data
209            .remote_addr
210            .ip()
211            .to_canonical()
212            .to_string()
213            .parse()?,
214        );
215
216        if socket_data.encrypted {
217          auth_hyper_request_parts.headers.insert(
218            HeaderName::from_static("x-forwarded-proto"),
219            "https".parse()?,
220          );
221        } else {
222          auth_hyper_request_parts.headers.insert(
223            HeaderName::from_static("x-forwarded-proto"),
224            "http".parse()?,
225          );
226        }
227
228        if let Some(original_host) = original_host {
229          auth_hyper_request_parts
230            .headers
231            .insert(HeaderName::from_static("x-forwarded-host"), original_host);
232        }
233
234        auth_hyper_request_parts.headers.insert(
235          HeaderName::from_static("x-forwarded-uri"),
236          path_and_query.parse()?,
237        );
238
239        auth_hyper_request_parts.headers.insert(
240          HeaderName::from_static("x-forwarded-method"),
241          hyper_request_parts.method.as_str().parse()?,
242        );
243
244        auth_hyper_request_parts.method = Method::GET;
245        auth_hyper_request_parts.version = Version::HTTP_11;
246
247        let auth_request = Request::from_parts(
248          auth_hyper_request_parts,
249          Empty::new().map_err(|e| match e {}).boxed(),
250        );
251        let original_hyper_request = Request::from_parts(hyper_request_parts, request_body);
252        let original_request = RequestData::new(
253          original_hyper_request,
254          auth_user,
255          original_url,
256          error_status_code,
257        );
258
259        let connections = &self.connections[rand::random_range(..self.connections.len())];
260
261        let rwlock_read = connections.read().await;
262        let sender_read_option = rwlock_read.get(&addr);
263
264        if let Some(sender_read) = sender_read_option {
265          if !sender_read.is_closed() {
266            drop(rwlock_read);
267            let mut rwlock_write = connections.write().await;
268            let sender_option = rwlock_write.get_mut(&addr);
269
270            if let Some(sender) = sender_option {
271              if !sender.is_closed() && sender.ready().await.is_ok() {
272                let result = http_forwarded_auth_kept_alive(
273                  sender,
274                  auth_request,
275                  error_logger,
276                  original_request,
277                  forwarded_auth_copy_headers,
278                )
279                .await;
280                drop(rwlock_write);
281                return result;
282              } else {
283                drop(rwlock_write);
284              }
285            } else {
286              drop(rwlock_write);
287            }
288          } else {
289            drop(rwlock_read);
290          }
291        } else {
292          drop(rwlock_read);
293        }
294
295        let stream = match TcpStream::connect(&addr).await {
296          Ok(stream) => stream,
297          Err(err) => {
298            match err.kind() {
299              tokio::io::ErrorKind::ConnectionRefused
300              | tokio::io::ErrorKind::NotFound
301              | tokio::io::ErrorKind::HostUnreachable => {
302                error_logger
303                  .log(&format!("Service unavailable: {err}"))
304                  .await;
305                return Ok(
306                  ResponseData::builder_without_request()
307                    .status(StatusCode::SERVICE_UNAVAILABLE)
308                    .build(),
309                );
310              }
311              tokio::io::ErrorKind::TimedOut => {
312                error_logger.log(&format!("Gateway timeout: {err}")).await;
313                return Ok(
314                  ResponseData::builder_without_request()
315                    .status(StatusCode::GATEWAY_TIMEOUT)
316                    .build(),
317                );
318              }
319              _ => {
320                error_logger.log(&format!("Bad gateway: {err}")).await;
321                return Ok(
322                  ResponseData::builder_without_request()
323                    .status(StatusCode::BAD_GATEWAY)
324                    .build(),
325                );
326              }
327            };
328          }
329        };
330
331        match stream.set_nodelay(true) {
332          Ok(_) => (),
333          Err(err) => {
334            error_logger.log(&format!("Bad gateway: {err}")).await;
335            return Ok(
336              ResponseData::builder_without_request()
337                .status(StatusCode::BAD_GATEWAY)
338                .build(),
339            );
340          }
341        };
342
343        if !encrypted {
344          http_forwarded_auth(
345            connections,
346            addr,
347            stream,
348            auth_request,
349            error_logger,
350            original_request,
351            forwarded_auth_copy_headers,
352          )
353          .await
354        } else {
355          let tls_client_config = rustls::ClientConfig::builder()
356            .with_root_certificates(self.roots.clone())
357            .with_no_client_auth();
358          let connector = TlsConnector::from(Arc::new(tls_client_config));
359          let domain = ServerName::try_from(host)?.to_owned();
360
361          let tls_stream = match connector.connect(domain, stream).await {
362            Ok(stream) => stream,
363            Err(err) => {
364              error_logger.log(&format!("Bad gateway: {err}")).await;
365              return Ok(
366                ResponseData::builder_without_request()
367                  .status(StatusCode::BAD_GATEWAY)
368                  .build(),
369              );
370            }
371          };
372
373          http_forwarded_auth(
374            connections,
375            addr,
376            tls_stream,
377            auth_request,
378            error_logger,
379            original_request,
380            forwarded_auth_copy_headers,
381          )
382          .await
383        }
384      } else {
385        Ok(ResponseData::builder(request).build())
386      }
387    })
388    .await
389  }
390
391  async fn proxy_request_handler(
392    &mut self,
393    request: RequestData,
394    _config: &ServerConfig,
395    _socket_data: &SocketData,
396    _error_logger: &ErrorLogger,
397  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
398    Ok(ResponseData::builder(request).build())
399  }
400
401  async fn response_modifying_handler(
402    &mut self,
403    response: HyperResponse,
404  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
405    Ok(response)
406  }
407
408  async fn proxy_response_modifying_handler(
409    &mut self,
410    response: HyperResponse,
411  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
412    Ok(response)
413  }
414
415  async fn connect_proxy_request_handler(
416    &mut self,
417    _upgraded_request: HyperUpgraded,
418    _connect_address: &str,
419    _config: &ServerConfig,
420    _socket_data: &SocketData,
421    _error_logger: &ErrorLogger,
422  ) -> Result<(), Box<dyn Error + Send + Sync>> {
423    Ok(())
424  }
425
426  fn does_connect_proxy_requests(&mut self) -> bool {
427    false
428  }
429
430  async fn websocket_request_handler(
431    &mut self,
432    _websocket: HyperWebsocket,
433    _uri: &hyper::Uri,
434    _headers: &hyper::HeaderMap,
435    _config: &ServerConfig,
436    _socket_data: &SocketData,
437    _error_logger: &ErrorLogger,
438  ) -> Result<(), Box<dyn Error + Send + Sync>> {
439    Ok(())
440  }
441
442  fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
443    false
444  }
445}
446
447async fn http_forwarded_auth(
448  connections: &RwLock<HashMap<String, SendRequest<BoxBody<Bytes, hyper::Error>>>>,
449  connect_addr: String,
450  stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
451  proxy_request: Request<BoxBody<Bytes, hyper::Error>>,
452  error_logger: &ErrorLogger,
453  mut original_request: RequestData,
454  forwarded_auth_copy_headers: Vec<String>,
455) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
456  let io = TokioIo::new(stream);
457
458  let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
459    Ok(data) => data,
460    Err(err) => {
461      error_logger.log(&format!("Bad gateway: {err}")).await;
462      return Ok(
463        ResponseData::builder_without_request()
464          .status(StatusCode::BAD_GATEWAY)
465          .build(),
466      );
467    }
468  };
469
470  let send_request = sender.send_request(proxy_request);
471
472  let mut pinned_conn = Box::pin(conn);
473  tokio::pin!(send_request);
474
475  let response;
476
477  loop {
478    tokio::select! {
479      biased;
480
481       proxy_response = &mut send_request => {
482        let proxy_response = match proxy_response {
483          Ok(response) => response,
484          Err(err) => {
485            error_logger.log(&format!("Bad gateway: {err}")).await;
486            return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
487          }
488        };
489
490        if proxy_response.status().is_success() {
491          if !forwarded_auth_copy_headers.is_empty() {
492            let response_headers = proxy_response.headers();
493            let request_headers = original_request.get_mut_hyper_request().headers_mut();
494            for forwarded_auth_copy_header_string in forwarded_auth_copy_headers.iter() {
495              let forwarded_auth_copy_header= HeaderName::from_str(forwarded_auth_copy_header_string)?;
496              if response_headers.contains_key(&forwarded_auth_copy_header) {
497                while request_headers.remove(&forwarded_auth_copy_header).is_some() {}
498                for header_value in response_headers.get_all(&forwarded_auth_copy_header).iter() {
499                  request_headers.append(&forwarded_auth_copy_header, header_value.clone());
500                }
501              }
502            }
503          }
504          response = ResponseData::builder(original_request).build();
505        } else {
506          response = ResponseData::builder_without_request()
507          .response(proxy_response.map(|b| {
508            b.map_err(|e| std::io::Error::other(e.to_string()))
509              .boxed()
510          }))
511          .parallel_fn(async move {
512            pinned_conn.await.unwrap_or_default();
513          })
514          .build();
515
516        }
517
518        break;
519      },
520      state = &mut pinned_conn => {
521        if state.is_err() {
522          error_logger.log("Bad gateway: incomplete response").await;
523          return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
524        }
525      },
526    };
527  }
528
529  if !sender.is_closed() {
530    let mut rwlock_write = connections.write().await;
531    rwlock_write.insert(connect_addr, sender);
532    drop(rwlock_write);
533  }
534
535  Ok(response)
536}
537
538async fn http_forwarded_auth_kept_alive(
539  sender: &mut SendRequest<BoxBody<Bytes, hyper::Error>>,
540  proxy_request: Request<BoxBody<Bytes, hyper::Error>>,
541  error_logger: &ErrorLogger,
542  mut original_request: RequestData,
543  forwarded_auth_copy_headers: Vec<String>,
544) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
545  let proxy_response = match sender.send_request(proxy_request).await {
546    Ok(response) => response,
547    Err(err) => {
548      error_logger.log(&format!("Bad gateway: {err}")).await;
549      return Ok(
550        ResponseData::builder_without_request()
551          .status(StatusCode::BAD_GATEWAY)
552          .build(),
553      );
554    }
555  };
556
557  let response = if proxy_response.status().is_success() {
558    if !forwarded_auth_copy_headers.is_empty() {
559      let response_headers = proxy_response.headers();
560      let request_headers = original_request.get_mut_hyper_request().headers_mut();
561      for forwarded_auth_copy_header_string in forwarded_auth_copy_headers.iter() {
562        let forwarded_auth_copy_header = HeaderName::from_str(forwarded_auth_copy_header_string)?;
563        if response_headers.contains_key(&forwarded_auth_copy_header) {
564          while request_headers
565            .remove(&forwarded_auth_copy_header)
566            .is_some()
567          {}
568          for header_value in response_headers.get_all(&forwarded_auth_copy_header).iter() {
569            request_headers.append(&forwarded_auth_copy_header, header_value.clone());
570          }
571        }
572      }
573    }
574    ResponseData::builder(original_request).build()
575  } else {
576    ResponseData::builder_without_request()
577      .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed()))
578      .build()
579  };
580
581  Ok(response)
582}