Fix Windows TUN NLA delays, UI timer, and Android UDP DNS resolution

This commit is contained in:
ospab 2026-05-28 18:19:01 +03:00
parent a0292b6087
commit 1b836b26ab
14 changed files with 294 additions and 24 deletions

View File

@ -200,6 +200,12 @@ impl Bridge {
self.last_rtt_ms = now.saturating_sub(ts) as f64; self.last_rtt_ms = now.saturating_sub(ts) as f64;
self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed); self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed);
} }
RelayMessage::UdpAssociate => {
// Should not be received by client, ignore
}
RelayMessage::UdpData(target, data) => {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data))));
}
RelayMessage::KeepAlive | RelayMessage::Ping(_) | RelayMessage::Connect(_) => {} RelayMessage::KeepAlive | RelayMessage::Ping(_) | RelayMessage::Connect(_) => {}
} }
} }
@ -581,6 +587,13 @@ impl Bridge {
let _ = tx.send(UiEvent::Log(format!("Proxy CONNECT stream_id={stream_id} target={target}"))).await; let _ = tx.send(UiEvent::Log(format!("Proxy CONNECT stream_id={stream_id} target={target}"))).await;
(stream_id, RelayMessage::Connect(target), false) (stream_id, RelayMessage::Connect(target), false)
} }
ProxyEvent::UdpAssociate { stream_id } => {
let _ = tx.send(UiEvent::Log(format!("Proxy UDP ASSOCIATE stream_id={stream_id}"))).await;
(stream_id, RelayMessage::UdpAssociate, false)
}
ProxyEvent::UdpData { stream_id, target, payload } => {
(stream_id, RelayMessage::UdpData(target, payload.to_vec()), false)
}
ProxyEvent::Data { stream_id, payload } => (stream_id, RelayMessage::Data(payload.to_vec()), false), ProxyEvent::Data { stream_id, payload } => (stream_id, RelayMessage::Data(payload.to_vec()), false),
ProxyEvent::Close { stream_id } => { ProxyEvent::Close { stream_id } => {
let _ = tx.send(UiEvent::Log(format!("Proxy CLOSE stream_id={stream_id}"))).await; let _ = tx.send(UiEvent::Log(format!("Proxy CLOSE stream_id={stream_id}"))).await;

View File

@ -41,6 +41,14 @@ pub enum ProxyEvent {
stream_id: u16, stream_id: u16,
target: String, target: String,
}, },
UdpAssociate {
stream_id: u16,
},
UdpData {
stream_id: u16,
target: String,
payload: bytes::Bytes,
},
Data { Data {
stream_id: u16, stream_id: u16,
payload: bytes::Bytes, payload: bytes::Bytes,
@ -54,6 +62,7 @@ pub enum ProxyEvent {
pub enum ProxyToClientMsg { pub enum ProxyToClientMsg {
ConnectOk, ConnectOk,
Data(bytes::Bytes), Data(bytes::Bytes),
UdpData(String, bytes::Bytes),
Close, Close,
Error(String), Error(String),
} }

View File

@ -126,7 +126,7 @@ pub async fn run_native_tunnel(
.build()?; .build()?;
let mut runner_task = tokio::spawn(async move { let mut runner_task = tokio::spawn(async move {
if let Some(mut runner) = tcp_runner { if let Some(runner) = tcp_runner {
let _ = runner.await; let _ = runner.await;
} }
}); });

View File

@ -1,7 +1,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream, UdpSocket};
use std::sync::Arc;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use tokio::time::{timeout, Duration}; use tokio::time::{timeout, Duration};
@ -140,6 +141,104 @@ impl Drop for StreamGuard {
} }
} }
async fn handle_udp_associate(
mut client_tcp: TcpStream,
udp_socket: tokio::net::UdpSocket,
stream_id: u16,
event_tx: mpsc::Sender<ProxyEvent>,
mut rx: mpsc::UnboundedReceiver<ProxyToClientMsg>,
close_tx: mpsc::Sender<u16>,
_debug: bool,
) -> Result<()> {
let mut client_udp_addr = None;
let mut buf = vec![0u8; 65536];
let udp_socket = Arc::new(udp_socket);
let sock_rx = udp_socket.clone();
let sock_tx = udp_socket;
let mut tcp_buf = [0u8; 1];
loop {
tokio::select! {
res = client_tcp.read(&mut tcp_buf) => {
let n = res?;
if n == 0 { break; }
}
res = sock_rx.recv_from(&mut buf) => {
let (len, addr) = res?;
if client_udp_addr.is_none() {
client_udp_addr = Some(addr);
}
if len < 4 { continue; }
let frag = buf[2];
if frag != 0 { continue; } // Fragmented UDP not supported
let atyp = buf[3];
let (header_len, target) = match atyp {
0x01 => {
if len < 10 { continue; }
let ip = std::net::Ipv4Addr::new(buf[4], buf[5], buf[6], buf[7]);
let port = u16::from_be_bytes([buf[8], buf[9]]);
(10, format!("{}:{}", ip, port))
}
0x03 => {
if len < 5 { continue; }
let domain_len = buf[4] as usize;
if len < 5 + domain_len + 2 { continue; }
let domain = String::from_utf8_lossy(&buf[5..5+domain_len]);
let port = u16::from_be_bytes([buf[5+domain_len], buf[5+domain_len+1]]);
(5 + domain_len + 2, format!("{}:{}", domain, port))
}
0x04 => {
if len < 22 { continue; }
let mut octets = [0u8; 16];
octets.copy_from_slice(&buf[4..20]);
let ip = std::net::Ipv6Addr::from(octets);
let port = u16::from_be_bytes([buf[20], buf[21]]);
(22, format!("[{}]:{}", ip, port))
}
_ => continue,
};
let payload = bytes::Bytes::copy_from_slice(&buf[header_len..len]);
let _ = event_tx.send(ProxyEvent::UdpData { stream_id, target, payload }).await;
}
msg = rx.recv() => {
match msg {
Some(ProxyToClientMsg::UdpData(target, data)) => {
if let Some(client_addr) = client_udp_addr {
let mut packet = vec![0x00, 0x00, 0x00];
let mut parts = target.rsplitn(2, ':');
let port_str = parts.next().unwrap_or("0");
let host_str = parts.next().unwrap_or(&target);
let host_str = host_str.trim_start_matches('[').trim_end_matches(']');
let port = port_str.parse::<u16>().unwrap_or(0);
if let Ok(ipv4) = host_str.parse::<std::net::Ipv4Addr>() {
packet.push(0x01);
packet.extend_from_slice(&ipv4.octets());
} else if let Ok(ipv6) = host_str.parse::<std::net::Ipv6Addr>() {
packet.push(0x04);
packet.extend_from_slice(&ipv6.octets());
} else {
packet.push(0x03);
let bytes = host_str.as_bytes();
packet.push(bytes.len() as u8);
packet.extend_from_slice(bytes);
}
packet.extend_from_slice(&port.to_be_bytes());
packet.extend_from_slice(&data);
let _ = sock_tx.send_to(&packet, client_addr).await;
}
}
Some(ProxyToClientMsg::Close) | Some(ProxyToClientMsg::Error(_)) | None => break,
_ => {}
}
}
}
}
let _ = close_tx.send(stream_id).await;
Ok(())
}
async fn handle_proxy_client( async fn handle_proxy_client(
mut client: TcpStream, mut client: TcpStream,
stream_id: u16, stream_id: u16,
@ -178,8 +277,10 @@ async fn handle_proxy_client(
if req[0] != 0x05 { if req[0] != 0x05 {
return Err(anyhow!("SOCKS5 request version mismatch")); return Err(anyhow!("SOCKS5 request version mismatch"));
} }
if req[1] != 0x01 {
// Not CONNECT — send COMMAND NOT SUPPORTED let is_udp = req[1] == 0x03;
if req[1] != 0x01 && !is_udp {
// Not CONNECT and Not UDP ASSOCIATE — send COMMAND NOT SUPPORTED
client.write_all(&[0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?; client.write_all(&[0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0]).await?;
return Err(anyhow!("unsupported SOCKS5 command {}", req[1])); return Err(anyhow!("unsupported SOCKS5 command {}", req[1]));
} }
@ -217,6 +318,18 @@ async fn handle_proxy_client(
} }
}; };
if is_udp {
if debug { tracing::info!("proxy UDP ASSOCIATE stream_id={stream_id}"); }
let udp_socket = UdpSocket::bind("127.0.0.1:0").await?;
let port = udp_socket.local_addr()?.port();
let mut reply = vec![0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1];
reply.extend_from_slice(&port.to_be_bytes());
client.write_all(&reply).await?;
event_tx.send(ProxyEvent::UdpAssociate { stream_id }).await?;
return handle_udp_associate(client, udp_socket, stream_id, event_tx, rx, close_tx, debug).await;
}
if debug { if debug {
tracing::info!("proxy CONNECT stream_id={stream_id} target={target}"); tracing::info!("proxy CONNECT stream_id={stream_id} target={target}");
} }
@ -386,7 +499,7 @@ async fn handle_proxy_client(
Some(ProxyToClientMsg::Close) | Some(ProxyToClientMsg::Error(_)) | None => { Some(ProxyToClientMsg::Close) | Some(ProxyToClientMsg::Error(_)) | None => {
break; break;
} }
Some(ProxyToClientMsg::ConnectOk) => {} // ignored after connect phase Some(ProxyToClientMsg::ConnectOk) | Some(ProxyToClientMsg::UdpData(_, _)) => {} // ignored after connect phase
} }
} }
} }

