253 lines
7.9 KiB
Rust
253 lines
7.9 KiB
Rust
|
use std::{
|
||
|
collections::HashMap,
|
||
|
mem,
|
||
|
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||
|
sync::{
|
||
|
atomic::{AtomicU32, Ordering},
|
||
|
Arc,
|
||
|
},
|
||
|
time::Duration,
|
||
|
};
|
||
|
|
||
|
use async_shared_timeout::Timeout;
|
||
|
use bytes::BytesMut;
|
||
|
use chacha20poly1305::{aead::Buffer, AeadCore, AeadInPlace, ChaCha20Poly1305, KeyInit};
|
||
|
use dashmap::DashMap;
|
||
|
use tokio::{net::UdpSocket, sync::mpsc};
|
||
|
use typenum::Unsigned;
|
||
|
|
||
|
#[derive(Debug)]
|
||
|
enum BMut<'a> {
|
||
|
Bytes(BytesMut),
|
||
|
Vec(&'a mut Vec<u8>),
|
||
|
}
|
||
|
|
||
|
impl<'a> Buffer for BMut<'a> {
|
||
|
fn extend_from_slice(&mut self, other: &[u8]) -> chacha20poly1305::aead::Result<()> {
|
||
|
match self {
|
||
|
Self::Bytes(x) => x.extend_from_slice(other),
|
||
|
Self::Vec(x) => x.extend_from_slice(other),
|
||
|
}
|
||
|
Ok(())
|
||
|
}
|
||
|
fn truncate(&mut self, len: usize) {
|
||
|
match self {
|
||
|
Self::Bytes(x) => x.truncate(len),
|
||
|
Self::Vec(x) => x.truncate(len),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
impl<'a> AsRef<[u8]> for BMut<'a> {
|
||
|
fn as_ref(&self) -> &[u8] {
|
||
|
match self {
|
||
|
Self::Bytes(x) => x.as_ref(),
|
||
|
Self::Vec(x) => x.as_ref(),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
impl<'a> AsMut<[u8]> for BMut<'a> {
|
||
|
fn as_mut(&mut self) -> &mut [u8] {
|
||
|
match self {
|
||
|
Self::Bytes(x) => x.as_mut(),
|
||
|
Self::Vec(x) => x.as_mut(),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
enum ChanOrSock {
|
||
|
Chan(mpsc::Receiver<BytesMut>),
|
||
|
Sock(Arc<UdpSocket>, Vec<u8>),
|
||
|
}
|
||
|
|
||
|
impl ChanOrSock {
|
||
|
async fn recv(&mut self) -> Option<BMut<'_>> {
|
||
|
match self {
|
||
|
Self::Chan(x) => x.recv().await.map(BMut::Bytes),
|
||
|
Self::Sock(sock, buf) => {
|
||
|
buf.resize(65536, 0);
|
||
|
let len = sock.recv(buf).await.ok()?;
|
||
|
buf.truncate(len);
|
||
|
Some(BMut::Vec(buf))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
static SESSION_ID: AtomicU32 = AtomicU32::new(0);
|
||
|
static PROTOCOL_VERSION: u8 = 0;
|
||
|
|
||
|
async fn thread(
|
||
|
(mut rx_enc, sock_enc, enc_addr): (ChanOrSock, Arc<UdpSocket>, SocketAddr),
|
||
|
(mut rx_dec, sock_dec, dec_addr): (ChanOrSock, Arc<UdpSocket>, SocketAddr),
|
||
|
server: bool,
|
||
|
chans: Arc<DashMap<SocketAddr, mpsc::Sender<BytesMut>>>,
|
||
|
cipher_remote: ChaCha20Poly1305,
|
||
|
cipher_local: ChaCha20Poly1305,
|
||
|
) {
|
||
|
let timeout = Timeout::new(
|
||
|
async_shared_timeout::runtime::Tokio::new(),
|
||
|
Duration::from_secs(1),
|
||
|
);
|
||
|
let session_id = SESSION_ID.fetch_add(1, Ordering::SeqCst).to_be_bytes();
|
||
|
let remote_to_local = async {
|
||
|
let mut timeout_set = false;
|
||
|
while let Some(mut buf) = rx_enc.recv().await {
|
||
|
if buf.len() < 12 {
|
||
|
continue;
|
||
|
}
|
||
|
let nonce = <[u8; 12]>::try_from(&buf.as_ref()[buf.len() - 12..]).unwrap();
|
||
|
let time = chrono::Utc::now().timestamp_nanos_opt().unwrap();
|
||
|
let nonce1 = [
|
||
|
nonce[4], nonce[5], nonce[6], nonce[7], nonce[8], nonce[9], nonce[10], nonce[11],
|
||
|
];
|
||
|
if i64::from_be_bytes(nonce1).abs_diff(time) > 10_000_000_000 {
|
||
|
eprintln!("got invalid nonce");
|
||
|
continue;
|
||
|
}
|
||
|
buf.truncate(buf.len() - 12);
|
||
|
if let Err(err) = cipher_remote.decrypt_in_place((&nonce).into(), b"", &mut buf) {
|
||
|
eprintln!("decrypt error: {err}");
|
||
|
continue;
|
||
|
}
|
||
|
if server {
|
||
|
if !timeout_set {
|
||
|
timeout_set = true;
|
||
|
timeout.set_default_timeout(Duration::from_secs(90));
|
||
|
}
|
||
|
timeout.reset();
|
||
|
}
|
||
|
|
||
|
if let Err(err) = sock_dec.as_ref().send_to(buf.as_ref(), dec_addr).await {
|
||
|
eprintln!("decrypted send error: {err}");
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
let local_to_remote = async {
|
||
|
let mut timeout_set = false;
|
||
|
while let Some(mut buf) = rx_dec.recv().await {
|
||
|
let time = chrono::Utc::now()
|
||
|
.timestamp_nanos_opt()
|
||
|
.unwrap()
|
||
|
.to_be_bytes();
|
||
|
let nonce = [
|
||
|
PROTOCOL_VERSION,
|
||
|
session_id[0],
|
||
|
session_id[1],
|
||
|
session_id[2],
|
||
|
time[0],
|
||
|
time[1],
|
||
|
time[2],
|
||
|
time[3],
|
||
|
time[4],
|
||
|
time[5],
|
||
|
time[6],
|
||
|
time[7],
|
||
|
];
|
||
|
if let Err(err) = cipher_local.encrypt_in_place((&nonce).into(), b"", &mut buf) {
|
||
|
eprintln!("encrypt error: {err}");
|
||
|
continue;
|
||
|
}
|
||
|
buf.extend_from_slice(&nonce).unwrap();
|
||
|
if !server {
|
||
|
if !timeout_set {
|
||
|
timeout_set = true;
|
||
|
timeout.set_default_timeout(Duration::from_secs(90));
|
||
|
}
|
||
|
timeout.reset();
|
||
|
}
|
||
|
if let Err(err) = sock_enc.as_ref().send_to(buf.as_ref(), enc_addr).await {
|
||
|
eprintln!("encrypted send error: {err}");
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
tokio::select! {
|
||
|
_ = timeout.wait() => {}
|
||
|
_ = remote_to_local => {}
|
||
|
_ = local_to_remote => {}
|
||
|
}
|
||
|
chans.remove(&enc_addr);
|
||
|
}
|
||
|
|
||
|
#[tokio::main]
|
||
|
async fn main() {
|
||
|
assert_eq!(<ChaCha20Poly1305 as AeadCore>::NonceSize::to_usize(), 12);
|
||
|
let mut args = std::env::args().skip(1);
|
||
|
let data = std::fs::read_to_string(args.next().unwrap())
|
||
|
.unwrap()
|
||
|
.lines()
|
||
|
.map(|x| {
|
||
|
x.split('#')
|
||
|
.next()
|
||
|
.unwrap()
|
||
|
.chars()
|
||
|
.filter(|x| !x.is_whitespace())
|
||
|
.collect::<String>()
|
||
|
})
|
||
|
.filter(|x| !x.is_empty())
|
||
|
.map(|x| {
|
||
|
let (a, b) = x.split_once('=').unwrap_or((&x, ""));
|
||
|
(a.to_owned(), b.to_owned())
|
||
|
})
|
||
|
.collect::<HashMap<_, _>>();
|
||
|
let ver: u8 = data.get("version").unwrap().parse().unwrap();
|
||
|
assert!(matches!(ver, 0));
|
||
|
let src: SocketAddr = data.get("bind").unwrap().parse().unwrap();
|
||
|
let dst: SocketAddr = data.get("connect").unwrap().parse().unwrap();
|
||
|
let h2b = |x: &str| {
|
||
|
x.as_bytes()
|
||
|
.chunks_exact(2)
|
||
|
.map(|x| u8::from_str_radix(std::str::from_utf8(x).unwrap(), 16).unwrap())
|
||
|
.collect::<Box<[u8]>>()
|
||
|
};
|
||
|
let key_local = h2b(data.get("key_local").unwrap());
|
||
|
let key_remote = h2b(data.get("key_remote").unwrap());
|
||
|
let server = data.contains_key("server");
|
||
|
let cipher_local = ChaCha20Poly1305::new_from_slice(&key_local).unwrap();
|
||
|
let cipher_remote = ChaCha20Poly1305::new_from_slice(&key_remote).unwrap();
|
||
|
let sock = Arc::new(UdpSocket::bind(src).await.unwrap());
|
||
|
let chans = Arc::new(DashMap::<SocketAddr, mpsc::Sender<_>>::new());
|
||
|
let mut buf = BytesMut::with_capacity(65536);
|
||
|
loop {
|
||
|
buf.resize(65536, 0);
|
||
|
let Ok((len, addr)) = sock.recv_from(&mut buf).await else {
|
||
|
continue;
|
||
|
};
|
||
|
if let Some(ch) = chans.get(&addr) {
|
||
|
if ch.send(buf.split_to(len)).await.is_ok() {
|
||
|
continue;
|
||
|
}
|
||
|
}
|
||
|
let (tx, rx) = mpsc::channel(128);
|
||
|
tx.send(buf.split_to(len)).await.unwrap();
|
||
|
chans.insert(addr, tx);
|
||
|
let sock2 = Arc::new(
|
||
|
UdpSocket::bind((
|
||
|
if dst.is_ipv6() {
|
||
|
IpAddr::V6(Ipv6Addr::UNSPECIFIED)
|
||
|
} else {
|
||
|
IpAddr::V4(Ipv4Addr::UNSPECIFIED)
|
||
|
},
|
||
|
0,
|
||
|
))
|
||
|
.await
|
||
|
.unwrap(),
|
||
|
);
|
||
|
sock2.connect(dst).await.unwrap();
|
||
|
let (mut enc, mut dec) = (
|
||
|
(ChanOrSock::Sock(sock2.clone(), vec![]), sock2, dst),
|
||
|
(ChanOrSock::Chan(rx), sock.clone(), addr),
|
||
|
);
|
||
|
if server {
|
||
|
mem::swap(&mut enc, &mut dec);
|
||
|
}
|
||
|
tokio::spawn(thread(
|
||
|
enc,
|
||
|
dec,
|
||
|
server,
|
||
|
chans.clone(),
|
||
|
cipher_remote.clone(),
|
||
|
cipher_local.clone(),
|
||
|
));
|
||
|
}
|
||
|
}
|