diff --git a/ostp-client/src/bridge.rs b/ostp-client/src/bridge.rs index 0ba9b1e..8c2e509 100644 --- a/ostp-client/src/bridge.rs +++ b/ostp-client/src/bridge.rs @@ -1,1138 +1,15 @@ -use std::time::{Duration, SystemTime}; -use std::sync::atomic::Ordering; -use portable_atomic::{AtomicU64, AtomicU8}; -use std::sync::Arc; - -use anyhow::{Context, Result}; -use bytes::Bytes; -use ostp_core::relay::RelayMessage; -use ostp_core::{NoiseRole, OstpEvent, PaddingStrategy, ProtocolAction, ProtocolConfig, ProtocolMachine, TrafficProfile}; -use rand::Rng; -use tokio::net::UdpSocket; -use tokio::sync::{mpsc, watch}; -use tokio::time::{interval, timeout, Instant}; - -use crate::app::{BridgeCommand, ConnectionStatus, UiEvent}; -use crate::config::ClientConfig; -use crate::tunnel::{ProxyEvent, ProxyToClientMsg}; - -static SOCKET_PROTECTOR: std::sync::OnceLock bool + Send + Sync>> = std::sync::OnceLock::new(); - -pub fn set_socket_protector(f: F) -where - F: Fn(i32) -> bool + Send + Sync + 'static, -{ - let _ = SOCKET_PROTECTOR.set(Box::new(f)); -} - -pub fn protect_socket(fd: i32) -> bool { - if let Some(f) = SOCKET_PROTECTOR.get() { - return f(fd); - } - true -} +use portable_atomic::{AtomicU64, AtomicU32, AtomicU8}; pub struct BridgeMetrics { pub bytes_sent: AtomicU64, pub bytes_recv: AtomicU64, pub connection_state: AtomicU8, - pub rtt_ms: portable_atomic::AtomicU32, + pub rtt_ms: AtomicU32, } -async fn send_datagram(socket: &crate::transport::Transport, frame: &Bytes, _webrtc_masquerade: bool) -> std::io::Result { - socket.send(frame).await +pub fn set_socket_protector(f: F) +where + F: Fn(i32) -> bool + Send + Sync + 'static, +{ + // stub } - -struct SessionState { - socket: crate::transport::Transport, - machine: ProtocolMachine, -} - -pub struct Bridge { - running: bool, - pub debug: bool, - profile: TrafficProfile, - server_addr: String, - local_bind_addr: String, - proxy_addr: String, - access_key: Bytes, - handshake_timeout_ms: u64, - io_timeout_ms: u64, - - pub keepalive_interval_sec: u64, - pub mode: String, - pub mux_enabled: bool, - pub mux_sessions: usize, - - pub transport_mode: String, - pub stealth_sni: String, - pub wss: bool, - pub mtu: usize, - pub kill_switch: bool, - pub reload_tx: Option>, - - metrics: Arc, - sample_sent: u64, - sample_recv: u64, - last_rtt_ms: f64, - last_sample_at: Instant, - last_valid_recv: Instant, -} - -impl Bridge { - pub fn new(config: &ClientConfig, metrics: Arc) -> Result { - Ok(Self { - running: false, - debug: config.debug, - profile: TrafficProfile::JsonRpc, - server_addr: config.ostp.server_addr.clone(), - local_bind_addr: config.ostp.local_bind_addr.clone(), - proxy_addr: config.local_proxy.bind_addr.clone(), - access_key: Bytes::from(config.ostp.access_key.clone()), - handshake_timeout_ms: config.ostp.handshake_timeout_ms, - io_timeout_ms: config.ostp.io_timeout_ms, - - keepalive_interval_sec: config.ostp.keepalive_interval_sec, - mode: config.mode.clone(), - mux_enabled: config.multiplex.enabled, - mux_sessions: config.multiplex.sessions.max(1), - - transport_mode: config.transport.mode.clone(), - stealth_sni: config.transport.stealth_sni.clone(), - wss: config.transport.wss, - mtu: config.ostp.mtu, - kill_switch: config.kill_switch, - reload_tx: None, - - metrics, - sample_sent: 0, - sample_recv: 0, - last_rtt_ms: 0.0, - last_sample_at: Instant::now(), - last_valid_recv: Instant::now(), - }) - } - - - pub async fn run( - mut self, - tx: mpsc::Sender, - mut bridge_rx: mpsc::Receiver, - mut shutdown: watch::Receiver, - mut proxy_rx: mpsc::Receiver, - proxy_tx: mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, - ) -> Result<()> { - let mut metrics_tick = interval(Duration::from_millis(500)); - let mut keepalive_tick = tokio::time::interval(Duration::from_secs(self.keepalive_interval_sec.max(1))); - let mut retransmit_tick = tokio::time::interval(Duration::from_millis(10)); - let init_msg = if self.mode == "tun" { - "Bridge initialized (TUN mode)".to_string() - } else { - "Bridge initialized (proxy mode)".to_string() - }; - tx.send(UiEvent::Log(init_msg)).await.ok(); - - let mut sessions_opt: Option> = None; - let mut udp_rx_opt: Option> = None; - let mut proxy_guard: Option = None; - let mut stream_map: std::collections::HashMap = std::collections::HashMap::new(); - - loop { - tokio::select! { - biased; - _ = shutdown.changed() => { - if *shutdown.borrow() { - self.running = false; - self.metrics.connection_state.store(0, Ordering::Relaxed); - #[allow(unused_assignments)] - { proxy_guard = None; } - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "manual stop"); - break; - } - } - udp_msg = async { - match udp_rx_opt.as_mut() { - Some(rx) => rx.recv().await, - None => std::future::pending().await, - } - }, if self.running => { - self.handle_inbound_udp(udp_msg, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await; - } - cmd = bridge_rx.recv() => { - if !self.handle_bridge_cmd(cmd, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await { - break; - } - } - _ = metrics_tick.tick() => { - if self.running { - self.emit_metrics(&tx).await; - } - } - _ = keepalive_tick.tick() => { - if self.running { - self.handle_keepalive(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx, &mut proxy_rx).await; - } - } - _ = retransmit_tick.tick() => { - if self.running { - self.handle_retransmit(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await; - } - } - proxy_ev = proxy_rx.recv(), if self.running && sessions_opt.as_ref().map(|s| { - s.iter().any(|ses| ses.machine.in_flight_count() < ses.machine.cwnd_packets().clamp(16, 16384)) - }).unwrap_or(true) => { - self.handle_proxy_event(proxy_ev, &mut sessions_opt, &mut stream_map, &tx, &proxy_tx).await; - } - } - } - - tx.send(UiEvent::Log("Bridge stopped".to_string())).await.ok(); - Ok(()) - } - - async fn handle_inbound_udp( - &mut self, - udp_msg: Option<(usize, Bytes)>, - sessions_opt: &mut Option>, - udp_rx_opt: &mut Option>, - _proxy_guard: &mut Option, - stream_map: &mut std::collections::HashMap, - tx: &mpsc::Sender, - proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, - ) { - match udp_msg { - Some((session_index, inbound)) => { - self.metrics.bytes_recv.fetch_add(inbound.len() as u64, Ordering::Relaxed); - self.last_valid_recv = Instant::now(); - if let Some(sessions) = sessions_opt.as_mut() { - if session_index < sessions.len() { - let session = &mut sessions[session_index]; - let initial_action = match session.machine.on_event(OstpEvent::Inbound(inbound)) { - Ok(a) => a, - Err(e) => { - let _ = tx.send(UiEvent::Log(format!("Protocol decrypt error: {e}"))).await; - tracing::warn!("Inbound protocol error (session {}): {}", session_index, e); - return; - } - }; - - let mut actions_queue = std::collections::VecDeque::new(); - actions_queue.push_back(initial_action); - - while let Some(current_action) = actions_queue.pop_front() { - match current_action { - ProtocolAction::Multiple(nested) => { - for a in nested { - actions_queue.push_back(a); - } - } - ProtocolAction::DeliverApp(stream_id, dec_payload) => { - match RelayMessage::decode(&dec_payload) { - Ok(relay_msg) => { - match relay_msg { - RelayMessage::ConnectOk => { - let _ = tx.send(UiEvent::Log(format!("Relay CONNECT OK stream_id={stream_id}"))).await; - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::ConnectOk)); - } - RelayMessage::Data(data) => { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Data(Bytes::from(data)))); - } - RelayMessage::Close => { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Close)); - } - RelayMessage::Error(msg) => { - let _ = tx.send(UiEvent::Log(format!("Relay error for stream {stream_id}: {msg}"))).await; - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error(msg))); - } - RelayMessage::Pong(ts) => { - let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; - self.last_rtt_ms = now.saturating_sub(ts) as f64; - self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed); - } - RelayMessage::UdpAssociate => {} - RelayMessage::UdpData(target, data) => { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data)))); - } - RelayMessage::KeepAlive | RelayMessage::Ping(_) | RelayMessage::Connect(_) => {} - } - } - Err(err) => { - let _ = tx.send(UiEvent::Log(format!("Relay decode error for stream {stream_id}: {err}"))).await; - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("relay decode failed".to_string()))); - } - } - } - ProtocolAction::SendDatagram(frame) => { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - _ => {} - } - } - } - } - } - None => { - let _ = tx.send(UiEvent::Log("UDP channel closed, resetting connection".to_string())).await; - self.running = false; - crate::sysproxy::disable_system_proxy(); - *sessions_opt = None; - *udp_rx_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "udp reader closed"); - let _ = tx.send(UiEvent::TunnelStopped).await; - } - } - } - - async fn handle_bridge_cmd( - &mut self, - cmd: Option, - sessions_opt: &mut Option>, - udp_rx_opt: &mut Option>, - proxy_guard: &mut Option, - stream_map: &mut std::collections::HashMap, - tx: &mpsc::Sender, - proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, - ) -> bool { - match cmd { - Some(BridgeCommand::ToggleTunnel) => { - if self.running { - self.running = false; - self.metrics.connection_state.store(0, Ordering::Relaxed); - *proxy_guard = None; - *sessions_opt = None; - *udp_rx_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "manual stop"); - tx.send(UiEvent::TunnelStopped).await.ok(); - let stop_msg = if self.mode == "tun" { "TUN tunnel stopped" } else { "Bridge stopped" }; - tx.send(UiEvent::Log(stop_msg.to_string())).await.ok(); - } else { - tx.send(UiEvent::Log("Connecting to remote server...".to_string())).await.ok(); - tx.send(UiEvent::Metrics { status: ConnectionStatus::Handshaking, rtt_ms: 0.0, throughput_bps: 0 }).await.ok(); - self.metrics.connection_state.store(1, Ordering::Relaxed); - - let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; - let (udp_tx, udp_rx) = mpsc::channel(1024); - let mut sessions = Vec::with_capacity(session_count); - let mut rtt_sum = 0.0; - let mut successful_sessions = 0; - - for idx in 0..session_count { - let session_id: u32 = rand::thread_rng().gen(); - match self.perform_handshake_with_id(&tx, session_id).await { - Ok((sock, mach, rtt)) => { - let session_index = sessions.len(); - let socket_clone = sock.clone(); - let udp_tx_clone = udp_tx.clone(); - - tokio::spawn(async move { - let mut buf = vec![0_u8; 65535]; - let is_uot = matches!(socket_clone, crate::transport::Transport::Uot { .. }); - loop { - match socket_clone.recv(&mut buf).await { - Ok(n) => { - let inbound = Bytes::copy_from_slice(&buf[..n]); - if udp_tx_clone.send((session_index, inbound)).await.is_err() { - break; - } - } - Err(e) => { - if is_uot { - // TCP is dead — drop sender to signal bridge via channel close - tracing::warn!("UoT session {} disconnected: {}", session_index, e); - break; - } else { - tracing::warn!("UDP socket recv error (session {}): {}", session_index, e); - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - } - } - } - } - }); - - sessions.push(SessionState { socket: sock, machine: mach }); - rtt_sum += rtt; - successful_sessions += 1; - } - Err(err) => { - tx.send(UiEvent::Log(format!("Multiplex session {}/{} handshake failed: {}. Continuing with remaining sessions...", idx + 1, session_count, err))).await.ok(); - } - } - } - - if sessions.is_empty() { - *proxy_guard = None; - tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok(); - tx.send(UiEvent::TunnelStopped).await.ok(); - self.metrics.connection_state.store(0, Ordering::Relaxed); - return true; - } - - *udp_rx_opt = Some(udp_rx); - *sessions_opt = Some(sessions); - self.last_rtt_ms = rtt_sum / successful_sessions as f64; - self.running = true; - self.last_sample_at = Instant::now(); - self.last_valid_recv = Instant::now(); - - let sys_proxy_addr = self.proxy_addr.replace("0.0.0.0:", "127.0.0.1:"); - *proxy_guard = Some(crate::sysproxy::SystemProxyGuard::enable(&sys_proxy_addr)); - - tx.send(UiEvent::Metrics { - status: ConnectionStatus::Established, - rtt_ms: self.last_rtt_ms, - throughput_bps: 0, - }).await.ok(); - self.metrics.connection_state.store(2, Ordering::Relaxed); - let start_msg = if self.mode == "tun" { "TUN tunnel established" } else { "Connection established" }; - tx.send(UiEvent::Log(start_msg.to_string())).await.ok(); - - for session in sessions_opt.as_mut().unwrap().iter_mut() { - let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; - let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); - if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - } - } - } - Some(BridgeCommand::NextProfile) => { - self.profile = next_profile(self.profile); - tx.send(UiEvent::ProfileChanged(self.profile)).await.ok(); - tx.send(UiEvent::Log(format!("Obfuscation profile switched to {:?}", self.profile))).await.ok(); - } - Some(BridgeCommand::NetworkChanged) => { - if self.running { - let _ = tx.send(UiEvent::Log("Network changed — starting immediate reconnect".to_string())).await; - self.metrics.connection_state.store(1, Ordering::Relaxed); - self.last_valid_recv = Instant::now() - Duration::from_secs(100); - - let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; - let (udp_tx, udp_rx) = mpsc::channel(1024); - let mut new_sessions = Vec::with_capacity(session_count); - let mut successful_sessions = 0; - let mut rtt_sum = 0.0; - - for idx in 0..session_count { - let session_id: u32 = rand::thread_rng().gen(); - match self.perform_handshake_with_id(&tx, session_id).await { - Ok((sock, mach, rtt)) => { - let session_index = new_sessions.len(); - let socket_clone = sock.clone(); - let udp_tx_clone = udp_tx.clone(); - - tokio::spawn(async move { - let mut buf = vec![0_u8; 65535]; - let is_uot = matches!(socket_clone, crate::transport::Transport::Uot { .. }); - loop { - match socket_clone.recv(&mut buf).await { - Ok(n) => { - let inbound = Bytes::copy_from_slice(&buf[..n]); - if udp_tx_clone.send((session_index, inbound)).await.is_err() { break; } - } - Err(e) => { - if is_uot { - tracing::warn!("UoT network-change session {} disconnected: {}", session_index, e); - break; - } else { - tracing::warn!("UDP recv error (network-change session {}): {}", session_index, e); - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - } - } - } - } - }); - new_sessions.push(SessionState { socket: sock, machine: mach }); - rtt_sum += rtt; - successful_sessions += 1; - } - Err(err) => { - let _ = tx.send(UiEvent::Log(format!("NetworkChanged reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await; - } - } - } - - if !new_sessions.is_empty() { - *sessions_opt = Some(new_sessions); - *udp_rx_opt = Some(udp_rx); - self.last_rtt_ms = rtt_sum / successful_sessions as f64; - self.last_valid_recv = Instant::now(); - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "network changed"); - self.metrics.connection_state.store(2, Ordering::Relaxed); - let _ = tx.send(UiEvent::Log("NetworkChanged reconnect successful!".to_string())).await; - } else { - let _ = tx.send(UiEvent::Log("NetworkChanged reconnect failed — will retry on keepalive tick".to_string())).await; - } - } - } - Some(BridgeCommand::ReloadConfig) => { - match ClientConfig::reload_from_json_near_binary() { - Ok(cfg) => { - let old_server = self.server_addr.clone(); - let old_mode = self.mode.clone(); - let old_transport = self.transport_mode.clone(); - - self.apply_runtime_config(&cfg); - - let requires_restart = self.server_addr != old_server || - self.mode != old_mode || - self.transport_mode != old_transport; - - if !requires_restart { - if let Some(tx_watch) = &self.reload_tx { - let _ = tx_watch.send(cfg.exclusions.clone()); - } - tx.send(UiEvent::Log("Exclusions updated in real-time (hot reload)".to_string())).await.ok(); - } else { - tx.send(UiEvent::Log("Runtime config reloaded. Restarting tunnel due to critical parameter changes.".to_string())).await.ok(); - if self.running { - self.running = false; - self.metrics.connection_state.store(0, Ordering::Relaxed); - *proxy_guard = None; - *sessions_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "config reload"); - let _ = tx.send(UiEvent::TunnelStopped).await; - } - } - } - Err(err) => { - let _ = tx.send(UiEvent::Log(format!("Config reload failed: {err}"))).await; - } - } - } - Some(BridgeCommand::Shutdown) | None => { - self.running = false; - *proxy_guard = None; - return false; - } - } - true - } - - async fn handle_keepalive( - &mut self, - sessions_opt: &mut Option>, - udp_rx_opt: &mut Option>, - proxy_guard: &mut Option, - stream_map: &mut std::collections::HashMap, - tx: &mpsc::Sender, - proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, - proxy_rx: &mut mpsc::Receiver, - ) { - if self.last_valid_recv.elapsed().as_secs() > 25 { - let elapsed = self.last_valid_recv.elapsed().as_secs(); - if elapsed > 180 { - if self.kill_switch { - let _ = tx.send(UiEvent::Log(format!("Connection stall ({}s). Kill Switch is ON, retrying reconnect indefinitely...", elapsed))).await; - } else { - let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await; - self.running = false; - *proxy_guard = None; - *sessions_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout"); - let _ = tx.send(UiEvent::TunnelStopped).await; - self.metrics.connection_state.store(0, Ordering::Relaxed); - return; - } - } else { - let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await; - } - - self.metrics.connection_state.store(1, Ordering::Relaxed); - - let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; - let (udp_tx, udp_rx) = mpsc::channel(1024); - let mut new_sessions = Vec::with_capacity(session_count); - let mut successful_sessions = 0; - let mut rtt_sum = 0.0; - - for idx in 0..session_count { - let session_id: u32 = rand::thread_rng().gen(); - match self.perform_handshake_with_id(&tx, session_id).await { - Ok((sock, mach, rtt)) => { - let session_index = new_sessions.len(); - let socket_clone = sock.clone(); - let udp_tx_clone = udp_tx.clone(); - - tokio::spawn(async move { - let mut buf = vec![0_u8; 65535]; - let is_uot = matches!(socket_clone, crate::transport::Transport::Uot { .. }); - loop { - match socket_clone.recv(&mut buf).await { - Ok(n) => { - let inbound = Bytes::copy_from_slice(&buf[..n]); - if udp_tx_clone.send((session_index, inbound)).await.is_err() { - break; - } - } - Err(e) => { - if is_uot { - tracing::warn!("UoT reconnect session {} disconnected: {}", session_index, e); - break; - } else { - tracing::warn!("UDP socket recv error (reconnect session {}): {}", session_index, e); - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - } - } - } - } - }); - - new_sessions.push(SessionState { socket: sock, machine: mach }); - rtt_sum += rtt; - successful_sessions += 1; - } - Err(err) => { - let _ = tx.send(UiEvent::Log(format!("Background reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await; - } - } - } - - if !new_sessions.is_empty() { - *sessions_opt = Some(new_sessions); - *udp_rx_opt = Some(udp_rx); - self.last_rtt_ms = rtt_sum / successful_sessions as f64; - self.last_valid_recv = Instant::now(); - self.metrics.connection_state.store(2, Ordering::Relaxed); - let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await; - - for session in sessions_opt.as_mut().unwrap().iter_mut() { - let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; - let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); - if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - } - - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "background reconnect"); - - let mut flushed = 0; - while let Ok(stale) = proxy_rx.try_recv() { - if let ProxyEvent::NewStream { stream_id, .. } = stale { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("connection reset".into()))); - } - flushed += 1; - } - if flushed > 0 { - let _ = tx.send(UiEvent::Log(format!("Flushed {} stale proxy messages to prevent UDP burst", flushed))).await; - } - } else { - let _ = tx.send(UiEvent::Log("Background reconnect failed. Will retry on next tick...".into())).await; - } - } - - if let Some(sessions) = sessions_opt.as_mut() { - for session in sessions.iter_mut() { - let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; - let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); - if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - - let ka_payload = Bytes::from(RelayMessage::KeepAlive.encode()); - if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ka_payload)) { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - } - } - } - - async fn handle_retransmit( - &mut self, - sessions_opt: &mut Option>, - udp_rx_opt: &mut Option>, - proxy_guard: &mut Option, - stream_map: &mut std::collections::HashMap, - tx: &mpsc::Sender, - proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, - ) { - let mut fatal_err = None; - if let Some(sessions) = sessions_opt.as_mut() { - for session in sessions.iter_mut() { - match session.machine.on_event(OstpEvent::Tick) { - Ok(action) => { - let mut queue = vec![action]; - while let Some(current_action) = queue.pop() { - match current_action { - ProtocolAction::Multiple(nested) => { - for a in nested { - queue.push(a); - } - } - ProtocolAction::SendDatagram(frame) => { - let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - } - _ => {} - } - } - } - Err(e) => { - fatal_err = Some(e); - break; - } - } - } - } - - if let Some(e) = fatal_err { - let _ = tx.send(UiEvent::Log(format!("Protocol tick fatal error: {e}"))).await; - self.running = false; - *proxy_guard = None; - *sessions_opt = None; - *udp_rx_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "protocol fatal error"); - let _ = tx.send(UiEvent::TunnelStopped).await; - self.metrics.connection_state.store(0, Ordering::Relaxed); - } - } - - async fn handle_proxy_event( - &mut self, - proxy_ev: Option, - sessions_opt: &mut Option>, - stream_map: &mut std::collections::HashMap, - tx: &mpsc::Sender, - proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, - ) { - if let Some(ev) = proxy_ev { - if let Some(sessions) = sessions_opt.as_mut() { - if sessions.is_empty() { - if let ProxyEvent::NewStream { stream_id, .. } = ev { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); - } - return; - } - let (stream_id, relay_msg, is_close) = match ev { - ProxyEvent::NewStream { stream_id, target } => { - let _ = tx.send(UiEvent::Log(format!("Proxy CONNECT stream_id={stream_id} target={target}"))).await; - (stream_id, RelayMessage::Connect(target), false) - } - ProxyEvent::UdpAssociate { stream_id } => { - let _ = tx.send(UiEvent::Log(format!("Proxy UDP ASSOCIATE stream_id={stream_id}"))).await; - (stream_id, RelayMessage::UdpAssociate, false) - } - ProxyEvent::UdpData { stream_id, target, payload } => { - (stream_id, RelayMessage::UdpData(target, payload.to_vec()), false) - } - ProxyEvent::Data { stream_id, payload } => (stream_id, RelayMessage::Data(payload.to_vec()), false), - ProxyEvent::Close { stream_id } => { - let _ = tx.send(UiEvent::Log(format!("Proxy CLOSE stream_id={stream_id}"))).await; - (stream_id, RelayMessage::Close, true) - } - }; - let len = sessions.len(); - let session_index = *stream_map.entry(stream_id).or_insert_with(|| { - rand::thread_rng().gen_range(0..len) - }); - if is_close { - stream_map.remove(&stream_id); - } - let session = &mut sessions[session_index]; - let out_payload = Bytes::from(relay_msg.encode()); - match session.machine.on_event(OstpEvent::Outbound(stream_id, out_payload)) { - Ok(ProtocolAction::SendDatagram(frame)) => { - if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() { - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - tracing::trace!("Outbound datagram sent stream_id={stream_id} bytes={}", frame.len()); - } - } - Ok(ProtocolAction::Multiple(list)) => { - let mut sent = 0usize; - for item in list { - if let ProtocolAction::SendDatagram(frame) = item { - if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() { - self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); - sent += 1; - } - } - } - tracing::trace!("Outbound datagram batch stream_id={stream_id} sent={sent}"); - } - Ok(ProtocolAction::Noop) => { - tracing::trace!("Outbound datagram noop stream_id={stream_id}"); - } - Ok(_) => { - tracing::trace!("Outbound datagram unexpected action stream_id={stream_id}"); - } - Err(e) => { - tracing::warn!("Protocol error packing outbound stream_id={}: {}", stream_id, e); - let _ = tx.send(UiEvent::Log(format!("Protocol error packing TCP: {e}"))).await; - } - } - } else { - if let ProxyEvent::NewStream { stream_id, .. } = ev { - let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); - } - } - } - } - - - fn reset_proxy_streams( - &self, - tx: &mpsc::Sender, - proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>, - reason: &str, - ) { - if proxy_tx - .send((0, ProxyToClientMsg::Close)) - .is_err() - { - let tx_clone = tx.clone(); - let reason_str = reason.to_string(); - tokio::spawn(async move { - let _ = tx_clone - .send(UiEvent::Log(format!( - "Failed to reset local proxy streams ({reason_str})" - ))) - .await; - }); - } - } - - async fn emit_metrics(&mut self, tx: &mpsc::Sender) { - let now = Instant::now(); - let elapsed = now.duration_since(self.last_sample_at).as_secs_f64().max(0.001); - self.last_sample_at = now; - - let cur_sent = self.metrics.bytes_sent.load(Ordering::Relaxed); - let cur_recv = self.metrics.bytes_recv.load(Ordering::Relaxed); - - let sent_delta = cur_sent.saturating_sub(self.sample_sent); - let recv_delta = cur_recv.saturating_sub(self.sample_recv); - - self.sample_sent = cur_sent; - self.sample_recv = cur_recv; - - let outgoing = (sent_delta as f64 / elapsed) as u64; - let incoming = (recv_delta as f64 / elapsed) as u64; - let throughput = incoming.saturating_add(outgoing); - - tx.send(UiEvent::Traffic { incoming_bps: incoming, outgoing_bps: outgoing }).await.ok(); - - // Dynamically report connection status based on whether we have received server packets recently (last 10 seconds) - let is_healthy = self.last_valid_recv.elapsed() < Duration::from_secs(10); - let status = if is_healthy { - self.metrics.connection_state.store(2, Ordering::Relaxed); - ConnectionStatus::Established - } else { - self.metrics.connection_state.store(1, Ordering::Relaxed); - ConnectionStatus::Handshaking - }; - - tx.send(UiEvent::Metrics { - status, - rtt_ms: self.last_rtt_ms, - throughput_bps: throughput, - }).await.ok(); - } - - async fn perform_handshake_with_id( - &mut self, - tx: &mpsc::Sender, - session_id: u32, - ) -> Result<(crate::transport::Transport, ProtocolMachine, f64)> { - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - - let mut handshake_payload = Vec::with_capacity(8 + 4 + self.access_key.len()); - handshake_payload.extend_from_slice(×tamp.to_be_bytes()); - handshake_payload.extend_from_slice(&session_id.to_be_bytes()); - handshake_payload.extend_from_slice(&self.access_key); - - let secrets = ostp_core::crypto::derive_all_secrets(&self.access_key); - - let mut resolved_addrs: Vec = match tokio::net::lookup_host(&self.server_addr).await { - Ok(addrs) => addrs.collect(), - Err(e) => return Err(anyhow::anyhow!("failed to resolve server address {}: {}", self.server_addr, e)), - }; - resolved_addrs.sort_by_key(|addr| if addr.is_ipv6() { 0 } else { 1 }); - - let mut last_err = anyhow::anyhow!("no IP addresses resolved for {}", self.server_addr); - - for target_addr in resolved_addrs { - let target_ip = target_addr.ip(); - let port = target_addr.port(); - - tx.send(UiEvent::Log(format!("Connecting to remote server: {}...", target_addr))).await.ok(); - - let socket = match self.try_connect_transport(target_ip, port).await { - Ok(sock) => sock, - Err(e) => { - if let std::net::IpAddr::V4(ipv4) = target_ip { - tx.send(UiEvent::Log(format!("Direct IPv4 connection failed: {}. Trying NAT64 fallback...", e))).await.ok(); - let nat64_ipv6 = synthesize_nat64(ipv4).await; - match self.try_connect_transport(std::net::IpAddr::V6(nat64_ipv6), port).await { - Ok(sock) => sock, - Err(fallback_err) => { - last_err = anyhow::anyhow!("Direct IPv4 failed: {}. NAT64 fallback failed: {}", e, fallback_err); - continue; - } - } - } else { - last_err = anyhow::anyhow!("Connection to {} failed: {}", target_addr, e); - continue; - } - } - }; - - let mut machine = ProtocolMachine::new(ProtocolConfig { - role: NoiseRole::Initiator, - psk: secrets.psk, - session_id, - handshake_payload: handshake_payload.clone(), - padding_strategy: PaddingStrategy::Profile(self.profile), - obfuscation_key: secrets.obfuscation_key, - max_reorder: 16384, - max_reorder_buffer: 8192, - ack_delay_ms: 5, - rto_ms: 100, - max_retries: 8, - max_sent_history: 32768, - handshake_pad_min: secrets.handshake_pad_min, - handshake_pad_max: secrets.handshake_pad_max, - mtu: self.mtu, - max_padding: self.mtu.saturating_sub(48).max(256), - })?; - - let start = Instant::now(); - let action = match machine.on_event(OstpEvent::Start) { - Ok(a) => a, - Err(e) => { - last_err = anyhow::anyhow!("protocol start error: {}", e); - continue; - } - }; - - let handshake_frame = match action { - ProtocolAction::SendDatagram(frame) => frame, - _ => { - last_err = anyhow::anyhow!("protocol did not emit handshake datagram"); - continue; - } - }; - - let mut buf = vec![0_u8; 4096]; - let mut size = 0; - let mut success = false; - - let is_uot = matches!(socket, crate::transport::Transport::Uot { .. }); - let (attempt_limit, attempt_timeout_ms) = if is_uot { (1, 8000) } else { (4, 1200) }; - - for attempt in 0..attempt_limit { - if attempt > 0 { - tx.send(UiEvent::Log(format!("Handshake attempt {} lost. Retransmitting...", attempt))).await.ok(); - } - if send_datagram(&socket, &handshake_frame, self.transport_mode == "udp").await.is_ok() { - self.metrics.bytes_sent.fetch_add(handshake_frame.len() as u64, Ordering::Relaxed); - } - - match timeout(Duration::from_millis(attempt_timeout_ms), socket.recv(&mut buf)).await { - Ok(Ok(n)) => { - size = n; - success = true; - break; - } - _ => {} - } - } - - let (final_socket, size) = if success { - (socket, size) - } else { - if let std::net::IpAddr::V4(ipv4) = target_ip { - tx.send(UiEvent::Log("Direct IPv4 handshake timed out. Trying NAT64 fallback...".to_string())).await.ok(); - let nat64_ipv6 = synthesize_nat64(ipv4).await; - match self.try_connect_transport(std::net::IpAddr::V6(nat64_ipv6), port).await { - Ok(fallback_socket) => { - let mut fallback_success = false; - for attempt in 0..4 { - if attempt > 0 { - tx.send(UiEvent::Log(format!("NAT64 handshake attempt {} lost. Retransmitting...", attempt))).await.ok(); - } - if send_datagram(&fallback_socket, &handshake_frame, self.transport_mode == "udp").await.is_ok() { - self.metrics.bytes_sent.fetch_add(handshake_frame.len() as u64, Ordering::Relaxed); - } - match timeout(Duration::from_millis(1200), fallback_socket.recv(&mut buf)).await { - Ok(Ok(n)) => { - size = n; - fallback_success = true; - break; - } - _ => {} - } - } - if fallback_success { - tx.send(UiEvent::Log("NAT64 fallback handshake successful!".to_string())).await.ok(); - (fallback_socket, size) - } else { - last_err = anyhow::anyhow!("NAT64 handshake failed after 4 attempts"); - continue; - } - } - Err(e) => { - last_err = anyhow::anyhow!("NAT64 fallback socket creation failed: {}", e); - continue; - } - } - } else { - last_err = anyhow::anyhow!("Direct handshake failed after attempts"); - continue; - } - }; - - let socket = final_socket; - self.metrics.bytes_recv.fetch_add(size as u64, Ordering::Relaxed); - tracing::info!("Handshake response received: {} bytes", size); - - let inbound = Bytes::copy_from_slice(&buf[..size]); - if let Err(e) = machine.on_event(OstpEvent::Inbound(inbound)) { - last_err = anyhow::anyhow!("Protocol invalid response: {}", e); - continue; - } - let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; - tracing::info!("Handshake complete: session={:#010x} rtt={:.1}ms", session_id, rtt_ms); - - return Ok((socket, machine, rtt_ms)); - } - - Err(last_err) - } - - fn apply_runtime_config(&mut self, cfg: &ClientConfig) { - self.server_addr = cfg.ostp.server_addr.clone(); - self.local_bind_addr = cfg.ostp.local_bind_addr.clone(); - self.proxy_addr = cfg.local_proxy.bind_addr.clone(); - self.access_key = Bytes::from(cfg.ostp.access_key.clone()); - self.handshake_timeout_ms = cfg.ostp.handshake_timeout_ms; - self.io_timeout_ms = cfg.ostp.io_timeout_ms; - self.mode = cfg.mode.clone(); // Bug fix: mode was never updated on hot-reload - self.mux_enabled = cfg.multiplex.enabled; - self.mux_sessions = cfg.multiplex.sessions.max(1); - self.transport_mode = cfg.transport.mode.clone(); - self.stealth_sni = cfg.transport.stealth_sni.clone(); - self.wss = cfg.transport.wss; // Fix: wss was not updated on hot-reload - self.mtu = cfg.ostp.mtu; - self.keepalive_interval_sec = cfg.ostp.keepalive_interval_sec; - self.kill_switch = cfg.kill_switch; - } - - async fn try_connect_transport( - &self, - target_ip: std::net::IpAddr, - port: u16, - ) -> Result { - let mode = self.transport_mode.to_lowercase(); - if mode == "uot" || mode == "tcp" { - let stream = tokio::net::TcpStream::connect((target_ip, port)).await?; - let _ = stream.set_nodelay(true); - let (mut read_half, mut write_half) = stream.into_split(); - - let (tx_out, mut rx_out) = tokio::sync::mpsc::channel::(1024); - let (tx_in, rx_in) = tokio::sync::mpsc::channel::(1024); - - // Task to write from rx_out to tcp stream - tokio::spawn(async move { - use tokio::io::AsyncWriteExt; - while let Some(data) = rx_out.recv().await { - let mut len_buf = [0u8; 2]; - len_buf.copy_from_slice(&(data.len() as u16).to_be_bytes()); - if write_half.write_all(&len_buf).await.is_err() { break; } - if write_half.write_all(&data).await.is_err() { break; } - } - }); - - // Task to read from tcp stream to tx_in - let tx_in_clone = tx_in.clone(); - tokio::spawn(async move { - use tokio::io::AsyncReadExt; - loop { - let mut len_buf = [0u8; 2]; - if read_half.read_exact(&mut len_buf).await.is_err() { break; } - let len = u16::from_be_bytes(len_buf) as usize; - let mut data = vec![0u8; len]; - if read_half.read_exact(&mut data).await.is_err() { break; } - if tx_in_clone.send(bytes::Bytes::from(data)).await.is_err() { break; } - } - }); - - Ok(crate::transport::Transport::Uot { tx: tx_out, rx: std::sync::Arc::new(tokio::sync::Mutex::new(rx_in)) }) - } else { - let is_ipv6 = target_ip.is_ipv6(); - let domain = if is_ipv6 { socket2::Domain::IPV6 } else { socket2::Domain::IPV4 }; - let bind_addr = if is_ipv6 { - std::net::SocketAddr::new(std::net::IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED), 0) - } else { - std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0) - }; - - let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, Some(socket2::Protocol::UDP))?; - #[cfg(unix)] - { - use std::os::unix::io::AsRawFd; - protect_socket(sock.as_raw_fd()); - } - let _ = sock.set_recv_buffer_size(33554432); // 32MB - let _ = sock.set_send_buffer_size(33554432); // 32MB - let actual_recv = sock.recv_buffer_size().unwrap_or(0); - let actual_send = sock.send_buffer_size().unwrap_or(0); - tracing::info!("UDP socket buffers: recv={}KB send={}KB", actual_recv / 1024, actual_send / 1024); - sock.bind(&bind_addr.into())?; - sock.set_nonblocking(true)?; - let socket = UdpSocket::from_std(sock.into())?; - - let connect_addr = std::net::SocketAddr::new(target_ip, port); - socket.connect(connect_addr).await.with_context(|| format!("failed to connect udp to {}", connect_addr))?; - Ok(crate::transport::Transport::Udp(Arc::new(socket))) - } - } -} - -fn next_profile(current: TrafficProfile) -> TrafficProfile { - match current { - TrafficProfile::JsonRpc => TrafficProfile::HttpsBurst, - TrafficProfile::HttpsBurst => TrafficProfile::VideoStream, - TrafficProfile::VideoStream => TrafficProfile::JsonRpc, - } -} - -async fn synthesize_nat64(ip: std::net::Ipv4Addr) -> std::net::Ipv6Addr { - let mut prefix = [0x00, 0x64, 0xff, 0x9b, 0, 0, 0, 0, 0, 0, 0, 0]; - if let Ok(addrs) = tokio::net::lookup_host("ipv4only.arpa:80").await { - for addr in addrs { - if let std::net::SocketAddr::V6(v6) = addr { - let octets = v6.ip().octets(); - prefix.copy_from_slice(&octets[0..12]); - break; - } - } - } - let octets = ip.octets(); - std::net::Ipv6Addr::new( - ((prefix[0] as u16) << 8) | prefix[1] as u16, - ((prefix[2] as u16) << 8) | prefix[3] as u16, - ((prefix[4] as u16) << 8) | prefix[5] as u16, - ((prefix[6] as u16) << 8) | prefix[7] as u16, - ((prefix[8] as u16) << 8) | prefix[9] as u16, - ((prefix[10] as u16) << 8) | prefix[11] as u16, - ((octets[0] as u16) << 8) | octets[1] as u16, - ((octets[2] as u16) << 8) | octets[3] as u16, - ) -} - - diff --git a/ostp-client/src/bridge.rs.bak b/ostp-client/src/bridge.rs.bak new file mode 100644 index 0000000..72e7013 Binary files /dev/null and b/ostp-client/src/bridge.rs.bak differ diff --git a/ostp-client/src/config.rs b/ostp-client/src/config.rs index 45bf19b..9567a86 100644 --- a/ostp-client/src/config.rs +++ b/ostp-client/src/config.rs @@ -1,87 +1,95 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; -/// Client runtime configuration. -/// Constructed by the main binary from the unified `config.json`, -/// then passed into `runner::run_client`. All I/O happens in the -/// binary layer — this crate only owns the plain data structures. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClientConfig { - pub mode: String, #[serde(default)] - pub debug: bool, - pub ostp: OstpConfig, - pub local_proxy: LocalProxyConfig, + pub log: LogConfig, #[serde(default)] - pub transport: TransportConfig, + pub inbounds: Vec, #[serde(default)] - pub exclusions: ExclusionConfig, + pub outbounds: Vec, #[serde(default)] - pub multiplex: MultiplexConfig, - pub dns_server: Option, - #[serde(default = "default_tun_stack")] - pub tun_stack: String, - #[serde(default)] - pub kill_switch: bool, + pub routing: RoutingConfig, #[serde(default, skip_serializing_if = "Option::is_none")] pub gui: Option, } -fn default_tun_stack() -> String { "system".to_string() } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LogConfig { + #[serde(default = "default_log_level")] + pub level: String, +} -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct ExclusionConfig { - #[serde(default)] - pub domains: Vec, - #[serde(default)] - pub ips: Vec, - #[serde(default)] - pub processes: Vec, +impl Default for LogConfig { + fn default() -> Self { + Self { level: default_log_level() } + } +} + +fn default_log_level() -> String { "info".to_string() } +fn default_true() -> bool { true } +pub fn default_mtu() -> usize { 1140 } + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum InboundConfig { + Tun { + tag: String, + #[serde(default = "default_true")] + auto_route: bool, + #[serde(default = "default_mtu")] + mtu: usize, + }, + LocalProxy { + tag: String, + protocol: String, // "socks" or "http" + listen: String, + port: u16, + }, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct MultiplexConfig { - pub enabled: bool, - pub sessions: usize, +#[serde(tag = "type", rename_all = "snake_case")] +pub enum OutboundConfig { + Selector { + tag: String, + outbounds: Vec, + default: Option, + }, + Urltest { + tag: String, + outbounds: Vec, + url: Option, + interval: Option, + }, + Ostp { + tag: String, + server: String, + port: u16, + access_key: String, + #[serde(default)] + transport: TransportConfig, + #[serde(default)] + multiplex: MultiplexConfig, + }, + Direct { + tag: String, + }, + Socks { + tag: String, + server: String, + port: u16, + }, + Block { + tag: String, + }, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OstpConfig { - pub server_addr: String, - pub local_bind_addr: String, - #[serde(alias = "auth_token")] - pub access_key: String, - pub handshake_timeout_ms: u64, - pub io_timeout_ms: u64, - #[serde(default = "default_mtu")] - pub mtu: usize, - #[serde(default = "default_keepalive")] - pub keepalive_interval_sec: u64, -} - -fn default_keepalive() -> u64 { 5 } - -fn default_mtu() -> usize { 1140 } - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LocalProxyConfig { - pub bind_addr: String, - pub connect_timeout_ms: u64, -} - -/// Transport layer configuration. -/// `mode` = "udp" (default) or "uot" (UDP over TCP with xHTTP stealth). #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TransportConfig { - /// "udp" or "uot" #[serde(default = "default_transport_mode")] - pub mode: String, - /// TLS SNI and HTTP Host for stealth routing - #[serde(default)] - pub stealth_sni: String, - /// Enable strict RFC 6455 WebSocket framing - #[serde(default)] - pub wss: bool, + pub r#type: String, // "udp" or "uot" } fn default_transport_mode() -> String { "udp".to_string() } @@ -89,58 +97,20 @@ fn default_transport_mode() -> String { "udp".to_string() } impl Default for TransportConfig { fn default() -> Self { Self { - mode: default_transport_mode(), - stealth_sni: String::new(), - wss: false, + r#type: default_transport_mode(), } } } - - - - -impl Default for OstpConfig { - fn default() -> Self { - Self { - server_addr: "127.0.0.1:50000".to_string(), - local_bind_addr: "0.0.0.0:0".to_string(), - access_key: String::new(), - handshake_timeout_ms: 5000, - io_timeout_ms: 2500, - mtu: default_mtu(), - keepalive_interval_sec: default_keepalive(), - } - } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MultiplexConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default = "default_mux_sessions")] + pub sessions: usize, } -impl Default for LocalProxyConfig { - fn default() -> Self { - Self { - bind_addr: "127.0.0.1:1088".to_string(), - connect_timeout_ms: 15000, - } - } -} - - -impl Default for ClientConfig { - fn default() -> Self { - Self { - mode: "proxy".to_string(), - debug: false, - ostp: OstpConfig::default(), - local_proxy: LocalProxyConfig::default(), - transport: TransportConfig::default(), - exclusions: ExclusionConfig::default(), - multiplex: MultiplexConfig::default(), - dns_server: None, - tun_stack: "system".to_string(), - kill_switch: false, - gui: None, - } - } -} +fn default_mux_sessions() -> usize { 1 } impl Default for MultiplexConfig { fn default() -> Self { @@ -151,57 +121,30 @@ impl Default for MultiplexConfig { } } -/// Unified shape of `config.json` as seen by the client. -/// Used only for hot-reloading (`BridgeCommand::ReloadConfig`). -#[derive(Debug, Deserialize)] -struct RawUnifiedConfig { - #[allow(dead_code)] - mode: String, - debug: Option, - server: Option, - access_key: Option, - mtu: Option, - socks5_bind: Option, - tun: Option, - exclude: Option, - mux: Option, - transport: Option, - gui: Option, +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct RoutingConfig { + #[serde(default)] + pub rules: Vec, + #[serde(default)] + pub default_outbound: String, } -#[derive(Debug, Deserialize)] -struct RawTransportSection { - mode: Option, - stealth_sni: Option, - wss: Option, +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoutingRule { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub domain_suffix: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub ip_cidr: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub process_name: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub inbound_tag: Option>, + pub outbound: String, } -#[derive(Debug, Deserialize)] -struct RawTunSection { - enable: Option, - dns: Option, - stack: Option, - kill_switch: Option, -} - -#[derive(Debug, Deserialize)] -struct RawExcludeSection { - domains: Option>, - ips: Option>, - processes: Option>, -} - -#[derive(Debug, Deserialize)] -struct RawMuxSection { - enabled: Option, - sessions: Option, -} - - - impl ClientConfig { /// Hot-reload from `config.json` placed next to the running binary. - /// Returns a new `ClientConfig` built from the unified JSON format. + /// Returns a new `ClientConfig` built from the JSON format. pub fn reload_from_json_near_binary() -> Result { let exe = std::env::current_exe().context("cannot resolve binary path")?; let dir = exe.parent().context("cannot resolve binary directory")?; @@ -210,58 +153,9 @@ impl ClientConfig { let raw = std::fs::read_to_string(&path) .with_context(|| format!("failed to read {}", path.display()))?; let mut stripped = json_comments::StripComments::new(raw.as_bytes()); - let raw: RawUnifiedConfig = serde_json::from_reader(&mut stripped) + let config: ClientConfig = serde_json::from_reader(&mut stripped) .with_context(|| format!("failed to parse {}", path.display()))?; - let is_tun = raw.tun.as_ref().and_then(|t| t.enable).unwrap_or(false); - let server = raw.server.unwrap_or_else(|| "127.0.0.1:50000".to_string()); - let key = raw.access_key.unwrap_or_default(); - let mtu = raw.mtu.unwrap_or(default_mtu()); - let socks5 = raw.socks5_bind.unwrap_or_else(|| "127.0.0.1:1088".to_string()); - let exclusions = raw.exclude.unwrap_or(RawExcludeSection { - domains: None, - ips: None, - processes: None, - }); - let mux = raw.mux.unwrap_or(RawMuxSection { - enabled: None, - sessions: None, - }); - - Ok(ClientConfig { - mode: if is_tun { "tun".to_string() } else { "proxy".to_string() }, - debug: raw.debug.unwrap_or(false), - ostp: OstpConfig { - server_addr: server, - local_bind_addr: "0.0.0.0:0".to_string(), - access_key: key, - handshake_timeout_ms: 5000, - io_timeout_ms: 2500, - mtu, - keepalive_interval_sec: default_keepalive(), - }, - local_proxy: LocalProxyConfig { - bind_addr: socks5, - connect_timeout_ms: 15000, - }, - transport: TransportConfig { - mode: raw.transport.as_ref().and_then(|t| t.mode.clone()).unwrap_or_else(default_transport_mode), - stealth_sni: raw.transport.as_ref().and_then(|t| t.stealth_sni.clone()).unwrap_or_default(), - wss: raw.transport.as_ref().and_then(|t| t.wss).unwrap_or(false), - }, - exclusions: ExclusionConfig { - domains: exclusions.domains.unwrap_or_default(), - ips: exclusions.ips.unwrap_or_default(), - processes: exclusions.processes.unwrap_or_default(), - }, - multiplex: MultiplexConfig { - enabled: mux.enabled.unwrap_or(false), - sessions: mux.sessions.unwrap_or(1), - }, - dns_server: raw.tun.as_ref().and_then(|t| t.dns.clone()), - tun_stack: raw.tun.as_ref().and_then(|t| t.stack.clone()).unwrap_or_else(|| "system".to_string()), - kill_switch: raw.tun.as_ref().and_then(|t| t.kill_switch).unwrap_or(false), - gui: raw.gui, - }) + Ok(config) } } diff --git a/ostp-client/src/runner.rs b/ostp-client/src/runner.rs index 44a2055..fda5c28 100644 --- a/ostp-client/src/runner.rs +++ b/ostp-client/src/runner.rs @@ -1,436 +1,74 @@ use anyhow::Result; +use std::sync::Arc; use tokio::sync::{mpsc, watch}; -use crate::app::BridgeCommand; -use crate::bridge::{Bridge, BridgeMetrics}; -use crate::signal::wait_for_shutdown_signal; -use crate::tunnel; -use std::sync::Arc; -use std::fs::OpenOptions; -use std::io::Write as _; - -fn log_to_core_file(msg: &str) { - let path = std::env::current_exe() - .ok() - .and_then(|p| p.parent().map(|d| d.join("ostp-core.log"))) - .unwrap_or_else(|| std::path::PathBuf::from("ostp-core.log")); - if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { - let _ = writeln!(file, "[{}] {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), msg); - } -} - -#[cfg(target_os = "windows")] -#[link(name = "kernel32")] -extern "system" { - fn FreeConsole() -> i32; - fn GetConsoleWindow() -> *mut std::ffi::c_void; -} - -#[cfg(target_os = "windows")] -#[link(name = "user32")] -extern "system" { - fn ShowWindow(hwnd: *mut std::ffi::c_void, cmd_show: i32) -> i32; -} - -fn hide_console() { - #[cfg(target_os = "windows")] - unsafe { - let hwnd = GetConsoleWindow(); - if !hwnd.is_null() { - ShowWindow(hwnd, 0); // SW_HIDE = 0 - } - FreeConsole(); - } -} - -#[cfg(target_os = "windows")] -pub fn is_admin() -> bool { - std::process::Command::new("net") - .arg("session") - .stdout(std::process::Stdio::null()) - .stderr(std::process::Stdio::null()) - .status() - .map(|s| s.success()) - .unwrap_or(false) -} - -#[cfg(target_os = "windows")] -fn relaunch_as_admin() -> Result<()> { - use std::ffi::OsStr; - use std::os::windows::ffi::OsStrExt; - use std::ptr::null_mut; - - let exe = std::env::current_exe()?; - let exe_wstr: Vec = exe.as_os_str().encode_wide().chain(Some(0)).collect(); - - let mut args_joined = String::new(); - for arg in std::env::args().skip(1) { - if !args_joined.is_empty() { - args_joined.push(' '); - } - args_joined.push('"'); - args_joined.push_str(&arg.replace('"', "\\\"")); - args_joined.push('"'); - } - let args_wstr: Vec = OsStr::new(&args_joined).encode_wide().chain(Some(0)).collect(); - - let dir = std::env::current_dir()?; - let dir_wstr: Vec = dir.as_os_str().encode_wide().chain(Some(0)).collect(); - - let verb_wstr: Vec = OsStr::new("runas").encode_wide().chain(Some(0)).collect(); - - #[link(name = "shell32")] - extern "system" { - fn ShellExecuteW( - hwnd: *mut std::ffi::c_void, - lpOperation: *const u16, - lpFile: *const u16, - lpParameters: *const u16, - lpDirectory: *const u16, - nShowCmd: i32, - ) -> isize; - } - - unsafe { - let ret = ShellExecuteW( - null_mut(), - verb_wstr.as_ptr(), - exe_wstr.as_ptr(), - args_wstr.as_ptr(), - dir_wstr.as_ptr(), - 1, // SW_SHOWNORMAL = 1 - ); - if ret <= 32 { - return Err(anyhow::anyhow!( - "Windows UAC Elevation failed or was denied by policy (ShellExecuteW code: {})", - ret - )); - } - } - - std::process::exit(0); -} - -#[cfg(target_os = "linux")] -pub fn is_root() -> bool { - unsafe { libc::geteuid() == 0 } -} - -#[cfg(target_os = "linux")] -fn relaunch_as_root() -> Result<()> { - use std::io::IsTerminal; - let exe = std::env::current_exe()?; - let args: Vec = std::env::args().skip(1).collect(); - - let is_gui = std::env::var("DISPLAY").is_ok() || std::env::var("WAYLAND_DISPLAY").is_ok(); - let is_term = std::io::stdout().is_terminal(); - - let mut cmd = if is_gui && !is_term { - let mut c = std::process::Command::new("pkexec"); - c.arg(exe); - c - } else { - let mut c = std::process::Command::new("sudo"); - c.arg(exe); - c - }; - - cmd.args(&args); - - let status = cmd.status().map_err(|e| anyhow::anyhow!("Failed to execute privilege escalation command: {}", e))?; - - if !status.success() { - return Err(anyhow::anyhow!("Privilege escalation failed or was denied.")); - } - - std::process::exit(0); -} - -pub async fn run_client(config: crate::config::ClientConfig) -> Result<()> { - #[cfg(target_os = "windows")] - if config.mode == "tun" && !is_admin() { - println!("[ostp] TUN mode requires administrator privileges. Relaunching..."); - relaunch_as_admin()?; - } - - #[cfg(target_os = "linux")] - if config.mode == "tun" && !is_root() { - println!("[ostp] TUN mode requires root privileges. Requesting sudo/pkexec elevation..."); - relaunch_as_root()?; - } - - let bg = std::env::args().any(|a| a == "--bg"); - - if bg { - hide_console(); - } - - let metrics = Arc::new(BridgeMetrics { - bytes_sent: portable_atomic::AtomicU64::new(0), - bytes_recv: portable_atomic::AtomicU64::new(0), - connection_state: portable_atomic::AtomicU8::new(0), - rtt_ms: portable_atomic::AtomicU32::new(0), - }); - - let (shutdown_tx, shutdown_rx) = watch::channel(false); - - tokio::spawn(async move { - if wait_for_shutdown_signal().await.is_ok() { - let _ = shutdown_tx.send(true); - } - }); - - run_client_core(config, metrics, shutdown_rx, None).await -} +use crate::app::{BridgeCommand, ConnectionStatus, UiEvent}; +use crate::config::{ClientConfig, InboundConfig}; +use crate::tunnel::balancer::Balancer; +use crate::tunnel::outbounds::OutboundManager; +use crate::tunnel::router::Router; pub async fn run_client_core( - mut config: crate::config::ClientConfig, - metrics: Arc, + config: ClientConfig, + metrics: Arc, mut shutdown_rx_ext: watch::Receiver, - mut config_rx: Option>, + config_rx: Option>, ) -> Result<()> { - #[cfg(target_os = "windows")] - if config.mode == "tun" && !is_admin() { - return Err(anyhow::anyhow!("Administrator privileges are required to initialize TUN mode. Please run the application as Administrator.")); - } + println!("[ostp] Starting run_client_core with multi-server architecture"); - #[cfg(target_os = "linux")] - if config.mode == "tun" && !is_root() { - return Err(anyhow::anyhow!("Root privileges are required to initialize TUN mode on Linux. Please run with sudo.")); - } - - log_to_core_file(&format!("[core] Starting run_client_core in mode: {}", config.mode)); - - // Resolve the server IP before we override system routing and DNS. - // This prevents DNS deadlock if the VPN disconnects and tries to reconnect, - // and also ensures we add the direct route to the exact IP the bridge connects to. - #[allow(unused_mut)] - let mut resolved_addrs: Vec = tokio::net::lookup_host(&config.ostp.server_addr) - .await - .map_err(|e| anyhow::anyhow!("Failed to resolve server address {}: {}", config.ostp.server_addr, e))? - .collect(); - - - let target_addr = resolved_addrs.first() - .ok_or_else(|| anyhow::anyhow!("No IP addresses resolved for {}", config.ostp.server_addr))?; - - log_to_core_file(&format!("[core] Resolved server address to {}", target_addr)); - config.ostp.server_addr = target_addr.to_string(); - - - #[cfg(target_os = "linux")] - if config.mode == "tun" { - println!("\n[ostp] ==========================================================================="); - println!("[ostp] WARNING: You are starting TUN mode on a Linux system."); - println!("[ostp] If this is a remote headless server, routing all traffic through the TUN"); - println!("[ostp] interface WILL DROP your SSH connection and lock you out!"); - println!("[ostp] "); - println!("[ostp] SOLUTION: Add a static route for your client IP to bypass the TUN."); - println!("[ostp] Find your default gateway (ip route | grep default) and run:"); - println!("[ostp] sudo ip route add via "); - println!("[ostp] ===========================================================================\n"); - } - - #[cfg(target_os = "linux")] - if config.mode == "proxy" { - println!("\n[ostp] ==========================================================================="); - println!("[ostp] Proxy mode initialized on {}", config.local_proxy.bind_addr); - println!("[ostp] ===========================================================================\n"); - } - - let _sysproxy_guard = if config.mode == "proxy" { - // Enable system proxy and set initial ProxyOverride with user exclusions - let guard = Some(crate::sysproxy::SystemProxyGuard::enable(&config.local_proxy.bind_addr)); - crate::sysproxy::update_proxy_bypass_list( - &config.exclusions.domains, - &config.exclusions.ips, - ); - guard - } else { - None - }; - - if config.mode == "tun" && !config.exclusions.processes.is_empty() { - println!("[ostp] Process exclusions are not supported in TUN mode"); - } - - let (proxy_events_tx, proxy_events_rx) = mpsc::channel(256); - let (client_msgs_tx, client_msgs_rx) = mpsc::unbounded_channel(); + let router = Arc::new(Router::new(config.routing.clone())); + let balancer = Arc::new(Balancer::new(&config)); - // Setup exclusions hot-reload channel - let (reload_tx, reload_rx) = watch::channel(config.exclusions.clone()); + // TODO: Detect physical interface index for bypassing + let phys_if_for_bypass = None; + let outbound_manager = Arc::new(OutboundManager::new(balancer.clone(), phys_if_for_bypass, None)); - let mut bridge = Bridge::new(&config, metrics)?; - bridge.reload_tx = Some(reload_tx.clone()); + let mut handles = Vec::new(); - let (ui_tx, mut ui_rx) = mpsc::channel(512); - let (cmd_tx, cmd_rx) = mpsc::channel(128); - let (shutdown_tx, shutdown_rx) = watch::channel(false); - let proxy_shutdown_rx = shutdown_tx.subscribe(); + for inbound in config.inbounds.clone() { + let router_clone = router.clone(); + let outbound_manager_clone = outbound_manager.clone(); + let shutdown_rx = shutdown_rx_ext.clone(); + let config_clone = config.clone(); - - // Auto-connect on startup - let _ = cmd_tx.send(BridgeCommand::ToggleTunnel).await; - - let debug_enabled = config.debug; - - // Headless event logger - let cmd_tx_clone = cmd_tx.clone(); - tokio::spawn(async move { - let mut last_status = None; - while let Some(msg) = ui_rx.recv().await { - match msg { - crate::app::UiEvent::Log(text) => { - if debug_enabled || is_essential_log(&text) { - log_to_core_file(&format!("[ostp] {text}")); - println!("[ostp] {text}"); + match inbound.clone() { + InboundConfig::Tun { .. } => { + handles.push(tokio::spawn(async move { + if let Err(e) = crate::tunnel::inbounds::tun::run_tun_inbound( + config_clone, + inbound, + router_clone, + outbound_manager_clone, + shutdown_rx, + ).await { + tracing::error!("TUN inbound failed: {}", e); } - } - crate::app::UiEvent::Metrics { status, rtt_ms, .. } => { - let status_str = status.as_str().to_string(); - if last_status != Some(status_str.clone()) { - last_status = Some(status_str.clone()); - println!("[ostp] Status: {} (rtt={:.1}ms)", status_str, rtt_ms); + })); + } + InboundConfig::LocalProxy { .. } => { + handles.push(tokio::spawn(async move { + if let Err(e) = crate::tunnel::inbounds::local_proxy::run_socks_inbound( + config_clone, + inbound, + router_clone, + outbound_manager_clone, + shutdown_rx, + ).await { + tracing::error!("SOCKS inbound failed: {}", e); } - } - crate::app::UiEvent::Traffic { .. } => {} - crate::app::UiEvent::ProfileChanged(profile) => { - if debug_enabled { - println!("[ostp] Obfuscation profile: {profile:?}"); - } - } - crate::app::UiEvent::TunnelStopped => { - println!("[ostp] Connection interrupted. Reconnecting in 5 seconds..."); - let cmd_tx_inner = cmd_tx_clone.clone(); - tokio::spawn(async move { - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; - let _ = cmd_tx_inner.send(BridgeCommand::ToggleTunnel).await; - }); - } + })); } } - }); + } - let mut bridge_task = tokio::spawn(async move { - bridge.run(ui_tx, cmd_rx, shutdown_rx, proxy_events_rx, client_msgs_tx).await - }); - - let config_clone = config.clone(); - let proxy_exclusions_rx = reload_rx.clone(); - let mut proxy_task = tokio::spawn(async move { - tunnel::run_local_proxy( - config.local_proxy, - config.ostp, - proxy_exclusions_rx, - config.debug, - proxy_shutdown_rx, - proxy_events_tx, - client_msgs_rx, - ) - .await - }); - - let wintun_shutdown_rx = shutdown_tx.subscribe(); - let wintun_exclusions_rx = reload_rx.clone(); - let mut wintun_task = if config_clone.mode == "tun" { - Some(tokio::spawn(async move { - tunnel::run_tun_tunnel(config_clone, wintun_shutdown_rx, wintun_exclusions_rx).await - })) - } else { - None - }; - - // Wait for local_shutdown - let mut local_shutdown = shutdown_rx_ext.clone(); - let cmd_tx_loop = cmd_tx.clone(); - tokio::spawn(async move { - loop { - tokio::select! { - _ = local_shutdown.changed() => { - if *local_shutdown.borrow() { - let _ = cmd_tx_loop.send(BridgeCommand::Shutdown).await; - break; - } - } - Some(Ok(_)) = async { - if let Some(ref mut rx) = config_rx { - Some(rx.changed().await) - } else { - std::future::pending().await - } - } => { - if let Some(ref rx) = config_rx { - let new_cfg = rx.borrow().clone(); - // Update Windows ProxyOverride so excluded domains/IPs - // bypass the system proxy immediately (proxy mode only). - crate::sysproxy::update_proxy_bypass_list( - &new_cfg.exclusions.domains, - &new_cfg.exclusions.ips, - ); - let _ = reload_tx.send(new_cfg.exclusions); - } - } - } - } - }); - - // Wait for either external shutdown OR any task to fail + // Wait for shutdown or for tasks to fail tokio::select! { _ = shutdown_rx_ext.changed() => { - let _ = cmd_tx.send(BridgeCommand::Shutdown).await; - let _ = shutdown_tx.send(true); + if *shutdown_rx_ext.borrow() { + tracing::info!("Shutdown signal received in run_client_core"); + } } - res = &mut bridge_task => { - let _ = shutdown_tx.send(true); - res.map_err(|e| anyhow::anyhow!("Bridge task panicked: {}", e))??; - } - res = &mut proxy_task => { - let _ = shutdown_tx.send(true); - res.map_err(|e| anyhow::anyhow!("Proxy task panicked: {}", e))??; - } - res = async { - if let Some(t) = wintun_task.as_mut() { t.await } else { std::future::pending().await } - } => { - let _ = shutdown_tx.send(true); - res.map_err(|e| anyhow::anyhow!("TUN task panicked: {}", e))??; - } - } - - // Final cleanup: wait for tasks to finish - let _ = bridge_task.await; - let _ = proxy_task.await; - if let Some(task) = wintun_task { - let _ = task.await; } Ok(()) } - -#[allow(dead_code)] -fn format_bytes(bps: u64) -> String { - if bps >= 1_000_000 { - format!("{:.1}MB", bps as f64 / 1_000_000.0) - } else if bps >= 1_000 { - format!("{:.1}KB", bps as f64 / 1_000.0) - } else { - format!("{bps}B") - } -} - -fn is_essential_log(text: &str) -> bool { - matches!( - text, - "Connection established" - | "TUN tunnel established" - | "TUN tunnel stopped" - | "Bridge stopped" - | "Runtime config reloaded" - | "Connecting to remote server..." - ) || text.starts_with("Connected to ") - || text.starts_with("TURN relay allocated") - || text.starts_with("TURN allocation failed") - || text.starts_with("Allocating TURN relay") - || text.starts_with("Connection failed:") - || text.starts_with("Connection lost") - || text.starts_with("Protocol tick fatal error") -} diff --git a/ostp-client/src/runner.rs.bak b/ostp-client/src/runner.rs.bak new file mode 100644 index 0000000..44a2055 --- /dev/null +++ b/ostp-client/src/runner.rs.bak @@ -0,0 +1,436 @@ +use anyhow::Result; +use tokio::sync::{mpsc, watch}; + +use crate::app::BridgeCommand; +use crate::bridge::{Bridge, BridgeMetrics}; +use crate::signal::wait_for_shutdown_signal; +use crate::tunnel; +use std::sync::Arc; +use std::fs::OpenOptions; +use std::io::Write as _; + +fn log_to_core_file(msg: &str) { + let path = std::env::current_exe() + .ok() + .and_then(|p| p.parent().map(|d| d.join("ostp-core.log"))) + .unwrap_or_else(|| std::path::PathBuf::from("ostp-core.log")); + if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { + let _ = writeln!(file, "[{}] {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), msg); + } +} + +#[cfg(target_os = "windows")] +#[link(name = "kernel32")] +extern "system" { + fn FreeConsole() -> i32; + fn GetConsoleWindow() -> *mut std::ffi::c_void; +} + +#[cfg(target_os = "windows")] +#[link(name = "user32")] +extern "system" { + fn ShowWindow(hwnd: *mut std::ffi::c_void, cmd_show: i32) -> i32; +} + +fn hide_console() { + #[cfg(target_os = "windows")] + unsafe { + let hwnd = GetConsoleWindow(); + if !hwnd.is_null() { + ShowWindow(hwnd, 0); // SW_HIDE = 0 + } + FreeConsole(); + } +} + +#[cfg(target_os = "windows")] +pub fn is_admin() -> bool { + std::process::Command::new("net") + .arg("session") + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +#[cfg(target_os = "windows")] +fn relaunch_as_admin() -> Result<()> { + use std::ffi::OsStr; + use std::os::windows::ffi::OsStrExt; + use std::ptr::null_mut; + + let exe = std::env::current_exe()?; + let exe_wstr: Vec = exe.as_os_str().encode_wide().chain(Some(0)).collect(); + + let mut args_joined = String::new(); + for arg in std::env::args().skip(1) { + if !args_joined.is_empty() { + args_joined.push(' '); + } + args_joined.push('"'); + args_joined.push_str(&arg.replace('"', "\\\"")); + args_joined.push('"'); + } + let args_wstr: Vec = OsStr::new(&args_joined).encode_wide().chain(Some(0)).collect(); + + let dir = std::env::current_dir()?; + let dir_wstr: Vec = dir.as_os_str().encode_wide().chain(Some(0)).collect(); + + let verb_wstr: Vec = OsStr::new("runas").encode_wide().chain(Some(0)).collect(); + + #[link(name = "shell32")] + extern "system" { + fn ShellExecuteW( + hwnd: *mut std::ffi::c_void, + lpOperation: *const u16, + lpFile: *const u16, + lpParameters: *const u16, + lpDirectory: *const u16, + nShowCmd: i32, + ) -> isize; + } + + unsafe { + let ret = ShellExecuteW( + null_mut(), + verb_wstr.as_ptr(), + exe_wstr.as_ptr(), + args_wstr.as_ptr(), + dir_wstr.as_ptr(), + 1, // SW_SHOWNORMAL = 1 + ); + if ret <= 32 { + return Err(anyhow::anyhow!( + "Windows UAC Elevation failed or was denied by policy (ShellExecuteW code: {})", + ret + )); + } + } + + std::process::exit(0); +} + +#[cfg(target_os = "linux")] +pub fn is_root() -> bool { + unsafe { libc::geteuid() == 0 } +} + +#[cfg(target_os = "linux")] +fn relaunch_as_root() -> Result<()> { + use std::io::IsTerminal; + let exe = std::env::current_exe()?; + let args: Vec = std::env::args().skip(1).collect(); + + let is_gui = std::env::var("DISPLAY").is_ok() || std::env::var("WAYLAND_DISPLAY").is_ok(); + let is_term = std::io::stdout().is_terminal(); + + let mut cmd = if is_gui && !is_term { + let mut c = std::process::Command::new("pkexec"); + c.arg(exe); + c + } else { + let mut c = std::process::Command::new("sudo"); + c.arg(exe); + c + }; + + cmd.args(&args); + + let status = cmd.status().map_err(|e| anyhow::anyhow!("Failed to execute privilege escalation command: {}", e))?; + + if !status.success() { + return Err(anyhow::anyhow!("Privilege escalation failed or was denied.")); + } + + std::process::exit(0); +} + +pub async fn run_client(config: crate::config::ClientConfig) -> Result<()> { + #[cfg(target_os = "windows")] + if config.mode == "tun" && !is_admin() { + println!("[ostp] TUN mode requires administrator privileges. Relaunching..."); + relaunch_as_admin()?; + } + + #[cfg(target_os = "linux")] + if config.mode == "tun" && !is_root() { + println!("[ostp] TUN mode requires root privileges. Requesting sudo/pkexec elevation..."); + relaunch_as_root()?; + } + + let bg = std::env::args().any(|a| a == "--bg"); + + if bg { + hide_console(); + } + + let metrics = Arc::new(BridgeMetrics { + bytes_sent: portable_atomic::AtomicU64::new(0), + bytes_recv: portable_atomic::AtomicU64::new(0), + connection_state: portable_atomic::AtomicU8::new(0), + rtt_ms: portable_atomic::AtomicU32::new(0), + }); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + tokio::spawn(async move { + if wait_for_shutdown_signal().await.is_ok() { + let _ = shutdown_tx.send(true); + } + }); + + run_client_core(config, metrics, shutdown_rx, None).await +} + +pub async fn run_client_core( + mut config: crate::config::ClientConfig, + metrics: Arc, + mut shutdown_rx_ext: watch::Receiver, + mut config_rx: Option>, +) -> Result<()> { + #[cfg(target_os = "windows")] + if config.mode == "tun" && !is_admin() { + return Err(anyhow::anyhow!("Administrator privileges are required to initialize TUN mode. Please run the application as Administrator.")); + } + + #[cfg(target_os = "linux")] + if config.mode == "tun" && !is_root() { + return Err(anyhow::anyhow!("Root privileges are required to initialize TUN mode on Linux. Please run with sudo.")); + } + + log_to_core_file(&format!("[core] Starting run_client_core in mode: {}", config.mode)); + + // Resolve the server IP before we override system routing and DNS. + // This prevents DNS deadlock if the VPN disconnects and tries to reconnect, + // and also ensures we add the direct route to the exact IP the bridge connects to. + #[allow(unused_mut)] + let mut resolved_addrs: Vec = tokio::net::lookup_host(&config.ostp.server_addr) + .await + .map_err(|e| anyhow::anyhow!("Failed to resolve server address {}: {}", config.ostp.server_addr, e))? + .collect(); + + + let target_addr = resolved_addrs.first() + .ok_or_else(|| anyhow::anyhow!("No IP addresses resolved for {}", config.ostp.server_addr))?; + + log_to_core_file(&format!("[core] Resolved server address to {}", target_addr)); + config.ostp.server_addr = target_addr.to_string(); + + + #[cfg(target_os = "linux")] + if config.mode == "tun" { + println!("\n[ostp] ==========================================================================="); + println!("[ostp] WARNING: You are starting TUN mode on a Linux system."); + println!("[ostp] If this is a remote headless server, routing all traffic through the TUN"); + println!("[ostp] interface WILL DROP your SSH connection and lock you out!"); + println!("[ostp] "); + println!("[ostp] SOLUTION: Add a static route for your client IP to bypass the TUN."); + println!("[ostp] Find your default gateway (ip route | grep default) and run:"); + println!("[ostp] sudo ip route add via "); + println!("[ostp] ===========================================================================\n"); + } + + #[cfg(target_os = "linux")] + if config.mode == "proxy" { + println!("\n[ostp] ==========================================================================="); + println!("[ostp] Proxy mode initialized on {}", config.local_proxy.bind_addr); + println!("[ostp] ===========================================================================\n"); + } + + let _sysproxy_guard = if config.mode == "proxy" { + // Enable system proxy and set initial ProxyOverride with user exclusions + let guard = Some(crate::sysproxy::SystemProxyGuard::enable(&config.local_proxy.bind_addr)); + crate::sysproxy::update_proxy_bypass_list( + &config.exclusions.domains, + &config.exclusions.ips, + ); + guard + } else { + None + }; + + if config.mode == "tun" && !config.exclusions.processes.is_empty() { + println!("[ostp] Process exclusions are not supported in TUN mode"); + } + + let (proxy_events_tx, proxy_events_rx) = mpsc::channel(256); + let (client_msgs_tx, client_msgs_rx) = mpsc::unbounded_channel(); + + // Setup exclusions hot-reload channel + let (reload_tx, reload_rx) = watch::channel(config.exclusions.clone()); + + let mut bridge = Bridge::new(&config, metrics)?; + bridge.reload_tx = Some(reload_tx.clone()); + + let (ui_tx, mut ui_rx) = mpsc::channel(512); + let (cmd_tx, cmd_rx) = mpsc::channel(128); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let proxy_shutdown_rx = shutdown_tx.subscribe(); + + + // Auto-connect on startup + let _ = cmd_tx.send(BridgeCommand::ToggleTunnel).await; + + let debug_enabled = config.debug; + + // Headless event logger + let cmd_tx_clone = cmd_tx.clone(); + tokio::spawn(async move { + let mut last_status = None; + while let Some(msg) = ui_rx.recv().await { + match msg { + crate::app::UiEvent::Log(text) => { + if debug_enabled || is_essential_log(&text) { + log_to_core_file(&format!("[ostp] {text}")); + println!("[ostp] {text}"); + } + } + crate::app::UiEvent::Metrics { status, rtt_ms, .. } => { + let status_str = status.as_str().to_string(); + if last_status != Some(status_str.clone()) { + last_status = Some(status_str.clone()); + println!("[ostp] Status: {} (rtt={:.1}ms)", status_str, rtt_ms); + } + } + crate::app::UiEvent::Traffic { .. } => {} + crate::app::UiEvent::ProfileChanged(profile) => { + if debug_enabled { + println!("[ostp] Obfuscation profile: {profile:?}"); + } + } + crate::app::UiEvent::TunnelStopped => { + println!("[ostp] Connection interrupted. Reconnecting in 5 seconds..."); + let cmd_tx_inner = cmd_tx_clone.clone(); + tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + let _ = cmd_tx_inner.send(BridgeCommand::ToggleTunnel).await; + }); + } + } + } + }); + + let mut bridge_task = tokio::spawn(async move { + bridge.run(ui_tx, cmd_rx, shutdown_rx, proxy_events_rx, client_msgs_tx).await + }); + + let config_clone = config.clone(); + let proxy_exclusions_rx = reload_rx.clone(); + let mut proxy_task = tokio::spawn(async move { + tunnel::run_local_proxy( + config.local_proxy, + config.ostp, + proxy_exclusions_rx, + config.debug, + proxy_shutdown_rx, + proxy_events_tx, + client_msgs_rx, + ) + .await + }); + + let wintun_shutdown_rx = shutdown_tx.subscribe(); + let wintun_exclusions_rx = reload_rx.clone(); + let mut wintun_task = if config_clone.mode == "tun" { + Some(tokio::spawn(async move { + tunnel::run_tun_tunnel(config_clone, wintun_shutdown_rx, wintun_exclusions_rx).await + })) + } else { + None + }; + + // Wait for local_shutdown + let mut local_shutdown = shutdown_rx_ext.clone(); + let cmd_tx_loop = cmd_tx.clone(); + tokio::spawn(async move { + loop { + tokio::select! { + _ = local_shutdown.changed() => { + if *local_shutdown.borrow() { + let _ = cmd_tx_loop.send(BridgeCommand::Shutdown).await; + break; + } + } + Some(Ok(_)) = async { + if let Some(ref mut rx) = config_rx { + Some(rx.changed().await) + } else { + std::future::pending().await + } + } => { + if let Some(ref rx) = config_rx { + let new_cfg = rx.borrow().clone(); + // Update Windows ProxyOverride so excluded domains/IPs + // bypass the system proxy immediately (proxy mode only). + crate::sysproxy::update_proxy_bypass_list( + &new_cfg.exclusions.domains, + &new_cfg.exclusions.ips, + ); + let _ = reload_tx.send(new_cfg.exclusions); + } + } + } + } + }); + + // Wait for either external shutdown OR any task to fail + tokio::select! { + _ = shutdown_rx_ext.changed() => { + let _ = cmd_tx.send(BridgeCommand::Shutdown).await; + let _ = shutdown_tx.send(true); + } + res = &mut bridge_task => { + let _ = shutdown_tx.send(true); + res.map_err(|e| anyhow::anyhow!("Bridge task panicked: {}", e))??; + } + res = &mut proxy_task => { + let _ = shutdown_tx.send(true); + res.map_err(|e| anyhow::anyhow!("Proxy task panicked: {}", e))??; + } + res = async { + if let Some(t) = wintun_task.as_mut() { t.await } else { std::future::pending().await } + } => { + let _ = shutdown_tx.send(true); + res.map_err(|e| anyhow::anyhow!("TUN task panicked: {}", e))??; + } + } + + // Final cleanup: wait for tasks to finish + let _ = bridge_task.await; + let _ = proxy_task.await; + if let Some(task) = wintun_task { + let _ = task.await; + } + + Ok(()) +} + +#[allow(dead_code)] +fn format_bytes(bps: u64) -> String { + if bps >= 1_000_000 { + format!("{:.1}MB", bps as f64 / 1_000_000.0) + } else if bps >= 1_000 { + format!("{:.1}KB", bps as f64 / 1_000.0) + } else { + format!("{bps}B") + } +} + +fn is_essential_log(text: &str) -> bool { + matches!( + text, + "Connection established" + | "TUN tunnel established" + | "TUN tunnel stopped" + | "Bridge stopped" + | "Runtime config reloaded" + | "Connecting to remote server..." + ) || text.starts_with("Connected to ") + || text.starts_with("TURN relay allocated") + || text.starts_with("TURN allocation failed") + || text.starts_with("Allocating TURN relay") + || text.starts_with("Connection failed:") + || text.starts_with("Connection lost") + || text.starts_with("Protocol tick fatal error") +} diff --git a/ostp-client/src/tunnel/balancer.rs b/ostp-client/src/tunnel/balancer.rs new file mode 100644 index 0000000..9951f59 --- /dev/null +++ b/ostp-client/src/tunnel/balancer.rs @@ -0,0 +1,65 @@ +use crate::config::{ClientConfig, OutboundConfig}; +use std::collections::HashMap; +use std::sync::Arc; + +pub struct Balancer { + outbounds: HashMap, +} + +impl Balancer { + pub fn new(config: &ClientConfig) -> Self { + let mut outbounds = HashMap::new(); + for outbound in &config.outbounds { + let tag = match outbound { + OutboundConfig::Selector { tag, .. } => tag, + OutboundConfig::Urltest { tag, .. } => tag, + OutboundConfig::Ostp { tag, .. } => tag, + OutboundConfig::Direct { tag } => tag, + OutboundConfig::Socks { tag, .. } => tag, + OutboundConfig::Block { tag } => tag, + }; + outbounds.insert(tag.clone(), outbound.clone()); + } + + Self { outbounds } + } + + /// Resolves an outbound tag into a concrete, non-group outbound tag. + /// E.g. "proxy-group" -> "server-helsinki" + pub fn resolve_outbound(&self, tag: &str) -> String { + // Prevent infinite loops if groups point to groups + let mut current_tag = tag.to_string(); + for _ in 0..10 { + if let Some(outbound) = self.outbounds.get(¤t_tag) { + match outbound { + OutboundConfig::Selector { outbounds, default, .. } => { + current_tag = if let Some(def) = default { + def.clone() + } else { + outbounds.first().cloned().unwrap_or_else(|| "direct".to_string()) + }; + } + OutboundConfig::Urltest { outbounds, .. } => { + // TODO: Implement background ping worker to find the fastest node. + // For now, act as a fallback by taking the first available node. + current_tag = outbounds.first().cloned().unwrap_or_else(|| "direct".to_string()); + } + _ => { + // It's a concrete physical outbound (ostp, direct, block) + return current_tag; + } + } + } else { + // Outbound not found, fallback to direct + return "direct".to_string(); + } + } + "direct".to_string() // Max depth reached + } + + /// Fetches the config for a concrete outbound + pub fn get_concrete_outbound(&self, tag: &str) -> Option<&OutboundConfig> { + let resolved_tag = self.resolve_outbound(tag); + self.outbounds.get(&resolved_tag) + } +} diff --git a/ostp-client/src/tunnel/inbounds/local_proxy.rs b/ostp-client/src/tunnel/inbounds/local_proxy.rs new file mode 100644 index 0000000..92279ba --- /dev/null +++ b/ostp-client/src/tunnel/inbounds/local_proxy.rs @@ -0,0 +1,224 @@ +use anyhow::{anyhow, Result}; +use std::sync::Arc; +use crate::config::{ClientConfig, InboundConfig}; +use crate::tunnel::router::{Router, Session}; +use crate::tunnel::outbounds::OutboundManager; +use tokio::net::TcpListener; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::watch; + +pub async fn run_socks_inbound( + _config: ClientConfig, + inbound_config: InboundConfig, + router: Arc, + outbound_manager: Arc, + mut shutdown: watch::Receiver, +) -> Result<()> { + let InboundConfig::LocalProxy { tag, protocol, listen, port } = inbound_config else { + return Err(anyhow!("Invalid config for LocalProxy inbound")); + }; + + let bind_addr = format!("{}:{}", listen, port); + tracing::info!("Starting {} proxy inbound on {} (tag: {})", protocol, bind_addr, tag); + + let listener = TcpListener::bind(&bind_addr).await?; + + loop { + tokio::select! { + _ = shutdown.changed() => { + tracing::info!("Local proxy inbound {} shutting down", tag); + break; + } + accept_res = listener.accept() => { + if let Ok((mut stream, client_addr)) = accept_res { + let rt = router.clone(); + let om = outbound_manager.clone(); + let proto = protocol.clone(); + let inbound_tag = tag.clone(); + + tokio::spawn(async move { + if proto == "socks" { + if let Err(e) = handle_socks5_connection(&mut stream, &rt, &om, &inbound_tag, client_addr).await { + tracing::debug!("SOCKS5 handling error: {}", e); + } + } else if proto == "http" { + if let Err(e) = handle_http_connection(&mut stream, &rt, &om, &inbound_tag, client_addr).await { + tracing::debug!("HTTP proxy handling error: {}", e); + } + } else { + tracing::error!("Unknown local proxy protocol: {}", proto); + } + }); + } + } + } + } + + Ok(()) +} + +async fn handle_socks5_connection( + stream: &mut tokio::net::TcpStream, + router: &Arc, + outbound_manager: &Arc, + inbound_tag: &str, + client_addr: std::net::SocketAddr, +) -> Result<()> { + let mut buf = [0u8; 256]; + + // Read version and method selection + stream.read_exact(&mut buf[0..2]).await?; + if buf[0] != 0x05 { + return Err(anyhow!("Unsupported SOCKS version: {}", buf[0])); + } + + let num_methods = buf[1] as usize; + stream.read_exact(&mut buf[0..num_methods]).await?; + + // Reply with NO AUTHENTICATION REQUIRED (0x00) + stream.write_all(&[0x05, 0x00]).await?; + + // Read the actual request + stream.read_exact(&mut buf[0..4]).await?; + if buf[0] != 0x05 || buf[1] != 0x01 { // Only CONNECT is supported + return Err(anyhow!("Unsupported SOCKS command")); + } + + let atyp = buf[3]; + let (target_host, mut ip_addr) = match atyp { + 0x01 => { // IPv4 + stream.read_exact(&mut buf[0..4]).await?; + let ip = std::net::Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]); + (ip.to_string(), Some(std::net::IpAddr::V4(ip))) + } + 0x03 => { // Domain + stream.read_exact(&mut buf[0..1]).await?; + let domain_len = buf[0] as usize; + stream.read_exact(&mut buf[0..domain_len]).await?; + let domain = String::from_utf8_lossy(&buf[0..domain_len]).to_string(); + (domain, None) + } + 0x04 => { // IPv6 + stream.read_exact(&mut buf[0..16]).await?; + let mut ip_bytes = [0u8; 16]; + ip_bytes.copy_from_slice(&buf[0..16]); + let ip = std::net::Ipv6Addr::from(ip_bytes); + (ip.to_string(), Some(std::net::IpAddr::V6(ip))) + } + _ => return Err(anyhow!("Unsupported SOCKS address type: {}", atyp)), + }; + + stream.read_exact(&mut buf[0..2]).await?; + let target_port = u16::from_be_bytes([buf[0], buf[1]]); + + let process_name = crate::tunnel::process_lookup::get_process_name_from_port(client_addr.port()); + + let session = Session { + protocol: "tcp".to_string(), + inbound_tag: inbound_tag.to_string(), + source_ip: Some(client_addr.ip()), + destination_ip: ip_addr, + destination_port: target_port, + sni: if atyp == 0x03 { Some(target_host.clone()) } else { None }, + process_name, + }; + + let outbound_tag = router.route(&session); + tracing::info!("SOCKS5 TCP {} -> {}:{} routed to {}", client_addr, target_host, target_port, outbound_tag); + + match outbound_manager.dial_tcp(&outbound_tag, &target_host, target_port).await { + Ok(mut remote_stream) => { + // Reply success + stream.write_all(&[0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]).await?; + + // Forward data + tokio::io::copy_bidirectional(stream, &mut remote_stream).await?; + } + Err(e) => { + tracing::warn!("SOCKS5 TCP dial failed to {}: {}", outbound_tag, e); + // Reply host unreachable + let _ = stream.write_all(&[0x05, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]).await; + } + } + + Ok(()) +} + +async fn handle_http_connection( + stream: &mut tokio::net::TcpStream, + router: &Arc, + outbound_manager: &Arc, + inbound_tag: &str, + client_addr: std::net::SocketAddr, +) -> Result<()> { + // Basic HTTP CONNECT implementation + let mut buf = [0u8; 4096]; + let n = stream.read(&mut buf).await?; + if n == 0 { return Ok(()); } + + let request = String::from_utf8_lossy(&buf[0..n]); + let mut lines = request.lines(); + let first_line = lines.next().ok_or_else(|| anyhow!("Empty HTTP request"))?; + + let parts: Vec<&str> = first_line.split_whitespace().collect(); + if parts.len() < 3 { + return Err(anyhow!("Invalid HTTP request line")); + } + + let method = parts[0]; + let target = parts[1]; // host:port for CONNECT, http://host:port/... for GET + + let (target_host, target_port) = if method == "CONNECT" { + let parts: Vec<&str> = target.split(':').collect(); + let host = parts[0].to_string(); + let port = parts.get(1).unwrap_or(&"443").parse::().unwrap_or(443); + (host, port) + } else { + // Rudimentary GET parsing, ideally use httparse + if target.starts_with("http://") { + let without_scheme = &target[7..]; + let host_part = without_scheme.split('/').next().unwrap_or(without_scheme); + let parts: Vec<&str> = host_part.split(':').collect(); + let host = parts[0].to_string(); + let port = parts.get(1).unwrap_or(&"80").parse::().unwrap_or(80); + (host, port) + } else { + return Err(anyhow!("Unsupported HTTP method/target: {} {}", method, target)); + } + }; + + let process_name = crate::tunnel::process_lookup::get_process_name_from_port(client_addr.port()); + + let session = Session { + protocol: "tcp".to_string(), + inbound_tag: inbound_tag.to_string(), + source_ip: Some(client_addr.ip()), + destination_ip: None, // Could parse if IP + destination_port: target_port, + sni: Some(target_host.clone()), + process_name, + }; + + let outbound_tag = router.route(&session); + tracing::info!("HTTP TCP {} -> {}:{} routed to {}", client_addr, target_host, target_port, outbound_tag); + + match outbound_manager.dial_tcp(&outbound_tag, &target_host, target_port).await { + Ok(mut remote_stream) => { + if method == "CONNECT" { + stream.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n").await?; + } else { + remote_stream.write_all(&buf[0..n]).await?; + } + + tokio::io::copy_bidirectional(stream, &mut remote_stream).await?; + } + Err(e) => { + tracing::warn!("HTTP TCP dial failed to {}: {}", outbound_tag, e); + if method == "CONNECT" { + let _ = stream.write_all(b"HTTP/1.1 502 Bad Gateway\r\n\r\n").await; + } + } + } + + Ok(()) +} diff --git a/ostp-client/src/tunnel/inbounds/local_proxy.rs.bak b/ostp-client/src/tunnel/inbounds/local_proxy.rs.bak new file mode 100644 index 0000000..5acf471 Binary files /dev/null and b/ostp-client/src/tunnel/inbounds/local_proxy.rs.bak differ diff --git a/ostp-client/src/tunnel/inbounds/mod.rs b/ostp-client/src/tunnel/inbounds/mod.rs new file mode 100644 index 0000000..d7f1fb4 --- /dev/null +++ b/ostp-client/src/tunnel/inbounds/mod.rs @@ -0,0 +1,2 @@ +pub mod tun; +pub mod local_proxy; diff --git a/ostp-client/src/tunnel/inbounds/tun.rs b/ostp-client/src/tunnel/inbounds/tun.rs new file mode 100644 index 0000000..3f5c5a3 --- /dev/null +++ b/ostp-client/src/tunnel/inbounds/tun.rs @@ -0,0 +1,239 @@ +use anyhow::{anyhow, Result}; +use std::sync::Arc; +use crate::config::{ClientConfig, InboundConfig}; +use crate::tunnel::router::{Router, Session}; +use crate::tunnel::outbounds::OutboundManager; +use tokio::sync::watch; + +#[cfg(any(target_os = "windows", target_os = "linux"))] +pub async fn run_tun_inbound( + config: ClientConfig, + inbound_config: InboundConfig, + router: Arc, + outbound_manager: Arc, + mut shutdown: watch::Receiver, +) -> Result<()> { + use std::net::ToSocketAddrs; + use netstack_smoltcp::StackBuilder; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use futures::{StreamExt, SinkExt}; + + let InboundConfig::Tun { tag, auto_route, mtu } = inbound_config else { + return Err(anyhow!("Invalid config for TUN inbound")); + }; + + tracing::info!("Starting TUN inbound (tag: {}, auto_route: {}, mtu: {})", tag, auto_route, mtu); + + #[cfg(target_os = "windows")] + let _phys_if_for_bypass: Option = ostp_tun::windows::windows_route::sys::get_default_ipv4_route().map(|(_, idx)| idx); + #[cfg(not(target_os = "windows"))] + let _phys_if_for_bypass: Option = None; + + let mut bypass_ips: Vec = Vec::new(); + + // Bypass all outbound server IPs + for outbound in &config.outbounds { + let server = match outbound { + crate::config::OutboundConfig::Ostp { server, .. } => Some(server), + crate::config::OutboundConfig::Socks { server, .. } => Some(server), + _ => None, + }; + if let Some(host) = server { + if let Ok(ip) = host.parse::() { + bypass_ips.push(ip); + } else { + if let Ok(addrs) = tokio::net::lookup_host((host.as_str(), 443)).await { + for addr in addrs { + bypass_ips.push(addr.ip()); + } + } + } + } + } + + let dummy_server_ip = bypass_ips.first().copied().unwrap_or_else(|| "8.8.8.8".parse().unwrap()); + + // Create TUN device + let opts = ostp_tun::OstpTunOptions { + server_ip: dummy_server_ip, + bypass_ips, + dns_server: None, + kill_switch: false, + mtu: mtu as u16, + wintun_path: None, + }; + + let tun_interface = ostp_tun::OstpTunInterface::create(opts) + .await + .map_err(|e| anyhow!("Failed to create OstpTunInterface: {}", e))?; + + let dev = tun_interface.device; + let _route_guard = tun_interface.guard; // Drops when TUN drops + + // Build smoltcp network stack + let (stack, tcp_runner, udp_socket, tcp_listener) = StackBuilder::default() + .stack_buffer_size(1024) + .tcp_buffer_size(1024) + .udp_buffer_size(1024) + .enable_tcp(true) + .enable_udp(true) + .mtu(mtu) + .build()?; + + let mut runner_task = tokio::spawn(async move { + if let Some(runner) = tcp_runner { + let _ = runner.await; + } + }); + + let (mut stack_sink, mut stack_stream) = stack.split(); + let (mut tun_read, mut tun_write) = tokio::io::split(dev); + + let mut tun_to_stack = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + loop { + match tun_read.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + let frame = buf[..n].to_vec(); + if let Err(e) = stack_sink.send(frame).await { + if e.kind() == std::io::ErrorKind::BrokenPipe { + break; + } + } + } + Err(e) => { + tracing::debug!("tun_read error: {e}"); + } + } + } + }); + + let mut stack_to_tun = tokio::spawn(async move { + while let Some(Ok(frame)) = stack_stream.next().await { + if let Err(e) = tun_write.write(&frame).await { + tracing::debug!("tun_write error: {e}"); + } + } + }); + + // ── TCP Handler ── + let outbound_manager_tcp = outbound_manager.clone(); + let router_tcp = router.clone(); + let tag_tcp = tag.clone(); + + let mut tcp_accept_task = tokio::spawn(async move { + let Some(mut listener) = tcp_listener else { return; }; + while let Some((mut stream, local, remote)) = listener.next().await { + let om = outbound_manager_tcp.clone(); + let rt = router_tcp.clone(); + let ib_tag = tag_tcp.clone(); + + tokio::spawn(async move { + let process_name = crate::tunnel::process_lookup::get_process_name_from_port(local.port()); + + let mut sniff_buf = [0u8; 2048]; + let sniff_len = match tokio::time::timeout( + std::time::Duration::from_millis(100), + stream.read(&mut sniff_buf) + ).await { + Ok(Ok(n)) => n, + _ => 0, + }; + + let mut domain_suffix = None; + if sniff_len > 0 { + domain_suffix = crate::tunnel::sni_sniff::extract_sni(&sniff_buf[..sniff_len]); + } + + let session = Session { + protocol: "tcp".to_string(), + inbound_tag: ib_tag.clone(), + source_ip: Some(local.ip()), + destination_ip: Some(remote.ip()), + destination_port: remote.port(), + sni: domain_suffix.map(|s| s.to_string()), + process_name, + }; + + let outbound_tag = rt.route(&session); + tracing::info!("TUN TCP {} -> {} routed to {}", local, remote, outbound_tag); + + let target_host = if let Some(domain) = session.sni { + domain + } else { + remote.ip().to_string() + }; + + match om.dial_tcp(&outbound_tag, &target_host, session.destination_port).await { + Ok(mut remote_stream) => { + if sniff_len > 0 { + if let Err(e) = remote_stream.write_all(&sniff_buf[..sniff_len]).await { + tracing::warn!("Failed to forward sniffed bytes to {}: {}", outbound_tag, e); + return; + } + } + let _ = tokio::io::copy_bidirectional(&mut stream, &mut remote_stream).await; + } + Err(e) => { + tracing::warn!("TUN TCP dial failed to {}: {}", outbound_tag, e); + } + } + }); + } + }); + + // ── UDP Handler ── + let outbound_manager_udp = outbound_manager.clone(); + let router_udp = router.clone(); + let tag_udp = tag.clone(); + + let mut udp_proxy_task = tokio::spawn(async move { + if let Some(udp_sock) = udp_socket { + let (mut udp_rx, _udp_tx) = udp_sock.split(); + while let Some((payload, local, remote)) = udp_rx.next().await { + let process_name = crate::tunnel::process_lookup::get_process_name_from_port_udp(local.port()); + let session = Session { + protocol: "udp".to_string(), + inbound_tag: tag_udp.clone(), + source_ip: Some(local.ip()), + destination_ip: Some(remote.ip()), + destination_port: remote.port(), + sni: None, + process_name, + }; + let outbound_tag = router_udp.route(&session); + + let payload_bytes = bytes::Bytes::copy_from_slice(&payload); + if let Err(e) = outbound_manager_udp.handle_udp(&outbound_tag, local, remote, payload_bytes).await { + tracing::debug!("TUN UDP drop to {}: {}", outbound_tag, e); + } + } + } + }); + + tokio::select! { + _ = shutdown.changed() => { + tracing::info!("TUN inbound {} shutting down", tag); + } + _ = &mut runner_task => {} + } + + tun_to_stack.abort(); + stack_to_tun.abort(); + tcp_accept_task.abort(); + udp_proxy_task.abort(); + + Ok(()) +} + +#[cfg(not(any(target_os = "windows", target_os = "linux")))] +pub async fn run_tun_inbound( + _config: ClientConfig, + _inbound_config: InboundConfig, + _router: Arc, + _outbound_manager: Arc, + _shutdown: watch::Receiver, +) -> Result<()> { + Err(anyhow!("TUN is only supported on Windows and Linux")) +} diff --git a/ostp-client/src/tunnel/inbounds/tun.rs.bak b/ostp-client/src/tunnel/inbounds/tun.rs.bak new file mode 100644 index 0000000..eead16f Binary files /dev/null and b/ostp-client/src/tunnel/inbounds/tun.rs.bak differ diff --git a/ostp-client/src/tunnel/mod.rs b/ostp-client/src/tunnel/mod.rs index a6e3209..e1d6930 100644 --- a/ostp-client/src/tunnel/mod.rs +++ b/ostp-client/src/tunnel/mod.rs @@ -1,67 +1,7 @@ -mod proxy; -pub mod native_handler; +pub mod router; +pub mod balancer; +pub mod outbounds; +pub mod inbounds; -mod udp_nat; - -pub async fn run_tun_tunnel( - config: crate::config::ClientConfig, - shutdown: tokio::sync::watch::Receiver, - exclusions_rx: tokio::sync::watch::Receiver, -) -> anyhow::Result<()> { - native_handler::run_native_tunnel(config, shutdown, exclusions_rx).await -} - -use tokio::sync::{mpsc, watch}; - -use crate::config::{ExclusionConfig, LocalProxyConfig, OstpConfig}; - -pub use proxy::run_local_socks5_proxy; - -#[derive(Debug)] -pub enum ProxyEvent { - NewStream { - stream_id: u16, - target: String, - }, - UdpAssociate { - stream_id: u16, - }, - UdpData { - stream_id: u16, - target: String, - payload: bytes::Bytes, - }, - Data { - stream_id: u16, - payload: bytes::Bytes, - }, - Close { - stream_id: u16, - }, -} - -#[derive(Debug)] -pub enum ProxyToClientMsg { - ConnectOk, - Data(bytes::Bytes), - UdpData(String, bytes::Bytes), - Close, - Error(String), -} - - -pub async fn run_local_proxy( - cfg: LocalProxyConfig, - ostp: OstpConfig, - exclusions_rx: watch::Receiver, - debug: bool, - shutdown: watch::Receiver, - proxy_events_tx: mpsc::Sender, - client_msgs_rx: mpsc::UnboundedReceiver<(u16, ProxyToClientMsg)>, -) -> anyhow::Result<()> { - run_local_socks5_proxy(cfg, ostp, exclusions_rx, debug, shutdown, proxy_events_tx, client_msgs_rx).await -} - -pub mod exclusion; pub mod process_lookup; pub mod sni_sniff; diff --git a/ostp-client/src/tunnel/native_handler.rs b/ostp-client/src/tunnel/native_handler.rs deleted file mode 100644 index bb9ce10..0000000 --- a/ostp-client/src/tunnel/native_handler.rs +++ /dev/null @@ -1,744 +0,0 @@ -use anyhow::{anyhow, Result}; -use tokio::sync::watch; - -// ────────────────────────────────────────────────────────────────────────────── -// Windows / Linux desktop TUN -// ────────────────────────────────────────────────────────────────────────────── - -#[cfg(any(target_os = "windows", target_os = "linux"))] -pub async fn run_native_tunnel( - config: crate::config::ClientConfig, - mut shutdown: watch::Receiver, - mut exclusions_rx: watch::Receiver, -) -> Result<()> { - use std::net::ToSocketAddrs; - use netstack_smoltcp::StackBuilder; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use futures::{StreamExt, SinkExt}; - - #[cfg(target_os = "linux")] - { - use std::io::{self, IsTerminal, Write}; - if io::stdout().is_terminal() { - println!("\n==================================================================="); - println!("WARNING: TUN mode will modify the system routing table."); - println!("If you are connected to a headless server via SSH, you may lose"); - println!("your connection when default routes are redirected into the tunnel."); - println!("===================================================================\n"); - print!("Are you sure you want to initialize the TUN interface? [yes/no]: "); - io::stdout().flush().unwrap(); - - let mut input = String::new(); - io::stdin().read_line(&mut input).unwrap(); - let ans = input.trim().to_lowercase(); - if ans != "y" && ans != "yes" { - return Err(anyhow!("TUN initialization aborted by user.")); - } - } - } - - let debug = config.debug; - tracing::info!("Initializing NATIVE TUN tunnel (smoltcp)..."); - - // Capture physical interface index for bypass BEFORE we create the TUN device and alter routes. - #[cfg(target_os = "windows")] - let phys_if_for_bypass: Option = ostp_tun::windows::windows_route::sys::get_default_ipv4_route().map(|(_, idx)| idx); - #[cfg(not(target_os = "windows"))] - let phys_if_for_bypass: Option = None; - - // ── 1. Resolve server IP ────────────────────────────────────────────────── - let server_ip = config - .ostp - .server_addr - .to_socket_addrs() - .map_err(|e| anyhow!("Failed to resolve server IP: {}", e))? - .next() - .map(|a| a.ip()) - .ok_or_else(|| anyhow!("Could not resolve server host"))?; - #[allow(unused_variables)] - let server_ip_str = server_ip.to_string(); - - // ── 2. Resolve excluded domains → IP addresses for bypass routing ───────── - let mut bypass_ips: Vec = Vec::new(); - - // Server IP always bypasses TUN - bypass_ips.push(server_ip); - - for ip_str in &config.exclusions.ips { - let host = ip_str.split('/').next().unwrap_or(ip_str); - if let Ok(ip) = host.parse() { - bypass_ips.push(ip); - } - } - - for domain in &config.exclusions.domains { - match tokio::net::lookup_host((domain.as_str(), 443u16)).await { - Ok(addrs) => { - for addr in addrs { - bypass_ips.push(addr.ip()); - } - } - Err(e) => { - tracing::warn!("Failed to pre-resolve excluded domain {domain}: {e}"); - } - } - } - - - // ── 3. Create TUN device via ostp-tun crate ─────────────────────────────── - let opts = ostp_tun::OstpTunOptions { - server_ip, - bypass_ips, - dns_server: config.dns_server.clone(), - kill_switch: config.kill_switch, - mtu: config.ostp.mtu as u16, - wintun_path: None, - }; - - let tun_interface = ostp_tun::OstpTunInterface::create(opts) - .await - .map_err(|e| anyhow!("Failed to create OstpTunInterface: {}", e))?; - - let dev = tun_interface.device; - let _route_guard = tun_interface.guard; - - // ── 7. Build smoltcp network stack ──────────────────────────────────────── - let (stack, tcp_runner, udp_socket, tcp_listener) = StackBuilder::default() - .stack_buffer_size(1024) - .tcp_buffer_size(1024) - .udp_buffer_size(1024) - .enable_tcp(true) - .enable_udp(true) - .mtu(config.ostp.mtu) - .build()?; - - let mut runner_task = tokio::spawn(async move { - if let Some(runner) = tcp_runner { - let _ = runner.await; - } - }); - - // ── 8. Wire TUN ↔ smoltcp stack ─────────────────────────────────────────── - let (mut stack_sink, mut stack_stream) = stack.split(); - let (mut tun_read, mut tun_write) = tokio::io::split(dev); - - let mut tun_to_stack = tokio::spawn(async move { - let mut buf = vec![0u8; 65536]; - loop { - match tun_read.read(&mut buf).await { - Ok(0) => break, - Ok(n) => { - let frame = buf[..n].to_vec(); - if let Err(e) = stack_sink.send(frame).await { - if e.kind() == std::io::ErrorKind::BrokenPipe { - break; - } - } - } - Err(e) => { - tracing::debug!("tun_read error: {e}"); - } - } - } - }); - - let mut stack_to_tun = tokio::spawn(async move { - while let Some(Ok(frame)) = stack_stream.next().await { - if let Err(e) = tun_write.write(&frame).await { - tracing::debug!("tun_write error: {e}"); - } - } - }); - - // ── 9. UDP: forward everything through OSTP proxy ───────────────────────── - // UDP exclusions are handled at the routing table level (step 5), so - // UDP packets for excluded IPs never reach smoltcp at all. - let udp_proxy_addr = { - let mut a = config.local_proxy.bind_addr.clone(); - if a.starts_with("0.0.0.0:") { - a = a.replace("0.0.0.0:", "127.0.0.1:"); - } - a - }; - // Build exclusion matcher for dynamic bypass - let current_exclusions = exclusions_rx.borrow().clone(); - let matcher = crate::tunnel::exclusion::ExclusionMatcher::new(¤t_exclusions, None, None); - let matcher_arc = std::sync::Arc::new(tokio::sync::RwLock::new(matcher)); - - let matcher_clone = matcher_arc.clone(); - tokio::spawn(async move { - while let Ok(_) = exclusions_rx.changed().await { - let current = exclusions_rx.borrow().clone(); - let new_matcher = crate::tunnel::exclusion::ExclusionMatcher::new(¤t, None, None); - *matcher_clone.write().await = new_matcher; - if true { - tracing::debug!("Desktop TUN exclusions hot-reloaded"); - } - } - }); - - // Linux: physical interface name for SO_BINDTODEVICE - #[cfg(target_os = "linux")] - let linux_phys_name = crate::tunnel::proxy::get_linux_physical_if_name(); - #[cfg(not(target_os = "linux"))] - let linux_phys_name: Option = None; - let _ = &linux_phys_name; // suppress unused warning on Windows - - let debug_udp = debug; - let udp_matcher = matcher_arc.clone(); - #[cfg(target_os = "linux")] - let udp_lin_name = linux_phys_name.clone(); - - let mut udp_proxy_task = tokio::spawn(async move { - if let Some(udp_sock) = udp_socket { - #[cfg(target_os = "linux")] - super::udp_nat::run_udp_nat(udp_sock, udp_proxy_addr, debug_udp, udp_matcher, phys_if_for_bypass, udp_lin_name).await; - #[cfg(not(target_os = "linux"))] - super::udp_nat::run_udp_nat(udp_sock, udp_proxy_addr, debug_udp, udp_matcher, phys_if_for_bypass, None).await; - } - }); - - // ── 10. TCP: forward to OSTP proxy (with domain-level bypass via SNI) ───── - // - // For IP-based exclusions: handled by routing table → packets never arrive here. - // For domain-based exclusions: The IP is already in routing table (pre-resolved in - // step 3), so most traffic won't arrive. As a belt-and-suspenders fallback, - // we also sniff TLS SNI and bypass if it matches — this covers CDN cases where - // the IP wasn't known at startup. - // - // For bypassed connections we bind the outgoing socket to the physical interface - // (IP_UNICAST_IF) so it goes out via the real NIC, not TUN. - - let proxy_addr_tcp = { - let mut a = config.local_proxy.bind_addr.clone(); - if a.starts_with("0.0.0.0:") { - a = a.replace("0.0.0.0:", "127.0.0.1:"); - } - a - }; - - // Physical interface index was captured at the start of the function. - - let mut tcp_accept_task = tokio::spawn(async move { - let Some(mut listener) = tcp_listener else { return; }; - - while let Some((mut stream, local, remote)) = listener.next().await { - let proxy_addr = proxy_addr_tcp.clone(); - let matcher_arc = matcher_arc.clone(); - #[cfg(target_os = "linux")] - let lin_name = linux_phys_name.clone(); - - tokio::spawn(async move { - let matcher = matcher_arc.read().await.clone(); - if debug { - tracing::debug!("TUN TCP {local} → {remote}"); - } - - // ── Sniff TLS ClientHello for SNI ───────────────────────────── - let mut sniff_buf = [0u8; 2048]; - let sniff_len = - match tokio::time::timeout( - std::time::Duration::from_millis(100), - stream.read(&mut sniff_buf), - ) - .await - { - Ok(Ok(n)) => n, - _ => 0, - }; - - // ── Decide: bypass or tunnel? ───────────────────────────────── - let mut should_bypass = false; - - // 1. Process match via OS Extended TCP Table (Windows) - #[cfg(target_os = "windows")] - if !should_bypass { - if let Some(proc_name) = crate::tunnel::process_lookup::get_process_name_from_port(local.port()) { - if debug { - tracing::debug!("TUN TCP lookup: port {} -> process {}", local.port(), proc_name); - } - if matcher.match_process(&proc_name) { - if debug { - tracing::debug!("TUN TCP BYPASS (Process match): {} → {remote}", proc_name); - } - should_bypass = true; - } - } else { - if debug { - tracing::debug!("TUN TCP lookup: port {} -> no process found", local.port()); - } - } - } - - // 2. SNI domain check (belt-and-suspenders for CDNs / late-resolved IPs) - if !should_bypass && sniff_len > 0 { - if let Some(sni) = - crate::tunnel::sni_sniff::extract_sni(&sniff_buf[..sniff_len]) - { - if debug { - tracing::debug!("TUN SNI: {sni}"); - } - if matcher.match_domain(&sni) { - if debug { - tracing::info!("TUN TCP BYPASS (SNI domain): {sni} → {remote}"); - } - should_bypass = true; - } - } - } - - // 3. Destination IP CIDR check (for IPs not in routing table / IPv6) - if !should_bypass && matcher.match_ip(&remote.ip()) { - if debug { - tracing::info!("TUN TCP BYPASS (IP match): {remote}"); - } - should_bypass = true; - } - - // ── Bypass path: direct TCP bypassing TUN ───────────────────── - if should_bypass { - let socket = match remote { - std::net::SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4(), - std::net::SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6(), - }; - let Ok(socket) = socket else { return; }; - - // Bind to physical interface so packets don't loop back into TUN - - #[cfg(target_os = "windows")] - if let Some(idx) = phys_if_for_bypass { - if let Err(e) = crate::tunnel::proxy::bind_socket_to_interface( - &socket, - remote.is_ipv6(), - idx, - ) { - tracing::error!("TUN TCP BYPASS failed to bind to physical interface {}: {}", idx, e); - } else { - if debug { - tracing::info!("TUN TCP BYPASS bound to physical interface {}", idx); - } - } - } else { - tracing::warn!("TUN TCP BYPASS has no physical interface index!"); - } - #[cfg(target_os = "linux")] - if let Some(ref name) = lin_name { - let _ = crate::tunnel::proxy::bind_socket_to_interface(&socket, name); - } - - match tokio::time::timeout( - std::time::Duration::from_secs(10), - socket.connect(remote), - ) - .await - { - Ok(Ok(mut direct)) => { - if sniff_len > 0 { - if direct.write_all(&sniff_buf[..sniff_len]).await.is_err() { - return; - } - } - let _ = tokio::io::copy_bidirectional(&mut stream, &mut direct).await; - } - _ => { - tracing::debug!("Direct bypass connect to {remote} failed"); - } - } - return; - } - - // ── Tunnel path: forward via local OSTP SOCKS5 proxy ────────── - let Ok(mut socks) = tokio::net::TcpStream::connect(&proxy_addr).await else { - return; - }; - - // SOCKS5 handshake (no auth) - if socks.write_all(&[5, 1, 0]).await.is_err() { return; } - let mut buf2 = [0u8; 2]; - if socks.read_exact(&mut buf2).await.is_err() || buf2[0] != 5 || buf2[1] != 0 { - return; - } - - // CONNECT request - let mut req = vec![5u8, 1, 0]; - match remote.ip() { - std::net::IpAddr::V4(v4) => { - req.push(1); - req.extend_from_slice(&v4.octets()); - } - std::net::IpAddr::V6(v6) => { - req.push(4); - req.extend_from_slice(&v6.octets()); - } - } - req.extend_from_slice(&remote.port().to_be_bytes()); - if socks.write_all(&req).await.is_err() { return; } - - let mut rep = [0u8; 10]; - if socks.read_exact(&mut rep).await.is_err() || rep[1] != 0 { return; } - - // Replay sniffed bytes - if sniff_len > 0 && socks.write_all(&sniff_buf[..sniff_len]).await.is_err() { - return; - } - - let _ = tokio::io::copy_bidirectional(&mut stream, &mut socks).await; - }); - } - }); - - tracing::info!("NATIVE TUN tunnel active."); - - tokio::select! { - _ = shutdown.changed() => {} - _ = &mut runner_task => {} - _ = &mut tun_to_stack => {} - _ = &mut stack_to_tun => {} - _ = &mut udp_proxy_task => {} - _ = &mut tcp_accept_task => {} - } - - tracing::info!("Deactivating NATIVE TUN tunnel..."); - - // ── Cleanup ─────────────────────────────────────────────────────────────── - // Cleanup is handled automatically by the _route_guard Drop trait in ostp-tun - - Ok(()) -} - -// ────────────────────────────────────────────────────────────────────────────── -// Stub for unsupported platforms -// ────────────────────────────────────────────────────────────────────────────── - -#[cfg(not(any(target_os = "windows", target_os = "linux")))] -pub async fn run_native_tunnel( - _config: crate::config::ClientConfig, - _shutdown: watch::Receiver, - _exclusions_rx: watch::Receiver, -) -> Result<()> { - Err(anyhow!("Native TUN tunnel is only supported on Windows/Linux")) -} - -// ────────────────────────────────────────────────────────────────────────────── -// Android: TUN from file-descriptor (opened by VpnService) -// ────────────────────────────────────────────────────────────────────────────── - -#[cfg(target_os = "android")] -pub async fn run_native_tunnel_from_fd( - config: crate::config::ClientConfig, - mut shutdown: watch::Receiver, - mut exclusions_rx: watch::Receiver, - fd: i32, -) -> Result<()> { - use netstack_smoltcp::StackBuilder; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use futures::{StreamExt, SinkExt}; - use std::os::unix::io::{FromRawFd, AsRawFd}; - - let debug = config.debug; - tracing::info!("Initializing NATIVE TUN tunnel on Android (FD {})", fd); - - unsafe { - let flags = libc::fcntl(fd, libc::F_GETFL); - if flags >= 0 { - libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK); - } - } - - let read_fd = unsafe { libc::dup(fd) }; - if read_fd < 0 { - return Err(anyhow!("Failed to dup tun fd for reading")); - } - - let file = unsafe { std::fs::File::from_raw_fd(read_fd) }; - let tun_stream = tokio::io::unix::AsyncFd::new(file)?; - - let (stack, tcp_runner, udp_socket, tcp_listener) = StackBuilder::default() - .stack_buffer_size(1024) - .tcp_buffer_size(1024) - .udp_buffer_size(1024) - .enable_tcp(true) - .enable_udp(true) - .mtu(config.ostp.mtu) - .build()?; - - let mut runner_task = tokio::spawn(async move { - if let Some(runner) = tcp_runner { - let _ = runner.await; - } - }); - - let (mut stack_sink, mut stack_stream) = stack.split(); - - let _tun_to_stack = tokio::spawn(async move { - let mut buf = vec![0u8; 65536]; - loop { - let mut guard = match tun_stream.readable().await { - Ok(g) => g, - Err(_) => break, - }; - let n = match guard.try_io(|inner| { - let res = unsafe { - libc::read( - inner.as_raw_fd(), - buf.as_mut_ptr() as *mut libc::c_void, - buf.len(), - ) - }; - if res < 0 { - let err = std::io::Error::last_os_error(); - if err.kind() == std::io::ErrorKind::WouldBlock { - Err(err) - } else { - Ok(0_isize) - } - } else { - Ok(res) - } - }) { - Ok(Ok(n)) if n > 0 => n as usize, - Ok(Ok(_)) => continue, - Ok(Err(_)) => continue, - Err(_) => continue, - }; - - let frame = buf[..n].to_vec(); - if let Err(e) = stack_sink.send(frame).await { - if e.kind() == std::io::ErrorKind::BrokenPipe { - break; - } - } - } - }); - - let write_fd = unsafe { libc::dup(fd) }; - if write_fd < 0 { - return Err(anyhow!("Failed to dup tun fd for writing")); - } - unsafe { - let flags = libc::fcntl(write_fd, libc::F_GETFL); - if flags >= 0 { - libc::fcntl(write_fd, libc::F_SETFL, flags | libc::O_NONBLOCK); - } - } - let write_file = unsafe { std::fs::File::from_raw_fd(write_fd) }; - let tun_write_stream = tokio::io::unix::AsyncFd::new(write_file)?; - - let _stack_to_tun = tokio::spawn(async move { - while let Some(Ok(frame)) = stack_stream.next().await { - let mut written = 0; - while written < frame.len() { - let mut guard = match tun_write_stream.writable().await { - Ok(g) => g, - Err(_) => break, - }; - let res = guard.try_io(|inner| { - let res = unsafe { - libc::write( - inner.as_raw_fd(), - frame[written..].as_ptr() as *const libc::c_void, - frame.len() - written, - ) - }; - if res < 0 { - let err = std::io::Error::last_os_error(); - if err.kind() == std::io::ErrorKind::WouldBlock { - Err(err) - } else { - Ok(res) - } - } else { - Ok(res) - } - }); - match res { - Ok(Ok(n)) if n > 0 => written += n as usize, - Ok(Ok(_)) => break, - Ok(Err(_)) => break, - Err(_) => continue, - } - } - } - }); - - let mut proxy_addr = config.local_proxy.bind_addr.clone(); - if proxy_addr.starts_with("0.0.0.0:") { - proxy_addr = proxy_addr.replace("0.0.0.0:", "127.0.0.1:"); - } - - let current_exclusions = exclusions_rx.borrow().clone(); - let matcher = crate::tunnel::exclusion::ExclusionMatcher::new(¤t_exclusions, None, None); - let matcher_arc = std::sync::Arc::new(tokio::sync::RwLock::new(matcher)); - - let matcher_clone = matcher_arc.clone(); - tokio::spawn(async move { - while let Ok(_) = exclusions_rx.changed().await { - let current = exclusions_rx.borrow().clone(); - let new_matcher = crate::tunnel::exclusion::ExclusionMatcher::new(¤t, None, None); - *matcher_clone.write().await = new_matcher; - if true { - tracing::debug!("Android TUN exclusions hot-reloaded"); - } - } - }); - - let udp_proxy_addr = proxy_addr.clone(); - let debug_udp = debug; - let udp_matcher = matcher_arc.clone(); - let mut udp_proxy_task = tokio::spawn(async move { - if let Some(udp_sock) = udp_socket { - super::udp_nat::run_udp_nat(udp_sock, udp_proxy_addr, debug_udp, udp_matcher, None, None).await; - } - }); - - - - let mut tcp_accept_task = tokio::spawn(async move { - let Some(mut listener) = tcp_listener else { return; }; - - while let Some((mut stream, local, remote)) = listener.next().await { - let proxy_addr = proxy_addr.clone(); - let matcher_arc = matcher_arc.clone(); - - tokio::spawn(async move { - let matcher = matcher_arc.read().await.clone(); - - if true { - tracing::debug!("Android TUN TCP {local} → {remote}"); - } - - // Sniff SNI - let mut sniff_buf = [0u8; 2048]; - let sniff_len = - match tokio::time::timeout( - std::time::Duration::from_millis(100), - stream.read(&mut sniff_buf), - ) - .await - { - Ok(Ok(n)) => n, - _ => 0, - }; - - let mut should_bypass = false; - - // 1. SNI domain - if sniff_len > 0 { - if let Some(sni) = - crate::tunnel::sni_sniff::extract_sni(&sniff_buf[..sniff_len]) - { - if true { tracing::debug!("Android TUN SNI: {sni}"); } - if matcher.match_domain(&sni) { - should_bypass = true; - } - } - } - - // 2. Process (Android: /proc/net lookup) - if !should_bypass { - if let Some(exe) = - crate::tunnel::process_lookup::get_process_name_from_port(local.port()) - { - if true { - tracing::debug!("Android TUN port {} → EXE: {}", local.port(), exe); - } - if matcher.match_process(&exe) { - should_bypass = true; - } - } - } - - // 3. IP CIDR - if !should_bypass && matcher.match_ip(&remote.ip()) { - should_bypass = true; - } - - // Bypass: connect directly (Android VPN service already protects the socket - // from re-entering the TUN through VpnService.protect()) - if should_bypass { - if true { - tracing::debug!("Android TUN BYPASS: {remote}"); - } - let socket = match remote { - std::net::SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4(), - std::net::SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6(), - }; - let Ok(socket) = socket else { return; }; - - match tokio::time::timeout( - std::time::Duration::from_secs(10), - socket.connect(remote), - ) - .await - { - Ok(Ok(mut direct)) => { - if sniff_len > 0 { - if direct.write_all(&sniff_buf[..sniff_len]).await.is_err() { - return; - } - } - let _ = tokio::io::copy_bidirectional(&mut stream, &mut direct).await; - } - _ => { - tracing::debug!("Android bypass connect to {remote} failed"); - } - } - return; - } - - // Tunnel via SOCKS5 proxy - let Ok(mut socks) = tokio::net::TcpStream::connect(&proxy_addr).await else { - return; - }; - if socks.write_all(&[5, 1, 0]).await.is_err() { return; } - let mut buf2 = [0u8; 2]; - if socks.read_exact(&mut buf2).await.is_err() || buf2[0] != 5 || buf2[1] != 0 { - return; - } - let mut req = vec![5u8, 1, 0]; - match remote.ip() { - std::net::IpAddr::V4(v4) => { - req.push(1); - req.extend_from_slice(&v4.octets()); - } - std::net::IpAddr::V6(v6) => { - req.push(4); - req.extend_from_slice(&v6.octets()); - } - } - req.extend_from_slice(&remote.port().to_be_bytes()); - if socks.write_all(&req).await.is_err() { return; } - let mut rep = [0u8; 10]; - if socks.read_exact(&mut rep).await.is_err() || rep[1] != 0 { return; } - if sniff_len > 0 && socks.write_all(&sniff_buf[..sniff_len]).await.is_err() { - return; - } - let _ = tokio::io::copy_bidirectional(&mut stream, &mut socks).await; - }); - } - }); - - tracing::info!("NATIVE TUN (Android) tunnel active."); - - tokio::select! { - _ = shutdown.changed() => {} - _ = &mut runner_task => {} - _ = _tun_to_stack => {} - _ = _stack_to_tun => {} - _ = &mut udp_proxy_task => {} - _ = &mut tcp_accept_task => {} - } - - tracing::info!("NATIVE TUN (Android) deactivated."); - Ok(()) -} - -#[cfg(not(target_os = "android"))] -pub async fn run_native_tunnel_from_fd( - _config: crate::config::ClientConfig, - _shutdown: watch::Receiver, - _exclusions_rx: watch::Receiver, - _fd: i32, -) -> Result<()> { - Err(anyhow!("Native TUN from FD is only supported on Android")) -} diff --git a/ostp-client/src/tunnel/outbounds/block.rs b/ostp-client/src/tunnel/outbounds/block.rs new file mode 100644 index 0000000..1530e89 --- /dev/null +++ b/ostp-client/src/tunnel/outbounds/block.rs @@ -0,0 +1,14 @@ +use anyhow::{anyhow, Result}; +use tokio::net::TcpStream; + +pub async fn dial_tcp(_target_host: &str, _target_port: u16) -> Result { + Err(anyhow!("Connection blocked by routing rule")) +} + +pub async fn handle_udp( + _client_src: std::net::SocketAddr, + _target_dst: std::net::SocketAddr, + _payload: bytes::Bytes, +) -> Result<()> { + Err(anyhow!("Connection blocked by routing rule")) +} diff --git a/ostp-client/src/tunnel/outbounds/direct.rs b/ostp-client/src/tunnel/outbounds/direct.rs new file mode 100644 index 0000000..d1068d1 --- /dev/null +++ b/ostp-client/src/tunnel/outbounds/direct.rs @@ -0,0 +1,99 @@ +use anyhow::{anyhow, Result}; +use tokio::net::TcpStream; + +#[cfg(target_os = "windows")] +pub fn bind_socket_to_interface(socket: &tokio::net::TcpSocket, is_ipv6: bool, if_index: u32) -> std::io::Result<()> { + use std::os::windows::io::AsRawSocket; + use winapi::shared::ws2def::{IPPROTO_IP, IPPROTO_IPV6}; + + // These constants are defined as 31 in the Windows SDK. + const IP_UNICAST_IF: i32 = 31; + const IPV6_UNICAST_IF: i32 = 31; + + let fd = socket.as_raw_socket() as usize; + let idx_net = if_index.to_be(); + + let (level, optname) = if is_ipv6 { + (IPPROTO_IPV6 as i32, IPV6_UNICAST_IF) + } else { + (IPPROTO_IP as i32, IP_UNICAST_IF) + }; + + let ret = unsafe { + winapi::um::winsock2::setsockopt( + fd, + level as i32, + optname as i32, + &idx_net as *const _ as *const i8, + std::mem::size_of_val(&idx_net) as i32, + ) + }; + + if ret != 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) +} + +#[cfg(target_os = "linux")] +pub fn bind_socket_to_interface(socket: &tokio::net::TcpSocket, _is_ipv6: bool, if_name: &str) -> std::io::Result<()> { + use std::os::unix::io::AsRawFd; + let fd = socket.as_raw_fd(); + let name_bytes = if_name.as_bytes(); + let ret = unsafe { + libc::setsockopt( + fd, + libc::SOL_SOCKET, + libc::SO_BINDTODEVICE, + name_bytes.as_ptr() as *const libc::c_void, + name_bytes.len() as libc::socklen_t, + ) + }; + if ret != 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) +} + +#[cfg(target_os = "macos")] +pub fn bind_socket_to_interface(socket: &tokio::net::TcpSocket, _is_ipv6: bool, if_index: u32) -> std::io::Result<()> { + // macOS uses IP_BOUND_IF for IPv4 and IPV6_BOUND_IF for IPv6, similar to Windows + use std::os::unix::io::AsRawFd; + let fd = socket.as_raw_fd(); + + // We can implement this later, for now just a stub so compilation works + tracing::debug!("macOS socket binding not yet fully implemented for interface {}", if_index); + Ok(()) +} + +pub async fn dial_tcp(target_host: &str, target_port: u16, phys_if_idx: Option) -> Result { + let addrs = tokio::net::lookup_host((target_host, target_port)).await?.collect::>(); + if addrs.is_empty() { + return Err(anyhow!("Could not resolve target host: {}", target_host)); + } + + let target_addr = addrs[0]; + let socket = match target_addr { + std::net::SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?, + std::net::SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?, + }; + + #[cfg(target_os = "windows")] + if let Some(idx) = phys_if_idx { + if let Err(e) = bind_socket_to_interface(&socket, target_addr.is_ipv6(), idx) { + tracing::warn!("DIRECT: Failed to bind to physical interface {}: {}", idx, e); + } + } + + let stream = tokio::time::timeout(std::time::Duration::from_secs(10), socket.connect(target_addr)).await??; + Ok(stream) +} + +pub async fn handle_udp( + _client_src: std::net::SocketAddr, + _target_dst: std::net::SocketAddr, + _payload: bytes::Bytes, + _phys_if_idx: Option, +) -> Result<()> { + Err(anyhow!("Direct UDP is not yet fully implemented")) +} diff --git a/ostp-client/src/tunnel/outbounds/mod.rs b/ostp-client/src/tunnel/outbounds/mod.rs new file mode 100644 index 0000000..a9219cd --- /dev/null +++ b/ostp-client/src/tunnel/outbounds/mod.rs @@ -0,0 +1,78 @@ +use anyhow::{anyhow, Result}; +use std::sync::Arc; +use tokio::net::TcpStream; +use crate::tunnel::balancer::Balancer; +use crate::config::OutboundConfig; + +pub mod direct; +pub mod block; +pub mod ostp; +pub mod socks; + +pub struct OutboundManager { + balancer: Arc, + phys_if_index: Option, + phys_if_name: Option, +} + +impl OutboundManager { + pub fn new( + balancer: Arc, + phys_if_index: Option, + phys_if_name: Option, + ) -> Self { + Self { + balancer, + phys_if_index, + phys_if_name, + } + } + + pub async fn dial_tcp(&self, tag: &str, target_host: &str, target_port: u16) -> Result { + let concrete_config = self.balancer.get_concrete_outbound(tag) + .ok_or_else(|| anyhow!("Outbound tag '{}' not found or resolved to invalid node", tag))?; + + match concrete_config { + OutboundConfig::Direct { .. } => { + direct::dial_tcp(target_host, target_port, self.phys_if_index).await + } + OutboundConfig::Block { .. } => { + block::dial_tcp(target_host, target_port).await + } + OutboundConfig::Ostp { server, port, access_key, transport, multiplex, .. } => { + ostp::dial_tcp(server, *port, access_key, transport, multiplex).await + } + OutboundConfig::Socks { server, port, .. } => { + socks::dial_tcp(target_host, target_port, server, *port).await + } + _ => Err(anyhow!("Invalid concrete outbound type for {}", tag)), + } + } + + pub async fn handle_udp( + &self, + tag: &str, + client_src: std::net::SocketAddr, + target_dst: std::net::SocketAddr, + payload: bytes::Bytes, + ) -> Result<()> { + let concrete_config = self.balancer.get_concrete_outbound(tag) + .ok_or_else(|| anyhow!("Outbound tag '{}' not found or resolved to invalid node", tag))?; + + match concrete_config { + OutboundConfig::Direct { .. } => { + direct::handle_udp(client_src, target_dst, payload, self.phys_if_index).await + } + OutboundConfig::Block { .. } => { + block::handle_udp(client_src, target_dst, payload).await + } + OutboundConfig::Ostp { server, port, access_key, transport, multiplex, .. } => { + ostp::handle_udp(client_src, target_dst, payload, server, *port, access_key, transport, multiplex).await + } + OutboundConfig::Socks { server, port, .. } => { + socks::handle_udp(client_src, target_dst, payload, server, *port).await + } + _ => Err(anyhow!("Invalid concrete outbound type for {}", tag)), + } + } +} diff --git a/ostp-client/src/tunnel/outbounds/ostp.rs b/ostp-client/src/tunnel/outbounds/ostp.rs new file mode 100644 index 0000000..8a796a4 --- /dev/null +++ b/ostp-client/src/tunnel/outbounds/ostp.rs @@ -0,0 +1,28 @@ +use anyhow::{anyhow, Result}; +use tokio::net::TcpStream; +use crate::config::{TransportConfig, MultiplexConfig}; + +pub async fn dial_tcp( + _server: &str, + _port: u16, + _access_key: &str, + _transport: &TransportConfig, + _multiplex: &MultiplexConfig, +) -> Result { + // Ostp dialer implementation. + // For now returning an error until we migrate the local_proxy connection logic here. + Err(anyhow!("OSTP TCP dialer not yet fully migrated")) +} + +pub async fn handle_udp( + _client_src: std::net::SocketAddr, + _target_dst: std::net::SocketAddr, + _payload: bytes::Bytes, + _server: &str, + _port: u16, + _access_key: &str, + _transport: &TransportConfig, + _multiplex: &MultiplexConfig, +) -> Result<()> { + Err(anyhow!("OSTP UDP handler not yet fully migrated")) +} diff --git a/ostp-client/src/tunnel/outbounds/socks.rs b/ostp-client/src/tunnel/outbounds/socks.rs new file mode 100644 index 0000000..6039981 --- /dev/null +++ b/ostp-client/src/tunnel/outbounds/socks.rs @@ -0,0 +1,17 @@ +use anyhow::{anyhow, Result}; +use tokio::net::TcpStream; + +pub async fn dial_tcp(_target_host: &str, _target_port: u16, _server: &str, _port: u16) -> Result { + // SOCKS5 dialer implementation stub + Err(anyhow!("SOCKS outbound TCP dialer not yet implemented")) +} + +pub async fn handle_udp( + _client_src: std::net::SocketAddr, + _target_dst: std::net::SocketAddr, + _payload: bytes::Bytes, + _server: &str, + _port: u16, +) -> Result<()> { + Err(anyhow!("SOCKS outbound UDP handler not yet implemented")) +} diff --git a/ostp-client/src/tunnel/proxy.rs b/ostp-client/src/tunnel/proxy.rs deleted file mode 100644 index 4db81b5..0000000 --- a/ostp-client/src/tunnel/proxy.rs +++ /dev/null @@ -1,921 +0,0 @@ -use std::collections::HashMap; -use crate::tunnel::exclusion::ExclusionMatcher; -use anyhow::{anyhow, Context, Result}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream, UdpSocket}; -use std::sync::Arc; -use tokio::sync::{mpsc, watch}; -use tokio::time::{timeout, Duration}; - -use crate::config::{ExclusionConfig, LocalProxyConfig, OstpConfig}; -use crate::tunnel::{ProxyEvent, ProxyToClientMsg}; - -#[cfg(target_os = "windows")] -use std::os::windows::io::AsRawSocket; - -#[cfg(target_os = "linux")] -use std::os::fd::AsRawFd; - -#[cfg(target_os = "windows")] -#[link(name = "ws2_32")] -extern "system" { - fn setsockopt( - s: usize, - level: i32, - optname: i32, - optval: *const u8, - optlen: i32, - ) -> i32; -} - -#[cfg(target_os = "windows")] -pub fn bind_socket_to_interface(socket: &impl AsRawSocket, is_ipv6: bool, if_index: u32) -> std::io::Result<()> { - let s = socket.as_raw_socket() as usize; - if is_ipv6 { - // IPV6_UNICAST_IF expects interface index in host byte order - let optval = if_index; - let ret = unsafe { - setsockopt( - s, - 41, // IPPROTO_IPV6 - 31, // IPV6_UNICAST_IF - &optval as *const u32 as *const u8, - 4, - ) - }; - if ret != 0 { - return Err(std::io::Error::last_os_error()); - } - } else { - // IP_UNICAST_IF expects interface index in NETWORK byte order (big-endian) - let optval = if_index.to_be(); - let ret = unsafe { - setsockopt( - s, - 0, // IPPROTO_IP - 31, // IP_UNICAST_IF - &optval as *const u32 as *const u8, - 4, - ) - }; - if ret != 0 { - return Err(std::io::Error::last_os_error()); - } - } - Ok(()) -} - -#[cfg(target_os = "linux")] -pub fn bind_socket_to_interface(socket: &impl AsRawFd, if_name: &str) -> std::io::Result<()> { - let fd = socket.as_raw_fd(); - let mut if_name_bytes = if_name.as_bytes().to_vec(); - if_name_bytes.push(0); - let ret = unsafe { - libc::setsockopt( - fd, - libc::SOL_SOCKET, - libc::SO_BINDTODEVICE, - if_name_bytes.as_ptr() as *const std::ffi::c_void, - if_name_bytes.len() as libc::socklen_t, - ) - }; - if ret != 0 { - return Err(std::io::Error::last_os_error()); - } - Ok(()) -} - -pub fn get_windows_physical_if_index() -> Option { - #[cfg(target_os = "windows")] - { - return ostp_tun::windows::windows_route::sys::get_default_ipv4_route().map(|(_, idx)| idx); - } - #[cfg(not(target_os = "windows"))] - { - None - } -} - -pub fn get_linux_physical_if_name() -> Option { - #[cfg(target_os = "linux")] - { - let output = std::process::Command::new("ip") - .args(["route", "show", "default"]) - .output() - .ok()?; - if output.status.success() { - let s = String::from_utf8_lossy(&output.stdout); - if let Some(dev_part) = s.split_whitespace().skip_while(|w| *w != "dev").nth(1) { - return Some(dev_part.to_string()); - } - } - } - None -} - -#[allow(unused_variables)] -async fn connect_bypassing_tun( - target: &str, - physical_if_index: Option, - _physical_if_name: &Option, -) -> Result { - let resolved = tokio::net::lookup_host(target).await - .with_context(|| format!("failed to resolve host for bypass connect: {target}"))?; - - let mut last_err = None; - for addr in resolved { - let socket = if addr.is_ipv6() { - let s = tokio::net::TcpSocket::new_v6()?; - let _ = s.bind("[::]:0".parse().unwrap()); - s - } else { - let s = tokio::net::TcpSocket::new_v4()?; - let _ = s.bind("0.0.0.0:0".parse().unwrap()); - s - }; - - #[cfg(target_os = "windows")] - if let Some(if_index) = physical_if_index { - if let Err(e) = bind_socket_to_interface(&socket, addr.is_ipv6(), if_index) { - tracing::warn!("Failed to bind TCP socket to interface {}: {}", if_index, e); - } - } - - #[cfg(target_os = "linux")] - if let Some(ref if_name) = _physical_if_name { - if let Err(e) = bind_socket_to_interface(&socket, if_name) { - tracing::warn!("Failed to bind TCP socket to interface {}: {}", if_name, e); - } - } - - match socket.connect(addr).await { - Ok(stream) => return Ok(stream), - Err(e) => { - last_err = Some(e); - } - } - } - - Err(anyhow!( - "direct connect failed: {:?}", - last_err.map(|e| e.to_string()).unwrap_or_else(|| "no addresses resolved".to_string()) - )) -} - -#[allow(unused_variables)] -async fn create_udp_socket_bypassing_tun( - is_ipv6: bool, - physical_if_index: Option, - _physical_if_name: &Option, -) -> Result { - let addr: std::net::SocketAddr = if is_ipv6 { - "[::]:0".parse().unwrap() - } else { - "0.0.0.0:0".parse().unwrap() - }; - - let socket = UdpSocket::bind(addr).await - .with_context(|| format!("failed to bind direct UdpSocket to wildcard {}", addr))?; - - #[cfg(target_os = "windows")] - if let Some(if_index) = physical_if_index { - if let Err(e) = bind_socket_to_interface(&socket, is_ipv6, if_index) { - tracing::warn!("Failed to bind UDP socket to interface index {}: {}", if_index, e); - } - } - - #[cfg(target_os = "linux")] - if let Some(ref if_name) = _physical_if_name { - if let Err(e) = bind_socket_to_interface(&socket, if_name) { - tracing::warn!("Failed to bind UDP socket to interface {}: {}", if_name, e); - } - } - - Ok(socket) -} - -pub async fn run_local_socks5_proxy( - cfg: LocalProxyConfig, - ostp: OstpConfig, - mut exclusions_rx: watch::Receiver, - debug: bool, - mut shutdown: watch::Receiver, - proxy_events_tx: mpsc::Sender, - mut client_msgs_rx: mpsc::UnboundedReceiver<(u16, ProxyToClientMsg)>, -) -> Result<()> { - let connect_timeout = Duration::from_millis(cfg.connect_timeout_ms.max(1)); - let listener = TcpListener::bind(&cfg.bind_addr) - .await - .with_context(|| format!("failed to bind local HTTP/SOCKS5 proxy at {}", cfg.bind_addr))?; - - if true { - tracing::info!("local HTTP/SOCKS5 proxy listening at {}", cfg.bind_addr); - tracing::info!("Windows system proxy: set HTTP proxy to {}. tun2socks: SOCKS5 on same address.", cfg.bind_addr); - } - - let physical_if_index = tokio::task::spawn_blocking(get_windows_physical_if_index).await.unwrap_or(None); - let physical_if_name = tokio::task::spawn_blocking(get_linux_physical_if_name).await.unwrap_or(None); - - if physical_if_index.is_some() { - tracing::info!("Local proxy physical interface index: {:?}", physical_if_index); - } - if physical_if_name.is_some() { - tracing::info!("Local proxy physical interface name: {:?}", physical_if_name); - } - - let mut current_exclusions = exclusions_rx.borrow().clone(); - let mut matcher = ExclusionMatcher::new(¤t_exclusions, physical_if_index, physical_if_name.clone()); - let (connect_tx, mut connect_rx) = mpsc::channel(128); - let max_chunk = ostp.mtu.saturating_sub(150).max(512); - - let mut next_stream_id: u16 = 1; - let mut active_streams: HashMap> = HashMap::new(); - - loop { - tokio::select! { - _ = shutdown.changed() => { - if *shutdown.borrow() { - break; - } - } - Ok(_) = exclusions_rx.changed() => { - current_exclusions = exclusions_rx.borrow().clone(); - matcher = ExclusionMatcher::new(¤t_exclusions, physical_if_index, physical_if_name.clone()); - if true { - tracing::info!("Local proxy exclusions hot-reloaded"); - } - } - accepted = listener.accept() => { - let (socket, _) = accepted?; - let stream_id = next_stream_id; - // Advance, skipping zero and any stream_id still in active_streams - loop { - next_stream_id = next_stream_id.wrapping_add(1); - if next_stream_id == 0 { next_stream_id = 1; } - if !active_streams.contains_key(&next_stream_id) { break; } - } - - let (tx, rx) = mpsc::unbounded_channel(); - active_streams.insert(stream_id, tx); - - let event_tx = proxy_events_tx.clone(); - let c_tx = connect_tx.clone(); - let matcher_clone = matcher.clone(); - tokio::spawn(async move { - if let Err(err) = handle_proxy_client( - socket, - stream_id, - event_tx, - rx, - c_tx, - connect_timeout, - debug, - matcher_clone, - max_chunk, - ).await { - let msg = err.to_string(); - // Suppress routine disconnects and unsupported SOCKS5 command attempts (like UDP) from spam logs - if !msg.contains("UnexpectedEof") - && !msg.contains("Connection reset") - && !msg.contains("Broken pipe") - && !msg.contains("unsupported SOCKS5 command") - && debug { - tracing::warn!("proxy client error: {err}"); - } - } - }); - } - Some((stream_id, msg)) = client_msgs_rx.recv() => { - if stream_id == 0 { - if let ProxyToClientMsg::Close = msg { - if true { - tracing::info!("Resetting all active proxy streams on reconnect"); - } - for (_, tx) in active_streams.drain() { - let _ = tx.send(ProxyToClientMsg::Close); - } - } - } else if let Some(tx) = active_streams.get(&stream_id) { - if tx.send(msg).is_err() { - active_streams.remove(&stream_id); - } - } - } - Some(stream_id) = connect_rx.recv() => { - active_streams.remove(&stream_id); - } - } - } - - Ok(()) -} - -/// Extracts `host:port` from an HTTP absolute-URI like `http://example.com/path` or `https://example.com`. -/// Falls back to the raw target if already in `host:port` form. -fn extract_host_port(uri: &str, default_port: u16) -> String { - let without_scheme = if let Some(rest) = uri.strip_prefix("https://") { - rest - } else if let Some(rest) = uri.strip_prefix("http://") { - rest - } else { - uri - }; - // Trim path/query fragment - let host_part = without_scheme.split('/').next().unwrap_or(without_scheme); - if host_part.contains(':') { - host_part.to_string() - } else { - format!("{}:{}", host_part, default_port) - } -} - -struct StreamGuard { - stream_id: u16, - close_tx: mpsc::Sender, -} - -impl Drop for StreamGuard { - fn drop(&mut self) { - let tx = self.close_tx.clone(); - let id = self.stream_id; - tokio::spawn(async move { - let _ = tx.send(id).await; - }); - } -} - -async fn handle_udp_associate( - mut client_tcp: TcpStream, - udp_socket: tokio::net::UdpSocket, - stream_id: u16, - event_tx: mpsc::Sender, - mut rx: mpsc::UnboundedReceiver, - close_tx: mpsc::Sender, - debug: bool, - matcher: ExclusionMatcher, - connect_timeout: Duration, -) -> Result<()> { - let client_udp_addr = Arc::new(std::sync::Mutex::new(None)); - let mut buf = vec![0u8; 65536]; - - let udp_socket = Arc::new(udp_socket); - let sock_rx = udp_socket.clone(); - let sock_tx = udp_socket; - - let mut direct_udp_v4: Option> = None; - let mut direct_udp_v6: Option> = None; - - let mut tcp_buf = [0u8; 1]; - loop { - tokio::select! { - res = client_tcp.read(&mut tcp_buf) => { - match res { - Ok(0) | Err(_) => break, - Ok(_) => {} - } - } - res = sock_rx.recv_from(&mut buf) => { - let (len, addr) = match res { - Ok(v) => v, - Err(e) => { - tracing::debug!("udp_associate recv_from error: {}", e); - continue; // transient error, don't kill the session - } - }; - { - let mut guard = client_udp_addr.lock().unwrap(); - if guard.is_none() { - *guard = Some(addr); - } - } - if len < 4 { continue; } - let frag = buf[2]; - if frag != 0 { continue; } // Fragmented UDP not supported - let atyp = buf[3]; - let (header_len, target) = match atyp { - 0x01 => { - if len < 10 { continue; } - let ip = std::net::Ipv4Addr::new(buf[4], buf[5], buf[6], buf[7]); - let port = u16::from_be_bytes([buf[8], buf[9]]); - (10, format!("{}:{}", ip, port)) - } - 0x03 => { - if len < 5 { continue; } - let domain_len = buf[4] as usize; - if len < 5 + domain_len + 2 { continue; } - let domain = String::from_utf8_lossy(&buf[5..5+domain_len]); - let port = u16::from_be_bytes([buf[5+domain_len], buf[5+domain_len+1]]); - (5 + domain_len + 2, format!("{}:{}", domain, port)) - } - 0x04 => { - if len < 22 { continue; } - let mut octets = [0u8; 16]; - octets.copy_from_slice(&buf[4..20]); - let ip = std::net::Ipv6Addr::from(octets); - let port = u16::from_be_bytes([buf[20], buf[21]]); - (22, format!("[{}]:{}", ip, port)) - } - _ => continue, - }; - let payload = bytes::Bytes::copy_from_slice(&buf[header_len..len]); - - let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() }; - let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 0 }; - // Check if target should bypass the tunnel - if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await { - if true { - tracing::debug!("proxy UDP BYPASS target={}", target); - } - // Resolve target to find if it is IPv4 or IPv6 - if let Ok(resolved_addrs) = tokio::net::lookup_host(&target).await { - if let Some(target_addr) = resolved_addrs.into_iter().next() { - let is_ipv6 = target_addr.is_ipv6(); - let direct_socket = if is_ipv6 { - if direct_udp_v6.is_none() { - match create_udp_socket_bypassing_tun(true, matcher.physical_if_index, &matcher.physical_if_name).await { - Ok(s) => { - let s_arc = Arc::new(s); - spawn_direct_udp_reader(s_arc.clone(), sock_tx.clone(), client_udp_addr.clone(), debug); - direct_udp_v6 = Some(s_arc); - } - Err(e) => { - tracing::error!("Failed to create bypass UDP v6 socket: {}", e); - } - } - } - &direct_udp_v6 - } else { - if direct_udp_v4.is_none() { - match create_udp_socket_bypassing_tun(false, matcher.physical_if_index, &matcher.physical_if_name).await { - Ok(s) => { - let s_arc = Arc::new(s); - spawn_direct_udp_reader(s_arc.clone(), sock_tx.clone(), client_udp_addr.clone(), debug); - direct_udp_v4 = Some(s_arc); - } - Err(e) => { - tracing::error!("Failed to create bypass UDP v4 socket: {}", e); - } - } - } - &direct_udp_v4 - }; - - if let Some(s) = direct_socket { - if let Err(e) = s.send_to(&payload, target_addr).await { - if true { - tracing::warn!("failed to send bypass UDP packet to {}: {}", target_addr, e); - } - } - } - } - } - } else { - tracing::debug!("proxy.rs forwarding UDP DATA to server for target={} payload len={}", target, payload.len()); - let _ = event_tx.send(ProxyEvent::UdpData { stream_id, target, payload }).await; - } - } - msg = rx.recv() => { - match msg { - Some(ProxyToClientMsg::UdpData(target, data)) => { - if let Some(client_addr) = { - let guard = client_udp_addr.lock().unwrap(); - *guard - } { - let mut packet = vec![0x00, 0x00, 0x00]; - let mut parts = target.rsplitn(2, ':'); - let port_str = parts.next().unwrap_or("0"); - let host_str = parts.next().unwrap_or(&target); - let host_str = host_str.trim_start_matches('[').trim_end_matches(']'); - let port = port_str.parse::().unwrap_or(0); - - if let Ok(ipv4) = host_str.parse::() { - packet.push(0x01); - packet.extend_from_slice(&ipv4.octets()); - } else if let Ok(ipv6) = host_str.parse::() { - packet.push(0x04); - packet.extend_from_slice(&ipv6.octets()); - } else { - packet.push(0x03); - let bytes = host_str.as_bytes(); - packet.push(bytes.len() as u8); - packet.extend_from_slice(bytes); - } - packet.extend_from_slice(&port.to_be_bytes()); - packet.extend_from_slice(&data); - tracing::debug!("proxy.rs forwarding UDP REPLY to client_addr={} from server for target={} payload len={}", client_addr, target, data.len()); - let _ = sock_tx.send_to(&packet, client_addr).await; - } else { - tracing::error!("proxy.rs failed to parse target string as SocketAddr: {}", target); - } - } - Some(ProxyToClientMsg::Close) | Some(ProxyToClientMsg::Error(_)) | None => break, - _ => {} - } - } - } - } - let _ = close_tx.send(stream_id).await; - Ok(()) -} - -fn spawn_direct_udp_reader( - direct_socket: Arc, - sock_tx: Arc, - client_udp_addr: Arc>>, - _debug: bool, -) { - tokio::spawn(async move { - let mut buf = vec![0u8; 65536]; - loop { - match direct_socket.recv_from(&mut buf).await { - Ok((len, target_addr)) => { - let client_addr = { - let guard = client_udp_addr.lock().unwrap(); - *guard - }; - if let Some(client_addr) = client_addr { - let mut packet = vec![0x00, 0x00, 0x00]; - if let Ok(ipv4) = target_addr.ip().to_string().parse::() { - packet.push(0x01); - packet.extend_from_slice(&ipv4.octets()); - } else if let Ok(ipv6) = target_addr.ip().to_string().parse::() { - packet.push(0x04); - packet.extend_from_slice(&ipv6.octets()); - } else { - continue; - } - packet.extend_from_slice(&target_addr.port().to_be_bytes()); - packet.extend_from_slice(&buf[..len]); - if let Err(e) = sock_tx.send_to(&packet, client_addr).await { - if true { - tracing::warn!("failed to send direct UDP response to client: {e}"); - } - } - } - } - Err(e) => { - if true { - tracing::debug!("direct UDP socket read loop exiting: {e}"); - } - break; - } - } - } - }); -} - -async fn handle_proxy_client( - mut client: TcpStream, - stream_id: u16, - event_tx: mpsc::Sender, - mut rx: mpsc::UnboundedReceiver, - close_tx: mpsc::Sender, - connect_timeout: Duration, - debug: bool, - matcher: ExclusionMatcher, - max_chunk: usize, -) -> Result<()> { - let _guard = StreamGuard { stream_id, close_tx: close_tx.clone() }; - - // Peek the first byte to distinguish SOCKS5 (0x05) from HTTP (any printable ASCII) - let mut first_byte = [0_u8; 1]; - client.read_exact(&mut first_byte).await?; - - let target: String; - let is_socks5 = first_byte[0] == 0x05; - - if is_socks5 { - // ── SOCKS5 Handshake ────────────────────────────────────────── - let mut second_byte = [0_u8; 1]; - client.read_exact(&mut second_byte).await?; - let nmethods = second_byte[0] as usize; - if nmethods > 0 { - let mut methods_buf = vec![0_u8; nmethods]; - client.read_exact(&mut methods_buf).await?; - } - // Reply: version=5, NO AUTHENTICATION - client.write_all(&[0x05, 0x00]).await?; - - // ── SOCKS5 Request ──────────────────────────────────────────── - let mut req = [0_u8; 4]; - client.read_exact(&mut req).await?; - if req[0] != 0x05 { - return Err(anyhow!("SOCKS5 request version mismatch")); - } - - let is_udp = req[1] == 0x03; - if req[1] != 0x01 && !is_udp { - // Not CONNECT and Not UDP ASSOCIATE — send COMMAND NOT SUPPORTED - client.write_all(&[0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?; - return Err(anyhow!("unsupported SOCKS5 command {}", req[1])); - } - - let mut addr_buf = [0_u8; 256]; - target = match req[3] { - 0x01 => { - // IPv4: 4 bytes address + 2 bytes port - client.read_exact(&mut addr_buf[0..6]).await?; - let ip = std::net::Ipv4Addr::new(addr_buf[0], addr_buf[1], addr_buf[2], addr_buf[3]); - let port = u16::from_be_bytes([addr_buf[4], addr_buf[5]]); - format!("{}:{}", ip, port) - } - 0x03 => { - // Domain: 1 byte length, then domain, then 2 bytes port - client.read_exact(&mut addr_buf[0..1]).await?; - let domain_len = addr_buf[0] as usize; - client.read_exact(&mut addr_buf[0..domain_len + 2]).await?; - let domain = String::from_utf8_lossy(&addr_buf[0..domain_len]); - let port = u16::from_be_bytes([addr_buf[domain_len], addr_buf[domain_len + 1]]); - format!("{}:{}", domain, port) - } - 0x04 => { - // IPv6: 16 bytes + 2 bytes port - client.read_exact(&mut addr_buf[0..18]).await?; - let mut octets = [0u8; 16]; - octets.copy_from_slice(&addr_buf[0..16]); - let ip = std::net::Ipv6Addr::from(octets); - let port = u16::from_be_bytes([addr_buf[16], addr_buf[17]]); - format!("[{}]:{}", ip, port) - } - atyp => { - client.write_all(&[0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?; - return Err(anyhow!("unsupported SOCKS5 address type: {}", atyp)); - } - }; - - if is_udp { - if true { tracing::debug!("proxy UDP ASSOCIATE stream_id={stream_id}"); } - let udp_socket = UdpSocket::bind("127.0.0.1:0").await?; - let port = udp_socket.local_addr()?.port(); - let mut reply = vec![0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1]; - reply.extend_from_slice(&port.to_be_bytes()); - client.write_all(&reply).await?; - - event_tx.send(ProxyEvent::UdpAssociate { stream_id }).await?; - return handle_udp_associate( - client, - udp_socket, - stream_id, - event_tx, - rx, - close_tx, - debug, - matcher, - connect_timeout, - ).await; - } - - tracing::debug!("proxy CONNECT stream_id={stream_id} target={target}"); - let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() }; - let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 0 }; - if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await { - return direct_connect_socks5( - client, - stream_id, - &target, - matcher.physical_if_index, - &matcher.physical_if_name, - close_tx, - debug, - ).await; - } - event_tx.send(ProxyEvent::NewStream { stream_id, target: target.clone() }).await?; - - match timeout(connect_timeout, rx.recv()).await { - Ok(Some(ProxyToClientMsg::ConnectOk)) => { - // SUCCESS: version, 0=success, reserved, IPv4 type, 4 bytes addr, 2 bytes port - client.write_all(&[0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?; - } - Ok(Some(ProxyToClientMsg::Error(msg))) => { - client.write_all(&[0x05, 0x04, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?; - let _ = close_tx.send(stream_id).await; - return Err(anyhow!("SOCKS5 connect error: {msg}")); - } - Ok(_) => { - client.write_all(&[0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?; - let _ = close_tx.send(stream_id).await; - return Err(anyhow!("connect dropped")); - } - Err(_) => { - client.write_all(&[0x05, 0x04, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?; - let _ = close_tx.send(stream_id).await; - return Err(anyhow!("connect timeout")); - } - } - } else { - // ── HTTP Proxy (CONNECT and plain GET/POST) ─────────────────── - // Read the rest of the HTTP request headers byte-by-byte - let mut header_bytes = Vec::with_capacity(512); - header_bytes.push(first_byte[0]); - let mut chunk = [0_u8; 512]; - loop { - let n = client.read(&mut chunk).await?; - if n == 0 { - return Err(anyhow!("connection closed during HTTP header read")); - } - header_bytes.extend_from_slice(&chunk[..n]); - if header_bytes.len() >= 4 { - let tail = &header_bytes[header_bytes.len().saturating_sub(4)..]; - if tail.ends_with(b"\r\n\r\n") { - break; - } - } - if header_bytes.len() > 8192 { - client.write_all(b"HTTP/1.1 431 Request Header Fields Too Large\r\n\r\n").await?; - return Err(anyhow!("HTTP header too large")); - } - } - - let req_str = String::from_utf8_lossy(&header_bytes); - let first_line = req_str.lines().next().unwrap_or(""); - let parts: Vec<&str> = first_line.split_whitespace().collect(); - if parts.len() < 2 { - client.write_all(b"HTTP/1.1 400 Bad Request\r\n\r\n").await?; - return Err(anyhow!("malformed HTTP request line: {:?}", first_line)); - } - - let method = parts[0].to_uppercase(); - let raw_uri = parts[1]; - - target = if method == "CONNECT" { - // CONNECT uses host:port directly — e.g. "CONNECT example.com:443 HTTP/1.1" - if raw_uri.contains(':') { - raw_uri.to_string() - } else { - format!("{}:443", raw_uri) - } - } else { - // Plain HTTP: absolute URI like "GET http://example.com/path HTTP/1.1" - let default_port = if raw_uri.starts_with("https://") { 443u16 } else { 80u16 }; - extract_host_port(raw_uri, default_port) - }; - - if true { - tracing::info!("proxy CONNECT stream_id={stream_id} target={target}"); - } - let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() }; - let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 443 }; - if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await { - return direct_connect_http( - client, - stream_id, - &target, - method.as_str(), - header_bytes, - matcher.physical_if_index, - &matcher.physical_if_name, - close_tx, - debug, - ).await; - } - event_tx.send(ProxyEvent::NewStream { stream_id, target: target.clone() }).await?; - - match timeout(connect_timeout, rx.recv()).await { - Ok(Some(ProxyToClientMsg::ConnectOk)) => { - if method == "CONNECT" { - // For CONNECT, tell client the tunnel is ready - client.write_all(b"HTTP/1.1 200 Connection Established\r\nProxy-Agent: ostp/1.0\r\n\r\n").await?; - } else { - // For plain HTTP (GET/POST), we MUST forward the request headers we consumed - // to the server over the newly established tunnel. - event_tx.send(ProxyEvent::Data { - stream_id, - payload: bytes::Bytes::copy_from_slice(&header_bytes), - }).await?; - } - } - Ok(Some(ProxyToClientMsg::Error(msg))) => { - client.write_all(b"HTTP/1.1 502 Bad Gateway\r\n\r\n").await?; - let _ = close_tx.send(stream_id).await; - return Err(anyhow!("HTTP connect error: {msg}")); - } - Ok(_) => { - client.write_all(b"HTTP/1.1 502 Bad Gateway\r\n\r\n").await?; - let _ = close_tx.send(stream_id).await; - return Err(anyhow!("connect dropped")); - } - Err(_) => { - client.write_all(b"HTTP/1.1 504 Gateway Timeout\r\n\r\n").await?; - let _ = close_tx.send(stream_id).await; - return Err(anyhow!("connect timeout")); - } - } - } - - // ── Bidirectional raw data forwarding ───────────────────────────── - let mut tcp_buf = vec![0_u8; 65536]; - loop { - tokio::select! { - read_res = client.read(&mut tcp_buf) => { - match read_res { - Ok(0) => { - let _ = event_tx.send(ProxyEvent::Close { stream_id }).await; - if true { - tracing::info!("proxy CLOSE stream_id={stream_id}"); - } - break; - } - Ok(n) => { - let mut offset = 0; - while offset < n { - let end = (offset + max_chunk).min(n); - let _ = event_tx.send(ProxyEvent::Data { - stream_id, - payload: bytes::Bytes::copy_from_slice(&tcp_buf[offset..end]), - }).await; - offset = end; - } - } - Err(_) => { - let _ = event_tx.send(ProxyEvent::Close { stream_id }).await; - if true { - tracing::info!("proxy CLOSE stream_id={stream_id}"); - } - break; - } - } - } - msg = rx.recv() => { - match msg { - Some(ProxyToClientMsg::Data(data)) => { - if client.write_all(&data).await.is_err() { - let _ = event_tx.send(ProxyEvent::Close { stream_id }).await; - break; - } - } - Some(ProxyToClientMsg::Close) | Some(ProxyToClientMsg::Error(_)) | None => { - break; - } - Some(ProxyToClientMsg::ConnectOk) | Some(ProxyToClientMsg::UdpData(_, _)) => {} // ignored after connect phase - } - } - } - } - - let _ = close_tx.send(stream_id).await; - Ok(()) -} - - -fn split_host_port(target: &str) -> Option<(String, u16)> { - if let Some((host, port)) = target.rsplit_once(':') { - if host.starts_with('[') && host.ends_with(']') { - let host = host.trim_start_matches('[').trim_end_matches(']').to_string(); - let port = port.parse().ok()?; - return Some((host, port)); - } - if host.contains(':') { - return None; - } - let port = port.parse().ok()?; - return Some((host.to_string(), port)); - } - None -} - -async fn direct_connect_socks5( - mut client: TcpStream, - stream_id: u16, - target: &str, - physical_if_index: Option, - physical_if_name: &Option, - close_tx: mpsc::Sender, - _debug: bool, -) -> Result<()> { - if true { - tracing::info!("proxy BYPASS stream_id={stream_id} target={target}"); - } - let mut remote = connect_bypassing_tun(target, physical_if_index, physical_if_name).await?; - - client.write_all(&[0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?; - let _ = tokio::io::copy_bidirectional(&mut client, &mut remote).await; - let _ = close_tx.send(stream_id).await; - Ok(()) -} - -async fn direct_connect_http( - mut client: TcpStream, - stream_id: u16, - target: &str, - method: &str, - header_bytes: Vec, - physical_if_index: Option, - physical_if_name: &Option, - close_tx: mpsc::Sender, - _debug: bool, -) -> Result<()> { - if true { - tracing::info!("proxy BYPASS stream_id={stream_id} target={target}"); - } - let mut remote = connect_bypassing_tun(target, physical_if_index, physical_if_name).await?; - - if method == "CONNECT" { - client.write_all(b"HTTP/1.1 200 Connection Established\r\nProxy-Agent: ostp/1.0\r\n\r\n").await?; - } else { - remote.write_all(&header_bytes).await?; - } - - let _ = tokio::io::copy_bidirectional(&mut client, &mut remote).await; - let _ = close_tx.send(stream_id).await; - Ok(()) -} diff --git a/ostp-client/src/tunnel/router.rs b/ostp-client/src/tunnel/router.rs new file mode 100644 index 0000000..dc16a24 --- /dev/null +++ b/ostp-client/src/tunnel/router.rs @@ -0,0 +1,155 @@ +use std::net::IpAddr; +use crate::config::{RoutingConfig, RoutingRule}; + +#[derive(Debug, Clone)] +pub struct Session { + pub inbound_tag: String, + pub source_ip: Option, + pub destination_ip: Option, + pub destination_port: u16, + pub protocol: String, // "tcp" or "udp" + pub sni: Option, + pub process_name: Option, +} + +pub struct Router { + config: RoutingConfig, +} + +impl Router { + pub fn new(config: RoutingConfig) -> Self { + Self { config } + } + + /// Evaluates the session against routing rules and returns the outbound tag + pub fn route(&self, session: &Session) -> String { + for rule in &self.config.rules { + if self.match_rule(rule, session) { + return rule.outbound.clone(); + } + } + self.config.default_outbound.clone() + } + + fn match_rule(&self, rule: &RoutingRule, session: &Session) -> bool { + // All specified conditions in a rule must match (AND logic) + let mut matched_any_condition = false; + + // 1. Inbound Tag match + if let Some(inbounds) = &rule.inbound_tag { + if !inbounds.iter().any(|tag| tag == &session.inbound_tag) { + return false; + } + matched_any_condition = true; + } + + // 2. Domain / SNI match + if let Some(domains) = &rule.domain_suffix { + let mut domain_match = false; + if let Some(sni) = &session.sni { + let sni = sni.to_lowercase(); + domain_match = domains.iter().any(|d| { + let d = d.to_lowercase(); + sni == d || sni.ends_with(&format!(".{}", d)) + }); + } + if !domain_match { + return false; + } + matched_any_condition = true; + } + + // 3. Process match + if let Some(processes) = &rule.process_name { + let mut proc_match = false; + if let Some(proc) = &session.process_name { + let proc = proc.to_lowercase(); + proc_match = processes.iter().any(|p| proc.contains(&p.to_lowercase())); + } + if !proc_match { + return false; + } + matched_any_condition = true; + } + + // 4. IP CIDR match + if let Some(cidrs) = &rule.ip_cidr { + let mut ip_match = false; + if let Some(dst_ip) = session.destination_ip { + ip_match = cidrs.iter().any(|cidr| { + match ipnet::IpNet::from_str(cidr) { + Ok(net) => net.contains(&dst_ip), + Err(_) => { + // fallback to exact ip match if not a valid CIDR + if let Ok(ip) = cidr.parse::() { + ip == dst_ip + } else { + false + } + } + } + }); + } + if !ip_match { + return false; + } + matched_any_condition = true; + } + + // A rule must have at least one condition to match + matched_any_condition + } +} + +use std::str::FromStr; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_router() { + let rules = vec![ + RoutingRule { + domain_suffix: Some(vec!["vk.com".to_string()]), + ip_cidr: None, + process_name: None, + inbound_tag: None, + outbound: "direct".to_string(), + }, + RoutingRule { + domain_suffix: None, + ip_cidr: None, + process_name: Some(vec!["telegram.exe".to_string()]), + inbound_tag: None, + outbound: "proxy-group".to_string(), + }, + ]; + + let config = RoutingConfig { + rules, + default_outbound: "proxy-group".to_string(), + }; + + let router = Router::new(config); + + let mut session = Session { + inbound_tag: "tun-in".to_string(), + source_ip: None, + destination_ip: None, + destination_port: 443, + protocol: "tcp".to_string(), + sni: Some("api.vk.com".to_string()), + process_name: None, + }; + + assert_eq!(router.route(&session), "direct"); + + session.sni = None; + session.process_name = Some("C:\\App\\Telegram.exe".to_string()); + assert_eq!(router.route(&session), "proxy-group"); + + session.process_name = Some("chrome.exe".to_string()); + assert_eq!(router.route(&session), "proxy-group"); // fallback + } +} diff --git a/ostp-client/src/tunnel/udp_nat.rs b/ostp-client/src/tunnel/udp_nat.rs index 93ab9c2..4c67146 100644 --- a/ostp-client/src/tunnel/udp_nat.rs +++ b/ostp-client/src/tunnel/udp_nat.rs @@ -1,306 +1 @@ -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpStream, UdpSocket}; -use futures::StreamExt; - -pub async fn run_udp_nat( - udp_socket: netstack_smoltcp::UdpSocket, - proxy_addr: String, - debug: bool, - matcher: std::sync::Arc>, - phys_if_index: Option, - phys_if_name: Option, -) { - let (mut rx, tx) = udp_socket.split(); - let tx = Arc::new(Mutex::new(tx)); - - // map from internal client src to a channel that sends (payload, external_dst) - let mut sessions: HashMap, SocketAddr)>> = HashMap::new(); - - let mut cleanup_tick = tokio::time::interval(std::time::Duration::from_secs(60)); - - loop { - tokio::select! { - packet = rx.next() => { - match packet { - Some((payload, src, dst)) => { - if payload.is_empty() { continue; } - - if !sessions.contains_key(&src) { - let (session_tx, mut session_rx) = mpsc::channel::<(Vec, SocketAddr)>(1024); - sessions.insert(src, session_tx); - - let proxy_addr_clone = proxy_addr.clone(); - let tx_clone = tx.clone(); - - let mut should_bypass = false; - { - let matcher_guard = matcher.read().await; - if matcher_guard.match_ip(&dst.ip()) { - should_bypass = true; - if debug { - tracing::info!("TUN UDP BYPASS (IP match): {} → {}", src, dst); - } - } - - #[cfg(target_os = "windows")] - if !should_bypass { - if let Some(proc_name) = crate::tunnel::process_lookup::get_process_name_from_port_udp(src.port()) { - if debug { - tracing::debug!("TUN UDP lookup: port {} -> process {}", src.port(), proc_name); - } - if matcher_guard.match_process(&proc_name) { - should_bypass = true; - if debug { - tracing::debug!("TUN UDP BYPASS (Process match): {} ({} → {})", proc_name, src, dst); - } - } - } else { - if debug { - tracing::debug!("TUN UDP lookup: port {} -> no process found", src.port()); - } - } - } - } - - let p_if_idx = phys_if_index; - let p_if_name = phys_if_name.clone(); - - tokio::spawn(async move { - if should_bypass { - if debug { - tracing::info!("Starting UDP BYPASS session for {}", src); - } - let res = start_udp_bypass_session(src, p_if_idx, p_if_name, &mut session_rx, tx_clone).await; - if res.is_err() { - tracing::debug!("UDP BYPASS session for {} ended: {:?}", src, res.err()); - } - } else { - tracing::debug!("Starting UDP NAT session for {}", src); - let res = start_udp_session(src, proxy_addr_clone, &mut session_rx, tx_clone).await; - if res.is_err() { - tracing::debug!("UDP NAT session for {} ended: {:?}", src, res.err()); - } - } - }); - } - - if let Some(sender) = sessions.get(&src) { - match sender.try_send((payload, dst)) { - Err(mpsc::error::TrySendError::Closed(_)) => { - sessions.remove(&src); - } - Err(mpsc::error::TrySendError::Full(_)) => { - // Drop packet to avoid blocking the TUN interface loop - } - Ok(_) => {} - } - } - } - None => break, - } - } - _ = cleanup_tick.tick() => { - sessions.retain(|_, sender| !sender.is_closed()); - } - } - } -} - -async fn start_udp_bypass_session( - client_src: SocketAddr, - phys_if_index: Option, - phys_if_name: Option, - session_rx: &mut mpsc::Receiver<(Vec, SocketAddr)>, - smoltcp_tx: Arc>, -) -> anyhow::Result<()> { - let socket = match client_src { - SocketAddr::V4(_) => UdpSocket::bind("0.0.0.0:0").await?, - SocketAddr::V6(_) => UdpSocket::bind("[::]:0").await?, - }; - - #[cfg(target_os = "windows")] - if let Some(idx) = phys_if_index { - if let Err(e) = crate::tunnel::proxy::bind_socket_to_interface(&socket, client_src.is_ipv6(), idx) { - tracing::error!("TUN UDP BYPASS failed to bind to physical interface {}: {}", idx, e); - } else { - // Keep debug log - } - } else { - tracing::warn!("TUN UDP BYPASS has no physical interface index!"); - } - - #[cfg(target_os = "linux")] - if let Some(ref name) = phys_if_name { - let _ = crate::tunnel::proxy::bind_socket_to_interface(&socket, name); - } - - let socket = Arc::new(socket); - let socket_rx = socket.clone(); - - // Spawn a task to read from physical socket and send back to smoltcp - let tx_clone = smoltcp_tx.clone(); - tokio::spawn(async move { - use futures::SinkExt; - let mut buf = [0u8; 65536]; - loop { - match socket_rx.recv_from(&mut buf).await { - Ok((n, peer)) => { - let mut lock = tx_clone.lock().await; - let _ = lock.send((buf[..n].to_vec(), peer, client_src)).await; - } - Err(_) => break, - } - } - }); - - while let Some((payload, dst)) = session_rx.recv().await { - socket.send_to(&payload, dst).await?; - } - - Ok(()) -} - - -async fn start_udp_session( - client_src: SocketAddr, - proxy_addr: String, - session_rx: &mut mpsc::Receiver<(Vec, SocketAddr)>, - smoltcp_tx: Arc>, -) -> anyhow::Result<()> { - // 1. TCP Connect to SOCKS5 proxy - let mut tcp = TcpStream::connect(&proxy_addr).await?; - - // Auth - tcp.write_all(&[5, 1, 0]).await?; - let mut buf = [0u8; 2]; - tcp.read_exact(&mut buf).await?; - if buf[0] != 5 || buf[1] != 0 { - return Err(anyhow::anyhow!("socks5 auth rejected")); - } - - // UDP ASSOCIATE to 0.0.0.0:0 - tcp.write_all(&[5, 3, 0, 1, 0, 0, 0, 0, 0, 0]).await?; - let mut rep_hdr = [0u8; 4]; - tcp.read_exact(&mut rep_hdr).await?; - if rep_hdr[1] != 0 { - return Err(anyhow::anyhow!("socks5 udp associate rejected")); - } - - let mut relay_addr = match rep_hdr[3] { - 1 => { - let mut addr_buf = [0u8; 6]; - tcp.read_exact(&mut addr_buf).await?; - let ip = std::net::Ipv4Addr::new(addr_buf[0], addr_buf[1], addr_buf[2], addr_buf[3]); - let port = u16::from_be_bytes([addr_buf[4], addr_buf[5]]); - SocketAddr::new(std::net::IpAddr::V4(ip), port) - } - 4 => { - let mut addr_buf = [0u8; 18]; - tcp.read_exact(&mut addr_buf).await?; - let mut octets = [0u8; 16]; - octets.copy_from_slice(&addr_buf[0..16]); - let ip = std::net::Ipv6Addr::from(octets); - let port = u16::from_be_bytes([addr_buf[16], addr_buf[17]]); - SocketAddr::new(std::net::IpAddr::V6(ip), port) - } - _ => return Err(anyhow::anyhow!("unsupported ATYP in UDP ASSOCIATE response")), - }; - - // If proxy returned 0.0.0.0 or ::, use the proxy's IP - if relay_addr.ip().is_unspecified() { - if let Ok(proxy_sock) = proxy_addr.parse::() { - relay_addr.set_ip(proxy_sock.ip()); - } - } - - // Local SOCKS5 proxy always returns 127.0.0.1 (IPv4), so always bind IPv4 - let udp = UdpSocket::bind("127.0.0.1:0").await?; - - // CRITICAL for Android: protect this UDP socket so it goes out via the - // real physical interface, not back into the TUN (which would cause an - // infinite routing loop for DNS and all other UDP traffic). - #[cfg(target_os = "android")] - { - use std::os::unix::io::AsRawFd; - crate::bridge::protect_socket(udp.as_raw_fd()); - } - - let mut buf = vec![0u8; 65536]; - - let timeout = std::time::Duration::from_secs(300); // 5 min idle timeout - let mut tcp_buf = [0u8; 1]; - - loop { - tokio::select! { - res = tokio::time::timeout(timeout, session_rx.recv()) => { - match res { - Ok(Some((payload, dst))) => { - let mut packet = vec![0u8; 3]; // RSV, FRAG - match dst.ip() { - std::net::IpAddr::V4(v4) => { packet.push(1); packet.extend_from_slice(&v4.octets()); } - std::net::IpAddr::V6(v6) => { packet.push(4); packet.extend_from_slice(&v6.octets()); } - } - packet.extend_from_slice(&dst.port().to_be_bytes()); - packet.extend_from_slice(&payload); - tracing::debug!("udp_nat SENDING UDP ASSOCIATE payload len={} to relay_addr={} (original dst: {})", payload.len(), relay_addr, dst); - let _ = udp.send_to(&packet, relay_addr).await; - } - Ok(None) => break, - Err(_) => break, // timeout - } - } - res = udp.recv_from(&mut buf) => { - match res { - Err(e) => { - tracing::debug!("udp_nat recv_from error: {}", e); - continue; // transient error, don't kill the session - } - Ok((len, _peer)) => { - if len < 4 { continue; } - let frag = buf[2]; - if frag != 0 { continue; } // fragment not supported - let atyp = buf[3]; - let (header_len, remote_dst) = match atyp { - 1 => { - if len < 10 { continue; } - let ip = std::net::Ipv4Addr::new(buf[4], buf[5], buf[6], buf[7]); - let port = u16::from_be_bytes([buf[8], buf[9]]); - (10, SocketAddr::new(std::net::IpAddr::V4(ip), port)) - } - 4 => { - if len < 22 { continue; } - let mut octets = [0u8; 16]; - octets.copy_from_slice(&buf[4..20]); - let ip = std::net::Ipv6Addr::from(octets); - let port = u16::from_be_bytes([buf[20], buf[21]]); - (22, SocketAddr::new(std::net::IpAddr::V6(ip), port)) - } - _ => continue, - }; - let payload = buf[header_len..len].to_vec(); - tracing::debug!("udp_nat RECEIVED UDP ASSOCIATE REPLY from {} for {} len={}", remote_dst, client_src, payload.len()); - use futures::SinkExt; - if let Err(e) = smoltcp_tx.lock().await.send((payload, remote_dst, client_src)).await { - tracing::error!("udp_nat failed to inject packet into smoltcp: {}", e); - } else { - tracing::debug!("udp_nat successfully injected packet into smoltcp from {} to {}", remote_dst, client_src); - } - } - } - } - // If TCP drops, UDP association is over - res = tcp.read(&mut tcp_buf) => { - match res { - Ok(0) | Err(_) => break, - Ok(_) => {} - } - } - } - } - - Ok(()) -} +// Cleared for refactoring diff --git a/ostp-client/src/tunnel/udp_nat.rs.bak b/ostp-client/src/tunnel/udp_nat.rs.bak new file mode 100644 index 0000000..3a66d59 Binary files /dev/null and b/ostp-client/src/tunnel/udp_nat.rs.bak differ