View File

@ -45,7 +45,16 @@ pub async fn run_wintun_tunnel(
let exe = std::env::current_exe()?; let exe = std::env::current_exe()?;
let dir = exe.parent().ok_or_else(|| anyhow!("failed to get binary directory"))?; let dir = exe.parent().ok_or_else(|| anyhow!("failed to get binary directory"))?;
let tun2socks_exe = dir.join("tun2socks.exe");
let mut tun2socks_exe = dir.join("tun2socks.exe");
if !tun2socks_exe.exists() {
if let Ok(cwd) = std::env::current_dir() {
let cwd_candidate = cwd.join("tun2socks.exe");
if cwd_candidate.exists() {
tun2socks_exe = cwd_candidate;
}
}
}
if !tun2socks_exe.exists() { if !tun2socks_exe.exists() {
return Err(anyhow!( return Err(anyhow!(
@ -59,14 +68,15 @@ pub async fn run_wintun_tunnel(
// 1. Delete stale TUN adapter if it exists from a previous run. // 1. Delete stale TUN adapter if it exists from a previous run.
// This prevents wintun from creating "ostp_tun 2", "ostp_tun 3", etc. // This prevents wintun from creating "ostp_tun 2", "ostp_tun 3", etc.
// Actually, tun2socks can reuse the existing adapter if we just leave it alone.
// We only clear old IP addresses and routes on it.
tracing::info!("Cleaning up stale TUN adapter..."); tracing::info!("Cleaning up stale TUN adapter...");
let _ = tokio::task::spawn_blocking(move || { let _ = tokio::task::spawn_blocking(move || {
Command::new("powershell") Command::new("powershell")
.creation_flags(CREATE_NO_WINDOW) .creation_flags(CREATE_NO_WINDOW)
.args(["-NoProfile", "-Command", &format!( .args(["-NoProfile", "-Command", &format!(
"Get-NetAdapter -Name '{TUN_NAME}*' -ErrorAction SilentlyContinue | \ "Remove-NetIPAddress -InterfaceAlias '{TUN_NAME}' -Confirm:$false -ErrorAction SilentlyContinue; \
Disable-NetAdapter -Confirm:$false -ErrorAction SilentlyContinue; \ Remove-NetRoute -InterfaceAlias '{TUN_NAME}' -Confirm:$false -ErrorAction SilentlyContinue"
netsh interface set interface \"{TUN_NAME}\" admin=disable 2>$null"
)]) )])
.output() .output()
}).await; }).await;
@ -173,10 +183,11 @@ pub async fn run_wintun_tunnel(
// 6. Configure the adapter (IP, metric, MTU, DNS) // 6. Configure the adapter (IP, metric, MTU, DNS)
tracing::info!("Applying network configuration..."); tracing::info!("Applying network configuration...");
let mut net_setup = format!( let mut net_setup = format!(
"netsh interface ipv4 set address name=\"{TUN_NAME}\" static 10.1.0.2 255.255.255.0 10.1.0.1\n\ "netsh interface ipv4 set address name=\"{TUN_NAME}\" static 10.1.0.2 255.255.255.0\n\
netsh interface ipv4 set subinterface \"{TUN_NAME}\" mtu={} store=persistent\n\ netsh interface ipv4 set subinterface \"{TUN_NAME}\" mtu={} store=persistent\n\
netsh interface ipv4 set interface name=\"{TUN_NAME}\" metric=1\n\ netsh interface ipv4 set interface name=\"{TUN_NAME}\" metric=1\n\
New-NetRoute -DestinationPrefix '0.0.0.0/0' -InterfaceAlias '{TUN_NAME}' -NextHop '10.1.0.1' -RouteMetric 1 -ErrorAction SilentlyContinue\n", New-NetRoute -DestinationPrefix '0.0.0.0/1' -InterfaceAlias '{TUN_NAME}' -RouteMetric 1 -ErrorAction SilentlyContinue\n\
New-NetRoute -DestinationPrefix '128.0.0.0/1' -InterfaceAlias '{TUN_NAME}' -RouteMetric 1 -ErrorAction SilentlyContinue\n",
config.ostp.mtu config.ostp.mtu
); );

View File

@ -49,6 +49,7 @@ enum Phase {
/// Probe bandwidth: cycle through pacing gains /// Probe bandwidth: cycle through pacing gains
ProbeBandwidth, ProbeBandwidth,
/// Periodically drain the queue to measure true min RTT /// Periodically drain the queue to measure true min RTT
#[allow(dead_code)]
ProbeRtt, ProbeRtt,
} }
@ -257,6 +258,7 @@ impl CongestionController {
} }
} }
#[allow(dead_code)]
fn bandwidth_delay_product(&self) -> u64 { fn bandwidth_delay_product(&self) -> u64 {
// BDP = max_bandwidth * min_rtt // BDP = max_bandwidth * min_rtt
let bw = if self.max_bandwidth > 0 { let bw = if self.max_bandwidth > 0 {

View File

@ -100,7 +100,7 @@ pub struct ProtocolMachine {
/// Key-derived handshake padding range /// Key-derived handshake padding range
handshake_pad_min: usize, handshake_pad_min: usize,
handshake_pad_max: usize, handshake_pad_max: usize,
mtu: usize, _mtu: usize,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -146,7 +146,7 @@ impl ProtocolMachine {
cc: CongestionController::new(config.mtu as u64), cc: CongestionController::new(config.mtu as u64),
handshake_pad_min: config.handshake_pad_min.max(8), handshake_pad_min: config.handshake_pad_min.max(8),
handshake_pad_max: config.handshake_pad_max.max(config.handshake_pad_min + 16), handshake_pad_max: config.handshake_pad_max.max(config.handshake_pad_min + 16),
mtu: config.mtu, _mtu: config.mtu,
}) })
} }

View File

@ -10,6 +10,8 @@ pub enum RelayMessage {
Error(String), Error(String),
Ping(u64), Ping(u64),
Pong(u64), Pong(u64),
UdpAssociate,
UdpData(String, Vec<u8>),
} }
impl RelayMessage { impl RelayMessage {
@ -23,6 +25,17 @@ impl RelayMessage {
RelayMessage::Error(msg) => encode_with_len(6, msg.as_bytes()), RelayMessage::Error(msg) => encode_with_len(6, msg.as_bytes()),
RelayMessage::Ping(ts) => encode_with_len(7, &ts.to_be_bytes()), RelayMessage::Ping(ts) => encode_with_len(7, &ts.to_be_bytes()),
RelayMessage::Pong(ts) => encode_with_len(8, &ts.to_be_bytes()), RelayMessage::Pong(ts) => encode_with_len(8, &ts.to_be_bytes()),
RelayMessage::UdpAssociate => vec![9],
RelayMessage::UdpData(addr, data) => {
let addr_bytes = addr.as_bytes();
let mut buf = Vec::with_capacity(1 + 2 + addr_bytes.len() + 2 + data.len());
buf.push(10);
buf.extend_from_slice(&(addr_bytes.len() as u16).to_be_bytes());
buf.extend_from_slice(addr_bytes);
buf.extend_from_slice(&(data.len() as u16).to_be_bytes());
buf.extend_from_slice(data);
buf
}
} }
} }
@ -56,11 +69,29 @@ impl RelayMessage {
} }
8 => { 8 => {
let payload = decode_with_len(&input[1..])?; let payload = decode_with_len(&input[1..])?;
if payload.len() != 8 { return Err(anyhow!("invalid pong payload len")); } if payload.len() != 8 {
let ts = u64::from_be_bytes(payload.try_into().unwrap()); return Err(anyhow!("invalid pong payload"));
Ok(RelayMessage::Pong(ts))
} }
t => Err(anyhow!("unknown relay message type {t}")), let mut ts = [0u8; 8];
ts.copy_from_slice(payload);
Ok(RelayMessage::Pong(u64::from_be_bytes(ts)))
}
9 => Ok(RelayMessage::UdpAssociate),
10 => {
if input.len() < 3 { return Err(anyhow!("invalid udp data")); }
let addr_len = u16::from_be_bytes([input[1], input[2]]) as usize;
if input.len() < 3 + addr_len + 2 { return Err(anyhow!("invalid udp data")); }
let addr = String::from_utf8(input[3..3+addr_len].to_vec())
.map_err(|_| anyhow!("invalid utf8 in udp addr"))?;
let data_offset = 3 + addr_len;
let data_len = u16::from_be_bytes([input[data_offset], input[data_offset+1]]) as usize;
if input.len() < data_offset + 2 + data_len { return Err(anyhow!("invalid udp data")); }
let data = input[data_offset+2..data_offset+2+data_len].to_vec();
Ok(RelayMessage::UdpData(addr, data))
}
_ => Err(anyhow!("unknown relay message type {}", input[0])),
} }
} }
} }

