From 303b157557776e98283442c9e4302cc4a34de79f Mon Sep 17 00:00:00 2001 From: chayleaf Date: Mon, 12 Aug 2024 22:58:50 +0700 Subject: [PATCH] some deduplication --- src/combine.rs | 13 ++- src/domain_tree.rs | 6 +- src/example.rs | 270 +++++++++++++++++++++------------------------ src/lib.rs | 20 +++- src/unbound.rs | 15 +++ 5 files changed, 170 insertions(+), 154 deletions(-) diff --git a/src/combine.rs b/src/combine.rs index e135094..434739f 100644 --- a/src/combine.rs +++ b/src/combine.rs @@ -1,5 +1,6 @@ use std::panic::{RefUnwindSafe, UnwindSafe}; +use crate::unbound::ModuleExtState; use crate::UnboundMod; macro_rules! impl_tuple { @@ -29,9 +30,15 @@ macro_rules! impl_tuple { qstate: &mut crate::unbound::ModuleQstate, event: crate::unbound::ModuleEvent, entry: &mut crate::unbound::OutboundEntryMut, - ) { - self.0.operate(qstate, event, entry); - $(self.$i.operate(qstate, event, entry);)* + ) -> Option { + #[allow(unused_mut)] + let mut ret = self.0.operate(qstate, event, entry); + $(if let Some(state) = self.$i.operate(qstate, event, entry) { + if !matches!(ret, Some(ret) if ret.importance() >= state.importance()) { + ret = Some(state); + } + })* + ret } fn get_mem(&self, env: &mut crate::unbound::ModuleEnv) -> usize { self.0.get_mem(env) $(* self.$i.get_mem(env))* diff --git a/src/domain_tree.rs b/src/domain_tree.rs index eaf47c2..cfebae9 100644 --- a/src/domain_tree.rs +++ b/src/domain_tree.rs @@ -63,7 +63,7 @@ impl PrefixSet { } struct Iter<'a, T>( - SmallVec<[std::collections::hash_map::Iter<'a, T, PrefixSet>; 8]>, + SmallVec<[std::collections::hash_map::Iter<'a, T, PrefixSet>; 9]>, SmallVec<[&'a T; 8]>, ); @@ -73,9 +73,7 @@ impl<'a, T> Iterator for Iter<'a, T> { while let Some(it) = self.0.last_mut() { let Some((k, v)) = it.next() else { self.0.pop(); - if self.1.pop().is_none() { - return None; - } + self.1.pop()?; continue; }; self.1.push(k); diff --git a/src/example.rs b/src/example.rs index aad2158..fa491c0 100644 --- a/src/example.rs +++ b/src/example.rs @@ -16,7 +16,7 @@ use std::{ use ctor::ctor; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; -use iptrie::{IpPrefix, RTrieSet}; +use iptrie::{IpPrefix, IpRootPrefix, RTrieSet}; use serde::{ de::{Error, Visitor}, Deserialize, @@ -87,8 +87,7 @@ struct ExampleMod { nft_token: Option, tmp_nft_token: Option, nft_queries: HashMap, - cache4: IpCache, - cache6: IpCache, + caches: (IpCache, IpCache), #[allow(clippy::type_complexity)] ruleset_queue: Option, smallvec::SmallVec<[IpNet; 8]>)>>, error_lock: Mutex<()>, @@ -123,31 +122,48 @@ impl Default for IpCache { } impl IpCache { - fn get_maybe_update_rev( + fn extend_set_with_domain>(&self, ips: &mut RTrieSet, domain_r: Domain) + where + T: Copy, + { + self.get_maybe_update_rev(domain_r, |val| { + if let Some(val) = val { + ips.extend(val.0.iter().copied().map(|x| J::from(x))); + } + fn ignore(_: &mut smallvec::SmallVec<[T; 4]>) {} + #[allow(unused_assignments)] + let mut val = Some(ignore::); + val = None; + val + }) + } + fn get_maybe_update_rev FnOnce(&'a mut smallvec::SmallVec<[T; 4]>)>( &self, domain_r: Domain, - upd: impl FnOnce(Option<&smallvec::SmallVec<[T; 4]>>) -> Option>, + upd: impl FnOnce(Option<(&smallvec::SmallVec<[T; 4]>, &Mutex<()>, &AtomicBool)>) -> Option, ) { let lock = self.0.read().unwrap(); let domain_r = IpCacheKey(domain_r); let key = lock.0.get(&domain_r).copied(); - if let Some(val) = if let Some(key) = key { - upd(lock.1.get(key).map(|x| x.0.read().unwrap()).as_deref()) + if let Some(val) = if let Some(x) = key.and_then(|key| lock.1.get(key)) { + upd(Some((&x.0.read().unwrap(), &x.1, &x.2))) } else { upd(None) } { if let Some(key) = key { - *lock.1.get(key).unwrap().0.write().unwrap() = val; + val(&mut *lock.1.get(key).unwrap().0.write().unwrap()); } else { drop(lock); let mut lock = self.0.write().unwrap(); if let Some(key) = lock.0.get(&domain_r).copied() { - *lock.1.get(key).unwrap().0.write().unwrap() = val; + val(&mut *lock.1.get(key).unwrap().0.write().unwrap()); } else { let key = lock.1.len(); lock.0.insert(domain_r, key); + let mut v = SmallVec::new(); + val(&mut v); lock.1 - .push((RwLock::new(val), Mutex::new(()), AtomicBool::new(false))); + .push((RwLock::new(v), Mutex::new(()), AtomicBool::new(true))); } } } @@ -156,62 +172,39 @@ impl IpCache { impl IpCache { fn set(&self, domain: &str, domain_r: IpCacheKey, val: smallvec::SmallVec<[T; 4]>) -> bool { - let lock = self.0.read().unwrap(); - let key = lock.0.get(&domain_r).copied(); - let to_write = val - .iter() - .map(|x| x.to_string()) - .collect::>() - .join("\n"); + let mut ret = true; + let ret1 = &mut ret; let mut path = self.1.clone(); path.push(domain); - let path1 = &path; - let finish = move |_lock| { - let Ok(mut file) = File::create(path1) else { - return; - }; - file.write_all(to_write.as_bytes()).unwrap_or(()); - }; - if let Some(key) = key { - let v = lock.1.get(key).unwrap(); - let mut lock = v.0.write().unwrap(); - if *lock == val { - if v.2 + self.get_maybe_update_rev(domain_r.0, |ips| { + if let Some(ips) = ips.as_ref().filter(|x| x.0 == &val) { + *ret1 = false; + if ips + .2 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) .is_ok() { let _ = filetime::set_file_mtime(path, filetime::FileTime::now()); } - return false; + return None; } - *lock = val; - finish(lock); - } else { - drop(lock); - let mut lock = self.0.write().unwrap(); - let key = if let Some(key) = lock.0.get(&domain_r).copied() { - key - } else { - let key = lock.1.len(); - lock.0.insert(domain_r, key); - lock.1 - .push((RwLock::new(val), Mutex::new(()), AtomicBool::new(false))); - key - }; - drop(lock); - finish( - self.0 - .read() - .unwrap() - .1 - .get(key) - .unwrap() - .0 - .write() - .unwrap(), - ); - } - true + Some(|ips: &mut SmallVec<_>| { + let Ok(mut file) = File::create(path) else { + *ips = val; + return; + }; + let to_write = val.iter().fold(String::new(), |mut s, ip| { + if !s.is_empty() { + s.push('\n'); + } + s.push_str(&ip.to_string()); + s + }); + file.write_all(to_write.as_bytes()).unwrap_or(()); + *ips = val; + }) + }); + ret } } @@ -280,6 +273,47 @@ struct NftData { name: String, } +impl NftData +where + IpNet: From, +{ + #[must_use] + fn verify(&mut self) -> bool { + if !self.name.is_empty() && self.set.is_none() { + self.ips = RTrieSet::new(); + false + } else { + true + } + } + fn flush_changes(&mut self, socket: &mnl::Socket, flush_set: bool) -> Result<(), io::Error> { + if let Some(set) = self.set.as_mut().filter(|_| self.dirty) { + if flush_set { + println!( + "initializing set {} with ~{} ips (e.g. {:?})", + self.name, + self.ips.len(), + iter_ip_trie(&self.ips).next(), + ); + } + set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from)) + } else { + Ok(()) + } + } + fn extend(&mut self, ips: impl Iterator) { + for ip in ips { + self.insert(ip); + } + } + fn insert(&mut self, ip: T) { + if self.set.is_some() && should_add(&self.ips, &ip) { + self.ips.insert(ip); + self.dirty = true; + } + } +} + // SAFETY: set are None initially and are never actually sent // (and Set1 might be fine to send anyway actually) unsafe impl Send for NftData {} @@ -444,10 +478,10 @@ impl UnboundMod for ExampleMod { } // load cached domains - if let Err(err) = ret.cache4.load(Path::new("/var/lib/unbound/domains4/")) { + if let Err(err) = ret.caches.0.load(Path::new("/var/lib/unbound/domains4/")) { ret.report("domains4", err); } - if let Err(err) = ret.cache6.load(Path::new("/var/lib/unbound/domains6/")) { + if let Err(err) = ret.caches.1.load(Path::new("/var/lib/unbound/domains6/")) { ret.report("domains6", err); } @@ -494,14 +528,14 @@ impl UnboundMod for ExampleMod { println!("loading {base}/{k}_ips.json"); match read_json::>(file) { Ok(ips) => { - r.0.ips.extend(ips.iter().filter_map(|x| { + r.0.extend(ips.iter().filter_map(|x| { if let IpNet::V4(x) = x.0 { Some(x) } else { None } })); - r.1.ips.extend(ips.iter().filter_map(|x| { + r.1.extend(ips.iter().filter_map(|x| { if let IpNet::V6(x) = x.0 { Some(x) } else { @@ -520,18 +554,12 @@ impl UnboundMod for ExampleMod { .collect::>() .join(&b"."[..]) .into(); - ret.cache4.get_maybe_update_rev(rev_domain.clone(), |val| { - if let Some(val) = val { - r.0.ips.extend(val.iter().map(|x| Ipv4Net::from(*x))); - } - None - }); - ret.cache6.get_maybe_update_rev(rev_domain, |val| { - if let Some(val) = val { - r.1.ips.extend(val.iter().map(|x| Ipv6Net::from(*x))); - } - None - }); + ret.caches + .0 + .extend_set_with_domain(&mut r.0.ips, rev_domain.clone()); + ret.caches + .1 + .extend_set_with_domain(&mut r.1.ips, rev_domain.clone()); } } @@ -574,51 +602,21 @@ impl UnboundMod for ExampleMod { } } for ruleset in &mut rulesets { - if !ruleset.0.name.is_empty() && ruleset.0.set.is_none() { + if !ruleset.0.verify() { report(format!("set {} not found", ruleset.0.name)); - ruleset.0.ips = RTrieSet::new(); } - if !ruleset.1.name.is_empty() && ruleset.1.set.is_none() { + if !ruleset.1.verify() { report(format!("set {} not found", ruleset.1.name)); - ruleset.1.ips = RTrieSet::new(); } } let mut first = true; loop { for ruleset in &mut rulesets { - if let Some(set) = ruleset.0.set.as_mut().filter(|_| ruleset.0.dirty) { - if first { - println!( - "initializing set {} with ~{} ips (e.g. {:?})", - ruleset.0.name, - ruleset.0.ips.len(), - iter_ip_trie(&ruleset.0.ips).next(), - ); - } - if let Err(err) = set.add_cidrs( - &socket, - first, - iter_ip_trie(&ruleset.0.ips).map(IpNet::V4), - ) { - report(err); - } + if let Err(err) = ruleset.0.flush_changes(&socket, first) { + report(err); } - if let Some(set) = ruleset.1.set.as_mut().filter(|_| ruleset.1.dirty) { - if first { - println!( - "initializing set {} with ~{} ips (e.g. {:?})", - ruleset.1.name, - ruleset.1.ips.len(), - iter_ip_trie(&ruleset.1.ips).next(), - ); - } - if let Err(err) = set.add_cidrs( - &socket, - first, - iter_ip_trie(&ruleset.1.ips).map(IpNet::V6), - ) { - report(err); - } + if let Err(err) = ruleset.1.flush_changes(&socket, first) { + report(err); } } if first { @@ -634,18 +632,8 @@ impl UnboundMod for ExampleMod { let ruleset = &mut rulesets[i]; for ip1 in ips.iter().copied() { match ip1 { - IpNet::V4(ip) => { - if ruleset.0.set.is_some() && !should_add(&ruleset.0.ips, &ip) { - ruleset.0.ips.insert(ip); - ruleset.0.dirty = true; - } - } - IpNet::V6(ip) => { - if ruleset.1.set.is_some() && !should_add(&ruleset.1.ips, &ip) { - ruleset.1.ips.insert(ip); - ruleset.1.dirty = true; - } - } + IpNet::V4(ip) => ruleset.0.insert(ip), + IpNet::V6(ip) => ruleset.1.insert(ip), } } } @@ -662,16 +650,14 @@ impl UnboundMod for ExampleMod { qstate: &mut crate::unbound::ModuleQstate, event: ModuleEvent, _entry: &mut crate::unbound::OutboundEntryMut, - ) { + ) -> Option { match event { ModuleEvent::New | ModuleEvent::Pass => { - qstate.set_ext_state(ModuleExtState::WaitModule); - return; + return Some(ModuleExtState::WaitModule); } ModuleEvent::ModDone => {} _ => { - qstate.set_ext_state(ModuleExtState::Error); - return; + return Some(ModuleExtState::Error); } } let info = qstate.qinfo_mut(); @@ -734,8 +720,7 @@ impl UnboundMod for ExampleMod { } } } - qstate.set_ext_state(ModuleExtState::Finished); - return; + return Some(ModuleExtState::Finished); } else if let Some(rev_domain) = self .tmp_nft_token .as_ref() @@ -755,8 +740,7 @@ impl UnboundMod for ExampleMod { } } } - qstate.set_ext_state(ModuleExtState::Finished); - return; + return Some(ModuleExtState::Finished); } let split_rev_domain = rev_domain .split(|x| *x == b'.') @@ -769,8 +753,7 @@ impl UnboundMod for ExampleMod { } } if qnames.is_empty() { - qstate.set_ext_state(ModuleExtState::Finished); - return; + return Some(ModuleExtState::Finished); } if let Some(ret) = qstate.return_msg_mut() { if let Some(rep) = ret.rep() { @@ -812,8 +795,7 @@ impl UnboundMod for ExampleMod { } Err(err) => { self.report("domain utf-8", err); - qstate.set_ext_state(ModuleExtState::Error); - return; + return Some(ModuleExtState::Error); } }; let mut split_rev_domain = split_rev_domain.into_iter(); @@ -829,13 +811,17 @@ impl UnboundMod for ExampleMod { to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); let keep4 = !ip4.is_empty() - && self - .cache4 - .set(&domain, IpCacheKey(joined_rev_domain.clone()), ip4); + && self.caches.0.set( + &domain, + IpCacheKey(joined_rev_domain.clone()), + ip4, + ); let keep6 = !ip6.is_empty() - && self - .cache6 - .set(&domain, IpCacheKey(joined_rev_domain.clone()), ip6); + && self.caches.1.set( + &domain, + IpCacheKey(joined_rev_domain.clone()), + ip6, + ); to_send .retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6); if !to_send.is_empty() { @@ -849,7 +835,7 @@ impl UnboundMod for ExampleMod { } } } - qstate.set_ext_state(ModuleExtState::Finished); + Some(ModuleExtState::Finished) } } diff --git a/src/lib.rs b/src/lib.rs index d470440..e13456d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,17 @@ use std::panic::{RefUnwindSafe, UnwindSafe}; + +use unbound::ModuleExtState; #[allow( dead_code, improper_ctypes, non_camel_case_types, non_snake_case, non_upper_case_globals, - unused_imports + unused_imports, + clippy::useless_transmute, + clippy::type_complexity, + clippy::too_many_arguments, + clippy::upper_case_acronyms )] mod bindings; mod combine; @@ -33,7 +39,8 @@ pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe { _qstate: &mut unbound::ModuleQstate, _event: unbound::ModuleEvent, _entry: &mut unbound::OutboundEntryMut, - ) { + ) -> Option { + Some(ModuleExtState::Finished) } fn inform_super( &self, @@ -101,11 +108,14 @@ unsafe impl SealedUnboundMod for T { entry: *mut bindings::outbound_entry, ) { std::panic::catch_unwind(|| { - self.operate( - &mut unbound::ModuleQstate(qstate, id, Default::default()), + let mut qstate = unbound::ModuleQstate(qstate, id, Default::default()); + if let Some(ext_state) = self.operate( + &mut qstate, event.into(), &mut unbound::OutboundEntryMut(entry, Default::default()), - ) + ) { + qstate.set_ext_state(ext_state); + } }) .unwrap_or(()); } diff --git a/src/unbound.rs b/src/unbound.rs index 791e79b..6ac6812 100644 --- a/src/unbound.rs +++ b/src/unbound.rs @@ -609,6 +609,21 @@ pub enum ModuleExtState { Unknown = 99, } +impl ModuleExtState { + pub(crate) fn importance(&self) -> usize { + match *self { + Self::Unknown => 0, + Self::InitialState => 1, + Self::Finished => 2, + Self::WaitModule => 3, + Self::RestartNext => 4, + Self::WaitSubquery => 5, + Self::WaitReply => 6, + Self::Error => 7, + } + } +} + impl From for ModuleExtState { fn from(value: module_ext_state) -> Self { match value {