Refactor: Phase 1 and 2 - Async architecture, JNI fixes, SmolTCP data races, and Tunnel optimizations

This commit is contained in:
ospab 2026-06-03 02:06:06 +03:00
parent 84797f55ab
commit 29e9ef739c
30 changed files with 2079 additions and 864 deletions

15
app-icon.svg Normal file
View File

@ -0,0 +1,15 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
<defs>
<linearGradient id="g2" x1="0%" y1="0%" x2="100%" y2="100%">
<stop offset="0%" stop-color="#111827" />
<stop offset="100%" stop-color="#374151" />
</linearGradient>
<linearGradient id="g2_path" x1="0%" y1="0%" x2="100%" y2="100%">
<stop offset="0%" stop-color="#3B82F6" />
<stop offset="100%" stop-color="#14B8A6" />
</linearGradient>
</defs>
<rect width="512" height="512" rx="120" fill="url(#g2)" />
<path d="M144 256c0-61.9 50.1-112 112-112s112 50.1 112 112-50.1 112-112 112S144 317.9 144 256zm-48 0c0 88.4 71.6 160 160 160s160-71.6 160-160S344.4 96 256 96 96 167.6 96 256z" fill="url(#g2_path)"/>
<circle cx="256" cy="256" r="40" fill="#F59E0B" />
</svg>

After

Width:  |  Height:  |  Size: 779 B

View File

@ -16,6 +16,7 @@ pub(super) struct VirtualDevice {
in_buf: UnboundedReceiver<Vec<u8>>, in_buf: UnboundedReceiver<Vec<u8>>,
out_buf: Sender<AnyIpPktFrame>, out_buf: Sender<AnyIpPktFrame>,
mtu: usize, mtu: usize,
cached_packet: Option<Vec<u8>>,
} }
impl VirtualDevice { impl VirtualDevice {
@ -31,6 +32,7 @@ impl VirtualDevice {
in_buf: iface_ingress_rx, in_buf: iface_ingress_rx,
out_buf: iface_egress_tx, out_buf: iface_egress_tx,
mtu, mtu,
cached_packet: None,
}, },
iface_ingress_tx, iface_ingress_tx,
iface_ingress_tx_avail, iface_ingress_tx_avail,
@ -43,12 +45,18 @@ impl Device for VirtualDevice {
type TxToken<'a> = VirtualTxToken<'a>; type TxToken<'a> = VirtualTxToken<'a>;
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
let Ok(buffer) = self.in_buf.try_recv() else { let buffer = if let Some(buf) = self.cached_packet.take() {
buf
} else {
let Ok(buf) = self.in_buf.try_recv() else {
self.in_buf_avail.store(false, Ordering::Release); self.in_buf_avail.store(false, Ordering::Release);
return None; return None;
}; };
buf
};
let Ok(permit) = self.out_buf.try_reserve() else { let Ok(permit) = self.out_buf.try_reserve() else {
self.cached_packet = Some(buffer);
self.in_buf_avail.store(false, Ordering::Release); self.in_buf_avail.store(false, Ordering::Release);
return None; return None;
}; };

View File

@ -12,23 +12,24 @@ use std::{
/// require two sets of API interfaces in single-threaded and multi-threaded. /// require two sets of API interfaces in single-threaded and multi-threaded.
/// ///
/// [BoxFuture in crate futures utils]: https://docs.rs/futures-util/latest/futures_util/future/type.BoxFuture.html /// [BoxFuture in crate futures utils]: https://docs.rs/futures-util/latest/futures_util/future/type.BoxFuture.html
pub struct BoxFuture<'a, T>(Pin<Box<dyn Future<Output = T> + 'a>>); pub struct BoxFuture<'a, T>(Pin<Box<dyn Future<Output = T> + Send + 'a>>);
impl<'a, T> BoxFuture<'a, T> { impl<'a, T> BoxFuture<'a, T> {
pub fn new<F>(f: F) -> BoxFuture<'a, T> pub fn new<F>(f: F) -> BoxFuture<'a, T>
where where
F: IntoFuture<Output = T> + 'a, F: IntoFuture<Output = T> + Send + 'a,
F::IntoFuture: Send + 'a,
{ {
BoxFuture(Box::pin(f.into_future())) BoxFuture(Box::pin(f.into_future()))
} }
#[allow(unused)] #[allow(unused)]
pub fn wrap(f: Pin<Box<dyn Future<Output = T> + 'a>>) -> BoxFuture<'a, T> { pub fn wrap(f: Pin<Box<dyn Future<Output = T> + Send + 'a>>) -> BoxFuture<'a, T> {
BoxFuture(f) BoxFuture(f)
} }
} }
unsafe impl<T: Send> Send for BoxFuture<'_, T> {}
impl<T> Future for BoxFuture<'_, T> { impl<T> Future for BoxFuture<'_, T> {
type Output = T; type Output = T;

View File

@ -29,3 +29,4 @@ libc = "0.2.186"
x25519-dalek = "2.0.1" x25519-dalek = "2.0.1"
chacha20poly1305.workspace = true chacha20poly1305.workspace = true
hex = "0.4.3" hex = "0.4.3"
winapi = { version = "0.3.9", features = ["iphlpapi", "tcpmib", "processthreadsapi", "psapi", "handleapi", "winerror", "minwindef", "winnt"] }

View File

@ -117,6 +117,7 @@ impl Bridge {
}) })
} }
pub async fn run( pub async fn run(
mut self, mut self,
tx: mpsc::Sender<UiEvent>, tx: mpsc::Sender<UiEvent>,
@ -137,7 +138,7 @@ impl Bridge {
let mut sessions_opt: Option<Vec<SessionState>> = None; let mut sessions_opt: Option<Vec<SessionState>> = None;
let mut udp_rx_opt: Option<mpsc::Receiver<(usize, Bytes)>> = None; let mut udp_rx_opt: Option<mpsc::Receiver<(usize, Bytes)>> = None;
let mut _proxy_guard: Option<crate::sysproxy::SystemProxyGuard> = None; let mut proxy_guard: Option<crate::sysproxy::SystemProxyGuard> = None;
let mut stream_map: std::collections::HashMap<u16, usize> = std::collections::HashMap::new(); let mut stream_map: std::collections::HashMap<u16, usize> = std::collections::HashMap::new();
loop { loop {
@ -147,7 +148,11 @@ impl Bridge {
if *shutdown.borrow() { if *shutdown.borrow() {
self.running = false; self.running = false;
self.metrics.connection_state.store(0, Ordering::Relaxed); self.metrics.connection_state.store(0, Ordering::Relaxed);
_proxy_guard = None; proxy_guard = None;
sessions_opt = None;
udp_rx_opt = None;
stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "manual stop");
break; break;
} }
} }
@ -157,6 +162,50 @@ impl Bridge {
None => std::future::pending().await, None => std::future::pending().await,
} }
}, if self.running => { }, if self.running => {
self.handle_inbound_udp(udp_msg, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await;
}
cmd = bridge_rx.recv() => {
if !self.handle_bridge_cmd(cmd, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await {
break;
}
}
_ = metrics_tick.tick() => {
if self.running {
self.emit_metrics(&tx).await;
}
}
_ = keepalive_tick.tick() => {
if self.running {
self.handle_keepalive(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx, &mut proxy_rx).await;
}
}
_ = retransmit_tick.tick() => {
if self.running {
self.handle_retransmit(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await;
}
}
proxy_ev = proxy_rx.recv(), if self.running && sessions_opt.as_ref().map(|s| {
s.iter().any(|ses| ses.machine.in_flight_count() < ses.machine.cwnd_packets().clamp(16, 16384))
}).unwrap_or(true) => {
self.handle_proxy_event(proxy_ev, &mut sessions_opt, &mut stream_map, &tx, &proxy_tx).await;
}
}
}
tx.send(UiEvent::Log("Bridge stopped".to_string())).await.ok();
Ok(())
}
async fn handle_inbound_udp(
&mut self,
udp_msg: Option<(usize, Bytes)>,
sessions_opt: &mut Option<Vec<SessionState>>,
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
stream_map: &mut std::collections::HashMap<u16, usize>,
tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
) {
match udp_msg { match udp_msg {
Some((session_index, inbound)) => { Some((session_index, inbound)) => {
self.metrics.bytes_recv.fetch_add(inbound.len() as u64, Ordering::Relaxed); self.metrics.bytes_recv.fetch_add(inbound.len() as u64, Ordering::Relaxed);
@ -169,7 +218,7 @@ impl Bridge {
Err(e) => { Err(e) => {
let _ = tx.send(UiEvent::Log(format!("Protocol decrypt error: {e}"))).await; let _ = tx.send(UiEvent::Log(format!("Protocol decrypt error: {e}"))).await;
tracing::warn!("Inbound protocol error (session {}): {}", session_index, e); tracing::warn!("Inbound protocol error (session {}): {}", session_index, e);
continue; return;
} }
}; };
@ -206,9 +255,7 @@ impl Bridge {
self.last_rtt_ms = now.saturating_sub(ts) as f64; self.last_rtt_ms = now.saturating_sub(ts) as f64;
self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed); self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed);
} }
RelayMessage::UdpAssociate => { RelayMessage::UdpAssociate => {}
// Should not be received by client, ignore
}
RelayMessage::UdpData(target, data) => { RelayMessage::UdpData(target, data) => {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data)))); let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data))));
} }
@ -234,24 +281,34 @@ impl Bridge {
None => { None => {
let _ = tx.send(UiEvent::Log("UDP channel closed, resetting connection".to_string())).await; let _ = tx.send(UiEvent::Log("UDP channel closed, resetting connection".to_string())).await;
self.running = false; self.running = false;
crate::sysproxy::disable_windows_proxy(); crate::sysproxy::disable_system_proxy();
sessions_opt = None; *sessions_opt = None;
udp_rx_opt = None; *udp_rx_opt = None;
stream_map.clear(); stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "udp reader closed"); self.reset_proxy_streams(&tx, &proxy_tx, "udp reader closed");
let _ = tx.send(UiEvent::TunnelStopped).await; let _ = tx.send(UiEvent::TunnelStopped).await;
} }
} }
} }
cmd = bridge_rx.recv() => {
async fn handle_bridge_cmd(
&mut self,
cmd: Option<BridgeCommand>,
sessions_opt: &mut Option<Vec<SessionState>>,
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
stream_map: &mut std::collections::HashMap<u16, usize>,
tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
) -> bool {
match cmd { match cmd {
Some(BridgeCommand::ToggleTunnel) => { Some(BridgeCommand::ToggleTunnel) => {
if self.running { if self.running {
self.running = false; self.running = false;
self.metrics.connection_state.store(0, Ordering::Relaxed); self.metrics.connection_state.store(0, Ordering::Relaxed);
_proxy_guard = None; *proxy_guard = None;
sessions_opt = None; *sessions_opt = None;
udp_rx_opt = None; *udp_rx_opt = None;
stream_map.clear(); stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "manual stop"); self.reset_proxy_streams(&tx, &proxy_tx, "manual stop");
tx.send(UiEvent::TunnelStopped).await.ok(); tx.send(UiEvent::TunnelStopped).await.ok();
@ -263,7 +320,7 @@ impl Bridge {
self.metrics.connection_state.store(1, Ordering::Relaxed); self.metrics.connection_state.store(1, Ordering::Relaxed);
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
let (udp_tx, udp_rx) = mpsc::channel(100000); // Increased for high-speed traffic stability let (udp_tx, udp_rx) = mpsc::channel(100000);
let mut sessions = Vec::with_capacity(session_count); let mut sessions = Vec::with_capacity(session_count);
let mut rtt_sum = 0.0; let mut rtt_sum = 0.0;
let mut successful_sessions = 0; let mut successful_sessions = 0;
@ -287,9 +344,6 @@ impl Bridge {
} }
} }
Err(e) => { Err(e) => {
// Under Windows/Winsock, transient UDP socket errors (like WSAECONNRESET) are returned
// as Err(ConnectionReset). We MUST NOT break the loop on transient errors, otherwise the
// download path will be permanently killed while the upload path keeps running.
tracing::warn!("UDP socket recv error (session {}): {}", session_index, e); tracing::warn!("UDP socket recv error (session {}): {}", session_index, e);
tokio::time::sleep(std::time::Duration::from_millis(10)).await; tokio::time::sleep(std::time::Duration::from_millis(10)).await;
} }
@ -308,22 +362,22 @@ impl Bridge {
} }
if sessions.is_empty() { if sessions.is_empty() {
_proxy_guard = None; *proxy_guard = None;
tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok(); tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok();
tx.send(UiEvent::TunnelStopped).await.ok(); tx.send(UiEvent::TunnelStopped).await.ok();
self.metrics.connection_state.store(0, Ordering::Relaxed); self.metrics.connection_state.store(0, Ordering::Relaxed);
continue; return true;
} }
udp_rx_opt = Some(udp_rx); *udp_rx_opt = Some(udp_rx);
sessions_opt = Some(sessions); *sessions_opt = Some(sessions);
self.last_rtt_ms = rtt_sum / successful_sessions as f64; self.last_rtt_ms = rtt_sum / successful_sessions as f64;
self.running = true; self.running = true;
self.last_sample_at = Instant::now(); self.last_sample_at = Instant::now();
self.last_valid_recv = Instant::now(); self.last_valid_recv = Instant::now();
let sys_proxy_addr = self.proxy_addr.replace("0.0.0.0:", "127.0.0.1:"); let sys_proxy_addr = self.proxy_addr.replace("0.0.0.0:", "127.0.0.1:");
_proxy_guard = Some(crate::sysproxy::SystemProxyGuard::enable(&sys_proxy_addr)); *proxy_guard = Some(crate::sysproxy::SystemProxyGuard::enable(&sys_proxy_addr));
tx.send(UiEvent::Metrics { tx.send(UiEvent::Metrics {
status: ConnectionStatus::Established, status: ConnectionStatus::Established,
@ -334,7 +388,6 @@ impl Bridge {
let start_msg = if self.mode == "tun" { "TUN tunnel established" } else { "Connection established" }; let start_msg = if self.mode == "tun" { "TUN tunnel established" } else { "Connection established" };
tx.send(UiEvent::Log(start_msg.to_string())).await.ok(); tx.send(UiEvent::Log(start_msg.to_string())).await.ok();
// Send an immediate Ping so the UI updates without a 60s delay
for session in sessions_opt.as_mut().unwrap().iter_mut() { for session in sessions_opt.as_mut().unwrap().iter_mut() {
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode());
@ -352,11 +405,9 @@ impl Bridge {
} }
Some(BridgeCommand::NetworkChanged) => { Some(BridgeCommand::NetworkChanged) => {
if self.running { if self.running {
// Network changed (e.g. WiFi→LTE): IP address changed, existing UDP
// socket is dead. Trigger immediate reconnect without waiting for stall.
let _ = tx.send(UiEvent::Log("Network changed — starting immediate reconnect".to_string())).await; let _ = tx.send(UiEvent::Log("Network changed — starting immediate reconnect".to_string())).await;
self.metrics.connection_state.store(1, Ordering::Relaxed); self.metrics.connection_state.store(1, Ordering::Relaxed);
self.last_valid_recv = Instant::now() - Duration::from_secs(100); // force stall path self.last_valid_recv = Instant::now() - Duration::from_secs(100);
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
let (udp_tx, udp_rx) = mpsc::channel(100000); let (udp_tx, udp_rx) = mpsc::channel(100000);
@ -398,8 +449,8 @@ impl Bridge {
} }
if !new_sessions.is_empty() { if !new_sessions.is_empty() {
sessions_opt = Some(new_sessions); *sessions_opt = Some(new_sessions);
udp_rx_opt = Some(udp_rx); *udp_rx_opt = Some(udp_rx);
self.last_rtt_ms = rtt_sum / successful_sessions as f64; self.last_rtt_ms = rtt_sum / successful_sessions as f64;
self.last_valid_recv = Instant::now(); self.last_valid_recv = Instant::now();
stream_map.clear(); stream_map.clear();
@ -419,11 +470,10 @@ impl Bridge {
if self.running { if self.running {
self.running = false; self.running = false;
self.metrics.connection_state.store(0, Ordering::Relaxed); self.metrics.connection_state.store(0, Ordering::Relaxed);
_proxy_guard = None; *proxy_guard = None;
sessions_opt = None; *sessions_opt = None;
stream_map.clear(); stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "config reload"); self.reset_proxy_streams(&tx, &proxy_tx, "config reload");
// User logic handles UI restart
let _ = tx.send(UiEvent::TunnelStopped).await; let _ = tx.send(UiEvent::TunnelStopped).await;
} }
} }
@ -434,36 +484,39 @@ impl Bridge {
} }
Some(BridgeCommand::Shutdown) | None => { Some(BridgeCommand::Shutdown) | None => {
self.running = false; self.running = false;
_proxy_guard = None; *proxy_guard = None;
break; return false;
} }
} }
true
} }
_ = metrics_tick.tick() => {
if self.running { async fn handle_keepalive(
self.emit_metrics(&tx).await; &mut self,
} sessions_opt: &mut Option<Vec<SessionState>>,
} udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
_ = keepalive_tick.tick() => { proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
if self.running { stream_map: &mut std::collections::HashMap<u16, usize>,
// 1. Connection Liveness Check & Silent Background Reconnect tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
proxy_rx: &mut mpsc::Receiver<ProxyEvent>,
) {
if self.last_valid_recv.elapsed().as_secs() > 25 { if self.last_valid_recv.elapsed().as_secs() > 25 {
let elapsed = self.last_valid_recv.elapsed().as_secs(); let elapsed = self.last_valid_recv.elapsed().as_secs();
if elapsed > 180 { if elapsed > 180 {
// Hard timeout after 3 minutes of total silence
let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await; let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await;
self.running = false; self.running = false;
_proxy_guard = None; *proxy_guard = None;
sessions_opt = None; *sessions_opt = None;
stream_map.clear(); stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout"); self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout");
let _ = tx.send(UiEvent::TunnelStopped).await; let _ = tx.send(UiEvent::TunnelStopped).await;
self.metrics.connection_state.store(0, Ordering::Relaxed); self.metrics.connection_state.store(0, Ordering::Relaxed);
continue; return;
} }
let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await; let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await;
self.metrics.connection_state.store(1, Ordering::Relaxed); // State: Connecting (Handshake) self.metrics.connection_state.store(1, Ordering::Relaxed);
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
let (udp_tx, udp_rx) = mpsc::channel(100000); let (udp_tx, udp_rx) = mpsc::channel(100000);
@ -508,14 +561,13 @@ impl Bridge {
} }
if !new_sessions.is_empty() { if !new_sessions.is_empty() {
sessions_opt = Some(new_sessions); *sessions_opt = Some(new_sessions);
udp_rx_opt = Some(udp_rx); *udp_rx_opt = Some(udp_rx);
self.last_rtt_ms = rtt_sum / successful_sessions as f64; self.last_rtt_ms = rtt_sum / successful_sessions as f64;
self.last_valid_recv = Instant::now(); self.last_valid_recv = Instant::now();
self.metrics.connection_state.store(2, Ordering::Relaxed); // State: Connected self.metrics.connection_state.store(2, Ordering::Relaxed);
let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await; let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await;
// Send an immediate Ping so the UI updates without a 60s delay
for session in sessions_opt.as_mut().unwrap().iter_mut() { for session in sessions_opt.as_mut().unwrap().iter_mut() {
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode());
@ -525,15 +577,9 @@ impl Bridge {
} }
} }
// FIX: Clear existing proxy streams. Since we are on a NEW session_id,
// the server does not know about our existing streams. Closing them
// forces local apps/TUN to immediately recreate them and send proper
// Connect/UdpAssociate over the new session, avoiding a 5-minute blackhole.
stream_map.clear(); stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "background reconnect"); self.reset_proxy_streams(&tx, &proxy_tx, "background reconnect");
// FIX: Flush all stale proxy messages accumulated during the stall/reconnect
// This prevents a massive post-reconnect UDP burst that causes mobile carriers to drop all packets
let mut flushed = 0; let mut flushed = 0;
while let Ok(stale) = proxy_rx.try_recv() { while let Ok(stale) = proxy_rx.try_recv() {
if let ProxyEvent::NewStream { stream_id, .. } = stale { if let ProxyEvent::NewStream { stream_id, .. } = stale {
@ -549,20 +595,15 @@ impl Bridge {
} }
} }
// 2. Active Keep-Alive / Heartbeat
if let Some(sessions) = sessions_opt.as_mut() { if let Some(sessions) = sessions_opt.as_mut() {
for session in sessions.iter_mut() { for session in sessions.iter_mut() {
// Send Ping (Internal RTT Metric)
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64; let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode()); let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode());
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) { if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) {
// Must go through send_datagram() for TURN-mode wrapping;
// raw socket.send() bypasses the ChannelData header and breaks RTT in TURN.
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
} }
// Send Relay KeepAlive (Force NAT/Server Persistence)
let ka_payload = Bytes::from(RelayMessage::KeepAlive.encode()); let ka_payload = Bytes::from(RelayMessage::KeepAlive.encode());
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ka_payload)) { if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ka_payload)) {
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await; let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
@ -571,9 +612,16 @@ impl Bridge {
} }
} }
} }
}
_ = retransmit_tick.tick() => { async fn handle_retransmit(
if self.running { &mut self,
sessions_opt: &mut Option<Vec<SessionState>>,
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
stream_map: &mut std::collections::HashMap<u16, usize>,
tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
) {
let mut fatal_err = None; let mut fatal_err = None;
if let Some(sessions) = sessions_opt.as_mut() { if let Some(sessions) = sessions_opt.as_mut() {
for session in sessions.iter_mut() { for session in sessions.iter_mut() {
@ -606,27 +654,31 @@ impl Bridge {
if let Some(e) = fatal_err { if let Some(e) = fatal_err {
let _ = tx.send(UiEvent::Log(format!("Protocol tick fatal error: {e}"))).await; let _ = tx.send(UiEvent::Log(format!("Protocol tick fatal error: {e}"))).await;
self.running = false; self.running = false;
_proxy_guard = None; *proxy_guard = None;
sessions_opt = None; *sessions_opt = None;
udp_rx_opt = None; *udp_rx_opt = None;
stream_map.clear(); stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "protocol fatal error"); self.reset_proxy_streams(&tx, &proxy_tx, "protocol fatal error");
let _ = tx.send(UiEvent::TunnelStopped).await; let _ = tx.send(UiEvent::TunnelStopped).await;
self.metrics.connection_state.store(0, Ordering::Relaxed); self.metrics.connection_state.store(0, Ordering::Relaxed);
} }
} }
}
proxy_ev = proxy_rx.recv(), if self.running && sessions_opt.as_ref().map(|s| { async fn handle_proxy_event(
// Backpressure: suspend proxy reads when ARQ window is saturated across ALL sessions &mut self,
s.iter().any(|ses| ses.machine.in_flight_count() < ses.machine.cwnd_packets().clamp(16, 16384)) proxy_ev: Option<ProxyEvent>,
}).unwrap_or(true) => { sessions_opt: &mut Option<Vec<SessionState>>,
stream_map: &mut std::collections::HashMap<u16, usize>,
tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
) {
if let Some(ev) = proxy_ev { if let Some(ev) = proxy_ev {
if let Some(sessions) = sessions_opt.as_mut() { if let Some(sessions) = sessions_opt.as_mut() {
if sessions.is_empty() { if sessions.is_empty() {
if let ProxyEvent::NewStream { stream_id, .. } = ev { if let ProxyEvent::NewStream { stream_id, .. } = ev {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into())));
} }
continue; return;
} }
let (stream_id, relay_msg, is_close) = match ev { let (stream_id, relay_msg, is_close) = match ev {
ProxyEvent::NewStream { stream_id, target } => { ProxyEvent::NewStream { stream_id, target } => {
@ -648,7 +700,6 @@ impl Bridge {
}; };
let len = sessions.len(); let len = sessions.len();
let session_index = *stream_map.entry(stream_id).or_insert_with(|| { let session_index = *stream_map.entry(stream_id).or_insert_with(|| {
// §8 FIX: Load balance multiplexed streams randomly across available connection sockets
rand::thread_rng().gen_range(0..len) rand::thread_rng().gen_range(0..len)
}); });
if is_close { if is_close {
@ -660,10 +711,7 @@ impl Bridge {
Ok(ProtocolAction::SendDatagram(frame)) => { Ok(ProtocolAction::SendDatagram(frame)) => {
if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() { if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() {
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed); self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
tracing::trace!( tracing::trace!("Outbound datagram sent stream_id={stream_id} bytes={}", frame.len());
"Outbound datagram sent stream_id={stream_id} bytes={}",
frame.len()
);
} }
} }
Ok(ProtocolAction::Multiple(list)) => { Ok(ProtocolAction::Multiple(list)) => {
@ -676,19 +724,13 @@ impl Bridge {
} }
} }
} }
tracing::trace!( tracing::trace!("Outbound datagram batch stream_id={stream_id} sent={sent}");
"Outbound datagram batch stream_id={stream_id} sent={sent}"
);
} }
Ok(ProtocolAction::Noop) => { Ok(ProtocolAction::Noop) => {
tracing::trace!( tracing::trace!("Outbound datagram noop stream_id={stream_id}");
"Outbound datagram noop stream_id={stream_id}"
);
} }
Ok(_) => { Ok(_) => {
tracing::trace!( tracing::trace!("Outbound datagram unexpected action stream_id={stream_id}");
"Outbound datagram unexpected action stream_id={stream_id}"
);
} }
Err(e) => { Err(e) => {
tracing::warn!("Protocol error packing outbound stream_id={}: {}", stream_id, e); tracing::warn!("Protocol error packing outbound stream_id={}: {}", stream_id, e);
@ -696,7 +738,6 @@ impl Bridge {
} }
} }
} else { } else {
// Drop it, not connected
if let ProxyEvent::NewStream { stream_id, .. } = ev { if let ProxyEvent::NewStream { stream_id, .. } = ev {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into()))); let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into())));
} }
@ -705,13 +746,6 @@ impl Bridge {
} }
}
}
tx.send(UiEvent::Log("Bridge stopped".to_string())).await.ok();
Ok(())
}
fn reset_proxy_streams( fn reset_proxy_streams(
&self, &self,
tx: &mpsc::Sender<UiEvent>, tx: &mpsc::Sender<UiEvent>,

View File

@ -58,7 +58,7 @@ pub struct OstpConfig {
fn default_keepalive() -> u64 { 5 } fn default_keepalive() -> u64 { 5 }
fn default_mtu() -> usize { 1280 } fn default_mtu() -> usize { 1140 }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalProxyConfig { pub struct LocalProxyConfig {

View File

@ -82,7 +82,7 @@ pub fn enable_windows_proxy(proxy_addr: &str) {
} }
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
pub fn disable_windows_proxy() { pub fn disable_system_proxy() {
tracing::info!("Disabling Windows system proxy"); tracing::info!("Disabling Windows system proxy");
let _ = Command::new("reg") let _ = Command::new("reg")
.creation_flags(CREATE_NO_WINDOW) .creation_flags(CREATE_NO_WINDOW)
@ -188,10 +188,6 @@ pub fn enable_system_proxy(proxy_addr: &str) {
enable_windows_proxy(proxy_addr); enable_windows_proxy(proxy_addr);
} }
#[cfg(target_os = "windows")]
pub fn disable_system_proxy() {
disable_windows_proxy();
}
pub struct SystemProxyGuard { pub struct SystemProxyGuard {
active: bool, active: bool,

View File

@ -0,0 +1,134 @@
use crate::config::ExclusionConfig;
use std::time::Duration;
use tokio::time::timeout;
#[derive(Clone)]
pub struct ExclusionMatcher {
pub domain_suffix: Vec<String>,
pub cidrs: Vec<Cidr>,
pub processes: Vec<String>,
pub physical_if_index: Option<u32>,
pub physical_if_name: Option<String>,
}
impl ExclusionMatcher {
pub fn new(
exclusions: &ExclusionConfig,
physical_if_index: Option<u32>,
physical_if_name: Option<String>,
) -> Self {
let mut cidrs = Vec::new();
for ip in &exclusions.ips {
if let Some(cidr) = parse_cidr(ip) {
cidrs.push(cidr);
}
}
let processes = exclusions.processes.iter()
.map(|p| p.trim().to_lowercase())
.filter(|p| !p.is_empty())
.collect();
Self {
domain_suffix: exclusions
.domains
.iter()
.map(|d| d.trim().trim_start_matches('.').to_lowercase())
.filter(|d| !d.is_empty())
.collect(),
cidrs,
processes,
physical_if_index,
physical_if_name,
}
}
pub async fn should_bypass_target(&self, host: &str, port: u16, timeout_value: Duration) -> bool {
if self.match_domain(host) {
return true;
}
if self.cidrs.is_empty() {
return false;
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return self.match_ip(&ip);
}
let lookup_target = (host.to_string(), port);
match timeout(timeout_value, tokio::net::lookup_host(lookup_target)).await {
Ok(Ok(addrs)) => addrs.into_iter().any(|addr| self.match_ip(&addr.ip())),
_ => false,
}
}
pub fn match_domain(&self, host: &str) -> bool {
if self.domain_suffix.is_empty() {
return false;
}
let host = host.trim_end_matches('.').to_lowercase();
self.domain_suffix.iter().any(|suffix| {
host == *suffix || host.ends_with(&format!(".{suffix}"))
})
}
pub fn match_ip(&self, ip: &std::net::IpAddr) -> bool {
self.cidrs.iter().any(|cidr| cidr.contains(ip))
}
pub fn match_process(&self, process_name: &str) -> bool {
if self.processes.is_empty() {
return false;
}
let p = process_name.to_lowercase();
self.processes.iter().any(|ex| p.contains(ex))
}
}
#[derive(Clone)]
pub enum Cidr {
V4(u32, u8),
V6(u128, u8),
}
impl Cidr {
pub fn contains(&self, ip: &std::net::IpAddr) -> bool {
match (self, ip) {
(Cidr::V4(net, bits), std::net::IpAddr::V4(addr)) => {
let mask = if *bits == 0 { 0 } else { u32::MAX << (32 - bits) };
let ip = u32::from_be_bytes(addr.octets());
(ip & mask) == (*net & mask)
}
(Cidr::V6(net, bits), std::net::IpAddr::V6(addr)) => {
let mask = if *bits == 0 { 0 } else { u128::MAX << (128 - bits) };
let ip = u128::from_be_bytes(addr.octets());
(ip & mask) == (*net & mask)
}
_ => false,
}
}
}
pub fn parse_cidr(s: &str) -> Option<Cidr> {
let parts: Vec<&str> = s.split('/').collect();
if parts.is_empty() || parts.len() > 2 {
return None;
}
if let Ok(ip) = parts[0].parse::<std::net::IpAddr>() {
let bits = if parts.len() == 2 {
parts[1].parse::<u8>().ok()?
} else {
match ip {
std::net::IpAddr::V4(_) => 32,
std::net::IpAddr::V6(_) => 128,
}
};
match ip {
std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), bits)),
std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), bits)),
}
} else {
None
}
}

View File

@ -61,3 +61,7 @@ pub async fn run_local_proxy(
} }
pub mod exclusion;
pub mod process_lookup;
pub mod sni_sniff;

