1use std::collections::HashMap;
2use std::error::Error;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::ferron_common::{
8 ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
9 ServerModuleHandlers, SocketData,
10};
11use crate::ferron_common::{HyperResponse, WithRuntime};
12use async_trait::async_trait;
13use futures_util::{SinkExt, StreamExt};
14use http::uri::{PathAndQuery, Scheme};
15use http_body_util::combinators::BoxBody;
16use http_body_util::BodyExt;
17use hyper::body::Bytes;
18use hyper::client::conn::http1::SendRequest;
19use hyper::header::{self, HeaderName, SEC_WEBSOCKET_PROTOCOL};
20use hyper::{Request, StatusCode, Uri, Version};
21use hyper_tungstenite::HyperWebsocket;
22use hyper_util::rt::TokioIo;
23use rustls::pki_types::ServerName;
24use rustls::RootCertStore;
25use rustls_native_certs::load_native_certs;
26use tokio::io::{AsyncRead, AsyncWrite};
27use tokio::net::TcpStream;
28use tokio::runtime::Handle;
29use tokio::sync::RwLock;
30use tokio_rustls::TlsConnector;
31use tokio_tungstenite::tungstenite::client::IntoClientRequest;
32use tokio_tungstenite::tungstenite::ClientRequestBuilder;
33use tokio_tungstenite::Connector;
34
35use crate::ferron_util::no_server_verifier::NoServerVerifier;
36use crate::ferron_util::ttl_cache::TtlCache;
37
38const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32;
39
40pub fn server_module_init(
41 config: &ServerConfig,
42) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
43 let mut roots: RootCertStore = RootCertStore::empty();
44 let certs_result = load_native_certs();
45 if !certs_result.errors.is_empty() {
46 Err(anyhow::anyhow!(
47 "Couldn't load the native certificate store: {}",
48 certs_result.errors[0]
49 ))?
50 }
51 let certs = certs_result.certs;
52
53 for cert in certs {
54 match roots.add(cert) {
55 Ok(_) => (),
56 Err(err) => Err(anyhow::anyhow!(
57 "Couldn't add a certificate to the certificate store: {}",
58 err
59 ))?,
60 }
61 }
62
63 let mut connections_vec = Vec::new();
64 for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST {
65 connections_vec.push(RwLock::new(HashMap::new()));
66 }
67 Ok(Box::new(ReverseProxyModule::new(
68 Arc::new(roots),
69 Arc::new(connections_vec),
70 Arc::new(RwLock::new(TtlCache::new(Duration::from_millis(
71 config["global"]["loadBalancerHealthCheckWindow"]
72 .as_i64()
73 .unwrap_or(5000) as u64,
74 )))),
75 )))
76}
77
78#[allow(clippy::type_complexity)]
79struct ReverseProxyModule {
80 roots: Arc<RootCertStore>,
81 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
82 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
83}
84
85impl ReverseProxyModule {
86 #[allow(clippy::type_complexity)]
87 fn new(
88 roots: Arc<RootCertStore>,
89 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
90 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
91 ) -> Self {
92 Self {
93 roots,
94 connections,
95 failed_backends,
96 }
97 }
98}
99
100impl ServerModule for ReverseProxyModule {
101 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
102 Box::new(ReverseProxyModuleHandlers {
103 roots: self.roots.clone(),
104 connections: self.connections.clone(),
105 failed_backends: self.failed_backends.clone(),
106 handle,
107 })
108 }
109}
110
111#[allow(clippy::type_complexity)]
112struct ReverseProxyModuleHandlers {
113 handle: Handle,
114 roots: Arc<RootCertStore>,
115 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
116 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
117}
118
119#[async_trait]
120impl ServerModuleHandlers for ReverseProxyModuleHandlers {
121 async fn request_handler(
122 &mut self,
123 request: RequestData,
124 config: &ServerConfig,
125 socket_data: &SocketData,
126 error_logger: &ErrorLogger,
127 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
128 WithRuntime::new(self.handle.clone(), async move {
129 let enable_health_check = config["enableLoadBalancerHealthCheck"]
130 .as_bool()
131 .unwrap_or(false);
132 let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
133 .as_i64()
134 .unwrap_or(3) as u64;
135 let disable_certificate_verification = config["disableProxyCertificateVerification"]
136 .as_bool()
137 .unwrap_or(false);
138 let proxy_intercept_errors = config["proxyInterceptErrors"].as_bool().unwrap_or(false);
139 if let Some(proxy_to) = determine_proxy_to(
140 config,
141 socket_data.encrypted,
142 &self.failed_backends,
143 enable_health_check,
144 health_check_max_fails,
145 )
146 .await
147 {
148 let (hyper_request, _, _, _) = request.into_parts();
149 let (mut hyper_request_parts, request_body) = hyper_request.into_parts();
150
151 let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
152 let scheme_str = proxy_request_url.scheme_str();
153 let mut encrypted = false;
154
155 match scheme_str {
156 Some("http") => {
157 encrypted = false;
158 }
159 Some("https") => {
160 encrypted = true;
161 }
162 _ => Err(anyhow::anyhow!(
163 "Only HTTP and HTTPS reverse proxy URLs are supported."
164 ))?,
165 };
166
167 let host = match proxy_request_url.host() {
168 Some(host) => host,
169 None => Err(anyhow::anyhow!(
170 "The reverse proxy URL doesn't include the host"
171 ))?,
172 };
173
174 let port = proxy_request_url.port_u16().unwrap_or(match scheme_str {
175 Some("http") => 80,
176 Some("https") => 443,
177 _ => 80,
178 });
179
180 let addr = format!("{host}:{port}");
181 let authority = proxy_request_url.authority().cloned();
182
183 let hyper_request_path = hyper_request_parts.uri.path();
184
185 let path = match hyper_request_path.as_bytes().first() {
186 Some(b'/') => {
187 let mut proxy_request_path = proxy_request_url.path();
188 while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
189 proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
190 }
191 format!("{proxy_request_path}{hyper_request_path}")
192 }
193 _ => hyper_request_path.to_string(),
194 };
195
196 hyper_request_parts.uri = Uri::from_str(&format!(
197 "{}{}",
198 path,
199 match hyper_request_parts.uri.query() {
200 Some(query) => format!("?{query}"),
201 None => "".to_string(),
202 }
203 ))?;
204
205 let original_host = hyper_request_parts.headers.get(header::HOST).cloned();
206
207 match authority {
209 Some(authority) => {
210 hyper_request_parts
211 .headers
212 .insert(header::HOST, authority.to_string().parse()?);
213 }
214 None => {
215 hyper_request_parts.headers.remove(header::HOST);
216 }
217 }
218
219 hyper_request_parts
221 .headers
222 .insert(header::CONNECTION, "keep-alive".parse()?);
223
224 if config["disableProxyXForwarded"].as_bool().unwrap_or(false) {
226 hyper_request_parts
227 .headers
228 .remove(HeaderName::from_static("x-forwarder-for"));
229 hyper_request_parts
230 .headers
231 .remove(HeaderName::from_static("x-forwarded-proto"));
232 hyper_request_parts
233 .headers
234 .remove(HeaderName::from_static("x-forwarded-host"));
235 } else {
236 hyper_request_parts.headers.insert(
237 HeaderName::from_static("x-forwarded-for"),
238 socket_data
239 .remote_addr
240 .ip()
241 .to_canonical()
242 .to_string()
243 .parse()?,
244 );
245
246 if socket_data.encrypted {
247 hyper_request_parts.headers.insert(
248 HeaderName::from_static("x-forwarded-proto"),
249 "https".parse()?,
250 );
251 } else {
252 hyper_request_parts.headers.insert(
253 HeaderName::from_static("x-forwarded-proto"),
254 "http".parse()?,
255 );
256 }
257
258 if let Some(original_host) = original_host {
259 hyper_request_parts
260 .headers
261 .insert("x-forwarded-host", original_host);
262 }
263 }
264
265 hyper_request_parts.version = Version::HTTP_11;
266
267 let proxy_request = Request::from_parts(hyper_request_parts, request_body);
268
269 let connections = &self.connections[rand::random_range(..self.connections.len())];
270
271 let rwlock_read = connections.read().await;
272 let sender_read_option = rwlock_read.get(&addr);
273
274 if let Some(sender_read) = sender_read_option {
275 if !sender_read.is_closed() {
276 drop(rwlock_read);
277 let mut rwlock_write = connections.write().await;
278 let sender_option = rwlock_write.get_mut(&addr);
279
280 if let Some(sender) = sender_option {
281 if !sender.is_closed() && sender.ready().await.is_ok() {
282 let result = http_proxy_kept_alive(
283 sender,
284 proxy_request,
285 error_logger,
286 proxy_intercept_errors,
287 )
288 .await;
289 drop(rwlock_write);
290 return result;
291 } else {
292 drop(rwlock_write);
293 }
294 } else {
295 drop(rwlock_write);
296 }
297 } else {
298 drop(rwlock_read);
299 }
300 } else {
301 drop(rwlock_read);
302 }
303
304 let stream = match TcpStream::connect(&addr).await {
305 Ok(stream) => stream,
306 Err(err) => {
307 if enable_health_check {
308 let mut failed_backends_write = self.failed_backends.write().await;
309 let proxy_to = proxy_to.clone();
310 let failed_attempts = failed_backends_write.get(&proxy_to);
311 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
312 }
313 match err.kind() {
314 tokio::io::ErrorKind::ConnectionRefused
315 | tokio::io::ErrorKind::NotFound
316 | tokio::io::ErrorKind::HostUnreachable => {
317 error_logger
318 .log(&format!("Service unavailable: {err}"))
319 .await;
320 return Ok(
321 ResponseData::builder_without_request()
322 .status(StatusCode::SERVICE_UNAVAILABLE)
323 .build(),
324 );
325 }
326 tokio::io::ErrorKind::TimedOut => {
327 error_logger.log(&format!("Gateway timeout: {err}")).await;
328 return Ok(
329 ResponseData::builder_without_request()
330 .status(StatusCode::GATEWAY_TIMEOUT)
331 .build(),
332 );
333 }
334 _ => {
335 error_logger.log(&format!("Bad gateway: {err}")).await;
336 return Ok(
337 ResponseData::builder_without_request()
338 .status(StatusCode::BAD_GATEWAY)
339 .build(),
340 );
341 }
342 };
343 }
344 };
345
346 match stream.set_nodelay(true) {
347 Ok(_) => (),
348 Err(err) => {
349 if enable_health_check {
350 let mut failed_backends_write = self.failed_backends.write().await;
351 let proxy_to = proxy_to.clone();
352 let failed_attempts = failed_backends_write.get(&proxy_to);
353 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
354 }
355 error_logger.log(&format!("Bad gateway: {err}")).await;
356 return Ok(
357 ResponseData::builder_without_request()
358 .status(StatusCode::BAD_GATEWAY)
359 .build(),
360 );
361 }
362 };
363
364 let failed_backends_option_borrowed = if enable_health_check {
365 Some(&*self.failed_backends)
366 } else {
367 None
368 };
369
370 if !encrypted {
371 http_proxy(
372 connections,
373 addr,
374 stream,
375 proxy_request,
376 error_logger,
377 proxy_to,
378 failed_backends_option_borrowed,
379 proxy_intercept_errors,
380 )
381 .await
382 } else {
383 let tls_client_config = (if disable_certificate_verification {
384 rustls::ClientConfig::builder()
385 .dangerous()
386 .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
387 } else {
388 rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
389 })
390 .with_no_client_auth();
391 let connector = TlsConnector::from(Arc::new(tls_client_config));
392 let domain = ServerName::try_from(host)?.to_owned();
393
394 let tls_stream = match connector.connect(domain, stream).await {
395 Ok(stream) => stream,
396 Err(err) => {
397 if enable_health_check {
398 let mut failed_backends_write = self.failed_backends.write().await;
399 let proxy_to = proxy_to.clone();
400 let failed_attempts = failed_backends_write.get(&proxy_to);
401 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
402 }
403 error_logger.log(&format!("Bad gateway: {err}")).await;
404 return Ok(
405 ResponseData::builder_without_request()
406 .status(StatusCode::BAD_GATEWAY)
407 .build(),
408 );
409 }
410 };
411
412 http_proxy(
413 connections,
414 addr,
415 tls_stream,
416 proxy_request,
417 error_logger,
418 proxy_to,
419 failed_backends_option_borrowed,
420 proxy_intercept_errors,
421 )
422 .await
423 }
424 } else {
425 Ok(ResponseData::builder(request).build())
426 }
427 })
428 .await
429 }
430
431 async fn proxy_request_handler(
432 &mut self,
433 request: RequestData,
434 _config: &ServerConfig,
435 _socket_data: &SocketData,
436 _error_logger: &ErrorLogger,
437 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
438 Ok(ResponseData::builder(request).build())
439 }
440
441 async fn response_modifying_handler(
442 &mut self,
443 response: HyperResponse,
444 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
445 Ok(response)
446 }
447
448 async fn proxy_response_modifying_handler(
449 &mut self,
450 response: HyperResponse,
451 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
452 Ok(response)
453 }
454
455 async fn connect_proxy_request_handler(
456 &mut self,
457 _upgraded_request: HyperUpgraded,
458 _connect_address: &str,
459 _config: &ServerConfig,
460 _socket_data: &SocketData,
461 _error_logger: &ErrorLogger,
462 ) -> Result<(), Box<dyn Error + Send + Sync>> {
463 Ok(())
464 }
465
466 fn does_connect_proxy_requests(&mut self) -> bool {
467 false
468 }
469
470 async fn websocket_request_handler(
471 &mut self,
472 websocket: HyperWebsocket,
473 uri: &hyper::Uri,
474 headers: &hyper::HeaderMap,
475 config: &ServerConfig,
476 socket_data: &SocketData,
477 error_logger: &ErrorLogger,
478 ) -> Result<(), Box<dyn Error + Send + Sync>> {
479 WithRuntime::new(self.handle.clone(), async move {
480 let enable_health_check = config["enableLoadBalancerHealthCheck"]
481 .as_bool()
482 .unwrap_or(false);
483 let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
484 .as_i64()
485 .unwrap_or(3) as u64;
486
487 let disable_certificate_verification = config["disableProxyCertificateVerification"]
488 .as_bool()
489 .unwrap_or(false);
490 if let Some(proxy_to) = determine_proxy_to(
491 config,
492 socket_data.encrypted,
493 &self.failed_backends,
494 enable_health_check,
495 health_check_max_fails,
496 )
497 .await
498 {
499 let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
500 let scheme_str = proxy_request_url.scheme_str();
501 let mut encrypted = false;
502
503 match scheme_str {
504 Some("http") => {
505 encrypted = false;
506 }
507 Some("https") => {
508 encrypted = true;
509 }
510 _ => Err(anyhow::anyhow!(
511 "Only HTTP and HTTPS reverse proxy URLs are supported."
512 ))?,
513 };
514
515 let request_path = uri.path();
516
517 let path = match request_path.as_bytes().first() {
518 Some(b'/') => {
519 let mut proxy_request_path = proxy_request_url.path();
520 while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
521 proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
522 }
523 format!("{proxy_request_path}{request_path}")
524 }
525 _ => request_path.to_string(),
526 };
527
528 let mut proxy_request_url_parts = proxy_request_url.into_parts();
529 proxy_request_url_parts.scheme = if encrypted {
530 Some(Scheme::from_str("wss")?)
531 } else {
532 Some(Scheme::from_str("ws")?)
533 };
534 match uri.path_and_query() {
535 Some(path_and_query) => {
536 let path_and_query_string = match path_and_query.query() {
537 Some(query) => {
538 format!("{path}?{query}")
539 }
540 None => path,
541 };
542 proxy_request_url_parts.path_and_query =
543 Some(PathAndQuery::from_str(&path_and_query_string)?);
544 }
545 None => {
546 proxy_request_url_parts.path_and_query = Some(PathAndQuery::from_str(&path)?);
547 }
548 };
549
550 let proxy_request_url = hyper::Uri::from_parts(proxy_request_url_parts)?;
551
552 let connector = if !encrypted {
553 Connector::Plain
554 } else {
555 Connector::Rustls(Arc::new(
556 (if disable_certificate_verification {
557 rustls::ClientConfig::builder()
558 .dangerous()
559 .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
560 } else {
561 rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
562 })
563 .with_no_client_auth(),
564 ))
565 };
566
567 let mut proxy_request_builder = ClientRequestBuilder::new(proxy_request_url);
568 for (header_name, header_value) in headers {
569 let header_name_str = header_name.as_str();
570 if header_name == SEC_WEBSOCKET_PROTOCOL {
571 for subprotocol in String::from_utf8_lossy(header_value.as_bytes()).split(",") {
572 proxy_request_builder = proxy_request_builder.with_sub_protocol(subprotocol.trim());
573 }
574 } else if !header_name_str.starts_with("sec-websocket-")
575 && header_name_str != HeaderName::from_static("x-forwarded-for")
576 {
577 proxy_request_builder = proxy_request_builder.with_header(
578 header_name_str,
579 String::from_utf8_lossy(header_value.as_bytes()),
580 );
581 }
582 }
583
584 proxy_request_builder = proxy_request_builder.with_header(
586 "x-forwarded-for",
587 socket_data.remote_addr.ip().to_canonical().to_string(),
588 );
589
590 let proxy_request_constructed = proxy_request_builder.into_client_request()?;
591
592 let client_bi_stream = websocket.await?;
593
594 let (proxy_bi_stream, _) = match tokio_tungstenite::connect_async_tls_with_config(
595 proxy_request_constructed,
596 None,
597 true,
598 Some(connector),
599 )
600 .await
601 {
602 Ok(data) => data,
603 Err(err) => {
604 error_logger
605 .log(&format!("Cannot connect to WebSocket server: {err}"))
606 .await;
607 return Ok(());
608 }
609 };
610
611 let (mut client_sink, mut client_stream) = client_bi_stream.split();
612 let (mut proxy_sink, mut proxy_stream) = proxy_bi_stream.split();
613
614 let client_to_proxy = async {
615 while let Some(Ok(value)) = client_stream.next().await {
616 if proxy_sink.send(value).await.is_err() {
617 break;
618 }
619 }
620 };
621
622 let proxy_to_client = async {
623 while let Some(Ok(value)) = proxy_stream.next().await {
624 if client_sink.send(value).await.is_err() {
625 break;
626 }
627 }
628 };
629
630 tokio::pin!(client_to_proxy);
631 tokio::pin!(proxy_to_client);
632
633 let client_to_proxy_first;
634 tokio::select! {
635 _ = &mut client_to_proxy => {
636 client_to_proxy_first = true;
637 }
638 _ = &mut proxy_to_client => {
639 client_to_proxy_first = false;
640 }
641 }
642
643 if client_to_proxy_first {
644 proxy_to_client.await;
645 } else {
646 client_to_proxy.await;
647 }
648 }
649
650 Ok(())
651 })
652 .await
653 }
654
655 fn does_websocket_requests(&mut self, config: &ServerConfig, socket_data: &SocketData) -> bool {
656 if socket_data.encrypted {
657 let secure_proxy_to = &config["secureProxyTo"];
658 if secure_proxy_to.as_vec().is_some() || secure_proxy_to.as_str().is_some() {
659 return true;
660 }
661 }
662
663 let proxy_to = &config["proxyTo"];
664 proxy_to.as_vec().is_some() || proxy_to.as_str().is_some()
665 }
666}
667
668async fn determine_proxy_to(
669 config: &ServerConfig,
670 encrypted: bool,
671 failed_backends: &RwLock<TtlCache<String, u64>>,
672 enable_health_check: bool,
673 health_check_max_fails: u64,
674) -> Option<String> {
675 let mut proxy_to = None;
676 if encrypted {
680 let secure_proxy_to_yaml = &config["secureProxyTo"];
681 if let Some(secure_proxy_to_vector) = secure_proxy_to_yaml.as_vec() {
682 if enable_health_check {
683 let mut secure_proxy_to_vector = secure_proxy_to_vector.clone();
684 loop {
685 if !secure_proxy_to_vector.is_empty() {
686 let index = rand::random_range(..secure_proxy_to_vector.len());
687 if let Some(secure_proxy_to) = secure_proxy_to_vector[index].as_str() {
688 proxy_to = Some(secure_proxy_to.to_string());
689 let failed_backends_read = failed_backends.read().await;
690 let failed_backend_fails =
691 match failed_backends_read.get(&secure_proxy_to.to_string()) {
692 Some(fails) => fails,
693 None => break,
694 };
695 if failed_backend_fails > health_check_max_fails {
696 secure_proxy_to_vector.remove(index);
697 } else {
698 break;
699 }
700 }
701 } else {
702 break;
703 }
704 }
705 } else if !secure_proxy_to_vector.is_empty() {
706 if let Some(secure_proxy_to) =
707 secure_proxy_to_vector[rand::random_range(..secure_proxy_to_vector.len())].as_str()
708 {
709 proxy_to = Some(secure_proxy_to.to_string());
710 }
711 }
712 } else if let Some(secure_proxy_to) = secure_proxy_to_yaml.as_str() {
713 proxy_to = Some(secure_proxy_to.to_string());
714 }
715 }
716
717 if proxy_to.is_none() {
718 let proxy_to_yaml = &config["proxyTo"];
719 if let Some(proxy_to_vector) = proxy_to_yaml.as_vec() {
720 if enable_health_check {
721 let mut proxy_to_vector = proxy_to_vector.clone();
722 loop {
723 if !proxy_to_vector.is_empty() {
724 let index = rand::random_range(..proxy_to_vector.len());
725 if let Some(proxy_to_str) = proxy_to_vector[index].as_str() {
726 proxy_to = Some(proxy_to_str.to_string());
727 let failed_backends_read = failed_backends.read().await;
728 let failed_backend_fails = match failed_backends_read.get(&proxy_to_str.to_string()) {
729 Some(fails) => fails,
730 None => break,
731 };
732 if failed_backend_fails > health_check_max_fails {
733 proxy_to_vector.remove(index);
734 } else {
735 break;
736 }
737 }
738 } else {
739 break;
740 }
741 }
742 } else if !proxy_to_vector.is_empty() {
743 if let Some(proxy_to_str) =
744 proxy_to_vector[rand::random_range(..proxy_to_vector.len())].as_str()
745 {
746 proxy_to = Some(proxy_to_str.to_string());
747 }
748 }
749 } else if let Some(proxy_to_str) = proxy_to_yaml.as_str() {
750 proxy_to = Some(proxy_to_str.to_string());
751 }
752 }
753
754 proxy_to
755}
756
757#[allow(clippy::too_many_arguments)]
758async fn http_proxy(
759 connections: &RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>,
760 connect_addr: String,
761 stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
762 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
763 error_logger: &ErrorLogger,
764 proxy_to: String,
765 failed_backends: Option<&tokio::sync::RwLock<TtlCache<std::string::String, u64>>>,
766 proxy_intercept_errors: bool,
767) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
768 let io = TokioIo::new(stream);
769
770 let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
771 Ok(data) => data,
772 Err(err) => {
773 if let Some(failed_backends) = failed_backends {
774 let mut failed_backends_write = failed_backends.write().await;
775 let failed_attempts = failed_backends_write.get(&proxy_to);
776 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
777 }
778 error_logger.log(&format!("Bad gateway: {err}")).await;
779 return Ok(
780 ResponseData::builder_without_request()
781 .status(StatusCode::BAD_GATEWAY)
782 .build(),
783 );
784 }
785 };
786
787 let send_request = sender.send_request(proxy_request);
788
789 let mut pinned_conn = Box::pin(conn);
790 tokio::pin!(send_request);
791
792 let response;
793
794 loop {
795 tokio::select! {
796 biased;
797
798 proxy_response = &mut send_request => {
799 let proxy_response = match proxy_response {
800 Ok(response) => response,
801 Err(err) => {
802 error_logger.log(&format!("Bad gateway: {err}")).await;
803 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
804 }
805 };
806
807 let status_code = proxy_response.status();
808 response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
809 ResponseData::builder_without_request()
810 .status(status_code)
811 .parallel_fn(async move {
812 pinned_conn.await.unwrap_or_default();
813 })
814 .build()
815 } else {
816 ResponseData::builder_without_request()
817 .response(proxy_response.map(|b| {
818 b.map_err(|e| std::io::Error::other(e.to_string()))
819 .boxed()
820 }))
821 .parallel_fn(async move {
822 pinned_conn.await.unwrap_or_default();
823 })
824 .build()
825 };
826
827 break;
828 },
829 state = &mut pinned_conn => {
830 if state.is_err() {
831 error_logger.log("Bad gateway: incomplete response").await;
832 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
833 }
834 },
835 };
836 }
837
838 if !sender.is_closed() {
839 let mut rwlock_write = connections.write().await;
840 rwlock_write.insert(connect_addr, sender);
841 drop(rwlock_write);
842 }
843
844 Ok(response)
845}
846
847async fn http_proxy_kept_alive(
848 sender: &mut SendRequest<BoxBody<Bytes, std::io::Error>>,
849 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
850 error_logger: &ErrorLogger,
851 proxy_intercept_errors: bool,
852) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
853 let proxy_response = match sender.send_request(proxy_request).await {
854 Ok(response) => response,
855 Err(err) => {
856 error_logger.log(&format!("Bad gateway: {err}")).await;
857 return Ok(
858 ResponseData::builder_without_request()
859 .status(StatusCode::BAD_GATEWAY)
860 .build(),
861 );
862 }
863 };
864
865 let status_code = proxy_response.status();
866 let response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
867 ResponseData::builder_without_request()
868 .status(status_code)
869 .build()
870 } else {
871 ResponseData::builder_without_request()
872 .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed()))
873 .build()
874 };
875
876 Ok(response)
877}