1use std::collections::HashMap;
2use std::error::Error;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::ferron_common::{
8 ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
9 ServerModuleHandlers, SocketData,
10};
11use crate::ferron_common::{HyperResponse, WithRuntime};
12use async_trait::async_trait;
13use futures_util::{SinkExt, StreamExt};
14use http::header::SEC_WEBSOCKET_PROTOCOL;
15use http::uri::{PathAndQuery, Scheme};
16use http_body_util::combinators::BoxBody;
17use http_body_util::BodyExt;
18use hyper::body::Bytes;
19use hyper::client::conn::http1::SendRequest;
20use hyper::{header, Request, StatusCode, Uri, Version};
21use hyper_tungstenite::HyperWebsocket;
22use hyper_util::rt::TokioIo;
23use rustls::pki_types::ServerName;
24use rustls::RootCertStore;
25use rustls_native_certs::load_native_certs;
26use tokio::io::{AsyncRead, AsyncWrite};
27use tokio::net::TcpStream;
28use tokio::runtime::Handle;
29use tokio::sync::RwLock;
30use tokio_rustls::TlsConnector;
31use tokio_tungstenite::tungstenite::client::IntoClientRequest;
32use tokio_tungstenite::tungstenite::ClientRequestBuilder;
33use tokio_tungstenite::Connector;
34
35use crate::ferron_util::no_server_verifier::NoServerVerifier;
36use crate::ferron_util::ttl_cache::TtlCache;
37
38const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32;
39
40pub fn server_module_init(
41 config: &ServerConfig,
42) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
43 let mut roots: RootCertStore = RootCertStore::empty();
44 let certs_result = load_native_certs();
45 if !certs_result.errors.is_empty() {
46 Err(anyhow::anyhow!(format!(
47 "Couldn't load the native certificate store: {}",
48 certs_result.errors[0]
49 )))?
50 }
51 let certs = certs_result.certs;
52
53 for cert in certs {
54 match roots.add(cert) {
55 Ok(_) => (),
56 Err(err) => Err(anyhow::anyhow!(format!(
57 "Couldn't add a certificate to the certificate store: {}",
58 err
59 )))?,
60 }
61 }
62
63 let mut connections_vec = Vec::new();
64 for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST {
65 connections_vec.push(RwLock::new(HashMap::new()));
66 }
67 Ok(Box::new(ReverseProxyModule::new(
68 Arc::new(roots),
69 Arc::new(connections_vec),
70 Arc::new(RwLock::new(TtlCache::new(Duration::from_millis(
71 config["global"]["loadBalancerHealthCheckWindow"]
72 .as_i64()
73 .unwrap_or(5000) as u64,
74 )))),
75 )))
76}
77
78#[allow(clippy::type_complexity)]
79struct ReverseProxyModule {
80 roots: Arc<RootCertStore>,
81 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
82 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
83}
84
85impl ReverseProxyModule {
86 #[allow(clippy::type_complexity)]
87 fn new(
88 roots: Arc<RootCertStore>,
89 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
90 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
91 ) -> Self {
92 Self {
93 roots,
94 connections,
95 failed_backends,
96 }
97 }
98}
99
100impl ServerModule for ReverseProxyModule {
101 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
102 Box::new(ReverseProxyModuleHandlers {
103 roots: self.roots.clone(),
104 connections: self.connections.clone(),
105 failed_backends: self.failed_backends.clone(),
106 handle,
107 })
108 }
109}
110
111#[allow(clippy::type_complexity)]
112struct ReverseProxyModuleHandlers {
113 handle: Handle,
114 roots: Arc<RootCertStore>,
115 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>>>,
116 failed_backends: Arc<RwLock<TtlCache<String, u64>>>,
117}
118
119#[async_trait]
120impl ServerModuleHandlers for ReverseProxyModuleHandlers {
121 async fn request_handler(
122 &mut self,
123 request: RequestData,
124 config: &ServerConfig,
125 socket_data: &SocketData,
126 error_logger: &ErrorLogger,
127 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
128 WithRuntime::new(self.handle.clone(), async move {
129 let enable_health_check = config["enableLoadBalancerHealthCheck"]
130 .as_bool()
131 .unwrap_or(false);
132 let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
133 .as_i64()
134 .unwrap_or(3) as u64;
135 let disable_certificate_verification = config["disableProxyCertificateVerification"]
136 .as_bool()
137 .unwrap_or(false);
138 let proxy_intercept_errors = config["proxyInterceptErrors"].as_bool().unwrap_or(false);
139 if let Some(proxy_to) = determine_proxy_to(
140 config,
141 socket_data.encrypted,
142 &self.failed_backends,
143 enable_health_check,
144 health_check_max_fails,
145 )
146 .await
147 {
148 let (hyper_request, _, _, _) = request.into_parts();
149 let (mut hyper_request_parts, request_body) = hyper_request.into_parts();
150
151 let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
152 let scheme_str = proxy_request_url.scheme_str();
153 let mut encrypted = false;
154
155 match scheme_str {
156 Some("http") => {
157 encrypted = false;
158 }
159 Some("https") => {
160 encrypted = true;
161 }
162 _ => Err(anyhow::anyhow!(
163 "Only HTTP and HTTPS reverse proxy URLs are supported."
164 ))?,
165 };
166
167 let host = match proxy_request_url.host() {
168 Some(host) => host,
169 None => Err(anyhow::anyhow!(
170 "The reverse proxy URL doesn't include the host"
171 ))?,
172 };
173
174 let port = proxy_request_url.port_u16().unwrap_or(match scheme_str {
175 Some("http") => 80,
176 Some("https") => 443,
177 _ => 80,
178 });
179
180 let addr = format!("{host}:{port}");
181 let authority = proxy_request_url.authority().cloned();
182
183 let hyper_request_path = hyper_request_parts.uri.path();
184
185 let path = match hyper_request_path.as_bytes().first() {
186 Some(b'/') => {
187 let mut proxy_request_path = proxy_request_url.path();
188 while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
189 proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
190 }
191 format!("{proxy_request_path}{hyper_request_path}")
192 }
193 _ => hyper_request_path.to_string(),
194 };
195
196 hyper_request_parts.uri = Uri::from_str(&format!(
197 "{}{}",
198 path,
199 match hyper_request_parts.uri.query() {
200 Some(query) => format!("?{query}"),
201 None => "".to_string(),
202 }
203 ))?;
204
205 let original_host = hyper_request_parts.headers.get(header::HOST).cloned();
206
207 match authority {
209 Some(authority) => {
210 hyper_request_parts
211 .headers
212 .insert(header::HOST, authority.to_string().parse()?);
213 }
214 None => {
215 hyper_request_parts.headers.remove(header::HOST);
216 }
217 }
218
219 hyper_request_parts
221 .headers
222 .insert(header::CONNECTION, "keep-alive".parse()?);
223
224 if config["disableProxyXForwarded"].as_bool().unwrap_or(false) {
226 hyper_request_parts.headers.remove("x-forwarder-for");
227 hyper_request_parts.headers.remove("x-forwarded-proto");
228 hyper_request_parts.headers.remove("x-forwarded-host");
229 } else {
230 hyper_request_parts.headers.insert(
231 "x-forwarded-for",
232 socket_data
233 .remote_addr
234 .ip()
235 .to_canonical()
236 .to_string()
237 .parse()?,
238 );
239
240 if socket_data.encrypted {
241 hyper_request_parts
242 .headers
243 .insert("x-forwarded-proto", "https".parse()?);
244 } else {
245 hyper_request_parts
246 .headers
247 .insert("x-forwarded-proto", "http".parse()?);
248 }
249
250 if let Some(original_host) = original_host {
251 hyper_request_parts
252 .headers
253 .insert("x-forwarded-host", original_host);
254 }
255 }
256
257 hyper_request_parts.version = Version::HTTP_11;
258
259 let proxy_request = Request::from_parts(hyper_request_parts, request_body);
260
261 let connections = &self.connections[rand::random_range(..self.connections.len())];
262
263 let rwlock_read = connections.read().await;
264 let sender_read_option = rwlock_read.get(&addr);
265
266 if let Some(sender_read) = sender_read_option {
267 if !sender_read.is_closed() {
268 drop(rwlock_read);
269 let mut rwlock_write = connections.write().await;
270 let sender_option = rwlock_write.get_mut(&addr);
271
272 if let Some(sender) = sender_option {
273 if !sender.is_closed() && sender.ready().await.is_ok() {
274 let result = http_proxy_kept_alive(
275 sender,
276 proxy_request,
277 error_logger,
278 proxy_intercept_errors,
279 )
280 .await;
281 drop(rwlock_write);
282 return result;
283 } else {
284 drop(rwlock_write);
285 }
286 } else {
287 drop(rwlock_write);
288 }
289 } else {
290 drop(rwlock_read);
291 }
292 } else {
293 drop(rwlock_read);
294 }
295
296 let stream = match TcpStream::connect(&addr).await {
297 Ok(stream) => stream,
298 Err(err) => {
299 if enable_health_check {
300 let mut failed_backends_write = self.failed_backends.write().await;
301 let proxy_to = proxy_to.clone();
302 let failed_attempts = failed_backends_write.get(&proxy_to);
303 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
304 }
305 match err.kind() {
306 tokio::io::ErrorKind::ConnectionRefused
307 | tokio::io::ErrorKind::NotFound
308 | tokio::io::ErrorKind::HostUnreachable => {
309 error_logger
310 .log(&format!("Service unavailable: {err}"))
311 .await;
312 return Ok(
313 ResponseData::builder_without_request()
314 .status(StatusCode::SERVICE_UNAVAILABLE)
315 .build(),
316 );
317 }
318 tokio::io::ErrorKind::TimedOut => {
319 error_logger.log(&format!("Gateway timeout: {err}")).await;
320 return Ok(
321 ResponseData::builder_without_request()
322 .status(StatusCode::GATEWAY_TIMEOUT)
323 .build(),
324 );
325 }
326 _ => {
327 error_logger.log(&format!("Bad gateway: {err}")).await;
328 return Ok(
329 ResponseData::builder_without_request()
330 .status(StatusCode::BAD_GATEWAY)
331 .build(),
332 );
333 }
334 };
335 }
336 };
337
338 match stream.set_nodelay(true) {
339 Ok(_) => (),
340 Err(err) => {
341 if enable_health_check {
342 let mut failed_backends_write = self.failed_backends.write().await;
343 let proxy_to = proxy_to.clone();
344 let failed_attempts = failed_backends_write.get(&proxy_to);
345 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
346 }
347 error_logger.log(&format!("Bad gateway: {err}")).await;
348 return Ok(
349 ResponseData::builder_without_request()
350 .status(StatusCode::BAD_GATEWAY)
351 .build(),
352 );
353 }
354 };
355
356 let failed_backends_option_borrowed = if enable_health_check {
357 Some(&*self.failed_backends)
358 } else {
359 None
360 };
361
362 if !encrypted {
363 http_proxy(
364 connections,
365 addr,
366 stream,
367 proxy_request,
368 error_logger,
369 proxy_to,
370 failed_backends_option_borrowed,
371 proxy_intercept_errors,
372 )
373 .await
374 } else {
375 let tls_client_config = (if disable_certificate_verification {
376 rustls::ClientConfig::builder()
377 .dangerous()
378 .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
379 } else {
380 rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
381 })
382 .with_no_client_auth();
383 let connector = TlsConnector::from(Arc::new(tls_client_config));
384 let domain = ServerName::try_from(host)?.to_owned();
385
386 let tls_stream = match connector.connect(domain, stream).await {
387 Ok(stream) => stream,
388 Err(err) => {
389 if enable_health_check {
390 let mut failed_backends_write = self.failed_backends.write().await;
391 let proxy_to = proxy_to.clone();
392 let failed_attempts = failed_backends_write.get(&proxy_to);
393 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
394 }
395 error_logger.log(&format!("Bad gateway: {err}")).await;
396 return Ok(
397 ResponseData::builder_without_request()
398 .status(StatusCode::BAD_GATEWAY)
399 .build(),
400 );
401 }
402 };
403
404 http_proxy(
405 connections,
406 addr,
407 tls_stream,
408 proxy_request,
409 error_logger,
410 proxy_to,
411 failed_backends_option_borrowed,
412 proxy_intercept_errors,
413 )
414 .await
415 }
416 } else {
417 Ok(ResponseData::builder(request).build())
418 }
419 })
420 .await
421 }
422
423 async fn proxy_request_handler(
424 &mut self,
425 request: RequestData,
426 _config: &ServerConfig,
427 _socket_data: &SocketData,
428 _error_logger: &ErrorLogger,
429 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
430 Ok(ResponseData::builder(request).build())
431 }
432
433 async fn response_modifying_handler(
434 &mut self,
435 response: HyperResponse,
436 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
437 Ok(response)
438 }
439
440 async fn proxy_response_modifying_handler(
441 &mut self,
442 response: HyperResponse,
443 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
444 Ok(response)
445 }
446
447 async fn connect_proxy_request_handler(
448 &mut self,
449 _upgraded_request: HyperUpgraded,
450 _connect_address: &str,
451 _config: &ServerConfig,
452 _socket_data: &SocketData,
453 _error_logger: &ErrorLogger,
454 ) -> Result<(), Box<dyn Error + Send + Sync>> {
455 Ok(())
456 }
457
458 fn does_connect_proxy_requests(&mut self) -> bool {
459 false
460 }
461
462 async fn websocket_request_handler(
463 &mut self,
464 websocket: HyperWebsocket,
465 uri: &hyper::Uri,
466 headers: &hyper::HeaderMap,
467 config: &ServerConfig,
468 socket_data: &SocketData,
469 error_logger: &ErrorLogger,
470 ) -> Result<(), Box<dyn Error + Send + Sync>> {
471 WithRuntime::new(self.handle.clone(), async move {
472 let enable_health_check = config["enableLoadBalancerHealthCheck"]
473 .as_bool()
474 .unwrap_or(false);
475 let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"]
476 .as_i64()
477 .unwrap_or(3) as u64;
478
479 let disable_certificate_verification = config["disableProxyCertificateVerification"]
480 .as_bool()
481 .unwrap_or(false);
482 if let Some(proxy_to) = determine_proxy_to(
483 config,
484 socket_data.encrypted,
485 &self.failed_backends,
486 enable_health_check,
487 health_check_max_fails,
488 )
489 .await
490 {
491 let proxy_request_url = proxy_to.parse::<hyper::Uri>()?;
492 let scheme_str = proxy_request_url.scheme_str();
493 let mut encrypted = false;
494
495 match scheme_str {
496 Some("http") => {
497 encrypted = false;
498 }
499 Some("https") => {
500 encrypted = true;
501 }
502 _ => Err(anyhow::anyhow!(
503 "Only HTTP and HTTPS reverse proxy URLs are supported."
504 ))?,
505 };
506
507 let request_path = uri.path();
508
509 let path = match request_path.as_bytes().first() {
510 Some(b'/') => {
511 let mut proxy_request_path = proxy_request_url.path();
512 while proxy_request_path.as_bytes().last().copied() == Some(b'/') {
513 proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)];
514 }
515 format!("{proxy_request_path}{request_path}")
516 }
517 _ => request_path.to_string(),
518 };
519
520 let mut proxy_request_url_parts = proxy_request_url.into_parts();
521 proxy_request_url_parts.scheme = if encrypted {
522 Some(Scheme::from_str("wss")?)
523 } else {
524 Some(Scheme::from_str("ws")?)
525 };
526 match uri.path_and_query() {
527 Some(path_and_query) => {
528 let path_and_query_string = match path_and_query.query() {
529 Some(query) => {
530 format!("{path}?{query}")
531 }
532 None => path,
533 };
534 proxy_request_url_parts.path_and_query =
535 Some(PathAndQuery::from_str(&path_and_query_string)?);
536 }
537 None => {
538 proxy_request_url_parts.path_and_query = Some(PathAndQuery::from_str(&path)?);
539 }
540 };
541
542 let proxy_request_url = hyper::Uri::from_parts(proxy_request_url_parts)?;
543
544 let connector = if !encrypted {
545 Connector::Plain
546 } else {
547 Connector::Rustls(Arc::new(
548 (if disable_certificate_verification {
549 rustls::ClientConfig::builder()
550 .dangerous()
551 .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new()))
552 } else {
553 rustls::ClientConfig::builder().with_root_certificates(self.roots.clone())
554 })
555 .with_no_client_auth(),
556 ))
557 };
558
559 let mut proxy_request_builder = ClientRequestBuilder::new(proxy_request_url);
560 for (header_name, header_value) in headers {
561 let header_name_str = header_name.as_str();
562 if header_name == SEC_WEBSOCKET_PROTOCOL {
563 for subprotocol in String::from_utf8_lossy(header_value.as_bytes()).split(",") {
564 proxy_request_builder = proxy_request_builder.with_sub_protocol(subprotocol.trim());
565 }
566 } else if !header_name_str.starts_with("sec-websocket-")
567 && header_name_str != "x-forwarded-for"
568 {
569 proxy_request_builder = proxy_request_builder.with_header(
570 header_name_str,
571 String::from_utf8_lossy(header_value.as_bytes()),
572 );
573 }
574 }
575
576 proxy_request_builder = proxy_request_builder.with_header(
578 "x-forwarded-for",
579 socket_data.remote_addr.ip().to_canonical().to_string(),
580 );
581
582 let proxy_request_constructed = proxy_request_builder.into_client_request()?;
583
584 let client_bi_stream = websocket.await?;
585
586 let (proxy_bi_stream, _) = match tokio_tungstenite::connect_async_tls_with_config(
587 proxy_request_constructed,
588 None,
589 true,
590 Some(connector),
591 )
592 .await
593 {
594 Ok(data) => data,
595 Err(err) => {
596 error_logger
597 .log(&format!("Cannot connect to WebSocket server: {err}"))
598 .await;
599 return Ok(());
600 }
601 };
602
603 let (mut client_sink, mut client_stream) = client_bi_stream.split();
604 let (mut proxy_sink, mut proxy_stream) = proxy_bi_stream.split();
605
606 let client_to_proxy = async {
607 while let Some(Ok(value)) = client_stream.next().await {
608 if proxy_sink.send(value).await.is_err() {
609 break;
610 }
611 }
612 };
613
614 let proxy_to_client = async {
615 while let Some(Ok(value)) = proxy_stream.next().await {
616 if client_sink.send(value).await.is_err() {
617 break;
618 }
619 }
620 };
621
622 tokio::pin!(client_to_proxy);
623 tokio::pin!(proxy_to_client);
624
625 let client_to_proxy_first;
626 tokio::select! {
627 _ = &mut client_to_proxy => {
628 client_to_proxy_first = true;
629 }
630 _ = &mut proxy_to_client => {
631 client_to_proxy_first = false;
632 }
633 }
634
635 if client_to_proxy_first {
636 proxy_to_client.await;
637 } else {
638 client_to_proxy.await;
639 }
640 }
641
642 Ok(())
643 })
644 .await
645 }
646
647 fn does_websocket_requests(&mut self, config: &ServerConfig, socket_data: &SocketData) -> bool {
648 if socket_data.encrypted {
649 let secure_proxy_to = &config["secureProxyTo"];
650 if secure_proxy_to.as_vec().is_some() || secure_proxy_to.as_str().is_some() {
651 return true;
652 }
653 }
654
655 let proxy_to = &config["proxyTo"];
656 proxy_to.as_vec().is_some() || proxy_to.as_str().is_some()
657 }
658}
659
660async fn determine_proxy_to(
661 config: &ServerConfig,
662 encrypted: bool,
663 failed_backends: &RwLock<TtlCache<String, u64>>,
664 enable_health_check: bool,
665 health_check_max_fails: u64,
666) -> Option<String> {
667 let mut proxy_to = None;
668 if encrypted {
672 let secure_proxy_to_yaml = &config["secureProxyTo"];
673 if let Some(secure_proxy_to_vector) = secure_proxy_to_yaml.as_vec() {
674 if enable_health_check {
675 let mut secure_proxy_to_vector = secure_proxy_to_vector.clone();
676 loop {
677 if !secure_proxy_to_vector.is_empty() {
678 let index = rand::random_range(..secure_proxy_to_vector.len());
679 if let Some(secure_proxy_to) = secure_proxy_to_vector[index].as_str() {
680 proxy_to = Some(secure_proxy_to.to_string());
681 let failed_backends_read = failed_backends.read().await;
682 let failed_backend_fails =
683 match failed_backends_read.get(&secure_proxy_to.to_string()) {
684 Some(fails) => fails,
685 None => break,
686 };
687 if failed_backend_fails > health_check_max_fails {
688 secure_proxy_to_vector.remove(index);
689 } else {
690 break;
691 }
692 }
693 } else {
694 break;
695 }
696 }
697 } else if !secure_proxy_to_vector.is_empty() {
698 if let Some(secure_proxy_to) =
699 secure_proxy_to_vector[rand::random_range(..secure_proxy_to_vector.len())].as_str()
700 {
701 proxy_to = Some(secure_proxy_to.to_string());
702 }
703 }
704 } else if let Some(secure_proxy_to) = secure_proxy_to_yaml.as_str() {
705 proxy_to = Some(secure_proxy_to.to_string());
706 }
707 }
708
709 if proxy_to.is_none() {
710 let proxy_to_yaml = &config["proxyTo"];
711 if let Some(proxy_to_vector) = proxy_to_yaml.as_vec() {
712 if enable_health_check {
713 let mut proxy_to_vector = proxy_to_vector.clone();
714 loop {
715 if !proxy_to_vector.is_empty() {
716 let index = rand::random_range(..proxy_to_vector.len());
717 if let Some(proxy_to_str) = proxy_to_vector[index].as_str() {
718 proxy_to = Some(proxy_to_str.to_string());
719 let failed_backends_read = failed_backends.read().await;
720 let failed_backend_fails = match failed_backends_read.get(&proxy_to_str.to_string()) {
721 Some(fails) => fails,
722 None => break,
723 };
724 if failed_backend_fails > health_check_max_fails {
725 proxy_to_vector.remove(index);
726 } else {
727 break;
728 }
729 }
730 } else {
731 break;
732 }
733 }
734 } else if !proxy_to_vector.is_empty() {
735 if let Some(proxy_to_str) =
736 proxy_to_vector[rand::random_range(..proxy_to_vector.len())].as_str()
737 {
738 proxy_to = Some(proxy_to_str.to_string());
739 }
740 }
741 } else if let Some(proxy_to_str) = proxy_to_yaml.as_str() {
742 proxy_to = Some(proxy_to_str.to_string());
743 }
744 }
745
746 proxy_to
747}
748
749#[allow(clippy::too_many_arguments)]
750async fn http_proxy(
751 connections: &RwLock<HashMap<String, SendRequest<BoxBody<Bytes, std::io::Error>>>>,
752 connect_addr: String,
753 stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
754 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
755 error_logger: &ErrorLogger,
756 proxy_to: String,
757 failed_backends: Option<&tokio::sync::RwLock<TtlCache<std::string::String, u64>>>,
758 proxy_intercept_errors: bool,
759) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
760 let io = TokioIo::new(stream);
761
762 let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
763 Ok(data) => data,
764 Err(err) => {
765 if let Some(failed_backends) = failed_backends {
766 let mut failed_backends_write = failed_backends.write().await;
767 let failed_attempts = failed_backends_write.get(&proxy_to);
768 failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1));
769 }
770 error_logger.log(&format!("Bad gateway: {err}")).await;
771 return Ok(
772 ResponseData::builder_without_request()
773 .status(StatusCode::BAD_GATEWAY)
774 .build(),
775 );
776 }
777 };
778
779 let send_request = sender.send_request(proxy_request);
780
781 let mut pinned_conn = Box::pin(conn);
782 tokio::pin!(send_request);
783
784 let response;
785
786 loop {
787 tokio::select! {
788 biased;
789
790 proxy_response = &mut send_request => {
791 let proxy_response = match proxy_response {
792 Ok(response) => response,
793 Err(err) => {
794 error_logger.log(&format!("Bad gateway: {err}")).await;
795 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
796 }
797 };
798
799 let status_code = proxy_response.status();
800 response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
801 ResponseData::builder_without_request()
802 .status(status_code)
803 .parallel_fn(async move {
804 pinned_conn.await.unwrap_or_default();
805 })
806 .build()
807 } else {
808 ResponseData::builder_without_request()
809 .response(proxy_response.map(|b| {
810 b.map_err(|e| std::io::Error::other(e.to_string()))
811 .boxed()
812 }))
813 .parallel_fn(async move {
814 pinned_conn.await.unwrap_or_default();
815 })
816 .build()
817 };
818
819 break;
820 },
821 state = &mut pinned_conn => {
822 if state.is_err() {
823 error_logger.log("Bad gateway: incomplete response").await;
824 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
825 }
826 },
827 };
828 }
829
830 if !sender.is_closed() {
831 let mut rwlock_write = connections.write().await;
832 rwlock_write.insert(connect_addr, sender);
833 drop(rwlock_write);
834 }
835
836 Ok(response)
837}
838
839async fn http_proxy_kept_alive(
840 sender: &mut SendRequest<BoxBody<Bytes, std::io::Error>>,
841 proxy_request: Request<BoxBody<Bytes, std::io::Error>>,
842 error_logger: &ErrorLogger,
843 proxy_intercept_errors: bool,
844) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
845 let proxy_response = match sender.send_request(proxy_request).await {
846 Ok(response) => response,
847 Err(err) => {
848 error_logger.log(&format!("Bad gateway: {err}")).await;
849 return Ok(
850 ResponseData::builder_without_request()
851 .status(StatusCode::BAD_GATEWAY)
852 .build(),
853 );
854 }
855 };
856
857 let status_code = proxy_response.status();
858 let response = if proxy_intercept_errors && status_code.as_u16() >= 400 {
859 ResponseData::builder_without_request()
860 .status(status_code)
861 .build()
862 } else {
863 ResponseData::builder_without_request()
864 .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed()))
865 .build()
866 };
867
868 Ok(response)
869}