Compare commits

...

2 commits

Author SHA1 Message Date
chayleaf feca3758a9
unbound bindings improvements 2024-08-13 12:00:58 +07:00
chayleaf 4f43475c76
make building the example optional 2024-08-13 11:16:38 +07:00
7 changed files with 321 additions and 328 deletions

View file

@ -9,20 +9,34 @@ crate-type = ["rlib", "cdylib"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
boxcar = "0.2.5" boxcar = { version = "0.2.5", optional = true }
ctor = { version = "0.2.8", optional = true } ctor = { version = "0.2.8", optional = true }
filetime = "0.2.24" filetime = { version = "0.2.24", optional = true }
ipnet = { version = "2.9.0", features = ["serde"] } ipnet = { version = "2.9.0", features = ["serde"], optional = true }
iptrie = "0.8.5" iptrie = { version = "0.8.5", optional = true }
libc = "0.2.155" libc = { version = "0.2.155", optional = true }
mnl = { version = "0.2.2", features = ["mnl-1-0-4"] } mnl = { version = "0.2.2", features = ["mnl-1-0-4"], optional = true }
nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] } nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"], optional = true }
nix = { version = "0.29.0", features = ["poll", "user"] } nix = { version = "0.29.0", features = ["poll", "user"], optional = true }
radix_trie = "0.2.1" radix_trie = { version = "0.2.1", optional = true }
serde = { version = "1.0.205", features = ["derive"] } serde = { version = "1.0.205", features = ["derive"], optional = true }
serde_json = "1.0.122" serde_json = { version = "1.0.122", optional = true }
smallvec = "1.13.2" smallvec = { version = "1.13.2", optional = true }
[features] [features]
example = ["ctor"] example = [
"boxcar",
"ctor",
"filetime",
"ipnet",
"iptrie",
"libc",
"mnl",
"nftnl",
"nix",
"radix_trie",
"serde",
"serde_json",
"smallvec",
]
default = ["example"] default = ["example"]

View file

