From 3df012a6dfd1ed84358f3581e5f7f200576fa1e3 Mon Sep 17 00:00:00 2001 From: chayleaf Date: Sun, 11 Aug 2024 00:54:48 +0700 Subject: [PATCH] more fixes --- Cargo.lock | 29 +++- Cargo.toml | 3 +- FIXME | 2 - flake.nix | 4 +- libnftnl-fix.patch | 24 +++ src/example.rs | 333 ++++++++++++++++++++------------------ src/nftables.rs | 383 ++++++++++++++++++++++++++++---------------- src/nftables_lib.rs | 15 ++ 8 files changed, 494 insertions(+), 299 deletions(-) delete mode 100644 FIXME create mode 100644 libnftnl-fix.patch create mode 100644 src/nftables_lib.rs diff --git a/Cargo.lock b/Cargo.lock index e696a54..17b333d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + [[package]] name = "boxcar" version = "0.2.5" @@ -20,6 +26,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "ctor" version = "0.2.8" @@ -119,7 +131,7 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9201688bd0bc571dfa4c21ce0a525480c8b782776cf88e12571fa89108dd920" dependencies = [ - "bitflags", + "bitflags 1.3.2", "err-derive", "log", "nftnl-sys", @@ -145,6 +157,18 @@ dependencies = [ "smallvec", ] +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "pkg-config" version = "0.3.30" @@ -303,9 +327,8 @@ dependencies = [ "iptrie", "libc", "mnl", - "mnl-sys", "nftnl", - "nftnl-sys", + "nix", "prefix-tree", "radix_trie", "serde", diff --git a/Cargo.toml b/Cargo.toml index dd8b562..02df0bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,9 +15,8 @@ ipnet = { version = "2.9.0", features = ["serde"] } iptrie = "0.8.5" libc = "0.2.155" mnl = { version = "0.2.2", features = ["mnl-1-0-4"] } -mnl-sys = { version = "0.2.1", features = ["mnl-1-0-4"] } nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] } -nftnl-sys = { version = "0.6.1", features = ["nftnl-1-1-2"] } +nix = { version = "0.29.0", features = ["poll"] } prefix-tree = "0.5.0" radix_trie = "0.2.1" serde = { version = "1.0.205", features = ["derive"] } diff --git a/FIXME b/FIXME deleted file mode 100644 index 6099071..0000000 --- a/FIXME +++ /dev/null @@ -1,2 +0,0 @@ -nftables -token is after, not before diff --git a/flake.nix b/flake.nix index 3a6f16e..655c1e7 100644 --- a/flake.nix +++ b/flake.nix @@ -25,7 +25,9 @@ in pkgs.mkShell rec { name = "unbound-rust-mod-shell"; LIBMNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libmnl}/lib"; - LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libnftnl}/lib"; + LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib (pkgs.libnftnl.overrideAttrs (old: { + patches = (old.patches or []) ++ [ ./libnftnl-fix.patch ]; + }))}/lib"; LD_LIBRARY_PATH = "${LIBMNL_LIB_DIR}:${LIBNFTNL_LIB_DIR}"; }; }; diff --git a/libnftnl-fix.patch b/libnftnl-fix.patch new file mode 100644 index 0000000..e2b37e0 --- /dev/null +++ b/libnftnl-fix.patch @@ -0,0 +1,24 @@ +diff --git a/src/libnftnl.map b/src/libnftnl.map +index 8fffff1..3f660de 100644 +--- a/src/libnftnl.map ++++ b/src/libnftnl.map +@@ -129,6 +129,7 @@ global: + nftnl_set_get_str; + nftnl_set_get_u32; + nftnl_set_get_u64; ++ nftnl_set_clone; + nftnl_set_nlmsg_build_payload; + nftnl_set_nlmsg_parse; + nftnl_set_parse; +diff --git a/src/set.c b/src/set.c +index 07e332d..c5f9518 100644 +--- a/src/set.c ++++ b/src/set.c +@@ -352,6 +352,7 @@ uint64_t nftnl_set_get_u64(const struct nftnl_set *s, uint16_t attr) + return val ? *val : 0; + } + ++EXPORT_SYMBOL(nftnl_set_clone); + struct nftnl_set *nftnl_set_clone(const struct nftnl_set *set) + { + struct nftnl_set *newset; diff --git a/src/example.rs b/src/example.rs index b403380..adfe848 100644 --- a/src/example.rs +++ b/src/example.rs @@ -1,29 +1,27 @@ use std::{ collections::HashMap, - ffi::CString, fmt::Display, fs::File, io::{self, BufRead, BufReader, Write}, net::{Ipv4Addr, Ipv6Addr}, - os::raw::c_char, path::{Path, PathBuf}, str::FromStr, sync::{ mpsc::{self, RecvError}, Mutex, RwLock, }, - time::{Duration, Instant, SystemTime}, + time::{Duration, SystemTime}, }; use ctor::ctor; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; -use iptrie::{Ipv4Prefix, Ipv6Prefix, RTrieSet}; -use nftnl::set::SetKey; +use iptrie::{IpPrefix, RTrieSet}; use prefix_tree::PrefixSet; use serde::Deserialize; use smallvec::SmallVec; use crate::{ + nftables::Set1, unbound::{rr_class, rr_type}, UnboundMod, }; @@ -199,12 +197,19 @@ impl IpCache { } struct NftData { - ips4: iptrie::Ipv4RTrieSet, - ips6: iptrie::Ipv6RTrieSet, - name4: CString, - name6: CString, + ips4: RTrieSet, + ips6: RTrieSet, + dirty4: bool, + dirty6: bool, + set4: Option, + set6: Option, + name4: String, + name6: String, } +// SAFETY: set4/set6 are None initially and are never actually sent +unsafe impl Send for NftData {} + struct NftQuery { domains: RwLock>, dynamic: bool, @@ -242,16 +247,61 @@ struct DpiInfo { // restriction: {"code": "ban"} } +trait Helper: iptrie::IpPrefix + PartialEq { + const ZERO: Self; + fn direct_parent(&self) -> Option; +} + +impl Helper for Ipv4Net { + const ZERO: Self = match Self::new(Ipv4Addr::UNSPECIFIED, 0) { + Ok(x) => x, + #[allow(clippy::empty_loop)] + Err(_) => loop {}, + }; + fn direct_parent(&self) -> Option { + self.len() + .checked_sub(1) + .and_then(|x| Self::new(self.bitslot().into(), x).ok()) + } +} + +impl Helper for Ipv6Net { + const ZERO: Self = match Self::new(Ipv6Addr::UNSPECIFIED, 0) { + Ok(x) => x, + #[allow(clippy::empty_loop)] + Err(_) => loop {}, + }; + fn direct_parent(&self) -> Option { + self.len() + .checked_sub(1) + .and_then(|x| Self::new(self.bitslot().into(), x).ok()) + } +} + +fn should_add(trie: &RTrieSet, elem: &T) -> bool { + *trie.lookup(elem) == T::ZERO +} + +fn iter_ip_trie(trie: &RTrieSet) -> impl '_ + Iterator { + trie.iter().copied().filter(|x| { + if let Some(par) = x.direct_parent() { + should_add(trie, &par) + } else { + *x != T::ZERO + } + }) +} + impl UnboundMod for ExampleMod { type EnvData = (); type QstateData = (); fn init(_env: &mut crate::unbound::ModuleEnv) -> Result { let mut ret = Self { nft_token: std::env::var_os("NFT_TOKEN") - .map(|x| x.to_str().ok_or(()).map(|s| ".".to_owned() + s)) + .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned() + ".")) .transpose()?, tmp_nft_token: std::env::var_os("NFT_TOKEN") - .map(|x| x.to_str().ok_or(()).map(|s| ".tmp".to_owned() + s)) + .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned() + ".tmp.")) .transpose()?, ..Self::default() }; @@ -295,10 +345,14 @@ impl UnboundMod for ExampleMod { }, ); rulesets.push(NftData { + set4: None, + set6: None, ips4: RTrieSet::new(), ips6: RTrieSet::new(), - name4: CString::from_vec_with_nul((set4.to_owned() + "\0").into()).unwrap(), - name6: CString::from_vec_with_nul((set6.to_owned() + "\0").into()).unwrap(), + dirty4: true, + dirty6: true, + name4: set4.to_owned(), + name6: set6.to_owned(), }); } } @@ -352,14 +406,14 @@ impl UnboundMod for ExampleMod { Ok(ips) => { r.ips4.extend(ips.iter().filter_map(|x| { if let IpNet::V4(x) = x { - Ipv4Prefix::new(x.addr(), x.prefix_len()).ok() + Some(*x) } else { None } })); r.ips6.extend(ips.iter().filter_map(|x| { if let IpNet::V6(x) = x { - Ipv6Prefix::new(x.addr(), x.prefix_len()).ok() + Some(*x) } else { None } @@ -378,7 +432,7 @@ impl UnboundMod for ExampleMod { .into(), |val| { if let Some(val) = val { - r.ips4.extend(val.iter().map(|x| Ipv4Prefix::from(*x))); + r.ips4.extend(val.iter().map(|x| Ipv4Net::from(*x))); } None }, @@ -392,7 +446,7 @@ impl UnboundMod for ExampleMod { .into(), |val| { if let Some(val) = val { - r.ips6.extend(val.iter().map(|x| Ipv6Prefix::from(*x))); + r.ips6.extend(val.iter().map(|x| Ipv6Net::from(*x))); } None }, @@ -403,162 +457,97 @@ impl UnboundMod for ExampleMod { // add stuff to nftables let (tx, rx) = mpsc::channel(); + ret.ruleset_queue = Some(tx); std::thread::spawn(move || { - let table = nftnl::Table::new( - &CString::from_vec_with_nul(b"global\0".to_vec()).unwrap(), - nftnl::ProtoFamily::Inet, - ); - let mut first = true; - let mut bufs = vec![Vec::::new(); rulesets.len()]; - let mut len = 0; - let mut queue_start = Instant::now(); - loop { - let res = if len == 0 { - match rx.recv() { - Ok(val) => { - queue_start = Instant::now(); - Some(val) - } - Err(RecvError) => break, + fn report(err: impl Display) { + if let Ok(mut file) = std::fs::OpenOptions::new() + .append(true) + .open("/var/lib/unbound/nftables.log") + { + if file.write_all(err.to_string().as_bytes()).is_err() { + return; } - } else { - match rx.recv_timeout((queue_start + Duration::from_secs(30)) - Instant::now()) + if file.write_all(b"\n").is_err() { + return; + } + file.flush().unwrap_or(()); + } + } + let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap(); + let all_sets = crate::nftables::get_sets(&socket).unwrap(); + for set in all_sets { + for ruleset in &mut rulesets { + if set.table_name() == Some("global") + && set.family() == libc::NFPROTO_INET as u32 { - Ok(val) => Some(val), - Err(mpsc::RecvTimeoutError::Timeout) => None, - Err(mpsc::RecvTimeoutError::Disconnected) => break, + if set.name() == Some(&ruleset.name4) { + ruleset.set4 = Some(set.clone()); + } else if set.name() == Some(&ruleset.name6) { + ruleset.set6 = Some(set.clone()); + } } + } + } + for ruleset in &mut rulesets { + if !ruleset.name4.is_empty() && ruleset.set4.is_none() { + report(format!("set {} not found", ruleset.name4)); + ruleset.ips4 = RTrieSet::new(); + } + if !ruleset.name6.is_empty() && ruleset.set6.is_none() { + report(format!("set {} not found", ruleset.name6)); + ruleset.ips6 = RTrieSet::new(); + } + } + let mut first = true; + loop { + for ruleset in &mut rulesets { + if let Some(set) = ruleset.set4.as_mut().filter(|_| ruleset.dirty4) { + if let Err(err) = set.add_cidrs( + &socket, + first, + iter_ip_trie(&ruleset.ips4).map(IpNet::V4), + ) { + report(err); + } + } + if let Some(set) = ruleset.set6.as_mut().filter(|_| ruleset.dirty6) { + if let Err(err) = set.add_cidrs( + &socket, + first, + iter_ip_trie(&ruleset.ips6).map(IpNet::V6), + ) { + report(err); + } + } + } + first = false; + let res = match rx.recv() { + Ok(val) => Some(val), + Err(RecvError) => break, }; - let do_it = - res.is_none() || (Instant::now() - queue_start) > Duration::from_secs(25); if let Some((rulesets1, ips)) = res { - for ruleset in rulesets1 { + for i in rulesets1.into_iter() { + let ruleset = &mut rulesets[i]; for ip1 in ips.iter().copied() { match ip1 { IpNet::V4(ip) => { - if !rulesets[ruleset].ips4.contains(&ip) { - rulesets[ruleset].ips4.insert(ip.into()); - bufs[ruleset].push(ip1); + if ruleset.set4.is_some() && !should_add(&ruleset.ips4, &ip) { + ruleset.ips4.insert(ip); + ruleset.dirty4 = true; } } IpNet::V6(ip) => { - if !rulesets[ruleset].ips6.contains(&ip) { - rulesets[ruleset].ips6.insert(ip.into()); - bufs[ruleset].push(ip1); - len += 1; + if ruleset.set6.is_some() && !should_add(&ruleset.ips6, &ip) { + ruleset.ips6.insert(ip); + ruleset.dirty6 = true; } } } } } } - struct FlushSetMsg<'a, T> { - set: &'a nftnl::set::Set<'a, T>, - } - unsafe impl<'a, T> nftnl::NlMsg for FlushSetMsg<'a, T> { - unsafe fn write( - &self, - buf: *mut std::ffi::c_void, - seq: u32, - _msg_type: nftnl::MsgType, - ) { - let header = nftnl_sys::nftnl_nlmsg_build_hdr( - buf as *mut c_char, - libc::NFT_MSG_DELSETELEM as u16, - self.set.get_family() as u16, - 0, - seq, - ); - nftnl_sys::nftnl_set_elems_nlmsg_build_payload(header, self.set.as_ptr()); - } - } - if do_it || len >= 128 { - let mut batch = nftnl::Batch::new(); - for (ruleset, buf) in rulesets.iter().zip(bufs.iter_mut()) { - // internally represented as a range - struct Cidr(T); - impl SetKey for Cidr { - const TYPE: u32 = Ipv4Addr::TYPE; - const LEN: u32 = Ipv4Addr::LEN * 2; - fn data(&self) -> Box<[u8]> { - let data = u32::from_be_bytes(self.0.network().octets()); - let mask = u32::from_be_bytes(self.0.netmask().octets()); - let mut ret = [0u8; (Self::LEN) as usize]; - ret[..(Self::LEN as usize)] - .copy_from_slice(&self.0.network().octets()); - ret[(Self::LEN as usize)..] - .copy_from_slice(&u32::to_be_bytes(!mask | data)); - Box::new(ret) - } - } - impl SetKey for Cidr { - const TYPE: u32 = Ipv6Addr::TYPE; - const LEN: u32 = Ipv6Addr::LEN * 2; - fn data(&self) -> Box<[u8]> { - let data = u128::from_be_bytes(self.0.network().octets()); - let mask = u128::from_be_bytes(self.0.netmask().octets()); - let mut ret = [0u8; (Self::LEN) as usize]; - ret[..(Self::LEN as usize)] - .copy_from_slice(&self.0.network().octets()); - ret[(Self::LEN as usize)..] - .copy_from_slice(&u128::to_be_bytes(!mask | data)); - Box::new(ret) - } - } - let set4 = nftnl::set::Set::>::new( - &ruleset.name4, - 0, - &table, - nftnl::ProtoFamily::Ipv4, - ); - let set6 = nftnl::set::Set::>::new( - &ruleset.name6, - 0, - &table, - nftnl::ProtoFamily::Ipv6, - ); - if first { - batch.add(&FlushSetMsg { set: &set4 }, nftnl::MsgType::Del); - batch.add(&FlushSetMsg { set: &set6 }, nftnl::MsgType::Del); - } - let mut set4 = nftnl::set::Set::new( - &ruleset.name4, - 0, - &table, - nftnl::ProtoFamily::Ipv4, - ); - let mut set6 = nftnl::set::Set::new( - &ruleset.name6, - 0, - &table, - nftnl::ProtoFamily::Ipv6, - ); - let mut added4 = false; - let mut added6 = false; - for ip in buf.drain(..) { - match ip { - IpNet::V4(ip) => { - set4.add(&Cidr(ip)); - added4 = true; - } - IpNet::V6(ip) => { - set6.add(&Cidr(ip)); - added6 = true; - } - } - } - if added4 { - batch.add_iter(set4.elems_iter(), nftnl::MsgType::Add); - } - if added6 { - batch.add_iter(set6.elems_iter(), nftnl::MsgType::Add); - } - } - len = 0; - first = false; - } } }); @@ -577,12 +566,12 @@ impl UnboundMod for ExampleMod { if let Some(rev_domain) = self .nft_token .as_ref() - .and_then(|token| rev_domain.strip_suffix(token.as_bytes())) + .and_then(|token| rev_domain.strip_prefix(token.as_bytes())) { for (qname, query) in self.nft_queries.iter() { if query.dynamic && rev_domain.ends_with(qname.as_bytes()) { if let Some(rev_domain) = - rev_domain.strip_suffix((".".to_owned() + qname).as_bytes()) + rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes()) { let rev_domain = rev_domain .split(|x| *x == b'.') @@ -635,12 +624,12 @@ impl UnboundMod for ExampleMod { } else if let Some(rev_domain) = self .tmp_nft_token .as_ref() - .and_then(|token| rev_domain.strip_suffix(token.as_bytes())) + .and_then(|token| rev_domain.strip_prefix(token.as_bytes())) { for (qname, query) in self.nft_queries.iter() { if query.dynamic && rev_domain.ends_with(qname.as_bytes()) { if let Some(rev_domain) = - rev_domain.strip_suffix((".".to_owned() + qname).as_bytes()) + rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes()) { let rev_domain = rev_domain .split(|x| *x == b'.') @@ -749,3 +738,33 @@ impl UnboundMod for ExampleMod { fn setup() { crate::set_unbound_mod::(); } + +#[cfg(test)] +mod test { + use std::net::Ipv4Addr; + + use ipnet::Ipv4Net; + use iptrie::RTrieSet; + + use crate::example::{iter_ip_trie, should_add}; + + #[test] + fn test() { + let mut trie = RTrieSet::new(); + assert!(should_add( + &trie, + &Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap() + )); + trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()); + assert!(!should_add( + &trie, + &Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap() + )); + trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 31).unwrap()); + assert!(dbg!(iter_ip_trie(&trie).collect::>()).len() == 1); + // contains 0.0.0.0, etc + assert!(dbg!(trie.iter().collect::>()).len() == 3); + trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 1, 1), 32).unwrap()); + assert!(dbg!(iter_ip_trie(&trie).collect::>()).len() == 2); + } +} diff --git a/src/nftables.rs b/src/nftables.rs index 7c94887..7eab454 100644 --- a/src/nftables.rs +++ b/src/nftables.rs @@ -1,98 +1,68 @@ use std::{ cell::Cell, + ffi::CStr, io, net::{Ipv4Addr, Ipv6Addr}, - os::raw::{c_char, c_void}, + os::{ + fd::BorrowedFd, + raw::{c_char, c_void}, + }, rc::Rc, }; -use ipnet::{Ipv4Net, Ipv6Net}; -use nftnl::{ - set::{Set, SetKey}, - FinalizedBatch, MsgType, NlMsg, -}; +use ipnet::{IpNet, Ipv4Net, Ipv6Net}; +use nftnl::{nftnl_sys, set::SetKey, Batch, FinalizedBatch, MsgType, NlMsg}; +use mnl::mnl_sys; -// internally represented as a range -struct Cidr(T); -impl SetKey for Cidr { - const TYPE: u32 = Ipv4Addr::TYPE; - const LEN: u32 = Ipv4Addr::LEN * 2; - fn data(&self) -> Box<[u8]> { - let data = u32::from_be_bytes(self.0.network().octets()); - let mask = u32::from_be_bytes(self.0.netmask().octets()); - let mut ret = [0u8; (Self::LEN) as usize]; - ret[..(Ipv4Addr::LEN as usize)].copy_from_slice(&self.0.network().octets()); - ret[(Ipv4Addr::LEN as usize)..].copy_from_slice(&u32::to_be_bytes(!mask | data)); - println!("{ret:?} {:?}", self.0.addr().data()); - Box::new(ret) - } -} -impl SetKey for Cidr { - const TYPE: u32 = Ipv6Addr::TYPE; - const LEN: u32 = Ipv6Addr::LEN * 2; - fn data(&self) -> Box<[u8]> { - let data = u128::from_be_bytes(self.0.network().octets()); - let mask = u128::from_be_bytes(self.0.netmask().octets()); - let mut ret = [0u8; (Self::LEN) as usize]; - ret[..(Ipv6Addr::LEN as usize)].copy_from_slice(&self.0.network().octets()); - ret[(Ipv6Addr::LEN as usize)..].copy_from_slice(&u128::to_be_bytes(!mask | data)); - Box::new(ret) +fn cidr_bound_ipv4(net: Ipv4Net) -> Option { + let data = u32::from(net.network()); + let mask = u32::from(net.netmask()); + let ip = (!mask | data).wrapping_add(1); + if ip == 0 { + None + } else { + Some(ip.into()) } } -struct FlushSetMsg<'a, T> { - set: &'a Set<'a, T>, +fn cidr_bound_ipv6(net: Ipv6Net) -> Option { + let data = u128::from_be_bytes(net.network().octets()); + let mask = u128::from_be_bytes(net.netmask().octets()); + let ip = (!mask | data).wrapping_add(1); + if ip == 0 { + None + } else { + Some(ip.into()) + } } -unsafe impl<'a, T> NlMsg for FlushSetMsg<'a, T> { + +#[must_use] +struct FlushSetMsg<'a> { + set: &'a Set1, +} +unsafe impl<'a> NlMsg for FlushSetMsg<'a> { unsafe fn write(&self, buf: *mut std::ffi::c_void, seq: u32, _msg_type: MsgType) { let header = nftnl_sys::nftnl_nlmsg_build_hdr( buf as *mut c_char, libc::NFT_MSG_DELSETELEM as u16, - self.set.get_family() as u16, + self.set.family() as u16, 0, seq, ); - nftnl_sys::nftnl_set_elems_nlmsg_build_payload(header, self.set.as_ptr()); + nftnl_sys::nftnl_set_elems_nlmsg_build_payload(header, self.set.as_mut_ptr()); } } -pub fn send_and_process(batch: &FinalizedBatch) -> io::Result<()> { - let socket = mnl::Socket::new(mnl::Bus::Netfilter)?; - eprintln!("a"); - socket.send_all(batch)?; - eprintln!("b"); - let portid = socket.portid(); - let mut buf = vec![0; nftnl::nft_nlmsg_maxsize() as usize]; - loop { - eprintln!("c"); - let n = socket.recv(&mut buf[..])?; - eprintln!("d {n}"); - if n == 0 { - break; - } - match mnl::cb_run(&buf[..n], 2, portid)? { - mnl::CbResult::Stop => { - println!("stop"); - break; - } - mnl::CbResult::Ok => { - println!("ok"); - } - } - } - Ok(()) -} - -pub struct SetElemsIter<'a, K> { - set: &'a Set<'a, K>, +pub struct SetElemsIter<'a> { + set: &'a Set1, iter: *mut nftnl_sys::nftnl_set_elems_iter, ret: Rc>, is_first: bool, } -impl<'a, K> SetElemsIter<'a, K> { - fn new(set: &'a Set<'a, K>) -> Self { - let iter = unsafe { nftnl_sys::nftnl_set_elems_iter_create(set.as_ptr()) }; +impl<'a> SetElemsIter<'a> { + fn new(set: &'a Set1) -> Self { + let iter = unsafe { nftnl_sys::nftnl_set_elems_iter_create(set.as_mut_ptr()) }; if iter.is_null() { panic!("oom"); } @@ -105,8 +75,8 @@ impl<'a, K> SetElemsIter<'a, K> { } } -impl<'a, K: 'a> Iterator for SetElemsIter<'a, K> { - type Item = SetElemsMsg<'a, K>; +impl<'a> Iterator for SetElemsIter<'a> { + type Item = SetElemsMsg<'a>; fn next(&mut self) -> Option { if self.is_first { @@ -128,31 +98,31 @@ impl<'a, K: 'a> Iterator for SetElemsIter<'a, K> { } } -impl<'a, K> Drop for SetElemsIter<'a, K> { +impl<'a> Drop for SetElemsIter<'a> { fn drop(&mut self) { unsafe { nftnl_sys::nftnl_set_elems_iter_destroy(self.iter) }; } } -pub struct SetElemsMsg<'a, K> { - set: &'a Set<'a, K>, +pub struct SetElemsMsg<'a> { + set: &'a Set1, iter: *mut nftnl_sys::nftnl_set_elems_iter, ret: Rc>, } -unsafe impl<'a, K> NlMsg for SetElemsMsg<'a, K> { +unsafe impl<'a> NlMsg for SetElemsMsg<'a> { unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { let (type_, flags) = match msg_type { MsgType::Add => ( libc::NFT_MSG_NEWSETELEM, - libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK, + libc::NLM_F_CREATE | libc::NLM_F_EXCL, ), - MsgType::Del => (libc::NFT_MSG_DELSETELEM, libc::NLM_F_ACK), + MsgType::Del => (libc::NFT_MSG_DELSETELEM, 0), }; let header = nftnl_sys::nftnl_nlmsg_build_hdr( buf as *mut c_char, type_ as u16, - self.set.get_family() as u16, + self.set.family() as u16, flags as u16, seq, ); @@ -163,55 +133,199 @@ unsafe impl<'a, K> NlMsg for SetElemsMsg<'a, K> { } } -fn add(set: &Set, key: &K) { - let data = key.data(); - let data_len = data.len() as u32; - unsafe { - let elem = nftnl_sys::nftnl_set_elem_alloc(); - if elem.is_null() { - panic!("oom"); +fn send_and_process(socket: &mnl::Socket, batch: &FinalizedBatch) -> io::Result<()> { + socket.send_all(batch)?; + let portid = socket.portid(); + let mut buf = vec![0; nftnl::nft_nlmsg_maxsize() as usize]; + let fd = unsafe { mnl_sys::mnl_socket_get_fd(socket.as_raw_socket()) }; + let mut readfds = nix::sys::select::FdSet::new(); + let fd1 = unsafe { BorrowedFd::borrow_raw(fd) }; + let mut tv = nix::sys::time::TimeVal::new(0, 0); + loop { + readfds.clear(); + readfds.insert(fd1); + if nix::sys::select::select(fd + 1, &mut readfds, None, None, &mut tv)? <= 0 { + break; } - nftnl_sys::nftnl_set_elem_set( - elem, - nftnl_sys::NFTNL_SET_ELEM_KEY as u16, - data.as_ptr() as *const c_void, - data_len / 2, - ); - nftnl_sys::nftnl_set_elem_set_u32( - elem, - nftnl_sys::NFTNL_SET_ELEM_FLAGS as u16, - 1, - ); - nftnl_sys::nftnl_set_elem_add(set.as_ptr(), elem); - - let elem = nftnl_sys::nftnl_set_elem_alloc(); - if elem.is_null() { - panic!("oom"); + if !readfds.contains(fd1) { + break; + } + let msglen = socket.recv(&mut buf)?; + match mnl::cb_run(&buf[..msglen], 0, portid)? { + mnl::CbResult::Stop => { + break; + } + mnl::CbResult::Ok => (), } - nftnl_sys::nftnl_set_elem_set( - elem, - nftnl_sys::NFTNL_SET_ELEM_KEY as u16, - data.as_ptr().add((data_len / 2) as usize) as *const c_void, - data_len / 2, - ); - // nftnl_sys::nftnl_set_elem_set_u32( - // elem, - // nftnl_sys::NFTNL_SET_ELEM_FLAGS as u16, - // libc::NFT_SET_ELEM_INTERVAL_END as u32, - // ); - nftnl_sys::nftnl_set_elem_add(set.as_ptr(), elem); } + Ok(()) +} + +pub struct Set1(*mut nftnl_sys::nftnl_set); +impl Set1 { + pub fn new() -> Self { + Self(unsafe { nftnl_sys::nftnl_set_alloc() }) + } + pub fn as_mut_ptr(&self) -> *mut nftnl_sys::nftnl_set { + self.0 + } + pub fn table_name(&self) -> Option<&str> { + let ret = + unsafe { nftnl_sys::nftnl_set_get_str(self.0, nftnl_sys::NFTNL_SET_TABLE as u16) }; + (!ret.is_null()) + .then(|| unsafe { CStr::from_ptr(ret) }.to_str().ok()) + .flatten() + } + pub fn name(&self) -> Option<&str> { + let ret = unsafe { nftnl_sys::nftnl_set_get_str(self.0, nftnl_sys::NFTNL_SET_NAME as u16) }; + (!ret.is_null()) + .then(|| unsafe { CStr::from_ptr(ret) }.to_str().ok()) + .flatten() + } + pub fn family(&self) -> u32 { + unsafe { nftnl_sys::nftnl_set_get_u32(self.0, nftnl_sys::NFTNL_SET_FAMILY as u16) } + } + pub fn add_range(&mut self, lower: &K, excl_upper: Option<&K>) { + let data1 = lower.data(); + let data1_len = data1.len() as u32; + unsafe { + let elem = nftnl_sys::nftnl_set_elem_alloc(); + if elem.is_null() { + panic!("oom"); + } + nftnl_sys::nftnl_set_elem_set( + elem, + nftnl_sys::NFTNL_SET_ELEM_KEY as u16, + data1.as_ptr() as *const c_void, + data1_len, + ); + nftnl_sys::nftnl_set_elem_add(self.as_mut_ptr(), elem); + + let Some(data2) = excl_upper.map(|key| key.data()) else { + return; + }; + let data2_len = data2.len() as u32; + + let elem = nftnl_sys::nftnl_set_elem_alloc(); + if elem.is_null() { + panic!("oom"); + } + nftnl_sys::nftnl_set_elem_set( + elem, + nftnl_sys::NFTNL_SET_ELEM_KEY as u16, + data2.as_ptr() as *const c_void, + data2_len, + ); + nftnl_sys::nftnl_set_elem_set_u32( + elem, + nftnl_sys::NFTNL_SET_ELEM_FLAGS as u16, + libc::NFT_SET_ELEM_INTERVAL_END as u32, + ); + nftnl_sys::nftnl_set_elem_add(self.as_mut_ptr(), elem); + } + } + pub fn add_cidrs(&self, socket: &mnl::Socket, flush: bool, cidrs: impl IntoIterator) -> io::Result<()> { + let mut batch = Batch::new(); + // FIXME: why 2048? + let max_batch_size = 2048; + let mut count = 0; + let mut set = self.clone(); + if flush { + count += 1; + batch.add(&set.flush_msg(), nftnl::MsgType::Del); + } + for net in cidrs.into_iter() { + if count + 2 > max_batch_size { + batch.add_iter(SetElemsIter::new(&set), MsgType::Add); + send_and_process(socket, &batch.finalize())?; + set = self.clone(); + batch = Batch::new(); + } + match net { + IpNet::V4(ip) => { + set.add_range(&ip.network(), cidr_bound_ipv4(ip).as_ref()); + } + IpNet::V6(ip) => { + set.add_range(&ip.network(), cidr_bound_ipv6(ip).as_ref()); + } + } + count += 2; + } + batch.add_iter(SetElemsIter::new(&set), MsgType::Add); + send_and_process(socket, &batch.finalize()) + } + + fn flush_msg(&self) -> FlushSetMsg<'_> { + FlushSetMsg { set: self } + } +} + +impl Clone for Set1 { + fn clone(&self) -> Self { + Self(unsafe { nftnl_sys::nftnl_set_clone(self.0) }) + } +} + +pub fn get_sets(socket: &mnl::Socket) -> io::Result> { + let mut buffer = vec![0; nftnl::nft_nlmsg_maxsize() as usize]; + let seq = 0; + let mut ret = Vec::new(); + unsafe { + nftnl_sys::nftnl_nlmsg_build_hdr( + buffer.as_mut_ptr() as *mut c_char, + libc::NFT_MSG_GETSET as u16, + nftnl::ProtoFamily::Inet as u16, + (libc::NLM_F_DUMP | libc::NLM_F_ACK) as u16, + seq, + ); + } + let cb = |header: &libc::nlmsghdr, ret: &mut Vec| -> libc::c_int { + unsafe { + let set = Set1::new(); + let err = nftnl_sys::nftnl_set_nlmsg_parse(header, set.0); + if err < 0 { + return err; + } + ret.push(set); + }; + 1 + }; + socket.send(&buffer[..])?; + + // Try to parse the messages coming back from netfilter. This part is still very unclear. + let portid = socket.portid(); + let mut buf = vec![0; nftnl::nft_nlmsg_maxsize() as usize]; + let fd = unsafe { mnl_sys::mnl_socket_get_fd(socket.as_raw_socket()) }; + let mut readfds = nix::sys::select::FdSet::new(); + let fd1 = unsafe { BorrowedFd::borrow_raw(fd) }; + let mut tv = nix::sys::time::TimeVal::new(0, 0); + loop { + readfds.clear(); + readfds.insert(fd1); + if nix::sys::select::select(fd + 1, &mut readfds, None, None, &mut tv)? <= 0 { + break; + } + if !readfds.contains(fd1) { + break; + } + let msglen = socket.recv(&mut buf)?; + match mnl::cb_run2(&buf[..msglen], 0, portid, cb, &mut ret)? { + mnl::CbResult::Stop => { + break; + } + mnl::CbResult::Ok => (), + } + } + Ok(ret) } #[cfg(test)] mod test { - use ipnet::Ipv4Net; - use std::{ - ffi::CString, - net::{IpAddr, Ipv4Addr}, - }; + use std::{ffi::CString, net::Ipv6Addr}; - use super::{add, send_and_process, Cidr, FlushSetMsg, SetElemsIter}; + use ipnet::Ipv6Net; + + use super::get_sets; #[test] fn test_nftables() { @@ -219,21 +333,22 @@ mod test { &CString::from_vec_with_nul(b"test\0".to_vec()).unwrap(), nftnl::ProtoFamily::Inet, ); - let mut batch = nftnl::Batch::new(); - let mut set4 = nftnl::set::Set::<_>::new( - &CString::from_vec_with_nul(b"test4\0".to_vec()).unwrap(), - 0, - &table, - nftnl::ProtoFamily::Inet, - ); - batch.add(&FlushSetMsg { set: &set4 }, nftnl::MsgType::Del); - add( - &set4, - &Cidr(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()), - ); - // set4.add(&Ipv4Addr::new(127, 0, 0, 1)); - let mut iter = SetElemsIter::new(&set4); - batch.add_iter(iter, nftnl::MsgType::Add); - send_and_process(&batch.finalize()).unwrap(); + let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap(); + let sets = get_sets(&socket).unwrap(); + assert!(!sets.is_empty()); + for set in sets { + if set.table_name() != Some("test") || set.name() != Some("test7") { + continue; + } + set.add_cidrs( + &socket, + true, + (0u128..8192u128) + .map(|x| ipnet::IpNet::V6(Ipv6Net::new(Ipv6Addr::from(x << 1), 127).unwrap())), + ) + .unwrap(); + return; + } + panic!(); } } diff --git a/src/nftables_lib.rs b/src/nftables_lib.rs new file mode 100644 index 0000000..0a08bda --- /dev/null +++ b/src/nftables_lib.rs @@ -0,0 +1,15 @@ +fn run( + family: &str, + table: &str, + set: &str, + flush: bool, + items: impl IntoIterator, +) { + let nft = libnftables1_sys::Nftables::new(); + let mut cmd = String::new(); + if flush { + cmd.push_str(&format!("flush set {family} {table} {set}")); + nft.run_cmd(c) + } + nft.set_numeric_time +}