From 29e9ef739c95c8f31b93ed883d6c99662837d979 Mon Sep 17 00:00:00 2001 From: ospab Date: Wed, 3 Jun 2026 02:06:06 +0300 Subject: [PATCH] Refactor: Phase 1 and 2 - Async architecture, JNI fixes, SmolTCP data races, and Tunnel optimizations --- app-icon.svg | 15 + netstack-smoltcp/src/device.rs | 14 +- netstack-smoltcp/src/runner.rs | 9 +- ostp-client/Cargo.toml | 1 + ostp-client/src/bridge.rs | 1092 +++++++++++----------- ostp-client/src/config.rs | 2 +- ostp-client/src/sysproxy.rs | 6 +- ostp-client/src/tunnel/exclusion.rs | 134 +++ ostp-client/src/tunnel/mod.rs | 4 + ostp-client/src/tunnel/native_handler.rs | 60 +- ostp-client/src/tunnel/process_lookup.rs | 142 +++ ostp-client/src/tunnel/proxy.rs | 136 +-- ostp-client/src/tunnel/sni_sniff.rs | 73 ++ ostp-core/src/congestion.rs | 91 +- ostp-core/src/protocol.rs | 8 +- ostp-core/src/relay.rs | 2 +- ostp-flutter/devtools_options.yaml | 3 + ostp-flutter/pubspec.lock | 80 ++ ostp-flutter/pubspec.yaml | 2 + ostp-gui/src-tauri/Cargo.lock | 5 +- ostp-gui/src-tauri/src/lib.rs | 59 +- ostp-gui/src/main.js | 11 +- ostp-jni/OstpClientSdk.kt | 48 +- ostp-jni/src/lib.rs | 74 +- ostp-server/src/api.rs | 38 +- ostp-server/src/dispatcher.rs | 26 +- ostp-server/src/lib.rs | 10 +- ostp-server/src/relay_node.rs | 6 +- ostp-tun-helper/src/main.rs | 134 ++- refactor.py | 658 +++++++++++++ 30 files changed, 2079 insertions(+), 864 deletions(-) create mode 100644 app-icon.svg create mode 100644 ostp-client/src/tunnel/exclusion.rs create mode 100644 ostp-client/src/tunnel/process_lookup.rs create mode 100644 ostp-client/src/tunnel/sni_sniff.rs create mode 100644 ostp-flutter/devtools_options.yaml create mode 100644 refactor.py diff --git a/app-icon.svg b/app-icon.svg new file mode 100644 index 0000000..24d1384 --- /dev/null +++ b/app-icon.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/netstack-smoltcp/src/device.rs b/netstack-smoltcp/src/device.rs index e9cabbf..9369910 100644 --- a/netstack-smoltcp/src/device.rs +++ b/netstack-smoltcp/src/device.rs @@ -16,6 +16,7 @@ pub(super) struct VirtualDevice { in_buf: UnboundedReceiver>, out_buf: Sender, mtu: usize, + cached_packet: Option>, } impl VirtualDevice { @@ -31,6 +32,7 @@ impl VirtualDevice { in_buf: iface_ingress_rx, out_buf: iface_egress_tx, mtu, + cached_packet: None, }, iface_ingress_tx, iface_ingress_tx_avail, @@ -43,12 +45,18 @@ impl Device for VirtualDevice { type TxToken<'a> = VirtualTxToken<'a>; fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - let Ok(buffer) = self.in_buf.try_recv() else { - self.in_buf_avail.store(false, Ordering::Release); - return None; + let buffer = if let Some(buf) = self.cached_packet.take() { + buf + } else { + let Ok(buf) = self.in_buf.try_recv() else { + self.in_buf_avail.store(false, Ordering::Release); + return None; + }; + buf }; let Ok(permit) = self.out_buf.try_reserve() else { + self.cached_packet = Some(buffer); self.in_buf_avail.store(false, Ordering::Release); return None; }; diff --git a/netstack-smoltcp/src/runner.rs b/netstack-smoltcp/src/runner.rs index 63b9675..dedf7e8 100644 --- a/netstack-smoltcp/src/runner.rs +++ b/netstack-smoltcp/src/runner.rs @@ -12,23 +12,24 @@ use std::{ /// require two sets of API interfaces in single-threaded and multi-threaded. /// /// [BoxFuture in crate futures utils]: https://docs.rs/futures-util/latest/futures_util/future/type.BoxFuture.html -pub struct BoxFuture<'a, T>(Pin + 'a>>); +pub struct BoxFuture<'a, T>(Pin + Send + 'a>>); impl<'a, T> BoxFuture<'a, T> { pub fn new(f: F) -> BoxFuture<'a, T> where - F: IntoFuture + 'a, + F: IntoFuture + Send + 'a, + F::IntoFuture: Send + 'a, { BoxFuture(Box::pin(f.into_future())) } #[allow(unused)] - pub fn wrap(f: Pin + 'a>>) -> BoxFuture<'a, T> { + pub fn wrap(f: Pin + Send + 'a>>) -> BoxFuture<'a, T> { BoxFuture(f) } } -unsafe impl Send for BoxFuture<'_, T> {} + impl Future for BoxFuture<'_, T> { type Output = T; diff --git a/ostp-client/Cargo.toml b/ostp-client/Cargo.toml index 6d63660..dd6c83f 100644 --- a/ostp-client/Cargo.toml +++ b/ostp-client/Cargo.toml @@ -29,3 +29,4 @@ libc = "0.2.186" x25519-dalek = "2.0.1" chacha20poly1305.workspace = true hex = "0.4.3" +winapi = { version = "0.3.9", features = ["iphlpapi", "tcpmib", "processthreadsapi", "psapi", "handleapi", "winerror", "minwindef", "winnt"] } diff --git a/ostp-client/src/bridge.rs b/ostp-client/src/bridge.rs index 852f8c3..602df97 100644 --- a/ostp-client/src/bridge.rs +++ b/ostp-client/src/bridge.rs @@ -117,6 +117,7 @@ impl Bridge { }) } + pub async fn run( mut self, tx: mpsc::Sender, @@ -137,7 +138,7 @@ impl Bridge { let mut sessions_opt: Option> = None; let mut udp_rx_opt: Option> = None; - let mut _proxy_guard: Option = None; + let mut proxy_guard: Option = None; let mut stream_map: std::collections::HashMap = std::collections::HashMap::new(); loop { @@ -147,7 +148,11 @@ impl Bridge { if *shutdown.borrow() { self.running = false; self.metrics.connection_state.store(0, Ordering::Relaxed); - _proxy_guard = None; + proxy_guard = None; + sessions_opt = None; + udp_rx_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "manual stop"); break; } } @@ -157,286 +162,11 @@ impl Bridge { None => std::future::pending().await, } }, if self.running => { - match udp_msg { - Some((session_index, inbound)) => { - self.metrics.bytes_recv.fetch_add(inbound.len() as u64, Ordering::Relaxed); - self.last_valid_recv = Instant::now(); - if let Some(sessions) = sessions_opt.as_mut() { - if session_index < sessions.len() { - let session = &mut sessions[session_index]; - let initial_action = match session.machine.on_event(OstpEvent::Inbound(inbound)) { - Ok(a) => a, - Err(e) => { - let _ = tx.send(UiEvent::Log(format!("Protocol decrypt error: {e}"))).await; - tracing::warn!("Inbound protocol error (session {}): {}", session_index, e); - continue; - } - }; - - let mut actions_queue = std::collections::VecDeque::new(); - actions_queue.push_back(initial_action); - - while let Some(current_action) = actions_queue.pop_front() { - match current_action { - ProtocolAction::Multiple(nested) => { - for a in nested { - actions_queue.push_back(a); - } - } - ProtocolAction::DeliverApp(stream_id, dec_payload) => { - match RelayMessage::decode(&dec_payload) { - Ok(relay_msg) => { - match relay_msg { - RelayMessage::ConnectOk => { - let _ = tx.send(UiEvent::Log(format!("Relay CONNECT OK stream_id={stream_id}"))).await; - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::ConnectOk)); - } - RelayMessage::Data(data) => { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Data(Bytes::from(data)))); - } - RelayMessage::Close => { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Close)); - } - RelayMessage::Error(msg) => { - let _ = tx.send(UiEvent::Log(format!("Relay error for stream {stream_id}: {msg}"))).await; - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error(msg))); - } - RelayMessage::Pong(ts) => { - let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; - self.last_rtt_ms = now.saturating_sub(ts) as f64; - self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed); - } - RelayMessage::UdpAssociate => { - // Should not be received by client, ignore - } - RelayMessage::UdpData(target, data) => { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data)))); - } - RelayMessage::KeepAlive | RelayMessage::Ping(_) | RelayMessage::Connect(_) => {} - } - } - Err(err) => { - let _ = tx.send(UiEvent::Log(format!("Relay decode error for stream {stream_id}: {err}"))).await; - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("relay decode failed".to_string()))); - } - } - } - ProtocolAction::SendDatagram(frame) => { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - _ => {} - } - } - } - } - } - None => { - let _ = tx.send(UiEvent::Log("UDP channel closed, resetting connection".to_string())).await; - self.running = false; - crate::sysproxy::disable_windows_proxy(); - sessions_opt = None; - udp_rx_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "udp reader closed"); - let _ = tx.send(UiEvent::TunnelStopped).await; - } - } + self.handle_inbound_udp(udp_msg, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await; } cmd = bridge_rx.recv() => { - match cmd { - Some(BridgeCommand::ToggleTunnel) => { - if self.running { - self.running = false; - self.metrics.connection_state.store(0, Ordering::Relaxed); - _proxy_guard = None; - sessions_opt = None; - udp_rx_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "manual stop"); - tx.send(UiEvent::TunnelStopped).await.ok(); - let stop_msg = if self.mode == "tun" { "TUN tunnel stopped" } else { "Bridge stopped" }; - tx.send(UiEvent::Log(stop_msg.to_string())).await.ok(); - } else { - tx.send(UiEvent::Log("Connecting to remote server...".to_string())).await.ok(); - tx.send(UiEvent::Metrics { status: ConnectionStatus::Handshaking, rtt_ms: 0.0, throughput_bps: 0 }).await.ok(); - self.metrics.connection_state.store(1, Ordering::Relaxed); - - let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; - let (udp_tx, udp_rx) = mpsc::channel(100000); // Increased for high-speed traffic stability - let mut sessions = Vec::with_capacity(session_count); - let mut rtt_sum = 0.0; - let mut successful_sessions = 0; - - for idx in 0..session_count { - let session_id: u32 = rand::thread_rng().gen(); - match self.perform_handshake_with_id(&tx, session_id).await { - Ok((sock, mach, rtt)) => { - let session_index = sessions.len(); - let socket_clone = sock.clone(); - let udp_tx_clone = udp_tx.clone(); - - tokio::spawn(async move { - let mut buf = vec![0_u8; 65535]; - loop { - match socket_clone.recv(&mut buf).await { - Ok(n) => { - let inbound = Bytes::copy_from_slice(&buf[..n]); - if udp_tx_clone.send((session_index, inbound)).await.is_err() { - break; - } - } - Err(e) => { - // Under Windows/Winsock, transient UDP socket errors (like WSAECONNRESET) are returned - // as Err(ConnectionReset). We MUST NOT break the loop on transient errors, otherwise the - // download path will be permanently killed while the upload path keeps running. - tracing::warn!("UDP socket recv error (session {}): {}", session_index, e); - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - } - } - } - }); - - sessions.push(SessionState { socket: sock, machine: mach }); - rtt_sum += rtt; - successful_sessions += 1; - } - Err(err) => { - tx.send(UiEvent::Log(format!("Multiplex session {}/{} handshake failed: {}. Continuing with remaining sessions...", idx + 1, session_count, err))).await.ok(); - } - } - } - - if sessions.is_empty() { - _proxy_guard = None; - tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok(); - tx.send(UiEvent::TunnelStopped).await.ok(); - self.metrics.connection_state.store(0, Ordering::Relaxed); - continue; - } - - udp_rx_opt = Some(udp_rx); - sessions_opt = Some(sessions); - self.last_rtt_ms = rtt_sum / successful_sessions as f64; - self.running = true; - self.last_sample_at = Instant::now(); - self.last_valid_recv = Instant::now(); - - let sys_proxy_addr = self.proxy_addr.replace("0.0.0.0:", "127.0.0.1:"); - _proxy_guard = Some(crate::sysproxy::SystemProxyGuard::enable(&sys_proxy_addr)); - - tx.send(UiEvent::Metrics { - status: ConnectionStatus::Established, - rtt_ms: self.last_rtt_ms, - throughput_bps: 0, - }).await.ok(); - self.metrics.connection_state.store(2, Ordering::Relaxed); - let start_msg = if self.mode == "tun" { "TUN tunnel established" } else { "Connection established" }; - tx.send(UiEvent::Log(start_msg.to_string())).await.ok(); - - // Send an immediate Ping so the UI updates without a 60s delay - for session in sessions_opt.as_mut().unwrap().iter_mut() { - let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; - let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); - if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - } - } - } - Some(BridgeCommand::NextProfile) => { - self.profile = next_profile(self.profile); - tx.send(UiEvent::ProfileChanged(self.profile)).await.ok(); - tx.send(UiEvent::Log(format!("Obfuscation profile switched to {:?}", self.profile))).await.ok(); - } - Some(BridgeCommand::NetworkChanged) => { - if self.running { - // Network changed (e.g. WiFi→LTE): IP address changed, existing UDP - // socket is dead. Trigger immediate reconnect without waiting for stall. - let _ = tx.send(UiEvent::Log("Network changed — starting immediate reconnect".to_string())).await; - self.metrics.connection_state.store(1, Ordering::Relaxed); - self.last_valid_recv = Instant::now() - Duration::from_secs(100); // force stall path - - let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; - let (udp_tx, udp_rx) = mpsc::channel(100000); - let mut new_sessions = Vec::with_capacity(session_count); - let mut successful_sessions = 0; - let mut rtt_sum = 0.0; - - for idx in 0..session_count { - let session_id: u32 = rand::thread_rng().gen(); - match self.perform_handshake_with_id(&tx, session_id).await { - Ok((sock, mach, rtt)) => { - let session_index = new_sessions.len(); - let socket_clone = sock.clone(); - let udp_tx_clone = udp_tx.clone(); - - tokio::spawn(async move { - let mut buf = vec![0_u8; 65535]; - loop { - match socket_clone.recv(&mut buf).await { - Ok(n) => { - let inbound = Bytes::copy_from_slice(&buf[..n]); - if udp_tx_clone.send((session_index, inbound)).await.is_err() { break; } - } - Err(e) => { - tracing::warn!("UDP recv error (network-change session {}): {}", session_index, e); - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - } - } - } - }); - new_sessions.push(SessionState { socket: sock, machine: mach }); - rtt_sum += rtt; - successful_sessions += 1; - } - Err(err) => { - let _ = tx.send(UiEvent::Log(format!("NetworkChanged reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await; - } - } - } - - if !new_sessions.is_empty() { - sessions_opt = Some(new_sessions); - udp_rx_opt = Some(udp_rx); - self.last_rtt_ms = rtt_sum / successful_sessions as f64; - self.last_valid_recv = Instant::now(); - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "network changed"); - self.metrics.connection_state.store(2, Ordering::Relaxed); - let _ = tx.send(UiEvent::Log("NetworkChanged reconnect successful!".to_string())).await; - } else { - let _ = tx.send(UiEvent::Log("NetworkChanged reconnect failed — will retry on keepalive tick".to_string())).await; - } - } - } - Some(BridgeCommand::ReloadConfig) => { - match ClientConfig::reload_from_json_near_binary() { - Ok(cfg) => { - self.apply_runtime_config(&cfg); - tx.send(UiEvent::Log("Runtime config reloaded".to_string())).await.ok(); - if self.running { - self.running = false; - self.metrics.connection_state.store(0, Ordering::Relaxed); - _proxy_guard = None; - sessions_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "config reload"); - // User logic handles UI restart - let _ = tx.send(UiEvent::TunnelStopped).await; - } - } - Err(err) => { - let _ = tx.send(UiEvent::Log(format!("Config reload failed: {err}"))).await; - } - } - } - Some(BridgeCommand::Shutdown) | None => { - self.running = false; - _proxy_guard = None; - break; - } + if !self.handle_bridge_cmd(cmd, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await { + break; } } _ = metrics_tick.tick() => { @@ -446,265 +176,19 @@ impl Bridge { } _ = keepalive_tick.tick() => { if self.running { - // 1. Connection Liveness Check & Silent Background Reconnect - if self.last_valid_recv.elapsed().as_secs() > 25 { - let elapsed = self.last_valid_recv.elapsed().as_secs(); - if elapsed > 180 { - // Hard timeout after 3 minutes of total silence - let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await; - self.running = false; - _proxy_guard = None; - sessions_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout"); - let _ = tx.send(UiEvent::TunnelStopped).await; - self.metrics.connection_state.store(0, Ordering::Relaxed); - continue; - } - - let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await; - self.metrics.connection_state.store(1, Ordering::Relaxed); // State: Connecting (Handshake) - - let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; - let (udp_tx, udp_rx) = mpsc::channel(100000); - let mut new_sessions = Vec::with_capacity(session_count); - let mut successful_sessions = 0; - let mut rtt_sum = 0.0; - - for idx in 0..session_count { - let session_id: u32 = rand::thread_rng().gen(); - match self.perform_handshake_with_id(&tx, session_id).await { - Ok((sock, mach, rtt)) => { - let session_index = new_sessions.len(); - let socket_clone = sock.clone(); - let udp_tx_clone = udp_tx.clone(); - - tokio::spawn(async move { - let mut buf = vec![0_u8; 65535]; - loop { - match socket_clone.recv(&mut buf).await { - Ok(n) => { - let inbound = Bytes::copy_from_slice(&buf[..n]); - if udp_tx_clone.send((session_index, inbound)).await.is_err() { - break; - } - } - Err(e) => { - tracing::warn!("UDP socket recv error (reconnect session {}): {}", session_index, e); - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - } - } - } - }); - - new_sessions.push(SessionState { socket: sock, machine: mach }); - rtt_sum += rtt; - successful_sessions += 1; - } - Err(err) => { - let _ = tx.send(UiEvent::Log(format!("Background reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await; - } - } - } - - if !new_sessions.is_empty() { - sessions_opt = Some(new_sessions); - udp_rx_opt = Some(udp_rx); - self.last_rtt_ms = rtt_sum / successful_sessions as f64; - self.last_valid_recv = Instant::now(); - self.metrics.connection_state.store(2, Ordering::Relaxed); // State: Connected - let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await; - - // Send an immediate Ping so the UI updates without a 60s delay - for session in sessions_opt.as_mut().unwrap().iter_mut() { - let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; - let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); - if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - } - - // FIX: Clear existing proxy streams. Since we are on a NEW session_id, - // the server does not know about our existing streams. Closing them - // forces local apps/TUN to immediately recreate them and send proper - // Connect/UdpAssociate over the new session, avoiding a 5-minute blackhole. - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "background reconnect"); - - // FIX: Flush all stale proxy messages accumulated during the stall/reconnect - // This prevents a massive post-reconnect UDP burst that causes mobile carriers to drop all packets - let mut flushed = 0; - while let Ok(stale) = proxy_rx.try_recv() { - if let ProxyEvent::NewStream { stream_id, .. } = stale { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("connection reset".into()))); - } - flushed += 1; - } - if flushed > 0 { - let _ = tx.send(UiEvent::Log(format!("Flushed {} stale proxy messages to prevent UDP burst", flushed))).await; - } - } else { - let _ = tx.send(UiEvent::Log("Background reconnect failed. Will retry on next tick...".into())).await; - } - } - - // 2. Active Keep-Alive / Heartbeat - if let Some(sessions) = sessions_opt.as_mut() { - for session in sessions.iter_mut() { - // Send Ping (Internal RTT Metric) - let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; - let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); - if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { - // Must go through send_datagram() for TURN-mode wrapping; - // raw socket.send() bypasses the ChannelData header and breaks RTT in TURN. - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - - // Send Relay KeepAlive (Force NAT/Server Persistence) - let ka_payload = Bytes::from(RelayMessage::KeepAlive.encode()); - if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ka_payload)) { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - } - } + self.handle_keepalive(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx, &mut proxy_rx).await; } } _ = retransmit_tick.tick() => { if self.running { - let mut fatal_err = None; - if let Some(sessions) = sessions_opt.as_mut() { - for session in sessions.iter_mut() { - match session.machine.on_event(OstpEvent::Tick) { - Ok(action) => { - let mut queue = vec![action]; - while let Some(current_action) = queue.pop() { - match current_action { - ProtocolAction::Multiple(nested) => { - for a in nested { - queue.push(a); - } - } - ProtocolAction::SendDatagram(frame) => { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - _ => {} - } - } - } - Err(e) => { - fatal_err = Some(e); - break; - } - } - } - } - - if let Some(e) = fatal_err { - let _ = tx.send(UiEvent::Log(format!("Protocol tick fatal error: {e}"))).await; - self.running = false; - _proxy_guard = None; - sessions_opt = None; - udp_rx_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "protocol fatal error"); - let _ = tx.send(UiEvent::TunnelStopped).await; - self.metrics.connection_state.store(0, Ordering::Relaxed); - } + self.handle_retransmit(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await; } } proxy_ev = proxy_rx.recv(), if self.running && sessions_opt.as_ref().map(|s| { - // Backpressure: suspend proxy reads when ARQ window is saturated across ALL sessions s.iter().any(|ses| ses.machine.in_flight_count() < ses.machine.cwnd_packets().clamp(16, 16384)) }).unwrap_or(true) => { - if let Some(ev) = proxy_ev { - if let Some(sessions) = sessions_opt.as_mut() { - if sessions.is_empty() { - if let ProxyEvent::NewStream { stream_id, .. } = ev { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); - } - continue; - } - let (stream_id, relay_msg, is_close) = match ev { - ProxyEvent::NewStream { stream_id, target } => { - let _ = tx.send(UiEvent::Log(format!("Proxy CONNECT stream_id={stream_id} target={target}"))).await; - (stream_id, RelayMessage::Connect(target), false) - } - ProxyEvent::UdpAssociate { stream_id } => { - let _ = tx.send(UiEvent::Log(format!("Proxy UDP ASSOCIATE stream_id={stream_id}"))).await; - (stream_id, RelayMessage::UdpAssociate, false) - } - ProxyEvent::UdpData { stream_id, target, payload } => { - (stream_id, RelayMessage::UdpData(target, payload.to_vec()), false) - } - ProxyEvent::Data { stream_id, payload } => (stream_id, RelayMessage::Data(payload.to_vec()), false), - ProxyEvent::Close { stream_id } => { - let _ = tx.send(UiEvent::Log(format!("Proxy CLOSE stream_id={stream_id}"))).await; - (stream_id, RelayMessage::Close, true) - } - }; - let len = sessions.len(); - let session_index = *stream_map.entry(stream_id).or_insert_with(|| { - // §8 FIX: Load balance multiplexed streams randomly across available connection sockets - rand::thread_rng().gen_range(0..len) - }); - if is_close { - stream_map.remove(&stream_id); - } - let session = &mut sessions[session_index]; - let out_payload = Bytes::from(relay_msg.encode()); - match session.machine.on_event(OstpEvent::Outbound(stream_id, out_payload)) { - Ok(ProtocolAction::SendDatagram(frame)) => { - if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() { - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - tracing::trace!( - "Outbound datagram sent stream_id={stream_id} bytes={}", - frame.len() - ); - } - } - Ok(ProtocolAction::Multiple(list)) => { - let mut sent = 0usize; - for item in list { - if let ProtocolAction::SendDatagram(frame) = item { - if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() { - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - sent += 1; - } - } - } - tracing::trace!( - "Outbound datagram batch stream_id={stream_id} sent={sent}" - ); - } - Ok(ProtocolAction::Noop) => { - tracing::trace!( - "Outbound datagram noop stream_id={stream_id}" - ); - } - Ok(_) => { - tracing::trace!( - "Outbound datagram unexpected action stream_id={stream_id}" - ); - } - Err(e) => { - tracing::warn!("Protocol error packing outbound stream_id={}: {}", stream_id, e); - let _ = tx.send(UiEvent::Log(format!("Protocol error packing TCP: {e}"))).await; - } - } - } else { - // Drop it, not connected - if let ProxyEvent::NewStream { stream_id, .. } = ev { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); - } - } - } + self.handle_proxy_event(proxy_ev, &mut sessions_opt, &mut stream_map, &tx, &proxy_tx).await; } - - } } @@ -712,6 +196,556 @@ impl Bridge { Ok(()) } + async fn handle_inbound_udp( + &mut self, + udp_msg: Option<(usize, Bytes)>, + sessions_opt: &mut Option>, + udp_rx_opt: &mut Option>, + proxy_guard: &mut Option, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + ) { + match udp_msg { + Some((session_index, inbound)) => { + self.metrics.bytes_recv.fetch_add(inbound.len() as u64, Ordering::Relaxed); + self.last_valid_recv = Instant::now(); + if let Some(sessions) = sessions_opt.as_mut() { + if session_index < sessions.len() { + let session = &mut sessions[session_index]; + let initial_action = match session.machine.on_event(OstpEvent::Inbound(inbound)) { + Ok(a) => a, + Err(e) => { + let _ = tx.send(UiEvent::Log(format!("Protocol decrypt error: {e}"))).await; + tracing::warn!("Inbound protocol error (session {}): {}", session_index, e); + return; + } + }; + + let mut actions_queue = std::collections::VecDeque::new(); + actions_queue.push_back(initial_action); + + while let Some(current_action) = actions_queue.pop_front() { + match current_action { + ProtocolAction::Multiple(nested) => { + for a in nested { + actions_queue.push_back(a); + } + } + ProtocolAction::DeliverApp(stream_id, dec_payload) => { + match RelayMessage::decode(&dec_payload) { + Ok(relay_msg) => { + match relay_msg { + RelayMessage::ConnectOk => { + let _ = tx.send(UiEvent::Log(format!("Relay CONNECT OK stream_id={stream_id}"))).await; + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::ConnectOk)); + } + RelayMessage::Data(data) => { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Data(Bytes::from(data)))); + } + RelayMessage::Close => { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Close)); + } + RelayMessage::Error(msg) => { + let _ = tx.send(UiEvent::Log(format!("Relay error for stream {stream_id}: {msg}"))).await; + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error(msg))); + } + RelayMessage::Pong(ts) => { + let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; + self.last_rtt_ms = now.saturating_sub(ts) as f64; + self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed); + } + RelayMessage::UdpAssociate => {} + RelayMessage::UdpData(target, data) => { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data)))); + } + RelayMessage::KeepAlive | RelayMessage::Ping(_) | RelayMessage::Connect(_) => {} + } + } + Err(err) => { + let _ = tx.send(UiEvent::Log(format!("Relay decode error for stream {stream_id}: {err}"))).await; + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("relay decode failed".to_string()))); + } + } + } + ProtocolAction::SendDatagram(frame) => { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + _ => {} + } + } + } + } + } + None => { + let _ = tx.send(UiEvent::Log("UDP channel closed, resetting connection".to_string())).await; + self.running = false; + crate::sysproxy::disable_system_proxy(); + *sessions_opt = None; + *udp_rx_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "udp reader closed"); + let _ = tx.send(UiEvent::TunnelStopped).await; + } + } + } + + async fn handle_bridge_cmd( + &mut self, + cmd: Option, + sessions_opt: &mut Option>, + udp_rx_opt: &mut Option>, + proxy_guard: &mut Option, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + ) -> bool { + match cmd { + Some(BridgeCommand::ToggleTunnel) => { + if self.running { + self.running = false; + self.metrics.connection_state.store(0, Ordering::Relaxed); + *proxy_guard = None; + *sessions_opt = None; + *udp_rx_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "manual stop"); + tx.send(UiEvent::TunnelStopped).await.ok(); + let stop_msg = if self.mode == "tun" { "TUN tunnel stopped" } else { "Bridge stopped" }; + tx.send(UiEvent::Log(stop_msg.to_string())).await.ok(); + } else { + tx.send(UiEvent::Log("Connecting to remote server...".to_string())).await.ok(); + tx.send(UiEvent::Metrics { status: ConnectionStatus::Handshaking, rtt_ms: 0.0, throughput_bps: 0 }).await.ok(); + self.metrics.connection_state.store(1, Ordering::Relaxed); + + let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; + let (udp_tx, udp_rx) = mpsc::channel(100000); + let mut sessions = Vec::with_capacity(session_count); + let mut rtt_sum = 0.0; + let mut successful_sessions = 0; + + for idx in 0..session_count { + let session_id: u32 = rand::thread_rng().gen(); + match self.perform_handshake_with_id(&tx, session_id).await { + Ok((sock, mach, rtt)) => { + let session_index = sessions.len(); + let socket_clone = sock.clone(); + let udp_tx_clone = udp_tx.clone(); + + tokio::spawn(async move { + let mut buf = vec![0_u8; 65535]; + loop { + match socket_clone.recv(&mut buf).await { + Ok(n) => { + let inbound = Bytes::copy_from_slice(&buf[..n]); + if udp_tx_clone.send((session_index, inbound)).await.is_err() { + break; + } + } + Err(e) => { + tracing::warn!("UDP socket recv error (session {}): {}", session_index, e); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + } + }); + + sessions.push(SessionState { socket: sock, machine: mach }); + rtt_sum += rtt; + successful_sessions += 1; + } + Err(err) => { + tx.send(UiEvent::Log(format!("Multiplex session {}/{} handshake failed: {}. Continuing with remaining sessions...", idx + 1, session_count, err))).await.ok(); + } + } + } + + if sessions.is_empty() { + *proxy_guard = None; + tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok(); + tx.send(UiEvent::TunnelStopped).await.ok(); + self.metrics.connection_state.store(0, Ordering::Relaxed); + return true; + } + + *udp_rx_opt = Some(udp_rx); + *sessions_opt = Some(sessions); + self.last_rtt_ms = rtt_sum / successful_sessions as f64; + self.running = true; + self.last_sample_at = Instant::now(); + self.last_valid_recv = Instant::now(); + + let sys_proxy_addr = self.proxy_addr.replace("0.0.0.0:", "127.0.0.1:"); + *proxy_guard = Some(crate::sysproxy::SystemProxyGuard::enable(&sys_proxy_addr)); + + tx.send(UiEvent::Metrics { + status: ConnectionStatus::Established, + rtt_ms: self.last_rtt_ms, + throughput_bps: 0, + }).await.ok(); + self.metrics.connection_state.store(2, Ordering::Relaxed); + let start_msg = if self.mode == "tun" { "TUN tunnel established" } else { "Connection established" }; + tx.send(UiEvent::Log(start_msg.to_string())).await.ok(); + + for session in sessions_opt.as_mut().unwrap().iter_mut() { + let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; + let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); + if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + } + } + } + Some(BridgeCommand::NextProfile) => { + self.profile = next_profile(self.profile); + tx.send(UiEvent::ProfileChanged(self.profile)).await.ok(); + tx.send(UiEvent::Log(format!("Obfuscation profile switched to {:?}", self.profile))).await.ok(); + } + Some(BridgeCommand::NetworkChanged) => { + if self.running { + let _ = tx.send(UiEvent::Log("Network changed — starting immediate reconnect".to_string())).await; + self.metrics.connection_state.store(1, Ordering::Relaxed); + self.last_valid_recv = Instant::now() - Duration::from_secs(100); + + let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; + let (udp_tx, udp_rx) = mpsc::channel(100000); + let mut new_sessions = Vec::with_capacity(session_count); + let mut successful_sessions = 0; + let mut rtt_sum = 0.0; + + for idx in 0..session_count { + let session_id: u32 = rand::thread_rng().gen(); + match self.perform_handshake_with_id(&tx, session_id).await { + Ok((sock, mach, rtt)) => { + let session_index = new_sessions.len(); + let socket_clone = sock.clone(); + let udp_tx_clone = udp_tx.clone(); + + tokio::spawn(async move { + let mut buf = vec![0_u8; 65535]; + loop { + match socket_clone.recv(&mut buf).await { + Ok(n) => { + let inbound = Bytes::copy_from_slice(&buf[..n]); + if udp_tx_clone.send((session_index, inbound)).await.is_err() { break; } + } + Err(e) => { + tracing::warn!("UDP recv error (network-change session {}): {}", session_index, e); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + } + }); + new_sessions.push(SessionState { socket: sock, machine: mach }); + rtt_sum += rtt; + successful_sessions += 1; + } + Err(err) => { + let _ = tx.send(UiEvent::Log(format!("NetworkChanged reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await; + } + } + } + + if !new_sessions.is_empty() { + *sessions_opt = Some(new_sessions); + *udp_rx_opt = Some(udp_rx); + self.last_rtt_ms = rtt_sum / successful_sessions as f64; + self.last_valid_recv = Instant::now(); + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "network changed"); + self.metrics.connection_state.store(2, Ordering::Relaxed); + let _ = tx.send(UiEvent::Log("NetworkChanged reconnect successful!".to_string())).await; + } else { + let _ = tx.send(UiEvent::Log("NetworkChanged reconnect failed — will retry on keepalive tick".to_string())).await; + } + } + } + Some(BridgeCommand::ReloadConfig) => { + match ClientConfig::reload_from_json_near_binary() { + Ok(cfg) => { + self.apply_runtime_config(&cfg); + tx.send(UiEvent::Log("Runtime config reloaded".to_string())).await.ok(); + if self.running { + self.running = false; + self.metrics.connection_state.store(0, Ordering::Relaxed); + *proxy_guard = None; + *sessions_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "config reload"); + let _ = tx.send(UiEvent::TunnelStopped).await; + } + } + Err(err) => { + let _ = tx.send(UiEvent::Log(format!("Config reload failed: {err}"))).await; + } + } + } + Some(BridgeCommand::Shutdown) | None => { + self.running = false; + *proxy_guard = None; + return false; + } + } + true + } + + async fn handle_keepalive( + &mut self, + sessions_opt: &mut Option>, + udp_rx_opt: &mut Option>, + proxy_guard: &mut Option, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + proxy_rx: &mut mpsc::Receiver, + ) { + if self.last_valid_recv.elapsed().as_secs() > 25 { + let elapsed = self.last_valid_recv.elapsed().as_secs(); + if elapsed > 180 { + let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await; + self.running = false; + *proxy_guard = None; + *sessions_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout"); + let _ = tx.send(UiEvent::TunnelStopped).await; + self.metrics.connection_state.store(0, Ordering::Relaxed); + return; + } + + let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await; + self.metrics.connection_state.store(1, Ordering::Relaxed); + + let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; + let (udp_tx, udp_rx) = mpsc::channel(100000); + let mut new_sessions = Vec::with_capacity(session_count); + let mut successful_sessions = 0; + let mut rtt_sum = 0.0; + + for idx in 0..session_count { + let session_id: u32 = rand::thread_rng().gen(); + match self.perform_handshake_with_id(&tx, session_id).await { + Ok((sock, mach, rtt)) => { + let session_index = new_sessions.len(); + let socket_clone = sock.clone(); + let udp_tx_clone = udp_tx.clone(); + + tokio::spawn(async move { + let mut buf = vec![0_u8; 65535]; + loop { + match socket_clone.recv(&mut buf).await { + Ok(n) => { + let inbound = Bytes::copy_from_slice(&buf[..n]); + if udp_tx_clone.send((session_index, inbound)).await.is_err() { + break; + } + } + Err(e) => { + tracing::warn!("UDP socket recv error (reconnect session {}): {}", session_index, e); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + } + }); + + new_sessions.push(SessionState { socket: sock, machine: mach }); + rtt_sum += rtt; + successful_sessions += 1; + } + Err(err) => { + let _ = tx.send(UiEvent::Log(format!("Background reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await; + } + } + } + + if !new_sessions.is_empty() { + *sessions_opt = Some(new_sessions); + *udp_rx_opt = Some(udp_rx); + self.last_rtt_ms = rtt_sum / successful_sessions as f64; + self.last_valid_recv = Instant::now(); + self.metrics.connection_state.store(2, Ordering::Relaxed); + let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await; + + for session in sessions_opt.as_mut().unwrap().iter_mut() { + let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; + let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); + if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + } + + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "background reconnect"); + + let mut flushed = 0; + while let Ok(stale) = proxy_rx.try_recv() { + if let ProxyEvent::NewStream { stream_id, .. } = stale { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("connection reset".into()))); + } + flushed += 1; + } + if flushed > 0 { + let _ = tx.send(UiEvent::Log(format!("Flushed {} stale proxy messages to prevent UDP burst", flushed))).await; + } + } else { + let _ = tx.send(UiEvent::Log("Background reconnect failed. Will retry on next tick...".into())).await; + } + } + + if let Some(sessions) = sessions_opt.as_mut() { + for session in sessions.iter_mut() { + let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; + let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); + if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + + let ka_payload = Bytes::from(RelayMessage::KeepAlive.encode()); + if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ka_payload)) { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + } + } + } + + async fn handle_retransmit( + &mut self, + sessions_opt: &mut Option>, + udp_rx_opt: &mut Option>, + proxy_guard: &mut Option, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + ) { + let mut fatal_err = None; + if let Some(sessions) = sessions_opt.as_mut() { + for session in sessions.iter_mut() { + match session.machine.on_event(OstpEvent::Tick) { + Ok(action) => { + let mut queue = vec![action]; + while let Some(current_action) = queue.pop() { + match current_action { + ProtocolAction::Multiple(nested) => { + for a in nested { + queue.push(a); + } + } + ProtocolAction::SendDatagram(frame) => { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + _ => {} + } + } + } + Err(e) => { + fatal_err = Some(e); + break; + } + } + } + } + + if let Some(e) = fatal_err { + let _ = tx.send(UiEvent::Log(format!("Protocol tick fatal error: {e}"))).await; + self.running = false; + *proxy_guard = None; + *sessions_opt = None; + *udp_rx_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "protocol fatal error"); + let _ = tx.send(UiEvent::TunnelStopped).await; + self.metrics.connection_state.store(0, Ordering::Relaxed); + } + } + + async fn handle_proxy_event( + &mut self, + proxy_ev: Option, + sessions_opt: &mut Option>, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + ) { + if let Some(ev) = proxy_ev { + if let Some(sessions) = sessions_opt.as_mut() { + if sessions.is_empty() { + if let ProxyEvent::NewStream { stream_id, .. } = ev { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); + } + return; + } + let (stream_id, relay_msg, is_close) = match ev { + ProxyEvent::NewStream { stream_id, target } => { + let _ = tx.send(UiEvent::Log(format!("Proxy CONNECT stream_id={stream_id} target={target}"))).await; + (stream_id, RelayMessage::Connect(target), false) + } + ProxyEvent::UdpAssociate { stream_id } => { + let _ = tx.send(UiEvent::Log(format!("Proxy UDP ASSOCIATE stream_id={stream_id}"))).await; + (stream_id, RelayMessage::UdpAssociate, false) + } + ProxyEvent::UdpData { stream_id, target, payload } => { + (stream_id, RelayMessage::UdpData(target, payload.to_vec()), false) + } + ProxyEvent::Data { stream_id, payload } => (stream_id, RelayMessage::Data(payload.to_vec()), false), + ProxyEvent::Close { stream_id } => { + let _ = tx.send(UiEvent::Log(format!("Proxy CLOSE stream_id={stream_id}"))).await; + (stream_id, RelayMessage::Close, true) + } + }; + let len = sessions.len(); + let session_index = *stream_map.entry(stream_id).or_insert_with(|| { + rand::thread_rng().gen_range(0..len) + }); + if is_close { + stream_map.remove(&stream_id); + } + let session = &mut sessions[session_index]; + let out_payload = Bytes::from(relay_msg.encode()); + match session.machine.on_event(OstpEvent::Outbound(stream_id, out_payload)) { + Ok(ProtocolAction::SendDatagram(frame)) => { + if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() { + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + tracing::trace!("Outbound datagram sent stream_id={stream_id} bytes={}", frame.len()); + } + } + Ok(ProtocolAction::Multiple(list)) => { + let mut sent = 0usize; + for item in list { + if let ProtocolAction::SendDatagram(frame) = item { + if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() { + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + sent += 1; + } + } + } + tracing::trace!("Outbound datagram batch stream_id={stream_id} sent={sent}"); + } + Ok(ProtocolAction::Noop) => { + tracing::trace!("Outbound datagram noop stream_id={stream_id}"); + } + Ok(_) => { + tracing::trace!("Outbound datagram unexpected action stream_id={stream_id}"); + } + Err(e) => { + tracing::warn!("Protocol error packing outbound stream_id={}: {}", stream_id, e); + let _ = tx.send(UiEvent::Log(format!("Protocol error packing TCP: {e}"))).await; + } + } + } else { + if let ProxyEvent::NewStream { stream_id, .. } = ev { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); + } + } + } + } + + fn reset_proxy_streams( &self, tx: &mpsc::Sender, diff --git a/ostp-client/src/config.rs b/ostp-client/src/config.rs index 61fbfbe..e9db5df 100644 --- a/ostp-client/src/config.rs +++ b/ostp-client/src/config.rs @@ -58,7 +58,7 @@ pub struct OstpConfig { fn default_keepalive() -> u64 { 5 } -fn default_mtu() -> usize { 1280 } +fn default_mtu() -> usize { 1140 } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LocalProxyConfig { diff --git a/ostp-client/src/sysproxy.rs b/ostp-client/src/sysproxy.rs index dd814ef..6de0634 100644 --- a/ostp-client/src/sysproxy.rs +++ b/ostp-client/src/sysproxy.rs @@ -82,7 +82,7 @@ pub fn enable_windows_proxy(proxy_addr: &str) { } #[cfg(target_os = "windows")] -pub fn disable_windows_proxy() { +pub fn disable_system_proxy() { tracing::info!("Disabling Windows system proxy"); let _ = Command::new("reg") .creation_flags(CREATE_NO_WINDOW) @@ -188,10 +188,6 @@ pub fn enable_system_proxy(proxy_addr: &str) { enable_windows_proxy(proxy_addr); } -#[cfg(target_os = "windows")] -pub fn disable_system_proxy() { - disable_windows_proxy(); -} pub struct SystemProxyGuard { active: bool, diff --git a/ostp-client/src/tunnel/exclusion.rs b/ostp-client/src/tunnel/exclusion.rs new file mode 100644 index 0000000..8b49f12 --- /dev/null +++ b/ostp-client/src/tunnel/exclusion.rs @@ -0,0 +1,134 @@ +use crate::config::ExclusionConfig; +use std::time::Duration; +use tokio::time::timeout; + +#[derive(Clone)] +pub struct ExclusionMatcher { + pub domain_suffix: Vec, + pub cidrs: Vec, + pub processes: Vec, + pub physical_if_index: Option, + pub physical_if_name: Option, +} + +impl ExclusionMatcher { + pub fn new( + exclusions: &ExclusionConfig, + physical_if_index: Option, + physical_if_name: Option, + ) -> Self { + let mut cidrs = Vec::new(); + for ip in &exclusions.ips { + if let Some(cidr) = parse_cidr(ip) { + cidrs.push(cidr); + } + } + + let processes = exclusions.processes.iter() + .map(|p| p.trim().to_lowercase()) + .filter(|p| !p.is_empty()) + .collect(); + + Self { + domain_suffix: exclusions + .domains + .iter() + .map(|d| d.trim().trim_start_matches('.').to_lowercase()) + .filter(|d| !d.is_empty()) + .collect(), + cidrs, + processes, + physical_if_index, + physical_if_name, + } + } + + pub async fn should_bypass_target(&self, host: &str, port: u16, timeout_value: Duration) -> bool { + if self.match_domain(host) { + return true; + } + + if self.cidrs.is_empty() { + return false; + } + + if let Ok(ip) = host.parse::() { + return self.match_ip(&ip); + } + + let lookup_target = (host.to_string(), port); + match timeout(timeout_value, tokio::net::lookup_host(lookup_target)).await { + Ok(Ok(addrs)) => addrs.into_iter().any(|addr| self.match_ip(&addr.ip())), + _ => false, + } + } + + pub fn match_domain(&self, host: &str) -> bool { + if self.domain_suffix.is_empty() { + return false; + } + let host = host.trim_end_matches('.').to_lowercase(); + self.domain_suffix.iter().any(|suffix| { + host == *suffix || host.ends_with(&format!(".{suffix}")) + }) + } + + pub fn match_ip(&self, ip: &std::net::IpAddr) -> bool { + self.cidrs.iter().any(|cidr| cidr.contains(ip)) + } + + pub fn match_process(&self, process_name: &str) -> bool { + if self.processes.is_empty() { + return false; + } + let p = process_name.to_lowercase(); + self.processes.iter().any(|ex| p.contains(ex)) + } +} + +#[derive(Clone)] +pub enum Cidr { + V4(u32, u8), + V6(u128, u8), +} + +impl Cidr { + pub fn contains(&self, ip: &std::net::IpAddr) -> bool { + match (self, ip) { + (Cidr::V4(net, bits), std::net::IpAddr::V4(addr)) => { + let mask = if *bits == 0 { 0 } else { u32::MAX << (32 - bits) }; + let ip = u32::from_be_bytes(addr.octets()); + (ip & mask) == (*net & mask) + } + (Cidr::V6(net, bits), std::net::IpAddr::V6(addr)) => { + let mask = if *bits == 0 { 0 } else { u128::MAX << (128 - bits) }; + let ip = u128::from_be_bytes(addr.octets()); + (ip & mask) == (*net & mask) + } + _ => false, + } + } +} + +pub fn parse_cidr(s: &str) -> Option { + let parts: Vec<&str> = s.split('/').collect(); + if parts.is_empty() || parts.len() > 2 { + return None; + } + if let Ok(ip) = parts[0].parse::() { + let bits = if parts.len() == 2 { + parts[1].parse::().ok()? + } else { + match ip { + std::net::IpAddr::V4(_) => 32, + std::net::IpAddr::V6(_) => 128, + } + }; + match ip { + std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), bits)), + std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), bits)), + } + } else { + None + } +} diff --git a/ostp-client/src/tunnel/mod.rs b/ostp-client/src/tunnel/mod.rs index 1d85004..c0a7b5a 100644 --- a/ostp-client/src/tunnel/mod.rs +++ b/ostp-client/src/tunnel/mod.rs @@ -61,3 +61,7 @@ pub async fn run_local_proxy( } + +pub mod exclusion; +pub mod process_lookup; +pub mod sni_sniff; diff --git a/ostp-client/src/tunnel/native_handler.rs b/ostp-client/src/tunnel/native_handler.rs index fff860d..3a5fb38 100644 --- a/ostp-client/src/tunnel/native_handler.rs +++ b/ostp-client/src/tunnel/native_handler.rs @@ -426,12 +426,63 @@ pub async fn run_native_tunnel_from_fd( } }); + let matcher = crate::tunnel::exclusion::ExclusionMatcher::new(&config.exclusions, None, None); + let mut tcp_accept_task = tokio::spawn(async move { if let Some(mut listener) = tcp_listener { - while let Some((mut stream, _local, remote)) = listener.next().await { + while let Some((mut stream, local, remote)) = listener.next().await { let proxy_addr = proxy_addr.clone(); + let matcher = matcher.clone(); tokio::spawn(async move { - if debug { tracing::info!("Native TUN intercepted TCP to {}", remote); } + if debug { tracing::info!("Native TUN intercepted TCP {local} -> {remote}"); } + + // Peak first chunk to see SNI + let mut sniff_buf = [0u8; 1500]; + let sniff_len = match tokio::time::timeout(std::time::Duration::from_millis(50), stream.read(&mut sniff_buf)).await { + Ok(Ok(n)) => n, + _ => 0, // Timeout or error + }; + + let mut should_bypass = false; + + // 1. Check SNI + if sniff_len > 0 { + if let Some(sni) = crate::tunnel::sni_sniff::extract_sni(&sniff_buf[..sniff_len]) { + if debug { tracing::info!("Native TUN sniffed SNI: {}", sni); } + if matcher.match_domain(&sni) { + should_bypass = true; + } + } + } + + // 2. Check Process + if !should_bypass { + if let Some(exe) = crate::tunnel::process_lookup::get_process_name_from_port(local.port()) { + if debug { tracing::info!("Native TUN source port {} maps to EXE: {}", local.port(), exe); } + if matcher.match_process(&exe) { + should_bypass = true; + } + } + } + + // 3. Check Target IP + if !should_bypass { + if matcher.match_ip(&remote.ip()) { + should_bypass = true; + } + } + + if should_bypass { + if debug { tracing::info!("Native TUN BYPASS matched for {}", remote); } + if let Ok(mut direct) = tokio::time::timeout(std::time::Duration::from_secs(5), tokio::net::TcpStream::connect(remote)).await.unwrap_or(Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "Direct connect timeout"))) { + if sniff_len > 0 { + let _ = direct.write_all(&sniff_buf[..sniff_len]).await; + } + let _ = tokio::io::copy_bidirectional(&mut stream, &mut direct).await; + } + return; + } + if let Ok(mut socks) = tokio::net::TcpStream::connect(&proxy_addr).await { if socks.write_all(&[5, 1, 0]).await.is_err() { return; } let mut buf = [0u8; 2]; @@ -456,6 +507,11 @@ pub async fn run_native_tunnel_from_fd( let mut rep = [0u8; 10]; if socks.read_exact(&mut rep).await.is_err() || rep[1] != 0 { return; } + // Write sniffed buffer to socks + if sniff_len > 0 { + if socks.write_all(&sniff_buf[..sniff_len]).await.is_err() { return; } + } + let _ = tokio::io::copy_bidirectional(&mut stream, &mut socks).await; } }); diff --git a/ostp-client/src/tunnel/process_lookup.rs b/ostp-client/src/tunnel/process_lookup.rs new file mode 100644 index 0000000..b140eb2 --- /dev/null +++ b/ostp-client/src/tunnel/process_lookup.rs @@ -0,0 +1,142 @@ +#[cfg(target_os = "windows")] +pub fn get_process_name_from_port(port: u16) -> Option { + use winapi::shared::minwindef::{DWORD, ULONG}; + use winapi::shared::winerror::ERROR_INSUFFICIENT_BUFFER; + use winapi::um::iphlpapi::GetExtendedTcpTable; + use winapi::shared::tcpmib::{MIB_TCPTABLE_OWNER_PID, MIB_TCPROW_OWNER_PID}; + + let mut size: ULONG = 0; + let table_class = 5; // TCP_TABLE_OWNER_PID_ALL + let mut table = vec![0u8; 1024]; + + unsafe { + let mut ret = GetExtendedTcpTable( + table.as_mut_ptr() as *mut _, + &mut size, + 0, + 2, // AF_INET + table_class, + 0, + ); + + if ret == ERROR_INSUFFICIENT_BUFFER { + table.resize(size as usize, 0); + ret = GetExtendedTcpTable( + table.as_mut_ptr() as *mut _, + &mut size, + 0, + 2, // AF_INET + table_class, + 0, + ); + } + + if ret == 0 { + let tcp_table = &*(table.as_ptr() as *const MIB_TCPTABLE_OWNER_PID); + let row_ptr = &tcp_table.table[0] as *const MIB_TCPROW_OWNER_PID; + for i in 0..tcp_table.dwNumEntries { + let row = &*row_ptr.add(i as usize); + // Local port is in network byte order + let local_port = u16::from_be(row.dwLocalPort as u16); + if local_port == port { + return get_process_name_from_pid(row.dwOwningPid); + } + } + } + } + None +} + +#[cfg(target_os = "windows")] +fn get_process_name_from_pid(pid: u32) -> Option { + use winapi::um::processthreadsapi::OpenProcess; + use winapi::um::psapi::GetModuleBaseNameW; + use winapi::um::winnt::{PROCESS_QUERY_INFORMATION, PROCESS_VM_READ}; + use winapi::um::handleapi::CloseHandle; + use std::os::windows::ffi::OsStringExt; + + unsafe { + let handle = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, 0, pid); + if handle.is_null() { + return None; + } + + let mut buffer = [0u16; 1024]; + let len = GetModuleBaseNameW(handle, std::ptr::null_mut(), buffer.as_mut_ptr(), buffer.len() as u32); + CloseHandle(handle); + + if len > 0 { + let name = std::ffi::OsString::from_wide(&buffer[..len as usize]); + return Some(name.to_string_lossy().into_owned()); + } + } + None +} + +#[cfg(target_os = "linux")] +pub fn get_process_name_from_port(port: u16) -> Option { + use std::fs; + use std::io::{BufRead, BufReader}; + + let mut target_inode = None; + let hex_port = format!("{:04X}", port); + + let check_net_file = |path: &str| -> Option { + let file = fs::File::open(path).ok()?; + let reader = BufReader::new(file); + for line in reader.lines().skip(1).filter_map(Result::ok) { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 10 { + let local_addr = parts[1]; + if local_addr.ends_with(&format!(":{}", hex_port)) { + if let Ok(inode) = parts[9].parse::() { + return Some(inode); + } + } + } + } + None + }; + + target_inode = check_net_file("/proc/net/tcp") + .or_else(|| check_net_file("/proc/net/tcp6")) + .or_else(|| check_net_file("/proc/net/udp")) + .or_else(|| check_net_file("/proc/net/udp6")); + + let target_inode = target_inode?; + let socket_str = format!("socket:[{}]", target_inode); + + for entry in fs::read_dir("/proc").ok()?.filter_map(Result::ok) { + let file_name = entry.file_name(); + let pid_str = file_name.to_string_lossy(); + if !pid_str.chars().all(char::is_numeric) { + continue; + } + + let fd_dir = entry.path().join("fd"); + if let Ok(fd_entries) = fs::read_dir(fd_dir) { + for fd_entry in fd_entries.filter_map(Result::ok) { + if let Ok(target) = fs::read_link(fd_entry.path()) { + if target.to_string_lossy() == socket_str { + let exe_path = entry.path().join("exe"); + if let Ok(exe_link) = fs::read_link(exe_path) { + if let Some(name) = exe_link.file_name() { + return Some(name.to_string_lossy().into_owned()); + } + } + if let Ok(comm) = fs::read_to_string(entry.path().join("comm")) { + return Some(comm.trim().to_string()); + } + } + } + } + } + } + + None +} + +#[cfg(not(any(target_os = "windows", target_os = "linux")))] +pub fn get_process_name_from_port(_port: u16) -> Option { + None +} diff --git a/ostp-client/src/tunnel/proxy.rs b/ostp-client/src/tunnel/proxy.rs index dae7cee..b0a267e 100644 --- a/ostp-client/src/tunnel/proxy.rs +++ b/ostp-client/src/tunnel/proxy.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use crate::tunnel::exclusion::{ExclusionMatcher, Cidr}; use anyhow::{anyhow, Context, Result}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, UdpSocket}; @@ -421,8 +422,10 @@ async fn handle_udp_associate( }; let payload = bytes::Bytes::copy_from_slice(&buf[header_len..len]); + let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() }; + let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 0 }; // Check if target should bypass the tunnel - if matcher.should_bypass(&target, connect_timeout).await { + if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await { if debug { tracing::info!("proxy UDP BYPASS target={}", target); } @@ -668,7 +671,9 @@ async fn handle_proxy_client( if debug { tracing::info!("proxy CONNECT stream_id={stream_id} target={target}"); } - if matcher.should_bypass(&target, connect_timeout).await { + let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() }; + let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 0 }; + if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await { return direct_connect_socks5( client, stream_id, @@ -753,7 +758,9 @@ async fn handle_proxy_client( if debug { tracing::info!("proxy CONNECT stream_id={stream_id} target={target}"); } - if matcher.should_bypass(&target, connect_timeout).await { + let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() }; + let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 443 }; + if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await { return direct_connect_http( client, stream_id, @@ -854,129 +861,6 @@ async fn handle_proxy_client( Ok(()) } -#[derive(Clone)] -struct ExclusionMatcher { - domain_suffix: Vec, - cidrs: Vec, - physical_if_index: Option, - physical_if_name: Option, -} - -impl ExclusionMatcher { - fn new( - exclusions: &ExclusionConfig, - physical_if_index: Option, - physical_if_name: Option, - ) -> Self { - let mut cidrs = Vec::new(); - for ip in &exclusions.ips { - if let Some(cidr) = parse_cidr(ip) { - cidrs.push(cidr); - } - } - - Self { - domain_suffix: exclusions - .domains - .iter() - .map(|d| d.trim().trim_start_matches('.').to_lowercase()) - .filter(|d| !d.is_empty()) - .collect(), - cidrs, - physical_if_index, - physical_if_name, - } - } - - async fn should_bypass(&self, target: &str, timeout_value: Duration) -> bool { - let (host, port) = match split_host_port(target) { - Some(v) => v, - None => return false, - }; - - if self.match_domain(&host) { - return true; - } - - if self.cidrs.is_empty() { - return false; - } - - if let Ok(ip) = host.parse::() { - return self.match_ip(&ip); - } - - let lookup_target = (host.clone(), port); - match timeout(timeout_value, tokio::net::lookup_host(lookup_target)).await { - Ok(Ok(addrs)) => addrs.into_iter().any(|addr| self.match_ip(&addr.ip())), - _ => false, - } - } - - fn match_domain(&self, host: &str) -> bool { - if self.domain_suffix.is_empty() { - return false; - } - let host = host.trim_end_matches('.').to_lowercase(); - self.domain_suffix.iter().any(|suffix| { - host == *suffix || host.ends_with(&format!(".{suffix}")) - }) - } - - fn match_ip(&self, ip: &std::net::IpAddr) -> bool { - self.cidrs.iter().any(|cidr| cidr.contains(ip)) - } -} - -#[derive(Clone)] -enum Cidr { - V4(u32, u8), - V6(u128, u8), -} - -impl Cidr { - fn contains(&self, ip: &std::net::IpAddr) -> bool { - match (self, ip) { - (Cidr::V4(net, bits), std::net::IpAddr::V4(addr)) => { - let mask = if *bits == 0 { 0 } else { u32::MAX << (32 - bits) }; - let ip = u32::from_be_bytes(addr.octets()); - (ip & mask) == (*net & mask) - } - (Cidr::V6(net, bits), std::net::IpAddr::V6(addr)) => { - let mask = if *bits == 0 { 0 } else { u128::MAX << (128 - bits) }; - let ip = u128::from_be_bytes(addr.octets()); - (ip & mask) == (*net & mask) - } - _ => false, - } - } -} - -fn parse_cidr(value: &str) -> Option { - let value = value.trim(); - if value.is_empty() { - return None; - } - - if let Some((addr_str, bits_str)) = value.split_once('/') { - let bits: u8 = bits_str.parse().ok()?; - if let Ok(addr) = addr_str.parse::() { - return match addr { - std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), bits.min(32))), - std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), bits.min(128))), - }; - } - } - - if let Ok(addr) = value.parse::() { - return match addr { - std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), 32)), - std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), 128)), - }; - } - - None -} fn split_host_port(target: &str) -> Option<(String, u16)> { if let Some((host, port)) = target.rsplit_once(':') { diff --git a/ostp-client/src/tunnel/sni_sniff.rs b/ostp-client/src/tunnel/sni_sniff.rs new file mode 100644 index 0000000..49eac15 --- /dev/null +++ b/ostp-client/src/tunnel/sni_sniff.rs @@ -0,0 +1,73 @@ +pub fn extract_sni(data: &[u8]) -> Option { + // Basic TLS ClientHello parser + // Must be at least 43 bytes to contain anything useful + if data.len() < 43 { + return None; + } + + // TLS Record layer: Handshake (22) + if data[0] != 0x16 { + return None; + } + + // Record layer version: 0x0301 (TLS 1.0) or 0x0303 (TLS 1.2) + if data[1] != 0x03 { + return None; + } + + // Handshake type: ClientHello (1) + if data[5] != 0x01 { + return None; + } + + let mut pos = 43; // Skip fixed ClientHello header + + // Skip Session ID + if pos >= data.len() { return None; } + let session_id_len = data[pos] as usize; + pos += 1 + session_id_len; + + // Skip Cipher Suites + if pos + 2 > data.len() { return None; } + let cipher_suites_len = ((data[pos] as usize) << 8) | (data[pos + 1] as usize); + pos += 2 + cipher_suites_len; + + // Skip Compression Methods + if pos >= data.len() { return None; } + let comp_methods_len = data[pos] as usize; + pos += 1 + comp_methods_len; + + // Extensions + if pos + 2 > data.len() { return None; } + let extensions_len = ((data[pos] as usize) << 8) | (data[pos + 1] as usize); + pos += 2; + + let extensions_end = pos + extensions_len; + if extensions_end > data.len() { return None; } + + while pos + 4 <= extensions_end { + let ext_type = ((data[pos] as usize) << 8) | (data[pos + 1] as usize); + let ext_len = ((data[pos + 2] as usize) << 8) | (data[pos + 3] as usize); + pos += 4; + + if ext_type == 0x0000 { // Server Name Indication (SNI) + if pos + 5 <= extensions_end { + let list_len = ((data[pos] as usize) << 8) | (data[pos + 1] as usize); + let name_type = data[pos + 2]; + if name_type == 0 { // Hostname + let name_len = ((data[pos + 3] as usize) << 8) | (data[pos + 4] as usize); + if pos + 5 + name_len <= extensions_end { + let sni_bytes = &data[pos + 5..pos + 5 + name_len]; + if let Ok(sni) = std::str::from_utf8(sni_bytes) { + return Some(sni.to_string()); + } + } + } + } + break; + } + pos += ext_len; + } + + None +} diff --git a/ostp-core/src/congestion.rs b/ostp-core/src/congestion.rs index 5ecb36a..55fe7ee 100644 --- a/ostp-core/src/congestion.rs +++ b/ostp-core/src/congestion.rs @@ -5,7 +5,6 @@ //! This replaces the fixed `retransmit_budget = 8` with an adaptive //! congestion window that responds to network conditions. -use std::collections::VecDeque; use std::time::{Duration, Instant}; /// Congestion control state for a single OSTP session. @@ -16,14 +15,8 @@ pub struct CongestionController { ssthresh: u64, /// Current phase phase: Phase, - /// Minimum RTT observed (used for BDP calculation) + /// Minimum RTT observed min_rtt: Duration, - /// Maximum bandwidth observed (bytes/sec) - max_bandwidth: u64, - /// RTT samples for smoothing - rtt_samples: VecDeque, - /// Bandwidth samples - bw_samples: VecDeque, /// Bytes currently in flight (unacknowledged) bytes_in_flight: u64, /// Total bytes acknowledged (for bandwidth estimation) @@ -36,8 +29,6 @@ pub struct CongestionController { pacing_rate: u64, /// MTU estimate (used for cwnd → packet count conversion) mtu: u64, - /// Probe RTT phase timer - probe_rtt_timer: Option, /// Min RTT expiry: re-probe after 10 seconds min_rtt_stamp: Instant, } @@ -48,35 +39,14 @@ enum Phase { SlowStart, /// Probe bandwidth: cycle through pacing gains ProbeBandwidth, - /// Periodically drain the queue to measure true min RTT - #[allow(dead_code)] - ProbeRtt, } -#[derive(Debug, Clone)] -#[allow(dead_code)] -struct RttSample { - rtt: Duration, - time: Instant, -} - -#[derive(Debug, Clone)] -#[allow(dead_code)] -struct BwSample { - bytes_per_sec: u64, - time: Instant, -} - -/// Maximum number of samples to keep for windowed min/max -const MAX_SAMPLES: usize = 32; /// Initial congestion window: 10 packets × MTU const INITIAL_CWND_PACKETS: u64 = 10; /// Minimum cwnd: 2 packets const MIN_CWND_PACKETS: u64 = 2; /// Min RTT expiry window (after which we re-probe) const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10); -/// ProbeRTT drain duration -const PROBE_RTT_DURATION: Duration = Duration::from_millis(200); impl CongestionController { pub fn new(mtu: u64) -> Self { @@ -87,16 +57,12 @@ impl CongestionController { ssthresh: u64::MAX, phase: Phase::SlowStart, min_rtt: Duration::from_millis(100), // Conservative initial estimate - max_bandwidth: 0, - rtt_samples: VecDeque::with_capacity(MAX_SAMPLES), - bw_samples: VecDeque::with_capacity(MAX_SAMPLES), bytes_in_flight: 0, total_acked: 0, last_ack_time: now, loss_count: 0, pacing_rate: initial_cwnd * 10, // initial: ~10 windows/sec mtu, - probe_rtt_timer: None, min_rtt_stamp: now, } } @@ -169,30 +135,8 @@ impl CongestionController { // TCP Reno Additive Increase: increase cwnd by ~1 MTU per RTT self.cwnd = self.cwnd.saturating_add(bytes * self.mtu / self.cwnd.max(1)); } - Phase::ProbeRtt => { - // Drain down to 4 packets to measure true min RTT - self.cwnd = MIN_CWND_PACKETS * self.mtu * 2; - if let Some(timer) = self.probe_rtt_timer { - if now.duration_since(timer) >= PROBE_RTT_DURATION { - // ProbeRTT complete, return to ProbeBandwidth - self.phase = Phase::ProbeBandwidth; - self.probe_rtt_timer = None; - self.cwnd = (MIN_CWND_PACKETS * self.mtu * 4).max(self.cwnd); - tracing::debug!(cwnd = self.cwnd, min_rtt = ?self.min_rtt, "congestion: probe RTT complete"); - } - } - } } - /* - // Periodically enter ProbeRTT to refresh min_rtt - if now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY && self.phase != Phase::ProbeRtt { - self.phase = Phase::ProbeRtt; - self.probe_rtt_timer = Some(now); - tracing::debug!("congestion: entering probe RTT phase"); - } - */ - self.update_pacing_rate(); self.last_ack_time = now; } @@ -215,9 +159,6 @@ impl CongestionController { self.cwnd = (self.cwnd * 7 / 10).max(MIN_CWND_PACKETS * self.mtu); tracing::debug!(cwnd = self.cwnd, "congestion: loss, cwnd reduced"); } - Phase::ProbeRtt => { - // Don't react to loss during ProbeRTT - } } self.update_pacing_rate(); @@ -236,40 +177,16 @@ impl CongestionController { self.min_rtt = rtt; self.min_rtt_stamp = now; } - - // Keep sample history - self.rtt_samples.push_back(RttSample { rtt, time: now }); - while self.rtt_samples.len() > MAX_SAMPLES { - self.rtt_samples.pop_front(); - } } - fn update_bandwidth(&mut self, acked_bytes: u64, now: Instant) { + fn update_bandwidth(&mut self, _acked_bytes: u64, now: Instant) { let elapsed = now.duration_since(self.last_ack_time); if elapsed.as_micros() > 0 { - let bw = acked_bytes * 1_000_000 / elapsed.as_micros() as u64; - if bw > self.max_bandwidth { - self.max_bandwidth = bw; - } - self.bw_samples.push_back(BwSample { bytes_per_sec: bw, time: now }); - while self.bw_samples.len() > MAX_SAMPLES { - self.bw_samples.pop_front(); - } + // Removed bw_samples tracking } } - #[allow(dead_code)] - fn bandwidth_delay_product(&self) -> u64 { - // BDP = max_bandwidth * min_rtt - let bw = if self.max_bandwidth > 0 { - self.max_bandwidth - } else { - // Fallback: assume 10 Mbps - 1_250_000 - }; - let rtt_secs = self.min_rtt.as_secs_f64(); - (bw as f64 * rtt_secs) as u64 - } + fn update_pacing_rate(&mut self) { // Pacing rate = cwnd / min_rtt (with gain) diff --git a/ostp-core/src/protocol.rs b/ostp-core/src/protocol.rs index 5dc6b0b..d421e07 100644 --- a/ostp-core/src/protocol.rs +++ b/ostp-core/src/protocol.rs @@ -290,7 +290,7 @@ impl ProtocolMachine { if raw_vec.len() < 12 { return Err(ProtocolError::Framing("data datagram too short".to_string())); } - let nonce = u64::from_be_bytes(raw_vec[4..12].try_into().unwrap()); + let nonce = u64::from_be_bytes(raw_vec[4..12].try_into().map_err(|_| ProtocolError::Framing("data datagram too short for nonce".into()))?); if nonce < self.expected_recv_nonce { // Duplicate — the ACK we sent was likely lost or delayed. @@ -330,7 +330,7 @@ impl ProtocolMachine { // Fast path processing for Nacks: act immediately, bypass sequence queue if packet.header.kind == FrameKind::Nack && packet.payload.len() >= 8 { - let req_nonce = u64::from_be_bytes(packet.payload[..8].try_into().unwrap()); + let req_nonce = u64::from_be_bytes(packet.payload[..8].try_into().map_err(|_| ProtocolError::Framing("nack payload too short".into()))?); if let Some(cached_frame) = self.lookup_sent_frame(req_nonce) { tracing::debug!("NACK received: retransmitting nonce={}", req_nonce); self.cc.on_loss(cached_frame.len() as u64); @@ -733,8 +733,8 @@ fn parse_ack_ranges(payload: &[u8]) -> Result, ProtocolError> { let mut ranges = Vec::with_capacity(count); let mut idx = 1; for _ in 0..count { - let start = u64::from_be_bytes(payload[idx..idx + 8].try_into().unwrap()); - let end = u64::from_be_bytes(payload[idx + 8..idx + 16].try_into().unwrap()); + let start = u64::from_be_bytes(payload[idx..idx + 8].try_into().map_err(|_| ProtocolError::Framing("ack range start invalid".into()))?); + let end = u64::from_be_bytes(payload[idx + 8..idx + 16].try_into().map_err(|_| ProtocolError::Framing("ack range end invalid".into()))?); ranges.push((start, end)); idx += 16; } diff --git a/ostp-core/src/relay.rs b/ostp-core/src/relay.rs index 4301ac0..0be3576 100644 --- a/ostp-core/src/relay.rs +++ b/ostp-core/src/relay.rs @@ -64,7 +64,7 @@ impl RelayMessage { 7 => { let payload = decode_with_len(&input[1..])?; if payload.len() != 8 { return Err(anyhow!("invalid ping payload len")); } - let ts = u64::from_be_bytes(payload.try_into().unwrap()); + let ts = u64::from_be_bytes(payload.try_into().map_err(|_| anyhow!("invalid ping payload size"))?); Ok(RelayMessage::Ping(ts)) } 8 => { diff --git a/ostp-flutter/devtools_options.yaml b/ostp-flutter/devtools_options.yaml new file mode 100644 index 0000000..fa0b357 --- /dev/null +++ b/ostp-flutter/devtools_options.yaml @@ -0,0 +1,3 @@ +description: This file stores settings for Dart & Flutter DevTools. +documentation: https://docs.flutter.dev/tools/devtools/extensions#configure-extension-enablement-states +extensions: diff --git a/ostp-flutter/pubspec.lock b/ostp-flutter/pubspec.lock index f59aed4..6334819 100644 --- a/ostp-flutter/pubspec.lock +++ b/ostp-flutter/pubspec.lock @@ -96,6 +96,14 @@ packages: description: flutter source: sdk version: "0.0.0" + json_annotation: + dependency: transitive + description: + name: json_annotation + sha256: "2a743920d81b7910627f68ee2c9ac1fc0bfee32b9fc3403587d7c6791ca12f80" + url: "https://pub.dev" + source: hosted + version: "4.12.0" leak_tracker: dependency: transitive description: @@ -144,6 +152,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.13.0" + menu_base: + dependency: transitive + description: + name: menu_base + sha256: "820368014a171bd1241030278e6c2617354f492f5c703d7b7d4570a6b8b84405" + url: "https://pub.dev" + source: hosted + version: "0.1.1" meta: dependency: transitive description: @@ -208,6 +224,46 @@ packages: url: "https://pub.dev" source: hosted version: "2.1.8" + screen_retriever: + dependency: transitive + description: + name: screen_retriever + sha256: "570dbc8e4f70bac451e0efc9c9bb19fa2d6799a11e6ef04f946d7886d2e23d0c" + url: "https://pub.dev" + source: hosted + version: "0.2.0" + screen_retriever_linux: + dependency: transitive + description: + name: screen_retriever_linux + sha256: f7f8120c92ef0784e58491ab664d01efda79a922b025ff286e29aa123ea3dd18 + url: "https://pub.dev" + source: hosted + version: "0.2.0" + screen_retriever_macos: + dependency: transitive + description: + name: screen_retriever_macos + sha256: "71f956e65c97315dd661d71f828708bd97b6d358e776f1a30d5aa7d22d78a149" + url: "https://pub.dev" + source: hosted + version: "0.2.0" + screen_retriever_platform_interface: + dependency: transitive + description: + name: screen_retriever_platform_interface + sha256: ee197f4581ff0d5608587819af40490748e1e39e648d7680ecf95c05197240c0 + url: "https://pub.dev" + source: hosted + version: "0.2.0" + screen_retriever_windows: + dependency: transitive + description: + name: screen_retriever_windows + sha256: "449ee257f03ca98a57288ee526a301a430a344a161f9202b4fcc38576716fe13" + url: "https://pub.dev" + source: hosted + version: "0.2.0" shared_preferences: dependency: "direct main" description: @@ -264,6 +320,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.4.1" + shortid: + dependency: transitive + description: + name: shortid + sha256: d0b40e3dbb50497dad107e19c54ca7de0d1a274eb9b4404991e443dadb9ebedb + url: "https://pub.dev" + source: hosted + version: "0.1.2" sky_engine: dependency: transitive description: flutter @@ -317,6 +381,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.7.10" + tray_manager: + dependency: "direct main" + description: + name: tray_manager + sha256: c5fd83b0ae4d80be6eaedfad87aaefab8787b333b8ebd064b0e442a81006035b + url: "https://pub.dev" + source: hosted + version: "0.5.2" vector_math: dependency: transitive description: @@ -341,6 +413,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.1.1" + window_manager: + dependency: "direct main" + description: + name: window_manager + sha256: "7eb6d6c4164ec08e1bf978d6e733f3cebe792e2a23fb07cbca25c2872bfdbdcd" + url: "https://pub.dev" + source: hosted + version: "0.5.1" xdg_directories: dependency: transitive description: diff --git a/ostp-flutter/pubspec.yaml b/ostp-flutter/pubspec.yaml index 8bea6dc..732595a 100644 --- a/ostp-flutter/pubspec.yaml +++ b/ostp-flutter/pubspec.yaml @@ -36,6 +36,8 @@ dependencies: cupertino_icons: ^1.0.8 shared_preferences: ^2.5.5 mobile_scanner: ^5.0.0 + window_manager: ^0.5.1 + tray_manager: ^0.5.2 dev_dependencies: flutter_test: diff --git a/ostp-gui/src-tauri/Cargo.lock b/ostp-gui/src-tauri/Cargo.lock index a62fe15..5bca46c 100644 --- a/ostp-gui/src-tauri/Cargo.lock +++ b/ostp-gui/src-tauri/Cargo.lock @@ -2641,7 +2641,7 @@ dependencies = [ [[package]] name = "ostp-client" -version = "0.2.73" +version = "0.2.79" dependencies = [ "anyhow", "base64 0.22.1", @@ -2666,12 +2666,13 @@ dependencies = [ "tracing", "tun", "webpki-roots 0.26.11", + "winapi", "x25519-dalek", ] [[package]] name = "ostp-core" -version = "0.2.73" +version = "0.2.79" dependencies = [ "anyhow", "bytes", diff --git a/ostp-gui/src-tauri/src/lib.rs b/ostp-gui/src-tauri/src/lib.rs index f17a02e..2c52aa8 100644 --- a/ostp-gui/src-tauri/src/lib.rs +++ b/ostp-gui/src-tauri/src/lib.rs @@ -106,6 +106,7 @@ struct HelperState { pipe_state: Arc>, cmd_tx: tokio::sync::mpsc::Sender, token: String, + port: u16, } enum TunnelHandle { @@ -282,6 +283,37 @@ async fn get_metrics(state: tauri::State<'_, AppState>) -> Result) -> Result { + let mut guard = state.0.lock().await; + if guard.tunnel.is_none() { + return Ok(false); + } + + let config_path = get_config_path(); + let config_str = std::fs::read_to_string(&config_path) + .map_err(|e| format!("Read config error: {}", e))?; + + match &guard.tunnel { + Some(TunnelHandle::Helper(h)) => { + let cmd = format!( + "{{\"cmd\":\"reload\",\"config\":{},\"token\":\"{}\"}}\n", + serde_json::to_string(&config_str).unwrap(), + h.token + ); + let _ = h.cmd_tx.send(cmd).await; + } + Some(TunnelHandle::InProcess(s)) => { + // Restarting in-process tunnel is not supported without re-calling start_tunnel, + // but we can just abort and we should really call start_tunnel again. + // For now, return false. + return Ok(false); + } + None => {} + } + Ok(true) +} + #[tauri::command] async fn stop_tunnel(state: tauri::State<'_, AppState>) -> Result { let mut guard = state.0.lock().await; @@ -375,24 +407,19 @@ async fn start_tun_via_helper( guard: &mut AppStateInner, raw: &ClientConfigRaw, ) -> Result { - #[cfg(target_os = "windows")] - { - // Kill any existing helper processes to prevent os error 10048 (port already in use) - use std::os::windows::process::CommandExt; - let _ = std::process::Command::new("taskkill") - .args(["/F", "/IM", "ostp-tun-helper.exe"]) - .creation_flags(0x08000000) - .output(); - } + let port = { + let listener = std::net::TcpListener::bind("127.0.0.1:0").map_err(|e| format!("Bind error: {}", e))?; + listener.local_addr().unwrap().port() + }; let auth_token = rand::random::().to_string(); let helper_exe = find_helper_exe().ok_or_else(|| "ostp-tun-helper.exe not found.".to_string())?; - launch_as_admin(&helper_exe, &auth_token).map_err(|e| format!("Failed to launch helper: {}", e))?; + launch_as_admin(&helper_exe, &auth_token, port).map_err(|e| format!("Failed to launch helper: {}", e))?; tokio::time::sleep(std::time::Duration::from_millis(1500)).await; let socket = tokio::time::timeout(std::time::Duration::from_secs(60), async { loop { - match tokio::net::TcpStream::connect("127.0.0.1:53211").await { + match tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)).await { Ok(s) => return Ok::<_, std::io::Error>(s), Err(_) => tokio::time::sleep(std::time::Duration::from_millis(200)).await, } @@ -443,7 +470,7 @@ async fn start_tun_via_helper( state_for_task.lock().await.connection_state = 0; }); - guard.tunnel = Some(TunnelHandle::Helper(HelperState { pipe_state, cmd_tx, token: auth_token })); + guard.tunnel = Some(TunnelHandle::Helper(HelperState { pipe_state, cmd_tx, token: auth_token, port })); Ok(true) } @@ -493,14 +520,14 @@ fn find_helper_exe() -> Option { } #[cfg(target_os = "windows")] -fn launch_as_admin(exe: &std::path::PathBuf, token: &str) -> anyhow::Result<()> { +fn launch_as_admin(exe: &std::path::PathBuf, token: &str, port: u16) -> anyhow::Result<()> { use std::ffi::OsStr; use std::os::windows::ffi::OsStrExt; use std::ptr::null_mut; let exe_wstr: Vec = exe.as_os_str().encode_wide().chain(Some(0)).collect(); let verb_wstr: Vec = OsStr::new("runas").encode_wide().chain(Some(0)).collect(); - let params_str = format!("--token {}", token); + let params_str = format!("--token {} --port {}", token, port); let params_wstr: Vec = OsStr::new(¶ms_str).encode_wide().chain(Some(0)).collect(); #[link(name = "shell32")] extern "system" { fn ShellExecuteW(h: *mut std::ffi::c_void, op: *const u16, f: *const u16, p: *const u16, d: *const u16, s: i32) -> isize; } @@ -514,7 +541,7 @@ fn launch_as_admin(exe: &std::path::PathBuf, token: &str) -> anyhow::Result<()> } #[cfg(not(target_os = "windows"))] -fn launch_as_admin(_exe: &PathBuf, _token: &str) -> Result<()> { anyhow::bail!("Windows only."); } +fn launch_as_admin(_exe: &PathBuf, _token: &str, _port: u16) -> Result<()> { anyhow::bail!("Windows only."); } #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { @@ -607,7 +634,7 @@ pub fn run() { } _ => {} }) - .invoke_handler(tauri::generate_handler![start_tunnel, stop_tunnel, get_tunnel_status, get_metrics, get_config, save_config]) + .invoke_handler(tauri::generate_handler![start_tunnel, stop_tunnel, reload_tunnel, get_tunnel_status, get_metrics, get_config, save_config]) .run(tauri::generate_context!()) .expect("error while running tauri application"); } diff --git a/ostp-gui/src/main.js b/ostp-gui/src/main.js index 21be5bc..30b29d1 100644 --- a/ostp-gui/src/main.js +++ b/ostp-gui/src/main.js @@ -319,6 +319,8 @@ async function handleSave(silent = false) { rawConfig.tun = rawConfig.tun || {}; rawConfig.tun.enable = inTun.checked; + rawConfig.tun.wintun_path = rawConfig.tun.wintun_path || './wintun.dll'; + rawConfig.tun.ipv4_address = rawConfig.tun.ipv4_address || '10.1.0.2/24'; rawConfig.tun.stack = 'ostp'; // owndns: if toggle is on, always write 10.1.0.1; otherwise use the custom field rawConfig.tun.dns = inOwndns.checked ? '10.1.0.1' : (inDns.value.trim() || null); @@ -477,7 +479,14 @@ window.addEventListener('DOMContentLoaded', async () => { // Auto-save wiring const formInputs = document.querySelectorAll('#settings-screen input:not(#in-import-url), #settings-screen textarea, #settings-screen select'); formInputs.forEach(el => { - el.addEventListener('input', scheduleAutoSave); + el.addEventListener('input', () => { + scheduleAutoSave(); + if (appState === 'connected') { + if (window.__TAURI__ && window.__TAURI__.invoke) { + window.__TAURI__.invoke('reload_tunnel'); + } + } + }); el.addEventListener('change', scheduleAutoSave); }); diff --git a/ostp-jni/OstpClientSdk.kt b/ostp-jni/OstpClientSdk.kt index 4d8d7e3..d4baa5a 100644 --- a/ostp-jni/OstpClientSdk.kt +++ b/ostp-jni/OstpClientSdk.kt @@ -32,12 +32,14 @@ class OstpClientSdk private constructor(private val context: Context) { // ── Native JNI bindings ─────────────────────────────────────────────────── - private external fun nativeStartClient(configJson: String): Boolean - private external fun nativeStopClient(): Boolean - private external fun nativeGetMetrics(): String - private external fun nativeGetLogs(): String + private external fun startClient(configJson: String, fd: Int, t2sBinPath: String, localProxy: String): Boolean + private external fun stopClient(): Boolean + private external fun getMetrics(): String + private external fun getLogs(): String + private external fun addLog(logMsg: String) private external fun notifyNetworkChanged() + // ── Public data models ──────────────────────────────────────────────────── /** @@ -175,7 +177,8 @@ class OstpClientSdk private constructor(private val context: Context) { _state.value = TunnelState.Connecting val json = config.toNativeJson() - val ok = nativeStartClient(json) + // Default values for fd, t2sBinPath, localProxy for proxy mode + val ok = startClient(json, -1, "", config.proxyBind) if (!ok) { _state.value = TunnelState.Failed("Native layer rejected config") started.set(false) @@ -197,7 +200,7 @@ class OstpClientSdk private constructor(private val context: Context) { pollingJob?.cancel() networkCallbackJob?.cancel() - nativeStopClient() + stopClient() unregisterNetworkCallback() _state.value = TunnelState.Idle emitLog("OSTP SDK stopped") @@ -209,7 +212,7 @@ class OstpClientSdk private constructor(private val context: Context) { */ fun drainLogs(): List { return try { - val array = JSONArray(nativeGetLogs()) + val array = JSONArray(getLogs()) (0 until array.length()).map { array.getString(it) } } catch (_: Exception) { emptyList() @@ -218,7 +221,7 @@ class OstpClientSdk private constructor(private val context: Context) { /** Read the latest [Metrics] snapshot. Returns zeroed metrics if tunnel is idle. */ fun getMetrics(): Metrics { - return parseMetrics(nativeGetMetrics()) + return parseMetrics(getMetrics()) } // ── Internal helpers ────────────────────────────────────────────────────── @@ -247,7 +250,7 @@ class OstpClientSdk private constructor(private val context: Context) { } // Update state based on metrics availability - val metrics = parseMetrics(nativeGetMetrics()) + val metrics = parseMetrics(getMetrics()) if (wasConnected) { _state.value = TunnelState.Connected(metrics) } @@ -320,6 +323,33 @@ class OstpClientSdk private constructor(private val context: Context) { @Volatile private var instance: OstpClientSdk? = null + @JvmStatic + fun protectSocket(fd: Int): Boolean { + var retries = 5 + while (retries > 0) { + // We use reflection or explicit class to get the VpnService instance + try { + val serviceClass = Class.forName("com.ospab.ostp_client.OstpVpnService") + val instanceField = serviceClass.getDeclaredField("instance") + instanceField.isAccessible = true + val service = instanceField.get(null) + if (service != null) { + val protectMethod = serviceClass.getMethod("protect", Int::class.javaPrimitiveType) + val res = protectMethod.invoke(service, fd) as Boolean + android.util.Log.i("OstpClientSdk", "VpnService.protect(socketFd=$fd) -> success=$res") + return res + } + } catch (e: Exception) { + android.util.Log.w("OstpClientSdk", "Error accessing VpnService via reflection: \${e.message}") + } + android.util.Log.w("OstpClientSdk", "VpnService instance not available! Retrying... (\$retries left)") + Thread.sleep(200) + retries-- + } + android.util.Log.e("OstpClientSdk", "VpnService instance is null! Cannot protect socketFd=\$fd") + return false + } + /** * Get the singleton SDK instance. * Must be called with an Application context to avoid memory leaks. diff --git a/ostp-jni/src/lib.rs b/ostp-jni/src/lib.rs index 278afde..e8fdd46 100644 --- a/ostp-jni/src/lib.rs +++ b/ostp-jni/src/lib.rs @@ -3,7 +3,7 @@ use jni::sys::{jboolean, jstring}; use jni::JNIEnv; use std::collections::VecDeque; -use std::sync::{atomic::Ordering, Arc, Mutex}; +use std::sync::{atomic::Ordering, Arc, Mutex, RwLock}; use tokio::runtime::Runtime; use tokio::sync::{mpsc, watch}; use ostp_client::bridge::{Bridge, BridgeMetrics}; @@ -80,13 +80,13 @@ impl SdkState { } } -static STATE: Mutex = Mutex::new(SdkState::new()); -static LOGS: Mutex> = Mutex::new(VecDeque::new()); -static JVM: Mutex> = Mutex::new(None); -static CLASS_REF: Mutex> = Mutex::new(None); +static STATE: RwLock = RwLock::new(SdkState::new()); +static LOGS: RwLock> = RwLock::new(VecDeque::new()); +static JVM: RwLock> = RwLock::new(None); +static CLASS_REF: RwLock> = RwLock::new(None); fn add_log(text: String) { - if let Ok(mut guard) = LOGS.lock() { + if let Ok(mut guard) = LOGS.write() { if guard.len() >= 1000 { guard.pop_front(); } @@ -95,7 +95,7 @@ fn add_log(text: String) { } #[no_mangle] -pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( +pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeStartClient( mut env: JNIEnv, _class: JClass, config_json: JString, @@ -103,7 +103,7 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( t2s_bin_path: JString, local_proxy: JString, ) -> jboolean { - let mut state = match STATE.lock() { + let mut state = match STATE.write() { Ok(s) => s, Err(_) => return jni::sys::JNI_FALSE, }; @@ -116,25 +116,25 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( init_tracing(); if let Ok(jvm) = env.get_java_vm() { - if let Ok(mut guard) = JVM.lock() { + if let Ok(mut guard) = JVM.write() { *guard = Some(jvm); } } if let Ok(cls) = env.find_class("net/ostp/client/OstpClientSdk") { if let Ok(global_cls) = env.new_global_ref(cls) { - if let Ok(mut guard) = CLASS_REF.lock() { + if let Ok(mut guard) = CLASS_REF.write() { *guard = Some(global_cls); } } } ostp_client::bridge::set_socket_protector(|fd| { - let jvm_guard = match JVM.lock() { + let jvm_guard = match JVM.read() { Ok(g) => g, Err(_) => return false, }; - let class_guard = match CLASS_REF.lock() { + let class_guard = match CLASS_REF.read() { Ok(g) => g, Err(_) => return false, }; @@ -346,12 +346,24 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( } #[no_mangle] -pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient( +pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( + env: JNIEnv, + class: JClass, + config_json: JString, + fd: jni::sys::jint, + t2s_bin_path: JString, + local_proxy: JString, +) -> jboolean { + Java_net_ostp_client_OstpClientSdk_nativeStartClient(env, class, config_json, fd, t2s_bin_path, local_proxy) +} + +#[no_mangle] +pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeStopClient( _env: JNIEnv, _class: JClass, ) -> jboolean { let (tun_child, shutdown_tx, runtime) = { - let mut state = match STATE.lock() { + let mut state = match STATE.write() { Ok(s) => s, Err(_) => return jni::sys::JNI_FALSE, }; @@ -381,11 +393,19 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient( } #[no_mangle] -pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics( +pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient( + env: JNIEnv, + class: JClass, +) -> jboolean { + Java_net_ostp_client_OstpClientSdk_nativeStopClient(env, class) +} + +#[no_mangle] +pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeGetMetrics( env: JNIEnv, _class: JClass, ) -> jstring { - let state = match STATE.lock() { + let state = match STATE.read() { Ok(s) => s, Err(_) => return match env.new_string("{}") { Ok(s) => s.into_raw(), @@ -415,11 +435,19 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics( } #[no_mangle] -pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs( +pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics( + env: JNIEnv, + class: JClass, +) -> jstring { + Java_net_ostp_client_OstpClientSdk_nativeGetMetrics(env, class) +} + +#[no_mangle] +pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeGetLogs( env: JNIEnv, _class: JClass, ) -> jstring { - let logs_vec: Vec = match LOGS.lock() { + let logs_vec: Vec = match LOGS.write() { Ok(mut guard) => guard.drain(..).collect(), Err(_) => Vec::new(), }; @@ -435,6 +463,14 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs( } } +#[no_mangle] +pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs( + env: JNIEnv, + class: JClass, +) -> jstring { + Java_net_ostp_client_OstpClientSdk_nativeGetLogs(env, class) +} + #[no_mangle] pub extern "system" fn Java_net_ostp_client_OstpClientSdk_addLog( mut env: JNIEnv, @@ -454,7 +490,7 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_notifyNetworkChanged( _env: JNIEnv, _class: JClass, ) { - let state = match STATE.lock() { + let state = match STATE.read() { Ok(s) => s, Err(_) => return, }; diff --git a/ostp-server/src/api.rs b/ostp-server/src/api.rs index cec5a76..df80329 100644 --- a/ostp-server/src/api.rs +++ b/ostp-server/src/api.rs @@ -310,7 +310,7 @@ fn check_token(state: &ApiState, headers: &axum::http::HeaderMap) -> bool { if let Some(value) = headers.get("authorization") { if let Ok(val) = value.to_str() { if let Some(token) = val.strip_prefix("Bearer ") { - let current_session = state.session_token.read().unwrap().clone(); + let current_session = state.session_token.read().unwrap_or_else(|e| e.into_inner()).clone(); if let Some(session) = current_session { if token == session { allowed = true; @@ -353,7 +353,7 @@ async fn handle_login( if hash_hex == state.password_hash { let token = uuid::Uuid::new_v4().to_string(); - *state.session_token.write().unwrap() = Some(token.clone()); + *state.session_token.write().unwrap_or_else(|e| e.into_inner()) = Some(token.clone()); (StatusCode::OK, ApiResponse::success(LoginResponse { token })) } else { api_unauthorized::() @@ -377,7 +377,7 @@ fn save_config_keys(state: &ApiState) -> Result<(), String> { let mut json_val: serde_json::Value = serde_json::from_str(&content_str) .map_err(|e| format!("failed to parse config JSON: {}", e))?; - let keys = state.access_keys.read().unwrap(); + let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner()); let mut access_keys_json = Vec::new(); for (k, m) in keys.iter() { if m.name.is_none() && m.limit_bytes.is_none() { @@ -511,8 +511,8 @@ async fn handle_status( return api_unauthorized::(); } - let keys = state.access_keys.read().unwrap(); - let stats = state.user_stats.read().unwrap(); + let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner()); + let stats = state.user_stats.read().unwrap_or_else(|e| e.into_inner()); let online = stats.values() .filter(|us| { let total = us.bytes_up.load(Ordering::Relaxed) + us.bytes_down.load(Ordering::Relaxed); @@ -538,8 +538,8 @@ async fn handle_list_users( return api_unauthorized::>(); } - let keys = state.access_keys.read().unwrap(); - let stats = state.user_stats.read().unwrap(); + let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner()); + let stats = state.user_stats.read().unwrap_or_else(|e| e.into_inner()); let mut users: Vec = keys.iter().map(|(key, meta)| { if let Some(us) = stats.get(key) { @@ -579,13 +579,13 @@ async fn handle_get_user( return api_unauthorized::(); } - let keys = state.access_keys.read().unwrap(); + let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner()); let meta = match keys.get(&key) { Some(m) => m.clone(), None => return api_error("user not found"), }; - let stats = state.user_stats.read().unwrap(); + let stats = state.user_stats.read().unwrap_or_else(|e| e.into_inner()); let snapshot = if let Some(us) = stats.get(&key) { UserStatsSnapshot { access_key: key.clone(), @@ -628,11 +628,11 @@ async fn handle_create_user( }); { - let mut keys = state.access_keys.write().unwrap(); + let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner()); keys.insert(key.clone(), UserMeta { name: body.name.clone(), limit_bytes: body.limit_bytes }); } - let mut stats = state.user_stats.write().unwrap(); + let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner()); stats.insert(key.clone(), Arc::new(UserStats::new(body.limit_bytes))); drop(stats); @@ -655,14 +655,14 @@ async fn delete_user( } { - let mut keys = state.access_keys.write().unwrap(); + let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner()); if keys.remove(&key).is_none() { return api_error::("User not found"); } } { - let mut stats = state.user_stats.write().unwrap(); + let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner()); stats.remove(&key); } @@ -685,7 +685,7 @@ async fn update_user( } { - let mut keys = state.access_keys.write().unwrap(); + let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner()); if let Some(meta) = keys.get_mut(&key) { meta.name = body.name.clone(); meta.limit_bytes = body.limit_bytes; @@ -695,7 +695,7 @@ async fn update_user( } { - let mut stats = state.user_stats.write().unwrap(); + let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner()); let entry = stats.entry(key.clone()) .or_insert_with(|| Arc::new(UserStats::new(body.limit_bytes))); @@ -727,7 +727,7 @@ async fn handle_set_limit( } { - let mut keys = state.access_keys.write().unwrap(); + let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner()); if let Some(meta) = keys.get_mut(&key) { meta.limit_bytes = body.limit_bytes; } else { @@ -735,7 +735,7 @@ async fn handle_set_limit( } } - let mut stats = state.user_stats.write().unwrap(); + let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner()); let entry = stats.entry(key.clone()) .or_insert_with(|| Arc::new(UserStats::new(body.limit_bytes))); @@ -765,7 +765,7 @@ async fn handle_reset_stats( return api_unauthorized::(); } - let mut stats = state.user_stats.write().unwrap(); + let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner()); if let Some(us) = stats.get(&key) { let limit = us.limit_bytes; stats.insert(key.clone(), Arc::new(UserStats::new(limit))); @@ -793,7 +793,7 @@ async fn handle_subscribe( // Validate that the key exists in a tightly scoped block to drop the guard let key_exists = { - let keys = state.access_keys.read().unwrap(); + let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner()); keys.contains_key(&key) }; diff --git a/ostp-server/src/dispatcher.rs b/ostp-server/src/dispatcher.rs index 125f607..e84c0a4 100644 --- a/ostp-server/src/dispatcher.rs +++ b/ostp-server/src/dispatcher.rs @@ -86,7 +86,7 @@ pub struct Dispatcher { impl Dispatcher { pub fn new(machine_config: ProtocolConfig, access_keys: Arc>>) -> Self { let mut initial_stats = HashMap::new(); - for (key, meta) in access_keys.read().unwrap().iter() { + for (key, meta) in access_keys.read().unwrap_or_else(|e| e.into_inner()).iter() { initial_stats.insert(key.clone(), Arc::new(UserStats::new(meta.limit_bytes))); } Self { @@ -108,7 +108,7 @@ impl Dispatcher { /// Snapshot all user stats for API responses. pub fn snapshot_all_users(&self) -> Vec { - let stats = self.user_stats.read().unwrap(); + let stats = self.user_stats.read().unwrap_or_else(|e| e.into_inner()); let online_keys: std::collections::HashSet = self.peer_machines.values() .map(|ps| ps.access_key.clone()) .collect(); @@ -125,15 +125,15 @@ impl Dispatcher { /// Get or create stats entry for a user key. fn get_or_create_user_stats(&self, key: &str) -> Arc { - let stats = self.user_stats.read().unwrap(); + let stats = self.user_stats.read().unwrap_or_else(|e| e.into_inner()); if let Some(existing) = stats.get(key) { return existing.clone(); } drop(stats); - let limit_bytes = self.access_keys.read().unwrap().get(key).and_then(|m| m.limit_bytes); + let limit_bytes = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).get(key).and_then(|m| m.limit_bytes); - let mut stats = self.user_stats.write().unwrap(); + let mut stats = self.user_stats.write().unwrap_or_else(|e| e.into_inner()); stats.entry(key.to_string()) .or_insert_with(|| Arc::new(UserStats::new(limit_bytes))) .clone() @@ -141,7 +141,7 @@ impl Dispatcher { /// Set traffic limit for a user. pub fn set_user_limit(&self, key: &str, limit: Option) { - let mut stats = self.user_stats.write().unwrap(); + let mut stats = self.user_stats.write().unwrap_or_else(|e| e.into_inner()); let entry = stats.entry(key.to_string()) .or_insert_with(|| Arc::new(UserStats::new(limit))); // Replace the entry with new limit (stats reset) @@ -212,7 +212,7 @@ impl Dispatcher { let key_opt = self.peer_machines.get(&session_id).map(|ps| ps.access_key.clone()); if let Some(access_key) = key_opt { // Check if key is still valid and not over limit - let key_valid = self.access_keys.read().unwrap().contains_key(&access_key); + let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&access_key); let user_stats = self.get_or_create_user_stats(&access_key); if !key_valid || user_stats.is_over_limit() { tracing::info!("Dropping session {} for key {} (valid={}, over_limit={})", @@ -280,7 +280,7 @@ impl Dispatcher { } // Not an existing session — try each registered access key's derived obfuscation key - let keys_snapshot: Vec = self.access_keys.read().unwrap().keys().cloned().collect(); + let keys_snapshot: Vec = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).keys().cloned().collect(); for candidate_key in keys_snapshot { let secrets = ostp_core::crypto::derive_all_secrets(candidate_key.as_bytes()); @@ -430,7 +430,7 @@ impl Dispatcher { // Gather expired or invalid sessions for (&sid, peer_state) in &self.peer_machines { - let key_valid = self.access_keys.read().unwrap().contains_key(&peer_state.access_key); + let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&peer_state.access_key); let user_stats = self.get_or_create_user_stats(&peer_state.access_key); if now.duration_since(peer_state.last_seen) > timeout_dur || !key_valid || user_stats.is_over_limit() { expired.push(sid); @@ -441,7 +441,7 @@ impl Dispatcher { for sid in &expired { let peer_state_opt = self.peer_machines.get(sid); let reason = if let Some(ps) = peer_state_opt { - let key_valid = self.access_keys.read().unwrap().contains_key(&ps.access_key); + let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&ps.access_key); let user_stats = self.get_or_create_user_stats(&ps.access_key); if now.duration_since(ps.last_seen) > timeout_dur { "inactive >5min" @@ -504,15 +504,15 @@ fn get_or_create_stats( key: &str, ) -> Arc { { - let stats = user_stats.read().unwrap(); + let stats = user_stats.read().unwrap_or_else(|e| e.into_inner()); if let Some(existing) = stats.get(key) { return existing.clone(); } } - let limit_bytes = access_keys.read().unwrap().get(key).and_then(|m| m.limit_bytes); + let limit_bytes = access_keys.read().unwrap_or_else(|e| e.into_inner()).get(key).and_then(|m| m.limit_bytes); - let mut stats = user_stats.write().unwrap(); + let mut stats = user_stats.write().unwrap_or_else(|e| e.into_inner()); stats.entry(key.to_string()) .or_insert_with(|| Arc::new(UserStats::new(limit_bytes))) .clone() diff --git a/ostp-server/src/lib.rs b/ostp-server/src/lib.rs index aa9f2d4..d837027 100644 --- a/ostp-server/src/lib.rs +++ b/ostp-server/src/lib.rs @@ -195,11 +195,11 @@ pub async fn run_server( } // 1. Update shared_keys - let mut keys_lock = shared_keys_clone.write().unwrap(); + let mut keys_lock = shared_keys_clone.write().unwrap_or_else(|e| e.into_inner()); *keys_lock = new_keys.clone(); // 2. Synchronize user_stats limits & cleanup deleted keys - let mut stats_lock = user_stats_clone.write().unwrap(); + let mut stats_lock = user_stats_clone.write().unwrap_or_else(|e| e.into_inner()); stats_lock.retain(|k, _| new_keys.contains_key(k)); for (k, meta) in &new_keys { @@ -308,7 +308,7 @@ pub async fn run_server( } }); - let key_count = shared_keys.read().unwrap().len(); + let key_count = shared_keys.read().unwrap_or_else(|e| e.into_inner()).len(); tracing::info!(listeners = bind_addrs.len(), keys = key_count, "server started"); tracing::info!("ARQ config: max_reorder=16384, reorder_buf=8192, sent_history=32768, rto=100ms"); let reality_config_arc = reality_config.map(std::sync::Arc::new); @@ -434,7 +434,7 @@ async fn run_server_loop( if debug { let _ = ui_event_tx.send(UiEvent::Log("Server loop started".to_string())); - let _ = ui_event_tx.send(UiEvent::KeyCount(shared_keys.read().unwrap().len())); + let _ = ui_event_tx.send(UiEvent::KeyCount(shared_keys.read().unwrap_or_else(|e| e.into_inner()).len())); } let mut retransmit_tick = interval(Duration::from_millis(10)); @@ -448,7 +448,7 @@ async fn run_server_loop( match cmd { Some(UiCommand::CreateClientKey) => { let key = format!("ostp_key_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs()); - shared_keys.write().unwrap().insert(key.clone(), crate::api::UserMeta { name: None, limit_bytes: None }); + shared_keys.write().unwrap_or_else(|e| e.into_inner()).insert(key.clone(), crate::api::UserMeta { name: None, limit_bytes: None }); let _ = ui_event_tx.send(UiEvent::KeyCreated { key }); } Some(UiCommand::Shutdown) | None => { diff --git a/ostp-server/src/relay_node.rs b/ostp-server/src/relay_node.rs index 37975e5..919ab58 100644 --- a/ostp-server/src/relay_node.rs +++ b/ostp-server/src/relay_node.rs @@ -48,7 +48,7 @@ pub async fn run_relay_node(cfg: RelayConfig) -> Result<()> { if let Err(e) = sync_keys(&cfg, &shared_keys).await { tracing::warn!("Relay: initial key sync failed: {}. Will retry.", e); } else { - let count = shared_keys.read().unwrap().len(); + let count = shared_keys.read().unwrap_or_else(|e| e.into_inner()).len(); tracing::info!("Relay: synced {} access key(s) from upstream API", count); } @@ -190,7 +190,7 @@ async fn run_udp_relay(cfg: RelayConfig, shared_keys: SharedKeys) -> Result<()> let ts_bytes: [u8; 8] = packet[0..8].try_into().unwrap(); let provided_mac = &packet[8..40]; - let keys_guard = keys.read().unwrap(); + let keys_guard = keys.read().unwrap_or_else(|e| e.into_inner()); if !verify_hmac(&ts_bytes, provided_mac, &keys_guard) { tracing::debug!("Relay UDP: unauthorized probe from {}, dropped", peer); @@ -369,7 +369,7 @@ async fn handle_tcp_client( // Проверяем по синхронизированным ключам let authorized = { - let keys = shared_keys.read().unwrap(); + let keys = shared_keys.read().unwrap_or_else(|e| e.into_inner()); verify_hmac(&ts_bytes, provided_mac, &keys) }; diff --git a/ostp-tun-helper/src/main.rs b/ostp-tun-helper/src/main.rs index 635149c..9dc93b8 100644 --- a/ostp-tun-helper/src/main.rs +++ b/ostp-tun-helper/src/main.rs @@ -5,7 +5,6 @@ use anyhow::Result; use serde::{Deserialize, Serialize}; use std::sync::Arc; use std::time::Duration; -use std::fs::OpenOptions; use std::io::Write as _; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::sync::{watch, Mutex}; @@ -13,21 +12,25 @@ use tokio::net::TcpListener; use portable_atomic::Ordering; fn log_to_file(msg: &str) { - let path = std::env::current_exe() - .ok() - .and_then(|p| p.parent().map(|d| d.join("ostp-helper.log"))) - .unwrap_or_else(|| std::path::PathBuf::from("ostp-helper.log")); - if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { - let _ = writeln!(file, "[{}] {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), msg); - } + let msg = msg.to_string(); + tokio::task::spawn_blocking(move || { + let path = std::env::current_exe() + .ok() + .and_then(|p| p.parent().map(|d| d.join("ostp-helper.log"))) + .unwrap_or_else(|| std::path::PathBuf::from("ostp-helper.log")); + if let Ok(mut file) = std::fs::OpenOptions::new().create(true).append(true).open(path) { + let _ = writeln!(file, "[{}] {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), msg); + } + }); } -const BIND_ADDR: &str = "127.0.0.1:53211"; + #[derive(Deserialize)] #[serde(tag = "cmd", rename_all = "lowercase")] enum GuiCmd { Start { config: String, token: String }, + Reload { config: String, token: String }, Stop { token: String }, } @@ -55,10 +58,13 @@ async fn main() -> Result<()> { } let mut expected_token = String::new(); + let mut port = 53211u16; let args: Vec = std::env::args().collect(); for i in 1..args.len() { if args[i] == "--token" && i + 1 < args.len() { expected_token = args[i + 1].clone(); + } else if args[i] == "--port" && i + 1 < args.len() { + port = args[i + 1].parse().unwrap_or(53211); } } @@ -69,21 +75,22 @@ async fn main() -> Result<()> { return Err(anyhow::anyhow!("--token argument is required")); } - if let Err(e) = run_server(expected_token).await { + if let Err(e) = run_server(expected_token, port).await { log_to_file(&format!("Fatal error: {}", e)); } log_to_file("Helper exiting"); Ok(()) } -async fn run_server(expected_token: String) -> Result<()> { +async fn run_server(expected_token: String, port: u16) -> Result<()> { let state = Arc::new(Mutex::new(TunnelState { shutdown_tx: None, metrics: None, })); - log_to_file(&format!("Attempting to bind to {}", BIND_ADDR)); - let listener = TcpListener::bind(BIND_ADDR).await.map_err(|e| { + let bind_addr = format!("127.0.0.1:{}", port); + log_to_file(&format!("Attempting to bind to {}", bind_addr)); + let listener = TcpListener::bind(&bind_addr).await.map_err(|e| { log_to_file(&format!("Bind failed: {}", e)); e })?; @@ -182,9 +189,10 @@ async fn run_server(expected_token: String) -> Result<()> { let metrics_for_runner = metrics.clone(); let writer_for_err = writer.clone(); + let shutdown_rx_for_core = shutdown_rx.clone(); tokio::spawn(async move { log_to_file("Starting tunnel core..."); - match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx).await { + match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx_for_core).await { Ok(_) => { log_to_file("Tunnel core stopped normally"); } Err(e) => { log_to_file(&format!("Tunnel core error: {}", e)); @@ -197,10 +205,17 @@ async fn run_server(expected_token: String) -> Result<()> { let writer_tick = writer.clone(); let metrics_tick = metrics.clone(); + let mut shutdown_rx_tick = shutdown_rx.clone(); tokio::spawn(async move { let mut last_state = 99u8; loop { - tokio::time::sleep(Duration::from_secs(1)).await; + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(1)) => {} + _ = shutdown_rx_tick.changed() => { + if *shutdown_rx_tick.borrow() { break; } + } + } + let cs = metrics_tick.connection_state.load(Ordering::Relaxed); let sent = metrics_tick.bytes_sent.load(Ordering::Relaxed); let recv = metrics_tick.bytes_recv.load(Ordering::Relaxed); @@ -221,6 +236,95 @@ async fn run_server(expected_token: String) -> Result<()> { send_msg(HelperMsg::Status { value: 1 }); } + GuiCmd::Reload { config, token } => { + if token != expected_token { + send_msg(HelperMsg::Error { message: "Invalid authorization token".to_string() }); + continue; + } + log_to_file("Received RELOAD command"); + + // Signal shutdown to current core + { + let mut st = state.lock().await; + if let Some(tx) = st.shutdown_tx.take() { + let _ = tx.send(true); + } + tokio::time::sleep(Duration::from_millis(500)).await; // give it time to shutdown cleanly + } + + let cfg: ostp_client::config::ClientConfig = match serde_json::from_str(&config) { + Ok(c) => c, + Err(e) => { + send_msg(HelperMsg::Error { message: format!("Config parse error during reload: {}", e) }); + continue; + } + }; + + let metrics = Arc::new(ostp_client::bridge::BridgeMetrics { + bytes_sent: portable_atomic::AtomicU64::new(0), + bytes_recv: portable_atomic::AtomicU64::new(0), + connection_state: portable_atomic::AtomicU8::new(0), + rtt_ms: portable_atomic::AtomicU32::new(0), + }); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + { + let mut st = state.lock().await; + st.shutdown_tx = Some(shutdown_tx); + st.metrics = Some(metrics.clone()); + } + + let metrics_for_runner = metrics.clone(); + let writer_for_err = writer.clone(); + let shutdown_rx_for_core = shutdown_rx.clone(); + tokio::spawn(async move { + log_to_file("Restarting tunnel core for reload..."); + match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx_for_core).await { + Ok(_) => { log_to_file("Reloaded core stopped normally"); } + Err(e) => { + let json = serde_json::to_string(&HelperMsg::Error { message: e.to_string() }).unwrap_or_default(); + let mut w = writer_for_err.lock().await; + let _ = w.write_all(format!("{}\n", json).as_bytes()).await; + } + } + }); + + // Status tick loop is already running and using old metrics? + // Wait! We re-created metrics, so the old tick loop will continue reporting old metrics (which are disconnected)! + // We should probably share the tick loop or spawn a new one and let the old one die. + // It's easier if `metrics` in state is a generic watcher, but since we re-spawned it: + let writer_tick = writer.clone(); + let metrics_tick = metrics.clone(); + let mut shutdown_rx_tick = shutdown_rx.clone(); + tokio::spawn(async move { + let mut last_state = 99u8; + loop { + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(1)) => {} + _ = shutdown_rx_tick.changed() => { + if *shutdown_rx_tick.borrow() { break; } + } + } + let cs = metrics_tick.connection_state.load(Ordering::Relaxed); + let sent = metrics_tick.bytes_sent.load(Ordering::Relaxed); + let recv = metrics_tick.bytes_recv.load(Ordering::Relaxed); + let rtt = metrics_tick.rtt_ms.load(Ordering::Relaxed); + + let mut w = writer_tick.lock().await; + if cs != last_state { + last_state = cs; + let json = serde_json::to_string(&HelperMsg::Status { value: cs }).unwrap_or_default(); + if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; } + } + let json = serde_json::to_string(&HelperMsg::Metrics { bytes_sent: sent, bytes_recv: recv, rtt_ms: rtt }).unwrap_or_default(); + if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; } + drop(w); + } + }); + + send_msg(HelperMsg::Status { value: 1 }); + } GuiCmd::Stop { token } => { if token != expected_token { log_to_file("Received STOP command with invalid token"); diff --git a/refactor.py b/refactor.py new file mode 100644 index 0000000..12b9fc4 --- /dev/null +++ b/refactor.py @@ -0,0 +1,658 @@ +import sys +import re + +with open("d:/ospab-projects/ostp/ostp-client/src/bridge.rs", "r", encoding="utf-8") as f: + code = f.read() + +start_idx = code.find(" pub async fn run(") +end_idx = -1 +brace_count = 0 +in_run = False +for i in range(start_idx, len(code)): + if code[i] == '{': + in_run = True + brace_count += 1 + elif code[i] == '}': + if in_run: + brace_count -= 1 + if brace_count == 0: + end_idx = i + 1 + break + +prefix = code[:start_idx] +suffix = code[end_idx:] + +# Define the new run function and helpers +new_run_and_helpers = """ + pub async fn run( + mut self, + tx: mpsc::Sender, + mut bridge_rx: mpsc::Receiver, + mut shutdown: watch::Receiver, + mut proxy_rx: mpsc::Receiver, + proxy_tx: mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + ) -> Result<()> { + let mut metrics_tick = interval(Duration::from_millis(500)); + let mut keepalive_tick = tokio::time::interval(Duration::from_secs(self.keepalive_interval_sec.max(1))); + let mut retransmit_tick = tokio::time::interval(Duration::from_millis(10)); + let init_msg = if self.mode == "tun" { + "Bridge initialized (TUN mode)".to_string() + } else { + "Bridge initialized (proxy mode)".to_string() + }; + tx.send(UiEvent::Log(init_msg)).await.ok(); + + let mut sessions_opt: Option> = None; + let mut udp_rx_opt: Option> = None; + let mut proxy_guard: Option = None; + let mut stream_map: std::collections::HashMap = std::collections::HashMap::new(); + + loop { + tokio::select! { + biased; + _ = shutdown.changed() => { + if *shutdown.borrow() { + self.running = false; + self.metrics.connection_state.store(0, Ordering::Relaxed); + proxy_guard = None; + sessions_opt = None; + udp_rx_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "manual stop"); + break; + } + } + udp_msg = async { + match udp_rx_opt.as_mut() { + Some(rx) => rx.recv().await, + None => std::future::pending().await, + } + }, if self.running => { + self.handle_inbound_udp(udp_msg, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await; + } + cmd = bridge_rx.recv() => { + if !self.handle_bridge_cmd(cmd, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await { + break; + } + } + _ = metrics_tick.tick() => { + if self.running { + self.emit_metrics(&tx).await; + } + } + _ = keepalive_tick.tick() => { + if self.running { + self.handle_keepalive(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx, &mut proxy_rx).await; + } + } + _ = retransmit_tick.tick() => { + if self.running { + self.handle_retransmit(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await; + } + } + proxy_ev = proxy_rx.recv(), if self.running && sessions_opt.as_ref().map(|s| { + s.iter().any(|ses| ses.machine.in_flight_count() < ses.machine.cwnd_packets().clamp(16, 16384)) + }).unwrap_or(true) => { + self.handle_proxy_event(proxy_ev, &mut sessions_opt, &mut stream_map, &tx, &proxy_tx).await; + } + } + } + + tx.send(UiEvent::Log("Bridge stopped".to_string())).await.ok(); + Ok(()) + } + + async fn handle_inbound_udp( + &mut self, + udp_msg: Option<(usize, Bytes)>, + sessions_opt: &mut Option>, + udp_rx_opt: &mut Option>, + proxy_guard: &mut Option, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + ) { + match udp_msg { + Some((session_index, inbound)) => { + self.metrics.bytes_recv.fetch_add(inbound.len() as u64, Ordering::Relaxed); + self.last_valid_recv = Instant::now(); + if let Some(sessions) = sessions_opt.as_mut() { + if session_index < sessions.len() { + let session = &mut sessions[session_index]; + let initial_action = match session.machine.on_event(OstpEvent::Inbound(inbound)) { + Ok(a) => a, + Err(e) => { + let _ = tx.send(UiEvent::Log(format!("Protocol decrypt error: {e}"))).await; + tracing::warn!("Inbound protocol error (session {}): {}", session_index, e); + return; + } + }; + + let mut actions_queue = std::collections::VecDeque::new(); + actions_queue.push_back(initial_action); + + while let Some(current_action) = actions_queue.pop_front() { + match current_action { + ProtocolAction::Multiple(nested) => { + for a in nested { + actions_queue.push_back(a); + } + } + ProtocolAction::DeliverApp(stream_id, dec_payload) => { + match RelayMessage::decode(&dec_payload) { + Ok(relay_msg) => { + match relay_msg { + RelayMessage::ConnectOk => { + let _ = tx.send(UiEvent::Log(format!("Relay CONNECT OK stream_id={stream_id}"))).await; + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::ConnectOk)); + } + RelayMessage::Data(data) => { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Data(Bytes::from(data)))); + } + RelayMessage::Close => { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Close)); + } + RelayMessage::Error(msg) => { + let _ = tx.send(UiEvent::Log(format!("Relay error for stream {stream_id}: {msg}"))).await; + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error(msg))); + } + RelayMessage::Pong(ts) => { + let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; + self.last_rtt_ms = now.saturating_sub(ts) as f64; + self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed); + } + RelayMessage::UdpAssociate => {} + RelayMessage::UdpData(target, data) => { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data)))); + } + RelayMessage::KeepAlive | RelayMessage::Ping(_) | RelayMessage::Connect(_) => {} + } + } + Err(err) => { + let _ = tx.send(UiEvent::Log(format!("Relay decode error for stream {stream_id}: {err}"))).await; + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("relay decode failed".to_string()))); + } + } + } + ProtocolAction::SendDatagram(frame) => { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + _ => {} + } + } + } + } + } + None => { + let _ = tx.send(UiEvent::Log("UDP channel closed, resetting connection".to_string())).await; + self.running = false; + crate::sysproxy::disable_system_proxy(); + *sessions_opt = None; + *udp_rx_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "udp reader closed"); + let _ = tx.send(UiEvent::TunnelStopped).await; + } + } + } + + async fn handle_bridge_cmd( + &mut self, + cmd: Option, + sessions_opt: &mut Option>, + udp_rx_opt: &mut Option>, + proxy_guard: &mut Option, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + ) -> bool { + match cmd { + Some(BridgeCommand::ToggleTunnel) => { + if self.running { + self.running = false; + self.metrics.connection_state.store(0, Ordering::Relaxed); + *proxy_guard = None; + *sessions_opt = None; + *udp_rx_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "manual stop"); + tx.send(UiEvent::TunnelStopped).await.ok(); + let stop_msg = if self.mode == "tun" { "TUN tunnel stopped" } else { "Bridge stopped" }; + tx.send(UiEvent::Log(stop_msg.to_string())).await.ok(); + } else { + tx.send(UiEvent::Log("Connecting to remote server...".to_string())).await.ok(); + tx.send(UiEvent::Metrics { status: ConnectionStatus::Handshaking, rtt_ms: 0.0, throughput_bps: 0 }).await.ok(); + self.metrics.connection_state.store(1, Ordering::Relaxed); + + let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; + let (udp_tx, udp_rx) = mpsc::channel(100000); + let mut sessions = Vec::with_capacity(session_count); + let mut rtt_sum = 0.0; + let mut successful_sessions = 0; + + for idx in 0..session_count { + let session_id: u32 = rand::thread_rng().gen(); + match self.perform_handshake_with_id(&tx, session_id).await { + Ok((sock, mach, rtt)) => { + let session_index = sessions.len(); + let socket_clone = sock.clone(); + let udp_tx_clone = udp_tx.clone(); + + tokio::spawn(async move { + let mut buf = vec![0_u8; 65535]; + loop { + match socket_clone.recv(&mut buf).await { + Ok(n) => { + let inbound = Bytes::copy_from_slice(&buf[..n]); + if udp_tx_clone.send((session_index, inbound)).await.is_err() { + break; + } + } + Err(e) => { + tracing::warn!("UDP socket recv error (session {}): {}", session_index, e); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + } + }); + + sessions.push(SessionState { socket: sock, machine: mach }); + rtt_sum += rtt; + successful_sessions += 1; + } + Err(err) => { + tx.send(UiEvent::Log(format!("Multiplex session {}/{} handshake failed: {}. Continuing with remaining sessions...", idx + 1, session_count, err))).await.ok(); + } + } + } + + if sessions.is_empty() { + *proxy_guard = None; + tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok(); + tx.send(UiEvent::TunnelStopped).await.ok(); + self.metrics.connection_state.store(0, Ordering::Relaxed); + return True; + } + + *udp_rx_opt = Some(udp_rx); + *sessions_opt = Some(sessions); + self.last_rtt_ms = rtt_sum / successful_sessions as f64; + self.running = true; + self.last_sample_at = Instant::now(); + self.last_valid_recv = Instant::now(); + + let sys_proxy_addr = self.proxy_addr.replace("0.0.0.0:", "127.0.0.1:"); + *proxy_guard = Some(crate::sysproxy::SystemProxyGuard::enable(&sys_proxy_addr)); + + tx.send(UiEvent::Metrics { + status: ConnectionStatus::Established, + rtt_ms: self.last_rtt_ms, + throughput_bps: 0, + }).await.ok(); + self.metrics.connection_state.store(2, Ordering::Relaxed); + let start_msg = if self.mode == "tun" { "TUN tunnel established" } else { "Connection established" }; + tx.send(UiEvent::Log(start_msg.to_string())).await.ok(); + + for session in sessions_opt.as_mut().unwrap().iter_mut() { + let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; + let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); + if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + } + } + } + Some(BridgeCommand::NextProfile) => { + self.profile = next_profile(self.profile); + tx.send(UiEvent::ProfileChanged(self.profile)).await.ok(); + tx.send(UiEvent::Log(format!("Obfuscation profile switched to {:?}", self.profile))).await.ok(); + } + Some(BridgeCommand::NetworkChanged) => { + if self.running { + let _ = tx.send(UiEvent::Log("Network changed — starting immediate reconnect".to_string())).await; + self.metrics.connection_state.store(1, Ordering::Relaxed); + self.last_valid_recv = Instant::now() - Duration::from_secs(100); + + let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; + let (udp_tx, udp_rx) = mpsc::channel(100000); + let mut new_sessions = Vec::with_capacity(session_count); + let mut successful_sessions = 0; + let mut rtt_sum = 0.0; + + for idx in 0..session_count { + let session_id: u32 = rand::thread_rng().gen(); + match self.perform_handshake_with_id(&tx, session_id).await { + Ok((sock, mach, rtt)) => { + let session_index = new_sessions.len(); + let socket_clone = sock.clone(); + let udp_tx_clone = udp_tx.clone(); + + tokio::spawn(async move { + let mut buf = vec![0_u8; 65535]; + loop { + match socket_clone.recv(&mut buf).await { + Ok(n) => { + let inbound = Bytes::copy_from_slice(&buf[..n]); + if udp_tx_clone.send((session_index, inbound)).await.is_err() { break; } + } + Err(e) => { + tracing::warn!("UDP recv error (network-change session {}): {}", session_index, e); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + } + }); + new_sessions.push(SessionState { socket: sock, machine: mach }); + rtt_sum += rtt; + successful_sessions += 1; + } + Err(err) => { + let _ = tx.send(UiEvent::Log(format!("NetworkChanged reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await; + } + } + } + + if !new_sessions.is_empty() { + *sessions_opt = Some(new_sessions); + *udp_rx_opt = Some(udp_rx); + self.last_rtt_ms = rtt_sum / successful_sessions as f64; + self.last_valid_recv = Instant::now(); + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "network changed"); + self.metrics.connection_state.store(2, Ordering::Relaxed); + let _ = tx.send(UiEvent::Log("NetworkChanged reconnect successful!".to_string())).await; + } else { + let _ = tx.send(UiEvent::Log("NetworkChanged reconnect failed — will retry on keepalive tick".to_string())).await; + } + } + } + Some(BridgeCommand::ReloadConfig) => { + match ClientConfig::reload_from_json_near_binary() { + Ok(cfg) => { + self.apply_runtime_config(&cfg); + tx.send(UiEvent::Log("Runtime config reloaded".to_string())).await.ok(); + if self.running { + self.running = false; + self.metrics.connection_state.store(0, Ordering::Relaxed); + *proxy_guard = None; + *sessions_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "config reload"); + let _ = tx.send(UiEvent::TunnelStopped).await; + } + } + Err(err) => { + let _ = tx.send(UiEvent::Log(format!("Config reload failed: {err}"))).await; + } + } + } + Some(BridgeCommand::Shutdown) | None => { + self.running = false; + *proxy_guard = None; + return False; + } + } + True + } + + async fn handle_keepalive( + &mut self, + sessions_opt: &mut Option>, + udp_rx_opt: &mut Option>, + proxy_guard: &mut Option, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + proxy_rx: &mut mpsc::Receiver, + ) { + if self.last_valid_recv.elapsed().as_secs() > 25 { + let elapsed = self.last_valid_recv.elapsed().as_secs(); + if elapsed > 180 { + let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await; + self.running = false; + *proxy_guard = None; + *sessions_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout"); + let _ = tx.send(UiEvent::TunnelStopped).await; + self.metrics.connection_state.store(0, Ordering::Relaxed); + return; + } + + let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await; + self.metrics.connection_state.store(1, Ordering::Relaxed); + + let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; + let (udp_tx, udp_rx) = mpsc::channel(100000); + let mut new_sessions = Vec::with_capacity(session_count); + let mut successful_sessions = 0; + let mut rtt_sum = 0.0; + + for idx in 0..session_count { + let session_id: u32 = rand::thread_rng().gen(); + match self.perform_handshake_with_id(&tx, session_id).await { + Ok((sock, mach, rtt)) => { + let session_index = new_sessions.len(); + let socket_clone = sock.clone(); + let udp_tx_clone = udp_tx.clone(); + + tokio::spawn(async move { + let mut buf = vec![0_u8; 65535]; + loop { + match socket_clone.recv(&mut buf).await { + Ok(n) => { + let inbound = Bytes::copy_from_slice(&buf[..n]); + if udp_tx_clone.send((session_index, inbound)).await.is_err() { + break; + } + } + Err(e) => { + tracing::warn!("UDP socket recv error (reconnect session {}): {}", session_index, e); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + } + }); + + new_sessions.push(SessionState { socket: sock, machine: mach }); + rtt_sum += rtt; + successful_sessions += 1; + } + Err(err) => { + let _ = tx.send(UiEvent::Log(format!("Background reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await; + } + } + } + + if !new_sessions.is_empty() { + *sessions_opt = Some(new_sessions); + *udp_rx_opt = Some(udp_rx); + self.last_rtt_ms = rtt_sum / successful_sessions as f64; + self.last_valid_recv = Instant::now(); + self.metrics.connection_state.store(2, Ordering::Relaxed); + let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await; + + for session in sessions_opt.as_mut().unwrap().iter_mut() { + let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; + let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); + if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + } + + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "background reconnect"); + + let mut flushed = 0; + while let Ok(stale) = proxy_rx.try_recv() { + if let ProxyEvent::NewStream { stream_id, .. } = stale { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("connection reset".into()))); + } + flushed += 1; + } + if flushed > 0 { + let _ = tx.send(UiEvent::Log(format!("Flushed {} stale proxy messages to prevent UDP burst", flushed))).await; + } + } else { + let _ = tx.send(UiEvent::Log("Background reconnect failed. Will retry on next tick...".into())).await; + } + } + + if let Some(sessions) = sessions_opt.as_mut() { + for session in sessions.iter_mut() { + let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; + let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); + if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + + let ka_payload = Bytes::from(RelayMessage::KeepAlive.encode()); + if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ka_payload)) { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + } + } + } + + async fn handle_retransmit( + &mut self, + sessions_opt: &mut Option>, + udp_rx_opt: &mut Option>, + proxy_guard: &mut Option, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + ) { + let mut fatal_err = None; + if let Some(sessions) = sessions_opt.as_mut() { + for session in sessions.iter_mut() { + match session.machine.on_event(OstpEvent::Tick) { + Ok(action) => { + let mut queue = vec![action]; + while let Some(current_action) = queue.pop() { + match current_action { + ProtocolAction::Multiple(nested) => { + for a in nested { + queue.push(a); + } + } + ProtocolAction::SendDatagram(frame) => { + let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + } + _ => {} + } + } + } + Err(e) => { + fatal_err = Some(e); + break; + } + } + } + } + + if let Some(e) = fatal_err { + let _ = tx.send(UiEvent::Log(format!("Protocol tick fatal error: {e}"))).await; + self.running = false; + *proxy_guard = None; + *sessions_opt = None; + *udp_rx_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "protocol fatal error"); + let _ = tx.send(UiEvent::TunnelStopped).await; + self.metrics.connection_state.store(0, Ordering::Relaxed); + } + } + + async fn handle_proxy_event( + &mut self, + proxy_ev: Option, + sessions_opt: &mut Option>, + stream_map: &mut std::collections::HashMap, + tx: &mpsc::Sender, + proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, + ) { + if let Some(ev) = proxy_ev { + if let Some(sessions) = sessions_opt.as_mut() { + if sessions.is_empty() { + if let ProxyEvent::NewStream { stream_id, .. } = ev { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); + } + return; + } + let (stream_id, relay_msg, is_close) = match ev { + ProxyEvent::NewStream { stream_id, target } => { + let _ = tx.send(UiEvent::Log(format!("Proxy CONNECT stream_id={stream_id} target={target}"))).await; + (stream_id, RelayMessage::Connect(target), false) + } + ProxyEvent::UdpAssociate { stream_id } => { + let _ = tx.send(UiEvent::Log(format!("Proxy UDP ASSOCIATE stream_id={stream_id}"))).await; + (stream_id, RelayMessage::UdpAssociate, false) + } + ProxyEvent::UdpData { stream_id, target, payload } => { + (stream_id, RelayMessage::UdpData(target, payload.to_vec()), false) + } + ProxyEvent::Data { stream_id, payload } => (stream_id, RelayMessage::Data(payload.to_vec()), false), + ProxyEvent::Close { stream_id } => { + let _ = tx.send(UiEvent::Log(format!("Proxy CLOSE stream_id={stream_id}"))).await; + (stream_id, RelayMessage::Close, true) + } + }; + let len = sessions.len(); + let session_index = *stream_map.entry(stream_id).or_insert_with(|| { + rand::thread_rng().gen_range(0..len) + }); + if is_close { + stream_map.remove(&stream_id); + } + let session = &mut sessions[session_index]; + let out_payload = Bytes::from(relay_msg.encode()); + match session.machine.on_event(OstpEvent::Outbound(stream_id, out_payload)) { + Ok(ProtocolAction::SendDatagram(frame)) => { + if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() { + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + tracing::trace!("Outbound datagram sent stream_id={stream_id} bytes={}", frame.len()); + } + } + Ok(ProtocolAction::Multiple(list)) => { + let mut sent = 0usize; + for item in list { + if let ProtocolAction::SendDatagram(frame) = item { + if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() { + self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); + sent += 1; + } + } + } + tracing::trace!("Outbound datagram batch stream_id={stream_id} sent={sent}"); + } + Ok(ProtocolAction::Noop) => { + tracing::trace!("Outbound datagram noop stream_id={stream_id}"); + } + Ok(_) => { + tracing::trace!("Outbound datagram unexpected action stream_id={stream_id}"); + } + Err(e) => { + tracing::warn!("Protocol error packing outbound stream_id={}: {}", stream_id, e); + let _ = tx.send(UiEvent::Log(format!("Protocol error packing TCP: {e}"))).await; + } + } + } else { + if let ProxyEvent::NewStream { stream_id, .. } = ev { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); + } + } + } + } +""" + +with open("d:/ospab-projects/ostp/ostp-client/src/bridge.rs", "w", encoding="utf-8") as f: + f.write(prefix + new_run_and_helpers + suffix) + +print("Done")