diff --git a/flake.nix b/flake.nix index ff71e75..ebf8a8f 100644 --- a/flake.nix +++ b/flake.nix @@ -37,9 +37,7 @@ }; doCheck = false; LIBMNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libmnl}/lib"; - LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib (pkgs.libnftnl.overrideAttrs (old: { - patches = (old.patches or []) ++ [ ./libnftnl-fix.patch ]; - }))}/lib"; + LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libnftnl}/lib"; }; default = unbound-mod; }); @@ -52,9 +50,7 @@ pkgs.nftables ]; LIBMNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libmnl}/lib"; - LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib (pkgs.libnftnl.overrideAttrs (old: { - patches = (old.patches or []) ++ [ ./libnftnl-fix.patch ]; - }))}/lib"; + LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libnftnl}/lib"; LD_LIBRARY_PATH = "${LIBMNL_LIB_DIR}:${LIBNFTNL_LIB_DIR}"; }; }); diff --git a/libnftnl-fix.patch b/libnftnl-fix.patch deleted file mode 100644 index e2b37e0..0000000 --- a/libnftnl-fix.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff --git a/src/libnftnl.map b/src/libnftnl.map -index 8fffff1..3f660de 100644 ---- a/src/libnftnl.map -+++ b/src/libnftnl.map -@@ -129,6 +129,7 @@ global: - nftnl_set_get_str; - nftnl_set_get_u32; - nftnl_set_get_u64; -+ nftnl_set_clone; - nftnl_set_nlmsg_build_payload; - nftnl_set_nlmsg_parse; - nftnl_set_parse; -diff --git a/src/set.c b/src/set.c -index 07e332d..c5f9518 100644 ---- a/src/set.c -+++ b/src/set.c -@@ -352,6 +352,7 @@ uint64_t nftnl_set_get_u64(const struct nftnl_set *s, uint16_t attr) - return val ? *val : 0; - } - -+EXPORT_SYMBOL(nftnl_set_clone); - struct nftnl_set *nftnl_set_clone(const struct nftnl_set *set) - { - struct nftnl_set *newset; diff --git a/src/example.rs b/src/example.rs index d1dd6b3..7b38fe6 100644 --- a/src/example.rs +++ b/src/example.rs @@ -391,7 +391,7 @@ impl ExampleMod { } fn load_env(&mut self) -> Result, NftData)>, ()> { self.nft_token = std::env::var_os("NFT_TOKEN") - .map(|x| x.to_str().ok_or(()).map(|s| s.to_owned())) + .map(|x| x.to_str().ok_or(()).map(ToOwned::to_owned)) .transpose()?; self.tmp_nft_token = std::env::var_os("NFT_TOKEN") .map(|x| x.to_str().ok_or(()).map(|s| format!("tmp{s}"))) diff --git a/src/nftables.rs b/src/nftables.rs index 5ee736d..61a0b93 100644 --- a/src/nftables.rs +++ b/src/nftables.rs @@ -169,22 +169,47 @@ impl Set1 { pub const fn as_mut_ptr(&self) -> *mut nftnl_sys::nftnl_set { self.0 } - pub fn table_name(&self) -> Option<&str> { + pub fn table_name(&self) -> Option<&CStr> { let ret = unsafe { nftnl_sys::nftnl_set_get_str(self.0, nftnl_sys::NFTNL_SET_TABLE as u16) }; - (!ret.is_null()) - .then(|| unsafe { CStr::from_ptr(ret) }.to_str().ok()) - .flatten() + (!ret.is_null()).then(|| unsafe { CStr::from_ptr(ret) }) } - pub fn name(&self) -> Option<&str> { + pub fn table_name_str(&self) -> Option<&str> { + self.table_name().and_then(|s| s.to_str().ok()) + } + pub fn set_table_name(&mut self, s: &CStr) -> Result<(), ()> { + if unsafe { + nftnl_sys::nftnl_set_set_str(self.0, nftnl_sys::NFTNL_SET_TABLE as u16, s.as_ptr()) + } == 0 + { + Ok(()) + } else { + Err(()) + } + } + pub fn name(&self) -> Option<&CStr> { let ret = unsafe { nftnl_sys::nftnl_set_get_str(self.0, nftnl_sys::NFTNL_SET_NAME as u16) }; - (!ret.is_null()) - .then(|| unsafe { CStr::from_ptr(ret) }.to_str().ok()) - .flatten() + (!ret.is_null()).then(|| unsafe { CStr::from_ptr(ret) }) + } + pub fn name_str(&self) -> Option<&str> { + self.name().and_then(|s| s.to_str().ok()) + } + pub fn set_name(&mut self, s: &CStr) -> Result<(), ()> { + if unsafe { + nftnl_sys::nftnl_set_set_str(self.0, nftnl_sys::NFTNL_SET_NAME as u16, s.as_ptr()) + } == 0 + { + Ok(()) + } else { + Err(()) + } } pub fn family(&self) -> u32 { unsafe { nftnl_sys::nftnl_set_get_u32(self.0, nftnl_sys::NFTNL_SET_FAMILY as u16) } } + pub fn set_family(&mut self, val: u32) { + unsafe { nftnl_sys::nftnl_set_set_u32(self.0, nftnl_sys::NFTNL_SET_FAMILY as u16, val) } + } pub fn add_range(&mut self, lower: &K, excl_upper: Option<&K>) { let data1 = lower.data(); let data1_len = data1.len() as u32; @@ -230,7 +255,21 @@ impl Set1 { // FIXME: why 2048? let max_batch_size = 2048; let mut count = 0; - let mut set = self.clone(); + let clone_self = || { + let mut set = Self::new(); + if let Some(s) = self.table_name() { + set.set_table_name(s).expect("oom"); + } + if let Some(s) = self.name() { + set.set_name(s).expect("oom"); + } + let family = self.family(); + if family != 0 { + set.set_family(self.family()); + } + set + }; + let mut set = clone_self(); if flush { count += 1; batch.add(&set.flush_msg(), nftnl::MsgType::Del); @@ -239,7 +278,7 @@ impl Set1 { if count + 2 > max_batch_size { batch.add_iter(SetElemsIter::new(&set), MsgType::Add); send_and_process(socket, &batch.finalize())?; - set = self.clone(); + set = clone_self(); batch = Batch::new(); } match net { @@ -261,12 +300,6 @@ impl Set1 { } } -impl Clone for Set1 { - fn clone(&self) -> Self { - Self(unsafe { nftnl_sys::nftnl_set_clone(self.0) }) - } -} - pub fn get_sets(socket: &mnl::Socket) -> io::Result> { let mut buffer = vec![0; nftnl::nft_nlmsg_maxsize() as usize]; let seq = 0; @@ -441,12 +474,12 @@ pub(crate) fn nftables_thread( let all_sets = crate::nftables::get_sets(&socket).unwrap(); for set in all_sets { for ruleset in &mut rulesets { - if set.table_name() == Some("global") && set.family() == libc::NFPROTO_INET as u32 { - if set.name() == Some(ruleset.0.name()) { + if set.table_name_str() == Some("global") && set.family() == libc::NFPROTO_INET as u32 { + if set.name_str() == Some(ruleset.0.name()) { println!("found set {}", ruleset.0.name()); ruleset.0.set_set(set); break; - } else if set.name() == Some(ruleset.1.name()) { + } else if set.name_str() == Some(ruleset.1.name()) { println!("found set {}", ruleset.1.name()); ruleset.1.set_set(set); break; @@ -511,9 +544,11 @@ mod test { let sets = get_sets(&socket).unwrap(); assert!(!sets.is_empty()); for set in sets { - if set.table_name() != Some("test") || set.name() != Some("test7") { + // add set inet test test7 { type ipv6_addr ; flags interval ; } + if set.table_name_str() != Some("test") || set.name_str() != Some("test7") { continue; } + // must end with ::3ffe/127 set.add_cidrs( &socket, true,