diff --git a/Cargo.lock b/Cargo.lock index 38310a1..e696a54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -50,12 +50,6 @@ dependencies = [ "synstructure", ] -[[package]] -name = "heck" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" - [[package]] name = "ipnet" version = "2.9.0" @@ -98,6 +92,17 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "mnl" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1a5469630da93e1813bb257964c0ccee3b26b6879dd858039ddec35cc8681ed" +dependencies = [ + "libc", + "log", + "mnl-sys", +] + [[package]] name = "mnl-sys" version = "0.2.1" @@ -108,20 +113,6 @@ dependencies = [ "pkg-config", ] -[[package]] -name = "nftables" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b5081f5cae4d24af558828494371be6672d02a693e9b4cbca4a0cbae5443e6f" -dependencies = [ - "serde", - "serde_json", - "serde_path_to_error", - "strum", - "strum_macros", - "thiserror", -] - [[package]] name = "nftnl" version = "0.6.2" @@ -262,41 +253,12 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_path_to_error" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" -dependencies = [ - "itoa", - "serde", -] - [[package]] name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -[[package]] -name = "strum" -version = "0.26.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" - -[[package]] -name = "strum_macros" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.72", -] - [[package]] name = "syn" version = "1.0.109" @@ -331,26 +293,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "thiserror" -version = "1.0.63" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.63" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "unbound-mod" version = "0.1.0" @@ -360,8 +302,8 @@ dependencies = [ "ipnet", "iptrie", "libc", + "mnl", "mnl-sys", - "nftables", "nftnl", "nftnl-sys", "prefix-tree", diff --git a/Cargo.toml b/Cargo.toml index 3c53cc6..dd8b562 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,8 @@ ctor = { version = "0.2.8", optional = true } ipnet = { version = "2.9.0", features = ["serde"] } iptrie = "0.8.5" libc = "0.2.155" +mnl = { version = "0.2.2", features = ["mnl-1-0-4"] } mnl-sys = { version = "0.2.1", features = ["mnl-1-0-4"] } -nftables = "0.4.1" nftnl = { version = "0.6.2", features = ["nftnl-1-1-2"] } nftnl-sys = { version = "0.6.1", features = ["nftnl-1-1-2"] } prefix-tree = "0.5.0" diff --git a/FIXME b/FIXME new file mode 100644 index 0000000..6099071 --- /dev/null +++ b/FIXME @@ -0,0 +1,2 @@ +nftables +token is after, not before diff --git a/flake.nix b/flake.nix index 6ebd650..3a6f16e 100644 --- a/flake.nix +++ b/flake.nix @@ -22,10 +22,11 @@ devShells.x86_64-linux.default = let pkgs = import nixpkgs { system = "x86_64-linux"; }; - in pkgs.mkShell { + in pkgs.mkShell rec { name = "unbound-rust-mod-shell"; LIBMNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libmnl}/lib"; LIBNFTNL_LIB_DIR = "${nixpkgs.lib.getLib pkgs.libnftnl}/lib"; + LD_LIBRARY_PATH = "${LIBMNL_LIB_DIR}:${LIBNFTNL_LIB_DIR}"; }; }; } diff --git a/src/combine.rs b/src/combine.rs new file mode 100644 index 0000000..e135094 --- /dev/null +++ b/src/combine.rs @@ -0,0 +1,76 @@ +use std::panic::{RefUnwindSafe, UnwindSafe}; + +use crate::UnboundMod; + +macro_rules! impl_tuple { + ($($i:tt $t:tt),*) => { + impl UnboundMod for (A, $($t, )*) + where + A: UnboundMod + UnwindSafe + RefUnwindSafe, + $($t: UnboundMod + + UnwindSafe + + RefUnwindSafe,)* + { + type EnvData = A::EnvData; + type QstateData = A::QstateData; + fn init(env: &mut crate::unbound::ModuleEnv) -> Result { + Ok((A::init(env)?, $($t::init(env)?, )*)) + } + fn clear(&self, qstate: &mut crate::unbound::ModuleQstate) { + self.0.clear(qstate); + $(self.$i.clear(qstate);)* + } + fn deinit(self, env: &mut crate::unbound::ModuleEnv) { + self.0.deinit(env); + $(self.$i.deinit(env);)* + } + fn operate( + &self, + qstate: &mut crate::unbound::ModuleQstate, + event: crate::unbound::ModuleEvent, + entry: &mut crate::unbound::OutboundEntryMut, + ) { + self.0.operate(qstate, event, entry); + $(self.$i.operate(qstate, event, entry);)* + } + fn get_mem(&self, env: &mut crate::unbound::ModuleEnv) -> usize { + self.0.get_mem(env) $(* self.$i.get_mem(env))* + } + fn inform_super( + &self, + qstate: &mut crate::unbound::ModuleQstate, + super_qstate: &mut crate::unbound::ModuleQstate, + ) { + self.0.inform_super(qstate, super_qstate); + $(self.$i.inform_super(qstate, super_qstate);)* + } + } + }; +} + +impl_tuple!(); +impl_tuple!(1 B); +impl_tuple!(1 B, 2 C); +impl_tuple!(1 B, 2 C, 3 D); +impl_tuple!(1 B, 2 C, 3 D, 4 E); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q, 17 R); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q, 17 R, 18 S); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q, 17 R, 18 S, 19 T); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q, 17 R, 18 S, 19 T, 20 U); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q, 17 R, 18 S, 19 T, 20 U, 21 V); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q, 17 R, 18 S, 19 T, 20 U, 21 V, 22 W); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q, 17 R, 18 S, 19 T, 20 U, 21 V, 22 W, 23 X); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q, 17 R, 18 S, 19 T, 20 U, 21 V, 22 W, 23 X, 24 Y); +impl_tuple!(1 B, 2 C, 3 D, 4 E, 5 F, 6 G, 7 H, 8 I, 9 J, 10 K, 11 L, 12 M, 13 N, 14 O, 15 P, 16 Q, 17 R, 18 S, 19 T, 20 U, 21 V, 22 W, 23 X, 24 Y, 25 Z); diff --git a/src/lib.rs b/src/lib.rs index 2f860c3..4d85b4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,9 +8,11 @@ use std::panic::{RefUnwindSafe, UnwindSafe}; unused_imports )] mod bindings; +mod combine; #[cfg(feature = "example")] mod example; mod exports; +mod nftables; mod unbound; pub fn add(left: usize, right: usize) -> usize { diff --git a/src/nftables.rs b/src/nftables.rs new file mode 100644 index 0000000..7c94887 --- /dev/null +++ b/src/nftables.rs @@ -0,0 +1,239 @@ +use std::{ + cell::Cell, + io, + net::{Ipv4Addr, Ipv6Addr}, + os::raw::{c_char, c_void}, + rc::Rc, +}; + +use ipnet::{Ipv4Net, Ipv6Net}; +use nftnl::{ + set::{Set, SetKey}, + FinalizedBatch, MsgType, NlMsg, +}; + +// internally represented as a range +struct Cidr(T); +impl SetKey for Cidr { + const TYPE: u32 = Ipv4Addr::TYPE; + const LEN: u32 = Ipv4Addr::LEN * 2; + fn data(&self) -> Box<[u8]> { + let data = u32::from_be_bytes(self.0.network().octets()); + let mask = u32::from_be_bytes(self.0.netmask().octets()); + let mut ret = [0u8; (Self::LEN) as usize]; + ret[..(Ipv4Addr::LEN as usize)].copy_from_slice(&self.0.network().octets()); + ret[(Ipv4Addr::LEN as usize)..].copy_from_slice(&u32::to_be_bytes(!mask | data)); + println!("{ret:?} {:?}", self.0.addr().data()); + Box::new(ret) + } +} +impl SetKey for Cidr { + const TYPE: u32 = Ipv6Addr::TYPE; + const LEN: u32 = Ipv6Addr::LEN * 2; + fn data(&self) -> Box<[u8]> { + let data = u128::from_be_bytes(self.0.network().octets()); + let mask = u128::from_be_bytes(self.0.netmask().octets()); + let mut ret = [0u8; (Self::LEN) as usize]; + ret[..(Ipv6Addr::LEN as usize)].copy_from_slice(&self.0.network().octets()); + ret[(Ipv6Addr::LEN as usize)..].copy_from_slice(&u128::to_be_bytes(!mask | data)); + Box::new(ret) + } +} + +struct FlushSetMsg<'a, T> { + set: &'a Set<'a, T>, +} +unsafe impl<'a, T> NlMsg for FlushSetMsg<'a, T> { + unsafe fn write(&self, buf: *mut std::ffi::c_void, seq: u32, _msg_type: MsgType) { + let header = nftnl_sys::nftnl_nlmsg_build_hdr( + buf as *mut c_char, + libc::NFT_MSG_DELSETELEM as u16, + self.set.get_family() as u16, + 0, + seq, + ); + nftnl_sys::nftnl_set_elems_nlmsg_build_payload(header, self.set.as_ptr()); + } +} + +pub fn send_and_process(batch: &FinalizedBatch) -> io::Result<()> { + let socket = mnl::Socket::new(mnl::Bus::Netfilter)?; + eprintln!("a"); + socket.send_all(batch)?; + eprintln!("b"); + let portid = socket.portid(); + let mut buf = vec![0; nftnl::nft_nlmsg_maxsize() as usize]; + loop { + eprintln!("c"); + let n = socket.recv(&mut buf[..])?; + eprintln!("d {n}"); + if n == 0 { + break; + } + match mnl::cb_run(&buf[..n], 2, portid)? { + mnl::CbResult::Stop => { + println!("stop"); + break; + } + mnl::CbResult::Ok => { + println!("ok"); + } + } + } + Ok(()) +} + +pub struct SetElemsIter<'a, K> { + set: &'a Set<'a, K>, + iter: *mut nftnl_sys::nftnl_set_elems_iter, + ret: Rc>, + is_first: bool, +} + +impl<'a, K> SetElemsIter<'a, K> { + fn new(set: &'a Set<'a, K>) -> Self { + let iter = unsafe { nftnl_sys::nftnl_set_elems_iter_create(set.as_ptr()) }; + if iter.is_null() { + panic!("oom"); + } + SetElemsIter { + set, + iter, + ret: Rc::new(Cell::new(1)), + is_first: true, + } + } +} + +impl<'a, K: 'a> Iterator for SetElemsIter<'a, K> { + type Item = SetElemsMsg<'a, K>; + + fn next(&mut self) -> Option { + if self.is_first { + self.is_first = false; + } else { + unsafe { nftnl_sys::nftnl_set_elems_iter_next(self.iter).is_null() }; + } + if self.ret.get() <= 0 + || unsafe { nftnl_sys::nftnl_set_elems_iter_cur(self.iter).is_null() } + { + None + } else { + Some(SetElemsMsg { + set: self.set, + iter: self.iter, + ret: self.ret.clone(), + }) + } + } +} + +impl<'a, K> Drop for SetElemsIter<'a, K> { + fn drop(&mut self) { + unsafe { nftnl_sys::nftnl_set_elems_iter_destroy(self.iter) }; + } +} + +pub struct SetElemsMsg<'a, K> { + set: &'a Set<'a, K>, + iter: *mut nftnl_sys::nftnl_set_elems_iter, + ret: Rc>, +} + +unsafe impl<'a, K> NlMsg for SetElemsMsg<'a, K> { + unsafe fn write(&self, buf: *mut c_void, seq: u32, msg_type: MsgType) { + let (type_, flags) = match msg_type { + MsgType::Add => ( + libc::NFT_MSG_NEWSETELEM, + libc::NLM_F_CREATE | libc::NLM_F_EXCL | libc::NLM_F_ACK, + ), + MsgType::Del => (libc::NFT_MSG_DELSETELEM, libc::NLM_F_ACK), + }; + let header = nftnl_sys::nftnl_nlmsg_build_hdr( + buf as *mut c_char, + type_ as u16, + self.set.get_family() as u16, + flags as u16, + seq, + ); + self.ret + .set(nftnl_sys::nftnl_set_elems_nlmsg_build_payload_iter( + header, self.iter, + )); + } +} + +fn add(set: &Set, key: &K) { + let data = key.data(); + let data_len = data.len() as u32; + unsafe { + let elem = nftnl_sys::nftnl_set_elem_alloc(); + if elem.is_null() { + panic!("oom"); + } + nftnl_sys::nftnl_set_elem_set( + elem, + nftnl_sys::NFTNL_SET_ELEM_KEY as u16, + data.as_ptr() as *const c_void, + data_len / 2, + ); + nftnl_sys::nftnl_set_elem_set_u32( + elem, + nftnl_sys::NFTNL_SET_ELEM_FLAGS as u16, + 1, + ); + nftnl_sys::nftnl_set_elem_add(set.as_ptr(), elem); + + let elem = nftnl_sys::nftnl_set_elem_alloc(); + if elem.is_null() { + panic!("oom"); + } + nftnl_sys::nftnl_set_elem_set( + elem, + nftnl_sys::NFTNL_SET_ELEM_KEY as u16, + data.as_ptr().add((data_len / 2) as usize) as *const c_void, + data_len / 2, + ); + // nftnl_sys::nftnl_set_elem_set_u32( + // elem, + // nftnl_sys::NFTNL_SET_ELEM_FLAGS as u16, + // libc::NFT_SET_ELEM_INTERVAL_END as u32, + // ); + nftnl_sys::nftnl_set_elem_add(set.as_ptr(), elem); + } +} + +#[cfg(test)] +mod test { + use ipnet::Ipv4Net; + use std::{ + ffi::CString, + net::{IpAddr, Ipv4Addr}, + }; + + use super::{add, send_and_process, Cidr, FlushSetMsg, SetElemsIter}; + + #[test] + fn test_nftables() { + let table = nftnl::Table::new( + &CString::from_vec_with_nul(b"test\0".to_vec()).unwrap(), + nftnl::ProtoFamily::Inet, + ); + let mut batch = nftnl::Batch::new(); + let mut set4 = nftnl::set::Set::<_>::new( + &CString::from_vec_with_nul(b"test4\0".to_vec()).unwrap(), + 0, + &table, + nftnl::ProtoFamily::Inet, + ); + batch.add(&FlushSetMsg { set: &set4 }, nftnl::MsgType::Del); + add( + &set4, + &Cidr(Ipv4Net::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap()), + ); + // set4.add(&Ipv4Addr::new(127, 0, 0, 1)); + let mut iter = SetElemsIter::new(&set4); + batch.add_iter(iter, nftnl::MsgType::Add); + send_and_process(&batch.finalize()).unwrap(); + } +} diff --git a/src/unbound.rs b/src/unbound.rs index de7c931..6136836 100644 --- a/src/unbound.rs +++ b/src/unbound.rs @@ -401,6 +401,8 @@ impl PackedRrsetData<'_> { type RrsetIdType = rrset_id_type; #[non_exhaustive] +#[repr(u32)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum ModuleEvent { /// new query New = 0, @@ -434,6 +436,9 @@ impl From for ModuleEvent { } } +#[non_exhaustive] +#[repr(u32)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum SecStatus { /// UNCHECKED means that object has yet to be validated. Unchecked = 0, @@ -464,6 +469,9 @@ impl From for SecStatus { } } +#[non_exhaustive] +#[repr(i32)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum SldnsEdeCode { None = -1, Other = 0, @@ -527,6 +535,9 @@ impl From for SldnsEdeCode { } } +#[non_exhaustive] +#[repr(u32)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum RrsetTrust { /// Initial value for trust None = 0,