more fixes

This commit is contained in:
chayleaf 2024-08-11 00:54:48 +07:00
parent 573ed066b7
commit 3df012a6df
Signed by: chayleaf
GPG key ID: 78171AD46227E68E
8 changed files with 494 additions and 299 deletions

29
Cargo.lock generated
View file

@ -8,6 +8,12 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]] [[package]]
name = "boxcar" name = "boxcar"
version = "0.2.5" version = "0.2.5"
@ -20,6 +26,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]] [[package]]
name = "ctor" name = "ctor"
version = "0.2.8" version = "0.2.8"
@ -119,7 +131,7 @@ version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9201688bd0bc571dfa4c21ce0a525480c8b782776cf88e12571fa89108dd920" checksum = "e9201688bd0bc571dfa4c21ce0a525480c8b782776cf88e12571fa89108dd920"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"err-derive", "err-derive",
"log", "log",
"nftnl-sys", "nftnl-sys",
@ -145,6 +157,18 @@ dependencies = [
"smallvec", "smallvec",
] ]
[[package]]
name = "nix"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags 2.6.0",
"cfg-if",
"cfg_aliases",
"libc",
]
[[package]] [[package]]
name = "pkg-config" name = "pkg-config"
version = "0.3.30" version = "0.3.30"
@ -303,9 +327,8 @@ dependencies = [
"iptrie", "iptrie",
"libc", "libc",
"mnl", "mnl",
"mnl-sys",
"nftnl", "nftnl",
"nftnl-sys", "nix",
"prefix-tree", "prefix-tree",
"radix_trie", "radix_trie",
"serde", "serde",

View file

@ -15,9 +15,8 @@ ipnet = { version = "2.9.0", features = ["serde"] }
iptrie = "0.8.5" iptrie = "0.8.5"
libc = "0.2.155" libc = "0.2.155"
mnl = { version = "0.2.2", features = ["mnl-1-0-4"] } mnl = { version = "0.2.2", features = ["mnl-1-0-4"] }
mnl-sys = { version = "0.2.1", features = ["mnl-1-0-4"] }
nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] } nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] }
nftnl-sys = { version = "0.6.1", features = ["nftnl-1-1-2"] } nix = { version = "0.29.0", features = ["poll"] }
prefix-tree = "0.5.0" prefix-tree = "0.5.0"
radix_trie = "0.2.1" radix_trie = "0.2.1"
serde = { version = "1.0.205", features = ["derive"] } serde = { version = "1.0.205", features = ["derive"] }

2
FIXME
View file

@ -1,2 +0,0 @@
nftables
token is after, not before

View file

@ -25,7 +25,9 @@
in pkgs.mkShell rec { in pkgs.mkShell rec {
name = "unbound-rust-mod-shell"; name = "unbound-rust-mod-shell";
LIBMNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libmnl}/lib"; LIBMNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libmnl}/lib";
LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libnftnl}/lib"; LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib (pkgs.libnftnl.overrideAttrs (old: {
patches = (old.patches or []) ++ [ ./libnftnl-fix.patch ];
}))}/lib";
LD_LIBRARY_PATH = "${LIBMNL_LIB_DIR}:${LIBNFTNL_LIB_DIR}"; LD_LIBRARY_PATH = "${LIBMNL_LIB_DIR}:${LIBNFTNL_LIB_DIR}";
}; };
}; };

24
libnftnl-fix.patch Normal file
View file

