From 14346134b555e9b2c315d23d67535779a5fa85ef Mon Sep 17 00:00:00 2001 From: chayleaf Date: Mon, 12 Aug 2024 11:24:01 +0700 Subject: [PATCH] almost done --- Cargo.lock | 113 ++++++++++++++- Cargo.toml | 2 +- FIXME | 2 + flake.nix | 5 +- src/domain_tree.rs | 130 +++++++++++++++++ src/example.rs | 340 ++++++++++++++++++++++++++++++--------------- src/lib.rs | 1 + src/unbound.rs | 135 +++++++++++------- 8 files changed, 559 insertions(+), 169 deletions(-) create mode 100644 FIXME create mode 100644 src/domain_tree.rs diff --git a/Cargo.lock b/Cargo.lock index 17b333d..858c65a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,6 +62,18 @@ dependencies = [ "synstructure", ] +[[package]] +name = "filetime" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf401df4a4e3872c4fe8151134cf483738e74b67fc934d6532c882b3d24a4550" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys", +] + [[package]] name = "ipnet" version = "2.9.0" @@ -92,6 +104,17 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.6.0", + "libc", + "redox_syscall", +] + [[package]] name = "log" version = "0.4.22" @@ -175,12 +198,6 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" -[[package]] -name = "prefix-tree" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f499660c89b7cfbbf11bb2faefe26a187062d7ff0f06bc4aba434328213f044" - [[package]] name = "proc-macro-error" version = "1.0.4" @@ -233,6 +250,15 @@ dependencies = [ "nibble_vec", ] +[[package]] +name = "redox_syscall" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "rustversion" version = "1.0.17" @@ -323,13 +349,13 @@ version = "0.1.0" dependencies = [ "boxcar", "ctor", + "filetime", "ipnet", "iptrie", "libc", "mnl", "nftnl", "nix", - "prefix-tree", "radix_trie", "serde", "serde_json", @@ -353,3 +379,76 @@ name = "version_check" version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/Cargo.toml b/Cargo.toml index 02df0bc..0f50c38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,13 +11,13 @@ crate-type = ["rlib", "cdylib"] [dependencies] boxcar = "0.2.5" ctor = { version = "0.2.8", optional = true } +filetime = "0.2.24" ipnet = { version = "2.9.0", features = ["serde"] } iptrie = "0.8.5" libc = "0.2.155" mnl = { version = "0.2.2", features = ["mnl-1-0-4"] } nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] } nix = { version = "0.29.0", features = ["poll"] } -prefix-tree = "0.5.0" radix_trie = "0.2.1" serde = { version = "1.0.205", features = ["derive"] } serde_json = "1.0.122" diff --git a/FIXME b/FIXME new file mode 100644 index 0000000..e4d2b80 --- /dev/null +++ b/FIXME @@ -0,0 +1,2 @@ +seemingly log files dont actually work +cant set mtime? or doesnt update shit? or doesnt write? questions, questions diff --git a/flake.nix b/flake.nix index 5eb126e..ff71e75 100644 --- a/flake.nix +++ b/flake.nix @@ -47,7 +47,10 @@ devShells = gen (pkgs: { default = pkgs.mkShell rec { name = "unbound-rust-mod-shell"; - nativeBuildInputs = [ pkgs.rustc pkgs.cargo pkgs.nftables ]; + nativeBuildInputs = [ + # pkgs.rustc pkgs.cargo + pkgs.nftables + ]; LIBMNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libmnl}/lib"; LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib (pkgs.libnftnl.overrideAttrs (old: { patches = (old.patches or []) ++ [ ./libnftnl-fix.patch ]; diff --git a/src/domain_tree.rs b/src/domain_tree.rs new file mode 100644 index 0000000..eaf47c2 --- /dev/null +++ b/src/domain_tree.rs @@ -0,0 +1,130 @@ +use std::{collections::HashMap, hash::Hash}; + +use smallvec::{smallvec, SmallVec}; + +pub enum PrefixSet { + Map(HashMap>), + Leaf, +} + +impl Default for PrefixSet { + fn default() -> Self { + Self::new() + } +} + +impl PrefixSet { + pub fn new() -> Self { + Self::Map(HashMap::new()) + } +} + +impl PrefixSet { + // returns whether its new + pub fn insert(&mut self, val: impl IntoIterator) -> bool { + match self { + Self::Leaf => false, + Self::Map(map) => { + let mut it = val.into_iter(); + if let Some(k) = it.next() { + map.entry(k).or_default().insert(it) + } else { + *self = Self::Leaf; + true + } + } + } + } + pub fn contains<'a>(&self, val: impl IntoIterator) -> bool + where + T: 'a, + { + match self { + Self::Leaf => true, + Self::Map(map) => { + let mut it = val.into_iter(); + if let Some(k) = it.next() { + let Some(next) = map.get(k) else { + return false; + }; + next.contains(it) + } else { + true + } + } + } + } + pub fn iter(&self) -> impl Iterator> { + match self { + Self::Leaf => Iter(SmallVec::new(), SmallVec::new()), + Self::Map(map) => Iter(smallvec![map.iter()], smallvec![]), + } + } +} + +struct Iter<'a, T>( + SmallVec<[std::collections::hash_map::Iter<'a, T, PrefixSet>; 8]>, + SmallVec<[&'a T; 8]>, +); + +impl<'a, T> Iterator for Iter<'a, T> { + type Item = smallvec::IntoIter<[&'a T; 8]>; + fn next(&mut self) -> Option { + while let Some(it) = self.0.last_mut() { + let Some((k, v)) = it.next() else { + self.0.pop(); + if self.1.pop().is_none() { + return None; + } + continue; + }; + self.1.push(k); + match v { + PrefixSet::Leaf => { + let ret = self.1.clone().into_iter(); + self.1.pop(); + return Some(ret); + } + PrefixSet::Map(m) => { + self.0.push(m.iter()); + } + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::PrefixSet; + + #[test] + fn test() { + let mut tree = PrefixSet::<&str>::new(); + assert!(tree.insert(["a", "b", "c"])); + assert!(tree.insert(["b", "c", "d"])); + assert!(tree.insert(["a", "b"])); + assert!(!tree.insert(["a", "b", "c"])); + assert!(tree.contains([&"a", &"b", &"c"])); + assert!(!tree.contains([&"a", &"c"])); + let mut it = tree.iter(); + assert!(matches!( + it.next() + .unwrap() + .into_iter() + .copied() + .collect::() + .as_str(), + "ab" | "bcd" + )); + assert!(matches!( + it.next() + .unwrap() + .into_iter() + .copied() + .collect::() + .as_str(), + "ab" | "bcd" + )); + } +} diff --git a/src/example.rs b/src/example.rs index adfe848..aad2158 100644 --- a/src/example.rs +++ b/src/example.rs @@ -2,11 +2,12 @@ use std::{ collections::HashMap, fmt::Display, fs::File, - io::{self, BufRead, BufReader, Write}, - net::{Ipv4Addr, Ipv6Addr}, + io::{self, BufRead, BufReader, Read, Write}, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::{Path, PathBuf}, str::FromStr, sync::{ + atomic::{AtomicBool, Ordering}, mpsc::{self, RecvError}, Mutex, RwLock, }, @@ -16,19 +17,70 @@ use std::{ use ctor::ctor; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use iptrie::{IpPrefix, RTrieSet}; -use prefix_tree::PrefixSet; -use serde::Deserialize; +use serde::{ + de::{Error, Visitor}, + Deserialize, +}; use smallvec::SmallVec; use crate::{ + domain_tree::PrefixSet, nftables::Set1, - unbound::{rr_class, rr_type}, + unbound::{rr_class, rr_type, ModuleEvent, ModuleExtState}, UnboundMod, }; type Domain = SmallVec<[u8; 32]>; type DomainSeg = SmallVec<[u8; 16]>; +struct IpNetDeser(IpNet); +struct IpNetVisitor; +impl<'de> Visitor<'de> for IpNetVisitor { + type Value = IpNetDeser; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("ip address or cidr") + } + fn visit_borrowed_str(self, v: &'de str) -> Result + where + E: Error, + { + if let Some((a, b)) = v.split_once('/') { + let ip = IpAddr::from_str(a) + .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(v), &self))?; + let len = u8::from_str(b) + .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(v), &self))?; + IpNet::new(ip, len) + } else { + let ip = IpAddr::from_str(v) + .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(v), &self))?; + IpNet::new(ip, if ip.is_ipv6() { 128 } else { 32 }) + } + .map(IpNetDeser) + .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(v), &self)) + } + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + self.visit_borrowed_str(v) + } + fn visit_string(self, v: String) -> Result + where + E: Error, + { + self.visit_borrowed_str(&v) + } +} + +impl<'de> Deserialize<'de> for IpNetDeser { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(IpNetVisitor) + } +} + #[derive(Default)] struct ExampleMod { domain_name_overrides: HashMap, @@ -47,13 +99,13 @@ struct ExampleMod { struct IpCache( RwLock<( radix_trie::Trie, - Vec<(RwLock>, Mutex<()>)>, + Vec<(RwLock>, Mutex<()>, AtomicBool)>, )>, PathBuf, ); #[repr(transparent)] -#[derive(PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct IpCacheKey(Domain); impl radix_trie::TrieKey for IpCacheKey { fn encode_bytes(&self) -> Vec { @@ -89,9 +141,14 @@ impl IpCache { } else { drop(lock); let mut lock = self.0.write().unwrap(); - let key = lock.1.len(); - lock.0.insert(domain_r, key).unwrap(); - lock.1.push((RwLock::new(val), Mutex::new(()))); + if let Some(key) = lock.0.get(&domain_r).copied() { + *lock.1.get(key).unwrap().0.write().unwrap() = val; + } else { + let key = lock.1.len(); + lock.0.insert(domain_r, key); + lock.1 + .push((RwLock::new(val), Mutex::new(()), AtomicBool::new(false))); + } } } } @@ -108,15 +165,23 @@ impl IpCache { .join("\n"); let mut path = self.1.clone(); path.push(domain); + let path1 = &path; let finish = move |_lock| { - let Ok(mut file) = File::create(path) else { + let Ok(mut file) = File::create(path1) else { return; }; file.write_all(to_write.as_bytes()).unwrap_or(()); }; if let Some(key) = key { - let mut lock = lock.1.get(key).unwrap().0.write().unwrap(); + let v = lock.1.get(key).unwrap(); + let mut lock = v.0.write().unwrap(); if *lock == val { + if v.2 + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + let _ = filetime::set_file_mtime(path, filetime::FileTime::now()); + } return false; } *lock = val; @@ -124,9 +189,15 @@ impl IpCache { } else { drop(lock); let mut lock = self.0.write().unwrap(); - let key = lock.1.len(); - lock.0.insert(domain_r, key).unwrap(); - lock.1.push((RwLock::new(val), Mutex::new(()))); + let key = if let Some(key) = lock.0.get(&domain_r).copied() { + key + } else { + let key = lock.1.len(); + lock.0.insert(domain_r, key); + lock.1 + .push((RwLock::new(val), Mutex::new(()), AtomicBool::new(false))); + key + }; drop(lock); finish( self.0 @@ -146,10 +217,11 @@ impl IpCache { impl IpCache { fn load(&mut self, dir: &Path) -> Result<(), io::Error> { + println!("loading {dir:?}"); std::fs::create_dir_all(dir)?; let mut lock = self.0.write().unwrap(); assert!(lock.1.is_empty()); - let domains = std::fs::read_dir("/var/lib/unbound/domains4/")?; + let domains = std::fs::read_dir(dir)?; for entry in domains.filter_map(|x| x.ok()) { let domain = entry.file_name(); let Some(domain) = domain.to_str() else { @@ -165,6 +237,9 @@ impl IpCache { continue; } } + let Ok(reader) = std::fs::File::open(entry.path()) else { + continue; + }; let domain_r = IpCacheKey( domain .split('.') @@ -174,15 +249,10 @@ impl IpCache { .join(&b"."[..]) .into(), ); - let key = lock.1.len(); - lock.0.insert(domain_r, key).unwrap(); - let Ok(reader) = std::fs::File::open(entry.path()) else { - continue; - }; let mut reader = BufReader::new(reader); let mut line = String::new(); let mut ips = SmallVec::new(); - while reader.read_line(&mut line).is_ok() { + while matches!(reader.read_line(&mut line), Ok(x) if x > 0) { let trimmed = line.trim(); if trimmed.is_empty() { continue; @@ -190,34 +260,39 @@ impl IpCache { ips.extend(T::from_str(trimmed)); line.clear(); } - lock.1.push((RwLock::new(ips), Mutex::new(()))); + if let Some(key) = lock.0.get(&domain_r).copied() { + lock.1[key].0.write().unwrap().extend(ips); + } else { + let key = lock.1.len(); + lock.0.insert(domain_r, key); + lock.1 + .push((RwLock::new(ips), Mutex::new(()), AtomicBool::new(false))); + } } Ok(()) } } -struct NftData { - ips4: RTrieSet, - ips6: RTrieSet, - dirty4: bool, - dirty6: bool, - set4: Option, - set6: Option, - name4: String, - name6: String, +struct NftData { + ips: RTrieSet, + dirty: bool, + set: Option, + name: String, } -// SAFETY: set4/set6 are None initially and are never actually sent -unsafe impl Send for NftData {} +// SAFETY: set are None initially and are never actually sent +// (and Set1 might be fine to send anyway actually) +unsafe impl Send for NftData {} struct NftQuery { - domains: RwLock>, + domains: RwLock>, dynamic: bool, index: usize, } impl ExampleMod { fn report(&self, code: &str, err: impl Display) { + println!("{code}: {err}"); if let Ok(mut file) = std::fs::OpenOptions::new() .append(true) .open("/var/lib/unbound/error.log") @@ -292,6 +367,13 @@ fn iter_ip_trie(trie: &RTrieSet) -> impl '_ + Iterator { }) } +fn read_json Deserialize<'a>>(mut f: File) -> Result { + let mut data = Vec::new(); + f.read_to_end(&mut data) + .map_err(serde_json::Error::custom)?; + serde_json::from_slice(&data) +} + impl UnboundMod for ExampleMod { type EnvData = (); type QstateData = (); @@ -344,16 +426,20 @@ impl UnboundMod for ExampleMod { index: i, }, ); - rulesets.push(NftData { - set4: None, - set6: None, - ips4: RTrieSet::new(), - ips6: RTrieSet::new(), - dirty4: true, - dirty6: true, - name4: set4.to_owned(), - name6: set6.to_owned(), - }); + rulesets.push(( + NftData { + set: None, + ips: RTrieSet::new(), + dirty: true, + name: set4.to_owned(), + }, + NftData { + set: None, + ips: RTrieSet::new(), + dirty: true, + name: set6.to_owned(), + }, + )); } } @@ -366,11 +452,13 @@ impl UnboundMod for ExampleMod { } // load json files - for ((k, v), r) in nft_queries.iter_mut().zip(rulesets.iter_mut()) { + for (k, v) in nft_queries.iter_mut() { + let r = &mut rulesets[v.index]; + let mut v_domains = v.domains.write().unwrap(); for base in ["/etc/unbound", "/var/lib/unbound"] { - let mut v_domains = v.domains.write().unwrap(); if let Ok(file) = std::fs::File::open(format!("{base}/{k}_domains.json")) { - match serde_json::from_reader::<_, Vec>(file) { + println!("loading {base}/{k}_domains.json"); + match read_json::>(file) { Ok(domains) => { for domain in domains { v_domains.insert( @@ -386,7 +474,8 @@ impl UnboundMod for ExampleMod { } } if let Ok(file) = std::fs::File::open(format!("{base}/{k}_dpi.json")) { - match serde_json::from_reader::<_, Vec>(file) { + println!("loading {base}/{k}_dpi.json"); + match read_json::>(file) { Ok(dpi_info) => { for domain in dpi_info.iter().flat_map(|x| &x.domains) { v_domains.insert( @@ -402,18 +491,19 @@ impl UnboundMod for ExampleMod { } } if let Ok(file) = std::fs::File::open(format!("{base}/{k}_ips.json")) { - match serde_json::from_reader::<_, Vec>(file) { + println!("loading {base}/{k}_ips.json"); + match read_json::>(file) { Ok(ips) => { - r.ips4.extend(ips.iter().filter_map(|x| { - if let IpNet::V4(x) = x { - Some(*x) + r.0.ips.extend(ips.iter().filter_map(|x| { + if let IpNet::V4(x) = x.0 { + Some(x) } else { None } })); - r.ips6.extend(ips.iter().filter_map(|x| { - if let IpNet::V6(x) = x { - Some(*x) + r.1.ips.extend(ips.iter().filter_map(|x| { + if let IpNet::V6(x) = x.0 { + Some(x) } else { None } @@ -422,36 +512,26 @@ impl UnboundMod for ExampleMod { Err(err) => ret.report("ips", err), } } - for rev_domain in v_domains.iter() { - ret.cache4.get_maybe_update_rev( - rev_domain - .iter() - .map(|x| x.as_slice()) - .collect::>() - .join(&b"."[..]) - .into(), - |val| { - if let Some(val) = val { - r.ips4.extend(val.iter().map(|x| Ipv4Net::from(*x))); - } - None - }, - ); - ret.cache6.get_maybe_update_rev( - rev_domain - .iter() - .map(|x| x.as_slice()) - .collect::>() - .join(&b"."[..]) - .into(), - |val| { - if let Some(val) = val { - r.ips6.extend(val.iter().map(|x| Ipv6Net::from(*x))); - } - None - }, - ); - } + } + println!("loading cached domain ips for {k}"); + for rev_domain in v_domains.iter() { + let rev_domain: SmallVec<_> = rev_domain + .map(|x| x.as_slice()) + .collect::>() + .join(&b"."[..]) + .into(); + ret.cache4.get_maybe_update_rev(rev_domain.clone(), |val| { + if let Some(val) = val { + r.0.ips.extend(val.iter().map(|x| Ipv4Net::from(*x))); + } + None + }); + ret.cache6.get_maybe_update_rev(rev_domain, |val| { + if let Some(val) = val { + r.1.ips.extend(val.iter().map(|x| Ipv6Net::from(*x))); + } + None + }); } } @@ -462,6 +542,7 @@ impl UnboundMod for ExampleMod { std::thread::spawn(move || { fn report(err: impl Display) { + println!("nftables: {err}"); if let Ok(mut file) = std::fs::OpenOptions::new() .append(true) .open("/var/lib/unbound/nftables.log") @@ -482,47 +563,68 @@ impl UnboundMod for ExampleMod { if set.table_name() == Some("global") && set.family() == libc::NFPROTO_INET as u32 { - if set.name() == Some(&ruleset.name4) { - ruleset.set4 = Some(set.clone()); - } else if set.name() == Some(&ruleset.name6) { - ruleset.set6 = Some(set.clone()); + if set.name() == Some(&ruleset.0.name) { + println!("found set {}", ruleset.0.name); + ruleset.0.set = Some(set.clone()); + } else if set.name() == Some(&ruleset.1.name) { + println!("found set {}", ruleset.1.name); + ruleset.1.set = Some(set.clone()); } } } } for ruleset in &mut rulesets { - if !ruleset.name4.is_empty() && ruleset.set4.is_none() { - report(format!("set {} not found", ruleset.name4)); - ruleset.ips4 = RTrieSet::new(); + if !ruleset.0.name.is_empty() && ruleset.0.set.is_none() { + report(format!("set {} not found", ruleset.0.name)); + ruleset.0.ips = RTrieSet::new(); } - if !ruleset.name6.is_empty() && ruleset.set6.is_none() { - report(format!("set {} not found", ruleset.name6)); - ruleset.ips6 = RTrieSet::new(); + if !ruleset.1.name.is_empty() && ruleset.1.set.is_none() { + report(format!("set {} not found", ruleset.1.name)); + ruleset.1.ips = RTrieSet::new(); } } let mut first = true; loop { for ruleset in &mut rulesets { - if let Some(set) = ruleset.set4.as_mut().filter(|_| ruleset.dirty4) { + if let Some(set) = ruleset.0.set.as_mut().filter(|_| ruleset.0.dirty) { + if first { + println!( + "initializing set {} with ~{} ips (e.g. {:?})", + ruleset.0.name, + ruleset.0.ips.len(), + iter_ip_trie(&ruleset.0.ips).next(), + ); + } if let Err(err) = set.add_cidrs( &socket, first, - iter_ip_trie(&ruleset.ips4).map(IpNet::V4), + iter_ip_trie(&ruleset.0.ips).map(IpNet::V4), ) { report(err); } } - if let Some(set) = ruleset.set6.as_mut().filter(|_| ruleset.dirty6) { + if let Some(set) = ruleset.1.set.as_mut().filter(|_| ruleset.1.dirty) { + if first { + println!( + "initializing set {} with ~{} ips (e.g. {:?})", + ruleset.1.name, + ruleset.1.ips.len(), + iter_ip_trie(&ruleset.1.ips).next(), + ); + } if let Err(err) = set.add_cidrs( &socket, first, - iter_ip_trie(&ruleset.ips6).map(IpNet::V6), + iter_ip_trie(&ruleset.1.ips).map(IpNet::V6), ) { report(err); } } } - first = false; + if first { + println!("nftables init done"); + first = false; + } let res = match rx.recv() { Ok(val) => Some(val), Err(RecvError) => break, @@ -533,15 +635,15 @@ impl UnboundMod for ExampleMod { for ip1 in ips.iter().copied() { match ip1 { IpNet::V4(ip) => { - if ruleset.set4.is_some() && !should_add(&ruleset.ips4, &ip) { - ruleset.ips4.insert(ip); - ruleset.dirty4 = true; + if ruleset.0.set.is_some() && !should_add(&ruleset.0.ips, &ip) { + ruleset.0.ips.insert(ip); + ruleset.0.dirty = true; } } IpNet::V6(ip) => { - if ruleset.set6.is_some() && !should_add(&ruleset.ips6, &ip) { - ruleset.ips6.insert(ip); - ruleset.dirty6 = true; + if ruleset.1.set.is_some() && !should_add(&ruleset.1.ips, &ip) { + ruleset.1.ips.insert(ip); + ruleset.1.dirty = true; } } } @@ -550,6 +652,7 @@ impl UnboundMod for ExampleMod { } } }); + println!("loaded"); Ok(ret) } @@ -557,9 +660,20 @@ impl UnboundMod for ExampleMod { fn operate( &self, qstate: &mut crate::unbound::ModuleQstate, - _event: crate::unbound::ModuleEvent, + event: ModuleEvent, _entry: &mut crate::unbound::OutboundEntryMut, ) { + match event { + ModuleEvent::New | ModuleEvent::Pass => { + qstate.set_ext_state(ModuleExtState::WaitModule); + return; + } + ModuleEvent::ModDone => {} + _ => { + qstate.set_ext_state(ModuleExtState::Error); + return; + } + } let info = qstate.qinfo_mut(); let name = info.qname().to_bytes(); let rev_domain = name.strip_suffix(b".").unwrap_or(name); @@ -597,7 +711,7 @@ impl UnboundMod for ExampleMod { }; let _lock = self.domains_write_lock.lock().unwrap(); let mut old: Vec = if let Ok(file) = File::open(&file_name) { - match serde_json::from_reader(file) { + match read_json(file) { Ok(x) => x, Err(err) => { self.report("domains json", err); @@ -620,6 +734,7 @@ impl UnboundMod for ExampleMod { } } } + qstate.set_ext_state(ModuleExtState::Finished); return; } else if let Some(rev_domain) = self .tmp_nft_token @@ -640,6 +755,7 @@ impl UnboundMod for ExampleMod { } } } + qstate.set_ext_state(ModuleExtState::Finished); return; } let split_rev_domain = rev_domain @@ -653,6 +769,7 @@ impl UnboundMod for ExampleMod { } } if qnames.is_empty() { + qstate.set_ext_state(ModuleExtState::Finished); return; } if let Some(ret) = qstate.return_msg_mut() { @@ -695,6 +812,7 @@ impl UnboundMod for ExampleMod { } Err(err) => { self.report("domain utf-8", err); + qstate.set_ext_state(ModuleExtState::Error); return; } }; @@ -731,6 +849,7 @@ impl UnboundMod for ExampleMod { } } } + qstate.set_ext_state(ModuleExtState::Finished); } } @@ -746,7 +865,7 @@ mod test { use ipnet::Ipv4Net; use iptrie::RTrieSet; - use crate::example::{iter_ip_trie, should_add}; + use crate::example::{iter_ip_trie, should_add, IpNetDeser}; #[test] fn test() { @@ -766,5 +885,6 @@ mod test { assert!(dbg!(trie.iter().collect::>()).len() == 3); trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 1, 1), 32).unwrap()); assert!(dbg!(iter_ip_trie(&trie).collect::>()).len() == 2); + assert!(serde_json::from_str::>(r#"["127.0.0.1/8","127.0.0.1"]"#).is_ok()) } } diff --git a/src/lib.rs b/src/lib.rs index 4d85b4b..d470440 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ use std::panic::{RefUnwindSafe, UnwindSafe}; )] mod bindings; mod combine; +mod domain_tree; #[cfg(feature = "example")] mod example; mod exports; diff --git a/src/unbound.rs b/src/unbound.rs index 6136836..791e79b 100644 --- a/src/unbound.rs +++ b/src/unbound.rs @@ -1,10 +1,10 @@ #![allow(dead_code)] use crate::bindings::{ - config_file, dns_msg, in6_addr, in6_addr__bindgen_ty_1, in_addr, infra_cache, key_cache, - lruhash_entry, module_env, module_ev, module_qstate, outbound_entry, packed_rrset_data, - packed_rrset_key, query_info, reply_info, rrset_cache, rrset_id_type, rrset_trust, sec_status, - slabhash, sldns_enum_ede_code, sockaddr_in, sockaddr_in6, sockaddr_storage, - ub_packed_rrset_key, AF_INET, AF_INET6, + self, config_file, dns_msg, in6_addr, in6_addr__bindgen_ty_1, in_addr, infra_cache, key_cache, + lruhash_entry, module_env, module_ev, module_ext_state, module_qstate, outbound_entry, + packed_rrset_data, packed_rrset_key, query_info, reply_info, rrset_cache, rrset_id_type, + rrset_trust, sec_status, slabhash, sldns_enum_ede_code, sockaddr_in, sockaddr_in6, + sockaddr_storage, ub_packed_rrset_key, AF_INET, AF_INET6, }; use std::{ffi::CStr, marker::PhantomData, net::SocketAddr, os::raw::c_char, ptr, time::Duration}; @@ -235,6 +235,11 @@ impl ModuleQstate<'_, T> { )) } } + pub fn set_ext_state(&mut self, state: ModuleExtState) { + unsafe { + (*self.0).ext_state[self.1 as usize] = state as module_ext_state; + } + } } impl DnsMsgMut<'_> { @@ -441,29 +446,29 @@ impl From for ModuleEvent { #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum SecStatus { /// UNCHECKED means that object has yet to be validated. - Unchecked = 0, + Unchecked = bindings::sec_status_sec_status_unchecked, /// BOGUS means that the object (RRset or message) failed to validate\n (according to local policy), but should have validated. - Bogus = 1, + Bogus = bindings::sec_status_sec_status_bogus, /// INDETERMINATE means that the object is insecure, but not\n authoritatively so. Generally this means that the RRset is not\n below a configured trust anchor. - Indeterminate = 2, + Indeterminate = bindings::sec_status_sec_status_indeterminate, /// INSECURE means that the object is authoritatively known to be\n insecure. Generally this means that this RRset is below a trust\n anchor, but also below a verified, insecure delegation. - Insecure = 3, + Insecure = bindings::sec_status_sec_status_insecure, /// SECURE_SENTINEL_FAIL means that the object (RRset or message)\n validated according to local policy but did not succeed in the root\n KSK sentinel test (draft-ietf-dnsop-kskroll-sentinel). - SecureSentinelFail = 4, + SecureSentinelFail = bindings::sec_status_sec_status_secure_sentinel_fail, /// SECURE means that the object (RRset or message) validated\n according to local policy. - Secure = 5, - Unknown = 6, + Secure = bindings::sec_status_sec_status_secure, + Unknown = 99, } impl From for SecStatus { fn from(value: module_ev) -> Self { match value { - 0 => Self::Unchecked, - 1 => Self::Bogus, - 2 => Self::Indeterminate, - 3 => Self::Insecure, - 4 => Self::SecureSentinelFail, - 5 => Self::Secure, + bindings::sec_status_sec_status_unchecked => Self::Unchecked, + bindings::sec_status_sec_status_bogus => Self::Bogus, + bindings::sec_status_sec_status_indeterminate => Self::Indeterminate, + bindings::sec_status_sec_status_insecure => Self::Insecure, + bindings::sec_status_sec_status_secure_sentinel_fail => Self::SecureSentinelFail, + bindings::sec_status_sec_status_secure => Self::Secure, _ => Self::Unknown, } } @@ -540,67 +545,97 @@ impl From for SldnsEdeCode { #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum RrsetTrust { /// Initial value for trust - None = 0, + None = bindings::rrset_trust_rrset_trust_none, /// Additional information from non-authoritative answers - AddNoAa = 1, + AddNoAa = bindings::rrset_trust_rrset_trust_add_noAA, /// Data from the authority section of a non-authoritative answer - AuthNoAa = 2, + AuthNoAa = bindings::rrset_trust_rrset_trust_auth_noAA, /// Additional information from an authoritative answer - AddAa = 3, + AddAa = bindings::rrset_trust_rrset_trust_add_AA, /// non-authoritative data from the answer section of authoritative answers - NonauthAnsAa = 4, + NonauthAnsAa = bindings::rrset_trust_rrset_trust_nonauth_ans_AA, /// Data from the answer section of a non-authoritative answer - AnsNoAa = 5, + AnsNoAa = bindings::rrset_trust_rrset_trust_ans_noAA, /// Glue from a primary zone, or glue from a zone transfer - Glue = 6, + Glue = bindings::rrset_trust_rrset_trust_glue, /// Data from the authority section of an authoritative answer - AuthAa = 7, + AuthAa = bindings::rrset_trust_rrset_trust_auth_AA, /// The authoritative data included in the answer section of an\n authoritative reply - AnsAa = 8, + AnsAa = bindings::rrset_trust_rrset_trust_ans_AA, /// Data from a zone transfer, other than glue - SecNoglue = 9, + SecNoglue = bindings::rrset_trust_rrset_trust_sec_noglue, /// Data from a primary zone file, other than glue data - PrimNoglue = 10, + PrimNoglue = bindings::rrset_trust_rrset_trust_prim_noglue, /// DNSSEC(rfc4034) validated with trusted keys - Validated = 11, + Validated = bindings::rrset_trust_rrset_trust_validated, /// Ultimately trusted, no more trust is possible, /// trusted keys from the unbound configuration setup. - Ultimate = 12, - Unknown = 13, + Ultimate = bindings::rrset_trust_rrset_trust_ultimate, + Unknown = 99, } impl From for RrsetTrust { fn from(value: rrset_trust) -> Self { match value { - 0 => Self::None, - 1 => Self::AddNoAa, - 2 => Self::AuthNoAa, - 3 => Self::AddAa, - 4 => Self::NonauthAnsAa, - 5 => Self::AnsNoAa, - 6 => Self::Glue, - 7 => Self::AuthAa, - 8 => Self::AnsAa, - 9 => Self::SecNoglue, - 10 => Self::PrimNoglue, - 11 => Self::Validated, - 12 => Self::Ultimate, + bindings::rrset_trust_rrset_trust_none => Self::None, + bindings::rrset_trust_rrset_trust_add_noAA => Self::AddNoAa, + bindings::rrset_trust_rrset_trust_auth_noAA => Self::AuthNoAa, + bindings::rrset_trust_rrset_trust_add_AA => Self::AddAa, + bindings::rrset_trust_rrset_trust_nonauth_ans_AA => Self::NonauthAnsAa, + bindings::rrset_trust_rrset_trust_ans_noAA => Self::AnsNoAa, + bindings::rrset_trust_rrset_trust_glue => Self::Glue, + bindings::rrset_trust_rrset_trust_auth_AA => Self::AuthAa, + bindings::rrset_trust_rrset_trust_ans_AA => Self::AnsAa, + bindings::rrset_trust_rrset_trust_sec_noglue => Self::SecNoglue, + bindings::rrset_trust_rrset_trust_prim_noglue => Self::PrimNoglue, + bindings::rrset_trust_rrset_trust_validated => Self::Validated, + bindings::rrset_trust_rrset_trust_ultimate => Self::Ultimate, + _ => Self::Unknown, + } + } +} + +#[non_exhaustive] +#[repr(u32)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub enum ModuleExtState { + InitialState = bindings::module_ext_state_module_state_initial, + WaitReply = bindings::module_ext_state_module_wait_reply, + WaitModule = bindings::module_ext_state_module_wait_module, + RestartNext = bindings::module_ext_state_module_restart_next, + WaitSubquery = bindings::module_ext_state_module_wait_subquery, + Error = bindings::module_ext_state_module_error, + Finished = bindings::module_ext_state_module_finished, + Unknown = 99, +} + +impl From for ModuleExtState { + fn from(value: module_ext_state) -> Self { + match value { + bindings::module_ext_state_module_state_initial => Self::InitialState, + bindings::module_ext_state_module_wait_reply => Self::WaitReply, + bindings::module_ext_state_module_wait_module => Self::WaitModule, + bindings::module_ext_state_module_restart_next => Self::RestartNext, + bindings::module_ext_state_module_wait_subquery => Self::WaitSubquery, + bindings::module_ext_state_module_error => Self::Error, + bindings::module_ext_state_module_finished => Self::Finished, _ => Self::Unknown, } } } pub mod rr_class { + use crate::bindings; /// the Internet - pub const IN: u16 = 1; + pub const IN: u16 = bindings::sldns_enum_rr_class_LDNS_RR_CLASS_IN as u16; /// Chaos class - pub const CH: u16 = 3; + pub const CH: u16 = bindings::sldns_enum_rr_class_LDNS_RR_CLASS_CH as u16; /// Hesiod (Dyer 87) - pub const HS: u16 = 4; + pub const HS: u16 = bindings::sldns_enum_rr_class_LDNS_RR_CLASS_HS as u16; /// None class, dynamic update - pub const NONE: u16 = 254; + pub const NONE: u16 = bindings::sldns_enum_rr_class_LDNS_RR_CLASS_NONE as u16; /// Any class - pub const ANY: u16 = 255; + pub const ANY: u16 = bindings::sldns_enum_rr_class_LDNS_RR_CLASS_ANY as u16; } pub mod rr_type {