fckrkn/src/main.rs

253 lines
7.9 KiB
Rust
Raw Normal View History

2024-02-01 04:52:23 +07:00
use std::{
collections::HashMap,
mem,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
time::Duration,
};
use async_shared_timeout::Timeout;
use bytes::BytesMut;
use chacha20poly1305::{aead::Buffer, AeadCore, AeadInPlace, ChaCha20Poly1305, KeyInit};
use dashmap::DashMap;
use tokio::{net::UdpSocket, sync::mpsc};
use typenum::Unsigned;
#[derive(Debug)]
enum BMut<'a> {
Bytes(BytesMut),
Vec(&'a mut Vec<u8>),
}
impl<'a> Buffer for BMut<'a> {
fn extend_from_slice(&mut self, other: &[u8]) -> chacha20poly1305::aead::Result<()> {
match self {
Self::Bytes(x) => x.extend_from_slice(other),
Self::Vec(x) => x.extend_from_slice(other),
}
Ok(())
}
fn truncate(&mut self, len: usize) {
match self {
Self::Bytes(x) => x.truncate(len),
Self::Vec(x) => x.truncate(len),
}
}
}
impl<'a> AsRef<[u8]> for BMut<'a> {
fn as_ref(&self) -> &[u8] {
match self {
Self::Bytes(x) => x.as_ref(),
Self::Vec(x) => x.as_ref(),
}
}
}
impl<'a> AsMut<[u8]> for BMut<'a> {
fn as_mut(&mut self) -> &mut [u8] {
match self {
Self::Bytes(x) => x.as_mut(),
Self::Vec(x) => x.as_mut(),
}
}
}
enum ChanOrSock {
Chan(mpsc::Receiver<BytesMut>),
Sock(Arc<UdpSocket>, Vec<u8>),
}
impl ChanOrSock {
async fn recv(&mut self) -> Option<BMut<'_>> {
match self {
Self::Chan(x) => x.recv().await.map(BMut::Bytes),
Self::Sock(sock, buf) => {
buf.resize(65536, 0);
let len = sock.recv(buf).await.ok()?;
buf.truncate(len);
Some(BMut::Vec(buf))
}
}
}
}
static SESSION_ID: AtomicU32 = AtomicU32::new(0);
static PROTOCOL_VERSION: u8 = 0;
async fn thread(
(mut rx_enc, sock_enc, enc_addr): (ChanOrSock, Arc<UdpSocket>, SocketAddr),
(mut rx_dec, sock_dec, dec_addr): (ChanOrSock, Arc<UdpSocket>, SocketAddr),
server: bool,
chans: Arc<DashMap<SocketAddr, mpsc::Sender<BytesMut>>>,
cipher_remote: ChaCha20Poly1305,
cipher_local: ChaCha20Poly1305,
) {
let timeout = Timeout::new(
async_shared_timeout::runtime::Tokio::new(),
Duration::from_secs(1),
);
let session_id = SESSION_ID.fetch_add(1, Ordering::SeqCst).to_be_bytes();
let remote_to_local = async {
let mut timeout_set = false;
while let Some(mut buf) = rx_enc.recv().await {
if buf.len() < 12 {
continue;
}
let nonce = <[u8; 12]>::try_from(&buf.as_ref()[buf.len() - 12..]).unwrap();
let time = chrono::Utc::now().timestamp_nanos_opt().unwrap();
let nonce1 = [
nonce[4], nonce[5], nonce[6], nonce[7], nonce[8], nonce[9], nonce[10], nonce[11],
];
if i64::from_be_bytes(nonce1).abs_diff(time) > 10_000_000_000 {
eprintln!("got invalid nonce");
continue;
}
buf.truncate(buf.len() - 12);
if let Err(err) = cipher_remote.decrypt_in_place((&nonce).into(), b"", &mut buf) {
eprintln!("decrypt error: {err}");
continue;
}
if server {
if !timeout_set {
timeout_set = true;
timeout.set_default_timeout(Duration::from_secs(90));
}
timeout.reset();
}
if let Err(err) = sock_dec.as_ref().send_to(buf.as_ref(), dec_addr).await {
eprintln!("decrypted send error: {err}");
}
}
};
let local_to_remote = async {
let mut timeout_set = false;
while let Some(mut buf) = rx_dec.recv().await {
let time = chrono::Utc::now()
.timestamp_nanos_opt()
.unwrap()
.to_be_bytes();
let nonce = [
PROTOCOL_VERSION,
session_id[0],
session_id[1],
session_id[2],
time[0],
time[1],
time[2],
time[3],
time[4],
time[5],
time[6],
time[7],
];
if let Err(err) = cipher_local.encrypt_in_place((&nonce).into(), b"", &mut buf) {
eprintln!("encrypt error: {err}");
continue;
}
buf.extend_from_slice(&nonce).unwrap();
if !server {
if !timeout_set {
timeout_set = true;
timeout.set_default_timeout(Duration::from_secs(90));
}
timeout.reset();
}
if let Err(err) = sock_enc.as_ref().send_to(buf.as_ref(), enc_addr).await {
eprintln!("encrypted send error: {err}");
}
}
};
tokio::select! {
_ = timeout.wait() => {}
_ = remote_to_local => {}
_ = local_to_remote => {}
}
chans.remove(&enc_addr);
}
#[tokio::main]
async fn main() {
assert_eq!(<ChaCha20Poly1305 as AeadCore>::NonceSize::to_usize(), 12);
let mut args = std::env::args().skip(1);
let data = std::fs::read_to_string(args.next().unwrap())
.unwrap()
.lines()
.map(|x| {
x.split('#')
.next()
.unwrap()
.chars()
.filter(|x| !x.is_whitespace())
.collect::<String>()
})
.filter(|x| !x.is_empty())
.map(|x| {
let (a, b) = x.split_once('=').unwrap_or((&x, ""));
(a.to_owned(), b.to_owned())
})
.collect::<HashMap<_, _>>();
let ver: u8 = data.get("version").unwrap().parse().unwrap();
assert!(matches!(ver, 0));
let src: SocketAddr = data.get("bind").unwrap().parse().unwrap();
let dst: SocketAddr = data.get("connect").unwrap().parse().unwrap();
let h2b = |x: &str| {
x.as_bytes()
.chunks_exact(2)
.map(|x| u8::from_str_radix(std::str::from_utf8(x).unwrap(), 16).unwrap())
.collect::<Box<[u8]>>()
};
let key_local = h2b(data.get("key_local").unwrap());
let key_remote = h2b(data.get("key_remote").unwrap());
let server = data.contains_key("server");
let cipher_local = ChaCha20Poly1305::new_from_slice(&key_local).unwrap();
let cipher_remote = ChaCha20Poly1305::new_from_slice(&key_remote).unwrap();
let sock = Arc::new(UdpSocket::bind(src).await.unwrap());
let chans = Arc::new(DashMap::<SocketAddr, mpsc::Sender<_>>::new());
let mut buf = BytesMut::with_capacity(65536);
loop {
buf.resize(65536, 0);
let Ok((len, addr)) = sock.recv_from(&mut buf).await else {
continue;
};
if let Some(ch) = chans.get(&addr) {
if ch.send(buf.split_to(len)).await.is_ok() {
continue;
}
}
let (tx, rx) = mpsc::channel(128);
tx.send(buf.split_to(len)).await.unwrap();
chans.insert(addr, tx);
let sock2 = Arc::new(
UdpSocket::bind((
if dst.is_ipv6() {
IpAddr::V6(Ipv6Addr::UNSPECIFIED)
} else {
IpAddr::V4(Ipv4Addr::UNSPECIFIED)
},
0,
))
.await
.unwrap(),
);
sock2.connect(dst).await.unwrap();
let (mut enc, mut dec) = (
(ChanOrSock::Sock(sock2.clone(), vec![]), sock2, dst),
(ChanOrSock::Chan(rx), sock.clone(), addr),
);
if server {
mem::swap(&mut enc, &mut dec);
}
tokio::spawn(thread(
enc,
dec,
server,
chans.clone(),
cipher_remote.clone(),
cipher_local.clone(),
));
}
}