more bugfixes

This commit is contained in:
chayleaf 2024-08-13 08:01:58 +07:00
parent e9a6f296df
commit 422001fc71
Signed by: chayleaf
GPG key ID: 78171AD46227E68E
5 changed files with 242 additions and 179 deletions

View file

@ -14,20 +14,20 @@ macro_rules! impl_tuple {
{ {
type EnvData = A::EnvData; type EnvData = A::EnvData;
type QstateData = A::QstateData; type QstateData = A::QstateData;
fn init(env: &mut crate::unbound::ModuleEnv<Self::EnvData>) -> Result<Self, ()> { fn init(env: &mut crate::unbound::ModuleEnvMut<Self::EnvData>) -> Result<Self, ()> {
Ok((A::init(env)?, $($t::init(env)?, )*)) Ok((A::init(env)?, $($t::init(env)?, )*))
} }
fn clear(&self, qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>) { fn clear(&self, qstate: &mut crate::unbound::ModuleQstateMut<Self::QstateData>) {
self.0.clear(qstate); self.0.clear(qstate);
$(self.$i.clear(qstate);)* $(self.$i.clear(qstate);)*
} }
fn deinit(self, env: &mut crate::unbound::ModuleEnv<Self::EnvData>) { fn deinit(self, env: &mut crate::unbound::ModuleEnvMut<Self::EnvData>) {
self.0.deinit(env); self.0.deinit(env);
$(self.$i.deinit(env);)* $(self.$i.deinit(env);)*
} }
fn operate( fn operate(
&self, &self,
qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>, qstate: &mut crate::unbound::ModuleQstateMut<Self::QstateData>,
event: crate::unbound::ModuleEvent, event: crate::unbound::ModuleEvent,
entry: &mut crate::unbound::OutboundEntryMut, entry: &mut crate::unbound::OutboundEntryMut,
) -> Option<ModuleExtState> { ) -> Option<ModuleExtState> {
@ -40,13 +40,13 @@ macro_rules! impl_tuple {
})* })*
ret ret
} }
fn get_mem(&self, env: &mut crate::unbound::ModuleEnv<Self::EnvData>) -> usize { fn get_mem(&self, env: &mut crate::unbound::ModuleEnvMut<Self::EnvData>) -> usize {
self.0.get_mem(env) $(* self.$i.get_mem(env))* self.0.get_mem(env) $(* self.$i.get_mem(env))*
} }
fn inform_super( fn inform_super(
&self, &self,
qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>, qstate: &mut crate::unbound::ModuleQstateMut<Self::QstateData>,
super_qstate: &mut crate::unbound::ModuleQstate<std::ffi::c_void>, super_qstate: &mut crate::unbound::ModuleQstateMut<std::ffi::c_void>,
) { ) {
self.0.inform_super(qstate, super_qstate); self.0.inform_super(qstate, super_qstate);
$(self.$i.inform_super(qstate, super_qstate);)* $(self.$i.inform_super(qstate, super_qstate);)*

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, hash::Hash}; use std::{borrow::Borrow, collections::HashMap, hash::Hash};
use smallvec::{smallvec, SmallVec}; use smallvec::{smallvec, SmallVec};
@ -36,9 +36,10 @@ impl<T: Hash + Eq> PrefixSet<T> {
} }
} }
} }
pub fn contains<'a>(&self, val: impl IntoIterator<Item = &'a T>) -> bool pub fn contains<'a, Y>(&self, val: impl IntoIterator<Item = &'a Y>) -> bool
where where
T: 'a, T: 'a + Borrow<Y>,
Y: 'a + ?Sized + Eq + Hash,
{ {
match self { match self {
Self::Leaf => true, Self::Leaf => true,

View file

@ -99,8 +99,26 @@ struct IpCache<T>(
); );
#[repr(transparent)] #[repr(transparent)]
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct IpCacheKey(Domain); struct IpCacheKey(Domain);
impl IpCacheKey {
fn from_split_domain<T: AsRef<[u8]>>(
split_domain: impl DoubleEndedIterator + Iterator<Item = T>,
) -> Self {
Self::from_split_rev_domain(split_domain.rev())
}
fn from_split_rev_domain<T: AsRef<[u8]>>(split_rev_domain: impl Iterator<Item = T>) -> Self {
let mut first = true;
Self(split_rev_domain.fold(Domain::new(), |mut ret, seg| {
if first {
first = false;
ret.push(b'.');
}
ret.extend_from_slice(seg.as_ref());
ret
}))
}
}
impl radix_trie::TrieKey for IpCacheKey { impl radix_trie::TrieKey for IpCacheKey {
fn encode_bytes(&self) -> Vec<u8> { fn encode_bytes(&self) -> Vec<u8> {
self.0.to_vec() self.0.to_vec()
@ -118,8 +136,11 @@ impl<T> Default for IpCache<T> {
fn ignore<T>(_: &mut smallvec::SmallVec<[T; 4]>) {} fn ignore<T>(_: &mut smallvec::SmallVec<[T; 4]>) {}
impl<T> IpCache<T> { impl<T> IpCache<T> {
fn extend_set_with_domain<J: IpPrefix + From<T>>(&self, ips: &mut RTrieSet<J>, domain_r: Domain) fn extend_set_with_domain<J: IpPrefix + From<T>>(
where &self,
ips: &mut RTrieSet<J>,
domain_r: IpCacheKey,
) where
T: Copy, T: Copy,
{ {
self.get_maybe_update_rev(domain_r, |val| { self.get_maybe_update_rev(domain_r, |val| {
@ -134,11 +155,10 @@ impl<T> IpCache<T> {
} }
fn get_maybe_update_rev<F: for<'a> FnOnce(&'a mut smallvec::SmallVec<[T; 4]>)>( fn get_maybe_update_rev<F: for<'a> FnOnce(&'a mut smallvec::SmallVec<[T; 4]>)>(
&self, &self,
domain_r: Domain, domain_r: IpCacheKey,
upd: impl FnOnce(Option<(&smallvec::SmallVec<[T; 4]>, &Mutex<()>, &AtomicBool)>) -> Option<F>, upd: impl FnOnce(Option<(&smallvec::SmallVec<[T; 4]>, &Mutex<()>, &AtomicBool)>) -> Option<F>,
) { ) {
let lock = self.0.read().unwrap(); let lock = self.0.read().unwrap();
let domain_r = IpCacheKey(domain_r);
let key = lock.0.get(&domain_r).copied(); let key = lock.0.get(&domain_r).copied();
if let Some(val) = if let Some(x) = key.and_then(|key| lock.1.get(key)) { if let Some(val) = if let Some(x) = key.and_then(|key| lock.1.get(key)) {
upd(Some((&x.0.read().unwrap(), &x.1, &x.2))) upd(Some((&x.0.read().unwrap(), &x.1, &x.2)))
@ -171,7 +191,7 @@ impl<T: ToString + PartialEq> IpCache<T> {
let ret1 = &mut ret; let ret1 = &mut ret;
let mut path = self.1.clone(); let mut path = self.1.clone();
path.push(domain); path.push(domain);
self.get_maybe_update_rev(domain_r.0, |ips| { self.get_maybe_update_rev(domain_r, |ips| {
if let Some(ips) = ips.as_ref().filter(|x| x.0 == &val) { if let Some(ips) = ips.as_ref().filter(|x| x.0 == &val) {
*ret1 = false; *ret1 = false;
if ips if ips
@ -234,15 +254,7 @@ impl<T: FromStr> IpCache<T> {
let Ok(reader) = std::fs::File::open(entry.path()) else { let Ok(reader) = std::fs::File::open(entry.path()) else {
continue; continue;
}; };
let domain_r = IpCacheKey( let domain_r = IpCacheKey::from_split_domain(domain.split('.'));
domain
.split('.')
.rev()
.map(|x| x.as_bytes())
.collect::<Vec<_>>()
.join(&b"."[..])
.into(),
);
let mut reader = BufReader::new(reader); let mut reader = BufReader::new(reader);
let mut line = String::new(); let mut line = String::new();
let mut ips = SmallVec::new(); let mut ips = SmallVec::new();
@ -280,7 +292,7 @@ pub(crate) const DATA_PREFIX: &str = "unbound-mod-test-data";
pub(crate) const CONFIG_PREFIX: &str = "unbound-mod-test-config"; pub(crate) const CONFIG_PREFIX: &str = "unbound-mod-test-config";
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
pub(crate) const PATH_PREFIX: &str = "/var/lib/unbound"; pub(crate) const DATA_PREFIX: &str = "/var/lib/unbound";
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
pub(crate) const CONFIG_PREFIX: &str = "/etc/unbound"; pub(crate) const CONFIG_PREFIX: &str = "/etc/unbound";
@ -375,11 +387,7 @@ impl ExampleMod {
} }
println!("loading cached domain ips for {k}"); println!("loading cached domain ips for {k}");
for rev_domain in v_domains.iter() { for rev_domain in v_domains.iter() {
let rev_domain: SmallVec<_> = rev_domain let rev_domain = IpCacheKey::from_split_rev_domain(rev_domain.into_iter());
.map(|x| x.as_slice())
.collect::<Vec<_>>()
.join(&b"."[..])
.into();
self.caches self.caches
.0 .0
.extend_set_with_domain(r.0.ips_mut(), rev_domain.clone()); .extend_set_with_domain(r.0.ips_mut(), rev_domain.clone());
@ -456,7 +464,7 @@ impl ExampleMod {
} }
fn handle_reply_info( fn handle_reply_info(
&self, &self,
split_rev_domain: SmallVec<[DomainSeg; 5]>, split_domain: &[&[u8]],
qnames: SmallVec<[usize; 5]>, qnames: SmallVec<[usize; 5]>,
rep: &ReplyInfo<'_>, rep: &ReplyInfo<'_>,
) -> Result<(), ()> { ) -> Result<(), ()> {
@ -485,52 +493,40 @@ impl ExampleMod {
} }
} }
} }
self.add_ips(ip4, ip6, split_rev_domain, qnames) self.add_ips(ip4, ip6, split_domain, qnames)
} }
fn add_ips( fn add_ips(
&self, &self,
ip4: SmallVec<[Ipv4Addr; 4]>, ip4: SmallVec<[Ipv4Addr; 4]>,
ip6: SmallVec<[Ipv6Addr; 4]>, ip6: SmallVec<[Ipv6Addr; 4]>,
split_rev_domain: SmallVec<[DomainSeg; 5]>, split_domain: &[&[u8]],
qnames: SmallVec<[usize; 5]>, qnames: SmallVec<[usize; 5]>,
) -> Result<(), ()> { ) -> Result<(), ()> {
println!("adding {ip4:?}/{ip6:?} for {split_domain:?} to {qnames:?}");
if !ip4.is_empty() || !ip6.is_empty() { if !ip4.is_empty() || !ip6.is_empty() {
let domain = match split_rev_domain let domain = match split_domain
.iter() .iter()
.rev() .copied()
.map(|x| String::from_utf8(x.to_vec()).map(|x| x + ".")) .map(std::str::from_utf8)
.collect::<Result<String, _>>() .try_fold(String::new(), |mut s, comp| {
{ if !s.is_empty() {
Ok(mut x) => { s.push('.');
x.pop();
x
} }
s.push_str(comp?);
Ok::<_, std::str::Utf8Error>(s)
}) {
Ok(x) => x,
Err(err) => { Err(err) => {
self.report("domain utf-8", err); self.report("domain utf-8", err);
return Err(()); return Err(());
} }
}; };
let mut split_rev_domain = split_rev_domain.into_iter(); let key = IpCacheKey::from_split_domain(split_domain.iter());
if let Some(first) = split_rev_domain.next() {
let first: Domain = first.to_vec().into();
let joined_rev_domain = split_rev_domain.fold(first, |mut res, mut next| {
res.push(b'.');
res.append(&mut next);
res
});
let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new(); let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new();
to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from));
to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from));
let keep4 = !ip4.is_empty() let keep4 = !ip4.is_empty() && self.caches.0.set(&domain, key.clone(), ip4);
&& self let keep6 = !ip6.is_empty() && self.caches.1.set(&domain, key, ip6);
.caches
.0
.set(&domain, IpCacheKey(joined_rev_domain.clone()), ip4);
let keep6 = !ip6.is_empty()
&& self
.caches
.1
.set(&domain, IpCacheKey(joined_rev_domain.clone()), ip6);
to_send.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6); to_send.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6);
if !to_send.is_empty() { if !to_send.is_empty() {
self.ruleset_queue self.ruleset_queue
@ -540,36 +536,37 @@ impl ExampleMod {
.unwrap(); .unwrap();
} }
} }
}
Ok(()) Ok(())
} }
fn run_commands(&self, rev_domain: &[u8]) -> Option<ModuleExtState> { fn run_commands(&self, split_domain: &[&[u8]]) -> Option<ModuleExtState> {
if let Some(rev_domain) = self if let Some(split_domain) = self.nft_token.as_ref().and_then(|token| {
.nft_token split_domain
.as_ref() .split_last()
.and_then(|token| rev_domain.strip_prefix(token.as_bytes())) .filter(|(a, _)| **a == token.as_bytes())
{ .map(|(_, b)| b)
}) {
for (qname, query) in self.nft_queries.iter() { for (qname, query) in self.nft_queries.iter() {
if query.dynamic && rev_domain.starts_with(qname.as_bytes()) { if query.dynamic {
if let Some(rev_domain) = if let Some(split_domain) = split_domain
rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes()) .split_last()
.filter(|(a, _)| **a == qname.as_bytes())
.map(|(_, b)| b)
{ {
let rev_domain = rev_domain
.split(|x| *x == b'.')
.map(|x| x.into())
.collect::<SmallVec<[_; 5]>>();
let mut domains = query.domains.write().unwrap(); let mut domains = query.domains.write().unwrap();
if domains.insert(rev_domain.clone()) { if domains.insert(split_domain.iter().copied().rev().map(From::from)) {
drop(domains); drop(domains);
let file_name = format!("{DATA_PREFIX}/{qname}_domains.json"); let file_name = format!("{DATA_PREFIX}/{qname}_domains.json");
let domain = match String::from_utf8( let domain = match split_domain
rev_domain
.iter() .iter()
.rev() .copied()
.map(|x| x.as_slice()) .map(std::str::from_utf8)
.collect::<Vec<_>>() .try_fold(String::new(), |mut s, comp| {
.join(&b"."[..]), if !s.is_empty() {
) { s.push('.');
}
s.push_str(comp?);
Ok::<_, std::str::Utf8Error>(s)
}) {
Ok(x) => x, Ok(x) => x,
Err(err) => { Err(err) => {
self.report("domain utf-8", err); self.report("domain utf-8", err);
@ -577,6 +574,7 @@ impl ExampleMod {
} }
}; };
let _lock = self.domains_write_lock.lock().unwrap(); let _lock = self.domains_write_lock.lock().unwrap();
println!("adding {domain} to {qname}");
let mut old: Vec<String> = if let Ok(file) = File::open(&file_name) { let mut old: Vec<String> = if let Ok(file) = File::open(&file_name) {
match read_json(file) { match read_json(file) {
Ok(x) => x, Ok(x) => x,
@ -602,22 +600,21 @@ impl ExampleMod {
} }
} }
return Some(ModuleExtState::Finished); return Some(ModuleExtState::Finished);
} else if let Some(rev_domain) = self } else if let Some(split_domain) = self.tmp_nft_token.as_ref().and_then(|token| {
.tmp_nft_token split_domain
.as_ref() .split_last()
.and_then(|token| rev_domain.strip_prefix(token.as_bytes())) .filter(|(a, _)| **a == token.as_bytes())
{ .map(|(_, b)| b)
}) {
for (qname, query) in self.nft_queries.iter() { for (qname, query) in self.nft_queries.iter() {
if query.dynamic && rev_domain.starts_with(qname.as_bytes()) { if query.dynamic {
if let Some(rev_domain) = if let Some(split_domain) = split_domain
rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes()) .split_last()
.filter(|(a, _)| **a == qname.as_bytes())
.map(|(_, b)| b)
{ {
let rev_domain = rev_domain
.split(|x| *x == b'.')
.map(|x| x.into())
.collect::<SmallVec<[_; 5]>>();
let mut domains = query.domains.write().unwrap(); let mut domains = query.domains.write().unwrap();
domains.insert(rev_domain.clone()); domains.insert(split_domain.iter().copied().rev().map(From::from));
} }
} }
} }
@ -625,10 +622,15 @@ impl ExampleMod {
} }
None None
} }
fn get_qnames(&self, split_rev_domain: &SmallVec<[DomainSeg; 5]>) -> SmallVec<[usize; 5]> { fn get_qnames(&self, split_domain: &[&[u8]]) -> SmallVec<[usize; 5]> {
let mut qnames: SmallVec<[usize; 5]> = SmallVec::new(); let mut qnames: SmallVec<[usize; 5]> = SmallVec::new();
for query in self.nft_queries.values() { for query in self.nft_queries.values() {
if query.domains.read().unwrap().contains(split_rev_domain) { if query
.domains
.read()
.unwrap()
.contains(split_domain.iter().copied().rev().map(From::from))
{
qnames.push(query.index); qnames.push(query.index);
} }
} }
@ -680,17 +682,31 @@ fn read_json<T: 'static + for<'a> Deserialize<'a>>(mut f: File) -> Result<T, ser
serde_json::from_slice(&data) serde_json::from_slice(&data)
} }
// \x06google\x03com
fn unwire_domain(domain: &[u8]) -> SmallVec<[&[u8]; 8]> {
let mut i = 0;
let mut ret = SmallVec::new();
while let Some(val) = domain.get(i).map(|x| *x as usize) {
i += 1;
if let Some(val) = domain.get(i..i + val) {
ret.push(val);
}
i += val;
}
ret
}
impl UnboundMod for ExampleMod { impl UnboundMod for ExampleMod {
type EnvData = (); type EnvData = ();
type QstateData = (); type QstateData = ();
fn init(_env: &mut crate::unbound::ModuleEnv<Self::EnvData>) -> Result<Self, ()> { fn init(_env: &mut crate::unbound::ModuleEnvMut<Self::EnvData>) -> Result<Self, ()> {
Self::new() Self::new()
} }
fn operate( fn operate(
&self, &self,
qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>, qstate: &mut crate::unbound::ModuleQstateMut<Self::QstateData>,
event: ModuleEvent, event: ModuleEvent,
_entry: &mut crate::unbound::OutboundEntryMut, _entry: &mut crate::unbound::OutboundEntryMut,
) -> Option<ModuleExtState> { ) -> Option<ModuleExtState> {
@ -703,26 +719,21 @@ impl UnboundMod for ExampleMod {
return Some(ModuleExtState::Error); return Some(ModuleExtState::Error);
} }
} }
let info = qstate.qinfo_mut(); let info = qstate.qinfo();
let name = info.qname().to_bytes(); let name = info.qname().to_bytes();
let rev_domain = name.strip_suffix(b".").unwrap_or(name); // let rev_domain = name.strip_suffix(b".").unwrap_or(name);
if let Some(val) = self.run_commands(rev_domain) { let split_domain = unwire_domain(name);
println!("handling {split_domain:?}");
if let Some(val) = self.run_commands(&split_domain) {
return Some(val); return Some(val);
} }
let split_rev_domain = rev_domain let qnames = self.get_qnames(&split_domain);
.split(|x| *x == b'.')
.map(|x| x.into())
.collect::<SmallVec<[_; 5]>>();
let qnames = self.get_qnames(&split_rev_domain);
if qnames.is_empty() { if qnames.is_empty() {
return Some(ModuleExtState::Finished); return Some(ModuleExtState::Finished);
} }
if let Some(ret) = qstate.return_msg_mut() { if let Some(ret) = qstate.return_msg() {
if let Some(rep) = ret.rep() { if let Some(rep) = ret.rep() {
if self if self.handle_reply_info(&split_domain, qnames, &rep).is_err() {
.handle_reply_info(split_rev_domain, qnames, &rep)
.is_err()
{
return Some(ModuleExtState::Error); return Some(ModuleExtState::Error);
} }
} }
@ -741,9 +752,9 @@ mod test {
use std::{net::Ipv4Addr, os::unix::fs::MetadataExt, path::PathBuf, str::FromStr, sync::mpsc}; use std::{net::Ipv4Addr, os::unix::fs::MetadataExt, path::PathBuf, str::FromStr, sync::mpsc};
use ipnet::IpNet; use ipnet::IpNet;
use smallvec::{smallvec, SmallVec}; use smallvec::smallvec;
use crate::example::{ignore, ExampleMod, IpNetDeser, DATA_PREFIX}; use crate::example::{ignore, ExampleMod, IpCacheKey, IpNetDeser, DATA_PREFIX};
#[test] #[test]
fn test() { fn test() {
@ -794,24 +805,26 @@ mod test {
base_path.push("domains6"); base_path.push("domains6");
t.caches.1.load(&base_path).unwrap(); t.caches.1.load(&base_path).unwrap();
t.caches t.caches.0.get_maybe_update_rev(
.0 IpCacheKey::from_split_domain(["a", "com"].into_iter()),
.get_maybe_update_rev("com.a".as_bytes().into(), |x| { |x| {
assert!(x.unwrap().0.len() == 2); assert!(x.unwrap().0.len() == 2);
#[allow(unused_assignments)] #[allow(unused_assignments)]
let mut val = Some(ignore); let mut val = Some(ignore);
val = None; val = None;
val val
}); },
t.caches );
.0 t.caches.0.get_maybe_update_rev(
.get_maybe_update_rev("com.b".as_bytes().into(), |x| { IpCacheKey::from_split_domain(["b", "com"].into_iter()),
|x| {
assert!(x.unwrap().0.len() == 1); assert!(x.unwrap().0.len() == 1);
#[allow(unused_assignments)] #[allow(unused_assignments)]
let mut val = Some(ignore); let mut val = Some(ignore);
val = None; val = None;
val val
}); },
);
t.load_json(&mut rulesets); t.load_json(&mut rulesets);
@ -838,48 +851,53 @@ mod test {
tx2.send(rulesets).unwrap(); tx2.send(rulesets).unwrap();
}); });
let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"c"[..])]; let split_domain = [&b"c"[..], &b"com"[..]];
let qnames = t.get_qnames(&split_rev_domain); let qnames = t.get_qnames(&split_domain);
assert_eq!(qnames.len(), 2); assert_eq!(qnames.len(), 2);
t.add_ips( t.add_ips(
smallvec![Ipv4Addr::new(7, 7, 7, 7), Ipv4Addr::new(6, 6, 6, 6)], smallvec![Ipv4Addr::new(7, 7, 7, 7), Ipv4Addr::new(6, 6, 6, 6)],
smallvec![], smallvec![],
split_rev_domain, &split_domain,
qnames, qnames,
) )
.unwrap(); .unwrap();
let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"a"[..])]; let split_domain = [&b"a"[..], &b"com"[..]];
let qnames = t.get_qnames(&split_rev_domain); let qnames = t.get_qnames(&split_domain);
t.add_ips( t.add_ips(
smallvec![Ipv4Addr::new(1, 2, 3, 4), Ipv4Addr::new(5, 6, 7, 8)], smallvec![Ipv4Addr::new(1, 2, 3, 4), Ipv4Addr::new(5, 6, 7, 8)],
smallvec![], smallvec![],
split_rev_domain, &split_domain,
qnames, qnames,
) )
.unwrap(); .unwrap();
t.run_commands(b"token.q.com.w").unwrap(); t.run_commands(&[&b"w"[..], &b"com"[..], &b"q"[..], &b"token"[..]])
t.run_commands(b"tmptoken.q.com.e").unwrap(); .unwrap();
t.run_commands(&[&b"e"[..], &b"com"[..], &b"q"[..], &b"tmptoken"[..]])
.unwrap();
assert!(t
.run_commands(&[&b"e"[..], &b"com"[..], &b"w"[..], &b"tmptoken"[..]])
.is_none());
let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"e"[..])]; let split_domain = [&b"e"[..], &b"com"[..]];
let qnames = t.get_qnames(&split_rev_domain); let qnames = t.get_qnames(&split_domain);
assert_eq!(qnames.len(), 1); assert_eq!(qnames.len(), 1);
t.add_ips( t.add_ips(
smallvec![Ipv4Addr::new(8, 8, 8, 8)], smallvec![Ipv4Addr::new(8, 8, 8, 8)],
smallvec![], smallvec![],
split_rev_domain, &split_domain,
qnames, qnames,
) )
.unwrap(); .unwrap();
let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"w"[..])]; let split_domain = [&b"w"[..], &b"com"[..]];
let qnames = t.get_qnames(&split_rev_domain); let qnames = t.get_qnames(&split_domain);
assert_eq!(qnames.len(), 1); assert_eq!(qnames.len(), 1);
t.add_ips( t.add_ips(
smallvec![Ipv4Addr::new(9, 8, 8, 8)], smallvec![Ipv4Addr::new(9, 8, 8, 8)],
smallvec![], smallvec![],
split_rev_domain, &split_domain,
qnames, qnames,
) )
.unwrap(); .unwrap();

View file

@ -31,13 +31,13 @@ pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe {
type EnvData; type EnvData;
type QstateData; type QstateData;
#[allow(clippy::result_unit_err)] #[allow(clippy::result_unit_err)]
fn init(_env: &mut unbound::ModuleEnv<Self::EnvData>) -> Result<Self, ()> { fn init(_env: &mut unbound::ModuleEnvMut<Self::EnvData>) -> Result<Self, ()> {
Err(()) Err(())
} }
fn deinit(self, _env: &mut unbound::ModuleEnv<Self::EnvData>) {} fn deinit(self, _env: &mut unbound::ModuleEnvMut<Self::EnvData>) {}
fn operate( fn operate(
&self, &self,
_qstate: &mut unbound::ModuleQstate<Self::QstateData>, _qstate: &mut unbound::ModuleQstateMut<Self::QstateData>,
_event: unbound::ModuleEvent, _event: unbound::ModuleEvent,
_entry: &mut unbound::OutboundEntryMut, _entry: &mut unbound::OutboundEntryMut,
) -> Option<ModuleExtState> { ) -> Option<ModuleExtState> {
@ -45,13 +45,13 @@ pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe {
} }
fn inform_super( fn inform_super(
&self, &self,
_qstate: &mut unbound::ModuleQstate<Self::QstateData>, _qstate: &mut unbound::ModuleQstateMut<Self::QstateData>,
_super_qstate: &mut unbound::ModuleQstate<::std::ffi::c_void>, _super_qstate: &mut unbound::ModuleQstateMut<::std::ffi::c_void>,
) { ) {
} }
fn clear(&self, _qstate: &mut unbound::ModuleQstate<Self::QstateData>) {} fn clear(&self, _qstate: &mut unbound::ModuleQstateMut<Self::QstateData>) {}
fn get_mem(&self, _env: &mut unbound::ModuleEnv<Self::EnvData>) -> usize { fn get_mem(&self, _env: &mut unbound::ModuleEnvMut<Self::EnvData>) -> usize {
0 0
} }
} }
@ -97,7 +97,7 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
id: ::std::os::raw::c_int, id: ::std::os::raw::c_int,
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.deinit(&mut unbound::ModuleEnv(env, id, Default::default())) self.deinit(&mut unbound::ModuleEnvMut(env, id, Default::default()))
}) })
.unwrap_or(()); .unwrap_or(());
} }
@ -109,13 +109,16 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
entry: *mut bindings::outbound_entry, entry: *mut bindings::outbound_entry,
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
let mut qstate = unbound::ModuleQstate(qstate, id, Default::default());
if let Some(ext_state) = self.operate( if let Some(ext_state) = self.operate(
&mut qstate, &mut unbound::ModuleQstateMut(unbound::ModuleQstate(
qstate,
id,
Default::default(),
)),
event.into(), event.into(),
&mut unbound::OutboundEntryMut(entry, Default::default()), &mut unbound::OutboundEntryMut(entry, Default::default()),
) { ) {
qstate.set_ext_state(ext_state); (*qstate).ext_state[id as usize] = ext_state as bindings::module_ext_state;
} }
}) })
.unwrap_or(()); .unwrap_or(());
@ -128,8 +131,16 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.inform_super( self.inform_super(
&mut unbound::ModuleQstate(qstate, id, Default::default()), &mut unbound::ModuleQstateMut(unbound::ModuleQstate(
&mut unbound::ModuleQstate(super_qstate, -1, Default::default()), qstate,
id,
Default::default(),
)),
&mut unbound::ModuleQstateMut(unbound::ModuleQstate(
super_qstate,
-1,
Default::default(),
)),
) )
}) })
.unwrap_or(()); .unwrap_or(());
@ -140,7 +151,11 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
id: ::std::os::raw::c_int, id: ::std::os::raw::c_int,
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.clear(&mut unbound::ModuleQstate(qstate, id, Default::default())) self.clear(&mut unbound::ModuleQstateMut(unbound::ModuleQstate(
qstate,
id,
Default::default(),
)))
}) })
.unwrap_or(()); .unwrap_or(());
} }
@ -150,7 +165,7 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
id: ::std::os::raw::c_int, id: ::std::os::raw::c_int,
) -> usize { ) -> usize {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.get_mem(&mut unbound::ModuleEnv(env, id, Default::default())) self.get_mem(&mut unbound::ModuleEnvMut(env, id, Default::default()))
}) })
.unwrap_or(0) .unwrap_or(0)
} }
@ -174,7 +189,7 @@ pub fn set_unbound_mod<T: 'static + UnboundMod>() {
.set(Box::new(|env, id| { .set(Box::new(|env, id| {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
if let Ok(module) = if let Ok(module) =
T::init(&mut unbound::ModuleEnv(env, id, Default::default())) T::init(&mut unbound::ModuleEnvMut(env, id, Default::default()))
{ {
MODULE.set(Box::new(module)).map_err(|_| ()).unwrap(); MODULE.set(Box::new(module)).map_err(|_| ()).unwrap();
1 1

View file

@ -6,7 +6,10 @@ use crate::bindings::{
rrset_trust, sec_status, slabhash, sldns_enum_ede_code, sockaddr_in, sockaddr_in6, rrset_trust, sec_status, slabhash, sldns_enum_ede_code, sockaddr_in, sockaddr_in6,
sockaddr_storage, ub_packed_rrset_key, AF_INET, AF_INET6, sockaddr_storage, ub_packed_rrset_key, AF_INET, AF_INET6,
}; };
use std::{ffi::CStr, marker::PhantomData, net::SocketAddr, os::raw::c_char, ptr, time::Duration}; use std::{
ffi::CStr, marker::PhantomData, net::SocketAddr, ops::Deref, os::raw::c_char, ptr,
time::Duration,
};
pub struct ConfigFileMut<'a>( pub struct ConfigFileMut<'a>(
pub(crate) *mut config_file, pub(crate) *mut config_file,
@ -22,7 +25,7 @@ pub struct InfraCacheMut<'a>(
PhantomData<&'a mut infra_cache>, PhantomData<&'a mut infra_cache>,
); );
pub struct KeyCacheMut<'a>(pub(crate) *mut key_cache, PhantomData<&'a mut key_cache>); pub struct KeyCacheMut<'a>(pub(crate) *mut key_cache, PhantomData<&'a mut key_cache>);
pub struct ModuleEnv<T>( pub struct ModuleEnvMut<T>(
pub(crate) *mut module_env, pub(crate) *mut module_env,
pub(crate) std::ffi::c_int, pub(crate) std::ffi::c_int,
pub(crate) PhantomData<T>, pub(crate) PhantomData<T>,
@ -32,18 +35,39 @@ pub struct ModuleQstate<'a, T>(
pub(crate) std::ffi::c_int, pub(crate) std::ffi::c_int,
pub(crate) PhantomData<&'a mut T>, pub(crate) PhantomData<&'a mut T>,
); );
pub struct ModuleQstateMut<'a, T>(pub(crate) ModuleQstate<'a, T>);
impl<'a, T> Deref for ModuleQstateMut<'a, T> {
type Target = ModuleQstate<'a, T>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct OutboundEntryMut<'a>( pub struct OutboundEntryMut<'a>(
pub(crate) *mut outbound_entry, pub(crate) *mut outbound_entry,
pub(crate) PhantomData<&'a mut outbound_entry>, pub(crate) PhantomData<&'a mut outbound_entry>,
); );
pub struct QueryInfoMut<'a>( pub struct QueryInfo<'a>(
pub(crate) *mut query_info, pub(crate) *mut query_info,
pub(crate) PhantomData<&'a mut query_info>, pub(crate) PhantomData<&'a mut query_info>,
); );
pub struct DnsMsgMut<'a>( pub struct QueryInfoMut<'a>(QueryInfo<'a>);
impl<'a> Deref for QueryInfoMut<'a> {
type Target = QueryInfo<'a>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct DnsMsg<'a>(
pub(crate) *mut dns_msg, pub(crate) *mut dns_msg,
pub(crate) PhantomData<&'a mut dns_msg>, pub(crate) PhantomData<&'a mut dns_msg>,
); );
pub struct DnsMsgMut<'a>(DnsMsg<'a>);
impl<'a> Deref for DnsMsgMut<'a> {
type Target = DnsMsg<'a>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct ReplyInfo<'a>( pub struct ReplyInfo<'a>(
pub(crate) *mut reply_info, pub(crate) *mut reply_info,
pub(crate) PhantomData<&'a mut reply_info>, pub(crate) PhantomData<&'a mut reply_info>,
@ -65,7 +89,7 @@ pub struct PackedRrsetData<'a>(
pub(crate) PhantomData<&'a mut packed_rrset_data>, pub(crate) PhantomData<&'a mut packed_rrset_data>,
); );
impl<'a> QueryInfoMut<'a> { impl<'a> QueryInfo<'a> {
pub fn qname(&self) -> &CStr { pub fn qname(&self) -> &CStr {
unsafe { CStr::from_ptr((*self.0).qname as *const c_char) } unsafe { CStr::from_ptr((*self.0).qname as *const c_char) }
} }
@ -77,7 +101,7 @@ impl<'a> QueryInfoMut<'a> {
} }
} }
impl<T> ModuleEnv<T> { impl<T> ModuleEnvMut<T> {
pub fn config_file_mut(&mut self) -> ConfigFileMut<'_> { pub fn config_file_mut(&mut self) -> ConfigFileMut<'_> {
ConfigFileMut(unsafe { (*self.0).cfg }, Default::default()) ConfigFileMut(unsafe { (*self.0).cfg }, Default::default())
} }
@ -219,30 +243,35 @@ impl<T> ModuleEnv<T> {
} }
impl<T> ModuleQstate<'_, T> { impl<T> ModuleQstate<'_, T> {
pub fn qinfo_mut(&mut self) -> QueryInfoMut<'_> { pub fn qinfo(&self) -> QueryInfo<'_> {
QueryInfoMut( QueryInfo(
unsafe { &mut (*self.0).qinfo as *mut query_info }, unsafe { &mut (*self.0).qinfo as *mut query_info },
Default::default(), Default::default(),
) )
} }
pub fn return_msg_mut(&mut self) -> Option<DnsMsgMut<'_>> { pub fn return_msg(&self) -> Option<DnsMsg<'_>> {
if unsafe { (*self.0).return_msg.is_null() } { if unsafe { (*self.0).return_msg.is_null() } {
None None
} else { } else {
Some(DnsMsgMut( Some(DnsMsg(unsafe { (*self.0).return_msg }, Default::default()))
unsafe { (*self.0).return_msg },
Default::default(),
))
} }
} }
}
impl<T> ModuleQstateMut<'_, T> {
pub fn qinfo_mut(&mut self) -> QueryInfoMut<'_> {
QueryInfoMut(self.qinfo())
}
pub fn return_msg_mut(&mut self) -> Option<DnsMsgMut<'_>> {
self.return_msg().map(DnsMsgMut)
}
pub fn set_ext_state(&mut self, state: ModuleExtState) { pub fn set_ext_state(&mut self, state: ModuleExtState) {
unsafe { unsafe {
(*self.0).ext_state[self.1 as usize] = state as module_ext_state; (*self.0 .0).ext_state[self.1 as usize] = state as module_ext_state;
} }
} }
} }
impl DnsMsgMut<'_> { impl DnsMsg<'_> {
pub fn rep(&self) -> Option<ReplyInfo<'_>> { pub fn rep(&self) -> Option<ReplyInfo<'_>> {
if unsafe { (*self.0).rep.is_null() } { if unsafe { (*self.0).rep.is_null() } {
None None