1use std::collections::HashMap;
4use std::error::Error;
5use std::str::FromStr;
6use std::sync::Arc;
7
8use crate::ferron_common::{
9 ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule,
10 ServerModuleHandlers, SocketData,
11};
12use crate::ferron_common::{HyperResponse, WithRuntime};
13use async_trait::async_trait;
14use http_body_util::combinators::BoxBody;
15use http_body_util::{BodyExt, Empty};
16use hyper::body::Bytes;
17use hyper::client::conn::http1::SendRequest;
18use hyper::header::{self, HeaderName};
19use hyper::{Method, Request, StatusCode, Uri, Version};
20use hyper_tungstenite::HyperWebsocket;
21use hyper_util::rt::TokioIo;
22use rustls::pki_types::ServerName;
23use rustls::RootCertStore;
24use rustls_native_certs::load_native_certs;
25use tokio::io::{AsyncRead, AsyncWrite};
26use tokio::net::TcpStream;
27use tokio::runtime::Handle;
28use tokio::sync::RwLock;
29use tokio_rustls::TlsConnector;
30
31const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32;
32
33pub fn server_module_init(
34 _config: &ServerConfig,
35) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
36 let mut roots: RootCertStore = RootCertStore::empty();
37 let certs_result = load_native_certs();
38 if !certs_result.errors.is_empty() {
39 Err(anyhow::anyhow!(
40 "Couldn't load the native certificate store: {}",
41 certs_result.errors[0]
42 ))?
43 }
44 let certs = certs_result.certs;
45
46 for cert in certs {
47 match roots.add(cert) {
48 Ok(_) => (),
49 Err(err) => Err(anyhow::anyhow!(
50 "Couldn't add a certificate to the certificate store: {}",
51 err
52 ))?,
53 }
54 }
55
56 let mut connections_vec = Vec::new();
57 for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST {
58 connections_vec.push(RwLock::new(HashMap::new()));
59 }
60 Ok(Box::new(ForwardedAuthenticationModule::new(
61 Arc::new(roots),
62 Arc::new(connections_vec),
63 )))
64}
65
66#[allow(clippy::type_complexity)]
67struct ForwardedAuthenticationModule {
68 roots: Arc<RootCertStore>,
69 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, hyper::Error>>>>>>,
70}
71
72impl ForwardedAuthenticationModule {
73 #[allow(clippy::type_complexity)]
74 fn new(
75 roots: Arc<RootCertStore>,
76 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, hyper::Error>>>>>>,
77 ) -> Self {
78 Self { roots, connections }
79 }
80}
81
82impl ServerModule for ForwardedAuthenticationModule {
83 fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
84 Box::new(ForwardedAuthenticationModuleHandlers {
85 roots: self.roots.clone(),
86 connections: self.connections.clone(),
87 handle,
88 })
89 }
90}
91
92#[allow(clippy::type_complexity)]
93struct ForwardedAuthenticationModuleHandlers {
94 handle: Handle,
95 roots: Arc<RootCertStore>,
96 connections: Arc<Vec<RwLock<HashMap<String, SendRequest<BoxBody<Bytes, hyper::Error>>>>>>,
97}
98
99#[async_trait]
100impl ServerModuleHandlers for ForwardedAuthenticationModuleHandlers {
101 async fn request_handler(
102 &mut self,
103 request: RequestData,
104 config: &ServerConfig,
105 socket_data: &SocketData,
106 error_logger: &ErrorLogger,
107 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
108 WithRuntime::new(self.handle.clone(), async move {
109 let mut auth_to = None;
110
111 if let Some(auth_to_str) = config["authTo"].as_str() {
112 auth_to = Some(auth_to_str.to_string());
113 }
114
115 let forwarded_auth_copy_headers = match config["forwardedAuthCopyHeaders"].as_vec() {
116 Some(vector) => {
117 let mut new_vector = Vec::new();
118 for yaml_value in vector.iter() {
119 if let Some(str_value) = yaml_value.as_str() {
120 new_vector.push(str_value.to_string());
121 }
122 }
123 new_vector
124 }
125 None => Vec::new(),
126 };
127
128 if let Some(auth_to) = auth_to {
129 let (hyper_request, auth_user, original_url, error_status_code) = request.into_parts();
130 let (hyper_request_parts, request_body) = hyper_request.into_parts();
131
132 let auth_request_url = auth_to.parse::<hyper::Uri>()?;
133 let scheme_str = auth_request_url.scheme_str();
134 let mut encrypted = false;
135
136 match scheme_str {
137 Some("http") => {
138 encrypted = false;
139 }
140 Some("https") => {
141 encrypted = true;
142 }
143 _ => Err(anyhow::anyhow!(
144 "Only HTTP and HTTPS reverse proxy URLs are supported."
145 ))?,
146 };
147
148 let host = match auth_request_url.host() {
149 Some(host) => host,
150 None => Err(anyhow::anyhow!(
151 "The reverse proxy URL doesn't include the host"
152 ))?,
153 };
154
155 let port = auth_request_url.port_u16().unwrap_or(match scheme_str {
156 Some("http") => 80,
157 Some("https") => 443,
158 _ => 80,
159 });
160
161 let addr = format!("{host}:{port}");
162 let authority = auth_request_url.authority().cloned();
163
164 let hyper_request_path = hyper_request_parts.uri.path();
165
166 let path_and_query = format!(
167 "{}{}",
168 hyper_request_path,
169 match hyper_request_parts.uri.query() {
170 Some(query) => format!("?{query}"),
171 None => "".to_string(),
172 }
173 );
174
175 let mut auth_hyper_request_parts = hyper_request_parts.clone();
176
177 auth_hyper_request_parts.uri = Uri::from_str(&format!(
178 "{}{}",
179 auth_request_url.path(),
180 match auth_request_url.query() {
181 Some(query) => format!("?{query}"),
182 None => "".to_string(),
183 }
184 ))?;
185
186 let original_host = hyper_request_parts.headers.get(header::HOST).cloned();
187
188 match authority {
190 Some(authority) => {
191 auth_hyper_request_parts
192 .headers
193 .insert(header::HOST, authority.to_string().parse()?);
194 }
195 None => {
196 auth_hyper_request_parts.headers.remove(header::HOST);
197 }
198 }
199
200 auth_hyper_request_parts
202 .headers
203 .insert(header::CONNECTION, "keep-alive".parse()?);
204
205 auth_hyper_request_parts.headers.insert(
207 HeaderName::from_static("x-forwarded-for"),
208 socket_data
209 .remote_addr
210 .ip()
211 .to_canonical()
212 .to_string()
213 .parse()?,
214 );
215
216 if socket_data.encrypted {
217 auth_hyper_request_parts.headers.insert(
218 HeaderName::from_static("x-forwarded-proto"),
219 "https".parse()?,
220 );
221 } else {
222 auth_hyper_request_parts.headers.insert(
223 HeaderName::from_static("x-forwarded-proto"),
224 "http".parse()?,
225 );
226 }
227
228 if let Some(original_host) = original_host {
229 auth_hyper_request_parts
230 .headers
231 .insert(HeaderName::from_static("x-forwarded-host"), original_host);
232 }
233
234 auth_hyper_request_parts.headers.insert(
235 HeaderName::from_static("x-forwarded-uri"),
236 path_and_query.parse()?,
237 );
238
239 auth_hyper_request_parts.headers.insert(
240 HeaderName::from_static("x-forwarded-method"),
241 hyper_request_parts.method.as_str().parse()?,
242 );
243
244 auth_hyper_request_parts.method = Method::GET;
245 auth_hyper_request_parts.version = Version::HTTP_11;
246
247 let auth_request = Request::from_parts(
248 auth_hyper_request_parts,
249 Empty::new().map_err(|e| match e {}).boxed(),
250 );
251 let original_hyper_request = Request::from_parts(hyper_request_parts, request_body);
252 let original_request = RequestData::new(
253 original_hyper_request,
254 auth_user,
255 original_url,
256 error_status_code,
257 );
258
259 let connections = &self.connections[rand::random_range(..self.connections.len())];
260
261 let rwlock_read = connections.read().await;
262 let sender_read_option = rwlock_read.get(&addr);
263
264 if let Some(sender_read) = sender_read_option {
265 if !sender_read.is_closed() {
266 drop(rwlock_read);
267 let mut rwlock_write = connections.write().await;
268 let sender_option = rwlock_write.get_mut(&addr);
269
270 if let Some(sender) = sender_option {
271 if !sender.is_closed() && sender.ready().await.is_ok() {
272 let result = http_forwarded_auth_kept_alive(
273 sender,
274 auth_request,
275 error_logger,
276 original_request,
277 forwarded_auth_copy_headers,
278 )
279 .await;
280 drop(rwlock_write);
281 return result;
282 } else {
283 drop(rwlock_write);
284 }
285 } else {
286 drop(rwlock_write);
287 }
288 } else {
289 drop(rwlock_read);
290 }
291 } else {
292 drop(rwlock_read);
293 }
294
295 let stream = match TcpStream::connect(&addr).await {
296 Ok(stream) => stream,
297 Err(err) => {
298 match err.kind() {
299 tokio::io::ErrorKind::ConnectionRefused
300 | tokio::io::ErrorKind::NotFound
301 | tokio::io::ErrorKind::HostUnreachable => {
302 error_logger
303 .log(&format!("Service unavailable: {err}"))
304 .await;
305 return Ok(
306 ResponseData::builder_without_request()
307 .status(StatusCode::SERVICE_UNAVAILABLE)
308 .build(),
309 );
310 }
311 tokio::io::ErrorKind::TimedOut => {
312 error_logger.log(&format!("Gateway timeout: {err}")).await;
313 return Ok(
314 ResponseData::builder_without_request()
315 .status(StatusCode::GATEWAY_TIMEOUT)
316 .build(),
317 );
318 }
319 _ => {
320 error_logger.log(&format!("Bad gateway: {err}")).await;
321 return Ok(
322 ResponseData::builder_without_request()
323 .status(StatusCode::BAD_GATEWAY)
324 .build(),
325 );
326 }
327 };
328 }
329 };
330
331 match stream.set_nodelay(true) {
332 Ok(_) => (),
333 Err(err) => {
334 error_logger.log(&format!("Bad gateway: {err}")).await;
335 return Ok(
336 ResponseData::builder_without_request()
337 .status(StatusCode::BAD_GATEWAY)
338 .build(),
339 );
340 }
341 };
342
343 if !encrypted {
344 http_forwarded_auth(
345 connections,
346 addr,
347 stream,
348 auth_request,
349 error_logger,
350 original_request,
351 forwarded_auth_copy_headers,
352 )
353 .await
354 } else {
355 let tls_client_config = rustls::ClientConfig::builder()
356 .with_root_certificates(self.roots.clone())
357 .with_no_client_auth();
358 let connector = TlsConnector::from(Arc::new(tls_client_config));
359 let domain = ServerName::try_from(host)?.to_owned();
360
361 let tls_stream = match connector.connect(domain, stream).await {
362 Ok(stream) => stream,
363 Err(err) => {
364 error_logger.log(&format!("Bad gateway: {err}")).await;
365 return Ok(
366 ResponseData::builder_without_request()
367 .status(StatusCode::BAD_GATEWAY)
368 .build(),
369 );
370 }
371 };
372
373 http_forwarded_auth(
374 connections,
375 addr,
376 tls_stream,
377 auth_request,
378 error_logger,
379 original_request,
380 forwarded_auth_copy_headers,
381 )
382 .await
383 }
384 } else {
385 Ok(ResponseData::builder(request).build())
386 }
387 })
388 .await
389 }
390
391 async fn proxy_request_handler(
392 &mut self,
393 request: RequestData,
394 _config: &ServerConfig,
395 _socket_data: &SocketData,
396 _error_logger: &ErrorLogger,
397 ) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
398 Ok(ResponseData::builder(request).build())
399 }
400
401 async fn response_modifying_handler(
402 &mut self,
403 response: HyperResponse,
404 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
405 Ok(response)
406 }
407
408 async fn proxy_response_modifying_handler(
409 &mut self,
410 response: HyperResponse,
411 ) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
412 Ok(response)
413 }
414
415 async fn connect_proxy_request_handler(
416 &mut self,
417 _upgraded_request: HyperUpgraded,
418 _connect_address: &str,
419 _config: &ServerConfig,
420 _socket_data: &SocketData,
421 _error_logger: &ErrorLogger,
422 ) -> Result<(), Box<dyn Error + Send + Sync>> {
423 Ok(())
424 }
425
426 fn does_connect_proxy_requests(&mut self) -> bool {
427 false
428 }
429
430 async fn websocket_request_handler(
431 &mut self,
432 _websocket: HyperWebsocket,
433 _uri: &hyper::Uri,
434 _headers: &hyper::HeaderMap,
435 _config: &ServerConfig,
436 _socket_data: &SocketData,
437 _error_logger: &ErrorLogger,
438 ) -> Result<(), Box<dyn Error + Send + Sync>> {
439 Ok(())
440 }
441
442 fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
443 false
444 }
445}
446
447async fn http_forwarded_auth(
448 connections: &RwLock<HashMap<String, SendRequest<BoxBody<Bytes, hyper::Error>>>>,
449 connect_addr: String,
450 stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static,
451 proxy_request: Request<BoxBody<Bytes, hyper::Error>>,
452 error_logger: &ErrorLogger,
453 mut original_request: RequestData,
454 forwarded_auth_copy_headers: Vec<String>,
455) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
456 let io = TokioIo::new(stream);
457
458 let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
459 Ok(data) => data,
460 Err(err) => {
461 error_logger.log(&format!("Bad gateway: {err}")).await;
462 return Ok(
463 ResponseData::builder_without_request()
464 .status(StatusCode::BAD_GATEWAY)
465 .build(),
466 );
467 }
468 };
469
470 let send_request = sender.send_request(proxy_request);
471
472 let mut pinned_conn = Box::pin(conn);
473 tokio::pin!(send_request);
474
475 let response;
476
477 loop {
478 tokio::select! {
479 biased;
480
481 proxy_response = &mut send_request => {
482 let proxy_response = match proxy_response {
483 Ok(response) => response,
484 Err(err) => {
485 error_logger.log(&format!("Bad gateway: {err}")).await;
486 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
487 }
488 };
489
490 if proxy_response.status().is_success() {
491 if !forwarded_auth_copy_headers.is_empty() {
492 let response_headers = proxy_response.headers();
493 let request_headers = original_request.get_mut_hyper_request().headers_mut();
494 for forwarded_auth_copy_header_string in forwarded_auth_copy_headers.iter() {
495 let forwarded_auth_copy_header= HeaderName::from_str(forwarded_auth_copy_header_string)?;
496 if response_headers.contains_key(&forwarded_auth_copy_header) {
497 while request_headers.remove(&forwarded_auth_copy_header).is_some() {}
498 for header_value in response_headers.get_all(&forwarded_auth_copy_header).iter() {
499 request_headers.append(&forwarded_auth_copy_header, header_value.clone());
500 }
501 }
502 }
503 }
504 response = ResponseData::builder(original_request).build();
505 } else {
506 response = ResponseData::builder_without_request()
507 .response(proxy_response.map(|b| {
508 b.map_err(|e| std::io::Error::other(e.to_string()))
509 .boxed()
510 }))
511 .parallel_fn(async move {
512 pinned_conn.await.unwrap_or_default();
513 })
514 .build();
515
516 }
517
518 break;
519 },
520 state = &mut pinned_conn => {
521 if state.is_err() {
522 error_logger.log("Bad gateway: incomplete response").await;
523 return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build());
524 }
525 },
526 };
527 }
528
529 if !sender.is_closed() {
530 let mut rwlock_write = connections.write().await;
531 rwlock_write.insert(connect_addr, sender);
532 drop(rwlock_write);
533 }
534
535 Ok(response)
536}
537
538async fn http_forwarded_auth_kept_alive(
539 sender: &mut SendRequest<BoxBody<Bytes, hyper::Error>>,
540 proxy_request: Request<BoxBody<Bytes, hyper::Error>>,
541 error_logger: &ErrorLogger,
542 mut original_request: RequestData,
543 forwarded_auth_copy_headers: Vec<String>,
544) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
545 let proxy_response = match sender.send_request(proxy_request).await {
546 Ok(response) => response,
547 Err(err) => {
548 error_logger.log(&format!("Bad gateway: {err}")).await;
549 return Ok(
550 ResponseData::builder_without_request()
551 .status(StatusCode::BAD_GATEWAY)
552 .build(),
553 );
554 }
555 };
556
557 let response = if proxy_response.status().is_success() {
558 if !forwarded_auth_copy_headers.is_empty() {
559 let response_headers = proxy_response.headers();
560 let request_headers = original_request.get_mut_hyper_request().headers_mut();
561 for forwarded_auth_copy_header_string in forwarded_auth_copy_headers.iter() {
562 let forwarded_auth_copy_header = HeaderName::from_str(forwarded_auth_copy_header_string)?;
563 if response_headers.contains_key(&forwarded_auth_copy_header) {
564 while request_headers
565 .remove(&forwarded_auth_copy_header)
566 .is_some()
567 {}
568 for header_value in response_headers.get_all(&forwarded_auth_copy_header).iter() {
569 request_headers.append(&forwarded_auth_copy_header, header_value.clone());
570 }
571 }
572 }
573 }
574 ResponseData::builder(original_request).build()
575 } else {
576 ResponseData::builder_without_request()
577 .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed()))
578 .build()
579 };
580
581 Ok(response)
582}