This commit is contained in:
chayleaf 2024-08-13 08:25:39 +07:00
parent 422001fc71
commit be0359f85f
Signed by: chayleaf
GPG key ID: 78171AD46227E68E
5 changed files with 66 additions and 110 deletions

View file

@ -151,7 +151,7 @@ impl<T> IpCache<T> {
let mut val = Some(ignore); let mut val = Some(ignore);
val = None; val = None;
val val
}) });
} }
fn get_maybe_update_rev<F: for<'a> FnOnce(&'a mut smallvec::SmallVec<[T; 4]>)>( fn get_maybe_update_rev<F: for<'a> FnOnce(&'a mut smallvec::SmallVec<[T; 4]>)>(
&self, &self,
@ -236,7 +236,7 @@ impl<T: FromStr> IpCache<T> {
let mut lock = self.0.write().unwrap(); let mut lock = self.0.write().unwrap();
assert!(lock.1.is_empty()); assert!(lock.1.is_empty());
let domains = std::fs::read_dir(dir)?; let domains = std::fs::read_dir(dir)?;
for entry in domains.filter_map(|x| x.ok()) { for entry in domains.filter_map(Result::ok) {
let domain = entry.file_name(); let domain = entry.file_name();
let Some(domain) = domain.to_str() else { let Some(domain) = domain.to_str() else {
continue; continue;
@ -324,7 +324,7 @@ impl ExampleMod {
Ok(ret) Ok(ret)
} }
fn load_json(&mut self, rulesets: &mut [(NftData<Ipv4Net>, NftData<Ipv6Net>)]) { fn load_json(&mut self, rulesets: &mut [(NftData<Ipv4Net>, NftData<Ipv6Net>)]) {
for (k, v) in self.nft_queries.iter_mut() { for (k, v) in &mut self.nft_queries {
let r = &mut rulesets[v.index]; let r = &mut rulesets[v.index];
let mut v_domains = v.domains.write().unwrap(); let mut v_domains = v.domains.write().unwrap();
for base in [CONFIG_PREFIX, DATA_PREFIX] { for base in [CONFIG_PREFIX, DATA_PREFIX] {
@ -333,13 +333,8 @@ impl ExampleMod {
match read_json::<Vec<String>>(file) { match read_json::<Vec<String>>(file) {
Ok(domains) => { Ok(domains) => {
for domain in domains { for domain in domains {
v_domains.insert( v_domains
domain .insert(domain.split('.').rev().map(|x| x.as_bytes().into()));
.split('.')
.rev()
.map(|x| x.as_bytes().into())
.collect::<SmallVec<[DomainSeg; 5]>>(),
);
} }
} }
Err(err) => Self::report2(&self.error_lock, "domains", err), Err(err) => Self::report2(&self.error_lock, "domains", err),
@ -350,13 +345,8 @@ impl ExampleMod {
match read_json::<Vec<DpiInfo>>(file) { match read_json::<Vec<DpiInfo>>(file) {
Ok(dpi_info) => { Ok(dpi_info) => {
for domain in dpi_info.iter().flat_map(|x| &x.domains) { for domain in dpi_info.iter().flat_map(|x| &x.domains) {
v_domains.insert( v_domains
domain .insert(domain.split('.').rev().map(|x| x.as_bytes().into()));
.split('.')
.rev()
.map(|x| x.as_bytes().into())
.collect::<SmallVec<[DomainSeg; 5]>>(),
);
} }
} }
Err(err) => Self::report2(&self.error_lock, "dpi", err), Err(err) => Self::report2(&self.error_lock, "dpi", err),
@ -409,7 +399,7 @@ impl ExampleMod {
if let Some(s) = std::env::var_os("NFT_QUERIES") { if let Some(s) = std::env::var_os("NFT_QUERIES") {
for (i, (name, set4, set6)) in s for (i, (name, set4, set6)) in s
.to_str() .to_str()
.map(|x| x.to_owned()) .map(ToOwned::to_owned)
.ok_or(())? .ok_or(())?
.split(';') .split(';')
.filter_map(|x| x.split_once(':')) .filter_map(|x| x.split_once(':'))
@ -418,11 +408,9 @@ impl ExampleMod {
}) })
.enumerate() .enumerate()
{ {
let (name, dynamic) = if let Some(name) = name.strip_suffix('!') { let (name, dynamic) = name
(name, true) .strip_suffix('!')
} else { .map_or((name, false), |name| (name, true));
(name, false)
};
self.nft_queries.insert( self.nft_queries.insert(
name.to_owned(), name.to_owned(),
NftQuery { NftQuery {
@ -545,7 +533,7 @@ impl ExampleMod {
.filter(|(a, _)| **a == token.as_bytes()) .filter(|(a, _)| **a == token.as_bytes())
.map(|(_, b)| b) .map(|(_, b)| b)
}) { }) {
for (qname, query) in self.nft_queries.iter() { for (qname, query) in &self.nft_queries {
if query.dynamic { if query.dynamic {
if let Some(split_domain) = split_domain if let Some(split_domain) = split_domain
.split_last() .split_last()
@ -606,7 +594,7 @@ impl ExampleMod {
.filter(|(a, _)| **a == token.as_bytes()) .filter(|(a, _)| **a == token.as_bytes())
.map(|(_, b)| b) .map(|(_, b)| b)
}) { }) {
for (qname, query) in self.nft_queries.iter() { for (qname, query) in &self.nft_queries {
if query.dynamic { if query.dynamic {
if let Some(split_domain) = split_domain if let Some(split_domain) = split_domain
.split_last() .split_last()

View file

@ -16,11 +16,7 @@ pub unsafe extern "C" fn init(
env: *mut module_env, env: *mut module_env,
id: ::std::os::raw::c_int, id: ::std::os::raw::c_int,
) -> ::std::os::raw::c_int { ) -> ::std::os::raw::c_int {
if let Some(fac) = crate::MODULE_FACTORY.take() { crate::MODULE_FACTORY.take().map_or(0, |fac| fac(env, id))
fac(env, id)
} else {
0
}
} }
/// Deinitialize module internals. /// Deinitialize module internals.
@ -33,7 +29,7 @@ pub unsafe extern "C" fn deinit(env: *mut module_env, id: ::std::os::raw::c_int)
} }
/// Perform action on pending query. Accepts a new query, or work on pending query. /// Perform action on pending query. Accepts a new query, or work on pending query.
/// You have to set qstate.ext_state on exit. /// You have to set `qstate.ext_state` on exit.
/// The state informs unbound about result and controls the following states. /// The state informs unbound about result and controls the following states.
/// ///
/// # Arguments /// # Arguments
@ -50,7 +46,7 @@ pub unsafe extern "C" fn operate(
entry: *mut outbound_entry, entry: *mut outbound_entry,
) { ) {
if let Some(module) = crate::module() { if let Some(module) = crate::module() {
module.internal_operate(qstate, event, id, entry) module.internal_operate(qstate, event, id, entry);
} }
} }
@ -69,7 +65,7 @@ pub unsafe extern "C" fn inform_super(
super_qstate: *mut module_qstate, super_qstate: *mut module_qstate,
) { ) {
if let Some(module) = crate::module() { if let Some(module) = crate::module() {
module.internal_inform_super(qstate, id, super_qstate) module.internal_inform_super(qstate, id, super_qstate);
} }
} }
@ -78,7 +74,7 @@ pub unsafe extern "C" fn inform_super(
#[no_mangle] #[no_mangle]
pub unsafe extern "C" fn clear(qstate: *mut module_qstate, id: ::std::os::raw::c_int) { pub unsafe extern "C" fn clear(qstate: *mut module_qstate, id: ::std::os::raw::c_int) {
if let Some(module) = crate::module() { if let Some(module) = crate::module() {
module.internal_clear(qstate, id) module.internal_clear(qstate, id);
} }
} }
@ -86,9 +82,7 @@ pub unsafe extern "C" fn clear(qstate: *mut module_qstate, id: ::std::os::raw::c
/// only happens explicitly and is only used to show memory usage to the user. /// only happens explicitly and is only used to show memory usage to the user.
#[no_mangle] #[no_mangle]
pub unsafe extern "C" fn get_mem(env: *mut module_env, id: ::std::os::raw::c_int) -> usize { pub unsafe extern "C" fn get_mem(env: *mut module_env, id: ::std::os::raw::c_int) -> usize {
crate::module() crate::module().map_or(0, |module| module.internal_get_mem(env, id))
.map(|module| module.internal_get_mem(env, id))
.unwrap_or(0)
} }
// function interface assertions // function interface assertions

View file

@ -9,10 +9,9 @@ use unbound::ModuleExtState;
non_snake_case, non_snake_case,
non_upper_case_globals, non_upper_case_globals,
unused_imports, unused_imports,
clippy::useless_transmute, clippy::all,
clippy::type_complexity, clippy::nursery,
clippy::too_many_arguments, clippy::pedantic
clippy::upper_case_acronyms
)] )]
mod bindings; mod bindings;
mod combine; mod combine;
@ -23,10 +22,6 @@ mod exports;
mod nftables; mod nftables;
mod unbound; mod unbound;
pub fn add(left: usize, right: usize) -> usize {
left + right
}
pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe { pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe {
type EnvData; type EnvData;
type QstateData; type QstateData;
@ -97,7 +92,7 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
id: ::std::os::raw::c_int, id: ::std::os::raw::c_int,
) { ) {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
self.deinit(&mut unbound::ModuleEnvMut(env, id, Default::default())) self.deinit(&mut unbound::ModuleEnvMut(env, id, Default::default()));
}) })
.unwrap_or(()); .unwrap_or(());
} }
@ -141,7 +136,7 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
-1, -1,
Default::default(), Default::default(),
)), )),
) );
}) })
.unwrap_or(()); .unwrap_or(());
} }
@ -155,7 +150,7 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
qstate, qstate,
id, id,
Default::default(), Default::default(),
))) )));
}) })
.unwrap_or(()); .unwrap_or(());
} }
@ -188,14 +183,13 @@ pub fn set_unbound_mod<T: 'static + UnboundMod>() {
MODULE_FACTORY MODULE_FACTORY
.set(Box::new(|env, id| { .set(Box::new(|env, id| {
std::panic::catch_unwind(|| { std::panic::catch_unwind(|| {
if let Ok(module) = T::init(&mut unbound::ModuleEnvMut(env, id, Default::default())).map_or(
T::init(&mut unbound::ModuleEnvMut(env, id, Default::default())) 0,
{ |module| {
MODULE.set(Box::new(module)).map_err(|_| ()).unwrap(); MODULE.set(Box::new(module)).map_err(|_| ()).unwrap();
1 1
} else { },
0 )
}
}) })
.unwrap_or(0) .unwrap_or(0)
})) }))
@ -203,14 +197,3 @@ pub fn set_unbound_mod<T: 'static + UnboundMod>() {
.unwrap(); .unwrap();
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}