@ -22,15 +22,23 @@ use serde::{
use smallvec::SmallVec; use smallvec::SmallVec;
use crate::{ use crate::{
domain_tree::PrefixSet,
nftables::{nftables_thread, NftData},
unbound::{rr_class, rr_type, ModuleEvent, ModuleExtState, ReplyInfo}, unbound::{rr_class, rr_type, ModuleEvent, ModuleExtState, ReplyInfo},
UnboundMod, UnboundMod,
}; };
use domain_tree::PrefixSet;
use nftables::{nftables_thread, NftData};
mod domain_tree;
mod nftables;
type Domain = SmallVec<[u8; 32]>; type Domain = SmallVec<[u8; 32]>;
type DomainSeg = SmallVec<[u8; 16]>; type DomainSeg = SmallVec<[u8; 16]>;
#[ctor]
fn setup() {
crate::set_unbound_mod::<ExampleMod>();
}
struct IpNetDeser(IpNet); struct IpNetDeser(IpNet);
struct IpNetVisitor; struct IpNetVisitor;
impl<'de> Visitor<'de> for IpNetVisitor { impl<'de> Visitor<'de> for IpNetVisitor {
@ -462,7 +470,9 @@ impl ExampleMod {
let mut ip6: SmallVec<[Ipv6Addr; 4]> = SmallVec::new(); let mut ip6: SmallVec<[Ipv6Addr; 4]> = SmallVec::new();
for rrset in rep.rrsets() { for rrset in rep.rrsets() {
let entry = rrset.entry(); let entry = rrset.entry();
let d = entry.data(); let Some(d) = entry.data() else {
continue;
};
let rk = rrset.rk(); let rk = rrset.rk();
if rk.rrset_class() != rr_class::IN { if rk.rrset_class() != rr_class::IN {
continue; continue;
@ -738,11 +748,6 @@ impl UnboundMod for ExampleMod {
} }
} }
#[ctor]
fn setup() {
crate::set_unbound_mod::<ExampleMod>();
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::{net::Ipv4Addr, os::unix::fs::MetadataExt, path::PathBuf, str::FromStr, sync::mpsc}; use std::{net::Ipv4Addr, os::unix::fs::MetadataExt, path::PathBuf, str::FromStr, sync::mpsc};

View file

@ -471,7 +471,7 @@ pub(crate) fn nftables_thread(
} }
} }
let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap(); let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap();
let all_sets = crate::nftables::get_sets(&socket).unwrap(); let all_sets = get_sets(&socket).unwrap();
for set in all_sets { for set in all_sets {
for ruleset in &mut rulesets { for ruleset in &mut rulesets {
if set.table_name_str() == Some("global") && set.family() == libc::NFPROTO_INET as u32 { if set.table_name_str() == Some("global") && set.family() == libc::NFPROTO_INET as u32 {
@ -531,9 +531,7 @@ mod test {
use ipnet::{Ipv4Net, Ipv6Net}; use ipnet::{Ipv4Net, Ipv6Net};
use iptrie::RTrieSet; use iptrie::RTrieSet;
use crate::nftables::{iter_ip_trie, should_add}; use super::{get_sets, iter_ip_trie, should_add};
use super::get_sets;
#[test] #[test]
fn test_nftables() { fn test_nftables() {

View file

@ -15,11 +15,9 @@ use unbound::ModuleExtState;
)] )]
mod bindings; mod bindings;
mod combine; mod combine;
mod domain_tree;
#[cfg(feature = "example")] #[cfg(feature = "example")]
mod example; mod example;
mod exports; mod exports;
mod nftables;
mod unbound; mod unbound;
pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe { pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe {
@ -92,7 +90,7 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
id: ::std::os::raw::c_int, id: ::std::os::raw::c_int,
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.deinit(&mut unbound::ModuleEnvMut(env, id, Default::default())); self.deinit(&mut unbound::ModuleEnvMut::from_raw(env, id).unwrap());
}) })
.unwrap_or(()); .unwrap_or(());
} }
@ -105,15 +103,13 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
if let Some(ext_state) = self.operate( if let Some(ext_state) = self.operate(
&mut unbound::ModuleQstateMut(unbound::ModuleQstate( &mut unbound::ModuleQstateMut::from_raw(qstate, id).unwrap(),
qstate,
id,
Default::default(),
)),
event.into(), event.into(),
&mut unbound::OutboundEntryMut(entry, Default::default()), &mut unbound::OutboundEntryMut::from_raw(entry).unwrap(),
) { ) {
(*qstate).ext_state[id as usize] = ext_state as bindings::module_ext_state; if let Some(id) = unbound::check_id(id) {
(*qstate).ext_state[id] = ext_state as bindings::module_ext_state;
}
} }
}) })
.unwrap_or(()); .unwrap_or(());
@ -126,16 +122,8 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.inform_super( self.inform_super(
&mut unbound::ModuleQstateMut(unbound::ModuleQstate( &mut unbound::ModuleQstateMut::from_raw(qstate, id).unwrap(),
qstate, &mut unbound::ModuleQstateMut::from_raw(super_qstate, -1).unwrap(),
id,
Default::default(),
)),
&mut unbound::ModuleQstateMut(unbound::ModuleQstate(
super_qstate,
-1,
Default::default(),
)),
); );
}) })
.unwrap_or(()); .unwrap_or(());
@ -146,11 +134,7 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
id: ::std::os::raw::c_int, id: ::std::os::raw::c_int,
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.clear(&mut unbound::ModuleQstateMut(unbound::ModuleQstate( self.clear(&mut unbound::ModuleQstateMut::from_raw(qstate, id).unwrap());
qstate,
id,
Default::default(),
)));
}) })
.unwrap_or(()); .unwrap_or(());
} }
@ -160,7 +144,7 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
id: ::std::os::raw::c_int, id: ::std::os::raw::c_int,
) -> usize { ) -> usize {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.get_mem(&mut unbound::ModuleEnvMut(env, id, Default::default())) self.get_mem(&mut unbound::ModuleEnvMut::from_raw(env, id).unwrap())
}) })
.unwrap_or(0) .unwrap_or(0)
} }
@ -183,13 +167,12 @@ pub fn set_unbound_mod<T: 'static + UnboundMod>() {
MODULE_FACTORY MODULE_FACTORY
.set(Box::new(|env, id| { .set(Box::new(|env, id| {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
T::init(&mut unbound::ModuleEnvMut(env, id, Default::default())).map_or( unbound::ModuleEnvMut::from_raw(env, id)
0, .and_then(|mut env| T::init(&mut env).ok())
|module| { .map_or(0, |module| {
MODULE.set(Box::new(module)).map_err(|_| ()).unwrap(); MODULE.set(Box::new(module)).map_err(|_| ()).unwrap();
1 1
}, })
)
}) })
.unwrap_or(0) .unwrap_or(0)
})) }))

View file

@ -1,15 +0,0 @@
fn run<T: ToString>(
family: &str,
table: &str,
set: &str,
flush: bool,
items: impl IntoIterator<T>,
) {
let nft = libnftables1_sys::Nftables::new();
let mut cmd = String::new();
if flush {
cmd.push_str(&format!("flush set {family} {table} {set}"));
nft.run_cmd(c)
}
nft.set_numeric_time
}

View file