View File

@ -426,12 +426,63 @@ pub async fn run_native_tunnel_from_fd(
} }
}); });
let matcher = crate::tunnel::exclusion::ExclusionMatcher::new(&config.exclusions, None, None);
let mut tcp_accept_task = tokio::spawn(async move { let mut tcp_accept_task = tokio::spawn(async move {
if let Some(mut listener) = tcp_listener { if let Some(mut listener) = tcp_listener {
while let Some((mut stream, _local, remote)) = listener.next().await { while let Some((mut stream, local, remote)) = listener.next().await {
let proxy_addr = proxy_addr.clone(); let proxy_addr = proxy_addr.clone();
let matcher = matcher.clone();
tokio::spawn(async move { tokio::spawn(async move {
if debug { tracing::info!("Native TUN intercepted TCP to {}", remote); } if debug { tracing::info!("Native TUN intercepted TCP {local} -> {remote}"); }
// Peak first chunk to see SNI
let mut sniff_buf = [0u8; 1500];
let sniff_len = match tokio::time::timeout(std::time::Duration::from_millis(50), stream.read(&mut sniff_buf)).await {
Ok(Ok(n)) => n,
_ => 0, // Timeout or error
};
let mut should_bypass = false;
// 1. Check SNI
if sniff_len > 0 {
if let Some(sni) = crate::tunnel::sni_sniff::extract_sni(&sniff_buf[..sniff_len]) {
if debug { tracing::info!("Native TUN sniffed SNI: {}", sni); }
if matcher.match_domain(&sni) {
should_bypass = true;
}
}
}
// 2. Check Process
if !should_bypass {
if let Some(exe) = crate::tunnel::process_lookup::get_process_name_from_port(local.port()) {
if debug { tracing::info!("Native TUN source port {} maps to EXE: {}", local.port(), exe); }
if matcher.match_process(&exe) {
should_bypass = true;
}
}
}
// 3. Check Target IP
if !should_bypass {
if matcher.match_ip(&remote.ip()) {
should_bypass = true;
}
}
if should_bypass {
if debug { tracing::info!("Native TUN BYPASS matched for {}", remote); }
if let Ok(mut direct) = tokio::time::timeout(std::time::Duration::from_secs(5), tokio::net::TcpStream::connect(remote)).await.unwrap_or(Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "Direct connect timeout"))) {
if sniff_len > 0 {
let _ = direct.write_all(&sniff_buf[..sniff_len]).await;
}
let _ = tokio::io::copy_bidirectional(&mut stream, &mut direct).await;
}
return;
}
if let Ok(mut socks) = tokio::net::TcpStream::connect(&proxy_addr).await { if let Ok(mut socks) = tokio::net::TcpStream::connect(&proxy_addr).await {
if socks.write_all(&[5, 1, 0]).await.is_err() { return; } if socks.write_all(&[5, 1, 0]).await.is_err() { return; }
let mut buf = [0u8; 2]; let mut buf = [0u8; 2];
@ -456,6 +507,11 @@ pub async fn run_native_tunnel_from_fd(
let mut rep = [0u8; 10]; let mut rep = [0u8; 10];
if socks.read_exact(&mut rep).await.is_err() || rep[1] != 0 { return; } if socks.read_exact(&mut rep).await.is_err() || rep[1] != 0 { return; }
// Write sniffed buffer to socks
if sniff_len > 0 {
if socks.write_all(&sniff_buf[..sniff_len]).await.is_err() { return; }
}
let _ = tokio::io::copy_bidirectional(&mut stream, &mut socks).await; let _ = tokio::io::copy_bidirectional(&mut stream, &mut socks).await;
} }
}); });

View File

@ -0,0 +1,142 @@
#[cfg(target_os = "windows")]
pub fn get_process_name_from_port(port: u16) -> Option<String> {
use winapi::shared::minwindef::{DWORD, ULONG};
use winapi::shared::winerror::ERROR_INSUFFICIENT_BUFFER;
use winapi::um::iphlpapi::GetExtendedTcpTable;
use winapi::shared::tcpmib::{MIB_TCPTABLE_OWNER_PID, MIB_TCPROW_OWNER_PID};
let mut size: ULONG = 0;
let table_class = 5; // TCP_TABLE_OWNER_PID_ALL
let mut table = vec![0u8; 1024];
unsafe {
let mut ret = GetExtendedTcpTable(
table.as_mut_ptr() as *mut _,
&mut size,
0,
2, // AF_INET
table_class,
0,
);
if ret == ERROR_INSUFFICIENT_BUFFER {
table.resize(size as usize, 0);
ret = GetExtendedTcpTable(
table.as_mut_ptr() as *mut _,
&mut size,
0,
2, // AF_INET
table_class,
0,
);
}
if ret == 0 {
let tcp_table = &*(table.as_ptr() as *const MIB_TCPTABLE_OWNER_PID);
let row_ptr = &tcp_table.table[0] as *const MIB_TCPROW_OWNER_PID;
for i in 0..tcp_table.dwNumEntries {
let row = &*row_ptr.add(i as usize);
// Local port is in network byte order
let local_port = u16::from_be(row.dwLocalPort as u16);
if local_port == port {
return get_process_name_from_pid(row.dwOwningPid);
}
}
}
}
None
}
#[cfg(target_os = "windows")]
fn get_process_name_from_pid(pid: u32) -> Option<String> {
use winapi::um::processthreadsapi::OpenProcess;
use winapi::um::psapi::GetModuleBaseNameW;
use winapi::um::winnt::{PROCESS_QUERY_INFORMATION, PROCESS_VM_READ};
use winapi::um::handleapi::CloseHandle;
use std::os::windows::ffi::OsStringExt;
unsafe {
let handle = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, 0, pid);
if handle.is_null() {
return None;
}
let mut buffer = [0u16; 1024];
let len = GetModuleBaseNameW(handle, std::ptr::null_mut(), buffer.as_mut_ptr(), buffer.len() as u32);
CloseHandle(handle);
if len > 0 {
let name = std::ffi::OsString::from_wide(&buffer[..len as usize]);
return Some(name.to_string_lossy().into_owned());
}
}
None
}
#[cfg(target_os = "linux")]
pub fn get_process_name_from_port(port: u16) -> Option<String> {
use std::fs;
use std::io::{BufRead, BufReader};
let mut target_inode = None;
let hex_port = format!("{:04X}", port);
let check_net_file = |path: &str| -> Option<u64> {
let file = fs::File::open(path).ok()?;
let reader = BufReader::new(file);
for line in reader.lines().skip(1).filter_map(Result::ok) {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 10 {
let local_addr = parts[1];
if local_addr.ends_with(&format!(":{}", hex_port)) {
if let Ok(inode) = parts[9].parse::<u64>() {
return Some(inode);
}
}
}
}
None
};
target_inode = check_net_file("/proc/net/tcp")
.or_else(|| check_net_file("/proc/net/tcp6"))
.or_else(|| check_net_file("/proc/net/udp"))
.or_else(|| check_net_file("/proc/net/udp6"));
let target_inode = target_inode?;
let socket_str = format!("socket:[{}]", target_inode);
for entry in fs::read_dir("/proc").ok()?.filter_map(Result::ok) {
let file_name = entry.file_name();
let pid_str = file_name.to_string_lossy();
if !pid_str.chars().all(char::is_numeric) {
continue;
}
let fd_dir = entry.path().join("fd");
if let Ok(fd_entries) = fs::read_dir(fd_dir) {
for fd_entry in fd_entries.filter_map(Result::ok) {
if let Ok(target) = fs::read_link(fd_entry.path()) {
if target.to_string_lossy() == socket_str {
let exe_path = entry.path().join("exe");
if let Ok(exe_link) = fs::read_link(exe_path) {
if let Some(name) = exe_link.file_name() {
return Some(name.to_string_lossy().into_owned());
}
}
if let Ok(comm) = fs::read_to_string(entry.path().join("comm")) {
return Some(comm.trim().to_string());
}
}
}
}
}
}
None
}
#[cfg(not(any(target_os = "windows", target_os = "linux")))]
pub fn get_process_name_from_port(_port: u16) -> Option<String> {
None
}

View File

@ -1,4 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::tunnel::exclusion::{ExclusionMatcher, Cidr};
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::net::{TcpListener, TcpStream, UdpSocket};
@ -421,8 +422,10 @@ async fn handle_udp_associate(
}; };
let payload = bytes::Bytes::copy_from_slice(&buf[header_len..len]); let payload = bytes::Bytes::copy_from_slice(&buf[header_len..len]);
let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() };
let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 0 };
// Check if target should bypass the tunnel // Check if target should bypass the tunnel
if matcher.should_bypass(&target, connect_timeout).await { if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await {
if debug { if debug {
tracing::info!("proxy UDP BYPASS target={}", target); tracing::info!("proxy UDP BYPASS target={}", target);
} }
@ -668,7 +671,9 @@ async fn handle_proxy_client(
if debug { if debug {
tracing::info!("proxy CONNECT stream_id={stream_id} target={target}"); tracing::info!("proxy CONNECT stream_id={stream_id} target={target}");
} }
if matcher.should_bypass(&target, connect_timeout).await { let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() };
let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 0 };
if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await {
return direct_connect_socks5( return direct_connect_socks5(
client, client,
stream_id, stream_id,
@ -753,7 +758,9 @@ async fn handle_proxy_client(
if debug { if debug {
tracing::info!("proxy CONNECT stream_id={stream_id} target={target}"); tracing::info!("proxy CONNECT stream_id={stream_id} target={target}");
} }
if matcher.should_bypass(&target, connect_timeout).await { let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() };
let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 443 };
if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await {
return direct_connect_http( return direct_connect_http(
client, client,
stream_id, stream_id,
@ -854,129 +861,6 @@ async fn handle_proxy_client(
Ok(()) Ok(())
} }
#[derive(Clone)]
struct ExclusionMatcher {
domain_suffix: Vec<String>,
cidrs: Vec<Cidr>,
physical_if_index: Option<u32>,
physical_if_name: Option<String>,
}
impl ExclusionMatcher {
fn new(
exclusions: &ExclusionConfig,
physical_if_index: Option<u32>,
physical_if_name: Option<String>,
) -> Self {
let mut cidrs = Vec::new();
for ip in &exclusions.ips {
if let Some(cidr) = parse_cidr(ip) {
cidrs.push(cidr);
}
}
Self {
domain_suffix: exclusions
.domains
.iter()
.map(|d| d.trim().trim_start_matches('.').to_lowercase())
.filter(|d| !d.is_empty())
.collect(),
cidrs,
physical_if_index,
physical_if_name,
}
}
async fn should_bypass(&self, target: &str, timeout_value: Duration) -> bool {
let (host, port) = match split_host_port(target) {
Some(v) => v,
None => return false,
};
if self.match_domain(&host) {
return true;
}
if self.cidrs.is_empty() {
return false;
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return self.match_ip(&ip);
}
let lookup_target = (host.clone(), port);
match timeout(timeout_value, tokio::net::lookup_host(lookup_target)).await {
Ok(Ok(addrs)) => addrs.into_iter().any(|addr| self.match_ip(&addr.ip())),
_ => false,
}
}
fn match_domain(&self, host: &str) -> bool {
if self.domain_suffix.is_empty() {
return false;
}
let host = host.trim_end_matches('.').to_lowercase();
self.domain_suffix.iter().any(|suffix| {
host == *suffix || host.ends_with(&format!(".{suffix}"))
})
}
fn match_ip(&self, ip: &std::net::IpAddr) -> bool {
self.cidrs.iter().any(|cidr| cidr.contains(ip))
}
}
#[derive(Clone)]
enum Cidr {
V4(u32, u8),
V6(u128, u8),
}
impl Cidr {
fn contains(&self, ip: &std::net::IpAddr) -> bool {
match (self, ip) {
(Cidr::V4(net, bits), std::net::IpAddr::V4(addr)) => {
let mask = if *bits == 0 { 0 } else { u32::MAX << (32 - bits) };
let ip = u32::from_be_bytes(addr.octets());
(ip & mask) == (*net & mask)
}
(Cidr::V6(net, bits), std::net::IpAddr::V6(addr)) => {
let mask = if *bits == 0 { 0 } else { u128::MAX << (128 - bits) };
let ip = u128::from_be_bytes(addr.octets());
(ip & mask) == (*net & mask)
}
_ => false,
}
}
}
fn parse_cidr(value: &str) -> Option<Cidr> {
let value = value.trim();
if value.is_empty() {
return None;
}
if let Some((addr_str, bits_str)) = value.split_once('/') {
let bits: u8 = bits_str.parse().ok()?;
if let Ok(addr) = addr_str.parse::<std::net::IpAddr>() {
return match addr {
std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), bits.min(32))),
std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), bits.min(128))),
};
}
}
if let Ok(addr) = value.parse::<std::net::IpAddr>() {
return match addr {
std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), 32)),
std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), 128)),
};
}
None
}
fn split_host_port(target: &str) -> Option<(String, u16)> { fn split_host_port(target: &str) -> Option<(String, u16)> {
if let Some((host, port)) = target.rsplit_once(':') { if let Some((host, port)) = target.rsplit_once(':') {

View File

@ -0,0 +1,73 @@
pub fn extract_sni(data: &[u8]) -> Option<String> {
// Basic TLS ClientHello parser
// Must be at least 43 bytes to contain anything useful
if data.len() < 43 {
return None;
}
// TLS Record layer: Handshake (22)
if data[0] != 0x16 {
return None;
}
// Record layer version: 0x0301 (TLS 1.0) or 0x0303 (TLS 1.2)
if data[1] != 0x03 {
return None;
}
// Handshake type: ClientHello (1)
if data[5] != 0x01 {
return None;
}
let mut pos = 43; // Skip fixed ClientHello header
// Skip Session ID
if pos >= data.len() { return None; }
let session_id_len = data[pos] as usize;
pos += 1 + session_id_len;
// Skip Cipher Suites
if pos + 2 > data.len() { return None; }
let cipher_suites_len = ((data[pos] as usize) << 8) | (data[pos + 1] as usize);
pos += 2 + cipher_suites_len;
// Skip Compression Methods
if pos >= data.len() { return None; }
let comp_methods_len = data[pos] as usize;
pos += 1 + comp_methods_len;
// Extensions
if pos + 2 > data.len() { return None; }
let extensions_len = ((data[pos] as usize) << 8) | (data[pos + 1] as usize);
pos += 2;
let extensions_end = pos + extensions_len;
if extensions_end > data.len() { return None; }
while pos + 4 <= extensions_end {
let ext_type = ((data[pos] as usize) << 8) | (data[pos + 1] as usize);
let ext_len = ((data[pos + 2] as usize) << 8) | (data[pos + 3] as usize);
pos += 4;
if ext_type == 0x0000 { // Server Name Indication (SNI)
if pos + 5 <= extensions_end {
let list_len = ((data[pos] as usize) << 8) | (data[pos + 1] as usize);
let name_type = data[pos + 2];
if name_type == 0 { // Hostname
let name_len = ((data[pos + 3] as usize) << 8) | (data[pos + 4] as usize);
if pos + 5 + name_len <= extensions_end {
let sni_bytes = &data[pos + 5..pos + 5 + name_len];
if let Ok(sni) = std::str::from_utf8(sni_bytes) {
return Some(sni.to_string());
}
}
}
}
break;
}
pos += ext_len;
}
None
}

View File

@ -5,7 +5,6 @@
//! This replaces the fixed `retransmit_budget = 8` with an adaptive //! This replaces the fixed `retransmit_budget = 8` with an adaptive
//! congestion window that responds to network conditions. //! congestion window that responds to network conditions.
use std::collections::VecDeque;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
/// Congestion control state for a single OSTP session. /// Congestion control state for a single OSTP session.
@ -16,14 +15,8 @@ pub struct CongestionController {
ssthresh: u64, ssthresh: u64,
/// Current phase /// Current phase
phase: Phase, phase: Phase,
/// Minimum RTT observed (used for BDP calculation) /// Minimum RTT observed
min_rtt: Duration, min_rtt: Duration,
/// Maximum bandwidth observed (bytes/sec)
max_bandwidth: u64,
/// RTT samples for smoothing
rtt_samples: VecDeque<RttSample>,
/// Bandwidth samples
bw_samples: VecDeque<BwSample>,
/// Bytes currently in flight (unacknowledged) /// Bytes currently in flight (unacknowledged)
bytes_in_flight: u64, bytes_in_flight: u64,
/// Total bytes acknowledged (for bandwidth estimation) /// Total bytes acknowledged (for bandwidth estimation)
@ -36,8 +29,6 @@ pub struct CongestionController {
pacing_rate: u64, pacing_rate: u64,
/// MTU estimate (used for cwnd → packet count conversion) /// MTU estimate (used for cwnd → packet count conversion)
mtu: u64, mtu: u64,
/// Probe RTT phase timer
probe_rtt_timer: Option<Instant>,
/// Min RTT expiry: re-probe after 10 seconds /// Min RTT expiry: re-probe after 10 seconds
min_rtt_stamp: Instant, min_rtt_stamp: Instant,
} }
@ -48,35 +39,14 @@ enum Phase {
SlowStart, SlowStart,
/// Probe bandwidth: cycle through pacing gains /// Probe bandwidth: cycle through pacing gains
ProbeBandwidth, ProbeBandwidth,
/// Periodically drain the queue to measure true min RTT
#[allow(dead_code)]
ProbeRtt,
} }
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct RttSample {
rtt: Duration,
time: Instant,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct BwSample {
bytes_per_sec: u64,
time: Instant,
}
/// Maximum number of samples to keep for windowed min/max
const MAX_SAMPLES: usize = 32;
/// Initial congestion window: 10 packets × MTU /// Initial congestion window: 10 packets × MTU
const INITIAL_CWND_PACKETS: u64 = 10; const INITIAL_CWND_PACKETS: u64 = 10;
/// Minimum cwnd: 2 packets /// Minimum cwnd: 2 packets
const MIN_CWND_PACKETS: u64 = 2; const MIN_CWND_PACKETS: u64 = 2;
/// Min RTT expiry window (after which we re-probe) /// Min RTT expiry window (after which we re-probe)
const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10); const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10);
/// ProbeRTT drain duration
const PROBE_RTT_DURATION: Duration = Duration::from_millis(200);
impl CongestionController { impl CongestionController {
pub fn new(mtu: u64) -> Self { pub fn new(mtu: u64) -> Self {
@ -87,16 +57,12 @@ impl CongestionController {
ssthresh: u64::MAX, ssthresh: u64::MAX,
phase: Phase::SlowStart, phase: Phase::SlowStart,
min_rtt: Duration::from_millis(100), // Conservative initial estimate min_rtt: Duration::from_millis(100), // Conservative initial estimate
max_bandwidth: 0,
rtt_samples: VecDeque::with_capacity(MAX_SAMPLES),
bw_samples: VecDeque::with_capacity(MAX_SAMPLES),
bytes_in_flight: 0, bytes_in_flight: 0,
total_acked: 0, total_acked: 0,
last_ack_time: now, last_ack_time: now,
loss_count: 0, loss_count: 0,
pacing_rate: initial_cwnd * 10, // initial: ~10 windows/sec pacing_rate: initial_cwnd * 10, // initial: ~10 windows/sec
mtu, mtu,
probe_rtt_timer: None,
min_rtt_stamp: now, min_rtt_stamp: now,
} }
} }
@ -169,29 +135,7 @@ impl CongestionController {
// TCP Reno Additive Increase: increase cwnd by ~1 MTU per RTT // TCP Reno Additive Increase: increase cwnd by ~1 MTU per RTT
self.cwnd = self.cwnd.saturating_add(bytes * self.mtu / self.cwnd.max(1)); self.cwnd = self.cwnd.saturating_add(bytes * self.mtu / self.cwnd.max(1));
} }
Phase::ProbeRtt => {
// Drain down to 4 packets to measure true min RTT
self.cwnd = MIN_CWND_PACKETS * self.mtu * 2;
if let Some(timer) = self.probe_rtt_timer {
if now.duration_since(timer) >= PROBE_RTT_DURATION {
// ProbeRTT complete, return to ProbeBandwidth
self.phase = Phase::ProbeBandwidth;
self.probe_rtt_timer = None;
self.cwnd = (MIN_CWND_PACKETS * self.mtu * 4).max(self.cwnd);
tracing::debug!(cwnd = self.cwnd, min_rtt = ?self.min_rtt, "congestion: probe RTT complete");
} }
}
}
}
/*
// Periodically enter ProbeRTT to refresh min_rtt
if now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY && self.phase != Phase::ProbeRtt {
self.phase = Phase::ProbeRtt;
self.probe_rtt_timer = Some(now);
tracing::debug!("congestion: entering probe RTT phase");
}
*/
self.update_pacing_rate(); self.update_pacing_rate();
self.last_ack_time = now; self.last_ack_time = now;
@ -215,9 +159,6 @@ impl CongestionController {
self.cwnd = (self.cwnd * 7 / 10).max(MIN_CWND_PACKETS * self.mtu); self.cwnd = (self.cwnd * 7 / 10).max(MIN_CWND_PACKETS * self.mtu);
tracing::debug!(cwnd = self.cwnd, "congestion: loss, cwnd reduced"); tracing::debug!(cwnd = self.cwnd, "congestion: loss, cwnd reduced");
} }
Phase::ProbeRtt => {
// Don't react to loss during ProbeRTT
}
} }
self.update_pacing_rate(); self.update_pacing_rate();
@ -236,40 +177,16 @@ impl CongestionController {
self.min_rtt = rtt; self.min_rtt = rtt;
self.min_rtt_stamp = now; self.min_rtt_stamp = now;
} }
// Keep sample history
self.rtt_samples.push_back(RttSample { rtt, time: now });
while self.rtt_samples.len() > MAX_SAMPLES {
self.rtt_samples.pop_front();
}
} }
fn update_bandwidth(&mut self, acked_bytes: u64, now: Instant) { fn update_bandwidth(&mut self, _acked_bytes: u64, now: Instant) {
let elapsed = now.duration_since(self.last_ack_time); let elapsed = now.duration_since(self.last_ack_time);
if elapsed.as_micros() > 0 { if elapsed.as_micros() > 0 {
let bw = acked_bytes * 1_000_000 / elapsed.as_micros() as u64; // Removed bw_samples tracking
if bw > self.max_bandwidth {
self.max_bandwidth = bw;
}
self.bw_samples.push_back(BwSample { bytes_per_sec: bw, time: now });
while self.bw_samples.len() > MAX_SAMPLES {
self.bw_samples.pop_front();
}
} }
} }
#[allow(dead_code)]
fn bandwidth_delay_product(&self) -> u64 {
// BDP = max_bandwidth * min_rtt
let bw = if self.max_bandwidth > 0 {
self.max_bandwidth
} else {
// Fallback: assume 10 Mbps
1_250_000
};
let rtt_secs = self.min_rtt.as_secs_f64();
(bw as f64 * rtt_secs) as u64
}
fn update_pacing_rate(&mut self) { fn update_pacing_rate(&mut self) {
// Pacing rate = cwnd / min_rtt (with gain) // Pacing rate = cwnd / min_rtt (with gain)

View File

@ -290,7 +290,7 @@ impl ProtocolMachine {
if raw_vec.len() < 12 { if raw_vec.len() < 12 {
return Err(ProtocolError::Framing("data datagram too short".to_string())); return Err(ProtocolError::Framing("data datagram too short".to_string()));
} }
let nonce = u64::from_be_bytes(raw_vec[4..12].try_into().unwrap()); let nonce = u64::from_be_bytes(raw_vec[4..12].try_into().map_err(|_| ProtocolError::Framing("data datagram too short for nonce".into()))?);
if nonce < self.expected_recv_nonce { if nonce < self.expected_recv_nonce {
// Duplicate — the ACK we sent was likely lost or delayed. // Duplicate — the ACK we sent was likely lost or delayed.
@ -330,7 +330,7 @@ impl ProtocolMachine {
// Fast path processing for Nacks: act immediately, bypass sequence queue // Fast path processing for Nacks: act immediately, bypass sequence queue
if packet.header.kind == FrameKind::Nack if packet.header.kind == FrameKind::Nack
&& packet.payload.len() >= 8 { && packet.payload.len() >= 8 {
let req_nonce = u64::from_be_bytes(packet.payload[..8].try_into().unwrap()); let req_nonce = u64::from_be_bytes(packet.payload[..8].try_into().map_err(|_| ProtocolError::Framing("nack payload too short".into()))?);
if let Some(cached_frame) = self.lookup_sent_frame(req_nonce) { if let Some(cached_frame) = self.lookup_sent_frame(req_nonce) {
tracing::debug!("NACK received: retransmitting nonce={}", req_nonce); tracing::debug!("NACK received: retransmitting nonce={}", req_nonce);
self.cc.on_loss(cached_frame.len() as u64); self.cc.on_loss(cached_frame.len() as u64);
@ -733,8 +733,8 @@ fn parse_ack_ranges(payload: &[u8]) -> Result<Vec<(u64, u64)>, ProtocolError> {
let mut ranges = Vec::with_capacity(count); let mut ranges = Vec::with_capacity(count);
let mut idx = 1; let mut idx = 1;
for _ in 0..count { for _ in 0..count {
let start = u64::from_be_bytes(payload[idx..idx + 8].try_into().unwrap()); let start = u64::from_be_bytes(payload[idx..idx + 8].try_into().map_err(|_| ProtocolError::Framing("ack range start invalid".into()))?);
let end = u64::from_be_bytes(payload[idx + 8..idx + 16].try_into().unwrap()); let end = u64::from_be_bytes(payload[idx + 8..idx + 16].try_into().map_err(|_| ProtocolError::Framing("ack range end invalid".into()))?);
ranges.push((start, end)); ranges.push((start, end));
idx += 16; idx += 16;
} }

View File

@ -64,7 +64,7 @@ impl RelayMessage {
7 => { 7 => {
let payload = decode_with_len(&input[1..])?; let payload = decode_with_len(&input[1..])?;
if payload.len() != 8 { return Err(anyhow!("invalid ping payload len")); } if payload.len() != 8 { return Err(anyhow!("invalid ping payload len")); }
let ts = u64::from_be_bytes(payload.try_into().unwrap()); let ts = u64::from_be_bytes(payload.try_into().map_err(|_| anyhow!("invalid ping payload size"))?);
Ok(RelayMessage::Ping(ts)) Ok(RelayMessage::Ping(ts))
} }
8 => { 8 => {

View File

@ -0,0 +1,3 @@
description: This file stores settings for Dart & Flutter DevTools.
documentation: https://docs.flutter.dev/tools/devtools/extensions#configure-extension-enablement-states
extensions:

View File

@ -96,6 +96,14 @@ packages:
description: flutter description: flutter
source: sdk source: sdk
version: "0.0.0" version: "0.0.0"
json_annotation:
dependency: transitive
description:
name: json_annotation
sha256: "2a743920d81b7910627f68ee2c9ac1fc0bfee32b9fc3403587d7c6791ca12f80"
url: "https://pub.dev"
source: hosted
version: "4.12.0"
leak_tracker: leak_tracker:
dependency: transitive dependency: transitive
description: description:
@ -144,6 +152,14 @@ packages:
url: "https://pub.dev" url: "https://pub.dev"
source: hosted source: hosted
version: "0.13.0" version: "0.13.0"
menu_base:
dependency: transitive
description:
name: menu_base
sha256: "820368014a171bd1241030278e6c2617354f492f5c703d7b7d4570a6b8b84405"
url: "https://pub.dev"
source: hosted
version: "0.1.1"
meta: meta:
dependency: transitive dependency: transitive
description: description:
@ -208,6 +224,46 @@ packages:
url: "https://pub.dev" url: "https://pub.dev"
source: hosted source: hosted
version: "2.1.8" version: "2.1.8"
screen_retriever:
dependency: transitive
description:
name: screen_retriever
sha256: "570dbc8e4f70bac451e0efc9c9bb19fa2d6799a11e6ef04f946d7886d2e23d0c"
url: "https://pub.dev"
source: hosted
version: "0.2.0"
screen_retriever_linux:
dependency: transitive
description:
name: screen_retriever_linux
sha256: f7f8120c92ef0784e58491ab664d01efda79a922b025ff286e29aa123ea3dd18
url: "https://pub.dev"
source: hosted
version: "0.2.0"
screen_retriever_macos:
dependency: transitive
description:
name: screen_retriever_macos
sha256: "71f956e65c97315dd661d71f828708bd97b6d358e776f1a30d5aa7d22d78a149"
url: "https://pub.dev"
source: hosted
version: "0.2.0"
screen_retriever_platform_interface:
dependency: transitive
description:
name: screen_retriever_platform_interface
sha256: ee197f4581ff0d5608587819af40490748e1e39e648d7680ecf95c05197240c0
url: "https://pub.dev"
source: hosted
version: "0.2.0"
screen_retriever_windows:
dependency: transitive
description:
name: screen_retriever_windows
sha256: "449ee257f03ca98a57288ee526a301a430a344a161f9202b4fcc38576716fe13"
url: "https://pub.dev"
source: hosted
version: "0.2.0"
shared_preferences: shared_preferences:
dependency: "direct main" dependency: "direct main"
description: description:
@ -264,6 +320,14 @@ packages:
url: "https://pub.dev" url: "https://pub.dev"
source: hosted source: hosted
version: "2.4.1" version: "2.4.1"
shortid:
dependency: transitive
description:
name: shortid
sha256: d0b40e3dbb50497dad107e19c54ca7de0d1a274eb9b4404991e443dadb9ebedb
url: "https://pub.dev"
source: hosted
version: "0.1.2"
sky_engine: sky_engine:
dependency: transitive dependency: transitive
description: flutter description: flutter
@ -317,6 +381,14 @@ packages:
url: "https://pub.dev" url: "https://pub.dev"
source: hosted source: hosted
version: "0.7.10" version: "0.7.10"
tray_manager:
dependency: "direct main"
description:
name: tray_manager
sha256: c5fd83b0ae4d80be6eaedfad87aaefab8787b333b8ebd064b0e442a81006035b
url: "https://pub.dev"
source: hosted
version: "0.5.2"
vector_math: vector_math:
dependency: transitive dependency: transitive
description: description:
@ -341,6 +413,14 @@ packages:
url: "https://pub.dev" url: "https://pub.dev"
source: hosted source: hosted
version: "1.1.1" version: "1.1.1"
window_manager:
dependency: "direct main"
description:
name: window_manager
sha256: "7eb6d6c4164ec08e1bf978d6e733f3cebe792e2a23fb07cbca25c2872bfdbdcd"
url: "https://pub.dev"
source: hosted
version: "0.5.1"
xdg_directories: xdg_directories:
dependency: transitive dependency: transitive
description: description:

View File

@ -36,6 +36,8 @@ dependencies:
cupertino_icons: ^1.0.8 cupertino_icons: ^1.0.8
shared_preferences: ^2.5.5 shared_preferences: ^2.5.5
mobile_scanner: ^5.0.0 mobile_scanner: ^5.0.0
window_manager: ^0.5.1
tray_manager: ^0.5.2
dev_dependencies: dev_dependencies:
flutter_test: flutter_test:

View File

@ -2641,7 +2641,7 @@ dependencies = [
[[package]] [[package]]
name = "ostp-client" name = "ostp-client"
version = "0.2.73" version = "0.2.79"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"base64 0.22.1", "base64 0.22.1",
@ -2666,12 +2666,13 @@ dependencies = [
"tracing", "tracing",
"tun", "tun",
"webpki-roots 0.26.11", "webpki-roots 0.26.11",
"winapi",
"x25519-dalek", "x25519-dalek",
] ]
[[package]] [[package]]
name = "ostp-core" name = "ostp-core"
version = "0.2.73" version = "0.2.79"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes", "bytes",

View File

@ -106,6 +106,7 @@ struct HelperState {
pipe_state: Arc<Mutex<HelperPipeState>>, pipe_state: Arc<Mutex<HelperPipeState>>,
cmd_tx: tokio::sync::mpsc::Sender<String>, cmd_tx: tokio::sync::mpsc::Sender<String>,
token: String, token: String,
port: u16,
} }
enum TunnelHandle { enum TunnelHandle {
@ -282,6 +283,37 @@ async fn get_metrics(state: tauri::State<'_, AppState>) -> Result<Option<UIMetri
} }
} }
#[tauri::command]
async fn reload_tunnel(state: tauri::State<'_, AppState>) -> Result<bool, String> {
let mut guard = state.0.lock().await;
if guard.tunnel.is_none() {
return Ok(false);
}
let config_path = get_config_path();
let config_str = std::fs::read_to_string(&config_path)
.map_err(|e| format!("Read config error: {}", e))?;
match &guard.tunnel {
Some(TunnelHandle::Helper(h)) => {
let cmd = format!(
"{{\"cmd\":\"reload\",\"config\":{},\"token\":\"{}\"}}\n",
serde_json::to_string(&config_str).unwrap(),
h.token
);
let _ = h.cmd_tx.send(cmd).await;
}
Some(TunnelHandle::InProcess(s)) => {
// Restarting in-process tunnel is not supported without re-calling start_tunnel,
// but we can just abort and we should really call start_tunnel again.
// For now, return false.
return Ok(false);
}
None => {}
}
Ok(true)
}
#[tauri::command] #[tauri::command]
async fn stop_tunnel(state: tauri::State<'_, AppState>) -> Result<bool, String> { async fn stop_tunnel(state: tauri::State<'_, AppState>) -> Result<bool, String> {
let mut guard = state.0.lock().await; let mut guard = state.0.lock().await;
@ -375,24 +407,19 @@ async fn start_tun_via_helper(
guard: &mut AppStateInner, guard: &mut AppStateInner,
raw: &ClientConfigRaw, raw: &ClientConfigRaw,
) -> Result<bool, String> { ) -> Result<bool, String> {
#[cfg(target_os = "windows")] let port = {
{ let listener = std::net::TcpListener::bind("127.0.0.1:0").map_err(|e| format!("Bind error: {}", e))?;
// Kill any existing helper processes to prevent os error 10048 (port already in use) listener.local_addr().unwrap().port()
use std::os::windows::process::CommandExt; };
let _ = std::process::Command::new("taskkill")
.args(["/F", "/IM", "ostp-tun-helper.exe"])
.creation_flags(0x08000000)
.output();
}
let auth_token = rand::random::<u64>().to_string(); let auth_token = rand::random::<u64>().to_string();
let helper_exe = find_helper_exe().ok_or_else(|| "ostp-tun-helper.exe not found.".to_string())?; let helper_exe = find_helper_exe().ok_or_else(|| "ostp-tun-helper.exe not found.".to_string())?;
launch_as_admin(&helper_exe, &auth_token).map_err(|e| format!("Failed to launch helper: {}", e))?; launch_as_admin(&helper_exe, &auth_token, port).map_err(|e| format!("Failed to launch helper: {}", e))?;
tokio::time::sleep(std::time::Duration::from_millis(1500)).await; tokio::time::sleep(std::time::Duration::from_millis(1500)).await;
let socket = tokio::time::timeout(std::time::Duration::from_secs(60), async { let socket = tokio::time::timeout(std::time::Duration::from_secs(60), async {
loop { loop {
match tokio::net::TcpStream::connect("127.0.0.1:53211").await { match tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)).await {
Ok(s) => return Ok::<_, std::io::Error>(s), Ok(s) => return Ok::<_, std::io::Error>(s),
Err(_) => tokio::time::sleep(std::time::Duration::from_millis(200)).await, Err(_) => tokio::time::sleep(std::time::Duration::from_millis(200)).await,
} }
@ -443,7 +470,7 @@ async fn start_tun_via_helper(
state_for_task.lock().await.connection_state = 0; state_for_task.lock().await.connection_state = 0;
}); });
guard.tunnel = Some(TunnelHandle::Helper(HelperState { pipe_state, cmd_tx, token: auth_token })); guard.tunnel = Some(TunnelHandle::Helper(HelperState { pipe_state, cmd_tx, token: auth_token, port }));
Ok(true) Ok(true)
} }
@ -493,14 +520,14 @@ fn find_helper_exe() -> Option<PathBuf> {
} }
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
fn launch_as_admin(exe: &std::path::PathBuf, token: &str) -> anyhow::Result<()> { fn launch_as_admin(exe: &std::path::PathBuf, token: &str, port: u16) -> anyhow::Result<()> {
use std::ffi::OsStr; use std::ffi::OsStr;
use std::os::windows::ffi::OsStrExt; use std::os::windows::ffi::OsStrExt;
use std::ptr::null_mut; use std::ptr::null_mut;
let exe_wstr: Vec<u16> = exe.as_os_str().encode_wide().chain(Some(0)).collect(); let exe_wstr: Vec<u16> = exe.as_os_str().encode_wide().chain(Some(0)).collect();
let verb_wstr: Vec<u16> = OsStr::new("runas").encode_wide().chain(Some(0)).collect(); let verb_wstr: Vec<u16> = OsStr::new("runas").encode_wide().chain(Some(0)).collect();
let params_str = format!("--token {}", token); let params_str = format!("--token {} --port {}", token, port);
let params_wstr: Vec<u16> = OsStr::new(&params_str).encode_wide().chain(Some(0)).collect(); let params_wstr: Vec<u16> = OsStr::new(&params_str).encode_wide().chain(Some(0)).collect();
#[link(name = "shell32")] extern "system" { fn ShellExecuteW(h: *mut std::ffi::c_void, op: *const u16, f: *const u16, p: *const u16, d: *const u16, s: i32) -> isize; } #[link(name = "shell32")] extern "system" { fn ShellExecuteW(h: *mut std::ffi::c_void, op: *const u16, f: *const u16, p: *const u16, d: *const u16, s: i32) -> isize; }
@ -514,7 +541,7 @@ fn launch_as_admin(exe: &std::path::PathBuf, token: &str) -> anyhow::Result<()>
} }
#[cfg(not(target_os = "windows"))] #[cfg(not(target_os = "windows"))]
fn launch_as_admin(_exe: &PathBuf, _token: &str) -> Result<()> { anyhow::bail!("Windows only."); } fn launch_as_admin(_exe: &PathBuf, _token: &str, _port: u16) -> Result<()> { anyhow::bail!("Windows only."); }
#[cfg_attr(mobile, tauri::mobile_entry_point)] #[cfg_attr(mobile, tauri::mobile_entry_point)]
pub fn run() { pub fn run() {
@ -607,7 +634,7 @@ pub fn run() {
} }
_ => {} _ => {}
}) })
.invoke_handler(tauri::generate_handler![start_tunnel, stop_tunnel, get_tunnel_status, get_metrics, get_config, save_config]) .invoke_handler(tauri::generate_handler![start_tunnel, stop_tunnel, reload_tunnel, get_tunnel_status, get_metrics, get_config, save_config])
.run(tauri::generate_context!()) .run(tauri::generate_context!())
.expect("error while running tauri application"); .expect("error while running tauri application");
} }

View File

@ -319,6 +319,8 @@ async function handleSave(silent = false) {
rawConfig.tun = rawConfig.tun || {}; rawConfig.tun = rawConfig.tun || {};
rawConfig.tun.enable = inTun.checked; rawConfig.tun.enable = inTun.checked;
rawConfig.tun.wintun_path = rawConfig.tun.wintun_path || './wintun.dll';
rawConfig.tun.ipv4_address = rawConfig.tun.ipv4_address || '10.1.0.2/24';
rawConfig.tun.stack = 'ostp'; rawConfig.tun.stack = 'ostp';
// owndns: if toggle is on, always write 10.1.0.1; otherwise use the custom field // owndns: if toggle is on, always write 10.1.0.1; otherwise use the custom field
rawConfig.tun.dns = inOwndns.checked ? '10.1.0.1' : (inDns.value.trim() || null); rawConfig.tun.dns = inOwndns.checked ? '10.1.0.1' : (inDns.value.trim() || null);
@ -477,7 +479,14 @@ window.addEventListener('DOMContentLoaded', async () => {
// Auto-save wiring // Auto-save wiring
const formInputs = document.querySelectorAll('#settings-screen input:not(#in-import-url), #settings-screen textarea, #settings-screen select'); const formInputs = document.querySelectorAll('#settings-screen input:not(#in-import-url), #settings-screen textarea, #settings-screen select');
formInputs.forEach(el => { formInputs.forEach(el => {
el.addEventListener('input', scheduleAutoSave); el.addEventListener('input', () => {
scheduleAutoSave();
if (appState === 'connected') {
if (window.__TAURI__ && window.__TAURI__.invoke) {
window.__TAURI__.invoke('reload_tunnel');
}
}
});
el.addEventListener('change', scheduleAutoSave); el.addEventListener('change', scheduleAutoSave);
}); });

View File

@ -32,12 +32,14 @@ class OstpClientSdk private constructor(private val context: Context) {
// ── Native JNI bindings ─────────────────────────────────────────────────── // ── Native JNI bindings ───────────────────────────────────────────────────
private external fun nativeStartClient(configJson: String): Boolean private external fun startClient(configJson: String, fd: Int, t2sBinPath: String, localProxy: String): Boolean
private external fun nativeStopClient(): Boolean private external fun stopClient(): Boolean
private external fun nativeGetMetrics(): String private external fun getMetrics(): String
private external fun nativeGetLogs(): String private external fun getLogs(): String
private external fun addLog(logMsg: String)
private external fun notifyNetworkChanged() private external fun notifyNetworkChanged()
// ── Public data models ──────────────────────────────────────────────────── // ── Public data models ────────────────────────────────────────────────────
/** /**
@ -175,7 +177,8 @@ class OstpClientSdk private constructor(private val context: Context) {
_state.value = TunnelState.Connecting _state.value = TunnelState.Connecting
val json = config.toNativeJson() val json = config.toNativeJson()
val ok = nativeStartClient(json) // Default values for fd, t2sBinPath, localProxy for proxy mode
val ok = startClient(json, -1, "", config.proxyBind)
if (!ok) { if (!ok) {
_state.value = TunnelState.Failed("Native layer rejected config") _state.value = TunnelState.Failed("Native layer rejected config")
started.set(false) started.set(false)
@ -197,7 +200,7 @@ class OstpClientSdk private constructor(private val context: Context) {
pollingJob?.cancel() pollingJob?.cancel()
networkCallbackJob?.cancel() networkCallbackJob?.cancel()
nativeStopClient() stopClient()
unregisterNetworkCallback() unregisterNetworkCallback()
_state.value = TunnelState.Idle _state.value = TunnelState.Idle
emitLog("OSTP SDK stopped") emitLog("OSTP SDK stopped")
@ -209,7 +212,7 @@ class OstpClientSdk private constructor(private val context: Context) {
*/ */
fun drainLogs(): List<String> { fun drainLogs(): List<String> {
return try { return try {
val array = JSONArray(nativeGetLogs()) val array = JSONArray(getLogs())
(0 until array.length()).map { array.getString(it) } (0 until array.length()).map { array.getString(it) }
} catch (_: Exception) { } catch (_: Exception) {
emptyList() emptyList()
@ -218,7 +221,7 @@ class OstpClientSdk private constructor(private val context: Context) {
/** Read the latest [Metrics] snapshot. Returns zeroed metrics if tunnel is idle. */ /** Read the latest [Metrics] snapshot. Returns zeroed metrics if tunnel is idle. */
fun getMetrics(): Metrics { fun getMetrics(): Metrics {
return parseMetrics(nativeGetMetrics()) return parseMetrics(getMetrics())
} }
// ── Internal helpers ────────────────────────────────────────────────────── // ── Internal helpers ──────────────────────────────────────────────────────
@ -247,7 +250,7 @@ class OstpClientSdk private constructor(private val context: Context) {
} }
// Update state based on metrics availability // Update state based on metrics availability
val metrics = parseMetrics(nativeGetMetrics()) val metrics = parseMetrics(getMetrics())
if (wasConnected) { if (wasConnected) {
_state.value = TunnelState.Connected(metrics) _state.value = TunnelState.Connected(metrics)
} }
@ -320,6 +323,33 @@ class OstpClientSdk private constructor(private val context: Context) {
@Volatile @Volatile
private var instance: OstpClientSdk? = null private var instance: OstpClientSdk? = null
@JvmStatic
fun protectSocket(fd: Int): Boolean {
var retries = 5
while (retries > 0) {
// We use reflection or explicit class to get the VpnService instance
try {
val serviceClass = Class.forName("com.ospab.ostp_client.OstpVpnService")
val instanceField = serviceClass.getDeclaredField("instance")
instanceField.isAccessible = true
val service = instanceField.get(null)
if (service != null) {
val protectMethod = serviceClass.getMethod("protect", Int::class.javaPrimitiveType)
val res = protectMethod.invoke(service, fd) as Boolean
android.util.Log.i("OstpClientSdk", "VpnService.protect(socketFd=$fd) -> success=$res")
return res
}
} catch (e: Exception) {
android.util.Log.w("OstpClientSdk", "Error accessing VpnService via reflection: \${e.message}")
}
android.util.Log.w("OstpClientSdk", "VpnService instance not available! Retrying... (\$retries left)")
Thread.sleep(200)
retries--
}
android.util.Log.e("OstpClientSdk", "VpnService instance is null! Cannot protect socketFd=\$fd")
return false
}
/** /**
* Get the singleton SDK instance. * Get the singleton SDK instance.
* Must be called with an Application context to avoid memory leaks. * Must be called with an Application context to avoid memory leaks.

View File

@ -3,7 +3,7 @@ use jni::sys::{jboolean, jstring};
use jni::JNIEnv; use jni::JNIEnv;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::{atomic::Ordering, Arc, Mutex}; use std::sync::{atomic::Ordering, Arc, Mutex, RwLock};
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use ostp_client::bridge::{Bridge, BridgeMetrics}; use ostp_client::bridge::{Bridge, BridgeMetrics};
@ -80,13 +80,13 @@ impl SdkState {
} }
} }
static STATE: Mutex<SdkState> = Mutex::new(SdkState::new()); static STATE: RwLock<SdkState> = RwLock::new(SdkState::new());
static LOGS: Mutex<VecDeque<String>> = Mutex::new(VecDeque::new()); static LOGS: RwLock<VecDeque<String>> = RwLock::new(VecDeque::new());
static JVM: Mutex<Option<jni::JavaVM>> = Mutex::new(None); static JVM: RwLock<Option<jni::JavaVM>> = RwLock::new(None);
static CLASS_REF: Mutex<Option<jni::objects::GlobalRef>> = Mutex::new(None); static CLASS_REF: RwLock<Option<jni::objects::GlobalRef>> = RwLock::new(None);
fn add_log(text: String) { fn add_log(text: String) {
if let Ok(mut guard) = LOGS.lock() { if let Ok(mut guard) = LOGS.write() {
if guard.len() >= 1000 { if guard.len() >= 1000 {
guard.pop_front(); guard.pop_front();
} }
@ -95,7 +95,7 @@ fn add_log(text: String) {
} }
#[no_mangle] #[no_mangle]
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeStartClient(
mut env: JNIEnv, mut env: JNIEnv,
_class: JClass, _class: JClass,
config_json: JString, config_json: JString,
@ -103,7 +103,7 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
t2s_bin_path: JString, t2s_bin_path: JString,
local_proxy: JString, local_proxy: JString,
) -> jboolean { ) -> jboolean {
let mut state = match STATE.lock() { let mut state = match STATE.write() {
Ok(s) => s, Ok(s) => s,
Err(_) => return jni::sys::JNI_FALSE, Err(_) => return jni::sys::JNI_FALSE,
}; };
@ -116,25 +116,25 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
init_tracing(); init_tracing();
if let Ok(jvm) = env.get_java_vm() { if let Ok(jvm) = env.get_java_vm() {
if let Ok(mut guard) = JVM.lock() { if let Ok(mut guard) = JVM.write() {
*guard = Some(jvm); *guard = Some(jvm);
} }
} }
if let Ok(cls) = env.find_class("net/ostp/client/OstpClientSdk") { if let Ok(cls) = env.find_class("net/ostp/client/OstpClientSdk") {
if let Ok(global_cls) = env.new_global_ref(cls) { if let Ok(global_cls) = env.new_global_ref(cls) {
if let Ok(mut guard) = CLASS_REF.lock() { if let Ok(mut guard) = CLASS_REF.write() {
*guard = Some(global_cls); *guard = Some(global_cls);
} }
} }
} }
ostp_client::bridge::set_socket_protector(|fd| { ostp_client::bridge::set_socket_protector(|fd| {
let jvm_guard = match JVM.lock() { let jvm_guard = match JVM.read() {
Ok(g) => g, Ok(g) => g,
Err(_) => return false, Err(_) => return false,
}; };
let class_guard = match CLASS_REF.lock() { let class_guard = match CLASS_REF.read() {
Ok(g) => g, Ok(g) => g,
Err(_) => return false, Err(_) => return false,
}; };
@ -346,12 +346,24 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
} }
#[no_mangle] #[no_mangle]
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient( pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
env: JNIEnv,
class: JClass,
config_json: JString,
fd: jni::sys::jint,
t2s_bin_path: JString,
local_proxy: JString,
) -> jboolean {
Java_net_ostp_client_OstpClientSdk_nativeStartClient(env, class, config_json, fd, t2s_bin_path, local_proxy)
}
#[no_mangle]
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeStopClient(
_env: JNIEnv, _env: JNIEnv,
_class: JClass, _class: JClass,
) -> jboolean { ) -> jboolean {
let (tun_child, shutdown_tx, runtime) = { let (tun_child, shutdown_tx, runtime) = {
let mut state = match STATE.lock() { let mut state = match STATE.write() {
Ok(s) => s, Ok(s) => s,
Err(_) => return jni::sys::JNI_FALSE, Err(_) => return jni::sys::JNI_FALSE,
}; };
@ -381,11 +393,19 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient(
} }
#[no_mangle] #[no_mangle]
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics( pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient(
env: JNIEnv,
class: JClass,
) -> jboolean {
Java_net_ostp_client_OstpClientSdk_nativeStopClient(env, class)
}
#[no_mangle]
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeGetMetrics(
env: JNIEnv, env: JNIEnv,
_class: JClass, _class: JClass,
) -> jstring { ) -> jstring {
let state = match STATE.lock() { let state = match STATE.read() {
Ok(s) => s, Ok(s) => s,
Err(_) => return match env.new_string("{}") { Err(_) => return match env.new_string("{}") {
Ok(s) => s.into_raw(), Ok(s) => s.into_raw(),
@ -415,11 +435,19 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics(
} }
#[no_mangle] #[no_mangle]
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs( pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics(
env: JNIEnv,
class: JClass,
) -> jstring {
Java_net_ostp_client_OstpClientSdk_nativeGetMetrics(env, class)
}
#[no_mangle]
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeGetLogs(
env: JNIEnv, env: JNIEnv,
_class: JClass, _class: JClass,
) -> jstring { ) -> jstring {
let logs_vec: Vec<String> = match LOGS.lock() { let logs_vec: Vec<String> = match LOGS.write() {
Ok(mut guard) => guard.drain(..).collect(), Ok(mut guard) => guard.drain(..).collect(),
Err(_) => Vec::new(), Err(_) => Vec::new(),
}; };
@ -435,6 +463,14 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs(
} }
} }
#[no_mangle]
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs(
env: JNIEnv,
class: JClass,
) -> jstring {
Java_net_ostp_client_OstpClientSdk_nativeGetLogs(env, class)
}
#[no_mangle] #[no_mangle]
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_addLog( pub extern "system" fn Java_net_ostp_client_OstpClientSdk_addLog(
mut env: JNIEnv, mut env: JNIEnv,
@ -454,7 +490,7 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_notifyNetworkChanged(
_env: JNIEnv, _env: JNIEnv,
_class: JClass, _class: JClass,
) { ) {
let state = match STATE.lock() { let state = match STATE.read() {
Ok(s) => s, Ok(s) => s,
Err(_) => return, Err(_) => return,
}; };

View File

@ -310,7 +310,7 @@ fn check_token(state: &ApiState, headers: &axum::http::HeaderMap) -> bool {
if let Some(value) = headers.get("authorization") { if let Some(value) = headers.get("authorization") {
if let Ok(val) = value.to_str() { if let Ok(val) = value.to_str() {
if let Some(token) = val.strip_prefix("Bearer ") { if let Some(token) = val.strip_prefix("Bearer ") {
let current_session = state.session_token.read().unwrap().clone(); let current_session = state.session_token.read().unwrap_or_else(|e| e.into_inner()).clone();
if let Some(session) = current_session { if let Some(session) = current_session {
if token == session { if token == session {
allowed = true; allowed = true;
@ -353,7 +353,7 @@ async fn handle_login(
if hash_hex == state.password_hash { if hash_hex == state.password_hash {
let token = uuid::Uuid::new_v4().to_string(); let token = uuid::Uuid::new_v4().to_string();
*state.session_token.write().unwrap() = Some(token.clone()); *state.session_token.write().unwrap_or_else(|e| e.into_inner()) = Some(token.clone());
(StatusCode::OK, ApiResponse::success(LoginResponse { token })) (StatusCode::OK, ApiResponse::success(LoginResponse { token }))
} else { } else {
api_unauthorized::<LoginResponse>() api_unauthorized::<LoginResponse>()
@ -377,7 +377,7 @@ fn save_config_keys(state: &ApiState) -> Result<(), String> {
let mut json_val: serde_json::Value = serde_json::from_str(&content_str) let mut json_val: serde_json::Value = serde_json::from_str(&content_str)
.map_err(|e| format!("failed to parse config JSON: {}", e))?; .map_err(|e| format!("failed to parse config JSON: {}", e))?;
let keys = state.access_keys.read().unwrap(); let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
let mut access_keys_json = Vec::new(); let mut access_keys_json = Vec::new();
for (k, m) in keys.iter() { for (k, m) in keys.iter() {
if m.name.is_none() && m.limit_bytes.is_none() { if m.name.is_none() && m.limit_bytes.is_none() {
@ -511,8 +511,8 @@ async fn handle_status(
return api_unauthorized::<ServerStatus>(); return api_unauthorized::<ServerStatus>();
} }
let keys = state.access_keys.read().unwrap(); let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
let stats = state.user_stats.read().unwrap(); let stats = state.user_stats.read().unwrap_or_else(|e| e.into_inner());
let online = stats.values() let online = stats.values()
.filter(|us| { .filter(|us| {
let total = us.bytes_up.load(Ordering::Relaxed) + us.bytes_down.load(Ordering::Relaxed); let total = us.bytes_up.load(Ordering::Relaxed) + us.bytes_down.load(Ordering::Relaxed);
@ -538,8 +538,8 @@ async fn handle_list_users(
return api_unauthorized::<Vec<UserStatsSnapshot>>(); return api_unauthorized::<Vec<UserStatsSnapshot>>();
} }
let keys = state.access_keys.read().unwrap(); let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
let stats = state.user_stats.read().unwrap(); let stats = state.user_stats.read().unwrap_or_else(|e| e.into_inner());
let mut users: Vec<UserStatsSnapshot> = keys.iter().map(|(key, meta)| { let mut users: Vec<UserStatsSnapshot> = keys.iter().map(|(key, meta)| {
if let Some(us) = stats.get(key) { if let Some(us) = stats.get(key) {
@ -579,13 +579,13 @@ async fn handle_get_user(
return api_unauthorized::<UserStatsSnapshot>(); return api_unauthorized::<UserStatsSnapshot>();
} }
let keys = state.access_keys.read().unwrap(); let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
let meta = match keys.get(&key) { let meta = match keys.get(&key) {
Some(m) => m.clone(), Some(m) => m.clone(),
None => return api_error("user not found"), None => return api_error("user not found"),
}; };
let stats = state.user_stats.read().unwrap(); let stats = state.user_stats.read().unwrap_or_else(|e| e.into_inner());
let snapshot = if let Some(us) = stats.get(&key) { let snapshot = if let Some(us) = stats.get(&key) {
UserStatsSnapshot { UserStatsSnapshot {
access_key: key.clone(), access_key: key.clone(),
@ -628,11 +628,11 @@ async fn handle_create_user(
}); });
{ {
let mut keys = state.access_keys.write().unwrap(); let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner());
keys.insert(key.clone(), UserMeta { name: body.name.clone(), limit_bytes: body.limit_bytes }); keys.insert(key.clone(), UserMeta { name: body.name.clone(), limit_bytes: body.limit_bytes });
} }
let mut stats = state.user_stats.write().unwrap(); let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
stats.insert(key.clone(), Arc::new(UserStats::new(body.limit_bytes))); stats.insert(key.clone(), Arc::new(UserStats::new(body.limit_bytes)));
drop(stats); drop(stats);
@ -655,14 +655,14 @@ async fn delete_user(
} }
{ {
let mut keys = state.access_keys.write().unwrap(); let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner());
if keys.remove(&key).is_none() { if keys.remove(&key).is_none() {
return api_error::<String>("User not found"); return api_error::<String>("User not found");
} }
} }
{ {
let mut stats = state.user_stats.write().unwrap(); let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
stats.remove(&key); stats.remove(&key);
} }
@ -685,7 +685,7 @@ async fn update_user(
} }
{ {
let mut keys = state.access_keys.write().unwrap(); let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner());
if let Some(meta) = keys.get_mut(&key) { if let Some(meta) = keys.get_mut(&key) {
meta.name = body.name.clone(); meta.name = body.name.clone();
meta.limit_bytes = body.limit_bytes; meta.limit_bytes = body.limit_bytes;
@ -695,7 +695,7 @@ async fn update_user(
} }
{ {
let mut stats = state.user_stats.write().unwrap(); let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
let entry = stats.entry(key.clone()) let entry = stats.entry(key.clone())
.or_insert_with(|| Arc::new(UserStats::new(body.limit_bytes))); .or_insert_with(|| Arc::new(UserStats::new(body.limit_bytes)));
@ -727,7 +727,7 @@ async fn handle_set_limit(
} }
{ {
let mut keys = state.access_keys.write().unwrap(); let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner());
if let Some(meta) = keys.get_mut(&key) { if let Some(meta) = keys.get_mut(&key) {
meta.limit_bytes = body.limit_bytes; meta.limit_bytes = body.limit_bytes;
} else { } else {
@ -735,7 +735,7 @@ async fn handle_set_limit(
} }
} }
let mut stats = state.user_stats.write().unwrap(); let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
let entry = stats.entry(key.clone()) let entry = stats.entry(key.clone())
.or_insert_with(|| Arc::new(UserStats::new(body.limit_bytes))); .or_insert_with(|| Arc::new(UserStats::new(body.limit_bytes)));
@ -765,7 +765,7 @@ async fn handle_reset_stats(
return api_unauthorized::<bool>(); return api_unauthorized::<bool>();
} }
let mut stats = state.user_stats.write().unwrap(); let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
if let Some(us) = stats.get(&key) { if let Some(us) = stats.get(&key) {
let limit = us.limit_bytes; let limit = us.limit_bytes;
stats.insert(key.clone(), Arc::new(UserStats::new(limit))); stats.insert(key.clone(), Arc::new(UserStats::new(limit)));
@ -793,7 +793,7 @@ async fn handle_subscribe(
// Validate that the key exists in a tightly scoped block to drop the guard // Validate that the key exists in a tightly scoped block to drop the guard
let key_exists = { let key_exists = {
let keys = state.access_keys.read().unwrap(); let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
keys.contains_key(&key) keys.contains_key(&key)
}; };

View File

@ -86,7 +86,7 @@ pub struct Dispatcher {
impl Dispatcher { impl Dispatcher {
pub fn new(machine_config: ProtocolConfig, access_keys: Arc<RwLock<HashMap<String, crate::api::UserMeta>>>) -> Self { pub fn new(machine_config: ProtocolConfig, access_keys: Arc<RwLock<HashMap<String, crate::api::UserMeta>>>) -> Self {
let mut initial_stats = HashMap::new(); let mut initial_stats = HashMap::new();
for (key, meta) in access_keys.read().unwrap().iter() { for (key, meta) in access_keys.read().unwrap_or_else(|e| e.into_inner()).iter() {
initial_stats.insert(key.clone(), Arc::new(UserStats::new(meta.limit_bytes))); initial_stats.insert(key.clone(), Arc::new(UserStats::new(meta.limit_bytes)));
} }
Self { Self {
@ -108,7 +108,7 @@ impl Dispatcher {
/// Snapshot all user stats for API responses. /// Snapshot all user stats for API responses.
pub fn snapshot_all_users(&self) -> Vec<UserStatsSnapshot> { pub fn snapshot_all_users(&self) -> Vec<UserStatsSnapshot> {
let stats = self.user_stats.read().unwrap(); let stats = self.user_stats.read().unwrap_or_else(|e| e.into_inner());
let online_keys: std::collections::HashSet<String> = self.peer_machines.values() let online_keys: std::collections::HashSet<String> = self.peer_machines.values()
.map(|ps| ps.access_key.clone()) .map(|ps| ps.access_key.clone())
.collect(); .collect();
@ -125,15 +125,15 @@ impl Dispatcher {
/// Get or create stats entry for a user key. /// Get or create stats entry for a user key.
fn get_or_create_user_stats(&self, key: &str) -> Arc<UserStats> { fn get_or_create_user_stats(&self, key: &str) -> Arc<UserStats> {
let stats = self.user_stats.read().unwrap(); let stats = self.user_stats.read().unwrap_or_else(|e| e.into_inner());
if let Some(existing) = stats.get(key) { if let Some(existing) = stats.get(key) {
return existing.clone(); return existing.clone();
} }
drop(stats); drop(stats);
let limit_bytes = self.access_keys.read().unwrap().get(key).and_then(|m| m.limit_bytes); let limit_bytes = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).get(key).and_then(|m| m.limit_bytes);
let mut stats = self.user_stats.write().unwrap(); let mut stats = self.user_stats.write().unwrap_or_else(|e| e.into_inner());
stats.entry(key.to_string()) stats.entry(key.to_string())
.or_insert_with(|| Arc::new(UserStats::new(limit_bytes))) .or_insert_with(|| Arc::new(UserStats::new(limit_bytes)))
.clone() .clone()
@ -141,7 +141,7 @@ impl Dispatcher {
/// Set traffic limit for a user. /// Set traffic limit for a user.
pub fn set_user_limit(&self, key: &str, limit: Option<u64>) { pub fn set_user_limit(&self, key: &str, limit: Option<u64>) {
let mut stats = self.user_stats.write().unwrap(); let mut stats = self.user_stats.write().unwrap_or_else(|e| e.into_inner());
let entry = stats.entry(key.to_string()) let entry = stats.entry(key.to_string())
.or_insert_with(|| Arc::new(UserStats::new(limit))); .or_insert_with(|| Arc::new(UserStats::new(limit)));
// Replace the entry with new limit (stats reset) // Replace the entry with new limit (stats reset)
@ -212,7 +212,7 @@ impl Dispatcher {
let key_opt = self.peer_machines.get(&session_id).map(|ps| ps.access_key.clone()); let key_opt = self.peer_machines.get(&session_id).map(|ps| ps.access_key.clone());
if let Some(access_key) = key_opt { if let Some(access_key) = key_opt {
// Check if key is still valid and not over limit // Check if key is still valid and not over limit
let key_valid = self.access_keys.read().unwrap().contains_key(&access_key); let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&access_key);
let user_stats = self.get_or_create_user_stats(&access_key); let user_stats = self.get_or_create_user_stats(&access_key);
if !key_valid || user_stats.is_over_limit() { if !key_valid || user_stats.is_over_limit() {
tracing::info!("Dropping session {} for key {} (valid={}, over_limit={})", tracing::info!("Dropping session {} for key {} (valid={}, over_limit={})",
@ -280,7 +280,7 @@ impl Dispatcher {
} }
// Not an existing session — try each registered access key's derived obfuscation key // Not an existing session — try each registered access key's derived obfuscation key
let keys_snapshot: Vec<String> = self.access_keys.read().unwrap().keys().cloned().collect(); let keys_snapshot: Vec<String> = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).keys().cloned().collect();
for candidate_key in keys_snapshot { for candidate_key in keys_snapshot {
let secrets = ostp_core::crypto::derive_all_secrets(candidate_key.as_bytes()); let secrets = ostp_core::crypto::derive_all_secrets(candidate_key.as_bytes());
@ -430,7 +430,7 @@ impl Dispatcher {
// Gather expired or invalid sessions // Gather expired or invalid sessions
for (&sid, peer_state) in &self.peer_machines { for (&sid, peer_state) in &self.peer_machines {
let key_valid = self.access_keys.read().unwrap().contains_key(&peer_state.access_key); let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&peer_state.access_key);
let user_stats = self.get_or_create_user_stats(&peer_state.access_key); let user_stats = self.get_or_create_user_stats(&peer_state.access_key);
if now.duration_since(peer_state.last_seen) > timeout_dur || !key_valid || user_stats.is_over_limit() { if now.duration_since(peer_state.last_seen) > timeout_dur || !key_valid || user_stats.is_over_limit() {
expired.push(sid); expired.push(sid);
@ -441,7 +441,7 @@ impl Dispatcher {
for sid in &expired { for sid in &expired {
let peer_state_opt = self.peer_machines.get(sid); let peer_state_opt = self.peer_machines.get(sid);
let reason = if let Some(ps) = peer_state_opt { let reason = if let Some(ps) = peer_state_opt {
let key_valid = self.access_keys.read().unwrap().contains_key(&ps.access_key); let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&ps.access_key);
let user_stats = self.get_or_create_user_stats(&ps.access_key); let user_stats = self.get_or_create_user_stats(&ps.access_key);
if now.duration_since(ps.last_seen) > timeout_dur { if now.duration_since(ps.last_seen) > timeout_dur {
"inactive >5min" "inactive >5min"
@ -504,15 +504,15 @@ fn get_or_create_stats(
key: &str, key: &str,
) -> Arc<UserStats> { ) -> Arc<UserStats> {
{ {
let stats = user_stats.read().unwrap(); let stats = user_stats.read().unwrap_or_else(|e| e.into_inner());
if let Some(existing) = stats.get(key) { if let Some(existing) = stats.get(key) {
return existing.clone(); return existing.clone();
} }
} }
let limit_bytes = access_keys.read().unwrap().get(key).and_then(|m| m.limit_bytes); let limit_bytes = access_keys.read().unwrap_or_else(|e| e.into_inner()).get(key).and_then(|m| m.limit_bytes);
let mut stats = user_stats.write().unwrap(); let mut stats = user_stats.write().unwrap_or_else(|e| e.into_inner());
stats.entry(key.to_string()) stats.entry(key.to_string())
.or_insert_with(|| Arc::new(UserStats::new(limit_bytes))) .or_insert_with(|| Arc::new(UserStats::new(limit_bytes)))
.clone() .clone()

View File

@ -195,11 +195,11 @@ pub async fn run_server(
} }
// 1. Update shared_keys // 1. Update shared_keys
let mut keys_lock = shared_keys_clone.write().unwrap(); let mut keys_lock = shared_keys_clone.write().unwrap_or_else(|e| e.into_inner());
*keys_lock = new_keys.clone(); *keys_lock = new_keys.clone();
// 2. Synchronize user_stats limits & cleanup deleted keys // 2. Synchronize user_stats limits & cleanup deleted keys
let mut stats_lock = user_stats_clone.write().unwrap(); let mut stats_lock = user_stats_clone.write().unwrap_or_else(|e| e.into_inner());
stats_lock.retain(|k, _| new_keys.contains_key(k)); stats_lock.retain(|k, _| new_keys.contains_key(k));
for (k, meta) in &new_keys { for (k, meta) in &new_keys {
@ -308,7 +308,7 @@ pub async fn run_server(
} }
}); });
let key_count = shared_keys.read().unwrap().len(); let key_count = shared_keys.read().unwrap_or_else(|e| e.into_inner()).len();
tracing::info!(listeners = bind_addrs.len(), keys = key_count, "server started"); tracing::info!(listeners = bind_addrs.len(), keys = key_count, "server started");
tracing::info!("ARQ config: max_reorder=16384, reorder_buf=8192, sent_history=32768, rto=100ms"); tracing::info!("ARQ config: max_reorder=16384, reorder_buf=8192, sent_history=32768, rto=100ms");
let reality_config_arc = reality_config.map(std::sync::Arc::new); let reality_config_arc = reality_config.map(std::sync::Arc::new);
@ -434,7 +434,7 @@ async fn run_server_loop(
if debug { if debug {
let _ = ui_event_tx.send(UiEvent::Log("Server loop started".to_string())); let _ = ui_event_tx.send(UiEvent::Log("Server loop started".to_string()));
let _ = ui_event_tx.send(UiEvent::KeyCount(shared_keys.read().unwrap().len())); let _ = ui_event_tx.send(UiEvent::KeyCount(shared_keys.read().unwrap_or_else(|e| e.into_inner()).len()));
} }
let mut retransmit_tick = interval(Duration::from_millis(10)); let mut retransmit_tick = interval(Duration::from_millis(10));
@ -448,7 +448,7 @@ async fn run_server_loop(
match cmd { match cmd {
Some(UiCommand::CreateClientKey) => { Some(UiCommand::CreateClientKey) => {
let key = format!("ostp_key_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs()); let key = format!("ostp_key_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs());
shared_keys.write().unwrap().insert(key.clone(), crate::api::UserMeta { name: None, limit_bytes: None }); shared_keys.write().unwrap_or_else(|e| e.into_inner()).insert(key.clone(), crate::api::UserMeta { name: None, limit_bytes: None });
let _ = ui_event_tx.send(UiEvent::KeyCreated { key }); let _ = ui_event_tx.send(UiEvent::KeyCreated { key });
} }
Some(UiCommand::Shutdown) | None => { Some(UiCommand::Shutdown) | None => {

View File

@ -48,7 +48,7 @@ pub async fn run_relay_node(cfg: RelayConfig) -> Result<()> {
if let Err(e) = sync_keys(&cfg, &shared_keys).await { if let Err(e) = sync_keys(&cfg, &shared_keys).await {
tracing::warn!("Relay: initial key sync failed: {}. Will retry.", e); tracing::warn!("Relay: initial key sync failed: {}. Will retry.", e);
} else { } else {
let count = shared_keys.read().unwrap().len(); let count = shared_keys.read().unwrap_or_else(|e| e.into_inner()).len();
tracing::info!("Relay: synced {} access key(s) from upstream API", count); tracing::info!("Relay: synced {} access key(s) from upstream API", count);
} }
@ -190,7 +190,7 @@ async fn run_udp_relay(cfg: RelayConfig, shared_keys: SharedKeys) -> Result<()>
let ts_bytes: [u8; 8] = packet[0..8].try_into().unwrap(); let ts_bytes: [u8; 8] = packet[0..8].try_into().unwrap();
let provided_mac = &packet[8..40]; let provided_mac = &packet[8..40];
let keys_guard = keys.read().unwrap(); let keys_guard = keys.read().unwrap_or_else(|e| e.into_inner());
if !verify_hmac(&ts_bytes, provided_mac, &keys_guard) { if !verify_hmac(&ts_bytes, provided_mac, &keys_guard) {
tracing::debug!("Relay UDP: unauthorized probe from {}, dropped", peer); tracing::debug!("Relay UDP: unauthorized probe from {}, dropped", peer);
@ -369,7 +369,7 @@ async fn handle_tcp_client(
// Проверяем по синхронизированным ключам // Проверяем по синхронизированным ключам
let authorized = { let authorized = {
let keys = shared_keys.read().unwrap(); let keys = shared_keys.read().unwrap_or_else(|e| e.into_inner());
verify_hmac(&ts_bytes, provided_mac, &keys) verify_hmac(&ts_bytes, provided_mac, &keys)
}; };

View File

@ -5,7 +5,6 @@ use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::fs::OpenOptions;
use std::io::Write as _; use std::io::Write as _;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::{watch, Mutex}; use tokio::sync::{watch, Mutex};
@ -13,21 +12,25 @@ use tokio::net::TcpListener;
use portable_atomic::Ordering; use portable_atomic::Ordering;
fn log_to_file(msg: &str) { fn log_to_file(msg: &str) {
let msg = msg.to_string();
tokio::task::spawn_blocking(move || {
let path = std::env::current_exe() let path = std::env::current_exe()
.ok() .ok()
.and_then(|p| p.parent().map(|d| d.join("ostp-helper.log"))) .and_then(|p| p.parent().map(|d| d.join("ostp-helper.log")))
.unwrap_or_else(|| std::path::PathBuf::from("ostp-helper.log")); .unwrap_or_else(|| std::path::PathBuf::from("ostp-helper.log"));
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { if let Ok(mut file) = std::fs::OpenOptions::new().create(true).append(true).open(path) {
let _ = writeln!(file, "[{}] {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), msg); let _ = writeln!(file, "[{}] {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), msg);
} }
});
} }
const BIND_ADDR: &str = "127.0.0.1:53211";
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(tag = "cmd", rename_all = "lowercase")] #[serde(tag = "cmd", rename_all = "lowercase")]
enum GuiCmd { enum GuiCmd {
Start { config: String, token: String }, Start { config: String, token: String },
Reload { config: String, token: String },
Stop { token: String }, Stop { token: String },
} }
@ -55,10 +58,13 @@ async fn main() -> Result<()> {
} }
let mut expected_token = String::new(); let mut expected_token = String::new();
let mut port = 53211u16;
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
for i in 1..args.len() { for i in 1..args.len() {
if args[i] == "--token" && i + 1 < args.len() { if args[i] == "--token" && i + 1 < args.len() {
expected_token = args[i + 1].clone(); expected_token = args[i + 1].clone();
} else if args[i] == "--port" && i + 1 < args.len() {
port = args[i + 1].parse().unwrap_or(53211);
} }
} }
@ -69,21 +75,22 @@ async fn main() -> Result<()> {
return Err(anyhow::anyhow!("--token argument is required")); return Err(anyhow::anyhow!("--token argument is required"));
} }
if let Err(e) = run_server(expected_token).await { if let Err(e) = run_server(expected_token, port).await {
log_to_file(&format!("Fatal error: {}", e)); log_to_file(&format!("Fatal error: {}", e));
} }
log_to_file("Helper exiting"); log_to_file("Helper exiting");
Ok(()) Ok(())
} }
async fn run_server(expected_token: String) -> Result<()> { async fn run_server(expected_token: String, port: u16) -> Result<()> {
let state = Arc::new(Mutex::new(TunnelState { let state = Arc::new(Mutex::new(TunnelState {
shutdown_tx: None, shutdown_tx: None,
metrics: None, metrics: None,
})); }));
log_to_file(&format!("Attempting to bind to {}", BIND_ADDR)); let bind_addr = format!("127.0.0.1:{}", port);
let listener = TcpListener::bind(BIND_ADDR).await.map_err(|e| { log_to_file(&format!("Attempting to bind to {}", bind_addr));
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
log_to_file(&format!("Bind failed: {}", e)); log_to_file(&format!("Bind failed: {}", e));
e e
})?; })?;
@ -182,9 +189,10 @@ async fn run_server(expected_token: String) -> Result<()> {
let metrics_for_runner = metrics.clone(); let metrics_for_runner = metrics.clone();
let writer_for_err = writer.clone(); let writer_for_err = writer.clone();
let shutdown_rx_for_core = shutdown_rx.clone();
tokio::spawn(async move { tokio::spawn(async move {
log_to_file("Starting tunnel core..."); log_to_file("Starting tunnel core...");
match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx).await { match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx_for_core).await {
Ok(_) => { log_to_file("Tunnel core stopped normally"); } Ok(_) => { log_to_file("Tunnel core stopped normally"); }
Err(e) => { Err(e) => {
log_to_file(&format!("Tunnel core error: {}", e)); log_to_file(&format!("Tunnel core error: {}", e));
@ -197,10 +205,17 @@ async fn run_server(expected_token: String) -> Result<()> {
let writer_tick = writer.clone(); let writer_tick = writer.clone();
let metrics_tick = metrics.clone(); let metrics_tick = metrics.clone();
let mut shutdown_rx_tick = shutdown_rx.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut last_state = 99u8; let mut last_state = 99u8;
loop { loop {
tokio::time::sleep(Duration::from_secs(1)).await; tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(1)) => {}
_ = shutdown_rx_tick.changed() => {
if *shutdown_rx_tick.borrow() { break; }
}
}
let cs = metrics_tick.connection_state.load(Ordering::Relaxed); let cs = metrics_tick.connection_state.load(Ordering::Relaxed);
let sent = metrics_tick.bytes_sent.load(Ordering::Relaxed); let sent = metrics_tick.bytes_sent.load(Ordering::Relaxed);
let recv = metrics_tick.bytes_recv.load(Ordering::Relaxed); let recv = metrics_tick.bytes_recv.load(Ordering::Relaxed);
@ -221,6 +236,95 @@ async fn run_server(expected_token: String) -> Result<()> {
send_msg(HelperMsg::Status { value: 1 }); send_msg(HelperMsg::Status { value: 1 });
} }
GuiCmd::Reload { config, token } => {
if token != expected_token {
send_msg(HelperMsg::Error { message: "Invalid authorization token".to_string() });
continue;
}
log_to_file("Received RELOAD command");
// Signal shutdown to current core
{
let mut st = state.lock().await;
if let Some(tx) = st.shutdown_tx.take() {
let _ = tx.send(true);
}
tokio::time::sleep(Duration::from_millis(500)).await; // give it time to shutdown cleanly
}
let cfg: ostp_client::config::ClientConfig = match serde_json::from_str(&config) {
Ok(c) => c,
Err(e) => {
send_msg(HelperMsg::Error { message: format!("Config parse error during reload: {}", e) });
continue;
}
};
let metrics = Arc::new(ostp_client::bridge::BridgeMetrics {
bytes_sent: portable_atomic::AtomicU64::new(0),
bytes_recv: portable_atomic::AtomicU64::new(0),
connection_state: portable_atomic::AtomicU8::new(0),
rtt_ms: portable_atomic::AtomicU32::new(0),
});
let (shutdown_tx, shutdown_rx) = watch::channel(false);
{
let mut st = state.lock().await;
st.shutdown_tx = Some(shutdown_tx);
st.metrics = Some(metrics.clone());
}
let metrics_for_runner = metrics.clone();
let writer_for_err = writer.clone();
let shutdown_rx_for_core = shutdown_rx.clone();
tokio::spawn(async move {
log_to_file("Restarting tunnel core for reload...");
match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx_for_core).await {
Ok(_) => { log_to_file("Reloaded core stopped normally"); }
Err(e) => {
let json = serde_json::to_string(&HelperMsg::Error { message: e.to_string() }).unwrap_or_default();
let mut w = writer_for_err.lock().await;
let _ = w.write_all(format!("{}\n", json).as_bytes()).await;
}
}
});
// Status tick loop is already running and using old metrics?
// Wait! We re-created metrics, so the old tick loop will continue reporting old metrics (which are disconnected)!
// We should probably share the tick loop or spawn a new one and let the old one die.
// It's easier if `metrics` in state is a generic watcher, but since we re-spawned it:
let writer_tick = writer.clone();
let metrics_tick = metrics.clone();
let mut shutdown_rx_tick = shutdown_rx.clone();
tokio::spawn(async move {
let mut last_state = 99u8;
loop {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(1)) => {}
_ = shutdown_rx_tick.changed() => {
if *shutdown_rx_tick.borrow() { break; }
}
}
let cs = metrics_tick.connection_state.load(Ordering::Relaxed);
let sent = metrics_tick.bytes_sent.load(Ordering::Relaxed);
let recv = metrics_tick.bytes_recv.load(Ordering::Relaxed);
let rtt = metrics_tick.rtt_ms.load(Ordering::Relaxed);
let mut w = writer_tick.lock().await;
if cs != last_state {
last_state = cs;
let json = serde_json::to_string(&HelperMsg::Status { value: cs }).unwrap_or_default();
if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; }
}
let json = serde_json::to_string(&HelperMsg::Metrics { bytes_sent: sent, bytes_recv: recv, rtt_ms: rtt }).unwrap_or_default();
if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; }
drop(w);
}
});
send_msg(HelperMsg::Status { value: 1 });
}
GuiCmd::Stop { token } => { GuiCmd::Stop { token } => {
if token != expected_token { if token != expected_token {
log_to_file("Received STOP command with invalid token"); log_to_file("Received STOP command with invalid token");

658
refactor.py Normal file
View File

@ -0,0 +1,658 @@
import sys
import re
with open("d:/ospab-projects/ostp/ostp-client/src/bridge.rs", "r", encoding="utf-8") as f:
code = f.read()
start_idx = code.find(" pub async fn run(")
end_idx = -1
brace_count = 0
in_run = False
for i in range(start_idx, len(code)):
if code[i] == '{':
in_run = True
brace_count += 1
elif code[i] == '}':
if in_run:
brace_count -= 1
if brace_count == 0:
end_idx = i + 1
break
prefix = code[:start_idx]
suffix = code[end_idx:]
# Define the new run function and helpers
new_run_and_helpers = """
pub async fn run(
mut self,
tx: mpsc::Sender<UiEvent>,
mut bridge_rx: mpsc::Receiver<BridgeCommand>,
mut shutdown: watch::Receiver<bool>,
mut proxy_rx: mpsc::Receiver<ProxyEvent>,
proxy_tx: mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
) -> Result<()> {
let mut metrics_tick = interval(Duration::from_millis(500));
let mut keepalive_tick = tokio::time::interval(Duration::from_secs(self.keepalive_interval_sec.max(1)));
let mut retransmit_tick = tokio::time::interval(Duration::from_millis(10));
let init_msg = if self.mode == "tun" {
"Bridge initialized (TUN mode)".to_string()
} else {
"Bridge initialized (proxy mode)".to_string()
};
tx.send(UiEvent::Log(init_msg)).await.ok();
let mut sessions_opt: Option<Vec<SessionState>> = None;
let mut udp_rx_opt: Option<mpsc::Receiver<(usize, Bytes)>> = None;
let mut proxy_guard: Option<crate::sysproxy::SystemProxyGuard> = None;
let mut stream_map: std::collections::HashMap<u16, usize> = std::collections::HashMap::new();
loop {
tokio::select! {
biased;
_ = shutdown.changed() => {
if *shutdown.borrow() {
self.running = false;
self.metrics.connection_state.store(0, Ordering::Relaxed);
proxy_guard = None;
sessions_opt = None;
udp_rx_opt = None;
stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "manual stop");
break;
}
}
udp_msg = async {
match udp_rx_opt.as_mut() {
Some(rx) => rx.recv().await,
None => std::future::pending().await,
}
}, if self.running => {
self.handle_inbound_udp(udp_msg, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await;
}
cmd = bridge_rx.recv() => {
if !self.handle_bridge_cmd(cmd, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await {
break;
}
}
_ = metrics_tick.tick() => {
if self.running {
self.emit_metrics(&tx).await;
}
}
_ = keepalive_tick.tick() => {
if self.running {
self.handle_keepalive(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx, &mut proxy_rx).await;
}
}
_ = retransmit_tick.tick() => {
if self.running {
self.handle_retransmit(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await;
}
}
proxy_ev = proxy_rx.recv(), if self.running && sessions_opt.as_ref().map(|s| {
s.iter().any(|ses| ses.machine.in_flight_count() < ses.machine.cwnd_packets().clamp(16, 16384))
}).unwrap_or(true) => {
self.handle_proxy_event(proxy_ev, &mut sessions_opt, &mut stream_map, &tx, &proxy_tx).await;
}
}
}
tx.send(UiEvent::Log("Bridge stopped".to_string())).await.ok();
Ok(())
}
async fn handle_inbound_udp(
&mut self,
udp_msg: Option<(usize, Bytes)>,
sessions_opt: &mut Option<Vec<SessionState>>,
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
stream_map: &mut std::collections::HashMap<u16, usize>,
tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
) {
match udp_msg {
Some((session_index, inbound)) => {
self.metrics.bytes_recv.fetch_add(inbound.len() as u64, Ordering::Relaxed);
self.last_valid_recv = Instant::now();
if let Some(sessions) = sessions_opt.as_mut() {
if session_index < sessions.len() {
let session = &mut sessions[session_index];
let initial_action = match session.machine.on_event(OstpEvent::Inbound(inbound)) {
Ok(a) => a,
Err(e) => {
let _ = tx.send(UiEvent::Log(format!("Protocol decrypt error: {e}"))).await;
tracing::warn!("Inbound protocol error (session {}): {}", session_index, e);
return;
}
};
let mut actions_queue = std::collections::VecDeque::new();
actions_queue.push_back(initial_action);
while let Some(current_action) = actions_queue.pop_front() {
match current_action {
ProtocolAction::Multiple(nested) => {
for a in nested {
actions_queue.push_back(a);
}
}
ProtocolAction::DeliverApp(stream_id, dec_payload) => {
match RelayMessage::decode(&dec_payload) {
Ok(relay_msg) => {
match relay_msg {
RelayMessage::ConnectOk => {
let _ = tx.send(UiEvent::Log(format!("Relay CONNECT OK stream_id={stream_id}"))).await;
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::ConnectOk));
}
RelayMessage::Data(data) => {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Data(Bytes::from(data))));
}
RelayMessage::Close => {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Close));
}
RelayMessage::Error(msg) => {
let _ = tx.send(UiEvent::Log(format!("Relay error for stream {stream_id}: {msg}"))).await;
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error(msg)));
}
RelayMessage::Pong(ts) => {
let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
self.last_rtt_ms = now.saturating_sub(ts) as f64;
self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed);
}
RelayMessage::UdpAssociate => {}
RelayMessage::UdpData(target, data) => {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data))));
}
RelayMessage::KeepAlive | RelayMessage::Ping(_) | RelayMessage::Connect(_) => {}
}
}
Err(err) => {
let _ = tx.send(UiEvent::Log(format!("Relay decode error for stream {stream_id}: {err}"))).await;
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("relay decode failed".to_string())));
}
}
}
ProtocolAction::SendDatagram(frame) => {
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
}
_ => {}
}
}
}
}
}
None => {
let _ = tx.send(UiEvent::Log("UDP channel closed, resetting connection".to_string())).await;
self.running = false;
crate::sysproxy::disable_system_proxy();
*sessions_opt = None;
*udp_rx_opt = None;
stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "udp reader closed");
let _ = tx.send(UiEvent::TunnelStopped).await;
}
}
}
async fn handle_bridge_cmd(
&mut self,
cmd: Option<BridgeCommand>,
sessions_opt: &mut Option<Vec<SessionState>>,
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
stream_map: &mut std::collections::HashMap<u16, usize>,
tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
) -> bool {
match cmd {
Some(BridgeCommand::ToggleTunnel) => {
if self.running {
self.running = false;
self.metrics.connection_state.store(0, Ordering::Relaxed);
*proxy_guard = None;
*sessions_opt = None;
*udp_rx_opt = None;
stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "manual stop");
tx.send(UiEvent::TunnelStopped).await.ok();
let stop_msg = if self.mode == "tun" { "TUN tunnel stopped" } else { "Bridge stopped" };
tx.send(UiEvent::Log(stop_msg.to_string())).await.ok();
} else {
tx.send(UiEvent::Log("Connecting to remote server...".to_string())).await.ok();
tx.send(UiEvent::Metrics { status: ConnectionStatus::Handshaking, rtt_ms: 0.0, throughput_bps: 0 }).await.ok();
self.metrics.connection_state.store(1, Ordering::Relaxed);
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
let (udp_tx, udp_rx) = mpsc::channel(100000);
let mut sessions = Vec::with_capacity(session_count);
let mut rtt_sum = 0.0;
let mut successful_sessions = 0;
for idx in 0..session_count {
let session_id: u32 = rand::thread_rng().gen();
match self.perform_handshake_with_id(&tx, session_id).await {
Ok((sock, mach, rtt)) => {
let session_index = sessions.len();
let socket_clone = sock.clone();
let udp_tx_clone = udp_tx.clone();
tokio::spawn(async move {
let mut buf = vec![0_u8; 65535];
loop {
match socket_clone.recv(&mut buf).await {
Ok(n) => {
let inbound = Bytes::copy_from_slice(&buf[..n]);
if udp_tx_clone.send((session_index, inbound)).await.is_err() {
break;
}
}
Err(e) => {
tracing::warn!("UDP socket recv error (session {}): {}", session_index, e);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
}
});
sessions.push(SessionState { socket: sock, machine: mach });
rtt_sum += rtt;
successful_sessions += 1;
}
Err(err) => {
tx.send(UiEvent::Log(format!("Multiplex session {}/{} handshake failed: {}. Continuing with remaining sessions...", idx + 1, session_count, err))).await.ok();
}
}
}
if sessions.is_empty() {
*proxy_guard = None;
tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok();
tx.send(UiEvent::TunnelStopped).await.ok();
self.metrics.connection_state.store(0, Ordering::Relaxed);
return True;
}
*udp_rx_opt = Some(udp_rx);
*sessions_opt = Some(sessions);
self.last_rtt_ms = rtt_sum / successful_sessions as f64;
self.running = true;
self.last_sample_at = Instant::now();
self.last_valid_recv = Instant::now();
let sys_proxy_addr = self.proxy_addr.replace("0.0.0.0:", "127.0.0.1:");
*proxy_guard = Some(crate::sysproxy::SystemProxyGuard::enable(&sys_proxy_addr));
tx.send(UiEvent::Metrics {
status: ConnectionStatus::Established,
rtt_ms: self.last_rtt_ms,
throughput_bps: 0,
}).await.ok();
self.metrics.connection_state.store(2, Ordering::Relaxed);
let start_msg = if self.mode == "tun" { "TUN tunnel established" } else { "Connection established" };
tx.send(UiEvent::Log(start_msg.to_string())).await.ok();
for session in sessions_opt.as_mut().unwrap().iter_mut() {
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode());
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) {
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await;
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
}
}
}
}
Some(BridgeCommand::NextProfile) => {
self.profile = next_profile(self.profile);
tx.send(UiEvent::ProfileChanged(self.profile)).await.ok();
tx.send(UiEvent::Log(format!("Obfuscation profile switched to {:?}", self.profile))).await.ok();
}
Some(BridgeCommand::NetworkChanged) => {
if self.running {
let _ = tx.send(UiEvent::Log("Network changed — starting immediate reconnect".to_string())).await;
self.metrics.connection_state.store(1, Ordering::Relaxed);
self.last_valid_recv = Instant::now() - Duration::from_secs(100);
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
let (udp_tx, udp_rx) = mpsc::channel(100000);
let mut new_sessions = Vec::with_capacity(session_count);
let mut successful_sessions = 0;
let mut rtt_sum = 0.0;
for idx in 0..session_count {
let session_id: u32 = rand::thread_rng().gen();
match self.perform_handshake_with_id(&tx, session_id).await {
Ok((sock, mach, rtt)) => {
let session_index = new_sessions.len();
let socket_clone = sock.clone();
let udp_tx_clone = udp_tx.clone();
tokio::spawn(async move {
let mut buf = vec![0_u8; 65535];
loop {
match socket_clone.recv(&mut buf).await {
Ok(n) => {
let inbound = Bytes::copy_from_slice(&buf[..n]);
if udp_tx_clone.send((session_index, inbound)).await.is_err() { break; }
}
Err(e) => {
tracing::warn!("UDP recv error (network-change session {}): {}", session_index, e);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
}
});
new_sessions.push(SessionState { socket: sock, machine: mach });
rtt_sum += rtt;
successful_sessions += 1;
}
Err(err) => {
let _ = tx.send(UiEvent::Log(format!("NetworkChanged reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await;
}
}
}
if !new_sessions.is_empty() {
*sessions_opt = Some(new_sessions);
*udp_rx_opt = Some(udp_rx);
self.last_rtt_ms = rtt_sum / successful_sessions as f64;
self.last_valid_recv = Instant::now();
stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "network changed");
self.metrics.connection_state.store(2, Ordering::Relaxed);
let _ = tx.send(UiEvent::Log("NetworkChanged reconnect successful!".to_string())).await;
} else {
let _ = tx.send(UiEvent::Log("NetworkChanged reconnect failed — will retry on keepalive tick".to_string())).await;
}
}
}
Some(BridgeCommand::ReloadConfig) => {
match ClientConfig::reload_from_json_near_binary() {
Ok(cfg) => {
self.apply_runtime_config(&cfg);
tx.send(UiEvent::Log("Runtime config reloaded".to_string())).await.ok();
if self.running {
self.running = false;
self.metrics.connection_state.store(0, Ordering::Relaxed);
*proxy_guard = None;
*sessions_opt = None;
stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "config reload");
let _ = tx.send(UiEvent::TunnelStopped).await;
}
}
Err(err) => {
let _ = tx.send(UiEvent::Log(format!("Config reload failed: {err}"))).await;
}
}
}
Some(BridgeCommand::Shutdown) | None => {
self.running = false;
*proxy_guard = None;
return False;
}
}
True
}
async fn handle_keepalive(
&mut self,
sessions_opt: &mut Option<Vec<SessionState>>,
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
stream_map: &mut std::collections::HashMap<u16, usize>,
tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
proxy_rx: &mut mpsc::Receiver<ProxyEvent>,
) {
if self.last_valid_recv.elapsed().as_secs() > 25 {
let elapsed = self.last_valid_recv.elapsed().as_secs();
if elapsed > 180 {
let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await;
self.running = false;
*proxy_guard = None;
*sessions_opt = None;
stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout");
let _ = tx.send(UiEvent::TunnelStopped).await;
self.metrics.connection_state.store(0, Ordering::Relaxed);
return;
}
let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await;
self.metrics.connection_state.store(1, Ordering::Relaxed);
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
let (udp_tx, udp_rx) = mpsc::channel(100000);
let mut new_sessions = Vec::with_capacity(session_count);
let mut successful_sessions = 0;
let mut rtt_sum = 0.0;
for idx in 0..session_count {
let session_id: u32 = rand::thread_rng().gen();
match self.perform_handshake_with_id(&tx, session_id).await {
Ok((sock, mach, rtt)) => {
let session_index = new_sessions.len();
let socket_clone = sock.clone();
let udp_tx_clone = udp_tx.clone();
tokio::spawn(async move {
let mut buf = vec![0_u8; 65535];
loop {
match socket_clone.recv(&mut buf).await {
Ok(n) => {
let inbound = Bytes::copy_from_slice(&buf[..n]);
if udp_tx_clone.send((session_index, inbound)).await.is_err() {
break;
}
}
Err(e) => {
tracing::warn!("UDP socket recv error (reconnect session {}): {}", session_index, e);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
}
});
new_sessions.push(SessionState { socket: sock, machine: mach });
rtt_sum += rtt;
successful_sessions += 1;
}
Err(err) => {
let _ = tx.send(UiEvent::Log(format!("Background reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await;
}
}
}
if !new_sessions.is_empty() {
*sessions_opt = Some(new_sessions);
*udp_rx_opt = Some(udp_rx);
self.last_rtt_ms = rtt_sum / successful_sessions as f64;
self.last_valid_recv = Instant::now();
self.metrics.connection_state.store(2, Ordering::Relaxed);
let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await;
for session in sessions_opt.as_mut().unwrap().iter_mut() {
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode());
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) {
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await;
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
}
}
stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "background reconnect");
let mut flushed = 0;
while let Ok(stale) = proxy_rx.try_recv() {
if let ProxyEvent::NewStream { stream_id, .. } = stale {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("connection reset".into())));
}
flushed += 1;
}
if flushed > 0 {
let _ = tx.send(UiEvent::Log(format!("Flushed {} stale proxy messages to prevent UDP burst", flushed))).await;
}
} else {
let _ = tx.send(UiEvent::Log("Background reconnect failed. Will retry on next tick...".into())).await;
}
}
if let Some(sessions) = sessions_opt.as_mut() {
for session in sessions.iter_mut() {
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode());
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) {
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
}
let ka_payload = Bytes::from(RelayMessage::KeepAlive.encode());
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ka_payload)) {
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
}
}
}
}
async fn handle_retransmit(
&mut self,
sessions_opt: &mut Option<Vec<SessionState>>,
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
stream_map: &mut std::collections::HashMap<u16, usize>,
tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
) {
let mut fatal_err = None;
if let Some(sessions) = sessions_opt.as_mut() {
for session in sessions.iter_mut() {
match session.machine.on_event(OstpEvent::Tick) {
Ok(action) => {
let mut queue = vec![action];
while let Some(current_action) = queue.pop() {
match current_action {
ProtocolAction::Multiple(nested) => {
for a in nested {
queue.push(a);
}
}
ProtocolAction::SendDatagram(frame) => {
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
}
_ => {}
}
}
}
Err(e) => {
fatal_err = Some(e);
break;
}
}
}
}
if let Some(e) = fatal_err {
let _ = tx.send(UiEvent::Log(format!("Protocol tick fatal error: {e}"))).await;
self.running = false;
*proxy_guard = None;
*sessions_opt = None;
*udp_rx_opt = None;
stream_map.clear();
self.reset_proxy_streams(&tx, &proxy_tx, "protocol fatal error");
let _ = tx.send(UiEvent::TunnelStopped).await;
self.metrics.connection_state.store(0, Ordering::Relaxed);
}
}
async fn handle_proxy_event(
&mut self,
proxy_ev: Option<ProxyEvent>,
sessions_opt: &mut Option<Vec<SessionState>>,
stream_map: &mut std::collections::HashMap<u16, usize>,
tx: &mpsc::Sender<UiEvent>,
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
) {
if let Some(ev) = proxy_ev {
if let Some(sessions) = sessions_opt.as_mut() {
if sessions.is_empty() {
if let ProxyEvent::NewStream { stream_id, .. } = ev {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into())));
}
return;
}
let (stream_id, relay_msg, is_close) = match ev {
ProxyEvent::NewStream { stream_id, target } => {
let _ = tx.send(UiEvent::Log(format!("Proxy CONNECT stream_id={stream_id} target={target}"))).await;
(stream_id, RelayMessage::Connect(target), false)
}
ProxyEvent::UdpAssociate { stream_id } => {
let _ = tx.send(UiEvent::Log(format!("Proxy UDP ASSOCIATE stream_id={stream_id}"))).await;
(stream_id, RelayMessage::UdpAssociate, false)
}
ProxyEvent::UdpData { stream_id, target, payload } => {
(stream_id, RelayMessage::UdpData(target, payload.to_vec()), false)
}
ProxyEvent::Data { stream_id, payload } => (stream_id, RelayMessage::Data(payload.to_vec()), false),
ProxyEvent::Close { stream_id } => {
let _ = tx.send(UiEvent::Log(format!("Proxy CLOSE stream_id={stream_id}"))).await;
(stream_id, RelayMessage::Close, true)
}
};
let len = sessions.len();
let session_index = *stream_map.entry(stream_id).or_insert_with(|| {
rand::thread_rng().gen_range(0..len)
});
if is_close {
stream_map.remove(&stream_id);
}
let session = &mut sessions[session_index];
let out_payload = Bytes::from(relay_msg.encode());
match session.machine.on_event(OstpEvent::Outbound(stream_id, out_payload)) {
Ok(ProtocolAction::SendDatagram(frame)) => {
if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() {
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
tracing::trace!("Outbound datagram sent stream_id={stream_id} bytes={}", frame.len());
}
}
Ok(ProtocolAction::Multiple(list)) => {
let mut sent = 0usize;
for item in list {
if let ProtocolAction::SendDatagram(frame) = item {
if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() {
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
sent += 1;
}
}
}
tracing::trace!("Outbound datagram batch stream_id={stream_id} sent={sent}");
}
Ok(ProtocolAction::Noop) => {
tracing::trace!("Outbound datagram noop stream_id={stream_id}");
}
Ok(_) => {
tracing::trace!("Outbound datagram unexpected action stream_id={stream_id}");
}
Err(e) => {
tracing::warn!("Protocol error packing outbound stream_id={}: {}", stream_id, e);
let _ = tx.send(UiEvent::Log(format!("Protocol error packing TCP: {e}"))).await;
}
}
} else {
if let ProxyEvent::NewStream { stream_id, .. } = ev {
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into())));
}
}
}
}
"""
with open("d:/ospab-projects/ostp/ostp-client/src/bridge.rs", "w", encoding="utf-8") as f:
f.write(prefix + new_run_and_helpers + suffix)
print("Done")