router: riir nftables plugin

This commit is contained in:
chayleaf 2024-08-13 10:42:32 +07:00
parent 03332873d2
commit 63eab7c843
Signed by: chayleaf
GPG key ID: 78171AD46227E68E
5 changed files with 69 additions and 498 deletions

View file

@ -143,6 +143,27 @@
"type": "github"
}
},
"crane_2": {
"inputs": {
"nixpkgs": [
"unbound-rust-mod",
"nixpkgs"
]
},
"locked": {
"lastModified": 1722960479,
"narHash": "sha256-NhCkJJQhD5GUib8zN9JrmYGMwt4lCRp6ZVNzIiYCl0Y=",
"owner": "ipetkov",
"repo": "crane",
"rev": "4c6c77920b8d44cd6660c1621dea6b3fc4b4c4f4",
"type": "github"
},
"original": {
"owner": "ipetkov",
"repo": "crane",
"type": "github"
}
},
"disko": {
"inputs": {
"nixpkgs": [
@ -757,7 +778,8 @@
"notnft": "notnft",
"nur": "nur",
"osu-wine": "osu-wine",
"rust-overlay": "rust-overlay"
"rust-overlay": "rust-overlay",
"unbound-rust-mod": "unbound-rust-mod"
}
},
"rust-overlay": {
@ -908,6 +930,27 @@
"type": "github"
}
},
"unbound-rust-mod": {
"inputs": {
"crane": "crane_2",
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1723517492,
"narHash": "sha256-I8+3ZSr/f73TxJmDQbjcbHmZGh7K3PEZRcGqeoXc1fw=",
"ref": "refs/heads/master",
"rev": "00418b649c67129f828182f81ed72c39f262922a",
"revCount": 14,
"type": "git",
"url": "https://git.pavluk.org/chayleaf/unbound-rust-mod.git"
},
"original": {
"type": "git",
"url": "https://git.pavluk.org/chayleaf/unbound-rust-mod.git"
}
},
"utils": {
"inputs": {
"systems": "systems_2"

View file

@ -52,6 +52,10 @@
url = "gitlab:simple-nixos-mailserver/nixos-mailserver";
inputs.nixpkgs.follows = "nixpkgs";
};
unbound-rust-mod = {
url = "git+https://git.pavluk.org/chayleaf/unbound-rust-mod.git";
inputs.nixpkgs.follows = "nixpkgs";
};
flake-compat = {
url = "github:edolstra/flake-compat";
flake = false;
@ -69,6 +73,7 @@
# nixos-router = true;
# notnft = true;
# nixpkgs = true;
# unbound-rust-mod = true;
};
# IRL-related stuff I'd rather not put into git
priv =

View file

@ -32,6 +32,11 @@ in
};
});
inherit (inputs.unbound-rust-mod.packages.${pkgs.system}) unbound-mod;
unbound-full = pkgs.unbound-full.overrideAttrs (old: {
configureFlags = old.configureFlags ++ [ "--with-dynlibmodule" ];
});
buffyboard = pkgs.callPackage ./buffyboard { };
clang-tools_latest = pkgs.clang-tools_16;
clang_latest = pkgs.clang_16;

View file

