This commit is contained in:
chayleaf 2024-08-13 05:41:57 +07:00
parent 303b157557
commit e9a6f296df
Signed by: chayleaf
GPG key ID: 78171AD46227E68E
6 changed files with 668 additions and 439 deletions

2
.gitignore vendored
View file

@ -1,2 +1,4 @@
/target /target
/result /result
/unbound-mod-test-config
/unbound-mod-test-data

View file

@ -17,7 +17,7 @@ iptrie = "0.8.5"
libc = "0.2.155" libc = "0.2.155"
mnl = { version = "0.2.2", features = ["mnl-1-0-4"] } mnl = { version = "0.2.2", features = ["mnl-1-0-4"] }
nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] } nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] }
nix = { version = "0.29.0", features = ["poll"] } nix = { version = "0.29.0", features = ["poll", "user"] }
radix_trie = "0.2.1" radix_trie = "0.2.1"
serde = { version = "1.0.205", features = ["derive"] } serde = { version = "1.0.205", features = ["derive"] }
serde_json = "1.0.122" serde_json = "1.0.122"

View file

@ -2,6 +2,7 @@ use std::{collections::HashMap, hash::Hash};
use smallvec::{smallvec, SmallVec}; use smallvec::{smallvec, SmallVec};
#[derive(Debug)]
pub enum PrefixSet<T> { pub enum PrefixSet<T> {
Map(HashMap<T, PrefixSet<T>>), Map(HashMap<T, PrefixSet<T>>),
Leaf, Leaf,

File diff suppressed because it is too large Load diff

View file

@ -1,3 +1,4 @@
#![allow(clippy::type_complexity)]
use std::panic::{RefUnwindSafe, UnwindSafe}; use std::panic::{RefUnwindSafe, UnwindSafe};
use unbound::ModuleExtState; use unbound::ModuleExtState;

View file

@ -1,18 +1,23 @@
use std::{ use std::{
cell::Cell, cell::Cell,
ffi::CStr, ffi::CStr,
io, fmt::Display,
io::{self, Write},
net::{Ipv4Addr, Ipv6Addr}, net::{Ipv4Addr, Ipv6Addr},
os::{ os::{
fd::BorrowedFd, fd::BorrowedFd,
raw::{c_char, c_void}, raw::{c_char, c_void},
}, },
rc::Rc, rc::Rc,
sync::mpsc,
}; };
use crate::example::{Helper, DATA_PREFIX};
use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use iptrie::RTrieSet;
use mnl::mnl_sys; use mnl::mnl_sys;
use nftnl::{nftnl_sys, set::SetKey, Batch, FinalizedBatch, MsgType, NlMsg}; use nftnl::{nftnl_sys, set::SetKey, Batch, FinalizedBatch, MsgType, NlMsg};
use smallvec::SmallVec;
fn cidr_bound_ipv4(net: Ipv4Net) -> Option<Ipv4Addr> { fn cidr_bound_ipv4(net: Ipv4Net) -> Option<Ipv4Addr> {
let data = u32::from(net.network()); let data = u32::from(net.network());
@ -324,16 +329,191 @@ pub fn get_sets(socket: &mnl::Socket) -> io::Result<Vec<Set1>> {
Ok(ret) Ok(ret)
} }
fn should_add<T: Helper>(trie: &RTrieSet<T>, elem: &T) -> bool {
*trie.lookup(elem) == T::ZERO
}
fn iter_ip_trie<T: Helper>(trie: &RTrieSet<T>) -> impl '_ + Iterator<Item = T> {
trie.iter().copied().filter(|x| {
if let Some(par) = x.direct_parent() {
should_add(trie, &par)
} else {
*x != T::ZERO
}
})
}
pub(crate) struct NftData<T: Helper> {
ips: RTrieSet<T>,
dirty: bool,
set: Option<Set1>,
name: String,
}
impl<T: Helper> NftData<T> {
pub fn new(name: &str) -> Self {
Self {
set: None,
ips: RTrieSet::new(),
dirty: true,
name: name.to_owned(),
}
}
}
// SAFETY: set is None initially so Set1 is never actually sent
// (and it might be fine to send anyway actually)
unsafe impl<T: Helper + Send> Send for NftData<T> {}
impl<T: Helper> NftData<T>
where
IpNet: From<T>,
{
#[must_use]
pub fn verify(&mut self) -> bool {
if !self.name.is_empty() && self.set.is_none() {
self.ips = RTrieSet::new();
false
} else {
true
}
}
pub fn flush_changes(
&mut self,
socket: &mnl::Socket,
flush_set: bool,
) -> Result<(), io::Error> {
if let Some(set) = self.set.as_mut().filter(|_| self.dirty) {
if flush_set {
println!(
"initializing set {} with ~{} ips (e.g. {:?})",
self.name,
self.ips.len(),
iter_ip_trie(&self.ips).next(),
);
}
set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from))
} else {
Ok(())
}
}
pub fn extend(&mut self, ips: impl Iterator<Item = T>) {
for ip in ips {
self.insert(ip, true);
}
}
pub fn insert(&mut self, ip: T, allow_empty_set: bool) {
if (if allow_empty_set {
!self.name.is_empty()
} else {
self.set.is_some()
}) && should_add(&self.ips, &ip)
{
self.ips.insert(ip);
self.dirty = true;
}
}
pub fn ips_mut(&mut self) -> &mut RTrieSet<T> {
&mut self.ips
}
#[cfg(test)]
pub fn ip_count(&self) -> usize {
iter_ip_trie(&self.ips).count()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn set_set(&mut self, set: Set1) {
self.set = Some(set);
}
}
pub(crate) fn nftables_thread(
mut rulesets: Vec<(NftData<Ipv4Net>, NftData<Ipv6Net>)>,
rx: mpsc::Receiver<(SmallVec<[usize; 5]>, smallvec::SmallVec<[IpNet; 8]>)>,
) {
fn report(err: impl Display) {
println!("nftables: {err}");
if let Ok(mut file) = std::fs::OpenOptions::new()
.append(true)
.create(true)
.open(format!("{DATA_PREFIX}/nftables.log"))
{
file.write_all((err.to_string() + "\n").as_bytes())
.unwrap_or(());
}
}
let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap();
let all_sets = crate::nftables::get_sets(&socket).unwrap();
for set in all_sets {
for ruleset in &mut rulesets {
if set.table_name() == Some("global") && set.family() == libc::NFPROTO_INET as u32 {
if set.name() == Some(ruleset.0.name()) {
println!("found set {}", ruleset.0.name());
ruleset.0.set_set(set);
break;
} else if set.name() == Some(ruleset.1.name()) {
println!("found set {}", ruleset.1.name());
ruleset.1.set_set(set);
break;
}
}
}
}
for ruleset in &mut rulesets {
if !ruleset.0.verify() {
report(format!("set {} not found", ruleset.0.name()));
}
if !ruleset.1.verify() {
report(format!("set {} not found", ruleset.1.name()));
}
}
let mut first = true;
loop {
for ruleset in &mut rulesets {
if let Err(err) = ruleset.0.flush_changes(&socket, first) {
report(err);
}
if let Err(err) = ruleset.1.flush_changes(&socket, first) {
report(err);
}
}
if first {
println!("nftables init done");
first = false;
}
let (rulesets1, ips) = match rx.recv() {
Ok(val) => val,
Err(_) => break,
};
for i in rulesets1.into_iter() {
let ruleset = &mut rulesets[i];
for ip1 in ips.iter().copied() {
match ip1 {
IpNet::V4(ip) => ruleset.0.insert(ip, false),
IpNet::V6(ip) => ruleset.1.insert(ip, false),
}
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::net::Ipv6Addr; use std::net::{Ipv4Addr, Ipv6Addr};
use ipnet::Ipv6Net; use ipnet::{Ipv4Net, Ipv6Net};
use iptrie::RTrieSet;
use crate::nftables::{iter_ip_trie, should_add};
use super::get_sets; use super::get_sets;
#[test] #[test]
fn test_nftables() { fn test_nftables() {
if !nix::unistd::Uid::effective().is_root() {
return;
}
let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap(); let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap();
let sets = get_sets(&socket).unwrap(); let sets = get_sets(&socket).unwrap();
assert!(!sets.is_empty()); assert!(!sets.is_empty());
@ -352,4 +532,24 @@ mod test {
} }
panic!(); panic!();
} }
#[test]
fn test_set() {
let mut trie = RTrieSet::new();
assert!(should_add(
&trie,
&Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()
));
trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap());
assert!(!should_add(
&trie,
&Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()
));
trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 31).unwrap());
assert!(dbg!(iter_ip_trie(&trie).collect::<Vec<_>>()).len() == 1);
// contains 0.0.0.0, etc
assert!(dbg!(trie.iter().collect::<Vec<_>>()).len() == 3);
trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 1, 1), 32).unwrap());
assert!(dbg!(iter_ip_trie(&trie).collect::<Vec<_>>()).len() == 2);
}
} }