View File

@ -2632,7 +2632,7 @@ dependencies = [
[[package]] [[package]]
name = "ostp-client" name = "ostp-client"
version = "0.2.60" version = "0.2.61"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"base64 0.22.1", "base64 0.22.1",
@ -2662,7 +2662,7 @@ dependencies = [
[[package]] [[package]]
name = "ostp-core" name = "ostp-core"
version = "0.2.60" version = "0.2.61"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes", "bytes",

View File

@ -484,14 +484,19 @@ fn find_helper_exe() -> Option<PathBuf> {
} }
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
fn launch_as_admin(exe: &PathBuf) -> Result<()> { fn launch_as_admin(exe: &std::path::PathBuf) -> anyhow::Result<()> {
use std::ffi::OsStr; use std::ffi::OsStr;
use std::os::windows::ffi::OsStrExt; use std::os::windows::ffi::OsStrExt;
use std::ptr::null_mut; use std::ptr::null_mut;
use std::path::Path;
let exe_wstr: Vec<u16> = exe.as_os_str().encode_wide().chain(Some(0)).collect(); let exe_wstr: Vec<u16> = exe.as_os_str().encode_wide().chain(Some(0)).collect();
let verb_wstr: Vec<u16> = OsStr::new("runas").encode_wide().chain(Some(0)).collect(); let verb_wstr: Vec<u16> = OsStr::new("runas").encode_wide().chain(Some(0)).collect();
#[link(name = "shell32")] extern "system" { fn ShellExecuteW(h: *mut std::ffi::c_void, op: *const u16, f: *const u16, p: *const u16, d: *const u16, s: i32) -> isize; } #[link(name = "shell32")] extern "system" { fn ShellExecuteW(h: *mut std::ffi::c_void, op: *const u16, f: *const u16, p: *const u16, d: *const u16, s: i32) -> isize; }
let dir_wstr: Vec<u16> = exe.parent().unwrap_or(Path::new(".")).as_os_str().encode_wide().chain(Some(0)).collect();
// Use the GUI executable's directory as the working directory so dependencies are found
let cwd_path = std::env::current_exe().unwrap_or_else(|_| std::path::PathBuf::from("."));
let dir_wstr: Vec<u16> = cwd_path.parent().unwrap_or(Path::new(".")).as_os_str().encode_wide().chain(Some(0)).collect();
let ret = unsafe { ShellExecuteW(null_mut(), verb_wstr.as_ptr(), exe_wstr.as_ptr(), null_mut(), dir_wstr.as_ptr(), 0) }; let ret = unsafe { ShellExecuteW(null_mut(), verb_wstr.as_ptr(), exe_wstr.as_ptr(), null_mut(), dir_wstr.as_ptr(), 0) };
if ret <= 32 { anyhow::bail!("UAC denied or helper missing."); } if ret <= 32 { anyhow::bail!("UAC denied or helper missing."); }
Ok(()) Ok(())

View File

@ -123,6 +123,7 @@ function setState(next) {
statusSub.textContent = t('hint_connecting'); statusSub.textContent = t('hint_connecting');
connInfo.classList.add('hidden'); connInfo.classList.add('hidden');
clearInterval(uptimeTimer); clearInterval(uptimeTimer);
uptimeTimer = null;
uptimeSecs = 0; uptimeSecs = 0;
} else if (next === 'connected') { } else if (next === 'connected') {
@ -141,6 +142,7 @@ function setState(next) {
// Start uptime counter // Start uptime counter
if (!uptimeTimer) { if (!uptimeTimer) {
uptimeSecs = 0; uptimeSecs = 0;
statusSub.textContent = fmtTime(uptimeSecs);
uptimeTimer = setInterval(() => { uptimeTimer = setInterval(() => {
uptimeSecs++; uptimeSecs++;
statusSub.textContent = fmtTime(uptimeSecs); statusSub.textContent = fmtTime(uptimeSecs);

View File

@ -59,6 +59,7 @@ pub(crate) enum UiEvent {
pub(crate) struct RemoteState { pub(crate) struct RemoteState {
pub data_tx: mpsc::UnboundedSender<Bytes>, pub data_tx: mpsc::UnboundedSender<Bytes>,
pub udp_tx: Option<mpsc::UnboundedSender<(String, Bytes)>>,
pub cancel_tx: mpsc::Sender<()>, pub cancel_tx: mpsc::Sender<()>,
#[allow(dead_code)] #[allow(dead_code)]
pub is_dns: bool, pub is_dns: bool,
@ -360,6 +361,7 @@ async fn run_server_loop(
) -> Result<()> { ) -> Result<()> {
let mut remotes: HashMap<(u32, u16), RemoteState> = HashMap::new(); let mut remotes: HashMap<(u32, u16), RemoteState> = HashMap::new();
let (stream_tx, mut stream_rx) = mpsc::unbounded_channel::<(u32, u16, Vec<u8>)>(); let (stream_tx, mut stream_rx) = mpsc::unbounded_channel::<(u32, u16, Vec<u8>)>();
let (udp_reply_tx, mut udp_reply_rx) = mpsc::unbounded_channel::<(u32, u16, String, Vec<u8>)>();
let (connect_tx, mut connect_rx) = mpsc::unbounded_channel::<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>(); let (connect_tx, mut connect_rx) = mpsc::unbounded_channel::<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>();
let tcp_map = std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())); let tcp_map = std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new()));
@ -545,6 +547,7 @@ async fn run_server_loop(
&mut remotes, &mut remotes,
&ui_event_tx, &ui_event_tx,
stream_tx.clone(), stream_tx.clone(),
udp_reply_tx.clone(),
connect_tx.clone(), connect_tx.clone(),
outbound.clone(), outbound.clone(),
dns_server.clone(), dns_server.clone(),
@ -568,6 +571,9 @@ async fn run_server_loop(
let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::Data(data), &mut dispatcher, &socket, &ui_event_tx).await; let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::Data(data), &mut dispatcher, &socket, &ui_event_tx).await;
} }
} }
Some((session_id, stream_id, target, data)) = udp_reply_rx.recv() => {
let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::UdpData(target, data), &mut dispatcher, &socket, &ui_event_tx).await;
}
Some((session_id, stream_id, target, res)) = connect_rx.recv() => { Some((session_id, stream_id, target, res)) = connect_rx.recv() => {
match res { match res {
Ok((writer, cancel_tx)) => { Ok((writer, cancel_tx)) => {
@ -580,7 +586,7 @@ async fn run_server_loop(
} }
} }
}); });
remotes.insert((session_id, stream_id), RemoteState { data_tx, cancel_tx, is_dns: false }); remotes.insert((session_id, stream_id), RemoteState { data_tx, udp_tx: None, cancel_tx, is_dns: false });
let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, &mut dispatcher, &socket, &ui_event_tx).await; let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, &mut dispatcher, &socket, &ui_event_tx).await;
let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CONNECT ok for [{session_id}:{stream_id}] -> {target}"))); let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CONNECT ok for [{session_id}:{stream_id}] -> {target}")));
} }

