use std::{
    collections::HashMap,
    mem,
    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
    sync::{
        atomic::{AtomicU32, Ordering},
        Arc,
    },
    time::Duration,
};

use async_shared_timeout::Timeout;
use bytes::BytesMut;
use chacha20poly1305::{aead::Buffer, AeadCore, AeadInPlace, ChaCha20Poly1305, KeyInit};
use dashmap::DashMap;
use tokio::{net::UdpSocket, sync::mpsc};
use typenum::Unsigned;

#[derive(Debug)]
enum BMut<'a> {
    Bytes(BytesMut),
    Vec(&'a mut Vec<u8>),
}

impl<'a> Buffer for BMut<'a> {
    fn extend_from_slice(&mut self, other: &[u8]) -> chacha20poly1305::aead::Result<()> {
        match self {
            Self::Bytes(x) => x.extend_from_slice(other),
            Self::Vec(x) => x.extend_from_slice(other),
        }
        Ok(())
    }
    fn truncate(&mut self, len: usize) {
        match self {
            Self::Bytes(x) => x.truncate(len),
            Self::Vec(x) => x.truncate(len),
        }
    }
}
impl<'a> AsRef<[u8]> for BMut<'a> {
    fn as_ref(&self) -> &[u8] {
        match self {
            Self::Bytes(x) => x.as_ref(),
            Self::Vec(x) => x.as_ref(),
        }
    }
}
impl<'a> AsMut<[u8]> for BMut<'a> {
    fn as_mut(&mut self) -> &mut [u8] {
        match self {
            Self::Bytes(x) => x.as_mut(),
            Self::Vec(x) => x.as_mut(),
        }
    }
}

enum ChanOrSock {
    Chan(mpsc::Receiver<BytesMut>),
    Sock(Arc<UdpSocket>, Vec<u8>),
}

impl ChanOrSock {
    async fn recv(&mut self) -> Option<BMut<'_>> {
        match self {
            Self::Chan(x) => x.recv().await.map(BMut::Bytes),
            Self::Sock(sock, buf) => {
                buf.resize(65536, 0);
                let len = sock.recv(buf).await.ok()?;
                buf.truncate(len);
                Some(BMut::Vec(buf))
            }
        }
    }
}

static SESSION_ID: AtomicU32 = AtomicU32::new(0);
static PROTOCOL_VERSION: u8 = 0;

async fn thread(
    (mut rx_enc, sock_enc, enc_addr): (ChanOrSock, Arc<UdpSocket>, SocketAddr),
    (mut rx_dec, sock_dec, dec_addr): (ChanOrSock, Arc<UdpSocket>, SocketAddr),
    server: bool,
    chans: Arc<DashMap<SocketAddr, mpsc::Sender<BytesMut>>>,
    cipher_remote: ChaCha20Poly1305,
    cipher_local: ChaCha20Poly1305,
) {
    let timeout = Timeout::new(
        async_shared_timeout::runtime::Tokio::new(),
        Duration::from_secs(1),
    );
    let session_id = SESSION_ID.fetch_add(1, Ordering::SeqCst).to_be_bytes();
    let remote_to_local = async {
        let mut timeout_set = false;
        while let Some(mut buf) = rx_enc.recv().await {
            if buf.len() < 12 {
                continue;
            }
            let nonce = <[u8; 12]>::try_from(&buf.as_ref()[buf.len() - 12..]).unwrap();
            let time = chrono::Utc::now().timestamp_nanos_opt().unwrap();
            let nonce1 = [
                nonce[4], nonce[5], nonce[6], nonce[7], nonce[8], nonce[9], nonce[10], nonce[11],
            ];
            if i64::from_be_bytes(nonce1).abs_diff(time) > 10_000_000_000 {
                eprintln!("got invalid nonce");
                continue;
            }
            buf.truncate(buf.len() - 12);
            if let Err(err) = cipher_remote.decrypt_in_place((&nonce).into(), b"", &mut buf) {
                eprintln!("decrypt error: {err}");
                continue;
            }
            if server {
                if !timeout_set {
                    timeout_set = true;
                    timeout.set_default_timeout(Duration::from_secs(90));
                }
                timeout.reset();
            }

            if let Err(err) = sock_dec.as_ref().send_to(buf.as_ref(), dec_addr).await {
                eprintln!("decrypted send error: {err}");
            }
        }
    };
    let local_to_remote = async {
        let mut timeout_set = false;
        while let Some(mut buf) = rx_dec.recv().await {
            let time = chrono::Utc::now()
                .timestamp_nanos_opt()
                .unwrap()
                .to_be_bytes();
            let nonce = [
                PROTOCOL_VERSION,
                session_id[0],
                session_id[1],
                session_id[2],
                time[0],
                time[1],
                time[2],
                time[3],
                time[4],
                time[5],
                time[6],
                time[7],
            ];
            if let Err(err) = cipher_local.encrypt_in_place((&nonce).into(), b"", &mut buf) {
                eprintln!("encrypt error: {err}");
                continue;
            }
            buf.extend_from_slice(&nonce).unwrap();
            if !server {
                if !timeout_set {
                    timeout_set = true;
                    timeout.set_default_timeout(Duration::from_secs(90));
                }
                timeout.reset();
            }
            if let Err(err) = sock_enc.as_ref().send_to(buf.as_ref(), enc_addr).await {
                eprintln!("encrypted send error: {err}");
            }
        }
    };
    tokio::select! {
        _ = timeout.wait() => {}
        _ = remote_to_local => {}
        _ = local_to_remote => {}
    }
    chans.remove(&enc_addr);
}