@ -178,12 +178,8 @@
#
import gi
import ipaddress
import json
import os
import subprocess
import pydbus
import pytricia # type: ignore
import re
import array
import threading
@ -205,22 +201,6 @@ from typing import TypedDict, Optional, Any
IF_UNSPEC = -1
PROTO_UNSPEC = -1
Domains = dict[str, "Domains | bool"]
class NftQuery(TypedDict):
domains: Domains
ips4: pytricia.PyTricia
ips6: pytricia.PyTricia
name4: str
name6: str
dynamic: bool
NFT_QUERIES: dict[str, NftQuery] = {}
# dynamic query update token
NFT_TOKEN: str = ""
DOMAIN_NAME_OVERRIDES: dict[str, str] = {}
DEBUG = False
MDNS_TTL: int
@ -230,8 +210,6 @@ MDNS_REJECT_TYPES: list[RdataType]
MDNS_ACCEPT_TYPES: list[RdataType]
MDNS_REJECT_NAMES: Optional[re.Pattern]
MDNS_ACCEPT_NAMES: Optional[re.Pattern]
REJECT_A: Optional[re.Pattern] = None
REJECT_AAAA: Optional[re.Pattern] = None
sysbus: pydbus.bus.Bus
avahi: Any # pydbus.proxy.ProxyObject
@ -241,22 +219,6 @@ dbus_thread: threading.Thread
dbus_loop: Any
def is_valid_ip4(x: str) -> bool:
try:
_ = ipaddress.IPv4Address(x)
return True
except ipaddress.AddressValueError:
return False
def is_valid_ip6(x: str) -> bool:
try:
_ = ipaddress.IPv6Address(x)
return True
except ipaddress.AddressValueError:
return False
def str2bool(v: str) -> bool:
if v.lower() in ["false", "no", "0", "off", ""]:
return False
@ -502,185 +464,11 @@ def parse_type_list(lst: str) -> list[RdataType]:
)
def build_ipset(ips: list[str]) -> pytricia.PyTricia:
pyt = pytricia.PyTricia()
for ip in ips:
try:
pyt.insert(ip, None)
except:
with open("/var/lib/unbound/error.log", "at") as f:
f.write(f"Warning: couldn't insert ip {ip}:\n")
traceback.print_exc(file=f)
return pyt
IP_Q = pytricia.PyTricia()
IP_Q_LEN = 0
def add_ips(set: str, ipv6: bool, ips: list[str], flush: bool = False):
global IP_Q, IP_Q_LEN
for ip in ips:
try:
IP_Q.insert(ip, None)
except:
with open("/var/lib/unbound/error.log", "at") as f:
f.write(f"Warning 2: couldn't insert ip {ip}:\n")
traceback.print_exc(file=f)
IP_Q_LEN += len(ips)
if IP_Q_LEN < 16:
return
# with open('/var/lib/unbound/info.log', 'at') as f:
# print('set', set, 'ipv6', ipv6, 'ips', ips, file=f)
pyt = IP_Q
IP_Q = pytricia.PyTricia()
ruleset: list[dict] = []
if flush:
ruleset.append(
{"flush": {"set": {"family": "inet", "table": "global", "name": set}}}
)
elems: list[str | dict] = []
if ipv6:
maxn = 128
is_valid: Callable[[str], bool] = is_valid_ip6
else:
maxn = 32
is_valid = is_valid_ip4
for ip in pyt.keys():
try:
if pyt.parent(ip) != None:
continue
except:
pass
if "/" not in ip:
n: int = maxn
else:
ip, n0 = ip.split("/")
try:
n = int(n0)
except:
continue
if not is_valid(ip):
continue
if n == maxn:
elems.append(ip)
else:
elems.append({"prefix": {"addr": ip, "len": n}})
# with open('/var/lib/unbound/info.log', 'at') as f:
# print('elems', elems, file=f)
if len(elems) == 0:
return
ruleset.append(
{
"add": {
"element": {
"family": "inet",
"table": "global",
"name": set,
"elem": elems,
}
}
}
)
data: bytes = json.dumps({"nftables": ruleset}).encode("utf-8")
# with open('/var/lib/unbound/info.log', 'at') as f:
# print('data', data, file=f)
try:
if flush:
out = subprocess.run(
["/run/current-system/sw/bin/nft", "-j", "-f", "/dev/stdin"],
capture_output=True,
input=data,
)
# with open('/var/lib/unbound/info.log', 'at') as f:
# print('out', out, file=f)
if out.returncode != 0:
with open("/var/lib/unbound/nftables.log", "wb") as f:
f.write(b"Error running nftables ruleset. Ruleset:\n")
f.write(data)
f.write(b"\n")
f.write(b"stdout:\n")
f.write(out.stdout)
f.write(b"\nstderr:\n")
f.write(out.stderr)
f.write(b"\n")
else:
proc = subprocess.Popen(
["/run/current-system/sw/bin/nft", "-j", "-f", "/dev/stdin"],
stdin=subprocess.PIPE,
)
assert proc.stdin is not None
proc.stdin.write(data)
proc.stdin.write(b"\n")
proc.stdin.close()
except:
with open("/var/lib/unbound/error.log", "at") as f:
f.write(f"While adding ips for set {set}:\n")
traceback.print_exc(file=f)
def add_split_domain(domains: Domains, split_domain: list[str]):
if not split_domain:
return
split_domain = split_domain[:]
if split_domain and split_domain[-1] == "*":
split_domain.pop()
if not split_domain:
return
while len(split_domain) > 1:
key = split_domain[-1]
if key in domains.keys():
domains1 = domains[key]
if isinstance(domains1, bool):
return
else:
domains1 = {}
domains[key] = domains1
domains = domains1
split_domain.pop()
domains[split_domain[-1]] = True
def build_domains(domains: list[str]) -> Domains:
ret: Domains = {}
for domain in domains:
add_split_domain(ret, domain.split("."))
return ret
def lookup_domain(domains: Domains, domain: str) -> bool:
split_domain: list[str] = domain.split(".")
while len(split_domain):
key: str = split_domain[-1]
split_domain = split_domain[:-1]
domains1 = domains.get(key, False)
if isinstance(domains1, bool):
return domains1
domains = domains1
return False
class DpiInfo(TypedDict):
domains: list[str]
name: str
restriction: dict
def init(*args: Any, **kwargs: Any):
global dbus_thread, DEBUG
global MDNS_TTL, MDNS_GETONE, MDNS_TIMEOUT
global MDNS_REJECT_TYPES, MDNS_ACCEPT_TYPES
global MDNS_REJECT_NAMES, MDNS_ACCEPT_NAMES
global NFT_QUERIES, NFT_TOKEN, DOMAIN_NAME_OVERRIDES
global REJECT_A, REJECT_AAAA
w = os.environ.get("REJECT_A", None)
if w is not None:
REJECT_A = re.compile(w)
w = os.environ.get("REJECT_AAAA", None)
if w is not None:
REJECT_AAAA = re.compile(w)
domain_name_overrides: str = os.environ.get("DOMAIN_NAME_OVERRIDES", "")
if domain_name_overrides:
@ -689,92 +477,6 @@ def init(*args: Any, **kwargs: Any):
DOMAIN_NAME_OVERRIDES[k1] = v1
DOMAIN_NAME_OVERRIDES[k1 + "."] = v1 + "."
NFT_TOKEN = os.environ.get("NFT_TOKEN", "")
nft_queries: str = os.environ.get("NFT_QUERIES", "")
if nft_queries:
for query in nft_queries.split(";"):
name, sets = query.split(":")
dynamic = False
if name.endswith("!"):
name = name.rstrip("!")
dynamic = True
set4, set6 = sets.split(",")
NFT_QUERIES[name] = {
"domains": {},
"ips4": [],
"ips6": [],
"name4": set4,
"name6": set6,
"dynamic": dynamic,
}
for k, v in NFT_QUERIES.items():
all_domains: list[str] = []
for base in ["/etc/unbound", "/var/lib/unbound"]:
try:
with open(f"{base}/{k}_domains.json", "rt", encoding="utf-8") as f:
domains: list[str] = json.load(f)
all_domains.extend(domains)
except FileNotFoundError:
pass
except:
with open("/var/lib/unbound/error.log", "at") as f:
traceback.print_exc(file=f)
try:
with open(f"{base}/{k}_dpi.json", "rt", encoding="utf-8") as f:
dpi: list[DpiInfo] = json.load(f)
for dpi_info in dpi:
all_domains.extend(dpi_info["domains"])
except FileNotFoundError:
pass
except:
with open("/var/lib/unbound/error.log", "at") as f:
traceback.print_exc(file=f)
try:
with open(f"{base}/{k}_ips.json", "rt", encoding="utf-8") as f:
ips: list[str] = json.load(f)
v["ips4"].extend(filter(lambda x: "." in x, ips))
v["ips6"].extend(filter(lambda x: ":" in x, ips))
except FileNotFoundError:
pass
except:
with open("/var/lib/unbound/error.log", "at") as f:
traceback.print_exc(file=f)
v["domains"] = build_domains(all_domains)
# cached resolved domains
try:
os.makedirs("/var/lib/unbound/domains4/", exist_ok=True)
for x in os.listdir("/var/lib/unbound/domains4/"):
with open(f"/var/lib/unbound/domains4/{x}", "rt") as f:
data = f.read().split("\n")
for k, v in NFT_QUERIES.items():
if lookup_domain(v["domains"], x):
v["ips4"].extend(data)
except:
with open("/var/lib/unbound/error.log", "at") as f:
traceback.print_exc(file=f)
try:
os.makedirs("/var/lib/unbound/domains6/", exist_ok=True)
for x in os.listdir("/var/lib/unbound/domains6/"):
with open(f"/var/lib/unbound/domains6/{x}", "rt") as f:
data = f.read().split("\n")
for k, v in NFT_QUERIES.items():
if lookup_domain(v["domains"], x):
v["ips6"].extend(data)
except:
with open("/var/lib/unbound/error.log", "at") as f:
traceback.print_exc(file=f)
# finally, add the ips to nftables
for k, v in NFT_QUERIES.items():
if v["ips4"] and v["name4"]:
add_ips(v["name4"], False, v["ips4"], flush=True)
if v["ips6"] and v["name6"]:
add_ips(v["name6"], True, v["ips6"], flush=True)
v["ips4"] = build_ipset(v["ips4"])
v["ips6"] = build_ipset(v["ips6"])
DEBUG = str2bool(os.environ.get("DEBUG", str(False)))
MDNS_TTL = int(os.environ.get("MDNS_TTL", 120))
@ -788,20 +490,20 @@ def init(*args: Any, **kwargs: Any):
if MDNS_ACCEPT_TYPES:
dbg(f"ONLY resolving the following types via Avahi: {MDNS_ACCEPT_TYPES}")
v2 = os.environ.get("MDNS_REJECT_NAMES", None)
MDNS_REJECT_NAMES = re.compile(v2, flags=re.I | re.S) if v2 is not None else None
v = os.environ.get("MDNS_REJECT_NAMES", None)
MDNS_REJECT_NAMES = re.compile(v, flags=re.I | re.S) if v is not None else None
if MDNS_REJECT_NAMES is not None:
dbg(f"Names NOT resolved via Avahi: {MDNS_REJECT_NAMES.pattern}")
v2 = os.environ.get("MDNS_ACCEPT_NAMES", None)
MDNS_ACCEPT_NAMES = re.compile(v2, flags=re.I | re.S) if v2 is not None else None
v = os.environ.get("MDNS_ACCEPT_NAMES", None)
MDNS_ACCEPT_NAMES = re.compile(v, flags=re.I | re.S) if v is not None else None
if MDNS_ACCEPT_NAMES is not None:
dbg(
f"ONLY resolving the following names via Avahi: {MDNS_ACCEPT_NAMES.pattern}"
)
v2 = os.environ.get("MDNS_TIMEOUT", None)
MDNS_TIMEOUT = int(v2) if v2 is not None else None
v = os.environ.get("MDNS_TIMEOUT", None)
MDNS_TIMEOUT = int(v) if v is not None else None
if MDNS_TIMEOUT is not None:
dbg(f"Avahi request timeout: {MDNS_TIMEOUT}")
@ -861,139 +563,6 @@ def operate(id, event, qstate, qdata) -> bool:
n2: str = name.rstrip(".")
if NFT_TOKEN and n2.endswith(f"{NFT_TOKEN}"):
if n2.endswith(f".{NFT_TOKEN}"):
n3 = n2.removesuffix(f".{NFT_TOKEN}")
for k, v in NFT_QUERIES.items():
if v["dynamic"] and n3.endswith(f".{k}"):
n4 = n3.removesuffix(f".{k}")
qdomains = v["domains"]
if not lookup_domain(qdomains, n4):
add_split_domain(qdomains, n4.split("."))
old = []
if os.path.exists(f"/var/lib/unbound/{k}_domains.json"):
with open(f"/var/lib/unbound/{k}_domains.json", "rt") as f:
old = json.load(f)
os.rename(
f"/var/lib/unbound/{k}_domains.json",
f"/var/lib/unbound/{k}_domains.json.bak",
)
old.append(n4)
with open(f"/var/lib/unbound/{k}_domains.json", "wt") as f:
json.dump(old, f)
elif n2.endswith(f".tmp{NFT_TOKEN}"):
n3 = n2.removesuffix(f".tmp{NFT_TOKEN}")
for k, v in NFT_QUERIES.items():
if v["dynamic"] and n3.endswith(f".{k}"):
n4 = n3.removesuffix(f".{k}")
qdomains = v["domains"]
if not lookup_domain(qdomains, n4):
add_split_domain(qdomains, n4.split("."))
return True
qnames: list[str] = []
for k, v in NFT_QUERIES.items():
if lookup_domain(v["domains"], n2):
qnames.append(k)
# THIS IS PAIN
if qnames:
try:
ip4: list[str] = []
ip6: list[str] = []
if qstate.return_msg and qstate.return_msg.rep:
rep = qstate.return_msg.rep
for i in range(rep.rrset_count):
d = rep.rrsets[i].entry.data
rk = rep.rrsets[i].rk
# IN
if rk.rrset_class != 256:
continue
for j in range(0, d.count + d.rrsig_count):
wire = array.array("B", d.rr_data[j]).tobytes()
# A, AAAA
if (
rk.type == 256
and len(wire) == 4 + 2
and wire[:2] == b"\x00\x04"
):
ip4.append(".".join(str(x) for x in wire[2:]))
elif (
rk.type == 7168
and len(wire) == 16 + 2
and wire[:2] == b"\x00\x10"
):
b = list(hex(x)[2:].zfill(2) for x in wire[2:])
ip6.append(
":".join(
"".join(b[x : x + 2]) for x in range(0, len(b), 2)
)
)
changed4 = False
changed6 = False
if ip4:
new_data = "\n".join(sorted(ip4))
try:
with open("/var/lib/unbound/domains4/" + n2, "rt") as f:
old_data = f.read()
except:
old_data = ""
if old_data != new_data:
changed4 = True
with open("/var/lib/unbound/domains4/" + n2, "wt") as f:
f.write(new_data)
if ip6:
new_data = "\n".join(sorted(ip6))
try:
with open("/var/lib/unbound/domains6/" + n2, "rt") as f:
old_data = f.read()
except:
old_data = ""
if old_data != new_data:
changed6 = True
with open("/var/lib/unbound/domains6/" + n2, "wt") as f:
f.write(new_data)
if changed4:
for qname in qnames:
q = NFT_QUERIES[qname]
name4 = q["name4"]
ips4 = q["ips4"]
if name4:
ip2 = []
for ip in ip4:
exists = False
try:
if ips4.has_key(ip) or ips4.parent(ip) != None:
exists = True
except:
pass
if not exists:
ips4.insert(ip, None)
ip2.append(ip)
if ip2:
add_ips(name4, False, ip2)
if changed6:
for qname in qnames:
q = NFT_QUERIES[qname]
name6 = q["name6"]
ips6 = q["ips6"]
if name6:
ip2 = []
for ip in ip6:
exists = False
try:
if ips6.has_key(ip) or ips6.parent(ip) != None:
exists = True
except:
pass
if not exists:
ips6.insert(ip, None)
ip2.append(ip)
if ip2:
add_ips(name6, True, ip2)
except:
with open("/var/lib/unbound/error.log", "at") as f:
traceback.print_exc(file=f)
if event == MODULE_EVENT_NEW or event == MODULE_EVENT_PASS:
qstate.ext_state[id] = MODULE_WAIT_MODULE
return True
@ -1005,57 +574,6 @@ def operate(id, event, qstate, qdata) -> bool:
qstate.ext_state[id] = MODULE_FINISHED
rej_a = REJECT_A and REJECT_A.match(n2)
rej_aaaa = REJECT_AAAA and REJECT_AAAA.match(n2)
if rej_a or rej_aaaa:
if qstate.return_msg and qstate.return_msg.rep:
rep = qstate.return_msg.rep
have_other = False
changed = False
msg = DNSMessage(
qstate.qinfo.qname_str,
qstate.qinfo.qtype,
qstate.qinfo.qclass,
qstate.query_flags,
)
for i in range(rep.rrset_count):
d = rep.rrsets[i].entry.data
rk = rep.rrsets[i].rk
if rk.rrset_class == 256 and (
rej_a and rk.type == 256 or rej_aaaa and rk.type == 7168
):
changed = True
continue
if rk.rrset_class == 256 and (
rej_aaaa
and not rej_a
and rk.type == 256
or rej_a
and not rej_aaaa
and rk.type == 7168
):
have_other = True
# IN
for j in range(0, d.count):
if rk.type == 256 and rej_a:
continue
elif rk.type == 7168 and rej_aaaa:
continue
msg.answer.append(
rr2text(
(rk.dname_str, rk.rrset_class, rk.type, d.rr_data[j]), d.ttl
)
)
if changed and not have_other:
# reject
qstate.ext_state[id] = MODULE_ERROR
return True
elif changed:
# replace
if not msg.set_return_msg(qstate):
qstate.ext_state[id] = MODULE_ERROR
return True
# Only resolve via Avahi if we got NXDOMAIN from the upstream DNS
# server, or if we could not reach the upstream DNS server. If we
# got some records for the name from the upstream DNS server
@ -1119,13 +637,11 @@ def operate(id, event, qstate, qdata) -> bool:
if not m.set_return_msg(qstate):
raise Exception("Error in set_return_msg")
# For some reason this breaks everything! Unbound responds with SERVFAIL instead of using the cache
# i.e. the first response is fine, but loading it from cache just doesn't work
# Resolution via Avahi works fast anyway so whatever
# if not storeQueryInCache(qstate, qstate.return_msg.qinfo, qstate.return_msg.rep, 0):
# raise Exception("Error in storeQueryInCache")
qstate.return_msg.rep.security = 2
if not storeQueryInCache(qstate, qstate.return_msg.qinfo, qstate.return_msg.rep, 0):
raise Exception("Error in storeQueryInCache")
qstate.return_rcode = Rcode.NOERROR
return True

