some deduplication

This commit is contained in:
chayleaf 2024-08-12 22:58:50 +07:00
parent 14346134b5
commit 303b157557
Signed by: chayleaf
GPG key ID: 78171AD46227E68E
5 changed files with 170 additions and 154 deletions

View file

@ -1,5 +1,6 @@
use std::panic::{RefUnwindSafe, UnwindSafe}; use std::panic::{RefUnwindSafe, UnwindSafe};
use crate::unbound::ModuleExtState;
use crate::UnboundMod; use crate::UnboundMod;
macro_rules! impl_tuple { macro_rules! impl_tuple {
@ -29,9 +30,15 @@ macro_rules! impl_tuple {
qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>, qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>,
event: crate::unbound::ModuleEvent, event: crate::unbound::ModuleEvent,
entry: &mut crate::unbound::OutboundEntryMut, entry: &mut crate::unbound::OutboundEntryMut,
) { ) -> Option<ModuleExtState> {
self.0.operate(qstate, event, entry); #[allow(unused_mut)]
$(self.$i.operate(qstate, event, entry);)* let mut ret = self.0.operate(qstate, event, entry);
$(if let Some(state) = self.$i.operate(qstate, event, entry) {
if !matches!(ret, Some(ret) if ret.importance() >= state.importance()) {
ret = Some(state);
}
})*
ret
} }
fn get_mem(&self, env: &mut crate::unbound::ModuleEnv<Self::EnvData>) -> usize { fn get_mem(&self, env: &mut crate::unbound::ModuleEnv<Self::EnvData>) -> usize {
self.0.get_mem(env) $(* self.$i.get_mem(env))* self.0.get_mem(env) $(* self.$i.get_mem(env))*

View file

@ -63,7 +63,7 @@ impl<T: Hash + Eq> PrefixSet<T> {
} }
struct Iter<'a, T>( struct Iter<'a, T>(
SmallVec<[std::collections::hash_map::Iter<'a, T, PrefixSet<T>>; 8]>, SmallVec<[std::collections::hash_map::Iter<'a, T, PrefixSet<T>>; 9]>,
SmallVec<[&'a T; 8]>, SmallVec<[&'a T; 8]>,
); );
@ -73,9 +73,7 @@ impl<'a, T> Iterator for Iter<'a, T> {
while let Some(it) = self.0.last_mut() { while let Some(it) = self.0.last_mut() {
let Some((k, v)) = it.next() else { let Some((k, v)) = it.next() else {
self.0.pop(); self.0.pop();
if self.1.pop().is_none() { self.1.pop()?;
return None;
}
continue; continue;
}; };
self.1.push(k); self.1.push(k);

View file

@ -16,7 +16,7 @@ use std::{
use ctor::ctor; use ctor::ctor;
use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use iptrie::{IpPrefix, RTrieSet}; use iptrie::{IpPrefix, IpRootPrefix, RTrieSet};
use serde::{ use serde::{
de::{Error, Visitor}, de::{Error, Visitor},
Deserialize, Deserialize,
@ -87,8 +87,7 @@ struct ExampleMod {
nft_token: Option<String>, nft_token: Option<String>,
tmp_nft_token: Option<String>, tmp_nft_token: Option<String>,
nft_queries: HashMap<String, NftQuery>, nft_queries: HashMap<String, NftQuery>,
cache4: IpCache<Ipv4Addr>, caches: (IpCache<Ipv4Addr>, IpCache<Ipv6Addr>),
cache6: IpCache<Ipv6Addr>,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
ruleset_queue: Option<mpsc::Sender<(SmallVec<[usize; 5]>, smallvec::SmallVec<[IpNet; 8]>)>>, ruleset_queue: Option<mpsc::Sender<(SmallVec<[usize; 5]>, smallvec::SmallVec<[IpNet; 8]>)>>,
error_lock: Mutex<()>, error_lock: Mutex<()>,
@ -123,31 +122,48 @@ impl<T> Default for IpCache<T> {
} }
impl<T> IpCache<T> { impl<T> IpCache<T> {
fn get_maybe_update_rev( fn extend_set_with_domain<J: IpPrefix + From<T>>(&self, ips: &mut RTrieSet<J>, domain_r: Domain)
where
T: Copy,
{
self.get_maybe_update_rev(domain_r, |val| {
if let Some(val) = val {
ips.extend(val.0.iter().copied().map(|x| J::from(x)));
}
fn ignore<T>(_: &mut smallvec::SmallVec<[T; 4]>) {}
#[allow(unused_assignments)]
let mut val = Some(ignore::<T>);
val = None;
val
})
}
fn get_maybe_update_rev<F: for<'a> FnOnce(&'a mut smallvec::SmallVec<[T; 4]>)>(
&self, &self,
domain_r: Domain, domain_r: Domain,
upd: impl FnOnce(Option<&smallvec::SmallVec<[T; 4]>>) -> Option<smallvec::SmallVec<[T; 4]>>, upd: impl FnOnce(Option<(&smallvec::SmallVec<[T; 4]>, &Mutex<()>, &AtomicBool)>) -> Option<F>,
) { ) {
let lock = self.0.read().unwrap(); let lock = self.0.read().unwrap();
let domain_r = IpCacheKey(domain_r); let domain_r = IpCacheKey(domain_r);
let key = lock.0.get(&domain_r).copied(); let key = lock.0.get(&domain_r).copied();
if let Some(val) = if let Some(key) = key { if let Some(val) = if let Some(x) = key.and_then(|key| lock.1.get(key)) {
upd(lock.1.get(key).map(|x| x.0.read().unwrap()).as_deref()) upd(Some((&x.0.read().unwrap(), &x.1, &x.2)))
} else { } else {
upd(None) upd(None)
} { } {
if let Some(key) = key { if let Some(key) = key {
*lock.1.get(key).unwrap().0.write().unwrap() = val; val(&mut *lock.1.get(key).unwrap().0.write().unwrap());
} else { } else {
drop(lock); drop(lock);
let mut lock = self.0.write().unwrap(); let mut lock = self.0.write().unwrap();
if let Some(key) = lock.0.get(&domain_r).copied() { if let Some(key) = lock.0.get(&domain_r).copied() {
*lock.1.get(key).unwrap().0.write().unwrap() = val; val(&mut *lock.1.get(key).unwrap().0.write().unwrap());
} else { } else {
let key = lock.1.len(); let key = lock.1.len();
lock.0.insert(domain_r, key); lock.0.insert(domain_r, key);
let mut v = SmallVec::new();
val(&mut v);
lock.1 lock.1
.push((RwLock::new(val), Mutex::new(()), AtomicBool::new(false))); .push((RwLock::new(v), Mutex::new(()), AtomicBool::new(true)));
} }
} }
} }
@ -156,62 +172,39 @@ impl<T> IpCache<T> {
impl<T: ToString + PartialEq> IpCache<T> { impl<T: ToString + PartialEq> IpCache<T> {
fn set(&self, domain: &str, domain_r: IpCacheKey, val: smallvec::SmallVec<[T; 4]>) -> bool { fn set(&self, domain: &str, domain_r: IpCacheKey, val: smallvec::SmallVec<[T; 4]>) -> bool {
let lock = self.0.read().unwrap(); let mut ret = true;
let key = lock.0.get(&domain_r).copied(); let ret1 = &mut ret;
let to_write = val
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join("\n");
let mut path = self.1.clone(); let mut path = self.1.clone();
path.push(domain); path.push(domain);
let path1 = &path; self.get_maybe_update_rev(domain_r.0, |ips| {
let finish = move |_lock| { if let Some(ips) = ips.as_ref().filter(|x| x.0 == &val) {
let Ok(mut file) = File::create(path1) else { *ret1 = false;
return; if ips
}; .2
file.write_all(to_write.as_bytes()).unwrap_or(());
};
if let Some(key) = key {
let v = lock.1.get(key).unwrap();
let mut lock = v.0.write().unwrap();
if *lock == val {
if v.2
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_ok() .is_ok()
{ {
let _ = filetime::set_file_mtime(path, filetime::FileTime::now()); let _ = filetime::set_file_mtime(path, filetime::FileTime::now());
} }
return false; return None;
} }
*lock = val; Some(|ips: &mut SmallVec<_>| {
finish(lock); let Ok(mut file) = File::create(path) else {
} else { *ips = val;
drop(lock); return;
let mut lock = self.0.write().unwrap();
let key = if let Some(key) = lock.0.get(&domain_r).copied() {
key
} else {
let key = lock.1.len();
lock.0.insert(domain_r, key);
lock.1
.push((RwLock::new(val), Mutex::new(()), AtomicBool::new(false)));
key
}; };
drop(lock); let to_write = val.iter().fold(String::new(), |mut s, ip| {
finish( if !s.is_empty() {
self.0 s.push('\n');
.read()
.unwrap()
.1
.get(key)
.unwrap()
.0
.write()
.unwrap(),
);
} }
true s.push_str(&ip.to_string());
s
});
file.write_all(to_write.as_bytes()).unwrap_or(());
*ips = val;
})
});
ret
} }
} }
@ -280,6 +273,47 @@ struct NftData<T: IpPrefix> {
name: String, name: String,
} }
impl<T: IpPrefix + IpRootPrefix + Helper> NftData<T>
where
IpNet: From<T>,
{
#[must_use]
fn verify(&mut self) -> bool {
if !self.name.is_empty() && self.set.is_none() {
self.ips = RTrieSet::new();
false
} else {
true
}
}
fn flush_changes(&mut self, socket: &mnl::Socket, flush_set: bool) -> Result<(), io::Error> {
if let Some(set) = self.set.as_mut().filter(|_| self.dirty) {
if flush_set {
println!(
"initializing set {} with ~{} ips (e.g. {:?})",
self.name,
self.ips.len(),
iter_ip_trie(&self.ips).next(),
);
}
set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from))
} else {
Ok(())
}
}
fn extend(&mut self, ips: impl Iterator<Item = T>) {
for ip in ips {
self.insert(ip);
}
}
fn insert(&mut self, ip: T) {
if self.set.is_some() && should_add(&self.ips, &ip) {
self.ips.insert(ip);
self.dirty = true;
}
}
}
// SAFETY: set are None initially and are never actually sent // SAFETY: set are None initially and are never actually sent
// (and Set1 might be fine to send anyway actually) // (and Set1 might be fine to send anyway actually)
unsafe impl<T: IpPrefix + Send> Send for NftData<T> {} unsafe impl<T: IpPrefix + Send> Send for NftData<T> {}
@ -444,10 +478,10 @@ impl UnboundMod for ExampleMod {
} }
// load cached domains // load cached domains
if let Err(err) = ret.cache4.load(Path::new("/var/lib/unbound/domains4/")) { if let Err(err) = ret.caches.0.load(Path::new("/var/lib/unbound/domains4/")) {
ret.report("domains4", err); ret.report("domains4", err);
} }
if let Err(err) = ret.cache6.load(Path::new("/var/lib/unbound/domains6/")) { if let Err(err) = ret.caches.1.load(Path::new("/var/lib/unbound/domains6/")) {
ret.report("domains6", err); ret.report("domains6", err);
} }
@ -494,14 +528,14 @@ impl UnboundMod for ExampleMod {
println!("loading {base}/{k}_ips.json"); println!("loading {base}/{k}_ips.json");
match read_json::<Vec<IpNetDeser>>(file) { match read_json::<Vec<IpNetDeser>>(file) {
Ok(ips) => { Ok(ips) => {
r.0.ips.extend(ips.iter().filter_map(|x| { r.0.extend(ips.iter().filter_map(|x| {
if let IpNet::V4(x) = x.0 { if let IpNet::V4(x) = x.0 {
Some(x) Some(x)
} else { } else {
None None
} }
})); }));
r.1.ips.extend(ips.iter().filter_map(|x| { r.1.extend(ips.iter().filter_map(|x| {
if let IpNet::V6(x) = x.0 { if let IpNet::V6(x) = x.0 {
Some(x) Some(x)
} else { } else {
@ -520,18 +554,12 @@ impl UnboundMod for ExampleMod {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(&b"."[..]) .join(&b"."[..])
.into(); .into();
ret.cache4.get_maybe_update_rev(rev_domain.clone(), |val| { ret.caches
if let Some(val) = val { .0
r.0.ips.extend(val.iter().map(|x| Ipv4Net::from(*x))); .extend_set_with_domain(&mut r.0.ips, rev_domain.clone());
} ret.caches
None .1
}); .extend_set_with_domain(&mut r.1.ips, rev_domain.clone());
ret.cache6.get_maybe_update_rev(rev_domain, |val| {
if let Some(val) = val {
r.1.ips.extend(val.iter().map(|x| Ipv6Net::from(*x)));
}
None
});
} }
} }
@ -574,53 +602,23 @@ impl UnboundMod for ExampleMod {
} }
} }
for ruleset in &mut rulesets { for ruleset in &mut rulesets {
if !ruleset.0.name.is_empty() && ruleset.0.set.is_none() { if !ruleset.0.verify() {
report(format!("set {} not found", ruleset.0.name)); report(format!("set {} not found", ruleset.0.name));
ruleset.0.ips = RTrieSet::new();
} }
if !ruleset.1.name.is_empty() && ruleset.1.set.is_none() { if !ruleset.1.verify() {
report(format!("set {} not found", ruleset.1.name)); report(format!("set {} not found", ruleset.1.name));
ruleset.1.ips = RTrieSet::new();
} }
} }
let mut first = true; let mut first = true;
loop { loop {
for ruleset in &mut rulesets { for ruleset in &mut rulesets {
if let Some(set) = ruleset.0.set.as_mut().filter(|_| ruleset.0.dirty) { if let Err(err) = ruleset.0.flush_changes(&socket, first) {
if first {
println!(
"initializing set {} with ~{} ips (e.g. {:?})",
ruleset.0.name,
ruleset.0.ips.len(),
iter_ip_trie(&ruleset.0.ips).next(),
);
}
if let Err(err) = set.add_cidrs(
&socket,
first,
iter_ip_trie(&ruleset.0.ips).map(IpNet::V4),
) {
report(err); report(err);
} }
} if let Err(err) = ruleset.1.flush_changes(&socket, first) {
if let Some(set) = ruleset.1.set.as_mut().filter(|_| ruleset.1.dirty) {
if first {
println!(
"initializing set {} with ~{} ips (e.g. {:?})",
ruleset.1.name,
ruleset.1.ips.len(),
iter_ip_trie(&ruleset.1.ips).next(),
);
}
if let Err(err) = set.add_cidrs(
&socket,
first,
iter_ip_trie(&ruleset.1.ips).map(IpNet::V6),
) {
report(err); report(err);
} }
} }
}
if first { if first {
println!("nftables init done"); println!("nftables init done");
first = false; first = false;
@ -634,18 +632,8 @@ impl UnboundMod for ExampleMod {
let ruleset = &mut rulesets[i]; let ruleset = &mut rulesets[i];
for ip1 in ips.iter().copied() { for ip1 in ips.iter().copied() {
match ip1 { match ip1 {
IpNet::V4(ip) => { IpNet::V4(ip) => ruleset.0.insert(ip),
if ruleset.0.set.is_some() && !should_add(&ruleset.0.ips, &ip) { IpNet::V6(ip) => ruleset.1.insert(ip),
ruleset.0.ips.insert(ip);
ruleset.0.dirty = true;
}
}
IpNet::V6(ip) => {
if ruleset.1.set.is_some() && !should_add(&ruleset.1.ips, &ip) {
ruleset.1.ips.insert(ip);
ruleset.1.dirty = true;
}
}
} }
} }
} }
@ -662,16 +650,14 @@ impl UnboundMod for ExampleMod {
qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>, qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>,
event: ModuleEvent, event: ModuleEvent,
_entry: &mut crate::unbound::OutboundEntryMut, _entry: &mut crate::unbound::OutboundEntryMut,
) { ) -> Option<ModuleExtState> {
match event { match event {
ModuleEvent::New | ModuleEvent::Pass => { ModuleEvent::New | ModuleEvent::Pass => {
qstate.set_ext_state(ModuleExtState::WaitModule); return Some(ModuleExtState::WaitModule);
return;
} }
ModuleEvent::ModDone => {} ModuleEvent::ModDone => {}
_ => { _ => {
qstate.set_ext_state(ModuleExtState::Error); return Some(ModuleExtState::Error);
return;
} }
} }
let info = qstate.qinfo_mut(); let info = qstate.qinfo_mut();
@ -734,8 +720,7 @@ impl UnboundMod for ExampleMod {
} }
} }
} }
qstate.set_ext_state(ModuleExtState::Finished); return Some(ModuleExtState::Finished);
return;
} else if let Some(rev_domain) = self } else if let Some(rev_domain) = self
.tmp_nft_token .tmp_nft_token
.as_ref() .as_ref()
@ -755,8 +740,7 @@ impl UnboundMod for ExampleMod {
} }
} }
} }
qstate.set_ext_state(ModuleExtState::Finished); return Some(ModuleExtState::Finished);
return;
} }
let split_rev_domain = rev_domain let split_rev_domain = rev_domain
.split(|x| *x == b'.') .split(|x| *x == b'.')
@ -769,8 +753,7 @@ impl UnboundMod for ExampleMod {
} }
} }
if qnames.is_empty() { if qnames.is_empty() {
qstate.set_ext_state(ModuleExtState::Finished); return Some(ModuleExtState::Finished);
return;
} }
if let Some(ret) = qstate.return_msg_mut() { if let Some(ret) = qstate.return_msg_mut() {
if let Some(rep) = ret.rep() { if let Some(rep) = ret.rep() {
@ -812,8 +795,7 @@ impl UnboundMod for ExampleMod {
} }
Err(err) => { Err(err) => {
self.report("domain utf-8", err); self.report("domain utf-8", err);
qstate.set_ext_state(ModuleExtState::Error); return Some(ModuleExtState::Error);
return;
} }
}; };
let mut split_rev_domain = split_rev_domain.into_iter(); let mut split_rev_domain = split_rev_domain.into_iter();
@ -829,13 +811,17 @@ impl UnboundMod for ExampleMod {
to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from));
to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from));
let keep4 = !ip4.is_empty() let keep4 = !ip4.is_empty()
&& self && self.caches.0.set(
.cache4 &domain,
.set(&domain, IpCacheKey(joined_rev_domain.clone()), ip4); IpCacheKey(joined_rev_domain.clone()),
ip4,
);
let keep6 = !ip6.is_empty() let keep6 = !ip6.is_empty()
&& self && self.caches.1.set(
.cache6 &domain,
.set(&domain, IpCacheKey(joined_rev_domain.clone()), ip6); IpCacheKey(joined_rev_domain.clone()),
ip6,
);
to_send to_send
.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6); .retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6);
if !to_send.is_empty() { if !to_send.is_empty() {
@ -849,7 +835,7 @@ impl UnboundMod for ExampleMod {
} }
} }
} }
qstate.set_ext_state(ModuleExtState::Finished); Some(ModuleExtState::Finished)
} }
} }

