more fixes or something

This commit is contained in:
chayleaf 2024-08-13 09:51:32 +07:00
parent 54623b7fb0
commit 00418b649c
Signed by: chayleaf
GPG key ID: 78171AD46227E68E
2 changed files with 56 additions and 50 deletions

View file

@ -14,7 +14,7 @@ use std::{
use ctor::ctor; use ctor::ctor;
use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use iptrie::{IpPrefix, RTrieSet}; use iptrie::IpPrefix;
use serde::{ use serde::{
de::{Error, Visitor}, de::{Error, Visitor},
Deserialize, Deserialize,
@ -137,16 +137,17 @@ impl<T> Default for IpCache<T> {
fn ignore<T>(_: &mut smallvec::SmallVec<[T; 4]>) {} fn ignore<T>(_: &mut smallvec::SmallVec<[T; 4]>) {}
impl<T> IpCache<T> { impl<T> IpCache<T> {
fn extend_set_with_domain<J: IpPrefix + From<T>>( fn extend_set_with_domain<J: Helper + From<T>>(
&self, &self,
ips: &mut RTrieSet<J>, ips: &mut NftData<J>,
domain_r: IpCacheKey, domain_r: IpCacheKey,
) where ) where
T: Copy, T: Copy,
IpNet: From<J>,
{ {
self.get_maybe_update_rev(domain_r, |val| { self.get_maybe_update_rev(domain_r, |val| {
if let Some(val) = val { if let Some(val) = val {
ips.extend(val.0.iter().copied().map(|x| J::from(x))); ips.extend(val.0.iter().copied().map(From::from));
} }
#[allow(unused_assignments)] #[allow(unused_assignments)]
let mut val = Some(ignore); let mut val = Some(ignore);
@ -381,10 +382,10 @@ impl ExampleMod {
let rev_domain = IpCacheKey::from_split_rev_domain(rev_domain.into_iter()); let rev_domain = IpCacheKey::from_split_rev_domain(rev_domain.into_iter());
self.caches self.caches
.0 .0
.extend_set_with_domain(r.0.ips_mut(), rev_domain.clone()); .extend_set_with_domain(&mut r.0, rev_domain.clone());
self.caches self.caches
.1 .1
.extend_set_with_domain(r.1.ips_mut(), rev_domain.clone()); .extend_set_with_domain(&mut r.1, rev_domain.clone());
} }
} }
} }
@ -491,47 +492,46 @@ impl ExampleMod {
split_domain: &[&[u8]], split_domain: &[&[u8]],
qnames: SmallVec<[usize; 5]>, qnames: SmallVec<[usize; 5]>,
) -> Result<(), ()> { ) -> Result<(), ()> {
println!("adding {ip4:?}/{ip6:?} for {split_domain:?} to {qnames:?}"); if ip4.is_empty() && ip6.is_empty() {
if !ip4.is_empty() || !ip6.is_empty() { return Ok(());
let mut first = true; }
let domain = match split_domain let mut first = true;
.iter() let domain = match split_domain
.copied() .iter()
.map(std::str::from_utf8) .copied()
.try_fold(String::new(), |mut s, comp| { .map(std::str::from_utf8)
if first { .try_fold(String::new(), |mut s, comp| {
first = false; if first {
} else { first = false;
s.push('.'); } else {
} s.push('.');
s.push_str(comp?);
Ok::<_, std::str::Utf8Error>(s)
}) {
Ok(x) => x,
Err(err) => {
self.report("domain utf-8", err);
return Err(());
} }
}; s.push_str(comp?);
let key = IpCacheKey::from_split_domain(split_domain.iter()); Ok::<_, std::str::Utf8Error>(s)
let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new(); }) {
to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from)); Ok(x) => x,
to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from)); Err(err) => {
let keep4 = !ip4.is_empty() && self.caches.0.set(&domain, key.clone(), ip4); self.report("domain utf-8", err);
let keep6 = !ip6.is_empty() && self.caches.1.set(&domain, key, ip6); return Err(());
to_send.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6);
if !to_send.is_empty() {
self.ruleset_queue
.as_ref()
.unwrap()
.send((qnames, to_send))
.unwrap();
} }
};
let key = IpCacheKey::from_split_domain(split_domain.iter());
let mut to_send: SmallVec<[IpNet; 8]> = SmallVec::new();
to_send.extend(ip4.iter().copied().map(Ipv4Net::from).map(IpNet::from));
to_send.extend(ip6.iter().copied().map(Ipv6Net::from).map(IpNet::from));
let keep4 = !ip4.is_empty() && self.caches.0.set(&domain, key.clone(), ip4);
let keep6 = !ip6.is_empty() && self.caches.1.set(&domain, key, ip6);
to_send.retain(|x| x.addr().is_ipv4() && keep4 || x.addr().is_ipv6() && keep6);
if !to_send.is_empty() {
self.ruleset_queue
.as_ref()
.unwrap()
.send((qnames, to_send))
.unwrap();
} }
Ok(()) Ok(())
} }
fn run_commands(&self, split_domain: &[&[u8]]) -> Option<ModuleExtState> { fn run_commands(&self, split_domain: &[&[u8]]) -> Option<ModuleExtState> {
println!("{split_domain:?} {:?}", self.nft_token);
if let Some(split_domain) = self.nft_token.as_ref().and_then(|token| { if let Some(split_domain) = self.nft_token.as_ref().and_then(|token| {
split_domain split_domain
.split_last() .split_last()

View file

@ -332,8 +332,8 @@ fn iter_ip_trie<T: Helper>(trie: &RTrieSet<T>) -> impl '_ + Iterator<Item = T> {
} }
pub(crate) struct NftData<T: Helper> { pub(crate) struct NftData<T: Helper> {
all_ips: RTrieSet<T>,
ips: RTrieSet<T>, ips: RTrieSet<T>,
dirty: bool,
set: Option<Set1>, set: Option<Set1>,
name: String, name: String,
} }
@ -343,7 +343,7 @@ impl<T: Helper> NftData<T> {
Self { Self {
set: None, set: None,
ips: RTrieSet::new(), ips: RTrieSet::new(),
dirty: true, all_ips: RTrieSet::new(),
name: name.to_owned(), name: name.to_owned(),
} }
} }
@ -361,17 +361,24 @@ where
pub fn verify(&mut self) -> bool { pub fn verify(&mut self) -> bool {
if !self.name.is_empty() && self.set.is_none() { if !self.name.is_empty() && self.set.is_none() {
self.ips = RTrieSet::new(); self.ips = RTrieSet::new();
self.all_ips = RTrieSet::new();
false false
} else { } else {
true true
} }
} }
fn dirty(&self) -> bool {
usize::from(self.ips.len()) > 1
}
pub fn flush_changes( pub fn flush_changes(
&mut self, &mut self,
socket: &mnl::Socket, socket: &mnl::Socket,
flush_set: bool, flush_set: bool,
) -> Result<(), io::Error> { ) -> Result<(), io::Error> {
if let Some(set) = self.set.as_mut().filter(|_| self.dirty) { if !self.dirty() {
return Ok(());
}
if let Some(set) = self.set.as_mut() {
if flush_set { if flush_set {
println!( println!(
"initializing set {} with ~{} ips (e.g. {:?})", "initializing set {} with ~{} ips (e.g. {:?})",
@ -380,7 +387,9 @@ where
iter_ip_trie(&self.ips).next(), iter_ip_trie(&self.ips).next(),
); );
} }
set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from)) let ret = set.add_cidrs(socket, flush_set, iter_ip_trie(&self.ips).map(IpNet::from));
self.ips = RTrieSet::new();
ret
} else { } else {
Ok(()) Ok(())
} }
@ -395,15 +404,12 @@ where
!self.name.is_empty() !self.name.is_empty()
} else { } else {
self.set.is_some() self.set.is_some()
}) && should_add(&self.ips, &ip) }) && should_add(&self.all_ips, &ip)
{ {
self.ips.insert(ip); self.ips.insert(ip);
self.dirty = true; self.all_ips.insert(ip);
} }
} }
pub fn ips_mut(&mut self) -> &mut RTrieSet<T> {
&mut self.ips
}
#[cfg(test)] #[cfg(test)]
pub fn ip_count(&self) -> usize { pub fn ip_count(&self) -> usize {
iter_ip_trie(&self.ips).count() iter_ip_trie(&self.ips).count()