diff --git a/system/hosts/router/avahi-resolver-v2.py b/system/hosts/router/avahi-resolver-v2.py index 1dbe0d0..2b7e4ff 100644 --- a/system/hosts/router/avahi-resolver-v2.py +++ b/system/hosts/router/avahi-resolver-v2.py @@ -177,11 +177,13 @@ # Example: MDNS_ACCEPT_NAMES=^.*\.example\.com\.$ # -import json +import gi import ipaddress +import json import os import subprocess -import pytricia +import pydbus +import pytricia # type: ignore import re import array import threading @@ -189,49 +191,79 @@ import traceback import dns.rdata import dns.rdatatype import dns.rdataclass + +from collections.abc import Callable +from dns.rcode import Rcode +from dns.rdataclass import RdataClass +from dns.rdatatype import RdataType from queue import Queue from gi.repository import GLib from pydbus import SystemBus +from typing import TypedDict, Optional, Any -IF_UNSPEC = -1 +IF_UNSPEC = -1 PROTO_UNSPEC = -1 -NFT_QUERIES = {} + +Domains = dict[str, dict] + + +class NftQuery(TypedDict): + domains: Domains + ips4: pytricia.PyTricia + ips6: pytricia.PyTricia + name4: str + name6: str + dynamic: bool + + +NFT_QUERIES: dict[str, NftQuery] = {} # dynamic query update token -NFT_TOKEN = "" -DOMAIN_NAME_OVERRIDES = {} +NFT_TOKEN: str = "" +DOMAIN_NAME_OVERRIDES: dict[str, str] = {} +DEBUG = False +MDNS_TTL: int +MDNS_GETONE: bool +MDNS_TIMEOUT: Optional[int] +MDNS_REJECT_TYPES: list[RdataType] +MDNS_ACCEPT_TYPES: list[RdataType] +MDNS_REJECT_NAMES: Optional[re.Pattern] +MDNS_ACCEPT_NAMES: Optional[re.Pattern] -sysbus = None -avahi = None -trampoline = dict() +sysbus: pydbus.bus.Bus +avahi: Any # pydbus.proxy.ProxyObject +trampoline: dict[str, "RecordBrowser"] = dict() thread_local = threading.local() -dbus_thread = None -dbus_loop = None +dbus_thread: threading.Thread +dbus_loop: Any -def is_valid_ip4(x): + +def is_valid_ip4(x: str) -> bool: try: _ = ipaddress.IPv4Address(x) return True except ipaddress.AddressValueError: return False -def is_valid_ip6(x): + +def is_valid_ip6(x: str) -> bool: try: _ = ipaddress.IPv6Address(x) return True except ipaddress.AddressValueError: return False -def str2bool(v): - if v.lower() in ['false', 'no', '0', 'off', '']: + +def str2bool(v: str) -> bool: + if v.lower() in ["false", "no", "0", "off", ""]: return False return True -def dbg(msg): +def dbg(msg: str) -> None: if DEBUG != False: - log_info('avahi-resolver: %s' % msg) + log_info(f"avahi-resolver: {msg}") # @@ -247,77 +279,127 @@ def dbg(msg): # record browser and do our own signal matching and dispatching via # the following function. # -def signal_dispatcher(connection, sender, path, interface, name, args): +def signal_dispatcher(connection, sender, path: str, interface, name, args) -> None: o = trampoline.get(path, None) if o is None: return - if name == 'ItemNew': o.itemNew(*args) - elif name == 'ItemRemove': o.itemRemove(*args) - elif name == 'AllForNow': o.allForNow(*args) - elif name == 'Failure': o.failure(*args) + if name == "ItemNew": + o.itemNew(*args) + elif name == "ItemRemove": + o.itemRemove(*args) + elif name == "AllForNow": + o.allForNow(*args) + elif name == "Failure": + o.failure(*args) class RecordBrowser: - def __init__(self, callback, name, type_, timeout=None, getone=True): + def __init__( + self, + callback: Callable[ + [list[tuple[str, RdataClass, RdataType, bytes]], Optional[Exception]], None + ], + name: str, + type_: RdataType, + timeout: Optional[int] = None, + getone: bool = True, + ): self.callback = callback - self.records = [] - self.error = None - self.getone = getone - name1 = DOMAIN_NAME_OVERRIDES.get(name, name) + self.records: list[tuple[str, RdataClass, RdataType, bytes]] = [] + self.error: Optional[Exception] = None + self.getone: bool = getone + name1: str = DOMAIN_NAME_OVERRIDES.get(name, name) if name1 != name: - self.overrides = { + self.overrides: dict[str, str] = { name1: name, } - if name.endswith('.') and name1.endswith('.'): + if name.endswith(".") and name1.endswith("."): self.overrides[name1[:-1]] = name[:-1] else: - self.overrides = { } + self.overrides = {} - self.timer = None if timeout is None else GLib.timeout_add(timeout, self.timedOut) + self.timer = ( + None if timeout is None else GLib.timeout_add(timeout, self.timedOut) + ) - self.browser_path = avahi.RecordBrowserNew(IF_UNSPEC, PROTO_UNSPEC, name1, dns.rdataclass.IN, type_, 0) + self.browser_path: str = avahi.RecordBrowserNew( + IF_UNSPEC, PROTO_UNSPEC, name1, dns.rdataclass.IN, type_, 0 + ) trampoline[self.browser_path] = self - self.browser = sysbus.get('.Avahi', self.browser_path) - self.dbg('Created RecordBrowser(name=%s, type=%s, getone=%s, timeout=%s)' - % (name1, dns.rdatatype.to_text(type_), getone, timeout)) + self.browser = sysbus.get(".Avahi", self.browser_path) + self.dbg( + f"Created RecordBrowser(name={name1}, type={dns.rdatatype.to_text(type_)}, getone={getone}, timeout={timeout})" + ) - def dbg(self, msg): - dbg('[%s] %s' % (self.browser_path, msg)) + def dbg(self, msg: str): + dbg(f"[{self.browser_path}] {msg}") - def _done(self): + def _done(self) -> None: del trampoline[self.browser_path] - self.dbg('Freeing') + self.dbg("Freeing") self.browser.Free() if self.timer is not None: - self.dbg('Removing timer') + self.dbg("Removing timer") GLib.source_remove(self.timer) self.callback(self.records, self.error) - def itemNew(self, interface, protocol, name, class_, type_, rdata, flags): - self.dbg('Got signal ItemNew') - self.records.append((self.overrides.get(name, name), class_, type_, rdata)) + def itemNew( + self, + interface: int, + protocol: int, + name: str, + class_: int, + type_: int, + rdata: bytes, + flags: int, + ): + self.dbg("Got signal ItemNew") + self.records.append( + ( + self.overrides.get(name, name), + RdataClass(class_), + RdataType(type_), + rdata, + ) + ) if self.getone: self._done() - def itemRemove(self, interface, protocol, name, class_, type_, rdata, flags): - self.dbg('Got signal ItemRemove') - self.records.remove((self.overrides.get(name, name), class_, type_, rdata)) + def itemRemove( + self, + interface: int, + protocol: int, + name: str, + class_: int, + type_: int, + rdata: bytes, + flags: int, + ): + self.dbg("Got signal ItemRemove") + self.records.remove( + ( + self.overrides.get(name, name), + RdataClass(class_), + RdataType(type_), + rdata, + ) + ) - def failure(self, error): - self.dbg('Got signal Failure') + def failure(self, error: str): + self.dbg("Got signal Failure") self.error = Exception(error) self._done() - def allForNow(self): - self.dbg('Got signal AllForNow') + def allForNow(self) -> None: + self.dbg("Got signal AllForNow") if self.timer is None: self._done() - def timedOut(self): - self.dbg('Timed out') + def timedOut(self) -> bool: + self.dbg("Timed out") self._done() return False @@ -326,23 +408,30 @@ class RecordBrowser: # This function runs the main event loop for DBus (GLib). This # function must be run in a dedicated worker thread. # -def dbus_main(): +def dbus_main() -> None: global sysbus, avahi, dbus_loop - dbg('Connecting to system DBus') + dbg("Connecting to system DBus") sysbus = SystemBus() - dbg('Subscribing to .Avahi.RecordBrowser signals') - sysbus.con.signal_subscribe('org.freedesktop.Avahi', - 'org.freedesktop.Avahi.RecordBrowser', - None, None, None, 0, signal_dispatcher) + dbg("Subscribing to .Avahi.RecordBrowser signals") + sysbus.con.signal_subscribe( + "org.freedesktop.Avahi", + "org.freedesktop.Avahi.RecordBrowser", + None, + None, + None, + 0, + signal_dispatcher, + ) - avahi = sysbus.get('.Avahi', '/') + avahi = sysbus.get(".Avahi", "/") - dbg("Connected to Avahi Daemon: %s (API %s) [%s]" - % (avahi.GetVersionString(), avahi.GetAPIVersion(), avahi.GetHostNameFqdn())) + dbg( + f"Connected to Avahi Daemon: {avahi.GetVersionString()} (API {avahi.GetAPIVersion()}) [{avahi.GetHostNameFqdn()}]" + ) - dbg('Starting DBus main loop') + dbg("Starting DBus main loop") dbus_loop = GLib.MainLoop() dbus_loop.run() @@ -352,9 +441,20 @@ def dbus_main(): # new RecordBrowser instance and once it has finished doing it thing, # it will send the result back to the original thread via the queue. # -def start_resolver(queue, *args, **kwargs): +def start_resolver( + queue: Queue[ + ( + tuple[list[tuple[str, RdataClass, RdataType, bytes]], None] + | tuple[None, Exception] + ) + ], + name: str, + type_: RdataType, + timeout: Optional[int] = None, + getone: bool = True, +) -> bool: try: - RecordBrowser(lambda *v: queue.put_nowait(v), *args, **kwargs) + RecordBrowser(lambda *v: queue.put_nowait(v), name, type_, timeout, getone) except Exception as e: queue.put_nowait((None, e)) @@ -367,15 +467,22 @@ def start_resolver(queue, *args, **kwargs): # queue. If the worker thread reports an error, raise the error as an # exception. # -def resolve(*args, **kwargs): +def resolve( + name: str, type_: RdataType, timeout: Optional[int] = None, getone: bool = True +) -> list[tuple[str, RdataClass, RdataType, bytes]]: try: - queue = thread_local.queue + queue: Queue[ + ( + tuple[list[tuple[str, RdataClass, RdataType, bytes]], None] + | tuple[None, Exception] + ) + ] = thread_local.queue except AttributeError: - dbg('Creating new per-thread queue') + dbg("Creating new per-thread queue") queue = Queue() thread_local.queue = queue - GLib.idle_add(lambda: start_resolver(queue, *args, **kwargs)) + GLib.idle_add(lambda: start_resolver(queue, name, type_, timeout, getone)) records, error = queue.get() queue.task_done() @@ -383,34 +490,41 @@ def resolve(*args, **kwargs): if error is not None: raise error + assert records is not None return records -def parse_type_list(lst): - return list(map(dns.rdatatype.from_text, [v.strip() for v in lst.split(',') if len(v)])) +def parse_type_list(lst: str) -> list[RdataType]: + return list( + map(dns.rdatatype.from_text, [v.strip() for v in lst.split(",") if len(v)]) + ) -def build_ipset(ips: list): + +def build_ipset(ips: list[str]) -> pytricia.PyTricia: pyt = pytricia.PyTricia() for ip in ips: try: pyt.insert(ip, None) except: - with open('/var/lib/unbound/error.log', 'at') as f: - f.write(f'Warning: couldn\'t insert ip {ip}:\n') + with open("/var/lib/unbound/error.log", "at") as f: + f.write(f"Warning: couldn't insert ip {ip}:\n") traceback.print_exc(file=f) return pyt -def add_ips(set: str, ipv6: bool, ips: list, flush: bool = False): - #with open('/var/lib/unbound/info.log', 'at') as f: - #print('set', set, 'ipv6', ipv6, 'ips', ips, file=f) + +def add_ips(set: str, ipv6: bool, ips: list[str], flush: bool = False): + # with open('/var/lib/unbound/info.log', 'at') as f: + # print('set', set, 'ipv6', ipv6, 'ips', ips, file=f) pyt = build_ipset(ips) - ruleset: list = [ ] + ruleset: list[dict] = [] if flush: - ruleset.append({"flush":{"set":{"family":"inet","table":"global","name":set}}}) - elems: list = [] + ruleset.append( + {"flush": {"set": {"family": "inet", "table": "global", "name": set}}} + ) + elems: list[str | dict] = [] if ipv6: maxn = 128 - is_valid = is_valid_ip6 + is_valid: Callable[[str], bool] = is_valid_ip6 else: maxn = 32 is_valid = is_valid_ip4 @@ -420,10 +534,10 @@ def add_ips(set: str, ipv6: bool, ips: list, flush: bool = False): continue except: pass - if '/' not in ip: - n = maxn + if "/" not in ip: + n: int = maxn else: - ip, n0 = ip.split('/') + ip, n0 = ip.split("/") try: n = int(n0) except: @@ -433,308 +547,374 @@ def add_ips(set: str, ipv6: bool, ips: list, flush: bool = False): if n == maxn: elems.append(ip) else: - elems.append({"prefix":{"addr":ip,"len":n}}) - #with open('/var/lib/unbound/info.log', 'at') as f: - #print('elems', elems, file=f) + elems.append({"prefix": {"addr": ip, "len": n}}) + # with open('/var/lib/unbound/info.log', 'at') as f: + # print('elems', elems, file=f) if len(elems) == 0: return - ruleset.append({"add":{"element":{"family":"inet","table":"global","name":set,"elem":elems}}}) - data = json.dumps({"nftables":ruleset}).encode('utf-8') - #with open('/var/lib/unbound/info.log', 'at') as f: - #print('data', data, file=f) + ruleset.append( + { + "add": { + "element": { + "family": "inet", + "table": "global", + "name": set, + "elem": elems, + } + } + } + ) + data: bytes = json.dumps({"nftables": ruleset}).encode("utf-8") + # with open('/var/lib/unbound/info.log', 'at') as f: + # print('data', data, file=f) try: - out = subprocess.run([ - '/run/current-system/sw/bin/nft', - '-j', '-f', '/dev/stdin' - ], capture_output=True, input=data) - #with open('/var/lib/unbound/info.log', 'at') as f: - #print('out', out, file=f) + out = subprocess.run( + ["/run/current-system/sw/bin/nft", "-j", "-f", "/dev/stdin"], + capture_output=True, + input=data, + ) + # with open('/var/lib/unbound/info.log', 'at') as f: + # print('out', out, file=f) if out.returncode != 0: - with open('/var/lib/unbound/nftables.log', 'wb') as f: - f.write(b'Error running nftables ruleset. Ruleset:\n') + with open("/var/lib/unbound/nftables.log", "wb") as f: + f.write(b"Error running nftables ruleset. Ruleset:\n") f.write(data) - f.write(b'\n') - f.write(b'stdout:\n') + f.write(b"\n") + f.write(b"stdout:\n") f.write(out.stdout) - f.write(b'\nstderr:\n') + f.write(b"\nstderr:\n") f.write(out.stderr) - f.write(b'\n') + f.write(b"\n") except: - with open('/var/lib/unbound/error.log', 'at') as f: - f.write(f'While adding ips for set {set}:\n') + with open("/var/lib/unbound/error.log", "at") as f: + f.write(f"While adding ips for set {set}:\n") traceback.print_exc(file=f) -def add_split_domain(domains, splitDomain): - while splitDomain: - key = splitDomain[-1] - if key not in domains.keys(): - domains[key] = {} - domains = domains[key] - splitDomain = splitDomain[:-1] - domains['__IsTrue__'] = True -def build_domains(domains): - ret = {} +def add_split_domain(domains: Domains, split_domain): + domains1: dict = domains + while split_domain: + key = split_domain[-1] + if key not in domains1.keys(): + domains1[key] = {} + domains = domains1[key] + split_domain = split_domain[:-1] + domains1["__IsTrue__"] = True + + +def build_domains(domains: list[str]) -> Domains: + ret: Domains = {} for domain in domains: - add_split_domain(ret, domain.split('.')) + add_split_domain(ret, domain.split(".")) return ret -def lookup_domain(domains, domain): - splitDomain = domain.split('.') - while len(splitDomain): - key = splitDomain[-1] - splitDomain = splitDomain[:-1] - star = domains.get('*', None) - if star != None and star.get('__IsTrue__', False): - return True - domains = domains.get(key, None) - if domains == None: - return False - star = domains.get('*', None) - if star != None and star.get('__IsTrue__', False): - return True - return domains.get('__IsTrue__', False) -def init(*args, **kwargs): +def lookup_domain(domains: Domains, domain: str) -> bool: + split_domain: list[str] = domain.split(".") + domains1: dict = domains + while len(split_domain): + key: str = split_domain[-1] + split_domain = split_domain[:-1] + star: Optional[dict] = domains1.get("*", None) + if star is not None and star.get("__IsTrue__", False): + return True + domains1 = domains1.get(key, None) + if domains1 is None: + return False + star = domains.get("*", None) + if star is not None and star.get("__IsTrue__", False): + return True + return bool(domains.get("__IsTrue__", False)) + + +class DpiInfo(TypedDict): + domains: list[str] + name: str + restriction: dict + + +def init(*args: Any, **kwargs: Any): global dbus_thread, DEBUG global MDNS_TTL, MDNS_GETONE, MDNS_TIMEOUT global MDNS_REJECT_TYPES, MDNS_ACCEPT_TYPES global MDNS_REJECT_NAMES, MDNS_ACCEPT_NAMES global NFT_QUERIES, NFT_TOKEN, DOMAIN_NAME_OVERRIDES - domain_name_overrides = os.environ.get('DOMAIN_NAME_OVERRIDES', '') + domain_name_overrides: str = os.environ.get("DOMAIN_NAME_OVERRIDES", "") if domain_name_overrides: - for kv in domain_name_overrides.split(';'): - k, v = kv.split('->') - DOMAIN_NAME_OVERRIDES[k] = v - DOMAIN_NAME_OVERRIDES[k + '.'] = v + '.' + for kv in domain_name_overrides.split(";"): + k1, v1 = kv.split("->") + DOMAIN_NAME_OVERRIDES[k1] = v1 + DOMAIN_NAME_OVERRIDES[k1 + "."] = v1 + "." - NFT_TOKEN = os.environ.get('NFT_TOKEN', '') - nft_queries = os.environ.get('NFT_QUERIES', '') + NFT_TOKEN = os.environ.get("NFT_TOKEN", "") + nft_queries: str = os.environ.get("NFT_QUERIES", "") if nft_queries: - for query in nft_queries.split(';'): - name, sets = query.split(':') + for query in nft_queries.split(";"): + name, sets = query.split(":") dynamic = False - if name.endswith('!'): - name = name.rstrip('!') + if name.endswith("!"): + name = name.rstrip("!") dynamic = True - set4, set6 = sets.split(',') - NFT_QUERIES[name] = { 'domains': [], 'ips4': [], 'ips6': [], 'name4': set4, 'name6': set6, 'dynamic': dynamic } + set4, set6 = sets.split(",") + NFT_QUERIES[name] = { + "domains": {}, + "ips4": [], + "ips6": [], + "name4": set4, + "name6": set6, + "dynamic": dynamic, + } for k, v in NFT_QUERIES.items(): - try: - domains = json.load(open(f'/etc/unbound/{k}_domains.json', 'rt', encoding='utf-8')) - v['domains'].extend(domains) - except: - pass - try: - domains = json.load(open(f'/var/lib/unbound/{k}_domains.json', 'rt', encoding='utf-8')) - v['domains'].extend(domains) - except: - pass - v['domains'] = build_domains(v['domains']) - try: - ips = json.load(open(f'/etc/unbound/{k}_ips.json', 'rt', encoding='utf-8')) - v['ips4'].extend(filter(lambda x: '.' in x, ips)) - v['ips6'].extend(filter(lambda x: ':' in x, ips)) - except: - pass - try: - ips = json.load(open(f'/var/lib/unbound/{k}_ips.json', 'rt', encoding='utf-8')) - v['ips4'].extend(filter(lambda x: '.' in x, ips)) - v['ips6'].extend(filter(lambda x: ':' in x, ips)) - except: - pass + all_domains: list[str] = [] + for base in ["/etc/unbound", "/var/lib/unbound"]: + try: + with open(f"{base}/{k}_domains.json", "rt", encoding="utf-8") as f: + domains: list[str] = json.load(f) + all_domains.extend(domains) + except FileNotFoundError: + pass + except: + with open("/var/lib/unbound/error.log", "at") as f: + traceback.print_exc(file=f) + try: + with open(f"{base}/{k}_dpi.json", "rt", encoding="utf-8") as f: + dpi: list[DpiInfo] = json.load(f) + for dpi_info in dpi: + all_domains.extend(dpi_info.get("domains", [])) + except FileNotFoundError: + pass + except: + with open("/var/lib/unbound/error.log", "at") as f: + traceback.print_exc(file=f) + try: + with open(f"{base}/{k}_ips.json", "rt", encoding="utf-8") as f: + ips: list[str] = json.load(f) + v["ips4"].extend(filter(lambda x: "." in x, ips)) + v["ips6"].extend(filter(lambda x: ":" in x, ips)) + except FileNotFoundError: + pass + except: + with open("/var/lib/unbound/error.log", "at") as f: + traceback.print_exc(file=f) + v["domains"] = build_domains(all_domains) # cached resolved domains try: - os.makedirs('/var/lib/unbound/domains4/', exist_ok=True) - for x in os.listdir('/var/lib/unbound/domains4/'): - with open('/var/lib/unbound/domains4/' + x, 'rt') as f: - data = f.read().split('\n') + os.makedirs("/var/lib/unbound/domains4/", exist_ok=True) + for x in os.listdir("/var/lib/unbound/domains4/"): + with open(f"/var/lib/unbound/domains4/{x}", "rt") as f: + data = f.read().split("\n") for k, v in NFT_QUERIES.items(): - if lookup_domain(v['domains'], x): - v['ips4'].extend(data) + if lookup_domain(v["domains"], x): + v["ips4"].extend(data) except: - with open('/var/lib/unbound/error.log', 'at') as f: + with open("/var/lib/unbound/error.log", "at") as f: traceback.print_exc(file=f) try: - os.makedirs('/var/lib/unbound/domains6/', exist_ok=True) - for x in os.listdir('/var/lib/unbound/domains6/'): - with open('/var/lib/unbound/domains6/' + x, 'rt') as f: - data = f.read().split('\n') + os.makedirs("/var/lib/unbound/domains6/", exist_ok=True) + for x in os.listdir("/var/lib/unbound/domains6/"): + with open(f"/var/lib/unbound/domains6/{x}", "rt") as f: + data = f.read().split("\n") for k, v in NFT_QUERIES.items(): - if lookup_domain(v['domains'], x): - v['ips6'].extend(data) + if lookup_domain(v["domains"], x): + v["ips6"].extend(data) except: - with open('/var/lib/unbound/error.log', 'at') as f: + with open("/var/lib/unbound/error.log", "at") as f: traceback.print_exc(file=f) # finally, add the ips to nftables for k, v in NFT_QUERIES.items(): - if v['ips4'] and v['name4']: - add_ips(v['name4'], False, v['ips4'], flush=True) - if v['ips6'] and v['name6']: - add_ips(v['name6'], True, v['ips6'], flush=True) - v['ips4'] = build_ipset(v['ips4']) - v['ips6'] = build_ipset(v['ips6']) + if v["ips4"] and v["name4"]: + add_ips(v["name4"], False, v["ips4"], flush=True) + if v["ips6"] and v["name6"]: + add_ips(v["name6"], True, v["ips6"], flush=True) + v["ips4"] = build_ipset(v["ips4"]) + v["ips6"] = build_ipset(v["ips6"]) - DEBUG = str2bool(os.environ.get('DEBUG', str(False))) + DEBUG = str2bool(os.environ.get("DEBUG", str(False))) - MDNS_TTL = int(os.environ.get('MDNS_TTL', 120)) - dbg("TTL for records from Avahi: %d" % MDNS_TTL) + MDNS_TTL = int(os.environ.get("MDNS_TTL", 120)) + dbg(f"TTL for records from Avahi: {MDNS_TTL}") - MDNS_REJECT_TYPES = parse_type_list(os.environ.get('MDNS_REJECT_TYPES', '')) + MDNS_REJECT_TYPES = parse_type_list(os.environ.get("MDNS_REJECT_TYPES", "")) if MDNS_REJECT_TYPES: - dbg('Types NOT resolved via Avahi: %s' % MDNS_REJECT_TYPES) + dbg(f"Types NOT resolved via Avahi: {MDNS_REJECT_TYPES}") - MDNS_ACCEPT_TYPES = parse_type_list(os.environ.get('MDNS_ACCEPT_TYPES', '')) + MDNS_ACCEPT_TYPES = parse_type_list(os.environ.get("MDNS_ACCEPT_TYPES", "")) if MDNS_ACCEPT_TYPES: - dbg('ONLY resolving the following types via Avahi: %s' % MDNS_ACCEPT_TYPES) + dbg(f"ONLY resolving the following types via Avahi: {MDNS_ACCEPT_TYPES}") - v = os.environ.get('MDNS_REJECT_NAMES', None) - MDNS_REJECT_NAMES = re.compile(v, flags=re.I | re.S) if v is not None else None + v2 = os.environ.get("MDNS_REJECT_NAMES", None) + MDNS_REJECT_NAMES = re.compile(v2, flags=re.I | re.S) if v2 is not None else None if MDNS_REJECT_NAMES is not None: - dbg('Names NOT resolved via Avahi: %s' % MDNS_REJECT_NAMES.pattern) + dbg(f"Names NOT resolved via Avahi: {MDNS_REJECT_NAMES.pattern}") - v = os.environ.get('MDNS_ACCEPT_NAMES', None) - MDNS_ACCEPT_NAMES = re.compile(v, flags=re.I | re.S) if v is not None else None + v2 = os.environ.get("MDNS_ACCEPT_NAMES", None) + MDNS_ACCEPT_NAMES = re.compile(v2, flags=re.I | re.S) if v2 is not None else None if MDNS_ACCEPT_NAMES is not None: - dbg('ONLY resolving the following names via Avahi: %s' % MDNS_ACCEPT_NAMES.pattern) + dbg( + f"ONLY resolving the following names via Avahi: {MDNS_ACCEPT_NAMES.pattern}" + ) - v = os.environ.get('MDNS_TIMEOUT', None) - MDNS_TIMEOUT = int(v) if v is not None else None + v2 = os.environ.get("MDNS_TIMEOUT", None) + MDNS_TIMEOUT = int(v2) if v2 is not None else None if MDNS_TIMEOUT is not None: - dbg('Avahi request timeout: %s' % MDNS_TIMEOUT) + dbg(f"Avahi request timeout: {MDNS_TIMEOUT}") - MDNS_GETONE = str2bool(os.environ.get('MDNS_GETONE', str(True))) - dbg('Terminate Avahi requests on first record: %s' % MDNS_GETONE) + MDNS_GETONE = str2bool(os.environ.get("MDNS_GETONE", str(True))) + dbg(f"Terminate Avahi requests on first record: {MDNS_GETONE}") dbus_thread = threading.Thread(target=dbus_main) dbus_thread.daemon = True dbus_thread.start() -def deinit(*args, **kwargs): +def deinit(*args, **kwargs) -> bool: dbus_loop.quit() dbus_thread.join() return True -def inform_super(id, qstate, superqstate, qdata): +def inform_super(id, qstate, superqstate, qdata) -> bool: return True -def get_rcode(msg): +MODULE_EVENT_NEW: int +MODULE_EVENT_PASS: int +MODULE_WAIT_MODULE: int +MODULE_EVENT_MODDONE: int +MODULE_ERROR: int +MODULE_FINISHED: int +PKT_QR: int +PKT_RD: int +PKT_RA: int +DNSMessage: Callable + + +def get_rcode(msg) -> Rcode: if not msg: - return RCODE_SERVFAIL + return Rcode.SERVFAIL - return msg.rep.flags & 0xf + return Rcode(msg.rep.flags & 0xF) -def rr2text(rec, ttl): +def rr2text(rec: tuple[str, RdataClass, RdataType, bytes], ttl: int) -> str: name, class_, type_, rdata = rec - wire = array.array('B', rdata).tobytes() - return '%s. %d %s %s %s' % ( - name, - ttl, - dns.rdataclass.to_text(class_), - dns.rdatatype.to_text(type_), - dns.rdata.from_wire(class_, type_, wire, 0, len(wire), None)) + wire = array.array("B", rdata).tobytes() + return f"{name}. {ttl} {dns.rdataclass.to_text(class_)} {dns.rdatatype.to_text(type_)} {dns.rdata.from_wire(class_, type_, wire, 0, len(wire), None)}" -def operate(id, event, qstate, qdata): + +def operate(id, event, qstate, qdata) -> bool: global NFT_QUERIES, NFT_TOKEN qi = qstate.qinfo - name = qi.qname_str - type_ = qi.qtype - type_str = dns.rdatatype.to_text(type_) - class_ = qi.qclass - class_str = dns.rdataclass.to_text(class_) - rc = get_rcode(qstate.return_msg) + name: str = qi.qname_str + type_: RdataType = qi.qtype + type_str: str = dns.rdatatype.to_text(type_) + class_: RdataClass = qi.qclass + class_str: str = dns.rdataclass.to_text(class_) + rc: Rcode = get_rcode(qstate.return_msg) - n2 = name.rstrip('.') + n2: str = name.rstrip(".") - if NFT_TOKEN and n2.endswith(f'{NFT_TOKEN}'): - if n2.endswith(f'.{NFT_TOKEN}'): - n3 = n2.removesuffix(f'.{NFT_TOKEN}') + if NFT_TOKEN and n2.endswith(f"{NFT_TOKEN}"): + if n2.endswith(f".{NFT_TOKEN}"): + n3 = n2.removesuffix(f".{NFT_TOKEN}") for k, v in NFT_QUERIES.items(): - if v['dynamic'] and n3.endswith(f'.{k}'): - n4 = n3.removesuffix(f'.{k}') - qdomains = v['domains'] + if v["dynamic"] and n3.endswith(f".{k}"): + n4 = n3.removesuffix(f".{k}") + qdomains = v["domains"] if not lookup_domain(qdomains, n4): - add_split_domain(qdomains, ['*'] + n4.split('.')) + add_split_domain(qdomains, ["*"] + n4.split(".")) old = [] - if os.path.exists(f'/var/lib/unbound/{k}_domains.json'): - with open(f'/var/lib/unbound/{k}_domains.json', 'rt') as f: + if os.path.exists(f"/var/lib/unbound/{k}_domains.json"): + with open(f"/var/lib/unbound/{k}_domains.json", "rt") as f: old = json.load(f) - os.rename(f'/var/lib/unbound/{k}_domains.json', f'/var/lib/unbound/{k}_domains.json.bak') - old.append('*.' + n4) - with open(f'/var/lib/unbound/{k}_domains.json', 'wt') as f: + os.rename( + f"/var/lib/unbound/{k}_domains.json", + f"/var/lib/unbound/{k}_domains.json.bak", + ) + old.append("*." + n4) + with open(f"/var/lib/unbound/{k}_domains.json", "wt") as f: json.dump(old, f) - elif n2.endswith(f'.tmp{NFT_TOKEN}'): - n3 = n2.removesuffix(f'.tmp{NFT_TOKEN}') + elif n2.endswith(f".tmp{NFT_TOKEN}"): + n3 = n2.removesuffix(f".tmp{NFT_TOKEN}") for k, v in NFT_QUERIES.items(): - if v['dynamic'] and n3.endswith(f'.{k}'): - n4 = n3.removesuffix(f'.{k}') - qdomains = v['domains'] + if v["dynamic"] and n3.endswith(f".{k}"): + n4 = n3.removesuffix(f".{k}") + qdomains = v["domains"] if not lookup_domain(qdomains, n4): - add_split_domain(qdomains, ['*'] + n4.split('.')) + add_split_domain(qdomains, ["*"] + n4.split(".")) return True - qnames = [] + qnames: list[str] = [] for k, v in NFT_QUERIES.items(): - if lookup_domain(v['domains'], n2): + if lookup_domain(v["domains"], n2): qnames.append(k) # THIS IS PAIN if qnames: try: - ip4 = [] - ip6 = [] + ip4: list[str] = [] + ip6: list[str] = [] if qstate.return_msg and qstate.return_msg.rep: rep = qstate.return_msg.rep for i in range(rep.rrset_count): d = rep.rrsets[i].entry.data rk = rep.rrsets[i].rk for j in range(0, d.count + d.rrsig_count): - wire = array.array('B', d.rr_data[j]).tobytes() + wire = array.array("B", d.rr_data[j]).tobytes() # IN - if rk.rrset_class != 256: continue + if rk.rrset_class != 256: + continue # A, AAAA - if rk.type == 256 and len(wire) == 4+2 and wire[:2] == b'\x00\x04': - ip4.append('.'.join(str(x) for x in wire[2:])) - elif rk.type == 7168 and len(wire) == 16+2 and wire[:2] == b'\x00\x10': + if ( + rk.type == 256 + and len(wire) == 4 + 2 + and wire[:2] == b"\x00\x04" + ): + ip4.append(".".join(str(x) for x in wire[2:])) + elif ( + rk.type == 7168 + and len(wire) == 16 + 2 + and wire[:2] == b"\x00\x10" + ): b = list(hex(x)[2:].zfill(2) for x in wire[2:]) - ip6.append(':'.join(''.join(b[x:x+2]) for x in range(0, len(b), 2))) + ip6.append( + ":".join( + "".join(b[x : x + 2]) for x in range(0, len(b), 2) + ) + ) changed4 = False changed6 = False if ip4: - new_data = '\n'.join(sorted(ip4)) + new_data = "\n".join(sorted(ip4)) try: - with open('/var/lib/unbound/domains4/' + n2, 'rt') as f: + with open("/var/lib/unbound/domains4/" + n2, "rt") as f: old_data = f.read() except: - old_data = '' + old_data = "" if old_data != new_data: changed4 = True - with open('/var/lib/unbound/domains4/' + n2, 'wt') as f: + with open("/var/lib/unbound/domains4/" + n2, "wt") as f: f.write(new_data) if ip6: - new_data = '\n'.join(sorted(ip6)) + new_data = "\n".join(sorted(ip6)) try: - with open('/var/lib/unbound/domains6/' + n2, 'rt') as f: + with open("/var/lib/unbound/domains6/" + n2, "rt") as f: old_data = f.read() except: - old_data = '' + old_data = "" if old_data != new_data: changed6 = True - with open('/var/lib/unbound/domains6/' + n2, 'wt') as f: + with open("/var/lib/unbound/domains6/" + n2, "wt") as f: f.write(new_data) if changed4: for qname in qnames: q = NFT_QUERIES[qname] - name4 = q['name4'] - ips4 = q['ips4'] + name4 = q["name4"] + ips4 = q["ips4"] if name4: ip2 = [] for ip in ip4: @@ -752,8 +932,8 @@ def operate(id, event, qstate, qdata): if changed6: for qname in qnames: q = NFT_QUERIES[qname] - name6 = q['name6'] - ips6 = q['ips6'] + name6 = q["name6"] + ips6 = q["ips6"] if name6: ip2 = [] for ip in ip6: @@ -769,7 +949,7 @@ def operate(id, event, qstate, qdata): if ip2: add_ips(name6, True, ip2) except: - with open('/var/lib/unbound/error.log', 'at') as f: + with open("/var/lib/unbound/error.log", "at") as f: traceback.print_exc(file=f) if event == MODULE_EVENT_NEW or event == MODULE_EVENT_PASS: @@ -783,64 +963,64 @@ def operate(id, event, qstate, qdata): qstate.ext_state[id] = MODULE_FINISHED - # Only resolve via Avahi if we got NXDOMAIn from the upstream DNS + # Only resolve via Avahi if we got NXDOMAIN from the upstream DNS # server, or if we could not reach the upstream DNS server. If we # got some records for the name from the upstream DNS server # already, do not resolve the record in Avahi. - if rc != RCODE_NXDOMAIN and rc != RCODE_SERVFAIL: + if rc != Rcode.NXDOMAIN and rc != Rcode.SERVFAIL: return True - dbg("Got request for '%s %s %s'" % (name, class_str, type_str)) + dbg(f"Got request for '{name} {class_str} {type_str}'") # Avahi only supports the IN class - if class_ != RR_CLASS_IN: - dbg('Rejected, Avahi only supports the IN class') + if class_ != RdataClass.IN: + dbg("Rejected, Avahi only supports the IN class") return True # Avahi does not support meta queries (e.g., ANY) if dns.rdatatype.is_metatype(type_): - dbg('Rejected, Avahi does not support the type %s' % type_str) + dbg(f"Rejected, Avahi does not support the type {type_str}") return True # If we have a type blacklist and the requested type is on the # list, reject it. if MDNS_REJECT_TYPES and type_ in MDNS_REJECT_TYPES: - dbg('Rejected, type %s is on the blacklist' % type_str) + dbg(f"Rejected, type {type_str} is on the blacklist") return True # If we have a type whitelist and if the requested type is not on # the list, reject it. if MDNS_ACCEPT_TYPES and type_ not in MDNS_ACCEPT_TYPES: - dbg('Rejected, type %s is not on the whitelist' % type_str) + dbg(f"Rejected, type {type_str} is not on the whitelist") return True # If we have a name blacklist and if the requested name matches # the blacklist, reject it. if MDNS_REJECT_NAMES is not None: if MDNS_REJECT_NAMES.search(name): - dbg('Rejected, name %s is on the blacklist' % name) + dbg(f"Rejected, name {name} is on the blacklist") return True # If we have a name whitelist and if the requested name does not # match the whitelist, reject it. if MDNS_ACCEPT_NAMES is not None: if not MDNS_ACCEPT_NAMES.search(name): - dbg('Rejected, name %s is not on the whitelist' % name) + dbg(f"Rejected, name {name} is not on the whitelist") return True - dbg("Resolving '%s %s %s' via Avahi" % (name, class_str, type_str)) + dbg(f"Resolving '{name} {class_str} {type_str}' via Avahi") recs = resolve(name, type_, getone=MDNS_GETONE, timeout=MDNS_TIMEOUT) if not recs: - dbg('Result: Not found (NXDOMAIN)') - qstate.return_rcode = RCODE_NXDOMAIN + dbg("Result: Not found (NXDOMAIN)") + qstate.return_rcode = Rcode.NXDOMAIN return True m = DNSMessage(name, type_, class_, PKT_QR | PKT_RD | PKT_RA) for r in recs: s = rr2text(r, MDNS_TTL) - dbg('Result: %s' % s) + dbg(f"Result: {s}") m.answer.append(s) if not m.set_return_msg(qstate): @@ -849,11 +1029,11 @@ def operate(id, event, qstate, qdata): # For some reason this breaks everything! Unbound responds with SERVFAIL instead of using the cache # i.e. the first response is fine, but loading it from cache just doesn't work # Resolution via Avahi works fast anyway so whatever - #if not storeQueryInCache(qstate, qstate.return_msg.qinfo, qstate.return_msg.rep, 0): + # if not storeQueryInCache(qstate, qstate.return_msg.qinfo, qstate.return_msg.rep, 0): # raise Exception("Error in storeQueryInCache") qstate.return_msg.rep.security = 2 - qstate.return_rcode = RCODE_NOERROR + qstate.return_rcode = Rcode.NOERROR return True @@ -864,42 +1044,43 @@ def operate(id, event, qstate, qdata): # run in interactive mode. # try: - import unboundmodule + import unboundmodule # type: ignore + embedded = True except ImportError: embedded = False -if __name__ == '__main__' and not embedded: +if __name__ == "__main__" and not embedded: import sys def log_info(msg): print(msg) def log_err(msg): - print('ERROR: %s' % msg, file=sys.stderr) + print(f"ERROR: {msg}", file=sys.stderr) if len(sys.argv) != 3: - print('Usage: %s ' % sys.argv[0]) + print(f"Usage: {sys.argv[0]} ") sys.exit(2) name = sys.argv[1] type_str = sys.argv[2] try: - type_ = dns.rdatatype.from_text(type_str) + type_: RdataType = dns.rdatatype.from_text(type_str) except dns.rdatatype.UnknownRdatatype: - log_err('Unsupported DNS record type "%s"' % type_str) + log_err(f'Unsupported DNS record type "{type_str}"') sys.exit(2) if dns.rdatatype.is_metatype(type_): - log_err('Meta record type "%s" cannot be resolved via Avahi' % type_str) + log_err(f'Meta record type "{type_str}" cannot be resolved via Avahi') sys.exit(2) init() try: recs = resolve(name, type_, getone=MDNS_GETONE, timeout=MDNS_TIMEOUT) if not len(recs): - print('%s not found (NXDOMAIN)' % name) + print(f"{name} not found (NXDOMAIN)") sys.exit(1) for r in recs: diff --git a/system/hosts/router/default.nix b/system/hosts/router/default.nix index 390ba63..7d56c44 100644 --- a/system/hosts/router/default.nix +++ b/system/hosts/router/default.nix @@ -894,10 +894,12 @@ in { # fetch vpn_ips.json and vpn_domains.json for unbound script = '' BLACKLIST=$(${pkgs.coreutils}/bin/mktemp) || exit 1 - ${pkgs.curl}/bin/curl "https://reestr.rublacklist.net/api/v2/ips/json/" -o "$BLACKLIST" || exit 1 + ${pkgs.curl}/bin/curl "https://reestr.rublacklist.net/api/v3/ips/" -o "$BLACKLIST" || exit 1 ${pkgs.jq}/bin/jq ".[0:0]" "$BLACKLIST" && chown unbound:unbound "$BLACKLIST" && mv "$BLACKLIST" /var/lib/unbound/vpn_ips.json - ${pkgs.curl}/bin/curl "https://reestr.rublacklist.net/api/v2/domains/json/" -o "$BLACKLIST" || exit 1 + ${pkgs.curl}/bin/curl "https://reestr.rublacklist.net/api/v3/domains/" -o "$BLACKLIST" || exit 1 ${pkgs.jq}/bin/jq ".[0:0]" "$BLACKLIST" && chown unbound:unbound "$BLACKLIST" && mv "$BLACKLIST" /var/lib/unbound/vpn_domains.json + ${pkgs.curl}/bin/curl "https://reestr.rublacklist.net/api/v3/dpi/" -o "$BLACKLIST" || exit 1 + ${pkgs.jq}/bin/jq ".[0:0]" "$BLACKLIST" && chown unbound:unbound "$BLACKLIST" && mv "$BLACKLIST" /var/lib/unbound/vpn_dpi.json ''; serviceConfig = { Type = "oneshot";