diff --git a/src/combine.rs b/src/combine.rs index 434739f..16cae70 100644 --- a/src/combine.rs +++ b/src/combine.rs @@ -14,20 +14,20 @@ macro_rules! impl_tuple { { type EnvData = A::EnvData; type QstateData = A::QstateData; - fn init(env: &mut crate::unbound::ModuleEnv) -> Result { + fn init(env: &mut crate::unbound::ModuleEnvMut) -> Result { Ok((A::init(env)?, $($t::init(env)?, )*)) } - fn clear(&self, qstate: &mut crate::unbound::ModuleQstate) { + fn clear(&self, qstate: &mut crate::unbound::ModuleQstateMut) { self.0.clear(qstate); $(self.$i.clear(qstate);)* } - fn deinit(self, env: &mut crate::unbound::ModuleEnv) { + fn deinit(self, env: &mut crate::unbound::ModuleEnvMut) { self.0.deinit(env); $(self.$i.deinit(env);)* } fn operate( &self, - qstate: &mut crate::unbound::ModuleQstate, + qstate: &mut crate::unbound::ModuleQstateMut, event: crate::unbound::ModuleEvent, entry: &mut crate::unbound::OutboundEntryMut, ) -> Option { @@ -40,13 +40,13 @@ macro_rules! impl_tuple { })* ret } - fn get_mem(&self, env: &mut crate::unbound::ModuleEnv) -> usize { + fn get_mem(&self, env: &mut crate::unbound::ModuleEnvMut) -> usize { self.0.get_mem(env) $(* self.$i.get_mem(env))* } fn inform_super( &self, - qstate: &mut crate::unbound::ModuleQstate, - super_qstate: &mut crate::unbound::ModuleQstate, + qstate: &mut crate::unbound::ModuleQstateMut, + super_qstate: &mut crate::unbound::ModuleQstateMut, ) { self.0.inform_super(qstate, super_qstate); $(self.$i.inform_super(qstate, super_qstate);)* diff --git a/src/domain_tree.rs b/src/domain_tree.rs index 67d67a3..417353f 100644 --- a/src/domain_tree.rs +++ b/src/domain_tree.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, hash::Hash}; +use std::{borrow::Borrow, collections::HashMap, hash::Hash}; use smallvec::{smallvec, SmallVec}; @@ -36,9 +36,10 @@ impl PrefixSet { } } } - pub fn contains<'a>(&self, val: impl IntoIterator) -> bool + pub fn contains<'a, Y>(&self, val: impl IntoIterator) -> bool where - T: 'a, + T: 'a + Borrow, + Y: 'a + ?Sized + Eq + Hash, { match self { Self::Leaf => true, diff --git a/src/example.rs b/src/example.rs index e1485cb..b8c27c1 100644 --- a/src/example.rs +++ b/src/example.rs @@ -99,8 +99,26 @@ struct IpCache( ); #[repr(transparent)] -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct IpCacheKey(Domain); +impl IpCacheKey { + fn from_split_domain>( + split_domain: impl DoubleEndedIterator + Iterator, + ) -> Self { + Self::from_split_rev_domain(split_domain.rev()) + } + fn from_split_rev_domain>(split_rev_domain: impl Iterator) -> Self { + let mut first = true; + Self(split_rev_domain.fold(Domain::new(), |mut ret, seg| { + if first { + first = false; + ret.push(b'.'); + } + ret.extend_from_slice(seg.as_ref()); + ret + })) + } +} impl radix_trie::TrieKey for IpCacheKey { fn encode_bytes(&self) -> Vec { self.0.to_vec() @@ -118,8 +136,11 @@ impl Default for IpCache { fn ignore(_: &mut smallvec::SmallVec<[T; 4]>) {} impl IpCache { - fn extend_set_with_domain>(&self, ips: &mut RTrieSet, domain_r: Domain) - where + fn extend_set_with_domain>( + &self, + ips: &mut RTrieSet, + domain_r: IpCacheKey, + ) where T: Copy, { self.get_maybe_update_rev(domain_r, |val| { @@ -134,11 +155,10 @@ impl IpCache { } fn get_maybe_update_rev FnOnce(&'a mut smallvec::SmallVec<[T; 4]>)>( &self, - domain_r: Domain, + domain_r: IpCacheKey, upd: impl FnOnce(Option<(&smallvec::SmallVec<[T; 4]>, &Mutex<()>, &AtomicBool)>) -> Option, ) { let lock = self.0.read().unwrap(); - let domain_r = IpCacheKey(domain_r); let key = lock.0.get(&domain_r).copied(); if let Some(val) = if let Some(x) = key.and_then(|key| lock.1.get(key)) { upd(Some((&x.0.read().unwrap(), &x.1, &x.2))) @@ -171,7 +191,7 @@ impl IpCache { let ret1 = &mut ret; let mut path = self.1.clone(); path.push(domain); - self.get_maybe_update_rev(domain_r.0, |ips| { + self.get_maybe_update_rev(domain_r, |ips| { if let Some(ips) = ips.as_ref().filter(|x| x.0 == &val) { *ret1 = false; if ips @@ -234,15 +254,7 @@ impl IpCache { let Ok(reader) = std::fs::File::open(entry.path()) else { continue; }; - let domain_r = IpCacheKey( - domain - .split('.') - .rev() - .map(|x| x.as_bytes()) - .collect::>() - .join(&b"."[..]) - .into(), - ); + let domain_r = IpCacheKey::from_split_domain(domain.split('.')); let mut reader = BufReader::new(reader); let mut line = String::new(); let mut ips = SmallVec::new(); @@ -280,7 +292,7 @@ pub(crate) const DATA_PREFIX: &str = "unbound-mod-test-data"; pub(crate) const CONFIG_PREFIX: &str = "unbound-mod-test-config"; #[cfg(not(debug_assertions))] -pub(crate) const PATH_PREFIX: &str = "/var/lib/unbound"; +pub(crate) const DATA_PREFIX: &str = "/var/lib/unbound"; #[cfg(not(debug_assertions))] pub(crate) const CONFIG_PREFIX: &str = "/etc/unbound"; @@ -375,11 +387,7 @@ impl ExampleMod { } println!("loading cached domain ips for {k}"); for rev_domain in v_domains.iter() { - let rev_domain: SmallVec<_> = rev_domain - .map(|x| x.as_slice()) - .collect::>() - .join(&b"."[..]) - .into(); + let rev_domain = IpCacheKey::from_split_rev_domain(rev_domain.into_iter()); self.caches .0 .extend_set_with_domain(r.0.ips_mut(), rev_domain.clone()); @@ -456,7 +464,7 @@ impl ExampleMod { } fn handle_reply_info( &self, - split_rev_domain: SmallVec<[DomainSeg; 5]>, + split_domain: &[&[u8]], qnames: SmallVec<[usize; 5]>, rep: &ReplyInfo<'_>, ) -> Result<(), ()> { @@ -485,91 +493,80 @@ impl ExampleMod { } } } - self.add_ips(ip4, ip6, split_rev_domain, qnames) + self.add_ips(ip4, ip6, split_domain, qnames) } fn add_ips( &self, ip4: SmallVec<[Ipv4Addr; 4]>, ip6: SmallVec<[Ipv6Addr; 4]>, - split_rev_domain: SmallVec<[DomainSeg; 5]>, + split_domain: &[&[u8]], qnames: SmallVec<[usize; 5]>, ) -> Result<(), ()> { + println!("adding {ip4:?}/{ip6:?} for {split_domain:?} to {qnames:?}"); if !ip4.is_empty() || !ip6.is_empty() { - let domain = match split_rev_domain + let domain = match split_domain .iter() - .rev() - .map(|x| String::from_utf8(x.to_vec()).map(|x| x + ".")) - .collect::>() - { - Ok(mut x) => { - x.pop(); - x - } + .copied() + .map(std::str::from_utf8) + .try_fold(String::new(), |mut s, comp| { + if !s.is_empty() { + s.push('.'); + } + s.push_str(comp?); + Ok::<_, std::str::Utf8Error>(s) + }) { + Ok(x) => x, Err(err) => { self.report("domain utf-8", err); return Err(()); } }; - let mut split_rev_domain = split_rev_domain.into_iter(); - if let Some(first) = split_rev_domain.next() { - let first: Domain = first.to_vec().into(); - let joined_rev_domain = split_rev_domain.fold(first, |mut res, mut next| { - res.push(b'.'); - res.append(&mut next); - res - }); - let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new(); - to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); - to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); - let keep4 = !ip4.is_empty() - && self - .caches - .0 - .set(&domain, IpCacheKey(joined_rev_domain.clone()), ip4); - let keep6 = !ip6.is_empty() - && self - .caches - .1 - .set(&domain, IpCacheKey(joined_rev_domain.clone()), ip6); - to_send.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6); - if !to_send.is_empty() { - self.ruleset_queue - .as_ref() - .unwrap() - .send((qnames, to_send)) - .unwrap(); - } + let key = IpCacheKey::from_split_domain(split_domain.iter()); + let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new(); + to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); + to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); + let keep4 = !ip4.is_empty() && self.caches.0.set(&domain, key.clone(), ip4); + let keep6 = !ip6.is_empty() && self.caches.1.set(&domain, key, ip6); + to_send.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6); + if !to_send.is_empty() { + self.ruleset_queue + .as_ref() + .unwrap() + .send((qnames, to_send)) + .unwrap(); } } Ok(()) } - fn run_commands(&self, rev_domain: &[u8]) -> Option { - if let Some(rev_domain) = self - .nft_token - .as_ref() - .and_then(|token| rev_domain.strip_prefix(token.as_bytes())) - { + fn run_commands(&self, split_domain: &[&[u8]]) -> Option { + if let Some(split_domain) = self.nft_token.as_ref().and_then(|token| { + split_domain + .split_last() + .filter(|(a, _)| **a == token.as_bytes()) + .map(|(_, b)| b) + }) { for (qname, query) in self.nft_queries.iter() { - if query.dynamic && rev_domain.starts_with(qname.as_bytes()) { - if let Some(rev_domain) = - rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes()) + if query.dynamic { + if let Some(split_domain) = split_domain + .split_last() + .filter(|(a, _)| **a == qname.as_bytes()) + .map(|(_, b)| b) { - let rev_domain = rev_domain - .split(|x| *x == b'.') - .map(|x| x.into()) - .collect::>(); let mut domains = query.domains.write().unwrap(); - if domains.insert(rev_domain.clone()) { + if domains.insert(split_domain.iter().copied().rev().map(From::from)) { drop(domains); let file_name = format!("{DATA_PREFIX}/{qname}_domains.json"); - let domain = match String::from_utf8( - rev_domain - .iter() - .rev() - .map(|x| x.as_slice()) - .collect::>() - .join(&b"."[..]), - ) { + let domain = match split_domain + .iter() + .copied() + .map(std::str::from_utf8) + .try_fold(String::new(), |mut s, comp| { + if !s.is_empty() { + s.push('.'); + } + s.push_str(comp?); + Ok::<_, std::str::Utf8Error>(s) + }) { Ok(x) => x, Err(err) => { self.report("domain utf-8", err); @@ -577,6 +574,7 @@ impl ExampleMod { } }; let _lock = self.domains_write_lock.lock().unwrap(); + println!("adding {domain} to {qname}"); let mut old: Vec = if let Ok(file) = File::open(&file_name) { match read_json(file) { Ok(x) => x, @@ -602,22 +600,21 @@ impl ExampleMod { } } return Some(ModuleExtState::Finished); - } else if let Some(rev_domain) = self - .tmp_nft_token - .as_ref() - .and_then(|token| rev_domain.strip_prefix(token.as_bytes())) - { + } else if let Some(split_domain) = self.tmp_nft_token.as_ref().and_then(|token| { + split_domain + .split_last() + .filter(|(a, _)| **a == token.as_bytes()) + .map(|(_, b)| b) + }) { for (qname, query) in self.nft_queries.iter() { - if query.dynamic && rev_domain.starts_with(qname.as_bytes()) { - if let Some(rev_domain) = - rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes()) + if query.dynamic { + if let Some(split_domain) = split_domain + .split_last() + .filter(|(a, _)| **a == qname.as_bytes()) + .map(|(_, b)| b) { - let rev_domain = rev_domain - .split(|x| *x == b'.') - .map(|x| x.into()) - .collect::>(); let mut domains = query.domains.write().unwrap(); - domains.insert(rev_domain.clone()); + domains.insert(split_domain.iter().copied().rev().map(From::from)); } } } @@ -625,10 +622,15 @@ impl ExampleMod { } None } - fn get_qnames(&self, split_rev_domain: &SmallVec<[DomainSeg; 5]>) -> SmallVec<[usize; 5]> { + fn get_qnames(&self, split_domain: &[&[u8]]) -> SmallVec<[usize; 5]> { let mut qnames: SmallVec<[usize; 5]> = SmallVec::new(); for query in self.nft_queries.values() { - if query.domains.read().unwrap().contains(split_rev_domain) { + if query + .domains + .read() + .unwrap() + .contains(split_domain.iter().copied().rev().map(From::from)) + { qnames.push(query.index); } } @@ -680,17 +682,31 @@ fn read_json Deserialize<'a>>(mut f: File) -> Result SmallVec<[&[u8]; 8]> { + let mut i = 0; + let mut ret = SmallVec::new(); + while let Some(val) = domain.get(i).map(|x| *x as usize) { + i += 1; + if let Some(val) = domain.get(i..i + val) { + ret.push(val); + } + i += val; + } + ret +} + impl UnboundMod for ExampleMod { type EnvData = (); type QstateData = (); - fn init(_env: &mut crate::unbound::ModuleEnv) -> Result { + fn init(_env: &mut crate::unbound::ModuleEnvMut) -> Result { Self::new() } fn operate( &self, - qstate: &mut crate::unbound::ModuleQstate, + qstate: &mut crate::unbound::ModuleQstateMut, event: ModuleEvent, _entry: &mut crate::unbound::OutboundEntryMut, ) -> Option { @@ -703,26 +719,21 @@ impl UnboundMod for ExampleMod { return Some(ModuleExtState::Error); } } - let info = qstate.qinfo_mut(); + let info = qstate.qinfo(); let name = info.qname().to_bytes(); - let rev_domain = name.strip_suffix(b".").unwrap_or(name); - if let Some(val) = self.run_commands(rev_domain) { + // let rev_domain = name.strip_suffix(b".").unwrap_or(name); + let split_domain = unwire_domain(name); + println!("handling {split_domain:?}"); + if let Some(val) = self.run_commands(&split_domain) { return Some(val); } - let split_rev_domain = rev_domain - .split(|x| *x == b'.') - .map(|x| x.into()) - .collect::>(); - let qnames = self.get_qnames(&split_rev_domain); + let qnames = self.get_qnames(&split_domain); if qnames.is_empty() { return Some(ModuleExtState::Finished); } - if let Some(ret) = qstate.return_msg_mut() { + if let Some(ret) = qstate.return_msg() { if let Some(rep) = ret.rep() { - if self - .handle_reply_info(split_rev_domain, qnames, &rep) - .is_err() - { + if self.handle_reply_info(&split_domain, qnames, &rep).is_err() { return Some(ModuleExtState::Error); } } @@ -741,9 +752,9 @@ mod test { use std::{net::Ipv4Addr, os::unix::fs::MetadataExt, path::PathBuf, str::FromStr, sync::mpsc}; use ipnet::IpNet; - use smallvec::{smallvec, SmallVec}; + use smallvec::smallvec; - use crate::example::{ignore, ExampleMod, IpNetDeser, DATA_PREFIX}; + use crate::example::{ignore, ExampleMod, IpCacheKey, IpNetDeser, DATA_PREFIX}; #[test] fn test() { @@ -794,24 +805,26 @@ mod test { base_path.push("domains6"); t.caches.1.load(&base_path).unwrap(); - t.caches - .0 - .get_maybe_update_rev("com.a".as_bytes().into(), |x| { + t.caches.0.get_maybe_update_rev( + IpCacheKey::from_split_domain(["a", "com"].into_iter()), + |x| { assert!(x.unwrap().0.len() == 2); #[allow(unused_assignments)] let mut val = Some(ignore); val = None; val - }); - t.caches - .0 - .get_maybe_update_rev("com.b".as_bytes().into(), |x| { + }, + ); + t.caches.0.get_maybe_update_rev( + IpCacheKey::from_split_domain(["b", "com"].into_iter()), + |x| { assert!(x.unwrap().0.len() == 1); #[allow(unused_assignments)] let mut val = Some(ignore); val = None; val - }); + }, + ); t.load_json(&mut rulesets); @@ -838,48 +851,53 @@ mod test { tx2.send(rulesets).unwrap(); }); - let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"c"[..])]; - let qnames = t.get_qnames(&split_rev_domain); + let split_domain = [&b"c"[..], &b"com"[..]]; + let qnames = t.get_qnames(&split_domain); assert_eq!(qnames.len(), 2); t.add_ips( smallvec![Ipv4Addr::new(7, 7, 7, 7), Ipv4Addr::new(6, 6, 6, 6)], smallvec![], - split_rev_domain, + &split_domain, qnames, ) .unwrap(); - let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"a"[..])]; - let qnames = t.get_qnames(&split_rev_domain); + let split_domain = [&b"a"[..], &b"com"[..]]; + let qnames = t.get_qnames(&split_domain); t.add_ips( smallvec![Ipv4Addr::new(1, 2, 3, 4), Ipv4Addr::new(5, 6, 7, 8)], smallvec![], - split_rev_domain, + &split_domain, qnames, ) .unwrap(); - t.run_commands(b"token.q.com.w").unwrap(); - t.run_commands(b"tmptoken.q.com.e").unwrap(); + t.run_commands(&[&b"w"[..], &b"com"[..], &b"q"[..], &b"token"[..]]) + .unwrap(); + t.run_commands(&[&b"e"[..], &b"com"[..], &b"q"[..], &b"tmptoken"[..]]) + .unwrap(); + assert!(t + .run_commands(&[&b"e"[..], &b"com"[..], &b"w"[..], &b"tmptoken"[..]]) + .is_none()); - let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"e"[..])]; - let qnames = t.get_qnames(&split_rev_domain); + let split_domain = [&b"e"[..], &b"com"[..]]; + let qnames = t.get_qnames(&split_domain); assert_eq!(qnames.len(), 1); t.add_ips( smallvec![Ipv4Addr::new(8, 8, 8, 8)], smallvec![], - split_rev_domain, + &split_domain, qnames, ) .unwrap(); - let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"w"[..])]; - let qnames = t.get_qnames(&split_rev_domain); + let split_domain = [&b"w"[..], &b"com"[..]]; + let qnames = t.get_qnames(&split_domain); assert_eq!(qnames.len(), 1); t.add_ips( smallvec![Ipv4Addr::new(9, 8, 8, 8)], smallvec![], - split_rev_domain, + &split_domain, qnames, ) .unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 739de1a..061913f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,13 +31,13 @@ pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe { type EnvData; type QstateData; #[allow(clippy::result_unit_err)] - fn init(_env: &mut unbound::ModuleEnv) -> Result { + fn init(_env: &mut unbound::ModuleEnvMut) -> Result { Err(()) } - fn deinit(self, _env: &mut unbound::ModuleEnv) {} + fn deinit(self, _env: &mut unbound::ModuleEnvMut) {} fn operate( &self, - _qstate: &mut unbound::ModuleQstate, + _qstate: &mut unbound::ModuleQstateMut, _event: unbound::ModuleEvent, _entry: &mut unbound::OutboundEntryMut, ) -> Option { @@ -45,13 +45,13 @@ pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe { } fn inform_super( &self, - _qstate: &mut unbound::ModuleQstate, - _super_qstate: &mut unbound::ModuleQstate<::std::ffi::c_void>, + _qstate: &mut unbound::ModuleQstateMut, + _super_qstate: &mut unbound::ModuleQstateMut<::std::ffi::c_void>, ) { } - fn clear(&self, _qstate: &mut unbound::ModuleQstate) {} + fn clear(&self, _qstate: &mut unbound::ModuleQstateMut) {} - fn get_mem(&self, _env: &mut unbound::ModuleEnv) -> usize { + fn get_mem(&self, _env: &mut unbound::ModuleEnvMut) -> usize { 0 } } @@ -97,7 +97,7 @@ unsafe impl SealedUnboundMod for T { id: ::std::os::raw::c_int, ) { std::panic::catch_unwind(|| { - self.deinit(&mut unbound::ModuleEnv(env, id, Default::default())) + self.deinit(&mut unbound::ModuleEnvMut(env, id, Default::default())) }) .unwrap_or(()); } @@ -109,13 +109,16 @@ unsafe impl SealedUnboundMod for T { entry: *mut bindings::outbound_entry, ) { std::panic::catch_unwind(|| { - let mut qstate = unbound::ModuleQstate(qstate, id, Default::default()); if let Some(ext_state) = self.operate( - &mut qstate, + &mut unbound::ModuleQstateMut(unbound::ModuleQstate( + qstate, + id, + Default::default(), + )), event.into(), &mut unbound::OutboundEntryMut(entry, Default::default()), ) { - qstate.set_ext_state(ext_state); + (*qstate).ext_state[id as usize] = ext_state as bindings::module_ext_state; } }) .unwrap_or(()); @@ -128,8 +131,16 @@ unsafe impl SealedUnboundMod for T { ) { std::panic::catch_unwind(|| { self.inform_super( - &mut unbound::ModuleQstate(qstate, id, Default::default()), - &mut unbound::ModuleQstate(super_qstate, -1, Default::default()), + &mut unbound::ModuleQstateMut(unbound::ModuleQstate( + qstate, + id, + Default::default(), + )), + &mut unbound::ModuleQstateMut(unbound::ModuleQstate( + super_qstate, + -1, + Default::default(), + )), ) }) .unwrap_or(()); @@ -140,7 +151,11 @@ unsafe impl SealedUnboundMod for T { id: ::std::os::raw::c_int, ) { std::panic::catch_unwind(|| { - self.clear(&mut unbound::ModuleQstate(qstate, id, Default::default())) + self.clear(&mut unbound::ModuleQstateMut(unbound::ModuleQstate( + qstate, + id, + Default::default(), + ))) }) .unwrap_or(()); } @@ -150,7 +165,7 @@ unsafe impl SealedUnboundMod for T { id: ::std::os::raw::c_int, ) -> usize { std::panic::catch_unwind(|| { - self.get_mem(&mut unbound::ModuleEnv(env, id, Default::default())) + self.get_mem(&mut unbound::ModuleEnvMut(env, id, Default::default())) }) .unwrap_or(0) } @@ -174,7 +189,7 @@ pub fn set_unbound_mod() { .set(Box::new(|env, id| { std::panic::catch_unwind(|| { if let Ok(module) = - T::init(&mut unbound::ModuleEnv(env, id, Default::default())) + T::init(&mut unbound::ModuleEnvMut(env, id, Default::default())) { MODULE.set(Box::new(module)).map_err(|_| ()).unwrap(); 1 diff --git a/src/unbound.rs b/src/unbound.rs index 6ac6812..6305e8e 100644 --- a/src/unbound.rs +++ b/src/unbound.rs @@ -6,7 +6,10 @@ use crate::bindings::{ rrset_trust, sec_status, slabhash, sldns_enum_ede_code, sockaddr_in, sockaddr_in6, sockaddr_storage, ub_packed_rrset_key, AF_INET, AF_INET6, }; -use std::{ffi::CStr, marker::PhantomData, net::SocketAddr, os::raw::c_char, ptr, time::Duration}; +use std::{ + ffi::CStr, marker::PhantomData, net::SocketAddr, ops::Deref, os::raw::c_char, ptr, + time::Duration, +}; pub struct ConfigFileMut<'a>( pub(crate) *mut config_file, @@ -22,7 +25,7 @@ pub struct InfraCacheMut<'a>( PhantomData<&'a mut infra_cache>, ); pub struct KeyCacheMut<'a>(pub(crate) *mut key_cache, PhantomData<&'a mut key_cache>); -pub struct ModuleEnv( +pub struct ModuleEnvMut( pub(crate) *mut module_env, pub(crate) std::ffi::c_int, pub(crate) PhantomData, @@ -32,18 +35,39 @@ pub struct ModuleQstate<'a, T>( pub(crate) std::ffi::c_int, pub(crate) PhantomData<&'a mut T>, ); +pub struct ModuleQstateMut<'a, T>(pub(crate) ModuleQstate<'a, T>); +impl<'a, T> Deref for ModuleQstateMut<'a, T> { + type Target = ModuleQstate<'a, T>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} pub struct OutboundEntryMut<'a>( pub(crate) *mut outbound_entry, pub(crate) PhantomData<&'a mut outbound_entry>, ); -pub struct QueryInfoMut<'a>( +pub struct QueryInfo<'a>( pub(crate) *mut query_info, pub(crate) PhantomData<&'a mut query_info>, ); -pub struct DnsMsgMut<'a>( +pub struct QueryInfoMut<'a>(QueryInfo<'a>); +impl<'a> Deref for QueryInfoMut<'a> { + type Target = QueryInfo<'a>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +pub struct DnsMsg<'a>( pub(crate) *mut dns_msg, pub(crate) PhantomData<&'a mut dns_msg>, ); +pub struct DnsMsgMut<'a>(DnsMsg<'a>); +impl<'a> Deref for DnsMsgMut<'a> { + type Target = DnsMsg<'a>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} pub struct ReplyInfo<'a>( pub(crate) *mut reply_info, pub(crate) PhantomData<&'a mut reply_info>, @@ -65,7 +89,7 @@ pub struct PackedRrsetData<'a>( pub(crate) PhantomData<&'a mut packed_rrset_data>, ); -impl<'a> QueryInfoMut<'a> { +impl<'a> QueryInfo<'a> { pub fn qname(&self) -> &CStr { unsafe { CStr::from_ptr((*self.0).qname as *const c_char) } } @@ -77,7 +101,7 @@ impl<'a> QueryInfoMut<'a> { } } -impl ModuleEnv { +impl ModuleEnvMut { pub fn config_file_mut(&mut self) -> ConfigFileMut<'_> { ConfigFileMut(unsafe { (*self.0).cfg }, Default::default()) } @@ -219,30 +243,35 @@ impl ModuleEnv { } impl ModuleQstate<'_, T> { - pub fn qinfo_mut(&mut self) -> QueryInfoMut<'_> { - QueryInfoMut( + pub fn qinfo(&self) -> QueryInfo<'_> { + QueryInfo( unsafe { &mut (*self.0).qinfo as *mut query_info }, Default::default(), ) } - pub fn return_msg_mut(&mut self) -> Option> { + pub fn return_msg(&self) -> Option> { if unsafe { (*self.0).return_msg.is_null() } { None } else { - Some(DnsMsgMut( - unsafe { (*self.0).return_msg }, - Default::default(), - )) + Some(DnsMsg(unsafe { (*self.0).return_msg }, Default::default())) } } +} +impl ModuleQstateMut<'_, T> { + pub fn qinfo_mut(&mut self) -> QueryInfoMut<'_> { + QueryInfoMut(self.qinfo()) + } + pub fn return_msg_mut(&mut self) -> Option> { + self.return_msg().map(DnsMsgMut) + } pub fn set_ext_state(&mut self, state: ModuleExtState) { unsafe { - (*self.0).ext_state[self.1 as usize] = state as module_ext_state; + (*self.0 .0).ext_state[self.1 as usize] = state as module_ext_state; } } } -impl DnsMsgMut<'_> { +impl DnsMsg<'_> { pub fn rep(&self) -> Option> { if unsafe { (*self.0).rep.is_null() } { None