ferron/util/
ip_blocklist.rs1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::net::{IpAddr, Ipv6Addr};
4
5use cidr::IpCidr;
6
7#[derive(Clone, Debug, PartialEq, Eq)]
9pub struct IpBlockList {
10 blocked_ips: HashSet<IpAddr>,
11 blocked_cidrs: HashSet<IpCidr>,
12}
13
14impl Default for IpBlockList {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl Ord for IpBlockList {
21 fn cmp(&self, other: &Self) -> Ordering {
22 self
23 .blocked_ips
24 .iter()
25 .cmp(other.blocked_ips.iter())
26 .then(self.blocked_cidrs.iter().cmp(other.blocked_cidrs.iter()))
27 }
28}
29
30impl PartialOrd for IpBlockList {
31 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
32 Some(self.cmp(other))
33 }
34}
35
36impl IpBlockList {
37 pub fn new() -> Self {
39 Self {
40 blocked_ips: HashSet::new(),
41 blocked_cidrs: HashSet::new(),
42 }
43 }
44
45 pub fn load_from_vec(&mut self, ip_list: Vec<&str>) {
47 for ip_str in ip_list {
48 match ip_str {
49 "localhost" => {
50 self.blocked_ips.insert(IpAddr::V6(Ipv6Addr::LOCALHOST));
51 }
52 _ => {
53 if let Ok(ip) = ip_str.parse::<IpAddr>() {
54 self.blocked_ips.insert(ip.to_canonical());
55 } else if let Ok(ip_cidr) = ip_str.parse::<IpCidr>() {
56 self.blocked_cidrs.insert(ip_cidr);
57 }
58 }
59 }
60 }
61 }
62
63 pub fn is_blocked(&self, ip: IpAddr) -> bool {
65 self.blocked_ips.contains(&ip.to_canonical())
66 || self
67 .blocked_cidrs
68 .iter()
69 .any(|cidr| cidr.contains(&ip.to_canonical()))
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 #[test]
78 fn test_ip_block_list() {
79 let mut block_list = IpBlockList::new();
80 block_list.load_from_vec(vec!["192.168.1.1", "10.0.0.1"]);
81
82 assert!(block_list.is_blocked("192.168.1.1".parse().unwrap()));
83 assert!(block_list.is_blocked("10.0.0.1".parse().unwrap()));
84 assert!(!block_list.is_blocked("8.8.8.8".parse().unwrap()));
85 }
86
87 #[test]
88 fn test_ip_cidr_block_list() {
89 let mut block_list = IpBlockList::new();
90 block_list.load_from_vec(vec!["192.168.1.0/24", "10.0.0.0/8"]);
91
92 assert!(block_list.is_blocked("192.168.1.1".parse().unwrap()));
93 assert!(block_list.is_blocked("10.0.0.1".parse().unwrap()));
94 assert!(!block_list.is_blocked("8.8.8.8".parse().unwrap()));
95 }
96}