ferron/optional_modules/
rproxy.rs

1use std::collections::HashMap;
2use std::error::Error;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::ferron_common::{
8  ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
9  ServerModuleHandlers, SocketData,
10};
11use crate::ferron_common::{HyperResponse, WithRuntime};
12use async_trait::async_trait;
13use futures_util::{SinkExt, StreamExt};
14use http::uri::{PathAndQuery, Scheme};
15use http_body_util::combinators::BoxBody;
16use http_body_util::BodyExt;
17use hyper::body::Bytes;
18use hyper::client::conn::http1::SendRequest;
19use hyper::header::{self, HeaderName, SEC_WEBSOCKET_PROTOCOL};
20use hyper::{Request, StatusCode, Uri, Version};
21use hyper_tungstenite::HyperWebsocket;
22use hyper_util::rt::TokioIo;
23use rustls::pki_types::ServerName;
24use rustls::RootCertStore;
25use rustls_native_certs::load_native_certs;
26use tokio::io::{AsyncRead, AsyncWrite};
27use tokio::net::TcpStream;
28use tokio::runtime::Handle;
29use tokio::sync::RwLock;
30use tokio_rustls::TlsConnector;
31use tokio_tungstenite::tungstenite::client::IntoClientRequest;
32use tokio_tungstenite::tungstenite::ClientRequestBuilder;
33use tokio_tungstenite::Connector;
34
35use crate::ferron_util::no_server_verifier::NoServerVerifier;
36use crate::ferron_util::ttl_cache::TtlCache;
37
38const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32;
39
40pub fn server_module_init(
41  config: &ServerConfig,
42) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
43  let mut roots: RootCertStore = RootCertStore::empty();
44  let certs_result = load_native_certs();
45  if !certs_result.errors.is_empty() {
46    Err(anyhow::anyhow!(
47      "Couldn't load the native certificate store: {}",
48      certs_result.errors[0]
49    ))?
50  }
51  let certs = certs_result.certs;
52
53  for cert in certs {
54    match roots.add(cert) {
55      Ok(_) => (),
56      Err(err) => Err(anyhow::anyhow!(
57        "Couldn't add a certificate to the certificate store: {}",
58        err
59      ))?,
60    }
61  }
62
63  let mut connections_vec = Vec::new();
64  for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST {
65    connections_vec.push(RwLock::new(HashMap::new()));
66  }
67  Ok(Box::new(ReverseProxyModule::new(
68    Arc::new(roots),
69    Arc::new(connections_vec),
70    Arc::new(RwLock::new(TtlCache::new(Duration::from_millis(
71      config["global"]["loadBalancerHealthCheckWindow"]
72        .as_i64()
73        .unwrap_or(5000) as u64,
74    )))),
75  )))
76}
77
78#[allow(clippy::type_complexity)]
79struct ReverseProxyModule {
80  roots: Arc<RootCertStore>,
81  connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
82  failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
83}
84
85impl ReverseProxyModule {
86  #[allow(clippy::type_complexity)]
87  fn new(
88    roots: Arc<RootCertStore>,
89    connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
90    failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
91  ) -> Self {
92    Self {
93      roots,
94      connections,
95      failed_backends,
96    }
97  }
98}
99
100impl ServerModule for ReverseProxyModule {
101  fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
102    Box::new(ReverseProxyModuleHandlers {
103      roots: self.roots.clone(),
104      connections: self.connections.clone(),
105      failed_backends: self.failed_backends.clone(),
106      handle,
107    })
108  }
109}
110
111#[allow(clippy::type_complexity)]
112struct ReverseProxyModuleHandlers {
113  handle: Handle,
114  roots: Arc<RootCertStore>,
115  connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
116  failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
117}
118
119#[async_trait]
120impl ServerModuleHandlers for ReverseProxyModuleHandlers {
121  async fn request_handler(
122    &mut self,
123    request: RequestData,
124    config: &ServerConfig,
125    socket_data: &SocketData,
126    error_logger: &ErrorLogger,
127  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
128    WithRuntime::new(self.handle.clone(), async move {
129      let enable_health_check = config["enableLoadBalancerHealthCheck"]
130        .as_bool()
131        .unwrap_or(false);
132      let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
133        .as_i64()
134        .unwrap_or(3) as u64;
135      let disable_certificate_verification = config["disableProxyCertificateVerification"]
136        .as_bool()
137        .unwrap_or(false);
138      let proxy_intercept_errors = config["proxyInterceptErrors"].as_bool().unwrap_or(false);
139      if let Some(proxy_to) = determine_proxy_to(
140        config,
141        socket_data.encrypted,
142        &self.failed_backends,
143        enable_health_check,
144        health_check_max_fails,
145      )
146      .await
147      {
148        let (hyper_request, _, _, _) = request.into_parts();
149        let (mut hyper_request_parts, request_body) = hyper_request.into_parts();
150
151        let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
152        let scheme_str = proxy_request_url.scheme_str();
153        let mut encrypted = false;
154
155        match scheme_str {
156          Some("http") => {
157            encrypted = false;
158          }
159          Some("https") => {
160            encrypted = true;
161          }
162          _ => Err(anyhow::anyhow!(
163            "Only HTTP and HTTPS reverse proxy URLs are supported."
164          ))?,
165        };
166
167        let host = match proxy_request_url.host() {
168          Some(host) => host,
169          None => Err(anyhow::anyhow!(
170            "The reverse proxy URL doesn't include the host"
171          ))?,
172        };
173
174        let port = proxy_request_url.port_u16().unwrap_or(match scheme_str {
175          Some("http") => 80,
176          Some("https") => 443,
177          _ => 80,
178        });
179
180        let addr = format!("{host}:{port}");
181        let authority = proxy_request_url.authority().cloned();
182
183        let hyper_request_path = hyper_request_parts.uri.path();
184
185        let path = match hyper_request_path.as_bytes().first() {
186          Some(b'/') => {
187            let mut proxy_request_path = proxy_request_url.path();
188            while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
189              proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
190            }
191            format!("{proxy_request_path}{hyper_request_path}")
192          }
193          _ => hyper_request_path.to_string(),
194        };
195
196        hyper_request_parts.uri = Uri::from_str(&format!(
197          "{}{}",
198          path,
199          match hyper_request_parts.uri.query() {
200            Some(query) => format!("?{query}"),
201            None => "".to_string(),
202          }
203        ))?;
204
205        let original_host = hyper_request_parts.headers.get(header::HOST).cloned();
206
207        // Host header for host identification
208        match authority {
209          Some(authority) => {
210            hyper_request_parts
211              .headers
212              .insert(header::HOST, authority.to_string().parse()?);
213          }
214          None => {
215            hyper_request_parts.headers.remove(header::HOST);
216          }
217        }
218
219        // Connection header to enable HTTP/1.1 keep-alive
220        hyper_request_parts
221          .headers
222          .insert(header::CONNECTION, "keep-alive".parse()?);
223
224        // X-Forwarded-* headers to send the client's data to a server that's behind the reverse proxy
225        if config["disableProxyXForwarded"].as_bool().unwrap_or(false) {
226          hyper_request_parts
227            .headers
228            .remove(HeaderName::from_static("x-forwarder-for"));
229          hyper_request_parts
230            .headers
231            .remove(HeaderName::from_static("x-forwarded-proto"));
232          hyper_request_parts
233            .headers
234            .remove(HeaderName::from_static("x-forwarded-host"));
235        } else {
236          hyper_request_parts.headers.insert(
237            HeaderName::from_static("x-forwarded-for"),
238            socket_data
239              .remote_addr
240              .ip()
241              .to_canonical()
242              .to_string()
243              .parse()?,
244          );
245
246          if socket_data.encrypted {
247            hyper_request_parts.headers.insert(
248              HeaderName::from_static("x-forwarded-proto"),
249              "https".parse()?,
250            );
251          } else {
252            hyper_request_parts.headers.insert(
253              HeaderName::from_static("x-forwarded-proto"),
254              "http".parse()?,
255            );
256          }
257
258          if let Some(original_host) = original_host {
259            hyper_request_parts
260              .headers
261              .insert("x-forwarded-host", original_host);
262          }
263        }
264
265        hyper_request_parts.version = Version::HTTP_11;
266
267        let proxy_request = Request::from_parts(hyper_request_parts, request_body);
268
269        let connections = &self.connections[rand::random_range(..self.connections.len())];
270
271        let rwlock_read = connections.read().await;
272        let sender_read_option = rwlock_read.get(&addr);
273
274        if let Some(sender_read) = sender_read_option {
275          if !sender_read.is_closed() {
276            drop(rwlock_read);
277            let mut rwlock_write = connections.write().await;
278            let sender_option = rwlock_write.get_mut(&addr);
279
280            if let Some(sender) = sender_option {
281              if !sender.is_closed() && sender.ready().await.is_ok() {
282                let result = http_proxy_kept_alive(
283                  sender,
284                  proxy_request,
285                  error_logger,
286                  proxy_intercept_errors,
287                )
288                .await;
289                drop(rwlock_write);
290                return result;
291              } else {
292                drop(rwlock_write);
293              }
294            } else {
295              drop(rwlock_write);
296            }
297          } else {
298            drop(rwlock_read);
299          }
300        } else {
301          drop(rwlock_read);
302        }
303
304        let stream = match TcpStream::connect(&addr).await {
305          Ok(stream) => stream,
306          Err(err) => {
307            if enable_health_check {
308              let mut failed_backends_write = self.failed_backends.write().await;
309              let proxy_to = proxy_to.clone();
310              let failed_attempts = failed_backends_write.get(&proxy_to);
311              failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
312            }
313            match err.kind() {
314              tokio::io::ErrorKind::ConnectionRefused
315              | tokio::io::ErrorKind::NotFound
316              | tokio::io::ErrorKind::HostUnreachable => {
317                error_logger
318                  .log(&format!("Service unavailable: {err}"))
319                  .await;
320                return Ok(
321                  ResponseData::builder_without_request()
322                    .status(StatusCode::SERVICE_UNAVAILABLE)
323                    .build(),
324                );
325              }
326              tokio::io::ErrorKind::TimedOut => {
327                error_logger.log(&format!("Gateway timeout: {err}")).await;
328                return Ok(
329                  ResponseData::builder_without_request()
330                    .status(StatusCode::GATEWAY_TIMEOUT)
331                    .build(),
332                );
333              }
334              _ => {
335                error_logger.log(&format!("Bad gateway: {err}")).await;
336                return Ok(
337                  ResponseData::builder_without_request()
338                    .status(StatusCode::BAD_GATEWAY)
339                    .build(),
340                );
341              }
342            };
343          }
344        };
345
346        match stream.set_nodelay(true) {
347          Ok(_) => (),
348          Err(err) => {
349            if enable_health_check {
350              let mut failed_backends_write = self.failed_backends.write().await;
351              let proxy_to = proxy_to.clone();
352              let failed_attempts = failed_backends_write.get(&proxy_to);
353              failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
354            }
355            error_logger.log(&format!("Bad gateway: {err}")).await;
356            return Ok(
357              ResponseData::builder_without_request()
358                .status(StatusCode::BAD_GATEWAY)
359                .build(),
360            );
361          }
362        };
363
364        let failed_backends_option_borrowed = if enable_health_check {
365          Some(&*self.failed_backends)
366        } else {
367          None
368        };
369
370        if !encrypted {
371          http_proxy(
372            connections,
373            addr,
374            stream,
375            proxy_request,
376            error_logger,
377            proxy_to,
378            failed_backends_option_borrowed,
379            proxy_intercept_errors,
380          )
381          .await
382        } else {
383          let tls_client_config = (if disable_certificate_verification {
384            rustls::ClientConfig::builder()
385              .dangerous()
386              .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
387          } else {
388            rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
389          })
390          .with_no_client_auth();
391          let connector = TlsConnector::from(Arc::new(tls_client_config));
392          let domain = ServerName::try_from(host)?.to_owned();
393
394          let tls_stream = match connector.connect(domain, stream).await {
395            Ok(stream) => stream,
396            Err(err) => {
397              if enable_health_check {
398                let mut failed_backends_write = self.failed_backends.write().await;
399                let proxy_to = proxy_to.clone();
400                let failed_attempts = failed_backends_write.get(&proxy_to);
401                failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
402              }
403              error_logger.log(&format!("Bad gateway: {err}")).await;
404              return Ok(
405                ResponseData::builder_without_request()
406                  .status(StatusCode::BAD_GATEWAY)
407                  .build(),
408              );
409            }
410          };
411
412          http_proxy(
413            connections,
414            addr,
415            tls_stream,
416            proxy_request,
417            error_logger,
418            proxy_to,
419            failed_backends_option_borrowed,
420            proxy_intercept_errors,
421          )
422          .await
423        }
424      } else {
425        Ok(ResponseData::builder(request).build())
426      }
427    })
428    .await
429  }
430
431  async fn proxy_request_handler(
432    &mut self,
433    request: RequestData,
434    _config: &ServerConfig,
435    _socket_data: &SocketData,
436    _error_logger: &ErrorLogger,
437  ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
438    Ok(ResponseData::builder(request).build())
439  }
440
441  async fn response_modifying_handler(
442    &mut self,
443    response: HyperResponse,
444  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
445    Ok(response)
446  }
447
448  async fn proxy_response_modifying_handler(
449    &mut self,
450    response: HyperResponse,
451  ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
452    Ok(response)
453  }
454
455  async fn connect_proxy_request_handler(
456    &mut self,
457    _upgraded_request: HyperUpgraded,
458    _connect_address: &str,
459    _config: &ServerConfig,
460    _socket_data: &SocketData,
461    _error_logger: &ErrorLogger,
462  ) -> Result<(), Box<dyn Error + Send + Sync>> {
463    Ok(())
464  }
465
466  fn does_connect_proxy_requests(&mut self) -> bool {
467    false
468  }
469
470  async fn websocket_request_handler(
471    &mut self,
472    websocket: HyperWebsocket,
473    uri: &hyper::Uri,
474    headers: &hyper::HeaderMap,
475    config: &ServerConfig,
476    socket_data: &SocketData,
477    error_logger: &ErrorLogger,
478  ) -> Result<(), Box<dyn Error + Send + Sync>> {
479    WithRuntime::new(self.handle.clone(), async move {
480      let enable_health_check = config["enableLoadBalancerHealthCheck"]
481        .as_bool()
482        .unwrap_or(false);
483      let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
484        .as_i64()
485        .unwrap_or(3) as u64;
486
487      let disable_certificate_verification = config["disableProxyCertificateVerification"]
488        .as_bool()
489        .unwrap_or(false);
490      if let Some(proxy_to) = determine_proxy_to(
491        config,
492        socket_data.encrypted,
493        &self.failed_backends,
494        enable_health_check,
495        health_check_max_fails,
496      )
497      .await
498      {
499        let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
500        let scheme_str = proxy_request_url.scheme_str();
501        let mut encrypted = false;
502
503        match scheme_str {
504          Some("http") => {
505            encrypted = false;
506          }
507          Some("https") => {
508            encrypted = true;
509          }
510          _ => Err(anyhow::anyhow!(
511            "Only HTTP and HTTPS reverse proxy URLs are supported."
512          ))?,
513        };
514
515        let request_path = uri.path();
516
517        let path = match request_path.as_bytes().first() {
518          Some(b'/') => {
519            let mut proxy_request_path = proxy_request_url.path();
520            while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
521              proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
522            }
523            format!("{proxy_request_path}{request_path}")
524          }
525          _ => request_path.to_string(),
526        };
527
528        let mut proxy_request_url_parts = proxy_request_url.into_parts();
529        proxy_request_url_parts.scheme = if encrypted {
530          Some(Scheme::from_str("wss")?)
531        } else {
532          Some(Scheme::from_str("ws")?)
533        };
534        match uri.path_and_query() {
535          Some(path_and_query) => {
536            let path_and_query_string = match path_and_query.query() {
537              Some(query) => {
538                format!("{path}?{query}")
539              }
540              None => path,
541            };
542            proxy_request_url_parts.path_and_query =
543              Some(PathAndQuery::from_str(&path_and_query_string)?);
544          }
545          None => {
546            proxy_request_url_parts.path_and_query = Some(PathAndQuery::from_str(&path)?);
547          }
548        };
549
550        let proxy_request_url = hyper::Uri::from_parts(proxy_request_url_parts)?;
551
552        let connector = if !encrypted {
553          Connector::Plain
554        } else {
555          Connector::Rustls(Arc::new(
556            (if disable_certificate_verification {
557              rustls::ClientConfig::builder()
558                .dangerous()
559                .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
560            } else {
561              rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
562            })
563            .with_no_client_auth(),
564          ))
565        };
566
567        let mut proxy_request_builder = ClientRequestBuilder::new(proxy_request_url);
568        for (header_name, header_value) in headers {
569          let header_name_str = header_name.as_str();
570          if header_name == SEC_WEBSOCKET_PROTOCOL {
571            for subprotocol in String::from_utf8_lossy(header_value.as_bytes()).split(",") {
572              proxy_request_builder = proxy_request_builder.with_sub_protocol(subprotocol.trim());
573            }
574          } else if !header_name_str.starts_with("sec-websocket-")
575            && header_name_str != HeaderName::from_static("x-forwarded-for")
576          {
577            proxy_request_builder = proxy_request_builder.with_header(
578              header_name_str,
579              String::from_utf8_lossy(header_value.as_bytes()),
580            );
581          }
582        }
583
584        // Add X-Forwarded-For header
585        proxy_request_builder = proxy_request_builder.with_header(
586          "x-forwarded-for",
587          socket_data.remote_addr.ip().to_canonical().to_string(),
588        );
589
590        let proxy_request_constructed = proxy_request_builder.into_client_request()?;
591
592        let client_bi_stream = websocket.await?;
593
594        let (proxy_bi_stream, _) = match tokio_tungstenite::connect_async_tls_with_config(
595          proxy_request_constructed,
596          None,
597          true,
598          Some(connector),
599        )
600        .await
601        {
602          Ok(data) => data,
603          Err(err) => {
604            error_logger
605              .log(&format!("Cannot connect to WebSocket server: {err}"))
606              .await;
607            return Ok(());
608          }
609        };
610
611        let (mut client_sink, mut client_stream) = client_bi_stream.split();
612        let (mut proxy_sink, mut proxy_stream) = proxy_bi_stream.split();
613
614        let client_to_proxy = async {
615          while let Some(Ok(value)) = client_stream.next().await {
616            if proxy_sink.send(value).await.is_err() {
617              break;
618            }
619          }
620        };
621
622        let proxy_to_client = async {
623          while let Some(Ok(value)) = proxy_stream.next().await {
624            if client_sink.send(value).await.is_err() {
625              break;
626            }
627          }
628        };
629
630        tokio::pin!(client_to_proxy);
631        tokio::pin!(proxy_to_client);
632
633        let client_to_proxy_first;
634        tokio::select! {
635          _ = &mut client_to_proxy => {
636            client_to_proxy_first = true;
637          }
638          _ = &mut proxy_to_client => {
639            client_to_proxy_first = false;
640          }
641        }
642
643        if client_to_proxy_first {
644          proxy_to_client.await;
645        } else {
646          client_to_proxy.await;
647        }
648      }
649
650      Ok(())
651    })
652    .await
653  }
654
655  fn does_websocket_requests(&mut self, config: &ServerConfig, socket_data: &SocketData) -> bool {
656    if socket_data.encrypted {
657      let secure_proxy_to = &config["secureProxyTo"];
658      if secure_proxy_to.as_vec().is_some() || secure_proxy_to.as_str().is_some() {
659        return true;
660      }
661    }
662
663    let proxy_to = &config["proxyTo"];
664    proxy_to.as_vec().is_some() || proxy_to.as_str().is_some()
665  }
666}
667
668async fn determine_proxy_to(
669  config: &ServerConfig,
670  encrypted: bool,
671  failed_backends: &RwLock<TtlCache<String, u64>>,
672  enable_health_check: bool,
673  health_check_max_fails: u64,
674) -> Option<String> {
675  let mut proxy_to = None;
676  // When the array is supplied with non-string values, the reverse proxy may have undesirable behavior
677  // The "proxyTo" and "secureProxyTo" are validated though.
678
679  if encrypted {
680    let secure_proxy_to_yaml = &config["secureProxyTo"];
681    if let Some(secure_proxy_to_vector) = secure_proxy_to_yaml.as_vec() {
682      if enable_health_check {
683        let mut secure_proxy_to_vector = secure_proxy_to_vector.clone();
684        loop {
685          if !secure_proxy_to_vector.is_empty() {
686            let index = rand::random_range(..secure_proxy_to_vector.len());
687            if let Some(secure_proxy_to) = secure_proxy_to_vector[index].as_str() {
688              proxy_to = Some(secure_proxy_to.to_string());
689              let failed_backends_read = failed_backends.read().await;
690              let failed_backend_fails =
691                match failed_backends_read.get(&secure_proxy_to.to_string()) {
692                  Some(fails) => fails,
693                  None => break,
694                };
695              if failed_backend_fails > health_check_max_fails {
696                secure_proxy_to_vector.remove(index);
697              } else {
698                break;
699              }
700            }
701          } else {
702            break;
703          }
704        }
705      } else if !secure_proxy_to_vector.is_empty() {
706        if let Some(secure_proxy_to) =
707          secure_proxy_to_vector[rand::random_range(..secure_proxy_to_vector.len())].as_str()
708        {
709          proxy_to = Some(secure_proxy_to.to_string());
710        }
711      }
712    } else if let Some(secure_proxy_to) = secure_proxy_to_yaml.as_str() {
713      proxy_to = Some(secure_proxy_to.to_string());
714    }
715  }
716
717  if proxy_to.is_none() {
718    let proxy_to_yaml = &config["proxyTo"];
719    if let Some(proxy_to_vector) = proxy_to_yaml.as_vec() {
720      if enable_health_check {
721        let mut proxy_to_vector = proxy_to_vector.clone();
722        loop {
723          if !proxy_to_vector.is_empty() {
724            let index = rand::random_range(..proxy_to_vector.len());
725            if let Some(proxy_to_str) = proxy_to_vector[index].as_str() {
726              proxy_to = Some(proxy_to_str.to_string());
727              let failed_backends_read = failed_backends.read().await;
728              let failed_backend_fails = match failed_backends_read.get(&proxy_to_str.to_string()) {
729                Some(fails) => fails,
730                None => break,
731              };
732              if failed_backend_fails > health_check_max_fails {
733                proxy_to_vector.remove(index);
734              } else {
735                break;
736              }
737            }
738          } else {
739            break;
740          }
741        }
742      } else if !proxy_to_vector.is_empty() {
743        if let Some(proxy_to_str) =
744          proxy_to_vector[rand::random_range(..proxy_to_vector.len())].as_str()
745        {
746          proxy_to = Some(proxy_to_str.to_string());
747        }
748      }
749    } else if let Some(proxy_to_str) = proxy_to_yaml.as_str() {
750      proxy_to = Some(proxy_to_str.to_string());
751    }
752  }
753
754  proxy_to
755}
756
757#[allow(clippy::too_many_arguments)]
758async fn http_proxy(
759  connections: &RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>,
760  connect_addr: String,
761  stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
762  proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
763  error_logger: &ErrorLogger,
764  proxy_to: String,
765  failed_backends: Option<&tokio::sync::RwLock<TtlCache<std::string::String, u64>>>,
766  proxy_intercept_errors: bool,
767) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
768  let io = TokioIo::new(stream);
769
770  let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
771    Ok(data) => data,
772    Err(err) => {
773      if let Some(failed_backends) = failed_backends {
774        let mut failed_backends_write = failed_backends.write().await;
775        let failed_attempts = failed_backends_write.get(&proxy_to);
776        failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
777      }
778      error_logger.log(&format!("Bad gateway: {err}")).await;
779      return Ok(
780        ResponseData::builder_without_request()
781          .status(StatusCode::BAD_GATEWAY)
782          .build(),
783      );
784    }
785  };
786
787  let send_request = sender.send_request(proxy_request);
788
789  let mut pinned_conn = Box::pin(conn);
790  tokio::pin!(send_request);
791
792  let response;
793
794  loop {
795    tokio::select! {
796      biased;
797
798      proxy_response = &mut send_request => {
799        let proxy_response = match proxy_response {
800          Ok(response) => response,
801          Err(err) => {
802            error_logger.log(&format!("Bad gateway: {err}")).await;
803            return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
804          }
805        };
806
807        let status_code = proxy_response.status();
808        response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
809          ResponseData::builder_without_request()
810          .status(status_code)
811          .parallel_fn(async move {
812            pinned_conn.await.unwrap_or_default();
813          })
814          .build()
815        } else {
816          ResponseData::builder_without_request()
817          .response(proxy_response.map(|b| {
818            b.map_err(|e| std::io::Error::other(e.to_string()))
819              .boxed()
820          }))
821          .parallel_fn(async move {
822            pinned_conn.await.unwrap_or_default();
823          })
824          .build()
825        };
826
827        break;
828      },
829      state = &mut pinned_conn => {
830        if state.is_err() {
831          error_logger.log("Bad gateway: incomplete response").await;
832          return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
833        }
834      },
835    };
836  }
837
838  if !sender.is_closed() {
839    let mut rwlock_write = connections.write().await;
840    rwlock_write.insert(connect_addr, sender);
841    drop(rwlock_write);
842  }
843
844  Ok(response)
845}
846
847async fn http_proxy_kept_alive(
848  sender: &mut SendRequest<BoxBody<Bytes, std::io::Error>>,
849  proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
850  error_logger: &ErrorLogger,
851  proxy_intercept_errors: bool,
852) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
853  let proxy_response = match sender.send_request(proxy_request).await {
854    Ok(response) => response,
855    Err(err) => {
856      error_logger.log(&format!("Bad gateway: {err}")).await;
857      return Ok(
858        ResponseData::builder_without_request()
859          .status(StatusCode::BAD_GATEWAY)
860          .build(),
861      );
862    }
863  };
864
865  let status_code = proxy_response.status();
866  let response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
867    ResponseData::builder_without_request()
868      .status(status_code)
869      .build()
870  } else {
871    ResponseData::builder_without_request()
872      .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed()))
873      .build()
874  };
875
876  Ok(response)
877}