#[tokio::main]
async fn main() {
    assert_eq!(<ChaCha20Poly1305 as AeadCore>::NonceSize::to_usize(), 12);
    let mut args = std::env::args().skip(1);
    let data = std::fs::read_to_string(args.next().unwrap())
        .unwrap()
        .lines()
        .map(|x| {
            x.split('#')
                .next()
                .unwrap()
                .chars()
                .filter(|x| !x.is_whitespace())
                .collect::<String>()
        })
        .filter(|x| !x.is_empty())
        .map(|x| {
            let (a, b) = x.split_once('=').unwrap_or((&x, ""));
            (a.to_owned(), b.to_owned())
        })
        .collect::<HashMap<_, _>>();
    let ver: u8 = data.get("version").unwrap().parse().unwrap();
    assert!(matches!(ver, 0));
    let src: SocketAddr = data.get("bind").unwrap().parse().unwrap();
    let dst: SocketAddr = data.get("connect").unwrap().parse().unwrap();
    let h2b = |x: &str| {
        x.as_bytes()
            .chunks_exact(2)
            .map(|x| u8::from_str_radix(std::str::from_utf8(x).unwrap(), 16).unwrap())
            .collect::<Box<[u8]>>()
    };
    let key_local = h2b(data.get("key_local").unwrap());
    let key_remote = h2b(data.get("key_remote").unwrap());
    let server = data.contains_key("server");
    let cipher_local = ChaCha20Poly1305::new_from_slice(&key_local).unwrap();
    let cipher_remote = ChaCha20Poly1305::new_from_slice(&key_remote).unwrap();
    let sock = Arc::new(UdpSocket::bind(src).await.unwrap());
    let chans = Arc::new(DashMap::<SocketAddr, mpsc::Sender<_>>::new());
    let mut buf = BytesMut::with_capacity(65536);
    loop {
        buf.resize(65536, 0);
        let Ok((len, addr)) = sock.recv_from(&mut buf).await else {
            continue;
        };
        if let Some(ch) = chans.get(&addr) {
            if ch.send(buf.split_to(len)).await.is_ok() {
                continue;
            }
        }
        let (tx, rx) = mpsc::channel(128);
        tx.send(buf.split_to(len)).await.unwrap();
        chans.insert(addr, tx);
        let sock2 = Arc::new(
            UdpSocket::bind((
                if dst.is_ipv6() {
                    IpAddr::V6(Ipv6Addr::UNSPECIFIED)
                } else {
                    IpAddr::V4(Ipv4Addr::UNSPECIFIED)
                },
                0,
            ))
            .await
            .unwrap(),
        );
        sock2.connect(dst).await.unwrap();
        let (mut enc, mut dec) = (
            (ChanOrSock::Sock(sock2.clone(), vec![]), sock2, dst),
            (ChanOrSock::Chan(rx), sock.clone(), addr),
        );
        if server {
            mem::swap(&mut enc, &mut dec);
        }
        tokio::spawn(thread(
            enc,
            dec,
            server,
            chans.clone(),
            cipher_remote.clone(),
            cipher_local.clone(),
        ));
    }
}