@ -1,38 +1,68 @@
#![allow(dead_code)] #![allow(dead_code)]
use crate::bindings::{ use crate::bindings::{
self, config_file, dns_msg, in6_addr, in6_addr__bindgen_ty_1, in_addr, infra_cache, key_cache, self, config_file, dns_msg, infra_cache, key_cache, lruhash_entry, module_env, module_ev,
lruhash_entry, module_env, module_ev, module_ext_state, module_qstate, outbound_entry, module_ext_state, module_qstate, outbound_entry, packed_rrset_data, packed_rrset_key,
packed_rrset_data, packed_rrset_key, query_info, reply_info, rrset_cache, rrset_id_type, query_info, reply_info, rrset_cache, rrset_id_type, rrset_trust, sec_status, slabhash,
rrset_trust, sec_status, slabhash, sldns_enum_ede_code, sockaddr_in, sockaddr_in6, sldns_enum_ede_code, ub_packed_rrset_key,
sockaddr_storage, ub_packed_rrset_key, AF_INET, AF_INET6,
}; };
use std::{ use std::{
ffi::CStr, marker::PhantomData, net::SocketAddr, ops::Deref, os::raw::c_char, ptr, ffi::CStr,
marker::PhantomData,
ops::Deref,
os::raw::{c_char, c_int},
ptr,
time::Duration, time::Duration,
}; };
pub struct ConfigFileMut<'a>( macro_rules! create_struct {
pub(crate) *mut config_file, ($ptr:tt, $name:tt, $mut:tt) => {
PhantomData<&'a mut config_file>, pub struct $name<'a>(pub(crate) *mut $ptr, pub(crate) PhantomData<&'a $ptr>);
); pub struct $mut<'a>(pub(crate) $name<'a>);
pub struct SlabHashMut<'a>(pub(crate) *mut slabhash, PhantomData<&'a mut slabhash>); impl<'a> Deref for $mut<'a> {
pub struct RrsetCacheMut<'a>( type Target = $name<'a>;
pub(crate) *mut rrset_cache, fn deref(&self) -> &Self::Target {
PhantomData<&'a mut rrset_cache>, &self.0
); }
pub struct InfraCacheMut<'a>( }
pub(crate) *mut infra_cache, impl<'a> $name<'a> {
PhantomData<&'a mut infra_cache>, pub const fn as_ptr(&self) -> *const $ptr {
); self.0.cast_const()
pub struct KeyCacheMut<'a>(pub(crate) *mut key_cache, PhantomData<&'a mut key_cache>); }
pub struct ModuleEnvMut<T>( pub unsafe fn from_raw(raw: *const $ptr) -> Option<Self> {
(!raw.is_null()).then_some(Self(raw.cast_mut(), PhantomData))
}
}
impl<'a> $mut<'a> {
pub fn as_mut_ptr(&mut self) -> *mut $ptr {
self.0 .0
}
pub unsafe fn from_raw(raw: *mut $ptr) -> Option<Self> {
(!raw.is_null()).then_some(Self($name(raw, PhantomData)))
}
}
};
}
create_struct!(config_file, ConfigFile, ConfigFileMut);
create_struct!(slabhash, SlabHash, SlabHashMut);
create_struct!(rrset_cache, RrsetCache, RrsetCacheMut);
create_struct!(infra_cache, InfraCache, InfraCacheMut);
create_struct!(key_cache, KeyCache, KeyCacheMut);
pub struct ModuleEnv<'a, T>(
pub(crate) *mut module_env, pub(crate) *mut module_env,
pub(crate) std::ffi::c_int, pub(crate) c_int,
pub(crate) PhantomData<T>, pub(crate) PhantomData<&'a T>,
); );
pub struct ModuleEnvMut<'a, T>(pub(crate) ModuleEnv<'a, T>);
impl<'a, T> Deref for ModuleEnvMut<'a, T> {
type Target = ModuleEnv<'a, T>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct ModuleQstate<'a, T>( pub struct ModuleQstate<'a, T>(
pub(crate) *mut module_qstate, pub(crate) *mut module_qstate,
pub(crate) std::ffi::c_int, pub(crate) c_int,
pub(crate) PhantomData<&'a mut T>, pub(crate) PhantomData<&'a mut T>,
); );
pub struct ModuleQstateMut<'a, T>(pub(crate) ModuleQstate<'a, T>); pub struct ModuleQstateMut<'a, T>(pub(crate) ModuleQstate<'a, T>);
@ -42,198 +72,176 @@ impl<'a, T> Deref for ModuleQstateMut<'a, T> {
&self.0 &self.0
} }
} }
pub struct OutboundEntryMut<'a>( create_struct!(outbound_entry, OutboundEntry, OutboundEntryMut);
pub(crate) *mut outbound_entry, create_struct!(query_info, QueryInfo, QueryInfoMut);
pub(crate) PhantomData<&'a mut outbound_entry>, create_struct!(dns_msg, DnsMsg, DnsMsgMut);
); create_struct!(reply_info, ReplyInfo, ReplyInfoMut);
pub struct QueryInfo<'a>( create_struct!(ub_packed_rrset_key, UbPackedRrsetKey, UbPackedRrsetKeyMut);
pub(crate) *mut query_info, create_struct!(lruhash_entry, LruHashEntry, LruHashEntryMut);
pub(crate) PhantomData<&'a mut query_info>, create_struct!(packed_rrset_key, PackedRrsetKey, PackedRrsetKeyMut);
); create_struct!(packed_rrset_data, PackedRrsetData, PackedRrsetDataMut);
pub struct QueryInfoMut<'a>(QueryInfo<'a>);
impl<'a> Deref for QueryInfoMut<'a> {
type Target = QueryInfo<'a>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct DnsMsg<'a>(
pub(crate) *mut dns_msg,
pub(crate) PhantomData<&'a mut dns_msg>,
);
pub struct DnsMsgMut<'a>(DnsMsg<'a>);
impl<'a> Deref for DnsMsgMut<'a> {
type Target = DnsMsg<'a>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct ReplyInfo<'a>(
pub(crate) *mut reply_info,
pub(crate) PhantomData<&'a mut reply_info>,
);
pub struct UbPackedRrsetKey<'a>(
pub(crate) *mut ub_packed_rrset_key,
pub(crate) PhantomData<&'a mut ub_packed_rrset_key>,
);
pub struct LruHashEntry<'a>(
pub(crate) *mut lruhash_entry,
pub(crate) PhantomData<&'a mut lruhash_entry>,
);
pub struct PackedRrsetKey<'a>(
pub(crate) *mut packed_rrset_key,
pub(crate) PhantomData<&'a mut packed_rrset_key>,
);
pub struct PackedRrsetData<'a>(
pub(crate) *mut packed_rrset_data,
pub(crate) PhantomData<&'a mut packed_rrset_data>,
);
impl<'a> QueryInfo<'a> { impl<'a> QueryInfo<'a> {
pub fn qname(&self) -> &CStr { pub fn qname(&self) -> &CStr {
unsafe { CStr::from_ptr((*self.0).qname as *const c_char) } unsafe { CStr::from_ptr((*self.as_ptr()).qname as *const c_char) }
} }
pub fn qtype(&self) -> u16 { pub fn qtype(&self) -> u16 {
unsafe { (*self.0).qtype } unsafe { (*self.as_ptr()).qtype }
} }
pub fn qclass(&self) -> u16 { pub fn qclass(&self) -> u16 {
unsafe { (*self.0).qclass } unsafe { (*self.as_ptr()).qclass }
} }
} }
impl<T> ModuleEnvMut<T> { impl<'a, T> ModuleEnv<'a, T> {
pub fn config_file_mut(&mut self) -> ConfigFileMut<'_> { pub unsafe fn from_raw(raw: *mut bindings::module_env, id: c_int) -> Option<Self> {
ConfigFileMut(unsafe { (*self.0).cfg }, Default::default()) (!raw.is_null()).then_some(Self(raw, id, PhantomData))
} }
pub fn msg_cache_mut(&mut self) -> SlabHashMut<'_> { pub const fn as_ptr(&self) -> *const module_env {
SlabHashMut(unsafe { (*self.0).msg_cache }, Default::default()) self.0.cast_const()
} }
pub fn rrset_cache_mut(&mut self) -> RrsetCacheMut<'_> { pub fn config_file(&self) -> ConfigFile<'_> {
RrsetCacheMut(unsafe { (*self.0).rrset_cache }, Default::default()) unsafe { ConfigFile::from_raw((*self.as_ptr()).cfg).unwrap() }
} }
pub fn infra_cache_mut(&mut self) -> InfraCacheMut<'_> { pub fn msg_cache(&self) -> SlabHash<'_> {
InfraCacheMut(unsafe { (*self.0).infra_cache }, Default::default()) unsafe { SlabHash::from_raw((*self.as_ptr()).msg_cache) }.unwrap()
} }
pub fn key_cache_mut(&mut self) -> KeyCacheMut<'_> { pub fn rrset_cache(&self) -> RrsetCache<'_> {
KeyCacheMut(unsafe { (*self.0).key_cache }, Default::default()) unsafe { RrsetCache::from_raw((*self.as_ptr()).rrset_cache) }.unwrap()
} }
#[allow(clippy::too_many_arguments)] pub fn infra_cache(&self) -> InfraCache<'_> {
pub fn send_query<Y>( unsafe { InfraCache::from_raw((*self.as_ptr()).infra_cache) }.unwrap()
&mut self,
qinfo: &QueryInfoMut,
flags: u16,
dnssec: u32,
want_dnssec: bool,
nocaps: bool,
check_ratelimit: bool,
addr: SocketAddr,
zone: &[u8],
tcp_upstream: bool,
ssl_upstream: bool,
tls_auth_name: Option<&CStr>,
q: &mut ModuleQstate<Y>,
) -> (Option<OutboundEntryMut<'_>>, bool) {
let mut was_ratelimited = 0;
let ret = unsafe {
let mut addr4 = sockaddr_in {
sin_port: 0,
sin_addr: in_addr { s_addr: 0 },
sin_zero: [0u8; 8],
sin_family: AF_INET as u16,
};
let mut addr6 = sockaddr_in6 {
sin6_port: 0,
sin6_addr: in6_addr {
__in6_u: in6_addr__bindgen_ty_1 {
__u6_addr8: [0u8; 16],
},
},
sin6_family: AF_INET6 as u16,
sin6_flowinfo: 0,
sin6_scope_id: 0,
};
let (addr, addr_len) = match addr {
SocketAddr::V4(x) => {
addr4.sin_port = x.port();
addr4.sin_addr.s_addr = (*x.ip()).into();
(
std::ptr::addr_of!(addr4).cast::<sockaddr_storage>(),
std::mem::size_of_val(&addr4),
)
} }
SocketAddr::V6(x) => { pub fn key_cache(&self) -> KeyCache<'_> {
addr6.sin6_addr.__in6_u.__u6_addr8 = x.ip().octets(); unsafe { KeyCache::from_raw((*self.as_ptr()).key_cache) }.unwrap()
addr6.sin6_flowinfo = x.flowinfo();
addr6.sin6_scope_id = x.scope_id();
(
std::ptr::addr_of!(addr6).cast(),
std::mem::size_of_val(&addr6),
)
}
};
((*self.0).send_query.unwrap_unchecked())(
qinfo.0 .0,
flags,
dnssec as i32,
want_dnssec.into(),
nocaps.into(),
check_ratelimit.into(),
addr.cast_mut(),
addr_len as u32,
zone.as_ptr().cast_mut(),
zone.len(),
tcp_upstream.into(),
ssl_upstream.into(),
tls_auth_name.map_or_else(ptr::null_mut, |x| x.as_ptr().cast_mut()),
q.0,
std::ptr::addr_of_mut!(was_ratelimited),
)
};
if ret.is_null() {
(None, was_ratelimited != 0)
} else {
(
Some(OutboundEntryMut(ret, Default::default())),
was_ratelimited != 0,
)
} }
} }
pub fn detach_subs<Y>(&mut self, qstate: &mut ModuleQstate<Y>) { impl<'a, T> ModuleEnvMut<'a, T> {
unsafe { (*self.0).detach_subs.unwrap_unchecked()(qstate.0) } pub unsafe fn from_raw(raw: *mut bindings::module_env, id: c_int) -> Option<Self> {
ModuleEnv::from_raw(raw, id).map(Self)
} }
unsafe fn attach_sub<Y>( pub fn as_mut_ptr(&mut self) -> *mut module_env {
&mut self, self.0 .0
qstate: &mut ModuleQstate<Y>,
qinfo: &QueryInfoMut,
qflags: u16,
prime: bool,
valrec: bool,
init_sub: impl FnOnce(*mut module_qstate) -> Result<(), ()>,
) -> Result<Option<ModuleQstate<'_, ()>>, ()> {
let mut newq: *mut module_qstate = ptr::null_mut();
let res = unsafe {
((*self.0).attach_sub.unwrap_unchecked())(
qstate.0,
qinfo.0 .0,
qflags,
prime.into(),
valrec.into(),
&mut newq as _,
)
};
if res != 0 {
Ok(if newq.is_null() {
None
} else if init_sub(newq).is_ok() {
Some(ModuleQstate(newq, qstate.1, Default::default()))
} else {
unsafe { ((*self.0).kill_sub.unwrap_unchecked())(newq) }
return Err(());
})
} else {
Err(())
} }
// FIXME: what lifetime to use?
// #[allow(clippy::too_many_arguments)]
// pub fn send_query<Y>(
// &mut self,
// qinfo: &QueryInfoMut,
// flags: u16,
// dnssec: u32,
// want_dnssec: bool,
// nocaps: bool,
// check_ratelimit: bool,
// addr: SocketAddr,
// zone: &[u8],
// tcp_upstream: bool,
// ssl_upstream: bool,
// tls_auth_name: Option<&CStr>,
// q: &mut ModuleQstate<Y>,
// ) -> (Option<OutboundEntryMut<'_>>, bool) {
// let mut was_ratelimited = 0;
// let ret = unsafe {
// let mut addr4 = sockaddr_in {
// sin_port: 0,
// sin_addr: in_addr { s_addr: 0 },
// sin_zero: [0u8; 8],
// sin_family: AF_INET as u16,
// };
// let mut addr6 = sockaddr_in6 {
// sin6_port: 0,
// sin6_addr: in6_addr {
// __in6_u: in6_addr__bindgen_ty_1 {
// __u6_addr8: [0u8; 16],
// },
// },
// sin6_family: AF_INET6 as u16,
// sin6_flowinfo: 0,
// sin6_scope_id: 0,
// };
// let (addr, addr_len) = match addr {
// SocketAddr::V4(x) => {
// addr4.sin_port = x.port();
// addr4.sin_addr.s_addr = (*x.ip()).into();
// (
// ptr::addr_of!(addr4).cast::<sockaddr_storage>(),
// mem::size_of_val(&addr4),
// )
// }
// SocketAddr::V6(x) => {
// addr6.sin6_addr.__in6_u.__u6_addr8 = x.ip().octets();
// addr6.sin6_flowinfo = x.flowinfo();
// addr6.sin6_scope_id = x.scope_id();
// (
// ptr::addr_of!(addr6).cast(),
// mem::size_of_val(&addr6),
// )
// }
// };
// ((*self.as_ptr()).send_query.unwrap_unchecked())(
// qinfo.as_ptr(),
// flags,
// dnssec as i32,
// want_dnssec.into(),
// nocaps.into(),
// check_ratelimit.into(),
// addr.cast_mut(),
// addr_len as u32,
// zone.as_ptr().cast_mut(),
// zone.len(),
// tcp_upstream.into(),
// ssl_upstream.into(),
// tls_auth_name.map_or_else(ptr::null_mut, |x| x.as_ptr().cast_mut()),
// q.as_ptr(),
// ptr::addr_of_mut!(was_ratelimited),
// )
// };
// if ret.is_null() {
// (None, was_ratelimited != 0)
// } else {
// (
// Some(OutboundEntryMut(OutboundEntry(ret, PhantomData))),
// was_ratelimited != 0,
// )
// }
// }
pub fn detach_subs<Y>(&mut self, qstate: &mut ModuleQstateMut<Y>) {
unsafe { (*self.as_ptr()).detach_subs.unwrap()(qstate.as_mut_ptr()) }
} }
// FIXME: what lifetime to use?
// unsafe fn attach_sub<Y>(
// &mut self,
// qstate: &mut ModuleQstate<Y>,
// qinfo: &QueryInfoMut,
// qflags: u16,
// prime: bool,
// valrec: bool,
// init_sub: impl FnOnce(*mut module_qstate) -> Result<(), ()>,
// ) -> Result<Option<ModuleQstate<'_, ()>>, ()> {
// let mut newq: *mut module_qstate = ptr::null_mut();
// let res = unsafe {
// ((*self.as_ptr()).attach_sub.unwrap_unchecked())(
// qstate.as_ptr(),
// qinfo.as_ptr(),
// qflags,
// prime.into(),
// valrec.into(),
// &mut newq as _,
// )
// };
// if res != 0 {
// Ok(if newq.is_null() {
// None
// } else if init_sub(newq).is_ok() {
// Some(ModuleQstate(newq, qstate.1, PhantomData))
// } else {
// unsafe { ((*self.as_ptr()).kill_sub.unwrap_unchecked())(newq) }
// return Err(());
// })
// } else {
// Err(())
// }
// }
// add_sub: TODO similar to above // add_sub: TODO similar to above
// detect_cycle: TODO // detect_cycle: TODO
// (note that &mut T is wrapped in dynmod stuff) // (note that &mut T is wrapped in dynmod stuff)
@ -241,21 +249,29 @@ impl<T> ModuleEnvMut<T> {
} }
impl<T> ModuleQstate<'_, T> { impl<T> ModuleQstate<'_, T> {
pub unsafe fn from_raw(raw: *mut bindings::module_qstate, id: c_int) -> Option<Self> {
(!raw.is_null()).then_some(Self(raw, id, PhantomData))
}
pub const fn as_ptr(&self) -> *const module_qstate {
self.0.cast_const()
}
pub fn qinfo(&self) -> QueryInfo<'_> { pub fn qinfo(&self) -> QueryInfo<'_> {
QueryInfo( unsafe { QueryInfo::from_raw(ptr::addr_of!((*self.as_ptr()).qinfo).cast_mut()).unwrap() }
unsafe { std::ptr::addr_of_mut!((*self.0).qinfo) },
Default::default(),
)
} }
pub fn return_msg(&self) -> Option<DnsMsg<'_>> { pub fn return_msg(&self) -> Option<DnsMsg<'_>> {
if unsafe { (*self.0).return_msg.is_null() } { unsafe { DnsMsg::from_raw((*self.as_ptr()).return_msg) }
None
} else {
Some(DnsMsg(unsafe { (*self.0).return_msg }, Default::default()))
} }
} }
pub(crate) fn check_id(id: i32) -> Option<usize> {
(id >= 0 && id < bindings::MAX_MODULE as i32).then_some(id as usize)
} }
impl<T> ModuleQstateMut<'_, T> { impl<T> ModuleQstateMut<'_, T> {
pub unsafe fn from_raw(raw: *mut bindings::module_qstate, id: c_int) -> Option<Self> {
ModuleQstate::from_raw(raw, id).map(Self)
}
pub fn as_mut_ptr(&mut self) -> *mut module_qstate {
self.0 .0
}
pub fn qinfo_mut(&mut self) -> QueryInfoMut<'_> { pub fn qinfo_mut(&mut self) -> QueryInfoMut<'_> {
QueryInfoMut(self.qinfo()) QueryInfoMut(self.qinfo())
} }
@ -264,158 +280,150 @@ impl<T> ModuleQstateMut<'_, T> {
} }
pub fn set_ext_state(&mut self, state: ModuleExtState) { pub fn set_ext_state(&mut self, state: ModuleExtState) {
unsafe { unsafe {
(*self.0 .0).ext_state[self.1 as usize] = state as module_ext_state; if let Some(id) = check_id(self.1) {
(*self.as_mut_ptr()).ext_state[id] = state as module_ext_state;
}
} }
} }
} }
impl DnsMsg<'_> { impl DnsMsg<'_> {
pub fn rep(&self) -> Option<ReplyInfo<'_>> { pub fn rep(&self) -> Option<ReplyInfo<'_>> {
if unsafe { (*self.0).rep.is_null() } { unsafe { ReplyInfo::from_raw((*self.as_ptr()).rep) }
None
} else {
Some(ReplyInfo(unsafe { (*self.0).rep }, Default::default()))
}
} }
} }
impl ReplyInfo<'_> { impl ReplyInfo<'_> {
pub fn flags(&self) -> u16 { pub fn flags(&self) -> u16 {
unsafe { (*self.0).flags } unsafe { (*self.as_ptr()).flags }
} }
pub fn authoritative(&self) -> bool { pub fn authoritative(&self) -> bool {
unsafe { (*self.0).authoritative != 0 } unsafe { (*self.as_ptr()).authoritative != 0 }
} }
pub fn qdcount(&self) -> u8 { pub fn qdcount(&self) -> u8 {
unsafe { (*self.0).qdcount } unsafe { (*self.as_ptr()).qdcount }
} }
pub fn padding(&self) -> u32 { pub fn padding(&self) -> u32 {
unsafe { (*self.0).padding } unsafe { (*self.as_ptr()).padding }
} }
pub fn ttl(&self) -> Option<Duration> { pub fn ttl(&self) -> Option<Duration> {
(unsafe { (*self.0).ttl }) (unsafe { (*self.as_ptr()).ttl })
.try_into() .try_into()
.map(Duration::from_secs) .map(Duration::from_secs)
.ok() .ok()
} }
pub fn prefetch_ttl(&self) -> Option<Duration> { pub fn prefetch_ttl(&self) -> Option<Duration> {
(unsafe { (*self.0).prefetch_ttl }) (unsafe { (*self.as_ptr()).prefetch_ttl })
.try_into() .try_into()
.map(Duration::from_secs) .map(Duration::from_secs)
.ok() .ok()
} }
pub fn serve_expired_ttl(&self) -> Option<Duration> { pub fn serve_expired_ttl(&self) -> Option<Duration> {
(unsafe { (*self.0).serve_expired_ttl }) (unsafe { (*self.as_ptr()).serve_expired_ttl })
.try_into() .try_into()
.map(Duration::from_secs) .map(Duration::from_secs)
.ok() .ok()
} }
pub fn security(&self) -> SecStatus { pub fn security(&self) -> SecStatus {
SecStatus::from(unsafe { (*self.0).security }) SecStatus::from(unsafe { (*self.as_ptr()).security })
} }
pub fn reason_bogus(&self) -> SldnsEdeCode { pub fn reason_bogus(&self) -> SldnsEdeCode {
SldnsEdeCode::from(unsafe { (*self.0).reason_bogus }) SldnsEdeCode::from(unsafe { (*self.as_ptr()).reason_bogus })
} }
pub fn reason_bogus_str(&self) -> Option<&CStr> { pub fn reason_bogus_str(&self) -> Option<&CStr> {
if unsafe { (*self.0).reason_bogus_str.is_null() } { if unsafe { (*self.as_ptr()).reason_bogus_str.is_null() } {
None None
} else { } else {
Some(unsafe { CStr::from_ptr((*self.0).reason_bogus_str) }) Some(unsafe { CStr::from_ptr((*self.as_ptr()).reason_bogus_str) })
} }
} }
pub fn an_numrrsets(&self) -> usize { pub fn an_numrrsets(&self) -> usize {
unsafe { (*self.0).an_numrrsets } unsafe { (*self.as_ptr()).an_numrrsets }
} }
pub fn ns_numrrsets(&self) -> usize { pub fn ns_numrrsets(&self) -> usize {
unsafe { (*self.0).ns_numrrsets } unsafe { (*self.as_ptr()).ns_numrrsets }
} }
pub fn ar_numrrsets(&self) -> usize { pub fn ar_numrrsets(&self) -> usize {
unsafe { (*self.0).ar_numrrsets } unsafe { (*self.as_ptr()).ar_numrrsets }
} }
pub fn rrset_count(&self) -> usize { pub fn rrset_count(&self) -> usize {
unsafe { (*self.0).rrset_count } unsafe { (*self.as_ptr()).rrset_count }
} }
pub fn rrsets(&self) -> impl '_ + Iterator<Item = UbPackedRrsetKey<'_>> { pub fn rrsets(&self) -> impl '_ + Iterator<Item = UbPackedRrsetKey<'_>> {
let total = self.rrset_count(); let total = self.rrset_count();
let rrsets = unsafe { (*self.0).rrsets }; let rrsets = unsafe { (*self.as_ptr()).rrsets };
(0..total).map(move |i| UbPackedRrsetKey(unsafe { *rrsets.add(i) }, Default::default())) (0..total).filter_map(move |i| unsafe { UbPackedRrsetKey::from_raw(*rrsets.add(i)) })
} }
} }
impl UbPackedRrsetKey<'_> { impl UbPackedRrsetKey<'_> {
pub fn entry(&self) -> LruHashEntry<'_> { pub fn entry(&self) -> LruHashEntry<'_> {
LruHashEntry( unsafe { LruHashEntry::from_raw(ptr::addr_of!((*self.as_ptr()).entry).cast_mut()).unwrap() }
unsafe { std::ptr::addr_of_mut!((*self.0).entry) },
Default::default(),
)
} }
pub fn id(&self) -> RrsetIdType { pub fn id(&self) -> RrsetIdType {
unsafe { (*self.0).id } unsafe { (*self.as_ptr()).id }
} }
pub fn rk(&self) -> PackedRrsetKey<'_> { pub fn rk(&self) -> PackedRrsetKey<'_> {
PackedRrsetKey( unsafe { PackedRrsetKey::from_raw(ptr::addr_of!((*self.as_ptr()).rk).cast_mut()).unwrap() }
unsafe { std::ptr::addr_of_mut!((*self.0).rk) },
Default::default(),
)
} }
} }
impl PackedRrsetKey<'_> { impl PackedRrsetKey<'_> {
pub fn dname(&self) -> Option<&'_ CStr> { pub fn dname(&self) -> Option<&'_ CStr> {
if unsafe { (*self.0).dname.is_null() } { if unsafe { (*self.as_ptr()).dname.is_null() } {
None None
} else { } else {
Some(unsafe { CStr::from_ptr((*self.0).dname as *const c_char) }) Some(unsafe { CStr::from_ptr((*self.as_ptr()).dname as *const c_char) })
} }
} }
pub fn flags(&self) -> u32 { pub fn flags(&self) -> u32 {
unsafe { (*self.0).flags } unsafe { (*self.as_ptr()).flags }
} }
pub fn type_(&self) -> u16 { pub fn type_(&self) -> u16 {
u16::from_be(unsafe { (*self.0).type_ }) u16::from_be(unsafe { (*self.as_ptr()).type_ })
} }
pub fn rrset_class(&self) -> u16 { pub fn rrset_class(&self) -> u16 {
u16::from_be(unsafe { (*self.0).rrset_class }) u16::from_be(unsafe { (*self.as_ptr()).rrset_class })
} }
} }
impl LruHashEntry<'_> { impl LruHashEntry<'_> {
pub fn data(&self) -> PackedRrsetData<'_> { pub fn data(&self) -> Option<PackedRrsetData<'_>> {
// FIXME: shouldnt pthread lock be used here? // FIXME: shouldnt pthread lock be used here?
unsafe { PackedRrsetData((*self.0).data.cast(), Default::default()) } unsafe { PackedRrsetData::from_raw((*self.as_ptr()).data.cast()) }
} }
} }
impl PackedRrsetData<'_> { impl PackedRrsetData<'_> {
pub fn ttl_add(&self) -> Option<Duration> { pub fn ttl_add(&self) -> Option<Duration> {
(unsafe { (*self.0).ttl_add }) (unsafe { (*self.as_ptr()).ttl_add })
.try_into() .try_into()
.map(Duration::from_secs) .map(Duration::from_secs)
.ok() .ok()
} }
pub fn ttl(&self) -> Option<Duration> { pub fn ttl(&self) -> Option<Duration> {
(unsafe { (*self.0).ttl }) (unsafe { (*self.as_ptr()).ttl })
.try_into() .try_into()
.map(Duration::from_secs) .map(Duration::from_secs)
.ok() .ok()
} }
pub fn count(&self) -> usize { pub fn count(&self) -> usize {
unsafe { (*self.0).count } unsafe { (*self.as_ptr()).count }
} }
pub fn rrsig_count(&self) -> usize { pub fn rrsig_count(&self) -> usize {
unsafe { (*self.0).rrsig_count } unsafe { (*self.as_ptr()).rrsig_count }
} }
pub fn trust(&self) -> RrsetTrust { pub fn trust(&self) -> RrsetTrust {
RrsetTrust::from(unsafe { (*self.0).trust }) RrsetTrust::from(unsafe { (*self.as_ptr()).trust })
} }
pub fn security(&self) -> SecStatus { pub fn security(&self) -> SecStatus {
SecStatus::from(unsafe { (*self.0).security }) SecStatus::from(unsafe { (*self.as_ptr()).security })
} }
pub fn rr_data(&self) -> impl '_ + Iterator<Item = (&[u8], Option<Duration>)> { pub fn rr_data(&self) -> impl '_ + Iterator<Item = (&[u8], Option<Duration>)> {
let total = self.count(); let total = self.count();
let ttl = unsafe { (*self.0).rr_ttl }; let ttl = unsafe { (*self.as_ptr()).rr_ttl };
let len = unsafe { (*self.0).rr_len }; let len = unsafe { (*self.as_ptr()).rr_len };
let data = unsafe { (*self.0).rr_data }; let data = unsafe { (*self.as_ptr()).rr_data };
(0..total).map(move |i| unsafe { (0..total).map(move |i| unsafe {
( (
std::slice::from_raw_parts(*data.add(i), *len.add(i)), std::slice::from_raw_parts(*data.add(i), *len.add(i)),
@ -426,8 +434,8 @@ impl PackedRrsetData<'_> {
pub fn rrsig_data(&self) -> impl '_ + Iterator<Item = &[u8]> { pub fn rrsig_data(&self) -> impl '_ + Iterator<Item = &[u8]> {
let total = self.count(); let total = self.count();
let total2 = self.rrsig_count(); let total2 = self.rrsig_count();
let len = unsafe { (*self.0).rr_len }; let len = unsafe { (*self.as_ptr()).rr_len };
let data = unsafe { (*self.0).rr_data }; let data = unsafe { (*self.as_ptr()).rr_data };
(total..total + total2) (total..total + total2)
.map(move |i| unsafe { std::slice::from_raw_parts(*data.add(i), *len.add(i)) }) .map(move |i| unsafe { std::slice::from_raw_parts(*data.add(i), *len.add(i)) })
} }