View file

@ -4,7 +4,8 @@
, lib
, router-lib
, server-config
, ... }:
, ...
}:
let
cfg = config.router-settings;
@ -857,7 +858,7 @@ in {
access-control = [ "${netCidrs.netns4} allow" "${netCidrs.netns6} allow" "${netCidrs.lan4} allow" "${netCidrs.lan6} allow" ];
aggressive-nsec = true;
do-ip6 = true;
module-config = ''"validator python iterator"'';
module-config = ''"validator dynlib python iterator"'';
local-zone = [
# incompatible with avahi resolver
# ''"local." static''
@ -889,6 +890,7 @@ in {
# normally it would refer to the flake path, but then the service changes on every flake update
# instead, write a new file in nix store
python.python-script = builtins.toFile "avahi-resolver-v2.py" (builtins.readFile ./avahi-resolver-v2.py);
dynlib.dynlib-file = "${pkgs.unbound-mod}/lib/libunbound_mod.so";
remote-control.control-enable = true;
};
};
@ -908,7 +910,7 @@ in {
networking.hosts."${serverAddress6}" = hosted-domains;
systemd.services.unbound = lib.mkIf config.services.unbound.enable {
environment.PYTHONPATH = let
unbound-python = pkgs.python3.withPackages (ps: with ps; [ pydbus dnspython requests pytricia nftables ]);
unbound-python = pkgs.python3.withPackages (ps: with ps; [ pydbus dnspython ]);
in
"${unbound-python}/${unbound-python.sitePackages}";
# see https://github.com/NixOS/nixpkgs/pull/310514