@ -0,0 +1,24 @@
diff --git a/src/libnftnl.map b/src/libnftnl.map
index 8fffff1..3f660de 100644
--- a/src/libnftnl.map
+++ b/src/libnftnl.map
@@ -129,6 +129,7 @@ global:
nftnl_set_get_str;
nftnl_set_get_u32;
nftnl_set_get_u64;
+ nftnl_set_clone;
nftnl_set_nlmsg_build_payload;
nftnl_set_nlmsg_parse;
nftnl_set_parse;
diff --git a/src/set.c b/src/set.c
index 07e332d..c5f9518 100644
--- a/src/set.c
+++ b/src/set.c
@@ -352,6 +352,7 @@ uint64_t nftnl_set_get_u64(const struct nftnl_set *s, uint16_t attr)
return val ? *val : 0;
}
+EXPORT_SYMBOL(nftnl_set_clone);
struct nftnl_set *nftnl_set_clone(const struct nftnl_set *set)
{
struct nftnl_set *newset;

View file

@ -1,29 +1,27 @@
use std::{ use std::{
collections::HashMap, collections::HashMap,
ffi::CString,
fmt::Display, fmt::Display,
fs::File, fs::File,
io::{self, BufRead, BufReader, Write}, io::{self, BufRead, BufReader, Write},
net::{Ipv4Addr, Ipv6Addr}, net::{Ipv4Addr, Ipv6Addr},
os::raw::c_char,
path::{Path, PathBuf}, path::{Path, PathBuf},
str::FromStr, str::FromStr,
sync::{ sync::{
mpsc::{self, RecvError}, mpsc::{self, RecvError},
Mutex, RwLock, Mutex, RwLock,
}, },
time::{Duration, Instant, SystemTime}, time::{Duration, SystemTime},
}; };
use ctor::ctor; use ctor::ctor;
use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use iptrie::{Ipv4Prefix, Ipv6Prefix, RTrieSet}; use iptrie::{IpPrefix, RTrieSet};
use nftnl::set::SetKey;
use prefix_tree::PrefixSet; use prefix_tree::PrefixSet;
use serde::Deserialize; use serde::Deserialize;
use smallvec::SmallVec; use smallvec::SmallVec;
use crate::{ use crate::{
nftables::Set1,
unbound::{rr_class, rr_type}, unbound::{rr_class, rr_type},
UnboundMod, UnboundMod,
}; };
@ -199,12 +197,19 @@ impl<T: FromStr> IpCache<T> {
} }
struct NftData { struct NftData {
ips4: iptrie::Ipv4RTrieSet, ips4: RTrieSet<Ipv4Net>,
ips6: iptrie::Ipv6RTrieSet, ips6: RTrieSet<Ipv6Net>,
name4: CString, dirty4: bool,
name6: CString, dirty6: bool,
set4: Option<Set1>,
set6: Option<Set1>,
name4: String,
name6: String,
} }
// SAFETY: set4/set6 are None initially and are never actually sent
unsafe impl Send for NftData {}
struct NftQuery { struct NftQuery {
domains: RwLock<prefix_tree::PrefixSet<DomainSeg>>, domains: RwLock<prefix_tree::PrefixSet<DomainSeg>>,
dynamic: bool, dynamic: bool,
@ -242,16 +247,61 @@ struct DpiInfo {
// restriction: {"code": "ban"} // restriction: {"code": "ban"}
} }
trait Helper: iptrie::IpPrefix + PartialEq {
const ZERO: Self;
fn direct_parent(&self) -> Option<Self>;
}
impl Helper for Ipv4Net {
const ZERO: Self = match Self::new(Ipv4Addr::UNSPECIFIED, 0) {
Ok(x) => x,
#[allow(clippy::empty_loop)]
Err(_) => loop {},
};
fn direct_parent(&self) -> Option<Self> {
self.len()
.checked_sub(1)
.and_then(|x| Self::new(self.bitslot().into(), x).ok())
}
}
impl Helper for Ipv6Net {
const ZERO: Self = match Self::new(Ipv6Addr::UNSPECIFIED, 0) {
Ok(x) => x,
#[allow(clippy::empty_loop)]
Err(_) => loop {},
};
fn direct_parent(&self) -> Option<Self> {
self.len()
.checked_sub(1)
.and_then(|x| Self::new(self.bitslot().into(), x).ok())
}
}
fn should_add<T: Helper>(trie: &RTrieSet<T>, elem: &T) -> bool {
*trie.lookup(elem) == T::ZERO
}
fn iter_ip_trie<T: Helper>(trie: &RTrieSet<T>) -> impl '_ + Iterator<Item = T> {
trie.iter().copied().filter(|x| {
if let Some(par) = x.direct_parent() {
should_add(trie, &par)
} else {
*x != T::ZERO
}
})
}
impl UnboundMod for ExampleMod { impl UnboundMod for ExampleMod {
type EnvData = (); type EnvData = ();
type QstateData = (); type QstateData = ();
fn init(_env: &mut crate::unbound::ModuleEnv<Self::EnvData>) -> Result<Self, ()> { fn init(_env: &mut crate::unbound::ModuleEnv<Self::EnvData>) -> Result<Self, ()> {
let mut ret = Self { let mut ret = Self {
nft_token: std::env::var_os("NFT_TOKEN") nft_token: std::env::var_os("NFT_TOKEN")
.map(|x| x.to_str().ok_or(()).map(|s| ".".to_owned() + s)) .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned() + "."))
.transpose()?, .transpose()?,
tmp_nft_token: std::env::var_os("NFT_TOKEN") tmp_nft_token: std::env::var_os("NFT_TOKEN")
.map(|x| x.to_str().ok_or(()).map(|s| ".tmp".to_owned() + s)) .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned() + ".tmp."))
.transpose()?, .transpose()?,
..Self::default() ..Self::default()
}; };
@ -295,10 +345,14 @@ impl UnboundMod for ExampleMod {
}, },
); );
rulesets.push(NftData { rulesets.push(NftData {
set4: None,
set6: None,
ips4: RTrieSet::new(), ips4: RTrieSet::new(),
ips6: RTrieSet::new(), ips6: RTrieSet::new(),
name4: CString::from_vec_with_nul((set4.to_owned() + "\0").into()).unwrap(), dirty4: true,
name6: CString::from_vec_with_nul((set6.to_owned() + "\0").into()).unwrap(), dirty6: true,
name4: set4.to_owned(),
name6: set6.to_owned(),
}); });
} }
} }
@ -352,14 +406,14 @@ impl UnboundMod for ExampleMod {
Ok(ips) => { Ok(ips) => {
r.ips4.extend(ips.iter().filter_map(|x| { r.ips4.extend(ips.iter().filter_map(|x| {
if let IpNet::V4(x) = x { if let IpNet::V4(x) = x {
Ipv4Prefix::new(x.addr(), x.prefix_len()).ok() Some(*x)
} else { } else {
None None
} }
})); }));
r.ips6.extend(ips.iter().filter_map(|x| { r.ips6.extend(ips.iter().filter_map(|x| {
if let IpNet::V6(x) = x { if let IpNet::V6(x) = x {
Ipv6Prefix::new(x.addr(), x.prefix_len()).ok() Some(*x)
} else { } else {
None None
} }
@ -378,7 +432,7 @@ impl UnboundMod for ExampleMod {
.into(), .into(),
|val| { |val| {
if let Some(val) = val { if let Some(val) = val {
r.ips4.extend(val.iter().map(|x| Ipv4Prefix::from(*x))); r.ips4.extend(val.iter().map(|x| Ipv4Net::from(*x)));
} }
None None
}, },
@ -392,7 +446,7 @@ impl UnboundMod for ExampleMod {
.into(), .into(),
|val| { |val| {
if let Some(val) = val { if let Some(val) = val {
r.ips6.extend(val.iter().map(|x| Ipv6Prefix::from(*x))); r.ips6.extend(val.iter().map(|x| Ipv6Net::from(*x)));
} }
None None
}, },
@ -403,162 +457,97 @@ impl UnboundMod for ExampleMod {
// add stuff to nftables // add stuff to nftables
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
ret.ruleset_queue = Some(tx); ret.ruleset_queue = Some(tx);
std::thread::spawn(move || { std::thread::spawn(move || {
let table = nftnl::Table::new( fn report(err: impl Display) {
&CString::from_vec_with_nul(b"global\0".to_vec()).unwrap(), if let Ok(mut file) = std::fs::OpenOptions::new()
nftnl::ProtoFamily::Inet, .append(true)
); .open("/var/lib/unbound/nftables.log")
let mut first = true;
let mut bufs = vec![Vec::<IpNet>::new(); rulesets.len()];
let mut len = 0;
let mut queue_start = Instant::now();
loop {
let res = if len == 0 {
match rx.recv() {
Ok(val) => {
queue_start = Instant::now();
Some(val)
}
Err(RecvError) => break,
}
} else {
match rx.recv_timeout((queue_start + Duration::from_secs(30)) - Instant::now())
{ {
Ok(val) => Some(val), if file.write_all(err.to_string().as_bytes()).is_err() {
Err(mpsc::RecvTimeoutError::Timeout) => None, return;
Err(mpsc::RecvTimeoutError::Disconnected) => break,
} }
if file.write_all(b"\n").is_err() {
return;
}
file.flush().unwrap_or(());
}
}
let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap();
let all_sets = crate::nftables::get_sets(&socket).unwrap();
for set in all_sets {
for ruleset in &mut rulesets {
if set.table_name() == Some("global")
&& set.family() == libc::NFPROTO_INET as u32
{
if set.name() == Some(&ruleset.name4) {
ruleset.set4 = Some(set.clone());
} else if set.name() == Some(&ruleset.name6) {
ruleset.set6 = Some(set.clone());
}
}
}
}
for ruleset in &mut rulesets {
if !ruleset.name4.is_empty() && ruleset.set4.is_none() {
report(format!("set {} not found", ruleset.name4));
ruleset.ips4 = RTrieSet::new();
}
if !ruleset.name6.is_empty() && ruleset.set6.is_none() {
report(format!("set {} not found", ruleset.name6));
ruleset.ips6 = RTrieSet::new();
}
}
let mut first = true;
loop {
for ruleset in &mut rulesets {
if let Some(set) = ruleset.set4.as_mut().filter(|_| ruleset.dirty4) {
if let Err(err) = set.add_cidrs(
&socket,
first,
iter_ip_trie(&ruleset.ips4).map(IpNet::V4),
) {
report(err);
}
}
if let Some(set) = ruleset.set6.as_mut().filter(|_| ruleset.dirty6) {
if let Err(err) = set.add_cidrs(
&socket,
first,
iter_ip_trie(&ruleset.ips6).map(IpNet::V6),
) {
report(err);
}
}
}
first = false;
let res = match rx.recv() {
Ok(val) => Some(val),
Err(RecvError) => break,
}; };
let do_it =
res.is_none() || (Instant::now() - queue_start) > Duration::from_secs(25);
if let Some((rulesets1, ips)) = res { if let Some((rulesets1, ips)) = res {
for ruleset in rulesets1 { for i in rulesets1.into_iter() {
let ruleset = &mut rulesets[i];
for ip1 in ips.iter().copied() { for ip1 in ips.iter().copied() {
match ip1 { match ip1 {
IpNet::V4(ip) => { IpNet::V4(ip) => {
if !rulesets[ruleset].ips4.contains(&ip) { if ruleset.set4.is_some() && !should_add(&ruleset.ips4, &ip) {
rulesets[ruleset].ips4.insert(ip.into()); ruleset.ips4.insert(ip);
bufs[ruleset].push(ip1); ruleset.dirty4 = true;
} }
} }
IpNet::V6(ip) => { IpNet::V6(ip) => {
if !rulesets[ruleset].ips6.contains(&ip) { if ruleset.set6.is_some() && !should_add(&ruleset.ips6, &ip) {
rulesets[ruleset].ips6.insert(ip.into()); ruleset.ips6.insert(ip);
bufs[ruleset].push(ip1); ruleset.dirty6 = true;
len += 1;
} }
} }
} }
} }
} }
} }
struct FlushSetMsg<'a, T> {
set: &'a nftnl::set::Set<'a, T>,
}
unsafe impl<'a, T> nftnl::NlMsg for FlushSetMsg<'a, T> {
unsafe fn write(
&self,
buf: *mut std::ffi::c_void,
seq: u32,
_msg_type: nftnl::MsgType,
) {
let header = nftnl_sys::nftnl_nlmsg_build_hdr(
buf as *mut c_char,
libc::NFT_MSG_DELSETELEM as u16,
self.set.get_family() as u16,
0,
seq,
);
nftnl_sys::nftnl_set_elems_nlmsg_build_payload(header, self.set.as_ptr());
}
}
if do_it || len >= 128 {
let mut batch = nftnl::Batch::new();
for (ruleset, buf) in rulesets.iter().zip(bufs.iter_mut()) {
// internally represented as a range
struct Cidr<T>(T);
impl SetKey for Cidr<Ipv4Net> {
const TYPE: u32 = Ipv4Addr::TYPE;
const LEN: u32 = Ipv4Addr::LEN * 2;
fn data(&self) -> Box<[u8]> {
let data = u32::from_be_bytes(self.0.network().octets());
let mask = u32::from_be_bytes(self.0.netmask().octets());
let mut ret = [0u8; (Self::LEN) as usize];
ret[..(Self::LEN as usize)]
.copy_from_slice(&self.0.network().octets());
ret[(Self::LEN as usize)..]
.copy_from_slice(&u32::to_be_bytes(!mask | data));
Box::new(ret)
}
}
impl SetKey for Cidr<Ipv6Net> {
const TYPE: u32 = Ipv6Addr::TYPE;
const LEN: u32 = Ipv6Addr::LEN * 2;
fn data(&self) -> Box<[u8]> {
let data = u128::from_be_bytes(self.0.network().octets());
let mask = u128::from_be_bytes(self.0.netmask().octets());
let mut ret = [0u8; (Self::LEN) as usize];
ret[..(Self::LEN as usize)]
.copy_from_slice(&self.0.network().octets());
ret[(Self::LEN as usize)..]
.copy_from_slice(&u128::to_be_bytes(!mask | data));
Box::new(ret)
}
}
let set4 = nftnl::set::Set::<Cidr<Ipv4Net>>::new(
&ruleset.name4,
0,
&table,
nftnl::ProtoFamily::Ipv4,
);
let set6 = nftnl::set::Set::<Cidr<Ipv6Net>>::new(
&ruleset.name6,
0,
&table,
nftnl::ProtoFamily::Ipv6,
);
if first {
batch.add(&FlushSetMsg { set: &set4 }, nftnl::MsgType::Del);
batch.add(&FlushSetMsg { set: &set6 }, nftnl::MsgType::Del);
}
let mut set4 = nftnl::set::Set::new(
&ruleset.name4,
0,
&table,
nftnl::ProtoFamily::Ipv4,
);
let mut set6 = nftnl::set::Set::new(
&ruleset.name6,
0,
&table,
nftnl::ProtoFamily::Ipv6,
);
let mut added4 = false;
let mut added6 = false;
for ip in buf.drain(..) {
match ip {
IpNet::V4(ip) => {
set4.add(&Cidr(ip));
added4 = true;
}
IpNet::V6(ip) => {
set6.add(&Cidr(ip));
added6 = true;
}
}
}
if added4 {
batch.add_iter(set4.elems_iter(), nftnl::MsgType::Add);
}
if added6 {
batch.add_iter(set6.elems_iter(), nftnl::MsgType::Add);
}
}
len = 0;
first = false;
}
} }
}); });
@ -577,12 +566,12 @@ impl UnboundMod for ExampleMod {
if let Some(rev_domain) = self if let Some(rev_domain) = self
.nft_token .nft_token
.as_ref() .as_ref()
.and_then(|token| rev_domain.strip_suffix(token.as_bytes())) .and_then(|token| rev_domain.strip_prefix(token.as_bytes()))
{ {
for (qname, query) in self.nft_queries.iter() { for (qname, query) in self.nft_queries.iter() {
if query.dynamic && rev_domain.ends_with(qname.as_bytes()) { if query.dynamic && rev_domain.ends_with(qname.as_bytes()) {
if let Some(rev_domain) = if let Some(rev_domain) =
rev_domain.strip_suffix((".".to_owned() + qname).as_bytes()) rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes())
{ {
let rev_domain = rev_domain let rev_domain = rev_domain
.split(|x| *x == b'.') .split(|x| *x == b'.')
@ -635,12 +624,12 @@ impl UnboundMod for ExampleMod {
} else if let Some(rev_domain) = self } else if let Some(rev_domain) = self
.tmp_nft_token .tmp_nft_token
.as_ref() .as_ref()
.and_then(|token| rev_domain.strip_suffix(token.as_bytes())) .and_then(|token| rev_domain.strip_prefix(token.as_bytes()))
{ {
for (qname, query) in self.nft_queries.iter() { for (qname, query) in self.nft_queries.iter() {
if query.dynamic && rev_domain.ends_with(qname.as_bytes()) { if query.dynamic && rev_domain.ends_with(qname.as_bytes()) {
if let Some(rev_domain) = if let Some(rev_domain) =
rev_domain.strip_suffix((".".to_owned() + qname).as_bytes()) rev_domain.strip_prefix((qname.to_owned() + ".").as_bytes())
{ {
let rev_domain = rev_domain let rev_domain = rev_domain
.split(|x| *x == b'.') .split(|x| *x == b'.')
@ -749,3 +738,33 @@ impl UnboundMod for ExampleMod {
fn setup() { fn setup() {
crate::set_unbound_mod::<ExampleMod>(); crate::set_unbound_mod::<ExampleMod>();
} }
#[cfg(test)]
mod test {
use std::net::Ipv4Addr;
use ipnet::Ipv4Net;
use iptrie::RTrieSet;
use crate::example::{iter_ip_trie, should_add};
#[test]
fn test() {
let mut trie = RTrieSet::new();
assert!(should_add(
&trie,
&Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()
));
trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap());
assert!(!should_add(
&trie,
&Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()
));
trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 31).unwrap());
assert!(dbg!(iter_ip_trie(&trie).collect::<Vec<_>>()).len() == 1);
// contains 0.0.0.0, etc
assert!(dbg!(trie.iter().collect::<Vec<_>>()).len() == 3);
trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 1, 1), 32).unwrap());
assert!(dbg!(iter_ip_trie(&trie).collect::<Vec<_>>()).len() == 2);
}
}

View file

@ -1,98 +1,68 @@
use std::{ use std::{
cell::Cell, cell::Cell,
ffi::CStr,
io, io,
net::{Ipv4Addr, Ipv6Addr}, net::{Ipv4Addr, Ipv6Addr},
os::raw::{c_char, c_void}, os::{
fd::BorrowedFd,
raw::{c_char, c_void},
},
rc::Rc, rc::Rc,
}; };
use ipnet::{Ipv4Net, Ipv6Net}; use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use nftnl::{ use nftnl::{nftnl_sys, set::SetKey, Batch, FinalizedBatch, MsgType, NlMsg};
set::{Set, SetKey}, use mnl::mnl_sys;
FinalizedBatch, MsgType, NlMsg,
};
// internally represented as a range fn cidr_bound_ipv4(net: Ipv4Net) -> Option<Ipv4Addr> {
struct Cidr<T>(T); let data = u32::from(net.network());
impl SetKey for Cidr<Ipv4Net> { let mask = u32::from(net.netmask());
const TYPE: u32 = Ipv4Addr::TYPE; let ip = (!mask | data).wrapping_add(1);
const LEN: u32 = Ipv4Addr::LEN * 2; if ip == 0 {
fn data(&self) -> Box<[u8]> { None
let data = u32::from_be_bytes(self.0.network().octets()); } else {
let mask = u32::from_be_bytes(self.0.netmask().octets()); Some(ip.into())
let mut ret = [0u8; (Self::LEN) as usize];
ret[..(Ipv4Addr::LEN as usize)].copy_from_slice(&self.0.network().octets());
ret[(Ipv4Addr::LEN as usize)..].copy_from_slice(&u32::to_be_bytes(!mask | data));
println!("{ret:?} {:?}", self.0.addr().data());
Box::new(ret)
}
}
impl SetKey for Cidr<Ipv6Net> {
const TYPE: u32 = Ipv6Addr::TYPE;
const LEN: u32 = Ipv6Addr::LEN * 2;
fn data(&self) -> Box<[u8]> {
let data = u128::from_be_bytes(self.0.network().octets());
let mask = u128::from_be_bytes(self.0.netmask().octets());
let mut ret = [0u8; (Self::LEN) as usize];
ret[..(Ipv6Addr::LEN as usize)].copy_from_slice(&self.0.network().octets());
ret[(Ipv6Addr::LEN as usize)..].copy_from_slice(&u128::to_be_bytes(!mask | data));
Box::new(ret)
} }
} }
struct FlushSetMsg<'a, T> { fn cidr_bound_ipv6(net: Ipv6Net) -> Option<Ipv6Addr> {
set: &'a Set<'a, T>, let data = u128::from_be_bytes(net.network().octets());
let mask = u128::from_be_bytes(net.netmask().octets());
let ip = (!mask | data).wrapping_add(1);
if ip == 0 {
None
} else {
Some(ip.into())
} }
unsafe impl<'a, T> NlMsg for FlushSetMsg<'a, T> { }
#[must_use]
struct FlushSetMsg<'a> {
set: &'a Set1,
}
unsafe impl<'a> NlMsg for FlushSetMsg<'a> {
unsafe fn write(&self, buf: *mut std::ffi::c_void, seq: u32, _msg_type: MsgType) { unsafe fn write(&self, buf: *mut std::ffi::c_void, seq: u32, _msg_type: MsgType) {
let header = nftnl_sys::nftnl_nlmsg_build_hdr( let header = nftnl_sys::nftnl_nlmsg_build_hdr(
buf as *mut c_char, buf as *mut c_char,
libc::NFT_MSG_DELSETELEM as u16, libc::NFT_MSG_DELSETELEM as u16,
self.set.get_family() as u16, self.set.family() as u16,
0, 0,
seq, seq,
); );
nftnl_sys::nftnl_set_elems_nlmsg_build_payload(header, self.set.as_ptr()); nftnl_sys::nftnl_set_elems_nlmsg_build_payload(header, self.set.as_mut_ptr());
} }
} }
pub fn send_and_process(batch: &FinalizedBatch) -> io::Result<()> { pub struct SetElemsIter<'a> {
let socket = mnl::Socket::new(mnl::Bus::Netfilter)?; set: &'a Set1,
eprintln!("a");
socket.send_all(batch)?;
eprintln!("b");
let portid = socket.portid();
let mut buf = vec![0; nftnl::nft_nlmsg_maxsize() as usize];
loop {
eprintln!("c");
let n = socket.recv(&mut buf[..])?;
eprintln!("d {n}");
if n == 0 {
break;
}
match mnl::cb_run(&buf[..n], 2, portid)? {
mnl::CbResult::Stop => {
println!("stop");
break;
}
mnl::CbResult::Ok => {
println!("ok");
}
}
}
Ok(())
}
pub struct SetElemsIter<'a, K> {
set: &'a Set<'a, K>,
iter: *mut nftnl_sys::nftnl_set_elems_iter, iter: *mut nftnl_sys::nftnl_set_elems_iter,
ret: Rc<Cell<i32>>, ret: Rc<Cell<i32>>,
is_first: bool, is_first: bool,
} }
impl<'a, K> SetElemsIter<'a, K> { impl<'a> SetElemsIter<'a> {
fn new(set: &'a Set<'a, K>) -> Self { fn new(set: &'a Set1) -> Self {
let iter = unsafe { nftnl_sys::nftnl_set_elems_iter_create(set.as_ptr()) }; let iter = unsafe { nftnl_sys::nftnl_set_elems_iter_create(set.as_mut_ptr()) };
if iter.is_null() { if iter.is_null() {
panic!("oom"); panic!("oom");
} }
@ -105,8 +75,8 @@ impl<'a, K> SetElemsIter<'a, K> {
} }
} }
impl<'a, K: 'a> Iterator for SetElemsIter<'a, K> { impl<'a> Iterator for SetElemsIter<'a> {
type Item = SetElemsMsg<'a, K>; type Item = SetElemsMsg<'a>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.is_first { if self.is_first {
@ -128,31 +98,31 @@ impl<'a, K: 'a> Iterator for SetElemsIter<'a, K> {
} }
} }
impl<'a, K> Drop for SetElemsIter<'a, K> { impl<'a> Drop for SetElemsIter<'a> {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { nftnl_sys::nftnl_set_elems_iter_destroy(self.iter) }; unsafe { nftnl_sys::nftnl_set_elems_iter_destroy(self.iter) };
} }
} }
pub struct SetElemsMsg<'a, K> { pub struct SetElemsMsg<'a> {
set: &'a Set<'a, K>, set: &'a Set1,
iter: *mut nftnl_sys::nftnl_set_elems_iter, iter: *mut nftnl_sys::nftnl_set_elems_iter,
ret: Rc<Cell<i32>>, ret: Rc<Cell<i32>>,
} }
unsafe impl<'a, K> NlMsg for SetElemsMsg<'a, K> { unsafe impl<'a> NlMsg for SetElemsMsg<'a> {
unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) {
let (type_, flags) = match msg_type { let (type_, flags) = match msg_type {
MsgType::Add => ( MsgType::Add => (
libc::NFT_MSG_NEWSETELEM, libc::NFT_MSG_NEWSETELEM,
libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK, libc::NLM_F_CREATE | libc::NLM_F_EXCL,
), ),
MsgType::Del => (libc::NFT_MSG_DELSETELEM, libc::NLM_F_ACK), MsgType::Del => (libc::NFT_MSG_DELSETELEM, 0),
}; };
let header = nftnl_sys::nftnl_nlmsg_build_hdr( let header = nftnl_sys::nftnl_nlmsg_build_hdr(
buf as *mut c_char, buf as *mut c_char,
type_ as u16, type_ as u16,
self.set.get_family() as u16, self.set.family() as u16,
flags as u16, flags as u16,
seq, seq,
); );
@ -163,9 +133,61 @@ unsafe impl<'a, K> NlMsg for SetElemsMsg<'a, K> {
} }
} }
fn add<K: SetKey>(set: &Set<K>, key: &K) { fn send_and_process(socket: &mnl::Socket, batch: &FinalizedBatch) -> io::Result<()> {
let data = key.data(); socket.send_all(batch)?;
let data_len = data.len() as u32; let portid = socket.portid();
let mut buf = vec![0; nftnl::nft_nlmsg_maxsize() as usize];
let fd = unsafe { mnl_sys::mnl_socket_get_fd(socket.as_raw_socket()) };
let mut readfds = nix::sys::select::FdSet::new();
let fd1 = unsafe { BorrowedFd::borrow_raw(fd) };
let mut tv = nix::sys::time::TimeVal::new(0, 0);
loop {
readfds.clear();
readfds.insert(fd1);
if nix::sys::select::select(fd + 1, &mut readfds, None, None, &mut tv)? <= 0 {
break;
}
if !readfds.contains(fd1) {
break;
}
let msglen = socket.recv(&mut buf)?;
match mnl::cb_run(&buf[..msglen], 0, portid)? {
mnl::CbResult::Stop => {
break;
}
mnl::CbResult::Ok => (),
}
}
Ok(())
}
pub struct Set1(*mut nftnl_sys::nftnl_set);
impl Set1 {
pub fn new() -> Self {
Self(unsafe { nftnl_sys::nftnl_set_alloc() })
}
pub fn as_mut_ptr(&self) -> *mut nftnl_sys::nftnl_set {
self.0
}
pub fn table_name(&self) -> Option<&str> {
let ret =
unsafe { nftnl_sys::nftnl_set_get_str(self.0, nftnl_sys::NFTNL_SET_TABLE as u16) };
(!ret.is_null())
.then(|| unsafe { CStr::from_ptr(ret) }.to_str().ok())
.flatten()
}
pub fn name(&self) -> Option<&str> {
let ret = unsafe { nftnl_sys::nftnl_set_get_str(self.0, nftnl_sys::NFTNL_SET_NAME as u16) };
(!ret.is_null())
.then(|| unsafe { CStr::from_ptr(ret) }.to_str().ok())
.flatten()
}
pub fn family(&self) -> u32 {
unsafe { nftnl_sys::nftnl_set_get_u32(self.0, nftnl_sys::NFTNL_SET_FAMILY as u16) }
}
pub fn add_range<K: SetKey>(&mut self, lower: &K, excl_upper: Option<&K>) {
let data1 = lower.data();
let data1_len = data1.len() as u32;
unsafe { unsafe {
let elem = nftnl_sys::nftnl_set_elem_alloc(); let elem = nftnl_sys::nftnl_set_elem_alloc();
if elem.is_null() { if elem.is_null() {
@ -174,15 +196,15 @@ fn add<K: SetKey>(set: &Set<K>, key: &K) {
nftnl_sys::nftnl_set_elem_set( nftnl_sys::nftnl_set_elem_set(
elem, elem,
nftnl_sys::NFTNL_SET_ELEM_KEY as u16, nftnl_sys::NFTNL_SET_ELEM_KEY as u16,
data.as_ptr() as *const c_void, data1.as_ptr() as *const c_void,
data_len / 2, data1_len,
); );
nftnl_sys::nftnl_set_elem_set_u32( nftnl_sys::nftnl_set_elem_add(self.as_mut_ptr(), elem);
elem,
nftnl_sys::NFTNL_SET_ELEM_FLAGS as u16, let Some(data2) = excl_upper.map(|key| key.data()) else {
1, return;
); };
nftnl_sys::nftnl_set_elem_add(set.as_ptr(), elem); let data2_len = data2.len() as u32;
let elem = nftnl_sys::nftnl_set_elem_alloc(); let elem = nftnl_sys::nftnl_set_elem_alloc();
if elem.is_null() { if elem.is_null() {
@ -191,27 +213,119 @@ fn add<K: SetKey>(set: &Set<K>, key: &K) {
nftnl_sys::nftnl_set_elem_set( nftnl_sys::nftnl_set_elem_set(
elem, elem,
nftnl_sys::NFTNL_SET_ELEM_KEY as u16, nftnl_sys::NFTNL_SET_ELEM_KEY as u16,
data.as_ptr().add((data_len / 2) as usize) as *const c_void, data2.as_ptr() as *const c_void,
data_len / 2, data2_len,
); );
// nftnl_sys::nftnl_set_elem_set_u32( nftnl_sys::nftnl_set_elem_set_u32(
// elem, elem,
// nftnl_sys::NFTNL_SET_ELEM_FLAGS as u16, nftnl_sys::NFTNL_SET_ELEM_FLAGS as u16,
// libc::NFT_SET_ELEM_INTERVAL_END as u32, libc::NFT_SET_ELEM_INTERVAL_END as u32,
// ); );
nftnl_sys::nftnl_set_elem_add(set.as_ptr(), elem); nftnl_sys::nftnl_set_elem_add(self.as_mut_ptr(), elem);
} }
} }
pub fn add_cidrs(&self, socket: &mnl::Socket, flush: bool, cidrs: impl IntoIterator<Item = IpNet>) -> io::Result<()> {
let mut batch = Batch::new();
// FIXME: why 2048?
let max_batch_size = 2048;
let mut count = 0;
let mut set = self.clone();
if flush {
count += 1;
batch.add(&set.flush_msg(), nftnl::MsgType::Del);
}
for net in cidrs.into_iter() {
if count + 2 > max_batch_size {
batch.add_iter(SetElemsIter::new(&set), MsgType::Add);
send_and_process(socket, &batch.finalize())?;
set = self.clone();
batch = Batch::new();
}
match net {
IpNet::V4(ip) => {
set.add_range(&ip.network(), cidr_bound_ipv4(ip).as_ref());
}
IpNet::V6(ip) => {
set.add_range(&ip.network(), cidr_bound_ipv6(ip).as_ref());
}
}
count += 2;
}
batch.add_iter(SetElemsIter::new(&set), MsgType::Add);
send_and_process(socket, &batch.finalize())
}
fn flush_msg(&self) -> FlushSetMsg<'_> {
FlushSetMsg { set: self }
}
}
impl Clone for Set1 {
fn clone(&self) -> Self {
Self(unsafe { nftnl_sys::nftnl_set_clone(self.0) })
}
}
pub fn get_sets(socket: &mnl::Socket) -> io::Result<Vec<Set1>> {
let mut buffer = vec![0; nftnl::nft_nlmsg_maxsize() as usize];
let seq = 0;
let mut ret = Vec::new();
unsafe {
nftnl_sys::nftnl_nlmsg_build_hdr(
buffer.as_mut_ptr() as *mut c_char,
libc::NFT_MSG_GETSET as u16,
nftnl::ProtoFamily::Inet as u16,
(libc::NLM_F_DUMP | libc::NLM_F_ACK) as u16,
seq,
);
}
let cb = |header: &libc::nlmsghdr, ret: &mut Vec<Set1>| -> libc::c_int {
unsafe {
let set = Set1::new();
let err = nftnl_sys::nftnl_set_nlmsg_parse(header, set.0);
if err < 0 {
return err;
}
ret.push(set);
};
1
};
socket.send(&buffer[..])?;
// Try to parse the messages coming back from netfilter. This part is still very unclear.
let portid = socket.portid();
let mut buf = vec![0; nftnl::nft_nlmsg_maxsize() as usize];
let fd = unsafe { mnl_sys::mnl_socket_get_fd(socket.as_raw_socket()) };
let mut readfds = nix::sys::select::FdSet::new();
let fd1 = unsafe { BorrowedFd::borrow_raw(fd) };
let mut tv = nix::sys::time::TimeVal::new(0, 0);
loop {
readfds.clear();
readfds.insert(fd1);
if nix::sys::select::select(fd + 1, &mut readfds, None, None, &mut tv)? <= 0 {
break;
}
if !readfds.contains(fd1) {
break;
}
let msglen = socket.recv(&mut buf)?;
match mnl::cb_run2(&buf[..msglen], 0, portid, cb, &mut ret)? {
mnl::CbResult::Stop => {
break;
}
mnl::CbResult::Ok => (),
}
}
Ok(ret)
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use ipnet::Ipv4Net; use std::{ffi::CString, net::Ipv6Addr};
use std::{
ffi::CString,
net::{IpAddr, Ipv4Addr},
};
use super::{add, send_and_process, Cidr, FlushSetMsg, SetElemsIter}; use ipnet::Ipv6Net;
use super::get_sets;
#[test] #[test]
fn test_nftables() { fn test_nftables() {
@ -219,21 +333,22 @@ mod test {
&CString::from_vec_with_nul(b"test\0".to_vec()).unwrap(), &CString::from_vec_with_nul(b"test\0".to_vec()).unwrap(),
nftnl::ProtoFamily::Inet, nftnl::ProtoFamily::Inet,
); );
let mut batch = nftnl::Batch::new(); let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap();
let mut set4 = nftnl::set::Set::<_>::new( let sets = get_sets(&socket).unwrap();
&CString::from_vec_with_nul(b"test4\0".to_vec()).unwrap(), assert!(!sets.is_empty());
0, for set in sets {
&table, if set.table_name() != Some("test") || set.name() != Some("test7") {
nftnl::ProtoFamily::Inet, continue;
); }
batch.add(&FlushSetMsg { set: &set4 }, nftnl::MsgType::Del); set.add_cidrs(
add( &socket,
&set4, true,
&Cidr(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()), (0u128..8192u128)
); .map(|x| ipnet::IpNet::V6(Ipv6Net::new(Ipv6Addr::from(x << 1), 127).unwrap())),
// set4.add(&Ipv4Addr::new(127, 0, 0, 1)); )
let mut iter = SetElemsIter::new(&set4); .unwrap();
batch.add_iter(iter, nftnl::MsgType::Add); return;
send_and_process(&batch.finalize()).unwrap(); }
panic!();
} }
} }

15
src/nftables_lib.rs Normal file
View file

@ -0,0 +1,15 @@
fn run<T: ToString>(
family: &str,
table: &str,
set: &str,
flush: bool,
items: impl IntoIterator<T>,
) {
let nft = libnftables1_sys::Nftables::new();
let mut cmd = String::new();
if flush {
cmd.push_str(&format!("flush set {family} {table} {set}"));
nft.run_cmd(c)
}
nft.set_numeric_time
}