ferron/optional_modules/
rproxy.rs

1use std::collections::HashMap;
2use std::error::Error;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::ferron_common::{
8  ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
9  ServerModuleHandlers, SocketData,
10};
11use crate::ferron_common::{HyperResponse, WithRuntime};
12use async_trait::async_trait;
13use futures_util::{SinkExt, StreamExt};
14use http::header::SEC_WEBSOCKET_PROTOCOL;
15use http::uri::{PathAndQuery, Scheme};
16use http_body_util::combinators::BoxBody;
17use http_body_util::BodyExt;
18use hyper::body::Bytes;
19use hyper::client::conn::http1::SendRequest;
20use hyper::{header, Request, StatusCode, Uri, Version};
21use hyper_tungstenite::HyperWebsocket;
22use hyper_util::rt::TokioIo;
23use rustls::pki_types::ServerName;
24use rustls::RootCertStore;
25use rustls_native_certs::load_native_certs;
26use tokio::io::{AsyncRead, AsyncWrite};
27use tokio::net::TcpStream;
28use tokio::runtime::Handle;
29use tokio::sync::RwLock;
30use tokio_rustls::TlsConnector;
31use tokio_tungstenite::tungstenite::client::IntoClientRequest;
32use tokio_tungstenite::tungstenite::ClientRequestBuilder;
33use tokio_tungstenite::Connector;
34
35use crate::ferron_util::no_server_verifier::NoServerVerifier;
36use crate::ferron_util::ttl_cache::TtlCache;
37
38const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32;
39
40pub fn server_module_init(
41  config: &ServerConfig,
42) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
43  let mut roots: RootCertStore = RootCertStore::empty();
44  let certs_result = load_native_certs();
45  if !certs_result.errors.is_empty() {
46    Err(anyhow::anyhow!(format!(
47      "Couldn't load the native certificate store: {}",
48      certs_result.errors[0]
49    )))?
50  }
51  let certs = certs_result.certs;
52
53  for cert in certs {
54    match roots.add(cert) {
55      Ok(_) => (),
56      Err(err) => Err(anyhow::anyhow!(format!(
57        "Couldn't add a certificate to the certificate store: {}",
58        err
59      )))?,
60    }
61  }
62
63  let mut connections_vec = Vec::new();
64  for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST {
65    connections_vec.push(RwLock::new(HashMap::new()));
66  }
67  Ok(Box::new(ReverseProxyModule::new(
68    Arc::new(roots),
69    Arc::new(connections_vec),
70    Arc::new(RwLock::new(TtlCache::new(Duration::from_millis(
71      config["global"]["loadBalancerHealthCheckWindow"]
72        .as_i64()
73        .unwrap_or(5000) as u64,
74    )))),
75  )))
76}
77
78#[allow(clippy::type_complexity)]
79struct ReverseProxyModule {
80  roots: Arc<RootCertStore>,
81  connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
82  failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
83}
84
85impl ReverseProxyModule {
86  #[allow(clippy::type_complexity)]
87  fn new(
88    roots: Arc<RootCertStore>,
89    connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
90    failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
91  ) -> Self {
92    Self {
93      roots,
94      connections,
95      failed_backends,
96    }
97  }
98}
99
100impl ServerModule for ReverseProxyModule {
101  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
102    Box::new(ReverseProxyModuleHandlers {
103      roots: self.roots.clone(),
104      connections: self.connections.clone(),
105      failed_backends: self.failed_backends.clone(),
106      handle,
107    })
108  }
109}
110
111#[allow(clippy::type_complexity)]
112struct ReverseProxyModuleHandlers {
113  handle: Handle,
114  roots: Arc<RootCertStore>,
115  connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
116  failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
117}
118
119#[async_trait]
120impl ServerModuleHandlers for ReverseProxyModuleHandlers {
121  async fn request_handler(
122    &mut self,
123    request: RequestData,
124    config: &ServerConfig,
125    socket_data: &SocketData,
126    error_logger: &ErrorLogger,
127  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
128    WithRuntime::new(self.handle.clone(), async move {
129      let enable_health_check = config["enableLoadBalancerHealthCheck"]
130        .as_bool()
131        .unwrap_or(false);
132      let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
133        .as_i64()
134        .unwrap_or(3) as u64;
135      let disable_certificate_verification = config["disableProxyCertificateVerification"]
136        .as_bool()
137        .unwrap_or(false);
138      let proxy_intercept_errors = config["proxyInterceptErrors"].as_bool().unwrap_or(false);
139      if let Some(proxy_to) = determine_proxy_to(
140        config,
141        socket_data.encrypted,
142        &self.failed_backends,
143        enable_health_check,
144        health_check_max_fails,
145      )
146      .await
147      {
148        let (hyper_request, _, _, _) = request.into_parts();
149        let (mut hyper_request_parts, request_body) = hyper_request.into_parts();
150
151        let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
152        let scheme_str = proxy_request_url.scheme_str();
153        let mut encrypted = false;
154
155        match scheme_str {
156          Some("http") => {
157            encrypted = false;
158          }
159          Some("https") => {
160            encrypted = true;
161          }
162          _ => Err(anyhow::anyhow!(
163            "Only HTTP and HTTPS reverse proxy URLs are supported."
164          ))?,
165        };
166
167        let host = match proxy_request_url.host() {
168          Some(host) => host,
169          None => Err(anyhow::anyhow!(
170            "The reverse proxy URL doesn't include the host"
171          ))?,
172        };
173
174        let port = proxy_request_url.port_u16().unwrap_or(match scheme_str {
175          Some("http") => 80,
176          Some("https") => 443,
177          _ => 80,
178        });
179
180        let addr = format!("{host}:{port}");
181        let authority = proxy_request_url.authority().cloned();
182
183        let hyper_request_path = hyper_request_parts.uri.path();
184
185        let path = match hyper_request_path.as_bytes().first() {
186          Some(b'/') => {
187            let mut proxy_request_path = proxy_request_url.path();
188            while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
189              proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
190            }
191            format!("{proxy_request_path}{hyper_request_path}")
192          }
193          _ => hyper_request_path.to_string(),
194        };
195
196        hyper_request_parts.uri = Uri::from_str(&format!(
197          "{}{}",
198          path,
199          match hyper_request_parts.uri.query() {
200            Some(query) => format!("?{query}"),
201            None => "".to_string(),
202          }
203        ))?;
204
205        let original_host = hyper_request_parts.headers.get(header::HOST).cloned();
206
207        // Host header for host identification
208        match authority {
209          Some(authority) => {
210            hyper_request_parts
211              .headers
212              .insert(header::HOST, authority.to_string().parse()?);
213          }
214          None => {
215            hyper_request_parts.headers.remove(header::HOST);
216          }
217        }
218
219        // Connection header to enable HTTP/1.1 keep-alive
220        hyper_request_parts
221          .headers
222          .insert(header::CONNECTION, "keep-alive".parse()?);
223
224        // X-Forwarded-* headers to send the client's data to a server that's behind the reverse proxy
225        if config["disableProxyXForwarded"].as_bool().unwrap_or(false) {
226          hyper_request_parts.headers.remove("x-forwarder-for");
227          hyper_request_parts.headers.remove("x-forwarded-proto");
228          hyper_request_parts.headers.remove("x-forwarded-host");
229        } else {
230          hyper_request_parts.headers.insert(
231            "x-forwarded-for",
232            socket_data
233              .remote_addr
234              .ip()
235              .to_canonical()
236              .to_string()
237              .parse()?,
238          );
239
240          if socket_data.encrypted {
241            hyper_request_parts
242              .headers
243              .insert("x-forwarded-proto", "https".parse()?);
244          } else {
245            hyper_request_parts
246              .headers
247              .insert("x-forwarded-proto", "http".parse()?);
248          }
249
250          if let Some(original_host) = original_host {
251            hyper_request_parts
252              .headers
253              .insert("x-forwarded-host", original_host);
254          }
255        }
256
257        hyper_request_parts.version = Version::HTTP_11;
258
259        let proxy_request = Request::from_parts(hyper_request_parts, request_body);
260
261        let connections = &self.connections[rand::random_range(..self.connections.len())];
262
263        let rwlock_read = connections.read().await;
264        let sender_read_option = rwlock_read.get(&addr);
265
266        if let Some(sender_read) = sender_read_option {
267          if !sender_read.is_closed() {
268            drop(rwlock_read);
269            let mut rwlock_write = connections.write().await;
270            let sender_option = rwlock_write.get_mut(&addr);
271
272            if let Some(sender) = sender_option {
273              if !sender.is_closed() && sender.ready().await.is_ok() {
274                let result = http_proxy_kept_alive(
275                  sender,
276                  proxy_request,
277                  error_logger,
278                  proxy_intercept_errors,
279                )
280                .await;
281                drop(rwlock_write);
282                return result;
283              } else {
284                drop(rwlock_write);
285              }
286            } else {
287              drop(rwlock_write);
288            }
289          } else {
290            drop(rwlock_read);
291          }
292        } else {
293          drop(rwlock_read);
294        }
295
296        let stream = match TcpStream::connect(&addr).await {
297          Ok(stream) => stream,
298          Err(err) => {
299            if enable_health_check {
300              let mut failed_backends_write = self.failed_backends.write().await;
301              let proxy_to = proxy_to.clone();
302              let failed_attempts = failed_backends_write.get(&proxy_to);
303              failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
304            }
305            match err.kind() {
306              tokio::io::ErrorKind::ConnectionRefused
307              | tokio::io::ErrorKind::NotFound
308              | tokio::io::ErrorKind::HostUnreachable => {
309                error_logger
310                  .log(&format!("Service unavailable: {err}"))
311                  .await;
312                return Ok(
313                  ResponseData::builder_without_request()
314                    .status(StatusCode::SERVICE_UNAVAILABLE)
315                    .build(),
316                );
317              }
318              tokio::io::ErrorKind::TimedOut => {
319                error_logger.log(&format!("Gateway timeout: {err}")).await;
320                return Ok(
321                  ResponseData::builder_without_request()
322                    .status(StatusCode::GATEWAY_TIMEOUT)
323                    .build(),
324                );
325              }
326              _ => {
327                error_logger.log(&format!("Bad gateway: {err}")).await;
328                return Ok(
329                  ResponseData::builder_without_request()
330                    .status(StatusCode::BAD_GATEWAY)
331                    .build(),
332                );
333              }
334            };
335          }
336        };
337
338        match stream.set_nodelay(true) {
339          Ok(_) => (),
340          Err(err) => {
341            if enable_health_check {
342              let mut failed_backends_write = self.failed_backends.write().await;
343              let proxy_to = proxy_to.clone();
344              let failed_attempts = failed_backends_write.get(&proxy_to);
345              failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
346            }
347            error_logger.log(&format!("Bad gateway: {err}")).await;
348            return Ok(
349              ResponseData::builder_without_request()
350                .status(StatusCode::BAD_GATEWAY)
351                .build(),
352            );
353          }
354        };
355
356        let failed_backends_option_borrowed = if enable_health_check {
357          Some(&*self.failed_backends)
358        } else {
359          None
360        };
361
362        if !encrypted {
363          http_proxy(
364            connections,
365            addr,
366            stream,
367            proxy_request,
368            error_logger,
369            proxy_to,
370            failed_backends_option_borrowed,
371            proxy_intercept_errors,
372          )
373          .await
374        } else {
375          let tls_client_config = (if disable_certificate_verification {
376            rustls::ClientConfig::builder()
377              .dangerous()
378              .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
379          } else {
380            rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
381          })
382          .with_no_client_auth();
383          let connector = TlsConnector::from(Arc::new(tls_client_config));
384          let domain = ServerName::try_from(host)?.to_owned();
385
386          let tls_stream = match connector.connect(domain, stream).await {
387            Ok(stream) => stream,
388            Err(err) => {
389              if enable_health_check {
390                let mut failed_backends_write = self.failed_backends.write().await;
391                let proxy_to = proxy_to.clone();
392                let failed_attempts = failed_backends_write.get(&proxy_to);
393                failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
394              }
395              error_logger.log(&format!("Bad gateway: {err}")).await;
396              return Ok(
397                ResponseData::builder_without_request()
398                  .status(StatusCode::BAD_GATEWAY)
399                  .build(),
400              );
401            }
402          };
403
404          http_proxy(
405            connections,
406            addr,
407            tls_stream,
408            proxy_request,
409            error_logger,
410            proxy_to,
411            failed_backends_option_borrowed,
412            proxy_intercept_errors,
413          )
414          .await
415        }
416      } else {
417        Ok(ResponseData::builder(request).build())
418      }
419    })
420    .await
421  }
422
423  async fn proxy_request_handler(
424    &mut self,
425    request: RequestData,
426    _config: &ServerConfig,
427    _socket_data: &SocketData,
428    _error_logger: &ErrorLogger,
429  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
430    Ok(ResponseData::builder(request).build())
431  }
432
433  async fn response_modifying_handler(
434    &mut self,
435    response: HyperResponse,
436  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
437    Ok(response)
438  }
439
440  async fn proxy_response_modifying_handler(
441    &mut self,
442    response: HyperResponse,
443  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
444    Ok(response)
445  }
446
447  async fn connect_proxy_request_handler(
448    &mut self,
449    _upgraded_request: HyperUpgraded,
450    _connect_address: &str,
451    _config: &ServerConfig,
452    _socket_data: &SocketData,
453    _error_logger: &ErrorLogger,
454  ) -> Result<(), Box<dyn Error + Send + Sync>> {
455    Ok(())
456  }
457
458  fn does_connect_proxy_requests(&mut self) -> bool {
459    false
460  }
461
462  async fn websocket_request_handler(
463    &mut self,
464    websocket: HyperWebsocket,
465    uri: &hyper::Uri,
466    headers: &hyper::HeaderMap,
467    config: &ServerConfig,
468    socket_data: &SocketData,
469    error_logger: &ErrorLogger,
470  ) -> Result<(), Box<dyn Error + Send + Sync>> {
471    WithRuntime::new(self.handle.clone(), async move {
472      let enable_health_check = config["enableLoadBalancerHealthCheck"]
473        .as_bool()
474        .unwrap_or(false);
475      let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
476        .as_i64()
477        .unwrap_or(3) as u64;
478
479      let disable_certificate_verification = config["disableProxyCertificateVerification"]
480        .as_bool()
481        .unwrap_or(false);
482      if let Some(proxy_to) = determine_proxy_to(
483        config,
484        socket_data.encrypted,
485        &self.failed_backends,
486        enable_health_check,
487        health_check_max_fails,
488      )
489      .await
490      {
491        let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
492        let scheme_str = proxy_request_url.scheme_str();
493        let mut encrypted = false;
494
495        match scheme_str {
496          Some("http") => {
497            encrypted = false;
498          }
499          Some("https") => {
500            encrypted = true;
501          }
502          _ => Err(anyhow::anyhow!(
503            "Only HTTP and HTTPS reverse proxy URLs are supported."
504          ))?,
505        };
506
507        let request_path = uri.path();
508
509        let path = match request_path.as_bytes().first() {
510          Some(b'/') => {
511            let mut proxy_request_path = proxy_request_url.path();
512            while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
513              proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
514            }
515            format!("{proxy_request_path}{request_path}")
516          }
517          _ => request_path.to_string(),
518        };
519
520        let mut proxy_request_url_parts = proxy_request_url.into_parts();
521        proxy_request_url_parts.scheme = if encrypted {
522          Some(Scheme::from_str("wss")?)
523        } else {
524          Some(Scheme::from_str("ws")?)
525        };
526        match uri.path_and_query() {
527          Some(path_and_query) => {
528            let path_and_query_string = match path_and_query.query() {
529              Some(query) => {
530                format!("{path}?{query}")
531              }
532              None => path,
533            };
534            proxy_request_url_parts.path_and_query =
535              Some(PathAndQuery::from_str(&path_and_query_string)?);
536          }
537          None => {
538            proxy_request_url_parts.path_and_query = Some(PathAndQuery::from_str(&path)?);
539          }
540        };
541
542        let proxy_request_url = hyper::Uri::from_parts(proxy_request_url_parts)?;
543
544        let connector = if !encrypted {
545          Connector::Plain
546        } else {
547          Connector::Rustls(Arc::new(
548            (if disable_certificate_verification {
549              rustls::ClientConfig::builder()
550                .dangerous()
551                .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
552            } else {
553              rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
554            })
555            .with_no_client_auth(),
556          ))
557        };
558
559        let mut proxy_request_builder = ClientRequestBuilder::new(proxy_request_url);
560        for (header_name, header_value) in headers {
561          let header_name_str = header_name.as_str();
562          if header_name == SEC_WEBSOCKET_PROTOCOL {
563            for subprotocol in String::from_utf8_lossy(header_value.as_bytes()).split(",") {
564              proxy_request_builder = proxy_request_builder.with_sub_protocol(subprotocol.trim());
565            }
566          } else if !header_name_str.starts_with("sec-websocket-")
567            && header_name_str != "x-forwarded-for"
568          {
569            proxy_request_builder = proxy_request_builder.with_header(
570              header_name_str,
571              String::from_utf8_lossy(header_value.as_bytes()),
572            );
573          }
574        }
575
576        // Add X-Forwarded-For header
577        proxy_request_builder = proxy_request_builder.with_header(
578          "x-forwarded-for",
579          socket_data.remote_addr.ip().to_canonical().to_string(),
580        );
581
582        let proxy_request_constructed = proxy_request_builder.into_client_request()?;
583
584        let client_bi_stream = websocket.await?;
585
586        let (proxy_bi_stream, _) = match tokio_tungstenite::connect_async_tls_with_config(
587          proxy_request_constructed,
588          None,
589          true,
590          Some(connector),
591        )
592        .await
593        {
594          Ok(data) => data,
595          Err(err) => {
596            error_logger
597              .log(&format!("Cannot connect to WebSocket server: {err}"))
598              .await;
599            return Ok(());
600          }
601        };
602
603        let (mut client_sink, mut client_stream) = client_bi_stream.split();
604        let (mut proxy_sink, mut proxy_stream) = proxy_bi_stream.split();
605
606        let client_to_proxy = async {
607          while let Some(Ok(value)) = client_stream.next().await {
608            if proxy_sink.send(value).await.is_err() {
609              break;
610            }
611          }
612        };
613
614        let proxy_to_client = async {
615          while let Some(Ok(value)) = proxy_stream.next().await {
616            if client_sink.send(value).await.is_err() {
617              break;
618            }
619          }
620        };
621
622        tokio::pin!(client_to_proxy);
623        tokio::pin!(proxy_to_client);
624
625        let client_to_proxy_first;
626        tokio::select! {
627          _ = &mut client_to_proxy => {
628            client_to_proxy_first = true;
629          }
630          _ = &mut proxy_to_client => {
631            client_to_proxy_first = false;
632          }
633        }
634
635        if client_to_proxy_first {
636          proxy_to_client.await;
637        } else {
638          client_to_proxy.await;
639        }
640      }
641
642      Ok(())
643    })
644    .await
645  }
646
647  fn does_websocket_requests(&mut self, config: &ServerConfig, socket_data: &SocketData) -> bool {
648    if socket_data.encrypted {
649      let secure_proxy_to = &config["secureProxyTo"];
650      if secure_proxy_to.as_vec().is_some() || secure_proxy_to.as_str().is_some() {
651        return true;
652      }
653    }
654
655    let proxy_to = &config["proxyTo"];
656    proxy_to.as_vec().is_some() || proxy_to.as_str().is_some()
657  }
658}
659
660async fn determine_proxy_to(
661  config: &ServerConfig,
662  encrypted: bool,
663  failed_backends: &RwLock<TtlCache<String, u64>>,
664  enable_health_check: bool,
665  health_check_max_fails: u64,
666) -> Option<String> {
667  let mut proxy_to = None;
668  // When the array is supplied with non-string values, the reverse proxy may have undesirable behavior
669  // The "proxyTo" and "secureProxyTo" are validated though.
670
671  if encrypted {
672    let secure_proxy_to_yaml = &config["secureProxyTo"];
673    if let Some(secure_proxy_to_vector) = secure_proxy_to_yaml.as_vec() {
674      if enable_health_check {
675        let mut secure_proxy_to_vector = secure_proxy_to_vector.clone();
676        loop {
677          if !secure_proxy_to_vector.is_empty() {
678            let index = rand::random_range(..secure_proxy_to_vector.len());
679            if let Some(secure_proxy_to) = secure_proxy_to_vector[index].as_str() {
680              proxy_to = Some(secure_proxy_to.to_string());
681              let failed_backends_read = failed_backends.read().await;
682              let failed_backend_fails =
683                match failed_backends_read.get(&secure_proxy_to.to_string()) {
684                  Some(fails) => fails,
685                  None => break,
686                };
687              if failed_backend_fails > health_check_max_fails {
688                secure_proxy_to_vector.remove(index);
689              } else {
690                break;
691              }
692            }
693          } else {
694            break;
695          }
696        }
697      } else if !secure_proxy_to_vector.is_empty() {
698        if let Some(secure_proxy_to) =
699          secure_proxy_to_vector[rand::random_range(..secure_proxy_to_vector.len())].as_str()
700        {
701          proxy_to = Some(secure_proxy_to.to_string());
702        }
703      }
704    } else if let Some(secure_proxy_to) = secure_proxy_to_yaml.as_str() {
705      proxy_to = Some(secure_proxy_to.to_string());
706    }
707  }
708
709  if proxy_to.is_none() {
710    let proxy_to_yaml = &config["proxyTo"];
711    if let Some(proxy_to_vector) = proxy_to_yaml.as_vec() {
712      if enable_health_check {
713        let mut proxy_to_vector = proxy_to_vector.clone();
714        loop {
715          if !proxy_to_vector.is_empty() {
716            let index = rand::random_range(..proxy_to_vector.len());
717            if let Some(proxy_to_str) = proxy_to_vector[index].as_str() {
718              proxy_to = Some(proxy_to_str.to_string());
719              let failed_backends_read = failed_backends.read().await;
720              let failed_backend_fails = match failed_backends_read.get(&proxy_to_str.to_string()) {
721                Some(fails) => fails,
722                None => break,
723              };
724              if failed_backend_fails > health_check_max_fails {
725                proxy_to_vector.remove(index);
726              } else {
727                break;
728              }
729            }
730          } else {
731            break;
732          }
733        }
734      } else if !proxy_to_vector.is_empty() {
735        if let Some(proxy_to_str) =
736          proxy_to_vector[rand::random_range(..proxy_to_vector.len())].as_str()
737        {
738          proxy_to = Some(proxy_to_str.to_string());
739        }
740      }
741    } else if let Some(proxy_to_str) = proxy_to_yaml.as_str() {
742      proxy_to = Some(proxy_to_str.to_string());
743    }
744  }
745
746  proxy_to
747}
748
749#[allow(clippy::too_many_arguments)]
750async fn http_proxy(
751  connections: &RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>,
752  connect_addr: String,
753  stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
754  proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
755  error_logger: &ErrorLogger,
756  proxy_to: String,
757  failed_backends: Option<&tokio::sync::RwLock<TtlCache<std::string::String, u64>>>,
758  proxy_intercept_errors: bool,
759) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
760  let io = TokioIo::new(stream);
761
762  let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
763    Ok(data) => data,
764    Err(err) => {
765      if let Some(failed_backends) = failed_backends {
766        let mut failed_backends_write = failed_backends.write().await;
767        let failed_attempts = failed_backends_write.get(&proxy_to);
768        failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
769      }
770      error_logger.log(&format!("Bad gateway: {err}")).await;
771      return Ok(
772        ResponseData::builder_without_request()
773          .status(StatusCode::BAD_GATEWAY)
774          .build(),
775      );
776    }
777  };
778
779  let send_request = sender.send_request(proxy_request);
780
781  let mut pinned_conn = Box::pin(conn);
782  tokio::pin!(send_request);
783
784  let response;
785
786  loop {
787    tokio::select! {
788      biased;
789
790      proxy_response = &mut send_request => {
791        let proxy_response = match proxy_response {
792          Ok(response) => response,
793          Err(err) => {
794            error_logger.log(&format!("Bad gateway: {err}")).await;
795            return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
796          }
797        };
798
799        let status_code = proxy_response.status();
800        response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
801          ResponseData::builder_without_request()
802          .status(status_code)
803          .parallel_fn(async move {
804            pinned_conn.await.unwrap_or_default();
805          })
806          .build()
807        } else {
808          ResponseData::builder_without_request()
809          .response(proxy_response.map(|b| {
810            b.map_err(|e| std::io::Error::other(e.to_string()))
811              .boxed()
812          }))
813          .parallel_fn(async move {
814            pinned_conn.await.unwrap_or_default();
815          })
816          .build()
817        };
818
819        break;
820      },
821      state = &mut pinned_conn => {
822        if state.is_err() {
823          error_logger.log("Bad gateway: incomplete response").await;
824          return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
825        }
826      },
827    };
828  }
829
830  if !sender.is_closed() {
831    let mut rwlock_write = connections.write().await;
832    rwlock_write.insert(connect_addr, sender);
833    drop(rwlock_write);
834  }
835
836  Ok(response)
837}
838
839async fn http_proxy_kept_alive(
840  sender: &mut SendRequest<BoxBody<Bytes, std::io::Error>>,
841  proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
842  error_logger: &ErrorLogger,
843  proxy_intercept_errors: bool,
844) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
845  let proxy_response = match sender.send_request(proxy_request).await {
846    Ok(response) => response,
847    Err(err) => {
848      error_logger.log(&format!("Bad gateway: {err}")).await;
849      return Ok(
850        ResponseData::builder_without_request()
851          .status(StatusCode::BAD_GATEWAY)
852          .build(),
853      );
854    }
855  };
856
857  let status_code = proxy_response.status();
858  let response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
859    ResponseData::builder_without_request()
860      .status(status_code)
861      .build()
862  } else {
863    ResponseData::builder_without_request()
864      .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed()))
865      .build()
866  };
867
868  Ok(response)
869}