View file

@ -1,11 +1,17 @@
use std::panic::{RefUnwindSafe, UnwindSafe}; use std::panic::{RefUnwindSafe, UnwindSafe};
use unbound::ModuleExtState;
#[allow( #[allow(
dead_code, dead_code,
improper_ctypes, improper_ctypes,
non_camel_case_types, non_camel_case_types,
non_snake_case, non_snake_case,
non_upper_case_globals, non_upper_case_globals,
unused_imports unused_imports,
clippy::useless_transmute,
clippy::type_complexity,
clippy::too_many_arguments,
clippy::upper_case_acronyms
)] )]
mod bindings; mod bindings;
mod combine; mod combine;
@ -33,7 +39,8 @@ pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe {
_qstate: &mut unbound::ModuleQstate<Self::QstateData>, _qstate: &mut unbound::ModuleQstate<Self::QstateData>,
_event: unbound::ModuleEvent, _event: unbound::ModuleEvent,
_entry: &mut unbound::OutboundEntryMut, _entry: &mut unbound::OutboundEntryMut,
) { ) -> Option<ModuleExtState> {
Some(ModuleExtState::Finished)
} }
fn inform_super( fn inform_super(
&self, &self,
@ -101,11 +108,14 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
entry: *mut bindings::outbound_entry, entry: *mut bindings::outbound_entry,
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.operate( let mut qstate = unbound::ModuleQstate(qstate, id, Default::default());
&mut unbound::ModuleQstate(qstate, id, Default::default()), if let Some(ext_state) = self.operate(
&mut qstate,
event.into(), event.into(),
&mut unbound::OutboundEntryMut(entry, Default::default()), &mut unbound::OutboundEntryMut(entry, Default::default()),
) ) {
qstate.set_ext_state(ext_state);
}
}) })
.unwrap_or(()); .unwrap_or(());
} }

View file

@ -609,6 +609,21 @@ pub enum ModuleExtState {
Unknown = 99, Unknown = 99,
} }
impl ModuleExtState {
pub(crate) fn importance(&self) -> usize {
match *self {
Self::Unknown => 0,
Self::InitialState => 1,
Self::Finished => 2,
Self::WaitModule => 3,
Self::RestartNext => 4,
Self::WaitSubquery => 5,
Self::WaitReply => 6,
Self::Error => 7,
}
}
}
impl From<module_ext_state> for ModuleExtState { impl From<module_ext_state> for ModuleExtState {
fn from(value: module_ext_state) -> Self { fn from(value: module_ext_state) -> Self {
match value { match value {