use std::{ collections::HashMap, mem, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::{ atomic::{AtomicU32, Ordering}, Arc, }, time::Duration, }; use async_shared_timeout::Timeout; use bytes::BytesMut; use chacha20poly1305::{aead::Buffer, AeadCore, AeadInPlace, ChaCha20Poly1305, KeyInit}; use dashmap::DashMap; use tokio::{net::UdpSocket, sync::mpsc}; use typenum::Unsigned; #[derive(Debug)] enum BMut<'a> { Bytes(BytesMut), Vec(&'a mut Vec<u8>), } impl<'a> Buffer for BMut<'a> { fn extend_from_slice(&mut self, other: &[u8]) -> chacha20poly1305::aead::Result<()> { match self { Self::Bytes(x) => x.extend_from_slice(other), Self::Vec(x) => x.extend_from_slice(other), } Ok(()) } fn truncate(&mut self, len: usize) { match self { Self::Bytes(x) => x.truncate(len), Self::Vec(x) => x.truncate(len), } } } impl<'a> AsRef<[u8]> for BMut<'a> { fn as_ref(&self) -> &[u8] { match self { Self::Bytes(x) => x.as_ref(), Self::Vec(x) => x.as_ref(), } } } impl<'a> AsMut<[u8]> for BMut<'a> { fn as_mut(&mut self) -> &mut [u8] { match self { Self::Bytes(x) => x.as_mut(), Self::Vec(x) => x.as_mut(), } } } enum ChanOrSock { Chan(mpsc::Receiver<BytesMut>), Sock(Arc<UdpSocket>, Vec<u8>), } impl ChanOrSock { async fn recv(&mut self) -> Option<BMut<'_>> { match self { Self::Chan(x) => x.recv().await.map(BMut::Bytes), Self::Sock(sock, buf) => { buf.resize(65536, 0); let len = sock.recv(buf).await.ok()?; buf.truncate(len); Some(BMut::Vec(buf)) } } } } static SESSION_ID: AtomicU32 = AtomicU32::new(0); static PROTOCOL_VERSION: u8 = 0; async fn thread( (mut rx_enc, sock_enc, enc_addr): (ChanOrSock, Arc<UdpSocket>, SocketAddr), (mut rx_dec, sock_dec, dec_addr): (ChanOrSock, Arc<UdpSocket>, SocketAddr), server: bool, chans: Arc<DashMap<SocketAddr, mpsc::Sender<BytesMut>>>, cipher_remote: ChaCha20Poly1305, cipher_local: ChaCha20Poly1305, ) { let timeout = Timeout::new( async_shared_timeout::runtime::Tokio::new(), Duration::from_secs(1), ); let session_id = SESSION_ID.fetch_add(1, Ordering::SeqCst).to_be_bytes(); let remote_to_local = async { let mut timeout_set = false; while let Some(mut buf) = rx_enc.recv().await { if buf.len() < 12 { continue; } let nonce = <[u8; 12]>::try_from(&buf.as_ref()[buf.len() - 12..]).unwrap(); let time = chrono::Utc::now().timestamp_nanos_opt().unwrap(); let nonce1 = [ nonce[4], nonce[5], nonce[6], nonce[7], nonce[8], nonce[9], nonce[10], nonce[11], ]; if i64::from_be_bytes(nonce1).abs_diff(time) > 10_000_000_000 { eprintln!("got invalid nonce"); continue; } buf.truncate(buf.len() - 12); if let Err(err) = cipher_remote.decrypt_in_place((&nonce).into(), b"", &mut buf) { eprintln!("decrypt error: {err}"); continue; } if server { if !timeout_set { timeout_set = true; timeout.set_default_timeout(Duration::from_secs(90)); } timeout.reset(); } if let Err(err) = sock_dec.as_ref().send_to(buf.as_ref(), dec_addr).await { eprintln!("decrypted send error: {err}"); } } }; let local_to_remote = async { let mut timeout_set = false; while let Some(mut buf) = rx_dec.recv().await { let time = chrono::Utc::now() .timestamp_nanos_opt() .unwrap() .to_be_bytes(); let nonce = [ PROTOCOL_VERSION, session_id[0], session_id[1], session_id[2], time[0], time[1], time[2], time[3], time[4], time[5], time[6], time[7], ]; if let Err(err) = cipher_local.encrypt_in_place((&nonce).into(), b"", &mut buf) { eprintln!("encrypt error: {err}"); continue; } buf.extend_from_slice(&nonce).unwrap(); if !server { if !timeout_set { timeout_set = true; timeout.set_default_timeout(Duration::from_secs(90)); } timeout.reset(); } if let Err(err) = sock_enc.as_ref().send_to(buf.as_ref(), enc_addr).await { eprintln!("encrypted send error: {err}"); } } }; tokio::select! { _ = timeout.wait() => {} _ = remote_to_local => {} _ = local_to_remote => {} } chans.remove(&enc_addr); } #[tokio::main] async fn main() { assert_eq!(<ChaCha20Poly1305 as AeadCore>::NonceSize::to_usize(), 12); let mut args = std::env::args().skip(1); let data = std::fs::read_to_string(args.next().unwrap()) .unwrap() .lines() .map(|x| { x.split('#') .next() .unwrap() .chars() .filter(|x| !x.is_whitespace()) .collect::<String>() }) .filter(|x| !x.is_empty()) .map(|x| { let (a, b) = x.split_once('=').unwrap_or((&x, "")); (a.to_owned(), b.to_owned()) }) .collect::<HashMap<_, _>>(); let ver: u8 = data.get("version").unwrap().parse().unwrap(); assert!(matches!(ver, 0)); let src: SocketAddr = data.get("bind").unwrap().parse().unwrap(); let dst: SocketAddr = data.get("connect").unwrap().parse().unwrap(); let h2b = |x: &str| { x.as_bytes() .chunks_exact(2) .map(|x| u8::from_str_radix(std::str::from_utf8(x).unwrap(), 16).unwrap()) .collect::<Box<[u8]>>() }; let key_local = h2b(data.get("key_local").unwrap()); let key_remote = h2b(data.get("key_remote").unwrap()); let server = data.contains_key("server"); let cipher_local = ChaCha20Poly1305::new_from_slice(&key_local).unwrap(); let cipher_remote = ChaCha20Poly1305::new_from_slice(&key_remote).unwrap(); let sock = Arc::new(UdpSocket::bind(src).await.unwrap()); let chans = Arc::new(DashMap::<SocketAddr, mpsc::Sender<_>>::new()); let mut buf = BytesMut::with_capacity(65536); loop { buf.resize(65536, 0); let Ok((len, addr)) = sock.recv_from(&mut buf).await else { continue; }; if let Some(ch) = chans.get(&addr) { if ch.send(buf.split_to(len)).await.is_ok() { continue; } } let (tx, rx) = mpsc::channel(128); tx.send(buf.split_to(len)).await.unwrap(); chans.insert(addr, tx); let sock2 = Arc::new( UdpSocket::bind(( if dst.is_ipv6() { IpAddr::V6(Ipv6Addr::UNSPECIFIED) } else { IpAddr::V4(Ipv4Addr::UNSPECIFIED) }, 0, )) .await .unwrap(), ); sock2.connect(dst).await.unwrap(); let (mut enc, mut dec) = ( (ChanOrSock::Sock(sock2.clone(), vec![]), sock2, dst), (ChanOrSock::Chan(rx), sock.clone(), addr), ); if server { mem::swap(&mut enc, &mut dec); } tokio::spawn(thread( enc, dec, server, chans.clone(), cipher_remote.clone(), cipher_local.clone(), )); } }