View file

@ -48,7 +48,7 @@ struct FlushSetMsg<'a> {
unsafe impl<'a> NlMsg for FlushSetMsg<'a> { unsafe impl<'a> NlMsg for FlushSetMsg<'a> {
unsafe fn write(&self, buf: *mut std::ffi::c_void, seq: u32, _msg_type: MsgType) { unsafe fn write(&self, buf: *mut std::ffi::c_void, seq: u32, _msg_type: MsgType) {
let header = nftnl_sys::nftnl_nlmsg_build_hdr( let header = nftnl_sys::nftnl_nlmsg_build_hdr(
buf as *mut c_char, buf.cast(),
libc::NFT_MSG_DELSETELEM as u16, libc::NFT_MSG_DELSETELEM as u16,
self.set.family() as u16, self.set.family() as u16,
0, 0,
@ -68,9 +68,7 @@ pub struct SetElemsIter<'a> {
impl<'a> SetElemsIter<'a> { impl<'a> SetElemsIter<'a> {
fn new(set: &'a Set1) -> Self { fn new(set: &'a Set1) -> Self {
let iter = unsafe { nftnl_sys::nftnl_set_elems_iter_create(set.as_mut_ptr()) }; let iter = unsafe { nftnl_sys::nftnl_set_elems_iter_create(set.as_mut_ptr()) };
if iter.is_null() { assert!(!iter.is_null(), "oom");
panic!("oom");
}
SetElemsIter { SetElemsIter {
set, set,
iter, iter,
@ -125,7 +123,7 @@ unsafe impl<'a> NlMsg for SetElemsMsg<'a> {
MsgType::Del => (libc::NFT_MSG_DELSETELEM, 0), MsgType::Del => (libc::NFT_MSG_DELSETELEM, 0),
}; };
let header = nftnl_sys::nftnl_nlmsg_build_hdr( let header = nftnl_sys::nftnl_nlmsg_build_hdr(
buf as *mut c_char, buf.cast(),
type_ as u16, type_ as u16,
self.set.family() as u16, self.set.family() as u16,
flags as u16, flags as u16,
@ -171,7 +169,7 @@ impl Set1 {
pub fn new() -> Self { pub fn new() -> Self {
Self(unsafe { nftnl_sys::nftnl_set_alloc() }) Self(unsafe { nftnl_sys::nftnl_set_alloc() })
} }
pub fn as_mut_ptr(&self) -> *mut nftnl_sys::nftnl_set { pub const fn as_mut_ptr(&self) -> *mut nftnl_sys::nftnl_set {
self.0 self.0
} }
pub fn table_name(&self) -> Option<&str> { pub fn table_name(&self) -> Option<&str> {
@ -195,30 +193,26 @@ impl Set1 {
let data1_len = data1.len() as u32; let data1_len = data1.len() as u32;
unsafe { unsafe {
let elem = nftnl_sys::nftnl_set_elem_alloc(); let elem = nftnl_sys::nftnl_set_elem_alloc();
if elem.is_null() { assert!(!elem.is_null(), "oom");
panic!("oom");
}
nftnl_sys::nftnl_set_elem_set( nftnl_sys::nftnl_set_elem_set(
elem, elem,
nftnl_sys::NFTNL_SET_ELEM_KEY as u16, nftnl_sys::NFTNL_SET_ELEM_KEY as u16,
data1.as_ptr() as *const c_void, data1.as_ptr().cast(),
data1_len, data1_len,
); );
nftnl_sys::nftnl_set_elem_add(self.as_mut_ptr(), elem); nftnl_sys::nftnl_set_elem_add(self.as_mut_ptr(), elem);
let Some(data2) = excl_upper.map(|key| key.data()) else { let Some(data2) = excl_upper.map(SetKey::data) else {
return; return;
}; };
let data2_len = data2.len() as u32; let data2_len = data2.len() as u32;
let elem = nftnl_sys::nftnl_set_elem_alloc(); let elem = nftnl_sys::nftnl_set_elem_alloc();
if elem.is_null() { assert!(!elem.is_null(), "oom");
panic!("oom");
}
nftnl_sys::nftnl_set_elem_set( nftnl_sys::nftnl_set_elem_set(
elem, elem,
nftnl_sys::NFTNL_SET_ELEM_KEY as u16, nftnl_sys::NFTNL_SET_ELEM_KEY as u16,
data2.as_ptr() as *const c_void, data2.as_ptr().cast(),
data2_len, data2_len,
); );
nftnl_sys::nftnl_set_elem_set_u32( nftnl_sys::nftnl_set_elem_set_u32(
@ -244,7 +238,7 @@ impl Set1 {
count += 1; count += 1;
batch.add(&set.flush_msg(), nftnl::MsgType::Del); batch.add(&set.flush_msg(), nftnl::MsgType::Del);
} }
for net in cidrs.into_iter() { for net in cidrs {
if count + 2 > max_batch_size { if count + 2 > max_batch_size {
batch.add_iter(SetElemsIter::new(&set), MsgType::Add); batch.add_iter(SetElemsIter::new(&set), MsgType::Add);
send_and_process(socket, &batch.finalize())?; send_and_process(socket, &batch.finalize())?;
@ -265,7 +259,7 @@ impl Set1 {
send_and_process(socket, &batch.finalize()) send_and_process(socket, &batch.finalize())
} }
fn flush_msg(&self) -> FlushSetMsg<'_> { const fn flush_msg(&self) -> FlushSetMsg<'_> {
FlushSetMsg { set: self } FlushSetMsg { set: self }
} }
} }
@ -282,7 +276,7 @@ pub fn get_sets(socket: &mnl::Socket) -> io::Result<Vec<Set1>> {
let mut ret = Vec::new(); let mut ret = Vec::new();
unsafe { unsafe {
nftnl_sys::nftnl_nlmsg_build_hdr( nftnl_sys::nftnl_nlmsg_build_hdr(
buffer.as_mut_ptr() as *mut c_char, buffer.as_mut_ptr().cast(),
libc::NFT_MSG_GETSET as u16, libc::NFT_MSG_GETSET as u16,
nftnl::ProtoFamily::Inet as u16, nftnl::ProtoFamily::Inet as u16,
(libc::NLM_F_DUMP | libc::NLM_F_ACK) as u16, (libc::NLM_F_DUMP | libc::NLM_F_ACK) as u16,
@ -335,11 +329,8 @@ fn should_add<T: Helper>(trie: &RTrieSet<T>, elem: &T) -> bool {
fn iter_ip_trie<T: Helper>(trie: &RTrieSet<T>) -> impl '_ + Iterator<Item = T> { fn iter_ip_trie<T: Helper>(trie: &RTrieSet<T>) -> impl '_ + Iterator<Item = T> {
trie.iter().copied().filter(|x| { trie.iter().copied().filter(|x| {
if let Some(par) = x.direct_parent() { x.direct_parent()
should_add(trie, &par) .map_or_else(|| *x != T::ZERO, |par| should_add(trie, &par))
} else {
*x != T::ZERO
}
}) })
} }
@ -482,11 +473,10 @@ pub(crate) fn nftables_thread(
println!("nftables init done"); println!("nftables init done");
first = false; first = false;
} }
let (rulesets1, ips) = match rx.recv() { let Ok((rulesets1, ips)) = rx.recv() else {
Ok(val) => val, break;
Err(_) => break,
}; };
for i in rulesets1.into_iter() { for i in rulesets1 {
let ruleset = &mut rulesets[i]; let ruleset = &mut rulesets[i];
for ip1 in ips.iter().copied() { for ip1 in ips.iter().copied() {
match ip1 { match ip1 {

View file

@ -157,7 +157,7 @@ impl<T> ModuleEnvMut<T> {
addr4.sin_port = x.port(); addr4.sin_port = x.port();
addr4.sin_addr.s_addr = (*x.ip()).into(); addr4.sin_addr.s_addr = (*x.ip()).into();
( (
&addr4 as *const _ as *const sockaddr_storage, std::ptr::addr_of!(addr4).cast::<sockaddr_storage>(),
std::mem::size_of_val(&addr4), std::mem::size_of_val(&addr4),
) )
} }
@ -166,29 +166,27 @@ impl<T> ModuleEnvMut<T> {
addr6.sin6_flowinfo = x.flowinfo(); addr6.sin6_flowinfo = x.flowinfo();
addr6.sin6_scope_id = x.scope_id(); addr6.sin6_scope_id = x.scope_id();
( (
&addr6 as *const _ as *const sockaddr_storage, std::ptr::addr_of!(addr6).cast(),
std::mem::size_of_val(&addr6), std::mem::size_of_val(&addr6),
) )
} }
}; };
((*self.0).send_query.unwrap_unchecked())( ((*self.0).send_query.unwrap_unchecked())(
&qinfo.0 as *const _ as *mut _, qinfo.0 .0,
flags, flags,
dnssec as i32, dnssec as i32,
want_dnssec.into(), want_dnssec.into(),
nocaps.into(), nocaps.into(),
check_ratelimit.into(), check_ratelimit.into(),
addr as *mut _, addr.cast_mut(),
addr_len as u32, addr_len as u32,
zone.as_ptr() as *mut _, zone.as_ptr().cast_mut(),
zone.len(), zone.len(),
tcp_upstream.into(), tcp_upstream.into(),
ssl_upstream.into(), ssl_upstream.into(),
tls_auth_name tls_auth_name.map_or_else(ptr::null_mut, |x| x.as_ptr().cast_mut()),
.map(|x| x.as_ptr() as *mut _)
.unwrap_or(ptr::null_mut()),
q.0, q.0,
&mut was_ratelimited as *mut _, std::ptr::addr_of_mut!(was_ratelimited),
) )
}; };
if ret.is_null() { if ret.is_null() {
@ -216,7 +214,7 @@ impl<T> ModuleEnvMut<T> {
let res = unsafe { let res = unsafe {
((*self.0).attach_sub.unwrap_unchecked())( ((*self.0).attach_sub.unwrap_unchecked())(
qstate.0, qstate.0,
&qinfo.0 as *const _ as *mut _, qinfo.0 .0,
qflags, qflags,
prime.into(), prime.into(),
valrec.into(), valrec.into(),
@ -245,7 +243,7 @@ impl<T> ModuleEnvMut<T> {
impl<T> ModuleQstate<'_, T> { impl<T> ModuleQstate<'_, T> {
pub fn qinfo(&self) -> QueryInfo<'_> { pub fn qinfo(&self) -> QueryInfo<'_> {
QueryInfo( QueryInfo(
unsafe { &mut (*self.0).qinfo as *mut query_info }, unsafe { std::ptr::addr_of_mut!((*self.0).qinfo) },
Default::default(), Default::default(),
) )
} }
@ -347,7 +345,7 @@ impl ReplyInfo<'_> {
impl UbPackedRrsetKey<'_> { impl UbPackedRrsetKey<'_> {
pub fn entry(&self) -> LruHashEntry<'_> { pub fn entry(&self) -> LruHashEntry<'_> {
LruHashEntry( LruHashEntry(
unsafe { &mut (*self.0).entry as *mut _ }, unsafe { std::ptr::addr_of_mut!((*self.0).entry) },
Default::default(), Default::default(),
) )
} }
@ -355,7 +353,10 @@ impl UbPackedRrsetKey<'_> {
unsafe { (*self.0).id } unsafe { (*self.0).id }
} }
pub fn rk(&self) -> PackedRrsetKey<'_> { pub fn rk(&self) -> PackedRrsetKey<'_> {
PackedRrsetKey(unsafe { &mut (*self.0).rk as *mut _ }, Default::default()) PackedRrsetKey(
unsafe { std::ptr::addr_of_mut!((*self.0).rk) },
Default::default(),
)
} }
} }
@ -381,7 +382,7 @@ impl PackedRrsetKey<'_> {
impl LruHashEntry<'_> { impl LruHashEntry<'_> {
pub fn data(&self) -> PackedRrsetData<'_> { pub fn data(&self) -> PackedRrsetData<'_> {
// FIXME: shouldnt pthread lock be used here? // FIXME: shouldnt pthread lock be used here?
unsafe { PackedRrsetData((*self.0).data as *mut packed_rrset_data, Default::default()) } unsafe { PackedRrsetData((*self.0).data.cast(), Default::default()) }
} }
} }
@ -639,8 +640,8 @@ pub enum ModuleExtState {
} }
impl ModuleExtState { impl ModuleExtState {
pub(crate) fn importance(&self) -> usize { pub(crate) const fn importance(self) -> usize {
match *self { match self {
Self::Unknown => 0, Self::Unknown => 0,
Self::InitialState => 1, Self::InitialState => 1,
Self::Finished => 2, Self::Finished => 2,