From e9a6f296dfef8b4382029fdb409038e77891bb19 Mon Sep 17 00:00:00 2001 From: chayleaf Date: Tue, 13 Aug 2024 05:41:57 +0700 Subject: [PATCH] bugfixes --- .gitignore | 2 + Cargo.toml | 2 +- src/domain_tree.rs | 1 + src/example.rs | 895 +++++++++++++++++++++++---------------------- src/lib.rs | 1 + src/nftables.rs | 206 ++++++++++- 6 files changed, 668 insertions(+), 439 deletions(-) diff --git a/.gitignore b/.gitignore index d787b70..ba48bdf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /target /result +/unbound-mod-test-config +/unbound-mod-test-data diff --git a/Cargo.toml b/Cargo.toml index 0f50c38..2352cfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ iptrie = "0.8.5" libc = "0.2.155" mnl = { version = "0.2.2", features = ["mnl-1-0-4"] } nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] } -nix = { version = "0.29.0", features = ["poll"] } +nix = { version = "0.29.0", features = ["poll", "user"] } radix_trie = "0.2.1" serde = { version = "1.0.205", features = ["derive"] } serde_json = "1.0.122" diff --git a/src/domain_tree.rs b/src/domain_tree.rs index cfebae9..67d67a3 100644 --- a/src/domain_tree.rs +++ b/src/domain_tree.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, hash::Hash}; use smallvec::{smallvec, SmallVec}; +#[derive(Debug)] pub enum PrefixSet { Map(HashMap>), Leaf, diff --git a/src/example.rs b/src/example.rs index fa491c0..e1485cb 100644 --- a/src/example.rs +++ b/src/example.rs @@ -8,15 +8,13 @@ use std::{ str::FromStr, sync::{ atomic::{AtomicBool, Ordering}, - mpsc::{self, RecvError}, - Mutex, RwLock, + mpsc, Mutex, RwLock, }, - time::{Duration, SystemTime}, }; use ctor::ctor; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; -use iptrie::{IpPrefix, IpRootPrefix, RTrieSet}; +use iptrie::{IpPrefix, RTrieSet}; use serde::{ de::{Error, Visitor}, Deserialize, @@ -25,8 +23,8 @@ use smallvec::SmallVec; use crate::{ domain_tree::PrefixSet, - nftables::Set1, - unbound::{rr_class, rr_type, ModuleEvent, ModuleExtState}, + nftables::{nftables_thread, NftData}, + unbound::{rr_class, rr_type, ModuleEvent, ModuleExtState, ReplyInfo}, UnboundMod, }; @@ -83,18 +81,15 @@ impl<'de> Deserialize<'de> for IpNetDeser { #[derive(Default)] struct ExampleMod { - domain_name_overrides: HashMap, nft_token: Option, tmp_nft_token: Option, nft_queries: HashMap, caches: (IpCache, IpCache), - #[allow(clippy::type_complexity)] ruleset_queue: Option, smallvec::SmallVec<[IpNet; 8]>)>>, error_lock: Mutex<()>, domains_write_lock: Mutex<()>, } -#[allow(clippy::type_complexity)] struct IpCache( RwLock<( radix_trie::Trie, @@ -120,6 +115,7 @@ impl Default for IpCache { ) } } +fn ignore(_: &mut smallvec::SmallVec<[T; 4]>) {} impl IpCache { fn extend_set_with_domain>(&self, ips: &mut RTrieSet, domain_r: Domain) @@ -130,9 +126,8 @@ impl IpCache { if let Some(val) = val { ips.extend(val.0.iter().copied().map(|x| J::from(x))); } - fn ignore(_: &mut smallvec::SmallVec<[T; 4]>) {} #[allow(unused_assignments)] - let mut val = Some(ignore::); + let mut val = Some(ignore); val = None; val }) @@ -184,13 +179,18 @@ impl IpCache { .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) .is_ok() { - let _ = filetime::set_file_mtime(path, filetime::FileTime::now()); + let _res = filetime::set_file_mtime(&path, filetime::FileTime::now()); + #[cfg(test)] + _res.unwrap(); } return None; } Some(|ips: &mut SmallVec<_>| { let Ok(mut file) = File::create(path) else { *ips = val; + #[cfg(test)] + panic!(); + #[cfg(not(test))] return; }; let to_write = val.iter().fold(String::new(), |mut s, ip| { @@ -211,6 +211,7 @@ impl IpCache { impl IpCache { fn load(&mut self, dir: &Path) -> Result<(), io::Error> { println!("loading {dir:?}"); + self.1 = dir.to_owned(); std::fs::create_dir_all(dir)?; let mut lock = self.0.write().unwrap(); assert!(lock.1.is_empty()); @@ -220,16 +221,16 @@ impl IpCache { let Some(domain) = domain.to_str() else { continue; }; - if let Some(age) = entry + /*if let Some(age) = entry .metadata() .and_then(|x| x.modified()) .ok() - .and_then(|x| SystemTime::now().duration_since(x).ok()) + .and_then(|x| std::time::SystemTime::now().duration_since(x).ok()) { - if age > Duration::from_secs(60 * 60 * 24 * 7) { + if age > std::time::Duration::from_secs(60 * 60 * 24 * 7) { continue; } - } + }*/ let Ok(reader) = std::fs::File::open(entry.path()) else { continue; }; @@ -266,230 +267,55 @@ impl IpCache { } } -struct NftData { - ips: RTrieSet, - dirty: bool, - set: Option, - name: String, -} - -impl NftData -where - IpNet: From, -{ - #[must_use] - fn verify(&mut self) -> bool { - if !self.name.is_empty() && self.set.is_none() { - self.ips = RTrieSet::new(); - false - } else { - true - } - } - fn flush_changes(&mut self, socket: &mnl::Socket, flush_set: bool) -> Result<(), io::Error> { - if let Some(set) = self.set.as_mut().filter(|_| self.dirty) { - if flush_set { - println!( - "initializing set {} with ~{} ips (e.g. {:?})", - self.name, - self.ips.len(), - iter_ip_trie(&self.ips).next(), - ); - } - set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from)) - } else { - Ok(()) - } - } - fn extend(&mut self, ips: impl Iterator) { - for ip in ips { - self.insert(ip); - } - } - fn insert(&mut self, ip: T) { - if self.set.is_some() && should_add(&self.ips, &ip) { - self.ips.insert(ip); - self.dirty = true; - } - } -} - -// SAFETY: set are None initially and are never actually sent -// (and Set1 might be fine to send anyway actually) -unsafe impl Send for NftData {} - +#[derive(Debug)] struct NftQuery { domains: RwLock>, dynamic: bool, index: usize, } +#[cfg(debug_assertions)] +pub(crate) const DATA_PREFIX: &str = "unbound-mod-test-data"; +#[cfg(debug_assertions)] +pub(crate) const CONFIG_PREFIX: &str = "unbound-mod-test-config"; + +#[cfg(not(debug_assertions))] +pub(crate) const PATH_PREFIX: &str = "/var/lib/unbound"; +#[cfg(not(debug_assertions))] +pub(crate) const CONFIG_PREFIX: &str = "/etc/unbound"; + impl ExampleMod { - fn report(&self, code: &str, err: impl Display) { - println!("{code}: {err}"); - if let Ok(mut file) = std::fs::OpenOptions::new() - .append(true) - .open("/var/lib/unbound/error.log") - { - let _lock = self.error_lock.lock(); - if file.write_all(code.as_bytes()).is_err() { - return; - } - if file.write_all(b": ").is_err() { - return; - } - if file.write_all(err.to_string().as_bytes()).is_err() { - return; - } - if file.write_all(b"\n").is_err() { - return; - } - file.flush().unwrap_or(()); - } - } -} + fn new() -> Result { + let mut ret = Self::default(); + let mut rulesets = ret.load_env()?; -#[derive(Deserialize)] -struct DpiInfo { - domains: Vec, - // name: String, - // restriction: {"code": "ban"} -} - -trait Helper: iptrie::IpPrefix + PartialEq { - const ZERO: Self; - fn direct_parent(&self) -> Option; -} - -impl Helper for Ipv4Net { - const ZERO: Self = match Self::new(Ipv4Addr::UNSPECIFIED, 0) { - Ok(x) => x, - #[allow(clippy::empty_loop)] - Err(_) => loop {}, - }; - fn direct_parent(&self) -> Option { - self.len() - .checked_sub(1) - .and_then(|x| Self::new(self.bitslot().into(), x).ok()) - } -} - -impl Helper for Ipv6Net { - const ZERO: Self = match Self::new(Ipv6Addr::UNSPECIFIED, 0) { - Ok(x) => x, - #[allow(clippy::empty_loop)] - Err(_) => loop {}, - }; - fn direct_parent(&self) -> Option { - self.len() - .checked_sub(1) - .and_then(|x| Self::new(self.bitslot().into(), x).ok()) - } -} - -fn should_add(trie: &RTrieSet, elem: &T) -> bool { - *trie.lookup(elem) == T::ZERO -} - -fn iter_ip_trie(trie: &RTrieSet) -> impl '_ + Iterator { - trie.iter().copied().filter(|x| { - if let Some(par) = x.direct_parent() { - should_add(trie, &par) - } else { - *x != T::ZERO - } - }) -} - -fn read_json Deserialize<'a>>(mut f: File) -> Result { - let mut data = Vec::new(); - f.read_to_end(&mut data) - .map_err(serde_json::Error::custom)?; - serde_json::from_slice(&data) -} - -impl UnboundMod for ExampleMod { - type EnvData = (); - type QstateData = (); - fn init(_env: &mut crate::unbound::ModuleEnv) -> Result { - let mut ret = Self { - nft_token: std::env::var_os("NFT_TOKEN") - .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned() + ".")) - .transpose()?, - tmp_nft_token: std::env::var_os("NFT_TOKEN") - .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned() + ".tmp.")) - .transpose()?, - ..Self::default() - }; - if let Some(s) = std::env::var_os("DOMAIN_NAME_OVERRIDES") { - for (k, v) in s - .to_str() - .map(|x| x.to_owned()) - .ok_or(())? - .split(';') - .filter_map(|x| x.split_once("->")) - { - ret.domain_name_overrides - .insert(k.as_bytes().into(), v.as_bytes().into()); - } - } - let mut nft_queries = HashMap::new(); - let mut rulesets = Vec::new(); - if let Some(s) = std::env::var_os("NFT_QUERIES") { - for (i, (name, set4, set6)) in s - .to_str() - .map(|x| x.to_owned()) - .ok_or(())? - .split(';') - .filter_map(|x| x.split_once(':')) - .filter_map(|(name, sets)| { - sets.split_once(',').map(|(set4, set6)| (name, set4, set6)) - }) - .enumerate() - { - let (name, dynamic) = if let Some(name) = name.strip_suffix('!') { - (name, true) - } else { - (name, false) - }; - nft_queries.insert( - name.to_owned(), - NftQuery { - domains: RwLock::new(PrefixSet::new()), - dynamic, - index: i, - }, - ); - rulesets.push(( - NftData { - set: None, - ips: RTrieSet::new(), - dirty: true, - name: set4.to_owned(), - }, - NftData { - set: None, - ips: RTrieSet::new(), - dirty: true, - name: set6.to_owned(), - }, - )); - } - } - - // load cached domains - if let Err(err) = ret.caches.0.load(Path::new("/var/lib/unbound/domains4/")) { + let mut base_path = PathBuf::from_str(DATA_PREFIX).unwrap(); + base_path.push("domains4"); + if let Err(err) = ret.caches.0.load(&base_path) { ret.report("domains4", err); } - if let Err(err) = ret.caches.1.load(Path::new("/var/lib/unbound/domains6/")) { + base_path.pop(); + base_path.push("domains6"); + if let Err(err) = ret.caches.1.load(&base_path) { ret.report("domains6", err); } - // load json files - for (k, v) in nft_queries.iter_mut() { + ret.load_json(&mut rulesets); + + // it takes like 10 seconds to initialize nftables, so move it to a separate thread + let (tx, rx) = mpsc::channel(); + ret.ruleset_queue = Some(tx); + std::thread::spawn(move || nftables_thread(rulesets, rx)); + + println!("loaded"); + + Ok(ret) + } + fn load_json(&mut self, rulesets: &mut [(NftData, NftData)]) { + for (k, v) in self.nft_queries.iter_mut() { let r = &mut rulesets[v.index]; let mut v_domains = v.domains.write().unwrap(); - for base in ["/etc/unbound", "/var/lib/unbound"] { + for base in [CONFIG_PREFIX, DATA_PREFIX] { if let Ok(file) = std::fs::File::open(format!("{base}/{k}_domains.json")) { println!("loading {base}/{k}_domains.json"); match read_json::>(file) { @@ -504,7 +330,7 @@ impl UnboundMod for ExampleMod { ); } } - Err(err) => ret.report("domains", err), + Err(err) => Self::report2(&self.error_lock, "domains", err), } } if let Ok(file) = std::fs::File::open(format!("{base}/{k}_dpi.json")) { @@ -521,7 +347,7 @@ impl UnboundMod for ExampleMod { ); } } - Err(err) => ret.report("dpi", err), + Err(err) => Self::report2(&self.error_lock, "dpi", err), } } if let Ok(file) = std::fs::File::open(format!("{base}/{k}_ips.json")) { @@ -543,7 +369,7 @@ impl UnboundMod for ExampleMod { } })); } - Err(err) => ret.report("ips", err), + Err(err) => Self::report2(&self.error_lock, "ips", err), } } } @@ -554,122 +380,177 @@ impl UnboundMod for ExampleMod { .collect::>() .join(&b"."[..]) .into(); - ret.caches + self.caches .0 - .extend_set_with_domain(&mut r.0.ips, rev_domain.clone()); - ret.caches + .extend_set_with_domain(r.0.ips_mut(), rev_domain.clone()); + self.caches .1 - .extend_set_with_domain(&mut r.1.ips, rev_domain.clone()); + .extend_set_with_domain(r.1.ips_mut(), rev_domain.clone()); } } - - // add stuff to nftables - let (tx, rx) = mpsc::channel(); - - ret.ruleset_queue = Some(tx); - - std::thread::spawn(move || { - fn report(err: impl Display) { - println!("nftables: {err}"); - if let Ok(mut file) = std::fs::OpenOptions::new() - .append(true) - .open("/var/lib/unbound/nftables.log") - { - if file.write_all(err.to_string().as_bytes()).is_err() { - return; - } - if file.write_all(b"\n").is_err() { - return; - } - file.flush().unwrap_or(()); - } - } - let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap(); - let all_sets = crate::nftables::get_sets(&socket).unwrap(); - for set in all_sets { - for ruleset in &mut rulesets { - if set.table_name() == Some("global") - && set.family() == libc::NFPROTO_INET as u32 - { - if set.name() == Some(&ruleset.0.name) { - println!("found set {}", ruleset.0.name); - ruleset.0.set = Some(set.clone()); - } else if set.name() == Some(&ruleset.1.name) { - println!("found set {}", ruleset.1.name); - ruleset.1.set = Some(set.clone()); - } - } - } - } - for ruleset in &mut rulesets { - if !ruleset.0.verify() { - report(format!("set {} not found", ruleset.0.name)); - } - if !ruleset.1.verify() { - report(format!("set {} not found", ruleset.1.name)); - } - } - let mut first = true; - loop { - for ruleset in &mut rulesets { - if let Err(err) = ruleset.0.flush_changes(&socket, first) { - report(err); - } - if let Err(err) = ruleset.1.flush_changes(&socket, first) { - report(err); - } - } - if first { - println!("nftables init done"); - first = false; - } - let res = match rx.recv() { - Ok(val) => Some(val), - Err(RecvError) => break, - }; - if let Some((rulesets1, ips)) = res { - for i in rulesets1.into_iter() { - let ruleset = &mut rulesets[i]; - for ip1 in ips.iter().copied() { - match ip1 { - IpNet::V4(ip) => ruleset.0.insert(ip), - IpNet::V6(ip) => ruleset.1.insert(ip), - } - } - } - } - } - }); - println!("loaded"); - - Ok(ret) } - - fn operate( - &self, - qstate: &mut crate::unbound::ModuleQstate, - event: ModuleEvent, - _entry: &mut crate::unbound::OutboundEntryMut, - ) -> Option { - match event { - ModuleEvent::New | ModuleEvent::Pass => { - return Some(ModuleExtState::WaitModule); - } - ModuleEvent::ModDone => {} - _ => { - return Some(ModuleExtState::Error); + fn load_env(&mut self) -> Result, NftData)>, ()> { + self.nft_token = std::env::var_os("NFT_TOKEN") + .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned() + ".")) + .transpose()?; + self.tmp_nft_token = std::env::var_os("NFT_TOKEN") + .map(|x| x.to_str().ok_or(()).map(|s| format!("tmp{s}."))) + .transpose()?; + let mut rulesets = Vec::new(); + assert!(self.nft_queries.is_empty()); + if let Some(s) = std::env::var_os("NFT_QUERIES") { + for (i, (name, set4, set6)) in s + .to_str() + .map(|x| x.to_owned()) + .ok_or(())? + .split(';') + .filter_map(|x| x.split_once(':')) + .filter_map(|(name, sets)| { + sets.split_once(',').map(|(set4, set6)| (name, set4, set6)) + }) + .enumerate() + { + let (name, dynamic) = if let Some(name) = name.strip_suffix('!') { + (name, true) + } else { + (name, false) + }; + self.nft_queries.insert( + name.to_owned(), + NftQuery { + domains: RwLock::new(PrefixSet::new()), + dynamic, + index: i, + }, + ); + rulesets.push((NftData::new(set4), NftData::new(set6))); } } - let info = qstate.qinfo_mut(); - let name = info.qname().to_bytes(); - let rev_domain = name.strip_suffix(b".").unwrap_or(name); + Ok(rulesets) + } + fn report2(error_lock: &Mutex<()>, code: &str, err: impl Display) { + println!("{code}: {err}"); + if let Ok(mut file) = std::fs::OpenOptions::new() + .append(true) + .create(true) + .open(format!("{DATA_PREFIX}/error.log")) + { + let _lock = error_lock.lock(); + if file.write_all(code.as_bytes()).is_err() { + return; + } + if file.write_all(b": ").is_err() { + return; + } + if file.write_all(err.to_string().as_bytes()).is_err() { + return; + } + if file.write_all(b"\n").is_err() { + return; + } + file.flush().unwrap_or(()); + } + } + fn report(&self, code: &str, err: impl Display) { + Self::report2(&self.error_lock, code, err); + } + fn handle_reply_info( + &self, + split_rev_domain: SmallVec<[DomainSeg; 5]>, + qnames: SmallVec<[usize; 5]>, + rep: &ReplyInfo<'_>, + ) -> Result<(), ()> { + let mut ip4: SmallVec<[Ipv4Addr; 4]> = SmallVec::new(); + let mut ip6: SmallVec<[Ipv6Addr; 4]> = SmallVec::new(); + for rrset in rep.rrsets() { + let entry = rrset.entry(); + let d = entry.data(); + let rk = rrset.rk(); + if rk.rrset_class() != rr_class::IN { + continue; + } + for (data, _ttl) in d.rr_data() { + match rk.type_() { + rr_type::A if data.len() == 2 + 4 && &data[..2] == b"\0\x04" => { + ip4.push(Ipv4Addr::from( + <[u8; 4]>::try_from(&data[2..2 + 4]).unwrap(), + )); + } + rr_type::AAAA if data.len() == 2 + 16 && &data[..2] == b"\0\x10" => { + ip6.push(Ipv6Addr::from( + <[u8; 16]>::try_from(&data[2..2 + 16]).unwrap(), + )); + } + _ => {} + } + } + } + self.add_ips(ip4, ip6, split_rev_domain, qnames) + } + fn add_ips( + &self, + ip4: SmallVec<[Ipv4Addr; 4]>, + ip6: SmallVec<[Ipv6Addr; 4]>, + split_rev_domain: SmallVec<[DomainSeg; 5]>, + qnames: SmallVec<[usize; 5]>, + ) -> Result<(), ()> { + if !ip4.is_empty() || !ip6.is_empty() { + let domain = match split_rev_domain + .iter() + .rev() + .map(|x| String::from_utf8(x.to_vec()).map(|x| x + ".")) + .collect::>() + { + Ok(mut x) => { + x.pop(); + x + } + Err(err) => { + self.report("domain utf-8", err); + return Err(()); + } + }; + let mut split_rev_domain = split_rev_domain.into_iter(); + if let Some(first) = split_rev_domain.next() { + let first: Domain = first.to_vec().into(); + let joined_rev_domain = split_rev_domain.fold(first, |mut res, mut next| { + res.push(b'.'); + res.append(&mut next); + res + }); + let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new(); + to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); + to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); + let keep4 = !ip4.is_empty() + && self + .caches + .0 + .set(&domain, IpCacheKey(joined_rev_domain.clone()), ip4); + let keep6 = !ip6.is_empty() + && self + .caches + .1 + .set(&domain, IpCacheKey(joined_rev_domain.clone()), ip6); + to_send.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6); + if !to_send.is_empty() { + self.ruleset_queue + .as_ref() + .unwrap() + .send((qnames, to_send)) + .unwrap(); + } + } + } + Ok(()) + } + fn run_commands(&self, rev_domain: &[u8]) -> Option { if let Some(rev_domain) = self .nft_token .as_ref() .and_then(|token| rev_domain.strip_prefix(token.as_bytes())) { for (qname, query) in self.nft_queries.iter() { - if query.dynamic && rev_domain.ends_with(qname.as_bytes()) { + if query.dynamic && rev_domain.starts_with(qname.as_bytes()) { if let Some(rev_domain) = rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes()) { @@ -680,7 +561,7 @@ impl UnboundMod for ExampleMod { let mut domains = query.domains.write().unwrap(); if domains.insert(rev_domain.clone()) { drop(domains); - let file_name = format!("/var/lib/unbound/{qname}_domains.json"); + let file_name = format!("{DATA_PREFIX}/{qname}_domains.json"); let domain = match String::from_utf8( rev_domain .iter() @@ -727,7 +608,7 @@ impl UnboundMod for ExampleMod { .and_then(|token| rev_domain.strip_prefix(token.as_bytes())) { for (qname, query) in self.nft_queries.iter() { - if query.dynamic && rev_domain.ends_with(qname.as_bytes()) { + if query.dynamic && rev_domain.starts_with(qname.as_bytes()) { if let Some(rev_domain) = rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes()) { @@ -742,96 +623,107 @@ impl UnboundMod for ExampleMod { } return Some(ModuleExtState::Finished); } + None + } + fn get_qnames(&self, split_rev_domain: &SmallVec<[DomainSeg; 5]>) -> SmallVec<[usize; 5]> { + let mut qnames: SmallVec<[usize; 5]> = SmallVec::new(); + for query in self.nft_queries.values() { + if query.domains.read().unwrap().contains(split_rev_domain) { + qnames.push(query.index); + } + } + qnames + } +} + +#[derive(Deserialize)] +struct DpiInfo { + domains: Vec, + // name: String, + // restriction: {"code": "ban"} +} + +pub(crate) trait Helper: iptrie::IpPrefix + iptrie::IpRootPrefix + PartialEq { + const ZERO: Self; + fn direct_parent(&self) -> Option; +} + +impl Helper for Ipv4Net { + const ZERO: Self = match Self::new(Ipv4Addr::UNSPECIFIED, 0) { + Ok(x) => x, + #[allow(clippy::empty_loop)] + Err(_) => loop {}, + }; + fn direct_parent(&self) -> Option { + self.len() + .checked_sub(1) + .and_then(|x| Self::new(self.bitslot().into(), x).ok()) + } +} + +impl Helper for Ipv6Net { + const ZERO: Self = match Self::new(Ipv6Addr::UNSPECIFIED, 0) { + Ok(x) => x, + #[allow(clippy::empty_loop)] + Err(_) => loop {}, + }; + fn direct_parent(&self) -> Option { + self.len() + .checked_sub(1) + .and_then(|x| Self::new(self.bitslot().into(), x).ok()) + } +} +fn read_json Deserialize<'a>>(mut f: File) -> Result { + let mut data = Vec::new(); + f.read_to_end(&mut data) + .map_err(serde_json::Error::custom)?; + serde_json::from_slice(&data) +} + +impl UnboundMod for ExampleMod { + type EnvData = (); + type QstateData = (); + + fn init(_env: &mut crate::unbound::ModuleEnv) -> Result { + Self::new() + } + + fn operate( + &self, + qstate: &mut crate::unbound::ModuleQstate, + event: ModuleEvent, + _entry: &mut crate::unbound::OutboundEntryMut, + ) -> Option { + match event { + ModuleEvent::New | ModuleEvent::Pass => { + return Some(ModuleExtState::WaitModule); + } + ModuleEvent::ModDone => {} + _ => { + return Some(ModuleExtState::Error); + } + } + let info = qstate.qinfo_mut(); + let name = info.qname().to_bytes(); + let rev_domain = name.strip_suffix(b".").unwrap_or(name); + if let Some(val) = self.run_commands(rev_domain) { + return Some(val); + } let split_rev_domain = rev_domain .split(|x| *x == b'.') .map(|x| x.into()) .collect::>(); - let mut qnames: SmallVec<[usize; 5]> = SmallVec::new(); - for query in self.nft_queries.values() { - if query.domains.read().unwrap().contains(&split_rev_domain) { - qnames.push(query.index); - } - } + let qnames = self.get_qnames(&split_rev_domain); if qnames.is_empty() { return Some(ModuleExtState::Finished); } if let Some(ret) = qstate.return_msg_mut() { if let Some(rep) = ret.rep() { - let mut ip4: SmallVec<[Ipv4Addr; 4]> = SmallVec::new(); - let mut ip6: SmallVec<[Ipv6Addr; 4]> = SmallVec::new(); - for rrset in rep.rrsets() { - let entry = rrset.entry(); - let d = entry.data(); - let rk = rrset.rk(); - if rk.rrset_class() != rr_class::IN { - continue; - } - for (data, _ttl) in d.rr_data() { - match rk.type_() { - rr_type::A if data.len() == 2 + 4 && &data[..2] == b"\0\x04" => { - ip4.push(Ipv4Addr::from( - <[u8; 4]>::try_from(&data[2..2 + 4]).unwrap(), - )); - } - rr_type::AAAA if data.len() == 2 + 16 && &data[..2] == b"\0\x10" => { - ip6.push(Ipv6Addr::from( - <[u8; 16]>::try_from(&data[2..2 + 16]).unwrap(), - )); - } - _ => {} - } - } - } - if !ip4.is_empty() || !ip6.is_empty() { - let domain = match split_rev_domain - .iter() - .rev() - .map(|x| String::from_utf8(x.to_vec()).map(|x| x + ".")) - .collect::>() - { - Ok(mut x) => { - x.pop(); - x - } - Err(err) => { - self.report("domain utf-8", err); - return Some(ModuleExtState::Error); - } - }; - let mut split_rev_domain = split_rev_domain.into_iter(); - if let Some(first) = split_rev_domain.next() { - let first: Domain = first.to_vec().into(); - let joined_rev_domain = - split_rev_domain.fold(first, |mut res, mut next| { - res.push(b'.'); - res.append(&mut next); - res - }); - let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new(); - to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); - to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); - let keep4 = !ip4.is_empty() - && self.caches.0.set( - &domain, - IpCacheKey(joined_rev_domain.clone()), - ip4, - ); - let keep6 = !ip6.is_empty() - && self.caches.1.set( - &domain, - IpCacheKey(joined_rev_domain.clone()), - ip6, - ); - to_send - .retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6); - if !to_send.is_empty() { - self.ruleset_queue - .as_ref() - .unwrap() - .send((qnames, to_send)) - .unwrap(); - } - } + if self + .handle_reply_info(split_rev_domain, qnames, &rep) + .is_err() + { + return Some(ModuleExtState::Error); } } } @@ -846,31 +738,164 @@ fn setup() { #[cfg(test)] mod test { - use std::net::Ipv4Addr; + use std::{net::Ipv4Addr, os::unix::fs::MetadataExt, path::PathBuf, str::FromStr, sync::mpsc}; - use ipnet::Ipv4Net; - use iptrie::RTrieSet; + use ipnet::IpNet; + use smallvec::{smallvec, SmallVec}; - use crate::example::{iter_ip_trie, should_add, IpNetDeser}; + use crate::example::{ignore, ExampleMod, IpNetDeser, DATA_PREFIX}; #[test] fn test() { - let mut trie = RTrieSet::new(); - assert!(should_add( - &trie, - &Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap() - )); - trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()); - assert!(!should_add( - &trie, - &Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap() - )); - trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 31).unwrap()); - assert!(dbg!(iter_ip_trie(&trie).collect::>()).len() == 1); - // contains 0.0.0.0, etc - assert!(dbg!(trie.iter().collect::>()).len() == 3); - trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 1, 1), 32).unwrap()); - assert!(dbg!(iter_ip_trie(&trie).collect::>()).len() == 2); - assert!(serde_json::from_str::>(r#"["127.0.0.1/8","127.0.0.1"]"#).is_ok()) + assert!(serde_json::from_str::>(r#"["127.0.0.1/8","127.0.0.1"]"#).is_ok()); + #[cfg(not(debug_assertions))] + return; + + std::fs::remove_dir_all(DATA_PREFIX).unwrap_or(()); + + std::env::set_var("NFT_TOKEN", "token"); + std::env::set_var("NFT_QUERIES", "q!:set_a,set_b;w:set_c,set_d"); + + std::fs::create_dir_all(DATA_PREFIX.to_string() + "/domains4").unwrap(); + std::fs::write( + DATA_PREFIX.to_string() + "/domains4/a.com", + "1.2.3.4\n5.6.7.8", + ) + .unwrap(); + filetime::set_file_mtime( + DATA_PREFIX.to_string() + "/domains4/a.com", + filetime::FileTime::zero(), + ) + .unwrap(); + + std::fs::write(DATA_PREFIX.to_string() + "/domains4/b.com", "8.7.6.5").unwrap(); + std::fs::write( + DATA_PREFIX.to_string() + "/q_domains.json", + r#"["a.com","c.com"]"#, + ) + .unwrap(); + std::fs::write(DATA_PREFIX.to_string() + "/q_ips.json", r#"["4.4.4.4"]"#).unwrap(); + std::fs::write(DATA_PREFIX.to_string() + "/w_domains.json", r#"["c.com"]"#).unwrap(); + std::fs::write(DATA_PREFIX.to_string() + "/w_ips.json", r#"["5.5.5.5"]"#).unwrap(); + + let mut t = ExampleMod::default(); + let mut rulesets = t.load_env().unwrap(); + assert!(t.nft_queries.len() == 2 && rulesets.len() == t.nft_queries.len()); + assert!(t.nft_queries.get("q").unwrap().dynamic); + assert!(!t.nft_queries.get("w").unwrap().dynamic); + + t.report("", ""); + std::fs::metadata(DATA_PREFIX.to_string() + "/error.log").unwrap(); + + let mut base_path = PathBuf::from_str(DATA_PREFIX).unwrap(); + base_path.push("domains4"); + t.caches.0.load(&base_path).unwrap(); + base_path.pop(); + base_path.push("domains6"); + t.caches.1.load(&base_path).unwrap(); + + t.caches + .0 + .get_maybe_update_rev("com.a".as_bytes().into(), |x| { + assert!(x.unwrap().0.len() == 2); + #[allow(unused_assignments)] + let mut val = Some(ignore); + val = None; + val + }); + t.caches + .0 + .get_maybe_update_rev("com.b".as_bytes().into(), |x| { + assert!(x.unwrap().0.len() == 1); + #[allow(unused_assignments)] + let mut val = Some(ignore); + val = None; + val + }); + + t.load_json(&mut rulesets); + + assert_eq!(rulesets[0].0.ip_count(), 3); + assert_eq!(rulesets[1].0.ip_count(), 1); + + let (tx, rx) = mpsc::channel(); + let (tx2, rx2) = mpsc::channel(); + + t.ruleset_queue = Some(tx); + + std::thread::spawn(move || { + while let Ok((rulesets1, ips)) = rx.recv() { + for i in rulesets1.into_iter() { + let ruleset = &mut rulesets[i]; + for ip1 in ips.iter().copied() { + match ip1 { + IpNet::V4(ip) => ruleset.0.insert(ip, true), + IpNet::V6(ip) => ruleset.1.insert(ip, true), + } + } + } + } + tx2.send(rulesets).unwrap(); + }); + + let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"c"[..])]; + let qnames = t.get_qnames(&split_rev_domain); + assert_eq!(qnames.len(), 2); + t.add_ips( + smallvec![Ipv4Addr::new(7, 7, 7, 7), Ipv4Addr::new(6, 6, 6, 6)], + smallvec![], + split_rev_domain, + qnames, + ) + .unwrap(); + + let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"a"[..])]; + let qnames = t.get_qnames(&split_rev_domain); + t.add_ips( + smallvec![Ipv4Addr::new(1, 2, 3, 4), Ipv4Addr::new(5, 6, 7, 8)], + smallvec![], + split_rev_domain, + qnames, + ) + .unwrap(); + + t.run_commands(b"token.q.com.w").unwrap(); + t.run_commands(b"tmptoken.q.com.e").unwrap(); + + let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"e"[..])]; + let qnames = t.get_qnames(&split_rev_domain); + assert_eq!(qnames.len(), 1); + t.add_ips( + smallvec![Ipv4Addr::new(8, 8, 8, 8)], + smallvec![], + split_rev_domain, + qnames, + ) + .unwrap(); + + let split_rev_domain = smallvec![SmallVec::from(&b"com"[..]), SmallVec::from(&b"w"[..])]; + let qnames = t.get_qnames(&split_rev_domain); + assert_eq!(qnames.len(), 1); + t.add_ips( + smallvec![Ipv4Addr::new(9, 8, 8, 8)], + smallvec![], + split_rev_domain, + qnames, + ) + .unwrap(); + + drop(t); + let rulesets = rx2.recv().unwrap(); + + std::fs::metadata(DATA_PREFIX.to_owned() + "/domains4/w.com").unwrap(); + assert_ne!( + std::fs::metadata(DATA_PREFIX.to_string() + "/domains4/a.com") + .unwrap() + .mtime(), + 0 + ); + + assert_eq!(rulesets[0].0.ip_count(), 7); + assert_eq!(rulesets[1].0.ip_count(), 3); } } diff --git a/src/lib.rs b/src/lib.rs index e13456d..739de1a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![allow(clippy::type_complexity)] use std::panic::{RefUnwindSafe, UnwindSafe}; use unbound::ModuleExtState; diff --git a/src/nftables.rs b/src/nftables.rs index 584b67f..877dc3b 100644 --- a/src/nftables.rs +++ b/src/nftables.rs @@ -1,18 +1,23 @@ use std::{ cell::Cell, ffi::CStr, - io, + fmt::Display, + io::{self, Write}, net::{Ipv4Addr, Ipv6Addr}, os::{ fd::BorrowedFd, raw::{c_char, c_void}, }, rc::Rc, + sync::mpsc, }; +use crate::example::{Helper, DATA_PREFIX}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; +use iptrie::RTrieSet; use mnl::mnl_sys; use nftnl::{nftnl_sys, set::SetKey, Batch, FinalizedBatch, MsgType, NlMsg}; +use smallvec::SmallVec; fn cidr_bound_ipv4(net: Ipv4Net) -> Option { let data = u32::from(net.network()); @@ -324,16 +329,191 @@ pub fn get_sets(socket: &mnl::Socket) -> io::Result> { Ok(ret) } +fn should_add(trie: &RTrieSet, elem: &T) -> bool { + *trie.lookup(elem) == T::ZERO +} + +fn iter_ip_trie(trie: &RTrieSet) -> impl '_ + Iterator { + trie.iter().copied().filter(|x| { + if let Some(par) = x.direct_parent() { + should_add(trie, &par) + } else { + *x != T::ZERO + } + }) +} + +pub(crate) struct NftData { + ips: RTrieSet, + dirty: bool, + set: Option, + name: String, +} + +impl NftData { + pub fn new(name: &str) -> Self { + Self { + set: None, + ips: RTrieSet::new(), + dirty: true, + name: name.to_owned(), + } + } +} + +// SAFETY: set is None initially so Set1 is never actually sent +// (and it might be fine to send anyway actually) +unsafe impl Send for NftData {} + +impl NftData +where + IpNet: From, +{ + #[must_use] + pub fn verify(&mut self) -> bool { + if !self.name.is_empty() && self.set.is_none() { + self.ips = RTrieSet::new(); + false + } else { + true + } + } + pub fn flush_changes( + &mut self, + socket: &mnl::Socket, + flush_set: bool, + ) -> Result<(), io::Error> { + if let Some(set) = self.set.as_mut().filter(|_| self.dirty) { + if flush_set { + println!( + "initializing set {} with ~{} ips (e.g. {:?})", + self.name, + self.ips.len(), + iter_ip_trie(&self.ips).next(), + ); + } + set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from)) + } else { + Ok(()) + } + } + pub fn extend(&mut self, ips: impl Iterator) { + for ip in ips { + self.insert(ip, true); + } + } + pub fn insert(&mut self, ip: T, allow_empty_set: bool) { + if (if allow_empty_set { + !self.name.is_empty() + } else { + self.set.is_some() + }) && should_add(&self.ips, &ip) + { + self.ips.insert(ip); + self.dirty = true; + } + } + pub fn ips_mut(&mut self) -> &mut RTrieSet { + &mut self.ips + } + #[cfg(test)] + pub fn ip_count(&self) -> usize { + iter_ip_trie(&self.ips).count() + } + pub fn name(&self) -> &str { + &self.name + } + pub fn set_set(&mut self, set: Set1) { + self.set = Some(set); + } +} + +pub(crate) fn nftables_thread( + mut rulesets: Vec<(NftData, NftData)>, + rx: mpsc::Receiver<(SmallVec<[usize; 5]>, smallvec::SmallVec<[IpNet; 8]>)>, +) { + fn report(err: impl Display) { + println!("nftables: {err}"); + if let Ok(mut file) = std::fs::OpenOptions::new() + .append(true) + .create(true) + .open(format!("{DATA_PREFIX}/nftables.log")) + { + file.write_all((err.to_string() + "\n").as_bytes()) + .unwrap_or(()); + } + } + let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap(); + let all_sets = crate::nftables::get_sets(&socket).unwrap(); + for set in all_sets { + for ruleset in &mut rulesets { + if set.table_name() == Some("global") && set.family() == libc::NFPROTO_INET as u32 { + if set.name() == Some(ruleset.0.name()) { + println!("found set {}", ruleset.0.name()); + ruleset.0.set_set(set); + break; + } else if set.name() == Some(ruleset.1.name()) { + println!("found set {}", ruleset.1.name()); + ruleset.1.set_set(set); + break; + } + } + } + } + for ruleset in &mut rulesets { + if !ruleset.0.verify() { + report(format!("set {} not found", ruleset.0.name())); + } + if !ruleset.1.verify() { + report(format!("set {} not found", ruleset.1.name())); + } + } + let mut first = true; + loop { + for ruleset in &mut rulesets { + if let Err(err) = ruleset.0.flush_changes(&socket, first) { + report(err); + } + if let Err(err) = ruleset.1.flush_changes(&socket, first) { + report(err); + } + } + if first { + println!("nftables init done"); + first = false; + } + let (rulesets1, ips) = match rx.recv() { + Ok(val) => val, + Err(_) => break, + }; + for i in rulesets1.into_iter() { + let ruleset = &mut rulesets[i]; + for ip1 in ips.iter().copied() { + match ip1 { + IpNet::V4(ip) => ruleset.0.insert(ip, false), + IpNet::V6(ip) => ruleset.1.insert(ip, false), + } + } + } + } +} + #[cfg(test)] mod test { - use std::net::Ipv6Addr; + use std::net::{Ipv4Addr, Ipv6Addr}; - use ipnet::Ipv6Net; + use ipnet::{Ipv4Net, Ipv6Net}; + use iptrie::RTrieSet; + + use crate::nftables::{iter_ip_trie, should_add}; use super::get_sets; #[test] fn test_nftables() { + if !nix::unistd::Uid::effective().is_root() { + return; + } let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap(); let sets = get_sets(&socket).unwrap(); assert!(!sets.is_empty()); @@ -352,4 +532,24 @@ mod test { } panic!(); } + + #[test] + fn test_set() { + let mut trie = RTrieSet::new(); + assert!(should_add( + &trie, + &Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap() + )); + trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()); + assert!(!should_add( + &trie, + &Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap() + )); + trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 31).unwrap()); + assert!(dbg!(iter_ip_trie(&trie).collect::>()).len() == 1); + // contains 0.0.0.0, etc + assert!(dbg!(trie.iter().collect::>()).len() == 3); + trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 1, 1), 32).unwrap()); + assert!(dbg!(iter_ip_trie(&trie).collect::>()).len() == 2); + } }