From f4830f043ffd84ada977b1bf55a7fd457e94db49 Mon Sep 17 00:00:00 2001 From: ospab Date: Fri, 29 May 2026 13:59:59 +0300 Subject: [PATCH] feat: implement optional WSS framing for DPI bypass & extract framing logic --- Cargo.lock | 1 - ostp-client/src/bridge.rs | 4 +- ostp-client/src/config.rs | 6 + ostp-client/src/transport/xhttp.rs | 115 ++++++---- ostp-core/src/framing/mod.rs | 2 + ostp-core/src/framing/wss.rs | 74 +++++++ ostp-core/src/protocol.rs | 342 +++++++++++++++-------------- ostp-flutter/lib/main.dart | 16 +- ostp-gui/src-tauri/Cargo.toml | 1 + ostp-gui/src-tauri/src/lib.rs | 23 +- ostp-gui/src/index.html | 13 ++ ostp-gui/src/main.js | 3 + ostp-jni/Cargo.toml | 1 - ostp-jni/src/lib.rs | 44 ++-- ostp-server/src/lib.rs | 255 ++++++++++++--------- ostp-server/src/transport/uot.rs | 108 ++++++--- ostp-tun-helper/src/main.rs | 36 ++- ostp/src/main.rs | 8 +- 18 files changed, 676 insertions(+), 376 deletions(-) create mode 100644 ostp-core/src/framing/wss.rs diff --git a/Cargo.lock b/Cargo.lock index c03c813..343569f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1422,7 +1422,6 @@ dependencies = [ "anyhow", "bytes", "jni", - "lazy_static", "ostp-client", "ostp-core", "portable-atomic", diff --git a/ostp-client/src/bridge.rs b/ostp-client/src/bridge.rs index e93f845..809b9f7 100644 --- a/ostp-client/src/bridge.rs +++ b/ostp-client/src/bridge.rs @@ -67,6 +67,7 @@ pub struct Bridge { pub transport_mode: String, pub stealth_sni: String, pub stealth_port: u16, + pub wss: bool, pub mtu: usize, pub reality_enabled: bool, @@ -99,6 +100,7 @@ impl Bridge { transport_mode: config.transport.mode.clone(), stealth_sni: config.transport.stealth_sni.clone(), stealth_port: config.transport.stealth_port, + wss: config.transport.wss, mtu: config.ostp.mtu, reality_enabled: !config.reality.pbk.is_empty(), @@ -905,7 +907,7 @@ impl Bridge { port }; let (tx, rx) = crate::transport::xhttp::connect_xhttp( - target_ip, uot_port, &self.stealth_sni, &self.access_key, self.reality_enabled + target_ip, uot_port, &self.stealth_sni, &self.access_key, self.reality_enabled, self.wss ).await?; Ok(crate::transport::Transport::Uot { tx, rx }) } else { diff --git a/ostp-client/src/config.rs b/ostp-client/src/config.rs index 599005b..8cb1f78 100644 --- a/ostp-client/src/config.rs +++ b/ostp-client/src/config.rs @@ -79,6 +79,9 @@ pub struct TransportConfig { /// TCP Port for the stealth connection #[serde(default = "default_stealth_port")] pub stealth_port: u16, + /// Enable strict RFC 6455 WebSocket framing + #[serde(default)] + pub wss: bool, } fn default_transport_mode() -> String { "udp".to_string() } @@ -90,6 +93,7 @@ impl Default for TransportConfig { mode: default_transport_mode(), stealth_sni: String::new(), stealth_port: default_stealth_port(), + wss: false, } } } @@ -185,6 +189,7 @@ struct RawTransportSection { mode: Option, stealth_sni: Option, stealth_port: Option, + wss: Option, } #[derive(Debug, Deserialize)] @@ -274,6 +279,7 @@ impl ClientConfig { mode: raw.transport.as_ref().and_then(|t| t.mode.clone()).unwrap_or_else(|| "udp".to_string()), stealth_sni: raw.transport.as_ref().and_then(|t| t.stealth_sni.clone()).unwrap_or_else(|| "microsoft.com".to_string()), stealth_port: raw.transport.as_ref().and_then(|t| t.stealth_port).unwrap_or(443), + wss: raw.transport.as_ref().and_then(|t| t.wss).unwrap_or(false), }, exclusions: ExclusionConfig { domains: exclusions.domains.unwrap_or_default(), diff --git a/ostp-client/src/transport/xhttp.rs b/ostp-client/src/transport/xhttp.rs index 4ede3df..3e5dd36 100644 --- a/ostp-client/src/transport/xhttp.rs +++ b/ostp-client/src/transport/xhttp.rs @@ -12,6 +12,7 @@ use rustls::ClientConfig; use rustls::pki_types::ServerName; use std::sync::Arc as StdArc; use tokio_rustls::TlsConnector; +use ostp_core::framing::wss::{encode_wss_frame, decode_wss_frame, WssFrameResult}; mod danger { use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; @@ -73,6 +74,7 @@ pub async fn connect_xhttp( sni: &str, access_key: &[u8], tls_enabled: bool, + wss: bool, ) -> Result<(mpsc::Sender, Arc>>)> { let addr = std::net::SocketAddr::new(target_ip, port); let tcp_stream = TcpStream::connect(addr).await @@ -96,9 +98,9 @@ pub async fn connect_xhttp( let tls_stream = connector.connect(server_name, tcp_stream).await .with_context(|| "TLS handshake failed")?; - xhttp_handshake_and_loop(tls_stream, target_ip, sni, access_key).await + xhttp_handshake_and_loop(tls_stream, target_ip, sni, access_key, wss).await } else { - xhttp_handshake_and_loop(tcp_stream, target_ip, sni, access_key).await + xhttp_handshake_and_loop(tcp_stream, target_ip, sni, access_key, wss).await } } @@ -107,6 +109,7 @@ async fn xhttp_handshake_and_loop( target_ip: IpAddr, sni: &str, access_key: &[u8], + wss: bool, ) -> Result<(mpsc::Sender, Arc>>)> where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, @@ -126,18 +129,27 @@ where let http_host = if sni.is_empty() { target_ip.to_string() } else { sni.to_string() }; - // 2. Send fake WebSocket upgrade — looks like a legit browser request to bypass DPI/proxies. - let req = format!( - "GET /stream HTTP/1.1\r\n\ - Host: {}\r\n\ - Upgrade: websocket\r\n\ - Connection: upgrade\r\n\ - Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ - Sec-WebSocket-Version: 13\r\n\ - Authorization: Bearer {}\r\n\ - \r\n", - http_host, auth_token - ); + let req = if wss { + format!( + "GET /wss HTTP/1.1\r\n\ + Host: {}\r\n\ + Upgrade: websocket\r\n\ + Connection: upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n\ + Authorization: Bearer {}\r\n\ + \r\n", + http_host, auth_token + ) + } else { + format!( + "GET /stream HTTP/1.1\r\n\ + Host: {}\r\n\ + Authorization: Bearer {}\r\n\ + \r\n", + http_host, auth_token + ) + }; stream.write_all(req.as_bytes()).await?; stream.flush().await?; @@ -168,13 +180,14 @@ where // 5. Split into read/write halves and start UoT loops let (rx, tx) = tokio::io::split(stream); - start_uot_loops(rx, tx, leftover) + start_uot_loops(rx, tx, leftover, wss) } fn start_uot_loops( mut net_rx: R, mut net_tx: W, - leftover: Vec + leftover: Vec, + wss: bool, ) -> Result<(mpsc::Sender, Arc>>)> where R: tokio::io::AsyncRead + Unpin + Send + 'static, @@ -183,39 +196,65 @@ where let (app_tx, mut tx_rx) = mpsc::channel::(16384); let (rx_tx, app_rx) = mpsc::channel::(16384); - // TX Loop (App -> UoT -> Network): prefix each frame with u16 BE length + // TX Loop (App -> UoT -> Network) tokio::spawn(async move { while let Some(frame) = tx_rx.recv().await { - let len = frame.len() as u16; - if net_tx.write_u16(len).await.is_err() { break; } - if net_tx.write_all(&frame).await.is_err() { break; } + let len = frame.len(); + if wss { + let header = encode_wss_frame(&frame, true); + if net_tx.write_all(&header).await.is_err() { break; } + } else { + let len_u16 = len as u16; + if net_tx.write_u16(len_u16).await.is_err() { break; } + if net_tx.write_all(&frame).await.is_err() { break; } + } } }); - // RX Loop (Network -> UoT -> App): parse [u16 len][payload] frames + // RX Loop (Network -> UoT -> App) tokio::spawn(async move { let mut buffer = BytesMut::from(&leftover[..]); loop { - while buffer.len() < 2 { - let mut temp = [0u8; 4096]; - match net_rx.read(&mut temp).await { - Ok(0) | Err(_) => return, - Ok(n) => buffer.extend_from_slice(&temp[..n]), + if wss { + // Parse WSS frame (from server, so NOT masked) + match decode_wss_frame(&buffer) { + WssFrameResult::Incomplete => { + let mut temp = [0u8; 4096]; + match net_rx.read(&mut temp).await { + Ok(0) | Err(_) => return, + Ok(n) => buffer.extend_from_slice(&temp[..n]), + } + } + WssFrameResult::Frame { payload, total_len } => { + let _ = buffer.split_to(total_len); + if rx_tx.send(Bytes::from(payload)).await.is_err() { + break; + } + } } - } - let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize; - - while buffer.len() < 2 + len { - let mut temp = [0u8; 4096]; - match net_rx.read(&mut temp).await { - Ok(0) | Err(_) => return, - Ok(n) => buffer.extend_from_slice(&temp[..n]), + } else { + // Parse raw u16 framing + while buffer.len() < 2 { + let mut temp = [0u8; 4096]; + match net_rx.read(&mut temp).await { + Ok(0) | Err(_) => return, + Ok(n) => buffer.extend_from_slice(&temp[..n]), + } } - } + let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize; - let packet = buffer.split_to(2 + len); - if rx_tx.send(Bytes::from(packet[2..].to_vec())).await.is_err() { - break; + while buffer.len() < 2 + len { + let mut temp = [0u8; 4096]; + match net_rx.read(&mut temp).await { + Ok(0) | Err(_) => return, + Ok(n) => buffer.extend_from_slice(&temp[..n]), + } + } + + let packet = buffer.split_to(2 + len); + if rx_tx.send(Bytes::from(packet[2..].to_vec())).await.is_err() { + break; + } } } }); diff --git a/ostp-core/src/framing/mod.rs b/ostp-core/src/framing/mod.rs index 24eaed0..5cc44af 100644 --- a/ostp-core/src/framing/mod.rs +++ b/ostp-core/src/framing/mod.rs @@ -1,5 +1,7 @@ pub mod frame; pub mod padding; +pub mod wss; pub use frame::{FrameHeader, FrameKind, FramedPacket}; pub use padding::{AdaptivePadder, PaddingStrategy, TrafficProfile}; +pub use wss::{encode_wss_frame, decode_wss_frame, WssFrameResult}; diff --git a/ostp-core/src/framing/wss.rs b/ostp-core/src/framing/wss.rs new file mode 100644 index 0000000..19d0d77 --- /dev/null +++ b/ostp-core/src/framing/wss.rs @@ -0,0 +1,74 @@ +use rand::RngCore; + +pub enum WssFrameResult { + Incomplete, + Frame { payload: Vec, total_len: usize }, +} + +pub fn encode_wss_frame(payload: &[u8], masked: bool) -> Vec { + let len = payload.len(); + let mut header = Vec::with_capacity(14 + len); + header.push(0x82); // FIN + Binary + + let mask_bit = if masked { 0x80 } else { 0x00 }; + + if len <= 125 { + header.push(mask_bit | (len as u8)); + } else if len <= 65535 { + header.push(mask_bit | 126); + header.extend_from_slice(&(len as u16).to_be_bytes()); + } else { + header.push(mask_bit | 127); + header.extend_from_slice(&(len as u64).to_be_bytes()); + } + + if masked { + let mut mask = [0u8; 4]; + rand::thread_rng().fill_bytes(&mut mask); + header.extend_from_slice(&mask); + + for (i, &b) in payload.iter().enumerate() { + header.push(b ^ mask[i % 4]); + } + } else { + header.extend_from_slice(payload); + } + + header +} + +pub fn decode_wss_frame(buffer: &[u8]) -> WssFrameResult { + if buffer.len() < 2 { + return WssFrameResult::Incomplete; + } + let is_masked = (buffer[1] & 0x80) != 0; + let payload_len_7 = (buffer[1] & 0x7F) as usize; + + let (header_len, payload_len) = if payload_len_7 == 126 { + if buffer.len() < 4 { return WssFrameResult::Incomplete; } + (4, u16::from_be_bytes([buffer[2], buffer[3]]) as usize) + } else if payload_len_7 == 127 { + if buffer.len() < 10 { return WssFrameResult::Incomplete; } + (10, u64::from_be_bytes([buffer[2], buffer[3], buffer[4], buffer[5], buffer[6], buffer[7], buffer[8], buffer[9]]) as usize) + } else { + (2, payload_len_7) + }; + + let mask_offset = header_len; + let full_header_len = header_len + if is_masked { 4 } else { 0 }; + let total_frame_len = full_header_len + payload_len; + + if buffer.len() < total_frame_len { + return WssFrameResult::Incomplete; + } + + let mut payload = buffer[full_header_len..total_frame_len].to_vec(); + if is_masked { + let mask = [buffer[mask_offset], buffer[mask_offset+1], buffer[mask_offset+2], buffer[mask_offset+3]]; + for (i, b) in payload.iter_mut().enumerate() { + *b ^= mask[i % 4]; + } + } + + WssFrameResult::Frame { payload, total_len: total_frame_len } +} diff --git a/ostp-core/src/protocol.rs b/ostp-core/src/protocol.rs index b95e204..463186b 100644 --- a/ostp-core/src/protocol.rs +++ b/ostp-core/src/protocol.rs @@ -235,187 +235,195 @@ impl ProtocolMachine { } if self.state == OstpState::Handshaking { - // Wire format: [session_id:4][noise_len:2][noise_payload:N][random_padding:*] - // Extract noise_len to pass exactly the right bytes to snow - if raw_vec.len() < 6 { - return Err(ProtocolError::Framing("handshake too short for length prefix".to_string())); - } - let noise_len = u16::from_be_bytes([raw_vec[4], raw_vec[5]]) as usize; - if raw_vec.len() < 6 + noise_len { - return Err(ProtocolError::Framing(format!( - "handshake truncated: expected {} noise bytes, got {}", - noise_len, raw_vec.len() - 6 - ))); - } - tracing::info!("handle_inbound: raw_vec.len()={}, noise_len={}, raw_vec[0..6]={:?}", raw_vec.len(), noise_len, &raw_vec[0..6]); - - let mut read_out = vec![0_u8; 1024]; - let n = self.noise.read_handshake(&raw_vec[6..6 + noise_len], &mut read_out).map_err(|e| { - ProtocolError::Crypto(format!("noise-read: {:?} (raw_len={}, noise_len={})", e, raw_vec.len(), noise_len)) - })?; - read_out.truncate(n); - - let response = match self.role { - NoiseRole::Responder => { - let mut write_out = vec![0_u8; 1024]; - let out_n = self.noise.write_handshake(&self.handshake_payload, &mut write_out)?; - write_out.truncate(out_n); - Some(self.wrap_datagram_handshake(&write_out)?) - } - NoiseRole::Initiator => None, - }; - - let mut key = [0_u8; 32]; - self.noise.handshake_hash(&mut key)?; - let (send_key, recv_key) = derive_split_keys(&key, self.role); - self.send_cipher = Some(SessionCipher::new(&send_key)); - self.recv_cipher = Some(SessionCipher::new(&recv_key)); - self.state = OstpState::Established; - - let extracted_payload = read_out[..n].to_vec(); - - Ok(ProtocolAction::HandshakePayload(Bytes::from(extracted_payload), response)) + self.handle_handshake_inbound(&raw_vec) } else if self.state == OstpState::Established { - if raw_vec.len() < 12 { - return Err(ProtocolError::Framing("data datagram too short".to_string())); + self.handle_data_inbound(&raw_vec) + } else { + Ok(ProtocolAction::Noop) + } + } + + fn handle_handshake_inbound(&mut self, raw_vec: &[u8]) -> Result { + // Wire format: [session_id:4][noise_len:2][noise_payload:N][random_padding:*] + // Extract noise_len to pass exactly the right bytes to snow + if raw_vec.len() < 6 { + return Err(ProtocolError::Framing("handshake too short for length prefix".to_string())); + } + let noise_len = u16::from_be_bytes([raw_vec[4], raw_vec[5]]) as usize; + if raw_vec.len() < 6 + noise_len { + return Err(ProtocolError::Framing(format!( + "handshake truncated: expected {} noise bytes, got {}", + noise_len, raw_vec.len() - 6 + ))); + } + tracing::info!("handle_inbound: raw_vec.len()={}, noise_len={}, raw_vec[0..6]={:?}", raw_vec.len(), noise_len, &raw_vec[0..6]); + + let mut read_out = vec![0_u8; 1024]; + let n = self.noise.read_handshake(&raw_vec[6..6 + noise_len], &mut read_out).map_err(|e| { + ProtocolError::Crypto(format!("noise-read: {:?} (raw_len={}, noise_len={})", e, raw_vec.len(), noise_len)) + })?; + read_out.truncate(n); + + let response = match self.role { + NoiseRole::Responder => { + let mut write_out = vec![0_u8; 1024]; + let out_n = self.noise.write_handshake(&self.handshake_payload, &mut write_out)?; + write_out.truncate(out_n); + Some(self.wrap_datagram_handshake(&write_out)?) } - let nonce = u64::from_be_bytes(raw_vec[4..12].try_into().unwrap()); - - if nonce < self.expected_recv_nonce { - // Duplicate — the ACK we sent was likely lost or delayed. - tracing::debug!("Duplicate frame nonce={} (expected {}), forcing ACK", nonce, self.expected_recv_nonce); - if let Some(ack_frame) = self.force_build_ack()? { - return Ok(ProtocolAction::SendDatagram(ack_frame)); + NoiseRole::Initiator => None, + }; + + let mut key = [0_u8; 32]; + self.noise.handshake_hash(&mut key)?; + let (send_key, recv_key) = derive_split_keys(&key, self.role); + self.send_cipher = Some(SessionCipher::new(&send_key)); + self.recv_cipher = Some(SessionCipher::new(&recv_key)); + self.state = OstpState::Established; + + let extracted_payload = read_out[..n].to_vec(); + + Ok(ProtocolAction::HandshakePayload(Bytes::from(extracted_payload), response)) + } + + fn handle_data_inbound(&mut self, raw_vec: &[u8]) -> Result { + if raw_vec.len() < 12 { + return Err(ProtocolError::Framing("data datagram too short".to_string())); + } + let nonce = u64::from_be_bytes(raw_vec[4..12].try_into().unwrap()); + + if nonce < self.expected_recv_nonce { + // Duplicate — the ACK we sent was likely lost or delayed. + tracing::debug!("Duplicate frame nonce={} (expected {}), forcing ACK", nonce, self.expected_recv_nonce); + if let Some(ack_frame) = self.force_build_ack()? { + return Ok(ProtocolAction::SendDatagram(ack_frame)); + } + return Ok(ProtocolAction::Noop); + } + + if nonce > self.expected_recv_nonce + self.max_reorder { + tracing::debug!("Frame nonce={} exceeds max reorder window (expected={}, max_gap={}), sending NACK", + nonce, self.expected_recv_nonce, self.max_reorder + ); + if let Ok(nack_frame) = self.build_control_datagram( + 0, + FrameKind::Nack, + Bytes::copy_from_slice(&self.expected_recv_nonce.to_be_bytes()), + ) { + return Ok(ProtocolAction::SendDatagram(nack_frame)); + } + return Ok(ProtocolAction::Noop); + } + + let ciphertext = &raw_vec[12..]; + let cipher = self.recv_cipher.as_ref().ok_or_else(|| { + ProtocolError::State("missing recv cipher".to_string()) + })?; + + let session_id_bytes = self.session_id.to_be_bytes(); + let plaintext = cipher.decrypt(nonce, ciphertext, &session_id_bytes)?; + + let packet = FramedPacket::decode_zero_copy(Bytes::from(plaintext))?; + + let mut outbound_actions = Vec::new(); + + // Fast path processing for Nacks: act immediately, bypass sequence queue + if packet.header.kind == FrameKind::Nack + && packet.payload.len() >= 8 { + let req_nonce = u64::from_be_bytes(packet.payload[..8].try_into().unwrap()); + if let Some(cached_frame) = self.lookup_sent_frame(req_nonce) { + tracing::debug!("NACK received: retransmitting nonce={}", req_nonce); + self.cc.on_loss(cached_frame.len() as u64); + outbound_actions.push(ProtocolAction::SendDatagram(cached_frame)); + } else { + tracing::debug!("NACK received: nonce={} not found in sent_history (evicted)", req_nonce); + // Estimate ~1200 bytes lost for evicted frames + self.cc.on_loss(1200); } - return Ok(ProtocolAction::Noop); } - if nonce > self.expected_recv_nonce + self.max_reorder { - tracing::debug!("Frame nonce={} exceeds max reorder window (expected={}, max_gap={}), sending NACK", - nonce, self.expected_recv_nonce, self.max_reorder - ); - if let Ok(nack_frame) = self.build_control_datagram( - 0, - FrameKind::Nack, - Bytes::copy_from_slice(&self.expected_recv_nonce.to_be_bytes()), - ) { - return Ok(ProtocolAction::SendDatagram(nack_frame)); - } - return Ok(ProtocolAction::Noop); - } + if packet.header.kind == FrameKind::Ack { + let ranges = parse_ack_ranges(&packet.payload)?; + self.drop_acked_frames(&ranges); + } - let ciphertext = &raw_vec[12..]; - let cipher = self.recv_cipher.as_ref().ok_or_else(|| { - ProtocolError::State("missing recv cipher".to_string()) + let action = match packet.header.kind { + FrameKind::Data => { + ProtocolAction::DeliverApp(packet.header.stream_id, packet.payload) + } + FrameKind::Resume => { + // 0-RTT: treat early data as application data + tracing::info!("0-RTT Resume frame received, processing early data"); + ProtocolAction::DeliverApp(packet.header.stream_id, packet.payload) + } + FrameKind::Close => { + tracing::info!("Received Close frame, terminating session"); + self.state = OstpState::Closed; + ProtocolAction::Noop + } + FrameKind::KeepAlive => ProtocolAction::Noop, + _ => ProtocolAction::Noop, + }; + + let mut app_actions = Vec::new(); + + if matches!(packet.header.kind, FrameKind::Data | FrameKind::Close | FrameKind::KeepAlive) { + self.ack_pending = true; + } + + if nonce == self.expected_recv_nonce { + app_actions.push(action); + self.expected_recv_nonce = self.expected_recv_nonce.checked_add(1).ok_or_else(|| { + ProtocolError::Crypto("recv nonce sequence exhausted".to_string()) })?; + self.last_recv_advance = Instant::now(); - let session_id_bytes = self.session_id.to_be_bytes(); - let plaintext = cipher.decrypt(nonce, ciphertext, &session_id_bytes)?; - - let packet = FramedPacket::decode_zero_copy(Bytes::from(plaintext))?; - - let mut outbound_actions = Vec::new(); - - // Fast path processing for Nacks: act immediately, bypass sequence queue - if packet.header.kind == FrameKind::Nack - && packet.payload.len() >= 8 { - let req_nonce = u64::from_be_bytes(packet.payload[..8].try_into().unwrap()); - if let Some(cached_frame) = self.lookup_sent_frame(req_nonce) { - tracing::debug!("NACK received: retransmitting nonce={}", req_nonce); - self.cc.on_loss(cached_frame.len() as u64); - outbound_actions.push(ProtocolAction::SendDatagram(cached_frame)); - } else { - tracing::debug!("NACK received: nonce={} not found in sent_history (evicted)", req_nonce); - // Estimate ~1200 bytes lost for evicted frames - self.cc.on_loss(1200); - } - } - - if packet.header.kind == FrameKind::Ack { - let ranges = parse_ack_ranges(&packet.payload)?; - self.drop_acked_frames(&ranges); - } - - let action = match packet.header.kind { - FrameKind::Data => { - ProtocolAction::DeliverApp(packet.header.stream_id, packet.payload) - } - FrameKind::Resume => { - // 0-RTT: treat early data as application data - tracing::info!("0-RTT Resume frame received, processing early data"); - ProtocolAction::DeliverApp(packet.header.stream_id, packet.payload) - } - FrameKind::Close => { - tracing::info!("Received Close frame, terminating session"); - self.state = OstpState::Closed; - ProtocolAction::Noop - } - FrameKind::KeepAlive => ProtocolAction::Noop, - _ => ProtocolAction::Noop, - }; - - let mut app_actions = Vec::new(); - - if matches!(packet.header.kind, FrameKind::Data | FrameKind::Close | FrameKind::KeepAlive) { - self.ack_pending = true; - } - - if nonce == self.expected_recv_nonce { - app_actions.push(action); + // Drain continuous queue + while let Some(buffered_action) = self.reorder_buffer.remove(&self.expected_recv_nonce) { + app_actions.push(buffered_action); self.expected_recv_nonce = self.expected_recv_nonce.checked_add(1).ok_or_else(|| { ProtocolError::Crypto("recv nonce sequence exhausted".to_string()) })?; - self.last_recv_advance = Instant::now(); - - // Drain continuous queue - while let Some(buffered_action) = self.reorder_buffer.remove(&self.expected_recv_nonce) { - app_actions.push(buffered_action); - self.expected_recv_nonce = self.expected_recv_nonce.checked_add(1).ok_or_else(|| { - ProtocolError::Crypto("recv nonce sequence exhausted".to_string()) - })?; - } - self.last_recv_advance = Instant::now(); - } else { - // Gap detected - if self.reorder_buffer.len() < self.max_reorder_buffer { - self.reorder_buffer.insert(nonce, action); - } else { - tracing::warn!("Reorder buffer full ({}/{}), dropping frame nonce={}", - self.reorder_buffer.len(), self.max_reorder_buffer, nonce - ); - } - - // Rate-limited NACK: send at most once per 30ms to prevent retransmit storms. - // Under high load with natural UDP reordering, sending a NACK per packet - // causes exponential retransmit explosion that saturates the channel. - let nack_cooldown = Duration::from_millis(30); - if self.last_nack_sent.elapsed() >= nack_cooldown { - self.last_nack_sent = Instant::now(); - let nack_payload = self.expected_recv_nonce.to_be_bytes(); - if let Ok(nack_frame) = self.build_control_datagram(0, FrameKind::Nack, Bytes::copy_from_slice(&nack_payload)) { - outbound_actions.push(ProtocolAction::SendDatagram(nack_frame)); - } - } - } - - if let Some(ack_frame) = self.build_ack_if_due()? { - outbound_actions.push(ProtocolAction::SendDatagram(ack_frame)); - } - - // Collate both types of output (application payloads and wire actions like Nacks/Retransmissions) - let mut all_actions = Vec::new(); - all_actions.extend(outbound_actions); - all_actions.extend(app_actions); - - if all_actions.is_empty() { - Ok(ProtocolAction::Noop) - } else if all_actions.len() == 1 { - Ok(all_actions.pop().unwrap()) - } else { - Ok(ProtocolAction::Multiple(all_actions)) } + self.last_recv_advance = Instant::now(); } else { + // Gap detected + if self.reorder_buffer.len() < self.max_reorder_buffer { + self.reorder_buffer.insert(nonce, action); + } else { + tracing::warn!("Reorder buffer full ({}/{}), dropping frame nonce={}", + self.reorder_buffer.len(), self.max_reorder_buffer, nonce + ); + } + + // Rate-limited NACK: send at most once per 30ms to prevent retransmit storms. + // Under high load with natural UDP reordering, sending a NACK per packet + // causes exponential retransmit explosion that saturates the channel. + let nack_cooldown = Duration::from_millis(30); + if self.last_nack_sent.elapsed() >= nack_cooldown { + self.last_nack_sent = Instant::now(); + let nack_payload = self.expected_recv_nonce.to_be_bytes(); + if let Ok(nack_frame) = self.build_control_datagram(0, FrameKind::Nack, Bytes::copy_from_slice(&nack_payload)) { + outbound_actions.push(ProtocolAction::SendDatagram(nack_frame)); + } + } + } + + if let Some(ack_frame) = self.build_ack_if_due()? { + outbound_actions.push(ProtocolAction::SendDatagram(ack_frame)); + } + + // Collate both types of output (application payloads and wire actions like Nacks/Retransmissions) + let mut all_actions = Vec::new(); + all_actions.extend(outbound_actions); + all_actions.extend(app_actions); + + if all_actions.is_empty() { Ok(ProtocolAction::Noop) + } else if all_actions.len() == 1 { + Ok(all_actions.pop().unwrap()) + } else { + Ok(ProtocolAction::Multiple(all_actions)) } } diff --git a/ostp-flutter/lib/main.dart b/ostp-flutter/lib/main.dart index 2ee2592..234670d 100644 --- a/ostp-flutter/lib/main.dart +++ b/ostp-flutter/lib/main.dart @@ -113,6 +113,7 @@ class _HomeScreenState extends State with TickerProviderStateMixin { final transportMode = widget.prefs.getString('transport_mode') ?? 'udp'; final stealthSni = widget.prefs.getString('stealth_sni') ?? 'vk.com'; final stealthPort = widget.prefs.getString('stealth_port') ?? '443'; + final wss = widget.prefs.getBool('wss') ?? false; final mtu = widget.prefs.getString('mtu') ?? '1350'; final muxEnabled = widget.prefs.getBool('mux_enabled') ?? false; final muxSessions = widget.prefs.getString('mux_sessions') ?? '2'; @@ -141,6 +142,7 @@ class _HomeScreenState extends State with TickerProviderStateMixin { "mode": transportMode, "stealth_sni": stealthSni, "stealth_port": int.tryParse(stealthPort) ?? 443, + "wss": wss, }, "multiplex": { "enabled": muxEnabled, @@ -209,6 +211,7 @@ class _HomeScreenState extends State with TickerProviderStateMixin { final transportMode = widget.prefs.getString('transport_mode') ?? 'udp'; final stealthSni = widget.prefs.getString('stealth_sni') ?? 'vk.com'; final stealthPort = widget.prefs.getString('stealth_port') ?? '443'; + final wss = widget.prefs.getBool('wss') ?? false; final mtu = widget.prefs.getString('mtu') ?? '1350'; final muxEnabled = widget.prefs.getBool('mux_enabled') ?? false; final muxSessions = widget.prefs.getString('mux_sessions') ?? '2'; @@ -237,6 +240,7 @@ class _HomeScreenState extends State with TickerProviderStateMixin { "mode": transportMode, "stealth_sni": stealthSni, "stealth_port": int.tryParse(stealthPort) ?? 443, + "wss": wss, }, "multiplex": { "enabled": muxEnabled, @@ -861,7 +865,8 @@ class _SettingsScreenState extends State { bool _obscureKey = true; bool _debugMode = false; - String _transportMode = 'udp'; // 'udp' | 'wss' + bool _wss = false; + String _transportMode = 'udp'; // 'udp' | 'uot' String _tunStack = 'ostp'; // 'system' | 'ostp' bool _muxEnabled = false; late TextEditingController _muxSessionsCtrl; @@ -883,6 +888,7 @@ class _SettingsScreenState extends State { _stealthPortCtrl = TextEditingController(text: widget.prefs.getString('stealth_port') ?? '443'); _pbkCtrl = TextEditingController(text: widget.prefs.getString('pbk') ?? ''); _sidCtrl = TextEditingController(text: widget.prefs.getString('sid') ?? ''); + _wss = widget.prefs.getBool('wss') ?? false; _transportMode = widget.prefs.getString('transport_mode') ?? 'udp'; _tunStack = widget.prefs.getString('tun_stack') ?? 'ostp'; _debugMode = widget.prefs.getBool('debug_mode') ?? false; @@ -921,6 +927,7 @@ class _SettingsScreenState extends State { widget.prefs.setString('ex_ips', _ipsCtrl.text.trim()); widget.prefs.setString('ex_processes', _processesCtrl.text.trim()); widget.prefs.setBool('debug_mode', _debugMode); + widget.prefs.setBool('wss', _wss); widget.prefs.setString('transport_mode', _transportMode); widget.prefs.setString('tun_stack', _tunStack); widget.prefs.setString('stealth_sni', _stealthSniCtrl.text.trim()); @@ -1060,6 +1067,7 @@ class _SettingsScreenState extends State { _stealthSniCtrl.text = uri.queryParameters['sni'] ?? ''; _pbkCtrl.text = uri.queryParameters['pbk'] ?? ''; _sidCtrl.text = uri.queryParameters['sid'] ?? ''; + _wss = uri.queryParameters['wss'] == 'true'; final type = uri.queryParameters['type'] ?? 'udp'; _transportMode = type == 'tcp' || type == 'http' ? 'uot' : 'udp'; _owndns = uri.queryParameters['owndns'] == 'true'; @@ -1221,6 +1229,12 @@ class _SettingsScreenState extends State { ); }), const SizedBox(height: 16), + _buildToggle('WebSocket (WSS)', 'Использовать RFC 6455 (для строгого DPI)', _wss, (val) { + setState(() { + _wss = val; + }); + }), + const SizedBox(height: 16), _buildTextField('Reality PublicKey (pbk)', _pbkCtrl, hint: 'Оставьте пустым для отключения Reality'), _buildTextField('Reality ShortId (sid)', _sidCtrl, hint: 'Опционально (необязательно)'), ], diff --git a/ostp-gui/src-tauri/Cargo.toml b/ostp-gui/src-tauri/Cargo.toml index bd6ca4e..e877eec 100644 --- a/ostp-gui/src-tauri/Cargo.toml +++ b/ostp-gui/src-tauri/Cargo.toml @@ -27,4 +27,5 @@ anyhow = "1" ostp-client = { path = "../../ostp-client" } portable-atomic = "1" json_comments = "0.2" +rand = "0.8" diff --git a/ostp-gui/src-tauri/src/lib.rs b/ostp-gui/src-tauri/src/lib.rs index fff40ba..a140cc3 100644 --- a/ostp-gui/src-tauri/src/lib.rs +++ b/ostp-gui/src-tauri/src/lib.rs @@ -104,6 +104,7 @@ struct InProcessState { struct HelperState { pipe_state: Arc>, cmd_tx: tokio::sync::mpsc::Sender, + token: String, } enum TunnelHandle { @@ -294,7 +295,11 @@ async fn stop_tunnel(state: tauri::State<'_, AppState>) -> Result ).await; } Some(TunnelHandle::Helper(h)) => { - let _ = h.cmd_tx.send("{\"cmd\":\"stop\"}\n".to_string()).await; + let stop_cmd = serde_json::json!({ + "cmd": "stop", + "token": h.token + }).to_string(); + let _ = h.cmd_tx.send(format!("{}\n", stop_cmd)).await; } } Ok(true) @@ -378,8 +383,9 @@ async fn start_tun_via_helper( .output(); } + let auth_token = rand::random::().to_string(); let helper_exe = find_helper_exe().ok_or_else(|| "ostp-tun-helper.exe not found.".to_string())?; - launch_as_admin(&helper_exe).map_err(|e| format!("Failed to launch helper: {}", e))?; + launch_as_admin(&helper_exe, &auth_token).map_err(|e| format!("Failed to launch helper: {}", e))?; tokio::time::sleep(std::time::Duration::from_millis(1500)).await; let socket = tokio::time::timeout(std::time::Duration::from_secs(60), async { @@ -396,7 +402,8 @@ async fn start_tun_via_helper( let mapped = map_to_client_config(raw, "tun"); let start_cmd = serde_json::json!({ "cmd": "start", - "config": serde_json::to_string(&mapped).unwrap_or_default() + "config": serde_json::to_string(&mapped).unwrap_or_default(), + "token": auth_token }).to_string(); let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::channel::(16); @@ -434,7 +441,7 @@ async fn start_tun_via_helper( state_for_task.lock().await.connection_state = 0; }); - guard.tunnel = Some(TunnelHandle::Helper(HelperState { pipe_state, cmd_tx })); + guard.tunnel = Some(TunnelHandle::Helper(HelperState { pipe_state, cmd_tx, token: auth_token })); Ok(true) } @@ -484,26 +491,28 @@ fn find_helper_exe() -> Option { } #[cfg(target_os = "windows")] -fn launch_as_admin(exe: &std::path::PathBuf) -> anyhow::Result<()> { +fn launch_as_admin(exe: &std::path::PathBuf, token: &str) -> anyhow::Result<()> { use std::ffi::OsStr; use std::os::windows::ffi::OsStrExt; use std::ptr::null_mut; use std::path::Path; let exe_wstr: Vec = exe.as_os_str().encode_wide().chain(Some(0)).collect(); let verb_wstr: Vec = OsStr::new("runas").encode_wide().chain(Some(0)).collect(); + let params_str = format!("--token {}", token); + let params_wstr: Vec = OsStr::new(¶ms_str).encode_wide().chain(Some(0)).collect(); #[link(name = "shell32")] extern "system" { fn ShellExecuteW(h: *mut std::ffi::c_void, op: *const u16, f: *const u16, p: *const u16, d: *const u16, s: i32) -> isize; } // Use the GUI executable's directory as the working directory so dependencies are found let cwd_path = std::env::current_exe().unwrap_or_else(|_| std::path::PathBuf::from(".")); let dir_wstr: Vec = cwd_path.parent().unwrap_or(Path::new(".")).as_os_str().encode_wide().chain(Some(0)).collect(); - let ret = unsafe { ShellExecuteW(null_mut(), verb_wstr.as_ptr(), exe_wstr.as_ptr(), null_mut(), dir_wstr.as_ptr(), 0) }; + let ret = unsafe { ShellExecuteW(null_mut(), verb_wstr.as_ptr(), exe_wstr.as_ptr(), params_wstr.as_ptr(), dir_wstr.as_ptr(), 0) }; if ret <= 32 { anyhow::bail!("UAC denied or helper missing."); } Ok(()) } #[cfg(not(target_os = "windows"))] -fn launch_as_admin(_exe: &PathBuf) -> Result<()> { anyhow::bail!("Windows only."); } +fn launch_as_admin(_exe: &PathBuf, _token: &str) -> Result<()> { anyhow::bail!("Windows only."); } #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { diff --git a/ostp-gui/src/index.html b/ostp-gui/src/index.html index f153aec..718c817 100644 --- a/ostp-gui/src/index.html +++ b/ostp-gui/src/index.html @@ -205,6 +205,19 @@ +
+
+ WebSocket (WSS) + Use RFC 6455 framing for strict DPI bypass +
+ +
+
diff --git a/ostp-gui/src/main.js b/ostp-gui/src/main.js index 5d55d91..47fc5b3 100644 --- a/ostp-gui/src/main.js +++ b/ostp-gui/src/main.js @@ -45,6 +45,7 @@ const inOwndns = $('in-owndns'); const groupCustomDns = $('group-custom-dns'); const inTransport = $('in-transport'); const inSni = $('in-stealth-sni'); +const inWss = $('in-wss'); const inPbk = $('in-pbk'); const inSid = $('in-sid'); const inMtu = $('in-mtu'); @@ -239,6 +240,7 @@ async function loadConfigIntoForm() { inSocks.value = c.socks5_bind || '127.0.0.1:1088'; inTransport.value = c.transport?.mode || 'udp'; inSni.value = c.transport?.stealth_sni || ''; + inWss.checked = !!c.transport?.wss; inPbk.value = c.reality?.pbk || ''; inSid.value = c.reality?.sid || ''; inMtu.value = c.mtu || ''; @@ -292,6 +294,7 @@ async function handleSave(silent = false) { rawConfig.transport = rawConfig.transport || {}; rawConfig.transport.mode = inTransport.value; rawConfig.transport.stealth_sni = inSni.value.trim() || undefined; + rawConfig.transport.wss = inWss.checked; const pbk = inPbk.value.trim(); if (pbk) { diff --git a/ostp-jni/Cargo.toml b/ostp-jni/Cargo.toml index bb1e0a3..1ce9757 100644 --- a/ostp-jni/Cargo.toml +++ b/ostp-jni/Cargo.toml @@ -16,7 +16,6 @@ ostp-core = { path = "../ostp-core" } ostp-client = { path = "../ostp-client" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -lazy_static = "1.4" portable-atomic = { workspace = true } tracing-subscriber = "0.3.23" tracing.workspace = true diff --git a/ostp-jni/src/lib.rs b/ostp-jni/src/lib.rs index 2ceda28..278afde 100644 --- a/ostp-jni/src/lib.rs +++ b/ostp-jni/src/lib.rs @@ -1,7 +1,7 @@ use jni::objects::{JClass, JString}; use jni::sys::{jboolean, jstring}; use jni::JNIEnv; -use lazy_static::lazy_static; + use std::collections::VecDeque; use std::sync::{atomic::Ordering, Arc, Mutex}; use tokio::runtime::Runtime; @@ -12,13 +12,19 @@ use ostp_client::tunnel; use ostp_client::app::{BridgeCommand, UiEvent}; use std::io::Write; +static LOG_TX: std::sync::OnceLock> = std::sync::OnceLock::new(); + struct JniLogWriter; impl Write for JniLogWriter { fn write(&mut self, buf: &[u8]) -> std::io::Result { let s = String::from_utf8_lossy(buf).trim().to_string(); if !s.is_empty() { - add_log(s); + if let Some(tx) = LOG_TX.get() { + let _ = tx.send(s); + } else { + add_log(s); + } } Ok(buf.len()) } @@ -38,6 +44,14 @@ static TRACING_INIT: std::sync::Once = std::sync::Once::new(); fn init_tracing() { TRACING_INIT.call_once(|| { + let (tx, rx) = std::sync::mpsc::channel::(); + LOG_TX.set(tx).ok(); + std::thread::spawn(move || { + while let Ok(text) = rx.recv() { + add_log(text); + } + }); + let subscriber = tracing_subscriber::fmt() .with_writer(JniLogWriter) .with_ansi(false) @@ -54,19 +68,23 @@ struct SdkState { cmd_tx: Option>, } -lazy_static! { - static ref STATE: Mutex = Mutex::new(SdkState { - runtime: None, - shutdown_tx: None, - metrics: None, - tun_child: None, - cmd_tx: None, - }); - static ref LOGS: Mutex> = Mutex::new(VecDeque::new()); - static ref JVM: Mutex> = Mutex::new(None); - static ref CLASS_REF: Mutex> = Mutex::new(None); +impl SdkState { + const fn new() -> Self { + Self { + runtime: None, + shutdown_tx: None, + metrics: None, + tun_child: None, + cmd_tx: None, + } + } } +static STATE: Mutex = Mutex::new(SdkState::new()); +static LOGS: Mutex> = Mutex::new(VecDeque::new()); +static JVM: Mutex> = Mutex::new(None); +static CLASS_REF: Mutex> = Mutex::new(None); + fn add_log(text: String) { if let Ok(mut guard) = LOGS.lock() { if guard.len() >= 1000 { diff --git a/ostp-server/src/lib.rs b/ostp-server/src/lib.rs index 92683d7..a710591 100644 --- a/ostp-server/src/lib.rs +++ b/ostp-server/src/lib.rs @@ -470,7 +470,6 @@ async fn run_server_loop( let mut last_empty_app_log = Instant::now() - Duration::from_secs(10); let mut peer_last_seen: HashMap = HashMap::new(); let mut peer_available: HashMap = HashMap::new(); - let peer_timeout = Duration::from_secs(45); loop { tokio::select! { @@ -489,74 +488,13 @@ async fn run_server_loop( } received = udp_rx.recv() => { if let Some((packet, peer)) = received { - let size = packet.len(); - match dispatcher.on_datagram(peer, packet) { - Ok(DispatchOutcome::Unauthorized) => { - let _ = ui_event_tx.send(UiEvent::UnauthorizedProbe { peer: peer.ip(), bytes: size }); - } - Ok(DispatchOutcome::Accepted { responses, app_payloads, peer_addr }) => { - let peer_ip = peer_addr.ip(); - let now = Instant::now(); - peer_last_seen.insert(peer_ip, now); - if !peer_available.get(&peer_ip).copied().unwrap_or(false) { - peer_available.insert(peer_ip, true); - let is_tcp = tcp_map.read().await.contains_key(&peer_addr); - let proto = if is_tcp { "TCP (UoT)" } else { "UDP" }; - let _ = ui_event_tx.send(UiEvent::Log(format!("Client {peer_ip} connected via {proto}"))); - } - - if app_payloads.is_empty() && now.duration_since(last_empty_app_log) > Duration::from_secs(5) { - last_empty_app_log = now; - let _ = ui_event_tx.send(UiEvent::Log(format!( - "Accepted datagrams from {peer_ip} with no app payloads (responses={})", - responses.len() - ))); - } - let _ = ui_event_tx.send(UiEvent::Rx { peer: peer_ip, bytes: size }); - - for resp in responses { - let resp_len = resp.len(); - let mut sent_tcp = false; - { - let map = tcp_map.read().await; - if let Some(tx) = map.get(&peer_addr) { - let _ = tx.try_send(resp.clone()); - sent_tcp = true; - } - } - if !sent_tcp { - let _ = socket.send_to(&resp, peer_addr).await?; - } - let _ = ui_event_tx.send(UiEvent::Tx { peer: peer_ip, bytes: resp_len }); - } - - for (session_id, stream_id, payload) in app_payloads { - let _ = ui_event_tx.send(UiEvent::Log(format!( - "Deliver app payload sid={session_id} stream={stream_id} bytes={}", - payload.len() - ))); - relay::handle_relay_message( - peer_addr, - session_id, - stream_id, - payload, - &mut dispatcher, - &socket, - &mut remotes, - &ui_event_tx, - stream_tx.clone(), - udp_reply_tx.clone(), - connect_tx.clone(), - outbound.clone(), - dns_server.clone(), - debug, - &tcp_map, - ).await?; - } - } - Err(err) => { - let _ = ui_event_tx.send(UiEvent::Log(format!("Protocol error for {peer}: {err}"))); - } + if let Err(e) = handle_udp_packet( + packet, peer, &mut dispatcher, &tcp_map, &socket, &mut remotes, &ui_event_tx, + stream_tx.clone(), udp_reply_tx.clone(), connect_tx.clone(), + outbound.clone(), dns_server.clone(), debug, + &mut peer_last_seen, &mut peer_available, &mut last_empty_app_log + ).await { + tracing::error!("handle_udp_packet error: {}", e); } } } @@ -596,41 +534,11 @@ async fn run_server_loop( } } _ = retransmit_tick.tick() => { - let now = Instant::now(); - for (peer_ip, last_seen) in peer_last_seen.iter() { - let is_available = peer_available.get(peer_ip).copied().unwrap_or(false); - if is_available && now.duration_since(*last_seen) > peer_timeout { - peer_available.insert(*peer_ip, false); - let _ = ui_event_tx.send(UiEvent::Log(format!("Client {peer_ip} disconnected (timeout)"))); - } - } - let (frames, dropped_sessions) = dispatcher.on_tick(); - for (frame, peer_addr) in frames { - let mut sent_tcp = false; - { - let map = tcp_map.read().await; - if let Some(tx) = map.get(&peer_addr) { - let _ = tx.try_send(frame.clone()); - sent_tcp = true; - } - } - if !sent_tcp { - let _ = socket.send_to(&frame, peer_addr).await?; - } - } - for sid in dropped_sessions { - let _ = ui_event_tx.send(UiEvent::Log(format!("Session {sid} expired, releasing resources"))); - let mut streams_to_cancel = Vec::new(); - for &(session_id, stream_id) in remotes.keys() { - if session_id == sid { - streams_to_cancel.push((session_id, stream_id)); - } - } - for key in streams_to_cancel { - if let Some(state) = remotes.remove(&key) { - let _ = state.cancel_tx.try_send(()); - } - } + if let Err(e) = handle_tick( + &mut dispatcher, &tcp_map, &socket, &mut remotes, &ui_event_tx, + &mut peer_last_seen, &mut peer_available + ).await { + tracing::error!("handle_tick error: {}", e); } } } @@ -638,3 +546,142 @@ async fn run_server_loop( Ok(()) } + +async fn handle_udp_packet( + packet: Bytes, + peer: std::net::SocketAddr, + dispatcher: &mut Dispatcher, + tcp_map: &std::sync::Arc>>>, + socket: &std::sync::Arc, + remotes: &mut HashMap<(u32, u16), RemoteState>, + ui_event_tx: &mpsc::UnboundedSender, + stream_tx: mpsc::UnboundedSender<(u32, u16, Vec)>, + udp_reply_tx: mpsc::UnboundedSender<(u32, u16, String, Vec)>, + connect_tx: mpsc::UnboundedSender<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>, + outbound: Option, + dns_server: std::sync::Arc, + debug: bool, + peer_last_seen: &mut HashMap, + peer_available: &mut HashMap, + last_empty_app_log: &mut Instant, +) -> Result<()> { + let size = packet.len(); + match dispatcher.on_datagram(peer, packet) { + Ok(DispatchOutcome::Unauthorized) => { + let _ = ui_event_tx.send(UiEvent::UnauthorizedProbe { peer: peer.ip(), bytes: size }); + } + Ok(DispatchOutcome::Accepted { responses, app_payloads, peer_addr }) => { + let peer_ip = peer_addr.ip(); + let now = Instant::now(); + peer_last_seen.insert(peer_ip, now); + if !peer_available.get(&peer_ip).copied().unwrap_or(false) { + peer_available.insert(peer_ip, true); + let is_tcp = tcp_map.read().await.contains_key(&peer_addr); + let proto = if is_tcp { "TCP (UoT)" } else { "UDP" }; + let _ = ui_event_tx.send(UiEvent::Log(format!("Client {peer_ip} connected via {proto}"))); + } + + if app_payloads.is_empty() && now.duration_since(*last_empty_app_log) > Duration::from_secs(5) { + *last_empty_app_log = now; + let _ = ui_event_tx.send(UiEvent::Log(format!( + "Accepted datagrams from {peer_ip} with no app payloads (responses={})", + responses.len() + ))); + } + let _ = ui_event_tx.send(UiEvent::Rx { peer: peer_ip, bytes: size }); + + for resp in responses { + let resp_len = resp.len(); + let mut sent_tcp = false; + { + let map = tcp_map.read().await; + if let Some(tx) = map.get(&peer_addr) { + let _ = tx.try_send(resp.clone()); + sent_tcp = true; + } + } + if !sent_tcp { + let _ = socket.send_to(&resp, peer_addr).await?; + } + let _ = ui_event_tx.send(UiEvent::Tx { peer: peer_ip, bytes: resp_len }); + } + + for (session_id, stream_id, payload) in app_payloads { + let _ = ui_event_tx.send(UiEvent::Log(format!( + "Deliver app payload sid={session_id} stream={stream_id} bytes={}", + payload.len() + ))); + relay::handle_relay_message( + peer_addr, + session_id, + stream_id, + payload, + dispatcher, + socket, + remotes, + ui_event_tx, + stream_tx.clone(), + udp_reply_tx.clone(), + connect_tx.clone(), + outbound.clone(), + dns_server.clone(), + debug, + tcp_map, + ).await?; + } + } + Err(err) => { + let _ = ui_event_tx.send(UiEvent::Log(format!("Protocol error for {peer}: {err}"))); + } + } + Ok(()) +} + +async fn handle_tick( + dispatcher: &mut Dispatcher, + tcp_map: &std::sync::Arc>>>, + socket: &std::sync::Arc, + remotes: &mut HashMap<(u32, u16), RemoteState>, + ui_event_tx: &mpsc::UnboundedSender, + peer_last_seen: &mut HashMap, + peer_available: &mut HashMap, +) -> Result<()> { + let now = Instant::now(); + let peer_timeout = Duration::from_secs(45); + for (peer_ip, last_seen) in peer_last_seen.iter() { + let is_available = peer_available.get(peer_ip).copied().unwrap_or(false); + if is_available && now.duration_since(*last_seen) > peer_timeout { + peer_available.insert(*peer_ip, false); + let _ = ui_event_tx.send(UiEvent::Log(format!("Client {peer_ip} disconnected (timeout)"))); + } + } + let (frames, dropped_sessions) = dispatcher.on_tick(); + for (frame, peer_addr) in frames { + let mut sent_tcp = false; + { + let map = tcp_map.read().await; + if let Some(tx) = map.get(&peer_addr) { + let _ = tx.try_send(frame.clone()); + sent_tcp = true; + } + } + if !sent_tcp { + let _ = socket.send_to(&frame, peer_addr).await?; + } + } + for sid in dropped_sessions { + let _ = ui_event_tx.send(UiEvent::Log(format!("Session {sid} expired, releasing resources"))); + let mut streams_to_cancel = Vec::new(); + for &(session_id, stream_id) in remotes.keys() { + if session_id == sid { + streams_to_cancel.push((session_id, stream_id)); + } + } + for key in streams_to_cancel { + if let Some(state) = remotes.remove(&key) { + let _ = state.cancel_tx.try_send(()); + } + } + } + Ok(()) +} diff --git a/ostp-server/src/transport/uot.rs b/ostp-server/src/transport/uot.rs index 2ecb9ae..05011cc 100644 --- a/ostp-server/src/transport/uot.rs +++ b/ostp-server/src/transport/uot.rs @@ -9,6 +9,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::{mpsc, RwLock}; use tracing::info; +use ostp_core::framing::wss::{encode_wss_frame, decode_wss_frame, WssFrameResult}; pub async fn handle_tcp_connection( mut stream: S, @@ -40,10 +41,14 @@ where let headers_str = String::from_utf8_lossy(&buf[..header_len]); // Fast-fail scanner bots - if !headers_str.starts_with("GET /stream HTTP/1.1\r\n") { + let wss = if headers_str.starts_with("GET /wss HTTP/1.1\r\n") { + true + } else if headers_str.starts_with("GET /stream HTTP/1.1\r\n") { + false + } else { send_404(&mut stream).await?; anyhow::bail!("invalid request line"); - } + }; // Extract Authorization or Cookie for signature let mut signature_base64 = None; @@ -109,9 +114,13 @@ where anyhow::bail!("unauthorized (invalid HMAC)"); } - // Reply 101 Switching Protocols - let response = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nX-Ostp-Server: 1\r\n\r\n"; - stream.write_all(response.as_bytes()).await?; + if wss { + let response = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nX-Ostp-Server: 1\r\n\r\n"; + stream.write_all(response.as_bytes()).await?; + } else { + let response = "HTTP/1.1 200 OK\r\nX-Ostp-Server: 1\r\nContent-Type: application/octet-stream\r\n\r\n"; + stream.write_all(response.as_bytes()).await?; + } info!("UoT client authenticated from {}", peer_addr); @@ -132,11 +141,16 @@ where let tcp_map_clone = tcp_map.clone(); let writer_task = tokio::spawn(async move { while let Some(packet) = rx.recv().await { - let mut out = BytesMut::with_capacity(2 + packet.len()); - out.put_u16(packet.len() as u16); - out.put_slice(&packet); - if write_half.write_all(&out).await.is_err() { - break; + if wss { + let header = encode_wss_frame(&packet, false); // Server sends unmasked WSS frames + if write_half.write_all(&header).await.is_err() { break; } + } else { + let mut out = BytesMut::with_capacity(2 + packet.len()); + out.put_u16(packet.len() as u16); + out.put_slice(&packet); + if write_half.write_all(&out).await.is_err() { + break; + } } } // Cleanup on writer exit @@ -146,35 +160,57 @@ where // Reader loop let mut buffer = BytesMut::from(leftover); loop { - while buffer.len() < 2 { - let mut temp = [0u8; 1024]; - match read_half.read(&mut temp).await { - Ok(0) | Err(_) => { - writer_task.abort(); - tcp_map.write().await.remove(&peer_addr); - return Ok(()); + if wss { + match decode_wss_frame(&buffer) { + WssFrameResult::Incomplete => { + let mut temp = [0u8; 1024]; + match read_half.read(&mut temp).await { + Ok(0) | Err(_) => { + writer_task.abort(); + tcp_map.write().await.remove(&peer_addr); + return Ok(()); + } + Ok(n) => buffer.extend_from_slice(&temp[..n]), + } + } + WssFrameResult::Frame { payload, total_len } => { + let _ = buffer.split_to(total_len); + if udp_tx.send((Bytes::from(payload), peer_addr)).await.is_err() { + break; + } + } } - Ok(n) => buffer.extend_from_slice(&temp[..n]), - } - } - - let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize; - - while buffer.len() < 2 + len { - let mut temp = [0u8; 1024]; - match read_half.read(&mut temp).await { - Ok(0) | Err(_) => { - writer_task.abort(); - tcp_map.write().await.remove(&peer_addr); - return Ok(()); + } else { + while buffer.len() < 2 { + let mut temp = [0u8; 1024]; + match read_half.read(&mut temp).await { + Ok(0) | Err(_) => { + writer_task.abort(); + tcp_map.write().await.remove(&peer_addr); + return Ok(()); + } + Ok(n) => buffer.extend_from_slice(&temp[..n]), } - Ok(n) => buffer.extend_from_slice(&temp[..n]), } - } - - let packet = buffer.split_to(2 + len); - if udp_tx.send((Bytes::from(packet[2..].to_vec()), peer_addr)).await.is_err() { - break; + + let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize; + + while buffer.len() < 2 + len { + let mut temp = [0u8; 1024]; + match read_half.read(&mut temp).await { + Ok(0) | Err(_) => { + writer_task.abort(); + tcp_map.write().await.remove(&peer_addr); + return Ok(()); + } + Ok(n) => buffer.extend_from_slice(&temp[..n]), + } + } + + let packet = buffer.split_to(2 + len); + if udp_tx.send((Bytes::from(packet[2..].to_vec()), peer_addr)).await.is_err() { + break; + } } } diff --git a/ostp-tun-helper/src/main.rs b/ostp-tun-helper/src/main.rs index aff6e6d..635149c 100644 --- a/ostp-tun-helper/src/main.rs +++ b/ostp-tun-helper/src/main.rs @@ -27,8 +27,8 @@ const BIND_ADDR: &str = "127.0.0.1:53211"; #[derive(Deserialize)] #[serde(tag = "cmd", rename_all = "lowercase")] enum GuiCmd { - Start { config: String }, - Stop, + Start { config: String, token: String }, + Stop { token: String }, } #[derive(Serialize)] @@ -54,15 +54,29 @@ async fn main() -> Result<()> { } } + let mut expected_token = String::new(); + let args: Vec = std::env::args().collect(); + for i in 1..args.len() { + if args[i] == "--token" && i + 1 < args.len() { + expected_token = args[i + 1].clone(); + } + } + log_to_file("Helper started (TCP mode)"); - if let Err(e) = run_server().await { + + if expected_token.is_empty() { + log_to_file("FATAL: --token argument is required for security. Unauthorized access denied."); + return Err(anyhow::anyhow!("--token argument is required")); + } + + if let Err(e) = run_server(expected_token).await { log_to_file(&format!("Fatal error: {}", e)); } log_to_file("Helper exiting"); Ok(()) } -async fn run_server() -> Result<()> { +async fn run_server(expected_token: String) -> Result<()> { let state = Arc::new(Mutex::new(TunnelState { shutdown_tx: None, metrics: None, @@ -127,7 +141,12 @@ async fn run_server() -> Result<()> { }; match cmd { - GuiCmd::Start { config } => { + GuiCmd::Start { config, token } => { + if token != expected_token { + log_to_file("Received START command with invalid token"); + send_msg(HelperMsg::Error { message: "Invalid authorization token".to_string() }); + continue; + } log_to_file("Received START command"); { let mut st = state.lock().await; @@ -202,7 +221,12 @@ async fn run_server() -> Result<()> { send_msg(HelperMsg::Status { value: 1 }); } - GuiCmd::Stop => { + GuiCmd::Stop { token } => { + if token != expected_token { + log_to_file("Received STOP command with invalid token"); + send_msg(HelperMsg::Error { message: "Invalid authorization token".to_string() }); + continue; + } log_to_file("Received STOP command"); let mut st = state.lock().await; if let Some(tx) = st.shutdown_tx.take() { diff --git a/ostp/src/main.rs b/ostp/src/main.rs index f673600..18165a3 100644 --- a/ostp/src/main.rs +++ b/ostp/src/main.rs @@ -72,6 +72,7 @@ fn parse_ostp_link(link: &str) -> Result { let mut transport_mode = String::from("udp"); let mut tun_enabled = false; let mut tun_dns = None; + let mut wss_enabled = false; for (k, v) in parsed.query_pairs() { match k.as_ref() { @@ -83,6 +84,7 @@ fn parse_ostp_link(link: &str) -> Result { "type" => transport_mode = v.into_owned(), "tun" => tun_enabled = v == "true", "dns" => tun_dns = Some(v.into_owned()), + "wss" => wss_enabled = v == "true", _ => {} } } @@ -95,6 +97,7 @@ fn parse_ostp_link(link: &str) -> Result { mode: Some(transport_mode), stealth_sni: Some(sni.clone()), stealth_port: Some(443), + wss: Some(wss_enabled), }), socks5_bind: Some("127.0.0.1:1088".to_string()), tun: Some(TunConfig { @@ -331,6 +334,7 @@ struct TransportConfigRaw { mode: Option, stealth_sni: Option, stealth_port: Option, + wss: Option, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -727,7 +731,8 @@ async fn run_app() -> Result<()> { "transport": {{ "mode": "udp", "stealth_sni": "www.microsoft.com", - "stealth_port": 443 + "stealth_port": 443, + "wss": false }}, "mux": {{ @@ -1098,6 +1103,7 @@ async fn run_client_directly(client_cfg: ClientConfig) -> Result<()> { mode: client_cfg.transport.as_ref().and_then(|t| t.mode.clone()).unwrap_or_else(|| "udp".to_string()), stealth_sni: client_cfg.transport.as_ref().and_then(|t| t.stealth_sni.clone()).unwrap_or_else(|| "microsoft.com".to_string()), stealth_port: client_cfg.transport.as_ref().and_then(|t| t.stealth_port).unwrap_or(443), + wss: client_cfg.transport.as_ref().and_then(|t| t.wss).unwrap_or(false), }, dns_server: client_cfg.tun.as_ref().and_then(|t| t.dns.clone()), };