diff --git a/ostp-client/src/bridge.rs b/ostp-client/src/bridge.rs index a29b799..5c25fd7 100644 --- a/ostp-client/src/bridge.rs +++ b/ostp-client/src/bridge.rs @@ -200,6 +200,12 @@ impl Bridge { self.last_rtt_ms = now.saturating_sub(ts) as f64; self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed); } + RelayMessage::UdpAssociate => { + // Should not be received by client, ignore + } + RelayMessage::UdpData(target, data) => { + let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data)))); + } RelayMessage::KeepAlive | RelayMessage::Ping(_) | RelayMessage::Connect(_) => {} } } @@ -581,6 +587,13 @@ impl Bridge { let _ = tx.send(UiEvent::Log(format!("Proxy CONNECT stream_id={stream_id} target={target}"))).await; (stream_id, RelayMessage::Connect(target), false) } + ProxyEvent::UdpAssociate { stream_id } => { + let _ = tx.send(UiEvent::Log(format!("Proxy UDP ASSOCIATE stream_id={stream_id}"))).await; + (stream_id, RelayMessage::UdpAssociate, false) + } + ProxyEvent::UdpData { stream_id, target, payload } => { + (stream_id, RelayMessage::UdpData(target, payload.to_vec()), false) + } ProxyEvent::Data { stream_id, payload } => (stream_id, RelayMessage::Data(payload.to_vec()), false), ProxyEvent::Close { stream_id } => { let _ = tx.send(UiEvent::Log(format!("Proxy CLOSE stream_id={stream_id}"))).await; diff --git a/ostp-client/src/tunnel/mod.rs b/ostp-client/src/tunnel/mod.rs index d348733..15a0456 100644 --- a/ostp-client/src/tunnel/mod.rs +++ b/ostp-client/src/tunnel/mod.rs @@ -41,6 +41,14 @@ pub enum ProxyEvent { stream_id: u16, target: String, }, + UdpAssociate { + stream_id: u16, + }, + UdpData { + stream_id: u16, + target: String, + payload: bytes::Bytes, + }, Data { stream_id: u16, payload: bytes::Bytes, @@ -54,6 +62,7 @@ pub enum ProxyEvent { pub enum ProxyToClientMsg { ConnectOk, Data(bytes::Bytes), + UdpData(String, bytes::Bytes), Close, Error(String), } diff --git a/ostp-client/src/tunnel/native_handler.rs b/ostp-client/src/tunnel/native_handler.rs index d446ab2..a00cae1 100644 --- a/ostp-client/src/tunnel/native_handler.rs +++ b/ostp-client/src/tunnel/native_handler.rs @@ -126,7 +126,7 @@ pub async fn run_native_tunnel( .build()?; let mut runner_task = tokio::spawn(async move { - if let Some(mut runner) = tcp_runner { + if let Some(runner) = tcp_runner { let _ = runner.await; } }); diff --git a/ostp-client/src/tunnel/proxy.rs b/ostp-client/src/tunnel/proxy.rs index 878c495..24bc2ff 100644 --- a/ostp-client/src/tunnel/proxy.rs +++ b/ostp-client/src/tunnel/proxy.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; use anyhow::{anyhow, Context, Result}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::{TcpListener, TcpStream, UdpSocket}; +use std::sync::Arc; use tokio::sync::{mpsc, watch}; use tokio::time::{timeout, Duration}; @@ -140,6 +141,104 @@ impl Drop for StreamGuard { } } +async fn handle_udp_associate( + mut client_tcp: TcpStream, + udp_socket: tokio::net::UdpSocket, + stream_id: u16, + event_tx: mpsc::Sender, + mut rx: mpsc::UnboundedReceiver, + close_tx: mpsc::Sender, + _debug: bool, +) -> Result<()> { + let mut client_udp_addr = None; + let mut buf = vec![0u8; 65536]; + + let udp_socket = Arc::new(udp_socket); + let sock_rx = udp_socket.clone(); + let sock_tx = udp_socket; + + let mut tcp_buf = [0u8; 1]; + loop { + tokio::select! { + res = client_tcp.read(&mut tcp_buf) => { + let n = res?; + if n == 0 { break; } + } + res = sock_rx.recv_from(&mut buf) => { + let (len, addr) = res?; + if client_udp_addr.is_none() { + client_udp_addr = Some(addr); + } + if len < 4 { continue; } + let frag = buf[2]; + if frag != 0 { continue; } // Fragmented UDP not supported + let atyp = buf[3]; + let (header_len, target) = match atyp { + 0x01 => { + if len < 10 { continue; } + let ip = std::net::Ipv4Addr::new(buf[4], buf[5], buf[6], buf[7]); + let port = u16::from_be_bytes([buf[8], buf[9]]); + (10, format!("{}:{}", ip, port)) + } + 0x03 => { + if len < 5 { continue; } + let domain_len = buf[4] as usize; + if len < 5 + domain_len + 2 { continue; } + let domain = String::from_utf8_lossy(&buf[5..5+domain_len]); + let port = u16::from_be_bytes([buf[5+domain_len], buf[5+domain_len+1]]); + (5 + domain_len + 2, format!("{}:{}", domain, port)) + } + 0x04 => { + if len < 22 { continue; } + let mut octets = [0u8; 16]; + octets.copy_from_slice(&buf[4..20]); + let ip = std::net::Ipv6Addr::from(octets); + let port = u16::from_be_bytes([buf[20], buf[21]]); + (22, format!("[{}]:{}", ip, port)) + } + _ => continue, + }; + let payload = bytes::Bytes::copy_from_slice(&buf[header_len..len]); + let _ = event_tx.send(ProxyEvent::UdpData { stream_id, target, payload }).await; + } + msg = rx.recv() => { + match msg { + Some(ProxyToClientMsg::UdpData(target, data)) => { + if let Some(client_addr) = client_udp_addr { + let mut packet = vec![0x00, 0x00, 0x00]; + let mut parts = target.rsplitn(2, ':'); + let port_str = parts.next().unwrap_or("0"); + let host_str = parts.next().unwrap_or(&target); + let host_str = host_str.trim_start_matches('[').trim_end_matches(']'); + let port = port_str.parse::().unwrap_or(0); + + if let Ok(ipv4) = host_str.parse::() { + packet.push(0x01); + packet.extend_from_slice(&ipv4.octets()); + } else if let Ok(ipv6) = host_str.parse::() { + packet.push(0x04); + packet.extend_from_slice(&ipv6.octets()); + } else { + packet.push(0x03); + let bytes = host_str.as_bytes(); + packet.push(bytes.len() as u8); + packet.extend_from_slice(bytes); + } + packet.extend_from_slice(&port.to_be_bytes()); + packet.extend_from_slice(&data); + let _ = sock_tx.send_to(&packet, client_addr).await; + } + } + Some(ProxyToClientMsg::Close) | Some(ProxyToClientMsg::Error(_)) | None => break, + _ => {} + } + } + } + } + let _ = close_tx.send(stream_id).await; + Ok(()) +} + async fn handle_proxy_client( mut client: TcpStream, stream_id: u16, @@ -178,8 +277,10 @@ async fn handle_proxy_client( if req[0] != 0x05 { return Err(anyhow!("SOCKS5 request version mismatch")); } - if req[1] != 0x01 { - // Not CONNECT — send COMMAND NOT SUPPORTED + + let is_udp = req[1] == 0x03; + if req[1] != 0x01 && !is_udp { + // Not CONNECT and Not UDP ASSOCIATE — send COMMAND NOT SUPPORTED client.write_all(&[0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?; return Err(anyhow!("unsupported SOCKS5 command {}", req[1])); } @@ -217,6 +318,18 @@ async fn handle_proxy_client( } }; + if is_udp { + if debug { tracing::info!("proxy UDP ASSOCIATE stream_id={stream_id}"); } + let udp_socket = UdpSocket::bind("127.0.0.1:0").await?; + let port = udp_socket.local_addr()?.port(); + let mut reply = vec![0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1]; + reply.extend_from_slice(&port.to_be_bytes()); + client.write_all(&reply).await?; + + event_tx.send(ProxyEvent::UdpAssociate { stream_id }).await?; + return handle_udp_associate(client, udp_socket, stream_id, event_tx, rx, close_tx, debug).await; + } + if debug { tracing::info!("proxy CONNECT stream_id={stream_id} target={target}"); } @@ -386,7 +499,7 @@ async fn handle_proxy_client( Some(ProxyToClientMsg::Close) | Some(ProxyToClientMsg::Error(_)) | None => { break; } - Some(ProxyToClientMsg::ConnectOk) => {} // ignored after connect phase + Some(ProxyToClientMsg::ConnectOk) | Some(ProxyToClientMsg::UdpData(_, _)) => {} // ignored after connect phase } } } diff --git a/ostp-client/src/tunnel/wintun_handler.rs b/ostp-client/src/tunnel/wintun_handler.rs index 158dc33..c7b9d14 100644 --- a/ostp-client/src/tunnel/wintun_handler.rs +++ b/ostp-client/src/tunnel/wintun_handler.rs @@ -45,7 +45,16 @@ pub async fn run_wintun_tunnel( let exe = std::env::current_exe()?; let dir = exe.parent().ok_or_else(|| anyhow!("failed to get binary directory"))?; - let tun2socks_exe = dir.join("tun2socks.exe"); + + let mut tun2socks_exe = dir.join("tun2socks.exe"); + if !tun2socks_exe.exists() { + if let Ok(cwd) = std::env::current_dir() { + let cwd_candidate = cwd.join("tun2socks.exe"); + if cwd_candidate.exists() { + tun2socks_exe = cwd_candidate; + } + } + } if !tun2socks_exe.exists() { return Err(anyhow!( @@ -59,14 +68,15 @@ pub async fn run_wintun_tunnel( // 1. Delete stale TUN adapter if it exists from a previous run. // This prevents wintun from creating "ostp_tun 2", "ostp_tun 3", etc. + // Actually, tun2socks can reuse the existing adapter if we just leave it alone. + // We only clear old IP addresses and routes on it. tracing::info!("Cleaning up stale TUN adapter..."); let _ = tokio::task::spawn_blocking(move || { Command::new("powershell") .creation_flags(CREATE_NO_WINDOW) .args(["-NoProfile", "-Command", &format!( - "Get-NetAdapter -Name '{TUN_NAME}*' -ErrorAction SilentlyContinue | \ - Disable-NetAdapter -Confirm:$false -ErrorAction SilentlyContinue; \ - netsh interface set interface \"{TUN_NAME}\" admin=disable 2>$null" + "Remove-NetIPAddress -InterfaceAlias '{TUN_NAME}' -Confirm:$false -ErrorAction SilentlyContinue; \ + Remove-NetRoute -InterfaceAlias '{TUN_NAME}' -Confirm:$false -ErrorAction SilentlyContinue" )]) .output() }).await; @@ -173,10 +183,11 @@ pub async fn run_wintun_tunnel( // 6. Configure the adapter (IP, metric, MTU, DNS) tracing::info!("Applying network configuration..."); let mut net_setup = format!( - "netsh interface ipv4 set address name=\"{TUN_NAME}\" static 10.1.0.2 255.255.255.0 10.1.0.1\n\ + "netsh interface ipv4 set address name=\"{TUN_NAME}\" static 10.1.0.2 255.255.255.0\n\ netsh interface ipv4 set subinterface \"{TUN_NAME}\" mtu={} store=persistent\n\ netsh interface ipv4 set interface name=\"{TUN_NAME}\" metric=1\n\ - New-NetRoute -DestinationPrefix '0.0.0.0/0' -InterfaceAlias '{TUN_NAME}' -NextHop '10.1.0.1' -RouteMetric 1 -ErrorAction SilentlyContinue\n", + New-NetRoute -DestinationPrefix '0.0.0.0/1' -InterfaceAlias '{TUN_NAME}' -RouteMetric 1 -ErrorAction SilentlyContinue\n\ + New-NetRoute -DestinationPrefix '128.0.0.0/1' -InterfaceAlias '{TUN_NAME}' -RouteMetric 1 -ErrorAction SilentlyContinue\n", config.ostp.mtu ); diff --git a/ostp-core/src/congestion.rs b/ostp-core/src/congestion.rs index 48016d5..5ecb36a 100644 --- a/ostp-core/src/congestion.rs +++ b/ostp-core/src/congestion.rs @@ -49,6 +49,7 @@ enum Phase { /// Probe bandwidth: cycle through pacing gains ProbeBandwidth, /// Periodically drain the queue to measure true min RTT + #[allow(dead_code)] ProbeRtt, } @@ -257,6 +258,7 @@ impl CongestionController { } } + #[allow(dead_code)] fn bandwidth_delay_product(&self) -> u64 { // BDP = max_bandwidth * min_rtt let bw = if self.max_bandwidth > 0 { diff --git a/ostp-core/src/protocol.rs b/ostp-core/src/protocol.rs index 153ed07..d74c751 100644 --- a/ostp-core/src/protocol.rs +++ b/ostp-core/src/protocol.rs @@ -100,7 +100,7 @@ pub struct ProtocolMachine { /// Key-derived handshake padding range handshake_pad_min: usize, handshake_pad_max: usize, - mtu: usize, + _mtu: usize, } #[derive(Debug, Clone)] @@ -146,7 +146,7 @@ impl ProtocolMachine { cc: CongestionController::new(config.mtu as u64), handshake_pad_min: config.handshake_pad_min.max(8), handshake_pad_max: config.handshake_pad_max.max(config.handshake_pad_min + 16), - mtu: config.mtu, + _mtu: config.mtu, }) } diff --git a/ostp-core/src/relay.rs b/ostp-core/src/relay.rs index bc0f1fc..4301ac0 100644 --- a/ostp-core/src/relay.rs +++ b/ostp-core/src/relay.rs @@ -10,6 +10,8 @@ pub enum RelayMessage { Error(String), Ping(u64), Pong(u64), + UdpAssociate, + UdpData(String, Vec), } impl RelayMessage { @@ -23,6 +25,17 @@ impl RelayMessage { RelayMessage::Error(msg) => encode_with_len(6, msg.as_bytes()), RelayMessage::Ping(ts) => encode_with_len(7, &ts.to_be_bytes()), RelayMessage::Pong(ts) => encode_with_len(8, &ts.to_be_bytes()), + RelayMessage::UdpAssociate => vec![9], + RelayMessage::UdpData(addr, data) => { + let addr_bytes = addr.as_bytes(); + let mut buf = Vec::with_capacity(1 + 2 + addr_bytes.len() + 2 + data.len()); + buf.push(10); + buf.extend_from_slice(&(addr_bytes.len() as u16).to_be_bytes()); + buf.extend_from_slice(addr_bytes); + buf.extend_from_slice(&(data.len() as u16).to_be_bytes()); + buf.extend_from_slice(data); + buf + } } } @@ -56,11 +69,29 @@ impl RelayMessage { } 8 => { let payload = decode_with_len(&input[1..])?; - if payload.len() != 8 { return Err(anyhow!("invalid pong payload len")); } - let ts = u64::from_be_bytes(payload.try_into().unwrap()); - Ok(RelayMessage::Pong(ts)) + if payload.len() != 8 { + return Err(anyhow!("invalid pong payload")); + } + let mut ts = [0u8; 8]; + ts.copy_from_slice(payload); + Ok(RelayMessage::Pong(u64::from_be_bytes(ts))) } - t => Err(anyhow!("unknown relay message type {t}")), + 9 => Ok(RelayMessage::UdpAssociate), + 10 => { + if input.len() < 3 { return Err(anyhow!("invalid udp data")); } + let addr_len = u16::from_be_bytes([input[1], input[2]]) as usize; + if input.len() < 3 + addr_len + 2 { return Err(anyhow!("invalid udp data")); } + let addr = String::from_utf8(input[3..3+addr_len].to_vec()) + .map_err(|_| anyhow!("invalid utf8 in udp addr"))?; + + let data_offset = 3 + addr_len; + let data_len = u16::from_be_bytes([input[data_offset], input[data_offset+1]]) as usize; + if input.len() < data_offset + 2 + data_len { return Err(anyhow!("invalid udp data")); } + + let data = input[data_offset+2..data_offset+2+data_len].to_vec(); + Ok(RelayMessage::UdpData(addr, data)) + } + _ => Err(anyhow!("unknown relay message type {}", input[0])), } } } diff --git a/ostp-gui/src-tauri/Cargo.lock b/ostp-gui/src-tauri/Cargo.lock index 57eb0cd..0e38548 100644 --- a/ostp-gui/src-tauri/Cargo.lock +++ b/ostp-gui/src-tauri/Cargo.lock @@ -2632,7 +2632,7 @@ dependencies = [ [[package]] name = "ostp-client" -version = "0.2.60" +version = "0.2.61" dependencies = [ "anyhow", "base64 0.22.1", @@ -2662,7 +2662,7 @@ dependencies = [ [[package]] name = "ostp-core" -version = "0.2.60" +version = "0.2.61" dependencies = [ "anyhow", "bytes", diff --git a/ostp-gui/src-tauri/src/lib.rs b/ostp-gui/src-tauri/src/lib.rs index 5cb3fac..fff40ba 100644 --- a/ostp-gui/src-tauri/src/lib.rs +++ b/ostp-gui/src-tauri/src/lib.rs @@ -484,14 +484,19 @@ fn find_helper_exe() -> Option { } #[cfg(target_os = "windows")] -fn launch_as_admin(exe: &PathBuf) -> Result<()> { +fn launch_as_admin(exe: &std::path::PathBuf) -> anyhow::Result<()> { use std::ffi::OsStr; use std::os::windows::ffi::OsStrExt; use std::ptr::null_mut; + use std::path::Path; let exe_wstr: Vec = exe.as_os_str().encode_wide().chain(Some(0)).collect(); let verb_wstr: Vec = OsStr::new("runas").encode_wide().chain(Some(0)).collect(); #[link(name = "shell32")] extern "system" { fn ShellExecuteW(h: *mut std::ffi::c_void, op: *const u16, f: *const u16, p: *const u16, d: *const u16, s: i32) -> isize; } - let dir_wstr: Vec = exe.parent().unwrap_or(Path::new(".")).as_os_str().encode_wide().chain(Some(0)).collect(); + + // Use the GUI executable's directory as the working directory so dependencies are found + let cwd_path = std::env::current_exe().unwrap_or_else(|_| std::path::PathBuf::from(".")); + let dir_wstr: Vec = cwd_path.parent().unwrap_or(Path::new(".")).as_os_str().encode_wide().chain(Some(0)).collect(); + let ret = unsafe { ShellExecuteW(null_mut(), verb_wstr.as_ptr(), exe_wstr.as_ptr(), null_mut(), dir_wstr.as_ptr(), 0) }; if ret <= 32 { anyhow::bail!("UAC denied or helper missing."); } Ok(()) diff --git a/ostp-gui/src/main.js b/ostp-gui/src/main.js index fc2b7ef..cd78cec 100644 --- a/ostp-gui/src/main.js +++ b/ostp-gui/src/main.js @@ -123,6 +123,7 @@ function setState(next) { statusSub.textContent = t('hint_connecting'); connInfo.classList.add('hidden'); clearInterval(uptimeTimer); + uptimeTimer = null; uptimeSecs = 0; } else if (next === 'connected') { @@ -141,6 +142,7 @@ function setState(next) { // Start uptime counter if (!uptimeTimer) { uptimeSecs = 0; + statusSub.textContent = fmtTime(uptimeSecs); uptimeTimer = setInterval(() => { uptimeSecs++; statusSub.textContent = fmtTime(uptimeSecs); diff --git a/ostp-server/src/lib.rs b/ostp-server/src/lib.rs index fc017f2..b0148e2 100644 --- a/ostp-server/src/lib.rs +++ b/ostp-server/src/lib.rs @@ -59,6 +59,7 @@ pub(crate) enum UiEvent { pub(crate) struct RemoteState { pub data_tx: mpsc::UnboundedSender, + pub udp_tx: Option>, pub cancel_tx: mpsc::Sender<()>, #[allow(dead_code)] pub is_dns: bool, @@ -360,6 +361,7 @@ async fn run_server_loop( ) -> Result<()> { let mut remotes: HashMap<(u32, u16), RemoteState> = HashMap::new(); let (stream_tx, mut stream_rx) = mpsc::unbounded_channel::<(u32, u16, Vec)>(); + let (udp_reply_tx, mut udp_reply_rx) = mpsc::unbounded_channel::<(u32, u16, String, Vec)>(); let (connect_tx, mut connect_rx) = mpsc::unbounded_channel::<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>(); let tcp_map = std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())); @@ -545,6 +547,7 @@ async fn run_server_loop( &mut remotes, &ui_event_tx, stream_tx.clone(), + udp_reply_tx.clone(), connect_tx.clone(), outbound.clone(), dns_server.clone(), @@ -568,6 +571,9 @@ async fn run_server_loop( let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::Data(data), &mut dispatcher, &socket, &ui_event_tx).await; } } + Some((session_id, stream_id, target, data)) = udp_reply_rx.recv() => { + let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::UdpData(target, data), &mut dispatcher, &socket, &ui_event_tx).await; + } Some((session_id, stream_id, target, res)) = connect_rx.recv() => { match res { Ok((writer, cancel_tx)) => { @@ -580,7 +586,7 @@ async fn run_server_loop( } } }); - remotes.insert((session_id, stream_id), RemoteState { data_tx, cancel_tx, is_dns: false }); + remotes.insert((session_id, stream_id), RemoteState { data_tx, udp_tx: None, cancel_tx, is_dns: false }); let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, &mut dispatcher, &socket, &ui_event_tx).await; let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CONNECT ok for [{session_id}:{stream_id}] -> {target}"))); } diff --git a/ostp-server/src/relay.rs b/ostp-server/src/relay.rs index 7736603..c04d3e5 100644 --- a/ostp-server/src/relay.rs +++ b/ostp-server/src/relay.rs @@ -21,6 +21,7 @@ pub async fn handle_relay_message( remotes: &mut HashMap<(u32, u16), RemoteState>, ui_event_tx: &mpsc::UnboundedSender, stream_tx: mpsc::UnboundedSender<(u32, u16, Vec)>, + udp_reply_tx: mpsc::UnboundedSender<(u32, u16, String, Vec)>, connect_tx: mpsc::UnboundedSender<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>, outbound_cfg: Option, dns_server: std::sync::Arc, @@ -52,6 +53,7 @@ pub async fn handle_relay_message( remotes.insert((session_id, stream_id), RemoteState { data_tx: dns_query_tx, + udp_tx: None, cancel_tx, is_dns: true, }); @@ -121,12 +123,82 @@ pub async fn handle_relay_message( send_relay_to_stream(session_id, stream_id, RelayMessage::Pong(ts), dispatcher, socket, ui_event_tx).await?; } RelayMessage::Pong(_) => {} + RelayMessage::UdpAssociate => { + if debug { + let _ = ui_event_tx.send(UiEvent::Log(format!("Relay UDP ASSOCIATE stream_id={stream_id}"))); + } + let server_udp = match UdpSocket::bind("0.0.0.0:0").await { + Ok(s) => std::sync::Arc::new(s), + Err(e) => { + let _ = ui_event_tx.send(UiEvent::Log(format!("UDP bind failed: {e}"))); + return Ok(()); + } + }; + + let (udp_tx, mut udp_rx) = mpsc::unbounded_channel::<(String, Bytes)>(); + let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); + let (dummy_data_tx, _) = mpsc::unbounded_channel::(); + + // Outbound UDP loop (tunnel -> target) + let tx_sock = server_udp.clone(); + let dns_srv = dns_server.clone(); + let udp_reply_clone_dns = udp_reply_tx.clone(); + let client_ip = peer_addr.ip(); + tokio::spawn(async move { + while let Some((target, data)) = udp_rx.recv().await { + let is_internal_dns = target == "10.1.0.1:53" && dns_srv.config.read().await.enabled; + if is_internal_dns { + if let Some(resp_bytes) = dns_srv.resolve(&data, client_ip).await { + let _ = udp_reply_clone_dns.send((session_id, stream_id, target, resp_bytes)); + } + } else { + let _ = tx_sock.send_to(&data, &target).await; + } + } + }); + + // Inbound UDP loop (target -> tunnel) + let rx_sock = server_udp.clone(); + let udp_reply_clone = udp_reply_tx.clone(); + tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + loop { + tokio::select! { + _ = cancel_rx.recv() => break, + res = rx_sock.recv_from(&mut buf) => { + match res { + Ok((len, addr)) => { + let _ = udp_reply_clone.send((session_id, stream_id, addr.to_string(), buf[..len].to_vec())); + } + Err(_) => break, + } + } + } + } + }); + + remotes.insert((session_id, stream_id), RemoteState { + data_tx: dummy_data_tx, + udp_tx: Some(udp_tx), + cancel_tx, + is_dns: false, + }); + + send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, dispatcher, socket, ui_event_tx).await?; + } + RelayMessage::UdpData(target, data) => { + if let Some(remote) = remotes.get_mut(&(session_id, stream_id)) { + if let Some(ref udp_tx) = remote.udp_tx { + let _ = udp_tx.send((target, Bytes::from(data))); + } + } else { + let _ = ui_event_tx.send(UiEvent::Log(format!("Relay UDP DATA for unknown stream [{session_id}:{stream_id}]"))); + } + } } Ok(()) } - - pub async fn send_relay_to_stream( session_id: u32, stream_id: u16, diff --git a/ostp-tun-helper/src/main.rs b/ostp-tun-helper/src/main.rs index 8b34a17..aff6e6d 100644 --- a/ostp-tun-helper/src/main.rs +++ b/ostp-tun-helper/src/main.rs @@ -48,6 +48,12 @@ struct TunnelState { #[tokio::main] async fn main() -> Result<()> { + if let Ok(exe) = std::env::current_exe() { + if let Some(dir) = exe.parent() { + let _ = std::env::set_current_dir(dir); + } + } + log_to_file("Helper started (TCP mode)"); if let Err(e) = run_server().await { log_to_file(&format!("Fatal error: {}", e));