use anyhow::Result; use bytes::{Buf, BufMut, Bytes, BytesMut}; use hmac::{Hmac, Mac}; use sha2::Sha256; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, RwLock as StdRwLock}; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::{mpsc, RwLock}; use tracing::info; use tokio::net::TcpStream; use base64::Engine; use std::pin::Pin; use std::task::{Context as TaskContext, Poll}; use chacha20poly1305::{aead::Aead, ChaCha20Poly1305, Nonce}; use x25519_dalek::StaticSecret; use ostp_core::framing::wss::{encode_wss_frame, decode_wss_frame, WssFrameResult}; use ostp_core::crypto::reality::{parse_client_hello, derive_keys, verify_session_id, REALITY_SERVER_HANDSHAKE_RECORDS}; use crate::RealityServerConfig; pub async fn handle_tcp_connection( mut stream: S, peer_addr: SocketAddr, shared_keys: Arc>>, udp_tx: mpsc::Sender<(Bytes, SocketAddr)>, tcp_map: Arc>>>, reality_config: Option>, fb_target: Option, ) -> Result<()> where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { let mut initial_buf = vec![0u8; 16384]; let mut header_len = 0; // Read the first chunk to determine if it's TLS or HTTP let n = stream.read(&mut initial_buf).await?; if n == 0 { anyhow::bail!("connection closed before data received"); } header_len += n; // Check if it's a TLS record (0x16 0x03 0x01 or 0x16 0x03 0x03) if initial_buf[0] == 0x16 && initial_buf[1] == 0x03 { // It's a TLS record. We need to ensure we read the entire record. if header_len >= 5 { let record_len = 5 + u16::from_be_bytes([initial_buf[3], initial_buf[4]]) as usize; if record_len > initial_buf.len() { anyhow::bail!("TLS record too large"); } while header_len < record_len { let n = stream.read(&mut initial_buf[header_len..record_len]).await?; if n == 0 { anyhow::bail!("connection closed while reading TLS record"); } header_len += n; } } if let Some(rc) = reality_config { return handle_reality_connection(stream, initial_buf[..header_len].to_vec(), peer_addr, shared_keys, udp_tx, tcp_map, rc).await; } else { // Received TLS but Reality is not enabled if let Some(target) = fb_target { tracing::info!("Fallback triggered for {} -> {}", peer_addr, target); let mut dest_stream: TcpStream = TcpStream::connect(&target).await?; dest_stream.write_all(&initial_buf[..header_len]).await?; tokio::io::copy_bidirectional(&mut stream, &mut dest_stream).await?; return Ok(()); } else { anyhow::bail!("received TLS but Reality is not configured and no fallback target"); } } } // Otherwise, assume it's HTTP (Standard xhttp/wss) loop { if initial_buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; } if header_len == initial_buf.len() { anyhow::bail!("handshake headers too large"); } let n = stream.read(&mut initial_buf[header_len..]).await?; if n == 0 { anyhow::bail!("connection closed before HTTP handshake complete"); } header_len += n; } let headers_str = String::from_utf8_lossy(&initial_buf[..header_len]); let wss = if headers_str.starts_with("GET /wss HTTP/1.1\r\n") { true } else if headers_str.starts_with("GET /stream HTTP/1.1\r\n") { false } else { if let Some(target) = fb_target { tracing::info!("Fallback triggered for {} -> {}", peer_addr, target); let mut dest_stream: TcpStream = TcpStream::connect(&target).await?; dest_stream.write_all(&initial_buf[..header_len]).await?; tokio::io::copy_bidirectional(&mut stream, &mut dest_stream).await?; return Ok(()); } else { let _ = stream.write_all(b"HTTP/1.1 404 Not Found\r\n\r\n").await; anyhow::bail!("invalid request line"); } }; // Extract Authorization let mut signature_base64 = None; for line in headers_str.lines() { let lower = line.to_ascii_lowercase(); if lower.starts_with("authorization: bearer ") { signature_base64 = Some(line[22..].trim().to_string()); } } let sig_b64 = match signature_base64 { Some(s) => s, None => { if let Some(target) = fb_target { tracing::info!("Fallback triggered for {} -> {}", peer_addr, target); let mut dest_stream: TcpStream = TcpStream::connect(&target).await?; dest_stream.write_all(&initial_buf[..header_len]).await?; tokio::io::copy_bidirectional(&mut stream, &mut dest_stream).await?; return Ok(()); } else { let _ = stream.write_all(b"HTTP/1.1 404 Not Found\r\n\r\n").await; anyhow::bail!("missing authorization"); } } }; let sig_bytes = match base64::Engine::decode(&base64::engine::general_purpose::STANDARD_NO_PAD, &sig_b64) { Ok(b) => b, Err(_) => { if let Some(target) = fb_target { tracing::info!("Fallback triggered for {} -> {}", peer_addr, target); let mut dest_stream: TcpStream = TcpStream::connect(&target).await?; dest_stream.write_all(&initial_buf[..header_len]).await?; tokio::io::copy_bidirectional(&mut stream, &mut dest_stream).await?; return Ok(()); } else { let _ = stream.write_all(b"HTTP/1.1 404 Not Found\r\n\r\n").await; anyhow::bail!("invalid base64 signature"); } } }; if sig_bytes.len() < 8 { let _ = stream.write_all(b"HTTP/1.1 404 Not Found\r\n\r\n").await; anyhow::bail!("signature too short"); } let ts_bytes: [u8; 8] = sig_bytes[0..8].try_into().unwrap(); let client_ts = u64::from_be_bytes(ts_bytes); let provided_mac = &sig_bytes[8..]; let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); if client_ts > now + 30 || client_ts < now.saturating_sub(60) { let _ = stream.write_all(b"HTTP/1.1 404 Not Found\r\n\r\n").await; anyhow::bail!("timestamp out of bounds (replay protection)"); } // Verify HMAC against known keys let keys = { let guard = shared_keys.read().unwrap(); guard.keys().cloned().collect::>() }; let mut authenticated = false; for key in keys { let mut mac = as Mac>::new_from_slice(key.as_bytes()) .unwrap_or_else(|_| as Mac>::new_from_slice(b"default").unwrap()); mac.update(&ts_bytes); if mac.verify_slice(provided_mac).is_ok() { authenticated = true; break; } } if !authenticated { if let Some(target) = fb_target { tracing::info!("Fallback triggered for {} -> {}", peer_addr, target); let mut dest_stream: TcpStream = TcpStream::connect(&target).await?; dest_stream.write_all(&initial_buf[..header_len]).await?; tokio::io::copy_bidirectional(&mut stream, &mut dest_stream).await?; return Ok(()); } else { let _ = stream.write_all(b"HTTP/1.1 404 Not Found\r\n\r\n").await; anyhow::bail!("unauthorized (invalid HMAC)"); } } if wss { let response = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nX-Ostp-Server: 1\r\n\r\n"; stream.write_all(response.as_bytes()).await?; } else { let response = "HTTP/1.1 200 OK\r\nX-Ostp-Server: 1\r\nContent-Type: application/octet-stream\r\n\r\n"; stream.write_all(response.as_bytes()).await?; } info!("UoT client authenticated from {} (xhttp)", peer_addr); start_uot_loops(stream, peer_addr, wss, tcp_map, udp_tx).await } async fn handle_reality_connection( mut stream: S, initial_buf: Vec, peer_addr: SocketAddr, _shared_keys: Arc>>, // Note: Reality uses its own keys (sid) udp_tx: mpsc::Sender<(Bytes, SocketAddr)>, tcp_map: Arc>>>, reality_config: Arc, ) -> Result<()> where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { // Try to parse ClientHello let parsed_ch = parse_client_hello(&initial_buf); let mut authenticated = false; let mut data_key_opt = None; if let Some(ch) = parsed_ch { // Validate SNI if reality_config.sni_list.contains(&ch.sni) { // Decode Server Private Key if let Ok(priv_bytes) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(&reality_config.private_key) { if priv_bytes.len() == 32 { let mut secret_bytes = [0u8; 32]; secret_bytes.copy_from_slice(&priv_bytes); let server_priv = StaticSecret::from(secret_bytes); let shared_secret = server_priv.diffie_hellman(&ch.c_pub); let (auth_key, data_key) = derive_keys(shared_secret.as_bytes()); // Attempt to decrypt Session ID if let Some((sid, _ts)) = verify_session_id(&auth_key, &ch.session_id) { // Check if sid is in config let sid_hex = hex::encode(sid); if reality_config.sid == sid_hex { authenticated = true; data_key_opt = Some(data_key); } } } } } } if authenticated { let data_key = data_key_opt.unwrap(); info!("Reality client authenticated from {} (sid matched)", peer_addr); // Build a fake TLS 1.3 server flight that matches what a real server sends. // Must be exactly REALITY_SERVER_HANDSHAKE_RECORDS (5) TLS records: // 1. ServerHello (0x16) - static blob with fake key share // 2. ChangeCipherSpec (0x14) - RFC 8446 §D.4 middlebox compat // 3. Fake EE (0x17) - simulates EncryptedExtensions // 4. Fake Certificate (0x17) - simulates Certificate (big, DPI-realistic) // 5. Fake Finished (0x17) - simulates CertificateVerify + Finished let _ = REALITY_SERVER_HANDSHAKE_RECORDS; // assert constant is imported (= 5) // Record 1: ServerHello (0x16), same static blob as before (valid structure) let server_hello_rec = hex::decode( "160303007a0200007603030000000000000000000000000000000000000000000000\ 000000000000000000000000200000000000000000000000000000000000000000\ 0000000000000000000000000000130100002e002b0002030400330024001d0020\ e29b191a62d0572e9a30d0fb9d08e50bc78d591dfc1dbafbfa533411db1c8e11" ).unwrap(); // Record 2: ChangeCipherSpec (0x14) let ccs_rec: &[u8] = &[0x14, 0x03, 0x03, 0x00, 0x01, 0x01]; // Record 3: Fake EncryptedExtensions (0x17), 108 zero bytes payload let mut fake_ee = vec![0x17u8, 0x03, 0x03, 0x00, 108]; fake_ee.extend_from_slice(&[0u8; 108]); // Record 4: Fake Certificate (0x17), 812 zero bytes (realistic cert size for DPI) let cert_payload_len: u16 = 812; let mut fake_cert = vec![0x17u8, 0x03, 0x03, (cert_payload_len >> 8) as u8, (cert_payload_len & 0xff) as u8]; fake_cert.extend_from_slice(&vec![0u8; cert_payload_len as usize]); // Record 5: Fake Finished (0x17), 52 zero bytes (CertificateVerify + Finished) let mut fake_fin = vec![0x17u8, 0x03, 0x03, 0x00, 52]; fake_fin.extend_from_slice(&[0u8; 52]); let mut server_flight = Vec::with_capacity( server_hello_rec.len() + ccs_rec.len() + fake_ee.len() + fake_cert.len() + fake_fin.len() ); server_flight.extend_from_slice(&server_hello_rec); server_flight.extend_from_slice(ccs_rec); server_flight.extend_from_slice(&fake_ee); server_flight.extend_from_slice(&fake_cert); server_flight.extend_from_slice(&fake_fin); stream.write_all(&server_flight).await?; // The client now sends ClientHello + CCS (6 bytes) as two separate TLS records. // The ClientHello was already consumed into initial_buf above. // The CCS may arrive as a separate TCP segment - drain it from the raw stream // before wrapping in RealityStream so RealityStream only ever sees 0x17 records. { let mut ccs_head = [0u8; 5]; if stream.read_exact(&mut ccs_head).await.is_ok() { // Expected: CCS record 0x14 0x03 0x03 0x00 0x01 // If it's something else (unlikely), we still drain its payload to stay in sync. let ccs_payload_len = u16::from_be_bytes([ccs_head[3], ccs_head[4]]) as usize; if ccs_payload_len <= 64 { let mut _discard = vec![0u8; ccs_payload_len]; let _ = stream.read_exact(&mut _discard).await; } } } let reality_stream = RealityStream::new(stream, data_key); return process_inner_reality_stream(reality_stream, peer_addr, tcp_map, udp_tx).await; } else { // Fallback: act as a transparent proxy to `reality_config.dest` info!("Reality fallback triggered for {} -> {}", peer_addr, reality_config.dest); let mut dest_stream: TcpStream = TcpStream::connect(&reality_config.dest).await?; dest_stream.write_all(&initial_buf).await?; tokio::io::copy_bidirectional(&mut stream, &mut dest_stream).await?; return Ok(()); } } async fn process_inner_reality_stream( mut stream: S, peer_addr: SocketAddr, tcp_map: Arc>>>, udp_tx: mpsc::Sender<(Bytes, SocketAddr)>, ) -> Result<()> where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { // 1. Read the inner HTTP Handshake let mut buf = [0u8; 4096]; let mut header_len = 0; loop { let n = stream.read(&mut buf[header_len..]).await?; if n == 0 { anyhow::bail!("inner connection closed before handshake complete"); } header_len += n; if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; } if header_len == buf.len() { anyhow::bail!("inner handshake headers too large"); } } let headers_str = String::from_utf8_lossy(&buf[..header_len]); let wss = if headers_str.starts_with("GET /wss HTTP/1.1\r\n") { true } else if headers_str.starts_with("GET /stream HTTP/1.1\r\n") { false } else { anyhow::bail!("invalid inner request line"); }; // We skip signature validation because Reality already authenticated the user via Session ID! if wss { let response = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nX-Ostp-Server: 1\r\n\r\n"; stream.write_all(response.as_bytes()).await?; } else { let response = "HTTP/1.1 200 OK\r\nX-Ostp-Server: 1\r\nContent-Type: application/octet-stream\r\n\r\n"; stream.write_all(response.as_bytes()).await?; } start_uot_loops(stream, peer_addr, wss, tcp_map, udp_tx).await } async fn start_uot_loops( stream: S, peer_addr: SocketAddr, wss: bool, tcp_map: Arc>>>, udp_tx: mpsc::Sender<(Bytes, SocketAddr)>, ) -> Result<()> where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, { // Register this connection in the map let (tx, mut rx) = mpsc::channel::(16384); { tcp_map.write().await.insert(peer_addr, tx); } // Process streams let (mut read_half, mut write_half) = tokio::io::split(stream); // Spawn writer task let peer_clone = peer_addr; let tcp_map_clone = tcp_map.clone(); let writer_task = tokio::spawn(async move { while let Some(packet) = rx.recv().await { if wss { let header = encode_wss_frame(&packet, false); // Server sends unmasked WSS frames if write_half.write_all(&header).await.is_err() { break; } } else { let mut out = BytesMut::with_capacity(2 + packet.len()); out.put_u16(packet.len() as u16); out.put_slice(&packet); if write_half.write_all(&out).await.is_err() { break; } } } let _ = tcp_map_clone.write().await.remove(&peer_clone); }); // Spawn reader task let tcp_map_clone2 = tcp_map.clone(); let reader_task = tokio::spawn(async move { if wss { let mut read_buf = BytesMut::with_capacity(65536); let mut tmp = [0u8; 8192]; loop { match read_half.read(&mut tmp).await { Ok(0) => break, Ok(n) => { read_buf.put_slice(&tmp[..n]); loop { match decode_wss_frame(&mut read_buf) { WssFrameResult::Frame { payload, total_len } => { if udp_tx.send((Bytes::from(payload), peer_clone)).await.is_err() { return; } read_buf.advance(total_len); } WssFrameResult::Incomplete => break, } } } Err(_) => break, } } } else { let mut len_buf = [0u8; 2]; loop { if read_half.read_exact(&mut len_buf).await.is_err() { break; } let len = u16::from_be_bytes(len_buf) as usize; if len > 65535 { break; } let mut data = vec![0u8; len]; if read_half.read_exact(&mut data).await.is_err() { break; } if udp_tx.send((Bytes::from(data), peer_clone)).await.is_err() { break; } } } let _ = tcp_map_clone2.write().await.remove(&peer_clone); }); let _ = tokio::join!(writer_task, reader_task); Ok(()) } // ----------------------------------------------------------------------- // RealityStream: Wraps a TCP stream in fake TLS Application Data Records // ----------------------------------------------------------------------- struct RealityStream { inner: S, data_key: ChaCha20Poly1305, rx_nonce: u64, tx_nonce: u64, rx_buf: BytesMut, plaintext_buf: BytesMut, tx_buf: BytesMut, } impl RealityStream { fn new(inner: S, data_key: ChaCha20Poly1305) -> Self { Self { inner, data_key, rx_nonce: 0, tx_nonce: 0, rx_buf: BytesMut::with_capacity(16384), plaintext_buf: BytesMut::new(), tx_buf: BytesMut::new(), } } fn make_nonce(seq: u64) -> [u8; 12] { let mut nonce = [0u8; 12]; nonce[4..12].copy_from_slice(&seq.to_le_bytes()); nonce } } impl tokio::io::AsyncRead for RealityStream { fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { loop { if !self.plaintext_buf.is_empty() { let out_len = std::cmp::min(buf.remaining(), self.plaintext_buf.len()); buf.put_slice(&self.plaintext_buf[..out_len]); self.plaintext_buf.advance(out_len); return Poll::Ready(Ok(())); } if self.rx_buf.len() >= 5 { let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize; if self.rx_buf.len() >= 5 + len { if self.rx_buf[0] != 0x17 { return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected application data record"))); } let ciphertext = &self.rx_buf[5..5+len]; let nonce_bytes = Self::make_nonce(self.rx_nonce); let nonce = Nonce::from_slice(&nonce_bytes); match self.data_key.decrypt(nonce, ciphertext) { Ok(plaintext) => { self.rx_nonce += 1; self.plaintext_buf.put_slice(&plaintext); self.rx_buf.advance(5 + len); continue; } Err(_) => return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed"))), } } } let mut read_buf = [0u8; 8192]; let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf); match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) { Poll::Ready(Ok(())) => { if tokio_buf.filled().is_empty() { return Poll::Ready(Ok(())); } self.rx_buf.put_slice(tokio_buf.filled()); } Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, } } } } impl tokio::io::AsyncWrite for RealityStream { fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); while !this.tx_buf.is_empty() { match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { Poll::Ready(Ok(n)) => this.tx_buf.advance(n), Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, } } let nonce_bytes = Self::make_nonce(this.tx_nonce); let nonce = Nonce::from_slice(&nonce_bytes); match this.data_key.encrypt(nonce, buf) { Ok(ciphertext) => { this.tx_nonce += 1; this.tx_buf.reserve(5 + ciphertext.len()); this.tx_buf.put_u8(0x17); this.tx_buf.put_u16(0x0303); this.tx_buf.put_u16(ciphertext.len() as u16); this.tx_buf.put_slice(&ciphertext); match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { Poll::Ready(Ok(n)) => this.tx_buf.advance(n), Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => {} } Poll::Ready(Ok(buf.len())) } Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { let this = self.get_mut(); while !this.tx_buf.is_empty() { match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { Poll::Ready(Ok(n)) => this.tx_buf.advance(n), Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, } } Pin::new(&mut this.inner).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { let this = self.get_mut(); while !this.tx_buf.is_empty() { match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { Poll::Ready(Ok(n)) => this.tx_buf.advance(n), Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, } } Pin::new(&mut this.inner).poll_shutdown(cx) } }