diff --git a/src/example.rs b/src/example.rs index e2b0e56..d1dd6b3 100644 --- a/src/example.rs +++ b/src/example.rs @@ -14,7 +14,7 @@ use std::{ use ctor::ctor; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; -use iptrie::{IpPrefix, RTrieSet}; +use iptrie::IpPrefix; use serde::{ de::{Error, Visitor}, Deserialize, @@ -137,16 +137,17 @@ impl Default for IpCache { fn ignore(_: &mut smallvec::SmallVec<[T; 4]>) {} impl IpCache { - fn extend_set_with_domain>( + fn extend_set_with_domain>( &self, - ips: &mut RTrieSet, + ips: &mut NftData, domain_r: IpCacheKey, ) where T: Copy, + IpNet: From, { self.get_maybe_update_rev(domain_r, |val| { if let Some(val) = val { - ips.extend(val.0.iter().copied().map(|x| J::from(x))); + ips.extend(val.0.iter().copied().map(From::from)); } #[allow(unused_assignments)] let mut val = Some(ignore); @@ -381,10 +382,10 @@ impl ExampleMod { let rev_domain = IpCacheKey::from_split_rev_domain(rev_domain.into_iter()); self.caches .0 - .extend_set_with_domain(r.0.ips_mut(), rev_domain.clone()); + .extend_set_with_domain(&mut r.0, rev_domain.clone()); self.caches .1 - .extend_set_with_domain(r.1.ips_mut(), rev_domain.clone()); + .extend_set_with_domain(&mut r.1, rev_domain.clone()); } } } @@ -491,47 +492,46 @@ impl ExampleMod { split_domain: &[&[u8]], qnames: SmallVec<[usize; 5]>, ) -> Result<(), ()> { - println!("adding {ip4:?}/{ip6:?} for {split_domain:?} to {qnames:?}"); - if !ip4.is_empty() || !ip6.is_empty() { - let mut first = true; - let domain = match split_domain - .iter() - .copied() - .map(std::str::from_utf8) - .try_fold(String::new(), |mut s, comp| { - if first { - first = false; - } else { - s.push('.'); - } - s.push_str(comp?); - Ok::<_, std::str::Utf8Error>(s) - }) { - Ok(x) => x, - Err(err) => { - self.report("domain utf-8", err); - return Err(()); + if ip4.is_empty() && ip6.is_empty() { + return Ok(()); + } + let mut first = true; + let domain = match split_domain + .iter() + .copied() + .map(std::str::from_utf8) + .try_fold(String::new(), |mut s, comp| { + if first { + first = false; + } else { + s.push('.'); } - }; - let key = IpCacheKey::from_split_domain(split_domain.iter()); - let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new(); - to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); - to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); - let keep4 = !ip4.is_empty() && self.caches.0.set(&domain, key.clone(), ip4); - let keep6 = !ip6.is_empty() && self.caches.1.set(&domain, key, ip6); - to_send.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6); - if !to_send.is_empty() { - self.ruleset_queue - .as_ref() - .unwrap() - .send((qnames, to_send)) - .unwrap(); + s.push_str(comp?); + Ok::<_, std::str::Utf8Error>(s) + }) { + Ok(x) => x, + Err(err) => { + self.report("domain utf-8", err); + return Err(()); } + }; + let key = IpCacheKey::from_split_domain(split_domain.iter()); + let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new(); + to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); + to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); + let keep4 = !ip4.is_empty() && self.caches.0.set(&domain, key.clone(), ip4); + let keep6 = !ip6.is_empty() && self.caches.1.set(&domain, key, ip6); + to_send.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6); + if !to_send.is_empty() { + self.ruleset_queue + .as_ref() + .unwrap() + .send((qnames, to_send)) + .unwrap(); } Ok(()) } fn run_commands(&self, split_domain: &[&[u8]]) -> Option { - println!("{split_domain:?} {:?}", self.nft_token); if let Some(split_domain) = self.nft_token.as_ref().and_then(|token| { split_domain .split_last() diff --git a/src/nftables.rs b/src/nftables.rs index e3a7435..5ee736d 100644 --- a/src/nftables.rs +++ b/src/nftables.rs @@ -332,8 +332,8 @@ fn iter_ip_trie(trie: &RTrieSet) -> impl '_ + Iterator { } pub(crate) struct NftData { + all_ips: RTrieSet, ips: RTrieSet, - dirty: bool, set: Option, name: String, } @@ -343,7 +343,7 @@ impl NftData { Self { set: None, ips: RTrieSet::new(), - dirty: true, + all_ips: RTrieSet::new(), name: name.to_owned(), } } @@ -361,17 +361,24 @@ where pub fn verify(&mut self) -> bool { if !self.name.is_empty() && self.set.is_none() { self.ips = RTrieSet::new(); + self.all_ips = RTrieSet::new(); false } else { true } } + fn dirty(&self) -> bool { + usize::from(self.ips.len()) > 1 + } pub fn flush_changes( &mut self, socket: &mnl::Socket, flush_set: bool, ) -> Result<(), io::Error> { - if let Some(set) = self.set.as_mut().filter(|_| self.dirty) { + if !self.dirty() { + return Ok(()); + } + if let Some(set) = self.set.as_mut() { if flush_set { println!( "initializing set {} with ~{} ips (e.g. {:?})", @@ -380,7 +387,9 @@ where iter_ip_trie(&self.ips).next(), ); } - set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from)) + let ret = set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from)); + self.ips = RTrieSet::new(); + ret } else { Ok(()) } @@ -395,15 +404,12 @@ where !self.name.is_empty() } else { self.set.is_some() - }) && should_add(&self.ips, &ip) + }) && should_add(&self.all_ips, &ip) { self.ips.insert(ip); - self.dirty = true; + self.all_ips.insert(ip); } } - pub fn ips_mut(&mut self) -> &mut RTrieSet { - &mut self.ips - } #[cfg(test)] pub fn ip_count(&self) -> usize { iter_ip_trie(&self.ips).count()