View File

@ -21,6 +21,7 @@ pub async fn handle_relay_message(
remotes: &mut HashMap<(u32, u16), RemoteState>, remotes: &mut HashMap<(u32, u16), RemoteState>,
ui_event_tx: &mpsc::UnboundedSender<UiEvent>, ui_event_tx: &mpsc::UnboundedSender<UiEvent>,
stream_tx: mpsc::UnboundedSender<(u32, u16, Vec<u8>)>, stream_tx: mpsc::UnboundedSender<(u32, u16, Vec<u8>)>,
udp_reply_tx: mpsc::UnboundedSender<(u32, u16, String, Vec<u8>)>,
connect_tx: mpsc::UnboundedSender<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>, connect_tx: mpsc::UnboundedSender<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>,
outbound_cfg: Option<OutboundConfig>, outbound_cfg: Option<OutboundConfig>,
dns_server: std::sync::Arc<crate::dns::DnsServer>, dns_server: std::sync::Arc<crate::dns::DnsServer>,
@ -52,6 +53,7 @@ pub async fn handle_relay_message(
remotes.insert((session_id, stream_id), RemoteState { remotes.insert((session_id, stream_id), RemoteState {
data_tx: dns_query_tx, data_tx: dns_query_tx,
udp_tx: None,
cancel_tx, cancel_tx,
is_dns: true, is_dns: true,
}); });
@ -121,12 +123,82 @@ pub async fn handle_relay_message(
send_relay_to_stream(session_id, stream_id, RelayMessage::Pong(ts), dispatcher, socket, ui_event_tx).await?; send_relay_to_stream(session_id, stream_id, RelayMessage::Pong(ts), dispatcher, socket, ui_event_tx).await?;
} }
RelayMessage::Pong(_) => {} RelayMessage::Pong(_) => {}
RelayMessage::UdpAssociate => {
if debug {
let _ = ui_event_tx.send(UiEvent::Log(format!("Relay UDP ASSOCIATE stream_id={stream_id}")));
}
let server_udp = match UdpSocket::bind("0.0.0.0:0").await {
Ok(s) => std::sync::Arc::new(s),
Err(e) => {
let _ = ui_event_tx.send(UiEvent::Log(format!("UDP bind failed: {e}")));
return Ok(());
}
};
let (udp_tx, mut udp_rx) = mpsc::unbounded_channel::<(String, Bytes)>();
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
let (dummy_data_tx, _) = mpsc::unbounded_channel::<Bytes>();
// Outbound UDP loop (tunnel -> target)
let tx_sock = server_udp.clone();
let dns_srv = dns_server.clone();
let udp_reply_clone_dns = udp_reply_tx.clone();
let client_ip = peer_addr.ip();
tokio::spawn(async move {
while let Some((target, data)) = udp_rx.recv().await {
let is_internal_dns = target == "10.1.0.1:53" && dns_srv.config.read().await.enabled;
if is_internal_dns {
if let Some(resp_bytes) = dns_srv.resolve(&data, client_ip).await {
let _ = udp_reply_clone_dns.send((session_id, stream_id, target, resp_bytes));
}
} else {
let _ = tx_sock.send_to(&data, &target).await;
}
}
});
// Inbound UDP loop (target -> tunnel)
let rx_sock = server_udp.clone();
let udp_reply_clone = udp_reply_tx.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
tokio::select! {
_ = cancel_rx.recv() => break,
res = rx_sock.recv_from(&mut buf) => {
match res {
Ok((len, addr)) => {
let _ = udp_reply_clone.send((session_id, stream_id, addr.to_string(), buf[..len].to_vec()));
}
Err(_) => break,
}
}
}
}
});
remotes.insert((session_id, stream_id), RemoteState {
data_tx: dummy_data_tx,
udp_tx: Some(udp_tx),
cancel_tx,
is_dns: false,
});
send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, dispatcher, socket, ui_event_tx).await?;
}
RelayMessage::UdpData(target, data) => {
if let Some(remote) = remotes.get_mut(&(session_id, stream_id)) {
if let Some(ref udp_tx) = remote.udp_tx {
let _ = udp_tx.send((target, Bytes::from(data)));
}
} else {
let _ = ui_event_tx.send(UiEvent::Log(format!("Relay UDP DATA for unknown stream [{session_id}:{stream_id}]")));
}
}
} }
Ok(()) Ok(())
} }
pub async fn send_relay_to_stream( pub async fn send_relay_to_stream(
session_id: u32, session_id: u32,
stream_id: u16, stream_id: u16,

View File

@ -48,6 +48,12 @@ struct TunnelState {
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
if let Ok(exe) = std::env::current_exe() {
if let Some(dir) = exe.parent() {
let _ = std::env::set_current_dir(dir);
}
}
log_to_file("Helper started (TCP mode)"); log_to_file("Helper started (TCP mode)");
if let Err(e) = run_server().await { if let Err(e) = run_server().await {
log_to_file(&format!("Fatal error: {}", e)); log_to_file(&format!("Fatal error: {}", e));