Compare commits
2 commits
14346134b5
...
e9a6f296df
Author | SHA1 | Date | |
---|---|---|---|
chayleaf | e9a6f296df | ||
chayleaf | 303b157557 |
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,2 +1,4 @@
|
||||||
/target
|
/target
|
||||||
/result
|
/result
|
||||||
|
/unbound-mod-test-config
|
||||||
|
/unbound-mod-test-data
|
||||||
|
|
|
@ -17,7 +17,7 @@ iptrie = "0.8.5"
|
||||||
libc = "0.2.155"
|
libc = "0.2.155"
|
||||||
mnl = { version = "0.2.2", features = ["mnl-1-0-4"] }
|
mnl = { version = "0.2.2", features = ["mnl-1-0-4"] }
|
||||||
nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] }
|
nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] }
|
||||||
nix = { version = "0.29.0", features = ["poll"] }
|
nix = { version = "0.29.0", features = ["poll", "user"] }
|
||||||
radix_trie = "0.2.1"
|
radix_trie = "0.2.1"
|
||||||
serde = { version = "1.0.205", features = ["derive"] }
|
serde = { version = "1.0.205", features = ["derive"] }
|
||||||
serde_json = "1.0.122"
|
serde_json = "1.0.122"
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use std::panic::{RefUnwindSafe, UnwindSafe};
|
use std::panic::{RefUnwindSafe, UnwindSafe};
|
||||||
|
|
||||||
|
use crate::unbound::ModuleExtState;
|
||||||
use crate::UnboundMod;
|
use crate::UnboundMod;
|
||||||
|
|
||||||
macro_rules! impl_tuple {
|
macro_rules! impl_tuple {
|
||||||
|
@ -29,9 +30,15 @@ macro_rules! impl_tuple {
|
||||||
qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>,
|
qstate: &mut crate::unbound::ModuleQstate<Self::QstateData>,
|
||||||
event: crate::unbound::ModuleEvent,
|
event: crate::unbound::ModuleEvent,
|
||||||
entry: &mut crate::unbound::OutboundEntryMut,
|
entry: &mut crate::unbound::OutboundEntryMut,
|
||||||
) {
|
) -> Option<ModuleExtState> {
|
||||||
self.0.operate(qstate, event, entry);
|
#[allow(unused_mut)]
|
||||||
$(self.$i.operate(qstate, event, entry);)*
|
let mut ret = self.0.operate(qstate, event, entry);
|
||||||
|
$(if let Some(state) = self.$i.operate(qstate, event, entry) {
|
||||||
|
if !matches!(ret, Some(ret) if ret.importance() >= state.importance()) {
|
||||||
|
ret = Some(state);
|
||||||
|
}
|
||||||
|
})*
|
||||||
|
ret
|
||||||
}
|
}
|
||||||
fn get_mem(&self, env: &mut crate::unbound::ModuleEnv<Self::EnvData>) -> usize {
|
fn get_mem(&self, env: &mut crate::unbound::ModuleEnv<Self::EnvData>) -> usize {
|
||||||
self.0.get_mem(env) $(* self.$i.get_mem(env))*
|
self.0.get_mem(env) $(* self.$i.get_mem(env))*
|
||||||
|
|
|
@ -2,6 +2,7 @@ use std::{collections::HashMap, hash::Hash};
|
||||||
|
|
||||||
use smallvec::{smallvec, SmallVec};
|
use smallvec::{smallvec, SmallVec};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub enum PrefixSet<T> {
|
pub enum PrefixSet<T> {
|
||||||
Map(HashMap<T, PrefixSet<T>>),
|
Map(HashMap<T, PrefixSet<T>>),
|
||||||
Leaf,
|
Leaf,
|
||||||
|
@ -63,7 +64,7 @@ impl<T: Hash + Eq> PrefixSet<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Iter<'a, T>(
|
struct Iter<'a, T>(
|
||||||
SmallVec<[std::collections::hash_map::Iter<'a, T, PrefixSet<T>>; 8]>,
|
SmallVec<[std::collections::hash_map::Iter<'a, T, PrefixSet<T>>; 9]>,
|
||||||
SmallVec<[&'a T; 8]>,
|
SmallVec<[&'a T; 8]>,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -73,9 +74,7 @@ impl<'a, T> Iterator for Iter<'a, T> {
|
||||||
while let Some(it) = self.0.last_mut() {
|
while let Some(it) = self.0.last_mut() {
|
||||||
let Some((k, v)) = it.next() else {
|
let Some((k, v)) = it.next() else {
|
||||||
self.0.pop();
|
self.0.pop();
|
||||||
if self.1.pop().is_none() {
|
self.1.pop()?;
|
||||||
return None;
|
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
self.1.push(k);
|
self.1.push(k);
|
||||||
|
|
1089
src/example.rs
1089
src/example.rs
File diff suppressed because it is too large
Load diff
21
src/lib.rs
21
src/lib.rs
|
@ -1,11 +1,18 @@
|
||||||
|
#![allow(clippy::type_complexity)]
|
||||||
use std::panic::{RefUnwindSafe, UnwindSafe};
|
use std::panic::{RefUnwindSafe, UnwindSafe};
|
||||||
|
|
||||||
|
use unbound::ModuleExtState;
|
||||||
#[allow(
|
#[allow(
|
||||||
dead_code,
|
dead_code,
|
||||||
improper_ctypes,
|
improper_ctypes,
|
||||||
non_camel_case_types,
|
non_camel_case_types,
|
||||||
non_snake_case,
|
non_snake_case,
|
||||||
non_upper_case_globals,
|
non_upper_case_globals,
|
||||||
unused_imports
|
unused_imports,
|
||||||
|
clippy::useless_transmute,
|
||||||
|
clippy::type_complexity,
|
||||||
|
clippy::too_many_arguments,
|
||||||
|
clippy::upper_case_acronyms
|
||||||
)]
|
)]
|
||||||
mod bindings;
|
mod bindings;
|
||||||
mod combine;
|
mod combine;
|
||||||
|
@ -33,7 +40,8 @@ pub trait UnboundMod: Send + Sync + Sized + RefUnwindSafe + UnwindSafe {
|
||||||
_qstate: &mut unbound::ModuleQstate<Self::QstateData>,
|
_qstate: &mut unbound::ModuleQstate<Self::QstateData>,
|
||||||
_event: unbound::ModuleEvent,
|
_event: unbound::ModuleEvent,
|
||||||
_entry: &mut unbound::OutboundEntryMut,
|
_entry: &mut unbound::OutboundEntryMut,
|
||||||
) {
|
) -> Option<ModuleExtState> {
|
||||||
|
Some(ModuleExtState::Finished)
|
||||||
}
|
}
|
||||||
fn inform_super(
|
fn inform_super(
|
||||||
&self,
|
&self,
|
||||||
|
@ -101,11 +109,14 @@ unsafe impl<T: UnboundMod> SealedUnboundMod for T {
|
||||||
entry: *mut bindings::outbound_entry,
|
entry: *mut bindings::outbound_entry,
|
||||||
) {
|
) {
|
||||||
std::panic::catch_unwind(|| {
|
std::panic::catch_unwind(|| {
|
||||||
self.operate(
|
let mut qstate = unbound::ModuleQstate(qstate, id, Default::default());
|
||||||
&mut unbound::ModuleQstate(qstate, id, Default::default()),
|
if let Some(ext_state) = self.operate(
|
||||||
|
&mut qstate,
|
||||||
event.into(),
|
event.into(),
|
||||||
&mut unbound::OutboundEntryMut(entry, Default::default()),
|
&mut unbound::OutboundEntryMut(entry, Default::default()),
|
||||||
)
|
) {
|
||||||
|
qstate.set_ext_state(ext_state);
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
}
|
}
|
||||||
|
|
206
src/nftables.rs
206
src/nftables.rs
|
@ -1,18 +1,23 @@
|
||||||
use std::{
|
use std::{
|
||||||
cell::Cell,
|
cell::Cell,
|
||||||
ffi::CStr,
|
ffi::CStr,
|
||||||
io,
|
fmt::Display,
|
||||||
|
io::{self, Write},
|
||||||
net::{Ipv4Addr, Ipv6Addr},
|
net::{Ipv4Addr, Ipv6Addr},
|
||||||
os::{
|
os::{
|
||||||
fd::BorrowedFd,
|
fd::BorrowedFd,
|
||||||
raw::{c_char, c_void},
|
raw::{c_char, c_void},
|
||||||
},
|
},
|
||||||
rc::Rc,
|
rc::Rc,
|
||||||
|
sync::mpsc,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::example::{Helper, DATA_PREFIX};
|
||||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||||
|
use iptrie::RTrieSet;
|
||||||
use mnl::mnl_sys;
|
use mnl::mnl_sys;
|
||||||
use nftnl::{nftnl_sys, set::SetKey, Batch, FinalizedBatch, MsgType, NlMsg};
|
use nftnl::{nftnl_sys, set::SetKey, Batch, FinalizedBatch, MsgType, NlMsg};
|
||||||
|
use smallvec::SmallVec;
|
||||||
|
|
||||||
fn cidr_bound_ipv4(net: Ipv4Net) -> Option<Ipv4Addr> {
|
fn cidr_bound_ipv4(net: Ipv4Net) -> Option<Ipv4Addr> {
|
||||||
let data = u32::from(net.network());
|
let data = u32::from(net.network());
|
||||||
|
@ -324,16 +329,191 @@ pub fn get_sets(socket: &mnl::Socket) -> io::Result<Vec<Set1>> {
|
||||||
Ok(ret)
|
Ok(ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn should_add<T: Helper>(trie: &RTrieSet<T>, elem: &T) -> bool {
|
||||||
|
*trie.lookup(elem) == T::ZERO
|
||||||
|
}
|
||||||
|
|
||||||
|
fn iter_ip_trie<T: Helper>(trie: &RTrieSet<T>) -> impl '_ + Iterator<Item = T> {
|
||||||
|
trie.iter().copied().filter(|x| {
|
||||||
|
if let Some(par) = x.direct_parent() {
|
||||||
|
should_add(trie, &par)
|
||||||
|
} else {
|
||||||
|
*x != T::ZERO
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct NftData<T: Helper> {
|
||||||
|
ips: RTrieSet<T>,
|
||||||
|
dirty: bool,
|
||||||
|
set: Option<Set1>,
|
||||||
|
name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Helper> NftData<T> {
|
||||||
|
pub fn new(name: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
set: None,
|
||||||
|
ips: RTrieSet::new(),
|
||||||
|
dirty: true,
|
||||||
|
name: name.to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: set is None initially so Set1 is never actually sent
|
||||||
|
// (and it might be fine to send anyway actually)
|
||||||
|
unsafe impl<T: Helper + Send> Send for NftData<T> {}
|
||||||
|
|
||||||
|
impl<T: Helper> NftData<T>
|
||||||
|
where
|
||||||
|
IpNet: From<T>,
|
||||||
|
{
|
||||||
|
#[must_use]
|
||||||
|
pub fn verify(&mut self) -> bool {
|
||||||
|
if !self.name.is_empty() && self.set.is_none() {
|
||||||
|
self.ips = RTrieSet::new();
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn flush_changes(
|
||||||
|
&mut self,
|
||||||
|
socket: &mnl::Socket,
|
||||||
|
flush_set: bool,
|
||||||
|
) -> Result<(), io::Error> {
|
||||||
|
if let Some(set) = self.set.as_mut().filter(|_| self.dirty) {
|
||||||
|
if flush_set {
|
||||||
|
println!(
|
||||||
|
"initializing set {} with ~{} ips (e.g. {:?})",
|
||||||
|
self.name,
|
||||||
|
self.ips.len(),
|
||||||
|
iter_ip_trie(&self.ips).next(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn extend(&mut self, ips: impl Iterator<Item = T>) {
|
||||||
|
for ip in ips {
|
||||||
|
self.insert(ip, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn insert(&mut self, ip: T, allow_empty_set: bool) {
|
||||||
|
if (if allow_empty_set {
|
||||||
|
!self.name.is_empty()
|
||||||
|
} else {
|
||||||
|
self.set.is_some()
|
||||||
|
}) && should_add(&self.ips, &ip)
|
||||||
|
{
|
||||||
|
self.ips.insert(ip);
|
||||||
|
self.dirty = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn ips_mut(&mut self) -> &mut RTrieSet<T> {
|
||||||
|
&mut self.ips
|
||||||
|
}
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn ip_count(&self) -> usize {
|
||||||
|
iter_ip_trie(&self.ips).count()
|
||||||
|
}
|
||||||
|
pub fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
pub fn set_set(&mut self, set: Set1) {
|
||||||
|
self.set = Some(set);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn nftables_thread(
|
||||||
|
mut rulesets: Vec<(NftData<Ipv4Net>, NftData<Ipv6Net>)>,
|
||||||
|
rx: mpsc::Receiver<(SmallVec<[usize; 5]>, smallvec::SmallVec<[IpNet; 8]>)>,
|
||||||
|
) {
|
||||||
|
fn report(err: impl Display) {
|
||||||
|
println!("nftables: {err}");
|
||||||
|
if let Ok(mut file) = std::fs::OpenOptions::new()
|
||||||
|
.append(true)
|
||||||
|
.create(true)
|
||||||
|
.open(format!("{DATA_PREFIX}/nftables.log"))
|
||||||
|
{
|
||||||
|
file.write_all((err.to_string() + "\n").as_bytes())
|
||||||
|
.unwrap_or(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap();
|
||||||
|
let all_sets = crate::nftables::get_sets(&socket).unwrap();
|
||||||
|
for set in all_sets {
|
||||||
|
for ruleset in &mut rulesets {
|
||||||
|
if set.table_name() == Some("global") && set.family() == libc::NFPROTO_INET as u32 {
|
||||||
|
if set.name() == Some(ruleset.0.name()) {
|
||||||
|
println!("found set {}", ruleset.0.name());
|
||||||
|
ruleset.0.set_set(set);
|
||||||
|
break;
|
||||||
|
} else if set.name() == Some(ruleset.1.name()) {
|
||||||
|
println!("found set {}", ruleset.1.name());
|
||||||
|
ruleset.1.set_set(set);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for ruleset in &mut rulesets {
|
||||||
|
if !ruleset.0.verify() {
|
||||||
|
report(format!("set {} not found", ruleset.0.name()));
|
||||||
|
}
|
||||||
|
if !ruleset.1.verify() {
|
||||||
|
report(format!("set {} not found", ruleset.1.name()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut first = true;
|
||||||
|
loop {
|
||||||
|
for ruleset in &mut rulesets {
|
||||||
|
if let Err(err) = ruleset.0.flush_changes(&socket, first) {
|
||||||
|
report(err);
|
||||||
|
}
|
||||||
|
if let Err(err) = ruleset.1.flush_changes(&socket, first) {
|
||||||
|
report(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if first {
|
||||||
|
println!("nftables init done");
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
let (rulesets1, ips) = match rx.recv() {
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(_) => break,
|
||||||
|
};
|
||||||
|
for i in rulesets1.into_iter() {
|
||||||
|
let ruleset = &mut rulesets[i];
|
||||||
|
for ip1 in ips.iter().copied() {
|
||||||
|
match ip1 {
|
||||||
|
IpNet::V4(ip) => ruleset.0.insert(ip, false),
|
||||||
|
IpNet::V6(ip) => ruleset.1.insert(ip, false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use std::net::Ipv6Addr;
|
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||||
|
|
||||||
use ipnet::Ipv6Net;
|
use ipnet::{Ipv4Net, Ipv6Net};
|
||||||
|
use iptrie::RTrieSet;
|
||||||
|
|
||||||
|
use crate::nftables::{iter_ip_trie, should_add};
|
||||||
|
|
||||||
use super::get_sets;
|
use super::get_sets;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_nftables() {
|
fn test_nftables() {
|
||||||
|
if !nix::unistd::Uid::effective().is_root() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap();
|
let socket = mnl::Socket::new(mnl::Bus::Netfilter).unwrap();
|
||||||
let sets = get_sets(&socket).unwrap();
|
let sets = get_sets(&socket).unwrap();
|
||||||
assert!(!sets.is_empty());
|
assert!(!sets.is_empty());
|
||||||
|
@ -352,4 +532,24 @@ mod test {
|
||||||
}
|
}
|
||||||
panic!();
|
panic!();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_set() {
|
||||||
|
let mut trie = RTrieSet::new();
|
||||||
|
assert!(should_add(
|
||||||
|
&trie,
|
||||||
|
&Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()
|
||||||
|
));
|
||||||
|
trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap());
|
||||||
|
assert!(!should_add(
|
||||||
|
&trie,
|
||||||
|
&Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()
|
||||||
|
));
|
||||||
|
trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 31).unwrap());
|
||||||
|
assert!(dbg!(iter_ip_trie(&trie).collect::<Vec<_>>()).len() == 1);
|
||||||
|
// contains 0.0.0.0, etc
|
||||||
|
assert!(dbg!(trie.iter().collect::<Vec<_>>()).len() == 3);
|
||||||
|
trie.insert(Ipv4Net::new(Ipv4Addr::new(127, 0, 1, 1), 32).unwrap());
|
||||||
|
assert!(dbg!(iter_ip_trie(&trie).collect::<Vec<_>>()).len() == 2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -609,6 +609,21 @@ pub enum ModuleExtState {
|
||||||
Unknown = 99,
|
Unknown = 99,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ModuleExtState {
|
||||||
|
pub(crate) fn importance(&self) -> usize {
|
||||||
|
match *self {
|
||||||
|
Self::Unknown => 0,
|
||||||
|
Self::InitialState => 1,
|
||||||
|
Self::Finished => 2,
|
||||||
|
Self::WaitModule => 3,
|
||||||
|
Self::RestartNext => 4,
|
||||||
|
Self::WaitSubquery => 5,
|
||||||
|
Self::WaitReply => 6,
|
||||||
|
Self::Error => 7,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<module_ext_state> for ModuleExtState {
|
impl From<module_ext_state> for ModuleExtState {
|
||||||
fn from(value: module_ext_state) -> Self {
|
fn from(value: module_ext_state) -> Self {
|
||||||
match value {
|
match value {
|
||||||
|
|
Loading…
Reference in a new issue