diff --git a/src/example.rs b/src/example.rs index 490a792..5cb1291 100644 --- a/src/example.rs +++ b/src/example.rs @@ -112,6 +112,7 @@ impl IpCacheKey { Self(split_rev_domain.fold(Domain::new(), |mut ret, seg| { if first { first = false; + } else { ret.push(b'.'); } ret.extend_from_slice(seg.as_ref()); @@ -389,10 +390,10 @@ impl ExampleMod { } fn load_env(&mut self) -> Result, NftData)>, ()> { self.nft_token = std::env::var_os("NFT_TOKEN") - .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned() + ".")) + .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned())) .transpose()?; self.tmp_nft_token = std::env::var_os("NFT_TOKEN") - .map(|x| x.to_str().ok_or(()).map(|s| format!("tmp{s}."))) + .map(|x| x.to_str().ok_or(()).map(|s| format!("tmp{s}"))) .transpose()?; let mut rulesets = Vec::new(); assert!(self.nft_queries.is_empty()); @@ -492,12 +493,15 @@ impl ExampleMod { ) -> Result<(), ()> { println!("adding {ip4:?}/{ip6:?} for {split_domain:?} to {qnames:?}"); if !ip4.is_empty() || !ip6.is_empty() { + let mut first = true; let domain = match split_domain .iter() .copied() .map(std::str::from_utf8) .try_fold(String::new(), |mut s, comp| { - if !s.is_empty() { + if first { + first = false; + } else { s.push('.'); } s.push_str(comp?); @@ -527,6 +531,7 @@ impl ExampleMod { Ok(()) } fn run_commands(&self, split_domain: &[&[u8]]) -> Option { + println!("{split_domain:?} {:?}", self.nft_token); if let Some(split_domain) = self.nft_token.as_ref().and_then(|token| { split_domain .split_last() @@ -544,12 +549,15 @@ impl ExampleMod { if domains.insert(split_domain.iter().copied().rev().map(From::from)) { drop(domains); let file_name = format!("{DATA_PREFIX}/{qname}_domains.json"); + let mut first = false; let domain = match split_domain .iter() .copied() .map(std::str::from_utf8) .try_fold(String::new(), |mut s, comp| { - if !s.is_empty() { + if !first { + first = true; + } else { s.push('.'); } s.push_str(comp?); @@ -584,10 +592,11 @@ impl ExampleMod { Err(err) => self.report("domains create", err), } } + return Some(ModuleExtState::Finished); } } } - return Some(ModuleExtState::Finished); + return Some(ModuleExtState::Error); } else if let Some(split_domain) = self.tmp_nft_token.as_ref().and_then(|token| { split_domain .split_last() @@ -603,10 +612,11 @@ impl ExampleMod { { let mut domains = query.domains.write().unwrap(); domains.insert(split_domain.iter().copied().rev().map(From::from)); + return Some(ModuleExtState::Finished); } } } - return Some(ModuleExtState::Finished); + return Some(ModuleExtState::Error); } None } @@ -709,9 +719,7 @@ impl UnboundMod for ExampleMod { } let info = qstate.qinfo(); let name = info.qname().to_bytes(); - // let rev_domain = name.strip_suffix(b".").unwrap_or(name); let split_domain = unwire_domain(name); - println!("handling {split_domain:?}"); if let Some(val) = self.run_commands(&split_domain) { return Some(val); } @@ -742,7 +750,10 @@ mod test { use ipnet::IpNet; use smallvec::smallvec; - use crate::example::{ignore, ExampleMod, IpCacheKey, IpNetDeser, DATA_PREFIX}; + use crate::{ + example::{ignore, ExampleMod, IpCacheKey, IpNetDeser, DATA_PREFIX}, + unbound::ModuleExtState, + }; #[test] fn test() { @@ -860,13 +871,25 @@ mod test { ) .unwrap(); - t.run_commands(&[&b"w"[..], &b"com"[..], &b"q"[..], &b"token"[..]]) - .unwrap(); - t.run_commands(&[&b"e"[..], &b"com"[..], &b"q"[..], &b"tmptoken"[..]]) - .unwrap(); - assert!(t - .run_commands(&[&b"e"[..], &b"com"[..], &b"w"[..], &b"tmptoken"[..]]) - .is_none()); + assert_eq!( + t.run_commands(&[&b"w"[..], &b"com"[..], &b"q"[..], &b"token"[..]]) + .unwrap(), + ModuleExtState::Finished + ); + assert_eq!( + t.run_commands(&[&b"w"[..], &b"com"[..], &b"q"[..], &b"wrongtoken"[..]]), + None + ); + assert_eq!( + t.run_commands(&[&b"e"[..], &b"com"[..], &b"q"[..], &b"tmptoken"[..]]) + .unwrap(), + ModuleExtState::Finished + ); + assert_eq!( + t.run_commands(&[&b"e"[..], &b"com"[..], &b"w"[..], &b"tmptoken"[..]]) + .unwrap(), + ModuleExtState::Error + ); let split_domain = [&b"e"[..], &b"com"[..]]; let qnames = t.get_qnames(&split_domain); diff --git a/src/nftables.rs b/src/nftables.rs index b25423b..e3a7435 100644 --- a/src/nftables.rs +++ b/src/nftables.rs @@ -4,10 +4,7 @@ use std::{ fmt::Display, io::{self, Write}, net::{Ipv4Addr, Ipv6Addr}, - os::{ - fd::BorrowedFd, - raw::{c_char, c_void}, - }, + os::{fd::BorrowedFd, raw::c_void}, rc::Rc, sync::mpsc, }; @@ -46,7 +43,7 @@ struct FlushSetMsg<'a> { set: &'a Set1, } unsafe impl<'a> NlMsg for FlushSetMsg<'a> { - unsafe fn write(&self, buf: *mut std::ffi::c_void, seq: u32, _msg_type: MsgType) { + unsafe fn write(&self, buf: *mut c_void, seq: u32, _msg_type: MsgType) { let header = nftnl_sys::nftnl_nlmsg_build_hdr( buf.cast(), libc::NFT_MSG_DELSETELEM as u16,